# -*- coding: utf-8 -*-
"""
Stable MODIS L2 pipeline for AutoDL
- fixed to /root/autodl-fs/ocssw
- serial getanc
- recover .anc from log if getanc does not write file
- parallel l2gen
- automatic HKM/QKM link repair for 500m / 250m
"""

import os
import re
import stat
import subprocess
from pathlib import Path
from datetime import datetime
from multiprocessing import Pool

import numpy as np
from netCDF4 import Dataset
from PIL import Image

# =========================================================
# 1. CONFIG
# =========================================================

OCSSW_ROOT = Path("/root/autodl-fs/ocssw")
OCSSW_ENV = OCSSW_ROOT / "OCSSW_bash.env"

ROOT_DATA = Path("/root/autodl-tmp/500m_resolution/OCSSW_OUT/2017/10/30")
OUT_BASE = Path("/root/autodl-tmp/500m_resolution/l2")

ANC_DIR = OUT_BASE / "ANC"
LOG_DIR = OUT_BASE / "LOG_l2"
PNG_DIR = OUT_BASE / "TrueColor_png"

PYTHON3 = "/root/miniconda3/bin/python3"

CMD_L2GEN = OCSSW_ROOT /A "bin" / "l2gen"
CMD_GETANC = OCSSW_ROOT / "bin" / "getanc.py"

TARGET_RESOLUTION = "500"

PRODUCTS = "rhos_469,rhos_555,rhos_645,rhos_859,rhos_1240,l2_flags"

L2GEN_OPTS = [
    "oformat=netcdf4",
    "proc_ocean=1",
    "maskland=0",
    "deflate=0"
]

TIMEOUT_GETANC = 3600
TIMEOUT_L2GEN = 7200

SKIP_EXISTING_ANC = True
SKIP_EXISTING_L2 = True
SKIP_EXISTING_PNG = True

# 关键修复：文本 anc 文件不能用 1024 阈值
MIN_TEXT_FILE_SIZE = 10
MIN_VALID_FILE_SIZE = 1024
MIN_L2_SIZE = 50000

L2_WORKERS = 40

USE_SUBSET = False
NORTH = 45
SOUTH = 0
WEST = 100
EAST = 145

AUTO_LINK_LAADS_COMPANION = True

P_LOW = 1.0
P_HIGH = 99.5
GAMMA = 1.2


# =========================================================
# 2. BASIC UTILS
# =========================================================

def ensure_dir(p: Path):
    p.mkdir(parents=True, exist_ok=True)


def is_nonempty(path: Path, min_size: int = MIN_VALID_FILE_SIZE) -> bool:
    return path is not None and path.exists() and path.stat().st_size >= min_size


def safe_symlink(src: Path, dst: Path):
    if not src.exists():
        print(f"[WARN] source not found, skip link: {src}")
        return None

    ensure_dir(dst.parent)

    try:
        if dst.is_symlink():
            try:
                if dst.resolve() == src.resolve():
                    return dst
            except Exception:
                pass
            dst.unlink()
        elif dst.exists():
            if dst.is_file():
                return dst
    except FileNotFoundError:
        pass

    dst.symlink_to(src)
    return dst


# =========================================================
# 3. ENVIRONMENT
# =========================================================

def load_ocssw_env(env_script: Path):
    if not env_script.exists():
        raise SystemExit(f"OCSSW env script not found: {env_script}")

    p = subprocess.run(
        ["bash", "-lc", f"source '{env_script}' >/dev/null 2>&1; env -0"],
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE,
        check=True
    )

    env = {}

    keep_keys = {
        "HOME", "USER", "LOGNAME", "LANG", "LC_ALL", "LC_CTYPE",
        "SHELL", "TMPDIR", "TEMP", "TMP"
    }
    for k, v in os.environ.items():
        if k in keep_keys:
            env[k] = v

    for item in p.stdout.split(b"\x00"):
        if item:
            k, _, v = item.partition(b"=")
            env[k.decode()] = v.decode()

    env["OCSSWROOT"] = str(OCSSW_ROOT)
    env["OCDATAROOT"] = str(OCSSW_ROOT / "share")
    env["OCVARROOT"] = str(OCSSW_ROOT / "var")
    env["OCSSW_BIN"] = str(OCSSW_ROOT / "bin")
    env["LIB3_BIN"] = str(OCSSW_ROOT / "bin")
    env["LIB3_DIR"] = str(OCSSW_ROOT / "lib3")

    ensure_dir(Path(env["OCVARROOT"]))

    env["PYTHONPATH"] = ":".join([
        str(OCSSW_ROOT / "bin"),
        str(OCSSW_ROOT / "bin" / "modis"),
        str(OCSSW_ROOT / "bin" / "ocssw_src" / "src" / "scripts"),
    ])

    old_path = env.get("PATH", "").split(":") if env.get("PATH") else []
    clean_old_path = []
    for pth in old_path:
        if not pth:
            continue
        if "ocssw" in pth.lower():
            continue
        if pth not in clean_old_path:
            clean_old_path.append(pth)

    base_path = [
        str(OCSSW_ROOT / "bin"),
        str(OCSSW_ROOT / "opt" / "bin"),
        "/opt/bin",
        "/usr/local/bin",
        "/usr/bin",
        "/bin",
    ]

    final_path = []
    for pth in base_path + clean_old_path:
        if pth and pth not in final_path:
            final_path.append(pth)
    env["PATH"] = ":".join(final_path)

    for bad_key in [
        "PYTHONHOME",
        "PYTHONSTARTUP",
        "PYTHONUSERBASE",
        "VIRTUAL_ENV",
        "CONDA_PREFIX",
        "CONDA_DEFAULT_ENV",
    ]:
        env.pop(bad_key, None)

    return env


def debug_ocssw_env(env: dict):
    print("========== DEBUG OCSSW ENV ==========")
    print("OCSSWROOT =", env.get("OCSSWROOT"))
    print("OCDATAROOT =", env.get("OCDATAROOT"))
    print("OCVARROOT =", env.get("OCVARROOT"))
    print("PYTHONPATH =", env.get("PYTHONPATH"))
    print("PATH =", env.get("PATH"))

    checks = [
        ("seadasutils.anc_utils", "import seadasutils.anc_utils; print(seadasutils.anc_utils.__file__)"),
        ("modis_utils", "import modis_utils; print(modis_utils.__file__)"),
    ]

    for name, code in checks:
        proc = subprocess.run(
            [PYTHON3, "-c", code],
            env=env,
            stdout=subprocess.PIPE,
            stderr=subprocess.STDOUT,
            text=True
        )
        print(f"{name} = {proc.stdout.strip()}")
    print("=====================================")


def require_executable(path_candidates, name):
    for p in path_candidates:
        if not p:
            continue
        pp = Path(p)
        if pp.is_file() and os.access(pp, os.X_OK):
            print(f"[OK] {name}: {pp}")
            return str(pp)
    raise SystemExit(f"{name} not found or not executable: {path_candidates}")


def check_runtime(env: dict):
    require_executable([str(CMD_GETANC)], "getanc.py")
    require_executable([str(CMD_L2GEN)], "l2gen")

    if not Path(env["OCDATAROOT"]).is_dir():
        raise SystemExit(f"OCDATAROOT not found: {env['OCDATAROOT']}")
    if not Path(env["OCVARROOT"]).is_dir():
        raise SystemExit(f"OCVARROOT not found: {env['OCVARROOT']}")


def check_netrc():
    netrc = Path.home() / ".netrc"
    if not netrc.exists():
        raise SystemExit("~/.netrc not found. Earthdata/OceanColor authentication is required.")

    mode = stat.S_IMODE(netrc.stat().st_mode)
    if mode != 0o600:
        raise SystemExit(f"~/.netrc permissions must be 600, current: {oct(mode)}")

    text = netrc.read_text(errors="ignore")
    if "urs.earthdata.nasa.gov" not in text and "oceandata.sci.gsfc.nasa.gov" not in text:
        raise SystemExit("~/.netrc missing Earthdata/OceanColor machine entries")

    print("[OK] ~/.netrc looks valid")


# =========================================================
# 4. FILE NAMING / SEARCH
# =========================================================

def mission_from_l1b_name(name: str) -> str:
    if name.startswith("A"):
        return "modisa"
    if name.startswith("T"):
        return "modist"
    return "modisa"


def parse_ocssw_l1b_datetime(l1b_path: Path):
    m = re.match(r"^[AT](\d{4})(\d{3})(\d{6})\.L1B_LAC$", l1b_path.name)
    if not m:
        return None
    year = int(m.group(1))
    doy = int(m.group(2))
    hhmmss = m.group(3)
    return datetime.strptime(f"{year} {doy:03d} {hhmmss}", "%Y %j %H%M%S")


def parse_laads_datetime(path: Path):
    m = re.search(r"\.(\d{8}T\d{6})\.", path.name)
    if not m:
        return None
    return datetime.strptime(m.group(1), "%Y%m%dT%H%M%S")


def find_companion_file(l1b_path: Path, subdir_name: str, src_suffix: str, dst_suffix: str):
    parts = list(l1b_path.parts)
    if "L1B" not in parts:
        return None
    parts[parts.index("L1B")] = subdir_name
    return Path(*parts).with_name(l1b_path.name.replace(src_suffix, dst_suffix))


def find_parallel_dir(l1b_path: Path, subdir_name: str):
    parts = list(l1b_path.parts)
    if "L1B" not in parts:
        return None
    parts[parts.index("L1B")] = subdir_name
    return Path(*parts).parent / subdir_name


def find_laads_companion_by_datetime(l1b_path: Path, product_kind: str):
    dt_l1b = parse_ocssw_l1b_datetime(l1b_path)
    if dt_l1b is None:
        return None

    target_dir = find_parallel_dir(l1b_path, product_kind)
    if target_dir is None or not target_dir.exists():
        return None

    platform = "AQUA" if l1b_path.name.startswith("A") else "TERRA"
    patterns = [
        f"{platform}_MODIS_{product_kind}.*.L1B.hdf",
        f"{platform.lower()}_modis_{product_kind.lower()}.*.l1b.hdf",
        f"*{product_kind}*.hdf",
    ]

    candidates = []
    seen = set()
    for pat in patterns:
        for p in sorted(target_dir.glob(pat)):
            if p not in seen:
                candidates.append(p)
                seen.add(p)

    for p in candidates:
        dt = parse_laads_datetime(p)
        if dt is not None and dt == dt_l1b:
            return p

    return None


def resolve_required_companion(l1b_path: Path, product_kind: str):
    assert product_kind in {"HKM", "QKM"}

    target_in_l1b = l1b_path.parent / l1b_path.name.replace("L1B_LAC", f"L1B_{product_kind}")

    if is_nonempty(target_in_l1b):
        return target_in_l1b

    standard = find_companion_file(
        l1b_path=l1b_path,
        subdir_name=product_kind,
        src_suffix="L1B_LAC",
        dst_suffix=f"L1B_{product_kind}"
    )

    if is_nonempty(standard):
        link = safe_symlink(standard, target_in_l1b)
        return link if link is not None else target_in_l1b

    laads = find_laads_companion_by_datetime(l1b_path, product_kind)
    if is_nonempty(laads) and AUTO_LINK_LAADS_COMPANION:
        link = safe_symlink(laads, target_in_l1b)
        if link is not None:
            print(f"[LINK {product_kind}] {link} -> {laads}")
            return link

    return target_in_l1b if target_in_l1b.exists() else None


def find_l1b_geo_pairs(root: Path):
    for l1b in sorted(root.rglob("*.L1B_LAC")):
        geo1 = l1b.with_name(l1b.name.replace("L1B_LAC", "GEO"))
        if geo1.exists():
            yield l1b, geo1
            continue

        geo2 = find_companion_file(
            l1b_path=l1b,
            subdir_name="GEO",
            src_suffix="L1B_LAC",
            dst_suffix="GEO"
        )
        if geo2 and geo2.exists():
            yield l1b, geo2
            continue

        yield l1b, None


# =========================================================
# 5. RUNNER
# =========================================================

def run(cmd, log_path: Path, env, timeout=7200, check=True, cwd=None):
    ensure_dir(log_path.parent)
    with open(log_path, "w", encoding="utf-8") as log:
        p = subprocess.run(
            cmd,
            stdout=log,
            stderr=subprocess.STDOUT,
            env=env,
            timeout=timeout,
            cwd=cwd
        )
    if check and p.returncode != 0:
        raise subprocess.CalledProcessError(p.returncode, cmd)
    return p.returncode


# =========================================================
# 6. PNG
# =========================================================

def _read_l2_var(ds: Dataset, name: str):
    if "geophysical_data" in ds.groups and name in ds.groups["geophysical_data"].variables:
        return np.array(ds.groups["geophysical_data"].variables[name][:], dtype=np.float32)
    if name in ds.variables:
        return np.array(ds.variables[name][:], dtype=np.float32)
    raise KeyError(name)


def stretch_to_uint8(arr, p_low=1.0, p_high=99.5, gamma=1.2):
    arr = np.asarray(arr, dtype=np.float32)
    valid = np.isfinite(arr)
    out = np.zeros(arr.shape, dtype=np.uint8)
    if not np.any(valid):
        return out

    lo = np.nanpercentile(arr[valid], p_low)
    hi = np.nanpercentile(arr[valid], p_high)

    if hi <= lo:
        scaled = np.zeros_like(arr, dtype=np.float32)
    else:
        scaled = (arr - lo) / (hi - lo)

    scaled = np.clip(scaled, 0.0, 1.0)
    if gamma != 1.0:
        scaled = np.power(scaled, 1.0 / gamma)
    return (scaled * 255.0).astype(np.uint8)


def make_truecolor_png_from_l2(l2_nc: Path, out_png: Path):
    with Dataset(str(l2_nc), "r") as ds:
        r = _read_l2_var(ds, "rhos_645")
        g = _read_l2_var(ds, "rhos_555")
        b = _read_l2_var(ds, "rhos_469")

    r8 = stretch_to_uint8(r, P_LOW, P_HIGH, GAMMA)
    g8 = stretch_to_uint8(g, P_LOW, P_HIGH, GAMMA)
    b8 = stretch_to_uint8(b, P_LOW, P_HIGH, GAMMA)

    rgb = np.dstack([r8, g8, b8])
    ensure_dir(out_png.parent)
    Image.fromarray(rgb).save(out_png)


# =========================================================
# 7. ANC STAGE (SERIAL)
# =========================================================

def recover_anc_from_log(log_path: Path, anc_path: Path) -> bool:
    try:
        text = log_path.read_text(encoding="utf-8", errors="ignore")
    except Exception:
        return False

    anc_lines = []
    for line in text.splitlines():
        s = line.strip()
        if "=" not in s:
            continue

        key, value = s.split("=", 1)
        key = key.strip()
        value = value.strip()

        allowed_prefixes = (
            "met", "ozone", "sstfile", "icefile",
            "no2file", "att1", "eph1", "att2", "eph2"
        )

        if key.startswith(allowed_prefixes) and value:
            anc_lines.append(f"{key}={value}")

    dedup = []
    seen = set()
    for line in anc_lines:
        if line not in seen:
            dedup.append(line)
            seen.add(line)

    if dedup:
        anc_path.write_text("\n".join(dedup) + "\n", encoding="utf-8")
        return is_nonempty(anc_path, min_size=MIN_TEXT_FILE_SIZE)

    return False


def prepare_one_anc(args):
    l1b_path, geo_path = args
    env = load_ocssw_env(OCSSW_ENV)

    if geo_path is None:
        return ("skip_geo", str(l1b_path))

    if not is_nonempty(l1b_path):
        return ("invalid_l1b", str(l1b_path))

    if TARGET_RESOLUTION == "500":
        hkm_path = resolve_required_companion(l1b_path, "HKM")
        if not is_nonempty(hkm_path):
            return ("skip_hkm", str(l1b_path))

    if TARGET_RESOLUTION == "250":
        hkm_path = resolve_required_companion(l1b_path, "HKM")
        qkm_path = resolve_required_companion(l1b_path, "QKM")
        if not is_nonempty(hkm_path):
            return ("skip_hkm", str(l1b_path))
        if not is_nonempty(qkm_path):
            return ("skip_qkm", str(l1b_path))

    rootname = l1b_path.name.replace(".L1B_LAC", "")
    ancfiles_dir = ANC_DIR / "ancfiles"
    ensure_dir(ancfiles_dir)
    ensure_dir(LOG_DIR)

    anc_path = ancfiles_dir / f"{rootname}.anc"
    anc_name = anc_path.name
    log_path = LOG_DIR / f"{rootname}.getanc.log"

    if SKIP_EXISTING_ANC and is_nonempty(anc_path, min_size=MIN_TEXT_FILE_SIZE):
        return ("anc_exists", str(l1b_path))

    cmd = [
        PYTHON3, str(CMD_GETANC),
        "--mission", mission_from_l1b_name(l1b_path.name),
        "--ancdir", str(ANC_DIR),
        "-o", anc_name,
        str(l1b_path)
    ]

    try:
        run(
            cmd,
            log_path=log_path,
            env=env,
            timeout=TIMEOUT_GETANC,
            check=True,
            cwd=ancfiles_dir
        )
    except subprocess.CalledProcessError:
        if recover_anc_from_log(log_path, anc_path):
            print(f"[ANC RECOVERED AFTER FAIL] {anc_path}")
            return ("anc_recovered", str(l1b_path))
        return ("fail_getanc", str(l1b_path))

    if is_nonempty(anc_path, min_size=MIN_TEXT_FILE_SIZE):
        return ("anc_ok", str(l1b_path))

    if recover_anc_from_log(log_path, anc_path):
        print(f"[ANC RECOVERED FROM LOG] {anc_path}")
        return ("anc_recovered", str(l1b_path))

    return ("fail_anc_missing", str(l1b_path))


# =========================================================
# 8. L2 STAGE (PARALLEL)
# =========================================================

def process_one_scene(args):
    l1b_path, geo_path = args
    env = load_ocssw_env(OCSSW_ENV)

    if geo_path is None:
        return ("skip_geo", str(l1b_path))

    if not is_nonempty(l1b_path):
        return ("invalid_l1b", str(l1b_path))

    if TARGET_RESOLUTION == "500":
        hkm_path = resolve_required_companion(l1b_path, "HKM")
        if not is_nonempty(hkm_path):
            return ("skip_hkm", str(l1b_path))

    if TARGET_RESOLUTION == "250":
        hkm_path = resolve_required_companion(l1b_path, "HKM")
        qkm_path = resolve_required_companion(l1b_path, "QKM")
        if not is_nonempty(hkm_path):
            return ("skip_hkm", str(l1b_path))
        if not is_nonempty(qkm_path):
            return ("skip_qkm", str(l1b_path))

    rootname = l1b_path.name.replace(".L1B_LAC", "")

    rel_dir = l1b_path.parent.relative_to(ROOT_DATA)
    l2_dir = OUT_BASE / rel_dir.parent / "L2"
    png_dir = PNG_DIR / rel_dir.parent / "TrueColor"
    ancfiles_dir = ANC_DIR / "ancfiles"

    ensure_dir(l2_dir)
    ensure_dir(png_dir)
    ensure_dir(ancfiles_dir)
    ensure_dir(LOG_DIR)

    l2_path = l2_dir / f"{rootname}.L2.nc"
    png_path = png_dir / f"{rootname}_truecolor.png"
    anc_path = ancfiles_dir / f"{rootname}.anc"

    # 关键修复：anc 用文本阈值
    if not is_nonempty(anc_path, min_size=MIN_TEXT_FILE_SIZE):
        return ("missing_anc", str(l1b_path))

    if SKIP_EXISTING_L2 and is_nonempty(l2_path, MIN_L2_SIZE):
        l2_ok = True
    else:
        cmd = [
            str(CMD_L2GEN),
            f"ifile={l1b_path}",
            f"geofile={geo_path}",
            f"ofile={l2_path}",
            f"l2prod={PRODUCTS}",
            f"par={anc_path}",
            f"resolution={TARGET_RESOLUTION}",
            *L2GEN_OPTS
        ]

        if USE_SUBSET:
            cmd.extend([
                f"north={NORTH}",
                f"south={SOUTH}",
                f"west={WEST}",
                f"east={EAST}",
            ])

        try:
            run(cmd, LOG_DIR / f"{rootname}.l2gen.log", env, TIMEOUT_L2GEN, check=True)
        except subprocess.CalledProcessError:
            return ("fail_l2gen", str(l1b_path))

        l2_ok = is_nonempty(l2_path, MIN_L2_SIZE)

    if not l2_ok:
        return ("fail_l2_empty", str(l1b_path))

    if not (SKIP_EXISTING_PNG and png_path.exists() and png_path.stat().st_size > 0):
        try:
            make_truecolor_png_from_l2(l2_path, png_path)
        except Exception:
            return ("fail_png", str(l1b_path))

    return ("ok", str(l1b_path))


# =========================================================
# 9. MAIN
# =========================================================

if __name__ == "__main__":
    print("========================================")
    print(f"ROOT_DATA          = {ROOT_DATA}")
    print(f"OUT_BASE           = {OUT_BASE}")
    print(f"TARGET_RESOLUTION  = {TARGET_RESOLUTION}")
    print(f"PRODUCTS           = {PRODUCTS}")
    print(f"L2_WORKERS         = {L2_WORKERS}")
    print(f"USE_SUBSET         = {USE_SUBSET}")
    if USE_SUBSET:
        print(f"N/S/W/E            = {NORTH}/{SOUTH}/{WEST}/{EAST}")
    print("========================================")

    ensure_dir(ANC_DIR)
    ensure_dir(LOG_DIR)
    ensure_dir(PNG_DIR)

    env = load_ocssw_env(OCSSW_ENV)
    debug_ocssw_env(env)
    check_netrc()
    check_runtime(env)

    pairs = list(find_l1b_geo_pairs(ROOT_DATA))
    print(f"Found {len(pairs)} L1B files")

    print("\n========== STAGE 1: PREPARE ANC (SERIAL) ==========")
    anc_summary = {}
    ready_pairs = []

    for pair in pairs:
        status, item = prepare_one_anc(pair)
        anc_summary[status] = anc_summary.get(status, 0) + 1

        if status in {"anc_ok", "anc_exists", "anc_recovered"}:
            ready_pairs.append(pair)

        print(f"[ANC] {status}: {item}")

    print("\n========== ANC SUMMARY ==========")
    for k in sorted(anc_summary):
        print(f"{k:16s}: {anc_summary[k]}")

    print(f"\nReady for l2gen: {len(ready_pairs)}")

    if len(ready_pairs) == 0:
        print("\nNo scenes ready for l2gen. Exit.")
        raise SystemExit(0)

    print("\n========== STAGE 2: L2GEN (PARALLEL) ==========")
    with Pool(processes=L2_WORKERS) as pool:
        results = pool.map(process_one_scene, ready_pairs)

    l2_summary = {}
    for status, item in results:
        l2_summary[status] = l2_summary.get(status, 0) + 1
        print(f"[L2] {status}: {item}")

    print("\n========== L2 SUMMARY ==========")
    for k in sorted(l2_summary):
        print(f"{k:16s}: {l2_summary[k]}")
