Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
f3c14762
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
f3c14762
编写于
8月 05, 2022
作者:
J
joanna.wozna.intel
提交者:
GitHub
8月 05, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add int8 support for matmulV2 (#44908)
上级
075d7219
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
231 addition
and
187 deletion
+231
-187
paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc
paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc
+167
-161
paddle/fluid/operators/mkldnn/mul_mkldnn_op.cc
paddle/fluid/operators/mkldnn/mul_mkldnn_op.cc
+10
-10
paddle/fluid/platform/mkldnn_reuse.h
paddle/fluid/platform/mkldnn_reuse.h
+54
-16
未找到文件。
paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc
浏览文件 @
f3c14762
...
@@ -659,7 +659,7 @@ float ComputeOutputScale(const ExecutionContext &ctx) {
...
@@ -659,7 +659,7 @@ float ComputeOutputScale(const ExecutionContext &ctx) {
return
alpha
*
scale_out
/
(
scale_x
*
scale_y
);
return
alpha
*
scale_out
/
(
scale_x
*
scale_y
);
}
}
template
<
typename
T
>
template
<
typename
T
,
typename
T_out
>
void
ExecuteMatMulV2
(
const
ExecutionContext
&
ctx
,
void
ExecuteMatMulV2
(
const
ExecutionContext
&
ctx
,
const
MKLDNNDeviceContext
&
dev_ctx
,
const
MKLDNNDeviceContext
&
dev_ctx
,
const
dnnl
::
engine
onednn_engine
,
const
dnnl
::
engine
onednn_engine
,
...
@@ -675,16 +675,16 @@ void ExecuteMatMulV2(const ExecutionContext &ctx,
...
@@ -675,16 +675,16 @@ void ExecuteMatMulV2(const ExecutionContext &ctx,
int
execution_number
=
0
)
{
int
execution_number
=
0
)
{
std
::
vector
<
int64_t
>
x_strides_override
=
GetInputStrides
(
ctx
,
"X"
);
std
::
vector
<
int64_t
>
x_strides_override
=
GetInputStrides
(
ctx
,
"X"
);
std
::
vector
<
int64_t
>
y_strides_override
=
GetInputStrides
(
ctx
,
"Y"
);
std
::
vector
<
int64_t
>
y_strides_override
=
GetInputStrides
(
ctx
,
"Y"
);
MatMulV2MKLDNNHandler
<
T
>
handler
(
ctx
,
MatMulV2MKLDNNHandler
<
T
,
T
,
T_out
>
handler
(
ctx
,
onednn_engine
,
onednn_engine
,
ctx
.
GetPlace
(),
ctx
.
GetPlace
(),
x_dims
,
x_dims
,
trans_x
,
trans_x
,
y_dims
,
y_dims
,
trans_y
,
trans_y
,
IsOutputFused
(
ctx
),
IsOutputFused
(
ctx
),
x_strides_override
,
x_strides_override
,
y_strides_override
);
y_strides_override
);
const
auto
src_memory_p
=
handler
.
AcquireSrcMemory
(
x
);
const
auto
src_memory_p
=
handler
.
AcquireSrcMemory
(
x
);
const
auto
weights_memory_p
=
handler
.
AcquireWeightsMemory
(
y
);
const
auto
weights_memory_p
=
handler
.
AcquireWeightsMemory
(
y
);
...
@@ -707,17 +707,41 @@ void ExecuteMatMulV2(const ExecutionContext &ctx,
...
@@ -707,17 +707,41 @@ void ExecuteMatMulV2(const ExecutionContext &ctx,
auto
&
astream
=
MKLDNNDeviceContext
::
tls
().
get_stream
();
auto
&
astream
=
MKLDNNDeviceContext
::
tls
().
get_stream
();
matmul_p
->
execute
(
astream
,
matmul_args
);
matmul_p
->
execute
(
astream
,
matmul_args
);
astream
.
wait
();
astream
.
wait
();
auto
format
=
auto
format
=
paddle
::
platform
::
MKLDNNFormatForSize
(
MKLDNNFormatForSize
(
out
->
dims
().
size
(),
dnnl
::
memory
::
format_tag
::
nchw
);
out
->
dims
().
size
(),
dnnl
::
memory
::
format_tag
::
nchw
);
out
->
set_layout
(
paddle
::
framework
::
DataLayout
::
kMKLDNN
);
out
->
set_format
(
format
);
out
->
set_format
(
format
);
out
->
set_layout
(
DataLayout
::
kMKLDNN
);
}
}
template
<
typename
T
>
template
<
typename
T
>
class
MatMulV2MKLDNNKernel
:
public
paddle
::
framework
::
OpKernel
<
T
>
{
class
MatMulV2MKLDNNKernel
:
public
paddle
::
framework
::
OpKernel
<
T
>
{
public:
public:
void
Compute
(
const
ExecutionContext
&
ctx
)
const
override
{
RunKernel
(
ctx
);
}
void
Compute
(
const
ExecutionContext
&
ctx
)
const
override
{
if
(
ctx
.
HasAttr
(
"head_number"
))
{
PADDLE_ENFORCE_EQ
(
ctx
.
Attr
<
int
>
(
"head_number"
),
1
,
paddle
::
platform
::
errors
::
Unimplemented
(
"oneDNN matmul doesn't support multiple heads. Expected "
"head_number=1. But received `head_number` is %d"
,
ctx
.
Attr
<
int
>
(
"head_number"
)));
}
constexpr
bool
is_int8
=
IsInt8
<
T
>
();
constexpr
bool
is_bfloat16
=
IsBfloat16
<
T
>
();
const
bool
force_fp32_output
=
ctx
.
HasAttr
(
"force_fp32_output"
)
?
ctx
.
Attr
<
bool
>
(
"force_fp32_output"
)
:
false
;
constexpr
bool
fuse_relu
=
false
;
// TODO(intel): Enable eltwise fuses
if
(
force_fp32_output
||
((
!
is_int8
)
&&
(
!
is_bfloat16
)))
{
RunKernel
<
float
>
(
ctx
);
}
else
if
(
is_bfloat16
)
{
RunKernel
<
paddle
::
platform
::
bfloat16
>
(
ctx
);
}
else
if
(
fuse_relu
)
{
RunKernel
<
uint8_t
>
(
ctx
);
}
else
{
RunKernel
<
int8_t
>
(
ctx
);
}
}
private:
private:
void
CalculateMatrixDims
(
const
ExecutionContext
&
ctx
,
void
CalculateMatrixDims
(
const
ExecutionContext
&
ctx
,
...
@@ -768,6 +792,7 @@ class MatMulV2MKLDNNKernel : public paddle::framework::OpKernel<T> {
...
@@ -768,6 +792,7 @@ class MatMulV2MKLDNNKernel : public paddle::framework::OpKernel<T> {
}
}
}
}
template
<
typename
T_out
>
void
RunKernel
(
const
ExecutionContext
&
ctx
)
const
{
void
RunKernel
(
const
ExecutionContext
&
ctx
)
const
{
const
auto
&
dev_ctx
=
ctx
.
template
device_context
<
MKLDNNDeviceContext
>();
const
auto
&
dev_ctx
=
ctx
.
template
device_context
<
MKLDNNDeviceContext
>();
const
auto
&
onednn_engine
=
dev_ctx
.
GetEngine
();
const
auto
&
onednn_engine
=
dev_ctx
.
GetEngine
();
...
@@ -793,18 +818,18 @@ class MatMulV2MKLDNNKernel : public paddle::framework::OpKernel<T> {
...
@@ -793,18 +818,18 @@ class MatMulV2MKLDNNKernel : public paddle::framework::OpKernel<T> {
CalculateMatrixDims
(
CalculateMatrixDims
(
ctx
,
x_dims
,
y_dims
,
&
x_bd_dims
,
&
y_bd_dims
,
&
out_dims
,
out
);
ctx
,
x_dims
,
y_dims
,
&
x_bd_dims
,
&
y_bd_dims
,
&
out_dims
,
out
);
ExecuteMatMulV2
<
T
>
(
ctx
,
ExecuteMatMulV2
<
T
,
T_out
>
(
ctx
,
dev_ctx
,
dev_ctx
,
onednn_engine
,
onednn_engine
,
ctx
.
GetPlace
(),
ctx
.
GetPlace
(),
x
,
x
,
x_bd_dims
,
x_bd_dims
,
trans_x
,
trans_x
,
y
,
y
,
y_bd_dims
,
y_bd_dims
,
trans_y
,
trans_y
,
out
,
out
,
out_dims
);
out_dims
);
}
}
};
};
...
@@ -939,113 +964,113 @@ class MatMulV2GradMKLDNNKernel : public paddle::framework::OpKernel<T> {
...
@@ -939,113 +964,113 @@ class MatMulV2GradMKLDNNKernel : public paddle::framework::OpKernel<T> {
ctx
,
&
dx_tmp
,
&
dy_tmp
,
x_dims
,
y_dims
,
&
dx_bd_dims
,
&
dy_bd_dims
);
ctx
,
&
dx_tmp
,
&
dy_tmp
,
x_dims
,
y_dims
,
&
dx_bd_dims
,
&
dy_bd_dims
);
if
(
trans_x
&&
trans_y
)
{
if
(
trans_x
&&
trans_y
)
{
ExecuteMatMulV2
<
T
>
(
ctx
,
ExecuteMatMulV2
<
T
,
T
>
(
ctx
,
dev_ctx
,
dev_ctx
,
onednn_engine
,
onednn_engine
,
ctx
.
GetPlace
(),
ctx
.
GetPlace
(),
y
,
y
,
y_dims
,
y_dims
,
true
,
true
,
dout
,
dout
,
dout_dims
,
dout_dims
,
true
,
true
,
&
dx_tmp
,
&
dx_tmp
,
dx_bd_dims
,
dx_bd_dims
,
1
);
1
);
ExecuteMatMulV2
<
T
>
(
ctx
,
ExecuteMatMulV2
<
T
,
T
>
(
ctx
,
dev_ctx
,
dev_ctx
,
onednn_engine
,
onednn_engine
,
ctx
.
GetPlace
(),
ctx
.
GetPlace
(),
dout
,
dout
,
dout_dims
,
dout_dims
,
true
,
true
,
x
,
x
,
x_dims
,
x_dims
,
true
,
true
,
&
dy_tmp
,
&
dy_tmp
,
dy_bd_dims
,
dy_bd_dims
,
2
);
2
);
}
else
if
(
trans_x
)
{
}
else
if
(
trans_x
)
{
ExecuteMatMulV2
<
T
>
(
ctx
,
ExecuteMatMulV2
<
T
,
T
>
(
ctx
,
dev_ctx
,
dev_ctx
,
onednn_engine
,
onednn_engine
,
ctx
.
GetPlace
(),
ctx
.
GetPlace
(),
y
,
y
,
y_dims
,
y_dims
,
false
,
false
,
dout
,
dout
,
dout_dims
,
dout_dims
,
true
,
true
,
&
dx_tmp
,
&
dx_tmp
,
dx_bd_dims
,
dx_bd_dims
,
1
);
1
);
ExecuteMatMulV2
<
T
>
(
ctx
,
ExecuteMatMulV2
<
T
,
T
>
(
ctx
,
dev_ctx
,
dev_ctx
,
onednn_engine
,
onednn_engine
,
ctx
.
GetPlace
(),
ctx
.
GetPlace
(),
x
,
x
,
x_dims
,
x_dims
,
false
,
false
,
dout
,
dout
,
dout_dims
,
dout_dims
,
false
,
false
,
&
dy_tmp
,
&
dy_tmp
,
dy_bd_dims
,
dy_bd_dims
,
2
);
2
);
}
else
if
(
trans_y
)
{
}
else
if
(
trans_y
)
{
ExecuteMatMulV2
<
T
>
(
ctx
,
ExecuteMatMulV2
<
T
,
T
>
(
ctx
,
dev_ctx
,
dev_ctx
,
onednn_engine
,
onednn_engine
,
ctx
.
GetPlace
(),
ctx
.
GetPlace
(),
dout
,
dout
,
dout_dims
,
dout_dims
,
false
,
false
,
y
,
y
,
y_dims
,
y_dims
,
false
,
false
,
&
dx_tmp
,
&
dx_tmp
,
dx_bd_dims
,
dx_bd_dims
,
1
);
1
);
ExecuteMatMulV2
<
T
>
(
ctx
,
ExecuteMatMulV2
<
T
,
T
>
(
ctx
,
dev_ctx
,
dev_ctx
,
onednn_engine
,
onednn_engine
,
ctx
.
GetPlace
(),
ctx
.
GetPlace
(),
dout
,
dout
,
dout_dims
,
dout_dims
,
true
,
true
,
x
,
x
,
x_dims
,
x_dims
,
false
,
false
,
&
dy_tmp
,
&
dy_tmp
,
dy_bd_dims
,
dy_bd_dims
,
2
);
2
);
}
else
{
}
else
{
ExecuteMatMulV2
<
T
>
(
ctx
,
ExecuteMatMulV2
<
T
,
T
>
(
ctx
,
dev_ctx
,
dev_ctx
,
onednn_engine
,
onednn_engine
,
ctx
.
GetPlace
(),
ctx
.
GetPlace
(),
dout
,
dout
,
dout_dims
,
dout_dims
,
false
,
false
,
y
,
y
,
y_dims
,
y_dims
,
true
,
true
,
&
dx_tmp
,
&
dx_tmp
,
dx_bd_dims
,
dx_bd_dims
,
1
);
1
);
ExecuteMatMulV2
<
T
>
(
ctx
,
ExecuteMatMulV2
<
T
,
T
>
(
ctx
,
dev_ctx
,
dev_ctx
,
onednn_engine
,
onednn_engine
,
ctx
.
GetPlace
(),
ctx
.
GetPlace
(),
x
,
x
,
x_dims
,
x_dims
,
true
,
true
,
dout
,
dout
,
dout_dims
,
dout_dims
,
false
,
false
,
&
dy_tmp
,
&
dy_tmp
,
dy_bd_dims
,
dy_bd_dims
,
2
);
2
);
}
}
if
(
x_dims
!=
dx_bd_dims
)
{
if
(
x_dims
!=
dx_bd_dims
)
{
...
@@ -1234,34 +1259,13 @@ template class MatMulGradMKLDNNKernel<paddle::platform::bfloat16>;
...
@@ -1234,34 +1259,13 @@ template class MatMulGradMKLDNNKernel<paddle::platform::bfloat16>;
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE
(
matmul
,
REGISTER_OP_KERNEL
(
matmul
,
MKLDNN
,
MKLDNN
,
::
paddle
::
platform
::
CPUPlace
,
::
paddle
::
platform
::
CPUPlace
,
S8
,
MatMulV2MKLDNNKernel
<
float
>
,
0
,
MatMulV2MKLDNNKernel
<
paddle
::
platform
::
bfloat16
>
,
MatMulMKLDNNKernel
<
int8_t
>
);
MatMulV2MKLDNNKernel
<
int8_t
>
,
MatMulV2MKLDNNKernel
<
uint8_t
>
);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE
(
matmul
,
MKLDNN
,
::
paddle
::
platform
::
CPUPlace
,
U8
,
0
,
MatMulMKLDNNKernel
<
uint8_t
>
);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE
(
matmul
,
MKLDNN
,
::
paddle
::
platform
::
CPUPlace
,
FP32
,
0
,
MatMulV2MKLDNNKernel
<
float
>
);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE
(
matmul
,
MKLDNN
,
::
paddle
::
platform
::
CPUPlace
,
BF16
,
0
,
MatMulV2MKLDNNKernel
<
paddle
::
platform
::
bfloat16
>
);
REGISTER_OP_KERNEL
(
matmul_grad
,
REGISTER_OP_KERNEL
(
matmul_grad
,
MKLDNN
,
MKLDNN
,
...
@@ -1273,7 +1277,9 @@ REGISTER_OP_KERNEL(matmul_v2,
...
@@ -1273,7 +1277,9 @@ REGISTER_OP_KERNEL(matmul_v2,
MKLDNN
,
MKLDNN
,
::
paddle
::
platform
::
CPUPlace
,
::
paddle
::
platform
::
CPUPlace
,
MatMulV2MKLDNNKernel
<
float
>
,
MatMulV2MKLDNNKernel
<
float
>
,
MatMulV2MKLDNNKernel
<
paddle
::
platform
::
bfloat16
>
);
MatMulV2MKLDNNKernel
<
paddle
::
platform
::
bfloat16
>
,
MatMulV2MKLDNNKernel
<
int8_t
>
,
MatMulV2MKLDNNKernel
<
uint8_t
>
);
REGISTER_OP_KERNEL
(
matmul_v2_grad
,
REGISTER_OP_KERNEL
(
matmul_v2_grad
,
MKLDNN
,
MKLDNN
,
...
...
paddle/fluid/operators/mkldnn/mul_mkldnn_op.cc
浏览文件 @
f3c14762
...
@@ -416,16 +416,16 @@ class MulMKLDNNKernel : public framework::OpKernel<XT> {
...
@@ -416,16 +416,16 @@ class MulMKLDNNKernel : public framework::OpKernel<XT> {
bool
trans_y
,
bool
trans_y
,
Tensor
*
out
)
const
{
Tensor
*
out
)
const
{
static
const
std
::
vector
<
int64_t
>
vec_placeholder
;
static
const
std
::
vector
<
int64_t
>
vec_placeholder
;
MatMulV2MKLDNNHandler
<
XT
>
handler
(
ctx
,
MatMulV2MKLDNNHandler
<
XT
,
YT
,
XT
>
handler
(
ctx
,
onednn_engine
,
onednn_engine
,
ctx
.
GetPlace
(),
ctx
.
GetPlace
(),
x_dims
,
x_dims
,
trans_x
,
trans_x
,
y_dims
,
y_dims
,
trans_y
,
trans_y
,
false
,
false
,
vec_placeholder
,
vec_placeholder
,
vec_placeholder
);
vec_placeholder
);
const
auto
src_memory_p
=
handler
.
AcquireSrcMemory
(
x
);
const
auto
src_memory_p
=
handler
.
AcquireSrcMemory
(
x
);
const
auto
weights_memory_p
=
handler
.
AcquireWeightsMemory
(
y
);
const
auto
weights_memory_p
=
handler
.
AcquireWeightsMemory
(
y
);
...
...
paddle/fluid/platform/mkldnn_reuse.h
浏览文件 @
f3c14762
...
@@ -860,8 +860,18 @@ class ReductionMKLDNNHandler
...
@@ -860,8 +860,18 @@ class ReductionMKLDNNHandler
};
};
template
<
typename
T
>
template
<
typename
T
>
constexpr
bool
IsInt8
()
{
return
std
::
is_same
<
T
,
int8_t
>::
value
||
std
::
is_same
<
T
,
uint8_t
>::
value
;
}
template
<
typename
T
>
constexpr
bool
IsBfloat16
()
{
return
std
::
is_same
<
T
,
paddle
::
platform
::
bfloat16
>::
value
;
}
template
<
typename
XT
,
typename
YT
,
typename
OT
>
class
MatMulV2MKLDNNHandler
class
MatMulV2MKLDNNHandler
:
public
paddle
::
platform
::
MKLDNNHandlerNoCachingT
<
T
,
dnnl
::
matmul
>
{
:
public
paddle
::
platform
::
MKLDNNHandlerNoCachingT
<
X
T
,
dnnl
::
matmul
>
{
public:
public:
MatMulV2MKLDNNHandler
(
const
framework
::
ExecutionContext
&
ctx
,
MatMulV2MKLDNNHandler
(
const
framework
::
ExecutionContext
&
ctx
,
const
dnnl
::
engine
engine
,
const
dnnl
::
engine
engine
,
...
@@ -873,8 +883,8 @@ class MatMulV2MKLDNNHandler
...
@@ -873,8 +883,8 @@ class MatMulV2MKLDNNHandler
bool
is_output_fused
,
bool
is_output_fused
,
const
std
::
vector
<
int64_t
>&
x_strides_override
,
const
std
::
vector
<
int64_t
>&
x_strides_override
,
const
std
::
vector
<
int64_t
>&
y_strides_override
)
const
std
::
vector
<
int64_t
>&
y_strides_override
)
:
paddle
::
platform
::
MKLDNNHandlerNoCachingT
<
T
,
dnnl
::
matmul
>
(
engine
,
:
paddle
::
platform
::
MKLDNNHandlerNoCachingT
<
X
T
,
dnnl
::
matmul
>
(
engine
,
cpu_place
)
{
cpu_place
)
{
// M X K * K X N
// M X K * K X N
std
::
vector
<
int64_t
>
x_dims
(
x_org_dims
);
std
::
vector
<
int64_t
>
x_dims
(
x_org_dims
);
std
::
vector
<
int64_t
>
y_dims
(
y_org_dims
);
std
::
vector
<
int64_t
>
y_dims
(
y_org_dims
);
...
@@ -934,28 +944,42 @@ class MatMulV2MKLDNNHandler
...
@@ -934,28 +944,42 @@ class MatMulV2MKLDNNHandler
out_strides
[
i
]
=
out_ddims
[
i
+
1
]
*
out_strides
[
i
+
1
];
out_strides
[
i
]
=
out_ddims
[
i
+
1
]
*
out_strides
[
i
+
1
];
}
}
if
(
is_output_fused
)
{
if
(
!
IsInt8
<
OT
>
()
&&
!
IsBfloat16
<
OT
>
()
&&
is_output_fused
)
{
out_strides
=
FakeTransposeStrides
(
out_ddims
);
out_strides
=
FakeTransposeStrides
(
out_ddims
);
}
}
auto
x_md
=
memory
::
desc
(
x_dims
,
MKLDNNGetDataType
<
T
>
(),
x_strides
);
auto
x_md
=
memory
::
desc
(
x_dims
,
MKLDNNGetDataType
<
X
T
>
(),
x_strides
);
auto
y_md
=
memory
::
desc
(
y_dims
,
MKLDNNGetDataType
<
T
>
(),
y_strides
);
auto
y_md
=
memory
::
desc
(
y_dims
,
MKLDNNGetDataType
<
Y
T
>
(),
y_strides
);
auto
out_md
=
memory
::
desc
(
out_ddims
,
MKLDNNGetDataType
<
T
>
(),
out_strides
);
auto
out_md
=
memory
::
desc
(
out_ddims
,
MKLDNNGetDataType
<
O
T
>
(),
out_strides
);
const
dnnl
::
primitive_attr
matmul_attrs
=
CreateMatmulAttrs
(
ctx
);
const
dnnl
::
primitive_attr
matmul_attrs
=
CreateMatmulAttrs
(
ctx
);
this
->
AcquireForwardPrimitiveDescriptor
(
matmul_attrs
,
x_md
,
y_md
,
out_md
);
this
->
AcquireForwardPrimitiveDescriptor
(
matmul_attrs
,
x_md
,
y_md
,
out_md
);
}
}
// TODO(jczaja) : Adapt to int8
float
ComputeOutputScale
(
const
framework
::
ExecutionContext
&
ctx
)
{
float
alpha
=
ctx
.
HasAttr
(
"alpha"
)
?
ctx
.
Attr
<
float
>
(
"alpha"
)
:
1.0
f
;
if
(
ctx
.
HasAttr
(
"Scale_x"
)
&&
ctx
.
HasAttr
(
"Scale_y"
)
&&
ctx
.
HasAttr
(
"Scale_out"
))
{
float
scale_x
=
ctx
.
Attr
<
float
>
(
"Scale_x"
);
float
scale_y
=
ctx
.
Attr
<
float
>
(
"Scale_y"
);
bool
force_fp32_out
=
ctx
.
HasAttr
(
"force_fp32_output"
)
?
ctx
.
Attr
<
bool
>
(
"force_fp32_output"
)
:
false
;
float
scale_out
=
force_fp32_out
?
1.
f
:
ctx
.
Attr
<
float
>
(
"Scale_out"
);
alpha
*=
scale_out
/
(
scale_x
*
scale_y
);
}
return
alpha
;
}
dnnl
::
primitive_attr
CreateMatmulAttrs
(
dnnl
::
primitive_attr
CreateMatmulAttrs
(
const
framework
::
ExecutionContext
&
ctx
)
{
const
framework
::
ExecutionContext
&
ctx
)
{
dnnl
::
primitive_attr
matmul_attrs
;
dnnl
::
primitive_attr
matmul_attrs
;
dnnl
::
post_ops
post_operations
;
dnnl
::
post_ops
post_operations
;
float
alpha
=
ctx
.
HasAttr
(
"alpha"
)
?
ctx
.
Attr
<
float
>
(
"alpha"
)
:
1.0
f
;
float
scale_out
=
ComputeOutputScale
(
ctx
)
;
if
(
alpha
!=
1.0
f
)
{
if
(
scale_out
!=
1.0
f
)
{
matmul_attrs
.
set_output_scales
(
0
,
{
alpha
});
matmul_attrs
.
set_output_scales
(
0
,
{
scale_out
});
}
}
if
(
ctx
.
HasInput
(
"ResidualData"
))
{
if
(
ctx
.
HasInput
(
"ResidualData"
))
{
...
@@ -993,9 +1017,23 @@ class MatMulV2MKLDNNHandler
...
@@ -993,9 +1017,23 @@ class MatMulV2MKLDNNHandler
}
}
std
::
shared_ptr
<
memory
>
AcquireWeightsMemory
(
const
Tensor
*
input
)
{
std
::
shared_ptr
<
memory
>
AcquireWeightsMemory
(
const
Tensor
*
input
)
{
const
T
*
input_data
=
input
->
data
<
T
>
();
const
YT
*
input_data
=
input
->
data
<
Y
T
>
();
return
this
->
AcquireMemoryFromPrimitive
(
this
->
fwd_pd_
->
weights_desc
(),
return
this
->
AcquireMemoryFromPrimitive
(
this
->
fwd_pd_
->
weights_desc
(),
to_void_cast
<
T
>
(
input_data
));
to_void_cast
<
YT
>
(
input_data
));
}
std
::
shared_ptr
<
dnnl
::
memory
>
AcquireDstMemory
(
paddle
::
framework
::
Tensor
*
output
)
{
// We cannot use base AcquireDstMemory as it makes an allocation request
// base on DST memory primitive size. This is fine in general, but in MatMul
// we have primitive that covers only one batch of Data and then shift
// pointer for every new batch. Hence Tensor size is bigger that dst memory
// primitive size. So would we request less memory that is there and it
// triggers an
// assertion. So as there is no 'any' format here we can leave default size
// of Tensor as computed in ComputeInferShape
OT
*
ptr
=
output
->
mutable_data
<
OT
>
(
this
->
place_
);
return
this
->
AcquireMemoryFromPrimitive
(
this
->
fwd_pd_
->
dst_desc
(),
ptr
);
}
}
};
};
...
@@ -1099,11 +1137,11 @@ class ActivationMKLDNNHandler
...
@@ -1099,11 +1137,11 @@ class ActivationMKLDNNHandler
static
std
::
unordered_map
<
std
::
string
,
std
::
string
>
GetAttributeMap
(
static
std
::
unordered_map
<
std
::
string
,
std
::
string
>
GetAttributeMap
(
std
::
string
act_type
)
{
std
::
string
act_type
)
{
std
::
unordered_map
<
std
::
string
,
std
::
string
>
attr_map
;
std
::
unordered_map
<
std
::
string
,
std
::
string
>
attr_map
;
if
(
act_type
==
"swish"
)
if
(
act_type
==
"swish"
)
{
attr_map
.
emplace
(
"beta"
,
"fuse_alpha"
);
attr_map
.
emplace
(
"beta"
,
"fuse_alpha"
);
else
if
(
act_type
==
"relu6"
)
}
else
if
(
act_type
==
"relu6"
)
{
attr_map
.
emplace
(
"threshold"
,
"fuse_alpha"
);
attr_map
.
emplace
(
"threshold"
,
"fuse_alpha"
);
else
if
(
act_type
==
"hard_sigmoid"
)
{
}
else
if
(
act_type
==
"hard_sigmoid"
)
{
attr_map
.
emplace
(
"slope"
,
"fuse_alpha"
);
attr_map
.
emplace
(
"slope"
,
"fuse_alpha"
);
attr_map
.
emplace
(
"offset"
,
"fuse_beta"
);
attr_map
.
emplace
(
"offset"
,
"fuse_beta"
);
}
else
if
(
act_type
==
"clip"
)
{
}
else
if
(
act_type
==
"clip"
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录