Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
d31a174f
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
d31a174f
编写于
4月 24, 2020
作者:
A
arlesniak
提交者:
GitHub
4月 24, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
added fusing matmul-transpose-reshape pass (#23866)
上级
46f3139c
变更
14
隐藏空白更改
内联
并排
Showing
14 changed file
with
727 addition
and
38 deletion
+727
-38
paddle/fluid/framework/ddim.cc
paddle/fluid/framework/ddim.cc
+62
-0
paddle/fluid/framework/ddim.h
paddle/fluid/framework/ddim.h
+4
-0
paddle/fluid/framework/ir/CMakeLists.txt
paddle/fluid/framework/ir/CMakeLists.txt
+2
-0
paddle/fluid/framework/ir/graph_pattern_detector.cc
paddle/fluid/framework/ir/graph_pattern_detector.cc
+37
-0
paddle/fluid/framework/ir/graph_pattern_detector.h
paddle/fluid/framework/ir/graph_pattern_detector.h
+18
-0
paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.cc
...framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.cc
+100
-0
paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.h
.../framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.h
+35
-0
paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass_tester.cc
...rk/ir/mkldnn/matmul_transpose_reshape_fuse_pass_tester.cc
+93
-0
paddle/fluid/inference/api/paddle_pass_builder.cc
paddle/fluid/inference/api/paddle_pass_builder.cc
+6
-5
paddle/fluid/operators/matmul_op.cc
paddle/fluid/operators/matmul_op.cc
+50
-1
paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc
paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc
+80
-32
python/paddle/fluid/contrib/slim/quantization/qat2_int8_mkldnn_pass.py
.../fluid/contrib/slim/quantization/qat2_int8_mkldnn_pass.py
+1
-0
python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_matmul_op_output_fuse_pass.py
...ts/ir/inference/test_mkldnn_matmul_op_output_fuse_pass.py
+110
-0
python/paddle/fluid/tests/unittests/mkldnn/test_matmul_mkldnn_op.py
...dle/fluid/tests/unittests/mkldnn/test_matmul_mkldnn_op.py
+129
-0
未找到文件。
paddle/fluid/framework/ddim.cc
浏览文件 @
d31a174f
...
...
@@ -131,5 +131,67 @@ DDim stride_numel(const DDim& ddim) {
return
strides
;
}
DDim
DDim
::
reshape
(
const
std
::
vector
<
int
>&
shape
)
const
{
const
int64_t
copy_dim_val
=
0
;
const
DDim
&
in_dims
=
*
this
;
DDim
out_dims
;
out_dims
.
rank_
=
shape
.
size
();
for
(
size_t
i
=
0
;
i
<
shape
.
size
();
++
i
)
{
if
(
shape
[
i
]
==
copy_dim_val
)
{
PADDLE_ENFORCE_LT
(
static_cast
<
int
>
(
i
),
in_dims
.
size
(),
platform
::
errors
::
InvalidArgument
(
"Index %d of shape under which the value of 0 "
"is stored, must be lower than the number of "
"old dimensions. But received shape[%d] = 0, "
"dimensions = %d, shape = [%s]."
,
i
,
in_dims
.
size
(),
in_dims
));
out_dims
[
i
]
=
in_dims
[
i
];
}
else
{
out_dims
[
i
]
=
shape
[
i
];
}
}
return
out_dims
;
}
DDim
DDim
::
transpose
(
const
std
::
vector
<
int
>&
axis
)
const
{
const
DDim
&
in_dims
=
*
this
;
size_t
in_rank
=
in_dims
.
size
();
size_t
axis_size
=
axis
.
size
();
PADDLE_ENFORCE_EQ
(
in_rank
,
axis_size
,
platform
::
errors
::
InvalidArgument
(
"The input dimension's size "
"should be equal to the axis's size. "
"But received dimension is %d, "
"axis's size is %d"
,
in_rank
,
axis_size
));
std
::
vector
<
int
>
count
(
axis_size
,
0
);
for
(
size_t
i
=
0
;
i
<
axis_size
;
i
++
)
{
PADDLE_ENFORCE_LT
(
axis
[
i
],
static_cast
<
int
>
(
axis_size
),
platform
::
errors
::
InvalidArgument
(
"ValueError: Each element of axis must appear "
"exactly once in the range from 0 to (dims - 1), "
"where the dims is the axis's size, "
"but received axis[%d] is %d, axis_size is %d"
,
i
,
axis
[
i
],
axis_size
));
PADDLE_ENFORCE_EQ
(
++
count
[
axis
[
i
]],
1
,
platform
::
errors
::
InvalidArgument
(
"ValueError: Each element of axis should "
"be a unique value range from 0 to (dims - 1), "
"where the dims is the axis's size, "
"unique value means this axis value can appear only once. "
"But received count[axis[%d]] is %d"
,
i
,
count
[
axis
[
i
]]));
}
DDim
out_dims
(
in_dims
);
for
(
size_t
i
=
0
;
i
<
axis_size
;
i
++
)
{
out_dims
[
i
]
=
in_dims
[
axis
[
i
]];
}
return
out_dims
;
}
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ddim.h
浏览文件 @
d31a174f
...
...
@@ -126,6 +126,10 @@ class DDim {
std
::
string
to_str
()
const
;
DDim
reshape
(
const
std
::
vector
<
int
>&
shape
)
const
;
DDim
transpose
(
const
std
::
vector
<
int
>&
axis
)
const
;
private:
template
<
int
D
>
inline
Dim
<
D
>&
UnsafeCast
()
{
...
...
paddle/fluid/framework/ir/CMakeLists.txt
浏览文件 @
d31a174f
...
...
@@ -97,6 +97,7 @@ if(WITH_MKLDNN)
pass_library
(
cpu_quantize_placement_pass base DIR mkldnn
)
pass_library
(
cpu_quantize_pass inference DIR mkldnn
)
pass_library
(
cpu_quantize_squash_pass inference DIR mkldnn
)
pass_library
(
matmul_transpose_reshape_fuse_pass inference DIR mkldnn
)
endif
()
cc_library
(
fuse_bn_act_pass SRCS fuse_bn_act_pass.cc DEPS pass graph_pattern_detector
)
...
...
@@ -144,4 +145,5 @@ if (WITH_MKLDNN)
cc_test
(
test_cpu_quantize_placement_pass SRCS mkldnn/cpu_quantize_placement_pass_tester.cc DEPS cpu_quantize_placement_pass
)
cc_test
(
test_cpu_quantize_pass SRCS mkldnn/cpu_quantize_pass_tester.cc DEPS cpu_quantize_pass naive_executor
)
cc_test
(
test_cpu_quantize_squash_pass SRCS mkldnn/cpu_quantize_squash_pass_tester.cc DEPS cpu_quantize_squash_pass naive_executor
)
cc_test
(
test_matmul_transpose_reshape_fuse_pass SRCS mkldnn/matmul_transpose_reshape_fuse_pass_tester.cc DEPS matmul_transpose_reshape_fuse_pass
)
endif
()
paddle/fluid/framework/ir/graph_pattern_detector.cc
浏览文件 @
d31a174f
...
...
@@ -2147,6 +2147,43 @@ void patterns::DeleteQuantDequantOpPattern::operator()() {
any_op2
->
LinksFrom
({
quant_dequant_out
});
}
PDNode
*
patterns
::
MatmulTransposeReshapePattern
::
operator
()()
{
auto
reshape_op
=
pattern
->
NewNode
(
reshape_op_repr
())
->
assert_is_op
(
"reshape2"
);
auto
transpose_op
=
pattern
->
NewNode
(
transpose_op_repr
())
->
assert_is_op
(
"transpose2"
);
auto
matmul_op
=
pattern
->
NewNode
(
matmul_op_repr
())
->
assert_is_op
(
"matmul"
);
auto
matmul_out
=
pattern
->
NewNode
(
matmul_out_repr
())
->
AsInput
()
->
assert_is_op_output
(
"matmul"
,
"Out"
)
->
assert_is_op_input
(
"transpose2"
,
"X"
);
auto
transpose_out
=
pattern
->
NewNode
(
transpose_out_repr
())
->
AsIntermediate
()
->
assert_is_op_output
(
"transpose2"
,
"Out"
)
->
assert_is_op_input
(
"reshape2"
,
"X"
);
auto
transpose_out_xshape
=
pattern
->
NewNode
(
transpose_out_xshape_repr
())
->
AsIntermediate
()
->
assert_is_op_output
(
"transpose2"
,
"XShape"
);
auto
reshape_out
=
pattern
->
NewNode
(
reshape_out_repr
())
->
AsOutput
()
->
assert_is_op_output
(
"reshape2"
);
auto
reshape_out_xshape
=
pattern
->
NewNode
(
reshape_out_xshape_repr
())
->
AsIntermediate
()
->
assert_is_op_output
(
"reshape2"
,
"XShape"
);
matmul_op
->
LinksTo
({
matmul_out
});
transpose_op
->
LinksTo
({
transpose_out_xshape
});
reshape_op
->
LinksTo
({
reshape_out_xshape
});
transpose_op
->
LinksFrom
({
matmul_out
}).
LinksTo
({
transpose_out
});
reshape_op
->
LinksFrom
({
transpose_out
}).
LinksTo
({
reshape_out
});
return
reshape_out
;
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/graph_pattern_detector.h
浏览文件 @
d31a174f
...
...
@@ -1210,6 +1210,24 @@ struct DeleteQuantDequantOpPattern : public PatternBase {
PATTERN_DECL_NODE
(
any_op2
);
};
// Matmul + Transpose + Reshape
struct
MatmulTransposeReshapePattern
:
public
PatternBase
{
MatmulTransposeReshapePattern
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"matmul_transpose_reshape"
)
{}
PDNode
*
operator
()();
PATTERN_DECL_NODE
(
matmul_op
);
PATTERN_DECL_NODE
(
matmul_out
);
PATTERN_DECL_NODE
(
transpose_op
);
PATTERN_DECL_NODE
(
transpose_out
);
PATTERN_DECL_NODE
(
transpose_out_xshape
);
PATTERN_DECL_NODE
(
reshape_op
);
PATTERN_DECL_NODE
(
reshape_out
);
PATTERN_DECL_NODE
(
reshape_out_xshape
);
};
}
// namespace patterns
// Link two ir::Nodes from each other.
...
...
paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.cc
0 → 100644
浏览文件 @
d31a174f
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.h"
#include <paddle/fluid/string/pretty_log.h>
#include <vector>
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
void
MatmulTransposeReshapeMKLDNNPass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
graph
,
platform
::
errors
::
InvalidArgument
(
"Pointer to graph argument should not be NULL."
));
FusePassBase
::
Init
(
name_scope_
,
graph
);
GraphPatternDetector
gpd
;
patterns
::
MatmulTransposeReshapePattern
mtrp
(
gpd
.
mutable_pattern
(),
name_scope_
);
mtrp
();
int
found_matmul_transpose_reshape_count
=
0
;
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
VLOG
(
4
)
<<
"handle matmul_transpose_reshape fuse"
;
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_op
,
matmul_op
,
mtrp
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_out
,
matmul_out
,
mtrp
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose_op
,
transpose_op
,
mtrp
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose_out
,
transpose_out
,
mtrp
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose_out_xshape
,
transpose_out_xshape
,
mtrp
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape_op
,
reshape_op
,
mtrp
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape_out
,
reshape_out
,
mtrp
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape_out_xshape
,
reshape_out_xshape
,
mtrp
);
auto
reshape_shape
=
boost
::
get
<
std
::
vector
<
int
>>
(
reshape_op
->
Op
()
->
GetAttr
(
"shape"
));
auto
transpose_axis
=
boost
::
get
<
std
::
vector
<
int
>>
(
transpose_op
->
Op
()
->
GetAttr
(
"axis"
));
auto
reshape_out_size
=
reshape_shape
.
size
();
auto
transpose_out_size
=
transpose_axis
.
size
();
const
std
::
vector
<
int
>
supported_axis
{
0
,
2
,
1
,
3
};
const
bool
supported_transpose_axis
=
std
::
equal
(
transpose_axis
.
begin
(),
transpose_axis
.
end
(),
supported_axis
.
begin
());
if
(
transpose_out_size
!=
4
)
{
VLOG
(
3
)
<<
"do not perform matmul_transpose_reshape fuse: "
<<
"supported rank is 4, received "
<<
transpose_out_size
;
return
;
}
if
(
!
supported_transpose_axis
)
{
VLOG
(
3
)
<<
"do not perform matmul_transpose_reshape fuse: "
<<
"supported transpose axis for the fuse are {0, 2, 1, 3}"
;
return
;
}
if
(
reshape_out_size
!=
3
)
{
VLOG
(
3
)
<<
"do not perform matmul_transpose_reshape fuse: "
<<
"reshape_out supported rank is 3, received "
<<
reshape_out_size
;
return
;
}
OpDesc
*
matmul_desc
=
matmul_op
->
Op
();
matmul_desc
->
SetOutput
(
"Out"
,
{
reshape_out
->
Name
()});
matmul_desc
->
SetAttr
(
"fused_reshape_Out"
,
reshape_shape
);
matmul_desc
->
SetAttr
(
"fused_transpose_Out"
,
transpose_axis
);
GraphSafeRemoveNodes
(
graph
,
{
matmul_out
,
transpose_op
,
transpose_out
,
reshape_op
,
transpose_out_xshape
,
reshape_out_xshape
});
IR_OP_VAR_LINK
(
matmul_op
,
reshape_out
);
found_matmul_transpose_reshape_count
++
;
};
gpd
(
graph
,
handler
);
AddStatis
(
found_matmul_transpose_reshape_count
);
std
::
stringstream
msg_ss
;
msg_ss
<<
"--- Fused "
<<
found_matmul_transpose_reshape_count
<<
" MatmulTransposeReshape patterns"
;
paddle
::
string
::
PrettyLogDetail
(
msg_ss
.
str
().
c_str
());
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
matmul_transpose_reshape_fuse_pass
,
paddle
::
framework
::
ir
::
MatmulTransposeReshapeMKLDNNPass
);
paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.h
0 → 100644
浏览文件 @
d31a174f
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
class
MatmulTransposeReshapeMKLDNNPass
:
public
FusePassBase
{
public:
virtual
~
MatmulTransposeReshapeMKLDNNPass
()
{}
protected:
void
ApplyImpl
(
Graph
*
graph
)
const
override
;
const
std
::
string
name_scope_
{
"matmul_transpose_reshape_fuse"
};
};
}
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass_tester.cc
0 → 100644
浏览文件 @
d31a174f
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.h"
#include <gtest/gtest.h>
namespace
paddle
{
namespace
framework
{
namespace
ir
{
void
SetOp
(
ProgramDesc
*
prog
,
const
std
::
string
&
type
,
const
std
::
vector
<
std
::
string
>
&
inputs
,
const
std
::
vector
<
std
::
string
>
&
outputs
)
{
auto
*
op
=
prog
->
MutableBlock
(
0
)
->
AppendOp
();
op
->
SetType
(
type
);
op
->
SetInput
(
"X"
,
{
inputs
[
0
]});
op
->
SetOutput
(
"Out"
,
{
outputs
[
0
]});
if
(
type
==
"transpose2"
)
{
op
->
SetAttr
(
"axis"
,
std
::
vector
<
int
>
({
0
,
2
,
1
,
3
}));
op
->
SetOutput
(
"XShape"
,
{
outputs
[
1
]});
}
if
(
type
==
"reshape2"
)
{
op
->
SetAttr
(
"shape"
,
std
::
vector
<
int
>
({
4
,
5
,
6
}));
op
->
SetOutput
(
"XShape"
,
{
outputs
[
1
]});
}
if
(
type
==
"matmul"
)
{
op
->
SetInput
(
"Y"
,
{
inputs
[
1
]});
op
->
SetAttr
(
"use_mkldnn"
,
true
);
}
}
ProgramDesc
BuildProgramDesc
()
{
ProgramDesc
prog
;
for
(
auto
&
v
:
std
::
initializer_list
<
std
::
string
>
(
{
"a1"
,
"a2"
,
"b"
,
"c"
,
"cx"
,
"d"
,
"dx"
,
"e"
}))
{
auto
*
var
=
prog
.
MutableBlock
(
0
)
->
Var
(
v
);
var
->
SetType
(
proto
::
VarType
::
SELECTED_ROWS
);
}
SetOp
(
&
prog
,
"matmul"
,
{
"a1"
,
"a2"
},
{
"b"
});
SetOp
(
&
prog
,
"transpose2"
,
{
"b"
},
{
"c"
,
"cx"
});
SetOp
(
&
prog
,
"reshape2"
,
{
"c"
},
{
"d"
,
"dx"
});
SetOp
(
&
prog
,
"fc"
,
{
"d"
},
{
"e"
});
return
prog
;
}
void
MainTest
(
const
ProgramDesc
&
prog
)
{
std
::
unique_ptr
<
ir
::
Graph
>
graph
(
new
ir
::
Graph
(
prog
));
int
original_nodes_num
=
graph
->
Nodes
().
size
();
auto
pass
=
PassRegistry
::
Instance
().
Get
(
"matmul_transpose_reshape_fuse_pass"
);
graph
.
reset
(
pass
->
Apply
(
graph
.
release
()));
int
current_nodes_num
=
graph
->
Nodes
().
size
();
EXPECT_EQ
(
original_nodes_num
-
6
,
current_nodes_num
);
for
(
auto
*
node
:
graph
->
Nodes
())
{
if
(
node
->
IsOp
())
{
auto
*
op
=
node
->
Op
();
if
(
op
->
Type
()
==
"matmul"
)
{
EXPECT_EQ
(
op
->
GetAttrIfExists
<
std
::
vector
<
int
>>
(
"fused_reshape_Out"
),
std
::
vector
<
int
>
({
4
,
5
,
6
}));
EXPECT_EQ
(
op
->
GetAttrIfExists
<
std
::
vector
<
int
>>
(
"fused_transpose_Out"
),
std
::
vector
<
int
>
({
0
,
2
,
1
,
3
}));
}
}
}
}
TEST
(
MatmulTransposeReshapeFusePass
,
matmul_inputs
)
{
auto
prog
=
BuildProgramDesc
();
MainTest
(
prog
);
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
USE_PASS
(
matmul_transpose_reshape_fuse_pass
);
paddle/fluid/inference/api/paddle_pass_builder.cc
浏览文件 @
d31a174f
...
...
@@ -191,11 +191,12 @@ void CpuPassStrategy::EnableMKLDNN() {
"conv3d_bias_mkldnn_fuse_pass"
,
//
"conv_elementwise_add_mkldnn_fuse_pass"
,
"conv_concat_relu_mkldnn_fuse_pass"
,
"conv_relu_mkldnn_fuse_pass"
,
//
"conv_leaky_relu_mkldnn_fuse_pass"
,
//
"conv_relu6_mkldnn_fuse_pass"
,
//
"conv_swish_mkldnn_fuse_pass"
,
//
"scale_matmul_fuse_pass"
,
//
"conv_relu_mkldnn_fuse_pass"
,
//
"conv_leaky_relu_mkldnn_fuse_pass"
,
//
"conv_relu6_mkldnn_fuse_pass"
,
//
"conv_swish_mkldnn_fuse_pass"
,
//
"scale_matmul_fuse_pass"
,
//
"matmul_transpose_reshape_fuse_pass"
,
//
// Disabled due to topology-dependent speed-up
// "fc_mkldnn_pass",
"mkldnn_inplace_pass"
,
// This pass should be activated after
...
...
paddle/fluid/operators/matmul_op.cc
浏览文件 @
d31a174f
...
...
@@ -407,7 +407,45 @@ class MatMulOp : public framework::OperatorWithKernel {
if
(
dim_out
.
empty
())
{
dim_out
=
{
1
};
}
context
->
SetOutputDim
(
"Out"
,
framework
::
make_ddim
(
dim_out
));
framework
::
DDim
ddim_out
=
framework
::
make_ddim
(
dim_out
);
#ifdef PADDLE_WITH_MKLDNN
// if mkldnn matmul+transpose+reshape fuse activated
auto
reshape_out
=
context
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"fused_reshape_Out"
);
auto
transpose_out
=
context
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"fused_transpose_Out"
);
if
(
!
reshape_out
.
empty
()
&&
!
transpose_out
.
empty
())
{
auto
reshape_out_size
=
reshape_out
.
size
();
auto
transpose_out_size
=
transpose_out
.
size
();
PADDLE_ENFORCE_EQ
(
transpose_out_size
,
4
,
platform
::
errors
::
InvalidArgument
(
"transpose_out supported rank is 4, "
"received %d"
,
transpose_out_size
));
const
std
::
vector
<
int
>
supported_axis
{
0
,
2
,
1
,
3
};
const
bool
supported_transpose_axis
=
std
::
equal
(
transpose_out
.
begin
(),
transpose_out
.
end
(),
supported_axis
.
begin
());
PADDLE_ENFORCE_EQ
(
supported_transpose_axis
,
true
,
platform
::
errors
::
InvalidArgument
(
"supported transpose axis for the fuse are {0, 2, 1, 3}"
));
PADDLE_ENFORCE_EQ
(
reshape_out_size
,
3
,
platform
::
errors
::
InvalidArgument
(
"reshape_out supported rank is 3, "
"received %d"
,
reshape_out_size
));
framework
::
DDim
shape_out
=
ddim_out
.
transpose
(
transpose_out
).
reshape
(
reshape_out
);
context
->
SetOutputDim
(
"Out"
,
shape_out
);
}
else
{
context
->
SetOutputDim
(
"Out"
,
ddim_out
);
}
#else
context
->
SetOutputDim
(
"Out"
,
ddim_out
);
#endif
context
->
ShareLoD
(
"X"
,
/*->*/
"Out"
);
}
...
...
@@ -446,6 +484,16 @@ class MatMulOpMaker : public framework::OpProtoAndCheckerMaker {
"use_mkldnn"
,
"(bool, default false) Indicates if MKL-DNN kernel will be used"
)
.
SetDefault
(
false
);
AddAttr
<
std
::
vector
<
int
>>
(
"fused_reshape_Out"
,
R"DOC(When MKLDNN MatMul_transpose_reshape fuse activated, "
"it's a shape atribute of fused reshape for `Out` output.)DOC"
)
.
SetDefault
({});
AddAttr
<
std
::
vector
<
int
>>
(
"fused_transpose_Out"
,
R"DOC(When MKLDNN MatMul_transpose_reshape fuse activated, "
"it's a axis atribute of fused transpose for `Out` output.)DOC"
)
.
SetDefault
({});
/* int8 parameters */
AddAttr
<
bool
>
(
"use_quantizer"
,
"(bool, default false) "
...
...
@@ -466,6 +514,7 @@ class MatMulOpMaker : public framework::OpProtoAndCheckerMaker {
"(bool, default false) Force INT8 kernel output FP32, only "
"used in MKL-DNN INT8"
)
.
SetDefault
(
false
);
#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA)
AddAttr
<
int
>
(
"head_number"
,
"The number of heads of the matrix"
)
.
SetDefault
(
1
);
...
...
paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc
浏览文件 @
d31a174f
...
...
@@ -31,6 +31,11 @@ using platform::MKLDNNDeviceContext;
using
framework
::
ExecutionContext
;
using
Tensor
=
framework
::
Tensor
;
template
<
typename
T
>
constexpr
bool
IsInt8
()
{
return
std
::
is_same
<
T
,
int8_t
>::
value
||
std
::
is_same
<
T
,
uint8_t
>::
value
;
}
// Get row matrix shape from a vector shape. If the rank of x_dim > 1, the
// original x_dim is returned.
static
framework
::
DDim
RowMatrixDimsFromVector
(
const
framework
::
DDim
&
x_dim
)
{
...
...
@@ -64,7 +69,8 @@ class MatMulFactory {
private:
struct
MatMulDims
{
const
memory
::
dim
BS
,
M
,
N
,
K
;
const
memory
::
dims
x_dims
,
y_dims
,
out_dims
,
x_strides
,
y_strides
,
out_strides
;
};
void
SetDNNLEngine
(
const
ExecutionContext
&
ctx
)
{
...
...
@@ -80,6 +86,19 @@ class MatMulFactory {
return
dnnl
::
memory
(
md
,
engine_
,
to_void_cast
(
data
));
}
bool
IsOutputFused
(
const
ExecutionContext
&
ctx
)
const
{
auto
&
fused_reshape_Out
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"fused_reshape_Out"
);
auto
&
fused_transpose_Out
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"fused_transpose_Out"
);
return
!
fused_reshape_Out
.
empty
()
&&
!
fused_transpose_Out
.
empty
();
}
void
CorrectStridesWhenFloatOutputFused
(
const
ExecutionContext
&
ctx
,
const
memory
::
dim
N
,
memory
::
dim
b
,
memory
::
dims
*
out_strides
)
const
{
if
(
!
IsInt8
<
OT
>
()
&&
IsOutputFused
(
ctx
))
*
out_strides
=
{
N
,
b
*
N
,
1
};
}
MatMulDims
GetMatmulDims
(
const
ExecutionContext
&
ctx
)
{
auto
mat_dim_x
=
math
::
CreateMatrixDescriptor
(
RowMatrixDimsFromVector
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
dims
()),
0
,
...
...
@@ -100,34 +119,45 @@ class MatMulFactory {
const
memory
::
dim
M
=
mat_dim_x
.
height_
;
const
memory
::
dim
N
=
mat_dim_y
.
width_
;
const
memory
::
dim
K
=
mat_dim_x
.
width_
;
return
{
BS
,
M
,
N
,
K
};
batch_size_
=
1
;
auto
b
=
BS
;
if
(
BS
>
1
&&
IsOutputFused
(
ctx
))
{
batch_size_
=
ctx
.
Input
<
Tensor
>
(
"X"
)
->
dims
()[
0
];
b
=
BS
/
batch_size_
;
}
memory
::
dims
x_dims
=
{
b
,
M
,
K
};
memory
::
dims
y_dims
=
{
b
,
K
,
N
};
memory
::
dims
out_dims
=
{
b
,
M
,
N
};
size_t
x_size
=
b
*
M
*
K
*
sizeof
(
XT
);
size_t
y_size
=
b
*
K
*
N
*
sizeof
(
YT
);
size_t
out_size
=
b
*
M
*
N
*
sizeof
(
OT
);
offsets_
=
{
x_size
,
y_size
,
out_size
};
// Translate transA and transB
memory
::
dims
strides_x
=
!
ctx
.
Attr
<
bool
>
(
"transpose_X"
)
?
memory
::
dims
{
M
*
K
,
K
,
1
}
:
memory
::
dims
{
M
*
K
,
1
,
M
};
memory
::
dims
strides_y
=
!
ctx
.
Attr
<
bool
>
(
"transpose_Y"
)
?
memory
::
dims
{
N
*
K
,
N
,
1
}
:
memory
::
dims
{
N
*
K
,
1
,
K
};
memory
::
dims
out_strides
=
memory
::
dims
{
M
*
N
,
N
,
1
};
CorrectStridesWhenFloatOutputFused
(
ctx
,
N
,
b
,
&
out_strides
);
return
{
x_dims
,
y_dims
,
out_dims
,
strides_x
,
strides_y
,
out_strides
};
}
void
CreateMemories
(
const
ExecutionContext
&
ctx
)
{
auto
matmul_dims
=
GetMatmulDims
(
ctx
);
auto
BS
=
matmul_dims
.
BS
;
auto
M
=
matmul_dims
.
M
;
auto
N
=
matmul_dims
.
N
;
auto
K
=
matmul_dims
.
K
;
bool
x_trans
=
ctx
.
Attr
<
bool
>
(
"transpose_X"
);
bool
y_trans
=
ctx
.
Attr
<
bool
>
(
"transpose_Y"
);
typedef
memory
::
dims
dims
;
dims
x_dims
=
{
BS
,
M
,
K
};
dims
y_dims
=
{
BS
,
K
,
N
};
dims
out_dims
=
{
BS
,
M
,
N
};
// Translate transA and transB
dims
x_strides
=
!
x_trans
?
dims
{
M
*
K
,
K
,
1
}
:
dims
{
M
*
K
,
1
,
M
};
dims
y_strides
=
!
y_trans
?
dims
{
N
*
K
,
N
,
1
}
:
dims
{
N
*
K
,
1
,
K
};
dims
out_strides
=
{
M
*
N
,
N
,
1
};
x_mem_
=
CreateMemory
<
XT
>
(
x_dims
,
x_strides
,
ctx
.
Input
<
Tensor
>
(
"X"
)
->
data
<
XT
>
());
y_mem_
=
CreateMemory
<
YT
>
(
y_dims
,
y_strides
,
ctx
.
Input
<
Tensor
>
(
"Y"
)
->
data
<
YT
>
());
x_mem_
=
CreateMemory
<
XT
>
(
matmul_dims
.
x_dims
,
matmul_dims
.
x_strides
,
ctx
.
Input
<
Tensor
>
(
"X"
)
->
data
<
XT
>
());
y_mem_
=
CreateMemory
<
YT
>
(
matmul_dims
.
y_dims
,
matmul_dims
.
y_strides
,
ctx
.
Input
<
Tensor
>
(
"Y"
)
->
data
<
YT
>
());
out_mem_
=
CreateMemory
<
OT
>
(
out_dims
,
out_strides
,
matmul_dims
.
out_dims
,
matmul_dims
.
out_strides
,
ctx
.
Output
<
Tensor
>
(
"Out"
)
->
mutable_data
<
OT
>
(
ctx
.
GetPlace
()));
}
...
...
@@ -156,11 +186,25 @@ class MatMulFactory {
void
Execute
()
{
dnnl
::
stream
stream
(
engine_
);
matmul_prim_
.
execute
(
stream
,
{
{
MKLDNN_ARG_SRC
,
x_mem_
},
{
MKLDNN_ARG_WEIGHTS
,
y_mem_
},
{
MKLDNN_ARG_DST
,
out_mem_
},
});
auto
offsets
=
offsets_
;
unsigned
bs
=
batch_size_
;
void
*
x_ptr
=
x_mem_
.
get_data_handle
();
void
*
y_ptr
=
y_mem_
.
get_data_handle
();
void
*
out_ptr
=
out_mem_
.
get_data_handle
();
for
(
unsigned
i
=
0
;
i
<
bs
;
i
++
)
{
x_mem_
.
set_data_handle
(
x_ptr
);
y_mem_
.
set_data_handle
(
y_ptr
);
out_mem_
.
set_data_handle
(
out_ptr
);
matmul_prim_
.
execute
(
stream
,
{
{
MKLDNN_ARG_SRC
,
x_mem_
},
{
MKLDNN_ARG_WEIGHTS
,
y_mem_
},
{
MKLDNN_ARG_DST
,
out_mem_
},
});
x_ptr
=
static_cast
<
char
*>
(
x_ptr
)
+
offsets
.
x_offset
;
y_ptr
=
static_cast
<
char
*>
(
y_ptr
)
+
offsets
.
y_offset
;
out_ptr
=
static_cast
<
char
*>
(
out_ptr
)
+
offsets
.
out_offset
;
}
stream
.
wait
();
}
...
...
@@ -188,11 +232,19 @@ class MatMulFactory {
void
SetInitialized
()
{
initialized_
=
true
;
}
private:
struct
memory_offsets
{
size_t
x_offset
;
size_t
y_offset
;
size_t
out_offset
;
};
dnnl
::
engine
engine_
;
dnnl
::
memory
x_mem_
;
dnnl
::
memory
y_mem_
;
dnnl
::
memory
out_mem_
;
dnnl
::
matmul
matmul_prim_
;
memory_offsets
offsets_
;
unsigned
batch_size_
;
bool
initialized_
=
false
;
};
...
...
@@ -217,10 +269,6 @@ static std::shared_ptr<MatMulFactory<XT, YT, OT>> GetPrimitiveFactory(
return
factory
;
}
template
<
typename
T
>
constexpr
bool
IsInt8
()
{
return
std
::
is_same
<
T
,
int8_t
>::
value
||
std
::
is_same
<
T
,
uint8_t
>::
value
;
}
// Choose appropriate primitive factory implementation based on inferred
// output type (uint8, int8 or float).
template
<
typename
XT
,
typename
YT
>
...
...
python/paddle/fluid/contrib/slim/quantization/qat2_int8_mkldnn_pass.py
浏览文件 @
d31a174f
...
...
@@ -371,6 +371,7 @@ class Qat2Int8MkldnnPass(object):
[
'use_gpu'
,
'use_fc_padding'
],
[
False
,
False
])
graph
=
self
.
_apply_pass
(
graph
,
'fc_mkldnn_pass'
)
graph
=
self
.
_apply_pass
(
graph
,
'matmul_transpose_reshape_fuse_pass'
)
return
graph
def
_apply_pass
(
self
,
graph
,
pass_name
,
attrs
=
None
,
attr_values
=
None
):
...
...
python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_matmul_op_output_fuse_pass.py
0 → 100644
浏览文件 @
d31a174f
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
numpy
as
np
import
paddle.fluid
as
fluid
from
inference_pass_test
import
InferencePassTest
class
TestMKLDNNMatmulFuseOp
(
InferencePassTest
):
def
init_data
(
self
):
self
.
bs
=
8
self
.
d_type
=
np
.
float32
self
.
shape_x
=
[
12
,
128
,
128
]
self
.
shape_y
=
[
12
,
128
,
64
]
self
.
enable_mkldnn
=
True
def
make_network
(
self
):
with
fluid
.
program_guard
(
self
.
main_program
,
self
.
startup_program
):
x
=
fluid
.
data
(
name
=
'x'
,
shape
=
[
-
1
]
+
self
.
shape_x
,
dtype
=
self
.
d_type
)
y
=
fluid
.
data
(
name
=
'y'
,
shape
=
[
-
1
]
+
self
.
shape_y
,
dtype
=
self
.
d_type
)
out
=
fluid
.
layers
.
matmul
(
x
,
y
)
out
=
fluid
.
layers
.
transpose
(
out
,
perm
=
[
0
,
2
,
1
,
3
])
out
=
fluid
.
layers
.
reshape
(
out
,
[
0
,
0
,
self
.
shape_y
[
0
]
*
self
.
shape_y
[
2
]])
out
=
fluid
.
layers
.
fc
(
out
,
size
=
1
)
return
out
def
setUp
(
self
):
self
.
init_data
()
out
=
self
.
make_network
()
self
.
set_feeds
(
out
)
def
set_feeds
(
self
,
out
):
self
.
feeds
=
{
"x"
:
np
.
random
.
random
([
self
.
bs
]
+
self
.
shape_x
).
astype
(
self
.
d_type
),
"y"
:
np
.
random
.
random
([
self
.
bs
]
+
self
.
shape_y
).
astype
(
self
.
d_type
)
}
self
.
fetch_list
=
[
out
]
def
test_check_output
(
self
):
use_gpu
=
False
self
.
check_output_with_option
(
use_gpu
)
class
TestMKLDNNMatmulOtherDimsFuseOp
(
TestMKLDNNMatmulFuseOp
):
def
init_data
(
self
):
self
.
bs
=
8
self
.
d_type
=
np
.
float32
self
.
shape_x
=
[
12
,
1
,
1
]
self
.
shape_y
=
[
12
,
1
,
64
]
self
.
enable_mkldnn
=
True
class
TestMKLDNNMatmulOpNotFusedWrongTransposeAxis
(
TestMKLDNNMatmulFuseOp
):
def
make_network
(
self
):
with
fluid
.
program_guard
(
self
.
main_program
,
self
.
startup_program
):
x
=
fluid
.
data
(
name
=
'x'
,
shape
=
[
-
1
]
+
self
.
shape_x
,
dtype
=
self
.
d_type
)
y
=
fluid
.
data
(
name
=
'y'
,
shape
=
[
-
1
]
+
self
.
shape_y
,
dtype
=
self
.
d_type
)
out
=
fluid
.
layers
.
matmul
(
x
,
y
)
out
=
fluid
.
layers
.
transpose
(
out
,
perm
=
[
0
,
1
,
2
,
3
])
out
=
fluid
.
layers
.
reshape
(
out
,
[
0
,
0
,
0
,
0
])
out
=
fluid
.
layers
.
fc
(
out
,
size
=
1
)
return
out
class
TestMKLDNNMatmulOpNotFusedBreakPattern
(
TestMKLDNNMatmulFuseOp
):
def
init_data
(
self
):
self
.
bs
=
7
self
.
d_type
=
np
.
float32
self
.
shape_x
=
[
12
,
128
,
128
]
self
.
shape_y
=
[
12
,
128
,
64
]
self
.
enable_mkldnn
=
True
def
make_network
(
self
):
with
fluid
.
program_guard
(
self
.
main_program
,
self
.
startup_program
):
x
=
fluid
.
data
(
name
=
'x'
,
shape
=
[
-
1
]
+
self
.
shape_x
,
dtype
=
self
.
d_type
)
y
=
fluid
.
data
(
name
=
'y'
,
shape
=
[
-
1
]
+
self
.
shape_y
,
dtype
=
self
.
d_type
)
out
=
fluid
.
layers
.
matmul
(
x
,
y
)
out
=
fluid
.
layers
.
transpose
(
out
,
perm
=
[
0
,
2
,
1
,
3
])
out
=
fluid
.
layers
.
transpose
(
out
,
perm
=
[
0
,
1
,
2
,
3
])
# breaks pattern
out
=
fluid
.
layers
.
reshape
(
out
,
[
0
,
0
,
self
.
shape_y
[
0
]
*
self
.
shape_y
[
2
]])
out
=
fluid
.
layers
.
fc
(
out
,
size
=
1
)
return
out
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/mkldnn/test_matmul_mkldnn_op.py
浏览文件 @
d31a174f
...
...
@@ -161,5 +161,134 @@ class TestDnnlMatMulOpInt8ForceFP32BasicScales(TestDnnlMatMulOp):
self
.
attrs
=
{
'force_fp32_output'
:
True
}
@
skip_check_grad_ci
(
reason
=
"Tests inference only optimization."
)
class
TestMatMulOpTransposeReshapeEmptyFloat
(
OpTest
):
def
init_data_type
(
self
):
self
.
data_type_
=
np
.
float32
def
generate_data
(
self
):
self
.
bs
=
1
self
.
x
=
np
.
random
.
random
([
self
.
bs
,
128
,
128
]).
astype
(
self
.
data_type_
)
self
.
y
=
np
.
random
.
random
([
self
.
bs
,
128
,
64
]).
astype
(
self
.
data_type_
)
def
init_params_and_out
(
self
):
self
.
transpose_out
=
[]
self
.
reshape_out
=
[]
self
.
out
=
np
.
matmul
(
self
.
x
,
self
.
y
)
def
setUp
(
self
):
os
.
environ
[
"DNNL_MAX_CPU_ISA"
]
=
"AVX"
self
.
op_type
=
"matmul"
self
.
_cpu_only
=
True
self
.
use_mkldnn
=
True
self
.
init_data_type
()
self
.
generate_data
()
self
.
init_params_and_out
()
self
.
inputs
=
{
'X'
:
self
.
x
,
'Y'
:
self
.
y
}
self
.
attrs
=
{
'use_mkldnn'
:
self
.
use_mkldnn
}
if
len
(
self
.
reshape_out
)
>
0
:
self
.
attrs
[
'fused_reshape_Out'
]
=
self
.
reshape_out
if
len
(
self
.
transpose_out
)
>
0
:
self
.
attrs
[
'fused_transpose_Out'
]
=
self
.
transpose_out
self
.
inputs
=
{
'X'
:
self
.
x
,
'Y'
:
self
.
y
}
self
.
outputs
=
{
'Out'
:
self
.
out
}
def
test_check_output
(
self
):
self
.
check_output
()
def
check_raise_error
(
self
,
msg
):
try
:
self
.
check_output
()
except
Exception
as
e
:
if
msg
in
str
(
e
):
raise
AttributeError
else
:
print
(
e
)
class
TestMatMulOpTransposeReshapeIntEmptyInt
(
TestMatMulOpTransposeReshapeEmptyFloat
):
def
init_data_type
(
self
):
self
.
data_type_
=
np
.
int8
class
TestMatMulOpTransposeReshapeBasicFloat
(
TestMatMulOpTransposeReshapeEmptyFloat
):
def
generate_data
(
self
):
self
.
bs
=
8
self
.
x
=
np
.
random
.
random
(
[
self
.
bs
,
12
,
128
,
128
]).
astype
(
self
.
data_type_
)
self
.
y
=
np
.
random
.
random
(
[
self
.
bs
,
12
,
128
,
64
]).
astype
(
self
.
data_type_
)
def
init_params_and_out
(
self
):
self
.
transpose_out
=
[
0
,
2
,
1
,
3
]
self
.
reshape_out
=
[
0
,
0
,
self
.
x
.
shape
[
1
]
*
self
.
y
.
shape
[
-
1
]]
self
.
out
=
np
.
matmul
(
self
.
x
,
self
.
y
).
transpose
([
0
,
2
,
1
,
3
]).
reshape
(
[
self
.
bs
,
-
1
,
self
.
x
.
shape
[
1
]
*
self
.
y
.
shape
[
-
1
]])
class
TestMatMulOpTransposeReshapeBasicInt
(
TestMatMulOpTransposeReshapeBasicFloat
):
def
init_data_type
(
self
):
self
.
data_type_
=
np
.
int8
class
TestMatMulOpTransposeReshapeOtherDimFloat
(
TestMatMulOpTransposeReshapeBasicFloat
):
def
generate_data
(
self
):
self
.
bs
=
11
self
.
x
=
np
.
random
.
random
([
self
.
bs
,
12
,
14
,
18
]).
astype
(
self
.
data_type_
)
self
.
y
=
np
.
random
.
random
([
self
.
bs
,
12
,
18
,
13
]).
astype
(
self
.
data_type_
)
class
TestMatMulOpTransposeReshapeOtherDimInt
(
TestMatMulOpTransposeReshapeOtherDimFloat
):
def
init_data_type
(
self
):
self
.
data_type_
=
np
.
int8
class
TestMatMulOpTransposeReshapeTransposeAxisNotSupportedException
(
TestMatMulOpTransposeReshapeBasicFloat
):
def
init_params_and_out
(
self
):
self
.
transpose_out
=
[
0
,
1
,
2
,
3
]
self
.
reshape_out
=
[
0
,
0
,
self
.
x
.
shape
[
1
]
*
self
.
y
.
shape
[
-
1
]]
self
.
out
=
np
.
matmul
(
self
.
x
,
self
.
y
)
def
test_check_output
(
self
):
self
.
assertRaises
(
AttributeError
,
self
.
check_raise_error
,
'InvalidArgumentError: supported transpose axis '
'for the fuse are {0, 2, 1, 3}'
)
class
TestMatMulOpTransposeReshapeTransposeRankNotSupportedException
(
TestMatMulOpTransposeReshapeBasicFloat
):
def
init_params_and_out
(
self
):
self
.
transpose_out
=
[
0
,
2
,
1
]
self
.
reshape_out
=
[
0
,
0
,
self
.
x
.
shape
[
1
]
*
self
.
y
.
shape
[
-
1
]]
self
.
out
=
np
.
matmul
(
self
.
x
,
self
.
y
)
def
test_check_output
(
self
):
self
.
assertRaises
(
AttributeError
,
self
.
check_raise_error
,
'InvalidArgumentError: transpose_out supported rank is 4'
)
class
TestMatMulOpTransposeReshapeRankOfReshapeNotSupportedException
(
TestMatMulOpTransposeReshapeBasicFloat
):
def
init_params_and_out
(
self
):
self
.
transpose_out
=
[
0
,
2
,
1
,
3
]
self
.
reshape_out
=
[
0
,
0
]
self
.
out
=
np
.
matmul
(
self
.
x
,
self
.
y
)
def
test_check_output
(
self
):
self
.
assertRaises
(
AttributeError
,
self
.
check_raise_error
,
'InvalidArgumentError: reshape_out supported rank is 3'
)
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录