import odeg as od
import matplotlib.pyplot as plt
import numpy as np

plt.rcParams['figure.dpi'] = 500
plt.rcParams['savefig.dpi'] = 500


if __name__ == '__main__':
    
    theta_l = np.arange(1e-2, 2.01, 0.01)
    theta_r = 0.125
    mu_l = -.6
    mu_r = np.linspace(-1., -.2, 200)
    k_grid_l = np.array([4,5,6])
    
    cfg = {
        'r0': .5,
        'rs': 2,
        'p_max': 5,
        'pos_charge': 5,
        'theta': theta_l,
        'mu': mu_l,
        'path': "../results",
        }
    
    
    fct_grid_l = [od.occIndicatorFct(k) for k in k_grid_l]
    data_l = np.zeros((k_grid_l.size+1, cfg['theta'].size))
    data_r = np.zeros((k_grid_l.size+1, mu_r.size))
    
    data_l[:-1] = od.fctExp(*fct_grid_l, **cfg)
    data_l[-1] = 1- np.sum(data_l, axis = 0)
    data_l = np.maximum(data_l, 1e-17)
    
    #change grids
    cfg['theta'] = theta_r
    cfg['mu'] = mu_r
    data_r[:-1] = od.fctExp(*fct_grid_l, **cfg)
    data_r[-1] = 1- np.sum(data_r, axis = 0)
    data_r = np.maximum(data_r, 1e-17)
    
    fig, axes = plt.subplots(1, 2, figsize = (9, 3), sharey = True)
    
    axes[0].set_xlabel(r"$\Theta$")
    axes[1].set_xlabel(r"$\mu/E_{\mathrm{}h}$")
    axes[0].set_ylabel(r"$\langle\chi_A(\hat N)\rangle$")
    axes[0].set_yscale('log')
    axes[0].set_ylim(bottom = 1e-7)
    
    lines = []
    for k_ind, k in enumerate(k_grid_l):
        line0, = axes[0].plot(theta_l, data_l[k_ind], label = r"$A=\{%d\}$" % k)
        axes[1].plot(mu_r, data_r[k_ind])
        lines.append(line0)
        
    line, = axes[0].plot(theta_l, data_l[-1], 
                         label = r"$A=\{0,\dots ,12\} \backslash \{4,5,6\}$")
    lines.append(line)
    axes[1].plot(mu_r, data_r[-1])
    
    axes[0].legend(handles = lines, ncols = 2)
    axes[0].annotate(r"$\mu = %g E_{\mathrm{}h}$" % mu_l, (.05,.9), xycoords = 'axes fraction')
    axes[1].annotate(r"$\Theta = %g$" % theta_r, (.05,.9), xycoords = 'axes fraction')
    fig.tight_layout()