import numpy as np
import odeg as od
from time import time
import matplotlib.pyplot as plt

plt.rcParams['figure.dpi'] = 300
plt.rcParams['savefig.dpi'] = 300

if __name__ == '__main__':
    
    pos_charge_grid = np.arange(.125, 8, .125)
    pmax_grid = np.arange(1, 6)
    n_spin = 2
    
    cfg = {
        'r0': 0.5,
        'rs': 2,
        'p_max': 5,
        'pos_charge': 10,
        'theta': 0.01,
        'mu_rel_target': 1 + 1e-8,
        }
    
    fig, ax = plt.subplots()
    
    data = np.zeros((pmax_grid.size, pos_charge_grid.size))
    
    for pmax_ind, pmax in enumerate(pmax_grid):
        red_pc_grid = pos_charge_grid[pos_charge_grid < n_spin * (2*pmax + 1)]
        cfg['p_max'] = pmax
        
        for pc_ind, pos_charge in enumerate(red_pc_grid):
            cfg['pos_charge'] = pos_charge
            
            print(f"exact code working on pmax = {pmax}, pos_charge = {pos_charge}")
            exp_p2sz2, exp_p2, exp_sz2, exp_sz4, exp_p4 = od.fctExp(od.mom2sz2Fct,
                                                                    od.mom2Fct,
                                                                    od.sz2Fct,
                                                                    od.sz4Fct,
                                                                    od.mom4Fct,
                                                                    **cfg)
            
            
            data[pmax_ind, pc_ind] = (exp_p2sz2 - exp_p2 * exp_sz2)/np.sqrt(exp_sz4 * exp_p4)
            
        ax.plot(red_pc_grid, data[pmax_ind, :red_pc_grid.size],
                label = r"$p_{\mathrm{max}} = $" + str(pmax))
        
        
    ax.set_xlabel(r"$\cal{N}$")
    # ax.set_ylabel(r"$f_{\mathrm{xc}}$")
    ax.set_ylabel(r"$\mathrm{cor}(\hat P^2, \hat S_z^2)$")
    ax.set_title(f"theta = {cfg['theta']}, r0 = {cfg['r0']}, rs = {cfg['rs']}")
    fig.legend()