未验证 提交 f860de4a 编写于 作者: P Pei Yang 提交者: GitHub

support clip op trt converter (#29411)

上级 1dd7b97b
......@@ -1100,6 +1100,7 @@ USE_TRT_CONVERTER(skip_layernorm);
USE_TRT_CONVERTER(slice);
USE_TRT_CONVERTER(scale);
USE_TRT_CONVERTER(stack);
USE_TRT_CONVERTER(clip);
#endif
namespace paddle_infer {
......
......@@ -307,7 +307,7 @@ class PD_INFER_DECL PaddlePredictor {
/// This will save the IO copy for transfering inputs and outputs to predictor
/// workspace
/// and get some performance improvement.
/// To use it, one should call the AnalysisConfig.SwitchUseFeedFetchOp(true)
/// To use it, one should call the AnalysisConfig.SwitchUseFeedFetchOp(false)
/// and then use the `GetInputTensor` and `GetOutputTensor`
/// to directly write or read the input/output tensors.
/// \return Whether the run is successful
......
......@@ -4,7 +4,7 @@ nv_library(tensorrt_converter
batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc
pad_op.cc split_op.cc prelu_op.cc leaky_relu_op.cc gelu_op.cc layer_norm_op.cc multihead_matmul_op.cc
shuffle_channel_op.cc swish_op.cc instance_norm_op.cc stack_op.cc
emb_eltwise_layernorm.cc skip_layernorm.cc scale_op.cc slice_op.cc hard_sigmoid_op.cc hard_swish_op.cc
emb_eltwise_layernorm.cc skip_layernorm.cc scale_op.cc slice_op.cc hard_sigmoid_op.cc hard_swish_op.cc clip_op.cc
DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry)
nv_test(test_op_converter SRCS test_op_converter.cc DEPS
......
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
namespace paddle {
namespace framework {
class Scope;
namespace proto {
class OpDesc;
} // namespace proto
} // namespace framework
} // namespace paddle
namespace paddle {
namespace inference {
namespace tensorrt {
/*
* ClipOp
*/
class ClipOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override {
#if IS_TRT_VERSION_GE(5130)
VLOG(3) << "convert a paddle clip op to tensorrt IActivationLayer.";
framework::OpDesc op_desc(op, nullptr);
// Declare inputs
auto* input = engine_->GetITensor(op_desc.Input("X")[0]);
float min = BOOST_GET_CONST(float, op_desc.GetAttr("min"));
float max = BOOST_GET_CONST(float, op_desc.GetAttr("max"));
auto* layer = TRT_ENGINE_ADD_LAYER(engine_, Activation, *input,
nvinfer1::ActivationType::kCLIP);
layer->setAlpha(min);
layer->setBeta(max);
auto output_name = op_desc.Output("Out")[0];
RreplenishLayerAndOutput(layer, "clip", {output_name}, test_mode);
#else
PADDLE_THROW(
platform::errors::Fatal("clip TRT converter is only supported on TRT "
"5.1.3.0 or higher version."));
#endif
}
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
REGISTER_TRT_OP_CONVERTER(clip, ClipOpConverter);
......@@ -32,8 +32,10 @@ struct SimpleOpTypeSetTeller : public Teller {
#if IS_TRT_VERSION_GE(5130)
teller_set.insert("relu6");
teller_set.insert("hard_sigmoid");
teller_set.insert("clip");
int8_teller_set.insert("relu6");
int8_teller_set.insert("hard_sigmoid");
int8_teller_set.insert("clip");
#endif
#if IS_TRT_VERSION_GE(6000)
teller_set.insert("fused_embedding_eltwise_layernorm");
......@@ -132,8 +134,9 @@ bool OpTeller::Tell(const std::string& op_type, const framework::OpDesc& desc,
auto* var_desc = block->FindVar(var_name);
const auto shape = var_desc->GetShape();
if (shape.size() < 3) {
VLOG(1) << "matmul op dims < 3 not supported in tensorrt, but got dims "
<< shape.size() << ", so jump it.";
VLOG(1)
<< "matmul op dims < 3 not supported in tensorrt, but got dims "
<< shape.size() << ", so jump it.";
return false;
}
}
......
......@@ -343,6 +343,11 @@ class TensorRTSubgraphPassHardSigmoidTest(TensorRTSubgraphPassActivationTest):
return fluid.layers.hard_sigmoid(x)
class TensorRTSubgraphPassClipTest(TensorRTSubgraphPassActivationTest):
def append_act(self, x):
return fluid.layers.clip(x, 0, 1)
class TensorRTSubgraphPassTanhTest(TensorRTSubgraphPassActivationTest):
def append_act(self, x):
return fluid.layers.tanh(x)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册