未验证 提交 aa67c292 编写于 作者: H huzhiqiang 提交者: GitHub

[infrt] support resnet50 on gpu backend (#41473)

上级 9ac6b7ed
......@@ -270,6 +270,12 @@ int InfRtPredictor::Init(const InfRtConfig& config) {
{::infrt::TargetType::CPU,
::infrt::PrecisionType::FLOAT32,
::infrt::LayoutType::NCHW}};
if (config.gpu_enabled()) {
valid_places.insert(valid_places.begin(),
::infrt::Place(::infrt::TargetType::GPU,
::infrt::PrecisionType::FLOAT32,
::infrt::LayoutType::NCHW));
}
pass_manager.addPass(CreatePhiOpCvtPass(valid_places));
pass_manager.addPass(CreateInfrtOpFusePass());
}
......@@ -300,12 +306,19 @@ int InfRtPredictor::Init(const InfRtConfig& config) {
}
// Load params
auto tensor_map = ::infrt::kernel::phi::LoadCombinedParameters(
config.model_dir(), config.param_dir());
if (config.gpu_enabled() && !config.tensorrt_enabled()) {
auto tensor_map = ::infrt::kernel::phi::LoadCombinedParamsToGpu(
config.model_dir(), config.param_dir());
impl_->executor.reset(
new PredictExecutor(module_op, registry, std::move(tensor_map)));
} else {
auto tensor_map = ::infrt::kernel::phi::LoadCombinedParameters(
config.model_dir(), config.param_dir());
impl_->executor.reset(
new PredictExecutor(module_op, registry, std::move(tensor_map)));
}
// Create PredictExecutor
impl_->executor.reset(
new PredictExecutor(module_op, registry, std::move(tensor_map)));
return 0;
}
......
......@@ -27,6 +27,7 @@ class InfRtConfig {
std::vector<std::string> shared_libs_;
// TODO(wilber): Design an easy-to-use interface.
bool gpu_enabled_{false};
bool tensorrt_enabled_{false};
public:
......@@ -42,6 +43,9 @@ class InfRtConfig {
}
const std::vector<std::string>& shared_libs() const { return shared_libs_; }
void enable_gpu() { gpu_enabled_ = true; }
bool gpu_enabled() const { return gpu_enabled_; }
// TODO(wilber): Design an easy-to-use interface.
void enable_tensorrt() { tensorrt_enabled_ = true; }
void disable_tensorrt() { tensorrt_enabled_ = false; }
......
......@@ -57,6 +57,57 @@ TEST(InfRtPredictor, predictor) {
ASSERT_EQ(output->dims(), ::phi::DDim({16, 10}));
}
TEST(InfRtPredictor, cpu_predictor) {
std::vector<std::string> shared_libs;
InfRtConfig config;
config.set_model_dir("@CMAKE_BINARY_DIR@/models/resnet50/model.pdmodel");
config.set_param_dir("@CMAKE_BINARY_DIR@/models/resnet50/model.pdiparams");
std::unique_ptr<InfRtPredictor> predictor = CreateInfRtPredictor(config);
::infrt::backends::CpuPhiAllocator cpu_allocator;
::phi::DenseTensor* input = predictor->GetInput(0);
input->Resize({2, 3, 256, 256});
input->AllocateFrom(&cpu_allocator, ::phi::DataType::FLOAT32);
auto* input_data = reinterpret_cast<float*>(input->data());
for (int i = 0; i < input->numel(); i++) input_data[i] = 1.0;
for(int i = 0; i < 10; i++) {
predictor->Run();
}
auto start = std::chrono::steady_clock::now();
for(int i = 0; i < 10; i++) {
predictor->Run();
}
auto end = std::chrono::steady_clock::now();
auto msec = std::chrono::duration_cast<std::chrono::milliseconds>(end-start);
std::cout <<"One predict period costs " << msec.count()/1000 << "ms.\n";
// get and print output tensor
auto* output = predictor->GetOutput(0);
ASSERT_EQ(output->dims(), ::phi::DDim({2, 1000}));
const std::vector<float> true_vals {
-3.319006264209747314e-01, -1.418896913528442383e+00,
-6.934890151023864746e-01, -1.498023152351379395e+00,
3.078042864799499512e-01, -1.340998053550720215e+00,
3.508620023727416992e+00, 2.274388313293457031e+00,
-1.321727275848388672e+00, -8.888689428567886353e-02,
-3.319006264209747314e-01, -1.418896913528442383e+00,
-6.934890151023864746e-01, -1.498023152351379395e+00,
3.078042864799499512e-01, -1.340998053550720215e+00,
3.508620023727416992e+00, 2.274388313293457031e+00,
-1.321727275848388672e+00, -8.888689428567886353e-02
};
for (size_t i = 0; i < true_vals.size(); i+=100) {
CHECK_NEAR(output->data<float>()[i*100], true_vals[i], 1e-5);
}
}
#ifdef INFRT_WITH_TRT
TEST(InfRtPredictor, trt_predictor) {
std::vector<std::string> shared_libs;
......@@ -100,4 +151,67 @@ TEST(InfRtPredictor, trt_predictor) {
}
#endif
#ifdef INFRT_WITH_GPU
TEST(InfRtPredictor, gpu_predictor) {
std::vector<std::string> shared_libs;
InfRtConfig config;
config.enable_gpu();
config.set_model_dir("@CMAKE_BINARY_DIR@/models/resnet50/model.pdmodel");
config.set_param_dir("@CMAKE_BINARY_DIR@/models/resnet50/model.pdiparams");
std::unique_ptr<InfRtPredictor> predictor = CreateInfRtPredictor(config);
::infrt::backends::GpuPhiAllocator gpu_allocator;
::phi::DenseTensor* input = predictor->GetInput(0);
input->Resize({2, 3, 256, 256});
input->AllocateFrom(&gpu_allocator, ::phi::DataType::FLOAT32);
auto* data = reinterpret_cast<float*>(input->data());
std::vector<float> input_data(2 * 3 * 256 * 256, 1.0);
cudaMemcpy(data,
input_data.data(),
sizeof(float) * input->numel(),
cudaMemcpyHostToDevice);
for(int i = 0; i < 10; i++) {
predictor->Run();
}
auto start = std::chrono::steady_clock::now();
for(int i = 0; i < 1000; i++) {
predictor->Run();
}
auto end = std::chrono::steady_clock::now();
auto msec = std::chrono::duration_cast<std::chrono::milliseconds>(end-start);
std::cout <<"One predict period costs " << msec.count()/1000 << "ms.\n";
auto* output = predictor->GetOutput(0);
std::vector<float> output_data(output->numel());
cudaMemcpy(output_data.data(),
output->data<float>(),
sizeof(float) * output->numel(),
cudaMemcpyDeviceToHost);
ASSERT_EQ(output->dims(), ::phi::DDim({2, 1000}));
const std::vector<float> true_vals {
-3.319006264209747314e-01, -1.418896913528442383e+00,
-6.934890151023864746e-01, -1.498023152351379395e+00,
3.078042864799499512e-01, -1.340998053550720215e+00,
3.508620023727416992e+00, 2.274388313293457031e+00,
-1.321727275848388672e+00, -8.888689428567886353e-02,
-3.319006264209747314e-01, -1.418896913528442383e+00,
-6.934890151023864746e-01, -1.498023152351379395e+00,
3.078042864799499512e-01, -1.340998053550720215e+00,
3.508620023727416992e+00, 2.274388313293457031e+00,
-1.321727275848388672e+00, -8.888689428567886353e-02
};
for (size_t i = 0; i < true_vals.size(); i+=100) {
CHECK_NEAR(output_data[i*100], true_vals[i], 1e-5);
}
}
#endif
} // namespace infrt
......@@ -11,6 +11,7 @@ limitations under the License. */
#pragma once
#include "paddle/fluid/memory/malloc.h"
#include "paddle/phi/core/allocator.h"
#ifdef INFRT_WITH_GPU
......@@ -40,12 +41,8 @@ class GpuPhiAllocator : public phi::Allocator {
static void deleter(phi::Allocation* ptr) { cudaFree(ptr->ptr()); }
AllocationPtr Allocate(size_t bytes_size) {
void* ptr;
cudaMalloc(&ptr, bytes_size);
return AllocationPtr(
new phi::Allocation(
ptr, bytes_size, phi::Place(phi::AllocationType::GPU)),
deleter);
return paddle::memory::Alloc(phi::Place(phi::AllocationType::GPU),
bytes_size);
}
};
#endif
......
......@@ -34,9 +34,8 @@ void registerCinnDialects(mlir::DialectRegistry &registry) { // NOLINT
InfrtDialect,
dt::DTDialect,
pd::PaddleDialect,
trt::TensorRTDialect
trt::TensorRTDialect,
#ifdef INFRT_WITH_PHI
,
phi::PHIDenseTensorDialect,
phi::PHICPUKernelDialect,
phi::PHIGPUKernelDialect,
......
......@@ -40,6 +40,13 @@ def CreateHostInitedDenseTensorOp : PDT_Op<"create_host_inited_dense_tensor.f32"
let results = (outs DenseTensor:$output);
}
def CreateInitedGpuFLOAT32DenseTensorOp
: PDT_Op<"create_inited_dense_tensor.gpu.f32", [NoSideEffect]> {
let arguments = (ins Context:$context, I64ArrayAttr:$dims,
LayoutAttr:$layout, I64ArrayAttr:$lod, F32Attr:$value);
let results = (outs DenseTensor:$output);
}
def CreateInitedCpuFLOAT32DenseTensorOp
: PDT_Op<"create_inited_dense_tensor.cpu.f32", [NoSideEffect]> {
let arguments = (ins Context:$context, I64ArrayAttr:$dims,
......@@ -86,6 +93,14 @@ def PDT_LoadCombinedParamsOp : PDT_Op<"load_combined_params", [NoSideEffect]> {
let assemblyFormat = "`(``)`attr-dict";
}
def PDT_LoadCombinedParamsGpuOp : PDT_Op<"load_combined_params_to_gpu", [NoSideEffect]> {
// input path of model params.
let arguments = (ins StrAttr:$model_path, StrAttr:$params_path);
let results = (outs PD_DenseTensorMap:$out);
let assemblyFormat = "`(``)`attr-dict";
}
def PDT_TensorMapGetSizeOp : PDT_Op<"tensor_map_get_size", [NoSideEffect]> {
let arguments = (ins PD_DenseTensorMap:$map);
let results = (outs I32:$size);
......
......@@ -76,6 +76,7 @@ class PhiOpConvertPass
void getDependentDialects(mlir::DialectRegistry &registry) const override;
private:
void updateInputsAndResults(infrt::TargetType target);
void convertStage();
void dispatchStage();
......@@ -110,10 +111,50 @@ mlir::LogicalResult PhiOpConvertPass::initialize(mlir::MLIRContext *context) {
// Implementation of the PhiOpConvertPass.
void PhiOpConvertPass::runOnFunction() {
updateInputsAndResults(valid_places_[0].target);
convertStage();
dispatchStage();
}
void PhiOpConvertPass::updateInputsAndResults(infrt::TargetType target) {
mlir::Block &body = getFunction().front();
auto loc = getFunction().getLoc();
mlir::Operation &operation = body.front();
mlir::MLIRContext *context = operation.getContext();
size_t num_input = body.getNumArguments();
// step1. update input cpu tensors into gpu tensors
for (size_t index = 0; index < num_input; index++) {
auto argument = body.getArgument(index);
if (auto t = argument.getType().dyn_cast<::infrt::DenseTensorType>()) {
mlir::Type replace_type = infrt::DenseTensorType::get(
context, target, t.getPrecision(), infrt::LayoutType::NCHW);
getFunction().insertArgument(index, replace_type, {}, loc);
argument.replaceAllUsesWith(getFunction().getArgument(index));
getFunction().eraseArgument(index + 1);
}
}
// update output tensors
unsigned int num_result = getFunction().getNumResults();
for (unsigned int index = 0; index < num_result; index++) {
mlir::Type replace_type =
infrt::DenseTensorType::get(context,
target,
infrt::PrecisionType::FLOAT32,
infrt::LayoutType::NCHW);
getFunction().eraseResult(index);
getFunction().insertResult(index, replace_type, {});
}
// update dense_tensor_map
mlir::Type replace_type = infrt::DenseTensorType::get(
context, target, infrt::PrecisionType::FLOAT32, infrt::LayoutType::NCHW);
for (auto &op : body.without_terminator()) {
if (op.getName().getIdentifier().str() == "phi_dt.tensor_map_get_tensor")
op.getResult(0).setType(replace_type);
}
}
void PhiOpConvertPass::convertStage() {
mlir::Block &body = getFunction().front();
std::vector<mlir::Operation *> worklist;
......@@ -200,6 +241,7 @@ void PhiOpConvertPass::dispatchStage() {
mlir::OpBuilder builder(&block, block.begin());
std::map<infrt::TargetType, mlir::Value> phi_context;
for (infrt::KernelOp kernel_op : worklist) {
std::string kernel_name = kernel_op.name().str();
std::vector<infrt::PhiKernelDesc> candidates =
......@@ -257,15 +299,25 @@ void PhiOpConvertPass::dispatchStage() {
for (size_t index = 0; index < phi_kernel_desc.input_types.size();
++index) {
mlir::Value input = kernel_op.getOperand(index);
auto cvt_tensor_type_op = builder.create<infrt::TensorCastOp>(
kernel_op.getLoc(),
infrt::DenseTensorType::get(
kernel_op.getContext(),
phi_kernel_desc.input_types[index].target,
phi_kernel_desc.input_types[index].precision,
phi_kernel_desc.input_types[index].layout),
input);
operation_state.addOperands(cvt_tensor_type_op.output());
if (input.getType().dyn_cast<::infrt::DenseTensorType>().getTarget() ==
::infrt::TargetType::CPU &&
phi_kernel_desc.input_types[index].target ==
::infrt::TargetType::GPU) {
auto cvt_tensor_type_op = builder.create<infrt::phi::GpuMemCopyOp>(
kernel_op.getLoc(),
infrt::DenseTensorType::get(
kernel_op.getContext(),
phi_kernel_desc.input_types[index].target,
phi_kernel_desc.input_types[index].precision,
phi_kernel_desc.input_types[index].layout),
input,
phi_context[infrt::TargetType::GPU],
mlir::BoolAttr::get(kernel_op.getContext(), /*d2h*/ false));
operation_state.addOperands(cvt_tensor_type_op.output());
} else {
operation_state.addOperands(input);
}
}
for (size_t index = 0; index < phi_kernel_desc.output_types.size();
......@@ -280,11 +332,8 @@ void PhiOpConvertPass::dispatchStage() {
mlir::Operation *phi_operation = builder.createOperation(operation_state);
for (size_t index = 0; index < phi_kernel_desc.output_types.size();
++index) {
mlir::Value input = phi_operation->getResult(index);
auto cvt_tensor_type_op = builder.create<infrt::TensorCastOp>(
kernel_op.getLoc(), kernel_op.getResultTypes()[index], input);
kernel_op.getResult(index).replaceAllUsesWith(
cvt_tensor_type_op.output());
phi_operation->getResult(index));
}
kernel_op.erase();
}
......
......@@ -62,7 +62,7 @@ namespace phi {
::phi::make_ddim(dims.get()),
ConvertLayoutToPhi(layout.get()),
{}));
float* a_data = dense_tensor.mutable_data<float>(::phi::CPUPlace());
float* a_data = dense_tensor.mutable_data<float>(context.GetPlace());
for (int64_t i = 0; i < dense_tensor.numel(); ++i) {
a_data[i] = value.get();
}
......@@ -260,6 +260,43 @@ void PrintDenseTensor(::phi::DenseTensor* dense_tensor) {
return map;
}
::infrt::phi::DenseTensorMap LoadCombinedParamsToGpu(
const std::string& model_path, const std::string& params_path) {
::infrt::phi::DenseTensorMap map;
auto pb_proto_prog = paddle::LoadProgram(model_path);
auto main_block = pb_proto_prog->blocks(0);
std::ifstream param_file(params_path, std::ios::binary);
std::set<std::string> tmp;
for (auto& var : main_block.vars()) {
if (var.name() == "feed" || var.name() == "fetch" || !var.persistable()) {
continue;
}
if (var.type().type() ==
::paddle::framework::proto::VarType_Type_LOD_TENSOR) {
tmp.emplace(var.name());
} else {
llvm_unreachable("the tensor type is illegal.");
}
}
#ifdef INFRT_WITH_GPU
::phi::GPUContext ctx;
ctx.PartialInitWithoutAllocator();
for (auto& var : tmp) {
std::unique_ptr<::phi::DenseTensor> tensor{
std::make_unique<::phi::DenseTensor>()};
::paddle::framework::DeserializeFromStream(param_file, tensor.get(), ctx);
map.SetDenseTensor(var, std::move(tensor));
}
#endif
return map;
}
::infrt::phi::DenseTensorMap LoadCombinedParams(
host_context::Attribute<std::string> model_path,
host_context::Attribute<std::string> params_path) {
......
......@@ -73,6 +73,9 @@ void PrintDenseTensor(::phi::DenseTensor* dense_tensor);
::infrt::phi::DenseTensorMap LoadCombinedParameters(
const std::string& model_path, const std::string& params_path);
::infrt::phi::DenseTensorMap LoadCombinedParamsToGpu(
const std::string& model_path, const std::string& params_path);
int32_t TensorMapGetSize(const ::infrt::phi::DenseTensorMap& map);
#ifdef INFRT_WITH_GPU
......
......@@ -68,6 +68,9 @@ void RegisterPhiKernels(host_context::KernelRegistry* registry) {
registry->AddKernel("phi_dt.load_params",
INFRT_KERNEL(infrt::kernel::phi::LoadParams),
{"path"});
registry->AddKernel("phi_dt.load_combined_params_to_gpu",
INFRT_KERNEL(infrt::kernel::phi::LoadCombinedParamsToGpu),
{"model_path", "params_path"});
registry->AddKernel("phi_dt.load_combined_params",
INFRT_KERNEL(infrt::kernel::phi::LoadCombinedParams),
{"model_path", "params_path"});
......
// RUN: infrtopt -phi-op-convert=valid-targets=CPU-FP32-NCHW -infrt-op-fuse %s
// CHECK-LABEL: @ops
func @ops(%a:!infrt.lod_tensor<?xf32,0>, %b:!infrt.lod_tensor<?xf32,0>) {
%g = "pd.elementwise_add"(%a, %b) {axis=1:si32} : (!infrt.lod_tensor<?xf32,0>, !infrt.lod_tensor<?xf32>) -> tensor<?xf32>
%h = "pd.abs"(%g):(tensor<?xf32>) -> tensor<?xf32>
infrt.return %h:tensor<?xf32>
func @ops(%a:!infrt.dense_tensor<CPU, FP32, NCHW>, %b:!infrt.dense_tensor<CPU, FP32, NCHW>) {
%g = "pd.elementwise_add"(%a, %b) {axis=1:si32} : (!infrt.dense_tensor<CPU, FP32, NCHW>, !infrt.dense_tensor<CPU, FP32, NCHW>) -> !infrt.dense_tensor<CPU, FP32, NCHW>
%h = "pd.abs"(%g):(!infrt.dense_tensor<CPU, FP32, NCHW>) -> !infrt.dense_tensor<CPU, FP32, NCHW>
infrt.return %h:!infrt.dense_tensor<CPU, FP32, NCHW>
}
// CHECK-LABEL: @op_execute
func @op_execute(%a:!infrt.lod_tensor<?xf32,0>, %b:!infrt.lod_tensor<?xf32,0>, %c:!infrt.lod_tensor<?xf32,0>) -> !infrt.lod_tensor<?xf32,0> {
%g = "pd.elementwise_add"(%a, %b) {axis=1:si32} : (!infrt.lod_tensor<?xf32,0>, !infrt.lod_tensor<?xf32>) -> tensor<?xf32>
%h = "pd.abs"(%g):(tensor<?xf32>) -> tensor<?xf32>
infrt.return %h:tensor<?xf32>
func @op_execute(%a:!infrt.dense_tensor<CPU, FP32, NCHW>, %b:!infrt.dense_tensor<CPU, FP32, NCHW>, %c:!infrt.dense_tensor<CPU, FP32, NCHW>) -> !infrt.dense_tensor<CPU, FP32, NCHW> {
%g = "pd.elementwise_add"(%a, %b) {axis=1:si32} : (!infrt.dense_tensor<CPU, FP32, NCHW>, !infrt.dense_tensor<CPU, FP32, NCHW>) -> !infrt.dense_tensor<CPU, FP32, NCHW>
%h = "pd.abs"(%g):(!infrt.dense_tensor<CPU, FP32, NCHW>) -> !infrt.dense_tensor<CPU, FP32, NCHW>
infrt.return %h:!infrt.dense_tensor<CPU, FP32, NCHW>
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册