共空间模式 (CSP)python 实现

news/2024/5/19 20:13:06 标签: python, EEG

代码参考自:https://github.com/orvindemsy/MI-BCI-CSP, 做了整理与封装,更方便使用

输入数据格式为:x_shape = [trial, channal, timepoint], y_shape = [trial]

python">from mimetypes import init
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy.io as io
from pandas import DataFrame as df
from scipy.linalg import inv, sqrtm
from sklearn.metrics import accuracy_score, classification_report
from sklearn.model_selection import (GridSearchCV, KFold, cross_val_score,
                                     train_test_split)
from sklearn.svm import SVC

def compute_cov(EEG_data):
    '''
    INPUT:
    EEG_data : EEG_data in shape T x N x S
    
    OUTPUT:
    avg_cov : covariance matrix of averaged over all trials
    '''
    cov = []
    for i in range(EEG_data.shape[0]):
        cov.append(EEG_data[i]@EEG_data[i].T/np.trace(EEG_data[i]@EEG_data[i].T))
        
    cov = np.mean(np.array(cov), 0)
    
    return cov

def decompose_cov(avg_cov):
    '''
    This function will decompose average covariance matrix of one class of each subject into 
    eigenvalues denoted by lambda and eigenvector denoted by V
    Both will be in descending order
    
    Parameter:
    avgCov = the averaged covariance of one class
    
    Return:
    λ_dsc and V_dsc, i.e. eigenvalues and eigenvector in descending order
    
    '''
    λ, V = np.linalg.eig(avg_cov)
    λ_dsc = np.sort(λ)[::-1] # Sort eigenvalue descending order, default is ascending order sort
    idx_dsc = np.argsort(λ)[::-1] # Find index in descending order
    V_dsc = V[:, idx_dsc] # Sort eigenvectors descending order
    λ_dsc = np.diag(λ_dsc) # Diagonalize λ_dsc
    
    return λ_dsc, V_dsc

def white_matrix(λ_dsc, V_dsc):
    '''
    '''
    λ_dsc_sqr = sqrtm(inv(λ_dsc))
    P = (λ_dsc_sqr)@(V_dsc.T)
    
    return P

def compute_S(avg_Cov, white):
    '''
    This function will compute S matrix, S = P * C * P.T

    INPUT:
    avg_Cov: averaged covariance of one class, dimension N x N, where N is number of electrodes
    white: the whitening transformation matrix
    
    OUTPUT:
    S
    '''
    S = white@avg_Cov@white.T
    
    return S

def decompose_S(S_one_class, order='descending'):
    '''
    This function will decompose the S matrix of one class to get the eigen vector
    Both eigenvector will be the same but in opposite order
    
    i.e the highest eigenvector in S left will be equal to lowest eigenvector in S right matrix 
    '''
    # Decompose S
    λ, B = np.linalg.eig(S_one_class)
    
    # Sort eigenvalues either descending or ascending
    if order == 'ascending':
        idx = λ.argsort() # Use this index to sort eigenvector smallest -> largest
    elif order == 'descending':
        idx = λ.argsort()[::-1] # Use this index to sort eigenvector largest -> smallest
    else:
        print('Wrong order input')
    
    λ = λ[idx]
    B = B[:, idx]
    
    return B, λ 

def spatial_filter(B, P):
    '''
    Will compute projection matrix using the following equation:
    W = B' @ P
    
    INPUT:
    B: the eigenvector either left or right class, choose one, size N x N, N is number of electrodes
    P: white matrix in size of N x N 
    
    OUTPUT:
    W spatial filter to filter EEG
    '''
    
    return (B.T@P)

def compute_Z(W, E, m):
    '''
    Will compute the Z
    Z = W @ E, 
    
    E is in the shape of N x M, N is number of electrodes, M is sample
    In application, E has nth trial, so there will be n numbers of Z
    
    Z, in each trial will have dimension of m x M, 
    where m is the first and last m rows of W, corresponds to smallest and largest eigenvalues
    '''
    Z = []
    
    W = np.delete(W, np.s_[m:-m:], 0)
    
    for i in range(E.shape[0]):
        Z.append(W @ E[i])
    
    return np.array(Z)

def feat_vector(Z):
    '''
    Will compute the feature vector of Z matrix
    
    INPUT:
    Z : projected EEG shape of T x N x S
    
    OUTPUT:
    feat : feature vector shape of T x m
    
    T = trial
    N = channel
    S = sample
    m = number of filter
    '''
    
    feat = []
    
    for i in range(Z.shape[0]):
        var = np.var(Z[i], ddof=1, axis=1)
        varsum = np.sum(var)
        
        feat.append(np.log10(var/varsum))
        
    return np.array(feat)


class CSP():
    def __init__(self,m = 2):
        super(CSP,self).__init__()
        self.m = m
        self. W = None
        
    def fit(self,x_train,y_train):
        # 分成左右手的数据
        x_train_l = x_train[y_train==0]
        x_train_r = x_train[y_train==1]

        # 算平均协方差阵
        C_l = compute_cov(x_train_l)
        C_r = compute_cov(x_train_r)
        C_c = C_l+C_r

        # 白化
        eigval, eigvec = decompose_cov(C_c)
        P = white_matrix(eigval,eigvec)

        # 计算S_l,S_r
        S_l = compute_S(C_l, P)
        S_r = compute_S(C_r, P)

        # 分解S_l,S_r
        S_l_eigvec, S_l_eigval = decompose_S(S_l, 'descending')
        S_r_eigvec, S_r_eigval = decompose_S(S_r, 'ascending')

        # 计算W
        self.W = spatial_filter(S_l_eigvec,P)
        
    def transform(self,x_train):
        Z_train = compute_Z(self.W,x_train,self.m)
        feat_train = feat_vector(Z_train)
        
        return feat_train
    
    def fit_transform(self,x_train,y_trian):
        self.fit(x_train,y_trian)
        feat_train = self.transform(x_train)
        
        return feat_train

http://www.niftyadmin.cn/n/770233.html

相关文章

Untiy-Resources 加载图片

一开始以为 将图片导入Unity时, 将其 图片转为sprite 以为就可以直接 load为sprite了,可是 一直报null异常 原来是 加载后Debug出来是这个类型 因为 加载的时候 是Texture2D类型,而我硬生生将其 转为 sprite,难怪会报异常了&…

Unity2018基于百度SDK的人脸识别

百度AI 中人脸识别,也可以应用到游戏里面 下面给出 一张 图下面是识别成功后给出的信息。详细的参数 可以到sdk文档中查看。 下面是sdk下载网址https://ai.baidu.com/sdk#bfr 百度api 下载 好sdk 后 // 设置APPID/AK/SK var APP_ID "你的 App ID"; va…

Unity安装VScode

https://code.visualstudio.com/ 下载好 之后 , 打开unity edit - preference external tools open by file extension , 点开, Browse , 然后选 你VSCode .exe 的位置。 之后 双击打开 会有我这个是安装好扩展的界面&#x…

设计模式-策略模式 C#

先放张UMI类图还是 挺好理解的, context上下文的 主要 有个对策略类的 引用,最终 给客服端引用。下面是在unity中的 代码 using System; using System.Collections; using System.Collections.Generic; using UnityEngine;abstract class Strategy {/…

Unity连接本地mysql数据库

虽然unity自带了Playerprefs 游戏存档,也有Get 和Set方法, 不懂的同学 请前去官方圣殿 只能作为游戏存档, 可以看作一个字典,但是要涉及到游戏帐号之类的化 就要用数据库来存储了。 之前学php和mysql,可以用mysql来做…

系统找不到文件C:\ProgramData\Oracle\Java\javapath\java.exe

jdk安装好之后,也配置好了环境变量,打开cmd输入javac可以跳出相关信息,可是输入java却一直提示:系统找不到文件C:\ProgramData\Oracle\Java\javapath\java.exe 在网上核对环境变量都正确啊,最后打开系统变量中的path可…

unity NGUI学习记录

大家好,我很菜,还在学习中。 欢迎大家关注我的blog。 另外zelog.xyz也是我的博客. 嘻嘻 最近一直纠结到底是NGUI好,还是UGUI好,搜了很多文章,其实没必要去纠结那一个好,反正用了之后就知道 了,于…

VRTK 开箱子功能

VRTK中 有个类是VRTK_Chest, 这个类运用了物理的join 链关节和事件委托 Max Angle 是箱子的盖子开的角度,一般90度比较合适, 在开箱子的过程中会触发事件。lid是盖子,handle要加上碰撞器。