未验证 提交 9aef0e3e 编写于 作者: H HongyuJia 提交者: GitHub

[Clean fluid] Add inner function _elementwise_op_with_axis (#48748)

* add inner function _elementwise_op_with_axis

* fix transformer_model

* polish API code

* remove elementwise_div/mul api

* delete API in __all__

* delete elementwise_mul completely

* polish elementwise_mul call

* polish internal api

* resolve conflict, fix rnn.py

* use non-inplace call

* delete elementwise_mul api test

* delete elementwise_mul api test

* clean elementwise_add/sub

* restore _elementwise_op_in_dygraph in nn.py
上级 6c0755d9
无相关合并请求
......@@ -367,9 +367,9 @@ def basic_gru(
new_hidden = unit_list[i](step_input, pre_hidden)
if mask:
new_hidden = layers.elementwise_mul(
new_hidden = paddle.tensor.math._multiply_with_axis(
new_hidden, step_mask, axis=0
) - layers.elementwise_mul(
) - paddle.tensor.math._multiply_with_axis(
pre_hidden, (step_mask - 1), axis=0
)
rnn.update_memory(pre_hidden, new_hidden)
......@@ -661,14 +661,14 @@ def basic_lstm(
)
if mask:
new_hidden = layers.elementwise_mul(
new_hidden = paddle.tensor.math._multiply_with_axis(
new_hidden, step_mask, axis=0
) - layers.elementwise_mul(
) - paddle.tensor.math._multiply_with_axis(
pre_hidden, (step_mask - 1), axis=0
)
new_cell = layers.elementwise_mul(
new_cell = paddle.tensor.math._multiply_with_axis(
new_cell, step_mask, axis=0
) - layers.elementwise_mul(
) - paddle.tensor.math._multiply_with_axis(
pre_cell, (step_mask - 1), axis=0
)
......
......@@ -115,8 +115,6 @@ class LayerHelperBase:
)
def _create_weight_normalize(self, attr, shape, dtype):
from .layers import elementwise_mul
# Remove these ops when LayerHelper and layers support indicating
# program and block.
def __norm_op(
......@@ -272,7 +270,7 @@ class LayerHelperBase:
# Currently, elementwise_mul only support broadcast when the shape
# of y is a subset of the shape of x. Thus, we reshape y to squeeze
# to achieve the subset.
w = elementwise_mul(
w = paddle.tensor.math._multiply_with_axis(
x=v,
y=scale
if dim is None
......
......@@ -79,10 +79,6 @@ __all__ = [
'unsqueeze',
'lod_reset',
'relu',
'elementwise_add',
'elementwise_div',
'elementwise_sub',
'elementwise_mul',
'clip',
'clip_by_norm',
'mul',
......@@ -2464,512 +2460,6 @@ def relu(x, name=None):
return out
from paddle.fluid.framework import convert_np_dtype_to_dtype_
def _elementwise_op(helper):
op_type = helper.layer_type
x = helper.kwargs.get('x', None)
y = helper.kwargs.get('y', None)
assert x is not None, 'x cannot be None in {}'.format(op_type)
assert y is not None, 'y cannot be None in {}'.format(op_type)
check_variable_and_dtype(
x,
'x',
['float16', 'uint16', 'float32', 'float64', 'int32', 'int64'],
op_type,
)
check_variable_and_dtype(
y,
'y',
['float16', 'uint16', 'float32', 'float64', 'int32', 'int64'],
op_type,
)
axis = helper.kwargs.get('axis', -1)
use_mkldnn = helper.kwargs.get('use_mkldnn', False)
name = helper.kwargs.get('name', None)
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type=op_type,
inputs={'X': x, 'Y': y},
outputs={'Out': out},
attrs={'axis': axis, 'use_mkldnn': use_mkldnn},
)
return helper.append_activation(out)
def elementwise_add(x, y, axis=-1, act=None, name=None):
"""
Examples:
.. code-block:: python
import paddle.fluid as fluid
import numpy as np
import paddle
def gen_data():
return {
"x": np.array([2, 3, 4]).astype('float32'),
"y": np.array([1, 5, 2]).astype('float32')
}
paddle.enable_static()
x = fluid.data(name="x", shape=[3], dtype='float32')
y = fluid.data(name="y", shape=[3], dtype='float32')
z = fluid.layers.elementwise_add(x, y)
# z = x + y
place = fluid.CPUPlace()
exe = fluid.Executor(place)
z_value = exe.run(feed=gen_data(),
fetch_list=[z.name])
print(z_value) # [3., 8., 6.]
.. code-block:: python
import paddle.fluid as fluid
import numpy as np
import paddle
def gen_data():
return {
"x": np.ones((2, 3, 4, 5)).astype('float32'),
"y": np.zeros((3, 4)).astype('float32')
}
paddle.enable_static()
x = fluid.data(name="x", shape=[2,3,4,5], dtype='float32')
y = fluid.data(name="y", shape=[3,4], dtype='float32')
z = fluid.layers.elementwise_add(x, y, axis=1)
# z = x + y
place = fluid.CPUPlace()
exe = fluid.Executor(place)
z_value = exe.run(feed=gen_data(),
fetch_list=[z.name])
print(z_value) # z.shape=[2,3,4,5]
.. code-block:: python
import paddle.fluid as fluid
import numpy as np
import paddle
def gen_data():
return {
"x": np.random.randint(1, 5, size=[2, 3, 4, 5]).astype('float32'),
"y": np.random.randint(1, 5, size=[5]).astype('float32')
}
paddle.enable_static()
x = fluid.data(name="x", shape=[2,3,4,5], dtype='float32')
y = fluid.data(name="y", shape=[5], dtype='float32')
z = fluid.layers.elementwise_add(x, y, axis=3)
# z = x + y
place = fluid.CPUPlace()
exe = fluid.Executor(place)
z_value = exe.run(feed=gen_data(),
fetch_list=[z.name])
print(z_value) # z.shape=[2,3,4,5]
"""
if _non_static_mode():
return _elementwise_op_in_dygraph(
x,
y,
axis=axis,
act=act,
op_name='elementwise_add',
use_mkldnn=_global_flags()["FLAGS_use_mkldnn"],
)
return _elementwise_op(LayerHelper('elementwise_add', **locals()))
@deprecated(since="2.0.0", update_to="paddle.divide")
def elementwise_div(x, y, axis=-1, act=None, name=None):
"""
Examples:
.. code-block:: python
import paddle.fluid as fluid
import numpy as np
import paddle
def gen_data():
return {
"x": np.array([2, 3, 4]).astype('float32'),
"y": np.array([1, 5, 2]).astype('float32')
}
paddle.enable_static()
x = fluid.data(name="x", shape=[3], dtype='float32')
y = fluid.data(name="y", shape=[3], dtype='float32')
z = fluid.layers.elementwise_div(x, y)
# z = x / y
place = fluid.CPUPlace()
exe = fluid.Executor(place)
z_value = exe.run(feed=gen_data(),
fetch_list=[z.name])
print(z_value) # [2., 0.6, 2.]
.. code-block:: python
import paddle.fluid as fluid
import numpy as np
import paddle
def gen_data():
return {
"x": np.ones((2, 3, 4, 5)).astype('float32'),
"y": np.zeros((3, 4)).astype('float32')
}
paddle.enable_static()
x = fluid.data(name="x", shape=[2,3,4,5], dtype='float32')
y = fluid.data(name="y", shape=[3,4], dtype='float32')
z = fluid.layers.elementwise_div(x, y, axis=1)
# z = x / y
place = fluid.CPUPlace()
exe = fluid.Executor(place)
z_value = exe.run(feed=gen_data(),
fetch_list=[z.name])
print(z_value) # z.shape=[2,3,4,5]
.. code-block:: python
import paddle.fluid as fluid
import numpy as np
import paddle
def gen_data():
return {
"x": np.random.randint(1, 5, size=[2, 3, 4, 5]).astype('float32'),
"y": np.random.randint(1, 5, size=[5]).astype('float32')
}
paddle.enable_static()
x = fluid.data(name="x", shape=[2,3,4,5], dtype='float32')
y = fluid.data(name="y", shape=[5], dtype='float32')
z = fluid.layers.elementwise_div(x, y, axis=3)
# z = x / y
place = fluid.CPUPlace()
exe = fluid.Executor(place)
z_value = exe.run(feed=gen_data(),
fetch_list=[z.name])
print(z_value) # z.shape=[2,3,4,5]
"""
if _non_static_mode():
return _elementwise_op_in_dygraph(
x, y, axis=axis, act=act, op_name='elementwise_div'
)
return _elementwise_op(LayerHelper('elementwise_div', **locals()))
def elementwise_sub(x, y, axis=-1, act=None, name=None):
"""
Examples:
.. code-block:: python
import paddle.fluid as fluid
import numpy as np
import paddle
def gen_data():
return {
"x": np.array([2, 3, 4]).astype('float32'),
"y": np.array([1, 5, 2]).astype('float32')
}
paddle.enable_static()
x = fluid.data(name="x", shape=[3], dtype='float32')
y = fluid.data(name="y", shape=[3], dtype='float32')
z = fluid.layers.elementwise_sub(x, y)
# z = x - y
place = fluid.CPUPlace()
exe = fluid.Executor(place)
z_value = exe.run(feed=gen_data(),
fetch_list=[z.name])
print(z_value) # [1., -2., 2.]
.. code-block:: python
import paddle.fluid as fluid
import numpy as np
import paddle
def gen_data():
return {
"x": np.ones((2, 3, 4, 5)).astype('float32'),
"y": np.zeros((3, 4)).astype('float32')
}
paddle.enable_static()
x = fluid.data(name="x", shape=[2,3,4,5], dtype='float32')
y = fluid.data(name="y", shape=[3,4], dtype='float32')
z = fluid.layers.elementwise_sub(x, y, axis=1)
# z = x - y
place = fluid.CPUPlace()
exe = fluid.Executor(place)
z_value = exe.run(feed=gen_data(),
fetch_list=[z.name])
print(z_value) # z.shape=[2,3,4,5]
.. code-block:: python
import paddle.fluid as fluid
import numpy as np
import paddle
def gen_data():
return {
"x": np.random.randint(1, 5, size=[2, 3, 4, 5]).astype('float32'),
"y": np.random.randint(1, 5, size=[5]).astype('float32')
}
paddle.enable_static()
x = fluid.data(name="x", shape=[2,3,4,5], dtype='float32')
y = fluid.data(name="y", shape=[5], dtype='float32')
z = fluid.layers.elementwise_sub(x, y, axis=3)
# z = x - y
place = fluid.CPUPlace()
exe = fluid.Executor(place)
z_value = exe.run(feed=gen_data(),
fetch_list=[z.name])
print(z_value) # z.shape=[2,3,4,5]
"""
if _non_static_mode():
return _elementwise_op_in_dygraph(
x, y, axis=axis, act=act, op_name='elementwise_sub'
)
return _elementwise_op(LayerHelper('elementwise_sub', **locals()))
@deprecated(since="2.0.0", update_to="paddle.multiply")
def elementwise_mul(x, y, axis=-1, act=None, name=None):
"""
Examples:
.. code-block:: python
import paddle.fluid as fluid
import numpy as np
import paddle
def gen_data():
return {
"x": np.array([2, 3, 4]).astype('float32'),
"y": np.array([1, 5, 2]).astype('float32')
}
paddle.enable_static()
x = fluid.data(name="x", shape=[3], dtype='float32')
y = fluid.data(name="y", shape=[3], dtype='float32')
z = fluid.layers.elementwise_mul(x, y)
# z = x * y
place = fluid.CPUPlace()
exe = fluid.Executor(place)
z_value = exe.run(feed=gen_data(),
fetch_list=[z.name])
print(z_value) # [2., 15., 8.]
.. code-block:: python
import paddle.fluid as fluid
import numpy as np
import paddle
def gen_data():
return {
"x": np.ones((2, 3, 4, 5)).astype('float32'),
"y": np.zeros((3, 4)).astype('float32')
}
paddle.enable_static()
x = fluid.data(name="x", shape=[2,3,4,5], dtype='float32')
y = fluid.data(name="y", shape=[3,4], dtype='float32')
z = fluid.layers.elementwise_mul(x, y, axis=1)
# z = x * y
place = fluid.CPUPlace()
exe = fluid.Executor(place)
z_value = exe.run(feed=gen_data(),
fetch_list=[z.name])
print(z_value) # z.shape=[2,3,4,5]
.. code-block:: python
import paddle.fluid as fluid
import numpy as np
import paddle
def gen_data():
return {
"x": np.random.randint(1, 5, size=[2, 3, 4, 5]).astype('float32'),
"y": np.random.randint(1, 5, size=[5]).astype('float32')
}
paddle.enable_static()
x = fluid.data(name="x", shape=[2,3,4,5], dtype='float32')
y = fluid.data(name="y", shape=[5], dtype='float32')
z = fluid.layers.elementwise_mul(x, y, axis=3)
# z = x * y
place = fluid.CPUPlace()
exe = fluid.Executor(place)
z_value = exe.run(feed=gen_data(),
fetch_list=[z.name])
print(z_value) # z.shape=[2,3,4,5]
"""
if _non_static_mode():
return _elementwise_op_in_dygraph(
x, y, axis=axis, act=act, op_name='elementwise_mul'
)
return _elementwise_op(LayerHelper('elementwise_mul', **locals()))
for func in [
elementwise_add,
elementwise_div,
elementwise_sub,
elementwise_mul,
]:
op_proto = OpProtoHolder.instance().get_op_proto(func.__name__)
# insert the c++ doc string on top of python doc string
func.__doc__ = (
_generate_doc_string_(
op_proto,
additional_args_lines=[
"axis (int32, optional): If X.dimension != Y.dimension, \
Y.dimension must be a subsequence of x.dimension. \
And axis is the start dimension index for broadcasting Y onto X. ",
"act (string, optional): Activation applied to the output. \
Default is None. Details: :ref:`api_guide_activations_en` ",
"name (string, optional): Name of the output. \
Default is None. It's used to print debug info for developers. Details: \
:ref:`api_guide_Name` ",
],
skip_attrs_set={
"x_data_format",
"y_data_format",
"axis",
"use_quantizer",
"mkldnn_data_type",
"Scale_x",
"Scale_y",
"Scale_out",
},
)
+ """\n"""
+ str(func.__doc__)
)
doc_list = func.__doc__.splitlines()
for idx, val in enumerate(doc_list):
if (
val.startswith("Warning: ")
and val.endswith(" instead.")
and "and will be removed in future versions." in val
):
doc_list.insert(0, doc_list.pop(idx))
func.__doc__ = "\n" + "\n".join(i for i in doc_list)
break
for func in []:
op_proto = OpProtoHolder.instance().get_op_proto(func.__name__)
func.__doc__ = _generate_doc_string_(
op_proto,
additional_args_lines=[
"act (basestring|None): Activation applied to the output.",
"name (basestring|None): Name of the output.",
],
)
func.__doc__ = (
func.__doc__
+ """
Examples:
.. code-block:: python
import paddle.fluid as fluid
# example 1: shape(x) = (2, 3, 4, 5), shape(y) = (2, 3, 4, 5)
x0 = fluid.layers.data(name="x0", shape=[2, 3, 4, 5], dtype='float32')
y0 = fluid.layers.data(name="y0", shape=[2, 3, 4, 5], dtype='float32')
z0 = fluid.layers.%s(x0, y0)
# example 2: shape(X) = (2, 3, 4, 5), shape(Y) = (5)
x1 = fluid.layers.data(name="x1", shape=[2, 3, 4, 5], dtype='float32')
y1 = fluid.layers.data(name="y1", shape=[5], dtype='float32')
z1 = fluid.layers.%s(x1, y1)
# example 3: shape(X) = (2, 3, 4, 5), shape(Y) = (4, 5), with axis=-1(default) or axis=2
x2 = fluid.layers.data(name="x2", shape=[2, 3, 4, 5], dtype='float32')
y2 = fluid.layers.data(name="y2", shape=[4, 5], dtype='float32')
z2 = fluid.layers.%s(x2, y2, axis=2)
# example 4: shape(X) = (2, 3, 4, 5), shape(Y) = (3, 4), with axis=1
x3 = fluid.layers.data(name="x3", shape=[2, 3, 4, 5], dtype='float32')
y3 = fluid.layers.data(name="y3", shape=[3, 4], dtype='float32')
z3 = fluid.layers.%s(x3, y3, axis=1)
# example 5: shape(X) = (2, 3, 4, 5), shape(Y) = (2), with axis=0
x4 = fluid.layers.data(name="x4", shape=[2, 3, 4, 5], dtype='float32')
y4 = fluid.layers.data(name="y4", shape=[2], dtype='float32')
z4 = fluid.layers.%s(x4, y4, axis=0)
# example 6: shape(X) = (2, 3, 4, 5), shape(Y) = (2, 1), with axis=0
x5 = fluid.layers.data(name="x5", shape=[2, 3, 4, 5], dtype='float32')
y5 = fluid.layers.data(name="y5", shape=[2], dtype='float32')
z5 = fluid.layers.%s(x5, y5, axis=0)
"""
% (
func.__name__,
func.__name__,
func.__name__,
func.__name__,
func.__name__,
func.__name__,
)
)
def _logical_op(op_name, x, y, out=None, name=None, binary_op=True):
if _non_static_mode():
op = getattr(_legacy_C_ops, op_name)
......
......@@ -547,9 +547,9 @@ class ArrayWrapper:
def _maybe_copy(state, new_state, step_mask):
"""update rnn state or just pass the old state through"""
new_state = nn.elementwise_mul(
new_state = paddle.tensor.math._multiply_with_axis(
new_state, step_mask, axis=0
) + nn.elementwise_mul(state, (1 - step_mask), axis=0)
) + paddle.tensor.math._multiply_with_axis(state, (1 - step_mask), axis=0)
return new_state
......@@ -833,9 +833,11 @@ def _dynamic_decode_imperative(
# otherwise, renamed bool gradients of would be summed up leading
# to sum(bool) error.
step_mask.stop_gradient = True
new_state = nn.elementwise_mul(
new_state = paddle.tensor.math._multiply_with_axis(
state, step_mask, axis=0
) - nn.elementwise_mul(new_state, (step_mask - 1), axis=0)
) - paddle.tensor.math._multiply_with_axis(
new_state, (step_mask - 1), axis=0
)
if convert_dtype(state_dtype) in ["bool"]:
new_state = tensor.cast(new_state, dtype=state_dtype)
return new_state
......@@ -988,9 +990,11 @@ def _dynamic_decode_declarative(
# otherwise, renamed bool gradients of would be summed up leading
# to sum(bool) error.
step_mask.stop_gradient = True
new_state = nn.elementwise_mul(
new_state = paddle.tensor.math._multiply_with_axis(
state, step_mask, axis=0
) - nn.elementwise_mul(new_state, (step_mask - 1), axis=0)
) - paddle.tensor.math._multiply_with_axis(
new_state, (step_mask - 1), axis=0
)
if convert_dtype(state_dtype) in ["bool"]:
new_state = tensor.cast(new_state, dtype=state_dtype)
return new_state
......
......@@ -139,7 +139,7 @@ class SqueezeExcitation(fluid.dygraph.Layer):
y = paddle.nn.functional.relu(y)
y = self._excitation(y)
y = paddle.nn.functional.sigmoid(y)
y = fluid.layers.elementwise_mul(x=input, y=y, axis=0)
y = paddle.tensor.math._multiply_with_axis(x=input, y=y, axis=0)
return y
......
......@@ -199,7 +199,9 @@ class SE_ResNeXt:
),
act='sigmoid',
)
scale = fluid.layers.elementwise_mul(x=input, y=excitation, axis=0)
scale = paddle.tensor.math._multiply_with_axis(
x=input, y=excitation, axis=0
)
return scale
......
......@@ -196,9 +196,11 @@ class BaseModel(fluid.dygraph.Layer):
return x
def _real_state(self, state, new_state, step_mask):
new_state = fluid.layers.elementwise_mul(
new_state = paddle.tensor.math._multiply_with_axis(
new_state, step_mask, axis=0
) - fluid.layers.elementwise_mul(state, (step_mask - 1), axis=0)
) - paddle.tensor.math._multiply_with_axis(
state, (step_mask - 1), axis=0
)
return new_state
def _gather(self, x, indices, batch_pos):
......@@ -452,10 +454,10 @@ class BaseModel(fluid.dygraph.Layer):
[-1, -1, self.tar_vocab_size],
),
noend_mask_tensor,
) - fluid.layers.elementwise_mul(
) - paddle.tensor.math._multiply_with_axis(
step_log_probs, (beam_finished - 1), axis=0
)
log_probs = fluid.layers.elementwise_add(
log_probs = paddle.tensor.math._add_with_axis(
x=step_log_probs, y=beam_state_log_probs, axis=0
)
scores = paddle.reshape(
......@@ -689,9 +691,11 @@ class AttentionModel(fluid.dygraph.Layer):
return x
def _real_state(self, state, new_state, step_mask):
new_state = fluid.layers.elementwise_mul(
new_state = paddle.tensor.math._multiply_with_axis(
new_state, step_mask, axis=0
) - fluid.layers.elementwise_mul(state, (step_mask - 1), axis=0)
) - paddle.tensor.math._multiply_with_axis(
state, (step_mask - 1), axis=0
)
return new_state
def _gather(self, x, indices, batch_pos):
......
......@@ -105,18 +105,14 @@ class Cycle_Gan(fluid.dygraph.Layer):
G = g_A_loss + g_B_loss
idt_A = self.build_generator_resnet_9blocks_a(input_B)
idt_loss_A = (
paddle.mean(
paddle.abs(fluid.layers.elementwise_sub(x=input_B, y=idt_A))
)
paddle.mean(paddle.abs(paddle.subtract(x=input_B, y=idt_A)))
* lambda_B
* lambda_identity
)
idt_B = self.build_generator_resnet_9blocks_b(input_A)
idt_loss_B = (
paddle.mean(
paddle.abs(fluid.layers.elementwise_sub(x=input_A, y=idt_B))
)
paddle.mean(paddle.abs(paddle.subtract(x=input_A, y=idt_B)))
* lambda_A
* lambda_identity
)
......
......@@ -152,7 +152,7 @@ class SqueezeExcitation(fluid.dygraph.Layer):
y = paddle.nn.functional.relu(y)
y = self._excitation(y)
y = paddle.nn.functional.sigmoid(y)
y = fluid.layers.elementwise_mul(x=input, y=y, axis=0)
y = paddle.tensor.math._multiply_with_axis(x=input, y=y, axis=0)
return y
......
......@@ -758,7 +758,9 @@ class Transformer(Layer):
[-1, -1, self.trg_vocab_size],
),
noend_mask_tensor,
) - layers.elementwise_mul(probs, (finished - 1), axis=0)
) - paddle.tensor.math._multiply_with_axis(
probs, (finished - 1), axis=0
)
return probs
def gather(input, indices, batch_pos):
......@@ -846,7 +848,7 @@ class Transformer(Layer):
step_log_probs = mask_probs(
step_log_probs, finished, noend_mask_tensor
)
log_probs = layers.elementwise_add(
log_probs = paddle.tensor.math._add_with_axis(
x=step_log_probs, y=log_probs, axis=0
)
log_probs = paddle.reshape(
......
......@@ -35,7 +35,7 @@ class TestMul(IPUOpTest):
return True
def set_test_op(self):
self.op = paddle.fluid.layers.elementwise_mul
self.op = paddle.tensor.math._multiply_with_axis
def set_feed_attr(self):
self.feed_shape = [x.shape for x in self.feed_fp32.values()]
......
......@@ -360,7 +360,7 @@ class ElementwiseScaleOneDNNFusePassTest_Add(
ElementwiseActivationMkldnnFusePassTest
):
def set_params(self):
self.operand = fluid.layers.elementwise_add
self.operand = paddle.add
self.act_alpha = 0.6
self.act = paddle.scale
......@@ -369,7 +369,7 @@ class ElementwiseScaleOneDNNFusePassTest_Sub(
ElementwiseActivationMkldnnFusePassTest
):
def set_params(self):
self.operand = fluid.layers.elementwise_sub
self.operand = paddle.subtract
self.act_alpha = 0.6
self.act = paddle.scale
......@@ -378,7 +378,7 @@ class ElementwiseScaleOneDNNFusePassTest_Mul(
ElementwiseActivationMkldnnFusePassTest
):
def set_params(self):
self.operand = fluid.layers.elementwise_mul
self.operand = paddle.multiply
self.act_alpha = 0.6
self.act = paddle.scale
......@@ -387,7 +387,7 @@ class ElementwiseScaleOneDNNFusePassTest_Div(
ElementwiseActivationMkldnnFusePassTest
):
def set_params(self):
self.operand = fluid.layers.elementwise_div
self.operand = paddle.divide
self.act_alpha = 0.6
self.act = paddle.scale
......
......@@ -19,6 +19,7 @@ import unittest
import numpy as np
from inference_pass_test import InferencePassTest
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
import paddle.static.nn as nn
......@@ -49,7 +50,7 @@ class TensorRTSubgraphPassElementwiseBroadcastTest(InferencePassTest):
self.fetch_list = [out]
def append_eltwise(self, data1, data2):
return fluid.layers.elementwise_add(x=data1, y=data2, axis=0)
return paddle.tensor.math._add_with_axis(x=data1, y=data2, axis=0)
def test_check_output(self):
if os.path.exists(self.path + "_opt_cache"):
......@@ -66,21 +67,21 @@ class TensorRTSubgraphPassElementwiseBroadcastTest1(
TensorRTSubgraphPassElementwiseBroadcastTest
):
def append_eltwise(self, data1, data2):
return fluid.layers.elementwise_sub(x=data1, y=data2, axis=0)
return paddle.tensor.math._subtract_with_axis(x=data1, y=data2, axis=0)
class TensorRTSubgraphPassElementwiseBroadcastTest2(
TensorRTSubgraphPassElementwiseBroadcastTest
):
def append_eltwise(self, data1, data2):
return fluid.layers.elementwise_mul(x=data1, y=data2, axis=0)
return paddle.tensor.math._multiply_with_axis(x=data1, y=data2, axis=0)
class TensorRTSubgraphPassElementwiseBroadcastTest3(
TensorRTSubgraphPassElementwiseBroadcastTest
):
def append_eltwise(self, data1, data2):
return fluid.layers.elementwise_div(x=data1, y=data2, axis=0)
return paddle.tensor.math._divide_with_axis(x=data1, y=data2, axis=0)
if __name__ == "__main__":
......
......@@ -394,7 +394,7 @@ class TensorRTSubgraphPassElementwiseMulTest(
TensorRTSubgraphPassElementwiseTest
):
def append_eltwise(self, data1, data2):
return fluid.layers.elementwise_mul(x=data1, y=data2)
return paddle.multiply(x=data1, y=data2)
class TensorRTSubgraphPassElementwiseSerializeTest(
......
......@@ -396,24 +396,6 @@ class TestElementwiseAddOp_same_shape_ysize_large(TestElementwiseAddOp):
self.axis = 0
class TestElementwiseAddOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
# the input of elementwise_add must be Variable.
x1 = fluid.create_lod_tensor(
np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], fluid.MLUPlace(0)
)
y1 = fluid.create_lod_tensor(
np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], fluid.MLUPlace(0)
)
self.assertRaises(TypeError, fluid.layers.elementwise_add, x1, y1)
# the input dtype of elementwise_add must be float16 or float32
x2 = fluid.layers.data(name='x2', shape=[3, 4, 5, 6], dtype="uint8")
y2 = fluid.layers.data(name='y2', shape=[3, 4, 5, 6], dtype="uint8")
self.assertRaises(TypeError, fluid.layers.elementwise_add, x2, y2)
class TestAddApi(unittest.TestCase):
def _executed_api(self, x, y, name=None):
return paddle.add(x, y, name)
......
......@@ -222,23 +222,5 @@ class TestElementwiseMulOp_xsize_lessthan_ysize(ElementwiseMulOp):
self.init_kernel_type()
class TestElementwiseMulOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
# the input of elementwise_mul must be Variable.
x1 = fluid.create_lod_tensor(
np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], fluid.CPUPlace()
)
y1 = fluid.create_lod_tensor(
np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], fluid.CPUPlace()
)
self.assertRaises(TypeError, fluid.layers.elementwise_mul, x1, y1)
# the input dtype of elementwise_mul must be float16 or float32 or int32
x2 = fluid.layers.data(name='x2', shape=[3, 4, 5, 6], dtype="uint8")
y2 = fluid.layers.data(name='y2', shape=[3, 4, 5, 6], dtype="uint8")
self.assertRaises(TypeError, fluid.layers.elementwise_mul, x2, y2)
if __name__ == '__main__':
unittest.main()
......@@ -226,7 +226,7 @@ class TestAddAPI(unittest.TestCase):
class TestAddError(unittest.TestCase):
def test_errors(self):
with paddle.static.program_guard(paddle.static.Program()):
# the input of elementwise_add must be Variable.
# the input of paddle.add must be Variable.
x1 = fluid.create_lod_tensor(
np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], fluid.NPUPlace(0)
)
......@@ -504,25 +504,6 @@ class TestElementwiseAddOp_same_shape_ysize_large(TestElementwiseAddOp):
self.axis = 0
class TestElementwiseAddOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
# the input of elementwise_add must be Variable.
x1 = fluid.create_lod_tensor(
np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], fluid.NPUPlace(0)
)
y1 = fluid.create_lod_tensor(
np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], fluid.NPUPlace(0)
)
self.assertRaises(TypeError, fluid.layers.elementwise_add, x1, y1)
# the input dtype of elementwise_add must be float16 or float32 or float64 or int32 or int64
# float16 only can be set on GPU place
x2 = fluid.layers.data(name='x2', shape=[3, 4, 5, 6], dtype="uint8")
y2 = fluid.layers.data(name='y2', shape=[3, 4, 5, 6], dtype="uint8")
self.assertRaises(TypeError, fluid.layers.elementwise_add, x2, y2)
class TestAddApi(unittest.TestCase):
def _executed_api(self, x, y, name=None):
return paddle.add(x, y, name)
......
......@@ -59,7 +59,9 @@ def squeeze_excitation(input, num_channels, reduction_ratio):
excitation = fluid.layers.fc(
input=squeeze, size=num_channels, act='sigmoid'
)
scale = fluid.layers.elementwise_mul(x=input, y=excitation, axis=0)
scale = paddle.tensor.math._multiply_with_axis(
x=input, y=excitation, axis=0
)
return scale
......
......@@ -65,7 +65,6 @@ class TestDeprecatedDocorator(unittest.TestCase):
"""
tests for paddle's Deprecated Docorator.
test_fluid_data: test for old fluid.data API.
test_fluid_elementwise_mul: test for old fluid.layers.elementwise_xxx APIs.
test_new_multiply: test for new api, which should not insert warning information.
test_ops_elementwise_mul: test for C++ elementwise_mul op, which should not insert warning information.
"""
......@@ -89,28 +88,6 @@ class TestDeprecatedDocorator(unittest.TestCase):
# testting
self.assertGreater(expected, captured)
def test_fluid_elementwise_mul(self):
"""
test old fluid elementwise_mul api, it should trigger Warinng function,
which insert the Warinng info on top of API's doc string.
"""
# Initialization
a = np.random.uniform(0.1, 1, [51, 76]).astype(np.float32)
b = np.random.uniform(0.1, 1, [51, 76]).astype(np.float32)
x = paddle.to_tensor(a)
y = paddle.to_tensor(b)
res = fluid.layers.elementwise_mul(x, y)
# expected
expected = LOWEST_WARNING_POSTION
# captured
captured = get_warning_index(fluid.layers.elementwise_mul)
# testting
self.assertGreater(expected, captured)
def test_new_multiply(self):
"""
Test for new multiply api, expected result should be False.
......@@ -147,7 +124,7 @@ class TestDeprecatedDocorator(unittest.TestCase):
expected = LOWEST_WARNING_POSTION
# captured
captured = get_warning_index(fluid.layers.elementwise_mul)
captured = get_warning_index(paddle.multiply)
# testting
self.assertGreater(expected, captured)
......
......@@ -19,7 +19,6 @@ import numpy as np
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid import Program, program_guard
from paddle.fluid.framework import _test_eager_guard
from paddle.fluid.tests.unittests.op_test import (
OpTest,
......@@ -490,25 +489,6 @@ class TestElementwiseAddOp_same_shape_ysize_large(TestElementwiseAddOp):
self.axis = 0
class TestElementwiseAddOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
# the input of elementwise_add must be Variable.
x1 = fluid.create_lod_tensor(
np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], fluid.CPUPlace()
)
y1 = fluid.create_lod_tensor(
np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], fluid.CPUPlace()
)
self.assertRaises(TypeError, fluid.layers.elementwise_add, x1, y1)
# the input dtype of elementwise_add must be float16 or float32 or float64 or int32 or int64
# float16 only can be set on GPU place
x2 = fluid.layers.data(name='x2', shape=[3, 4, 5, 6], dtype="uint8")
y2 = fluid.layers.data(name='y2', shape=[3, 4, 5, 6], dtype="uint8")
self.assertRaises(TypeError, fluid.layers.elementwise_add, x2, y2)
class TestAddApi(unittest.TestCase):
def _executed_api(self, x, y, name=None):
return paddle.add(x, y, name)
......
......@@ -17,9 +17,7 @@ import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid import Program, program_guard
from paddle.fluid.framework import _test_eager_guard
from paddle.fluid.tests.unittests.op_test import (
OpTest,
......@@ -289,25 +287,6 @@ class TestElementwiseMulOp_xsize_lessthan_ysize(ElementwiseMulOp):
self.init_kernel_type()
class TestElementwiseMulOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
# the input of elementwise_mul must be Variable.
x1 = fluid.create_lod_tensor(
np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], fluid.CPUPlace()
)
y1 = fluid.create_lod_tensor(
np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], fluid.CPUPlace()
)
self.assertRaises(TypeError, fluid.layers.elementwise_mul, x1, y1)
# the input dtype of elementwise_mul must be float16 or float32 or float64 or int32 or int64
# float16 only can be set on GPU place
x2 = fluid.layers.data(name='x2', shape=[3, 4, 5, 6], dtype="uint8")
y2 = fluid.layers.data(name='y2', shape=[3, 4, 5, 6], dtype="uint8")
self.assertRaises(TypeError, fluid.layers.elementwise_mul, x2, y2)
class TestComplexElementwiseMulOp(OpTest):
def setUp(self):
self.op_type = "elementwise_mul"
......
......@@ -36,7 +36,7 @@ class TestElementwiseMulDoubleGradCheck(unittest.TestCase):
y = layers.data('y', shape, False, dtype)
x.persistable = True
y.persistable = True
out = layers.elementwise_mul(x, y)
out = paddle.multiply(x, y)
x_arr = np.random.uniform(-1, 1, shape).astype(dtype)
y_arr = np.random.uniform(-1, 1, shape).astype(dtype)
......@@ -65,7 +65,7 @@ class TestElementwiseMulBroadcastDoubleGradCheck(unittest.TestCase):
y = layers.data('y', shape[:-1], False, dtype)
x.persistable = True
y.persistable = True
out = layers.elementwise_mul(x, y, axis=0)
out = paddle.tensor.math._multiply_with_axis(x, y, axis=0)
x_arr = np.random.uniform(-1, 1, shape).astype(dtype)
y_arr = np.random.uniform(-1, 1, shape[:-1]).astype(dtype)
......@@ -123,7 +123,7 @@ class TestElementwiseAddBroadcastDoubleGradCheck(unittest.TestCase):
y = layers.data('y', shape[:-1], False, dtype)
x.persistable = True
y.persistable = True
out = layers.elementwise_add(x, y, axis=0)
out = paddle.tensor.math._add_with_axis(x, y, axis=0)
x_arr = np.random.uniform(-1, 1, shape).astype(dtype)
y_arr = np.random.uniform(-1, 1, shape[:-1]).astype(dtype)
......@@ -191,7 +191,7 @@ class TestElementwiseSubBroadcastDoubleGradCheck(unittest.TestCase):
y = layers.data('y', shape[:-1], False, dtype)
x.persistable = True
y.persistable = True
out = layers.elementwise_sub(x, y, axis=0)
out = paddle.tensor.math._subtract_with_axis(x, y, axis=0)
x_arr = np.random.uniform(-1, 1, shape).astype(dtype)
y_arr = np.random.uniform(-1, 1, shape[:-1]).astype(dtype)
......@@ -223,7 +223,7 @@ class TestElementwiseDivDoubleGradCheck(unittest.TestCase):
y = layers.data('y', shape, False, dtype)
x.persistable = True
y.persistable = True
out = layers.elementwise_div(x, y, axis=0)
out = paddle.tensor.math._divide_with_axis(x, y, axis=0)
x_arr = np.random.uniform(-1, 1, shape).astype(dtype)
y_arr = np.random.uniform(-1, 1, shape).astype(dtype)
y_arr[np.abs(y_arr) < 0.005] = 0.02
......@@ -261,7 +261,7 @@ class TestElementwiseDivBroadcastDoubleGradCheck(unittest.TestCase):
y = layers.data('y', shape[1:-1], False, dtype)
x.persistable = True
y.persistable = True
out = layers.elementwise_div(x, y, axis=1)
out = paddle.tensor.math._divide_with_axis(x, y, axis=1)
x_arr = np.random.uniform(-1, 1, shape).astype(dtype)
y_arr = np.random.uniform(-1, 1, shape[1:-1]).astype(dtype)
y_arr[np.abs(y_arr) < 0.005] = 0.02
......@@ -320,7 +320,7 @@ class TestElementwiseAddBroadcastTripleGradCheck(unittest.TestCase):
y = layers.data('y', shape[:-1], False, dtype)
x.persistable = True
y.persistable = True
out = layers.elementwise_add(x, y, axis=0)
out = paddle.tensor.math._add_with_axis(x, y, axis=0)
x_arr = np.random.uniform(-1, 1, shape).astype(dtype)
y_arr = np.random.uniform(-1, 1, shape[:-1]).astype(dtype)
......@@ -352,7 +352,7 @@ class TestElementwiseMulTripleGradCheck(unittest.TestCase):
y = layers.data('y', shape, False, dtype)
x.persistable = True
y.persistable = True
out = layers.elementwise_mul(x, y)
out = paddle.multiply(x, y)
x_arr = np.random.uniform(-1, 1, shape).astype(dtype)
y_arr = np.random.uniform(-1, 1, shape).astype(dtype)
......@@ -390,7 +390,7 @@ class TestElementwiseMulBroadcastTripleGradCheck(unittest.TestCase):
y = layers.data('y', shape[:-1], False, dtype)
x.persistable = True
y.persistable = True
out = layers.elementwise_add(x, y, axis=0)
out = paddle.tensor.math._add_with_axis(x, y, axis=0)
x_arr = np.random.uniform(-1, 1, shape).astype(dtype)
y_arr = np.random.uniform(-1, 1, shape[:-1]).astype(dtype)
......
......@@ -98,7 +98,7 @@ class TestFuseActElewiseAddInplaceGradPass(unittest.TestCase):
Y = fluid.data(name="Y", shape=[3, 3], dtype='float32')
Out1 = X * 5
Out2 = fluid.layers.relu(Out1)
prediction = fluid.layers.elementwise_add(Y, Out2, axis=1)
prediction = paddle.tensor.math._add_with_axis(Y, Out2, axis=1)
loss = paddle.mean(prediction)
sgd = fluid.optimizer.SGD(learning_rate=0.001)
sgd.minimize(loss)
......
......@@ -307,7 +307,7 @@ class SimpleAttention(fluid.dygraph.Layer):
)
weights_reshape = paddle.nn.functional.softmax(weights_reshape)
scaled = fluid.layers.elementwise_mul(
scaled = paddle.tensor.math._multiply_with_axis(
x=encoder_vec, y=weights_reshape, axis=0
)
context = paddle.sum(scaled, axis=1)
......
......@@ -130,7 +130,7 @@ class SqueezeExcitation(fluid.dygraph.Layer):
y = self.act_1(y)
y = self._excitation(y)
y = self.act_2(y)
y = fluid.layers.elementwise_mul(x=input, y=y, axis=0)
y = paddle.tensor.math._multiply_with_axis(x=input, y=y, axis=0)
return y
......
......@@ -409,9 +409,9 @@ def gradient_penalty(f, real, fake, no_grad_set, cfg):
input=a, shape=shape, min=0.1, max=1.0, seed=cfg.seed
)
inner = fluid.layers.elementwise_mul(
inner = paddle.tensor.math._multiply_with_axis(
b, 1.0 - alpha, axis=0
) + fluid.layers.elementwise_mul(a, alpha, axis=0)
) + paddle.tensor.math._multiply_with_axis(a, alpha, axis=0)
return inner
x = _interpolate(real, fake)
......
......@@ -159,8 +159,8 @@ def multi_head_attention(
def __softmax(x, eps=1e-9):
exp_out = paddle.exp(x=x)
sum_out = paddle.sum(exp_out, axis=-1, keepdim=False)
return layers.elementwise_div(x=exp_out, y=sum_out, axis=0)
sum_out = paddle.sum(exp_out, axis=-1, keepdim=True)
return paddle.divide(x=exp_out, y=sum_out)
scaled_q = paddle.scale(x=q, scale=d_model**-0.5)
product = paddle.matmul(x=scaled_q, y=k, transpose_y=True)
......
......@@ -29,7 +29,6 @@ from xpu.get_test_cover_info import (
import paddle
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
paddle.enable_static()
......@@ -263,32 +262,6 @@ class XPUTestElementwiseAddOp(XPUOpTestWrapper):
def init_axis(self):
self.axis = 2
class TestElementwiseAddOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
# the input of elementwise_add must be Variable.
x1 = fluid.create_lod_tensor(
np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], fluid.XPUPlace(0)
)
y1 = fluid.create_lod_tensor(
np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], fluid.XPUPlace(0)
)
self.assertRaises(
TypeError, fluid.layers.elementwise_add, x1, y1
)
# the input dtype of elementwise_add must be float16 or float32 or float64 or int32 or int64
# float16 only can be set on GPU place
x2 = fluid.layers.data(
name='x2', shape=[3, 4, 5, 6], dtype="uint8"
)
y2 = fluid.layers.data(
name='y2', shape=[3, 4, 5, 6], dtype="uint8"
)
self.assertRaises(
TypeError, fluid.layers.elementwise_add, x2, y2
)
class TestAddOp(unittest.TestCase):
def test_name(self):
with fluid.program_guard(fluid.Program()):
......
......@@ -24,7 +24,6 @@ from op_test_xpu import XPUOpTest
import paddle
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
paddle.enable_static()
......@@ -307,28 +306,6 @@ class TestElementwiseAddOp_xsize_lessthan_ysize_add(TestElementwiseAddOp):
self.axis = 2
@unittest.skipIf(
not paddle.is_compiled_with_xpu(), "core is not compiled with XPU"
)
class TestElementwiseAddOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
# the input of elementwise_add must be Variable.
x1 = fluid.create_lod_tensor(
np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], fluid.XPUPlace(0)
)
y1 = fluid.create_lod_tensor(
np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], fluid.XPUPlace(0)
)
self.assertRaises(TypeError, fluid.layers.elementwise_add, x1, y1)
# the input dtype of elementwise_add must be float16 or float32 or float64 or int32 or int64
# float16 only can be set on GPU place
x2 = fluid.layers.data(name='x2', shape=[3, 4, 5, 6], dtype="uint8")
y2 = fluid.layers.data(name='y2', shape=[3, 4, 5, 6], dtype="uint8")
self.assertRaises(TypeError, fluid.layers.elementwise_add, x2, y2)
@unittest.skipIf(
not paddle.is_compiled_with_xpu(), "core is not compiled with XPU"
)
......
......@@ -26,8 +26,6 @@ from xpu.get_test_cover_info import (
)
import paddle
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
paddle.enable_static()
......@@ -241,31 +239,6 @@ class XPUTestElementwiseMulOp(XPUOpTestWrapper):
'Out': self.inputs['X'].reshape(1, 1, 10, 10) * self.inputs['Y']
}
class TestElementwiseMulOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
# the input of elementwise_mul must be Variable.
x1 = fluid.create_lod_tensor(
np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], fluid.XPUPlace(0)
)
y1 = fluid.create_lod_tensor(
np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], fluid.XPUPlace(0)
)
self.assertRaises(
TypeError, fluid.layers.elementwise_mul, x1, y1
)
# the input dtype of elementwise_mul must be float32
x2 = fluid.layers.data(
name='x2', shape=[3, 4, 5, 6], dtype="uint8"
)
y2 = fluid.layers.data(
name='y2', shape=[3, 4, 5, 6], dtype="uint8"
)
self.assertRaises(
TypeError, fluid.layers.elementwise_mul, x2, y2
)
support_types = get_xpu_op_support_types('elementwise_mul')
for stype in support_types:
......
......@@ -26,11 +26,11 @@ from paddle.fluid.framework import (
_in_legacy_dygraph,
in_dygraph_mode,
)
from paddle.tensor.math import _add_with_axis
from ...device import get_cudnn_version
from ...fluid.data_feeder import check_dtype, check_variable_and_dtype
from ...fluid.layer_helper import LayerHelper
from ...fluid.layers import nn
from ...fluid.layers.utils import (
_contain_var,
_convert_to_tensor_list,
......@@ -226,7 +226,7 @@ def _conv_nd(
)
pre_bias = getattr(_legacy_C_ops, op_type)(x, weight, *attrs)
if bias is not None:
out = nn.elementwise_add(pre_bias, bias, axis=channel_dim)
out = _add_with_axis(pre_bias, bias, axis=channel_dim)
else:
out = pre_bias
else:
......@@ -491,7 +491,7 @@ def conv1d(
False,
)
if bias is not None:
out = nn.elementwise_add(out, bias, axis=channel_dim)
out = _add_with_axis(out, bias, axis=channel_dim)
elif _in_legacy_dygraph():
attrs = (
'strides',
......@@ -515,7 +515,7 @@ def conv1d(
)
out = getattr(_legacy_C_ops, l_type)(x, weight, *attrs)
if bias is not None:
out = nn.elementwise_add(out, bias, axis=channel_dim)
out = _add_with_axis(out, bias, axis=channel_dim)
else:
inputs = {'Input': [x], 'Filter': [weight]}
attrs = {
......@@ -540,7 +540,7 @@ def conv1d(
type=l_type, inputs=inputs, outputs=outputs, attrs=attrs
)
if bias is not None:
out = nn.elementwise_add(out, bias, axis=channel_dim)
out = _add_with_axis(out, bias, axis=channel_dim)
out = squeeze(out, axis=[squeeze_aixs])
return out
......@@ -1047,7 +1047,7 @@ def conv1d_transpose(
conv2d_data_format,
)
if bias is not None:
out = nn.elementwise_add(out, bias, axis=channel_dim)
out = _add_with_axis(out, bias, axis=channel_dim)
elif _in_legacy_dygraph():
attrs = (
'output_padding',
......@@ -1071,7 +1071,7 @@ def conv1d_transpose(
)
out = getattr(_legacy_C_ops, op_type)(x, weight, *attrs)
if bias is not None:
out = nn.elementwise_add(out, bias, axis=channel_dim)
out = _add_with_axis(out, bias, axis=channel_dim)
else:
inputs = {'Input': [x], 'Filter': [weight]}
attrs = {
......@@ -1096,7 +1096,7 @@ def conv1d_transpose(
type=op_type, inputs=inputs, outputs=outputs, attrs=attrs
)
if bias is not None:
out = nn.elementwise_add(out, bias, axis=channel_dim)
out = _add_with_axis(out, bias, axis=channel_dim)
out = squeeze(out, axis=[squeeze_axis])
return out
......@@ -1351,7 +1351,7 @@ def conv2d_transpose(
data_format,
)
if bias is not None:
return nn.elementwise_add(pre_bias, bias, axis=channel_dim)
return _add_with_axis(pre_bias, bias, axis=channel_dim)
else:
return pre_bias
......@@ -1378,7 +1378,7 @@ def conv2d_transpose(
)
pre_bias = getattr(_legacy_C_ops, op_type)(x, weight, *attrs)
if bias is not None:
out = nn.elementwise_add(pre_bias, bias, axis=channel_dim)
out = _add_with_axis(pre_bias, bias, axis=channel_dim)
else:
out = pre_bias
else:
......@@ -1405,7 +1405,7 @@ def conv2d_transpose(
)
if bias is not None:
out = nn.elementwise_add(pre_bias, bias, axis=channel_dim)
out = _add_with_axis(pre_bias, bias, axis=channel_dim)
else:
out = pre_bias
......@@ -1824,7 +1824,7 @@ def conv3d_transpose(
data_format_,
)
if bias is not None:
return nn.elementwise_add(pre_bias, bias, axis=channel_dim)
return _add_with_axis(pre_bias, bias, axis=channel_dim)
else:
return pre_bias
......@@ -1851,7 +1851,7 @@ def conv3d_transpose(
)
pre_bias = getattr(_legacy_C_ops, op_type)(x, weight, *attrs)
if bias is not None:
out = nn.elementwise_add(pre_bias, bias, axis=channel_dim)
out = _add_with_axis(pre_bias, bias, axis=channel_dim)
else:
out = pre_bias
else:
......@@ -1879,7 +1879,7 @@ def conv3d_transpose(
type=op_type, inputs=inputs, outputs=outputs, attrs=attrs
)
if bias is not None:
out = nn.elementwise_add(pre_bias, bias, axis=channel_dim)
out = _add_with_axis(pre_bias, bias, axis=channel_dim)
else:
out = pre_bias
......
......@@ -1405,15 +1405,13 @@ def l1_loss(input, label, reduction='mean', name=None):
)
if reduction == 'sum':
unreduced = paddle.fluid.layers.elementwise_sub(input, label, act='abs')
unreduced = paddle.abs(paddle.subtract(x=input, y=label))
return paddle.sum(unreduced, name=name)
elif reduction == 'mean':
unreduced = paddle.fluid.layers.elementwise_sub(input, label, act='abs')
unreduced = paddle.abs(paddle.subtract(x=input, y=label))
return paddle.mean(unreduced, name=name)
else:
return paddle.fluid.layers.elementwise_sub(
input, label, act='abs', name=name
)
return paddle.abs(paddle.subtract(x=input, y=label, name=name))
def nll_loss(
......
......@@ -90,7 +90,7 @@ def _weight_norm(v, g, dim):
v_normalized = F.l2_normalize(p_matrix, axis=1)
v_normalized = paddle.reshape(v_normalized, transposed_shape)
v_normalized = paddle.transpose(v_normalized, perm)
weight = F.elementwise_mul(
weight = paddle.tensor.math._multiply_with_axis(
v_normalized, g, axis=dim if dim is not None else -1
)
return weight
......
......@@ -31,7 +31,7 @@ from ..fluid.data_feeder import (
check_variable_and_dtype,
convert_dtype,
)
from ..fluid.layers import elementwise_sub, utils
from ..fluid.layers import utils
from ..framework import (
LayerHelper,
_in_legacy_dygraph,
......@@ -985,6 +985,115 @@ def multiply(x, y, name=None):
return _elementwise_op(LayerHelper(op_type, **locals()))
@dygraph_only
def _elementwise_op_with_axis_in_dygraph(
x, y, axis=-1, name=None, op_type="Undifined"
):
assert (
in_dygraph_mode()
), "You can only call `_elementwise_op_with_axis_in_dygraph` function within in_dygraph_mode"
assert op_type in ["add", "subtract", "multiply", "divide"], (
"op_name input error! _elementwise_op_with_axis is an inner function to replace elementwise_add/sub/mul/div. Input op_name=%s, Expect op_name=[add|subtract|multiply|divide]\n"
% op_type
)
op = getattr(_C_ops, op_type)
x_shape = list(x.shape)
y_shape = list(y.shape)
if axis == -1 or len(x_shape) == len(y_shape):
return op(x, y)
if len(x_shape) > len(y_shape):
padding = len(x_shape) - len(y_shape) - axis
y = paddle.reshape(y, [1] * axis + y_shape + [1] * padding)
else:
padding = len(y_shape) - len(x_shape) - axis
x = paddle.reshape(x, [1] * axis + y_shape + [1] * padding)
return op(x, y)
def _add_with_axis(x, y, axis=-1, name=None):
# opt performance, only dynamic mode needs reshape
if in_dygraph_mode():
return _elementwise_op_with_axis_in_dygraph(x, y, axis, name, "add")
else:
op_type = 'elementwise_add'
act = None
if _in_legacy_dygraph():
return _elementwise_op_in_dygraph(
x, y, axis=axis, act=act, op_name=op_type
)
else:
if x.dtype != y.dtype:
raise TypeError(
'Input tensors must be same type, but received type of x: %s, type of y: %s '
% (x.dtype, y.dtype)
)
return _elementwise_op(LayerHelper(op_type, **locals()))
def _subtract_with_axis(x, y, axis=-1, name=None):
# opt performance, only dynamic mode needs reshape
if in_dygraph_mode():
return _elementwise_op_with_axis_in_dygraph(
x, y, axis, name, "subtract"
)
else:
op_type = 'elementwise_sub'
act = None
if _in_legacy_dygraph():
return _elementwise_op_in_dygraph(
x, y, axis=axis, act=act, op_name=op_type
)
else:
if x.dtype != y.dtype:
raise TypeError(
'Input tensors must be same type, but received type of x: %s, type of y: %s '
% (x.dtype, y.dtype)
)
return _elementwise_op(LayerHelper(op_type, **locals()))
def _multiply_with_axis(x, y, axis=-1, name=None):
# opt performance, only dynamic mode needs reshape
if in_dygraph_mode():
return _elementwise_op_with_axis_in_dygraph(
x, y, axis, name, "multiply"
)
else:
op_type = 'elementwise_mul'
act = None
if _in_legacy_dygraph():
return _elementwise_op_in_dygraph(
x, y, axis=axis, act=act, op_name=op_type
)
else:
if x.dtype != y.dtype:
raise TypeError(
'Input tensors must be same type, but received type of x: %s, type of y: %s '
% (x.dtype, y.dtype)
)
return _elementwise_op(LayerHelper(op_type, **locals()))
def _divide_with_axis(x, y, axis=-1, name=None):
# opt performance, only dynamic mode needs reshape
if in_dygraph_mode():
return _elementwise_op_with_axis_in_dygraph(x, y, axis, name, "divide")
else:
op_type = 'elementwise_div'
act = None
if _in_legacy_dygraph():
return _elementwise_op_in_dygraph(
x, y, axis=axis, act=act, op_name=op_type
)
else:
if x.dtype != y.dtype:
raise TypeError(
'Input tensors must be same type, but received type of x: %s, type of y: %s '
% (x.dtype, y.dtype)
)
return _elementwise_op(LayerHelper(op_type, **locals()))
def maximum(x, y, name=None):
"""
Compare two tensors and returns a new tensor containing the element-wise maxima. The equation is:
......@@ -4877,7 +4986,9 @@ def diff(x, n=1, axis=-1, prepend=None, append=None, name=None):
if x.dtype == paddle.bool:
return _legacy_C_ops.logical_xor(input_back, input_front)
else:
return elementwise_sub(input_back, input_front, axis=axis)
return paddle.tensor.math._subtract_with_axis(
input_back, input_front, axis=axis
)
else:
check_variable_and_dtype(
x, 'x', ['float32', 'float64', 'bool', 'int32', 'int64'], 'diff'
......@@ -4941,7 +5052,9 @@ def diff(x, n=1, axis=-1, prepend=None, append=None, name=None):
outputs={"Out": out},
)
else:
out = elementwise_sub(input_back, input_front, axis=axis)
out = paddle.tensor.math._subtract_with_axis(
input_back, input_front, axis=axis
)
return out
......
......@@ -15,6 +15,7 @@
import numpy as np
from paddle import _C_ops, _legacy_C_ops
from paddle.tensor.math import _add_with_axis
from ..fluid.data_feeder import check_type, check_variable_and_dtype
from ..fluid.framework import (
......@@ -25,7 +26,7 @@ from ..fluid.framework import (
)
from ..fluid.initializer import Normal
from ..fluid.layer_helper import LayerHelper
from ..fluid.layers import nn, utils
from ..fluid.layers import utils
from ..framework import _current_expected_place
from ..nn import BatchNorm2D, Conv2D, Layer, ReLU, Sequential
......@@ -985,7 +986,7 @@ def deform_conv2d(
1,
)
if bias is not None:
out = nn.elementwise_add(pre_bias, bias, axis=1)
out = _add_with_axis(pre_bias, bias, axis=1)
else:
out = pre_bias
elif _in_legacy_dygraph():
......@@ -1014,7 +1015,7 @@ def deform_conv2d(
x, offset, mask, weight, *attrs
)
if bias is not None:
out = nn.elementwise_add(pre_bias, bias, axis=1)
out = _add_with_axis(pre_bias, bias, axis=1)
else:
out = pre_bias
else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册
新手
引导
客服 返回
顶部