未验证 提交 0ee230a7 编写于 作者: W wenbin 提交者: GitHub

shuffle_channel pass fix (#39735)

上级 e764cde6
...@@ -94,6 +94,7 @@ void ShuffleChannelDetectPass::ApplyImpl(ir::Graph* graph) const { ...@@ -94,6 +94,7 @@ void ShuffleChannelDetectPass::ApplyImpl(ir::Graph* graph) const {
auto* input_node = subgraph.at(x); auto* input_node = subgraph.at(x);
auto reshape1_desc = reshape1_op->Op(); auto reshape1_desc = reshape1_op->Op();
auto reshape2_desc = reshape2_op->Op(); auto reshape2_desc = reshape2_op->Op();
auto trans_desc = transpose_op->Op();
std::string input_name = input_node->Name(); std::string input_name = input_node->Name();
std::string output_name = reshape2_out->Name(); std::string output_name = reshape2_out->Name();
...@@ -101,10 +102,101 @@ void ShuffleChannelDetectPass::ApplyImpl(ir::Graph* graph) const { ...@@ -101,10 +102,101 @@ void ShuffleChannelDetectPass::ApplyImpl(ir::Graph* graph) const {
BOOST_GET_CONST(std::vector<int>, reshape1_desc->GetAttr("shape")); BOOST_GET_CONST(std::vector<int>, reshape1_desc->GetAttr("shape"));
auto reshape2_shape = auto reshape2_shape =
BOOST_GET_CONST(std::vector<int>, reshape2_desc->GetAttr("shape")); BOOST_GET_CONST(std::vector<int>, reshape2_desc->GetAttr("shape"));
auto trans_axis =
BOOST_GET_CONST(std::vector<int>, trans_desc->GetAttr("axis"));
auto* block1 = reshape1_desc->Block();
auto* block2 = reshape2_desc->Block();
if (block1 && block2) {
auto x_var_name = reshape1_desc->Input("X")[0];
auto* x_var_desc = block1->FindVar(x_var_name);
auto x_shape1 = x_var_desc->GetShape();
x_var_name = reshape2_desc->Input("X")[0];
x_var_desc = block2->FindVar(x_var_name);
auto x_shape2 = x_var_desc->GetShape();
// now shuffle_channel is 4D(NCHW) only.
if (x_shape1.size() != 4 || reshape1_shape.size() != 5 ||
reshape2_shape.size() != 4 || trans_axis.size() != 5) {
return;
}
// process 0 and -1 in reshape.
constexpr int64_t copy_dim_val = 0;
for (size_t i = 0; i < reshape1_shape.size(); i++) {
if (reshape1_shape[i] == copy_dim_val) {
reshape1_shape[i] = x_shape1[i];
}
}
for (size_t i = 0; i < reshape2_shape.size(); i++) {
if (reshape2_shape[i] == copy_dim_val) {
reshape2_shape[i] = x_shape2[i];
}
}
constexpr int64_t unk_dim_idx = -1;
bool all_positive = std::all_of(x_shape1.cbegin(), x_shape1.cend(),
[](int64_t i) { return i > 0; });
for (size_t i = 0; i < reshape1_shape.size(); ++i) {
// if -1 is not in batch dim, try to calculate number
if ((reshape1_shape[i] == unk_dim_idx) && (i != 0)) {
// there is no sufficient info
if (!all_positive) return;
reshape1_shape[i] =
std::accumulate(x_shape1.begin(), x_shape1.end(),
static_cast<int64_t>(1),
std::multiplies<int64_t>()) /
std::accumulate(reshape1_shape.begin(), reshape1_shape.end(),
static_cast<int64_t>(-1),
std::multiplies<int64_t>());
break;
}
}
all_positive = std::all_of(x_shape2.cbegin(), x_shape2.cend(),
[](int64_t i) { return i > 0; });
for (size_t i = 0; i < reshape2_shape.size(); ++i) {
// if -1 is not in batch dim, try to calculate number
if ((reshape2_shape[i] == unk_dim_idx) && (i != 0)) {
// there is no sufficient info
if (!all_positive) return;
reshape2_shape[i] =
std::accumulate(x_shape2.begin(), x_shape2.end(),
static_cast<int64_t>(1),
std::multiplies<int64_t>()) /
std::accumulate(reshape2_shape.begin(), reshape2_shape.end(),
static_cast<int64_t>(-1),
std::multiplies<int64_t>());
break;
}
}
// shuffle_channel dosen't change shape
if ((reshape2_shape[0] != -1) && (x_shape1[0] != reshape2_shape[0])) {
return;
}
for (size_t i = 1; i < x_shape1.size(); i++) {
if (x_shape1[i] != reshape2_shape[i]) {
return;
}
}
if ((reshape2_shape[3] != reshape1_shape[4]) ||
(reshape2_shape[2] != reshape1_shape[3])) {
return;
}
} else {
return; // conservative judgement
}
int i_c = reshape1_shape[2]; int i_c = reshape1_shape[2];
int o_c = reshape2_shape[1]; int o_c = reshape2_shape[1];
int group = o_c / i_c; int group = o_c / i_c;
// should split on channel dim
if (reshape2_shape[1] != reshape1_shape[2] * reshape1_shape[1]) return;
// trans on channel dim
if (trans_axis[0] != 0 || trans_axis[3] != 3 || trans_axis[4] != 4) return;
if (group != 1 && i_c != 1) {
if (trans_axis[1] != 2 && trans_axis[2] != 1) {
return;
}
}
framework::OpDesc new_op_desc; framework::OpDesc new_op_desc;
new_op_desc.SetType("shuffle_channel"); new_op_desc.SetType("shuffle_channel");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册