diff --git a/paddle/infrt/CMakeLists.txt b/paddle/infrt/CMakeLists.txt index 4e273f6d551edd74ec979e6ec34aedabdb58bd10..f394b754a8eada5550f297647d5fdf36c639896a 100644 --- a/paddle/infrt/CMakeLists.txt +++ b/paddle/infrt/CMakeLists.txt @@ -3,12 +3,22 @@ if (NOT WITH_INFRT) endif() option(INFRT_WITH_PHI "Compile INFRT with PHI" ON) +option(INFRT_WITH_GPU "Compile INFRT with GPU" OFF) +option(INFRT_WITH_TRT "Compile INFRT with TensorRT" OFF) #TODO(xiaowei) remove fluid include_directories(${PADDLE_SOURCE_DIR}/paddle/fluid/platform) if (INFRT_WITH_PHI) - add_definitions("-DINFRT_WITH_PHI") + add_definitions("-DINFRT_WITH_PHI") + + # TODO(wilber): Now Infrt gpu/trt depends on phi's components, Modify compile dependency options later. + if (INFRT_WITH_GPU) + add_definitions("-DINFRT_WITH_GPU") + if (INFRT_WITH_TRT) + add_definitions("-DINFRT_WITH_TRT") + endif() + endif() endif() # compile flags @@ -106,6 +116,9 @@ if (INFRT_WITH_PHI) endif() cc_library(infrt SHARED SRCS ${infrt_src} DEPS glog boost ${mlir_libs} ${phi_libs} paddle_framework_proto infrt_naive) +if (INFRT_WITH_TRT) + target_link_libraries(infrt infrt_trt) +endif() cc_library(infrt_static SRCS ${infrt_src} DEPS glog boost ${mlir_libs} ${phi_libs} paddle_framework_proto) add_dependencies(infrt ${infrt_mlir_incs} mlir-headers) diff --git a/paddle/infrt/backends/host/phi_allocator.h b/paddle/infrt/backends/host/phi_allocator.h index c8f97e04a1b8376efbac749fffa70d77c7b95e72..6e3bef9299162d493825f49e3962c75f2845e2d0 100644 --- a/paddle/infrt/backends/host/phi_allocator.h +++ b/paddle/infrt/backends/host/phi_allocator.h @@ -13,6 +13,10 @@ limitations under the License. */ #include "paddle/phi/core/allocator.h" +#ifdef INFRT_WITH_GPU +#include +#endif + namespace infrt { namespace backends { @@ -29,5 +33,22 @@ class CpuPhiAllocator : public phi::Allocator { } }; +#ifdef INFRT_WITH_GPU +// TODO(wilber): Just for demo test. we need a more efficient gpu allocator. +class GpuPhiAllocator : public phi::Allocator { + public: + static void deleter(phi::Allocation* ptr) { cudaFree(ptr->ptr()); } + + AllocationPtr Allocate(size_t bytes_size) { + void* ptr; + cudaMalloc(&ptr, bytes_size); + return AllocationPtr( + new phi::Allocation( + ptr, bytes_size, phi::Place(phi::AllocationType::GPU)), + deleter); + } +}; +#endif + } // namespace backends } // namespace infrt diff --git a/paddle/infrt/backends/host/phi_context.h b/paddle/infrt/backends/host/phi_context.h index 5713fdbbaf82b2ea2190d2ee1b1dc5d944f2c262..bcd63dbb39fe8c52499138423bc9b86fa5de9d57 100644 --- a/paddle/infrt/backends/host/phi_context.h +++ b/paddle/infrt/backends/host/phi_context.h @@ -13,6 +13,7 @@ limitations under the License. */ #include "paddle/infrt/backends/host/phi_allocator.h" #include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/backends/gpu/gpu_context.h" namespace infrt { namespace backends { @@ -31,5 +32,16 @@ class CpuPhiContext : public phi::CPUContext { std::unique_ptr alloc_{std::make_unique()}; }; +class GpuPhiContext : public phi::GPUContext { + public: + using Base = phi::GPUContext; + using phi::GPUContext::SetStream; + using phi::GPUContext::SetEigenDevice; + using phi::GPUContext::SetBlasHandle; + using phi::GPUContext::SetDnnHandle; + using phi::GPUContext::SetSolverHandle; + using phi::GPUContext::SetSparseHandle; +}; + } // namespace backends } // namespace infrt diff --git a/paddle/infrt/backends/tensorrt/test_trt_engine.cc b/paddle/infrt/backends/tensorrt/test_trt_engine.cc index 12cf14060e27c1d58e3fd9b14cc12b3c1f7f8907..0ab64dd51c88758d043fb9105ffbf0d109e44cc0 100644 --- a/paddle/infrt/backends/tensorrt/test_trt_engine.cc +++ b/paddle/infrt/backends/tensorrt/test_trt_engine.cc @@ -37,9 +37,9 @@ namespace infrt { namespace backends { namespace tensorrt { -const char* model_input = "model_input"; -const char* model_output = "model_output1"; -const char* model_output2 = "model_output2"; +const char* model_input = "input_0"; +const char* model_output = "output_0"; +const char* model_output2 = "output_1"; TrtUniquePtr ConstructNetwork( nvinfer1::IBuilder* builder, nvinfer1::Dims dims, bool is_static_shape) { @@ -122,27 +122,26 @@ TEST(trt, run_static) { std::unordered_map inputs; inputs.emplace(std::make_pair(model_input, &input)); - phi::DenseTensor output, output2; - std::unordered_map outputs; - outputs.emplace(std::make_pair(model_output, &output)); - outputs.emplace(std::make_pair(model_output2, &output2)); - - static_trt_engine.SetUpInference(inference_options, inputs, &outputs); + static_trt_engine.PrepareOutputHandle("output_0"); + static_trt_engine.PrepareOutputHandle("output_1"); + static_trt_engine.SetUpInference(inference_options, inputs); static_trt_engine.GetEngineInfo(); static_trt_engine.Run(context); + phi::DenseTensor* output0 = static_trt_engine.GetOutput("output_0"); + phi::DenseTensor* output1 = static_trt_engine.GetOutput("output_1"); std::vector output_data1(inference_options.batch * 1 * 28 * 28, 0); std::vector output_data2(inference_options.batch * 2 * 28 * 28, 0); paddle::memory::Copy(phi::CPUPlace(), output_data1.data(), place, - output.data(), + output0->data(), sizeof(float) * output_data1.size(), context.stream()); paddle::memory::Copy(phi::CPUPlace(), output_data2.data(), place, - output2.data(), + output1->data(), sizeof(float) * output_data2.size(), context.stream()); cudaStreamSynchronize(context.stream()); @@ -208,27 +207,27 @@ TEST(trt, run_dynamic) { context.stream()); std::unordered_map inputs; - std::unordered_map outputs; inputs.emplace(std::make_pair(model_input, &input)); - outputs.emplace(std::make_pair(model_output, &output)); - outputs.emplace(std::make_pair(model_output2, &output2)); - - engine.SetUpInference(inference_options, inputs, &outputs); + engine.PrepareOutputHandle("output_0"); + engine.PrepareOutputHandle("output_1"); + engine.SetUpInference(inference_options, inputs); engine.GetEngineInfo(); engine.Run(context); + phi::DenseTensor* output0 = engine.GetOutput("output_0"); + phi::DenseTensor* output1 = engine.GetOutput("output_1"); std::vector output_data1(inference_options.batch * 1 * 16 * 16, 0); std::vector output_data2(inference_options.batch * 2 * 16 * 16, 0); paddle::memory::Copy(phi::CPUPlace(), output_data1.data(), place, - output.data(), + output0->data(), sizeof(float) * output_data1.size(), context.stream()); paddle::memory::Copy(phi::CPUPlace(), output_data2.data(), place, - output2.data(), + output1->data(), sizeof(float) * output_data2.size(), context.stream()); cudaStreamSynchronize(context.stream()); diff --git a/paddle/infrt/backends/tensorrt/trt_engine.cc b/paddle/infrt/backends/tensorrt/trt_engine.cc index 232653e8c41f71fd9bb32c9eac302b047d122b66..43d356b6d6983afdca220029d34d9d5cd27da009 100644 --- a/paddle/infrt/backends/tensorrt/trt_engine.cc +++ b/paddle/infrt/backends/tensorrt/trt_engine.cc @@ -21,6 +21,7 @@ #include "paddle/phi/backends/dynload/tensorrt.h" #include "paddle/phi/backends/gpu/gpu_info.h" #include "paddle/phi/core/ddim.h" +#include "paddle/phi/core/dense_tensor.h" namespace infrt { namespace backends { @@ -235,10 +236,20 @@ bool TrtEngine::SetupNetworkAndConfig(const BuildOptions& build, return true; } +void TrtEngine::PrepareOutputHandle(const std::string& out_name) { + phi::DenseTensor t; + outputs_.emplace(out_name, t); +} + +phi::DenseTensor* TrtEngine::GetOutput(const std::string& name) { + return &outputs_[name]; +} + +size_t TrtEngine::GetOutputNum() const { return outputs_.size(); } + bool TrtEngine::SetUpInference( const InferenceOptions& inference, - const std::unordered_map& inputs, - std::unordered_map* outputs) { + const std::unordered_map& inputs) { // TODO(wilber): now only create one exec_context FreshDeviceId(); CHECK(engine_ != nullptr); @@ -252,10 +263,10 @@ bool TrtEngine::SetUpInference( bindings_.front()->AddBinding( bind_index, it.first, true, it.second, nvinfer1::DataType::kFLOAT); } - for (auto& it : *outputs) { + for (auto& it : outputs_) { const int bind_index = engine_->getBindingIndex(it.first.c_str()); bindings_.front()->AddBinding( - bind_index, it.first, false, it.second, nvinfer1::DataType::kFLOAT); + bind_index, it.first, false, &it.second, nvinfer1::DataType::kFLOAT); } return true; @@ -290,11 +301,13 @@ void TrtEngine::StaticRun(const phi::GPUContext& ctx) { const int bind_index = engine_->getBindingIndex(bind.name.c_str()); std::vector ddim; auto dims = engine_->getBindingDimensions(bind_index); + CHECK_NE(runtime_batch, -1) << "runtime_batch should not be -1."; ddim.push_back(runtime_batch); for (int i = 0; i < dims.nbDims; ++i) { ddim.push_back(dims.d[i]); } bind.buffer->Resize(phi::make_ddim(ddim)); + // TODO(wilber): now only support float output. ctx.Alloc(bind.buffer, sizeof(float) * bind.buffer->numel()); buffers[bind_index] = static_cast(bind.buffer->data()); } diff --git a/paddle/infrt/backends/tensorrt/trt_engine.h b/paddle/infrt/backends/tensorrt/trt_engine.h index 3c8243e3c3838e30eb70877f8c82d623c103eaff..a26474f8cbb357d42cd6d951829bbdc24a256640 100644 --- a/paddle/infrt/backends/tensorrt/trt_engine.h +++ b/paddle/infrt/backends/tensorrt/trt_engine.h @@ -81,11 +81,17 @@ class TrtEngine { // TODO(wilber): How to support multiple execution contexts? bool SetUpInference( const InferenceOptions& inference, - const std::unordered_map& inputs, - std::unordered_map* outputs); + const std::unordered_map& inputs); void GetEngineInfo(); + void PrepareOutputHandle(const std::string& out_name); + + // TODO(wilber): The output tensor names are: output_0, output_1, ... + phi::DenseTensor* GetOutput(const std::string&); + + size_t GetOutputNum() const; + private: void FreshDeviceId(); @@ -112,6 +118,7 @@ class TrtEngine { std::vector> bindings_; int device_id_{0}; bool is_dynamic_shape_{false}; + std::unordered_map outputs_; }; } // namespace tensorrt diff --git a/paddle/infrt/dialect/dense_tensor.td b/paddle/infrt/dialect/dense_tensor.td index 666c7b300af33db0c27e5b3ab8a74aa4b1591c9b..59df4e9697370e9d8db4bbc0a5d69e8ef03950a5 100644 --- a/paddle/infrt/dialect/dense_tensor.td +++ b/paddle/infrt/dialect/dense_tensor.td @@ -130,7 +130,7 @@ def TensorMapGetTensorOp : DT_Op<"tensor_map_get_tensor", [NoSideEffect]> { } def TensorMapGetSizeOp : DT_Op<"tensor_map_get_size", [NoSideEffect]> { - let summary = "ddt.tensor_map_get_size operation"; + let summary = "dt.tensor_map_get_size operation"; let description = [{ An operation that get the size of a TensorMap. @@ -141,6 +141,32 @@ def TensorMapGetSizeOp : DT_Op<"tensor_map_get_size", [NoSideEffect]> { let assemblyFormat = "`(` $map `)` attr-dict `->` type($size)"; } +def Infrt_TensorListGetTensorOp : DT_Op<"tensor_list_get_tensor", [NoSideEffect]> { + let summary = "dt.tensor_list_get_tensor operation"; + + let description = [{ + An operation that can get a tensor from a TensorList. + }]; + + let arguments = (ins + DenseTensorList:$l, + I32Attr:$id + ); + let results = (outs DenseTensor:$output); + let verifier = ?; +} + +def TensorListGetSizeOp : DT_Op<"tensor_list_get_size", [NoSideEffect]> { + let summary = "dt.tensor_list_get_size operation"; + + let description = [{ + An operation that get the size of a TensorList. + }]; + + let arguments = (ins DenseTensorList:$map); + let results = (outs I32:$size); +} + def GetTensorShapeOp : DT_Op<"get_tensor_shape", [NoSideEffect]> { let summary = "dt.get_tensor_shape operation"; diff --git a/paddle/infrt/dialect/infrt/ir/infrt_base.td b/paddle/infrt/dialect/infrt/ir/infrt_base.td index c5130e89bb13a58a0aa0cf3aeae1b00e269eb259..86cfc375330b19878528645a2e810efb797e153f 100644 --- a/paddle/infrt/dialect/infrt/ir/infrt_base.td +++ b/paddle/infrt/dialect/infrt/ir/infrt_base.td @@ -89,6 +89,13 @@ def DenseTensorMap : Infrt_Type<"DenseTensorMap"> { let parameters = (ins); } +// TODO(wilber): Add !infrt.vec type. +def DenseTensorList : Infrt_Type<"DenseTensorList"> { + let summary = "infrt dense tensor map"; + let description = [{dense_tensor map}]; + let parameters = (ins); +} + // Type Constrait for concrete DenseTensor type. class DenseTensor : Type, diff --git a/paddle/infrt/dialect/infrt/ir/infrt_dialect.cc b/paddle/infrt/dialect/infrt/ir/infrt_dialect.cc index 3a1b45d3a20a1e3ff6698f37e412837fcb064f7c..8966ca13c2be08f1c744a73b4beaf20b0a3c015c 100644 --- a/paddle/infrt/dialect/infrt/ir/infrt_dialect.cc +++ b/paddle/infrt/dialect/infrt/ir/infrt_dialect.cc @@ -138,6 +138,10 @@ mlir::Type InfrtDialect::parseType(::mlir::DialectAsmParser &parser) const { parser.getContext(), *targetType, *precisionType, *layoutType); } + if (keyword == "tensor_list") { + return infrt::DenseTensorListType::get(parser.getContext()); + } + if (keyword == "dense_tensor_map") { return DenseTensorMapType::get(parser.getContext()); } @@ -175,6 +179,9 @@ void InfrtDialect::printType(::mlir::Type type, return; } + if (type.isa()) { + os << "tensor_list"; + } // print DenseTensorType, for example: !infrt.dense_tensor if (type.isa()) { os << "dense_tensor_map"; diff --git a/paddle/infrt/dialect/init_dialects.cc b/paddle/infrt/dialect/init_dialects.cc index 55f6de625237a59f0ab73f7f7203847d4a9754e5..6183295cafb356e85c0fd8bf417c3fb18eb30787 100644 --- a/paddle/infrt/dialect/init_dialects.cc +++ b/paddle/infrt/dialect/init_dialects.cc @@ -26,6 +26,7 @@ #include "paddle/infrt/dialect/phi/ir/phi_kernels.h" #include "paddle/infrt/dialect/tensor_shape.h" +#include "paddle/infrt/dialect/tensorrt/trt_ops.h" namespace infrt { void registerCinnDialects(mlir::DialectRegistry ®istry) { // NOLINT @@ -37,7 +38,8 @@ void registerCinnDialects(mlir::DialectRegistry ®istry) { // NOLINT phi::PHIDenseTensorDialect, phi::PHICPUKernelDialect, phi::PHIGPUKernelDialect, - phi::PHIDialect + phi::PHIDialect, + infrt::trt::TensorRTDialect #endif >(); } diff --git a/paddle/infrt/dialect/phi/ir/infrt_phi_tensor.td b/paddle/infrt/dialect/phi/ir/infrt_phi_tensor.td index 8c3a79498d74d3b80e1590bbc2c0530c7af6411e..1fda2d9d8886008c6415b5a1cf36d53c1500707a 100644 --- a/paddle/infrt/dialect/phi/ir/infrt_phi_tensor.td +++ b/paddle/infrt/dialect/phi/ir/infrt_phi_tensor.td @@ -21,8 +21,8 @@ def PHI_DenseTensorDialect : Dialect { class PDT_Op traits = []> : Op {} -class CreateDenseTensorOp - : PDT_Op<"create_dense_tensor", [NoSideEffect]> { +class CreateDenseTensorOp + : PDT_Op<"create_dense_tensor." # target, [NoSideEffect]> { let arguments = (ins Context:$context, I64ArrayAttr:$dims, LayoutAttr:$layout, I64ArrayAttr:$lod, PrecisionAttr:$precision); let results = (outs DenseTensor:$output); @@ -51,9 +51,11 @@ class CreateContextOp let results = (outs Context:$output); } -def PDT_CreateDenseTensorOp : CreateDenseTensorOp; +def PDT_CreateCPUDenseTensorOp : CreateDenseTensorOp<"cpu">; +def PDT_CreateGPUDenseTensorOp : CreateDenseTensorOp<"gpu">; def PDT_FillDenseTensorOp_f32 : FillDenseTensorOp; def PDT_CreateCPUContextOp : CreateContextOp<"cpu">; +def PDT_CreateGPUContextOp : CreateContextOp<"gpu">; def PDT_PrintDenseTensor : PrintDenseTensorOp; def FakeKernelOp : PDT_Op<"fake_phi_kernel"> { diff --git a/paddle/infrt/dialect/tensorrt/trt_ops.cc b/paddle/infrt/dialect/tensorrt/trt_ops.cc index d5222976625a2adece9a87c8952dba10137ae9ba..415a78a6967ab6fd4e2a38380d09a5d5c64b1c2f 100644 --- a/paddle/infrt/dialect/tensorrt/trt_ops.cc +++ b/paddle/infrt/dialect/tensorrt/trt_ops.cc @@ -21,6 +21,10 @@ #include "paddle/infrt/common/global.h" #include "paddle/infrt/dialect/tensorrt/trt_dialect_types.h" +#include "paddle/infrt/dialect/dense_tensor.h" +#include "paddle/infrt/dialect/infrt/ir/infrt_dialect.h" +#include "paddle/infrt/dialect/phi/ir/phi_base.h" + namespace infrt { namespace trt { diff --git a/paddle/infrt/dialect/tensorrt/trt_ops.td b/paddle/infrt/dialect/tensorrt/trt_ops.td index 132a1d7805bdb85af8716e384ec29357a6ff68ad..31b28a38e7cfee4eb6da68302d482218d97f8350 100755 --- a/paddle/infrt/dialect/tensorrt/trt_ops.td +++ b/paddle/infrt/dialect/tensorrt/trt_ops.td @@ -7,6 +7,8 @@ include "mlir/Interfaces/CallInterfaces.td" include "mlir/IR/OpBase.td" include "paddle/infrt/dialect/tensorrt/trt_op_base.td" +include "paddle/infrt/dialect/infrt/ir/infrt_base.td" +include "paddle/infrt/dialect/phi/ir/infrt_phi_base.td" def TRT_CreateEngineOp : TRT_Op<"create_engine", [SingleBlockImplicitTerminator<"::infrt::ReturnOp">]> { let summary = "trt CreateEngine Op"; @@ -14,8 +16,8 @@ def TRT_CreateEngineOp : TRT_Op<"create_engine", [SingleBlockImplicitTerminator< Describe a tensorrt subgraph. }]; let regions = (region SizedRegion<1>:$body); - let arguments = (ins Variadic:$inputs, DefaultValuedAttr:$run_once); - let results = (outs TRT_EngineType:$output); + let arguments = (ins Variadic:$inputs, DefaultValuedAttr:$run_once); + let results = (outs TRT_EngineType:$engine); } def TRT_ExecuteOp : TRT_Op<"execute", [NoSideEffect]> { @@ -23,8 +25,25 @@ def TRT_ExecuteOp : TRT_Op<"execute", [NoSideEffect]> { let description = [{ Describe a tensorrt runtime. }]; - let arguments = (ins TRT_EngineType:$engine, Variadic:$inputs); - let results = (outs Variadic:$output); + let arguments = (ins TRT_EngineType:$engine, Variadic:$inputs); + let results = (outs Variadic:$output); +} + +def TRT_EngineComputeOp : TRT_Op<"compute", [NoSideEffect]> { + let summary = "trt compute engine"; + let description = [{ + execute engine + }]; + let arguments = (ins TRT_EngineType:$engine, Context:$context); + let results = (outs DenseTensorList:$outputs); +} + +def TRT_InspectEngineOp : TRT_Op<"inspect_engine", [NoSideEffect]> { + let summary = "trt inspect engine"; + let description = [{ + Show engine + }]; + let arguments = (ins TRT_EngineType:$engine); } def TRT_ActivationOp : TRT_Op<"Activation", [NoSideEffect]> { @@ -34,11 +53,11 @@ def TRT_ActivationOp : TRT_Op<"Activation", [NoSideEffect]> { TensorRT IActivationLayer. }]; - let arguments = (ins TRT_Tensor:$input, SI32Attr:$activation_type, + let arguments = (ins DenseTensor:$input, SI32Attr:$activation_type, DefaultValuedAttr:$alpha, DefaultValuedAttr:$beta); - let results = (outs TRT_Tensor:$output); + let results = (outs DenseTensor:$output); } def TRT_ElementWiseOp : TRT_Op<"ElementWise", [NoSideEffect]> { @@ -48,9 +67,9 @@ def TRT_ElementWiseOp : TRT_Op<"ElementWise", [NoSideEffect]> { TensorRT IElementWiseLayer. }]; - let arguments = (ins TRT_Tensor:$input1, TRT_Tensor:$input2, SI32Attr:$elementwise_operation); + let arguments = (ins DenseTensor:$input1, DenseTensor:$input2, SI32Attr:$elementwise_operation); - let results = (outs TRT_Tensor:$output); + let results = (outs DenseTensor:$output); } def TRT_MatrixMultiplyOp : TRT_Op<"MatrixMultiply", [NoSideEffect]> { @@ -60,10 +79,10 @@ def TRT_MatrixMultiplyOp : TRT_Op<"MatrixMultiply", [NoSideEffect]> { TensorRT IMatrixMultiplyLayer. }]; - let arguments = (ins TRT_Tensor:$input1, BoolAttr:$transpose1, - TRT_Tensor:$input2, BoolAttr:$transpose2); + let arguments = (ins DenseTensor:$input1, BoolAttr:$transpose1, + DenseTensor:$input2, BoolAttr:$transpose2); - let results = (outs TRT_Tensor:$output); + let results = (outs DenseTensor:$output); } #endif // TRT_OPS diff --git a/paddle/infrt/host_context/mlir_exec.cc b/paddle/infrt/host_context/mlir_exec.cc index 319df90d3eec133d3f02be6749e9ad379fd225fd..81bf873ddf0cf3f1a94489bd3b0b2769274b1b4a 100644 --- a/paddle/infrt/host_context/mlir_exec.cc +++ b/paddle/infrt/host_context/mlir_exec.cc @@ -33,7 +33,10 @@ #include "paddle/infrt/dialect/phi/pass/phi_op_convert_pass.h" #include "paddle/infrt/kernel/phi/infershaped/infershaped_kernel_launchers.h" #include "paddle/infrt/kernel/phi/registry.h" -#endif +#if defined(INFRT_WITH_GPU) && defined(INFRT_WITH_TRT) +#include "paddle/infrt/kernel/tensorrt/registry.h" +#endif // INFRT_WITH_GPU && INFRT_WITH_TRT +#endif // INFRT_WITH_PHI static llvm::cl::list cl_shared_libs( // NOLINT "shared_libs", @@ -62,6 +65,9 @@ int main(int argc, char** argv) { #ifdef INFRT_WITH_PHI kernel::RegisterPhiKernels(®istry); kernel::RegisterInferShapeLaunchers(®istry); +#if defined(INFRT_WITH_GPU) && defined(INFRT_WITH_TRT) + kernel::RegisterTrtKernels(®istry); +#endif // INFRT_WITH_GPU && INFRT_WITH_TRT #endif // load extra shared library diff --git a/paddle/infrt/host_context/mlir_to_runtime_translate.cc b/paddle/infrt/host_context/mlir_to_runtime_translate.cc index c613843cd1779599fbac5aea6042b26b151534e8..3d5cccb5c32694ff05d10811bbff0f068bd6bc51 100644 --- a/paddle/infrt/host_context/mlir_to_runtime_translate.cc +++ b/paddle/infrt/host_context/mlir_to_runtime_translate.cc @@ -16,12 +16,14 @@ #include #include +#include #include #include #include #include #include +#include #include #include #include @@ -42,6 +44,13 @@ #include "paddle/infrt/host_context/value.h" #include "paddle/infrt/tensor/tensor_shape.h" +#ifdef INFRT_WITH_PHI +#ifdef INFRT_WITH_TRT +#include "paddle/infrt/kernel/tensorrt/trt_kernels.h" +#endif +#include "paddle/phi/core/dense_tensor.h" +#endif + namespace infrt { namespace host_context { @@ -277,33 +286,58 @@ bool MlirToRuntimeTranslator::EmitGeneralOp( impl_->runtime->NewOpExecutable(op->getName().getStringRef().str()); VLOG(3) << "processing general op : " << op->getName().getStringRef().str(); + // TODO(wilber): Find a more appropriate way to handle special cases. + if (op->getName().getStringRef() == "trt.create_engine") { +#ifdef INFRT_WITH_TRT + auto* symbols = impl_->runtime->symbol_table(); + ::infrt::kernel::tensorrt::MlirOperationWithInfrtSymbol mlir_operation; + mlir_operation.operation = op; + mlir_operation.symbol_table = symbols; + impl_->cur_op->AppendArgument(new Value(mlir_operation)); + // TODO(wilber): how to pass DenseTensor to create_engine op? temporialiy + // add a naive implement. + for (int i = 0, e = op->getNumOperands(); i < e; ++i) { + auto operand = op->getOperand(i); + if (operand.isa()) { + mlir::BlockArgument arg = operand.dyn_cast(); + Value* arg_value = GetValue(arg); + if (arg_value->is_type()) { + impl_->runtime->FeedInArgs( + std::make_pair(std::to_string(i), ValueRef(arg_value))); + } + } + } +#else + CHECK(false) << "should not reach here"; +#endif + } else { + // process operands + for (int i = 0, e = op->getNumOperands(); i < e; i++) { + // function argument as value + auto operand = op->getOperand(i); + /// if (operand.getKind() == mlir::Value::Kind::BlockArgument) { + if (operand.isa()) { + mlir::BlockArgument arg = operand.dyn_cast(); + Value* arg_value = GetValue(arg); + impl_->cur_op->AppendArgument(arg_value); + VLOG(3) << "* op mlir operand: " << DumpToString(arg) << " " + << GetValue(arg); + continue; + } - // process operands - for (int i = 0, e = op->getNumOperands(); i < e; i++) { - // function argument as value - auto operand = op->getOperand(i); - /// if (operand.getKind() == mlir::Value::Kind::BlockArgument) { - if (operand.isa()) { - mlir::BlockArgument arg = operand.dyn_cast(); - Value* arg_value = GetValue(arg); + // normal value + Value* arg_value = GetValue(operand); + if (!arg_value) { + auto upstream_op = operand.getDefiningOp(); + arg_value = GetOpResult(upstream_op); + } + CHECK(arg_value) << "No-exist argument value found: " + << DumpToString(operand); impl_->cur_op->AppendArgument(arg_value); - VLOG(3) << "* op mlir operand: " << DumpToString(arg) << " " - << GetValue(arg); - continue; - } - // normal value - Value* arg_value = GetValue(operand); - if (!arg_value) { - auto upstream_op = operand.getDefiningOp(); - arg_value = GetOpResult(upstream_op); + VLOG(3) << "* op mlir operand: " << DumpToString(operand) << " " + << GetValue(operand) << " vs " << arg_value; } - CHECK(arg_value) << "No-exist argument value found: " - << DumpToString(operand); - impl_->cur_op->AppendArgument(arg_value); - - VLOG(3) << "* op mlir operand: " << DumpToString(operand) << " " - << GetValue(operand) << " vs " << arg_value; } // process attributes @@ -383,33 +417,6 @@ bool MlirToRuntimeTranslator::EmitGeneralOp( impl_->cur_op->AppendAttribute(tmp[i]); } - // process results - llvm::SmallVector res_values; - for (int i = 0, e = op->getNumResults(); i < e; i++) { - auto res = op->getResult(i); - if (res.getType().isa<::infrt::DenseTensorType>()) { - auto r = impl_->value_map.try_emplace( - res, ValueRef(new Value{::phi::DenseTensor()})); - CHECK(r.second) << "Duplicate add mlir value [" << DumpToString(res) - << "]"; - res_values.push_back(r.first->second.get()); - } else { - res_values.push_back(AddValue(res)); - } - - VLOG(3) << "* op mlir res: " << DumpToString(res) << " " << GetValue(res); - } - impl_->cur_op->SetResults(res_values); - -#ifdef INFRT_DEBUG - { - VLOG(3) << "check result"; - for (int i = 0; i < impl_->cur_op->frame().GetNumResults(); i++) { - VLOG(3) << "+ res value: " << impl_->cur_op->frame().GetResults()[i]; - } - } -#endif - // process regions, we treat regions as attribute. auto num_regions = op->getNumRegions(); if (num_regions > 0) { @@ -438,6 +445,33 @@ bool MlirToRuntimeTranslator::EmitGeneralOp( impl_->cur_op->AppendAttribute(new Value(function)); } + // process results + llvm::SmallVector res_values; + for (int i = 0, e = op->getNumResults(); i < e; i++) { + auto res = op->getResult(i); + if (res.getType().isa<::infrt::DenseTensorType>()) { + auto r = impl_->value_map.try_emplace( + res, ValueRef(new Value{::phi::DenseTensor()})); + CHECK(r.second) << "Duplicate add mlir value [" << DumpToString(res) + << "]"; + res_values.push_back(r.first->second.get()); + } else { + res_values.push_back(AddValue(res)); + } + + VLOG(3) << "* op mlir res: " << DumpToString(res) << " " << GetValue(res); + } + impl_->cur_op->SetResults(res_values); + +#ifdef INFRT_DEBUG + { + VLOG(3) << "check result"; + for (int i = 0; i < impl_->cur_op->frame().GetNumResults(); i++) { + VLOG(3) << "+ res value: " << impl_->cur_op->frame().GetResults()[i]; + } + } +#endif + return true; } diff --git a/paddle/infrt/host_context/value.h b/paddle/infrt/host_context/value.h index 957d852442b10620244e230a2f7704eb7fa0a33e..1f0b1dabd94d8dcf28e8e0543a8e3b12ed250704 100644 --- a/paddle/infrt/host_context/value.h +++ b/paddle/infrt/host_context/value.h @@ -24,6 +24,7 @@ #include "paddle/infrt/common/shared.h" #include "paddle/infrt/dialect/infrt/common/types.h" #include "paddle/infrt/host_context/function.h" +#include "paddle/infrt/host_context/symbol_table.h" #include "paddle/infrt/support/variant.h" #include "paddle/infrt/tensor/dense_host_tensor.h" #include "paddle/infrt/tensor/dense_tensor_view.h" @@ -41,7 +42,15 @@ #include "paddle/phi/common/scalar_array.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/meta_tensor.h" -#endif + +#ifdef INFRT_WITH_GPU +#include "paddle/phi/backends/gpu/gpu_context.h" +#endif // INFRT_WITH_GPU +#ifdef INFRT_WITH_TRT +#include "paddle/infrt/backends/tensorrt/trt_engine.h" +#include "paddle/infrt/kernel/tensorrt/trt_kernels.h" +#endif // INFRT_WITH_TRT +#endif // INFRT_WITH_PHI namespace infrt { namespace host_context { @@ -72,8 +81,13 @@ using ValueVariantType = ::phi::MetaTensor, ::phi::DenseTensor, backends::CpuPhiContext, +#ifdef INFRT_WITH_GPU + backends::GpuPhiContext, + ::phi::GPUContext, +#endif ::phi::CPUContext, std::vector, + std::vector, paddle::experimental::ScalarBase, paddle::experimental::ScalarArrayBase, std::vector, @@ -81,6 +95,10 @@ using ValueVariantType = paddle::experimental::Backend, paddle::experimental::DataLayout, paddle::experimental::DataType, +#ifdef INFRT_WITH_TRT + ::infrt::backends::tensorrt::TrtEngine, + ::infrt::kernel::tensorrt::MlirOperationWithInfrtSymbol, +#endif // INFRT_WITH_TRT #endif std::vector, std::vector, @@ -120,8 +138,18 @@ class Value : public common::Object { #ifdef INFRT_WITH_PHI explicit Value(::phi::CPUContext&& x) : data(std::move(x)) {} explicit Value(backends::CpuPhiContext&& x) : data(std::move(x)) {} +#ifdef INFRT_WITH_GPU + explicit Value(::phi::GPUContext&& x) : data(std::move(x)) {} + explicit Value(backends::GpuPhiContext&& x) : data(std::move(x)) {} +#endif explicit Value(::phi::DenseTensor&& x) : data(std::move(x)) {} explicit Value(::phi::MetaTensor&& x) : data(std::move(x)) {} +#ifdef INFRT_WITH_TRT + explicit Value(::infrt::backends::tensorrt::TrtEngine&& x) + : data(std::move(x)) {} + explicit Value(::infrt::kernel::tensorrt::MlirOperationWithInfrtSymbol x) + : data(x) {} +#endif // INFRT_WITH_TRT #endif template diff --git a/paddle/infrt/kernel/CMakeLists.txt b/paddle/infrt/kernel/CMakeLists.txt index f1cbfba1c46b33e461a7c9f08cf646625fbafb24..f20344f6f6b84ae8e63f44c7b7b83c6ba9d8d6da 100644 --- a/paddle/infrt/kernel/CMakeLists.txt +++ b/paddle/infrt/kernel/CMakeLists.txt @@ -1,4 +1,5 @@ add_subdirectory(phi) +add_subdirectory(tensorrt) core_gather_headers() diff --git a/paddle/infrt/kernel/phi/context_kernels.cc b/paddle/infrt/kernel/phi/context_kernels.cc index 39ef172fadef9e0f6317dec192c251c6a1df6828..b27eacf9e522d2bbb8b7ffd70ad57f54e5775499 100644 --- a/paddle/infrt/kernel/phi/context_kernels.cc +++ b/paddle/infrt/kernel/phi/context_kernels.cc @@ -25,6 +25,16 @@ namespace phi { return ctx; } +#ifdef INFRT_WITH_GPU +::phi::GPUContext CreateGPUContext() { + ::phi::GPUContext context; + context.PartialInitWithoutAllocator(); + context.SetAllocator(new ::infrt::backends::GpuPhiAllocator{}); + context.PartialInitWithAllocator(); + return context; +} +#endif + } // namespace phi } // namespace kernel } // namespace infrt diff --git a/paddle/infrt/kernel/phi/context_kernels.h b/paddle/infrt/kernel/phi/context_kernels.h index 3e9580b91da5724b42c72224847e45715f47dbb7..ae3f76c8fe536f96689680668cc52e4981894063 100644 --- a/paddle/infrt/kernel/phi/context_kernels.h +++ b/paddle/infrt/kernel/phi/context_kernels.h @@ -25,6 +25,10 @@ namespace phi { ::phi::CPUContext CreateCPUContext(); +#ifdef INFRT_WITH_GPU +::phi::GPUContext CreateGPUContext(); +#endif + } // namespace phi } // namespace kernel } // namespace infrt diff --git a/paddle/infrt/kernel/phi/dense_tensor_kernels.cc b/paddle/infrt/kernel/phi/dense_tensor_kernels.cc index 777fb29ac60d9c7125898752747bbdf553f370c0..6d16b814c6b02b08e279190d5a685d65c124942d 100644 --- a/paddle/infrt/kernel/phi/dense_tensor_kernels.cc +++ b/paddle/infrt/kernel/phi/dense_tensor_kernels.cc @@ -15,6 +15,12 @@ #include "paddle/infrt/kernel/phi/dense_tensor_kernels.h" #include "paddle/infrt/dialect/phi/data_type.h" #include "paddle/infrt/kernel/phi/context_kernels.h" +#include "paddle/phi/backends/all_context.h" +#include "paddle/phi/common/place.h" + +#ifdef INFRT_WITH_GPU +#include +#endif namespace infrt { namespace kernel { @@ -34,26 +40,83 @@ namespace phi { {})); } +::phi::DenseTensor CreateGPUDenseTensor( + const ::phi::GPUContext& context, + host_context::Attribute> dims, + host_context::Attribute> lod, + host_context::Attribute<::infrt::LayoutType> layout, + host_context::Attribute<::infrt::PrecisionType> precision) { + return ::phi::DenseTensor( + const_cast<::phi::Allocator*>(&context.GetAllocator()), + ::phi::DenseTensorMeta(ConvertPrecisionToPhi(precision.get()), + ::phi::make_ddim(dims.get()), + ConvertLayoutToPhi(layout.get()), + {})); +} + void FillDenseTensorF32(::phi::DenseTensor* dense_tensor, host_context::Attribute> value) { - auto place = ::phi::CPUPlace(); + auto place = dense_tensor->place(); float* a_data = dense_tensor->mutable_data(place); - for (int64_t i = 0; i < dense_tensor->numel(); ++i) { - a_data[i] = (value.get())[i]; + if (place.GetType() == ::phi::AllocationType::CPU) { + for (int64_t i = 0; i < dense_tensor->numel(); ++i) { + a_data[i] = (value.get())[i]; + } + } else if (place.GetType() == ::phi::AllocationType::GPU) { +#ifdef INFRT_WITH_GPU + // TODO(wilber): how to set the stream parameter to copy with stream. + cudaMemcpy(a_data, + value.get().data(), + sizeof(float) * value.get().size(), + cudaMemcpyHostToDevice); +#endif + } else { + llvm_unreachable("temporarily not support other target."); } } void PrintDenseTensor(::phi::DenseTensor* dense_tensor) { -#define PRINT_META_DATA(PHI_DATATYPE, DTYPE) \ - case ::phi::DataType::PHI_DATATYPE: { \ - DTYPE* data = dense_tensor->data(); \ - if (dense_tensor->numel() == 0) break; \ - std::cout << data[0]; \ - for (int64_t i = 1; i < dense_tensor->numel(); i++) { \ - std::cout << "," << data[i]; \ - } \ - break; \ +#ifndef INFRT_WITH_GPU +#define PRINT_META_DATA(PHI_DATATYPE, DTYPE) \ + case ::phi::DataType::PHI_DATATYPE: { \ + auto place = dense_tensor->place(); \ + if (place.GetType() == ::phi::AllocationType::CPU) { \ + DTYPE* data = dense_tensor->data(); \ + if (dense_tensor->numel() == 0) break; \ + std::cout << data[0]; \ + for (int64_t i = 1; i < dense_tensor->numel(); i++) { \ + std::cout << "," << data[i]; \ + } \ + } \ + break; \ + } +#else +#define PRINT_META_DATA(PHI_DATATYPE, DTYPE) \ + case ::phi::DataType::PHI_DATATYPE: { \ + auto place = dense_tensor->place(); \ + DTYPE* data = dense_tensor->data(); \ + if (dense_tensor->numel() == 0) break; \ + if (place.GetType() == ::phi::AllocationType::CPU) { \ + std::cout << data[0]; \ + for (int64_t i = 1; i < dense_tensor->numel(); i++) { \ + std::cout << "," << data[i]; \ + } \ + } else if (place.GetType() == ::phi::AllocationType::GPU) { \ + std::vector host_data(dense_tensor->numel(), 0); \ + cudaMemcpy(host_data.data(), \ + data, \ + sizeof(DTYPE) * dense_tensor->numel(), \ + cudaMemcpyDeviceToHost); \ + std::cout << host_data[0]; \ + for (int64_t i = 1; i < dense_tensor->numel(); i++) { \ + std::cout << "," << host_data[i]; \ + } \ + } else { \ + llvm_unreachable("temporarily not support other target."); \ + } \ + break; \ } +#endif ::phi::DDim dims = dense_tensor->dims(); std::cout << "dense_tensor: shape=shape" << dims.to_str() << "," diff --git a/paddle/infrt/kernel/phi/dense_tensor_kernels.h b/paddle/infrt/kernel/phi/dense_tensor_kernels.h index 8cc0e39e0e4431f073ac37a7f0557f2c837dc753..47d89506e2aa615b0bc425a4c373c904d937e03f 100644 --- a/paddle/infrt/kernel/phi/dense_tensor_kernels.h +++ b/paddle/infrt/kernel/phi/dense_tensor_kernels.h @@ -30,6 +30,13 @@ namespace phi { host_context::Attribute<::infrt::LayoutType> layout, host_context::Attribute<::infrt::PrecisionType> precision); +::phi::DenseTensor CreateGPUDenseTensor( + const ::phi::GPUContext& context, + host_context::Attribute> dims, + host_context::Attribute> lod, + host_context::Attribute<::infrt::LayoutType> layout, + host_context::Attribute<::infrt::PrecisionType> precision); + void FillDenseTensorF32(::phi::DenseTensor* dense_tensor, host_context::Attribute> values); void PrintDenseTensor(::phi::DenseTensor* dense_tensor); diff --git a/paddle/infrt/kernel/phi/registry.cc b/paddle/infrt/kernel/phi/registry.cc index 0e071418603f8390ca3283f617b06cf1fa91b94c..36d40118f16a0bd1779765064caaac6dbe414772 100644 --- a/paddle/infrt/kernel/phi/registry.cc +++ b/paddle/infrt/kernel/phi/registry.cc @@ -35,7 +35,7 @@ void RegisterPhiKernels(host_context::KernelRegistry* registry) { registry->AddKernel("phi_dt.create_context.cpu", INFRT_KERNEL(infrt::kernel::phi::CreateCPUContext)); registry->AddKernelWithAttrs( - "phi_dt.create_dense_tensor", + "phi_dt.create_dense_tensor.cpu", INFRT_KERNEL(infrt::kernel::phi::CreateDenseTensor), {"dims", "lod", "layout", "precision"}); registry->AddKernelWithAttrs( @@ -44,6 +44,15 @@ void RegisterPhiKernels(host_context::KernelRegistry* registry) { {"value"}); registry->AddKernel("phi_dt.print_tensor", INFRT_KERNEL(infrt::kernel::phi::PrintDenseTensor)); + +#ifdef INFRT_WITH_GPU + registry->AddKernel("phi_dt.create_context.gpu", + INFRT_KERNEL(infrt::kernel::phi::CreateGPUContext)); + registry->AddKernelWithAttrs( + "phi_dt.create_dense_tensor.gpu", + INFRT_KERNEL(infrt::kernel::phi::CreateGPUDenseTensor), + {"dims", "lod", "layout", "precision"}); +#endif } } // namespace kernel diff --git a/paddle/infrt/kernel/tensor_kernels.cc b/paddle/infrt/kernel/tensor_kernels.cc index b7503aa4ef35894dda514fdb7fa4336485323094..79502f9fdfd4bd88666f61ff30bc526325b91341 100644 --- a/paddle/infrt/kernel/tensor_kernels.cc +++ b/paddle/infrt/kernel/tensor_kernels.cc @@ -25,6 +25,10 @@ #include "paddle/infrt/tensor/tensor_map.h" #include "paddle/infrt/tensor/tensor_shape.h" +#ifdef INFRT_WITH_PHI +#include "paddle/phi/core/dense_tensor.h" +#endif + namespace infrt { namespace kernel { using namespace host_context; // NOLINT @@ -62,6 +66,20 @@ DenseHostTensor TensorMapGetTensor(TensorMap map, Attribute name) { int32_t TensorMapGetSize(TensorMap map) { return map.size(); } +// TODO(wilber): Maybe we should place TensorList type in dt dialect. +#ifdef INFRT_WITH_PHI +phi::DenseTensor TensorListGetTensor(std::vector list, + Attribute idx) { + CHECK_LT(idx.get(), static_cast(list.size())) + << "idx should less than list size"; + return *list[idx.get()]; +} + +int32_t TensorListGetSize(const std::vector &list) { + return list.size(); +} +#endif + DenseHostTensor ShallowCopyTensor(DenseHostTensor v) { return v; } template @@ -126,6 +144,14 @@ void RegisterTensorKernels(host_context::KernelRegistry *registry) { INFRT_KERNEL(TensorMapGetTensor)); registry->AddKernel("dt.tensor_map_get_size", INFRT_KERNEL(TensorMapGetSize)); +// TensorList related methods. +#ifdef INFRT_WITH_PHI + registry->AddKernel("dt.tensor_list_get_tensor", + INFRT_KERNEL(TensorListGetTensor)); + registry->AddKernel("dt.tensor_list_get_size", + INFRT_KERNEL(TensorListGetSize)); +#endif + registry->AddKernel("dt.shallow_copy_tensor", INFRT_KERNEL(ShallowCopyTensor)); diff --git a/paddle/infrt/kernel/tensorrt/CMakeLists.txt b/paddle/infrt/kernel/tensorrt/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..cd35fccbe2aa35453a4d4ac13364ef6bb5a6b6aa --- /dev/null +++ b/paddle/infrt/kernel/tensorrt/CMakeLists.txt @@ -0,0 +1,10 @@ +if (NOT (INFRT_WITH_PHI AND INFRT_WITH_GPU AND INFRT_WITH_TRT)) + return() +endif() + +core_gather_headers() + +gather_srcs(infrt_src SRCS + registry.cc + trt_kernels.cc +) diff --git a/paddle/infrt/kernel/tensorrt/registry.cc b/paddle/infrt/kernel/tensorrt/registry.cc new file mode 100644 index 0000000000000000000000000000000000000000..a37e3c0f7f2785e23c8a0b9a25d3283396215f70 --- /dev/null +++ b/paddle/infrt/kernel/tensorrt/registry.cc @@ -0,0 +1,33 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/infrt/kernel/tensorrt/registry.h" + +#include "paddle/infrt/host_context/kernel_registry.h" +#include "paddle/infrt/host_context/kernel_utils.h" +#include "paddle/infrt/kernel/tensorrt/trt_kernels.h" + +namespace infrt { +namespace kernel { + +void RegisterTrtKernels(host_context::KernelRegistry* registry) { + registry->AddKernel("trt.create_engine", + INFRT_KERNEL(tensorrt::CreateTrtEngine)); + registry->AddKernel("trt.inspect_engine", + INFRT_KERNEL(tensorrt::PrintTrtLayer)); + registry->AddKernel("trt.compute", INFRT_KERNEL(tensorrt::TrtEngineCompute)); +} + +} // namespace kernel +} // namespace infrt diff --git a/paddle/infrt/kernel/tensorrt/registry.h b/paddle/infrt/kernel/tensorrt/registry.h new file mode 100644 index 0000000000000000000000000000000000000000..762329ca61d02a16edc150854afcc3dd431a941d --- /dev/null +++ b/paddle/infrt/kernel/tensorrt/registry.h @@ -0,0 +1,35 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include + +namespace infrt { +namespace host_context { + +struct KernelRegistry; + +} // namespace host_context +} // namespace infrt + +namespace infrt { +namespace kernel { + +/** + * Register all the trt kernels to registry. + */ +void RegisterTrtKernels(host_context::KernelRegistry* registry); + +} // namespace kernel +} // namespace infrt diff --git a/paddle/infrt/kernel/tensorrt/trt_kernels.cc b/paddle/infrt/kernel/tensorrt/trt_kernels.cc new file mode 100644 index 0000000000000000000000000000000000000000..04847ac8982f861ab2799bd23b1c2ab723422327 --- /dev/null +++ b/paddle/infrt/kernel/tensorrt/trt_kernels.cc @@ -0,0 +1,172 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/infrt/kernel/tensorrt/trt_kernels.h" +#include +#include "NvInfer.h" +#include "NvInferRuntime.h" +#include "NvInferRuntimeCommon.h" +#include "glog/logging.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Value.h" +#include "paddle/infrt/backends/tensorrt/trt_engine.h" +#include "paddle/infrt/backends/tensorrt/trt_options.h" +#include "paddle/infrt/dialect/tensorrt/trt_ops.h" +#include "paddle/infrt/host_context/symbol_table.h" +#include "paddle/phi/core/dense_tensor.h" + +namespace infrt { +namespace kernel { +namespace tensorrt { + +::infrt::backends::tensorrt::TrtEngine CreateTrtEngine( + MlirOperationWithInfrtSymbol + create_engine_op /*, input_tensors, output_tensors, weights*/) { + // TODO(wilber): The device_id needs to get from mlir. + int device_id = 0; + backends::tensorrt::TrtEngine engine(device_id); + + auto* builder = engine.GetTrtBuilder(); + // TODO(wilber): How to process weights? + backends::tensorrt::TrtUniquePtr network; + // TODO(wilber): static_shape or dynamic_shape network? The code is just + // static_shape test. + network.reset(builder->createNetworkV2(0)); + + // TODO(wilber): The build option shoule be fiiled from mlir info. + backends::tensorrt::BuildOptions options; + options.max_batch = 4; + + // Parse mlir Region which only has one block. + mlir::Operation& operation = *create_engine_op.operation; + auto* symbol_table = create_engine_op.symbol_table; + CHECK_NOTNULL(symbol_table); + + unsigned int num_regions = operation.getNumRegions(); + CHECK_EQ(num_regions, 1U) << "only support one region case."; + auto& region = operation.getRegion(0); + auto& block = region.getBlocks().front(); + + llvm::DenseMap map_info; + std::unordered_map trt_bind_inputs; + + for (auto index_operand : llvm::enumerate(operation.getOperands())) { + mlir::Value operand = index_operand.value(); + size_t idx = index_operand.index(); + + const std::string input_name = "input_" + std::to_string(idx); + auto* v = symbol_table->GetValue(std::to_string(idx)); + CHECK_NOTNULL(v); + auto* t = &v->get(); + trt_bind_inputs[input_name] = t; + // TODO(wilber): get input info from mlir. + // TODO(wilber): input dims, now only support static_shape, and just remove + // the first dimension. + // TODO(wilber): now only suppot float input. + nvinfer1::Dims dims; + dims.nbDims = t->dims().size() - 1; + for (int i = 0; i < dims.nbDims; ++i) { + dims.d[i] = t->dims()[i + 1]; + } + auto* in = + network->addInput(input_name.c_str(), nvinfer1::DataType::kFLOAT, dims); + map_info[operand] = in; + } + + // TODO(wilber): Find a way to add layer. + for (auto& inner_op : block.without_terminator()) { + if (inner_op.getName().getStringRef() == "trt.Activation") { + trt::ActivationOp act_op = llvm::dyn_cast(inner_op); + auto in_arg = act_op.getOperand(); + if (!map_info.count(in_arg)) { + CHECK(false) << "map_info not has in_arg."; + } + nvinfer1::ActivationType act_type = + static_cast(act_op.activation_type()); + auto* act_layer = network->addActivation(*map_info[in_arg], act_type); + act_layer->setAlpha(act_op.alpha().convertToFloat()); + act_layer->setBeta(act_op.beta().convertToFloat()); + for (size_t i = 0; i < act_op->getNumResults(); ++i) { + nvinfer1::ITensor* act_out_tensor = act_layer->getOutput(i); + mlir::Value act_out = act_op->getResult(i); + map_info[act_out] = act_out_tensor; + } + } + + // if (inner_op.getName().getStringRef() == "trt.Constant") { + // trt::ConstantOp op = llvm::dyn_cast(inner_op); + // mlir::Value op_out = op.getResult(); + // std::vector weight_data{1}; + // auto* layer = network->addConstant(nvinfer1::Dims2(1, 1), + // nvinfer1::Weights{nvinfer1::DataType::kFLOAT, weight_data.data(), 1}); + // auto* op_out_tenor = layer->getOutput(0); + // map_info[op_out] = op_out_tenor; + // } + } + for (auto& inner_op : block.without_terminator()) { + for (mlir::Value v : inner_op.getResults()) { + for (mlir::Operation* user : v.getUsers()) { + if (user->getName().getStringRef() == "infrt.return") { + if (!map_info.count(v)) { + CHECK(false) << "map_info not has value"; + } + network->markOutput(*map_info[v]); + } + } + } + } + // std::unordered_map trt_bind_outputs; + mlir::Operation* ret = block.getTerminator(); + for (unsigned int i = 0; i < ret->getNumOperands(); ++i) { + mlir::Value arg = ret->getOperand(i); + CHECK(map_info.count(arg)); + map_info[arg]->setName(("output_" + std::to_string(i)).c_str()); + } + for (int i = 0; i < network->getNbOutputs(); ++i) { + engine.PrepareOutputHandle(network->getOutput(i)->getName()); + } + + VLOG(3) << "trt engine build start."; + engine.Build(std::move(network), options); + VLOG(3) << "trt engine build done."; + + // TODO(wilber): get inference options from mlir. + backends::tensorrt::InferenceOptions inference_options; + inference_options.batch = 1; + // TODO(wilber): bind trt input/output tensors. + engine.SetUpInference(inference_options, trt_bind_inputs); + return engine; +} + +void PrintTrtLayer(backends::tensorrt::TrtEngine* engine) { + engine->GetEngineInfo(); +} + +std::vector TrtEngineCompute( + backends::tensorrt::TrtEngine* engine, const phi::GPUContext& context) { + engine->Run(context); + std::vector res; + for (size_t i = 0; i < engine->GetOutputNum(); ++i) { + res.push_back(engine->GetOutput("output_" + std::to_string(i))); + } + return res; +} + +} // namespace tensorrt +} // namespace kernel +} // namespace infrt diff --git a/paddle/infrt/kernel/tensorrt/trt_kernels.h b/paddle/infrt/kernel/tensorrt/trt_kernels.h new file mode 100644 index 0000000000000000000000000000000000000000..546ee9dc78852e6967bf8b61ae81563d32beae66 --- /dev/null +++ b/paddle/infrt/kernel/tensorrt/trt_kernels.h @@ -0,0 +1,49 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "mlir/IR/Operation.h" + +#include "paddle/infrt/backends/tensorrt/trt_engine.h" +#include "paddle/phi/backends/gpu/gpu_context.h" + +namespace infrt { +namespace host_context { +class SymbolTable; +} // namespace host_context + +namespace kernel { +namespace tensorrt { + +struct MlirOperationWithInfrtSymbol { + mlir::Operation* operation; + ::infrt::host_context::SymbolTable* symbol_table; +}; + +::infrt::backends::tensorrt::TrtEngine CreateTrtEngine( + MlirOperationWithInfrtSymbol engine_op); + +void PrintTrtLayer(backends::tensorrt::TrtEngine* engine); + +std::vector TrtEngineCompute( + backends::tensorrt::TrtEngine* engine, const phi::GPUContext& context); + +} // namespace tensorrt +} // namespace kernel +} // namespace infrt diff --git a/paddle/infrt/tests/dialect/disabled_trt.mlir b/paddle/infrt/tests/dialect/disabled_trt.mlir new file mode 100644 index 0000000000000000000000000000000000000000..ef86dcf1e72a04c478a7763000cf366715665d81 --- /dev/null +++ b/paddle/infrt/tests/dialect/disabled_trt.mlir @@ -0,0 +1,37 @@ +// RUN: infrtexec -i %s | FileCheck %s + +// CHECK-LABEL: @run_trt +func @run_trt(%0 : !infrt.dense_tensor, %ctx : !phi.context) { + %a = "trt.create_engine"(%0) ({ + %1 = "trt.Activation"(%0) {activation_type = 1 : si32, alpha = 1.0 : f32, beta = 6.0 : f32} : (!infrt.dense_tensor) -> !infrt.dense_tensor + "infrt.return"(%1) : (!infrt.dense_tensor) -> () + }) : (!infrt.dense_tensor) -> !trt.engine + "trt.inspect_engine"(%a) {} : (!trt.engine) -> () + + %res = "trt.compute"(%a, %ctx) {} : (!trt.engine, !phi.context) -> (!infrt.tensor_list) + %size = "dt.tensor_list_get_size"(%res) {} : (!infrt.tensor_list) -> (i32) + "infrt.print.i32"(%size) {} : (i32) -> () + + %ts0 = "dt.tensor_list_get_tensor"(%res) {id = 0 : i32} : (!infrt.tensor_list) -> (!infrt.dense_tensor) + "phi_dt.print_tensor" (%ts0) : (!infrt.dense_tensor) -> () + + infrt.return +} + +// CHECK-LABEL: @main +func @main() { + %ctx = "phi_dt.create_context.gpu" (): () -> !phi.context + %t = "phi_dt.create_dense_tensor.gpu" (%ctx) { + precision=#infrt.precision, + layout=#infrt.layout, + dims=[1:i64, 3:i64, 1:i64, 1:i64], lod=[1:i64]}: (!phi.context) -> (!infrt.dense_tensor) + + "phi_dt.fill_dense_tensor.f32"(%t) {value=[3.8:f32, 2.4:f32, 1.3:f32]} : (!infrt.dense_tensor) -> () + "phi_dt.print_tensor" (%t) : (!infrt.dense_tensor) -> () + + //%res = + infrt.call @run_trt(%t, %ctx) : (!infrt.dense_tensor, !phi.context) -> () + //-> (!infrt.dense_tensor) + + infrt.return +} diff --git a/paddle/infrt/tests/dialect/phi/dense_tensor.mlir b/paddle/infrt/tests/dialect/phi/dense_tensor.mlir index 3657777a5b0bce1c5a5e4df8d59695f8b122da56..b8cb1a5cec2a17d3f6d15036249fcf9f7f711948 100644 --- a/paddle/infrt/tests/dialect/phi/dense_tensor.mlir +++ b/paddle/infrt/tests/dialect/phi/dense_tensor.mlir @@ -3,7 +3,7 @@ // CHECK-LABEL: @sign_any_float32_execute func @sign_any_float32_execute() { %ctx = "phi_dt.create_context.cpu" (): () -> !phi.context - %t = "phi_dt.create_dense_tensor" (%ctx) { + %t = "phi_dt.create_dense_tensor.cpu" (%ctx) { precision=#infrt.precision, layout=#infrt.layout, lod=[1:i64], dims=[1:i64]}: (!phi.context) -> (!infrt.dense_tensor) "phi_dt.fill_dense_tensor.f32"(%t) {value=[3.8:f32]} : (!infrt.dense_tensor) -> () diff --git a/paddle/infrt/tests/dialect/phi/phi_test.mlir b/paddle/infrt/tests/dialect/phi/phi_test.mlir index 5b0fa735897a31287bb6dea487e2f22eacd7b0aa..21ee8ebf0b705894446192b0d5d0bfeb9f10f326 100644 --- a/paddle/infrt/tests/dialect/phi/phi_test.mlir +++ b/paddle/infrt/tests/dialect/phi/phi_test.mlir @@ -6,7 +6,7 @@ module { } func @main() { %ctx = "phi_dt.create_context.cpu" (): () -> !phi.context - %t = "phi_dt.create_dense_tensor" (%ctx) {precision=#infrt.precision, layout=#infrt.layout, lod=[1:i64], dims=[1:i64]}: (!phi.context) -> (!infrt.dense_tensor) + %t = "phi_dt.create_dense_tensor.cpu" (%ctx) {precision=#infrt.precision, layout=#infrt.layout, lod=[1:i64], dims=[1:i64]}: (!phi.context) -> (!infrt.dense_tensor) "phi_dt.fill_dense_tensor.f32"(%t) {value=[3.8:f32]} : (!infrt.dense_tensor) -> () %2 = infrt.call@predict(%t) : (!infrt.dense_tensor) -> !infrt.dense_tensor phi_dt.print_tensor(%2 : !infrt.dense_tensor) diff --git a/paddle/infrt/tests/dialect/trt_ops.mlir b/paddle/infrt/tests/dialect/trt_ops.mlir index e3cb9670bec015e58e2a538bb55dfbe7c8b7f554..7bdf62a277896afe2f8a5e156fa8183742f1d853 100644 --- a/paddle/infrt/tests/dialect/trt_ops.mlir +++ b/paddle/infrt/tests/dialect/trt_ops.mlir @@ -1,16 +1,16 @@ // RUN: trt-exec %s // CHECK-LABEL: @main -func @main(%bias:tensor, %c:tensor, %b1:tensor, %b2:tensor, %bias1:tensor, %bias2:tensor) -> tensor { - %d = "pd.elementwise_add"(%c, %bias) {axis=-1:si32} : (tensor, tensor) -> tensor - %e = "pd.relu6"(%d) {} : (tensor) -> tensor +func @main(%bias:!infrt.dense_tensor, %c:!infrt.dense_tensor, %b1:!infrt.dense_tensor, %b2:!infrt.dense_tensor, %bias1:!infrt.dense_tensor, %bias2:!infrt.dense_tensor) -> !infrt.dense_tensor { + %d = "pd.elementwise_add"(%c, %bias) {axis=-1:si32} : (!infrt.dense_tensor, !infrt.dense_tensor) -> !infrt.dense_tensor + %e = "pd.relu6"(%d) {} : (!infrt.dense_tensor) -> !infrt.dense_tensor - %c1 = "pd.matmul"(%e, %b1) {transpose_x=false, transpose_y=false} : (tensor, tensor) -> tensor - %d1 = "pd.elementwise_add"(%c1, %bias1) {axis=-1:si32} : (tensor, tensor) -> tensor - %e1 = "pd.relu"(%d1) {} : (tensor) -> tensor + %c1 = "pd.matmul"(%e, %b1) {transpose_x=false, transpose_y=false} : (!infrt.dense_tensor, !infrt.dense_tensor) -> !infrt.dense_tensor + %d1 = "pd.elementwise_add"(%c1, %bias1) {axis=-1:si32} : (!infrt.dense_tensor, !infrt.dense_tensor) -> !infrt.dense_tensor + %e1 = "pd.relu"(%d1) {} : (!infrt.dense_tensor) -> !infrt.dense_tensor - %c2 = "pd.matmul"(%e1, %b2) {transpose_x=true, transpose_y=false} : (tensor, tensor) -> tensor - %d2 = "pd.elementwise_add"(%c2, %bias2) {axis=-1:si32} : (tensor, tensor) -> tensor - %e2 = "pd.relu"(%d2) {} : (tensor) -> tensor + %c2 = "pd.matmul"(%e1, %b2) {transpose_x=true, transpose_y=false} : (!infrt.dense_tensor, !infrt.dense_tensor) -> !infrt.dense_tensor + %d2 = "pd.elementwise_add"(%c2, %bias2) {axis=-1:si32} : (!infrt.dense_tensor, !infrt.dense_tensor) -> !infrt.dense_tensor + %e2 = "pd.relu"(%d2) {} : (!infrt.dense_tensor) -> !infrt.dense_tensor - infrt.return %e2 : tensor + infrt.return %e2 : !infrt.dense_tensor } diff --git a/paddle/phi/backends/gpu/gpu_context.cc b/paddle/phi/backends/gpu/gpu_context.cc index a3b252598582bc212ba66f9c18ec52e035a29a68..0394835aa8b700ba4f9ee9b106661e2d70fc50b6 100644 --- a/paddle/phi/backends/gpu/gpu_context.cc +++ b/paddle/phi/backends/gpu/gpu_context.cc @@ -741,6 +741,10 @@ struct GPUContext::Impl { GPUContext::GPUContext() : DeviceContext(), impl_(std::make_unique()) {} +GPUContext::GPUContext(GPUContext&&) = default; + +GPUContext& GPUContext::operator=(GPUContext&&) = default; + GPUContext::GPUContext(const GPUPlace& place) : DeviceContext(), impl_(std::make_unique(place)) {} diff --git a/paddle/phi/backends/gpu/gpu_context.h b/paddle/phi/backends/gpu/gpu_context.h index 3eb4360ad35382369681308b46050cc3e6e04ea0..cd08da1c0f2f8031a461a0410a89254823a6a903 100644 --- a/paddle/phi/backends/gpu/gpu_context.h +++ b/paddle/phi/backends/gpu/gpu_context.h @@ -77,6 +77,8 @@ class DnnWorkspaceHandle { class GPUContext : public DeviceContext { public: GPUContext(); + GPUContext(GPUContext&&); + GPUContext& operator=(GPUContext&&); explicit GPUContext(const GPUPlace& place);