mirror of
https://github.com/thekiwismarthome/shopping-list-manager.git
synced 2026-05-01 11:46:30 +00:00
feat: Implement comprehensive access control for shopping list and item WebSocket handlers and refine country code validation.
This commit is contained in:
@@ -48,6 +48,62 @@ from .. import get_storage
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# ACCESS-CHECK HELPERS
|
||||
# =============================================================================
|
||||
|
||||
def _user_can_access_list(lst, user) -> bool:
|
||||
"""Return True if the user may read or write to this list.
|
||||
|
||||
Global lists (owner_id=None) are accessible to everyone.
|
||||
Private lists are accessible to their owner, anyone in allowed_users, and admins.
|
||||
"""
|
||||
if lst.owner_id is None:
|
||||
return True
|
||||
if user is None:
|
||||
return False
|
||||
if user.is_admin or user.id == lst.owner_id:
|
||||
return True
|
||||
return user.id in (lst.allowed_users or [])
|
||||
|
||||
|
||||
def _check_list_access(storage, connection, msg, list_id, require_owner=False):
|
||||
"""Verify the connected user may access list_id.
|
||||
|
||||
Sends the appropriate WebSocket error if access is denied.
|
||||
Returns the ShoppingList object on success, or None if an error was sent.
|
||||
|
||||
Args:
|
||||
require_owner: When True, only the list owner (or an admin) is allowed.
|
||||
Use for destructive/administrative operations.
|
||||
"""
|
||||
lst = storage.get_list(list_id)
|
||||
if lst is None:
|
||||
connection.send_error(msg["id"], "not_found", "List not found")
|
||||
return None
|
||||
|
||||
user = connection.user
|
||||
if require_owner:
|
||||
if lst.owner_id is not None and not (user and (user.is_admin or user.id == lst.owner_id)):
|
||||
connection.send_error(msg["id"], "forbidden", "Only the list owner can perform this action")
|
||||
return None
|
||||
else:
|
||||
if not _user_can_access_list(lst, user):
|
||||
connection.send_error(msg["id"], "forbidden", "You do not have access to this list")
|
||||
return None
|
||||
|
||||
return lst
|
||||
|
||||
|
||||
def _find_item_list_id(storage, item_id):
|
||||
"""Return the list_id that contains item_id, or None if not found."""
|
||||
for list_id, items in storage._items.items():
|
||||
for item in items:
|
||||
if item.id == item_id:
|
||||
return list_id
|
||||
return None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# LIST HANDLERS
|
||||
# =============================================================================
|
||||
@@ -62,16 +118,28 @@ async def websocket_subscribe(
|
||||
msg: dict,
|
||||
) -> None:
|
||||
"""Subscribe to shopping list manager events via WebSocket."""
|
||||
|
||||
storage = get_storage(hass)
|
||||
|
||||
@callback
|
||||
def forward_event(event):
|
||||
"""Forward HA bus event to WebSocket connection."""
|
||||
"""Forward HA bus event to WebSocket connection.
|
||||
|
||||
Events that reference a list_id are only forwarded if the connected
|
||||
user has access to that list, preventing cross-user data leakage.
|
||||
"""
|
||||
data = event.data
|
||||
list_id = data.get("list_id")
|
||||
if list_id:
|
||||
lst = storage.get_list(list_id)
|
||||
if lst and not _user_can_access_list(lst, connection.user):
|
||||
return # skip — user cannot see this list
|
||||
|
||||
connection.send_message(
|
||||
websocket_api.event_message(
|
||||
msg["id"],
|
||||
{
|
||||
"event_type": event.event_type,
|
||||
"data": event.data,
|
||||
"data": data,
|
||||
}
|
||||
)
|
||||
)
|
||||
@@ -248,7 +316,10 @@ async def websocket_update_list(
|
||||
"""Handle update list command."""
|
||||
storage = get_storage(hass)
|
||||
list_id = msg["list_id"]
|
||||
|
||||
|
||||
if _check_list_access(storage, connection, msg, list_id, require_owner=True) is None:
|
||||
return
|
||||
|
||||
# Build update kwargs
|
||||
update_data = {}
|
||||
if "name" in msg:
|
||||
@@ -334,7 +405,10 @@ async def websocket_set_active_list(
|
||||
"""Handle set active list command."""
|
||||
storage = get_storage(hass)
|
||||
list_id = msg["list_id"]
|
||||
|
||||
|
||||
if _check_list_access(storage, connection, msg, list_id) is None:
|
||||
return
|
||||
|
||||
success = await storage.set_active_list(list_id)
|
||||
|
||||
if not success:
|
||||
@@ -369,7 +443,10 @@ def websocket_get_items(
|
||||
"""Handle get items command."""
|
||||
storage = get_storage(hass)
|
||||
list_id = msg["list_id"]
|
||||
|
||||
|
||||
if _check_list_access(storage, connection, msg, list_id) is None:
|
||||
return
|
||||
|
||||
items = storage.get_items(list_id)
|
||||
|
||||
connection.send_result(
|
||||
@@ -404,7 +481,10 @@ async def websocket_add_item(
|
||||
"""Handle add item command."""
|
||||
storage = get_storage(hass)
|
||||
list_id = msg["list_id"]
|
||||
|
||||
|
||||
if _check_list_access(storage, connection, msg, list_id) is None:
|
||||
return
|
||||
|
||||
# Build item data
|
||||
item_data = {
|
||||
"name": msg["name"],
|
||||
@@ -463,7 +543,14 @@ async def websocket_update_item(
|
||||
"""Handle update item command."""
|
||||
storage = get_storage(hass)
|
||||
item_id = msg["item_id"]
|
||||
|
||||
|
||||
list_id = _find_item_list_id(storage, item_id)
|
||||
if list_id is None:
|
||||
connection.send_error(msg["id"], "not_found", "Item not found")
|
||||
return
|
||||
if _check_list_access(storage, connection, msg, list_id) is None:
|
||||
return
|
||||
|
||||
# Build update data
|
||||
update_data = {}
|
||||
update_fields = ["name", "quantity", "unit", "note", "price", "category_id", "image_url"]
|
||||
@@ -548,7 +635,14 @@ async def websocket_delete_item(
|
||||
"""Handle delete item command."""
|
||||
storage = get_storage(hass)
|
||||
item_id = msg["item_id"]
|
||||
|
||||
|
||||
list_id = _find_item_list_id(storage, item_id)
|
||||
if list_id is None:
|
||||
connection.send_error(msg["id"], "not_found", "Item not found")
|
||||
return
|
||||
if _check_list_access(storage, connection, msg, list_id) is None:
|
||||
return
|
||||
|
||||
success = await storage.delete_item(item_id)
|
||||
|
||||
if not success:
|
||||
@@ -931,10 +1025,12 @@ def websocket_get_integration_settings(
|
||||
)
|
||||
|
||||
|
||||
_VALID_COUNTRIES = ["NZ", "AU", "US", "GB", "CA"]
|
||||
|
||||
@websocket_api.websocket_command(
|
||||
{
|
||||
vol.Required("type"): "shopping_list_manager/set_country",
|
||||
vol.Required("country"): str,
|
||||
vol.Required("country"): vol.In(_VALID_COUNTRIES),
|
||||
}
|
||||
)
|
||||
@websocket_api.async_response
|
||||
|
||||
Reference in New Issue
Block a user