未验证 提交 13a250a2 编写于 作者: Z zhaoyingli 提交者: GitHub

[AutoParallel] add 'to_static' in engine api (#44202)

* add 'to_static' in engine api

* fix cmakelist
上级 c57e12be
......@@ -125,6 +125,9 @@ class DistributedContext:
# A flag indicates whether the used parallelism is data parallel
self._data_parallel = False
# flag whether using `to_static`
self._dygraph_mode = True
@property
def serial_main_program(self):
return self._serial_main_program
......
......@@ -21,6 +21,7 @@ import paddle.utils as utils
from paddle import fluid, static
from paddle.io import Dataset
from paddle.jit import to_static
from paddle.metric import Metric
from paddle.static import InputSpec
from paddle.fluid import core
......@@ -28,7 +29,7 @@ from paddle.fluid import program_guard
from paddle.fluid.layers.utils import flatten
from paddle.fluid.executor import global_scope, _to_name_str
from paddle.fluid.backward import append_backward
from paddle.fluid.framework import Operator
from paddle.fluid.framework import Operator, Parameter, _non_static_mode
from paddle.fluid.framework import _current_expected_place as _get_device
from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.distributed import fleet
......@@ -82,6 +83,7 @@ class Engine:
self._feed_vars = {}
self._fetch_vars = {}
self._planners = {}
self._dygraph_mode = False
def prepare(self,
optimizer=None,
......@@ -131,6 +133,88 @@ class Engine:
def _build(self, mode):
if _non_static_mode() or self._dygraph_mode:
self._dygraph_mode = True
self._logger.info("Building model with 'to_static' method.")
# build forward main program
self.static_model = to_static(self.model,
input_spec=self.inputs_spec)
inputs = self.static_model.forward.inputs
outputs = self.static_model.forward.outputs
forward_main_prog = self.static_model.forward.main_program
forward_startup_prog = self.static_model.forward.concrete_program.startup_program
self.concrete_program = self.static_model.forward.concrete_program
# build loss main program
outputs_spec = []
outputs_name = []
for out in outputs:
outputs_spec.append(InputSpec(out.shape, out.dtype, out.name))
outputs_name.append(out.name)
if isinstance(self._loss, paddle.nn.Layer):
self.static_loss = to_static(self._loss.forward,
input_spec=outputs_spec +
self.labels_spec)
loss_main_prog = self.static_loss.main_program
elif callable(self._loss):
self.static_loss = to_static(self._loss,
input_spec=outputs_spec +
self.labels_spec)
loss_main_prog = self.static_loss.main_program
# build startup program
for param in self.concrete_program.parameters:
Parameter(name=param.name,
desc=param,
type=param.type,
shape=param.shape,
dtype=param.dtype,
stop_gradient=param.stop_gradient,
block=forward_startup_prog.global_block())
paddle.enable_static()
# NOTE: pure program will loss dist_attr
# feeded_var_names = [var.name for var in inputs]
# main_prog_0 = main_prog_0._prune_with_input(
# feeded_var_names=feeded_var_names, targets=outputs)
labels = []
losses = []
metrics = []
# concat forward and loss prog
if mode != 'predict' and self._loss:
forward_block = forward_main_prog.global_block()
loss_block = loss_main_prog.global_block()
for idx, op in enumerate(loss_block.ops):
op_desc = forward_block.desc.append_op()
op_desc.copy_from(op.desc)
for in_name in op.input_arg_names:
if in_name in outputs_name:
continue
in_var = forward_block._clone_variable(
loss_block.vars[in_name], force_persistable=False)
if loss_block.vars[in_name].is_data:
labels.append(in_var)
for out_name in op.output_arg_names:
out_var = forward_block._clone_variable(
loss_block.vars[out_name], force_persistable=False)
if idx == len(loss_block.ops) - 1:
losses.append(out_var)
forward_block._sync_with_cpp()
serial_main_prog = forward_main_prog
serial_startup_prog = forward_startup_prog
# update metrics op in program
with static.program_guard(serial_main_prog, serial_startup_prog), \
utils.unique_name.guard():
if mode != "predict":
for metric in self._metrics:
metrics.extend(
to_list(metric.compute(*(outputs + labels))))
else:
# build program in static mode
serial_main_prog = self._serial_main_progs.get(mode, None)
if serial_main_prog is not None:
return
......@@ -151,7 +235,8 @@ class Engine:
if mode != "predict":
for metric in self._metrics:
metrics.extend(to_list(metric.compute(*(outputs + labels))))
metrics.extend(
to_list(metric.compute(*(outputs + labels))))
default_ctx = get_default_distributed_context()
if not default_ctx.has_annotation:
......@@ -172,6 +257,7 @@ class Engine:
serial_main_prog, serial_startup_prog, self._optimizer, losses,
feed_vars, fetch_vars, self.cluster, self.strategy)
self._dist_contexts[mode].gradient_scale = self._gradient_scale
self._dist_contexts[mode]._dygraph_mode = self._dygraph_mode
def _plan(self, mode):
if self._planned_mode is None:
......@@ -236,6 +322,35 @@ class Engine:
self._place = _get_device()
if isinstance(self._place, fluid.CUDAPlace):
self._place = fluid.CUDAPlace(ParallelEnv().dev_id)
if self._dygraph_mode:
paddle.disable_static()
main_program = self._dist_main_progs[mode][self._cur_rank]
for param in self.concrete_program.parameters:
# create var in scope and share parameters to scope
if param.name not in main_program.global_block().vars:
continue
# get param_var's dist_attr
var = main_program.global_block().vars[param.name]
var_dist_attr = self._dist_contexts[
mode].get_tensor_dist_attr_for_program(var)
dist_attr = {
"dims_mapping": var_dist_attr.dims_mapping,
"process_shape": var_dist_attr.process_mesh.topology,
"process_group": var_dist_attr.process_mesh.processes
}
# slice param_value with dist_attr
# share sliced_param_value with param_tensor in global_scope
from .converter import Converter
param_tensor = global_scope().var(param.name).get_tensor()
sliced_param = Converter.slice_with_dist_attr(
param.numpy(), dist_attr)
shared_tensor = paddle.to_tensor(sliced_param,
place=self._place)
param_tensor._share_data_with(
shared_tensor.value().get_tensor())
paddle.enable_static()
if self._executor is None:
self._executor = paddle.static.Executor(self._place)
uninitialized = []
......
......@@ -15,8 +15,10 @@
import copy
from collections import defaultdict
import paddle
from paddle.fluid import program_guard
from paddle.fluid.backward import append_backward
from paddle.fluid.framework import _non_static_mode
from paddle.distributed.passes import new_pass
from .reshard import Resharder
......@@ -110,9 +112,14 @@ class Parallelizer:
def _generate_optimizer(self, main_program, startup_program, optimizer,
params_grads):
if self._dist_context._dygraph_mode:
paddle.disable_static()
optimizer = copy.deepcopy(optimizer)
paddle.enable_static()
else:
optimizer = copy.deepcopy(optimizer)
with program_guard(main_program, startup_program):
optimizer_ops = copy.deepcopy(optimizer).apply_gradients(
params_grads)
optimizer_ops = optimizer.apply_gradients(params_grads)
self._completer.complete_update_annotation(main_program)
return optimizer_ops
......
......@@ -53,4 +53,5 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_comp_cost MODULES test_comp_cost ENVS ${dist_ENVS})
py_test_modules(test_dist_context MODULES test_dist_context ENVS ${dist_ENVS})
py_test_modules(test_prim_dist_op MODULES test_prim_dist_op ENVS ${dist_ENVS})
py_test_modules(test_to_static MODULES test_to_static ENVS ${dist_ENVS})
endif()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import os
import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
import paddle.distributed.auto_parallel as auto
import paddle.distributed.fleet as fleet
from paddle.io import Dataset
from paddle.static import InputSpec
from paddle.fluid.framework import _non_static_mode
from paddle.distributed.auto_parallel.engine import Engine
batch_size = 4
batch_num = 30
hidden_size = 1024
class_num = 10
class MyDataset(Dataset):
def __init__(self, num_samples):
super(MyDataset, self).__init__()
self.num_samples = num_samples
def __getitem__(self, index):
input = np.random.uniform(size=hidden_size).astype("float32")
label = np.random.randint(0, class_num - 1, dtype="int64")
return input, label
def __len__(self):
return self.num_samples
class MLPLayer(nn.Layer):
def __init__(self,
hidden_size=1024,
intermediate_size=4 * 1024,
dropout_ratio=0.1,
initializer_range=0.02):
super(MLPLayer, self).__init__()
d_model = hidden_size
dim_feedforward = intermediate_size
weight_attr = paddle.ParamAttr(
initializer=nn.initializer.Normal(mean=0.0, std=initializer_range))
self.linear0 = nn.Linear(d_model,
dim_feedforward,
weight_attr,
bias_attr=None)
self.linear1 = nn.Linear(dim_feedforward,
d_model,
weight_attr,
bias_attr=None)
self.linear2 = nn.Linear(d_model, 1, weight_attr, bias_attr=None)
self.norm = nn.LayerNorm(d_model, epsilon=1e-5)
self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train")
def forward(self, input):
out = self.norm(input)
out = self.linear0(out)
out = F.gelu(out, approximate=True)
out = self.linear1(out)
out = self.dropout(out)
out = self.linear2(out)
return out
class TestToStatic(unittest.TestCase):
def test_to_static(self):
mlp = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02)
loss = paddle.nn.CrossEntropyLoss()
optimizer = paddle.optimizer.SGD(learning_rate=0.00001,
parameters=mlp.parameters())
dataset = MyDataset(batch_num * batch_size)
inputs = InputSpec([batch_size, hidden_size], 'float32', 'x')
labels = InputSpec([batch_size], 'int64', 'label')
engine = Engine(model=mlp,
inputs_spec=inputs,
labels_spec=labels,
strategy=None)
assert _non_static_mode() == True
engine.prepare(optimizer=optimizer,
loss=loss,
metrics=paddle.metric.Accuracy())
assert _non_static_mode() == False
engine.fit(dataset, batch_size=batch_size)
engine.evaluate(dataset, batch_size=batch_size)
engine.predict(dataset, batch_size=batch_size)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册