未验证 提交 dbc63555 编写于 作者: W wenbin 提交者: GitHub

support int input for scale (#48044)

* int scale

* round

* revert commit
上级 5329187d
...@@ -49,9 +49,12 @@ class ScaleOpConverter : public OpConverter { ...@@ -49,9 +49,12 @@ class ScaleOpConverter : public OpConverter {
PADDLE_GET_CONST(bool, op_desc.GetAttr("bias_after_scale")); PADDLE_GET_CONST(bool, op_desc.GetAttr("bias_after_scale"));
float bias = PADDLE_GET_CONST(float, op_desc.GetAttr("bias")); float bias = PADDLE_GET_CONST(float, op_desc.GetAttr("bias"));
float scale = PADDLE_GET_CONST(float, op_desc.GetAttr("scale")); float scale = PADDLE_GET_CONST(float, op_desc.GetAttr("scale"));
bool is_int = input->getType() == nvinfer1::DataType::kINT32;
nvinfer1::ILayer* layer = nullptr; nvinfer1::ILayer* layer = nullptr;
if (engine_->with_dynamic_shape()) { if (engine_->with_dynamic_shape()) {
nvinfer1::ITensor* bias_tensor = Add1DConstantLayer(bias); nvinfer1::ITensor* bias_tensor =
is_int ? Add1DConstantLayer(static_cast<int>(bias))
: Add1DConstantLayer(bias);
bool is_bias_0 = (bias < 1e-06 && bias > -1e-06); bool is_bias_0 = (bias < 1e-06 && bias > -1e-06);
std::vector<int32_t> bias_shapes(input->getDimensions().nbDims, 1); std::vector<int32_t> bias_shapes(input->getDimensions().nbDims, 1);
...@@ -72,7 +75,8 @@ class ScaleOpConverter : public OpConverter { ...@@ -72,7 +75,8 @@ class ScaleOpConverter : public OpConverter {
is_scale_1 = false; is_scale_1 = false;
} else { } else {
has_scale_tensor = false; has_scale_tensor = false;
scale_tensor = Add1DConstantLayer(scale); scale_tensor = is_int ? Add1DConstantLayer(static_cast<int>(scale))
: Add1DConstantLayer(scale);
is_scale_1 = ((scale - 1.0) < 1e-06 && (scale - 1.0) > -1e-06); is_scale_1 = ((scale - 1.0) < 1e-06 && (scale - 1.0) > -1e-06);
} }
......
...@@ -1076,14 +1076,25 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -1076,14 +1076,25 @@ struct SimpleOpTypeSetTeller : public Teller {
auto* x_var_desc = block->FindVar(x_var_name); auto* x_var_desc = block->FindVar(x_var_name);
const auto x_shape = x_var_desc->GetShape(); const auto x_shape = x_var_desc->GetShape();
auto dtype = x_var_desc->GetDataType(); auto dtype = x_var_desc->GetDataType();
if (!with_dynamic_shape) {
// At present, only support float32 or float16 into trt. // At present, only support float32 or float16 into trt.
if (!(dtype == 5 || dtype == 4)) { if (!(dtype == framework::proto::VarType::FP32 ||
dtype == framework::proto::VarType::FP16)) {
return false; return false;
} }
if (!with_dynamic_shape && x_shape.size() == 1) { if (x_shape.size() == 1) {
VLOG(3) << "Scale op does not support 1-dimensional input in tensorrt"; VLOG(3)
<< "Scale op does not support 1-dimensional input in tensorrt";
return false; return false;
} }
} else {
// At present, only support float32 or float16 or int32 into trt.
if (!(dtype == framework::proto::VarType::FP32 ||
dtype == framework::proto::VarType::FP16 ||
dtype == framework::proto::VarType::INT32)) {
return false;
}
}
} }
if (op_type == "roll") { if (op_type == "roll") {
......
...@@ -26,18 +26,24 @@ class TrtConvertScaleTest(TrtLayerAutoScanTest): ...@@ -26,18 +26,24 @@ class TrtConvertScaleTest(TrtLayerAutoScanTest):
return True return True
def sample_program_configs(self): def sample_program_configs(self):
def generate_input1(attrs: List[Dict[str, Any]], batch): def generate_input1(attrs: List[Dict[str, Any]], batch, is_int):
if self.dims == 4: if self.dims == 4:
return np.ones([batch, 3, 24, 24]).astype(np.float32) return np.ones([batch, 3, 24, 24]).astype(
np.int32 if is_int else np.float32
)
elif self.dims == 3: elif self.dims == 3:
return np.ones([batch, 3, 24]).astype(np.float32) return np.ones([batch, 3, 24]).astype(
np.int32 if is_int else np.float32
)
elif self.dims == 2: elif self.dims == 2:
return np.ones([batch, 24]).astype(np.float32) return np.ones([batch, 24]).astype(
np.int32 if is_int else np.float32
)
elif self.dims == 1: elif self.dims == 1:
return np.ones([24]).astype(np.float32) return np.ones([24]).astype(np.int32 if is_int else np.float32)
def generate_weight1(attrs: List[Dict[str, Any]]): def generate_weight1(attrs: List[Dict[str, Any]], is_int):
return np.ones([1]).astype(np.float32) return np.ones([1]).astype(np.int32 if is_int else np.float32)
for num_input in [0, 1]: for num_input in [0, 1]:
for dims in [1, 2, 3, 4]: for dims in [1, 2, 3, 4]:
...@@ -45,8 +51,10 @@ class TrtConvertScaleTest(TrtLayerAutoScanTest): ...@@ -45,8 +51,10 @@ class TrtConvertScaleTest(TrtLayerAutoScanTest):
for scale in [0.1, -1.0]: for scale in [0.1, -1.0]:
for bias in [0.0, 1.2]: for bias in [0.0, 1.2]:
for bias_after_scale in [False, True]: for bias_after_scale in [False, True]:
for is_int in [False, True]:
self.num_input = num_input self.num_input = num_input
self.dims = dims self.dims = dims
self.is_int = is_int
dics = [ dics = [
{ {
"scale": scale, "scale": scale,
...@@ -67,7 +75,9 @@ class TrtConvertScaleTest(TrtLayerAutoScanTest): ...@@ -67,7 +75,9 @@ class TrtConvertScaleTest(TrtLayerAutoScanTest):
{ {
"ScaleTensor": TensorConfig( "ScaleTensor": TensorConfig(
data_gen=partial( data_gen=partial(
generate_weight1, dics generate_weight1,
dics,
is_int,
) )
) )
}, },
...@@ -78,7 +88,9 @@ class TrtConvertScaleTest(TrtLayerAutoScanTest): ...@@ -78,7 +88,9 @@ class TrtConvertScaleTest(TrtLayerAutoScanTest):
{ {
"op_type": "scale", "op_type": "scale",
"op_inputs": dics_intput[num_input], "op_inputs": dics_intput[num_input],
"op_outputs": {"Out": ["scale_out"]}, "op_outputs": {
"Out": ["scale_out"]
},
"op_attrs": dics[0], "op_attrs": dics[0],
} }
] ]
...@@ -89,7 +101,10 @@ class TrtConvertScaleTest(TrtLayerAutoScanTest): ...@@ -89,7 +101,10 @@ class TrtConvertScaleTest(TrtLayerAutoScanTest):
inputs={ inputs={
"scale_input": TensorConfig( "scale_input": TensorConfig(
data_gen=partial( data_gen=partial(
generate_input1, dics, batch generate_input1,
dics,
batch,
is_int,
) )
) )
}, },
...@@ -182,6 +197,17 @@ class TrtConvertScaleTest(TrtLayerAutoScanTest): ...@@ -182,6 +197,17 @@ class TrtConvertScaleTest(TrtLayerAutoScanTest):
"INPUT DIM EQUAL TO 1 OF STATIC SHAPE NOT SUPPORT", "INPUT DIM EQUAL TO 1 OF STATIC SHAPE NOT SUPPORT",
) )
def teller3(program_config, predictor_config):
if self.is_int and len(self.dynamic_shape.min_input_shape) == 0:
return True
return False
self.add_skip_case(
teller3,
SkipReasons.TRT_NOT_SUPPORT,
"INTEGER INPUT OF STATIC SHAPE NOT SUPPORT",
)
def test(self): def test(self):
self.add_skip_trt_case() self.add_skip_trt_case()
self.run_test() self.run_test()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册