提交 c4e3010b 编写于 作者: L Luo Tao

use template to do registry

上级 d599de5c
nv_library(tensorrt_convert SRCS convert.cc DEPS dynload_cuda)
nv_library(tensorrt_convert SRCS convert.cc mul_op.cc conv2d_op.cc DEPS dynload_cuda)
nv_test(test_tensorrt_convert SRCS test_convert.cc DEPS tensorrt paddle_fluid)
......@@ -12,25 +12,18 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/inference/tensorrt/convert/convert.h"
namespace paddle {
namespace inference {
namespace tensorrt {
class Conv2dOpConverter : public OpConverter {
public:
Conv2dOpConverter() {}
void Convert(const framework::OpDesc& op);
};
REGISTER_TRT_OP_CONVETER(conv2d, Conv2dOpConverter);
void Conv2dOpConverter::Convert(const framework::OpDesc& op) {
LOG(INFO) << "convert a fluid conv2d op to tensorrt conv layer without bias";
}
REGISTER_TRT_OP_CONVETER(conv2d, Conv2dOpConverter);
} // namespace tensorrt
} // namespace inference
} // namespace paddle
......@@ -13,8 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/convert.h"
#include "paddle/fluid/inference/tensorrt/convert/convert_conv2d.h"
#include "paddle/fluid/inference/tensorrt/convert/convert_mul.h"
namespace paddle {
namespace inference {
......@@ -23,10 +21,8 @@ namespace tensorrt {
void TensorRTConverter::ConvertBlock(const framework::BlockDesc& block) {
for (auto op : block.AllOps()) {
std::string type = op->Type();
PADDLE_ENFORCE(GetOpConverter().count(type),
"No converter registered for op: %s", type);
auto op_converter = GetOpConverter()[type];
op_converter->Convert(*op);
OpConverter op_converter;
op_converter.Convert(*op);
}
}
......
......@@ -26,9 +26,21 @@ namespace paddle {
namespace inference {
namespace tensorrt {
class ConverterBase {
class OpConverter {
public:
ConverterBase() {}
OpConverter() {}
void Convert(const framework::OpDesc& op) {
std::string type = op.Type();
OpConverter& op_converter = this->register_op_converter_[type];
op_converter.Convert(op);
}
template <typename T>
static void Register(const std::string key) {
register_op_converter_[key] = T();
}
static std::unordered_map<std::string, OpConverter> register_op_converter_;
// fluid inference scope
framework::Scope* scope_;
......@@ -37,30 +49,14 @@ class ConverterBase {
std::unordered_map<std::string, nvinfer1::ITensor*> tr_tensors_;
};
class OpConverter : public ConverterBase {
public:
OpConverter() {}
virtual ~OpConverter() {}
// convert fluid op to tensorrt layer
virtual void Convert(const framework::OpDesc& op) = 0;
};
static std::unordered_map<std::string, OpConverter*>& GetOpConverter() {
static std::unordered_map<std::string, OpConverter*> register_op_converter;
return register_op_converter;
}
#define REGISTER_TRT_OP_CONVETER(op_type, convert_class) \
class convert_class##Register { \
public: \
convert_class##Register() { \
GetOpConverter()[#op_type] = new convert_class; \
} \
}; \
convert_class##Register convert_class##reg;
#define REGISTER_TRT_OP_CONVETER(op_type, convert_class) \
class convert_class : public OpConverter { \
public: \
convert_class() { OpConverter::Register<convert_class>(#op_type); } \
void Convert(const framework::OpDesc& op); \
}
class TensorRTConverter : public ConverterBase {
class TensorRTConverter {
public:
TensorRTConverter() {}
......
......@@ -12,20 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/inference/tensorrt/convert/convert.h"
namespace paddle {
namespace inference {
namespace tensorrt {
class MulOpConverter : public OpConverter {
public:
MulOpConverter() {}
void Convert(const framework::OpDesc& op);
};
REGISTER_TRT_OP_CONVETER(mul, MulOpConverter);
void MulOpConverter::Convert(const framework::OpDesc& op) {
LOG(INFO) << "convert a fluid mul op to tensorrt fc layer without bias";
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册