# app/services/vectorstore_service.py
import logging
from typing import Any, Dict, List, Optional, Tuple

import redis
from langchain_redis import RedisVectorStore

logger = logging.getLogger(__name__)


class VectorstoreService:
    """Wrapper for Redis vector store; handles creation and connection."""

    def __init__(
        self,
        redis_url: str,
        index_name: str,
        embeddings: Any,
    ):
        self.redis_url = redis_url
        self.index_name = index_name
        self.embeddings = embeddings
        self._vectorstore: Optional[RedisVectorStore] = None

    def get_or_connect(self) -> Optional[RedisVectorStore]:
        """Get existing vectorstore or connect to existing index."""
        if self._vectorstore is not None:
            return self._vectorstore

        try:
            key_prefix = self._get_index_key_prefix()
            if key_prefix:
                self._vectorstore = RedisVectorStore.from_existing_index(
                    embedding=self.embeddings,
                    index_name=self.index_name,
                    redis_url=self.redis_url,
                    key_prefix=key_prefix,
                )
            else:
                self._vectorstore = RedisVectorStore.from_existing_index(
                    embedding=self.embeddings,
                    index_name=self.index_name,
                    redis_url=self.redis_url,
                )
            logger.info(
                "Vector store connected%s.",
                f" (key_prefix='{key_prefix}')" if key_prefix else "",
            )
            return self._vectorstore
        except Exception as e:
            logger.warning("Vector store not ready: %s", e)
            return None

    def create_or_append(
        self,
        texts: List[str],
        metadatas: List[Dict[str, Any]],
        ids: List[str],
        max_retries: int = 2,
    ) -> RedisVectorStore:
        """Create index on first write or append to existing."""
        last_error: Optional[Exception] = None
        for attempt in range(max_retries + 1):
            try:
                key_prefix = self._get_index_key_prefix()
                if key_prefix:
                    self._vectorstore = RedisVectorStore.from_texts(
                        texts=texts,
                        metadatas=metadatas,
                        embedding=self.embeddings,
                        index_name=self.index_name,
                        redis_url=self.redis_url,
                        ids=ids,
                        key_prefix=key_prefix,
                    )
                else:
                    self._vectorstore = RedisVectorStore.from_texts(
                        texts=texts,
                        metadatas=metadatas,
                        embedding=self.embeddings,
                        index_name=self.index_name,
                        redis_url=self.redis_url,
                        ids=ids,
                    )
                logger.info(
                    "Vector store created and first assets indexed%s.",
                    f" (key_prefix='{key_prefix}')" if key_prefix else "",
                )
                return self._vectorstore
            except (ConnectionError, OSError) as e:
                last_error = e
                if attempt < max_retries:
                    logger.warning("Connection error (attempt %d/%d), retrying: %s", attempt + 1, max_retries + 1, e)
                    self._vectorstore = None
                    continue
                logger.error("Failed to create vectorstore after %d attempts: %s", max_retries + 1, e)
                raise
            except Exception as e:
                logger.error("Failed to create vectorstore: %s", e)
                raise
        raise last_error

    def add_texts(
        self,
        texts: List[str],
        metadatas: List[Dict[str, Any]],
        ids: List[str],
        max_retries: int = 2,
    ) -> None:
        """Add texts to existing vectorstore."""
        if self._vectorstore is None:
            raise ValueError("Vectorstore not initialized")
        
        for attempt in range(max_retries + 1):
            try:
                self._vectorstore.add_texts(texts=texts, metadatas=metadatas, ids=ids)
                return
            except (ConnectionError, OSError) as e:
                if attempt < max_retries:
                    logger.warning("Connection error in add_texts (attempt %d/%d), reconnecting: %s", attempt + 1, max_retries + 1, e)
                    self._vectorstore = None
                    self.get_or_connect()
                    if self._vectorstore is None:
                        raise ValueError("Failed to reconnect to vectorstore")
                    continue
                logger.error("Failed to add texts after %d attempts: %s", max_retries + 1, e)
                raise

    def _get_index_key_prefix(self) -> Optional[str]:
        """Fetch key_prefix from existing index definition."""
        try:
            info, _ = self._fetch_ft_info()
            if not info:
                return None
            id_def = info.get("index_definition")
            if isinstance(id_def, list):
                id_def = self._list_to_dict(id_def)
            if isinstance(id_def, dict):
                prefixes = id_def.get("prefixes")
                if isinstance(prefixes, list) and prefixes:
                    return prefixes[0]
            return None
        except Exception:
            return None

    def _fetch_ft_info(self) -> Tuple[Optional[Dict[str, Any]], Optional[str]]:
        """Fetch FT.INFO for the index."""
        try:
            r = redis.Redis.from_url(self.redis_url, decode_responses=True)
            info = r.execute_command("FT.INFO", self.index_name)
            if isinstance(info, list):
                info = self._list_to_dict(info)
            if not isinstance(info, dict):
                info = {"raw": info}
            return info, None
        except Exception as e:
            return None, str(e)

    @staticmethod
    def _list_to_dict(items: List[Any]) -> Dict[str, Any]:
        """Convert Redis response list to dict."""
        d: Dict[str, Any] = {}
        it = iter(items)
        for k in it:
            v = next(it, None)
            if isinstance(k, (bytes, bytearray)):
                k = k.decode("utf-8", errors="ignore")
            if isinstance(v, (bytes, bytearray)):
                v = v.decode("utf-8", errors="ignore")
            d[str(k)] = v
        return d
