未验证 提交 c294cca4 编写于 作者: W wenbin 提交者: GitHub

禁止在低版本TRT中使用strides>1的conv (#32997)

* revert elementwise and disable trt conv if strides > 1

* strides check

* remove useless var

* comments
上级 09bc0f59
...@@ -52,11 +52,6 @@ class ActivationOpConverter : public OpConverter { ...@@ -52,11 +52,6 @@ class ActivationOpConverter : public OpConverter {
engine_->GetITensor(op_desc.Input("X")[0]); engine_->GetITensor(op_desc.Input("X")[0]);
auto op_pair = ops.find(op_type_); auto op_pair = ops.find(op_type_);
if (op_pair == ops.end()) {
PADDLE_THROW(platform::errors::Fatal(
"Wrong activation op type, the trt do not support the %s act type.",
op_type_));
}
nvinfer1::IActivationLayer* layer = TRT_ENGINE_ADD_LAYER( nvinfer1::IActivationLayer* layer = TRT_ENGINE_ADD_LAYER(
engine_, Activation, *const_cast<nvinfer1::ITensor*>(input_tensor), engine_, Activation, *const_cast<nvinfer1::ITensor*>(input_tensor),
......
...@@ -55,16 +55,6 @@ class AffineChannelOpConverter : public OpConverter { ...@@ -55,16 +55,6 @@ class AffineChannelOpConverter : public OpConverter {
auto* bias_t = bias_v->GetMutable<framework::LoDTensor>(); auto* bias_t = bias_v->GetMutable<framework::LoDTensor>();
float* bias_ptr = engine_->GetWeightCPUData(bias_name, bias_t, false); float* bias_ptr = engine_->GetWeightCPUData(bias_name, bias_t, false);
auto data_layout = framework::StringToDataLayout(
BOOST_GET_CONST(std::string, op_desc.GetAttr("data_layout")));
PADDLE_ENFORCE_EQ(
data_layout, framework::DataLayout::kNCHW,
platform::errors::InvalidArgument(
"TensorRT affine channel converter can only convert NCHW format. "
"Other format should be run in fluid mode. Report a bug on github "
"issue if you see this line."));
// tensorrt scalend layer only support spatial dims >= 2, // tensorrt scalend layer only support spatial dims >= 2,
// so nhwc is not availabe (spatial dims == 0) // so nhwc is not availabe (spatial dims == 0)
const int channel_axis = engine_->with_dynamic_shape(); const int channel_axis = engine_->with_dynamic_shape();
......
...@@ -25,10 +25,6 @@ static bool CheckDims(const nvinfer1::Dims& dims_x, ...@@ -25,10 +25,6 @@ static bool CheckDims(const nvinfer1::Dims& dims_x,
return false; return false;
} }
for (int i = 0; i < dims_x.nbDims; i++) { for (int i = 0; i < dims_x.nbDims; i++) {
// conservative judgment
if (dims_x.d[i] == -1 || dims_y.d[i] == -1) {
return false;
}
if (dims_x.d[i] != dims_y.d[i]) { if (dims_x.d[i] != dims_y.d[i]) {
return false; return false;
} }
......
...@@ -143,6 +143,19 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, ...@@ -143,6 +143,19 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
BOOST_GET_CONST(std::vector<int>, desc.GetAttr("paddings")); BOOST_GET_CONST(std::vector<int>, desc.GetAttr("paddings"));
if (paddings.size() > 2) return false; if (paddings.size() > 2) return false;
// strides > 1 is only supported by trt7.0 above
#if !IS_TRT_VERSION_GE(7000)
if (desc.HasAttr("strides")) {
const std::vector<int> strides =
BOOST_GET_CONST(std::vector<int>, desc.GetAttr("strides"));
// there is no issue if strides.size() less than 2
if (strides.size() > 1) {
for (size_t i = 0; i < strides.size(); i++) {
if (strides[i] > 1) return false;
}
}
}
#endif
} }
if (op_type == "pool2d") { if (op_type == "pool2d") {
...@@ -225,6 +238,20 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, ...@@ -225,6 +238,20 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
<< desc.Output("Output").size() << " output."; << desc.Output("Output").size() << " output.";
return false; return false;
} }
// strides > 1 is only supported by trt7.0 above
#if !IS_TRT_VERSION_GE(7000)
if (desc.HasAttr("strides")) {
const std::vector<int> strides =
BOOST_GET_CONST(std::vector<int>, desc.GetAttr("strides"));
// there is no issue if strides.size() less than 2
if (strides.size() > 1) {
for (size_t i = 0; i < strides.size(); i++) {
if (strides[i] > 1) return false;
}
}
}
#endif
} }
if (op_type == "matmul") { if (op_type == "matmul") {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册