# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
ByteHouse 多模态检索客户端
支持向量检索、混合检索、以文搜图、以图搜图等功能
"""

import os
import json
import asyncio
from typing import List, Dict, Any

from .embedding import MultimodalEmbedding


class ByteHouseMultimodalSearch:
    """ByteHouse多模态检索客户端"""
    
    def __init__(self, 
                 connection_type: str = "http",
                 secure: bool = True,
                 compress: str = "zstd",
                 connect_timeout: int = 300,
                 send_receive_timeout: int = 1000,
                 prefer_mcp: bool = True):
        """
        初始化ByteHouse多模态检索客户端
        
        Args:
            connection_type: 连接方式，可选 http/tcp
            secure: 是否启用加密连接
            compress: 压缩方式，可选 zstd/lz4/False
            connect_timeout: 连接超时时间，单位秒
            send_receive_timeout: 请求超时时间，单位秒
            prefer_mcp: 是否优先使用ByteHouse MCP Skill
        """
        self.connection_type = connection_type
        self.dimensions = int(os.environ.get("EMBEDDING_DIMENSIONS", 1536))
        self.embedding = MultimodalEmbedding()
        self.use_mcp = False
        self.mcp_client = None
        
        # 优先尝试使用MCP连接
        if prefer_mcp:
            try:
                from mcp_client import ByteHouseMCPClient
                
                async def test_mcp_connection():
                    async with ByteHouseMCPClient() as client:
                        await client.connect()
                        return client
                
                self.mcp_client = asyncio.run(test_mcp_connection())
                self.use_mcp = True
            except Exception:
                self.use_mcp = False
        
        # MCP不可用时使用原生驱动连接
        if not self.use_mcp:
            if connection_type == "http":
                import clickhouse_connect
                self.client = clickhouse_connect.get_client(
                    host=os.environ.get("BYTEHOUSE_HOST"),
                    port=int(os.environ.get("BYTEHOUSE_PORT", 8123)),
                    username=os.environ.get("BYTEHOUSE_USER"),
                    password=os.environ.get("BYTEHOUSE_PASSWORD"),
                    database=os.environ.get("BYTEHOUSE_DATABASE", "default"),
                    secure=secure,
                    compress=compress,
                    send_receive_timeout=send_receive_timeout
                )
            elif connection_type == "tcp":
                from clickhouse_driver import Client
                self.client = Client(
                    host=os.environ.get("BYTEHOUSE_HOST"),
                    port=int(os.environ.get("BYTEHOUSE_PORT", 9000)),
                    user=os.environ.get("BYTEHOUSE_USER"),
                    password=os.environ.get("BYTEHOUSE_PASSWORD"),
                    database=os.environ.get("BYTEHOUSE_DATABASE", "default"),
                    connect_timeout=connect_timeout,
                    send_receive_timeout=send_receive_timeout,
                    compression=compress if compress else False,
                    secure=secure,
                    client_revision=54430
                )
            else:
                raise ValueError(f"不支持的连接类型: {connection_type}")
    
    def _execute_sql(self, sql: str, query_type: str = "select"):
        """内部通用SQL执行方法，自动适配MCP和原生驱动"""
        try:
            if self.use_mcp:
                tool_name = "run_select_query" if query_type == "select" else "run_dml_ddl_query"
                
                async def run_mcp_query():
                    return await self.mcp_client.call_tool(tool_name, {"query": sql})
                
                result = asyncio.run(run_mcp_query())
                if result and len(result) > 0:
                    try:
                        return [list(item.values()) for item in json.loads(result[0])]
                    except:
                        return [line.split('\t') for line in result[0].strip().split('\n')]
                return []
            else:
                if query_type == "select":
                    result = self.client.query(sql)
                    return result.result_rows if hasattr(result, 'result_rows') else result
                else:
                    return self.client.command(sql)
        
        except Exception as e:
            error_msg = str(e).lower()
            if "connection" in error_msg or "timeout" in error_msg:
                raise ConnectionError(f"数据库连接异常：{e}")
            elif "syntax" in error_msg or "parse" in error_msg:
                raise ValueError(f"SQL语法错误：{e}")
            elif "permission" in error_msg or "auth" in error_msg:
                raise PermissionError(f"权限不足：{e}")
            else:
                raise Exception(f"数据库操作失败：{e}")
    
    def create_multimodal_table(self, 
                               table_name: str,
                               enable_text_search: bool = True,
                               index_type: str = "HNSW",
                               metric: str = "COSINE",
                               hnsw_m: int = 32,
                               hnsw_ef_construction: int = 512):
        """
        创建多模态检索表
        
        Args:
            table_name: 表名
            enable_text_search: 是否开启全文检索
            index_type: 索引类型，可选 HNSW/HNSW_SQ/IVF_FLAT/IVF_PQ/IVF_PQ_FS
            metric: 距离度量，可选 COSINE/L2
            hnsw_m: HNSW 每个节点最大连接数
            hnsw_ef_construction: HNSW 构建时探索因子
        """
        if index_type in ["HNSW", "HNSW_SQ"]:
            index_config = f"TYPE {index_type}('DIM={self.dimensions}, METRIC={metric}, M={hnsw_m}, EF_CONSTRUCTION={hnsw_ef_construction}')"
        else:
            index_config = f"TYPE {index_type}('dim={self.dimensions}', 'metric={metric}')"
        
        text_index = f"INDEX text_idx (title, content) TYPE inverted" if enable_text_search else ""
        
        create_sql = f"""
        CREATE TABLE IF NOT EXISTS {table_name} (
            id UInt64 COMMENT '唯一ID',
            content_type Enum('text' = 1, 'image' = 2, 'video' = 3) COMMENT '内容类型',
            content String COMMENT '原始内容或URL',
            title String COMMENT '标题/描述',
            embedding Array(Float32) COMMENT '向量',
            CONSTRAINT cons_vec_len CHECK length(embedding) = {self.dimensions},
            metadata Map(String, String) COMMENT '元数据',
            create_time DateTime DEFAULT now() COMMENT '创建时间',
            INDEX vec_idx embedding {index_config},
            {text_index}
        ) ENGINE = MergeTree
        ORDER BY id
        SETTINGS 
            index_granularity = 1024,
            index_granularity_bytes = 0,
            enable_vector_index_preload = 1
        """
        
        self._execute_sql(create_sql, query_type="ddl")
    
    def insert_document(self,
                       table_name: str,
                       doc_id: int,
                       content_type: str,
                       content: str,
                       title: str = "",
                       metadata: Dict = None,
                       embedding: List[float] = None,
                       instruction: str = None) -> bool:
        """插入单条文档"""
        if embedding is None:
            if content_type == "text":
                embedding = self.embedding.encode_text(content, instruction)
            elif content_type == "image":
                embedding = self.embedding.encode_image(content, instruction)
            elif content_type == "video":
                embedding = self.embedding.encode_video(content, instruction)
            else:
                raise ValueError(f"不支持的内容类型: {content_type}")
        
        if len(embedding) != self.dimensions:
            raise ValueError(f"向量维度错误：期望{self.dimensions}维，实际{len(embedding)}维")
        
        metadata_str = json.dumps(metadata).replace("'", "''") if metadata else "{}"
        
        insert_sql = f"""
        INSERT INTO {table_name}
        (id, content_type, content, title, embedding, metadata)
        VALUES
        ({doc_id}, '{content_type}', '{content.replace("'", "''")}', 
         '{title.replace("'", "''")}', {embedding}, '{metadata_str}')
        """
        
        self._execute_sql(insert_sql, query_type="dml")
        return True
    
    def insert_batch_documents(self,
                              table_name: str,
                              documents: List[Dict],
                              instruction: str = None,
                              skip_error: bool = True) -> Dict:
        """批量插入文档"""
        rows = []
        failed = []
        
        for idx, doc in enumerate(documents):
            try:
                required_fields = ["doc_id", "content_type", "content"]
                for field in required_fields:
                    if field not in doc:
                        raise ValueError(f"缺少必填字段: {field}")
                
                if doc.get('embedding'):
                    embedding = doc['embedding']
                    if len(embedding) != self.dimensions:
                        raise ValueError(f"向量维度错误")
                else:
                    if doc['content_type'] == "text":
                        embedding = self.embedding.encode_text(doc['content'], instruction)
                    elif doc['content_type'] == "image":
                        embedding = self.embedding.encode_image(doc['content'], instruction)
                    elif doc['content_type'] == "video":
                        embedding = self.embedding.encode_video(doc['content'], instruction)
                    else:
                        raise ValueError(f"不支持的内容类型: {doc['content_type']}")
                
                metadata = doc.get('metadata', {})
                metadata_str = json.dumps(metadata).replace("'", "''")
                
                rows.append([
                    doc['doc_id'],
                    doc['content_type'],
                    doc['content'].replace("'", "''"),
                    doc.get('title', '').replace("'", "''"),
                    embedding,
                    metadata_str
                ])
            
            except Exception as e:
                failed.append({"doc_id": doc.get("doc_id", idx), "error": str(e)})
                if not skip_error:
                    raise
        
        if rows:
            try:
                if self.use_mcp:
                    values_str = [f"({row[0]}, '{row[1]}', '{row[2]}', '{row[3]}', {row[4]}, '{row[5]}')" for row in rows]
                    insert_sql = f"INSERT INTO {table_name} VALUES {','.join(values_str)}"
                    self._execute_sql(insert_sql, query_type="dml")
                else:
                    if self.connection_type == "http":
                        self.client.insert(
                            table_name,
                            rows,
                            column_names=['id', 'content_type', 'content', 'title', 'embedding', 'metadata'],
                            column_type_names=['UInt64', 'Enum', 'String', 'String', 'Array(Float32)', 'Map(String, String)']
                        )
                    else:
                        self.client.execute(
                            f'INSERT INTO {table_name} VALUES',
                            rows
                        )
                success_count = len(rows)
            except Exception as e:
                for row in rows:
                    failed.append({"doc_id": row[0], "error": f"批量插入失败: {str(e)}"})
                success_count = 0
        else:
            success_count = 0
        
        return {
            "success_count": success_count,
            "failed_count": len(failed),
            "failed_details": failed
        }
    
    def vector_search(self,
                     table_name: str,
                     query_embedding: List[float],
                     top_k: int = 10,
                     filter_condition: str = None,
                     metric: str = "COSINE",
                     hnsw_ef_s: int = 200) -> List[Dict]:
        """纯向量检索"""
        distance_func = "cosineDistance" if metric == "COSINE" else "L2Distance"
        
        sql = f"""
        SELECT 
            id, content_type, content, title, metadata, create_time,
            {distance_func}(embedding, {query_embedding}) AS score
        FROM {table_name}
        {f"WHERE {filter_condition}" if filter_condition else ""}
        ORDER BY score ASC
        LIMIT {top_k}
        SETTINGS enable_new_ann = 1, hnsw_ef_s = {hnsw_ef_s}
        """
        
        rows = self._execute_sql(sql, query_type="select")
        columns = ['id', 'content_type', 'content', 'title', 'metadata', 'create_time', 'score']
        return [dict(zip(columns, row)) for row in rows]
    
    def hybrid_search(self,
                     table_name: str,
                     query_text: str,
                     query_embedding: List[float] = None,
                     top_k: int = 10,
                     filter_condition: str = None,
                     vector_weight: float = 0.7,
                     text_weight: float = 0.3,
                     metric: str = "COSINE") -> List[Dict]:
        """混合检索：向量检索 + 全文检索"""
        if query_embedding is None:
            query_embedding = self.embedding.encode_text(query_text)
        
        vector_results = self.vector_search(table_name, query_embedding, top_k * 2, filter_condition, metric)
        
        text_search_sql = f"""
        SELECT id, content_type, content, title, metadata, create_time, 1.0 AS text_score
        FROM {table_name}
        WHERE (title LIKE '%{query_text.replace("'", "''")}%' OR content LIKE '%{query_text.replace("'", "''")}%')
        {f'AND {filter_condition}' if filter_condition else ''}
        LIMIT {top_k * 2}
        """
        text_rows = self._execute_sql(text_search_sql, query_type="select")
        text_results = [dict(zip(['id', 'content_type', 'content', 'title', 'metadata', 'create_time', 'text_score'], row)) 
                       for row in text_rows]
        
        # RRF 融合算法
        all_results = {}
        k = 60
        
        for rank, item in enumerate(vector_results):
            doc_id = item['id']
            if doc_id not in all_results:
                all_results[doc_id] = item
            all_results[doc_id]['vector_rank'] = rank
        
        for rank, item in enumerate(text_results):
            doc_id = item['id']
            if doc_id not in all_results:
                all_results[doc_id] = item
            all_results[doc_id]['text_rank'] = rank
        
        for doc_id, item in all_results.items():
            vector_score = 1.0 / (k + item.get('vector_rank', 10000))
            text_score = 1.0 / (k + item.get('text_rank', 10000))
            item['final_score'] = vector_weight * vector_score + text_weight * text_score
        
        sorted_results = sorted(all_results.values(), key=lambda x: x['final_score'], reverse=True)
        return sorted_results[:top_k]
    
    def text_search_image(self, table_name: str, query_text: str, top_k: int = 10, **kwargs) -> List[Dict]:
        """以文搜图"""
        filter_cond = "content_type = 'image'"
        if 'filter_condition' in kwargs:
            filter_cond += f" AND {kwargs.pop('filter_condition')}"
        
        instruction = kwargs.pop('instruction', None)
        query_embedding = self.embedding.encode_text(query_text, instruction)
        return self.vector_search(table_name, query_embedding, top_k, filter_condition=filter_cond, **kwargs)
    
    def image_search_image(self, table_name: str, image_url: str, top_k: int = 10, **kwargs) -> List[Dict]:
        """以图搜图"""
        filter_cond = "content_type = 'image'"
        if 'filter_condition' in kwargs:
            filter_cond += f" AND {kwargs.pop('filter_condition')}"
        
        instruction = kwargs.pop('instruction', None)
        query_embedding = self.embedding.encode_image(image_url, instruction)
        return self.vector_search(table_name, query_embedding, top_k, filter_condition=filter_cond, **kwargs)
    
    def text_search_video(self, table_name: str, query_text: str, top_k: int = 10, **kwargs) -> List[Dict]:
        """以文搜视频"""
        filter_cond = "content_type = 'video'"
        if 'filter_condition' in kwargs:
            filter_cond += f" AND {kwargs.pop('filter_condition')}"
        
        instruction = kwargs.pop('instruction', None)
        query_embedding = self.embedding.encode_text(query_text, instruction)
        return self.vector_search(table_name, query_embedding, top_k, filter_condition=filter_cond, **kwargs)
