# app/main.py (refactored)
"""FastAPI RAG API with Redis vector store."""
import json
import logging
import os
from datetime import datetime
from typing import Any, Dict, List, Optional

import redis
from fastapi import FastAPI, HTTPException, UploadFile, File, Query, Path
from fastapi.responses import FileResponse
from pydantic import BaseModel

from app.config import settings
from app.services.embeddings_service import EmbeddingsService
from app.services.query_service import QueryService
from app.services.vectorstore_service import VectorstoreService
from app.services.file_extractor_service import FileExtractorService
from app.services.llm_service import LLMService, SourceDocument as LLMSourceDocument
from app.main_helpers import (
    is_safe_filename,
    load_registry,
    save_registry,
    chunk_text,
    read_text_file,
    mask_redis_url,
    get_index_name,
    get_registry_path,
    get_all_courses,
)

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Initialize FastAPI app
app = FastAPI(
    title="RAG Query API",
    description="An API to query a Redis-backed RAG system.",
    version="1.1.0",
)

# Service instances (per-course)
embeddings_service: Optional[EmbeddingsService] = None
vectorstore_services: Dict[str, VectorstoreService] = {}
query_services: Dict[str, QueryService] = {}
llm_service: Optional[LLMService] = None

# Config
DATA_DIR = settings.DATA_DIR
BASE_REGISTRY_PATH = os.path.join(os.path.dirname(__file__), "assets_index.json")


# ===========================
# Pydantic Models
# ===========================


class QueryRequest(BaseModel):
    query: str
    top_k: Optional[int] = None
    distance_threshold: Optional[float] = None


class YouTubeImportRequest(BaseModel):
    url: str
    title: Optional[str] = None


class ChatSourceDocument(BaseModel):
    """Source document in chat response."""
    content: str
    filename: str
    chunk_id: Optional[str] = None


class ChatRequest(BaseModel):
    """Request for chat endpoint."""
    query: str
    top_k: int = 3
    guidelines: Optional[str] = None


class ChatResponse(BaseModel):
    """Response from chat endpoint."""
    answer: str
    sources: List[ChatSourceDocument]


class DocumentResponse(BaseModel):
    page_content: str
    metadata: Dict[str, Any]
    distance: Optional[float] = None


class AssetSummary(BaseModel):
    id: str
    name: str
    size_bytes: int
    modified_at: datetime
    chunk_count: int


class AssetDetail(AssetSummary):
    path: str
    chunk_size: int
    chunk_overlap: int
    embedding_ids: List[str]


class AssetListResponse(BaseModel):
    page: int
    per_page: int
    total: int
    items: List[AssetSummary]


class IndexInfoResponse(BaseModel):
    index_name: str
    redis_url: str
    vectorstore_connected: bool
    exists: bool
    num_docs: Optional[int] = None
    num_docs_search: Optional[int] = None
    raw_info: Optional[Dict[str, Any]] = None
    error: Optional[str] = None


class CourseResponse(BaseModel):
    id: str


class CoursesListResponse(BaseModel):
    courses: List[CourseResponse]


# ===========================
# Lifecycle
# ===========================


def _get_or_init_services(course_id: str) -> tuple[Optional[QueryService], Optional[VectorstoreService]]:
    """Get or initialize services for a course."""
    global embeddings_service, vectorstore_services, query_services
    
    if course_id in query_services and course_id in vectorstore_services:
        return query_services[course_id], vectorstore_services[course_id]
    
    try:
        if embeddings_service is None:
            return None, None
        
        embeddings = embeddings_service.get()
        index_name = get_index_name(course_id)
        vs_service = VectorstoreService(
            redis_url=settings.REDIS_URL,
            index_name=index_name,
            embeddings=embeddings,
        )
        vs = vs_service.get_or_connect()
        qs = QueryService(vectorstore=vs) if vs else None
        
        vectorstore_services[course_id] = vs_service
        query_services[course_id] = qs
        
        logger.info("Services initialized for course: %s (index: %s)", course_id, index_name)
        return qs, vs_service
    except Exception as e:
        logger.error("Failed to initialize services for course %s: %s", course_id, e)
        return None, None


@app.on_event("startup")
async def startup_event() -> None:
    """Initialize embeddings and LLM services on startup."""
    global embeddings_service, llm_service
    try:
        embeddings_service = EmbeddingsService(model_name=settings.MODEL_NAME)
        logger.info("Embeddings service initialized")
    except Exception as e:
        logger.error("Failed to initialize embeddings service: %s", e)
    
    # Initialize LLM service (optional, requires OPENAI_API_KEY)
    if settings.OPENAI_API_KEY:
        try:
            llm_service = LLMService(
                api_key=settings.OPENAI_API_KEY,
                model=settings.OPENAI_MODEL,
                temperature=settings.LLM_TEMPERATURE,
                max_tokens=settings.LLM_MAX_TOKENS,
            )
            logger.info("LLM service initialized with model: %s", settings.OPENAI_MODEL)
        except Exception as e:
            logger.error("Failed to initialize LLM service: %s", e)
            llm_service = None
    else:
        logger.warning("OPENAI_API_KEY not set; LLM service will not be available")


# ===========================
# Course Management
# ===========================


@app.get("/courses", response_model=CoursesListResponse)
async def get_courses() -> CoursesListResponse:
    """Get all available courses."""
    course_ids = get_all_courses(BASE_REGISTRY_PATH)
    return CoursesListResponse(courses=[CourseResponse(id=cid) for cid in course_ids])


# ===========================
# Query Endpoint
# ===========================


@app.post("/courses/{course_id}/query", response_model=List[DocumentResponse])
async def query(
    course_id: str = Path(..., description="Course ID"),
    req: QueryRequest = None,
) -> List[DocumentResponse]:
    """Query the RAG system for similar documents in a course.
    
    Args:
        course_id: Course identifier
        req.query: Search query string
        req.top_k: Maximum number of results (defaults to EMBEDDING_TOP_K from config)
        req.distance_threshold: Optional distance threshold for filtering (defaults to EMBEDDING_DISTANCE_THRESHOLD)
    
    Returns:
        List of DocumentResponse with page_content, metadata, and distance score
    """
    query_service, _ = _get_or_init_services(course_id)
    
    if query_service is None or query_service.vectorstore is None:
        raise HTTPException(
            status_code=503,
            detail=f"Vector store not initialized for course {course_id}. Run ingestion first.",
        )

    try:
        # Use defaults from config if not provided in request
        top_k = req.top_k if req.top_k is not None else settings.EMBEDDING_TOP_K
        distance_threshold = req.distance_threshold if req.distance_threshold is not None else settings.EMBEDDING_DISTANCE_THRESHOLD
        
        results = query_service.search(
            query=req.query,
            top_k=top_k,
            distance_threshold=distance_threshold
        )
        return [DocumentResponse(**r) for r in results]
    except Exception as e:
        logger.error("Query failed for course %s: %s", course_id, e)
        raise HTTPException(status_code=500, detail=str(e))


# ===========================
# LLM Chat Endpoint
# ===========================


@app.post("/courses/{course_id}/llm/chat", response_model=ChatResponse)
async def chat(
    course_id: str = Path(..., description="Course ID"),
    req: ChatRequest = None,
) -> ChatResponse:
    """Chat endpoint: answer questions about course content using RAG + LLM.
    
    Process:
    1. Retrieve top_k relevant documents from RAG for the query
    2. Pass documents and query to LLM with context
    3. Return synthesized answer with source references
    """
    # Check if LLM service is initialized
    if llm_service is None:
        raise HTTPException(
            status_code=503,
            detail="LLM service not initialized. Set OPENAI_API_KEY environment variable.",
        )
    
    # Get RAG query service
    query_service, _ = _get_or_init_services(course_id)
    
    if query_service is None or query_service.vectorstore is None:
        raise HTTPException(
            status_code=503,
            detail=f"Vector store not initialized for course {course_id}. Run ingestion first.",
        )

    try:
        # Step 1: Retrieve relevant documents from RAG
        rag_results = query_service.search(query=req.query, top_k=req.top_k)
        
        if not rag_results:
            raise HTTPException(
                status_code=404,
                detail=f"No relevant documents found for query in course {course_id}.",
            )
        
        # Step 2: Synthesize answer using LLM with RAG context
        synthesis_result = llm_service.synthesize(
            query=req.query,
            context_docs=rag_results,
            guidelines=req.guidelines,
        )
        
        # Step 3: Convert source documents to response format
        sources = [
            ChatSourceDocument(
                content=source.content,
                filename=source.filename,
                chunk_id=source.chunk_id,
            )
            for source in synthesis_result.sources
        ]
        
        logger.info(
            "Chat endpoint successful for course %s, query: %s", 
            course_id, 
            req.query[:50]
        )
        return ChatResponse(answer=synthesis_result.answer, sources=sources)
        
    except HTTPException:
        raise
    except Exception as e:
        logger.error("Chat failed for course %s: %s", course_id, e)
        raise HTTPException(status_code=500, detail=f"Chat failed: {str(e)}")


# ===========================
# Asset Management
# ===========================


def ensure_data_dir() -> None:
    """Ensure data directory exists."""
    os.makedirs(DATA_DIR, exist_ok=True)


@app.post("/courses/{course_id}/asset/add", response_model=AssetDetail)
async def asset_add(
    course_id: str = Path(..., description="Course ID"),
    file: UploadFile = File(...),
    overwrite: bool = Query(False),
) -> AssetDetail:
    """Upload and index a file for a course. Supports: txt, md, pdf, youtube."""
    if not file.filename:
        raise HTTPException(status_code=400, detail="No filename provided")

    filename = file.filename
    if not is_safe_filename(filename):
        raise HTTPException(status_code=400, detail="Invalid filename")

    # Check file type support
    if not FileExtractorService.is_supported(filename):
        raise HTTPException(
            status_code=415,
            detail=f"Unsupported file type. Supported: txt, md, pdf, youtube"
        )

    ensure_data_dir()
    target_path = os.path.join(DATA_DIR, filename)
    file_id = filename

    # Check if file exists
    if os.path.exists(target_path) and not overwrite:
        raise HTTPException(
            status_code=409, detail="File exists. Use overwrite=true to replace."
        )

    # Read and extract text from file
    try:
        content = await file.read()
        text = FileExtractorService.extract_text(filename, content)
    except HTTPException:
        raise
    except Exception as e:
        logger.error("Failed to extract text from %s: %s", filename, e)
        raise HTTPException(status_code=400, detail=f"Failed to extract text: {str(e)}")

    with open(target_path, "w", encoding="utf-8") as f:
        f.write(text)

    # Chunk and embed
    chunks = chunk_text(text, settings.CHUNK_SIZE, settings.CHUNK_OVERLAP)
    if not chunks:
        raise HTTPException(status_code=400, detail="File is empty after chunking")

    metadatas: List[Dict[str, Any]] = []
    ids: List[str] = []
    for i, _chunk in enumerate(chunks):
        metadatas.append({
            "file_id": file_id,
            "source": f"./app/data/{file_id}",
            "chunk": i,
            "filename": filename,
            "course_id": course_id,
        })
        ids.append(f"{file_id}:{i:06d}")

    # Add to vectorstore
    try:
        _, vectorstore_service = _get_or_init_services(course_id)
        if vectorstore_service is None:
            raise HTTPException(status_code=503, detail="Vector store not ready")

        # create_or_append will create the index if it doesn't exist
        # or append to existing one
        vectorstore_service.create_or_append(
            texts=chunks, metadatas=metadatas, ids=ids
        )
    except HTTPException:
        raise
    except Exception as e:
        logger.error("Failed to index file for course %s: %s", course_id, e)
        raise HTTPException(status_code=500, detail=f"Indexing failed: {e}")

    # Update registry
    stat = os.stat(target_path)
    registry_path = get_registry_path(BASE_REGISTRY_PATH, course_id)
    reg = load_registry(registry_path)
    reg[file_id] = {
        "id": file_id,
        "name": filename,
        "size_bytes": stat.st_size,
        "modified_at": datetime.fromtimestamp(stat.st_mtime).isoformat(),
        "path": target_path,
        "chunk_count": len(chunks),
        "chunk_size": settings.CHUNK_SIZE,
        "chunk_overlap": settings.CHUNK_OVERLAP,
        "embedding_ids": ids,
        "course_id": course_id,
    }
    save_registry(reg, registry_path)

    return AssetDetail(
        id=file_id,
        name=filename,
        size_bytes=stat.st_size,
        modified_at=datetime.fromtimestamp(stat.st_mtime),
        path=target_path,
        chunk_count=len(chunks),
        chunk_size=settings.CHUNK_SIZE,
        chunk_overlap=settings.CHUNK_OVERLAP,
        embedding_ids=ids,
    )


@app.post("/courses/{course_id}/asset/add-youtube", response_model=AssetDetail)
async def asset_add_youtube(
    course_id: str = Path(..., description="Course ID"),
    req: YouTubeImportRequest = None,
) -> AssetDetail:
    """Import YouTube video transcript and index it for a course."""
    if not req.url:
        raise HTTPException(status_code=400, detail="URL is required")
    
    try:
        # Extract transcript from YouTube
        text = FileExtractorService.extract_text("video.youtube", req.url.encode("utf-8"))
    except HTTPException:
        raise
    
    # Generate filename
    ensure_data_dir()
    if req.title:
        filename = f"{req.title}.youtube"
    else:
        # Use URL as filename
        import hashlib
        url_hash = hashlib.md5(req.url.encode()).hexdigest()[:8]
        filename = f"youtube_{url_hash}.youtube"
    
    if not is_safe_filename(filename):
        raise HTTPException(status_code=400, detail="Invalid title for YouTube video")
    
    target_path = os.path.join(DATA_DIR, filename)
    file_id = filename
    
    # Write transcript to file
    with open(target_path, "w", encoding="utf-8") as f:
        f.write(text)
    
    # Chunk and embed
    chunks = chunk_text(text, settings.CHUNK_SIZE, settings.CHUNK_OVERLAP)
    if not chunks:
        raise HTTPException(status_code=400, detail="YouTube video has no transcript content")
    
    metadatas: List[Dict[str, Any]] = []
    ids: List[str] = []
    for i, _chunk in enumerate(chunks):
        metadatas.append({
            "file_id": file_id,
            "source": req.url,
            "chunk": i,
        })
        ids.append(f"{file_id}:{i:06d}")
    
    # Add to vectorstore
    query_service, vectorstore_service = _get_or_init_services(course_id)
    if not vectorstore_service or not vectorstore_service._vectorstore:
        raise HTTPException(
            status_code=503, detail="Vector store not initialized for course"
        )
    
    vectorstore_service.create_or_append(chunks, metadatas, ids)
    
    # Update registry
    stat = os.stat(target_path)
    registry_path = get_registry_path(BASE_REGISTRY_PATH, course_id)
    reg = load_registry(registry_path)
    reg[file_id] = {
        "id": file_id,
        "name": filename,
        "size_bytes": stat.st_size,
        "modified_at": datetime.fromtimestamp(stat.st_mtime).isoformat(),
        "path": target_path,
        "chunk_count": len(chunks),
        "chunk_size": settings.CHUNK_SIZE,
        "chunk_overlap": settings.CHUNK_OVERLAP,
        "embedding_ids": ids,
        "course_id": course_id,
        "type": "youtube",
        "youtube_url": req.url,
    }
    save_registry(reg, registry_path)
    
    return AssetDetail(
        id=file_id,
        name=filename,
        size_bytes=stat.st_size,
        modified_at=datetime.fromtimestamp(stat.st_mtime),
        path=target_path,
        chunk_count=len(chunks),
        chunk_size=settings.CHUNK_SIZE,
        chunk_overlap=settings.CHUNK_OVERLAP,
        embedding_ids=ids,
    )


@app.delete("/courses/{course_id}/asset/remove/{asset_id}")
async def asset_remove(
    course_id: str = Path(..., description="Course ID"),
    asset_id: str = Path(...),
) -> Dict[str, Any]:
    """Remove a file and its embeddings from a course."""
    if not is_safe_filename(asset_id):
        raise HTTPException(status_code=400, detail="Invalid asset id")

    path = os.path.join(DATA_DIR, asset_id)
    existed = os.path.exists(path)

    # Remove embeddings
    removed_embeddings = 0
    try:
        registry_path = get_registry_path(BASE_REGISTRY_PATH, course_id)
        reg = load_registry(registry_path)
        embedding_ids = list(reg.get(asset_id, {}).get("embedding_ids") or [])

        _, vectorstore_service = _get_or_init_services(course_id)
        if vectorstore_service and vectorstore_service._vectorstore and embedding_ids:
            try:
                vectorstore_service._vectorstore.delete(ids=embedding_ids)
                removed_embeddings = len(embedding_ids)
            except Exception as e:
                logger.warning("Embedding deletion failed for course %s: %s", course_id, e)
    except Exception as e:
        logger.warning("Failed to remove embeddings for course %s: %s", course_id, e)

    # Remove file
    removed_file = False
    if existed:
        try:
            os.remove(path)
            removed_file = not os.path.exists(path)
        except Exception as e:
            raise HTTPException(status_code=500, detail=f"Failed to remove file: {e}")

    # Update registry
    registry_path = get_registry_path(BASE_REGISTRY_PATH, course_id)
    reg = load_registry(registry_path)
    if asset_id in reg:
        del reg[asset_id]
        save_registry(reg, registry_path)

    return {
        "id": asset_id,
        "file_removed": removed_file,
        "removed_embeddings": removed_embeddings,
    }


@app.get("/courses/{course_id}/asset/get_all", response_model=AssetListResponse)
async def asset_get_all(
    course_id: str = Path(..., description="Course ID"),
    page: int = Query(1, ge=1), per_page: int = Query(20, ge=1, le=200)
) -> AssetListResponse:
    """List files for a course with pagination."""
    ensure_data_dir()
    registry_path = get_registry_path(BASE_REGISTRY_PATH, course_id)
    reg = load_registry(registry_path)
    
    # Only include files that belong to this course
    course_files = [name for name, info in reg.items() if info.get("course_id") == course_id]
    course_files.sort()

    total = len(course_files)
    start = (page - 1) * per_page
    end = start + per_page
    selected = course_files[start:end]

    items: List[AssetSummary] = []
    for name in selected:
        path = os.path.join(DATA_DIR, name)
        if not os.path.exists(path):
            continue
        stat = os.stat(path)
        chunk_count = reg.get(name, {}).get("chunk_count")
        if chunk_count is None:
            try:
                text = read_text_file(path)
                chunk_count = len(chunk_text(text, settings.CHUNK_SIZE, settings.CHUNK_OVERLAP))
            except HTTPException:
                chunk_count = 0
        items.append(
            AssetSummary(
                id=name,
                name=name,
                size_bytes=stat.st_size,
                modified_at=datetime.fromtimestamp(stat.st_mtime),
                chunk_count=chunk_count,
            )
        )

    return AssetListResponse(page=page, per_page=per_page, total=total, items=items)


@app.get("/courses/{course_id}/asset/get/{asset_id}", response_model=AssetDetail)
async def asset_get(
    course_id: str = Path(..., description="Course ID"),
    asset_id: str = Path(...),
) -> AssetDetail:
    """Get details for a specific file in a course."""
    if not is_safe_filename(asset_id):
        raise HTTPException(status_code=400, detail="Invalid asset id")

    path = os.path.join(DATA_DIR, asset_id)
    if not os.path.exists(path):
        raise HTTPException(status_code=404, detail="File not found")

    registry_path = get_registry_path(BASE_REGISTRY_PATH, course_id)
    reg = load_registry(registry_path)
    stat = os.stat(path)

    if asset_id in reg and reg[asset_id].get("course_id") == course_id:
        entry = reg[asset_id]
        return AssetDetail(
            id=entry["id"],
            name=entry["name"],
            size_bytes=entry["size_bytes"],
            modified_at=datetime.fromisoformat(entry["modified_at"]),
            path=entry["path"],
            chunk_count=entry["chunk_count"],
            chunk_size=entry["chunk_size"],
            chunk_overlap=entry["chunk_overlap"],
            embedding_ids=entry.get("embedding_ids", []),
        )

    # Fallback
    text = read_text_file(path)
    chunks = chunk_text(text, settings.CHUNK_SIZE, settings.CHUNK_OVERLAP)
    ids = [f"{asset_id}:{i:06d}" for i in range(len(chunks))]

    return AssetDetail(
        id=asset_id,
        name=asset_id,
        size_bytes=stat.st_size,
        modified_at=datetime.fromtimestamp(stat.st_mtime),
        path=path,
        chunk_count=len(chunks),
        chunk_size=settings.CHUNK_SIZE,
        chunk_overlap=settings.CHUNK_OVERLAP,
        embedding_ids=ids,
    )


@app.get("/courses/{course_id}/asset/download/{asset_id}")
async def asset_download(
    course_id: str = Path(..., description="Course ID"),
    asset_id: str = Path(...),
) -> FileResponse:
    """Download a file from a course."""
    if not is_safe_filename(asset_id):
        raise HTTPException(status_code=400, detail="Invalid asset id")
    
    registry_path = get_registry_path(BASE_REGISTRY_PATH, course_id)
    reg = load_registry(registry_path)
    
    # Verify file belongs to this course
    if asset_id not in reg or reg[asset_id].get("course_id") != course_id:
        raise HTTPException(status_code=404, detail="File not found")
    
    path = os.path.join(DATA_DIR, asset_id)
    if not os.path.exists(path):
        raise HTTPException(status_code=404, detail="File not found")
    return FileResponse(path, filename=asset_id)


# ===========================
# Debug Endpoints
# ===========================


@app.get("/courses/{course_id}/debug/index_info", response_model=IndexInfoResponse)
async def debug_index_info(course_id: str = Path(..., description="Course ID")) -> IndexInfoResponse:
    """Get index metadata for a course."""
    _, vs_service = _get_or_init_services(course_id)
    vs_connected = vs_service and vs_service.get_or_connect() is not None
    index_name = get_index_name(course_id)
    info, err = _fetch_ft_info(index_name)
    exists = info is not None and err is None

    num_docs: Optional[int] = None
    if info:
        for key in ("num_docs", "numDocs", "num_docs_cached"):
            if key in info:
                try:
                    num_docs = int(info[key])
                except Exception:
                    try:
                        num_docs = int(float(info[key]))
                    except Exception:
                        num_docs = None
                break

    num_docs_search: Optional[int] = None
    try:
        r = redis.Redis.from_url(settings.REDIS_URL, decode_responses=True)
        res = r.execute_command("FT.SEARCH", index_name, "*", "RETURN", "0", "LIMIT", "0", "0")
        if isinstance(res, list) and res:
            num_docs_search = int(res[0])
    except Exception:
        num_docs_search = None

    return IndexInfoResponse(
        index_name=index_name,
        redis_url=mask_redis_url(settings.REDIS_URL),
        vectorstore_connected=bool(vs_connected),
        exists=bool(exists),
        num_docs=num_docs,
        num_docs_search=num_docs_search,
        raw_info=info if info else None,
        error=err,
    )


def _fetch_ft_info(index_name: str) -> tuple[Optional[Dict[str, Any]], Optional[str]]:
    """Fetch FT.INFO for a specific index."""
    try:
        r = redis.Redis.from_url(settings.REDIS_URL, decode_responses=True)
        info = r.execute_command("FT.INFO", index_name)
        if isinstance(info, list):
            from app.services.vectorstore_service import VectorstoreService
            info = VectorstoreService._list_to_dict(info)
        if not isinstance(info, dict):
            info = {"raw": info}
        return info, None
    except Exception as e:
        return None, str(e)
