# tests/test_llm_service.py
"""Unit tests for LLM service."""
import pytest
from unittest.mock import MagicMock, patch
from langchain_core.messages import AIMessage

from app.services.llm_service import LLMService, SynthesisResult, SourceDocument


class TestLLMServiceInitialization:
    """Test LLMService initialization."""

    def test_init_with_valid_api_key(self):
        """Test initialization with valid API key."""
        with patch('app.services.llm_service.ChatOpenAI') as mock_chat:
            service = LLMService(api_key="test-key-12345")
            assert service.api_key == "test-key-12345"
            assert service.model == "gpt-4o-mini"
            assert service.temperature == 0.7
            assert service.max_tokens == 1024
            mock_chat.assert_called_once()

    def test_init_without_api_key_raises_error(self):
        """Test initialization without API key raises ValueError."""
        with pytest.raises(ValueError, match="OPENAI_API_KEY must be provided"):
            LLMService(api_key="")

    def test_init_with_custom_parameters(self):
        """Test initialization with custom parameters."""
        with patch('app.services.llm_service.ChatOpenAI') as mock_chat:
            service = LLMService(
                api_key="test-key",
                model="gpt-4",
                temperature=0.5,
                max_tokens=2048
            )
            assert service.model == "gpt-4"
            assert service.temperature == 0.5
            assert service.max_tokens == 2048


class TestLLMServicePromptBuilding:
    """Test prompt building functionality."""

    def test_build_prompt_basic(self, mock_llm_service, sample_rag_results):
        """Test basic prompt building."""
        query = "What is machine learning?"
        messages = mock_llm_service._build_prompt(query, sample_rag_results)
        
        assert len(messages) == 2
        assert messages[0].type == "system"
        assert "helpful educational assistant" in messages[0].content
        assert messages[1].type == "human"
        assert query in messages[1].content

    def test_build_prompt_with_custom_guidelines(self, mock_llm_service, sample_rag_results):
        """Test prompt building with custom guidelines."""
        query = "Explain neural networks"
        custom_guidelines = "Be very technical and use mathematical notation"
        messages = mock_llm_service._build_prompt(query, sample_rag_results, custom_guidelines)
        
        assert custom_guidelines in messages[0].content
        assert query in messages[1].content

    def test_build_prompt_includes_context(self, mock_llm_service, sample_rag_results):
        """Test that prompt includes document context."""
        query = "What is AI?"
        messages = mock_llm_service._build_prompt(query, sample_rag_results)
        
        user_prompt = messages[1].content
        for doc in sample_rag_results:
            assert doc["page_content"] in user_prompt
            assert doc["metadata"]["source"] in user_prompt

    def test_build_prompt_empty_context(self, mock_llm_service):
        """Test prompt building with no context documents."""
        query = "What is AI?"
        messages = mock_llm_service._build_prompt(query, [])
        
        assert len(messages) == 2
        assert query in messages[1].content


class TestLLMServiceSynthesis:
    """Test synthesis functionality."""

    def test_synthesize_basic(self, mock_llm_service, sample_rag_results):
        """Test basic synthesis."""
        query = "What is machine learning?"
        result = mock_llm_service.synthesize(query, sample_rag_results)
        
        assert isinstance(result, SynthesisResult)
        assert result.answer
        assert len(result.sources) == 3

    def test_synthesize_sources_match_input(self, mock_llm_service, sample_rag_results):
        """Test that synthesis sources match input documents."""
        query = "What is machine learning?"
        result = mock_llm_service.synthesize(query, sample_rag_results)
        
        assert len(result.sources) == len(sample_rag_results)
        for i, source in enumerate(result.sources):
            assert source.content == sample_rag_results[i]["page_content"]
            assert source.filename == sample_rag_results[i]["metadata"]["source"]

    def test_synthesize_with_custom_guidelines(self, mock_llm_service, sample_rag_results):
        """Test synthesis with custom guidelines."""
        query = "What is deep learning?"
        guidelines = "Use simple language for beginners"
        result = mock_llm_service.synthesize(query, sample_rag_results, guidelines)
        
        assert isinstance(result, SynthesisResult)
        assert result.answer

    def test_synthesize_calls_llm(self, mock_llm_service, sample_rag_results):
        """Test that synthesize calls the LLM."""
        mock_llm_service._llm.invoke = MagicMock(
            return_value=AIMessage(content="Test answer")
        )
        query = "Test query?"
        
        result = mock_llm_service.synthesize(query, sample_rag_results)
        
        mock_llm_service._llm.invoke.assert_called_once()
        assert result.answer == "Test answer"

    def test_synthesize_llm_failure_raises_exception(self, mock_llm_service, sample_rag_results):
        """Test that LLM failure raises exception."""
        mock_llm_service._llm.invoke = MagicMock(
            side_effect=Exception("API Error")
        )
        query = "Test query?"
        
        with pytest.raises(Exception, match="API Error"):
            mock_llm_service.synthesize(query, sample_rag_results)

    def test_synthesize_with_empty_context(self, mock_llm_service):
        """Test synthesis with no context documents."""
        query = "What is AI?"
        result = mock_llm_service.synthesize(query, [])
        
        assert isinstance(result, SynthesisResult)
        assert result.answer
        assert len(result.sources) == 0

    def test_synthesize_source_document_model(self):
        """Test SourceDocument model."""
        doc = SourceDocument(
            content="Test content",
            filename="test.txt",
            chunk_id="chunk_001"
        )
        
        assert doc.content == "Test content"
        assert doc.filename == "test.txt"
        assert doc.chunk_id == "chunk_001"

    def test_synthesis_result_model(self, sample_rag_results):
        """Test SynthesisResult model."""
        sources = [
            SourceDocument(
                content=doc["page_content"],
                filename=doc["metadata"]["source"]
            )
            for doc in sample_rag_results
        ]
        
        result = SynthesisResult(answer="Test answer", sources=sources)
        
        assert result.answer == "Test answer"
        assert len(result.sources) == 3


class TestLLMServiceGetLLM:
    """Test get_llm method."""

    def test_get_llm_returns_chat_openai(self, mock_llm_service):
        """Test that get_llm returns the ChatOpenAI instance."""
        llm = mock_llm_service.get_llm()
        assert llm is not None
        assert llm == mock_llm_service._llm
