未验证 提交 97f86d84 编写于 作者: A Aurelius84 提交者: GitHub

[Dy2Stat] Refine PartialProgramLayer logic (#33796)

* refine temp_scope_vec logic

* polish partial_program

* fix fake var

* add stop_gradient in spec

* fix fake_var

* fix unittest
上级 9347df84
......@@ -103,8 +103,11 @@ class FunctionSpec(object):
for idx, input_var in enumerate(flatten(args)):
if isinstance(input_var, np.ndarray):
input_var = paddle.static.InputSpec.from_numpy(input_var)
_set_spec_stop_gradient(input_var, True)
elif isinstance(input_var, core.VarBase):
stop_gradient = input_var.stop_gradient
input_var = paddle.static.InputSpec.from_tensor(input_var)
_set_spec_stop_gradient(input_var, stop_gradient)
args_with_spec.append(input_var)
......@@ -172,13 +175,15 @@ class FunctionSpec(object):
block = main_program.global_block()
for i, var_spec in enumerate(flat_input_spec):
if isinstance(var_spec, paddle.static.InputSpec):
stop_gradient = getattr(var_spec, 'stop_gradient', False)
feed_layer = block.create_var(
# TODO(Aurelius84): consider a more elegant way to name this
name=var_spec.name or "feed_%s" % i,
shape=var_spec.shape,
dtype=var_spec.dtype,
is_data=True,
need_check_feed=False)
need_check_feed=False,
stop_gradient=stop_gradient)
else:
feed_layer = var_spec
inputs.append(feed_layer)
......@@ -302,7 +307,7 @@ def convert_to_input_spec(inputs, input_spec):
if isinstance(rest_input, (core.VarBase, np.ndarray)):
logging_utils.warn(
"The inputs constain `{}` without specificing InputSpec, its shape and dtype will be treated immutable. "
"Please specific InputSpec information in `@declarative` if you expect them as mutable inputs.".
"Please specific InputSpec information in `@to_static` if you expect them as mutable inputs.".
format(type_name(rest_input)))
input_with_spec.extend(inputs[len(input_spec):])
......@@ -380,3 +385,12 @@ def _replace_spec_name(name, input_spec):
return processed_specs
else:
return input_spec
def _set_spec_stop_gradient(spec, stop_gradient):
"""
Set new attribute ``stop_gradient`` for InputSpec to avoid generating redundant grad_op
while append_backward.
"""
assert isinstance(spec, paddle.static.InputSpec)
spec.stop_gradient = stop_gradient
......@@ -35,6 +35,7 @@ class NestSequence(object):
def __init__(self, raw_input, need_check=False):
self.__raw_input = raw_input
self.__input_list = self.tolist()
self.__var_ids = self._get_var_ids()
self._check_non_variable(need_check)
......@@ -48,12 +49,12 @@ class NestSequence(object):
"""
Restores the nested sequence from value list.
"""
assert len(self.tolist()) == len(value_list)
assert len(self.__input_list) == len(value_list)
return pack_sequence_as(self.__raw_input, value_list)
def _get_var_ids(self):
var_ids = []
for idx, var in enumerate(self.tolist()):
for idx, var in enumerate(self.__input_list):
if isinstance(var, (framework.Variable, core.VarBase)):
var_ids.append(idx)
......@@ -65,7 +66,7 @@ class NestSequence(object):
"""
if need_check:
warning_types = set()
for var in self.tolist():
for var in self.__input_list:
if not isinstance(var, (framework.Variable, core.VarBase)):
warning_types.add(type(var))
if warning_types:
......@@ -80,7 +81,7 @@ class NestSequence(object):
return self.__var_ids
def __getitem__(self, item):
return self.tolist()[item]
return self.__input_list[item]
class LazyInitialized(object):
......@@ -106,7 +107,7 @@ def _change_is_test_status(program, is_test):
return program
class PartialProgramLayer(layers.Layer):
class PartialProgramLayer:
"""
PartialProgramLayer wraps all the ops from layers decorated by `@declarative`
and execute them as a static subgraph.
......@@ -134,7 +135,9 @@ class PartialProgramLayer(layers.Layer):
self._params = parameters if parameters is not None else []
self._origin_main_program = self._verify_program(main_program)
self._inner_scope = core.Scope()
self._tmp_scope_vec = self._create_scope_vec()
# A fake_var to handle empty input or output
self.__fake_vars = _create_fake_var()
# Set default mode to train
self._double_grads = self._get_double_grads(self._origin_main_program)
self.training = True
......@@ -217,19 +220,19 @@ class PartialProgramLayer(layers.Layer):
var_desc.name(),
var_desc.type(), False)
double_grads.append(var_base)
return double_grads
return self._valid_vars(double_grads)
def forward(self, inputs):
in_vars, out_vars, tmp_scope_vec = self._prepare(inputs)
def __call__(self, inputs):
in_vars, out_vars = self._prepare(inputs)
attrs = ('global_block', self.program.desc.block(0), 'start_op_index',
0, 'end_op_index', self._infer_program.desc.block(0).op_size(),
'is_test', not self.training)
core.ops.run_program(
valid_vars(in_vars),
valid_vars(self._params),
valid_vars(out_vars), tmp_scope_vec,
valid_vars(self._double_grads), *attrs)
self._valid_vars(in_vars),
self._valid_vars(self._params),
self._valid_vars(out_vars), self._tmp_scope_vec, self._double_grads,
*attrs)
restored_nest_out = self._restore_out(out_vars)
return self._remove_no_value(restored_nest_out)
......@@ -264,7 +267,6 @@ class PartialProgramLayer(layers.Layer):
expected_place):
var = value._copy_to(expected_place, False)
var.stop_gradient = True
var.name = value.name
else:
var = value
var.name = self._inputs[i].desc.name()
......@@ -272,25 +274,29 @@ class PartialProgramLayer(layers.Layer):
continue
input_vars.append(var)
# Create VarBase to receive output data.
out_vars = []
for idx in self._outputs.var_ids:
var = self._outputs[idx]
def create_out(var_id):
var = self._outputs[var_id]
assert isinstance(var, framework.Variable)
var_desc = var.desc
var_base = core.VarBase(var_desc.dtype(),
var_desc.shape(),
var_desc.name(), var_desc.type(), False)
out_vars.append(var_base)
return var_base
# Create VarBase to receive output data.
out_vars = list(map(create_out, self._outputs.var_ids))
return input_vars, out_vars
def _create_scope_vec(self):
# Hold forward variables
tmp_scope_vec = core.VarBase(core.VarDesc.VarType.FP32, [],
"program_out_scope",
core.VarDesc.VarType.STEP_SCOPES, True)
tmp_scope_vec.value().set_scope(self._inner_scope)
return input_vars, out_vars, tmp_scope_vec
inner_scope = core.Scope()
tmp_scope_vec.value().set_scope(inner_scope)
return tmp_scope_vec
def _restore_out(self, out_vars):
"""
......@@ -311,8 +317,9 @@ class PartialProgramLayer(layers.Layer):
return main_program.clone(for_test=True)
def _is_no_value(self, var):
if isinstance(var, core.VarBase):
if var.shape == [1] and var.numpy()[0] == RETURN_NO_VALUE_MAGIC_NUM:
if isinstance(var, core.VarBase) and var.shape == [1]:
# NOTE: .numpy() will insert MemcpySync operation, it hits performance.
if var.numpy()[0] == RETURN_NO_VALUE_MAGIC_NUM:
return True
return False
......@@ -405,20 +412,22 @@ class PartialProgramLayer(layers.Layer):
"Please define the layer with parameters in `__init__` function."
% name)
def _valid_vars(self, vars):
"""
Note: run_program_op.InferShape requires `X`/'Out' not be null.
But it's common in dy2static, fake varBase is created to handle the
problem.
"""
return vars if vars else self.__fake_vars
def valid_vars(vars):
def _create_fake_var():
"""
Note: run_program_op.InferShape requires `X`/'Out' not be null.
But it's common in dy2static, fake varBase is created to handle the
problem.
Create a fake_var (force on CPU) to handle empty input or output
"""
if vars:
return vars
return [
core.VarBase(
value=[1],
name='Fake_var',
place=framework._current_expected_place())
core.VarBase(core.VarDesc.VarType.FP32, [], "Fake_var",
core.VarDesc.VarType.RAW, False)
]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册