LSTM递归神经网络
百度网盘代码分享:http://pan.baidu.com/s/1eRQcCuu
电脑:Win7旗舰版+64Bit+AMD Athlon(tm)X2 DualCore QL-64 2.10GHz RAM2.75GB
Anaconda3-4.2.0-Windows-x86_64- import time
- #from tensorflow.examples.tutorials.mnist import input_data
- import tensorflow as tf
- import Get_Mnist_Data
- start=time.clock()
- #mnist = input_data.read_data_sets('/temp/', one_hot=True)
- mnist = Get_Mnist_Data.read_data_sets('Get_Mnist_Data', one_hot=True)
- end=time.clock()
- print('Runing time = %s Seconds'%(end-start))
- def RNN(X,weights,biases):
- # hidden layer for input
- X = tf.reshape(X, [-1, n_inputs])
- X_in = tf.matmul(X, weights['in']) + biases['in']
- X_in = tf.reshape(X_in, [-1,n_steps, n_hidden_units])
-
- # cell
- lstm_cell = tf.contrib.rnn.BasicLSTMCell(n_hidden_units, forget_bias=1.0, state_is_tuple=True)
- _init_state = lstm_cell.zero_state(batch_size, dtype=tf.float32)
- outputs,states = tf.nn.dynamic_rnn(lstm_cell, X_in, initial_state=_init_state, time_major=False)
-
- #hidden layer for output as the final results
- #results = tf.matmul(states[1], weights['out']) + biases['out']
- # or
- outputs = tf.unstack(tf.transpose(outputs, [1,0,2]))
- results = tf.matmul(outputs[-1], weights['out']) + biases['out']
- return results
-
- #load mnist data
- #mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
- # parameters init
- l_r = 0.001
- training_iters = 100
- batch_size = 128
- n_inputs = 28
- n_steps = 28
- n_hidden_units = 128
- n_classes = 10
- #define placeholder for input
- x = tf.placeholder(tf.float32, [None, n_steps, n_inputs])
- y = tf.placeholder(tf.float32, [None, n_classes])
- # define w and b
- weights = {
- 'in': tf.Variable(tf.random_normal([n_inputs,n_hidden_units])),
- 'out': tf.Variable(tf.random_normal([n_hidden_units,n_classes]))
- }
- biases = {
- 'in': tf.Variable(tf.constant(0.1,shape=[n_hidden_units,])),
- 'out': tf.Variable(tf.constant(0.1,shape=[n_classes,]))
- }
- pred = RNN(x, weights, biases)
- cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred,labels=y))
- train_op = tf.train.AdamOptimizer(l_r).minimize(cost)
- correct_pred = tf.equal(tf.argmax(pred,1),tf.argmax(y,1))
- accuracy = tf.reduce_mean(tf.cast(correct_pred,tf.float32))
- #init session
- sess = tf.Session()
- #init all variables
- sess.run(tf.global_variables_initializer())
- #start training
- #for i in range(training_iters):
- for i in range(training_iters):
- #get batch to learn easily
- batch_x, batch_y = mnist.train.next_batch(batch_size)
- batch_x = batch_x.reshape([batch_size, n_steps, n_inputs])
- sess.run(train_op,feed_dict={x: batch_x, y: batch_y})
- if i % 50 == 0:
- print(sess.run(accuracy,feed_dict={x: batch_x, y: batch_y,}))
- #test_data = mnist.test.images.reshape([-1, n_steps, n_inputs])
- #test_label = mnist.test.labels
- #print("Testing Accuracy: ", sess.run(accuracy, feed_dict={x: test_data, y: test_label}))
- sess.close()
复制代码 原理分析如下:
RNN递归神经网络主要用于解决时序问题类的预测研究,例如文本处理、语音识别、视频处理等。对于图像的分类识别,推荐使用CNN卷积神经网络。 LSTM是RNN的增强学习算法; 对于传统网络而言,设数据序列x1和x2,x1和x2通过隐藏层传递函数、输出层传递函数,输出得到y1和y2;
传统递归网络结构
对于递归神经网络而言,x1和x2通过隐藏层传递函数得到s1和s2,s1和s2将会被保存,在下一次迭代计算中,s1和s2与x1和x2一并作为输入,从而达到优化神经网络性能; 一般的传递函数(激励函数): Sigmoid(x) = 1/(1+exp(-x)) Tanh(x) = 2/(1+exp(-2x)) – 1 ReLU(x) = {0 if x<0, x if x≥0} Leaky ReLU(x) = {0.01x if x<0, xif x≥0} 基于传统递归神经网络这个思路,我们以文本上下文语义识别为例进行进一步说明,逐步引入到RNN深度学习网络。 文本上下文语义识别流程图 如图所示,以arrive、Taipei、on三个单词为例,分别设为x1、x2、x3,传统的网络训练输出直接将x1、x2、x3对应输出为y1、y2、y3,这种传统的网络结构,依赖于每一个输入x的特征,没有相应的语意关系,即各个输入量之间没有依赖关系。 一般动词arrive会接一个地名Taipei,地名Taipei后面接一个副词on,RNN深度学习网络将x1训练得到的激励特征a1存储起来,然后将a1和x2作为原始x2的输入,得到x2的激励特征a2,那么a2包含x1和x2的基本信息,最后将a2和x3作为原始x3的输入;也就是y1由x1决定,y2由a1和x2决定,a1由x1得到,y3由a2和x3决定,a2由x1和x2得到;这样一种操作时序,也就更加符合我们的期望了,相信我们此时对RNN深度学习网络有更加深入的理解了。
RNN深度学习网络能够较好处理时序数据问题,但是有一点值得考虑的时,这个时序怎么把握,对于arrive、Taipei、on三个单词为例,arrive对on的时序影响是减弱的,当arrive和on之间再穿插几个单词时,那么arrive对on的时序影响是不可把控的,这也就是RNN深度学习网络的局限性。 RNN深度学习网络结构 对于RNN深度学习网络的局限性,引入LSTM(长短时记忆型递归神经网络),LSTM递归神经网络在整个x序列上引入控制参数C,如图所示。 LSTM网络结构 如图所示的LSTM网络结构,σ为sigmoid函数,网络结构为链式图(如同MATLAB Simulink模型规则)。 如图所示LSTM网络结构,相应的引入各种LSTM变种模型,如Peephole connection、Gate忘记/更新相互影响、GRU等。 (a) Peephole connection (b) Gate忘记/更新相互影响 (c)GRU
参考: https://github.com/wiibrew/DeepLearningBook
|