未验证 提交 308073de 编写于 作者: P pangyoki 提交者: GitHub

Add 12 inplace APIs including auto generated (#32573)

* add relu6_ hardsigmoid_ leaky_relu_ Inplace APIs

* add softmax_with_cross_entropy_ Inplace API

* add clip_ scale_ add_ subtract_ Inplace APIs

* add wlist

* fix parameter of scale api

* add add_n_ Inplace API and remove log_ Inplace API

* fix elementwise_add_ and elementwise_sub_ broadcast problem

* elementwise inplace api give error message before run the op

* use broadcast_shape in elementwise inplace op

* add 8 inplace apis that is auto generated

* add unittest for all inplace apis

* add decorator for inplace apis in static mode

* fix windows blas fail of exp inplace api, change array_equal to allclose

* add flatten inplace api

* add flatten unittest

* fix flatten unittest

* add decorator

* fix grad.numpy in test_pylayer_op

* unsupport softmax_with_cross_entropy_

* add test_inplace_softmax_with_cross_entropy to static_mode_white_list

* delete __all__ in inplace_utils

* delete activation inplace function and add Tensor.inplace_func

* change paddle.inplace_ to Tensor.inplace_

* fix little problem

* add paddle in inplace_utils
上级 9b4fabf9
......@@ -408,7 +408,8 @@ void BasicEngine::Execute() {
VLOG(10) << "create temporary var of " << var->Name()
<< " for sum gradient within this graph!";
} else if (!inplace_grad_name_map.empty() &&
inplace_grad_name_map.count(pair.first)) {
inplace_grad_name_map.count(pair.first) &&
bwd_ins.count(inplace_grad_name_map.at(pair.first))) {
// When calculate Inplace grad op, create a new output var.
// If a tmp var has been created, there is no need to create it
// again.
......
......@@ -120,23 +120,9 @@ template <typename DeviceContext, typename T>
class FlattenContiguousRangeKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
auto &start_axis = context.Attr<int>("start_axis");
auto &stop_axis = context.Attr<int>("stop_axis");
auto *in = context.Input<framework::LoDTensor>("X");
auto x_dims = in->dims();
int in_dims_size = x_dims.size();
int real_start_axis = start_axis, real_stop_axis = stop_axis;
if (start_axis < 0) {
real_start_axis = start_axis + in_dims_size;
}
if (stop_axis < 0) {
real_stop_axis = stop_axis + in_dims_size;
}
auto *out = context.Output<framework::LoDTensor>("Out");
auto out_dims = framework::make_ddim(
GetOutputShape(real_start_axis, real_stop_axis, x_dims));
auto out_dims = out->dims();
out->mutable_data(context.GetPlace(), in->type());
framework::TensorCopy(
......@@ -144,27 +130,6 @@ class FlattenContiguousRangeKernel : public framework::OpKernel<T> {
context.template device_context<platform::DeviceContext>(), out);
out->Resize(out_dims);
}
static std::vector<int32_t> GetOutputShape(const int start_axis,
const int stop_axis,
const framework::DDim &in_dims) {
int64_t outer = 1;
std::vector<int32_t> out_shape;
int in_dims_size = in_dims.size();
out_shape.reserve(in_dims_size - stop_axis + start_axis);
for (int i = 0; i < start_axis; ++i) {
out_shape.push_back(in_dims[i]);
}
for (int i = start_axis; i <= stop_axis; i++) {
outer *= in_dims[i];
}
out_shape.push_back(outer);
for (int i = stop_axis + 1; i < in_dims_size; i++) {
out_shape.push_back(in_dims[i]);
}
return out_shape;
}
};
template <typename DeviceContext, typename T>
......
......@@ -58,6 +58,8 @@ from .amp import *
from .math_op_patch import monkey_patch_math_varbase
from .inplace_utils import inplace_apis_in_dygraph_only
__all__ = []
__all__ += layers.__all__
__all__ += base.__all__
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..wrapped_decorator import wrap_decorator
from ..framework import in_dygraph_mode
import warnings
import paddle
# NOTE(pangyoki): The Inplace APIs with underline(`_`) is only valid for the method of calling `core.ops`
# in dygraph mode. If static mode is used, the inplace mechanism will not be used, and the static method
# of the original API will be called.
def _inplace_apis_in_dygraph_only_(func):
def __impl__(*args, **kwargs):
if not in_dygraph_mode():
origin_api_name = func.__name__[:-1]
warnings.warn(
"In static mode, {}() is the same as {}() and does not perform inplace operation.".
format(func.__name__, origin_api_name))
origin_func = "{}.{}".format(func.__module__, origin_api_name)
return eval(origin_func)(*args, **kwargs)
return func(*args, **kwargs)
return __impl__
inplace_apis_in_dygraph_only = wrap_decorator(_inplace_apis_in_dygraph_only_)
......@@ -25,7 +25,8 @@ from ..layer_helper import LayerHelper
from ..data_feeder import check_variable_and_dtype
__all__ = [
'generate_layer_fn', 'generate_activation_fn', 'autodoc', 'templatedoc'
'generate_layer_fn', 'generate_activation_fn', 'generate_inplace_fn',
'autodoc', 'templatedoc'
]
......@@ -283,6 +284,35 @@ def generate_activation_fn(op_type):
return func
def generate_inplace_fn(inplace_op_type):
"""Register the Python layer for an Inplace Operator without Attribute.
Args:
inplace_op_type: The name of the inplace operator to be created.
This function takes in the inplace operator type (exp_ , ceil_ etc) and
creates the operator functionality.
"""
origin_op_type = inplace_op_type[:-1]
def func(x, name=None):
if in_dygraph_mode():
op = getattr(core.ops, inplace_op_type)
return op(x)
warnings.warn(
"In static mode, {}() is the same as {}() and does not perform inplace operation.".
format(inplace_op_type, origin_op_type))
return generate_activation_fn(origin_op_type)(x, name)
func.__name__ = inplace_op_type
func.__doc__ = """
Inplace version of ``{0}`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_fluid_layers_{1}`.
""".format(origin_op_type, origin_op_type)
return func
def autodoc(comment=""):
def __impl__(func):
func.__doc__ = _generate_doc_string_(OpProtoHolder.instance(
......
......@@ -14,7 +14,7 @@
from __future__ import print_function
import os
from .layer_function_generator import generate_layer_fn, generate_activation_fn, add_sample_code
from .layer_function_generator import generate_layer_fn, generate_activation_fn, generate_inplace_fn, add_sample_code
from .. import core
from ..framework import convert_np_dtype_to_dtype_, Variable
from ..data_feeder import convert_dtype, check_variable_and_dtype, check_type, check_dtype
......@@ -55,6 +55,16 @@ __unary_func__ = [
'square',
]
__inplace_unary_func__ = [
'exp_',
'sqrt_',
'rsqrt_',
'ceil_',
'floor_',
'round_',
'reciprocal_',
]
__all__ = []
for _OP in set(__all__):
......@@ -69,6 +79,7 @@ globals()['_elementwise_div'] = generate_layer_fn('elementwise_div')
__all__ += __activations_noattr__
__all__ += __unary_func__
__all__ += __inplace_unary_func__
for _OP in set(__activations_noattr__):
_new_OP = _OP
......@@ -87,6 +98,14 @@ for _OP in set(__unary_func__):
func = deprecated(since="2.0.0", update_to="paddle.%s" % (_new_OP))(func)
globals()[_OP] = func
for _OP in set(__inplace_unary_func__):
_new_OP = _OP
if _OP in __deprecated_func_name__:
_new_OP = __deprecated_func_name__[_OP]
func = generate_inplace_fn(_OP)
func = deprecated(since="2.0.0", update_to="paddle.%s" % (_new_OP))(func)
globals()[_OP] = func
add_sample_code(globals()["sigmoid"], r"""
Examples:
.. code-block:: python
......
......@@ -124,6 +124,9 @@ class TestClipOpError(unittest.TestCase):
class TestClipAPI(unittest.TestCase):
def _executed_api(self, x, min=None, max=None):
return paddle.clip(x, min, max)
def test_clip(self):
paddle.enable_static()
data_shape = [1, 9, 9, 4]
......@@ -136,18 +139,20 @@ class TestClipAPI(unittest.TestCase):
) else fluid.CPUPlace()
exe = fluid.Executor(place)
out_1 = paddle.clip(images, min=min, max=max)
out_2 = paddle.clip(images, min=0.2, max=0.9)
out_3 = paddle.clip(images, min=0.3)
out_4 = paddle.clip(images, max=0.7)
out_5 = paddle.clip(images, min=min)
out_6 = paddle.clip(images, max=max)
out_7 = paddle.clip(images, max=-1.)
out_8 = paddle.clip(images)
out_9 = paddle.clip(paddle.cast(images, 'float64'), min=0.2, max=0.9)
out_10 = paddle.clip(paddle.cast(images * 10, 'int32'), min=2, max=8)
out_11 = paddle.clip(paddle.cast(images * 10, 'int64'), min=2, max=8)
out_1 = self._executed_api(images, min=min, max=max)
out_2 = self._executed_api(images, min=0.2, max=0.9)
out_3 = self._executed_api(images, min=0.3)
out_4 = self._executed_api(images, max=0.7)
out_5 = self._executed_api(images, min=min)
out_6 = self._executed_api(images, max=max)
out_7 = self._executed_api(images, max=-1.)
out_8 = self._executed_api(images)
out_9 = self._executed_api(
paddle.cast(images, 'float64'), min=0.2, max=0.9)
out_10 = self._executed_api(
paddle.cast(images * 10, 'int32'), min=2, max=8)
out_11 = self._executed_api(
paddle.cast(images * 10, 'int64'), min=2, max=8)
res1, res2, res3, res4, res5, res6, res7, res8, res9, res10, res11 = exe.run(
fluid.default_main_program(),
......@@ -188,12 +193,16 @@ class TestClipAPI(unittest.TestCase):
v_min = paddle.to_tensor(np.array([0.2], dtype=np.float32))
v_max = paddle.to_tensor(np.array([0.8], dtype=np.float32))
out_1 = paddle.clip(images, min=0.2, max=0.8)
out_2 = paddle.clip(images, min=0.2, max=0.9)
out_3 = paddle.clip(images, min=v_min, max=v_max)
out_1 = self._executed_api(images, min=0.2, max=0.8)
images = paddle.to_tensor(data, dtype='float32')
out_2 = self._executed_api(images, min=0.2, max=0.9)
images = paddle.to_tensor(data, dtype='float32')
out_3 = self._executed_api(images, min=v_min, max=v_max)
out_4 = paddle.clip(paddle.cast(images * 10, 'int32'), min=2, max=8)
out_5 = paddle.clip(paddle.cast(images * 10, 'int64'), min=2, max=8)
out_4 = self._executed_api(
paddle.cast(images * 10, 'int32'), min=2, max=8)
out_5 = self._executed_api(
paddle.cast(images * 10, 'int64'), min=2, max=8)
self.assertTrue(np.allclose(out_1.numpy(), data.clip(0.2, 0.8)))
self.assertTrue(np.allclose(out_2.numpy(), data.clip(0.2, 0.9)))
......@@ -212,5 +221,10 @@ class TestClipAPI(unittest.TestCase):
paddle.disable_static()
class TestInplaceClipAPI(TestClipAPI):
def _executed_api(self, x, min=None, max=None):
return x.clip_(min, max)
if __name__ == '__main__':
unittest.main()
......@@ -408,13 +408,16 @@ class TestElementwiseAddOpError(unittest.TestCase):
self.assertRaises(TypeError, fluid.layers.elementwise_add, x2, y2)
class TestAddOp(unittest.TestCase):
class TestAddApi(unittest.TestCase):
def _executed_api(self, x, y, name=None):
return paddle.add(x, y, name)
def test_name(self):
with fluid.program_guard(fluid.Program()):
x = fluid.data(name="x", shape=[2, 3], dtype="float32")
y = fluid.data(name='y', shape=[2, 3], dtype='float32')
y_1 = paddle.add(x, y, name='add_res')
y_1 = self._executed_api(x, y, name='add_res')
self.assertEqual(('add_res' in y_1.name), True)
def test_declarative(self):
......@@ -428,7 +431,7 @@ class TestAddOp(unittest.TestCase):
x = fluid.data(name="x", shape=[3], dtype='float32')
y = fluid.data(name="y", shape=[3], dtype='float32')
z = paddle.add(x, y)
z = self._executed_api(x, y)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
......@@ -442,12 +445,75 @@ class TestAddOp(unittest.TestCase):
np_y = np.array([1, 5, 2]).astype('float64')
x = fluid.dygraph.to_variable(np_x)
y = fluid.dygraph.to_variable(np_y)
z = paddle.add(x, y)
z = self._executed_api(x, y)
np_z = z.numpy()
z_expected = np.array([3., 8., 6.])
self.assertEqual((np_z == z_expected).all(), True)
class TestAddInplaceApi(TestAddApi):
def _executed_api(self, x, y, name=None):
return x.add_(y, name)
class TestAddInplaceBroadcastSuccess(unittest.TestCase):
def init_data(self):
self.x_numpy = np.random.rand(2, 3, 4).astype('float')
self.y_numpy = np.random.rand(3, 4).astype('float')
def test_broadcast_success(self):
paddle.disable_static()
self.init_data()
x = paddle.to_tensor(self.x_numpy)
y = paddle.to_tensor(self.y_numpy)
inplace_result = x.add_(y)
numpy_result = self.x_numpy + self.y_numpy
self.assertEqual((inplace_result.numpy() == numpy_result).all(), True)
paddle.enable_static()
class TestAddInplaceBroadcastSuccess2(TestAddInplaceBroadcastSuccess):
def init_data(self):
self.x_numpy = np.random.rand(1, 2, 3, 1).astype('float')
self.y_numpy = np.random.rand(3, 1).astype('float')
class TestAddInplaceBroadcastSuccess3(TestAddInplaceBroadcastSuccess):
def init_data(self):
self.x_numpy = np.random.rand(2, 3, 1, 5).astype('float')
self.y_numpy = np.random.rand(1, 3, 1, 5).astype('float')
class TestAddInplaceBroadcastError(unittest.TestCase):
def init_data(self):
self.x_numpy = np.random.rand(3, 4).astype('float')
self.y_numpy = np.random.rand(2, 3, 4).astype('float')
def test_broadcast_errors(self):
paddle.disable_static()
self.init_data()
x = paddle.to_tensor(self.x_numpy)
y = paddle.to_tensor(self.y_numpy)
def broadcast_shape_error():
x.add_(y)
self.assertRaises(ValueError, broadcast_shape_error)
paddle.enable_static()
class TestAddInplaceBroadcastError2(TestAddInplaceBroadcastError):
def init_data(self):
self.x_numpy = np.random.rand(2, 1, 4).astype('float')
self.y_numpy = np.random.rand(2, 3, 4).astype('float')
class TestAddInplaceBroadcastError3(TestAddInplaceBroadcastError):
def init_data(self):
self.x_numpy = np.random.rand(5, 2, 1, 4).astype('float')
self.y_numpy = np.random.rand(2, 3, 4).astype('float')
class TestComplexElementwiseAddOp(OpTest):
def setUp(self):
self.op_type = "elementwise_add"
......
......@@ -16,6 +16,7 @@ from __future__ import print_function
import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
from op_test import OpTest, skip_check_grad_ci
......@@ -237,6 +238,111 @@ class TestRealComplexElementwiseSubOp(TestComplexElementwiseSubOp):
self.grad_y = -self.grad_out
class TestSubtractApi(unittest.TestCase):
def _executed_api(self, x, y, name=None):
return paddle.subtract(x, y, name)
def test_name(self):
with fluid.program_guard(fluid.Program()):
x = fluid.data(name="x", shape=[2, 3], dtype="float32")
y = fluid.data(name='y', shape=[2, 3], dtype='float32')
y_1 = self._executed_api(x, y, name='subtract_res')
self.assertEqual(('subtract_res' in y_1.name), True)
def test_declarative(self):
with fluid.program_guard(fluid.Program()):
def gen_data():
return {
"x": np.array([2, 3, 4]).astype('float32'),
"y": np.array([1, 5, 2]).astype('float32')
}
x = fluid.data(name="x", shape=[3], dtype='float32')
y = fluid.data(name="y", shape=[3], dtype='float32')
z = self._executed_api(x, y)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
z_value = exe.run(feed=gen_data(), fetch_list=[z.name])
z_expected = np.array([1., -2., 2.])
self.assertEqual((z_value == z_expected).all(), True)
def test_dygraph(self):
with fluid.dygraph.guard():
np_x = np.array([2, 3, 4]).astype('float64')
np_y = np.array([1, 5, 2]).astype('float64')
x = fluid.dygraph.to_variable(np_x)
y = fluid.dygraph.to_variable(np_y)
z = self._executed_api(x, y)
np_z = z.numpy()
z_expected = np.array([1., -2., 2.])
self.assertEqual((np_z == z_expected).all(), True)
class TestSubtractInplaceApi(TestSubtractApi):
def _executed_api(self, x, y, name=None):
return x.subtract_(y, name)
class TestSubtractInplaceBroadcastSuccess(unittest.TestCase):
def init_data(self):
self.x_numpy = np.random.rand(2, 3, 4).astype('float')
self.y_numpy = np.random.rand(3, 4).astype('float')
def test_broadcast_success(self):
paddle.disable_static()
self.init_data()
x = paddle.to_tensor(self.x_numpy)
y = paddle.to_tensor(self.y_numpy)
inplace_result = x.subtract_(y)
numpy_result = self.x_numpy - self.y_numpy
self.assertEqual((inplace_result.numpy() == numpy_result).all(), True)
paddle.enable_static()
class TestSubtractInplaceBroadcastSuccess2(TestSubtractInplaceBroadcastSuccess):
def init_data(self):
self.x_numpy = np.random.rand(1, 2, 3, 1).astype('float')
self.y_numpy = np.random.rand(3, 1).astype('float')
class TestSubtractInplaceBroadcastSuccess3(TestSubtractInplaceBroadcastSuccess):
def init_data(self):
self.x_numpy = np.random.rand(2, 3, 1, 5).astype('float')
self.y_numpy = np.random.rand(1, 3, 1, 5).astype('float')
class TestSubtractInplaceBroadcastError(unittest.TestCase):
def init_data(self):
self.x_numpy = np.random.rand(3, 4).astype('float')
self.y_numpy = np.random.rand(2, 3, 4).astype('float')
def test_broadcast_errors(self):
paddle.disable_static()
self.init_data()
x = paddle.to_tensor(self.x_numpy)
y = paddle.to_tensor(self.y_numpy)
def broadcast_shape_error():
x.subtract_(y)
self.assertRaises(ValueError, broadcast_shape_error)
paddle.enable_static()
class TestSubtractInplaceBroadcastError2(TestSubtractInplaceBroadcastError):
def init_data(self):
self.x_numpy = np.random.rand(2, 1, 4).astype('float')
self.y_numpy = np.random.rand(2, 3, 4).astype('float')
class TestSubtractInplaceBroadcastError3(TestSubtractInplaceBroadcastError):
def init_data(self):
self.x_numpy = np.random.rand(5, 2, 1, 4).astype('float')
self.y_numpy = np.random.rand(2, 3, 4).astype('float')
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
......@@ -182,6 +182,30 @@ class TestFlatten2OpError(unittest.TestCase):
self.assertRaises(ValueError, test_InputError)
class TestStaticFlattenPythonAPI(unittest.TestCase):
def execute_api(self, x, start_axis=0, stop_axis=-1):
return paddle.flatten(x, start_axis, stop_axis)
def test_static_api(self):
paddle.enable_static()
np_x = np.random.rand(2, 3, 4, 4).astype('float32')
main_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog, paddle.static.Program()):
x = paddle.static.data(
name="x", shape=[2, 3, 4, 4], dtype='float32')
out = self.execute_api(x, start_axis=-2, stop_axis=-1)
exe = paddle.static.Executor(place=paddle.CPUPlace())
fetch_out = exe.run(main_prog, feed={"x": np_x}, fetch_list=[out])
self.assertTrue((2, 3, 16) == fetch_out[0].shape)
class TestStaticInplaceFlattenPythonAPI(TestStaticFlattenPythonAPI):
def execute_api(self, x, start_axis=0, stop_axis=-1):
return x.flatten_(start_axis, stop_axis)
class TestFlattenPython(unittest.TestCase):
def test_python_api(self):
image_shape = (2, 3, 4, 4)
......@@ -204,5 +228,23 @@ class TestFlattenPython(unittest.TestCase):
self.assertTrue((2, 3, 16) == res_shape)
class TestDygraphInplaceFlattenPython(unittest.TestCase):
def test_python_api(self):
image_shape = (2, 3, 4, 4)
x = np.arange(image_shape[0] * image_shape[1] * image_shape[2] *
image_shape[3]).reshape(image_shape) / 100.
x = x.astype('float32')
def test_Negative():
paddle.disable_static()
img = paddle.to_tensor(x)
out = img.flatten_(start_axis=-2, stop_axis=-1)
return out.numpy().shape
res_shape = test_Negative()
self.assertTrue((2, 3, 16) == res_shape)
paddle.enable_static()
if __name__ == "__main__":
unittest.main()
......@@ -98,11 +98,15 @@ class TestInplace(unittest.TestCase):
class TestDygraphInplace(unittest.TestCase):
def setUp(self):
self.init_data()
self.set_np_compare_func()
def init_data(self):
self.input_var_numpy = np.random.rand(2, 3, 1)
self.input_var_numpy = np.random.uniform(-5, 5, [10, 20, 1])
self.dtype = "float32"
def set_np_compare_func(self):
self.np_compare = np.array_equal
def non_inplace_api_processing(self, var):
return paddle.squeeze(var)
......@@ -190,7 +194,7 @@ class TestDygraphInplace(unittest.TestCase):
loss.backward()
grad_var_a = var_a.grad.numpy()
self.assertTrue(np.array_equal(grad_var_a_inplace, grad_var_a))
self.assertTrue(self.np_compare(grad_var_a_inplace, grad_var_a))
def test_backward_success_2(self):
# Although var_b is modified inplace after using it, it does not used in gradient computation.
......@@ -244,6 +248,14 @@ class TestDygraphInplaceReshape(TestDygraphInplace):
return paddle.reshape_(var, [-1])
class TestDygraphInplaceFlatten(TestDygraphInplace):
def non_inplace_api_processing(self, var):
return var.flatten()
def inplace_api_processing(self, var):
return var.flatten_()
class TestDygraphInplaceScatter(TestDygraphInplace):
def init_data(self):
self.input_var_numpy = np.array([[1, 1], [2, 2], [3, 3]])
......@@ -296,5 +308,106 @@ class TestDygraphInplaceTanh(TestDygraphInplace):
return paddle.tanh_(var)
class TestDygraphInplaceCeil(TestDygraphInplace):
def non_inplace_api_processing(self, var):
return var.ceil()
def inplace_api_processing(self, var):
return var.ceil_()
class TestDygraphInplaceFloor(TestDygraphInplace):
def non_inplace_api_processing(self, var):
return var.floor()
def inplace_api_processing(self, var):
return var.floor_()
class TestDygraphInplaceExp(TestDygraphInplace):
def set_np_compare_func(self):
self.np_compare = np.allclose
def non_inplace_api_processing(self, var):
return var.exp()
def inplace_api_processing(self, var):
return var.exp_()
class TestDygraphInplaceReciprocal(TestDygraphInplace):
def non_inplace_api_processing(self, var):
return var.reciprocal()
def inplace_api_processing(self, var):
return var.reciprocal_()
class TestDygraphInplaceRound(TestDygraphInplace):
def non_inplace_api_processing(self, var):
return var.round()
def inplace_api_processing(self, var):
return var.round_()
class TestDygraphInplaceSqrt(TestDygraphInplace):
def init_data(self):
self.input_var_numpy = np.random.uniform(0, 5, [10, 20, 1])
self.dtype = "float32"
def non_inplace_api_processing(self, var):
return var.sqrt()
def inplace_api_processing(self, var):
return var.sqrt_()
class TestDygraphInplaceRsqrt(TestDygraphInplaceSqrt):
def non_inplace_api_processing(self, var):
return var.rsqrt()
def inplace_api_processing(self, var):
return var.rsqrt_()
class TestDygraphInplaceClip(TestDygraphInplace):
def non_inplace_api_processing(self, var):
return var.clip(0.6, 1.5)
def inplace_api_processing(self, var):
return var.clip_(0.6, 1.5)
class TestDygraphInplaceScale(TestDygraphInplace):
def non_inplace_api_processing(self, var):
return var.scale(scale=2.0, bias=3.0)
def inplace_api_processing(self, var):
return var.scale_(scale=2.0, bias=3.0)
class TestDygraphInplaceAdd(TestDygraphInplace):
def init_data(self):
self.input_var_numpy = np.random.rand(2, 3, 4)
self.dtype = "float32"
input_var_numpy_2 = np.random.rand(2, 3, 4).astype(self.dtype)
self.input_var_2 = paddle.to_tensor(input_var_numpy_2)
def non_inplace_api_processing(self, var):
return var.add(self.input_var_2)
def inplace_api_processing(self, var):
return var.add_(self.input_var_2)
class TestDygraphInplaceSubtract(TestDygraphInplaceAdd):
def non_inplace_api_processing(self, var):
return var.subtract(self.input_var_2)
def inplace_api_processing(self, var):
return var.subtract_(self.input_var_2)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
import paddle
import paddle.fluid as fluid
from paddle.static import Program, program_guard
# In static mode, inplace strategy will not be used in Inplace APIs.
class TestStaticAutoGeneratedAPI(unittest.TestCase):
def setUp(self):
paddle.enable_static()
self.init_data()
self.set_np_compare_func()
def init_data(self):
self.dtype = 'float32'
self.shape = [10, 20]
self.np_x = np.random.uniform(-5, 5, self.shape).astype(self.dtype)
def set_np_compare_func(self):
self.np_compare = np.array_equal
def executed_paddle_api(self, x):
return x.ceil()
def executed_numpy_api(self, x):
return np.ceil(x)
def test_api(self):
main_prog = Program()
with program_guard(main_prog, Program()):
x = paddle.static.data(name="x", shape=self.shape, dtype=self.dtype)
out = self.executed_paddle_api(x)
exe = paddle.static.Executor(place=paddle.CPUPlace())
fetch_x, fetch_out = exe.run(main_prog,
feed={"x": self.np_x},
fetch_list=[x, out])
self.assertTrue(np.array_equal(fetch_x, self.np_x))
self.assertTrue(
self.np_compare(fetch_out, self.executed_numpy_api(self.np_x)))
class TestStaticInplaceAutoGeneratedAPI(TestStaticAutoGeneratedAPI):
def executed_paddle_api(self, x):
return x.ceil_()
class TestStaticFloorAPI(TestStaticAutoGeneratedAPI):
def executed_paddle_api(self, x):
return x.floor()
def executed_numpy_api(self, x):
return np.floor(x)
class TestStaticInplaceFloorAPI(TestStaticFloorAPI):
def executed_paddle_api(self, x):
return x.floor_()
class TestStaticExpAPI(TestStaticAutoGeneratedAPI):
def set_np_compare_func(self):
self.np_compare = np.allclose
def executed_paddle_api(self, x):
return x.exp()
def executed_numpy_api(self, x):
return np.exp(x)
class TestStaticInplaceExpAPI(TestStaticExpAPI):
def executed_paddle_api(self, x):
return x.exp_()
class TestStaticReciprocalAPI(TestStaticAutoGeneratedAPI):
def executed_paddle_api(self, x):
return x.reciprocal()
def executed_numpy_api(self, x):
return np.reciprocal(x)
class TestStaticInplaceReciprocalAPI(TestStaticReciprocalAPI):
def executed_paddle_api(self, x):
return x.reciprocal_()
class TestStaticRoundAPI(TestStaticAutoGeneratedAPI):
def executed_paddle_api(self, x):
return x.round()
def executed_numpy_api(self, x):
return np.round(x)
class TestStaticInplaceRoundAPI(TestStaticRoundAPI):
def executed_paddle_api(self, x):
return x.round_()
class TestStaticSqrtAPI(TestStaticAutoGeneratedAPI):
def init_data(self):
self.dtype = 'float32'
self.shape = [10, 20]
self.np_x = np.random.uniform(0, 5, self.shape).astype(self.dtype)
def set_np_compare_func(self):
self.np_compare = np.allclose
def executed_paddle_api(self, x):
return x.sqrt()
def executed_numpy_api(self, x):
return np.sqrt(x)
class TestStaticInplaceSqrtAPI(TestStaticSqrtAPI):
def executed_paddle_api(self, x):
return x.sqrt_()
class TestStaticRsqrtAPI(TestStaticSqrtAPI):
def executed_paddle_api(self, x):
return x.rsqrt()
def executed_numpy_api(self, x):
return 1 / np.sqrt(x)
class TestStaticInplaceRsqrtAPI(TestStaticRsqrtAPI):
def executed_paddle_api(self, x):
return x.rsqrt_()
# In dygraph mode, inplace strategy will be used in Inplace APIs.
class TestDygraphAutoGeneratedAPI(unittest.TestCase):
def setUp(self):
paddle.disable_static()
self.init_data()
self.set_np_compare_func()
def init_data(self):
self.dtype = 'float32'
self.shape = [10, 20]
self.np_x = np.random.uniform(-5, 5, self.shape).astype(self.dtype)
def set_np_compare_func(self):
self.np_compare = np.array_equal
def executed_paddle_api(self, x):
return x.ceil()
def executed_numpy_api(self, x):
return np.ceil(x)
def test_api(self):
x = paddle.to_tensor(self.np_x, dtype=self.dtype)
out = self.executed_paddle_api(x)
self.assertTrue(
self.np_compare(out.numpy(), self.executed_numpy_api(self.np_x)))
class TestDygraphInplaceAutoGeneratedAPI(TestDygraphAutoGeneratedAPI):
def executed_paddle_api(self, x):
return x.ceil_()
class TestDygraphFloorAPI(TestDygraphAutoGeneratedAPI):
def executed_paddle_api(self, x):
return x.floor()
def executed_numpy_api(self, x):
return np.floor(x)
class TestDygraphInplaceFloorAPI(TestDygraphFloorAPI):
def executed_paddle_api(self, x):
return x.floor_()
class TestDygraphExpAPI(TestDygraphAutoGeneratedAPI):
def executed_paddle_api(self, x):
return x.exp()
def executed_numpy_api(self, x):
return np.exp(x)
def set_np_compare_func(self):
self.np_compare = np.allclose
class TestDygraphInplaceExpAPI(TestDygraphExpAPI):
def executed_paddle_api(self, x):
return x.exp_()
class TestDygraphReciprocalAPI(TestDygraphAutoGeneratedAPI):
def executed_paddle_api(self, x):
return x.reciprocal()
def executed_numpy_api(self, x):
return np.reciprocal(x)
class TestDygraphInplaceReciprocalAPI(TestDygraphReciprocalAPI):
def executed_paddle_api(self, x):
return x.reciprocal_()
class TestDygraphRoundAPI(TestDygraphAutoGeneratedAPI):
def executed_paddle_api(self, x):
return x.round()
def executed_numpy_api(self, x):
return np.round(x)
class TestDygraphInplaceRoundAPI(TestDygraphRoundAPI):
def executed_paddle_api(self, x):
return x.round_()
class TestDygraphSqrtAPI(TestDygraphAutoGeneratedAPI):
def init_data(self):
self.dtype = 'float32'
self.shape = [10, 20]
self.np_x = np.random.uniform(0, 100, self.shape).astype(self.dtype)
def set_np_compare_func(self):
self.np_compare = np.allclose
def executed_paddle_api(self, x):
return x.sqrt()
def executed_numpy_api(self, x):
return np.sqrt(x)
class TestDygraphInplaceSqrtAPI(TestDygraphSqrtAPI):
def executed_paddle_api(self, x):
return x.sqrt_()
class TestDygraphRsqrtAPI(TestDygraphSqrtAPI):
def executed_paddle_api(self, x):
return x.rsqrt()
def executed_numpy_api(self, x):
return 1. / np.sqrt(x)
class TestDygraphInplaceRsqrtAPI(TestDygraphRsqrtAPI):
def executed_paddle_api(self, x):
return x.rsqrt_()
if __name__ == "__main__":
unittest.main()
......@@ -17,9 +17,11 @@ from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid.op import Operator
from paddle.static import Program, program_guard
class TestScaleOp(OpTest):
......@@ -168,5 +170,45 @@ class TestScaleFp16OpSelectedRows(TestScaleOpSelectedRows):
self.check_with_place(place, 'in', 'in')
class TestScaleApiStatic(unittest.TestCase):
def _executed_api(self, x, scale=1.0, bias=0.0):
return paddle.scale(x, scale, bias)
def test_api(self):
paddle.enable_static()
input = np.random.random([2, 25]).astype("float32")
main_prog = Program()
with program_guard(main_prog, Program()):
x = paddle.static.data(name="x", shape=[2, 25], dtype="float32")
out = self._executed_api(x, scale=2.0, bias=3.0)
exe = paddle.static.Executor(place=paddle.CPUPlace())
out = exe.run(main_prog, feed={"x": input}, fetch_list=[out])
self.assertEqual(np.array_equal(out[0], input * 2.0 + 3.0), True)
class TestScaleInplaceApiStatic(TestScaleApiStatic):
def _executed_api(self, x, scale=1.0, bias=0.0):
return x.scale_(scale, bias)
class TestScaleApiDygraph(unittest.TestCase):
def _executed_api(self, x, scale=1.0, bias=0.0):
return paddle.scale(x, scale, bias)
def test_api(self):
paddle.disable_static()
input = np.random.random([2, 25]).astype("float32")
x = paddle.to_tensor(input)
out = self._executed_api(x, scale=2.0, bias=3.0)
self.assertEqual(np.array_equal(out.numpy(), input * 2.0 + 3.0), True)
paddle.enable_static()
class TestScaleInplaceApiDygraph(TestScaleApiDygraph):
def _executed_api(self, x, scale=1.0, bias=0.0):
return x.scale_(scale, bias)
if __name__ == "__main__":
unittest.main()
......@@ -16,7 +16,7 @@ from ...fluid.layers import sigmoid # noqa: F401
from ...tensor.math import tanh # noqa: F401
from ...tensor.math import tanh_ # noqa: F401
from ...tensor.manipulation import _print_warning_in_static_mode
from ...fluid.dygraph.inplace_utils import inplace_apis_in_dygraph_only
from ...tensor.manipulation import chunk
from ...tensor.math import multiply
......@@ -73,17 +73,13 @@ def elu(x, alpha=1.0, name=None):
return out
@inplace_apis_in_dygraph_only
def elu_(x, alpha=1.0, name=None):
r"""
Inplace version of ``elu`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_nn_cn_elu`.
"""
if in_dygraph_mode():
return core.ops.elu_(x, 'alpha', alpha)
_print_warning_in_static_mode("elu")
return elu(x, alpha, name)
return core.ops.elu_(x, 'alpha', alpha)
def gelu(x, approximate=False, name=None):
......@@ -501,17 +497,13 @@ def relu(x, name=None):
return out
@inplace_apis_in_dygraph_only
def relu_(x, name=None):
"""
Inplace version of ``relu`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_nn_cn_relu`.
"""
if in_dygraph_mode():
return core.ops.relu_(x)
_print_warning_in_static_mode("relu")
return relu(x, name)
return core.ops.relu_(x)
def log_sigmoid(x, name=None):
......@@ -912,21 +904,16 @@ def softmax(x, axis=-1, dtype=None, name=None):
return outs_softmax
@inplace_apis_in_dygraph_only
def softmax_(x, axis=-1, dtype=None, name=None):
r"""
Inplace version of ``softmax`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_nn_cn_softmax`.
"""
if (dtype is not None) and (not isinstance(dtype, core.VarDesc.VarType)):
dtype = convert_np_dtype_to_dtype_(dtype)
use_cudnn = True
if in_dygraph_mode():
return core.ops.softmax_(x, 'axis', axis, 'use_cudnn', use_cudnn)
_print_warning_in_static_mode("softmax")
return softmax(x, axis, dtype, name)
return core.ops.softmax_(x, 'axis', axis, 'use_cudnn', use_cudnn)
def softplus(x, beta=1, threshold=20, name=None):
......
......@@ -65,6 +65,7 @@ from .manipulation import broadcast_to # noqa: F401
from .manipulation import expand_as # noqa: F401
from .manipulation import tile # noqa: F401
from .manipulation import flatten # noqa: F401
from .manipulation import flatten_ # noqa: F401
from .manipulation import gather # noqa: F401
from .manipulation import gather_nd # noqa: F401
from .manipulation import reshape # noqa: F401
......@@ -95,24 +96,32 @@ from .math import acos # noqa: F401
from .math import asin # noqa: F401
from .math import atan # noqa: F401
from .math import ceil # noqa: F401
from .math import ceil_ # noqa: F401
from .math import cos # noqa: F401
from .math import tan # noqa: F401
from .math import cosh # noqa: F401
from .math import cumsum # noqa: F401
from .math import exp # noqa: F401
from .math import exp_ # noqa: F401
from .math import floor # noqa: F401
from .math import floor_ # noqa: F401
from .math import increment # noqa: F401
from .math import log # noqa: F401
from .math import multiplex # noqa: F401
from .math import pow # noqa: F401
from .math import reciprocal # noqa: F401
from .math import reciprocal_ # noqa: F401
from .math import round # noqa: F401
from .math import round_ # noqa: F401
from .math import rsqrt # noqa: F401
from .math import rsqrt_ # noqa: F401
from .math import scale # noqa: F401
from .math import scale_ # noqa: F401
from .math import sign # noqa: F401
from .math import sin # noqa: F401
from .math import sinh # noqa: F401
from .math import sqrt # noqa: F401
from .math import sqrt_ # noqa: F401
from .math import square # noqa: F401
from .math import stanh # noqa: F401
from .math import sum # noqa: F401
......@@ -131,7 +140,9 @@ from .math import mod # noqa: F401
from .math import floor_mod # noqa: F401
from .math import multiply # noqa: F401
from .math import add # noqa: F401
from .math import add_ # noqa: F401
from .math import subtract # noqa: F401
from .math import subtract_ # noqa: F401
from .math import atan # noqa: F401
from .math import logsumexp # noqa: F401
from .math import inverse # noqa: F401
......@@ -141,6 +152,7 @@ from .math import log1p # noqa: F401
from .math import erf # noqa: F401
from .math import addmm # noqa: F401
from .math import clip # noqa: F401
from .math import clip_ # noqa: F401
from .math import trace # noqa: F401
from .math import kron # noqa: F401
from .math import isfinite # noqa: F401
......@@ -202,11 +214,14 @@ tensor_method_func = [ #noqa
'asin',
'atan',
'ceil',
'ceil_',
'cos',
'cosh',
'cumsum',
'exp',
'exp_',
'floor',
'floor_',
'increment',
'log',
'log2',
......@@ -217,13 +232,18 @@ tensor_method_func = [ #noqa
'pow',
'prod',
'reciprocal',
'reciprocal_',
'round',
'round_',
'rsqrt',
'rsqrt_',
'scale',
'scale_',
'sign',
'sin',
'sinh',
'sqrt',
'sqrt_',
'square',
'stanh',
'sum',
......@@ -242,7 +262,9 @@ tensor_method_func = [ #noqa
'floor_mod',
'multiply',
'add',
'add_',
'subtract',
'subtract_',
'atan',
'logsumexp',
'inverse',
......@@ -250,6 +272,7 @@ tensor_method_func = [ #noqa
'erf',
'addmm',
'clip',
'clip_',
'trace',
'kron',
'isfinite',
......@@ -277,6 +300,7 @@ tensor_method_func = [ #noqa
'broadcast_to',
'expand_as',
'flatten',
'flatten_',
'gather',
'gather_nd',
'reshape',
......
......@@ -31,18 +31,12 @@ from ..fluid.layers import unstack # noqa: F401
from ..fluid.layers import scatter_nd # noqa: F401
from ..fluid.layers import shard_index # noqa: F401
from ..fluid import layers
from ..fluid.dygraph.inplace_utils import inplace_apis_in_dygraph_only
import paddle
import warnings
__all__ = []
def _print_warning_in_static_mode(api_name):
warnings.warn(
"In static mode, {}_() is the same as {}() and does not perform inplace operation.".
format(api_name, api_name))
@dygraph_only
def tolist(x):
"""
......@@ -289,6 +283,36 @@ def flatten(x, start_axis=0, stop_axis=-1, name=None):
return out
@inplace_apis_in_dygraph_only
def flatten_(x, start_axis=0, stop_axis=-1, name=None):
"""
Inplace version of ``flatten`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_tensor_flatten`.
"""
if not (isinstance(x, Variable)):
raise ValueError("The input x should be a Tensor")
x_dim = len(x.shape)
if not (isinstance(start_axis, int)) or (
start_axis > x_dim - 1) or start_axis < -x_dim:
raise ValueError(
"The start_axis should be a int, and in range [-rank(x), rank(x))")
if not (isinstance(stop_axis, int)) or (
stop_axis > x_dim - 1) or stop_axis < -x_dim:
raise ValueError(
"The stop_axis should be a int, and in range [-rank(x), rank(x))")
if start_axis < 0:
start_axis = start_axis + x_dim
if stop_axis < 0:
stop_axis = stop_axis + x_dim
if start_axis > stop_axis:
raise ValueError("The stop_axis should be larger than stat_axis")
dy_out, _ = core.ops.flatten_contiguous_range_(x, 'start_axis', start_axis,
'stop_axis', stop_axis)
return dy_out
def roll(x, shifts, axis=None, name=None):
"""
Roll the `x` tensor along the given axis(axes). With specific 'shifts', Elements that
......@@ -582,6 +606,7 @@ def squeeze(x, axis=None, name=None):
return layers.squeeze(x, axis, name)
@inplace_apis_in_dygraph_only
def squeeze_(x, axis=None, name=None):
"""
Inplace version of ``squeeze`` API, the output Tensor will be inplaced with input ``x``.
......@@ -594,12 +619,8 @@ def squeeze_(x, axis=None, name=None):
elif isinstance(axis, tuple):
axis = list(axis)
if in_dygraph_mode():
out, _ = core.ops.squeeze2_(x, 'axes', axis)
return out
_print_warning_in_static_mode("squeeze")
return squeeze(x, axis, name)
out, _ = core.ops.squeeze2_(x, 'axes', axis)
return out
def unique(x,
......@@ -775,26 +796,23 @@ def unsqueeze(x, axis, name=None):
return layers.unsqueeze(x, axis, name)
@inplace_apis_in_dygraph_only
def unsqueeze_(x, axis, name=None):
"""
Inplace version of ``unsqueeze`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_paddle_tensor_unsqueeze`.
"""
if in_dygraph_mode():
if isinstance(axis, int):
axis = [axis]
elif isinstance(axis, Variable):
axis = axis.numpy().tolist()
elif isinstance(axis, (list, tuple)):
axis = [
item.numpy().item(0) if isinstance(item, Variable) else item
for item in axis
]
out, _ = core.ops.unsqueeze2_(x, 'axes', axis)
return out
_print_warning_in_static_mode("unsqueeze")
return unsqueeze(x, axis, name)
if isinstance(axis, int):
axis = [axis]
elif isinstance(axis, Variable):
axis = axis.numpy().tolist()
elif isinstance(axis, (list, tuple)):
axis = [
item.numpy().item(0) if isinstance(item, Variable) else item
for item in axis
]
out, _ = core.ops.unsqueeze2_(x, 'axes', axis)
return out
def gather(x, index, axis=None, name=None):
......@@ -1023,16 +1041,13 @@ def scatter(x, index, updates, overwrite=True, name=None):
return out
@inplace_apis_in_dygraph_only
def scatter_(x, index, updates, overwrite=True, name=None):
"""
Inplace version of ``scatter`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_paddle_tensor_scatter`.
"""
if in_dygraph_mode():
return core.ops.scatter_(x, index, updates, 'overwrite', overwrite)
_print_warning_in_static_mode("scatter")
return scatter(x, index, updates, overwrite, name)
return core.ops.scatter_(x, index, updates, 'overwrite', overwrite)
def scatter_nd_add(x, index, updates, name=None):
......@@ -1555,26 +1570,23 @@ def reshape(x, shape, name=None):
return paddle.fluid.layers.reshape(x=x, shape=shape, name=name)
@inplace_apis_in_dygraph_only
def reshape_(x, shape, name=None):
"""
Inplace version of ``reshape`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_paddle_tensor_reshape`.
"""
if in_dygraph_mode():
if isinstance(shape, (list, tuple)):
shape = [
item.numpy().item(0) if isinstance(item, Variable) else item
for item in shape
]
out, _ = core.ops.reshape2_(x, None, 'shape', shape)
return out
elif isinstance(shape, Variable):
shape.stop_gradient = True
out, _ = core.ops.reshape2_(x, shape)
return out
_print_warning_in_static_mode("reshape")
return reshape(x, shape, name)
if isinstance(shape, (list, tuple)):
shape = [
item.numpy().item(0) if isinstance(item, Variable) else item
for item in shape
]
out, _ = core.ops.reshape2_(x, None, 'shape', shape)
return out
elif isinstance(shape, Variable):
shape.stop_gradient = True
out, _ = core.ops.reshape2_(x, shape)
return out
def gather_nd(x, index, name=None):
......
......@@ -30,7 +30,7 @@ from ..fluid.framework import core, _varbase_creator, in_dygraph_mode, Variable,
from ..fluid.layer_helper import LayerHelper
from ..fluid.data_feeder import check_variable_and_dtype, check_type, check_dtype, convert_dtype
from ..fluid.layers.layer_function_generator import _generate_doc_string_, generate_activation_fn, generate_layer_fn
from .manipulation import _print_warning_in_static_mode
from ..fluid.dygraph.inplace_utils import inplace_apis_in_dygraph_only
# TODO: define math functions
# yapf: disable
......@@ -38,22 +38,29 @@ from ..fluid.layers import abs # noqa: F401
from ..fluid.layers import acos # noqa: F401
from ..fluid.layers import asin # noqa: F401
from ..fluid.layers import ceil # noqa: F401
from ..fluid.layers import ceil_ # noqa: F401
from ..fluid.layers import cos # noqa: F401
from ..fluid.layers import tan # noqa: F401
from ..fluid.layers import sinh # noqa: F401
from ..fluid.layers import cosh # noqa: F401
from ..fluid.layers import exp # noqa: F401
from ..fluid.layers import exp_ # noqa: F401
from ..fluid.layers import floor # noqa: F401
from ..fluid.layers import floor_ # noqa: F401
from ..fluid.layers import log # noqa: F401
from ..fluid.layers import reciprocal # noqa: F401
from ..fluid.layers import reciprocal_ # noqa: F401
from ..fluid.layers import round # noqa: F401
from ..fluid.layers import round_ # noqa: F401
from ..fluid.layers import rsqrt # noqa: F401
from ..fluid.layers import rsqrt_ # noqa: F401
from ..fluid.layers import scale # noqa: F401
from ..fluid.layers import square # noqa: F401
from ..fluid.layers import stanh # noqa: F401
from ..fluid.layers import atan # noqa: F401
from ..fluid.layers import erf # noqa: F401
from ..fluid.layers import sqrt # noqa: F401
from ..fluid.layers import sqrt_ # noqa: F401
from ..fluid.layers import sin # noqa: F401
from ..fluid.layers import multiplex # noqa: F401
......@@ -74,6 +81,19 @@ _supported_float_dtype_ = [
VarDesc.VarType.FP64,
]
@inplace_apis_in_dygraph_only
def scale_(x, scale=1.0, bias=0.0, bias_after_scale=True, act=None, name=None):
"""
Inplace version of ``scale`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_tensor_scale`.
"""
_scale = scale.numpy().item(0) if isinstance(scale, Variable) else scale
return core.ops.scale_(x, 'scale',
float(_scale), 'bias',
float(bias), 'bias_after_scale', bias_after_scale)
def pow(x, y, name=None):
"""
Compute the power of tensor elements. The equation is:
......@@ -221,6 +241,24 @@ def add(x, y, name=None):
return _elementwise_op(LayerHelper(op_type, **locals()))
@inplace_apis_in_dygraph_only
def add_(x, y, name=None):
"""
Inplace version of ``add`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_tensor_add`.
"""
op_type = 'elementwise_add_'
axis = -1
out_shape = broadcast_shape(x.shape, y.shape)
if out_shape != x.shape:
raise ValueError("The shape of broadcast output {} is different from that of inplace tensor {} in the Inplace operation.".format(out_shape, x.shape))
out = _elementwise_op_in_dygraph(
x, y, axis=axis, op_name=op_type)
return out
def subtract(x, y, name=None):
"""
Substract two tensors element-wise. The equation is:
......@@ -282,6 +320,24 @@ def subtract(x, y, name=None):
return _elementwise_op(LayerHelper(op_type, **locals()))
@inplace_apis_in_dygraph_only
def subtract_(x, y, name=None):
"""
Inplace version of ``subtract`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_tensor_subtract`.
"""
axis = -1
act = None
out_shape = broadcast_shape(x.shape, y.shape)
if out_shape != x.shape:
raise ValueError("The shape of broadcast output {} is different from that of inplace tensor {} in the Inplace operation.".format(out_shape, x.shape))
out = _elementwise_op_in_dygraph(
x, y, axis=axis, act=act, op_name='elementwise_sub_')
return out
def divide(x, y, name=None):
"""
Divide two tensors element-wise. The equation is:
......@@ -1489,6 +1545,24 @@ def clip(x, min=None, max=None, name=None):
return output
@inplace_apis_in_dygraph_only
def clip_(x, min=None, max=None, name=None):
"""
Inplace version of ``clip`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_tensor_clip`.
"""
fmin = float(np.finfo(np.float32).min)
fmax = float(np.finfo(np.float32).max)
if isinstance(min, Variable):
min = min.numpy().item(0)
if isinstance(max, Variable):
max = max.numpy().item(0)
min = fmin if min is None else min
max = fmax if max is None else max
return core.ops.clip_(x, "min", min, "max", max)
def trace(x, offset=0, axis1=0, axis2=1, name=None):
"""
**trace**
......@@ -1908,16 +1982,14 @@ def tanh(x, name=None):
helper.append_op(type='tanh', inputs={'X': x}, outputs={'Out': out})
return out
@inplace_apis_in_dygraph_only
def tanh_(x, name=None):
r"""
Inplace version of ``tanh`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_tensor_tanh`.
"""
if in_dygraph_mode():
return core.ops.tanh_(x)
return core.ops.tanh_(x)
_print_warning_in_static_mode("tanh")
return tanh(x, name)
def increment(x, value=1.0, name=None):
"""
......
......@@ -34,6 +34,10 @@
"name":"reshape_",
"annotation":"Inplace APIs don't need sample code. There is a special document introducing Inplace strategy"
},
{
"name":"flatten_",
"annotation":"Inplace APIs don't need sample code. There is a special document introducing Inplace strategy"
},
{
"name":"scatter_",
"annotation":"Inplace APIs don't need sample code. There is a special document introducing Inplace strategy"
......@@ -53,6 +57,50 @@
{
"name":"tanh_",
"annotation":"Inplace APIs don't need sample code. There is a special document introducing Inplace strategy"
},
{
"name":"ceil_",
"annotation":"Inplace APIs don't need sample code. There is a special document introducing Inplace strategy"
},
{
"name":"floor_",
"annotation":"Inplace APIs don't need sample code. There is a special document introducing Inplace strategy"
},
{
"name":"exp_",
"annotation":"Inplace APIs don't need sample code. There is a special document introducing Inplace strategy"
},
{
"name":"reciprocal_",
"annotation":"Inplace APIs don't need sample code. There is a special document introducing Inplace strategy"
},
{
"name":"round_",
"annotation":"Inplace APIs don't need sample code. There is a special document introducing Inplace strategy"
},
{
"name":"sqrt_",
"annotation":"Inplace APIs don't need sample code. There is a special document introducing Inplace strategy"
},
{
"name":"rsqrt_",
"annotation":"Inplace APIs don't need sample code. There is a special document introducing Inplace strategy"
},
{
"name":"clip_",
"annotation":"Inplace APIs don't need sample code. There is a special document introducing Inplace strategy"
},
{
"name":"scale_",
"annotation":"Inplace APIs don't need sample code. There is a special document introducing Inplace strategy"
},
{
"name":"subtract_",
"annotation":"Inplace APIs don't need sample code. There is a special document introducing Inplace strategy"
},
{
"name":"add_",
"annotation":"Inplace APIs don't need sample code. There is a special document introducing Inplace strategy"
}
],
"wlist_temp_api":[
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册