未验证 提交 93a2f565 编写于 作者: 石晓伟 提交者: GitHub

predictor supports phi, test=develop (#40856)

上级 ca871957
......@@ -3,6 +3,7 @@ core_gather_headers()
gather_srcs(infrt_src SRCS
infrt_api.cc
)
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/infrt_api_test.cc.in ${CMAKE_CURRENT_SOURCE_DIR}/infrt_api_test.cc)
# Disable temporarily for the external-kernel's mkldnn is outdate
# cc_test(test_infrt_api SRCS infrt_api_test.cc DEPS infrt ${MLIR_IR_LIBS})
cc_test_tiny(test_infrt_api SRCS infrt_api_test.cc DEPS infrt ${MLIR_IR_LIBS})
......@@ -22,18 +22,27 @@
#include <unordered_map>
#include <vector>
#include "mlir/Pass/PassManager.h"
#include "paddle/infrt/backends/host/phi_allocator.h"
#include "paddle/infrt/common/global.h"
#include "paddle/infrt/dialect/dense_tensor.h"
#include "paddle/infrt/dialect/infrt/ir/infrt_dialect.h"
#include "paddle/infrt/dialect/infrt/pass/infrt_op_fuse_pass.h"
#include "paddle/infrt/dialect/mlir_loader.h"
#include "paddle/infrt/dialect/phi/ir/phi_base.h"
#include "paddle/infrt/dialect/phi/pass/phi_op_convert_pass.h"
#include "paddle/infrt/host_context/core_runtime.h"
#include "paddle/infrt/host_context/kernel_registry.h"
#include "paddle/infrt/host_context/mlir_function_executable.h"
#include "paddle/infrt/host_context/mlir_to_runtime_translate.h"
#include "paddle/infrt/host_context/op_executable.h"
#include "paddle/infrt/host_context/paddle_mlir.h"
#include "paddle/infrt/host_context/value.h"
#include "paddle/infrt/kernel/basic_kernels.h"
#include "paddle/infrt/kernel/control_flow_kernels.h"
#include "paddle/infrt/kernel/phi/dense_tensor_kernels.h"
#include "paddle/infrt/kernel/phi/infershaped/infershaped_kernel_launchers.h"
#include "paddle/infrt/kernel/phi/registry.h"
#include "paddle/infrt/kernel/tensor_kernels.h"
#include "paddle/infrt/kernel/tensor_shape_kernels.h"
#include "paddle/infrt/kernel/test_kernels.h"
......@@ -84,12 +93,12 @@ class PredictExecutor : public MlirToRuntimeTranslator {
PredictExecutor(mlir::ModuleOp module,
KernelRegistry* registry,
TensorMap* map)
::infrt::phi::DenseTensorMap&& map)
: MlirToRuntimeTranslator(module, &core_runtime),
core_runtime(registry),
registry_(registry) {
CHECK(registry_);
Init(map);
Init(std::move(map));
}
void Run() {
......@@ -100,18 +109,18 @@ class PredictExecutor : public MlirToRuntimeTranslator {
int GetInputNum() { return inputs_.size(); }
DenseHostTensor* GetInput(int i) { return inputs_[i]; }
::phi::DenseTensor* GetInput(int i) { return inputs_[i]; }
int GetOutputNum() { return outputs_.size(); }
DenseHostTensor* GetOutput(int i) { return outputs_[i]; }
::phi::DenseTensor* GetOutput(int i) { return outputs_[i]; }
private:
void Init(TensorMap* map) {
void Init(::infrt::phi::DenseTensorMap&& map) {
EmitFunctions();
llvm::Optional<mlir::FuncOp> predict_func_ = llvm::None;
for (auto func_op : impl_->module.getOps<mlir::FuncOp>()) {
if (func_op.getName().str() != "predict") continue;
if (func_op.getName().str() != "main_graph") continue;
predict_func_ = func_op;
break;
}
......@@ -125,20 +134,24 @@ class PredictExecutor : public MlirToRuntimeTranslator {
new MlirFunctionExecutable(predict_func, registry_, impl_->func_defs);
// process parammeters
VLOG(3) << "Arguments num of predict func: "
<< predict_func.getNumArguments();
for (size_t i = 0; i < predict_func.getNumArguments(); ++i) {
auto arg = predict_func.getArgument(i);
auto type = arg.getType();
// this param is TensorMap
if (type.isa<infrt::DenseHostTensorMapType>()) {
auto* value = new host_context::Value(std::move(*map));
if (type.isa<::infrt::phi::DenseTensorMapType>()) {
auto* value = new host_context::Value(std::move(map));
arguments_.push_back(value);
AddValue(predict_func.getArgument(i), value);
} else {
} else if (type.isa<::infrt::DenseTensorType>()) {
// this param is an input Tensor
auto dht = DenseHostTensor();
auto dht = ::phi::DenseTensor();
auto* value = new host_context::Value(std::move(dht));
arguments_.push_back(value);
inputs_.push_back(&(value->get<DenseHostTensor>()));
inputs_.push_back(&(value->get<::phi::DenseTensor>()));
} else {
llvm_unreachable("The input type has not been supported by predictor.");
}
}
......@@ -146,9 +159,18 @@ class PredictExecutor : public MlirToRuntimeTranslator {
auto& last_op = predict_func.front().back();
if (last_op.getName().getStringRef() == "infrt.return") {
for (size_t i = 0; i < last_op.getNumOperands(); ++i) {
auto* value = AddValue(mlir::Value(last_op.getOperand(i)));
auto operand = last_op.getOperand(i);
if (operand.getType().isa<::infrt::DenseTensorType>()) {
auto r = impl_->value_map.try_emplace(
operand, ValueRef(new host_context::Value(::phi::DenseTensor())));
CHECK(r.second) << "Duplicate add mlir value ["
<< DumpToString(operand) << "]";
auto* value = r.first->second.get();
results_.push_back(ValueRef(value));
outputs_.push_back(&(value->get<DenseHostTensor>()));
outputs_.push_back(&(value->get<::phi::DenseTensor>()));
} else {
llvm_unreachable("infrt.return only supports DenseTensor now.");
}
}
}
}
......@@ -166,22 +188,22 @@ class PredictExecutor : public MlirToRuntimeTranslator {
private:
KernelRegistry* registry_{};
MlirFunctionExecutable* function_executable_;
llvm::SmallVector<DenseHostTensor*, 1> inputs_;
llvm::SmallVector<::phi::DenseTensor*, 1> inputs_;
llvm::SmallVector<host_context::Value*, 2> arguments_;
llvm::SmallVector<DenseHostTensor*, 1> outputs_;
llvm::SmallVector<::phi::DenseTensor*, 1> outputs_;
llvm::SmallVector<ValueRef, 1> results_;
};
std::shared_ptr<InfRtPredictor> CreateInfRtPredictor(
std::unique_ptr<InfRtPredictor> CreateInfRtPredictor(
const InfRtConfig& config) {
auto x = std::make_shared<InfRtPredictor>();
auto x = std::make_unique<InfRtPredictor>();
x->Init(config);
return x;
}
struct InfRtPredictor::Impl {
mlir::OwningModuleRef module_ref;
std::unique_ptr<PredictExecutor> executor;
MLIRModelGenImpl module_gen_;
};
InfRtPredictor::InfRtPredictor() : impl_(new Impl) {}
......@@ -190,8 +212,7 @@ InfRtPredictor::~InfRtPredictor() {}
void InfRtPredictor::Run() { impl_->executor->Run(); }
int InfRtPredictor::Init(const InfRtConfig& config) {
mlir::MLIRContext* context = infrt::Global::getMLIRContext();
auto module_ref = dialect::LoadMlirFile(config.mlir_path(), context);
mlir::MLIRContext* context = ::infrt::Global::getMLIRContext();
KernelRegistry* registry = new KernelRegistry();
......@@ -200,8 +221,32 @@ int InfRtPredictor::Init(const InfRtConfig& config) {
kernel::RegisterTensorShapeKernels(registry);
kernel::RegisterTensorKernels(registry);
kernel::RegisterControlFlowKernels(registry);
impl_->module_ref = std::move(module_ref);
#ifdef INFRT_WITH_PHI
kernel::RegisterPhiKernels(registry);
kernel::RegisterInferShapeLaunchers(registry);
#if defined(INFRT_WITH_GPU) && defined(INFRT_WITH_TRT)
kernel::RegisterTrtKernels(registry);
#endif // INFRT_WITH_GPU && INFRT_WITH_TRT
#endif
auto module_op = impl_->module_gen_.ImportPaddleModel(config.model_dir(),
config.param_dir());
context->loadAllAvailableDialects();
::mlir::PassManager pm(context);
::mlir::OpPassManager& phi_pass_manager = pm.nest<::mlir::FuncOp>();
std::vector<::infrt::Place> valid_places = {{::infrt::TargetType::CPU,
::infrt::PrecisionType::FLOAT32,
::infrt::LayoutType::NCHW}};
phi_pass_manager.addPass(::infrt::createPhiOpCvtPass(valid_places));
phi_pass_manager.addPass(::infrt::createInfrtOpFusePass());
if (mlir::failed(pm.run(module_op))) {
std::cout << "\npass failed!\n" << std::endl;
return 4;
}
#ifndef NDEBUG
module_op.dump();
#endif // NDEBUG
// load extra shared library
for (const std::string& lib_path : config.shared_libs()) {
......@@ -222,23 +267,24 @@ int InfRtPredictor::Init(const InfRtConfig& config) {
}
// Load params
TensorMap* tensor_map = LoadParams(config.model_dir());
auto tensor_map = ::infrt::kernel::phi::LoadCombinedParameters(
config.model_dir(), config.param_dir());
// Create PredictExecutor
impl_->executor.reset(
new PredictExecutor(impl_->module_ref.get(), registry, tensor_map));
new PredictExecutor(module_op, registry, std::move(tensor_map)));
return 0;
}
int InfRtPredictor::GetInputNum() { return impl_->executor->GetInputNum(); }
DenseHostTensor* InfRtPredictor::GetInput(int i) {
::phi::DenseTensor* InfRtPredictor::GetInput(int i) {
return impl_->executor->GetInput(i);
}
int InfRtPredictor::GetOutputNum() { return impl_->executor->GetOutputNum(); }
DenseHostTensor* InfRtPredictor::GetOutput(int i) {
::phi::DenseTensor* InfRtPredictor::GetOutput(int i) {
return impl_->executor->GetOutput(i);
}
......
......@@ -17,13 +17,13 @@
#include <string>
#include <vector>
#include "paddle/infrt/tensor/dense_host_tensor.h"
#include "paddle/phi/core/dense_tensor.h"
namespace infrt {
class InfRtConfig {
std::string model_dir_;
std::string mlir_path_;
std::string param_dir_;
std::vector<std::string> shared_libs_;
public:
......@@ -31,8 +31,8 @@ class InfRtConfig {
void set_model_dir(const std::string& model_dir) { model_dir_ = model_dir; }
const std::string& model_dir() const { return model_dir_; }
void set_mlir_path(const std::string& mlir_path) { mlir_path_ = mlir_path; }
const std::string& mlir_path() const { return mlir_path_; }
void set_param_dir(const std::string& param_dir) { param_dir_ = param_dir; }
const std::string& param_dir() const { return param_dir_; }
void set_shared_libs(const std::vector<std::string>& shared_libs) {
shared_libs_ = shared_libs;
......@@ -49,15 +49,15 @@ class InfRtPredictor {
void Run();
int Init(const InfRtConfig& config);
int GetInputNum();
tensor::DenseHostTensor* GetInput(int i);
::phi::DenseTensor* GetInput(int i);
int GetOutputNum();
tensor::DenseHostTensor* GetOutput(int i);
::phi::DenseTensor* GetOutput(int i);
protected:
struct Impl;
std::unique_ptr<Impl> impl_;
};
std::shared_ptr<InfRtPredictor> CreateInfRtPredictor(const InfRtConfig& config);
std::unique_ptr<InfRtPredictor> CreateInfRtPredictor(const InfRtConfig& config);
} // namespace infrt
......@@ -12,14 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/api/infrt_api.h"
#include <gtest/gtest.h>
#include <iostream>
#include <vector>
#include "llvm/Support/raw_ostream.h"
#include "paddle/infrt/api/infrt_api.h"
#include "paddle/infrt/backends/host/phi_allocator.h"
#include "paddle/infrt/common/buffer.h"
#include "paddle/infrt/common/dtype.h"
......@@ -31,49 +31,30 @@ namespace infrt {
TEST(InfRtPredictor, predictor) {
std::vector<std::string> shared_libs;
shared_libs.push_back("../../paddle/libexternal_kernels.so");
InfRtConfig config;
// set external shared libraries that contain kernels.
config.set_shared_libs(shared_libs);
// set model dir
config.set_model_dir("../../paddle/paddle_1.8_fc_model");
// set mlir path
config.set_mlir_path("../../../infrt/dialect/mlir_tests/tensor_map.mlir");
std::shared_ptr<InfRtPredictor> predictor = CreateInfRtPredictor(config);
config.set_model_dir("@CMAKE_BINARY_DIR@/linear/linear.pdmodel");
config.set_param_dir("@CMAKE_BINARY_DIR@/linear/linear.pdiparams");
auto* input = predictor->GetInput(0);
std::vector<int64_t> shape = {3, 3};
input->Init(shape, infrt::GetDType<float>());
llvm::outs() << input->shape() << "\n";
std::unique_ptr<InfRtPredictor> predictor = CreateInfRtPredictor(config);
// init input tensor
auto* input_data = reinterpret_cast<float*>(input->buffer()->data()->memory);
for (int i = 0; i < input->shape().GetNumElements(); i++) input_data[i] = 1.0;
::infrt::backends::CpuPhiAllocator cpu_allocator;
::phi::DenseTensor* input = predictor->GetInput(0);
input->Resize({16, 784});
input->AllocateFrom(&cpu_allocator, ::phi::DataType::FLOAT32);
auto* input_data = reinterpret_cast<float*>(input->data());
for (int i = 0; i < input->numel(); i++) input_data[i] = 1.0;
predictor->Run();
// get and print output tensor
auto* output = predictor->GetOutput(0);
auto* output_data =
reinterpret_cast<float*>(output->buffer()->data()->memory);
std::vector<float> ans = {0.428458,
0.244493,
0.572342,
0.572008,
0.509771,
0.495599,
0.651287,
0.326426,
0.404649};
// TODO(Shixiaowei02): Automatic result validation for training then inference.
// auto* output_data = reinterpret_cast<float*>(output->data());
ASSERT_EQ(output->shape().GetNumElements(), ans.size());
for (int i = 0; i < output->shape().GetNumElements(); ++i) {
ASSERT_NEAR(output_data[i], ans[i], 0.000001);
}
ASSERT_EQ(output->dims(), ::phi::DDim({16, 10}));
}
} // namespace infrt
......@@ -18,10 +18,10 @@ limitations under the License. */
namespace infrt {
namespace backends {
class CpuPhiContext : public phi::CPUContext {
class CpuPhiContext : public ::phi::CPUContext {
public:
using Base = phi::CPUContext;
using phi::CPUContext::SetEigenDevice;
using Base = ::phi::CPUContext;
using ::phi::CPUContext::SetEigenDevice;
CpuPhiContext() {
Init();
......@@ -29,18 +29,18 @@ class CpuPhiContext : public phi::CPUContext {
}
private:
std::unique_ptr<phi::Allocator> alloc_{std::make_unique<CpuPhiAllocator>()};
std::unique_ptr<::phi::Allocator> alloc_{std::make_unique<CpuPhiAllocator>()};
};
class GpuPhiContext : public phi::GPUContext {
class GpuPhiContext : public ::phi::GPUContext {
public:
using Base = phi::GPUContext;
using phi::GPUContext::SetStream;
using phi::GPUContext::SetEigenDevice;
using phi::GPUContext::SetBlasHandle;
using phi::GPUContext::SetDnnHandle;
using phi::GPUContext::SetSolverHandle;
using phi::GPUContext::SetSparseHandle;
using Base = ::phi::GPUContext;
using ::phi::GPUContext::SetStream;
using ::phi::GPUContext::SetEigenDevice;
using ::phi::GPUContext::SetBlasHandle;
using ::phi::GPUContext::SetDnnHandle;
using ::phi::GPUContext::SetSolverHandle;
using ::phi::GPUContext::SetSparseHandle;
};
} // namespace backends
......
......@@ -110,6 +110,8 @@ void PhiOpConvertPass::convertStage() {
::phi::KernelSignature kernel_sign =
::phi::OpUtilsMap::Instance().GetArgumentMappingFn(op_name)(
infrt::ProtoArgumentMappingContext(op));
VLOG(3) << "IncompatiblePhiKernel: op(" << op_name << "), kernel("
<< kernel_sign.name << ")";
// resort input&output according to kernel_sign
::llvm::SmallVector<mlir::Value, 4> inputs, ori_output;
::llvm::SmallVector<mlir::Type, 4> output_types;
......
......@@ -19,6 +19,7 @@
#include "paddle/infrt/kernel/phi/context_kernels.h"
#include "paddle/infrt/paddle/model_parser.h"
#include "paddle/infrt/paddle/scope.h"
#include "paddle/infrt/tensor/tensor_map.h"
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/common/place.h"
......@@ -167,9 +168,7 @@ void PrintDenseTensor(::phi::DenseTensor* dense_tensor) {
#undef PRINT_META_DATA
}
::infrt::phi::DenseTensorMap LoadParams(
host_context::Attribute<std::string> path) {
const auto& file_path = path.get();
::infrt::phi::DenseTensorMap LoadParameters(const std::string& file_path) {
std::cout << "loading params from: " << file_path << std::endl;
::infrt::phi::DenseTensorMap map;
......@@ -201,17 +200,19 @@ void PrintDenseTensor(::phi::DenseTensor* dense_tensor) {
return map;
}
::infrt::phi::DenseTensorMap LoadCombinedParams(
host_context::Attribute<std::string> model_path,
host_context::Attribute<std::string> params_path) {
const auto& model = model_path.get();
std::cout << "loading params from: " << model << std::endl;
::infrt::phi::DenseTensorMap LoadParams(
host_context::Attribute<std::string> path) {
return LoadParameters(path.get());
}
::infrt::phi::DenseTensorMap LoadCombinedParameters(
const std::string& model_path, const std::string& params_path) {
::infrt::phi::DenseTensorMap map;
auto pb_proto_prog = paddle::LoadProgram(model);
auto pb_proto_prog = paddle::LoadProgram(model_path);
auto main_block = pb_proto_prog->blocks(0);
std::ifstream param_file(params_path.get(), std::ios::binary);
std::ifstream param_file(params_path, std::ios::binary);
std::set<std::string> tmp;
for (auto& var : main_block.vars()) {
......@@ -237,6 +238,12 @@ void PrintDenseTensor(::phi::DenseTensor* dense_tensor) {
return map;
}
::infrt::phi::DenseTensorMap LoadCombinedParams(
host_context::Attribute<std::string> model_path,
host_context::Attribute<std::string> params_path) {
return LoadCombinedParameters(model_path.get(), params_path.get());
}
::phi::DenseTensor TensorMapGetTensor(
const ::infrt::phi::DenseTensorMap& map,
host_context::Attribute<std::string> name) {
......
......@@ -50,7 +50,9 @@ void FillDenseTensorF32(::phi::DenseTensor* dense_tensor,
host_context::Attribute<std::vector<float>> values);
void PrintDenseTensor(::phi::DenseTensor* dense_tensor);
infrt::phi::DenseTensorMap LoadParams(
::infrt::phi::DenseTensorMap LoadParameters(const std::string& path);
::infrt::phi::DenseTensorMap LoadParams(
host_context::Attribute<std::string> path);
::phi::DenseTensor TensorMapGetTensor(
......@@ -61,6 +63,9 @@ infrt::phi::DenseTensorMap LoadParams(
host_context::Attribute<std::string> model_path,
host_context::Attribute<std::string> params_path);
::infrt::phi::DenseTensorMap LoadCombinedParameters(
const std::string& model_path, const std::string& params_path);
int32_t TensorMapGetSize(const ::infrt::phi::DenseTensorMap& map);
#ifdef INFRT_WITH_GPU
......
......@@ -90,4 +90,6 @@ DenseHostTensor::~DenseHostTensor() {}
void* DenseHostTensor::raw_data() const { return buffer_->data()->memory; }
DType DenseHostTensor::dtype() const { return metadata().dtype; }
} // namespace infrt::tensor
......@@ -78,6 +78,8 @@ class DenseHostTensor : public HostTensor {
const TensorShape& shape() const;
TensorShape* mutable_shape();
DType dtype() const;
const Buffer* buffer() const;
void* raw_data() const;
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <chrono>
namespace infrt {
namespace tests {
template <typename ClockT>
class ChronoTimer {
public:
using TimePoint = std::chrono::time_point<ClockT>;
ChronoTimer() : start_{TimePoint::min()} {}
void Clear() { start_ = TimePoint::min(); }
void Start() { start_ = ClockT::now(); }
double GetMs() {
auto diff = ClockT::now() - start_;
return static_cast<double>(
std::chrono::duration_cast<std::chrono::duration<double>>(diff)
.count()) *
1000.0;
}
private:
TimePoint start_;
};
using WallClockTimer = ChronoTimer<std::chrono::steady_clock>;
class CpuClockTimer {
public:
CpuClockTimer() = default;
void Clear() { start_ = 0; }
void Start() { start_ = std::clock(); }
double GetMs() {
std::clock_t diff = std::clock() - start_;
return static_cast<double>(diff * 1000.0 / CLOCKS_PER_SEC);
}
private:
std::clock_t start_{0};
};
class BenchmarkStats {
public:
void Start() {
wall_timer_.Start();
cpu_timer_.Start();
}
void Stop() {
wall_time_.push_back(wall_timer_.GetMs());
cpu_time_.push_back(cpu_timer_.GetMs());
}
std::string Summerize(const std::vector<float>& percents) {
std::stringstream ss;
std::sort(wall_time_.begin(), wall_time_.end());
std::sort(cpu_time_.begin(), cpu_time_.end());
auto percentile = [](float p, const std::vector<float>& stats) {
assert(p >= 0 && p < 1);
return stats[stats.size() * p];
};
for (auto p : percents) {
ss << "=== Wall Time (ms): \n";
ss << " * percent " << std::to_string(static_cast<int>(p * 100));
ss << ": " << percentile(p, wall_time_) << '\n';
}
for (auto p : percents) {
ss << "=== CPU Time (ms): \n";
ss << " * percent " << std::to_string(static_cast<int>(p * 100));
ss << ": " << percentile(p, cpu_time_) << '\n';
}
return ss.str();
}
private:
WallClockTimer wall_timer_;
std::vector<float> wall_time_;
CpuClockTimer cpu_timer_;
std::vector<float> cpu_time_;
};
} // namespace tests
} // namespace infrt
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册