diff --git a/paddle/fluid/operators/prim_ops/CMakeLists.txt b/paddle/fluid/operators/prim_ops/CMakeLists.txt index 30e162a4dd2a9671d42cf9760a84f1af649220fe..7a75b9b98572707e04e352e4dbc648263dd6cfd3 100644 --- a/paddle/fluid/operators/prim_ops/CMakeLists.txt +++ b/paddle/fluid/operators/prim_ops/CMakeLists.txt @@ -37,7 +37,8 @@ set(PRIM_OP_SRCS max_p_op.cc erf_p_op.cc abs_p_op.cc - cast_p_op.cc) + cast_p_op.cc + rsqrt_p_op.cc) cc_test( prim_op_test diff --git a/paddle/fluid/operators/prim_ops/rsqrt_p_op.cc b/paddle/fluid/operators/prim_ops/rsqrt_p_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..9f094cee8948c9a259c7881558ddd939ca6d7b2d --- /dev/null +++ b/paddle/fluid/operators/prim_ops/rsqrt_p_op.cc @@ -0,0 +1,82 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" + +namespace paddle { +namespace framework { +class InferShapeContext; +class VarDesc; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace operators { +class RsqrtPrimOp : public framework::OperatorBase { + public: + RsqrtPrimOp(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : framework::OperatorBase(type, inputs, outputs, attrs) {} + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { + PADDLE_THROW(platform::errors::Unimplemented( + "Prim operator rsqrt_p should not be excuted directly")); + } +}; + +class RsqrtPrimOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor), The input tensor of rsqrt_p op."); + AddOutput("Y", "(Tensor), The output tensor of rsqrt_p op."); + AddComment(R"DOC( +Autograd primitive rsqrt_p operator. +)DOC"); + } +}; + +class RsqrtPrimOpShapeInference : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *ctx) const override { + framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0]; + framework::InferShapeVarPtr y_var_ptr = ctx->GetOutputVarPtrs("Y")[0]; + + framework::VarDesc *x_var = PADDLE_GET(framework::VarDesc *, x_var_ptr); + + PADDLE_GET(framework::VarDesc *, y_var_ptr)->SetShape(x_var->GetShape()); + } +}; + +class RsqrtPrimOpVarTypeInference + : public framework::StaticGraphVarTypeInference { + public: + void operator()(framework::InferVarTypeContext *ctx) const override { + auto x_name = Input(ctx, "X")[0]; + auto y_name = Output(ctx, "Y")[0]; + SetType(ctx, y_name, GetType(ctx, x_name)); + SetDataType(ctx, y_name, GetDataType(ctx, x_name)); + } +}; + +} // namespace operators +} // namespace paddle + +REGISTER_OPERATOR(rsqrt_p, + paddle::operators::RsqrtPrimOp, + paddle::operators::RsqrtPrimOpMaker, + paddle::operators::RsqrtPrimOpShapeInference, + paddle::operators::RsqrtPrimOpVarTypeInference); diff --git a/python/paddle/fluid/tests/unittests/autograd/test_jvp_and_transpose.py b/python/paddle/fluid/tests/unittests/autograd/test_jvp_and_transpose.py index c8e1b3965228d71bbcf4a8942f7c6adfbb672f7c..4a92f1bc65bfdb4d913137fa4951dd3cddd472a9 100644 --- a/python/paddle/fluid/tests/unittests/autograd/test_jvp_and_transpose.py +++ b/python/paddle/fluid/tests/unittests/autograd/test_jvp_and_transpose.py @@ -241,6 +241,39 @@ class TestSqrtPJVPAndTranspose(TestAddPJVPAndTranspose): ] +class TestRSqrtPJVPAndTranspose(TestAddPJVPAndTranspose): + + def init_data(self): + # Set prim op + self.op_type = 'rsqrt_p' + X = paddle.static.data(name='X', shape=[5, 6], dtype='int64') + self.prim_input = { + 'X': X, + } + self.prim_output = { + 'Y': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.prim_attrs = {} + + # Set JVP + X_DOT = paddle.static.data(name='X_DOT', shape=[5, 6], dtype='int64') + self.jvp_args = (X_DOT, ) + self.jvp_out_shape_map = {0: self.prim_output['Y']} + + self.all_ops = [ + # prim op: + 'rsqrt_p', + # jvp op: + 'div_p', + 'div_p', + 'mul_p', + 'fill_constant_p', + # 'sqrt_p', + # transpose op: + ] + + class TestTanhPJVPAndTranspose(TestAddPJVPAndTranspose): def init_data(self): diff --git a/python/paddle/fluid/tests/unittests/autograd/test_orig2prim.py b/python/paddle/fluid/tests/unittests/autograd/test_orig2prim.py index e1d5ee11a13ace711db583a1772d5d5d5b94e76a..d289f3dad9c692b96378d04b8eb7985f0d904156 100644 --- a/python/paddle/fluid/tests/unittests/autograd/test_orig2prim.py +++ b/python/paddle/fluid/tests/unittests/autograd/test_orig2prim.py @@ -879,5 +879,26 @@ class TestSquareOrig2Prim(TestElementWiseAddOrig2Prim): self.out_map = {0: self.output['Out']} +class TestRSqrtOrig2Prim(TestElementWiseAddOrig2Prim): + + def init_data(self): + self.op_type = 'rsqrt' + X = paddle.static.data(name='X', shape=[7, 8], dtype='float64') + + self.input = { + 'X': X, + } + self.output = { + 'Out': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.attrs = {} + + self.orig2prim_args = (X, ) + self.all_ops = ['rsqrt', 'rsqrt_p'] + # { prim_op_output_index: orig_op_output_var } + self.out_map = {0: self.output['Out']} + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/autograd/test_prim2orig.py b/python/paddle/fluid/tests/unittests/autograd/test_prim2orig.py index a89b91bdd2b64a7e28b7782a003992eb2b2e16f9..abc2803a8c0889868785d5d4cfbb440e1eb70325 100644 --- a/python/paddle/fluid/tests/unittests/autograd/test_prim2orig.py +++ b/python/paddle/fluid/tests/unittests/autograd/test_prim2orig.py @@ -690,5 +690,25 @@ class TestCastPPrim2Orig(TestAddPPrim2Orig): self.out_map = {self.output['Y']: 0} +class TestRsqrtPrim2Orig(TestAddPPrim2Orig): + + def init_data(self): + self.op_type = 'rsqrt_p' + X = paddle.static.data(name='X', shape=[7, 8], dtype='float64') + + self.input = { + 'X': X, + } + self.output = { + 'Y': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.attrs = {} + + self.prim2orig_args = (X, ) + self.all_ops = ['rsqrt_p', 'rsqrt'] + self.out_map = {self.output['Y']: 0} + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/autograd/test_primapi.py b/python/paddle/fluid/tests/unittests/autograd/test_primapi.py index bdc54563fc8d2a127670106af005ed00a1b9abf3..1a086e12f20a671219d6092be8f28d46d687c2f4 100644 --- a/python/paddle/fluid/tests/unittests/autograd/test_primapi.py +++ b/python/paddle/fluid/tests/unittests/autograd/test_primapi.py @@ -152,6 +152,7 @@ class TestWithoutProgramGuard(unittest.TestCase): ('log', paddle.log, (np.random.rand(3, 4), ), None, 'float32'), ('abs', paddle.abs, (np.random.uniform(-10, 10, (10, 10)), ), None, 'float32'), + ('rsqrt', paddle.rsqrt, (np.random.rand(100, 200), ), None, 'float32'), )) # paddle.where, paddle.pow, paddle.maximum has no double grad definition, # can not compute forward grad use double trick @@ -267,6 +268,7 @@ where_wrap = lambda x, y: paddle.where(paddle.eye(3, 4) == 1, x, y) (np.random.rand(3, 3), np.random.rand(3, 3)), (np.random.rand(3, 3), ), 'float64'), ('sin', paddle.sin, (np.random.rand(100, 200), ), None, 'float32'), + ('rsqrt', paddle.rsqrt, (np.random.rand(100, 200), ), None, 'float32'), ('cos', paddle.cos, (np.random.rand(200, 90), ), None, 'float32'), ('exp', paddle.exp, (np.random.rand(299, 320), ), None, 'float32'), # In where op, grad of condition computed by paddle.static.gradients is None, diff --git a/python/paddle/fluid/tests/unittests/autograd/test_transform.py b/python/paddle/fluid/tests/unittests/autograd/test_transform.py index 6c0aa697550bc369981408e1b7a7bd0f26ac4df2..32ef176fa5282fd4bbd3c62bf300366361bc443d 100644 --- a/python/paddle/fluid/tests/unittests/autograd/test_transform.py +++ b/python/paddle/fluid/tests/unittests/autograd/test_transform.py @@ -48,15 +48,16 @@ class TestAutoGradTransformForAdd(unittest.TestCase): A = paddle.tanh(X0) B = paddle.tanh(X1) - Y = paddle.add(A, B) + C = paddle.rsqrt(B) + Y = paddle.add(A, C) self.orig_xs = [X0, X1] self.orig_ys = [ Y, ] - self.orig_ops = ['tanh', 'tanh', 'elementwise_add'] - self.orig2prim_ops = ['tanh_p', 'tanh_p', 'add_p'] + self.orig_ops = ['tanh', 'tanh', 'elementwise_add', 'rsqrt'] + self.orig2prim_ops = ['tanh_p', 'tanh_p', 'add_p', 'rsqrt_p'] self.linearize_ops = self.orig2prim_ops + [ # call fill_const() in linearize() function 'fill_constant_p', @@ -71,6 +72,10 @@ class TestAutoGradTransformForAdd(unittest.TestCase): 'fill_constant_p', 'mul_p', 'add_p', + 'fill_constant_p', + 'div_p', + 'div_p', + 'mul_p', ] self.transpose_ops = self.orig2prim_ops + [ # call fill_const() in transpose() function @@ -84,6 +89,10 @@ class TestAutoGradTransformForAdd(unittest.TestCase): 'mul_p', 'sub_p', 'fill_constant_p', + 'mul_p', + 'div_p', + 'div_p', + 'fill_constant_p', # transposed op 'mul_p', 'mul_p' @@ -92,13 +101,16 @@ class TestAutoGradTransformForAdd(unittest.TestCase): 'tanh', 'tanh', 'add_p', 'fill_constant', 'fill_constant', 'fill_constant', 'elementwise_mul', 'sub_p', 'fill_constant', 'elementwise_mul', 'sub_p', 'fill_constant', 'elementwise_mul', - 'elementwise_mul' + 'elementwise_mul', 'rsqrt', 'fill_constant', 'elementwise_div', + 'elementwise_div', 'elementwise_mul' ] self.prim2orig_ops = [ 'tanh', 'tanh', 'elementwise_add', 'fill_constant', 'fill_constant', 'fill_constant', 'elementwise_mul', 'elementwise_sub', 'fill_constant', 'elementwise_mul', 'elementwise_sub', - 'fill_constant', 'elementwise_mul', 'elementwise_mul' + 'fill_constant', 'elementwise_mul', 'elementwise_mul', 'rsqrt', + 'fill_constant', 'elementwise_div', 'elementwise_div', + 'elementwise_mul' ] def test_run(self): diff --git a/python/paddle/incubate/autograd/primops.py b/python/paddle/incubate/autograd/primops.py index 636dc8922049053daf4546823864acc483e45b02..a000bec277ec8ff4e96610a044e5d156471ef048 100644 --- a/python/paddle/incubate/autograd/primops.py +++ b/python/paddle/incubate/autograd/primops.py @@ -394,3 +394,8 @@ def cast(x, dtype, out=None): outputs={'Y': out}, attrs={'dtype': dtype}) return out + + +@REGISTER_FN('rsqrt_p', 'X', 'Y') +def rsqrt(x, out=None): + return _simple_unop(LayerHelper('rsqrt_p', **locals())) diff --git a/python/paddle/incubate/autograd/primrules.py b/python/paddle/incubate/autograd/primrules.py index 4625cfd362f07030feba94c7191b178108384474..4dbcc421498c7236fcacf3d31e3d4d4b28f517b5 100644 --- a/python/paddle/incubate/autograd/primrules.py +++ b/python/paddle/incubate/autograd/primrules.py @@ -23,7 +23,7 @@ from .primops import (add, broadcast, concat, cos, div, eq, erf, exp, fill_const, gather, ge, gt, log, matmul, max, mul, ne, neg, reduce_sum, reshape, scatter_add, select, set_value, sin, slice_assign, slice_select, split, sqrt, sub, tanh, - transpose) + transpose, rsqrt) from .primreg import (REGISTER_JVP, REGISTER_ORIG2PRIM, REGISTER_PRIM2ORIG, REGISTER_TRANSPOSE, lookup_fn, lookup_jvp, lookup_orig2prim, lookup_prim2orig, lookup_transpose, @@ -252,6 +252,11 @@ def sqrt_orig2prim(op, x): return sqrt(x) +@REGISTER_ORIG2PRIM('rsqrt') +def rsqrt_orig2prim(op, x): + return rsqrt(x) + + @REGISTER_ORIG2PRIM('matmul_v2') def matmul_v2_orig2prim(op, x, y): @@ -456,6 +461,11 @@ def sub_prim2orig(op, x, y): return paddle.subtract(x, y) +@REGISTER_PRIM2ORIG('rsqrt_p') +def rsqrt_prim2orig(op, x): + return paddle.rsqrt(x) + + @REGISTER_PRIM2ORIG('mul_p') def mul_prim2orig(op, x, y): return paddle.multiply(x, y) @@ -969,6 +979,17 @@ def cast_jvp(op, x_dot): return primops.cast(x_dot, y.dtype) +@REGISTER_JVP('rsqrt_p') +def rsqrt_jvp(op, x_dot): + if x_dot is None: + return None + y = op_position_output(op) + x = op_position_inputs(op) + c2 = fill_const(value=-2.0, shape=y.shape, dtype=y.dtype) + y_dot = mul(x_dot, div(div(y, x), c2)) + return y_dot + + ## Register transpose rules