未验证 提交 c182e5dd 编写于 作者: Y Yulong Ao 提交者: GitHub

[Auto Parallel] Make Engine class callable (#46416)

* [Auto Parallel] Imporve the user-defined fetches and logging

* [Auto Parallel] Make Engine class callable

* [Auto Parallel] Update the data loading of tuner
上级 55accdfc
......@@ -34,7 +34,7 @@ class DistributedDataLoader(metaclass=abc.ABCMeta):
self.dataset = dataset
self.epochs = epochs
self.drop_lost = drop_last
self.drop_last = drop_last
if batch_size is None:
self.batch_size = None
......@@ -105,7 +105,7 @@ class NonIterableGeneratorLoader(DistributedDataLoader):
self.collate_fn = collate_fn or default_convert_fn
self.dataset_fetcher = _DatasetKind.create_fetcher(
self.dataset_kind, self.dataset, self.auto_collate_batch,
self.collate_fn, self.drop_lost)
self.collate_fn, self.drop_last)
self._steps = self._infer_steps()
self._inner_dataloader = self._create_inner_dataloader()
......@@ -153,7 +153,7 @@ class NonIterableGeneratorLoader(DistributedDataLoader):
self.dataset_fetcher = _DatasetKind.create_fetcher(
self.dataset_kind, self.dataset,
self.auto_collate_batch, self.collate_fn,
self.drop_lost)
self.drop_last)
break
partial_data = []
......
......@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
import paddle
from paddle.fluid import core
from .process_mesh import ProcessMesh
......@@ -196,15 +198,42 @@ def recompute(op):
return RecomputeOperator(op)
_g_fetched_tensors = {}
# _g_fetched_tensors = {}
# def fetch(tensor, name=None):
# if name is None:
# _g_fetched_tensors[tensor.name] = tensor
# else:
# _g_fetched_tensors[name] = tensor
def fetch(tensor, name=None):
if name is None:
_g_fetched_tensors[tensor.name] = tensor
# def _get_fetches():
# return _g_fetched_tensors
_g_collections = {}
class CollectionNames(object):
FEEDS = "feeds"
FETCHES = "fetches"
def get_collection(name):
collection = _g_collections.get(name, None)
if collection is None:
collection = []
_g_collections[name] = collection
return _g_collections[name]
def add_to_collection(collection_name, value, value_name=None):
if collection_name not in _g_collections:
_g_collections[collection_name] = []
else:
_g_fetched_tensors[name] = tensor
if value_name is not None:
_g_collections[collection_name].append((value_name, value))
else:
_g_collections[collection_name].append((None, value))
def _get_fetches():
return _g_fetched_tensors
def fetch(tensor, name=None):
add_to_collection(CollectionNames.FETCHES, tensor, name)
......@@ -97,7 +97,7 @@ class MLPLayer(nn.Layer):
out = self.dropout(out)
out = self.linear2(out)
if is_fetch:
auto.fetch(out, "out")
auto.fetch(out, "my_out")
return out
......@@ -145,6 +145,57 @@ def train(fetch):
temp_dir.cleanup()
def train_callable():
mlp = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02)
loss = paddle.nn.CrossEntropyLoss()
optimizer = paddle.optimizer.Adam(learning_rate=0.00001,
beta1=0.9,
beta2=0.999,
epsilon=1e-08,
grad_clip=None)
metric = paddle.metric.Accuracy()
strategy = auto.Strategy()
strategy.auto_mode = "semi"
engine = auto.Engine(mlp, loss, optimizer, metric, strategy=strategy)
# train
train_dataset = MyDataset(batch_num * batch_size)
train_dataloader = engine.dataloader(train_dataset,
batch_size=batch_size,
mode="train")
for _ in train_dataloader:
outs = engine(mode="train")
# eval
eval_dataset2 = MyDataset(batch_size)
eval_dataloader = engine.dataloader(eval_dataset2,
batch_size=batch_size,
mode="eval")
for _ in eval_dataloader:
outs = engine(mode="eval")
# predict
test_dataset = MyDataset(batch_size)
predict_dataloader = engine.dataloader(test_dataset,
batch_size=batch_size,
mode="predict")
for _ in predict_dataloader:
outs = engine(mode="predict")
# save
temp_dir = tempfile.TemporaryDirectory()
model_filename = os.path.join(temp_dir.name, 'mlp')
engine.save(model_filename, training=True)
engine.load(model_filename)
temp_dir.cleanup()
if __name__ == "__main__":
train(fetch=True)
train(fetch=False)
train_callable()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册