DataLoader源代码学习笔记

1.迭代器和生成器

弄明白DataLoader前有必要先弄明白迭代器

img

1.1 Iterator,Iterable

Python中列表、字典、集合等容器都是可迭代的,即Iterable。可以用for...in...来取出其中的每个元素。

1
2
3
4
5
6
7
8
9
for e in [1,2,3,4]:
print(e)
'''
out:
1
2
3
4
'''

等效于:

1
2
3
4
5
6
it = iter([1,2,3,4])
while True:
try:
print(next(it))
except:
break

事实上,列表、字典等只是Iterable而不是迭代器,即Iterator

for...in...本质上是把可迭代的容器转换成迭代器,然后通过next()来取出其中的元素。

StopIteration表明所有元素已经取出

img

迭代器需要方法__iter____next__,其中__iter__方法使得类可迭代

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
class isIterable():
def __init__(self):
pass

class isIterable2():
def __init__(self):
pass
def __iter__(self):
return self

from collections import Iterable
a = isIterable()
a2 = isIterable2()
isinstance(a,Iterable)
isinstance(a2,Iterable)
'''
out:
False
True
'''

具有__iter__方法的对象可以使用iter()生成迭代器,否则不行。

当具有__next__方法时,对象就是一个迭代器

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
class isIterator():
def __init__(self):
pass
def __iter__(self):
return iter(self)

class isIterator2():
def __init__(self):
pass
def __iter__(self):
return self
def __next__(self):
return 1

from collections import Iterator
a = isIterator()
a2 = isIterator2()
isinstance(a,Iterator)
isinstance(a2,Iterator)
'''
out:
False
True
'''

1.2 生成器和yield

生成器其实是一种特殊的迭代器,但是不需要像迭代器一样实现__iter____next__方法,只需要使用关键字yield就可以。

可以参考本文第一张图和这篇文章(35条消息) python中yield的用法详解——最简单,最清晰的解释_mieleizhi0522的博客-CSDN博客_yield

简单来说yield相当于return,但是每次调用next()会让上一步停止的地方(即yield处)接着运行

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
def yield_test(num):
while True:
print("a")
yield num
num -= 5
print("b")

a = yield_test(10)
#第一次next
num = next(a)
#第二次next
num = next(a)
'''
第一次next的时候:
out:
a
num = 10


第二次next的时候
out:
b
a
num = 5
'''

2.DataLoader

原文(35条消息) PyTorch学习笔记(6)——DataLoader源代码剖析_g11d111的博客-CSDN博客

上文写的很详细,本文补充了一些个人的理解

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
class DataLoader(object):
__initialized = False

def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,
num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False,
timeout=0, worker_init_fn=None):

self.dataset = dataset
self.batch_size = batch_size
self.num_workers = num_workers
...

if sampler is not None and shuffle:
raise ValueError('sampler option is mutually exclusive with "shuffle"')
...
if batch_sampler is None:
if sampler is None:
if shuffle:
sampler = RandomSampler(dataset)
else:
sampler = SequentialSampler(dataset)
batch_sampler = BatchSampler(sampler, batch_size, drop_last)

self.sampler = sampler
self.batch_sampler = batch_sampler
self.__initialized = True
...
def __iter__(self):
return _DataLoaderIter(self)
...

为什么不直接用迭代器对象,而是用可迭代对象呢?

参考这篇 python的迭代器为什么一定要实现__iter__方法? - 知乎 (zhihu.com)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
class iterable():
def __init__(self):
self.count = 0
def __iter__(self):
return self
def __next__(self):
self.count+=2
if self.count<=10:
return self.count
else:
raise StopIteration
it = iterable()
for i in it:
print(i)
'''
2
4
6
8
10
'''
for i in it:
print(i)
#输出为空
next(it)
#StopIteration:

可以看到,直接使用一个迭代器,完成迭代后这个对象不能重复使用了,就好像迭代器里的元素被“摇完”了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
class iterator():
def __init__(self):
self.count = 0
def __next__(self):
self.count+=2
if self.count<=10:
return self.count
else:
raise StopIteration

class iterable():
def __init__(self):
pass
def __iter__(self):
return iterator()

it = iterable()
for i in it:
print(i,end=" ")

for i in it:
print(i,end=" ")
#2 4 6 8 10 2 4 6 8 10

可以看到此时,__iter__每次都会创建新的iterator对象,不会存在把元素“摇完”的情况。

2.1 DataLoader之RandomSampler(dataset)、 SequentialSampler(dataset)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
class Sampler(object):
r"""Base class for all Samplers.
Every Sampler subclass has to provide an __iter__ method, providing a way
to iterate over indices of dataset elements, and a __len__ method that
returns the length of the returned iterators.
"""

def __init__(self, data_source):
pass

def __iter__(self):
raise NotImplementedError

def __len__(self):
raise NotImplementedError


class SequentialSampler(Sampler):
r"""Samples elements sequentially, always in the same order.
Arguments:
data_source (Dataset): dataset to sample from
"""

def __init__(self, data_source):
self.data_source = data_source

def __iter__(self):
return iter(range(len(self.data_source)))

def __len__(self):
return len(self.data_source)


class RandomSampler(Sampler):
r"""Samples elements randomly, without replacement.
Arguments:
data_source (Dataset): dataset to sample from
"""

def __init__(self, data_source):
self.data_source = data_source

def __iter__(self):
return iter(torch.randperm(len(self.data_source)).tolist())

def __len__(self):
return len(self.data_source)

if __name__ == "__main__":
print(list(RandomSampler(range(10))))
#[2, 8, 3, 5, 9, 4, 6, 0, 1, 7]
print(list(SequentialSampler(range(10))))
#[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

2.2 DataLoader之BatchSampler(Sampler)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
class BatchSampler(Sampler):
r"""Wraps another sampler to yield a mini-batch of indices.
Args:
sampler (Sampler): Base sampler.
batch_size (int): Size of mini-batch.
drop_last (bool): If ``True``, the sampler will drop the last batch if
its size would be less than ``batch_size``
Example:
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))
[[0, 1, 2], [3, 4, 5], [6, 7, 8]]
"""

def __init__(self, sampler, batch_size, drop_last):
if not isinstance(sampler, Sampler):
raise ValueError("sampler should be an instance of "
"torch.utils.data.Sampler, but got sampler={}"
.format(sampler))
if not isinstance(batch_size, _int_classes) or isinstance(batch_size, bool) or \
batch_size <= 0:
raise ValueError("batch_size should be a positive integeral value, "
"but got batch_size={}".format(batch_size))
if not isinstance(drop_last, bool):
raise ValueError("drop_last should be a boolean value, but got "
"drop_last={}".format(drop_last))
self.sampler = sampler
self.batch_size = batch_size
self.drop_last = drop_last

def __iter__(self):
batch = []
# 一旦达到batch_size的长度,说明batch被填满,就可以yield出去了
for idx in self.sampler:
batch.append(idx)
if len(batch) == self.batch_size:
yield batch
batch = []
if len(batch) > 0 and not self.drop_last:
yield batch

def __len__(self):
# 比如epoch有100个样本,batch_size选择为64,那么drop_last的结果为1,不drop_last的结果为2
if self.drop_last:
return len(self.sampler) // self.batch_size
else:
return (len(self.sampler) + self.batch_size - 1) // self.batch_size
if __name__ == "__main__":
print(list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False)))
# [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
print(list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True)))
# [[0, 1, 2], [3, 4, 5], [6, 7, 8]]

1
2
3
4
5
6
7
class iterable():
def __iter__(self):
return iter([1,2])

it = iterable()
list(it)
#[1, 2]

可以看到list可以使用__iter__来生成列表。

3. _DataLoaderIter