diff --git a/paddle/fluid/inference/analysis/helper.h b/paddle/fluid/inference/analysis/helper.h index 153dca576bd6734d62f00c4a7cb9b503506b33e2..58eb0e715cb71d87179f3240de55021603cd7423 100644 --- a/paddle/fluid/inference/analysis/helper.h +++ b/paddle/fluid/inference/analysis/helper.h @@ -18,6 +18,8 @@ limitations under the License. */ #include #include +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/platform/enforce.h" namespace paddle { @@ -107,6 +109,13 @@ class OrderedRegistry { std::vector> data_; }; +template +T &GetFromScope(const framework::Scope &scope, const std::string &name) { + framework::Variable *var = scope.FindVar(name); + PADDLE_ENFORCE(var != nullptr); + return *var->GetMutable(); +} + } // namespace analysis } // namespace inference } // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt index 23ca8bfac84f35ebdca2e2a1a8538d366358ca8b..0dd0e5c9a2b08e406bf500f40e2fc8926012ac0e 100644 --- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -1,10 +1,16 @@ # Add TRT tests -nv_test(test_op_converter SRCS test_op_converter.cc mul_op.cc conv2d_op.cc DEPS ${FLUID_CORE_MODULES} tensorrt_engine) # This test is not stable # See https://paddleci.ngrok.io/viewLog.html?tab=buildLog&buildTypeId=Paddle_PrCi2&buildId=36834&_focus=8828 #nv_test(test_trt_activation_op SRCS test_activation_op.cc activation_op.cc io_converter.cc # DEPS ${FLUID_CORE_MODULES} activation_op tensorrt_engine # SERIAL) +nv_library(tensorrt_converter + SRCS mul_op.cc conv2d_op.cc fc_op.cc + DEPS tensorrt_engine mul_op) + +nv_test(test_op_converter SRCS test_op_converter.cc DEPS + ${FLUID_CORE_MODULES} tensorrt_engine tensorrt_converter) + nv_test(test_io_converter SRCS test_io_converter.cc io_converter.cc DEPS dynload_cuda dynamic_loader lod_tensor) nv_test(test_trt_mul_op SRCS test_mul_op.cc mul_op.cc DEPS ${FLUID_CORE_MODULES} tensorrt_engine mul_op SERIAL) diff --git a/paddle/fluid/inference/tensorrt/convert/activation_op.cc b/paddle/fluid/inference/tensorrt/convert/activation_op.cc index 79d01b640a214ed5eb86173a36d5e85a6626066f..7facf30d781a26c2c6eb0a8966ef1b87e5dfdf0b 100644 --- a/paddle/fluid/inference/tensorrt/convert/activation_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/activation_op.cc @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/inference/tensorrt/convert/op_converter.h" namespace paddle { @@ -36,8 +37,8 @@ class ReluOpConverter : public OpConverter { } }; -REGISTER_TRT_OP_CONVERTER(relu, ReluOpConverter); - } // namespace tensorrt } // namespace inference } // namespace paddle + +REGISTER_TRT_OP_CONVERTER(relu, ReluOpConverter); diff --git a/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc b/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc index 668d344f1bba1c012dcb42c71b996209b4703d78..8e7e23377d4b2fe7afd51f1f58048fc4ed3c6d99 100644 --- a/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc @@ -22,14 +22,14 @@ class Conv2dOpConverter : public OpConverter { public: Conv2dOpConverter() {} void operator()(const framework::proto::OpDesc& op, - const framework::Scope& scope) override { + const framework::Scope& scope, bool test_mode) override { LOG(INFO) << "convert a fluid conv2d op to tensorrt conv layer without bias"; } }; -REGISTER_TRT_OP_CONVERTER(conv2d, Conv2dOpConverter); - } // namespace tensorrt } // namespace inference } // namespace paddle + +REGISTER_TRT_OP_CONVERTER(conv2d, Conv2dOpConverter); diff --git a/paddle/fluid/inference/tensorrt/convert/fc_op.cc b/paddle/fluid/inference/tensorrt/convert/fc_op.cc index 45b079559754a8f5c3fe39781b5700a75f425e99..bb603efaf30bb72d74b5583abc45d01a16c076a3 100644 --- a/paddle/fluid/inference/tensorrt/convert/fc_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/fc_op.cc @@ -56,7 +56,7 @@ void ReorderCKtoKC(TensorRTEngine::Weight& iweights, class FcOpConverter : public OpConverter { public: void operator()(const framework::proto::OpDesc& op, - const framework::Scope& scope) override { + const framework::Scope& scope, bool test_mode) override { VLOG(4) << "convert a fluid fc op to tensorrt fc layer without bias"; framework::OpDesc op_desc(op, nullptr); @@ -106,14 +106,16 @@ class FcOpConverter : public OpConverter { n_output, weight.get(), bias.get()); auto output_name = op_desc.Output("Out").front(); - engine_->DeclareOutput(layer, 0, output_name); + engine_->SetITensor(output_name, layer->getOutput(0)); + if (test_mode) { + engine_->DeclareOutput(output_name); + } } }; -REGISTER_TRT_OP_CONVERTER(fc, FcOpConverter); - } // namespace tensorrt } // namespace inference } // namespace paddle +REGISTER_TRT_OP_CONVERTER(fc, FcOpConverter); USE_OP(mul); diff --git a/paddle/fluid/inference/tensorrt/convert/mul_op.cc b/paddle/fluid/inference/tensorrt/convert/mul_op.cc index 6bb07709c7ee1c6b29c46425849a4f472d3df59d..3c342957360ad4192d838147bf37e84d233c2629 100644 --- a/paddle/fluid/inference/tensorrt/convert/mul_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/mul_op.cc @@ -23,9 +23,8 @@ namespace tensorrt { */ class MulOpConverter : public OpConverter { public: - MulOpConverter() {} void operator()(const framework::proto::OpDesc& op, - const framework::Scope& scope) override { + const framework::Scope& scope, bool test_mode) override { VLOG(4) << "convert a fluid mul op to tensorrt mul layer without bias"; framework::OpDesc op_desc(op, nullptr); @@ -37,12 +36,18 @@ class MulOpConverter : public OpConverter { engine_, MatrixMultiply, *const_cast(input1), false, *const_cast(input2), false); - engine_->DeclareOutput(layer, 0, op_desc.Output("Out")[0]); + auto output_name = op_desc.Output("Out")[0]; + engine_->SetITensor(output_name, layer->getOutput(0)); + if (test_mode) { // the test framework can not determine which is the + // output, so place the declaration inside. + engine_->DeclareOutput(output_name); + } } }; -REGISTER_TRT_OP_CONVERTER(mul, MulOpConverter); - } // namespace tensorrt } // namespace inference } // namespace paddle + +USE_OP(mul); +REGISTER_TRT_OP_CONVERTER(mul, MulOpConverter); diff --git a/paddle/fluid/inference/tensorrt/convert/op_converter.h b/paddle/fluid/inference/tensorrt/convert/op_converter.h index 3beafeefd06f24ec50b0e61c1fabe13d7e53f242..c7a5a49dd02d0db022fabff5c3ae1c7800bac25c 100644 --- a/paddle/fluid/inference/tensorrt/convert/op_converter.h +++ b/paddle/fluid/inference/tensorrt/convert/op_converter.h @@ -17,6 +17,7 @@ limitations under the License. */ #include #include #include "paddle/fluid/framework/block_desc.h" +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/inference/tensorrt/engine.h" #include "paddle/fluid/inference/utils/singleton.h" @@ -34,12 +35,15 @@ class OpConverter { // Converter logic for an op. virtual void operator()(const framework::proto::OpDesc& op, - const framework::Scope& scope) {} + const framework::Scope& scope, + bool test_mode = false) {} - // Convert a single fluid operaotr and add the corresponding layer to TRT. + // Convert a single fluid operator and add the corresponding layer to TRT. + // test_mode: whether the instance executes in an unit test. void ConvertOp(const framework::proto::OpDesc& op, const std::unordered_set& parameters, - const framework::Scope& scope, TensorRTEngine* engine) { + const framework::Scope& scope, TensorRTEngine* engine, + bool test_mode = false) { framework::OpDesc op_desc(op, nullptr); OpConverter* it{nullptr}; @@ -57,7 +61,7 @@ class OpConverter { PADDLE_ENFORCE_NOT_NULL(it, "no OpConverter for optype [%s]", op_desc.Type()); it->SetEngine(engine); - (*it)(op, scope); + (*it)(op, scope, test_mode); } // convert fluid block to tensorrt network @@ -77,6 +81,9 @@ class OpConverter { // TensorRT engine TensorRTEngine* engine_{nullptr}; + protected: + bool test_mode_; + private: // registered op converter map, whose key is the fluid op type, and value is // the pointer position of corresponding OpConverter class. @@ -85,13 +92,24 @@ class OpConverter { framework::Scope* scope_{nullptr}; }; -#define REGISTER_TRT_OP_CONVERTER(op_type__, Converter__) \ - struct trt_##op_type__##_converter { \ - trt_##op_type__##_converter() { \ - Registry::Register(#op_type__); \ - } \ - }; \ - trt_##op_type__##_converter trt_##op_type__##_converter__; +#define REGISTER_TRT_OP_CONVERTER(op_type__, Converter__) \ + struct trt_##op_type__##_converter : public ::paddle::framework::Registrar { \ + trt_##op_type__##_converter() { \ + ::paddle::inference:: \ + Registry::Register< \ + ::paddle::inference::tensorrt::Converter__>(#op_type__); \ + } \ + }; \ + trt_##op_type__##_converter trt_##op_type__##_converter__; \ + int TouchConverterRegister_##op_type__() { \ + trt_##op_type__##_converter__.Touch(); \ + return 0; \ + } + +#define USE_TRT_CONVERTER(op_type__) \ + extern int TouchConverterRegister_##op_type__(); \ + static int use_op_converter_trt_##op_type__ __attribute__((unused)) = \ + TouchConverterRegister_##op_type__(); } // namespace tensorrt } // namespace inference diff --git a/paddle/fluid/inference/tensorrt/convert/test_op_converter.cc b/paddle/fluid/inference/tensorrt/convert/test_op_converter.cc index 1d3f5eabb2f839b2acfa9da6527589df1ec3767f..9b79f86b0edba983019bd932f52b08711ff36d41 100644 --- a/paddle/fluid/inference/tensorrt/convert/test_op_converter.cc +++ b/paddle/fluid/inference/tensorrt/convert/test_op_converter.cc @@ -36,3 +36,5 @@ TEST(OpConverter, ConvertBlock) { } // namespace tensorrt } // namespace inference } // namespace paddle + +USE_TRT_CONVERTER(conv2d) diff --git a/paddle/fluid/inference/tensorrt/convert/ut_helper.h b/paddle/fluid/inference/tensorrt/convert/ut_helper.h index d7e05dd5b5b235b7b166b22c5b094dc364e28dfc..8613d5b1c13bc24572b374a8d115690f089a71d1 100644 --- a/paddle/fluid/inference/tensorrt/convert/ut_helper.h +++ b/paddle/fluid/inference/tensorrt/convert/ut_helper.h @@ -27,6 +27,7 @@ limitations under the License. */ #include "paddle/fluid/inference/analysis/helper.h" #include "paddle/fluid/inference/tensorrt/convert/op_converter.h" #include "paddle/fluid/inference/tensorrt/engine.h" +#include "paddle/fluid/inference/utils/singleton.h" namespace paddle { namespace inference { @@ -104,8 +105,8 @@ class TRTConvertValidation { void SetOp(const framework::proto::OpDesc& desc) { op_ = framework::OpRegistry::CreateOp(desc); - OpConverter op_converter; - op_converter.ConvertOp(desc, parameters_, scope_, engine_.get()); + Singleton::Global().ConvertOp( + desc, parameters_, scope_, engine_.get(), true /*test_mode*/); engine_->FreezeNetwork(); diff --git a/paddle/fluid/inference/tensorrt/engine.cc b/paddle/fluid/inference/tensorrt/engine.cc index 3d75fefc1a735168131a6c67ac073e80aba32945..596e0fe9da3d272ecb1c0f8dbef09a75d08a4b1a 100644 --- a/paddle/fluid/inference/tensorrt/engine.cc +++ b/paddle/fluid/inference/tensorrt/engine.cc @@ -43,9 +43,10 @@ void TensorRTEngine::Execute(int batch_size) { } TensorRTEngine::~TensorRTEngine() { + cudaStreamSynchronize(*stream_); // clean buffer for (auto& buf : buffers_) { - if (buf.buffer != nullptr) { + if (buf.device == DeviceType::GPU && buf.buffer != nullptr) { PADDLE_ENFORCE_EQ(0, cudaFree(buf.buffer)); buf.buffer = nullptr; buf.max_size = 0; @@ -80,6 +81,8 @@ void TensorRTEngine::FreezeNetwork() { auto& buf = buffer(item.first); CHECK(buf.buffer == nullptr); // buffer should be allocated only once. PADDLE_ENFORCE_EQ(0, cudaMalloc(&buf.buffer, item.second)); + VLOG(4) << "buffer malloc " << item.first << " " << item.second << " " + << buf.buffer; buf.size = buf.max_size = item.second; buf.device = DeviceType::GPU; } @@ -96,6 +99,7 @@ nvinfer1::ITensor* TensorRTEngine::DeclareInput(const std::string& name, PADDLE_ENFORCE(input, "infer network add input %s failed", name); buffer_sizes_[name] = kDataTypeSize[static_cast(dtype)] * analysis::AccuDims(dims.d, dims.nbDims); + PADDLE_ENFORCE(input->isNetworkInput()); TensorRTEngine::SetITensor(name, input); return input; } @@ -109,7 +113,9 @@ void TensorRTEngine::DeclareOutput(const nvinfer1::ILayer* layer, int offset, SetITensor(name, output); PADDLE_ENFORCE(output != nullptr); output->setName(name.c_str()); + PADDLE_ENFORCE(!output->isNetworkInput()); infer_network_->markOutput(*output); + PADDLE_ENFORCE(output->isNetworkOutput()); // output buffers' size can only be decided latter, set zero here to mark this // and will reset latter. buffer_sizes_[name] = 0; @@ -122,6 +128,7 @@ void TensorRTEngine::DeclareOutput(const std::string& name) { auto* output = TensorRTEngine::GetITensor(name); PADDLE_ENFORCE(output != nullptr); output->setName(name.c_str()); + PADDLE_ENFORCE(!output->isNetworkInput()); infer_network_->markOutput(*output); // output buffers' size can only be decided latter, set zero here to mark this // and will reset latter. diff --git a/paddle/fluid/inference/tensorrt/engine.h b/paddle/fluid/inference/tensorrt/engine.h index fabcfd9e80cc0ef2637201a1499ebbe2d6adfd8c..b60f00de9fa5fc8f8f4537379bf9ee9c8bb6f31c 100644 --- a/paddle/fluid/inference/tensorrt/engine.h +++ b/paddle/fluid/inference/tensorrt/engine.h @@ -21,6 +21,7 @@ limitations under the License. */ #include #include "paddle/fluid/inference/engine.h" #include "paddle/fluid/inference/tensorrt/helper.h" +#include "paddle/fluid/inference/utils/singleton.h" namespace paddle { namespace inference { @@ -131,7 +132,11 @@ class TensorRTEngine : public EngineBase { // TensorRT related internal members template struct Destroyer { - void operator()(T* x) { x->destroy(); } + void operator()(T* x) { + if (x) { + x->destroy(); + } + } }; template using infer_ptr = std::unique_ptr>; @@ -155,6 +160,27 @@ class TensorRTEngine : public EngineBase { #define TRT_ENGINE_ADD_LAYER(engine__, layer__, ARGS...) \ engine__->network()->add##layer__(ARGS); +/* + * Helper to control the TensorRT engine's creation and deletion. + */ +class TRT_EngineManager { + public: + TensorRTEngine* Create(int max_batch, int max_workspace, + cudaStream_t* stream) { + engines_.emplace_back(new TensorRTEngine(max_batch, max_workspace, stream)); + return engines_.back().get(); + } + + void DeleteALl() { + for (auto& ptr : engines_) { + ptr.reset(nullptr); + } + } + + private: + std::vector> engines_; +}; + } // namespace tensorrt } // namespace inference } // namespace paddle diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index de6ff29c6f8edbcf930546ff157a1c226e1311db..f75b7c70d60e77eb07927261d3c60bd526986f98 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -227,6 +227,8 @@ op_library(softmax_op DEPS softmax) op_library(sequence_softmax_op DEPS softmax) if (WITH_GPU AND TENSORRT_FOUND) op_library(tensorrt_engine_op DEPS tensorrt_engine) + nv_test(test_tensorrt_engine_op SRCS tensorrt_engine_op_test.cc + DEPS tensorrt_engine_op tensorrt_engine tensorrt_converter) else() set(DEPS_OPS ${DEPS_OPS} tensorrt_engine_op) endif() diff --git a/paddle/fluid/operators/tensorrt_engine_op.cc b/paddle/fluid/operators/tensorrt_engine_op.cc index 855157e7c4c5c4a43091d28d3a5414e6e386b727..4b1208c4376b48e25866fc510f3a6d2ea06e7610 100644 --- a/paddle/fluid/operators/tensorrt_engine_op.cc +++ b/paddle/fluid/operators/tensorrt_engine_op.cc @@ -17,23 +17,93 @@ #include "paddle/fluid/operators/tensorrt_engine_op.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/inference/tensorrt/convert/op_converter.h" +#include "paddle/fluid/inference/tensorrt/engine.h" #include "paddle/fluid/inference/utils/singleton.h" namespace paddle { namespace operators { +using inference::Singleton; +using inference::tensorrt::TRT_EngineManager; + +using FluidDT = framework::proto::VarType_Type; +using TRT_DT = nvinfer1::DataType; + +namespace { + +TRT_DT FluidDataType2TRT(FluidDT type) { + switch (type) { + case FluidDT::VarType_Type_FP32: + return TRT_DT::kFLOAT; + case FluidDT::VarType_Type_INT32: + return TRT_DT::kINT32; + default: + return TRT_DT::kINT32; + } + PADDLE_THROW("unkown type"); + return TRT_DT::kINT32; +} + +nvinfer1::Dims Vec2TRT_Dims(const std::vector &shape) { + PADDLE_ENFORCE_GT(shape.size(), 1UL, + "TensorRT' tensor input requires at least 2 dimensions"); + PADDLE_ENFORCE_LE(shape.size(), 4UL, + "TensorRT' tensor input requires at most 4 dimensions"); + + switch (shape.size()) { + case 2: + return nvinfer1::Dims2(shape[0], shape[1]); + case 3: + return nvinfer1::Dims3(shape[0], shape[1], shape[2]); + case 4: + return nvinfer1::Dims4(shape[0], shape[1], shape[2], shape[3]); + default: + return nvinfer1::Dims(); + } + return nvinfer1::Dims(); +} + +} // namespace + template void paddle::operators::TensorRTEngineKernel::Prepare( const framework::ExecutionContext &context) const { + VLOG(4) << "Prepare engine"; // Get the ProgramDesc and pass to convert. - const auto &block = context.Attr("subgraph"); + framework::proto::BlockDesc block_desc; + block_desc.ParseFromString(context.Attr("subgraph")); max_batch_ = context.Attr("max_batch"); auto max_workspace = context.Attr("max_workspace"); - engine_.reset(new inference::tensorrt::TensorRTEngine( - max_batch_, max_workspace, nullptr)); + engine_ = Singleton::Global().Create( + max_batch_, max_workspace, &stream_); + engine_->InitNetwork(); + + framework::BlockDesc block(nullptr /*programdesc*/, &block_desc); + // Add inputs + VLOG(4) << "declare inputs"; + for (auto &input : context.Inputs("Xs")) { + VLOG(4) << "declare input " << input; + auto *var = block.FindVar(input); + PADDLE_ENFORCE_EQ(var->GetType(), FluidDT::VarType_Type_LOD_TENSOR, + "TensorRT engine only takes LoDTensor as input"); + auto shape = var->GetShape(); + engine_->DeclareInput( + input, FluidDataType2TRT( + var->Proto()->type().lod_tensor().tensor().data_type()), + Vec2TRT_Dims(var->GetShape())); + } + // TODO(Superjomn) parameters should be passed after analysised from outside. inference::Singleton::Global().ConvertBlock( - block, {}, context.scope(), engine_.get()); + block_desc, {}, context.scope(), engine_); + + // Add outputs + VLOG(4) << "declare outputs"; + for (auto &output : context.Outputs("Ys")) { + VLOG(4) << "declare output " << output; + engine_->DeclareOutput(output); + } + engine_->FreezeNetwork(); } @@ -42,7 +112,9 @@ class TensorRTEngineOpMaker : public framework::OpProtoAndCheckerMaker { void Make() override { AddInput("Xs", "A list of inputs.").AsDuplicable(); AddOutput("Ys", "A list of outputs").AsDuplicable(); - AddAttr("subgraph", "the subgraph"); + AddAttr("subgraph", "the subgraph."); + AddAttr("max_batch", "the maximum batch size."); + AddAttr("max_workspace", "the maximum batch size."); AddComment("TensorRT engine operator."); } }; diff --git a/paddle/fluid/operators/tensorrt_engine_op.h b/paddle/fluid/operators/tensorrt_engine_op.h index fe273d386c529be3df05a955f492e2c39d4d8812..4b089601ff76eedd87bb3a52a38c4d22d4a94bf6 100644 --- a/paddle/fluid/operators/tensorrt_engine_op.h +++ b/paddle/fluid/operators/tensorrt_engine_op.h @@ -32,9 +32,12 @@ class TensorRTEngineOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { + auto input0 = ctx.Inputs("Xs").front(); framework::OpKernelType kt = framework::OpKernelType( - framework::ToDataType( - ctx.Input("pre_ids")->type()), + framework::ToDataType(ctx.scope() + .FindVar(input0) + ->GetMutable() + ->type()), platform::CPUPlace()); return kt; } @@ -50,17 +53,16 @@ class TensorRTEngineKernel : public framework::OpKernel { auto input_names = context.op().Inputs("Xs"); PADDLE_ENFORCE(!input_names.empty(), "should pass more than one inputs"); // Try to determine a batch_size - auto* tensor0 = context.Input(input_names.front()); - PADDLE_ENFORCE_NOT_NULL(tensor0); - int batch_size = tensor0->dims()[0]; + auto& tensor0 = inference::analysis::GetFromScope( + context.scope(), input_names.front()); + int batch_size = tensor0.dims()[0]; PADDLE_ENFORCE_LE(batch_size, max_batch_); // Convert input tensor from fluid to engine. for (const auto& x : context.Inputs("Xs")) { // convert input and copy to TRT engine's buffer - auto* v = context.scope().FindVar(x); - PADDLE_ENFORCE_NOT_NULL(v, "no variable called %s", x); - auto& t = v->Get(); + auto& t = inference::analysis::GetFromScope( + context.scope(), x); if (platform::is_cpu_place(t.place())) { engine_->SetInputFromCPU(x, static_cast(t.data()), t.memory_size()); @@ -86,13 +88,18 @@ class TensorRTEngineKernel : public framework::OpKernel { fluid_t->Resize(framework::make_ddim(ddim)); auto size = inference::analysis::AccuDims(dims.d, dims.nbDims); if (platform::is_cpu_place(fluid_t->place())) { + // TODO(Superjomn) change this float to dtype size. engine_->GetOutputInCPU( - y, fluid_t->mutable_data(platform::CPUPlace()), size); + y, fluid_t->mutable_data(platform::CPUPlace()), + size * sizeof(float)); } else { engine_->GetOutputInGPU( - y, fluid_t->mutable_data(platform::CUDAPlace()), size); + y, fluid_t->mutable_data(platform::CUDAPlace()), + size * sizeof(float)); } } + + cudaStreamSynchronize(stream_); } protected: @@ -100,7 +107,8 @@ class TensorRTEngineKernel : public framework::OpKernel { void Prepare(const framework::ExecutionContext& context) const; private: - mutable std::unique_ptr engine_; + mutable cudaStream_t stream_; + mutable inference::tensorrt::TensorRTEngine* engine_{nullptr}; mutable int max_batch_{0}; }; diff --git a/paddle/fluid/operators/tensorrt_engine_op_test.cc b/paddle/fluid/operators/tensorrt_engine_op_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..6f383de259b270038c32296b59007f6c7d895f12 --- /dev/null +++ b/paddle/fluid/operators/tensorrt_engine_op_test.cc @@ -0,0 +1,152 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "paddle/fluid/framework/block_desc.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_desc.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" +#include "paddle/fluid/inference/tensorrt/convert/ut_helper.h" + +USE_CPU_ONLY_OP(tensorrt_engine); + +namespace paddle { +namespace operators { + +namespace { +void CreateCPUTensor(framework::Scope* scope, const std::string& name, + const std::vector& shape) { + auto* var = scope->Var(name); + auto* tensor = var->GetMutable(); + auto dims = framework::make_ddim(shape); + tensor->Resize(dims); + platform::CPUPlace place; + platform::CPUDeviceContext ctx(place); + inference::tensorrt::RandomizeTensor(tensor, place, ctx); +} + +void AddTensorToBlockDesc(framework::proto::BlockDesc* block, + const std::string& name, + const std::vector& shape) { + using framework::proto::VarType; + auto* var = block->add_vars(); + framework::VarDesc desc(name); + desc.SetType(VarType::LOD_TENSOR); + desc.SetDataType(VarType::FP32); + desc.SetShape(shape); + *var = *desc.Proto(); +} + +template +void SetAttr(framework::proto::OpDesc* op, const std::string& name, + const T& data); + +template <> +void SetAttr(framework::proto::OpDesc* op, const std::string& name, + const std::string& data) { + auto* attr = op->add_attrs(); + attr->set_name(name); + attr->set_type(paddle::framework::proto::AttrType::STRING); + attr->set_s(data); +} +template <> +void SetAttr(framework::proto::OpDesc* op, const std::string& name, + const int& data) { + auto* attr = op->add_attrs(); + attr->set_name(name); + attr->set_type(paddle::framework::proto::AttrType::INT); + attr->set_i(data); +} +template <> +void SetAttr(framework::proto::OpDesc* op, const std::string& name, + const int64_t& data) { + auto* attr = op->add_attrs(); + attr->set_name(name); + attr->set_type(paddle::framework::proto::AttrType::LONG); + attr->set_l(data); +} + +} // namespace + +TEST(TensorRTEngineOp, manual) { + framework::ProgramDesc program; + auto* block_ = program.Proto()->add_blocks(); + block_->set_idx(0); + block_->set_parent_idx(-1); + + LOG(INFO) << "create block desc"; + framework::BlockDesc block_desc(&program, block_); + LOG(INFO) << "create mul op"; + auto* mul = block_desc.AppendOp(); + mul->SetType("mul"); + mul->SetInput("X", std::vector({"x"})); // 2 x 4 + mul->SetInput("Y", std::vector({"y"})); // 4 x 6 + mul->SetOutput("Out", std::vector({"z"})); // 2 x 6 + + LOG(INFO) << "create fc op"; + auto* fc = block_desc.AppendOp(); + fc->SetType("mul"); + fc->SetInput("X", std::vector({"z"})); + fc->SetInput("Y", std::vector({"y0"})); // 6 x 8 + fc->SetOutput("Out", std::vector({"z0"})); // 2 x 8 + + // Set inputs' variable shape in BlockDesc + AddTensorToBlockDesc(block_, "x", std::vector({2, 4})); + AddTensorToBlockDesc(block_, "y", std::vector({4, 6})); + AddTensorToBlockDesc(block_, "y0", std::vector({6, 8})); + AddTensorToBlockDesc(block_, "z", std::vector({2, 6})); + + // It is wired, need to copy manually. + *block_->add_ops() = *mul->Proto(); + *block_->add_ops() = *fc->Proto(); + + ASSERT_EQ(block_->ops_size(), 2); + + LOG(INFO) << "create tensorrt desc"; + framework::OpDesc engine_op_desc(nullptr); + engine_op_desc.SetType("tensorrt_engine"); + engine_op_desc.SetInput("Xs", std::vector({"x", "y", "y0"})); + engine_op_desc.SetOutput("Ys", std::vector({"z0"})); + SetAttr(engine_op_desc.Proto(), "subgraph", + block_->SerializeAsString()); + SetAttr(engine_op_desc.Proto(), "max_batch", 30); + SetAttr(engine_op_desc.Proto(), "max_workspace", 1 << 10); + + LOG(INFO) << "create engine op"; + auto engine_op = framework::OpRegistry::CreateOp(*engine_op_desc.Proto()); + + framework::Scope scope; + platform::CPUPlace place; + platform::CPUDeviceContext ctx(place); + // Prepare variables. + CreateCPUTensor(&scope, "x", std::vector({2, 4})); + CreateCPUTensor(&scope, "y", std::vector({4, 6})); + CreateCPUTensor(&scope, "z", std::vector({2, 6})); + + CreateCPUTensor(&scope, "y0", std::vector({6, 8})); + CreateCPUTensor(&scope, "z0", std::vector({2, 8})); + + // Execute them. + LOG(INFO) << "engine_op run"; + engine_op->Run(scope, place); +} + +} // namespace operators +} // namespace paddle + +USE_TRT_CONVERTER(mul) +USE_TRT_CONVERTER(fc)