diff --git a/python/paddle/distributed/auto_parallel/engine.py b/python/paddle/distributed/auto_parallel/engine.py index 72a377603edc759c91bf483ec58bbab556704e97..b9ff116d2443bc334f774e69effcf2597315fd5c 100644 --- a/python/paddle/distributed/auto_parallel/engine.py +++ b/python/paddle/distributed/auto_parallel/engine.py @@ -37,6 +37,7 @@ from paddle.distributed import fleet from paddle.distributed.utils import get_logger from paddle.distributed.passes import new_pass, PassContext +from .hepler import ProgramHelper from ..collective import _get_global_env from .cluster import Cluster, get_default_cluster from .planner_v2 import Planner @@ -141,87 +142,28 @@ class Engine: self._mode_init_states[mode] = True def _build(self, mode): - if _non_static_mode() or self._dygraph_mode: + paddle.disable_static() self._dygraph_mode = True self._logger.info("Building model with 'to_static' method.") + program_helper = ProgramHelper(self.model, self._loss, + self._metrics, self.inputs_spec, + self.labels_spec) # build forward main program - self.static_model = to_static(self.model, - input_spec=self.inputs_spec) - inputs = self.static_model.forward.inputs - outputs = self.static_model.forward.outputs - forward_main_prog = self.static_model.forward.main_program - forward_startup_prog = self.static_model.forward.concrete_program.startup_program - self.concrete_program = self.static_model.forward.concrete_program - - # build loss main program - outputs_spec = [] - outputs_name = [] - for out in outputs: - outputs_spec.append(InputSpec(out.shape, out.dtype, out.name)) - outputs_name.append(out.name) - if isinstance(self._loss, paddle.nn.Layer): - self.static_loss = to_static(self._loss.forward, - input_spec=outputs_spec + - self.labels_spec) - loss_main_prog = self.static_loss.main_program - elif callable(self._loss): - self.static_loss = to_static(self._loss, - input_spec=outputs_spec + - self.labels_spec) - loss_main_prog = self.static_loss.main_program - - # build startup program - for param in self.concrete_program.parameters: - Parameter(name=param.name, - desc=param, - type=param.type, - shape=param.shape, - dtype=param.dtype, - stop_gradient=param.stop_gradient, - block=forward_startup_prog.global_block()) + program_helper.build_program(mode) - paddle.enable_static() + self.concrete_program = program_helper.concrete_program + serial_main_prog = program_helper.main_program + serial_startup_prog = program_helper.startup_program - # NOTE: pure program will loss dist_attr - # feeded_var_names = [var.name for var in inputs] - # main_prog_0 = main_prog_0._prune_with_input( - # feeded_var_names=feeded_var_names, targets=outputs) - - labels = [] - losses = [] - metrics = [] - # concat forward and loss prog - if mode != 'predict' and self._loss: - forward_block = forward_main_prog.global_block() - loss_block = loss_main_prog.global_block() - for idx, op in enumerate(loss_block.ops): - op_desc = forward_block.desc.append_op() - op_desc.copy_from(op.desc) - for in_name in op.input_arg_names: - if in_name in outputs_name: - continue - in_var = forward_block._clone_variable( - loss_block.vars[in_name], force_persistable=False) - if loss_block.vars[in_name].is_data: - labels.append(in_var) - for out_name in op.output_arg_names: - out_var = forward_block._clone_variable( - loss_block.vars[out_name], force_persistable=False) - if idx == len(loss_block.ops) - 1: - losses.append(out_var) - forward_block._sync_with_cpp() - serial_main_prog = forward_main_prog - serial_startup_prog = forward_startup_prog - # update metrics op in program - with static.program_guard(serial_main_prog, serial_startup_prog), \ - utils.unique_name.guard(): - if mode != "predict": - for metric in self._metrics: - metrics.extend( - to_list(metric.compute(*(outputs + labels)))) + inputs = program_helper.input_vars + outputs = program_helper.output_vars + labels = program_helper.label_vars + losses = program_helper.loss_vars + metrics = program_helper.metric_vars + paddle.enable_static() else: # build program in static mode serial_main_prog = self._serial_main_progs.get(mode, None) diff --git a/python/paddle/distributed/auto_parallel/hepler.py b/python/paddle/distributed/auto_parallel/hepler.py new file mode 100644 index 0000000000000000000000000000000000000000..d85489daf64caccae1a6a76c07de6bb0e56c951e --- /dev/null +++ b/python/paddle/distributed/auto_parallel/hepler.py @@ -0,0 +1,244 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from paddle.nn import Layer +from paddle.jit import to_static, not_to_static +from paddle.distributed.utils import get_logger +from paddle.fluid.framework import Operator, Parameter, _non_static_mode + +from .utils import to_list + + +class ProxyLayer(Layer): + """ + ProxyLayer implements all logic for converting dygraph model into + static Program IR. Meanwhile, it provides conviential interfaces for + auto parallel to visit feed/fetch/loss/metric variables. + """ + + def __init__(self, layer, loss_func, metrics): + super(ProxyLayer, self).__init__() + # NOTE: All verify logics are finished in Engine.Prepare + self.inner_layer = layer + self.loss_func = loss_func + self.metrics = metrics + # train / eval / predict + self.mode = None + + # generated program vars + self.input_vars = [] + self.label_vars = [] + self.output_vars = [] + self.loss_vars = [] + self.metric_vars = [] + + def _train(self, inputs, labels): + """ + Train process of inner_layer with forward/loss/metric logic. + """ + # step 1. save feed variables of Program + self.input_vars = inputs + self.label_vars = labels + + # step 2. call inner_layer.forward + self.output_vars = self.inner_layer(*inputs) + + # step 3. calculate loss if needed + new_inputs = self._prepare(self.output_vars, labels) + self.loss_vars = self.call_loss(new_inputs) + + # step 4. calculate metrics if needed + self.metric_vars = self.call_metrics(new_inputs) + + def _eval(self, inputs, labels): + """ + Evaluate process of inner_layer with forward/loss/metric logic. + """ + # TODO(dev): we can reuse codes with self._train after making + # sure if they can. + + # step 1. save feed variables of Program + self.input_vars = inputs + self.label_vars = labels + + # step 2. call inner_layer.forward + self.output_vars = self.inner_layer(*inputs) + + # step 3. calculate loss if needed + new_inputs = self._prepare(self.output_vars, labels) + self.loss_vars = self.call_loss(new_inputs) + + # step 4. calculate metrics if needed + self.metric_vars = self.call_metrics(new_inputs) + + def _predict(self, inputs): + """ + Predict process of inner_layer with forward logic. + """ + # step 1. save feed variables of Program + self.input_vars = inputs + + # step 2. call inner_layer.forward + self.output_vars = self.inner_layer(*inputs) + + @not_to_static + def _prepare(self, outputs, labels): + """ + Concat outputs and labels as a single list + + NOTE(dev): We use @not_to_static to avoid AST Analysis. + """ + return to_list(outputs) + to_list(labels) + + def call_loss(self, inputs): + """ + Apply Loss Function on outputs and labels. + + Args: + inputs: List[Variable] + + Returns: List[Variable] + """ + res = [] + if self.loss_func is not None: + res = self.loss_func(*inputs) + return res + + def call_metrics(self, inputs): + """ + Apply Metrics Function on outputs and labels. + + Args: + inputs: List[Variable] + + Returns: List[Variable] + """ + outs = [] + for metric in self.metrics: + outs.extend(metric.compute(*inputs)) + + return outs + + def set_mode(self, mode): + self.mode = mode + self.training = mode == 'train' + + +class BuildInfo: + + def __init__(self, mode=None, state=False): + self.mode = mode + self.state = state + + def has_cache(self, mode): + return self.mode == mode and self.state is True + + +class ProgramHelper(object): + """ + A Helper class for Engine to provides different Program IR according specified 'mode'. + """ + + def __init__(self, layer, loss_func, metrics, inputs_spec, labels_spec): + # original model config information + # TODO(Aurelius84): Implenet append_backward and optimizer in ProxyLayer + # after distribute engine satisify basic condition. + self.proxy_layer = ProxyLayer(layer, loss_func, metrics) + self.inputs_spec = inputs_spec + self.labels_spec = labels_spec + + self.build_info = BuildInfo() + self._logger = get_logger(logging.INFO) + + def build_program(self, mode): + """ + Convert dygraph model into static Program IR. + """ + assert mode in ['train', 'eval', 'predict'] + # skip if we has already built program. + if self.build_info.has_cache(mode): + self._logger.info( + "Already build program with mode = %s, use cached program." % + mode) + return + + self._logger.info("start to build program for mode = %s." % mode) + self.proxy_layer.mode = mode + input_spec = [self.inputs_spec, self.labels_spec + ] if mode != 'predict' else [self.inputs_spec] + static_func = to_static(self.static_func(), input_spec=input_spec) + + func_name = '_' + mode + setattr(self.proxy_layer, func_name, static_func) + + # NOTE(dev): Because @to_static is a Lazy mechanism, so we explicitly call this to trigger + # generating Program IR immediately. + getattr(self.proxy_layer, func_name).concrete_program + + def _build_startup_program(self): + """ + Create and Sync parameters into startup program. + """ + for param in self.concrete_program.parameters: + Parameter(name=param.name, + desc=param, + type=param.type, + shape=param.shape, + dtype=param.dtype, + stop_gradient=param.stop_gradient, + block=self.startup_program.global_block()) + + def static_func(self): + """ + Return target mode function. + """ + assert self.proxy_layer.mode in [ + 'train', 'eval', 'predict' + ], "Please call build_program(mode) firstly." + func_name = '_' + self.proxy_layer.mode + return getattr(self.proxy_layer, func_name) + + @property + def concrete_program(self): + return self.static_func().concrete_program + + @property + def main_program(self): + return self.concrete_program.main_program + + @property + def startup_program(self): + return self.concrete_program.startup_program + + @property + def input_vars(self): + return to_list(self.proxy_layer.input_vars) + + @property + def output_vars(self): + return to_list(self.proxy_layer.output_vars) + + @property + def label_vars(self): + return to_list(self.proxy_layer.label_vars) + + @property + def loss_vars(self): + return to_list(self.proxy_layer.loss_vars) + + @property + def metric_vars(self): + return to_list(self.proxy_layer.metric_vars)