未验证 提交 140d786d 编写于 作者: 姜永久 提交者: GitHub

rm in_legacy_dygraph python/paddle/nn/functional/ part1 (#49258)

* rm in_legacy_dygraph nn part1

* rm non_static_mode

* modify rrelu
上级 861fef52
......@@ -25,11 +25,7 @@ from ...fluid.data_feeder import (
check_type,
check_variable_and_dtype,
)
from ...fluid.framework import (
_in_legacy_dygraph,
_non_static_mode,
in_dygraph_mode,
)
from ...fluid.framework import in_dygraph_mode
from ...tensor import clip, concat, sqrt, sum
from ...tensor.creation import zeros
......@@ -927,24 +923,22 @@ def bilinear(x1, x2, weight, bias=None, name=None):
if in_dygraph_mode():
return _C_ops.bilinear_tensor_product(x1, x2, weight, bias)
elif _non_static_mode():
return _legacy_C_ops.bilinear_tensor_product(x1, x2, weight, bias)
check_variable_and_dtype(x1, 'x1', ['float32', 'float64'], 'bilinear')
check_variable_and_dtype(x2, 'x2', ['float32', 'float64'], 'bilinear')
else:
check_variable_and_dtype(x1, 'x1', ['float32', 'float64'], 'bilinear')
check_variable_and_dtype(x2, 'x2', ['float32', 'float64'], 'bilinear')
inputs = {"X": x1, "Y": x2, "Weight": weight}
if bias is not None:
inputs["Bias"] = bias
inputs = {"X": x1, "Y": x2, "Weight": weight}
if bias is not None:
inputs["Bias"] = bias
helper = LayerHelper("bilinear", **locals())
out = helper.create_variable_for_type_inference(dtype=x1.dtype)
helper = LayerHelper("bilinear", **locals())
out = helper.create_variable_for_type_inference(dtype=x1.dtype)
helper.append_op(
type="bilinear_tensor_product", inputs=inputs, outputs={"Out": out}
)
helper.append_op(
type="bilinear_tensor_product", inputs=inputs, outputs={"Out": out}
)
return out
return out
def dropout(
......@@ -1118,77 +1112,62 @@ def dropout(
'downgrade_in_infer' if mode == 'downscale_in_infer' else mode
) # semantic transfer
if _non_static_mode():
if in_dygraph_mode():
if default_main_program().random_seed != 0:
seed = default_main_program().random_seed
if in_dygraph_mode():
out, mask = _C_ops.dropout(
x,
None,
p,
not training,
mode,
seed if seed is not None else 0,
seed is not None,
)
return out
out, mask = _legacy_C_ops.dropout(
out, mask = _C_ops.dropout(
x,
'dropout_prob',
None,
p,
'is_test',
not training,
'fix_seed',
seed is not None,
'seed',
seed if seed is not None else 0,
'dropout_implementation',
mode,
seed if seed is not None else 0,
seed is not None,
)
return out
helper = LayerHelper('dropout', **locals())
check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64'], 'dropout'
)
return out
else:
helper = LayerHelper('dropout', **locals())
check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64'], 'dropout'
)
out = helper.create_variable_for_type_inference(dtype=x.dtype)
mask = helper.create_variable_for_type_inference(
dtype=core.VarDesc.VarType.UINT8, stop_gradient=True
)
out = helper.create_variable_for_type_inference(dtype=x.dtype)
mask = helper.create_variable_for_type_inference(
dtype=core.VarDesc.VarType.UINT8, stop_gradient=True
)
def get_attrs(prog, dropout_prob, is_test, seed):
if (seed is None or seed == 0) and prog.random_seed != 0:
seed = prog.random_seed
def get_attrs(prog, dropout_prob, is_test, seed):
if (seed is None or seed == 0) and prog.random_seed != 0:
seed = prog.random_seed
if isinstance(
dropout_prob, Variable
) and not dropout_prob.shape != [1]:
raise TypeError(
"Required p.shape == [1] if type(p) is Variable, but received p.shape = {}".format(
p.shape
if isinstance(
dropout_prob, Variable
) and not dropout_prob.shape != [1]:
raise TypeError(
"Required p.shape == [1] if type(p) is Variable, but received p.shape = {}".format(
p.shape
)
)
)
attrs = {
'dropout_prob': dropout_prob,
'is_test': is_test,
'fix_seed': seed is not None,
'seed': seed if seed is not None else 0,
'dropout_implementation': mode,
}
return attrs
attrs = {
'dropout_prob': dropout_prob,
'is_test': is_test,
'fix_seed': seed is not None,
'seed': seed if seed is not None else 0,
'dropout_implementation': mode,
}
return attrs
attrs = get_attrs(helper.main_program, p, not training, seed)
attrs = get_attrs(helper.main_program, p, not training, seed)
helper.append_op(
type='dropout',
inputs={'X': [x]},
outputs={'Out': [out], 'Mask': [mask]},
attrs=attrs,
)
return out
helper.append_op(
type='dropout',
inputs={'X': [x]},
outputs={'Out': [out], 'Mask': [mask]},
attrs=attrs,
)
return out
else: # sometimes called dropout_nd #TODO: optimize with c++
if not in_dynamic_mode():
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'dropout')
......@@ -1684,38 +1663,21 @@ def pad(x, pad, mode='constant', value=0.0, data_format="NCHW", name=None):
pad = pad.numpy().tolist()
out = _C_ops.pad3d(x, pad, mode, value, data_format)
else:
if _in_legacy_dygraph():
if isinstance(pad, Variable):
pad = pad.numpy().tolist()
out = _legacy_C_ops.pad3d(
x,
"paddings",
pad,
"mode",
mode,
"value",
value,
"data_format",
data_format,
"name",
name,
)
attrs = {'mode': mode, 'value': value, 'data_format': data_format}
inputs = {'X': [x]}
if isinstance(pad, Variable):
inputs['Paddings'] = [pad]
attrs['paddings'] = []
else:
attrs = {'mode': mode, 'value': value, 'data_format': data_format}
inputs = {'X': [x]}
if isinstance(pad, Variable):
inputs['Paddings'] = [pad]
attrs['paddings'] = []
else:
attrs['paddings'] = pad
attrs['paddings'] = pad
helper = LayerHelper('pad3d', **locals())
helper = LayerHelper('pad3d', **locals())
dtype = helper.input_dtype(input_param_name='input')
out = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type='pad3d', inputs=inputs, outputs={"Out": out}, attrs=attrs
)
dtype = helper.input_dtype(input_param_name='input')
out = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type='pad3d', inputs=inputs, outputs={"Out": out}, attrs=attrs
)
if len(unsqueezed_dim) != 0:
out = squeeze(out, axis=unsqueezed_dim)
......@@ -1873,46 +1835,34 @@ def linear(x, weight, bias=None, name=None):
# TODO(jiabin): using addmm for fast forward route
return _C_ops.linear(x, weight, bias)
else:
if _in_legacy_dygraph():
pre_bias = _legacy_C_ops.matmul_v2(
x, weight, 'trans_x', False, 'trans_y', False
)
if bias is None:
return pre_bias
return _legacy_C_ops.elementwise_add(pre_bias, bias)
else:
helper = LayerHelper('linear', **locals())
dtype = x.dtype
helper = LayerHelper('linear', **locals())
dtype = x.dtype
check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64'], 'linear'
)
check_dtype(
dtype, 'dtype', ['float16', 'float32', 'float64'], 'linear'
)
check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64'], 'linear'
)
check_dtype(dtype, 'dtype', ['float16', 'float32', 'float64'], 'linear')
inputs = {'X': [x], 'Y': [weight]}
attrs = {'trans_x': False, 'trans_y': False}
tmp = helper.create_variable_for_type_inference(dtype)
inputs = {'X': [x], 'Y': [weight]}
attrs = {'trans_x': False, 'trans_y': False}
tmp = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type='matmul_v2',
inputs=inputs,
outputs={'Out': tmp},
attrs=attrs,
)
if bias is not None:
res = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type='matmul_v2',
inputs=inputs,
outputs={'Out': tmp},
attrs=attrs,
type='elementwise_add',
inputs={'X': [tmp], 'Y': [bias]},
outputs={'Out': [res]},
attrs={'axis': len(x.shape) - 1},
)
if bias is not None:
res = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type='elementwise_add',
inputs={'X': [tmp], 'Y': [bias]},
outputs={'Out': [res]},
attrs={'axis': len(x.shape) - 1},
)
else:
res = tmp
return res
else:
res = tmp
return res
def label_smooth(label, prior_dist=None, epsilon=0.1, name=None):
......
......@@ -19,11 +19,7 @@ from paddle.device import (
is_compiled_with_npu,
is_compiled_with_rocm,
)
from paddle.fluid.framework import (
_global_flags,
_in_legacy_dygraph,
in_dygraph_mode,
)
from paddle.fluid.framework import _global_flags, in_dygraph_mode
from paddle.tensor.math import _add_with_axis
from ...device import get_cudnn_version
......@@ -489,30 +485,6 @@ def conv1d(
)
if bias is not None:
out = _add_with_axis(out, bias, axis=channel_dim)
elif _in_legacy_dygraph():
attrs = (
'strides',
stride,
'paddings',
padding,
'dilations',
dilation,
'groups',
groups,
'use_cudnn',
use_cudnn,
'use_mkldnn',
False,
'fuse_relu_before_depthwise_conv',
False,
"padding_algorithm",
padding_algorithm,
"data_format",
conv2d_data_format,
)
out = getattr(_legacy_C_ops, l_type)(x, weight, *attrs)
if bias is not None:
out = _add_with_axis(out, bias, axis=channel_dim)
else:
inputs = {'Input': [x], 'Filter': [weight]}
attrs = {
......@@ -1044,30 +1016,6 @@ def conv1d_transpose(
)
if bias is not None:
out = _add_with_axis(out, bias, axis=channel_dim)
elif _in_legacy_dygraph():
attrs = (
'output_padding',
output_padding,
'output_size',
output_size,
'strides',
stride,
'paddings',
padding,
'padding_algorithm',
padding_algorithm,
'dilations',
dilation,
'groups',
groups,
'use_cudnn',
use_cudnn,
'data_format',
conv2d_data_format,
)
out = getattr(_legacy_C_ops, op_type)(x, weight, *attrs)
if bias is not None:
out = _add_with_axis(out, bias, axis=channel_dim)
else:
inputs = {'Input': [x], 'Filter': [weight]}
attrs = {
......@@ -1350,33 +1298,6 @@ def conv2d_transpose(
return _add_with_axis(pre_bias, bias, axis=channel_dim)
else:
return pre_bias
if _in_legacy_dygraph():
attrs = (
'output_padding',
output_padding,
'output_size',
output_size,
'strides',
stride,
'paddings',
padding,
'padding_algorithm',
padding_algorithm,
'dilations',
dilation,
'groups',
groups,
'use_cudnn',
use_cudnn,
'data_format',
data_format,
)
pre_bias = getattr(_legacy_C_ops, op_type)(x, weight, *attrs)
if bias is not None:
out = _add_with_axis(pre_bias, bias, axis=channel_dim)
else:
out = pre_bias
else:
inputs = {'Input': [x], 'Filter': [weight]}
attrs = {
......@@ -1823,33 +1744,6 @@ def conv3d_transpose(
return _add_with_axis(pre_bias, bias, axis=channel_dim)
else:
return pre_bias
if _in_legacy_dygraph():
attrs = (
'output_padding',
output_padding,
'output_size',
output_size,
'paddings',
padding,
"padding_algorithm",
padding_algorithm,
'strides',
stride,
'dilations',
dilation,
'groups',
groups,
'use_cudnn',
use_cudnn,
"data_format",
data_format_,
)
pre_bias = getattr(_legacy_C_ops, op_type)(x, weight, *attrs)
if bias is not None:
out = _add_with_axis(pre_bias, bias, axis=channel_dim)
else:
out = pre_bias
else:
inputs = {'Input': [x], 'Filter': [weight]}
attrs = {
......
......@@ -13,8 +13,8 @@
# limitations under the License.
import paddle
from paddle import _C_ops, _legacy_C_ops
from paddle.fluid.framework import _in_legacy_dygraph, in_dygraph_mode
from paddle import _C_ops
from paddle.fluid.framework import in_dygraph_mode
from ...fluid.data_feeder import check_type, check_variable_and_dtype
from ...fluid.layer_helper import LayerHelper
......@@ -81,36 +81,30 @@ def pairwise_distance(x, y, p=2.0, epsilon=1e-6, keepdim=False, name=None):
sub = _C_ops.add(sub, epsilon)
return _C_ops.p_norm(sub, p, -1, 0.0, keepdim, False)
if _in_legacy_dygraph():
sub = _legacy_C_ops.elementwise_sub(x, y)
else:
check_variable_and_dtype(
x, 'x', ['float32', 'float64'], 'PairwiseDistance'
)
check_variable_and_dtype(
y, 'y', ['float32', 'float64'], 'PairwiseDistance'
)
sub = paddle.subtract(x, y)
if epsilon != 0.0:
epsilon = paddle.fluid.dygraph.base.to_variable(
[epsilon], dtype=sub.dtype
epsilon_var = sub.block.create_var(dtype=sub.dtype)
epsilon_var = paddle.full(
shape=[1], fill_value=epsilon, dtype=sub.dtype
)
sub = _legacy_C_ops.elementwise_add(sub, epsilon)
return _legacy_C_ops.p_norm(
sub, 'axis', -1, 'porder', p, 'keepdim', keepdim, 'epsilon', 0.0
sub = paddle.add(sub, epsilon_var)
helper = LayerHelper("PairwiseDistance", name=name)
attrs = {
'axis': -1,
'porder': p,
'keepdim': keepdim,
'epsilon': 0.0,
}
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type='p_norm', inputs={'X': sub}, outputs={'Out': out}, attrs=attrs
)
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'PairwiseDistance')
check_variable_and_dtype(y, 'y', ['float32', 'float64'], 'PairwiseDistance')
sub = paddle.subtract(x, y)
if epsilon != 0.0:
epsilon_var = sub.block.create_var(dtype=sub.dtype)
epsilon_var = paddle.full(
shape=[1], fill_value=epsilon, dtype=sub.dtype
)
sub = paddle.add(sub, epsilon_var)
helper = LayerHelper("PairwiseDistance", name=name)
attrs = {
'axis': -1,
'porder': p,
'keepdim': keepdim,
'epsilon': 0.0,
}
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type='p_norm', inputs={'X': sub}, outputs={'Out': out}, attrs=attrs
)
return out
return out
......@@ -34,17 +34,11 @@ import numpy as np
from paddle import _C_ops, _legacy_C_ops, in_dynamic_mode
from paddle.device import get_all_custom_device_type
from paddle.fluid.framework import _in_legacy_dygraph, in_dygraph_mode
from paddle.fluid.framework import in_dygraph_mode
from ...fluid import dygraph_utils
from ...fluid.data_feeder import check_variable_and_dtype
from ...framework import (
ParamAttr,
_global_flags,
_non_static_mode,
get_default_dtype,
no_grad,
)
from ...framework import ParamAttr, _global_flags, get_default_dtype, no_grad
from .. import Layer
from .. import functional as F
from ..functional import batch_norm, instance_norm, layer_norm
......@@ -492,20 +486,6 @@ class GroupNorm(Layer):
dtype=input.dtype, stop_gradient=True
)
if _in_legacy_dygraph():
pre_act, _, _ = _legacy_C_ops.group_norm(
input,
self.weight,
self.bias,
mean_out,
variance_out,
'epsilon',
self._epsilon,
'groups',
self._num_groups,
)
return pre_act
inputs = {'X': input}
if self.bias is not None:
inputs['Bias'] = self.bias
......@@ -1005,121 +985,86 @@ class BatchNorm(Layer):
self._trainable_statistics = trainable_statistics
def forward(self, input):
# create output
# mean and mean_out share the same memory
mean_out = self._mean
# variance and variance out share the same memory
variance_out = self._variance
if _non_static_mode():
if in_dygraph_mode():
batch_norm_out, t1, t2, t3, t4, _ = _C_ops.batch_norm(
input,
self._mean,
self._variance,
self.weight,
self.bias,
not self.training,
self._momentum,
self._epsilon,
self._data_layout,
self._use_global_stats,
self._trainable_statistics,
)
return dygraph_utils._append_activation_in_dygraph(
batch_norm_out, act=self._act, use_mkldnn=self._use_mkldnn
)
elif _in_legacy_dygraph():
attrs = (
"momentum",
self._momentum,
"epsilon",
self._epsilon,
"is_test",
not self.training,
"data_layout",
self._data_layout,
"use_mkldnn",
self._use_mkldnn,
"fuse_with_relu",
self._fuse_with_relu,
"use_global_stats",
self._use_global_stats,
'trainable_statistics',
self._trainable_statistics,
)
batch_norm_out, _, _, _, _, _ = _legacy_C_ops.batch_norm(
input,
self.weight,
self.bias,
self._mean,
self._variance,
None,
mean_out,
variance_out,
*attrs
)
if in_dygraph_mode():
batch_norm_out, t1, t2, t3, t4, _ = _C_ops.batch_norm(
input,
self._mean,
self._variance,
self.weight,
self.bias,
not self.training,
self._momentum,
self._epsilon,
self._data_layout,
self._use_global_stats,
self._trainable_statistics,
)
return dygraph_utils._append_activation_in_dygraph(
batch_norm_out, act=self._act, use_mkldnn=self._use_mkldnn
)
else:
# create output
# mean and mean_out share the same memory
mean_out = self._mean
# variance and variance out share the same memory
variance_out = self._variance
check_variable_and_dtype(
input, 'input', ['float16', 'float32', 'float64'], 'BatchNorm'
)
check_variable_and_dtype(
input, 'input', ['float16', 'float32', 'float64'], 'BatchNorm'
)
attrs = {
"momentum": self._momentum,
"epsilon": self._epsilon,
"is_test": self._is_test,
"data_layout": self._data_layout,
"use_mkldnn": False,
"fuse_with_relu": self._fuse_with_relu,
"use_global_stats": self._use_global_stats,
"trainable_statistics": self._trainable_statistics,
}
inputs = {
"X": [input],
"Scale": [self.weight],
"Bias": [self.bias],
"Mean": [self._mean],
"Variance": [self._variance],
}
saved_mean = self._helper.create_variable_for_type_inference(
dtype=self._dtype, stop_gradient=True
)
saved_variance = self._helper.create_variable_for_type_inference(
dtype=self._dtype, stop_gradient=True
)
reserve_space = self._helper.create_variable_for_type_inference(
dtype=self._helper.input_dtype(input), stop_gradient=True
)
batch_norm_out = (
input
if self._in_place
else self._helper.create_variable_for_type_inference(self._dtype)
)
attrs = {
"momentum": self._momentum,
"epsilon": self._epsilon,
"is_test": self._is_test,
"data_layout": self._data_layout,
"use_mkldnn": False,
"fuse_with_relu": self._fuse_with_relu,
"use_global_stats": self._use_global_stats,
"trainable_statistics": self._trainable_statistics,
}
inputs = {
"X": [input],
"Scale": [self.weight],
"Bias": [self.bias],
"Mean": [self._mean],
"Variance": [self._variance],
}
saved_mean = self._helper.create_variable_for_type_inference(
dtype=self._dtype, stop_gradient=True
)
saved_variance = self._helper.create_variable_for_type_inference(
dtype=self._dtype, stop_gradient=True
)
reserve_space = self._helper.create_variable_for_type_inference(
dtype=self._helper.input_dtype(input), stop_gradient=True
)
outputs = {
"Y": [batch_norm_out],
"MeanOut": [mean_out],
"VarianceOut": [variance_out],
"SavedMean": [saved_mean],
"SavedVariance": [saved_variance],
}
if reserve_space is not None:
outputs["ReserveSpace"] = [reserve_space]
batch_norm_out = (
input
if self._in_place
else self._helper.create_variable_for_type_inference(
self._dtype
)
)
self._helper.append_op(
type="batch_norm", inputs=inputs, outputs=outputs, attrs=attrs
)
outputs = {
"Y": [batch_norm_out],
"MeanOut": [mean_out],
"VarianceOut": [variance_out],
"SavedMean": [saved_mean],
"SavedVariance": [saved_variance],
}
if reserve_space is not None:
outputs["ReserveSpace"] = [reserve_space]
self._helper.append_op(
type="batch_norm", inputs=inputs, outputs=outputs, attrs=attrs
)
# Currently, we don't support inplace in dygraph mode
return self._helper.append_activation(batch_norm_out, self._act)
# Currently, we don't support inplace in dygraph mode
return self._helper.append_activation(batch_norm_out, self._act)
class BatchNorm1D(_BatchNormBase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册