diff --git a/custom_components/shopping_list_manager/websocket/handlers.py b/custom_components/shopping_list_manager/websocket/handlers.py index 57a696f..f00c527 100644 --- a/custom_components/shopping_list_manager/websocket/handlers.py +++ b/custom_components/shopping_list_manager/websocket/handlers.py @@ -48,6 +48,62 @@ from .. import get_storage _LOGGER = logging.getLogger(__name__) +# ============================================================================= +# ACCESS-CHECK HELPERS +# ============================================================================= + +def _user_can_access_list(lst, user) -> bool: + """Return True if the user may read or write to this list. + + Global lists (owner_id=None) are accessible to everyone. + Private lists are accessible to their owner, anyone in allowed_users, and admins. + """ + if lst.owner_id is None: + return True + if user is None: + return False + if user.is_admin or user.id == lst.owner_id: + return True + return user.id in (lst.allowed_users or []) + + +def _check_list_access(storage, connection, msg, list_id, require_owner=False): + """Verify the connected user may access list_id. + + Sends the appropriate WebSocket error if access is denied. + Returns the ShoppingList object on success, or None if an error was sent. + + Args: + require_owner: When True, only the list owner (or an admin) is allowed. + Use for destructive/administrative operations. + """ + lst = storage.get_list(list_id) + if lst is None: + connection.send_error(msg["id"], "not_found", "List not found") + return None + + user = connection.user + if require_owner: + if lst.owner_id is not None and not (user and (user.is_admin or user.id == lst.owner_id)): + connection.send_error(msg["id"], "forbidden", "Only the list owner can perform this action") + return None + else: + if not _user_can_access_list(lst, user): + connection.send_error(msg["id"], "forbidden", "You do not have access to this list") + return None + + return lst + + +def _find_item_list_id(storage, item_id): + """Return the list_id that contains item_id, or None if not found.""" + for list_id, items in storage._items.items(): + for item in items: + if item.id == item_id: + return list_id + return None + + # ============================================================================= # LIST HANDLERS # ============================================================================= @@ -62,16 +118,28 @@ async def websocket_subscribe( msg: dict, ) -> None: """Subscribe to shopping list manager events via WebSocket.""" - + storage = get_storage(hass) + @callback def forward_event(event): - """Forward HA bus event to WebSocket connection.""" + """Forward HA bus event to WebSocket connection. + + Events that reference a list_id are only forwarded if the connected + user has access to that list, preventing cross-user data leakage. + """ + data = event.data + list_id = data.get("list_id") + if list_id: + lst = storage.get_list(list_id) + if lst and not _user_can_access_list(lst, connection.user): + return # skip — user cannot see this list + connection.send_message( websocket_api.event_message( msg["id"], { "event_type": event.event_type, - "data": event.data, + "data": data, } ) ) @@ -248,7 +316,10 @@ async def websocket_update_list( """Handle update list command.""" storage = get_storage(hass) list_id = msg["list_id"] - + + if _check_list_access(storage, connection, msg, list_id, require_owner=True) is None: + return + # Build update kwargs update_data = {} if "name" in msg: @@ -334,7 +405,10 @@ async def websocket_set_active_list( """Handle set active list command.""" storage = get_storage(hass) list_id = msg["list_id"] - + + if _check_list_access(storage, connection, msg, list_id) is None: + return + success = await storage.set_active_list(list_id) if not success: @@ -369,7 +443,10 @@ def websocket_get_items( """Handle get items command.""" storage = get_storage(hass) list_id = msg["list_id"] - + + if _check_list_access(storage, connection, msg, list_id) is None: + return + items = storage.get_items(list_id) connection.send_result( @@ -404,7 +481,10 @@ async def websocket_add_item( """Handle add item command.""" storage = get_storage(hass) list_id = msg["list_id"] - + + if _check_list_access(storage, connection, msg, list_id) is None: + return + # Build item data item_data = { "name": msg["name"], @@ -463,7 +543,14 @@ async def websocket_update_item( """Handle update item command.""" storage = get_storage(hass) item_id = msg["item_id"] - + + list_id = _find_item_list_id(storage, item_id) + if list_id is None: + connection.send_error(msg["id"], "not_found", "Item not found") + return + if _check_list_access(storage, connection, msg, list_id) is None: + return + # Build update data update_data = {} update_fields = ["name", "quantity", "unit", "note", "price", "category_id", "image_url"] @@ -548,7 +635,14 @@ async def websocket_delete_item( """Handle delete item command.""" storage = get_storage(hass) item_id = msg["item_id"] - + + list_id = _find_item_list_id(storage, item_id) + if list_id is None: + connection.send_error(msg["id"], "not_found", "Item not found") + return + if _check_list_access(storage, connection, msg, list_id) is None: + return + success = await storage.delete_item(item_id) if not success: @@ -931,10 +1025,12 @@ def websocket_get_integration_settings( ) +_VALID_COUNTRIES = ["NZ", "AU", "US", "GB", "CA"] + @websocket_api.websocket_command( { vol.Required("type"): "shopping_list_manager/set_country", - vol.Required("country"): str, + vol.Required("country"): vol.In(_VALID_COUNTRIES), } ) @websocket_api.async_response