未验证 提交 af6ef888 编写于 作者: 石晓伟 提交者: GitHub

adjusts the mlir attrs order, test=develop (#40514)

上级 e7057932
......@@ -16,7 +16,7 @@
namespace infrt {
phi::Backend cvtTarget2Phi(TargetType target) {
phi::Backend ConvertTargetToPhi(TargetType target) {
switch (target) {
case TargetType::CPU:
return phi::Backend::CPU;
......@@ -27,7 +27,7 @@ phi::Backend cvtTarget2Phi(TargetType target) {
}
}
TargetType cvtTargetFromPhi(phi::Backend backend) {
TargetType ConvertTargetFromPhi(phi::Backend backend) {
switch (backend) {
case phi::Backend::CPU:
return TargetType::CPU;
......@@ -38,7 +38,7 @@ TargetType cvtTargetFromPhi(phi::Backend backend) {
}
}
phi::DataType cvtPrecision2Phi(PrecisionType precision) {
phi::DataType ConvertPrecisionToPhi(PrecisionType precision) {
#define CONVERT_PRECISION_TO_PHI(Precision) \
case PrecisionType::Precision: \
return phi::DataType::Precision;
......@@ -61,7 +61,7 @@ phi::DataType cvtPrecision2Phi(PrecisionType precision) {
#undef CONVERT_PRECISION_TO_PHI
}
PrecisionType cvtPrecisionFromPhi(phi::DataType datatype) {
PrecisionType ConvertPrecisionFromPhi(phi::DataType datatype) {
#define CONVERT_PRECISION_FROM_PHI(Precision) \
case phi::DataType::Precision: \
return PrecisionType::Precision;
......@@ -84,7 +84,7 @@ PrecisionType cvtPrecisionFromPhi(phi::DataType datatype) {
#undef CONVERT_PRECISION_FROM_PHI
}
phi::DataLayout cvtLayout2Phi(LayoutType layout) {
phi::DataLayout ConvertLayoutToPhi(LayoutType layout) {
switch (layout) {
case LayoutType::NCHW:
return phi::DataLayout::NCHW;
......@@ -97,7 +97,7 @@ phi::DataLayout cvtLayout2Phi(LayoutType layout) {
}
}
LayoutType cvtLayoutFromPhi(phi::DataLayout layout) {
LayoutType ConvertLayoutFromPhi(phi::DataLayout layout) {
switch (layout) {
case phi::DataLayout::NCHW:
return LayoutType::NCHW;
......@@ -110,16 +110,16 @@ LayoutType cvtLayoutFromPhi(phi::DataLayout layout) {
}
}
phi::KernelKey cvtPlace2Phi(const Place& place) {
return phi::KernelKey(cvtTarget2Phi(place.target),
cvtLayout2Phi(place.layout),
cvtPrecision2Phi(place.precision));
phi::KernelKey ConvertPlaceToPhi(const Place& place) {
return phi::KernelKey(ConvertTargetToPhi(place.target),
ConvertLayoutToPhi(place.layout),
ConvertPrecisionToPhi(place.precision));
}
Place cvtPlaceFromPhi(phi::TensorArgDef tensor_arg) {
return Place(cvtTargetFromPhi(tensor_arg.backend),
cvtPrecisionFromPhi(tensor_arg.dtype),
cvtLayoutFromPhi(tensor_arg.layout));
Place ConvertPlaceFromPhi(phi::TensorArgDef tensor_arg) {
return Place(ConvertTargetFromPhi(tensor_arg.backend),
ConvertPrecisionFromPhi(tensor_arg.dtype),
ConvertLayoutFromPhi(tensor_arg.layout));
}
} // namespace infrt
......@@ -23,16 +23,16 @@
namespace infrt {
phi::Backend cvtTarget2Phi(TargetType target);
TargetType cvtTargetFromPhi(phi::Backend backend);
phi::Backend ConvertTargetToPhi(TargetType target);
TargetType ConvertTargetFromPhi(phi::Backend backend);
phi::DataType cvtPrecision2Phi(PrecisionType precision);
PrecisionType cvtPrecisionFromPhi(phi::DataType datatype);
phi::DataType ConvertPrecisionToPhi(PrecisionType precision);
PrecisionType ConvertPrecisionFromPhi(phi::DataType datatype);
phi::DataLayout cvtLayout2Phi(LayoutType layout);
LayoutType cvtLayoutFromPhi(phi::DataLayout layout);
phi::DataLayout ConvertLayoutToPhi(LayoutType layout);
LayoutType ConvertLayoutFromPhi(phi::DataLayout layout);
phi::KernelKey cvtPlace2Phi(const Place& place);
Place cvtPlaceFromPhi(phi::TensorArgDef tensor_arg);
phi::KernelKey ConvertPlaceToPhi(const Place& place);
Place ConvertPlaceFromPhi(phi::TensorArgDef tensor_arg);
} // namespace infrt
......@@ -80,7 +80,7 @@ std::vector<PhiKernelDesc> getCandidateKernels(
phi::KernelKeyMap kernel_key_map =
phi::KernelFactory::Instance().SelectKernelMap(name);
for (Place place : valid_palces) {
phi::KernelKey kernel_key = cvtPlace2Phi(place);
phi::KernelKey kernel_key = ConvertPlaceToPhi(place);
if (kernel_key_map.find(kernel_key) == kernel_key_map.end()) {
kernel_key = phi::KernelKey(kernel_key.backend(),
phi::DataLayout::ALL_LAYOUT,
......@@ -97,10 +97,10 @@ std::vector<PhiKernelDesc> getCandidateKernels(
const paddle::SmallVector<phi::TensorArgDef>& output_arg =
args_def.output_defs();
for (auto tensor_arg : input_arg) {
phi_kernel_desc.inputsType.emplace_back(cvtPlaceFromPhi(tensor_arg));
phi_kernel_desc.inputsType.emplace_back(ConvertPlaceFromPhi(tensor_arg));
}
for (auto tensor_arg : output_arg) {
phi_kernel_desc.outputsType.emplace_back(cvtPlaceFromPhi(tensor_arg));
phi_kernel_desc.outputsType.emplace_back(ConvertPlaceFromPhi(tensor_arg));
}
candidate_kernels.emplace_back(phi_kernel_desc);
}
......
......@@ -23,8 +23,9 @@ namespace infrt {
namespace host_context {
struct KernelRegistry::Impl {
std::unordered_map<std::string, KernelImplementation> data;
std::unordered_map<std::string, llvm::SmallVector<std::string, 4>> attr_names;
std::unordered_map<std::string,
std::pair<KernelImplementation, std::vector<const char *>>>
data;
};
KernelRegistry::KernelRegistry() : impl_(std::make_unique<Impl>()) {}
......@@ -33,20 +34,29 @@ void KernelRegistry::AddKernel(const std::string &key,
KernelImplementation fn) {
CHECK(!impl_->data.count(key)) << "kernel [" << key
<< "] is registered twice";
impl_->data.emplace(key, fn);
impl_->data.emplace(
key, std::make_pair(std::move(fn), std::vector<const char *>{}));
}
void KernelRegistry::AddKernelAttrNameList(
const std::string &key, const std::vector<std::string> &names) {
CHECK(!impl_->attr_names.count(key))
<< "kernel [" << key << "] is registered twice in attribute names";
impl_->attr_names.emplace(
key, llvm::SmallVector<std::string, 4>(names.begin(), names.end()));
const std::vector<const char *> &KernelRegistry::GetAttrNameList(
const std::string &key) const {
CHECK(impl_->data.count(key));
return impl_->data[key].second;
}
void KernelRegistry::AddKernelWithAttrs(
const std::string &key,
KernelImplementation fn,
std::vector<const char *> &&attr_order) {
CHECK(!impl_->data.count(key)) << "kernel [" << key
<< "] is registered twice";
impl_->data.emplace(key,
std::make_pair(std::move(fn), std::move(attr_order)));
}
KernelImplementation KernelRegistry::GetKernel(const std::string &key) const {
auto it = impl_->data.find(key);
return it != impl_->data.end() ? it->second : KernelImplementation{};
return it != impl_->data.end() ? it->second.first : KernelImplementation{};
}
std::vector<std::string> KernelRegistry::GetKernelList() const {
......
......@@ -34,10 +34,14 @@ class KernelRegistry {
KernelRegistry();
void AddKernel(const std::string &key, KernelImplementation fn);
void AddKernelAttrNameList(const std::string &key,
const std::vector<std::string> &names);
void AddKernelWithAttrs(const std::string &key,
KernelImplementation fn,
std::vector<const char *> &&attrs_order);
KernelImplementation GetKernel(const std::string &key) const;
const std::vector<const char *> &GetAttrNameList(
const std::string &key) const;
std::vector<std::string> GetKernelList() const;
size_t size() const;
......
......@@ -43,6 +43,7 @@ MlirFunctionExecutable::MlirFunctionExecutable(
func_op.getNumResults()),
MlirToRuntimeTranslator(&core_runtime_builder_),
region_(&func_op.getRegion()),
kernel_registry_(kernel_registry),
core_runtime_builder_(kernel_registry),
function_table_(function_table) {}
......@@ -54,6 +55,7 @@ MlirFunctionExecutable::MlirFunctionExecutable(
: Function("", func_type.getNumInputs(), func_type.getNumResults()),
MlirToRuntimeTranslator(&core_runtime_builder_),
region_(region),
kernel_registry_(kernel_registry),
core_runtime_builder_(kernel_registry),
function_table_(function_table) {}
......@@ -90,7 +92,7 @@ void MlirFunctionExecutable::BuildExecutables(
if (EmitCallOp(&op, &function_table_)) continue;
if (EmitGeneralOp(&op)) continue;
if (EmitGeneralOp(&op, *kernel_registry_)) continue;
LOG(FATAL) << "Not supported op: " << DumpToString(op);
}
......
......@@ -70,6 +70,7 @@ class MlirFunctionExecutable : public Function, public MlirToRuntimeTranslator {
private:
mlir::Region* region_{};
KernelRegistry* kernel_registry_{};
CoreRuntimeBuilder core_runtime_builder_;
MlirToRuntimeTranslator::function_defs_t& function_table_;
std::function<void()> copy_res_fn_;
......
......@@ -270,7 +270,8 @@ static bool IsReturn(mlir::Operation* op) {
return op->getName().getStringRef() == "infrt.return";
}
bool MlirToRuntimeTranslator::EmitGeneralOp(mlir::Operation* op) {
bool MlirToRuntimeTranslator::EmitGeneralOp(
mlir::Operation* op, const KernelRegistry& kernel_registry) {
CHECK(impl_->runtime);
impl_->cur_op =
impl_->runtime->NewOpExecutable(op->getName().getStringRef().str());
......@@ -308,42 +309,80 @@ bool MlirToRuntimeTranslator::EmitGeneralOp(mlir::Operation* op) {
// process attributes
auto attrs = op->getAttrs();
// MLIR's underlying attr storage type is `Builtin_Dictionary`, and its
// elements
// are sorted by name. The following code adapts the order of function
// signatures
// of the phi operator library.
llvm::SmallVector<Value*, 4> tmp;
tmp.resize(attrs.size());
const std::string& kernel_name = op->getName().getStringRef().str();
const auto& attr_names = kernel_registry.GetAttrNameList(kernel_name);
if (attrs.size() && attr_names.empty()) {
LOG(WARNING) << "The kernel `" << kernel_name
<< "` has no specified attr order.";
}
auto get_offset = [](const char* attr,
const std::vector<const char*>& names,
const std::string& kernel_name) -> int {
for (size_t i = 0; i < names.size(); ++i) {
if (!std::strcmp(attr, names[i])) {
return i;
}
}
LOG(WARNING) << "The attribute `" << attr << "` of kernel `" << kernel_name
<< "` is not properly registered with "
"`KernelRegistry::AddKernelWithAttrs()`.";
return -1;
};
for (size_t i = 0; i < attrs.size(); i++) {
auto& attr = attrs[i];
int offset{};
if (attr_names.size()) {
offset = get_offset(attr.getName().data(), attr_names, kernel_name);
} else {
offset = i;
}
CHECK_NE(offset, -1);
if (auto v = EmitAttribute<int32_t>(attr.getValue())) {
impl_->cur_op->AppendAttribute(new Value(*v));
tmp[offset] = new Value(*v);
} else if (auto v = EmitAttribute<int64_t>(attr.getValue())) {
impl_->cur_op->AppendAttribute(new Value(*v));
tmp[offset] = new Value(*v);
} else if (auto v = EmitAttribute<float>(attr.getValue())) {
impl_->cur_op->AppendAttribute(new Value(*v));
tmp[offset] = new Value(*v);
} else if (auto v = EmitAttribute<double>(attr.getValue())) {
impl_->cur_op->AppendAttribute(new Value(*v));
tmp[offset] = new Value(*v);
} else if (auto v = EmitAttribute<std::string>(attr.getValue())) {
impl_->cur_op->AppendAttribute(new Value(std::move(*v)));
tmp[offset] = new Value(std::move(*v));
} else if (auto v = EmitAttribute<bool>(attr.getValue())) {
impl_->cur_op->AppendAttribute(new Value(*v));
tmp[offset] = new Value(*v);
} else if (auto v = EmitAttribute<::infrt::TargetType>(attr.getValue())) {
impl_->cur_op->AppendAttribute(new Value(*v));
tmp[offset] = new Value(*v);
} else if (auto v =
EmitAttribute<::infrt::PrecisionType>(attr.getValue())) {
impl_->cur_op->AppendAttribute(new Value(*v));
tmp[offset] = new Value(*v);
} else if (auto v = EmitAttribute<::infrt::LayoutType>(attr.getValue())) {
impl_->cur_op->AppendAttribute(new Value(*v));
tmp[offset] = new Value(*v);
} else if (auto v = EmitAttribute<std::vector<int16_t>>(attr.getValue())) {
impl_->cur_op->AppendAttribute(new Value(std::move(*v)));
tmp[offset] = new Value(std::move(*v));
} else if (auto v = EmitAttribute<std::vector<int32_t>>(attr.getValue())) {
impl_->cur_op->AppendAttribute(new Value(std::move(*v)));
tmp[offset] = new Value(std::move(*v));
} else if (auto v = EmitAttribute<std::vector<int64_t>>(attr.getValue())) {
impl_->cur_op->AppendAttribute(new Value(std::move(*v)));
tmp[offset] = new Value(std::move(*v));
} else if (auto v = EmitAttribute<std::vector<float>>(attr.getValue())) {
impl_->cur_op->AppendAttribute(new Value(std::move(*v)));
tmp[offset] = new Value(std::move(*v));
} else if (auto v = EmitAttribute<std::vector<double>>(attr.getValue())) {
impl_->cur_op->AppendAttribute(new Value(std::move(*v)));
tmp[offset] = new Value(std::move(*v));
} else {
LOG(FATAL) << "Not supported attribute type";
}
}
for (size_t i = 0; i < tmp.size(); i++) {
impl_->cur_op->AppendAttribute(tmp[i]);
}
// process results
llvm::SmallVector<Value*, 4> res_values;
for (int i = 0, e = op->getNumResults(); i < e; i++) {
......@@ -598,7 +637,7 @@ class MlirProgramTestExecutor : public MlirToRuntimeTranslator {
llvm::SmallVector<mlir::Value, 3> results;
if (EmitReturnOp(&op, &results)) continue;
if (EmitCallOp(&op, &impl_->func_defs)) continue;
if (EmitGeneralOp(&op)) continue;
if (EmitGeneralOp(&op, *registry)) continue;
LOG(FATAL) << "Not supported op: " << DumpToString(op);
}
......
......@@ -63,7 +63,8 @@ class MlirToRuntimeTranslator {
//! Emit a "ts.build_shape" operation.
bool EmitBuildShapeOp(mlir::Operation* op);
//! Emit an operation other than the special cases above.
bool EmitGeneralOp(mlir::Operation* op);
bool EmitGeneralOp(mlir::Operation* op,
const KernelRegistry& kernel_registry);
//! Emit all the functions.
bool EmitFunctions();
......
......@@ -23,23 +23,23 @@ namespace phi {
::phi::DenseTensor CreateDenseTensor(
const ::phi::CPUContext& context,
host_context::Attribute<std::vector<int64_t>> dims,
host_context::Attribute<::infrt::LayoutType> layout,
host_context::Attribute<std::vector<int64_t>> lod,
host_context::Attribute<::infrt::LayoutType> layout,
host_context::Attribute<::infrt::PrecisionType> precision) {
return ::phi::DenseTensor(
const_cast<::phi::Allocator*>(&context.GetAllocator()),
::phi::DenseTensorMeta(cvtPrecision2Phi(precision.get()),
::phi::DenseTensorMeta(ConvertPrecisionToPhi(precision.get()),
::phi::make_ddim(dims.get()),
cvtLayout2Phi(layout.get()),
ConvertLayoutToPhi(layout.get()),
{}));
}
void FillDenseTensorF32(::phi::DenseTensor* dense_tensor,
host_context::Attribute<std::vector<float>> values) {
host_context::Attribute<std::vector<float>> value) {
auto place = ::phi::CPUPlace();
float* a_data = dense_tensor->mutable_data<float>(place);
for (int64_t i = 0; i < dense_tensor->numel(); ++i) {
a_data[i] = (values.get())[i];
a_data[i] = (value.get())[i];
}
}
......@@ -57,7 +57,7 @@ void PrintDenseTensor(::phi::DenseTensor* dense_tensor) {
::phi::DDim dims = dense_tensor->dims();
std::cout << "dense_tensor: shape=shape" << dims.to_str() << ","
<< " values=[";
<< " value=[";
switch (dense_tensor->dtype()) {
PRINT_META_DATA(FLOAT32, float);
PRINT_META_DATA(INT32, int32_t);
......
......@@ -26,8 +26,8 @@ namespace phi {
::phi::DenseTensor CreateDenseTensor(
const ::phi::CPUContext& context,
host_context::Attribute<std::vector<int64_t>> dims,
host_context::Attribute<::infrt::LayoutType> layout,
host_context::Attribute<std::vector<int64_t>> lod,
host_context::Attribute<::infrt::LayoutType> layout,
host_context::Attribute<::infrt::PrecisionType> precision);
void FillDenseTensorF32(::phi::DenseTensor* dense_tensor,
......
......@@ -34,10 +34,14 @@ namespace kernel {
void RegisterPhiKernels(host_context::KernelRegistry* registry) {
registry->AddKernel("phi_dt.create_context.cpu",
INFRT_KERNEL(infrt::kernel::phi::CreateCPUContext));
registry->AddKernel("phi_dt.create_dense_tensor",
INFRT_KERNEL(infrt::kernel::phi::CreateDenseTensor));
registry->AddKernel("phi_dt.fill_dense_tensor.f32",
INFRT_KERNEL(infrt::kernel::phi::FillDenseTensorF32));
registry->AddKernelWithAttrs(
"phi_dt.create_dense_tensor",
INFRT_KERNEL(infrt::kernel::phi::CreateDenseTensor),
{"dims", "lod", "layout", "precision"});
registry->AddKernelWithAttrs(
"phi_dt.fill_dense_tensor.f32",
INFRT_KERNEL(infrt::kernel::phi::FillDenseTensorF32),
{"value"});
registry->AddKernel("phi_dt.print_tensor",
INFRT_KERNEL(infrt::kernel::phi::PrintDenseTensor));
}
......
......@@ -111,9 +111,9 @@ void NaiveMatmul(const DenseHostTensor &x,
/// ===== Kernel end ====
void RegisterTensorKernels(host_context::KernelRegistry *registry) {
registry->AddKernel("dt.create_uninit_tensor.f32",
INFRT_KERNEL(CreateUninitTensor<float>));
registry->AddKernelAttrNameList("dt.create_uninit_tensor.f32", {"shape"});
registry->AddKernelWithAttrs("dt.create_uninit_tensor.f32",
INFRT_KERNEL(CreateUninitTensor<float>),
{"shape"});
registry->AddKernel("dt.print_tensor", INFRT_KERNEL(PrintTensor));
registry->AddKernel("dt.fill_tensor_with_constant.f32",
INFRT_KERNEL(FillTensorWithConstant<float>));
......
......@@ -9,7 +9,7 @@ func @sign_any_float32_execute() {
"phi_dt.fill_dense_tensor.f32"(%t) {value=[3.8:f32]} : (!infrt.dense_tensor<CPU, FP32, NCHW>) -> ()
%e = "phi_cpu.sign.float32.any"(%ctx, %t) : (!phi.context<CPU>, !infrt.dense_tensor<CPU, FP32, NCHW>) -> (!infrt.dense_tensor<CPU, FP32, NCHW>)
// CHECK: dense_tensor: shape=shape[1], values=[1]
// CHECK: dense_tensor: shape=shape[1], value=[1]
"phi_dt.print_tensor" (%e) : (!infrt.dense_tensor<CPU, FP32, NCHW>) -> ()
infrt.return
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册