热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

干货!!!学习笔记TensorFlowNN

1、NN1.1基础网络的解析其中,测试和验证集能赋值到tf.constant()中,而训练集可以导入tf.placeholder()中,

1、NN

1.1


基础网络的解析


       
其中,测试和验证集能赋值到 tf.constant() 中,而训练集可以导入 tf.placeholder() 中,训练集只有导入占位符我们才能在随机梯度下降中成批量地进行训练。网络定义的权重矩阵和偏置向量后需要执行初始化,每一层需要一个权重矩阵和一个偏置向量。构建损失函数,并计算训练损失。模型会输出一个预测向量,我们可以比较预测标签和真实标签并使用交叉熵函数和 softmax 回归来确定损失值。训练损失衡量预测值和真实值之间差距,并用于更新权重矩阵。优化器,优化器将使用计算的损失值和反向传播算法更新权重和偏置项参数。

1.2 数据的加载过程



       这里有一些定义的函数用来加载数据时候使用:


  1. def randomize(dataset, labels):
  2. permutation = np.random.permutation(labels.shape[0])
  3. shuffled_dataset = dataset[permutation, :, :]
  4. shuffled_labels = labels[permutation]
  5. return shuffled_dataset, shuffled_labels
  6. def one_hot_encode(np_array):
  7. return (np.arange(10) == np_array[:,None]).astype(np.float32)

def reformat_data(dataset, labels, image_width, image_height, image_depth):


  1. np_dataset_ = np.array([np.array(image_data).reshape(image_width, image_height, image_depth) for image_data in dataset])
  2. np_labels_ = one_hot_encode(np.array(labels, dtype=np.float32))
  3. np_dataset, np_labels = randomize(np_dataset_, np_labels_)
  4. return np_dataset, np_labels

def flatten_tf_array(array):


  1. shape = array.get_shape().as_list()
  2. return tf.reshape(array, [shape[0], shape[1] * shape[2] * shape[3]])


  1. def accuracy(predictions, labels):
  2. return (100.0 * np.sum(np.argmax(predictions, 1) == np.argmax(labels, 1)) / predictions.shape[0])def randomize(dataset, labels):
  3. permutation = np.random.permutation(labels.shape[0])
  4. shuffled_dataset = dataset[permutation, :, :]
  5. shuffled_labels = labels[permutation]
  6. return shuffled_dataset, shuffled_labels
  7. def one_hot_encode(np_array):
  8. return (np.arange(10) == np_array[:,None]).astype(np.float32)


  1. def reformat_data(dataset, labels, image_width, image_height, image_depth):
  2. np_dataset_ = np.array([np.array(image_data).reshape(image_width, image_height, image_depth) for image_data in dataset])
  3. np_labels_ = one_hot_encode(np.array(labels, dtype=np.float32))
  4. np_dataset, np_labels = randomize(np_dataset_, np_labels_)
  5. return np_dataset, np_labels
  6. def flatten_tf_array(array):
  7. shape = array.get_shape().as_list()
  8. return tf.reshape(array, [shape[0], shape[1] * shape[2] * shape[3]])
  9. def accuracy(predictions, labels):
  10. return (100.0 * np.sum(np.argmax(predictions, 1) == np.argmax(labels, 1)) / predictions.shape[0])

实例:
mnist手写数据集的加载方法


2.cifar10数据集的下载:





推荐阅读
author-avatar
柯怡如63391
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有