未验证 提交 243acdb4 编写于 作者: A Aurelius84 提交者: GitHub

[dy2st]Add ProgramHelper to polish build program logic in autoparallel.Engine (#44513)

* [dy2st]Add ProgramHelper to polish build program logic in autoparallel.Engine

* refine code
上级 ead696ae
...@@ -37,6 +37,7 @@ from paddle.distributed import fleet ...@@ -37,6 +37,7 @@ from paddle.distributed import fleet
from paddle.distributed.utils import get_logger from paddle.distributed.utils import get_logger
from paddle.distributed.passes import new_pass, PassContext from paddle.distributed.passes import new_pass, PassContext
from .hepler import ProgramHelper
from ..collective import _get_global_env from ..collective import _get_global_env
from .cluster import Cluster, get_default_cluster from .cluster import Cluster, get_default_cluster
from .planner_v2 import Planner from .planner_v2 import Planner
...@@ -141,87 +142,28 @@ class Engine: ...@@ -141,87 +142,28 @@ class Engine:
self._mode_init_states[mode] = True self._mode_init_states[mode] = True
def _build(self, mode): def _build(self, mode):
if _non_static_mode() or self._dygraph_mode: if _non_static_mode() or self._dygraph_mode:
paddle.disable_static()
self._dygraph_mode = True self._dygraph_mode = True
self._logger.info("Building model with 'to_static' method.") self._logger.info("Building model with 'to_static' method.")
# build forward main program program_helper = ProgramHelper(self.model, self._loss,
self.static_model = to_static(self.model, self._metrics, self.inputs_spec,
input_spec=self.inputs_spec)
inputs = self.static_model.forward.inputs
outputs = self.static_model.forward.outputs
forward_main_prog = self.static_model.forward.main_program
forward_startup_prog = self.static_model.forward.concrete_program.startup_program
self.concrete_program = self.static_model.forward.concrete_program
# build loss main program
outputs_spec = []
outputs_name = []
for out in outputs:
outputs_spec.append(InputSpec(out.shape, out.dtype, out.name))
outputs_name.append(out.name)
if isinstance(self._loss, paddle.nn.Layer):
self.static_loss = to_static(self._loss.forward,
input_spec=outputs_spec +
self.labels_spec) self.labels_spec)
loss_main_prog = self.static_loss.main_program # build forward main program
elif callable(self._loss): program_helper.build_program(mode)
self.static_loss = to_static(self._loss,
input_spec=outputs_spec +
self.labels_spec)
loss_main_prog = self.static_loss.main_program
# build startup program
for param in self.concrete_program.parameters:
Parameter(name=param.name,
desc=param,
type=param.type,
shape=param.shape,
dtype=param.dtype,
stop_gradient=param.stop_gradient,
block=forward_startup_prog.global_block())
paddle.enable_static() self.concrete_program = program_helper.concrete_program
serial_main_prog = program_helper.main_program
serial_startup_prog = program_helper.startup_program
# NOTE: pure program will loss dist_attr inputs = program_helper.input_vars
# feeded_var_names = [var.name for var in inputs] outputs = program_helper.output_vars
# main_prog_0 = main_prog_0._prune_with_input( labels = program_helper.label_vars
# feeded_var_names=feeded_var_names, targets=outputs) losses = program_helper.loss_vars
metrics = program_helper.metric_vars
labels = []
losses = []
metrics = []
# concat forward and loss prog
if mode != 'predict' and self._loss:
forward_block = forward_main_prog.global_block()
loss_block = loss_main_prog.global_block()
for idx, op in enumerate(loss_block.ops):
op_desc = forward_block.desc.append_op()
op_desc.copy_from(op.desc)
for in_name in op.input_arg_names:
if in_name in outputs_name:
continue
in_var = forward_block._clone_variable(
loss_block.vars[in_name], force_persistable=False)
if loss_block.vars[in_name].is_data:
labels.append(in_var)
for out_name in op.output_arg_names:
out_var = forward_block._clone_variable(
loss_block.vars[out_name], force_persistable=False)
if idx == len(loss_block.ops) - 1:
losses.append(out_var)
forward_block._sync_with_cpp()
serial_main_prog = forward_main_prog
serial_startup_prog = forward_startup_prog
# update metrics op in program
with static.program_guard(serial_main_prog, serial_startup_prog), \
utils.unique_name.guard():
if mode != "predict":
for metric in self._metrics:
metrics.extend(
to_list(metric.compute(*(outputs + labels))))
paddle.enable_static()
else: else:
# build program in static mode # build program in static mode
serial_main_prog = self._serial_main_progs.get(mode, None) serial_main_prog = self._serial_main_progs.get(mode, None)
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from paddle.nn import Layer
from paddle.jit import to_static, not_to_static
from paddle.distributed.utils import get_logger
from paddle.fluid.framework import Operator, Parameter, _non_static_mode
from .utils import to_list
class ProxyLayer(Layer):
"""
ProxyLayer implements all logic for converting dygraph model into
static Program IR. Meanwhile, it provides conviential interfaces for
auto parallel to visit feed/fetch/loss/metric variables.
"""
def __init__(self, layer, loss_func, metrics):
super(ProxyLayer, self).__init__()
# NOTE: All verify logics are finished in Engine.Prepare
self.inner_layer = layer
self.loss_func = loss_func
self.metrics = metrics
# train / eval / predict
self.mode = None
# generated program vars
self.input_vars = []
self.label_vars = []
self.output_vars = []
self.loss_vars = []
self.metric_vars = []
def _train(self, inputs, labels):
"""
Train process of inner_layer with forward/loss/metric logic.
"""
# step 1. save feed variables of Program
self.input_vars = inputs
self.label_vars = labels
# step 2. call inner_layer.forward
self.output_vars = self.inner_layer(*inputs)
# step 3. calculate loss if needed
new_inputs = self._prepare(self.output_vars, labels)
self.loss_vars = self.call_loss(new_inputs)
# step 4. calculate metrics if needed
self.metric_vars = self.call_metrics(new_inputs)
def _eval(self, inputs, labels):
"""
Evaluate process of inner_layer with forward/loss/metric logic.
"""
# TODO(dev): we can reuse codes with self._train after making
# sure if they can.
# step 1. save feed variables of Program
self.input_vars = inputs
self.label_vars = labels
# step 2. call inner_layer.forward
self.output_vars = self.inner_layer(*inputs)
# step 3. calculate loss if needed
new_inputs = self._prepare(self.output_vars, labels)
self.loss_vars = self.call_loss(new_inputs)
# step 4. calculate metrics if needed
self.metric_vars = self.call_metrics(new_inputs)
def _predict(self, inputs):
"""
Predict process of inner_layer with forward logic.
"""
# step 1. save feed variables of Program
self.input_vars = inputs
# step 2. call inner_layer.forward
self.output_vars = self.inner_layer(*inputs)
@not_to_static
def _prepare(self, outputs, labels):
"""
Concat outputs and labels as a single list
NOTE(dev): We use @not_to_static to avoid AST Analysis.
"""
return to_list(outputs) + to_list(labels)
def call_loss(self, inputs):
"""
Apply Loss Function on outputs and labels.
Args:
inputs: List[Variable]
Returns: List[Variable]
"""
res = []
if self.loss_func is not None:
res = self.loss_func(*inputs)
return res
def call_metrics(self, inputs):
"""
Apply Metrics Function on outputs and labels.
Args:
inputs: List[Variable]
Returns: List[Variable]
"""
outs = []
for metric in self.metrics:
outs.extend(metric.compute(*inputs))
return outs
def set_mode(self, mode):
self.mode = mode
self.training = mode == 'train'
class BuildInfo:
def __init__(self, mode=None, state=False):
self.mode = mode
self.state = state
def has_cache(self, mode):
return self.mode == mode and self.state is True
class ProgramHelper(object):
"""
A Helper class for Engine to provides different Program IR according specified 'mode'.
"""
def __init__(self, layer, loss_func, metrics, inputs_spec, labels_spec):
# original model config information
# TODO(Aurelius84): Implenet append_backward and optimizer in ProxyLayer
# after distribute engine satisify basic condition.
self.proxy_layer = ProxyLayer(layer, loss_func, metrics)
self.inputs_spec = inputs_spec
self.labels_spec = labels_spec
self.build_info = BuildInfo()
self._logger = get_logger(logging.INFO)
def build_program(self, mode):
"""
Convert dygraph model into static Program IR.
"""
assert mode in ['train', 'eval', 'predict']
# skip if we has already built program.
if self.build_info.has_cache(mode):
self._logger.info(
"Already build program with mode = %s, use cached program." %
mode)
return
self._logger.info("start to build program for mode = %s." % mode)
self.proxy_layer.mode = mode
input_spec = [self.inputs_spec, self.labels_spec
] if mode != 'predict' else [self.inputs_spec]
static_func = to_static(self.static_func(), input_spec=input_spec)
func_name = '_' + mode
setattr(self.proxy_layer, func_name, static_func)
# NOTE(dev): Because @to_static is a Lazy mechanism, so we explicitly call this to trigger
# generating Program IR immediately.
getattr(self.proxy_layer, func_name).concrete_program
def _build_startup_program(self):
"""
Create and Sync parameters into startup program.
"""
for param in self.concrete_program.parameters:
Parameter(name=param.name,
desc=param,
type=param.type,
shape=param.shape,
dtype=param.dtype,
stop_gradient=param.stop_gradient,
block=self.startup_program.global_block())
def static_func(self):
"""
Return target mode function.
"""
assert self.proxy_layer.mode in [
'train', 'eval', 'predict'
], "Please call build_program(mode) firstly."
func_name = '_' + self.proxy_layer.mode
return getattr(self.proxy_layer, func_name)
@property
def concrete_program(self):
return self.static_func().concrete_program
@property
def main_program(self):
return self.concrete_program.main_program
@property
def startup_program(self):
return self.concrete_program.startup_program
@property
def input_vars(self):
return to_list(self.proxy_layer.input_vars)
@property
def output_vars(self):
return to_list(self.proxy_layer.output_vars)
@property
def label_vars(self):
return to_list(self.proxy_layer.label_vars)
@property
def loss_vars(self):
return to_list(self.proxy_layer.loss_vars)
@property
def metric_vars(self):
return to_list(self.proxy_layer.metric_vars)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册