From b77e20ac69681691cecb5851a39299ef93027af7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=9F=B3=E6=99=93=E4=BC=9F?= <39303645+Shixiaowei02@users.noreply.github.com> Date: Mon, 21 Mar 2022 10:48:14 +0800 Subject: [PATCH] add the map for dense tensor, test=develop (#40665) --- paddle/infrt/api/infrt_api.cc | 2 +- paddle/infrt/dialect/dense_tensor.td | 6 +- paddle/infrt/dialect/infrt/ir/infrt_base.td | 2 +- .../infrt/dialect/infrt/ir/infrt_dialect.cc | 10 +- paddle/infrt/dialect/phi/data_type.h | 16 +-- paddle/infrt/dialect/phi/ir/infrt_phi_base.td | 4 + .../infrt/dialect/phi/ir/infrt_phi_tensor.td | 34 +++++++ .../host_context/mlir_to_runtime_translate.cc | 24 +++-- paddle/infrt/host_context/paddle_mlir.cc | 2 +- paddle/infrt/host_context/value.h | 19 ++-- .../infrt/kernel/phi/dense_tensor_kernels.cc | 98 +++++++++++++++++++ .../infrt/kernel/phi/dense_tensor_kernels.h | 14 +++ .../infershaped/infershape_launchers_test.cc | 29 +++--- paddle/infrt/kernel/phi/registry.cc | 13 +++ paddle/infrt/kernel/tensor_kernels.cc | 6 +- paddle/infrt/tensor/CMakeLists.txt | 2 + paddle/infrt/tensor/phi/CMakeLists.txt | 3 + paddle/infrt/tensor/phi/tensor_map.cc | 47 +++++++++ paddle/infrt/tensor/phi/tensor_map.h | 37 +++++++ .../tests/dialect/tensor/tensor_map.mlir.in | 27 +++++ tools/infrt/fake_models/multi_fc.py | 3 + tools/infrt/get_phi_kernel_function.sh | 4 +- tools/infrt/get_phi_kernel_info.py | 16 +-- 23 files changed, 353 insertions(+), 65 deletions(-) create mode 100644 paddle/infrt/tensor/phi/CMakeLists.txt create mode 100644 paddle/infrt/tensor/phi/tensor_map.cc create mode 100644 paddle/infrt/tensor/phi/tensor_map.h diff --git a/paddle/infrt/api/infrt_api.cc b/paddle/infrt/api/infrt_api.cc index 0500a812304..5ac51fb6715 100644 --- a/paddle/infrt/api/infrt_api.cc +++ b/paddle/infrt/api/infrt_api.cc @@ -129,7 +129,7 @@ class PredictExecutor : public MlirToRuntimeTranslator { auto arg = predict_func.getArgument(i); auto type = arg.getType(); // this param is TensorMap - if (type.isa()) { + if (type.isa()) { auto* value = new host_context::Value(std::move(*map)); arguments_.push_back(value); AddValue(predict_func.getArgument(i), value); diff --git a/paddle/infrt/dialect/dense_tensor.td b/paddle/infrt/dialect/dense_tensor.td index 59df4e96973..822a4879e6f 100644 --- a/paddle/infrt/dialect/dense_tensor.td +++ b/paddle/infrt/dialect/dense_tensor.td @@ -106,7 +106,7 @@ def LoadParamsOp : DT_Op<"load_params", [NoSideEffect]> { // input path of model params. let arguments = (ins StrAttr:$path); - let results = (outs DenseTensorMap:$out); + let results = (outs DenseHostTensorMap:$out); let assemblyFormat = "`(``)`attr-dict"; } @@ -121,7 +121,7 @@ def TensorMapGetTensorOp : DT_Op<"tensor_map_get_tensor", [NoSideEffect]> { // input path of model params. let arguments = (ins - DenseTensorMap:$map, + DenseHostTensorMap:$map, StrAttr:$name ); let results = (outs DenseTensor:$output); @@ -136,7 +136,7 @@ def TensorMapGetSizeOp : DT_Op<"tensor_map_get_size", [NoSideEffect]> { An operation that get the size of a TensorMap. }]; - let arguments = (ins DenseTensorMap:$map); + let arguments = (ins DenseHostTensorMap:$map); let results = (outs I32:$size); let assemblyFormat = "`(` $map `)` attr-dict `->` type($size)"; } diff --git a/paddle/infrt/dialect/infrt/ir/infrt_base.td b/paddle/infrt/dialect/infrt/ir/infrt_base.td index 86cfc375330..9b1d2132292 100644 --- a/paddle/infrt/dialect/infrt/ir/infrt_base.td +++ b/paddle/infrt/dialect/infrt/ir/infrt_base.td @@ -83,7 +83,7 @@ def DenseTensor : Infrt_Type<"DenseTensor"> { ); } -def DenseTensorMap : Infrt_Type<"DenseTensorMap"> { +def DenseHostTensorMap : Infrt_Type<"DenseHostTensorMap"> { let summary = "infrt dense tensor map"; let description = [{dense_tensor map}]; let parameters = (ins); diff --git a/paddle/infrt/dialect/infrt/ir/infrt_dialect.cc b/paddle/infrt/dialect/infrt/ir/infrt_dialect.cc index f8d8f514749..eb69a95c583 100644 --- a/paddle/infrt/dialect/infrt/ir/infrt_dialect.cc +++ b/paddle/infrt/dialect/infrt/ir/infrt_dialect.cc @@ -91,7 +91,7 @@ mlir::Type InfrtDialect::parseType(::mlir::DialectAsmParser &parser) const { parser.getContext(), shape, elementType, lod_level); } if (keyword == "dense_tensor_map") { - return DenseTensorMapType::get(parser.getContext()); + return DenseHostTensorMapType::get(parser.getContext()); } if (keyword == "dense_tensor") { // parse DenseTensor, for example: !i=Infrt.tensor @@ -162,7 +162,7 @@ void InfrtDialect::printType(::mlir::Type type, << lod_tensor_type.getLod_level() << ">"; return; } - if (type.isa()) { + if (type.isa()) { os << "dense_tensor_map"; return; } @@ -180,12 +180,6 @@ void InfrtDialect::printType(::mlir::Type type, os << "tensor_list"; return; } - // print DenseTensorType, for example: !infrt.dense_tensor - if (type.isa()) { - os << "dense_tensor_map"; - return; - } - llvm_unreachable("unknown infrt type."); } diff --git a/paddle/infrt/dialect/phi/data_type.h b/paddle/infrt/dialect/phi/data_type.h index bd258cb1038..8e831c8c27d 100644 --- a/paddle/infrt/dialect/phi/data_type.h +++ b/paddle/infrt/dialect/phi/data_type.h @@ -23,16 +23,16 @@ namespace infrt { -phi::Backend ConvertTargetToPhi(TargetType target); -TargetType ConvertTargetFromPhi(phi::Backend backend); +::phi::Backend ConvertTargetToPhi(TargetType target); +TargetType ConvertTargetFromPhi(::phi::Backend backend); -phi::DataType ConvertPrecisionToPhi(PrecisionType precision); -PrecisionType ConvertPrecisionFromPhi(phi::DataType datatype); +::phi::DataType ConvertPrecisionToPhi(PrecisionType precision); +PrecisionType ConvertPrecisionFromPhi(::phi::DataType datatype); -phi::DataLayout ConvertLayoutToPhi(LayoutType layout); -LayoutType ConvertLayoutFromPhi(phi::DataLayout layout); +::phi::DataLayout ConvertLayoutToPhi(LayoutType layout); +LayoutType ConvertLayoutFromPhi(::phi::DataLayout layout); -phi::KernelKey ConvertPlaceToPhi(const Place& place); -Place ConvertPlaceFromPhi(phi::TensorArgDef tensor_arg); +::phi::KernelKey ConvertPlaceToPhi(const Place& place); +Place ConvertPlaceFromPhi(::phi::TensorArgDef tensor_arg); } // namespace infrt diff --git a/paddle/infrt/dialect/phi/ir/infrt_phi_base.td b/paddle/infrt/dialect/phi/ir/infrt_phi_base.td index 5d7338ec429..8e21283183d 100644 --- a/paddle/infrt/dialect/phi/ir/infrt_phi_base.td +++ b/paddle/infrt/dialect/phi/ir/infrt_phi_base.td @@ -37,4 +37,8 @@ def Allocator : PHI_Type<"Allocator"> { let assemblyFormat = "`<` $target `>`"; } +def PD_DenseTensorMap : PHI_Type<"DenseTensorMap"> { + let mnemonic = "dense_tensor_map"; +} + #endif diff --git a/paddle/infrt/dialect/phi/ir/infrt_phi_tensor.td b/paddle/infrt/dialect/phi/ir/infrt_phi_tensor.td index 1fda2d9d888..3af7033d2f4 100644 --- a/paddle/infrt/dialect/phi/ir/infrt_phi_tensor.td +++ b/paddle/infrt/dialect/phi/ir/infrt_phi_tensor.td @@ -51,12 +51,46 @@ class CreateContextOp let results = (outs Context:$output); } +def PDT_LoadParamsOp : PDT_Op<"load_params", [NoSideEffect]> { + // input path of model params. + let arguments = (ins StrAttr:$path); + let results = (outs PD_DenseTensorMap:$out); + + let assemblyFormat = "`(``)`attr-dict"; +} + +def PDT_LoadCombinedParamsOp : PDT_Op<"load_combined_params", [NoSideEffect]> { + // input path of model params. + let arguments = (ins StrAttr:$model_path, StrAttr:$params_path); + let results = (outs PD_DenseTensorMap:$out); + + let assemblyFormat = "`(``)`attr-dict"; +} + +def PDT_TensorMapGetSizeOp : PDT_Op<"tensor_map_get_size", [NoSideEffect]> { + let arguments = (ins PD_DenseTensorMap:$map); + let results = (outs I32:$size); + let assemblyFormat = "`(` $map `)` attr-dict `->` type($size)"; +} + +class TensorMapGetTensorOp: + PDT_Op<"tensor_map_get_tensor"> { + let arguments = (ins + PD_DenseTensorMap:$map, + StrAttr:$name + ); + let results = (outs DenseTensor:$output); + let assemblyFormat = "`(` operands `)` attr-dict `->` type($output)"; + let verifier = ?; +} + def PDT_CreateCPUDenseTensorOp : CreateDenseTensorOp<"cpu">; def PDT_CreateGPUDenseTensorOp : CreateDenseTensorOp<"gpu">; def PDT_FillDenseTensorOp_f32 : FillDenseTensorOp; def PDT_CreateCPUContextOp : CreateContextOp<"cpu">; def PDT_CreateGPUContextOp : CreateContextOp<"gpu">; def PDT_PrintDenseTensor : PrintDenseTensorOp; +def PDT_TensorMapGetTensorOp: TensorMapGetTensorOp; def FakeKernelOp : PDT_Op<"fake_phi_kernel"> { let arguments = (ins Context:$dev_ctx, DenseTensor:$x, DenseTensor:$y, BoolAttr:$transpose_x, BoolAttr:$transpose_y); diff --git a/paddle/infrt/host_context/mlir_to_runtime_translate.cc b/paddle/infrt/host_context/mlir_to_runtime_translate.cc index bcd44540b33..7e90f225cff 100644 --- a/paddle/infrt/host_context/mlir_to_runtime_translate.cc +++ b/paddle/infrt/host_context/mlir_to_runtime_translate.cc @@ -351,18 +351,26 @@ bool MlirToRuntimeTranslator::EmitGeneralOp( auto attrs = op->getAttrs(); // MLIR's underlying attr storage type is `Builtin_Dictionary`, and its - // elements - // are sorted by name. The following code adapts the order of function - // signatures - // of the phi operator library. + // elements are sorted by name. The following code adapts the order of + // function signatures of the phi operator library. llvm::SmallVector tmp; tmp.resize(attrs.size()); const std::string& kernel_name = op->getName().getStringRef().str(); const auto& attr_names = kernel_registry.GetAttrNameList(kernel_name); - if (attrs.size() && attr_names.empty()) { - LOG(WARNING) << "The kernel `" << kernel_name - << "` has no specified attr order."; + if (attrs.size()) { + if (attr_names.empty()) { + LOG(WARNING) << "The kernel `" << kernel_name + << "` has not been registered with " + "`KernelRegistry::AddKernelWithAttrs()`."; + } else { + CHECK_EQ(attr_names.size(), attrs.size()) + << "The number of kernel `" << kernel_name + << "` attributes specified by mlir (" << attrs.size() + << ") is inconsistent with the registration (" << attr_names.size() + << ")."; + } } + auto get_offset = [](const char* attr, const std::vector& names, const std::string& kernel_name) -> int { @@ -385,7 +393,7 @@ bool MlirToRuntimeTranslator::EmitGeneralOp( } else { offset = i; } - CHECK_NE(offset, -1); + CHECK_GT(offset, -1); if (auto v = EmitAttribute(attr.getValue())) { tmp[offset] = new Value(*v); } else if (auto v = EmitAttribute(attr.getValue())) { diff --git a/paddle/infrt/host_context/paddle_mlir.cc b/paddle/infrt/host_context/paddle_mlir.cc index 29328520212..e161dc47075 100644 --- a/paddle/infrt/host_context/paddle_mlir.cc +++ b/paddle/infrt/host_context/paddle_mlir.cc @@ -79,7 +79,7 @@ mlir::FuncOp MLIRModelGenImpl::UpdateModelModule( llvm::SmallVector MLIRModelGenImpl::GetModelInputsType( const infrt::paddle::framework_proto::ProgramDesc &program) { llvm::SmallVector operandTypes; - operandTypes.push_back(infrt::DenseTensorMapType::get(context_)); + operandTypes.push_back(infrt::DenseHostTensorMapType::get(context_)); for (auto &op_desc : main_block_.ops()) { if (op_desc.type() != "feed") continue; for (int var_idx = 0; var_idx < op_desc.outputs_size(); ++var_idx) { diff --git a/paddle/infrt/host_context/value.h b/paddle/infrt/host_context/value.h index 1f0b1dabd94..5b92d183b79 100644 --- a/paddle/infrt/host_context/value.h +++ b/paddle/infrt/host_context/value.h @@ -34,6 +34,7 @@ #ifdef INFRT_WITH_PHI #include "paddle/infrt/backends/host/phi_allocator.h" #include "paddle/infrt/backends/host/phi_context.h" +#include "paddle/infrt/tensor/phi/tensor_map.h" #include "paddle/phi/backends/all_context.h" #include "paddle/phi/common/backend.h" #include "paddle/phi/common/data_type.h" @@ -84,22 +85,23 @@ using ValueVariantType = #ifdef INFRT_WITH_GPU backends::GpuPhiContext, ::phi::GPUContext, -#endif +#endif // INFRT_WITH_GPU ::phi::CPUContext, - std::vector, - std::vector, - paddle::experimental::ScalarBase, - paddle::experimental::ScalarArrayBase, - std::vector, - phi::MetaConfig, + std::vector, + std::vector<::phi::DenseTensor*>, + paddle::experimental::ScalarBase<::phi::DenseTensor>, + paddle::experimental::ScalarArrayBase<::phi::DenseTensor>, + std::vector<::phi::MetaTensor*>, + ::phi::MetaConfig, paddle::experimental::Backend, paddle::experimental::DataLayout, paddle::experimental::DataType, + ::infrt::phi::DenseTensorMap, +#endif // INFRT_WITH_PHI #ifdef INFRT_WITH_TRT ::infrt::backends::tensorrt::TrtEngine, ::infrt::kernel::tensorrt::MlirOperationWithInfrtSymbol, #endif // INFRT_WITH_TRT -#endif std::vector, std::vector, std::vector, @@ -136,6 +138,7 @@ class Value : public common::Object { explicit Value(tensor::DenseHostTensor&& x) : data(std::move(x)) {} explicit Value(MlirFunctionExecutable* x) : data(x) {} #ifdef INFRT_WITH_PHI + explicit Value(::infrt::phi::DenseTensorMap&& x) : data(std::move(x)) {} explicit Value(::phi::CPUContext&& x) : data(std::move(x)) {} explicit Value(backends::CpuPhiContext&& x) : data(std::move(x)) {} #ifdef INFRT_WITH_GPU diff --git a/paddle/infrt/kernel/phi/dense_tensor_kernels.cc b/paddle/infrt/kernel/phi/dense_tensor_kernels.cc index 6d16b814c6b..c8b1bd8c9eb 100644 --- a/paddle/infrt/kernel/phi/dense_tensor_kernels.cc +++ b/paddle/infrt/kernel/phi/dense_tensor_kernels.cc @@ -13,8 +13,11 @@ // limitations under the License. #include "paddle/infrt/kernel/phi/dense_tensor_kernels.h" +#include "paddle/infrt/common/string.h" #include "paddle/infrt/dialect/phi/data_type.h" #include "paddle/infrt/kernel/phi/context_kernels.h" +#include "paddle/infrt/paddle/model_parser.h" +#include "paddle/infrt/paddle/scope.h" #include "paddle/phi/backends/all_context.h" #include "paddle/phi/common/place.h" @@ -22,6 +25,18 @@ #include #endif +namespace paddle { +namespace platform { +using DeviceContext = ::phi::DeviceContext; +} // namespace platform +namespace framework { +using LoDTensor = ::phi::DenseTensor; +void DeserializeFromStream(std::istream& is, + LoDTensor* tensor, + const platform::DeviceContext& dev_ctx); +} +} // namespace paddle + namespace infrt { namespace kernel { namespace phi { @@ -130,6 +145,89 @@ void PrintDenseTensor(::phi::DenseTensor* dense_tensor) { std::cout << "]\n"; #undef PRINT_META_DATA } + +::infrt::phi::DenseTensorMap LoadParams( + host_context::Attribute path) { + const auto& file_path = path.get(); + std::cout << "loading params from: " << file_path << std::endl; + ::infrt::phi::DenseTensorMap map; + + const std::string model_path = file_path + "/__model__"; + auto pb_proto_prog = paddle::LoadProgram(model_path); + auto main_block = pb_proto_prog->blocks(0); + + for (auto& var : main_block.vars()) { + if (var.name() == "feed" || var.name() == "fetch" || !var.persistable()) + continue; + std::string param_path = file_path + "/" + var.name(); + std::ifstream param_file(param_path, std::ios::binary); + switch (var.type().type()) { + case ::paddle::framework::proto::VarType_Type_LOD_TENSOR: { + std::unique_ptr<::phi::DenseTensor> tensor{ + std::make_unique<::phi::DenseTensor>()}; + ::phi::CPUContext ctx; + ::paddle::framework::DeserializeFromStream( + param_file, tensor.get(), ctx); + map.SetDenseTensor(var.name(), std::move(tensor)); + } break; + default: { + LOG(WARNING) << "Var `" << var.name() << "` type `" + << static_cast(var.type().type()) + << "` has not been supported now."; + } + } + } + return map; +} + +::infrt::phi::DenseTensorMap LoadCombinedParams( + host_context::Attribute model_path, + host_context::Attribute params_path) { + const auto& model = model_path.get(); + std::cout << "loading params from: " << model << std::endl; + ::infrt::phi::DenseTensorMap map; + + auto pb_proto_prog = paddle::LoadProgram(model); + auto main_block = pb_proto_prog->blocks(0); + + std::ifstream param_file(params_path.get(), std::ios::binary); + + std::set tmp; + for (auto& var : main_block.vars()) { + if (var.name() == "feed" || var.name() == "fetch" || !var.persistable()) { + continue; + } + if (var.type().type() == + ::paddle::framework::proto::VarType_Type_LOD_TENSOR) { + tmp.emplace(var.name()); + } else { + llvm_unreachable("the tensor type is illegal."); + } + } + + for (auto& var : tmp) { + std::unique_ptr<::phi::DenseTensor> tensor{ + std::make_unique<::phi::DenseTensor>()}; + ::phi::CPUContext ctx; + ::paddle::framework::DeserializeFromStream(param_file, tensor.get(), ctx); + map.SetDenseTensor(var, std::move(tensor)); + } + + return map; +} + +::phi::DenseTensor TensorMapGetTensor( + const ::infrt::phi::DenseTensorMap& map, + host_context::Attribute name) { + auto* tensor = map.GetDenseTensor(name.get()); + CHECK(tensor); + return *tensor; +} + +int32_t TensorMapGetSize(const ::infrt::phi::DenseTensorMap& map) { + return map.size(); +} + } // namespace phi } // namespace kernel } // namespace infrt diff --git a/paddle/infrt/kernel/phi/dense_tensor_kernels.h b/paddle/infrt/kernel/phi/dense_tensor_kernels.h index 47d89506e2a..6cfcc6f91be 100644 --- a/paddle/infrt/kernel/phi/dense_tensor_kernels.h +++ b/paddle/infrt/kernel/phi/dense_tensor_kernels.h @@ -17,6 +17,7 @@ #include "paddle/infrt/backends/host/phi_allocator.h" #include "paddle/infrt/dialect/infrt/common/types.h" #include "paddle/infrt/host_context/kernel_utils.h" +#include "paddle/infrt/tensor/phi/tensor_map.h" #include "paddle/phi/core/dense_tensor.h" namespace infrt { @@ -41,6 +42,19 @@ void FillDenseTensorF32(::phi::DenseTensor* dense_tensor, host_context::Attribute> values); void PrintDenseTensor(::phi::DenseTensor* dense_tensor); +infrt::phi::DenseTensorMap LoadParams( + host_context::Attribute path); + +::phi::DenseTensor TensorMapGetTensor( + const ::infrt::phi::DenseTensorMap& map, + host_context::Attribute name); + +::infrt::phi::DenseTensorMap LoadCombinedParams( + host_context::Attribute model_path, + host_context::Attribute params_path); + +int32_t TensorMapGetSize(const ::infrt::phi::DenseTensorMap& map); + } // namespace phi } // namespace kernel } // namespace infrt diff --git a/paddle/infrt/kernel/phi/infershaped/infershape_launchers_test.cc b/paddle/infrt/kernel/phi/infershaped/infershape_launchers_test.cc index 08c2e19dedd..5a314817c24 100644 --- a/paddle/infrt/kernel/phi/infershaped/infershape_launchers_test.cc +++ b/paddle/infrt/kernel/phi/infershaped/infershape_launchers_test.cc @@ -37,15 +37,16 @@ TEST(utils, registry) { CHECK_EQ(count, 2U); } -class FancyAllocator : public phi::Allocator { +class FancyAllocator : public ::phi::Allocator { public: - static void Delete(phi::Allocation* allocation) { + static void Delete(::phi::Allocation* allocation) { ::operator delete(allocation->ptr()); } AllocationPtr Allocate(size_t bytes_size) override { void* data = ::operator new(bytes_size); - auto* allocation = new phi::Allocation(data, bytes_size, phi::CPUPlace()); + auto* allocation = + new ::phi::Allocation(data, bytes_size, ::phi::CPUPlace()); return AllocationPtr(allocation, Delete); } }; @@ -56,20 +57,20 @@ TEST(ElementwiseAdd, launcher_registry) { ASSERT_GE(registry.size(), 1UL); auto creator = registry.GetKernel("phi_cpu.add.float32.any"); - const phi::DDim dims({1, 2}); - const phi::DataType dtype{phi::DataType::FLOAT32}; - const phi::DataLayout layout{phi::DataLayout::NHWC}; - const phi::LoD lod{}; - phi::DenseTensorMeta meta(dtype, dims, layout, lod); + const ::phi::DDim dims({1, 2}); + const ::phi::DataType dtype{::phi::DataType::FLOAT32}; + const ::phi::DataLayout layout{::phi::DataLayout::NHWC}; + const ::phi::LoD lod{}; + ::phi::DenseTensorMeta meta(dtype, dims, layout, lod); - auto fancy_allocator = std::unique_ptr(new FancyAllocator); + auto fancy_allocator = std::unique_ptr<::phi::Allocator>(new FancyAllocator); auto* alloc = fancy_allocator.get(); - phi::DenseTensor a(alloc, meta); - phi::DenseTensor b(alloc, meta); - phi::DenseTensor c(alloc, meta); + ::phi::DenseTensor a(alloc, meta); + ::phi::DenseTensor b(alloc, meta); + ::phi::DenseTensor c(alloc, meta); - auto place = phi::CPUPlace(); + auto place = ::phi::CPUPlace(); float* a_data = a.mutable_data(place); float* b_data = b.mutable_data(place); float* c_data = c.mutable_data(place); @@ -78,7 +79,7 @@ TEST(ElementwiseAdd, launcher_registry) { b_data[i] = 2.f; } - phi::CPUContext context; + ::phi::CPUContext context; context.SetAllocator(alloc); context.Init(); diff --git a/paddle/infrt/kernel/phi/registry.cc b/paddle/infrt/kernel/phi/registry.cc index 36d40118f16..08683d7cb66 100644 --- a/paddle/infrt/kernel/phi/registry.cc +++ b/paddle/infrt/kernel/phi/registry.cc @@ -53,6 +53,19 @@ void RegisterPhiKernels(host_context::KernelRegistry* registry) { INFRT_KERNEL(infrt::kernel::phi::CreateGPUDenseTensor), {"dims", "lod", "layout", "precision"}); #endif + registry->AddKernelWithAttrs("phi_dt.load_params", + INFRT_KERNEL(infrt::kernel::phi::LoadParams), + {"path"}); + registry->AddKernelWithAttrs( + "phi_dt.load_combined_params", + INFRT_KERNEL(infrt::kernel::phi::LoadCombinedParams), + {"model_path", "params_path"}); + registry->AddKernelWithAttrs( + "phi_dt.tensor_map_get_tensor", + INFRT_KERNEL(infrt::kernel::phi::TensorMapGetTensor), + {"name"}); + registry->AddKernel("phi_dt.tensor_map_get_size", + INFRT_KERNEL(infrt::kernel::phi::TensorMapGetSize)); } } // namespace kernel diff --git a/paddle/infrt/kernel/tensor_kernels.cc b/paddle/infrt/kernel/tensor_kernels.cc index a9077220cfc..407ae16c19c 100644 --- a/paddle/infrt/kernel/tensor_kernels.cc +++ b/paddle/infrt/kernel/tensor_kernels.cc @@ -68,14 +68,14 @@ int32_t TensorMapGetSize(TensorMap map) { return map.size(); } // TODO(wilber): Maybe we should place TensorList type in dt dialect. #ifdef INFRT_WITH_PHI -phi::DenseTensor TensorListGetTensor(std::vector list, - Attribute idx) { +::phi::DenseTensor TensorListGetTensor(std::vector<::phi::DenseTensor *> list, + Attribute idx) { CHECK_LT(idx.get(), static_cast(list.size())) << "idx should less than list size"; return *list[idx.get()]; } -int32_t TensorListGetSize(const std::vector &list) { +int32_t TensorListGetSize(const std::vector<::phi::DenseTensor *> &list) { return list.size(); } #endif diff --git a/paddle/infrt/tensor/CMakeLists.txt b/paddle/infrt/tensor/CMakeLists.txt index 95b2e8f6839..95d4090a9a3 100644 --- a/paddle/infrt/tensor/CMakeLists.txt +++ b/paddle/infrt/tensor/CMakeLists.txt @@ -1,5 +1,7 @@ core_gather_headers() +add_subdirectory(phi) + gather_srcs(infrt_src SRCS tensor_map.cc tensor_metadata.cc diff --git a/paddle/infrt/tensor/phi/CMakeLists.txt b/paddle/infrt/tensor/phi/CMakeLists.txt new file mode 100644 index 00000000000..97e26661266 --- /dev/null +++ b/paddle/infrt/tensor/phi/CMakeLists.txt @@ -0,0 +1,3 @@ +gather_srcs(infrt_src SRCS + tensor_map.cc +) diff --git a/paddle/infrt/tensor/phi/tensor_map.cc b/paddle/infrt/tensor/phi/tensor_map.cc new file mode 100644 index 00000000000..7690322aed4 --- /dev/null +++ b/paddle/infrt/tensor/phi/tensor_map.cc @@ -0,0 +1,47 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/infrt/tensor/phi/tensor_map.h" +#include "llvm/Support/ErrorHandling.h" + +namespace infrt { +namespace phi { + +void DenseTensorMap::SetDenseTensor( + const std::string& name, std::unique_ptr<::phi::DenseTensor>&& tensor) { + std::lock_guard lock(mu_); + auto it = map_.emplace(std::make_pair(name, std::move(tensor))); + if (!it.second) { + llvm_unreachable("dense tensor map insert failed."); + } +} + +::phi::DenseTensor* DenseTensorMap::GetDenseTensor( + const std::string& name) const { + std::lock_guard lock(mu_); + auto it = map_.find(name); + if (it != map_.end()) { + return it->second.get(); + } + LOG(WARNING) << "can not find `" << name << "` in the tensor map."; + return nullptr; +} + +size_t DenseTensorMap::size() const { + std::lock_guard lock(mu_); + return map_.size(); +} + +} // namespace phi +} // namespace infrt diff --git a/paddle/infrt/tensor/phi/tensor_map.h b/paddle/infrt/tensor/phi/tensor_map.h new file mode 100644 index 00000000000..1b9fbdd9def --- /dev/null +++ b/paddle/infrt/tensor/phi/tensor_map.h @@ -0,0 +1,37 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace infrt { +namespace phi { + +class DenseTensorMap { + public: + DenseTensorMap() = default; + DenseTensorMap(DenseTensorMap&& other) : map_(std::move(other.map_)) {} + void SetDenseTensor(const std::string& name, + std::unique_ptr<::phi::DenseTensor>&& tensor); + ::phi::DenseTensor* GetDenseTensor(const std::string& name) const; + size_t size() const; + + private: + mutable std::mutex mu_; + std::unordered_map> map_; +}; + +} // namespace phi +} // namespace infrt diff --git a/paddle/infrt/tests/dialect/tensor/tensor_map.mlir.in b/paddle/infrt/tests/dialect/tensor/tensor_map.mlir.in index 7aeb3f8a4d0..9e3773edd77 100644 --- a/paddle/infrt/tests/dialect/tensor/tensor_map.mlir.in +++ b/paddle/infrt/tests/dialect/tensor/tensor_map.mlir.in @@ -12,3 +12,30 @@ func @load_tensor_map() { infrt.return } + +func @load_phi_tensor_map() { + %map = phi_dt.load_params(){path="@CMAKE_BINARY_DIR@/multi_fc_model"} + %size = phi_dt.tensor_map_get_size(%map) -> i32 + infrt.print.i32 %size + + %a = phi_dt.tensor_map_get_tensor(%map) {name="fc_bias"} -> !infrt.dense_tensor + + // CHECK: dense_tensor: shape=shape[2], value=[0,0] + phi_dt.print_tensor (%a : !infrt.dense_tensor) + + infrt.return +} + +func @load_combined_phi_tensor_map() { + %map = phi_dt.load_combined_params(){model_path="@CMAKE_BINARY_DIR@/multi_fc_model/fc.pdmodel", + params_path="@CMAKE_BINARY_DIR@/multi_fc_model/fc.pdiparams"} + %size = phi_dt.tensor_map_get_size(%map) -> i32 + infrt.print.i32 %size + + %a = phi_dt.tensor_map_get_tensor(%map) {name="fc_bias"} -> !infrt.dense_tensor + + // CHECK: dense_tensor: shape=shape[2], value=[0,0] + phi_dt.print_tensor (%a : !infrt.dense_tensor) + + infrt.return +} diff --git a/tools/infrt/fake_models/multi_fc.py b/tools/infrt/fake_models/multi_fc.py index 0d633cfc60a..7149c8d022a 100644 --- a/tools/infrt/fake_models/multi_fc.py +++ b/tools/infrt/fake_models/multi_fc.py @@ -52,4 +52,7 @@ loss = exe = fluid.Executor(cpu) exe.run(fluid.default_startup_program()) fluid.io.save_inference_model("./multi_fc_model", [a.name], [fc_out], exe) +fluid.io.save_inference_model("./multi_fc_model", [a.name], [fc_out], exe, None, + "fc.pdmodel", "fc.pdiparams") + print('output name', fc_out.name) diff --git a/tools/infrt/get_phi_kernel_function.sh b/tools/infrt/get_phi_kernel_function.sh index 6b2586d4081..febfe5d0476 100644 --- a/tools/infrt/get_phi_kernel_function.sh +++ b/tools/infrt/get_phi_kernel_function.sh @@ -49,7 +49,7 @@ all_ir_name=`grep -Eo "PDTCPU_Kernel<.*\"" paddle/infrt/dialect/phi/ir/phi_cpu_k for ir in $all_ir_name do attr_name=`grep "<\"$ir" -A 3 paddle/infrt/dialect/phi/ir/phi_cpu_kernels.td | grep -Eo "Attr:.*)" \ - | awk '{gsub(/F32Attr/,"");gsub(/F64Attr/,"");gsub(/StrAttr/,"");gsub(/BOOLAttr/,""); \ + | awk '{gsub(/F32Attr/,"");gsub(/F64Attr/,"");gsub(/StrAttr/,"");gsub(/BoolAttr/,""); \ gsub(/SI1Attr/,"");gsub(/SI8Attr/,"");gsub(/SI16Attr/,"");gsub(/SI32Attr/,"");gsub(/SI64Attr/,""); \ gsub(/UI1Attr/,"");gsub(/UI8Attr/,"");gsub(/I16Attr/,"");gsub(/I32Attr/,"");gsub(/I64Attr/,""); \ gsub(/I1Attr/,"");gsub(/I8Attr/,"");gsub(/UI16Attr/,"");gsub(/UI32Attr/,"");gsub(/UI64Attr/,""); \ @@ -62,7 +62,7 @@ all_ir_name=`grep -Eo "PDTGPU_Kernel<.*\"" paddle/infrt/dialect/phi/ir/phi_gpu_k for ir in $all_ir_name do attr_name=`grep "<\"$ir" -A 3 paddle/infrt/dialect/phi/ir/phi_gpu_kernels.td | grep -Eo "Attr:.*)" \ - | awk '{gsub(/F32Attr/,"");gsub(/F64Attr/,"");gsub(/StrAttr/,"");gsub(/BOOLAttr/,""); \ + | awk '{gsub(/F32Attr/,"");gsub(/F64Attr/,"");gsub(/StrAttr/,"");gsub(/BoolAttr/,""); \ gsub(/SI1Attr/,"");gsub(/SI8Attr/,"");gsub(/SI16Attr/,"");gsub(/SI32Attr/,"");gsub(/SI64Attr/,""); \ gsub(/UI1Attr/,"");gsub(/UI8Attr/,"");gsub(/I16Attr/,"");gsub(/I32Attr/,"");gsub(/I64Attr/,""); \ gsub(/I1Attr/,"");gsub(/I8Attr/,"");gsub(/UI16Attr/,"");gsub(/UI32Attr/,"");gsub(/UI64Attr/,""); \ diff --git a/tools/infrt/get_phi_kernel_info.py b/tools/infrt/get_phi_kernel_info.py index 85ad585cdef..8b752f92871 100644 --- a/tools/infrt/get_phi_kernel_info.py +++ b/tools/infrt/get_phi_kernel_info.py @@ -133,11 +133,11 @@ namespace kernel { def gen_context(val): if val == "CPU": - return "phi::CPUContext", "phi_cpu" + return "::phi::CPUContext", "phi_cpu" elif val == "GPU": - return "phi::GPUContext", "phi_gpu" + return "::phi::GPUContext", "phi_gpu" # elif val == "XPU": - # return "phi::XPUContext", "phi_xpu" + # return "::phi::XPUContext", "phi_xpu" else: # raise Exception(f"Unknown context type {val}") return "", "" @@ -157,12 +157,12 @@ def gen_kernel_func(val, ctx_name, dtype_name): ed = val.index('>') func_name = val[:st] template_name = val[st + 1:ed] - if 'phi::' in template_name: - return "&phi::" + val + if '::phi::' in template_name: + return "&::phi::" + val else: - return "&phi::" + func_name + "" + return "&::phi::" + func_name + "<::phi::" + template_name + ">" else: - return "&phi::" + val + "<" + dtype_name + ", " + ctx_name + ">" + return "&::phi::" + val + "<" + dtype_name + ", " + ctx_name + ">" def gen_dtype(vals: List[str]): @@ -227,7 +227,7 @@ def gen_register_code_info(item: List[str], attr_data: Dict[str, List[str]]): return "" item[2] = gen_layout(item[2]) ir_dtypes, origin_dtypes = gen_dtype(item[4:-1]) - infer_shape_func = "&phi::" + item[-1] + infer_shape_func = "&::phi::" + item[-1] res = "" -- GitLab