TensorFlow数据集-Dataset-API

TensorFlow在学习的时候, 它是有一个mnist数据集让我们学习的. 通过batch_size来每轮数据训练的大小, 现在打算将一个我们实际的数据转换为跟mnist数据集一样的效果

内存中自定义数据

Dataset-API是1.3版引入的, 支持从内存中/硬盘中生成数据集.

  • 一维数组
    直接就是给到一维数组, 生成数据集

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19


    1

    2

    3

    4

    5

    6

    7

    8


    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19


    dataset = tf.data.Dataset.from\_tensor\_slices(np.array(\[1.0, 2.0, 3.0, 4.0, 5.0\]))

    out\_put=""""

    1.0

    2.0

    3.0

    4.0

    5.0

    """


  • 多维数组, size=(5,2) , 可以理解为(count, dimension)

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19


    1

    2

    3

    4

    5

    6

    7

    8


    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19


    dataset = tf.data.Dataset.from\_tensor\_slices(np.random.uniform(size=(5, 2)))

    out\_put="""

    \[ 0.01721917 0.4621821 \]

    \[ 0.58484623 0.13534625\]

    \[ 0.83591111 0.01397783\]

    \[ 0.88934806 0.97464257\]

    \[ 0.76707649 0.42036516\]

    """


  • 元组

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25


    1

    2

    3

    4

    5

    6

    7

    8

    9

    10

    11


    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25


    dataset = tf.data.Dataset.from\_tensor\_slices(

    (np.array(\[1.0, 2.0, 3.0, 4.0, 5.0\]), np.random.uniform(size=(5, 2)))

    )

    out\_put="""

    元组的形式:

    (1.0, array(\[ 0.41459922, 0.75492457\]))

    (2.0, array(\[ 0.47954237, 0.93916116\]))

    (3.0, array(\[ 0.70576017, 0.58064858\]))

    (4.0, array(\[ 0.8239234 , 0.92814029\]))

    (5.0, array(\[ 0.03073594, 0.16718188\]))

    """


  • 字典
    字典的格式给到,结果也是字典, 按照key索引(这种用的比较多一些)

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31


    1

    2

    3

    4

    5

    6

    7

    8

    9

    10

    11

    12

    13

    14


    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31


    dataset = tf.data.Dataset.from\_tensor\_slices(

    {

    "a": np.array(\[1.0, 2.0, 3.0, 4.0, 5.0\]),

    "b": np.random.uniform(size=(5, 2))

    }

    )

    out\_put="""

    字典的形式:

    {'a': 1.0, 'b': array(\[ 0.15337225, 0.97730736\])}

    {'a': 2.0, 'b': array(\[ 0.89860896, 0.95473649\])}

    {'a': 3.0, 'b': array(\[ 0.79198725, 0.84507321\])}

    {'a': 4.0, 'b': array(\[ 0.27289686, 0.56223038\])}

    {'a': 5.0, 'b': array(\[ 0.19825011, 0.44183586\])}

    """


两种输出验证方式

当内存中数据生成后, 我们通过这两种方式来验证下
dataset来自上方代码

  • index提取(缺点是需要提前知道个数)

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13


    1

    2

    3

    4

    5


    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13


    iterator = dataset.make\_one\_shot\_iterator()

    one\_element = iterator.get\_next()

    with tf.Session() as sess:

    for i in range(5):

    print(sess.run(one\_element))


  • 死循环提取, 当超过个数的时候会抛出OutOfRangeError,以此停止(因此它适合全部的场合)

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21


    1

    2

    3

    4

    5

    6

    7

    8

    9


    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19


    iterator = dataset.make\_one\_shot\_iterator()

    one\_element = iterator.get\_next()

    with tf.Session() as sess:

    try:

    while True:

    print(sess.run(one\_element))

    except tf.errors.OutOfRangeError:

    print("end!")


数据混淆预处理

当数据从数据集读取的时候,可能存在一些不合理性,我们需要进行混淆. 又或者你的数据不是很足够, 类似图片左右投影一般进行N*2复制来扩大数据集. 诸如此类的预处理是很有必要的.

1
2
3
4
5


1


1
2
3
4
5


dataset = tf.data.Dataset.from\_tensor\_slices(np.array(\[1.0, 2.0, 3.0, 4.0, 5.0\]))


  • dataset.map(func), 可以理解一个很酷的装饰器,如例子

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11


    1

    2

    3

    4


    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11


    dataset = dataset.map(lambda x: x + 1)

    out\_put="""

    2.0, 3.0, 4.0, 5.0, 6.0

    """


  • dataset = dataset.batch(batch_size), 可以按batch_size格式, 无法放大

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17


    1

    2

    3

    4

    5

    6

    7


    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17


    ataset = dataset.batch(2)

    out\_put="""

    \# 注意上方batch原先是5的, 现在传入2之后最大只有2.

    2,

    2,

    1

    """


  • dataset.shuffle(buffer_size) , 按照大小进行打乱,特别注意buffer最好大于数据的数量, 详见stackoverflow

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19


    1

    2

    3

    4

    5

    6

    7

    8


    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19


    dataset = dataset.shuffle(buffer\_size=5)

    out\_put="""

    3.0

    1.0

    2.0

    5.0

    4.0

    """


  • dataset.repeat(repeat_count) 数据重复放大repeat_count倍, 搭配shuffle打乱更酷(数据总数*repeat_count,这样子才足以覆盖全部数据), repeat必须带有参数, 不然会无限重复下去

    1
    2
    3
    4
    5
    6
    7


    1

    2


    1
    2
    3
    4
    5
    6
    7


    dataset = dataset.repeat(10).shuffle(buffer\_size=1000\*10)

    \# 数量太多不贴了