未验证 提交 97f86d84 编写于 作者: A Aurelius84 提交者: GitHub

[Dy2Stat] Refine PartialProgramLayer logic (#33796)

* refine temp_scope_vec logic

* polish partial_program

* fix fake var

* add stop_gradient in spec

* fix fake_var

* fix unittest
上级 9347df84
...@@ -103,8 +103,11 @@ class FunctionSpec(object): ...@@ -103,8 +103,11 @@ class FunctionSpec(object):
for idx, input_var in enumerate(flatten(args)): for idx, input_var in enumerate(flatten(args)):
if isinstance(input_var, np.ndarray): if isinstance(input_var, np.ndarray):
input_var = paddle.static.InputSpec.from_numpy(input_var) input_var = paddle.static.InputSpec.from_numpy(input_var)
_set_spec_stop_gradient(input_var, True)
elif isinstance(input_var, core.VarBase): elif isinstance(input_var, core.VarBase):
stop_gradient = input_var.stop_gradient
input_var = paddle.static.InputSpec.from_tensor(input_var) input_var = paddle.static.InputSpec.from_tensor(input_var)
_set_spec_stop_gradient(input_var, stop_gradient)
args_with_spec.append(input_var) args_with_spec.append(input_var)
...@@ -172,13 +175,15 @@ class FunctionSpec(object): ...@@ -172,13 +175,15 @@ class FunctionSpec(object):
block = main_program.global_block() block = main_program.global_block()
for i, var_spec in enumerate(flat_input_spec): for i, var_spec in enumerate(flat_input_spec):
if isinstance(var_spec, paddle.static.InputSpec): if isinstance(var_spec, paddle.static.InputSpec):
stop_gradient = getattr(var_spec, 'stop_gradient', False)
feed_layer = block.create_var( feed_layer = block.create_var(
# TODO(Aurelius84): consider a more elegant way to name this # TODO(Aurelius84): consider a more elegant way to name this
name=var_spec.name or "feed_%s" % i, name=var_spec.name or "feed_%s" % i,
shape=var_spec.shape, shape=var_spec.shape,
dtype=var_spec.dtype, dtype=var_spec.dtype,
is_data=True, is_data=True,
need_check_feed=False) need_check_feed=False,
stop_gradient=stop_gradient)
else: else:
feed_layer = var_spec feed_layer = var_spec
inputs.append(feed_layer) inputs.append(feed_layer)
...@@ -302,7 +307,7 @@ def convert_to_input_spec(inputs, input_spec): ...@@ -302,7 +307,7 @@ def convert_to_input_spec(inputs, input_spec):
if isinstance(rest_input, (core.VarBase, np.ndarray)): if isinstance(rest_input, (core.VarBase, np.ndarray)):
logging_utils.warn( logging_utils.warn(
"The inputs constain `{}` without specificing InputSpec, its shape and dtype will be treated immutable. " "The inputs constain `{}` without specificing InputSpec, its shape and dtype will be treated immutable. "
"Please specific InputSpec information in `@declarative` if you expect them as mutable inputs.". "Please specific InputSpec information in `@to_static` if you expect them as mutable inputs.".
format(type_name(rest_input))) format(type_name(rest_input)))
input_with_spec.extend(inputs[len(input_spec):]) input_with_spec.extend(inputs[len(input_spec):])
...@@ -380,3 +385,12 @@ def _replace_spec_name(name, input_spec): ...@@ -380,3 +385,12 @@ def _replace_spec_name(name, input_spec):
return processed_specs return processed_specs
else: else:
return input_spec return input_spec
def _set_spec_stop_gradient(spec, stop_gradient):
"""
Set new attribute ``stop_gradient`` for InputSpec to avoid generating redundant grad_op
while append_backward.
"""
assert isinstance(spec, paddle.static.InputSpec)
spec.stop_gradient = stop_gradient
...@@ -35,6 +35,7 @@ class NestSequence(object): ...@@ -35,6 +35,7 @@ class NestSequence(object):
def __init__(self, raw_input, need_check=False): def __init__(self, raw_input, need_check=False):
self.__raw_input = raw_input self.__raw_input = raw_input
self.__input_list = self.tolist()
self.__var_ids = self._get_var_ids() self.__var_ids = self._get_var_ids()
self._check_non_variable(need_check) self._check_non_variable(need_check)
...@@ -48,12 +49,12 @@ class NestSequence(object): ...@@ -48,12 +49,12 @@ class NestSequence(object):
""" """
Restores the nested sequence from value list. Restores the nested sequence from value list.
""" """
assert len(self.tolist()) == len(value_list) assert len(self.__input_list) == len(value_list)
return pack_sequence_as(self.__raw_input, value_list) return pack_sequence_as(self.__raw_input, value_list)
def _get_var_ids(self): def _get_var_ids(self):
var_ids = [] var_ids = []
for idx, var in enumerate(self.tolist()): for idx, var in enumerate(self.__input_list):
if isinstance(var, (framework.Variable, core.VarBase)): if isinstance(var, (framework.Variable, core.VarBase)):
var_ids.append(idx) var_ids.append(idx)
...@@ -65,7 +66,7 @@ class NestSequence(object): ...@@ -65,7 +66,7 @@ class NestSequence(object):
""" """
if need_check: if need_check:
warning_types = set() warning_types = set()
for var in self.tolist(): for var in self.__input_list:
if not isinstance(var, (framework.Variable, core.VarBase)): if not isinstance(var, (framework.Variable, core.VarBase)):
warning_types.add(type(var)) warning_types.add(type(var))
if warning_types: if warning_types:
...@@ -80,7 +81,7 @@ class NestSequence(object): ...@@ -80,7 +81,7 @@ class NestSequence(object):
return self.__var_ids return self.__var_ids
def __getitem__(self, item): def __getitem__(self, item):
return self.tolist()[item] return self.__input_list[item]
class LazyInitialized(object): class LazyInitialized(object):
...@@ -106,7 +107,7 @@ def _change_is_test_status(program, is_test): ...@@ -106,7 +107,7 @@ def _change_is_test_status(program, is_test):
return program return program
class PartialProgramLayer(layers.Layer): class PartialProgramLayer:
""" """
PartialProgramLayer wraps all the ops from layers decorated by `@declarative` PartialProgramLayer wraps all the ops from layers decorated by `@declarative`
and execute them as a static subgraph. and execute them as a static subgraph.
...@@ -134,7 +135,9 @@ class PartialProgramLayer(layers.Layer): ...@@ -134,7 +135,9 @@ class PartialProgramLayer(layers.Layer):
self._params = parameters if parameters is not None else [] self._params = parameters if parameters is not None else []
self._origin_main_program = self._verify_program(main_program) self._origin_main_program = self._verify_program(main_program)
self._inner_scope = core.Scope() self._tmp_scope_vec = self._create_scope_vec()
# A fake_var to handle empty input or output
self.__fake_vars = _create_fake_var()
# Set default mode to train # Set default mode to train
self._double_grads = self._get_double_grads(self._origin_main_program) self._double_grads = self._get_double_grads(self._origin_main_program)
self.training = True self.training = True
...@@ -217,19 +220,19 @@ class PartialProgramLayer(layers.Layer): ...@@ -217,19 +220,19 @@ class PartialProgramLayer(layers.Layer):
var_desc.name(), var_desc.name(),
var_desc.type(), False) var_desc.type(), False)
double_grads.append(var_base) double_grads.append(var_base)
return double_grads return self._valid_vars(double_grads)
def forward(self, inputs): def __call__(self, inputs):
in_vars, out_vars, tmp_scope_vec = self._prepare(inputs) in_vars, out_vars = self._prepare(inputs)
attrs = ('global_block', self.program.desc.block(0), 'start_op_index', attrs = ('global_block', self.program.desc.block(0), 'start_op_index',
0, 'end_op_index', self._infer_program.desc.block(0).op_size(), 0, 'end_op_index', self._infer_program.desc.block(0).op_size(),
'is_test', not self.training) 'is_test', not self.training)
core.ops.run_program( core.ops.run_program(
valid_vars(in_vars), self._valid_vars(in_vars),
valid_vars(self._params), self._valid_vars(self._params),
valid_vars(out_vars), tmp_scope_vec, self._valid_vars(out_vars), self._tmp_scope_vec, self._double_grads,
valid_vars(self._double_grads), *attrs) *attrs)
restored_nest_out = self._restore_out(out_vars) restored_nest_out = self._restore_out(out_vars)
return self._remove_no_value(restored_nest_out) return self._remove_no_value(restored_nest_out)
...@@ -264,7 +267,6 @@ class PartialProgramLayer(layers.Layer): ...@@ -264,7 +267,6 @@ class PartialProgramLayer(layers.Layer):
expected_place): expected_place):
var = value._copy_to(expected_place, False) var = value._copy_to(expected_place, False)
var.stop_gradient = True var.stop_gradient = True
var.name = value.name
else: else:
var = value var = value
var.name = self._inputs[i].desc.name() var.name = self._inputs[i].desc.name()
...@@ -272,25 +274,29 @@ class PartialProgramLayer(layers.Layer): ...@@ -272,25 +274,29 @@ class PartialProgramLayer(layers.Layer):
continue continue
input_vars.append(var) input_vars.append(var)
# Create VarBase to receive output data. def create_out(var_id):
out_vars = [] var = self._outputs[var_id]
for idx in self._outputs.var_ids:
var = self._outputs[idx]
assert isinstance(var, framework.Variable) assert isinstance(var, framework.Variable)
var_desc = var.desc var_desc = var.desc
var_base = core.VarBase(var_desc.dtype(), var_base = core.VarBase(var_desc.dtype(),
var_desc.shape(), var_desc.shape(),
var_desc.name(), var_desc.type(), False) var_desc.name(), var_desc.type(), False)
out_vars.append(var_base) return var_base
# Create VarBase to receive output data.
out_vars = list(map(create_out, self._outputs.var_ids))
return input_vars, out_vars
def _create_scope_vec(self):
# Hold forward variables # Hold forward variables
tmp_scope_vec = core.VarBase(core.VarDesc.VarType.FP32, [], tmp_scope_vec = core.VarBase(core.VarDesc.VarType.FP32, [],
"program_out_scope", "program_out_scope",
core.VarDesc.VarType.STEP_SCOPES, True) core.VarDesc.VarType.STEP_SCOPES, True)
tmp_scope_vec.value().set_scope(self._inner_scope) inner_scope = core.Scope()
tmp_scope_vec.value().set_scope(inner_scope)
return input_vars, out_vars, tmp_scope_vec return tmp_scope_vec
def _restore_out(self, out_vars): def _restore_out(self, out_vars):
""" """
...@@ -311,8 +317,9 @@ class PartialProgramLayer(layers.Layer): ...@@ -311,8 +317,9 @@ class PartialProgramLayer(layers.Layer):
return main_program.clone(for_test=True) return main_program.clone(for_test=True)
def _is_no_value(self, var): def _is_no_value(self, var):
if isinstance(var, core.VarBase): if isinstance(var, core.VarBase) and var.shape == [1]:
if var.shape == [1] and var.numpy()[0] == RETURN_NO_VALUE_MAGIC_NUM: # NOTE: .numpy() will insert MemcpySync operation, it hits performance.
if var.numpy()[0] == RETURN_NO_VALUE_MAGIC_NUM:
return True return True
return False return False
...@@ -405,20 +412,22 @@ class PartialProgramLayer(layers.Layer): ...@@ -405,20 +412,22 @@ class PartialProgramLayer(layers.Layer):
"Please define the layer with parameters in `__init__` function." "Please define the layer with parameters in `__init__` function."
% name) % name)
def _valid_vars(self, vars):
"""
Note: run_program_op.InferShape requires `X`/'Out' not be null.
But it's common in dy2static, fake varBase is created to handle the
problem.
"""
return vars if vars else self.__fake_vars
def valid_vars(vars): def _create_fake_var():
""" """
Note: run_program_op.InferShape requires `X`/'Out' not be null. Create a fake_var (force on CPU) to handle empty input or output
But it's common in dy2static, fake varBase is created to handle the
problem.
""" """
if vars:
return vars
return [ return [
core.VarBase( core.VarBase(core.VarDesc.VarType.FP32, [], "Fake_var",
value=[1], core.VarDesc.VarType.RAW, False)
name='Fake_var',
place=framework._current_expected_place())
] ]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册