Tensorflow2.0 tf.data.Dataset api數據集常用操作

1.創建

from_tensor_slices

創建一個Dataset,其元素為給定張量的切片,給定的張量沿其第一維被切片。此操作將保留輸入張量的結構,刪除每個張量的第一維並將其用作數據集維。所有輸入張量的第一個尺寸必須相同。

data = np.array([0.1, 0.4, 0.6, 0.2, 0.8, 0.8, 0.4, 0.9, 0.3, 0.2])

label = np.array([0, 0, 1, 0, 1, 1, 0, 1, 0, 0])

dataset = tf.data.Dataset.from_tensor_slices((data, label))

2.遍歷

for x,y in dataset:

print(x,y)

也可以用

for i in dataset.__iter__():

print(i)

3.repeat

<code>repeat(
count=None
)
eg:dataset = dataset.repeat(3)
如果參數為空 代表無限期重複/<code>

4.cache

<code>cache(    
filename=''
)
緩存數據集元素。
第一次迭代數據集時,其元素將緩存在指定文件或內存中。隨後的迭代將使用緩存的數據。
filename為空表示緩存到內存中
/<code>

5.shuffle

<code>shuffle(    buffer_size, seed=None, reshuffle_each_iteration=None)/<code>

隨機重新排列此數據集的元素。

該數據集用buffer_size元素填充緩衝區,然後從該緩衝區中隨機採樣元素,用新元素替換所選元素。為了實現完美的改組,需要緩衝區大小大於或等於數據集的完整大小

<code>dataset = tf.data.Dataset.range(3) 
dataset = dataset.shuffle(3)
[1,0,2]/<code>

6.batch

<code>batch(    batch_size, drop_remainder=False)/<code>

將此數據集的連續元素合併為批

<code>dataset = tf.data.Dataset.range(8) 
dataset = dataset.batch(3)
for i in dataset.__iter__():
print(i)
-----------------------------------------------------
tf.Tensor([0 1 2], shape=(3,), dtype=int64)
tf.Tensor([3 4 5], shape=(3,), dtype=int64)
tf.Tensor([6 7], shape=(2,), dtype=int64)/<code>


分享到:


相關文章: