未验证 提交 686fa07a 编写于 作者: Y Yulong Ao 提交者: GitHub

[Auto Parallel] Improve the fine-grained APIs (#46552)

* [Auto Parallel] Suppport different dataloaders

* [Auto Parallel] Add num_shards config for dataset

* [Auto Parallel] Unify the logger and outputs of Engine API

* [Auto Parallel] Fix the bugs of to_static

* [Auto Parallel] Adjust the test_to_static.py

* [Auto Parallel] Add the prepare API and replace __call__ with run

* [Auto Parallel] Improve the private implementations of Engine

* [Auto Parallel] Set capacity of dataloader for opt tuning

* [Auto Parallel] [WIP] Change the fine-grained API

* [Auto Parallel] Improve APIs to support different user cases

* [Auto Parallel] Add removed config

* [Auto Parallel] Add imports

* [Auto Parallel] Fix bugs for to_static

* [Auto Parallel] Remove unnecessary imports
上级 01baa0b6
...@@ -116,3 +116,10 @@ set_field_default_config(TUNING, "profile_start_step", 1) ...@@ -116,3 +116,10 @@ set_field_default_config(TUNING, "profile_start_step", 1)
set_field_default_config(TUNING, "profile_end_step", 1) set_field_default_config(TUNING, "profile_end_step", 1)
set_field_default_config(TUNING, "run_after_tuning", True) set_field_default_config(TUNING, "run_after_tuning", True)
set_field_default_config(TUNING, "verbose", True) set_field_default_config(TUNING, "verbose", True)
#########################################
# dataset configuration
#########################################
DATASET = "dataset"
set_field_default_config(DATASET, "enable", False)
set_field_default_config(DATASET, "num_shards", 1)
...@@ -17,38 +17,11 @@ import numpy as np ...@@ -17,38 +17,11 @@ import numpy as np
import paddle import paddle
from paddle.io import BatchSampler, IterableDataset from paddle.io import BatchSampler, IterableDataset
from paddle.fluid.dataloader.batch_sampler import _InfiniteIterableSampler from paddle.fluid.dataloader.batch_sampler import _InfiniteIterableSampler, DistributedBatchSampler
from paddle.fluid.dataloader.dataloader_iter import _DatasetKind, default_collate_fn, default_convert_fn from paddle.fluid.dataloader.dataloader_iter import _DatasetKind, default_collate_fn, default_convert_fn
class DistributedDataLoader(metaclass=abc.ABCMeta): class DistributedDataLoaderBase(metaclass=abc.ABCMeta):
def __init__(self, dataset, batch_size=1, epochs=1, drop_last=False):
if isinstance(dataset, IterableDataset):
self.dataset_kind = _DatasetKind.ITER
else:
self.dataset_kind = _DatasetKind.MAP
self.dataset = dataset
self.epochs = epochs
self.drop_last = drop_last
if batch_size is None:
self.batch_size = None
self.batch_sampler = None
else:
self.batch_size = batch_size
if isinstance(dataset, IterableDataset):
self.batch_sampler = _InfiniteIterableSampler(
dataset, batch_size)
else:
self.batch_sampler = BatchSampler(dataset,
batch_size=batch_size,
shuffle=False,
drop_last=drop_last)
self.auto_collate_batch = self.batch_sampler is not None
self.sampler_iter = iter(self.index_sampler)
@abc.abstractmethod @abc.abstractmethod
def __iter__(self): def __iter__(self):
...@@ -58,48 +31,70 @@ class DistributedDataLoader(metaclass=abc.ABCMeta): ...@@ -58,48 +31,70 @@ class DistributedDataLoader(metaclass=abc.ABCMeta):
def __next__(self): def __next__(self):
raise NotImplementedError raise NotImplementedError
@property
def index_sampler(self):
if self.auto_collate_batch:
return self.batch_sampler
else:
if self.dataset_kind == _DatasetKind.MAP:
return list(range(len(self.dataset)))
else:
return _InfiniteIterableSampler(self.dataset, 1)
class NonIterableGeneratorLoader(DistributedDataLoader): class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase):
def __init__(self, def __init__(self,
dataset, dataset,
feed_list, feed_list=None,
places, capacity=None,
use_double_buffer=True,
iterable=True,
return_list=False,
use_multiprocess=False,
drop_last=True,
places=None,
batch_size=1, batch_size=1,
epochs=1, epochs=1,
steps_per_epoch=None, steps_per_epoch=None,
collate_fn=None, collate_fn=None,
split_data=True,
data_parallel_world_size=[], data_parallel_world_size=[],
data_parallel_rank=[], data_parallel_rank=[]):
drop_last=False, self.dataset = dataset
split_data=True):
self.feed_list = feed_list self.feed_list = feed_list
self.capacity = capacity
self.use_double_buffer = use_double_buffer
self.iterable = iterable
self.return_list = return_list
self.use_multiprocess = use_multiprocess
self.drop_last = drop_last
self.places = places self.places = places
self.batch_size = batch_size
self.epochs = epochs
self.steps_per_epoch = steps_per_epoch self.steps_per_epoch = steps_per_epoch
self.collate_fn = collate_fn
self.split_data = split_data
assert len(data_parallel_world_size) == len(feed_list) assert len(data_parallel_world_size) == len(feed_list)
assert len(data_parallel_rank) == len(feed_list) assert len(data_parallel_rank) == len(feed_list)
self.dp_world_sizes = data_parallel_world_size self.dp_world_sizes = data_parallel_world_size
self.dp_ranks = data_parallel_rank self.dp_ranks = data_parallel_rank
self.split_data = split_data
super(NonIterableGeneratorLoader, if isinstance(dataset, IterableDataset):
self).__init__(dataset, batch_size, epochs, drop_last) self.dataset_kind = _DatasetKind.ITER
else:
self.dataset_kind = _DatasetKind.MAP
if self.batch_size is None:
self.batch_sampler = None
else:
if isinstance(dataset, IterableDataset):
self.batch_sampler = _InfiniteIterableSampler(
dataset, batch_size)
else:
self.batch_sampler = BatchSampler(dataset,
batch_size=batch_size,
shuffle=False,
drop_last=drop_last)
self.auto_collate_batch = self.batch_sampler is not None
self.sampler_iter = iter(self.index_sampler)
if self.auto_collate_batch: if self.auto_collate_batch:
self.collate_fn = collate_fn or default_collate_fn self.collate_fn = collate_fn or default_collate_fn
else: else:
self.collate_fn = collate_fn or default_convert_fn self.collate_fn = collate_fn or default_convert_fn
self.dataset_fetcher = _DatasetKind.create_fetcher( self.dataset_fetcher = _DatasetKind.create_fetcher(
self.dataset_kind, self.dataset, self.auto_collate_batch, self.dataset_kind, self.dataset, self.auto_collate_batch,
self.collate_fn, self.drop_last) self.collate_fn, self.drop_last)
...@@ -115,8 +110,10 @@ class NonIterableGeneratorLoader(DistributedDataLoader): ...@@ -115,8 +110,10 @@ class NonIterableGeneratorLoader(DistributedDataLoader):
def __next__(self): def __next__(self):
if not self._steps: if not self._steps:
self._cur_step += 1 self._cur_step += 1
return None
elif self._cur_step < self._steps: elif self._cur_step < self._steps:
self._cur_step += 1 self._cur_step += 1
return None
else: else:
self._inner_dataloader.reset() self._inner_dataloader.reset()
self.sampler_iter = iter(self.index_sampler) self.sampler_iter = iter(self.index_sampler)
...@@ -138,6 +135,16 @@ class NonIterableGeneratorLoader(DistributedDataLoader): ...@@ -138,6 +135,16 @@ class NonIterableGeneratorLoader(DistributedDataLoader):
) )
return steps_per_epoch return steps_per_epoch
@property
def index_sampler(self):
if self.auto_collate_batch:
return self.batch_sampler
else:
if self.dataset_kind == _DatasetKind.MAP:
return list(range(len(self.dataset)))
else:
return _InfiniteIterableSampler(self.dataset, 1)
def _create_inner_dataloader(self): def _create_inner_dataloader(self):
def data_generator(): def data_generator():
...@@ -170,7 +177,83 @@ class NonIterableGeneratorLoader(DistributedDataLoader): ...@@ -170,7 +177,83 @@ class NonIterableGeneratorLoader(DistributedDataLoader):
yield partial_data yield partial_data
dataloader = paddle.fluid.io.DataLoader.from_generator( dataloader = paddle.fluid.io.DataLoader.from_generator(
feed_list=self.feed_list, capacity=70, iterable=False) feed_list=self.feed_list,
capacity=self.capacity,
use_double_buffer=self.use_double_buffer,
# iterable=self.iterable,
iterable=False,
return_list=self.return_list,
use_multiprocess=self.use_multiprocess,
drop_last=self.drop_last)
dataloader.set_batch_generator(data_generator, self.places) dataloader.set_batch_generator(data_generator, self.places)
return dataloader return dataloader
class DistributedDataLoader(DistributedDataLoaderBase):
def __init__(self,
dataset,
feed_list=None,
places=None,
return_list=True,
batch_size=1,
shuffle=False,
drop_last=False,
collate_fn=None,
num_workers=0,
use_buffer_reader=True,
use_shared_memory=True,
timeout=0,
worker_init_fn=None,
epochs=1,
steps_per_epoch=None,
split_data=True,
data_parallel_world_size=[],
data_parallel_rank=[]):
self.dataset = dataset
self.feed_list = feed_list
self.return_list = return_list
self.places = places
self.batch_size = batch_size
self.shuffle = shuffle
self.drop_last = drop_last
self.collate_fn = collate_fn
self.num_workers = num_workers
self.use_buffer_reader = use_buffer_reader
self.use_shared_memory = use_shared_memory
self.timeout = timeout
self.worker_init_fn = worker_init_fn
self.epochs = epochs
self.steps_per_epoch = steps_per_epoch
self.dp_world_sizes = data_parallel_world_size
self.dp_ranks = data_parallel_rank
self.split_data = split_data
# TODO: rank info
self.batch_sampler = DistributedBatchSampler(
self.dataset, self.batch_size, self.dp_world_sizes[0],
self.dp_ranks[0], self.shuffle, self.drop_last)
self._inner_dataloader = self._create_inner_dataloader()
def __iter__(self):
return self
def __next__(self):
return next(self.data)
def _create_inner_dataloader(self):
dataloader = paddle.fluid.io.DataLoader(
self.dataset,
feed_list=self.feed_list,
places=self.places,
return_list=self.return_list,
batch_sampler=self.batch_sampler,
collate_fn=self.collate_fn,
num_workers=self.num_workers,
use_buffer_reader=self.use_buffer_reader,
use_shared_memory=self.use_shared_memory,
timeout=self.timeout,
worker_init_fn=self.worker_init_fn)
self.data = (x for x in dataloader)
return dataloader
...@@ -210,11 +210,11 @@ def get_collection(name): ...@@ -210,11 +210,11 @@ def get_collection(name):
return _g_collections[name] return _g_collections[name]
def add_to_collection(collection_name, value, value_name=None): def add_to_collection(collection_name, value, name=None):
if collection_name not in _g_collections: if collection_name not in _g_collections:
_g_collections[collection_name] = [] _g_collections[collection_name] = []
if value_name is not None: if name is not None:
_g_collections[collection_name].append((value_name, value)) _g_collections[collection_name].append((name, value))
else: else:
_g_collections[collection_name].append((None, value)) _g_collections[collection_name].append((None, value))
......
...@@ -23,7 +23,7 @@ from .dist_attribute import OperatorDistributedAttribute ...@@ -23,7 +23,7 @@ from .dist_attribute import OperatorDistributedAttribute
from .utils import is_backward_op, is_forward_op, is_loss_op, is_optimize_op from .utils import is_backward_op, is_forward_op, is_loss_op, is_optimize_op
from .operators.common import BACKWARD_ONLY_DIST_OPS from .operators.common import BACKWARD_ONLY_DIST_OPS
__varname_not_in_block__ = ["lod_tensor_blocking_queue_0"] __varname_not_in_block__ = ["lod_tensor_blocking_queue"]
__not_shape_var_type__ = [ __not_shape_var_type__ = [
core.VarDesc.VarType.READER, core.VarDesc.VarType.STEP_SCOPES core.VarDesc.VarType.READER, core.VarDesc.VarType.STEP_SCOPES
] ]
...@@ -238,7 +238,9 @@ class Partitioner(object): ...@@ -238,7 +238,9 @@ class Partitioner(object):
target_block, serial_input_varname, target_block, serial_input_varname,
new_varname) new_varname)
else: else:
assert serial_input_varname in __varname_not_in_block__ for varname_not_in_block in __varname_not_in_block__:
assert varname_not_in_block in serial_input_varname, \
"{} is not found".format(serial_input_varname)
self._serial2dist_varname_mapping[ self._serial2dist_varname_mapping[
serial_input_varname] = new_varname serial_input_varname] = new_varname
......
...@@ -45,7 +45,8 @@ def get_var_with_recursion(var_name, block, program): ...@@ -45,7 +45,8 @@ def get_var_with_recursion(var_name, block, program):
parent_block = program.blocks[block.parent_idx] parent_block = program.blocks[block.parent_idx]
if var_name in parent_block.vars: if var_name in parent_block.vars:
var = parent_block.vars[var_name] var = parent_block.vars[var_name]
assert var is not None assert var is not None, \
"{} is not found".format(var.name)
return var return var
...@@ -1838,8 +1839,8 @@ class Resharder: ...@@ -1838,8 +1839,8 @@ class Resharder:
idx_offset = 0 idx_offset = 0
for var_name in input_var_names: for var_name in input_var_names:
# skip lod_tensor_blocking_queue_0 # skip lod_tensor_blocking_queue_? name
if var_name == "lod_tensor_blocking_queue_0": if "lod_tensor_blocking_queue" in var_name:
continue continue
var = get_var_with_recursion(var_name, block, var = get_var_with_recursion(var_name, block,
self.auto_parallel_main_prog) self.auto_parallel_main_prog)
......
...@@ -114,6 +114,13 @@ class TuningConfig(BaseConfig): ...@@ -114,6 +114,13 @@ class TuningConfig(BaseConfig):
super(TuningConfig, self).__init__(category, config_dict) super(TuningConfig, self).__init__(category, config_dict)
class DatasetConfig(BaseConfig):
def __init__(self, config_dict=None):
category = constants.DATASET
super(DatasetConfig, self).__init__(category, config_dict)
class Strategy(BaseConfig): class Strategy(BaseConfig):
""" """
The `Strategy` object is used to configure the paralleization and optimization beheviors. The `Strategy` object is used to configure the paralleization and optimization beheviors.
...@@ -178,3 +185,6 @@ class Strategy(BaseConfig): ...@@ -178,3 +185,6 @@ class Strategy(BaseConfig):
config_dict = self._config_dict.get(constants.TUNING, None) config_dict = self._config_dict.get(constants.TUNING, None)
self.tuning = TuningConfig(config_dict) self.tuning = TuningConfig(config_dict)
config_dict = self._config_dict.get(constants.DATASET, None)
self.dataset = DatasetConfig(config_dict)
...@@ -23,7 +23,7 @@ import paddle ...@@ -23,7 +23,7 @@ import paddle
from paddle.fluid.framework import Program, _current_expected_place from paddle.fluid.framework import Program, _current_expected_place
from paddle.fluid.framework import Operator from paddle.fluid.framework import Operator
from paddle.distributed.auto_parallel.process_group import get_all_process_groups, new_process_group from paddle.distributed.auto_parallel.process_group import get_all_process_groups, new_process_group
from paddle.distributed.auto_parallel.dist_loader import NonIterableGeneratorLoader from paddle.distributed.auto_parallel.dist_loader import DistributedDataLoaderFromGenerator
from paddle.distributed.collective import _get_global_env from paddle.distributed.collective import _get_global_env
paddle.enable_static() paddle.enable_static()
...@@ -132,13 +132,14 @@ def create_dataloader(main_program, ...@@ -132,13 +132,14 @@ def create_dataloader(main_program,
# insert read op at the end of program # insert read op at the end of program
places = paddle.static.cuda_places() places = paddle.static.cuda_places()
with paddle.static.program_guard(main_program, startup_program): with paddle.static.program_guard(main_program, startup_program):
dataloader = NonIterableGeneratorLoader( dataloader = DistributedDataLoaderFromGenerator(
dataset, dataset=dataset,
feed_list, feed_list=feed_list,
places, capacity=70,
dataset.batch_size, places=places,
epochs, batch_size=dataset.batch_size,
steps_per_epoch, epochs=epochs,
steps_per_epoch=steps_per_epoch,
data_parallel_world_size=dataset.dp_world_size, data_parallel_world_size=dataset.dp_world_size,
data_parallel_rank=dataset.dp_rank) data_parallel_rank=dataset.dp_rank)
......
...@@ -16,6 +16,8 @@ import tempfile ...@@ -16,6 +16,8 @@ import tempfile
import os import os
import numpy as np import numpy as np
import paddle import paddle
import paddle.static as static
import paddle.utils as utils
import paddle.nn as nn import paddle.nn as nn
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle.io import Dataset from paddle.io import Dataset
...@@ -26,7 +28,8 @@ paddle.enable_static() ...@@ -26,7 +28,8 @@ paddle.enable_static()
global_process_mesh = auto.ProcessMesh(mesh=[0, 1]) global_process_mesh = auto.ProcessMesh(mesh=[0, 1])
PP_MESH_0 = auto.ProcessMesh([0]) PP_MESH_0 = auto.ProcessMesh([0])
PP_MESH_1 = auto.ProcessMesh([1]) PP_MESH_1 = auto.ProcessMesh([1])
batch_size = 1 epoch_num = 1
batch_size = 2
batch_num = 10 batch_num = 10
hidden_size = 1024 hidden_size = 1024
sequence_len = 512 sequence_len = 512
...@@ -36,6 +39,8 @@ class_num = 10 ...@@ -36,6 +39,8 @@ class_num = 10
paddle.seed(44) paddle.seed(44)
is_fetch = True is_fetch = True
is_feed = True
my_feed_vars = []
class MyDataset(Dataset): class MyDataset(Dataset):
...@@ -53,6 +58,23 @@ class MyDataset(Dataset): ...@@ -53,6 +58,23 @@ class MyDataset(Dataset):
return self.num_samples return self.num_samples
def get_random_inputs_and_labels(image_shape, label_shape):
input = np.random.random(size=image_shape).astype('float32')
label = np.random.random(size=label_shape).astype('int64')
return input, label
def batch_generator_creator():
def __reader__():
for _ in range(batch_num):
batch_input, batch_label = get_random_inputs_and_labels(
[batch_size, image_size], [batch_size, 1])
yield batch_input, batch_label
return __reader__
class MLPLayer(nn.Layer): class MLPLayer(nn.Layer):
def __init__(self, def __init__(self,
...@@ -82,16 +104,20 @@ class MLPLayer(nn.Layer): ...@@ -82,16 +104,20 @@ class MLPLayer(nn.Layer):
def forward(self, input): def forward(self, input):
out = auto.shard_op(self.norm, PP_MESH_0)(input) out = auto.shard_op(self.norm, PP_MESH_0)(input)
out = self.linear0(out) out = self.linear0(out)
if is_feed:
my_feed_vars.append((out, out.shape))
out = F.gelu(out, approximate=True) out = F.gelu(out, approximate=True)
out = auto.shard_op(self.linear1, PP_MESH_1)(out) out = auto.shard_op(self.linear1, PP_MESH_1)(out)
out = self.dropout(out) out = self.dropout(out)
out = self.linear2(out) out = self.linear2(out)
if is_feed:
my_feed_vars.append((out, out.shape))
if is_fetch: if is_fetch:
auto.fetch(out, "my_out", logging=True) auto.fetch(out, "my_out", logging=True)
return out return out
def train(fetch): def train_high_level(fetch):
global is_fetch global is_fetch
is_fetch = fetch is_fetch = fetch
mlp = MLPLayer(hidden_size=hidden_size, mlp = MLPLayer(hidden_size=hidden_size,
...@@ -135,7 +161,7 @@ def train(fetch): ...@@ -135,7 +161,7 @@ def train(fetch):
temp_dir.cleanup() temp_dir.cleanup()
def train_callable(): def train_low_level():
mlp = MLPLayer(hidden_size=hidden_size, mlp = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size, intermediate_size=4 * hidden_size,
dropout_ratio=0.1, dropout_ratio=0.1,
...@@ -151,31 +177,38 @@ def train_callable(): ...@@ -151,31 +177,38 @@ def train_callable():
strategy = auto.Strategy() strategy = auto.Strategy()
strategy.auto_mode = "semi" strategy.auto_mode = "semi"
engine = auto.Engine(mlp, loss, optimizer, metric, strategy=strategy) engine = auto.Engine(mlp, loss, optimizer, metrics=None, strategy=strategy)
feed_dict = {}
for feed_var, shape in my_feed_vars:
feed_dict[feed_var.name] = np.zeros(shape, dtype="float32")
# Build normal normal dataloader
# train # train
train_dataset = MyDataset(batch_num * batch_size) train_dataset = MyDataset(batch_num * batch_size)
train_dataloader = engine.dataloader(train_dataset, train_dataloader = engine.dataloader(train_dataset,
batch_size=batch_size, batch_size=batch_size,
mode="train") mode="train")
for _ in train_dataloader: engine.prepare(mode="train")
outs = engine(mode="train") for data in train_dataloader:
outs = engine.run(data, feed=feed_dict, mode="train")
# eval # eval
eval_dataset2 = MyDataset(batch_size) eval_dataset2 = MyDataset(batch_size)
eval_dataloader = engine.dataloader(eval_dataset2, eval_dataloader = engine.dataloader(eval_dataset2,
batch_size=batch_size, batch_size=batch_size,
mode="eval") mode="eval")
for _ in eval_dataloader: engine.prepare(mode="eval")
outs = engine(mode="eval") for data in eval_dataloader:
outs = engine.run(data, feed=feed_dict, mode="eval")
# predict # predict
engine.to_mode("predict")
test_dataset = MyDataset(batch_size) test_dataset = MyDataset(batch_size)
predict_dataloader = engine.dataloader(test_dataset, predict_dataloader = engine.dataloader(test_dataset, batch_size=batch_size)
batch_size=batch_size, engine.prepare()
mode="predict") for data in predict_dataloader:
for _ in predict_dataloader: outs = engine.run(data, feed=feed_dict)
outs = engine(mode="predict")
# save # save
temp_dir = tempfile.TemporaryDirectory() temp_dir = tempfile.TemporaryDirectory()
...@@ -184,8 +217,144 @@ def train_callable(): ...@@ -184,8 +217,144 @@ def train_callable():
engine.load(model_filename) engine.load(model_filename)
temp_dir.cleanup() temp_dir.cleanup()
# Build dataloader from generator
# train
train_dataset = MyDataset(batch_num * batch_size)
train_dataloader = engine.dataloader_from_generator(train_dataset,
batch_size=batch_size,
mode="train")
engine.prepare(mode="train")
for data in train_dataloader:
outs = engine.run(data, feed=feed_dict, mode="train")
# eval
engine.to_mode("eval")
eval_dataset2 = MyDataset(batch_size)
eval_dataloader = engine.dataloader_from_generator(eval_dataset2,
batch_size=batch_size)
engine.prepare()
for data in eval_dataloader:
outs = engine.run(data, feed=feed_dict)
# predict
test_dataset = MyDataset(batch_size)
predict_dataloader = engine.dataloader_from_generator(test_dataset,
batch_size=batch_size,
mode="predict")
engine.prepare(mode="predict")
for data in predict_dataloader:
outs = engine.run(data, feed=feed_dict, mode="predict")
# save
temp_dir = tempfile.TemporaryDirectory()
model_filename = os.path.join(temp_dir.name, 'mlp')
engine.save(model_filename, training=True)
engine.load(model_filename)
temp_dir.cleanup()
def train_builtin_data_vars():
mlp = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02)
loss = paddle.nn.CrossEntropyLoss()
optimizer = paddle.optimizer.Adam(learning_rate=0.00001,
beta1=0.9,
beta2=0.999,
epsilon=1e-08,
grad_clip=None)
metric = paddle.metric.Accuracy()
strategy = auto.Strategy()
strategy.auto_mode = "semi"
engine = auto.Engine(mlp, loss, optimizer, metric, strategy=strategy)
# train
engine.to_mode("train")
input_spec = static.InputSpec([batch_size, image_size], 'float32', 'input')
label_spec = static.InputSpec([batch_size, 1], 'int64', 'label')
engine.prepare(inputs_spec=[input_spec], labels_spec=[label_spec])
with static.program_guard(engine.main_program, engine.startup_program):
feed_list = engine.inputs + engine.labels
print(feed_list)
loader = paddle.io.DataLoader.from_generator(feed_list=feed_list,
capacity=4 * batch_size,
iterable=False)
places = static.cuda_places()
loader.set_batch_generator(batch_generator_creator(), places=places)
for _ in range(epoch_num):
loader.start() # call DataLoader.start() before each epoch starts
try:
while True:
engine.run()
except paddle.fluid.core.EOFException:
loader.reset(
) # call DataLoader.reset() after catching EOFException
def train_non_builtin_data_vars():
main_program = static.Program()
startup_program = static.Program()
with static.program_guard(main_program,
startup_program), utils.unique_name.guard():
input = static.data(name="input",
shape=[batch_size, image_size],
dtype='float32')
label = static.data(name="label", shape=[batch_size, 1], dtype='int64')
loader = paddle.io.DataLoader.from_generator(feed_list=[input, label],
capacity=4 * batch_size,
iterable=False)
places = static.cuda_places()
loader.set_batch_generator(batch_generator_creator(), places=places)
mlp = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02)
loss = paddle.nn.CrossEntropyLoss()
optimizer = paddle.optimizer.Adam(learning_rate=0.00001,
beta1=0.9,
beta2=0.999,
epsilon=1e-08,
grad_clip=None)
metric = paddle.metric.Accuracy()
predict = mlp(input)
loss_var = loss(predict, label)
strategy = auto.Strategy()
strategy.auto_mode = "semi"
engine = auto.Engine(loss=loss_var,
optimizer=optimizer,
metrics=metric,
strategy=strategy)
# train
engine.to_mode("train")
engine.prepare(inputs=[input],
labels=[label],
main_program=main_program,
startup_program=startup_program)
for _ in range(epoch_num):
loader.start() # call DataLoader.start() before each epoch starts
try:
while True:
engine.run()
except paddle.fluid.core.EOFException:
loader.reset(
) # call DataLoader.reset() after catching EOFException
if __name__ == "__main__": if __name__ == "__main__":
train(fetch=True) train_high_level(fetch=True)
train(fetch=False) train_high_level(fetch=False)
train_callable() train_low_level()
train_builtin_data_vars()
train_non_builtin_data_vars()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册