#!/usr/bin/env python3
"""okama-update - safe OkamaOS update manager.

Usage:
  okama-update check
  okama-update apply [--dry-run] [--sha256 <hex>] <path-or-url.okupdate>
  okama-update rollback
"""

import sys
import os

_lib = os.environ.get("OKAMA_LIB", "")
if not _lib:
    _lib = os.path.join(os.path.dirname(os.path.realpath(__file__)), "../lib")
sys.path.insert(0, os.path.abspath(_lib))

import argparse
import datetime as _dt
import json
import posixpath
import shutil
import subprocess
import tarfile
import tempfile
import urllib.parse
import urllib.request

import okamaos.config as cfg_mod
import okamaos.store as store_mod
import okamaos.updates as updates_mod

UPDATES_DIR = os.environ.get("OKAMA_UPDATES", "/var/okamaos/updates")
UPDATE_ROOT = os.environ.get("OKAMA_UPDATE_ROOT", "/")

PRESERVED_PATHS = (
    "/var/okamaos/games",
    "/var/okamaos/saves",
    "/var/okamaos/logs",
    "/var/okamaos/cache",
    "/var/okamaos/controllers",
    "/var/okamaos/updates",
    "/etc/okamaos/parent.conf",
    "/etc/okamaos/devmode.conf",
    "/etc/wpa_supplicant.conf",
    "/etc/wpa_supplicant",
    "/etc/dropbear",
)

ALLOWED_FILE_PREFIXES = (
    "/usr/bin/okama-",
    "/usr/lib/okamaos/",
    "/usr/share/okamaos/",
    "/etc/init.d/S",
    "/boot/okamaos/",
    "/boot/extlinux/",
)

ALLOWED_FILE_EXACT = (
    "/etc/profile",
    "/etc/issue",
    "/etc/motd",
)

MERGE_CONFIG_PATHS = (
    "/etc/okamaos/okama.conf",
)


def cmd_check(_args):
    conf = cfg_mod.get()
    print("okama-update: checking for updates...")
    print(f"  OS updates   : {conf.get('UPDATE_URL', updates_mod.UPDATE_URL_DEFAULT)}")
    print(f"  Game catalog : {conf.get('STORE_URL', store_mod.CATALOG_URL_DEFAULT)}")

    summary = updates_mod.check_all_updates()
    try:
        updates_mod.write_update_state(summary)
    except Exception:
        pass
    notices = []
    if summary.get("os_update"):
        notices.append(summary["os_update"])
    notices.extend(summary.get("game_updates", []))

    for error in summary.get("errors", []):
        print(f"  Warning      : {error}")

    if not notices:
        print("  Status       : no updates available")
        return

    print(f"  Status       : {len(notices)} update(s) available")
    for notice in notices:
        kind = "OS" if notice["type"] == "os" else "Game"
        size = store_mod.format_size(int(notice.get("size_bytes", 0) or 0))
        print("")
        print(f"{kind}: {notice.get('title', notice.get('name', '?'))}")
        print(f"  Installed    : {notice.get('current_version', '?')}")
        print(f"  Available    : {notice.get('version', '?')}")
        if notice.get("summary"):
            print(f"  Summary      : {notice['summary']}")
        if notice.get("download_url"):
            print(f"  Download     : {notice['download_url']}")
        if size != "? MB":
            print(f"  Size         : {size}")
        if notice.get("sha256"):
            print(f"  SHA-256      : {notice['sha256']}")


def cmd_apply(args):
    try:
        update_path = _resolve_update_file(args.update_file, expected_sha256=args.sha256)
        summary = _apply_bundle(update_path, dry_run=args.dry_run)
    except Exception as exc:
        print(f"ERROR: {exc}", file=sys.stderr)
        sys.exit(1)

    action = "would update" if args.dry_run else "updated"
    print(f"okama-update: {action} OkamaOS to {summary['version']}")
    print(f"  Files        : {len(summary['files'])}")
    print(f"  Config merges: {len(summary['configs'])}")
    if summary.get("pip_packages"):
        action_str = "would install" if args.dry_run else "installed"
        print(f"  Pip packages : {action_str}: {', '.join(summary['pip_packages'])}")
    if summary.get("backup_dir"):
        print(f"  Backup       : {summary['backup_dir']}")
    if summary.get("requires_reboot"):
        print("  Reboot       : required to finish the system update")


def cmd_rollback(_args):
    try:
        backup_dir = _latest_backup_dir()
        restored = _restore_backup(backup_dir)
    except Exception as exc:
        print(f"ERROR: {exc}", file=sys.stderr)
        sys.exit(1)

    print("okama-update: rollback complete")
    print(f"  Backup       : {backup_dir}")
    print(f"  Restored     : {restored} file(s)")
    print("  Reboot       : recommended")


def _resolve_update_file(source: str, expected_sha256: str = "") -> str:
    if source.startswith(("http://", "https://")):
        downloads_dir = _data_path(os.path.join(UPDATES_DIR, "downloads"))
        os.makedirs(downloads_dir, exist_ok=True)
        name = os.path.basename(urllib.parse.urlparse(source).path) or "update.okupdate"
        dest = os.path.join(downloads_dir, name)
        urllib.request.urlretrieve(source, dest)
        expected = expected_sha256 or _feed_hash_for_url(source)
        if expected:
            _verify_sha256(dest, expected)
        return dest
    if not os.path.exists(source):
        raise FileNotFoundError(f"file not found: {source}")
    if expected_sha256:
        _verify_sha256(source, expected_sha256)
    return source


def _feed_hash_for_url(source: str) -> str:
    try:
        info = updates_mod.fetch_release_info(timeout=3)
    except Exception:
        return ""
    if info.get("download_url") == source:
        checksum = info.get("checksum", info.get("sha256", ""))
        return checksum.split(":", 1)[-1] if checksum.startswith("sha256:") else checksum
    return ""


def _verify_sha256(path: str, expected: str) -> None:
    expected = expected.strip().lower()
    actual = updates_mod.sha256_file(path)
    if actual != expected:
        raise ValueError(
            f"SHA-256 mismatch for {os.path.basename(path)}: "
            f"expected {expected}, got {actual}"
        )


def _apply_bundle(path: str, dry_run: bool = False) -> dict:
    with tempfile.TemporaryDirectory(prefix="okama-update-") as tmp:
        manifest = None
        overlay_root = os.path.join(tmp, "files")
        try:
            with tarfile.open(path, "r:*") as tf:
                _safe_extract(tf, tmp)
        except tarfile.TarError:
            try:
                with open(path, encoding="utf-8") as f:
                    manifest = json.load(f)
            except Exception as exc:
                raise ValueError("not an installable .okupdate bundle") from exc

        if manifest is None:
            manifest_path = _first_existing(
                os.path.join(tmp, "manifest.okupdate.json"),
                os.path.join(tmp, "manifest.json"),
            )
            if not manifest_path:
                raise ValueError("missing manifest.okupdate.json in update bundle")
            with open(manifest_path, encoding="utf-8") as f:
                manifest = json.load(f)
            overlay_root = os.path.join(tmp, manifest.get("overlay_root", "files"))

        kind = manifest.get("type", "okamaos-system-update")
        if kind not in ("okamaos-system-update", "okamaos-update-manifest", "system"):
            raise ValueError("update bundle is not an OkamaOS system update")

        min_version = manifest.get("min_version", manifest.get("minimum_supported_version", ""))
        if min_version and updates_mod.is_newer(updates_mod.current_version(), min_version):
            raise ValueError(f"update requires OkamaOS {min_version} or newer")

        files = _collect_overlay_files(overlay_root)
        configs = manifest.get("config_merges", [])
        _validate_targets(files, configs)

        summary = {
            "version": manifest.get("version", "unknown"),
            "files": files,
            "configs": configs,
            "requires_reboot": bool(manifest.get("requires_reboot", True)),
        }
        pip_packages = _install_requirements(tmp, manifest, dry_run=True)
        summary["pip_packages"] = pip_packages

        if not files and not configs:
            summary["requires_reboot"] = False
            return summary
        if dry_run:
            return summary

        backup_dir = _make_backup_dir(manifest)
        _backup_targets(files, configs, backup_dir)
        _install_files(overlay_root, files)
        _merge_configs(configs)
        pip_packages = _install_requirements(tmp, manifest)
        _write_history(manifest, files, configs, backup_dir, pip_packages)
        summary["backup_dir"] = backup_dir
        summary["pip_packages"] = pip_packages
        return summary


def _first_existing(*paths: str) -> str:
    for path in paths:
        if os.path.exists(path):
            return path
    return ""


def _safe_extract(tf: tarfile.TarFile, dest: str) -> None:
    dest_abs = os.path.abspath(dest)
    for member in tf.getmembers():
        name = member.name
        if name.startswith("/") or ".." in name.split("/"):
            raise ValueError(f"unsafe update path: {name}")
        target = os.path.abspath(os.path.join(dest_abs, name))
        if not target.startswith(dest_abs + os.sep) and target != dest_abs:
            raise ValueError(f"unsafe update path: {name}")
        if member.islnk() or member.issym():
            raise ValueError(f"links are not allowed in updates: {name}")

    for member in tf.getmembers():
        target = os.path.abspath(os.path.join(dest_abs, member.name))
        if member.isdir():
            os.makedirs(target, exist_ok=True)
        elif member.isfile():
            os.makedirs(os.path.dirname(target), exist_ok=True)
            src = tf.extractfile(member)
            if src is None:
                raise ValueError(f"cannot read update file: {member.name}")
            with src, open(target, "wb") as out:
                shutil.copyfileobj(src, out)
            os.chmod(target, member.mode & 0o777)
        else:
            raise ValueError(f"unsupported update entry: {member.name}")


def _collect_overlay_files(overlay_root: str) -> list:
    if not os.path.isdir(overlay_root):
        return []
    files = []
    for root, _, names in os.walk(overlay_root):
        for name in names:
            src = os.path.join(root, name)
            rel = os.path.relpath(src, overlay_root)
            target = "/" + rel.replace(os.sep, "/")
            files.append({"source": rel, "target": target})
    return sorted(files, key=lambda item: item["target"])


def _validate_targets(files: list, configs: list) -> None:
    for item in files:
        target = _clean_target(item["target"])
        if _is_preserved_path(target):
            raise ValueError(f"update refuses to overwrite preserved data: {target}")
        if not _is_allowed_file_target(target):
            raise ValueError(f"update target is not allowed: {target}")
        item["target"] = target

    for item in configs:
        path = _clean_target(item.get("path", ""))
        if path not in MERGE_CONFIG_PATHS:
            raise ValueError(f"config merge target is not allowed: {path}")
        item["path"] = path


def _clean_target(path: str) -> str:
    if not path:
        raise ValueError("empty update target")
    clean = posixpath.normpath("/" + path.lstrip("/"))
    if clean == "/" or clean.startswith("/../"):
        raise ValueError(f"unsafe update target: {path}")
    return clean


def _is_preserved_path(path: str) -> bool:
    return any(path == p or path.startswith(p + "/") for p in PRESERVED_PATHS)


def _is_allowed_file_target(path: str) -> bool:
    return path in ALLOWED_FILE_EXACT or any(path.startswith(p) for p in ALLOWED_FILE_PREFIXES)


def _root_path(path: str) -> str:
    return os.path.join(UPDATE_ROOT, path.lstrip("/"))


def _data_path(path: str) -> str:
    if UPDATE_ROOT == "/" or not os.path.isabs(path):
        return path
    root_abs = os.path.abspath(UPDATE_ROOT)
    path_abs = os.path.abspath(path)
    if path_abs == root_abs or path_abs.startswith(root_abs + os.sep):
        return path
    return _root_path(path)


def _make_backup_dir(manifest: dict) -> str:
    version = manifest.get("version", "unknown")
    stamp = _dt.datetime.now(_dt.UTC).strftime("%Y%m%dT%H%M%SZ")
    base = _data_path(cfg_mod.get().get(
        "UPDATE_BACKUP_DIR", os.path.join(UPDATES_DIR, "backups")))
    path = os.path.join(base, f"{stamp}-{version}")
    os.makedirs(path, exist_ok=False)
    return path


def _backup_targets(files: list, configs: list, backup_dir: str) -> None:
    entries = []
    targets = [item["target"] for item in files] + [item["path"] for item in configs]
    for target in sorted(set(targets)):
        src = _root_path(target)
        rel = target.lstrip("/")
        backup_path = os.path.join(backup_dir, "files", rel)
        existed = os.path.exists(src)
        if existed:
            os.makedirs(os.path.dirname(backup_path), exist_ok=True)
            shutil.copy2(src, backup_path)
        entries.append({"target": target, "existed": existed})

    with open(os.path.join(backup_dir, "backup.json"), "w", encoding="utf-8") as f:
        json.dump({"entries": entries}, f, indent=2)


def _install_files(overlay_root: str, files: list) -> None:
    for item in files:
        src = os.path.join(overlay_root, item["source"])
        dest = _root_path(item["target"])
        os.makedirs(os.path.dirname(dest), exist_ok=True)
        shutil.copy2(src, dest)
        if item["target"].startswith("/usr/bin/") or "/init.d/" in item["target"]:
            os.chmod(dest, 0o755)


def _merge_configs(configs: list) -> None:
    for item in configs:
        path = _root_path(item["path"])
        data = _read_kv(path)
        data.update({str(k): str(v) for k, v in item.get("set", {}).items()})
        for key, value in item.get("defaults", {}).items():
            data.setdefault(str(key), str(value))
        _write_kv(path, data)


def _read_kv(path: str) -> dict:
    data = {}
    try:
        with open(path, encoding="utf-8") as f:
            for line in f:
                line = line.strip()
                if not line or line.startswith("#") or "=" not in line:
                    continue
                key, value = line.split("=", 1)
                data[key] = value
    except FileNotFoundError:
        pass
    return data


def _write_kv(path: str, data: dict) -> None:
    os.makedirs(os.path.dirname(path), exist_ok=True)
    with open(path, "w", encoding="utf-8") as f:
        for key in sorted(data):
            f.write(f"{key}={data[key]}\n")


def _install_requirements(bundle_tmp: str, manifest: dict, dry_run: bool = False) -> list:
    """Install Python deps from requirements.txt bundled at the tar root.

    The file name is configurable via manifest key ``requirements_file``
    (default: ``requirements.txt``).  Runs ``sys.executable -m pip install``
    so the correct interpreter is always used.  A ``PIP_INSTALL_ARGS``
    okama.conf key can override extra pip flags (default:
    ``--break-system-packages``).

    Returns the list of installed requirement specs, or [] if absent.
    """
    req_file = manifest.get("requirements_file", "requirements.txt")
    req_path = os.path.join(bundle_tmp, req_file)
    if not os.path.isfile(req_path):
        return []
    with open(req_path, encoding="utf-8") as f:
        packages = [ln.strip() for ln in f
                    if ln.strip() and not ln.strip().startswith("#")]
    if not packages:
        return []
    if dry_run:
        return packages

    extra_args = cfg_mod.get().get(
        "PIP_INSTALL_ARGS", "--break-system-packages"
    ).split()
    cmd = [sys.executable, "-m", "pip", "install", "--quiet"] + extra_args + ["-r", req_path]
    result = subprocess.run(cmd, capture_output=True, text=True, timeout=300)
    if result.returncode != 0:
        # Retry without the extra flags in case the flag is unsupported
        cmd_fallback = [sys.executable, "-m", "pip", "install", "--quiet", "-r", req_path]
        result = subprocess.run(cmd_fallback, capture_output=True, text=True, timeout=300)
    if result.returncode != 0:
        raise RuntimeError(
            f"pip install failed for update requirements:\n"
            f"{(result.stderr or result.stdout).strip()[:300]}"
        )
    return packages


def _write_history(manifest: dict, files: list, configs: list, backup_dir: str,
                   pip_packages: list = None) -> None:
    history_dir = _data_path(cfg_mod.get().get(
        "UPDATE_HISTORY_DIR", os.path.join(UPDATES_DIR, "history")))
    os.makedirs(history_dir, exist_ok=True)
    version = manifest.get("version", "unknown")
    entry = {
        "version": version,
        "applied_at": _dt.datetime.now(_dt.UTC).isoformat(),
        "backup_dir": backup_dir,
        "files": files,
        "configs": configs,
    }
    if pip_packages:
        entry["pip_packages"] = pip_packages
    with open(os.path.join(history_dir, f"{version}.json"), "w", encoding="utf-8") as f:
        json.dump(entry, f, indent=2)


def _latest_backup_dir() -> str:
    base = _data_path(cfg_mod.get().get(
        "UPDATE_BACKUP_DIR", os.path.join(UPDATES_DIR, "backups")))
    try:
        names = sorted(n for n in os.listdir(base) if os.path.isdir(os.path.join(base, n)))
    except FileNotFoundError:
        names = []
    if not names:
        raise FileNotFoundError("no update backups available")
    return os.path.join(base, names[-1])


def _restore_backup(backup_dir: str) -> int:
    manifest_path = os.path.join(backup_dir, "backup.json")
    with open(manifest_path, encoding="utf-8") as f:
        manifest = json.load(f)

    restored = 0
    for entry in manifest.get("entries", []):
        target = _clean_target(entry["target"])
        dest = _root_path(target)
        backup_file = os.path.join(backup_dir, "files", target.lstrip("/"))
        if entry.get("existed"):
            os.makedirs(os.path.dirname(dest), exist_ok=True)
            shutil.copy2(backup_file, dest)
            restored += 1
        elif os.path.exists(dest):
            os.remove(dest)
            restored += 1
    return restored


def main():
    p = argparse.ArgumentParser(prog="okama-update",
                                description="OkamaOS OTA update tool")
    sub = p.add_subparsers(dest="command", required=True)

    sub.add_parser("check")

    pa = sub.add_parser("apply")
    pa.add_argument("--dry-run", action="store_true")
    pa.add_argument("--sha256", default="", help="Expected SHA-256 for the update bundle")
    pa.add_argument("update_file")

    sub.add_parser("rollback")

    args = p.parse_args()
    {"check": cmd_check, "apply": cmd_apply, "rollback": cmd_rollback}[args.command](args)


if __name__ == "__main__":
    main()
