From 93a2f5652fa632e4eff8febf49304d64ec72b569 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=9F=B3=E6=99=93=E4=BC=9F?= <39303645+Shixiaowei02@users.noreply.github.com> Date: Mon, 28 Mar 2022 21:29:13 +0800 Subject: [PATCH] predictor supports phi, test=develop (#40856) --- paddle/infrt/api/.gitignore | 1 + paddle/infrt/api/CMakeLists.txt | 3 +- paddle/infrt/api/infrt_api.cc | 100 +++++++++++++----- paddle/infrt/api/infrt_api.h | 14 +-- paddle/infrt/api/infrt_api_test.cc | 79 -------------- paddle/infrt/api/infrt_api_test.cc.in | 60 +++++++++++ paddle/infrt/backends/host/phi_context.h | 24 ++--- .../dialect/phi/pass/phi_op_convert_pass.cc | 2 + .../infrt/kernel/phi/dense_tensor_kernels.cc | 27 +++-- .../infrt/kernel/phi/dense_tensor_kernels.h | 7 +- paddle/infrt/tensor/dense_host_tensor.cc | 2 + paddle/infrt/tensor/dense_host_tensor.h | 2 + paddle/infrt/tests/timer.h | 99 +++++++++++++++++ 13 files changed, 283 insertions(+), 137 deletions(-) create mode 100644 paddle/infrt/api/.gitignore delete mode 100644 paddle/infrt/api/infrt_api_test.cc create mode 100644 paddle/infrt/api/infrt_api_test.cc.in create mode 100644 paddle/infrt/tests/timer.h diff --git a/paddle/infrt/api/.gitignore b/paddle/infrt/api/.gitignore new file mode 100644 index 00000000000..06196d34f87 --- /dev/null +++ b/paddle/infrt/api/.gitignore @@ -0,0 +1 @@ +infrt_api_test.cc diff --git a/paddle/infrt/api/CMakeLists.txt b/paddle/infrt/api/CMakeLists.txt index 93a7ae83695..27d736cfdf7 100644 --- a/paddle/infrt/api/CMakeLists.txt +++ b/paddle/infrt/api/CMakeLists.txt @@ -3,6 +3,7 @@ core_gather_headers() gather_srcs(infrt_src SRCS infrt_api.cc ) +configure_file(${CMAKE_CURRENT_SOURCE_DIR}/infrt_api_test.cc.in ${CMAKE_CURRENT_SOURCE_DIR}/infrt_api_test.cc) # Disable temporarily for the external-kernel's mkldnn is outdate -# cc_test(test_infrt_api SRCS infrt_api_test.cc DEPS infrt ${MLIR_IR_LIBS}) +cc_test_tiny(test_infrt_api SRCS infrt_api_test.cc DEPS infrt ${MLIR_IR_LIBS}) diff --git a/paddle/infrt/api/infrt_api.cc b/paddle/infrt/api/infrt_api.cc index 5ac51fb6715..91668dc176e 100644 --- a/paddle/infrt/api/infrt_api.cc +++ b/paddle/infrt/api/infrt_api.cc @@ -22,18 +22,27 @@ #include #include +#include "mlir/Pass/PassManager.h" +#include "paddle/infrt/backends/host/phi_allocator.h" #include "paddle/infrt/common/global.h" #include "paddle/infrt/dialect/dense_tensor.h" #include "paddle/infrt/dialect/infrt/ir/infrt_dialect.h" +#include "paddle/infrt/dialect/infrt/pass/infrt_op_fuse_pass.h" #include "paddle/infrt/dialect/mlir_loader.h" +#include "paddle/infrt/dialect/phi/ir/phi_base.h" +#include "paddle/infrt/dialect/phi/pass/phi_op_convert_pass.h" #include "paddle/infrt/host_context/core_runtime.h" #include "paddle/infrt/host_context/kernel_registry.h" #include "paddle/infrt/host_context/mlir_function_executable.h" #include "paddle/infrt/host_context/mlir_to_runtime_translate.h" #include "paddle/infrt/host_context/op_executable.h" +#include "paddle/infrt/host_context/paddle_mlir.h" #include "paddle/infrt/host_context/value.h" #include "paddle/infrt/kernel/basic_kernels.h" #include "paddle/infrt/kernel/control_flow_kernels.h" +#include "paddle/infrt/kernel/phi/dense_tensor_kernels.h" +#include "paddle/infrt/kernel/phi/infershaped/infershaped_kernel_launchers.h" +#include "paddle/infrt/kernel/phi/registry.h" #include "paddle/infrt/kernel/tensor_kernels.h" #include "paddle/infrt/kernel/tensor_shape_kernels.h" #include "paddle/infrt/kernel/test_kernels.h" @@ -84,12 +93,12 @@ class PredictExecutor : public MlirToRuntimeTranslator { PredictExecutor(mlir::ModuleOp module, KernelRegistry* registry, - TensorMap* map) + ::infrt::phi::DenseTensorMap&& map) : MlirToRuntimeTranslator(module, &core_runtime), core_runtime(registry), registry_(registry) { CHECK(registry_); - Init(map); + Init(std::move(map)); } void Run() { @@ -100,18 +109,18 @@ class PredictExecutor : public MlirToRuntimeTranslator { int GetInputNum() { return inputs_.size(); } - DenseHostTensor* GetInput(int i) { return inputs_[i]; } + ::phi::DenseTensor* GetInput(int i) { return inputs_[i]; } int GetOutputNum() { return outputs_.size(); } - DenseHostTensor* GetOutput(int i) { return outputs_[i]; } + ::phi::DenseTensor* GetOutput(int i) { return outputs_[i]; } private: - void Init(TensorMap* map) { + void Init(::infrt::phi::DenseTensorMap&& map) { EmitFunctions(); llvm::Optional predict_func_ = llvm::None; for (auto func_op : impl_->module.getOps()) { - if (func_op.getName().str() != "predict") continue; + if (func_op.getName().str() != "main_graph") continue; predict_func_ = func_op; break; } @@ -125,20 +134,24 @@ class PredictExecutor : public MlirToRuntimeTranslator { new MlirFunctionExecutable(predict_func, registry_, impl_->func_defs); // process parammeters + VLOG(3) << "Arguments num of predict func: " + << predict_func.getNumArguments(); for (size_t i = 0; i < predict_func.getNumArguments(); ++i) { auto arg = predict_func.getArgument(i); auto type = arg.getType(); // this param is TensorMap - if (type.isa()) { - auto* value = new host_context::Value(std::move(*map)); + if (type.isa<::infrt::phi::DenseTensorMapType>()) { + auto* value = new host_context::Value(std::move(map)); arguments_.push_back(value); AddValue(predict_func.getArgument(i), value); - } else { + } else if (type.isa<::infrt::DenseTensorType>()) { // this param is an input Tensor - auto dht = DenseHostTensor(); + auto dht = ::phi::DenseTensor(); auto* value = new host_context::Value(std::move(dht)); arguments_.push_back(value); - inputs_.push_back(&(value->get())); + inputs_.push_back(&(value->get<::phi::DenseTensor>())); + } else { + llvm_unreachable("The input type has not been supported by predictor."); } } @@ -146,9 +159,18 @@ class PredictExecutor : public MlirToRuntimeTranslator { auto& last_op = predict_func.front().back(); if (last_op.getName().getStringRef() == "infrt.return") { for (size_t i = 0; i < last_op.getNumOperands(); ++i) { - auto* value = AddValue(mlir::Value(last_op.getOperand(i))); - results_.push_back(ValueRef(value)); - outputs_.push_back(&(value->get())); + auto operand = last_op.getOperand(i); + if (operand.getType().isa<::infrt::DenseTensorType>()) { + auto r = impl_->value_map.try_emplace( + operand, ValueRef(new host_context::Value(::phi::DenseTensor()))); + CHECK(r.second) << "Duplicate add mlir value [" + << DumpToString(operand) << "]"; + auto* value = r.first->second.get(); + results_.push_back(ValueRef(value)); + outputs_.push_back(&(value->get<::phi::DenseTensor>())); + } else { + llvm_unreachable("infrt.return only supports DenseTensor now."); + } } } } @@ -166,22 +188,22 @@ class PredictExecutor : public MlirToRuntimeTranslator { private: KernelRegistry* registry_{}; MlirFunctionExecutable* function_executable_; - llvm::SmallVector inputs_; + llvm::SmallVector<::phi::DenseTensor*, 1> inputs_; llvm::SmallVector arguments_; - llvm::SmallVector outputs_; + llvm::SmallVector<::phi::DenseTensor*, 1> outputs_; llvm::SmallVector results_; }; -std::shared_ptr CreateInfRtPredictor( +std::unique_ptr CreateInfRtPredictor( const InfRtConfig& config) { - auto x = std::make_shared(); + auto x = std::make_unique(); x->Init(config); return x; } struct InfRtPredictor::Impl { - mlir::OwningModuleRef module_ref; std::unique_ptr executor; + MLIRModelGenImpl module_gen_; }; InfRtPredictor::InfRtPredictor() : impl_(new Impl) {} @@ -190,8 +212,7 @@ InfRtPredictor::~InfRtPredictor() {} void InfRtPredictor::Run() { impl_->executor->Run(); } int InfRtPredictor::Init(const InfRtConfig& config) { - mlir::MLIRContext* context = infrt::Global::getMLIRContext(); - auto module_ref = dialect::LoadMlirFile(config.mlir_path(), context); + mlir::MLIRContext* context = ::infrt::Global::getMLIRContext(); KernelRegistry* registry = new KernelRegistry(); @@ -200,8 +221,32 @@ int InfRtPredictor::Init(const InfRtConfig& config) { kernel::RegisterTensorShapeKernels(registry); kernel::RegisterTensorKernels(registry); kernel::RegisterControlFlowKernels(registry); - - impl_->module_ref = std::move(module_ref); +#ifdef INFRT_WITH_PHI + kernel::RegisterPhiKernels(registry); + kernel::RegisterInferShapeLaunchers(registry); +#if defined(INFRT_WITH_GPU) && defined(INFRT_WITH_TRT) + kernel::RegisterTrtKernels(registry); +#endif // INFRT_WITH_GPU && INFRT_WITH_TRT +#endif + + auto module_op = impl_->module_gen_.ImportPaddleModel(config.model_dir(), + config.param_dir()); + + context->loadAllAvailableDialects(); + ::mlir::PassManager pm(context); + ::mlir::OpPassManager& phi_pass_manager = pm.nest<::mlir::FuncOp>(); + std::vector<::infrt::Place> valid_places = {{::infrt::TargetType::CPU, + ::infrt::PrecisionType::FLOAT32, + ::infrt::LayoutType::NCHW}}; + phi_pass_manager.addPass(::infrt::createPhiOpCvtPass(valid_places)); + phi_pass_manager.addPass(::infrt::createInfrtOpFusePass()); + if (mlir::failed(pm.run(module_op))) { + std::cout << "\npass failed!\n" << std::endl; + return 4; + } +#ifndef NDEBUG + module_op.dump(); +#endif // NDEBUG // load extra shared library for (const std::string& lib_path : config.shared_libs()) { @@ -222,23 +267,24 @@ int InfRtPredictor::Init(const InfRtConfig& config) { } // Load params - TensorMap* tensor_map = LoadParams(config.model_dir()); + auto tensor_map = ::infrt::kernel::phi::LoadCombinedParameters( + config.model_dir(), config.param_dir()); // Create PredictExecutor impl_->executor.reset( - new PredictExecutor(impl_->module_ref.get(), registry, tensor_map)); + new PredictExecutor(module_op, registry, std::move(tensor_map))); return 0; } int InfRtPredictor::GetInputNum() { return impl_->executor->GetInputNum(); } -DenseHostTensor* InfRtPredictor::GetInput(int i) { +::phi::DenseTensor* InfRtPredictor::GetInput(int i) { return impl_->executor->GetInput(i); } int InfRtPredictor::GetOutputNum() { return impl_->executor->GetOutputNum(); } -DenseHostTensor* InfRtPredictor::GetOutput(int i) { +::phi::DenseTensor* InfRtPredictor::GetOutput(int i) { return impl_->executor->GetOutput(i); } diff --git a/paddle/infrt/api/infrt_api.h b/paddle/infrt/api/infrt_api.h index 82b6cb8df91..cf14cab3c06 100644 --- a/paddle/infrt/api/infrt_api.h +++ b/paddle/infrt/api/infrt_api.h @@ -17,13 +17,13 @@ #include #include -#include "paddle/infrt/tensor/dense_host_tensor.h" +#include "paddle/phi/core/dense_tensor.h" namespace infrt { class InfRtConfig { std::string model_dir_; - std::string mlir_path_; + std::string param_dir_; std::vector shared_libs_; public: @@ -31,8 +31,8 @@ class InfRtConfig { void set_model_dir(const std::string& model_dir) { model_dir_ = model_dir; } const std::string& model_dir() const { return model_dir_; } - void set_mlir_path(const std::string& mlir_path) { mlir_path_ = mlir_path; } - const std::string& mlir_path() const { return mlir_path_; } + void set_param_dir(const std::string& param_dir) { param_dir_ = param_dir; } + const std::string& param_dir() const { return param_dir_; } void set_shared_libs(const std::vector& shared_libs) { shared_libs_ = shared_libs; @@ -49,15 +49,15 @@ class InfRtPredictor { void Run(); int Init(const InfRtConfig& config); int GetInputNum(); - tensor::DenseHostTensor* GetInput(int i); + ::phi::DenseTensor* GetInput(int i); int GetOutputNum(); - tensor::DenseHostTensor* GetOutput(int i); + ::phi::DenseTensor* GetOutput(int i); protected: struct Impl; std::unique_ptr impl_; }; -std::shared_ptr CreateInfRtPredictor(const InfRtConfig& config); +std::unique_ptr CreateInfRtPredictor(const InfRtConfig& config); } // namespace infrt diff --git a/paddle/infrt/api/infrt_api_test.cc b/paddle/infrt/api/infrt_api_test.cc deleted file mode 100644 index 92e069f4752..00000000000 --- a/paddle/infrt/api/infrt_api_test.cc +++ /dev/null @@ -1,79 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/infrt/api/infrt_api.h" - -#include - -#include -#include - -#include "llvm/Support/raw_ostream.h" -#include "paddle/infrt/common/buffer.h" -#include "paddle/infrt/common/dtype.h" - -using infrt::InfRtConfig; -using infrt::InfRtPredictor; -using infrt::CreateInfRtPredictor; - -namespace infrt { - -TEST(InfRtPredictor, predictor) { - std::vector shared_libs; - shared_libs.push_back("../../paddle/libexternal_kernels.so"); - - InfRtConfig config; - - // set external shared libraries that contain kernels. - config.set_shared_libs(shared_libs); - // set model dir - config.set_model_dir("../../paddle/paddle_1.8_fc_model"); - // set mlir path - config.set_mlir_path("../../../infrt/dialect/mlir_tests/tensor_map.mlir"); - - std::shared_ptr predictor = CreateInfRtPredictor(config); - - auto* input = predictor->GetInput(0); - std::vector shape = {3, 3}; - input->Init(shape, infrt::GetDType()); - llvm::outs() << input->shape() << "\n"; - - // init input tensor - auto* input_data = reinterpret_cast(input->buffer()->data()->memory); - for (int i = 0; i < input->shape().GetNumElements(); i++) input_data[i] = 1.0; - - predictor->Run(); - - // get and print output tensor - auto* output = predictor->GetOutput(0); - auto* output_data = - reinterpret_cast(output->buffer()->data()->memory); - - std::vector ans = {0.428458, - 0.244493, - 0.572342, - 0.572008, - 0.509771, - 0.495599, - 0.651287, - 0.326426, - 0.404649}; - - ASSERT_EQ(output->shape().GetNumElements(), ans.size()); - for (int i = 0; i < output->shape().GetNumElements(); ++i) { - ASSERT_NEAR(output_data[i], ans[i], 0.000001); - } -} - -} // namespace infrt diff --git a/paddle/infrt/api/infrt_api_test.cc.in b/paddle/infrt/api/infrt_api_test.cc.in new file mode 100644 index 00000000000..6323b6a540a --- /dev/null +++ b/paddle/infrt/api/infrt_api_test.cc.in @@ -0,0 +1,60 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include +#include + +#include "llvm/Support/raw_ostream.h" +#include "paddle/infrt/api/infrt_api.h" +#include "paddle/infrt/backends/host/phi_allocator.h" +#include "paddle/infrt/common/buffer.h" +#include "paddle/infrt/common/dtype.h" + +using infrt::InfRtConfig; +using infrt::InfRtPredictor; +using infrt::CreateInfRtPredictor; + +namespace infrt { + +TEST(InfRtPredictor, predictor) { + std::vector shared_libs; + + InfRtConfig config; + + config.set_model_dir("@CMAKE_BINARY_DIR@/linear/linear.pdmodel"); + config.set_param_dir("@CMAKE_BINARY_DIR@/linear/linear.pdiparams"); + + std::unique_ptr predictor = CreateInfRtPredictor(config); + + ::infrt::backends::CpuPhiAllocator cpu_allocator; + ::phi::DenseTensor* input = predictor->GetInput(0); + input->Resize({16, 784}); + input->AllocateFrom(&cpu_allocator, ::phi::DataType::FLOAT32); + auto* input_data = reinterpret_cast(input->data()); + for (int i = 0; i < input->numel(); i++) input_data[i] = 1.0; + + predictor->Run(); + + // get and print output tensor + auto* output = predictor->GetOutput(0); + + // TODO(Shixiaowei02): Automatic result validation for training then inference. + // auto* output_data = reinterpret_cast(output->data()); + + ASSERT_EQ(output->dims(), ::phi::DDim({16, 10})); +} + +} // namespace infrt diff --git a/paddle/infrt/backends/host/phi_context.h b/paddle/infrt/backends/host/phi_context.h index bcd63dbb39f..2af1fab1008 100644 --- a/paddle/infrt/backends/host/phi_context.h +++ b/paddle/infrt/backends/host/phi_context.h @@ -18,10 +18,10 @@ limitations under the License. */ namespace infrt { namespace backends { -class CpuPhiContext : public phi::CPUContext { +class CpuPhiContext : public ::phi::CPUContext { public: - using Base = phi::CPUContext; - using phi::CPUContext::SetEigenDevice; + using Base = ::phi::CPUContext; + using ::phi::CPUContext::SetEigenDevice; CpuPhiContext() { Init(); @@ -29,18 +29,18 @@ class CpuPhiContext : public phi::CPUContext { } private: - std::unique_ptr alloc_{std::make_unique()}; + std::unique_ptr<::phi::Allocator> alloc_{std::make_unique()}; }; -class GpuPhiContext : public phi::GPUContext { +class GpuPhiContext : public ::phi::GPUContext { public: - using Base = phi::GPUContext; - using phi::GPUContext::SetStream; - using phi::GPUContext::SetEigenDevice; - using phi::GPUContext::SetBlasHandle; - using phi::GPUContext::SetDnnHandle; - using phi::GPUContext::SetSolverHandle; - using phi::GPUContext::SetSparseHandle; + using Base = ::phi::GPUContext; + using ::phi::GPUContext::SetStream; + using ::phi::GPUContext::SetEigenDevice; + using ::phi::GPUContext::SetBlasHandle; + using ::phi::GPUContext::SetDnnHandle; + using ::phi::GPUContext::SetSolverHandle; + using ::phi::GPUContext::SetSparseHandle; }; } // namespace backends diff --git a/paddle/infrt/dialect/phi/pass/phi_op_convert_pass.cc b/paddle/infrt/dialect/phi/pass/phi_op_convert_pass.cc index 4abdb388dc2..bfc43125b8b 100644 --- a/paddle/infrt/dialect/phi/pass/phi_op_convert_pass.cc +++ b/paddle/infrt/dialect/phi/pass/phi_op_convert_pass.cc @@ -110,6 +110,8 @@ void PhiOpConvertPass::convertStage() { ::phi::KernelSignature kernel_sign = ::phi::OpUtilsMap::Instance().GetArgumentMappingFn(op_name)( infrt::ProtoArgumentMappingContext(op)); + VLOG(3) << "IncompatiblePhiKernel: op(" << op_name << "), kernel(" + << kernel_sign.name << ")"; // resort input&output according to kernel_sign ::llvm::SmallVector inputs, ori_output; ::llvm::SmallVector output_types; diff --git a/paddle/infrt/kernel/phi/dense_tensor_kernels.cc b/paddle/infrt/kernel/phi/dense_tensor_kernels.cc index 26048a43f99..844db8aecb2 100644 --- a/paddle/infrt/kernel/phi/dense_tensor_kernels.cc +++ b/paddle/infrt/kernel/phi/dense_tensor_kernels.cc @@ -19,6 +19,7 @@ #include "paddle/infrt/kernel/phi/context_kernels.h" #include "paddle/infrt/paddle/model_parser.h" #include "paddle/infrt/paddle/scope.h" +#include "paddle/infrt/tensor/tensor_map.h" #include "paddle/phi/backends/all_context.h" #include "paddle/phi/common/place.h" @@ -167,9 +168,7 @@ void PrintDenseTensor(::phi::DenseTensor* dense_tensor) { #undef PRINT_META_DATA } -::infrt::phi::DenseTensorMap LoadParams( - host_context::Attribute path) { - const auto& file_path = path.get(); +::infrt::phi::DenseTensorMap LoadParameters(const std::string& file_path) { std::cout << "loading params from: " << file_path << std::endl; ::infrt::phi::DenseTensorMap map; @@ -201,17 +200,19 @@ void PrintDenseTensor(::phi::DenseTensor* dense_tensor) { return map; } -::infrt::phi::DenseTensorMap LoadCombinedParams( - host_context::Attribute model_path, - host_context::Attribute params_path) { - const auto& model = model_path.get(); - std::cout << "loading params from: " << model << std::endl; +::infrt::phi::DenseTensorMap LoadParams( + host_context::Attribute path) { + return LoadParameters(path.get()); +} + +::infrt::phi::DenseTensorMap LoadCombinedParameters( + const std::string& model_path, const std::string& params_path) { ::infrt::phi::DenseTensorMap map; - auto pb_proto_prog = paddle::LoadProgram(model); + auto pb_proto_prog = paddle::LoadProgram(model_path); auto main_block = pb_proto_prog->blocks(0); - std::ifstream param_file(params_path.get(), std::ios::binary); + std::ifstream param_file(params_path, std::ios::binary); std::set tmp; for (auto& var : main_block.vars()) { @@ -237,6 +238,12 @@ void PrintDenseTensor(::phi::DenseTensor* dense_tensor) { return map; } +::infrt::phi::DenseTensorMap LoadCombinedParams( + host_context::Attribute model_path, + host_context::Attribute params_path) { + return LoadCombinedParameters(model_path.get(), params_path.get()); +} + ::phi::DenseTensor TensorMapGetTensor( const ::infrt::phi::DenseTensorMap& map, host_context::Attribute name) { diff --git a/paddle/infrt/kernel/phi/dense_tensor_kernels.h b/paddle/infrt/kernel/phi/dense_tensor_kernels.h index 2d0698eb597..60cc63a928f 100644 --- a/paddle/infrt/kernel/phi/dense_tensor_kernels.h +++ b/paddle/infrt/kernel/phi/dense_tensor_kernels.h @@ -50,7 +50,9 @@ void FillDenseTensorF32(::phi::DenseTensor* dense_tensor, host_context::Attribute> values); void PrintDenseTensor(::phi::DenseTensor* dense_tensor); -infrt::phi::DenseTensorMap LoadParams( +::infrt::phi::DenseTensorMap LoadParameters(const std::string& path); + +::infrt::phi::DenseTensorMap LoadParams( host_context::Attribute path); ::phi::DenseTensor TensorMapGetTensor( @@ -61,6 +63,9 @@ infrt::phi::DenseTensorMap LoadParams( host_context::Attribute model_path, host_context::Attribute params_path); +::infrt::phi::DenseTensorMap LoadCombinedParameters( + const std::string& model_path, const std::string& params_path); + int32_t TensorMapGetSize(const ::infrt::phi::DenseTensorMap& map); #ifdef INFRT_WITH_GPU diff --git a/paddle/infrt/tensor/dense_host_tensor.cc b/paddle/infrt/tensor/dense_host_tensor.cc index 639b0f9f517..26eaf2618e8 100644 --- a/paddle/infrt/tensor/dense_host_tensor.cc +++ b/paddle/infrt/tensor/dense_host_tensor.cc @@ -90,4 +90,6 @@ DenseHostTensor::~DenseHostTensor() {} void* DenseHostTensor::raw_data() const { return buffer_->data()->memory; } +DType DenseHostTensor::dtype() const { return metadata().dtype; } + } // namespace infrt::tensor diff --git a/paddle/infrt/tensor/dense_host_tensor.h b/paddle/infrt/tensor/dense_host_tensor.h index 6003c821185..5ff34625344 100644 --- a/paddle/infrt/tensor/dense_host_tensor.h +++ b/paddle/infrt/tensor/dense_host_tensor.h @@ -78,6 +78,8 @@ class DenseHostTensor : public HostTensor { const TensorShape& shape() const; TensorShape* mutable_shape(); + DType dtype() const; + const Buffer* buffer() const; void* raw_data() const; diff --git a/paddle/infrt/tests/timer.h b/paddle/infrt/tests/timer.h new file mode 100644 index 00000000000..18372cbe541 --- /dev/null +++ b/paddle/infrt/tests/timer.h @@ -0,0 +1,99 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +namespace infrt { +namespace tests { + +template +class ChronoTimer { + public: + using TimePoint = std::chrono::time_point; + ChronoTimer() : start_{TimePoint::min()} {} + void Clear() { start_ = TimePoint::min(); } + void Start() { start_ = ClockT::now(); } + + double GetMs() { + auto diff = ClockT::now() - start_; + return static_cast( + std::chrono::duration_cast>(diff) + .count()) * + 1000.0; + } + + private: + TimePoint start_; +}; + +using WallClockTimer = ChronoTimer; + +class CpuClockTimer { + public: + CpuClockTimer() = default; + void Clear() { start_ = 0; } + void Start() { start_ = std::clock(); } + double GetMs() { + std::clock_t diff = std::clock() - start_; + return static_cast(diff * 1000.0 / CLOCKS_PER_SEC); + } + + private: + std::clock_t start_{0}; +}; + +class BenchmarkStats { + public: + void Start() { + wall_timer_.Start(); + cpu_timer_.Start(); + } + + void Stop() { + wall_time_.push_back(wall_timer_.GetMs()); + cpu_time_.push_back(cpu_timer_.GetMs()); + } + + std::string Summerize(const std::vector& percents) { + std::stringstream ss; + std::sort(wall_time_.begin(), wall_time_.end()); + std::sort(cpu_time_.begin(), cpu_time_.end()); + auto percentile = [](float p, const std::vector& stats) { + assert(p >= 0 && p < 1); + return stats[stats.size() * p]; + }; + for (auto p : percents) { + ss << "=== Wall Time (ms): \n"; + ss << " * percent " << std::to_string(static_cast(p * 100)); + ss << ": " << percentile(p, wall_time_) << '\n'; + } + for (auto p : percents) { + ss << "=== CPU Time (ms): \n"; + ss << " * percent " << std::to_string(static_cast(p * 100)); + ss << ": " << percentile(p, cpu_time_) << '\n'; + } + return ss.str(); + } + + private: + WallClockTimer wall_timer_; + std::vector wall_time_; + CpuClockTimer cpu_timer_; + std::vector cpu_time_; +}; + +} // namespace tests +} // namespace infrt -- GitLab