diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 68eaf1a0ed469464790bffab9e6963e3c9513cfe..63bf3ab6a0382be4764976eedac0ca5314bcd584 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -22,6 +22,7 @@ add_subdirectory(reduce_ops) add_subdirectory(sequence_ops) add_subdirectory(string) add_subdirectory(jit) +add_subdirectory(prim_ops) if(WITH_MKLDNN) add_subdirectory(mkldnn) endif() diff --git a/paddle/fluid/operators/prim_ops/CMakeLists.txt b/paddle/fluid/operators/prim_ops/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..a58ee6dc1f7ba150d5d3ce8bd154a9d43d8bb945 --- /dev/null +++ b/paddle/fluid/operators/prim_ops/CMakeLists.txt @@ -0,0 +1,28 @@ +include(operators) +if(WITH_UNITY_BUILD) + # Load Unity Build rules for operators in paddle/fluid/operators/prim_ops. + include(unity_build_rule.cmake) +endif() +register_operators() + +SET(PRIM_OP_SRCS + reshape_p_op.cc + broadcast_p_op.cc + reduce_p_op.cc + transpose_p_op.cc + split_p_op.cc + concat_p_op.cc + slice_select_p_op.cc + slice_assign_p_op.cc + gather_p_op.cc + scatter_add_p_op.cc + add_p_op.cc + sub_p_op.cc + mul_p_op.cc + div_p_op.cc + sqrt_p_op.cc + tanh_p_op.cc + matmul_p_op.cc + fill_constant_p_op.cc) + +cc_test(prim_op_test SRCS prim_op_test.cc ${PRIM_OP_SRCS} DEPS op_registry) diff --git a/paddle/fluid/operators/prim_ops/add_p_op.cc b/paddle/fluid/operators/prim_ops/add_p_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..4789ed8958f91f748f4c90a219f29052e1c43225 --- /dev/null +++ b/paddle/fluid/operators/prim_ops/add_p_op.cc @@ -0,0 +1,116 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" + +namespace paddle { +namespace framework { +class InferShapeContext; +class VarDesc; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace operators { +class AddPrimOp : public framework::OperatorBase { + public: + AddPrimOp(const std::string &type, const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : framework::OperatorBase(type, inputs, outputs, attrs) {} + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { + PADDLE_THROW(platform::errors::Unimplemented( + "Prim operator add_p should not be excuted directly")); + } +}; + +class AddPrimOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor), The input tensor of add_p op."); + AddInput("Y", "(Tensor), The input tensor of add_p op."); + AddOutput("Z", "(Tensor), The output tensor of add_p op."); + AddComment(R"DOC( +Autograd primitive add_p operator. +)DOC"); + } +}; + +class AddPrimOpShapeInference : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *ctx) const override { + framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0]; + framework::InferShapeVarPtr y_var_ptr = ctx->GetInputVarPtrs("Y")[0]; + framework::InferShapeVarPtr z_var_ptr = ctx->GetOutputVarPtrs("Z")[0]; + + framework::VarDesc *x_var = BOOST_GET(framework::VarDesc *, x_var_ptr); + framework::VarDesc *y_var = BOOST_GET(framework::VarDesc *, y_var_ptr); + auto x_shape = x_var->GetShape(); + auto y_shape = y_var->GetShape(); + size_t x_rank = x_shape.size(); + size_t y_rank = y_shape.size(); + PADDLE_ENFORCE_EQ(x_rank, y_rank, + platform::errors::InvalidArgument( + "The dimensions of two input tensor should be same, " + "but get %d and %d", + x_rank, y_rank)); + for (size_t i = 0; i < x_rank; ++i) { + PADDLE_ENFORCE_EQ( + x_shape[i], y_shape[i], + platform::errors::InvalidArgument( + "The shape of two input tensor at dimension %d should be same, " + "but get %d and %d", + i, x_shape[i], y_shape[i])); + } + + BOOST_GET(framework::VarDesc *, z_var_ptr)->SetShape(x_shape); + } +}; + +class AddPrimOpVarTypeInference + : public framework::StaticGraphVarTypeInference { + public: + void operator()(framework::InferVarTypeContext *ctx) const override { + auto x_name = Input(ctx, "X")[0]; + auto y_name = Input(ctx, "Y")[0]; + auto z_name = Output(ctx, "Z")[0]; + auto x_type = GetType(ctx, x_name); + auto y_type = GetType(ctx, y_name); + auto x_dtype = GetDataType(ctx, x_name); + auto y_dtype = GetDataType(ctx, y_name); + PADDLE_ENFORCE_EQ(x_type, y_type, + platform::errors::InvalidArgument( + "The type of two input tensor should be same, " + "but get %d and %d", + x_type, y_type)); + PADDLE_ENFORCE_EQ(x_dtype, y_dtype, + platform::errors::InvalidArgument( + "The datatype of two input tensor should be same, " + "but get %d and %d", + x_dtype, y_dtype)); + + SetType(ctx, z_name, x_type); + SetDataType(ctx, z_name, x_dtype); + } +}; + +} // namespace operators +} // namespace paddle + +REGISTER_OPERATOR(add_p, paddle::operators::AddPrimOp, + paddle::operators::AddPrimOpMaker, + paddle::operators::AddPrimOpShapeInference, + paddle::operators::AddPrimOpVarTypeInference); diff --git a/paddle/fluid/operators/prim_ops/broadcast_p_op.cc b/paddle/fluid/operators/prim_ops/broadcast_p_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..5459b73911473616c79e7e40f61951cf81a13c35 --- /dev/null +++ b/paddle/fluid/operators/prim_ops/broadcast_p_op.cc @@ -0,0 +1,110 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" + +namespace paddle { +namespace framework { +class InferShapeContext; +class VarDesc; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace operators { +class BroadcastPrimOp : public framework::OperatorBase { + public: + BroadcastPrimOp(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : framework::OperatorBase(type, inputs, outputs, attrs) {} + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { + PADDLE_THROW(platform::errors::Unimplemented( + "Prim operator broadcast_p should not be excuted directly")); + } +}; + +class BroadcastPrimOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor), The input tensor of broadcast_p op."); + AddOutput("Y", "(Tensor), The output tensor of broadcast_p op."); + AddAttr>( + "shape", + "(std::vector) Target shape of broadcast_p operator."); + AddComment(R"DOC( +Autograd primitive broadcast_p operator. +)DOC"); + } +}; + +static void CheckShapeValid(const std::vector &x_shape, + const std::vector &target_shape) { + size_t x_rank = x_shape.size(); + size_t target_rank = target_shape.size(); + PADDLE_ENFORCE_GE(target_rank, x_rank, + platform::errors::InvalidArgument( + "The rank of target shape should be greater than or " + "equal to input tensor's dimensions, " + "but received %d and %d", + target_rank, x_rank)); + std::vector::const_iterator it = target_shape.begin(); + for (size_t i = 0; i < x_rank; i++, it++) { + if (x_shape[i] != 1) { + it = std::find(it, target_shape.end(), x_shape[i]); + } + PADDLE_ENFORCE_EQ( + it != target_shape.end(), true, + platform::errors::InvalidArgument( + "Invalid shape, can not broadcast input tensor into target shape," + "the first dismatching shape %d is shape of input tensor at " + "dimension %d", + x_shape[i], i)); + } +} + +class BroadcastPrimOpShapeInference : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *ctx) const override { + framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0]; + framework::InferShapeVarPtr y_var_ptr = ctx->GetOutputVarPtrs("Y")[0]; + framework::VarDesc *x_var = BOOST_GET(framework::VarDesc *, x_var_ptr); + auto x_shape = x_var->GetShape(); + auto target_shape = ctx->Attrs().Get>("shape"); + CheckShapeValid(x_shape, target_shape); + BOOST_GET(framework::VarDesc *, y_var_ptr)->SetShape(target_shape); + } +}; + +class BroadcastPrimOpVarTypeInference + : public framework::StaticGraphVarTypeInference { + public: + void operator()(framework::InferVarTypeContext *ctx) const override { + auto x_name = Input(ctx, "X")[0]; + auto y_name = Output(ctx, "Y")[0]; + SetType(ctx, y_name, GetType(ctx, x_name)); + SetDataType(ctx, y_name, GetDataType(ctx, x_name)); + } +}; + +} // namespace operators +} // namespace paddle + +REGISTER_OPERATOR(broadcast_p, paddle::operators::BroadcastPrimOp, + paddle::operators::BroadcastPrimOpMaker, + paddle::operators::BroadcastPrimOpShapeInference, + paddle::operators::BroadcastPrimOpVarTypeInference); diff --git a/paddle/fluid/operators/prim_ops/concat_p_op.cc b/paddle/fluid/operators/prim_ops/concat_p_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..24516356a28367acef30424680fe368fcd0ca030 --- /dev/null +++ b/paddle/fluid/operators/prim_ops/concat_p_op.cc @@ -0,0 +1,134 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" + +namespace paddle { +namespace framework { +class InferShapeContext; +class VarDesc; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace operators { +class ConcatPrimOp : public framework::OperatorBase { + public: + ConcatPrimOp(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : framework::OperatorBase(type, inputs, outputs, attrs) {} + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { + PADDLE_THROW(platform::errors::Unimplemented( + "Prim operator concat_p should not be excuted directly")); + } +}; + +class ConcatPrimOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("XS", "(Tensor), The input tensors of concat_p op.") + .AsDuplicable(); + AddOutput("Y", "(Tensor), The output tensor of concat_p op."); + AddAttr("axis", "(int64_t), The axis along which to concat."); + AddComment(R"DOC( +Autograd primitive concat_p operator. +)DOC"); + } +}; + +class ConcatPrimOpShapeInference : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *ctx) const override { + auto x_var_ptrs = ctx->GetInputVarPtrs("XS"); + framework::InferShapeVarPtr y_var_ptr = ctx->GetOutputVarPtrs("Y")[0]; + auto axis = ctx->Attrs().Get("axis"); + int64_t cnt_along_axis = 0; + framework::VarDesc *first_x_var = + BOOST_GET(framework::VarDesc *, x_var_ptrs[0]); + auto first_x_shape = first_x_var->GetShape(); + cnt_along_axis += first_x_shape[axis]; + size_t first_x_rank = first_x_shape.size(); + for (size_t i = 1; i < x_var_ptrs.size(); ++i) { + framework::VarDesc *x_var = + BOOST_GET(framework::VarDesc *, x_var_ptrs[i]); + auto x_shape = x_var->GetShape(); + cnt_along_axis += x_shape[axis]; + size_t x_rank = x_shape.size(); + PADDLE_ENFORCE_EQ( + x_rank, first_x_rank, + platform::errors::InvalidArgument("The dimensions of %d input tensor " + "should be same as the dimensions " + "of 1st input tensor's, " + "but get %d and %d", + i + 1, x_rank, first_x_rank)); + for (size_t j = 0; j < x_rank; ++j) { + if (j != size_t(axis)) { + PADDLE_ENFORCE_EQ(x_shape[j], first_x_shape[j], + platform::errors::InvalidArgument( + "The shape of %d input tensor at dimension %d " + "should be same as the 1st input tensor's, " + "but get %d and %d", + i + 1, j, x_shape[j], first_x_shape[j])); + } + } + } + + std::vector y_shape(first_x_shape); + y_shape[axis] = cnt_along_axis; + BOOST_GET(framework::VarDesc *, y_var_ptr)->SetShape(y_shape); + } +}; + +class ConcatPrimOpVarTypeInference + : public framework::StaticGraphVarTypeInference { + public: + void operator()(framework::InferVarTypeContext *ctx) const override { + auto x_names = Input(ctx, "XS"); + auto y_name = Output(ctx, "Y")[0]; + auto first_x_name = x_names[0]; + auto first_x_type = GetType(ctx, first_x_name); + auto first_x_dtype = GetDataType(ctx, first_x_name); + for (size_t i = 1; i < x_names.size(); ++i) { + auto x_name = x_names[i]; + auto x_type = GetType(ctx, x_name); + auto x_dtype = GetDataType(ctx, x_name); + PADDLE_ENFORCE_EQ(x_type, first_x_type, + platform::errors::InvalidArgument( + "The type of %d input tensor should be same as the " + "first input tensor's, " + "but get %d and %d", + i + 1, x_type, first_x_type)); + PADDLE_ENFORCE_EQ(x_dtype, first_x_dtype, + platform::errors::InvalidArgument( + "The datatype of %d input tensor should be same as " + "the first input tensor's, " + "but get %d and %d", + i + 1, x_dtype, first_x_dtype)); + } + SetType(ctx, y_name, GetType(ctx, first_x_name)); + SetDataType(ctx, y_name, GetDataType(ctx, first_x_name)); + } +}; + +} // namespace operators +} // namespace paddle + +REGISTER_OPERATOR(concat_p, paddle::operators::ConcatPrimOp, + paddle::operators::ConcatPrimOpMaker, + paddle::operators::ConcatPrimOpShapeInference, + paddle::operators::ConcatPrimOpVarTypeInference); diff --git a/paddle/fluid/operators/prim_ops/div_p_op.cc b/paddle/fluid/operators/prim_ops/div_p_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..35ae1f69cd2c8864d6807580a3fccd41520a7610 --- /dev/null +++ b/paddle/fluid/operators/prim_ops/div_p_op.cc @@ -0,0 +1,116 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" + +namespace paddle { +namespace framework { +class InferShapeContext; +class VarDesc; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace operators { +class DivPrimOp : public framework::OperatorBase { + public: + DivPrimOp(const std::string &type, const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : framework::OperatorBase(type, inputs, outputs, attrs) {} + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { + PADDLE_THROW(platform::errors::Unimplemented( + "Prim operator div_p should not be excuted directly")); + } +}; + +class DivPrimOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor), The input tensor of div_p op."); + AddInput("Y", "(Tensor), The input tensor of div_p op."); + AddOutput("Z", "(Tensor), The output tensor of div_p op."); + AddComment(R"DOC( +Autograd primitive div_p operator. +)DOC"); + } +}; + +class DivPrimOpShapeInference : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *ctx) const override { + framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0]; + framework::InferShapeVarPtr y_var_ptr = ctx->GetInputVarPtrs("Y")[0]; + framework::InferShapeVarPtr z_var_ptr = ctx->GetOutputVarPtrs("Z")[0]; + + framework::VarDesc *x_var = BOOST_GET(framework::VarDesc *, x_var_ptr); + framework::VarDesc *y_var = BOOST_GET(framework::VarDesc *, y_var_ptr); + auto x_shape = x_var->GetShape(); + auto y_shape = y_var->GetShape(); + size_t x_rank = x_shape.size(); + size_t y_rank = y_shape.size(); + PADDLE_ENFORCE_EQ(x_rank, y_rank, + platform::errors::InvalidArgument( + "The dimensions of two input tensor should be same, " + "but get %d and %d", + x_rank, y_rank)); + for (size_t i = 0; i < x_rank; ++i) { + PADDLE_ENFORCE_EQ( + x_shape[i], y_shape[i], + platform::errors::InvalidArgument( + "The shape of two input tensor at dimension %d should be same, " + "but get %d and %d", + i, x_shape[i], y_shape[i])); + } + + BOOST_GET(framework::VarDesc *, z_var_ptr)->SetShape(x_shape); + } +}; + +class DivPrimOpVarTypeInference + : public framework::StaticGraphVarTypeInference { + public: + void operator()(framework::InferVarTypeContext *ctx) const override { + auto x_name = Input(ctx, "X")[0]; + auto y_name = Input(ctx, "Y")[0]; + auto z_name = Output(ctx, "Z")[0]; + auto x_type = GetType(ctx, x_name); + auto y_type = GetType(ctx, y_name); + auto x_dtype = GetDataType(ctx, x_name); + auto y_dtype = GetDataType(ctx, y_name); + PADDLE_ENFORCE_EQ(x_type, y_type, + platform::errors::InvalidArgument( + "The type of two input tensor should be same, " + "but get %d and %d", + x_type, y_type)); + PADDLE_ENFORCE_EQ(x_dtype, y_dtype, + platform::errors::InvalidArgument( + "The datatype of two input tensor should be same, " + "but get %d and %d", + x_dtype, y_dtype)); + + SetType(ctx, z_name, x_type); + SetDataType(ctx, z_name, x_dtype); + } +}; + +} // namespace operators +} // namespace paddle + +REGISTER_OPERATOR(div_p, paddle::operators::DivPrimOp, + paddle::operators::DivPrimOpMaker, + paddle::operators::DivPrimOpShapeInference, + paddle::operators::DivPrimOpVarTypeInference); diff --git a/paddle/fluid/operators/prim_ops/fill_constant_p_op.cc b/paddle/fluid/operators/prim_ops/fill_constant_p_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..9831599e46ccc0ee23fae77a5914f228770e63d9 --- /dev/null +++ b/paddle/fluid/operators/prim_ops/fill_constant_p_op.cc @@ -0,0 +1,81 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" + +namespace paddle { +namespace framework { +class InferShapeContext; +class VarDesc; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace operators { +class FillConstantPrimOp : public framework::OperatorBase { + public: + FillConstantPrimOp(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : framework::OperatorBase(type, inputs, outputs, attrs) {} + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { + PADDLE_THROW(platform::errors::Unimplemented( + "Prim operator fill_constant_p should not be excuted directly")); + } +}; + +class FillConstantPrimOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddOutput("Y", "(Tensor), The output tensor of fill_constant_p op."); + AddAttr("value", "(float) The value of output tensor."); + AddAttr>( + "shape", "(std::vector) The shape of output tensor."); + AddAttr("dtype", "(int) The dtype of output tensor."); + AddComment(R"DOC( +Autograd primitive fill_constant_p operator. +)DOC"); + } +}; + +class FillConstantPrimOpShapeInference : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *ctx) const override { + framework::InferShapeVarPtr y_var_ptr = ctx->GetOutputVarPtrs("Y")[0]; + auto shape = ctx->Attrs().Get>("shape"); + BOOST_GET(framework::VarDesc *, y_var_ptr)->SetShape(shape); + } +}; + +class FillConstantPrimOpVarTypeInference + : public framework::StaticGraphVarTypeInference { + public: + void operator()(framework::InferVarTypeContext *ctx) const override { + auto y_name = Output(ctx, "Y")[0]; + auto data_type = static_cast( + BOOST_GET_CONST(int, ctx->GetAttr("dtype"))); + SetDataType(ctx, y_name, data_type); + } +}; + +} // namespace operators +} // namespace paddle + +REGISTER_OPERATOR(fill_constant_p, paddle::operators::FillConstantPrimOp, + paddle::operators::FillConstantPrimOpMaker, + paddle::operators::FillConstantPrimOpShapeInference, + paddle::operators::FillConstantPrimOpVarTypeInference); diff --git a/paddle/fluid/operators/prim_ops/gather_p_op.cc b/paddle/fluid/operators/prim_ops/gather_p_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..be777de055803bc2b4d65469dfcc7c1ea9c2d796 --- /dev/null +++ b/paddle/fluid/operators/prim_ops/gather_p_op.cc @@ -0,0 +1,117 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" + +namespace paddle { +namespace framework { +class InferShapeContext; +class VarDesc; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace operators { +class GatherPrimOp : public framework::OperatorBase { + public: + GatherPrimOp(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : framework::OperatorBase(type, inputs, outputs, attrs) {} + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { + PADDLE_THROW(platform::errors::Unimplemented( + "Prim operator gather_p should not be excuted directly")); + } +}; + +class GatherPrimOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor), The input tensor of gather_p op."); + AddInput("IndexTensor", + "(Tensor), The index tensor of gather_p op, which is a 1D tensor.") + .AsDispensable(); + AddOutput("Y", "(Tensor), The output tensor of gather_p op."); + AddAttr("axis", "(int64_t), The axis along which to gather."); + AddAttr>( + "index", "(std::vector) The index of gather_p op") + .SetDefault({0}); + AddComment(R"DOC( +Autograd primitive gather_p operator. +)DOC"); + } +}; + +class GatherPrimOpShapeInference : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *ctx) const override { + framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0]; + framework::InferShapeVarPtr y_var_ptr = ctx->GetOutputVarPtrs("Y")[0]; + int64_t num_index = 0; + if (ctx->HasInput("IndexTensor")) { + framework::InferShapeVarPtr index_var_ptr = + ctx->GetInputVarPtrs("IndexTensor")[0]; + framework::VarDesc *index_var = + BOOST_GET(framework::VarDesc *, index_var_ptr); + auto index_shape = index_var->GetShape(); + PADDLE_ENFORCE_EQ(index_shape.size(), 1, + platform::errors::InvalidArgument( + "The index tensor should be a 1D tensor," + "but get rank %d", + index_shape.size())); + num_index = index_shape[0]; + } else { + num_index = ctx->Attrs().Get>("index").size(); + } + auto axis = ctx->Attrs().Get("axis"); + + framework::VarDesc *x_var = BOOST_GET(framework::VarDesc *, x_var_ptr); + auto x_shape = x_var->GetShape(); + x_shape[axis] = num_index; + + BOOST_GET(framework::VarDesc *, y_var_ptr)->SetShape(x_shape); + } +}; + +class GatherPrimOpVarTypeInference + : public framework::StaticGraphVarTypeInference { + public: + void operator()(framework::InferVarTypeContext *ctx) const override { + auto x_name = Input(ctx, "X")[0]; + auto y_name = Output(ctx, "Y")[0]; + if (ctx->HasInput("IndexTensor")) { + auto index_name = Input(ctx, "IndexTensor")[0]; + auto index_dtype = GetDataType(ctx, index_name); + PADDLE_ENFORCE_EQ( + index_dtype, framework::proto::VarType_Type_INT32, + platform::errors::InvalidArgument( + "The datatype of input tensor should be VarType_Type_INT32(%d), " + "but get %d", + framework::proto::VarType_Type_INT32, index_dtype)); + } + SetType(ctx, y_name, GetType(ctx, x_name)); + SetDataType(ctx, y_name, GetDataType(ctx, x_name)); + } +}; + +} // namespace operators +} // namespace paddle + +REGISTER_OPERATOR(gather_p, paddle::operators::GatherPrimOp, + paddle::operators::GatherPrimOpMaker, + paddle::operators::GatherPrimOpShapeInference, + paddle::operators::GatherPrimOpVarTypeInference); diff --git a/paddle/fluid/operators/prim_ops/matmul_p_op.cc b/paddle/fluid/operators/prim_ops/matmul_p_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..1a28e1ca5c427591fc6422f64aaeb9725df15b04 --- /dev/null +++ b/paddle/fluid/operators/prim_ops/matmul_p_op.cc @@ -0,0 +1,138 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" + +namespace paddle { +namespace framework { +class InferShapeContext; +class VarDesc; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace operators { +class MatmulPrimOp : public framework::OperatorBase { + public: + MatmulPrimOp(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : framework::OperatorBase(type, inputs, outputs, attrs) {} + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { + PADDLE_THROW(platform::errors::Unimplemented( + "Prim operator matmul_p should not be excuted directly")); + } +}; + +class MatmulPrimOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor), The input tensor of matmul_p op."); + AddInput("Y", "(Tensor), The input tensor of matmul_p op."); + AddOutput("Z", "(Tensor), The output tensor of matmul_p op."); + AddComment(R"DOC( +Autograd primitive matmul_p operator. +)DOC"); + } +}; + +class MatmulPrimOpShapeInference : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *ctx) const override { + framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0]; + framework::InferShapeVarPtr y_var_ptr = ctx->GetInputVarPtrs("Y")[0]; + framework::InferShapeVarPtr z_var_ptr = ctx->GetOutputVarPtrs("Z")[0]; + + framework::VarDesc *x_var = BOOST_GET(framework::VarDesc *, x_var_ptr); + framework::VarDesc *y_var = BOOST_GET(framework::VarDesc *, y_var_ptr); + auto x_shape = x_var->GetShape(); + auto y_shape = y_var->GetShape(); + size_t x_rank = x_shape.size(); + size_t y_rank = y_shape.size(); + PADDLE_ENFORCE_EQ(x_rank, y_rank, + platform::errors::InvalidArgument( + "The two input tensor's dimension should be equal" + "But received first input tensor's dimension is %d, " + "and another input tensor's dimension is %d", + x_rank, y_rank)); + + PADDLE_ENFORCE_EQ(x_rank == 2 || x_rank == 3, true, + platform::errors::InvalidArgument( + "The input tensor's dimension should be 2 or 3" + "But received input tensor's dimension is %d", + x_rank)); + + PADDLE_ENFORCE_EQ( + x_shape[x_rank - 1], y_shape[y_rank - 2], + platform::errors::InvalidArgument( + "Invalid shape for matmul, the last dimension of first input and " + "the penultimate dimension for the second input should be same." + "But received %d and %d.", + x_shape[x_rank - 1], y_shape[y_rank - 2])); + if (x_rank == 2) { + std::vector z_shape{x_shape[x_rank - 2], y_shape[y_rank - 1]}; + BOOST_GET(framework::VarDesc *, z_var_ptr)->SetShape(z_shape); + } else { + PADDLE_ENFORCE_EQ(x_shape[0], y_shape[0], + platform::errors::InvalidArgument( + "Invalid shape for matmul when input tensor's " + "dimension is 3, the first dimension of first " + "input and the second input should be same." + "But received %d and %d.", + x_shape[0], y_shape[0])); + + std::vector z_shape{x_shape[0], x_shape[x_rank - 2], + y_shape[y_rank - 1]}; + BOOST_GET(framework::VarDesc *, z_var_ptr)->SetShape(z_shape); + } + } +}; + +class MatmulPrimOpVarTypeInference + : public framework::StaticGraphVarTypeInference { + public: + void operator()(framework::InferVarTypeContext *ctx) const override { + auto x_name = Input(ctx, "X")[0]; + auto y_name = Input(ctx, "Y")[0]; + auto z_name = Output(ctx, "Z")[0]; + auto x_type = GetType(ctx, x_name); + auto y_type = GetType(ctx, y_name); + auto x_dtype = GetDataType(ctx, x_name); + auto y_dtype = GetDataType(ctx, y_name); + PADDLE_ENFORCE_EQ(x_type, y_type, + platform::errors::InvalidArgument( + "The type of two input tensor should be same, " + "but get %d and %d", + x_type, y_type)); + PADDLE_ENFORCE_EQ(x_dtype, y_dtype, + platform::errors::InvalidArgument( + "The datatype of two input tensor should be same, " + "but get %d and %d", + x_dtype, y_dtype)); + + SetType(ctx, z_name, x_type); + SetDataType(ctx, z_name, x_dtype); + } +}; + +} // namespace operators +} // namespace paddle + +REGISTER_OPERATOR(matmul_p, paddle::operators::MatmulPrimOp, + paddle::operators::MatmulPrimOpMaker, + paddle::operators::MatmulPrimOpShapeInference, + paddle::operators::MatmulPrimOpVarTypeInference); diff --git a/paddle/fluid/operators/prim_ops/mul_p_op.cc b/paddle/fluid/operators/prim_ops/mul_p_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..a60e2601a339bf8f9e02dad0a1ca9c9580386cf2 --- /dev/null +++ b/paddle/fluid/operators/prim_ops/mul_p_op.cc @@ -0,0 +1,116 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" + +namespace paddle { +namespace framework { +class InferShapeContext; +class VarDesc; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace operators { +class MulPrimOp : public framework::OperatorBase { + public: + MulPrimOp(const std::string &type, const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : framework::OperatorBase(type, inputs, outputs, attrs) {} + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { + PADDLE_THROW(platform::errors::Unimplemented( + "Prim operator mul_p should not be excuted directly")); + } +}; + +class MulPrimOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor), The input tensor of mul_p op."); + AddInput("Y", "(Tensor), The input tensor of mul_p op."); + AddOutput("Z", "(Tensor), The output tensor of mul_p op."); + AddComment(R"DOC( +Autograd primitive mul_p operator. +)DOC"); + } +}; + +class MulPrimOpShapeInference : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *ctx) const override { + framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0]; + framework::InferShapeVarPtr y_var_ptr = ctx->GetInputVarPtrs("Y")[0]; + framework::InferShapeVarPtr z_var_ptr = ctx->GetOutputVarPtrs("Z")[0]; + + framework::VarDesc *x_var = BOOST_GET(framework::VarDesc *, x_var_ptr); + framework::VarDesc *y_var = BOOST_GET(framework::VarDesc *, y_var_ptr); + auto x_shape = x_var->GetShape(); + auto y_shape = y_var->GetShape(); + size_t x_rank = x_shape.size(); + size_t y_rank = y_shape.size(); + PADDLE_ENFORCE_EQ(x_rank, y_rank, + platform::errors::InvalidArgument( + "The dimensions of two input tensor should be same, " + "but get %d and %d", + x_rank, y_rank)); + for (size_t i = 0; i < x_rank; ++i) { + PADDLE_ENFORCE_EQ( + x_shape[i], y_shape[i], + platform::errors::InvalidArgument( + "The shape of two input tensor at dimension %d should be same, " + "but get %d and %d", + i, x_shape[i], y_shape[i])); + } + + BOOST_GET(framework::VarDesc *, z_var_ptr)->SetShape(x_shape); + } +}; + +class MulPrimOpVarTypeInference + : public framework::StaticGraphVarTypeInference { + public: + void operator()(framework::InferVarTypeContext *ctx) const override { + auto x_name = Input(ctx, "X")[0]; + auto y_name = Input(ctx, "Y")[0]; + auto z_name = Output(ctx, "Z")[0]; + auto x_type = GetType(ctx, x_name); + auto y_type = GetType(ctx, y_name); + auto x_dtype = GetDataType(ctx, x_name); + auto y_dtype = GetDataType(ctx, y_name); + PADDLE_ENFORCE_EQ(x_type, y_type, + platform::errors::InvalidArgument( + "The type of two input tensor should be same, " + "but get %d and %d", + x_type, y_type)); + PADDLE_ENFORCE_EQ(x_dtype, y_dtype, + platform::errors::InvalidArgument( + "The datatype of two input tensor should be same, " + "but get %d and %d", + x_dtype, y_dtype)); + + SetType(ctx, z_name, x_type); + SetDataType(ctx, z_name, x_dtype); + } +}; + +} // namespace operators +} // namespace paddle + +REGISTER_OPERATOR(mul_p, paddle::operators::MulPrimOp, + paddle::operators::MulPrimOpMaker, + paddle::operators::MulPrimOpShapeInference, + paddle::operators::MulPrimOpVarTypeInference); diff --git a/paddle/fluid/operators/prim_ops/prim_op_test.cc b/paddle/fluid/operators/prim_ops/prim_op_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..2d65149d130bbc7d8e317233df9f78846e2f66af --- /dev/null +++ b/paddle/fluid/operators/prim_ops/prim_op_test.cc @@ -0,0 +1,553 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gtest/gtest.h" + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/program_desc.h" + +USE_OP_ITSELF(reshape_p); +USE_OP_ITSELF(broadcast_p); +USE_OP_ITSELF(reduce_p); +USE_OP_ITSELF(transpose_p); +USE_OP_ITSELF(split_p); +USE_OP_ITSELF(concat_p); +USE_OP_ITSELF(slice_select_p); +USE_OP_ITSELF(slice_assign_p); +USE_OP_ITSELF(gather_p); +USE_OP_ITSELF(scatter_add_p); +USE_OP_ITSELF(add_p); +USE_OP_ITSELF(sub_p); +USE_OP_ITSELF(mul_p); +USE_OP_ITSELF(div_p); +USE_OP_ITSELF(sqrt_p); +USE_OP_ITSELF(tanh_p); +USE_OP_ITSELF(matmul_p); +USE_OP_ITSELF(fill_constant_p); + +namespace paddle { +namespace framework { + +static void NewVar(BlockDesc *block, const std::string &name, + const std::vector &shape) { + auto *var_desc = block->Var(name); + if (shape.size() > 0) { + var_desc->SetShape(shape); + var_desc->SetType(proto::VarType::LOD_TENSOR); + var_desc->SetDataType(proto::VarType_Type_FP32); + } +} + +static void AppendOp(BlockDesc *block, const std::string &type, + VariableNameMap inputs, VariableNameMap outputs, + AttributeMap attrs) { + auto &op_info = OpInfoMap::Instance().Get(type); + if (op_info.Checker()) { + op_info.Checker()->Check(&attrs); + } + + auto *op = block->AppendOp(); + op->SetType(type); + for (auto &pair : inputs) { + op->SetInput(pair.first, pair.second); + } + + for (auto &pair : outputs) { + op->SetOutput(pair.first, pair.second); + for (auto &var_name : pair.second) { + if (!block->FindVarRecursive(var_name)) { + NewVar(block, var_name, {}); + } + } + } + + op->SetAttrMap(attrs); + op->InferVarType(block); + op->InferShape(*block); +} + +TEST(PrimOp, reshape_p) { + ProgramDesc program; + auto *block = program.MutableBlock(0); + std::vector shape{3, 4, 5}; + + std::string x0 = "x0"; + std::string x1 = "x1"; + + NewVar(block, x0, shape); + AppendOp(block, "reshape_p", {{"X", {x0}}}, {{"Y", {x1}}}, + {{"shape", std::vector{12, 5}}}); + ASSERT_EQ(block->Var("x1")->GetType(), proto::VarType::LOD_TENSOR); + ASSERT_EQ(block->Var("x1")->GetDataType(), proto::VarType_Type_FP32); + auto shapes = block->Var("x1")->GetShape(); + ASSERT_EQ(shapes.size(), 2UL); + ASSERT_EQ(shapes[0], 12L); + ASSERT_EQ(shapes[1], 5L); +} + +TEST(PrimOp, broadcast_p) { + ProgramDesc program; + auto *block = program.MutableBlock(0); + std::vector shape{3, 1}; + + std::string x0 = "x0"; + std::string x1 = "x1"; + + NewVar(block, x0, shape); + AppendOp(block, "broadcast_p", {{"X", {x0}}}, {{"Y", {x1}}}, + {{"shape", std::vector{3, 4, 5}}}); + ASSERT_EQ(block->Var("x1")->GetType(), proto::VarType::LOD_TENSOR); + ASSERT_EQ(block->Var("x1")->GetDataType(), proto::VarType_Type_FP32); + auto shapes = block->Var("x1")->GetShape(); + ASSERT_EQ(shapes.size(), 3UL); + ASSERT_EQ(shapes[0], 3L); + ASSERT_EQ(shapes[1], 4L); + ASSERT_EQ(shapes[2], 5L); +} + +TEST(PrimOp, reduce_p) { + ProgramDesc program; + auto *block = program.MutableBlock(0); + std::vector shape{3, 4, 5}; + + std::string x0 = "x0"; + std::string x1 = "x1"; + std::string x2 = "x2"; + + NewVar(block, x0, shape); + AppendOp(block, "reduce_p", {{"X", {x0}}}, {{"Y", {x1}}}, + {{"axis", std::vector{0, 2}}, {"keepdim", false}}); + ASSERT_EQ(block->Var("x1")->GetType(), proto::VarType::LOD_TENSOR); + ASSERT_EQ(block->Var("x1")->GetDataType(), proto::VarType_Type_FP32); + auto shapes = block->Var("x1")->GetShape(); + ASSERT_EQ(shapes.size(), 1UL); + ASSERT_EQ(shapes[0], 4L); + AppendOp(block, "reduce_p", {{"X", {x0}}}, {{"Y", {x2}}}, + {{"axis", std::vector{0, 2}}, {"keepdim", true}}); + ASSERT_EQ(block->Var("x2")->GetType(), proto::VarType::LOD_TENSOR); + ASSERT_EQ(block->Var("x2")->GetDataType(), proto::VarType_Type_FP32); + shapes = block->Var("x2")->GetShape(); + ASSERT_EQ(shapes.size(), 3UL); + ASSERT_EQ(shapes[0], 1L); + ASSERT_EQ(shapes[1], 4L); + ASSERT_EQ(shapes[2], 1L); +} + +TEST(PrimOp, transpose_p) { + ProgramDesc program; + auto *block = program.MutableBlock(0); + std::vector shape{3, 4, 5}; + + std::string x0 = "x0"; + std::string x1 = "x1"; + + NewVar(block, x0, shape); + AppendOp(block, "transpose_p", {{"X", {x0}}}, {{"Y", {x1}}}, + {{"axis", std::vector{2, 1, 0}}}); + ASSERT_EQ(block->Var("x1")->GetType(), proto::VarType::LOD_TENSOR); + ASSERT_EQ(block->Var("x1")->GetDataType(), proto::VarType_Type_FP32); + auto shapes = block->Var("x1")->GetShape(); + ASSERT_EQ(shapes.size(), 3UL); + ASSERT_EQ(shapes[0], 5L); + ASSERT_EQ(shapes[1], 4L); + ASSERT_EQ(shapes[2], 3L); +} + +TEST(PrimOp, split_p) { + ProgramDesc program; + auto *block = program.MutableBlock(0); + std::vector shape{6, 8, 10}; + + std::string x0 = "x0"; + std::string x1 = "x1"; + std::string x2 = "x2"; + std::string x3 = "x3"; + + NewVar(block, x0, shape); + AppendOp(block, "split_p", {{"X", {x0}}}, {{"YS", {x1, x2, x3}}}, + {{"axis", int64_t{1}}, + {"num_or_sections", std::vector{2, 4, 2}}}); + ASSERT_EQ(block->Var("x1")->GetType(), proto::VarType::LOD_TENSOR); + ASSERT_EQ(block->Var("x1")->GetDataType(), proto::VarType_Type_FP32); + auto shapes = block->Var("x1")->GetShape(); + ASSERT_EQ(shapes.size(), 3UL); + ASSERT_EQ(shapes[0], 6L); + ASSERT_EQ(shapes[1], 2L); + ASSERT_EQ(shapes[2], 10L); + ASSERT_EQ(block->Var("x2")->GetType(), proto::VarType::LOD_TENSOR); + ASSERT_EQ(block->Var("x2")->GetDataType(), proto::VarType_Type_FP32); + shapes = block->Var("x2")->GetShape(); + ASSERT_EQ(shapes.size(), 3UL); + ASSERT_EQ(shapes[0], 6L); + ASSERT_EQ(shapes[1], 4L); + ASSERT_EQ(shapes[2], 10L); + ASSERT_EQ(block->Var("x3")->GetType(), proto::VarType::LOD_TENSOR); + ASSERT_EQ(block->Var("x3")->GetDataType(), proto::VarType_Type_FP32); + shapes = block->Var("x3")->GetShape(); + ASSERT_EQ(shapes.size(), 3UL); + ASSERT_EQ(shapes[0], 6L); + ASSERT_EQ(shapes[1], 2L); + ASSERT_EQ(shapes[2], 10L); + std::string x4 = "x4"; + std::string x5 = "x5"; + AppendOp( + block, "split_p", {{"X", {x0}}}, {{"YS", {x4, x5}}}, + {{"axis", int64_t{2}}, {"num_or_sections", std::vector{2}}}); + ASSERT_EQ(block->Var("x4")->GetType(), proto::VarType::LOD_TENSOR); + ASSERT_EQ(block->Var("x4")->GetDataType(), proto::VarType_Type_FP32); + shapes = block->Var("x4")->GetShape(); + ASSERT_EQ(shapes.size(), 3UL); + ASSERT_EQ(shapes[0], 6L); + ASSERT_EQ(shapes[1], 8L); + ASSERT_EQ(shapes[2], 5L); + ASSERT_EQ(block->Var("x5")->GetType(), proto::VarType::LOD_TENSOR); + ASSERT_EQ(block->Var("x5")->GetDataType(), proto::VarType_Type_FP32); + shapes = block->Var("x5")->GetShape(); + ASSERT_EQ(shapes.size(), 3UL); + ASSERT_EQ(shapes[0], 6L); + ASSERT_EQ(shapes[1], 8L); + ASSERT_EQ(shapes[2], 5L); +} + +TEST(PrimOp, concat_p) { + ProgramDesc program; + auto *block = program.MutableBlock(0); + std::vector shape_0{3, 1, 5}; + std::vector shape_1{3, 4, 5}; + std::vector shape_2{3, 6, 5}; + + std::string x0 = "x0"; + std::string x1 = "x1"; + std::string x2 = "x2"; + std::string x3 = "x3"; + + NewVar(block, x0, shape_0); + NewVar(block, x1, shape_1); + NewVar(block, x2, shape_2); + AppendOp(block, "concat_p", {{"XS", {x0, x1, x2}}}, {{"Y", {x3}}}, + {{"axis", int64_t{1}}}); + ASSERT_EQ(block->Var("x3")->GetType(), proto::VarType::LOD_TENSOR); + ASSERT_EQ(block->Var("x3")->GetDataType(), proto::VarType_Type_FP32); + auto shapes = block->Var("x3")->GetShape(); + ASSERT_EQ(shapes.size(), 3UL); + ASSERT_EQ(shapes[0], 3L); + ASSERT_EQ(shapes[1], 11L); + ASSERT_EQ(shapes[2], 5L); +} + +TEST(PrimOp, slice_select_p) { + ProgramDesc program; + auto *block = program.MutableBlock(0); + std::vector shape{6, 8, 10}; + + std::string x0 = "x0"; + std::string x1 = "x1"; + + NewVar(block, x0, shape); + AppendOp(block, "slice_select_p", {{"X", {x0}}}, {{"Y", {x1}}}, + {{"axis", std::vector{0, 1, 2}}, + {"starts", std::vector{0, 0, 0}}, + {"ends", std::vector{5, 7, 9}}, + {"strides", std::vector{2, 2, 2}}}); + ASSERT_EQ(block->Var("x1")->GetType(), proto::VarType::LOD_TENSOR); + ASSERT_EQ(block->Var("x1")->GetDataType(), proto::VarType_Type_FP32); + auto shapes = block->Var("x1")->GetShape(); + ASSERT_EQ(shapes.size(), 3UL); + ASSERT_EQ(shapes[0], 3L); + ASSERT_EQ(shapes[1], 4L); + ASSERT_EQ(shapes[2], 5L); +} + +TEST(PrimOp, slice_assign_p) { + ProgramDesc program; + auto *block = program.MutableBlock(0); + std::vector shape_0{6, 8, 10}; + std::vector shape_1{3, 4, 5}; + + std::string x0 = "x0"; + std::string x1 = "x1"; + std::string x2 = "x2"; + + NewVar(block, x0, shape_0); + NewVar(block, x1, shape_1); + AppendOp(block, "slice_assign_p", {{"X", {x0}}, {"Y", {x1}}}, {{"Z", {x2}}}, + {{"axis", std::vector{0, 1, 2}}, + {"starts", std::vector{0, 0, 0}}, + {"ends", std::vector{5, 7, 9}}, + {"strides", std::vector{2, 2, 2}}}); + ASSERT_EQ(block->Var("x2")->GetType(), proto::VarType::LOD_TENSOR); + ASSERT_EQ(block->Var("x2")->GetDataType(), proto::VarType_Type_FP32); + auto shapes = block->Var("x2")->GetShape(); + ASSERT_EQ(shapes.size(), 3UL); + ASSERT_EQ(shapes[0], 6L); + ASSERT_EQ(shapes[1], 8L); + ASSERT_EQ(shapes[2], 10L); +} + +TEST(PrimOp, gather_p) { + ProgramDesc program; + auto *block = program.MutableBlock(0); + std::vector shape{6, 8, 10}; + + std::string x0 = "x0"; + std::string x1 = "x1"; + + NewVar(block, x0, shape); + AppendOp(block, "gather_p", {{"X", {x0}}}, {{"Y", {x1}}}, + {{"axis", int64_t{1}}, {"index", std::vector{0, 2, 5}}}); + ASSERT_EQ(block->Var("x1")->GetType(), proto::VarType::LOD_TENSOR); + ASSERT_EQ(block->Var("x1")->GetDataType(), proto::VarType_Type_FP32); + auto shapes = block->Var("x1")->GetShape(); + ASSERT_EQ(shapes.size(), 3UL); + ASSERT_EQ(shapes[0], 6L); + ASSERT_EQ(shapes[1], 3L); + ASSERT_EQ(shapes[2], 10L); + std::string index_t = "index_t"; + std::string x2 = "x2"; + + auto *var_desc = block->Var(index_t); + var_desc->SetShape(std::vector{3}); + var_desc->SetType(proto::VarType::LOD_TENSOR); + var_desc->SetDataType(proto::VarType_Type_INT32); + AppendOp(block, "gather_p", {{"X", {x0}}, {"IndexTensor", {index_t}}}, + {{"Y", {x2}}}, {{"axis", int64_t{1}}}); + ASSERT_EQ(block->Var("x2")->GetType(), proto::VarType::LOD_TENSOR); + ASSERT_EQ(block->Var("x2")->GetDataType(), proto::VarType_Type_FP32); + shapes = block->Var("x2")->GetShape(); + ASSERT_EQ(shapes.size(), 3UL); + ASSERT_EQ(shapes[0], 6L); + ASSERT_EQ(shapes[1], 3L); + ASSERT_EQ(shapes[2], 10L); +} + +TEST(PrimOp, scatter_add_p) { + ProgramDesc program; + auto *block = program.MutableBlock(0); + std::vector shape_0{6, 8, 10}; + std::vector shape_1{6, 3, 10}; + + std::string x0 = "x0"; + std::string x1 = "x1"; + std::string x2 = "x2"; + + NewVar(block, x0, shape_0); + NewVar(block, x1, shape_1); + AppendOp(block, "scatter_add_p", {{"X", {x0}}, {"Y", {x1}}}, {{"Z", {x2}}}, + {{"axis", int64_t{1}}, {"index", std::vector{0, 2, 5}}}); + ASSERT_EQ(block->Var("x2")->GetType(), proto::VarType::LOD_TENSOR); + ASSERT_EQ(block->Var("x2")->GetDataType(), proto::VarType_Type_FP32); + auto shapes = block->Var("x2")->GetShape(); + ASSERT_EQ(shapes.size(), 3UL); + ASSERT_EQ(shapes[0], 6L); + ASSERT_EQ(shapes[1], 8L); + ASSERT_EQ(shapes[2], 10L); + std::string index_t = "index_t"; + std::string x3 = "x3"; + + auto *var_desc = block->Var(index_t); + var_desc->SetShape(std::vector{3}); + var_desc->SetType(proto::VarType::LOD_TENSOR); + var_desc->SetDataType(proto::VarType_Type_INT32); + AppendOp(block, "scatter_add_p", + {{"X", {x0}}, {"Y", {x1}}, {"IndexTensor", {index_t}}}, + {{"Z", {x3}}}, {{"axis", int64_t{1}}}); + ASSERT_EQ(block->Var("x3")->GetType(), proto::VarType::LOD_TENSOR); + ASSERT_EQ(block->Var("x3")->GetDataType(), proto::VarType_Type_FP32); + shapes = block->Var("x3")->GetShape(); + ASSERT_EQ(shapes.size(), 3UL); + ASSERT_EQ(shapes[0], 6L); + ASSERT_EQ(shapes[1], 8L); + ASSERT_EQ(shapes[2], 10L); +} + +TEST(PrimOp, add_p) { + ProgramDesc program; + auto *block = program.MutableBlock(0); + std::vector shape{3, 4, 5}; + + std::string x0 = "x0"; + std::string x1 = "x1"; + std::string x2 = "x2"; + + NewVar(block, x0, shape); + NewVar(block, x1, shape); + AppendOp(block, "add_p", {{"X", {x0}}, {"Y", {x1}}}, {{"Z", {x2}}}, {}); + ASSERT_EQ(block->Var("x2")->GetType(), proto::VarType::LOD_TENSOR); + ASSERT_EQ(block->Var("x2")->GetDataType(), proto::VarType_Type_FP32); + auto shapes = block->Var("x2")->GetShape(); + ASSERT_EQ(shapes.size(), 3UL); + ASSERT_EQ(shapes[0], 3L); + ASSERT_EQ(shapes[1], 4L); + ASSERT_EQ(shapes[2], 5L); +} + +TEST(PrimOp, sub_p) { + ProgramDesc program; + auto *block = program.MutableBlock(0); + std::vector shape{3, 4, 5}; + + std::string x0 = "x0"; + std::string x1 = "x1"; + std::string x2 = "x2"; + + NewVar(block, x0, shape); + NewVar(block, x1, shape); + AppendOp(block, "sub_p", {{"X", {x0}}, {"Y", {x1}}}, {{"Z", {x2}}}, {}); + ASSERT_EQ(block->Var("x2")->GetType(), proto::VarType::LOD_TENSOR); + ASSERT_EQ(block->Var("x2")->GetDataType(), proto::VarType_Type_FP32); + auto shapes = block->Var("x2")->GetShape(); + ASSERT_EQ(shapes.size(), 3UL); + ASSERT_EQ(shapes[0], 3L); + ASSERT_EQ(shapes[1], 4L); + ASSERT_EQ(shapes[2], 5L); +} + +TEST(PrimOp, mul_p) { + ProgramDesc program; + auto *block = program.MutableBlock(0); + std::vector shape{3, 4, 5}; + + std::string x0 = "x0"; + std::string x1 = "x1"; + std::string x2 = "x2"; + + NewVar(block, x0, shape); + NewVar(block, x1, shape); + AppendOp(block, "mul_p", {{"X", {x0}}, {"Y", {x1}}}, {{"Z", {x2}}}, {}); + ASSERT_EQ(block->Var("x2")->GetType(), proto::VarType::LOD_TENSOR); + ASSERT_EQ(block->Var("x2")->GetDataType(), proto::VarType_Type_FP32); + auto shapes = block->Var("x2")->GetShape(); + ASSERT_EQ(shapes.size(), 3UL); + ASSERT_EQ(shapes[0], 3L); + ASSERT_EQ(shapes[1], 4L); + ASSERT_EQ(shapes[2], 5L); +} + +TEST(PrimOp, div_p) { + ProgramDesc program; + auto *block = program.MutableBlock(0); + std::vector shape{3, 4, 5}; + + std::string x0 = "x0"; + std::string x1 = "x1"; + std::string x2 = "x2"; + + NewVar(block, x0, shape); + NewVar(block, x1, shape); + AppendOp(block, "div_p", {{"X", {x0}}, {"Y", {x1}}}, {{"Z", {x2}}}, {}); + ASSERT_EQ(block->Var("x2")->GetType(), proto::VarType::LOD_TENSOR); + ASSERT_EQ(block->Var("x2")->GetDataType(), proto::VarType_Type_FP32); + auto shapes = block->Var("x2")->GetShape(); + ASSERT_EQ(shapes.size(), 3UL); + ASSERT_EQ(shapes[0], 3L); + ASSERT_EQ(shapes[1], 4L); + ASSERT_EQ(shapes[2], 5L); +} + +TEST(PrimOp, sqrt_p) { + ProgramDesc program; + auto *block = program.MutableBlock(0); + std::vector shape{3, 4, 5}; + + std::string x0 = "x0"; + std::string x1 = "x1"; + + NewVar(block, x0, shape); + AppendOp(block, "sqrt_p", {{"X", {x0}}}, {{"Y", {x1}}}, {}); + ASSERT_EQ(block->Var("x1")->GetType(), proto::VarType::LOD_TENSOR); + ASSERT_EQ(block->Var("x1")->GetDataType(), proto::VarType_Type_FP32); + auto shapes = block->Var("x1")->GetShape(); + ASSERT_EQ(shapes.size(), 3UL); + ASSERT_EQ(shapes[0], 3L); + ASSERT_EQ(shapes[1], 4L); + ASSERT_EQ(shapes[2], 5L); +} + +TEST(PrimOp, tanh_p) { + ProgramDesc program; + auto *block = program.MutableBlock(0); + std::vector shape{3, 4, 5}; + + std::string x0 = "x0"; + std::string x1 = "x1"; + + NewVar(block, x0, shape); + AppendOp(block, "tanh_p", {{"X", {x0}}}, {{"Y", {x1}}}, {}); + ASSERT_EQ(block->Var("x1")->GetType(), proto::VarType::LOD_TENSOR); + ASSERT_EQ(block->Var("x1")->GetDataType(), proto::VarType_Type_FP32); + auto shapes = block->Var("x1")->GetShape(); + ASSERT_EQ(shapes.size(), 3UL); + ASSERT_EQ(shapes[0], 3L); + ASSERT_EQ(shapes[1], 4L); + ASSERT_EQ(shapes[2], 5L); +} + +TEST(PrimOp, matmul_p) { + ProgramDesc program; + auto *block = program.MutableBlock(0); + std::vector shape_0{3, 4, 5}; + std::vector shape_1{3, 5, 8}; + + std::string x0 = "x0"; + std::string x1 = "x1"; + std::string x2 = "x2"; + + NewVar(block, x0, shape_0); + NewVar(block, x1, shape_1); + AppendOp(block, "matmul_p", {{"X", {x0}}, {"Y", {x1}}}, {{"Z", {x2}}}, {}); + ASSERT_EQ(block->Var("x2")->GetType(), proto::VarType::LOD_TENSOR); + ASSERT_EQ(block->Var("x2")->GetDataType(), proto::VarType_Type_FP32); + auto shapes = block->Var("x2")->GetShape(); + ASSERT_EQ(shapes.size(), 3UL); + ASSERT_EQ(shapes[0], 3L); + ASSERT_EQ(shapes[1], 4L); + ASSERT_EQ(shapes[2], 8L); + std::vector shape_2{4, 5}; + std::vector shape_3{5, 8}; + + std::string x3 = "x3"; + std::string x4 = "x4"; + std::string x5 = "x5"; + + NewVar(block, x3, shape_2); + NewVar(block, x4, shape_3); + AppendOp(block, "matmul_p", {{"X", {x3}}, {"Y", {x4}}}, {{"Z", {x5}}}, {}); + ASSERT_EQ(block->Var("x5")->GetType(), proto::VarType::LOD_TENSOR); + ASSERT_EQ(block->Var("x5")->GetDataType(), proto::VarType_Type_FP32); + shapes = block->Var("x5")->GetShape(); + ASSERT_EQ(shapes.size(), 2UL); + ASSERT_EQ(shapes[0], 4L); + ASSERT_EQ(shapes[1], 8L); +} + +TEST(PrimOp, fill_constant_p) { + ProgramDesc program; + auto *block = program.MutableBlock(0); + std::string x0 = "x0"; + + AppendOp(block, "fill_constant_p", {{}}, {{"Y", {x0}}}, + {{"value", 0.0f}, + {"dtype", proto::VarType_Type_FP32}, + {"shape", std::vector{3, 4, 5}}}); + ASSERT_EQ(block->Var("x0")->GetType(), proto::VarType::LOD_TENSOR); + ASSERT_EQ(block->Var("x0")->GetDataType(), proto::VarType_Type_FP32); + auto shapes = block->Var("x0")->GetShape(); + ASSERT_EQ(shapes.size(), 3UL); + ASSERT_EQ(shapes[0], 3L); + ASSERT_EQ(shapes[1], 4L); + ASSERT_EQ(shapes[2], 5L); +} + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/operators/prim_ops/reduce_p_op.cc b/paddle/fluid/operators/prim_ops/reduce_p_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..9f2b5f3ed2c43ff294e99383d8cd083ebc36675d --- /dev/null +++ b/paddle/fluid/operators/prim_ops/reduce_p_op.cc @@ -0,0 +1,107 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" + +namespace paddle { +namespace framework { +class InferShapeContext; +class VarDesc; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace operators { +class ReducePrimOp : public framework::OperatorBase { + public: + ReducePrimOp(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : framework::OperatorBase(type, inputs, outputs, attrs) {} + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { + PADDLE_THROW(platform::errors::Unimplemented( + "Prim operator reduce_p should not be excuted directly")); + } +}; + +class ReducePrimOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor), The input tensor of reduce_p op."); + AddOutput("Y", "(Tensor), The output tensor of reduce_p op."); + AddAttr>( + "axis", + "(std::vector) The axis along which to reduce on. Must be in " + "range [-rank(input), rank(input)]. If `axis[i] < 0`, the axis[i] to " + "reduce is `rank + axis[i]`."); + AddAttr("keepdim", + "(bool, default false) " + "If true, retain the reduced axis with length 1.") + .SetDefault(false); + AddComment(R"DOC( +Autograd primitive reduce_p operator. +)DOC"); + } +}; + +class ReducePrimOpShapeInference : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *ctx) const override { + framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0]; + framework::InferShapeVarPtr y_var_ptr = ctx->GetOutputVarPtrs("Y")[0]; + framework::VarDesc *x_var = BOOST_GET(framework::VarDesc *, x_var_ptr); + auto x_shape = x_var->GetShape(); + auto axis = ctx->Attrs().Get>("axis"); + auto keepdim = ctx->Attrs().Get("keepdim"); + if (keepdim) { + for (size_t i = 0; i < axis.size(); ++i) { + x_shape[axis[i]] = 1; + } + } else { + const int kDelFlag = -2; + for (size_t i = 0; i < axis.size(); ++i) { + x_shape[axis[i]] = kDelFlag; + } + x_shape.erase(remove(x_shape.begin(), x_shape.end(), kDelFlag), + x_shape.end()); + } + if (!keepdim && x_shape.size() == 0) { + x_shape.push_back(1); + } + + BOOST_GET(framework::VarDesc *, y_var_ptr)->SetShape(x_shape); + } +}; + +class ReducePrimOpVarTypeInference + : public framework::StaticGraphVarTypeInference { + public: + void operator()(framework::InferVarTypeContext *ctx) const override { + auto x_name = Input(ctx, "X")[0]; + auto y_name = Output(ctx, "Y")[0]; + SetType(ctx, y_name, GetType(ctx, x_name)); + SetDataType(ctx, y_name, GetDataType(ctx, x_name)); + } +}; + +} // namespace operators +} // namespace paddle + +REGISTER_OPERATOR(reduce_p, paddle::operators::ReducePrimOp, + paddle::operators::ReducePrimOpMaker, + paddle::operators::ReducePrimOpShapeInference, + paddle::operators::ReducePrimOpVarTypeInference); diff --git a/paddle/fluid/operators/prim_ops/reshape_p_op.cc b/paddle/fluid/operators/prim_ops/reshape_p_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..497bc8fbaffb3906dda1d1c7c0d5f0952db81b0d --- /dev/null +++ b/paddle/fluid/operators/prim_ops/reshape_p_op.cc @@ -0,0 +1,97 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" + +namespace paddle { +namespace framework { +class InferShapeContext; +class VarDesc; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace operators { +class ReshapePrimOp : public framework::OperatorBase { + public: + ReshapePrimOp(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : framework::OperatorBase(type, inputs, outputs, attrs) {} + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { + PADDLE_THROW(platform::errors::Unimplemented( + "Prim operator reshape_p should not be excuted directly")); + } +}; + +class ReshapePrimOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor), The input tensor of reshape_p op."); + AddOutput("Y", "(Tensor), The output tensor of reshape_p op."); + AddAttr>( + "shape", "(std::vector) Target shape of reshape_p operator."); + AddComment(R"DOC( +Autograd primitive reshape_p operator. +)DOC"); + } +}; + +static int64_t product(const std::vector &shape) { + int64_t rslt = 1; + for (size_t i = 0; i < shape.size(); ++i) { + rslt *= shape[i]; + } + return rslt; +} + +class ReshapePrimOpShapeInference : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *ctx) const override { + framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0]; + framework::InferShapeVarPtr y_var_ptr = ctx->GetOutputVarPtrs("Y")[0]; + framework::VarDesc *x_var = BOOST_GET(framework::VarDesc *, x_var_ptr); + auto x_shape = x_var->GetShape(); + auto shape = ctx->Attrs().Get>("shape"); + PADDLE_ENFORCE_EQ(product(x_shape), product(shape), + platform::errors::InvalidArgument( + "The input tensor can't be reshaped to target shape, " + "the input tensor has %d elements but target shape " + "contains %d elements", + product(x_shape), product(shape))); + BOOST_GET(framework::VarDesc *, y_var_ptr)->SetShape(shape); + } +}; + +class ReshapePrimOpVarTypeInference + : public framework::StaticGraphVarTypeInference { + public: + void operator()(framework::InferVarTypeContext *ctx) const override { + auto x_name = Input(ctx, "X")[0]; + auto y_name = Output(ctx, "Y")[0]; + SetType(ctx, y_name, GetType(ctx, x_name)); + SetDataType(ctx, y_name, GetDataType(ctx, x_name)); + } +}; + +} // namespace operators +} // namespace paddle + +REGISTER_OPERATOR(reshape_p, paddle::operators::ReshapePrimOp, + paddle::operators::ReshapePrimOpMaker, + paddle::operators::ReshapePrimOpShapeInference, + paddle::operators::ReshapePrimOpVarTypeInference); diff --git a/paddle/fluid/operators/prim_ops/scatter_add_p_op.cc b/paddle/fluid/operators/prim_ops/scatter_add_p_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..420e6907e193dc9d6299380124ee259948d5bbc5 --- /dev/null +++ b/paddle/fluid/operators/prim_ops/scatter_add_p_op.cc @@ -0,0 +1,160 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" + +namespace paddle { +namespace framework { +class InferShapeContext; +class VarDesc; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace operators { +class ScatterAddPrimOp : public framework::OperatorBase { + public: + ScatterAddPrimOp(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : framework::OperatorBase(type, inputs, outputs, attrs) {} + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { + PADDLE_THROW(platform::errors::Unimplemented( + "Prim operator scatter_add_p should not be excuted directly")); + } +}; + +class ScatterAddPrimOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor), The tensor to apply scatter rule and add on."); + AddInput("Y", "(Tensor), The source tensor of scatter_add_p op."); + AddInput( + "IndexTensor", + "(Tensor), The index tensor of scatter_add_p op, which is a 1D tensor.") + .AsDispensable(); + AddOutput("Z", "(Tensor), The output tensor of scatter_add_p op."); + AddAttr("axis", + "(int64_t), The axis along which to scatter and add."); + AddAttr>( + "index", "(std::vector) The index of scatter_add_p op") + .SetDefault({0}); + AddComment(R"DOC( +Autograd primitive scatter_add_p operator. +)DOC"); + } +}; + +class ScatterAddPrimOpShapeInference : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *ctx) const override { + framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0]; + framework::InferShapeVarPtr y_var_ptr = ctx->GetInputVarPtrs("Y")[0]; + framework::InferShapeVarPtr z_var_ptr = ctx->GetOutputVarPtrs("Z")[0]; + int64_t num_index = 0; + if (ctx->HasInput("IndexTensor")) { + framework::InferShapeVarPtr index_var_ptr = + ctx->GetInputVarPtrs("IndexTensor")[0]; + framework::VarDesc *index_var = + BOOST_GET(framework::VarDesc *, index_var_ptr); + auto index_shape = index_var->GetShape(); + PADDLE_ENFORCE_EQ(index_shape.size(), 1, + platform::errors::InvalidArgument( + "The index tensor should be a 1D tensor," + "but get rank %d", + index_shape.size())); + num_index = index_shape[0]; + } else { + num_index = ctx->Attrs().Get>("index").size(); + } + auto axis = ctx->Attrs().Get("axis"); + framework::VarDesc *x_var = BOOST_GET(framework::VarDesc *, x_var_ptr); + framework::VarDesc *y_var = BOOST_GET(framework::VarDesc *, y_var_ptr); + auto x_shape = x_var->GetShape(); + auto y_shape = y_var->GetShape(); + size_t x_rank = x_shape.size(); + size_t y_rank = y_shape.size(); + PADDLE_ENFORCE_EQ(x_rank, y_rank, + platform::errors::InvalidArgument( + "The dimensions of two input tensor should be same, " + "but get %d and %d", + x_rank, y_rank)); + PADDLE_ENFORCE_EQ(y_shape[axis], num_index, + platform::errors::InvalidArgument( + "The shape of source input tensor at scatter axis " + "should be equal to num_index, " + "but get %d and %d", + y_shape[axis], num_index)); + for (size_t i = 0; i < x_rank; ++i) { + if (i != size_t(axis)) { + PADDLE_ENFORCE_EQ( + x_shape[i], y_shape[i], + platform::errors::InvalidArgument( + "The shape of two input tensor at dimension %d should be same, " + "but get %d and %d", + i, x_rank, y_rank)); + } + } + + BOOST_GET(framework::VarDesc *, z_var_ptr)->SetShape(x_shape); + } +}; + +class ScatterAddPrimOpVarTypeInference + : public framework::StaticGraphVarTypeInference { + public: + void operator()(framework::InferVarTypeContext *ctx) const override { + auto x_name = Input(ctx, "X")[0]; + auto y_name = Input(ctx, "Y")[0]; + auto z_name = Output(ctx, "Z")[0]; + auto x_type = GetType(ctx, x_name); + auto y_type = GetType(ctx, y_name); + auto x_dtype = GetDataType(ctx, x_name); + auto y_dtype = GetDataType(ctx, y_name); + PADDLE_ENFORCE_EQ(x_type, y_type, + platform::errors::InvalidArgument( + "The type of two input tensor should be same, " + "but get %d and %d", + x_type, y_type)); + PADDLE_ENFORCE_EQ(x_dtype, y_dtype, + platform::errors::InvalidArgument( + "The datatype of two input tensor should be same, " + "but get %d and %d", + x_dtype, y_dtype)); + + if (ctx->HasInput("IndexTensor")) { + auto index_name = Input(ctx, "IndexTensor")[0]; + auto index_dtype = GetDataType(ctx, index_name); + PADDLE_ENFORCE_EQ( + index_dtype, framework::proto::VarType_Type_INT32, + platform::errors::InvalidArgument( + "The datatype of input tensor should be VarType_Type_INT32(%d), " + "but get %d", + framework::proto::VarType_Type_INT32, index_dtype)); + } + SetType(ctx, z_name, GetType(ctx, x_name)); + SetDataType(ctx, z_name, GetDataType(ctx, x_name)); + } +}; + +} // namespace operators +} // namespace paddle + +REGISTER_OPERATOR(scatter_add_p, paddle::operators::ScatterAddPrimOp, + paddle::operators::ScatterAddPrimOpMaker, + paddle::operators::ScatterAddPrimOpShapeInference, + paddle::operators::ScatterAddPrimOpVarTypeInference); diff --git a/paddle/fluid/operators/prim_ops/slice_assign_p_op.cc b/paddle/fluid/operators/prim_ops/slice_assign_p_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..6fff54cced55093e09c3fba4a980b66454aa806b --- /dev/null +++ b/paddle/fluid/operators/prim_ops/slice_assign_p_op.cc @@ -0,0 +1,152 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" + +namespace paddle { +namespace framework { +class InferShapeContext; +class VarDesc; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace operators { +class SliceAssignPrimOp : public framework::OperatorBase { + public: + SliceAssignPrimOp(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : framework::OperatorBase(type, inputs, outputs, attrs) {} + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { + PADDLE_THROW(platform::errors::Unimplemented( + "Prim operator slice_assign_p should not be excuted directly")); + } +}; + +class SliceAssignPrimOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor), The tensor to slice from and assign on."); + AddInput("Y", "(Tensor), The source tensor of slice_assign_p op."); + AddOutput("Z", "(Tensor), The output tensor of slice_assign_p op."); + AddAttr>( + "axis", "(std::vector), The axis along which to gather."); + AddAttr>( + "starts", + "(std::vector) The slice starts of slice_assign_p op"); + AddAttr>( + "ends", "(std::vector) The slice ends of slice_assign_p op"); + AddAttr>( + "strides", + "(std::vector) The slice strides of slice_assign_p op"); + AddComment(R"DOC( +Autograd primitive slice_assign_p operator. +)DOC"); + } +}; + +class SliceAssignPrimOpShapeInference : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *ctx) const override { + framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0]; + framework::InferShapeVarPtr y_var_ptr = ctx->GetInputVarPtrs("Y")[0]; + framework::InferShapeVarPtr z_var_ptr = ctx->GetOutputVarPtrs("Z")[0]; + framework::VarDesc *x_var = BOOST_GET(framework::VarDesc *, x_var_ptr); + framework::VarDesc *y_var = BOOST_GET(framework::VarDesc *, y_var_ptr); + auto x_shape = x_var->GetShape(); + auto y_shape = y_var->GetShape(); + size_t x_rank = x_shape.size(); + size_t y_rank = y_shape.size(); + auto axis = ctx->Attrs().Get>("axis"); + auto starts = ctx->Attrs().Get>("starts"); + auto ends = ctx->Attrs().Get>("ends"); + auto strides = ctx->Attrs().Get>("strides"); + PADDLE_ENFORCE_EQ( + starts.size(), axis.size(), + platform::errors::InvalidArgument( + "Number of starts attribute and axis attribute should be same, " + "but get %d and %d", + starts.size(), axis.size())); + PADDLE_ENFORCE_EQ( + ends.size(), axis.size(), + platform::errors::InvalidArgument( + "Number of ends attribute and axis attribute should be same, " + "but get %d and %d", + ends.size(), axis.size())); + PADDLE_ENFORCE_EQ( + strides.size(), axis.size(), + platform::errors::InvalidArgument( + "Number of strides attribute and axis attribute should be same, " + "but get %d and %d", + strides.size(), axis.size())); + PADDLE_ENFORCE_EQ(x_rank, y_rank, + platform::errors::InvalidArgument( + "The dimensions of two input tensor should be same, " + "but get %d and %d", + x_rank, y_rank)); + std::vector y_target_shape(x_shape); + for (size_t i = 0; i < axis.size(); ++i) { + y_target_shape[axis[i]] = + (ends[i] - starts[i] + strides[i] - 1) / strides[i]; + } + for (size_t i = 0; i < x_rank; ++i) { + PADDLE_ENFORCE_EQ(y_target_shape[i], y_shape[i], + platform::errors::InvalidArgument( + "The shape of source tensor of slice_assign_p op " + "at dimension %d should be %d, " + "but get %d", + i, y_target_shape[i], y_shape[i])); + } + BOOST_GET(framework::VarDesc *, z_var_ptr)->SetShape(x_shape); + } +}; + +class SliceAssignPrimOpVarTypeInference + : public framework::StaticGraphVarTypeInference { + public: + void operator()(framework::InferVarTypeContext *ctx) const override { + auto x_name = Input(ctx, "X")[0]; + auto y_name = Input(ctx, "Y")[0]; + auto z_name = Output(ctx, "Z")[0]; + auto x_type = GetType(ctx, x_name); + auto y_type = GetType(ctx, y_name); + auto x_dtype = GetDataType(ctx, x_name); + auto y_dtype = GetDataType(ctx, y_name); + PADDLE_ENFORCE_EQ(x_type, y_type, + platform::errors::InvalidArgument( + "The type of two input tensor should be same, " + "but get %d and %d", + x_type, y_type)); + PADDLE_ENFORCE_EQ(x_dtype, y_dtype, + platform::errors::InvalidArgument( + "The datatype of two input tensor should be same, " + "but get %d and %d", + x_dtype, y_dtype)); + + SetType(ctx, z_name, GetType(ctx, x_name)); + SetDataType(ctx, z_name, GetDataType(ctx, x_name)); + } +}; + +} // namespace operators +} // namespace paddle + +REGISTER_OPERATOR(slice_assign_p, paddle::operators::SliceAssignPrimOp, + paddle::operators::SliceAssignPrimOpMaker, + paddle::operators::SliceAssignPrimOpShapeInference, + paddle::operators::SliceAssignPrimOpVarTypeInference); diff --git a/paddle/fluid/operators/prim_ops/slice_select_p_op.cc b/paddle/fluid/operators/prim_ops/slice_select_p_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..9456ab403737daf51eb80352eb97ebfd0f234fca --- /dev/null +++ b/paddle/fluid/operators/prim_ops/slice_select_p_op.cc @@ -0,0 +1,115 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" + +namespace paddle { +namespace framework { +class InferShapeContext; +class VarDesc; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace operators { +class SliceSelectPrimOp : public framework::OperatorBase { + public: + SliceSelectPrimOp(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : framework::OperatorBase(type, inputs, outputs, attrs) {} + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { + PADDLE_THROW(platform::errors::Unimplemented( + "Prim operator slice_select_p should not be excuted directly")); + } +}; + +class SliceSelectPrimOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor), The input tensor of slice_select_p op."); + AddOutput("Y", "(Tensor), The output tensor of slice_select_p op."); + AddAttr>( + "axis", "(std::vector), The axis along which to gather."); + AddAttr>( + "starts", + "(std::vector) The slice starts of slice_select_p op"); + AddAttr>( + "ends", "(std::vector) The slice ends of slice_select_p op"); + AddAttr>( + "strides", + "(std::vector) The slice strides of slice_select_p op"); + AddComment(R"DOC( +Autograd primitive slice_select_p operator. +)DOC"); + } +}; + +class SliceSelectPrimOpShapeInference : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *ctx) const override { + framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0]; + framework::InferShapeVarPtr y_var_ptr = ctx->GetOutputVarPtrs("Y")[0]; + framework::VarDesc *x_var = BOOST_GET(framework::VarDesc *, x_var_ptr); + auto x_shape = x_var->GetShape(); + auto axis = ctx->Attrs().Get>("axis"); + auto starts = ctx->Attrs().Get>("starts"); + auto ends = ctx->Attrs().Get>("ends"); + auto strides = ctx->Attrs().Get>("strides"); + PADDLE_ENFORCE_EQ( + starts.size(), axis.size(), + platform::errors::InvalidArgument( + "Number of starts attribute and axis attribute should be same, " + "but get %d and %d", + starts.size(), axis.size())); + PADDLE_ENFORCE_EQ( + ends.size(), axis.size(), + platform::errors::InvalidArgument( + "Number of ends attribute and axis attribute should be same, " + "but get %d and %d", + ends.size(), axis.size())); + PADDLE_ENFORCE_EQ( + strides.size(), axis.size(), + platform::errors::InvalidArgument( + "Number of strides attribute and axis attribute should be same, " + "but get %d and %d", + strides.size(), axis.size())); + for (size_t i = 0; i < axis.size(); ++i) { + x_shape[axis[i]] = (ends[i] - starts[i] + strides[i] - 1) / strides[i]; + } + BOOST_GET(framework::VarDesc *, y_var_ptr)->SetShape(x_shape); + } +}; + +class SliceSelectPrimOpVarTypeInference + : public framework::StaticGraphVarTypeInference { + public: + void operator()(framework::InferVarTypeContext *ctx) const override { + auto x_name = Input(ctx, "X")[0]; + auto y_name = Output(ctx, "Y")[0]; + SetType(ctx, y_name, GetType(ctx, x_name)); + SetDataType(ctx, y_name, GetDataType(ctx, x_name)); + } +}; + +} // namespace operators +} // namespace paddle + +REGISTER_OPERATOR(slice_select_p, paddle::operators::SliceSelectPrimOp, + paddle::operators::SliceSelectPrimOpMaker, + paddle::operators::SliceSelectPrimOpShapeInference, + paddle::operators::SliceSelectPrimOpVarTypeInference); diff --git a/paddle/fluid/operators/prim_ops/split_p_op.cc b/paddle/fluid/operators/prim_ops/split_p_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..212692bf0355b9fee0916bc5f7b1ba5daf197c81 --- /dev/null +++ b/paddle/fluid/operators/prim_ops/split_p_op.cc @@ -0,0 +1,119 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" + +namespace paddle { +namespace framework { +class InferShapeContext; +class VarDesc; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace operators { +class SplitPrimOp : public framework::OperatorBase { + public: + SplitPrimOp(const std::string &type, const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : framework::OperatorBase(type, inputs, outputs, attrs) {} + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { + PADDLE_THROW(platform::errors::Unimplemented( + "Prim operator split_p should not be excuted directly")); + } +}; + +class SplitPrimOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor), The input tensor of split_p op."); + AddOutput("YS", "(Tensor), The output tensors of split_p op.") + .AsDuplicable(); + AddAttr("axis", "(int64_t), The axis along which to split."); + AddAttr>( + "num_or_sections", + "(std::vector) If num_or_sections has only one element, then " + "num_or_sections indicates the number of equal sized sub-Tensors that " + "the input will be divided into. If num_or_sections has more then one " + "element, the length of it indicates the number of sub-Tensors and the " + "elements in it indicate the sizes of sub-Tensors’ dimension orderly. " + "The length of the vector must not be larger than the input's size of " + "specified axis."); + AddComment(R"DOC( +Autograd primitive split_p operator. +)DOC"); + } +}; + +class SplitPrimOpShapeInference : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *ctx) const override { + framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0]; + auto y_var_ptrs = ctx->GetOutputVarPtrs("YS"); + framework::VarDesc *x_var = BOOST_GET(framework::VarDesc *, x_var_ptr); + auto x_shape = x_var->GetShape(); + auto axis = ctx->Attrs().Get("axis"); + auto num_or_sections = + ctx->Attrs().Get>("num_or_sections"); + std::vector y_shape(x_shape); + if (num_or_sections.size() == 1) { + PADDLE_ENFORCE_EQ(x_shape[axis] % num_or_sections[0], 0, + platform::errors::InvalidArgument( + "The input tensor can't be devided equally into %d " + "parts equally along axis %d", + num_or_sections[0], axis)); + y_shape[axis] = x_shape[axis] / num_or_sections[0]; + for (size_t i = 0; i < size_t(num_or_sections[0]); ++i) { + BOOST_GET(framework::VarDesc *, y_var_ptrs[i])->SetShape(y_shape); + } + } else { + int64_t cnt_along_axis = 0; + for (size_t i = 0; i < num_or_sections.size(); ++i) { + y_shape[axis] = num_or_sections[i]; + cnt_along_axis += num_or_sections[i]; + BOOST_GET(framework::VarDesc *, y_var_ptrs[i])->SetShape(y_shape); + } + PADDLE_ENFORCE_EQ( + x_shape[axis], cnt_along_axis, + platform::errors::InvalidArgument( + "The input tensor has %d elements along axis %d, thus can't be " + "devided into %d tensor with %d elements totally.", + x_shape[axis], axis, num_or_sections.size(), cnt_along_axis)); + } + } +}; + +class SplitPrimOpVarTypeInference + : public framework::StaticGraphVarTypeInference { + public: + void operator()(framework::InferVarTypeContext *ctx) const override { + auto x_name = Input(ctx, "X")[0]; + auto y_names = Output(ctx, "YS"); + for (auto y_name : y_names) { + SetType(ctx, y_name, GetType(ctx, x_name)); + SetDataType(ctx, y_name, GetDataType(ctx, x_name)); + } + } +}; + +} // namespace operators +} // namespace paddle + +REGISTER_OPERATOR(split_p, paddle::operators::SplitPrimOp, + paddle::operators::SplitPrimOpMaker, + paddle::operators::SplitPrimOpShapeInference, + paddle::operators::SplitPrimOpVarTypeInference); diff --git a/paddle/fluid/operators/prim_ops/sqrt_p_op.cc b/paddle/fluid/operators/prim_ops/sqrt_p_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..de4958d29f9331f7b5377d302d8d06b1fabe1a02 --- /dev/null +++ b/paddle/fluid/operators/prim_ops/sqrt_p_op.cc @@ -0,0 +1,80 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" + +namespace paddle { +namespace framework { +class InferShapeContext; +class VarDesc; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace operators { +class SqrtPrimOp : public framework::OperatorBase { + public: + SqrtPrimOp(const std::string &type, const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : framework::OperatorBase(type, inputs, outputs, attrs) {} + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { + PADDLE_THROW(platform::errors::Unimplemented( + "Prim operator sqrt_p should not be excuted directly")); + } +}; + +class SqrtPrimOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor), The input tensor of sqrt_p op."); + AddOutput("Y", "(Tensor), The output tensor of sqrt_p op."); + AddComment(R"DOC( +Autograd primitive sqrt_p operator. +)DOC"); + } +}; + +class SqrtPrimOpShapeInference : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *ctx) const override { + framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0]; + framework::InferShapeVarPtr y_var_ptr = ctx->GetOutputVarPtrs("Y")[0]; + + framework::VarDesc *x_var = BOOST_GET(framework::VarDesc *, x_var_ptr); + + BOOST_GET(framework::VarDesc *, y_var_ptr)->SetShape(x_var->GetShape()); + } +}; + +class SqrtPrimOpVarTypeInference + : public framework::StaticGraphVarTypeInference { + public: + void operator()(framework::InferVarTypeContext *ctx) const override { + auto x_name = Input(ctx, "X")[0]; + auto y_name = Output(ctx, "Y")[0]; + SetType(ctx, y_name, GetType(ctx, x_name)); + SetDataType(ctx, y_name, GetDataType(ctx, x_name)); + } +}; + +} // namespace operators +} // namespace paddle + +REGISTER_OPERATOR(sqrt_p, paddle::operators::SqrtPrimOp, + paddle::operators::SqrtPrimOpMaker, + paddle::operators::SqrtPrimOpShapeInference, + paddle::operators::SqrtPrimOpVarTypeInference); diff --git a/paddle/fluid/operators/prim_ops/sub_p_op.cc b/paddle/fluid/operators/prim_ops/sub_p_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..f689f2d2d918b882a271e80cd27b43b3bfa2d49b --- /dev/null +++ b/paddle/fluid/operators/prim_ops/sub_p_op.cc @@ -0,0 +1,116 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" + +namespace paddle { +namespace framework { +class InferShapeContext; +class VarDesc; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace operators { +class SubPrimOp : public framework::OperatorBase { + public: + SubPrimOp(const std::string &type, const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : framework::OperatorBase(type, inputs, outputs, attrs) {} + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { + PADDLE_THROW(platform::errors::Unimplemented( + "Prim operator sub_p should not be excuted directly")); + } +}; + +class SubPrimOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor), The input tensor of sub_p op."); + AddInput("Y", "(Tensor), The input tensor of sub_p op."); + AddOutput("Z", "(Tensor), The output tensor of sub_p op."); + AddComment(R"DOC( +Autograd primitive sub_p operator. +)DOC"); + } +}; + +class SubPrimOpShapeInference : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *ctx) const override { + framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0]; + framework::InferShapeVarPtr y_var_ptr = ctx->GetInputVarPtrs("Y")[0]; + framework::InferShapeVarPtr z_var_ptr = ctx->GetOutputVarPtrs("Z")[0]; + + framework::VarDesc *x_var = BOOST_GET(framework::VarDesc *, x_var_ptr); + framework::VarDesc *y_var = BOOST_GET(framework::VarDesc *, y_var_ptr); + auto x_shape = x_var->GetShape(); + auto y_shape = y_var->GetShape(); + size_t x_rank = x_shape.size(); + size_t y_rank = y_shape.size(); + PADDLE_ENFORCE_EQ(x_rank, y_rank, + platform::errors::InvalidArgument( + "The dimensions of two input tensor should be same, " + "but get %d and %d", + x_rank, y_rank)); + for (size_t i = 0; i < x_rank; ++i) { + PADDLE_ENFORCE_EQ( + x_shape[i], y_shape[i], + platform::errors::InvalidArgument( + "The shape of two input tensor at dimension %d should be same, " + "but get %d and %d", + i, x_shape[i], y_shape[i])); + } + + BOOST_GET(framework::VarDesc *, z_var_ptr)->SetShape(x_shape); + } +}; + +class SubPrimOpVarTypeInference + : public framework::StaticGraphVarTypeInference { + public: + void operator()(framework::InferVarTypeContext *ctx) const override { + auto x_name = Input(ctx, "X")[0]; + auto y_name = Input(ctx, "Y")[0]; + auto z_name = Output(ctx, "Z")[0]; + auto x_type = GetType(ctx, x_name); + auto y_type = GetType(ctx, y_name); + auto x_dtype = GetDataType(ctx, x_name); + auto y_dtype = GetDataType(ctx, y_name); + PADDLE_ENFORCE_EQ(x_type, y_type, + platform::errors::InvalidArgument( + "The type of two input tensor should be same, " + "but get %d and %d", + x_type, y_type)); + PADDLE_ENFORCE_EQ(x_dtype, y_dtype, + platform::errors::InvalidArgument( + "The datatype of two input tensor should be same, " + "but get %d and %d", + x_dtype, y_dtype)); + + SetType(ctx, z_name, x_type); + SetDataType(ctx, z_name, x_dtype); + } +}; + +} // namespace operators +} // namespace paddle + +REGISTER_OPERATOR(sub_p, paddle::operators::SubPrimOp, + paddle::operators::SubPrimOpMaker, + paddle::operators::SubPrimOpShapeInference, + paddle::operators::SubPrimOpVarTypeInference); diff --git a/paddle/fluid/operators/prim_ops/tanh_p_op.cc b/paddle/fluid/operators/prim_ops/tanh_p_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..c2afdcbe4b20719da95096873e225885f488393f --- /dev/null +++ b/paddle/fluid/operators/prim_ops/tanh_p_op.cc @@ -0,0 +1,80 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" + +namespace paddle { +namespace framework { +class InferShapeContext; +class VarDesc; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace operators { +class TanhPrimOp : public framework::OperatorBase { + public: + TanhPrimOp(const std::string &type, const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : framework::OperatorBase(type, inputs, outputs, attrs) {} + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { + PADDLE_THROW(platform::errors::Unimplemented( + "Prim operator tanh_p should not be excuted directly")); + } +}; + +class TanhPrimOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor), The input tensor of tanh_p op."); + AddOutput("Y", "(Tensor), The output tensor of tanh_p op."); + AddComment(R"DOC( +Autograd primitive tanh_p operator. +)DOC"); + } +}; + +class TanhPrimOpShapeInference : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *ctx) const override { + framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0]; + framework::InferShapeVarPtr y_var_ptr = ctx->GetOutputVarPtrs("Y")[0]; + + framework::VarDesc *x_var = BOOST_GET(framework::VarDesc *, x_var_ptr); + + BOOST_GET(framework::VarDesc *, y_var_ptr)->SetShape(x_var->GetShape()); + } +}; + +class TanhPrimOpVarTypeInference + : public framework::StaticGraphVarTypeInference { + public: + void operator()(framework::InferVarTypeContext *ctx) const override { + auto x_name = Input(ctx, "X")[0]; + auto y_name = Output(ctx, "Y")[0]; + SetType(ctx, y_name, GetType(ctx, x_name)); + SetDataType(ctx, y_name, GetDataType(ctx, x_name)); + } +}; + +} // namespace operators +} // namespace paddle + +REGISTER_OPERATOR(tanh_p, paddle::operators::TanhPrimOp, + paddle::operators::TanhPrimOpMaker, + paddle::operators::TanhPrimOpShapeInference, + paddle::operators::TanhPrimOpVarTypeInference); diff --git a/paddle/fluid/operators/prim_ops/transpose_p_op.cc b/paddle/fluid/operators/prim_ops/transpose_p_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..b3b72318cd51ded5ab40b276b99494b54b14a33a --- /dev/null +++ b/paddle/fluid/operators/prim_ops/transpose_p_op.cc @@ -0,0 +1,116 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" + +namespace paddle { +namespace framework { +class InferShapeContext; +class VarDesc; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace operators { +class TransposePrimOp : public framework::OperatorBase { + public: + TransposePrimOp(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : framework::OperatorBase(type, inputs, outputs, attrs) {} + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { + PADDLE_THROW(platform::errors::Unimplemented( + "Prim operator transpose_p should not be excuted directly")); + } +}; + +class TransposePrimOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor), The input tensor of transpose_p op."); + AddOutput("Y", "(Tensor), The output tensor of transpose_p op."); + AddAttr>("axis", + "(std::vector) Tanspose axis."); + AddComment(R"DOC( +Autograd primitive transpose_p operator. +)DOC"); + } +}; + +class TransposePrimOpShapeInference : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *ctx) const override { + framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0]; + framework::InferShapeVarPtr y_var_ptr = ctx->GetOutputVarPtrs("Y")[0]; + framework::VarDesc *x_var = BOOST_GET(framework::VarDesc *, x_var_ptr); + auto x_shape = x_var->GetShape(); + auto axis = ctx->Attrs().Get>("axis"); + size_t x_rank = x_shape.size(); + size_t axis_size = axis.size(); + PADDLE_ENFORCE_EQ(x_rank, axis_size, + platform::errors::InvalidArgument( + "The input tensor's dimension " + "should be equal to the axis's size. " + "But received input tensor's dimension is %d, " + "axis's size is %d", + x_rank, axis_size)); + + std::vector count(axis_size, 0); + for (size_t i = 0; i < axis_size; i++) { + PADDLE_ENFORCE_GE(axis[i], 0, + platform::errors::InvalidArgument( + "The axis should be greater than or equal to 0." + "But received %d of axis[%d]", + axis[i], i)); + + PADDLE_ENFORCE_EQ( + axis[i] < static_cast(axis_size) && ++count[axis[i]] == 1, true, + platform::errors::InvalidArgument( + "Each element of Attribute axis should " + "be a unique value range from 0 to (dims - 1), " + "where the dims is the axis's size, " + "unique value means this axis value can appear only once. " + "But received axis[%d] is %d, axis_size is %d, " + "count[axis[%d]] is %d", + i, axis[i], axis_size, i, count[axis[i]])); + } + std::vector y_shape(axis_size); + for (size_t i = 0; i < axis_size; i++) { + y_shape[i] = x_shape[axis[i]]; + } + BOOST_GET(framework::VarDesc *, y_var_ptr)->SetShape(y_shape); + } +}; + +class TransposePrimOpVarTypeInference + : public framework::StaticGraphVarTypeInference { + public: + void operator()(framework::InferVarTypeContext *ctx) const override { + auto x_name = Input(ctx, "X")[0]; + auto y_name = Output(ctx, "Y")[0]; + SetType(ctx, y_name, GetType(ctx, x_name)); + SetDataType(ctx, y_name, GetDataType(ctx, x_name)); + } +}; + +} // namespace operators +} // namespace paddle + +REGISTER_OPERATOR(transpose_p, paddle::operators::TransposePrimOp, + paddle::operators::TransposePrimOpMaker, + paddle::operators::TransposePrimOpShapeInference, + paddle::operators::TransposePrimOpVarTypeInference); diff --git a/paddle/fluid/operators/prim_ops/unity_build_rule.cmake b/paddle/fluid/operators/prim_ops/unity_build_rule.cmake new file mode 100644 index 0000000000000000000000000000000000000000..5d6a732272b9bf906a5931712cb10cd5dcb02471 --- /dev/null +++ b/paddle/fluid/operators/prim_ops/unity_build_rule.cmake @@ -0,0 +1,20 @@ +register_unity_group(cc + reshape_p_op.cc + broadcast_p_op.cc + reduce_p_op.cc + transpose_p_op.cc + split_p_op.cc + concat_p_op.cc + slice_select_p_op.cc + slice_assign_p_op.cc + gather_p_op.cc + scatter_add_p_op.cc + add_p_op.cc + sub_p_op.cc + mul_p_op.cc + div_p_op.cc + sqrt_p_op.cc + tanh_p_op.cc + matmul_p_op.cc + fill_constant_p_op.cc + )