Hello Mat

 找回密码
 立即注册
查看: 2965|回复: 0

RNN之NLSTM递归神经网络

[复制链接]

1323

主题

1551

帖子

0

金钱

管理员

Rank: 9Rank: 9Rank: 9

积分
22647
发表于 2017-9-14 23:41:18 | 显示全部楼层 |阅读模式
RNN之NLSTM递归神经网络
百度网盘代码分享:http://pan.baidu.com/s/1eRQcCuu
电脑:Win7旗舰版+64Bit+AMD Athlon(tm)X2 DualCore QL-64 2.10GHz   RAM2.75GB
Anaconda3-4.2.0-Windows-x86_64
  1. import time
  2. #from tensorflow.examples.tutorials.mnist import input_data
  3. import tensorflow as tf
  4. import Get_Mnist_Data

  5. start=time.clock()
  6. #mnist = input_data.read_data_sets('/temp/', one_hot=True)
  7. mnist = Get_Mnist_Data.read_data_sets('Get_Mnist_Data', one_hot=True)
  8. end=time.clock()  
  9. print('Runing time = %s Seconds'%(end-start))

  10. def compute_accuracy(v_x, v_y):
  11.     global pred
  12.     #input v_x to nn and get the result with y_pre
  13.     y_pre = sess.run(pred, feed_dict={x:v_x})
  14.     #find how many right
  15.     correct_prediction = tf.equal(tf.argmax(y_pre,1), tf.argmax(v_y,1))
  16.     #calculate average
  17.     accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
  18.     #get input content
  19.     result = sess.run(accuracy,feed_dict={x: v_x, y: v_y})
  20.     return result

  21. def LSTM_cell():
  22.     return tf.contrib.rnn.BasicLSTMCell(n_hidden_units, forget_bias=1.0, state_is_tuple=True)

  23. def Drop_lstm_cell():
  24.     return tf.contrib.rnn.DropoutWrapper(LSTM_cell(), output_keep_prob=0.5)

  25. def Mul_lstm_cell():
  26.     return tf.contrib.rnn.MultiRNNCell([Drop_lstm_cell() for _ in range(lstm_layer)], state_is_tuple=True)

  27. def RNN(X,weights,biases):
  28.     # hidden layer for input
  29.     X = tf.reshape(X, [-1, n_inputs])
  30.     X_in = tf.matmul(X, weights['in']) + biases['in']
  31.     X_in = tf.reshape(X_in, [-1,n_steps, n_hidden_units])
  32.    
  33.     # cell
  34.     #lstm_cell = tf.contrib.rnn.BasicLSTMCell(n_hidden_units, forget_bias=1.0, state_is_tuple=True)
  35.     lstm_cell = Mul_lstm_cell()
  36.     _init_state = lstm_cell.zero_state(batch_size, dtype=tf.float32)
  37.     outputs,states = tf.nn.dynamic_rnn(lstm_cell, X_in, initial_state=_init_state, time_major=False)
  38.    
  39.     #hidden layer for output as the final results
  40.     #results = tf.matmul(states[2][1], weights['out']) + biases['out']
  41.     # or
  42.     outputs = tf.unstack(tf.transpose(outputs, [1,0,2]))
  43.     results = tf.matmul(outputs[-1], weights['out']) + biases['out']

  44.     return results
  45.    

  46. #load mnist data
  47. #mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

  48. # parameters init
  49. lstm_layer = 3
  50. l_r = 0.001
  51. training_iters = 100
  52. batch_size = 128

  53. n_inputs = 28
  54. n_steps = 28
  55. n_hidden_units = 128
  56. n_classes = 10

  57. #define placeholder for input
  58. x = tf.placeholder(tf.float32, [None, n_steps, n_inputs])
  59. y = tf.placeholder(tf.float32, [None, n_classes])

  60. # define w and b
  61. weights = {
  62.     'in': tf.Variable(tf.random_normal([n_inputs,n_hidden_units])),
  63.     'out': tf.Variable(tf.random_normal([n_hidden_units,n_classes]))
  64. }
  65. biases = {
  66.     'in': tf.Variable(tf.constant(0.1,shape=[n_hidden_units,])),
  67.     'out': tf.Variable(tf.constant(0.1,shape=[n_classes,]))
  68. }

  69. pred = RNN(x, weights, biases)
  70. cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred,labels=y))
  71. train_op = tf.train.AdamOptimizer(l_r).minimize(cost)

  72. correct_pred = tf.equal(tf.argmax(pred,1),tf.argmax(y,1))
  73. accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

  74. #init session
  75. sess = tf.Session()
  76. #init all variables
  77. sess.run(tf.global_variables_initializer())
  78. #start training

  79. # x_image,x_label = mnist.test.next_batch(500)
  80. # x_image = x_image.reshape([500, n_steps, n_inputs])

  81. for i in range(training_iters):
  82.     #get batch to learn easily
  83.     batch_x, batch_y = mnist.train.next_batch(batch_size)
  84.     batch_x = batch_x.reshape([batch_size, n_steps, n_inputs])
  85.     sess.run(train_op,feed_dict={x: batch_x, y: batch_y})
  86.     if i % 50 == 0:
  87.         print(sess.run(accuracy,feed_dict={x: batch_x, y: batch_y,}))
  88.       #  print(sess.run(accuracy,feed_dict={x: x_image, y: x_label}))
  89. sess.close()
复制代码







算法QQ  3283892722
群智能算法链接http://halcom.cn/forum.php?mod=forumdisplay&fid=73
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

Python|Opencv|MATLAB|Halcom.cn ( 蜀ICP备16027072号 )

GMT+8, 2024-11-22 23:19 , Processed in 0.241629 second(s), 22 queries .

Powered by Discuz! X3.4

Copyright © 2001-2021, Tencent Cloud.

快速回复 返回顶部 返回列表