from __future__ import annotations

import base64
from pathlib import Path
from typing import Any

from configs.kling import V1_MODELS, ALLOWED_30_MODELS, OMNI_MODEL_ENDPOINTS, PRODUCTION_30_ENDPOINTS


class ValidationError(ValueError):
    pass


REQUIRED_FIELDS = {
    'text2video': ['prompt'],
    'image2video': ['image', 'prompt'],
    'omni': ['prompt'],
    'reference2video': ['image_list', 'prompt'],
    'extend': ['prompt'],
}

ALLOWED_FIELDS = {
    'text2video': {'model_name', 'prompt', 'negative_prompt', 'duration', 'mode', 'aspect_ratio', 'callback_url', 'external_task_id', 'multi_shot', 'shot_type', 'multi_prompt', 'sound', 'cfg_scale', 'camera_control', 'watermark_info'},
    'image2video': {'model_name', 'image', 'image_tail', 'prompt', 'negative_prompt', 'duration', 'mode', 'aspect_ratio', 'callback_url', 'external_task_id', 'element_list', 'voice_list', 'sound', 'cfg_scale', 'static_mask', 'dynamic_masks', 'camera_control', 'watermark_info', 'multi_shot', 'shot_type', 'multi_prompt'},
    'omni': {'model_name', 'prompt', 'image_list', 'video_list', 'element_list', 'duration', 'mode', 'aspect_ratio', 'callback_url', 'external_task_id', 'multi_shot', 'shot_type', 'multi_prompt', 'sound'},
    'reference2video': {'model_name', 'image_list', 'prompt', 'negative_prompt', 'duration', 'mode', 'aspect_ratio', 'callback_url', 'external_task_id'},
    'extend': {'video_id', 'video_url', 'prompt', 'negative_prompt', 'callback_url', 'external_task_id'},
}

ALLOWED_ENUMS = {
    'mode': {'std', 'pro'},
    'aspect_ratio': {'16:9', '9:16', '1:1'},
}


# Only validate rules that are clearly grounded in current docs/live notes.
# If a semantic remains unsettled (for example, whether a top-level prompt should be rejected
# for every multi-shot family in practice), do not encode that as a hard error yet.
# Current production posture supports the latest 3.0 base model for non-Omni video families
# and the latest 3.0 Omni families for the Omni endpoint.

def validate_payload(endpoint_type: str, payload: dict[str, Any]) -> None:
    multi_shot_active = bool(payload.get('multi_shot'))
    req = REQUIRED_FIELDS.get(endpoint_type, [])
    missing = [k for k in req if k not in payload or payload.get(k) in (None, '', [])]
    if multi_shot_active and endpoint_type in ('omni', 'text2video') and 'prompt' in missing:
        missing.remove('prompt')
    if missing:
        raise ValidationError(f'Missing required fields for {endpoint_type}: {missing}')

    allowed = ALLOWED_FIELDS.get(endpoint_type)
    if allowed is not None:
        unknown = sorted(set(payload.keys()) - allowed)
        if unknown:
            raise ValidationError(f'Unknown fields for {endpoint_type}: {unknown}')

    model = payload.get('model_name')
    if model is not None:
        if endpoint_type not in OMNI_MODEL_ENDPOINTS and endpoint_type not in PRODUCTION_30_ENDPOINTS:
            raise ValidationError(
                f"model_name is not supported for endpoint {endpoint_type} in the current scaffold policy"
            )
        if model not in ALLOWED_30_MODELS:
            raise ValidationError(f"Unsupported model_name: {model}. Only synchronized 3.0/3.0 Omni models are allowed in the current scaffold: {sorted(ALLOWED_30_MODELS)}")
        allowed_models = set(V1_MODELS.get(endpoint_type, []))
        if allowed_models and model not in allowed_models:
            raise ValidationError(f"Model {model} is not allowed for endpoint {endpoint_type}. Allowed: {sorted(allowed_models)}")

    mode = payload.get('mode')
    if mode is not None and mode not in ALLOWED_ENUMS['mode']:
        raise ValidationError(f"Invalid mode: {mode}. Expected one of {sorted(ALLOWED_ENUMS['mode'])}")

    ar = payload.get('aspect_ratio')
    if ar is not None and ar not in ALLOWED_ENUMS['aspect_ratio']:
        raise ValidationError(f"Invalid aspect_ratio: {ar}. Expected one of {sorted(ALLOWED_ENUMS['aspect_ratio'])}")

    duration_str = payload.get('duration')
    if duration_str is not None:
        _validate_duration(duration_str)

    if endpoint_type == 'extend' and not (payload.get('video_id') or payload.get('video_url')):
        raise ValidationError('Extend requires either video_id or video_url')

    if payload.get('multi_shot'):
        if endpoint_type not in ('text2video', 'image2video', 'omni'):
            raise ValidationError(f'multi_shot is not supported in current scaffold for {endpoint_type}')
        _validate_multi_shot(payload, endpoint_type)
    else:
        if payload.get('shot_type') is not None:
            raise ValidationError('shot_type must be omitted when multi_shot is false')
        if payload.get('multi_prompt') is not None:
            raise ValidationError('multi_prompt must be omitted when multi_shot is false')

    if endpoint_type == 'omni' and 'image_list' in payload:
        _validate_omni_image_list(payload)
    if endpoint_type == 'reference2video' and 'image_list' in payload:
        _validate_reference2video_image_list(payload)
    if endpoint_type == 'omni' and 'video_list' in payload:
        _validate_omni_video_list(payload)
    if endpoint_type == 'image2video':
        _validate_image2video_inputs(payload)
    if endpoint_type == 'omni' and 'element_list' in payload:
        _validate_omni_element_list(payload)
    if endpoint_type == 'image2video' and 'voice_list' in payload and 'element_list' in payload:
        if payload.get('voice_list') and payload.get('element_list'):
            raise ValidationError('image2video element_list and voice_list are mutually exclusive in current preserved docs')


def _validate_duration(duration: Any) -> None:
    if isinstance(duration, str) and duration.isdigit():
        seconds = int(duration)
    elif isinstance(duration, int):
        seconds = duration
    else:
        raise ValidationError(f'duration must be a numeric string or int, got {duration!r}')
    if not 3 <= seconds <= 15:
        raise ValidationError(f'duration must be between 3 and 15 seconds inclusive, got {seconds}')


def _validate_multi_shot(payload: dict[str, Any], endpoint_type: str) -> None:
    shot_type = payload.get('shot_type')
    if shot_type not in {'customize'}:
        raise ValidationError(f'{endpoint_type} multi_shot currently requires shot_type="customize" in the synchronized scaffold')

    multi_prompt = payload.get('multi_prompt')
    if not isinstance(multi_prompt, list) or not multi_prompt:
        raise ValidationError(f'{endpoint_type} multi_shot requires a non-empty multi_prompt list')
    if len(multi_prompt) > 6:
        raise ValidationError(f'{endpoint_type} multi_prompt supports at most 6 storyboard entries')

    declared_duration = _coerce_duration_seconds(payload.get('duration'))
    total_storyboard_duration = 0
    seen_indexes: set[int] = set()
    for i, item in enumerate(multi_prompt, start=1):
        if not isinstance(item, dict):
            raise ValidationError(f'{endpoint_type} multi_prompt[{i}] must be an object')
        unknown_item_keys = sorted(set(item.keys()) - {'index', 'prompt', 'duration'})
        if unknown_item_keys:
            raise ValidationError(f'{endpoint_type} multi_prompt[{i}] has unknown keys: {unknown_item_keys}')
        index = item.get('index')
        if not isinstance(index, int) or index < 1:
            raise ValidationError(f'{endpoint_type} multi_prompt[{i}] index must be a positive integer')
        if index in seen_indexes:
            raise ValidationError(f'{endpoint_type} multi_prompt indexes must be unique; duplicate index {index} found')
        seen_indexes.add(index)
        prompt = item.get('prompt')
        if not isinstance(prompt, str) or not prompt.strip():
            raise ValidationError(f'{endpoint_type} multi_prompt[{i}] requires a non-empty prompt')
        if len(prompt) > 512:
            raise ValidationError(f'{endpoint_type} multi_prompt[{i}] prompt must not exceed 512 characters')
        item_duration = _coerce_duration_seconds(item.get('duration'), label=f'{endpoint_type} multi_prompt[{i}] duration')
        if item_duration < 1:
            raise ValidationError(f'{endpoint_type} multi_prompt[{i}] duration must be at least 1 second')
        if item_duration > declared_duration:
            raise ValidationError(f'{endpoint_type} multi_prompt[{i}] duration must not exceed total task duration {declared_duration}')
        total_storyboard_duration += item_duration

    if total_storyboard_duration != declared_duration:
        raise ValidationError(f'{endpoint_type} multi_prompt durations must sum to total duration {declared_duration}; got {total_storyboard_duration}')


def _coerce_duration_seconds(value: Any, *, label: str = 'duration') -> int:
    if isinstance(value, int):
        return value
    if isinstance(value, str) and value.isdigit():
        return int(value)
    raise ValidationError(f'{label} must be a numeric string or int, got {value!r}')


def _validate_image2video_inputs(payload: dict[str, Any]) -> None:
    _validate_image_value(payload.get('image'), label='image2video image')
    if payload.get('image_tail') is not None:
        _validate_image_value(payload.get('image_tail'), label='image2video image_tail')


def _validate_omni_image_list(payload: dict[str, Any]) -> None:
    image_list = payload.get('image_list') or []
    if not isinstance(image_list, list):
        raise ValidationError('omni image_list must be a list')
    first_frame_count = 0
    end_frame_count = 0
    for i, item in enumerate(image_list, start=1):
        if not isinstance(item, dict):
            raise ValidationError(f'omni image_list[{i}] must be an object')
        allowed_item_keys = {'image_url', 'type'}
        unknown_item_keys = sorted(set(item.keys()) - allowed_item_keys)
        if unknown_item_keys:
            raise ValidationError(f'omni image_list[{i}] has unknown keys: {unknown_item_keys}')
        if not item.get('image_url'):
            raise ValidationError(f'omni image_list[{i}] requires image_url')
        _validate_image_value(item['image_url'], label=f'omni image_list[{i}].image_url')
        image_type = item.get('type')
        if image_type is not None and image_type not in {'first_frame', 'end_frame'}:
            raise ValidationError('omni image_list[{i}] type must be first_frame or end_frame when provided'.format(i=i))
        if image_type == 'first_frame':
            first_frame_count += 1
        if image_type == 'end_frame':
            end_frame_count += 1
    if first_frame_count > 1:
        raise ValidationError('omni image_list supports at most one first_frame entry in the current synchronized scaffold')
    if end_frame_count > 1:
        raise ValidationError('omni image_list supports at most one end_frame entry in the current synchronized scaffold')
    if end_frame_count and not first_frame_count:
        raise ValidationError('omni end_frame is not supported without a first_frame')
    if end_frame_count and len(image_list) > 2:
        raise ValidationError('omni end_frame is not supported when more than two images are supplied')


def _validate_reference2video_image_list(payload: dict[str, Any]) -> None:
    image_list = payload.get('image_list') or []
    if not isinstance(image_list, list) or not image_list:
        raise ValidationError('reference2video image_list must be a non-empty list')
    for i, item in enumerate(image_list, start=1):
        if not isinstance(item, dict):
            raise ValidationError(f'reference2video image_list[{i}] must be an object')
        allowed_item_keys = {'image', 'image_url'}
        unknown_item_keys = sorted(set(item.keys()) - allowed_item_keys)
        if unknown_item_keys:
            raise ValidationError(f'reference2video image_list[{i}] has unknown keys: {unknown_item_keys}')
        if item.get('image_url'):
            _validate_image_value(item['image_url'], label=f'reference2video image_list[{i}].image_url')
        elif item.get('image'):
            _validate_image_value(item['image'], label=f'reference2video image_list[{i}].image')
        else:
            raise ValidationError(f'reference2video image_list[{i}] requires image_url or image')


def _validate_image_value(value: Any, *, label: str) -> None:
    if not isinstance(value, str) or not value.strip():
        raise ValidationError(f'{label} must be a non-empty string')
    raw = value.strip()
    if raw.startswith('http://') or raw.startswith('https://'):
        return
    if raw.startswith('data:'):
        _, _, raw = raw.partition(',')
    try:
        base64.b64decode(raw, validate=True)
        return
    except Exception:
        pass
    if raw.startswith('file://'):
        if not Path(raw[7:]).is_file():
            raise ValidationError(f'{label} file path does not exist: {raw}')
        return
    try:
        if Path(raw).is_file():
            return
    except OSError:
        pass
    raise ValidationError(f'{label} must be a reachable URL, local file path, or raw base64 image data')


def _validate_omni_video_list(payload: dict[str, Any]) -> None:
    video_list = payload.get('video_list') or []
    if not isinstance(video_list, list):
        raise ValidationError('omni video_list must be a list')
    if len(video_list) > 1:
        raise ValidationError('omni video_list currently supports only one video entry')
    for i, item in enumerate(video_list, start=1):
        if not isinstance(item, dict):
            raise ValidationError(f'omni video_list[{i}] must be an object')
        allowed_item_keys = {'video_url', 'refer_type', 'keep_original_sound'}
        unknown_item_keys = sorted(set(item.keys()) - allowed_item_keys)
        if unknown_item_keys:
            raise ValidationError(f'omni video_list[{i}] has unknown keys: {unknown_item_keys}')
        if not item.get('video_url'):
            raise ValidationError(f'omni video_list[{i}] requires video_url')
        if item.get('refer_type') not in {None, 'base', 'feature'}:
            raise ValidationError(f'omni video_list[{i}] refer_type must be base or feature')
        if item.get('keep_original_sound') not in {None, 'yes', 'no'}:
            raise ValidationError(f'omni video_list[{i}] keep_original_sound must be yes or no')
    if payload.get('sound') not in (None, 'off'):
        raise ValidationError('omni sound must be omitted or off when video_list is present')


def _validate_omni_element_list(payload: dict[str, Any]) -> None:
    element_list = payload.get('element_list') or []
    if not isinstance(element_list, list):
        raise ValidationError('omni element_list must be a list')
    for i, item in enumerate(element_list, start=1):
        if not isinstance(item, dict):
            raise ValidationError(f'omni element_list[{i}] must be an object')
        allowed_item_keys = {'element_id'}
        unknown_item_keys = sorted(set(item.keys()) - allowed_item_keys)
        if unknown_item_keys:
            raise ValidationError(f'omni element_list[{i}] has unknown keys: {unknown_item_keys}')
        if 'element_id' not in item:
            raise ValidationError(f'omni element_list[{i}] requires element_id')
        if not isinstance(item['element_id'], int):
            raise ValidationError(f'omni element_list[{i}] element_id must be an integer')
