Tensorflow2.0 tf.data.Dataset api数据集常用操作

1.创建

from_tensor_slices

创建一个Dataset,其元素为给定张量的切片,给定的张量沿其第一维被切片。此操作将保留输入张量的结构,删除每个张量的第一维并将其用作数据集维。所有输入张量的第一个尺寸必须相同。

data = np.array([0.1, 0.4, 0.6, 0.2, 0.8, 0.8, 0.4, 0.9, 0.3, 0.2])

label = np.array([0, 0, 1, 0, 1, 1, 0, 1, 0, 0])

dataset = tf.data.Dataset.from_tensor_slices((data, label))

2.遍历

for x,y in dataset:

print(x,y)

也可以用

for i in dataset.__iter__():

print(i)

3.repeat

<code>repeat(
count=None
)
eg:dataset = dataset.repeat(3)
如果参数为空 代表无限期重复/<code>

4.cache

<code>cache(    
filename=''
)
缓存数据集元素。
第一次迭代数据集时,其元素将缓存在指定文件或内存中。随后的迭代将使用缓存的数据。
filename为空表示缓存到内存中
/<code>

5.shuffle

<code>shuffle(    buffer_size, seed=None, reshuffle_each_iteration=None)/<code>

随机重新排列此数据集的元素。

该数据集用buffer_size元素填充缓冲区,然后从该缓冲区中随机采样元素,用新元素替换所选元素。为了实现完美的改组,需要缓冲区大小大于或等于数据集的完整大小

<code>dataset = tf.data.Dataset.range(3) 
dataset = dataset.shuffle(3)
[1,0,2]/<code>

6.batch

<code>batch(    batch_size, drop_remainder=False)/<code>

将此数据集的连续元素合并为批

<code>dataset = tf.data.Dataset.range(8) 
dataset = dataset.batch(3)
for i in dataset.__iter__():
print(i)
-----------------------------------------------------
tf.Tensor([0 1 2], shape=(3,), dtype=int64)
tf.Tensor([3 4 5], shape=(3,), dtype=int64)
tf.Tensor([6 7], shape=(2,), dtype=int64)/<code>


分享到:


相關文章: