未验证 提交 97f86d84 编写于 作者: A Aurelius84 提交者: GitHub

[Dy2Stat] Refine PartialProgramLayer logic (#33796)

* refine temp_scope_vec logic

* polish partial_program

* fix fake var

* add stop_gradient in spec

* fix fake_var

* fix unittest
上级 9347df84
...@@ -103,8 +103,11 @@ class FunctionSpec(object): ...@@ -103,8 +103,11 @@ class FunctionSpec(object):
for idx, input_var in enumerate(flatten(args)): for idx, input_var in enumerate(flatten(args)):
if isinstance(input_var, np.ndarray): if isinstance(input_var, np.ndarray):
input_var = paddle.static.InputSpec.from_numpy(input_var) input_var = paddle.static.InputSpec.from_numpy(input_var)
_set_spec_stop_gradient(input_var, True)
elif isinstance(input_var, core.VarBase): elif isinstance(input_var, core.VarBase):
stop_gradient = input_var.stop_gradient
input_var = paddle.static.InputSpec.from_tensor(input_var) input_var = paddle.static.InputSpec.from_tensor(input_var)
_set_spec_stop_gradient(input_var, stop_gradient)
args_with_spec.append(input_var) args_with_spec.append(input_var)
...@@ -172,13 +175,15 @@ class FunctionSpec(object): ...@@ -172,13 +175,15 @@ class FunctionSpec(object):
block = main_program.global_block() block = main_program.global_block()
for i, var_spec in enumerate(flat_input_spec): for i, var_spec in enumerate(flat_input_spec):
if isinstance(var_spec, paddle.static.InputSpec): if isinstance(var_spec, paddle.static.InputSpec):
stop_gradient = getattr(var_spec, 'stop_gradient', False)
feed_layer = block.create_var( feed_layer = block.create_var(
# TODO(Aurelius84): consider a more elegant way to name this # TODO(Aurelius84): consider a more elegant way to name this
name=var_spec.name or "feed_%s" % i, name=var_spec.name or "feed_%s" % i,
shape=var_spec.shape, shape=var_spec.shape,
dtype=var_spec.dtype, dtype=var_spec.dtype,
is_data=True, is_data=True,
need_check_feed=False) need_check_feed=False,
stop_gradient=stop_gradient)
else: else:
feed_layer = var_spec feed_layer = var_spec
inputs.append(feed_layer) inputs.append(feed_layer)
...@@ -302,7 +307,7 @@ def convert_to_input_spec(inputs, input_spec): ...@@ -302,7 +307,7 @@ def convert_to_input_spec(inputs, input_spec):
if isinstance(rest_input, (core.VarBase, np.ndarray)): if isinstance(rest_input, (core.VarBase, np.ndarray)):
logging_utils.warn( logging_utils.warn(
"The inputs constain `{}` without specificing InputSpec, its shape and dtype will be treated immutable. " "The inputs constain `{}` without specificing InputSpec, its shape and dtype will be treated immutable. "
"Please specific InputSpec information in `@declarative` if you expect them as mutable inputs.". "Please specific InputSpec information in `@to_static` if you expect them as mutable inputs.".
format(type_name(rest_input))) format(type_name(rest_input)))
input_with_spec.extend(inputs[len(input_spec):]) input_with_spec.extend(inputs[len(input_spec):])
...@@ -380,3 +385,12 @@ def _replace_spec_name(name, input_spec): ...@@ -380,3 +385,12 @@ def _replace_spec_name(name, input_spec):
return processed_specs return processed_specs
else: else:
return input_spec return input_spec
def _set_spec_stop_gradient(spec, stop_gradient):
"""
Set new attribute ``stop_gradient`` for InputSpec to avoid generating redundant grad_op
while append_backward.
"""
assert isinstance(spec, paddle.static.InputSpec)
spec.stop_gradient = stop_gradient
...@@ -35,6 +35,7 @@ class NestSequence(object): ...@@ -35,6 +35,7 @@ class NestSequence(object):
def __init__(self, raw_input, need_check=False): def __init__(self, raw_input, need_check=False):
self.__raw_input = raw_input self.__raw_input = raw_input
self.__input_list = self.tolist()
self.__var_ids = self._get_var_ids() self.__var_ids = self._get_var_ids()
self._check_non_variable(need_check) self._check_non_variable(need_check)
...@@ -48,12 +49,12 @@ class NestSequence(object): ...@@ -48,12 +49,12 @@ class NestSequence(object):
""" """
Restores the nested sequence from value list. Restores the nested sequence from value list.
""" """
assert len(self.tolist()) == len(value_list) assert len(self.__input_list) == len(value_list)
return pack_sequence_as(self.__raw_input, value_list) return pack_sequence_as(self.__raw_input, value_list)
def _get_var_ids(self): def _get_var_ids(self):
var_ids = [] var_ids = []
for idx, var in enumerate(self.tolist()): for idx, var in enumerate(self.__input_list):
if isinstance(var, (framework.Variable, core.VarBase)): if isinstance(var, (framework.Variable, core.VarBase)):
var_ids.append(idx) var_ids.append(idx)
...@@ -65,7 +66,7 @@ class NestSequence(object): ...@@ -65,7 +66,7 @@ class NestSequence(object):
""" """
if need_check: if need_check:
warning_types = set() warning_types = set()
for var in self.tolist(): for var in self.__input_list:
if not isinstance(var, (framework.Variable, core.VarBase)): if not isinstance(var, (framework.Variable, core.VarBase)):
warning_types.add(type(var)) warning_types.add(type(var))
if warning_types: if warning_types:
...@@ -80,7 +81,7 @@ class NestSequence(object): ...@@ -80,7 +81,7 @@ class NestSequence(object):
return self.__var_ids return self.__var_ids
def __getitem__(self, item): def __getitem__(self, item):
return self.tolist()[item] return self.__input_list[item]
class LazyInitialized(object): class LazyInitialized(object):
...@@ -106,7 +107,7 @@ def _change_is_test_status(program, is_test): ...@@ -106,7 +107,7 @@ def _change_is_test_status(program, is_test):
return program return program
class PartialProgramLayer(layers.Layer): class PartialProgramLayer:
""" """
PartialProgramLayer wraps all the ops from layers decorated by `@declarative` PartialProgramLayer wraps all the ops from layers decorated by `@declarative`
and execute them as a static subgraph. and execute them as a static subgraph.
...@@ -134,7 +135,9 @@ class PartialProgramLayer(layers.Layer): ...@@ -134,7 +135,9 @@ class PartialProgramLayer(layers.Layer):
self._params = parameters if parameters is not None else [] self._params = parameters if parameters is not None else []
self._origin_main_program = self._verify_program(main_program) self._origin_main_program = self._verify_program(main_program)
self._inner_scope = core.Scope() self._tmp_scope_vec = self._create_scope_vec()
# A fake_var to handle empty input or output
self.__fake_vars = _create_fake_var()
# Set default mode to train # Set default mode to train
self._double_grads = self._get_double_grads(self._origin_main_program) self._double_grads = self._get_double_grads(self._origin_main_program)
self.training = True self.training = True
...@@ -217,19 +220,19 @@ class PartialProgramLayer(layers.Layer): ...@@ -217,19 +220,19 @@ class PartialProgramLayer(layers.Layer):
var_desc.name(), var_desc.name(),
var_desc.type(), False) var_desc.type(), False)
double_grads.append(var_base) double_grads.append(var_base)
return double_grads return self._valid_vars(double_grads)
def forward(self, inputs): def __call__(self, inputs):
in_vars, out_vars, tmp_scope_vec = self._prepare(inputs) in_vars, out_vars = self._prepare(inputs)
attrs = ('global_block', self.program.desc.block(0), 'start_op_index', attrs = ('global_block', self.program.desc.block(0), 'start_op_index',
0, 'end_op_index', self._infer_program.desc.block(0).op_size(), 0, 'end_op_index', self._infer_program.desc.block(0).op_size(),
'is_test', not self.training) 'is_test', not self.training)
core.ops.run_program( core.ops.run_program(
valid_vars(in_vars), self._valid_vars(in_vars),
valid_vars(self._params), self._valid_vars(self._params),
valid_vars(out_vars), tmp_scope_vec, self._valid_vars(out_vars), self._tmp_scope_vec, self._double_grads,
valid_vars(self._double_grads), *attrs) *attrs)
restored_nest_out = self._restore_out(out_vars) restored_nest_out = self._restore_out(out_vars)
return self._remove_no_value(restored_nest_out) return self._remove_no_value(restored_nest_out)
...@@ -264,7 +267,6 @@ class PartialProgramLayer(layers.Layer): ...@@ -264,7 +267,6 @@ class PartialProgramLayer(layers.Layer):
expected_place): expected_place):
var = value._copy_to(expected_place, False) var = value._copy_to(expected_place, False)
var.stop_gradient = True var.stop_gradient = True
var.name = value.name
else: else:
var = value var = value
var.name = self._inputs[i].desc.name() var.name = self._inputs[i].desc.name()
...@@ -272,25 +274,29 @@ class PartialProgramLayer(layers.Layer): ...@@ -272,25 +274,29 @@ class PartialProgramLayer(layers.Layer):
continue continue
input_vars.append(var) input_vars.append(var)
# Create VarBase to receive output data. def create_out(var_id):
out_vars = [] var = self._outputs[var_id]
for idx in self._outputs.var_ids:
var = self._outputs[idx]
assert isinstance(var, framework.Variable) assert isinstance(var, framework.Variable)
var_desc = var.desc var_desc = var.desc
var_base = core.VarBase(var_desc.dtype(), var_base = core.VarBase(var_desc.dtype(),
var_desc.shape(), var_desc.shape(),
var_desc.name(), var_desc.type(), False) var_desc.name(), var_desc.type(), False)
out_vars.append(var_base) return var_base
# Create VarBase to receive output data.
out_vars = list(map(create_out, self._outputs.var_ids))
return input_vars, out_vars
def _create_scope_vec(self):
# Hold forward variables # Hold forward variables
tmp_scope_vec = core.VarBase(core.VarDesc.VarType.FP32, [], tmp_scope_vec = core.VarBase(core.VarDesc.VarType.FP32, [],
"program_out_scope", "program_out_scope",
core.VarDesc.VarType.STEP_SCOPES, True) core.VarDesc.VarType.STEP_SCOPES, True)
tmp_scope_vec.value().set_scope(self._inner_scope) inner_scope = core.Scope()
tmp_scope_vec.value().set_scope(inner_scope)
return input_vars, out_vars, tmp_scope_vec return tmp_scope_vec
def _restore_out(self, out_vars): def _restore_out(self, out_vars):
""" """
...@@ -311,8 +317,9 @@ class PartialProgramLayer(layers.Layer): ...@@ -311,8 +317,9 @@ class PartialProgramLayer(layers.Layer):
return main_program.clone(for_test=True) return main_program.clone(for_test=True)
def _is_no_value(self, var): def _is_no_value(self, var):
if isinstance(var, core.VarBase): if isinstance(var, core.VarBase) and var.shape == [1]:
if var.shape == [1] and var.numpy()[0] == RETURN_NO_VALUE_MAGIC_NUM: # NOTE: .numpy() will insert MemcpySync operation, it hits performance.
if var.numpy()[0] == RETURN_NO_VALUE_MAGIC_NUM:
return True return True
return False return False
...@@ -405,20 +412,22 @@ class PartialProgramLayer(layers.Layer): ...@@ -405,20 +412,22 @@ class PartialProgramLayer(layers.Layer):
"Please define the layer with parameters in `__init__` function." "Please define the layer with parameters in `__init__` function."
% name) % name)
def _valid_vars(self, vars):
def valid_vars(vars):
""" """
Note: run_program_op.InferShape requires `X`/'Out' not be null. Note: run_program_op.InferShape requires `X`/'Out' not be null.
But it's common in dy2static, fake varBase is created to handle the But it's common in dy2static, fake varBase is created to handle the
problem. problem.
""" """
if vars: return vars if vars else self.__fake_vars
return vars
def _create_fake_var():
"""
Create a fake_var (force on CPU) to handle empty input or output
"""
return [ return [
core.VarBase( core.VarBase(core.VarDesc.VarType.FP32, [], "Fake_var",
value=[1], core.VarDesc.VarType.RAW, False)
name='Fake_var',
place=framework._current_expected_place())
] ]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册