Hello Mat

 找回密码
 立即注册
查看: 2908|回复: 0

直接保存mnist数据,供tensorflow使用

[复制链接]

1294

主题

1520

帖子

110

金钱

管理员

Rank: 9Rank: 9Rank: 9

积分
22633
发表于 2017-8-28 21:29:10 | 显示全部楼层 |阅读模式
直接保存mnist数据,供tensorflow使用:
(1)基类DataSet需要添加;
(2)mnist = pickle.load(pickle_data)
具体代码如下:
  1. # -*- coding: utf-8 -*-
  2. """
  3. Created on Mon Aug 28 21:16:13 2017

  4. @author: Administrator
  5. """

  6. import time   
  7. import pickle
  8. #from PIL import Image
  9. import numpy
  10. #import os
  11. from six.moves import xrange  # pylint: disable=redefined-builtin
  12. import tensorflow as tf
  13. from tensorflow.python.framework import dtypes

  14. class DataSet(object):

  15.   def __init__(self,
  16.                images,
  17.                labels,
  18.                fake_data=False,
  19.                one_hot=False,
  20.                dtype=dtypes.float32,
  21.                reshape=True):
  22.     """Construct a DataSet.
  23.     one_hot arg is used only if fake_data is true.  `dtype` can be either
  24.     `uint8` to leave the input as `[0, 255]`, or `float32` to rescale into
  25.     `[0, 1]`.
  26.     """
  27.     dtype = dtypes.as_dtype(dtype).base_dtype
  28.     if dtype not in (dtypes.uint8, dtypes.float32):
  29.       raise TypeError('Invalid image dtype %r, expected uint8 or float32' %
  30.                       dtype)
  31.     if fake_data:
  32.       self._num_examples = 10000
  33.       self.one_hot = one_hot
  34.     else:
  35.       assert images.shape[0] == labels.shape[0], (
  36.           'images.shape: %s labels.shape: %s' % (images.shape, labels.shape))
  37.       self._num_examples = images.shape[0]

  38.       # Convert shape from [num examples, rows, columns, depth]
  39.       # to [num examples, rows*columns] (assuming depth == 1)
  40.       if reshape:
  41.         assert images.shape[3] == 1
  42.         images = images.reshape(images.shape[0],
  43.                                 images.shape[1] * images.shape[2])
  44.       if dtype == dtypes.float32:
  45.         # Convert from [0, 255] -> [0.0, 1.0].
  46.         images = images.astype(numpy.float32)
  47.         images = numpy.multiply(images, 1.0 / 255.0)
  48.     self._images = images
  49.     self._labels = labels
  50.     self._epochs_completed = 0
  51.     self._index_in_epoch = 0

  52.   @property
  53.   def images(self):
  54.     return self._images

  55.   @property
  56.   def labels(self):
  57.     return self._labels

  58.   @property
  59.   def num_examples(self):
  60.     return self._num_examples

  61.   @property
  62.   def epochs_completed(self):
  63.     return self._epochs_completed

  64.   def next_batch(self, batch_size, fake_data=False, shuffle=True):
  65.     """Return the next `batch_size` examples from this data set."""
  66.     if fake_data:
  67.       fake_image = [1] * 784
  68.       if self.one_hot:
  69.         fake_label = [1] + [0] * 9
  70.       else:
  71.         fake_label = 0
  72.       return [fake_image for _ in xrange(batch_size)], [
  73.           fake_label for _ in xrange(batch_size)
  74.       ]
  75.     start = self._index_in_epoch
  76.     # Shuffle for the first epoch
  77.     if self._epochs_completed == 0 and start == 0 and shuffle:
  78.       perm0 = numpy.arange(self._num_examples)
  79.       numpy.random.shuffle(perm0)
  80.       self._images = self.images[perm0]
  81.       self._labels = self.labels[perm0]
  82.     # Go to the next epoch
  83.     if start + batch_size > self._num_examples:
  84.       # Finished epoch
  85.       self._epochs_completed += 1
  86.       # Get the rest examples in this epoch
  87.       rest_num_examples = self._num_examples - start
  88.       images_rest_part = self._images[start:self._num_examples]
  89.       labels_rest_part = self._labels[start:self._num_examples]
  90.       # Shuffle the data
  91.       if shuffle:
  92.         perm = numpy.arange(self._num_examples)
  93.         numpy.random.shuffle(perm)
  94.         self._images = self.images[perm]
  95.         self._labels = self.labels[perm]
  96.       # Start next epoch
  97.       start = 0
  98.       self._index_in_epoch = batch_size - rest_num_examples
  99.       end = self._index_in_epoch
  100.       images_new_part = self._images[start:end]
  101.       labels_new_part = self._labels[start:end]
  102.       return numpy.concatenate((images_rest_part, images_new_part), axis=0) , numpy.concatenate((labels_rest_part, labels_new_part), axis=0)
  103.     else:
  104.       self._index_in_epoch += batch_size
  105.       end = self._index_in_epoch
  106.       #print(start, end)
  107.       return self._images[start:end], self._labels[start:end]


  108. start=time.clock()

  109. # 加载数据
  110. pickle_data = open('mnist.pkl','rb');
  111. mnist = pickle.load(pickle_data)
  112. pickle_data.close()

  113. end=time.clock()  
  114. print('Runing time = %s Seconds'%(end-start))

  115. # Create the model
  116. x = tf.placeholder(tf.float32, [None, 784])
  117. W = tf.Variable(tf.zeros([784, 10]))
  118. b = tf.Variable(tf.zeros([10]))
  119. y = tf.matmul(x, W) + b

  120. # Define loss and optimizer
  121. y_ = tf.placeholder(tf.float32, [None, 10])


  122. cross_entropy = tf.reduce_mean(
  123.       tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))
  124. train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

  125. sess = tf.InteractiveSession()
  126. tf.global_variables_initializer().run()

  127. # Train
  128. for _ in range(1000):
  129.     batch_xs, batch_ys = mnist.train.next_batch(100)
  130.     sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})

  131. # Test trained model
  132. correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
  133. accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
  134. print(sess.run(accuracy, feed_dict={x: mnist.test.images,
  135.                                       y_: mnist.test.labels}))
复制代码



算法QQ  3283892722
群智能算法链接http://halcom.cn/forum.php?mod=forumdisplay&fid=73
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

Python|Opencv|MATLAB|Halcom.cn ( 蜀ICP备16027072号 )

GMT+8, 2024-4-25 05:23 , Processed in 0.217013 second(s), 21 queries .

Powered by Discuz! X3.4

Copyright © 2001-2021, Tencent Cloud.

快速回复 返回顶部 返回列表