Merge pull request #4 from thekiwismarthome/v2.0.5-Country-Specific-Product-Catalogues

V2.0.5 country specific product catalogues
This commit is contained in:
thekiwismarthome
2026-02-27 17:57:50 +13:00
committed by GitHub
5 changed files with 715 additions and 17 deletions
@@ -202,7 +202,61 @@ async def _async_register_websocket_handlers(
hass,
handlers.websocket_get_categories,
)
# Integration settings handlers
websocket_api.async_register_command(
hass,
handlers.websocket_get_integration_settings,
)
websocket_api.async_register_command(
hass,
handlers.websocket_set_country,
)
# Backup / Restore handlers
websocket_api.async_register_command(
hass,
handlers.websocket_export_data,
)
websocket_api.async_register_command(
hass,
handlers.websocket_import_data,
)
# List members handler
websocket_api.async_register_command(
hass,
handlers.websocket_update_list_members,
)
# HA users handler
websocket_api.async_register_command(
hass,
handlers.websocket_get_ha_users,
)
# Loyalty card handlers
websocket_api.async_register_command(
hass,
handlers.websocket_get_loyalty_cards,
)
websocket_api.async_register_command(
hass,
handlers.websocket_add_loyalty_card,
)
websocket_api.async_register_command(
hass,
handlers.websocket_update_loyalty_card,
)
websocket_api.async_register_command(
hass,
handlers.websocket_delete_loyalty_card,
)
websocket_api.async_register_command(
hass,
handlers.websocket_update_loyalty_card_members,
)
_LOGGER.debug("WebSocket handlers registered")
@@ -9,6 +9,7 @@ STORAGE_KEY_LISTS = f"{DOMAIN}.lists"
STORAGE_KEY_ITEMS = f"{DOMAIN}.items"
STORAGE_KEY_PRODUCTS = f"{DOMAIN}.products"
STORAGE_KEY_CATEGORIES = f"{DOMAIN}.categories"
STORAGE_KEY_LOYALTY_CARDS = f"{DOMAIN}.loyalty_cards"
# WebSocket Commands - Lists
WS_TYPE_LISTS_GET_ALL = f"{DOMAIN}/lists/get_all"
@@ -16,6 +17,10 @@ WS_TYPE_LISTS_CREATE = f"{DOMAIN}/lists/create"
WS_TYPE_LISTS_UPDATE = f"{DOMAIN}/lists/update"
WS_TYPE_LISTS_DELETE = f"{DOMAIN}/lists/delete"
WS_TYPE_LISTS_SET_ACTIVE = f"{DOMAIN}/lists/set_active"
WS_TYPE_LISTS_UPDATE_MEMBERS = f"{DOMAIN}/lists/update_members"
# WebSocket Commands - Users
WS_TYPE_USERS_GET_ALL = f"{DOMAIN}/users/get_all"
# WebSocket Commands - Items
WS_TYPE_ITEMS_GET = f"{DOMAIN}/items/get"
@@ -39,6 +44,13 @@ WS_TYPE_PRODUCTS_DELETE = f"{DOMAIN}/products/delete"
WS_TYPE_CATEGORIES_GET_ALL = f"{DOMAIN}/categories/get_all"
WS_TYPE_CATEGORIES_REORDER = f"{DOMAIN}/categories/reorder"
# WebSocket Commands - Loyalty Cards
WS_TYPE_LOYALTY_GET_ALL = f"{DOMAIN}/loyalty/get_all"
WS_TYPE_LOYALTY_ADD = f"{DOMAIN}/loyalty/add"
WS_TYPE_LOYALTY_UPDATE = f"{DOMAIN}/loyalty/update"
WS_TYPE_LOYALTY_DELETE = f"{DOMAIN}/loyalty/delete"
WS_TYPE_LOYALTY_UPDATE_MEMBERS = f"{DOMAIN}/loyalty/update_members"
# WebSocket Commands - Subscriptions
WS_TYPE_SUBSCRIBE = f"{DOMAIN}/subscribe"
WS_TYPE_UNSUBSCRIBE = f"{DOMAIN}/unsubscribe"
@@ -96,6 +96,28 @@ class Item:
self.estimated_total = self.quantity * self.price
@dataclass
class LoyaltyCard:
"""Loyalty card model."""
id: str
name: str
number: str
barcode: str = ""
barcode_type: str = "barcode" # "barcode" or "qrcode"
logo: str = ""
notes: str = ""
color: str = "#9fa8da"
created_at: str = field(default_factory=current_timestamp)
updated_at: str = field(default_factory=current_timestamp)
# Ownership: None = visible to all users; set = private to owner + allowed_users
owner_id: Optional[str] = None
allowed_users: List[str] = field(default_factory=list)
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary."""
return asdict(self)
@dataclass
class ShoppingList:
"""Shopping list model."""
@@ -107,6 +129,9 @@ class ShoppingList:
item_order: List[str] = field(default_factory=list)
category_order: List[str] = field(default_factory=list)
active: bool = False
# Ownership: None = visible to all users; set = private to owner + allowed_users
owner_id: Optional[str] = None
allowed_users: List[str] = field(default_factory=list)
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary."""
@@ -1,5 +1,8 @@
"""Storage management for Shopping List Manager."""
import json
import logging
import os
from datetime import datetime, timezone
from typing import Dict, List, Optional, Any
from .utils.search import ProductSearch
from homeassistant.core import HomeAssistant
@@ -11,9 +14,10 @@ from .const import (
STORAGE_KEY_ITEMS,
STORAGE_KEY_PRODUCTS,
STORAGE_KEY_CATEGORIES,
STORAGE_KEY_LOYALTY_CARDS,
)
from .data.catalog_loader import load_product_catalog
from .models import ShoppingList, Item, Product, Category, generate_id
from .models import ShoppingList, Item, Product, Category, LoyaltyCard, generate_id
from .data.category_loader import load_categories
_LOGGER = logging.getLogger(__name__)
@@ -37,11 +41,13 @@ class ShoppingListStorage:
self._store_items = Store(hass, STORAGE_VERSION, STORAGE_KEY_ITEMS)
self._store_products = Store(hass, STORAGE_VERSION, STORAGE_KEY_PRODUCTS)
self._store_categories = Store(hass, STORAGE_VERSION, STORAGE_KEY_CATEGORIES)
self._store_loyalty_cards = Store(hass, STORAGE_VERSION, STORAGE_KEY_LOYALTY_CARDS)
self._lists: Dict[str, ShoppingList] = {}
self._items: Dict[str, List[Item]] = {}
self._products: Dict[str, Product] = {}
self._categories: List[Category] = []
self._loyalty_cards: Dict[str, LoyaltyCard] = {}
self._search_engine: Optional[ProductSearch] = None
async def async_load(self) -> None:
@@ -141,6 +147,15 @@ class ShoppingListStorage:
await self._save_products()
_LOGGER.info("Successfully imported %d products from catalog", len(self._products))
# Load loyalty cards
loyalty_data = await self._store_loyalty_cards.async_load()
if loyalty_data:
self._loyalty_cards = {
card_id: LoyaltyCard(**card_data)
for card_id, card_data in loyalty_data.items()
}
_LOGGER.debug("Loaded %d loyalty cards", len(self._loyalty_cards))
# Initialize search engine after products are loaded
if self._products:
products_dict = {pid: p.to_dict() for pid, p in self._products.items()}
@@ -156,9 +171,21 @@ class ShoppingListStorage:
data = {list_id: lst.to_dict() for list_id, lst in self._lists.items()}
await self._store_lists.async_save(data)
def get_lists(self) -> List[ShoppingList]:
"""Get all lists."""
return list(self._lists.values())
def get_lists(self, user_id: str = None, is_admin: bool = False) -> List[ShoppingList]:
"""Get lists visible to the specified user.
Global lists (owner_id=None) are visible to everyone.
Private lists are visible to their owner, anyone in allowed_users, and admins.
"""
all_lists = list(self._lists.values())
if is_admin or user_id is None:
return all_lists
return [
lst for lst in all_lists
if lst.owner_id is None
or lst.owner_id == user_id
or user_id in (lst.allowed_users or [])
]
def get_list(self, list_id: str) -> Optional[ShoppingList]:
"""Get a specific list."""
@@ -171,17 +198,19 @@ class ShoppingListStorage:
return lst
return None
async def create_list(self, name: str, icon: str = "mdi:cart") -> ShoppingList:
"""Create a new list."""
async def create_list(self, name: str, icon: str = "mdi:cart", owner_id: str = None) -> ShoppingList:
"""Create a new list. Pass owner_id to make the list private to that user."""
new_list = ShoppingList(
id=generate_id(),
name=name,
icon=icon,
category_order=[cat.id for cat in self._categories]
category_order=[cat.id for cat in self._categories],
owner_id=owner_id,
)
self._lists[new_list.id] = new_list
self._items[new_list.id] = []
await self._save_lists()
await self._write_config_backup()
_LOGGER.info("Created new list: %s", name)
return new_list
@@ -202,6 +231,18 @@ class ShoppingListStorage:
_LOGGER.debug("Updated list: %s", list_id)
return lst
async def update_list_members(self, list_id: str, allowed_users: List[str]) -> Optional[ShoppingList]:
"""Update the allowed_users for a private list."""
if list_id not in self._lists:
return None
lst = self._lists[list_id]
lst.allowed_users = allowed_users
from .models import current_timestamp
lst.updated_at = current_timestamp()
await self._save_lists()
_LOGGER.debug("Updated members for list: %s", list_id)
return lst
async def delete_list(self, list_id: str) -> bool:
"""Delete a list."""
if list_id not in self._lists:
@@ -460,25 +501,167 @@ class ShoppingListStorage:
)
self._products[new_product.id] = new_product
await self._save_products()
await self._write_config_backup()
# Rebuild search engine so the new product is immediately searchable
products_dict = {pid: p.to_dict() for pid, p in self._products.items()}
self._search_engine = ProductSearch(products_dict)
_LOGGER.debug("Added product: %s", new_product.name)
return new_product
async def reload_catalog(self, country_code: str) -> int:
"""Replace catalog-sourced products with those from a new country's catalog.
Products with source='user' are preserved."""
catalog_ids = [
pid for pid, p in self._products.items()
if getattr(p, 'source', 'user') == 'catalog'
]
for pid in catalog_ids:
del self._products[pid]
self._country = country_code
catalog_products = await load_product_catalog(self._component_path, country_code)
count = 0
for prod_data in catalog_products:
try:
product = Product(
id=prod_data.get("id", generate_id()),
name=prod_data["name"],
category_id=prod_data.get("category_id", "other"),
aliases=prod_data.get("aliases", []),
default_unit=prod_data.get("default_unit", "units"),
default_quantity=prod_data.get("default_quantity", 1),
price=prod_data.get("price") or prod_data.get("typical_price"),
currency=self.hass.config.currency,
barcode=prod_data.get("barcode"),
brands=prod_data.get("brands", []),
image_url=prod_data.get("image_url", ""),
custom=False,
source="catalog",
tags=prod_data.get("tags", []),
collections=prod_data.get("collections", []),
taxonomy=prod_data.get("taxonomy", {}),
allergens=prod_data.get("allergens", []),
substitution_group=prod_data.get("substitution_group", ""),
priority_level=prod_data.get("priority_level", 0),
image_hint=prod_data.get("image_hint", "")
)
self._products[product.id] = product
count += 1
except Exception as err:
_LOGGER.error("Failed to import product %s: %s", prod_data.get("name"), err)
await self._save_products()
products_dict = {pid: p.to_dict() for pid, p in self._products.items()}
self._search_engine = ProductSearch(products_dict)
_LOGGER.info("Reloaded catalog for %s: %d products imported", country_code, count)
return count
async def update_product(self, product_id: str, **kwargs) -> Optional[Product]:
"""Update a product."""
if product_id not in self._products:
return None
product = self._products[product_id]
for key, value in kwargs.items():
if hasattr(product, key):
setattr(product, key, value)
await self._save_products()
await self._write_config_backup()
_LOGGER.debug("Updated product: %s", product_id)
return product
# ---------------------------------------------------------------------------
# Backup / Restore
# ---------------------------------------------------------------------------
async def export_user_data(self) -> dict:
"""Return a serialisable snapshot of all user-created data."""
user_products = [
p.to_dict() for p in self._products.values()
if getattr(p, "source", "user") == "user"
]
lists = [lst.to_dict() for lst in self._lists.values()]
items = {
list_id: [item.to_dict() for item in items_list]
for list_id, items_list in self._items.items()
}
return {
"slm_backup_version": "1.0",
"exported_at": datetime.now(timezone.utc).isoformat(),
"country": self._country,
"user_products": user_products,
"lists": lists,
"items": items,
}
async def import_user_data(self, data: dict) -> dict:
"""Merge a backup into live storage. Skips anything already present by ID."""
imported_products = 0
imported_lists = 0
imported_items = 0
for prod_data in data.get("user_products", []):
prod_id = prod_data.get("id")
if prod_id and prod_id not in self._products:
try:
self._products[prod_id] = Product(**prod_data)
imported_products += 1
except Exception as err:
_LOGGER.warning("Skipped product during import: %s", err)
if imported_products:
await self._save_products()
products_dict = {pid: p.to_dict() for pid, p in self._products.items()}
self._search_engine = ProductSearch(products_dict)
for list_data in data.get("lists", []):
list_id = list_data.get("id")
if list_id and list_id not in self._lists:
try:
lst = ShoppingList(**list_data)
lst.active = False
self._lists[list_id] = lst
imported_lists += 1
except Exception as err:
_LOGGER.warning("Skipped list during import: %s", err)
backup_items = data.get("items", {})
for list_id, items_list in backup_items.items():
if list_id in self._lists and list_id not in self._items:
try:
self._items[list_id] = [Item(**d) for d in items_list]
imported_items += len(self._items[list_id])
except Exception as err:
_LOGGER.warning("Skipped items for list %s: %s", list_id, err)
if imported_lists or imported_items:
await self._save_lists()
await self._save_items()
_LOGGER.info(
"Import complete: %d products, %d lists, %d items",
imported_products, imported_lists, imported_items,
)
return {"products": imported_products, "lists": imported_lists, "items": imported_items}
async def _write_config_backup(self) -> None:
"""Silently write a backup JSON to the HA config directory."""
try:
backup_path = os.path.join(
self.hass.config.config_dir,
"shopping_list_manager_backup.json",
)
data = await self.export_user_data()
def _write() -> None:
with open(backup_path, "w", encoding="utf-8") as f:
json.dump(data, f, indent=2, ensure_ascii=False)
await self.hass.async_add_executor_job(_write)
_LOGGER.debug("Auto-backup written to %s", backup_path)
except Exception as err:
_LOGGER.warning("Failed to write config backup: %s", err)
# Categories methods
async def _save_categories(self) -> None:
@@ -489,3 +672,81 @@ class ShoppingListStorage:
def get_categories(self) -> List[Category]:
"""Get all categories."""
return self._categories
# Loyalty card methods
async def _save_loyalty_cards(self) -> None:
"""Save loyalty cards to storage."""
data = {card_id: card.to_dict() for card_id, card in self._loyalty_cards.items()}
await self._store_loyalty_cards.async_save(data)
def get_loyalty_cards(self, user_id: str = None, is_admin: bool = False) -> List[LoyaltyCard]:
"""Get loyalty cards visible to the specified user.
Global cards (owner_id=None) are visible to everyone.
Private cards are visible to their owner, anyone in allowed_users, and admins.
"""
all_cards = list(self._loyalty_cards.values())
if is_admin or user_id is None:
return all_cards
return [
card for card in all_cards
if card.owner_id is None
or card.owner_id == user_id
or user_id in (card.allowed_users or [])
]
def get_loyalty_card(self, card_id: str) -> Optional[LoyaltyCard]:
"""Get a specific loyalty card."""
return self._loyalty_cards.get(card_id)
async def create_loyalty_card(self, owner_id: str = None, **kwargs) -> LoyaltyCard:
"""Create a new loyalty card."""
from .models import current_timestamp
new_card = LoyaltyCard(
id=generate_id(),
owner_id=owner_id,
**kwargs
)
self._loyalty_cards[new_card.id] = new_card
await self._save_loyalty_cards()
_LOGGER.debug("Created loyalty card: %s", new_card.name)
return new_card
async def update_loyalty_card(self, card_id: str, **kwargs) -> Optional[LoyaltyCard]:
"""Update a loyalty card."""
if card_id not in self._loyalty_cards:
return None
card = self._loyalty_cards[card_id]
for key, value in kwargs.items():
if hasattr(card, key):
setattr(card, key, value)
from .models import current_timestamp
card.updated_at = current_timestamp()
await self._save_loyalty_cards()
_LOGGER.debug("Updated loyalty card: %s", card_id)
return card
async def delete_loyalty_card(self, card_id: str) -> bool:
"""Delete a loyalty card."""
if card_id not in self._loyalty_cards:
return False
del self._loyalty_cards[card_id]
await self._save_loyalty_cards()
_LOGGER.debug("Deleted loyalty card: %s", card_id)
return True
async def update_loyalty_card_members(self, card_id: str, allowed_users: List[str]) -> Optional[LoyaltyCard]:
"""Update the allowed_users for a private loyalty card."""
if card_id not in self._loyalty_cards:
return None
card = self._loyalty_cards[card_id]
card.allowed_users = allowed_users
from .models import current_timestamp
card.updated_at = current_timestamp()
await self._save_loyalty_cards()
_LOGGER.debug("Updated members for loyalty card: %s", card_id)
return card
@@ -14,6 +14,8 @@ from ..const import (
WS_TYPE_LISTS_UPDATE,
WS_TYPE_LISTS_DELETE,
WS_TYPE_LISTS_SET_ACTIVE,
WS_TYPE_LISTS_UPDATE_MEMBERS,
WS_TYPE_USERS_GET_ALL,
WS_TYPE_ITEMS_GET,
WS_TYPE_ITEMS_ADD,
WS_TYPE_ITEMS_UPDATE,
@@ -28,6 +30,11 @@ from ..const import (
WS_TYPE_PRODUCTS_ADD,
WS_TYPE_PRODUCTS_UPDATE,
WS_TYPE_CATEGORIES_GET_ALL,
WS_TYPE_LOYALTY_GET_ALL,
WS_TYPE_LOYALTY_ADD,
WS_TYPE_LOYALTY_UPDATE,
WS_TYPE_LOYALTY_DELETE,
WS_TYPE_LOYALTY_UPDATE_MEMBERS,
WS_TYPE_SUBSCRIBE,
EVENT_ITEM_ADDED,
EVENT_ITEM_UPDATED,
@@ -171,8 +178,11 @@ def websocket_get_lists(
) -> None:
"""Handle get all lists command."""
storage = get_storage(hass)
lists = storage.get_lists()
user = connection.user
user_id = user.id if user else None
is_admin = user.is_admin if user else False
lists = storage.get_lists(user_id=user_id, is_admin=is_admin)
connection.send_result(
msg["id"],
{
@@ -186,6 +196,7 @@ def websocket_get_lists(
vol.Required("type"): WS_TYPE_LISTS_CREATE,
vol.Required("name"): str,
vol.Optional("icon", default="mdi:cart"): str,
vol.Optional("private", default=True): bool,
}
)
@websocket_api.async_response
@@ -196,10 +207,15 @@ async def websocket_create_list(
) -> None:
"""Handle create list command."""
storage = get_storage(hass)
# Private lists are owned by the creating user; global lists have no owner.
is_private = msg.get("private", True)
owner_id = connection.user.id if is_private and connection.user else None
new_list = await storage.create_list(
name=msg["name"],
icon=msg.get("icon", "mdi:cart")
icon=msg.get("icon", "mdi:cart"),
owner_id=owner_id,
)
# Fire event
@@ -275,9 +291,21 @@ async def websocket_delete_list(
"""Handle delete list command."""
storage = get_storage(hass)
list_id = msg["list_id"]
lst = storage.get_list(list_id)
if lst is None:
connection.send_error(msg["id"], "not_found", "List not found")
return
# Only the owner or an admin may delete a private list
if lst.owner_id is not None:
user = connection.user
if not (user and (user.is_admin or user.id == lst.owner_id)):
connection.send_error(msg["id"], "forbidden", "Only the list owner can delete this list")
return
success = await storage.delete_list(list_id)
if not success:
connection.send_error(msg["id"], "not_found", "List not found")
return
@@ -869,3 +897,321 @@ def websocket_get_categories(
"categories": [cat.to_dict() for cat in categories]
}
)
# =============================================================================
# INTEGRATION SETTINGS HANDLERS
# =============================================================================
@websocket_api.websocket_command(
{
vol.Required("type"): "shopping_list_manager/get_integration_settings",
}
)
@callback
def websocket_get_integration_settings(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: Dict[str, Any],
) -> None:
"""Return current country and available country options."""
country = hass.data[DOMAIN].get("country", "NZ")
connection.send_result(
msg["id"],
{
"country": country,
"available_countries": {
"NZ": "New Zealand",
"AU": "Australia",
"US": "United States",
"GB": "United Kingdom",
"CA": "Canada",
},
}
)
@websocket_api.websocket_command(
{
vol.Required("type"): "shopping_list_manager/set_country",
vol.Required("country"): str,
}
)
@websocket_api.async_response
async def websocket_set_country(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: Dict[str, Any],
) -> None:
"""Switch to a different country catalog. Preserves user-added products."""
country = msg["country"].upper()
storage = get_storage(hass)
count = await storage.reload_catalog(country)
# Persist to HA config entry so country survives restart
entries = hass.config_entries.async_entries(DOMAIN)
if entries:
entry = entries[0]
hass.config_entries.async_update_entry(entry, options={**entry.options, "country": country})
hass.data[DOMAIN]["country"] = country
connection.send_result(
msg["id"],
{"success": True, "country": country, "products_loaded": count}
)
# =============================================================================
# BACKUP / RESTORE HANDLERS
# =============================================================================
@websocket_api.websocket_command(
{
vol.Required("type"): "shopping_list_manager/export_data",
}
)
@websocket_api.async_response
async def websocket_export_data(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: Dict[str, Any],
) -> None:
"""Export all user-created data as a JSON-serialisable dict."""
storage = get_storage(hass)
data = await storage.export_user_data()
connection.send_result(msg["id"], data)
@websocket_api.websocket_command(
{
vol.Required("type"): "shopping_list_manager/import_data",
vol.Required("data"): dict,
}
)
@websocket_api.async_response
async def websocket_import_data(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: Dict[str, Any],
) -> None:
"""Import user data from a backup payload."""
storage = get_storage(hass)
counts = await storage.import_user_data(msg["data"])
connection.send_result(msg["id"], {"success": True, "imported": counts})
# =============================================================================
# LIST MEMBERS HANDLER
# =============================================================================
@websocket_api.websocket_command(
{
vol.Required("type"): WS_TYPE_LISTS_UPDATE_MEMBERS,
vol.Required("list_id"): str,
vol.Required("allowed_users"): [str],
}
)
@websocket_api.async_response
async def websocket_update_list_members(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: Dict[str, Any],
) -> None:
"""Update the allowed_users for a private list."""
storage = get_storage(hass)
list_id = msg["list_id"]
lst = storage.get_list(list_id)
if lst is None:
connection.send_error(msg["id"], "not_found", "List not found")
return
# Only the owner or an admin may manage members
user = connection.user
if lst.owner_id is not None and not (user and (user.is_admin or user.id == lst.owner_id)):
connection.send_error(msg["id"], "forbidden", "Only the list owner can manage members")
return
updated = await storage.update_list_members(list_id, msg["allowed_users"])
hass.bus.async_fire(
EVENT_LIST_UPDATED,
{"list_id": list_id, "action": "members_updated"}
)
connection.send_result(msg["id"], {"list": updated.to_dict()})
# =============================================================================
# HA USERS HANDLER
# =============================================================================
@websocket_api.websocket_command(
{
vol.Required("type"): WS_TYPE_USERS_GET_ALL,
}
)
@websocket_api.async_response
async def websocket_get_ha_users(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: Dict[str, Any],
) -> None:
"""Return all active, non-system HA users."""
users = await hass.auth.async_get_users()
result = [
{"id": u.id, "name": u.name}
for u in users
if not u.system_generated and u.is_active
]
connection.send_result(msg["id"], {"users": result})
# =============================================================================
# LOYALTY CARD HANDLERS
# =============================================================================
@websocket_api.websocket_command({
vol.Required("type"): WS_TYPE_LOYALTY_GET_ALL,
})
@websocket_api.async_response
async def websocket_get_loyalty_cards(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: Dict[str, Any],
) -> None:
"""Return all loyalty cards visible to the current user."""
storage = get_storage(hass)
user = connection.user
user_id = user.id if user else None
is_admin = user.is_admin if user else False
cards = storage.get_loyalty_cards(user_id=user_id, is_admin=is_admin)
connection.send_result(msg["id"], {"cards": [c.to_dict() for c in cards]})
@websocket_api.websocket_command({
vol.Required("type"): WS_TYPE_LOYALTY_ADD,
vol.Required("name"): str,
vol.Required("number"): str,
vol.Optional("barcode", default=""): str,
vol.Optional("barcode_type", default="barcode"): str,
vol.Optional("logo", default=""): str,
vol.Optional("notes", default=""): str,
vol.Optional("color", default="#9fa8da"): str,
vol.Optional("private", default=True): bool,
})
@websocket_api.async_response
async def websocket_add_loyalty_card(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: Dict[str, Any],
) -> None:
"""Add a new loyalty card."""
storage = get_storage(hass)
user = connection.user
owner_id = user.id if (user and msg.get("private")) else None
card = await storage.create_loyalty_card(
owner_id=owner_id,
name=msg["name"],
number=msg["number"],
barcode=msg.get("barcode", ""),
barcode_type=msg.get("barcode_type", "barcode"),
logo=msg.get("logo", ""),
notes=msg.get("notes", ""),
color=msg.get("color", "#9fa8da"),
)
connection.send_result(msg["id"], {"card": card.to_dict()})
@websocket_api.websocket_command({
vol.Required("type"): WS_TYPE_LOYALTY_UPDATE,
vol.Required("card_id"): str,
vol.Optional("name"): str,
vol.Optional("number"): str,
vol.Optional("barcode"): str,
vol.Optional("barcode_type"): str,
vol.Optional("logo"): str,
vol.Optional("notes"): str,
vol.Optional("color"): str,
})
@websocket_api.async_response
async def websocket_update_loyalty_card(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: Dict[str, Any],
) -> None:
"""Update an existing loyalty card."""
storage = get_storage(hass)
card_id = msg["card_id"]
card = storage.get_loyalty_card(card_id)
if card is None:
connection.send_error(msg["id"], "not_found", "Loyalty card not found")
return
user = connection.user
if card.owner_id is not None and not (user and (user.is_admin or user.id == card.owner_id)):
connection.send_error(msg["id"], "forbidden", "Only the card owner can update it")
return
fields = {k: v for k, v in msg.items() if k not in ("type", "id", "card_id")}
updated = await storage.update_loyalty_card(card_id, **fields)
connection.send_result(msg["id"], {"card": updated.to_dict()})
@websocket_api.websocket_command({
vol.Required("type"): WS_TYPE_LOYALTY_DELETE,
vol.Required("card_id"): str,
})
@websocket_api.async_response
async def websocket_delete_loyalty_card(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: Dict[str, Any],
) -> None:
"""Delete a loyalty card."""
storage = get_storage(hass)
card_id = msg["card_id"]
card = storage.get_loyalty_card(card_id)
if card is None:
connection.send_error(msg["id"], "not_found", "Loyalty card not found")
return
user = connection.user
if card.owner_id is not None and not (user and (user.is_admin or user.id == card.owner_id)):
connection.send_error(msg["id"], "forbidden", "Only the card owner can delete it")
return
await storage.delete_loyalty_card(card_id)
connection.send_result(msg["id"], {"success": True})
@websocket_api.websocket_command({
vol.Required("type"): WS_TYPE_LOYALTY_UPDATE_MEMBERS,
vol.Required("card_id"): str,
vol.Required("allowed_users"): [str],
})
@websocket_api.async_response
async def websocket_update_loyalty_card_members(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: Dict[str, Any],
) -> None:
"""Update the allowed_users for a private loyalty card."""
storage = get_storage(hass)
card_id = msg["card_id"]
card = storage.get_loyalty_card(card_id)
if card is None:
connection.send_error(msg["id"], "not_found", "Loyalty card not found")
return
user = connection.user
if card.owner_id is not None and not (user and (user.is_admin or user.id == card.owner_id)):
connection.send_error(msg["id"], "forbidden", "Only the card owner can manage members")
return
updated = await storage.update_loyalty_card_members(card_id, msg["allowed_users"])
connection.send_result(msg["id"], {"card": updated.to_dict()})