博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
tensorflow入门:CNN for MNIST
阅读量:4122 次
发布时间:2019-05-25

本文共 10490 字,大约阅读时间需要 34 分钟。

在这里插入图片描述

使用tensorflow构建如上图所示的CNN用于对MNIST数据集进行softmax classification。

理论部分不再赘述,完整的代码如下:

import tensorflow as tfimport numpy as npfrom tensorflow.examples.tutorials.mnist import input_datamnist = input_data.read_data_sets("MNIST_data/", one_hot=True)# hyperparameterlearning_rate = 0.001training_epoches = 20batch_size = 100class Model:        def __init__(self, sess, name):        self.sess = sess        self.name = name        self._build_net()            def _build_net(self):        # with tf.variable_scope(self.name):        self.training = tf.placeholder(tf.bool)        # input placeholder for X & Y        self.X = tf.placeholder(tf.float32, [None, 784])        self.Y = tf.placeholder(tf.float32, [None, 10])        # img 28x28x1 (black/white)        X_img = tf.reshape(self.X, [-1, 28, 28, 1])                    # convolutional layer 1 & pooling layer 1        conv1 = tf.layers.conv2d(inputs=X_img, filters=32, kernel_size=[3, 3],                                  padding="SAME", activation=tf.nn.relu)        pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2],                                        padding="SAME", strides=2)        dropout1 = tf.layers.dropout(inputs=pool1, rate=0.3, training=self.training)        # convolutional layer 2 & pooling layer 2        conv2 = tf.layers.conv2d(inputs=dropout1, filters=64, kernel_size=[3, 3],                                padding="SAME", activation=tf.nn.relu)        pool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2],                                       padding="SAME", strides=2)        dropout2 = tf.layers.dropout(inputs=pool2, rate=0.3, training=self.training)        # convolutional layer 3 & pooling layer 3        conv3 = tf.layers.conv2d(inputs=dropout2, filters=128, kernel_size=[3, 3],                                padding="SAME", activation=tf.nn.relu)        pool3 = tf.layers.max_pooling2d(inputs=conv3, pool_size=[2, 2],                                       padding="SAME", strides=2)        dropout3 = tf.layers.dropout(inputs=pool3, rate=0.3, training=self.training)        # dense layer with Relu        flat = tf.reshape(dropout3, [-1, 128 * 4 * 4])        dense4 = tf.layers.dense(inputs=flat, units=625, activation=tf.nn.relu)        dropout4 = tf.layers.dropout(inputs=dense4, rate=0.5, training=self.training)        # FC layer 625 input -> 10 output, no activation function        self.logits = tf.layers.dense(inputs=dropout4, units=10)        # define loss & optimizer        self.cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(                                   logits = self.logits, labels=self.Y))        self.optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(self.cost)        # accuracy        correct_prediction = tf.equal(tf.argmax(self.logits, 1), tf.argmax(self.Y, 1))        self.accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))            def train(self, x_data, y_data, training=True):        return self.sess.run([self.cost, self.optimizer],                              feed_dict={
self.X: x_data, self.Y: y_data, self.training: training}) def predict(self, x_test, training=False): return self.sess.run(self.logits, feed_dict={
self.X :x_test, self.training: training}) def get_accuracy(self, x_test, y_test, training=False): return self.sess.run(self.accuracy, feed_dict={
self.X: x_test,self.Y: y_test, self.training: training}) # train the modelswith tf.Session() as sess: models = [] num_models = 2 for m in range(num_models): models.append(Model(sess, "modal"+str(m))) sess.run(tf.global_variables_initializer()) print('Learning Start!') for epoch in range(training_epoches): avg_cost_list = np.zeros(len(models)) total_batch = int(mnist.train.num_examples / batch_size) for i in range(total_batch): batch_xs, batch_ys = mnist.train.next_batch(batch_size) # train each modal for m_id, m in enumerate(models): c, _ = m.train(batch_xs, batch_ys) avg_cost_list[m_id] += c / total_batch print('Epoch: ', '%04d' %(epoch+1), 'cost=', avg_cost_list) print('Learning finished!') # test & accuracy test_size = len(mnist.test.labels) predictions = np.zeros([test_size, 10]) for m_id,m in enumerate(models): print(m_id, "Accuracy:", m.get_accuracy(mnist.test.images, mnist.test.labels)) p = m.predict(mnist.test.images) predictions += p ensemble_correct_prediction = tf.equal(tf.argmax(predictions, 1), tf.argmax(mnist.test.labels, 1)) ensemble_accuracy = tf.reduce_mean(tf.cast(ensemble_correct_prediction, tf.float32)) print("Ensemble_accuracy:", sess.run(ensemble_accuracy))

结果:

Learning Start!Epoch:  0001 cost= [0.29211415 0.28355632]Epoch:  0002 cost= [0.08716567 0.0870499 ]Epoch:  0003 cost= [0.06902521 0.06623169]Epoch:  0004 cost= [0.05563359 0.05452387]Epoch:  0005 cost= [0.04963774 0.04871382]Epoch:  0006 cost= [0.04462749 0.04449957]Epoch:  0007 cost= [0.04132144 0.03907955]Epoch:  0008 cost= [0.03792324 0.03861412]Epoch:  0009 cost= [0.0354344  0.03323769]Epoch:  0010 cost= [0.03516847 0.03405525]Epoch:  0011 cost= [0.03143759 0.03219781]Epoch:  0012 cost= [0.03051504 0.02993162]Epoch:  0013 cost= [0.02906878 0.02711077]Epoch:  0014 cost= [0.02729127 0.02754832]Epoch:  0015 cost= [0.02729633 0.02632647]Epoch:  0016 cost= [0.02438517 0.02701174]Epoch:  0017 cost= [0.02482958 0.0244114 ]Epoch:  0018 cost= [0.02455271 0.02649499]Epoch:  0019 cost= [0.02371975 0.02178147]Epoch:  0020 cost= [0.02260135 0.0213784 ]Learning finished!0 Accuracy: 0.9951 Accuracy: 0.9949Ensemble_accuracy: 0.9954

结果前面的其实有很长的warning,这里没有给出。warning是说新版本的tensorflow把mnist数据集移动到了别的地方,建议你从别的地方导入进来。这篇博文仅做例子。实际使用tensorflow的时候,你都是自己写读取数据的函数什么的,需要根据数据集的存储格式写不同的Python代码。

转载地址:http://osvpi.baihongyu.com/

你可能感兴趣的文章
Redis几个认识误区
查看>>
Mysql 自动备份与恢复
查看>>
IDEA如何打包可运行jar,外部引用jar包版
查看>>
Ajax (部分二:prototype.js代码后半部分)自己做的,总结页面向后台传Form值、单个值和后台向前台传一个或是一组值
查看>>
Ajax (部分二:prototype.js代码前半部)自己做的,总结页面向后台传Form值、单个值和后台向前台传一个或是一组值
查看>>
Ajax (部分一)自己做的,总结页面向后台传Form值、单个值和后台向前台传一个或是一组值
查看>>
JS 横向图片跑马灯效果
查看>>
Java线程知识深入解析(1)
查看>>
MyEclipse中改变选择JDK版本
查看>>
Java类文件中取得request、response、session的方法
查看>>
JS实现可编辑下拉框
查看>>
js网页定位,window,body元素的定位属性
查看>>
计算机编程简史图
查看>>
Myeclipse 快捷键大全
查看>>
properties文件读写自己写的方法
查看>>
properties文件读写自己写的方法
查看>>
Java保留小数问题
查看>>
java session HttpSessionListener、HttpSessionBindingListener使用区别,实现在线人数统计以及踢出用户
查看>>
Struts2 学习笔记——Action开发详解
查看>>
java 实现自动编译成json struts2 中不用配置json等jar包来实现低耦合,低入侵式ajax访问返回数据
查看>>