From cf8be325b954c68fcad85738a8d164fe636bf95c Mon Sep 17 00:00:00 2001 From: Wilber Date: Thu, 24 Mar 2022 10:26:53 +0800 Subject: [PATCH] Trt engine (#40744) * infrt add trt engine * fix register * file generate * fix ci error * fix conflict * add copyright * update * update * update * update engine name * refactor trt code * update * update * update * update * fix conflict * update * refactor code * first commit * update pdtensor to denseTensor * code * style * code * code style * add the tensor map, test=develop * update * update * update * trt engine * update trt mlir and runtime * update mlir test * update * update * update Co-authored-by: DannyIsFunny <912790387@qq.com> Co-authored-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com> --- paddle/infrt/backends/tensorrt/trt_engine.cc | 29 +-- paddle/infrt/backends/tensorrt/trt_engine.h | 12 +- paddle/infrt/backends/tensorrt/trt_utils.h | 4 +- .../infrt/dialect/phi/ir/infrt_phi_tensor.td | 13 ++ paddle/infrt/dialect/tensorrt/CMakeLists.txt | 1 + paddle/infrt/dialect/tensorrt/trt_exec.cc | 38 ++++ .../dialect/tensorrt/trt_op_converter_pass.cc | 62 ++++--- .../dialect/tensorrt/trt_op_teller_pass.cc | 4 + .../dialect/tensorrt/trt_type_convert_pass.cc | 169 ++++++++++++++++++ .../dialect/tensorrt/trt_type_convert_pass.h | 25 +++ .../host_context/mlir_to_runtime_translate.cc | 2 +- paddle/infrt/kernel/phi/context_kernels.cc | 1 + .../infrt/kernel/phi/dense_tensor_kernels.cc | 64 +++++++ .../infrt/kernel/phi/dense_tensor_kernels.h | 7 + paddle/infrt/kernel/phi/registry.cc | 3 + paddle/infrt/kernel/tensorrt/trt_kernels.cc | 19 +- paddle/infrt/kernel/tensorrt/trt_kernels.h | 4 +- paddle/infrt/tests/CMakeLists.txt | 1 + .../dialect/tensorrt/disabled_linear.mlir.in | 33 ++++ 19 files changed, 433 insertions(+), 58 deletions(-) create mode 100644 paddle/infrt/dialect/tensorrt/trt_type_convert_pass.cc create mode 100644 paddle/infrt/dialect/tensorrt/trt_type_convert_pass.h create mode 100644 paddle/infrt/tests/dialect/tensorrt/disabled_linear.mlir.in diff --git a/paddle/infrt/backends/tensorrt/trt_engine.cc b/paddle/infrt/backends/tensorrt/trt_engine.cc index 43d356b6d69..72d98d865a6 100644 --- a/paddle/infrt/backends/tensorrt/trt_engine.cc +++ b/paddle/infrt/backends/tensorrt/trt_engine.cc @@ -33,19 +33,21 @@ namespace tensorrt { static nvinfer1::IBuilder* createInferBuilder( nvinfer1::ILogger& logger) { // NOLINT return static_cast( - phi::dynload::createInferBuilder_INTERNAL(&logger, NV_TENSORRT_VERSION)); + ::phi::dynload::createInferBuilder_INTERNAL(&logger, + NV_TENSORRT_VERSION)); } static nvinfer1::IRuntime* createInferRuntime( nvinfer1::ILogger& logger) { // NOLINT return static_cast( - phi::dynload::createInferRuntime_INTERNAL(&logger, NV_TENSORRT_VERSION)); + ::phi::dynload::createInferRuntime_INTERNAL(&logger, + NV_TENSORRT_VERSION)); } TrtEngine::TrtEngine(int device_id) : device_id_(device_id) { FreshDeviceId(); logger_.reset(new TrtLogger()); builder_.reset(createInferBuilder(logger_->GetTrtLogger())); - phi::dynload::initLibNvInferPlugins(&logger_->GetTrtLogger(), ""); + ::phi::dynload::initLibNvInferPlugins(&logger_->GetTrtLogger(), ""); } nvinfer1::IBuilder* TrtEngine::GetTrtBuilder() { @@ -237,11 +239,11 @@ bool TrtEngine::SetupNetworkAndConfig(const BuildOptions& build, } void TrtEngine::PrepareOutputHandle(const std::string& out_name) { - phi::DenseTensor t; + ::phi::DenseTensor t; outputs_.emplace(out_name, t); } -phi::DenseTensor* TrtEngine::GetOutput(const std::string& name) { +::phi::DenseTensor* TrtEngine::GetOutput(const std::string& name) { return &outputs_[name]; } @@ -249,7 +251,7 @@ size_t TrtEngine::GetOutputNum() const { return outputs_.size(); } bool TrtEngine::SetUpInference( const InferenceOptions& inference, - const std::unordered_map& inputs) { + const std::unordered_map& inputs) { // TODO(wilber): now only create one exec_context FreshDeviceId(); CHECK(engine_ != nullptr); @@ -272,7 +274,7 @@ bool TrtEngine::SetUpInference( return true; } -void TrtEngine::Run(const phi::GPUContext& ctx) { +void TrtEngine::Run(const ::phi::GPUContext& ctx) { if (is_dynamic_shape_) { DynamicRun(ctx); } else { @@ -280,7 +282,7 @@ void TrtEngine::Run(const phi::GPUContext& ctx) { } } -void TrtEngine::StaticRun(const phi::GPUContext& ctx) { +void TrtEngine::StaticRun(const ::phi::GPUContext& ctx) { const int num_bindings = engine_->getNbBindings(); std::vector buffers(num_bindings, nullptr); @@ -291,7 +293,8 @@ void TrtEngine::StaticRun(const phi::GPUContext& ctx) { buffers[bind_index] = const_cast(static_cast(bind.buffer->data())); if (runtime_batch != -1) { - CHECK_EQ(runtime_batch, phi::vectorize(bind.buffer->dims())[0]); + CHECK_EQ(runtime_batch, + ::phi::vectorize(bind.buffer->dims())[0]); } runtime_batch = bind.buffer->dims()[0]; } @@ -306,7 +309,7 @@ void TrtEngine::StaticRun(const phi::GPUContext& ctx) { for (int i = 0; i < dims.nbDims; ++i) { ddim.push_back(dims.d[i]); } - bind.buffer->Resize(phi::make_ddim(ddim)); + bind.buffer->Resize(::phi::make_ddim(ddim)); // TODO(wilber): now only support float output. ctx.Alloc(bind.buffer, sizeof(float) * bind.buffer->numel()); buffers[bind_index] = static_cast(bind.buffer->data()); @@ -316,7 +319,7 @@ void TrtEngine::StaticRun(const phi::GPUContext& ctx) { runtime_batch, buffers.data(), ctx.stream(), nullptr); } -void TrtEngine::DynamicRun(const phi::GPUContext& ctx) { +void TrtEngine::DynamicRun(const ::phi::GPUContext& ctx) { const int num_bindings = engine_->getNbBindings(); std::vector buffers(num_bindings, nullptr); @@ -344,7 +347,7 @@ void TrtEngine::DynamicRun(const phi::GPUContext& ctx) { for (int i = 0; i < dims.nbDims; ++i) { ddim[i] = dims.d[i]; } - bind.buffer->Resize(phi::make_ddim(ddim)); + bind.buffer->Resize(::phi::make_ddim(ddim)); ctx.Alloc(bind.buffer, sizeof(float) * bind.buffer->numel()); buffers[bind_index] = static_cast(bind.buffer->data()); } @@ -356,7 +359,7 @@ void TrtEngine::FreshDeviceId() { int count; cudaGetDeviceCount(&count); CHECK_LT(device_id_, count); - phi::backends::gpu::SetDeviceId(device_id_); + ::phi::backends::gpu::SetDeviceId(device_id_); } void TrtEngine::GetEngineInfo() { diff --git a/paddle/infrt/backends/tensorrt/trt_engine.h b/paddle/infrt/backends/tensorrt/trt_engine.h index a26474f8cbb..41d11a71117 100644 --- a/paddle/infrt/backends/tensorrt/trt_engine.h +++ b/paddle/infrt/backends/tensorrt/trt_engine.h @@ -76,19 +76,19 @@ class TrtEngine { const BuildOptions& build_options); // TODO(wilber): Modify signature after infrt-trt ready. - void Run(const phi::GPUContext& ctx); + void Run(const ::phi::GPUContext& ctx); // TODO(wilber): How to support multiple execution contexts? bool SetUpInference( const InferenceOptions& inference, - const std::unordered_map& inputs); + const std::unordered_map& inputs); void GetEngineInfo(); void PrepareOutputHandle(const std::string& out_name); // TODO(wilber): The output tensor names are: output_0, output_1, ... - phi::DenseTensor* GetOutput(const std::string&); + ::phi::DenseTensor* GetOutput(const std::string&); size_t GetOutputNum() const; @@ -104,9 +104,9 @@ class TrtEngine { bool ModelToBuildEnv(TrtUniquePtr network, const BuildOptions& build); - void StaticRun(const phi::GPUContext& ctx); + void StaticRun(const ::phi::GPUContext& ctx); - void DynamicRun(const phi::GPUContext& ctx); + void DynamicRun(const ::phi::GPUContext& ctx); private: std::unique_ptr logger_{nullptr}; @@ -118,7 +118,7 @@ class TrtEngine { std::vector> bindings_; int device_id_{0}; bool is_dynamic_shape_{false}; - std::unordered_map outputs_; + std::unordered_map outputs_; }; } // namespace tensorrt diff --git a/paddle/infrt/backends/tensorrt/trt_utils.h b/paddle/infrt/backends/tensorrt/trt_utils.h index c66a850ffb1..c23d4608bb3 100644 --- a/paddle/infrt/backends/tensorrt/trt_utils.h +++ b/paddle/infrt/backends/tensorrt/trt_utils.h @@ -92,7 +92,7 @@ class TrtLogger : public nvinfer1::ILogger { struct Binding { bool is_input{false}; nvinfer1::DataType data_type{nvinfer1::DataType::kFLOAT}; - phi::DenseTensor* buffer{nullptr}; + ::phi::DenseTensor* buffer{nullptr}; std::string name; }; @@ -103,7 +103,7 @@ class Bindings { void AddBinding(int32_t b, const std::string& name, bool is_input, - phi::DenseTensor* buffer, + ::phi::DenseTensor* buffer, nvinfer1::DataType data_type) { while (bindings_.size() <= static_cast(b)) { bindings_.emplace_back(); diff --git a/paddle/infrt/dialect/phi/ir/infrt_phi_tensor.td b/paddle/infrt/dialect/phi/ir/infrt_phi_tensor.td index 3af7033d2f4..9df9abe18cb 100644 --- a/paddle/infrt/dialect/phi/ir/infrt_phi_tensor.td +++ b/paddle/infrt/dialect/phi/ir/infrt_phi_tensor.td @@ -97,4 +97,17 @@ def FakeKernelOp : PDT_Op<"fake_phi_kernel"> { let results = (outs DenseTensor:$output); } +// TODO(wilber): Add a infrt_gpu dialect. +def PDT_GpuMemCopyOp : PDT_Op<"memcpy.gpu", [NoSideEffect]> { + let summary = "phi_dt.gpu.memcpy"; + let description = [{gpu memcpy d2h or h2d}]; + // TODO(wilber): add context argument to support stream. + let arguments = (ins + DenseTensor:$input, + Context:$context, + BoolAttr:$d2h + ); + let results = (outs DenseTensor:$output); +} + #endif diff --git a/paddle/infrt/dialect/tensorrt/CMakeLists.txt b/paddle/infrt/dialect/tensorrt/CMakeLists.txt index 99c335ed178..5b62b78e4da 100755 --- a/paddle/infrt/dialect/tensorrt/CMakeLists.txt +++ b/paddle/infrt/dialect/tensorrt/CMakeLists.txt @@ -6,6 +6,7 @@ gather_srcs(infrt_src SRCS trt_op_teller_pass.cc trt_graph_fuse_pass.cc trt_graph_split_pass.cc + trt_type_convert_pass.cc ) mlir_tablegen_on(trt_ops) mlir_add_rewriter(pd_lower_to_trt) diff --git a/paddle/infrt/dialect/tensorrt/trt_exec.cc b/paddle/infrt/dialect/tensorrt/trt_exec.cc index 7af1fa53d12..be239255ffb 100644 --- a/paddle/infrt/dialect/tensorrt/trt_exec.cc +++ b/paddle/infrt/dialect/tensorrt/trt_exec.cc @@ -21,6 +21,26 @@ #include "paddle/infrt/dialect/tensorrt/trt_graph_split_pass.h" #include "paddle/infrt/dialect/tensorrt/trt_op_converter_pass.h" #include "paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h" +#include "paddle/infrt/dialect/tensorrt/trt_type_convert_pass.h" + +#include "paddle/infrt/host_context/core_runtime.h" +#include "paddle/infrt/host_context/kernel_registry.h" +#include "paddle/infrt/host_context/mlir_to_runtime_translate.h" + +#include "paddle/infrt/kernel/basic_kernels.h" +#include "paddle/infrt/kernel/control_flow_kernels.h" +#include "paddle/infrt/kernel/tensor_kernels.h" +#include "paddle/infrt/kernel/tensor_shape_kernels.h" +#include "paddle/infrt/kernel/test_kernels.h" + +#include "paddle/infrt/kernel/tensorrt/registry.h" + +#ifdef INFRT_WITH_PHI +#include "paddle/infrt/dialect/infrt/pass/infrt_op_fuse_pass.h" +#include "paddle/infrt/dialect/phi/pass/phi_op_convert_pass.h" +#include "paddle/infrt/kernel/phi/infershaped/infershaped_kernel_launchers.h" +#include "paddle/infrt/kernel/phi/registry.h" +#endif int main(int argc, char** argv) { static llvm::cl::opt input_file( @@ -33,6 +53,22 @@ int main(int argc, char** argv) { mlir::MLIRContext* context = infrt::Global::getMLIRContext(); auto module = infrt::dialect::LoadMlirFile(input_file.c_str(), context); + infrt::host_context::KernelRegistry registry; + + ::infrt::kernel::RegisterBasicKernels(®istry); + ::infrt::kernel::RegisterTestKernels(®istry); + ::infrt::kernel::RegisterTensorShapeKernels(®istry); + ::infrt::kernel::RegisterTensorKernels(®istry); + ::infrt::kernel::RegisterControlFlowKernels(®istry); +#ifdef INFRT_WITH_PHI + ::infrt::kernel::RegisterPhiKernels(®istry); + ::infrt::kernel::RegisterInferShapeLaunchers(®istry); +#endif +#if defined(INFRT_WITH_GPU) && defined(INFRT_WITH_TRT) + ::infrt::kernel::RegisterTrtKernels(®istry); +#endif + + context->loadAllAvailableDialects(); module->dump(); mlir::PassManager pm(context); @@ -41,10 +77,12 @@ int main(int argc, char** argv) { trt_pass_manager.addPass(std::make_unique()); trt_pass_manager.addPass(std::make_unique(1)); trt_pass_manager.addPass(std::make_unique()); + trt_pass_manager.addPass(infrt::trt::createTrtTypeConvertPass()); if (mlir::failed(pm.run(*module))) { std::cout << "\npass failed!\n" << std::endl; return 4; } module->dump(); + ::infrt::host_context::TestMlir(module.get(), ®istry); return 0; } diff --git a/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.cc b/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.cc index 19c6b13e971..1e50b772e08 100644 --- a/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.cc +++ b/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.cc @@ -12,10 +12,17 @@ // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/infrt/dialect/tensorrt/trt_op_converter_pass.h" + +#include #include #include + +#include "paddle/infrt/dialect/dense_tensor.h" #include "paddle/infrt/dialect/pd/ir/pd_ops.h" +#include "paddle/infrt/dialect/phi/ir/infrt_phi_tensor.h" +#include "paddle/infrt/dialect/phi/ir/phi_base.h" #include "paddle/infrt/dialect/tensorrt/trt_dialect_types.h" +#include "paddle/infrt/dialect/tensorrt/trt_ops.h" namespace infrt { namespace trt { @@ -41,34 +48,34 @@ struct PD2TRT_GraphLower : public ::mlir::RewritePattern { ::llvm::SmallVector(1, EngineType::get()), trt_inputs, true /*run_once*/); - ::mlir::Block *block = new ::mlir::Block; - block->getOperations().splice(block->begin(), - casted_op.getBody()->getOperations(), - casted_op.getBody()->begin(), - casted_op.getBody()->end()); - create_engine_op.body().push_back(block); + auto &block = create_engine_op.body().emplaceBlock(); + block.getOperations().splice(block.begin(), + casted_op.getBody()->getOperations(), + casted_op.getBody()->begin(), + casted_op.getBody()->end()); - // trt.execute - // outputs - ::llvm::SmallVector<::mlir::Type, 4> execute_outputs_types; - for (auto v : casted_op.getODSResults(0)) { - execute_outputs_types.push_back(v.getType()); - } - // inputs - ::mlir::SmallVector<::mlir::Value, 4> execute_inputs( - create_engine_op.getODSResults(0)); - for (auto v : inputs) { - execute_inputs.push_back(v); - } - auto execute_op = rewriter.create( - ods_loc, execute_outputs_types, execute_inputs); - - ::llvm::SmallVector<::mlir::Value, 4> replace_values; - for (auto v : - ::llvm::SmallVector<::mlir::Value, 4>{execute_op.getODSResults(0)}) { - replace_values.push_back(v); + // trt.compute + ::llvm::SmallVector<::mlir::Value, 4> replace_values2; + auto ctx_op = rewriter.create<::infrt::phi::CreateGPUContextOp>( + ods_loc, + infrt::phi::ContextType::get(rewriter.getContext(), + infrt::TargetType::GPU)); + auto compute_op = rewriter.create( + ods_loc, + ::infrt::DenseTensorListType::get(rewriter.getContext()), + create_engine_op.engine(), + ctx_op.output()); + auto tensor_list_val = compute_op.outputs(); + for (size_t i = 0; i < casted_op.getNumResults(); ++i) { + auto res = casted_op->getResult(i); + auto int_attr = mlir::IntegerAttr::get( + mlir::IntegerType::get(rewriter.getContext(), 32), i); + auto get_tensor_op = rewriter.create<::infrt::dt::TensorListGetTensorOp>( + ods_loc, res.getType(), tensor_list_val, int_attr); + replace_values2.push_back(get_tensor_op.output()); } - rewriter.replaceOp(op, replace_values); + ctx_op->moveBefore(ctx_op->getBlock(), ctx_op->getBlock()->begin()); + rewriter.replaceOp(op, replace_values2); return ::mlir::success(); } }; @@ -82,6 +89,9 @@ void TRTOpConverterPass::runOnOperation() { // this lowering. In our case, we are lowering to TensorRTDialect from // PaddleDialect target.addLegalDialect(); + target.addLegalDialect<::infrt::phi::PHIDialect>(); + target.addLegalDialect<::infrt::dt::DTDialect>(); + target.addLegalDialect(); // Now that the conversion target has been defined, we just need to provide // the set of patterns that will lower the TensorRT operations. diff --git a/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.cc b/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.cc index ef9ccc82678..5918be90cdd 100644 --- a/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.cc +++ b/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.cc @@ -14,7 +14,9 @@ #include "paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h" +#include #include +#include "paddle/infrt/dialect/dense_tensor.h" #include "paddle/infrt/dialect/infrt/ir/basic_kernels.h" #include "paddle/infrt/dialect/infrt/ir/infrt_dialect.h" #include "paddle/infrt/dialect/pd/ir/pd_ops.h" @@ -35,10 +37,12 @@ void TRTOpTellerPass::runOnFunction() { auto *op = worklist.back(); worklist.pop_back(); if (op == nullptr) continue; + if (op->getName().getStringRef().substr(0, 3) != "pd.") continue; if (::llvm::dyn_cast_or_null(op)) continue; if (::llvm::dyn_cast_or_null(op)) continue; if (::llvm::dyn_cast_or_null(op)) continue; if (::llvm::dyn_cast_or_null<::infrt::ReturnOp>(op)) continue; + builder.setInsertionPoint(op); auto loc = getFunction().getLoc(); auto graph_op = builder.create( diff --git a/paddle/infrt/dialect/tensorrt/trt_type_convert_pass.cc b/paddle/infrt/dialect/tensorrt/trt_type_convert_pass.cc new file mode 100644 index 00000000000..cd55fef696a --- /dev/null +++ b/paddle/infrt/dialect/tensorrt/trt_type_convert_pass.cc @@ -0,0 +1,169 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/infrt/dialect/tensorrt/trt_type_convert_pass.h" + +#include + +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/Value.h" +#include "mlir/Pass/Pass.h" +#include "paddle/infrt/dialect/infrt/common/types.h" +#include "paddle/infrt/dialect/infrt/ir/infrt_dialect.h" +#include "paddle/infrt/dialect/phi/ir/infrt_phi_tensor.h" +#include "paddle/infrt/dialect/tensorrt/trt_ops.h" + +namespace { + +class TrtTypeConvertPass + : public mlir::PassWrapper { + public: + ::llvm::StringRef getName() const override { return "TrtTypeConvertPass"; } + + void runOnFunction() override; +}; + +void TrtTypeConvertPass::runOnFunction() { + mlir::Block& body = getFunction().front(); + auto* mlir_ctx = getFunction()->getContext(); + mlir::OpBuilder builder(&body, body.begin()); + + std::vector worklist; + mlir::Operation* ctx_op{nullptr}; + worklist.reserve(body.getOperations().size()); + for (auto& op : body) { + worklist.push_back(&op); + if (op.getName().getStringRef() == "phi_dt.create_context.gpu") { + ctx_op = &op; + } + } + + ::infrt::LayoutType layout = ::infrt::LayoutType::NCHW; + ::infrt::TargetType target = ::infrt::TargetType::GPU; + for (auto& op : worklist) { + if (auto tensor_map_get_op = + llvm::dyn_cast<::infrt::phi::TensorMapGetTensorOp>(op)) { + auto res = tensor_map_get_op.output(); + if (auto t = res.getType().dyn_cast<::infrt::DenseTensorType>()) { + auto replace_type = ::infrt::DenseTensorType::get( + mlir_ctx, t.getTarget(), t.getPrecision(), layout); + res.setType(replace_type); + } + } + if (auto create_engine = llvm::dyn_cast<::infrt::trt::CreateEngineOp>(op)) { + // Insert `infrt.gpu.memcpy` op. + for (auto arg : create_engine.getOperands()) { + if (mlir::Operation* producer = arg.getDefiningOp()) { + if (arg.getType().isa<::infrt::DenseTensorType>()) { + builder.setInsertionPointAfter(producer); + auto t = arg.getType().dyn_cast<::infrt::DenseTensorType>(); + if (producer->getName().getStringRef() != + "phi_dt.tensor_map_get_tensor" && + t.getTarget() != ::infrt::TargetType::GPU) { + auto replace_type = ::infrt::DenseTensorType::get( + mlir_ctx, target, t.getPrecision(), layout); + CHECK_NOTNULL(ctx_op); + auto mem_cpy_op = builder.create<::infrt::phi::GpuMemCopyOp>( + arg.getLoc(), + replace_type, + arg, + llvm::dyn_cast<::infrt::phi::CreateGPUContextOp>(ctx_op) + .output(), + mlir::BoolAttr::get(mlir_ctx, /*d2h*/ false)); + arg.replaceAllUsesExcept(mem_cpy_op.output(), mem_cpy_op); + } + } + } else { + auto blockArg = arg.cast(); + if (arg.getType().isa<::infrt::DenseTensorType>()) { + auto t = arg.getType().dyn_cast<::infrt::DenseTensorType>(); + builder.setInsertionPointAfter(ctx_op); + auto replace_type = ::infrt::DenseTensorType::get( + mlir_ctx, ::infrt::TargetType::GPU, t.getPrecision(), layout); + CHECK_NOTNULL(ctx_op); + auto mem_cpy_op = builder.create<::infrt::phi::GpuMemCopyOp>( + blockArg.getLoc(), + replace_type, + blockArg, + llvm::dyn_cast<::infrt::phi::CreateGPUContextOp>(ctx_op) + .output(), + mlir::BoolAttr::get(mlir_ctx, /*d2h*/ false)); + arg.replaceAllUsesExcept(mem_cpy_op.output(), mem_cpy_op); + } + } + } + + // Change ops(in block) types. + auto& block = create_engine.getRegion().getBlocks().front(); + for (auto& op : block.without_terminator()) { + for (size_t i = 0; i < op.getNumResults(); ++i) { + if (auto t = op.getResult(i) + .getType() + .dyn_cast<::infrt::DenseTensorType>()) { + auto replace_type = ::infrt::DenseTensorType::get( + mlir_ctx, ::infrt::TargetType::GPU, t.getPrecision(), layout); + op.getResult(i).setType(replace_type); + } + } + } + } else if (auto list_get_tensor_op = + llvm::dyn_cast<::infrt::dt::TensorListGetTensorOp>(op)) { + auto result = list_get_tensor_op.output(); + if (auto t = result.getType().dyn_cast<::infrt::DenseTensorType>()) { + result.setType(::infrt::DenseTensorType::get( + mlir_ctx, ::infrt::TargetType::GPU, t.getPrecision(), layout)); + } + } else if (auto return_op = llvm::dyn_cast<::infrt::ReturnOp>(op)) { + for (auto arg : return_op->getOperands()) { + if (auto t = arg.getType().dyn_cast<::infrt::DenseTensorType>()) { + if (t.getLayout() != ::infrt::LayoutType::ANY || + t.getTarget() != ::infrt::TargetType::CPU || + t.getPrecision() != ::infrt::PrecisionType::FLOAT32) { + builder.setInsertionPoint(return_op); + CHECK_NOTNULL(ctx_op); + auto mem_cpy_op = builder.create<::infrt::phi::GpuMemCopyOp>( + return_op.getLoc(), + ::infrt::DenseTensorType::get(mlir_ctx, + ::infrt::TargetType::CPU, + t.getPrecision(), + ::infrt::LayoutType::ANY), + arg, + llvm::dyn_cast<::infrt::phi::CreateGPUContextOp>(ctx_op) + .output(), + mlir::BoolAttr::get(mlir_ctx, /*d2h*/ true)); + arg.replaceAllUsesExcept(mem_cpy_op.output(), mem_cpy_op); + } + } + } + } + } +} + +} // namespace + +namespace infrt { +namespace trt { + +std::unique_ptr createTrtTypeConvertPass() { + return std::make_unique(); +} + +} // namespace trt +} // namespace infrt diff --git a/paddle/infrt/dialect/tensorrt/trt_type_convert_pass.h b/paddle/infrt/dialect/tensorrt/trt_type_convert_pass.h new file mode 100644 index 00000000000..fbc30cdbeb7 --- /dev/null +++ b/paddle/infrt/dialect/tensorrt/trt_type_convert_pass.h @@ -0,0 +1,25 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +namespace infrt { +namespace trt { + +std::unique_ptr createTrtTypeConvertPass(); + +} // namespace trt +} // namespace infrt diff --git a/paddle/infrt/host_context/mlir_to_runtime_translate.cc b/paddle/infrt/host_context/mlir_to_runtime_translate.cc index 7e90f225cff..609524bead1 100644 --- a/paddle/infrt/host_context/mlir_to_runtime_translate.cc +++ b/paddle/infrt/host_context/mlir_to_runtime_translate.cc @@ -309,7 +309,7 @@ bool MlirToRuntimeTranslator::EmitGeneralOp( arg_value = GetOpResult(upstream_op); } } - if (arg_value->is_type()) { + if (arg_value->is_type<::phi::DenseTensor>()) { impl_->runtime->FeedInArgs( std::make_pair(std::to_string(i), ValueRef(arg_value))); } diff --git a/paddle/infrt/kernel/phi/context_kernels.cc b/paddle/infrt/kernel/phi/context_kernels.cc index b27eacf9e52..f38a1107716 100644 --- a/paddle/infrt/kernel/phi/context_kernels.cc +++ b/paddle/infrt/kernel/phi/context_kernels.cc @@ -30,6 +30,7 @@ namespace phi { ::phi::GPUContext context; context.PartialInitWithoutAllocator(); context.SetAllocator(new ::infrt::backends::GpuPhiAllocator{}); + context.SetHostAllocator(new backends::CpuPhiAllocator{}); context.PartialInitWithAllocator(); return context; } diff --git a/paddle/infrt/kernel/phi/dense_tensor_kernels.cc b/paddle/infrt/kernel/phi/dense_tensor_kernels.cc index c8b1bd8c9eb..66698d36b55 100644 --- a/paddle/infrt/kernel/phi/dense_tensor_kernels.cc +++ b/paddle/infrt/kernel/phi/dense_tensor_kernels.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/infrt/kernel/phi/dense_tensor_kernels.h" +#include "llvm/Support/ErrorHandling.h" #include "paddle/infrt/common/string.h" #include "paddle/infrt/dialect/phi/data_type.h" #include "paddle/infrt/kernel/phi/context_kernels.h" @@ -228,6 +229,69 @@ int32_t TensorMapGetSize(const ::infrt::phi::DenseTensorMap& map) { return map.size(); } +#ifdef INFRT_WITH_GPU +inline size_t SizeOfDataType(::phi::DataType data_type) { + switch (data_type) { + case ::phi::DataType::BOOL: + case ::phi::DataType::UINT8: + case ::phi::DataType::INT8: + return 1; + case ::phi::DataType::BFLOAT16: + case ::phi::DataType::FLOAT16: + case ::phi::DataType::INT16: + case ::phi::DataType::UINT16: + return 2; + case ::phi::DataType::FLOAT32: + case ::phi::DataType::INT32: + case ::phi::DataType::UINT32: + return 4; + case ::phi::DataType::FLOAT64: + case ::phi::DataType::INT64: + case ::phi::DataType::UINT64: + case ::phi::DataType::COMPLEX64: + return 8; + case ::phi::DataType::COMPLEX128: + return 16; + case ::phi::DataType::UNDEFINED: + return 0; + default: + llvm_unreachable("should not reach here"); + return 0; + } + return 0; +} +::phi::DenseTensor GpuMemCpy(const ::phi::DenseTensor& input, + const ::phi::GPUContext& context, + bool d2h) { + if (d2h) { + ::phi::DenseTensor ret( + const_cast<::phi::Allocator*>(&context.GetHostAllocator()), + input.meta()); + CHECK(input.place().GetType() == ::phi::AllocationType::GPU); + // TODO(wilber): Add sync op and stream. + cudaMemcpyAsync(ret.data(), + input.data(), + SizeOfDataType(input.dtype()) * input.numel(), + cudaMemcpyDeviceToHost, + nullptr); + return ret; + } else { + // h2d + ::phi::DenseTensor ret( + const_cast<::phi::Allocator*>(&context.GetAllocator()), input.meta()); + CHECK(input.place().GetType() == ::phi::AllocationType::CPU || + input.place().GetType() == ::phi::AllocationType::GPUPINNED); + // TODO(wilber): Add sync op and stream. + cudaMemcpyAsync(ret.data(), + input.data(), + SizeOfDataType(input.dtype()) * input.numel(), + cudaMemcpyHostToDevice, + nullptr); + return ret; + } +} +#endif + } // namespace phi } // namespace kernel } // namespace infrt diff --git a/paddle/infrt/kernel/phi/dense_tensor_kernels.h b/paddle/infrt/kernel/phi/dense_tensor_kernels.h index 6cfcc6f91be..75eab19396f 100644 --- a/paddle/infrt/kernel/phi/dense_tensor_kernels.h +++ b/paddle/infrt/kernel/phi/dense_tensor_kernels.h @@ -18,6 +18,7 @@ #include "paddle/infrt/dialect/infrt/common/types.h" #include "paddle/infrt/host_context/kernel_utils.h" #include "paddle/infrt/tensor/phi/tensor_map.h" +#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/dense_tensor.h" namespace infrt { @@ -55,6 +56,12 @@ infrt::phi::DenseTensorMap LoadParams( int32_t TensorMapGetSize(const ::infrt::phi::DenseTensorMap& map); +#ifdef INFRT_WITH_GPU +::phi::DenseTensor GpuMemCpy(const ::phi::DenseTensor& input, + const ::phi::GPUContext& context, + bool d2h); +#endif + } // namespace phi } // namespace kernel } // namespace infrt diff --git a/paddle/infrt/kernel/phi/registry.cc b/paddle/infrt/kernel/phi/registry.cc index 08683d7cb66..3b437a439fc 100644 --- a/paddle/infrt/kernel/phi/registry.cc +++ b/paddle/infrt/kernel/phi/registry.cc @@ -52,6 +52,9 @@ void RegisterPhiKernels(host_context::KernelRegistry* registry) { "phi_dt.create_dense_tensor.gpu", INFRT_KERNEL(infrt::kernel::phi::CreateGPUDenseTensor), {"dims", "lod", "layout", "precision"}); + registry->AddKernelWithAttrs("phi_dt.memcpy.gpu", + INFRT_KERNEL(infrt::kernel::phi::GpuMemCpy), + {"d2h"}); #endif registry->AddKernelWithAttrs("phi_dt.load_params", INFRT_KERNEL(infrt::kernel::phi::LoadParams), diff --git a/paddle/infrt/kernel/tensorrt/trt_kernels.cc b/paddle/infrt/kernel/tensorrt/trt_kernels.cc index aa7609092b8..2f73c6b13f4 100644 --- a/paddle/infrt/kernel/tensorrt/trt_kernels.cc +++ b/paddle/infrt/kernel/tensorrt/trt_kernels.cc @@ -14,6 +14,7 @@ #include "paddle/infrt/kernel/tensorrt/trt_kernels.h" #include +#include #include "NvInfer.h" #include "NvInferRuntime.h" #include "NvInferRuntimeCommon.h" @@ -68,7 +69,7 @@ namespace tensorrt { auto& region = operation.getRegion(0); auto& block = region.getBlocks().front(); - std::unordered_map trt_bind_inputs; + std::unordered_map trt_bind_inputs; ValueToITensorMap value_to_trt_tensor_map; ValueToTensorMap value_to_tensor_map; @@ -79,7 +80,7 @@ namespace tensorrt { const std::string input_name = "input_" + std::to_string(idx); auto* v = symbol_table->GetValue(std::to_string(idx)); CHECK_NOTNULL(v); - auto* t = &v->get(); + auto* t = &v->get<::phi::DenseTensor>(); value_to_tensor_map[operand] = t; // TODO(wilber): get input info from mlir. @@ -93,7 +94,7 @@ namespace tensorrt { if (operand.isa()) { // TODO(wilber): A trick: the weights are CPU tensor and inputs are GPU // tensor, so we treat all GPU tensors as inputs to trt. - if (t->place().GetType() == phi::AllocationType::GPU) { + if (t->place().GetType() == ::phi::AllocationType::GPU) { trt_bind_inputs[input_name] = t; nvinfer1::Dims dims; dims.nbDims = t->dims().size() - 1; @@ -106,8 +107,10 @@ namespace tensorrt { } } else { // TODO(wilber): Replace with the op name that generates the weights. - if (operand.getDefiningOp()->getName().getStringRef() != - "phi_dt.create_dense_tensor.cpu") { + std::unordered_set weight_flags{ + "phi_dt.tensor_map_get_tensor", "phi_dt.create_dense_tensor.cpu"}; + if (!weight_flags.count( + operand.getDefiningOp()->getName().getStringRef().str())) { trt_bind_inputs[input_name] = t; nvinfer1::Dims dims; dims.nbDims = t->dims().size() - 1; @@ -167,10 +170,10 @@ void PrintTrtLayer(backends::tensorrt::TrtEngine* engine) { engine->GetEngineInfo(); } -std::vector TrtEngineCompute( - backends::tensorrt::TrtEngine* engine, const phi::GPUContext& context) { +std::vector<::phi::DenseTensor*> TrtEngineCompute( + backends::tensorrt::TrtEngine* engine, const ::phi::GPUContext& context) { engine->Run(context); - std::vector res; + std::vector<::phi::DenseTensor*> res; for (size_t i = 0; i < engine->GetOutputNum(); ++i) { res.push_back(engine->GetOutput("output_" + std::to_string(i))); } diff --git a/paddle/infrt/kernel/tensorrt/trt_kernels.h b/paddle/infrt/kernel/tensorrt/trt_kernels.h index 546ee9dc788..bf23bd45c13 100644 --- a/paddle/infrt/kernel/tensorrt/trt_kernels.h +++ b/paddle/infrt/kernel/tensorrt/trt_kernels.h @@ -41,8 +41,8 @@ struct MlirOperationWithInfrtSymbol { void PrintTrtLayer(backends::tensorrt::TrtEngine* engine); -std::vector TrtEngineCompute( - backends::tensorrt::TrtEngine* engine, const phi::GPUContext& context); +std::vector<::phi::DenseTensor*> TrtEngineCompute( + backends::tensorrt::TrtEngine* engine, const ::phi::GPUContext& context); } // namespace tensorrt } // namespace kernel diff --git a/paddle/infrt/tests/CMakeLists.txt b/paddle/infrt/tests/CMakeLists.txt index 6f839cdc395..3c4a2f1cbb8 100644 --- a/paddle/infrt/tests/CMakeLists.txt +++ b/paddle/infrt/tests/CMakeLists.txt @@ -7,3 +7,4 @@ add_test(NAME test_infrt_by_lit COMMAND sh -c "lit -v ${CMAKE_SOURCE_DIR}/paddle configure_file(${CMAKE_CURRENT_SOURCE_DIR}/dialect/tensor/tensor_map.mlir.in ${CMAKE_CURRENT_SOURCE_DIR}/dialect/tensor/tensor_map.mlir) configure_file(${CMAKE_CURRENT_SOURCE_DIR}/dialect/phi/linear_cpu.mlir.in ${CMAKE_CURRENT_SOURCE_DIR}/dialect/phi/linear_cpu.mlir) +configure_file(${CMAKE_CURRENT_SOURCE_DIR}/dialect/tensorrt/disabled_linear.mlir.in ${CMAKE_CURRENT_SOURCE_DIR}/dialect/tensorrt/disabled_linear.mlir) diff --git a/paddle/infrt/tests/dialect/tensorrt/disabled_linear.mlir.in b/paddle/infrt/tests/dialect/tensorrt/disabled_linear.mlir.in new file mode 100644 index 00000000000..74a7de43350 --- /dev/null +++ b/paddle/infrt/tests/dialect/tensorrt/disabled_linear.mlir.in @@ -0,0 +1,33 @@ +module { + func @main_graph(%map: !phi.dense_tensor_map, %arg0: !infrt.dense_tensor) -> !infrt.dense_tensor { + %0 = "phi_dt.create_context.gpu"() : () -> !phi.context + %1 = "phi_dt.memcpy.gpu"(%arg0, %0) {d2h = false} : (!infrt.dense_tensor, !phi.context) -> !infrt.dense_tensor + + %3 = phi_dt.tensor_map_get_tensor(%map) {name = "linear_0.b_0"} -> !infrt.dense_tensor + %4 = phi_dt.tensor_map_get_tensor(%map) {name = "linear_0.w_0"} -> !infrt.dense_tensor + %5 = "trt.create_engine"(%1, %4, %3) ( { + %10 = "trt.FullyConnected"(%1, %4, %3) {out_channel_num = 10 : si32} : (!infrt.dense_tensor, !infrt.dense_tensor, !infrt.dense_tensor) -> !infrt.dense_tensor + infrt.return %10 : !infrt.dense_tensor + }) {run_once = true} : (!infrt.dense_tensor, !infrt.dense_tensor, !infrt.dense_tensor) -> !trt.engine + %6 = "trt.compute"(%5, %0) : (!trt.engine, !phi.context) -> !infrt.tensor_list + %7 = "dt.tensor_list_get_tensor"(%6) {id = 0 : i32} : (!infrt.tensor_list) -> !infrt.dense_tensor + %8 = "phi_dt.memcpy.gpu"(%7, %0) {d2h = true} : (!infrt.dense_tensor, !phi.context) -> !infrt.dense_tensor + infrt.return %8 : !infrt.dense_tensor + } + + func @main() { + %map = phi_dt.load_combined_params(){model_path="@CMAKE_BINARY_DIR@/linear/linear.pdmodel", + params_path="@CMAKE_BINARY_DIR@/linear/linear.pdiparams"} + + %ctx = "phi_dt.create_context.cpu" (): () -> !phi.context + %input_tensor = "phi_dt.create_dense_tensor.cpu" (%ctx) { + precision=#infrt.precision, + layout=#infrt.layout, + dims=[3:i64, 784:i64, 1:i64, 1:i64], lod=[1:i64]}: (!phi.context) -> (!infrt.dense_tensor) + "phi_dt.fill_dense_tensor.f32"(%input_tensor) {value=[3.8:f32, 2.4:f32, 1.3:f32]} : (!infrt.dense_tensor) -> () + + %res = infrt.call @main_graph(%map, %input_tensor) {} : (!phi.dense_tensor_map, !infrt.dense_tensor) -> !infrt.dense_tensor + "phi_dt.print_tensor" (%res) : (!infrt.dense_tensor) -> () + infrt.return + } +} -- GitLab