未验证 提交 df3f74df 编写于 作者: 姜永久 提交者: GitHub

rm legacy dygraph part7 (#49285)

* rm legacy dygraph part7

* rm non_static_mode

* modify

* modify

* add static test

* set static for lstm_cudnn test

* reset tracer

* reset varbase

* fix
上级 e81883e6
...@@ -20,7 +20,6 @@ from paddle.fluid import core ...@@ -20,7 +20,6 @@ from paddle.fluid import core
import contextlib import contextlib
from paddle.fluid.framework import ( from paddle.fluid.framework import (
Variable, Variable,
_non_static_mode,
OpProtoHolder, OpProtoHolder,
Parameter, Parameter,
_dygraph_tracer, _dygraph_tracer,
......
...@@ -27,7 +27,6 @@ from ..data_feeder import convert_dtype ...@@ -27,7 +27,6 @@ from ..data_feeder import convert_dtype
import warnings import warnings
from ..framework import ( from ..framework import (
_get_paddle_place, _get_paddle_place,
_in_legacy_dygraph,
_in_eager_without_dygraph_check, _in_eager_without_dygraph_check,
) )
import paddle import paddle
...@@ -113,11 +112,7 @@ _functional_dygraph_context_manager = None ...@@ -113,11 +112,7 @@ _functional_dygraph_context_manager = None
@signature_safe_contextmanager @signature_safe_contextmanager
def param_guard(parameters): def param_guard(parameters):
# Note: parameters is a reference of self._parameters or self._buffers # Note: parameters is a reference of self._parameters or self._buffers
if ( if in_declarative_mode() and not framework.in_dygraph_mode() and parameters:
in_declarative_mode()
and not framework._non_static_mode()
and parameters
):
origin_parameters = parameters.copy() origin_parameters = parameters.copy()
for name, var_base in parameters.items(): for name, var_base in parameters.items():
if isinstance(var_base, list): if isinstance(var_base, list):
...@@ -189,7 +184,7 @@ def enabled(): ...@@ -189,7 +184,7 @@ def enabled():
print(fluid.dygraph.enabled()) # False print(fluid.dygraph.enabled()) # False
""" """
# TODO(jiabin): Make this check as in_dygraph_mode when we support default eager mode. # TODO(jiabin): Make this check as in_dygraph_mode when we support default eager mode.
return framework._non_static_mode() return framework.in_dygraph_mode()
def enable_dygraph(place=None): def enable_dygraph(place=None):
......
...@@ -18,7 +18,6 @@ import functools ...@@ -18,7 +18,6 @@ import functools
from ..framework import ( from ..framework import (
Variable, Variable,
default_main_program, default_main_program,
_non_static_mode,
dygraph_only, dygraph_only,
Parameter, Parameter,
ParamBase, ParamBase,
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
import warnings import warnings
from paddle.fluid.framework import default_main_program, _non_static_mode from paddle.fluid.framework import default_main_program, in_dygraph_mode
class LayerOpsRecoder: class LayerOpsRecoder:
...@@ -34,7 +34,7 @@ def record_program_ops_pre_hook(layer, inputs): ...@@ -34,7 +34,7 @@ def record_program_ops_pre_hook(layer, inputs):
""" """
A pre-hook to mark op numbers before enter layer.forward. A pre-hook to mark op numbers before enter layer.forward.
""" """
if not _non_static_mode(): if not in_dygraph_mode():
if layer._op_recorder.start < 0: if layer._op_recorder.start < 0:
layer._op_recorder.start = len( layer._op_recorder.start = len(
default_main_program().current_block().ops default_main_program().current_block().ops
...@@ -55,7 +55,7 @@ def set_op_customized_attrs_post_hook(layer, inputs, outputs): ...@@ -55,7 +55,7 @@ def set_op_customized_attrs_post_hook(layer, inputs, outputs):
""" """
A post-hook to append customized attributes into all operators generated in current layer. A post-hook to append customized attributes into all operators generated in current layer.
""" """
if not _non_static_mode() and layer._op_recorder.is_valid: if not in_dygraph_mode() and layer._op_recorder.is_valid:
start = layer._op_recorder.start start = layer._op_recorder.start
end = len(default_main_program().current_block().ops) end = len(default_main_program().current_block().ops)
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import copy import copy
from ..framework import Parameter, _non_static_mode, _global_flags from ..framework import Parameter, in_dygraph_mode, _global_flags
from ..param_attr import ParamAttr from ..param_attr import ParamAttr
from .. import core from .. import core
...@@ -169,7 +169,7 @@ class LayerObjectHelper(LayerHelperBase): ...@@ -169,7 +169,7 @@ class LayerObjectHelper(LayerHelperBase):
if (use_mkldnn is not None) and use_mkldnn: if (use_mkldnn is not None) and use_mkldnn:
act['use_mkldnn'] = use_mkldnn act['use_mkldnn'] = use_mkldnn
act_type = act.pop('type') act_type = act.pop('type')
if _non_static_mode(): if in_dygraph_mode():
res = _append_activation_in_dygraph( res = _append_activation_in_dygraph(
input_var, act_type, use_cudnn, use_mkldnn input_var, act_type, use_cudnn, use_mkldnn
) )
......
...@@ -46,7 +46,6 @@ from paddle.fluid import framework ...@@ -46,7 +46,6 @@ from paddle.fluid import framework
from ..param_attr import ParamAttr from ..param_attr import ParamAttr
from paddle.fluid.executor import Executor, global_scope from paddle.fluid.executor import Executor, global_scope
from paddle.fluid.framework import ( from paddle.fluid.framework import (
_non_static_mode,
convert_np_dtype_to_dtype_, convert_np_dtype_to_dtype_,
in_dygraph_mode, in_dygraph_mode,
) )
...@@ -153,7 +152,7 @@ class Layer: ...@@ -153,7 +152,7 @@ class Layer:
self._helper = LayerObjectHelper(self._full_name) self._helper = LayerObjectHelper(self._full_name)
self._built = False self._built = False
self._dtype = dtype self._dtype = dtype
self._init_in_dynamic_mode = framework._non_static_mode() self._init_in_dynamic_mode = in_dygraph_mode()
self._parameters = collections.OrderedDict() self._parameters = collections.OrderedDict()
# Buffers the variable (not parameter) created in layer # Buffers the variable (not parameter) created in layer
...@@ -211,7 +210,7 @@ class Layer: ...@@ -211,7 +210,7 @@ class Layer:
# global setting in dygraph # global setting in dygraph
# NOTE(chenweihang): nn.Layer also can be used in static mode, # NOTE(chenweihang): nn.Layer also can be used in static mode,
# but _dygraph_tracer() can not be called in static mode # but _dygraph_tracer() can not be called in static mode
if _non_static_mode(): if in_dygraph_mode():
framework._dygraph_tracer().train_mode() framework._dygraph_tracer().train_mode()
# Layer-level setting # Layer-level setting
self.training = True self.training = True
...@@ -252,7 +251,7 @@ class Layer: ...@@ -252,7 +251,7 @@ class Layer:
# global setting in dygraph # global setting in dygraph
# NOTE(chenweihang): nn.Layer also can be used in static mode, # NOTE(chenweihang): nn.Layer also can be used in static mode,
# but _dygraph_tracer() can not be called in static mode # but _dygraph_tracer() can not be called in static mode
if _non_static_mode(): if in_dygraph_mode():
framework._dygraph_tracer().eval_mode() framework._dygraph_tracer().eval_mode()
# Layer-level setting # Layer-level setting
self.training = False self.training = False
...@@ -1667,7 +1666,7 @@ class Layer: ...@@ -1667,7 +1666,7 @@ class Layer:
for key in state_dict.keys(): for key in state_dict.keys():
if key not in match_keys: if key not in match_keys:
unexpected_keys.append(key) unexpected_keys.append(key)
if _non_static_mode(): if in_dygraph_mode():
for param, state in matched_param_state: for param, state in matched_param_state:
param.set_value(state) param.set_value(state)
else: else:
......
...@@ -17,7 +17,6 @@ from ..framework import ( ...@@ -17,7 +17,6 @@ from ..framework import (
Variable, Variable,
convert_np_dtype_to_dtype_, convert_np_dtype_to_dtype_,
_varbase_creator, _varbase_creator,
_in_legacy_dygraph,
in_dygraph_mode, in_dygraph_mode,
) )
from ..layers.layer_function_generator import OpProtoHolder from ..layers.layer_function_generator import OpProtoHolder
...@@ -123,17 +122,13 @@ def monkey_patch_math_varbase(): ...@@ -123,17 +122,13 @@ def monkey_patch_math_varbase():
""" """
if not isinstance(dtype, core.VarDesc.VarType): if not isinstance(dtype, core.VarDesc.VarType):
dtype = convert_np_dtype_to_dtype_(dtype) dtype = convert_np_dtype_to_dtype_(dtype)
if _in_legacy_dygraph():
return _legacy_C_ops.cast(
self, 'in_dtype', self.dtype, 'out_dtype', dtype
)
return _C_ops.cast(self, dtype) return _C_ops.cast(self, dtype)
def _scalar_elementwise_op_(var, scale, bias): def _scalar_elementwise_op_(var, scale, bias):
if framework.in_dygraph_mode(): if framework.in_dygraph_mode():
return _C_ops.scale(var, float(scale), bias, True) return _C_ops.scale(var, float(scale), bias, True)
return _legacy_C_ops.scale(var, 'scale', scale, 'bias', bias) else:
return _legacy_C_ops.scale(var, 'scale', scale, 'bias', bias)
def _neg_(var): def _neg_(var):
return _scalar_elementwise_op_(var, -1.0, 0.0) return _scalar_elementwise_op_(var, -1.0, 0.0)
...@@ -194,10 +189,7 @@ def monkey_patch_math_varbase(): ...@@ -194,10 +189,7 @@ def monkey_patch_math_varbase():
perm = [] perm = []
for i in range(len(var.shape)): for i in range(len(var.shape)):
perm.insert(0, i) perm.insert(0, i)
if _in_legacy_dygraph(): out = _C_ops.transpose(var, perm)
out, _ = _legacy_C_ops.transpose2(var, 'axis', perm)
else:
out = _C_ops.transpose(var, perm)
return out return out
def _scalar_add_(var, value): def _scalar_add_(var, value):
......
...@@ -20,7 +20,6 @@ from .. import dygraph_utils ...@@ -20,7 +20,6 @@ from .. import dygraph_utils
from . import layers from . import layers
from ..framework import ( from ..framework import (
Variable, Variable,
_non_static_mode,
OpProtoHolder, OpProtoHolder,
Parameter, Parameter,
_dygraph_tracer, _dygraph_tracer,
...@@ -28,7 +27,6 @@ from ..framework import ( ...@@ -28,7 +27,6 @@ from ..framework import (
default_main_program, default_main_program,
_global_flags, _global_flags,
in_dygraph_mode, in_dygraph_mode,
_in_legacy_dygraph,
) )
from ..data_feeder import ( from ..data_feeder import (
...@@ -247,115 +245,81 @@ class BatchNorm(layers.Layer): ...@@ -247,115 +245,81 @@ class BatchNorm(layers.Layer):
# variance and variance out share the same memory # variance and variance out share the same memory
variance_out = self._variance variance_out = self._variance
if _non_static_mode(): if in_dygraph_mode():
if in_dygraph_mode(): batch_norm_out, t1, t2, t3, t4, _ = _C_ops.batch_norm(
batch_norm_out, t1, t2, t3, t4, _ = _C_ops.batch_norm( input,
input, self._mean,
self._mean, self._variance,
self._variance, self.weight,
self.weight, self.bias,
self.bias, not self.training,
not self.training, self._momentum,
self._momentum, self._epsilon,
self._epsilon, self._data_layout,
self._data_layout, self._use_global_stats,
self._use_global_stats, self._trainable_statistics,
self._trainable_statistics, )
)
return dygraph_utils._append_activation_in_dygraph(
batch_norm_out, act=self._act, use_mkldnn=self._use_mkldnn
)
elif _in_legacy_dygraph():
attrs = (
"momentum",
self._momentum,
"epsilon",
self._epsilon,
"is_test",
not self.training,
"data_layout",
self._data_layout,
"use_mkldnn",
self._use_mkldnn,
"fuse_with_relu",
self._fuse_with_relu,
"use_global_stats",
self._use_global_stats,
'trainable_statistics',
self._trainable_statistics,
)
batch_norm_out, _, _, _, _, _ = _legacy_C_ops.batch_norm(
input,
self.weight,
self.bias,
self._mean,
self._variance,
None,
mean_out,
variance_out,
*attrs
)
return dygraph_utils._append_activation_in_dygraph( return dygraph_utils._append_activation_in_dygraph(
batch_norm_out, act=self._act, use_mkldnn=self._use_mkldnn batch_norm_out, act=self._act, use_mkldnn=self._use_mkldnn
) )
else:
check_variable_and_dtype(
input, 'input', ['float16', 'float32', 'float64'], 'BatchNorm'
)
check_variable_and_dtype( attrs = {
input, 'input', ['float16', 'float32', 'float64'], 'BatchNorm' "momentum": self._momentum,
) "epsilon": self._epsilon,
"is_test": self._is_test,
attrs = { "data_layout": self._data_layout,
"momentum": self._momentum, "use_mkldnn": False,
"epsilon": self._epsilon, "fuse_with_relu": self._fuse_with_relu,
"is_test": self._is_test, "use_global_stats": self._use_global_stats,
"data_layout": self._data_layout, "trainable_statistics": self._trainable_statistics,
"use_mkldnn": False, }
"fuse_with_relu": self._fuse_with_relu,
"use_global_stats": self._use_global_stats, inputs = {
"trainable_statistics": self._trainable_statistics, "X": [input],
} "Scale": [self.weight],
"Bias": [self.bias],
inputs = { "Mean": [self._mean],
"X": [input], "Variance": [self._variance],
"Scale": [self.weight], }
"Bias": [self.bias],
"Mean": [self._mean], saved_mean = self._helper.create_variable_for_type_inference(
"Variance": [self._variance], dtype=self._dtype, stop_gradient=True
} )
saved_variance = self._helper.create_variable_for_type_inference(
saved_mean = self._helper.create_variable_for_type_inference( dtype=self._dtype, stop_gradient=True
dtype=self._dtype, stop_gradient=True )
) reserve_space = self._helper.create_variable_for_type_inference(
saved_variance = self._helper.create_variable_for_type_inference( dtype=self._helper.input_dtype(input), stop_gradient=True
dtype=self._dtype, stop_gradient=True )
)
reserve_space = self._helper.create_variable_for_type_inference(
dtype=self._helper.input_dtype(input), stop_gradient=True
)
batch_norm_out = (
input
if self._in_place
else self._helper.create_variable_for_type_inference(self._dtype)
)
outputs = { batch_norm_out = (
"Y": [batch_norm_out], input
"MeanOut": [mean_out], if self._in_place
"VarianceOut": [variance_out], else self._helper.create_variable_for_type_inference(
"SavedMean": [saved_mean], self._dtype
"SavedVariance": [saved_variance], )
} )
if reserve_space is not None:
outputs["ReserveSpace"] = [reserve_space]
self._helper.append_op( outputs = {
type="batch_norm", inputs=inputs, outputs=outputs, attrs=attrs "Y": [batch_norm_out],
) "MeanOut": [mean_out],
"VarianceOut": [variance_out],
"SavedMean": [saved_mean],
"SavedVariance": [saved_variance],
}
if reserve_space is not None:
outputs["ReserveSpace"] = [reserve_space]
self._helper.append_op(
type="batch_norm", inputs=inputs, outputs=outputs, attrs=attrs
)
# Currently, we don't support inplace in dygraph mode # Currently, we don't support inplace in dygraph mode
return self._helper.append_activation(batch_norm_out, self._act) return self._helper.append_activation(batch_norm_out, self._act)
class RowConv(layers.Layer): class RowConv(layers.Layer):
...@@ -410,7 +374,7 @@ class RowConv(layers.Layer): ...@@ -410,7 +374,7 @@ class RowConv(layers.Layer):
self, name_scope, future_context_size, param_attr=None, act=None self, name_scope, future_context_size, param_attr=None, act=None
): ):
assert ( assert (
not _non_static_mode() not in_dygraph_mode()
), "RowConv is not supported by dynamic graph mode yet!" ), "RowConv is not supported by dynamic graph mode yet!"
super().__init__(name_scope) super().__init__(name_scope)
self._act = act self._act = act
......
...@@ -32,8 +32,6 @@ from ..layers import collective ...@@ -32,8 +32,6 @@ from ..layers import collective
from paddle.fluid.dygraph import base as imperative_base from paddle.fluid.dygraph import base as imperative_base
from paddle.fluid.framework import ( from paddle.fluid.framework import (
ParamBase, ParamBase,
_in_legacy_dygraph,
_non_static_mode,
in_dygraph_mode, in_dygraph_mode,
) )
...@@ -302,23 +300,7 @@ def _reshape_inplace(x, shape): ...@@ -302,23 +300,7 @@ def _reshape_inplace(x, shape):
@framework.dygraph_only @framework.dygraph_only
def _split_tensors(coalesced_grads_and_grad_vars): def _split_tensors(coalesced_grads_and_grad_vars):
if _in_legacy_dygraph(): if in_dygraph_mode():
for (
coalesced_grad,
origin_grad_vars,
grad_shapes,
) in coalesced_grads_and_grad_vars:
grad_var_len = [np.prod(g_shape) for g_shape in grad_shapes]
framework._dygraph_tracer().trace_op(
type='split',
inputs={'X': coalesced_grad},
outputs={'Out': origin_grad_vars},
attrs={'sections': grad_var_len, 'axis': 0},
)
for g_var, g_shape in zip(origin_grad_vars, grad_shapes):
_reshape_inplace(x=g_var, shape=g_shape)
assert g_var.shape == g_shape
elif in_dygraph_mode():
for ( for (
coalesced_grad, coalesced_grad,
origin_grad_vars, origin_grad_vars,
...@@ -587,7 +569,7 @@ class DataParallel(layers.Layer): ...@@ -587,7 +569,7 @@ class DataParallel(layers.Layer):
super().__init__(layers.full_name() + "_data_parallel") super().__init__(layers.full_name() + "_data_parallel")
assert ( assert (
_non_static_mode() in_dygraph_mode()
), "It's not supported to construct DataParallel in static mode." ), "It's not supported to construct DataParallel in static mode."
self._layers = layers self._layers = layers
...@@ -704,21 +686,6 @@ class DataParallel(layers.Layer): ...@@ -704,21 +686,6 @@ class DataParallel(layers.Layer):
[self.last_comm_buffer_size, self.comm_buffer_size], [self.last_comm_buffer_size, self.comm_buffer_size],
self.find_unused_parameters, self.find_unused_parameters,
) )
elif _in_legacy_dygraph():
self.group_indices = core.assign_group_by_size(
trainable_parameters,
is_sparse_gradient,
[self.last_comm_buffer_size, self.comm_buffer_size],
)
self._reducer = core.Reducer(
trainable_parameters,
list(reversed(self.group_indices)),
is_sparse_gradient,
parallel_helper.__parallel_ctx__clz__,
[self.last_comm_buffer_size, self.comm_buffer_size],
self.find_unused_parameters,
)
def _find_varbase(self, obj): def _find_varbase(self, obj):
var_type = core.eager.Tensor if in_dygraph_mode() else core.VarBase var_type = core.eager.Tensor if in_dygraph_mode() else core.VarBase
......
...@@ -20,7 +20,7 @@ import sys ...@@ -20,7 +20,7 @@ import sys
import paddle import paddle
from .. import framework from .. import framework
from ..framework import convert_np_dtype_to_dtype_, _in_legacy_dygraph from ..framework import convert_np_dtype_to_dtype_
from .. import core from .. import core
from .. import unique_name from .. import unique_name
from ..framework import ( from ..framework import (
......
...@@ -42,7 +42,9 @@ class TestBprLossOp1(OpTest): ...@@ -42,7 +42,9 @@ class TestBprLossOp1(OpTest):
self.outputs = {"Y": bpr_loss} self.outputs = {"Y": bpr_loss}
def test_check_output(self): def test_check_output(self):
paddle.enable_static()
self.check_output() self.check_output()
paddle.disable_static()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(["X"], "Y", numeric_grad_delta=0.001) self.check_grad(["X"], "Y", numeric_grad_delta=0.001)
......
...@@ -522,9 +522,11 @@ class TestCUDNNLstmOp(OpTest): ...@@ -522,9 +522,11 @@ class TestCUDNNLstmOp(OpTest):
place, atol=1e-5, no_check_set=['Reserve', 'StateOut'] place, atol=1e-5, no_check_set=['Reserve', 'StateOut']
) )
else: else:
paddle.enable_static()
self.check_output_with_place( self.check_output_with_place(
place, no_check_set=['Reserve', 'StateOut'] place, no_check_set=['Reserve', 'StateOut']
) )
paddle.disable_static()
def test_grad_with_place(self): def test_grad_with_place(self):
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册