feat: Implement comprehensive access control for shopping list and item WebSocket handlers and refine country code validation.

This commit is contained in:
thekiwismarthome
2026-02-28 10:34:40 +13:00
parent 2a8b12a07e
commit ae133ae59b
@@ -48,6 +48,62 @@ from .. import get_storage
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
# =============================================================================
# ACCESS-CHECK HELPERS
# =============================================================================
def _user_can_access_list(lst, user) -> bool:
"""Return True if the user may read or write to this list.
Global lists (owner_id=None) are accessible to everyone.
Private lists are accessible to their owner, anyone in allowed_users, and admins.
"""
if lst.owner_id is None:
return True
if user is None:
return False
if user.is_admin or user.id == lst.owner_id:
return True
return user.id in (lst.allowed_users or [])
def _check_list_access(storage, connection, msg, list_id, require_owner=False):
"""Verify the connected user may access list_id.
Sends the appropriate WebSocket error if access is denied.
Returns the ShoppingList object on success, or None if an error was sent.
Args:
require_owner: When True, only the list owner (or an admin) is allowed.
Use for destructive/administrative operations.
"""
lst = storage.get_list(list_id)
if lst is None:
connection.send_error(msg["id"], "not_found", "List not found")
return None
user = connection.user
if require_owner:
if lst.owner_id is not None and not (user and (user.is_admin or user.id == lst.owner_id)):
connection.send_error(msg["id"], "forbidden", "Only the list owner can perform this action")
return None
else:
if not _user_can_access_list(lst, user):
connection.send_error(msg["id"], "forbidden", "You do not have access to this list")
return None
return lst
def _find_item_list_id(storage, item_id):
"""Return the list_id that contains item_id, or None if not found."""
for list_id, items in storage._items.items():
for item in items:
if item.id == item_id:
return list_id
return None
# ============================================================================= # =============================================================================
# LIST HANDLERS # LIST HANDLERS
# ============================================================================= # =============================================================================
@@ -62,16 +118,28 @@ async def websocket_subscribe(
msg: dict, msg: dict,
) -> None: ) -> None:
"""Subscribe to shopping list manager events via WebSocket.""" """Subscribe to shopping list manager events via WebSocket."""
storage = get_storage(hass)
@callback @callback
def forward_event(event): def forward_event(event):
"""Forward HA bus event to WebSocket connection.""" """Forward HA bus event to WebSocket connection.
Events that reference a list_id are only forwarded if the connected
user has access to that list, preventing cross-user data leakage.
"""
data = event.data
list_id = data.get("list_id")
if list_id:
lst = storage.get_list(list_id)
if lst and not _user_can_access_list(lst, connection.user):
return # skip — user cannot see this list
connection.send_message( connection.send_message(
websocket_api.event_message( websocket_api.event_message(
msg["id"], msg["id"],
{ {
"event_type": event.event_type, "event_type": event.event_type,
"data": event.data, "data": data,
} }
) )
) )
@@ -249,6 +317,9 @@ async def websocket_update_list(
storage = get_storage(hass) storage = get_storage(hass)
list_id = msg["list_id"] list_id = msg["list_id"]
if _check_list_access(storage, connection, msg, list_id, require_owner=True) is None:
return
# Build update kwargs # Build update kwargs
update_data = {} update_data = {}
if "name" in msg: if "name" in msg:
@@ -335,6 +406,9 @@ async def websocket_set_active_list(
storage = get_storage(hass) storage = get_storage(hass)
list_id = msg["list_id"] list_id = msg["list_id"]
if _check_list_access(storage, connection, msg, list_id) is None:
return
success = await storage.set_active_list(list_id) success = await storage.set_active_list(list_id)
if not success: if not success:
@@ -370,6 +444,9 @@ def websocket_get_items(
storage = get_storage(hass) storage = get_storage(hass)
list_id = msg["list_id"] list_id = msg["list_id"]
if _check_list_access(storage, connection, msg, list_id) is None:
return
items = storage.get_items(list_id) items = storage.get_items(list_id)
connection.send_result( connection.send_result(
@@ -405,6 +482,9 @@ async def websocket_add_item(
storage = get_storage(hass) storage = get_storage(hass)
list_id = msg["list_id"] list_id = msg["list_id"]
if _check_list_access(storage, connection, msg, list_id) is None:
return
# Build item data # Build item data
item_data = { item_data = {
"name": msg["name"], "name": msg["name"],
@@ -464,6 +544,13 @@ async def websocket_update_item(
storage = get_storage(hass) storage = get_storage(hass)
item_id = msg["item_id"] item_id = msg["item_id"]
list_id = _find_item_list_id(storage, item_id)
if list_id is None:
connection.send_error(msg["id"], "not_found", "Item not found")
return
if _check_list_access(storage, connection, msg, list_id) is None:
return
# Build update data # Build update data
update_data = {} update_data = {}
update_fields = ["name", "quantity", "unit", "note", "price", "category_id", "image_url"] update_fields = ["name", "quantity", "unit", "note", "price", "category_id", "image_url"]
@@ -549,6 +636,13 @@ async def websocket_delete_item(
storage = get_storage(hass) storage = get_storage(hass)
item_id = msg["item_id"] item_id = msg["item_id"]
list_id = _find_item_list_id(storage, item_id)
if list_id is None:
connection.send_error(msg["id"], "not_found", "Item not found")
return
if _check_list_access(storage, connection, msg, list_id) is None:
return
success = await storage.delete_item(item_id) success = await storage.delete_item(item_id)
if not success: if not success:
@@ -931,10 +1025,12 @@ def websocket_get_integration_settings(
) )
_VALID_COUNTRIES = ["NZ", "AU", "US", "GB", "CA"]
@websocket_api.websocket_command( @websocket_api.websocket_command(
{ {
vol.Required("type"): "shopping_list_manager/set_country", vol.Required("type"): "shopping_list_manager/set_country",
vol.Required("country"): str, vol.Required("country"): vol.In(_VALID_COUNTRIES),
} }
) )
@websocket_api.async_response @websocket_api.async_response