未验证 提交 6e80b84d 编写于 作者: W Weilong Wu 提交者: GitHub

[Eager] polish several api (#49589)

上级 cac5f5a7
......@@ -494,11 +494,6 @@ def prelu(x, weight, data_format="NCHW", name=None):
# [-1.25, 6. , 7. , -2. ],
# [ 6. , 7. , 8. , 9. ]]]]
"""
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'prelu')
check_variable_and_dtype(
weight, 'weight', ['float16', 'float32', 'float64'], 'prelu'
)
assert (
len(weight.shape) == 1
), "The dim count of weight shape should be 1 in prelu()."
......@@ -541,6 +536,12 @@ def prelu(x, weight, data_format="NCHW", name=None):
if in_dygraph_mode():
return _C_ops.prelu(x, weight, data_format, mode)
else:
check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64'], 'prelu'
)
check_variable_and_dtype(
weight, 'weight', ['float16', 'float32', 'float64'], 'prelu'
)
helper = LayerHelper('prelu', **locals())
out = helper.create_variable_for_type_inference(x.dtype)
helper.append_op(
......@@ -622,12 +623,6 @@ def rrelu(x, lower=1.0 / 8.0, upper=1.0 / 3.0, training=True, name=None):
# [-1.3766339 6. 7. -2.3465784 ]
# [ 6. 7. 8. 9. ]]]]
"""
if not in_dynamic_mode():
check_variable_and_dtype(
x, 'X', ['float16', 'float32', 'float64'], 'rrelu'
)
if not isinstance(lower, float) or not isinstance(upper, float):
raise TypeError(
"The lower and upper values must be float type. Received: lower {}, upper {}.".format(
......@@ -664,6 +659,9 @@ def rrelu(x, lower=1.0 / 8.0, upper=1.0 / 3.0, training=True, name=None):
)
return out
else:
check_variable_and_dtype(
x, 'X', ['float16', 'float32', 'float64'], 'rrelu'
)
helper = LayerHelper('rrelu', **locals())
out = helper.create_variable_for_type_inference(x.dtype)
noise = helper.create_variable_for_type_inference(dtype=x.dtype)
......
......@@ -422,9 +422,7 @@ def interpolate(
return paddle.nn.functional.adaptive_avg_pool2d(x, size)
elif len(x.shape) == 5:
return paddle.nn.functional.adaptive_avg_pool3d(x, size)
helper = LayerHelper('{}_interp_v2'.format(resample_type), **locals())
dtype = helper.input_dtype(input_param_name='x')
if len(x.shape) == 3 and data_format not in ['NCW', 'NWC']:
raise ValueError(
"Got wrong value for param `data_format`: "
......@@ -678,6 +676,9 @@ def interpolate(
else:
out = _legacy_C_ops.bicubic_interp_v2(x, *dy_attr)
return out
dtype = helper.input_dtype(input_param_name='x')
out = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type='{}_interp_v2'.format(resample_type),
......
......@@ -597,8 +597,6 @@ def max_pool1d(
"""
"""NCL to NCHW"""
data_format = "NCHW"
if not in_dynamic_mode():
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'max_pool1d')
_check_input(x, 3)
x = unsqueeze(x, [2])
kernel_size = [1] + utils.convert_to_list(kernel_size, 1, 'pool_size')
......@@ -641,6 +639,7 @@ def max_pool1d(
return squeeze(pool_out, [2])
else:
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'max_pool1d')
op_type = 'max_pool2d_with_index' if return_mask else "pool2d"
helper = LayerHelper(op_type, **locals())
dtype = helper.input_dtype(input_param_name='x')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册