未验证 提交 674bd839 编写于 作者: Y Yan Chunwei 提交者: GitHub

OpConverter change BlockDesc to proto::BlockDesc (#10623)

上级 de81ccb5
nv_test(test_op_converter SRCS test_op_converter.cc mul_op.cc conv2d_op.cc DEPS ${FLUID_CORE_MODULES}) nv_test(test_op_converter SRCS test_op_converter.cc mul_op.cc conv2d_op.cc op_converter.h DEPS ${FLUID_CORE_MODULES})
nv_test(test_trt_activation_op SRCS test_activation_op.cc activation_op.cc nv_test(test_trt_activation_op SRCS test_activation_op.cc activation_op.cc
DEPS ${FLUID_CORE_MODULES} activation_op tensorrt_engine) DEPS ${FLUID_CORE_MODULES} activation_op tensorrt_engine)
nv_test(test_io_converter SRCS test_io_converter.cc io_converter.cc DEPS dynload_cuda dynamic_loader lod_tensor) nv_test(test_io_converter SRCS test_io_converter.cc io_converter.cc DEPS dynload_cuda dynamic_loader lod_tensor)
...@@ -21,15 +21,18 @@ namespace tensorrt { ...@@ -21,15 +21,18 @@ namespace tensorrt {
class ReluOpConverter : public OpConverter { class ReluOpConverter : public OpConverter {
public: public:
ReluOpConverter() {} ReluOpConverter() {}
void operator()(const framework::OpDesc& op) override { void operator()(const framework::proto::OpDesc& op) override {
// Here the two nullptr looks strange, that's because the
// framework::OpDesc's constructor is strange.
framework::OpDesc op_desc(op, nullptr, nullptr);
LOG(INFO) << "convert a fluid relu op to tensorrt activation layer whose " LOG(INFO) << "convert a fluid relu op to tensorrt activation layer whose "
"type is Relu"; "type is Relu";
const nvinfer1::ITensor* input_tensor = const nvinfer1::ITensor* input_tensor =
engine_->GetITensor(op.Input("X")[0]); engine_->GetITensor(op_desc.Input("X")[0]);
nvinfer1::IActivationLayer* layer = TRT_ENGINE_ADD_LAYER( nvinfer1::IActivationLayer* layer = TRT_ENGINE_ADD_LAYER(
engine_, Activation, *const_cast<nvinfer1::ITensor*>(input_tensor), engine_, Activation, *const_cast<nvinfer1::ITensor*>(input_tensor),
nvinfer1::ActivationType::kRELU); nvinfer1::ActivationType::kRELU);
engine_->SetITensor(op.Output("Out")[0], layer->getOutput(0)); engine_->SetITensor(op_desc.Output("Out")[0], layer->getOutput(0));
} }
}; };
......
...@@ -21,7 +21,7 @@ namespace tensorrt { ...@@ -21,7 +21,7 @@ namespace tensorrt {
class Conv2dOpConverter : public OpConverter { class Conv2dOpConverter : public OpConverter {
public: public:
Conv2dOpConverter() {} Conv2dOpConverter() {}
void operator()(const framework::OpDesc& op) override { void operator()(const framework::proto::OpDesc& op) override {
LOG(INFO) LOG(INFO)
<< "convert a fluid conv2d op to tensorrt conv layer without bias"; << "convert a fluid conv2d op to tensorrt conv layer without bias";
} }
......
...@@ -21,7 +21,7 @@ namespace tensorrt { ...@@ -21,7 +21,7 @@ namespace tensorrt {
class MulOpConverter : public OpConverter { class MulOpConverter : public OpConverter {
public: public:
MulOpConverter() {} MulOpConverter() {}
void operator()(const framework::OpDesc& op) override { void operator()(const framework::proto::OpDesc& op) override {
LOG(INFO) << "convert a fluid mul op to tensorrt fc layer without bias"; LOG(INFO) << "convert a fluid mul op to tensorrt fc layer without bias";
} }
}; };
......
...@@ -31,10 +31,10 @@ namespace tensorrt { ...@@ -31,10 +31,10 @@ namespace tensorrt {
class OpConverter { class OpConverter {
public: public:
OpConverter() {} OpConverter() {}
virtual void operator()(const framework::OpDesc& op) {} virtual void operator()(const framework::proto::OpDesc& op) {}
void Run(const framework::OpDesc& op, TensorRTEngine* engine) { void Run(const framework::proto::OpDesc& op, TensorRTEngine* engine) {
std::string type = op.Type(); std::string type = op.type();
auto* it = Registry<OpConverter>::Lookup(type); auto* it = Registry<OpConverter>::Lookup(type);
PADDLE_ENFORCE_NOT_NULL(it, "no OpConverter for optype [%s]", type); PADDLE_ENFORCE_NOT_NULL(it, "no OpConverter for optype [%s]", type);
it->SetEngine(engine); it->SetEngine(engine);
...@@ -42,14 +42,16 @@ class OpConverter { ...@@ -42,14 +42,16 @@ class OpConverter {
} }
// convert fluid op to tensorrt layer // convert fluid op to tensorrt layer
void ConvertOp(const framework::OpDesc& op, TensorRTEngine* engine) { void ConvertOp(const framework::proto::OpDesc& op, TensorRTEngine* engine) {
OpConverter::Run(op, engine); OpConverter::Run(op, engine);
} }
// convert fluid block to tensorrt network // convert fluid block to tensorrt network
void ConvertBlock(const framework::BlockDesc& block, TensorRTEngine* engine) { void ConvertBlock(const framework::proto::BlockDesc& block,
for (auto op : block.AllOps()) { TensorRTEngine* engine) {
OpConverter::Run(*op, engine); for (size_t i = 0; i < block.ops_size(); i++) {
const auto& op = block.ops(i);
OpConverter::Run(op, engine);
} }
} }
......
...@@ -49,7 +49,7 @@ void Compare(float input, float expect) { ...@@ -49,7 +49,7 @@ void Compare(float input, float expect) {
op_desc.SetInput("X", {"X"}); op_desc.SetInput("X", {"X"});
op_desc.SetOutput("Out", {"Out"}); op_desc.SetOutput("Out", {"Out"});
auto relu_op = framework::OpRegistry::CreateOp(op_desc); auto relu_op = framework::OpRegistry::CreateOp(*op_desc.Proto());
// run fluid op // run fluid op
relu_op->Run(scope, place); relu_op->Run(scope, place);
...@@ -65,7 +65,7 @@ void Compare(float input, float expect) { ...@@ -65,7 +65,7 @@ void Compare(float input, float expect) {
nvinfer1::DimsCHW{1, 1, 1}); nvinfer1::DimsCHW{1, 1, 1});
OpConverter op_converter; OpConverter op_converter;
op_converter.ConvertOp(op_desc, engine); op_converter.ConvertOp(*op_desc.Proto(), engine);
engine->DeclareOutput("Out"); engine->DeclareOutput("Out");
engine->FreezeNetwork(); engine->FreezeNetwork();
......
...@@ -29,7 +29,7 @@ TEST(OpConverter, ConvertBlock) { ...@@ -29,7 +29,7 @@ TEST(OpConverter, ConvertBlock) {
conv2d_op->SetType("conv2d"); conv2d_op->SetType("conv2d");
OpConverter converter; OpConverter converter;
converter.ConvertBlock(*block, nullptr /*TensorRTEngine*/); converter.ConvertBlock(*block->Proto(), nullptr /*TensorRTEngine*/);
} }
} // namespace tensorrt } // namespace tensorrt
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册