TensorFlow 2.0 tf.dataset类的使用

这个笔记主要是TensorFlow 2.0的tf.dataset接口的使用。下面的示例会把numpy array的数据写入到TFRecord文件中,以及从TFRecord文件中读取数据到numpy array。

安装

可以参考官网的https://www.tensorflow.org/install教程来安装。

安装完成后,检查安装的TensorFlow的版本:

import tensorflow as tf
print(tf.__version__)

TensorFlow Dataset的使用

在TensorFlow 2.0中,向网络灌输数据的最好方法是使用tf.dataset类,dataset本身就是一个迭代器,所以可以使用for循环的方法来迭代dataset里的数据。

1、使用numpy array来创建一个dataset:

import numpy as np
np.random.seed(0)
data = np.random.randn(256, 8, 8, 3)
dataset = tf.data.Dataset.from_tensor_slices(data)
print(dataset)
...
<TensorSliceDataset shapes: (8, 8, 3), types: tf.float64>

可以通过print()方法输出dataset,可以看到dataset的shap。

通常,第一维度的数据表示训练样本的数量。DataSet可以生产任何大小的batch size,但默认情况下,batch size的值为1, 也即是生成各自独立的训练样本。

2、迭代dataset

使用for循环可以dataset做迭代,如果想获取每个批次的数据,可以使用Python的enumerate,或者使用Dataset自身的方法enumerate(),迭代示例如下:

for i, batch in enumerate(dataset):
if i == 255 or i == 256:
print(i, batch.shape)
...
255 (8, 8, 3)
...
for i, batch in dataset.enumerate():
if i == 255 or i == 256:
print(i, batch.shape)
print(i.numpy(), batch.shape)
...
tf.Tensor(255, shape=(), dtype=int64) (8, 8, 3) 255 (8, 8, 3)
...

可以看到,使用dataset.enumerate()内置方法,返回的的一个值是一个Tensor(张量)。

3、重复迭代dataset

如果需要重复多次对dataset进行迭代,可以使用dataset的内置方法repeat()。示例:

for i, batch in dataset.repeat(2).enumerate():
if i == 255 or i == 256:
print(i.numpy(), batch.shape)
...
255 (8, 8, 3)
256 (8, 8, 3)

4、使用take()获取指定数量大小的样本数

如果不想使用整个数据集,可以使用take()方法来获取指定数量的数据集:

for batch in dataset.take(3):
print(batch.shape)
...
(8, 8, 3)
(8, 8, 3)
(8, 8, 3)

5、设置batch size

默认情况下,dataset是以batch size为1来迭代,可以使用batch()方法设置batch size的大小。

dataset = dataset.batch(16)
for batch in dataset.take(3):
print(batch.shape)
...
(16, 8, 8, 3)
(16, 8, 8, 3)
(16, 8, 8, 3)

设置了batch size为16

6、打乱数据集

shuffle()方法可以用来打乱数据,其中shuffle()方法会接收一个buffer_size的参数,这个参数作为一个每一次打乱数据的缓存区,也即是每次去出buffer_size大小的数据进行打乱。如果想完全打乱整个数据集,buffer_size需要设置为整个数据集的大小。

示例:

dataset = tf.data.Dataset.from_tensor_slices(np.arange(19))
for batch in dataset.batch(5):
print(batch)
...
tf.Tensor([0 1 2 3 4], shape=(5,), dtype=int64)
tf.Tensor([5 6 7 8 9], shape=(5,), dtype=int64)
tf.Tensor([10 11 12 13 14], shape=(5,), dtype=int64)
tf.Tensor([15 16 17 18], shape=(4,), dtype=int64)
...
for batch in dataset.shuffle(5).batch(5):
print(batch)
...
tf.Tensor([2 5 0 4 1], shape=(5,), dtype=int64)
tf.Tensor([ 6 9 3 12 10], shape=(5,), dtype=int64)
tf.Tensor([13 8 15 17 11], shape=(5,), dtype=int64)
tf.Tensor([18 16 14 7], shape=(4,), dtype=int64)

可以看到shuffle()的buffer_size为5,batch size也是5,每次取出5个数据,并进行打乱。打乱后,每个批次的数据就不是原来按顺序的了。

需要注意的是,如果把shuffle()方法和batch()方法调转,会导致的结果是对批次打乱,而不是对数据集里的数据打乱。

for batch in dataset.batch(5).shuffle(5):
print(batch)
...
tf.Tensor([5 6 7 8 9], shape=(5,), dtype=int64)
tf.Tensor([15 16 17 18], shape=(4,), dtype=int64)
tf.Tensor([0 1 2 3 4], shape=(5,), dtype=int64)
tf.Tensor([10 11 12 13 14], shape=(5,), dtype=int64)

7、转换数据

如果我们想对导入的数据做预处理,可以使用map方法。

def tranform(data):
mean = tf.reduce_mean(data)
return data - mean
for batch in dataset.shuffle(5).batch(5).map(tranform):
print(batch)
...
tf.Tensor([ 2 3 -1 0 -2], shape=(5,), dtype=int64)
tf.Tensor([-2 -5 2 3 4], shape=(5,), dtype=int64)
tf.Tensor([-1 1 2 -5 3], shape=(5,), dtype=int64)
tf.Tensor([ 3 -3 7 -4], shape=(4,), dtype=int64)

8、预取指定大小的batch来做训练

通常,读取和处理dataset的数据会很耗时,即耗CPU时间,为了让GPU不出现太多空闲,可以使用prefetch()方法预取一定数据的batch来做训练。

dataset.shuffle(5).batch(5).prefetch(buffer_size=tf.data.experimental.AUTOTUNE)

其中,把buffer_size设置为tf.data.experimental.AUTOTUNE,意思是让TensorFlow自己找到一个合适的最优的buffer_size。



版权声明:著作权归作者所有。

相关推荐

ASP.NET Core 2.0 日志配置

ASP.NET Core 2.0的日志系统做了break change的升级。.NET Core 2.0日志配置的改变主要体现在三点:使用新的方法AddLogging和Builder API配置services允许在Program.cs使用WebHostBuilder配置日志 ASP.NET Core 2.0模板里的W

JavaScript获取Object类名的几种方法

有以下几种方法可以用来获取Object的类名:typeofinstanceofobj.constructorfunc.prototype, proto.isPrototypeOffunc.name(ES6)使用示例:function Foo() {} var foo = new Foo(); typeof Foo; 

Django 2.0:路径转换器(Path converter)的用法

Django2.0于2017年12月2日已经正式发布。Django2.0支持Python3.4,3.5以及3.6,移除了对Python2.7的支持。官方强烈建议Python 3.x使用最新的版本。在Django2.0其中一个新特性为:简化Url路由的语法。在代码上主要体现在新增了django.urls.path函数,它带来了更简洁、更可读的路由语法,如:原来的urlurl(r'^arti

Kotlin:类的定义

基本定义Kotlin使用关键词class定义类,如:class User { } 声明类主要包括三部分:类名:必选,类的名称,一般以大写字母开头。类头:可选,类头包括type parameter(如泛型),主构造(primary constructor)等。类体:可选,在Kotlin,类体是可选的,它有大括号{}括起来。类头和类体是可选的,一个最简单的类可