Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
ff92b6ba
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
ff92b6ba
编写于
8月 16, 2018
作者:
T
tensor-tang
提交者:
GitHub
8月 16, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #12531 from tensor-tang/refine/op/gru
Refine gru cpu forward
上级
d4d8f831
c588c64a
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
373 addition
and
87 deletion
+373
-87
paddle/fluid/operators/gru_op.cc
paddle/fluid/operators/gru_op.cc
+159
-3
paddle/fluid/operators/gru_op.cu.cc
paddle/fluid/operators/gru_op.cu.cc
+90
-0
paddle/fluid/operators/gru_op.h
paddle/fluid/operators/gru_op.h
+0
-84
paddle/fluid/operators/math/blas.h
paddle/fluid/operators/math/blas.h
+41
-0
paddle/fluid/operators/math/blas_impl.h
paddle/fluid/operators/math/blas_impl.h
+75
-0
paddle/fluid/platform/dynload/mklml.h
paddle/fluid/platform/dynload/mklml.h
+8
-0
未找到文件。
paddle/fluid/operators/gru_op.cc
浏览文件 @
ff92b6ba
...
...
@@ -14,6 +14,11 @@ limitations under the License. */
#include "paddle/fluid/operators/gru_op.h"
#include <string>
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/detail/gru_cpu_kernel.h"
#include "paddle/fluid/operators/math/detail/gru_kernel.h"
DECLARE_int32
(
paddle_num_threads
);
namespace
paddle
{
namespace
operators
{
...
...
@@ -211,6 +216,158 @@ class GRUGradOp : public framework::OperatorWithKernel {
}
};
template
<
typename
T
>
class
GRUCPUKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
BatchCompute
(
const
framework
::
ExecutionContext
&
context
)
const
{
using
DeviceContext
=
paddle
::
platform
::
CPUDeviceContext
;
auto
*
input
=
context
.
Input
<
LoDTensor
>
(
"Input"
);
auto
*
h0
=
context
.
Input
<
Tensor
>
(
"H0"
);
auto
*
weight
=
context
.
Input
<
Tensor
>
(
"Weight"
);
const
T
*
weight_data
=
weight
->
data
<
T
>
();
auto
*
bias
=
context
.
Input
<
Tensor
>
(
"Bias"
);
auto
*
batch_gate
=
context
.
Output
<
LoDTensor
>
(
"BatchGate"
);
batch_gate
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
*
batch_reset_hidden_prev
=
context
.
Output
<
LoDTensor
>
(
"BatchResetHiddenPrev"
);
batch_reset_hidden_prev
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
*
batch_hidden
=
context
.
Output
<
LoDTensor
>
(
"BatchHidden"
);
batch_hidden
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
*
hidden
=
context
.
Output
<
LoDTensor
>
(
"Hidden"
);
hidden
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
hidden_dims
=
hidden
->
dims
();
bool
is_reverse
=
context
.
Attr
<
bool
>
(
"is_reverse"
);
math
::
LoDTensor2BatchFunctor
<
DeviceContext
,
T
>
to_batch
;
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
to_batch
(
dev_ctx
,
*
input
,
batch_gate
,
true
,
is_reverse
);
if
(
bias
)
{
math
::
RowwiseAdd
<
DeviceContext
,
T
>
add_bias
;
add_bias
(
dev_ctx
,
*
batch_gate
,
*
bias
,
batch_gate
);
}
int
frame_size
=
hidden_dims
[
1
];
math
::
GRUMetaValue
<
T
>
gru_value
;
gru_value
.
gate_weight
=
const_cast
<
T
*>
(
weight_data
);
gru_value
.
state_weight
=
const_cast
<
T
*>
(
weight_data
+
2
*
frame_size
*
frame_size
);
Tensor
ordered_h0
;
framework
::
Vector
<
size_t
>
order
(
batch_gate
->
lod
()[
2
]);
if
(
h0
)
{
// Since the batch computing for GRU reorders the input sequences
// according to their length. The initialized cell state also needs
// to reorder.
ReorderInitState
<
DeviceContext
,
T
>
(
context
.
template
device_context
<
DeviceContext
>(),
*
h0
,
order
,
&
ordered_h0
,
true
);
gru_value
.
prev_out_value
=
ordered_h0
.
data
<
T
>
();
}
else
{
gru_value
.
prev_out_value
=
nullptr
;
}
auto
batch_starts
=
batch_gate
->
lod
()[
0
];
size_t
seq_len
=
batch_starts
.
size
()
-
1
;
auto
active_node
=
math
::
detail
::
GetActivationType
(
context
.
Attr
<
std
::
string
>
(
"activation"
));
auto
active_gate
=
math
::
detail
::
GetActivationType
(
context
.
Attr
<
std
::
string
>
(
"gate_activation"
));
#ifdef PADDLE_WITH_MKLML
// use MKL packed to speedup GEMM
if
(
FLAGS_paddle_num_threads
>=
4
)
{
auto
blas
=
math
::
GetBlas
<
DeviceContext
,
T
>
(
dev_ctx
);
T
*
packed_gate
=
blas
.
GEMM_ALLOC
(
CblasBMatrix
,
1
/*height of C*/
,
frame_size
*
2
/*width of weight*/
,
frame_size
/*height of height*/
);
PADDLE_ENFORCE
(
packed_gate
);
blas
.
GEMM_PACK
(
CblasBMatrix
,
CblasNoTrans
,
1
/*cur bs?*/
,
frame_size
*
2
,
frame_size
,
T
(
1.0
),
gru_value
.
gate_weight
,
frame_size
*
2
,
packed_gate
);
T
*
packed_state
=
blas
.
GEMM_ALLOC
(
CblasBMatrix
,
1
/*height of C*/
,
frame_size
/*width of weight*/
,
frame_size
/*height of height*/
);
PADDLE_ENFORCE
(
packed_state
);
blas
.
GEMM_PACK
(
CblasBMatrix
,
CblasNoTrans
,
1
/*cur bs?*/
,
frame_size
,
frame_size
,
T
(
1.0
),
gru_value
.
state_weight
,
frame_size
,
packed_state
);
for
(
size_t
n
=
0
;
n
<
seq_len
;
n
++
)
{
int
bstart
=
static_cast
<
int
>
(
batch_starts
[
n
]);
int
bend
=
static_cast
<
int
>
(
batch_starts
[
n
+
1
]);
int
cur_batch_size
=
bend
-
bstart
;
Tensor
gate_t
=
batch_gate
->
Slice
(
bstart
,
bend
);
Tensor
reset_hidden_prev_t
=
batch_reset_hidden_prev
->
Slice
(
bstart
,
bend
);
Tensor
hidden_t
=
batch_hidden
->
Slice
(
bstart
,
bend
);
gru_value
.
output_value
=
hidden_t
.
data
<
T
>
();
gru_value
.
gate_value
=
gate_t
.
data
<
T
>
();
gru_value
.
reset_output_value
=
reset_hidden_prev_t
.
data
<
T
>
();
if
(
gru_value
.
prev_out_value
)
{
blas
.
GEMM_COMPUTE
(
CblasNoTrans
,
CblasPacked
,
cur_batch_size
,
frame_size
*
2
,
frame_size
,
gru_value
.
prev_out_value
,
frame_size
,
packed_gate
,
frame_size
*
2
,
T
(
1
),
gru_value
.
gate_value
,
frame_size
*
3
);
}
math
::
detail
::
forward_reset_output
(
math
::
detail
::
forward
::
gru_resetOutput
<
T
>
(),
gru_value
,
frame_size
,
cur_batch_size
,
active_gate
);
if
(
gru_value
.
prev_out_value
)
{
blas
.
GEMM_COMPUTE
(
CblasNoTrans
,
CblasPacked
,
cur_batch_size
,
frame_size
,
frame_size
,
gru_value
.
reset_output_value
,
frame_size
,
packed_state
,
frame_size
,
T
(
1
),
gru_value
.
gate_value
+
frame_size
*
2
,
frame_size
*
3
);
}
math
::
detail
::
forward_final_output
(
math
::
detail
::
forward
::
gru_finalOutput
<
T
>
(),
gru_value
,
frame_size
,
cur_batch_size
,
active_node
);
gru_value
.
prev_out_value
=
gru_value
.
output_value
;
}
blas
.
GEMM_FREE
(
packed_gate
);
blas
.
GEMM_FREE
(
packed_state
);
}
else
{
#endif
for
(
size_t
n
=
0
;
n
<
seq_len
;
n
++
)
{
int
bstart
=
static_cast
<
int
>
(
batch_starts
[
n
]);
int
bend
=
static_cast
<
int
>
(
batch_starts
[
n
+
1
]);
int
cur_batch_size
=
bend
-
bstart
;
Tensor
gate_t
=
batch_gate
->
Slice
(
bstart
,
bend
);
Tensor
reset_hidden_prev_t
=
batch_reset_hidden_prev
->
Slice
(
bstart
,
bend
);
Tensor
hidden_t
=
batch_hidden
->
Slice
(
bstart
,
bend
);
gru_value
.
output_value
=
hidden_t
.
data
<
T
>
();
gru_value
.
gate_value
=
gate_t
.
data
<
T
>
();
gru_value
.
reset_output_value
=
reset_hidden_prev_t
.
data
<
T
>
();
math
::
GRUUnitFunctor
<
DeviceContext
,
T
>::
compute
(
dev_ctx
,
gru_value
,
frame_size
,
cur_batch_size
,
active_node
,
active_gate
);
gru_value
.
prev_out_value
=
gru_value
.
output_value
;
}
#ifdef PADDLE_WITH_MKLML
}
#endif
math
::
Batch2LoDTensorFunctor
<
DeviceContext
,
T
>
to_seq
;
batch_hidden
->
set_lod
(
batch_gate
->
lod
());
to_seq
(
dev_ctx
,
*
batch_hidden
,
hidden
);
}
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
BatchCompute
(
context
);
}
};
}
// namespace operators
}
// namespace paddle
...
...
@@ -218,9 +375,8 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR
(
gru
,
ops
::
GRUOp
,
ops
::
GRUOpMaker
,
paddle
::
framework
::
DefaultGradOpDescMaker
<
true
>
);
REGISTER_OPERATOR
(
gru_grad
,
ops
::
GRUGradOp
);
REGISTER_OP_CPU_KERNEL
(
gru
,
ops
::
GRUKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
GRUKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
REGISTER_OP_CPU_KERNEL
(
gru
,
ops
::
GRUCPUKernel
<
float
>
,
ops
::
GRUCPUKernel
<
double
>
);
REGISTER_OP_CPU_KERNEL
(
gru_grad
,
ops
::
GRUGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
GRUGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
paddle/fluid/operators/gru_op.cu.cc
浏览文件 @
ff92b6ba
...
...
@@ -14,6 +14,96 @@ limitations under the License. */
#include "paddle/fluid/operators/gru_op.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
DeviceContext
,
typename
T
>
class
GRUKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
BatchCompute
(
const
framework
::
ExecutionContext
&
context
)
const
{
auto
*
input
=
context
.
Input
<
LoDTensor
>
(
"Input"
);
auto
*
h0
=
context
.
Input
<
Tensor
>
(
"H0"
);
auto
*
weight
=
context
.
Input
<
Tensor
>
(
"Weight"
);
const
T
*
weight_data
=
weight
->
data
<
T
>
();
auto
*
bias
=
context
.
Input
<
Tensor
>
(
"Bias"
);
auto
*
batch_gate
=
context
.
Output
<
LoDTensor
>
(
"BatchGate"
);
batch_gate
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
*
batch_reset_hidden_prev
=
context
.
Output
<
LoDTensor
>
(
"BatchResetHiddenPrev"
);
batch_reset_hidden_prev
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
*
batch_hidden
=
context
.
Output
<
LoDTensor
>
(
"BatchHidden"
);
batch_hidden
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
*
hidden
=
context
.
Output
<
LoDTensor
>
(
"Hidden"
);
hidden
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
hidden_dims
=
hidden
->
dims
();
bool
is_reverse
=
context
.
Attr
<
bool
>
(
"is_reverse"
);
math
::
LoDTensor2BatchFunctor
<
DeviceContext
,
T
>
to_batch
;
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
to_batch
(
dev_ctx
,
*
input
,
batch_gate
,
true
,
is_reverse
);
if
(
bias
)
{
math
::
RowwiseAdd
<
DeviceContext
,
T
>
add_bias
;
add_bias
(
dev_ctx
,
*
batch_gate
,
*
bias
,
batch_gate
);
}
int
frame_size
=
hidden_dims
[
1
];
math
::
GRUMetaValue
<
T
>
gru_value
;
gru_value
.
gate_weight
=
const_cast
<
T
*>
(
weight_data
);
gru_value
.
state_weight
=
const_cast
<
T
*>
(
weight_data
+
2
*
frame_size
*
frame_size
);
Tensor
ordered_h0
;
framework
::
Vector
<
size_t
>
order
(
batch_gate
->
lod
()[
2
]);
if
(
h0
)
{
// Since the batch computing for GRU reorders the input sequences
// according to their length. The initialized cell state also needs
// to reorder.
ReorderInitState
<
DeviceContext
,
T
>
(
context
.
template
device_context
<
DeviceContext
>(),
*
h0
,
order
,
&
ordered_h0
,
true
);
gru_value
.
prev_out_value
=
ordered_h0
.
data
<
T
>
();
}
else
{
gru_value
.
prev_out_value
=
nullptr
;
}
auto
batch_starts
=
batch_gate
->
lod
()[
0
];
size_t
num_batch
=
batch_starts
.
size
()
-
1
;
auto
active_node
=
math
::
detail
::
GetActivationType
(
context
.
Attr
<
std
::
string
>
(
"activation"
));
auto
active_gate
=
math
::
detail
::
GetActivationType
(
context
.
Attr
<
std
::
string
>
(
"gate_activation"
));
for
(
size_t
n
=
0
;
n
<
num_batch
;
n
++
)
{
int
bstart
=
static_cast
<
int
>
(
batch_starts
[
n
]);
int
bend
=
static_cast
<
int
>
(
batch_starts
[
n
+
1
]);
int
cur_batch_size
=
bend
-
bstart
;
Tensor
gate_t
=
batch_gate
->
Slice
(
bstart
,
bend
);
Tensor
reset_hidden_prev_t
=
batch_reset_hidden_prev
->
Slice
(
bstart
,
bend
);
Tensor
hidden_t
=
batch_hidden
->
Slice
(
bstart
,
bend
);
gru_value
.
output_value
=
hidden_t
.
data
<
T
>
();
gru_value
.
gate_value
=
gate_t
.
data
<
T
>
();
gru_value
.
reset_output_value
=
reset_hidden_prev_t
.
data
<
T
>
();
math
::
GRUUnitFunctor
<
DeviceContext
,
T
>::
compute
(
dev_ctx
,
gru_value
,
frame_size
,
cur_batch_size
,
active_node
,
active_gate
);
gru_value
.
prev_out_value
=
gru_value
.
output_value
;
}
math
::
Batch2LoDTensorFunctor
<
DeviceContext
,
T
>
to_seq
;
batch_hidden
->
set_lod
(
batch_gate
->
lod
());
to_seq
(
dev_ctx
,
*
batch_hidden
,
hidden
);
}
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
BatchCompute
(
context
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_CUDA_KERNEL
(
gru
,
ops
::
GRUKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
...
...
paddle/fluid/operators/gru_op.h
浏览文件 @
ff92b6ba
...
...
@@ -37,90 +37,6 @@ inline void ReorderInitState(const DeviceContext& ctx,
row_shuffle
(
ctx
,
src
,
index_lod
,
dst
,
indexed_src
);
}
template
<
typename
DeviceContext
,
typename
T
>
class
GRUKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
BatchCompute
(
const
framework
::
ExecutionContext
&
context
)
const
{
auto
*
input
=
context
.
Input
<
LoDTensor
>
(
"Input"
);
auto
*
h0
=
context
.
Input
<
Tensor
>
(
"H0"
);
auto
*
weight
=
context
.
Input
<
Tensor
>
(
"Weight"
);
const
T
*
weight_data
=
weight
->
data
<
T
>
();
auto
*
bias
=
context
.
Input
<
Tensor
>
(
"Bias"
);
auto
*
batch_gate
=
context
.
Output
<
LoDTensor
>
(
"BatchGate"
);
batch_gate
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
*
batch_reset_hidden_prev
=
context
.
Output
<
LoDTensor
>
(
"BatchResetHiddenPrev"
);
batch_reset_hidden_prev
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
*
batch_hidden
=
context
.
Output
<
LoDTensor
>
(
"BatchHidden"
);
batch_hidden
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
*
hidden
=
context
.
Output
<
LoDTensor
>
(
"Hidden"
);
hidden
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
hidden_dims
=
hidden
->
dims
();
bool
is_reverse
=
context
.
Attr
<
bool
>
(
"is_reverse"
);
math
::
LoDTensor2BatchFunctor
<
DeviceContext
,
T
>
to_batch
;
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
to_batch
(
dev_ctx
,
*
input
,
batch_gate
,
true
,
is_reverse
);
if
(
bias
)
{
math
::
RowwiseAdd
<
DeviceContext
,
T
>
add_bias
;
add_bias
(
dev_ctx
,
*
batch_gate
,
*
bias
,
batch_gate
);
}
int
frame_size
=
hidden_dims
[
1
];
math
::
GRUMetaValue
<
T
>
gru_value
;
gru_value
.
gate_weight
=
const_cast
<
T
*>
(
weight_data
);
gru_value
.
state_weight
=
const_cast
<
T
*>
(
weight_data
+
2
*
frame_size
*
frame_size
);
Tensor
ordered_h0
;
framework
::
Vector
<
size_t
>
order
(
batch_gate
->
lod
()[
2
]);
if
(
h0
)
{
// Since the batch computing for GRU reorders the input sequences
// according to their length. The initialized cell state also needs
// to reorder.
ReorderInitState
<
DeviceContext
,
T
>
(
context
.
template
device_context
<
DeviceContext
>(),
*
h0
,
order
,
&
ordered_h0
,
true
);
gru_value
.
prev_out_value
=
ordered_h0
.
data
<
T
>
();
}
else
{
gru_value
.
prev_out_value
=
nullptr
;
}
auto
batch_starts
=
batch_gate
->
lod
()[
0
];
size_t
num_batch
=
batch_starts
.
size
()
-
1
;
auto
active_node
=
math
::
detail
::
GetActivationType
(
context
.
Attr
<
std
::
string
>
(
"activation"
));
auto
active_gate
=
math
::
detail
::
GetActivationType
(
context
.
Attr
<
std
::
string
>
(
"gate_activation"
));
for
(
size_t
n
=
0
;
n
<
num_batch
;
n
++
)
{
int
bstart
=
static_cast
<
int
>
(
batch_starts
[
n
]);
int
bend
=
static_cast
<
int
>
(
batch_starts
[
n
+
1
]);
int
cur_batch_size
=
bend
-
bstart
;
Tensor
gate_t
=
batch_gate
->
Slice
(
bstart
,
bend
);
Tensor
reset_hidden_prev_t
=
batch_reset_hidden_prev
->
Slice
(
bstart
,
bend
);
Tensor
hidden_t
=
batch_hidden
->
Slice
(
bstart
,
bend
);
gru_value
.
output_value
=
hidden_t
.
data
<
T
>
();
gru_value
.
gate_value
=
gate_t
.
data
<
T
>
();
gru_value
.
reset_output_value
=
reset_hidden_prev_t
.
data
<
T
>
();
math
::
GRUUnitFunctor
<
DeviceContext
,
T
>::
compute
(
dev_ctx
,
gru_value
,
frame_size
,
cur_batch_size
,
active_node
,
active_gate
);
gru_value
.
prev_out_value
=
gru_value
.
output_value
;
}
math
::
Batch2LoDTensorFunctor
<
DeviceContext
,
T
>
to_seq
;
batch_hidden
->
set_lod
(
batch_gate
->
lod
());
to_seq
(
dev_ctx
,
*
batch_hidden
,
hidden
);
}
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
BatchCompute
(
context
);
}
};
template
<
typename
DeviceContext
,
typename
T
>
class
GRUGradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
...
...
paddle/fluid/operators/math/blas.h
浏览文件 @
ff92b6ba
...
...
@@ -90,6 +90,25 @@ class Blas {
void
GEMM
(
bool
transA
,
bool
transB
,
int
M
,
int
N
,
int
K
,
T
alpha
,
const
T
*
A
,
int
lda
,
const
T
*
B
,
int
ldb
,
T
beta
,
T
*
C
,
int
ldc
)
const
;
#ifdef PADDLE_WITH_MKLML
template
<
typename
T
>
T
*
GEMM_ALLOC
(
const
CBLAS_IDENTIFIER
id
,
const
int
M
,
const
int
N
,
const
int
K
)
const
;
template
<
typename
T
>
void
GEMM_PACK
(
const
CBLAS_IDENTIFIER
id
,
const
CBLAS_TRANSPOSE
trans
,
int
M
,
int
N
,
int
K
,
const
T
alpha
,
const
T
*
src
,
const
int
ld
,
T
*
dst
)
const
;
template
<
typename
T
>
void
GEMM_COMPUTE
(
int
transA
,
int
transB
,
int
M
,
int
N
,
int
K
,
const
T
*
A
,
const
int
lda
,
const
T
*
B
,
const
int
ldb
,
T
beta
,
T
*
C
,
const
int
ldc
)
const
;
template
<
typename
T
>
void
GEMM_FREE
(
T
*
data
)
const
;
#endif
template
<
typename
T
>
void
MatMul
(
const
framework
::
Tensor
&
mat_a
,
bool
trans_a
,
const
framework
::
Tensor
&
mat_b
,
bool
trans_b
,
T
alpha
,
...
...
@@ -146,6 +165,28 @@ class BlasT : private Blas<DeviceContext> {
Base
()
->
template
GEMM
<
T
>(
args
...);
}
#ifdef PADDLE_WITH_MKLML
template
<
typename
...
ARGS
>
T
*
GEMM_ALLOC
(
ARGS
...
args
)
const
{
return
Base
()
->
template
GEMM_ALLOC
<
T
>(
args
...);
}
template
<
typename
...
ARGS
>
void
GEMM_PACK
(
ARGS
...
args
)
const
{
Base
()
->
template
GEMM_PACK
<
T
>(
args
...);
}
template
<
typename
...
ARGS
>
void
GEMM_COMPUTE
(
ARGS
...
args
)
const
{
Base
()
->
template
GEMM_COMPUTE
<
T
>(
args
...);
}
template
<
typename
...
ARGS
>
void
GEMM_FREE
(
ARGS
...
args
)
const
{
Base
()
->
template
GEMM_FREE
<
T
>(
args
...);
}
#endif
template
<
typename
...
ARGS
>
void
MatMul
(
ARGS
...
args
)
const
{
Base
()
->
template
MatMul
<
T
>(
args
...);
...
...
paddle/fluid/operators/math/blas_impl.h
浏览文件 @
ff92b6ba
...
...
@@ -31,6 +31,26 @@ struct CBlas<float> {
platform
::
dynload
::
cblas_sgemm
(
args
...);
}
template
<
typename
...
ARGS
>
static
float
*
GEMM_ALLOC
(
ARGS
...
args
)
{
return
platform
::
dynload
::
cblas_sgemm_alloc
(
args
...);
}
template
<
typename
...
ARGS
>
static
void
GEMM_PACK
(
ARGS
...
args
)
{
platform
::
dynload
::
cblas_sgemm_pack
(
args
...);
}
template
<
typename
...
ARGS
>
static
void
GEMM_COMPUTE
(
ARGS
...
args
)
{
platform
::
dynload
::
cblas_sgemm_compute
(
args
...);
}
template
<
typename
...
ARGS
>
static
void
GEMM_FREE
(
ARGS
...
args
)
{
platform
::
dynload
::
cblas_sgemm_free
(
args
...);
}
#ifdef PADDLE_WITH_LIBXSMM
template
<
typename
...
ARGS
>
static
void
SMM_GEMM
(
ARGS
...
args
)
{
...
...
@@ -71,6 +91,26 @@ struct CBlas<double> {
platform
::
dynload
::
cblas_dgemm
(
args
...);
}
template
<
typename
...
ARGS
>
static
double
*
GEMM_ALLOC
(
ARGS
...
args
)
{
return
platform
::
dynload
::
cblas_dgemm_alloc
(
args
...);
}
template
<
typename
...
ARGS
>
static
void
GEMM_PACK
(
ARGS
...
args
)
{
platform
::
dynload
::
cblas_dgemm_pack
(
args
...);
}
template
<
typename
...
ARGS
>
static
void
GEMM_COMPUTE
(
ARGS
...
args
)
{
platform
::
dynload
::
cblas_dgemm_compute
(
args
...);
}
template
<
typename
...
ARGS
>
static
void
GEMM_FREE
(
ARGS
...
args
)
{
platform
::
dynload
::
cblas_dgemm_free
(
args
...);
}
#ifdef PADDLE_WITH_LIBXSMM
template
<
typename
...
ARGS
>
static
void
SMM_GEMM
(
ARGS
...
args
)
{
...
...
@@ -224,6 +264,41 @@ inline void GEMM_WARP(CBLAS_ORDER order, CBLAS_TRANSPOSE transA,
beta
,
C
,
ldc
);
}
#ifdef PADDLE_WITH_MKLML
template
<
>
template
<
typename
T
>
T
*
Blas
<
platform
::
CPUDeviceContext
>::
GEMM_ALLOC
(
const
CBLAS_IDENTIFIER
id
,
const
int
M
,
const
int
N
,
const
int
K
)
const
{
return
CBlas
<
T
>::
GEMM_ALLOC
(
id
,
M
,
N
,
K
);
}
template
<
>
template
<
typename
T
>
void
Blas
<
platform
::
CPUDeviceContext
>::
GEMM_PACK
(
const
CBLAS_IDENTIFIER
id
,
const
CBLAS_TRANSPOSE
trans
,
int
M
,
int
N
,
int
K
,
const
T
alpha
,
const
T
*
src
,
const
int
ld
,
T
*
dst
)
const
{
CBlas
<
T
>::
GEMM_PACK
(
CblasRowMajor
,
id
,
trans
,
M
,
N
,
K
,
alpha
,
src
,
ld
,
dst
);
}
template
<
>
template
<
typename
T
>
void
Blas
<
platform
::
CPUDeviceContext
>::
GEMM_COMPUTE
(
int
transA
,
int
transB
,
int
M
,
int
N
,
int
K
,
const
T
*
A
,
const
int
lda
,
const
T
*
B
,
const
int
ldb
,
T
beta
,
T
*
C
,
const
int
ldc
)
const
{
CBlas
<
T
>::
GEMM_COMPUTE
(
CblasRowMajor
,
transA
,
transB
,
M
,
N
,
K
,
A
,
lda
,
B
,
ldb
,
beta
,
C
,
ldc
);
}
template
<
>
template
<
typename
T
>
void
Blas
<
platform
::
CPUDeviceContext
>::
GEMM_FREE
(
T
*
data
)
const
{
CBlas
<
T
>::
GEMM_FREE
(
data
);
}
#endif
template
<
>
template
<
typename
T
>
void
Blas
<
platform
::
CPUDeviceContext
>::
GEMM
(
CBLAS_TRANSPOSE
transA
,
...
...
paddle/fluid/platform/dynload/mklml.h
浏览文件 @
ff92b6ba
...
...
@@ -60,6 +60,14 @@ extern void* mklml_dso_handle;
__macro(cblas_dgemm_batch); \
__macro(vsAdd); \
__macro(vdAdd); \
__macro(cblas_sgemm_alloc); \
__macro(cblas_sgemm_pack); \
__macro(cblas_sgemm_compute); \
__macro(cblas_sgemm_free); \
__macro(cblas_dgemm_alloc); \
__macro(cblas_dgemm_pack); \
__macro(cblas_dgemm_compute); \
__macro(cblas_dgemm_free); \
__macro(MKL_Set_Num_Threads)
MKLML_ROUTINE_EACH
(
DECLARE_DYNAMIC_LOAD_MKLML_WRAP
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录