From 463fc15e8d7ceb0a80d95ecd58e71d59ae28c804 Mon Sep 17 00:00:00 2001 From: Sing_chan <51314274+betterpig@users.noreply.github.com> Date: Mon, 8 Aug 2022 15:07:51 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90autograd=E3=80=91add=20log=5Fp=20primi?= =?UTF-8?q?tive=20operator=20for=20new=20autograd=20(#44779)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add log_p for auto_grad * add log_p_op.cc in prim_op_test srcs * fix bug of wrong op name; add test in test_primops * add test case of log in testprimapi * fix bug of test_without_guard * no need to fix test_without_guard --- .../fluid/operators/prim_ops/CMakeLists.txt | 3 +- paddle/fluid/operators/prim_ops/log_p_op.cc | 75 +++++++++++++++++++ .../fluid/operators/prim_ops/prim_op_test.cc | 20 +++++ .../autograd/test_jvp_and_transpose.py | 29 +++++++ .../unittests/autograd/test_orig2prim.py | 20 +++++ .../unittests/autograd/test_prim2orig.py | 20 +++++ .../tests/unittests/autograd/test_primapi.py | 2 + .../tests/unittests/autograd/test_primops.py | 1 + python/paddle/incubate/autograd/primops.py | 5 ++ python/paddle/incubate/autograd/primrules.py | 20 ++++- 10 files changed, 193 insertions(+), 2 deletions(-) create mode 100644 paddle/fluid/operators/prim_ops/log_p_op.cc diff --git a/paddle/fluid/operators/prim_ops/CMakeLists.txt b/paddle/fluid/operators/prim_ops/CMakeLists.txt index d29933bc196..2583d8cfd9c 100644 --- a/paddle/fluid/operators/prim_ops/CMakeLists.txt +++ b/paddle/fluid/operators/prim_ops/CMakeLists.txt @@ -23,7 +23,8 @@ set(PRIM_OP_SRCS sqrt_p_op.cc tanh_p_op.cc matmul_p_op.cc - fill_constant_p_op.cc) + fill_constant_p_op.cc + log_p_op.cc) cc_test( prim_op_test diff --git a/paddle/fluid/operators/prim_ops/log_p_op.cc b/paddle/fluid/operators/prim_ops/log_p_op.cc new file mode 100644 index 00000000000..199ef0bad36 --- /dev/null +++ b/paddle/fluid/operators/prim_ops/log_p_op.cc @@ -0,0 +1,75 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" + +namespace paddle { +namespace operators { +class LogPrimOp : public framework::OperatorBase { + public: + LogPrimOp(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : framework::OperatorBase(type, inputs, outputs, attrs) {} + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { + PADDLE_THROW(platform::errors::Unimplemented( + "Prim operator log_p should not be excuted directly")); + } +}; + +class LogPrimOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor), The input tensor of log_p op."); + AddOutput("Y", "(Tensor), The output tensor of log_p op."); + AddComment(R"DOC( +Autograd primitive log_p operator. +)DOC"); + } +}; + +class LogPrimOpShapeInference : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *ctx) const override { + framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0]; + framework::InferShapeVarPtr y_var_ptr = ctx->GetOutputVarPtrs("Y")[0]; + + framework::VarDesc *x_var = PADDLE_GET(framework::VarDesc *, x_var_ptr); + + PADDLE_GET(framework::VarDesc *, y_var_ptr)->SetShape(x_var->GetShape()); + } +}; + +class LogPrimOpVarTypeInference + : public framework::StaticGraphVarTypeInference { + public: + void operator()(framework::InferVarTypeContext *ctx) const override { + auto x_name = Input(ctx, "X")[0]; + auto y_name = Output(ctx, "Y")[0]; + SetType(ctx, y_name, GetType(ctx, x_name)); + SetDataType(ctx, y_name, GetDataType(ctx, x_name)); + } +}; + +} // namespace operators +} // namespace paddle + +REGISTER_OPERATOR(log_p, + paddle::operators::LogPrimOp, + paddle::operators::LogPrimOpMaker, + paddle::operators::LogPrimOpShapeInference, + paddle::operators::LogPrimOpVarTypeInference); diff --git a/paddle/fluid/operators/prim_ops/prim_op_test.cc b/paddle/fluid/operators/prim_ops/prim_op_test.cc index df5de4e1ab4..5fb7ae82308 100644 --- a/paddle/fluid/operators/prim_ops/prim_op_test.cc +++ b/paddle/fluid/operators/prim_ops/prim_op_test.cc @@ -34,6 +34,7 @@ USE_OP_ITSELF(sqrt_p); USE_OP_ITSELF(tanh_p); USE_OP_ITSELF(matmul_p); USE_OP_ITSELF(fill_constant_p); +USE_OP_ITSELF(log_p); namespace paddle { namespace framework { @@ -595,5 +596,24 @@ TEST(PrimOp, fill_constant_p) { ASSERT_EQ(shapes[2], 5L); } +TEST(PrimOp, log_p) { + ProgramDesc program; + auto *block = program.MutableBlock(0); + std::vector shape{3, 4, 5}; + + std::string x0 = "x0"; + std::string x1 = "x1"; + + NewVar(block, x0, shape); + AppendOp(block, "log_p", {{"X", {x0}}}, {{"Y", {x1}}}, {}); + ASSERT_EQ(block->Var("x1")->GetType(), proto::VarType::LOD_TENSOR); + ASSERT_EQ(block->Var("x1")->GetDataType(), proto::VarType_Type_FP32); + auto shapes = block->Var("x1")->GetShape(); + ASSERT_EQ(shapes.size(), 3UL); + ASSERT_EQ(shapes[0], 3L); + ASSERT_EQ(shapes[1], 4L); + ASSERT_EQ(shapes[2], 5L); +} + } // namespace framework } // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/autograd/test_jvp_and_transpose.py b/python/paddle/fluid/tests/unittests/autograd/test_jvp_and_transpose.py index 9c5df9148c6..718ea255bb2 100644 --- a/python/paddle/fluid/tests/unittests/autograd/test_jvp_and_transpose.py +++ b/python/paddle/fluid/tests/unittests/autograd/test_jvp_and_transpose.py @@ -364,6 +364,35 @@ class TestExpPJVPAndTranspose(TestAddPJVPAndTranspose): ] +class TestLogPJVPAndTranspose(TestAddPJVPAndTranspose): + + def init_data(self): + # Set prim op + self.op_type = 'log_p' + X = paddle.static.data(name='X', shape=[5, 6], dtype='int64') + self.prim_input = { + 'X': X, + } + self.prim_output = { + 'Y': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.prim_attrs = {} + + # Set JVP + X_DOT = paddle.static.data(name='X_DOT', shape=[5, 6], dtype='int64') + self.jvp_args = (X_DOT, ) + self.jvp_out_shape_map = {0: self.prim_output['Y']} + + self.all_ops = [ + # prim op: + 'log_p', + # jvp op: + 'div_p', + # transpose op: + ] + + class TestReshapePJVPAndTranspose(TestAddPJVPAndTranspose): def init_data(self): diff --git a/python/paddle/fluid/tests/unittests/autograd/test_orig2prim.py b/python/paddle/fluid/tests/unittests/autograd/test_orig2prim.py index 7557d2ba668..7745d1d59b3 100644 --- a/python/paddle/fluid/tests/unittests/autograd/test_orig2prim.py +++ b/python/paddle/fluid/tests/unittests/autograd/test_orig2prim.py @@ -208,6 +208,26 @@ class TestExpOrig2Prim(TestElementWiseAddOrig2Prim): self.out_map = {0: self.output['Out']} +class TestLogOrig2Prim(TestElementWiseAddOrig2Prim): + + def init_data(self): + self.op_type = 'log' + X = paddle.static.data(name='X', shape=[3, 4], dtype='float') + + self.input = { + 'X': X, + } + self.output = { + 'Out': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.attrs = {} + + self.orig2prim_args = (X, ) + self.all_ops = ['log', 'log_p'] + self.out_map = {0: self.output['Out']} + + class TestReshape2Orig2Prim(TestElementWiseAddOrig2Prim): def init_data(self): diff --git a/python/paddle/fluid/tests/unittests/autograd/test_prim2orig.py b/python/paddle/fluid/tests/unittests/autograd/test_prim2orig.py index 42c8cce0a8f..9ab5c563a51 100644 --- a/python/paddle/fluid/tests/unittests/autograd/test_prim2orig.py +++ b/python/paddle/fluid/tests/unittests/autograd/test_prim2orig.py @@ -224,6 +224,26 @@ class TestExpPPrim2Orig(TestAddPPrim2Orig): self.out_map = {self.output['Y']: 0} +class TestLogPPrim2Orig(TestAddPPrim2Orig): + + def init_data(self): + self.op_type = 'log_p' + X = paddle.static.data(name='X', shape=[7, 8], dtype='float64') + + self.input = { + 'X': X, + } + self.output = { + 'Y': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.attrs = {} + + self.prim2orig_args = (X, ) + self.all_ops = ['log_p', 'log'] + self.out_map = {self.output['Y']: 0} + + class TestReshapePPrim2Orig(TestAddPPrim2Orig): def init_data(self): diff --git a/python/paddle/fluid/tests/unittests/autograd/test_primapi.py b/python/paddle/fluid/tests/unittests/autograd/test_primapi.py index 777c16a41e6..d6baf16a5b6 100644 --- a/python/paddle/fluid/tests/unittests/autograd/test_primapi.py +++ b/python/paddle/fluid/tests/unittests/autograd/test_primapi.py @@ -146,6 +146,7 @@ class TestWithoutProgramGuard(unittest.TestCase): ('input_gradients_not_none', paddle.matmul, (np.random.rand(3, 3), np.random.rand(3, 3)), (np.random.rand(3, 3), np.random.rand(3, 3)), 'float64'), + ('log', paddle.log, (np.random.rand(3, 4), ), None, 'float32'), )) class TestForwardGrad(unittest.TestCase): @@ -254,6 +255,7 @@ class TestForwardGrad(unittest.TestCase): ('sin', paddle.sin, (np.random.rand(100, 200), ), None, 'float32'), ('cos', paddle.cos, (np.random.rand(200, 90), ), None, 'float32'), ('exp', paddle.exp, (np.random.rand(299, 320), ), None, 'float32'), + ('log', paddle.log, (np.random.rand(3, 4), ), None, 'float32'), )) class TestGrad(unittest.TestCase): diff --git a/python/paddle/fluid/tests/unittests/autograd/test_primops.py b/python/paddle/fluid/tests/unittests/autograd/test_primops.py index f95e6304b9a..00a30899a58 100644 --- a/python/paddle/fluid/tests/unittests/autograd/test_primops.py +++ b/python/paddle/fluid/tests/unittests/autograd/test_primops.py @@ -41,6 +41,7 @@ paddle.enable_static() ('sin', primops.sin, randn(2, 3), {}, (2, 3), 'float64'), ('cos', primops.cos, randn(2, 3), {}, (2, 3), 'float64'), ('exp', primops.exp, randn(2, 3), {}, (2, 3), 'float64'), + ('log', primops.log, randn(2, 3), {}, (2, 3), 'float64'), ('reshape', primops.reshape, randn(2, 3), { 'shape': (3, 2) }, (3, 2), 'float64'), diff --git a/python/paddle/incubate/autograd/primops.py b/python/paddle/incubate/autograd/primops.py index f2313dfc2e8..c8b8a54df60 100644 --- a/python/paddle/incubate/autograd/primops.py +++ b/python/paddle/incubate/autograd/primops.py @@ -137,6 +137,11 @@ def exp(x, out=None): return _simple_unop(LayerHelper('exp_p', **locals())) +@REGISTER_FN('log_p', 'X', 'Y') +def log(x, out=None): + return _simple_unop(LayerHelper('log_p', **locals())) + + @REGISTER_FN('reshape_p', 'X', 'Y') def reshape(x, shape, out=None): return _manipulation_unop(LayerHelper('reshape_p', **locals())) diff --git a/python/paddle/incubate/autograd/primrules.py b/python/paddle/incubate/autograd/primrules.py index 56dffe932fa..f6f32c32375 100644 --- a/python/paddle/incubate/autograd/primrules.py +++ b/python/paddle/incubate/autograd/primrules.py @@ -18,7 +18,7 @@ import paddle from .primops import (add, broadcast, concat, cos, div, exp, fill_const, gather, matmul, mul, neg, reduce, reshape, scatter_add, set_value, sin, slice_assign, slice_select, split, sqrt, sub, tanh, - transpose) + transpose, log) from .primreg import (REGISTER_JVP, REGISTER_ORIG2PRIM, REGISTER_PRIM2ORIG, REGISTER_TRANSPOSE, lookup_fn, lookup_jvp, lookup_orig2prim, lookup_prim2orig, lookup_transpose, @@ -166,6 +166,11 @@ def exp_orig2prim(op, x): return exp(x) +@REGISTER_ORIG2PRIM('log') +def log_orig2prim(op, x): + return log(x) + + @REGISTER_ORIG2PRIM('fill_zeros_like') def fill_zeros_like_orig2prim(op, x): return fill_const(value=0.0, shape=x.shape, dtype=x.dtype) @@ -333,6 +338,11 @@ def exp_prim2orig(op, x): return paddle.exp(x) +@REGISTER_PRIM2ORIG('log_p') +def log_prim2orig(op, x): + return paddle.log(x) + + @REGISTER_PRIM2ORIG('reshape_p') def reshape_prim2orig(op, x): return paddle.reshape(x, shape=op.attr('shape')) @@ -509,6 +519,14 @@ def exp_jvp(op, x_dot): return mul(x_dot, y) +@REGISTER_JVP('log_p') +def log_jvp(op, x_dot): + if x_dot is None: + return None + x, = op_position_inputs(op) + return div(x_dot, x) + + @REGISTER_JVP('reshape_p') def reshape_jvp(op, x_dot): if x_dot is None: -- GitLab