import matplotlib.pyplot as plt
import numpy as np
import time
from IPython import display

def plot_signal(signal,tend,tlines=[]):
    fig, axs = plt.subplots(2, 1, figsize=(20, 6))
    fig.tight_layout(h_pad=3)
    for ax, title in zip(axs, ["envelope", "signal"]):
        signal.draw(0, tend, 50000, title, axis=ax)
        ax.set_xlabel("Time [s]")
        ax.set_ylabel("Amplitude")
        ax.set_title(title)
        for j in range(0,len(tlines)):
            ax.vlines(tlines[j], ax.get_ylim()[0], ax.get_ylim()[1], "k", linestyle="dashed")

def get_meas_signal(results,n,nshots):
    values = []
    for i in range(n):
        try:
            p1 = results.get_counts(i)['1'];
        except KeyError:
            p1 = 0; 
        values.append(p1/nshots);
    return values; 

def video_bloch_traj(y,mod=1):
    for i in range(len(y)):
        if np.mod(i,mod) != 0:
                continue;
        try:
            display.display(y[i].draw("bloch"))
            display.clear_output(wait=True)
            time.sleep(0.001)
        except KeyboardInterrupt:
            break

def plot_bloch_traj(states,mod=1):
    from qiskit.visualization.bloch import Bloch
    from qiskit.quantum_info import Operator
    import numpy as np
    import matplotlib.pyplot as plt
    import matplotlib as mpl
    X = Operator.from_label('X'); 
    Y = Operator.from_label('Y'); 
    Z = Operator.from_label('Z'); 

    cmap = mpl.colormaps["winter"]
    blues = cmap(np.linspace(0,1,int(np.ceil(len(states)/mod))));
    bls = [];
    for i in range(len(blues)):
        bls.append(mpl.colors.rgb2hex(blues[i]));
    blues = bls; 

    _bloch = Bloch()
    _bloch.point_marker = ['o']
    _bloch.point_color = blues; 

    for i in range(len(states)):
        if np.mod(i,mod) != 0:
            continue;
        _bloch.add_points([states[i].expectation_value(X).real, states[i].expectation_value(Y).real, states[i].expectation_value(Z).real])
    
    _bloch.render();  
    
def plot_qubit_dynamics(sol, t_eval):
    from qiskit.quantum_info import Operator
    
    X = Operator.from_label('X');
    Y = Operator.from_label('Y');
    Z = Operator.from_label('Z');

    n_times = len(sol.y)
    x_data = np.zeros((n_times,))
    y_data = np.zeros((n_times,))
    z_data = np.zeros((n_times,))

    for t_i, sol_t in enumerate(sol.y):
        x_data[t_i] = sol_t.expectation_value(X).real
        y_data[t_i] = sol_t.expectation_value(Y).real
        z_data[t_i] = sol_t.expectation_value(Z).real

    _, ax = plt.subplots(figsize = (20, 6))
    fontsize = 16;
    plt.rcParams.update({'font.size': fontsize})
    plt.plot(t_eval/1e-9, x_data, label = '$\\langle X \\rangle$')
    plt.plot(t_eval/1e-9, y_data, label = '$\\langle Y \\rangle$')
    plt.plot(t_eval/1e-9, z_data, label = '$\\langle Z \\rangle$')
    plt.legend(fontsize = fontsize)
    ax.set_xlabel('$t$ [ns]', fontsize = fontsize)
    ax.set_title('Bloch vector vs. $t$', fontsize = fontsize)
    ax.set_ylim([-1,1]); 
    plt.show()

def fit_function(x_values, y_values, function, init_params):
    from scipy.optimize import curve_fit
    fitparams, conv = curve_fit(function, x_values, y_values, init_params)
    y_fit = function(x_values, *fitparams)
    return fitparams, y_fit

def plot_populations(sol,tend):
    pop0 = [psi.probabilities()[0] for psi in sol.y]
    pop1 = [psi.probabilities()[1] for psi in sol.y]

    plt.figure(figsize=(8, 5))
    plt.plot(sol.t, pop0, lw=3, label="Population in |0>")
    plt.plot(sol.t, pop1, lw=3, label="Population in |1>")
    plt.xlabel("Time (ns)")
    plt.ylabel("Population")
    plt.legend(frameon=False)
    plt.ylim([0, 1.05])
    plt.xlim([0, tend])
    plt.vlines(tend, 0, 1.05, "k", linestyle="dashed")

def calibrate_rabi(backend, num_rabi_points, duration, drive_amp_min, drive_amp_max):
    from qiskit.circuit import Parameter 
    import qiskit.pulse as pulse 

    #########################################
    # Rabi calibration setup
    #########################################
    # num_rabi_points = 50                        # Number of drive amplitudes to test
    # duration        = 100;                      # Pulse duration [cycles]

    # drive_amp_min   = 0                         # Minimum drive to test
    # drive_amp_max   = 1.                        # Maximum drive to test
    drive_amps      = np.linspace(drive_amp_min, drive_amp_max, num_rabi_points)

    #########################################
    # Parametric Pulse schedule creation
    #########################################
    drive_amp = Parameter('drive_amp')
    with pulse.build() as rabi_sched:
        with pulse.align_sequential():
            pulse.play(pulse.Constant(duration,drive_amp,name="Rabi Pulse"),pulse.DriveChannel(0)); 
            pulse.acquire(1,0,pulse.MemorySlot(0)); 

    rabi_schedules = [rabi_sched.assign_parameters({drive_amp: a}, inplace=False) for a in drive_amps]

    #########################################
    # Execution
    #########################################
    num_shots_per_point = 100
    job = backend.run(rabi_schedules, 
                    meas_level=2, 
                    meas_return='avg', 
                    shots=num_shots_per_point)
    rabi_results = job.result()
    rabi_values = get_meas_signal(rabi_results,num_rabi_points,num_shots_per_point);

    #########################################
    # Curve fitting
    #########################################
    fit_params, y_fit = fit_function(drive_amps,
                                    rabi_values, 
                                    lambda x, A, B, drive_period, phi: (A*np.sin(2*np.pi*x/drive_period - phi) + B),
                                    [0.5, -0.5, 0.9, 0])

    drive_period    = fit_params[2]
    pi_amp          = abs(drive_period / 2)
    pih_amp         = pi_amp/2; 

    #########################################
    # Graphics
    #########################################
    plt.scatter(drive_amps, rabi_values, color='black')
    plt.plot(drive_amps, y_fit, color='red')

    plt.axvline(drive_period/2, color='red', linestyle='--')
    plt.axvline(0, color='red', linestyle='--')
    plt.annotate("", xy=(0, 0.5), xytext=(drive_period/2,0.5), arrowprops=dict(arrowstyle="<->", color='red'))
    plt.annotate("$\\pi$", xy=(drive_period*.4, 0.52), color='red')

    plt.xlabel("Drive amp [a.u.]", fontsize=15)
    plt.ylabel("P$_1$", fontsize=15)
    plt.show()

    print(f"Pi Amplitude = {pi_amp}")
    print(f"Pi/2 Amplitude = {pih_amp}")

    return [pi_amp,pih_amp];

