Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
8c2f0770
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
8c2f0770
编写于
4月 25, 2020
作者:
A
arlesniak
提交者:
GitHub
4月 25, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[DNNL] Added mkldnn_matmul-transpose-reshape fuse pass (#23866) (#24140)
test=release/2.0
上级
cae30c02
变更
14
隐藏空白更改
内联
并排
Showing
14 changed file
with
727 addition
and
38 deletion
+727
-38
paddle/fluid/framework/ddim.cc
paddle/fluid/framework/ddim.cc
+62
-0
paddle/fluid/framework/ddim.h
paddle/fluid/framework/ddim.h
+4
-0
paddle/fluid/framework/ir/CMakeLists.txt
paddle/fluid/framework/ir/CMakeLists.txt
+2
-0
paddle/fluid/framework/ir/graph_pattern_detector.cc
paddle/fluid/framework/ir/graph_pattern_detector.cc
+37
-0
paddle/fluid/framework/ir/graph_pattern_detector.h
paddle/fluid/framework/ir/graph_pattern_detector.h
+18
-0
paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.cc
...framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.cc
+100
-0
paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.h
.../framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.h
+35
-0
paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass_tester.cc
...rk/ir/mkldnn/matmul_transpose_reshape_fuse_pass_tester.cc
+93
-0
paddle/fluid/inference/api/paddle_pass_builder.cc
paddle/fluid/inference/api/paddle_pass_builder.cc
+6
-5
paddle/fluid/operators/matmul_op.cc
paddle/fluid/operators/matmul_op.cc
+50
-1
paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc
paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc
+80
-32
python/paddle/fluid/contrib/slim/quantization/qat2_int8_mkldnn_pass.py
.../fluid/contrib/slim/quantization/qat2_int8_mkldnn_pass.py
+1
-0
python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_matmul_op_output_fuse_pass.py
...ts/ir/inference/test_mkldnn_matmul_op_output_fuse_pass.py
+110
-0
python/paddle/fluid/tests/unittests/mkldnn/test_matmul_mkldnn_op.py
...dle/fluid/tests/unittests/mkldnn/test_matmul_mkldnn_op.py
+129
-0
未找到文件。
paddle/fluid/framework/ddim.cc
浏览文件 @
8c2f0770
...
...
@@ -131,5 +131,67 @@ DDim stride_numel(const DDim& ddim) {
return
strides
;
}
DDim
DDim
::
reshape
(
const
std
::
vector
<
int
>&
shape
)
const
{
const
int64_t
copy_dim_val
=
0
;
const
DDim
&
in_dims
=
*
this
;
DDim
out_dims
;
out_dims
.
rank_
=
shape
.
size
();
for
(
size_t
i
=
0
;
i
<
shape
.
size
();
++
i
)
{
if
(
shape
[
i
]
==
copy_dim_val
)
{
PADDLE_ENFORCE_LT
(
static_cast
<
int
>
(
i
),
in_dims
.
size
(),
platform
::
errors
::
InvalidArgument
(
"Index %d of shape under which the value of 0 "
"is stored, must be lower than the number of "
"old dimensions. But received shape[%d] = 0, "
"dimensions = %d, shape = [%s]."
,
i
,
in_dims
.
size
(),
in_dims
));
out_dims
[
i
]
=
in_dims
[
i
];
}
else
{
out_dims
[
i
]
=
shape
[
i
];
}
}
return
out_dims
;
}
DDim
DDim
::
transpose
(
const
std
::
vector
<
int
>&
axis
)
const
{
const
DDim
&
in_dims
=
*
this
;
size_t
in_rank
=
in_dims
.
size
();
size_t
axis_size
=
axis
.
size
();
PADDLE_ENFORCE_EQ
(
in_rank
,
axis_size
,
platform
::
errors
::
InvalidArgument
(
"The input dimension's size "
"should be equal to the axis's size. "
"But received dimension is %d, "
"axis's size is %d"
,
in_rank
,
axis_size
));
std
::
vector
<
int
>
count
(
axis_size
,
0
);
for
(
size_t
i
=
0
;
i
<
axis_size
;
i
++
)
{
PADDLE_ENFORCE_LT
(
axis
[
i
],
static_cast
<
int
>
(
axis_size
),
platform
::
errors
::
InvalidArgument
(
"ValueError: Each element of axis must appear "
"exactly once in the range from 0 to (dims - 1), "
"where the dims is the axis's size, "
"but received axis[%d] is %d, axis_size is %d"
,
i
,
axis
[
i
],
axis_size
));
PADDLE_ENFORCE_EQ
(
++
count
[
axis
[
i
]],
1
,
platform
::
errors
::
InvalidArgument
(
"ValueError: Each element of axis should "
"be a unique value range from 0 to (dims - 1), "
"where the dims is the axis's size, "
"unique value means this axis value can appear only once. "
"But received count[axis[%d]] is %d"
,
i
,
count
[
axis
[
i
]]));
}
DDim
out_dims
(
in_dims
);
for
(
size_t
i
=
0
;
i
<
axis_size
;
i
++
)
{
out_dims
[
i
]
=
in_dims
[
axis
[
i
]];
}
return
out_dims
;
}
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ddim.h
浏览文件 @
8c2f0770
...
...
@@ -126,6 +126,10 @@ class DDim {
std
::
string
to_str
()
const
;
DDim
reshape
(
const
std
::
vector
<
int
>&
shape
)
const
;
DDim
transpose
(
const
std
::
vector
<
int
>&
axis
)
const
;
private:
template
<
int
D
>
inline
Dim
<
D
>&
UnsafeCast
()
{
...
...
paddle/fluid/framework/ir/CMakeLists.txt
浏览文件 @
8c2f0770
...
...
@@ -97,6 +97,7 @@ if(WITH_MKLDNN)
pass_library
(
cpu_quantize_placement_pass base DIR mkldnn
)
pass_library
(
cpu_quantize_pass inference DIR mkldnn
)
pass_library
(
cpu_quantize_squash_pass inference DIR mkldnn
)
pass_library
(
matmul_transpose_reshape_fuse_pass inference DIR mkldnn
)
endif
()
cc_library
(
fuse_bn_act_pass SRCS fuse_bn_act_pass.cc DEPS pass graph_pattern_detector
)
...
...
@@ -144,4 +145,5 @@ if (WITH_MKLDNN)
cc_test
(
test_cpu_quantize_placement_pass SRCS mkldnn/cpu_quantize_placement_pass_tester.cc DEPS cpu_quantize_placement_pass
)
cc_test
(
test_cpu_quantize_pass SRCS mkldnn/cpu_quantize_pass_tester.cc DEPS cpu_quantize_pass naive_executor
)
cc_test
(
test_cpu_quantize_squash_pass SRCS mkldnn/cpu_quantize_squash_pass_tester.cc DEPS cpu_quantize_squash_pass naive_executor
)
cc_test
(
test_matmul_transpose_reshape_fuse_pass SRCS mkldnn/matmul_transpose_reshape_fuse_pass_tester.cc DEPS matmul_transpose_reshape_fuse_pass
)
endif
()
paddle/fluid/framework/ir/graph_pattern_detector.cc
浏览文件 @
8c2f0770
...
...
@@ -2147,6 +2147,43 @@ void patterns::DeleteQuantDequantOpPattern::operator()() {
any_op2
->
LinksFrom
({
quant_dequant_out
});
}
PDNode
*
patterns
::
MatmulTransposeReshapePattern
::
operator
()()
{
auto
reshape_op
=
pattern
->
NewNode
(
reshape_op_repr
())
->
assert_is_op
(
"reshape2"
);
auto
transpose_op
=
pattern
->
NewNode
(
transpose_op_repr
())
->
assert_is_op
(
"transpose2"
);
auto
matmul_op
=
pattern
->
NewNode
(
matmul_op_repr
())
->
assert_is_op
(
"matmul"
);
auto
matmul_out
=
pattern
->
NewNode
(
matmul_out_repr
())
->
AsInput
()
->
assert_is_op_output
(
"matmul"
,
"Out"
)
->
assert_is_op_input
(
"transpose2"
,
"X"
);
auto
transpose_out
=
pattern
->
NewNode
(
transpose_out_repr
())
->
AsIntermediate
()
->
assert_is_op_output
(
"transpose2"
,
"Out"
)
->
assert_is_op_input
(
"reshape2"
,
"X"
);
auto
transpose_out_xshape
=
pattern
->
NewNode
(
transpose_out_xshape_repr
())
->
AsIntermediate
()
->
assert_is_op_output
(
"transpose2"
,
"XShape"
);
auto
reshape_out
=
pattern
->
NewNode
(
reshape_out_repr
())
->
AsOutput
()
->
assert_is_op_output
(
"reshape2"
);
auto
reshape_out_xshape
=
pattern
->
NewNode
(
reshape_out_xshape_repr
())
->
AsIntermediate
()
->
assert_is_op_output
(
"reshape2"
,
"XShape"
);
matmul_op
->
LinksTo
({
matmul_out
});
transpose_op
->
LinksTo
({
transpose_out_xshape
});
reshape_op
->
LinksTo
({
reshape_out_xshape
});
transpose_op
->
LinksFrom
({
matmul_out
}).
LinksTo
({
transpose_out
});
reshape_op
->
LinksFrom
({
transpose_out
}).
LinksTo
({
reshape_out
});
return
reshape_out
;
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/graph_pattern_detector.h
浏览文件 @
8c2f0770
...
...
@@ -1210,6 +1210,24 @@ struct DeleteQuantDequantOpPattern : public PatternBase {
PATTERN_DECL_NODE
(
any_op2
);
};
// Matmul + Transpose + Reshape
struct
MatmulTransposeReshapePattern
:
public
PatternBase
{
MatmulTransposeReshapePattern
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"matmul_transpose_reshape"
)
{}
PDNode
*
operator
()();
PATTERN_DECL_NODE
(
matmul_op
);
PATTERN_DECL_NODE
(
matmul_out
);
PATTERN_DECL_NODE
(
transpose_op
);
PATTERN_DECL_NODE
(
transpose_out
);
PATTERN_DECL_NODE
(
transpose_out_xshape
);
PATTERN_DECL_NODE
(
reshape_op
);
PATTERN_DECL_NODE
(
reshape_out
);
PATTERN_DECL_NODE
(
reshape_out_xshape
);
};
}
// namespace patterns
// Link two ir::Nodes from each other.
...
...
paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.cc
0 → 100644
浏览文件 @
8c2f0770
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.h"
#include <paddle/fluid/string/pretty_log.h>
#include <vector>
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
void
MatmulTransposeReshapeMKLDNNPass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
graph
,
platform
::
errors
::
InvalidArgument
(
"Pointer to graph argument should not be NULL."
));
FusePassBase
::
Init
(
name_scope_
,
graph
);
GraphPatternDetector
gpd
;
patterns
::
MatmulTransposeReshapePattern
mtrp
(
gpd
.
mutable_pattern
(),
name_scope_
);
mtrp
();
int
found_matmul_transpose_reshape_count
=
0
;
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
VLOG
(
4
)
<<
"handle matmul_transpose_reshape fuse"
;
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_op
,
matmul_op
,
mtrp
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_out
,
matmul_out
,
mtrp
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose_op
,
transpose_op
,
mtrp
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose_out
,
transpose_out
,
mtrp
);
GET_IR_NODE_FROM_SUBGRAPH
(
transpose_out_xshape
,
transpose_out_xshape
,
mtrp
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape_op
,
reshape_op
,
mtrp
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape_out
,
reshape_out
,
mtrp
);
GET_IR_NODE_FROM_SUBGRAPH
(
reshape_out_xshape
,
reshape_out_xshape
,
mtrp
);
auto
reshape_shape
=
boost
::
get
<
std
::
vector
<
int
>>
(
reshape_op
->
Op
()
->
GetAttr
(
"shape"
));
auto
transpose_axis
=
boost
::
get
<
std
::
vector
<
int
>>
(
transpose_op
->
Op
()
->
GetAttr
(
"axis"
));
auto
reshape_out_size
=
reshape_shape
.
size
();
auto
transpose_out_size
=
transpose_axis
.
size
();
const
std
::
vector
<
int
>
supported_axis
{
0
,
2
,
1
,
3
};
const
bool
supported_transpose_axis
=
std
::
equal
(
transpose_axis
.
begin
(),
transpose_axis
.
end
(),
supported_axis
.
begin
());
if
(
transpose_out_size
!=
4
)
{
VLOG
(
3
)
<<
"do not perform matmul_transpose_reshape fuse: "
<<
"supported rank is 4, received "
<<
transpose_out_size
;
return
;
}
if
(
!
supported_transpose_axis
)
{
VLOG
(
3
)
<<
"do not perform matmul_transpose_reshape fuse: "
<<
"supported transpose axis for the fuse are {0, 2, 1, 3}"
;
return
;
}
if
(
reshape_out_size
!=
3
)
{
VLOG
(
3
)
<<
"do not perform matmul_transpose_reshape fuse: "
<<
"reshape_out supported rank is 3, received "
<<
reshape_out_size
;
return
;
}
OpDesc
*
matmul_desc
=
matmul_op
->
Op
();
matmul_desc
->
SetOutput
(
"Out"
,
{
reshape_out
->
Name
()});
matmul_desc
->
SetAttr
(
"fused_reshape_Out"
,
reshape_shape
);
matmul_desc
->
SetAttr
(
"fused_transpose_Out"
,
transpose_axis
);
GraphSafeRemoveNodes
(
graph
,
{
matmul_out
,
transpose_op
,
transpose_out
,
reshape_op
,
transpose_out_xshape
,
reshape_out_xshape
});
IR_OP_VAR_LINK
(
matmul_op
,
reshape_out
);
found_matmul_transpose_reshape_count
++
;
};
gpd
(
graph
,
handler
);
AddStatis
(
found_matmul_transpose_reshape_count
);
std
::
stringstream
msg_ss
;
msg_ss
<<
"--- Fused "
<<
found_matmul_transpose_reshape_count
<<
" MatmulTransposeReshape patterns"
;
paddle
::
string
::
PrettyLogDetail
(
msg_ss
.
str
().
c_str
());
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
matmul_transpose_reshape_fuse_pass
,
paddle
::
framework
::
ir
::
MatmulTransposeReshapeMKLDNNPass
);
paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.h
0 → 100644
浏览文件 @
8c2f0770
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
class
MatmulTransposeReshapeMKLDNNPass
:
public
FusePassBase
{
public:
virtual
~
MatmulTransposeReshapeMKLDNNPass
()
{}
protected:
void
ApplyImpl
(
Graph
*
graph
)
const
override
;
const
std
::
string
name_scope_
{
"matmul_transpose_reshape_fuse"
};
};
}
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass_tester.cc
0 → 100644
浏览文件 @
8c2f0770
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.h"
#include <gtest/gtest.h>
namespace
paddle
{
namespace
framework
{
namespace
ir
{
void
SetOp
(
ProgramDesc
*
prog
,
const
std
::
string
&
type
,
const
std
::
vector
<
std
::
string
>
&
inputs
,
const
std
::
vector
<
std
::
string
>
&
outputs
)
{
auto
*
op
=
prog
->
MutableBlock
(
0
)
->
AppendOp
();
op
->
SetType
(
type
);
op
->
SetInput
(
"X"
,
{
inputs
[
0
]});
op
->
SetOutput
(
"Out"
,
{
outputs
[
0
]});
if
(
type
==
"transpose2"
)
{
op
->
SetAttr
(
"axis"
,
std
::
vector
<
int
>
({
0
,
2
,
1
,
3
}));
op
->
SetOutput
(
"XShape"
,
{
outputs
[
1
]});
}
if
(
type
==
"reshape2"
)
{
op
->
SetAttr
(
"shape"
,
std
::
vector
<
int
>
({
4
,
5
,
6
}));
op
->
SetOutput
(
"XShape"
,
{
outputs
[
1
]});
}
if
(
type
==
"matmul"
)
{
op
->
SetInput
(
"Y"
,
{
inputs
[
1
]});
op
->
SetAttr
(
"use_mkldnn"
,
true
);
}
}
ProgramDesc
BuildProgramDesc
()
{
ProgramDesc
prog
;
for
(
auto
&
v
:
std
::
initializer_list
<
std
::
string
>
(
{
"a1"
,
"a2"
,
"b"
,
"c"
,
"cx"
,
"d"
,
"dx"
,
"e"
}))
{
auto
*
var
=
prog
.
MutableBlock
(
0
)
->
Var
(
v
);
var
->
SetType
(
proto
::
VarType
::
SELECTED_ROWS
);
}
SetOp
(
&
prog
,
"matmul"
,
{
"a1"
,
"a2"
},
{
"b"
});
SetOp
(
&
prog
,
"transpose2"
,
{
"b"
},
{
"c"
,
"cx"
});
SetOp
(
&
prog
,
"reshape2"
,
{
"c"
},
{
"d"
,
"dx"
});
SetOp
(
&
prog
,
"fc"
,
{
"d"
},
{
"e"
});
return
prog
;
}
void
MainTest
(
const
ProgramDesc
&
prog
)
{
std
::
unique_ptr
<
ir
::
Graph
>
graph
(
new
ir
::
Graph
(
prog
));
int
original_nodes_num
=
graph
->
Nodes
().
size
();
auto
pass
=
PassRegistry
::
Instance
().
Get
(
"matmul_transpose_reshape_fuse_pass"
);
graph
.
reset
(
pass
->
Apply
(
graph
.
release
()));
int
current_nodes_num
=
graph
->
Nodes
().
size
();
EXPECT_EQ
(
original_nodes_num
-
6
,
current_nodes_num
);
for
(
auto
*
node
:
graph
->
Nodes
())
{
if
(
node
->
IsOp
())
{
auto
*
op
=
node
->
Op
();
if
(
op
->
Type
()
==
"matmul"
)
{
EXPECT_EQ
(
op
->
GetAttrIfExists
<
std
::
vector
<
int
>>
(
"fused_reshape_Out"
),
std
::
vector
<
int
>
({
4
,
5
,
6
}));
EXPECT_EQ
(
op
->
GetAttrIfExists
<
std
::
vector
<
int
>>
(
"fused_transpose_Out"
),
std
::
vector
<
int
>
({
0
,
2
,
1
,
3
}));
}
}
}
}
TEST
(
MatmulTransposeReshapeFusePass
,
matmul_inputs
)
{
auto
prog
=
BuildProgramDesc
();
MainTest
(
prog
);
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
USE_PASS
(
matmul_transpose_reshape_fuse_pass
);
paddle/fluid/inference/api/paddle_pass_builder.cc
浏览文件 @
8c2f0770
...
...
@@ -191,11 +191,12 @@ void CpuPassStrategy::EnableMKLDNN() {
"conv3d_bias_mkldnn_fuse_pass"
,
//
"conv_elementwise_add_mkldnn_fuse_pass"
,
"conv_concat_relu_mkldnn_fuse_pass"
,
"conv_relu_mkldnn_fuse_pass"
,
//
"conv_leaky_relu_mkldnn_fuse_pass"
,
//
"conv_relu6_mkldnn_fuse_pass"
,
//
"conv_swish_mkldnn_fuse_pass"
,
//
"scale_matmul_fuse_pass"
,
//
"conv_relu_mkldnn_fuse_pass"
,
//
"conv_leaky_relu_mkldnn_fuse_pass"
,
//
"conv_relu6_mkldnn_fuse_pass"
,
//
"conv_swish_mkldnn_fuse_pass"
,
//
"scale_matmul_fuse_pass"
,
//
"matmul_transpose_reshape_fuse_pass"
,
//
// Disabled due to topology-dependent speed-up
// "fc_mkldnn_pass",
"mkldnn_inplace_pass"
,
// This pass should be activated after
...
...
paddle/fluid/operators/matmul_op.cc
浏览文件 @
8c2f0770
...
...
@@ -407,7 +407,45 @@ class MatMulOp : public framework::OperatorWithKernel {
if
(
dim_out
.
empty
())
{
dim_out
=
{
1
};
}
context
->
SetOutputDim
(
"Out"
,
framework
::
make_ddim
(
dim_out
));
framework
::
DDim
ddim_out
=
framework
::
make_ddim
(
dim_out
);
#ifdef PADDLE_WITH_MKLDNN
// if mkldnn matmul+transpose+reshape fuse activated
auto
reshape_out
=
context
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"fused_reshape_Out"
);
auto
transpose_out
=
context
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"fused_transpose_Out"
);
if
(
!
reshape_out
.
empty
()
&&
!
transpose_out
.
empty
())
{
auto
reshape_out_size
=
reshape_out
.
size
();
auto
transpose_out_size
=
transpose_out
.
size
();
PADDLE_ENFORCE_EQ
(
transpose_out_size
,
4
,
platform
::
errors
::
InvalidArgument
(
"transpose_out supported rank is 4, "
"received %d"
,
transpose_out_size
));
const
std
::
vector
<
int
>
supported_axis
{
0
,
2
,
1
,
3
};
const
bool
supported_transpose_axis
=
std
::
equal
(
transpose_out
.
begin
(),
transpose_out
.
end
(),
supported_axis
.
begin
());
PADDLE_ENFORCE_EQ
(
supported_transpose_axis
,
true
,
platform
::
errors
::
InvalidArgument
(
"supported transpose axis for the fuse are {0, 2, 1, 3}"
));
PADDLE_ENFORCE_EQ
(
reshape_out_size
,
3
,
platform
::
errors
::
InvalidArgument
(
"reshape_out supported rank is 3, "
"received %d"
,
reshape_out_size
));
framework
::
DDim
shape_out
=
ddim_out
.
transpose
(
transpose_out
).
reshape
(
reshape_out
);
context
->
SetOutputDim
(
"Out"
,
shape_out
);
}
else
{
context
->
SetOutputDim
(
"Out"
,
ddim_out
);
}
#else
context
->
SetOutputDim
(
"Out"
,
ddim_out
);
#endif
context
->
ShareLoD
(
"X"
,
/*->*/
"Out"
);
}
...
...
@@ -446,6 +484,16 @@ class MatMulOpMaker : public framework::OpProtoAndCheckerMaker {
"use_mkldnn"
,
"(bool, default false) Indicates if MKL-DNN kernel will be used"
)
.
SetDefault
(
false
);
AddAttr
<
std
::
vector
<
int
>>
(
"fused_reshape_Out"
,
R"DOC(When MKLDNN MatMul_transpose_reshape fuse activated, "
"it's a shape atribute of fused reshape for `Out` output.)DOC"
)
.
SetDefault
({});
AddAttr
<
std
::
vector
<
int
>>
(
"fused_transpose_Out"
,
R"DOC(When MKLDNN MatMul_transpose_reshape fuse activated, "
"it's a axis atribute of fused transpose for `Out` output.)DOC"
)
.
SetDefault
({});
/* int8 parameters */
AddAttr
<
bool
>
(
"use_quantizer"
,
"(bool, default false) "
...
...
@@ -466,6 +514,7 @@ class MatMulOpMaker : public framework::OpProtoAndCheckerMaker {
"(bool, default false) Force INT8 kernel output FP32, only "
"used in MKL-DNN INT8"
)
.
SetDefault
(
false
);
#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA)
AddAttr
<
int
>
(
"head_number"
,
"The number of heads of the matrix"
)
.
SetDefault
(
1
);
...
...
paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc
浏览文件 @
8c2f0770
...
...
@@ -31,6 +31,11 @@ using platform::MKLDNNDeviceContext;
using
framework
::
ExecutionContext
;
using
Tensor
=
framework
::
Tensor
;
template
<
typename
T
>
constexpr
bool
IsInt8
()
{
return
std
::
is_same
<
T
,
int8_t
>::
value
||
std
::
is_same
<
T
,
uint8_t
>::
value
;
}
// Get row matrix shape from a vector shape. If the rank of x_dim > 1, the
// original x_dim is returned.
static
framework
::
DDim
RowMatrixDimsFromVector
(
const
framework
::
DDim
&
x_dim
)
{
...
...
@@ -64,7 +69,8 @@ class MatMulFactory {
private:
struct
MatMulDims
{
const
memory
::
dim
BS
,
M
,
N
,
K
;
const
memory
::
dims
x_dims
,
y_dims
,
out_dims
,
x_strides
,
y_strides
,
out_strides
;
};
void
SetDNNLEngine
(
const
ExecutionContext
&
ctx
)
{
...
...
@@ -80,6 +86,19 @@ class MatMulFactory {
return
dnnl
::
memory
(
md
,
engine_
,
to_void_cast
(
data
));
}
bool
IsOutputFused
(
const
ExecutionContext
&
ctx
)
const
{
auto
&
fused_reshape_Out
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"fused_reshape_Out"
);
auto
&
fused_transpose_Out
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"fused_transpose_Out"
);
return
!
fused_reshape_Out
.
empty
()
&&
!
fused_transpose_Out
.
empty
();
}
void
CorrectStridesWhenFloatOutputFused
(
const
ExecutionContext
&
ctx
,
const
memory
::
dim
N
,
memory
::
dim
b
,
memory
::
dims
*
out_strides
)
const
{
if
(
!
IsInt8
<
OT
>
()
&&
IsOutputFused
(
ctx
))
*
out_strides
=
{
N
,
b
*
N
,
1
};
}
MatMulDims
GetMatmulDims
(
const
ExecutionContext
&
ctx
)
{
auto
mat_dim_x
=
math
::
CreateMatrixDescriptor
(
RowMatrixDimsFromVector
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
dims
()),
0
,
...
...
@@ -100,34 +119,45 @@ class MatMulFactory {
const
memory
::
dim
M
=
mat_dim_x
.
height_
;
const
memory
::
dim
N
=
mat_dim_y
.
width_
;
const
memory
::
dim
K
=
mat_dim_x
.
width_
;
return
{
BS
,
M
,
N
,
K
};
batch_size_
=
1
;
auto
b
=
BS
;
if
(
BS
>
1
&&
IsOutputFused
(
ctx
))
{
batch_size_
=
ctx
.
Input
<
Tensor
>
(
"X"
)
->
dims
()[
0
];
b
=
BS
/
batch_size_
;
}
memory
::
dims
x_dims
=
{
b
,
M
,
K
};
memory
::
dims
y_dims
=
{
b
,
K
,
N
};
memory
::
dims
out_dims
=
{
b
,
M
,
N
};
size_t
x_size
=
b
*
M
*
K
*
sizeof
(
XT
);
size_t
y_size
=
b
*
K
*
N
*
sizeof
(
YT
);
size_t
out_size
=
b
*
M
*
N
*
sizeof
(
OT
);
offsets_
=
{
x_size
,
y_size
,
out_size
};
// Translate transA and transB
memory
::
dims
strides_x
=
!
ctx
.
Attr
<
bool
>
(
"transpose_X"
)
?
memory
::
dims
{
M
*
K
,
K
,
1
}
:
memory
::
dims
{
M
*
K
,
1
,
M
};
memory
::
dims
strides_y
=
!
ctx
.
Attr
<
bool
>
(
"transpose_Y"
)
?
memory
::
dims
{
N
*
K
,
N
,
1
}
:
memory
::
dims
{
N
*
K
,
1
,
K
};
memory
::
dims
out_strides
=
memory
::
dims
{
M
*
N
,
N
,
1
};
CorrectStridesWhenFloatOutputFused
(
ctx
,
N
,
b
,
&
out_strides
);
return
{
x_dims
,
y_dims
,
out_dims
,
strides_x
,
strides_y
,
out_strides
};
}
void
CreateMemories
(
const
ExecutionContext
&
ctx
)
{
auto
matmul_dims
=
GetMatmulDims
(
ctx
);
auto
BS
=
matmul_dims
.
BS
;
auto
M
=
matmul_dims
.
M
;
auto
N
=
matmul_dims
.
N
;
auto
K
=
matmul_dims
.
K
;
bool
x_trans
=
ctx
.
Attr
<
bool
>
(
"transpose_X"
);
bool
y_trans
=
ctx
.
Attr
<
bool
>
(
"transpose_Y"
);
typedef
memory
::
dims
dims
;
dims
x_dims
=
{
BS
,
M
,
K
};
dims
y_dims
=
{
BS
,
K
,
N
};
dims
out_dims
=
{
BS
,
M
,
N
};
// Translate transA and transB
dims
x_strides
=
!
x_trans
?
dims
{
M
*
K
,
K
,
1
}
:
dims
{
M
*
K
,
1
,
M
};
dims
y_strides
=
!
y_trans
?
dims
{
N
*
K
,
N
,
1
}
:
dims
{
N
*
K
,
1
,
K
};
dims
out_strides
=
{
M
*
N
,
N
,
1
};
x_mem_
=
CreateMemory
<
XT
>
(
x_dims
,
x_strides
,
ctx
.
Input
<
Tensor
>
(
"X"
)
->
data
<
XT
>
());
y_mem_
=
CreateMemory
<
YT
>
(
y_dims
,
y_strides
,
ctx
.
Input
<
Tensor
>
(
"Y"
)
->
data
<
YT
>
());
x_mem_
=
CreateMemory
<
XT
>
(
matmul_dims
.
x_dims
,
matmul_dims
.
x_strides
,
ctx
.
Input
<
Tensor
>
(
"X"
)
->
data
<
XT
>
());
y_mem_
=
CreateMemory
<
YT
>
(
matmul_dims
.
y_dims
,
matmul_dims
.
y_strides
,
ctx
.
Input
<
Tensor
>
(
"Y"
)
->
data
<
YT
>
());
out_mem_
=
CreateMemory
<
OT
>
(
out_dims
,
out_strides
,
matmul_dims
.
out_dims
,
matmul_dims
.
out_strides
,
ctx
.
Output
<
Tensor
>
(
"Out"
)
->
mutable_data
<
OT
>
(
ctx
.
GetPlace
()));
}
...
...
@@ -156,11 +186,25 @@ class MatMulFactory {
void
Execute
()
{
dnnl
::
stream
stream
(
engine_
);
matmul_prim_
.
execute
(
stream
,
{
{
MKLDNN_ARG_SRC
,
x_mem_
},
{
MKLDNN_ARG_WEIGHTS
,
y_mem_
},
{
MKLDNN_ARG_DST
,
out_mem_
},
});
auto
offsets
=
offsets_
;
unsigned
bs
=
batch_size_
;
void
*
x_ptr
=
x_mem_
.
get_data_handle
();
void
*
y_ptr
=
y_mem_
.
get_data_handle
();
void
*
out_ptr
=
out_mem_
.
get_data_handle
();
for
(
unsigned
i
=
0
;
i
<
bs
;
i
++
)
{
x_mem_
.
set_data_handle
(
x_ptr
);
y_mem_
.
set_data_handle
(
y_ptr
);
out_mem_
.
set_data_handle
(
out_ptr
);
matmul_prim_
.
execute
(
stream
,
{
{
MKLDNN_ARG_SRC
,
x_mem_
},
{
MKLDNN_ARG_WEIGHTS
,
y_mem_
},
{
MKLDNN_ARG_DST
,
out_mem_
},
});
x_ptr
=
static_cast
<
char
*>
(
x_ptr
)
+
offsets
.
x_offset
;
y_ptr
=
static_cast
<
char
*>
(
y_ptr
)
+
offsets
.
y_offset
;
out_ptr
=
static_cast
<
char
*>
(
out_ptr
)
+
offsets
.
out_offset
;
}
stream
.
wait
();
}
...
...
@@ -188,11 +232,19 @@ class MatMulFactory {
void
SetInitialized
()
{
initialized_
=
true
;
}
private:
struct
memory_offsets
{
size_t
x_offset
;
size_t
y_offset
;
size_t
out_offset
;
};
dnnl
::
engine
engine_
;
dnnl
::
memory
x_mem_
;
dnnl
::
memory
y_mem_
;
dnnl
::
memory
out_mem_
;
dnnl
::
matmul
matmul_prim_
;
memory_offsets
offsets_
;
unsigned
batch_size_
;
bool
initialized_
=
false
;
};
...
...
@@ -217,10 +269,6 @@ static std::shared_ptr<MatMulFactory<XT, YT, OT>> GetPrimitiveFactory(
return
factory
;
}
template
<
typename
T
>
constexpr
bool
IsInt8
()
{
return
std
::
is_same
<
T
,
int8_t
>::
value
||
std
::
is_same
<
T
,
uint8_t
>::
value
;
}
// Choose appropriate primitive factory implementation based on inferred
// output type (uint8, int8 or float).
template
<
typename
XT
,
typename
YT
>
...
...
python/paddle/fluid/contrib/slim/quantization/qat2_int8_mkldnn_pass.py
浏览文件 @
8c2f0770
...
...
@@ -371,6 +371,7 @@ class Qat2Int8MkldnnPass(object):
[
'use_gpu'
,
'use_fc_padding'
],
[
False
,
False
])
graph
=
self
.
_apply_pass
(
graph
,
'fc_mkldnn_pass'
)
graph
=
self
.
_apply_pass
(
graph
,
'matmul_transpose_reshape_fuse_pass'
)
return
graph
def
_apply_pass
(
self
,
graph
,
pass_name
,
attrs
=
None
,
attr_values
=
None
):
...
...
python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_matmul_op_output_fuse_pass.py
0 → 100644
浏览文件 @
8c2f0770
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
numpy
as
np
import
paddle.fluid
as
fluid
from
inference_pass_test
import
InferencePassTest
class
TestMKLDNNMatmulFuseOp
(
InferencePassTest
):
def
init_data
(
self
):
self
.
bs
=
8
self
.
d_type
=
np
.
float32
self
.
shape_x
=
[
12
,
128
,
128
]
self
.
shape_y
=
[
12
,
128
,
64
]
self
.
enable_mkldnn
=
True
def
make_network
(
self
):
with
fluid
.
program_guard
(
self
.
main_program
,
self
.
startup_program
):
x
=
fluid
.
data
(
name
=
'x'
,
shape
=
[
-
1
]
+
self
.
shape_x
,
dtype
=
self
.
d_type
)
y
=
fluid
.
data
(
name
=
'y'
,
shape
=
[
-
1
]
+
self
.
shape_y
,
dtype
=
self
.
d_type
)
out
=
fluid
.
layers
.
matmul
(
x
,
y
)
out
=
fluid
.
layers
.
transpose
(
out
,
perm
=
[
0
,
2
,
1
,
3
])
out
=
fluid
.
layers
.
reshape
(
out
,
[
0
,
0
,
self
.
shape_y
[
0
]
*
self
.
shape_y
[
2
]])
out
=
fluid
.
layers
.
fc
(
out
,
size
=
1
)
return
out
def
setUp
(
self
):
self
.
init_data
()
out
=
self
.
make_network
()
self
.
set_feeds
(
out
)
def
set_feeds
(
self
,
out
):
self
.
feeds
=
{
"x"
:
np
.
random
.
random
([
self
.
bs
]
+
self
.
shape_x
).
astype
(
self
.
d_type
),
"y"
:
np
.
random
.
random
([
self
.
bs
]
+
self
.
shape_y
).
astype
(
self
.
d_type
)
}
self
.
fetch_list
=
[
out
]
def
test_check_output
(
self
):
use_gpu
=
False
self
.
check_output_with_option
(
use_gpu
)
class
TestMKLDNNMatmulOtherDimsFuseOp
(
TestMKLDNNMatmulFuseOp
):
def
init_data
(
self
):
self
.
bs
=
8
self
.
d_type
=
np
.
float32
self
.
shape_x
=
[
12
,
1
,
1
]
self
.
shape_y
=
[
12
,
1
,
64
]
self
.
enable_mkldnn
=
True
class
TestMKLDNNMatmulOpNotFusedWrongTransposeAxis
(
TestMKLDNNMatmulFuseOp
):
def
make_network
(
self
):
with
fluid
.
program_guard
(
self
.
main_program
,
self
.
startup_program
):
x
=
fluid
.
data
(
name
=
'x'
,
shape
=
[
-
1
]
+
self
.
shape_x
,
dtype
=
self
.
d_type
)
y
=
fluid
.
data
(
name
=
'y'
,
shape
=
[
-
1
]
+
self
.
shape_y
,
dtype
=
self
.
d_type
)
out
=
fluid
.
layers
.
matmul
(
x
,
y
)
out
=
fluid
.
layers
.
transpose
(
out
,
perm
=
[
0
,
1
,
2
,
3
])
out
=
fluid
.
layers
.
reshape
(
out
,
[
0
,
0
,
0
,
0
])
out
=
fluid
.
layers
.
fc
(
out
,
size
=
1
)
return
out
class
TestMKLDNNMatmulOpNotFusedBreakPattern
(
TestMKLDNNMatmulFuseOp
):
def
init_data
(
self
):
self
.
bs
=
7
self
.
d_type
=
np
.
float32
self
.
shape_x
=
[
12
,
128
,
128
]
self
.
shape_y
=
[
12
,
128
,
64
]
self
.
enable_mkldnn
=
True
def
make_network
(
self
):
with
fluid
.
program_guard
(
self
.
main_program
,
self
.
startup_program
):
x
=
fluid
.
data
(
name
=
'x'
,
shape
=
[
-
1
]
+
self
.
shape_x
,
dtype
=
self
.
d_type
)
y
=
fluid
.
data
(
name
=
'y'
,
shape
=
[
-
1
]
+
self
.
shape_y
,
dtype
=
self
.
d_type
)
out
=
fluid
.
layers
.
matmul
(
x
,
y
)
out
=
fluid
.
layers
.
transpose
(
out
,
perm
=
[
0
,
2
,
1
,
3
])
out
=
fluid
.
layers
.
transpose
(
out
,
perm
=
[
0
,
1
,
2
,
3
])
# breaks pattern
out
=
fluid
.
layers
.
reshape
(
out
,
[
0
,
0
,
self
.
shape_y
[
0
]
*
self
.
shape_y
[
2
]])
out
=
fluid
.
layers
.
fc
(
out
,
size
=
1
)
return
out
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/mkldnn/test_matmul_mkldnn_op.py
浏览文件 @
8c2f0770
...
...
@@ -161,5 +161,134 @@ class TestDnnlMatMulOpInt8ForceFP32BasicScales(TestDnnlMatMulOp):
self
.
attrs
=
{
'force_fp32_output'
:
True
}
@
skip_check_grad_ci
(
reason
=
"Tests inference only optimization."
)
class
TestMatMulOpTransposeReshapeEmptyFloat
(
OpTest
):
def
init_data_type
(
self
):
self
.
data_type_
=
np
.
float32
def
generate_data
(
self
):
self
.
bs
=
1
self
.
x
=
np
.
random
.
random
([
self
.
bs
,
128
,
128
]).
astype
(
self
.
data_type_
)
self
.
y
=
np
.
random
.
random
([
self
.
bs
,
128
,
64
]).
astype
(
self
.
data_type_
)
def
init_params_and_out
(
self
):
self
.
transpose_out
=
[]
self
.
reshape_out
=
[]
self
.
out
=
np
.
matmul
(
self
.
x
,
self
.
y
)
def
setUp
(
self
):
os
.
environ
[
"DNNL_MAX_CPU_ISA"
]
=
"AVX"
self
.
op_type
=
"matmul"
self
.
_cpu_only
=
True
self
.
use_mkldnn
=
True
self
.
init_data_type
()
self
.
generate_data
()
self
.
init_params_and_out
()
self
.
inputs
=
{
'X'
:
self
.
x
,
'Y'
:
self
.
y
}
self
.
attrs
=
{
'use_mkldnn'
:
self
.
use_mkldnn
}
if
len
(
self
.
reshape_out
)
>
0
:
self
.
attrs
[
'fused_reshape_Out'
]
=
self
.
reshape_out
if
len
(
self
.
transpose_out
)
>
0
:
self
.
attrs
[
'fused_transpose_Out'
]
=
self
.
transpose_out
self
.
inputs
=
{
'X'
:
self
.
x
,
'Y'
:
self
.
y
}
self
.
outputs
=
{
'Out'
:
self
.
out
}
def
test_check_output
(
self
):
self
.
check_output
()
def
check_raise_error
(
self
,
msg
):
try
:
self
.
check_output
()
except
Exception
as
e
:
if
msg
in
str
(
e
):
raise
AttributeError
else
:
print
(
e
)
class
TestMatMulOpTransposeReshapeIntEmptyInt
(
TestMatMulOpTransposeReshapeEmptyFloat
):
def
init_data_type
(
self
):
self
.
data_type_
=
np
.
int8
class
TestMatMulOpTransposeReshapeBasicFloat
(
TestMatMulOpTransposeReshapeEmptyFloat
):
def
generate_data
(
self
):
self
.
bs
=
8
self
.
x
=
np
.
random
.
random
(
[
self
.
bs
,
12
,
128
,
128
]).
astype
(
self
.
data_type_
)
self
.
y
=
np
.
random
.
random
(
[
self
.
bs
,
12
,
128
,
64
]).
astype
(
self
.
data_type_
)
def
init_params_and_out
(
self
):
self
.
transpose_out
=
[
0
,
2
,
1
,
3
]
self
.
reshape_out
=
[
0
,
0
,
self
.
x
.
shape
[
1
]
*
self
.
y
.
shape
[
-
1
]]
self
.
out
=
np
.
matmul
(
self
.
x
,
self
.
y
).
transpose
([
0
,
2
,
1
,
3
]).
reshape
(
[
self
.
bs
,
-
1
,
self
.
x
.
shape
[
1
]
*
self
.
y
.
shape
[
-
1
]])
class
TestMatMulOpTransposeReshapeBasicInt
(
TestMatMulOpTransposeReshapeBasicFloat
):
def
init_data_type
(
self
):
self
.
data_type_
=
np
.
int8
class
TestMatMulOpTransposeReshapeOtherDimFloat
(
TestMatMulOpTransposeReshapeBasicFloat
):
def
generate_data
(
self
):
self
.
bs
=
11
self
.
x
=
np
.
random
.
random
([
self
.
bs
,
12
,
14
,
18
]).
astype
(
self
.
data_type_
)
self
.
y
=
np
.
random
.
random
([
self
.
bs
,
12
,
18
,
13
]).
astype
(
self
.
data_type_
)
class
TestMatMulOpTransposeReshapeOtherDimInt
(
TestMatMulOpTransposeReshapeOtherDimFloat
):
def
init_data_type
(
self
):
self
.
data_type_
=
np
.
int8
class
TestMatMulOpTransposeReshapeTransposeAxisNotSupportedException
(
TestMatMulOpTransposeReshapeBasicFloat
):
def
init_params_and_out
(
self
):
self
.
transpose_out
=
[
0
,
1
,
2
,
3
]
self
.
reshape_out
=
[
0
,
0
,
self
.
x
.
shape
[
1
]
*
self
.
y
.
shape
[
-
1
]]
self
.
out
=
np
.
matmul
(
self
.
x
,
self
.
y
)
def
test_check_output
(
self
):
self
.
assertRaises
(
AttributeError
,
self
.
check_raise_error
,
'InvalidArgumentError: supported transpose axis '
'for the fuse are {0, 2, 1, 3}'
)
class
TestMatMulOpTransposeReshapeTransposeRankNotSupportedException
(
TestMatMulOpTransposeReshapeBasicFloat
):
def
init_params_and_out
(
self
):
self
.
transpose_out
=
[
0
,
2
,
1
]
self
.
reshape_out
=
[
0
,
0
,
self
.
x
.
shape
[
1
]
*
self
.
y
.
shape
[
-
1
]]
self
.
out
=
np
.
matmul
(
self
.
x
,
self
.
y
)
def
test_check_output
(
self
):
self
.
assertRaises
(
AttributeError
,
self
.
check_raise_error
,
'InvalidArgumentError: transpose_out supported rank is 4'
)
class
TestMatMulOpTransposeReshapeRankOfReshapeNotSupportedException
(
TestMatMulOpTransposeReshapeBasicFloat
):
def
init_params_and_out
(
self
):
self
.
transpose_out
=
[
0
,
2
,
1
,
3
]
self
.
reshape_out
=
[
0
,
0
]
self
.
out
=
np
.
matmul
(
self
.
x
,
self
.
y
)
def
test_check_output
(
self
):
self
.
assertRaises
(
AttributeError
,
self
.
check_raise_error
,
'InvalidArgumentError: reshape_out supported rank is 3'
)
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录