diff --git a/paddle/fluid/inference/tensorrt/CMakeLists.txt b/paddle/fluid/inference/tensorrt/CMakeLists.txt index 288789d6e484100820c937e6081701f1e9245706..c8b656394b403c4965e01e96c9215d9406091907 100644 --- a/paddle/fluid/inference/tensorrt/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/CMakeLists.txt @@ -1,4 +1,5 @@ nv_test(test_tensorrt SRCS test_tensorrt.cc DEPS dynload_cuda device_context dynamic_loader) nv_test(test_tensorrt_engine SRCS test_engine.cc engine.cc DEPS dynload_cuda) +nv_test(test_io_converter SRCS test_io_converter.cc io_converter.cc DEPS dynload_cuda dynamic_loader lod_tensor) set(ENGINE_FILE ${CMAKE_CURRENT_SOURCE_DIR}/engine.cc) add_subdirectory(convert) diff --git a/paddle/fluid/inference/tensorrt/io_converter.cc b/paddle/fluid/inference/tensorrt/io_converter.cc new file mode 100644 index 0000000000000000000000000000000000000000..2baac96c26453af7e70e541d80b437df3d5c2657 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/io_converter.cc @@ -0,0 +1,57 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/io_converter.h" +#include +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +using platform::is_gpu_place; +using platform::is_cpu_place; + +class DefaultInputConverter : public EngineInputConverter { + public: + DefaultInputConverter() {} + // NOTE out is GPU memory. + virtual void operator()(const LoDTensor& in, void* out, + size_t max_size) override { + PADDLE_ENFORCE(out != nullptr); + PADDLE_ENFORCE_LE(in.memory_size(), max_size); + const auto& place = in.place(); + if (is_cpu_place(place)) { + PADDLE_ENFORCE(stream_ != nullptr); + PADDLE_ENFORCE_EQ(0, + cudaMemcpyAsync(out, in.data(), in.memory_size(), + cudaMemcpyHostToDevice, *stream_)); + + } else if (is_gpu_place(place)) { + PADDLE_ENFORCE_EQ(0, + cudaMemcpyAsync(out, in.data(), in.memory_size(), + cudaMemcpyHostToHost, *stream_)); + + } else { + PADDLE_THROW("Unknown device for converter"); + } + cudaStreamSynchronize(*stream_); + } +}; + +REGISTER_TENSORRT_INPUT_CONVERTER(mul, DefaultInputConverter); + +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/io_converter.h b/paddle/fluid/inference/tensorrt/io_converter.h new file mode 100644 index 0000000000000000000000000000000000000000..6ea61cbbac05f106f736b7d6a13912157c5ef48c --- /dev/null +++ b/paddle/fluid/inference/tensorrt/io_converter.h @@ -0,0 +1,66 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/inference/utils/singleton.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +using framework::LoDTensor; + +/* + * Convert Input from Fluid to an Engine. + * TensorRT's ITensor follows row major, NCHW. Fluid is also row major, so in + * most cases just need to copy the data. + */ +class EngineInputConverter { + public: + EngineInputConverter() {} + + virtual void operator()(const LoDTensor& in, void* out, size_t max_size) {} + + void SetStream(cudaStream_t* stream) { stream_ = stream; } + + static void Run(const std::string& in_op_type, const LoDTensor& in, void* out, + size_t max_size, cudaStream_t* stream) { + PADDLE_ENFORCE(stream != nullptr); + auto* converter = Registry::Lookup(in_op_type); + PADDLE_ENFORCE_NOT_NULL(converter); + converter->SetStream(stream); + (*converter)(in, out, max_size); + } + + virtual ~EngineInputConverter() {} + + protected: + cudaStream_t* stream_{nullptr}; +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +#define REGISTER_TENSORRT_INPUT_CONVERTER(in_op_type__, Converter__) \ + struct trt_input_##in_op_type__##_converter { \ + trt_input_##in_op_type__##_converter() { \ + ::paddle::inference::Registry::Register< \ + Converter__>(#in_op_type__); \ + } \ + }; \ + trt_input_##in_op_type__##_converter trt_input_##in_op_type__##_converter__; diff --git a/paddle/fluid/inference/tensorrt/test_io_converter.cc b/paddle/fluid/inference/tensorrt/test_io_converter.cc new file mode 100644 index 0000000000000000000000000000000000000000..365e9366862bee25c70dba0cdd92f318ab3ee90f --- /dev/null +++ b/paddle/fluid/inference/tensorrt/test_io_converter.cc @@ -0,0 +1,53 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/inference/tensorrt/io_converter.h" + +#include + +namespace paddle { +namespace inference { +namespace tensorrt { + +class EngineInputConverterTester : public ::testing::Test { + public: + void SetUp() override { tensor.Resize({10, 10}); } + + framework::LoDTensor tensor; +}; + +TEST_F(EngineInputConverterTester, DefaultCPU) { + void* buffer; + tensor.mutable_data(platform::CPUPlace()); + ASSERT_EQ(cudaMalloc(&buffer, tensor.memory_size()), 0); + + cudaStream_t stream; + EngineInputConverter::Run("mul", tensor, buffer, tensor.memory_size(), + &stream); +} + +TEST_F(EngineInputConverterTester, DefaultGPU) { + void* buffer; + tensor.mutable_data(platform::CUDAPlace()); + ASSERT_EQ(cudaMalloc(&buffer, tensor.memory_size()), 0); + + cudaStream_t stream; + EngineInputConverter::Run("mul", tensor, buffer, tensor.memory_size(), + &stream); +} + +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/utils/singleton.h b/paddle/fluid/inference/utils/singleton.h new file mode 100644 index 0000000000000000000000000000000000000000..f05921067c45f156319375b613f51101cfda8e90 --- /dev/null +++ b/paddle/fluid/inference/utils/singleton.h @@ -0,0 +1,73 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace inference { + +// NOTE not thread-safe. +template +struct Singleton { + static T& Global() { + static T* x = new T; + return *x; + } + + Singleton() = delete; + Singleton& operator=(const Singleton&) = delete; +}; + +/* + * An registor for any type. + * NOTE not thread-safe. + */ +template +struct Registry { + static Registry& Global() { + static auto* x = new Registry; + return *x; + } + + template + static void Register(const std::string& name) { + PADDLE_ENFORCE_EQ(items_.count(name), 0); + items_[name] = new ItemChild; + } + + static ItemParent* Lookup(const std::string& name) { + auto it = items_.find(name); + if (it == items_.end()) return nullptr; + return it->second; + } + + ~Registry() { + for (auto& item : items_) { + delete item.second; + } + } + + private: + Registry() = default; + static std::unordered_map items_; +}; + +template +std::unordered_map Registry::items_; + +} // namespace inference +} // namespace paddle