提交 c4e3010b 编写于 作者: L Luo Tao

use template to do registry

上级 d599de5c
nv_library(tensorrt_convert SRCS convert.cc DEPS dynload_cuda) nv_library(tensorrt_convert SRCS convert.cc mul_op.cc conv2d_op.cc DEPS dynload_cuda)
nv_test(test_tensorrt_convert SRCS test_convert.cc DEPS tensorrt paddle_fluid) nv_test(test_tensorrt_convert SRCS test_convert.cc DEPS tensorrt paddle_fluid)
...@@ -12,25 +12,18 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,25 +12,18 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once
#include "paddle/fluid/inference/tensorrt/convert/convert.h" #include "paddle/fluid/inference/tensorrt/convert/convert.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
namespace tensorrt { namespace tensorrt {
class Conv2dOpConverter : public OpConverter { REGISTER_TRT_OP_CONVETER(conv2d, Conv2dOpConverter);
public:
Conv2dOpConverter() {}
void Convert(const framework::OpDesc& op);
};
void Conv2dOpConverter::Convert(const framework::OpDesc& op) { void Conv2dOpConverter::Convert(const framework::OpDesc& op) {
LOG(INFO) << "convert a fluid conv2d op to tensorrt conv layer without bias"; LOG(INFO) << "convert a fluid conv2d op to tensorrt conv layer without bias";
} }
REGISTER_TRT_OP_CONVETER(conv2d, Conv2dOpConverter);
} // namespace tensorrt } // namespace tensorrt
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
...@@ -13,8 +13,6 @@ See the License for the specific language governing permissions and ...@@ -13,8 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/convert.h" #include "paddle/fluid/inference/tensorrt/convert/convert.h"
#include "paddle/fluid/inference/tensorrt/convert/convert_conv2d.h"
#include "paddle/fluid/inference/tensorrt/convert/convert_mul.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
...@@ -23,10 +21,8 @@ namespace tensorrt { ...@@ -23,10 +21,8 @@ namespace tensorrt {
void TensorRTConverter::ConvertBlock(const framework::BlockDesc& block) { void TensorRTConverter::ConvertBlock(const framework::BlockDesc& block) {
for (auto op : block.AllOps()) { for (auto op : block.AllOps()) {
std::string type = op->Type(); std::string type = op->Type();
PADDLE_ENFORCE(GetOpConverter().count(type), OpConverter op_converter;
"No converter registered for op: %s", type); op_converter.Convert(*op);
auto op_converter = GetOpConverter()[type];
op_converter->Convert(*op);
} }
} }
......
...@@ -26,9 +26,21 @@ namespace paddle { ...@@ -26,9 +26,21 @@ namespace paddle {
namespace inference { namespace inference {
namespace tensorrt { namespace tensorrt {
class ConverterBase { class OpConverter {
public: public:
ConverterBase() {} OpConverter() {}
void Convert(const framework::OpDesc& op) {
std::string type = op.Type();
OpConverter& op_converter = this->register_op_converter_[type];
op_converter.Convert(op);
}
template <typename T>
static void Register(const std::string key) {
register_op_converter_[key] = T();
}
static std::unordered_map<std::string, OpConverter> register_op_converter_;
// fluid inference scope // fluid inference scope
framework::Scope* scope_; framework::Scope* scope_;
...@@ -37,30 +49,14 @@ class ConverterBase { ...@@ -37,30 +49,14 @@ class ConverterBase {
std::unordered_map<std::string, nvinfer1::ITensor*> tr_tensors_; std::unordered_map<std::string, nvinfer1::ITensor*> tr_tensors_;
}; };
class OpConverter : public ConverterBase { #define REGISTER_TRT_OP_CONVETER(op_type, convert_class) \
public: class convert_class : public OpConverter { \
OpConverter() {} public: \
virtual ~OpConverter() {} convert_class() { OpConverter::Register<convert_class>(#op_type); } \
void Convert(const framework::OpDesc& op); \
// convert fluid op to tensorrt layer }
virtual void Convert(const framework::OpDesc& op) = 0;
};
static std::unordered_map<std::string, OpConverter*>& GetOpConverter() {
static std::unordered_map<std::string, OpConverter*> register_op_converter;
return register_op_converter;
}
#define REGISTER_TRT_OP_CONVETER(op_type, convert_class) \
class convert_class##Register { \
public: \
convert_class##Register() { \
GetOpConverter()[#op_type] = new convert_class; \
} \
}; \
convert_class##Register convert_class##reg;
class TensorRTConverter : public ConverterBase { class TensorRTConverter {
public: public:
TensorRTConverter() {} TensorRTConverter() {}
......
...@@ -12,20 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,20 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once
#include "paddle/fluid/inference/tensorrt/convert/convert.h" #include "paddle/fluid/inference/tensorrt/convert/convert.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
namespace tensorrt { namespace tensorrt {
class MulOpConverter : public OpConverter {
public:
MulOpConverter() {}
void Convert(const framework::OpDesc& op);
};
REGISTER_TRT_OP_CONVETER(mul, MulOpConverter); REGISTER_TRT_OP_CONVETER(mul, MulOpConverter);
void MulOpConverter::Convert(const framework::OpDesc& op) { void MulOpConverter::Convert(const framework::OpDesc& op) {
LOG(INFO) << "convert a fluid mul op to tensorrt fc layer without bias"; LOG(INFO) << "convert a fluid mul op to tensorrt fc layer without bias";
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册