Back to Article
Cross Correlations
Download Notebook

Cross Correlations

Imports

In [2]:
from concurrent.futures import ProcessPoolExecutor, as_completed

import h5py
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy.io as sio
import seaborn as sns
from scipy.interpolate import interp1d
from scipy.signal import (butter, correlate, filtfilt, find_peaks, resample,
                          sosfiltfilt, welch)

Loading experiment data

In [3]:
trial = 'A0'
eeg_filepath = f'../data/{trial}/converted/eeg_matrix.mat'
eeg_trialinfo = f'../data/{trial}/converted/trialinfo_matrix_{trial}_cleanedtrials.mat'
eeg_raw_data = f'../data/{trial}/experimental/data.mat'
eeg_labels_path = f'../data/A0/preprocessing/channel_labels.mat' # always get the labels from A0
In [4]:
f = sio.loadmat(eeg_labels_path)
labels = f['channellabels_ADSselection']
len(labels) # 37 + 4
41
In [5]:
def load_mat_file(filepath):
    """ Load a MATLAB .mat file and return its content. """
    try:
        mat_contents = sio.loadmat(filepath)
        return mat_contents
    except NotImplementedError:
        # If the file is v7.3, it will need h5py to handle it
        with h5py.File(filepath, 'r') as file:
            return {key: np.array(value) for key, value in file.items()}

def get_eeg_data(eeg_data_path):
    """ Extract EEG data from the .mat file. """
    data = load_mat_file(eeg_data_path)
    eeg_data = data['data_eeg']
    return eeg_data

def get_trial_info(trial_info_path):
    """ Extract trial information from the .mat file. """
    data = load_mat_file(trial_info_path)
    trial_info = data['all_info']
    return trial_info

def load_raw_data(noise_data_path):
    """ Load and return the entire noise stimuli structure from the .mat file. """
    mat_contents = load_mat_file(noise_data_path)
    # Extract the 'noise_stims' from the nested structure
    noise_stims = mat_contents['data'][0, 0]
    return noise_stims

def get_specific_raw_data(idx, eeg_raw_data, block_index, trial_index):
    """ Fetch the specific noise stimulus for the given block, trial, and channel (0 for left, 1 for right). """
    idxs = ['reaction_times', 'reaction_err', 'answers', 'base_delay', 'target_timings', 'flicker_sides', 'attend_sides', 'orients_L', 'orients_R', 'angle_magnitude', 'probe_sides', 'missed', 'targets_binary', 'tagging_types', 'noise_stims', 'p_num', 'date']
    return eeg_raw_data[idxs.index(idx)][block_index, trial_index]

def trial_number_to_indices(trial_number):
    """ Convert trial number to block and trial indices, zero-indexed. """
    trial_number = int(trial_number) - 1
    block_index = trial_number // 32 + 1 # +1 block because we've discard the first block
    trial_index = trial_number  % 32
    return block_index, trial_index

def load_and_process_data(eeg_filepath, trialinfo_filepath, raw_data_filepath):
    eeg_data = get_eeg_data(eeg_filepath)
    trial_info = get_trial_info(trialinfo_filepath)
    eeg_raw_data = load_raw_data(raw_data_filepath)

    processed_trials = []
    # Assuming we loop over actual trial data, adjust as necessary
    for trial, eeg_data in zip(trial_info, eeg_data):
        block_index, trial_index = trial_number_to_indices(trial[14])

        # Access the corresponding noise stimulus for left and right channels
        noise_stim_left, noise_stim_right = get_specific_raw_data('noise_stims', eeg_raw_data, block_index, trial_index)

        trial_data = {
            'eeg': eeg_data,  # Adjust for zero-index and skip block 1
            'trial_info': trial,
            'reaction_time_1': trial[2],
            'reaction_time_2': get_specific_raw_data('reaction_times', eeg_raw_data, block_index, trial_index),
            'has_noise': bool(int(trial[13])),
            'has_56|64': int(trial[7]) == 0,
            'has_56|60': int(trial[7]) == 1,
            'attended_side': trial[8],
            'noise_stim_left': noise_stim_left,
            'noise_stim_right': noise_stim_right,
            'base_delay': trial[6],
            'target_timings': trial[5],
        }
        processed_trials.append(trial_data)

    return processed_trials

experiment_data = load_and_process_data(eeg_filepath, eeg_trialinfo, eeg_raw_data)
experiment_data = pd.DataFrame(experiment_data)
# These are in the form trials (440), samples (6758), channels (37)
eeg_data = np.stack(experiment_data['eeg'])
channel_labels = load_mat_file(eeg_labels_path)['channellabels_ADSselection'].flatten()
channel_labels = [x[0] for x in channel_labels][:-4] # remove the last 4 channels, they're not used (EXG1 EXG2 EXG3 EXG4)
print(' '.join(channel_labels))
Fp1 AF7 AF3 F1 F3 F5 F7 FT7 FC5 FC3 FC1 C1 C3 C5 T7 TP7 CP5 CP3 CP1 P1 P3 P5 P7 PO7 PO3 O1 Iz Oz POz Pz P2 P4 P6 P8 PO8 PO4 O2
In [16]:
import matplotlib.pyplot as plt
import numpy as np
import scipy.io as sio
from scipy.signal import resample
from matplotlib.colors import LinearSegmentedColormap


plt.rcParams['font.sans-serif'] = 'Open Sans'
plt.rcParams.update({'figure.autolayout': True})

figsize = 21.124755984047216

mm_to_pt = 1/25.4*72
linewidth = 0.23550824252304192 * mm_to_pt # mm
linewidth_thicc = 0.38106034103303515 * mm_to_pt # mm

def plot_artistic_eeg_signals(eeg_data, noise_signal=None, sampling_rate=2048, noise_sample_rate=480, output_path='artistic_eeg.svg'):
    num_channels = eeg_data.shape[1]
    duration = eeg_data.shape[0] / sampling_rate
    time_vector = np.linspace(0, duration, eeg_data.shape[0])

    # Set up the plot``
    fig, ax = plt.subplots(figsize=(figsize, figsize))
    
    # Create a deep blue to purple gradient background
    gradient_cmap = LinearSegmentedColormap.from_list("deep_blue_purple", ["#001140", "#5a2184"], N=256)

    # Add background gradient
    # ax.imshow([[0, 1], [0, 1]], cmap=gradient_cmap, interpolation='bicubic', aspect='auto', alpha=0.9,
    #           extent=[time_vector[0], time_vector[-1], -num_channels-3, num_channels+3])
    
    # Calculate vertical spacing
    spacing = 2
    shift_values = np.linspace(-spacing * (num_channels // 2), spacing * (num_channels // 2), num_channels)
    
    # Normalize and plot each EEG signal within the same plot
    colors = plt.cm.rainbow(np.linspace(0, 1, num_channels))  # Rainbow colors

    for channel, shift, color in zip(range(num_channels), shift_values, colors):
        normalized_eeg = (eeg_data[:, channel] - np.mean(eeg_data[:, channel])) / np.std(eeg_data[:, channel])
        ax.plot(time_vector, normalized_eeg + shift, color=color, lw=linewidth, alpha=1,  clip_on=False)

    # Customize the plot for a striking appearance
    ax.set_aspect('auto')
    ax.axis('off')  # Remove axes for a clean look
    fig.patch.set_facecolor('white')  # Set background to white

    # Adjust layout and remove extra padding
    plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
    plt.margins(0)
    ax.margins(0)

    # Save the plot as an SVG file
    plt.savefig(output_path, format='svg', bbox_inches='tight', pad_inches=0, transparent=True)
    plt.show()
    plt.close()

def plot_noise_as_circles(inner_noise, outer_noise, inner_radius=-3, outer_radius=3, output_path='noise_circles.svg'):
    # Normalize and resample noise signals to the same length
    num_points = 480*3  # Number of points for the circle
    inner_noise_resampled = resample(inner_noise, num_points)
    outer_noise_resampled = resample(outer_noise, num_points)
    
    inner_noise_normalized = (inner_noise_resampled - np.mean(inner_noise_resampled)) / np.std(inner_noise_resampled)
    outer_noise_normalized = (outer_noise_resampled - np.mean(outer_noise_resampled)) / np.std(outer_noise_resampled)

    # Create polar coordinates
    theta = np.linspace(0, 2 * np.pi, num_points)

    # Calculate x, y positions for inner and outer circles
    inner_x = (inner_radius + inner_noise_normalized) * np.cos(theta)
    inner_y = (inner_radius + inner_noise_normalized) * np.sin(theta)
    outer_x = (outer_radius + outer_noise_normalized) * np.cos(theta)
    outer_y = (outer_radius + outer_noise_normalized) * np.sin(theta)

    # Set up the plot
    fig, ax = plt.subplots(figsize=(figsize, figsize), subplot_kw={'projection': 'polar'})

    # print(inner_radius + inner_noise_normalized, outer_radius + outer_noise_normalized)

    # Plot inner and outer circles
    # ax.plot(theta, inner_radius + inner_noise_normalized, color='black', lw=0.6165685835560771, alpha=1, label='Inner Noise')
    ax.plot(theta, outer_radius + outer_noise_normalized, color='black', lw=linewidth_thicc, alpha=0.3, label='Outer Noise', clip_on=False)

    # Customize the plot for a striking appearance
    # ax.set_aspect('auto')
    ax.grid(False)  # Hide grid
    ax.set_xticks([])  # Hide angular ticks
    ax.set_yticks([])  # Hide radial ticks
    ax.spines['polar'].set_visible(False)  # Hide polar spine

    # Remove the frame
    fig.patch.set_visible(False)
    ax.patch.set_visible(False)
    plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
    plt.margins(0)
    ax.margins(0)
    
    # Save the plot as an SVG file
    plt.savefig(output_path, format='svg', bbox_inches='tight', pad_inches=0, transparent=True)
    plt.show()
    plt.close()

noise_experiments = experiment_data[experiment_data['has_noise']]
skip = 4
for _, experiment in noise_experiments.iterrows():
    if skip > 0:
        skip -= 1
        continue
    eeg_data = experiment['eeg']
    print(experiment['base_delay'])
    print(experiment['target_timings'])
    print(eeg_data.shape)
    noise_stim_left = experiment['noise_stim_left']
    noise_stim_right = experiment['noise_stim_right']
    print(noise_stim_left.shape)
    plot_artistic_eeg_signals(eeg_data, output_path='../presentations/eeg_art.svg')
    plot_noise_as_circles(noise_stim_left, noise_stim_right, output_path='../presentations/noise_circles.svg')
    break
0.119
0.768
(6758, 37)
(4800,)

Functions to calculate cross correlation of noise over a trial

In [8]:
def butter_bandpass(lowcut, highcut, fs, order=5):
    nyquist = 0.5 * fs
    low = lowcut / nyquist
    high = highcut / nyquist
    b, a = butter(order, [low, high], btype='band')
    return b, a

def bandpass_filter(data, lowcut, highcut, fs, order=5):
    b, a = butter_bandpass(lowcut, highcut, fs, order=order)
    y = filtfilt(b, a, data)
    return y

def butter_bandpass_sos(lowcut, highcut, fs, order=5):
    nyquist = 0.5 * fs
    low = lowcut / nyquist
    high = highcut / nyquist
    sos = butter(order, [low, high], btype='band', output='sos')
    return sos

def bandpass_filter_sos(data, lowcut, highcut, fs, order=5):
    sos = butter_bandpass_sos(lowcut, highcut, fs, order=order)
    y = sosfiltfilt(sos, data)
    return y

def mix_signals(eeg_data, noise_signal, strength):
    mixed_data = (1-strength) * eeg_data + strength * noise_signal
    return mixed_data

def upsample_or_downsample(data, original_freq, new_freq):
    """ Upsample or downsample data to the new frequency. """
    num_samples = int(len(data) * new_freq / original_freq)
    return resample(data, num_samples)

def normalise_signal(signal):
    return (signal - np.mean(signal)) / np.std(signal)

def compute_best_combined_offset(eeg_data, noise_stim_left, noise_stim_right, channel_labels,
                                 sampling_rate=2048, noise_freq=480,
                                 upsample_noise=True, eeg_start_time=0.0, eeg_end_time=3.2,
                                 noise_start_time=0.0, noise_end_time=2.0,
                                 bandpass_eeg=None, bandpass_noise=None, mix_noise_strength=None, plot_results=False):

    num_samples, num_channels = eeg_data.shape
    # num_channels = len(channel_labels)
    eeg_segment_start = int(eeg_start_time * sampling_rate)
    eeg_segment_end = int(eeg_end_time * sampling_rate)

    # Select EEG segment
    eeg_data_segment = eeg_data[eeg_segment_start:eeg_segment_end, :].copy()

    # Select noise segment
    noise_segment_start = int(noise_start_time * noise_freq)
    noise_segment_end = int(noise_end_time * noise_freq)
    noise_stim_left_segment = noise_stim_left[noise_segment_start:noise_segment_end].copy()
    noise_stim_right_segment = noise_stim_right[noise_segment_start:noise_segment_end].copy()

    # Optionally bandpass filter the signals
    if bandpass_eeg:
        lowcut, highcut = bandpass_eeg
        for channel in range(num_channels):
            eeg_data_segment[:, channel] = bandpass_filter_sos(eeg_data_segment[:, channel], lowcut, highcut, sampling_rate)

    if bandpass_noise:
        lowcut, highcut = bandpass_noise
        noise_stim_left_segment = bandpass_filter_sos(noise_stim_left_segment, lowcut, highcut, noise_freq)
        noise_stim_right_segment = bandpass_filter_sos(noise_stim_right_segment, lowcut, highcut, noise_freq)

    # Upsample or downsample the noise signals
    if upsample_noise:
        noise_stim_left_segment = upsample_or_downsample(noise_stim_left_segment, noise_freq, sampling_rate)
        noise_stim_right_segment = upsample_or_downsample(noise_stim_right_segment, noise_freq, sampling_rate)
    else:
        eeg_data_segment = upsample_or_downsample(eeg_data_segment, sampling_rate, noise_freq)
        sampling_rate = noise_freq

    # Normalise the signals before mixing and cross-correlation
    noise_stim_left_segment = normalise_signal(noise_stim_left_segment)
    noise_stim_right_segment = normalise_signal(noise_stim_right_segment)
    eeg_data_segment = normalise_signal(eeg_data_segment)

    # Optionally merge the noise signals
    mixed_noise_signal = (noise_stim_left_segment + noise_stim_right_segment) / 2

    # Optionally mix the noise signal into the EEG data
    if mix_noise_strength is not None:
        offset = 1 * sampling_rate
        for channel in range(num_channels):
            eeg_data_segment[offset:offset+len(mixed_noise_signal), channel] = mix_signals(eeg_data_segment[offset:offset+len(mixed_noise_signal), channel], mixed_noise_signal, mix_noise_strength)

    # Store the cross-correlation results and lags
    cross_corr_results = []
    lags = []
    max_cors = []

    for channel, channel_label in zip(range(num_channels), channel_labels):
        # if channel_label not in ['Pz', 'POz', 'Iz', 'O1', 'PO3']: continue # Only plot the interesting occipital channels

        eeg_channel_data = eeg_data_segment[:, channel]

        # Compute cross-correlation with the mixed noise signal
        corr = correlate(eeg_channel_data, mixed_noise_signal, mode='full')[len(mixed_noise_signal) - 1:
                                                                            len(mixed_noise_signal) - 1 + len(eeg_channel_data)]
        # Above aligns corr, so that the measured lag corresponds directly to where the signal starts in the EEG data

        # Find the peak in the cross-correlation
        peak_lag = np.argmax(corr)
        # Store results
        cross_corr_results.append(corr)
        lags.append(peak_lag)
        max_cors.append(np.max(corr))

    # lags = [2048]
    # print('Max Cross-Correlation:', max_cors)
    # print('Lags:', sorted(lags))
    # print('Sorted:')
    # for tup in sorted(zip(channel_labels, lags, max_cors), key=lambda x: x[1]):
    #     print(tup)

    if not plot_results:
        return lags, cross_corr_results

    # Plotting the results
    time_vector = np.linspace(eeg_start_time, eeg_end_time, eeg_data_segment.shape[0])
    for channel in range(num_channels):
        plt.figure(figsize=(40, 10))  # Adjusting the figure size for better visualization

        normalized_eeg = normalise_signal(eeg_data_segment[:, channel])
        normalized_noise = normalise_signal(mixed_noise_signal)
        normalized_corr = normalise_signal(cross_corr_results[channel])

        print(len(normalized_eeg), len(normalized_noise), len(normalized_corr))

        peak_time = lags[channel] / sampling_rate
        offset_time_vector = np.linspace(0, len(normalized_noise) / sampling_rate, len(normalized_noise)) + peak_time

        # Plot cross-correlation in the upper subplot
        ax1 = plt.subplot(2, 1, 1)
        plt.plot(time_vector, normalized_corr, color='g', label='Cross-Correlation', alpha=0.7, linestyle='-')
        plt.axvline(x=peak_time, color='r', linestyle='--', label='Detected Peak')
        plt.legend()
        plt.title(f'Cross-Correlation for Channel {channel + 1} / {channel_labels[channel]}')
        plt.xlabel('Time (seconds)')
        plt.ylabel('Normalized Correlation')
        plt.xlim(eeg_start_time, eeg_end_time)

        # Plot EEG and noise signals in the lower subplot
        ax2 = plt.subplot(2, 1, 2)
        plt.plot(time_vector, normalized_eeg, label=f'EEG Channel {channel + 1}', alpha=0.7)
        plt.plot(offset_time_vector, normalized_noise, label='Offset Noise Signal', alpha=0.7)
        plt.axvline(x=peak_time, color='r', linestyle='--', label='Detected Peak')
        plt.legend()
        plt.title(f'EEG, Offset Noise, and Cross-Correlation for Channel {channel + 1} / {channel_labels[channel]}')
        plt.xlabel('Time (seconds)')
        plt.ylabel('Normalized Amplitude')
        plt.xlim(eeg_start_time, eeg_end_time)

        plt.show()

    return lags, cross_corr_results

# Example usage with the provided loop

noise_experiments = experiment_data[experiment_data['has_noise']]
skip = 4
for _, experiment in noise_experiments.iterrows():
    if skip > 0:
        skip -= 1
        continue
    eeg_data = experiment['eeg']
    print(experiment['base_delay'])
    print(experiment['target_timings'])
    print(eeg_data.shape)
    noise_stim_left = experiment['noise_stim_left']
    noise_stim_right = experiment['noise_stim_right']
    print(noise_stim_left.shape)

    lags, _ = compute_best_combined_offset(eeg_data, noise_stim_left, noise_stim_right, channel_labels,
                                        upsample_noise=True,
                                        eeg_start_time=0.0, eeg_end_time=3.2,
                                        noise_start_time=0.0, noise_end_time=2.0, # noise times present in the experiment
                                        bandpass_eeg=(50, 80), bandpass_noise=(50, 80),
                                        mix_noise_strength=0, # this can be used to mix in noise
                                        plot_results=False
    )
    print('Lags:', lags)
    break
0.119
0.768
(6758, 37)
(4800,)
Lags: [186, 1866, 2470, 1217, 1221, 423, 332, 151, 224, 1223, 1185, 1039, 2697, 304, 121, 2484, 2283, 2835, 1039, 2698, 406, 406, 919, 1669, 1352, 2256, 2256, 2255, 2505, 1077, 2136, 1820, 1850, 1848, 201, 2696, 2882]

Visualise cross correlation over all trials

In [6]:
eeg_data.shape
(6758, 37)
In [9]:
num_channels = eeg_data.shape[1]
noise_experiments = experiment_data[experiment_data['has_noise']]
print(noise_experiments.shape)
cross_corrs_all_trials = [[] for _ in range(num_channels)]

def process_experiment(experiment, channel_labels):
    eeg_data = experiment['eeg']
    noise_stim_left = experiment['noise_stim_left']
    noise_stim_right = experiment['noise_stim_right']

    lags, correlations = compute_best_combined_offset(eeg_data, noise_stim_left, noise_stim_right, channel_labels,
                                                      upsample_noise=True,
                                                      eeg_start_time=0.0, eeg_end_time=3.2,
                                                      noise_start_time=0.0, noise_end_time=2.0, # noise times present in the experiment
                                                    #   noise_start_time=4.0, noise_end_time=6.0, # noise times not present in the experiment
                                                      bandpass_eeg=(50, 80), bandpass_noise=(50, 80),
                                                    #   mix_noise_strength=0.2,
                                                      plot_results=False)
    # correlations = [interpolate_local_maxima(corr, 2048)[0] for corr in correlations]
    return correlations

def parallel_process_experiments(experiments, channel_labels, num_channels):
    cross_corrs_all_trials = [[] for _ in range(num_channels)]

    with ProcessPoolExecutor() as executor:
        futures = {executor.submit(process_experiment, experiment, channel_labels): i for i, experiment in experiments.iterrows()}

        for future in as_completed(futures):
            correlations = future.result()
            for channel in range(num_channels):
                cross_corrs_all_trials[channel].append(correlations[channel])

    return cross_corrs_all_trials

cross_corrs_all_trials = parallel_process_experiments(noise_experiments, channel_labels, num_channels)

# Plot heatmaps for each channel
sampling_rate = 2048
for channel, channel_label in enumerate(channel_labels):
    if channel_label not in ['Pz', 'POz', 'Iz', 'O1', 'PO3']: continue # Only plot the interesting occipital channels

    plt.figure(figsize=(40, 10))
    # Plot heatmap of cross-correlations for all trials
    # plt.subplot(2, 1, 1)
    # sns.heatmap(np.array(cross_corrs_all_trials[channel]), cmap='viridis', cbar=False)
    # plt.title(f'Heatmap of Cross-Correlations for Channel {channel + 1} / {channel_labels[channel]}')
    # plt.xlabel('Lag')
    # plt.ylabel('Trial')

    # plt.subplot(2, 1, 2)
    # Calculate and plot cumulative cross-correlation
    cumulative_corr = np.sum(np.array(cross_corrs_all_trials[channel]), axis=0)
    time_vector = np.linspace(0, len(cumulative_corr) / sampling_rate, len(cumulative_corr))

    plt.plot(time_vector, cumulative_corr, label='Cumulative Cross-Correlation')
    plt.title(f'Cumulative Cross-Correlation for Channel {channel + 1} / {channel_labels[channel]}')
    plt.xlabel('Time (seconds)')
    plt.ylabel('Cumulative Correlation')
    plt.xlim(0, time_vector[-1])
    plt.legend()

    plt.tight_layout()
    plt.show()
(219, 12)

In [8]:
# Function to plot EEG channel before and after filtering
def plot_signal_filtering(signal, sampling_rate, lowcut, highcut, order=5, signal_name='EEG Signal'):
    # Apply bandpass filter
    filtered_signal = bandpass_filter(signal, lowcut, highcut, sampling_rate, order=order)

    # Time vector for plotting
    time_vector = np.linspace(0, len(signal) / sampling_rate, len(signal))

    # Plot EEG signal before and after filtering
    plt.figure(figsize=(40, 10))

    plt.plot(time_vector, signal, label=f'Original {signal_name}')
    plt.plot(time_vector, filtered_signal, label=f'Filtered {signal_name}', color='orange')
    plt.title(f'{signal_name} Beforeand After Filtering')
    plt.xlabel('Time (seconds)')
    plt.ylabel('Amplitude')
    plt.xlim(0, len(signal) / sampling_rate)
    plt.legend()

    plt.show()

    # Frequency analysis before and after filtering
    freqs_before, psd_before = welch(signal, fs=sampling_rate)
    freqs_after, psd_after = welch(filtered_signal, fs=sampling_rate)

    plt.figure(figsize=(40, 10))

    plt.semilogy(freqs_before, psd_before, label=f'Original {signal_name}')
    plt.semilogy(freqs_after, psd_after, label=f'Filtered {signal_name}', color='orange')
    plt.title('Frequency Analysis Before and After Filtering')
    plt.xlabel('Frequency (Hz)')
    plt.ylabel('Power Spectral Density (dB/Hz)')
    plt.xlim(0, 200)
    plt.legend()

    plt.show()

# Example usage with the provided loop
skip = 2
for _, experiment in experiment_data.iterrows():
    if not experiment['has_noise']: continue
    if skip > 0:
        skip -= 1
        continue
    channel = 1
    eeg_data = experiment['eeg'][:, channel]
    noise_stim_left = experiment['noise_stim_left']
    print(eeg_data.shape)
    plot_signal_filtering(eeg_data, sampling_rate=2048, lowcut=50, highcut=80)
    plot_signal_filtering(noise_stim_left, sampling_rate=480, lowcut=50, highcut=80, signal_name='Noise Stimulus')
    break
(6758,)

Topography plot of SNR’s

In [14]:
import mne

def calculate_snr(cross_corrs_all_trials, sampling_rate, expected_peak_time=None):
    snrs = []
    for channel_corrs in cross_corrs_all_trials:
        cumulative_corr = np.sum(channel_corrs, axis=0)
        # Time vector for the cross-correlation
        # time_vector = np.linspace(0, len(cumulative_corr) / sampling_rate, len(cumulative_corr))

        if expected_peak_time is not None:
            # Find the peak closest to the expected peak time
            peak_signal = np.max(cumulative_corr[int((expected_peak_time - 0.0)*sampling_rate):
                                                 int((expected_peak_time + 0.1)*sampling_rate)])
        else:
            # Find the highest peak in the cross-correlation
            print(np.max(cumulative_corr))
            peak_signal = np.max(cumulative_corr)

        # Calculate the noise as the standard deviation of the cross-correlation
        noise = np.std(cumulative_corr)  # Assuming this represents the noise
        snr = 10 * np.log10(peak_signal**2 / noise**2)
        snrs.append(snr)
    return snrs

# Plot topography
def plot_topography(snr_values, channel_labels, sampling_rate):
    info = mne.create_info(ch_names=channel_labels, sfreq=sampling_rate, ch_types='eeg')

    # Use a standard montage for electrode positions
    montage = mne.channels.make_standard_montage('standard_1020')
    info.set_montage(montage)

    # Create Evoked object with SNR values
    evoked_data = np.array(snr_values).reshape(-1, 1)
    evoked = mne.EvokedArray(evoked_data, info)

    # Plot topography
    fig, ax = plt.subplots(figsize=(8, 6))
    im, cn = mne.viz.plot_topomap(evoked.data[:, 0], evoked.info, axes=ax, show=False)
    fig.colorbar(im, ax=ax, format="%0.1f dB")
    ax.set_title('Topography of SNR (dB)', fontsize=14)
    plt.show()

# Example usage
expected_peak_time = 1.077 # Adjust based on where you expect the peak
snr_values = calculate_snr(cross_corrs_all_trials, sampling_rate, expected_peak_time)
plot_topography(snr_values, channel_labels, sampling_rate)

Predicting cued side

In [15]:
from sklearn.metrics import accuracy_score
from scipy.signal import coherence, resample, sosfilt, butter

def determine_dominant_side_coherence(eeg_data, noise_stim_left, noise_stim_right, channel_labels, lags,
                                      sampling_rate=2048, noise_sample_rate=480,
                                      upsample_noise=True, eeg_start_time=0.0, eeg_end_time=3.2,
                                      noise_start_time=0.0, noise_end_time=2.0,
                                      bandpass_eeg=None, bandpass_noise=None):

    num_samples, num_channels = eeg_data.shape
    eeg_segment_start = int(eeg_start_time * sampling_rate)
    eeg_segment_end = int(eeg_end_time * sampling_rate)

    # Select EEG segment
    eeg_data_segment = eeg_data[eeg_segment_start:eeg_segment_end, :].copy()

    # Select noise segment
    noise_segment_start = int(noise_start_time * noise_sample_rate)
    noise_segment_end = int(noise_end_time * noise_sample_rate)
    noise_stim_left_segment = noise_stim_left[noise_segment_start:noise_segment_end].copy()
    noise_stim_right_segment = noise_stim_right[noise_segment_start:noise_segment_end].copy()

    # Optionally bandpass filter the signals
    if bandpass_eeg:
        lowcut, highcut = bandpass_eeg
        for channel in range(num_channels):
            eeg_data_segment[:, channel] = bandpass_filter_sos(eeg_data_segment[:, channel], lowcut, highcut, sampling_rate)

    if bandpass_noise:
        lowcut, highcut = bandpass_noise
        noise_stim_left_segment = bandpass_filter_sos(noise_stim_left_segment, lowcut, highcut, noise_sample_rate)
        noise_stim_right_segment = bandpass_filter_sos(noise_stim_right_segment, lowcut, highcut, noise_sample_rate)

    # Upsample or downsample the noise signals
    if upsample_noise:
        noise_stim_left_segment = upsample_or_downsample(noise_stim_left_segment, noise_sample_rate, sampling_rate)
        noise_stim_right_segment = upsample_or_downsample(noise_stim_right_segment, noise_sample_rate, sampling_rate)
    else:
        eeg_data_segment = upsample_or_downsample(eeg_data_segment, sampling_rate, noise_sample_rate)
        sampling_rate = noise_sample_rate

    # Normalise the signals before calculating coherence
    noise_stim_left_segment = normalise_signal(noise_stim_left_segment)
    noise_stim_right_segment = normalise_signal(noise_stim_right_segment)
    eeg_data_segment = normalise_signal(eeg_data_segment)

    predictions = []

    for channel, channel_label in enumerate(channel_labels):
        lag = lags[channel_label]

        # Select EEG segment starting from the lag
        if upsample_noise:
            # Find the coherent part from lag in EEG to the end
            eeg_aligned = eeg_data_segment[:, channel][lag+sampling_rate:lag+len(noise_stim_left_segment)]
        else:
            # Find the coherent part from lag in EEG to the end
            eeg_aligned = eeg_data_segment[:, channel][lag+sampling_rate:lag+len(noise_stim_left_segment)]
            eeg_aligned = upsample_or_downsample(eeg_aligned, noise_sample_rate, sampling_rate)

        # Compute coherence for left and right noise signals
        f, coh_left = coherence(eeg_aligned[:sampling_rate], noise_stim_left_segment[:sampling_rate], fs=sampling_rate, nperseg=256)
        f, coh_right = coherence(eeg_aligned[:sampling_rate], noise_stim_right_segment[:sampling_rate], fs=sampling_rate, nperseg=256)

        # print(coh_left.shape, coh_right.shape)

        # Calculate mean coherence value
        mean_coh_left = np.max(coh_left)
        mean_coh_right = np.max(coh_right)

        # Predict side: False for left, True for right
        predicted_side = mean_coh_right > mean_coh_left
        predictions.append(predicted_side)

    return predictions

def determine_dominant_side(eeg_data, noise_stim_left, noise_stim_right, channel_labels, lags,
                            sampling_rate=2048, noise_sample_rate=480,
                            upsample_noise=True, eeg_start_time=0.0, eeg_end_time=3.2,
                            noise_start_time=0.0, noise_end_time=2.0,
                            bandpass_eeg=None, bandpass_noise=None,):

    num_samples, num_channels = eeg_data.shape
    # num_channels = len(channel_labels)
    eeg_segment_start = int(eeg_start_time * sampling_rate)
    eeg_segment_end = int(eeg_end_time * sampling_rate)

    # Select EEG segment
    eeg_data_segment = eeg_data[eeg_segment_start:eeg_segment_end, :].copy()

    # Select noise segment
    noise_segment_start = int(noise_start_time * noise_sample_rate)
    noise_segment_end = int(noise_end_time * noise_sample_rate)
    noise_stim_left_segment = noise_stim_left[noise_segment_start:noise_segment_end].copy()
    noise_stim_right_segment = noise_stim_right[noise_segment_start:noise_segment_end].copy()

    # Optionally bandpass filter the signals
    if bandpass_eeg:
        lowcut, highcut = bandpass_eeg
        for channel in range(num_channels):
            eeg_data_segment[:, channel] = bandpass_filter_sos(eeg_data_segment[:, channel], lowcut, highcut, sampling_rate)

    if bandpass_noise:
        lowcut, highcut = bandpass_noise
        noise_stim_left_segment = bandpass_filter_sos(noise_stim_left_segment, lowcut, highcut, noise_sample_rate)
        noise_stim_right_segment = bandpass_filter_sos(noise_stim_right_segment, lowcut, highcut, noise_sample_rate)

    # Upsample or downsample the noise signals
    if upsample_noise:
        noise_stim_left_segment = upsample_or_downsample(noise_stim_left_segment, noise_sample_rate, sampling_rate)
        noise_stim_right_segment = upsample_or_downsample(noise_stim_right_segment, noise_sample_rate, sampling_rate)
    else:
        eeg_data_segment = upsample_or_downsample(eeg_data_segment, sampling_rate, noise_sample_rate)
        sampling_rate = noise_sample_rate

    # Normalise the signals before mixing and cross-correlation
    noise_stim_left_segment = normalise_signal(noise_stim_left_segment)
    noise_stim_right_segment = normalise_signal(noise_stim_right_segment)
    eeg_data_segment = normalise_signal(eeg_data_segment)

    predictions = []

    for channel, channel_label in enumerate(channel_labels):
        lag = lags[channel_label]

        # Compute correlation for left and right noise signals starting from the lag
        corr_left = correlate(eeg_data_segment[:, channel], noise_stim_left_segment, mode='full')
        corr_right = correlate(eeg_data_segment[:, channel], noise_stim_right_segment, mode='full')
        # print(corr_left.shape, corr_right.shape)

        # Use lag to index into the correlation
        corr_left_value = corr_left[lag + (len(noise_stim_left_segment) - 1)]
        corr_right_value = corr_right[lag + (len(noise_stim_right_segment) - 1)]
        print(corr_left_value, corr_right_value)
        # Predict side: False for left, True for right
        predicted_side = corr_right_value > corr_left_value
        predictions.append(predicted_side)

    return predictions

def compute_channel_accuracy(experiment_data, cross_corrs_all_trials, sampling_rate):
    lags = {}
    for channel_corrs, channel_label in zip(cross_corrs_all_trials, channel_labels):
        cumulative_corr = np.sum(channel_corrs, axis=0)
        lag = np.argmax(cumulative_corr)
        lags[channel_label] = lag
    print(sorted(lags.items(), key=lambda x: x[1]))

    accuracies = {label: [] for label in channel_labels}
    overall_predictions = []

    for _, experiment in experiment_data.iterrows():
        eeg_data = experiment['eeg']
        noise_stim_left = experiment['noise_stim_left']
        noise_stim_right = experiment['noise_stim_right']
        actual_side = experiment['attended_side']

        predictions = determine_dominant_side_coherence(eeg_data, noise_stim_left, noise_stim_right, channel_labels, lags,
                                              upsample_noise=True, bandpass_eeg=(50, 80), bandpass_noise=(50, 80),
                                              eeg_start_time=0.0, eeg_end_time=3.2,
                                              noise_start_time=0.0, noise_end_time=2.0)

        for channel, predicted_side in zip(channel_labels, predictions):
            if channel == 'Iz' or channel == 'POz':
                print(f'Channel {channel}: Predicted Side = {predicted_side}, Actual Side = {actual_side}')
            accuracies[channel].append(predicted_side == actual_side)

        overall_predictions.append(predictions)

    # Compute accuracy per channel
    # 'Iz', 2206), ('POz'
    # print(channel['Iz'])
    channel_accuracies = {channel: np.mean(acc) for channel, acc in accuracies.items()}

    return channel_accuracies, overall_predictions

def majority_voting(overall_predictions, experiment_data, selected_channels):
    majority_predictions = []

    for trial_predictions, experiment in zip(overall_predictions, experiment_data.iterrows()):
        selected_predictions = [trial_predictions[channel_labels.index(channel)] for channel in selected_channels]
        predicted_side = np.mean(selected_predictions) > 0.5
        actual_side = experiment[1]['attended_side']
        majority_predictions.append(predicted_side == actual_side)

    overall_accuracy = np.mean(majority_predictions)
    return overall_accuracy

# Example usage
noise_experiments = experiment_data[experiment_data['has_noise']]
channel_accuracies, overall_predictions = compute_channel_accuracy(noise_experiments, cross_corrs_all_trials, sampling_rate)

# Print accuracy for each channel
for channel, accuracy in channel_accuracies.items():
    print(f'Channel {channel}: Accuracy = {accuracy:.2f}')

# Majority voting across selected channels
selected_channels = channel_labels  # Example: use all channels
selected_channels = ['P5', 'PO3', 'O1', 'Iz', 'POz',]
overall_accuracy = majority_voting(overall_predictions, experiment_data, selected_channels)
print(f'Overall accuracy with majority voting: {overall_accuracy:.2f}')
[('FC5', 2), ('AF7', 3), ('PO8', 79), ('F5', 473), ('P6', 720), ('FT7', 726), ('P7', 787), ('T7', 907), ('Fp1', 934), ('PO7', 1002), ('C3', 1213), ('C5', 1245), ('CP5', 1247), ('F1', 1552), ('P8', 1562), ('FC1', 1691), ('PO4', 1717), ('P4', 1763), ('CP3', 1765), ('O2', 1779), ('F3', 1824), ('P1', 2173), ('Pz', 2174), ('P3', 2175), ('P5', 2206), ('PO3', 2206), ('O1', 2206), ('Iz', 2206), ('POz', 2207), ('AF3', 2371), ('F7', 2444), ('Oz', 2528), ('FC3', 2565), ('C1', 2569), ('TP7', 2875), ('CP1', 3216), ('P2', 3248)]
Channel Iz: Predicted Side = True, Actual Side = 0.0
Channel POz: Predicted Side = True, Actual Side = 0.0
Channel Iz: Predicted Side = False, Actual Side = 1.0
Channel POz: Predicted Side = False, Actual Side = 1.0
Channel Iz: Predicted Side = True, Actual Side = 1.0
Channel POz: Predicted Side = True, Actual Side = 1.0
Channel Iz: Predicted Side = False, Actual Side = 1.0
Channel POz: Predicted Side = True, Actual Side = 1.0
Channel Iz: Predicted Side = True, Actual Side = 0.0
Channel POz: Predicted Side = False, Actual Side = 0.0
Channel Iz: Predicted Side = False, Actual Side = 1.0
Channel POz: Predicted Side = True, Actual Side = 1.0
Channel Iz: Predicted Side = True, Actual Side = 1.0
Channel POz: Predicted Side = True, Actual Side = 1.0
Channel Iz: Predicted Side = False, Actual Side = 1.0
Channel POz: Predicted Side = True, Actual Side = 1.0
Channel Iz: Predicted Side = True, Actual Side = 1.0
Channel POz: Predicted Side = False, Actual Side = 1.0
Channel Iz: Predicted Side = True, Actual Side = 1.0
Channel POz: Predicted Side = True, Actual Side = 1.0
Channel Iz: Predicted Side = False, Actual Side = 1.0
Channel POz: Predicted Side = True, Actual Side = 1.0
Channel Iz: Predicted Side = True, Actual Side = 0.0
Channel POz: Predicted Side = True, Actual Side = 0.0
Channel Iz: Predicted Side = True, Actual Side = 1.0
Channel POz: Predicted Side = False, Actual Side = 1.0
Channel Iz: Predicted Side = False, Actual Side = 0.0
Channel POz: Predicted Side = False, Actual Side = 0.0
Channel Iz: Predicted Side = True, Actual Side = 0.0
Channel POz: Predicted Side = False, Actual Side = 0.0
Channel Iz: Predicted Side = True, Actual Side = 1.0
Channel POz: Predicted Side = False, Actual Side = 1.0
Channel Iz: Predicted Side = True, Actual Side = 0.0
Channel POz: Predicted Side = True, Actual Side = 0.0
Channel Iz: Predicted Side = False, Actual Side = 0.0
Channel POz: Predicted Side = False, Actual Side = 0.0
Channel Iz: Predicted Side = False, Actual Side = 1.0
Channel POz: Predicted Side = False, Actual Side = 1.0
Channel Iz: Predicted Side = True, Actual Side = 1.0
Channel POz: Predicted Side = True, Actual Side = 1.0
Channel Iz: Predicted Side = True, Actual Side = 1.0
Channel POz: Predicted Side = True, Actual Side = 1.0
Channel Iz: Predicted Side = False, Actual Side = 1.0
Channel POz: Predicted Side = True, Actual Side = 1.0
Channel Iz: Predicted Side = True, Actual Side = 0.0
Channel POz: Predicted Side = True, Actual Side = 0.0
Channel Iz: Predicted Side = True, Actual Side = 1.0
Channel POz: Predicted Side = False, Actual Side = 1.0
Channel Iz: Predicted Side = True, Actual Side = 0.0
Channel POz: Predicted Side = True, Actual Side = 0.0
Channel Iz: Predicted Side = False, Actual Side = 1.0
Channel POz: Predicted Side = False, Actual Side = 1.0
Channel Iz: Predicted Side = False, Actual Side = 1.0
Channel POz: Predicted Side = True, Actual Side = 1.0
Channel Iz: Predicted Side = True, Actual Side = 1.0
Channel POz: Predicted Side = True, Actual Side = 1.0
Channel Iz: Predicted Side = False, Actual Side = 1.0
Channel POz: Predicted Side = True, Actual Side = 1.0
Channel Iz: Predicted Side = True, Actual Side = 1.0
Channel POz: Predicted Side = False, Actual Side = 1.0
Channel Iz: Predicted Side = False, Actual Side = 0.0
Channel POz: Predicted Side = True, Actual Side = 0.0
Channel Iz: Predicted Side = False, Actual Side = 1.0
Channel POz: Predicted Side = False, Actual Side = 1.0
Channel Iz: Predicted Side = True, Actual Side = 1.0
Channel POz: Predicted Side = False, Actual Side = 1.0
Channel Iz: Predicted Side = False, Actual Side = 1.0
Channel POz: Predicted Side = False, Actual Side = 1.0
Channel Iz: Predicted Side = True, Actual Side = 1.0
Channel POz: Predicted Side = True, Actual Side = 1.0
Channel Iz: Predicted Side = True, Actual Side = 1.0
Channel POz: Predicted Side = True, Actual Side = 1.0
Channel Iz: Predicted Side = True, Actual Side = 0.0
Channel POz: Predicted Side = False, Actual Side = 0.0
Channel Iz: Predicted Side = True, Actual Side = 1.0
Channel POz: Predicted Side = True, Actual Side = 1.0
Channel Iz: Predicted Side = True, Actual Side = 0.0
Channel POz: Predicted Side = False, Actual Side = 0.0
Channel Iz: Predicted Side = False, Actual Side = 1.0
Channel POz: Predicted Side = True, Actual Side = 1.0
Channel Iz: Predicted Side = False, Actual Side = 1.0
Channel POz: Predicted Side = False, Actual Side = 1.0
Channel Iz: Predicted Side = False, Actual Side = 0.0
Channel POz: Predicted Side = True, Actual Side = 0.0
Channel Iz: Predicted Side = False, Actual Side = 0.0
Channel POz: Predicted Side = True, Actual Side = 0.0
Channel Iz: Predicted Side = False, Actual Side = 0.0
Channel POz: Predicted Side = True, Actual Side = 0.0
Channel Iz: Predicted Side = False, Actual Side = 0.0
Channel POz: Predicted Side = True, Actual Side = 0.0
Channel Iz: Predicted Side = True, Actual Side = 0.0
Channel POz: Predicted Side = True, Actual Side = 0.0
Channel Iz: Predicted Side = True, Actual Side = 0.0
Channel POz: Predicted Side = True, Actual Side = 0.0
Channel Iz: Predicted Side = True, Actual Side = 1.0
Channel POz: Predicted Side = True, Actual Side = 1.0
Channel Iz: Predicted Side = True, Actual Side = 1.0
Channel POz: Predicted Side = True, Actual Side = 1.0
Channel Iz: Predicted Side = True, Actual Side = 0.0
Channel POz: Predicted Side = True, Actual Side = 0.0
Channel Iz: Predicted Side = True, Actual Side = 0.0
Channel POz: Predicted Side = False, Actual Side = 0.0
Channel Iz: Predicted Side = False, Actual Side = 0.0
Channel POz: Predicted Side = True, Actual Side = 0.0
Channel Iz: Predicted Side = False, Actual Side = 0.0
Channel POz: Predicted Side = False, Actual Side = 0.0
Channel Iz: Predicted Side = True, Actual Side = 0.0
Channel POz: Predicted Side = True, Actual Side = 0.0
Channel Iz: Predicted Side = False, Actual Side = 1.0
Channel POz: Predicted Side = True, Actual Side = 1.0
Channel Iz: Predicted Side = True, Actual Side = 1.0
Channel POz: Predicted Side = True, Actual Side = 1.0
Channel Iz: Predicted Side = False, Actual Side = 1.0
Channel POz: Predicted Side = False, Actual Side = 1.0
Channel Iz: Predicted Side = True, Actual Side = 0.0
Channel POz: Predicted Side = True, Actual Side = 0.0
Channel Iz: Predicted Side = False, Actual Side = 1.0
Channel POz: Predicted Side = True, Actual Side = 1.0
Channel Iz: Predicted Side = False, Actual Side = 1.0
Channel POz: Predicted Side = True, Actual Side = 1.0
Channel Iz: Predicted Side = True, Actual Side = 1.0
Channel POz: Predicted Side = True, Actual Side = 1.0
Channel Iz: Predicted Side = True, Actual Side = 0.0
Channel POz: Predicted Side = True, Actual Side = 0.0
Channel Iz: Predicted Side = False, Actual Side = 1.0
Channel POz: Predicted Side = False, Actual Side = 1.0
Channel Iz: Predicted Side = True, Actual Side = 0.0
Channel POz: Predicted Side = False, Actual Side = 0.0
Channel Iz: Predicted Side = True, Actual Side = 1.0
Channel POz: Predicted Side = False, Actual Side = 1.0
Channel Iz: Predicted Side = False, Actual Side = 0.0
Channel POz: Predicted Side = False, Actual Side = 0.0
Channel Iz: Predicted Side = True, Actual Side = 0.0
Channel POz: Predicted Side = False, Actual Side = 0.0
Channel Iz: Predicted Side = True, Actual Side = 0.0
Channel POz: Predicted Side = True, Actual Side = 0.0
Channel Iz: Predicted Side = True, Actual Side = 1.0
Channel POz: Predicted Side = False, Actual Side = 1.0
Channel Iz: Predicted Side = True, Actual Side = 1.0
Channel POz: Predicted Side = False, Actual Side = 1.0
Channel Iz: Predicted Side = True, Actual Side = 0.0
Channel POz: Predicted Side = False, Actual Side = 0.0
Channel Iz: Predicted Side = True, Actual Side = 0.0
Channel POz: Predicted Side = True, Actual Side = 0.0
Channel Iz: Predicted Side = False, Actual Side = 0.0
Channel POz: Predicted Side = True, Actual Side = 0.0
Channel Iz: Predicted Side = True, Actual Side = 1.0
Channel POz: Predicted Side = False, Actual Side = 1.0
Channel Iz: Predicted Side = False, Actual Side = 1.0
Channel POz: Predicted Side = True, Actual Side = 1.0
Channel Iz: Predicted Side = False, Actual Side = 0.0
Channel POz: Predicted Side = False, Actual Side = 0.0
Channel Iz: Predicted Side = False, Actual Side = 0.0
Channel POz: Predicted Side = True, Actual Side = 0.0
Channel Iz: Predicted Side = True, Actual Side = 1.0
Channel POz: Predicted Side = True, Actual Side = 1.0
Channel Iz: Predicted Side = False, Actual Side = 1.0
Channel POz: Predicted Side = True, Actual Side = 1.0
Channel Iz: Predicted Side = True, Actual Side = 1.0
Channel POz: Predicted Side = True, Actual Side = 1.0
Channel Iz: Predicted Side = True, Actual Side = 0.0
Channel POz: Predicted Side = False, Actual Side = 0.0
Channel Iz: Predicted Side = False, Actual Side = 1.0
Channel POz: Predicted Side = True, Actual Side = 1.0
Channel Iz: Predicted Side = True, Actual Side = 0.0
Channel POz: Predicted Side = True, Actual Side = 0.0
Channel Iz: Predicted Side = True, Actual Side = 1.0
Channel POz: Predicted Side = False, Actual Side = 1.0
Channel Iz: Predicted Side = True, Actual Side = 1.0
Channel POz: Predicted Side = False, Actual Side = 1.0
Channel Iz: Predicted Side = False, Actual Side = 1.0
Channel POz: Predicted Side = True, Actual Side = 1.0
Channel Iz: Predicted Side = True, Actual Side = 0.0
Channel POz: Predicted Side = True, Actual Side = 0.0
Channel Iz: Predicted Side = True, Actual Side = 0.0
Channel POz: Predicted Side = True, Actual Side = 0.0
Channel Iz: Predicted Side = False, Actual Side = 1.0
Channel POz: Predicted Side = True, Actual Side = 1.0
Channel Iz: Predicted Side = False, Actual Side = 0.0
Channel POz: Predicted Side = True, Actual Side = 0.0
Channel Iz: Predicted Side = False, Actual Side = 0.0
Channel POz: Predicted Side = True, Actual Side = 0.0
Channel Iz: Predicted Side = True, Actual Side = 1.0
Channel POz: Predicted Side = False, Actual Side = 1.0
Channel Iz: Predicted Side = False, Actual Side = 0.0
Channel POz: Predicted Side = True, Actual Side = 0.0
Channel Iz: Predicted Side = False, Actual Side = 0.0
Channel POz: Predicted Side = False, Actual Side = 0.0
Channel Iz: Predicted Side = True, Actual Side = 0.0
Channel POz: Predicted Side = True, Actual Side = 0.0
Channel Iz: Predicted Side = True, Actual Side = 1.0
Channel POz: Predicted Side = False, Actual Side = 1.0
Channel Iz: Predicted Side = False, Actual Side = 1.0
Channel POz: Predicted Side = False, Actual Side = 1.0
Channel Iz: Predicted Side = True, Actual Side = 1.0
Channel POz: Predicted Side = True, Actual Side = 1.0
Channel Iz: Predicted Side = False, Actual Side = 0.0
Channel POz: Predicted Side = False, Actual Side = 0.0
Channel Iz: Predicted Side = True, Actual Side = 0.0
Channel POz: Predicted Side = True, Actual Side = 0.0
Channel Iz: Predicted Side = True, Actual Side = 1.0
Channel POz: Predicted Side = False, Actual Side = 1.0
Channel Iz: Predicted Side = True, Actual Side = 1.0
Channel POz: Predicted Side = False, Actual Side = 1.0
Channel Iz: Predicted Side = False, Actual Side = 0.0
Channel POz: Predicted Side = False, Actual Side = 0.0
Channel Iz: Predicted Side = False, Actual Side = 0.0
Channel POz: Predicted Side = False, Actual Side = 0.0
Channel Iz: Predicted Side = False, Actual Side = 0.0
Channel POz: Predicted Side = False, Actual Side = 0.0
Channel Iz: Predicted Side = True, Actual Side = 0.0
Channel POz: Predicted Side = False, Actual Side = 0.0
Channel Iz: Predicted Side = True, Actual Side = 0.0
Channel POz: Predicted Side = True, Actual Side = 0.0
Channel Iz: Predicted Side = True, Actual Side = 1.0
Channel POz: Predicted Side = True, Actual Side = 1.0
Channel Iz: Predicted Side = True, Actual Side = 0.0
Channel POz: Predicted Side = True, Actual Side = 0.0
Channel Iz: Predicted Side = True, Actual Side = 0.0
Channel POz: Predicted Side = True, Actual Side = 0.0
Channel Iz: Predicted Side = True, Actual Side = 0.0
Channel POz: Predicted Side = True, Actual Side = 0.0
Channel Iz: Predicted Side = True, Actual Side = 0.0
Channel POz: Predicted Side = True, Actual Side = 0.0
Channel Iz: Predicted Side = True, Actual Side = 1.0
Channel POz: Predicted Side = True, Actual Side = 1.0
Channel Iz: Predicted Side = False, Actual Side = 1.0
Channel POz: Predicted Side = True, Actual Side = 1.0
Channel Iz: Predicted Side = True, Actual Side = 0.0
Channel POz: Predicted Side = False, Actual Side = 0.0
Channel Iz: Predicted Side = True, Actual Side = 0.0
Channel POz: Predicted Side = False, Actual Side = 0.0
Channel Iz: Predicted Side = True, Actual Side = 0.0
Channel POz: Predicted Side = True, Actual Side = 0.0
Channel Iz: Predicted Side = False, Actual Side = 1.0
Channel POz: Predicted Side = False, Actual Side = 1.0
Channel Iz: Predicted Side = False, Actual Side = 0.0
Channel POz: Predicted Side = True, Actual Side = 0.0
Channel Iz: Predicted Side = False, Actual Side = 1.0
Channel POz: Predicted Side = False, Actual Side = 1.0
Channel Iz: Predicted Side = False, Actual Side = 0.0
Channel POz: Predicted Side = False, Actual Side = 0.0
Channel Iz: Predicted Side = True, Actual Side = 1.0
Channel POz: Predicted Side = False, Actual Side = 1.0
Channel Iz: Predicted Side = True, Actual Side = 1.0
Channel POz: Predicted Side = False, Actual Side = 1.0
Channel Iz: Predicted Side = True, Actual Side = 0.0
Channel POz: Predicted Side = False, Actual Side = 0.0
Channel Iz: Predicted Side = False, Actual Side = 1.0
Channel POz: Predicted Side = False, Actual Side = 1.0
Channel Iz: Predicted Side = False, Actual Side = 1.0
Channel POz: Predicted Side = True, Actual Side = 1.0
Channel Iz: Predicted Side = True, Actual Side = 1.0
Channel POz: Predicted Side = False, Actual Side = 1.0
Channel Iz: Predicted Side = False, Actual Side = 1.0
Channel POz: Predicted Side = True, Actual Side = 1.0
Channel Iz: Predicted Side = False, Actual Side = 1.0
Channel POz: Predicted Side = True, Actual Side = 1.0
Channel Iz: Predicted Side = True, Actual Side = 1.0
Channel POz: Predicted Side = True, Actual Side = 1.0
Channel Iz: Predicted Side = True, Actual Side = 0.0
Channel POz: Predicted Side = False, Actual Side = 0.0
Channel Iz: Predicted Side = True, Actual Side = 1.0
Channel POz: Predicted Side = False, Actual Side = 1.0
Channel Iz: Predicted Side = True, Actual Side = 0.0
Channel POz: Predicted Side = True, Actual Side = 0.0
Channel Iz: Predicted Side = False, Actual Side = 1.0
Channel POz: Predicted Side = True, Actual Side = 1.0
Channel Iz: Predicted Side = False, Actual Side = 0.0
Channel POz: Predicted Side = False, Actual Side = 0.0
Channel Iz: Predicted Side = False, Actual Side = 1.0
Channel POz: Predicted Side = True, Actual Side = 1.0
Channel Iz: Predicted Side = False, Actual Side = 0.0
Channel POz: Predicted Side = False, Actual Side = 0.0
Channel Iz: Predicted Side = True, Actual Side = 0.0
Channel POz: Predicted Side = True, Actual Side = 0.0
Channel Iz: Predicted Side = True, Actual Side = 0.0
Channel POz: Predicted Side = True, Actual Side = 0.0
Channel Iz: Predicted Side = False, Actual Side = 0.0
Channel POz: Predicted Side = True, Actual Side = 0.0
Channel Iz: Predicted Side = True, Actual Side = 0.0
Channel POz: Predicted Side = False, Actual Side = 0.0
Channel Iz: Predicted Side = True, Actual Side = 0.0
Channel POz: Predicted Side = True, Actual Side = 0.0
Channel Iz: Predicted Side = False, Actual Side = 0.0
Channel POz: Predicted Side = False, Actual Side = 0.0
Channel Iz: Predicted Side = True, Actual Side = 1.0
Channel POz: Predicted Side = False, Actual Side = 1.0
Channel Iz: Predicted Side = False, Actual Side = 0.0
Channel POz: Predicted Side = False, Actual Side = 0.0
Channel Iz: Predicted Side = False, Actual Side = 1.0
Channel POz: Predicted Side = True, Actual Side = 1.0
Channel Iz: Predicted Side = True, Actual Side = 1.0
Channel POz: Predicted Side = True, Actual Side = 1.0
Channel Iz: Predicted Side = True, Actual Side = 0.0
Channel POz: Predicted Side = True, Actual Side = 0.0
Channel Iz: Predicted Side = False, Actual Side = 1.0
Channel POz: Predicted Side = True, Actual Side = 1.0
Channel Iz: Predicted Side = False, Actual Side = 1.0
Channel POz: Predicted Side = True, Actual Side = 1.0
Channel Iz: Predicted Side = False, Actual Side = 1.0
Channel POz: Predicted Side = False, Actual Side = 1.0
Channel Iz: Predicted Side = False, Actual Side = 1.0
Channel POz: Predicted Side = True, Actual Side = 1.0
Channel Iz: Predicted Side = True, Actual Side = 1.0
Channel POz: Predicted Side = False, Actual Side = 1.0
Channel Iz: Predicted Side = False, Actual Side = 1.0
Channel POz: Predicted Side = True, Actual Side = 1.0
Channel Iz: Predicted Side = True, Actual Side = 0.0
Channel POz: Predicted Side = True, Actual Side = 0.0
Channel Iz: Predicted Side = True, Actual Side = 0.0
Channel POz: Predicted Side = True, Actual Side = 0.0
Channel Iz: Predicted Side = True, Actual Side = 0.0
Channel POz: Predicted Side = True, Actual Side = 0.0
Channel Iz: Predicted Side = False, Actual Side = 1.0
Channel POz: Predicted Side = False, Actual Side = 1.0
Channel Iz: Predicted Side = False, Actual Side = 0.0
Channel POz: Predicted Side = False, Actual Side = 0.0
Channel Iz: Predicted Side = True, Actual Side = 0.0
Channel POz: Predicted Side = True, Actual Side = 0.0
Channel Iz: Predicted Side = True, Actual Side = 1.0
Channel POz: Predicted Side = False, Actual Side = 1.0
Channel Iz: Predicted Side = False, Actual Side = 0.0
Channel POz: Predicted Side = False, Actual Side = 0.0
Channel Iz: Predicted Side = True, Actual Side = 0.0
Channel POz: Predicted Side = False, Actual Side = 0.0
Channel Iz: Predicted Side = False, Actual Side = 0.0
Channel POz: Predicted Side = False, Actual Side = 0.0
Channel Iz: Predicted Side = True, Actual Side = 1.0
Channel POz: Predicted Side = False, Actual Side = 1.0
Channel Iz: Predicted Side = True, Actual Side = 0.0
Channel POz: Predicted Side = True, Actual Side = 0.0
Channel Iz: Predicted Side = True, Actual Side = 0.0
Channel POz: Predicted Side = False, Actual Side = 0.0
Channel Iz: Predicted Side = True, Actual Side = 0.0
Channel POz: Predicted Side = False, Actual Side = 0.0
Channel Iz: Predicted Side = False, Actual Side = 1.0
Channel POz: Predicted Side = False, Actual Side = 1.0
Channel Iz: Predicted Side = False, Actual Side = 0.0
Channel POz: Predicted Side = True, Actual Side = 0.0
Channel Iz: Predicted Side = True, Actual Side = 0.0
Channel POz: Predicted Side = False, Actual Side = 0.0
Channel Iz: Predicted Side = True, Actual Side = 1.0
Channel POz: Predicted Side = False, Actual Side = 1.0
Channel Iz: Predicted Side = True, Actual Side = 1.0
Channel POz: Predicted Side = False, Actual Side = 1.0
Channel Iz: Predicted Side = False, Actual Side = 1.0
Channel POz: Predicted Side = True, Actual Side = 1.0
Channel Iz: Predicted Side = False, Actual Side = 0.0
Channel POz: Predicted Side = False, Actual Side = 0.0
Channel Iz: Predicted Side = True, Actual Side = 0.0
Channel POz: Predicted Side = True, Actual Side = 0.0
Channel Iz: Predicted Side = True, Actual Side = 0.0
Channel POz: Predicted Side = False, Actual Side = 0.0
Channel Iz: Predicted Side = False, Actual Side = 0.0
Channel POz: Predicted Side = True, Actual Side = 0.0
Channel Iz: Predicted Side = True, Actual Side = 0.0
Channel POz: Predicted Side = True, Actual Side = 0.0
Channel Iz: Predicted Side = True, Actual Side = 0.0
Channel POz: Predicted Side = True, Actual Side = 0.0
Channel Iz: Predicted Side = True, Actual Side = 1.0
Channel POz: Predicted Side = False, Actual Side = 1.0
Channel Iz: Predicted Side = True, Actual Side = 1.0
Channel POz: Predicted Side = False, Actual Side = 1.0
Channel Iz: Predicted Side = False, Actual Side = 1.0
Channel POz: Predicted Side = True, Actual Side = 1.0
Channel Iz: Predicted Side = False, Actual Side = 1.0
Channel POz: Predicted Side = False, Actual Side = 1.0
Channel Iz: Predicted Side = True, Actual Side = 1.0
Channel POz: Predicted Side = True, Actual Side = 1.0
Channel Iz: Predicted Side = True, Actual Side = 0.0
Channel POz: Predicted Side = True, Actual Side = 0.0
Channel Iz: Predicted Side = True, Actual Side = 0.0
Channel POz: Predicted Side = False, Actual Side = 0.0
Channel Iz: Predicted Side = True, Actual Side = 0.0
Channel POz: Predicted Side = True, Actual Side = 0.0
Channel Iz: Predicted Side = True, Actual Side = 0.0
Channel POz: Predicted Side = True, Actual Side = 0.0
Channel Iz: Predicted Side = True, Actual Side = 1.0
Channel POz: Predicted Side = False, Actual Side = 1.0
Channel Iz: Predicted Side = True, Actual Side = 0.0
Channel POz: Predicted Side = False, Actual Side = 0.0
Channel Iz: Predicted Side = True, Actual Side = 1.0
Channel POz: Predicted Side = True, Actual Side = 1.0
Channel Iz: Predicted Side = False, Actual Side = 1.0
Channel POz: Predicted Side = False, Actual Side = 1.0
Channel Iz: Predicted Side = True, Actual Side = 1.0
Channel POz: Predicted Side = True, Actual Side = 1.0
Channel Iz: Predicted Side = True, Actual Side = 1.0
Channel POz: Predicted Side = False, Actual Side = 1.0
Channel Iz: Predicted Side = True, Actual Side = 0.0
Channel POz: Predicted Side = True, Actual Side = 0.0
Channel Iz: Predicted Side = True, Actual Side = 0.0
Channel POz: Predicted Side = False, Actual Side = 0.0
Channel Iz: Predicted Side = False, Actual Side = 1.0
Channel POz: Predicted Side = False, Actual Side = 1.0
Channel Iz: Predicted Side = False, Actual Side = 1.0
Channel POz: Predicted Side = True, Actual Side = 1.0
Channel Iz: Predicted Side = False, Actual Side = 0.0
Channel POz: Predicted Side = False, Actual Side = 0.0
Channel Iz: Predicted Side = False, Actual Side = 0.0
Channel POz: Predicted Side = False, Actual Side = 0.0
Channel Iz: Predicted Side = False, Actual Side = 1.0
Channel POz: Predicted Side = False, Actual Side = 1.0
Channel Iz: Predicted Side = True, Actual Side = 1.0
Channel POz: Predicted Side = False, Actual Side = 1.0
Channel Iz: Predicted Side = False, Actual Side = 0.0
Channel POz: Predicted Side = True, Actual Side = 0.0
Channel Iz: Predicted Side = False, Actual Side = 1.0
Channel POz: Predicted Side = True, Actual Side = 1.0
Channel Iz: Predicted Side = True, Actual Side = 1.0
Channel POz: Predicted Side = False, Actual Side = 1.0
Channel Iz: Predicted Side = False, Actual Side = 1.0
Channel POz: Predicted Side = False, Actual Side = 1.0
Channel Iz: Predicted Side = True, Actual Side = 1.0
Channel POz: Predicted Side = False, Actual Side = 1.0
Channel Iz: Predicted Side = True, Actual Side = 1.0
Channel POz: Predicted Side = False, Actual Side = 1.0
Channel Iz: Predicted Side = False, Actual Side = 0.0
Channel POz: Predicted Side = False, Actual Side = 0.0
Channel Iz: Predicted Side = False, Actual Side = 0.0
Channel POz: Predicted Side = True, Actual Side = 0.0
Channel Iz: Predicted Side = True, Actual Side = 1.0
Channel POz: Predicted Side = True, Actual Side = 1.0
Channel Iz: Predicted Side = True, Actual Side = 0.0
Channel POz: Predicted Side = True, Actual Side = 0.0
Channel Iz: Predicted Side = False, Actual Side = 1.0
Channel POz: Predicted Side = False, Actual Side = 1.0
Channel Iz: Predicted Side = False, Actual Side = 1.0
Channel POz: Predicted Side = False, Actual Side = 1.0
Channel Iz: Predicted Side = False, Actual Side = 1.0
Channel POz: Predicted Side = False, Actual Side = 1.0
Channel Iz: Predicted Side = True, Actual Side = 0.0
Channel POz: Predicted Side = False, Actual Side = 0.0
Channel Iz: Predicted Side = False, Actual Side = 1.0
Channel POz: Predicted Side = False, Actual Side = 1.0
Channel Iz: Predicted Side = True, Actual Side = 0.0
Channel POz: Predicted Side = True, Actual Side = 0.0
Channel Iz: Predicted Side = False, Actual Side = 0.0
Channel POz: Predicted Side = False, Actual Side = 0.0
Channel Iz: Predicted Side = True, Actual Side = 1.0
Channel POz: Predicted Side = True, Actual Side = 1.0
Channel Iz: Predicted Side = False, Actual Side = 1.0
Channel POz: Predicted Side = True, Actual Side = 1.0
Channel Iz: Predicted Side = False, Actual Side = 1.0
Channel POz: Predicted Side = True, Actual Side = 1.0
Channel Iz: Predicted Side = True, Actual Side = 0.0
Channel POz: Predicted Side = True, Actual Side = 0.0
Channel Iz: Predicted Side = False, Actual Side = 1.0
Channel POz: Predicted Side = False, Actual Side = 1.0
Channel Iz: Predicted Side = False, Actual Side = 0.0
Channel POz: Predicted Side = False, Actual Side = 0.0
Channel Iz: Predicted Side = True, Actual Side = 1.0
Channel POz: Predicted Side = False, Actual Side = 1.0
Channel Iz: Predicted Side = True, Actual Side = 1.0
Channel POz: Predicted Side = True, Actual Side = 1.0
Channel Iz: Predicted Side = True, Actual Side = 0.0
Channel POz: Predicted Side = True, Actual Side = 0.0
Channel Iz: Predicted Side = True, Actual Side = 0.0
Channel POz: Predicted Side = True, Actual Side = 0.0
Channel Iz: Predicted Side = False, Actual Side = 0.0
Channel POz: Predicted Side = False, Actual Side = 0.0
Channel Iz: Predicted Side = True, Actual Side = 1.0
Channel POz: Predicted Side = True, Actual Side = 1.0
Channel Iz: Predicted Side = False, Actual Side = 1.0
Channel POz: Predicted Side = False, Actual Side = 1.0
Channel Fp1: Accuracy = 0.51
Channel AF7: Accuracy = 0.49
Channel AF3: Accuracy = 0.55
Channel F1: Accuracy = 0.54
Channel F3: Accuracy = 0.56
Channel F5: Accuracy = 0.49
Channel F7: Accuracy = 0.51
Channel FT7: Accuracy = 0.48
Channel FC5: Accuracy = 0.45
Channel FC3: Accuracy = 0.54
Channel FC1: Accuracy = 0.51
Channel C1: Accuracy = 0.52
Channel C3: Accuracy = 0.47
Channel C5: Accuracy = 0.53
Channel T7: Accuracy = 0.52
Channel TP7: Accuracy = 0.52
Channel CP5: Accuracy = 0.52
Channel CP3: Accuracy = 0.51
Channel CP1: Accuracy = 0.56
Channel P1: Accuracy = 0.48
Channel P3: Accuracy = 0.54
Channel P5: Accuracy = 0.56
Channel P7: Accuracy = 0.58
Channel PO7: Accuracy = 0.50
Channel PO3: Accuracy = 0.53
Channel O1: Accuracy = 0.54
Channel Iz: Accuracy = 0.45
Channel Oz: Accuracy = 0.52
Channel POz: Accuracy = 0.47
Channel Pz: Accuracy = 0.52
Channel P2: Accuracy = 0.55
Channel P4: Accuracy = 0.47
Channel P6: Accuracy = 0.44
Channel P8: Accuracy = 0.54
Channel PO8: Accuracy = 0.55
Channel PO4: Accuracy = 0.43
Channel O2: Accuracy = 0.52
Overall accuracy with majority voting: 0.52

Don’t mind this below :)

In [9]:

def interpolate_local_maxima(signal, sampling_rate):
    # Find peaks (local maxima)
    peaks, _ = find_peaks(signal)

    # Extract peak values and their corresponding times
    peak_values = signal[peaks]
    peak_times = peaks / sampling_rate

    # Create an interpolation function
    interp_func = interp1d(peak_times, peak_values, kind='cubic', fill_value="extrapolate")

    # Interpolate over the entire signal duration
    full_time_vector = np.arange(len(signal)) / sampling_rate
    interpolated_signal = interp_func(full_time_vector)

    return interpolated_signal, peaks

sampling_rate = 2048
noise_experiments = experiment_data[experiment_data['has_noise']]
for _, experiment in noise_experiments.iterrows():
    eeg_data = experiment['eeg']
    noise_stim_left = experiment['noise_stim_left']
    noise_stim_right = experiment['noise_stim_right']

    lags, cross_corrs = compute_best_combined_offset(eeg_data, noise_stim_left, noise_stim_right, channel_labels,
                                                    upsample_noise=True,
                                                    eeg_start_time=0.0, eeg_end_time=3.2,
                                                    noise_start_time=0.1, noise_end_time=1.9,
                                                    bandpass_eeg=(50, 80), bandpass_noise=(50, 80),
                                                    # mix_noise_strength=0.1,
                                                    plot_results=False)
    # Plot the cross-correlation for the first channel
    channel = 26
    cross_corr = cross_corrs[channel]
    interpolated_signal, peaks = interpolate_local_maxima(cross_corr, sampling_rate=sampling_rate)
    
    # Plot the original cross-correlation signal and the interpolated local maxima signal
    plt.figure(figsize=(50, 6))

    # Time vector for cross-correlation signal
    corr_time_vector = np.linspace(0, len(cross_corr) / (sampling_rate), len(cross_corr))

    # Plot original cross-correlation signal
    plt.plot(corr_time_vector, cross_corr, label='Original Cross-Correlation Signal', color='blue')

    # Plot interpolated local maxima signal
    plt.plot(corr_time_vector, interpolated_signal, label='Interpolated Local Maxima Signal', color='orange', linestyle='--')

    # Mark peaks
    plt.plot(corr_time_vector[peaks], cross_corr[peaks], 'x', label='Local Maxima', color='red')

    # Add titles and labels
    plt.title('Cross-Correlation Signal and Interpolated Local Maxima Signal')
    plt.xlabel('Time (seconds)')
    plt.ylabel('Amplitude')
    plt.legend()
    plt.xlim(0, len(cross_corr) / sampling_rate)
    plt.grid(True)

    # Show plot
    plt.show()

    # Compute the FFT of the original cross-correlation signal and the interpolated local maxima signal
    freqs = np.fft.fftfreq(len(cross_corr), 1 / sampling_rate)
    fft_original = np.abs(np.fft.fft(cross_corr))
    fft_interpolated = np.abs(np.fft.fft(interpolated_signal))

    # Plot the frequency spectrum before and after interpolation
    plt.figure(figsize=(14, 6))

    # Plot original signal frequency spectrum
    plt.plot(freqs[:len(freqs)//2], fft_original[:len(freqs)//2], label='Original Signal Frequency Spectrum', color='blue')

    # Plot interpolated signal frequency spectrum
    plt.plot(freqs[:len(freqs)//2], fft_interpolated[:len(freqs)//2], label='Interpolated Signal Frequency Spectrum', color='orange', linestyle='--')

    # Add titles and labels
    plt.title('Frequency Spectrum Before and After Interpolation')
    plt.xlabel('Frequency (Hz)')
    plt.ylabel('Amplitude')
    plt.legend()
    plt.xlim(0, 100)
    plt.grid(True)
    plt.tight_layout()

    # Show plot
    plt.show()

    break