diff --git a/python/paddle/distributed/fleet/data_generator/data_generator.py b/python/paddle/distributed/fleet/data_generator/data_generator.py index 6fa6652e2c6bab3f29041c0ace5d701ab87d3b3f..cddba2ddee38219056d14eecc749a3d7eeac30eb 100644 --- a/python/paddle/distributed/fleet/data_generator/data_generator.py +++ b/python/paddle/distributed/fleet/data_generator/data_generator.py @@ -38,21 +38,20 @@ class DataGenerator: .. code-block:: python - import paddle.distributed.fleet.data_generator as dg - class MyData(dg.DataGenerator): - - def generate_sample(self, line): - def local_iter(): - int_words = [int(x) for x in line.split()] - yield ("words", int_words) - return local_iter - - def generate_batch(self, samples): - def local_iter(): - for s in samples: - yield ("words", s[1].extend([s[1][0]])) - mydata = MyData() - mydata.set_batch(128) + >>> import paddle.distributed.fleet.data_generator as dg + >>> class MyData(dg.DataGenerator): + ... def generate_sample(self, line): + ... def local_iter(): + ... int_words = [int(x) for x in line.split()] + ... yield ("words", int_words) + ... return local_iter + ... + ... def generate_batch(self, samples): + ... def local_iter(): + ... for s in samples: + ... yield ("words", s[1].extend([s[1][0]])) + >>> mydata = MyData() + >>> mydata.set_batch(128) ''' self.batch_size_ = batch_size @@ -65,16 +64,15 @@ class DataGenerator: Example: .. code-block:: python - import paddle.distributed.fleet.data_generator as dg - class MyData(dg.DataGenerator): - - def generate_sample(self, line): - def local_iter(): - yield ("words", [1, 2, 3, 4]) - return local_iter - - mydata = MyData() - mydata.run_from_memory() + >>> # doctest: +SKIP('raise NotImplementedError') + >>> import paddle.distributed.fleet.data_generator as dg + >>> class MyData(dg.DataGenerator): + ... def generate_sample(self, line): + ... def local_iter(): + ... yield ("words", [1, 2, 3, 4]) + ... return local_iter + >>> mydata = MyData() + >>> mydata.run_from_memory() ''' batch_samples = [] line_iter = self.generate_sample(None) @@ -104,17 +102,15 @@ class DataGenerator: .. code-block:: python - import paddle.distributed.fleet.data_generator as dg - class MyData(dg.DataGenerator): - - def generate_sample(self, line): - def local_iter(): - int_words = [int(x) for x in line.split()] - yield ("words", [int_words]) - return local_iter - - mydata = MyData() - mydata.run_from_stdin() + >>> import paddle.distributed.fleet.data_generator as dg + >>> class MyData(dg.DataGenerator): + ... def generate_sample(self, line): + ... def local_iter(): + ... int_words = [int(x) for x in line.split()] + ... yield ("words", [int_words]) + ... return local_iter + >>> mydata = MyData() + >>> mydata.run_from_stdin() ''' batch_samples = [] @@ -177,15 +173,13 @@ class DataGenerator: .. code-block:: python - import paddle.distributed.fleet.data_generator as dg - class MyData(dg.DataGenerator): - - def generate_sample(self, line): - def local_iter(): - int_words = [int(x) for x in line.split()] - yield ("words", [int_words]) - return local_iter - + >>> import paddle.distributed.fleet.data_generator as dg + >>> class MyData(dg.DataGenerator): + ... def generate_sample(self, line): + ... def local_iter(): + ... int_words = [int(x) for x in line.split()] + ... yield ("words", [int_words]) + ... return local_iter ''' raise NotImplementedError( "Please rewrite this function to return a list or tuple: " @@ -210,21 +204,20 @@ class DataGenerator: .. code-block:: python - import paddle.distributed.fleet.data_generator as dg - class MyData(dg.DataGenerator): - - def generate_sample(self, line): - def local_iter(): - int_words = [int(x) for x in line.split()] - yield ("words", int_words) - return local_iter - - def generate_batch(self, samples): - def local_iter(): - for s in samples: - yield ("words", s[1].extend([s[1][0]])) - mydata = MyData() - mydata.set_batch(128) + >>> import paddle.distributed.fleet.data_generator as dg + >>> class MyData(dg.DataGenerator): + ... def generate_sample(self, line): + ... def local_iter(): + ... int_words = [int(x) for x in line.split()] + ... yield ("words", int_words) + ... return local_iter + ... + ... def generate_batch(self, samples): + ... def local_iter(): + ... for s in samples: + ... yield ("words", s[1].extend([s[1][0]])) + >>> mydata = MyData() + >>> mydata.set_batch(128) ''' def local_iter():