未验证 提交 3a452e4e 编写于 作者: Z zhangbo9674 提交者: GitHub

[IR] Refine IR builder and throw methods (#54396)

* refine code

* refine code

* refine code

* refine code

* refine code

* refine code

* refine code

* fix bug

* refine code

* refine code

* refine code

* refine code

* refine code

* delete unused code

* delete unused code

* refine code
上级 b62b384b
此差异已折叠。
......@@ -26,5 +26,23 @@ phi::DataLayout DataLayoutAttribute::data() const {
return storage()->GetAsKey();
}
phi::Scalar ScalarAttribute::data() {
if (isa<ir::FloatAttribute>()) {
return phi::Scalar(dyn_cast<ir::FloatAttribute>().data());
} else if (isa<ir::DoubleAttribute>()) {
return phi::Scalar(dyn_cast<ir::DoubleAttribute>().data());
} else if (isa<ir::Int32_tAttribute>()) {
return phi::Scalar(dyn_cast<ir::Int32_tAttribute>().data());
} else if (isa<ir::Int64_tAttribute>()) {
return phi::Scalar(dyn_cast<ir::Int64_tAttribute>().data());
} else if (isa<ir::BoolAttribute>()) {
return phi::Scalar(dyn_cast<ir::BoolAttribute>().data());
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"Unsupported ir attribute when casting it into "
"phi scalar."));
}
}
} // namespace dialect
} // namespace paddle
......@@ -17,6 +17,8 @@
#include "paddle/fluid/ir/dialect/pd_attribute_storage.h"
#include "paddle/ir/core/attribute.h"
#include "paddle/ir/core/builtin_attribute.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/enforce.h"
namespace paddle {
namespace dialect {
......@@ -45,6 +47,8 @@ class ScalarAttribute : public ir::Attribute {
(val.type_id() == ir::Int32_tAttribute::type_id()) ||
(val.type_id() == ir::Int64_tAttribute::type_id());
}
phi::Scalar data();
};
class DataTypeAttribute : public ir::Attribute {
......
......@@ -101,14 +101,17 @@ struct OpInputInfo {
std::string type_name;
bool optional = false;
bool no_need_buffer = false;
bool is_mutable_attribute = false;
OpInputInfo(std::string name,
std::string type_name,
bool optional,
bool no_need_buffer)
bool no_need_buffer,
bool is_mutable_attribute)
: name(name),
type_name(type_name),
optional(optional),
no_need_buffer(no_need_buffer) {}
no_need_buffer(no_need_buffer),
is_mutable_attribute(is_mutable_attribute) {}
};
struct OpOutputInfo {
......
......@@ -56,7 +56,7 @@ class Builder {
template <typename OpTy, typename... Args>
OpTy Build(Args &&...args) {
OperationArgument argument(context_->GetRegisteredOpInfo(OpTy::name()));
OpTy::Build(argument, std::forward<Args>(args)...);
OpTy::Build(*this, argument, std::forward<Args>(args)...);
Operation *op = Build(std::move(argument));
return op->dyn_cast<OpTy>();
}
......
......@@ -57,20 +57,15 @@ void ModuleOp::Verify(const std::vector<ir::OpResult> &inputs,
const ir::AttributeMap &attributes) {
VLOG(4) << "Verifying inputs, outputs and attributes for: ModuleOp.";
// Verify inputs type:
if (inputs.size() != 0) {
throw("The size of inputs must be equal to 0.");
}
IR_ENFORCE(inputs.size() == 0, "The size of inputs must be equal to 0.");
// Verify if attributes contain attribute name in attributes_name:
auto iter = attributes.find("program");
if (iter == attributes.end() || !iter->second.isa<PointerAttribute>()) {
throw("Type of attribute: program is not right.");
}
IR_ENFORCE(iter != attributes.end() && iter->second.isa<PointerAttribute>(),
"Type of attribute: program is not right.");
// Verify outputs type:
if (outputs.size() != 0) {
throw("The size of outputs must be equal to 0.");
}
IR_ENFORCE(outputs.size() == 0, "The size of outputs must be equal to 0.");
}
const char *GetParameterOp::attributes_name[attributes_num] = {
......@@ -81,17 +76,15 @@ void GetParameterOp::Verify(const std::vector<ir::OpResult> &inputs,
const ir::AttributeMap &attributes) {
VLOG(4) << "Verifying inputs, outputs and attributes for: GetParameterOp.";
// Verify inputs type:
if (inputs.size() != 0) {
throw("The size of inputs must be equal to 0.");
}
// Verify outputs type:
if (outputs.size() != 1) {
throw("The size of outputs must be equal to 1.");
}
IR_ENFORCE(inputs.size() == 0, "The size of inputs must be equal to 0.");
// Verify if attributes contain attribute name in attributes_name:
if (!attributes.at("parameter_name").isa<StrAttribute>()) {
throw("Type of attribute: parameter_name is not right.");
}
auto iter = attributes.find("parameter_name");
IR_ENFORCE(iter != attributes.end() && iter->second.isa<StrAttribute>(),
"Type of attribute: parameter_name is not right.");
// Verify outputs type:
IR_ENFORCE(outputs.size() == 1, "The size of outputs must be equal to 1.");
}
const char *SetParameterOp::attributes_name[attributes_num] = {
......@@ -102,54 +95,45 @@ void SetParameterOp::Verify(const std::vector<ir::OpResult> &inputs,
const ir::AttributeMap &attributes) {
VLOG(4) << "Verifying inputs, outputs and attributes for: SetParameterOp.";
// Verify inputs type:
if (inputs.size() != 1) {
throw("The size of inputs must be equal to 1.");
}
// Verify outputs type:
if (outputs.size() != 0) {
throw("The size of outputs must be equal to 0.");
}
IR_ENFORCE(inputs.size() == 1, "The size of outputs must be equal to 1.");
// Verify if attributes contain attribute name in attributes_name:
if (!attributes.at("parameter_name").isa<StrAttribute>()) {
throw("Type of attribute: parameter_name is not right.");
}
auto iter = attributes.find("parameter_name");
IR_ENFORCE(iter != attributes.end() && iter->second.isa<StrAttribute>(),
"Type of attribute: parameter_name is not right.");
// Verify outputs type:
IR_ENFORCE(outputs.size() == 0, "The size of outputs must be equal to 0.");
}
void CombineOp::Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes) {
// outputs.size() == 1
PADDLE_ENFORCE_EQ(
outputs.size(),
1,
phi::errors::PreconditionNotMet(
"The size %d of outputs must be equal to 1.", outputs.size()));
IR_ENFORCE(outputs.size() == 1,
"The size %d of outputs must be equal to 1.",
outputs.size());
// outputs[0].type == Vector<Type>
PADDLE_ENFORCE(outputs[0].isa<ir::VectorType>(),
phi::errors::PreconditionNotMet(
IR_ENFORCE(outputs[0].isa<ir::VectorType>(),
"The type %s of outputs[0] must be equal to VectorType.",
outputs[0]));
outputs[0]);
ir::VectorType output_type = outputs[0].dyn_cast<ir::VectorType>();
// inputs.size() == outputs[0].size()
PADDLE_ENFORCE_EQ(
output_type.size(),
inputs.size(),
phi::errors::PreconditionNotMet(
IR_ENFORCE(output_type.size() == inputs.size(),
"The size %d of outputs[0] must be equal to size %d of inputs.",
output_type.size(),
inputs.size()));
inputs.size());
// forall i in inputs.size(): inputs[i].type == outputs[0][i].type
for (size_t i = 0; i < inputs.size(); i++) {
PADDLE_ENFORCE_EQ(
output_type[i],
inputs[i].type(),
phi::errors::PreconditionNotMet("The type %s of outputs[0][%d] must be "
IR_ENFORCE(output_type[i] == inputs[i].type(),
"The type %s of outputs[0][%d] must be "
"equal to type %s of inputs[%d].",
output_type[i],
i,
inputs[i].type(),
i));
i);
}
}
......@@ -158,65 +142,50 @@ void SliceOp::Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes) {
// inputs.size() == 1
PADDLE_ENFORCE_EQ(
inputs.size(),
1,
phi::errors::PreconditionNotMet(
"The size %d of inputs must be equal to 1.", inputs.size()));
IR_ENFORCE(inputs.size() == 1,
"The size %d of inputs must be equal to 1.",
inputs.size());
// inputs[0].type == Vector<Type>
PADDLE_ENFORCE(inputs[0].type().isa<ir::VectorType>(),
phi::errors::PreconditionNotMet(
IR_ENFORCE(inputs[0].type().isa<ir::VectorType>(),
"The type %s of inputs[0] must be equal to VectorType.",
inputs[0].type()));
inputs[0].type());
ir::VectorType input_type = inputs[0].type().dyn_cast<ir::VectorType>();
// outputs.size() == 1
PADDLE_ENFORCE_EQ(
outputs.size(),
1,
phi::errors::PreconditionNotMet(
"The size %d of outputs must be equal to 1.", outputs.size()));
IR_ENFORCE(outputs.size() == 1,
"The size %d of outputs must be equal to 1.",
outputs.size());
// attributes contains index: Int32
PADDLE_ENFORCE_NE(
attributes.count("index"),
0,
phi::errors::PreconditionNotMet("The attributes must contains index."));
IR_ENFORCE(attributes.count("index") != 0,
"The attributes must contains index.");
const ir::Attribute &attr = attributes.at("index");
PADDLE_ENFORCE(
attr.isa<ir::Int32_tAttribute>(),
phi::errors::PreconditionNotMet("The attribute index must be INT32."));
IR_ENFORCE(attr.isa<ir::Int32_tAttribute>(),
"The attribute index must be INT32.");
auto index = attr.dyn_cast<ir::Int32_tAttribute>().data();
// index >= 0 and < inputs[0].size()
PADDLE_ENFORCE_GE(
index,
0,
phi::errors::PreconditionNotMet(
"The index %d must be greater or equal than 0.", index));
PADDLE_ENFORCE_LT(
index,
input_type.size(),
phi::errors::PreconditionNotMet(
IR_ENFORCE(
index >= 0, "The index %d must be greater or equal than 0.", index);
IR_ENFORCE(static_cast<size_t>(index) < input_type.size(),
"The index %d must be less or equal than size %d of inputs[0].",
index,
input_type.size()));
input_type.size());
// inputs[index].type == outputs[0].type
PADDLE_ENFORCE_EQ(
input_type[index],
outputs[0],
phi::errors::PreconditionNotMet(
IR_ENFORCE(
input_type[index] == outputs[0],
"The type %s of inputs[%d] must be equal to type %s of outputs[0].",
input_type[index],
index,
outputs[0]));
outputs[0]);
}
const char *ConstantOp::attributes_name[attributes_num] = {"value"};
void ConstantOp::Build(OperationArgument &argument,
void ConstantOp::Build(Builder &builder,
OperationArgument &argument,
Attribute value,
Type output_type) {
argument.AddAttribute("value", value);
......
......@@ -86,6 +86,7 @@ class CombineOp : public ir::Op<CombineOp> {
static constexpr uint32_t attributes_num = 0;
static constexpr const char **attributes_name = nullptr;
static void Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes);
......@@ -125,7 +126,8 @@ class ConstantOp : public Op<ConstantOp, ConstantLikeTrait> {
static constexpr uint32_t attributes_num = 1;
static const char *attributes_name[attributes_num];
static void Build(OperationArgument &argument, // NOLINT
static void Build(Builder &builder, // NOLINT
OperationArgument &argument, // NOLINT
Attribute value,
Type output_type);
......
......@@ -16,6 +16,7 @@
#include "paddle/ir/core/block.h"
#include "paddle/ir/core/dialect.h"
#include "paddle/ir/core/enforce.h"
#include "paddle/ir/core/op_info.h"
#include "paddle/ir/core/operation.h"
#include "paddle/ir/core/program.h"
......@@ -85,7 +86,7 @@ Operation *Operation::Create(const std::vector<ir::OpResult> &inputs,
base_ptr += sizeof(Operation);
// 3.3. Construct OpOperands.
if ((reinterpret_cast<uintptr_t>(base_ptr) & 0x7) != 0) {
throw("The address of OpOperandImpl must be divisible by 8.");
IR_THROW("The address of OpOperandImpl must be divisible by 8.");
}
for (size_t idx = 0; idx < num_operands; idx++) {
new (base_ptr) detail::OpOperandImpl(inputs[idx].impl_, op);
......@@ -147,7 +148,7 @@ void Operation::Destroy() {
// 2.2. Deconstruct Operation.
if (reinterpret_cast<uintptr_t>(base_ptr) !=
reinterpret_cast<uintptr_t>(this)) {
throw("Operation address error");
IR_THROW("Operation address error");
}
reinterpret_cast<Operation *>(base_ptr)->~Operation();
base_ptr += sizeof(Operation);
......@@ -178,7 +179,7 @@ Operation::Operation(const AttributeMap &attributes,
ir::OpResult Operation::GetResultByIndex(uint32_t index) const {
if (index >= num_results_) {
throw("index exceeds OP output range.");
IR_THROW("index exceeds OP output range.");
}
uint32_t max_inline_idx = detail::OpResultImpl::GetMaxInlineResultIndex();
const char *ptr =
......@@ -199,7 +200,7 @@ ir::OpResult Operation::GetResultByIndex(uint32_t index) const {
ir::OpOperand Operation::GetOperandByIndex(uint32_t index) const {
if (index >= num_operands_) {
throw("index exceeds OP input range.");
IR_THROW("index exceeds OP input range.");
}
const char *ptr = reinterpret_cast<const char *>(this) + sizeof(Operation) +
(index) * sizeof(detail::OpOperandImpl);
......
......@@ -17,6 +17,8 @@
#include <memory>
#include <unordered_map>
#include "paddle/ir/core/enforce.h"
namespace ir {
// This is a structure for creating, caching, and looking up Storage of
// parametric types.
......@@ -76,7 +78,7 @@ StorageManager::StorageBase *StorageManager::GetParametricStorageImpl(
<< std::hash<ir::TypeId>()(type_id) << ", param_hash=" << hash_value
<< "].";
if (parametric_instance_.find(type_id) == parametric_instance_.end()) {
throw("The input data pointer is null.");
IR_THROW("The input data pointer is null.");
}
ParametricStorageManager &parametric_storage = *parametric_instance_[type_id];
return parametric_storage.GetOrCreate(hash_value, equal_func, constructor);
......@@ -88,7 +90,7 @@ StorageManager::StorageBase *StorageManager::GetParameterlessStorageImpl(
VLOG(4) << "Try to get a parameterless storage of: [TypeId_hash="
<< std::hash<ir::TypeId>()(type_id) << "].";
if (parameterless_instance_.find(type_id) == parameterless_instance_.end())
throw("TypeId not found in IrContext.");
IR_THROW("TypeId not found in IrContext.");
StorageBase *parameterless_instance = parameterless_instance_[type_id];
return parameterless_instance;
}
......@@ -107,7 +109,7 @@ void StorageManager::RegisterParameterlessStorageImpl(
VLOG(4) << "Register a parameterless storage of: [TypeId_hash="
<< std::hash<ir::TypeId>()(type_id) << "].";
if (parameterless_instance_.find(type_id) != parameterless_instance_.end())
throw("storage class already registered");
IR_THROW("storage class already registered");
parameterless_instance_.emplace(type_id, constructor());
}
......
......@@ -427,6 +427,18 @@
data_type : dtype
backend : place
- op : full_int_array
args : (IntArray value, DataType dtype=DataType::FLOAT32, Place place=CPUPlace())
output: Tensor(out)
infer_meta :
func : CreateIntArrayInferMeta
param : [value, dtype]
kernel :
func : full_int_array
param : [value, dtype]
data_type : dtype
backend : place
- op : full_like
args : (Tensor x, Scalar value, DataType dtype = DataType::UNDEFINED, Place place = {})
output: Tensor(out)
......
......@@ -62,9 +62,9 @@
beta2 :
data_type : float
tensor_name : Beta2Tensor
episilon :
epsilon :
data_type : float
tensor_name : EpisilonTensor
tensor_name : EpsilonTensor
manual_signature : [adam_]
- op : adamax_
......@@ -85,9 +85,9 @@
beta2 :
data_type : float
tensor_name : Beta2Tensor
episilon :
epsilon :
data_type : float
tensor_name : EpisilonTensor
tensor_name : EpsilonTensor
- op : add (elementwise_add)
backward : add_grad (elementwise_add_grad)
......@@ -1970,7 +1970,7 @@
outputs:
out : Out
int_array:
axis :
dims :
data_type : int
extra :
attrs : [bool use_mkldnn = false]
......
......@@ -41,6 +41,15 @@ void CreateInferMeta(const IntArray& shape, DataType dtype, MetaTensor* out) {
CreateInferMetaBase(shape.GetData(), dtype, DataLayout::NCHW, out);
}
void CreateIntArrayInferMeta(const IntArray& data,
DataType dtype,
MetaTensor* out) {
CreateInferMetaBase({static_cast<int64_t>(data.GetData().size())},
dtype,
DataLayout::NCHW,
out);
}
void CreateInferMetaBase(const std::vector<int64_t>& shape,
DataType dtype,
DataLayout layout,
......
......@@ -35,6 +35,10 @@ void AssignValueInferMeta(const std::vector<int>& shape,
DataType dtype,
MetaTensor* out);
void CreateIntArrayInferMeta(const IntArray& data,
DataType dtype,
MetaTensor* out);
void CreateInferMeta(const IntArray& shape, DataType dtype, MetaTensor* out);
void CreateInferMetaBase(const std::vector<int64_t>& shape,
......
......@@ -80,6 +80,18 @@ void FullLikeKernel(const Context& dev_ctx,
FullValue<T>(dev_ctx, out, value);
}
template <typename T, typename Context>
void FullIntArrayKernel(const Context& dev_ctx,
const IntArray& val,
DataType dtype UNUSED,
DenseTensor* out) {
out->Resize(phi::make_ddim({static_cast<int64_t>(val.GetData().size())}));
T* out_data = dev_ctx.template Alloc<T>(out);
for (size_t i = 0; i < val.GetData().size(); ++i) {
out_data[i] = static_cast<T>(val.GetData()[i]);
}
}
} // namespace phi
PD_REGISTER_KERNEL(full,
......@@ -115,3 +127,6 @@ PD_REGISTER_KERNEL(full_like,
phi::dtype::complex<double>) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
}
PD_REGISTER_KERNEL(
full_int_array, CPU, ALL_LAYOUT, phi::FullIntArrayKernel, int, int64_t) {}
......@@ -83,4 +83,10 @@ DenseTensor FullLike(const Context& dev_ctx,
return dense_out;
}
template <typename T, typename Context>
void FullIntArrayKernel(const Context& dev_ctx,
const IntArray& val,
DataType dtype,
DenseTensor* out);
} // namespace phi
......@@ -44,76 +44,61 @@
#include "paddle/phi/core/kernel_registry.h"
#include "test/cpp/ir/core/phi_kernel_adaptor.h"
PD_DECLARE_KERNEL(full, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(full_int_array, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(uniform, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(add, CPU, ALL_LAYOUT);
bool simple_cmp(float a, float b) { return std::abs((a - b) / a) < 1e-5; }
TEST(program_test, program) {
// Prepare ir env
ir::IrContext* ctx = ir::IrContext::Instance();
ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>();
ir::Program program(ctx);
ir::Builder builder = ir::Builder::AtBlockEnd(ctx, program.block());
ir::Block* block = program.block();
ir::Type fp32_dtype = ir::Float32Type::get(ctx);
paddle::dialect::DenseTensorTypeStorage::Dim dims = {2, 2};
paddle::dialect::DenseTensorTypeStorage::DataLayout data_layout =
paddle::dialect::DenseTensorTypeStorage::DataLayout::NCHW;
paddle::dialect::DenseTensorTypeStorage::LoD lod = {};
size_t offset = 0;
ir::Type dense_tensor_dtype = paddle::dialect::DenseTensorType::get(
ctx, fp32_dtype, dims, data_layout, lod, offset);
// (1) Def a = GetParameterOp("a")
std::string op1_name = std::string(paddle::dialect::UniformOp::name());
ir::OpInfo op1_info = ctx->GetRegisteredOpInfo(op1_name);
// ir::Attribute shape_1 = ir::ArrayAttribute::get(ctx, {ten} );
ir::Attribute shape_1 = paddle::dialect::IntArrayAttribute::get(
ctx, std::vector<int64_t>({2, 2}));
ir::Attribute data_type =
paddle::dialect::DataTypeAttribute::get(ctx, phi::DataType::FLOAT32);
ir::Attribute min = ir::FloatAttribute::get(ctx, 0.0);
ir::Attribute max = ir::FloatAttribute::get(ctx, 1.0);
ir::Attribute seed = ir::Int32_tAttribute::get(ctx, 2);
ir::Attribute uni_place = paddle::dialect::PlaceAttribute::get(
ctx, phi::Place(phi::AllocationType::CPU));
std::unordered_map<std::string, ir::Attribute> op1_attribute{
{"shape", shape_1},
{"dtype", data_type},
{"min", min},
{"max", max},
{"seed", seed},
{"place", uni_place}};
ir::Operation* op1 =
ir::Operation::Create({}, op1_attribute, {dense_tensor_dtype}, op1_info);
block->push_back(op1);
// (2) Def b = GetParameterOp("b")
std::string op2_name = std::string(paddle::dialect::UniformOp::name());
ir::OpInfo op2_info = ctx->GetRegisteredOpInfo(op2_name);
ir::Attribute ten2 = ir::Int32_tAttribute::get(ctx, 3);
std::unordered_map<std::string, ir::Attribute> op2_attribute{{"shape", ten2}};
ir::Operation* op2 =
ir::Operation::Create({}, op1_attribute, {dense_tensor_dtype}, op2_info);
block->push_back(op2);
// (3) Def out = AddOp(a, b)
std::string add_op_name = std::string(paddle::dialect::AddOp::name());
ir::OpInfo add_op_info = ctx->GetRegisteredOpInfo(add_op_name);
ir::Operation* add_op = ir::Operation::Create(
{op1->GetResultByIndex(0), op2->GetResultByIndex(0)},
{},
{dense_tensor_dtype},
add_op_info);
block->push_back(add_op);
// Def: A = paddle::dialect::UniformOp(std::vector<int64_t> shape,
// phi::DataType dtype, float min, float max, int seed, phi::Place place)
paddle::dialect::UniformOp uniform1 =
builder.Build<paddle::dialect::UniformOp>(std::vector<int64_t>{2, 2},
phi::DataType::FLOAT32,
0.0,
1.0,
2,
phi::CPUPlace());
EXPECT_EQ(uniform1->GetResultByIndex(0)
.type()
.isa<paddle::dialect::DenseTensorType>(),
true);
EXPECT_EQ(block->size(), 4u);
// Def: B = paddle::dialect::UniformOp(...)
paddle::dialect::UniformOp uniform2 =
builder.Build<paddle::dialect::UniformOp>(std::vector<int64_t>{2, 2},
phi::DataType::FLOAT32,
0.0,
1.0,
2,
phi::CPUPlace());
EXPECT_EQ(uniform2->GetResultByIndex(0)
.type()
.isa<paddle::dialect::DenseTensorType>(),
true);
EXPECT_EQ(block->size(), 8u);
// Def: C = paddle::dialect::AddOp(ir::OpResult x_, ir::OpResult y_)
paddle::dialect::AddOp add = builder.Build<paddle::dialect::AddOp>(
uniform1->GetResultByIndex(0), uniform2->GetResultByIndex(0));
EXPECT_EQ(
add->GetResultByIndex(0).type().isa<paddle::dialect::DenseTensorType>(),
true);
EXPECT_EQ(block->size(), 9u);
// Execute program
paddle::framework::Scope scope;
PhiKernelAdaptor phi_kernel_adaptor(&scope);
phi_kernel_adaptor.run(&program);
auto out_tensor =
......
......@@ -20,6 +20,7 @@
#include "paddle/ir/core/builtin_op.h"
#include "paddle/ir/core/builtin_type.h"
#include "paddle/ir/core/dialect.h"
#include "paddle/ir/core/enforce.h"
#include "paddle/ir/core/ir_context.h"
#include "paddle/ir/core/op_base.h"
#include "paddle/ir/core/program.h"
......@@ -240,10 +241,12 @@ TEST(op_test, module_op_death) {
ir::AttributeMap attrs{{"program", ir::Int32_tAttribute::get(ctx, 1)}};
std::vector<ir::Type> output_types = {ir::Float32Type::get(ctx)};
EXPECT_THROW(ir::Operation::Create(inputs, {}, {}, op_info), const char *);
EXPECT_THROW(ir::Operation::Create({}, attrs, {}, op_info), const char *);
EXPECT_THROW(ir::Operation::Create(inputs, {}, {}, op_info),
ir::IrNotMetException);
EXPECT_THROW(ir::Operation::Create({}, attrs, {}, op_info),
ir::IrNotMetException);
EXPECT_THROW(ir::Operation::Create({}, {}, output_types, op_info),
const char *);
ir::IrNotMetException);
ir::Program program(ctx);
......
......@@ -98,27 +98,30 @@ void build_context(ir::Operation* op,
op->dyn_cast<paddle::dialect::OpYamlInfoInterface>();
auto op_info_res = op_info_interface.GetOpInfo();
// inputs include input and mutable attributes
auto input_info = std::get<0>(op_info_res);
std::set<std::string> input_set;
std::map<std::string, size_t> input_index_map;
std::map<std::string, std::string> mutable_attr_type_map;
int input_index = 0;
for (auto& t : input_info) {
VLOG(6) << t.name << "\t" << t.type_name;
input_set.insert(t.name);
input_index_map[t.name] = input_index++;
if (t.is_mutable_attribute) {
mutable_attr_type_map[t.name] = t.type_name;
}
}
auto attr_map = op->attributes();
std::map<std::string, std::string> attr_type_map;
auto attr_info = std::get<1>(op_info_res);
std::map<std::string, std::string> attr_type_map;
for (auto& t : attr_info) {
VLOG(6) << t.name << "\t" << t.type_name;
attr_type_map[t.name] = t.type_name;
}
auto attr_map = op->attributes();
auto runtime_info = std::get<3>(op_info_res);
int input_index = 0;
// int input_index = 0;
std::vector<std::string> vec_param_list;
if (is_infer_meta) {
vec_param_list = runtime_info.infer_meta_param;
......@@ -126,14 +129,32 @@ void build_context(ir::Operation* op,
vec_param_list = runtime_info.kernel_param;
}
for (auto& t : vec_param_list) {
if (input_set.count(t)) {
if (input_index_map.count(t)) {
// get information from input
ir::Value ptr = op->GetOperandByIndex(input_index++).source();
ir::Value ptr = op->GetOperandByIndex(input_index_map[t]).source();
auto in_var_name = name_map.at(ptr);
if (mutable_attr_type_map.count(t)) {
VLOG(6) << "ctx->EmplaceBack mutable attr: " << t << "\t"
<< in_var_name;
if (mutable_attr_type_map[t] == "paddle::dialect::IntArrayAttribute") {
ctx->EmplaceBackAttr(phi::IntArray(
*(scope->Var(in_var_name)->GetMutable<phi::DenseTensor>())));
} else if (mutable_attr_type_map[t] ==
"paddle::dialect::ScalarAttribute") {
ctx->EmplaceBackAttr(phi::Scalar(
*(scope->Var(in_var_name)->GetMutable<phi::DenseTensor>())));
} else {
PADDLE_THROW(phi::errors::Unimplemented("attr type not support [%s] ",
mutable_attr_type_map[t]));
}
} else {
VLOG(6) << "ctx->EmplaceBackInput: " << t << "\t" << in_var_name;
ctx->EmplaceBackInput(
scope->Var(in_var_name)->GetMutable<phi::DenseTensor>());
}
}
if (attr_type_map.count(t)) {
auto type_name = attr_type_map[t];
......@@ -149,10 +170,14 @@ void build_context(ir::Operation* op,
} else if (type_name == "paddle::dialect::PlaceAttribute") {
ctx->EmplaceBackAttr(
attr_map[t].dyn_cast<paddle::dialect::PlaceAttribute>().data());
} else if (type_name == "paddle::dialect::ScalarAttribute") {
ctx->EmplaceBackAttr(
attr_map[t].dyn_cast<paddle::dialect::ScalarAttribute>().data());
} else {
PADDLE_THROW(phi::errors::Unimplemented("attr type not support [%s] ",
type_name));
}
VLOG(6) << "ctx->EmplaceBackAttr: " << t;
}
}
......@@ -197,6 +222,9 @@ class PhiKernelAdaptor {
phi::KernelKey kernel_key(phi::TransToPhiBackend(cpu_place),
phi::DataLayout::ANY,
phi::DataType::FLOAT32);
if (runtime_info.kernel_func[0] == "full_int_array") {
kernel_key.set_dtype(phi::DataType::INT64);
}
auto found_it = phi_kernels.find(kernel_key);
if (found_it == phi_kernels.end()) {
std::cerr << "kernel name " << runtime_info.kernel_func[0] << std::endl;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册