未验证 提交 4c82e455 编写于 作者: W WangZhen 提交者: GitHub

[Cherry-Pick][Dy2St]Support call backward() without params in dy2st (#49812) (#50144)

* [Dy2St]Support call backward() without params in dy2st (#49812)

* Support call backward() without params in dy2st

* format code

* format code
上级 8c5e432b
...@@ -93,7 +93,8 @@ class SelectOutputInferShape : public framework::InferShapeBase { ...@@ -93,7 +93,8 @@ class SelectOutputInferShape : public framework::InferShapeBase {
void operator()(framework::InferShapeContext *context) const override { void operator()(framework::InferShapeContext *context) const override {
OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", "SelectOutput"); OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", "SelectOutput");
OP_INOUT_CHECK(context->HasInput("Mask"), "Input", "Mask", "SelectOutput"); OP_INOUT_CHECK(context->HasInput("Mask"), "Input", "Mask", "SelectOutput");
OP_INOUT_CHECK(context->HasOutputs("Out"), "Output", "Out", "SelectOutput"); OP_INOUT_CHECK(
context->HasOutputs("Out", true), "Output", "Out", "SelectOutput");
} }
}; };
......
...@@ -26,9 +26,9 @@ static PyObject *eager_api_run_program(PyObject *self, ...@@ -26,9 +26,9 @@ static PyObject *eager_api_run_program(PyObject *self,
PyObject *kwargs) { PyObject *kwargs) {
PyThreadState *tstate = nullptr; PyThreadState *tstate = nullptr;
try { try {
auto X = GetTensorListFromArgs("run_program", "X", args, 0, false); auto X = GetTensorListFromArgs("run_program", "X", args, 0, true);
auto Params = GetTensorListFromArgs("run_program", "Params", args, 1, true); auto Params = GetTensorListFromArgs("run_program", "Params", args, 1, true);
auto Out = GetTensorPtrListFromArgs("run_program", "Out", args, 2, false); auto Out = GetTensorPtrListFromArgs("run_program", "Out", args, 2, true);
auto OutScope = auto OutScope =
GetScopePtrListFromArgs("run_program", "OutScope", args, 3, false); GetScopePtrListFromArgs("run_program", "OutScope", args, 3, false);
auto DOut = GetTensorPtrListFromArgs("run_program", "DOut", args, 4, true); auto DOut = GetTensorPtrListFromArgs("run_program", "DOut", args, 4, true);
......
...@@ -18,19 +18,32 @@ import six ...@@ -18,19 +18,32 @@ import six
import paddle import paddle
from paddle.fluid import framework, backward, core, program_guard from paddle.fluid import framework, backward, core, program_guard
from paddle.fluid.executor import _is_enable_standalone_executor, _is_dy2st_enable_standalone_executor from paddle.fluid.executor import (
_is_enable_standalone_executor,
_is_dy2st_enable_standalone_executor,
)
from paddle.fluid.dygraph import layers from paddle.fluid.dygraph import layers
from paddle.fluid.dygraph.base import switch_to_static_graph from paddle.fluid.dygraph.base import switch_to_static_graph
from paddle.fluid.dygraph.dygraph_to_static import logging_utils from paddle.fluid.dygraph.dygraph_to_static import logging_utils
from paddle.fluid.dygraph.dygraph_to_static.return_transformer import RETURN_NO_VALUE_MAGIC_NUM from paddle.fluid.dygraph.dygraph_to_static.return_transformer import (
RETURN_NO_VALUE_MAGIC_NUM,
)
from paddle.fluid.layers.utils import flatten from paddle.fluid.layers.utils import flatten
from paddle.fluid.layers.utils import pack_sequence_as from paddle.fluid.layers.utils import pack_sequence_as
from paddle.fluid.layers.utils import _hash_with_id from paddle.fluid.layers.utils import _hash_with_id
from paddle.fluid.compiler import BuildStrategy from paddle.fluid.compiler import BuildStrategy
from paddle.fluid.framework import _apply_pass from paddle.fluid.framework import _apply_pass
from paddle.fluid.contrib.mixed_precision.decorator import AutoMixedPrecisionLists from paddle.fluid.contrib.mixed_precision.decorator import (
from paddle.fluid.contrib.mixed_precision.fp16_utils import rewrite_program, cast_model_to_fp16 AutoMixedPrecisionLists,
from paddle.fluid.dygraph.amp.auto_cast import _in_amp_guard, _in_pure_fp16_guard )
from paddle.fluid.contrib.mixed_precision.fp16_utils import (
rewrite_program,
cast_model_to_fp16,
)
from paddle.fluid.dygraph.amp.auto_cast import (
_in_amp_guard,
_in_pure_fp16_guard,
)
import paddle.compat as cpt import paddle.compat as cpt
from paddle import _C_ops, _legacy_C_ops from paddle import _C_ops, _legacy_C_ops
...@@ -64,7 +77,8 @@ class NestSequence(object): ...@@ -64,7 +77,8 @@ class NestSequence(object):
var_ids = [] var_ids = []
for idx, var in enumerate(self.__input_list): for idx, var in enumerate(self.__input_list):
if isinstance( if isinstance(
var, (framework.Variable, core.VarBase, core.eager.Tensor)): var, (framework.Variable, core.VarBase, core.eager.Tensor)
):
var_ids.append(idx) var_ids.append(idx)
return var_ids return var_ids
...@@ -77,15 +91,17 @@ class NestSequence(object): ...@@ -77,15 +91,17 @@ class NestSequence(object):
warning_types = set() warning_types = set()
for var in self.__input_list: for var in self.__input_list:
if not isinstance( if not isinstance(
var, var, (framework.Variable, core.VarBase, core.eager.Tensor)
(framework.Variable, core.VarBase, core.eager.Tensor)): ):
warning_types.add(type(var)) warning_types.add(type(var))
if warning_types: if warning_types:
logging_utils.warn( logging_utils.warn(
"Output of traced function contains non-tensor type values: {}. " "Output of traced function contains non-tensor type values: {}. "
"Currently, We don't support to update them while training and will return " "Currently, We don't support to update them while training and will return "
"what we first saw. Please try to return them as tensor.". "what we first saw. Please try to return them as tensor.".format(
format(list(warning_types))) list(warning_types)
)
)
@property @property
def var_ids(self): def var_ids(self):
...@@ -139,12 +155,9 @@ class PartialProgramLayer: ...@@ -139,12 +155,9 @@ class PartialProgramLayer:
Layer: A Layer object that run all ops internally in static mode. Layer: A Layer object that run all ops internally in static mode.
""" """
def __init__(self, def __init__(
main_program, self, main_program, inputs, outputs, parameters=None, **kwargs
inputs, ):
outputs,
parameters=None,
**kwargs):
super(PartialProgramLayer, self).__init__() super(PartialProgramLayer, self).__init__()
self._inputs = NestSequence(inputs) self._inputs = NestSequence(inputs)
self._outputs = NestSequence(outputs, need_check=True) self._outputs = NestSequence(outputs, need_check=True)
...@@ -167,7 +180,8 @@ class PartialProgramLayer: ...@@ -167,7 +180,8 @@ class PartialProgramLayer:
# For AMP training # For AMP training
self._amp_list = AutoMixedPrecisionLists( self._amp_list = AutoMixedPrecisionLists(
custom_white_list=custom_white_list, custom_white_list=custom_white_list,
custom_black_list=custom_black_list) custom_black_list=custom_black_list,
)
# program_id -> list(scope) # program_id -> list(scope)
self._scope_cache = {} self._scope_cache = {}
...@@ -188,10 +202,6 @@ class PartialProgramLayer: ...@@ -188,10 +202,6 @@ class PartialProgramLayer:
else: else:
return core.Scope() return core.Scope()
@LazyInitialized
def __fake_vars(self):
return _create_fake_var()
@LazyInitialized @LazyInitialized
def _double_grads(self): def _double_grads(self):
return self._get_double_grads(self._origin_main_program) return self._get_double_grads(self._origin_main_program)
...@@ -203,7 +213,8 @@ class PartialProgramLayer: ...@@ -203,7 +213,8 @@ class PartialProgramLayer:
return self._origin_main_program.clone(for_test=is_infer_mode) return self._origin_main_program.clone(for_test=is_infer_mode)
else: else:
train_program = self._append_backward_desc( train_program = self._append_backward_desc(
self._origin_main_program) self._origin_main_program
)
# Note: Only set grad type once after initializing train program. So we put it here. # Note: Only set grad type once after initializing train program. So we put it here.
self._set_grad_type(self._params, train_program) self._set_grad_type(self._params, train_program)
return train_program return train_program
...@@ -223,16 +234,18 @@ class PartialProgramLayer: ...@@ -223,16 +234,18 @@ class PartialProgramLayer:
@switch_to_static_graph @switch_to_static_graph
def _create_pure_fp16_program(self, is_infer_mode=False): def _create_pure_fp16_program(self, is_infer_mode=False):
pure_fp16_program = self._origin_main_program.clone( pure_fp16_program = self._origin_main_program.clone(
for_test=is_infer_mode) for_test=is_infer_mode
)
with program_guard(pure_fp16_program): with program_guard(pure_fp16_program):
cast_model_to_fp16(pure_fp16_program, cast_model_to_fp16(
self._amp_list, pure_fp16_program, self._amp_list, use_fp16_guard=False
use_fp16_guard=False) )
if is_infer_mode: if is_infer_mode:
return pure_fp16_program return pure_fp16_program
else: else:
train_pure_fp16_program = self._append_backward_desc( train_pure_fp16_program = self._append_backward_desc(
pure_fp16_program) pure_fp16_program
)
self._set_grad_type(self._params, train_pure_fp16_program) self._set_grad_type(self._params, train_pure_fp16_program)
return train_pure_fp16_program return train_pure_fp16_program
...@@ -240,23 +253,27 @@ class PartialProgramLayer: ...@@ -240,23 +253,27 @@ class PartialProgramLayer:
def _create_forward_backward_train_program(self): def _create_forward_backward_train_program(self):
whole_program = self._create_program() whole_program = self._create_program()
forward_end_op_index = self._infer_program.desc.block(0).op_size() forward_end_op_index = self._infer_program.desc.block(0).op_size()
return self._get_forward_backward_program_form(whole_program, return self._get_forward_backward_program_form(
forward_end_op_index) whole_program, forward_end_op_index
)
@switch_to_static_graph @switch_to_static_graph
def _create_forward_backward_train_amp_program(self): def _create_forward_backward_train_amp_program(self):
whole_program = self._create_amp_program() whole_program = self._create_amp_program()
forward_end_op_index = self._infer_amp_program.desc.block(0).op_size() forward_end_op_index = self._infer_amp_program.desc.block(0).op_size()
return self._get_forward_backward_program_form(whole_program, return self._get_forward_backward_program_form(
forward_end_op_index) whole_program, forward_end_op_index
)
@switch_to_static_graph @switch_to_static_graph
def _create_forward_backward_train_pure_fp16_program(self): def _create_forward_backward_train_pure_fp16_program(self):
whole_program = self._create_pure_fp16_program() whole_program = self._create_pure_fp16_program()
forward_end_op_index = self._infer_pure_fp16_program.desc.block( forward_end_op_index = self._infer_pure_fp16_program.desc.block(
0).op_size() 0
return self._get_forward_backward_program_form(whole_program, ).op_size()
forward_end_op_index) return self._get_forward_backward_program_form(
whole_program, forward_end_op_index
)
@LazyInitialized @LazyInitialized
def _train_program(self): def _train_program(self):
...@@ -352,8 +369,9 @@ class PartialProgramLayer: ...@@ -352,8 +369,9 @@ class PartialProgramLayer:
@LazyInitialized @LazyInitialized
def _train_program_id(self): def _train_program_id(self):
program_id = _hash_with_id(self._train_program, self) program_id = _hash_with_id(self._train_program, self)
core._set_cached_executor_build_strategy(program_id, core._set_cached_executor_build_strategy(
self._build_strategy) program_id, self._build_strategy
)
return program_id return program_id
@LazyInitialized @LazyInitialized
...@@ -363,8 +381,9 @@ class PartialProgramLayer: ...@@ -363,8 +381,9 @@ class PartialProgramLayer:
@LazyInitialized @LazyInitialized
def _train_amp_program_id(self): def _train_amp_program_id(self):
program_id = _hash_with_id(self._train_amp_program, self) program_id = _hash_with_id(self._train_amp_program, self)
core._set_cached_executor_build_strategy(program_id, core._set_cached_executor_build_strategy(
self._build_strategy) program_id, self._build_strategy
)
return program_id return program_id
@LazyInitialized @LazyInitialized
...@@ -374,8 +393,9 @@ class PartialProgramLayer: ...@@ -374,8 +393,9 @@ class PartialProgramLayer:
@LazyInitialized @LazyInitialized
def _train_pure_fp16_program_id(self): def _train_pure_fp16_program_id(self):
program_id = _hash_with_id(self._train_pure_fp16_program, self) program_id = _hash_with_id(self._train_pure_fp16_program, self)
core._set_cached_executor_build_strategy(program_id, core._set_cached_executor_build_strategy(
self._build_strategy) program_id, self._build_strategy
)
return program_id return program_id
@LazyInitialized @LazyInitialized
...@@ -411,8 +431,9 @@ class PartialProgramLayer: ...@@ -411,8 +431,9 @@ class PartialProgramLayer:
return main_program return main_program
def prepare_gradient_aggregation(self, start_idx, main_program, def prepare_gradient_aggregation(
target_program): self, start_idx, main_program, target_program
):
""" """
Why we need add gradient aggregation operation ? Why we need add gradient aggregation operation ?
In some cases, if non leaf nodes are used as output, gradient overwriting will occur, such as In some cases, if non leaf nodes are used as output, gradient overwriting will occur, such as
...@@ -420,7 +441,7 @@ class PartialProgramLayer: ...@@ -420,7 +441,7 @@ class PartialProgramLayer:
x = 2 * in # <---- x is a non-leaf node in program. x = 2 * in # <---- x is a non-leaf node in program.
y = x + 3 y = x + 3
return x, y return x, y
loss = forward(in)[0].sum() loss = forward(in)[0].sum()
loss.backward() # <----- x@grad will be overwrited by elementwise_add_grad Op loss.backward() # <----- x@grad will be overwrited by elementwise_add_grad Op
""" """
...@@ -430,8 +451,8 @@ class PartialProgramLayer: ...@@ -430,8 +451,8 @@ class PartialProgramLayer:
if exist a op whose inputs is var, then return True if exist a op whose inputs is var, then return True
""" """
if not isinstance(var, framework.Variable) or var.type not in [ if not isinstance(var, framework.Variable) or var.type not in [
core.VarDesc.VarType.LOD_TENSOR, core.VarDesc.VarType.LOD_TENSOR,
core.VarDesc.VarType.SELECTED_ROWS core.VarDesc.VarType.SELECTED_ROWS,
]: ]:
return False return False
if var.dtype not in [paddle.float32, paddle.float64]: if var.dtype not in [paddle.float32, paddle.float64]:
...@@ -448,20 +469,28 @@ class PartialProgramLayer: ...@@ -448,20 +469,28 @@ class PartialProgramLayer:
new_grad_name = var.name + suffix + "@GRAD" new_grad_name = var.name + suffix + "@GRAD"
finded_ops = list( finded_ops = list(
filter( filter(
lambda x: x[0] >= start_idx and any([ lambda x: x[0] >= start_idx
out_arg == var_grad_name and any(
for out_arg in x[1].output_arg_names [
]), enumerate(target_program.block(0).ops))) out_arg == var_grad_name
for out_arg in x[1].output_arg_names
]
),
enumerate(target_program.block(0).ops),
)
)
# len(finded_ops) may equals zero when stop_gradient works. # len(finded_ops) may equals zero when stop_gradient works.
# len(finded_ops) may > 1, because we may have fill_constant op. # len(finded_ops) may > 1, because we may have fill_constant op.
if len(finded_ops) == 0: if len(finded_ops) == 0:
return None return None
# step1: create a new var named var.name@GRAD # step1: create a new var named var.name@GRAD
target_program.block(0).create_var(name=new_grad_name, target_program.block(0).create_var(
type=var.type, name=new_grad_name,
dtype=var.dtype, type=var.type,
shape=var.shape) dtype=var.dtype,
shape=var.shape,
)
# step2: rename the var.name@GRAD to var.name@GRAD@dy2static # step2: rename the var.name@GRAD to var.name@GRAD@dy2static
for idx, op in finded_ops: for idx, op in finded_ops:
op._rename_input(var_grad_name, new_grad_name) op._rename_input(var_grad_name, new_grad_name)
...@@ -472,11 +501,13 @@ class PartialProgramLayer: ...@@ -472,11 +501,13 @@ class PartialProgramLayer:
finded_ops[-1][0] + 1, finded_ops[-1][0] + 1,
type='sum', type='sum',
inputs={'X': [var_grad_name, new_grad_name]}, inputs={'X': [var_grad_name, new_grad_name]},
outputs={"Out": var_grad_name}) outputs={"Out": var_grad_name},
)
return None return None
to_processed_vars = list( to_processed_vars = list(
filter(_need_aggregation, self._outputs.tolist())) filter(_need_aggregation, self._outputs.tolist())
)
for _var in to_processed_vars: for _var in to_processed_vars:
_insert_aggregation_ops_for_var(target_program, _var) _insert_aggregation_ops_for_var(target_program, _var)
...@@ -489,11 +520,12 @@ class PartialProgramLayer: ...@@ -489,11 +520,12 @@ class PartialProgramLayer:
if isinstance(out, framework.Variable): if isinstance(out, framework.Variable):
targets.append(program.global_block().var(out.name)) targets.append(program.global_block().var(out.name))
if targets and self._params: if targets:
backward.gradients(targets=targets, inputs=[]) backward.gradients(targets=targets, inputs=[])
start_idx = len( start_idx = len(main_program.block(0).ops) + 2 * len(
main_program.block(0).ops) + 2 * len(self._outputs.tolist()) self._outputs.tolist()
)
self.prepare_gradient_aggregation(start_idx, main_program, program) self.prepare_gradient_aggregation(start_idx, main_program, program)
...@@ -512,7 +544,10 @@ class PartialProgramLayer: ...@@ -512,7 +544,10 @@ class PartialProgramLayer:
found_param = False found_param = False
for block in program.blocks: for block in program.blocks:
for op in block.ops: for op in block.ops:
if param.name in op.input_arg_names or param.name in op.output_arg_names: if (
param.name in op.input_arg_names
or param.name in op.output_arg_names
):
required_params.append(param) required_params.append(param)
found_param = True found_param = True
break break
...@@ -529,15 +564,21 @@ class PartialProgramLayer: ...@@ -529,15 +564,21 @@ class PartialProgramLayer:
var_desc = block.vars[name].desc var_desc = block.vars[name].desc
var_base = None var_base = None
if not framework._in_eager_mode_: if not framework._in_eager_mode_:
var_base = core.VarBase(var_desc.dtype(), var_base = core.VarBase(
var_desc.shape(), var_desc.dtype(),
var_desc.name(), var_desc.shape(),
var_desc.type(), False) var_desc.name(),
var_desc.type(),
False,
)
else: else:
var_base = core.eager.Tensor(var_desc.dtype(), var_base = core.eager.Tensor(
var_desc.shape(), var_desc.dtype(),
var_desc.name(), var_desc.shape(),
var_desc.type(), False) var_desc.name(),
var_desc.type(),
False,
)
double_grads.append(var_base) double_grads.append(var_base)
return self._valid_vars(double_grads) return self._valid_vars(double_grads)
...@@ -557,36 +598,62 @@ class PartialProgramLayer: ...@@ -557,36 +598,62 @@ class PartialProgramLayer:
attrs = [ attrs = [
'global_block', 'global_block',
self.program.desc.block(0), 'start_op_index', 0, 'end_op_index', self.program.desc.block(0),
self._get_end_op_index(), 'is_test', not self.training, 'start_op_index',
'program_id', self.program_id 0,
'end_op_index',
self._get_end_op_index(),
'is_test',
not self.training,
'program_id',
self.program_id,
] ]
if self._cuda_graph_capture_mode: if self._cuda_graph_capture_mode:
attrs.extend( attrs.extend(
('cuda_graph_capture_mode', self._cuda_graph_capture_mode, (
'cuda_graph_pool_id', self._cuda_graph_pool_id)) 'cuda_graph_capture_mode',
self._cuda_graph_capture_mode,
use_interpretorcore = _is_enable_standalone_executor( 'cuda_graph_pool_id',
) and _is_dy2st_enable_standalone_executor() self._cuda_graph_pool_id,
)
)
use_interpretorcore = (
_is_enable_standalone_executor()
and _is_dy2st_enable_standalone_executor()
)
attrs.extend(('use_interpretorcore', use_interpretorcore)) attrs.extend(('use_interpretorcore', use_interpretorcore))
if use_interpretorcore: if use_interpretorcore:
attrs.extend( attrs.extend(
('forward_global_block', self.forward_program.desc.block(0), (
'backward_global_block', self.backward_program.desc.block(0))) 'forward_global_block',
self.forward_program.desc.block(0),
'backward_global_block',
self.backward_program.desc.block(0),
)
)
_legacy_C_ops.run_program( _legacy_C_ops.run_program(
self._valid_vars(in_vars), self._valid_vars(self._params), self._valid_vars(in_vars),
self._valid_vars(self._params),
self._valid_vars(out_vars), self._valid_vars(out_vars),
self._create_scope_vec(program_id=self.program_id, self._create_scope_vec(
use_scope_cache=True), program_id=self.program_id, use_scope_cache=True
self._double_grads, self._cuda_graph_vec, *attrs) ),
self._double_grads,
self._cuda_graph_vec,
*attrs
)
else: else:
_legacy_C_ops.run_program(self._valid_vars(in_vars), _legacy_C_ops.run_program(
self._valid_vars(self._params), self._valid_vars(in_vars),
self._valid_vars(out_vars), self._valid_vars(self._params),
self._create_scope_vec(), self._valid_vars(out_vars),
self._double_grads, self._cuda_graph_vec, self._create_scope_vec(),
*attrs) self._double_grads,
self._cuda_graph_vec,
*attrs
)
restored_nest_out = self._restore_out(out_vars) restored_nest_out = self._restore_out(out_vars)
return self._remove_no_value(restored_nest_out) return self._remove_no_value(restored_nest_out)
...@@ -594,9 +661,11 @@ class PartialProgramLayer: ...@@ -594,9 +661,11 @@ class PartialProgramLayer:
if _in_pure_fp16_guard(): if _in_pure_fp16_guard():
for i, var in enumerate(in_vars): for i, var in enumerate(in_vars):
name = var.name name = var.name
if (self.program.global_block().has_var(name) if (
and self.program.global_block().var(name).dtype self.program.global_block().has_var(name)
== paddle.float16): and self.program.global_block().var(name).dtype
== paddle.float16
):
in_vars[i] = var.astype('float16') in_vars[i] = var.astype('float16')
in_vars[i].name = name in_vars[i].name = name
...@@ -627,25 +696,32 @@ class PartialProgramLayer: ...@@ -627,25 +696,32 @@ class PartialProgramLayer:
return self._infer_program return self._infer_program
@switch_to_static_graph @switch_to_static_graph
def _get_forward_backward_program_form(self, whole_program, def _get_forward_backward_program_form(
forward_end_op_index): self, whole_program, forward_end_op_index
):
forward_builded_program = add_build_strategy_for( forward_builded_program = add_build_strategy_for(
whole_program, 0, forward_end_op_index, self._build_strategy) whole_program, 0, forward_end_op_index, self._build_strategy
)
backward_start_op_index = forward_end_op_index + 2 * len( backward_start_op_index = forward_end_op_index + 2 * len(
self._outputs.var_ids) self._outputs.var_ids
)
backward_end_op_index = whole_program.desc.block(0).op_size() backward_end_op_index = whole_program.desc.block(0).op_size()
backward_builded_program = add_build_strategy_for( backward_builded_program = add_build_strategy_for(
whole_program, backward_start_op_index, backward_end_op_index, whole_program,
self._build_strategy) backward_start_op_index,
self._apply_inplace_pass(forward_builded_program, backward_end_op_index,
backward_builded_program) self._build_strategy,
)
self._apply_inplace_pass(
forward_builded_program, backward_builded_program
)
return [forward_builded_program, backward_builded_program] return [forward_builded_program, backward_builded_program]
def _apply_inplace_pass(self, forward_program, backward_program): def _apply_inplace_pass(self, forward_program, backward_program):
attr_types = { attr_types = {
"use_cuda": "bool", "use_cuda": "bool",
"mem_opt_skip_vars": "list[str]", "mem_opt_skip_vars": "list[str]",
"for_partial_block": "bool" "for_partial_block": "bool",
} }
empty_startup_program = paddle.static.Program() empty_startup_program = paddle.static.Program()
use_cuda = True if core.is_compiled_with_cuda() else False use_cuda = True if core.is_compiled_with_cuda() else False
...@@ -667,22 +743,33 @@ class PartialProgramLayer: ...@@ -667,22 +743,33 @@ class PartialProgramLayer:
forward_mem_opt_skip_vars.append(var.desc.name()) forward_mem_opt_skip_vars.append(var.desc.name())
backward_mem_opt_skip_vars.append(var.desc.name()) backward_mem_opt_skip_vars.append(var.desc.name())
for var_name in core.parse_safe_eager_deletion_skip_vars( for var_name in core.parse_safe_eager_deletion_skip_vars(
backward_program.desc): backward_program.desc
):
forward_mem_opt_skip_vars.append(var_name) forward_mem_opt_skip_vars.append(var_name)
attrs = { attrs = {
"use_cuda": use_cuda, "use_cuda": use_cuda,
"mem_opt_skip_vars": forward_mem_opt_skip_vars, "mem_opt_skip_vars": forward_mem_opt_skip_vars,
"for_partial_block": True "for_partial_block": True,
} }
_apply_pass(forward_program, empty_startup_program, _apply_pass(
"buffer_shared_inplace_pass", attrs, attr_types) forward_program,
empty_startup_program,
"buffer_shared_inplace_pass",
attrs,
attr_types,
)
attrs = { attrs = {
"use_cuda": use_cuda, "use_cuda": use_cuda,
"mem_opt_skip_vars": backward_mem_opt_skip_vars, "mem_opt_skip_vars": backward_mem_opt_skip_vars,
"for_partial_block": True "for_partial_block": True,
} }
_apply_pass(backward_program, empty_startup_program, _apply_pass(
"buffer_shared_inplace_pass", attrs, attr_types) backward_program,
empty_startup_program,
"buffer_shared_inplace_pass",
attrs,
attr_types,
)
def _prepare(self, inputs): def _prepare(self, inputs):
""" """
...@@ -698,23 +785,28 @@ class PartialProgramLayer: ...@@ -698,23 +785,28 @@ class PartialProgramLayer:
if isinstance(value, np.ndarray): if isinstance(value, np.ndarray):
var = None var = None
if not framework._in_eager_mode_: if not framework._in_eager_mode_:
var = core.VarBase(value=value, var = core.VarBase(
name=self._inputs[i].desc.name(), value=value,
persistable=False, name=self._inputs[i].desc.name(),
place=expected_place, persistable=False,
zero_copy=True) place=expected_place,
zero_copy=True,
)
else: else:
var = core.eager.Tensor(value=value, var = core.eager.Tensor(
name=self._inputs[i].desc.name(), value=value,
persistable=False, name=self._inputs[i].desc.name(),
place=expected_place, persistable=False,
zero_copy=True) place=expected_place,
zero_copy=True,
)
elif isinstance(value, (core.VarBase, core.eager.Tensor)): elif isinstance(value, (core.VarBase, core.eager.Tensor)):
# NOTE(Aurelius84): If var is on CPUPlace, it will be transformed multi times # NOTE(Aurelius84): If var is on CPUPlace, it will be transformed multi times
# into CUDAPlace when it's as input of multi Ops. so we move it in advance # into CUDAPlace when it's as input of multi Ops. so we move it in advance
# to avoid this problem. # to avoid this problem.
if value.stop_gradient and not value.place._equals( if value.stop_gradient and not value.place._equals(
expected_place): expected_place
):
var = value._copy_to(expected_place, False) var = value._copy_to(expected_place, False)
var.stop_gradient = True var.stop_gradient = True
else: else:
...@@ -737,12 +829,21 @@ class PartialProgramLayer: ...@@ -737,12 +829,21 @@ class PartialProgramLayer:
return out_varbase_map[var_desc.name()] return out_varbase_map[var_desc.name()]
if not framework._in_eager_mode_: if not framework._in_eager_mode_:
var_base = core.VarBase(var_desc.dtype(), var_desc.shape(), var_base = core.VarBase(
var_desc.name(), var_desc.type(), False) var_desc.dtype(),
var_desc.shape(),
var_desc.name(),
var_desc.type(),
False,
)
else: else:
var_base = core.eager.Tensor(var_desc.dtype(), var_desc.shape(), var_base = core.eager.Tensor(
var_desc.name(), var_desc.type(), var_desc.dtype(),
False) var_desc.shape(),
var_desc.name(),
var_desc.type(),
False,
)
var_base.stop_gradient = var.stop_gradient var_base.stop_gradient = var.stop_gradient
out_varbase_map[var_desc.name()] = var_base out_varbase_map[var_desc.name()] = var_base
return var_base return var_base
...@@ -755,20 +856,30 @@ class PartialProgramLayer: ...@@ -755,20 +856,30 @@ class PartialProgramLayer:
def _create_scope_vec(self, program_id=None, use_scope_cache=False): def _create_scope_vec(self, program_id=None, use_scope_cache=False):
# Hold forward variables # Hold forward variables
tmp_scope_vec = None tmp_scope_vec = None
inner_scope = self._get_scope(program_id=program_id, inner_scope = self._get_scope(
use_scope_cache=use_scope_cache) program_id=program_id, use_scope_cache=use_scope_cache
)
if not framework._in_eager_mode_: if not framework._in_eager_mode_:
tmp_scope_vec = core.VarBase(core.VarDesc.VarType.FP32, [], tmp_scope_vec = core.VarBase(
"program_out_scope", core.VarDesc.VarType.FP32,
core.VarDesc.VarType.STEP_SCOPES, True) [],
"program_out_scope",
core.VarDesc.VarType.STEP_SCOPES,
True,
)
tmp_scope_vec.value().set_scope(inner_scope) tmp_scope_vec.value().set_scope(inner_scope)
else: else:
tmp_scope_vec = [inner_scope] tmp_scope_vec = [inner_scope]
return tmp_scope_vec return tmp_scope_vec
def _create_cuda_graph_vec(self): def _create_cuda_graph_vec(self):
var = core.VarBase(core.VarDesc.VarType.FP32, [], "cuda_graph", var = core.VarBase(
core.VarDesc.VarType.RAW, True) core.VarDesc.VarType.FP32,
[],
"cuda_graph",
core.VarDesc.VarType.RAW,
True,
)
var.stop_gradient = True var.stop_gradient = True
return var return var
...@@ -791,8 +902,9 @@ class PartialProgramLayer: ...@@ -791,8 +902,9 @@ class PartialProgramLayer:
return main_program.clone(for_test=True) return main_program.clone(for_test=True)
def _is_no_value(self, var): def _is_no_value(self, var):
if isinstance(var, if isinstance(var, (core.VarBase, core.eager.Tensor)) and var.shape == [
(core.VarBase, core.eager.Tensor)) and var.shape == [1]: 1
]:
# NOTE: .numpy() will insert MemcpySync operation, it hits performance. # NOTE: .numpy() will insert MemcpySync operation, it hits performance.
if var.numpy()[0] == RETURN_NO_VALUE_MAGIC_NUM: if var.numpy()[0] == RETURN_NO_VALUE_MAGIC_NUM:
return True return True
...@@ -808,13 +920,14 @@ class PartialProgramLayer: ...@@ -808,13 +920,14 @@ class PartialProgramLayer:
return out_vars return out_vars
elif isinstance(out_vars, (tuple, list)): elif isinstance(out_vars, (tuple, list)):
if isinstance(out_vars, tuple): if isinstance(out_vars, tuple):
res = tuple(var for var in out_vars res = tuple(
if not self._is_no_value(var)) var for var in out_vars if not self._is_no_value(var)
)
else: else:
# isinstance(out_vars, list) # isinstance(out_vars, list)
res = [var for var in out_vars if not self._is_no_value(var)] res = [var for var in out_vars if not self._is_no_value(var)]
has_removed = (len(out_vars) > len(res)) has_removed = len(out_vars) > len(res)
# len(out_vars) > len(res) means we have removed var. This is # len(out_vars) > len(res) means we have removed var. This is
# preventing out_vars is empty or just one element at the beginning # preventing out_vars is empty or just one element at the beginning
if len(res) == 0 and has_removed: if len(res) == 0 and has_removed:
...@@ -835,7 +948,8 @@ class PartialProgramLayer: ...@@ -835,7 +948,8 @@ class PartialProgramLayer:
for param in params: for param in params:
grad_name = param.name + core.grad_var_suffix() grad_name = param.name + core.grad_var_suffix()
grad_var = train_program.desc.block(0).find_var( grad_var = train_program.desc.block(0).find_var(
cpt.to_bytes(grad_name)) cpt.to_bytes(grad_name)
)
# NOTE: cannot find var desc maybe no problem, such as in batch_norm # NOTE: cannot find var desc maybe no problem, such as in batch_norm
if grad_var is None: if grad_var is None:
continue continue
...@@ -864,15 +978,18 @@ class PartialProgramLayer: ...@@ -864,15 +978,18 @@ class PartialProgramLayer:
if not isinstance(self._params, (list, tuple)): if not isinstance(self._params, (list, tuple)):
raise TypeError( raise TypeError(
"Type of self._params in PartialProgramLayer should be list or tuple, but received %s." "Type of self._params in PartialProgramLayer should be list or tuple, but received %s."
% type(self._params)) % type(self._params)
)
param_and_buffer_names_set = set() param_and_buffer_names_set = set()
for i, var in enumerate(self._params): for i, var in enumerate(self._params):
# self._params constains parameters and buffers with persistable=True. # self._params constains parameters and buffers with persistable=True.
if not isinstance(var, (core.VarBase, core.eager.Tensor)): if not isinstance(var, (core.VarBase, core.eager.Tensor)):
raise TypeError( raise TypeError(
'Type of self._params[{}] in PartialProgramLayer should be Parameter or Variable, but received {}.' 'Type of self._params[{}] in PartialProgramLayer should be Parameter or Variable, but received {}.'.format(
.format(i, type(var))) i, type(var)
)
)
param_and_buffer_names_set.add(var.name) param_and_buffer_names_set.add(var.name)
for block in main_program.blocks: for block in main_program.blocks:
...@@ -886,15 +1003,11 @@ class PartialProgramLayer: ...@@ -886,15 +1003,11 @@ class PartialProgramLayer:
"\n\tRevise suggestion: " "\n\tRevise suggestion: "
"\n\t\t1. Please ensure all your sublayers are inheritted from nn.Layer." "\n\t\t1. Please ensure all your sublayers are inheritted from nn.Layer."
"\n\t\t2. Please use nn.ParameterList and nn.LayerList as container instead of using a native Python container such as List" "\n\t\t2. Please use nn.ParameterList and nn.LayerList as container instead of using a native Python container such as List"
% name) % name
)
def _valid_vars(self, vars): def _valid_vars(self, vars):
""" return vars if vars else None
Note: run_program_op.InferShape requires `X`/'Out' not be null.
But it's common in dy2static, fake varBase is created to handle the
problem.
"""
return vars if vars else self.__fake_vars
def _create_fake_var(): def _create_fake_var():
...@@ -903,13 +1016,23 @@ def _create_fake_var(): ...@@ -903,13 +1016,23 @@ def _create_fake_var():
""" """
if not framework._in_eager_mode_: if not framework._in_eager_mode_:
return [ return [
core.VarBase(core.VarDesc.VarType.FP32, [], "Fake_var", core.VarBase(
core.VarDesc.VarType.RAW, False) core.VarDesc.VarType.FP32,
[],
"Fake_var",
core.VarDesc.VarType.RAW,
False,
)
] ]
else: else:
return [ return [
core.eager.Tensor(core.VarDesc.VarType.FP32, [], "Fake_var", core.eager.Tensor(
core.VarDesc.VarType.RAW, False) core.VarDesc.VarType.FP32,
[],
"Fake_var",
core.VarDesc.VarType.RAW,
False,
)
] ]
...@@ -918,23 +1041,27 @@ def partial_program_from(concrete_program): ...@@ -918,23 +1041,27 @@ def partial_program_from(concrete_program):
if inputs and isinstance(inputs[0], layers.Layer): if inputs and isinstance(inputs[0], layers.Layer):
inputs = inputs[1:] inputs = inputs[1:]
return PartialProgramLayer(concrete_program.main_program, inputs, return PartialProgramLayer(
concrete_program.outputs, concrete_program.main_program,
concrete_program.parameters, inputs,
**concrete_program.kwargs) concrete_program.outputs,
concrete_program.parameters,
**concrete_program.kwargs
)
@switch_to_static_graph @switch_to_static_graph
def add_build_strategy_for(program, def add_build_strategy_for(
start_op_index, program, start_op_index, end_op_index, build_strategy=None
end_op_index, ):
build_strategy=None): if start_op_index < end_op_index:
if (start_op_index < end_op_index):
compiled_program = paddle.static.CompiledProgram( compiled_program = paddle.static.CompiledProgram(
core.Graph(program.desc, start_op_index, end_op_index), core.Graph(program.desc, start_op_index, end_op_index),
build_strategy=build_strategy) build_strategy=build_strategy,
compiled_program._compile(core.Scope(), )
framework._current_expected_place()) compiled_program._compile(
core.Scope(), framework._current_expected_place()
)
ir_graph = framework.IrGraph(compiled_program._graph) ir_graph = framework.IrGraph(compiled_program._graph)
builded_program = ir_graph.to_program() builded_program = ir_graph.to_program()
if hasattr(compiled_program._program, 'lr_sheduler'): if hasattr(compiled_program._program, 'lr_sheduler'):
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
class Net(paddle.nn.Layer):
def __init__(self):
super(Net, self).__init__()
@paddle.jit.to_static
def forward(self, x):
out = x + 1
return out
class TestBackwardWithoutParams(unittest.TestCase):
def test_run(self):
net = Net()
x = paddle.ones([2, 2])
x.stop_gradient = False
out = net(x)
loss = paddle.mean(out)
loss.backward()
np.testing.assert_equal(x.grad.numpy(), np.full(x.shape, 0.25))
if __name__ == '__main__':
unittest.main()
...@@ -292,7 +292,6 @@ def for_tuple_as_enumerate_value(x_array): ...@@ -292,7 +292,6 @@ def for_tuple_as_enumerate_value(x_array):
# 20. test for function in a class # 20. test for function in a class
class ForwardContainsForLayer(paddle.nn.Layer): class ForwardContainsForLayer(paddle.nn.Layer):
def __init__(self): def __init__(self):
super(ForwardContainsForLayer, self).__init__() super(ForwardContainsForLayer, self).__init__()
self.high = 5 self.high = 5
...@@ -328,8 +327,8 @@ def for_original_tuple(): ...@@ -328,8 +327,8 @@ def for_original_tuple():
# 23. for zip error # 23. for zip error
@paddle.jit.to_static( @paddle.jit.to_static(
input_spec=[InputSpec(shape=[None, 10]), input_spec=[InputSpec(shape=[None, 10]), InputSpec(shape=[None, 10])]
InputSpec(shape=[None, 10])]) )
def for_zip_error(x, y): def for_zip_error(x, y):
for i, j in zip(x, y): for i, j in zip(x, y):
a = i + j a = i + j
...@@ -338,8 +337,8 @@ def for_zip_error(x, y): ...@@ -338,8 +337,8 @@ def for_zip_error(x, y):
# 24. for zip # 24. for zip
@paddle.jit.to_static( @paddle.jit.to_static(
input_spec=[InputSpec(shape=[2, 10]), input_spec=[InputSpec(shape=[2, 10]), InputSpec(shape=[2, 10])]
InputSpec(shape=[2, 10])]) )
def for_zip(x, y): def for_zip(x, y):
for i, j in zip(x, y): for i, j in zip(x, y):
a = i + j a = i + j
...@@ -347,10 +346,12 @@ def for_zip(x, y): ...@@ -347,10 +346,12 @@ def for_zip(x, y):
class TestTransformBase(unittest.TestCase): class TestTransformBase(unittest.TestCase):
def setUp(self): def setUp(self):
self.place = fluid.CUDAPlace( self.place = (
0) if fluid.is_compiled_with_cuda() else fluid.CPUPlace() fluid.CUDAPlace(0)
if fluid.is_compiled_with_cuda()
else fluid.CPUPlace()
)
self.set_input() self.set_input()
self.set_test_func() self.set_test_func()
...@@ -359,7 +360,8 @@ class TestTransformBase(unittest.TestCase): ...@@ -359,7 +360,8 @@ class TestTransformBase(unittest.TestCase):
def set_test_func(self): def set_test_func(self):
raise NotImplementedError( raise NotImplementedError(
"For Enumerate test should implement set_test_func") "For Enumerate test should implement set_test_func"
)
def _run(self, to_static): def _run(self, to_static):
program_translator.enable(to_static) program_translator.enable(to_static)
...@@ -374,22 +376,21 @@ class TestTransformBase(unittest.TestCase): ...@@ -374,22 +376,21 @@ class TestTransformBase(unittest.TestCase):
class TestTransform(TestTransformBase): class TestTransform(TestTransformBase):
def transformed_result_compare(self): def transformed_result_compare(self):
dy_outs = self.get_dygraph_output() dy_outs = self.get_dygraph_output()
if not isinstance(dy_outs, (tuple, list)): if not isinstance(dy_outs, (tuple, list)):
dy_outs = (dy_outs, ) dy_outs = (dy_outs,)
self.dygraph_func.eval()
st_outs = self.get_static_output() st_outs = self.get_static_output()
if not isinstance(st_outs, (tuple, list)): if not isinstance(st_outs, (tuple, list)):
st_outs = (st_outs, ) st_outs = (st_outs,)
for x, y in zip(dy_outs, st_outs): for x, y in zip(dy_outs, st_outs):
np.testing.assert_allclose(x.numpy(), y.numpy(), rtol=1e-05) np.testing.assert_allclose(x.numpy(), y.numpy(), rtol=1e-05)
class TestTransformForOriginalList(TestTransform): class TestTransformForOriginalList(TestTransform):
def _run(self, to_static): def _run(self, to_static):
program_translator.enable(to_static) program_translator.enable(to_static)
with fluid.dygraph.guard(): with fluid.dygraph.guard():
...@@ -397,7 +398,6 @@ class TestTransformForOriginalList(TestTransform): ...@@ -397,7 +398,6 @@ class TestTransformForOriginalList(TestTransform):
class TestTransformError(TestTransformBase): class TestTransformError(TestTransformBase):
def transformed_error(self, etype): def transformed_error(self, etype):
with self.assertRaises(etype): with self.assertRaises(etype):
dy_out = self.get_dygraph_output() dy_out = self.get_dygraph_output()
...@@ -405,7 +405,6 @@ class TestTransformError(TestTransformBase): ...@@ -405,7 +405,6 @@ class TestTransformError(TestTransformBase):
class TestForInRange(TestTransform): class TestForInRange(TestTransform):
def set_input(self): def set_input(self):
self.input = np.array([5]) self.input = np.array([5])
...@@ -417,7 +416,6 @@ class TestForInRange(TestTransform): ...@@ -417,7 +416,6 @@ class TestForInRange(TestTransform):
class TestForIterList(TestTransform): class TestForIterList(TestTransform):
def set_test_func(self): def set_test_func(self):
self.dygraph_func = for_iter_list self.dygraph_func = for_iter_list
...@@ -426,19 +424,16 @@ class TestForIterList(TestTransform): ...@@ -426,19 +424,16 @@ class TestForIterList(TestTransform):
class TestForEnumerateSimple(TestForIterList): class TestForEnumerateSimple(TestForIterList):
def set_test_func(self): def set_test_func(self):
self.dygraph_func = for_enumerate_list self.dygraph_func = for_enumerate_list
class TestForInRangeWithBreak(TestForInRange): class TestForInRangeWithBreak(TestForInRange):
def set_test_func(self): def set_test_func(self):
self.dygraph_func = for_in_range_with_break self.dygraph_func = for_in_range_with_break
class TestForIterVarNumpy(TestTransform): class TestForIterVarNumpy(TestTransform):
def set_input(self): def set_input(self):
self.input = np.array([1, 2, 3, 4, 5]) self.input = np.array([1, 2, 3, 4, 5])
...@@ -450,103 +445,86 @@ class TestForIterVarNumpy(TestTransform): ...@@ -450,103 +445,86 @@ class TestForIterVarNumpy(TestTransform):
class TestForEnumerateVarNumpy(TestForIterVarNumpy): class TestForEnumerateVarNumpy(TestForIterVarNumpy):
def set_test_func(self): def set_test_func(self):
self.dygraph_func = for_enumerate_var_numpy self.dygraph_func = for_enumerate_var_numpy
class TestForEnumerateVarNumpyWithStart(TestForIterVarNumpy): class TestForEnumerateVarNumpyWithStart(TestForIterVarNumpy):
def set_test_func(self): def set_test_func(self):
self.dygraph_func = for_enumerate_var_numpy_with_start self.dygraph_func = for_enumerate_var_numpy_with_start
class TestForEnumerateVarNumpyWithBreak(TestForIterVarNumpy): class TestForEnumerateVarNumpyWithBreak(TestForIterVarNumpy):
def set_test_func(self): def set_test_func(self):
self.dygraph_func = for_enumerate_var_numpy_with_break self.dygraph_func = for_enumerate_var_numpy_with_break
class TestForEnumerateVarNumpyWithContinue(TestForIterVarNumpy): class TestForEnumerateVarNumpyWithContinue(TestForIterVarNumpy):
def set_test_func(self): def set_test_func(self):
self.dygraph_func = for_enumerate_var_numpy_with_continue self.dygraph_func = for_enumerate_var_numpy_with_continue
class TestForEnumerateVarNumpyWithStartAndBreak(TestForIterVarNumpy): class TestForEnumerateVarNumpyWithStartAndBreak(TestForIterVarNumpy):
def set_test_func(self): def set_test_func(self):
self.dygraph_func = for_enumerate_var_numpy_with_start_break self.dygraph_func = for_enumerate_var_numpy_with_start_break
class TestForEnumerateVarNumpyWithStartAndContinue(TestForIterVarNumpy): class TestForEnumerateVarNumpyWithStartAndContinue(TestForIterVarNumpy):
def set_test_func(self): def set_test_func(self):
self.dygraph_func = for_enumerate_var_numpy_with_start_continue self.dygraph_func = for_enumerate_var_numpy_with_start_continue
class TestForIterVar(TestForIterVarNumpy): class TestForIterVar(TestForIterVarNumpy):
def set_test_func(self): def set_test_func(self):
self.dygraph_func = for_iter_var self.dygraph_func = for_iter_var
class TestForIterVarIdx(TestForIterVarNumpy): class TestForIterVarIdx(TestForIterVarNumpy):
def set_test_func(self): def set_test_func(self):
self.dygraph_func = for_iter_var_idx self.dygraph_func = for_iter_var_idx
class TestForEnumerateVar(TestForIterVarNumpy): class TestForEnumerateVar(TestForIterVarNumpy):
def set_test_func(self): def set_test_func(self):
self.dygraph_func = for_enumerate_var self.dygraph_func = for_enumerate_var
class TestForEnumerateVarWithNestedRange(TestForIterVarNumpy): class TestForEnumerateVarWithNestedRange(TestForIterVarNumpy):
def set_test_func(self): def set_test_func(self):
self.dygraph_func = for_enumerate_var_with_nested_range self.dygraph_func = for_enumerate_var_with_nested_range
class TestForIterVarList(TestForInRange): class TestForIterVarList(TestForInRange):
def set_test_func(self): def set_test_func(self):
self.dygraph_func = for_iter_var_list self.dygraph_func = for_iter_var_list
class TestForEnumerateVarList(TestForInRange): class TestForEnumerateVarList(TestForInRange):
def set_test_func(self): def set_test_func(self):
self.dygraph_func = for_enumerate_var_list self.dygraph_func = for_enumerate_var_list
class TestForTupleAsIterVar(TestForIterVarNumpy): class TestForTupleAsIterVar(TestForIterVarNumpy):
def set_test_func(self): def set_test_func(self):
self.dygraph_func = for_tuple_as_iter_var self.dygraph_func = for_tuple_as_iter_var
class TestForTupleAsEnumerateIter(TestForIterVarNumpy): class TestForTupleAsEnumerateIter(TestForIterVarNumpy):
def set_test_func(self): def set_test_func(self):
self.dygraph_func = for_tuple_as_enumerate_iter self.dygraph_func = for_tuple_as_enumerate_iter
class TestForTupleAsEnumerateValue(TestForIterVarNumpy): class TestForTupleAsEnumerateValue(TestForIterVarNumpy):
def set_test_func(self): def set_test_func(self):
self.dygraph_func = for_tuple_as_enumerate_value self.dygraph_func = for_tuple_as_enumerate_value
class TestForwardContainsForLayer(TestForIterVarNumpy): class TestForwardContainsForLayer(TestForIterVarNumpy):
def set_test_func(self): def set_test_func(self):
self.dygraph_func = ForwardContainsForLayer() self.dygraph_func = ForwardContainsForLayer()
class TestForOriginalList(TestTransformForOriginalList): class TestForOriginalList(TestTransformForOriginalList):
def set_test_func(self): def set_test_func(self):
self.dygraph_func = for_original_list self.dygraph_func = for_original_list
...@@ -555,7 +533,6 @@ class TestForOriginalList(TestTransformForOriginalList): ...@@ -555,7 +533,6 @@ class TestForOriginalList(TestTransformForOriginalList):
class TestForOriginalTuple(TestTransformForOriginalList): class TestForOriginalTuple(TestTransformForOriginalList):
def set_test_func(self): def set_test_func(self):
self.dygraph_func = for_original_tuple self.dygraph_func = for_original_tuple
...@@ -564,7 +541,6 @@ class TestForOriginalTuple(TestTransformForOriginalList): ...@@ -564,7 +541,6 @@ class TestForOriginalTuple(TestTransformForOriginalList):
class TestForZip(unittest.TestCase): class TestForZip(unittest.TestCase):
def setUp(self): def setUp(self):
self.temp_dir = tempfile.TemporaryDirectory() self.temp_dir = tempfile.TemporaryDirectory()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册