未验证 提交 b3caa233 编写于 作者: G Guoxia Wang 提交者: GitHub

add inplace sigmoid_ and multiply_ (#50267)

上级 77d9b4c3
......@@ -1196,6 +1196,7 @@
kernel :
func : multiply {dense, dense -> dense},
multiply_sr {selected_rows, dense -> selected_rows}
inplace : (x -> out)
backward : multiply_grad
- op : nms
......
......@@ -1261,6 +1261,7 @@
func : UnchangedInferMeta
kernel :
func : sigmoid
inplace : (x -> out)
backward : sigmoid_grad
- op : sign
......
......@@ -181,5 +181,57 @@ class TestMultiplyError(unittest.TestCase):
self.assertRaises(ValueError, paddle.multiply, x_data, y_data)
class TestMultiplyInplaceApi(TestMultiplyApi):
def _run_static_graph_case(self, x_data, y_data):
with program_guard(Program(), Program()):
paddle.enable_static()
x = paddle.static.data(
name='x', shape=x_data.shape, dtype=x_data.dtype
)
y = paddle.static.data(
name='y', shape=y_data.shape, dtype=y_data.dtype
)
res = x.multiply_(y)
place = (
paddle.CUDAPlace(0)
if paddle.is_compiled_with_cuda()
else paddle.CPUPlace()
)
exe = paddle.static.Executor(place)
outs = exe.run(
paddle.static.default_main_program(),
feed={'x': x_data, 'y': y_data},
fetch_list=[res],
)
res = outs[0]
return res
def _run_dynamic_graph_case(self, x_data, y_data):
paddle.disable_static()
with paddle.no_grad():
x = paddle.to_tensor(x_data)
y = paddle.to_tensor(y_data)
x.multiply_(y)
return x.numpy()
class TestMultiplyInplaceError(unittest.TestCase):
def test_errors(self):
paddle.disable_static()
# test dynamic computation graph: inputs must be broadcastable
x_data = np.random.rand(3, 4)
y_data = np.random.rand(2, 3, 4)
x = paddle.to_tensor(x_data)
y = paddle.to_tensor(y_data)
def multiply_shape_error():
with paddle.no_grad():
x.multiply_(y)
self.assertRaises(ValueError, multiply_shape_error)
paddle.enable_static()
if __name__ == '__main__':
unittest.main()
......@@ -192,6 +192,7 @@ from .math import remainder_ # noqa: F401
from .math import mod # noqa: F401
from .math import floor_mod # noqa: F401
from .math import multiply # noqa: F401
from .math import multiply_ # noqa: F401
from .math import add # noqa: F401
from .math import add_ # noqa: F401
from .math import subtract # noqa: F401
......@@ -243,6 +244,8 @@ from .math import frac # noqa: F401
from .math import sgn # noqa: F401
from .math import take # noqa: F401
from .math import frexp # noqa: F401
from .math import sigmoid # noqa: F401
from .math import sigmoid_ # noqa: F401
from .random import multinomial # noqa: F401
from .random import standard_normal # noqa: F401
......@@ -380,6 +383,7 @@ tensor_method_func = [ # noqa
'mod',
'floor_mod',
'multiply',
'multiply_',
'add',
'add_',
'subtract',
......@@ -525,6 +529,8 @@ tensor_method_func = [ # noqa
'bucketize',
'sgn',
'frexp',
'sigmoid',
'sigmoid_',
]
# this list used in math_op_patch.py for magic_method bind
......
......@@ -34,6 +34,7 @@ from ..fluid.data_feeder import (
)
from ..framework import (
LayerHelper,
_dygraph_tracer,
convert_np_dtype_to_dtype_,
core,
in_dygraph_mode,
......@@ -64,6 +65,8 @@ from .ops import round # noqa: F401
from .ops import round_ # noqa: F401
from .ops import rsqrt # noqa: F401
from .ops import rsqrt_ # noqa: F401
from .ops import sigmoid # noqa: F401
from .ops import sigmoid_ # noqa: F401
from .ops import sin # noqa: F401
from .ops import sinh # noqa: F401
from .ops import sqrt # noqa: F401
......@@ -893,6 +896,28 @@ def multiply(x, y, name=None):
return _elementwise_op(LayerHelper('elementwise_mul', **locals()))
@inplace_apis_in_dygraph_only
def multiply_(x, y, name=None):
"""
Inplace version of ``multiply`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_tensor_multiply`.
"""
assert (
_dygraph_tracer()._has_grad is False
), "The current inplace version of multiply_ needs to be used in the context of paddle.no_grad() since inplace multiply_grad is not yet supported."
out_shape = broadcast_shape(x.shape, y.shape)
if out_shape != x.shape:
raise ValueError(
"The shape of broadcast output {} is different from that of inplace tensor {} in the Inplace operation.".format(
out_shape, x.shape
)
)
return _C_ops.multiply_(x, y)
@dygraph_only
def _elementwise_op_with_axis_in_dygraph(
x, y, axis=-1, name=None, op_type="Undifined"
......
......@@ -47,6 +47,7 @@ __inplace_unary_func__ = [
'floor_',
'round_',
'reciprocal_',
'sigmoid_',
]
__all__ = []
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册