Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
87099d12
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
87099d12
编写于
2月 17, 2022
作者:
P
phlrain
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' of
https://github.com/PaddlePaddle/Paddle
into move_slice_to_pten
上级
75aca6d4
b4d3597a
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
94 addition
and
85 deletion
+94
-85
paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc
...d/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc
+1
-1
paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc
...e/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc
+2
-2
paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass.cc
.../framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass.cc
+1
-1
paddle/fluid/operators/distribution_helper.h
paddle/fluid/operators/distribution_helper.h
+21
-14
paddle/pten/kernels/primitive/compute_primitives.h
paddle/pten/kernels/primitive/compute_primitives.h
+53
-0
python/paddle/fluid/tests/unittests/ir/inference/test_conv_act_mkldnn_fuse_pass.py
.../unittests/ir/inference/test_conv_act_mkldnn_fuse_pass.py
+3
-5
python/paddle/fluid/tests/unittests/ir/inference/test_conv_bias_mkldnn_fuse_pass.py
...unittests/ir/inference/test_conv_bias_mkldnn_fuse_pass.py
+6
-8
python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_concat_relu_mkldnn_fuse_pass.py
...nference/test_mkldnn_conv_concat_relu_mkldnn_fuse_pass.py
+0
-9
python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_gelu_fuse_pass.py
...unittests/ir/inference/test_mkldnn_conv_gelu_fuse_pass.py
+0
-22
python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_hard_sigmoid_fuse_pass.py
...s/ir/inference/test_mkldnn_conv_hard_sigmoid_fuse_pass.py
+0
-9
python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_hard_swish_fuse_pass.py
...sts/ir/inference/test_mkldnn_conv_hard_swish_fuse_pass.py
+0
-9
python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_transpose_bias_fuse_pass.py
...ir/inference/test_mkldnn_conv_transpose_bias_fuse_pass.py
+7
-5
未找到文件。
paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc
浏览文件 @
87099d12
...
@@ -157,7 +157,7 @@ ConvActivationFusePass::ConvActivationFusePass() {
...
@@ -157,7 +157,7 @@ ConvActivationFusePass::ConvActivationFusePass() {
// IsStringIn({"NHWC", "NCHW"}) MobileNetV2 has no this attribute
// IsStringIn({"NHWC", "NCHW"}) MobileNetV2 has no this attribute
.
AddAttr
(
"data_format"
)
.
AddAttr
(
"data_format"
)
.
IsOptional
()
.
IsOptional
()
.
IsStringIn
({
"NCHW"
,
"AnyLayout"
})
.
IsStringIn
({
"NCHW"
,
"
NHWC"
,
"
AnyLayout"
})
.
End
();
.
End
();
AddOpCompat
(
OpCompat
(
"relu"
))
AddOpCompat
(
OpCompat
(
"relu"
))
...
...
paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc
浏览文件 @
87099d12
...
@@ -115,7 +115,7 @@ Conv2DTransposeBiasFusePass::Conv2DTransposeBiasFusePass() {
...
@@ -115,7 +115,7 @@ Conv2DTransposeBiasFusePass::Conv2DTransposeBiasFusePass() {
.
IsStringIn
({
"EXPLICIT"
,
"SAME"
,
"VALID"
})
.
IsStringIn
({
"EXPLICIT"
,
"SAME"
,
"VALID"
})
.
End
()
.
End
()
.
AddAttr
(
"data_format"
)
.
AddAttr
(
"data_format"
)
.
IsStringIn
({
"NCHW"
})
.
IsStringIn
({
"NCHW"
,
"NHWC"
,
"AnyLayout"
})
.
End
();
.
End
();
AddOpCompat
(
OpCompat
(
"elementwise_add"
))
AddOpCompat
(
OpCompat
(
"elementwise_add"
))
...
@@ -129,7 +129,7 @@ Conv2DTransposeBiasFusePass::Conv2DTransposeBiasFusePass() {
...
@@ -129,7 +129,7 @@ Conv2DTransposeBiasFusePass::Conv2DTransposeBiasFusePass() {
.
IsTensor
()
.
IsTensor
()
.
End
()
.
End
()
.
AddAttr
(
"axis"
)
.
AddAttr
(
"axis"
)
.
IsIntIn
({
1
})
.
IsIntIn
({
1
,
3
})
.
End
();
.
End
();
}
}
...
...
paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass.cc
浏览文件 @
87099d12
...
@@ -59,7 +59,7 @@ ConvConcatReLUFusePass::ConvConcatReLUFusePass() {
...
@@ -59,7 +59,7 @@ ConvConcatReLUFusePass::ConvConcatReLUFusePass() {
.
IsType
<
std
::
vector
<
int
>>
()
.
IsType
<
std
::
vector
<
int
>>
()
.
End
()
.
End
()
.
AddAttr
(
"data_format"
)
.
AddAttr
(
"data_format"
)
.
IsStringIn
({
"NCHW"
})
.
IsStringIn
({
"NCHW"
,
"NHWC"
,
"AnyLayout"
})
.
End
();
.
End
();
AddOpCompat
(
OpCompat
(
"concat"
))
AddOpCompat
(
OpCompat
(
"concat"
))
...
...
paddle/fluid/operators/distribution_helper.h
浏览文件 @
87099d12
...
@@ -28,6 +28,10 @@ limitations under the License. */
...
@@ -28,6 +28,10 @@ limitations under the License. */
#include "paddle/fluid/platform/for_range.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/pten/core/hostdevice.h"
#include "paddle/pten/core/hostdevice.h"
#if defined(__NVCC__) || defined(__HIPCC__)
#include "paddle/pten/kernels/primitive/kernel_primitives.h"
#endif
#if !defined(_WIN32)
#if !defined(_WIN32)
#define UNLIKELY(condition) __builtin_expect(static_cast<bool>(condition), 0)
#define UNLIKELY(condition) __builtin_expect(static_cast<bool>(condition), 0)
#else
#else
...
@@ -91,6 +95,8 @@ struct normal_transform {
...
@@ -91,6 +95,8 @@ struct normal_transform {
#if defined(__NVCC__) || defined(__HIPCC__)
#if defined(__NVCC__) || defined(__HIPCC__)
namespace
kps
=
pten
::
kps
;
/*********************** Distribution Function *************************/
/*********************** Distribution Function *************************/
template
<
typename
T
>
template
<
typename
T
>
struct
uniform_distribution
;
struct
uniform_distribution
;
...
@@ -176,25 +182,26 @@ template <typename T, typename DistOp, typename TransformOp>
...
@@ -176,25 +182,26 @@ template <typename T, typename DistOp, typename TransformOp>
__global__
void
DistributionKernel
(
size_t
size
,
uint64_t
seed
,
uint64_t
offset
,
__global__
void
DistributionKernel
(
size_t
size
,
uint64_t
seed
,
uint64_t
offset
,
DistOp
dist
,
TransformOp
trans
,
DistOp
dist
,
TransformOp
trans
,
T
*
out_data
)
{
T
*
out_data
)
{
size_t
idx
=
static_cast
<
size_t
>
(
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
);
size_t
idx
=
static_cast
<
size_t
>
(
BLOCK_ID_X
*
BLOCK_NUM_X
);
int32_t
returns_c
ount
=
DistOp
::
kReturnsCount
;
static
constexpr
int
kC
ount
=
DistOp
::
kReturnsCount
;
#if defined(__NVCC__)
#if defined(__NVCC__)
curandStatePhilox4_32_10_t
state
;
curandStatePhilox4_32_10_t
state
;
curand_init
(
seed
,
idx
,
offset
,
&
state
);
curand_init
(
seed
,
idx
+
THREAD_ID_X
,
offset
,
&
state
);
using
SType
=
curandStatePhilox4_32_10_t
;
#else
#else
hiprandStatePhilox4_32_10_t
state
;
hiprandStatePhilox4_32_10_t
state
;
hiprand_init
(
seed
,
idx
,
offset
,
&
state
);
hiprand_init
(
seed
,
idx
+
THREAD_ID_X
,
offset
,
&
state
);
using
SType
=
hiprandStatePhilox4_32_10_t
;
#endif
#endif
size_t
total_thread
=
gridDim
.
x
*
blockDim
.
x
;
size_t
total_thread
=
GRID_NUM_X
*
BLOCK_NUM_X
;
for
(
size_t
i
=
idx
;
i
<
size
;
i
+=
total_thread
*
returns_count
)
{
T
args
[
kCount
];
auto
random_tuple
=
dist
(
&
state
);
T
result
[
kCount
];
for
(
size_t
j
=
0
;
j
<
returns_count
;
j
++
)
{
for
(
size_t
i
=
idx
;
i
<
size
;
i
+=
total_thread
*
kCount
)
{
size_t
index
=
i
+
j
*
total_thread
;
kps
::
ElementwiseRandom
<
SType
,
T
,
kCount
,
1
,
DistOp
>
(
&
args
[
0
],
dist
,
&
state
);
if
(
index
<
size
)
{
kps
::
ElementwiseUnary
<
T
,
T
,
kCount
,
1
,
1
,
TransformOp
>
(
&
result
[
0
],
&
args
[
0
],
auto
random
=
(
&
random_tuple
.
x
)[
j
];
trans
);
out_data
[
index
]
=
static_cast
<
T
>
(
trans
(
random
));
kps
::
WriteData
<
T
,
T
,
kCount
,
1
,
1
,
true
>
(
out_data
+
i
,
&
result
[
0
],
size
-
i
,
}
1
,
total_thread
,
1
);
}
}
}
}
}
...
...
paddle/pten/kernels/primitive/compute_primitives.h
浏览文件 @
87099d12
...
@@ -428,5 +428,58 @@ __device__ __forceinline__ void ElementwiseConstant(OutT* out, OpFunc compute) {
...
@@ -428,5 +428,58 @@ __device__ __forceinline__ void ElementwiseConstant(OutT* out, OpFunc compute) {
}
}
}
}
template
<
typename
StateType
,
typename
OutT
,
int
ReturnsCount
,
int
BlockSize
,
class
OpFunc
>
__device__
__forceinline__
void
ElementwiseRandom
(
OutT
*
out
,
OpFunc
compute
,
StateType
*
state
)
{
auto
random_tuple
=
compute
(
state
);
#pragma unroll
for
(
int
i
=
0
;
i
<
ReturnsCount
;
i
++
)
{
out
[
i
]
=
static_cast
<
OutT
>
((
&
random_tuple
.
x
)[
i
]);
}
}
// attention please set share_size = blockDim.x;
// data and b are the register pointer
#define shared_size 64
template
<
typename
InT
,
typename
OutT
,
int
NX
,
int
NY
,
int
BlockSize
,
class
OpFunc
>
__device__
__forceinline__
void
Cumsum
(
OutT
*
out
,
const
InT
*
in
,
OpFunc
compute
)
{
__shared__
InT
temp
[
shared_size
*
2
+
(
shared_size
*
2
)
/
32
];
int
tidx
=
threadIdx
.
x
;
temp
[
tidx
+
tidx
/
32
]
=
in
[
0
];
temp
[
shared_size
+
tidx
+
(
shared_size
+
tidx
)
/
32
]
=
in
[
1
];
for
(
int
stride
=
1
;
stride
<=
blockDim
.
x
;
stride
*=
2
)
{
__syncthreads
();
int
index
=
(
tidx
+
1
)
*
2
*
stride
-
1
;
if
(
index
<
(
blockDim
.
x
*
2
))
{
temp
[
index
+
index
/
32
]
+=
temp
[
index
-
stride
+
(
index
-
stride
)
/
32
];
}
}
for
(
int
stride
=
(
blockDim
.
x
*
2
)
/
4
;
stride
>
0
;
stride
/=
2
)
{
__syncthreads
();
int
index
=
(
tidx
+
1
)
*
2
*
stride
-
1
;
if
((
index
+
stride
)
<
(
blockDim
.
x
*
2
))
{
temp
[
index
+
stride
+
(
stride
+
index
)
/
32
]
+=
temp
[
index
+
(
index
)
/
32
];
}
}
__syncthreads
();
out
[
0
]
=
static_cast
<
OutT
>
(
temp
[
tidx
+
tidx
/
32
]);
out
[
1
]
=
static_cast
<
OutT
>
(
temp
[
tidx
+
shared_size
+
(
tidx
+
shared_size
)
/
32
]);
}
}
// namespace kps
}
// namespace kps
}
// namespace pten
}
// namespace pten
python/paddle/fluid/tests/unittests/ir/inference/test_conv_act_mkldnn_fuse_pass.py
浏览文件 @
87099d12
...
@@ -53,8 +53,6 @@ class TestConvActMkldnnFusePass(PassAutoScanTest):
...
@@ -53,8 +53,6 @@ class TestConvActMkldnnFusePass(PassAutoScanTest):
data_format
=
prog_config
.
ops
[
0
].
attrs
[
"data_format"
]
data_format
=
prog_config
.
ops
[
0
].
attrs
[
"data_format"
]
filter_shape
=
prog_config
.
weights
[
"filter"
].
shape
filter_shape
=
prog_config
.
weights
[
"filter"
].
shape
input_shape
=
prog_config
.
inputs
[
"input_x"
].
shape
input_shape
=
prog_config
.
inputs
[
"input_x"
].
shape
if
data_format
!=
"NCHW"
:
return
False
if
padding_algorithm
==
"VALID"
:
if
padding_algorithm
==
"VALID"
:
if
((
input_shape
[
2
]
-
(
dilations
[
0
]
*
(
filter_shape
[
2
]
-
1
)
+
1
))
/
strides
[
0
]
+
1
)
<=
1
or
\
if
((
input_shape
[
2
]
-
(
dilations
[
0
]
*
(
filter_shape
[
2
]
-
1
)
+
1
))
/
strides
[
0
]
+
1
)
<=
1
or
\
((
input_shape
[
3
]
-
(
dilations
[
1
]
*
(
filter_shape
[
3
]
-
1
)
+
1
))
/
strides
[
1
]
+
1
)
<=
1
:
((
input_shape
[
3
]
-
(
dilations
[
1
]
*
(
filter_shape
[
3
]
-
1
)
+
1
))
/
strides
[
1
]
+
1
)
<=
1
:
...
@@ -80,8 +78,8 @@ class TestConvActMkldnnFusePass(PassAutoScanTest):
...
@@ -80,8 +78,8 @@ class TestConvActMkldnnFusePass(PassAutoScanTest):
x_shape
=
draw
(
x_shape
=
draw
(
st
.
lists
(
st
.
lists
(
st
.
integers
(
st
.
integers
(
min_value
=
1
,
max_value
=
100
),
min_size
=
4
,
max_size
=
4
))
min_value
=
5
,
max_value
=
100
),
min_size
=
4
,
max_size
=
4
))
x_shape
[
1
]
=
draw
(
st
.
integers
(
min_value
=
1
,
max_value
=
10
))
x_shape
[
1
]
=
draw
(
st
.
integers
(
min_value
=
5
,
max_value
=
10
))
# 2. Generate legal attr:data_format of conv2d
# 2. Generate legal attr:data_format of conv2d
data_format
=
draw
(
st
.
sampled_from
([
"NCHW"
,
"NHWC"
]))
data_format
=
draw
(
st
.
sampled_from
([
"NCHW"
,
"NHWC"
]))
...
@@ -90,7 +88,7 @@ class TestConvActMkldnnFusePass(PassAutoScanTest):
...
@@ -90,7 +88,7 @@ class TestConvActMkldnnFusePass(PassAutoScanTest):
f_shape
=
draw
(
f_shape
=
draw
(
st
.
lists
(
st
.
lists
(
st
.
integers
(
st
.
integers
(
min_value
=
1
,
max_value
=
7
),
min_size
=
4
,
max_size
=
4
))
min_value
=
1
,
max_value
=
5
),
min_size
=
4
,
max_size
=
4
))
if
data_format
==
"NCHW"
:
if
data_format
==
"NCHW"
:
f_shape
[
1
]
=
x_shape
[
1
]
f_shape
[
1
]
=
x_shape
[
1
]
else
:
else
:
...
...
python/paddle/fluid/tests/unittests/ir/inference/test_conv_bias_mkldnn_fuse_pass.py
浏览文件 @
87099d12
...
@@ -53,8 +53,6 @@ class TestConvBiasMkldnnFusePass(PassAutoScanTest):
...
@@ -53,8 +53,6 @@ class TestConvBiasMkldnnFusePass(PassAutoScanTest):
data_format
=
prog_config
.
ops
[
0
].
attrs
[
"data_format"
]
data_format
=
prog_config
.
ops
[
0
].
attrs
[
"data_format"
]
filter_shape
=
prog_config
.
weights
[
"filter"
].
shape
filter_shape
=
prog_config
.
weights
[
"filter"
].
shape
input_shape
=
prog_config
.
inputs
[
"input_x"
].
shape
input_shape
=
prog_config
.
inputs
[
"input_x"
].
shape
if
data_format
!=
"NCHW"
:
return
False
if
padding_algorithm
==
"VALID"
:
if
padding_algorithm
==
"VALID"
:
if
((
input_shape
[
2
]
-
(
dilations
[
0
]
*
(
filter_shape
[
2
]
-
1
)
+
1
))
/
strides
[
0
]
+
1
)
<=
1
or
\
if
((
input_shape
[
2
]
-
(
dilations
[
0
]
*
(
filter_shape
[
2
]
-
1
)
+
1
))
/
strides
[
0
]
+
1
)
<=
1
or
\
((
input_shape
[
3
]
-
(
dilations
[
1
]
*
(
filter_shape
[
3
]
-
1
)
+
1
))
/
strides
[
1
]
+
1
)
<=
1
:
((
input_shape
[
3
]
-
(
dilations
[
1
]
*
(
filter_shape
[
3
]
-
1
)
+
1
))
/
strides
[
1
]
+
1
)
<=
1
:
...
@@ -80,8 +78,8 @@ class TestConvBiasMkldnnFusePass(PassAutoScanTest):
...
@@ -80,8 +78,8 @@ class TestConvBiasMkldnnFusePass(PassAutoScanTest):
x_shape
=
draw
(
x_shape
=
draw
(
st
.
lists
(
st
.
lists
(
st
.
integers
(
st
.
integers
(
min_value
=
1
,
max_value
=
100
),
min_size
=
4
,
max_size
=
4
))
min_value
=
5
,
max_value
=
100
),
min_size
=
4
,
max_size
=
4
))
x_shape
[
1
]
=
draw
(
st
.
integers
(
min_value
=
1
,
max_value
=
10
))
x_shape
[
1
]
=
draw
(
st
.
integers
(
min_value
=
5
,
max_value
=
10
))
# 2. Generate legal attr:data_format of conv2d
# 2. Generate legal attr:data_format of conv2d
data_format
=
draw
(
st
.
sampled_from
([
"NCHW"
,
"NHWC"
]))
data_format
=
draw
(
st
.
sampled_from
([
"NCHW"
,
"NHWC"
]))
...
@@ -90,7 +88,7 @@ class TestConvBiasMkldnnFusePass(PassAutoScanTest):
...
@@ -90,7 +88,7 @@ class TestConvBiasMkldnnFusePass(PassAutoScanTest):
f_shape
=
draw
(
f_shape
=
draw
(
st
.
lists
(
st
.
lists
(
st
.
integers
(
st
.
integers
(
min_value
=
1
,
max_value
=
7
),
min_size
=
4
,
max_size
=
4
))
min_value
=
1
,
max_value
=
4
),
min_size
=
4
,
max_size
=
4
))
if
data_format
==
"NCHW"
:
if
data_format
==
"NCHW"
:
f_shape
[
1
]
=
x_shape
[
1
]
f_shape
[
1
]
=
x_shape
[
1
]
else
:
else
:
...
@@ -100,7 +98,7 @@ class TestConvBiasMkldnnFusePass(PassAutoScanTest):
...
@@ -100,7 +98,7 @@ class TestConvBiasMkldnnFusePass(PassAutoScanTest):
strides
=
draw
(
strides
=
draw
(
st
.
lists
(
st
.
lists
(
st
.
integers
(
st
.
integers
(
min_value
=
1
,
max_value
=
5
),
min_size
=
2
,
max_size
=
2
))
min_value
=
1
,
max_value
=
4
),
min_size
=
2
,
max_size
=
2
))
# 5. Generate legal attr:padding_algorithm of conv2d
# 5. Generate legal attr:padding_algorithm of conv2d
padding_algorithm
=
draw
(
st
.
sampled_from
([
"EXPLICIT"
,
"SAME"
,
"VALID"
]))
padding_algorithm
=
draw
(
st
.
sampled_from
([
"EXPLICIT"
,
"SAME"
,
"VALID"
]))
...
@@ -109,7 +107,7 @@ class TestConvBiasMkldnnFusePass(PassAutoScanTest):
...
@@ -109,7 +107,7 @@ class TestConvBiasMkldnnFusePass(PassAutoScanTest):
padding
=
draw
(
padding
=
draw
(
st
.
lists
(
st
.
lists
(
st
.
integers
(
st
.
integers
(
min_value
=
1
,
max_value
=
5
),
min_size
=
4
,
max_size
=
4
))
min_value
=
1
,
max_value
=
4
),
min_size
=
4
,
max_size
=
4
))
# 7. Generate legal attr:groups of conv2d
# 7. Generate legal attr:groups of conv2d
groups
=
draw
(
st
.
integers
(
min_value
=
1
,
max_value
=
3
))
groups
=
draw
(
st
.
integers
(
min_value
=
1
,
max_value
=
3
))
...
@@ -118,7 +116,7 @@ class TestConvBiasMkldnnFusePass(PassAutoScanTest):
...
@@ -118,7 +116,7 @@ class TestConvBiasMkldnnFusePass(PassAutoScanTest):
dilations
=
draw
(
dilations
=
draw
(
st
.
lists
(
st
.
lists
(
st
.
integers
(
st
.
integers
(
min_value
=
1
,
max_value
=
5
),
min_size
=
2
,
max_size
=
2
))
min_value
=
1
,
max_value
=
4
),
min_size
=
2
,
max_size
=
2
))
# 9. Generate legal shape of input:bias of elementwise_add
# 9. Generate legal shape of input:bias of elementwise_add
bias_shape
=
[
f_shape
[
0
]]
bias_shape
=
[
f_shape
[
0
]]
...
...
python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_concat_relu_mkldnn_fuse_pass.py
浏览文件 @
87099d12
...
@@ -27,15 +27,6 @@ import hypothesis.strategies as st
...
@@ -27,15 +27,6 @@ import hypothesis.strategies as st
class
TestConvConcatReluMkldnnFusePass
(
PassAutoScanTest
):
class
TestConvConcatReluMkldnnFusePass
(
PassAutoScanTest
):
def
is_program_valid
(
self
,
program_config
:
ProgramConfig
)
->
bool
:
def
is_program_valid
(
self
,
program_config
:
ProgramConfig
)
->
bool
:
attrs
=
[
program_config
.
ops
[
i
].
attrs
for
i
in
range
(
len
(
program_config
.
ops
))
]
# If the problem has been fixed, the judgment
# needs to be deleted!!!
if
attrs
[
0
][
'data_format'
]
==
"NHWC"
:
return
False
return
True
return
True
def
sample_program_config
(
self
,
draw
):
def
sample_program_config
(
self
,
draw
):
...
...
python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_gelu_fuse_pass.py
浏览文件 @
87099d12
...
@@ -27,15 +27,6 @@ import hypothesis.strategies as st
...
@@ -27,15 +27,6 @@ import hypothesis.strategies as st
class
TestConvGeluMkldnnFusePass
(
PassAutoScanTest
):
class
TestConvGeluMkldnnFusePass
(
PassAutoScanTest
):
def
is_program_valid
(
self
,
program_config
:
ProgramConfig
)
->
bool
:
def
is_program_valid
(
self
,
program_config
:
ProgramConfig
)
->
bool
:
attrs
=
[
program_config
.
ops
[
i
].
attrs
for
i
in
range
(
len
(
program_config
.
ops
))
]
# If the problem has been fixed, the judgment
# needs to be deleted!!!
if
attrs
[
0
][
'data_format'
]
==
"NHWC"
:
return
False
return
True
return
True
def
sample_program_config
(
self
,
draw
):
def
sample_program_config
(
self
,
draw
):
...
@@ -108,19 +99,6 @@ class TestConvGeluMkldnnFusePass(PassAutoScanTest):
...
@@ -108,19 +99,6 @@ class TestConvGeluMkldnnFusePass(PassAutoScanTest):
config
=
self
.
create_inference_config
(
use_mkldnn
=
True
)
config
=
self
.
create_inference_config
(
use_mkldnn
=
True
)
yield
config
,
[
"conv2d"
],
(
1e-5
,
1e-5
)
yield
config
,
[
"conv2d"
],
(
1e-5
,
1e-5
)
# If the problem has been fixed, the judgment
# needs to be deleted!!!
def
add_ignore_pass_case
(
self
):
def
teller1
(
program_config
,
predictor_config
):
if
program_config
.
ops
[
0
].
attrs
[
'data_format'
]
==
"NHWC"
:
return
True
return
False
self
.
add_ignore_check_case
(
teller1
,
SkipReasons
.
PASS_ACCURACY_ERROR
,
"The output format of conv2d is wrong when data_format attribute is NHWC"
)
def
test
(
self
):
def
test
(
self
):
self
.
run_and_statis
(
quant
=
False
,
passes
=
[
"conv_gelu_mkldnn_fuse_pass"
])
self
.
run_and_statis
(
quant
=
False
,
passes
=
[
"conv_gelu_mkldnn_fuse_pass"
])
...
...
python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_hard_sigmoid_fuse_pass.py
浏览文件 @
87099d12
...
@@ -27,15 +27,6 @@ import hypothesis.strategies as st
...
@@ -27,15 +27,6 @@ import hypothesis.strategies as st
class
TestConvHardSigmoidMkldnnFusePass
(
PassAutoScanTest
):
class
TestConvHardSigmoidMkldnnFusePass
(
PassAutoScanTest
):
def
is_program_valid
(
self
,
program_config
:
ProgramConfig
)
->
bool
:
def
is_program_valid
(
self
,
program_config
:
ProgramConfig
)
->
bool
:
attrs
=
[
program_config
.
ops
[
i
].
attrs
for
i
in
range
(
len
(
program_config
.
ops
))
]
# If the problem has been fixed, the judgment
# needs to be deleted!!!
if
attrs
[
0
][
'data_format'
]
==
"NHWC"
:
return
False
return
True
return
True
def
sample_program_config
(
self
,
draw
):
def
sample_program_config
(
self
,
draw
):
...
...
python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_hard_swish_fuse_pass.py
浏览文件 @
87099d12
...
@@ -27,15 +27,6 @@ import hypothesis.strategies as st
...
@@ -27,15 +27,6 @@ import hypothesis.strategies as st
class
TestConvHardSwishMkldnnFusePass
(
PassAutoScanTest
):
class
TestConvHardSwishMkldnnFusePass
(
PassAutoScanTest
):
def
is_program_valid
(
self
,
program_config
:
ProgramConfig
)
->
bool
:
def
is_program_valid
(
self
,
program_config
:
ProgramConfig
)
->
bool
:
attrs
=
[
program_config
.
ops
[
i
].
attrs
for
i
in
range
(
len
(
program_config
.
ops
))
]
# If the problem has been fixed, the judgment
# needs to be deleted!!!
if
attrs
[
0
][
'data_format'
]
==
"NHWC"
:
return
False
return
True
return
True
def
sample_program_config
(
self
,
draw
):
def
sample_program_config
(
self
,
draw
):
...
...
python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_transpose_bias_fuse_pass.py
浏览文件 @
87099d12
...
@@ -32,9 +32,9 @@ class TestConvTransposeMkldnnFusePass(PassAutoScanTest):
...
@@ -32,9 +32,9 @@ class TestConvTransposeMkldnnFusePass(PassAutoScanTest):
for
i
in
range
(
len
(
program_config
.
ops
))
for
i
in
range
(
len
(
program_config
.
ops
))
]
]
# If the problem has been fixed, the judgment
if
attrs
[
0
][
'data_format'
]
==
"NCHW"
and
attrs
[
1
][
"axis"
]
==
3
:
# needs to be deleted!!!
return
False
if
attrs
[
0
][
'data_format'
]
==
"NHWC"
:
if
attrs
[
0
][
'data_format'
]
==
"NHWC"
and
attrs
[
1
][
"axis"
]
==
1
:
return
False
return
False
return
True
return
True
...
@@ -46,7 +46,7 @@ class TestConvTransposeMkldnnFusePass(PassAutoScanTest):
...
@@ -46,7 +46,7 @@ class TestConvTransposeMkldnnFusePass(PassAutoScanTest):
groups
=
draw
(
st
.
sampled_from
([
1
,
2
,
4
,
8
]))
groups
=
draw
(
st
.
sampled_from
([
1
,
2
,
4
,
8
]))
paddings
=
draw
(
st
.
sampled_from
([[
0
,
3
],
[
1
,
2
,
3
,
4
]]))
paddings
=
draw
(
st
.
sampled_from
([[
0
,
3
],
[
1
,
2
,
3
,
4
]]))
strides
=
draw
(
st
.
sampled_from
([[
1
,
1
],
[
2
,
2
],
[
1
,
2
]]))
strides
=
draw
(
st
.
sampled_from
([[
1
,
1
],
[
2
,
2
],
[
1
,
2
]]))
axis
=
draw
(
st
.
sampled_from
([
1
]))
axis
=
draw
(
st
.
sampled_from
([
1
,
3
]))
batch_size
=
draw
(
st
.
integers
(
min_value
=
1
,
max_value
=
4
))
batch_size
=
draw
(
st
.
integers
(
min_value
=
1
,
max_value
=
4
))
def
generate_input
():
def
generate_input
():
...
@@ -110,7 +110,9 @@ class TestConvTransposeMkldnnFusePass(PassAutoScanTest):
...
@@ -110,7 +110,9 @@ class TestConvTransposeMkldnnFusePass(PassAutoScanTest):
def
test
(
self
):
def
test
(
self
):
self
.
run_and_statis
(
self
.
run_and_statis
(
quant
=
False
,
passes
=
[
"conv_transpose_bias_mkldnn_fuse_pass"
])
quant
=
False
,
max_duration
=
300
,
passes
=
[
"conv_transpose_bias_mkldnn_fuse_pass"
])
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录