未验证 提交 af9ddeb7 编写于 作者: W wenbin 提交者: GitHub

fix shuffle_channel_detect_pass (#39242)

* shuffle channel pass

* add ut

* timeout fix

* makefile fix
上级 f2226441
...@@ -94,6 +94,7 @@ void ShuffleChannelDetectPass::ApplyImpl(ir::Graph* graph) const { ...@@ -94,6 +94,7 @@ void ShuffleChannelDetectPass::ApplyImpl(ir::Graph* graph) const {
auto* input_node = subgraph.at(x); auto* input_node = subgraph.at(x);
auto reshape1_desc = reshape1_op->Op(); auto reshape1_desc = reshape1_op->Op();
auto reshape2_desc = reshape2_op->Op(); auto reshape2_desc = reshape2_op->Op();
auto trans_desc = transpose_op->Op();
std::string input_name = input_node->Name(); std::string input_name = input_node->Name();
std::string output_name = reshape2_out->Name(); std::string output_name = reshape2_out->Name();
...@@ -101,25 +102,102 @@ void ShuffleChannelDetectPass::ApplyImpl(ir::Graph* graph) const { ...@@ -101,25 +102,102 @@ void ShuffleChannelDetectPass::ApplyImpl(ir::Graph* graph) const {
BOOST_GET_CONST(std::vector<int>, reshape1_desc->GetAttr("shape")); BOOST_GET_CONST(std::vector<int>, reshape1_desc->GetAttr("shape"));
auto reshape2_shape = auto reshape2_shape =
BOOST_GET_CONST(std::vector<int>, reshape2_desc->GetAttr("shape")); BOOST_GET_CONST(std::vector<int>, reshape2_desc->GetAttr("shape"));
// shuffle_channel dosen't change shape auto trans_axis =
auto* block = reshape1_desc->Block(); BOOST_GET_CONST(std::vector<int>, trans_desc->GetAttr("axis"));
if (block) { auto* block1 = reshape1_desc->Block();
auto* block2 = reshape2_desc->Block();
if (block1 && block2) {
auto x_var_name = reshape1_desc->Input("X")[0]; auto x_var_name = reshape1_desc->Input("X")[0];
auto* x_var_desc = block->FindVar(x_var_name); auto* x_var_desc = block1->FindVar(x_var_name);
const auto x_shape = x_var_desc->GetShape(); auto x_shape1 = x_var_desc->GetShape();
x_var_name = reshape2_desc->Input("X")[0];
if (x_shape.size() != reshape2_shape.size()) { x_var_desc = block2->FindVar(x_var_name);
auto x_shape2 = x_var_desc->GetShape();
// now shuffle_channel is 4D(NCHW) only.
if (x_shape1.size() != 4 || reshape1_shape.size() != 5 ||
reshape2_shape.size() != 4 || trans_axis.size() != 5) {
return; return;
} }
for (size_t i = 0; i < x_shape.size(); i++) { // process 0 and -1 in reshape.
if (x_shape[i] != reshape2_shape[i]) return; constexpr int64_t copy_dim_val = 0;
for (size_t i = 0; i < reshape1_shape.size(); i++) {
if (reshape1_shape[i] == copy_dim_val) {
reshape1_shape[i] = x_shape1[i];
}
}
for (size_t i = 0; i < reshape2_shape.size(); i++) {
if (reshape2_shape[i] == copy_dim_val) {
reshape2_shape[i] = x_shape2[i];
}
}
constexpr int64_t unk_dim_idx = -1;
bool all_positive = std::all_of(x_shape1.cbegin(), x_shape1.cend(),
[](int64_t i) { return i > 0; });
for (size_t i = 0; i < reshape1_shape.size(); ++i) {
// if -1 is not in batch dim, try to calculate number
if ((reshape1_shape[i] == unk_dim_idx) && (i != 0)) {
// there is no sufficient info
if (!all_positive) return;
reshape1_shape[i] =
std::accumulate(x_shape1.begin(), x_shape1.end(),
static_cast<int64_t>(1),
std::multiplies<int64_t>()) /
std::accumulate(reshape1_shape.begin(), reshape1_shape.end(),
static_cast<int64_t>(-1),
std::multiplies<int64_t>());
break;
}
}
all_positive = std::all_of(x_shape2.cbegin(), x_shape2.cend(),
[](int64_t i) { return i > 0; });
for (size_t i = 0; i < reshape2_shape.size(); ++i) {
// if -1 is not in batch dim, try to calculate number
if ((reshape2_shape[i] == unk_dim_idx) && (i != 0)) {
// there is no sufficient info
if (!all_positive) return;
reshape2_shape[i] =
std::accumulate(x_shape2.begin(), x_shape2.end(),
static_cast<int64_t>(1),
std::multiplies<int64_t>()) /
std::accumulate(reshape2_shape.begin(), reshape2_shape.end(),
static_cast<int64_t>(-1),
std::multiplies<int64_t>());
break;
} }
} }
// shuffle_channel dosen't change shape
if ((reshape2_shape[0] != -1) && (x_shape1[0] != reshape2_shape[0])) {
return;
}
for (size_t i = 1; i < x_shape1.size(); i++) {
if (x_shape1[i] != reshape2_shape[i]) {
return;
}
}
if ((reshape2_shape[3] != reshape1_shape[4]) ||
(reshape2_shape[2] != reshape1_shape[3])) {
return;
}
} else {
return; // conservative judgement
}
int i_c = reshape1_shape[2]; int i_c = reshape1_shape[2];
int o_c = reshape2_shape[1]; int o_c = reshape2_shape[1];
int group = o_c / i_c; int group = o_c / i_c;
// should split on channel dim
if (reshape2_shape[1] != reshape1_shape[2] * reshape1_shape[1]) return;
// trans on channel dim
if (trans_axis[0] != 0 || trans_axis[3] != 3 || trans_axis[4] != 4) return;
if (group != 1) {
if (trans_axis[1] != 2 && trans_axis[2] != 1) {
return;
}
}
framework::OpDesc new_op_desc; framework::OpDesc new_op_desc;
new_op_desc.SetType("shuffle_channel"); new_op_desc.SetType("shuffle_channel");
......
...@@ -39,12 +39,7 @@ class ShuffleChannelOpConverter : public OpConverter { ...@@ -39,12 +39,7 @@ class ShuffleChannelOpConverter : public OpConverter {
// Declare inputs // Declare inputs
auto* input = engine_->GetITensor(op_desc.Input("X")[0]); auto* input = engine_->GetITensor(op_desc.Input("X")[0]);
auto input_dims = input->getDimensions(); auto input_dims = input->getDimensions();
PADDLE_ENFORCE_EQ(
input_dims.nbDims, 3,
platform::errors::InvalidArgument("ShuffleChannel TRT op converter "
"input dims is invalid. The input "
"dims size should be 3, but got %d.",
input_dims.nbDims));
int c = input_dims.d[0]; int c = input_dims.d[0];
int h = input_dims.d[1]; int h = input_dims.d[1];
int w = input_dims.d[2]; int w = input_dims.d[2];
......
...@@ -1295,6 +1295,20 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, ...@@ -1295,6 +1295,20 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
"the shuffle_channel op does not support dynamic shape yet"; "the shuffle_channel op does not support dynamic shape yet";
return false; return false;
} }
auto* block = desc.Block();
if (block == nullptr) {
VLOG(3) << "The block desc is nullptr, we can't continue to analyze. "
"Developers need to check whether block_desc is passed in "
"the pass.";
return false;
}
auto* input_desc = block->FindVar(desc.Input("X").front());
const auto input_shape = input_desc->GetShape();
if (input_shape.size() != 4) {
VLOG(3) << "input dims is invalid. The input "
"dims size should be 4.";
return false;
}
} }
if (op_type == "skip_layernorm") { if (op_type == "skip_layernorm") {
...@@ -1606,7 +1620,6 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, ...@@ -1606,7 +1620,6 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
if ((*teller)(op_type, desc, use_no_calib_int8)) return true; if ((*teller)(op_type, desc, use_no_calib_int8)) return true;
} }
VLOG(3) << "trt unsupported op " << op_type;
return false; return false;
} }
......
...@@ -102,6 +102,7 @@ if (WITH_MKLDNN AND TENSORRT_FOUND AND WITH_GPU) ...@@ -102,6 +102,7 @@ if (WITH_MKLDNN AND TENSORRT_FOUND AND WITH_GPU)
set_tests_properties(test_flatten2_matmul_fuse_pass PROPERTIES TIMEOUT 240) set_tests_properties(test_flatten2_matmul_fuse_pass PROPERTIES TIMEOUT 240)
set_tests_properties(test_squeeze2_matmul_fuse_pass PROPERTIES TIMEOUT 240) set_tests_properties(test_squeeze2_matmul_fuse_pass PROPERTIES TIMEOUT 240)
set_tests_properties(test_reshape2_matmul_fuse_pass PROPERTIES TIMEOUT 240) set_tests_properties(test_reshape2_matmul_fuse_pass PROPERTIES TIMEOUT 240)
set_tests_properties(test_shuffle_channel_detect_pass PROPERTIES TIMEOUT 120)
if (WIN32) if (WIN32)
set_tests_properties(test_matmul_scale_fuse_pass PROPERTIES TIMEOUT 300) set_tests_properties(test_matmul_scale_fuse_pass PROPERTIES TIMEOUT 300)
set_tests_properties(test_matmul_v2_scale_fuse_pass PROPERTIES TIMEOUT 300) set_tests_properties(test_matmul_v2_scale_fuse_pass PROPERTIES TIMEOUT 300)
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from auto_scan_test import PassAutoScanTest, IgnoreReasons
from program_config import TensorConfig, ProgramConfig, OpConfig
import numpy as np
import paddle.inference as paddle_infer
from functools import partial
from typing import Optional, List, Callable, Dict, Any, Set
import unittest
import hypothesis
from hypothesis import given, settings, seed, example, assume, reproduce_failure
import hypothesis.strategies as st
class TestShuffleChannelDetectPass(PassAutoScanTest):
def is_program_valid(self, program_config: ProgramConfig) -> bool:
attrs = [
program_config.ops[i].attrs
for i in range(len(program_config.ops))
]
if attrs[0]['input_shape'] != attrs[2]['shape']:
return False
return True
def sample_program_config(self, draw):
batch_size = draw(st.integers(min_value=1, max_value=4))
out_channel = draw(st.integers(min_value=1, max_value=16))
group = draw(st.integers(min_value=1, max_value=4))
in_channel = group * out_channel
x_shape = [batch_size, in_channel, 64, 64]
shape = [0, group, out_channel, -1, 64]
axis_v = [0, 2, 1, 3, 4]
def generate_reshape2_Input():
return np.random.random(x_shape).astype(np.float32)
reshape2_op1 = OpConfig(
"reshape2",
inputs={"X": ["reshape2_input1"], },
outputs={
"Out": ["reshape2_output1"],
"XShape": ["reshape2_xshape1"]
},
shape=shape,
input_shape=x_shape)
transpose2_op = OpConfig(
"transpose2",
inputs={"X": ["reshape2_output1"], },
outputs={
"Out": ["transpose2_ouput"],
"XShape": ["transpose2_xshape"]
},
axis=axis_v)
reshape2_op2 = OpConfig(
"reshape2",
inputs={"X": ["transpose2_ouput"], },
outputs={
"Out": ["reshape2_output2"],
"XShape": ["reshape2_xshape2"]
},
shape=x_shape)
ops = [reshape2_op1, transpose2_op, reshape2_op2]
program_config = ProgramConfig(
ops=ops,
inputs={
"reshape2_input1":
TensorConfig(data_gen=partial(generate_reshape2_Input)),
},
weights={},
outputs=["reshape2_output2"])
return program_config
def sample_predictor_configs(self, program_config):
config = self.create_trt_inference_config()
config.enable_tensorrt_engine(
workspace_size=1 << 20,
max_batch_size=4,
min_subgraph_size=1,
precision_mode=paddle_infer.PrecisionType.Float32,
use_static=False,
use_calib_mode=False)
yield config, ['shuffle_channel'], (1e-5, 1e-5)
def test(self):
self.run_and_statis(
quant=False,
passes=["shuffle_channel_detect_pass"], )
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册