diff --git a/paddle/fluid/inference/tensorrt/CMakeLists.txt b/paddle/fluid/inference/tensorrt/CMakeLists.txt index 4b5866ad5dd5c2e33830b61e575ab92954cecb39..288789d6e484100820c937e6081701f1e9245706 100644 --- a/paddle/fluid/inference/tensorrt/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/CMakeLists.txt @@ -1,4 +1,4 @@ -if(WITH_TESTING) - nv_test(test_tensorrt SRCS test_tensorrt.cc DEPS dynload_cuda device_context dynamic_loader) - nv_test(test_tensorrt_engine SRCS test_engine.cc engine.cc DEPS dynload_cuda) -endif() +nv_test(test_tensorrt SRCS test_tensorrt.cc DEPS dynload_cuda device_context dynamic_loader) +nv_test(test_tensorrt_engine SRCS test_engine.cc engine.cc DEPS dynload_cuda) +set(ENGINE_FILE ${CMAKE_CURRENT_SOURCE_DIR}/engine.cc) +add_subdirectory(convert) diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..572e29515f89685c528023cf6569ef8e559d6874 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -0,0 +1,3 @@ +nv_test(test_tensorrt_op_converter SRCS test_op_converter.cc mul_op.cc conv2d_op.cc DEPS ${FLUID_CORE_MODULES}) +nv_test(test_tensorrt_activation_op SRCS test_activation_op.cc ${ENGINE_FILE} activation_op.cc + DEPS ${FLUID_CORE_MODULES} activation_op) diff --git a/paddle/fluid/inference/tensorrt/convert/activation_op.cc b/paddle/fluid/inference/tensorrt/convert/activation_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..543784289cfc51048057e467d36fdd1f334eb903 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/activation_op.cc @@ -0,0 +1,40 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +class ReluOpConverter : public OpConverter { + public: + ReluOpConverter() {} + void operator()(const framework::OpDesc& op) override { + LOG(INFO) << "convert a fluid relu op to tensorrt activation layer whose " + "type is Relu"; + const nvinfer1::ITensor* input_tensor = + engine_->GetITensor(op.Input("X")[0]); + nvinfer1::IActivationLayer* layer = TRT_ENGINE_ADD_LAYER( + engine_, Activation, *const_cast(input_tensor), + nvinfer1::ActivationType::kRELU); + engine_->SetITensor(op.Output("Out")[0], layer->getOutput(0)); + } +}; + +REGISTER_TRT_OP_CONVERTER(relu, ReluOpConverter); + +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc b/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..431500b90e144e2a30fe705b72e93452f806ca65 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc @@ -0,0 +1,34 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +class Conv2dOpConverter : public OpConverter { + public: + Conv2dOpConverter() {} + void operator()(const framework::OpDesc& op) override { + LOG(INFO) + << "convert a fluid conv2d op to tensorrt conv layer without bias"; + } +}; + +REGISTER_TRT_OP_CONVERTER(conv2d, Conv2dOpConverter); + +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/convert/mul_op.cc b/paddle/fluid/inference/tensorrt/convert/mul_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..f9834ab156c9dcc11f4e89075b7bf5457cf00268 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/mul_op.cc @@ -0,0 +1,33 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +class MulOpConverter : public OpConverter { + public: + MulOpConverter() {} + void operator()(const framework::OpDesc& op) override { + LOG(INFO) << "convert a fluid mul op to tensorrt fc layer without bias"; + } +}; + +REGISTER_TRT_OP_CONVERTER(mul, MulOpConverter); + +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/convert/op_converter.h b/paddle/fluid/inference/tensorrt/convert/op_converter.h new file mode 100644 index 0000000000000000000000000000000000000000..f8ca219bb83c6db3ba640464233c6cbd337926a6 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/op_converter.h @@ -0,0 +1,91 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include "paddle/fluid/framework/block_desc.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/inference/tensorrt/engine.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +/* + * Convert Op from Fluid to TensorRT Engine. + */ +class OpConverter { + public: + OpConverter() {} + virtual void operator()(const framework::OpDesc& op) {} + + void Execute(const framework::OpDesc& op, TensorRTEngine* engine) { + std::string type = op.Type(); + auto it = converters_.find(type); + PADDLE_ENFORCE(it != converters_.end(), "no OpConverter for optype [%s]", + type); + it->second->SetEngine(engine); + (*it->second)(op); + } + + static OpConverter& Global() { + static auto* x = new OpConverter; + return *x; + } + + template + void Register(const std::string& key) { + converters_[key] = new T; + } + + // convert fluid op to tensorrt layer + void ConvertOp(const framework::OpDesc& op, TensorRTEngine* engine) { + OpConverter::Global().Execute(op, engine); + } + + // convert fluid block to tensorrt network + void ConvertBlock(const framework::BlockDesc& block, TensorRTEngine* engine) { + for (auto op : block.AllOps()) { + OpConverter::Global().Execute(*op, engine); + } + } + + void SetEngine(TensorRTEngine* engine) { engine_ = engine; } + + virtual ~OpConverter() {} + + // TensorRT engine + TensorRTEngine* engine_{nullptr}; + + private: + // registered op converter map, whose key is the fluid op type, and value is + // the pointer position of corresponding OpConverter class. + std::unordered_map converters_; + // fluid inference scope + framework::Scope* scope_{nullptr}; +}; + +#define REGISTER_TRT_OP_CONVERTER(op_type__, Converter__) \ + struct trt_##op_type__##_converter { \ + trt_##op_type__##_converter() { \ + OpConverter::Global().Register(#op_type__); \ + } \ + }; \ + trt_##op_type__##_converter trt_##op_type__##_converter__; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/convert/test_activation_op.cc b/paddle/fluid/inference/tensorrt/convert/test_activation_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..0f390bee1f5967547ea21e6d42a1b4300de58d04 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/test_activation_op.cc @@ -0,0 +1,94 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/place.h" + +USE_OP(relu); + +namespace paddle { +namespace inference { +namespace tensorrt { + +void compare(float input, float expect) { + framework::Scope scope; + platform::CUDAPlace place; + platform::CUDADeviceContext ctx(place); + + // init fluid op and variable + auto x_var = scope.Var("X"); + auto x_tensor = x_var->GetMutable(); + x_tensor->Resize({1, 1}); + std::vector init; + init.push_back(input); + framework::TensorFromVector(init, ctx, x_tensor); + + auto out_var = scope.Var("Out"); + auto out_tensor = out_var->GetMutable(); + out_tensor->Resize({1, 1}); + out_tensor->mutable_data(place); + + framework::OpDesc op_desc; + op_desc.SetType("relu"); + op_desc.SetInput("X", {"X"}); + op_desc.SetOutput("Out", {"Out"}); + + auto relu_op = framework::OpRegistry::CreateOp(op_desc); + + // run fluid op + relu_op->Run(scope, place); + std::vector out1; + framework::TensorToVector(*out_tensor, ctx, &out1); + + // init tensorrt op + cudaStream_t stream; + ASSERT_EQ(0, cudaStreamCreate(&stream)); + TensorRTEngine* engine = new TensorRTEngine(1, 1 << 10, &stream); + engine->InitNetwork(); + engine->DeclareInput("X", nvinfer1::DataType::kFLOAT, + nvinfer1::DimsCHW{1, 1, 1}); + + OpConverter op_converter; + op_converter.ConvertOp(op_desc, engine); + + engine->DeclareOutput("Out"); + engine->FreezeNetwork(); + engine->SetInputFromCPU("X", &input, 1 * sizeof(float)); + + // run tensorrt op + engine->Execute(1); + + float out2; + engine->GetOutputInCPU("Out", &out2, 1 * sizeof(float)); + + ASSERT_EQ(out1[0], out2); + ASSERT_EQ(out1[0], expect); + + delete engine; + cudaStreamDestroy(stream); +} + +TEST(OpConverter, ConvertRelu) { + compare(1, 1); // relu(1) = 1 + compare(-5, 0); // relu(-5) = 0 +} + +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/convert/test_op_converter.cc b/paddle/fluid/inference/tensorrt/convert/test_op_converter.cc new file mode 100644 index 0000000000000000000000000000000000000000..5c5ac10394d3cb76ac4b8409c7fa79eb5e9a90b2 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/test_op_converter.cc @@ -0,0 +1,37 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +TEST(BlockConverter, ConvertBlock) { + framework::ProgramDesc prog; + auto* block = prog.MutableBlock(0); + auto* mul_op = block->AppendOp(); + mul_op->SetType("mul"); + auto* conv2d_op = block->AppendOp(); + conv2d_op->SetType("conv2d"); + + OpConverter converter; + converter.ConvertBlock(*block, nullptr /*TensorRTEngine*/); +} + +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/engine.cc b/paddle/fluid/inference/tensorrt/engine.cc index 03a25f8e8b52b452cb621aeddf3563cc21a033df..df123a59079acc5f549e733b412ab302aa397a92 100644 --- a/paddle/fluid/inference/tensorrt/engine.cc +++ b/paddle/fluid/inference/tensorrt/engine.cc @@ -80,8 +80,8 @@ nvinfer1::ITensor* TensorRTEngine::DeclareInput(const std::string& name, PADDLE_ENFORCE(infer_network_ != nullptr, "should initnetwork first"); auto* input = infer_network_->addInput(name.c_str(), dtype, dim); PADDLE_ENFORCE(input, "infer network add input %s failed", name); - buffer_sizes_[name] = kDataTypeSize[static_cast(dtype)] * AccumDims(dim); + TensorRTEngine::SetITensor(name, input); return input; } @@ -99,6 +99,19 @@ void TensorRTEngine::DeclareOutput(const nvinfer1::ILayer* layer, int offset, buffer_sizes_[name] = 0; } +void TensorRTEngine::DeclareOutput(const std::string& name) { + PADDLE_ENFORCE_EQ(0, buffer_sizes_.count(name), "duplicate output name %s", + name); + + auto* output = TensorRTEngine::GetITensor(name); + PADDLE_ENFORCE(output != nullptr); + output->setName(name.c_str()); + infer_network_->markOutput(*output); + // output buffers' size can only be decided latter, set zero here to mark this + // and will reset latter. + buffer_sizes_[name] = 0; +} + void* TensorRTEngine::GetOutputInGPU(const std::string& name) { return buffer(name); } @@ -110,7 +123,6 @@ void TensorRTEngine::GetOutputInCPU(const std::string& name, void* dst, PADDLE_ENFORCE(it != buffer_sizes_.end()); PADDLE_ENFORCE_GT(it->second, 0); PADDLE_ENFORCE_GE(max_size, it->second); - PADDLE_ENFORCE_EQ(0, cudaMemcpyAsync(dst, buffer(name), it->second, cudaMemcpyDeviceToHost, *stream_)); } @@ -126,10 +138,24 @@ void*& TensorRTEngine::buffer(const std::string& name) { void TensorRTEngine::SetInputFromCPU(const std::string& name, void* data, size_t size) { void* buf = buffer(name); + cudaMemcpyAsync(buf, data, size, cudaMemcpyHostToDevice, *stream_); PADDLE_ENFORCE_EQ( 0, cudaMemcpyAsync(buf, data, size, cudaMemcpyHostToDevice, *stream_)); } +void TensorRTEngine::SetITensor(const std::string& name, + nvinfer1::ITensor* tensor) { + PADDLE_ENFORCE(tensor != nullptr); + PADDLE_ENFORCE_EQ(0, itensor_map_.count(name), "duplicate itensor name %s", + name); + itensor_map_[name] = tensor; +} + +nvinfer1::ITensor* TensorRTEngine::GetITensor(const std::string& name) { + PADDLE_ENFORCE(itensor_map_.count(name), "no itensor %s", name); + return itensor_map_[name]; +} + } // namespace tensorrt } // namespace inference } // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/engine.h b/paddle/fluid/inference/tensorrt/engine.h index d6d4c2f8a2ced08ed8481e92e131f6e2bed9ec05..ec919b943d3281dd675b15e2f14adb7b3487f46f 100644 --- a/paddle/fluid/inference/tensorrt/engine.h +++ b/paddle/fluid/inference/tensorrt/engine.h @@ -80,6 +80,8 @@ class TensorRTEngine : public EngineBase { // name. void DeclareOutput(const nvinfer1::ILayer* layer, int offset, const std::string& name); + // Set the itensor_map_[name] as the network's output, and set its name. + void DeclareOutput(const std::string& name); // GPU memory address for an ITensor with specific name. One can operate on // these memory directly for acceleration, for example, output the converted @@ -98,6 +100,10 @@ class TensorRTEngine : public EngineBase { // LOW EFFICENCY! Get output to CPU, this will trigger a memory copy from GPU // to CPU. void GetOutputInCPU(const std::string& name, void* dst, size_t max_size); + // Fill an ITensor into map itensor_map_. + void SetITensor(const std::string& name, nvinfer1::ITensor* tensor); + // Get an ITensor called name. + nvinfer1::ITensor* GetITensor(const std::string& name); nvinfer1::ICudaEngine* engine() { return infer_engine_.get(); } nvinfer1::INetworkDefinition* network() { return infer_network_.get(); } @@ -113,6 +119,8 @@ class TensorRTEngine : public EngineBase { std::vector buffers_; // max data size for the buffers. std::unordered_map buffer_sizes_; + std::unordered_map + itensor_map_; // TensorRT related internal members template diff --git a/paddle/fluid/inference/tensorrt/test_engine.cc b/paddle/fluid/inference/tensorrt/test_engine.cc index c6e1c71cdc8eea48ab6197057d1d389f03c9cf54..a08b78f930d30d674247a713fadd3e42e5ada350 100644 --- a/paddle/fluid/inference/tensorrt/test_engine.cc +++ b/paddle/fluid/inference/tensorrt/test_engine.cc @@ -70,7 +70,6 @@ TEST_F(TensorRTEngineTest, add_layer) { engine_->Execute(1); LOG(INFO) << "to get output"; - // void* y_v = float y_cpu; engine_->GetOutputInCPU("y", &y_cpu, sizeof(float));