热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

数据项_Tensorflow_datasets中batch(batch_size)和shuffle(buffer_size)理解

篇首语:本文由编程笔记#小编为大家整理,主要介绍了Tensorflow_datasets中batch(batch_size)和shuffle(buffer_size)理解相关的知识,希望对你有一

篇首语:本文由编程笔记#小编为大家整理,主要介绍了Tensorflow_datasets中batch(batch_size)和shuffle(buffer_size)理解相关的知识,希望对你有一定的参考价值。


相关内容引用:https://zhuanlan.zhihu.com/p/42417456

1.shuffle(buffer_size)

tensorflow中的数据集类Dataset有一个shuffle方法,用来打乱数据集中数据顺序,训练时非常常用。其中shuffle方法有一个参数buffer_size,文档的解释如下:

dataset.shuffle(buffer_size, seed=None, reshuffle_each_iteration=None) 
Randomly shuffles the elements of this dataset.
This dataset fills a buffer with `buffer_size` elements, then randomly
samples elements from this buffer, replacing the selected elements with new
elements. For perfect shuffling, a buffer size greater than or equal to the
full size of the dataset is required.
For instance, if your dataset contains 10,000 elements but `buffer_size` is
set to 1,000, then `shuffle` will initially select a random element from
only the first 1,000 elements in the buffer. Once an element is selected,
its space in the buffer is replaced by the next (i.e. 1,001-st) element,
maintaining the 1,000 element buffer.
`reshuffle_each_iteration` controls whether the shuffle order should be
different for each epoch.

首先,Dataset会取所有数据的前buffer_size数据项,填充 buffer,如下图

然后,从buffer中随机选择一条数据输出。假设随机选中了,item 7,那么bufferitem 7对应的位置就空出来了 。

然后,从Dataset中顺序选择最新的一条数据填充到buffer中。这里顺序选择到的是item 10。

然后在从Buffer中随机选择下一条数据输出。

用一个实际的例子来说明:

import tensorflow as tf
import numpy as np
buffer_size=4
data = np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0])
label = np.array([0, 0, 1, 0, 1, 1, 0, 1, 0, 0])
dataset = tf.data.Dataset.from_tensor_slices((data, label))
dataset = dataset.shuffle(buffer_size)
it = dataset.__iter__()
for i in range(10):
x, y = it.next()
print(x, y)

 输出:

tf.Tensor(0.1, shape=(), dtype=float64) tf.Tensor(0, shape=(), dtype=int32)
tf.Tensor(0.2, shape=(), dtype=float64) tf.Tensor(0, shape=(), dtype=int32)
tf.Tensor(0.6, shape=(), dtype=float64) tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(0.5, shape=(), dtype=float64) tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(0.8, shape=(), dtype=float64) tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(0.7, shape=(), dtype=float64) tf.Tensor(0, shape=(), dtype=int32)
tf.Tensor(0.4, shape=(), dtype=float64) tf.Tensor(0, shape=(), dtype=int32)
tf.Tensor(0.3, shape=(), dtype=float64) tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(0.9, shape=(), dtype=float64) tf.Tensor(0, shape=(), dtype=int32)
tf.Tensor(1.0, shape=(), dtype=float64) tf.Tensor(0, shape=(), dtype=int32)

0.1, 0.2, 0.3, 0.40.1(随机选中)0.1, 0.2, 0.3, 0.4
0.5, 0.2, 0.3, 0.40.2(随机选中)0.5, 0.2, 0.3, 0.4
0.5, 0.6, 0.3, 0.40.6(随机选中)0.5, 0.6, 0.3, 0.4
0.5, 0.7, 0.3, 0.40.5(随机选中)0.5, 0.7, 0.3, 0.4
0.8, 0.7, 0.3, 0.40.8(随机选中)0.8, 0.7, 0.3, 0.4
0.9, 0.7, 0.3, 0.40.7(随机选中)0.9, 0.7, 0.3, 0.4
0.9, 1.0, 0.3, 0.40.4(随机选中)0.9, 1.0, 0.3, 0.4
0.9, 1.0, 0.30.3(随机选中)0.9, 1.0, 0.3
0.9, 1.00.9(随机选中)0.9, 1.0
1.01.0(随机选中)1.0

如此,shuffle 后的dataset序列为上述output中的序列。


2.batch(batch_size)

import tensorflow as tf
import numpy as np
dataset = tf.data.Dataset.from_tensor_slices(np.array([1, 2, 3, 4, 5, 6,
7,8,9,10,11,12,13,14,15,16]))
#有序的
batch_dataset=dataset.batch(4)
for ele in batch_dataset:
print(ele)

output:

tf.Tensor([1 2 3 4], shape=(4,), dtype=int32)
tf.Tensor([5 6 7 8], shape=(4,), dtype=int32)
tf.Tensor([ 9 10 11 12], shape=(4,), dtype=int32)
tf.Tensor([13 14 15 16], shape=(4,), dtype=int32)

这里batch就是从dataset中按顺序分成4个批次,仔细看可以知道上面所有输出结果都是有序的,这在机器学习中用来训练模型是浪费资源且没有意义的,所以我们需要将数据打乱,这样每批次训练的时候所用到的数据集是不一样的,这样啊可以提高模型训练效果。

因此需要和shuffle结合起来使用。


3.shuffle(buffer_size)+ batch(batch_size)

import tensorflow as tf
import numpy as np
dataset = tf.data.Dataset.from_tensor_slices(np.array([1, 2, 3, 4, 5, 6,
7,8,9,10,11,12,13,14,15,16]))
dataset1=dataset.shuffle(16)
dataset2=dataset1.batch(2)
for i in dataset1:
print(i)
print("separate")
for j in dataset2:
print(j)

output:

tf.Tensor(3, shape=(), dtype=int32)
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(16, shape=(), dtype=int32)
tf.Tensor(15, shape=(), dtype=int32)
tf.Tensor(13, shape=(), dtype=int32)
tf.Tensor(12, shape=(), dtype=int32)
tf.Tensor(6, shape=(), dtype=int32)
tf.Tensor(5, shape=(), dtype=int32)
tf.Tensor(11, shape=(), dtype=int32)
tf.Tensor(4, shape=(), dtype=int32)
tf.Tensor(10, shape=(), dtype=int32)
tf.Tensor(7, shape=(), dtype=int32)
tf.Tensor(8, shape=(), dtype=int32)
tf.Tensor(14, shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)
tf.Tensor(9, shape=(), dtype=int32)
separate
tf.Tensor([8 2], shape=(2,), dtype=int32)
tf.Tensor([4 7], shape=(2,), dtype=int32)
tf.Tensor([ 3 12], shape=(2,), dtype=int32)
tf.Tensor([ 9 16], shape=(2,), dtype=int32)
tf.Tensor([10 5], shape=(2,), dtype=int32)
tf.Tensor([15 14], shape=(2,), dtype=int32)
tf.Tensor([ 1 13], shape=(2,), dtype=int32)
tf.Tensor([ 6 11], shape=(2,), dtype=int32)

在这里buffer_size:该函数的作用就是先构建buffer,大小为buffer_size,然后从dataset中提取数据将它填满。batch操作,从buffer中提取。如果buffer_size小于Dataset的大小,每次提取buffer中的数据,会再次从Dataset中抽取数据将它填满(当然是之前没有抽过的)。所以一般最好的方式是buffer_size=Dataset_size
 

交换shuffle 和 batch的前后会有什么不同呢?

t1 = t.shuffle(int).batch(int)

#这个是先打乱t的顺序,然后batch

t2 = t.batch(int).shuffle(int)

#这个是打乱batch的顺序

dataset3=dataset.shuffle(2)
dataset4=dataset3.batch(16)
for i in dataset3:
print(i)
print("separate")
for j in dataset4:
print(j)

输出:

tf.Tensor([1 2], shape=(2,), dtype=int32)
tf.Tensor([3 4], shape=(2,), dtype=int32)
tf.Tensor([5 6], shape=(2,), dtype=int32)
tf.Tensor([7 8], shape=(2,), dtype=int32)
tf.Tensor([ 9 10], shape=(2,), dtype=int32)
tf.Tensor([11 12], shape=(2,), dtype=int32)
tf.Tensor([13 14], shape=(2,), dtype=int32)
tf.Tensor([15 16], shape=(2,), dtype=int32)
separate
tf.Tensor([11 12], shape=(2,), dtype=int32)
tf.Tensor([13 14], shape=(2,), dtype=int32)
tf.Tensor([15 16], shape=(2,), dtype=int32)
tf.Tensor([5 6], shape=(2,), dtype=int32)
tf.Tensor([1 2], shape=(2,), dtype=int32)
tf.Tensor([3 4], shape=(2,), dtype=int32)
tf.Tensor([7 8], shape=(2,), dtype=int32)
tf.Tensor([ 9 10], shape=(2,), dtype=int32)


推荐阅读
  • 大类|电阻器_使用Requests、Etree、BeautifulSoup、Pandas和Path库进行数据抓取与处理 | 将指定区域内容保存为HTML和Excel格式
    大类|电阻器_使用Requests、Etree、BeautifulSoup、Pandas和Path库进行数据抓取与处理 | 将指定区域内容保存为HTML和Excel格式 ... [详细]
  • 本文介绍了UUID(通用唯一标识符)的概念及其在JavaScript中生成Java兼容UUID的代码实现与优化技巧。UUID是一个128位的唯一标识符,广泛应用于分布式系统中以确保唯一性。文章详细探讨了如何利用JavaScript生成符合Java标准的UUID,并提供了多种优化方法,以提高生成效率和兼容性。 ... [详细]
  • vue引入echarts地图的四种方式
    一、vue中引入echart1、安装echarts:npminstallecharts--save2、在main.js文件中引入echarts实例:  Vue.prototype.$echartsecharts3、在需要用到echart图形的vue文件中引入:   importechartsfrom"echarts";4、如果用到map(地图),还 ... [详细]
  • iOS snow animation
    CTSnowAnimationView.hCTMyCtripCreatedbyalexon1614.Copyright©2016年ctrip.Allrightsreserved.# ... [详细]
  • 本文整理了一份基础的嵌入式Linux工程师笔试题,涵盖填空题、编程题和简答题,旨在帮助考生更好地准备考试。 ... [详细]
  • iOS 不定参数 详解 ... [详细]
  • 本文节选自《NLTK基础教程——用NLTK和Python库构建机器学习应用》一书的第1章第1.2节,作者Nitin Hardeniya。本文将带领读者快速了解Python的基础知识,为后续的机器学习应用打下坚实的基础。 ... [详细]
  • 浅析python实现布隆过滤器及Redis中的缓存穿透原理_python
    本文带你了解了位图的实现,布隆过滤器的原理及Python中的使用,以及布隆过滤器如何应对Redis中的缓存穿透,相信你对布隆过滤 ... [详细]
  • Webpack 初探:Import 和 Require 的使用
    本文介绍了 Webpack 中 Import 和 Require 的基本概念和使用方法,帮助读者更好地理解和应用模块化开发。 ... [详细]
  • 本文介绍如何通过 Python 的 `unittest` 和 `functools` 模块封装一个依赖方法,用于管理测试用例之间的依赖关系。该方法能够确保在某个测试用例失败时,依赖于它的其他测试用例将被跳过。 ... [详细]
  • 本文将详细介绍如何在Mac上安装Jupyter Notebook,并提供一些常见的问题解决方法。通过这些步骤,您将能够顺利地在Mac上运行Jupyter Notebook。 ... [详细]
  • Python错误重试让多少开发者头疼?高效解决方案出炉
    ### 优化后的摘要在处理 Python 开发中的错误重试问题时,许多开发者常常感到困扰。为了应对这一挑战,`tenacity` 库提供了一种高效的解决方案。首先,通过 `pip install tenacity` 安装该库。使用时,可以通过简单的规则配置重试策略。例如,可以设置多个重试条件,使用 `|`(或)和 `&`(与)操作符组合不同的参数,从而实现灵活的错误重试机制。此外,`tenacity` 还支持自定义等待时间、重试次数和异常处理,为开发者提供了强大的工具来提高代码的健壮性和可靠性。 ... [详细]
  • 2018 HDU 多校联合第五场 G题:Glad You Game(线段树优化解法)
    题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=6356在《Glad You Game》中,Steve 面临一个复杂的区间操作问题。该题可以通过线段树进行高效优化。具体来说,线段树能够快速处理区间更新和查询操作,从而大大提高了算法的效率。本文详细介绍了线段树的构建和维护方法,并给出了具体的代码实现,帮助读者更好地理解和应用这一数据结构。 ... [详细]
  • 每年,意甲、德甲、英超和西甲等各大足球联赛的赛程表都是球迷们关注的焦点。本文通过 Python 编程实现了一种生成赛程表的方法,该方法基于蛇形环算法。具体而言,将所有球队排列成两列的环形结构,左侧球队对阵右侧球队,首支队伍固定不动,其余队伍按顺时针方向循环移动,从而确保每场比赛不重复。此算法不仅高效,而且易于实现,为赛程安排提供了可靠的解决方案。 ... [详细]
  • 机器学习中的标准化缩放、最小-最大缩放及鲁棒缩放技术解析 ... [详细]
author-avatar
品位人生2602905223
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有