MXNet官方教程5Iterators-加载数据
Posted xiang_freedom
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了MXNet官方教程5Iterators-加载数据相关的知识,希望对你有一定的参考价值。
在这篇教程里,我们关注将数据放入训练或预测模型。大部分MXNet的训练和预测模型支持数据迭代器(Iterators),它简化了数据加载过程,尤其是读取大量数据的时候。这里我们介绍一下API规范和几个定义好的迭代器。
先决条件
我们需要:
- MXNet
- OpenCV Python library, Python Requests, Matplotlib 和 Jupyter Notebook.
$ pip install opencv-python requests matplotlib jupyter
- 设置
MXNET_HOME
环境变量为MXNet源代码目录
$ git clone https://github.com/dmlc/mxnet ~/mxnet
$ export MXNET_HOME='~/mxnet'
MXNet数据迭代器
MXNet里的数据迭代器和Python对象迭代器差不多。在Python里,iter
方法通过调用Python迭代器对象(比如list
)的next()
方法来顺序读取元素。迭代器给各种可迭代对象提供了抽象的接口,从而不必要暴露底层数据结构。
在MXNet里,每次调用数据迭代器的next
方法会返回一批数据DataBatch
。一个DataBatch
通常包含n个训练样本和对应的标签,这个n称为迭代器的batch_size
。当数据流末尾没有数据可读时,迭代器和Python iter
一样抛出StopIteration
异常。DataBatch
的结构定义在这里。
样本和标签的信息,比如name,shape,type和layout,由数据描述对象DataDesc
提供,DataDesc
对象可以由DataBatch
的provide_data
和provide_label
属性得到。DataDesc
的结构定义在这里。
MXNet的所有IO都集中在mx.io.DataIter
和其子类里。在这篇教程里,我们将讨论MXNet提供的几个常用的数据迭代器。
在此之前,我们先导入一些必要的包
import mxnet as mx
%matplotlib inline
import os
import sys
import subprocess
import numpy as np
import matplotlib.pyplot as plt
import tarfile
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)
从内存中读取数据
当数据已加载到内存后,不管是NDArray还是numpy的ndarray,我们可以用NDArrayIter读取:
import numpy as np
data = np.random.rand(100,3)
label = np.random.randint(0, 10, (100,))
data_iter = mx.io.NDArrayIter(data=data, label=label, batch_size=30)
for batch in data_iter:
print([batch.data, batch.label, batch.pad])
[[
[[ 0.6530531 0.46522644 0.51619464]
[ 0.16499737 0.76653564 0.03481427]
[ 0.74804795 0.52990937 0.26833427]
[ 0.66638368 0.3022213 0.06241459]
[ 0.91972476 0.55546296 0.66393465]
[ 0.08645096 0.38091755 0.56270498]
[ 0.64666849 0.21491572 0.35175693]
[ 0.86468941 0.47517869 0.05744886]
[ 0.67380011 0.05342721 0.4926973 ]
[ 0.19055551 0.69065297 0.67532688]
[ 0.02628353 0.6844967 0.19746889]
[ 0.66295046 0.09987242 0.58984101]
[ 0.9458763 0.81699216 0.1675925 ]
[ 0.46482503 0.21350947 0.06471642]
[ 0.89144808 0.94313467 0.68858165]
[ 0.81668401 0.7621479 0.27384126]
[ 0.63461167 0.65230727 0.97777712]
[ 0.22063005 0.24458201 0.10742629]
[ 0.92816764 0.13466544 0.04408605]
[ 0.30227584 0.14775786 0.87613076]
[ 0.63119137 0.81612813 0.92117757]
[ 0.94886911 0.43556175 0.46657735]
[ 0.27003673 0.76513189 0.23725513]
[ 0.2746343 0.47627121 0.54125744]
[ 0.25552508 0.01837774 0.39958724]
[ 0.83126527 0.03519598 0.34842646]
[ 0.2845566 0.64368129 0.46485582]
[ 0.08218119 0.41793332 0.51502693]
[ 0.09406272 0.91428256 0.31059062]
[ 0.13820368 0.17033553 0.28657481]]
<NDArray 30x3 @cpu(0)>], [
[ 7. 0. 0. 2. 0. 7. 0. 5. 8. 6. 8. 2. 1. 4. 5. 1. 7. 0.
9. 3. 6. 0. 4. 0. 5. 5. 4. 4. 1. 0.]
<NDArray 30 @cpu(0)>], 0]
[[
[[ 0.13700683 0.52412778 0.8671298 ]
[ 0.23877266 0.33903974 0.27537507]
[ 0.06243272 0.19531037 0.5281179 ]
[ 0.82232028 0.56463391 0.17138779]
[ 0.5569911 0.20909874 0.52542228]
[ 0.60343611 0.23063758 0.81608468]
[ 0.70935023 0.78453153 0.78045279]
[ 0.11286663 0.79524058 0.09895906]
[ 0.77800322 0.16236411 0.87968087]
[ 0.01127686 0.69375426 0.08547841]
[ 0.82750279 0.0399946 0.60687792]
[ 0.53598893 0.78744203 0.96958113]
[ 0.82449615 0.11746258 0.29264763]
[ 0.4362683 0.64713514 0.10649233]
[ 0.14894192 0.90637457 0.13595931]
[ 0.08151129 0.23844923 0.09844355]
[ 0.33128792 0.7256636 0.84794742]
[ 0.85739374 0.19100513 0.895199 ]
[ 0.21185175 0.80707216 0.97806442]
[ 0.64928633 0.63623446 0.12098809]
[ 0.39623061 0.78550547 0.52882141]
[ 0.13212556 0.1327759 0.27480963]
[ 0.9550342 0.47325855 0.08431709]
[ 0.7431556 0.03889066 0.39910018]
[ 0.52704382 0.44965392 0.76548541]
[ 0.69834208 0.30493379 0.17661361]
[ 0.60261613 0.54987943 0.25630507]
[ 0.54871225 0.27588007 0.58100933]
[ 0.44057187 0.44038841 0.16040143]
[ 0.90090007 0.6116063 0.17815863]]
<NDArray 30x3 @cpu(0)>], [
[ 3. 2. 1. 6. 7. 0. 9. 0. 3. 2. 7. 9. 3. 9. 9. 7. 0. 0.
8. 3. 1. 7. 0. 7. 2. 5. 7. 7. 2. 2.]
<NDArray 30 @cpu(0)>], 0]
[[
[[ 0.38992238 0.38013816 0.08797289]
[ 0.52365917 0.90862477 0.08777413]
[ 0.2257922 0.86751235 0.90748298]
[ 0.33275157 0.43818691 0.6017009 ]
[ 0.60666031 0.2776112 0.05172342]
[ 0.6747095 0.57716769 0.7660436 ]
[ 0.17980796 0.22112246 0.11028604]
[ 0.01789156 0.70913017 0.19826613]
[ 0.5153966 0.89273322 0.36250469]
[ 0.28154254 0.76787293 0.96268386]
[ 0.41763589 0.36206847 0.44974893]
[ 0.15868166 0.08014139 0.8982119 ]
[ 0.47615659 0.41381609 0.7943607 ]
[ 0.17594658 0.37957036 0.61883914]
[ 0.67709279 0.857777 0.1074606 ]
[ 0.69237596 0.84633547 0.50271791]
[ 0.16945797 0.68244201 0.19087805]
[ 0.1299092 0.49758923 0.37441683]
[ 0.19249067 0.42367426 0.04579045]
[ 0.74409103 0.40689459 0.38356331]
[ 0.00723282 0.24958441 0.46890974]
[ 0.61725456 0.08689262 0.64768285]
[ 0.0569613 0.89262682 0.67263258]
[ 0.17189403 0.5448252 0.02259486]
[ 0.49834073 0.36860585 0.8104018 ]
[ 0.63564551 0.62717378 0.90756214]
[ 0.3532913 0.82186127 0.07672632]
[ 0.72964108 0.71190619 0.22283019]
[ 0.77529597 0.09597207 0.45330995]
[ 0.10836289 0.07343143 0.02535379]]
<NDArray 30x3 @cpu(0)>], [
[ 7. 3. 8. 3. 1. 9. 8. 9. 9. 1. 3. 4. 7. 6. 7. 5. 4. 7.
2. 3. 3. 3. 7. 2. 5. 2. 7. 2. 6. 2.]
<NDArray 30 @cpu(0)>], 0]
[[
[[ 0.79986835 0.65680069 0.76467651]
[ 0.81243736 0.09769222 0.27826148]
[ 0.42728153 0.97143823 0.02860877]
[ 0.01750882 0.9944576 0.80612904]
[ 0.47085366 0.35999826 0.48983538]
[ 0.24489172 0.37438354 0.81461328]
[ 0.72018224 0.4823792 0.02590245]
[ 0.97203141 0.78433287 0.1679011 ]
[ 0.86534786 0.07474694 0.96176213]
[ 0.04950644 0.82327616 0.80272979]
[ 0.6530531 0.46522644 0.51619464]
[ 0.16499737 0.76653564 0.03481427]
[ 0.74804795 0.52990937 0.26833427]
[ 0.66638368 0.3022213 0.06241459]
[ 0.91972476 0.55546296 0.66393465]
[ 0.08645096 0.38091755 0.56270498]
[ 0.64666849 0.21491572 0.35175693]
[ 0.86468941 0.47517869 0.05744886]
[ 0.67380011 0.05342721 0.4926973 ]
[ 0.19055551 0.69065297 0.67532688]
[ 0.02628353 0.6844967 0.19746889]
[ 0.66295046 0.09987242 0.58984101]
[ 0.9458763 0.81699216 0.1675925 ]
[ 0.46482503 0.21350947 0.06471642]
[ 0.89144808 0.94313467 0.68858165]
[ 0.81668401 0.7621479 0.27384126]
[ 0.63461167 0.65230727 0.97777712]
[ 0.22063005 0.24458201 0.10742629]
[ 0.92816764 0.13466544 0.04408605]
[ 0.30227584 0.14775786 0.87613076]]
<NDArray 30x3 @cpu(0)>], [
[ 6. 4. 2. 5. 4. 5. 2. 3. 6. 1. 7. 0. 0. 2. 0. 7. 0. 5.
8. 6. 8. 2. 1. 4. 5. 1. 7. 0. 9. 3.]
<NDArray 30 @cpu(0)>], 20]
从CSV文件读取数据
MXNet提供CSVIter读取CSV文件:
# lets save `data` into a csv file first and try reading it back
np.savetxt('data.csv', data, delimiter=',')
data_iter = mx.io.CSVIter(data_csv='data.csv', data_shape=(3,), batch_size=30)
for batch in data_iter:
print([batch.data, batch.pad])
自定义迭代器
当内置的迭代器不能满足应用需求时,可以创建自己的迭代器。
一个MXNet的迭代器应该满足:
- 实现Python2的
next()
或者Python3的__next()__
,返回一个DataBatch
或者在数据流末尾抛出StopIteration
。 - 实现
reset()
方法,从头开始重新读数据。 - 提供一个
provide_data
属性,包含一个DataDesc
列表,每个DataDesc
包含数据的name, shape, type 和 layout information信息(详见这里)。 - 提供一个
provide_label
属性,包含一个DataDesc
列表,每个DataDesc
包含标签的name, shape, type 和 layout information信息。
创建新迭代器时,你可以从头开始创建或者利用已经存在的迭代器。例如,在图片字幕应用里,数据样本是图片而标签是文本。所以我们可以这样创建新迭代器:
- 用提供多线程数据预处理和增强的
ImageRecordIter
创建一个image_iter
- 用
NDArrayIter
或者rnn包里的bucketing迭代器创建一个caption_iter
next()
返回image_iter.next()
和caption_iter.next()
的组合。
下面的代码展示了如何创建一个简单的迭代器:
class SimpleIter(mx.io.DataIter):
def __init__(self, data_names, data_shapes, data_gen,
label_names, label_shapes, label_gen, num_batches=10):
self._provide_data = list(zip(data_names, data_shapes))
self._provide_label = list(zip(label_names, label_shapes))
self.num_batches = num_batches
self.data_gen = data_gen
self.label_gen = label_gen
self.cur_batch = 0
def __iter__(self):
return self
def reset(self):
self.cur_batch = 0
def __next__(self):
return self.next()
@property
def provide_data(self):
return self._provide_data
@property
def provide_label(self):
return self._provide_label
def next(self):
if self.cur_batch < self.num_batches:
self.cur_batch += 1
data = [mx.nd.array(g(d[1])) for d,g in zip(self._provide_data, self.data_gen)]
label = [mx.nd.array(g(d[1])) for d,g in zip(self._provide_label, self.label_gen)]
return mx.io.DataBatch(data, label)
else:
raise StopIteration
我们可以用上面的SimpleIter
训练一个简单的MLP模型:
import mxnet as mx
num_classes = 10
net = mx.sym.Variable('data')
net = mx.sym.FullyConnected(data=net, name='fc1', num_hidden=64)
net = mx.sym.Activation(data=net, name='relu1', act_type="relu")
net = mx.sym.FullyConnected(data=net, name='fc2', num_hidden=num_classes)
net = mx.sym.SoftmaxOutput(data=net, name='softmax')
print(net.list_arguments())
print(net.list_outputs())
['data', 'fc1_weight', 'fc1_bias', 'fc2_weight', 'fc2_bias', 'softmax_label']
['softmax_output']
这里,有四个需要学习的参数:全连接层fc1和fc2的weights和biases。有两个输入变量:训练样本data和标签softmax_label,最后还有一个输出softmax_output。
data是MXNet的Symbol API提供的变量,为了执行Symbol,它们需要先绑定数据。详见【MXNet官方教程3】Symbol -神经网络图和自动区分。
我们通过MXNet的module API把迭代器的数据送入神经网络。详见【MXNet官方教程4】Module - 神经网络训练和预测。
import logging
logging.basicConfig(level=logging.INFO)
n = 32
data_iter = SimpleIter(['data'], [(n, 100)],
[lambda s: np.random.uniform(-1, 1, s)],
['softmax_label'], [(n,)],
[lambda s: np.random.randint(0, num_classes, s)])
mod = mx.mod.Module(symbol=net)
mod.fit(data_iter, num_epoch=5)
INFO:root:Epoch[0] Train-accuracy=0.081250
INFO:root:Epoch[0] Time cost=0.006
INFO:root:Epoch[1] Train-accuracy=0.125000
INFO:root:Epoch[1] Time cost=0.005
INFO:root:Epoch[2] Train-accuracy=0.121875
INFO:root:Epoch[2] Time cost=0.006
INFO:root:Epoch[3] Train-accuracy=0.084375
INFO:root:Epoch[3] Time cost=0.004
INFO:root:Epoch[4] Train-accuracy=0.109375
INFO:root:Epoch[4] Time cost=0.005
使用Python3的注意事项:mxnet的许多方法用python2的字符串和python3的字节。为了保持教程的可读性,我们定义一个工具方法,将字符串转为python3的字节。
def str_or_bytes(str):
"""
A utility function for this tutorial that helps us convert string
to bytes if we are using python3.
Parameters
----------
str : string
Returns
-------
string (python2) or bytes (python3)
"""
if sys.version_info[0] < 3:
return str
else:
return bytes(str, 'utf-8')
Record IO
Record IO是MXNet 数据IO的一种文件格式。其对数据简洁的封装用于分布式文件系统比如 Hadoop HDFS 和 AWS S3的高效读写。更多内容参见这里。
MXNet提供了MXRecordIO和MXIndexedRecordIO用于顺序读取和随机读取数据。
MXRecordIO
首先,我们来看一下怎么用MXRecordIO
顺序读取数据。文件以.rec
为后缀。
record = mx.recordio.MXRecordIO('tmp.rec', 'w')
for i in range(5):
record.write(str_or_bytes('record_%d'%i))
record.close()
我们可以用参数r
读取文件:
record = mx.recordio.MXRecordIO('tmp.rec', 'r')
while True:
item = record.read()
if not item:
break
print (item)
record.close()
b'record_0'
b'record_1'
b'record_2'
b'record_3'
b'record_4'
MXIndexedRecordIO
MXIndexedRecordIO
支持随机或者按索引读取数据。我们创建一个按索引记录文件和对应的索引文件:
record = mx.recordio.MXIndexedRecordIO('tmp.idx', 'tmp.rec', 'w')
for i in range(5):
record.write_idx(i, str_or_bytes('record_%d'%i))
record.close()
现在,我们可以用索引键读取对应的记录
record = mx.recordio.MXIndexedRecordIO('tmp.idx', 'tmp.rec', 'r')
record.read_idx(3)
b'record_3'
你也可以把所有的索引键列出来:
record.keys
[0, 1, 2, 3, 4]
封装和解封装数据
.rec文件的每条记录都包含任意的二进制数据。然而,大多数深度学习任务需要数据以标签/样本的格式输入。mx.recordio
包提供了一些工具方法,比如pack
, unpack
, pack_img
, 和unpack_img
。
封装/解封装二进制数据
pack
和unpack
用于浮点数(或者一维浮点数向量)标签和二进制数据。数据和一个header包装在一起,header的结构定义在这里。
# pack
data = 'data'
label1 = 1.0
header1 = mx.recordio.IRHeader(flag=0, label=label1, id=1, id2=0)
s1 = mx.recordio.pack(header1, str_or_bytes(data))
label2 = [1.0, 2.0, 3.0]
header2 = mx.recordio.IRHeader(flag=3, label=label2, id=2, id2=0)
s2 = mx.recordio.pack(header2, str_or_bytes(data))
# unpack
print(mx.recordio.unpack(s1))
print(mx.recordio.unpack(s2))
(HEADER(flag=0, label=1.0, id=1, id2=0), b'data')
(HEADER(flag=3, label=array([ 1., 2., 3.], dtype=float32), id=2, id2=0), b'data')
封装/解封装图片
MXNet提供pack_img
和 unpack_img
来封装/解封装图片数据。pack_img
封装的记录可以直接由mx.io.ImageRecordIter
加载。
data = np.ones((3,3,1), dtype=np.uint8)
label = 1.0
header = mx.recordio.IRHeader(flag=0, label=label, id=0, id2=0)
s = mx.recordio.pack_img(header, data, quality=100, img_fmt='.jpg')
# unpack_img
print(mx.recordio.unpack_img(s))
(HEADER(flag=0, label=1.0, id=0, id2=0), array([[1, 1, 1],
[1, 1, 1],
[1, 1, 1]], dtype=uint8))
使用tools/im2rec.py
你可以用MXNet src/tools文件夹下的im2rec.py
工具脚本直接将源图片转换为RecordIO格式。在下面的图片IO部分有一个使用此脚本的例子。
图片IO
这部分,我们将学习怎么预处理和加载图片。
有4种加载图片的方式:
- 使用mx.image.imdecode加载原图。
- 使用python实现的且易于定制的mx.img.ImageIter。它可以读取.rec文件和原图。
- 使用C++实现的mx.io.ImageRecordIter。它不那么容易定制,但是可以绑定各种语言使用。
- 自定义迭代器继承
mx.io.DataIter
。
预处理图片
图片预处理有多种方式:
- 使用
mx.io.ImageRecordIter
,很快但不那么灵活。适用与简单的任务比如图片识别,但在复杂的任务比如检测和分割行不通。 - 使用
mx.recordio.unpack_img
(或者cv2.imread
,skimage
等) +numpy
,很灵活,但是比较慢(因为python全局解释器锁GIL)。 - 使用MXNet提供的
mx.image
包。它将图片存储为NDArray并且启动MXNet的dependency engine自动并行处理,规避GIL。
下面,我们演示mx.image
包里几种常见的预处理例子。
下载我们需要处理的图片
fname = mx.test_utils.download(url='http://data.mxnet.io/data/test_images.tar.gz', dirname='data', overwrite=False)
tar = tarfile.open(fname)
tar.extractall(path='./data')
tar.close()
加载原图
mx.image.imdecode
加载图片,imdecode
提供了类似于OpenCv的接口。
**注意:**为了使用mx.image.imdecode
,你仍然需要安装OpenCV
,而不是cv2 python库。
img = mx.image.imdecode(open('data/test_images/ILSVRC2012_val_00000001.JPEG', 'rb').read())
plt.imshow(img.asnumpy()); plt.show()
图片变换
# resize to w x h
tmp = mx.image.imresize(img, 100, 70)
plt.imshow(tmp.asnumpy()); plt.show()
# crop a random w x h region from image
tmp, coord = mx.image.random_crop(img, (150, 200))
print(coord)
plt.imshow(tmp.asnumpy()); plt.show()
使用图片迭代器加载数据
在使用两种内置的图片加载器之前,先获取一个包含101类对象的Caltech 101
数据集并且转为RecordIO格式。下载并解压:
fname = mx.test_utils.download(url='http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz', dirname='data', overwrite=False)
tar = tarfile.open(fname)
tar.extractall(path='./data')
tar.close()
我们先看一眼数据,在根目录下(./data/101_ObjectCategories),每一个分类都有一个子文件夹(./data/101_ObjectCategories/yin_yang)。
现在我们使用im2rec.py
脚本将图片转为RecordIO格式。首先,我们需要一个包含所有图片文件和分类的列表。
os.system('python %s/tools/im2rec.py --list=1 --recursive=1 --shuffle=1 --test-ratio=0.2 data/caltech data/101_ObjectCategories'%os.environ['MXNET_HOME'])
得到的列表文件(./data/caltech_train.lst)以index\\t(one or more label)\\tpath
的格式。在这个例子里,每一个图片只有一个标签,但是你可以修改列表用于多标签训练(参见MXNet im2rec.py使用教程)。
7167 69.000000 okapi/image_0017.jpg
6153 52.000000 ibis/image_0073.jpg
7761 81.000000 scissors/image_0005.jpg
7792 81.000000 scissors/image_0036.jpg
1326 2.000000 Faces_easy/image_0425.jpg
...
然后我们可以使用这个列表来创建RecordIO文件。
os.system("python %s/tools/im2rec.py --num-thread=4 --pass-through=1 data/caltech data/101_ObjectCategories"%os.environ['MXNET_HOME'])
record io文件保存在这里(./data)。
使用ImageRecordIter
ImageRecordIter加载RecordIO格式的图片数据,只需要简单地创建一个加载实例:
data_iter = mx.io.ImageRecordIter(
path_imgrec="./data/caltech.rec", # the target record file
data_shape=(3, 227, 227), # output data shape. An 227x227 region will be cropped from the original image.
batch_size=4, # number of samples per batch
resize=256 # resize the shorter edge to 256 before cropping
# ... you can add more augumentation options as defined in ImageRecordIter.
)
data_iter.reset()
batch = data_iter.next()
data = batch.data[0]
for i in range(4):
plt.subplot(1,4,i+1)
plt.imshow(data[i].asnumpy().astype(np.uint8).transpose((1,2,0)))
plt.show()
使用ImageIter
ImageIter是一个灵活的接口,支持从RecordIO和源文件加载图片。
data_iter = mx.image.ImageIter(batch_size=4, data_shape=(3, 227, 227),
path_imgrec="./data/caltech.rec",
path_imgidx="./data/caltech.idx" )
data_iter.reset()
batch = data_iter.next()
data = batch.data[0]
for i in range(4):
plt.subplot(1,4,i+1)
plt.imshow(data[i].asnumpy().astype(np.uint8).transpose((1,2,0)))
plt.show()
以上是关于MXNet官方教程5Iterators-加载数据的主要内容,如果未能解决你的问题,请参考以下文章