You need to sign in or sign up before continuing.
未验证 提交 674bd839 编写于 作者: Y Yan Chunwei 提交者: GitHub

OpConverter change BlockDesc to proto::BlockDesc (#10623)

上级 de81ccb5
nv_test(test_op_converter SRCS test_op_converter.cc mul_op.cc conv2d_op.cc DEPS ${FLUID_CORE_MODULES})
nv_test(test_op_converter SRCS test_op_converter.cc mul_op.cc conv2d_op.cc op_converter.h DEPS ${FLUID_CORE_MODULES})
nv_test(test_trt_activation_op SRCS test_activation_op.cc activation_op.cc
DEPS ${FLUID_CORE_MODULES} activation_op tensorrt_engine)
nv_test(test_io_converter SRCS test_io_converter.cc io_converter.cc DEPS dynload_cuda dynamic_loader lod_tensor)
......@@ -21,15 +21,18 @@ namespace tensorrt {
class ReluOpConverter : public OpConverter {
public:
ReluOpConverter() {}
void operator()(const framework::OpDesc& op) override {
void operator()(const framework::proto::OpDesc& op) override {
// Here the two nullptr looks strange, that's because the
// framework::OpDesc's constructor is strange.
framework::OpDesc op_desc(op, nullptr, nullptr);
LOG(INFO) << "convert a fluid relu op to tensorrt activation layer whose "
"type is Relu";
const nvinfer1::ITensor* input_tensor =
engine_->GetITensor(op.Input("X")[0]);
engine_->GetITensor(op_desc.Input("X")[0]);
nvinfer1::IActivationLayer* layer = TRT_ENGINE_ADD_LAYER(
engine_, Activation, *const_cast<nvinfer1::ITensor*>(input_tensor),
nvinfer1::ActivationType::kRELU);
engine_->SetITensor(op.Output("Out")[0], layer->getOutput(0));
engine_->SetITensor(op_desc.Output("Out")[0], layer->getOutput(0));
}
};
......
......@@ -21,7 +21,7 @@ namespace tensorrt {
class Conv2dOpConverter : public OpConverter {
public:
Conv2dOpConverter() {}
void operator()(const framework::OpDesc& op) override {
void operator()(const framework::proto::OpDesc& op) override {
LOG(INFO)
<< "convert a fluid conv2d op to tensorrt conv layer without bias";
}
......
......@@ -21,7 +21,7 @@ namespace tensorrt {
class MulOpConverter : public OpConverter {
public:
MulOpConverter() {}
void operator()(const framework::OpDesc& op) override {
void operator()(const framework::proto::OpDesc& op) override {
LOG(INFO) << "convert a fluid mul op to tensorrt fc layer without bias";
}
};
......
......@@ -31,10 +31,10 @@ namespace tensorrt {
class OpConverter {
public:
OpConverter() {}
virtual void operator()(const framework::OpDesc& op) {}
virtual void operator()(const framework::proto::OpDesc& op) {}
void Run(const framework::OpDesc& op, TensorRTEngine* engine) {
std::string type = op.Type();
void Run(const framework::proto::OpDesc& op, TensorRTEngine* engine) {
std::string type = op.type();
auto* it = Registry<OpConverter>::Lookup(type);
PADDLE_ENFORCE_NOT_NULL(it, "no OpConverter for optype [%s]", type);
it->SetEngine(engine);
......@@ -42,14 +42,16 @@ class OpConverter {
}
// convert fluid op to tensorrt layer
void ConvertOp(const framework::OpDesc& op, TensorRTEngine* engine) {
void ConvertOp(const framework::proto::OpDesc& op, TensorRTEngine* engine) {
OpConverter::Run(op, engine);
}
// convert fluid block to tensorrt network
void ConvertBlock(const framework::BlockDesc& block, TensorRTEngine* engine) {
for (auto op : block.AllOps()) {
OpConverter::Run(*op, engine);
void ConvertBlock(const framework::proto::BlockDesc& block,
TensorRTEngine* engine) {
for (size_t i = 0; i < block.ops_size(); i++) {
const auto& op = block.ops(i);
OpConverter::Run(op, engine);
}
}
......
......@@ -49,7 +49,7 @@ void Compare(float input, float expect) {
op_desc.SetInput("X", {"X"});
op_desc.SetOutput("Out", {"Out"});
auto relu_op = framework::OpRegistry::CreateOp(op_desc);
auto relu_op = framework::OpRegistry::CreateOp(*op_desc.Proto());
// run fluid op
relu_op->Run(scope, place);
......@@ -65,7 +65,7 @@ void Compare(float input, float expect) {
nvinfer1::DimsCHW{1, 1, 1});
OpConverter op_converter;
op_converter.ConvertOp(op_desc, engine);
op_converter.ConvertOp(*op_desc.Proto(), engine);
engine->DeclareOutput("Out");
engine->FreezeNetwork();
......
......@@ -29,7 +29,7 @@ TEST(OpConverter, ConvertBlock) {
conv2d_op->SetType("conv2d");
OpConverter converter;
converter.ConvertBlock(*block, nullptr /*TensorRTEngine*/);
converter.ConvertBlock(*block->Proto(), nullptr /*TensorRTEngine*/);
}
} // namespace tensorrt
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册