import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM

class Retriever:
    def __init__(self, embedder, tokenizer, corpus):
        """
        初始化检索器
            embedder: 嵌入模型，用于将文本转换为向量表示
            tokenizer: 对应 embedder 的分词器，负责将文本切分为 token ID
            corpus: 语料库，即所有待检索的文档/句子列表
        """
        self.embedder = embedder  # 存储嵌入模型
        self.tokenizer = tokenizer  # 存储分词器
        self.corpus = corpus  # 存储原始文档集合
        self.embeddings = self.build_index(corpus)  # 构建整个语料库的向量索引（提前编码好）
        # 初始化函数说明：
        # 在创建 Retriever 实例时，传入嵌入模型、分词器和文档库。
        # 调用 build_index() 将所有文档转为向量并保存在 self.embeddings 中，便于后续快速检索。

    def _text_to_vector(self, text_batch):
        """
        将一批文本转换为固定长度的向量表示（句子嵌入）
        Args:
            text_batch (list of str): 文本列表，例如 ["Paris is...", "Tokyo is..."]
        Returns:
            sent_vecs (torch.Tensor): 归一化前的句子向量，形状 [batch_size, hidden_size]
        """
        # [   21   ]
        # enc = self.tokenizer(text_batch, return_tensors="pt",truncation=True,max_length=512)
        # #答案，要正常运行需要加上padding
        enc = self.tokenizer(text_batch, return_tensors="pt", truncation=True, max_length=512,padding=True)
        """"
        参数说明：
        - text_batch: 输入的文本列表，可以是单个文本或多个文本
        - return_tensors="pt": 返回PyTorch张量格式
        - truncation=True: 当文本长度超过max_length时进行截断
        - max_length=512: 设置最大序列长度，超过的部分会被截断
        - padding=True: 对批处理中的不同长度文本进行填充，使它们具有相同的长度（以最长文本为准）
                        这是解决ValueError的关键：确保批处理张量具有相同长度
                        padding=True会自动将短文本用pad_token填充到相同长度
        """
        # 将张量移动到与模型相同的设备上（GPU 或 CPU）
        enc = {k: v.to(self.embedder.device) for k, v in enc.items()}
        # print("enc:\n",enc)
        with torch.no_grad():  # 推理阶段不计算梯度，节省内存和加速
            # [   22   ]
            outputs = self.embedder(**enc)  # 前向传播，获取模型输出 **enc，字典解包
            last_hidden = outputs.last_hidden_state  # 取最后一层隐藏状态，shape: [B, L, D]
            # B=batch_size(批次大小), L=seq_len(序列长度), D=hidden_dim(隐藏维度)
            attn_mask = enc["attention_mask"]  # 注意力掩码，标识哪些是真实 token, shape:[B,L]
            # 注意力掩码，包含0和1, 1表示真实token，0表示padding填充

            # [   23   ]
            sent_vecs = (last_hidden * attn_mask.unsqueeze(-1)).sum(1) / attn_mask.sum(1,keepdim=True)
            # 这行代码实现的是带注意力掩码的平均池化（Masked Mean Pooling），用于将变长的句子序列转换为固定长度的句子向量。
            # attn_mask.unsqueeze(-1)：形状变为 [B, L, 1]
            # last_hidden * attn_mask.unsqueeze(-1) 逐元素相乘，将padding位置的隐藏状态置为0
            # 张量广播机制,当两个张量进行逐元素运算时，如果它们的形状不完全相同，PyTorch会按照广播规则自动扩展较小的张量以匹配较大的张量。
            # .sum(1)：沿着序列维度求和,对每个句子的所有token的隐藏状态求和，由于padding位置已被置0，实际上只对真实token求和
            # attn_mask.sum(1,keepdim=True) 计算每个句子中真实token的数量
            # 最后总和除以真实token数量，得到平均值，得到每个句子的平均向量表示，此为平均池化操作。

        return sent_vecs  # 输出 shape: [batch_size, hidden_size]

    def build_index(self, corpus):
        """
        构建语料库的向量索引：将每条文档编码为向量，并进行 L2 归一化
        Args:
            corpus (list of str): 所有文档组成的列表
        Returns:
            normalized_vecs (torch.Tensor): 归一化后的文档向量，shape [N, D]
        """
        vecs_tensor = self._text_to_vector(corpus)  # 编码全部文档 → [N, D]
        vecs_normalized = F.normalize(vecs_tensor, p=2, dim=1)  # L2 归一化：使向量模长为 1
        return vecs_normalized.cpu()  # 移动到 CPU 保存（避免 GPU 内存占用）
        # 为什么要做 L2 归一化？
        # - 使得余弦相似度等于向量点积：cos_sim(a,b) = a·b
        # - 提高检索效率和稳定性。

    def search(self, query_vec, topk=3):
        """
        在文档库中搜索与查询向量最相似的 top-k 个文档
        Args:
            query_vec (torch.Tensor): 查询向量，shape [D,] 或 [1, D]
            topk (int): 返回前 k 个结果
        Returns:
            results (list of dict): 包含文本和得分的结果列表
        """
        if query_vec.dim() == 1:
            query_vec = query_vec.unsqueeze(0)  # 如果是单维向量，升维成 [1, D]

        query_normalized = F.normalize(query_vec, p=2, dim=1)  # 查询也做 L2 归一化
        similarities = torch.matmul(query_normalized, self.embeddings.T)  # 点积 ≈ 余弦相似度
        similarities = similarities.squeeze(0)  # 去掉 batch 维度 → [num_docs]
        # 相似度计算原理：
        # - 已知 query_normalized 和 self.embeddings 都已归一化，
        # - 则 a · b^T 就是余弦相似度。

        # 获取 top-k 最相似的文档索引和分数
        topk_scores, topk_indices = torch.topk(similarities, k=min(topk, len(self.corpus)))

        # 构造返回结果：包含原文和匹配分数
        results = [
            {"text": self.corpus[idx], "score": float(topk_scores[i])}
            for i, idx in enumerate(topk_indices)
        ]
        return results
        # 示例输出：
        # [
        #   {"text": "Paris is the capital of France.", "score": 0.87},
        #   {"text": "Beijing is the capital of China.", "score": 0.32}
        # ]


def rag_answer(query, retriever, generator, gen_tokenizer, topk=3, max_new_tokens=128):
    """
    完整的 RAG 流程：检索 + 生成
    Args:
        query (str): 用户提出的问题
        retriever (Retriever): 检索器实例
        generator (AutoModelForCausalLM): 生成模型（如 Qwen）
        gen_tokenizer (AutoTokenizer): 生成模型对应的分词器
        topk (int): 检索返回的文档数量
        max_new_tokens (int): 控制生成答案的最大长度
    Returns:
        answer (str): 最终生成的答案
    """
    # 将查询文本转为向量
    query_vec_tensor = retriever._text_to_vector([query])  # 输入必须是 list → ["query"]
    query_vec = query_vec_tensor.squeeze(0).cpu()  # 去掉 batch 维度并移到 CPU
    # 注意：虽然 _text_to_vector 是 Retriever 的私有方法，但这里外部调用了它来获得查询向量。

    # [   24   ]
    docs = retriever.search(query_vec, topk=topk)  # 调用检索器查找 top-k 相关文档
    #查询 Retriever 类中编写的search方法，根据这个方法可以得出答案

    # 合并检索到的文档作为上下文
    context = "\n".join([d["text"] for d in docs]) if docs else ""

    # 构建 prompt，指导模型使用上下文回答问题
    prompt = f"Use the context to answer.\nContext:\n{context}\n\nQ: {query}\nA: "

    # 使用生成模型的 tokenizer 对 prompt 进行编码
    inputs = gen_tokenizer(prompt, return_tensors="pt")
    device = generator.device
    inputs = {k: v.to(device) for k, v in inputs.items()}  # 移动到模型所在设备

    # 设置 pad_token_id（若未设置，则用 eos_token_id 替代）
    pad_id = gen_tokenizer.pad_token_id if gen_tokenizer.pad_token_id is not None else gen_tokenizer.eos_token_id
    # 有些模型没有显式 pad_token，需用 eos_token（句尾符）代替，防止警告。

    with torch.no_grad():  # 推理模式，不记录梯度
        output_ids = generator.generate(
            **inputs,
            do_sample=True,  # 是否采样（而非贪心）
            temperature=0.7,  # 控制随机性：越高越多样
            top_p=0.9,  # 核采样（nucleus sampling），过滤低概率词
            max_new_tokens=max_new_tokens,  # 最多生成多少新 token
            pad_token_id=pad_id,
            eos_token_id=gen_tokenizer.eos_token_id,  # 遇到 EOS 停止生成
        )
    # 生成参数解释：
    # - do_sample=True: 开启随机采样，避免重复输出。
    # - temperature=0.7: 适中温度，平衡创造性和准确性。
    # - top_p=0.9: 只从累计概率最高的 90% 的词汇中采样。

    # [   25   ]
    # 提取新生成的部分（去掉输入 prompt 的部分）
    # 原因:分离输入和输出：模型返回的是完整的序列（输入+生成）,我们只需要新生成的部分
    new_tokens = output_ids[0][inputs["input_ids"].shape[1]:]
    # output_ids[0] 选择第一个（也是通常唯一的）样本
    # 因为 output_ids 是批处理格式 [batch_size, ...]
    # 即使只处理一个样本，输出也是 [1, total_length] 形状
    # output_ids[0] 得到形状为 [total_length] 的一维张量

    # inputs["input_ids"].shape[1] 作用是获取prompt的长度
    # inputs["input_ids"] 是一个二维张量，形状为 [batch_size, sequence_length]
    # .shape 返回一个包含两个元素的元组 (batch_size, sequence_length)
    # [1] 表示取元组的第二个元素，即序列长度
    # 这个值告诉我们输入的prompt有多少个token，用于确定从哪里开始提取新生成的内容

    # 示例：
    # - 输入 prompt 有 100 个 token
    # - 总共输出 130 个 token
    # - new_tokens = output_ids[0][100:] → 只保留后 30 个生成的内容

    # print("output_ids:\n",output_ids)
    # print("new_tokens:",new_tokens)

    # 解码为可读文本，跳过特殊符号（如 [PAD], </s> 等）
    answer = gen_tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
    return answer


if __name__ == "__main__":
    print("Start!!!")

    # 加载 Embedder（用于检索）
    embedder_name = "../Qwen/Qwen3-Embedding-0.6B"
    retriever_tokenizer = AutoTokenizer.from_pretrained(embedder_name)
    embedder = AutoModel.from_pretrained(embedder_name)
    # 使用的是通义千问的专用嵌入模型：Qwen3-Embedding-0.6B
    # - 专为生成高质量句向量设计
    # - 支持多语言、长文本

    # 加载 Generator（用于回答）
    generator_name = "../Qwen/Qwen3-0.6B"
    gen_tokenizer = AutoTokenizer.from_pretrained(generator_name)
    generator = AutoModelForCausalLM.from_pretrained(generator_name)
    # 使用的是通义千问的小型因果语言模型：Qwen3-0.6B
    # - 可本地运行，适合轻量级 RAG 系统
    # - 支持对话、问答等多种任务

    # 定义语料库
    corpus = [
        "Paris is the capital of France.",
        "Tokyo is the capital of Japan.",
        "Beijing is the capital of China.",
    ]
    # 当前知识库只有三句话，可用于测试基本功能。

    # 创建 Retriever 并执行问答
    retriever = Retriever(embedder, retriever_tokenizer, corpus)
    query = "What is the capital of France?"

    ans = rag_answer(query, retriever, generator, gen_tokenizer, topk=2)
    print("Q:", query)
    print("A:", ans)




