From b681c88c61535e06505e604acd75b539672c1355 Mon Sep 17 00:00:00 2001 From: Sing_chan <51314274+betterpig@users.noreply.github.com> Date: Tue, 16 Aug 2022 10:35:36 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90autograd=E3=80=91add=20select=5Fp?= =?UTF-8?q?=E3=80=81eq=5Fp=E3=80=81pow=5Fp=20primitive=20operator=20for=20?= =?UTF-8?q?new=20autograd=20(#44813)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add select_p * fix bugs * add custom test for select_p; modify select_p primrules * modify according to xiaoxu's comment * add eq_p, select_p, pow_p, use autograd to test grad of high order * add requirement of autograd, modify expected type of eq * modify according to xiaoxu's comment * import primops to use primops.pow --- .../fluid/operators/prim_ops/CMakeLists.txt | 5 +- paddle/fluid/operators/prim_ops/eq_p_op.cc | 119 +++++++++ paddle/fluid/operators/prim_ops/pow_p_op.cc | 114 +++++++++ .../fluid/operators/prim_ops/prim_op_test.cc | 72 ++++++ .../fluid/operators/prim_ops/select_p_op.cc | 153 ++++++++++++ .../tests/unittests/autograd/CMakeLists.txt | 1 + .../autograd/test_jvp_and_transpose.py | 115 +++++++++ .../unittests/autograd/test_orig2prim.py | 58 +++++ .../unittests/autograd/test_prim2orig.py | 57 +++++ .../tests/unittests/autograd/test_primapi.py | 225 +++++++++--------- .../tests/unittests/autograd/test_primops.py | 4 + python/paddle/incubate/autograd/primops.py | 45 +++- python/paddle/incubate/autograd/primrules.py | 112 ++++++++- python/unittest_py/requirements.txt | 1 + 14 files changed, 960 insertions(+), 121 deletions(-) create mode 100644 paddle/fluid/operators/prim_ops/eq_p_op.cc create mode 100644 paddle/fluid/operators/prim_ops/pow_p_op.cc create mode 100644 paddle/fluid/operators/prim_ops/select_p_op.cc diff --git a/paddle/fluid/operators/prim_ops/CMakeLists.txt b/paddle/fluid/operators/prim_ops/CMakeLists.txt index 2583d8cfd9..5bbf4fbc61 100644 --- a/paddle/fluid/operators/prim_ops/CMakeLists.txt +++ b/paddle/fluid/operators/prim_ops/CMakeLists.txt @@ -24,7 +24,10 @@ set(PRIM_OP_SRCS tanh_p_op.cc matmul_p_op.cc fill_constant_p_op.cc - log_p_op.cc) + log_p_op.cc + select_p_op.cc + eq_p_op.cc + pow_p_op.cc) cc_test( prim_op_test diff --git a/paddle/fluid/operators/prim_ops/eq_p_op.cc b/paddle/fluid/operators/prim_ops/eq_p_op.cc new file mode 100644 index 0000000000..a22ff0d3b8 --- /dev/null +++ b/paddle/fluid/operators/prim_ops/eq_p_op.cc @@ -0,0 +1,119 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" + +namespace paddle { +namespace operators { +class EqPrimOp : public framework::OperatorBase { + public: + EqPrimOp(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : framework::OperatorBase(type, inputs, outputs, attrs) {} + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { + PADDLE_THROW(platform::errors::Unimplemented( + "Prim operator eq_p should not be excuted directly")); + } +}; + +class EqPrimOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor), The input tensor of eq_p op."); + AddInput("Y", "(Tensor), The input tensor of eq_p op."); + AddOutput("Z", "(Tensor), The output tensor of eq_p op."); + AddComment(R"DOC( +Autograd primitive eq_p operator. +)DOC"); + } +}; + +class EqPrimOpShapeInference : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *ctx) const override { + framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0]; + framework::InferShapeVarPtr y_var_ptr = ctx->GetInputVarPtrs("Y")[0]; + framework::InferShapeVarPtr z_var_ptr = ctx->GetOutputVarPtrs("Z")[0]; + + framework::VarDesc *x_var = PADDLE_GET(framework::VarDesc *, x_var_ptr); + framework::VarDesc *y_var = PADDLE_GET(framework::VarDesc *, y_var_ptr); + auto x_shape = x_var->GetShape(); + auto y_shape = y_var->GetShape(); + size_t x_rank = x_shape.size(); + size_t y_rank = y_shape.size(); + PADDLE_ENFORCE_EQ(x_rank, + y_rank, + platform::errors::InvalidArgument( + "The dimensions of two input tensor should be same, " + "but get %d and %d", + x_rank, + y_rank)); + for (size_t i = 0; i < x_rank; ++i) { + PADDLE_ENFORCE_EQ( + x_shape[i], + y_shape[i], + platform::errors::InvalidArgument( + "The shape of two input tensor at dimension %d should be same, " + "but get %d and %d", + i, + x_shape[i], + y_shape[i])); + } + + PADDLE_GET(framework::VarDesc *, z_var_ptr)->SetShape(x_shape); + } +}; + +class EqPrimOpVarTypeInference : public framework::StaticGraphVarTypeInference { + public: + void operator()(framework::InferVarTypeContext *ctx) const override { + auto x_name = Input(ctx, "X")[0]; + auto y_name = Input(ctx, "Y")[0]; + auto z_name = Output(ctx, "Z")[0]; + auto x_type = GetType(ctx, x_name); + auto y_type = GetType(ctx, y_name); + auto x_dtype = GetDataType(ctx, x_name); + auto y_dtype = GetDataType(ctx, y_name); + PADDLE_ENFORCE_EQ(x_type, + y_type, + platform::errors::InvalidArgument( + "The type of two input tensor should be same, " + "but get %d and %d", + x_type, + y_type)); + PADDLE_ENFORCE_EQ(x_dtype, + y_dtype, + platform::errors::InvalidArgument( + "The datatype of two input tensor should be same, " + "but get %d and %d", + x_dtype, + y_dtype)); + + SetType(ctx, z_name, x_type); + SetDataType(ctx, z_name, framework::proto::VarType::BOOL); + } +}; + +} // namespace operators +} // namespace paddle + +REGISTER_OPERATOR(eq_p, + paddle::operators::EqPrimOp, + paddle::operators::EqPrimOpMaker, + paddle::operators::EqPrimOpShapeInference, + paddle::operators::EqPrimOpVarTypeInference); diff --git a/paddle/fluid/operators/prim_ops/pow_p_op.cc b/paddle/fluid/operators/prim_ops/pow_p_op.cc new file mode 100644 index 0000000000..c9b5cf4331 --- /dev/null +++ b/paddle/fluid/operators/prim_ops/pow_p_op.cc @@ -0,0 +1,114 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" + +namespace paddle { +namespace operators { +class PowPrimOp : public framework::OperatorBase { + public: + PowPrimOp(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : framework::OperatorBase(type, inputs, outputs, attrs) {} + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { + PADDLE_THROW(platform::errors::Unimplemented( + "Prim operator pow_p should not be excuted directly")); + } +}; + +class PowPrimOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor), The base of pow_p op."); + AddInput("Y", "(Tensor), The exponents of pow_p op."); + AddOutput("Z", "(Tensor), The output tensor of pow_p op."); + AddComment(R"DOC( +Autograd primitive pow_p operator. +)DOC"); + } +}; + +class PowPrimOpShapeInference : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *ctx) const override { + framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0]; + framework::InferShapeVarPtr y_var_ptr = ctx->GetInputVarPtrs("Y")[0]; + framework::InferShapeVarPtr z_var_ptr = ctx->GetOutputVarPtrs("Z")[0]; + + framework::VarDesc *x_var = PADDLE_GET(framework::VarDesc *, x_var_ptr); + framework::VarDesc *y_var = PADDLE_GET(framework::VarDesc *, y_var_ptr); + auto x_shape = x_var->GetShape(); + auto y_shape = y_var->GetShape(); + size_t x_rank = x_shape.size(); + size_t y_rank = y_shape.size(); + + PADDLE_ENFORCE_EQ(x_rank, + y_rank, + platform::errors::InvalidArgument( + "The dimensions of two input tensor should be same, " + "but get %d and %d", + x_rank, + y_rank)); + for (size_t i = 0; i < x_rank; ++i) { + PADDLE_ENFORCE_EQ( + x_shape[i], + y_shape[i], + platform::errors::InvalidArgument( + "The shape of two input tensor at dimension %d should be same, " + "but get %d and %d", + i, + x_shape[i], + y_shape[i])); + } + + PADDLE_GET(framework::VarDesc *, z_var_ptr)->SetShape(x_shape); + } +}; + +class PowPrimOpVarTypeInference + : public framework::StaticGraphVarTypeInference { + public: + void operator()(framework::InferVarTypeContext *ctx) const override { + auto x_name = Input(ctx, "X")[0]; + auto y_name = Input(ctx, "Y")[0]; + auto z_name = Output(ctx, "Z")[0]; + auto x_type = GetType(ctx, x_name); + auto y_type = GetType(ctx, y_name); + auto x_dtype = GetDataType(ctx, x_name); + + PADDLE_ENFORCE_EQ(x_type, + y_type, + platform::errors::InvalidArgument( + "The type of two input tensor should be same, " + "but get %d and %d", + x_type, + y_type)); + + SetType(ctx, z_name, x_type); + SetDataType(ctx, z_name, x_dtype); + } +}; + +} // namespace operators +} // namespace paddle + +REGISTER_OPERATOR(pow_p, + paddle::operators::PowPrimOp, + paddle::operators::PowPrimOpMaker, + paddle::operators::PowPrimOpShapeInference, + paddle::operators::PowPrimOpVarTypeInference); diff --git a/paddle/fluid/operators/prim_ops/prim_op_test.cc b/paddle/fluid/operators/prim_ops/prim_op_test.cc index 5fb7ae8230..7f2f07cf1a 100644 --- a/paddle/fluid/operators/prim_ops/prim_op_test.cc +++ b/paddle/fluid/operators/prim_ops/prim_op_test.cc @@ -35,6 +35,9 @@ USE_OP_ITSELF(tanh_p); USE_OP_ITSELF(matmul_p); USE_OP_ITSELF(fill_constant_p); USE_OP_ITSELF(log_p); +USE_OP_ITSELF(select_p); +USE_OP_ITSELF(eq_p); +USE_OP_ITSELF(pow_p); namespace paddle { namespace framework { @@ -615,5 +618,74 @@ TEST(PrimOp, log_p) { ASSERT_EQ(shapes[2], 5L); } +TEST(PrimOp, select_p) { + ProgramDesc program; + auto *block = program.MutableBlock(0); + std::vector shape{2, 3}; + + std::string cond = "cond"; + std::string x = "x"; + std::string y = "y"; + std::string z = "z"; + + NewVar(block, cond, shape); + NewVar(block, x, shape); + NewVar(block, y, shape); + + AppendOp(block, + "select_p", + {{"Condition", {cond}}, {"X", {x}}, {"Y", {y}}}, + {{"Z", {z}}}, + {}); + ASSERT_EQ(block->Var("z")->GetType(), proto::VarType::LOD_TENSOR); + ASSERT_EQ(block->Var("z")->GetDataType(), proto::VarType_Type_FP32); + auto shapes = block->Var("z")->GetShape(); + ASSERT_EQ(shapes.size(), 2UL); + ASSERT_EQ(shapes[0], 2L); + ASSERT_EQ(shapes[1], 3L); +} + +TEST(PrimOp, eq_p) { + ProgramDesc program; + auto *block = program.MutableBlock(0); + std::vector shape{3, 4, 5}; + + std::string x = "x"; + std::string y = "y"; + std::string z = "z"; + + NewVar(block, x, shape); + NewVar(block, y, shape); + AppendOp(block, "eq_p", {{"X", {x}}, {"Y", {y}}}, {{"Z", {z}}}, {}); + ASSERT_EQ(block->Var("z")->GetType(), proto::VarType::LOD_TENSOR); + ASSERT_EQ(block->Var("z")->GetDataType(), proto::VarType::BOOL); + auto shapes = block->Var("z")->GetShape(); + ASSERT_EQ(shapes.size(), 3UL); + ASSERT_EQ(shapes[0], 3L); + ASSERT_EQ(shapes[1], 4L); + ASSERT_EQ(shapes[2], 5L); +} + +TEST(PrimOp, pow_p) { + ProgramDesc program; + auto *block = program.MutableBlock(0); + std::vector shape{3, 4, 5}; + + std::string x = "x"; + std::string y = "y"; + std::string z = "z"; + + NewVar(block, x, shape); + NewVar(block, y, shape); + AppendOp(block, "pow_p", {{"X", {x}}, {"Y", {y}}}, {{"Z", {z}}}, {}); + ASSERT_EQ(block->Var("z")->GetType(), proto::VarType::LOD_TENSOR); + ASSERT_EQ(block->Var("z")->GetDataType(), proto::VarType_Type_FP32); + auto shapes = block->Var("z")->GetShape(); + ASSERT_EQ(shapes.size(), 3UL); + ASSERT_EQ(shapes[0], 3L); + ASSERT_EQ(shapes[1], 4L); + ASSERT_EQ(shapes[2], 5L); +} + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/operators/prim_ops/select_p_op.cc b/paddle/fluid/operators/prim_ops/select_p_op.cc new file mode 100644 index 0000000000..8c4c3c2f18 --- /dev/null +++ b/paddle/fluid/operators/prim_ops/select_p_op.cc @@ -0,0 +1,153 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" + +namespace paddle { +namespace operators { +class SelectPrimOp : public framework::OperatorBase { + public: + SelectPrimOp(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : framework::OperatorBase(type, inputs, outputs, attrs) {} + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { + PADDLE_THROW(platform::errors::Unimplemented( + "Prim operator select_p should not be excuted directly")); + } +}; + +class SelectPrimOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("Condition", "(Tensor), The input condition of select_p op."); + AddInput("X", "(Tensor), The input tensor of select_p op."); + AddInput("Y", "(Tensor), The input tensor of select_p op."); + AddOutput("Z", "(Tensor), The output tensor of select_p op."); + AddComment(R"DOC( +Autograd primitive select_p operator. +)DOC"); + } +}; + +class SelectPrimOpShapeInference : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *ctx) const override { + framework::InferShapeVarPtr condition_var_ptr = + ctx->GetInputVarPtrs("Condition")[0]; + framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0]; + framework::InferShapeVarPtr y_var_ptr = ctx->GetInputVarPtrs("Y")[0]; + framework::InferShapeVarPtr z_var_ptr = ctx->GetOutputVarPtrs("Z")[0]; + + framework::VarDesc *condition_var = + PADDLE_GET(framework::VarDesc *, condition_var_ptr); + framework::VarDesc *x_var = PADDLE_GET(framework::VarDesc *, x_var_ptr); + framework::VarDesc *y_var = PADDLE_GET(framework::VarDesc *, y_var_ptr); + + auto condition_shape = condition_var->GetShape(); + auto x_shape = x_var->GetShape(); + auto y_shape = y_var->GetShape(); + + size_t condition_rank = condition_shape.size(); + size_t x_rank = x_shape.size(); + size_t y_rank = y_shape.size(); + + PADDLE_ENFORCE_EQ( + condition_rank, + x_rank, + platform::errors::InvalidArgument( + "The dimensions of condtion and Inputs(X) should be same, " + "but get %d and %d", + condition_rank, + x_rank)); + PADDLE_ENFORCE_EQ( + x_rank, + y_rank, + platform::errors::InvalidArgument( + "The dimensions of Inputs(X) and Inputs(Y) should be same, " + "but get %d and %d", + x_rank, + y_rank)); + for (size_t i = 0; i < condition_rank; ++i) { + PADDLE_ENFORCE_EQ(condition_shape[i], + x_shape[i], + platform::errors::InvalidArgument( + "The shape of condition and Inputs(X) at dimension " + "%d should be same, " + "but get %d and %d", + i, + condition_shape[i], + x_shape[i])); + } + for (size_t i = 0; i < x_rank; ++i) { + PADDLE_ENFORCE_EQ(x_shape[i], + y_shape[i], + platform::errors::InvalidArgument( + "The shape of Inputs(X) and Inputs(Y) at dimension " + "%d should be same, " + "but get %d and %d", + i, + x_shape[i], + y_shape[i])); + } + + PADDLE_GET(framework::VarDesc *, z_var_ptr)->SetShape(condition_shape); + } +}; + +class SelectPrimOpVarTypeInference + : public framework::StaticGraphVarTypeInference { + public: + void operator()(framework::InferVarTypeContext *ctx) const override { + auto x_name = Input(ctx, "X")[0]; + auto y_name = Input(ctx, "Y")[0]; + auto z_name = Output(ctx, "Z")[0]; + + auto x_type = GetType(ctx, x_name); + auto y_type = GetType(ctx, y_name); + + auto x_dtype = GetDataType(ctx, x_name); + auto y_dtype = GetDataType(ctx, y_name); + + PADDLE_ENFORCE_EQ(x_type, + y_type, + platform::errors::InvalidArgument( + "The type of two input tensor should be same, " + "but get %d and %d", + x_type, + y_type)); + PADDLE_ENFORCE_EQ(x_dtype, + y_dtype, + platform::errors::InvalidArgument( + "The datatype of two input tensor should be same, " + "but get %d and %d", + x_dtype, + y_dtype)); + + SetType(ctx, z_name, x_type); + SetDataType(ctx, z_name, x_dtype); + } +}; + +} // namespace operators +} // namespace paddle + +REGISTER_OPERATOR(select_p, + paddle::operators::SelectPrimOp, + paddle::operators::SelectPrimOpMaker, + paddle::operators::SelectPrimOpShapeInference, + paddle::operators::SelectPrimOpVarTypeInference); diff --git a/python/paddle/fluid/tests/unittests/autograd/CMakeLists.txt b/python/paddle/fluid/tests/unittests/autograd/CMakeLists.txt index 45c0a08efe..f1af779f4f 100644 --- a/python/paddle/fluid/tests/unittests/autograd/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/autograd/CMakeLists.txt @@ -20,4 +20,5 @@ set_tests_properties(test_autograd_functional_static PROPERTIES TIMEOUT 160) set_tests_properties(test_minimize PROPERTIES TIMEOUT 60) if(NOT WIN32) set_tests_properties(test_autograd_functional_prim PROPERTIES TIMEOUT 60) + set_tests_properties(test_primapi PROPERTIES TIMEOUT 60) endif() diff --git a/python/paddle/fluid/tests/unittests/autograd/test_jvp_and_transpose.py b/python/paddle/fluid/tests/unittests/autograd/test_jvp_and_transpose.py index 718ea255bb..c09e5bf864 100644 --- a/python/paddle/fluid/tests/unittests/autograd/test_jvp_and_transpose.py +++ b/python/paddle/fluid/tests/unittests/autograd/test_jvp_and_transpose.py @@ -867,5 +867,120 @@ class TestScatterAddPJVPAndTranspose(TestAddPJVPAndTranspose): ] +class TestSelectPJVPAndTranspose(TestAddPJVPAndTranspose): + + def init_data(self): + # Set prim op + self.op_type = 'select_p' + Cond = paddle.static.data(name='Condition', shape=[9, 5], dtype='bool') + X = paddle.static.data(name='X', shape=[9, 5], dtype='float64') + Y = paddle.static.data(name='Y', shape=[9, 5], dtype='float64') + + self.prim_input = {'Condition': Cond, 'X': X, 'Y': Y} + self.prim_output = { + 'Z': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.prim_attrs = {} + + # Set JVP + Cond_DOT = paddle.static.data(name='Cond_DOT', + shape=[9, 5], + dtype='float64') + X_DOT = paddle.static.data(name='X_DOT', shape=[9, 5], dtype='float64') + Y_DOT = paddle.static.data(name='Y_DOT', shape=[9, 5], dtype='float64') + self.jvp_args = (Cond_DOT, X_DOT, Y_DOT) + self.jvp_out_shape_map = {0: self.prim_output['Z']} + + # Set transpose + check_dot = lambda v: True + Z_BAR = paddle.static.data(name='Z_BAR', shape=[9, 5], dtype='float64') + self.transpose_args = (check_dot, Z_BAR) + self.transpose_out_shape_map = {0: X, 1: Y} + + self.all_ops = [ + # prim op: + 'select_p', + # jvp op: + 'select_p', + # transpose op: + 'fill_constant_p', + 'fill_constant_p', + 'fill_constant_p', + 'select_p', + 'select_p', + ] + + +class TestEqPJVPAndTranspose(TestAddPJVPAndTranspose): + + def init_data(self): + # Set prim op + self.op_type = 'eq_p' + X = paddle.static.data(name='X', shape=[4, 5], dtype='float64') + Y = paddle.static.data(name='Y', shape=[4, 5], dtype='float64') + + self.prim_input = {'X': X, 'Y': Y} + self.prim_output = { + 'Z': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.prim_attrs = {} + + # Set JVP + X_DOT = paddle.static.data(name='X_DOT', shape=[4, 5], dtype='float64') + Y_DOT = paddle.static.data(name='Y_DOT', shape=[4, 5], dtype='float64') + self.jvp_args = (X_DOT, Y_DOT) + self.jvp_out_shape_map = {0: self.prim_output['Z']} + + self.all_ops = [ + # prim op: + 'eq_p', + # jvp op: + 'fill_constant_p', + # transpose op: + ] + + +class TestPowPJVPAndTranspose(TestAddPJVPAndTranspose): + + def init_data(self): + # Set prim op + self.op_type = 'pow_p' + X = paddle.static.data(name='X', shape=[5, 6], dtype='float32') + Y = paddle.static.data(name='Y', shape=[5, 6], dtype='float32') + self.prim_input = {'X': X, 'Y': Y} + self.prim_output = { + 'Z': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.prim_attrs = {} + + # Set JVP + X_DOT = paddle.static.data(name='X_DOT', shape=[5, 6], dtype='float32') + Y_DOT = paddle.static.data(name='Y_DOT', shape=[5, 6], dtype='float32') + self.jvp_args = (X_DOT, Y_DOT) + self.jvp_out_shape_map = {0: self.prim_output['Z']} + + self.all_ops = [ + # prim op: + 'pow_p', + # jvp op: + 'fill_constant_p', + 'fill_constant_p', + 'eq_p', + 'select_p', + 'sub_p', + 'mul_p', + 'mul_p', + 'pow_p', + 'mul_p', + 'mul_p', + 'log_p', + 'add_p' + # transpose op: + ] + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/autograd/test_orig2prim.py b/python/paddle/fluid/tests/unittests/autograd/test_orig2prim.py index 7745d1d59b..87f9490341 100644 --- a/python/paddle/fluid/tests/unittests/autograd/test_orig2prim.py +++ b/python/paddle/fluid/tests/unittests/autograd/test_orig2prim.py @@ -481,5 +481,63 @@ class TestAssignOrig2Prim(TestElementWiseAddOrig2Prim): self.out_map = {0: self.output['Out']} +class TestWhereOrig2Prim(TestElementWiseAddOrig2Prim): + + def init_data(self): + self.op_type = 'where' + Cond = paddle.static.data(name='Condition', shape=[5, 6], dtype='bool') + X = paddle.static.data(name='X', shape=[5, 6], dtype='float32') + Y = paddle.static.data(name='Y', shape=[5, 6], dtype='float32') + + self.input = {'Condition': Cond, 'X': X, 'Y': Y} + self.output = { + 'Out': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.attrs = {} + self.orig2prim_args = (Cond, X, Y) + self.all_ops = ['where', 'select_p'] + self.out_map = {0: self.output['Out']} + + +class TestEqualOrig2Prim(TestElementWiseAddOrig2Prim): + + def init_data(self): + self.op_type = 'equal' + X = paddle.static.data(name='X', shape=[5, 8], dtype='float') + Y = paddle.static.data(name='Y', shape=[5, 8], dtype='float') + + self.input = {'X': X, 'Y': Y} + self.output = { + 'Out': + self.layer_help.create_variable_for_type_inference(dtype='bool') + } + self.attrs = {} + self.orig2prim_args = (X, Y) + self.all_ops = ['equal', 'eq_p'] + # { prim_op_output_index: orig_op_output_var } + self.out_map = {0: self.output['Out']} + + +class TestPowOrig2Prim(TestElementWiseAddOrig2Prim): + + def init_data(self): + self.op_type = 'elementwise_pow' + X = paddle.static.data(name='X', shape=[5, 8], dtype='float') + Y = paddle.static.data(name='Y', shape=[5, 8], dtype='float') + + self.input = {'X': X, 'Y': Y} + self.output = { + 'Out': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.attrs = {} + + self.orig2prim_args = (X, Y) + self.all_ops = ['elementwise_pow', 'pow_p'] + # { prim_op_output_index: orig_op_output_var } + self.out_map = {0: self.output['Out']} + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/autograd/test_prim2orig.py b/python/paddle/fluid/tests/unittests/autograd/test_prim2orig.py index 9ab5c563a5..8d23ddad1d 100644 --- a/python/paddle/fluid/tests/unittests/autograd/test_prim2orig.py +++ b/python/paddle/fluid/tests/unittests/autograd/test_prim2orig.py @@ -497,5 +497,62 @@ class TestFillConstantPPrim2Orig(TestAddPPrim2Orig): self.out_map = {self.output['Y']: 0} +class TestSelectPPrim2Orig(TestAddPPrim2Orig): + + def init_data(self): + self.op_type = 'select_p' + Cond = paddle.static.data(name='Condition', shape=[5, 6], dtype='bool') + X = paddle.static.data(name='X', shape=[5, 6], dtype='float32') + Y = paddle.static.data(name='Y', shape=[5, 6], dtype='float32') + + self.input = {'Condition': Cond, 'X': X, 'Y': Y} + self.output = { + 'Z': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.attrs = {} + self.prim2orig_args = (Cond, X, Y) + self.all_ops = ['select_p', 'where'] + self.out_map = {self.output['Z']: 0} + + +class TestEqPPrim2Orig(TestAddPPrim2Orig): + + def init_data(self): + self.op_type = 'eq_p' + X = paddle.static.data(name='X', shape=[7, 8], dtype='float64') + Y = paddle.static.data(name='Y', shape=[7, 8], dtype='float64') + + self.input = {'X': X, 'Y': Y} + self.output = { + 'Z': + self.layer_help.create_variable_for_type_inference(dtype='bool') + } + self.attrs = {} + + self.prim2orig_args = (X, Y) + self.all_ops = ['eq_p', 'equal'] + self.out_map = {self.output['Z']: 0} + + +class TestPowPPrim2Orig(TestAddPPrim2Orig): + + def init_data(self): + self.op_type = 'pow_p' + X = paddle.static.data(name='X', shape=[7, 8], dtype='float64') + Y = paddle.static.data(name='Y', shape=[7, 8], dtype='float64') + + self.input = {'X': X, 'Y': Y} + self.output = { + 'Z': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.attrs = {} + + self.prim2orig_args = (X, Y) + self.all_ops = ['pow_p', 'elementwise_pow'] + self.out_map = {self.output['Z']: 0} + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/autograd/test_primapi.py b/python/paddle/fluid/tests/unittests/autograd/test_primapi.py index d6baf16a5b..04610ce2c7 100644 --- a/python/paddle/fluid/tests/unittests/autograd/test_primapi.py +++ b/python/paddle/fluid/tests/unittests/autograd/test_primapi.py @@ -16,6 +16,9 @@ import typing import unittest import numpy as np +import autograd +import autograd.numpy as np_autograd + import paddle import config @@ -148,6 +151,8 @@ class TestWithoutProgramGuard(unittest.TestCase): (np.random.rand(3, 3), np.random.rand(3, 3)), 'float64'), ('log', paddle.log, (np.random.rand(3, 4), ), None, 'float32'), )) +# paddle.where, paddle.pow has no double grad definition, +# can not compute forward grad use double trick class TestForwardGrad(unittest.TestCase): @classmethod @@ -239,24 +244,36 @@ class TestForwardGrad(unittest.TestCase): paddle.incubate.autograd.disable_prim() +where_wrap = lambda x, y: paddle.where(paddle.eye(3, 4) == 1, x, y) + + @utils.place(config.DEVICES) -@utils.parameterize((utils.TEST_CASE_NAME, 'fun', 'xs', 'v', 'dtype'), ( - ('matmul', paddle.matmul, - (np.random.rand(2, 3), np.random.rand(3, 2)), None, 'float32'), - ('multiply', paddle.multiply, - (np.random.rand(2, 3), np.random.rand(2, 3)), None, 'float64'), - ('add', paddle.add, - (np.random.rand(2, 3), np.random.rand(2, 3)), None, 'float32'), - ('input_not_sequence', paddle.tanh, - (np.random.rand(5, 5), ), None, 'float64'), - ('input_gradients_not_none', paddle.matmul, - (np.random.rand(3, 3), np.random.rand(3, 3)), - (np.random.rand(3, 3), ), 'float64'), - ('sin', paddle.sin, (np.random.rand(100, 200), ), None, 'float32'), - ('cos', paddle.cos, (np.random.rand(200, 90), ), None, 'float32'), - ('exp', paddle.exp, (np.random.rand(299, 320), ), None, 'float32'), - ('log', paddle.log, (np.random.rand(3, 4), ), None, 'float32'), -)) +@utils.parameterize( + (utils.TEST_CASE_NAME, 'fun', 'xs', 'v', 'dtype'), + ( + ('matmul', paddle.matmul, + (np.random.rand(2, 3), np.random.rand(3, 2)), None, 'float32'), + ('multiply', paddle.multiply, + (np.random.rand(2, 3), np.random.rand(2, 3)), None, 'float64'), + ('add', paddle.add, + (np.random.rand(2, 3), np.random.rand(2, 3)), None, 'float32'), + ('input_not_sequence', paddle.tanh, + (np.random.rand(5, 5), ), None, 'float64'), + ('input_gradients_not_none', paddle.matmul, + (np.random.rand(3, 3), np.random.rand(3, 3)), + (np.random.rand(3, 3), ), 'float64'), + ('sin', paddle.sin, (np.random.rand(100, 200), ), None, 'float32'), + ('cos', paddle.cos, (np.random.rand(200, 90), ), None, 'float32'), + ('exp', paddle.exp, (np.random.rand(299, 320), ), None, 'float32'), + # In where op, grad of condition computed by paddle.static.gradients is None, + # and paddle.incubate.autograd.grad will replace None with zeros while transpose + # will just return None because cond_dot is unused, that is a diff. + ('select', where_wrap, + (np.random.rand(3, 4), np.random.rand(3, 4)), None, 'float32'), + # pow_p and pow has diff when compute z_dot of 0^0 + ('pow', paddle.pow, + (np.array([1, 2, 3]), np.array([0, 2, 7])), None, 'float32'), + )) class TestGrad(unittest.TestCase): def setUp(self): @@ -367,6 +384,33 @@ class TestGrad(unittest.TestCase): np.testing.assert_allclose(i, j, rtol=self._rtol, atol=self._atol) +def multiply_pd(x): + x2 = paddle.multiply(x, x) + x3 = paddle.multiply(x2, x2) + x4 = paddle.multiply(x3, x) + return x4 + + +multiply_ag = lambda xs: xs[0] * xs[0] * xs[0] * xs[0] * xs[0] +sin_ag = lambda xs: np_autograd.sin(xs[0]) +cos_ag = lambda xs: np_autograd.cos(xs[0]) +exp_ag = lambda xs: np_autograd.exp(xs[0]) +pow_ag = lambda xs: xs[0]**xs[1] +log_ag = lambda xs: np_autograd.log(xs[0]) + + +@utils.place(config.DEVICES) +@utils.parameterize( + (utils.TEST_CASE_NAME, 'fun_pd', 'fun_ag', 'xs', 'v', 'dtype'), ( + ('multiply', multiply_pd, multiply_ag, + (np.random.rand(3, 5), ), None, 'float32'), + ('sin', paddle.sin, sin_ag, (np.random.rand(2, 3), ), None, 'float32'), + ('cos', paddle.cos, cos_ag, (np.random.rand(3, 4), ), None, 'float32'), + ('exp', paddle.exp, exp_ag, (np.random.rand(2, 3), ), None, 'float32'), + ('pow', paddle.pow, pow_ag, + (np.random.rand(2, 3), np.random.rand(2, 3)), None, 'float32'), + ('log', paddle.log, log_ag, (np.random.rand(3, 8), ), None, 'float32'), + )) class TestGradWithHigherOrder(unittest.TestCase): def setUp(self): @@ -377,105 +421,58 @@ class TestGradWithHigherOrder(unittest.TestCase): paddle.incubate.autograd.disable_prim() paddle.disable_static() - def test_third_order(self): - paddle.incubate.autograd.enable_prim() - main = paddle.static.Program() - startup = paddle.static.Program() - with paddle.static.program_guard(main, startup): - x = paddle.static.data(name='x', shape=[1], dtype='float32') - x2 = paddle.multiply(x, x) - x3 = paddle.multiply(x2, x) - x4 = paddle.multiply(x3, x) - - grad1, = paddle.incubate.autograd.grad([x4], [x]) - grad2, = paddle.incubate.autograd.grad([grad1], [x]) - grad3, = paddle.incubate.autograd.grad([grad2], [x]) - - paddle.incubate.autograd.prim2orig(main.block(0)) - - feed = {x.name: np.array([2.]).astype('float32')} - fetch_list = [grad3.name] - result = [np.array([48.])] - - place = paddle.CPUPlace() - if paddle.device.is_compiled_with_cuda(): - place = paddle.CUDAPlace(0) - exe = paddle.static.Executor(place) - exe.run(startup) - outs = exe.run(main, feed=feed, fetch_list=fetch_list) - np.testing.assert_allclose(outs, result, rtol=1e-5, atol=1e-5) - paddle.incubate.autograd.disable_prim() + @classmethod + def setUpClass(cls): + cls.xs = tuple(x.astype(cls.dtype) for x in cls.xs) + cls._rtol = config.TOLERANCE.get(str( + cls.dtype)).get("first_order_grad").get("rtol") + cls._atol = config.TOLERANCE.get(str( + cls.dtype)).get("first_order_grad").get("atol") - def test_fourth_order(self): - paddle.incubate.autograd.enable_prim() - main = paddle.static.Program() - startup = paddle.static.Program() - with paddle.static.program_guard(main, startup): - x = paddle.static.data(name='x', shape=[1], dtype='float32') - x2 = paddle.multiply(x, x) - x3 = paddle.multiply(x2, x) - x4 = paddle.multiply(x3, x) - x5 = paddle.multiply(x4, x) - out = paddle.sqrt(x5 + x4) - - grad1, = paddle.incubate.autograd.grad([out], [x]) - grad2, = paddle.incubate.autograd.grad([grad1], [x]) - grad3, = paddle.incubate.autograd.grad([grad2], [x]) - grad4, = paddle.incubate.autograd.grad([grad3], [x]) - - paddle.incubate.autograd.prim2orig(main.block(0)) - - feed = { - x.name: np.array([2.]).astype('float32'), - } - fetch_list = [grad4.name] - # (3*(-5*x^2-16*x-16))/(16*(x+1)^3.5) - result = [np.array([-0.27263762711])] - - place = paddle.CPUPlace() - if paddle.device.is_compiled_with_cuda(): - place = paddle.CUDAPlace(0) - exe = paddle.static.Executor(place) - exe.run(startup) - outs = exe.run(main, feed=feed, fetch_list=fetch_list) - np.testing.assert_allclose(outs, result, rtol=1e-5, atol=1e-5) - paddle.incubate.autograd.disable_prim() + def test_grad(self): - def test_fifth_order(self): - paddle.incubate.autograd.enable_prim() - main = paddle.static.Program() - startup = paddle.static.Program() - with paddle.static.program_guard(main, startup): - x = paddle.static.data(name='x', shape=[1], dtype='float32') - x2 = paddle.multiply(x, x) - x3 = paddle.multiply(x2, x) - x4 = paddle.multiply(x3, x) - x5 = paddle.multiply(x4, x) - x6 = paddle.multiply(x5, x) - out = x6 + x5 - - grad1, = paddle.incubate.autograd.grad([out], [x]) - grad2, = paddle.incubate.autograd.grad([grad1], [x]) - grad3, = paddle.incubate.autograd.grad([grad2], [x]) - grad4, = paddle.incubate.autograd.grad([grad3], [x]) - grad5, = paddle.incubate.autograd.grad([grad4], [x]) - - paddle.incubate.autograd.prim2orig() - - feed = { - x.name: np.array([2.]).astype('float32'), - } - fetch_list = [grad5.name] - result = [np.array([1560.0])] - - place = paddle.CPUPlace() - if paddle.device.is_compiled_with_cuda(): - place = paddle.CUDAPlace(0) - exe = paddle.static.Executor(place) - exe.run(startup) - outs = exe.run(main, feed=feed, fetch_list=fetch_list) - np.testing.assert_allclose(outs, result, rtol=1e-5, atol=1e-5) - paddle.incubate.autograd.disable_prim() + def expected(): + egrad = autograd.elementwise_grad + grad_3 = egrad(egrad(egrad(self.fun_ag)))(self.xs) + grad_4 = egrad(egrad(egrad(egrad(self.fun_ag))))(self.xs) + grad_5 = egrad(egrad(egrad(egrad(egrad(self.fun_ag)))))(self.xs) + # the output of egrad is tuple + return list(grad_3 + grad_4 + grad_5) + + def actual(): + paddle_grad = paddle.incubate.autograd.grad + paddle.incubate.autograd.enable_prim() + main = paddle.static.Program() + startup = paddle.static.Program() + with paddle.static.program_guard(main, startup): + feed, static_xs, static_v = utils.gen_static_data_and_feed( + self.xs, self.v, stop_gradient=False) + ys = self.fun_pd(*static_xs) if isinstance( + static_xs, typing.Sequence) else self.fun_pd(static_xs) + + grad1 = paddle_grad(ys, static_xs, static_v) + grad2 = paddle_grad(grad1, static_xs, static_v) + grad3 = paddle_grad(grad2, static_xs, static_v) + grad4 = paddle_grad(grad3, static_xs, static_v) + grad5 = paddle_grad(grad4, static_xs, static_v) + paddle.incubate.autograd.prim2orig() + + fetch_list = [grad3, grad4, grad5] + + place = paddle.CPUPlace() + if paddle.device.is_compiled_with_cuda(): + place = paddle.CUDAPlace(0) + exe = paddle.static.Executor(place) + exe.run(startup) + outs = exe.run(main, feed=feed, fetch_list=fetch_list) + paddle.incubate.autograd.disable_prim() + return outs + + actual = actual() + expected = expected() + self.assertEqual(type(actual), type(expected)) + for i, j in zip(actual, expected): + np.testing.assert_allclose(i, j, rtol=self._rtol, atol=self._atol) if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/autograd/test_primops.py b/python/paddle/fluid/tests/unittests/autograd/test_primops.py index 00a30899a5..79e9326a8c 100644 --- a/python/paddle/fluid/tests/unittests/autograd/test_primops.py +++ b/python/paddle/fluid/tests/unittests/autograd/test_primops.py @@ -95,6 +95,10 @@ paddle.enable_static() 'dtype': paddle.float32 }, (3, 2), 'float32'), ('neg', primops.neg, randn(2, 3), {}, (2, 3), 'float64'), + ('select', primops.select, + (randn(2, 3) > 0, randn(2, 3), randn(2, 3)), {}, (2, 3), 'float64'), + ('eq', primops.eq, (randn(2, 3), randn(2, 3)), {}, (2, 3), 'bool'), + ('pow', primops.pow, (randn(2, 3), randn(2, 3)), {}, (2, 3), 'float64'), )) class TestPrimops(unittest.TestCase): diff --git a/python/paddle/incubate/autograd/primops.py b/python/paddle/incubate/autograd/primops.py index c8b8a54df6..bd48a86fe4 100644 --- a/python/paddle/incubate/autograd/primops.py +++ b/python/paddle/incubate/autograd/primops.py @@ -137,11 +137,6 @@ def exp(x, out=None): return _simple_unop(LayerHelper('exp_p', **locals())) -@REGISTER_FN('log_p', 'X', 'Y') -def log(x, out=None): - return _simple_unop(LayerHelper('log_p', **locals())) - - @REGISTER_FN('reshape_p', 'X', 'Y') def reshape(x, shape, out=None): return _manipulation_unop(LayerHelper('reshape_p', **locals())) @@ -315,3 +310,43 @@ def scatter_add(x, y, indextensor, axis, out=None): outputs={'Z': out}, attrs=attrs) return out + + +@REGISTER_FN('log_p', 'X', 'Y') +def log(x, out=None): + return _simple_unop(LayerHelper('log_p', **locals())) + + +@REGISTER_FN('select_p', 'Condition', 'X', 'Y', 'Z') +def select(cond, x, y, out=None): + if len(cond.shape) != len(x.shape): + raise ValueError( + "len(cond.shape) should be equal to len(x.shape), but len(cond.shape)={} and len(x.shape)={}." + .format(len(cond.shape), len(x.shape))) + + if len(x.shape) != len(y.shape): + raise ValueError( + "len(x.shape) should be equal to len(y.shape), but len(x.shape)={} and len(y.shape)={}." + .format(len(x.shape), len(y.shape))) + + helper = LayerHelper('select_p', **locals()) + if out is None: + out = helper.create_variable_for_type_inference(dtype=x.dtype) + helper.append_op(type=helper.layer_type, + inputs={ + 'Condition': cond, + 'X': x, + 'Y': y + }, + outputs={'Z': out}) + return out + + +@REGISTER_FN('eq_p', 'X', 'Y', 'Z') +def eq(x, y, out=None): + return _simple_binop(LayerHelper('eq_p', **locals())) + + +@REGISTER_FN('pow_p', 'X', 'Y', 'Z') +def pow(x, y, out=None): + return _simple_binop(LayerHelper('pow_p', **locals())) diff --git a/python/paddle/incubate/autograd/primrules.py b/python/paddle/incubate/autograd/primrules.py index f6f32c3237..3795bffae0 100644 --- a/python/paddle/incubate/autograd/primrules.py +++ b/python/paddle/incubate/autograd/primrules.py @@ -15,10 +15,11 @@ import typing import paddle +from . import primops from .primops import (add, broadcast, concat, cos, div, exp, fill_const, gather, matmul, mul, neg, reduce, reshape, scatter_add, set_value, sin, slice_assign, slice_select, split, sqrt, sub, tanh, - transpose, log) + transpose, log, select, eq) from .primreg import (REGISTER_JVP, REGISTER_ORIG2PRIM, REGISTER_PRIM2ORIG, REGISTER_TRANSPOSE, lookup_fn, lookup_jvp, lookup_orig2prim, lookup_prim2orig, lookup_transpose, @@ -66,6 +67,10 @@ index_select scale assign sqrt +log +select +equal +elementwise_pow These original ops are partially supported: @@ -290,6 +295,28 @@ def p_norm_orig2prim(op, x): raise RuntimeError('Only support lower l2/l1 norm currently') +# TODO: support broadcast +@REGISTER_ORIG2PRIM('where') +def select_orig2prim(op, condition, x, y): + return select(condition, x, y) + + +@REGISTER_ORIG2PRIM('equal') +def equal_orig2prim(op, x, y): + if x.shape != y.shape: + y = broadcast(y, shape=x.shape) + return eq(x, y) + + +@REGISTER_ORIG2PRIM('elementwise_pow') +def elementwise_pow_orig2prim(op, x, y): + if x.shape != y.shape: + y = broadcast(y, shape=x.shape) + + z = primops.pow(x, y) + return z + + ## Register prim2orig lower rules @@ -424,6 +451,21 @@ def fill_constant_prim2orig(op): dtype=INT_DTYPE_2_STRING[op.attr('dtype')]) +@REGISTER_PRIM2ORIG('select_p') +def select_prim2orig(op, condition, x, y): + return paddle.where(condition, x, y) + + +@REGISTER_PRIM2ORIG('eq_p') +def eq_prim2orig(op, x, y): + return paddle.equal(x, y) + + +@REGISTER_PRIM2ORIG('pow_p') +def pow_prim2orig(op, x, y): + return paddle.pow(x, y) + + ## Register linearize rules @REGISTER_JVP('add_p') def add_jvp(op, x_dot, y_dot): @@ -646,6 +688,55 @@ def scatter_add_jvp(op, x_dot, y_dot): return linear_jvp(op, x_dot, y_dot, indextensor, axis=axis) +@REGISTER_JVP('select_p') +def select_jvp(op, cond_dot, x_dot, y_dot): + if x_dot is None and y_dot is None: + return None + + cond, x, y = op_position_inputs(op) + if x_dot is None: + x_dot = fill_const(value=0.0, shape=y.shape, dtype=y.dtype) + if y_dot is None: + y_dot = fill_const(value=0.0, shape=y.shape, dtype=y.dtype) + return select(cond, x_dot, y_dot) + + +@REGISTER_JVP('eq_p') +def eq_jvp(op, x_dot, y_dot): + if x_dot is None and y_dot is None: + return None + x, _ = op_position_inputs(op) + z_dot = fill_const(value=0., shape=x.shape, dtype=x.dtype) + return z_dot + + +@REGISTER_JVP('pow_p') +def pow_jvp(op, x_dot, y_dot): + + def _compute_t1(x, y): + zero_y = fill_const(value=0.0, shape=y.shape, dtype=y.dtype) + one_y = fill_const(value=1.0, shape=y.shape, dtype=y.dtype) + + cond = eq(y, zero_y) + new_y = select(cond, one_y, sub(y, one_y)) + t1 = mul(x_dot, mul(y, primops.pow(x, new_y))) + return t1 + + if x_dot is None and y_dot is None: + return None + x, y = op_position_inputs(op) + z = op_position_output(op) + + if y_dot is None: + return _compute_t1(x, y) + elif x_dot is None: + return mul(y_dot, mul(log(x), z)) + else: + t1, t2 = _compute_t1(x, y), mul(y_dot, mul(log(x), z)) + z_dot = add(t1, t2) + return z_dot + + ## Register transpose rules @@ -831,3 +922,22 @@ def scatter_add_transpose(op, check_dot, z_bar): y_bar = gather(z_bar, indextensor, axis=axis) indextensor_bar = None return x_bar, y_bar, indextensor_bar + + +@REGISTER_TRANSPOSE('select_p') +def select_transpose(op, check_dot, z_bar): + cond, x, y = op_position_inputs(op) + assert check_dot(cond) or check_dot(x) or check_dot(y), ( + f'check_dot(cond) ^ (check_dot(x) ^ check_dot(y)) must be True, ' + f'but check_dot(cond)={check_dot(cond)}, check_dot(x)={check_dot(x)} and check_dot(y)={check_dot(y)}.' + ) + + zeros_x = fill_const(value=0.0, shape=x.shape, dtype=x.dtype) + zeros_y = fill_const(value=0.0, shape=y.shape, dtype=y.dtype) + + cond_bar = fill_const(value=0.0, shape=y.shape, + dtype=cond.dtype) if check_dot(cond) else None + x_bar = select(cond, z_bar, zeros_x) if check_dot(x) else None + y_bar = select(cond, zeros_y, z_bar) if check_dot(y) else None + + return cond_bar, x_bar, y_bar diff --git a/python/unittest_py/requirements.txt b/python/unittest_py/requirements.txt index ea82c46b95..f70037e716 100644 --- a/python/unittest_py/requirements.txt +++ b/python/unittest_py/requirements.txt @@ -14,3 +14,4 @@ scipy>=1.5; python_version == "3.6" prettytable distro numpy>=1.20,<1.22; python_version >= "3.7" +autograd==1.4 -- GitLab