import matplotlib.pyplot as plt
import numpy as np
import time
from IPython import display

def plot_signal(signal,tend,tlines=[]):
    fig, axs = plt.subplots(2, 1, figsize=(20, 6))
    fig.tight_layout(h_pad=3)
    for ax, title in zip(axs, ["envelope", "signal"]):
        signal.draw(0, tend, 50000, title, axis=ax)
        ax.set_xlabel("Time [s]")
        ax.set_ylabel("Amplitude")
        ax.set_title(title)
        for j in range(0,len(tlines)):
            ax.vlines(tlines[j], ax.get_ylim()[0], ax.get_ylim()[1], "k", linestyle="dashed")

def get_meas_signal(results,n,nshots):
    values = []
    for i in range(n):
        try:
            p1 = results.get_counts(i)['1'];
        except KeyError:
            p1 = 0; 
        values.append(p1/nshots);
    return values; 

def video_bloch_traj(y,mod=1):
    for i in range(len(y)):
        if np.mod(i,mod) != 0:
                continue;
        try:
            display.display(y[i].draw("bloch"))
            display.clear_output(wait=True)
            time.sleep(0.001)
        except KeyboardInterrupt:
            break

def plot_bloch_traj(states,mod=1):
    from qiskit.visualization.bloch import Bloch
    from qiskit.quantum_info import Operator
    import numpy as np
    import matplotlib.pyplot as plt
    import matplotlib as mpl
    X = Operator.from_label('X'); 
    Y = Operator.from_label('Y'); 
    Z = Operator.from_label('Z'); 

    cmap = mpl.colormaps["winter"]
    blues = cmap(np.linspace(0,1,int(np.ceil(len(states)/mod))));
    bls = [];
    for i in range(len(blues)):
        bls.append(mpl.colors.rgb2hex(blues[i]));
    blues = bls; 

    _bloch = Bloch()
    _bloch.point_marker = ['o']
    _bloch.point_color = blues; 

    for i in range(len(states)):
        if np.mod(i,mod) != 0:
            continue;
        _bloch.add_points([states[i].expectation_value(X).real, states[i].expectation_value(Y).real, states[i].expectation_value(Z).real])
    
    _bloch.render();  
    
def plot_qubit_dynamics(sol, t_eval):
    from qiskit.quantum_info import Operator
    
    X = Operator.from_label('X');
    Y = Operator.from_label('Y');
    Z = Operator.from_label('Z');

    n_times = len(sol.y)
    x_data = np.zeros((n_times,))
    y_data = np.zeros((n_times,))
    z_data = np.zeros((n_times,))

    for t_i, sol_t in enumerate(sol.y):
        x_data[t_i] = sol_t.expectation_value(X).real
        y_data[t_i] = sol_t.expectation_value(Y).real
        z_data[t_i] = sol_t.expectation_value(Z).real

    _, ax = plt.subplots(figsize = (20, 6))
    fontsize = 16;
    plt.rcParams.update({'font.size': fontsize})
    plt.plot(t_eval/1e-9, x_data, label = '$\\langle X \\rangle$')
    plt.plot(t_eval/1e-9, y_data, label = '$\\langle Y \\rangle$')
    plt.plot(t_eval/1e-9, z_data, label = '$\\langle Z \\rangle$')
    plt.legend(fontsize = fontsize)
    ax.set_xlabel('$t$ [ns]', fontsize = fontsize)
    ax.set_title('Bloch vector vs. $t$', fontsize = fontsize)
    ax.set_ylim([-1,1]); 
    plt.show()

def fit_function(x_values, y_values, function, init_params):
    from scipy.optimize import curve_fit
    fitparams, conv = curve_fit(function, x_values, y_values, init_params)
    y_fit = function(x_values, *fitparams)
    return fitparams, y_fit

def plot_populations(sol,tend):
    pop0 = [psi.probabilities()[0] for psi in sol.y]
    pop1 = [psi.probabilities()[1] for psi in sol.y]

    plt.figure(figsize=(8, 5))
    plt.plot(sol.t, pop0, lw=3, label="Population in |0>")
    plt.plot(sol.t, pop1, lw=3, label="Population in |1>")
    plt.xlabel("Time (ns)")
    plt.ylabel("Population")
    plt.legend(frameon=False)
    plt.ylim([0, 1.05])
    plt.xlim([0, tend])
    plt.vlines(tend, 0, 1.05, "k", linestyle="dashed")

def build_dissipative_backend(omega,Omega,omegad,Gamma_1,Gamma_2,dt):

    from qiskit.quantum_info import Operator
    from qiskit_dynamics import Solver, DynamicsBackend

    #########################################
    # Pauli operators
    #########################################
    X = Operator.from_label('X'); 
    Y = Operator.from_label('Y'); 
    Z = Operator.from_label('Z'); 

    #########################################
    # Hamiltonians and Lindbladians
    #########################################
    hdrift  = 1/2 * omega * Z; 
    hdrive  = Omega * X;
    diss1   = np.sqrt(Gamma_1) * 0.5 * (X + 1j * Y);
    diss2   = np.sqrt(Gamma_2) * Z; 

    #########################################
    # Backend construction
    #########################################
    solver  = Solver(static_hamiltonian=hdrift, hamiltonian_operators=[hdrive], hamiltonian_channels=["d0"], rotating_frame=hdrift, channel_carrier_freqs={"d0":omegad/2/np.pi}, dt=dt, static_dissipators=[diss1,diss2]);
    backend = DynamicsBackend(solver=solver,solver_options={'max_step':200*dt});

    return [backend,solver];

def calibrate_rabi(backend, num_rabi_points, duration, drive_amp_min, drive_amp_max, freq, fit0=[0.5, -0.5, 1, 0]):
    from qiskit.circuit import Parameter 
    import qiskit.pulse as pulse 

    #########################################
    # Rabi calibration setup
    #########################################
    # num_rabi_points = 50                        # Number of drive amplitudes to test
    # duration        = 100;                      # Pulse duration [cycles]

    # drive_amp_min   = 0                         # Minimum drive to test
    # drive_amp_max   = 1.                        # Maximum drive to test
    drive_amps      = np.linspace(drive_amp_min, drive_amp_max, num_rabi_points)

    #########################################
    # Parametric Pulse schedule creation
    #########################################
    drive_amp = Parameter('drive_amp')
    with pulse.build() as rabi_sched:
        with pulse.align_sequential():
            pulse.set_frequency(freq,pulse.DriveChannel(0)); 
            pulse.play(pulse.Constant(duration,drive_amp,name="Rabi Pulse"),pulse.DriveChannel(0)); 
            pulse.acquire(1,0,pulse.MemorySlot(0)); 

    rabi_schedules = [rabi_sched.assign_parameters({drive_amp: a}, inplace=False) for a in drive_amps]

    #########################################
    # Execution
    #########################################
    num_shots_per_point = 100
    job = backend.run(rabi_schedules, 
                    meas_level=2, 
                    meas_return='avg', 
                    shots=num_shots_per_point)
    rabi_results = job.result()
    rabi_values = get_meas_signal(rabi_results,num_rabi_points,num_shots_per_point);

    #########################################
    # Curve fitting
    #########################################
    fit_params, y_fit = fit_function(drive_amps,
                                    rabi_values, 
                                    lambda x, A, B, drive_period, phi: (A*np.sin(2*np.pi*x/drive_period - phi) + B),
                                    fit0)

    drive_period    = fit_params[2]
    pi_amp          = abs(drive_period / 2)
    pih_amp         = pi_amp/2; 

    #########################################
    # Graphics
    #########################################
    plt.scatter(drive_amps, rabi_values, color='black')
    plt.plot(drive_amps, y_fit, color='red')

    plt.axvline(drive_period/2, color='red', linestyle='--')
    plt.axvline(0, color='red', linestyle='--')
    plt.annotate("", xy=(0, 0.5), xytext=(drive_period/2,0.5), arrowprops=dict(arrowstyle="<->", color='red'))
    plt.annotate("$\\pi$", xy=(drive_period*.4, 0.52), color='red')

    plt.xlabel("Drive amp [a.u.]", fontsize=15)
    plt.ylabel("P$_1$", fontsize=15)
    plt.show()

    print(f"Pi Amplitude = {pi_amp}")
    print(f"Pi/2 Amplitude = {pih_amp}")

    return [pi_amp,pih_amp];

def calibrate_coarse_frequency(backend,duration,pulse_amp,min_freq,max_freq,num_points,num_shots_per_point,fit0=[1, 4.94, 0.01, -0.001]):

    from qiskit.circuit import Parameter 
    import qiskit.pulse as pulse
    import numpy as np 

    #########################################
    # Parametric schedule creation
    #########################################
    freq_vec    = np.linspace(min_freq,max_freq,num_points); 
    freq        = Parameter('freq')
    with pulse.build(default_alignment='sequential') as sweep_sched:
        pulse.set_frequency(freq,pulse.DriveChannel(0)); 
        pulse.play(pulse.Constant(duration,
                                    pulse_amp,
                                    name="Excitation Pulse"),
                                    pulse.DriveChannel(0)); 
        pulse.acquire(1,0,pulse.MemorySlot(0)); 
    sweep_schedules = [sweep_sched.assign_parameters({freq: a}, inplace=False) for a in freq_vec]

    #########################################
    # Backend simulation
    #########################################
    job = backend.run(sweep_schedules,
                    meas_level=2, 
                    meas_return='avg', 
                    shots=num_shots_per_point)
    results = job.result()

    #########################################
    # P1 computation
    #########################################
    sweep_values = []
    for i in range(len(results.get_counts())):
        try:
            p1 = results.get_counts(i)['1'];
        except KeyError:
            p1 = 0; 
        sweep_values.append(p1/num_shots_per_point)

    #########################################
    # Graphics
    #########################################
    plt.xlabel("Drive frequency [GHz]")
    plt.ylabel("P$_1$ [a.u.]")
    plt.scatter(freq_vec/1e9, sweep_values, color='black')
    plt.show()

    #########################################
    # Curve fitting
    #########################################
    fit_params, y_fit = fit_function(freq_vec/1e9,
                                    np.real(sweep_values), 
                                    lambda x, A, q_freq, B, C: (A / np.pi) * (B / ((x - q_freq)**2 + B**2)) + C,
                                    fit0 # initial parameters for curve_fit
                                    )
    rough_f = fit_params[1] * 1e9; 
    print(f"Estimated qubit frequency: {rough_f/1e9:.4f} GHz")

    #########################################
    # Graphics
    #########################################
    plt.xlabel("Drive frequency [GHz]")
    plt.ylabel("P$_1$ [a.u.]")
    plt.scatter(freq_vec/1e9, sweep_values, color='black')
    plt.plot(freq_vec/1e9, y_fit, color='red')
    plt.show()

    return rough_f;

def ramsey_experiment(backend,duration,pi_amp,num_points,delay_dt,detune,num_shots_per_point,freq,dt,fit0=[0.5, -0.5, 100e-9, 0]):
   
    from qiskit.circuit import Parameter 
    import qiskit.pulse as pulse
    import numpy as np 

    #########################################
    # X-90 pulse creation
    #########################################
    with pulse.build(default_alignment='sequential') as x_pihalf:
        pulse.play(pulse.Constant(duration,pi_amp/2,name="$\\pi/2$-Rabi Pulse"),pulse.DriveChannel(0)); 

    #########################################
    # X pulse creation
    #########################################
    with pulse.build(default_alignment='sequential') as x_pi:
        pulse.play(pulse.Constant(duration,pi_amp,name="$\\pi$-Rabi Pulse"),pulse.DriveChannel(0)); 
    
    #########################################
    # Parametric schedule construction
    #########################################
    delay = Parameter('delay')
    with pulse.build(default_alignment='sequential') as rams_sched:
        pulse.set_frequency(freq+detune,pulse.DriveChannel(0)); 
        pulse.call(x_pihalf);
        pulse.delay(delay,pulse.DriveChannel(0)); 
        pulse.call(x_pihalf); 
        pulse.acquire(1,0,pulse.MemorySlot(0)); 
    delay_vec       = np.linspace(0,(num_points-1),num_points)*delay_dt;
    rams_schedules  = [rams_sched.assign_parameters({delay: a}, inplace=False) for a in delay_vec]

    #########################################
    # Backend simulation
    #########################################
    job = backend.run(rams_schedules, 
                    meas_level=2, 
                    meas_return='avg', 
                    shots=num_shots_per_point)
    results = job.result()

    #########################################
    # P1 computation
    #########################################
    rams_values = []
    for i in range(num_points):
        try:
            p1 = results.get_counts(i)['1'];
        except KeyError:
            p1 = 0; 
        rams_values.append(p1/num_shots_per_point)

    #########################################
    # Curve fitting
    #########################################
    delay_times         = delay_vec*dt; 
    fit_params, y_fit   = fit_function(delay_times,
                                    rams_values,
                                    lambda x, A, B, delay_period, phi: (A*np.sin(2*np.pi*x/delay_period - phi) + B),
                                    fit0)
    delay_period    = fit_params[2];
    ramsey_f        = 1/delay_period; 
    qubit_f         = freq - (ramsey_f - detune); 
    print(f"Ramsey frequency (natural + forced): {ramsey_f/1e6:.3f} MHz")
    print(f"Ramsey frequency (natural): {(ramsey_f-detune)/1e6:.3f} MHz")
    print(f"Updated drive frequency: {(-ramsey_f+freq+detune)/1e9:.6f} GHz")

    #########################################
    # Graphics
    #########################################
    plt.scatter(delay_times/1e-9, rams_values, color='black')
    plt.plot(delay_times/1e-9, y_fit, color='red')
    plt.axvline((delay_times[0] + delay_period/2)/1e-9, color='red', linestyle='--')
    plt.axvline((delay_times[0])/1e-9, color='red', linestyle='--')
    plt.annotate("", xy=(delay_times[0]/1e-9, 0.5), xytext=((delay_times[0] + delay_period/2)/1e-9,0.5), arrowprops=dict(arrowstyle="<->", color='red'))
    plt.annotate("$\\pi$", xy=((delay_times[0] + delay_period/4)/1e-9, 0.52), color='red')
    plt.xlabel("Delay time [ns]", fontsize=15)
    plt.ylabel("P$_1$", fontsize=15)
    plt.show()

    return qubit_f;
    
def inversion_recovery(backend,duration,pi_amp,num_points,delay_dt,num_shots_per_point,freq,dt,fit0):
    from qiskit.circuit import Parameter 
    import qiskit.pulse as pulse
    import numpy as np 

    #########################################
    # X-90 pulse creation
    #########################################
    with pulse.build(default_alignment='sequential') as x_pihalf:
        pulse.play(pulse.Constant(duration,pi_amp/2,name="$\\pi/2$-Rabi Pulse"),pulse.DriveChannel(0)); 

    #########################################
    # X pulse creation
    #########################################
    with pulse.build(default_alignment='sequential') as x_pi:
        pulse.play(pulse.Constant(duration,pi_amp,name="$\\pi$-Rabi Pulse"),pulse.DriveChannel(0));

    #########################################
    # Parametric schedule creation
    #########################################
    delay = Parameter('delay')
    with pulse.build(default_alignment='sequential') as t1_sched:
            pulse.set_frequency(freq,pulse.DriveChannel(0)); 
            pulse.call(x_pi); 
            pulse.delay(delay,pulse.DriveChannel(0)); 
            pulse.acquire(1,0,pulse.MemorySlot(0)); 
    delay_vec   = np.linspace(0,(num_points-1),num_points)*delay_dt; 
    t1_scheds   = [t1_sched.assign_parameters({delay: a}, inplace=False) for a in delay_vec]

    #########################################
    # Backend simulation
    #########################################
    job = backend.run(t1_scheds, 
                    meas_level=2, 
                    meas_return='avg', 
                    shots=num_shots_per_point)
    t1_results = job.result() 

    #########################################
    # P1 computation
    #########################################
    t1_values = []
    for i in range(num_points):
        try:
            p1 = t1_results.get_counts(i)['1'];
        except KeyError:
            p1 = 0; 
        t1_values.append(p1/num_shots_per_point)

    #########################################
    # Curve fitting
    #########################################
    fit_params, y_fit = fit_function(delay_vec*dt,
                                    t1_values, 
                                    lambda x, T, A, C: A*np.exp(-x/T)+C,
                                    fit0)
    T1 = fit_params[0];
    print(f"T1 = {T1/1e-6:.3f} µs")

    #########################################
    # Graphics
    #########################################
    plt.scatter(delay_vec*dt/1e-6, t1_values, color='black')
    plt.plot(delay_vec*dt/1e-6, y_fit, color='red')
    plt.xlabel("Delay [$\\mu$s]", fontsize=15)
    plt.ylabel("P$_1$", fontsize=15)
    plt.show()

    return T1;

def hahn_echo_experiment(backend,duration,pi_amp,num_points,delay_dt,num_shots_per_point,freq,dt,fit0=[10e-6,0.5,0.5]):
    from qiskit.circuit import Parameter 
    import qiskit.pulse as pulse
    import numpy as np 

    #########################################
    # X-90 pulse creation
    #########################################
    with pulse.build(default_alignment='sequential') as x_pihalf:
        pulse.play(pulse.Constant(duration,pi_amp/2,name="$\\pi/2$-Rabi Pulse"),pulse.DriveChannel(0)); 

    #########################################
    # X pulse creation
    #########################################
    with pulse.build(default_alignment='sequential') as x_pi:
        pulse.play(pulse.Constant(duration,pi_amp,name="$\\pi$-Rabi Pulse"),pulse.DriveChannel(0));

    #########################################
    # Experiment parameters
    #########################################
    delay = Parameter('delay')
    with pulse.build(default_alignment='sequential') as t2_sched:
        pulse.set_frequency(freq,pulse.DriveChannel(0)); 
        pulse.call(x_pihalf);     

        pulse.delay(delay,pulse.DriveChannel(0)); 
        pulse.call(x_pi); 
        pulse.delay(delay,pulse.DriveChannel(0)); 

        pulse.call(x_pihalf); 
        pulse.acquire(1,0,pulse.MemorySlot(0));
    delay_vec   = np.linspace(0,num_points-1,num_points)*delay_dt; 
    t2_scheds = [t2_sched.assign_parameters({delay: a}, inplace=False) for a in delay_vec]

    #########################################
    # Backend simulation
    #########################################
    job = backend.run(t2_scheds, 
                    meas_level=2, 
                    meas_return='avg', 
                    shots=num_shots_per_point)
    t2_results = job.result()

    #########################################
    # P0 computation
    #########################################
    t2_values = []
    for i in range(len(delay_vec)):
        try:
            p0 = t2_results.get_counts(i)['0'];
        except KeyError:
            p0 = 0; 
        t2_values.append(p0/num_shots_per_point)

    #########################################
    # Curve fitting
    #########################################
    fit_params, y_fit = fit_function(delay_vec*dt,
                                    t2_values, 
                                    lambda x, T, A, C: A*np.exp(-x/T)+C,
                                    fit0)
    T2 = fit_params[0]; 
    print(f"T2 = {T2/1e-6} µs")

    #########################################
    # Graphics
    #########################################
    plt.scatter(delay_vec*dt/1e-6, t2_values, color='black')
    plt.plot(delay_vec*dt/1e-6, y_fit, color='red')
    plt.axhline(0.5, color='red', linestyle='--')
    plt.xlabel("X90-$\\pi$ delay [$\\mu$s]", fontsize=15)
    plt.ylabel("P$_0$", fontsize=15)
    plt.ylim([0,1]);
    plt.show()

    return T2; 

def xth_quantum_tomography(backend,duration,pi_amp,num_shots_per_point,freq):
    from qiskit.visualization import plot_histogram
    import qiskit.pulse as pulse
    import numpy as np 
    from IPython import display

    #########################################
    # X-90 pulse creation
    #########################################
    with pulse.build(default_alignment='sequential') as x_pihalf:
        pulse.play(pulse.Constant(duration,pi_amp/2,name="$\\pi/2$-Rabi Pulse"),pulse.DriveChannel(0)); 

    #########################################
    # X pulse creation
    #########################################
    with pulse.build(default_alignment='sequential') as x_pi:
        pulse.play(pulse.Constant(duration,pi_amp,name="$\\pi$-Rabi Pulse"),pulse.DriveChannel(0));

    #########################################
    # Hadamard pulse
    #########################################
    with pulse.build(default_alignment='sequential') as h_sched:
        # H pulse
        pulse.shift_phase(np.pi/2, pulse.DriveChannel(0))
        pulse.call(x_pihalf);     
        pulse.shift_phase(-np.pi/2, pulse.DriveChannel(0))
        pulse.call(x_pi);     

    #########################################
    # XTH pulse
    #########################################
    with pulse.build(default_alignment='sequential') as xth_sched:
        # H pulse
        pulse.call(h_sched); 
    
        # T pulse
        pulse.shift_phase(np.pi/4, pulse.DriveChannel(0))

        # X pulse
        pulse.call(x_pi);     

    #########################################
    # Schedule buildup
    #########################################
    with pulse.build(default_alignment='sequential',name="O=XTH") as o_sched:
        # Set drive frequency
        pulse.set_frequency(freq,pulse.DriveChannel(0)); 

        # XTH pulse
        pulse.call(xth_sched); 
    
        # Measurement
        pulse.acquire(1,0,pulse.MemorySlot(0));

    display.display(o_sched.draw())

    #########################################
    # Backend simulation
    #########################################
    job = backend.run(o_sched, 
                    meas_level=2, 
                    meas_return='avg', 
                    shots=num_shots_per_point)
    z_result = job.result()

    #########################################
    # P0z/P1z computation
    #########################################
    z_counts = z_result.get_counts()
    try:
        p0z = z_counts['0']/num_shots_per_point; 
    except KeyError:
        p0z = 0; 
    try:
        p1z = z_counts['1']/num_shots_per_point; 
    except KeyError:
        p1z = 0; 

    #########################################
    # Graphics
    #########################################
    print("Z-projection\n") 
    display.display(z_result.get_counts())
    display.display(plot_histogram(z_result.get_counts()))

    #########################################
    # X-Projection schedule
    #########################################
    print("X-projection\n")
    with pulse.build(default_alignment='sequential',name="OH=XTHH (X-proj)") as sched_x:
        pulse.set_frequency(freq,pulse.DriveChannel(0)); 
        pulse.call(xth_sched);
        pulse.call(h_sched); 
        pulse.acquire(1,0,pulse.MemorySlot(0));
    display.display(sched_x.draw())

    #########################################
    # Backend simulation
    #########################################
    num_shots_per_point = 1000
    job = backend.run(sched_x,
                    meas_level=2,
                    meas_return='avg',
                    shots=num_shots_per_point)
    x_result = job.result()

    #########################################
    # P0x/P1x computation
    #########################################
    x_counts = x_result.get_counts()
    try:
        p0x = x_counts['0']/num_shots_per_point; 
    except KeyError:
        p0x = 0; 
    try:
        p1x = x_counts['1']/num_shots_per_point; 
    except KeyError:
        p1x = 0; 

    #########################################
    # Graphics
    #########################################
    display.display(x_result.get_counts())
    display.display(plot_histogram(x_result.get_counts()))

    #########################################
    # S*-gate creation
    #########################################
    with pulse.build(default_alignment='sequential') as sstar_gate:
        pulse.shift_phase(-np.pi/2,pulse.DriveChannel(0)); 

    #########################################
    # Y-Projection schedule
    #########################################
    print("Y-projection\n")
    with pulse.build(default_alignment='sequential',name="OSdH=XTHSdH (Y-proj)") as sched_y:
        pulse.set_frequency(freq,pulse.DriveChannel(0)); 
        pulse.call(xth_sched);
        pulse.call(sstar_gate); 
        pulse.call(h_sched); 
        pulse.acquire(1,0,pulse.MemorySlot(0));
    display.display(sched_y.draw())

    #########################################
    # Backend simulation
    #########################################
    num_shots_per_point = 1000
    job = backend.run(sched_y,
                    meas_level=2,
                    meas_return='avg',
                    shots=num_shots_per_point)
    y_result = job.result()

    #########################################
    # P0y/P1y computation
    #########################################
    counts = y_result.get_counts()
    try:
        p0y = counts['0']/num_shots_per_point; 
    except KeyError:
        p0y = 0; 
    try:
        p1y = counts['1']/num_shots_per_point; 
    except KeyError:
        p1y = 0; 

    #########################################
    # Graphics
    #########################################
    print("Y-projection\n")
    display.display(y_result.get_counts())
    display.display(plot_histogram(y_result.get_counts()))


    return [p0z,p1z,p0x,p1x,p0y,p1y];
