From 22342d51258ae885d2122c90fb694410220729cc Mon Sep 17 00:00:00 2001 From: Xiaoxu Chen Date: Tue, 26 Jul 2022 19:41:33 +0800 Subject: [PATCH] add sin,cos,exp primitive operators (#44345) --- paddle/fluid/operators/prim_ops/cos_p_op.cc | 71 +++++ paddle/fluid/operators/prim_ops/exp_p_op.cc | 71 +++++ paddle/fluid/operators/prim_ops/sin_p_op.cc | 71 +++++ .../autograd/test_jvp_and_transpose.py | 91 ++++++ .../unittests/autograd/test_orig2prim.py | 60 ++++ .../unittests/autograd/test_prim2orig.py | 60 ++++ .../tests/unittests/autograd/test_primapi.py | 201 ++++++++++---- .../tests/unittests/autograd/test_primops.py | 260 ++++++++---------- python/paddle/incubate/autograd/primops.py | 15 + python/paddle/incubate/autograd/primrules.py | 70 ++++- 10 files changed, 765 insertions(+), 205 deletions(-) create mode 100644 paddle/fluid/operators/prim_ops/cos_p_op.cc create mode 100644 paddle/fluid/operators/prim_ops/exp_p_op.cc create mode 100644 paddle/fluid/operators/prim_ops/sin_p_op.cc diff --git a/paddle/fluid/operators/prim_ops/cos_p_op.cc b/paddle/fluid/operators/prim_ops/cos_p_op.cc new file mode 100644 index 00000000000..1eba8878fb2 --- /dev/null +++ b/paddle/fluid/operators/prim_ops/cos_p_op.cc @@ -0,0 +1,71 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" + +namespace paddle { +namespace operators { +class CosPrimOp : public framework::OperatorBase { + public: + CosPrimOp(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : framework::OperatorBase(type, inputs, outputs, attrs) {} + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { + PADDLE_THROW(platform::errors::Unimplemented( + "Prim operator cos_p should not be excuted directly")); + } +}; + +class CosPrimOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor), The input tensor of cos_p op."); + AddOutput("Y", "(Tensor), The output tensor of cos_p op."); + AddComment(R"DOC(Autograd primitive cos_p operator.)DOC"); + } +}; + +class CosPrimOpShapeInference : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *ctx) const override { + framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0]; + framework::InferShapeVarPtr y_var_ptr = ctx->GetOutputVarPtrs("Y")[0]; + framework::VarDesc *x_var = PADDLE_GET(framework::VarDesc *, x_var_ptr); + PADDLE_GET(framework::VarDesc *, y_var_ptr)->SetShape(x_var->GetShape()); + } +}; + +class CosPrimOpVarTypeInference + : public framework::StaticGraphVarTypeInference { + public: + void operator()(framework::InferVarTypeContext *ctx) const override { + auto x_name = Input(ctx, "X")[0]; + auto y_name = Output(ctx, "Y")[0]; + SetType(ctx, y_name, GetType(ctx, x_name)); + SetDataType(ctx, y_name, GetDataType(ctx, x_name)); + } +}; + +} // namespace operators +} // namespace paddle + +REGISTER_OPERATOR(cos_p, + paddle::operators::CosPrimOp, + paddle::operators::CosPrimOpMaker, + paddle::operators::CosPrimOpShapeInference, + paddle::operators::CosPrimOpVarTypeInference); diff --git a/paddle/fluid/operators/prim_ops/exp_p_op.cc b/paddle/fluid/operators/prim_ops/exp_p_op.cc new file mode 100644 index 00000000000..629f900c787 --- /dev/null +++ b/paddle/fluid/operators/prim_ops/exp_p_op.cc @@ -0,0 +1,71 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" + +namespace paddle { +namespace operators { +class ExpPrimOp : public framework::OperatorBase { + public: + ExpPrimOp(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : framework::OperatorBase(type, inputs, outputs, attrs) {} + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { + PADDLE_THROW(platform::errors::Unimplemented( + "Prim operator exp_p should not be excuted directly")); + } +}; + +class ExpPrimOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor), The input tensor of exp_p op."); + AddOutput("Y", "(Tensor), The output tensor of exp_p op."); + AddComment(R"DOC(Autograd primitive exp_p operator.)DOC"); + } +}; + +class ExpPrimOpShapeInference : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *ctx) const override { + framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0]; + framework::InferShapeVarPtr y_var_ptr = ctx->GetOutputVarPtrs("Y")[0]; + framework::VarDesc *x_var = PADDLE_GET(framework::VarDesc *, x_var_ptr); + PADDLE_GET(framework::VarDesc *, y_var_ptr)->SetShape(x_var->GetShape()); + } +}; + +class ExpPrimOpVarTypeInference + : public framework::StaticGraphVarTypeInference { + public: + void operator()(framework::InferVarTypeContext *ctx) const override { + auto x_name = Input(ctx, "X")[0]; + auto y_name = Output(ctx, "Y")[0]; + SetType(ctx, y_name, GetType(ctx, x_name)); + SetDataType(ctx, y_name, GetDataType(ctx, x_name)); + } +}; + +} // namespace operators +} // namespace paddle + +REGISTER_OPERATOR(exp_p, + paddle::operators::ExpPrimOp, + paddle::operators::ExpPrimOpMaker, + paddle::operators::ExpPrimOpShapeInference, + paddle::operators::ExpPrimOpVarTypeInference); diff --git a/paddle/fluid/operators/prim_ops/sin_p_op.cc b/paddle/fluid/operators/prim_ops/sin_p_op.cc new file mode 100644 index 00000000000..2e658cf2f94 --- /dev/null +++ b/paddle/fluid/operators/prim_ops/sin_p_op.cc @@ -0,0 +1,71 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" + +namespace paddle { +namespace operators { +class SinPrimOp : public framework::OperatorBase { + public: + SinPrimOp(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : framework::OperatorBase(type, inputs, outputs, attrs) {} + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { + PADDLE_THROW(platform::errors::Unimplemented( + "Prim operator sin_p should not be excuted directly")); + } +}; + +class SinPrimOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor), The input tensor of sin_p op."); + AddOutput("Y", "(Tensor), The output tensor of sin_p op."); + AddComment(R"DOC(Autograd primitive sin_p operator.)DOC"); + } +}; + +class SinPrimOpShapeInference : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *ctx) const override { + framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0]; + framework::InferShapeVarPtr y_var_ptr = ctx->GetOutputVarPtrs("Y")[0]; + framework::VarDesc *x_var = PADDLE_GET(framework::VarDesc *, x_var_ptr); + PADDLE_GET(framework::VarDesc *, y_var_ptr)->SetShape(x_var->GetShape()); + } +}; + +class SinPrimOpVarTypeInference + : public framework::StaticGraphVarTypeInference { + public: + void operator()(framework::InferVarTypeContext *ctx) const override { + auto x_name = Input(ctx, "X")[0]; + auto y_name = Output(ctx, "Y")[0]; + SetType(ctx, y_name, GetType(ctx, x_name)); + SetDataType(ctx, y_name, GetDataType(ctx, x_name)); + } +}; + +} // namespace operators +} // namespace paddle + +REGISTER_OPERATOR(sin_p, + paddle::operators::SinPrimOp, + paddle::operators::SinPrimOpMaker, + paddle::operators::SinPrimOpShapeInference, + paddle::operators::SinPrimOpVarTypeInference); diff --git a/python/paddle/fluid/tests/unittests/autograd/test_jvp_and_transpose.py b/python/paddle/fluid/tests/unittests/autograd/test_jvp_and_transpose.py index f99bb9074c9..9c5df9148c6 100644 --- a/python/paddle/fluid/tests/unittests/autograd/test_jvp_and_transpose.py +++ b/python/paddle/fluid/tests/unittests/autograd/test_jvp_and_transpose.py @@ -273,6 +273,97 @@ class TestTanhPJVPAndTranspose(TestAddPJVPAndTranspose): ] +class TestSinPJVPAndTranspose(TestAddPJVPAndTranspose): + + def init_data(self): + # Set prim op + self.op_type = 'sin_p' + X = paddle.static.data(name='X', shape=[5, 6], dtype='int64') + self.prim_input = { + 'X': X, + } + self.prim_output = { + 'Y': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.prim_attrs = {} + + # Set JVP + X_DOT = paddle.static.data(name='X_DOT', shape=[5, 6], dtype='int64') + self.jvp_args = (X_DOT, ) + self.jvp_out_shape_map = {0: self.prim_output['Y']} + + self.all_ops = [ + # prim op: + 'sin_p', + # jvp op: + 'mul_p', + 'cos_p', + # transpose op: + ] + + +class TestCosPJVPAndTranspose(TestAddPJVPAndTranspose): + + def init_data(self): + # Set prim op + self.op_type = 'cos_p' + X = paddle.static.data(name='X', shape=[5, 6], dtype='int64') + self.prim_input = { + 'X': X, + } + self.prim_output = { + 'Y': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.prim_attrs = {} + + # Set JVP + X_DOT = paddle.static.data(name='X_DOT', shape=[5, 6], dtype='int64') + self.jvp_args = (X_DOT, ) + self.jvp_out_shape_map = {0: self.prim_output['Y']} + + self.all_ops = [ + # prim op: + 'cos_p', + # jvp op: + 'mul_p', + 'sin_p', + 'fill_constant_p', + 'sub_p' + # transpose op: + ] + + +class TestExpPJVPAndTranspose(TestAddPJVPAndTranspose): + + def init_data(self): + # Set prim op + self.op_type = 'exp_p' + X = paddle.static.data(name='X', shape=[5, 6], dtype='int64') + self.prim_input = { + 'X': X, + } + self.prim_output = { + 'Y': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.prim_attrs = {} + + # Set JVP + X_DOT = paddle.static.data(name='X_DOT', shape=[5, 6], dtype='int64') + self.jvp_args = (X_DOT, ) + self.jvp_out_shape_map = {0: self.prim_output['Y']} + + self.all_ops = [ + # prim op: + 'exp_p', + # jvp op: + 'mul_p', + # transpose op: + ] + + class TestReshapePJVPAndTranspose(TestAddPJVPAndTranspose): def init_data(self): diff --git a/python/paddle/fluid/tests/unittests/autograd/test_orig2prim.py b/python/paddle/fluid/tests/unittests/autograd/test_orig2prim.py index 924292c4a4a..7557d2ba668 100644 --- a/python/paddle/fluid/tests/unittests/autograd/test_orig2prim.py +++ b/python/paddle/fluid/tests/unittests/autograd/test_orig2prim.py @@ -148,6 +148,66 @@ class TestTanhOrig2Prim(TestElementWiseAddOrig2Prim): self.out_map = {0: self.output['Out']} +class TestSinOrig2Prim(TestElementWiseAddOrig2Prim): + + def init_data(self): + self.op_type = 'sin' + X = paddle.static.data(name='X', shape=[3, 4], dtype='float') + + self.input = { + 'X': X, + } + self.output = { + 'Out': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.attrs = {} + + self.orig2prim_args = (X, ) + self.all_ops = ['sin', 'sin_p'] + self.out_map = {0: self.output['Out']} + + +class TestCosOrig2Prim(TestElementWiseAddOrig2Prim): + + def init_data(self): + self.op_type = 'cos' + X = paddle.static.data(name='X', shape=[3, 4], dtype='float') + + self.input = { + 'X': X, + } + self.output = { + 'Out': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.attrs = {} + + self.orig2prim_args = (X, ) + self.all_ops = ['cos', 'cos_p'] + self.out_map = {0: self.output['Out']} + + +class TestExpOrig2Prim(TestElementWiseAddOrig2Prim): + + def init_data(self): + self.op_type = 'exp' + X = paddle.static.data(name='X', shape=[3, 4], dtype='float') + + self.input = { + 'X': X, + } + self.output = { + 'Out': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.attrs = {} + + self.orig2prim_args = (X, ) + self.all_ops = ['exp', 'exp_p'] + self.out_map = {0: self.output['Out']} + + class TestReshape2Orig2Prim(TestElementWiseAddOrig2Prim): def init_data(self): diff --git a/python/paddle/fluid/tests/unittests/autograd/test_prim2orig.py b/python/paddle/fluid/tests/unittests/autograd/test_prim2orig.py index 56a28f38712..42c8cce0a8f 100644 --- a/python/paddle/fluid/tests/unittests/autograd/test_prim2orig.py +++ b/python/paddle/fluid/tests/unittests/autograd/test_prim2orig.py @@ -164,6 +164,66 @@ class TestTanhPPrim2Orig(TestAddPPrim2Orig): self.out_map = {self.output['Y']: 0} +class TestSinPPrim2Orig(TestAddPPrim2Orig): + + def init_data(self): + self.op_type = 'sin_p' + X = paddle.static.data(name='X', shape=[7, 8], dtype='float64') + + self.input = { + 'X': X, + } + self.output = { + 'Y': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.attrs = {} + + self.prim2orig_args = (X, ) + self.all_ops = ['sin_p', 'sin'] + self.out_map = {self.output['Y']: 0} + + +class TestCosPPrim2Orig(TestAddPPrim2Orig): + + def init_data(self): + self.op_type = 'cos_p' + X = paddle.static.data(name='X', shape=[7, 8], dtype='float64') + + self.input = { + 'X': X, + } + self.output = { + 'Y': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.attrs = {} + + self.prim2orig_args = (X, ) + self.all_ops = ['cos_p', 'cos'] + self.out_map = {self.output['Y']: 0} + + +class TestExpPPrim2Orig(TestAddPPrim2Orig): + + def init_data(self): + self.op_type = 'exp_p' + X = paddle.static.data(name='X', shape=[7, 8], dtype='float64') + + self.input = { + 'X': X, + } + self.output = { + 'Y': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.attrs = {} + + self.prim2orig_args = (X, ) + self.all_ops = ['exp_p', 'exp'] + self.out_map = {self.output['Y']: 0} + + class TestReshapePPrim2Orig(TestAddPPrim2Orig): def init_data(self): diff --git a/python/paddle/fluid/tests/unittests/autograd/test_primapi.py b/python/paddle/fluid/tests/unittests/autograd/test_primapi.py index 09bd64ee678..777c16a41e6 100644 --- a/python/paddle/fluid/tests/unittests/autograd/test_primapi.py +++ b/python/paddle/fluid/tests/unittests/autograd/test_primapi.py @@ -17,7 +17,6 @@ import unittest import numpy as np import paddle -from paddle.incubate.autograd import primapi import config import utils @@ -135,19 +134,19 @@ class TestWithoutProgramGuard(unittest.TestCase): @utils.place(config.DEVICES) -@utils.parameterize( - (utils.TEST_CASE_NAME, 'fun', 'xs', 'v', 'dtype'), - (('matmul', paddle.matmul, - (np.random.rand(2, 3), np.random.rand(3, 2)), None, 'float32'), - ('multiply', paddle.multiply, - (np.random.rand(2, 3), np.random.rand(2, 3)), None, 'float64'), - ('add', paddle.add, - (np.random.rand(2, 3), np.random.rand(2, 3)), None, 'float32'), - ('input_not_sequence', paddle.tanh, - (np.random.rand(5, 5), ), None, 'float64'), - ('input_gradients_not_none', paddle.matmul, - (np.random.rand(3, 3), np.random.rand(3, 3)), - (np.random.rand(3, 3), np.random.rand(3, 3)), 'float64'))) +@utils.parameterize((utils.TEST_CASE_NAME, 'fun', 'xs', 'v', 'dtype'), ( + ('matmul', paddle.matmul, + (np.random.rand(2, 3), np.random.rand(3, 2)), None, 'float32'), + ('multiply', paddle.multiply, + (np.random.rand(2, 3), np.random.rand(2, 3)), None, 'float64'), + ('add', paddle.add, + (np.random.rand(2, 3), np.random.rand(2, 3)), None, 'float32'), + ('input_not_sequence', paddle.tanh, + (np.random.rand(5, 5), ), None, 'float64'), + ('input_gradients_not_none', paddle.matmul, + (np.random.rand(3, 3), np.random.rand(3, 3)), + (np.random.rand(3, 3), np.random.rand(3, 3)), 'float64'), +)) class TestForwardGrad(unittest.TestCase): @classmethod @@ -219,7 +218,8 @@ class TestForwardGrad(unittest.TestCase): self.xs, self.v, stop_gradient=False) ys = self.fun(*static_xs) if isinstance( static_xs, typing.Sequence) else self.fun(static_xs) - ys_grad = primapi.forward_grad(ys, static_xs, static_v) + ys_grad = paddle.incubate.autograd.forward_grad( + ys, static_xs, static_v) paddle.incubate.autograd.prim2orig(mp.block(0)) exe = paddle.static.Executor() exe.run(sp) @@ -229,15 +229,144 @@ class TestForwardGrad(unittest.TestCase): def test_illegal_param(self): paddle.incubate.autograd.enable_prim() with self.assertRaises(TypeError): - primapi.forward_grad(1, paddle.static.data('inputs', shape=[1])) + paddle.incubate.autograd.forward_grad( + 1, paddle.static.data('inputs', shape=[1])) with self.assertRaises(TypeError): - primapi.forward_grad(paddle.static.data('targets', shape=[1]), 1) + paddle.incubate.autograd.forward_grad( + paddle.static.data('targets', shape=[1]), 1) paddle.incubate.autograd.disable_prim() +@utils.place(config.DEVICES) +@utils.parameterize((utils.TEST_CASE_NAME, 'fun', 'xs', 'v', 'dtype'), ( + ('matmul', paddle.matmul, + (np.random.rand(2, 3), np.random.rand(3, 2)), None, 'float32'), + ('multiply', paddle.multiply, + (np.random.rand(2, 3), np.random.rand(2, 3)), None, 'float64'), + ('add', paddle.add, + (np.random.rand(2, 3), np.random.rand(2, 3)), None, 'float32'), + ('input_not_sequence', paddle.tanh, + (np.random.rand(5, 5), ), None, 'float64'), + ('input_gradients_not_none', paddle.matmul, + (np.random.rand(3, 3), np.random.rand(3, 3)), + (np.random.rand(3, 3), ), 'float64'), + ('sin', paddle.sin, (np.random.rand(100, 200), ), None, 'float32'), + ('cos', paddle.cos, (np.random.rand(200, 90), ), None, 'float32'), + ('exp', paddle.exp, (np.random.rand(299, 320), ), None, 'float32'), +)) class TestGrad(unittest.TestCase): + def setUp(self): + paddle.enable_static() + paddle.incubate.autograd.enable_prim() + + def tearDown(self): + paddle.incubate.autograd.disable_prim() + paddle.disable_static() + + @classmethod + def setUpClass(cls): + cls.xs = tuple(x.astype(cls.dtype) for x in cls.xs) + cls._rtol = config.TOLERANCE.get(str( + cls.dtype)).get("first_order_grad").get("rtol") + cls._atol = config.TOLERANCE.get(str( + cls.dtype)).get("first_order_grad").get("atol") + + def test_grad(self): + + def expected(): + paddle.incubate.autograd.disable_prim() + sp = paddle.static.Program() + mp = paddle.static.Program() + with paddle.static.program_guard(mp, sp): + feed, static_xs, static_v = utils.gen_static_data_and_feed( + self.xs, self.v, stop_gradient=False) + _, ys_grad = paddle.incubate.autograd.vjp( + self.fun, static_xs, static_v) + exe = paddle.static.Executor() + exe.run(sp) + out = exe.run(mp, feed=feed, fetch_list=ys_grad) + paddle.incubate.autograd.enable_prim() + return out + + def actual(): + paddle.incubate.autograd.enable_prim() + sp = paddle.static.Program() + mp = paddle.static.Program() + with paddle.static.program_guard(mp, sp): + feed, static_xs, static_v = utils.gen_static_data_and_feed( + self.xs, self.v, stop_gradient=False) + ys = self.fun(*static_xs) if isinstance( + static_xs, typing.Sequence) else self.fun(static_xs) + ys_grad = paddle.incubate.autograd.grad(ys, static_xs, static_v) + paddle.incubate.autograd.prim2orig(mp.block(0)) + exe = paddle.static.Executor() + exe.run(sp) + out = exe.run(mp, feed=feed, fetch_list=ys_grad) + paddle.incubate.autograd.disable_prim() + return out + + actual = actual() + expected = expected() + self.assertEqual(type(actual), type(expected)) + for i, j in zip(actual, expected): + np.testing.assert_allclose(i, j, rtol=self._rtol, atol=self._atol) + + def test_illegal_param(self): + paddle.incubate.autograd.enable_prim() + with self.assertRaises(TypeError): + paddle.incubate.autograd.grad( + 1, paddle.static.data('inputs', shape=[1])) + + with self.assertRaises(TypeError): + paddle.incubate.autograd.grad( + paddle.static.data('targets', shape=[1]), 1) + paddle.incubate.autograd.disable_prim() + + def test_disable_prim(self): + + def expected(): + paddle.incubate.autograd.disable_prim() + sp = paddle.static.Program() + mp = paddle.static.Program() + with paddle.static.program_guard(mp, sp): + feed, static_xs, static_v = utils.gen_static_data_and_feed( + self.xs, self.v, stop_gradient=False) + ys = self.fun(*static_xs) if isinstance( + static_xs, typing.Sequence) else self.fun(static_xs) + ys_grad = paddle.incubate.autograd.grad(ys, static_xs, static_v) + exe = paddle.static.Executor() + exe.run(sp) + out = exe.run(mp, feed=feed, fetch_list=ys_grad) + paddle.incubate.autograd.enable_prim() + return out + + def actual(): + paddle.incubate.autograd.disable_prim() + sp = paddle.static.Program() + mp = paddle.static.Program() + with paddle.static.program_guard(mp, sp): + feed, static_xs, static_v = utils.gen_static_data_and_feed( + self.xs, self.v, stop_gradient=False) + ys = self.fun(*static_xs) if isinstance( + static_xs, typing.Sequence) else self.fun(static_xs) + ys_grad = paddle.static.gradients(ys, static_xs, static_v) + exe = paddle.static.Executor() + exe.run(sp) + out = exe.run(mp, feed=feed, fetch_list=ys_grad) + paddle.incubate.autograd.enable_prim() + return out + + actual = actual() + expected = expected() + self.assertEqual(type(actual), type(expected)) + for i, j in zip(actual, expected): + np.testing.assert_allclose(i, j, rtol=self._rtol, atol=self._atol) + + +class TestGradWithHigherOrder(unittest.TestCase): + def setUp(self): paddle.enable_static() paddle.incubate.autograd.enable_prim() @@ -346,44 +475,6 @@ class TestGrad(unittest.TestCase): np.testing.assert_allclose(outs, result, rtol=1e-5, atol=1e-5) paddle.incubate.autograd.disable_prim() - def test_disable_prim(self): - - def actual(x: np.array): - paddle.incubate.autograd.disable_prim() - main = paddle.static.Program() - startup = paddle.static.Program() - with paddle.static.program_guard(main, startup): - var_x = paddle.static.data('x', shape=x.shape, dtype=x.dtype) - var_x.stop_gradient = False - y = paddle.tanh(var_x) - y_grad = paddle.incubate.autograd.grad(y, var_x) - y_second_grad = paddle.incubate.autograd.grad(y_grad, var_x) - exe = paddle.static.Executor() - exe.run(startup) - return exe.run(main, - feed={'x': x}, - fetch_list=[y_grad, y_second_grad]) - - def expect(x: np.array): - paddle.incubate.autograd.disable_prim() - main = paddle.static.Program() - startup = paddle.static.Program() - with paddle.static.program_guard(main, startup): - var_x = paddle.static.data('x', shape=x.shape, dtype=x.dtype) - var_x.stop_gradient = False - y = paddle.tanh(var_x) - y_grad = paddle.static.gradients(y, var_x) - y_second_grad = paddle.static.gradients(y_grad, var_x) - exe = paddle.static.Executor() - exe.run(startup) - return exe.run(main, - feed={'x': x}, - fetch_list=[y_grad, y_second_grad]) - - x = np.random.randn(100, 200) - for i, j in zip(actual(x), expect(x)): - np.testing.assert_allclose(i, j) - if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/autograd/test_primops.py b/python/paddle/fluid/tests/unittests/autograd/test_primops.py index f14664237f3..f95e6304b9a 100644 --- a/python/paddle/fluid/tests/unittests/autograd/test_primops.py +++ b/python/paddle/fluid/tests/unittests/autograd/test_primops.py @@ -11,154 +11,128 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import unittest +import uuid + import numpy as np import paddle -from paddle.incubate.autograd.primops import (neg, set_value, add, sub, mul, - div, sqrt, tanh, reshape, - broadcast, transpose, split, - concat, reduce, matmul, - slice_select, slice_assign, - gather, scatter_add, fill_const) -from paddle.incubate.autograd.primx import Transform, topo_path, orig2prim, prim2orig -from paddle.incubate.autograd.utils import enable_prim, disable_prim, prim_enabled - - -class TestPyPrimOps(unittest.TestCase): - """ Test Python wrappers of primitive ops. """ - - def setUp(self): +from numpy.random import randint, randn +from paddle.incubate.autograd import primops, primx +from paddle.incubate.autograd import utils as prim_utils + +import config +import utils + +paddle.enable_static() + + +@utils.place(config.DEVICES) +@utils.parameterize( + (utils.TEST_CASE_NAME, 'op', 'args', 'kwargs', 'expected_shape', + 'expected_dtype'), + ( + ('add', primops.add, (randn(2, 3), randn(2, 3)), {}, (2, 3), 'float64'), + ('sub', primops.sub, (randn(2, 3), randn(2, 3)), {}, (2, 3), 'float64'), + ('mul', primops.mul, (randn(2, 3), randn(2, 3)), {}, (2, 3), 'float64'), + ('div', primops.div, (randn(2, 3), randn(2, 3)), {}, (2, 3), 'float64'), + ('sub', primops.sub, (randn(2, 3), randn(2, 3)), {}, (2, 3), 'float64'), + ('sqrt', primops.sqrt, randn(2, 3), {}, (2, 3), 'float64'), + ('tanh', primops.tanh, randn(2, 3), {}, (2, 3), 'float64'), + ('sin', primops.sin, randn(2, 3), {}, (2, 3), 'float64'), + ('cos', primops.cos, randn(2, 3), {}, (2, 3), 'float64'), + ('exp', primops.exp, randn(2, 3), {}, (2, 3), 'float64'), + ('reshape', primops.reshape, randn(2, 3), { + 'shape': (3, 2) + }, (3, 2), 'float64'), + ('broadcast', primops.broadcast, randn(2), { + 'shape': (3, 2) + }, (3, 2), 'float64'), + ('transpose', primops.transpose, randn(2, 3), { + 'axis': (1, 0) + }, (3, 2), 'float64'), + ('concat_axis0', primops.concat, ((randn(2, 3), randn(2, 3)), ), { + 'axis': 0 + }, (4, 3), 'float64'), + ('concat_axis1', primops.concat, ((randn(2, 3), randn(2, 3)), ), { + 'axis': 1 + }, (2, 6), 'float64'), + ('reduce_axis1', primops.reduce, randn(2, 3), { + 'axis': (1, ) + }, (2, ), 'float64'), + ('reduce_axis01', primops.reduce, randn(2, 3), { + 'axis': (0, 1) + }, (1, ), 'float64'), + ('split', primops.split, randn(2, 3), { + 'num_or_sections': [1, 2], + 'axis': 1 + }, ((2, 1), (2, 2)), ('float64', 'float64')), + ('matmul', primops.matmul, (randn(2, 3), randn(3, 2)), {}, + (2, 2), 'float64'), + ('slice_select', primops.slice_select, randn(3, 2), { + 'axis': [0], + 'starts': [0], + 'ends': [2], + 'strides': [1] + }, (2, 2), 'float64'), + ('slice_assign', primops.slice_assign, (randn(2, 3), randn(2, 2)), { + 'axis': [1], + 'starts': [1], + 'ends': [3], + 'strides': [1] + }, (2, 3), 'float64'), + ('gather', primops.gather, (randn(3, 2), randint(0, 2, + (5, ), np.int32)), { + 'axis': 0 + }, (5, 2), 'float64'), + ('scatter_add', primops.scatter_add, + (randn(3, 2), randn(5, 2), randint(0, 2, (5, ), np.int32)), { + 'axis': 0 + }, (3, 2), 'float64'), + ('fill_const', primops.fill_const, (), { + 'value': 10, + 'shape': (3, 2), + 'dtype': paddle.float32 + }, (3, 2), 'float32'), + ('neg', primops.neg, randn(2, 3), {}, (2, 3), 'float64'), + )) +class TestPrimops(unittest.TestCase): + + @classmethod + def setUpClass(cls): paddle.enable_static() - def test_ops(self): - A = np.random.rand(1) - B = np.random.rand(2) - C = np.random.rand(2, 3) - D = np.random.rand(2, 3) - E = np.random.rand(3, 2) - - a = paddle.static.data(name='A', shape=A.shape, dtype='float32') - b = paddle.static.data(name='B', shape=B.shape, dtype='float32') - c = paddle.static.data(name='C', shape=C.shape, dtype='float32') - d = paddle.static.data(name='D', shape=D.shape, dtype='float32') - e = paddle.static.data(name='E', shape=E.shape, dtype='float32') - - add_1 = add(a, a) - self.assertEqual(add_1.dtype, a.dtype) - self.assertEqual(add_1.shape, a.shape) - - add_2 = add(c, d) - self.assertEqual(add_2.dtype, c.dtype) - self.assertEqual(add_2.shape, c.shape) - - sub_1 = sub(c, d) - self.assertEqual(sub_1.dtype, c.dtype) - self.assertEqual(sub_1.shape, c.shape) - - mul_1 = mul(c, d) - self.assertEqual(mul_1.dtype, c.dtype) - self.assertEqual(mul_1.shape, c.shape) - - div_1 = div(c, d) - self.assertEqual(div_1.dtype, c.dtype) - self.assertEqual(div_1.shape, c.shape) - - sqrt_1 = sqrt(b) - self.assertEqual(sqrt_1.dtype, b.dtype) - self.assertEqual(sqrt_1.shape, b.shape) - - tanh_1 = tanh(d) - self.assertEqual(tanh_1.dtype, d.dtype) - self.assertEqual(tanh_1.shape, d.shape) - - reshape_1 = reshape(c, d.shape) - self.assertEqual(reshape_1.dtype, c.dtype) - self.assertEqual(reshape_1.shape, d.shape) - - broadcast_1 = broadcast(b, e.shape) - self.assertEqual(broadcast_1.dtype, b.dtype) - self.assertEqual(broadcast_1.shape, e.shape) - - transpose_1 = transpose(c, axis=[1, 0]) - self.assertEqual(transpose_1.dtype, c.dtype) - self.assertEqual(transpose_1.shape, e.shape) - - split_1_0, split_1_1 = split(c, num_or_sections=[1, 2], axis=1) - self.assertEqual(split_1_0.dtype, c.dtype) - self.assertEqual(split_1_0.shape, (2, 1)) - self.assertEqual(split_1_1.shape, (2, 2)) - - concat_1 = concat([c, d], axis=0) - self.assertEqual(concat_1.dtype, c.dtype) - self.assertEqual(concat_1.shape, (4, 3)) - - reduce_1 = reduce(d, axis=[1]) - self.assertEqual(reduce_1.dtype, d.dtype) - self.assertEqual(reduce_1.shape, (2, )) - - reduce_2 = reduce(c, axis=[0, 1]) - self.assertEqual(reduce_2.dtype, c.dtype) - self.assertEqual(reduce_2.shape, (1, )) - # TODO: reduce + keepdim - - matmul_1 = matmul(d, e) - self.assertEqual(matmul_1.dtype, d.dtype) - self.assertEqual(matmul_1.shape, (2, 2)) - - slice_select_1 = slice_select(e, - axis=[0], - starts=[0], - ends=[2], - strides=[1]) - self.assertEqual(slice_select_1.dtype, e.dtype) - self.assertEqual(slice_select_1.shape, (2, 2)) - - slice_select_2 = slice_select(d, - axis=[0, 1], - starts=[0, 1], - ends=[2, 3], - strides=[1, 2]) - self.assertEqual(slice_select_2.dtype, d.dtype) - self.assertEqual(slice_select_2.shape, (2, 1)) - - y = broadcast(b, [2, 2]) - slice_assign_1 = slice_assign(d, - y, - axis=[1], - starts=[1], - ends=[3], - strides=[1]) - self.assertEqual(slice_assign_1.dtype, d.dtype) - self.assertEqual(slice_assign_1.shape, d.shape) - - index = paddle.static.data('index', shape=[5], dtype='int32') - gather_1 = gather(e, index, axis=0) - self.assertEqual(gather_1.dtype, e.dtype) - self.assertEqual(gather_1.shape, (5, 2)) - - y = paddle.rand([5, 2], dtype='float32') - scatter_add_1 = scatter_add(e, y, index, axis=0) - self.assertEqual(scatter_add_1.dtype, e.dtype) - self.assertEqual(scatter_add_1.shape, e.shape) - - fill_const_1 = fill_const(value=10, shape=a.shape, dtype=a.dtype) - self.assertEqual(fill_const_1.shape, a.shape) - self.assertEqual(fill_const_1.dtype, a.dtype) - - neg_1 = neg(x=b) - self.assertEqual(neg_1.shape, b.shape) - self.assertEqual(neg_1.dtype, b.dtype) - - set_value_1 = set_value(d, - a, - axis=[1], - starts=[1], - ends=[3], - strides=[1], - out=d) - self.assertEqual(set_value_1.shape, d.shape) - self.assertEqual(set_value_1.dtype, d.dtype) + @classmethod + def tearDownClass(cls): + paddle.disable_static() + + def test_prim_ops(self): + program = paddle.static.Program() + with paddle.static.program_guard(program): + args = self._as_tuple(self.args) + args = self.arr2var(args) + results = self.op(*args, **self.kwargs) + results = self._as_tuple(results) + expected_shape = self._as_tuple(self.expected_shape) + expected_dtype = self._as_tuple(self.expected_dtype) + + for r, shape, dtype in zip(results, expected_shape, expected_dtype): + self.assertEqual(r.shape, shape) + self.assertEqual(str(r.dtype).split('.')[1], dtype) + + def arr2var(self, arr): + """convert numpy ndarray to paddle Variable recursively.""" + return [ + paddle.static.data(f'x{uuid.uuid4()}', v.shape, v.dtype) + if isinstance(v, np.ndarray) else self.arr2var(v) for v in arr + ] + + def _as_tuple(self, input): + if isinstance(input, (tuple, list)) and len(input) == 0: + return input + if not isinstance(input, (tuple, list)) or all( + isinstance(i, int) for i in input): + return (input, ) + return input if __name__ == '__main__': diff --git a/python/paddle/incubate/autograd/primops.py b/python/paddle/incubate/autograd/primops.py index b9a3ac45996..f2313dfc2e8 100644 --- a/python/paddle/incubate/autograd/primops.py +++ b/python/paddle/incubate/autograd/primops.py @@ -122,6 +122,21 @@ def tanh(x, out=None): return _simple_unop(LayerHelper('tanh_p', **locals())) +@REGISTER_FN('sin_p', 'X', 'Y') +def sin(x, out=None): + return _simple_unop(LayerHelper('sin_p', **locals())) + + +@REGISTER_FN('cos_p', 'X', 'Y') +def cos(x, out=None): + return _simple_unop(LayerHelper('cos_p', **locals())) + + +@REGISTER_FN('exp_p', 'X', 'Y') +def exp(x, out=None): + return _simple_unop(LayerHelper('exp_p', **locals())) + + @REGISTER_FN('reshape_p', 'X', 'Y') def reshape(x, shape, out=None): return _manipulation_unop(LayerHelper('reshape_p', **locals())) diff --git a/python/paddle/incubate/autograd/primrules.py b/python/paddle/incubate/autograd/primrules.py index 24e48e8c542..56dffe932fa 100644 --- a/python/paddle/incubate/autograd/primrules.py +++ b/python/paddle/incubate/autograd/primrules.py @@ -15,13 +15,15 @@ import typing import paddle -from .primreg import REGISTER_ORIG2PRIM, REGISTER_PRIM2ORIG, REGISTER_JVP, REGISTER_TRANSPOSE -from .primreg import (lookup_fn, lookup_orig2prim, lookup_prim2orig, lookup_jvp, - lookup_transpose, op_position_inputs, op_position_output) -from .primops import (neg, add, sub, mul, div, sqrt, tanh, reshape, broadcast, - transpose, split, concat, reduce, matmul, slice_select, - slice_assign, gather, scatter_add, fill_const, set_value) -from .utils import get_input_var_list, get_output_var_list, INT_DTYPE_2_STRING +from .primops import (add, broadcast, concat, cos, div, exp, fill_const, gather, + matmul, mul, neg, reduce, reshape, scatter_add, set_value, + sin, slice_assign, slice_select, split, sqrt, sub, tanh, + transpose) +from .primreg import (REGISTER_JVP, REGISTER_ORIG2PRIM, REGISTER_PRIM2ORIG, + REGISTER_TRANSPOSE, lookup_fn, lookup_jvp, + lookup_orig2prim, lookup_prim2orig, lookup_transpose, + op_position_inputs, op_position_output) +from .utils import INT_DTYPE_2_STRING, get_input_var_list, get_output_var_list def _orig2prim(op, *args): @@ -149,6 +151,21 @@ def tanh_orig2prim(op, x): return tanh(x) +@REGISTER_ORIG2PRIM('sin') +def sin_orig2prim(op, x): + return sin(x) + + +@REGISTER_ORIG2PRIM('cos') +def cos_orig2prim(op, x): + return cos(x) + + +@REGISTER_ORIG2PRIM('exp') +def exp_orig2prim(op, x): + return exp(x) + + @REGISTER_ORIG2PRIM('fill_zeros_like') def fill_zeros_like_orig2prim(op, x): return fill_const(value=0.0, shape=x.shape, dtype=x.dtype) @@ -301,6 +318,21 @@ def tanh_prim2orig(op, x): return paddle.tanh(x) +@REGISTER_PRIM2ORIG('sin_p') +def sin_prim2orig(op, x): + return paddle.sin(x) + + +@REGISTER_PRIM2ORIG('cos_p') +def cos_prim2orig(op, x): + return paddle.cos(x) + + +@REGISTER_PRIM2ORIG('exp_p') +def exp_prim2orig(op, x): + return paddle.exp(x) + + @REGISTER_PRIM2ORIG('reshape_p') def reshape_prim2orig(op, x): return paddle.reshape(x, shape=op.attr('shape')) @@ -453,6 +485,30 @@ def tanh_jvp(op, x_dot): return y_dot +@REGISTER_JVP('sin_p') +def sin_jvp(op, x_dot): + if x_dot is None: + return None + x, = op_position_inputs(op) + return mul(x_dot, cos(x)) + + +@REGISTER_JVP('cos_p') +def cos_jvp(op, x_dot): + if x_dot is None: + return None + x, = op_position_inputs(op) + return mul(x_dot, neg(sin(x))) + + +@REGISTER_JVP('exp_p') +def exp_jvp(op, x_dot): + if x_dot is None: + return None + y = op_position_output(op) + return mul(x_dot, y) + + @REGISTER_JVP('reshape_p') def reshape_jvp(op, x_dot): if x_dot is None: -- GitLab