TensorFlow数据集制作实战

1 Mins read

这是实战篇,往下看先

当前数据


当前我有这些图片数据, 分为是飞机(plane_)和不是飞机(plane_false)的二分类数据集原图.分别存在两个不同的文件夹

  • 数据1 -> 大概是1600+张
  • 数据0 -> 大概是160+张

比例是10:1左右, 而数量1800+上来看也不是很理想, 我这里为了省事是直接水平翻转图片, 得到1800*2(待会见代码). 当然为了避免欠拟合情况当然更好的方法是挖掘更多的资源.

根据文件夹标签化图片

头部一些声明:

import tensorflow as tf
import os
import random
from PIL import Image
plane_path = '/Users/dobby/Documents/data_img/plane'
UNplane_path = '/Users/dobby/Documents/data_img/UNplane'
records_path = '/Users/dobby/Documents/data_img/train.tfrecords'
dataset_list = []

遍历这两个文件夹, 如果是plane文件夹,那里面存储是label=1的图片, 反之是0. 进行标签化
存储方式为(image_path, label)的元组 一起存储在dataset_list数组中

def classic_data(path):
    # 根据目录标签数据
    if path == plane_path:
        label = 1
    else:
        label = 0
    file_list =  os.listdir(path)
    for each in file_list:
        if each[:5] != 'plane':
            continue
        im_full_path = os.path.join(path, each)
        dataset_list.append((im_full_path, label))

转为数据集

def create_record(data_list, should_transpose=False):
    """
     图片转为bytes写入
     字符串也是bytes
     1/0 Int
    """
    counter = 0
     # 新建一个写入session
    writer = tf.python_io.TFRecordWriter(records_path)
    for path,label in data_list:
        counter += 1
        print("{i},{j}n".format(i=path, j=label))
        # 打开图片
        img = Image.open(path)
        # 将图片统一大小
        img = img.resize((300, 300))
        # 转换为bytes
        img_raw = img.tobytes()
        data = tf.train.Example(
            features=tf.train.Features(
                feature={
                    'label':tf.train.Feature(int64_list=tf.train.Int64List(value=[label])), # 0/1 分类,所以是INT
                    'image':tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])) # 字符串/图片/语音用bytes
                }
            )
        )
        writer.write(data.SerializeToString())
        if should_transpose:
            counter += 1
            # 将图片左右翻转后生成一张新的图片,label不变,
            rot_img = img.transpose(Image.FLIP_LEFT_RIGHT)
            rot_img_raw = rot_img.tobytes()
            data_2 = tf.train.Example(
                features=tf.train.Features(
                    feature={
                        'label':tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),
                        'image':tf.train.Feature(bytes_list=tf.train.BytesList(value=[rot_img_raw]))
                    }
                )
            )
            writer.write(data_2.SerializeToString())
    writer.close()
    print("写入数据集-DONE, 共存{}个数据".format(counter))
  • writer = tf.python_io.TFRecordWriter(records_path)启动一个写入TFRecord句柄, 遍历数组取出图片和label, 将图片设置统一大小并转为bytes.
  • tf.train.Example(tf.train.Features)是核心的处理代码, Example成TensorFlow的特定规则数据, 通过使用TFRecordWriter写入到TFRecord中.Example包含一个键值对数据结构(与dict相同), 使用属性features记录, 因此, 初始化时必须传入这个features参数
  • writer.write(data.SerializeToString())把Example序列成字符串写入TFRecord
  • should_transpose=False参数用来配置是否水平翻转图片, 并令数据扩大一倍
  • 当然关于TFRecord的写入具体规则, 如果需要可以参考该链接 Tensorflow: 文件读写

测试是否写入成功

执行函数代码, 需要注意, 我提前将数据存储在python的列表中的, 可以使用random.shuffle进行数据的洗牌

classic_data(plane_path)
classic_data(UNplane_path)
random.shuffle(dataset_list)
create_record(dataset_list, should_transpose=True)

结果:

可以看到数据洗牌, 而且数据*2, 都成功做到了.