defyield_test(num): whileTrue: print("a") yield num num -= 5 print("b") a = yield_test(10) #第一次next num = next(a) #第二次next num = next(a) ''' 第一次next的时候: out: a num = 10 第二次next的时候 out: b a num = 5 '''
classiterable(): def__init__(self): self.count = 0 def__iter__(self): return self def__next__(self): self.count+=2 if self.count<=10: return self.count else: raise StopIteration it = iterable() for i in it: print(i) ''' 2 4 6 8 10 ''' for i in it: print(i) #输出为空 next(it) #StopIteration:
classSampler(object): r"""Base class for all Samplers. Every Sampler subclass has to provide an __iter__ method, providing a way to iterate over indices of dataset elements, and a __len__ method that returns the length of the returned iterators. """
def__init__(self, data_source): pass
def__iter__(self): raise NotImplementedError
def__len__(self): raise NotImplementedError
classSequentialSampler(Sampler): r"""Samples elements sequentially, always in the same order. Arguments: data_source (Dataset): dataset to sample from """
classBatchSampler(Sampler): r"""Wraps another sampler to yield a mini-batch of indices. Args: sampler (Sampler): Base sampler. batch_size (int): Size of mini-batch. drop_last (bool): If ``True``, the sampler will drop the last batch if its size would be less than ``batch_size`` Example: >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False)) [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True)) [[0, 1, 2], [3, 4, 5], [6, 7, 8]] """
def__init__(self, sampler, batch_size, drop_last): ifnotisinstance(sampler, Sampler): raise ValueError("sampler should be an instance of " "torch.utils.data.Sampler, but got sampler={}" .format(sampler)) ifnotisinstance(batch_size, _int_classes) orisinstance(batch_size, bool) or \ batch_size <= 0: raise ValueError("batch_size should be a positive integeral value, " "but got batch_size={}".format(batch_size)) ifnotisinstance(drop_last, bool): raise ValueError("drop_last should be a boolean value, but got " "drop_last={}".format(drop_last)) self.sampler = sampler self.batch_size = batch_size self.drop_last = drop_last