#! /usr/bin/env python3
"""stack_corr — compute mean correlation from a stack of correlation grids.

Python port of csh stack_corr.csh (X. Tong, D. Sandwell 2013).
Implements the Rosen et al. (2000) mean-coherence stacking formula:
  mean_corr = sqrt( 1 / (1 + (1/N) * sum( 1 / (corr_i^2) - 1 / corr_i^2 ))^-1 )
The csh form is more transparently:
  for each corr:  tmp = 1/corr^2;  accum += 1 / tmp  (i.e. accum += corr^2)
  mean = sqrt( 1 / (1 + accum/N) )    -- but the actual formula in csh is:
  mean = sqrt( 1 / (sum/N + 1) )  with sum = sum_i (1 - corr_i^2) / corr_i^2

Usage:  stack_corr grid_list out.grd
"""
import os
import sys
from gmtsar_lib import run


def stack_corr():
    if len(sys.argv) != 3:
        sys.exit(
            "Usage: stack_corr grid_list out.grd\n"
            "  grid_list: text file listing correlation .grd paths\n"
            "  All grids must have consistent dimensions."
        )
    grid_list, out = sys.argv[1], sys.argv[2]
    if not os.path.isfile(grid_list):
        sys.exit(f"no input file found: {grid_list}")

    with open(grid_list) as f:
        files = [ln.strip() for ln in f if ln.strip()]
    if not files:
        sys.exit("stack_corr: empty input list")

    print("computing the mean correlation of the grids ..")
    for i, cor in enumerate(files, start=1):
        if not os.path.isfile(cor):
            sys.exit(f" Error: file not found: {cor}")
        # tmp = corr^2;   per-pair contribution = (1 - tmp) / tmp
        run(f"gmt grdmath {cor} SQR = tmp.grd")
        if i == 1:
            run("gmt grdmath 1 tmp.grd SUB tmp.grd DIV = sum.grd")
        else:
            run("gmt grdmath 1 tmp.grd SUB tmp.grd DIV sum.grd ADD = tmp2.grd")
            run("mv tmp2.grd sum.grd")

    num = len(files)
    # Final: mean_corr = sqrt(1 / (1 + sum/N))
    run(f"gmt grdmath 1 sum.grd {num} DIV 1 ADD DIV SQRT = {out}")
    run("rm -f tmp.grd sum.grd")


if __name__ == "__main__":
    stack_corr()
