From 211e131525ac3270b8afdd4d40d1f46b7f9dcfda Mon Sep 17 00:00:00 2001 From: Yan Chunwei Date: Wed, 30 May 2018 21:50:05 +0800 Subject: [PATCH] feature/tensorrt engine op (#11001) --- paddle/fluid/inference/tensorrt/engine.cc | 26 ++++- paddle/fluid/inference/tensorrt/engine.h | 8 +- paddle/fluid/operators/CMakeLists.txt | 3 + paddle/fluid/operators/tensorrt_engine_op.cc | 70 ++++++++++++ paddle/fluid/operators/tensorrt_engine_op.h | 110 +++++++++++++++++++ 5 files changed, 213 insertions(+), 4 deletions(-) create mode 100644 paddle/fluid/operators/tensorrt_engine_op.cc create mode 100644 paddle/fluid/operators/tensorrt_engine_op.h diff --git a/paddle/fluid/inference/tensorrt/engine.cc b/paddle/fluid/inference/tensorrt/engine.cc index fb27c8394c1..a88236ae98e 100644 --- a/paddle/fluid/inference/tensorrt/engine.cc +++ b/paddle/fluid/inference/tensorrt/engine.cc @@ -131,6 +131,20 @@ void* TensorRTEngine::GetOutputInGPU(const std::string& name) { return buffer(name).buffer; } +void TensorRTEngine::GetOutputInGPU(const std::string& name, void* dst, + size_t max_size) { + // determine data size + auto it = buffer_sizes_.find(name); + PADDLE_ENFORCE(it != buffer_sizes_.end()); + PADDLE_ENFORCE_GT(it->second, 0); + PADDLE_ENFORCE_GE(max_size, it->second); + auto& buf = buffer(name); + PADDLE_ENFORCE_NOT_NULL(buf.buffer, "buffer should be allocated before"); + PADDLE_ENFORCE_EQ(cudaMemcpyAsync(dst, buf.buffer, it->second, + cudaMemcpyDeviceToDevice, *stream_), + 0); +} + void TensorRTEngine::GetOutputInCPU(const std::string& name, void* dst, size_t max_size) { // determine data size @@ -152,7 +166,7 @@ Buffer& TensorRTEngine::buffer(const std::string& name) { return buffers_[slot_offset]; } -void TensorRTEngine::SetInputFromCPU(const std::string& name, void* data, +void TensorRTEngine::SetInputFromCPU(const std::string& name, const void* data, size_t size) { auto& buf = buffer(name); PADDLE_ENFORCE_NOT_NULL(buf.buffer); @@ -162,6 +176,16 @@ void TensorRTEngine::SetInputFromCPU(const std::string& name, void* data, cudaMemcpyHostToDevice, *stream_)); } +void TensorRTEngine::SetInputFromGPU(const std::string& name, const void* data, + size_t size) { + auto& buf = buffer(name); + PADDLE_ENFORCE_NOT_NULL(buf.buffer); + PADDLE_ENFORCE_LE(size, buf.max_size, "buffer is too small"); + PADDLE_ENFORCE(buf.device == DeviceType::GPU); + PADDLE_ENFORCE_EQ(0, cudaMemcpyAsync(buf.buffer, data, size, + cudaMemcpyDeviceToDevice, *stream_)); +} + void TensorRTEngine::SetITensor(const std::string& name, nvinfer1::ITensor* tensor) { PADDLE_ENFORCE(tensor != nullptr); diff --git a/paddle/fluid/inference/tensorrt/engine.h b/paddle/fluid/inference/tensorrt/engine.h index b8298c6059e..d9d3163b66d 100644 --- a/paddle/fluid/inference/tensorrt/engine.h +++ b/paddle/fluid/inference/tensorrt/engine.h @@ -92,13 +92,15 @@ class TensorRTEngine : public EngineBase { cudaStream_t* stream() { return stream_; } // Fill an input from CPU memory with name and size. - void SetInputFromCPU(const std::string& name, void* data, size_t size); + void SetInputFromCPU(const std::string& name, const void* data, size_t size); // TODO(Superjomn) is this method necessary given that buffer(xxx) can be // accessed directly. Fill an input from GPU memory with name and size. - void SetInputFromGPU(const std::string& name, void* data, size_t size); + void SetInputFromGPU(const std::string& name, const void* data, size_t size); // Get an output called name, the output of tensorrt is in GPU, so this method - // will just return the output's GPU memory address. + // Return the output's GPU memory address without copy. void* GetOutputInGPU(const std::string& name); + // Copy data into dst inside the GPU device. + void GetOutputInGPU(const std::string& name, void* dst, size_t max_size); // LOW EFFICENCY! Get output to CPU, this will trigger a memory copy from GPU // to CPU. void GetOutputInCPU(const std::string& name, void* dst, size_t max_size); diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 52cdc0bd0ae..819d02ae6f3 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -225,6 +225,9 @@ op_library(cross_entropy_op DEPS cross_entropy) op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax) op_library(softmax_op DEPS softmax) op_library(sequence_softmax_op DEPS softmax) +if (WITH_GPU AND TENSORRT_FOUND) + op_library(tensorrt_engine_op DEPS tensorrt_engine) +endif() op_library(sum_op DEPS selected_rows_functor) op_library(sgd_op DEPS selected_rows_functor) op_library(print_op DEPS lod_tensor) diff --git a/paddle/fluid/operators/tensorrt_engine_op.cc b/paddle/fluid/operators/tensorrt_engine_op.cc new file mode 100644 index 00000000000..83e768b4dc9 --- /dev/null +++ b/paddle/fluid/operators/tensorrt_engine_op.cc @@ -0,0 +1,70 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#ifdef PADDLE_WITH_CUDA + +#include "paddle/fluid/operators/tensorrt_engine_op.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" +#include "paddle/fluid/inference/utils/singleton.h" + +namespace paddle { +namespace operators { + +template +void paddle::operators::TensorRTEngineKernel::Prepare( + const framework::ExecutionContext &context) const { + // Get the ProgramDesc and pass to convert. + const auto &block = context.Attr("subgraph"); + max_batch_ = context.Attr("max_batch"); + auto max_workspace = context.Attr("max_workspace"); + engine_.reset(new inference::tensorrt::TensorRTEngine( + max_batch_, max_workspace, nullptr)); + inference::Singleton::Global().ConvertBlock( + block, engine_.get()); + engine_->FreezeNetwork(); +} + +class TensorRTEngineOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("Xs", "A list of inputs.").AsDuplicable(); + AddOutput("Ys", "A list of outputs").AsDuplicable(); + AddAttr("subgraph", "the subgraph"); + AddComment("TensorRT engine operator."); + } +}; + +class TensorRTEngineInferVarType : public framework::VarTypeInference { + public: + void operator()(const framework::OpDesc &op_desc, + framework::BlockDesc *block) const override {} +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(tensorrt_engine, ops::TensorRTEngineOp, + ops::TensorRTEngineOpMaker, ops::TensorRTEngineOpMaker); + +REGISTER_OP_CPU_KERNEL( + tensorrt_engine, + ops::TensorRTEngineKernel, + ops::TensorRTEngineKernel, + ops::TensorRTEngineKernel, + ops::TensorRTEngineKernel); + +#endif // PADDLE_WITH_CUDA diff --git a/paddle/fluid/operators/tensorrt_engine_op.h b/paddle/fluid/operators/tensorrt_engine_op.h new file mode 100644 index 00000000000..fe273d386c5 --- /dev/null +++ b/paddle/fluid/operators/tensorrt_engine_op.h @@ -0,0 +1,110 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once + +#ifdef PADDLE_WITH_CUDA + +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/inference/analysis/helper.h" +#include "paddle/fluid/inference/tensorrt/engine.h" + +namespace paddle { +namespace operators { + +class TensorRTEngineOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override {} + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + framework::OpKernelType kt = framework::OpKernelType( + framework::ToDataType( + ctx.Input("pre_ids")->type()), + platform::CPUPlace()); + return kt; + } +}; + +template +class TensorRTEngineKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + if (!engine_) { + Prepare(context); + } + auto input_names = context.op().Inputs("Xs"); + PADDLE_ENFORCE(!input_names.empty(), "should pass more than one inputs"); + // Try to determine a batch_size + auto* tensor0 = context.Input(input_names.front()); + PADDLE_ENFORCE_NOT_NULL(tensor0); + int batch_size = tensor0->dims()[0]; + PADDLE_ENFORCE_LE(batch_size, max_batch_); + + // Convert input tensor from fluid to engine. + for (const auto& x : context.Inputs("Xs")) { + // convert input and copy to TRT engine's buffer + auto* v = context.scope().FindVar(x); + PADDLE_ENFORCE_NOT_NULL(v, "no variable called %s", x); + auto& t = v->Get(); + if (platform::is_cpu_place(t.place())) { + engine_->SetInputFromCPU(x, static_cast(t.data()), + t.memory_size()); + } else { + engine_->SetInputFromGPU(x, static_cast(t.data()), + t.memory_size()); + } + } + // Execute the engine. + PADDLE_ENFORCE_GT(batch_size, 0); + engine_->Execute(batch_size); + // Convert output tensor from engine to fluid + for (const auto& y : context.Outputs("Ys")) { + // convert output and copy to fluid. + nvinfer1::ITensor* trt_t = engine_->GetITensor(y); + auto dims = trt_t->getDimensions(); + // Use the output ITensor's dims to reshape the Fluid Tensor. + std::vector ddim(dims.d, dims.d + dims.nbDims); + + auto* fluid_v = context.scope().FindVar(y); + PADDLE_ENFORCE_NOT_NULL(fluid_v, "no output variable called %s", y); + auto* fluid_t = fluid_v->GetMutable(); + fluid_t->Resize(framework::make_ddim(ddim)); + auto size = inference::analysis::AccuDims(dims.d, dims.nbDims); + if (platform::is_cpu_place(fluid_t->place())) { + engine_->GetOutputInCPU( + y, fluid_t->mutable_data(platform::CPUPlace()), size); + } else { + engine_->GetOutputInGPU( + y, fluid_t->mutable_data(platform::CUDAPlace()), size); + } + } + } + + protected: + // Build the engine. + void Prepare(const framework::ExecutionContext& context) const; + + private: + mutable std::unique_ptr engine_; + mutable int max_batch_{0}; +}; + +} // namespace operators +} // namespace paddle + +#endif // PADDLE_WITH_CUDA -- GitLab