未验证 提交 af6ef888 编写于 作者: 石晓伟 提交者: GitHub

adjusts the mlir attrs order, test=develop (#40514)

上级 e7057932
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
namespace infrt { namespace infrt {
phi::Backend cvtTarget2Phi(TargetType target) { phi::Backend ConvertTargetToPhi(TargetType target) {
switch (target) { switch (target) {
case TargetType::CPU: case TargetType::CPU:
return phi::Backend::CPU; return phi::Backend::CPU;
...@@ -27,7 +27,7 @@ phi::Backend cvtTarget2Phi(TargetType target) { ...@@ -27,7 +27,7 @@ phi::Backend cvtTarget2Phi(TargetType target) {
} }
} }
TargetType cvtTargetFromPhi(phi::Backend backend) { TargetType ConvertTargetFromPhi(phi::Backend backend) {
switch (backend) { switch (backend) {
case phi::Backend::CPU: case phi::Backend::CPU:
return TargetType::CPU; return TargetType::CPU;
...@@ -38,7 +38,7 @@ TargetType cvtTargetFromPhi(phi::Backend backend) { ...@@ -38,7 +38,7 @@ TargetType cvtTargetFromPhi(phi::Backend backend) {
} }
} }
phi::DataType cvtPrecision2Phi(PrecisionType precision) { phi::DataType ConvertPrecisionToPhi(PrecisionType precision) {
#define CONVERT_PRECISION_TO_PHI(Precision) \ #define CONVERT_PRECISION_TO_PHI(Precision) \
case PrecisionType::Precision: \ case PrecisionType::Precision: \
return phi::DataType::Precision; return phi::DataType::Precision;
...@@ -61,7 +61,7 @@ phi::DataType cvtPrecision2Phi(PrecisionType precision) { ...@@ -61,7 +61,7 @@ phi::DataType cvtPrecision2Phi(PrecisionType precision) {
#undef CONVERT_PRECISION_TO_PHI #undef CONVERT_PRECISION_TO_PHI
} }
PrecisionType cvtPrecisionFromPhi(phi::DataType datatype) { PrecisionType ConvertPrecisionFromPhi(phi::DataType datatype) {
#define CONVERT_PRECISION_FROM_PHI(Precision) \ #define CONVERT_PRECISION_FROM_PHI(Precision) \
case phi::DataType::Precision: \ case phi::DataType::Precision: \
return PrecisionType::Precision; return PrecisionType::Precision;
...@@ -84,7 +84,7 @@ PrecisionType cvtPrecisionFromPhi(phi::DataType datatype) { ...@@ -84,7 +84,7 @@ PrecisionType cvtPrecisionFromPhi(phi::DataType datatype) {
#undef CONVERT_PRECISION_FROM_PHI #undef CONVERT_PRECISION_FROM_PHI
} }
phi::DataLayout cvtLayout2Phi(LayoutType layout) { phi::DataLayout ConvertLayoutToPhi(LayoutType layout) {
switch (layout) { switch (layout) {
case LayoutType::NCHW: case LayoutType::NCHW:
return phi::DataLayout::NCHW; return phi::DataLayout::NCHW;
...@@ -97,7 +97,7 @@ phi::DataLayout cvtLayout2Phi(LayoutType layout) { ...@@ -97,7 +97,7 @@ phi::DataLayout cvtLayout2Phi(LayoutType layout) {
} }
} }
LayoutType cvtLayoutFromPhi(phi::DataLayout layout) { LayoutType ConvertLayoutFromPhi(phi::DataLayout layout) {
switch (layout) { switch (layout) {
case phi::DataLayout::NCHW: case phi::DataLayout::NCHW:
return LayoutType::NCHW; return LayoutType::NCHW;
...@@ -110,16 +110,16 @@ LayoutType cvtLayoutFromPhi(phi::DataLayout layout) { ...@@ -110,16 +110,16 @@ LayoutType cvtLayoutFromPhi(phi::DataLayout layout) {
} }
} }
phi::KernelKey cvtPlace2Phi(const Place& place) { phi::KernelKey ConvertPlaceToPhi(const Place& place) {
return phi::KernelKey(cvtTarget2Phi(place.target), return phi::KernelKey(ConvertTargetToPhi(place.target),
cvtLayout2Phi(place.layout), ConvertLayoutToPhi(place.layout),
cvtPrecision2Phi(place.precision)); ConvertPrecisionToPhi(place.precision));
} }
Place cvtPlaceFromPhi(phi::TensorArgDef tensor_arg) { Place ConvertPlaceFromPhi(phi::TensorArgDef tensor_arg) {
return Place(cvtTargetFromPhi(tensor_arg.backend), return Place(ConvertTargetFromPhi(tensor_arg.backend),
cvtPrecisionFromPhi(tensor_arg.dtype), ConvertPrecisionFromPhi(tensor_arg.dtype),
cvtLayoutFromPhi(tensor_arg.layout)); ConvertLayoutFromPhi(tensor_arg.layout));
} }
} // namespace infrt } // namespace infrt
...@@ -23,16 +23,16 @@ ...@@ -23,16 +23,16 @@
namespace infrt { namespace infrt {
phi::Backend cvtTarget2Phi(TargetType target); phi::Backend ConvertTargetToPhi(TargetType target);
TargetType cvtTargetFromPhi(phi::Backend backend); TargetType ConvertTargetFromPhi(phi::Backend backend);
phi::DataType cvtPrecision2Phi(PrecisionType precision); phi::DataType ConvertPrecisionToPhi(PrecisionType precision);
PrecisionType cvtPrecisionFromPhi(phi::DataType datatype); PrecisionType ConvertPrecisionFromPhi(phi::DataType datatype);
phi::DataLayout cvtLayout2Phi(LayoutType layout); phi::DataLayout ConvertLayoutToPhi(LayoutType layout);
LayoutType cvtLayoutFromPhi(phi::DataLayout layout); LayoutType ConvertLayoutFromPhi(phi::DataLayout layout);
phi::KernelKey cvtPlace2Phi(const Place& place); phi::KernelKey ConvertPlaceToPhi(const Place& place);
Place cvtPlaceFromPhi(phi::TensorArgDef tensor_arg); Place ConvertPlaceFromPhi(phi::TensorArgDef tensor_arg);
} // namespace infrt } // namespace infrt
...@@ -80,7 +80,7 @@ std::vector<PhiKernelDesc> getCandidateKernels( ...@@ -80,7 +80,7 @@ std::vector<PhiKernelDesc> getCandidateKernels(
phi::KernelKeyMap kernel_key_map = phi::KernelKeyMap kernel_key_map =
phi::KernelFactory::Instance().SelectKernelMap(name); phi::KernelFactory::Instance().SelectKernelMap(name);
for (Place place : valid_palces) { for (Place place : valid_palces) {
phi::KernelKey kernel_key = cvtPlace2Phi(place); phi::KernelKey kernel_key = ConvertPlaceToPhi(place);
if (kernel_key_map.find(kernel_key) == kernel_key_map.end()) { if (kernel_key_map.find(kernel_key) == kernel_key_map.end()) {
kernel_key = phi::KernelKey(kernel_key.backend(), kernel_key = phi::KernelKey(kernel_key.backend(),
phi::DataLayout::ALL_LAYOUT, phi::DataLayout::ALL_LAYOUT,
...@@ -97,10 +97,10 @@ std::vector<PhiKernelDesc> getCandidateKernels( ...@@ -97,10 +97,10 @@ std::vector<PhiKernelDesc> getCandidateKernels(
const paddle::SmallVector<phi::TensorArgDef>& output_arg = const paddle::SmallVector<phi::TensorArgDef>& output_arg =
args_def.output_defs(); args_def.output_defs();
for (auto tensor_arg : input_arg) { for (auto tensor_arg : input_arg) {
phi_kernel_desc.inputsType.emplace_back(cvtPlaceFromPhi(tensor_arg)); phi_kernel_desc.inputsType.emplace_back(ConvertPlaceFromPhi(tensor_arg));
} }
for (auto tensor_arg : output_arg) { for (auto tensor_arg : output_arg) {
phi_kernel_desc.outputsType.emplace_back(cvtPlaceFromPhi(tensor_arg)); phi_kernel_desc.outputsType.emplace_back(ConvertPlaceFromPhi(tensor_arg));
} }
candidate_kernels.emplace_back(phi_kernel_desc); candidate_kernels.emplace_back(phi_kernel_desc);
} }
......
...@@ -23,8 +23,9 @@ namespace infrt { ...@@ -23,8 +23,9 @@ namespace infrt {
namespace host_context { namespace host_context {
struct KernelRegistry::Impl { struct KernelRegistry::Impl {
std::unordered_map<std::string, KernelImplementation> data; std::unordered_map<std::string,
std::unordered_map<std::string, llvm::SmallVector<std::string, 4>> attr_names; std::pair<KernelImplementation, std::vector<const char *>>>
data;
}; };
KernelRegistry::KernelRegistry() : impl_(std::make_unique<Impl>()) {} KernelRegistry::KernelRegistry() : impl_(std::make_unique<Impl>()) {}
...@@ -33,20 +34,29 @@ void KernelRegistry::AddKernel(const std::string &key, ...@@ -33,20 +34,29 @@ void KernelRegistry::AddKernel(const std::string &key,
KernelImplementation fn) { KernelImplementation fn) {
CHECK(!impl_->data.count(key)) << "kernel [" << key CHECK(!impl_->data.count(key)) << "kernel [" << key
<< "] is registered twice"; << "] is registered twice";
impl_->data.emplace(key, fn); impl_->data.emplace(
key, std::make_pair(std::move(fn), std::vector<const char *>{}));
} }
void KernelRegistry::AddKernelAttrNameList( const std::vector<const char *> &KernelRegistry::GetAttrNameList(
const std::string &key, const std::vector<std::string> &names) { const std::string &key) const {
CHECK(!impl_->attr_names.count(key)) CHECK(impl_->data.count(key));
<< "kernel [" << key << "] is registered twice in attribute names"; return impl_->data[key].second;
impl_->attr_names.emplace( }
key, llvm::SmallVector<std::string, 4>(names.begin(), names.end()));
void KernelRegistry::AddKernelWithAttrs(
const std::string &key,
KernelImplementation fn,
std::vector<const char *> &&attr_order) {
CHECK(!impl_->data.count(key)) << "kernel [" << key
<< "] is registered twice";
impl_->data.emplace(key,
std::make_pair(std::move(fn), std::move(attr_order)));
} }
KernelImplementation KernelRegistry::GetKernel(const std::string &key) const { KernelImplementation KernelRegistry::GetKernel(const std::string &key) const {
auto it = impl_->data.find(key); auto it = impl_->data.find(key);
return it != impl_->data.end() ? it->second : KernelImplementation{}; return it != impl_->data.end() ? it->second.first : KernelImplementation{};
} }
std::vector<std::string> KernelRegistry::GetKernelList() const { std::vector<std::string> KernelRegistry::GetKernelList() const {
......
...@@ -34,10 +34,14 @@ class KernelRegistry { ...@@ -34,10 +34,14 @@ class KernelRegistry {
KernelRegistry(); KernelRegistry();
void AddKernel(const std::string &key, KernelImplementation fn); void AddKernel(const std::string &key, KernelImplementation fn);
void AddKernelAttrNameList(const std::string &key, void AddKernelWithAttrs(const std::string &key,
const std::vector<std::string> &names); KernelImplementation fn,
std::vector<const char *> &&attrs_order);
KernelImplementation GetKernel(const std::string &key) const; KernelImplementation GetKernel(const std::string &key) const;
const std::vector<const char *> &GetAttrNameList(
const std::string &key) const;
std::vector<std::string> GetKernelList() const; std::vector<std::string> GetKernelList() const;
size_t size() const; size_t size() const;
......
...@@ -43,6 +43,7 @@ MlirFunctionExecutable::MlirFunctionExecutable( ...@@ -43,6 +43,7 @@ MlirFunctionExecutable::MlirFunctionExecutable(
func_op.getNumResults()), func_op.getNumResults()),
MlirToRuntimeTranslator(&core_runtime_builder_), MlirToRuntimeTranslator(&core_runtime_builder_),
region_(&func_op.getRegion()), region_(&func_op.getRegion()),
kernel_registry_(kernel_registry),
core_runtime_builder_(kernel_registry), core_runtime_builder_(kernel_registry),
function_table_(function_table) {} function_table_(function_table) {}
...@@ -54,6 +55,7 @@ MlirFunctionExecutable::MlirFunctionExecutable( ...@@ -54,6 +55,7 @@ MlirFunctionExecutable::MlirFunctionExecutable(
: Function("", func_type.getNumInputs(), func_type.getNumResults()), : Function("", func_type.getNumInputs(), func_type.getNumResults()),
MlirToRuntimeTranslator(&core_runtime_builder_), MlirToRuntimeTranslator(&core_runtime_builder_),
region_(region), region_(region),
kernel_registry_(kernel_registry),
core_runtime_builder_(kernel_registry), core_runtime_builder_(kernel_registry),
function_table_(function_table) {} function_table_(function_table) {}
...@@ -90,7 +92,7 @@ void MlirFunctionExecutable::BuildExecutables( ...@@ -90,7 +92,7 @@ void MlirFunctionExecutable::BuildExecutables(
if (EmitCallOp(&op, &function_table_)) continue; if (EmitCallOp(&op, &function_table_)) continue;
if (EmitGeneralOp(&op)) continue; if (EmitGeneralOp(&op, *kernel_registry_)) continue;
LOG(FATAL) << "Not supported op: " << DumpToString(op); LOG(FATAL) << "Not supported op: " << DumpToString(op);
} }
......
...@@ -70,6 +70,7 @@ class MlirFunctionExecutable : public Function, public MlirToRuntimeTranslator { ...@@ -70,6 +70,7 @@ class MlirFunctionExecutable : public Function, public MlirToRuntimeTranslator {
private: private:
mlir::Region* region_{}; mlir::Region* region_{};
KernelRegistry* kernel_registry_{};
CoreRuntimeBuilder core_runtime_builder_; CoreRuntimeBuilder core_runtime_builder_;
MlirToRuntimeTranslator::function_defs_t& function_table_; MlirToRuntimeTranslator::function_defs_t& function_table_;
std::function<void()> copy_res_fn_; std::function<void()> copy_res_fn_;
......
...@@ -270,7 +270,8 @@ static bool IsReturn(mlir::Operation* op) { ...@@ -270,7 +270,8 @@ static bool IsReturn(mlir::Operation* op) {
return op->getName().getStringRef() == "infrt.return"; return op->getName().getStringRef() == "infrt.return";
} }
bool MlirToRuntimeTranslator::EmitGeneralOp(mlir::Operation* op) { bool MlirToRuntimeTranslator::EmitGeneralOp(
mlir::Operation* op, const KernelRegistry& kernel_registry) {
CHECK(impl_->runtime); CHECK(impl_->runtime);
impl_->cur_op = impl_->cur_op =
impl_->runtime->NewOpExecutable(op->getName().getStringRef().str()); impl_->runtime->NewOpExecutable(op->getName().getStringRef().str());
...@@ -308,42 +309,80 @@ bool MlirToRuntimeTranslator::EmitGeneralOp(mlir::Operation* op) { ...@@ -308,42 +309,80 @@ bool MlirToRuntimeTranslator::EmitGeneralOp(mlir::Operation* op) {
// process attributes // process attributes
auto attrs = op->getAttrs(); auto attrs = op->getAttrs();
// MLIR's underlying attr storage type is `Builtin_Dictionary`, and its
// elements
// are sorted by name. The following code adapts the order of function
// signatures
// of the phi operator library.
llvm::SmallVector<Value*, 4> tmp;
tmp.resize(attrs.size());
const std::string& kernel_name = op->getName().getStringRef().str();
const auto& attr_names = kernel_registry.GetAttrNameList(kernel_name);
if (attrs.size() && attr_names.empty()) {
LOG(WARNING) << "The kernel `" << kernel_name
<< "` has no specified attr order.";
}
auto get_offset = [](const char* attr,
const std::vector<const char*>& names,
const std::string& kernel_name) -> int {
for (size_t i = 0; i < names.size(); ++i) {
if (!std::strcmp(attr, names[i])) {
return i;
}
}
LOG(WARNING) << "The attribute `" << attr << "` of kernel `" << kernel_name
<< "` is not properly registered with "
"`KernelRegistry::AddKernelWithAttrs()`.";
return -1;
};
for (size_t i = 0; i < attrs.size(); i++) { for (size_t i = 0; i < attrs.size(); i++) {
auto& attr = attrs[i]; auto& attr = attrs[i];
int offset{};
if (attr_names.size()) {
offset = get_offset(attr.getName().data(), attr_names, kernel_name);
} else {
offset = i;
}
CHECK_NE(offset, -1);
if (auto v = EmitAttribute<int32_t>(attr.getValue())) { if (auto v = EmitAttribute<int32_t>(attr.getValue())) {
impl_->cur_op->AppendAttribute(new Value(*v)); tmp[offset] = new Value(*v);
} else if (auto v = EmitAttribute<int64_t>(attr.getValue())) { } else if (auto v = EmitAttribute<int64_t>(attr.getValue())) {
impl_->cur_op->AppendAttribute(new Value(*v)); tmp[offset] = new Value(*v);
} else if (auto v = EmitAttribute<float>(attr.getValue())) { } else if (auto v = EmitAttribute<float>(attr.getValue())) {
impl_->cur_op->AppendAttribute(new Value(*v)); tmp[offset] = new Value(*v);
} else if (auto v = EmitAttribute<double>(attr.getValue())) { } else if (auto v = EmitAttribute<double>(attr.getValue())) {
impl_->cur_op->AppendAttribute(new Value(*v)); tmp[offset] = new Value(*v);
} else if (auto v = EmitAttribute<std::string>(attr.getValue())) { } else if (auto v = EmitAttribute<std::string>(attr.getValue())) {
impl_->cur_op->AppendAttribute(new Value(std::move(*v))); tmp[offset] = new Value(std::move(*v));
} else if (auto v = EmitAttribute<bool>(attr.getValue())) { } else if (auto v = EmitAttribute<bool>(attr.getValue())) {
impl_->cur_op->AppendAttribute(new Value(*v)); tmp[offset] = new Value(*v);
} else if (auto v = EmitAttribute<::infrt::TargetType>(attr.getValue())) { } else if (auto v = EmitAttribute<::infrt::TargetType>(attr.getValue())) {
impl_->cur_op->AppendAttribute(new Value(*v)); tmp[offset] = new Value(*v);
} else if (auto v = } else if (auto v =
EmitAttribute<::infrt::PrecisionType>(attr.getValue())) { EmitAttribute<::infrt::PrecisionType>(attr.getValue())) {
impl_->cur_op->AppendAttribute(new Value(*v)); tmp[offset] = new Value(*v);
} else if (auto v = EmitAttribute<::infrt::LayoutType>(attr.getValue())) { } else if (auto v = EmitAttribute<::infrt::LayoutType>(attr.getValue())) {
impl_->cur_op->AppendAttribute(new Value(*v)); tmp[offset] = new Value(*v);
} else if (auto v = EmitAttribute<std::vector<int16_t>>(attr.getValue())) { } else if (auto v = EmitAttribute<std::vector<int16_t>>(attr.getValue())) {
impl_->cur_op->AppendAttribute(new Value(std::move(*v))); tmp[offset] = new Value(std::move(*v));
} else if (auto v = EmitAttribute<std::vector<int32_t>>(attr.getValue())) { } else if (auto v = EmitAttribute<std::vector<int32_t>>(attr.getValue())) {
impl_->cur_op->AppendAttribute(new Value(std::move(*v))); tmp[offset] = new Value(std::move(*v));
} else if (auto v = EmitAttribute<std::vector<int64_t>>(attr.getValue())) { } else if (auto v = EmitAttribute<std::vector<int64_t>>(attr.getValue())) {
impl_->cur_op->AppendAttribute(new Value(std::move(*v))); tmp[offset] = new Value(std::move(*v));
} else if (auto v = EmitAttribute<std::vector<float>>(attr.getValue())) { } else if (auto v = EmitAttribute<std::vector<float>>(attr.getValue())) {
impl_->cur_op->AppendAttribute(new Value(std::move(*v))); tmp[offset] = new Value(std::move(*v));
} else if (auto v = EmitAttribute<std::vector<double>>(attr.getValue())) { } else if (auto v = EmitAttribute<std::vector<double>>(attr.getValue())) {
impl_->cur_op->AppendAttribute(new Value(std::move(*v))); tmp[offset] = new Value(std::move(*v));
} else { } else {
LOG(FATAL) << "Not supported attribute type"; LOG(FATAL) << "Not supported attribute type";
} }
} }
for (size_t i = 0; i < tmp.size(); i++) {
impl_->cur_op->AppendAttribute(tmp[i]);
}
// process results // process results
llvm::SmallVector<Value*, 4> res_values; llvm::SmallVector<Value*, 4> res_values;
for (int i = 0, e = op->getNumResults(); i < e; i++) { for (int i = 0, e = op->getNumResults(); i < e; i++) {
...@@ -598,7 +637,7 @@ class MlirProgramTestExecutor : public MlirToRuntimeTranslator { ...@@ -598,7 +637,7 @@ class MlirProgramTestExecutor : public MlirToRuntimeTranslator {
llvm::SmallVector<mlir::Value, 3> results; llvm::SmallVector<mlir::Value, 3> results;
if (EmitReturnOp(&op, &results)) continue; if (EmitReturnOp(&op, &results)) continue;
if (EmitCallOp(&op, &impl_->func_defs)) continue; if (EmitCallOp(&op, &impl_->func_defs)) continue;
if (EmitGeneralOp(&op)) continue; if (EmitGeneralOp(&op, *registry)) continue;
LOG(FATAL) << "Not supported op: " << DumpToString(op); LOG(FATAL) << "Not supported op: " << DumpToString(op);
} }
......
...@@ -63,7 +63,8 @@ class MlirToRuntimeTranslator { ...@@ -63,7 +63,8 @@ class MlirToRuntimeTranslator {
//! Emit a "ts.build_shape" operation. //! Emit a "ts.build_shape" operation.
bool EmitBuildShapeOp(mlir::Operation* op); bool EmitBuildShapeOp(mlir::Operation* op);
//! Emit an operation other than the special cases above. //! Emit an operation other than the special cases above.
bool EmitGeneralOp(mlir::Operation* op); bool EmitGeneralOp(mlir::Operation* op,
const KernelRegistry& kernel_registry);
//! Emit all the functions. //! Emit all the functions.
bool EmitFunctions(); bool EmitFunctions();
......
...@@ -23,23 +23,23 @@ namespace phi { ...@@ -23,23 +23,23 @@ namespace phi {
::phi::DenseTensor CreateDenseTensor( ::phi::DenseTensor CreateDenseTensor(
const ::phi::CPUContext& context, const ::phi::CPUContext& context,
host_context::Attribute<std::vector<int64_t>> dims, host_context::Attribute<std::vector<int64_t>> dims,
host_context::Attribute<::infrt::LayoutType> layout,
host_context::Attribute<std::vector<int64_t>> lod, host_context::Attribute<std::vector<int64_t>> lod,
host_context::Attribute<::infrt::LayoutType> layout,
host_context::Attribute<::infrt::PrecisionType> precision) { host_context::Attribute<::infrt::PrecisionType> precision) {
return ::phi::DenseTensor( return ::phi::DenseTensor(
const_cast<::phi::Allocator*>(&context.GetAllocator()), const_cast<::phi::Allocator*>(&context.GetAllocator()),
::phi::DenseTensorMeta(cvtPrecision2Phi(precision.get()), ::phi::DenseTensorMeta(ConvertPrecisionToPhi(precision.get()),
::phi::make_ddim(dims.get()), ::phi::make_ddim(dims.get()),
cvtLayout2Phi(layout.get()), ConvertLayoutToPhi(layout.get()),
{})); {}));
} }
void FillDenseTensorF32(::phi::DenseTensor* dense_tensor, void FillDenseTensorF32(::phi::DenseTensor* dense_tensor,
host_context::Attribute<std::vector<float>> values) { host_context::Attribute<std::vector<float>> value) {
auto place = ::phi::CPUPlace(); auto place = ::phi::CPUPlace();
float* a_data = dense_tensor->mutable_data<float>(place); float* a_data = dense_tensor->mutable_data<float>(place);
for (int64_t i = 0; i < dense_tensor->numel(); ++i) { for (int64_t i = 0; i < dense_tensor->numel(); ++i) {
a_data[i] = (values.get())[i]; a_data[i] = (value.get())[i];
} }
} }
...@@ -57,7 +57,7 @@ void PrintDenseTensor(::phi::DenseTensor* dense_tensor) { ...@@ -57,7 +57,7 @@ void PrintDenseTensor(::phi::DenseTensor* dense_tensor) {
::phi::DDim dims = dense_tensor->dims(); ::phi::DDim dims = dense_tensor->dims();
std::cout << "dense_tensor: shape=shape" << dims.to_str() << "," std::cout << "dense_tensor: shape=shape" << dims.to_str() << ","
<< " values=["; << " value=[";
switch (dense_tensor->dtype()) { switch (dense_tensor->dtype()) {
PRINT_META_DATA(FLOAT32, float); PRINT_META_DATA(FLOAT32, float);
PRINT_META_DATA(INT32, int32_t); PRINT_META_DATA(INT32, int32_t);
......
...@@ -26,8 +26,8 @@ namespace phi { ...@@ -26,8 +26,8 @@ namespace phi {
::phi::DenseTensor CreateDenseTensor( ::phi::DenseTensor CreateDenseTensor(
const ::phi::CPUContext& context, const ::phi::CPUContext& context,
host_context::Attribute<std::vector<int64_t>> dims, host_context::Attribute<std::vector<int64_t>> dims,
host_context::Attribute<::infrt::LayoutType> layout,
host_context::Attribute<std::vector<int64_t>> lod, host_context::Attribute<std::vector<int64_t>> lod,
host_context::Attribute<::infrt::LayoutType> layout,
host_context::Attribute<::infrt::PrecisionType> precision); host_context::Attribute<::infrt::PrecisionType> precision);
void FillDenseTensorF32(::phi::DenseTensor* dense_tensor, void FillDenseTensorF32(::phi::DenseTensor* dense_tensor,
......
...@@ -34,10 +34,14 @@ namespace kernel { ...@@ -34,10 +34,14 @@ namespace kernel {
void RegisterPhiKernels(host_context::KernelRegistry* registry) { void RegisterPhiKernels(host_context::KernelRegistry* registry) {
registry->AddKernel("phi_dt.create_context.cpu", registry->AddKernel("phi_dt.create_context.cpu",
INFRT_KERNEL(infrt::kernel::phi::CreateCPUContext)); INFRT_KERNEL(infrt::kernel::phi::CreateCPUContext));
registry->AddKernel("phi_dt.create_dense_tensor", registry->AddKernelWithAttrs(
INFRT_KERNEL(infrt::kernel::phi::CreateDenseTensor)); "phi_dt.create_dense_tensor",
registry->AddKernel("phi_dt.fill_dense_tensor.f32", INFRT_KERNEL(infrt::kernel::phi::CreateDenseTensor),
INFRT_KERNEL(infrt::kernel::phi::FillDenseTensorF32)); {"dims", "lod", "layout", "precision"});
registry->AddKernelWithAttrs(
"phi_dt.fill_dense_tensor.f32",
INFRT_KERNEL(infrt::kernel::phi::FillDenseTensorF32),
{"value"});
registry->AddKernel("phi_dt.print_tensor", registry->AddKernel("phi_dt.print_tensor",
INFRT_KERNEL(infrt::kernel::phi::PrintDenseTensor)); INFRT_KERNEL(infrt::kernel::phi::PrintDenseTensor));
} }
......
...@@ -111,9 +111,9 @@ void NaiveMatmul(const DenseHostTensor &x, ...@@ -111,9 +111,9 @@ void NaiveMatmul(const DenseHostTensor &x,
/// ===== Kernel end ==== /// ===== Kernel end ====
void RegisterTensorKernels(host_context::KernelRegistry *registry) { void RegisterTensorKernels(host_context::KernelRegistry *registry) {
registry->AddKernel("dt.create_uninit_tensor.f32", registry->AddKernelWithAttrs("dt.create_uninit_tensor.f32",
INFRT_KERNEL(CreateUninitTensor<float>)); INFRT_KERNEL(CreateUninitTensor<float>),
registry->AddKernelAttrNameList("dt.create_uninit_tensor.f32", {"shape"}); {"shape"});
registry->AddKernel("dt.print_tensor", INFRT_KERNEL(PrintTensor)); registry->AddKernel("dt.print_tensor", INFRT_KERNEL(PrintTensor));
registry->AddKernel("dt.fill_tensor_with_constant.f32", registry->AddKernel("dt.fill_tensor_with_constant.f32",
INFRT_KERNEL(FillTensorWithConstant<float>)); INFRT_KERNEL(FillTensorWithConstant<float>));
......
...@@ -9,7 +9,7 @@ func @sign_any_float32_execute() { ...@@ -9,7 +9,7 @@ func @sign_any_float32_execute() {
"phi_dt.fill_dense_tensor.f32"(%t) {value=[3.8:f32]} : (!infrt.dense_tensor<CPU, FP32, NCHW>) -> () "phi_dt.fill_dense_tensor.f32"(%t) {value=[3.8:f32]} : (!infrt.dense_tensor<CPU, FP32, NCHW>) -> ()
%e = "phi_cpu.sign.float32.any"(%ctx, %t) : (!phi.context<CPU>, !infrt.dense_tensor<CPU, FP32, NCHW>) -> (!infrt.dense_tensor<CPU, FP32, NCHW>) %e = "phi_cpu.sign.float32.any"(%ctx, %t) : (!phi.context<CPU>, !infrt.dense_tensor<CPU, FP32, NCHW>) -> (!infrt.dense_tensor<CPU, FP32, NCHW>)
// CHECK: dense_tensor: shape=shape[1], values=[1] // CHECK: dense_tensor: shape=shape[1], value=[1]
"phi_dt.print_tensor" (%e) : (!infrt.dense_tensor<CPU, FP32, NCHW>) -> () "phi_dt.print_tensor" (%e) : (!infrt.dense_tensor<CPU, FP32, NCHW>) -> ()
infrt.return infrt.return
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册