未验证 提交 3f17596a 编写于 作者: R ronnywang 提交者: GitHub

[PHI CAPI] Add support for registering a new operator, PART1 (#55532)

上级 36bc5511
......@@ -28,9 +28,6 @@ limitations under the License. */
#include "paddle/fluid/eager/api/utils/global_utils.h"
#include "paddle/fluid/framework/attribute.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/custom_operator_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/phi_utils.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
......@@ -843,289 +840,6 @@ class CustomOperator : public OperatorWithKernel {
}
};
class CustomOpMaker : public OpProtoAndCheckerMaker {
public:
explicit CustomOpMaker(const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs,
const std::vector<std::string>& attrs)
: inputs_(inputs), outputs_(outputs), attrs_(attrs) {}
void Make() override {
for (auto& in_name : inputs_) {
auto input_var_builder =
AddInput(in_name, "The input " + in_name + "of Custom operator.");
if (detail::IsDuplicableVar(in_name)) {
input_var_builder.AsDuplicable();
}
if (detail::IsOptionalVar(in_name)) {
input_var_builder.AsDispensable();
}
}
for (auto& out_name : outputs_) {
auto output_var_builder =
AddOutput(out_name, "The output " + out_name + "of Custom Operator.");
if (detail::IsDuplicableVar(out_name)) {
output_var_builder.AsDuplicable();
}
if (detail::IsOptionalVar(out_name)) {
output_var_builder.AsDispensable();
}
}
for (auto& attr : attrs_) {
auto attr_name_and_type = paddle::ParseAttrStr(attr);
auto attr_name = attr_name_and_type[0];
auto attr_type_str = attr_name_and_type[1];
if (attr_type_str == "bool") {
AddAttr<bool>(attr_name, "custom operator bool attribute.")
.SetDefault(false);
} else if (attr_type_str == "int") {
AddAttr<int>(attr_name, "custom operator int attribute.").SetDefault(1);
} else if (attr_type_str == "float") {
AddAttr<float>(attr_name, "custom operator float attribute.")
.SetDefault(1.0f);
} else if (attr_type_str == "int64_t") {
AddAttr<int64_t>(attr_name, "custom operator int64_t attribute.")
.SetDefault(1);
} else if (attr_type_str == "std::string") {
AddAttr<std::string>(attr_name, "custom operator int attribute.")
.SetDefault("");
} else if (attr_type_str == "std::vector<int>") {
AddAttr<std::vector<int>>(attr_name,
"custom operator std::vector<int> attribute.")
.SetDefault({});
} else if (attr_type_str == "std::vector<float>") {
AddAttr<std::vector<float>>(
attr_name, "custom operator std::vector<float> attribute.")
.SetDefault({});
} else if (attr_type_str == "std::vector<int64_t>") {
AddAttr<std::vector<int64_t>>(
attr_name, "custom operator std::vector<int64_t> attribute.")
.SetDefault({});
} else if (attr_type_str == "std::vector<std::string>") {
AddAttr<std::vector<std::string>>(
attr_name, "custom operator std::vector<std::string> attribute.")
.SetDefault({});
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported `%s` type value as custom attribute now. "
"Supported data types include `bool`, `int`, `float`, "
"`int64_t`, `std::string`, `std::vector<int>`, "
"`std::vector<float>`, `std::vector<int64_t>`, "
"`std::vector<std::string>`, Please check whether "
"the attribute data type and data type string are matched.",
attr_type_str));
}
}
AddComment(R"DOC(
Custom Operator.
According to the phi::DenseTensor operation function implemented by the user
independently of the framework, it is encapsulated into a framework
operator to adapt to various execution scenarios such as dynamic graph
mode, static graph mode, and inference mode.
)DOC");
}
private:
std::vector<std::string> inputs_;
std::vector<std::string> outputs_;
std::vector<std::string> attrs_;
};
template <typename T>
class CustomGradOpMaker;
template <>
class CustomGradOpMaker<OpDesc> : public SingleGradOpMaker<OpDesc> {
public:
explicit CustomGradOpMaker(
const OpDesc& fwd_op,
const std::unordered_set<std::string>& no_grad_set,
std::unordered_map<std::string, std::string>* grad_to_var,
const std::vector<BlockDesc*>& grad_block,
const std::string& name,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs,
bool is_double_grad)
: SingleGradOpMaker<OpDesc>(fwd_op, no_grad_set, grad_to_var, grad_block),
name_(name),
inputs_(inputs),
outputs_(outputs),
is_double_grad_(is_double_grad) {}
protected:
void Apply(GradOpPtr<OpDesc> grad_op) const override {
grad_op->SetType(name_);
auto fwd_op_inputs = this->InputNames();
auto fwd_op_outputs = this->OutputNames();
for (auto& in_name : inputs_) {
VLOG(3) << "Custom Operator: GradOpDescMaker - input: " << in_name;
if (!detail::IsGradVar(in_name, is_double_grad_)) {
if (detail::IsMemberOf(fwd_op_inputs, in_name)) {
grad_op->SetInput(in_name, this->Input(in_name));
} else if (detail::IsMemberOf(fwd_op_outputs, in_name)) {
grad_op->SetInput(in_name, this->Output(in_name));
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"The input tensor name `%s` is invalid, expected it is the input "
"or output of forward operator.",
in_name));
}
} else {
if (this->HasOutput(detail::NoGrad(in_name))) {
grad_op->SetInput(in_name, this->OutputGrad(detail::NoGrad(in_name)));
} else {
// Maybe visit here! handle inplace optional case
PADDLE_ENFORCE(
in_name.find(paddle::kOptionalSuffix) != std::string::npos,
phi::errors::InvalidArgument(
"Custom operator couldn't find grad operator input name for "
"%s. If you are using inplace optional inputs & outputs, "
"please check your InplaceMap and `Outputs` again and make "
"sure %s is wrapped by `paddle::Optional`",
in_name,
in_name));
VLOG(3) << "Custom Operator: GradOpDescMaker - handle unfound input: "
<< in_name;
}
}
}
for (auto& out_name : outputs_) {
// Handle inplace optional case
if (!this->HasInput(detail::NoGrad(out_name, is_double_grad_))) {
PADDLE_ENFORCE(
out_name.find(paddle::kOptionalSuffix) != std::string::npos,
phi::errors::InvalidArgument(
"Custom operator couldn't find grad operator output name for "
"%s. If you are using inplace optional inputs & outputs, "
"please check your InplaceMap and `Outputs` again and make "
"sure %s is wrapped by `paddle::Optional`",
out_name,
out_name));
VLOG(3) << "Custom Operator: GradOpDescMaker - handle unfound output: "
<< out_name;
continue;
}
VLOG(3) << "Custom Operator: GradOpDescMaker - output: " << out_name;
if (detail::IsDuplicableVar(out_name)) {
grad_op->SetOutput(
out_name,
this->InputGrad(detail::NoGrad(out_name, is_double_grad_),
/*drop_empty_grad=*/false));
} else {
grad_op->SetOutput(
out_name,
this->InputGrad(detail::NoGrad(out_name, is_double_grad_)));
}
}
grad_op->SetAttrMap(this->Attrs());
}
private:
std::string name_;
std::vector<std::string> inputs_;
std::vector<std::string> outputs_;
bool is_double_grad_{false};
};
template <>
class CustomGradOpMaker<imperative::OpBase>
: public SingleGradOpMaker<imperative::OpBase> {
public:
explicit CustomGradOpMaker(
const std::string& type,
const imperative::NameVarBaseMap& var_base_map_in,
const imperative::NameVarBaseMap& var_base_map_out,
const AttributeMap& attrs,
const std::map<std::string, std::string>& inplace_map,
const std::string& name,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs,
bool is_double_grad)
: SingleGradOpMaker<imperative::OpBase>(
type, var_base_map_in, var_base_map_out, attrs, inplace_map),
name_(name),
inputs_(inputs),
outputs_(outputs),
is_double_grad_(is_double_grad) {}
protected:
// TODO(chenweihang): The code is duplicated with the previous one, because
// ere OpMaker's Input, Output and other methods are protected. Putting the
// function implementation outside the class will cause the method to be
// uncallable,
// so it is still implemented in the class for the time being.
void Apply(GradOpPtr<imperative::OpBase> grad_op) const override {
grad_op->SetType(name_);
auto fwd_op_inputs = this->InputNames();
auto fwd_op_outputs = this->OutputNames();
for (auto& in_name : inputs_) {
VLOG(3) << "Custom Operator: GradOpBaseMaker - input: " << in_name;
if (!detail::IsGradVar(in_name, is_double_grad_)) {
if (detail::IsMemberOf(fwd_op_inputs, in_name)) {
grad_op->SetInput(in_name, this->Input(in_name));
} else if (detail::IsMemberOf(fwd_op_outputs, in_name)) {
grad_op->SetInput(in_name, this->Output(in_name));
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"The input tensor name `%s` is invalid, expected it is the input "
"or output of forward operator.",
in_name));
}
} else {
// Handle inplace optional case
if (this->HasOutput(detail::NoGrad(in_name))) {
grad_op->SetInput(in_name, this->OutputGrad(detail::NoGrad(in_name)));
} else {
PADDLE_ENFORCE(
in_name.find(paddle::kOptionalSuffix) != std::string::npos,
phi::errors::InvalidArgument(
"Custom operator couldn't find grad operator input name for "
"%s. If you are using inplace optional inputs & outputs, "
"please check your InplaceMap and `Outputs` again and make "
"sure %s is wrapped by `paddle::Optional`",
in_name,
in_name));
VLOG(3) << "Custom Operator: GradOpBaseMaker - handle unfound input: "
<< in_name;
}
}
}
for (auto& out_name : outputs_) {
// Handle inplace optional case
if (!this->HasInput(detail::NoGrad(out_name, is_double_grad_))) {
PADDLE_ENFORCE(
out_name.find(paddle::kOptionalSuffix) != std::string::npos,
phi::errors::InvalidArgument(
"Custom operator couldn't find grad operator output name for "
"%s. If you are using inplace optional inputs & outputs, "
"please check your InplaceMap and `Outputs` again and make "
"sure %s is wrapped by `paddle::Optional`",
out_name,
out_name));
VLOG(3) << "Custom Operator: GradOpBaseMaker - handle unfound output: "
<< out_name;
continue;
}
VLOG(3) << "Custom Operator: GradOpBaseMaker - output: " << out_name;
grad_op->SetOutput(
out_name, this->InputGrad(detail::NoGrad(out_name, is_double_grad_)));
}
grad_op->SetAttrMap(this->Attrs());
}
private:
std::string name_;
std::vector<std::string> inputs_;
std::vector<std::string> outputs_;
bool is_double_grad_{false};
};
//////////// Operator and Kernel Register //////////////
static void RegisterOperatorKernelWithPlace(
......
......@@ -16,10 +16,297 @@ limitations under the License. */
#include <string>
#include "paddle/fluid/framework/custom_operator_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/phi/api/ext/op_meta_info.h"
namespace paddle {
namespace framework {
class CustomOpMaker : public OpProtoAndCheckerMaker {
public:
explicit CustomOpMaker(const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs,
const std::vector<std::string>& attrs)
: inputs_(inputs), outputs_(outputs), attrs_(attrs) {}
void Make() override {
for (auto& in_name : inputs_) {
auto input_var_builder =
AddInput(in_name, "The input " + in_name + "of Custom operator.");
if (detail::IsDuplicableVar(in_name)) {
input_var_builder.AsDuplicable();
}
if (detail::IsOptionalVar(in_name)) {
input_var_builder.AsDispensable();
}
}
for (auto& out_name : outputs_) {
auto output_var_builder =
AddOutput(out_name, "The output " + out_name + "of Custom Operator.");
if (detail::IsDuplicableVar(out_name)) {
output_var_builder.AsDuplicable();
}
if (detail::IsOptionalVar(out_name)) {
output_var_builder.AsDispensable();
}
}
for (auto& attr : attrs_) {
auto attr_name_and_type = paddle::ParseAttrStr(attr);
auto attr_name = attr_name_and_type[0];
auto attr_type_str = attr_name_and_type[1];
if (attr_type_str == "bool") {
AddAttr<bool>(attr_name, "custom operator bool attribute.")
.SetDefault(false);
} else if (attr_type_str == "int") {
AddAttr<int>(attr_name, "custom operator int attribute.").SetDefault(1);
} else if (attr_type_str == "float") {
AddAttr<float>(attr_name, "custom operator float attribute.")
.SetDefault(1.0f);
} else if (attr_type_str == "int64_t") {
AddAttr<int64_t>(attr_name, "custom operator int64_t attribute.")
.SetDefault(1);
} else if (attr_type_str == "std::string") {
AddAttr<std::string>(attr_name, "custom operator int attribute.")
.SetDefault("");
} else if (attr_type_str == "std::vector<int>") {
AddAttr<std::vector<int>>(attr_name,
"custom operator std::vector<int> attribute.")
.SetDefault({});
} else if (attr_type_str == "std::vector<float>") {
AddAttr<std::vector<float>>(
attr_name, "custom operator std::vector<float> attribute.")
.SetDefault({});
} else if (attr_type_str == "std::vector<int64_t>") {
AddAttr<std::vector<int64_t>>(
attr_name, "custom operator std::vector<int64_t> attribute.")
.SetDefault({});
} else if (attr_type_str == "std::vector<std::string>") {
AddAttr<std::vector<std::string>>(
attr_name, "custom operator std::vector<std::string> attribute.")
.SetDefault({});
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported `%s` type value as custom attribute now. "
"Supported data types include `bool`, `int`, `float`, "
"`int64_t`, `std::string`, `std::vector<int>`, "
"`std::vector<float>`, `std::vector<int64_t>`, "
"`std::vector<std::string>`, Please check whether "
"the attribute data type and data type string are matched.",
attr_type_str));
}
}
AddComment(R"DOC(
Custom Operator.
According to the phi::DenseTensor operation function implemented by the user
independently of the framework, it is encapsulated into a framework
operator to adapt to various execution scenarios such as dynamic graph
mode, static graph mode, and inference mode.
)DOC");
}
private:
std::vector<std::string> inputs_;
std::vector<std::string> outputs_;
std::vector<std::string> attrs_;
};
template <typename T>
class CustomGradOpMaker;
template <>
class CustomGradOpMaker<OpDesc> : public SingleGradOpMaker<OpDesc> {
public:
explicit CustomGradOpMaker(
const OpDesc& fwd_op,
const std::unordered_set<std::string>& no_grad_set,
std::unordered_map<std::string, std::string>* grad_to_var,
const std::vector<BlockDesc*>& grad_block,
const std::string& name,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs,
bool is_double_grad)
: SingleGradOpMaker<OpDesc>(fwd_op, no_grad_set, grad_to_var, grad_block),
name_(name),
inputs_(inputs),
outputs_(outputs),
is_double_grad_(is_double_grad) {}
protected:
void Apply(GradOpPtr<OpDesc> grad_op) const override {
grad_op->SetType(name_);
auto fwd_op_inputs = this->InputNames();
auto fwd_op_outputs = this->OutputNames();
for (auto& in_name : inputs_) {
VLOG(3) << "Custom Operator: GradOpDescMaker - input: " << in_name;
if (!detail::IsGradVar(in_name, is_double_grad_)) {
if (detail::IsMemberOf(fwd_op_inputs, in_name)) {
grad_op->SetInput(in_name, this->Input(in_name));
} else if (detail::IsMemberOf(fwd_op_outputs, in_name)) {
grad_op->SetInput(in_name, this->Output(in_name));
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"The input tensor name `%s` is invalid, expected it is the input "
"or output of forward operator.",
in_name));
}
} else {
if (this->HasOutput(detail::NoGrad(in_name))) {
grad_op->SetInput(in_name, this->OutputGrad(detail::NoGrad(in_name)));
} else {
// Maybe visit here! handle inplace optional case
PADDLE_ENFORCE(
in_name.find(paddle::kOptionalSuffix) != std::string::npos,
phi::errors::InvalidArgument(
"Custom operator couldn't find grad operator input name for "
"%s. If you are using inplace optional inputs & outputs, "
"please check your InplaceMap and `Outputs` again and make "
"sure %s is wrapped by `paddle::Optional`",
in_name,
in_name));
VLOG(3) << "Custom Operator: GradOpDescMaker - handle unfound input: "
<< in_name;
}
}
}
for (auto& out_name : outputs_) {
// Handle inplace optional case
if (!this->HasInput(detail::NoGrad(out_name, is_double_grad_))) {
PADDLE_ENFORCE(
out_name.find(paddle::kOptionalSuffix) != std::string::npos,
phi::errors::InvalidArgument(
"Custom operator couldn't find grad operator output name for "
"%s. If you are using inplace optional inputs & outputs, "
"please check your InplaceMap and `Outputs` again and make "
"sure %s is wrapped by `paddle::Optional`",
out_name,
out_name));
VLOG(3) << "Custom Operator: GradOpDescMaker - handle unfound output: "
<< out_name;
continue;
}
VLOG(3) << "Custom Operator: GradOpDescMaker - output: " << out_name;
if (detail::IsDuplicableVar(out_name)) {
grad_op->SetOutput(
out_name,
this->InputGrad(detail::NoGrad(out_name, is_double_grad_),
/*drop_empty_grad=*/false));
} else {
grad_op->SetOutput(
out_name,
this->InputGrad(detail::NoGrad(out_name, is_double_grad_)));
}
}
grad_op->SetAttrMap(this->Attrs());
}
private:
std::string name_;
std::vector<std::string> inputs_;
std::vector<std::string> outputs_;
bool is_double_grad_{false};
};
template <>
class CustomGradOpMaker<imperative::OpBase>
: public SingleGradOpMaker<imperative::OpBase> {
public:
explicit CustomGradOpMaker(
const std::string& type,
const imperative::NameVarBaseMap& var_base_map_in,
const imperative::NameVarBaseMap& var_base_map_out,
const AttributeMap& attrs,
const std::map<std::string, std::string>& inplace_map,
const std::string& name,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs,
bool is_double_grad)
: SingleGradOpMaker<imperative::OpBase>(
type, var_base_map_in, var_base_map_out, attrs, inplace_map),
name_(name),
inputs_(inputs),
outputs_(outputs),
is_double_grad_(is_double_grad) {}
protected:
// TODO(chenweihang): The code is duplicated with the previous one, because
// ere OpMaker's Input, Output and other methods are protected. Putting the
// function implementation outside the class will cause the method to be
// uncallable,
// so it is still implemented in the class for the time being.
void Apply(GradOpPtr<imperative::OpBase> grad_op) const override {
grad_op->SetType(name_);
auto fwd_op_inputs = this->InputNames();
auto fwd_op_outputs = this->OutputNames();
for (auto& in_name : inputs_) {
VLOG(3) << "Custom Operator: GradOpBaseMaker - input: " << in_name;
if (!detail::IsGradVar(in_name, is_double_grad_)) {
if (detail::IsMemberOf(fwd_op_inputs, in_name)) {
grad_op->SetInput(in_name, this->Input(in_name));
} else if (detail::IsMemberOf(fwd_op_outputs, in_name)) {
grad_op->SetInput(in_name, this->Output(in_name));
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"The input tensor name `%s` is invalid, expected it is the input "
"or output of forward operator.",
in_name));
}
} else {
// Handle inplace optional case
if (this->HasOutput(detail::NoGrad(in_name))) {
grad_op->SetInput(in_name, this->OutputGrad(detail::NoGrad(in_name)));
} else {
PADDLE_ENFORCE(
in_name.find(paddle::kOptionalSuffix) != std::string::npos,
phi::errors::InvalidArgument(
"Custom operator couldn't find grad operator input name for "
"%s. If you are using inplace optional inputs & outputs, "
"please check your InplaceMap and `Outputs` again and make "
"sure %s is wrapped by `paddle::Optional`",
in_name,
in_name));
VLOG(3) << "Custom Operator: GradOpBaseMaker - handle unfound input: "
<< in_name;
}
}
}
for (auto& out_name : outputs_) {
// Handle inplace optional case
if (!this->HasInput(detail::NoGrad(out_name, is_double_grad_))) {
PADDLE_ENFORCE(
out_name.find(paddle::kOptionalSuffix) != std::string::npos,
phi::errors::InvalidArgument(
"Custom operator couldn't find grad operator output name for "
"%s. If you are using inplace optional inputs & outputs, "
"please check your InplaceMap and `Outputs` again and make "
"sure %s is wrapped by `paddle::Optional`",
out_name,
out_name));
VLOG(3) << "Custom Operator: GradOpBaseMaker - handle unfound output: "
<< out_name;
continue;
}
VLOG(3) << "Custom Operator: GradOpBaseMaker - output: " << out_name;
grad_op->SetOutput(
out_name, this->InputGrad(detail::NoGrad(out_name, is_double_grad_)));
}
grad_op->SetAttrMap(this->Attrs());
}
private:
std::string name_;
std::vector<std::string> inputs_;
std::vector<std::string> outputs_;
bool is_double_grad_{false};
};
// Load custom op api: register op after user compiled
const std::unordered_map<std::string, std::vector<OpMetaInfo>>&
LoadOpMetaInfoAndRegisterOp(const std::string& dso_name);
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#if !defined(_WIN32)
#include "paddle/phi/capi/include/c_data_type.h"
#include "paddle/phi/capi/include/c_int_array.h"
#include "paddle/phi/capi/include/c_meta_tensor.h"
#include "paddle/phi/capi/include/c_place.h"
#include "paddle/phi/capi/include/c_scalar.h"
#ifdef __cplusplus
extern "C" {
#endif
typedef struct PD_InferMetaContext PD_InferMetaContext;
PD_MetaTensor *PD_InferMetaContextInputAt(PD_InferMetaContext *ctx,
size_t index);
PD_List PD_InferMetaContextMultiInputAt(PD_InferMetaContext *ctx, size_t index);
PD_MetaTensor *PD_InferMetaContextOutputAt(PD_InferMetaContext *ctx,
size_t index);
PD_List PD_InferMetaContextMultiOutputAt(PD_InferMetaContext *ctx,
size_t index);
bool PD_InferMetaContextBoolAttrAt(PD_InferMetaContext *ctx, size_t index);
int32_t PD_InferMetaContextInt32AttrAt(PD_InferMetaContext *ctx, size_t index);
int64_t PD_InferMetaContextInt64AttrAt(PD_InferMetaContext *ctx, size_t index);
float PD_InferMetaContextFloatAttrAt(PD_InferMetaContext *ctx, size_t index);
double PD_InferMetaContextDoubleAttrAt(PD_InferMetaContext *ctx, size_t index);
PD_Scalar *PD_InferMetaContextScalarAttrAt(PD_InferMetaContext *ctx,
size_t index);
PD_IntArray *PD_InferMetaContextIntArrayAttrAt(PD_InferMetaContext *ctx,
size_t index);
PD_DataType PD_InferMetaContextDataTypeAttrAt(PD_InferMetaContext *ctx,
size_t index);
PD_DataLayout PD_InferMetaContextDataLayoutAttrAt(PD_InferMetaContext *ctx,
size_t index);
char *PD_InferMetaContextStringAttrAt(PD_InferMetaContext *ctx, size_t index);
PD_List PD_InferMetaContextListBoolAttrAt(PD_InferMetaContext *ctx,
size_t index);
PD_List PD_InferMetaContextListInt32AttrAt(PD_InferMetaContext *ctx,
size_t index);
PD_List PD_InferMetaContextListInt64AttrAt(PD_InferMetaContext *ctx,
size_t index);
PD_List PD_InferMetaContextListFloatAttrAt(PD_InferMetaContext *ctx,
size_t index);
PD_List PD_InferMetaContextListDoubleAttrAt(PD_InferMetaContext *ctx,
size_t index);
PD_List PD_InferMetaContextListStringAttrAt(PD_InferMetaContext *ctx,
size_t index);
PD_List PD_InferMetaContextListScalarAttrAt(PD_InferMetaContext *ctx,
size_t index);
PD_Place *PD_InferMetaContextPlaceAttrAt(PD_InferMetaContext *ctx,
size_t index);
PD_DataType PD_InferMetaContextDataTypeAttrAt(PD_InferMetaContext *ctx,
size_t index);
PD_DataLayout PD_InferMetaContextDataLayoutAttrAt(PD_InferMetaContext *ctx,
size_t index);
#ifdef __cplusplus
} // extern "C"
#endif
#endif
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#if !defined(_WIN32)
#include "paddle/phi/capi/include/c_data_type.h"
#ifdef __cplusplus
extern "C" {
#endif
typedef struct PD_MetaTensor PD_MetaTensor;
PD_DataType PD_MetaTensorGetPDDataType(const PD_MetaTensor *tensor,
PD_Status *status);
PD_DataLayout PD_MetaTensorGetDataLayout(const PD_MetaTensor *tensor,
PD_Status *status);
int64_t PD_MetaTensorGetElementCount(const PD_MetaTensor *tensor,
PD_Status *status);
int64_t PD_MetaTensorGetNumDims(const PD_MetaTensor *tensor, PD_Status *status);
int64_t PD_MetaTensorGetDim(const PD_MetaTensor *tensor,
size_t index,
PD_Status *status);
bool PD_MetaTensorIsValid(const PD_MetaTensor *tensor, PD_Status *status);
void PD_MetaTensorSetDims(PD_MetaTensor *tensor,
int64_t ndims,
const int64_t *dims,
PD_Status *status);
void PD_MetaTensorSetDataType(PD_MetaTensor *tensor,
PD_DataType dtype,
PD_Status *status);
void PD_MetaTensorSetDataLayout(PD_MetaTensor *tensor,
PD_DataLayout layout,
PD_Status *status);
#ifdef __cplusplus
} // extern "C"
#endif
#endif
......@@ -25,10 +25,12 @@
#include "paddle/phi/api/ext/exception.h"
#include "paddle/phi/capi/include/c_device_context.h"
#include "paddle/phi/capi/include/c_infer_meta_context.h"
#include "paddle/phi/capi/include/c_int_array.h"
#include "paddle/phi/capi/include/c_kernel_context.h"
#include "paddle/phi/capi/include/c_kernel_factory.h"
#include "paddle/phi/capi/include/c_kernel_registry.h"
#include "paddle/phi/capi/include/c_meta_tensor.h"
#include "paddle/phi/capi/include/c_place.h"
#include "paddle/phi/capi/include/c_scalar.h"
#include "paddle/phi/capi/include/c_tensor.h"
......@@ -70,6 +72,19 @@ inline std::vector<int64_t> PD_TensorGetDims(PD_Tensor* tensor,
return std::vector<int64_t>();
}
inline std::vector<int64_t> PD_MetaTensorGetDims(PD_MetaTensor* tensor,
PD_Status* status) {
int64_t ndims = PD_MetaTensorGetNumDims(tensor, status);
if (ndims > 0) {
std::vector<int64_t> shape(ndims);
for (int64_t i = 0; i < ndims; ++i) {
shape[i] = PD_MetaTensorGetDim(tensor, i, status);
}
return shape;
}
return std::vector<int64_t>();
}
template <typename T>
class WrapperBase {
public:
......@@ -212,26 +227,6 @@ class DenseTensor : public WrapperBase<PD_Tensor> {
return static_cast<T*>(ptr);
}
// template <typename T>
// T* mutable_data(int64_t size = 0, const PD_DeviceContext* ctx = nullptr) {
// C_Status status;
// auto ptr = PD_DeviceContextAllocateTensor(
// ctx, raw_data(), size, phi::capi::CppTypeToPDType<T>::Type(),
// &status);
// PD_CHECK_STATUS(status);
// return static_cast<T*>(ptr);
// }
// void* mutable_data(PD_DataType data_type,
// int64_t size = 0,
// const PD_DeviceContext* ctx = nullptr) {
// C_Status status;
// auto ptr = PD_DeviceContextAllocateTensor(
// ctx, raw_data(), size, data_type, &status);
// PD_CHECK_STATUS(status);
// return static_cast<void*>(ptr);
// }
DenseTensor& ShareDataWith(const DenseTensor& src) {
C_Status status;
PD_TensorShareDataWith(raw_data(), src.raw_data(), &status);
......@@ -448,10 +443,6 @@ class KernelArgsDef : WrapperBase<PD_KernelArgsDef> {
PD_DeletePointerList(list);
return ret;
}
// std::vector<AttributeArgDef>
// attribute_defs() {
// }
};
class KernelKey : WrapperBase<PD_KernelKey> {
......@@ -459,7 +450,6 @@ class KernelKey : WrapperBase<PD_KernelKey> {
explicit KernelKey(PD_KernelKey* kernel_key)
: WrapperBase<PD_KernelKey>(kernel_key) {}
// Backend backend() const { return backend_; }
PD_DataLayout layout() const {
PD_Status status;
auto layout_ = PD_KernelKeyGetLayout(raw_data(), &status);
......@@ -491,6 +481,58 @@ class Kernel : WrapperBase<PD_Kernel> {
TensorArgDef OutputAt(size_t idx) { return args_def().input_defs()[idx]; }
};
class MetaTensor : WrapperBase<PD_MetaTensor> {
public:
explicit MetaTensor(PD_MetaTensor* meta_tensor)
: WrapperBase<PD_MetaTensor>(meta_tensor) {}
std::vector<int64_t> dims() const {
C_Status status;
auto dimension = PD_MetaTensorGetDims(raw_data(), &status);
PD_CHECK_STATUS(status);
return dimension;
}
PD_DataType dtype() const {
C_Status status;
auto data_type = PD_MetaTensorGetPDDataType(raw_data(), &status);
PD_CHECK_STATUS(status);
return data_type;
}
PD_DataLayout layout() const {
C_Status status;
auto data_layout = PD_MetaTensorGetDataLayout(raw_data(), &status);
PD_CHECK_STATUS(status);
return data_layout;
}
int64_t numel() const {
C_Status status;
auto element_count = PD_MetaTensorGetElementCount(raw_data(), &status);
PD_CHECK_STATUS(status);
return element_count;
}
void set_dims(const std::vector<int64_t>& dims) {
C_Status status;
PD_MetaTensorSetDims(raw_data(), dims.size(), dims.data(), &status);
PD_CHECK_STATUS(status);
}
void set_dtype(PD_DataType data_type) {
C_Status status;
PD_MetaTensorSetDataType(raw_data(), data_type, &status);
PD_CHECK_STATUS(status);
}
void set_layout(PD_DataLayout data_layout) {
C_Status status;
PD_MetaTensorSetDataLayout(raw_data(), data_layout, &status);
PD_CHECK_STATUS(status);
}
};
} // namespace capi
} // namespace phi
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册