From 6e80b84dc7c13931949be51ce5e341ff230e0519 Mon Sep 17 00:00:00 2001 From: Weilong Wu Date: Fri, 6 Jan 2023 11:10:44 +0800 Subject: [PATCH] [Eager] polish several api (#49589) --- python/paddle/nn/functional/activation.py | 20 +++++++++----------- python/paddle/nn/functional/common.py | 5 +++-- python/paddle/nn/functional/pooling.py | 3 +-- 3 files changed, 13 insertions(+), 15 deletions(-) diff --git a/python/paddle/nn/functional/activation.py b/python/paddle/nn/functional/activation.py index ada12acf6e..74f90c2970 100644 --- a/python/paddle/nn/functional/activation.py +++ b/python/paddle/nn/functional/activation.py @@ -494,11 +494,6 @@ def prelu(x, weight, data_format="NCHW", name=None): # [-1.25, 6. , 7. , -2. ], # [ 6. , 7. , 8. , 9. ]]]] """ - check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'prelu') - check_variable_and_dtype( - weight, 'weight', ['float16', 'float32', 'float64'], 'prelu' - ) - assert ( len(weight.shape) == 1 ), "The dim count of weight shape should be 1 in prelu()." @@ -541,6 +536,12 @@ def prelu(x, weight, data_format="NCHW", name=None): if in_dygraph_mode(): return _C_ops.prelu(x, weight, data_format, mode) else: + check_variable_and_dtype( + x, 'x', ['float16', 'float32', 'float64'], 'prelu' + ) + check_variable_and_dtype( + weight, 'weight', ['float16', 'float32', 'float64'], 'prelu' + ) helper = LayerHelper('prelu', **locals()) out = helper.create_variable_for_type_inference(x.dtype) helper.append_op( @@ -622,12 +623,6 @@ def rrelu(x, lower=1.0 / 8.0, upper=1.0 / 3.0, training=True, name=None): # [-1.3766339 6. 7. -2.3465784 ] # [ 6. 7. 8. 9. ]]]] """ - - if not in_dynamic_mode(): - check_variable_and_dtype( - x, 'X', ['float16', 'float32', 'float64'], 'rrelu' - ) - if not isinstance(lower, float) or not isinstance(upper, float): raise TypeError( "The lower and upper values must be float type. Received: lower {}, upper {}.".format( @@ -664,6 +659,9 @@ def rrelu(x, lower=1.0 / 8.0, upper=1.0 / 3.0, training=True, name=None): ) return out else: + check_variable_and_dtype( + x, 'X', ['float16', 'float32', 'float64'], 'rrelu' + ) helper = LayerHelper('rrelu', **locals()) out = helper.create_variable_for_type_inference(x.dtype) noise = helper.create_variable_for_type_inference(dtype=x.dtype) diff --git a/python/paddle/nn/functional/common.py b/python/paddle/nn/functional/common.py index beaec4c91c..6768288c3b 100644 --- a/python/paddle/nn/functional/common.py +++ b/python/paddle/nn/functional/common.py @@ -422,9 +422,7 @@ def interpolate( return paddle.nn.functional.adaptive_avg_pool2d(x, size) elif len(x.shape) == 5: return paddle.nn.functional.adaptive_avg_pool3d(x, size) - helper = LayerHelper('{}_interp_v2'.format(resample_type), **locals()) - dtype = helper.input_dtype(input_param_name='x') if len(x.shape) == 3 and data_format not in ['NCW', 'NWC']: raise ValueError( "Got wrong value for param `data_format`: " @@ -678,6 +676,9 @@ def interpolate( else: out = _legacy_C_ops.bicubic_interp_v2(x, *dy_attr) return out + + dtype = helper.input_dtype(input_param_name='x') + out = helper.create_variable_for_type_inference(dtype) helper.append_op( type='{}_interp_v2'.format(resample_type), diff --git a/python/paddle/nn/functional/pooling.py b/python/paddle/nn/functional/pooling.py index 42b58147b9..6aab78ec11 100755 --- a/python/paddle/nn/functional/pooling.py +++ b/python/paddle/nn/functional/pooling.py @@ -597,8 +597,6 @@ def max_pool1d( """ """NCL to NCHW""" data_format = "NCHW" - if not in_dynamic_mode(): - check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'max_pool1d') _check_input(x, 3) x = unsqueeze(x, [2]) kernel_size = [1] + utils.convert_to_list(kernel_size, 1, 'pool_size') @@ -641,6 +639,7 @@ def max_pool1d( return squeeze(pool_out, [2]) else: + check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'max_pool1d') op_type = 'max_pool2d_with_index' if return_mask else "pool2d" helper = LayerHelper(op_type, **locals()) dtype = helper.input_dtype(input_param_name='x') -- GitLab