Merge pull request #5 from thekiwismarthome/v2.1.0---Themes

V2.1.0   themes
This commit is contained in:
thekiwismarthome
2026-02-28 22:16:05 +13:00
committed by GitHub
13 changed files with 106 additions and 10 deletions
@@ -48,6 +48,62 @@ from .. import get_storage
_LOGGER = logging.getLogger(__name__)
# =============================================================================
# ACCESS-CHECK HELPERS
# =============================================================================
def _user_can_access_list(lst, user) -> bool:
"""Return True if the user may read or write to this list.
Global lists (owner_id=None) are accessible to everyone.
Private lists are accessible to their owner, anyone in allowed_users, and admins.
"""
if lst.owner_id is None:
return True
if user is None:
return False
if user.is_admin or user.id == lst.owner_id:
return True
return user.id in (lst.allowed_users or [])
def _check_list_access(storage, connection, msg, list_id, require_owner=False):
"""Verify the connected user may access list_id.
Sends the appropriate WebSocket error if access is denied.
Returns the ShoppingList object on success, or None if an error was sent.
Args:
require_owner: When True, only the list owner (or an admin) is allowed.
Use for destructive/administrative operations.
"""
lst = storage.get_list(list_id)
if lst is None:
connection.send_error(msg["id"], "not_found", "List not found")
return None
user = connection.user
if require_owner:
if lst.owner_id is not None and not (user and (user.is_admin or user.id == lst.owner_id)):
connection.send_error(msg["id"], "forbidden", "Only the list owner can perform this action")
return None
else:
if not _user_can_access_list(lst, user):
connection.send_error(msg["id"], "forbidden", "You do not have access to this list")
return None
return lst
def _find_item_list_id(storage, item_id):
"""Return the list_id that contains item_id, or None if not found."""
for list_id, items in storage._items.items():
for item in items:
if item.id == item_id:
return list_id
return None
# =============================================================================
# LIST HANDLERS
# =============================================================================
@@ -62,16 +118,28 @@ async def websocket_subscribe(
msg: dict,
) -> None:
"""Subscribe to shopping list manager events via WebSocket."""
storage = get_storage(hass)
@callback
def forward_event(event):
"""Forward HA bus event to WebSocket connection."""
"""Forward HA bus event to WebSocket connection.
Events that reference a list_id are only forwarded if the connected
user has access to that list, preventing cross-user data leakage.
"""
data = event.data
list_id = data.get("list_id")
if list_id:
lst = storage.get_list(list_id)
if lst and not _user_can_access_list(lst, connection.user):
return # skip — user cannot see this list
connection.send_message(
websocket_api.event_message(
msg["id"],
{
"event_type": event.event_type,
"data": event.data,
"data": data,
}
)
)
@@ -248,7 +316,10 @@ async def websocket_update_list(
"""Handle update list command."""
storage = get_storage(hass)
list_id = msg["list_id"]
if _check_list_access(storage, connection, msg, list_id, require_owner=True) is None:
return
# Build update kwargs
update_data = {}
if "name" in msg:
@@ -334,7 +405,10 @@ async def websocket_set_active_list(
"""Handle set active list command."""
storage = get_storage(hass)
list_id = msg["list_id"]
if _check_list_access(storage, connection, msg, list_id) is None:
return
success = await storage.set_active_list(list_id)
if not success:
@@ -369,7 +443,10 @@ def websocket_get_items(
"""Handle get items command."""
storage = get_storage(hass)
list_id = msg["list_id"]
if _check_list_access(storage, connection, msg, list_id) is None:
return
items = storage.get_items(list_id)
connection.send_result(
@@ -404,7 +481,10 @@ async def websocket_add_item(
"""Handle add item command."""
storage = get_storage(hass)
list_id = msg["list_id"]
if _check_list_access(storage, connection, msg, list_id) is None:
return
# Build item data
item_data = {
"name": msg["name"],
@@ -463,7 +543,14 @@ async def websocket_update_item(
"""Handle update item command."""
storage = get_storage(hass)
item_id = msg["item_id"]
list_id = _find_item_list_id(storage, item_id)
if list_id is None:
connection.send_error(msg["id"], "not_found", "Item not found")
return
if _check_list_access(storage, connection, msg, list_id) is None:
return
# Build update data
update_data = {}
update_fields = ["name", "quantity", "unit", "note", "price", "category_id", "image_url"]
@@ -548,7 +635,14 @@ async def websocket_delete_item(
"""Handle delete item command."""
storage = get_storage(hass)
item_id = msg["item_id"]
list_id = _find_item_list_id(storage, item_id)
if list_id is None:
connection.send_error(msg["id"], "not_found", "Item not found")
return
if _check_list_access(storage, connection, msg, list_id) is None:
return
success = await storage.delete_item(item_id)
if not success:
@@ -931,10 +1025,12 @@ def websocket_get_integration_settings(
)
_VALID_COUNTRIES = ["NZ", "AU", "US", "GB", "CA"]
@websocket_api.websocket_command(
{
vol.Required("type"): "shopping_list_manager/set_country",
vol.Required("country"): str,
vol.Required("country"): vol.In(_VALID_COUNTRIES),
}
)
@websocket_api.async_response