import numpy as np
import odeg.potentials as pt
import matplotlib.pyplot as plt
import odeg.refs.hartreeref as hr
from time import time




if __name__ == '__main__':
    
    r_0 = 0.5
    r_s = 8
    p_max = 4
    pos_charge = 10
    
    theta_grid = np.linspace(0.1, 2, 50)
    mu = -0
    
    
    
    
    start_time = time()
    
    fe = pt.freeEnergy(r_0, r_s, pos_charge, p_max, theta_grid, stored = True, parallel = False)
    
    # om = pt.grandPotential(r_0, r_s, pos_charge, p_max, theta_grid, mu,
    #                           stored = True,
    #                           parallel = False)
    
    om = pt.eqGrandPotential(r_0, r_s, pos_charge, p_max, theta_grid,
                              stored = True,
                              parallel = False,
                              method = 'right')
    
    print('time used: ' + str(time()-start_time))
    
    
    # np.save('results/testrun_fe', fr_en_dpp)
    
    

    
    
    fig, ax = plt.subplots()
    ax.plot(theta_grid, om, label = r"$\Omega$")
    ax.plot(theta_grid, fe, label = r"$F$")
    ax.legend()
    
    
    
    
    
    
    
    
    # for idx in range(p_max_grid.size):
    #     fr_en_dpp[idx,:] = pt.freeEnergy(r_0, r_s, pos_charge_grid[idx], p_max_grid[idx], 
    #                                      theta_grid)/(r_s*np.log(pos_charge_grid[idx]))
        
        # for t_idx in range(theta_grid.size):
        #     fr_en_dpp[idx,t_idx] -= hr.getHartreeFreeEnergyPerElectron(r_s,
        #                                                                theta_grid[t_idx],
        #                                                                2,
        #                                                                r_0,
        #                                                                pos_charge_grid[idx],
        #                                                                p_max_grid[idx])
        
        
        
        
    # fig, ax = plt.subplots()
    # ax.plot(pos_charge_grid ** (-1.), fr_en_dpp, linestyle = '', marker = 'x',
    #         label = [(r'$\Theta$ = ' + str(t)) for t in theta_grid] )
    # ax.legend(loc='upper center')
    # ax.set_xlabel(r'$\mathcal{N}^{\,-1}$')
    # ax.set_ylabel(r'$F/r_s \mathcal{N}$', rotation = 0)
    # ax.yaxis.set_label_coords(-.1, .95)
    # fig.savefig('plots/testrun_fe.jpg', dpi = 500)