TensorFlow数据集-Dataset-API

2 Mins read

TensorFlow在学习的时候, 它是有一个mnist数据集让我们学习的. 通过batch_size来每轮数据训练的大小, 现在打算将一个我们实际的数据转换为跟mnist数据集一样的效果

内存中自定义数据

Dataset-API是1.3版引入的, 支持从内存中/硬盘中生成数据集.

  • 一维数组
    直接就是给到一维数组, 生成数据集
dataset = tf.data.Dataset.from_tensor_slices(
    (np.array([1.0, 2.0, 3.0, 4.0, 5.0]), np.random.uniform(size=(5, 2)))
)
out_put="""
元组的形式:
(1.0, array([ 0.41459922,  0.75492457]))
(2.0, array([ 0.47954237,  0.93916116]))
(3.0, array([ 0.70576017,  0.58064858]))
(4.0, array([ 0.8239234 ,  0.92814029]))
(5.0, array([ 0.03073594,  0.16718188]))
"""
  • 字典
    字典的格式给到,结果也是字典, 按照key索引(这种用的比较多一些)
dataset = tf.data.Dataset.from_tensor_slices(
    {
        "a": np.array([1.0, 2.0, 3.0, 4.0, 5.0]),
        "b": np.random.uniform(size=(5, 2))
    }
)

out_put="""
字典的形式:
{'a': 1.0, 'b': array([ 0.15337225,  0.97730736])}
{'a': 2.0, 'b': array([ 0.89860896,  0.95473649])}
{'a': 3.0, 'b': array([ 0.79198725,  0.84507321])}
{'a': 4.0, 'b': array([ 0.27289686,  0.56223038])}
{'a': 5.0, 'b': array([ 0.19825011,  0.44183586])}
"""

两种输出验证方式

当内存中数据生成后, 我们通过这两种方式来验证下
dataset来自上方代码

  • index提取(缺点是需要提前知道个数)
iterator = dataset.make_one_shot_iterator()
one_element = iterator.get_next()
with tf.Session() as sess:
        for i in range(5):
            print(sess.run(one_element))
  • 死循环提取, 当超过个数的时候会抛出OutOfRangeError,以此停止(因此它适合全部的场合)
iterator = dataset.make_one_shot_iterator()
one_element = iterator.get_next()
with tf.Session() as sess:
    try:
        while True:
            print(sess.run(one_element))
    except tf.errors.OutOfRangeError:
        print("end!")

数据混淆预处理

当数据从数据集读取的时候,可能存在一些不合理性,我们需要进行混淆. 又或者你的数据不是很足够, 类似图片左右投影一般进行N*2复制来扩大数据集. 诸如此类的预处理是很有必要的.

dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0, 2.0, 3.0, 4.0, 5.0]))
  • dataset.map(func), 可以理解一个很酷的装饰器,如例子
dataset = dataset.map(lambda x: x + 1)
out_put="""
2.0, 3.0, 4.0, 5.0, 6.0
"""
  • dataset = dataset.batch(batch_size), 可以按batch_size格式, 无法放大


dataset = dataset.batch(2)
out_put="""
# 注意上方batch原先是5的, 现在传入2之后最大只有2. 
2,
2,
1
"""
  • dataset.shuffle(buffer_size) , 按照大小进行打乱,特别注意buffer最好大于数据的数量, 详见 stackoverflow
dataset = dataset.shuffle(buffer_size=5)
out_put="""
3.0
1.0
2.0
5.0
4.0
"""
  • dataset.repeat(repeat_count) 数据重复放大repeat_count倍, 搭配shuffle打乱更酷(数据总数*repeat_count,这样子才足以覆盖全部数据), repeat必须带有参数, 不然会无限重复下去
dataset = dataset.repeat(10).shuffle(buffer_size=1000*10)
# 数量太多不贴了