diff --git a/paddle/fluid/framework/custom_operator.cc b/paddle/fluid/framework/custom_operator.cc index ebfed9a6f73f6e900b4bd74efc7d6e98073b34b5..059524b21c6d617bb752281f16e28e4b90391c61 100644 --- a/paddle/fluid/framework/custom_operator.cc +++ b/paddle/fluid/framework/custom_operator.cc @@ -28,9 +28,6 @@ limitations under the License. */ #include "paddle/fluid/eager/api/utils/global_utils.h" #include "paddle/fluid/framework/attribute.h" #include "paddle/fluid/framework/convert_utils.h" -#include "paddle/fluid/framework/custom_operator_utils.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/phi_utils.h" #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h" @@ -843,289 +840,6 @@ class CustomOperator : public OperatorWithKernel { } }; -class CustomOpMaker : public OpProtoAndCheckerMaker { - public: - explicit CustomOpMaker(const std::vector& inputs, - const std::vector& outputs, - const std::vector& attrs) - : inputs_(inputs), outputs_(outputs), attrs_(attrs) {} - - void Make() override { - for (auto& in_name : inputs_) { - auto input_var_builder = - AddInput(in_name, "The input " + in_name + "of Custom operator."); - if (detail::IsDuplicableVar(in_name)) { - input_var_builder.AsDuplicable(); - } - if (detail::IsOptionalVar(in_name)) { - input_var_builder.AsDispensable(); - } - } - for (auto& out_name : outputs_) { - auto output_var_builder = - AddOutput(out_name, "The output " + out_name + "of Custom Operator."); - if (detail::IsDuplicableVar(out_name)) { - output_var_builder.AsDuplicable(); - } - if (detail::IsOptionalVar(out_name)) { - output_var_builder.AsDispensable(); - } - } - for (auto& attr : attrs_) { - auto attr_name_and_type = paddle::ParseAttrStr(attr); - auto attr_name = attr_name_and_type[0]; - auto attr_type_str = attr_name_and_type[1]; - if (attr_type_str == "bool") { - AddAttr(attr_name, "custom operator bool attribute.") - .SetDefault(false); - } else if (attr_type_str == "int") { - AddAttr(attr_name, "custom operator int attribute.").SetDefault(1); - } else if (attr_type_str == "float") { - AddAttr(attr_name, "custom operator float attribute.") - .SetDefault(1.0f); - } else if (attr_type_str == "int64_t") { - AddAttr(attr_name, "custom operator int64_t attribute.") - .SetDefault(1); - } else if (attr_type_str == "std::string") { - AddAttr(attr_name, "custom operator int attribute.") - .SetDefault(""); - } else if (attr_type_str == "std::vector") { - AddAttr>(attr_name, - "custom operator std::vector attribute.") - .SetDefault({}); - } else if (attr_type_str == "std::vector") { - AddAttr>( - attr_name, "custom operator std::vector attribute.") - .SetDefault({}); - } else if (attr_type_str == "std::vector") { - AddAttr>( - attr_name, "custom operator std::vector attribute.") - .SetDefault({}); - } else if (attr_type_str == "std::vector") { - AddAttr>( - attr_name, "custom operator std::vector attribute.") - .SetDefault({}); - } else { - PADDLE_THROW(platform::errors::Unimplemented( - "Unsupported `%s` type value as custom attribute now. " - "Supported data types include `bool`, `int`, `float`, " - "`int64_t`, `std::string`, `std::vector`, " - "`std::vector`, `std::vector`, " - "`std::vector`, Please check whether " - "the attribute data type and data type string are matched.", - attr_type_str)); - } - } - AddComment(R"DOC( -Custom Operator. - -According to the phi::DenseTensor operation function implemented by the user -independently of the framework, it is encapsulated into a framework -operator to adapt to various execution scenarios such as dynamic graph -mode, static graph mode, and inference mode. - -)DOC"); - } - - private: - std::vector inputs_; - std::vector outputs_; - std::vector attrs_; -}; - -template -class CustomGradOpMaker; - -template <> -class CustomGradOpMaker : public SingleGradOpMaker { - public: - explicit CustomGradOpMaker( - const OpDesc& fwd_op, - const std::unordered_set& no_grad_set, - std::unordered_map* grad_to_var, - const std::vector& grad_block, - const std::string& name, - const std::vector& inputs, - const std::vector& outputs, - bool is_double_grad) - : SingleGradOpMaker(fwd_op, no_grad_set, grad_to_var, grad_block), - name_(name), - inputs_(inputs), - outputs_(outputs), - is_double_grad_(is_double_grad) {} - - protected: - void Apply(GradOpPtr grad_op) const override { - grad_op->SetType(name_); - - auto fwd_op_inputs = this->InputNames(); - auto fwd_op_outputs = this->OutputNames(); - - for (auto& in_name : inputs_) { - VLOG(3) << "Custom Operator: GradOpDescMaker - input: " << in_name; - if (!detail::IsGradVar(in_name, is_double_grad_)) { - if (detail::IsMemberOf(fwd_op_inputs, in_name)) { - grad_op->SetInput(in_name, this->Input(in_name)); - } else if (detail::IsMemberOf(fwd_op_outputs, in_name)) { - grad_op->SetInput(in_name, this->Output(in_name)); - } else { - PADDLE_THROW(platform::errors::InvalidArgument( - "The input tensor name `%s` is invalid, expected it is the input " - "or output of forward operator.", - in_name)); - } - } else { - if (this->HasOutput(detail::NoGrad(in_name))) { - grad_op->SetInput(in_name, this->OutputGrad(detail::NoGrad(in_name))); - } else { - // Maybe visit here! handle inplace optional case - PADDLE_ENFORCE( - in_name.find(paddle::kOptionalSuffix) != std::string::npos, - phi::errors::InvalidArgument( - "Custom operator couldn't find grad operator input name for " - "%s. If you are using inplace optional inputs & outputs, " - "please check your InplaceMap and `Outputs` again and make " - "sure %s is wrapped by `paddle::Optional`", - in_name, - in_name)); - VLOG(3) << "Custom Operator: GradOpDescMaker - handle unfound input: " - << in_name; - } - } - } - for (auto& out_name : outputs_) { - // Handle inplace optional case - if (!this->HasInput(detail::NoGrad(out_name, is_double_grad_))) { - PADDLE_ENFORCE( - out_name.find(paddle::kOptionalSuffix) != std::string::npos, - phi::errors::InvalidArgument( - "Custom operator couldn't find grad operator output name for " - "%s. If you are using inplace optional inputs & outputs, " - "please check your InplaceMap and `Outputs` again and make " - "sure %s is wrapped by `paddle::Optional`", - out_name, - out_name)); - VLOG(3) << "Custom Operator: GradOpDescMaker - handle unfound output: " - << out_name; - continue; - } - VLOG(3) << "Custom Operator: GradOpDescMaker - output: " << out_name; - if (detail::IsDuplicableVar(out_name)) { - grad_op->SetOutput( - out_name, - this->InputGrad(detail::NoGrad(out_name, is_double_grad_), - /*drop_empty_grad=*/false)); - } else { - grad_op->SetOutput( - out_name, - this->InputGrad(detail::NoGrad(out_name, is_double_grad_))); - } - } - grad_op->SetAttrMap(this->Attrs()); - } - - private: - std::string name_; - std::vector inputs_; - std::vector outputs_; - bool is_double_grad_{false}; -}; - -template <> -class CustomGradOpMaker - : public SingleGradOpMaker { - public: - explicit CustomGradOpMaker( - const std::string& type, - const imperative::NameVarBaseMap& var_base_map_in, - const imperative::NameVarBaseMap& var_base_map_out, - const AttributeMap& attrs, - const std::map& inplace_map, - const std::string& name, - const std::vector& inputs, - const std::vector& outputs, - bool is_double_grad) - : SingleGradOpMaker( - type, var_base_map_in, var_base_map_out, attrs, inplace_map), - name_(name), - inputs_(inputs), - outputs_(outputs), - is_double_grad_(is_double_grad) {} - - protected: - // TODO(chenweihang): The code is duplicated with the previous one, because - // ere OpMaker's Input, Output and other methods are protected. Putting the - // function implementation outside the class will cause the method to be - // uncallable, - // so it is still implemented in the class for the time being. - void Apply(GradOpPtr grad_op) const override { - grad_op->SetType(name_); - - auto fwd_op_inputs = this->InputNames(); - auto fwd_op_outputs = this->OutputNames(); - - for (auto& in_name : inputs_) { - VLOG(3) << "Custom Operator: GradOpBaseMaker - input: " << in_name; - if (!detail::IsGradVar(in_name, is_double_grad_)) { - if (detail::IsMemberOf(fwd_op_inputs, in_name)) { - grad_op->SetInput(in_name, this->Input(in_name)); - } else if (detail::IsMemberOf(fwd_op_outputs, in_name)) { - grad_op->SetInput(in_name, this->Output(in_name)); - } else { - PADDLE_THROW(platform::errors::InvalidArgument( - "The input tensor name `%s` is invalid, expected it is the input " - "or output of forward operator.", - in_name)); - } - } else { - // Handle inplace optional case - if (this->HasOutput(detail::NoGrad(in_name))) { - grad_op->SetInput(in_name, this->OutputGrad(detail::NoGrad(in_name))); - } else { - PADDLE_ENFORCE( - in_name.find(paddle::kOptionalSuffix) != std::string::npos, - phi::errors::InvalidArgument( - "Custom operator couldn't find grad operator input name for " - "%s. If you are using inplace optional inputs & outputs, " - "please check your InplaceMap and `Outputs` again and make " - "sure %s is wrapped by `paddle::Optional`", - in_name, - in_name)); - VLOG(3) << "Custom Operator: GradOpBaseMaker - handle unfound input: " - << in_name; - } - } - } - for (auto& out_name : outputs_) { - // Handle inplace optional case - if (!this->HasInput(detail::NoGrad(out_name, is_double_grad_))) { - PADDLE_ENFORCE( - out_name.find(paddle::kOptionalSuffix) != std::string::npos, - phi::errors::InvalidArgument( - "Custom operator couldn't find grad operator output name for " - "%s. If you are using inplace optional inputs & outputs, " - "please check your InplaceMap and `Outputs` again and make " - "sure %s is wrapped by `paddle::Optional`", - out_name, - out_name)); - VLOG(3) << "Custom Operator: GradOpBaseMaker - handle unfound output: " - << out_name; - continue; - } - VLOG(3) << "Custom Operator: GradOpBaseMaker - output: " << out_name; - grad_op->SetOutput( - out_name, this->InputGrad(detail::NoGrad(out_name, is_double_grad_))); - } - grad_op->SetAttrMap(this->Attrs()); - } - - private: - std::string name_; - std::vector inputs_; - std::vector outputs_; - bool is_double_grad_{false}; -}; - //////////// Operator and Kernel Register ////////////// static void RegisterOperatorKernelWithPlace( diff --git a/paddle/fluid/framework/custom_operator.h b/paddle/fluid/framework/custom_operator.h index fef1e82a14fe6e03de40c8376f922f87f64564f8..35d133f9d0f31f5d9329624cc8bd9171ddc6ed94 100644 --- a/paddle/fluid/framework/custom_operator.h +++ b/paddle/fluid/framework/custom_operator.h @@ -16,10 +16,297 @@ limitations under the License. */ #include +#include "paddle/fluid/framework/custom_operator_utils.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" #include "paddle/phi/api/ext/op_meta_info.h" namespace paddle { namespace framework { + +class CustomOpMaker : public OpProtoAndCheckerMaker { + public: + explicit CustomOpMaker(const std::vector& inputs, + const std::vector& outputs, + const std::vector& attrs) + : inputs_(inputs), outputs_(outputs), attrs_(attrs) {} + + void Make() override { + for (auto& in_name : inputs_) { + auto input_var_builder = + AddInput(in_name, "The input " + in_name + "of Custom operator."); + if (detail::IsDuplicableVar(in_name)) { + input_var_builder.AsDuplicable(); + } + if (detail::IsOptionalVar(in_name)) { + input_var_builder.AsDispensable(); + } + } + for (auto& out_name : outputs_) { + auto output_var_builder = + AddOutput(out_name, "The output " + out_name + "of Custom Operator."); + if (detail::IsDuplicableVar(out_name)) { + output_var_builder.AsDuplicable(); + } + if (detail::IsOptionalVar(out_name)) { + output_var_builder.AsDispensable(); + } + } + for (auto& attr : attrs_) { + auto attr_name_and_type = paddle::ParseAttrStr(attr); + auto attr_name = attr_name_and_type[0]; + auto attr_type_str = attr_name_and_type[1]; + if (attr_type_str == "bool") { + AddAttr(attr_name, "custom operator bool attribute.") + .SetDefault(false); + } else if (attr_type_str == "int") { + AddAttr(attr_name, "custom operator int attribute.").SetDefault(1); + } else if (attr_type_str == "float") { + AddAttr(attr_name, "custom operator float attribute.") + .SetDefault(1.0f); + } else if (attr_type_str == "int64_t") { + AddAttr(attr_name, "custom operator int64_t attribute.") + .SetDefault(1); + } else if (attr_type_str == "std::string") { + AddAttr(attr_name, "custom operator int attribute.") + .SetDefault(""); + } else if (attr_type_str == "std::vector") { + AddAttr>(attr_name, + "custom operator std::vector attribute.") + .SetDefault({}); + } else if (attr_type_str == "std::vector") { + AddAttr>( + attr_name, "custom operator std::vector attribute.") + .SetDefault({}); + } else if (attr_type_str == "std::vector") { + AddAttr>( + attr_name, "custom operator std::vector attribute.") + .SetDefault({}); + } else if (attr_type_str == "std::vector") { + AddAttr>( + attr_name, "custom operator std::vector attribute.") + .SetDefault({}); + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Unsupported `%s` type value as custom attribute now. " + "Supported data types include `bool`, `int`, `float`, " + "`int64_t`, `std::string`, `std::vector`, " + "`std::vector`, `std::vector`, " + "`std::vector`, Please check whether " + "the attribute data type and data type string are matched.", + attr_type_str)); + } + } + AddComment(R"DOC( +Custom Operator. + +According to the phi::DenseTensor operation function implemented by the user +independently of the framework, it is encapsulated into a framework +operator to adapt to various execution scenarios such as dynamic graph +mode, static graph mode, and inference mode. + +)DOC"); + } + + private: + std::vector inputs_; + std::vector outputs_; + std::vector attrs_; +}; + +template +class CustomGradOpMaker; + +template <> +class CustomGradOpMaker : public SingleGradOpMaker { + public: + explicit CustomGradOpMaker( + const OpDesc& fwd_op, + const std::unordered_set& no_grad_set, + std::unordered_map* grad_to_var, + const std::vector& grad_block, + const std::string& name, + const std::vector& inputs, + const std::vector& outputs, + bool is_double_grad) + : SingleGradOpMaker(fwd_op, no_grad_set, grad_to_var, grad_block), + name_(name), + inputs_(inputs), + outputs_(outputs), + is_double_grad_(is_double_grad) {} + + protected: + void Apply(GradOpPtr grad_op) const override { + grad_op->SetType(name_); + + auto fwd_op_inputs = this->InputNames(); + auto fwd_op_outputs = this->OutputNames(); + + for (auto& in_name : inputs_) { + VLOG(3) << "Custom Operator: GradOpDescMaker - input: " << in_name; + if (!detail::IsGradVar(in_name, is_double_grad_)) { + if (detail::IsMemberOf(fwd_op_inputs, in_name)) { + grad_op->SetInput(in_name, this->Input(in_name)); + } else if (detail::IsMemberOf(fwd_op_outputs, in_name)) { + grad_op->SetInput(in_name, this->Output(in_name)); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "The input tensor name `%s` is invalid, expected it is the input " + "or output of forward operator.", + in_name)); + } + } else { + if (this->HasOutput(detail::NoGrad(in_name))) { + grad_op->SetInput(in_name, this->OutputGrad(detail::NoGrad(in_name))); + } else { + // Maybe visit here! handle inplace optional case + PADDLE_ENFORCE( + in_name.find(paddle::kOptionalSuffix) != std::string::npos, + phi::errors::InvalidArgument( + "Custom operator couldn't find grad operator input name for " + "%s. If you are using inplace optional inputs & outputs, " + "please check your InplaceMap and `Outputs` again and make " + "sure %s is wrapped by `paddle::Optional`", + in_name, + in_name)); + VLOG(3) << "Custom Operator: GradOpDescMaker - handle unfound input: " + << in_name; + } + } + } + for (auto& out_name : outputs_) { + // Handle inplace optional case + if (!this->HasInput(detail::NoGrad(out_name, is_double_grad_))) { + PADDLE_ENFORCE( + out_name.find(paddle::kOptionalSuffix) != std::string::npos, + phi::errors::InvalidArgument( + "Custom operator couldn't find grad operator output name for " + "%s. If you are using inplace optional inputs & outputs, " + "please check your InplaceMap and `Outputs` again and make " + "sure %s is wrapped by `paddle::Optional`", + out_name, + out_name)); + VLOG(3) << "Custom Operator: GradOpDescMaker - handle unfound output: " + << out_name; + continue; + } + VLOG(3) << "Custom Operator: GradOpDescMaker - output: " << out_name; + if (detail::IsDuplicableVar(out_name)) { + grad_op->SetOutput( + out_name, + this->InputGrad(detail::NoGrad(out_name, is_double_grad_), + /*drop_empty_grad=*/false)); + } else { + grad_op->SetOutput( + out_name, + this->InputGrad(detail::NoGrad(out_name, is_double_grad_))); + } + } + grad_op->SetAttrMap(this->Attrs()); + } + + private: + std::string name_; + std::vector inputs_; + std::vector outputs_; + bool is_double_grad_{false}; +}; + +template <> +class CustomGradOpMaker + : public SingleGradOpMaker { + public: + explicit CustomGradOpMaker( + const std::string& type, + const imperative::NameVarBaseMap& var_base_map_in, + const imperative::NameVarBaseMap& var_base_map_out, + const AttributeMap& attrs, + const std::map& inplace_map, + const std::string& name, + const std::vector& inputs, + const std::vector& outputs, + bool is_double_grad) + : SingleGradOpMaker( + type, var_base_map_in, var_base_map_out, attrs, inplace_map), + name_(name), + inputs_(inputs), + outputs_(outputs), + is_double_grad_(is_double_grad) {} + + protected: + // TODO(chenweihang): The code is duplicated with the previous one, because + // ere OpMaker's Input, Output and other methods are protected. Putting the + // function implementation outside the class will cause the method to be + // uncallable, + // so it is still implemented in the class for the time being. + void Apply(GradOpPtr grad_op) const override { + grad_op->SetType(name_); + + auto fwd_op_inputs = this->InputNames(); + auto fwd_op_outputs = this->OutputNames(); + + for (auto& in_name : inputs_) { + VLOG(3) << "Custom Operator: GradOpBaseMaker - input: " << in_name; + if (!detail::IsGradVar(in_name, is_double_grad_)) { + if (detail::IsMemberOf(fwd_op_inputs, in_name)) { + grad_op->SetInput(in_name, this->Input(in_name)); + } else if (detail::IsMemberOf(fwd_op_outputs, in_name)) { + grad_op->SetInput(in_name, this->Output(in_name)); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "The input tensor name `%s` is invalid, expected it is the input " + "or output of forward operator.", + in_name)); + } + } else { + // Handle inplace optional case + if (this->HasOutput(detail::NoGrad(in_name))) { + grad_op->SetInput(in_name, this->OutputGrad(detail::NoGrad(in_name))); + } else { + PADDLE_ENFORCE( + in_name.find(paddle::kOptionalSuffix) != std::string::npos, + phi::errors::InvalidArgument( + "Custom operator couldn't find grad operator input name for " + "%s. If you are using inplace optional inputs & outputs, " + "please check your InplaceMap and `Outputs` again and make " + "sure %s is wrapped by `paddle::Optional`", + in_name, + in_name)); + VLOG(3) << "Custom Operator: GradOpBaseMaker - handle unfound input: " + << in_name; + } + } + } + for (auto& out_name : outputs_) { + // Handle inplace optional case + if (!this->HasInput(detail::NoGrad(out_name, is_double_grad_))) { + PADDLE_ENFORCE( + out_name.find(paddle::kOptionalSuffix) != std::string::npos, + phi::errors::InvalidArgument( + "Custom operator couldn't find grad operator output name for " + "%s. If you are using inplace optional inputs & outputs, " + "please check your InplaceMap and `Outputs` again and make " + "sure %s is wrapped by `paddle::Optional`", + out_name, + out_name)); + VLOG(3) << "Custom Operator: GradOpBaseMaker - handle unfound output: " + << out_name; + continue; + } + VLOG(3) << "Custom Operator: GradOpBaseMaker - output: " << out_name; + grad_op->SetOutput( + out_name, this->InputGrad(detail::NoGrad(out_name, is_double_grad_))); + } + grad_op->SetAttrMap(this->Attrs()); + } + + private: + std::string name_; + std::vector inputs_; + std::vector outputs_; + bool is_double_grad_{false}; +}; + // Load custom op api: register op after user compiled const std::unordered_map>& LoadOpMetaInfoAndRegisterOp(const std::string& dso_name); diff --git a/paddle/phi/capi/include/c_infer_meta_context.h b/paddle/phi/capi/include/c_infer_meta_context.h new file mode 100644 index 0000000000000000000000000000000000000000..5f7fbbca0b34c1a0d0ef33cb3d493d1223275a0c --- /dev/null +++ b/paddle/phi/capi/include/c_infer_meta_context.h @@ -0,0 +1,99 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#if !defined(_WIN32) + +#include "paddle/phi/capi/include/c_data_type.h" +#include "paddle/phi/capi/include/c_int_array.h" +#include "paddle/phi/capi/include/c_meta_tensor.h" +#include "paddle/phi/capi/include/c_place.h" +#include "paddle/phi/capi/include/c_scalar.h" + +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct PD_InferMetaContext PD_InferMetaContext; + +PD_MetaTensor *PD_InferMetaContextInputAt(PD_InferMetaContext *ctx, + size_t index); + +PD_List PD_InferMetaContextMultiInputAt(PD_InferMetaContext *ctx, size_t index); + +PD_MetaTensor *PD_InferMetaContextOutputAt(PD_InferMetaContext *ctx, + size_t index); + +PD_List PD_InferMetaContextMultiOutputAt(PD_InferMetaContext *ctx, + size_t index); + +bool PD_InferMetaContextBoolAttrAt(PD_InferMetaContext *ctx, size_t index); + +int32_t PD_InferMetaContextInt32AttrAt(PD_InferMetaContext *ctx, size_t index); + +int64_t PD_InferMetaContextInt64AttrAt(PD_InferMetaContext *ctx, size_t index); + +float PD_InferMetaContextFloatAttrAt(PD_InferMetaContext *ctx, size_t index); + +double PD_InferMetaContextDoubleAttrAt(PD_InferMetaContext *ctx, size_t index); + +PD_Scalar *PD_InferMetaContextScalarAttrAt(PD_InferMetaContext *ctx, + size_t index); + +PD_IntArray *PD_InferMetaContextIntArrayAttrAt(PD_InferMetaContext *ctx, + size_t index); + +PD_DataType PD_InferMetaContextDataTypeAttrAt(PD_InferMetaContext *ctx, + size_t index); + +PD_DataLayout PD_InferMetaContextDataLayoutAttrAt(PD_InferMetaContext *ctx, + size_t index); + +char *PD_InferMetaContextStringAttrAt(PD_InferMetaContext *ctx, size_t index); + +PD_List PD_InferMetaContextListBoolAttrAt(PD_InferMetaContext *ctx, + size_t index); + +PD_List PD_InferMetaContextListInt32AttrAt(PD_InferMetaContext *ctx, + size_t index); + +PD_List PD_InferMetaContextListInt64AttrAt(PD_InferMetaContext *ctx, + size_t index); + +PD_List PD_InferMetaContextListFloatAttrAt(PD_InferMetaContext *ctx, + size_t index); + +PD_List PD_InferMetaContextListDoubleAttrAt(PD_InferMetaContext *ctx, + size_t index); + +PD_List PD_InferMetaContextListStringAttrAt(PD_InferMetaContext *ctx, + size_t index); + +PD_List PD_InferMetaContextListScalarAttrAt(PD_InferMetaContext *ctx, + size_t index); + +PD_Place *PD_InferMetaContextPlaceAttrAt(PD_InferMetaContext *ctx, + size_t index); + +PD_DataType PD_InferMetaContextDataTypeAttrAt(PD_InferMetaContext *ctx, + size_t index); + +PD_DataLayout PD_InferMetaContextDataLayoutAttrAt(PD_InferMetaContext *ctx, + size_t index); + +#ifdef __cplusplus +} // extern "C" +#endif +#endif diff --git a/paddle/phi/capi/include/c_meta_tensor.h b/paddle/phi/capi/include/c_meta_tensor.h new file mode 100644 index 0000000000000000000000000000000000000000..08f01084c6abf3d1bc108bb7a72dbd9db22203be --- /dev/null +++ b/paddle/phi/capi/include/c_meta_tensor.h @@ -0,0 +1,60 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#if !defined(_WIN32) + +#include "paddle/phi/capi/include/c_data_type.h" + +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct PD_MetaTensor PD_MetaTensor; + +PD_DataType PD_MetaTensorGetPDDataType(const PD_MetaTensor *tensor, + PD_Status *status); + +PD_DataLayout PD_MetaTensorGetDataLayout(const PD_MetaTensor *tensor, + PD_Status *status); + +int64_t PD_MetaTensorGetElementCount(const PD_MetaTensor *tensor, + PD_Status *status); + +int64_t PD_MetaTensorGetNumDims(const PD_MetaTensor *tensor, PD_Status *status); + +int64_t PD_MetaTensorGetDim(const PD_MetaTensor *tensor, + size_t index, + PD_Status *status); + +bool PD_MetaTensorIsValid(const PD_MetaTensor *tensor, PD_Status *status); + +void PD_MetaTensorSetDims(PD_MetaTensor *tensor, + int64_t ndims, + const int64_t *dims, + PD_Status *status); + +void PD_MetaTensorSetDataType(PD_MetaTensor *tensor, + PD_DataType dtype, + PD_Status *status); + +void PD_MetaTensorSetDataLayout(PD_MetaTensor *tensor, + PD_DataLayout layout, + PD_Status *status); + +#ifdef __cplusplus +} // extern "C" +#endif +#endif diff --git a/paddle/phi/capi/include/wrapper_base.h b/paddle/phi/capi/include/wrapper_base.h index 36653a86f2a28527ee1798ddd093774c04f96026..1379bf108880f96a21dcb77e3a7ac22d11ef2b76 100644 --- a/paddle/phi/capi/include/wrapper_base.h +++ b/paddle/phi/capi/include/wrapper_base.h @@ -25,10 +25,12 @@ #include "paddle/phi/api/ext/exception.h" #include "paddle/phi/capi/include/c_device_context.h" +#include "paddle/phi/capi/include/c_infer_meta_context.h" #include "paddle/phi/capi/include/c_int_array.h" #include "paddle/phi/capi/include/c_kernel_context.h" #include "paddle/phi/capi/include/c_kernel_factory.h" #include "paddle/phi/capi/include/c_kernel_registry.h" +#include "paddle/phi/capi/include/c_meta_tensor.h" #include "paddle/phi/capi/include/c_place.h" #include "paddle/phi/capi/include/c_scalar.h" #include "paddle/phi/capi/include/c_tensor.h" @@ -70,6 +72,19 @@ inline std::vector PD_TensorGetDims(PD_Tensor* tensor, return std::vector(); } +inline std::vector PD_MetaTensorGetDims(PD_MetaTensor* tensor, + PD_Status* status) { + int64_t ndims = PD_MetaTensorGetNumDims(tensor, status); + if (ndims > 0) { + std::vector shape(ndims); + for (int64_t i = 0; i < ndims; ++i) { + shape[i] = PD_MetaTensorGetDim(tensor, i, status); + } + return shape; + } + return std::vector(); +} + template class WrapperBase { public: @@ -212,26 +227,6 @@ class DenseTensor : public WrapperBase { return static_cast(ptr); } - // template - // T* mutable_data(int64_t size = 0, const PD_DeviceContext* ctx = nullptr) { - // C_Status status; - // auto ptr = PD_DeviceContextAllocateTensor( - // ctx, raw_data(), size, phi::capi::CppTypeToPDType::Type(), - // &status); - // PD_CHECK_STATUS(status); - // return static_cast(ptr); - // } - - // void* mutable_data(PD_DataType data_type, - // int64_t size = 0, - // const PD_DeviceContext* ctx = nullptr) { - // C_Status status; - // auto ptr = PD_DeviceContextAllocateTensor( - // ctx, raw_data(), size, data_type, &status); - // PD_CHECK_STATUS(status); - // return static_cast(ptr); - // } - DenseTensor& ShareDataWith(const DenseTensor& src) { C_Status status; PD_TensorShareDataWith(raw_data(), src.raw_data(), &status); @@ -448,10 +443,6 @@ class KernelArgsDef : WrapperBase { PD_DeletePointerList(list); return ret; } - - // std::vector - // attribute_defs() { - // } }; class KernelKey : WrapperBase { @@ -459,7 +450,6 @@ class KernelKey : WrapperBase { explicit KernelKey(PD_KernelKey* kernel_key) : WrapperBase(kernel_key) {} - // Backend backend() const { return backend_; } PD_DataLayout layout() const { PD_Status status; auto layout_ = PD_KernelKeyGetLayout(raw_data(), &status); @@ -491,6 +481,58 @@ class Kernel : WrapperBase { TensorArgDef OutputAt(size_t idx) { return args_def().input_defs()[idx]; } }; +class MetaTensor : WrapperBase { + public: + explicit MetaTensor(PD_MetaTensor* meta_tensor) + : WrapperBase(meta_tensor) {} + + std::vector dims() const { + C_Status status; + auto dimension = PD_MetaTensorGetDims(raw_data(), &status); + PD_CHECK_STATUS(status); + return dimension; + } + + PD_DataType dtype() const { + C_Status status; + auto data_type = PD_MetaTensorGetPDDataType(raw_data(), &status); + PD_CHECK_STATUS(status); + return data_type; + } + + PD_DataLayout layout() const { + C_Status status; + auto data_layout = PD_MetaTensorGetDataLayout(raw_data(), &status); + PD_CHECK_STATUS(status); + return data_layout; + } + + int64_t numel() const { + C_Status status; + auto element_count = PD_MetaTensorGetElementCount(raw_data(), &status); + PD_CHECK_STATUS(status); + return element_count; + } + + void set_dims(const std::vector& dims) { + C_Status status; + PD_MetaTensorSetDims(raw_data(), dims.size(), dims.data(), &status); + PD_CHECK_STATUS(status); + } + + void set_dtype(PD_DataType data_type) { + C_Status status; + PD_MetaTensorSetDataType(raw_data(), data_type, &status); + PD_CHECK_STATUS(status); + } + + void set_layout(PD_DataLayout data_layout) { + C_Status status; + PD_MetaTensorSetDataLayout(raw_data(), data_layout, &status); + PD_CHECK_STATUS(status); + } +}; + } // namespace capi } // namespace phi