'''
Author: Unsaturated Transistor; Code created and exported to HTML using the eric IDE; Dated 12th October 2024
Want the code? Download the python file here.
Go Back?
'''

import numpy as np
import scipy.optimize as opt
import matplotlib.pyplot as plt
import csv
import os

def csv_import(filename):
    """Imports data from a CSV file into two lists: time and voltages.

    Args:
        filename: The name of the CSV file.

    Returns:
        A tuple containing two lists: time and voltages.
    """

    time = []; voltages = []

    with open(filename, 'r') as file:
        reader = csv.reader(file)
        for row in reader:
            time.append(row[0])
            voltages.append(row[1])

    return time, voltages
    
def func(t, A, tau):
    '''
    Args:
        t: time; A: amplitude; tau: time constant
        
    Returns:
        Function for V_t
    '''
    return A * (1 - np.exp(-t / tau))

def main():
    '''
    Plots the data and the function for V_t.
    
    # Args:
        None
    
    # Returns:
        ax (matplotlib.axes._axes.Axes): The axes object containing the plot.

    '''
    
    # Import relevant data
    time, voltages = csv_import("main_data.csv")
    
    # Convert time and voltages to numpy arrays for better performance
    time_np = np.array([float(t) for t in time])
    voltages_np = np.array([float(v) for v in voltages])
    
    # Divide all voltages by 10 to account for attenuation
    voltages_np = voltages_np / 10
    
    # Create the figure and subplots with equal size, with A4 overall size
    fig, axes = plt.subplots(2, 1, figsize=(8.27, 11.69), gridspec_kw={'height_ratios': [2, 1]})
    
    # Set the border linewidth (thickness) and edge color
    fig.patch.set_linewidth(2)  # Adjust linewidth as desired (e.g., 1, 2, 3)
    fig.patch.set_edgecolor('black')
    
    # Share the time-axis between the subplots
    axes[0].sharex(axes[1])
    
    # Create the scatter plot
    ax1 = axes[0]
    ax1.scatter(time_np, voltages_np, marker='x', s=10, color='blue', label='Raw data')
    
    # Perform curve fitting
    popt, pcov = opt.curve_fit(func, time_np, voltages_np)

    # Extract fitted parameters
    A_fit = popt[0]; tau_fit = popt[1]
    
    # Calculate resistance R
    C = 10e-9  # Capacitance in Farads from circuit diagram
    R = tau_fit / C
    
    # Plot the fitted curve
    ax1.plot(time_np, func(time_np, A_fit, tau_fit), color='orange', label='Fitted Curve')

    # Set labels and title
    ax1.set_xlabel(r"$t$ / s")
    ax1.set_ylabel(r"$V_t$ / V")
    ax1.set_title(r"Voltage $v_{t}$ Across the 10nF Capacitor Due to Increase in Charge Over Time $t$.")
    
    # Add legend
    ax1.legend(loc='lower right')
    
    # Set axis limits based on first and last entries
    ax1.set_xlim(time_np[0], time_np[-1])
    ax1.set_ylim(min(voltages_np), (max(voltages_np)+0.5))
    
    # Calculate tick interval based on the maximum value in time_np
    time_tick_interval = max(time_np) / 5
    voltage_tick_interval = max(voltages_np) / 5

    # Set ticks with the calculated interval
    ax1.set_xticks(np.arange(min(time_np), max(time_np) + time_tick_interval, time_tick_interval))
    ax1.set_yticks(np.arange(min(voltages_np), max(voltages_np) + 0.1, voltage_tick_interval))

    # Display the calculated resistance value
    ax1.text(0.5, 0.5, f"$R$ = {R / 1000:.2f} kΩ; ($A$ = %f V, $τ$ = %f s)" % (A_fit, tau_fit),\
    transform=ax1.transAxes, horizontalalignment='center', verticalalignment='center', fontsize=14)
    
    # Bottom subplot: Empty for future use
    ax2 = axes[1]
    # Calculate the difference between fitted curve and raw data
    difference = voltages_np - func(time_np, A_fit, tau_fit)

    # Plot the difference on the second subplot
    ax2.scatter(time_np, difference, color='brown', marker='x', s=10, label='Difference')

    # Set labels for the bottom subplot
    ax2.set_xlabel(r"$t$ / s")
    ax2.set_ylabel(r"$\Delta V_t$ / V")
    ax2.set_title("Difference Between Fitted Curve and Raw Data.")
    
    # Adjust spacing between subplots
    plt.subplots_adjust(hspace=0.3)

    # Define filename for the plot (modify as needed)
    filename = "rc_circuit_plot.png"

    # Check if the file already exists
    if os.path.exists(filename):
        # Delete the existing file before saving
        os.remove(filename)

    # Save the plot as a PNG
    fig.savefig(filename.png, transparent = True)

    # Show the plot
    plt.show()
    
if __name__ == "__main__":
    main()