本篇文章主要介绍如何使用TensorFlow构建自己的图片数据集TFRecord的方法,并使用最新的数据处理Dataset API进行操作。

TFRecord

TFRecord数据文件是一种对任何数据进行存储的二进制文件,能更好的利用内存,在TensorFlow中快速的复制,移动,读取,存储等,只要生成一次TFRecord,之后的数据读取和加工处理的效率都会得到提高。

一般来说,我们使用TensorFlow进行数据读取的方式有以下4种:

  • (1)预先把所有数据加载进内存
  • (2)在每轮训练中使用原生Python代码读取一部分数据,然后使用feed_dict输入到计算图
  • (3)利用Threading和Queues从TFRecord中分批次读取数据
  • (4)使用Dataset API

(1)方案对于数据量不大的场景来说是足够简单而高效的,但是随着数据量的增长,势必会对有限的内存空间带来极大的压力,还有长时间的数据预加载,甚至导致我们十分熟悉的OutOfMemoryError。

(2)方案可以一定程度上缓解了方案(1)的内存压力问题,但是由于在单线程环境下我们的IO操作一般都是同步阻塞的,势必会在一定程度上导致学习时间的增加,尤其是相同的数据需要重复多次读取的情况下。

而方案(3)和方案(4)都利用了我们的TFRecord,由于使用了多线程使得IO操作不再阻塞我们的模型训练,同时为了实现线程间的数据传输引入了Queues。

在本文中,我们主要使用方案(4)进行操作。

建立TFRecord

整体上建立TFRecord文件的流程主要如下;

  • 在TFRecord数据文件中,任何数据都是以bytes列表或float列表或int64列表的形式存储(注意:是列表形式),因此,将每条数据转化为列表格式。
  • 创建的每条数据列表都必须由一个Feature类包装,并且,每个feature都存储在一个key-value键值对中,其中key对应每个feature的名称。这些key将在后面从TFRecord提取数据时使用。
  • 当所需的字典创建完之后,会传递给Features类。
  • 最后,将features对象作为输入传递给example类,然后这个example类对象会被追加到TFRecord中。
  • 对于所有数据,重复上述过程。

接下来,对一个简单数据创建TFRecord。我们创建了两条样例数据,包含了整型、浮点型、字符串型和列表型,如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import  tensorflow as tf
# 案例数据
data_arr = [
{
'int_data':108, # 整型
'float_data':2.45, #浮点型
'str_data':'string 100'.encode(), # 字符串型,python3下转化为byte
'float_list_data':[256.78,13.9] # 列表型
},
{
'int_data': 2108,
'float_data': 12.45,
'str_data': 'string 200'.encode(),
'float_list_data': [1.34,256.78, 65.22]
}
]

首先,我们将原始数据的每一个值转换成列表形式。需要注意的是每条数据对应的数据类型。

1
2
3
4
5
6
7
8
#处理一条数据
def get_example_object(data_record):
# 将数据转化为int64 float 或bytes类型的列表
# 注意都是list形式
int_list1 = tf.train.Int64List(value = [data_record['int_data']])
float_list1 = tf.train.FloatList(value = [data_record['float_data']])
str_list1 = tf.train.BytesList(value = [data_record['str_data']])
float_list2 = tf.train.FloatList(value = data_record['float_list_data'])

然后,使用Feature类对每个数据列表进行包装,并且以key-value的字典格式存储。

1
2
3
4
5
6
7
# 将数据封装成一个dict
feature_key_value_pair = {
'int_list':tf.train.Feature(int64_list = int_list1),
'float_list': tf.train.Feature(float_list=float_list1),
'str_list': tf.train.Feature(bytes_list=str_list1),
'float_list2': tf.train.Feature(float_list=float_list2),
}

接着,将创建好的feature字典传递给features类,并且使用Example类处理成一个example。

1
2
3
4
5
# 创建一个features
features = tf.train.Features(feature = feature_key_value_pair)
# 创建一个example
example = tf.train.Example(features = features)
return example

最后,遍历所有数据集,将每条数据写入tfrecord中。

1
2
3
4
5
6
with tf.python_io.TFRecordWriter('example.tfrecord') as tfwriter:
#遍历所有数据
for data_record in data_arr:
example = get_example_object(data_record)
# 写入tfrecord数据文件
tfwriter.write(example.SerializeToString())

运行整个代码之后,我们在磁盘中将看到一个’example.tfrecord’文件

1
2
3
$ ls |grep *.tfrecord

example.tfrecord

该文件中存储的就是上面我们定义好的两条数据,接下来,我们将图像数据保存到TFRecord文件中。

图像数据-TFRecord

通过上面一个简单例子,我们基本了解了如何为包含字典和列表的文本类型的数据创建TFRecord,接下来,我们对图像数据创建TFRecord。我们使用kaggle上面的猫狗数据集。

该数据集可以从:kaggle猫狗进行下载。

下载完之后,我们会得到两个文件夹

1
test  train

其中train文件夹中主要是训练数据集,test文件夹中主要是预测数据集,主要对train数据集进行操作。

1
2
3
ls |wc -w

25000

该训练集中一共有25000张图像,其中猫狗图像各一半,接下来我们看看数据格式。

1
2
3
4
$ ls 

cat.124.jpg cat.3750.jpg cat.6250.jpg cat.8751.jpg dog.11250.jpg dog.2500.jpg dog.5000.jpg dog.7501.jpg
...

在train文件夹中,我们可以看到图片数据主要是以.jpg结尾的,并且文件名中包含了该图像的所属标签,我们需要从文件名中提取每张图像对应的标签类别。

对图像数据进行保存,主要有两种方式。首先我们来看看常见的方式,即首先读取这些图像数据,然后将这些数值化的图像数据转化为字符串形式,并存储到TFRecord。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23

import tensorflow as tf
import os
import time
from glob import glob
import progressbar
from PIL import Image

class GenerateTFRecord():
def __init__(self,labels):
self.labels = labels

def _get_label_with_filename(self,filename):
basename = os.path.basename(filename).split(".")[0]
return self.labels[basename]

def _convert_image(self,img_path,is_train=True):
label = self._get_label_with_filename(img_path)
image_data = Image.open(img_path)
image_data = image_data.resize((227, 227)) # 重新定义图片的大小
image_str = image_data.tobytes()
filename = os.path.basename(img_path)

首先,我们创建一个生成TFRecorf类——GenerateTFRecord,其中,label一般是一个字典格式,将文本型的标签转化为对应的数值型标签,比如,这里,我们令0表示猫,1表示狗,从而label为

1
labels = {"cat":0,'dog':1}

另外,函数_get_label_with_fielname主要是从文件名中提取对应的标签类别。

接着,我们定义一个转换函数-_convert_image,

  • img_path:表示一张图片的具体路径
  • is_train:表示是否是训练集,上面我们下载了两份数据,训练数据集中带有标签,而test数据集中没有标签,在保存成TFRecord时,令test的数据label为-1

首先使用Image读取数据,接着将数据大小统一成227x227x3(这里只是一个案例,一般我们在构建模型之前会将图像数据大小统一成一个指定的大小),然后将图像数据转化为二进制格式。

处理完原始图像数据之后,构建一个example。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
if is_train:
feature_key_value_pair = {
'filename': tf.train.Feature(bytes_list=tf.train.BytesList(value=[filename.encode()])),
'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_str])),
'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))
}
else:
feature_key_value_pair = {
'filename': tf.train.Feature(bytes_list=tf.train.BytesList(value=[filename.encode()])),
'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_str])),
'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[-1]))
}
feature = tf.train.Features(feature = feature_key_value_pair)
example = tf.train.Example(features = feature)
return example

这里,我们保存了三个信息,即文件名、处理之后的图像信息和图像标签(当然还可以保存其他数据,只要按照上面格式定义好就行了)。

每张图像处理模式完成之后,遍历所有train数据集,并保存到tfrecord中。

1
2
3
4
5
6
7
8
9
10
11
12
def convert_image_folder(self,img_folder,tfrecord_file_name):
img_paths = [img_path for img_path in glob(os.path.join(img_folder,'*'))]

with tf.python_io.TFRecordWriter(tfrecord_file_name) as tfwriter:
widgets = ['[INFO] write image to tfrecord: ', progressbar.Percentage(), " ",
progressbar.Bar(), " ", progressbar.ETA()]
pbar = progressbar.ProgressBar(maxval=len(img_paths), widgets=widgets).start()
for i,img_path in enumerate(img_paths):
example = self._convert_image(img_path,is_train=True)
tfwriter.write(example.SerializeToString())
pbar.update(i)
pbar.finish()

其中:

  • img_folder:原始图像存放的路径
  • tfrecord_file_name:tfrecord文件保存路径

上面,我们使用了progressbar模块,该模块是一个进度条显示模块,可以帮助我们很好的监控数据处理情况。

最后,加入下列代码,并运行整个代码以完成train数据集的tfrexord构建。

1
2
3
4
5
6
if __name__ == "__main__":
start = time.time()
labels = {"cat":0,'dog':1}
t = GenerateTFRecord(labels)
t.convert_image_folder('train','train.tfrecord')
print("Took %f seconds." % (time.time() - start))

该方法使用了约115s完成了整个train数据集的TFRecord生成过程,在目录中,我们生成了一个名为train.tfrecord的文件。

1
2
3
$ ls -lht

11G train.tfrecord

该文件大小居然达到了11G(注意:该文件直接保存的是原始图像,不是处理之后的,因为需要跟另一种方法进行比较)。从前面,我们知道该train数据集中只有25000张图像数据,每张图像大小差不多50kb左右,25000张图像大小总共差不多1.2G左右,而生成的TFRecord文件居然达到11G,那么对于imagenet的数据集,可能会发生磁盘装不下的。这或许是许多人不喜欢使用TFRecord的一个原因吧。

为什么TFRecord变得如此巨大?

我们来简单的分析下,通过查看每张图像的shape,比如cat.8739.jpg,

1
2
3
4
5
6
7
import matplotlib.image as mpimg
from PIL import Image
img_path = 'train/cat.8739.jpg'
img_data = mpimg.imread(img_path)
img_data.shape

# output:(324,319,3)

该猫图像数据的shape是(324,319,3)。对每个维度进行相乘,即324x319x3=310068,那么在numpy数据格式中(假设数据类型为unit8),该图片以310069个整数表示。当我们调用.tobytes()时,这些数字将按顺序存在在一个二进制序列中。我们假设每一个数字都是大于100的,也就是需要三个字符,如果每个数字之间使用’,'分割,则对于该图片,我们需要:

310068 x(3+1) = 1240232个字符,如果一个字符对应一个字节,那么一张图片就差不多需要1MB。

上面只是个人计算,也许本身就不对的。

如何解决?

我们从另一个角度考虑:图片的存储大小,即上面我们分析每张图片差不多就50kb左右。其实在实际应用中,很多训练数据集的图像存储大小一般都在几kb到几百kb左右。因此,我们可以直接存储图像的bytes到tfrecord中。tensorflow模块提供了一个tf.gfile.FastGFile类,可以直接读取图像的bytes形式。我们来看看tf.gfile.FastGFile主要读取的是什么内容。

1
2
3
4
5
path_jpg = img_path = 'train/cat.8739.jpg'
image_raw_data = tf.gfile.FastGFile(path_jpg,'rb').read()

with tf.Session() as sess:
print(image_raw_data)

你将在屏幕上看到一大串的bytes,比如;

1
2
b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00\x00\x01\x00\x01\x00\x00\xff\xdb\x00C\x00\n\x07\x07\x08\x07\x06\n\x08\x08\x08\x0b\n\n\x0b\x0e\x18\x10\x0e\r\r\x0e\x1d\x15\x16\x11\x18#\x1f%$"\x1f"!&+7/&)4)!
....

我们可以看到tf.gfile.FastGFile读取的不在是原始图像的内容,也不是numpy格式。

因此,我们将读取图像部分代码替换为:

1
2
with tf.gfile.FastGFile(img_path,'rb') as fid:
image_str = fid.read()

其他保持不变,并且保存为train2.tfrecord文件。即:

1
2
3
4
5
6
if __name__ == "__main__":
start = time.time()
labels = {"cat":0,'dog':1}
t = GenerateTFRecord(labels)
t.convert_image_folder('train','train2.tfrecord')
print("Took %f seconds." % (time.time() - start))

该方法只使用了约8s完成了整个train数据集的TFRecord生成过程,在目录中,我们生成了一个新的train2.tfrecord的文件

1
2
3
$ ls -lht

548M train2.tfrecord

从结果中可以看到,新的TFRecord文件只有548M,相比原先的11G,减小了很多。因此使用tf.gfile.FastGFile读取图像数据,明显的好处有:

  • 缩短了读取数据时间
  • 降低了磁盘使用大小

当然还有其他办法可以再进一步降低大小,但是可能会改变图像的内容。因此,这里就不做描述了。因为这种降低已经可以满足我目前的项目需求了。

从TFRecord中提取数据

上面我们已经对数据生成了TFRecord文件,接下来,我们将从中读取出数据。具体如下:

  • 首先,对生成的TFRecord初始化一个TFRecordDataset类
  • 接着,从TFRecord中提取数据,这里就需要利用到我们之前设定的key值,另外。如果我们知道每个值列表中的大小(即大小相同的),那么我们可以使用FixedLenFeature,否则,我们应该使用VarLenFeature。
  • 最后,使用parse_single_example api从每条data record中提取我们定义的数据字典。

下面,我们通过一个简单的提取数据代码来说明整个过程。

1
2
3
4
5
6
7
8
9
10
11
12
import tensorflow as tf
def extract_fn(data_record):
features = {
'int_list':tf.FixedLenFeature([],tf.int64),
'float_list':tf.FixedLenFeature([],tf.float32),
'str_list':tf.FixedLenFeature([],tf.string),
# 如果不同的record中的大小不一样,则使用VarLenFeature
'float_list2':tf.VarLenFeature(tf.float32)
}

sample = tf.parse_single_example(data_record,features)
return sample

上面的extract_fn函数对应了整个过程,下面我们使用Dataset模块处理数据

1
2
3
4
5
6
# 使用dataset模块读取数据
dataset = tf.data.TFRecordDataset(filenames=['example.tfrecord'])
# 对每一条record进行解析
dataset = dataset.map(extractz_fn)
iterator = dataset.make_one_shot_iterator()
next_example = iterator.get_next()

首先,对TFRrecord初始化一个TFRecordDataset类,然后通过map函数对TFRecords中的每条记录提取数据,最后通过一个迭代器一条条返回数据。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# eager 模式下
tf.enable_eager_execution()
try:
while True:
next_example = iterator.get_next()
print(next_example)
except:
pass

# 非eager模式
with tf.Session() as sess:
try:
while True:
data_record = sess.run(next_example)
print(data_record)
except:
pass

从TFRecord中提取图像

在对图像TFRecord数据文件提取数据时,需要利用tf.image.decode_image API,可以对图像数据进行解码,直接看代码:

1
2
3
4
5
6
7
import tensorflow as tf
import os
class TFRecordExtractor():
def __init__(self,tfrecord_file,epochs,batch_size):
self.tfrecord_file = os.path.abspath(tfrecord_file)
self.epochs = epochs
self.batch_size = batch_size

其中:

  • tfrecord_file:tfrecord数据文件路径
  • epochs:模型训练的epochs
  • batch_size: batch的大小,每次返回的数据量

定义一个提取数据函数,该函数后面通过map函数对每个data record进行解析。类似于生成TFRecord的feature格式,解析成字典格式,主要是通过key值获取对应的数据。

1
2
3
4
5
6
7
8
9
10
11
12
def _extract_fn(self,tfrecord):
# 解码器
# 解析出一条数据,如果需要解析多条数据,可以使用parse_example函数
# tf提供了两种不同的属性解析方法:
## 1. tf.FixdLenFeature:得到的是一个Tensor
## 2. tf.VarLenFeature:得到的是一个sparseTensor,用于处理稀疏数据
features ={
'filename': tf.FixedLenFeature([],tf.string),
'image': tf.FixedLenFeature([],tf.string),
'label': tf.FixedLenFeature([],tf.int64)
}

下面,使用tf.image.decode_image API对图像数据进行解码,并重新定义图像的大小(由于使用tf.gfile.FastGFile读取图像数据时无法重新定义图像大小,因此我们在解码时候进行重新定义图像大小)。最后返回图像数据、标签和文件名。

1
2
3
4
5
6
sample = tf.parse_single_example(tfrecord,features)
image = tf.image.decode_jpeg(sample['image'])
image = tf.image.resize_images(image, (227, 227),method=1)
label = sample['label']
filename = sample['filename']
return [image,label,filename]

使用Dataset对TFRecord文件进行操作:

1
2
3
4
5
def extract_image(self):
dataset = tf.data.TFRecordDataset([self.tfrecord_file])
dataset = dataset.map(self._extract_fn)
dataset = dataset.repeat(count = self.epochs).batch(batch_size=self.batch_size)
return dataset

首先,对TFRecord文件初始化一个 tf.data.TFRecordDataset类。接着使用map函数对每条data record进行_extract_fn解析。这里的epochs和batch_size跟模型训练有关,该函数最后返回一个迭代器,每次调取的是batch大小的数据量。

1
2
3
4
5
6
if __name__ == "__main__":
#tf.enable_eager_execution()
t = TFRecordExtractor('train2.tfrecord',epochs=1,batch_size=10)
dataset = t.extract_image()
for (batch,batch_data) in enumerate(dataset):
pass 

完成代码

我将两个功能何在一个TFRecord类中,主要是方便后续使用。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
# encoding:utf-8
import tensorflow as tf
import os
from glob import glob
import progressbar

class TFRecord():
def __init__(self, labels, tfrecord_file):
self.labels = labels
self.tfrecord_file = tfrecord_file

def _get_label_with_filename(self, filename):
basename = os.path.basename(filename).split(".")[0]
return self.labels[basename]

def _convert_image(self, img_path, is_train=True):
label = self._get_label_with_filename(img_path)
filename = os.path.basename(img_path)
with tf.gfile.FastGFile(img_path, 'rb') as fid:
image_str = fid.read()
if is_train:
feature_key_value_pair = {
'filename': tf.train.Feature(bytes_list=tf.train.BytesList(value=[filename.encode()])),
'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_str])),
'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))
}
else:
feature_key_value_pair = {
'filename': tf.train.Feature(bytes_list=tf.train.BytesList(value=[filename.encode()])),
'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_str])),
'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[-1]))
}
feature = tf.train.Features(feature=feature_key_value_pair)
example = tf.train.Example(features=feature)
return example

def convert_image_folder(self, img_folder):
img_paths = [img_path for img_path in glob(os.path.join(img_folder, '*'))]

with tf.python_io.TFRecordWriter(self.tfrecord_file) as tfwriter:
widgets = ['[INFO] write image to tfrecord: ', progressbar.Percentage(), " ",
progressbar.Bar(), " ", progressbar.ETA()]
pbar = progressbar.ProgressBar(maxval=len(img_paths), widgets=widgets).start()
for i, img_path in enumerate(img_paths):
example = self._convert_image(img_path, is_train=True)
tfwriter.write(example.SerializeToString())
pbar.update(i)
pbar.finish()

def _extract_fn(self, tfrecord):
# 解码器
# 解析出一条数据,如果需要解析多条数据,可以使用parse_example函数
# tf提供了两种不同的属性解析方法:
## 1. tf.FixdLenFeature:得到的是一个Tensor
## 2. tf.VarLenFeature:得到的是一个sparseTensor,用于处理稀疏数据
features = {
'filename': tf.FixedLenFeature([], tf.string),
'image': tf.FixedLenFeature([], tf.string),
'label': tf.FixedLenFeature([], tf.int64)
}

sample = tf.parse_single_example(tfrecord, features)
image = tf.image.decode_jpeg(sample['image'])
image = tf.image.resize_images(image, (227, 227), method=1)
label = sample['label']
filename = sample['filename']
return [image, label, filename]

def extract_image(self, shuffle_size,batch_size):
dataset = tf.data.TFRecordDataset([self.tfrecord_file])
dataset = dataset.map(self._extract_fn)
dataset = dataset.shuffle(shuffle_size).batch(batch_size)
return dataset