diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt index cd51fd609c02b483bc289c513584855aed6a3f0b..c4b8514c1c966f64bd5abfa061989034caeb128a 100644 --- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -1,2 +1,2 @@ -nv_library(tensorrt_convert SRCS convert.cc DEPS dynload_cuda) +nv_library(tensorrt_convert SRCS convert.cc mul_op.cc conv2d_op.cc DEPS dynload_cuda) nv_test(test_tensorrt_convert SRCS test_convert.cc DEPS tensorrt paddle_fluid) diff --git a/paddle/fluid/inference/tensorrt/convert/convert_conv2d.h b/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc similarity index 87% rename from paddle/fluid/inference/tensorrt/convert/convert_conv2d.h rename to paddle/fluid/inference/tensorrt/convert/conv2d_op.cc index 34622f92a492b725e65a3e5121de6c47a6aa3a91..1201a7696ae424197610b2ed1c7299d91c05d689 100644 --- a/paddle/fluid/inference/tensorrt/convert/convert_conv2d.h +++ b/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc @@ -12,25 +12,18 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#pragma once #include "paddle/fluid/inference/tensorrt/convert/convert.h" namespace paddle { namespace inference { namespace tensorrt { -class Conv2dOpConverter : public OpConverter { - public: - Conv2dOpConverter() {} - void Convert(const framework::OpDesc& op); -}; +REGISTER_TRT_OP_CONVETER(conv2d, Conv2dOpConverter); void Conv2dOpConverter::Convert(const framework::OpDesc& op) { LOG(INFO) << "convert a fluid conv2d op to tensorrt conv layer without bias"; } -REGISTER_TRT_OP_CONVETER(conv2d, Conv2dOpConverter); - } // namespace tensorrt } // namespace inference } // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/convert/convert.cc b/paddle/fluid/inference/tensorrt/convert/convert.cc index bf6f1cd2c1c943128ff0c84570bda2b0d638b6e2..78a72b1a8baea3caa3836db03179745e8d52bf63 100644 --- a/paddle/fluid/inference/tensorrt/convert/convert.cc +++ b/paddle/fluid/inference/tensorrt/convert/convert.cc @@ -13,8 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/inference/tensorrt/convert/convert.h" -#include "paddle/fluid/inference/tensorrt/convert/convert_conv2d.h" -#include "paddle/fluid/inference/tensorrt/convert/convert_mul.h" namespace paddle { namespace inference { @@ -23,10 +21,8 @@ namespace tensorrt { void TensorRTConverter::ConvertBlock(const framework::BlockDesc& block) { for (auto op : block.AllOps()) { std::string type = op->Type(); - PADDLE_ENFORCE(GetOpConverter().count(type), - "No converter registered for op: %s", type); - auto op_converter = GetOpConverter()[type]; - op_converter->Convert(*op); + OpConverter op_converter; + op_converter.Convert(*op); } } diff --git a/paddle/fluid/inference/tensorrt/convert/convert.h b/paddle/fluid/inference/tensorrt/convert/convert.h index 4f9523305737b4ac70a839b64657b4af032beefd..953086ace962a7ede0c785c322b90f9f1bd9c0eb 100644 --- a/paddle/fluid/inference/tensorrt/convert/convert.h +++ b/paddle/fluid/inference/tensorrt/convert/convert.h @@ -26,9 +26,21 @@ namespace paddle { namespace inference { namespace tensorrt { -class ConverterBase { +class OpConverter { public: - ConverterBase() {} + OpConverter() {} + + void Convert(const framework::OpDesc& op) { + std::string type = op.Type(); + OpConverter& op_converter = this->register_op_converter_[type]; + op_converter.Convert(op); + } + + template + static void Register(const std::string key) { + register_op_converter_[key] = T(); + } + static std::unordered_map register_op_converter_; // fluid inference scope framework::Scope* scope_; @@ -37,30 +49,14 @@ class ConverterBase { std::unordered_map tr_tensors_; }; -class OpConverter : public ConverterBase { - public: - OpConverter() {} - virtual ~OpConverter() {} - - // convert fluid op to tensorrt layer - virtual void Convert(const framework::OpDesc& op) = 0; -}; - -static std::unordered_map& GetOpConverter() { - static std::unordered_map register_op_converter; - return register_op_converter; -} - -#define REGISTER_TRT_OP_CONVETER(op_type, convert_class) \ - class convert_class##Register { \ - public: \ - convert_class##Register() { \ - GetOpConverter()[#op_type] = new convert_class; \ - } \ - }; \ - convert_class##Register convert_class##reg; +#define REGISTER_TRT_OP_CONVETER(op_type, convert_class) \ + class convert_class : public OpConverter { \ + public: \ + convert_class() { OpConverter::Register(#op_type); } \ + void Convert(const framework::OpDesc& op); \ + } -class TensorRTConverter : public ConverterBase { +class TensorRTConverter { public: TensorRTConverter() {} diff --git a/paddle/fluid/inference/tensorrt/convert/convert_mul.h b/paddle/fluid/inference/tensorrt/convert/mul_op.cc similarity index 87% rename from paddle/fluid/inference/tensorrt/convert/convert_mul.h rename to paddle/fluid/inference/tensorrt/convert/mul_op.cc index a626300cf32e440ec77cd193c67be8174802c8d5..0ce5eb73024466777488dd9576419fe67f9843ea 100644 --- a/paddle/fluid/inference/tensorrt/convert/convert_mul.h +++ b/paddle/fluid/inference/tensorrt/convert/mul_op.cc @@ -12,20 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#pragma once #include "paddle/fluid/inference/tensorrt/convert/convert.h" namespace paddle { namespace inference { namespace tensorrt { -class MulOpConverter : public OpConverter { - public: - MulOpConverter() {} - void Convert(const framework::OpDesc& op); -}; - REGISTER_TRT_OP_CONVETER(mul, MulOpConverter); + void MulOpConverter::Convert(const framework::OpDesc& op) { LOG(INFO) << "convert a fluid mul op to tensorrt fc layer without bias"; }