From c4e3010b14cfbc3847466843ee58e49792e31b27 Mon Sep 17 00:00:00 2001 From: Luo Tao Date: Wed, 25 Apr 2018 22:30:57 +0800 Subject: [PATCH] use template to do registry --- .../inference/tensorrt/convert/CMakeLists.txt | 2 +- .../{convert_conv2d.h => conv2d_op.cc} | 9 +--- .../inference/tensorrt/convert/convert.cc | 8 +--- .../inference/tensorrt/convert/convert.h | 46 +++++++++---------- .../convert/{convert_mul.h => mul_op.cc} | 8 +--- 5 files changed, 26 insertions(+), 47 deletions(-) rename paddle/fluid/inference/tensorrt/convert/{convert_conv2d.h => conv2d_op.cc} (87%) rename paddle/fluid/inference/tensorrt/convert/{convert_mul.h => mul_op.cc} (87%) diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt index cd51fd609c0..c4b8514c1c9 100644 --- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -1,2 +1,2 @@ -nv_library(tensorrt_convert SRCS convert.cc DEPS dynload_cuda) +nv_library(tensorrt_convert SRCS convert.cc mul_op.cc conv2d_op.cc DEPS dynload_cuda) nv_test(test_tensorrt_convert SRCS test_convert.cc DEPS tensorrt paddle_fluid) diff --git a/paddle/fluid/inference/tensorrt/convert/convert_conv2d.h b/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc similarity index 87% rename from paddle/fluid/inference/tensorrt/convert/convert_conv2d.h rename to paddle/fluid/inference/tensorrt/convert/conv2d_op.cc index 34622f92a49..1201a7696ae 100644 --- a/paddle/fluid/inference/tensorrt/convert/convert_conv2d.h +++ b/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc @@ -12,25 +12,18 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#pragma once #include "paddle/fluid/inference/tensorrt/convert/convert.h" namespace paddle { namespace inference { namespace tensorrt { -class Conv2dOpConverter : public OpConverter { - public: - Conv2dOpConverter() {} - void Convert(const framework::OpDesc& op); -}; +REGISTER_TRT_OP_CONVETER(conv2d, Conv2dOpConverter); void Conv2dOpConverter::Convert(const framework::OpDesc& op) { LOG(INFO) << "convert a fluid conv2d op to tensorrt conv layer without bias"; } -REGISTER_TRT_OP_CONVETER(conv2d, Conv2dOpConverter); - } // namespace tensorrt } // namespace inference } // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/convert/convert.cc b/paddle/fluid/inference/tensorrt/convert/convert.cc index bf6f1cd2c1c..78a72b1a8ba 100644 --- a/paddle/fluid/inference/tensorrt/convert/convert.cc +++ b/paddle/fluid/inference/tensorrt/convert/convert.cc @@ -13,8 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/inference/tensorrt/convert/convert.h" -#include "paddle/fluid/inference/tensorrt/convert/convert_conv2d.h" -#include "paddle/fluid/inference/tensorrt/convert/convert_mul.h" namespace paddle { namespace inference { @@ -23,10 +21,8 @@ namespace tensorrt { void TensorRTConverter::ConvertBlock(const framework::BlockDesc& block) { for (auto op : block.AllOps()) { std::string type = op->Type(); - PADDLE_ENFORCE(GetOpConverter().count(type), - "No converter registered for op: %s", type); - auto op_converter = GetOpConverter()[type]; - op_converter->Convert(*op); + OpConverter op_converter; + op_converter.Convert(*op); } } diff --git a/paddle/fluid/inference/tensorrt/convert/convert.h b/paddle/fluid/inference/tensorrt/convert/convert.h index 4f952330573..953086ace96 100644 --- a/paddle/fluid/inference/tensorrt/convert/convert.h +++ b/paddle/fluid/inference/tensorrt/convert/convert.h @@ -26,9 +26,21 @@ namespace paddle { namespace inference { namespace tensorrt { -class ConverterBase { +class OpConverter { public: - ConverterBase() {} + OpConverter() {} + + void Convert(const framework::OpDesc& op) { + std::string type = op.Type(); + OpConverter& op_converter = this->register_op_converter_[type]; + op_converter.Convert(op); + } + + template + static void Register(const std::string key) { + register_op_converter_[key] = T(); + } + static std::unordered_map register_op_converter_; // fluid inference scope framework::Scope* scope_; @@ -37,30 +49,14 @@ class ConverterBase { std::unordered_map tr_tensors_; }; -class OpConverter : public ConverterBase { - public: - OpConverter() {} - virtual ~OpConverter() {} - - // convert fluid op to tensorrt layer - virtual void Convert(const framework::OpDesc& op) = 0; -}; - -static std::unordered_map& GetOpConverter() { - static std::unordered_map register_op_converter; - return register_op_converter; -} - -#define REGISTER_TRT_OP_CONVETER(op_type, convert_class) \ - class convert_class##Register { \ - public: \ - convert_class##Register() { \ - GetOpConverter()[#op_type] = new convert_class; \ - } \ - }; \ - convert_class##Register convert_class##reg; +#define REGISTER_TRT_OP_CONVETER(op_type, convert_class) \ + class convert_class : public OpConverter { \ + public: \ + convert_class() { OpConverter::Register(#op_type); } \ + void Convert(const framework::OpDesc& op); \ + } -class TensorRTConverter : public ConverterBase { +class TensorRTConverter { public: TensorRTConverter() {} diff --git a/paddle/fluid/inference/tensorrt/convert/convert_mul.h b/paddle/fluid/inference/tensorrt/convert/mul_op.cc similarity index 87% rename from paddle/fluid/inference/tensorrt/convert/convert_mul.h rename to paddle/fluid/inference/tensorrt/convert/mul_op.cc index a626300cf32..0ce5eb73024 100644 --- a/paddle/fluid/inference/tensorrt/convert/convert_mul.h +++ b/paddle/fluid/inference/tensorrt/convert/mul_op.cc @@ -12,20 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#pragma once #include "paddle/fluid/inference/tensorrt/convert/convert.h" namespace paddle { namespace inference { namespace tensorrt { -class MulOpConverter : public OpConverter { - public: - MulOpConverter() {} - void Convert(const framework::OpDesc& op); -}; - REGISTER_TRT_OP_CONVETER(mul, MulOpConverter); + void MulOpConverter::Convert(const framework::OpDesc& op) { LOG(INFO) << "convert a fluid mul op to tensorrt fc layer without bias"; } -- GitLab