未验证 提交 dbc63555 编写于 作者: W wenbin 提交者: GitHub

support int input for scale (#48044)

* int scale

* round

* revert commit
上级 5329187d
......@@ -49,9 +49,12 @@ class ScaleOpConverter : public OpConverter {
PADDLE_GET_CONST(bool, op_desc.GetAttr("bias_after_scale"));
float bias = PADDLE_GET_CONST(float, op_desc.GetAttr("bias"));
float scale = PADDLE_GET_CONST(float, op_desc.GetAttr("scale"));
bool is_int = input->getType() == nvinfer1::DataType::kINT32;
nvinfer1::ILayer* layer = nullptr;
if (engine_->with_dynamic_shape()) {
nvinfer1::ITensor* bias_tensor = Add1DConstantLayer(bias);
nvinfer1::ITensor* bias_tensor =
is_int ? Add1DConstantLayer(static_cast<int>(bias))
: Add1DConstantLayer(bias);
bool is_bias_0 = (bias < 1e-06 && bias > -1e-06);
std::vector<int32_t> bias_shapes(input->getDimensions().nbDims, 1);
......@@ -72,7 +75,8 @@ class ScaleOpConverter : public OpConverter {
is_scale_1 = false;
} else {
has_scale_tensor = false;
scale_tensor = Add1DConstantLayer(scale);
scale_tensor = is_int ? Add1DConstantLayer(static_cast<int>(scale))
: Add1DConstantLayer(scale);
is_scale_1 = ((scale - 1.0) < 1e-06 && (scale - 1.0) > -1e-06);
}
......
......@@ -1076,14 +1076,25 @@ struct SimpleOpTypeSetTeller : public Teller {
auto* x_var_desc = block->FindVar(x_var_name);
const auto x_shape = x_var_desc->GetShape();
auto dtype = x_var_desc->GetDataType();
if (!with_dynamic_shape) {
// At present, only support float32 or float16 into trt.
if (!(dtype == 5 || dtype == 4)) {
if (!(dtype == framework::proto::VarType::FP32 ||
dtype == framework::proto::VarType::FP16)) {
return false;
}
if (!with_dynamic_shape && x_shape.size() == 1) {
VLOG(3) << "Scale op does not support 1-dimensional input in tensorrt";
if (x_shape.size() == 1) {
VLOG(3)
<< "Scale op does not support 1-dimensional input in tensorrt";
return false;
}
} else {
// At present, only support float32 or float16 or int32 into trt.
if (!(dtype == framework::proto::VarType::FP32 ||
dtype == framework::proto::VarType::FP16 ||
dtype == framework::proto::VarType::INT32)) {
return false;
}
}
}
if (op_type == "roll") {
......
......@@ -26,18 +26,24 @@ class TrtConvertScaleTest(TrtLayerAutoScanTest):
return True
def sample_program_configs(self):
def generate_input1(attrs: List[Dict[str, Any]], batch):
def generate_input1(attrs: List[Dict[str, Any]], batch, is_int):
if self.dims == 4:
return np.ones([batch, 3, 24, 24]).astype(np.float32)
return np.ones([batch, 3, 24, 24]).astype(
np.int32 if is_int else np.float32
)
elif self.dims == 3:
return np.ones([batch, 3, 24]).astype(np.float32)
return np.ones([batch, 3, 24]).astype(
np.int32 if is_int else np.float32
)
elif self.dims == 2:
return np.ones([batch, 24]).astype(np.float32)
return np.ones([batch, 24]).astype(
np.int32 if is_int else np.float32
)
elif self.dims == 1:
return np.ones([24]).astype(np.float32)
return np.ones([24]).astype(np.int32 if is_int else np.float32)
def generate_weight1(attrs: List[Dict[str, Any]]):
return np.ones([1]).astype(np.float32)
def generate_weight1(attrs: List[Dict[str, Any]], is_int):
return np.ones([1]).astype(np.int32 if is_int else np.float32)
for num_input in [0, 1]:
for dims in [1, 2, 3, 4]:
......@@ -45,8 +51,10 @@ class TrtConvertScaleTest(TrtLayerAutoScanTest):
for scale in [0.1, -1.0]:
for bias in [0.0, 1.2]:
for bias_after_scale in [False, True]:
for is_int in [False, True]:
self.num_input = num_input
self.dims = dims
self.is_int = is_int
dics = [
{
"scale": scale,
......@@ -67,7 +75,9 @@ class TrtConvertScaleTest(TrtLayerAutoScanTest):
{
"ScaleTensor": TensorConfig(
data_gen=partial(
generate_weight1, dics
generate_weight1,
dics,
is_int,
)
)
},
......@@ -78,7 +88,9 @@ class TrtConvertScaleTest(TrtLayerAutoScanTest):
{
"op_type": "scale",
"op_inputs": dics_intput[num_input],
"op_outputs": {"Out": ["scale_out"]},
"op_outputs": {
"Out": ["scale_out"]
},
"op_attrs": dics[0],
}
]
......@@ -89,7 +101,10 @@ class TrtConvertScaleTest(TrtLayerAutoScanTest):
inputs={
"scale_input": TensorConfig(
data_gen=partial(
generate_input1, dics, batch
generate_input1,
dics,
batch,
is_int,
)
)
},
......@@ -182,6 +197,17 @@ class TrtConvertScaleTest(TrtLayerAutoScanTest):
"INPUT DIM EQUAL TO 1 OF STATIC SHAPE NOT SUPPORT",
)
def teller3(program_config, predictor_config):
if self.is_int and len(self.dynamic_shape.min_input_shape) == 0:
return True
return False
self.add_skip_case(
teller3,
SkipReasons.TRT_NOT_SUPPORT,
"INTEGER INPUT OF STATIC SHAPE NOT SUPPORT",
)
def test(self):
self.add_skip_trt_case()
self.run_test()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册