diff --git a/custom_components/shopping_list_manager/__init__.py b/custom_components/shopping_list_manager/__init__.py index 074d62b..ca291bc 100644 --- a/custom_components/shopping_list_manager/__init__.py +++ b/custom_components/shopping_list_manager/__init__.py @@ -213,6 +213,16 @@ async def _async_register_websocket_handlers( handlers.websocket_set_country, ) + # Backup / Restore handlers + websocket_api.async_register_command( + hass, + handlers.websocket_export_data, + ) + websocket_api.async_register_command( + hass, + handlers.websocket_import_data, + ) + _LOGGER.debug("WebSocket handlers registered") diff --git a/custom_components/shopping_list_manager/storage.py b/custom_components/shopping_list_manager/storage.py index 3d38607..32e5563 100644 --- a/custom_components/shopping_list_manager/storage.py +++ b/custom_components/shopping_list_manager/storage.py @@ -1,5 +1,8 @@ """Storage management for Shopping List Manager.""" +import json import logging +import os +from datetime import datetime, timezone from typing import Dict, List, Optional, Any from .utils.search import ProductSearch from homeassistant.core import HomeAssistant @@ -182,6 +185,7 @@ class ShoppingListStorage: self._lists[new_list.id] = new_list self._items[new_list.id] = [] await self._save_lists() + await self._write_config_backup() _LOGGER.info("Created new list: %s", name) return new_list @@ -460,6 +464,7 @@ class ShoppingListStorage: ) self._products[new_product.id] = new_product await self._save_products() + await self._write_config_backup() # Rebuild search engine so the new product is immediately searchable products_dict = {pid: p.to_dict() for pid, p in self._products.items()} self._search_engine = ProductSearch(products_dict) @@ -518,15 +523,108 @@ class ShoppingListStorage: """Update a product.""" if product_id not in self._products: return None - + product = self._products[product_id] for key, value in kwargs.items(): if hasattr(product, key): setattr(product, key, value) - + await self._save_products() + await self._write_config_backup() _LOGGER.debug("Updated product: %s", product_id) return product + + # --------------------------------------------------------------------------- + # Backup / Restore + # --------------------------------------------------------------------------- + + async def export_user_data(self) -> dict: + """Return a serialisable snapshot of all user-created data.""" + user_products = [ + p.to_dict() for p in self._products.values() + if getattr(p, "source", "user") == "user" + ] + lists = [lst.to_dict() for lst in self._lists.values()] + items = { + list_id: [item.to_dict() for item in items_list] + for list_id, items_list in self._items.items() + } + return { + "slm_backup_version": "1.0", + "exported_at": datetime.now(timezone.utc).isoformat(), + "country": self._country, + "user_products": user_products, + "lists": lists, + "items": items, + } + + async def import_user_data(self, data: dict) -> dict: + """Merge a backup into live storage. Skips anything already present by ID.""" + imported_products = 0 + imported_lists = 0 + imported_items = 0 + + for prod_data in data.get("user_products", []): + prod_id = prod_data.get("id") + if prod_id and prod_id not in self._products: + try: + self._products[prod_id] = Product(**prod_data) + imported_products += 1 + except Exception as err: + _LOGGER.warning("Skipped product during import: %s", err) + + if imported_products: + await self._save_products() + products_dict = {pid: p.to_dict() for pid, p in self._products.items()} + self._search_engine = ProductSearch(products_dict) + + for list_data in data.get("lists", []): + list_id = list_data.get("id") + if list_id and list_id not in self._lists: + try: + lst = ShoppingList(**list_data) + lst.active = False + self._lists[list_id] = lst + imported_lists += 1 + except Exception as err: + _LOGGER.warning("Skipped list during import: %s", err) + + backup_items = data.get("items", {}) + for list_id, items_list in backup_items.items(): + if list_id in self._lists and list_id not in self._items: + try: + self._items[list_id] = [Item(**d) for d in items_list] + imported_items += len(self._items[list_id]) + except Exception as err: + _LOGGER.warning("Skipped items for list %s: %s", list_id, err) + + if imported_lists or imported_items: + await self._save_lists() + await self._save_items() + + _LOGGER.info( + "Import complete: %d products, %d lists, %d items", + imported_products, imported_lists, imported_items, + ) + return {"products": imported_products, "lists": imported_lists, "items": imported_items} + + async def _write_config_backup(self) -> None: + """Silently write a backup JSON to the HA config directory.""" + try: + backup_path = os.path.join( + self.hass.config.config_dir, + "shopping_list_manager_backup.json", + ) + data = await self.export_user_data() + + def _write() -> None: + with open(backup_path, "w", encoding="utf-8") as f: + json.dump(data, f, indent=2, ensure_ascii=False) + + await self.hass.async_add_executor_job(_write) + _LOGGER.debug("Auto-backup written to %s", backup_path) + except Exception as err: + _LOGGER.warning("Failed to write config backup: %s", err) # Categories methods async def _save_categories(self) -> None: diff --git a/custom_components/shopping_list_manager/websocket/handlers.py b/custom_components/shopping_list_manager/websocket/handlers.py index 12b3e15..5a7736f 100644 --- a/custom_components/shopping_list_manager/websocket/handlers.py +++ b/custom_components/shopping_list_manager/websocket/handlers.py @@ -933,3 +933,42 @@ async def websocket_set_country( msg["id"], {"success": True, "country": country, "products_loaded": count} ) + + +# ============================================================================= +# BACKUP / RESTORE HANDLERS +# ============================================================================= + +@websocket_api.websocket_command( + { + vol.Required("type"): "shopping_list_manager/export_data", + } +) +@websocket_api.async_response +async def websocket_export_data( + hass: HomeAssistant, + connection: websocket_api.ActiveConnection, + msg: Dict[str, Any], +) -> None: + """Export all user-created data as a JSON-serialisable dict.""" + storage = get_storage(hass) + data = await storage.export_user_data() + connection.send_result(msg["id"], data) + + +@websocket_api.websocket_command( + { + vol.Required("type"): "shopping_list_manager/import_data", + vol.Required("data"): dict, + } +) +@websocket_api.async_response +async def websocket_import_data( + hass: HomeAssistant, + connection: websocket_api.ActiveConnection, + msg: Dict[str, Any], +) -> None: + """Import user data from a backup payload.""" + storage = get_storage(hass) + counts = await storage.import_user_data(msg["data"]) + connection.send_result(msg["id"], {"success": True, "imported": counts})