未验证 提交 cf8be325 编写于 作者: W Wilber 提交者: GitHub

Trt engine (#40744)

* infrt add trt engine

* fix register

* file generate

* fix ci error

* fix conflict

* add copyright

* update

* update

* update

* update engine name

* refactor trt code

* update

* update

* update

* update

* fix conflict

* update

* refactor code

* first commit

* update pdtensor to denseTensor

* code

* style

* code

* code style

* add the tensor map, test=develop

* update

* update

* update

* trt engine

* update trt mlir and runtime

* update mlir test

* update

* update

* update
Co-authored-by: NDannyIsFunny <912790387@qq.com>
Co-authored-by: NShixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
上级 7e1155ed
......@@ -33,19 +33,21 @@ namespace tensorrt {
static nvinfer1::IBuilder* createInferBuilder(
nvinfer1::ILogger& logger) { // NOLINT
return static_cast<nvinfer1::IBuilder*>(
phi::dynload::createInferBuilder_INTERNAL(&logger, NV_TENSORRT_VERSION));
::phi::dynload::createInferBuilder_INTERNAL(&logger,
NV_TENSORRT_VERSION));
}
static nvinfer1::IRuntime* createInferRuntime(
nvinfer1::ILogger& logger) { // NOLINT
return static_cast<nvinfer1::IRuntime*>(
phi::dynload::createInferRuntime_INTERNAL(&logger, NV_TENSORRT_VERSION));
::phi::dynload::createInferRuntime_INTERNAL(&logger,
NV_TENSORRT_VERSION));
}
TrtEngine::TrtEngine(int device_id) : device_id_(device_id) {
FreshDeviceId();
logger_.reset(new TrtLogger());
builder_.reset(createInferBuilder(logger_->GetTrtLogger()));
phi::dynload::initLibNvInferPlugins(&logger_->GetTrtLogger(), "");
::phi::dynload::initLibNvInferPlugins(&logger_->GetTrtLogger(), "");
}
nvinfer1::IBuilder* TrtEngine::GetTrtBuilder() {
......@@ -237,11 +239,11 @@ bool TrtEngine::SetupNetworkAndConfig(const BuildOptions& build,
}
void TrtEngine::PrepareOutputHandle(const std::string& out_name) {
phi::DenseTensor t;
::phi::DenseTensor t;
outputs_.emplace(out_name, t);
}
phi::DenseTensor* TrtEngine::GetOutput(const std::string& name) {
::phi::DenseTensor* TrtEngine::GetOutput(const std::string& name) {
return &outputs_[name];
}
......@@ -249,7 +251,7 @@ size_t TrtEngine::GetOutputNum() const { return outputs_.size(); }
bool TrtEngine::SetUpInference(
const InferenceOptions& inference,
const std::unordered_map<std::string, phi::DenseTensor*>& inputs) {
const std::unordered_map<std::string, ::phi::DenseTensor*>& inputs) {
// TODO(wilber): now only create one exec_context
FreshDeviceId();
CHECK(engine_ != nullptr);
......@@ -272,7 +274,7 @@ bool TrtEngine::SetUpInference(
return true;
}
void TrtEngine::Run(const phi::GPUContext& ctx) {
void TrtEngine::Run(const ::phi::GPUContext& ctx) {
if (is_dynamic_shape_) {
DynamicRun(ctx);
} else {
......@@ -280,7 +282,7 @@ void TrtEngine::Run(const phi::GPUContext& ctx) {
}
}
void TrtEngine::StaticRun(const phi::GPUContext& ctx) {
void TrtEngine::StaticRun(const ::phi::GPUContext& ctx) {
const int num_bindings = engine_->getNbBindings();
std::vector<void*> buffers(num_bindings, nullptr);
......@@ -291,7 +293,8 @@ void TrtEngine::StaticRun(const phi::GPUContext& ctx) {
buffers[bind_index] =
const_cast<void*>(static_cast<const void*>(bind.buffer->data<float>()));
if (runtime_batch != -1) {
CHECK_EQ(runtime_batch, phi::vectorize<int64_t>(bind.buffer->dims())[0]);
CHECK_EQ(runtime_batch,
::phi::vectorize<int64_t>(bind.buffer->dims())[0]);
}
runtime_batch = bind.buffer->dims()[0];
}
......@@ -306,7 +309,7 @@ void TrtEngine::StaticRun(const phi::GPUContext& ctx) {
for (int i = 0; i < dims.nbDims; ++i) {
ddim.push_back(dims.d[i]);
}
bind.buffer->Resize(phi::make_ddim(ddim));
bind.buffer->Resize(::phi::make_ddim(ddim));
// TODO(wilber): now only support float output.
ctx.Alloc<float>(bind.buffer, sizeof(float) * bind.buffer->numel());
buffers[bind_index] = static_cast<void*>(bind.buffer->data<float>());
......@@ -316,7 +319,7 @@ void TrtEngine::StaticRun(const phi::GPUContext& ctx) {
runtime_batch, buffers.data(), ctx.stream(), nullptr);
}
void TrtEngine::DynamicRun(const phi::GPUContext& ctx) {
void TrtEngine::DynamicRun(const ::phi::GPUContext& ctx) {
const int num_bindings = engine_->getNbBindings();
std::vector<void*> buffers(num_bindings, nullptr);
......@@ -344,7 +347,7 @@ void TrtEngine::DynamicRun(const phi::GPUContext& ctx) {
for (int i = 0; i < dims.nbDims; ++i) {
ddim[i] = dims.d[i];
}
bind.buffer->Resize(phi::make_ddim(ddim));
bind.buffer->Resize(::phi::make_ddim(ddim));
ctx.Alloc<float>(bind.buffer, sizeof(float) * bind.buffer->numel());
buffers[bind_index] = static_cast<void*>(bind.buffer->data<float>());
}
......@@ -356,7 +359,7 @@ void TrtEngine::FreshDeviceId() {
int count;
cudaGetDeviceCount(&count);
CHECK_LT(device_id_, count);
phi::backends::gpu::SetDeviceId(device_id_);
::phi::backends::gpu::SetDeviceId(device_id_);
}
void TrtEngine::GetEngineInfo() {
......
......@@ -76,19 +76,19 @@ class TrtEngine {
const BuildOptions& build_options);
// TODO(wilber): Modify signature after infrt-trt ready.
void Run(const phi::GPUContext& ctx);
void Run(const ::phi::GPUContext& ctx);
// TODO(wilber): How to support multiple execution contexts?
bool SetUpInference(
const InferenceOptions& inference,
const std::unordered_map<std::string, phi::DenseTensor*>& inputs);
const std::unordered_map<std::string, ::phi::DenseTensor*>& inputs);
void GetEngineInfo();
void PrepareOutputHandle(const std::string& out_name);
// TODO(wilber): The output tensor names are: output_0, output_1, ...
phi::DenseTensor* GetOutput(const std::string&);
::phi::DenseTensor* GetOutput(const std::string&);
size_t GetOutputNum() const;
......@@ -104,9 +104,9 @@ class TrtEngine {
bool ModelToBuildEnv(TrtUniquePtr<nvinfer1::INetworkDefinition> network,
const BuildOptions& build);
void StaticRun(const phi::GPUContext& ctx);
void StaticRun(const ::phi::GPUContext& ctx);
void DynamicRun(const phi::GPUContext& ctx);
void DynamicRun(const ::phi::GPUContext& ctx);
private:
std::unique_ptr<TrtLogger> logger_{nullptr};
......@@ -118,7 +118,7 @@ class TrtEngine {
std::vector<std::unique_ptr<Bindings>> bindings_;
int device_id_{0};
bool is_dynamic_shape_{false};
std::unordered_map<std::string, phi::DenseTensor> outputs_;
std::unordered_map<std::string, ::phi::DenseTensor> outputs_;
};
} // namespace tensorrt
......
......@@ -92,7 +92,7 @@ class TrtLogger : public nvinfer1::ILogger {
struct Binding {
bool is_input{false};
nvinfer1::DataType data_type{nvinfer1::DataType::kFLOAT};
phi::DenseTensor* buffer{nullptr};
::phi::DenseTensor* buffer{nullptr};
std::string name;
};
......@@ -103,7 +103,7 @@ class Bindings {
void AddBinding(int32_t b,
const std::string& name,
bool is_input,
phi::DenseTensor* buffer,
::phi::DenseTensor* buffer,
nvinfer1::DataType data_type) {
while (bindings_.size() <= static_cast<size_t>(b)) {
bindings_.emplace_back();
......
......@@ -97,4 +97,17 @@ def FakeKernelOp : PDT_Op<"fake_phi_kernel"> {
let results = (outs DenseTensor:$output);
}
// TODO(wilber): Add a infrt_gpu dialect.
def PDT_GpuMemCopyOp : PDT_Op<"memcpy.gpu", [NoSideEffect]> {
let summary = "phi_dt.gpu.memcpy";
let description = [{gpu memcpy d2h or h2d}];
// TODO(wilber): add context argument to support stream.
let arguments = (ins
DenseTensor:$input,
Context:$context,
BoolAttr:$d2h
);
let results = (outs DenseTensor:$output);
}
#endif
......@@ -6,6 +6,7 @@ gather_srcs(infrt_src SRCS
trt_op_teller_pass.cc
trt_graph_fuse_pass.cc
trt_graph_split_pass.cc
trt_type_convert_pass.cc
)
mlir_tablegen_on(trt_ops)
mlir_add_rewriter(pd_lower_to_trt)
......
......@@ -21,6 +21,26 @@
#include "paddle/infrt/dialect/tensorrt/trt_graph_split_pass.h"
#include "paddle/infrt/dialect/tensorrt/trt_op_converter_pass.h"
#include "paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h"
#include "paddle/infrt/dialect/tensorrt/trt_type_convert_pass.h"
#include "paddle/infrt/host_context/core_runtime.h"
#include "paddle/infrt/host_context/kernel_registry.h"
#include "paddle/infrt/host_context/mlir_to_runtime_translate.h"
#include "paddle/infrt/kernel/basic_kernels.h"
#include "paddle/infrt/kernel/control_flow_kernels.h"
#include "paddle/infrt/kernel/tensor_kernels.h"
#include "paddle/infrt/kernel/tensor_shape_kernels.h"
#include "paddle/infrt/kernel/test_kernels.h"
#include "paddle/infrt/kernel/tensorrt/registry.h"
#ifdef INFRT_WITH_PHI
#include "paddle/infrt/dialect/infrt/pass/infrt_op_fuse_pass.h"
#include "paddle/infrt/dialect/phi/pass/phi_op_convert_pass.h"
#include "paddle/infrt/kernel/phi/infershaped/infershaped_kernel_launchers.h"
#include "paddle/infrt/kernel/phi/registry.h"
#endif
int main(int argc, char** argv) {
static llvm::cl::opt<std::string> input_file(
......@@ -33,6 +53,22 @@ int main(int argc, char** argv) {
mlir::MLIRContext* context = infrt::Global::getMLIRContext();
auto module = infrt::dialect::LoadMlirFile(input_file.c_str(), context);
infrt::host_context::KernelRegistry registry;
::infrt::kernel::RegisterBasicKernels(&registry);
::infrt::kernel::RegisterTestKernels(&registry);
::infrt::kernel::RegisterTensorShapeKernels(&registry);
::infrt::kernel::RegisterTensorKernels(&registry);
::infrt::kernel::RegisterControlFlowKernels(&registry);
#ifdef INFRT_WITH_PHI
::infrt::kernel::RegisterPhiKernels(&registry);
::infrt::kernel::RegisterInferShapeLaunchers(&registry);
#endif
#if defined(INFRT_WITH_GPU) && defined(INFRT_WITH_TRT)
::infrt::kernel::RegisterTrtKernels(&registry);
#endif
context->loadAllAvailableDialects();
module->dump();
mlir::PassManager pm(context);
......@@ -41,10 +77,12 @@ int main(int argc, char** argv) {
trt_pass_manager.addPass(std::make_unique<infrt::trt::TRTGraphFusePass>());
trt_pass_manager.addPass(std::make_unique<infrt::trt::TRTGraphSplitPass>(1));
trt_pass_manager.addPass(std::make_unique<infrt::trt::TRTOpConverterPass>());
trt_pass_manager.addPass(infrt::trt::createTrtTypeConvertPass());
if (mlir::failed(pm.run(*module))) {
std::cout << "\npass failed!\n" << std::endl;
return 4;
}
module->dump();
::infrt::host_context::TestMlir(module.get(), &registry);
return 0;
}
......@@ -12,10 +12,17 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/dialect/tensorrt/trt_op_converter_pass.h"
#include <glog/logging.h>
#include <mlir/IR/Builders.h>
#include <mlir/Transforms/DialectConversion.h>
#include "paddle/infrt/dialect/dense_tensor.h"
#include "paddle/infrt/dialect/pd/ir/pd_ops.h"
#include "paddle/infrt/dialect/phi/ir/infrt_phi_tensor.h"
#include "paddle/infrt/dialect/phi/ir/phi_base.h"
#include "paddle/infrt/dialect/tensorrt/trt_dialect_types.h"
#include "paddle/infrt/dialect/tensorrt/trt_ops.h"
namespace infrt {
namespace trt {
......@@ -41,34 +48,34 @@ struct PD2TRT_GraphLower : public ::mlir::RewritePattern {
::llvm::SmallVector<mlir::Type, 4>(1, EngineType::get()),
trt_inputs,
true /*run_once*/);
::mlir::Block *block = new ::mlir::Block;
block->getOperations().splice(block->begin(),
casted_op.getBody()->getOperations(),
casted_op.getBody()->begin(),
casted_op.getBody()->end());
create_engine_op.body().push_back(block);
auto &block = create_engine_op.body().emplaceBlock();
block.getOperations().splice(block.begin(),
casted_op.getBody()->getOperations(),
casted_op.getBody()->begin(),
casted_op.getBody()->end());
// trt.execute
// outputs
::llvm::SmallVector<::mlir::Type, 4> execute_outputs_types;
for (auto v : casted_op.getODSResults(0)) {
execute_outputs_types.push_back(v.getType());
}
// inputs
::mlir::SmallVector<::mlir::Value, 4> execute_inputs(
create_engine_op.getODSResults(0));
for (auto v : inputs) {
execute_inputs.push_back(v);
}
auto execute_op = rewriter.create<ExecuteOp>(
ods_loc, execute_outputs_types, execute_inputs);
::llvm::SmallVector<::mlir::Value, 4> replace_values;
for (auto v :
::llvm::SmallVector<::mlir::Value, 4>{execute_op.getODSResults(0)}) {
replace_values.push_back(v);
// trt.compute
::llvm::SmallVector<::mlir::Value, 4> replace_values2;
auto ctx_op = rewriter.create<::infrt::phi::CreateGPUContextOp>(
ods_loc,
infrt::phi::ContextType::get(rewriter.getContext(),
infrt::TargetType::GPU));
auto compute_op = rewriter.create<EngineComputeOp>(
ods_loc,
::infrt::DenseTensorListType::get(rewriter.getContext()),
create_engine_op.engine(),
ctx_op.output());
auto tensor_list_val = compute_op.outputs();
for (size_t i = 0; i < casted_op.getNumResults(); ++i) {
auto res = casted_op->getResult(i);
auto int_attr = mlir::IntegerAttr::get(
mlir::IntegerType::get(rewriter.getContext(), 32), i);
auto get_tensor_op = rewriter.create<::infrt::dt::TensorListGetTensorOp>(
ods_loc, res.getType(), tensor_list_val, int_attr);
replace_values2.push_back(get_tensor_op.output());
}
rewriter.replaceOp(op, replace_values);
ctx_op->moveBefore(ctx_op->getBlock(), ctx_op->getBlock()->begin());
rewriter.replaceOp(op, replace_values2);
return ::mlir::success();
}
};
......@@ -82,6 +89,9 @@ void TRTOpConverterPass::runOnOperation() {
// this lowering. In our case, we are lowering to TensorRTDialect from
// PaddleDialect
target.addLegalDialect<TensorRTDialect>();
target.addLegalDialect<::infrt::phi::PHIDialect>();
target.addLegalDialect<::infrt::dt::DTDialect>();
target.addLegalDialect<phi::PHIDenseTensorDialect>();
// Now that the conversion target has been defined, we just need to provide
// the set of patterns that will lower the TensorRT operations.
......
......@@ -14,7 +14,9 @@
#include "paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h"
#include <llvm/Support/Casting.h>
#include <mlir/IR/Builders.h>
#include "paddle/infrt/dialect/dense_tensor.h"
#include "paddle/infrt/dialect/infrt/ir/basic_kernels.h"
#include "paddle/infrt/dialect/infrt/ir/infrt_dialect.h"
#include "paddle/infrt/dialect/pd/ir/pd_ops.h"
......@@ -35,10 +37,12 @@ void TRTOpTellerPass::runOnFunction() {
auto *op = worklist.back();
worklist.pop_back();
if (op == nullptr) continue;
if (op->getName().getStringRef().substr(0, 3) != "pd.") continue;
if (::llvm::dyn_cast_or_null<infrt::pd::FeedOp>(op)) continue;
if (::llvm::dyn_cast_or_null<infrt::pd::FetchOp>(op)) continue;
if (::llvm::dyn_cast_or_null<infrt::pd::GraphOp>(op)) continue;
if (::llvm::dyn_cast_or_null<::infrt::ReturnOp>(op)) continue;
builder.setInsertionPoint(op);
auto loc = getFunction().getLoc();
auto graph_op = builder.create<infrt::pd::GraphOp>(
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/dialect/tensorrt/trt_type_convert_pass.h"
#include <glog/logging.h>
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Casting.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/Value.h"
#include "mlir/Pass/Pass.h"
#include "paddle/infrt/dialect/infrt/common/types.h"
#include "paddle/infrt/dialect/infrt/ir/infrt_dialect.h"
#include "paddle/infrt/dialect/phi/ir/infrt_phi_tensor.h"
#include "paddle/infrt/dialect/tensorrt/trt_ops.h"
namespace {
class TrtTypeConvertPass
: public mlir::PassWrapper<TrtTypeConvertPass, mlir::FunctionPass> {
public:
::llvm::StringRef getName() const override { return "TrtTypeConvertPass"; }
void runOnFunction() override;
};
void TrtTypeConvertPass::runOnFunction() {
mlir::Block& body = getFunction().front();
auto* mlir_ctx = getFunction()->getContext();
mlir::OpBuilder builder(&body, body.begin());
std::vector<mlir::Operation*> worklist;
mlir::Operation* ctx_op{nullptr};
worklist.reserve(body.getOperations().size());
for (auto& op : body) {
worklist.push_back(&op);
if (op.getName().getStringRef() == "phi_dt.create_context.gpu") {
ctx_op = &op;
}
}
::infrt::LayoutType layout = ::infrt::LayoutType::NCHW;
::infrt::TargetType target = ::infrt::TargetType::GPU;
for (auto& op : worklist) {
if (auto tensor_map_get_op =
llvm::dyn_cast<::infrt::phi::TensorMapGetTensorOp>(op)) {
auto res = tensor_map_get_op.output();
if (auto t = res.getType().dyn_cast<::infrt::DenseTensorType>()) {
auto replace_type = ::infrt::DenseTensorType::get(
mlir_ctx, t.getTarget(), t.getPrecision(), layout);
res.setType(replace_type);
}
}
if (auto create_engine = llvm::dyn_cast<::infrt::trt::CreateEngineOp>(op)) {
// Insert `infrt.gpu.memcpy` op.
for (auto arg : create_engine.getOperands()) {
if (mlir::Operation* producer = arg.getDefiningOp()) {
if (arg.getType().isa<::infrt::DenseTensorType>()) {
builder.setInsertionPointAfter(producer);
auto t = arg.getType().dyn_cast<::infrt::DenseTensorType>();
if (producer->getName().getStringRef() !=
"phi_dt.tensor_map_get_tensor" &&
t.getTarget() != ::infrt::TargetType::GPU) {
auto replace_type = ::infrt::DenseTensorType::get(
mlir_ctx, target, t.getPrecision(), layout);
CHECK_NOTNULL(ctx_op);
auto mem_cpy_op = builder.create<::infrt::phi::GpuMemCopyOp>(
arg.getLoc(),
replace_type,
arg,
llvm::dyn_cast<::infrt::phi::CreateGPUContextOp>(ctx_op)
.output(),
mlir::BoolAttr::get(mlir_ctx, /*d2h*/ false));
arg.replaceAllUsesExcept(mem_cpy_op.output(), mem_cpy_op);
}
}
} else {
auto blockArg = arg.cast<mlir::BlockArgument>();
if (arg.getType().isa<::infrt::DenseTensorType>()) {
auto t = arg.getType().dyn_cast<::infrt::DenseTensorType>();
builder.setInsertionPointAfter(ctx_op);
auto replace_type = ::infrt::DenseTensorType::get(
mlir_ctx, ::infrt::TargetType::GPU, t.getPrecision(), layout);
CHECK_NOTNULL(ctx_op);
auto mem_cpy_op = builder.create<::infrt::phi::GpuMemCopyOp>(
blockArg.getLoc(),
replace_type,
blockArg,
llvm::dyn_cast<::infrt::phi::CreateGPUContextOp>(ctx_op)
.output(),
mlir::BoolAttr::get(mlir_ctx, /*d2h*/ false));
arg.replaceAllUsesExcept(mem_cpy_op.output(), mem_cpy_op);
}
}
}
// Change ops(in block) types.
auto& block = create_engine.getRegion().getBlocks().front();
for (auto& op : block.without_terminator()) {
for (size_t i = 0; i < op.getNumResults(); ++i) {
if (auto t = op.getResult(i)
.getType()
.dyn_cast<::infrt::DenseTensorType>()) {
auto replace_type = ::infrt::DenseTensorType::get(
mlir_ctx, ::infrt::TargetType::GPU, t.getPrecision(), layout);
op.getResult(i).setType(replace_type);
}
}
}
} else if (auto list_get_tensor_op =
llvm::dyn_cast<::infrt::dt::TensorListGetTensorOp>(op)) {
auto result = list_get_tensor_op.output();
if (auto t = result.getType().dyn_cast<::infrt::DenseTensorType>()) {
result.setType(::infrt::DenseTensorType::get(
mlir_ctx, ::infrt::TargetType::GPU, t.getPrecision(), layout));
}
} else if (auto return_op = llvm::dyn_cast<::infrt::ReturnOp>(op)) {
for (auto arg : return_op->getOperands()) {
if (auto t = arg.getType().dyn_cast<::infrt::DenseTensorType>()) {
if (t.getLayout() != ::infrt::LayoutType::ANY ||
t.getTarget() != ::infrt::TargetType::CPU ||
t.getPrecision() != ::infrt::PrecisionType::FLOAT32) {
builder.setInsertionPoint(return_op);
CHECK_NOTNULL(ctx_op);
auto mem_cpy_op = builder.create<::infrt::phi::GpuMemCopyOp>(
return_op.getLoc(),
::infrt::DenseTensorType::get(mlir_ctx,
::infrt::TargetType::CPU,
t.getPrecision(),
::infrt::LayoutType::ANY),
arg,
llvm::dyn_cast<::infrt::phi::CreateGPUContextOp>(ctx_op)
.output(),
mlir::BoolAttr::get(mlir_ctx, /*d2h*/ true));
arg.replaceAllUsesExcept(mem_cpy_op.output(), mem_cpy_op);
}
}
}
}
}
}
} // namespace
namespace infrt {
namespace trt {
std::unique_ptr<mlir::Pass> createTrtTypeConvertPass() {
return std::make_unique<TrtTypeConvertPass>();
}
} // namespace trt
} // namespace infrt
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <mlir/Pass/Pass.h>
namespace infrt {
namespace trt {
std::unique_ptr<mlir::Pass> createTrtTypeConvertPass();
} // namespace trt
} // namespace infrt
......@@ -309,7 +309,7 @@ bool MlirToRuntimeTranslator::EmitGeneralOp(
arg_value = GetOpResult(upstream_op);
}
}
if (arg_value->is_type<phi::DenseTensor>()) {
if (arg_value->is_type<::phi::DenseTensor>()) {
impl_->runtime->FeedInArgs(
std::make_pair(std::to_string(i), ValueRef(arg_value)));
}
......
......@@ -30,6 +30,7 @@ namespace phi {
::phi::GPUContext context;
context.PartialInitWithoutAllocator();
context.SetAllocator(new ::infrt::backends::GpuPhiAllocator{});
context.SetHostAllocator(new backends::CpuPhiAllocator{});
context.PartialInitWithAllocator();
return context;
}
......
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/infrt/kernel/phi/dense_tensor_kernels.h"
#include "llvm/Support/ErrorHandling.h"
#include "paddle/infrt/common/string.h"
#include "paddle/infrt/dialect/phi/data_type.h"
#include "paddle/infrt/kernel/phi/context_kernels.h"
......@@ -228,6 +229,69 @@ int32_t TensorMapGetSize(const ::infrt::phi::DenseTensorMap& map) {
return map.size();
}
#ifdef INFRT_WITH_GPU
inline size_t SizeOfDataType(::phi::DataType data_type) {
switch (data_type) {
case ::phi::DataType::BOOL:
case ::phi::DataType::UINT8:
case ::phi::DataType::INT8:
return 1;
case ::phi::DataType::BFLOAT16:
case ::phi::DataType::FLOAT16:
case ::phi::DataType::INT16:
case ::phi::DataType::UINT16:
return 2;
case ::phi::DataType::FLOAT32:
case ::phi::DataType::INT32:
case ::phi::DataType::UINT32:
return 4;
case ::phi::DataType::FLOAT64:
case ::phi::DataType::INT64:
case ::phi::DataType::UINT64:
case ::phi::DataType::COMPLEX64:
return 8;
case ::phi::DataType::COMPLEX128:
return 16;
case ::phi::DataType::UNDEFINED:
return 0;
default:
llvm_unreachable("should not reach here");
return 0;
}
return 0;
}
::phi::DenseTensor GpuMemCpy(const ::phi::DenseTensor& input,
const ::phi::GPUContext& context,
bool d2h) {
if (d2h) {
::phi::DenseTensor ret(
const_cast<::phi::Allocator*>(&context.GetHostAllocator()),
input.meta());
CHECK(input.place().GetType() == ::phi::AllocationType::GPU);
// TODO(wilber): Add sync op and stream.
cudaMemcpyAsync(ret.data(),
input.data(),
SizeOfDataType(input.dtype()) * input.numel(),
cudaMemcpyDeviceToHost,
nullptr);
return ret;
} else {
// h2d
::phi::DenseTensor ret(
const_cast<::phi::Allocator*>(&context.GetAllocator()), input.meta());
CHECK(input.place().GetType() == ::phi::AllocationType::CPU ||
input.place().GetType() == ::phi::AllocationType::GPUPINNED);
// TODO(wilber): Add sync op and stream.
cudaMemcpyAsync(ret.data(),
input.data(),
SizeOfDataType(input.dtype()) * input.numel(),
cudaMemcpyHostToDevice,
nullptr);
return ret;
}
}
#endif
} // namespace phi
} // namespace kernel
} // namespace infrt
......@@ -18,6 +18,7 @@
#include "paddle/infrt/dialect/infrt/common/types.h"
#include "paddle/infrt/host_context/kernel_utils.h"
#include "paddle/infrt/tensor/phi/tensor_map.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/dense_tensor.h"
namespace infrt {
......@@ -55,6 +56,12 @@ infrt::phi::DenseTensorMap LoadParams(
int32_t TensorMapGetSize(const ::infrt::phi::DenseTensorMap& map);
#ifdef INFRT_WITH_GPU
::phi::DenseTensor GpuMemCpy(const ::phi::DenseTensor& input,
const ::phi::GPUContext& context,
bool d2h);
#endif
} // namespace phi
} // namespace kernel
} // namespace infrt
......@@ -52,6 +52,9 @@ void RegisterPhiKernels(host_context::KernelRegistry* registry) {
"phi_dt.create_dense_tensor.gpu",
INFRT_KERNEL(infrt::kernel::phi::CreateGPUDenseTensor),
{"dims", "lod", "layout", "precision"});
registry->AddKernelWithAttrs("phi_dt.memcpy.gpu",
INFRT_KERNEL(infrt::kernel::phi::GpuMemCpy),
{"d2h"});
#endif
registry->AddKernelWithAttrs("phi_dt.load_params",
INFRT_KERNEL(infrt::kernel::phi::LoadParams),
......
......@@ -14,6 +14,7 @@
#include "paddle/infrt/kernel/tensorrt/trt_kernels.h"
#include <string>
#include <unordered_set>
#include "NvInfer.h"
#include "NvInferRuntime.h"
#include "NvInferRuntimeCommon.h"
......@@ -68,7 +69,7 @@ namespace tensorrt {
auto& region = operation.getRegion(0);
auto& block = region.getBlocks().front();
std::unordered_map<std::string, phi::DenseTensor*> trt_bind_inputs;
std::unordered_map<std::string, ::phi::DenseTensor*> trt_bind_inputs;
ValueToITensorMap value_to_trt_tensor_map;
ValueToTensorMap value_to_tensor_map;
......@@ -79,7 +80,7 @@ namespace tensorrt {
const std::string input_name = "input_" + std::to_string(idx);
auto* v = symbol_table->GetValue(std::to_string(idx));
CHECK_NOTNULL(v);
auto* t = &v->get<phi::DenseTensor>();
auto* t = &v->get<::phi::DenseTensor>();
value_to_tensor_map[operand] = t;
// TODO(wilber): get input info from mlir.
......@@ -93,7 +94,7 @@ namespace tensorrt {
if (operand.isa<mlir::BlockArgument>()) {
// TODO(wilber): A trick: the weights are CPU tensor and inputs are GPU
// tensor, so we treat all GPU tensors as inputs to trt.
if (t->place().GetType() == phi::AllocationType::GPU) {
if (t->place().GetType() == ::phi::AllocationType::GPU) {
trt_bind_inputs[input_name] = t;
nvinfer1::Dims dims;
dims.nbDims = t->dims().size() - 1;
......@@ -106,8 +107,10 @@ namespace tensorrt {
}
} else {
// TODO(wilber): Replace with the op name that generates the weights.
if (operand.getDefiningOp()->getName().getStringRef() !=
"phi_dt.create_dense_tensor.cpu") {
std::unordered_set<std::string> weight_flags{
"phi_dt.tensor_map_get_tensor", "phi_dt.create_dense_tensor.cpu"};
if (!weight_flags.count(
operand.getDefiningOp()->getName().getStringRef().str())) {
trt_bind_inputs[input_name] = t;
nvinfer1::Dims dims;
dims.nbDims = t->dims().size() - 1;
......@@ -167,10 +170,10 @@ void PrintTrtLayer(backends::tensorrt::TrtEngine* engine) {
engine->GetEngineInfo();
}
std::vector<phi::DenseTensor*> TrtEngineCompute(
backends::tensorrt::TrtEngine* engine, const phi::GPUContext& context) {
std::vector<::phi::DenseTensor*> TrtEngineCompute(
backends::tensorrt::TrtEngine* engine, const ::phi::GPUContext& context) {
engine->Run(context);
std::vector<phi::DenseTensor*> res;
std::vector<::phi::DenseTensor*> res;
for (size_t i = 0; i < engine->GetOutputNum(); ++i) {
res.push_back(engine->GetOutput("output_" + std::to_string(i)));
}
......
......@@ -41,8 +41,8 @@ struct MlirOperationWithInfrtSymbol {
void PrintTrtLayer(backends::tensorrt::TrtEngine* engine);
std::vector<phi::DenseTensor*> TrtEngineCompute(
backends::tensorrt::TrtEngine* engine, const phi::GPUContext& context);
std::vector<::phi::DenseTensor*> TrtEngineCompute(
backends::tensorrt::TrtEngine* engine, const ::phi::GPUContext& context);
} // namespace tensorrt
} // namespace kernel
......
......@@ -7,3 +7,4 @@ add_test(NAME test_infrt_by_lit COMMAND sh -c "lit -v ${CMAKE_SOURCE_DIR}/paddle
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/dialect/tensor/tensor_map.mlir.in ${CMAKE_CURRENT_SOURCE_DIR}/dialect/tensor/tensor_map.mlir)
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/dialect/phi/linear_cpu.mlir.in ${CMAKE_CURRENT_SOURCE_DIR}/dialect/phi/linear_cpu.mlir)
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/dialect/tensorrt/disabled_linear.mlir.in ${CMAKE_CURRENT_SOURCE_DIR}/dialect/tensorrt/disabled_linear.mlir)
module {
func @main_graph(%map: !phi.dense_tensor_map, %arg0: !infrt.dense_tensor<CPU, FP32, ANY>) -> !infrt.dense_tensor<CPU, FP32, ANY> {
%0 = "phi_dt.create_context.gpu"() : () -> !phi.context<GPU>
%1 = "phi_dt.memcpy.gpu"(%arg0, %0) {d2h = false} : (!infrt.dense_tensor<CPU, FP32, ANY>, !phi.context<GPU>) -> !infrt.dense_tensor<GPU, FP32, NCHW>
%3 = phi_dt.tensor_map_get_tensor(%map) {name = "linear_0.b_0"} -> !infrt.dense_tensor<CPU, FP32, NCHW>
%4 = phi_dt.tensor_map_get_tensor(%map) {name = "linear_0.w_0"} -> !infrt.dense_tensor<CPU, FP32, NCHW>
%5 = "trt.create_engine"(%1, %4, %3) ( {
%10 = "trt.FullyConnected"(%1, %4, %3) {out_channel_num = 10 : si32} : (!infrt.dense_tensor<GPU, FP32, NCHW>, !infrt.dense_tensor<CPU, FP32, NCHW>, !infrt.dense_tensor<CPU, FP32, NCHW>) -> !infrt.dense_tensor<GPU, FP32, NCHW>
infrt.return %10 : !infrt.dense_tensor<GPU, FP32, NCHW>
}) {run_once = true} : (!infrt.dense_tensor<GPU, FP32, NCHW>, !infrt.dense_tensor<CPU, FP32, NCHW>, !infrt.dense_tensor<CPU, FP32, NCHW>) -> !trt.engine
%6 = "trt.compute"(%5, %0) : (!trt.engine, !phi.context<GPU>) -> !infrt.tensor_list
%7 = "dt.tensor_list_get_tensor"(%6) {id = 0 : i32} : (!infrt.tensor_list) -> !infrt.dense_tensor<GPU, FP32, NCHW>
%8 = "phi_dt.memcpy.gpu"(%7, %0) {d2h = true} : (!infrt.dense_tensor<GPU, FP32, NCHW>, !phi.context<GPU>) -> !infrt.dense_tensor<CPU, FP32, ANY>
infrt.return %8 : !infrt.dense_tensor<CPU, FP32, ANY>
}
func @main() {
%map = phi_dt.load_combined_params(){model_path="@CMAKE_BINARY_DIR@/linear/linear.pdmodel",
params_path="@CMAKE_BINARY_DIR@/linear/linear.pdiparams"}
%ctx = "phi_dt.create_context.cpu" (): () -> !phi.context<CPU>
%input_tensor = "phi_dt.create_dense_tensor.cpu" (%ctx) {
precision=#infrt.precision<FP32>,
layout=#infrt.layout<NCHW>,
dims=[3:i64, 784:i64, 1:i64, 1:i64], lod=[1:i64]}: (!phi.context<CPU>) -> (!infrt.dense_tensor<CPU, FP32, NCHW>)
"phi_dt.fill_dense_tensor.f32"(%input_tensor) {value=[3.8:f32, 2.4:f32, 1.3:f32]} : (!infrt.dense_tensor<CPU, FP32, NCHW>) -> ()
%res = infrt.call @main_graph(%map, %input_tensor) {} : (!phi.dense_tensor_map, !infrt.dense_tensor<CPU, FP32, NCHW>) -> !infrt.dense_tensor<CPU, FP32, NCHW>
"phi_dt.print_tensor" (%res) : (!infrt.dense_tensor<CPU, FP32, NCHW>) -> ()
infrt.return
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册