...
 
Commits (2)
    https://gitcode.net/wjd2002/ncnn/-/commit/e112461d3019f8714f91fcba52e6c7ec4bd20172 write shape, fuse sam image encoder attention (#4792) 2023-06-12T11:43:21+08:00 nihui nihuini@tencent.com * write shape, fuse sam image encoder attention * set more dynamic shape as static * less warning for constant tensor node https://gitcode.net/wjd2002/ncnn/-/commit/4a78b6d457c82e6d513792d4ab8020369c223d5d Update HUAWEI KunPeng 920 platform (#4795) 2023-06-12T16:48:11+08:00 Zhang Geng mobtgzhang@outlook.com
......@@ -4095,6 +4095,135 @@ cooling_down = 0
vision_transformer min = 1617.60 max = 1634.13 avg = 1625.87
FastestDet min = 10.19 max = 10.55 avg = 10.36
```
### HUAWEI KunPeng 920 2251K (x8 cores)
test on UOS 1050
```
mobtgzhang@mobtgzhang-PC:~/ncnn/benchmark$ ./benchncnn 10 1 0 -1 0
loop_count = 10
num_threads = 1
powersave = 0
gpu_device = -1
cooling_down = 0
squeezenet min = 12.11 max = 12.40 avg = 12.25
squeezenet_int8 min = 14.24 max = 14.50 avg = 14.36
mobilenet min = 20.52 max = 21.11 avg = 20.63
mobilenet_int8 min = 18.29 max = 18.63 avg = 18.45
mobilenet_v2 min = 13.73 max = 13.90 avg = 13.79
mobilenet_v3 min = 11.37 max = 11.49 avg = 11.41
shufflenet min = 7.90 max = 7.96 avg = 7.92
shufflenet_v2 min = 8.09 max = 8.13 avg = 8.11
mnasnet min = 13.26 max = 13.44 avg = 13.30
proxylessnasnet min = 16.19 max = 16.39 avg = 16.26
efficientnet_b0 min = 34.92 max = 35.22 avg = 35.04
efficientnetv2_b0 min = 43.82 max = 44.39 avg = 43.94
regnety_400m min = 17.55 max = 18.02 avg = 17.65
blazeface min = 3.05 max = 3.08 avg = 3.07
googlenet min = 58.65 max = 59.26 avg = 58.89
googlenet_int8 min = 60.55 max = 63.00 avg = 61.96
resnet18 min = 34.27 max = 35.43 avg = 34.84
resnet18_int8 min = 60.79 max = 62.15 avg = 61.47
alexnet min = 42.01 max = 44.43 avg = 43.36
vgg16 min = 174.46 max = 177.33 avg = 175.57
vgg16_int8 min = 453.93 max = 457.03 avg = 454.79
resnet50 min = 95.36 max = 96.27 avg = 95.55
resnet50_int8 min = 119.77 max = 121.26 avg = 120.46
squeezenet_ssd min = 39.05 max = 39.69 avg = 39.20
squeezenet_ssd_int8 min = 55.06 max = 56.23 avg = 55.72
mobilenet_ssd min = 45.20 max = 45.96 avg = 45.49
mobilenet_ssd_int8 min = 39.40 max = 40.13 avg = 39.76
mobilenet_yolo min = 98.86 max = 99.85 avg = 99.34
mobilenetv2_yolov3 min = 51.17 max = 52.89 avg = 51.89
yolov4-tiny min = 66.43 max = 67.23 avg = 66.70
nanodet_m min = 20.59 max = 20.79 avg = 20.71
yolo-fastest-1.1 min = 7.90 max = 7.99 avg = 7.93
yolo-fastestv2 min = 7.45 max = 7.49 avg = 7.47
vision_transformer min = 1586.33 max = 1595.34 avg = 1589.76
FastestDet min = 7.45 max = 7.52 avg = 7.47
mobtgzhang@mobtgzhang-PC:~/ncnn/benchmark$ ./benchncnn 10 8 0 -1 0
loop_count = 10
num_threads = 8
powersave = 0
gpu_device = -1
cooling_down = 0
squeezenet min = 2.93 max = 3.10 avg = 3.00
squeezenet_int8 min = 3.47 max = 3.56 avg = 3.52
mobilenet min = 3.89 max = 4.04 avg = 3.94
mobilenet_int8 min = 3.29 max = 3.39 avg = 3.33
mobilenet_v2 min = 3.95 max = 4.08 avg = 3.98
mobilenet_v3 min = 3.45 max = 3.59 avg = 3.49
shufflenet min = 3.42 max = 4.66 avg = 3.62
shufflenet_v2 min = 2.60 max = 2.94 avg = 2.68
mnasnet min = 3.46 max = 3.57 avg = 3.52
proxylessnasnet min = 3.94 max = 12.34 avg = 4.88
efficientnet_b0 min = 7.31 max = 7.60 avg = 7.38
efficientnetv2_b0 min = 9.01 max = 9.22 avg = 9.08
regnety_400m min = 8.56 max = 9.36 avg = 8.70
blazeface min = 1.36 max = 3.52 avg = 1.60
googlenet min = 11.80 max = 12.02 avg = 11.93
googlenet_int8 min = 11.87 max = 23.09 avg = 13.16
resnet18 min = 7.27 max = 7.64 avg = 7.38
resnet18_int8 min = 11.02 max = 11.73 avg = 11.20
alexnet min = 9.05 max = 9.35 avg = 9.17
vgg16 min = 44.13 max = 50.84 avg = 46.89
vgg16_int8 min = 75.15 max = 80.73 avg = 77.52
resnet50 min = 18.72 max = 27.49 avg = 19.96
resnet50_int8 min = 22.72 max = 36.80 avg = 26.78
squeezenet_ssd min = 13.96 max = 27.42 avg = 15.62
squeezenet_ssd_int8 min = 15.01 max = 29.53 avg = 19.51
mobilenet_ssd min = 9.37 max = 13.34 avg = 10.44
mobilenet_ssd_int8 min = 8.07 max = 24.28 avg = 9.83
mobilenet_yolo min = 22.06 max = 24.89 avg = 22.91
mobilenetv2_yolov3 min = 14.41 max = 15.97 avg = 14.78
yolov4-tiny min = 20.71 max = 23.96 avg = 21.42
nanodet_m min = 6.37 max = 6.59 avg = 6.45
yolo-fastest-1.1 min = 4.27 max = 4.52 avg = 4.34
yolo-fastestv2 min = 3.53 max = 3.63 avg = 3.58
vision_transformer min = 435.60 max = 523.43 avg = 479.70
FastestDet min = 3.54 max = 7.95 avg = 5.24
mobtgzhang@mobtgzhang-PC:~/ncnn/benchmark$ ./benchncnn 10 4 2 -1 0
loop_count = 10
num_threads = 4
powersave = 2
gpu_device = -1
cooling_down = 0
squeezenet min = 4.04 max = 4.22 avg = 4.09
squeezenet_int8 min = 4.64 max = 4.76 avg = 4.69
mobilenet min = 6.04 max = 6.06 avg = 6.05
mobilenet_int8 min = 5.23 max = 5.32 avg = 5.25
mobilenet_v2 min = 5.00 max = 5.03 avg = 5.01
mobilenet_v3 min = 4.49 max = 4.69 avg = 4.52
shufflenet min = 3.90 max = 3.94 avg = 3.91
shufflenet_v2 min = 3.27 max = 3.48 avg = 3.33
mnasnet min = 4.80 max = 4.83 avg = 4.82
proxylessnasnet min = 5.20 max = 5.28 avg = 5.23
efficientnet_b0 min = 10.53 max = 11.06 avg = 10.68
efficientnetv2_b0 min = 13.18 max = 13.37 avg = 13.25
regnety_400m min = 9.20 max = 9.25 avg = 9.22
blazeface min = 1.43 max = 1.45 avg = 1.44
googlenet min = 17.63 max = 17.78 avg = 17.71
googlenet_int8 min = 17.63 max = 18.03 avg = 17.85
resnet18 min = 10.34 max = 10.59 avg = 10.40
resnet18_int8 min = 17.93 max = 18.84 avg = 18.25
alexnet min = 13.28 max = 13.37 avg = 13.31
vgg16 min = 55.41 max = 56.60 avg = 55.70
vgg16_int8 min = 123.71 max = 125.34 avg = 124.48
resnet50 min = 27.82 max = 28.22 avg = 27.95
resnet50_int8 min = 34.50 max = 34.89 avg = 34.70
squeezenet_ssd min = 14.67 max = 15.19 avg = 14.85
squeezenet_ssd_int8 min = 19.76 max = 20.32 avg = 19.87
mobilenet_ssd min = 13.15 max = 13.38 avg = 13.21
mobilenet_ssd_int8 min = 11.52 max = 11.70 avg = 11.60
mobilenet_yolo min = 30.95 max = 31.28 avg = 31.05
mobilenetv2_yolov3 min = 20.04 max = 20.36 avg = 20.16
yolov4-tiny min = 25.61 max = 26.73 avg = 25.80
nanodet_m min = 7.93 max = 7.97 avg = 7.95
yolo-fastest-1.1 min = 4.52 max = 4.59 avg = 4.53
yolo-fastestv2 min = 3.74 max = 3.88 avg = 3.77
vision_transformer min = 546.94 max = 726.81 avg = 698.27
FastestDet min = 3.59 max = 3.61 avg = 3.60
```
### Intel Atom x5-Z8350
```
......@@ -5738,4 +5867,4 @@ cooling_down = 0
yolo-fastestv2 min = 5.68 max = 7.20 avg = 5.88
vision_transformer min = 600.83 max = 666.35 avg = 617.33
FastestDet min = 6.05 max = 6.72 avg = 6.23
```
\ No newline at end of file
```
......@@ -269,10 +269,8 @@ Parameter::Parameter(const torch::jit::Node* value_node)
}
else
{
const int ndim = (int)t.dim();
// constant tensor will become pnnx attribute node later
type = 8;
fprintf(stderr, "unknown Parameter value kind %s of TensorType, t.dim = %d\n", value_node->kind().toDisplayString(), ndim);
}
break;
......
......@@ -100,6 +100,50 @@ void GraphRewriterPass::write(Operator* op, const std::map<std::string, Paramete
op->params[x.first] = Parameter::parse_from_string(str);
}
for (size_t i = 0; i < op->inputs.size(); i++)
{
Operand* operand = op->inputs[i];
std::vector<int>& shape = operand->shape;
for (size_t j = 0; j < shape.size(); j++)
{
int ai = shape[j];
if (ai == -233)
{
std::string key = operand->params.at(std::string("__shape_") + std::to_string(j)).s;
if (captured_params.find(key) == captured_params.end())
{
fprintf(stderr, "replace pattern param %%%s missing captured\n", key.c_str());
return;
}
shape[j] = captured_params.at(key).i;
}
}
}
for (size_t i = 0; i < op->outputs.size(); i++)
{
Operand* operand = op->outputs[i];
std::vector<int>& shape = operand->shape;
for (size_t j = 0; j < shape.size(); j++)
{
int ai = shape[j];
if (ai == -233)
{
std::string key = operand->params.at(std::string("__shape_") + std::to_string(j)).s;
if (captured_params.find(key) == captured_params.end())
{
fprintf(stderr, "replace pattern param %%%s missing captured\n", key.c_str());
return;
}
shape[j] = captured_params.at(key).i;
}
}
}
}
void GraphRewriterPass::write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
......
......@@ -100,7 +100,8 @@ static bool operand_maybe_tensor(const Operand* operand)
|| op->type == "aten::div"
|| op->type == "aten::floor_divide"
|| op->type == "aten::mul"
|| op->type == "aten::pow")
|| op->type == "aten::pow"
|| op->type == "aten::remainder")
{
return operand_maybe_tensor(op->inputs[0]) || operand_maybe_tensor(op->inputs[1]);
}
......
......@@ -31,13 +31,12 @@ static bool token_is_interger_literal(const std::string& t)
return iss.eof() && !iss.fail();
}
static std::vector<int> build_shape(const std::string& expr)
static void build_shape(const std::string& expr, std::vector<int>& shape, std::vector<std::string>& expr_tokens)
{
std::string listexpr = expr.substr(1, expr.size() - 2);
std::vector<int> shape;
std::string t;
std::string et;
int level = 0;
for (size_t i = 0; i < listexpr.size(); i++)
{
......@@ -47,21 +46,26 @@ static std::vector<int> build_shape(const std::string& expr)
{
level += 1;
t = "-1";
et += ch;
}
else if (ch == ')' || ch == ']')
{
level -= 1;
t = "-1";
et += ch;
}
else if (level == 0 && ch == ',')
{
int dimsize = token_is_interger_literal(t) ? std::stoi(t) : -1;
shape.push_back(dimsize);
expr_tokens.push_back(et);
t.clear();
et.clear();
}
else
{
t += ch;
et += ch;
}
}
......@@ -71,7 +75,26 @@ static std::vector<int> build_shape(const std::string& expr)
shape.push_back(dimsize);
}
return shape;
if (level == 0 && !et.empty())
{
expr_tokens.push_back(et);
}
}
static std::string build_expr(const std::vector<std::string>& expr_tokens)
{
std::string expr;
expr += '[';
for (int i = 0; i < (int)expr_tokens.size(); i++)
{
expr += expr_tokens[i];
if (i != (int)expr_tokens.size() - 1)
expr += ',';
}
expr += ']';
return expr;
}
void eliminate_reshape_shape_expression(Graph& graph)
......@@ -98,18 +121,21 @@ void eliminate_reshape_shape_expression(Graph& graph)
if (expr.empty() || expr[0] != '[')
continue;
std::vector<int> shape = build_shape(expr);
std::vector<int> outshape = op->outputs[0]->shape;
if (outshape.empty())
continue;
std::vector<int> shape;
std::vector<std::string> expr_tokens;
build_shape(expr, shape, expr_tokens);
// replace -1 with static dim-size
std::vector<int> outshape = op->outputs[0]->shape;
if (!outshape.empty())
for (size_t j = 0; j < outshape.size(); j++)
{
for (size_t j = 0; j < outshape.size(); j++)
if (outshape[j] != -1)
{
if (outshape[j] != -1)
{
shape[j] = outshape[j];
}
shape[j] = outshape[j];
expr_tokens[j] = std::to_string(outshape[j]);
}
}
......@@ -124,7 +150,10 @@ void eliminate_reshape_shape_expression(Graph& graph)
}
if (dynamic_dim_count > 1)
{
op_expr->params["expr"] = build_expr(expr_tokens);
continue;
}
matched = true;
......@@ -156,6 +185,34 @@ void eliminate_reshape_shape_expression(Graph& graph)
if (!matched)
break;
}
for (size_t i = 0; i < graph.ops.size(); i++)
{
Operator* op = graph.ops[i];
if (op->type != "Tensor.view" && op->type != "Tensor.reshape")
continue;
if (op->inputs.size() != 1)
continue;
std::vector<int> outshape = op->outputs[0]->shape;
if (outshape.empty())
continue;
std::vector<int> shape = op->params.at("shape").ai;
// replace -1 with static dim-size
for (size_t j = 0; j < outshape.size(); j++)
{
if (outshape[j] != -1)
{
shape[j] = outshape[j];
}
}
op->params["shape"] = shape;
}
}
} // namespace pnnx
......@@ -56,7 +56,7 @@ public:
pnnx.Input input 0 1 input
Tensor.view op_0 1 1 input 13 shape=(%batch,%groups,%channels_per_group,%h,%w)
torch.transpose op_1 1 1 13 14 dim0=1 dim1=2
Tensor.reshape op_2 1 1 14 out shape=(%batch,-1,%h,%w)
Tensor.reshape op_2 1 1 14 out shape=(%batch,%channels,%h,%w)
pnnx.Output output 1 0 out
)PNNXIR";
}
......
......@@ -1060,14 +1060,14 @@ nn.Linear op_1 1 1 input 4 bias=%kbias in_features=%embed_d
nn.Linear op_2 1 1 input 6 bias=%vbias in_features=%embed_dim out_features=%embed_dim @bias @weight
pnnx.Expression op_3 1 1 2 3 expr=mul(@0,%inv_sqrt_embed_dim_per_head)
Tensor.view op_4 1 1 3 8 shape=(%batch,%size,%num_heads,%feat_per_head)
Tensor.view op_5 1 1 4 5 shape=(%batch,-1,%num_heads,%feat_per_head)
Tensor.view op_6 1 1 6 7 shape=(%batch,-1,%num_heads,%feat_per_head)
Tensor.view op_5 1 1 4 5 shape=(%batch,%size,%num_heads,%feat_per_head)
Tensor.view op_6 1 1 6 7 shape=(%batch,%size,%num_heads,%feat_per_head)
torch.transpose op_7 1 1 8 9 dim0=1 dim1=2
torch.transpose op_8 1 1 5 10 dim0=1 dim1=2
torch.transpose op_9 1 1 7 11 dim0=1 dim1=2
Tensor.reshape op_10 1 1 9 14 shape=(%num_heads,-1,%feat_per_head)
Tensor.reshape op_11 1 1 10 12 shape=(%num_heads,-1,%feat_per_head)
Tensor.reshape op_12 1 1 11 17 shape=(%num_heads,-1,%feat_per_head)
Tensor.reshape op_10 1 1 9 14 shape=(%num_heads,%batch_mul_size,%feat_per_head)
Tensor.reshape op_11 1 1 10 12 shape=(%num_heads,%batch_mul_size,%feat_per_head)
Tensor.reshape op_12 1 1 11 17 shape=(%num_heads,%batch_mul_size,%feat_per_head)
torch.transpose op_13 1 1 12 13 dim0=1 dim1=2
torch.bmm op_14 2 1 14 13 15
F.softmax op_15 1 1 15 16 dim=-1
......@@ -1094,14 +1094,14 @@ nn.Linear op_1 1 1 input 5 bias=%kbias in_features=%embed_d
nn.Linear op_2 1 1 input 7 bias=%vbias in_features=%embed_dim out_features=%embed_dim @bias @weight
pnnx.Expression op_3 1 1 3 4 expr=mul(@0,%inv_sqrt_embed_dim_per_head)
Tensor.view op_4 1 1 4 9 shape=(%batch,%size,%num_heads,%feat_per_head)
Tensor.view op_5 1 1 5 6 shape=(%batch,-1,%num_heads,%feat_per_head)
Tensor.view op_6 1 1 7 8 shape=(%batch,-1,%num_heads,%feat_per_head)
Tensor.view op_5 1 1 5 6 shape=(%batch,%size,%num_heads,%feat_per_head)
Tensor.view op_6 1 1 7 8 shape=(%batch,%size,%num_heads,%feat_per_head)
torch.transpose op_7 1 1 9 10 dim0=1 dim1=2
torch.transpose op_8 1 1 6 11 dim0=1 dim1=2
torch.transpose op_9 1 1 8 12 dim0=1 dim1=2
Tensor.reshape op_10 1 1 10 15 shape=(%num_heads,-1,%feat_per_head)
Tensor.reshape op_11 1 1 11 13 shape=(%num_heads,-1,%feat_per_head)
Tensor.reshape op_12 1 1 12 21 shape=(%num_heads,-1,%feat_per_head)
Tensor.reshape op_10 1 1 10 15 shape=(%num_heads,%batch_mul_size,%feat_per_head)
Tensor.reshape op_11 1 1 11 13 shape=(%num_heads,%batch_mul_size,%feat_per_head)
Tensor.reshape op_12 1 1 12 21 shape=(%num_heads,%batch_mul_size,%feat_per_head)
torch.transpose op_13 1 1 13 14 dim0=1 dim1=2
torch.bmm op_14 2 1 15 14 16
Tensor.view op_15 1 1 16 17 shape=(%batch,%num_heads,%size,%size)
......@@ -1301,7 +1301,7 @@ pnnx.Expression op_7 2 1 33 attn_mask 35 expr=add(@0,@1)
Tensor.view op_8 1 1 35 36 shape=(1,%batch,%num_heads,%size,%size)
pnnx.Attribute op_9 0 1 37 @data=(1,%batch,1,%size,%size)f32
pnnx.Expression op_10 2 1 36 37 38 expr=add(@0,@1)
Tensor.view op_11 1 1 38 39 shape=(-1,%num_heads,%size,%size)
Tensor.view op_11 1 1 38 39 shape=(%batch,%num_heads,%size,%size)
F.softmax op_12 1 1 39 40 dim=-1
torch.matmul op_13 2 1 40 30 41
torch.transpose op_14 1 1 41 42 dim0=1 dim1=2
......
......@@ -79,13 +79,73 @@ pnnx.Output output 1 0 out
}
};
class fuse_scaled_dot_product_attention_pass_1 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
14 13
pnnx.Input input_0 0 1 query #query=(%batch,%qsize,%feat_per_head)f32
pnnx.Input input_1 0 1 key #key=(%batch,%kvsize,%feat_per_head)f32
pnnx.Input input_2 0 1 value #value=(%batch,%kvsize,%feat_per_head)f32
pnnx.Input input_Rh 0 1 Rh #Rh=(%batch,%h,%w,%h,1)f32
pnnx.Input input_Rw 0 1 Rw #Rw=(%batch,%h,%w,1,%w)f32
pnnx.Expression op_0 1 1 query 17 expr=mul(@0,%inv_sqrt_embed_dim_per_head)
torch.transpose op_1 1 1 key 22 dim0=-2 dim1=-1
torch.matmul op_2 2 1 17 22 23
Tensor.view op_3 1 1 23 24 shape=(%batch,%h,%w,%h,%w)
pnnx.Expression op_4 3 1 24 Rh Rw 28 expr=add(add(@0,@1),@2)
Tensor.view op_5 1 1 28 29 shape=(%batch,%qsize,%qsize)
F.softmax op_6 1 1 29 30 dim=-1
torch.matmul op_7 2 1 30 value out
pnnx.Output output 1 0 out
)PNNXIR";
}
const char* replace_pattern_graph() const
{
return R"PNNXIR(7767517
9 8
pnnx.Input input_0 0 1 query
pnnx.Input input_1 0 1 key
pnnx.Input input_2 0 1 value
pnnx.Input input_Rh 0 1 Rh
pnnx.Input input_Rw 0 1 Rw
pnnx.Expression RhRw 2 1 Rh Rw RhRw expr=add(@0,@1) #RhRw=(%batch,%h,%w,%h,%w)f32
Tensor.reshape attn_mask 1 1 RhRw attn_mask shape=(%batch,%qsize,%qsize) #attn_mask=(%batch,%qsize,%qsize)f32
F.scaled_dot_product_attention op_0 4 1 query key value attn_mask out dropout_p=0.0 is_causal=False $attn_mask=attn_mask
pnnx.Output output 1 0 out
)PNNXIR";
}
bool match(const std::map<std::string, Parameter>& captured_params) const
{
const int qsize = captured_params.at("qsize").i;
const int h = captured_params.at("h").i;
const int w = captured_params.at("w").i;
const int feat_per_head = captured_params.at("feat_per_head").i;
const float inv_sqrt_embed_dim_per_head = captured_params.at("inv_sqrt_embed_dim_per_head").f;
if (qsize != h * w)
return false;
if (!NearlyEqual(inv_sqrt_embed_dim_per_head, 1.f / sqrt(feat_per_head), 0.001))
return false;
return true;
}
};
void fuse_scaled_dot_product_attention(Graph& graph)
{
#if TORCH_VERSION_MAJOR >= 2
fuse_scaled_dot_product_attention_pass a;
fuse_scaled_dot_product_attention_pass_1 b;
int opindex = 0;
pnnx_graph_rewrite(graph, &a, opindex);
pnnx_graph_rewrite(graph, &b, opindex);
#endif
}
......
......@@ -35,51 +35,58 @@ public:
{
return R"PNNXIR(7767517
6 6
pnnx.Input input 0 1 input
Tensor.reshape op_0 1 1 input a shape=%shape
torch.permute op_1 1 1 a b dims=%dims
Tensor.reshape op_2 1 1 b c shape=%shape2
pnnx.Input input 0 1 input #input=(%batch,%c,%h,%w)f32
Tensor.reshape op_0 1 1 input a shape=(%batch_mul_ch_per_group,%groups,%h_mul_w)
torch.permute op_1 1 1 a b dims=(1,0,2)
Tensor.reshape op_2 1 1 b c shape=(%groups,%batch,%ch_per_group,%h,%w)
torch.unbind op_3 1 2 c out0 out1 dim=0
pnnx.Output output 2 0 out0 out1
)PNNXIR";
}
const char* type_str() const
{
return "ncnn._shufflechannel_slice";
}
const char* name_str() const
const char* replace_pattern_graph() const
{
return "shufflechannel_slice";
return R"PNNXIR(7767517
4 4
pnnx.Input input 0 1 input
ShuffleChannel shufflechannel 1 1 input a 0=%groups 1=1 #a=(%batch,%c,%h,%w)f32
Slice slice 1 2 a out0 out1 0=(-233,-233) 1=0
pnnx.Output output 2 0 out0 out1
)PNNXIR";
}
bool match(const std::map<std::string, Parameter>& captured_params) const
{
// (116,2,1024)
// (1,0,2)
// (2,-1,116,32,32)
const std::vector<int>& shape = captured_params.at("shape").ai;
const std::vector<int>& dims = captured_params.at("dims").ai;
const std::vector<int>& shape2 = captured_params.at("shape2").ai;
if (dims != std::vector<int>{1, 0, 2})
const int groups = captured_params.at("groups").i;
const int batch = captured_params.at("batch").i;
const int batch_mul_ch_per_group = captured_params.at("batch_mul_ch_per_group").i;
const int ch_per_group = captured_params.at("ch_per_group").i;
const int h_mul_w = captured_params.at("h_mul_w").i;
const int c = captured_params.at("c").i;
const int h = captured_params.at("h").i;
const int w = captured_params.at("w").i;
if (groups != 2 || groups * ch_per_group != c)
return false;
if (shape[0] != shape2[2] || shape[1] != shape2[0] || shape[2] != shape2[3] * shape2[4] || shape[1] != 2 || shape2[1] != -1)
if (batch_mul_ch_per_group != batch * ch_per_group)
return false;
if (h_mul_w != h * w)
return false;
return true;
}
void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
void write(const std::map<std::string, Operator*>& ops, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
const std::vector<int>& shape = captured_params.at("shape").ai;
GraphRewriterPass::write(ops, captured_params, captured_attrs);
int groups = shape[1];
const int batch_index = ops.at("shufflechannel")->inputs[0]->params["__batch_index"].i;
op->params["0"] = groups;
op->params["1"] = 1;
ops.at("slice")->inputs[0]->params["__batch_index"] = batch_index;
ops.at("slice")->outputs[0]->params["__batch_index"] = batch_index;
ops.at("slice")->outputs[1]->params["__batch_index"] = batch_index;
}
};
......@@ -90,10 +97,10 @@ public:
{
return R"PNNXIR(7767517
6 6
pnnx.Input input 0 1 input
Tensor.reshape op_0 1 1 input a shape=%shape
Tensor.permute op_1 1 1 a b dims=%dims
Tensor.reshape op_2 1 1 b c shape=%shape2
pnnx.Input input 0 1 input #input=(%batch,%c,%h,%w)f32
Tensor.reshape op_0 1 1 input a shape=(%batch_mul_ch_per_group,%groups,%h_mul_w)
Tensor.permute op_1 1 1 a b dims=(1,0,2)
Tensor.reshape op_2 1 1 b c shape=(%groups,%batch,%ch_per_group,%h,%w)
torch.unbind op_3 1 2 c out0 out1 dim=0
pnnx.Output output 2 0 out0 out1
)PNNXIR";
......@@ -108,57 +115,6 @@ void fuse_convert_shufflechannel_slice(Graph& graph)
pnnx_graph_rewrite(graph, &a, opindex);
pnnx_graph_rewrite(graph, &b, opindex);
int op_index = 0;
while (1)
{
bool matched = false;
for (Operator* op : graph.ops)
{
if (op->type != "ncnn._shufflechannel_slice")
continue;
matched = true;
const int batch_index = op->inputs[0]->params["__batch_index"].i;
op->type = "ShuffleChannel";
op->name = std::string("shufflechannel_") + std::to_string(op_index++);
Operand* out0 = op->outputs[0];
Operand* out1 = op->outputs[1];
Operator* slice = graph.new_operator_after("Slice", op->name + "_slice", op);
Operand* slice_in = graph.new_operand(op->name + "_slice_in");
slice_in->params["__batch_index"] = batch_index;
out0->params["__batch_index"] = batch_index;
out1->params["__batch_index"] = batch_index;
slice->inputs.push_back(slice_in);
slice->outputs.push_back(out0);
slice->outputs.push_back(out1);
op->outputs.clear();
op->outputs.push_back(slice_in);
out0->producer = slice;
out1->producer = slice;
slice_in->producer = op;
slice_in->consumers.push_back(slice);
slice->params["0"] = std::vector<int>{-233, -233};
slice->params["1"] = 0;
break;
}
if (!matched)
break;
}
}
} // namespace ncnn
......