import numpy as np
import odeg as od
from time import time
import matplotlib.pyplot as plt

plt.rcParams['figure.dpi'] = 300
plt.rcParams['savefig.dpi'] = 300

if __name__ == '__main__':
    
    pos_charge_grid = np.arange(1.5, 6.5, .05)
    pmax_grid = np.arange(1, 5)
    n_spin = 2
    
    cfg = {
        'r0': 0.5,
        'rs': 2,
        'p_max': 5,
        'pos_charge': 10,
        'theta': 2,
        'return_mu': True,
        'mu_rel_target': 1 + 1e-8
        }
    
    fig, ax = plt.subplots()
    
    data = np.zeros((pmax_grid.size, pos_charge_grid.size))
    
    for pmax_ind, pmax in enumerate(pmax_grid):
        red_pc_grid = pos_charge_grid[pos_charge_grid < n_spin * (2*pmax + 1)]
        cfg['p_max'] = pmax
        
        for pc_ind, pos_charge in enumerate(red_pc_grid):
            cfg['pos_charge'] = pos_charge
            
            print(f"exact code working on pmax = {pmax}, pos_charge = {pos_charge}")
            data[pmax_ind, pc_ind] = od.fctExp(**cfg)
            
        ax.plot(red_pc_grid, data[pmax_ind, :red_pc_grid.size],
                label = r"$p_{\mathrm{max}} = $" + str(pmax))
        
        
    ax.set_xlabel(r"$\cal{N}$")
    # ax.set_ylabel(r"$f_{\mathrm{xc}}$")
    ax.set_ylabel(r"$\mu$")
    ax.set_title(f"theta = {cfg['theta']}, r0 = {cfg['r0']}, rs = {cfg['rs']}")
    fig.legend()