From 57b6d52ddfe831ad65ead564181bcbb6a5b2f351 Mon Sep 17 00:00:00 2001 From: thekiwismarthome Date: Thu, 26 Feb 2026 06:51:26 +1300 Subject: [PATCH] feat: Add country selection for product catalogs, including new WebSocket handlers and catalog reloading logic. --- .../shopping_list_manager/__init__.py | 12 +++- .../shopping_list_manager/storage.py | 48 ++++++++++++++ .../websocket/handlers.py | 64 +++++++++++++++++++ 3 files changed, 123 insertions(+), 1 deletion(-) diff --git a/custom_components/shopping_list_manager/__init__.py b/custom_components/shopping_list_manager/__init__.py index f95bfc3..074d62b 100644 --- a/custom_components/shopping_list_manager/__init__.py +++ b/custom_components/shopping_list_manager/__init__.py @@ -202,7 +202,17 @@ async def _async_register_websocket_handlers( hass, handlers.websocket_get_categories, ) - + + # Integration settings handlers + websocket_api.async_register_command( + hass, + handlers.websocket_get_integration_settings, + ) + websocket_api.async_register_command( + hass, + handlers.websocket_set_country, + ) + _LOGGER.debug("WebSocket handlers registered") diff --git a/custom_components/shopping_list_manager/storage.py b/custom_components/shopping_list_manager/storage.py index c02f5dd..3d38607 100644 --- a/custom_components/shopping_list_manager/storage.py +++ b/custom_components/shopping_list_manager/storage.py @@ -466,6 +466,54 @@ class ShoppingListStorage: _LOGGER.debug("Added product: %s", new_product.name) return new_product + async def reload_catalog(self, country_code: str) -> int: + """Replace catalog-sourced products with those from a new country's catalog. + Products with source='user' are preserved.""" + catalog_ids = [ + pid for pid, p in self._products.items() + if getattr(p, 'source', 'user') == 'catalog' + ] + for pid in catalog_ids: + del self._products[pid] + + self._country = country_code + catalog_products = await load_product_catalog(self._component_path, country_code) + count = 0 + for prod_data in catalog_products: + try: + product = Product( + id=prod_data.get("id", generate_id()), + name=prod_data["name"], + category_id=prod_data.get("category_id", "other"), + aliases=prod_data.get("aliases", []), + default_unit=prod_data.get("default_unit", "units"), + default_quantity=prod_data.get("default_quantity", 1), + price=prod_data.get("price") or prod_data.get("typical_price"), + currency=self.hass.config.currency, + barcode=prod_data.get("barcode"), + brands=prod_data.get("brands", []), + image_url=prod_data.get("image_url", ""), + custom=False, + source="catalog", + tags=prod_data.get("tags", []), + collections=prod_data.get("collections", []), + taxonomy=prod_data.get("taxonomy", {}), + allergens=prod_data.get("allergens", []), + substitution_group=prod_data.get("substitution_group", ""), + priority_level=prod_data.get("priority_level", 0), + image_hint=prod_data.get("image_hint", "") + ) + self._products[product.id] = product + count += 1 + except Exception as err: + _LOGGER.error("Failed to import product %s: %s", prod_data.get("name"), err) + + await self._save_products() + products_dict = {pid: p.to_dict() for pid, p in self._products.items()} + self._search_engine = ProductSearch(products_dict) + _LOGGER.info("Reloaded catalog for %s: %d products imported", country_code, count) + return count + async def update_product(self, product_id: str, **kwargs) -> Optional[Product]: """Update a product.""" if product_id not in self._products: diff --git a/custom_components/shopping_list_manager/websocket/handlers.py b/custom_components/shopping_list_manager/websocket/handlers.py index 68e0d99..12b3e15 100644 --- a/custom_components/shopping_list_manager/websocket/handlers.py +++ b/custom_components/shopping_list_manager/websocket/handlers.py @@ -869,3 +869,67 @@ def websocket_get_categories( "categories": [cat.to_dict() for cat in categories] } ) + + +# ============================================================================= +# INTEGRATION SETTINGS HANDLERS +# ============================================================================= + +@websocket_api.websocket_command( + { + vol.Required("type"): "shopping_list_manager/get_integration_settings", + } +) +@callback +def websocket_get_integration_settings( + hass: HomeAssistant, + connection: websocket_api.ActiveConnection, + msg: Dict[str, Any], +) -> None: + """Return current country and available country options.""" + country = hass.data[DOMAIN].get("country", "NZ") + connection.send_result( + msg["id"], + { + "country": country, + "available_countries": { + "NZ": "New Zealand", + "AU": "Australia", + "US": "United States", + "GB": "United Kingdom", + "CA": "Canada", + }, + } + ) + + +@websocket_api.websocket_command( + { + vol.Required("type"): "shopping_list_manager/set_country", + vol.Required("country"): str, + } +) +@websocket_api.async_response +async def websocket_set_country( + hass: HomeAssistant, + connection: websocket_api.ActiveConnection, + msg: Dict[str, Any], +) -> None: + """Switch to a different country catalog. Preserves user-added products.""" + country = msg["country"].upper() + storage = get_storage(hass) + + count = await storage.reload_catalog(country) + + # Persist to HA config entry so country survives restart + entries = hass.config_entries.async_entries(DOMAIN) + if entries: + entry = entries[0] + hass.config_entries.async_update_entry(entry, options={**entry.options, "country": country}) + + hass.data[DOMAIN]["country"] = country + + connection.send_result( + msg["id"], + {"success": True, "country": country, "products_loaded": count} + )