Merge pull request #2 from thekiwismarthome/v2.0.0-phase2-backend-foundation

V2.0.0 phase2 backend foundation
This commit is contained in:
thekiwismarthome
2026-02-14 16:16:23 +13:00
committed by GitHub
21 changed files with 30062 additions and 866 deletions
@@ -1,56 +1,195 @@
"""
Shopping List Manager - Home Assistant Custom Integration
Clean-slate architecture with enforced invariants
"""
"""Shopping List Manager integration for Home Assistant."""
import logging
import os
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from homeassistant.components import websocket_api as ha_websocket
from .websocket_api import websocket_create_list
from homeassistant.helpers.typing import ConfigType
from .const import DOMAIN
from .manager import ShoppingListManager
# Import websocket handler functions directly
from .websocket_api import (
websocket_add_product,
websocket_set_qty,
websocket_get_products,
websocket_get_active,
websocket_delete_product,
ws_get_catalogues,
ws_get_lists,
)
from .storage import ShoppingListStorage
from .utils.images import ImageHandler
_LOGGER = logging.getLogger(__name__)
# Track storage instance globally
DATA_STORAGE = f"{DOMAIN}_storage"
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up the Shopping List Manager component from yaml (not used)."""
# This integration doesn't support YAML configuration
# All setup is done via config entries (UI configuration)
return True
# In async_setup_entry function, after storage initialization:
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Set up Shopping List Manager from a config entry."""
# Initialize the manager
manager = ShoppingListManager(hass)
await manager.async_load()
_LOGGER.info("Setting up Shopping List Manager")
# Store manager in hass.data
# Get component path for loading data files
component_path = os.path.dirname(__file__)
config_path = hass.config.path()
# Get country from options (or fall back to data, or default to NZ)
country = entry.options.get("country") or entry.data.get("country", "NZ")
_LOGGER.info("Using country: %s", country)
# Initialize storage with country
storage = ShoppingListStorage(hass, component_path, country)
await storage.async_load()
# Initialize image handler
image_handler = ImageHandler(hass, config_path)
# Store instances in hass.data
hass.data.setdefault(DOMAIN, {})
hass.data[DOMAIN]["manager"] = manager
hass.data[DOMAIN][DATA_STORAGE] = storage
hass.data[DOMAIN]["image_handler"] = image_handler
hass.data[DOMAIN]["country"] = country
# Register WebSocket commands using Home Assistant's websocket_api
ha_websocket.async_register_command(hass, websocket_create_list)
ha_websocket.async_register_command(hass, websocket_add_product)
ha_websocket.async_register_command(hass, websocket_set_qty)
ha_websocket.async_register_command(hass, websocket_get_products)
ha_websocket.async_register_command(hass, websocket_get_active)
ha_websocket.async_register_command(hass, websocket_delete_product)
ha_websocket.async_register_command(hass, ws_get_catalogues)
ha_websocket.async_register_command(hass, ws_get_lists)
# Register update listener for options changes
entry.async_on_unload(entry.add_update_listener(update_listener))
_LOGGER.info("Shopping List Manager setup complete - registered 7 WebSocket commands")
# Register WebSocket commands
await _async_register_websocket_handlers(hass, storage)
# Register frontend resources
await _async_register_frontend(hass)
_LOGGER.info("Shopping List Manager setup complete")
return True
async def update_listener(hass: HomeAssistant, entry: ConfigEntry) -> None:
"""Handle options update."""
# Reload the integration when options change
await hass.config_entries.async_reload(entry.entry_id)
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload Shopping List Manager."""
hass.data[DOMAIN].pop("manager", None)
"""Unload a config entry."""
_LOGGER.info("Unloading Shopping List Manager")
# Clean up hass.data
if DOMAIN in hass.data:
hass.data[DOMAIN].pop(DATA_STORAGE, None)
return True
async def _async_register_websocket_handlers(
hass: HomeAssistant,
storage: ShoppingListStorage
) -> None:
"""Register WebSocket API handlers."""
from homeassistant.components import websocket_api
from .websocket import handlers
# Lists handlers
websocket_api.async_register_command(
hass,
handlers.websocket_get_lists,
)
websocket_api.async_register_command(
hass,
handlers.websocket_create_list,
)
websocket_api.async_register_command(
hass,
handlers.websocket_update_list,
)
websocket_api.async_register_command(
hass,
handlers.websocket_delete_list,
)
websocket_api.async_register_command(
hass,
handlers.websocket_set_active_list,
)
# Items handlers
websocket_api.async_register_command(
hass,
handlers.websocket_get_items,
)
websocket_api.async_register_command(
hass,
handlers.websocket_add_item,
)
websocket_api.async_register_command(
hass,
handlers.websocket_update_item,
)
websocket_api.async_register_command(
hass,
handlers.websocket_check_item,
)
websocket_api.async_register_command(
hass,
handlers.websocket_delete_item,
)
websocket_api.async_register_command(
hass,
handlers.websocket_reorder_items,
)
websocket_api.async_register_command(
hass,
handlers.websocket_bulk_check_items,
)
websocket_api.async_register_command(
hass,
handlers.websocket_clear_checked_items,
)
websocket_api.async_register_command(
hass,
handlers.websocket_get_list_total,
)
# Products handlers
websocket_api.async_register_command(
hass,
handlers.websocket_search_products,
)
websocket_api.async_register_command(
hass,
handlers.websocket_get_product_suggestions,
)
websocket_api.async_register_command(
hass,
handlers.websocket_add_product,
)
websocket_api.async_register_command(
hass,
handlers.websocket_update_product,
)
websocket_api.async_register_command(
hass,
handlers.websocket_get_product_substitutes,
)
# Categories handlers
websocket_api.async_register_command(
hass,
handlers.websocket_get_categories,
)
_LOGGER.debug("WebSocket handlers registered")
async def _async_register_frontend(hass: HomeAssistant) -> None:
"""Register frontend resources."""
# Since frontend is a separate HACS module, we don't need to register it here
# The frontend card registers itself independently
_LOGGER.debug("Frontend resources skipped (separate HACS module)")
_LOGGER.debug("Frontend resources registered")
def get_storage(hass: HomeAssistant) -> ShoppingListStorage:
"""Get the storage instance from hass.data.
Helper function for WebSocket handlers to access storage.
"""
return hass.data[DOMAIN][DATA_STORAGE]
@@ -1,5 +1,7 @@
"""Config flow for Shopping List Manager."""
import voluptuous as vol
from homeassistant import config_entries
from homeassistant.core import callback
from .const import DOMAIN
@@ -16,10 +18,76 @@ class ShoppingListManagerConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
return self.async_abort(reason="single_instance_allowed")
if user_input is not None:
# Create entry with default country
return self.async_create_entry(
title="Shopping List Manager",
data={}
data={"country": "NZ"},
options={
"country": "NZ",
"enable_price_tracking": True,
"enable_image_search": True,
"metric_units_only": True,
}
)
# Show simple form
return self.async_show_form(step_id="user")
# Show simple setup form
return self.async_show_form(
step_id="user",
description_placeholders={
"info": "Country and other settings can be configured after setup via the Configure button."
}
)
@staticmethod
@callback
def async_get_options_flow(config_entry):
"""Get the options flow for this handler."""
return OptionsFlowHandler(config_entry)
class OptionsFlowHandler(config_entries.OptionsFlow):
"""Handle options flow for Shopping List Manager."""
def __init__(self, config_entry):
"""Initialize options flow."""
self.config_entry = config_entry
async def async_step_init(self, user_input=None):
"""Manage the options."""
if user_input is not None:
# Update options
return self.async_create_entry(title="", data=user_input)
# Get current settings
current_country = self.config_entry.options.get(
"country",
self.config_entry.data.get("country", "NZ")
)
return self.async_show_form(
step_id="init",
data_schema=vol.Schema({
vol.Required("country", default=current_country): vol.In({
"NZ": "New Zealand",
"AU": "Australia",
"US": "United States",
"GB": "United Kingdom",
"CA": "Canada",
}),
vol.Optional(
"enable_price_tracking",
default=self.config_entry.options.get("enable_price_tracking", True)
): bool,
vol.Optional(
"enable_image_search",
default=self.config_entry.options.get("enable_image_search", True)
): bool,
vol.Optional(
"metric_units_only",
default=self.config_entry.options.get("metric_units_only", True)
): bool,
}),
description_placeholders={
"info": "Changing country will reload the product catalog on next restart."
}
)
@@ -1,12 +1,119 @@
"""Constants for Shopping List Manager."""
# Domain
DOMAIN = "shopping_list_manager"
# Storage keys
STORAGE_VERSION = 1
# Storage Keys
STORAGE_VERSION = 2
STORAGE_KEY_LISTS = f"{DOMAIN}.lists"
STORAGE_KEY_ITEMS = f"{DOMAIN}.items"
STORAGE_KEY_PRODUCTS = f"{DOMAIN}.products"
STORAGE_KEY_ACTIVE = f"{DOMAIN}.active_list"
LISTS_STORE_KEY = "shopping_list_manager.lists"
STORAGE_KEY_CATEGORIES = f"{DOMAIN}.categories"
# WebSocket Commands - Lists
WS_TYPE_LISTS_GET_ALL = f"{DOMAIN}/lists/get_all"
WS_TYPE_LISTS_CREATE = f"{DOMAIN}/lists/create"
WS_TYPE_LISTS_UPDATE = f"{DOMAIN}/lists/update"
WS_TYPE_LISTS_DELETE = f"{DOMAIN}/lists/delete"
WS_TYPE_LISTS_SET_ACTIVE = f"{DOMAIN}/lists/set_active"
# WebSocket Commands - Items
WS_TYPE_ITEMS_GET = f"{DOMAIN}/items/get"
WS_TYPE_ITEMS_ADD = f"{DOMAIN}/items/add"
WS_TYPE_ITEMS_UPDATE = f"{DOMAIN}/items/update"
WS_TYPE_ITEMS_CHECK = f"{DOMAIN}/items/check"
WS_TYPE_ITEMS_DELETE = f"{DOMAIN}/items/delete"
WS_TYPE_ITEMS_REORDER = f"{DOMAIN}/items/reorder"
WS_TYPE_ITEMS_BULK_CHECK = f"{DOMAIN}/items/bulk_check"
WS_TYPE_ITEMS_CLEAR_CHECKED = f"{DOMAIN}/items/clear_checked"
WS_TYPE_ITEMS_GET_TOTAL = f"{DOMAIN}/items/get_total"
# WebSocket Commands - Products
WS_TYPE_PRODUCTS_SEARCH = f"{DOMAIN}/products/search"
WS_TYPE_PRODUCTS_SUGGESTIONS = f"{DOMAIN}/products/suggestions"
WS_TYPE_PRODUCTS_ADD = f"{DOMAIN}/products/add"
WS_TYPE_PRODUCTS_UPDATE = f"{DOMAIN}/products/update"
WS_TYPE_PRODUCTS_DELETE = f"{DOMAIN}/products/delete"
# WebSocket Commands - Categories
WS_TYPE_CATEGORIES_GET_ALL = f"{DOMAIN}/categories/get_all"
WS_TYPE_CATEGORIES_REORDER = f"{DOMAIN}/categories/reorder"
# WebSocket Commands - Subscriptions
WS_TYPE_SUBSCRIBE = f"{DOMAIN}/subscribe"
WS_TYPE_UNSUBSCRIBE = f"{DOMAIN}/unsubscribe"
# WebSocket Commands - Barcode (Phase 5)
WS_TYPE_BARCODE_SCAN = f"{DOMAIN}/barcode/scan"
WS_TYPE_BARCODE_ADD = f"{DOMAIN}/barcode/add_to_list"
# WebSocket Commands - OpenFoodFacts (Phase 5)
WS_TYPE_OFF_FETCH = f"{DOMAIN}/openfoodfacts/fetch"
WS_TYPE_OFF_IMPORT = f"{DOMAIN}/openfoodfacts/import"
# Events
EVENT_SHOPPING_LIST_UPDATED = f"{DOMAIN}_updated"
EVENT_ITEM_ADDED = f"{DOMAIN}_item_added"
EVENT_ITEM_UPDATED = f"{DOMAIN}_item_updated"
EVENT_ITEM_CHECKED = f"{DOMAIN}_item_checked"
EVENT_ITEM_DELETED = f"{DOMAIN}_item_deleted"
EVENT_LIST_UPDATED = f"{DOMAIN}_list_updated"
EVENT_LIST_DELETED = f"{DOMAIN}_list_deleted"
# Image Configuration
IMAGE_FORMAT = "webp"
IMAGE_SIZE = 200 # 200x200px
IMAGE_QUALITY = 85
IMAGE_MAX_SIZE_KB = 15
IMAGES_LOCAL_DIR = "www/shopping_list_manager/images"
# Placeholder image (inline SVG)
PLACEHOLDER_IMAGE = "data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' width='200' height='200'%3E%3Crect width='200' height='200' fill='%23f0f0f0'/%3E%3Ctext x='50%25' y='50%25' dominant-baseline='middle' text-anchor='middle' font-family='Arial' font-size='16' fill='%23999'%3ENo Image%3C/text%3E%3C/svg%3E"
# Metric Units (always metric, regardless of country)
METRIC_UNITS = {
"weight": ["kg", "g"],
"volume": ["L", "mL"],
"count": ["units", "pack", "loaf", "dozen", "ea", "pkt", "tray", "bottle", "can", "bunch", "pottle", "roll", "sachet", "tub", "bar"]
}
# Default quantities for common products (NZ-focused, can be country-specific later)
DEFAULT_QUANTITIES = {
"milk": {"quantity": 2, "unit": "L"},
"bread": {"quantity": 1, "unit": "loaf"},
"butter": {"quantity": 500, "unit": "g"},
"eggs": {"quantity": 12, "unit": "ea"},
"cheese": {"quantity": 500, "unit": "g"},
"yogurt": {"quantity": 1, "unit": "kg"},
"flour": {"quantity": 1.5, "unit": "kg"},
"sugar": {"quantity": 1.5, "unit": "kg"},
"rice": {"quantity": 1, "unit": "kg"},
"pasta": {"quantity": 500, "unit": "g"},
"chicken breast": {"quantity": 1, "unit": "kg"},
"beef mince": {"quantity": 500, "unit": "g"},
"sausages": {"quantity": 500, "unit": "g"},
"bacon": {"quantity": 500, "unit": "g"},
"apples": {"quantity": 1, "unit": "kg"},
"bananas": {"quantity": 1, "unit": "kg"},
"potatoes": {"quantity": 2, "unit": "kg"},
"onions": {"quantity": 1, "unit": "kg"},
"carrots": {"quantity": 1, "unit": "kg"},
"tomatoes": {"quantity": 500, "unit": "g"},
"lettuce": {"quantity": 1, "unit": "ea"},
"capsicum": {"quantity": 1, "unit": "ea"},
"broccoli": {"quantity": 1, "unit": "ea"},
"cereal": {"quantity": 1, "unit": "pack"},
"baked beans": {"quantity": 1, "unit": "can"},
"tuna": {"quantity": 1, "unit": "can"},
"olive oil": {"quantity": 1, "unit": "L"},
"coffee": {"quantity": 200, "unit": "g"},
"tea bags": {"quantity": 100, "unit": "ea"},
"toilet paper": {"quantity": 12, "unit": "roll"},
"paper towels": {"quantity": 2, "unit": "roll"},
"dishwashing liquid": {"quantity": 500, "unit": "mL"},
"laundry powder": {"quantity": 2, "unit": "kg"}
}
# Paths
CATEGORIES_FILE = "categories.json"
PRODUCTS_CATALOG_FILE = "products_catalog.json"
IMAGES_PATH = "images/products"
@@ -0,0 +1,62 @@
"""Product catalog loader for Shopping List Manager."""
import json
import logging
from typing import List, Dict, Any
import aiofiles
_LOGGER = logging.getLogger(__name__)
async def load_product_catalog(component_path: str, country_code: str = "NZ") -> List[Dict[str, Any]]:
"""Load product catalog from JSON file asynchronously.
Args:
component_path: Path to the component directory
country_code: Country code (e.g., 'NZ', 'AU', 'US')
Returns:
List of product dictionaries
"""
import os
# Try country-specific catalog first
if country_code:
catalog_file = os.path.join(
component_path, "data", f"products_catalog_{country_code.lower()}.json"
)
if not os.path.exists(catalog_file):
_LOGGER.warning(
"No country-specific catalog found for %s at %s",
country_code,
catalog_file
)
return []
else:
return []
try:
# Use aiofiles for async file reading
async with aiofiles.open(catalog_file, "r", encoding="utf-8") as f:
content = await f.read()
data = json.loads(content)
_LOGGER.info(
"Loaded product catalog version %s for region %s",
data.get("version", "unknown"),
data.get("region", "default")
)
products = data.get("products", [])
_LOGGER.info("Loaded %d products from catalog", len(products))
return products
except FileNotFoundError:
_LOGGER.error("Product catalog file not found: %s", catalog_file)
return []
except json.JSONDecodeError as err:
_LOGGER.error("Failed to parse product catalog file: %s", err)
return []
except Exception as err:
_LOGGER.error("Unexpected error loading product catalog: %s", err)
return []
@@ -0,0 +1,110 @@
{
"version": "1.0.0",
"region": "NZ",
"categories": [
{
"id": "produce",
"name": "Fruit & Veg",
"icon": "mdi:fruit-cherries",
"color": "#4CAF50",
"sort_order": 1,
"system": true
},
{
"id": "dairy",
"name": "Dairy & Eggs",
"icon": "mdi:cheese",
"color": "#FFC107",
"sort_order": 2,
"system": true
},
{
"id": "meat",
"name": "Meat & Seafood",
"icon": "mdi:food-steak",
"color": "#F44336",
"sort_order": 3,
"system": true
},
{
"id": "bakery",
"name": "Bakery",
"icon": "mdi:bread-slice",
"color": "#FF9800",
"sort_order": 4,
"system": true
},
{
"id": "frozen",
"name": "Frozen Foods",
"icon": "mdi:snowflake",
"color": "#2196F3",
"sort_order": 5,
"system": true
},
{
"id": "pantry",
"name": "Pantry",
"icon": "mdi:package-variant",
"color": "#795548",
"sort_order": 6,
"system": true
},
{
"id": "beverages",
"name": "Drinks",
"icon": "mdi:cup",
"color": "#00BCD4",
"sort_order": 7,
"system": true
},
{
"id": "snacks",
"name": "Snacks & Biscuits",
"icon": "mdi:food-apple",
"color": "#E91E63",
"sort_order": 8,
"system": true
},
{
"id": "household",
"name": "Household",
"icon": "mdi:spray-bottle",
"color": "#9C27B0",
"sort_order": 9,
"system": true
},
{
"id": "health",
"name": "Health & Beauty",
"icon": "mdi:heart-pulse",
"color": "#E91E63",
"sort_order": 10,
"system": true
},
{
"id": "pet",
"name": "Pet Supplies",
"icon": "mdi:paw",
"color": "#FF5722",
"sort_order": 11,
"system": true
},
{
"id": "baby",
"name": "Baby",
"icon": "mdi:baby-face",
"color": "#FFEB3B",
"sort_order": 12,
"system": true
},
{
"id": "other",
"name": "Other",
"icon": "mdi:dots-horizontal",
"color": "#9E9E9E",
"sort_order": 99,
"system": true
}
]
}
@@ -0,0 +1,98 @@
"""Category loader utility."""
import json
import logging
import os
from typing import List, Dict, Any
import aiofiles
_LOGGER = logging.getLogger(__name__)
async def load_categories(component_path: str, country_code: str = None) -> List[Dict[str, Any]]:
"""Load categories from JSON file asynchronously.
Args:
component_path: Path to the component directory
country_code: Country code from HA config (e.g., 'NZ', 'AU', 'US')
If None, loads default categories.json
Returns:
List of category dictionaries
"""
import os
# Try country-specific file first if country_code provided
if country_code:
country_file = os.path.join(
component_path, "data", f"categories_{country_code.lower()}.json"
)
if os.path.exists(country_file):
categories_file = country_file
_LOGGER.debug("Using country-specific categories: %s", country_code)
else:
_LOGGER.debug(
"No country-specific categories found for %s, using default",
country_code
)
categories_file = os.path.join(component_path, "data", "categories.json")
else:
categories_file = os.path.join(component_path, "data", "categories.json")
try:
async with aiofiles.open(categories_file, "r", encoding="utf-8") as f:
content = await f.read()
data = json.loads(content)
_LOGGER.info(
"Loaded categories version %s for region %s",
data.get("version", "unknown"),
data.get("region", "default")
)
return data.get("categories", [])
except FileNotFoundError:
_LOGGER.error("Categories file not found: %s", categories_file)
return _get_fallback_categories()
except json.JSONDecodeError as err:
_LOGGER.error("Failed to parse categories file: %s", err)
return _get_fallback_categories()
except Exception as err:
_LOGGER.error("Unexpected error loading categories: %s", err)
return _get_fallback_categories()
def _get_fallback_categories() -> List[Dict[str, Any]]:
"""Get minimal fallback categories if file loading fails.
Returns:
List of basic category dictionaries
"""
_LOGGER.warning("Using fallback categories")
return [
{
"id": "produce",
"name": "Produce",
"icon": "mdi:fruit-cherries",
"color": "#4CAF50",
"sort_order": 1,
"system": True
},
{
"id": "dairy",
"name": "Dairy",
"icon": "mdi:cheese",
"color": "#FFC107",
"sort_order": 2,
"system": True
},
{
"id": "other",
"name": "Other",
"icon": "mdi:dots-horizontal",
"color": "#9E9E9E",
"sort_order": 99,
"system": True
}
]
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
@@ -1,406 +0,0 @@
"""Core Shopping List Manager with invariant enforcement."""
import asyncio
import logging
from typing import Dict, Optional, List
from homeassistant.core import HomeAssistant
from homeassistant.helpers import storage
from homeassistant.helpers.storage import Store
from .const import (
DOMAIN,
EVENT_SHOPPING_LIST_UPDATED,
STORAGE_KEY_ACTIVE,
STORAGE_KEY_PRODUCTS,
STORAGE_VERSION,
)
from .models import Product, ActiveItem, InvariantError, validate_invariant
_LOGGER = logging.getLogger(__name__)
class ShoppingListManager:
"""
Manages multiple independent shopping lists.
Each list_id gets its own pair of storage files:
- shopping_list_manager.{list_id}.products
- shopping_list_manager.{list_id}.active_list
The default "groceries" list uses the original flat keys for backward compat:
- shopping_list_manager.products groceries products
- shopping_list_manager.active_list groceries active
Architecture principles:
1. Products and active_list are separate concerns per list
2. Products are authoritative, persistent data
3. Active list is ephemeral state
4. Invariant (active products) enforced on every mutation
5. Lock ensures atomic operations per list
"""
def __init__(self, hass: HomeAssistant):
"""Initialize the manager."""
self.hass = hass
# Per-list in-memory caches: list_id -> {key: Product}
self._products: Dict[str, Dict[str, Product]] = {}
self._active_list: Dict[str, Dict[str, ActiveItem]] = {}
# Per-list locks
self._locks: Dict[str, asyncio.Lock] = {}
# Per-list storage Store instances (created lazily, except groceries)
self._store_products: Dict[str, storage.Store] = {}
self._store_active: Dict[str, storage.Store] = {}
# Pre-create stores for the default "groceries" list using the original flat keys
# for backward compatibility — existing data just works
self._store_products["groceries"] = storage.Store(
hass, STORAGE_VERSION, STORAGE_KEY_PRODUCTS # "shopping_list_manager.products"
)
self._store_active["groceries"] = storage.Store(
hass, STORAGE_VERSION, STORAGE_KEY_ACTIVE # "shopping_list_manager.active_list"
)
# --- Catalogue + list metadata (NEW, additive) ---
self._store_catalogues = Store(
hass,
STORAGE_VERSION,
f"{DOMAIN}.catalogues",
)
self._catalogues: Dict[str, dict] = {}
self._store_lists = Store(
hass,
STORAGE_VERSION,
f"{DOMAIN}.lists",
)
self._lists: Dict[str, dict] = {}
def _lock_for(self, list_id: str) -> asyncio.Lock:
"""Get or create lock for a list."""
if list_id not in self._locks:
self._locks[list_id] = asyncio.Lock()
return self._locks[list_id]
def _store_products_for(self, list_id: str) -> storage.Store:
"""Get or create products Store for a list."""
if list_id not in self._store_products:
catalogue_id = self._lists.get(list_id, {}).get("catalogue", list_id)
catalogue = self._catalogues.get(catalogue_id)
key = (
catalogue["products_store"]
if catalogue
else f"{DOMAIN}.{list_id}.products"
)
self._store_products[list_id] = storage.Store(
self.hass, STORAGE_VERSION, key
)
return self._store_products[list_id]
def _store_active_for(self, list_id: str) -> storage.Store:
"""Get or create active Store for a list."""
if list_id not in self._store_active:
key = f"{DOMAIN}.{list_id}.active_list"
self._store_active[list_id] = storage.Store(self.hass, STORAGE_VERSION, key)
return self._store_active[list_id]
async def _ensure_loaded(self, list_id: str) -> None:
"""Lazily load a list from storage if not yet in memory."""
await self._ensure_catalogues_loaded()
await self._ensure_lists_loaded()
# Register list if it does not exist yet
if list_id not in self._lists:
self._lists[list_id] = {
"catalogue": list_id
}
await self._store_lists.async_save(self._lists)
if list_id in self._products:
return # already loaded
products_data = await self._store_products_for(list_id).async_load()
self._products[list_id] = {
key: Product.from_dict(data) for key, data in (products_data or {}).items()
}
active_data = await self._store_active_for(list_id).async_load()
self._active_list[list_id] = {
key: ActiveItem.from_dict(data) for key, data in (active_data or {}).items()
}
# Repair any orphaned active items
await self._async_repair_invariant(list_id)
_LOGGER.info(
"Loaded list '%s': %d products, %d active",
list_id, len(self._products[list_id]), len(self._active_list[list_id])
)
async def _ensure_catalogues_loaded(self) -> None:
data = await self._store_catalogues.async_load()
if isinstance(data, dict):
self._catalogues = data
return
# Bootstrap from existing behavior (no changes)
self._catalogues = {
"groceries": {
"name": "Groceries",
"icon": "🛒",
"products_store": f"{DOMAIN}.products",
}
}
await self._store_catalogues.async_save(self._catalogues)
async def _ensure_lists_loaded(self) -> None:
data = await self._store_lists.async_load()
if isinstance(data, dict):
self._lists = data
return
# Default: list_id == catalogue_id (current behavior)
self._lists = {
"groceries": {
"catalogue": "groceries",
}
}
await self._store_lists.async_save(self._lists)
async def async_load(self) -> None:
"""Pre-load the default groceries list for backward compat."""
async with self._lock_for("groceries"):
await self._ensure_loaded("groceries")
async def _async_repair_invariant(self, list_id: str) -> None:
"""Remove active items whose product no longer exists."""
orphaned = [k for k in self._active_list[list_id] if k not in self._products[list_id]]
if orphaned:
_LOGGER.warning(
"List '%s': removing %d orphaned active items: %s",
list_id, len(orphaned), orphaned
)
for k in orphaned:
del self._active_list[list_id][k]
await self._async_save_active(list_id)
async def _async_save_products(self, list_id: str) -> None:
"""Persist products to storage."""
data = {key: p.to_dict() for key, p in self._products[list_id].items()}
await self._store_products_for(list_id).async_save(data)
async def _async_save_active(self, list_id: str) -> None:
"""Persist active list to storage."""
data = {key: a.to_dict() for key, a in self._active_list[list_id].items()}
await self._store_active_for(list_id).async_save(data)
def _fire_update_event(self) -> None:
"""Fire event to notify listeners of changes."""
self.hass.bus.async_fire(EVENT_SHOPPING_LIST_UPDATED)
# ========================================================================
# PUBLIC API - All operations enforce invariants
# ========================================================================
import time
async def async_create_list(
self,
list_id: str,
catalogue: str,
owner: str,
visibility: str = "shared",
):
await self._ensure_catalogues_loaded()
await self._ensure_lists_loaded()
if list_id in self._lists:
raise ValueError(f"List '{list_id}' already exists")
if catalogue not in self._catalogues:
raise ValueError(f"Catalogue '{catalogue}' does not exist")
self._lists[list_id] = {
"catalogue": catalogue,
"owner": owner,
"visibility": visibility,
"created_at": time.time(),
"updated_at": time.time(),
}
await self._store_lists.async_save(self._lists)
async def async_add_product(
self,
list_id: str,
key: str,
name: str,
category: str = "other",
unit: str = "pcs",
image: str = ""
) -> Product:
"""
Add or update a product in a list's catalog.
This operation:
- Creates/updates product metadata
- Does NOT modify quantities
- Is idempotent
- Persists to storage
Args:
list_id: List identifier
key: Unique product identifier
name: Display name
category: Product category
unit: Unit of measurement
image: Image URL
Returns:
The created/updated Product
"""
async with self._lock_for(list_id):
await self._ensure_loaded(list_id)
product = Product(
key=key,
name=name,
category=category,
unit=unit,
image=image
)
self._products[list_id][key] = product
await self._async_save_products(list_id)
_LOGGER.debug("List '%s': added/updated product %s (%s)", list_id, name, key)
self._fire_update_event()
return product
async def async_set_qty(self, list_id: str, key: str, qty: int) -> None:
"""
Set quantity for a product on the shopping list.
This operation:
- REQUIRES product to exist (enforces invariant)
- qty > 0: adds/updates active_list
- qty == 0: removes from active_list
- Persists state
- Fires update event
Args:
list_id: List identifier
key: Product key (must exist in catalog)
qty: New quantity (0 to remove, >0 to add/update)
Raises:
InvariantError: If product doesn't exist
ValueError: If qty is negative
"""
if qty < 0:
raise ValueError(f"Quantity cannot be negative: {qty}")
async with self._lock_for(list_id):
await self._ensure_loaded(list_id)
# INVARIANT ENFORCEMENT: Product must exist
if key not in self._products[list_id]:
raise InvariantError(
f"Cannot set quantity for unknown product '{key}' in list '{list_id}'. "
f"Product must be created first with add_product."
)
# Update or remove from active list
if qty > 0:
self._active_list[list_id][key] = ActiveItem(qty=qty)
_LOGGER.debug("List '%s': set qty for %s: %d", list_id, key, qty)
else:
# qty == 0: remove from list
if key in self._active_list[list_id]:
del self._active_list[list_id][key]
_LOGGER.debug("List '%s': removed %s from active list", list_id, key)
await self._async_save_active(list_id)
self._fire_update_event()
async def async_delete_product(self, list_id: str, key: str) -> None:
"""
Delete a product from the catalog.
This operation:
- Removes product from catalog
- Removes from active list (maintains invariant)
- Persists both changes
Args:
list_id: List identifier
key: Product key to delete
"""
async with self._lock_for(list_id):
await self._ensure_loaded(list_id)
if key not in self._products[list_id]:
_LOGGER.warning("List '%s': attempted to delete non-existent product: %s", list_id, key)
return
# Remove from catalog
del self._products[list_id][key]
# Remove from active list (maintain invariant)
if key in self._active_list[list_id]:
del self._active_list[list_id][key]
await self._async_save_products(list_id)
await self._async_save_active(list_id)
_LOGGER.debug("List '%s': deleted product: %s", list_id, key)
self._fire_update_event()
def get_catalogues(self) -> Dict[str, dict]:
return self._catalogues
async def async_get_lists(self) -> Dict[str, dict]:
await self._ensure_catalogues_loaded()
await self._ensure_lists_loaded()
return self._lists
async def async_get_products(self, list_id: str) -> Dict[str, dict]:
"""
Get all products in a list's catalog.
Args:
list_id: List identifier
Returns:
Dictionary of product key -> product data
"""
async with self._lock_for(list_id):
await self._ensure_loaded(list_id)
return {key: product.to_dict() for key, product in self._products[list_id].items()}
async def async_get_active(self, list_id: str) -> Dict[str, dict]:
"""
Get active shopping list (quantities only).
Args:
list_id: List identifier
Returns:
Dictionary of product key -> active item data (qty only)
"""
async with self._lock_for(list_id):
await self._ensure_loaded(list_id)
return {key: item.to_dict() for key, item in self._active_list[list_id].items()}
# NOTE: The following methods were removed as they're not used by the websocket API
# and would need updating to support per-list structure:
# - async_get_full_state()
# - get_product()
# - get_active_qty()
@@ -1,12 +1,17 @@
{
"domain": "shopping_list_manager",
"name": "Shopping List Manager",
"version": "1.5.0",
"version": "2.0.0",
"documentation": "https://github.com/thekiwismarthome/shopping-list-manager",
"issue_tracker": "https://github.com/thekiwismarthome/shopping-list-manager/issues",
"requirements": [],
"requirements": [
"Pillow>=10.0.0",
"aiofiles>=23.0.0",
"rapidfuzz>=3.0.0"
],
"dependencies": [],
"codeowners": ["@thekiwismarthome"],
"config_flow": true,
"iot_class": "local_push"
"iot_class": "local_push",
"integration_type": "service"
}
@@ -1,104 +1,113 @@
"""Data models for Shopping List Manager."""
from dataclasses import dataclass, asdict
from typing import Dict
from dataclasses import dataclass, field, asdict
from datetime import datetime
from typing import Optional, List, Dict, Any
import uuid
def generate_id() -> str:
"""Generate a unique ID."""
return str(uuid.uuid4())
def current_timestamp() -> str:
"""Get current ISO timestamp."""
return datetime.utcnow().isoformat() + "Z"
@dataclass
class Category:
"""Category model."""
id: str
name: str
icon: str
color: str
sort_order: int
system: bool = True
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary."""
return asdict(self)
@dataclass
class Product:
"""
Product catalog entry - authoritative product definition.
Products exist independently of the shopping list.
They define WHAT can be shopped, not HOW MUCH is needed.
"""
key: str
"""Product model."""
id: str
name: str
category: str = "other"
unit: str = "pcs"
image: str = ""
category_id: str
aliases: List[str] = field(default_factory=list)
default_unit: str = "units"
default_quantity: float = 1
price: Optional[float] = None
currency: Optional[str] = None
price_per_unit: Optional[float] = None
price_updated: Optional[str] = None
image_url: Optional[str] = None
image_source: Optional[str] = None
barcode: Optional[str] = None
brands: List[str] = field(default_factory=list)
nutrition: Optional[Dict[str, Any]] = None
user_frequency: int = 0
last_used: Optional[str] = None
custom: bool = False
source: str = "user"
tags: List[str] = field(default_factory=list)
collections: List[str] = field(default_factory=list)
taxonomy: Dict[str, Any] = field(default_factory=dict)
allergens: List[str] = field(default_factory=list)
substitution_group: str = ""
priority_level: int = 0
image_hint: str = ""
def to_dict(self) -> dict:
"""Convert to dictionary for storage/transmission."""
return {
"key": self.key,
"name": self.name,
"category": self.category,
"unit": self.unit,
"image": self.image
}
@staticmethod
def from_dict(data: dict) -> 'Product':
"""Create Product from dictionary."""
return Product(
key=data["key"],
name=data["name"],
category=data.get("category", "other"),
unit=data.get("unit", "pcs"),
image=data.get("image", "")
)
def __post_init__(self):
"""Validate product data."""
if not self.key:
raise ValueError("Product key cannot be empty")
if not self.name:
raise ValueError("Product name cannot be empty")
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary."""
return asdict(self)
@dataclass
class ActiveItem:
"""
Shopping list state - quantity only.
class Item:
"""Shopping list item model."""
id: str
list_id: str
name: str
category_id: str
product_id: Optional[str] = None
quantity: float = 1
unit: str = "units"
note: Optional[str] = None
checked: bool = False
checked_at: Optional[str] = None
created_at: str = field(default_factory=current_timestamp)
updated_at: str = field(default_factory=current_timestamp)
image_url: Optional[str] = None
order_index: int = 0
price: Optional[float] = None
estimated_total: Optional[float] = None
barcode: Optional[str] = None
Contains NO product metadata, only references products by key.
qty > 0 means "on the list"
qty == 0 means "not on the list" (should be removed)
"""
qty: int
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary."""
return asdict(self)
def to_dict(self) -> dict:
"""Convert to dictionary for storage/transmission."""
return {"qty": self.qty}
@staticmethod
def from_dict(data: dict) -> 'ActiveItem':
"""Create ActiveItem from dictionary."""
return ActiveItem(qty=data["qty"])
def __post_init__(self):
"""Validate quantity."""
if self.qty < 0:
raise ValueError("Quantity cannot be negative")
def calculate_total(self) -> None:
"""Calculate estimated total from quantity and price."""
if self.price is not None:
self.estimated_total = self.quantity * self.price
class InvariantError(Exception):
"""
Raised when the core data model invariant is violated.
@dataclass
class ShoppingList:
"""Shopping list model."""
id: str
name: str
icon: str = "mdi:cart"
created_at: str = field(default_factory=current_timestamp)
updated_at: str = field(default_factory=current_timestamp)
item_order: List[str] = field(default_factory=list)
category_order: List[str] = field(default_factory=list)
active: bool = False
Invariant: Every key in active_list MUST exist in products.
If this exception is raised, the system is in an inconsistent state
and must be repaired before continuing.
"""
pass
def validate_invariant(products: Dict[str, Product],
active_list: Dict[str, ActiveItem]) -> None:
"""
Validate the core data model invariant.
Args:
products: Product catalog dictionary
active_list: Active shopping list dictionary
Raises:
InvariantError: If any key in active_list doesn't exist in products
"""
for key in active_list:
if key not in products:
raise InvariantError(
f"Invariant violated: active_list contains unknown product key '{key}'. "
f"This product must be added to the catalog first."
)
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary."""
return asdict(self)
@@ -0,0 +1,488 @@
"""Storage management for Shopping List Manager."""
import logging
from typing import Dict, List, Optional, Any
from .utils.search import ProductSearch
from homeassistant.core import HomeAssistant
from homeassistant.helpers.storage import Store
from .const import (
STORAGE_VERSION,
STORAGE_KEY_LISTS,
STORAGE_KEY_ITEMS,
STORAGE_KEY_PRODUCTS,
STORAGE_KEY_CATEGORIES,
)
from .data.catalog_loader import load_product_catalog
from .models import ShoppingList, Item, Product, Category, generate_id
from .data.category_loader import load_categories
_LOGGER = logging.getLogger(__name__)
class ShoppingListStorage:
"""Handle storage for shopping lists."""
def __init__(self, hass: HomeAssistant, component_path: str, country: str = "NZ") -> None:
"""Initialize storage.
Args:
hass: Home Assistant instance
component_path: Path to the component directory
country: Country code (NZ, AU, US, GB, CA, etc.)
"""
self.hass = hass
self._component_path = component_path
self._country = country # Store country
self._store_lists = Store(hass, STORAGE_VERSION, STORAGE_KEY_LISTS)
self._store_items = Store(hass, STORAGE_VERSION, STORAGE_KEY_ITEMS)
self._store_products = Store(hass, STORAGE_VERSION, STORAGE_KEY_PRODUCTS)
self._store_categories = Store(hass, STORAGE_VERSION, STORAGE_KEY_CATEGORIES)
self._lists: Dict[str, ShoppingList] = {}
self._items: Dict[str, List[Item]] = {}
self._products: Dict[str, Product] = {}
self._categories: List[Category] = []
self._search_engine: Optional[ProductSearch] = None
async def async_load(self) -> None:
"""Load data from storage."""
# Load lists
lists_data = await self._store_lists.async_load()
if lists_data:
self._lists = {
list_id: ShoppingList(**list_data)
for list_id, list_data in lists_data.items()
}
_LOGGER.debug("Loaded %d lists", len(self._lists))
else:
# Create default list if none exist
default_list = ShoppingList(
id=generate_id(),
name="Shopping List",
icon="mdi:cart",
active=True
)
self._lists[default_list.id] = default_list
await self._save_lists()
_LOGGER.info("Created default shopping list")
# Load items
items_data = await self._store_items.async_load()
if items_data:
self._items = {
list_id: [Item(**item_data) for item_data in items]
for list_id, items in items_data.items()
}
_LOGGER.debug("Loaded items for %d lists", len(self._items))
# Load products
products_data = await self._store_products.async_load()
if products_data:
self._products = {
product_id: Product(**product_data)
for product_id, product_data in products_data.items()
}
_LOGGER.debug("Loaded %d products", len(self._products))
# Load categories
categories_data = await self._store_categories.async_load()
if categories_data:
self._categories = [Category(**cat_data) for cat_data in categories_data]
_LOGGER.debug("Loaded %d categories", len(self._categories))
else:
# Initialize with default categories from JSON file
default_categories = await load_categories(self._component_path, self._country) # Use self._country
self._categories = [Category(**cat) for cat in default_categories]
await self._save_categories()
_LOGGER.info(
"Initialized %d default categories for country: %s",
len(self._categories),
self._country # Use self._country
)
# Load product catalog if products are empty
if not self._products:
_LOGGER.info("Loading product catalog for country: %s", self._country)
catalog_products = await load_product_catalog(self._component_path, self._country) # Use self._country
if catalog_products:
_LOGGER.info("Importing %d products from catalog", len(catalog_products))
# ... rest of import code ...
for prod_data in catalog_products:
try:
# Create Product from catalog data
product = Product(
id=prod_data.get("id", generate_id()),
name=prod_data["name"],
category_id=prod_data.get("category_id", "other"),
aliases=prod_data.get("aliases", []),
default_unit=prod_data.get("default_unit", "units"),
default_quantity=prod_data.get("default_quantity", 1),
price=prod_data.get("price") or prod_data.get("typical_price"),
currency=self.hass.config.currency,
barcode=prod_data.get("barcode"),
brands=prod_data.get("brands", []),
image_url=prod_data.get("image_url", ""),
custom=False,
source="catalog",
tags=prod_data.get("tags", []),
collections=prod_data.get("collections", []),
taxonomy=prod_data.get("taxonomy", {}),
allergens=prod_data.get("allergens", []),
substitution_group=prod_data.get("substitution_group", ""),
priority_level=prod_data.get("priority_level", 0),
image_hint=prod_data.get("image_hint", "")
)
self._products[product.id] = product
except Exception as err:
_LOGGER.error("Failed to import product %s: %s", prod_data.get("name"), err)
continue
await self._save_products()
_LOGGER.info("Successfully imported %d products from catalog", len(self._products))
# Initialize search engine after products are loaded
if self._products:
products_dict = {pid: p.to_dict() for pid, p in self._products.items()}
self._search_engine = ProductSearch(products_dict)
_LOGGER.debug("Initialized product search engine with %d products", len(self._products))
else:
self._search_engine = None
_LOGGER.warning("No products loaded, search engine not initialized")
# Lists methods
async def _save_lists(self) -> None:
"""Save lists to storage."""
data = {list_id: lst.to_dict() for list_id, lst in self._lists.items()}
await self._store_lists.async_save(data)
def get_lists(self) -> List[ShoppingList]:
"""Get all lists."""
return list(self._lists.values())
def get_list(self, list_id: str) -> Optional[ShoppingList]:
"""Get a specific list."""
return self._lists.get(list_id)
def get_active_list(self) -> Optional[ShoppingList]:
"""Get the active list."""
for lst in self._lists.values():
if lst.active:
return lst
return None
async def create_list(self, name: str, icon: str = "mdi:cart") -> ShoppingList:
"""Create a new list."""
new_list = ShoppingList(
id=generate_id(),
name=name,
icon=icon,
category_order=[cat.id for cat in self._categories]
)
self._lists[new_list.id] = new_list
self._items[new_list.id] = []
await self._save_lists()
_LOGGER.info("Created new list: %s", name)
return new_list
async def update_list(self, list_id: str, **kwargs) -> Optional[ShoppingList]:
"""Update a list."""
if list_id not in self._lists:
return None
lst = self._lists[list_id]
for key, value in kwargs.items():
if hasattr(lst, key):
setattr(lst, key, value)
from .models import current_timestamp
lst.updated_at = current_timestamp()
await self._save_lists()
_LOGGER.debug("Updated list: %s", list_id)
return lst
async def delete_list(self, list_id: str) -> bool:
"""Delete a list."""
if list_id not in self._lists:
return False
del self._lists[list_id]
if list_id in self._items:
del self._items[list_id]
await self._save_lists()
await self._save_items()
_LOGGER.info("Deleted list: %s", list_id)
return True
async def set_active_list(self, list_id: str) -> bool:
"""Set the active list."""
if list_id not in self._lists:
return False
# Deactivate all lists
for lst in self._lists.values():
lst.active = False
# Activate the specified list
self._lists[list_id].active = True
await self._save_lists()
_LOGGER.debug("Set active list: %s", list_id)
return True
# Items methods
async def _save_items(self) -> None:
"""Save items to storage."""
data = {
list_id: [item.to_dict() for item in items]
for list_id, items in self._items.items()
}
await self._store_items.async_save(data)
def get_items(self, list_id: str) -> List[Item]:
"""Get items for a list."""
return self._items.get(list_id, [])
async def add_item(self, list_id: str, **kwargs) -> Optional[Item]:
"""Add an item to a list."""
if list_id not in self._lists:
return None
new_item = Item(
id=generate_id(),
list_id=list_id,
**kwargs
)
new_item.calculate_total()
if list_id not in self._items:
self._items[list_id] = []
self._items[list_id].append(new_item)
# Update product frequency if product_id provided
if new_item.product_id and new_item.product_id in self._products:
product = self._products[new_item.product_id]
product.user_frequency += 1
from .models import current_timestamp
product.last_used = current_timestamp()
await self._save_products()
await self._save_items()
_LOGGER.debug("Added item to list %s: %s", list_id, new_item.name)
return new_item
async def update_item(self, item_id: str, **kwargs) -> Optional[Item]:
"""Update an item."""
for list_id, items in self._items.items():
for item in items:
if item.id == item_id:
for key, value in kwargs.items():
if hasattr(item, key):
setattr(item, key, value)
from .models import current_timestamp
item.updated_at = current_timestamp()
item.calculate_total()
await self._save_items()
_LOGGER.debug("Updated item: %s", item_id)
return item
return None
async def check_item(self, item_id: str, checked: bool) -> Optional[Item]:
"""Check or uncheck an item."""
for items in self._items.values():
for item in items:
if item.id == item_id:
item.checked = checked
from .models import current_timestamp
item.checked_at = current_timestamp() if checked else None
item.updated_at = current_timestamp()
await self._save_items()
_LOGGER.debug("Checked item: %s = %s", item_id, checked)
return item
return None
async def delete_item(self, item_id: str) -> bool:
"""Delete an item."""
for list_id, items in self._items.items():
for i, item in enumerate(items):
if item.id == item_id:
del self._items[list_id][i]
await self._save_items()
_LOGGER.debug("Deleted item: %s", item_id)
return True
return False
async def bulk_check_items(self, item_ids: List[str], checked: bool) -> int:
"""Bulk check/uncheck items."""
count = 0
from .models import current_timestamp
timestamp = current_timestamp()
for items in self._items.values():
for item in items:
if item.id in item_ids:
item.checked = checked
item.checked_at = timestamp if checked else None
item.updated_at = timestamp
count += 1
if count > 0:
await self._save_items()
_LOGGER.debug("Bulk checked %d items", count)
return count
async def clear_checked_items(self, list_id: str) -> int:
"""Clear all checked items from a list."""
if list_id not in self._items:
return 0
original_count = len(self._items[list_id])
self._items[list_id] = [item for item in self._items[list_id] if not item.checked]
removed_count = original_count - len(self._items[list_id])
if removed_count > 0:
await self._save_items()
_LOGGER.info("Cleared %d checked items from list %s", removed_count, list_id)
return removed_count
def get_list_total(self, list_id: str) -> Dict[str, Any]:
"""Get total price for a list."""
items = self.get_items(list_id)
total = 0.0
item_count = 0
for item in items:
if not item.checked and item.price is not None:
total += item.quantity * item.price
item_count += 1
return {
"total": round(total, 2),
"currency": self.hass.config.currency,
"item_count": item_count
}
# Products methods
async def _save_products(self) -> None:
"""Save products to storage."""
data = {product_id: product.to_dict() for product_id, product in self._products.items()}
await self._store_products.async_save(data)
def get_products(self) -> List[Product]:
"""Get all products."""
return list(self._products.values())
def get_product(self, product_id: str) -> Optional[Product]:
"""Get a specific product."""
return self._products.get(product_id)
def search_products(
self,
query: str,
limit: int = 10,
exclude_allergens: Optional[List[str]] = None,
include_tags: Optional[List[str]] = None,
substitution_group: Optional[str] = None,
) -> List[Product]:
"""Search products with enhanced fuzzy matching and filters.
Args:
query: Search query
limit: Maximum results
exclude_allergens: Allergens to exclude
include_tags: Tags to include
substitution_group: Filter by substitution group
Returns:
List of matching products
"""
if not self._search_engine:
_LOGGER.warning("Search engine not initialized")
return []
# Convert products dict to format search engine expects
products_dict = {pid: p.to_dict() for pid, p in self._products.items()}
search_engine = ProductSearch(products_dict)
results = search_engine.search(
query=query,
limit=limit,
exclude_allergens=exclude_allergens,
include_tags=include_tags,
substitution_group=substitution_group,
)
# Convert back to Product objects
return [self._products[r["id"]] for r in results if r["id"] in self._products]
def find_product_substitutes(self, product_id: str, limit: int = 5) -> List[Product]:
"""Find substitute products.
Args:
product_id: Product to find substitutes for
limit: Maximum substitutes
Returns:
List of substitute products
"""
if not self._search_engine:
return []
products_dict = {pid: p.to_dict() for pid, p in self._products.items()}
search_engine = ProductSearch(products_dict)
results = search_engine.find_substitutes(product_id, limit)
return [self._products[r["id"]] for r in results if r["id"] in self._products]
def get_product_suggestions(self, limit: int = 20) -> List[Product]:
"""Get product suggestions based on usage frequency."""
products = list(self._products.values())
products.sort(key=lambda p: p.user_frequency, reverse=True)
return products[:limit]
async def add_product(self, **kwargs) -> Product:
"""Add a new product."""
new_product = Product(
id=generate_id(),
currency=self.hass.config.currency,
**kwargs
)
self._products[new_product.id] = new_product
await self._save_products()
_LOGGER.debug("Added product: %s", new_product.name)
return new_product
async def update_product(self, product_id: str, **kwargs) -> Optional[Product]:
"""Update a product."""
if product_id not in self._products:
return None
product = self._products[product_id]
for key, value in kwargs.items():
if hasattr(product, key):
setattr(product, key, value)
await self._save_products()
_LOGGER.debug("Updated product: %s", product_id)
return product
# Categories methods
async def _save_categories(self) -> None:
"""Save categories to storage."""
data = [cat.to_dict() for cat in self._categories]
await self._store_categories.async_save(data)
def get_categories(self) -> List[Category]:
"""Get all categories."""
return self._categories
@@ -0,0 +1 @@
"""Utilities for Shopping List Manager."""
@@ -0,0 +1,109 @@
"""Image handling utilities for Shopping List Manager."""
import logging
import os
from pathlib import Path
from typing import Optional
_LOGGER = logging.getLogger(__name__)
class ImageHandler:
"""Handle product images with URL and local file support."""
def __init__(self, hass, config_path: str):
"""Initialize image handler.
Args:
hass: Home Assistant instance
config_path: Path to HA config directory
"""
self.hass = hass
# Images stored in /config/www/shopping_list_manager/images/
self._local_images_dir = Path(config_path) / "www" / "shopping_list_manager" / "images"
self._local_images_dir.mkdir(parents=True, exist_ok=True)
_LOGGER.info("Image directory: %s", self._local_images_dir)
def get_image_url(self, product_name: str, external_url: Optional[str] = None) -> str:
"""Get image URL for a product.
Priority:
1. External URL (if provided)
2. Local file match
3. Placeholder
Args:
product_name: Name of product to find image for
external_url: Optional external image URL
Returns:
Image URL (external, local, or placeholder)
"""
# Priority 1: Use external URL if provided
if external_url:
return external_url
# Priority 2: Look for local file
local_url = self._find_local_image(product_name)
if local_url:
return local_url
# Priority 3: Placeholder
return self._get_placeholder_url()
def _find_local_image(self, product_name: str) -> Optional[str]:
"""Find local image file for product.
Searches for files matching product name (case-insensitive).
Supports: .webp, .jpg, .jpeg, .png
Args:
product_name: Product name to search for
Returns:
Local URL if found, None otherwise
"""
# Normalize product name for filename matching
normalized_name = product_name.lower().replace(" ", "_")
# Supported extensions
extensions = [".webp", ".jpg", ".jpeg", ".png"]
for ext in extensions:
# Check exact match
image_file = self._local_images_dir / f"{normalized_name}{ext}"
if image_file.exists():
return f"/local/shopping_list_manager/images/{normalized_name}{ext}"
# Check for files starting with the product name
for file in self._local_images_dir.glob(f"{normalized_name}*{ext}"):
return f"/local/shopping_list_manager/images/{file.name}"
return None
def _get_placeholder_url(self) -> str:
"""Get placeholder image URL.
Returns:
URL to placeholder image
"""
# Use a simple colored placeholder
# You can replace this with a real placeholder image later
return "data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' width='200' height='200'%3E%3Crect width='200' height='200' fill='%23f0f0f0'/%3E%3Ctext x='50%25' y='50%25' dominant-baseline='middle' text-anchor='middle' font-family='Arial' font-size='16' fill='%23999'%3ENo Image%3C/text%3E%3C/svg%3E"
def list_available_images(self) -> list:
"""List all available local images.
Returns:
List of (filename, product_name_guess) tuples
"""
images = []
extensions = [".webp", ".jpg", ".jpeg", ".png"]
for ext in extensions:
for image_file in self._local_images_dir.glob(f"*{ext}"):
# Guess product name from filename
product_name = image_file.stem.replace("_", " ").title()
images.append((image_file.name, product_name))
return sorted(images)
@@ -0,0 +1,192 @@
"""Enhanced product search utilities."""
import logging
from typing import List, Dict, Any, Optional
from rapidfuzz import fuzz, process
_LOGGER = logging.getLogger(__name__)
class ProductSearch:
"""Advanced product search with fuzzy matching and filtering."""
def __init__(self, products: Dict[str, Any]):
"""Initialize search with product catalog.
Args:
products: Dictionary of product_id -> Product objects
"""
self.products = products
def search(
self,
query: str,
limit: int = 10,
exclude_allergens: Optional[List[str]] = None,
include_tags: Optional[List[str]] = None,
substitution_group: Optional[str] = None,
taxonomy_filters: Optional[Dict[str, Any]] = None,
min_score: int = 60,
) -> List[Dict[str, Any]]:
"""Advanced product search with multiple filters.
Args:
query: Search query string
limit: Maximum results to return
exclude_allergens: List of allergens to exclude (e.g., ["milk", "gluten"])
include_tags: Only include products with these tags
substitution_group: Filter by substitution group
taxonomy_filters: Filter by taxonomy (e.g., {"dietary": ["vegan"]})
min_score: Minimum fuzzy match score (0-100)
Returns:
List of matching products with scores
"""
query_lower = query.lower().strip()
if not query_lower:
return []
candidates = []
for product_id, product in self.products.items():
# Apply allergen filter
if exclude_allergens:
if any(
allergen in product.get("allergens", [])
for allergen in exclude_allergens
):
continue
# Apply tag filter
if include_tags:
if not any(
tag in product.get("tags", [])
for tag in include_tags
):
continue
# Apply substitution group filter
if substitution_group:
if product.get("substitution_group") != substitution_group:
continue
# Apply taxonomy filters
if taxonomy_filters:
product_taxonomy = product.get("taxonomy", {})
matches_taxonomy = True
for key, values in taxonomy_filters.items():
if key not in product_taxonomy:
matches_taxonomy = False
break
product_values = product_taxonomy[key]
if isinstance(product_values, list):
if not any(v in product_values for v in values):
matches_taxonomy = False
break
else:
if product_values not in values:
matches_taxonomy = False
break
if not matches_taxonomy:
continue
# Calculate match score
score = self._calculate_score(query_lower, product)
if score >= min_score:
candidates.append({
"product": product,
"score": score,
})
# Sort by score (descending), then by user frequency, then by priority
candidates.sort(
key=lambda x: (
x["score"],
x["product"].get("user_frequency", 0),
x["product"].get("priority_level", 0),
),
reverse=True
)
# Return top results
return [c["product"] for c in candidates[:limit]]
def _calculate_score(self, query: str, product: Dict[str, Any]) -> int:
"""Calculate fuzzy match score for a product.
Args:
query: Search query
product: Product dictionary
Returns:
Score from 0-100
"""
product_name = product.get("name", "").lower()
aliases = [a.lower() for a in product.get("aliases", [])]
# Exact match gets highest score
if query == product_name:
return 100
# Check aliases for exact match
if query in aliases:
return 95
# Check if query is substring of product name
if query in product_name:
return 90
# Check if query is substring of any alias
for alias in aliases:
if query in alias:
return 85
# Fuzzy match on product name
name_score = fuzz.WRatio(query, product_name)
# Fuzzy match on aliases
alias_scores = [fuzz.WRatio(query, alias) for alias in aliases]
best_alias_score = max(alias_scores) if alias_scores else 0
# Return best score
return max(name_score, best_alias_score)
def find_substitutes(self, product_id: str, limit: int = 5) -> List[Dict[str, Any]]:
"""Find substitute products for a given product.
Args:
product_id: ID of product to find substitutes for
limit: Maximum substitutes to return
Returns:
List of substitute products
"""
if product_id not in self.products:
return []
product = self.products[product_id]
substitution_group = product.get("substitution_group")
if not substitution_group:
return []
# Find all products in the same substitution group
substitutes = []
for pid, p in self.products.items():
if pid != product_id and p.get("substitution_group") == substitution_group:
substitutes.append(p)
# Sort by priority and frequency
substitutes.sort(
key=lambda x: (
x.get("priority_level", 0),
x.get("user_frequency", 0),
),
reverse=True
)
return substitutes[:limit]
@@ -0,0 +1 @@
"""WebSocket API handlers for Shopping List Manager."""
@@ -0,0 +1,793 @@
"""WebSocket API handlers for Shopping List Manager."""
import logging
from typing import Any, Dict
import voluptuous as vol
from homeassistant.components import websocket_api
from homeassistant.core import HomeAssistant, callback
from ..const import (
WS_TYPE_LISTS_GET_ALL,
WS_TYPE_LISTS_CREATE,
WS_TYPE_LISTS_UPDATE,
WS_TYPE_LISTS_DELETE,
WS_TYPE_LISTS_SET_ACTIVE,
WS_TYPE_ITEMS_GET,
WS_TYPE_ITEMS_ADD,
WS_TYPE_ITEMS_UPDATE,
WS_TYPE_ITEMS_CHECK,
WS_TYPE_ITEMS_DELETE,
WS_TYPE_ITEMS_REORDER,
WS_TYPE_ITEMS_BULK_CHECK,
WS_TYPE_ITEMS_CLEAR_CHECKED,
WS_TYPE_ITEMS_GET_TOTAL,
WS_TYPE_PRODUCTS_SEARCH,
WS_TYPE_PRODUCTS_SUGGESTIONS,
WS_TYPE_PRODUCTS_ADD,
WS_TYPE_PRODUCTS_UPDATE,
WS_TYPE_CATEGORIES_GET_ALL,
EVENT_ITEM_ADDED,
EVENT_ITEM_UPDATED,
EVENT_ITEM_CHECKED,
EVENT_ITEM_DELETED,
EVENT_LIST_UPDATED,
EVENT_LIST_DELETED,
)
from .. import get_storage
_LOGGER = logging.getLogger(__name__)
# =============================================================================
# LIST HANDLERS
# =============================================================================
@websocket_api.websocket_command(
{
vol.Required("type"): WS_TYPE_LISTS_GET_ALL,
}
)
@callback
def websocket_get_lists(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: Dict[str, Any],
) -> None:
"""Handle get all lists command."""
storage = get_storage(hass)
lists = storage.get_lists()
connection.send_result(
msg["id"],
{
"lists": [lst.to_dict() for lst in lists]
}
)
@websocket_api.websocket_command(
{
vol.Required("type"): WS_TYPE_LISTS_CREATE,
vol.Required("name"): str,
vol.Optional("icon", default="mdi:cart"): str,
}
)
@websocket_api.async_response
async def websocket_create_list(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: Dict[str, Any],
) -> None:
"""Handle create list command."""
storage = get_storage(hass)
new_list = await storage.create_list(
name=msg["name"],
icon=msg.get("icon", "mdi:cart")
)
# Fire event
hass.bus.async_fire(
EVENT_LIST_UPDATED,
{"list_id": new_list.id, "action": "created"}
)
connection.send_result(
msg["id"],
{"list": new_list.to_dict()}
)
@websocket_api.websocket_command(
{
vol.Required("type"): WS_TYPE_LISTS_UPDATE,
vol.Required("list_id"): str,
vol.Optional("name"): str,
vol.Optional("icon"): str,
vol.Optional("category_order"): [str],
}
)
@websocket_api.async_response
async def websocket_update_list(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: Dict[str, Any],
) -> None:
"""Handle update list command."""
storage = get_storage(hass)
list_id = msg["list_id"]
# Build update kwargs
update_data = {}
if "name" in msg:
update_data["name"] = msg["name"]
if "icon" in msg:
update_data["icon"] = msg["icon"]
if "category_order" in msg:
update_data["category_order"] = msg["category_order"]
updated_list = await storage.update_list(list_id, **update_data)
if updated_list is None:
connection.send_error(msg["id"], "not_found", "List not found")
return
# Fire event
hass.bus.async_fire(
EVENT_LIST_UPDATED,
{"list_id": list_id, "action": "updated"}
)
connection.send_result(
msg["id"],
{"list": updated_list.to_dict()}
)
@websocket_api.websocket_command(
{
vol.Required("type"): WS_TYPE_LISTS_DELETE,
vol.Required("list_id"): str,
}
)
@websocket_api.async_response
async def websocket_delete_list(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: Dict[str, Any],
) -> None:
"""Handle delete list command."""
storage = get_storage(hass)
list_id = msg["list_id"]
success = await storage.delete_list(list_id)
if not success:
connection.send_error(msg["id"], "not_found", "List not found")
return
# Fire event
hass.bus.async_fire(
EVENT_LIST_DELETED,
{"list_id": list_id}
)
connection.send_result(msg["id"], {"success": True})
@websocket_api.websocket_command(
{
vol.Required("type"): WS_TYPE_LISTS_SET_ACTIVE,
vol.Required("list_id"): str,
}
)
@websocket_api.async_response
async def websocket_set_active_list(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: Dict[str, Any],
) -> None:
"""Handle set active list command."""
storage = get_storage(hass)
list_id = msg["list_id"]
success = await storage.set_active_list(list_id)
if not success:
connection.send_error(msg["id"], "not_found", "List not found")
return
# Fire event
hass.bus.async_fire(
EVENT_LIST_UPDATED,
{"list_id": list_id, "action": "set_active"}
)
connection.send_result(msg["id"], {"success": True})
# =============================================================================
# ITEM HANDLERS
# =============================================================================
@websocket_api.websocket_command(
{
vol.Required("type"): WS_TYPE_ITEMS_GET,
vol.Required("list_id"): str,
}
)
@callback
def websocket_get_items(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: Dict[str, Any],
) -> None:
"""Handle get items command."""
storage = get_storage(hass)
list_id = msg["list_id"]
items = storage.get_items(list_id)
connection.send_result(
msg["id"],
{
"items": [item.to_dict() for item in items]
}
)
@websocket_api.websocket_command(
{
vol.Required("type"): WS_TYPE_ITEMS_ADD,
vol.Required("list_id"): str,
vol.Required("name"): str,
vol.Required("category_id"): str,
vol.Optional("product_id"): str,
vol.Optional("quantity", default=1): vol.Coerce(float),
vol.Optional("unit", default="units"): str,
vol.Optional("note"): str,
vol.Optional("price"): vol.Coerce(float),
vol.Optional("image_url"): str,
vol.Optional("barcode"): str,
}
)
@websocket_api.async_response
async def websocket_add_item(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: Dict[str, Any],
) -> None:
"""Handle add item command."""
storage = get_storage(hass)
list_id = msg["list_id"]
# Build item data
item_data = {
"name": msg["name"],
"category_id": msg["category_id"],
"quantity": msg.get("quantity", 1),
"unit": msg.get("unit", "units"),
}
# Optional fields
optional_fields = ["product_id", "note", "price", "image_url", "barcode"]
for field in optional_fields:
if field in msg:
item_data[field] = msg[field]
new_item = await storage.add_item(list_id, **item_data)
if new_item is None:
connection.send_error(msg["id"], "not_found", "List not found")
return
# Fire event
hass.bus.async_fire(
EVENT_ITEM_ADDED,
{
"list_id": list_id,
"item_id": new_item.id,
"item": new_item.to_dict()
}
)
connection.send_result(
msg["id"],
{"item": new_item.to_dict()}
)
@websocket_api.websocket_command(
{
vol.Required("type"): WS_TYPE_ITEMS_UPDATE,
vol.Required("item_id"): str,
vol.Optional("name"): str,
vol.Optional("quantity"): vol.Coerce(float),
vol.Optional("unit"): str,
vol.Optional("note"): str,
vol.Optional("price"): vol.Coerce(float),
vol.Optional("category_id"): str,
vol.Optional("image_url"): str,
}
)
@websocket_api.async_response
async def websocket_update_item(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: Dict[str, Any],
) -> None:
"""Handle update item command."""
storage = get_storage(hass)
item_id = msg["item_id"]
# Build update data
update_data = {}
update_fields = ["name", "quantity", "unit", "note", "price", "category_id", "image_url"]
for field in update_fields:
if field in msg:
update_data[field] = msg[field]
updated_item = await storage.update_item(item_id, **update_data)
if updated_item is None:
connection.send_error(msg["id"], "not_found", "Item not found")
return
# Fire event
hass.bus.async_fire(
EVENT_ITEM_UPDATED,
{
"list_id": updated_item.list_id,
"item_id": item_id,
"item": updated_item.to_dict()
}
)
connection.send_result(
msg["id"],
{"item": updated_item.to_dict()}
)
@websocket_api.websocket_command(
{
vol.Required("type"): WS_TYPE_ITEMS_CHECK,
vol.Required("item_id"): str,
vol.Required("checked"): bool,
}
)
@websocket_api.async_response
async def websocket_check_item(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: Dict[str, Any],
) -> None:
"""Handle check/uncheck item command."""
storage = get_storage(hass)
item_id = msg["item_id"]
checked = msg["checked"]
updated_item = await storage.check_item(item_id, checked)
if updated_item is None:
connection.send_error(msg["id"], "not_found", "Item not found")
return
# Fire event
hass.bus.async_fire(
EVENT_ITEM_CHECKED,
{
"list_id": updated_item.list_id,
"item_id": item_id,
"checked": checked
}
)
connection.send_result(
msg["id"],
{"item": updated_item.to_dict()}
)
@websocket_api.websocket_command(
{
vol.Required("type"): WS_TYPE_ITEMS_DELETE,
vol.Required("item_id"): str,
}
)
@websocket_api.async_response
async def websocket_delete_item(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: Dict[str, Any],
) -> None:
"""Handle delete item command."""
storage = get_storage(hass)
item_id = msg["item_id"]
success = await storage.delete_item(item_id)
if not success:
connection.send_error(msg["id"], "not_found", "Item not found")
return
# Fire event
hass.bus.async_fire(
EVENT_ITEM_DELETED,
{"item_id": item_id}
)
connection.send_result(msg["id"], {"success": True})
@websocket_api.websocket_command(
{
vol.Required("type"): WS_TYPE_ITEMS_REORDER,
vol.Required("list_id"): str,
vol.Required("item_order"): [str],
}
)
@websocket_api.async_response
async def websocket_reorder_items(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: Dict[str, Any],
) -> None:
"""Handle reorder items command."""
storage = get_storage(hass)
list_id = msg["list_id"]
item_order = msg["item_order"]
updated_list = await storage.update_list(list_id, item_order=item_order)
if updated_list is None:
connection.send_error(msg["id"], "not_found", "List not found")
return
# Fire event
hass.bus.async_fire(
EVENT_LIST_UPDATED,
{"list_id": list_id, "action": "reordered"}
)
connection.send_result(msg["id"], {"success": True})
@websocket_api.websocket_command(
{
vol.Required("type"): WS_TYPE_ITEMS_BULK_CHECK,
vol.Required("item_ids"): [str],
vol.Required("checked"): bool,
}
)
@websocket_api.async_response
async def websocket_bulk_check_items(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: Dict[str, Any],
) -> None:
"""Handle bulk check/uncheck items command."""
storage = get_storage(hass)
item_ids = msg["item_ids"]
checked = msg["checked"]
count = await storage.bulk_check_items(item_ids, checked)
# Fire event
hass.bus.async_fire(
EVENT_ITEM_CHECKED,
{
"item_ids": item_ids,
"checked": checked,
"count": count
}
)
connection.send_result(
msg["id"],
{"success": True, "count": count}
)
@websocket_api.websocket_command(
{
vol.Required("type"): WS_TYPE_ITEMS_CLEAR_CHECKED,
vol.Required("list_id"): str,
}
)
@websocket_api.async_response
async def websocket_clear_checked_items(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: Dict[str, Any],
) -> None:
"""Handle clear checked items command."""
storage = get_storage(hass)
list_id = msg["list_id"]
count = await storage.clear_checked_items(list_id)
# Fire event
hass.bus.async_fire(
EVENT_ITEM_DELETED,
{"list_id": list_id, "count": count, "action": "cleared_checked"}
)
connection.send_result(
msg["id"],
{"success": True, "count": count}
)
@websocket_api.websocket_command(
{
vol.Required("type"): WS_TYPE_ITEMS_GET_TOTAL,
vol.Required("list_id"): str,
}
)
@callback
def websocket_get_list_total(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: Dict[str, Any],
) -> None:
"""Handle get list total command."""
storage = get_storage(hass)
list_id = msg["list_id"]
total_data = storage.get_list_total(list_id)
connection.send_result(msg["id"], total_data)
# =============================================================================
# PRODUCT HANDLERS
# =============================================================================
@websocket_api.websocket_command(
{
vol.Required("type"): WS_TYPE_PRODUCTS_SEARCH,
vol.Required("query"): str,
vol.Optional("limit", default=10): int,
vol.Optional("exclude_allergens", default=None): vol.Any(None, [str]),
vol.Optional("include_tags", default=None): vol.Any(None, [str]),
vol.Optional("substitution_group", default=None): vol.Any(None, str),
}
)
@callback
def websocket_search_products(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: Dict[str, Any],
) -> None:
"""Handle search products command with enhanced filters."""
storage = get_storage(hass)
try:
results = storage.search_products(
query=msg["query"],
limit=msg.get("limit", 10),
exclude_allergens=msg.get("exclude_allergens"),
include_tags=msg.get("include_tags"),
substitution_group=msg.get("substitution_group"),
)
connection.send_result(
msg["id"],
{"products": [product.to_dict() for product in results]}
)
except Exception as err:
_LOGGER.error("Error searching products: %s", err)
connection.send_error(msg["id"], "search_failed", str(err))
@websocket_api.websocket_command(
{
vol.Required("type"): "shopping_list_manager/products/substitutes",
vol.Required("product_id"): str,
vol.Optional("limit", default=5): int,
}
)
@callback
def websocket_get_product_substitutes(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: Dict[str, Any],
) -> None:
"""Handle get product substitutes command."""
storage = get_storage(hass)
try:
substitutes = storage.find_product_substitutes(
product_id=msg["product_id"],
limit=msg.get("limit", 5),
)
connection.send_result(
msg["id"],
{"substitutes": [product.to_dict() for product in substitutes]}
)
except Exception as err:
_LOGGER.error("Error finding substitutes: %s", err)
connection.send_error(msg["id"], "substitutes_failed", str(err))
@websocket_api.websocket_command(
{
vol.Required("type"): WS_TYPE_PRODUCTS_SEARCH,
vol.Required("query"): str,
vol.Optional("limit", default=10): int,
vol.Optional("exclude_allergens"): [str],
vol.Optional("include_tags"): [str],
vol.Optional("substitution_group"): str,
}
)
@callback
def websocket_search_products(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: Dict[str, Any],
) -> None:
"""Handle search products command with enhanced filters."""
storage = get_storage(hass)
try:
results = storage.search_products(
query=msg["query"],
limit=msg.get("limit", 10),
exclude_allergens=msg.get("exclude_allergens"),
include_tags=msg.get("include_tags"),
substitution_group=msg.get("substitution_group"),
)
connection.send_result(
msg["id"],
{"products": [product.to_dict() for product in results]}
)
except Exception as err:
_LOGGER.error("Error searching products: %s", err)
connection.send_error(msg["id"], "search_failed", str(err))
@websocket_api.websocket_command(
{
vol.Required("type"): WS_TYPE_PRODUCTS_SUGGESTIONS,
vol.Optional("limit", default=20): int,
}
)
@callback
def websocket_get_product_suggestions(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: Dict[str, Any],
) -> None:
"""Handle get product suggestions command."""
storage = get_storage(hass)
limit = msg.get("limit", 20)
suggestions = storage.get_product_suggestions(limit)
connection.send_result(
msg["id"],
{
"products": [product.to_dict() for product in suggestions]
}
)
@websocket_api.websocket_command(
{
vol.Required("type"): WS_TYPE_PRODUCTS_ADD,
vol.Required("name"): str,
vol.Required("category_id"): str,
vol.Optional("aliases"): [str],
vol.Optional("default_unit", default="units"): str,
vol.Optional("default_quantity", default=1): vol.Coerce(float),
vol.Optional("price"): vol.Coerce(float),
vol.Optional("barcode"): str,
vol.Optional("image_url"): str,
}
)
@websocket_api.async_response
async def websocket_add_product(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: Dict[str, Any],
) -> None:
"""Handle add product command."""
storage = get_storage(hass)
# Build product data
product_data = {
"name": msg["name"],
"category_id": msg["category_id"],
"default_unit": msg.get("default_unit", "units"),
"default_quantity": msg.get("default_quantity", 1),
"custom": True,
"source": "user"
}
# Optional fields
optional_fields = ["aliases", "price", "barcode", "image_url"]
for field in optional_fields:
if field in msg:
product_data[field] = msg[field]
new_product = await storage.add_product(**product_data)
connection.send_result(
msg["id"],
{"product": new_product.to_dict()}
)
@websocket_api.websocket_command(
{
vol.Required("type"): WS_TYPE_PRODUCTS_UPDATE,
vol.Required("product_id"): str,
vol.Optional("name"): str,
vol.Optional("category_id"): str,
vol.Optional("price"): vol.Coerce(float),
vol.Optional("default_unit"): str,
vol.Optional("default_quantity"): vol.Coerce(float),
vol.Optional("aliases"): [str],
vol.Optional("image_url"): str,
}
)
@websocket_api.async_response
async def websocket_update_product(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: Dict[str, Any],
) -> None:
"""Handle update product command."""
storage = get_storage(hass)
product_id = msg["product_id"]
# Build update data
update_data = {}
update_fields = ["name", "category_id", "price", "default_unit", "default_quantity", "aliases", "image_url"]
for field in update_fields:
if field in msg:
update_data[field] = msg[field]
# Add price_updated timestamp if price changed
if "price" in update_data:
from ..models import current_timestamp
update_data["price_updated"] = current_timestamp()
updated_product = await storage.update_product(product_id, **update_data)
if updated_product is None:
connection.send_error(msg["id"], "not_found", "Product not found")
return
connection.send_result(
msg["id"],
{"product": updated_product.to_dict()}
)
# =============================================================================
# CATEGORY HANDLERS
# =============================================================================
@websocket_api.websocket_command(
{
vol.Required("type"): WS_TYPE_CATEGORIES_GET_ALL,
}
)
@callback
def websocket_get_categories(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: Dict[str, Any],
) -> None:
"""Handle get all categories command."""
storage = get_storage(hass)
categories = storage.get_categories()
connection.send_result(
msg["id"],
{
"categories": [cat.to_dict() for cat in categories]
}
)
@@ -1,325 +0,0 @@
"""WebSocket API for Shopping List Manager."""
import logging
import voluptuous as vol
from homeassistant.components import websocket_api
from homeassistant.core import HomeAssistant, callback
from .const import DOMAIN
from .models import InvariantError
_LOGGER = logging.getLogger(__name__)
@websocket_api.websocket_command({
vol.Required("type"): "shopping_list_manager/create_list",
vol.Required("list_id"): str,
vol.Required("catalogue"): str,
vol.Optional("visibility", default="shared"): vol.In(["shared", "private"]),
})
@websocket_api.async_response
async def websocket_create_list(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: dict,
) -> None:
manager = hass.data[DOMAIN]["manager"]
try:
await manager.async_create_list(
list_id=msg["list_id"],
catalogue=msg["catalogue"],
owner=connection.user.id,
visibility=msg.get("visibility", "shared"),
)
connection.send_result(msg["id"], {"success": True})
except Exception as err:
connection.send_error(msg["id"], "create_list_failed", str(err))
@websocket_api.websocket_command({
vol.Required("type"): "shopping_list_manager/add_product",
vol.Optional("list_id", default="groceries"): str,
vol.Required("key"): str,
vol.Required("name"): str,
vol.Optional("category", default="other"): str,
vol.Optional("unit", default="pcs"): str,
vol.Optional("image", default=""): str,
})
@websocket_api.async_response
async def websocket_add_product(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: dict,
) -> None:
"""
Add or update a product in the catalog.
Does NOT modify quantity - use set_qty for that.
Request:
{
"type": "shopping_list_manager/add_product",
"list_id": "groceries", # optional, defaults to "groceries"
"key": "milk",
"name": "Milk",
"category": "dairy",
"unit": "pcs",
"image": ""
}
Response:
{
"success": true,
"result": {
"key": "milk",
"name": "Milk",
"category": "dairy",
"unit": "pcs",
"image": ""
}
}
"""
manager = hass.data[DOMAIN]["manager"]
list_id = msg.get("list_id", "groceries")
try:
product = await manager.async_add_product(
list_id=list_id,
key=msg["key"],
name=msg["name"],
category=msg.get("category", "other"),
unit=msg.get("unit", "pcs"),
image=msg.get("image", "")
)
connection.send_result(msg["id"], product.to_dict())
except Exception as err:
_LOGGER.error("Error adding product to list '%s': %s", list_id, err)
connection.send_error(msg["id"], "add_product_failed", str(err))
@websocket_api.websocket_command({
vol.Required("type"): "shopping_list_manager/get_catalogues",
})
@websocket_api.async_response
async def ws_get_catalogues(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: dict,
) -> None:
"""Return catalogue metadata (read-only)."""
manager = hass.data[DOMAIN]["manager"]
try:
catalogues = manager.get_catalogues()
connection.send_result(msg["id"], catalogues)
except Exception as err:
_LOGGER.error("Error getting catalogues: %s", err)
connection.send_error(msg["id"], "get_catalogues_failed", str(err))
@websocket_api.websocket_command({
vol.Required("type"): "shopping_list_manager/get_lists",
})
@websocket_api.async_response
async def ws_get_lists(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: dict,
) -> None:
"""Return list → catalogue mapping (read-only)."""
manager = hass.data[DOMAIN]["manager"]
try:
# Ensure lists are loaded
await manager._ensure_lists_loaded()
lists = manager._lists
connection.send_result(msg["id"], lists)
except Exception as err:
_LOGGER.error("Error getting lists: %s", err)
connection.send_error(msg["id"], "get_lists_failed", str(err))
@websocket_api.websocket_command({
vol.Required("type"): "shopping_list_manager/set_qty",
vol.Optional("list_id", default="groceries"): str,
vol.Required("key"): str,
vol.Required("qty"): vol.All(int, vol.Range(min=0)),
})
@websocket_api.async_response
async def websocket_set_qty(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: dict,
) -> None:
"""
Set quantity for a product on the shopping list.
Product MUST exist in catalog first.
qty = 0 removes from list.
qty > 0 adds/updates on list.
Request:
{
"type": "shopping_list_manager/set_qty",
"list_id": "groceries", # optional, defaults to "groceries"
"key": "milk",
"qty": 2
}
Response:
{
"success": true
}
Error (if product doesn't exist):
{
"success": false,
"error": {
"code": "invariant_violation",
"message": "Cannot set quantity for unknown product 'milk'..."
}
}
"""
manager = hass.data[DOMAIN]["manager"]
list_id = msg.get("list_id", "groceries")
try:
await manager.async_set_qty(
list_id=list_id,
key=msg["key"],
qty=msg["qty"]
)
connection.send_result(msg["id"], {"success": True})
except InvariantError as err:
# This is expected if frontend tries to set qty for non-existent product
_LOGGER.warning("Invariant violation in set_qty (list '%s'): %s", list_id, err)
connection.send_error(msg["id"], "invariant_violation", str(err))
except Exception as err:
_LOGGER.error("Error setting quantity in list '%s': %s", list_id, err)
connection.send_error(msg["id"], "set_qty_failed", str(err))
@websocket_api.websocket_command({
vol.Required("type"): "shopping_list_manager/get_products",
vol.Optional("list_id", default="groceries"): str,
})
@websocket_api.async_response
async def websocket_get_products(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: dict,
) -> None:
"""
Get all products in the catalog.
Request:
{
"type": "shopping_list_manager/get_products",
"list_id": "groceries" # optional, defaults to "groceries"
}
Response:
{
"milk": {
"key": "milk",
"name": "Milk",
"category": "dairy",
"unit": "pcs",
"image": ""
},
...
}
"""
manager = hass.data[DOMAIN]["manager"]
list_id = msg.get("list_id", "groceries")
try:
products = await manager.async_get_products(list_id=list_id)
connection.send_result(msg["id"], products)
except Exception as err:
_LOGGER.error("Error getting products for list '%s': %s", list_id, err)
connection.send_error(msg["id"], "get_products_failed", str(err))
@websocket_api.websocket_command({
vol.Required("type"): "shopping_list_manager/get_active",
vol.Optional("list_id", default="groceries"): str,
})
@websocket_api.async_response
async def websocket_get_active(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: dict,
) -> None:
"""
Get active shopping list (quantities only).
Request:
{
"type": "shopping_list_manager/get_active",
"list_id": "groceries" # optional, defaults to "groceries"
}
Response:
{
"milk": {"qty": 2},
"bread": {"qty": 1},
...
}
"""
manager = hass.data[DOMAIN]["manager"]
list_id = msg.get("list_id", "groceries")
try:
active = await manager.async_get_active(list_id=list_id)
connection.send_result(msg["id"], active)
except Exception as err:
_LOGGER.error("Error getting active list for '%s': %s", list_id, err)
connection.send_error(msg["id"], "get_active_failed", str(err))
@websocket_api.websocket_command({
vol.Required("type"): "shopping_list_manager/delete_product",
vol.Optional("list_id", default="groceries"): str,
vol.Required("key"): str,
})
@websocket_api.async_response
async def websocket_delete_product(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: dict,
) -> None:
"""
Delete a product from catalog (and remove from active list).
Request:
{
"type": "shopping_list_manager/delete_product",
"list_id": "groceries", # optional, defaults to "groceries"
"key": "milk"
}
Response:
{
"success": true
}
"""
manager = hass.data[DOMAIN]["manager"]
list_id = msg.get("list_id", "groceries")
try:
await manager.async_delete_product(list_id=list_id, key=msg["key"])
connection.send_result(msg["id"], {"success": True})
except Exception as err:
_LOGGER.error("Error deleting product from list '%s': %s", list_id, err)
connection.send_error(msg["id"], "delete_product_failed", str(err))