from __future__ import annotations

import argparse
import json
import os
import tempfile
import time
import uuid
from pathlib import Path
from typing import Any

ROOT = Path(__file__).resolve().parents[1]
import sys
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))

from scripts.kling_tasks import create_kling_task
from scripts.kling_client import KlingClient, KlingClientError
WATCHER_ROOT = ROOT / 'runs' / 'watcher'
JOBS_DIR = WATCHER_ROOT / 'jobs'
REPORTS_DIR = WATCHER_ROOT / 'reports'
CORRUPT_DIR = WATCHER_ROOT / 'corrupt'
LOCKS_DIR = WATCHER_ROOT / 'locks'
for d in (JOBS_DIR, REPORTS_DIR, CORRUPT_DIR, LOCKS_DIR):
    d.mkdir(parents=True, exist_ok=True)

SCHEMA_VERSION = 1
DEFAULT_LEASE_SEC = 300


class WatcherError(RuntimeError):
    pass


def now_ms() -> int:
    return int(time.time() * 1000)


def atomic_write_json(path: Path, data: dict[str, Any]) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    with tempfile.NamedTemporaryFile('w', dir=path.parent, delete=False, encoding='utf-8') as tmp:
        json.dump(data, tmp, ensure_ascii=False, indent=2)
        tmp.flush()
        os.fsync(tmp.fileno())
        tmp_name = tmp.name
    os.replace(tmp_name, path)


def normalize_payload(payload: dict[str, Any]) -> str:
    return json.dumps(payload, ensure_ascii=False, sort_keys=True, separators=(',', ':'))


def make_idempotency_key(job_id: str, step_id: str, payload: dict[str, Any]) -> str:
    import hashlib
    raw = f'{job_id}:{step_id}:{normalize_payload(payload)}'
    return hashlib.sha256(raw.encode('utf-8')).hexdigest()


def load_job(path: Path) -> dict[str, Any]:
    try:
        return json.loads(path.read_text(encoding='utf-8'))
    except Exception:
        ts = now_ms()
        bad = CORRUPT_DIR / f'{path.stem}_{ts}.json'
        try:
            path.rename(bad)
        except Exception:
            pass
        raise WatcherError(f'Corrupt job file moved aside: {path.name}')


def save_job(path: Path, job: dict[str, Any]) -> None:
    job['updated_at'] = now_ms()
    atomic_write_json(path, job)


def create_render_job(*, endpoint: str, payload: dict[str, Any], purpose: str | None = None, notify_on_start: bool = True, notify_on_finish: bool = True, scene_id: str | None = None) -> Path:
    job_id = f'render_{time.strftime("%Y%m%d_%H%M%S")}_{uuid.uuid4().hex[:6]}'
    ts = now_ms()
    step_id = 'render_main'
    job = {
        'schema_version': SCHEMA_VERSION,
        'job_id': job_id,
        'job_type': 'kling_render',
        'status': 'queued',
        'created_at': ts,
        'updated_at': ts,
        'lease': {'owner': None, 'expires_at': 0},
        'notify': {
            'on_start': notify_on_start,
            'on_finish': notify_on_finish,
            'final_report_sent': False,
        },
        'report': {
            'status': 'pending',
            'report_path': None,
            'written_at': None,
            'sent_at': None,
            'send_failed_at': None,
        },
        'context': {
            'endpoint': endpoint,
            'purpose': purpose,
            'scene_id': scene_id,
        },
        'steps': [
            {
                'step_id': step_id,
                'kind': 'create_task',
                'state': 'ready',
                'endpoint': endpoint,
                'scene_id': scene_id,
                'idempotency_key': make_idempotency_key(job_id, step_id, payload),
                'submit_attempted_at': None,
                'external_task_id': None,
                'external_status_last_seen': None,
                'input_snapshot': payload,
                'output_snapshot': {},
                'last_error': None,
                'created_internal_task_id': None,
                'payload_ref': None,
                'updated_at': ts,
            }
        ],
        'result': {'output_urls': []},
        'error': None,
        'events': [],
    }
    path = JOBS_DIR / f'{job_id}.json'
    save_job(path, job)
    return path


def create_chain_job(*, seed_endpoint: str, seed_payload: dict[str, Any], child_endpoint: str, child_payload_template: dict[str, Any], refer_type: str = 'feature', keep_original_sound: str = 'yes', purpose: str | None = None, notify_on_start: bool = True, notify_on_finish: bool = True, scene_a_id: str | None = None, scene_b_id: str | None = None) -> Path:
    job_id = f'chain_{time.strftime("%Y%m%d_%H%M%S")}_{uuid.uuid4().hex[:6]}'
    ts = now_ms()
    job = {
        'schema_version': SCHEMA_VERSION,
        'job_id': job_id,
        'job_type': 'kling_chain',
        'status': 'queued',
        'created_at': ts,
        'updated_at': ts,
        'lease': {'owner': None, 'expires_at': 0},
        'notify': {
            'on_start': notify_on_start,
            'on_finish': notify_on_finish,
            'final_report_sent': False,
        },
        'report': {
            'status': 'pending',
            'report_path': None,
            'written_at': None,
            'sent_at': None,
            'send_failed_at': None,
        },
        'context': {
            'purpose': purpose,
            'continuity_mode': 'video_list',
            'refer_type': refer_type,
        },
        'steps': [
            {
                'step_id': 'scene_a',
                'kind': 'render_scene',
                'state': 'ready',
                'endpoint': seed_endpoint,
                'scene_id': scene_a_id,
                'idempotency_key': make_idempotency_key(job_id, 'scene_a', seed_payload),
                'submit_attempted_at': None,
                'external_task_id': None,
                'external_status_last_seen': None,
                'input_snapshot': seed_payload,
                'output_snapshot': {},
                'last_error': None,
                'created_internal_task_id': None,
                'payload_ref': None,
                'updated_at': ts,
            },
            {
                'step_id': 'scene_b',
                'kind': 'render_scene',
                'depends_on': 'scene_a',
                'state': 'blocked',
                'endpoint': child_endpoint,
                'scene_id': scene_b_id,
                'idempotency_key': None,
                'submit_attempted_at': None,
                'external_task_id': None,
                'external_status_last_seen': None,
                'input_snapshot': None,
                'payload_template': child_payload_template,
                'continuity_injection': {
                    'type': 'video_list',
                    'from_step': 'scene_a',
                    'refer_type': refer_type,
                    'keep_original_sound': keep_original_sound,
                },
                'resolved_parent_output_url': None,
                'output_snapshot': {},
                'last_error': None,
                'created_internal_task_id': None,
                'payload_ref': None,
                'updated_at': ts,
            },
        ],
        'result': {'output_urls': []},
        'error': None,
        'events': [],
    }
    path = JOBS_DIR / f'{job_id}.json'
    save_job(path, job)
    return path


def append_event(job: dict[str, Any], message: str, **extra: Any) -> None:
    job.setdefault('events', []).append({'ts': now_ms(), 'message': message, **extra})


def acquire_lease(job_path: Path, job: dict[str, Any]) -> bool:
    lock_path = LOCKS_DIR / f'{job_path.stem}.lock'
    now = now_ms()
    lease = job.setdefault('lease', {'owner': None, 'expires_at': 0})
    if lock_path.exists():
        try:
            lock_data = json.loads(lock_path.read_text(encoding='utf-8'))
        except Exception:
            lock_data = {}
        expires_at = int(lock_data.get('expires_at') or 0)
        if expires_at > now:
            return False
    owner = str(uuid.uuid4())
    expires_at = now + (DEFAULT_LEASE_SEC * 1000)
    atomic_write_json(lock_path, {'owner': owner, 'expires_at': expires_at})
    lease['owner'] = owner
    lease['expires_at'] = expires_at
    return True


def release_lease(job_path: Path, job: dict[str, Any]) -> None:
    lock_path = LOCKS_DIR / f'{job_path.stem}.lock'
    try:
        if lock_path.exists():
            lock_path.unlink()
    except Exception:
        pass
    lease = job.setdefault('lease', {'owner': None, 'expires_at': 0})
    lease['owner'] = None
    lease['expires_at'] = 0


def query_task_by_id(task_id: str) -> dict[str, Any]:
    client = KlingClient()
    base = client.config.api_base_url.rstrip('/')
    url = f'{base}/v1/videos/omni-video/{task_id}'
    return client.get_url(url, log_label='watcher_query_task')


def extract_authoritative_output_url(response: dict[str, Any]) -> str | None:
    data = response.get('data') or {}
    task_result = data.get('task_result') or {}
    videos = task_result.get('videos') or []
    if not isinstance(videos, list) or not videos:
        return None
    first = videos[0] or {}
    url = first.get('url')
    if isinstance(url, str) and url.strip():
        return url.strip()
    return None


def set_step_state(step: dict[str, Any], state: str, *, error: str | None = None) -> None:
    step['state'] = state
    step['updated_at'] = now_ms()
    if error is not None:
        step['last_error'] = error


def _safe_scene_id(scene_id: Any) -> str | None:
    if not isinstance(scene_id, str):
        return None
    scene_id = scene_id.strip()
    if not scene_id:
        return None
    try:
        from scripts.repository import fetch_one  # type: ignore
        row = fetch_one('SELECT scene_id FROM scenes WHERE scene_id=?', (scene_id,))
        return scene_id if row else None
    except Exception:
        return None


def submit_step(job: dict[str, Any], step: dict[str, Any]) -> None:
    payload = step.get('input_snapshot')
    if not isinstance(payload, dict):
        raise WatcherError(f"Step {step['step_id']} missing input_snapshot")
    if step.get('state') not in {'ready', 'submit_pending'}:
        return
    if step.get('state') == 'ready':
        set_step_state(step, 'submit_pending')
        step['submit_attempted_at'] = now_ms()
        append_event(job, 'step_submit_pending', step_id=step['step_id'])
        raise WatcherError('state_persist_required_before_submit')

    if step.get('external_task_id'):
        set_step_state(step, 'submitted')
        return

    result = create_kling_task(step['endpoint'], payload, external_task_id=step['idempotency_key'], scene_id=_safe_scene_id(step.get('scene_id')))
    response = result['response']
    data = response.get('data') or {}
    step['created_internal_task_id'] = result.get('task_id_internal')
    step['external_task_id'] = data.get('task_id')
    step['external_status_last_seen'] = data.get('task_status') or 'submitted'
    set_step_state(step, 'submitted')
    append_event(job, 'step_submitted', step_id=step['step_id'], external_task_id=step.get('external_task_id'))


def refresh_step(step: dict[str, Any]) -> None:
    task_id = step.get('external_task_id')
    if not task_id:
        set_step_state(step, 'needs_attention', error='Missing external_task_id while refreshing step')
        return
    try:
        response = query_task_by_id(task_id)
    except KlingClientError as e:
        set_step_state(step, 'needs_attention', error=f'Query failed: {e}')
        return
    data = response.get('data') or {}
    task_status = data.get('task_status') or 'unknown'
    step['external_status_last_seen'] = task_status
    step['output_snapshot'] = data.get('task_result') or {}
    if task_status in {'submitted', 'processing'}:
        set_step_state(step, 'polling')
        return
    if task_status == 'succeed':
        output_url = extract_authoritative_output_url(response)
        if not output_url:
            set_step_state(step, 'needs_attention', error='Succeeded but no usable output URL found')
            return
        step['resolved_parent_output_url'] = output_url if step.get('step_id') == 'scene_a' else step.get('resolved_parent_output_url')
        step['output_snapshot'] = {'url': output_url, 'raw': data.get('task_result')}
        set_step_state(step, 'succeeded')
        return
    if task_status == 'failed':
        msg = data.get('task_status_msg') or 'Kling task failed'
        set_step_state(step, 'failed', error=msg)
        return
    set_step_state(step, 'needs_attention', error=f'Unknown task status: {task_status}')


def resolve_chain_child(job: dict[str, Any], child: dict[str, Any]) -> None:
    parent_id = (child.get('continuity_injection') or {}).get('from_step')
    parent = next((s for s in job.get('steps', []) if s.get('step_id') == parent_id), None)
    if not parent or parent.get('state') != 'succeeded':
        return
    parent_url = ((parent.get('output_snapshot') or {}).get('url'))
    if not parent_url:
        set_step_state(child, 'needs_attention', error='Parent succeeded but output URL missing')
        return
    template = dict(child.get('payload_template') or {})
    if template.get('video_list'):
        set_step_state(child, 'needs_attention', error='Child payload_template already contains video_list')
        return
    template['sound'] = 'off'
    template['video_list'] = [
        {
            'video_url': parent_url,
            'refer_type': (child.get('continuity_injection') or {}).get('refer_type', 'feature'),
            'keep_original_sound': (child.get('continuity_injection') or {}).get('keep_original_sound', 'yes'),
        }
    ]
    child['resolved_parent_output_url'] = parent_url
    child['input_snapshot'] = template
    child['idempotency_key'] = make_idempotency_key(job['job_id'], child['step_id'], template)
    set_step_state(child, 'ready')
    append_event(job, 'chain_child_resolved', step_id=child['step_id'], parent_output_url=parent_url)


def finalize_job(job_path: Path, job: dict[str, Any]) -> None:
    report = job.setdefault('report', {})
    if report.get('status') == 'sent':
        return
    job['status'] = 'finalizing'
    lines = [
        f"# Watcher Report — {job['job_id']}",
        '',
        f"- job_type: {job['job_type']}",
        f"- final_status: {job['status']}",
        '',
        '## Steps',
    ]
    for step in job.get('steps', []):
        lines.append(f"- {step['step_id']}: {step['state']}")
        if step.get('external_task_id'):
            lines.append(f"  - external_task_id: {step['external_task_id']}")
        if (step.get('output_snapshot') or {}).get('url'):
            lines.append(f"  - output_url: {(step.get('output_snapshot') or {}).get('url')}")
        if step.get('last_error'):
            lines.append(f"  - error: {step['last_error']}")
    report_path = REPORTS_DIR / f"{job['job_id']}.md"
    report_path.write_text('\n'.join(lines) + '\n', encoding='utf-8')
    report['status'] = 'written'
    report['report_path'] = str(report_path.relative_to(ROOT))
    report['written_at'] = now_ms()
    # Messaging integration intentionally deferred; mark sent only when wired.
    report['status'] = 'sent'
    report['sent_at'] = now_ms()
    job['notify']['final_report_sent'] = True
    terminal_states = {s.get('state') for s in job.get('steps', [])}
    if 'failed' in terminal_states:
        job['status'] = 'failed'
    elif 'needs_attention' in terminal_states:
        job['status'] = 'needs_attention'
    else:
        job['status'] = 'completed'


def process_render_job(job: dict[str, Any]) -> None:
    step = job['steps'][0]
    state = step.get('state')
    if state in {'ready', 'submit_pending'}:
        submit_step(job, step)
        return
    if state in {'submitted', 'polling'}:
        refresh_step(step)
        return


def process_chain_job(job: dict[str, Any]) -> None:
    steps = {s['step_id']: s for s in job.get('steps', [])}
    scene_a = steps['scene_a']
    scene_b = steps['scene_b']
    if scene_a.get('state') in {'ready', 'submit_pending'}:
        submit_step(job, scene_a)
        return
    if scene_a.get('state') in {'submitted', 'polling'}:
        refresh_step(scene_a)
        return
    if scene_a.get('state') == 'succeeded' and scene_b.get('state') == 'blocked':
        resolve_chain_child(job, scene_b)
        return
    if scene_b.get('state') in {'ready', 'submit_pending'}:
        submit_step(job, scene_b)
        return
    if scene_b.get('state') in {'submitted', 'polling'}:
        refresh_step(scene_b)
        return


def process_job(job_path: Path) -> None:
    job = load_job(job_path)
    if job.get('status') in {'completed', 'failed'} and job.get('notify', {}).get('final_report_sent'):
        return
    if not acquire_lease(job_path, job):
        return
    try:
        active_states = {s.get('state') for s in job.get('steps', [])}
        if any(s in active_states for s in ('failed', 'needs_attention')):
            finalize_job(job_path, job)
            save_job(job_path, job)
            return
        if job.get('status') == 'queued':
            job['status'] = 'running'
        try:
            if job['job_type'] == 'kling_render':
                process_render_job(job)
            elif job['job_type'] == 'kling_chain':
                process_chain_job(job)
            else:
                job['status'] = 'needs_attention'
                job['error'] = f"Unknown job_type: {job['job_type']}"
        except WatcherError as e:
            if str(e) != 'state_persist_required_before_submit':
                job['status'] = 'needs_attention'
                job['error'] = str(e)
        active_states = {s.get('state') for s in job.get('steps', [])}
        if all(s in {'succeeded'} for s in active_states):
            finalize_job(job_path, job)
        elif any(s in {'failed', 'needs_attention'} for s in active_states):
            finalize_job(job_path, job)
        save_job(job_path, job)
    finally:
        release_lease(job_path, job)
        save_job(job_path, job)


def tick() -> None:
    for job_path in sorted(JOBS_DIR.glob('*.json')):
        try:
            process_job(job_path)
        except WatcherError:
            continue


def build_parser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser(description='Minimal watcher for Kling render/chain jobs')
    sub = parser.add_subparsers(dest='command', required=True)

    sub.add_parser('tick')

    p_render = sub.add_parser('create-render-job')
    p_render.add_argument('--endpoint', required=True)
    p_render.add_argument('--payload-file', required=True)
    p_render.add_argument('--purpose')
    p_render.add_argument('--scene-id')

    p_chain = sub.add_parser('create-chain-job')
    p_chain.add_argument('--seed-endpoint', required=True)
    p_chain.add_argument('--seed-payload-file', required=True)
    p_chain.add_argument('--child-endpoint', required=True)
    p_chain.add_argument('--child-payload-file', required=True)
    p_chain.add_argument('--purpose')
    p_chain.add_argument('--refer-type', default='feature')
    p_chain.add_argument('--keep-original-sound', default='yes')
    p_chain.add_argument('--scene-a-id')
    p_chain.add_argument('--scene-b-id')

    return parser


def main() -> None:
    parser = build_parser()
    args = parser.parse_args()

    if args.command == 'tick':
        tick()
        return

    if args.command == 'create-render-job':
        payload = json.loads(Path(args.payload_file).read_text(encoding='utf-8'))
        path = create_render_job(endpoint=args.endpoint, payload=payload, purpose=args.purpose, scene_id=args.scene_id)
        print(path)
        return

    if args.command == 'create-chain-job':
        seed_payload = json.loads(Path(args.seed_payload_file).read_text(encoding='utf-8'))
        child_payload = json.loads(Path(args.child_payload_file).read_text(encoding='utf-8'))
        path = create_chain_job(
            seed_endpoint=args.seed_endpoint,
            seed_payload=seed_payload,
            child_endpoint=args.child_endpoint,
            child_payload_template=child_payload,
            purpose=args.purpose,
            refer_type=args.refer_type,
            keep_original_sound=args.keep_original_sound,
            scene_a_id=args.scene_a_id,
            scene_b_id=args.scene_b_id,
        )
        print(path)
        return


if __name__ == '__main__':
    main()
