diff --git a/paddle/fluid/operators/prim_ops/CMakeLists.txt b/paddle/fluid/operators/prim_ops/CMakeLists.txt index 1f63d5d1721b5c2fe38586879d561b8c5db105ca..30e162a4dd2a9671d42cf9760a84f1af649220fe 100644 --- a/paddle/fluid/operators/prim_ops/CMakeLists.txt +++ b/paddle/fluid/operators/prim_ops/CMakeLists.txt @@ -8,7 +8,7 @@ register_operators() set(PRIM_OP_SRCS reshape_p_op.cc broadcast_p_op.cc - reduce_p_op.cc + reduce_sum_p_op.cc transpose_p_op.cc split_p_op.cc concat_p_op.cc @@ -30,9 +30,14 @@ set(PRIM_OP_SRCS log_p_op.cc select_p_op.cc eq_p_op.cc + gt_p_op.cc + ge_p_op.cc + ne_p_op.cc pow_p_op.cc max_p_op.cc - erf_p_op.cc) + erf_p_op.cc + abs_p_op.cc + cast_p_op.cc) cc_test( prim_op_test diff --git a/paddle/fluid/operators/prim_ops/abs_p_op.cc b/paddle/fluid/operators/prim_ops/abs_p_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..8ad9d131689e70e98c511608ed74504778316d43 --- /dev/null +++ b/paddle/fluid/operators/prim_ops/abs_p_op.cc @@ -0,0 +1,71 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" + +namespace paddle { +namespace operators { +class AbsPrimOp : public framework::OperatorBase { + public: + AbsPrimOp(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : framework::OperatorBase(type, inputs, outputs, attrs) {} + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { + PADDLE_THROW(platform::errors::Unimplemented( + "Prim operator abs_p should not be excuted directly")); + } +}; + +class AbsPrimOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor), The input tensor of abs_p op."); + AddOutput("Y", "(Tensor), The output tensor of abs_p op."); + AddComment(R"DOC(Autograd primitive abs_p operator.)DOC"); + } +}; + +class AbsPrimOpShapeInference : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *ctx) const override { + framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0]; + framework::InferShapeVarPtr y_var_ptr = ctx->GetOutputVarPtrs("Y")[0]; + framework::VarDesc *x_var = PADDLE_GET(framework::VarDesc *, x_var_ptr); + PADDLE_GET(framework::VarDesc *, y_var_ptr)->SetShape(x_var->GetShape()); + } +}; + +class AbsPrimOpVarTypeInference + : public framework::StaticGraphVarTypeInference { + public: + void operator()(framework::InferVarTypeContext *ctx) const override { + auto x_name = Input(ctx, "X")[0]; + auto y_name = Output(ctx, "Y")[0]; + SetType(ctx, y_name, GetType(ctx, x_name)); + SetDataType(ctx, y_name, GetDataType(ctx, x_name)); + } +}; + +} // namespace operators +} // namespace paddle + +REGISTER_OPERATOR(abs_p, + paddle::operators::AbsPrimOp, + paddle::operators::AbsPrimOpMaker, + paddle::operators::AbsPrimOpShapeInference, + paddle::operators::AbsPrimOpVarTypeInference); diff --git a/paddle/fluid/operators/prim_ops/cast_p_op.cc b/paddle/fluid/operators/prim_ops/cast_p_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..5c8b9ab45c6bca9a3bb44ee2ca047c0c82c86b44 --- /dev/null +++ b/paddle/fluid/operators/prim_ops/cast_p_op.cc @@ -0,0 +1,78 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" + +namespace paddle { +namespace framework { +class InferShapeContext; +class VarDesc; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace operators { +class CastPrimOp : public framework::OperatorBase { + public: + CastPrimOp(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : framework::OperatorBase(type, inputs, outputs, attrs) {} + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { + PADDLE_THROW(platform::errors::Unimplemented( + "Prim operator cast_p should not be excuted directly")); + } +}; + +class CastPrimOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor), The input tensor of cast_p op."); + AddOutput("Y", "(Tensor), The output tensor of cast_p op."); + AddAttr("dtype", "output data type"); + AddComment(R"DOC(Autograd primitive cast_p operator.)DOC"); + } +}; + +class CastPrimOpShapeInference : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *ctx) const override { + framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0]; + framework::InferShapeVarPtr y_var_ptr = ctx->GetOutputVarPtrs("Y")[0]; + framework::VarDesc *x_var = PADDLE_GET(framework::VarDesc *, x_var_ptr); + PADDLE_GET(framework::VarDesc *, y_var_ptr)->SetShape(x_var->GetShape()); + } +}; + +class CastPrimOpVarTypeInference + : public framework::StaticGraphVarTypeInference { + public: + void operator()(framework::InferVarTypeContext *ctx) const override { + auto out_type = static_cast( + PADDLE_GET_CONST(int, ctx->GetAttr("dtype"))); + ctx->SetOutputDataType("Y", out_type); + } +}; + +} // namespace operators +} // namespace paddle + +REGISTER_OPERATOR(cast_p, + paddle::operators::CastPrimOp, + paddle::operators::CastPrimOpMaker, + paddle::operators::CastPrimOpShapeInference, + paddle::operators::CastPrimOpVarTypeInference); diff --git a/paddle/fluid/operators/prim_ops/ge_p_op.cc b/paddle/fluid/operators/prim_ops/ge_p_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..33fbd4cd71497ff949a64abdf5a44fc1bff458b8 --- /dev/null +++ b/paddle/fluid/operators/prim_ops/ge_p_op.cc @@ -0,0 +1,119 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" + +namespace paddle { +namespace operators { +class GePrimOp : public framework::OperatorBase { + public: + GePrimOp(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : framework::OperatorBase(type, inputs, outputs, attrs) {} + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { + PADDLE_THROW(platform::errors::Unimplemented( + "Prim operator ge_p should not be excuted directly")); + } +}; + +class GePrimOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor), The input tensor of ge_p op."); + AddInput("Y", "(Tensor), The input tensor of ge_p op."); + AddOutput("Z", "(Tensor), The output tensor of ge_p op."); + AddComment(R"DOC( +Autograd primitive ge_p operator. +)DOC"); + } +}; + +class GePrimOpShapeInference : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *ctx) const override { + framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0]; + framework::InferShapeVarPtr y_var_ptr = ctx->GetInputVarPtrs("Y")[0]; + framework::InferShapeVarPtr z_var_ptr = ctx->GetOutputVarPtrs("Z")[0]; + + framework::VarDesc *x_var = PADDLE_GET(framework::VarDesc *, x_var_ptr); + framework::VarDesc *y_var = PADDLE_GET(framework::VarDesc *, y_var_ptr); + auto x_shape = x_var->GetShape(); + auto y_shape = y_var->GetShape(); + size_t x_rank = x_shape.size(); + size_t y_rank = y_shape.size(); + PADDLE_ENFORCE_EQ(x_rank, + y_rank, + platform::errors::InvalidArgument( + "The dimensions of two input tensor should be same, " + "but get %d and %d", + x_rank, + y_rank)); + for (size_t i = 0; i < x_rank; ++i) { + PADDLE_ENFORCE_EQ( + x_shape[i], + y_shape[i], + platform::errors::InvalidArgument( + "The shape of two input tensor at dimension %d should be same, " + "but get %d and %d", + i, + x_shape[i], + y_shape[i])); + } + + PADDLE_GET(framework::VarDesc *, z_var_ptr)->SetShape(x_shape); + } +}; + +class GePrimOpVarTypeInference : public framework::StaticGraphVarTypeInference { + public: + void operator()(framework::InferVarTypeContext *ctx) const override { + auto x_name = Input(ctx, "X")[0]; + auto y_name = Input(ctx, "Y")[0]; + auto z_name = Output(ctx, "Z")[0]; + auto x_type = GetType(ctx, x_name); + auto y_type = GetType(ctx, y_name); + auto x_dtype = GetDataType(ctx, x_name); + auto y_dtype = GetDataType(ctx, y_name); + PADDLE_ENFORCE_EQ(x_type, + y_type, + platform::errors::InvalidArgument( + "The type of two input tensor should be same, " + "but get %d and %d", + x_type, + y_type)); + PADDLE_ENFORCE_EQ(x_dtype, + y_dtype, + platform::errors::InvalidArgument( + "The datatype of two input tensor should be same, " + "but get %d and %d", + x_dtype, + y_dtype)); + + SetType(ctx, z_name, x_type); + SetDataType(ctx, z_name, framework::proto::VarType::BOOL); + } +}; + +} // namespace operators +} // namespace paddle + +REGISTER_OPERATOR(ge_p, + paddle::operators::GePrimOp, + paddle::operators::GePrimOpMaker, + paddle::operators::GePrimOpShapeInference, + paddle::operators::GePrimOpVarTypeInference); diff --git a/paddle/fluid/operators/prim_ops/gt_p_op.cc b/paddle/fluid/operators/prim_ops/gt_p_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..baacab62d8c3ebfcad51819246d458002f71fd0e --- /dev/null +++ b/paddle/fluid/operators/prim_ops/gt_p_op.cc @@ -0,0 +1,119 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" + +namespace paddle { +namespace operators { +class GtPrimOp : public framework::OperatorBase { + public: + GtPrimOp(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : framework::OperatorBase(type, inputs, outputs, attrs) {} + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { + PADDLE_THROW(platform::errors::Unimplemented( + "Prim operator gt_p should not be excuted directly")); + } +}; + +class GtPrimOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor), The input tensor of gt_p op."); + AddInput("Y", "(Tensor), The input tensor of gt_p op."); + AddOutput("Z", "(Tensor), The output tensor of gt_p op."); + AddComment(R"DOC( +Autograd primitive gt_p operator. +)DOC"); + } +}; + +class GtPrimOpShapeInference : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *ctx) const override { + framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0]; + framework::InferShapeVarPtr y_var_ptr = ctx->GetInputVarPtrs("Y")[0]; + framework::InferShapeVarPtr z_var_ptr = ctx->GetOutputVarPtrs("Z")[0]; + + framework::VarDesc *x_var = PADDLE_GET(framework::VarDesc *, x_var_ptr); + framework::VarDesc *y_var = PADDLE_GET(framework::VarDesc *, y_var_ptr); + auto x_shape = x_var->GetShape(); + auto y_shape = y_var->GetShape(); + size_t x_rank = x_shape.size(); + size_t y_rank = y_shape.size(); + PADDLE_ENFORCE_EQ(x_rank, + y_rank, + platform::errors::InvalidArgument( + "The dimensions of two input tensor should be same, " + "but get %d and %d", + x_rank, + y_rank)); + for (size_t i = 0; i < x_rank; ++i) { + PADDLE_ENFORCE_EQ( + x_shape[i], + y_shape[i], + platform::errors::InvalidArgument( + "The shape of two input tensor at dimension %d should be same, " + "but get %d and %d", + i, + x_shape[i], + y_shape[i])); + } + + PADDLE_GET(framework::VarDesc *, z_var_ptr)->SetShape(x_shape); + } +}; + +class GtPrimOpVarTypeInference : public framework::StaticGraphVarTypeInference { + public: + void operator()(framework::InferVarTypeContext *ctx) const override { + auto x_name = Input(ctx, "X")[0]; + auto y_name = Input(ctx, "Y")[0]; + auto z_name = Output(ctx, "Z")[0]; + auto x_type = GetType(ctx, x_name); + auto y_type = GetType(ctx, y_name); + auto x_dtype = GetDataType(ctx, x_name); + auto y_dtype = GetDataType(ctx, y_name); + PADDLE_ENFORCE_EQ(x_type, + y_type, + platform::errors::InvalidArgument( + "The type of two input tensor should be same, " + "but get %d and %d", + x_type, + y_type)); + PADDLE_ENFORCE_EQ(x_dtype, + y_dtype, + platform::errors::InvalidArgument( + "The datatype of two input tensor should be same, " + "but get %d and %d", + x_dtype, + y_dtype)); + + SetType(ctx, z_name, x_type); + SetDataType(ctx, z_name, framework::proto::VarType::BOOL); + } +}; + +} // namespace operators +} // namespace paddle + +REGISTER_OPERATOR(gt_p, + paddle::operators::GtPrimOp, + paddle::operators::GtPrimOpMaker, + paddle::operators::GtPrimOpShapeInference, + paddle::operators::GtPrimOpVarTypeInference); diff --git a/paddle/fluid/operators/prim_ops/ne_p_op.cc b/paddle/fluid/operators/prim_ops/ne_p_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..fac503309de1b7c3f22497a54a424f147c84e0c4 --- /dev/null +++ b/paddle/fluid/operators/prim_ops/ne_p_op.cc @@ -0,0 +1,119 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" + +namespace paddle { +namespace operators { +class NePrimOp : public framework::OperatorBase { + public: + NePrimOp(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : framework::OperatorBase(type, inputs, outputs, attrs) {} + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { + PADDLE_THROW(platform::errors::Unimplemented( + "Prim operator ne_p should not be excuted directly")); + } +}; + +class NePrimOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor), The input tensor of ne_p op."); + AddInput("Y", "(Tensor), The input tensor of ne_p op."); + AddOutput("Z", "(Tensor), The output tensor of ne_p op."); + AddComment(R"DOC( +Autograd primitive ne_p operator. +)DOC"); + } +}; + +class NePrimOpShapeInference : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *ctx) const override { + framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0]; + framework::InferShapeVarPtr y_var_ptr = ctx->GetInputVarPtrs("Y")[0]; + framework::InferShapeVarPtr z_var_ptr = ctx->GetOutputVarPtrs("Z")[0]; + + framework::VarDesc *x_var = PADDLE_GET(framework::VarDesc *, x_var_ptr); + framework::VarDesc *y_var = PADDLE_GET(framework::VarDesc *, y_var_ptr); + auto x_shape = x_var->GetShape(); + auto y_shape = y_var->GetShape(); + size_t x_rank = x_shape.size(); + size_t y_rank = y_shape.size(); + PADDLE_ENFORCE_EQ(x_rank, + y_rank, + platform::errors::InvalidArgument( + "The dimensions of two input tensor should be same, " + "but get %d and %d", + x_rank, + y_rank)); + for (size_t i = 0; i < x_rank; ++i) { + PADDLE_ENFORCE_EQ( + x_shape[i], + y_shape[i], + platform::errors::InvalidArgument( + "The shape of two input tensor at dimension %d should be same, " + "but get %d and %d", + i, + x_shape[i], + y_shape[i])); + } + + PADDLE_GET(framework::VarDesc *, z_var_ptr)->SetShape(x_shape); + } +}; + +class NePrimOpVarTypeInference : public framework::StaticGraphVarTypeInference { + public: + void operator()(framework::InferVarTypeContext *ctx) const override { + auto x_name = Input(ctx, "X")[0]; + auto y_name = Input(ctx, "Y")[0]; + auto z_name = Output(ctx, "Z")[0]; + auto x_type = GetType(ctx, x_name); + auto y_type = GetType(ctx, y_name); + auto x_dtype = GetDataType(ctx, x_name); + auto y_dtype = GetDataType(ctx, y_name); + PADDLE_ENFORCE_EQ(x_type, + y_type, + platform::errors::InvalidArgument( + "The type of two input tensor should be same, " + "but get %d and %d", + x_type, + y_type)); + PADDLE_ENFORCE_EQ(x_dtype, + y_dtype, + platform::errors::InvalidArgument( + "The datatype of two input tensor should be same, " + "but get %d and %d", + x_dtype, + y_dtype)); + + SetType(ctx, z_name, x_type); + SetDataType(ctx, z_name, framework::proto::VarType::BOOL); + } +}; + +} // namespace operators +} // namespace paddle + +REGISTER_OPERATOR(ne_p, + paddle::operators::NePrimOp, + paddle::operators::NePrimOpMaker, + paddle::operators::NePrimOpShapeInference, + paddle::operators::NePrimOpVarTypeInference); diff --git a/paddle/fluid/operators/prim_ops/prim_op_test.cc b/paddle/fluid/operators/prim_ops/prim_op_test.cc index 44872f9060bfe390cd62028ab58239eb53bc4c0d..153a4575463bc849796354857ad38d750e1420ee 100644 --- a/paddle/fluid/operators/prim_ops/prim_op_test.cc +++ b/paddle/fluid/operators/prim_ops/prim_op_test.cc @@ -18,7 +18,7 @@ USE_OP_ITSELF(reshape_p); USE_OP_ITSELF(broadcast_p); -USE_OP_ITSELF(reduce_p); +USE_OP_ITSELF(reduce_sum_p); USE_OP_ITSELF(transpose_p); USE_OP_ITSELF(split_p); USE_OP_ITSELF(concat_p); @@ -130,7 +130,7 @@ TEST(PrimOp, broadcast_p) { ASSERT_EQ(shapes[2], 5L); } -TEST(PrimOp, reduce_p) { +TEST(PrimOp, reduce_sum_p) { ProgramDesc program; auto *block = program.MutableBlock(0); std::vector shape{3, 4, 5}; @@ -141,7 +141,7 @@ TEST(PrimOp, reduce_p) { NewVar(block, x0, shape); AppendOp(block, - "reduce_p", + "reduce_sum_p", {{"X", {x0}}}, {{"Y", {x1}}}, {{"axis", std::vector{0, 2}}, {"keepdim", false}}); @@ -151,7 +151,7 @@ TEST(PrimOp, reduce_p) { ASSERT_EQ(shapes.size(), 1UL); ASSERT_EQ(shapes[0], 4L); AppendOp(block, - "reduce_p", + "reduce_sum_p", {{"X", {x0}}}, {{"Y", {x2}}}, {{"axis", std::vector{0, 2}}, {"keepdim", true}}); diff --git a/paddle/fluid/operators/prim_ops/reduce_p_op.cc b/paddle/fluid/operators/prim_ops/reduce_sum_p_op.cc similarity index 74% rename from paddle/fluid/operators/prim_ops/reduce_p_op.cc rename to paddle/fluid/operators/prim_ops/reduce_sum_p_op.cc index 3c18ce46f9d937969d96518b2bc6e7efb9e5505f..b31b4934706a93d5abe0a3c5d16b2b2b18ddda44 100644 --- a/paddle/fluid/operators/prim_ops/reduce_p_op.cc +++ b/paddle/fluid/operators/prim_ops/reduce_sum_p_op.cc @@ -24,25 +24,25 @@ class VarDesc; namespace paddle { namespace operators { -class ReducePrimOp : public framework::OperatorBase { +class ReduceSumPrimOp : public framework::OperatorBase { public: - ReducePrimOp(const std::string &type, - const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) + ReduceSumPrimOp(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) : framework::OperatorBase(type, inputs, outputs, attrs) {} void RunImpl(const framework::Scope &scope, const platform::Place &dev_place) const override { PADDLE_THROW(platform::errors::Unimplemented( - "Prim operator reduce_p should not be excuted directly")); + "Prim operator reduce_sum_p should not be excuted directly")); } }; -class ReducePrimOpMaker : public framework::OpProtoAndCheckerMaker { +class ReduceSumPrimOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { - AddInput("X", "(Tensor), The input tensor of reduce_p op."); - AddOutput("Y", "(Tensor), The output tensor of reduce_p op."); + AddInput("X", "(Tensor), The input tensor of reduce_sum_p op."); + AddOutput("Y", "(Tensor), The output tensor of reduce_sum_p op."); AddAttr>( "axis", "(std::vector) The axis along which to reduce on. Must be in " @@ -53,12 +53,12 @@ class ReducePrimOpMaker : public framework::OpProtoAndCheckerMaker { "If true, retain the reduced axis with length 1.") .SetDefault(false); AddComment(R"DOC( -Autograd primitive reduce_p operator. +Autograd primitive reduce_sum_p operator. )DOC"); } }; -class ReducePrimOpShapeInference : public framework::InferShapeBase { +class ReduceSumPrimOpShapeInference : public framework::InferShapeBase { public: void operator()(framework::InferShapeContext *ctx) const override { framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0]; @@ -87,7 +87,7 @@ class ReducePrimOpShapeInference : public framework::InferShapeBase { } }; -class ReducePrimOpVarTypeInference +class ReduceSumPrimOpVarTypeInference : public framework::StaticGraphVarTypeInference { public: void operator()(framework::InferVarTypeContext *ctx) const override { @@ -101,8 +101,8 @@ class ReducePrimOpVarTypeInference } // namespace operators } // namespace paddle -REGISTER_OPERATOR(reduce_p, - paddle::operators::ReducePrimOp, - paddle::operators::ReducePrimOpMaker, - paddle::operators::ReducePrimOpShapeInference, - paddle::operators::ReducePrimOpVarTypeInference); +REGISTER_OPERATOR(reduce_sum_p, + paddle::operators::ReduceSumPrimOp, + paddle::operators::ReduceSumPrimOpMaker, + paddle::operators::ReduceSumPrimOpShapeInference, + paddle::operators::ReduceSumPrimOpVarTypeInference); diff --git a/python/paddle/distributed/auto_parallel/operators/__init__.py b/python/paddle/distributed/auto_parallel/operators/__init__.py index 295e3557df27d355c8aa1b86f742137b3d6c1d7d..02b5138be21467e38d6f3843c3f9d0c85e3f570c 100644 --- a/python/paddle/distributed/auto_parallel/operators/__init__.py +++ b/python/paddle/distributed/auto_parallel/operators/__init__.py @@ -32,4 +32,4 @@ from . import dist_pnorm from . import dist_slice from . import dist_fused_feedforward from . import dist_fused_attention -from . import dist_reduce_p +from . import dist_reduce_sum_p diff --git a/python/paddle/distributed/auto_parallel/operators/dist_reduce_p.py b/python/paddle/distributed/auto_parallel/operators/dist_reduce_sum_p.py similarity index 92% rename from python/paddle/distributed/auto_parallel/operators/dist_reduce_p.py rename to python/paddle/distributed/auto_parallel/operators/dist_reduce_sum_p.py index bdd105ef64c30313bafdf326aba93287662507a2..6b53b2eed7ad00eb9a351a290027572dedae36e0 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_reduce_p.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_reduce_sum_p.py @@ -33,21 +33,21 @@ from ..process_group import new_process_group from ..utils import _get_comm_group, _get_corresponding_rank -class DistributedReducePrimtive(DistributedOperatorImplContainer): +class DistributedReduceSumPrimtive(DistributedOperatorImplContainer): def __init__(self, op_type): - super(DistributedReducePrimtive, self).__init__(op_type) + super(DistributedReduceSumPrimtive, self).__init__(op_type) register_distributed_operator_impl_container( - DistributedReducePrimtive("reduce_p")) + DistributedReduceSumPrimtive("reduce_sum_p")) -# Batch Dimension Reduce Primitive -class DistributedReducePrimtiveImpl0(DistributedOperatorImpl): +# Batch Dimension ReduceSum Primitive +class DistributedReduceSumPrimtiveImpl0(DistributedOperatorImpl): def __init__(self, name): - super(DistributedReducePrimtiveImpl0, self).__init__(name) + super(DistributedReduceSumPrimtiveImpl0, self).__init__(name) self._forward_implemented = True self._backward_implemented = True @@ -149,4 +149,5 @@ class DistributedReducePrimtiveImpl0(DistributedOperatorImpl): register_distributed_operator_impl( - "reduce_p", DistributedReducePrimtiveImpl0("batch_dimension_reduce_p")) + "reduce_sum_p", + DistributedReduceSumPrimtiveImpl0("batch_dimension_reduce_sum_p")) diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_prim_dist_op.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_prim_dist_op.py index 67894f6dd93df9b86dfa4c56a833b8faa53dd830..69f92012c17efbb825459218664dee01e92e059a 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_prim_dist_op.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_prim_dist_op.py @@ -78,7 +78,7 @@ class TestPrimDistOp(unittest.TestCase): outputs={'Z': self.w_grad}, attrs=self.attrs) - op = self.layer_help.append_op(type="reduce_p", + op = self.layer_help.append_op(type="reduce_sum_p", inputs={'X': self.tmp2}, outputs={'Y': self.batch_reduced}, attrs={"axis": [0]}) diff --git a/python/paddle/fluid/tests/unittests/autograd/test_jvp_and_transpose.py b/python/paddle/fluid/tests/unittests/autograd/test_jvp_and_transpose.py index 51104223f954dbe5adfc0fe750e294026e6809e2..c8e1b3965228d71bbcf4a8942f7c6adfbb672f7c 100644 --- a/python/paddle/fluid/tests/unittests/autograd/test_jvp_and_transpose.py +++ b/python/paddle/fluid/tests/unittests/autograd/test_jvp_and_transpose.py @@ -400,6 +400,75 @@ class TestErfPJVPAndTranspose(TestAddPJVPAndTranspose): ] +class TestAbsPJVPAndTranspose(TestAddPJVPAndTranspose): + + def init_data(self): + # Set prim op + self.op_type = 'abs_p' + X = paddle.static.data(name='X', shape=[5, 6], dtype='int64') + self.prim_input = { + 'X': X, + } + self.prim_output = { + 'Y': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.prim_attrs = {} + + # Set JVP + X_DOT = paddle.static.data(name='X_DOT', shape=[5, 6], dtype='int64') + self.jvp_args = (X_DOT, ) + self.jvp_out_shape_map = {0: self.prim_output['Y']} + + self.all_ops = [ + # prim op: + 'abs_p', + # jvp op: + 'select_p', + 'ge_p', + 'fill_constant_p', + 'fill_constant_p', + 'sub_p', + # transpose op: + ] + + +class TestCastPJVPAndTranspose(TestAddPJVPAndTranspose): + + def init_data(self): + # Set prim op + self.op_type = 'cast_p' + X = paddle.static.data(name='X', shape=[5, 6], dtype='int64') + self.prim_input = { + 'X': X, + } + self.prim_output = { + 'Y': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.prim_attrs = {'dtype': paddle.float64} + + # Set JVP + X_DOT = paddle.static.data(name='X_DOT', shape=[5, 6], dtype='int64') + self.jvp_args = (X_DOT, ) + self.jvp_out_shape_map = {0: self.prim_output['Y']} + + # Set transpose + check_dot = lambda v: True + Y_BAR = paddle.static.data(name='Y_BAR', shape=[5, 6], dtype='float') + self.transpose_args = (check_dot, Y_BAR) + self.transpose_out_shape_map = {0: X} + + self.all_ops = [ + # prim op: + 'cast_p', + # jvp op: + 'cast_p', + # transpose op: + 'cast_p' + ] + + class TestLogPJVPAndTranspose(TestAddPJVPAndTranspose): def init_data(self): @@ -503,7 +572,7 @@ class TestBroadcastPJVPAndTranspose(TestAddPJVPAndTranspose): # jvp op: 'broadcast_p', # transpose op: - 'reduce_p', + 'reduce_sum_p', 'reshape_p' ] @@ -650,11 +719,11 @@ class TestConcatPJVPAndTranspose(TestAddPJVPAndTranspose): ] -class TestReducePJVPAndTranspose(TestAddPJVPAndTranspose): +class TestReduceSumPJVPAndTranspose(TestAddPJVPAndTranspose): def init_data(self): # Set prim op - self.op_type = 'reduce_p' + self.op_type = 'reduce_sum_p' X = paddle.static.data(name='X', shape=[2, 3, 4, 5], dtype='float64') self.prim_input = {'X': X} self.prim_output = { @@ -682,9 +751,9 @@ class TestReducePJVPAndTranspose(TestAddPJVPAndTranspose): self.all_ops = [ # prim op: - 'reduce_p', + 'reduce_sum_p', # jvp op: - 'reduce_p', + 'reduce_sum_p', # transpose op: 'reshape_p', 'broadcast_p', @@ -978,6 +1047,96 @@ class TestEqPJVPAndTranspose(TestAddPJVPAndTranspose): ] +class TestGtPJVPAndTranspose(TestAddPJVPAndTranspose): + + def init_data(self): + # Set prim op + self.op_type = 'gt_p' + X = paddle.static.data(name='X', shape=[4, 5], dtype='float64') + Y = paddle.static.data(name='Y', shape=[4, 5], dtype='float64') + + self.prim_input = {'X': X, 'Y': Y} + self.prim_output = { + 'Z': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.prim_attrs = {} + + # Set JVP + X_DOT = paddle.static.data(name='X_DOT', shape=[4, 5], dtype='float64') + Y_DOT = paddle.static.data(name='Y_DOT', shape=[4, 5], dtype='float64') + self.jvp_args = (X_DOT, Y_DOT) + self.jvp_out_shape_map = {0: self.prim_output['Z']} + + self.all_ops = [ + # prim op: + 'gt_p', + # jvp op: + 'fill_constant_p', + # transpose op: + ] + + +class TestGePJVPAndTranspose(TestAddPJVPAndTranspose): + + def init_data(self): + # Set prim op + self.op_type = 'ge_p' + X = paddle.static.data(name='X', shape=[4, 5], dtype='float64') + Y = paddle.static.data(name='Y', shape=[4, 5], dtype='float64') + + self.prim_input = {'X': X, 'Y': Y} + self.prim_output = { + 'Z': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.prim_attrs = {} + + # Set JVP + X_DOT = paddle.static.data(name='X_DOT', shape=[4, 5], dtype='float64') + Y_DOT = paddle.static.data(name='Y_DOT', shape=[4, 5], dtype='float64') + self.jvp_args = (X_DOT, Y_DOT) + self.jvp_out_shape_map = {0: self.prim_output['Z']} + + self.all_ops = [ + # prim op: + 'ge_p', + # jvp op: + 'fill_constant_p', + # transpose op: + ] + + +class TestNePJVPAndTranspose(TestAddPJVPAndTranspose): + + def init_data(self): + # Set prim op + self.op_type = 'ne_p' + X = paddle.static.data(name='X', shape=[4, 5], dtype='float64') + Y = paddle.static.data(name='Y', shape=[4, 5], dtype='float64') + + self.prim_input = {'X': X, 'Y': Y} + self.prim_output = { + 'Z': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.prim_attrs = {} + + # Set JVP + X_DOT = paddle.static.data(name='X_DOT', shape=[4, 5], dtype='float64') + Y_DOT = paddle.static.data(name='Y_DOT', shape=[4, 5], dtype='float64') + self.jvp_args = (X_DOT, Y_DOT) + self.jvp_out_shape_map = {0: self.prim_output['Z']} + + self.all_ops = [ + # prim op: + 'ne_p', + # jvp op: + 'fill_constant_p', + # transpose op: + ] + + class TestPowPJVPAndTranspose(TestAddPJVPAndTranspose): def init_data(self): diff --git a/python/paddle/fluid/tests/unittests/autograd/test_orig2prim.py b/python/paddle/fluid/tests/unittests/autograd/test_orig2prim.py index 5693520ef0ace5f506f84131efc3992b4bc60ad8..e1d5ee11a13ace711db583a1772d5d5d5b94e76a 100644 --- a/python/paddle/fluid/tests/unittests/autograd/test_orig2prim.py +++ b/python/paddle/fluid/tests/unittests/autograd/test_orig2prim.py @@ -110,6 +110,26 @@ class TestElementWiseMulOrig2Prim(TestElementWiseAddOrig2Prim): self.out_map = {0: self.output['Out']} +class TestElementWiseDivOrig2Prim(TestElementWiseAddOrig2Prim): + + def init_data(self): + self.op_type = 'elementwise_div' + X = paddle.static.data(name='X', shape=[8, 8], dtype='float') + Y = paddle.static.data(name='Y', shape=[8, 8], dtype='float') + + self.input = {'X': X, 'Y': Y} + self.output = { + 'Out': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.attrs = {} + + self.orig2prim_args = (X, Y) + self.all_ops = ['elementwise_div', 'div_p'] + # { prim_op_output_index: orig_op_output_var } + self.out_map = {0: self.output['Out']} + + class TestMatmulV2Orig2Prim(TestElementWiseAddOrig2Prim): def init_data(self): @@ -229,6 +249,26 @@ class TestErfOrig2Prim(TestElementWiseAddOrig2Prim): self.out_map = {0: self.output['Out']} +class TestAbsOrig2Prim(TestElementWiseAddOrig2Prim): + + def init_data(self): + self.op_type = 'abs' + X = paddle.static.data(name='X', shape=[3, 4], dtype='float') + + self.input = { + 'X': X, + } + self.output = { + 'Out': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.attrs = {} + + self.orig2prim_args = (X, ) + self.all_ops = ['abs', 'abs_p'] + self.out_map = {0: self.output['Out']} + + class TestLogOrig2Prim(TestElementWiseAddOrig2Prim): def init_data(self): @@ -422,7 +462,9 @@ class TestPNormOrig2Prim1(TestElementWiseAddOrig2Prim): } self.orig2prim_args = (X, ) - self.all_ops = ['p_norm', 'reshape_p', 'sqrt_p', 'reduce_p', 'mul_p'] + self.all_ops = [ + 'p_norm', 'reshape_p', 'sqrt_p', 'reduce_sum_p', 'mul_p' + ] self.out_map = {0: self.output['Out']} @@ -445,7 +487,9 @@ class TestPNormOrig2Prim2(TestElementWiseAddOrig2Prim): } self.orig2prim_args = (X, ) - self.all_ops = ['p_norm', 'reshape_p', 'sqrt_p', 'reduce_p', 'mul_p'] + self.all_ops = [ + 'p_norm', 'reshape_p', 'sqrt_p', 'reduce_sum_p', 'mul_p' + ] self.out_map = {0: self.output['Out']} @@ -580,6 +624,63 @@ class TestEqualOrig2Prim(TestElementWiseAddOrig2Prim): self.out_map = {0: self.output['Out']} +class TestNeOrig2Prim(TestElementWiseAddOrig2Prim): + + def init_data(self): + self.op_type = 'not_equal' + X = paddle.static.data(name='X', shape=[5, 8], dtype='float') + Y = paddle.static.data(name='Y', shape=[5, 8], dtype='float') + + self.input = {'X': X, 'Y': Y} + self.output = { + 'Out': + self.layer_help.create_variable_for_type_inference(dtype='bool') + } + self.attrs = {} + self.orig2prim_args = (X, Y) + self.all_ops = ['not_equal', 'ne_p'] + # { prim_op_output_index: orig_op_output_var } + self.out_map = {0: self.output['Out']} + + +class TestGtOrig2Prim(TestElementWiseAddOrig2Prim): + + def init_data(self): + self.op_type = 'greater_than' + X = paddle.static.data(name='X', shape=[5, 8], dtype='float') + Y = paddle.static.data(name='Y', shape=[5, 8], dtype='float') + + self.input = {'X': X, 'Y': Y} + self.output = { + 'Out': + self.layer_help.create_variable_for_type_inference(dtype='bool') + } + self.attrs = {} + self.orig2prim_args = (X, Y) + self.all_ops = ['greater_than', 'gt_p'] + # { prim_op_output_index: orig_op_output_var } + self.out_map = {0: self.output['Out']} + + +class TestGeOrig2Prim(TestElementWiseAddOrig2Prim): + + def init_data(self): + self.op_type = 'greater_equal' + X = paddle.static.data(name='X', shape=[5, 8], dtype='float') + Y = paddle.static.data(name='Y', shape=[5, 8], dtype='float') + + self.input = {'X': X, 'Y': Y} + self.output = { + 'Out': + self.layer_help.create_variable_for_type_inference(dtype='bool') + } + self.attrs = {} + self.orig2prim_args = (X, Y) + self.all_ops = ['greater_equal', 'ge_p'] + # { prim_op_output_index: orig_op_output_var } + self.out_map = {0: self.output['Out']} + + class TestPowOrig2Prim(TestElementWiseAddOrig2Prim): def init_data(self): @@ -665,5 +766,118 @@ class TestGeluApproximateOrig2Prim(TestElementWiseAddOrig2Prim): self.out_map = {0: self.output['Out']} +class TestReduceSumOrig2Prim(TestElementWiseAddOrig2Prim): + + def init_data(self): + self.op_type = 'reduce_sum' + X = paddle.static.data(name='X', shape=[5, 8], dtype='float') + + self.input = {'X': X} + self.output = { + 'Out': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.attrs = {'axis': [0, 1], 'keep_dim': False} + + self.orig2prim_args = (X, ) + self.all_ops = ['reduce_sum', 'reduce_sum_p'] + # { prim_op_output_index: orig_op_output_var } + self.out_map = {0: self.output['Out']} + + +class TestReduceMeanOrig2Prim(TestElementWiseAddOrig2Prim): + + def init_data(self): + self.op_type = 'reduce_mean' + X = paddle.static.data(name='X', shape=[5, 8], dtype='float') + + self.input = {'X': X} + self.output = { + 'Out': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.attrs = {'axis': [0, 1], 'keep_dim': False} + + self.orig2prim_args = (X, ) + self.all_ops = [ + 'reduce_mean', 'reduce_sum_p', 'fill_constant_p', 'div_p' + ] + # { prim_op_output_index: orig_op_output_var } + self.out_map = {0: self.output['Out']} + + +class TestSizeOrig2Prim(TestElementWiseAddOrig2Prim): + + def init_data(self): + self.op_type = 'size' + X = paddle.static.data(name='X', shape=[5, 8], dtype='float') + + self.input = {'Input': X} + self.output = { + 'Out': + self.layer_help.create_variable_for_type_inference( + dtype=paddle.int64) + } + self.attrs = {} + self.orig2prim_args = (X, ) + self.all_ops = ['size', 'fill_constant_p'] + # { prim_op_output_index: orig_op_output_var } + self.out_map = {0: self.output['Out']} + + +class TestCastOrig2Prim(TestElementWiseAddOrig2Prim): + + def init_data(self): + self.op_type = 'cast' + X = paddle.static.data(name='X', shape=[5, 8], dtype='float') + + self.input = {'X': X} + self.output = { + 'Out': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.attrs = {'in_dtype': X.dtype, 'out_dtype': paddle.float64} + self.orig2prim_args = (X, ) + self.all_ops = ['cast', 'cast_p'] + # { prim_op_output_index: orig_op_output_var } + self.out_map = {0: self.output['Out']} + + +class TestPowScalarOrig2Prim(TestElementWiseAddOrig2Prim): + + def init_data(self): + self.op_type = 'pow' + X = paddle.static.data(name='X', shape=[5, 8], dtype='float') + + self.input = {'X': X} + self.output = { + 'Out': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.attrs = {'factor': 2.} + self.orig2prim_args = (None, X) + self.all_ops = ['pow', 'pow_p', 'fill_constant_p'] + # { prim_op_output_index: orig_op_output_var } + self.out_map = {0: self.output['Out']} + + +class TestSquareOrig2Prim(TestElementWiseAddOrig2Prim): + + def init_data(self): + self.op_type = 'square' + X = paddle.static.data(name='X', shape=[5, 8], dtype='float') + + self.input = {'X': X} + self.output = { + 'Out': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.attrs = {} + self.orig2prim_args = (X, ) + self.all_ops = ['square', 'pow_p', 'fill_constant_p'] + # { prim_op_output_index: orig_op_output_var } + self.out_map = {0: self.output['Out']} + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/autograd/test_prim2orig.py b/python/paddle/fluid/tests/unittests/autograd/test_prim2orig.py index 4d0f150073604f94faa5ad2f4e0f455c2d2d2d47..a89b91bdd2b64a7e28b7782a003992eb2b2e16f9 100644 --- a/python/paddle/fluid/tests/unittests/autograd/test_prim2orig.py +++ b/python/paddle/fluid/tests/unittests/autograd/test_prim2orig.py @@ -244,6 +244,26 @@ class TestErfPPrim2Orig(TestAddPPrim2Orig): self.out_map = {self.output['Y']: 0} +class TestAbsPPrim2Orig(TestAddPPrim2Orig): + + def init_data(self): + self.op_type = 'abs_p' + X = paddle.static.data(name='X', shape=[7, 8], dtype='float64') + + self.input = { + 'X': X, + } + self.output = { + 'Y': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.attrs = {} + + self.prim2orig_args = (X, ) + self.all_ops = ['abs_p', 'abs'] + self.out_map = {self.output['Y']: 0} + + class TestLogPPrim2Orig(TestAddPPrim2Orig): def init_data(self): @@ -375,7 +395,7 @@ class TestConcatPPrim2Orig(TestAddPPrim2Orig): class TestReducePPrim2Orig(TestAddPPrim2Orig): def init_data(self): - self.op_type = 'reduce_p' + self.op_type = 'reduce_sum_p' X = paddle.static.data(name='X', shape=[3, 9, 5], dtype='float64') self.input = {'X': X} @@ -386,7 +406,7 @@ class TestReducePPrim2Orig(TestAddPPrim2Orig): self.attrs = {'axis': [1], 'keepdim': True} self.prim2orig_args = (X, ) - self.all_ops = ['reduce_p', 'reduce_sum'] + self.all_ops = ['reduce_sum_p', 'reduce_sum'] self.out_map = {self.output['Y']: 0} @@ -555,6 +575,63 @@ class TestEqPPrim2Orig(TestAddPPrim2Orig): self.out_map = {self.output['Z']: 0} +class TestNePPrim2Orig(TestAddPPrim2Orig): + + def init_data(self): + self.op_type = 'ne_p' + X = paddle.static.data(name='X', shape=[7, 8], dtype='float64') + Y = paddle.static.data(name='Y', shape=[7, 8], dtype='float64') + + self.input = {'X': X, 'Y': Y} + self.output = { + 'Z': + self.layer_help.create_variable_for_type_inference(dtype='bool') + } + self.attrs = {} + + self.prim2orig_args = (X, Y) + self.all_ops = ['ne_p', 'not_equal'] + self.out_map = {self.output['Z']: 0} + + +class TestGtPPrim2Orig(TestAddPPrim2Orig): + + def init_data(self): + self.op_type = 'gt_p' + X = paddle.static.data(name='X', shape=[7, 8], dtype='float64') + Y = paddle.static.data(name='Y', shape=[7, 8], dtype='float64') + + self.input = {'X': X, 'Y': Y} + self.output = { + 'Z': + self.layer_help.create_variable_for_type_inference(dtype='bool') + } + self.attrs = {} + + self.prim2orig_args = (X, Y) + self.all_ops = ['gt_p', 'greater_than'] + self.out_map = {self.output['Z']: 0} + + +class TestGePPrim2Orig(TestAddPPrim2Orig): + + def init_data(self): + self.op_type = 'ge_p' + X = paddle.static.data(name='X', shape=[7, 8], dtype='float64') + Y = paddle.static.data(name='Y', shape=[7, 8], dtype='float64') + + self.input = {'X': X, 'Y': Y} + self.output = { + 'Z': + self.layer_help.create_variable_for_type_inference(dtype='bool') + } + self.attrs = {} + + self.prim2orig_args = (X, Y) + self.all_ops = ['ge_p', 'greater_equal'] + self.out_map = {self.output['Z']: 0} + + class TestPowPPrim2Orig(TestAddPPrim2Orig): def init_data(self): @@ -593,5 +670,25 @@ class TestMaxPPrim2Orig(TestAddPPrim2Orig): self.out_map = {self.output['Z']: 0} +class TestCastPPrim2Orig(TestAddPPrim2Orig): + + def init_data(self): + self.op_type = 'cast_p' + X = paddle.static.data(name='X', shape=[7, 8], dtype='float64') + + self.input = { + 'X': X, + } + self.output = { + 'Y': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.attrs = {'dtype': paddle.int64} + + self.prim2orig_args = (X, ) + self.all_ops = ['cast_p', 'cast'] + self.out_map = {self.output['Y']: 0} + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/autograd/test_primapi.py b/python/paddle/fluid/tests/unittests/autograd/test_primapi.py index 7edbb6ef77f26e7039b6da1a93969159ea00610f..bdc54563fc8d2a127670106af005ed00a1b9abf3 100644 --- a/python/paddle/fluid/tests/unittests/autograd/test_primapi.py +++ b/python/paddle/fluid/tests/unittests/autograd/test_primapi.py @@ -150,6 +150,8 @@ class TestWithoutProgramGuard(unittest.TestCase): (np.random.rand(3, 3), np.random.rand(3, 3)), (np.random.rand(3, 3), np.random.rand(3, 3)), 'float64'), ('log', paddle.log, (np.random.rand(3, 4), ), None, 'float32'), + ('abs', paddle.abs, (np.random.uniform(-10, 10, + (10, 10)), ), None, 'float32'), )) # paddle.where, paddle.pow, paddle.maximum has no double grad definition, # can not compute forward grad use double trick @@ -255,6 +257,8 @@ where_wrap = lambda x, y: paddle.where(paddle.eye(3, 4) == 1, x, y) (np.random.rand(2, 3), np.random.rand(3, 2)), None, 'float32'), ('multiply', paddle.multiply, (np.random.rand(2, 3), np.random.rand(2, 3)), None, 'float64'), + ('div', paddle.divide, + (np.random.rand(2, 3), np.random.rand(2, 3)), None, 'float64'), ('add', paddle.add, (np.random.rand(2, 3), np.random.rand(2, 3)), None, 'float32'), ('input_not_sequence', paddle.tanh, @@ -283,7 +287,36 @@ where_wrap = lambda x, y: paddle.where(paddle.eye(3, 4) == 1, x, y) (np.random.rand(200, 189), ), None, 'float32'), ('gelu_approximate', lambda x: paddle.nn.functional.gelu(x, True), (np.random.rand(200, 189), ), None, 'float32'), - )) + ('sum', paddle.sum, (np.random.rand(200, 345), ), None, 'float32'), + ('sum_with_axis', lambda x: paddle.sum(x, axis=1), + (np.random.rand(200, 345), ), None, 'float32'), + ('sum_with_keepdim', lambda x: paddle.sum(x, keepdim=True), + (np.random.rand(200, 345), ), None, 'float32'), + ('mean', paddle.mean, (np.random.rand(200, 345), ), None, 'float32'), + ('mean_with_axis', lambda x: paddle.mean(x, axis=1), + (np.random.rand(200, 345), ), None, 'float32'), + ('mean_with_keepdim', lambda x: paddle.mean(x, keepdim=True), + (np.random.rand(200, 345), ), None, 'float32'), + ('mean_with_axis_keepdim', + lambda x: paddle.mean(x, axis=0, keepdim=True), + (np.random.rand(200, 345), ), None, 'float32'), + ('abs', paddle.abs, (np.random.uniform(-10, 10, + (200, 345)), ), None, 'float32'), + ('cast_float', lambda x: paddle.cast(x, paddle.float64), + (np.random.rand(10, 20), ), None, 'float32'), + ('cast_int', lambda x: paddle.cast(x, paddle.int32), + (np.random.rand(10, 20), ), None, 'float32'), + ('square', paddle.square, (np.random.rand(100), ), None, 'float32'), + ('pow_scalar', lambda x: paddle.pow(x, 2), + (np.random.rand(20, 30), ), None, 'float32'), + ('var', paddle.var, (np.random.rand(200, 324), ), None, 'float32'), + ('var_with_axis', lambda x: paddle.var(x, axis=1), + (np.random.rand(10, 20, 30), ), None, 'float32'), + ('var_without_unbiased', + lambda x: paddle.var(x, axis=1, unbiased=False), + (np.random.rand(10, 20, 30), ), None, 'float32'), + ('var_with_keepdim', lambda x: paddle.var(x, axis=1, keepdim=True), + (np.random.rand(10, 20, 30), ), None, 'float32'))) class TestGrad(unittest.TestCase): def setUp(self): diff --git a/python/paddle/fluid/tests/unittests/autograd/test_primops.py b/python/paddle/fluid/tests/unittests/autograd/test_primops.py index 25a3d9bce235a00e9ef73a72db567688606ffb48..35291432f6e8fb8da3cf070422139680066200f8 100644 --- a/python/paddle/fluid/tests/unittests/autograd/test_primops.py +++ b/python/paddle/fluid/tests/unittests/autograd/test_primops.py @@ -42,7 +42,11 @@ paddle.enable_static() ('cos', primops.cos, randn(2, 3), {}, (2, 3), 'float64'), ('exp', primops.exp, randn(2, 3), {}, (2, 3), 'float64'), ('erf', primops.erf, randn(2, 3), {}, (2, 3), 'float64'), + ('abs', primops.abs, randn(2, 3), {}, (2, 3), 'float64'), ('log', primops.log, randn(2, 3), {}, (2, 3), 'float64'), + ('cast', primops.cast, randn(2, 3), { + 'dtype': paddle.int64 + }, (2, 3), 'int64'), ('reshape', primops.reshape, randn(2, 3), { 'shape': (3, 2) }, (3, 2), 'float64'), @@ -58,10 +62,10 @@ paddle.enable_static() ('concat_axis1', primops.concat, ((randn(2, 3), randn(2, 3)), ), { 'axis': 1 }, (2, 6), 'float64'), - ('reduce_axis1', primops.reduce, randn(2, 3), { + ('reduce_axis1', primops.reduce_sum, randn(2, 3), { 'axis': (1, ) }, (2, ), 'float64'), - ('reduce_axis01', primops.reduce, randn(2, 3), { + ('reduce_axis01', primops.reduce_sum, randn(2, 3), { 'axis': (0, 1) }, (1, ), 'float64'), ('split', primops.split, randn(2, 3), { @@ -99,6 +103,9 @@ paddle.enable_static() ('select', primops.select, (randn(2, 3) > 0, randn(2, 3), randn(2, 3)), {}, (2, 3), 'float64'), ('eq', primops.eq, (randn(2, 3), randn(2, 3)), {}, (2, 3), 'bool'), + ('ne', primops.ne, (randn(2, 3), randn(2, 3)), {}, (2, 3), 'bool'), + ('gt', primops.gt, (randn(2, 3), randn(2, 3)), {}, (2, 3), 'bool'), + ('ge', primops.ge, (randn(2, 3), randn(2, 3)), {}, (2, 3), 'bool'), ('pow', primops.pow, (randn(2, 3), randn(2, 3)), {}, (2, 3), 'float64'), ('max', primops.max, (randn(2, 3), randn(2, 3)), {}, (2, 3), 'float64'), )) diff --git a/python/paddle/fluid/tests/unittests/autograd/test_transform.py b/python/paddle/fluid/tests/unittests/autograd/test_transform.py index f976ef729cc7a00f403e1c0f5360e9b92dd974d8..6c0aa697550bc369981408e1b7a7bd0f26ac4df2 100644 --- a/python/paddle/fluid/tests/unittests/autograd/test_transform.py +++ b/python/paddle/fluid/tests/unittests/autograd/test_transform.py @@ -290,8 +290,8 @@ class TestAutoGradTransformForIndexSelect(TestAutoGradTransformForAdd): 'index_select' ] self.orig2prim_ops = [ - 'broadcast_p', 'add_p', 'reshape_p', 'mul_p', 'reduce_p', 'sqrt_p', - 'broadcast_p', 'sub_p', 'concat_p', 'gather_p' + 'broadcast_p', 'add_p', 'reshape_p', 'mul_p', 'reduce_sum_p', + 'sqrt_p', 'broadcast_p', 'sub_p', 'concat_p', 'gather_p' ] self.linearize_ops = self.orig2prim_ops + [ # call fill_const() in linearize() function @@ -306,7 +306,7 @@ class TestAutoGradTransformForIndexSelect(TestAutoGradTransformForAdd): 'mul_p', 'mul_p', 'add_p', - 'reduce_p', + 'reduce_sum_p', 'fill_constant_p', # 'sqrt_p', Will not append sqrt_p op when apply JVP for sqrt_p 'mul_p', 'div_p', @@ -326,7 +326,7 @@ class TestAutoGradTransformForIndexSelect(TestAutoGradTransformForAdd): 'fill_constant_p', 'mul_p', # transposed op - 'reduce_p', + 'reduce_sum_p', 'reshape_p', 'reshape_p', 'mul_p', @@ -334,7 +334,7 @@ class TestAutoGradTransformForIndexSelect(TestAutoGradTransformForAdd): 'reshape_p', 'broadcast_p', 'div_p', - 'reduce_p', + 'reduce_sum_p', 'reshape_p', 'fill_constant_p', 'sub_p', diff --git a/python/paddle/incubate/autograd/primops.py b/python/paddle/incubate/autograd/primops.py index cb67d3f23e911d20d95feb7c055828db742f93c6..636dc8922049053daf4546823864acc483e45b02 100644 --- a/python/paddle/incubate/autograd/primops.py +++ b/python/paddle/incubate/autograd/primops.py @@ -137,6 +137,11 @@ def exp(x, out=None): return _simple_unop(LayerHelper('exp_p', **locals())) +@REGISTER_FN('abs_p', 'X', 'Y') +def abs(x, out=None): + return _simple_unop(LayerHelper('abs_p', **locals())) + + @REGISTER_FN('reshape_p', 'X', 'Y') def reshape(x, shape, out=None): return _manipulation_unop(LayerHelper('reshape_p', **locals())) @@ -193,15 +198,17 @@ def concat(xs, axis=0, out=None): return out -@REGISTER_FN('reduce_p', 'X', 'Y') -def reduce(x, axis, keepdim=False, out=None): +@REGISTER_FN('reduce_sum_p', 'X', 'Y') +def reduce_sum(x, axis=None, keepdim=False, out=None): + axes = axis or tuple(range(0, len(x.shape))) + axes = (axes, ) if isinstance(axes, int) else axes if not isinstance(axis, (tuple, list)): raise TypeError(f'axis must be tuple or list, but got {type(axis)}') if not isinstance(keepdim, bool): raise TypeError(f'keepdim must be bool, but got {type(keepdim)}') - attrs = {'axis': axis, 'keepdim': keepdim} - helper = LayerHelper('reduce_p', **locals()) + attrs = {'axis': axis, 'keepdim': keepdim} + helper = LayerHelper('reduce_sum_p', **locals()) if out is None: out = helper.create_variable_for_type_inference(dtype=x.dtype) @@ -347,6 +354,21 @@ def eq(x, y, out=None): return _simple_binop(LayerHelper('eq_p', **locals())) +@REGISTER_FN('gt_p', 'X', 'Y', 'Z') +def gt(x, y, out=None): + return _simple_binop(LayerHelper('gt_p', **locals())) + + +@REGISTER_FN('ge_p', 'X', 'Y', 'Z') +def ge(x, y, out=None): + return _simple_binop(LayerHelper('ge_p', **locals())) + + +@REGISTER_FN('ne_p', 'X', 'Y', 'Z') +def ne(x, y, out=None): + return _simple_binop(LayerHelper('ne_p', **locals())) + + @REGISTER_FN('pow_p', 'X', 'Y', 'Z') def pow(x, y, out=None): return _simple_binop(LayerHelper('pow_p', **locals())) @@ -360,3 +382,15 @@ def max(x, y, out=None): @REGISTER_FN('erf_p', 'X', 'Y') def erf(x, out=None): return _simple_unop(LayerHelper('erf_p', **locals())) + + +@REGISTER_FN('cast_p', 'X', 'Y') +def cast(x, dtype, out=None): + helper = LayerHelper('cast_p', **locals()) + if out is None: + out = helper.create_variable_for_type_inference(dtype) + helper.append_op(type=helper.layer_type, + inputs={'X': x}, + outputs={'Y': out}, + attrs={'dtype': dtype}) + return out diff --git a/python/paddle/incubate/autograd/primreg.py b/python/paddle/incubate/autograd/primreg.py index 6c3ece09a6be1ffc48fbf0b4600c250465f9e5dc..34b1c7f48334845947d81b5f3de5b4edf14b2f0a 100644 --- a/python/paddle/incubate/autograd/primreg.py +++ b/python/paddle/incubate/autograd/primreg.py @@ -80,7 +80,7 @@ def op_position_inputs(op): """ args = _primop_position_argnames.lookup(op.type) - assert args is not None, 'args should not be None in op_position_inputs().' + assert args is not None, f'args of {op.type} should not be None in op_position_inputs().' *input_names, _ = args inputs = [] diff --git a/python/paddle/incubate/autograd/primrules.py b/python/paddle/incubate/autograd/primrules.py index 3fe40da787d0be172b1525f9a5f7647957d3246f..4625cfd362f07030feba94c7191b178108384474 100644 --- a/python/paddle/incubate/autograd/primrules.py +++ b/python/paddle/incubate/autograd/primrules.py @@ -11,16 +11,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import typing +import functools import math +import operator +import typing import paddle from . import primops -from .primops import (add, broadcast, concat, cos, div, exp, fill_const, gather, - matmul, mul, neg, reduce, reshape, scatter_add, set_value, +from .primops import (add, broadcast, concat, cos, div, eq, erf, exp, + fill_const, gather, ge, gt, log, matmul, max, mul, ne, + neg, reduce_sum, reshape, scatter_add, select, set_value, sin, slice_assign, slice_select, split, sqrt, sub, tanh, - transpose, log, select, eq, max, erf) + transpose) from .primreg import (REGISTER_JVP, REGISTER_ORIG2PRIM, REGISTER_PRIM2ORIG, REGISTER_TRANSPOSE, lookup_fn, lookup_jvp, lookup_orig2prim, lookup_prim2orig, lookup_transpose, @@ -155,6 +158,13 @@ def elementwise_mul_orig2prim(op, x, y): return z +@REGISTER_ORIG2PRIM('elementwise_div') +def elementwise_div_orig2prim(op, x, y): + if x.shape != y.shape: + y = broadcast(y, shape=x.shape) + return primops.div(x, y) + + @REGISTER_ORIG2PRIM('tanh') def tanh_orig2prim(op, x): return tanh(x) @@ -180,6 +190,11 @@ def erf_orig2prim(op, x): return erf(x) +@REGISTER_ORIG2PRIM('abs') +def abs_orig2prim(op, x): + return primops.abs(x) + + @REGISTER_ORIG2PRIM('log') def log_orig2prim(op, x): return log(x) @@ -307,13 +322,18 @@ def p_norm_orig2prim(op, x): x = reshape(x, shape=[num_el(x.shape)]) if abs(op.attr('porder') - 2.0) < 1e-5: - return sqrt(reduce(mul(x, x), axis=[0])) + return sqrt(reduce_sum(mul(x, x), axis=[0])) elif abs(op.attr('porder') - 1.0) < 1e-5: - return reduce(sqrt(mul(x, x)), axis=[0]) + return reduce_sum(sqrt(mul(x, x)), axis=[0]) else: raise RuntimeError('Only support lower l2/l1 norm currently') +@REGISTER_ORIG2PRIM('cast') +def cast_orig2prim(op, x): + return primops.cast(x, paddle.dtype(op.attr('out_dtype'))) + + # TODO: support broadcast @REGISTER_ORIG2PRIM('where') def select_orig2prim(op, condition, x, y): @@ -327,15 +347,48 @@ def equal_orig2prim(op, x, y): return eq(x, y) +@REGISTER_ORIG2PRIM('not_equal') +def ne_orig2prim(op, x, y): + if x.shape != y.shape: + y = broadcast(y, shape=x.shape) + return ne(x, y) + + +@REGISTER_ORIG2PRIM('greater_than') +def gt_orig2prim(op, x, y): + if x.shape != y.shape: + y = broadcast(y, shape=x.shape) + return gt(x, y) + + +@REGISTER_ORIG2PRIM('greater_equal') +def ge_orig2prim(op, x, y): + if x.shape != y.shape: + y = broadcast(y, shape=x.shape) + return ge(x, y) + + +# paddle.pow API use "elementwise_pow" operator when y is a Tensor. @REGISTER_ORIG2PRIM('elementwise_pow') def elementwise_pow_orig2prim(op, x, y): if x.shape != y.shape: y = broadcast(y, shape=x.shape) - z = primops.pow(x, y) return z +# paddle.pow API use "pow" operator when y is a scalar. +@REGISTER_ORIG2PRIM('pow') +def pow_orig2prim(op, x, y): + # x is factorTensor defined in paddle phi op. Currently it is None. + return primops.pow(y, fill_const(op.attr('factor'), y.shape, y.dtype)) + + +@REGISTER_ORIG2PRIM('square') +def square_orig2prim(op, x): + return primops.pow(x, fill_const(2., x.shape, x.dtype)) + + @REGISTER_ORIG2PRIM('elementwise_max') def elementwise_max_orig2prim(op, x, y): if x.shape != y.shape: @@ -367,6 +420,31 @@ def gelu_orig2prim(op, x): erf(mul(x, fill_const(1 / math.sqrt(2.), x.shape, x.dtype))))) +@REGISTER_ORIG2PRIM('reduce_sum') +def reduce_sum_orig2prim(op, x): + axes = tuple(range(0, len( + x.shape))) if op.attr('reduce_all') else op.attr('dim') + return reduce_sum(x, axis=axes, keepdim=op.attr('keep_dim')) + + +@REGISTER_ORIG2PRIM('reduce_mean') +def reduce_mean_orig2prim(op, x): + axes = tuple(range(0, len( + x.shape))) if op.attr('reduce_all') else op.attr('dim') + sum = reduce_sum(x, axis=axes, keepdim=op.attr('keep_dim')) + norm = fill_const(shape=sum.shape, + value=functools.reduce(operator.mul, + [x.shape[axis] for axis in axes]), + dtype=sum.dtype) + return div(sum, norm) + + +@REGISTER_ORIG2PRIM('size') +def size_orig2prim(op, x): + return fill_const(functools.reduce(operator.mul, x.shape), (1, ), + paddle.int64) + + ## Register prim2orig lower rules @REGISTER_PRIM2ORIG('add_p') def add_prim2orig(op, x, y): @@ -418,6 +496,11 @@ def erf_prim2orig(op, x): return paddle.erf(x) +@REGISTER_PRIM2ORIG('abs_p') +def abs_prim2orig(op, x): + return paddle.abs(x) + + @REGISTER_PRIM2ORIG('log_p') def log_prim2orig(op, x): return paddle.log(x) @@ -453,7 +536,7 @@ def concat_prim2orig(op, xs): return paddle.concat(xs, axis=op.attr('axis')) -@REGISTER_PRIM2ORIG('reduce_p') +@REGISTER_PRIM2ORIG('reduce_sum_p') def reduce_prim2orig(op, x): return paddle.sum(x, axis=op.attr('axis'), keepdim=op.attr('keepdim')) @@ -514,6 +597,21 @@ def eq_prim2orig(op, x, y): return paddle.equal(x, y) +@REGISTER_PRIM2ORIG('gt_p') +def gt_prim2orig(op, x, y): + return paddle.greater_than(x, y) + + +@REGISTER_PRIM2ORIG('ge_p') +def ge_prim2orig(op, x, y): + return paddle.greater_equal(x, y) + + +@REGISTER_PRIM2ORIG('ne_p') +def ne_prim2orig(op, x, y): + return paddle.not_equal(x, y) + + @REGISTER_PRIM2ORIG('pow_p') def pow_prim2orig(op, x, y): return paddle.pow(x, y) @@ -524,6 +622,11 @@ def max_prim2orig(op, x, y): return paddle.maximum(x, y) +@REGISTER_PRIM2ORIG('cast_p') +def cast_prim2orig(op, x): + return paddle.cast(x, paddle.dtype(op.attr('dtype'))) + + ## Register linearize rules @REGISTER_JVP('add_p') def add_jvp(op, x_dot, y_dot): @@ -629,6 +732,14 @@ def erf_jvp(op, x_dot): mul(x_dot, exp(neg(primops.pow(x, fill_const(2., x.shape, x.dtype)))))) +@REGISTER_JVP('abs_p') +def abs_jvp(op, x_dot): + if x_dot is None: + return None + x, = op_position_inputs(op) + return select(ge(x, fill_const(0., x.shape, x.dtype)), x_dot, neg(x_dot)) + + @REGISTER_JVP('log_p') def log_jvp(op, x_dot): if x_dot is None: @@ -678,8 +789,8 @@ def concat_jvp(op, xs_dot): return linear_jvp(op, xs_dot, axis=axis) -@REGISTER_JVP('reduce_p') -def reduce_jvp(op, x_dot): +@REGISTER_JVP('reduce_sum_p') +def reduce_sum_jvp(op, x_dot): if x_dot is None: return None axis = op.attr('axis') @@ -778,6 +889,33 @@ def eq_jvp(op, x_dot, y_dot): return z_dot +@REGISTER_JVP('gt_p') +def gt_jvp(op, x_dot, y_dot): + if x_dot is None and y_dot is None: + return None + x, _ = op_position_inputs(op) + z_dot = fill_const(value=0., shape=x.shape, dtype=x.dtype) + return z_dot + + +@REGISTER_JVP('ge_p') +def ge_jvp(op, x_dot, y_dot): + if x_dot is None and y_dot is None: + return None + x, _ = op_position_inputs(op) + z_dot = fill_const(value=0., shape=x.shape, dtype=x.dtype) + return z_dot + + +@REGISTER_JVP('ne_p') +def ne_jvp(op, x_dot, y_dot): + if x_dot is None and y_dot is None: + return None + x, _ = op_position_inputs(op) + z_dot = fill_const(value=0., shape=x.shape, dtype=x.dtype) + return z_dot + + @REGISTER_JVP('pow_p') def pow_jvp(op, x_dot, y_dot): @@ -825,6 +963,12 @@ def max_jvp(op, x_dot, y_dot): return select(eq(y, z), y_dot, x_dot) +@REGISTER_JVP('cast_p') +def cast_jvp(op, x_dot): + y = op_position_output(op) + return primops.cast(x_dot, y.dtype) + + ## Register transpose rules @@ -886,7 +1030,7 @@ def broadcast_transpose(op, check_dot, y_bar): keepdim = [(bat + i) for i, s in enumerate(x.shape) if s == 1] axis += keepdim # TODO: Change it. keepdim boolean - out = reduce(y_bar, axis=axis, keepdim=False) + out = reduce_sum(y_bar, axis=axis, keepdim=False) return reshape(out, x.shape) @@ -921,8 +1065,8 @@ def concat_transpose(op, check_dot, y_bar): return split(y_bar, num_or_sections=sections, axis=axis) -@REGISTER_TRANSPOSE('reduce_p') -def reduce_transpose(op, check_dot, y_bar): +@REGISTER_TRANSPOSE('reduce_sum_p') +def reduce_sum_transpose(op, check_dot, y_bar): x, = op_position_inputs(op) assert check_dot(x), 'check_dot(x) must be True' axes = op.attr('axis') @@ -1029,3 +1173,9 @@ def select_transpose(op, check_dot, z_bar): y_bar = select(cond, zeros_y, z_bar) if check_dot(y) else None return cond_bar, x_bar, y_bar + + +@REGISTER_TRANSPOSE('cast_p') +def cast_transpose(op, check_dot, y_bar): + x, = op_position_inputs(op) + return primops.cast(y_bar, x.dtype) diff --git a/python/paddle/incubate/autograd/primx.py b/python/paddle/incubate/autograd/primx.py index 19f87dd9292154ca17dcca8364e76544107b2d5d..565fcb0b4ed836670db1ae23290820d36a52581b 100644 --- a/python/paddle/incubate/autograd/primx.py +++ b/python/paddle/incubate/autograd/primx.py @@ -12,18 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections import OrderedDict + import paddle -from paddle.fluid import framework as framework -from paddle.fluid.framework import default_main_program -from paddle.fluid.framework import Operator from paddle import compat as cpt -from .primops import fill_const, add -from .primreg import op_position_inputs, op_position_output, lookup_orig2prim, lookup_prim2orig -from .primrules import _orig2prim, _prim2orig, _jvp, _transpose -from .utils import get_input_var_list, get_output_var_list, flatten, flatten_and_remove_none -from collections import OrderedDict +from paddle.fluid import framework as framework +from paddle.fluid.framework import Operator, default_main_program from paddle.incubate.autograd.utils import as_tensors +from .primops import add, fill_const +from .primreg import (lookup_orig2prim, lookup_prim2orig, op_position_inputs, + op_position_output) +from .primrules import _jvp, _orig2prim, _prim2orig, _transpose +from .utils import (flatten, flatten_and_remove_none, get_input_var_list, + get_output_var_list) + def topo_path(xs, ys, block=None): """ Returns the list of ops on the path from `xs` to `ys` in topological