未验证 提交 609b50a8 编写于 作者: W Weilong Wu 提交者: GitHub

[Eager] polish some api logic (#49717)

* [Eager] polish some api logic

* fix split

* revover
上级 0b24d167
......@@ -999,6 +999,8 @@ class BatchNorm(Layer):
self._use_global_stats,
self._trainable_statistics,
)
if self._act is None:
return batch_norm_out
return dygraph_utils._append_activation_in_dygraph(
batch_norm_out, act=self._act, use_mkldnn=self._use_mkldnn
)
......
......@@ -1916,31 +1916,20 @@ def split(x, num_or_sections, axis=0, name=None):
input = x
dim = axis
if in_dygraph_mode():
num = None
attrs = ()
if isinstance(dim, Variable):
dim = dim.numpy()
dim = dim.item(0)
assert len(input.shape) + dim >= 0, "(rank(x) + axis) must >= 0"
dim = (len(input.shape) + dim) if dim < 0 else dim
attrs += ('axis', dim)
if isinstance(num_or_sections, int):
num = num_or_sections
attrs += ('num', num_or_sections)
elif isinstance(num_or_sections, (list, tuple)):
num = len(num_or_sections)
if isinstance(num_or_sections, (list, tuple)):
if utils._contain_var(num_or_sections):
for index, item in enumerate(num_or_sections):
if isinstance(item, Variable):
num_or_sections[index] = num_or_sections[index].numpy()[
0
]
attrs += ('sections', list(num_or_sections))
else:
attrs += ('sections', list(num_or_sections))
else:
elif not isinstance(num_or_sections, int):
raise TypeError(
"The type of 'num_or_sections' in split must be int, list or tuple in imperative mode, but "
"received %s." % (type(num_or_sections))
......
......@@ -612,15 +612,6 @@ def where(condition, x=None, y=None, name=None):
if x is None or y is None:
raise ValueError("either both or neither of x and y should be given")
if not paddle.in_dynamic_mode():
check_variable_and_dtype(condition, 'condition', ['bool'], 'where')
check_variable_and_dtype(
x, 'x', ['float32', 'float64', 'int32', 'int64'], 'where'
)
check_variable_and_dtype(
y, 'y', ['float32', 'float64', 'int32', 'int64'], 'where'
)
condition_shape = list(condition.shape)
x_shape = list(x.shape)
y_shape = list(y.shape)
......@@ -646,6 +637,14 @@ def where(condition, x=None, y=None, name=None):
if in_dygraph_mode():
return _C_ops.where(broadcast_condition, broadcast_x, broadcast_y)
else:
check_variable_and_dtype(condition, 'condition', ['bool'], 'where')
check_variable_and_dtype(
x, 'x', ['float32', 'float64', 'int32', 'int64'], 'where'
)
check_variable_and_dtype(
y, 'y', ['float32', 'float64', 'int32', 'int64'], 'where'
)
helper = LayerHelper("where", **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册