from __future__ import annotations

import json
import time
import urllib.error
import urllib.request
from pathlib import Path
from typing import Any

from configs.kling import KlingConfig, ENDPOINTS
from scripts.kling_auth import make_jwt_from_env

ROOT = Path(__file__).resolve().parents[1]
REQ_LOG_DIR = ROOT / 'logs' / 'requests'
REQ_LOG_DIR.mkdir(parents=True, exist_ok=True)


class KlingClientError(RuntimeError):
    def __init__(self, message: str, *, category: str = 'unknown', status_code: int | None = None, response_body: str | None = None):
        super().__init__(message)
        self.category = category
        self.status_code = status_code
        self.response_body = response_body


class KlingClient:
    def __init__(self, config: KlingConfig | None = None):
        self.config = config or KlingConfig()
        if not self.config.api_token:
            try:
                self.config = KlingConfig(
                    api_base_url=self.config.api_base_url,
                    api_token=make_jwt_from_env(),
                    callback_base_url=self.config.callback_base_url,
                    request_timeout_sec=self.config.request_timeout_sec,
                    concurrency_limit=self.config.concurrency_limit,
                    callback_shared_secret=self.config.callback_shared_secret,
                    assets_root=self.config.assets_root,
                )
            except Exception as e:
                raise KlingClientError(f'Kling auth bootstrap failed: {e}', category='config') from e

    def _headers(self) -> dict[str, str]:
        return {'Authorization': f'Bearer {self.config.api_token}', 'Content-Type': 'application/json'}

    def _url(self, endpoint_type: str) -> str:
        path = ENDPOINTS.get(endpoint_type)
        if not path:
            raise KlingClientError(f'Unknown endpoint_type: {endpoint_type}', category='config')
        return self.config.api_base_url.rstrip('/') + path

    def _categorize(self, status: int | None, body: str | None) -> str:
        blob = (body or '').lower()
        if status in (429, 409) or 'concurrency' in blob or 'quota' in blob or 'rate limit' in blob:
            return 'quota_or_concurrency'
        if status and 500 <= status < 600:
            return 'transient'
        if 'moderation' in blob or 'content risk' in blob:
            return 'moderation'
        if status and 400 <= status < 500:
            return 'malformed'
        if 'json' in blob or 'html' in blob or 'doctype' in blob:
            return 'unexpected_response'
        return 'unknown'

    def post(self, endpoint_type: str, payload: dict[str, Any]) -> dict[str, Any]:
        return self._request('POST', self._url(endpoint_type), payload, endpoint_type)

    def get_url(self, url: str, log_label: str = 'query') -> dict[str, Any]:
        return self._request('GET', url, None, log_label)

    def _request(self, method: str, url: str, payload: dict[str, Any] | None, log_label: str) -> dict[str, Any]:
        raw = json.dumps(payload).encode('utf-8') if payload is not None else None
        req = urllib.request.Request(url, data=raw, headers=self._headers(), method=method)
        try:
            with urllib.request.urlopen(req, timeout=self.config.request_timeout_sec) as resp:
                body = resp.read().decode('utf-8', errors='replace')
                status = getattr(resp, 'status', None)
        except urllib.error.HTTPError as e:
            body = e.read().decode('utf-8', errors='ignore')
            category = self._categorize(e.code, body)
            self._write_log(log_label, payload or {}, {'error': str(e), 'status_code': e.code, 'response_body': body, 'category': category})
            raise KlingClientError(str(e), category=category, status_code=e.code, response_body=body) from e
        except Exception as e:
            self._write_log(log_label, payload or {}, {'error': str(e), 'category': 'unknown'})
            raise KlingClientError(str(e), category='unknown') from e

        try:
            data = json.loads(body)
        except json.JSONDecodeError as e:
            category = self._categorize(status, body)
            self._write_log(log_label, payload or {}, {'error': f'Non-JSON response: {e}', 'status_code': status, 'response_body': body, 'category': category})
            raise KlingClientError(f'Non-JSON response from Kling: {e}', category=category, status_code=status, response_body=body) from e

        self._write_log(log_label, payload or {}, data)
        return data

    def _write_log(self, endpoint_type: str, payload: dict[str, Any], response: dict[str, Any]) -> None:
        ts = int(time.time() * 1000)
        path = REQ_LOG_DIR / f'{ts}_{endpoint_type}.json'
        path.write_text(json.dumps({'endpoint_type': endpoint_type, 'payload': payload, 'response': response}, ensure_ascii=False, indent=2))
