From f6fb51a164890267141f240633b74cb6c8b90b3b Mon Sep 17 00:00:00 2001 From: Luo Tao Date: Thu, 7 Jun 2018 13:35:12 +0800 Subject: [PATCH] add test_mode in trt/activation_op --- paddle/fluid/inference/tensorrt/convert/activation_op.cc | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/convert/activation_op.cc b/paddle/fluid/inference/tensorrt/convert/activation_op.cc index 7814f6d354..e1cace9cc1 100644 --- a/paddle/fluid/inference/tensorrt/convert/activation_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/activation_op.cc @@ -23,7 +23,7 @@ class ReluOpConverter : public OpConverter { public: ReluOpConverter() {} void operator()(const framework::proto::OpDesc& op, - const framework::Scope& scope) override { + const framework::Scope& scope, bool test_mode) override { // Here the two nullptr looks strange, that's because the // framework::OpDesc's constructor is strange. framework::OpDesc op_desc(op, nullptr); @@ -34,7 +34,12 @@ class ReluOpConverter : public OpConverter { nvinfer1::IActivationLayer* layer = TRT_ENGINE_ADD_LAYER( engine_, Activation, *const_cast(input_tensor), nvinfer1::ActivationType::kRELU); - engine_->DeclareOutput(layer, 0, op_desc.Output("Out")[0]); + auto output_name = op_desc.Output("Out")[0]; + engine_->SetITensor(output_name, layer->getOutput(0)); + if (test_mode) { // the test framework can not determine which is the + // output, so place the declaration inside. + engine_->DeclareOutput(output_name); + } } }; -- GitLab