mirror of
https://github.com/thekiwismarthome/shopping-list-manager.git
synced 2026-05-01 11:46:30 +00:00
Merge pull request #4 from thekiwismarthome/v2.0.5-Country-Specific-Product-Catalogues
V2.0.5 country specific product catalogues
This commit is contained in:
@@ -202,7 +202,61 @@ async def _async_register_websocket_handlers(
|
||||
hass,
|
||||
handlers.websocket_get_categories,
|
||||
)
|
||||
|
||||
|
||||
# Integration settings handlers
|
||||
websocket_api.async_register_command(
|
||||
hass,
|
||||
handlers.websocket_get_integration_settings,
|
||||
)
|
||||
websocket_api.async_register_command(
|
||||
hass,
|
||||
handlers.websocket_set_country,
|
||||
)
|
||||
|
||||
# Backup / Restore handlers
|
||||
websocket_api.async_register_command(
|
||||
hass,
|
||||
handlers.websocket_export_data,
|
||||
)
|
||||
websocket_api.async_register_command(
|
||||
hass,
|
||||
handlers.websocket_import_data,
|
||||
)
|
||||
|
||||
# List members handler
|
||||
websocket_api.async_register_command(
|
||||
hass,
|
||||
handlers.websocket_update_list_members,
|
||||
)
|
||||
|
||||
# HA users handler
|
||||
websocket_api.async_register_command(
|
||||
hass,
|
||||
handlers.websocket_get_ha_users,
|
||||
)
|
||||
|
||||
# Loyalty card handlers
|
||||
websocket_api.async_register_command(
|
||||
hass,
|
||||
handlers.websocket_get_loyalty_cards,
|
||||
)
|
||||
websocket_api.async_register_command(
|
||||
hass,
|
||||
handlers.websocket_add_loyalty_card,
|
||||
)
|
||||
websocket_api.async_register_command(
|
||||
hass,
|
||||
handlers.websocket_update_loyalty_card,
|
||||
)
|
||||
websocket_api.async_register_command(
|
||||
hass,
|
||||
handlers.websocket_delete_loyalty_card,
|
||||
)
|
||||
websocket_api.async_register_command(
|
||||
hass,
|
||||
handlers.websocket_update_loyalty_card_members,
|
||||
)
|
||||
|
||||
_LOGGER.debug("WebSocket handlers registered")
|
||||
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ STORAGE_KEY_LISTS = f"{DOMAIN}.lists"
|
||||
STORAGE_KEY_ITEMS = f"{DOMAIN}.items"
|
||||
STORAGE_KEY_PRODUCTS = f"{DOMAIN}.products"
|
||||
STORAGE_KEY_CATEGORIES = f"{DOMAIN}.categories"
|
||||
STORAGE_KEY_LOYALTY_CARDS = f"{DOMAIN}.loyalty_cards"
|
||||
|
||||
# WebSocket Commands - Lists
|
||||
WS_TYPE_LISTS_GET_ALL = f"{DOMAIN}/lists/get_all"
|
||||
@@ -16,6 +17,10 @@ WS_TYPE_LISTS_CREATE = f"{DOMAIN}/lists/create"
|
||||
WS_TYPE_LISTS_UPDATE = f"{DOMAIN}/lists/update"
|
||||
WS_TYPE_LISTS_DELETE = f"{DOMAIN}/lists/delete"
|
||||
WS_TYPE_LISTS_SET_ACTIVE = f"{DOMAIN}/lists/set_active"
|
||||
WS_TYPE_LISTS_UPDATE_MEMBERS = f"{DOMAIN}/lists/update_members"
|
||||
|
||||
# WebSocket Commands - Users
|
||||
WS_TYPE_USERS_GET_ALL = f"{DOMAIN}/users/get_all"
|
||||
|
||||
# WebSocket Commands - Items
|
||||
WS_TYPE_ITEMS_GET = f"{DOMAIN}/items/get"
|
||||
@@ -39,6 +44,13 @@ WS_TYPE_PRODUCTS_DELETE = f"{DOMAIN}/products/delete"
|
||||
WS_TYPE_CATEGORIES_GET_ALL = f"{DOMAIN}/categories/get_all"
|
||||
WS_TYPE_CATEGORIES_REORDER = f"{DOMAIN}/categories/reorder"
|
||||
|
||||
# WebSocket Commands - Loyalty Cards
|
||||
WS_TYPE_LOYALTY_GET_ALL = f"{DOMAIN}/loyalty/get_all"
|
||||
WS_TYPE_LOYALTY_ADD = f"{DOMAIN}/loyalty/add"
|
||||
WS_TYPE_LOYALTY_UPDATE = f"{DOMAIN}/loyalty/update"
|
||||
WS_TYPE_LOYALTY_DELETE = f"{DOMAIN}/loyalty/delete"
|
||||
WS_TYPE_LOYALTY_UPDATE_MEMBERS = f"{DOMAIN}/loyalty/update_members"
|
||||
|
||||
# WebSocket Commands - Subscriptions
|
||||
WS_TYPE_SUBSCRIBE = f"{DOMAIN}/subscribe"
|
||||
WS_TYPE_UNSUBSCRIBE = f"{DOMAIN}/unsubscribe"
|
||||
|
||||
@@ -96,6 +96,28 @@ class Item:
|
||||
self.estimated_total = self.quantity * self.price
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoyaltyCard:
|
||||
"""Loyalty card model."""
|
||||
id: str
|
||||
name: str
|
||||
number: str
|
||||
barcode: str = ""
|
||||
barcode_type: str = "barcode" # "barcode" or "qrcode"
|
||||
logo: str = ""
|
||||
notes: str = ""
|
||||
color: str = "#9fa8da"
|
||||
created_at: str = field(default_factory=current_timestamp)
|
||||
updated_at: str = field(default_factory=current_timestamp)
|
||||
# Ownership: None = visible to all users; set = private to owner + allowed_users
|
||||
owner_id: Optional[str] = None
|
||||
allowed_users: List[str] = field(default_factory=list)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return asdict(self)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ShoppingList:
|
||||
"""Shopping list model."""
|
||||
@@ -107,6 +129,9 @@ class ShoppingList:
|
||||
item_order: List[str] = field(default_factory=list)
|
||||
category_order: List[str] = field(default_factory=list)
|
||||
active: bool = False
|
||||
# Ownership: None = visible to all users; set = private to owner + allowed_users
|
||||
owner_id: Optional[str] = None
|
||||
allowed_users: List[str] = field(default_factory=list)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
"""Storage management for Shopping List Manager."""
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
from typing import Dict, List, Optional, Any
|
||||
from .utils.search import ProductSearch
|
||||
from homeassistant.core import HomeAssistant
|
||||
@@ -11,9 +14,10 @@ from .const import (
|
||||
STORAGE_KEY_ITEMS,
|
||||
STORAGE_KEY_PRODUCTS,
|
||||
STORAGE_KEY_CATEGORIES,
|
||||
STORAGE_KEY_LOYALTY_CARDS,
|
||||
)
|
||||
from .data.catalog_loader import load_product_catalog
|
||||
from .models import ShoppingList, Item, Product, Category, generate_id
|
||||
from .models import ShoppingList, Item, Product, Category, LoyaltyCard, generate_id
|
||||
from .data.category_loader import load_categories
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
@@ -37,11 +41,13 @@ class ShoppingListStorage:
|
||||
self._store_items = Store(hass, STORAGE_VERSION, STORAGE_KEY_ITEMS)
|
||||
self._store_products = Store(hass, STORAGE_VERSION, STORAGE_KEY_PRODUCTS)
|
||||
self._store_categories = Store(hass, STORAGE_VERSION, STORAGE_KEY_CATEGORIES)
|
||||
|
||||
self._store_loyalty_cards = Store(hass, STORAGE_VERSION, STORAGE_KEY_LOYALTY_CARDS)
|
||||
|
||||
self._lists: Dict[str, ShoppingList] = {}
|
||||
self._items: Dict[str, List[Item]] = {}
|
||||
self._products: Dict[str, Product] = {}
|
||||
self._categories: List[Category] = []
|
||||
self._loyalty_cards: Dict[str, LoyaltyCard] = {}
|
||||
self._search_engine: Optional[ProductSearch] = None
|
||||
|
||||
async def async_load(self) -> None:
|
||||
@@ -141,6 +147,15 @@ class ShoppingListStorage:
|
||||
await self._save_products()
|
||||
_LOGGER.info("Successfully imported %d products from catalog", len(self._products))
|
||||
|
||||
# Load loyalty cards
|
||||
loyalty_data = await self._store_loyalty_cards.async_load()
|
||||
if loyalty_data:
|
||||
self._loyalty_cards = {
|
||||
card_id: LoyaltyCard(**card_data)
|
||||
for card_id, card_data in loyalty_data.items()
|
||||
}
|
||||
_LOGGER.debug("Loaded %d loyalty cards", len(self._loyalty_cards))
|
||||
|
||||
# Initialize search engine after products are loaded
|
||||
if self._products:
|
||||
products_dict = {pid: p.to_dict() for pid, p in self._products.items()}
|
||||
@@ -156,9 +171,21 @@ class ShoppingListStorage:
|
||||
data = {list_id: lst.to_dict() for list_id, lst in self._lists.items()}
|
||||
await self._store_lists.async_save(data)
|
||||
|
||||
def get_lists(self) -> List[ShoppingList]:
|
||||
"""Get all lists."""
|
||||
return list(self._lists.values())
|
||||
def get_lists(self, user_id: str = None, is_admin: bool = False) -> List[ShoppingList]:
|
||||
"""Get lists visible to the specified user.
|
||||
|
||||
Global lists (owner_id=None) are visible to everyone.
|
||||
Private lists are visible to their owner, anyone in allowed_users, and admins.
|
||||
"""
|
||||
all_lists = list(self._lists.values())
|
||||
if is_admin or user_id is None:
|
||||
return all_lists
|
||||
return [
|
||||
lst for lst in all_lists
|
||||
if lst.owner_id is None
|
||||
or lst.owner_id == user_id
|
||||
or user_id in (lst.allowed_users or [])
|
||||
]
|
||||
|
||||
def get_list(self, list_id: str) -> Optional[ShoppingList]:
|
||||
"""Get a specific list."""
|
||||
@@ -171,17 +198,19 @@ class ShoppingListStorage:
|
||||
return lst
|
||||
return None
|
||||
|
||||
async def create_list(self, name: str, icon: str = "mdi:cart") -> ShoppingList:
|
||||
"""Create a new list."""
|
||||
async def create_list(self, name: str, icon: str = "mdi:cart", owner_id: str = None) -> ShoppingList:
|
||||
"""Create a new list. Pass owner_id to make the list private to that user."""
|
||||
new_list = ShoppingList(
|
||||
id=generate_id(),
|
||||
name=name,
|
||||
icon=icon,
|
||||
category_order=[cat.id for cat in self._categories]
|
||||
category_order=[cat.id for cat in self._categories],
|
||||
owner_id=owner_id,
|
||||
)
|
||||
self._lists[new_list.id] = new_list
|
||||
self._items[new_list.id] = []
|
||||
await self._save_lists()
|
||||
await self._write_config_backup()
|
||||
_LOGGER.info("Created new list: %s", name)
|
||||
return new_list
|
||||
|
||||
@@ -202,6 +231,18 @@ class ShoppingListStorage:
|
||||
_LOGGER.debug("Updated list: %s", list_id)
|
||||
return lst
|
||||
|
||||
async def update_list_members(self, list_id: str, allowed_users: List[str]) -> Optional[ShoppingList]:
|
||||
"""Update the allowed_users for a private list."""
|
||||
if list_id not in self._lists:
|
||||
return None
|
||||
lst = self._lists[list_id]
|
||||
lst.allowed_users = allowed_users
|
||||
from .models import current_timestamp
|
||||
lst.updated_at = current_timestamp()
|
||||
await self._save_lists()
|
||||
_LOGGER.debug("Updated members for list: %s", list_id)
|
||||
return lst
|
||||
|
||||
async def delete_list(self, list_id: str) -> bool:
|
||||
"""Delete a list."""
|
||||
if list_id not in self._lists:
|
||||
@@ -460,25 +501,167 @@ class ShoppingListStorage:
|
||||
)
|
||||
self._products[new_product.id] = new_product
|
||||
await self._save_products()
|
||||
await self._write_config_backup()
|
||||
# Rebuild search engine so the new product is immediately searchable
|
||||
products_dict = {pid: p.to_dict() for pid, p in self._products.items()}
|
||||
self._search_engine = ProductSearch(products_dict)
|
||||
_LOGGER.debug("Added product: %s", new_product.name)
|
||||
return new_product
|
||||
|
||||
async def reload_catalog(self, country_code: str) -> int:
|
||||
"""Replace catalog-sourced products with those from a new country's catalog.
|
||||
Products with source='user' are preserved."""
|
||||
catalog_ids = [
|
||||
pid for pid, p in self._products.items()
|
||||
if getattr(p, 'source', 'user') == 'catalog'
|
||||
]
|
||||
for pid in catalog_ids:
|
||||
del self._products[pid]
|
||||
|
||||
self._country = country_code
|
||||
catalog_products = await load_product_catalog(self._component_path, country_code)
|
||||
count = 0
|
||||
for prod_data in catalog_products:
|
||||
try:
|
||||
product = Product(
|
||||
id=prod_data.get("id", generate_id()),
|
||||
name=prod_data["name"],
|
||||
category_id=prod_data.get("category_id", "other"),
|
||||
aliases=prod_data.get("aliases", []),
|
||||
default_unit=prod_data.get("default_unit", "units"),
|
||||
default_quantity=prod_data.get("default_quantity", 1),
|
||||
price=prod_data.get("price") or prod_data.get("typical_price"),
|
||||
currency=self.hass.config.currency,
|
||||
barcode=prod_data.get("barcode"),
|
||||
brands=prod_data.get("brands", []),
|
||||
image_url=prod_data.get("image_url", ""),
|
||||
custom=False,
|
||||
source="catalog",
|
||||
tags=prod_data.get("tags", []),
|
||||
collections=prod_data.get("collections", []),
|
||||
taxonomy=prod_data.get("taxonomy", {}),
|
||||
allergens=prod_data.get("allergens", []),
|
||||
substitution_group=prod_data.get("substitution_group", ""),
|
||||
priority_level=prod_data.get("priority_level", 0),
|
||||
image_hint=prod_data.get("image_hint", "")
|
||||
)
|
||||
self._products[product.id] = product
|
||||
count += 1
|
||||
except Exception as err:
|
||||
_LOGGER.error("Failed to import product %s: %s", prod_data.get("name"), err)
|
||||
|
||||
await self._save_products()
|
||||
products_dict = {pid: p.to_dict() for pid, p in self._products.items()}
|
||||
self._search_engine = ProductSearch(products_dict)
|
||||
_LOGGER.info("Reloaded catalog for %s: %d products imported", country_code, count)
|
||||
return count
|
||||
|
||||
async def update_product(self, product_id: str, **kwargs) -> Optional[Product]:
|
||||
"""Update a product."""
|
||||
if product_id not in self._products:
|
||||
return None
|
||||
|
||||
|
||||
product = self._products[product_id]
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(product, key):
|
||||
setattr(product, key, value)
|
||||
|
||||
|
||||
await self._save_products()
|
||||
await self._write_config_backup()
|
||||
_LOGGER.debug("Updated product: %s", product_id)
|
||||
return product
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Backup / Restore
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def export_user_data(self) -> dict:
|
||||
"""Return a serialisable snapshot of all user-created data."""
|
||||
user_products = [
|
||||
p.to_dict() for p in self._products.values()
|
||||
if getattr(p, "source", "user") == "user"
|
||||
]
|
||||
lists = [lst.to_dict() for lst in self._lists.values()]
|
||||
items = {
|
||||
list_id: [item.to_dict() for item in items_list]
|
||||
for list_id, items_list in self._items.items()
|
||||
}
|
||||
return {
|
||||
"slm_backup_version": "1.0",
|
||||
"exported_at": datetime.now(timezone.utc).isoformat(),
|
||||
"country": self._country,
|
||||
"user_products": user_products,
|
||||
"lists": lists,
|
||||
"items": items,
|
||||
}
|
||||
|
||||
async def import_user_data(self, data: dict) -> dict:
|
||||
"""Merge a backup into live storage. Skips anything already present by ID."""
|
||||
imported_products = 0
|
||||
imported_lists = 0
|
||||
imported_items = 0
|
||||
|
||||
for prod_data in data.get("user_products", []):
|
||||
prod_id = prod_data.get("id")
|
||||
if prod_id and prod_id not in self._products:
|
||||
try:
|
||||
self._products[prod_id] = Product(**prod_data)
|
||||
imported_products += 1
|
||||
except Exception as err:
|
||||
_LOGGER.warning("Skipped product during import: %s", err)
|
||||
|
||||
if imported_products:
|
||||
await self._save_products()
|
||||
products_dict = {pid: p.to_dict() for pid, p in self._products.items()}
|
||||
self._search_engine = ProductSearch(products_dict)
|
||||
|
||||
for list_data in data.get("lists", []):
|
||||
list_id = list_data.get("id")
|
||||
if list_id and list_id not in self._lists:
|
||||
try:
|
||||
lst = ShoppingList(**list_data)
|
||||
lst.active = False
|
||||
self._lists[list_id] = lst
|
||||
imported_lists += 1
|
||||
except Exception as err:
|
||||
_LOGGER.warning("Skipped list during import: %s", err)
|
||||
|
||||
backup_items = data.get("items", {})
|
||||
for list_id, items_list in backup_items.items():
|
||||
if list_id in self._lists and list_id not in self._items:
|
||||
try:
|
||||
self._items[list_id] = [Item(**d) for d in items_list]
|
||||
imported_items += len(self._items[list_id])
|
||||
except Exception as err:
|
||||
_LOGGER.warning("Skipped items for list %s: %s", list_id, err)
|
||||
|
||||
if imported_lists or imported_items:
|
||||
await self._save_lists()
|
||||
await self._save_items()
|
||||
|
||||
_LOGGER.info(
|
||||
"Import complete: %d products, %d lists, %d items",
|
||||
imported_products, imported_lists, imported_items,
|
||||
)
|
||||
return {"products": imported_products, "lists": imported_lists, "items": imported_items}
|
||||
|
||||
async def _write_config_backup(self) -> None:
|
||||
"""Silently write a backup JSON to the HA config directory."""
|
||||
try:
|
||||
backup_path = os.path.join(
|
||||
self.hass.config.config_dir,
|
||||
"shopping_list_manager_backup.json",
|
||||
)
|
||||
data = await self.export_user_data()
|
||||
|
||||
def _write() -> None:
|
||||
with open(backup_path, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
await self.hass.async_add_executor_job(_write)
|
||||
_LOGGER.debug("Auto-backup written to %s", backup_path)
|
||||
except Exception as err:
|
||||
_LOGGER.warning("Failed to write config backup: %s", err)
|
||||
|
||||
# Categories methods
|
||||
async def _save_categories(self) -> None:
|
||||
@@ -489,3 +672,81 @@ class ShoppingListStorage:
|
||||
def get_categories(self) -> List[Category]:
|
||||
"""Get all categories."""
|
||||
return self._categories
|
||||
|
||||
# Loyalty card methods
|
||||
async def _save_loyalty_cards(self) -> None:
|
||||
"""Save loyalty cards to storage."""
|
||||
data = {card_id: card.to_dict() for card_id, card in self._loyalty_cards.items()}
|
||||
await self._store_loyalty_cards.async_save(data)
|
||||
|
||||
def get_loyalty_cards(self, user_id: str = None, is_admin: bool = False) -> List[LoyaltyCard]:
|
||||
"""Get loyalty cards visible to the specified user.
|
||||
|
||||
Global cards (owner_id=None) are visible to everyone.
|
||||
Private cards are visible to their owner, anyone in allowed_users, and admins.
|
||||
"""
|
||||
all_cards = list(self._loyalty_cards.values())
|
||||
if is_admin or user_id is None:
|
||||
return all_cards
|
||||
return [
|
||||
card for card in all_cards
|
||||
if card.owner_id is None
|
||||
or card.owner_id == user_id
|
||||
or user_id in (card.allowed_users or [])
|
||||
]
|
||||
|
||||
def get_loyalty_card(self, card_id: str) -> Optional[LoyaltyCard]:
|
||||
"""Get a specific loyalty card."""
|
||||
return self._loyalty_cards.get(card_id)
|
||||
|
||||
async def create_loyalty_card(self, owner_id: str = None, **kwargs) -> LoyaltyCard:
|
||||
"""Create a new loyalty card."""
|
||||
from .models import current_timestamp
|
||||
new_card = LoyaltyCard(
|
||||
id=generate_id(),
|
||||
owner_id=owner_id,
|
||||
**kwargs
|
||||
)
|
||||
self._loyalty_cards[new_card.id] = new_card
|
||||
await self._save_loyalty_cards()
|
||||
_LOGGER.debug("Created loyalty card: %s", new_card.name)
|
||||
return new_card
|
||||
|
||||
async def update_loyalty_card(self, card_id: str, **kwargs) -> Optional[LoyaltyCard]:
|
||||
"""Update a loyalty card."""
|
||||
if card_id not in self._loyalty_cards:
|
||||
return None
|
||||
|
||||
card = self._loyalty_cards[card_id]
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(card, key):
|
||||
setattr(card, key, value)
|
||||
|
||||
from .models import current_timestamp
|
||||
card.updated_at = current_timestamp()
|
||||
await self._save_loyalty_cards()
|
||||
_LOGGER.debug("Updated loyalty card: %s", card_id)
|
||||
return card
|
||||
|
||||
async def delete_loyalty_card(self, card_id: str) -> bool:
|
||||
"""Delete a loyalty card."""
|
||||
if card_id not in self._loyalty_cards:
|
||||
return False
|
||||
|
||||
del self._loyalty_cards[card_id]
|
||||
await self._save_loyalty_cards()
|
||||
_LOGGER.debug("Deleted loyalty card: %s", card_id)
|
||||
return True
|
||||
|
||||
async def update_loyalty_card_members(self, card_id: str, allowed_users: List[str]) -> Optional[LoyaltyCard]:
|
||||
"""Update the allowed_users for a private loyalty card."""
|
||||
if card_id not in self._loyalty_cards:
|
||||
return None
|
||||
|
||||
card = self._loyalty_cards[card_id]
|
||||
card.allowed_users = allowed_users
|
||||
from .models import current_timestamp
|
||||
card.updated_at = current_timestamp()
|
||||
await self._save_loyalty_cards()
|
||||
_LOGGER.debug("Updated members for loyalty card: %s", card_id)
|
||||
return card
|
||||
|
||||
@@ -14,6 +14,8 @@ from ..const import (
|
||||
WS_TYPE_LISTS_UPDATE,
|
||||
WS_TYPE_LISTS_DELETE,
|
||||
WS_TYPE_LISTS_SET_ACTIVE,
|
||||
WS_TYPE_LISTS_UPDATE_MEMBERS,
|
||||
WS_TYPE_USERS_GET_ALL,
|
||||
WS_TYPE_ITEMS_GET,
|
||||
WS_TYPE_ITEMS_ADD,
|
||||
WS_TYPE_ITEMS_UPDATE,
|
||||
@@ -28,6 +30,11 @@ from ..const import (
|
||||
WS_TYPE_PRODUCTS_ADD,
|
||||
WS_TYPE_PRODUCTS_UPDATE,
|
||||
WS_TYPE_CATEGORIES_GET_ALL,
|
||||
WS_TYPE_LOYALTY_GET_ALL,
|
||||
WS_TYPE_LOYALTY_ADD,
|
||||
WS_TYPE_LOYALTY_UPDATE,
|
||||
WS_TYPE_LOYALTY_DELETE,
|
||||
WS_TYPE_LOYALTY_UPDATE_MEMBERS,
|
||||
WS_TYPE_SUBSCRIBE,
|
||||
EVENT_ITEM_ADDED,
|
||||
EVENT_ITEM_UPDATED,
|
||||
@@ -171,8 +178,11 @@ def websocket_get_lists(
|
||||
) -> None:
|
||||
"""Handle get all lists command."""
|
||||
storage = get_storage(hass)
|
||||
lists = storage.get_lists()
|
||||
|
||||
user = connection.user
|
||||
user_id = user.id if user else None
|
||||
is_admin = user.is_admin if user else False
|
||||
lists = storage.get_lists(user_id=user_id, is_admin=is_admin)
|
||||
|
||||
connection.send_result(
|
||||
msg["id"],
|
||||
{
|
||||
@@ -186,6 +196,7 @@ def websocket_get_lists(
|
||||
vol.Required("type"): WS_TYPE_LISTS_CREATE,
|
||||
vol.Required("name"): str,
|
||||
vol.Optional("icon", default="mdi:cart"): str,
|
||||
vol.Optional("private", default=True): bool,
|
||||
}
|
||||
)
|
||||
@websocket_api.async_response
|
||||
@@ -196,10 +207,15 @@ async def websocket_create_list(
|
||||
) -> None:
|
||||
"""Handle create list command."""
|
||||
storage = get_storage(hass)
|
||||
|
||||
|
||||
# Private lists are owned by the creating user; global lists have no owner.
|
||||
is_private = msg.get("private", True)
|
||||
owner_id = connection.user.id if is_private and connection.user else None
|
||||
|
||||
new_list = await storage.create_list(
|
||||
name=msg["name"],
|
||||
icon=msg.get("icon", "mdi:cart")
|
||||
icon=msg.get("icon", "mdi:cart"),
|
||||
owner_id=owner_id,
|
||||
)
|
||||
|
||||
# Fire event
|
||||
@@ -275,9 +291,21 @@ async def websocket_delete_list(
|
||||
"""Handle delete list command."""
|
||||
storage = get_storage(hass)
|
||||
list_id = msg["list_id"]
|
||||
|
||||
|
||||
lst = storage.get_list(list_id)
|
||||
if lst is None:
|
||||
connection.send_error(msg["id"], "not_found", "List not found")
|
||||
return
|
||||
|
||||
# Only the owner or an admin may delete a private list
|
||||
if lst.owner_id is not None:
|
||||
user = connection.user
|
||||
if not (user and (user.is_admin or user.id == lst.owner_id)):
|
||||
connection.send_error(msg["id"], "forbidden", "Only the list owner can delete this list")
|
||||
return
|
||||
|
||||
success = await storage.delete_list(list_id)
|
||||
|
||||
|
||||
if not success:
|
||||
connection.send_error(msg["id"], "not_found", "List not found")
|
||||
return
|
||||
@@ -869,3 +897,321 @@ def websocket_get_categories(
|
||||
"categories": [cat.to_dict() for cat in categories]
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# INTEGRATION SETTINGS HANDLERS
|
||||
# =============================================================================
|
||||
|
||||
@websocket_api.websocket_command(
|
||||
{
|
||||
vol.Required("type"): "shopping_list_manager/get_integration_settings",
|
||||
}
|
||||
)
|
||||
@callback
|
||||
def websocket_get_integration_settings(
|
||||
hass: HomeAssistant,
|
||||
connection: websocket_api.ActiveConnection,
|
||||
msg: Dict[str, Any],
|
||||
) -> None:
|
||||
"""Return current country and available country options."""
|
||||
country = hass.data[DOMAIN].get("country", "NZ")
|
||||
connection.send_result(
|
||||
msg["id"],
|
||||
{
|
||||
"country": country,
|
||||
"available_countries": {
|
||||
"NZ": "New Zealand",
|
||||
"AU": "Australia",
|
||||
"US": "United States",
|
||||
"GB": "United Kingdom",
|
||||
"CA": "Canada",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@websocket_api.websocket_command(
|
||||
{
|
||||
vol.Required("type"): "shopping_list_manager/set_country",
|
||||
vol.Required("country"): str,
|
||||
}
|
||||
)
|
||||
@websocket_api.async_response
|
||||
async def websocket_set_country(
|
||||
hass: HomeAssistant,
|
||||
connection: websocket_api.ActiveConnection,
|
||||
msg: Dict[str, Any],
|
||||
) -> None:
|
||||
"""Switch to a different country catalog. Preserves user-added products."""
|
||||
country = msg["country"].upper()
|
||||
storage = get_storage(hass)
|
||||
|
||||
count = await storage.reload_catalog(country)
|
||||
|
||||
# Persist to HA config entry so country survives restart
|
||||
entries = hass.config_entries.async_entries(DOMAIN)
|
||||
if entries:
|
||||
entry = entries[0]
|
||||
hass.config_entries.async_update_entry(entry, options={**entry.options, "country": country})
|
||||
|
||||
hass.data[DOMAIN]["country"] = country
|
||||
|
||||
connection.send_result(
|
||||
msg["id"],
|
||||
{"success": True, "country": country, "products_loaded": count}
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# BACKUP / RESTORE HANDLERS
|
||||
# =============================================================================
|
||||
|
||||
@websocket_api.websocket_command(
|
||||
{
|
||||
vol.Required("type"): "shopping_list_manager/export_data",
|
||||
}
|
||||
)
|
||||
@websocket_api.async_response
|
||||
async def websocket_export_data(
|
||||
hass: HomeAssistant,
|
||||
connection: websocket_api.ActiveConnection,
|
||||
msg: Dict[str, Any],
|
||||
) -> None:
|
||||
"""Export all user-created data as a JSON-serialisable dict."""
|
||||
storage = get_storage(hass)
|
||||
data = await storage.export_user_data()
|
||||
connection.send_result(msg["id"], data)
|
||||
|
||||
|
||||
@websocket_api.websocket_command(
|
||||
{
|
||||
vol.Required("type"): "shopping_list_manager/import_data",
|
||||
vol.Required("data"): dict,
|
||||
}
|
||||
)
|
||||
@websocket_api.async_response
|
||||
async def websocket_import_data(
|
||||
hass: HomeAssistant,
|
||||
connection: websocket_api.ActiveConnection,
|
||||
msg: Dict[str, Any],
|
||||
) -> None:
|
||||
"""Import user data from a backup payload."""
|
||||
storage = get_storage(hass)
|
||||
counts = await storage.import_user_data(msg["data"])
|
||||
connection.send_result(msg["id"], {"success": True, "imported": counts})
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# LIST MEMBERS HANDLER
|
||||
# =============================================================================
|
||||
|
||||
@websocket_api.websocket_command(
|
||||
{
|
||||
vol.Required("type"): WS_TYPE_LISTS_UPDATE_MEMBERS,
|
||||
vol.Required("list_id"): str,
|
||||
vol.Required("allowed_users"): [str],
|
||||
}
|
||||
)
|
||||
@websocket_api.async_response
|
||||
async def websocket_update_list_members(
|
||||
hass: HomeAssistant,
|
||||
connection: websocket_api.ActiveConnection,
|
||||
msg: Dict[str, Any],
|
||||
) -> None:
|
||||
"""Update the allowed_users for a private list."""
|
||||
storage = get_storage(hass)
|
||||
list_id = msg["list_id"]
|
||||
|
||||
lst = storage.get_list(list_id)
|
||||
if lst is None:
|
||||
connection.send_error(msg["id"], "not_found", "List not found")
|
||||
return
|
||||
|
||||
# Only the owner or an admin may manage members
|
||||
user = connection.user
|
||||
if lst.owner_id is not None and not (user and (user.is_admin or user.id == lst.owner_id)):
|
||||
connection.send_error(msg["id"], "forbidden", "Only the list owner can manage members")
|
||||
return
|
||||
|
||||
updated = await storage.update_list_members(list_id, msg["allowed_users"])
|
||||
hass.bus.async_fire(
|
||||
EVENT_LIST_UPDATED,
|
||||
{"list_id": list_id, "action": "members_updated"}
|
||||
)
|
||||
connection.send_result(msg["id"], {"list": updated.to_dict()})
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# HA USERS HANDLER
|
||||
# =============================================================================
|
||||
|
||||
@websocket_api.websocket_command(
|
||||
{
|
||||
vol.Required("type"): WS_TYPE_USERS_GET_ALL,
|
||||
}
|
||||
)
|
||||
@websocket_api.async_response
|
||||
async def websocket_get_ha_users(
|
||||
hass: HomeAssistant,
|
||||
connection: websocket_api.ActiveConnection,
|
||||
msg: Dict[str, Any],
|
||||
) -> None:
|
||||
"""Return all active, non-system HA users."""
|
||||
users = await hass.auth.async_get_users()
|
||||
result = [
|
||||
{"id": u.id, "name": u.name}
|
||||
for u in users
|
||||
if not u.system_generated and u.is_active
|
||||
]
|
||||
connection.send_result(msg["id"], {"users": result})
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# LOYALTY CARD HANDLERS
|
||||
# =============================================================================
|
||||
|
||||
@websocket_api.websocket_command({
|
||||
vol.Required("type"): WS_TYPE_LOYALTY_GET_ALL,
|
||||
})
|
||||
@websocket_api.async_response
|
||||
async def websocket_get_loyalty_cards(
|
||||
hass: HomeAssistant,
|
||||
connection: websocket_api.ActiveConnection,
|
||||
msg: Dict[str, Any],
|
||||
) -> None:
|
||||
"""Return all loyalty cards visible to the current user."""
|
||||
storage = get_storage(hass)
|
||||
user = connection.user
|
||||
user_id = user.id if user else None
|
||||
is_admin = user.is_admin if user else False
|
||||
cards = storage.get_loyalty_cards(user_id=user_id, is_admin=is_admin)
|
||||
connection.send_result(msg["id"], {"cards": [c.to_dict() for c in cards]})
|
||||
|
||||
|
||||
@websocket_api.websocket_command({
|
||||
vol.Required("type"): WS_TYPE_LOYALTY_ADD,
|
||||
vol.Required("name"): str,
|
||||
vol.Required("number"): str,
|
||||
vol.Optional("barcode", default=""): str,
|
||||
vol.Optional("barcode_type", default="barcode"): str,
|
||||
vol.Optional("logo", default=""): str,
|
||||
vol.Optional("notes", default=""): str,
|
||||
vol.Optional("color", default="#9fa8da"): str,
|
||||
vol.Optional("private", default=True): bool,
|
||||
})
|
||||
@websocket_api.async_response
|
||||
async def websocket_add_loyalty_card(
|
||||
hass: HomeAssistant,
|
||||
connection: websocket_api.ActiveConnection,
|
||||
msg: Dict[str, Any],
|
||||
) -> None:
|
||||
"""Add a new loyalty card."""
|
||||
storage = get_storage(hass)
|
||||
user = connection.user
|
||||
owner_id = user.id if (user and msg.get("private")) else None
|
||||
|
||||
card = await storage.create_loyalty_card(
|
||||
owner_id=owner_id,
|
||||
name=msg["name"],
|
||||
number=msg["number"],
|
||||
barcode=msg.get("barcode", ""),
|
||||
barcode_type=msg.get("barcode_type", "barcode"),
|
||||
logo=msg.get("logo", ""),
|
||||
notes=msg.get("notes", ""),
|
||||
color=msg.get("color", "#9fa8da"),
|
||||
)
|
||||
connection.send_result(msg["id"], {"card": card.to_dict()})
|
||||
|
||||
|
||||
@websocket_api.websocket_command({
|
||||
vol.Required("type"): WS_TYPE_LOYALTY_UPDATE,
|
||||
vol.Required("card_id"): str,
|
||||
vol.Optional("name"): str,
|
||||
vol.Optional("number"): str,
|
||||
vol.Optional("barcode"): str,
|
||||
vol.Optional("barcode_type"): str,
|
||||
vol.Optional("logo"): str,
|
||||
vol.Optional("notes"): str,
|
||||
vol.Optional("color"): str,
|
||||
})
|
||||
@websocket_api.async_response
|
||||
async def websocket_update_loyalty_card(
|
||||
hass: HomeAssistant,
|
||||
connection: websocket_api.ActiveConnection,
|
||||
msg: Dict[str, Any],
|
||||
) -> None:
|
||||
"""Update an existing loyalty card."""
|
||||
storage = get_storage(hass)
|
||||
card_id = msg["card_id"]
|
||||
|
||||
card = storage.get_loyalty_card(card_id)
|
||||
if card is None:
|
||||
connection.send_error(msg["id"], "not_found", "Loyalty card not found")
|
||||
return
|
||||
|
||||
user = connection.user
|
||||
if card.owner_id is not None and not (user and (user.is_admin or user.id == card.owner_id)):
|
||||
connection.send_error(msg["id"], "forbidden", "Only the card owner can update it")
|
||||
return
|
||||
|
||||
fields = {k: v for k, v in msg.items() if k not in ("type", "id", "card_id")}
|
||||
updated = await storage.update_loyalty_card(card_id, **fields)
|
||||
connection.send_result(msg["id"], {"card": updated.to_dict()})
|
||||
|
||||
|
||||
@websocket_api.websocket_command({
|
||||
vol.Required("type"): WS_TYPE_LOYALTY_DELETE,
|
||||
vol.Required("card_id"): str,
|
||||
})
|
||||
@websocket_api.async_response
|
||||
async def websocket_delete_loyalty_card(
|
||||
hass: HomeAssistant,
|
||||
connection: websocket_api.ActiveConnection,
|
||||
msg: Dict[str, Any],
|
||||
) -> None:
|
||||
"""Delete a loyalty card."""
|
||||
storage = get_storage(hass)
|
||||
card_id = msg["card_id"]
|
||||
|
||||
card = storage.get_loyalty_card(card_id)
|
||||
if card is None:
|
||||
connection.send_error(msg["id"], "not_found", "Loyalty card not found")
|
||||
return
|
||||
|
||||
user = connection.user
|
||||
if card.owner_id is not None and not (user and (user.is_admin or user.id == card.owner_id)):
|
||||
connection.send_error(msg["id"], "forbidden", "Only the card owner can delete it")
|
||||
return
|
||||
|
||||
await storage.delete_loyalty_card(card_id)
|
||||
connection.send_result(msg["id"], {"success": True})
|
||||
|
||||
|
||||
@websocket_api.websocket_command({
|
||||
vol.Required("type"): WS_TYPE_LOYALTY_UPDATE_MEMBERS,
|
||||
vol.Required("card_id"): str,
|
||||
vol.Required("allowed_users"): [str],
|
||||
})
|
||||
@websocket_api.async_response
|
||||
async def websocket_update_loyalty_card_members(
|
||||
hass: HomeAssistant,
|
||||
connection: websocket_api.ActiveConnection,
|
||||
msg: Dict[str, Any],
|
||||
) -> None:
|
||||
"""Update the allowed_users for a private loyalty card."""
|
||||
storage = get_storage(hass)
|
||||
card_id = msg["card_id"]
|
||||
|
||||
card = storage.get_loyalty_card(card_id)
|
||||
if card is None:
|
||||
connection.send_error(msg["id"], "not_found", "Loyalty card not found")
|
||||
return
|
||||
|
||||
user = connection.user
|
||||
if card.owner_id is not None and not (user and (user.is_admin or user.id == card.owner_id)):
|
||||
connection.send_error(msg["id"], "forbidden", "Only the card owner can manage members")
|
||||
return
|
||||
|
||||
updated = await storage.update_loyalty_card_members(card_id, msg["allowed_users"])
|
||||
connection.send_result(msg["id"], {"card": updated.to_dict()})
|
||||
|
||||
Reference in New Issue
Block a user