|
直接保存mnist数据,供tensorflow使用:
(1)基类DataSet需要添加;
(2)mnist = pickle.load(pickle_data)
具体代码如下:
- # -*- coding: utf-8 -*-
- """
- Created on Mon Aug 28 21:16:13 2017
- @author: Administrator
- """
- import time
- import pickle
- #from PIL import Image
- import numpy
- #import os
- from six.moves import xrange # pylint: disable=redefined-builtin
- import tensorflow as tf
- from tensorflow.python.framework import dtypes
- class DataSet(object):
- def __init__(self,
- images,
- labels,
- fake_data=False,
- one_hot=False,
- dtype=dtypes.float32,
- reshape=True):
- """Construct a DataSet.
- one_hot arg is used only if fake_data is true. `dtype` can be either
- `uint8` to leave the input as `[0, 255]`, or `float32` to rescale into
- `[0, 1]`.
- """
- dtype = dtypes.as_dtype(dtype).base_dtype
- if dtype not in (dtypes.uint8, dtypes.float32):
- raise TypeError('Invalid image dtype %r, expected uint8 or float32' %
- dtype)
- if fake_data:
- self._num_examples = 10000
- self.one_hot = one_hot
- else:
- assert images.shape[0] == labels.shape[0], (
- 'images.shape: %s labels.shape: %s' % (images.shape, labels.shape))
- self._num_examples = images.shape[0]
- # Convert shape from [num examples, rows, columns, depth]
- # to [num examples, rows*columns] (assuming depth == 1)
- if reshape:
- assert images.shape[3] == 1
- images = images.reshape(images.shape[0],
- images.shape[1] * images.shape[2])
- if dtype == dtypes.float32:
- # Convert from [0, 255] -> [0.0, 1.0].
- images = images.astype(numpy.float32)
- images = numpy.multiply(images, 1.0 / 255.0)
- self._images = images
- self._labels = labels
- self._epochs_completed = 0
- self._index_in_epoch = 0
- @property
- def images(self):
- return self._images
- @property
- def labels(self):
- return self._labels
- @property
- def num_examples(self):
- return self._num_examples
- @property
- def epochs_completed(self):
- return self._epochs_completed
- def next_batch(self, batch_size, fake_data=False, shuffle=True):
- """Return the next `batch_size` examples from this data set."""
- if fake_data:
- fake_image = [1] * 784
- if self.one_hot:
- fake_label = [1] + [0] * 9
- else:
- fake_label = 0
- return [fake_image for _ in xrange(batch_size)], [
- fake_label for _ in xrange(batch_size)
- ]
- start = self._index_in_epoch
- # Shuffle for the first epoch
- if self._epochs_completed == 0 and start == 0 and shuffle:
- perm0 = numpy.arange(self._num_examples)
- numpy.random.shuffle(perm0)
- self._images = self.images[perm0]
- self._labels = self.labels[perm0]
- # Go to the next epoch
- if start + batch_size > self._num_examples:
- # Finished epoch
- self._epochs_completed += 1
- # Get the rest examples in this epoch
- rest_num_examples = self._num_examples - start
- images_rest_part = self._images[start:self._num_examples]
- labels_rest_part = self._labels[start:self._num_examples]
- # Shuffle the data
- if shuffle:
- perm = numpy.arange(self._num_examples)
- numpy.random.shuffle(perm)
- self._images = self.images[perm]
- self._labels = self.labels[perm]
- # Start next epoch
- start = 0
- self._index_in_epoch = batch_size - rest_num_examples
- end = self._index_in_epoch
- images_new_part = self._images[start:end]
- labels_new_part = self._labels[start:end]
- return numpy.concatenate((images_rest_part, images_new_part), axis=0) , numpy.concatenate((labels_rest_part, labels_new_part), axis=0)
- else:
- self._index_in_epoch += batch_size
- end = self._index_in_epoch
- #print(start, end)
- return self._images[start:end], self._labels[start:end]
- start=time.clock()
- # 加载数据
- pickle_data = open('mnist.pkl','rb');
- mnist = pickle.load(pickle_data)
- pickle_data.close()
- end=time.clock()
- print('Runing time = %s Seconds'%(end-start))
- # Create the model
- x = tf.placeholder(tf.float32, [None, 784])
- W = tf.Variable(tf.zeros([784, 10]))
- b = tf.Variable(tf.zeros([10]))
- y = tf.matmul(x, W) + b
- # Define loss and optimizer
- y_ = tf.placeholder(tf.float32, [None, 10])
- cross_entropy = tf.reduce_mean(
- tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))
- train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
- sess = tf.InteractiveSession()
- tf.global_variables_initializer().run()
- # Train
- for _ in range(1000):
- batch_xs, batch_ys = mnist.train.next_batch(100)
- sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
- # Test trained model
- correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
- accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
- print(sess.run(accuracy, feed_dict={x: mnist.test.images,
- y_: mnist.test.labels}))
复制代码
|
|