import numpy as np
from time import time
import matplotlib.pyplot as plt
from fractions import Fraction

import sys
sys.path.append('../.')
import odeg as od

# from matplotlib import rc
plt.rcParams['figure.dpi'] = 300
plt.rcParams['savefig.dpi'] = 300
# rc('text', usetex=True)

def app_rel_err(fill_ratio, **cfg):
    err_max = 0
    
    for pmax in (2, 3, 4):
        cfg['p_max'] = pmax
        cfg['pos_charge'] = round(fill_ratio * 2 * (2*pmax+1), 5)
        
        app_f = od.freeEnergy(**cfg)
        
        cfg['p_max'] = 5
        true_f = od.freeEnergy(**cfg)
        
        err_max = max(err_max, abs((app_f-true_f) / true_f))
        
    return err_max


def err_repr(val, rel_err):
    abs_err = abs(val * rel_err)
    # print("ae1:", abs_err)
    digit = int(np.floor(np.log10(abs_err)))
    abs_err = round(abs_err, -digit)
    # print("ae2:", abs_err)
    digit = int(np.floor(np.log10(abs_err))) # in case abs_err was rounded up from .xx98 to .x10 
    
    abs_err_leading_digit = round(abs_err * 10**(-digit))
    val_rounded = round(val, -digit)
    
    # if digit < 0:
    #     res = (f"%.{-digit}f" % val_rounded) + r" $\pm$ " + (f"%.{-digit}f" % abs_err)
    # else:
    #     res = ("%.f" % val_rounded) + r" $\pm$ " + ("%f" % abs_err)
        
    if digit < 0:
        res = f"%.{-digit}f(%d)" % (val_rounded, abs_err_leading_digit)
    else:
        res = "%f(%d)" % (val_rounded, abs_err_leading_digit)
    
    # print("val, rel_err = ", val, rel_err)
    # print("output:", res)
    return "$" + res + "$"

if __name__ == '__main__':
    
    r0_grid = np.array([0.5, 1.])
    rs_grid = np.array([2., 4., 8.])
    theta_grid = np.array([.125, .5, 1.])
    
    
    cfg = {
        'p_max': 5,
        'pos_charge': 5,
        'path': "../results",
    }
    
    fill_ratio = cfg['pos_charge'] / (2 * (2*cfg['p_max']+1))

    free_energy_per_pc = np.zeros((r0_grid.size, rs_grid.size, theta_grid.size))
    rel_error = np.zeros(free_energy_per_pc.shape)
    
    for r0_ind, rs_ind, theta_ind in np.ndindex(r0_grid.size, rs_grid.size, theta_grid.size):
        
        cfg['r0'] = r0_grid[r0_ind]
        cfg['rs'] = rs_grid[rs_ind]
        cfg['theta'] = theta_grid[theta_ind]
        
        free_energy_per_pc[r0_ind, rs_ind, theta_ind] = od.freeEnergy(**cfg) / cfg['pos_charge']
        rel_error[r0_ind, rs_ind, theta_ind] = app_rel_err(fill_ratio, **cfg)
        
        print("\n\ncfg:", cfg)
        print("fe:", free_energy_per_pc[r0_ind, rs_ind, theta_ind])
        print("err:", rel_error[r0_ind, rs_ind, theta_ind])
        
    
    file = open("table_f.tex", "w")
    
    for theta_ind in range(theta_grid.size):
        line = r"$\Theta = " + str(Fraction(theta_grid[theta_ind])) + "$"
        
        for r0_ind in range(r0_grid.size):
            for rs_ind in range(rs_grid.size):
                line += " & " + err_repr(free_energy_per_pc[r0_ind, rs_ind, theta_ind],
                                         rel_error[r0_ind, rs_ind, theta_ind])
        if theta_ind+1 < theta_grid.size:
            line += r"\\" + "\n"
        file.write(line)
    file.close()