Update storage.py

This commit is contained in:
thekiwismarthome
2026-02-13 12:24:19 +13:00
committed by GitHub
parent 16b594f11a
commit 296a8ad5fd
@@ -1 +1,393 @@
"""Storage management for Shopping List Manager."""
import logging
from typing import Dict, List, Optional, Any
from homeassistant.core import HomeAssistant
from homeassistant.helpers.storage import Store
from .const import (
STORAGE_VERSION,
STORAGE_KEY_LISTS,
STORAGE_KEY_ITEMS,
STORAGE_KEY_PRODUCTS,
STORAGE_KEY_CATEGORIES,
)
from .models import ShoppingList, Item, Product, Category, generate_id
from .data.category_loader import load_categories
_LOGGER = logging.getLogger(__name__)
class ShoppingListStorage:
"""Handle storage for shopping lists."""
def __init__(self, hass: HomeAssistant, component_path: str) -> None:
"""Initialize storage.
Args:
hass: Home Assistant instance
component_path: Path to the component directory
"""
self.hass = hass
self._component_path = component_path
self._store_lists = Store(hass, STORAGE_VERSION, STORAGE_KEY_LISTS)
self._store_items = Store(hass, STORAGE_VERSION, STORAGE_KEY_ITEMS)
self._store_products = Store(hass, STORAGE_VERSION, STORAGE_KEY_PRODUCTS)
self._store_categories = Store(hass, STORAGE_VERSION, STORAGE_KEY_CATEGORIES)
self._lists: Dict[str, ShoppingList] = {}
self._items: Dict[str, List[Item]] = {}
self._products: Dict[str, Product] = {}
self._categories: List[Category] = []
async def async_load(self) -> None:
"""Load data from storage."""
# Load lists
lists_data = await self._store_lists.async_load()
if lists_data:
self._lists = {
list_id: ShoppingList(**list_data)
for list_id, list_data in lists_data.items()
}
_LOGGER.debug("Loaded %d lists", len(self._lists))
else:
# Create default list if none exist
default_list = ShoppingList(
id=generate_id(),
name="Shopping List",
icon="mdi:cart",
active=True
)
self._lists[default_list.id] = default_list
await self._save_lists()
_LOGGER.info("Created default shopping list")
# Load items
items_data = await self._store_items.async_load()
if items_data:
self._items = {
list_id: [Item(**item_data) for item_data in items]
for list_id, items in items_data.items()
}
_LOGGER.debug("Loaded items for %d lists", len(self._items))
# Load products
products_data = await self._store_products.async_load()
if products_data:
self._products = {
product_id: Product(**product_data)
for product_id, product_data in products_data.items()
}
_LOGGER.debug("Loaded %d products", len(self._products))
# Load categories
categories_data = await self._store_categories.async_load()
if categories_data:
self._categories = [Category(**cat_data) for cat_data in categories_data]
_LOGGER.debug("Loaded %d categories", len(self._categories))
else:
# Initialize with default categories from JSON file
default_categories = load_categories(self._component_path)
self._categories = [Category(**cat) for cat in default_categories]
await self._save_categories()
_LOGGER.info("Initialized %d default categories", len(self._categories))
# Lists methods
async def _save_lists(self) -> None:
"""Save lists to storage."""
data = {list_id: lst.to_dict() for list_id, lst in self._lists.items()}
await self._store_lists.async_save(data)
def get_lists(self) -> List[ShoppingList]:
"""Get all lists."""
return list(self._lists.values())
def get_list(self, list_id: str) -> Optional[ShoppingList]:
"""Get a specific list."""
return self._lists.get(list_id)
def get_active_list(self) -> Optional[ShoppingList]:
"""Get the active list."""
for lst in self._lists.values():
if lst.active:
return lst
return None
async def create_list(self, name: str, icon: str = "mdi:cart") -> ShoppingList:
"""Create a new list."""
new_list = ShoppingList(
id=generate_id(),
name=name,
icon=icon,
category_order=[cat.id for cat in self._categories]
)
self._lists[new_list.id] = new_list
self._items[new_list.id] = []
await self._save_lists()
_LOGGER.info("Created new list: %s", name)
return new_list
async def update_list(self, list_id: str, **kwargs) -> Optional[ShoppingList]:
"""Update a list."""
if list_id not in self._lists:
return None
lst = self._lists[list_id]
for key, value in kwargs.items():
if hasattr(lst, key):
setattr(lst, key, value)
from .models import current_timestamp
lst.updated_at = current_timestamp()
await self._save_lists()
_LOGGER.debug("Updated list: %s", list_id)
return lst
async def delete_list(self, list_id: str) -> bool:
"""Delete a list."""
if list_id not in self._lists:
return False
del self._lists[list_id]
if list_id in self._items:
del self._items[list_id]
await self._save_lists()
await self._save_items()
_LOGGER.info("Deleted list: %s", list_id)
return True
async def set_active_list(self, list_id: str) -> bool:
"""Set the active list."""
if list_id not in self._lists:
return False
# Deactivate all lists
for lst in self._lists.values():
lst.active = False
# Activate the specified list
self._lists[list_id].active = True
await self._save_lists()
_LOGGER.debug("Set active list: %s", list_id)
return True
# Items methods
async def _save_items(self) -> None:
"""Save items to storage."""
data = {
list_id: [item.to_dict() for item in items]
for list_id, items in self._items.items()
}
await self._store_items.async_save(data)
def get_items(self, list_id: str) -> List[Item]:
"""Get items for a list."""
return self._items.get(list_id, [])
async def add_item(self, list_id: str, **kwargs) -> Optional[Item]:
"""Add an item to a list."""
if list_id not in self._lists:
return None
new_item = Item(
id=generate_id(),
list_id=list_id,
**kwargs
)
new_item.calculate_total()
if list_id not in self._items:
self._items[list_id] = []
self._items[list_id].append(new_item)
# Update product frequency if product_id provided
if new_item.product_id and new_item.product_id in self._products:
product = self._products[new_item.product_id]
product.user_frequency += 1
from .models import current_timestamp
product.last_used = current_timestamp()
await self._save_products()
await self._save_items()
_LOGGER.debug("Added item to list %s: %s", list_id, new_item.name)
return new_item
async def update_item(self, item_id: str, **kwargs) -> Optional[Item]:
"""Update an item."""
for list_id, items in self._items.items():
for item in items:
if item.id == item_id:
for key, value in kwargs.items():
if hasattr(item, key):
setattr(item, key, value)
from .models import current_timestamp
item.updated_at = current_timestamp()
item.calculate_total()
await self._save_items()
_LOGGER.debug("Updated item: %s", item_id)
return item
return None
async def check_item(self, item_id: str, checked: bool) -> Optional[Item]:
"""Check or uncheck an item."""
for items in self._items.values():
for item in items:
if item.id == item_id:
item.checked = checked
from .models import current_timestamp
item.checked_at = current_timestamp() if checked else None
item.updated_at = current_timestamp()
await self._save_items()
_LOGGER.debug("Checked item: %s = %s", item_id, checked)
return item
return None
async def delete_item(self, item_id: str) -> bool:
"""Delete an item."""
for list_id, items in self._items.items():
for i, item in enumerate(items):
if item.id == item_id:
del self._items[list_id][i]
await self._save_items()
_LOGGER.debug("Deleted item: %s", item_id)
return True
return False
async def bulk_check_items(self, item_ids: List[str], checked: bool) -> int:
"""Bulk check/uncheck items."""
count = 0
from .models import current_timestamp
timestamp = current_timestamp()
for items in self._items.values():
for item in items:
if item.id in item_ids:
item.checked = checked
item.checked_at = timestamp if checked else None
item.updated_at = timestamp
count += 1
if count > 0:
await self._save_items()
_LOGGER.debug("Bulk checked %d items", count)
return count
async def clear_checked_items(self, list_id: str) -> int:
"""Clear all checked items from a list."""
if list_id not in self._items:
return 0
original_count = len(self._items[list_id])
self._items[list_id] = [item for item in self._items[list_id] if not item.checked]
removed_count = original_count - len(self._items[list_id])
if removed_count > 0:
await self._save_items()
_LOGGER.info("Cleared %d checked items from list %s", removed_count, list_id)
return removed_count
def get_list_total(self, list_id: str) -> Dict[str, Any]:
"""Get total price for a list."""
items = self.get_items(list_id)
total = 0.0
item_count = 0
for item in items:
if not item.checked and item.price is not None:
total += item.quantity * item.price
item_count += 1
return {
"total": round(total, 2),
"currency": self.hass.config.currency,
"item_count": item_count
}
# Products methods
async def _save_products(self) -> None:
"""Save products to storage."""
data = {product_id: product.to_dict() for product_id, product in self._products.items()}
await self._store_products.async_save(data)
def get_products(self) -> List[Product]:
"""Get all products."""
return list(self._products.values())
def get_product(self, product_id: str) -> Optional[Product]:
"""Get a specific product."""
return self._products.get(product_id)
def search_products(self, query: str, limit: int = 10) -> List[Product]:
"""Search products by name or alias."""
query_lower = query.lower()
results = []
for product in self._products.values():
# Check name
if query_lower in product.name.lower():
results.append(product)
continue
# Check aliases
if any(query_lower in alias.lower() for alias in product.aliases):
results.append(product)
continue
# Sort by frequency
results.sort(key=lambda p: p.user_frequency, reverse=True)
return results[:limit]
def get_product_suggestions(self, limit: int = 20) -> List[Product]:
"""Get product suggestions based on usage frequency."""
products = list(self._products.values())
products.sort(key=lambda p: p.user_frequency, reverse=True)
return products[:limit]
async def add_product(self, **kwargs) -> Product:
"""Add a new product."""
new_product = Product(
id=generate_id(),
currency=self.hass.config.currency,
**kwargs
)
self._products[new_product.id] = new_product
await self._save_products()
_LOGGER.debug("Added product: %s", new_product.name)
return new_product
async def update_product(self, product_id: str, **kwargs) -> Optional[Product]:
"""Update a product."""
if product_id not in self._products:
return None
product = self._products[product_id]
for key, value in kwargs.items():
if hasattr(product, key):
setattr(product, key, value)
await self._save_products()
_LOGGER.debug("Updated product: %s", product_id)
return product
# Categories methods
async def _save_categories(self) -> None:
"""Save categories to storage."""
data = [cat.to_dict() for cat in self._categories]
await self._store_categories.async_save(data)
def get_categories(self) -> List[Category]:
"""Get all categories."""
return self._categories