未验证 提交 6e80b84d 编写于 作者: W Weilong Wu 提交者: GitHub

[Eager] polish several api (#49589)

上级 cac5f5a7
...@@ -494,11 +494,6 @@ def prelu(x, weight, data_format="NCHW", name=None): ...@@ -494,11 +494,6 @@ def prelu(x, weight, data_format="NCHW", name=None):
# [-1.25, 6. , 7. , -2. ], # [-1.25, 6. , 7. , -2. ],
# [ 6. , 7. , 8. , 9. ]]]] # [ 6. , 7. , 8. , 9. ]]]]
""" """
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'prelu')
check_variable_and_dtype(
weight, 'weight', ['float16', 'float32', 'float64'], 'prelu'
)
assert ( assert (
len(weight.shape) == 1 len(weight.shape) == 1
), "The dim count of weight shape should be 1 in prelu()." ), "The dim count of weight shape should be 1 in prelu()."
...@@ -541,6 +536,12 @@ def prelu(x, weight, data_format="NCHW", name=None): ...@@ -541,6 +536,12 @@ def prelu(x, weight, data_format="NCHW", name=None):
if in_dygraph_mode(): if in_dygraph_mode():
return _C_ops.prelu(x, weight, data_format, mode) return _C_ops.prelu(x, weight, data_format, mode)
else: else:
check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64'], 'prelu'
)
check_variable_and_dtype(
weight, 'weight', ['float16', 'float32', 'float64'], 'prelu'
)
helper = LayerHelper('prelu', **locals()) helper = LayerHelper('prelu', **locals())
out = helper.create_variable_for_type_inference(x.dtype) out = helper.create_variable_for_type_inference(x.dtype)
helper.append_op( helper.append_op(
...@@ -622,12 +623,6 @@ def rrelu(x, lower=1.0 / 8.0, upper=1.0 / 3.0, training=True, name=None): ...@@ -622,12 +623,6 @@ def rrelu(x, lower=1.0 / 8.0, upper=1.0 / 3.0, training=True, name=None):
# [-1.3766339 6. 7. -2.3465784 ] # [-1.3766339 6. 7. -2.3465784 ]
# [ 6. 7. 8. 9. ]]]] # [ 6. 7. 8. 9. ]]]]
""" """
if not in_dynamic_mode():
check_variable_and_dtype(
x, 'X', ['float16', 'float32', 'float64'], 'rrelu'
)
if not isinstance(lower, float) or not isinstance(upper, float): if not isinstance(lower, float) or not isinstance(upper, float):
raise TypeError( raise TypeError(
"The lower and upper values must be float type. Received: lower {}, upper {}.".format( "The lower and upper values must be float type. Received: lower {}, upper {}.".format(
...@@ -664,6 +659,9 @@ def rrelu(x, lower=1.0 / 8.0, upper=1.0 / 3.0, training=True, name=None): ...@@ -664,6 +659,9 @@ def rrelu(x, lower=1.0 / 8.0, upper=1.0 / 3.0, training=True, name=None):
) )
return out return out
else: else:
check_variable_and_dtype(
x, 'X', ['float16', 'float32', 'float64'], 'rrelu'
)
helper = LayerHelper('rrelu', **locals()) helper = LayerHelper('rrelu', **locals())
out = helper.create_variable_for_type_inference(x.dtype) out = helper.create_variable_for_type_inference(x.dtype)
noise = helper.create_variable_for_type_inference(dtype=x.dtype) noise = helper.create_variable_for_type_inference(dtype=x.dtype)
......
...@@ -422,9 +422,7 @@ def interpolate( ...@@ -422,9 +422,7 @@ def interpolate(
return paddle.nn.functional.adaptive_avg_pool2d(x, size) return paddle.nn.functional.adaptive_avg_pool2d(x, size)
elif len(x.shape) == 5: elif len(x.shape) == 5:
return paddle.nn.functional.adaptive_avg_pool3d(x, size) return paddle.nn.functional.adaptive_avg_pool3d(x, size)
helper = LayerHelper('{}_interp_v2'.format(resample_type), **locals()) helper = LayerHelper('{}_interp_v2'.format(resample_type), **locals())
dtype = helper.input_dtype(input_param_name='x')
if len(x.shape) == 3 and data_format not in ['NCW', 'NWC']: if len(x.shape) == 3 and data_format not in ['NCW', 'NWC']:
raise ValueError( raise ValueError(
"Got wrong value for param `data_format`: " "Got wrong value for param `data_format`: "
...@@ -678,6 +676,9 @@ def interpolate( ...@@ -678,6 +676,9 @@ def interpolate(
else: else:
out = _legacy_C_ops.bicubic_interp_v2(x, *dy_attr) out = _legacy_C_ops.bicubic_interp_v2(x, *dy_attr)
return out return out
dtype = helper.input_dtype(input_param_name='x')
out = helper.create_variable_for_type_inference(dtype) out = helper.create_variable_for_type_inference(dtype)
helper.append_op( helper.append_op(
type='{}_interp_v2'.format(resample_type), type='{}_interp_v2'.format(resample_type),
......
...@@ -597,8 +597,6 @@ def max_pool1d( ...@@ -597,8 +597,6 @@ def max_pool1d(
""" """
"""NCL to NCHW""" """NCL to NCHW"""
data_format = "NCHW" data_format = "NCHW"
if not in_dynamic_mode():
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'max_pool1d')
_check_input(x, 3) _check_input(x, 3)
x = unsqueeze(x, [2]) x = unsqueeze(x, [2])
kernel_size = [1] + utils.convert_to_list(kernel_size, 1, 'pool_size') kernel_size = [1] + utils.convert_to_list(kernel_size, 1, 'pool_size')
...@@ -641,6 +639,7 @@ def max_pool1d( ...@@ -641,6 +639,7 @@ def max_pool1d(
return squeeze(pool_out, [2]) return squeeze(pool_out, [2])
else: else:
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'max_pool1d')
op_type = 'max_pool2d_with_index' if return_mask else "pool2d" op_type = 'max_pool2d_with_index' if return_mask else "pool2d"
helper = LayerHelper(op_type, **locals()) helper = LayerHelper(op_type, **locals())
dtype = helper.input_dtype(input_param_name='x') dtype = helper.input_dtype(input_param_name='x')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册