From 609b50a8beda8b8fbc0649aaffcbd2d1063e1f44 Mon Sep 17 00:00:00 2001 From: Weilong Wu Date: Fri, 13 Jan 2023 10:27:17 +0800 Subject: [PATCH] [Eager] polish some api logic (#49717) * [Eager] polish some api logic * fix split * revover --- python/paddle/nn/layer/norm.py | 2 ++ python/paddle/tensor/manipulation.py | 15 ++------------- python/paddle/tensor/search.py | 17 ++++++++--------- 3 files changed, 12 insertions(+), 22 deletions(-) diff --git a/python/paddle/nn/layer/norm.py b/python/paddle/nn/layer/norm.py index aed7e455d63..9e1486d1d2a 100644 --- a/python/paddle/nn/layer/norm.py +++ b/python/paddle/nn/layer/norm.py @@ -999,6 +999,8 @@ class BatchNorm(Layer): self._use_global_stats, self._trainable_statistics, ) + if self._act is None: + return batch_norm_out return dygraph_utils._append_activation_in_dygraph( batch_norm_out, act=self._act, use_mkldnn=self._use_mkldnn ) diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 3c3355d7bb8..0c74a900e1b 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -1916,31 +1916,20 @@ def split(x, num_or_sections, axis=0, name=None): input = x dim = axis if in_dygraph_mode(): - num = None - attrs = () - if isinstance(dim, Variable): dim = dim.numpy() dim = dim.item(0) assert len(input.shape) + dim >= 0, "(rank(x) + axis) must >= 0" dim = (len(input.shape) + dim) if dim < 0 else dim - attrs += ('axis', dim) - if isinstance(num_or_sections, int): - num = num_or_sections - attrs += ('num', num_or_sections) - elif isinstance(num_or_sections, (list, tuple)): - num = len(num_or_sections) + if isinstance(num_or_sections, (list, tuple)): if utils._contain_var(num_or_sections): for index, item in enumerate(num_or_sections): if isinstance(item, Variable): num_or_sections[index] = num_or_sections[index].numpy()[ 0 ] - attrs += ('sections', list(num_or_sections)) - else: - attrs += ('sections', list(num_or_sections)) - else: + elif not isinstance(num_or_sections, int): raise TypeError( "The type of 'num_or_sections' in split must be int, list or tuple in imperative mode, but " "received %s." % (type(num_or_sections)) diff --git a/python/paddle/tensor/search.py b/python/paddle/tensor/search.py index 8c0e912ac8e..7fe0850ac5c 100755 --- a/python/paddle/tensor/search.py +++ b/python/paddle/tensor/search.py @@ -612,15 +612,6 @@ def where(condition, x=None, y=None, name=None): if x is None or y is None: raise ValueError("either both or neither of x and y should be given") - if not paddle.in_dynamic_mode(): - check_variable_and_dtype(condition, 'condition', ['bool'], 'where') - check_variable_and_dtype( - x, 'x', ['float32', 'float64', 'int32', 'int64'], 'where' - ) - check_variable_and_dtype( - y, 'y', ['float32', 'float64', 'int32', 'int64'], 'where' - ) - condition_shape = list(condition.shape) x_shape = list(x.shape) y_shape = list(y.shape) @@ -646,6 +637,14 @@ def where(condition, x=None, y=None, name=None): if in_dygraph_mode(): return _C_ops.where(broadcast_condition, broadcast_x, broadcast_y) else: + check_variable_and_dtype(condition, 'condition', ['bool'], 'where') + check_variable_and_dtype( + x, 'x', ['float32', 'float64', 'int32', 'int64'], 'where' + ) + check_variable_and_dtype( + y, 'y', ['float32', 'float64', 'int32', 'int64'], 'where' + ) + helper = LayerHelper("where", **locals()) out = helper.create_variable_for_type_inference(dtype=x.dtype) -- GitLab