信号处理--使用CNN+LSTM实现单通道脑电信号EEG的睡眠分期评估

news/2024/5/19 18:32:13 标签: 人工智能, 深度学习, 机器学习, EEG, 睡眠分析

目录

背景

亮点

环境配置

数据

方法

结果

代码获取

参考文献


背景

睡眠对人体健康很重要。监测人体的睡眠分期对于人体健康和医疗具有重要意义。

亮点

  • 架构在第一层使用两个具有不同滤波器大小的 CNN 和双向 LSTM。 CNN 可以被训练来学习滤波器,以从原始单通道 EEG 中提取时不变特征,而双向 LSTM 可以被训练来将时间信息(例如睡眠阶段转换规则)编码到模型中。
  • 实现了一种两步训练算法,可以通过反向传播有效地端到端训练我们的模型,同时防止模型遭受大睡眠中出现的类别不平衡问题(即学习仅对大多数睡眠阶段进行分类) 数据集。
  • 在不改变模型架构和训练算法的情况下,模型可以从两个数据集的不同原始单通道脑电图自动学习睡眠阶段评分的特征,这两个数据集具有不同的属性(例如采样率)和评分标准( AASM 和 R&K)。

环境配置

  • python3.5.4
  • tensorflowgpu  1.15.2

数据

Sleep-EDF

MASS

方法

 模型主要代码:

class MyModel(DeepFeatureNet):

    def __init__(
        self, 
        batch_size, 
        input_dims, 
        n_classes, 
        seq_length,
        n_rnn_layers,
        return_last,
        is_train, 
        reuse_params,
        use_dropout_feature, 
        use_dropout_sequence,
        name="deepsleepnet"
    ):
        super(self.__class__, self).__init__(
            batch_size=batch_size, 
            input_dims=input_dims, 
            n_classes=n_classes, 
            is_train=is_train, 
            reuse_params=reuse_params, 
            use_dropout=use_dropout_feature, 
            name=name
        )

        self.seq_length = seq_length
        self.n_rnn_layers = n_rnn_layers
        self.return_last = return_last

        self.use_dropout_sequence = use_dropout_sequence

    def _build_placeholder(self):
        # Input
        name = "x_train" if self.is_train else "x_valid"
        self.input_var = tf.compat.v1.placeholder(
            tf.float32, 
            shape=[self.batch_size*self.seq_length, self.input_dims, 1, 1],
            name=name + "_inputs"
        )
        # Target
        self.target_var = tf.compat.v1.placeholder(
            tf.int32, 
            shape=[self.batch_size*self.seq_length, ],
            name=name + "_targets"
        )

    def build_model(self, input_var):
        # Create a network with superclass method
        network = super(self.__class__, self).build_model(
            input_var=self.input_var
        )

        # Residual (or shortcut) connection
        output_conns = []

        # Fully-connected to select some part of the output to add with the output from bi-directional LSTM
        name = "l{}_fc".format(self.layer_idx)
        with tf.compat.v1.variable_scope(name) as scope:
            output_tmp = fc(name="fc", input_var=network, n_hiddens=1024, bias=None, wd=0)
            output_tmp = batch_norm_new(name="bn", input_var=output_tmp, is_train=self.is_train)
            # output_tmp = leaky_relu(name="leaky_relu", input_var=output_tmp)
            output_tmp = tf.nn.relu(output_tmp, name="relu")
        self.activations.append((name, output_tmp))
        self.layer_idx += 1
        output_conns.append(output_tmp)

        ######################################################################

        # Reshape the input from (batch_size * seq_length, input_dim) to
        # (batch_size, seq_length, input_dim)
        name = "l{}_reshape_seq".format(self.layer_idx)
        input_dim = network.get_shape()[-1].value
        seq_input = tf.reshape(network,
                               shape=[-1, self.seq_length, input_dim],
                               name=name)
        assert self.batch_size == seq_input.get_shape()[0].value
        self.activations.append((name, seq_input))
        self.layer_idx += 1

        # Bidirectional LSTM network
        name = "l{}_bi_lstm".format(self.layer_idx)
        hidden_size = 512   # will output 1024 (512 forward, 512 backward)
        with tf.compat.v1.variable_scope(name) as scope:

            def lstm_cell():

                cell = tf.compat.v1.nn.rnn_cell.LSTMCell(hidden_size,                               
                                                use_peepholes=True,
                                                state_is_tuple=True,
                                                reuse=tf.compat.v1.get_variable_scope().reuse) 
                if self.use_dropout_sequence:
                    keep_prob = 0.5 if self.is_train else 1.0
                    cell = tf.compat.v1.nn.rnn_cell.DropoutWrapper(
                        cell,
                        output_keep_prob=keep_prob
                    )

                return cell

            fw_cell = tf.compat.v1.nn.rnn_cell.MultiRNNCell([lstm_cell() for _ in range(self.n_rnn_layers)], state_is_tuple = True)
            bw_cell = tf.compat.v1.nn.rnn_cell.MultiRNNCell([lstm_cell() for _ in range(self.n_rnn_layers)], state_is_tuple = True)

            # Initial state of RNN
            self.fw_initial_state = fw_cell.zero_state(self.batch_size, tf.float32)
            self.bw_initial_state = bw_cell.zero_state(self.batch_size, tf.float32)

            # Feedforward to MultiRNNCell
            list_rnn_inputs = tf.unstack(seq_input, axis=1)
            #outputs, fw_state, bw_state = tf.nn.bidirectional_rnn(
            outputs, fw_state, bw_state = tf.compat.v1.nn.static_bidirectional_rnn(
                cell_fw=fw_cell,
                cell_bw=bw_cell,
                inputs=list_rnn_inputs,
                initial_state_fw=self.fw_initial_state,
                initial_state_bw=self.bw_initial_state
            )

            if self.return_last:
                network = outputs[-1]
            else:
                network = tf.reshape(tf.concat(axis=1, values=outputs), [-1, hidden_size*2],
                                    name=name)
            self.activations.append((name, network))
            self.layer_idx +=1

            self.fw_final_state = fw_state
            self.bw_final_state = bw_state

        # Append output
        output_conns.append(network)

        ######################################################################

        # Add
        name = "l{}_add".format(self.layer_idx)
        network = tf.add_n(output_conns, name=name)
        self.activations.append((name, network))
        self.layer_idx += 1

        # Dropout
        if self.use_dropout_sequence:
            name = "l{}_dropout".format(self.layer_idx)
            if self.is_train:
                network = tf.nn.dropout(network, keep_prob=0.5, name=name)
            else:
                network = tf.nn.dropout(network, keep_prob=1.0, name=name)
            self.activations.append((name, network))
        self.layer_idx += 1

        return network

    def init_ops(self):
        self._build_placeholder()

        # Get loss and prediction operations
        with tf.compat.v1.variable_scope(self.name) as scope:
            
            # Reuse variables for validation
            if self.reuse_params:
                scope.reuse_variables()

            # Build model
            network = self.build_model(input_var=self.input_var)

            # Softmax linear
            name = "l{}_softmax_linear".format(self.layer_idx)
            network = fc(name=name, input_var=network, n_hiddens=self.n_classes, bias=0.0, wd=0)
            self.activations.append((name, network))
            self.layer_idx += 1

            # Outputs of softmax linear are logits
            self.logits = network

            ######### Compute loss #########

            # Weighted cross-entropy loss for a sequence of logits (per example)
            loss = tf.contrib.legacy_seq2seq.sequence_loss_by_example(
                [self.logits],
                [self.target_var],
                [tf.ones([self.batch_size * self.seq_length])],
                name="sequence_loss_by_example"
            )
            loss = tf.reduce_sum(loss) / self.batch_size

            # Regularization loss
            regular_loss = tf.add_n(
                tf.compat.v1.get_collection("losses", scope=scope.name + "\/"),
                name="regular_loss"
            )

            # print " "
            # print "Params to compute regularization loss:"
            # for p in tf.compat.v1.get_collection("losses", scope=scope.name + "\/"):
            #     print p.name
            # print " "

            # Total loss
            self.loss_op = tf.add(loss, regular_loss)

            # Predictions
            self.pred_op = tf.argmax(self.logits, 1)

结果

睡眠分期效果图

MASS数据集分类表

代码获取

后台私信 1

参考文献

K. Wulffet al., “Sleep and circadian rhythm disruption in psychiatric and neurodegenerative disease,”Nature Reviews Neuroscience, vol. 11, no. 8, pp. 589–599, 2010.

C. S. Huanget al., “Knowledge-based identification of sleep stages based on two forehead electroencephalogram channels,”Frontiers in Neuroscience, vol. 8, p. 263, 2014.

N. Srivastavaet al., “Dropout : A Simple Way to Prevent Neural Networks from Overfitting,”J. of Machine Learning Research, vol. 15, pp. 1929–1958, 2014.

B. Kempet al., “Analysis of a sleep-dependent neuronal feedback loop: The slow-wave microcontinuity of the EEG,”IEEE Trans. Biomed. Eng., vol. 47, no. 9, pp. 1185–1194, 2000.


http://www.niftyadmin.cn/n/5413097.html

相关文章

启动Docker镜像时候,ENTRYPOINT 和CMD这两者指令的写法有什么不同和区别?

ENTRYPOINT和CMD在Dockerfile中都用于指定容器启动时执行的命令,但它们之间存在一些关键的区别和不同的用途: 1. 基本用途和行为差异 ENTRYPOINT 定义了容器启动时执行的基础命令,使得容器像一个可执行程序。ENTRYPOINT让你能够指定容器启动…

万字长文讲解Golang pprof 的使用

往期好文推荐 ⭐️⭐️⭐️: # golang pprof 监控系列(1) —— go trace 统计原理与使用 # golang pprof监控系列(2) —— memory,block,mutex 使用 # golang pprof 监控系列(3) —— memory,block,mute…

【Linux】开始使用gdb吧!

开始使用gdb吧! 1 下载安装2 开始使用3 实践运用补充一下 print 的 功能 (类似监视窗口的作用)和显示堆栈的功能Thanks♪(・ω・)ノ谢谢阅读!!!下一篇文章见!&am…

路由与转发

路由器主要完成两个功能: 1是路由选择,也就是确定那一条路径。 2是分组转发,当一个分组到达时,所采取的动作。 前者根据路由选择协议构造并维护路由表,后者处理通过路由器的数据流,关键操作时转发表查询…

秒杀的时候怎么使用Redis?

商品信息存储:在Redis中存储秒杀商品的库存信息。可以使用Redis的Hash数据类型,将商品ID作为字段,库存数量作为值存储在Hash中。例如,HSET seckill_goods stock_1 100表示商品ID为stock_1的商品库存数量为100。 秒杀订单存储&…

Zabbix 6.0部署+自定义监控项+自动发现与自动注册+部署zabbix代理服务器

Zabbix 6.0 Zabbix 6.0一、关于zabbix1、什么是zabbix2、zabbix工作原理3、zabbix 6.0 特性4、zabbix 6.0 功能组件 二、Zabbix 6.0 部署1、 部署 zabbix 服务端(1) 部署 Nginx PHP 环境并测试(2) 部署数据库(3) 编译安装 zabbix server 服务端(4) 部署 Web 前端,进…

Urban Elevations(UBA-211)

网址如下: Urban Elevations - UVA 221 - Virtual Judge (vjudge.net) 第三方网站的 说实话,我看英语看得头大 最近学了一堆stl的容器,比如map,set啥的,方便是很方便,但是导致我脑子里第一个念头就是用他…

基于similarities的文本语义相似度计算和文本匹配搜索

similarities 实现了多种相似度计算、匹配搜索算法,支持文本、图像,python3开发。 安装 pip3 install torch # conda install pytorch pip3 install -U similarities或 git clone https://github.com/shibing624/similarities.git cd similarities py…