提交 f6fb51a1 编写于 作者: L Luo Tao

add test_mode in trt/activation_op

上级 c73977af
......@@ -23,7 +23,7 @@ class ReluOpConverter : public OpConverter {
public:
ReluOpConverter() {}
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope) override {
const framework::Scope& scope, bool test_mode) override {
// Here the two nullptr looks strange, that's because the
// framework::OpDesc's constructor is strange.
framework::OpDesc op_desc(op, nullptr);
......@@ -34,7 +34,12 @@ class ReluOpConverter : public OpConverter {
nvinfer1::IActivationLayer* layer = TRT_ENGINE_ADD_LAYER(
engine_, Activation, *const_cast<nvinfer1::ITensor*>(input_tensor),
nvinfer1::ActivationType::kRELU);
engine_->DeclareOutput(layer, 0, op_desc.Output("Out")[0]);
auto output_name = op_desc.Output("Out")[0];
engine_->SetITensor(output_name, layer->getOutput(0));
if (test_mode) { // the test framework can not determine which is the
// output, so place the declaration inside.
engine_->DeclareOutput(output_name);
}
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册