From 89dcb0bd151313c758e539e5c90aa9b0cb53d27a Mon Sep 17 00:00:00 2001 From: Luo Tao Date: Tue, 8 May 2018 13:32:54 +0800 Subject: [PATCH] refine EngineIOConverter, and use io_convert in test_trt_activation_op --- .../inference/tensorrt/convert/CMakeLists.txt | 2 +- .../tensorrt/convert/io_converter.cc | 42 +++++++++---- .../inference/tensorrt/convert/io_converter.h | 53 ++++++++++------ .../tensorrt/convert/test_activation_op.cc | 39 +++++++----- .../tensorrt/convert/test_io_converter.cc | 63 +++++++++++++------ paddle/fluid/inference/tensorrt/engine.cc | 1 - 6 files changed, 131 insertions(+), 69 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt index 3c5909c0be1..bf494d921a1 100644 --- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -1,4 +1,4 @@ nv_test(test_op_converter SRCS test_op_converter.cc mul_op.cc conv2d_op.cc DEPS ${FLUID_CORE_MODULES}) -nv_test(test_trt_activation_op SRCS test_activation_op.cc ${ENGINE_FILE} activation_op.cc +nv_test(test_trt_activation_op SRCS test_activation_op.cc io_converter.cc ${ENGINE_FILE} activation_op.cc DEPS ${FLUID_CORE_MODULES} activation_op) nv_test(test_io_converter SRCS test_io_converter.cc io_converter.cc DEPS dynload_cuda dynamic_loader lod_tensor) diff --git a/paddle/fluid/inference/tensorrt/convert/io_converter.cc b/paddle/fluid/inference/tensorrt/convert/io_converter.cc index 32e8631fde3..13bc2b37595 100644 --- a/paddle/fluid/inference/tensorrt/convert/io_converter.cc +++ b/paddle/fluid/inference/tensorrt/convert/io_converter.cc @@ -23,26 +23,42 @@ namespace tensorrt { using platform::is_gpu_place; using platform::is_cpu_place; -class DefaultInputConverter : public EngineInputConverter { +class DefaultIOConverter : public EngineIOConverter { public: - DefaultInputConverter() {} + DefaultIOConverter() {} // NOTE out is GPU memory. virtual void operator()(const LoDTensor& in, void* out, size_t max_size) override { PADDLE_ENFORCE(out != nullptr); - PADDLE_ENFORCE_LE(in.memory_size(), max_size); + PADDLE_ENFORCE(stream_ != nullptr); const auto& place = in.place(); + size_t size = in.memory_size(); + PADDLE_ENFORCE_LE(size, max_size); if (is_cpu_place(place)) { - PADDLE_ENFORCE(stream_ != nullptr); - PADDLE_ENFORCE_EQ(0, - cudaMemcpyAsync(out, in.data(), in.memory_size(), - cudaMemcpyHostToDevice, *stream_)); - + PADDLE_ENFORCE_EQ(0, cudaMemcpyAsync(out, in.data(), size, + cudaMemcpyHostToDevice, *stream_)); } else if (is_gpu_place(place)) { - PADDLE_ENFORCE_EQ(0, - cudaMemcpyAsync(out, in.data(), in.memory_size(), - cudaMemcpyHostToHost, *stream_)); - + PADDLE_ENFORCE_EQ(0, cudaMemcpyAsync(out, in.data(), size, + cudaMemcpyHostToHost, *stream_)); + } else { + PADDLE_THROW("Unknown device for converter"); + } + cudaStreamSynchronize(*stream_); + } + // NOTE in is GPU memory. + virtual void operator()(const void* in, LoDTensor* out, + size_t max_size) override { + PADDLE_ENFORCE(in != nullptr); + PADDLE_ENFORCE(stream_ != nullptr); + const auto& place = out->place(); + size_t size = out->memory_size(); + PADDLE_ENFORCE_LE(size, max_size); + if (is_cpu_place(place)) { + PADDLE_ENFORCE_EQ(0, cudaMemcpyAsync(out->data(), in, size, + cudaMemcpyDeviceToHost, *stream_)); + } else if (is_gpu_place(place)) { + PADDLE_ENFORCE_EQ(0, cudaMemcpyAsync(out->data(), in, size, + cudaMemcpyHostToHost, *stream_)); } else { PADDLE_THROW("Unknown device for converter"); } @@ -50,7 +66,7 @@ class DefaultInputConverter : public EngineInputConverter { } }; -REGISTER_TENSORRT_INPUT_CONVERTER(default, DefaultInputConverter); +REGISTER_TENSORRT_IO_CONVERTER(default, DefaultIOConverter); } // namespace tensorrt } // namespace inference diff --git a/paddle/fluid/inference/tensorrt/convert/io_converter.h b/paddle/fluid/inference/tensorrt/convert/io_converter.h index 8972dae92be..71c48e085d2 100644 --- a/paddle/fluid/inference/tensorrt/convert/io_converter.h +++ b/paddle/fluid/inference/tensorrt/convert/io_converter.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include #include #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/inference/utils/singleton.h" @@ -25,43 +26,57 @@ namespace tensorrt { using framework::LoDTensor; /* - * Convert Input from Fluid to an Engine. - * TensorRT's ITensor follows row major, NCHW. Fluid is also row major, so in - * most cases just need to copy the data. + * Convert Input from Fluid to TensorRT Engine. + * Convert Output from TensorRT Engine to Fluid. + * + * Note that TensorRT's ITensor follows row major, NCHW. Fluid is also row + * major, + * so in the default case just need to copy the data. */ -class EngineInputConverter { +class EngineIOConverter { public: - EngineInputConverter() {} + EngineIOConverter() {} virtual void operator()(const LoDTensor& in, void* out, size_t max_size) {} + virtual void operator()(const void* in, LoDTensor* out, size_t max_size) {} void SetStream(cudaStream_t* stream) { stream_ = stream; } - static void Run(const std::string& in_op_type, const LoDTensor& in, void* out, - size_t max_size, cudaStream_t* stream) { + static void ConvertInput(const std::string& op_type, const LoDTensor& in, + void* out, size_t max_size, cudaStream_t* stream) { PADDLE_ENFORCE(stream != nullptr); - auto* converter = Registry::Lookup( - in_op_type, "default" /* default_type */); + auto* converter = Registry::Lookup( + op_type, "default" /* default_type */); PADDLE_ENFORCE_NOT_NULL(converter); converter->SetStream(stream); (*converter)(in, out, max_size); } - virtual ~EngineInputConverter() {} + static void ConvertOutput(const std::string& op_type, const void* in, + LoDTensor* out, size_t max_size, + cudaStream_t* stream) { + PADDLE_ENFORCE(stream != nullptr); + auto* converter = Registry::Lookup( + op_type, "default" /* default_type */); + PADDLE_ENFORCE_NOT_NULL(converter); + converter->SetStream(stream); + (*converter)(in, out, max_size); + } + + virtual ~EngineIOConverter() {} protected: cudaStream_t* stream_{nullptr}; }; +#define REGISTER_TENSORRT_IO_CONVERTER(op_type__, Converter__) \ + struct trt_io_##op_type__##_converter { \ + trt_io_##op_type__##_converter() { \ + Registry::Register(#op_type__); \ + } \ + }; \ + trt_io_##op_type__##_converter trt_io_##op_type__##_converter__; + } // namespace tensorrt } // namespace inference } // namespace paddle - -#define REGISTER_TENSORRT_INPUT_CONVERTER(in_op_type__, Converter__) \ - struct trt_input_##in_op_type__##_converter { \ - trt_input_##in_op_type__##_converter() { \ - ::paddle::inference::Registry::Register< \ - Converter__>(#in_op_type__); \ - } \ - }; \ - trt_input_##in_op_type__##_converter trt_input_##in_op_type__##_converter__; diff --git a/paddle/fluid/inference/tensorrt/convert/test_activation_op.cc b/paddle/fluid/inference/tensorrt/convert/test_activation_op.cc index 23e3435c217..c43f7202127 100644 --- a/paddle/fluid/inference/tensorrt/convert/test_activation_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/test_activation_op.cc @@ -16,6 +16,7 @@ limitations under the License. */ #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/inference/tensorrt/convert/io_converter.h" #include "paddle/fluid/inference/tensorrt/convert/op_converter.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/place.h" @@ -26,7 +27,7 @@ namespace paddle { namespace inference { namespace tensorrt { -void Compare(float input, float expect) { +void Compare(const std::string op_type, float input, float expect) { framework::Scope scope; platform::CUDAPlace place; platform::CUDADeviceContext ctx(place); @@ -35,6 +36,7 @@ void Compare(float input, float expect) { auto x_var = scope.Var("X"); auto x_tensor = x_var->GetMutable(); x_tensor->Resize({1, 1}); + x_tensor->mutable_data(place); std::vector init; init.push_back(input); framework::TensorFromVector(init, ctx, x_tensor); @@ -45,14 +47,15 @@ void Compare(float input, float expect) { out_tensor->mutable_data(place); framework::OpDesc op_desc; - op_desc.SetType("relu"); + op_desc.SetType(op_type); op_desc.SetInput("X", {"X"}); op_desc.SetOutput("Out", {"Out"}); - auto relu_op = framework::OpRegistry::CreateOp(op_desc); + auto op = framework::OpRegistry::CreateOp(op_desc); // run fluid op - relu_op->Run(scope, place); + op->Run(scope, place); + // get fluid output std::vector out1; framework::TensorToVector(*out_tensor, ctx, &out1); @@ -63,21 +66,27 @@ void Compare(float input, float expect) { engine->InitNetwork(); engine->DeclareInput("X", nvinfer1::DataType::kFLOAT, nvinfer1::DimsCHW{1, 1, 1}); - + // convert op OpConverter op_converter; op_converter.ConvertOp(op_desc, engine); - engine->DeclareOutput("Out"); engine->FreezeNetwork(); - engine->SetInputFromCPU("X", &input, 1 * sizeof(float)); - // run tensorrt op + // convert LoDTensor to ITensor + size_t size = x_tensor->memory_size(); + EngineIOConverter::ConvertInput(op_type, *x_tensor, engine->buffer("X"), size, + &stream); + // run tensorrt Outp engine->Execute(1); - - float out2; - engine->GetOutputInCPU("Out", &out2, 1 * sizeof(float)); - - ASSERT_EQ(out1[0], out2); + // convert ITensor to LoDTensor + EngineIOConverter::ConvertOutput(op_type, engine->buffer("Out"), out_tensor, + size, &stream); + // get tensorrt output + std::vector out2; + framework::TensorToVector(*out_tensor, ctx, &out2); + + // compare + ASSERT_EQ(out1[0], out2[0]); ASSERT_EQ(out1[0], expect); delete engine; @@ -85,8 +94,8 @@ void Compare(float input, float expect) { } TEST(OpConverter, ConvertRelu) { - Compare(1, 1); // relu(1) = 1 - Compare(-5, 0); // relu(-5) = 0 + Compare("relu", 1, 1); // relu(1) = 1 + Compare("relu", -5, 0); // relu(-5) = 0 } } // namespace tensorrt diff --git a/paddle/fluid/inference/tensorrt/convert/test_io_converter.cc b/paddle/fluid/inference/tensorrt/convert/test_io_converter.cc index afcc516e6b7..8f91309a0a0 100644 --- a/paddle/fluid/inference/tensorrt/convert/test_io_converter.cc +++ b/paddle/fluid/inference/tensorrt/convert/test_io_converter.cc @@ -12,40 +12,63 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/inference/tensorrt/convert/io_converter.h" -#include - namespace paddle { namespace inference { namespace tensorrt { -class EngineInputConverterTester : public ::testing::Test { - public: - void SetUp() override { tensor.Resize({10, 10}); } +void IOConverterTester(const platform::DeviceContext& ctx) { + cudaStream_t stream; + ASSERT_EQ(0, cudaStreamCreate(&stream)); - framework::LoDTensor tensor; -}; + // init fluid in_tensor + framework::LoDTensor in_tensor; + in_tensor.Resize({10, 10}); + auto place = ctx.GetPlace(); + in_tensor.mutable_data(place); + std::vector init; + for (int64_t i = 0; i < 10 * 10; ++i) { + init.push_back(i); + } + framework::TensorFromVector(init, ctx, &in_tensor); -TEST_F(EngineInputConverterTester, DefaultCPU) { + // init tensorrt buffer void* buffer; - tensor.mutable_data(platform::CPUPlace()); - ASSERT_EQ(cudaMalloc(&buffer, tensor.memory_size()), 0); + size_t size = in_tensor.memory_size(); + ASSERT_EQ(cudaMalloc(&buffer, size), 0); - cudaStream_t stream; - EngineInputConverter::Run("test", tensor, buffer, tensor.memory_size(), - &stream); + // convert fluid in_tensor to tensorrt buffer + EngineIOConverter::ConvertInput("test", in_tensor, buffer, size, &stream); + + // convert tensorrt buffer to fluid out_tensor + framework::LoDTensor out_tensor; + out_tensor.Resize({10, 10}); + out_tensor.mutable_data(place); + EngineIOConverter::ConvertOutput("test", buffer, &out_tensor, size, &stream); + + // compare in_tensor and out_tensor + std::vector result; + framework::TensorToVector(out_tensor, ctx, &result); + EXPECT_EQ(init.size(), result.size()); + for (size_t i = 0; i < init.size(); i++) { + EXPECT_EQ(init[i], result[i]); + } + cudaStreamDestroy(stream); } -TEST_F(EngineInputConverterTester, DefaultGPU) { - void* buffer; - tensor.mutable_data(platform::CUDAPlace()); - ASSERT_EQ(cudaMalloc(&buffer, tensor.memory_size()), 0); +TEST(EngineIOConverterTester, DefaultCPU) { + platform::CPUPlace place; + platform::CPUDeviceContext ctx(place); + IOConverterTester(ctx); +} - cudaStream_t stream; - EngineInputConverter::Run("test", tensor, buffer, tensor.memory_size(), - &stream); +TEST(EngineIOConverterTester, DefaultGPU) { + platform::CUDAPlace place; + platform::CUDADeviceContext ctx(place); + IOConverterTester(ctx); } } // namespace tensorrt diff --git a/paddle/fluid/inference/tensorrt/engine.cc b/paddle/fluid/inference/tensorrt/engine.cc index df123a59079..0a69ab9bdde 100644 --- a/paddle/fluid/inference/tensorrt/engine.cc +++ b/paddle/fluid/inference/tensorrt/engine.cc @@ -138,7 +138,6 @@ void*& TensorRTEngine::buffer(const std::string& name) { void TensorRTEngine::SetInputFromCPU(const std::string& name, void* data, size_t size) { void* buf = buffer(name); - cudaMemcpyAsync(buf, data, size, cudaMemcpyHostToDevice, *stream_); PADDLE_ENFORCE_EQ( 0, cudaMemcpyAsync(buf, data, size, cudaMemcpyHostToDevice, *stream_)); } -- GitLab