提交 6f6f3304 编写于 作者: L Luo Tao

update the register method

上级 326221ac
nv_test(test_tensorrt SRCS test_tensorrt.cc DEPS dynload_cuda device_context dynamic_loader) nv_test(test_tensorrt SRCS test_tensorrt.cc DEPS dynload_cuda device_context dynamic_loader)
nv_test(test_tensorrt_engine SRCS test_engine.cc engine.cc DEPS dynload_cuda) nv_test(test_tensorrt_engine SRCS test_engine.cc engine.cc DEPS dynload_cuda)
cc_library(tensorrt DEPS tensorrt_convert)
add_subdirectory(convert) add_subdirectory(convert)
nv_library(tensorrt_convert SRCS convert.cc mul_op.cc conv2d_op.cc DEPS dynload_cuda) file(GLOB TENSORRT_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*_op.cc")
nv_test(test_tensorrt_convert SRCS test_convert.cc DEPS tensorrt paddle_fluid) nv_test(test_tensorrt_op_converter SRCS test_op_converter.cc ${TENSORRT_OPS} DEPS ${FLUID_CORE_MODULES})
...@@ -4,7 +4,7 @@ Licensed under the Apache License, Version 2.0 (the "License"); ...@@ -4,7 +4,7 @@ Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
...@@ -12,17 +12,22 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,17 +12,22 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/convert.h" #include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
namespace tensorrt { namespace tensorrt {
REGISTER_TRT_OP_CONVETER(conv2d, Conv2dOpConverter); class Conv2dOpConverter : public OpConverter {
public:
Conv2dOpConverter() {}
void operator()(const framework::OpDesc& op) override {
LOG(INFO)
<< "convert a fluid conv2d op to tensorrt conv layer without bias";
}
};
void Conv2dOpConverter::Convert(const framework::OpDesc& op) { REGISTER_TRT_OP_CONVERTER(conv2d, Conv2dOpConverter);
LOG(INFO) << "convert a fluid conv2d op to tensorrt conv layer without bias";
}
} // namespace tensorrt } // namespace tensorrt
} // namespace inference } // namespace inference
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/convert.h"
namespace paddle {
namespace inference {
namespace tensorrt {
void TensorRTConverter::ConvertBlock(const framework::BlockDesc& block) {
for (auto op : block.AllOps()) {
std::string type = op->Type();
OpConverter op_converter;
op_converter.Convert(*op);
}
}
} // namespace tensorrt
} // namespace inference
} // namespace paddle
...@@ -4,7 +4,7 @@ Licensed under the Apache License, Version 2.0 (the "License"); ...@@ -4,7 +4,7 @@ Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
...@@ -12,17 +12,21 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,17 +12,21 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/convert.h" #include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
namespace tensorrt { namespace tensorrt {
REGISTER_TRT_OP_CONVETER(mul, MulOpConverter); class MulOpConverter : public OpConverter {
public:
MulOpConverter() {}
void operator()(const framework::OpDesc& op) override {
LOG(INFO) << "convert a fluid mul op to tensorrt fc layer without bias";
}
};
void MulOpConverter::Convert(const framework::OpDesc& op) { REGISTER_TRT_OP_CONVERTER(mul, MulOpConverter);
LOG(INFO) << "convert a fluid mul op to tensorrt fc layer without bias";
}
} // namespace tensorrt } // namespace tensorrt
} // namespace inference } // namespace inference
......
...@@ -4,7 +4,7 @@ Licensed under the Apache License, Version 2.0 (the "License"); ...@@ -4,7 +4,7 @@ Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
...@@ -14,54 +14,74 @@ limitations under the License. */ ...@@ -14,54 +14,74 @@ limitations under the License. */
#pragma once #pragma once
#include <NvInfer.h>
#include <functional>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
namespace tensorrt { namespace tensorrt {
/*
* Convert Op from Fluid to TensorRT Engine.
*/
class OpConverter { class OpConverter {
public: public:
OpConverter() {} OpConverter() {}
void Convert(const framework::OpDesc& op) { virtual void operator()(const framework::OpDesc& op) {}
void Execute(const framework::OpDesc& op) {
std::string type = op.Type(); std::string type = op.Type();
OpConverter& op_converter = this->register_op_converter_[type]; auto it = converters_.find(type);
op_converter.Convert(op); PADDLE_ENFORCE(it != converters_.end(), "no OpConverter for optype [%s]",
type);
(*it->second)(op);
}
static OpConverter& Global() {
static auto* x = new OpConverter;
return *x;
} }
template <typename T> template <typename T>
static void Register(const std::string key) { void Register(const std::string& key) {
register_op_converter_[key] = T(); converters_[key] = new T;
} }
static std::unordered_map<std::string, OpConverter> register_op_converter_;
virtual ~OpConverter() {}
private:
// registered op converter map, whose key is the fluid op type, and value is
// the pointer position of corresponding OpConverter class.
std::unordered_map<std::string, OpConverter*> converters_;
// fluid inference scope // fluid inference scope
framework::Scope* scope_; framework::Scope* scope_;
// tensorrt input/output tensor list, whose key is the fluid variable name, // tensorrt input/output tensor map, whose key is the fluid variable name,
// and value is the pointer position of tensorrt tensor // and value is the pointer position of tensorrt tensor
std::unordered_map<std::string, nvinfer1::ITensor*> tr_tensors_; std::unordered_map<std::string, nvinfer1::ITensor*> tr_tensors_;
}; };
#define REGISTER_TRT_OP_CONVETER(op_type, convert_class) \ #define REGISTER_TRT_OP_CONVERTER(op_type__, Converter__) \
class convert_class : public OpConverter { \ struct trt_##op_type__##_converter { \
public: \ trt_##op_type__##_converter() { \
convert_class() { OpConverter::Register<convert_class>(#op_type); } \ OpConverter::Global().Register<Converter__>(#op_type__); \
void Convert(const framework::OpDesc& op); \ } \
} }; \
trt_##op_type__##_converter trt_##op_type__##_converter__;
class TensorRTConverter { class BlockConverter {
public: public:
TensorRTConverter() {} BlockConverter() {}
// convert fluid block to tensorrt network // convert fluid block to tensorrt network
void ConvertBlock(const framework::BlockDesc& block); void ConvertBlock(const framework::BlockDesc& block) {
for (auto op : block.AllOps()) {
OpConverter::Global().Execute(*op);
}
}
}; };
} // namespace tensorrt } // namespace tensorrt
......
...@@ -14,13 +14,13 @@ limitations under the License. */ ...@@ -14,13 +14,13 @@ limitations under the License. */
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/inference/tensorrt/convert/convert.h" #include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
namespace tensorrt { namespace tensorrt {
TEST(tensorrt, ConvertBlock) { TEST(BlockConverter, ConvertBlock) {
framework::ProgramDesc prog; framework::ProgramDesc prog;
auto* block = prog.MutableBlock(0); auto* block = prog.MutableBlock(0);
auto* mul_op = block->AppendOp(); auto* mul_op = block->AppendOp();
...@@ -28,7 +28,7 @@ TEST(tensorrt, ConvertBlock) { ...@@ -28,7 +28,7 @@ TEST(tensorrt, ConvertBlock) {
auto* conv2d_op = block->AppendOp(); auto* conv2d_op = block->AppendOp();
conv2d_op->SetType("conv2d"); conv2d_op->SetType("conv2d");
TensorRTConverter converter; BlockConverter converter;
converter.ConvertBlock(*block); converter.ConvertBlock(*block);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册