import threading
import time
import sqlite3
import json
import logging
import ssl
import os
import uuid
import urllib.parse
from http.server import HTTPServer, SimpleHTTPRequestHandler
from typing import Dict, Any, Optional, List
from plugins import DriverPlugin
# 导入规则引擎
from core.rule_engine import RuleEngine
from core.database import get_db_manager
from core.log_utils import sanitize_for_log


# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# --- 修改请求处理器 ---
class StaticFileHTTPRequestHandler(SimpleHTTPRequestHandler):
    '''
    自定义请求处理器，用于提供静态文件服务，并集成反制逻辑
    '''

    # 类属性，由 driver 在启动时设置
    rule_engine: Optional[RuleEngine] = None

    def __init__(self, *args, directory=None, environment_id=None, environment_rules=None, **kwargs):
        self.directory = directory if directory is not None else os.getcwd()
        self.environment_id = environment_id
        # 存储环境特定规则，如果没有则使用空列表
        self.environment_rules = environment_rules if environment_rules is not None else []
        # 用于存储当前请求的注入信息
        self._inject_info = None
        super().__init__(*args, **kwargs)

    def log_message(self, format, *args):
        # Override default logging to use our logger
        safe_addr = sanitize_for_log(self.address_string())
        safe_fmt = sanitize_for_log(format % args)
        logger.info("%s - - [%s] %s\n" % (safe_addr, self.log_date_time_string(), safe_fmt))

    def _build_request_context(self) -> Dict[str, Any]:
        """构建规则引擎需要的请求上下文"""
        return {
            'protocol': 'https' if isinstance(self.connection, ssl.SSLSocket) else 'http',
            'method': self.command,
            'path': self.path,
            'headers': dict(self.headers),
            'ip': self.client_address[0],
            # 可以根据需要添加更多上下文
        }

    def _execute_countermeasures(self, request_context: Dict[str, Any]) -> List[Dict[str, Any]]:
        """执行反制措施"""
        if self.rule_engine and self.environment_id:
            # 使用环境特定规则，如果没有则使用空列表（将由规则引擎使用全局规则）
            return self.rule_engine.evaluate_and_execute_for_environment(
                request_context, self.environment_rules, self.environment_id
            )
        else:
            logger.warning("Rule engine or environment ID not available for countermeasure execution.")
            return []

    def _handle_inject_content(self, results: List[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
        """处理内容注入反制结果"""
        for result in results:
            if result.get('action') == 'inject_content' and result.get('result', {}).get('success'):
                return result['result'].get('inject_info')
        return None

    def _apply_countermeasure_side_effects(self, results: List[Dict[str, Any]]):
        """应用反制措施的副作用（如延迟）"""
        # 目前只处理延迟，其他副作用可以类似添加
        for result in results:
            if result.get('action') == 'delay':
                delay_config = result.get('result', {}).get('config', {})
                duration = delay_config.get('duration', 0)
                if duration > 0:
                    logger.info(f"Applying delay of {duration} seconds for countermeasure.")
                    time.sleep(duration)

    def do_GET(self):
        """处理 GET 请求，集成反制逻辑"""
        # 1. 构建请求上下文
        request_context = self._build_request_context()

        # 2. 执行反制措施
        countermeasure_results = self._execute_countermeasures(request_context)

        # 3. 应用反制措施的副作用（如延迟）
        self._apply_countermeasure_side_effects(countermeasure_results)

        # 4. 处理内容注入反制结果，并存储到实例变量
        self._inject_info = self._handle_inject_content(countermeasure_results)
        if self._inject_info:
            safe_path = sanitize_for_log(self.path)
            safe_info = sanitize_for_log(str(self._inject_info))
            logger.info(f"Content injection prepared for {safe_path}. Inject info: {safe_info}")

        # 5. 检查是否有错误码反制结果
        error_triggered = False
        for result in countermeasure_results:
            if result.get('action') == 'error_response':
                error_config = result.get('result', {}).get('config', {})
                error_code = error_config.get('code', 500)
                error_message = error_config.get('message', 'Internal Server Error')
                self.send_error(error_code, error_message)
                error_triggered = True
                break

        # 如果没有触发错误码，则正常处理文件
        if not error_triggered:
            super().do_GET()
            
            
    def send_head(self):
        """
        重写 send_head 方法以支持 HTML 内容注入。
        """
        # 先调用父类的 send_head 获取路径和 MIME 类型等信息
        path = self.translate_path(self.path)
        # 运行时边界检查：禁止通过符号链接或其它方式逃逸到静态根目录之外
        try:
            resolved_requested_path = os.path.realpath(path)
            resolved_base_dir = os.path.realpath(self.directory)
            if os.path.commonpath([resolved_requested_path, resolved_base_dir]) != resolved_base_dir:
                self.send_error(403, "Forbidden")
                return None
        except Exception:
            # 任意异常都视为非法访问
            self.send_error(403, "Forbidden")
            return None
        f = None
        if os.path.isdir(path):
            # 如果是目录，让父类处理（例如返回目录列表）
            # 注意：目录列表注入比较复杂，这里暂不处理
            parts = urllib.parse.urlsplit(self.path)
            if not parts.path.endswith('/'):
                # redirect browser - doing basically what apache does
                self.send_response(301)
                new_parts = (parts[0], parts[1], parts[2] + '/',
                            parts[3], parts[4])
                new_url = urllib.parse.urlunsplit(new_parts)
                self.send_header("Location", new_url)
                self.end_headers()
                return None
            for index in "index.html", "index.htm":
                index = os.path.join(path, index)
                if os.path.exists(index):
                    path = index
                    break
            else:
                return super().send_head() # 如果没有 index 文件，让父类处理目录列表

        ctype = self.guess_type(path)
        # 即使是目录，如果找到了 index 文件，path 也会被更新为 index 文件的路径
        # 所以这里检查的是 index 文件的类型

        # 只有在是 HTML 文件且有注入信息时才进行特殊处理
        if ctype.startswith('text/html') and self._inject_info:
            try:
                f = open(path, 'rb')
            except OSError:
                self.send_error(404, "File not found")
                return None

            # 读取整个文件内容
            fs = os.fstat(f.fileno())
            content = f.read()
            f.close() # 读取后立即关闭文件

            # 将内容解码为字符串进行处理
            # SimpleHTTPRequestHandler 默认使用 utf-8
            try:
                content_str = content.decode('utf-8')
            except UnicodeDecodeError:
                # 如果不是 utf-8 编码，回退到原始处理方式
                logger.warning(f"Failed to decode {path} as utf-8 for injection. Serving as-is.")
                return super().send_head()

            # --- 执行内容注入 ---
            content_to_inject = self._inject_info['content']
            location = self._inject_info['location']

            injected_content = content_str
            if location == 'before_body_end':
                injected_content = content_str.replace('</body>', f'{content_to_inject}\n</body>', 1)
            elif location == 'head':
                injected_content = content_str.replace('</head>', f'{content_to_inject}\n</head>', 1)
            elif location == 'body_start':
                injected_content = content_str.replace('<body>', f'<body>\n{content_to_inject}', 1)
            else:
                # 如果位置不支持，追加到内容末尾
                injected_content = content_str + content_to_inject

            # --- 发送响应 ---
            # 编码回 bytes
            injected_bytes = injected_content.encode('utf-8')

            # 发送响应头
            self.send_response(200)
            self.send_header("Content-type", ctype)
            self.send_header("Content-Length", str(len(injected_bytes)))
            self.send_header("Last-Modified",
                self.date_time_string(fs.st_mtime))
            self.end_headers()

            # 发送修改后的内容
            self.wfile.write(injected_bytes)
            return None # 已经发送了响应，返回 None

        else:
            # 对于非 HTML 文件、没有注入信息的情况，或注入失败的情况，使用默认处理方式
            return super().send_head()

    
# --- 修改驱动 ---
class StaticFileHoneypotDriver(DriverPlugin):
    '''
    静态文件蜜罐驱动插件实现 (集成反制)
    '''

    def __init__(self):
        #self.environments = {}  # 存储环境信息
        self.db_manager = get_db_manager()
        self.servers = {}       # 存储运行中的服务器实例
        self.rule_engine = None # RuleEngine 实例

    def _get_rule_engine(self):
        """获取 rule_engine 实例"""
        # 尝试从 api.server 获取全局 rule_engine
        try:
            import api.server
            return api.server.rule_engine
        except (ImportError, AttributeError):
            logger.warning("Failed to get global rule_engine instance in StaticFileHoneypotDriver.")
            return None

    def _get_rule_manager(self):
        """获取 rule_manager 实例以获取环境规则"""
        try:
            import api.server
            return api.server.rule_manager
        except (ImportError, AttributeError):
            logger.warning("Failed to get global rule_manager instance in StaticFileHoneypotDriver.")
            return None

    def _resolve_and_validate_static_root(self, static_root: str) -> str:
        """解析并校验静态根目录，限制在项目 static_sites 目录内。
        :param static_root: 配置传入的静态目录（相对或绝对路径）
        :return: 经过 realpath 解析且校验通过的绝对路径
        :raises: ValueError 当路径不存在/不是目录/越界时
        """
        # 计算项目根目录（当前文件位于 drivers/static_file_honeypot/driver.py）
        drivers_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
        project_root = os.path.dirname(drivers_dir)
        allowed_base = os.path.realpath(os.path.join(project_root, 'index'))

        # 将相对路径视为相对于项目根目录
        candidate_path = static_root
        if not os.path.isabs(candidate_path):
            candidate_path = os.path.join(project_root, candidate_path)

        resolved = os.path.realpath(candidate_path)

        # 基础存在性与类型检查
        if not os.path.exists(resolved):
            raise ValueError(f"Static root does not exist: {resolved}")
        if not os.path.isdir(resolved):
            raise ValueError(f"Static root is not a directory: {resolved}")

        # 边界检查：必须在 static_sites 基础目录内
        try:
            if os.path.commonpath([resolved, allowed_base]) != allowed_base:
                raise ValueError(f"Static root is outside allowed base: {resolved}")
        except Exception:
            # 在 Windows 等平台 commonpath 可能在不同盘符上抛出异常
            raise ValueError(f"Static root validation failed for: {resolved}")

        return resolved

    def list_environments(self) -> List[Dict[str, Any]]:
        """
        从数据库列出所有环境。
        :return: 环境信息字典列表。
        """
        try:
            with self.db_manager.get_connection() as conn:
                cursor = conn.cursor()
                cursor.execute("SELECT * FROM environments")
                rows = cursor.fetchall()
                environments = []
                for row in rows:
                    env_data = dict(row)
                    # 将 config JSON 字符串转换回字典
                    env_data['config'] = json.loads(env_data['config'])
                    environments.append(env_data)
                return environments
        except sqlite3.Error as e:
            logger.error(f"Error listing environments: {e}")
            return []

    def get_environment(self, env_id: str) -> Optional[Dict[str, Any]]:
        """
        从数据库获取环境信息。
        :param env_id: 环境唯一标识符。
        :return: 环境信息字典，如果未找到则返回 None。
        """
        try:
            with self.db_manager.get_connection() as conn:
                cursor = conn.cursor()
                cursor.execute("SELECT * FROM environments WHERE id = ?", (env_id,))
                row = cursor.fetchone()
                if row:
                    # 将 config JSON 字符串转换回字典
                    env_data = dict(row)
                    env_data['config'] = json.loads(env_data['config'])
                    return env_data
        except sqlite3.Error as e:
            logger.error(f"Error getting environment {env_id}: {e}")
        return None


    def update_environment(self, env_id: str, status: str) -> bool:
        try:
            with self.db_manager.get_connection() as conn:
                cursor = conn.cursor()
                cursor.execute(
                    "UPDATE environments set status = ? where id = ?",
                    (status, env_id)
                )
                conn.commit()
                logger.info(f"Environment {env_id} update in database.")
                return True
        except sqlite3.IntegrityError:
            logger.warning(f"Environment {env_id} already exists in database.")
            return False
        except sqlite3.Error as e:
            logger.error(f"Error updateing environment {env_id}: {e}")
            return False
        
    def create_environment(self, env_config: Dict[str, Any]) -> str:
        '''创建静态文件蜜罐环境'''
        # 使用 uuid4 生成唯一ID
        env_id = f"static_env_{uuid.uuid4().hex}"
        # self.environments[env_id] = {
        #     'id': env_id,
        #     'config': env_config,
        #     'status': 'created'
        # }
        logger.info(f"Created Static File Honeypot environment {env_id} with config {env_config}")
        return env_id

    def start_environment(self, env_id: str) -> bool:
        '''启动静态文件蜜罐环境 (集成反制)'''
        env = self.get_environment(env_id)
        if not env:
            logger.error(f"Environment {env_id} not found")
            return False

        status = env['status']
        # if status == 'running':
        #     logger.warning(f"Environment {env_id} is already running")
        #     return False

        try:
            # 确保 rule_engine 可用
            if not self.rule_engine:
                self.rule_engine = self._get_rule_engine()
            if not self.rule_engine:
                logger.error(f"Rule engine not available for environment {env_id}")
                return False
            # 将 rule_engine 赋值给 handler 类属性
            StaticFileHTTPRequestHandler.rule_engine = self.rule_engine

            config = env['config']
            host = config.get('host', '0.0.0.0')
            port = config.get('port', 8080)
            # 从配置中获取静态文件根目录
            static_root = config.get('static_root', '.')
            # 在启动时校验静态根目录，限制在项目 static_sites 内，防止目录遍历/逃逸
            static_root = self._resolve_and_validate_static_root(static_root)
            ssl_config = config.get('ssl', None)

            # --- 获取环境规则 ---
            environment_rules = []
            rule_manager = self._get_rule_manager()
            if rule_manager:
                try:
                    environment_rules = rule_manager.get_environment_rules(env_id)
                except Exception as e:
                    logger.warning(f"Failed to get environment rules for {env_id}: {e}. Using empty list.")
            else:
                logger.warning(f"Rule manager not available for environment {env_id}. Using empty rule list.")

            # --- 创建自定义请求处理器类 (使用闭包传递参数) ---
            def handler_factory(*args, **kwargs):
                return StaticFileHTTPRequestHandler(
                    *args, 
                    directory=static_root, 
                    environment_id=env_id,
                    environment_rules=environment_rules,
                    **kwargs
                )

            # 创建并启动HTTP服务器
            server = HTTPServer((host, port), handler_factory)
            
            # 配置SSL（如果提供）
            if ssl_config:
                certfile = ssl_config.get('certfile')
                keyfile = ssl_config.get('keyfile')
                if certfile and keyfile and os.path.exists(certfile) and os.path.exists(keyfile):
                    server.socket = ssl.wrap_socket(server.socket, 
                                                   certfile=certfile, 
                                                   keyfile=keyfile, 
                                                   server_side=True)
                    logger.info(f"SSL enabled for environment {env_id}")
                else:
                    logger.warning(f"Invalid SSL configuration for environment {env_id}. Starting without SSL.")
            
            server_thread = threading.Thread(target=server.serve_forever, daemon=True)
            server_thread.start()

            # 保存服务器实例
            self.servers[env_id] = {
                'server': server,
                'thread': server_thread
            }
            status = 'running'
            try:
                with self.db_manager.get_connection() as conn:
                    cursor = conn.cursor()
                    cursor.execute(
                        "UPDATE environments set status = ? where id = ?",
                        (status, env_id)
                    )
                    conn.commit()
            except sqlite3.IntegrityError:
                logger.warning(f"Environment {env_id} already exists in database.")
                return False
            except sqlite3.Error as e:
                logger.error(f"Error updateing environment {env_id}: {e}")
                return False
            logger.info(f"Started Static File Honeypot environment {env_id} on {host}:{port} serving files from '{static_root}'{' (HTTPS)' if ssl_config else ' (HTTP)'}")
            return True
        except Exception as e:
            logger.error(f"Failed to start Static File Honeypot environment {env_id}: {e}")
            return False

    def stop_environment(self, env_id: str) -> bool:
        '''停止静态文件蜜罐环境'''
        env = self.get_environment(env_id)
        if not env:
            logger.error(f"Environment {env_id} not found")
            return False

        if env['status'] != 'running':
            logger.warning(f"Environment {env_id} is not running")
            return False

        try:
            server_info = self.servers.get(env_id)
            if server_info:
                server = server_info['server']
                server.shutdown()
                server.server_close()
                del self.servers[env_id]

            # env['status'] = 'stopped'
            self.update_environment(env_id, "stopped")
            logger.info(f"Stopped Static File Honeypot environment {env_id}")
            return True
        except Exception as e:
            logger.error(f"Failed to stop Static File Honeypot environment {env_id}: {e}")
            return False

    def delete_environment(self, env_id: str) -> bool:
        '''删除静态文件蜜罐环境'''
        env = self.get_environment(env_id)
        if not env:
            logger.error(f"Environment {env_id} not found")
            return False

        # 如果环境正在运行，先停止它
        if env['status'] == 'running':
            self.stop_environment(env_id)

        # del self.environments[env_id]
        
        logger.info(f"Deleted Static File Honeypot environment {env_id}")
        return True

    def get_environment_status(self, env_id: str) -> str:
        '''获取环境状态'''
        env = self.get_environment(env_id)
        if not env:
            return 'not_found'
        return env['status']

    # def list_environments(self) -> Dict[str, Any]:
    #     '''列出所有环境'''
    #     return self.environments
