Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
815d8884
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
815d8884
编写于
5月 03, 2018
作者:
Y
Yu Yang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Clean MatMul
上级
9d7279b9
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
156 addition
and
299 deletion
+156
-299
paddle/fluid/operators/conv_op.h
paddle/fluid/operators/conv_op.h
+6
-8
paddle/fluid/operators/conv_transpose_op.h
paddle/fluid/operators/conv_transpose_op.h
+5
-9
paddle/fluid/operators/lstm_op.h
paddle/fluid/operators/lstm_op.h
+14
-18
paddle/fluid/operators/lstmp_op.h
paddle/fluid/operators/lstmp_op.h
+26
-37
paddle/fluid/operators/math/blas_impl.cu.h
paddle/fluid/operators/math/blas_impl.cu.h
+14
-16
paddle/fluid/operators/math/blas_impl.h
paddle/fluid/operators/math/blas_impl.h
+34
-10
paddle/fluid/operators/math/math_function.cc
paddle/fluid/operators/math/math_function.cc
+0
-67
paddle/fluid/operators/math/math_function.cu
paddle/fluid/operators/math/math_function.cu
+0
-87
paddle/fluid/operators/math/math_function.h
paddle/fluid/operators/math/math_function.h
+28
-12
paddle/fluid/operators/math/math_function_test.cu
paddle/fluid/operators/math/math_function_test.cu
+17
-20
paddle/fluid/operators/mul_op.h
paddle/fluid/operators/mul_op.h
+7
-7
paddle/fluid/operators/sequence_conv_op.h
paddle/fluid/operators/sequence_conv_op.h
+5
-8
未找到文件。
paddle/fluid/operators/conv_op.h
浏览文件 @
815d8884
...
...
@@ -161,6 +161,7 @@ class GemmConvKernel : public framework::OpKernel<T> {
math
::
Im2ColFunctor
<
math
::
ColFormat
::
kCFO
,
DeviceContext
,
T
>
im2col
;
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
auto
blas
=
math
::
GetBlas
<
DeviceContext
,
T
>
(
dev_ctx
);
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
Tensor
in_batch
=
input
->
Slice
(
i
,
i
+
1
).
Resize
(
input_shape
);
Tensor
out_batch
=
output
->
Slice
(
i
,
i
+
1
).
Resize
(
output_matrix_shape
);
...
...
@@ -186,8 +187,7 @@ class GemmConvKernel : public framework::OpKernel<T> {
// gemm
Tensor
out_slice
=
out_batch
.
Slice
(
g
*
out_step
,
(
g
+
1
)
*
out_step
);
Tensor
filter_slice
=
filter
.
Slice
(
g
*
out_step
,
(
g
+
1
)
*
out_step
);
math
::
matmul
<
DeviceContext
,
T
>
(
dev_ctx
,
filter_slice
,
false
,
col_matrix
,
false
,
T
(
1.0
),
&
out_slice
,
T
(
0.0
));
blas
.
MatMul
(
filter_slice
,
col_matrix
,
&
out_slice
);
}
}
}
...
...
@@ -274,6 +274,7 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
math
::
SetConstant
<
DeviceContext
,
T
>
set_zero
;
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
auto
blas
=
math
::
GetBlas
<
DeviceContext
,
T
>
(
dev_ctx
);
if
(
input_grad
)
{
input_grad
->
mutable_data
<
T
>
(
context
.
GetPlace
());
...
...
@@ -303,9 +304,7 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
col_matrix
.
ShareDataWith
(
in_grad_slice
);
col_matrix
.
Resize
(
col_matrix_shape
);
}
math
::
matmul
<
DeviceContext
,
T
>
(
dev_ctx
,
filter_slice
,
true
,
out_grad_slice
,
false
,
T
(
1.0
),
&
col_matrix
,
T
(
0.0
));
blas
.
MatMul
(
filter_slice
,
true
,
out_grad_slice
,
false
,
&
col_matrix
);
if
(
is_expand
&&
data_dim
==
2U
)
{
col2im
(
dev_ctx
,
col
,
dilations
,
strides
,
...
...
@@ -352,9 +351,8 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
// gemm
Tensor
filter_grad_slice
=
filter_grad_
.
Slice
(
g
*
out_step
,
(
g
+
1
)
*
out_step
);
math
::
matmul
<
DeviceContext
,
T
>
(
dev_ctx
,
out_grad_slice
,
false
,
col_matrix
,
true
,
T
(
1.0
),
&
filter_grad_slice
,
T
(
1.0
));
blas
.
MatMul
(
out_grad_slice
,
false
,
col_matrix
,
true
,
&
filter_grad_slice
);
}
}
}
...
...
paddle/fluid/operators/conv_transpose_op.h
浏览文件 @
815d8884
...
...
@@ -118,6 +118,7 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
output
->
mutable_data
<
T
>
(
context
.
GetPlace
());
math
::
SetConstant
<
DeviceContext
,
T
>
set_zero
;
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
auto
blas
=
math
::
GetBlas
<
DeviceContext
,
T
>
(
dev_ctx
);
set_zero
(
dev_ctx
,
output
,
static_cast
<
T
>
(
0
));
math
::
Col2ImFunctor
<
math
::
ColFormat
::
kCFO
,
DeviceContext
,
T
>
col2im
;
...
...
@@ -134,9 +135,7 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
// col_matrix = filter * input_batch
// of shape (c * k_h * k_w, h * w) or (c * k_d * k_h * k_w, d * h * w)
math
::
matmul
<
DeviceContext
,
T
>
(
dev_ctx
,
filter
,
true
,
input_batch
,
false
,
static_cast
<
T
>
(
1.0
),
&
col_matrix
,
static_cast
<
T
>
(
0.0
));
blas
.
MatMul
(
filter
,
true
,
input_batch
,
false
,
&
col_matrix
);
if
(
data_dim
==
2U
)
{
// col2im: col_matrix -> dy
...
...
@@ -213,6 +212,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
// im2col + gemm (similar to conv-forward)
// input need to compute gradient
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
auto
blas
=
math
::
GetBlas
<
DeviceContext
,
T
>
(
dev_ctx
);
if
(
input_grad
||
filter_grad
)
{
Tensor
col
;
col
.
mutable_data
<
T
>
(
col_shape
,
context
.
GetPlace
());
...
...
@@ -267,9 +267,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
// or
// (m, c * k_d * k_h * k_w) * (c * k_d * k_h * k_w, d * h * w) -> (m,
// d, h, w)
math
::
matmul
<
DeviceContext
,
T
>
(
dev_ctx
,
filter
,
false
,
col_matrix
,
false
,
static_cast
<
T
>
(
1.0
),
&
input_grad_batch
,
static_cast
<
T
>
(
0.0
));
blas
.
MatMul
(
filter
,
false
,
col_matrix
,
false
,
&
input_grad_batch
);
}
if
(
filter_grad
)
{
// input batch
...
...
@@ -279,9 +277,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
// or
// (m, d * h * w) * (d * h * w, c * k_d * k_h * k_w) -> (m, c * k_d *
// k_h * k_w)
math
::
matmul
<
DeviceContext
,
T
>
(
dev_ctx
,
in_batch
,
false
,
col_matrix
,
true
,
static_cast
<
T
>
(
1.0
),
&
filter_grad_
,
static_cast
<
T
>
(
1.0
));
blas
.
MatMul
(
in_batch
,
false
,
col_matrix
,
true
,
&
filter_grad_
);
}
}
}
...
...
paddle/fluid/operators/lstm_op.h
浏览文件 @
815d8884
...
...
@@ -114,6 +114,7 @@ class LSTMKernel : public framework::OpKernel<T> {
auto
cand_act
=
math
::
detail
::
GetActivationType
(
ctx
.
Attr
<
std
::
string
>
(
"candidate_activation"
));
auto
blas
=
math
::
GetBlas
<
DeviceContext
,
T
>
(
device_ctx
);
for
(
size_t
n
=
0
;
n
<
num_batch
;
n
++
)
{
int
bstart
=
static_cast
<
int
>
(
batch_starts
[
n
]);
int
bend
=
static_cast
<
int
>
(
batch_starts
[
n
+
1
]);
...
...
@@ -129,9 +130,8 @@ class LSTMKernel : public framework::OpKernel<T> {
int
pre_h_start
=
static_cast
<
int
>
(
batch_starts
[
n
-
1
]);
int
pre_h_end
=
pre_h_start
+
cur_batch_size
;
auto
pre_hidden_t
=
batch_hidden
.
Slice
(
pre_h_start
,
pre_h_end
);
math
::
matmul
<
DeviceContext
,
T
>
(
device_ctx
,
pre_hidden_t
,
false
,
*
weight
,
false
,
static_cast
<
T
>
(
1.0
),
&
gate_t
,
static_cast
<
T
>
(
1.0
));
blas
.
MatMul
(
pre_hidden_t
,
false
,
*
weight
,
false
,
static_cast
<
T
>
(
1.0
),
&
gate_t
,
static_cast
<
T
>
(
1.0
));
}
else
if
(
hidden_t0
)
{
// If n == 0 and there is no initialized hidden state, that is to say
// the H0 is zeros, the calculation W_h * H0 will be skiped.
...
...
@@ -143,9 +143,8 @@ class LSTMKernel : public framework::OpKernel<T> {
Tensor
ordered_h0
;
ReorderInitState
<
DeviceContext
,
T
>
(
device_ctx
,
*
hidden_t0
,
order
,
&
ordered_h0
,
true
);
math
::
matmul
<
DeviceContext
,
T
>
(
device_ctx
,
ordered_h0
,
false
,
*
weight
,
false
,
static_cast
<
T
>
(
1.0
),
&
gate_t
,
static_cast
<
T
>
(
1.0
));
blas
.
MatMul
(
ordered_h0
,
false
,
*
weight
,
false
,
static_cast
<
T
>
(
1.0
),
&
gate_t
,
static_cast
<
T
>
(
1.0
));
}
lstm_value
.
gate_value
=
gate_t
.
data
<
T
>
();
...
...
@@ -282,6 +281,7 @@ class LSTMGradKernel : public framework::OpKernel<T> {
auto
batch_starts
=
batch_gate
->
lod
()[
0
];
size_t
num_batch
=
batch_starts
.
size
()
-
1
;
auto
blas
=
math
::
GetBlas
<
DeviceContext
,
T
>
(
device_ctx
);
for
(
int
n
=
static_cast
<
int
>
(
num_batch
)
-
1
;
n
>=
0
;
n
--
)
{
int
bstart
=
static_cast
<
int
>
(
batch_starts
[
n
]);
int
bend
=
static_cast
<
int
>
(
batch_starts
[
n
+
1
]);
...
...
@@ -320,29 +320,25 @@ class LSTMGradKernel : public framework::OpKernel<T> {
int
pre_h_start
=
static_cast
<
int
>
(
batch_starts
[
n
-
1
]);
int
pre_h_end
=
pre_h_start
+
cur_batch_size
;
auto
pre_hidden_g
=
batch_hidden_g
.
Slice
(
pre_h_start
,
pre_h_end
);
math
::
matmul
<
DeviceContext
,
T
>
(
device_ctx
,
gate_g
,
false
,
*
weight
,
true
,
static_cast
<
T
>
(
1.0
),
&
pre_hidden_g
,
static_cast
<
T
>
(
1.0
));
blas
.
MatMul
(
gate_g
,
false
,
*
weight
,
true
,
static_cast
<
T
>
(
1.0
),
&
pre_hidden_g
,
static_cast
<
T
>
(
1.0
));
if
(
weight_g
)
{
/* backward weight */
auto
pre_hidden
=
batch_hidden
.
Slice
(
pre_h_start
,
pre_h_end
);
math
::
matmul
<
DeviceContext
,
T
>
(
device_ctx
,
pre_hidden
,
true
,
gate_g
,
false
,
static_cast
<
T
>
(
1.0
),
weight_g
,
static_cast
<
T
>
(
1.0
));
blas
.
MatMul
(
pre_hidden
,
true
,
gate_g
,
false
,
static_cast
<
T
>
(
1.0
),
weight_g
,
static_cast
<
T
>
(
1.0
));
}
}
else
{
if
(
h0
&&
weight_g
)
{
ReorderInitState
<
DeviceContext
,
T
>
(
device_ctx
,
*
h0
,
order
,
&
ordered_h0
,
true
);
math
::
matmul
<
DeviceContext
,
T
>
(
device_ctx
,
ordered_h0
,
true
,
gate_g
,
false
,
static_cast
<
T
>
(
1.0
),
weight_g
,
static_cast
<
T
>
(
1.0
));
blas
.
MatMul
(
ordered_h0
,
true
,
gate_g
,
false
,
static_cast
<
T
>
(
1.0
),
weight_g
,
static_cast
<
T
>
(
1.0
));
}
if
(
h0
&&
h0_g
)
{
ordered_h0_g
.
mutable_data
<
T
>
(
h0_g
->
dims
(),
ctx
.
GetPlace
());
math
::
matmul
<
DeviceContext
,
T
>
(
device_ctx
,
gate_g
,
false
,
*
weight
,
true
,
static_cast
<
T
>
(
1.0
),
&
ordered_h0_g
,
static_cast
<
T
>
(
0.0
));
blas
.
MatMul
(
gate_g
,
false
,
*
weight
,
true
,
static_cast
<
T
>
(
1.0
),
&
ordered_h0_g
,
static_cast
<
T
>
(
0.0
));
}
}
}
...
...
paddle/fluid/operators/lstmp_op.h
浏览文件 @
815d8884
...
...
@@ -143,7 +143,7 @@ class LSTMPKernel : public framework::OpKernel<T> {
auto
proj_act
=
math
::
detail
::
GetActivationType
(
ctx
.
Attr
<
std
::
string
>
(
"proj_activation"
));
auto
&
place
=
*
ctx
.
template
device_context
<
DeviceContext
>().
eigen_device
();
auto
blas
=
math
::
GetBlas
<
DeviceContext
,
T
>
(
device_ctx
);
for
(
size_t
n
=
0
;
n
<
num_batch
;
n
++
)
{
int
bstart
=
static_cast
<
int
>
(
batch_starts
[
n
]);
int
bend
=
static_cast
<
int
>
(
batch_starts
[
n
+
1
]);
...
...
@@ -160,9 +160,8 @@ class LSTMPKernel : public framework::OpKernel<T> {
int
pre_h_start
=
static_cast
<
int
>
(
batch_starts
[
n
-
1
]);
int
pre_h_end
=
pre_h_start
+
cur_batch_size
;
auto
pre_proj_t
=
batch_proj
.
Slice
(
pre_h_start
,
pre_h_end
);
math
::
matmul
<
DeviceContext
,
T
>
(
device_ctx
,
pre_proj_t
,
false
,
*
weight
,
false
,
static_cast
<
T
>
(
1.0
),
&
gate_t
,
static_cast
<
T
>
(
1.0
));
blas
.
MatMul
(
pre_proj_t
,
false
,
*
weight
,
false
,
static_cast
<
T
>
(
1.0
),
&
gate_t
,
static_cast
<
T
>
(
1.0
));
}
else
if
(
hidden_t0
)
{
// If n == 0 and there is no initialized hidden state, that is to say
// the H0 is zeros, the calculation W_h * H0 will be skiped.
...
...
@@ -176,16 +175,14 @@ class LSTMPKernel : public framework::OpKernel<T> {
ordered_proj0
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
ReorderInitState
<
DeviceContext
,
T
>
(
device_ctx
,
*
hidden_t0
,
order
,
&
ordered_h0
,
true
);
math
::
matmul
<
DeviceContext
,
T
>
(
device_ctx
,
ordered_h0
,
false
,
*
proj_weight
,
false
,
static_cast
<
T
>
(
1.0
),
ordered_proj0
,
static_cast
<
T
>
(
0.0
));
blas
.
MatMul
(
ordered_h0
,
false
,
*
proj_weight
,
false
,
static_cast
<
T
>
(
1.0
),
ordered_proj0
,
static_cast
<
T
>
(
0.0
));
if
(
proj_act
!=
math
::
detail
::
ActivationType
::
kIdentity
)
{
auto
proj0_dev
=
EigenMatrix
<
T
>::
From
(
*
ordered_proj0
);
ActCompute
(
cell_act
,
place
,
proj0_dev
,
proj0_dev
);
}
math
::
matmul
<
DeviceContext
,
T
>
(
device_ctx
,
*
ordered_proj0
,
false
,
*
weight
,
false
,
static_cast
<
T
>
(
1.0
),
&
gate_t
,
static_cast
<
T
>
(
1.0
));
blas
.
MatMul
(
*
ordered_proj0
,
false
,
*
weight
,
false
,
static_cast
<
T
>
(
1.0
),
&
gate_t
,
static_cast
<
T
>
(
1.0
));
}
lstmp_value
.
gate_value
=
gate_t
.
data
<
T
>
();
...
...
@@ -196,9 +193,8 @@ class LSTMPKernel : public framework::OpKernel<T> {
device_ctx
,
lstmp_value
,
frame_size
,
cur_batch_size
,
gate_act
,
cell_act
,
cand_act
);
lstmp_value
.
prev_state_value
=
lstmp_value
.
state_value
;
math
::
matmul
<
DeviceContext
,
T
>
(
device_ctx
,
hidden_t
,
false
,
*
proj_weight
,
false
,
static_cast
<
T
>
(
1.0
),
&
proj_t
,
static_cast
<
T
>
(
0.0
));
blas
.
MatMul
(
hidden_t
,
false
,
*
proj_weight
,
false
,
static_cast
<
T
>
(
1.0
),
&
proj_t
,
static_cast
<
T
>
(
0.0
));
if
(
proj_act
!=
math
::
detail
::
ActivationType
::
kIdentity
)
{
auto
proj_t_dev
=
EigenMatrix
<
T
>::
From
(
proj_t
);
ActCompute
(
cell_act
,
place
,
proj_t_dev
,
proj_t_dev
);
...
...
@@ -361,6 +357,7 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
auto
batch_starts
=
batch_gate
->
lod
()[
0
];
size_t
num_batch
=
batch_starts
.
size
()
-
1
;
auto
blas
=
math
::
GetBlas
<
DeviceContext
,
T
>
(
device_ctx
);
for
(
int
n
=
static_cast
<
int
>
(
num_batch
)
-
1
;
n
>=
0
;
n
--
)
{
int
bstart
=
static_cast
<
int
>
(
batch_starts
[
n
]);
int
bend
=
static_cast
<
int
>
(
batch_starts
[
n
+
1
]);
...
...
@@ -375,15 +372,13 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
}
/* hidden state backwarad */
Tensor
out_g
=
batch_hidden_g
.
Slice
(
bstart
,
bend
);
math
::
matmul
<
DeviceContext
,
T
>
(
device_ctx
,
proj_g
,
false
,
*
proj_weight
,
true
,
static_cast
<
T
>
(
1.0
),
&
out_g
,
static_cast
<
T
>
(
0.0
));
blas
.
MatMul
(
proj_g
,
false
,
*
proj_weight
,
true
,
static_cast
<
T
>
(
1.0
),
&
out_g
,
static_cast
<
T
>
(
0.0
));
/* projection weight backward*/
if
(
proj_weight_g
)
{
Tensor
hidden_t
=
batch_hidden
->
Slice
(
bstart
,
bend
);
math
::
matmul
<
DeviceContext
,
T
>
(
device_ctx
,
hidden_t
,
true
,
proj_g
,
false
,
static_cast
<
T
>
(
1.0
),
proj_weight_g
,
static_cast
<
T
>
(
1.0
));
blas
.
MatMul
(
hidden_t
,
true
,
proj_g
,
false
,
static_cast
<
T
>
(
1.0
),
proj_weight_g
,
static_cast
<
T
>
(
1.0
));
}
Tensor
gate
=
batch_gate
->
Slice
(
bstart
,
bend
);
...
...
@@ -419,24 +414,21 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
int
pre_h_start
=
static_cast
<
int
>
(
batch_starts
[
n
-
1
]);
int
pre_h_end
=
pre_h_start
+
cur_batch_size
;
auto
pre_proj_g
=
batch_proj_g
.
Slice
(
pre_h_start
,
pre_h_end
);
math
::
matmul
<
DeviceContext
,
T
>
(
device_ctx
,
gate_g
,
false
,
*
weight
,
true
,
static_cast
<
T
>
(
1.0
),
&
pre_proj_g
,
static_cast
<
T
>
(
1.0
));
blas
.
MatMul
(
gate_g
,
false
,
*
weight
,
true
,
static_cast
<
T
>
(
1.0
),
&
pre_proj_g
,
static_cast
<
T
>
(
1.0
));
if
(
weight_g
)
{
/* weight backward*/
auto
pre_proj
=
batch_proj
.
Slice
(
pre_h_start
,
pre_h_end
);
math
::
matmul
<
DeviceContext
,
T
>
(
device_ctx
,
pre_proj
,
true
,
gate_g
,
false
,
static_cast
<
T
>
(
1.0
),
weight_g
,
static_cast
<
T
>
(
1.0
));
blas
.
MatMul
(
pre_proj
,
true
,
gate_g
,
false
,
static_cast
<
T
>
(
1.0
),
weight_g
,
static_cast
<
T
>
(
1.0
));
}
}
else
{
if
(
h0
&&
weight_g
)
{
ReorderInitState
<
DeviceContext
,
T
>
(
device_ctx
,
*
h0
,
order
,
&
ordered_h0
,
true
);
if
(
weight_g
)
{
math
::
matmul
<
DeviceContext
,
T
>
(
device_ctx
,
*
ordered_proj0
,
true
,
gate_g
,
false
,
static_cast
<
T
>
(
1.0
),
weight_g
,
static_cast
<
T
>
(
1.0
));
blas
.
MatMul
(
*
ordered_proj0
,
true
,
gate_g
,
false
,
static_cast
<
T
>
(
1.0
),
weight_g
,
static_cast
<
T
>
(
1.0
));
}
}
if
(
h0
&&
(
h0_g
||
proj_weight_g
))
{
...
...
@@ -444,9 +436,8 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
Tensor
proj0_g
;
proj0_g
.
Resize
({
in_dims
[
0
],
proj_weight
->
dims
()[
1
]});
proj0_g
.
mutable_data
<
T
>
(
ctx
.
GetPlace
());
math
::
matmul
<
DeviceContext
,
T
>
(
device_ctx
,
gate_g
,
false
,
*
weight
,
true
,
static_cast
<
T
>
(
1.0
),
&
proj0_g
,
static_cast
<
T
>
(
0.0
));
blas
.
MatMul
(
gate_g
,
false
,
*
weight
,
true
,
static_cast
<
T
>
(
1.0
),
&
proj0_g
,
static_cast
<
T
>
(
0.0
));
if
(
proj_act
!=
math
::
detail
::
ActivationType
::
kIdentity
)
{
auto
proj0_dev
=
EigenMatrix
<
T
>::
From
(
*
ordered_proj0
);
auto
proj0_g_dev
=
EigenMatrix
<
T
>::
From
(
proj0_g
);
...
...
@@ -454,14 +445,12 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
proj0_g_dev
);
}
if
(
h0_g
)
{
math
::
matmul
<
DeviceContext
,
T
>
(
device_ctx
,
proj0_g
,
false
,
*
proj_weight
,
true
,
static_cast
<
T
>
(
1.0
),
&
ordered_h0_g
,
static_cast
<
T
>
(
0.0
));
blas
.
MatMul
(
proj0_g
,
false
,
*
proj_weight
,
true
,
static_cast
<
T
>
(
1.0
),
&
ordered_h0_g
,
static_cast
<
T
>
(
0.0
));
}
if
(
proj_weight_g
)
{
math
::
matmul
<
DeviceContext
,
T
>
(
device_ctx
,
ordered_h0
,
true
,
proj0_g
,
false
,
static_cast
<
T
>
(
1.0
),
proj_weight_g
,
static_cast
<
T
>
(
1.0
));
blas
.
MatMul
(
ordered_h0
,
true
,
proj0_g
,
false
,
static_cast
<
T
>
(
1.0
),
proj_weight_g
,
static_cast
<
T
>
(
1.0
));
}
}
}
...
...
paddle/fluid/operators/math/blas_impl.cu.h
浏览文件 @
815d8884
...
...
@@ -61,12 +61,10 @@ struct CUBlas<platform::float16> {
template
<
>
template
<
typename
T
>
void
Blas
<
platform
::
CUDADeviceContext
>::
GEMM
(
const
CBLAS_TRANSPOSE
transA
,
const
CBLAS_TRANSPOSE
transB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
T
alpha
,
const
T
*
A
,
const
T
*
B
,
const
T
beta
,
T
*
C
)
const
{
void
Blas
<
platform
::
CUDADeviceContext
>::
GEMM
(
CBLAS_TRANSPOSE
transA
,
CBLAS_TRANSPOSE
transB
,
int
M
,
int
N
,
int
K
,
T
alpha
,
const
T
*
A
,
const
T
*
B
,
T
beta
,
T
*
C
)
const
{
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int
lda
=
(
transA
==
CblasNoTrans
)
?
K
:
M
;
...
...
@@ -83,10 +81,10 @@ void Blas<platform::CUDADeviceContext>::GEMM(const CBLAS_TRANSPOSE transA,
template
<
>
template
<
>
inline
void
Blas
<
platform
::
CUDADeviceContext
>::
GEMM
(
const
CBLAS_TRANSPOSE
transA
,
const
CBLAS_TRANSPOSE
transB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
platform
::
float16
alpha
,
const
platform
::
float16
*
A
,
const
platform
::
float16
*
B
,
const
platform
::
float16
beta
,
platform
::
float16
*
C
)
const
{
CBLAS_TRANSPOSE
transA
,
CBLAS_TRANSPOSE
transB
,
int
M
,
int
N
,
int
K
,
platform
::
float16
alpha
,
const
platform
::
float16
*
A
,
const
platform
::
float16
*
B
,
platform
::
float16
beta
,
platform
::
float16
*
C
)
const
{
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int
lda
=
(
transA
==
CblasNoTrans
)
?
K
:
M
;
...
...
@@ -134,14 +132,14 @@ inline void Blas<platform::CUDADeviceContext>::GEMM(
template
<
>
template
<
typename
T
>
void
Blas
<
platform
::
CUDADeviceContext
>::
GEMM
(
const
bool
transA
,
const
bool
transB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
T
alpha
,
const
T
*
A
,
const
int
lda
,
const
T
*
B
,
const
int
ldb
,
const
T
beta
,
T
*
C
,
const
int
ldc
)
const
{
void
Blas
<
platform
::
CUDADeviceContext
>::
GEMM
(
bool
transA
,
bool
transB
,
int
M
,
int
N
,
int
K
,
T
alpha
,
const
T
*
A
,
int
lda
,
const
T
*
B
,
int
ldb
,
T
beta
,
T
*
C
,
int
ldc
)
const
{
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
cublasOperation_t
cuTransA
=
transA
==
false
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
cublasOperation_t
cuTransB
=
transB
==
false
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
cublasOperation_t
cuTransA
=
transA
?
CUBLAS_OP_T
:
CUBLAS_OP_N
;
cublasOperation_t
cuTransB
=
transB
?
CUBLAS_OP_T
:
CUBLAS_OP_N
;
CUBlas
<
T
>::
GEMM
(
context_
.
cublas_handle
(),
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
alpha
,
B
,
ldb
,
A
,
lda
,
&
beta
,
C
,
ldc
);
}
...
...
paddle/fluid/operators/math/blas_impl.h
浏览文件 @
815d8884
...
...
@@ -45,12 +45,10 @@ struct CBlas<platform::float16> {
template
<
>
template
<
typename
T
>
void
Blas
<
platform
::
CPUDeviceContext
>::
GEMM
(
const
CBLAS_TRANSPOSE
transA
,
const
CBLAS_TRANSPOSE
transB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
T
alpha
,
const
T
*
A
,
const
T
*
B
,
const
T
beta
,
T
*
C
)
const
{
void
Blas
<
platform
::
CPUDeviceContext
>::
GEMM
(
CBLAS_TRANSPOSE
transA
,
CBLAS_TRANSPOSE
transB
,
int
M
,
int
N
,
int
K
,
T
alpha
,
const
T
*
A
,
const
T
*
B
,
T
beta
,
T
*
C
)
const
{
int
lda
=
(
transA
==
CblasNoTrans
)
?
K
:
M
;
int
ldb
=
(
transB
==
CblasNoTrans
)
?
N
:
K
;
int
ldc
=
N
;
...
...
@@ -60,15 +58,41 @@ void Blas<platform::CPUDeviceContext>::GEMM(const CBLAS_TRANSPOSE transA,
template
<
>
template
<
typename
T
>
void
Blas
<
platform
::
CPUDeviceContext
>::
GEMM
(
const
bool
transA
,
const
bool
transB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
T
alpha
,
const
T
*
A
,
const
int
lda
,
const
T
*
B
,
const
int
ldb
,
const
T
beta
,
T
*
C
,
const
int
ldc
)
const
{
void
Blas
<
platform
::
CPUDeviceContext
>::
GEMM
(
bool
transA
,
bool
transB
,
int
M
,
int
N
,
int
K
,
T
alpha
,
const
T
*
A
,
int
lda
,
const
T
*
B
,
int
ldb
,
T
beta
,
T
*
C
,
int
ldc
)
const
{
CBlas
<
T
>::
GEMM
(
CblasRowMajor
,
transA
==
false
?
CblasNoTrans
:
CblasTrans
,
transB
==
false
?
CblasNoTrans
:
CblasTrans
,
M
,
N
,
K
,
alpha
,
A
,
lda
,
B
,
ldb
,
beta
,
C
,
ldc
);
}
template
<
typename
DeviceContext
>
template
<
typename
T
>
void
Blas
<
DeviceContext
>::
MatMul
(
const
framework
::
Tensor
&
mat_a
,
bool
trans_a
,
const
framework
::
Tensor
&
mat_b
,
bool
trans_b
,
T
alpha
,
framework
::
Tensor
*
mat_out
,
T
beta
)
const
{
auto
dim_a
=
mat_a
.
dims
();
auto
dim_b
=
mat_b
.
dims
();
auto
dim_out
=
mat_out
->
dims
();
PADDLE_ENFORCE
(
dim_a
.
size
()
==
2
&&
dim_b
.
size
()
==
2
&&
dim_out
.
size
()
==
2
,
"The input and output of matmul be matrix"
);
PADDLE_ENFORCE
(
mat_a
.
place
()
==
mat_b
.
place
()
&&
mat_a
.
place
()
==
mat_out
->
place
(),
"The places of matrices must be same"
);
int
M
=
dim_out
[
0
];
int
N
=
dim_out
[
1
];
int
K
=
!
trans_a
?
dim_a
[
1
]
:
dim_a
[
0
];
CBLAS_TRANSPOSE
transA
=
!
trans_a
?
CblasNoTrans
:
CblasTrans
;
CBLAS_TRANSPOSE
transB
=
!
trans_b
?
CblasNoTrans
:
CblasTrans
;
this
->
GEMM
(
transA
,
transB
,
M
,
N
,
K
,
alpha
,
mat_a
.
data
<
T
>
(),
mat_b
.
data
<
T
>
(),
beta
,
mat_out
->
data
<
T
>
());
}
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/math/math_function.cc
浏览文件 @
815d8884
...
...
@@ -24,73 +24,6 @@ namespace math {
using
float16
=
paddle
::
platform
::
float16
;
template
<
>
void
matmul
<
platform
::
CPUDeviceContext
,
float16
>
(
const
platform
::
CPUDeviceContext
&
context
,
const
framework
::
Tensor
&
matrix_a
,
bool
trans_a
,
const
framework
::
Tensor
&
matrix_b
,
bool
trans_b
,
float16
alpha
,
framework
::
Tensor
*
matrix_out
,
float16
beta
)
{
PADDLE_THROW
(
"float16 matmul not supported on CPU"
);
}
template
<
>
void
matmul
<
platform
::
CPUDeviceContext
,
float
>
(
const
platform
::
CPUDeviceContext
&
context
,
const
framework
::
Tensor
&
matrix_a
,
bool
trans_a
,
const
framework
::
Tensor
&
matrix_b
,
bool
trans_b
,
float
alpha
,
framework
::
Tensor
*
matrix_out
,
float
beta
)
{
auto
dim_a
=
matrix_a
.
dims
();
auto
dim_b
=
matrix_b
.
dims
();
auto
dim_out
=
matrix_out
->
dims
();
PADDLE_ENFORCE
(
dim_a
.
size
()
==
2
&&
dim_b
.
size
()
==
2
&&
dim_out
.
size
()
==
2
,
"The input and output of matmul be matrix"
);
PADDLE_ENFORCE
(
platform
::
is_cpu_place
(
matrix_a
.
place
())
&&
platform
::
is_cpu_place
(
matrix_b
.
place
())
&&
platform
::
is_cpu_place
(
matrix_out
->
place
()),
"Matrix must all be in CPUPlace"
);
int
M
=
dim_out
[
0
];
int
N
=
dim_out
[
1
];
int
K
=
(
trans_a
==
false
)
?
dim_a
[
1
]
:
dim_a
[
0
];
CBLAS_TRANSPOSE
transA
=
(
trans_a
==
false
)
?
CblasNoTrans
:
CblasTrans
;
CBLAS_TRANSPOSE
transB
=
(
trans_b
==
false
)
?
CblasNoTrans
:
CblasTrans
;
Blas
<
platform
::
CPUDeviceContext
>
(
context
).
GEMM
(
transA
,
transB
,
M
,
N
,
K
,
alpha
,
matrix_a
.
data
<
float
>
(),
matrix_b
.
data
<
float
>
(),
beta
,
matrix_out
->
data
<
float
>
());
}
template
<
>
void
matmul
<
platform
::
CPUDeviceContext
,
double
>
(
const
platform
::
CPUDeviceContext
&
context
,
const
framework
::
Tensor
&
matrix_a
,
bool
trans_a
,
const
framework
::
Tensor
&
matrix_b
,
bool
trans_b
,
double
alpha
,
framework
::
Tensor
*
matrix_out
,
double
beta
)
{
auto
dim_a
=
matrix_a
.
dims
();
auto
dim_b
=
matrix_b
.
dims
();
auto
dim_out
=
matrix_out
->
dims
();
PADDLE_ENFORCE
(
dim_a
.
size
()
==
2
&&
dim_b
.
size
()
==
2
&&
dim_out
.
size
()
==
2
,
"The input and output of matmul be matrix"
);
PADDLE_ENFORCE
(
platform
::
is_cpu_place
(
matrix_a
.
place
())
&&
platform
::
is_cpu_place
(
matrix_b
.
place
())
&&
platform
::
is_cpu_place
(
matrix_out
->
place
()),
"Matrix must all be in CPUPlace"
);
int
M
=
dim_out
[
0
];
int
N
=
dim_out
[
1
];
int
K
=
(
trans_a
==
false
)
?
dim_a
[
1
]
:
dim_a
[
0
];
CBLAS_TRANSPOSE
transA
=
(
trans_a
==
false
)
?
CblasNoTrans
:
CblasTrans
;
CBLAS_TRANSPOSE
transB
=
(
trans_b
==
false
)
?
CblasNoTrans
:
CblasTrans
;
Blas
<
platform
::
CPUDeviceContext
>
(
context
).
GEMM
(
transA
,
transB
,
M
,
N
,
K
,
alpha
,
matrix_a
.
data
<
double
>
(),
matrix_b
.
data
<
double
>
(),
beta
,
matrix_out
->
data
<
double
>
());
}
template
<
>
void
batched_gemm
<
platform
::
CPUDeviceContext
,
float16
>
(
const
platform
::
CPUDeviceContext
&
context
,
const
CBLAS_TRANSPOSE
transA
,
...
...
paddle/fluid/operators/math/math_function.cu
浏览文件 @
815d8884
...
...
@@ -25,93 +25,6 @@ namespace math {
using
float16
=
paddle
::
platform
::
float16
;
template
<
>
void
matmul
<
platform
::
CUDADeviceContext
,
float16
>
(
const
platform
::
CUDADeviceContext
&
context
,
const
framework
::
Tensor
&
matrix_a
,
bool
trans_a
,
const
framework
::
Tensor
&
matrix_b
,
bool
trans_b
,
float16
alpha
,
framework
::
Tensor
*
matrix_out
,
float16
beta
)
{
auto
dim_a
=
matrix_a
.
dims
();
auto
dim_b
=
matrix_b
.
dims
();
auto
dim_out
=
matrix_out
->
dims
();
PADDLE_ENFORCE
(
dim_a
.
size
()
==
2
&&
dim_b
.
size
()
==
2
&&
dim_out
.
size
()
==
2
,
"The input and output of matmul be matrix"
);
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
matrix_a
.
place
())
&&
platform
::
is_gpu_place
(
matrix_b
.
place
())
&&
platform
::
is_gpu_place
(
matrix_out
->
place
()),
"Matrix must all be in CUDAPlace"
);
int
M
=
dim_out
[
0
];
int
N
=
dim_out
[
1
];
int
K
=
(
trans_a
==
false
)
?
dim_a
[
1
]
:
dim_a
[
0
];
CBLAS_TRANSPOSE
transA
=
(
trans_a
==
false
)
?
CblasNoTrans
:
CblasTrans
;
CBLAS_TRANSPOSE
transB
=
(
trans_b
==
false
)
?
CblasNoTrans
:
CblasTrans
;
Blas
<
platform
::
CUDADeviceContext
>
(
context
).
GEMM
(
transA
,
transB
,
M
,
N
,
K
,
alpha
,
matrix_a
.
data
<
float16
>
(),
matrix_b
.
data
<
float16
>
(),
beta
,
matrix_out
->
data
<
float16
>
());
}
template
<
>
void
matmul
<
platform
::
CUDADeviceContext
,
float
>
(
const
platform
::
CUDADeviceContext
&
context
,
const
framework
::
Tensor
&
matrix_a
,
bool
trans_a
,
const
framework
::
Tensor
&
matrix_b
,
bool
trans_b
,
float
alpha
,
framework
::
Tensor
*
matrix_out
,
float
beta
)
{
auto
dim_a
=
matrix_a
.
dims
();
auto
dim_b
=
matrix_b
.
dims
();
auto
dim_out
=
matrix_out
->
dims
();
PADDLE_ENFORCE
(
dim_a
.
size
()
==
2
&&
dim_b
.
size
()
==
2
&&
dim_out
.
size
()
==
2
,
"The input and output of matmul be matrix"
);
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
matrix_a
.
place
())
&&
platform
::
is_gpu_place
(
matrix_b
.
place
())
&&
platform
::
is_gpu_place
(
matrix_out
->
place
()),
"Matrix must all be in CUDAPlace"
);
int
M
=
dim_out
[
0
];
int
N
=
dim_out
[
1
];
int
K
=
(
trans_a
==
false
)
?
dim_a
[
1
]
:
dim_a
[
0
];
CBLAS_TRANSPOSE
transA
=
(
trans_a
==
false
)
?
CblasNoTrans
:
CblasTrans
;
CBLAS_TRANSPOSE
transB
=
(
trans_b
==
false
)
?
CblasNoTrans
:
CblasTrans
;
Blas
<
platform
::
CUDADeviceContext
>
(
context
).
GEMM
(
transA
,
transB
,
M
,
N
,
K
,
alpha
,
matrix_a
.
data
<
float
>
(),
matrix_b
.
data
<
float
>
(),
beta
,
matrix_out
->
data
<
float
>
());
}
template
<
>
void
matmul
<
platform
::
CUDADeviceContext
,
double
>
(
const
platform
::
CUDADeviceContext
&
context
,
const
framework
::
Tensor
&
matrix_a
,
bool
trans_a
,
const
framework
::
Tensor
&
matrix_b
,
bool
trans_b
,
double
alpha
,
framework
::
Tensor
*
matrix_out
,
double
beta
)
{
auto
dim_a
=
matrix_a
.
dims
();
auto
dim_b
=
matrix_b
.
dims
();
auto
dim_out
=
matrix_out
->
dims
();
PADDLE_ENFORCE
(
dim_a
.
size
()
==
2
&&
dim_b
.
size
()
==
2
&&
dim_out
.
size
()
==
2
,
"The input and output of matmul be matrix"
);
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
matrix_a
.
place
())
&&
platform
::
is_gpu_place
(
matrix_b
.
place
())
&&
platform
::
is_gpu_place
(
matrix_out
->
place
()),
"Matrix must all be in CUDAPlace"
);
int
M
=
dim_out
[
0
];
int
N
=
dim_out
[
1
];
int
K
=
(
trans_a
==
false
)
?
dim_a
[
1
]
:
dim_a
[
0
];
CBLAS_TRANSPOSE
transA
=
(
trans_a
==
false
)
?
CblasNoTrans
:
CblasTrans
;
CBLAS_TRANSPOSE
transB
=
(
trans_b
==
false
)
?
CblasNoTrans
:
CblasTrans
;
Blas
<
platform
::
CUDADeviceContext
>
(
context
).
GEMM
(
transA
,
transB
,
M
,
N
,
K
,
alpha
,
matrix_a
.
data
<
double
>
(),
matrix_b
.
data
<
double
>
(),
beta
,
matrix_out
->
data
<
double
>
());
}
template
<
>
void
batched_gemm
<
platform
::
CUDADeviceContext
,
float16
>
(
const
platform
::
CUDADeviceContext
&
context
,
const
CBLAS_TRANSPOSE
transA
,
...
...
paddle/fluid/operators/math/math_function.h
浏览文件 @
815d8884
...
...
@@ -64,14 +64,31 @@ class Blas {
explicit
Blas
(
const
DeviceContext
&
context
)
:
context_
(
context
)
{}
template
<
typename
T
>
void
GEMM
(
const
CBLAS_TRANSPOSE
transA
,
const
CBLAS_TRANSPOSE
transB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
T
alpha
,
const
T
*
A
,
const
T
*
B
,
const
T
beta
,
T
*
C
)
const
;
void
GEMM
(
CBLAS_TRANSPOSE
transA
,
CBLAS_TRANSPOSE
transB
,
int
M
,
int
N
,
int
K
,
T
alpha
,
const
T
*
A
,
const
T
*
B
,
T
beta
,
T
*
C
)
const
;
template
<
typename
T
>
void
GEMM
(
const
bool
transA
,
const
bool
transB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
T
alpha
,
const
T
*
A
,
const
int
lda
,
const
T
*
B
,
const
int
ldb
,
const
T
beta
,
T
*
C
,
const
int
ldc
)
const
;
void
GEMM
(
bool
transA
,
bool
transB
,
int
M
,
int
N
,
int
K
,
T
alpha
,
const
T
*
A
,
int
lda
,
const
T
*
B
,
int
ldb
,
T
beta
,
T
*
C
,
int
ldc
)
const
;
template
<
typename
T
>
void
MatMul
(
const
framework
::
Tensor
&
mat_a
,
bool
trans_a
,
const
framework
::
Tensor
&
mat_b
,
bool
trans_b
,
T
alpha
,
framework
::
Tensor
*
mat_out
,
T
beta
)
const
;
template
<
typename
T
>
void
MatMul
(
const
framework
::
Tensor
&
mat_a
,
bool
trans_a
,
const
framework
::
Tensor
&
mat_b
,
bool
trans_b
,
framework
::
Tensor
*
mat_out
)
const
{
MatMul
(
mat_a
,
trans_a
,
mat_b
,
trans_b
,
static_cast
<
T
>
(
1.0
),
mat_out
,
static_cast
<
T
>
(
0.0
));
}
template
<
typename
T
>
void
MatMul
(
const
framework
::
Tensor
&
mat_a
,
const
framework
::
Tensor
&
mat_b
,
framework
::
Tensor
*
mat_out
)
const
{
this
->
template
MatMul
<
T
>(
mat_a
,
false
,
mat_b
,
false
,
mat_out
);
}
private:
const
DeviceContext
&
context_
;
...
...
@@ -86,6 +103,11 @@ class BlasT : private Blas<DeviceContext> {
void
GEMM
(
ARGS
...
args
)
const
{
static_cast
<
const
Blas
<
DeviceContext
>*>
(
this
)
->
template
GEMM
<
T
>(
args
...);
}
template
<
typename
...
ARGS
>
void
MatMul
(
ARGS
...
args
)
const
{
static_cast
<
const
Blas
<
DeviceContext
>*>
(
this
)
->
template
MatMul
<
T
>(
args
...);
}
};
template
<
typename
DeviceContext
,
typename
T
>
...
...
@@ -100,12 +122,6 @@ inline BlasT<DeviceContext, T> GetBlas(const DeviceContext& dev_ctx) {
return
BlasT
<
DeviceContext
,
T
>
(
dev_ctx
);
}
// matrix multiply with continuous memory
template
<
typename
DeviceContext
,
typename
T
>
void
matmul
(
const
DeviceContext
&
context
,
const
framework
::
Tensor
&
matrix_a
,
bool
trans_a
,
const
framework
::
Tensor
&
matrix_b
,
bool
trans_b
,
T
alpha
,
framework
::
Tensor
*
matrix_out
,
T
beta
);
// Batched gemm
template
<
typename
DeviceContext
,
typename
T
>
void
batched_gemm
(
const
DeviceContext
&
context
,
const
CBLAS_TRANSPOSE
transA
,
...
...
paddle/fluid/operators/math/math_function_test.cu
浏览文件 @
815d8884
...
...
@@ -23,6 +23,13 @@ void fill_fp16_data(paddle::platform::float16* in_ptr, size_t size,
}
}
template
<
typename
T
>
inline
paddle
::
operators
::
math
::
BlasT
<
paddle
::
platform
::
CUDADeviceContext
,
T
>
GetBlas
(
const
paddle
::
platform
::
CUDADeviceContext
&
context
)
{
return
paddle
::
operators
::
math
::
GetBlas
<
paddle
::
platform
::
CUDADeviceContext
,
T
>
(
context
);
}
TEST
(
math_function
,
notrans_mul_trans_fp32
)
{
paddle
::
framework
::
Tensor
input1
;
paddle
::
framework
::
Tensor
input1_gpu
;
...
...
@@ -42,9 +49,8 @@ TEST(math_function, notrans_mul_trans_fp32) {
paddle
::
framework
::
TensorCopySync
(
input1
,
gpu_place
,
&
input2_gpu
);
out_gpu
.
mutable_data
<
float
>
({
2
,
2
},
gpu_place
);
paddle
::
operators
::
math
::
matmul
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
(
context
,
input1_gpu
,
false
,
input2_gpu
,
true
,
1
,
&
out_gpu
,
0
);
GetBlas
<
float
>
(
context
).
MatMul
(
input1_gpu
,
false
,
input2_gpu
,
true
,
1
,
&
out_gpu
,
0
);
paddle
::
framework
::
TensorCopySync
(
out_gpu
,
cpu_place
,
&
out
);
...
...
@@ -81,10 +87,9 @@ TEST(math_function, notrans_mul_trans_fp16) {
out_gpu
.
mutable_data
<
paddle
::
platform
::
float16
>
({
2
,
2
},
gpu_place
);
paddle
::
operators
::
math
::
matmul
<
paddle
::
platform
::
CUDADeviceContext
,
paddle
::
platform
::
float16
>
(
context
,
input1_gpu
,
false
,
input2_gpu
,
true
,
paddle
::
platform
::
float16
(
1
),
&
out_gpu
,
paddle
::
platform
::
float16
(
0
));
GetBlas
<
paddle
::
platform
::
float16
>
(
context
).
MatMul
(
input1_gpu
,
false
,
input2_gpu
,
true
,
paddle
::
platform
::
float16
(
1
),
&
out_gpu
,
paddle
::
platform
::
float16
(
0
));
paddle
::
framework
::
TensorCopySync
(
out_gpu
,
cpu_place
,
&
out
);
...
...
@@ -116,8 +121,8 @@ TEST(math_function, trans_mul_notrans_fp32) {
out_gpu
.
mutable_data
<
float
>
({
3
,
3
},
gpu_place
);
paddle
::
operators
::
math
::
matmul
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
(
context
,
input1_gpu
,
true
,
input2_gpu
,
false
,
1
,
&
out_gpu
,
0
);
GetBlas
<
float
>
(
context
).
MatMul
(
input1_gpu
,
true
,
input2_gpu
,
false
,
1
,
&
out_gpu
,
0
);
paddle
::
framework
::
TensorCopySync
(
out_gpu
,
cpu_place
,
&
out
);
...
...
@@ -159,10 +164,9 @@ TEST(math_function, trans_mul_notrans_fp16) {
out_gpu
.
mutable_data
<
paddle
::
platform
::
float16
>
({
3
,
3
},
gpu_place
);
paddle
::
operators
::
math
::
matmul
<
paddle
::
platform
::
CUDADeviceContext
,
paddle
::
platform
::
float16
>
(
context
,
input1_gpu
,
true
,
input2_gpu
,
false
,
paddle
::
platform
::
float16
(
1
),
&
out_gpu
,
paddle
::
platform
::
float16
(
0
));
GetBlas
<
paddle
::
platform
::
float16
>
(
context
).
MatMul
(
input1_gpu
,
true
,
input2_gpu
,
false
,
paddle
::
platform
::
float16
(
1
),
&
out_gpu
,
paddle
::
platform
::
float16
(
0
));
paddle
::
framework
::
TensorCopySync
(
out_gpu
,
cpu_place
,
&
out
);
...
...
@@ -179,13 +183,6 @@ TEST(math_function, trans_mul_notrans_fp16) {
EXPECT_EQ
(
static_cast
<
float
>
(
out_ptr
[
8
]),
29
);
}
template
<
typename
T
>
inline
paddle
::
operators
::
math
::
BlasT
<
paddle
::
platform
::
CUDADeviceContext
,
T
>
GetBlas
(
const
paddle
::
platform
::
CUDADeviceContext
&
context
)
{
return
paddle
::
operators
::
math
::
GetBlas
<
paddle
::
platform
::
CUDADeviceContext
,
T
>
(
context
);
}
TEST
(
math_function
,
gemm_notrans_cublas_fp32
)
{
paddle
::
framework
::
Tensor
input1
;
paddle
::
framework
::
Tensor
input2
;
...
...
paddle/fluid/operators/mul_op.h
浏览文件 @
815d8884
...
...
@@ -46,9 +46,10 @@ class MulKernel : public framework::OpKernel<T> {
if
(
z_dim
.
size
()
!=
2
)
{
z
->
Resize
({
x_matrix
.
dims
()[
0
],
y_matrix
.
dims
()[
1
]});
}
math
::
matmul
<
DeviceContext
,
T
>
(
context
.
template
device_context
<
DeviceContext
>(),
x_matrix
,
false
,
y_matrix
,
false
,
static_cast
<
T
>
(
1
),
z
,
static_cast
<
T
>
(
0
));
auto
blas
=
math
::
GetBlas
<
DeviceContext
,
T
>
(
context
);
blas
.
MatMul
(
x_matrix
,
y_matrix
,
z
);
if
(
z_dim
.
size
()
!=
2
)
{
z
->
Resize
(
z_dim
);
}
...
...
@@ -79,6 +80,7 @@ class MulGradKernel : public framework::OpKernel<T> {
Tensor
*
dx
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
Tensor
*
dy
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Y"
));
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
auto
blas
=
math
::
GetBlas
<
DeviceContext
,
T
>
(
dev_ctx
);
if
(
dx
)
{
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
Tensor
dx_matrix
=
dx
->
dims
().
size
()
>
2
...
...
@@ -86,8 +88,7 @@ class MulGradKernel : public framework::OpKernel<T> {
:
*
dx
;
// dx = dout * y'. dx: M x K, dout : M x N, y : K x N
math
::
matmul
<
DeviceContext
,
T
>
(
dev_ctx
,
dout_mat
,
false
,
y_matrix
,
true
,
1
,
&
dx_matrix
,
0
);
blas
.
MatMul
(
dout_mat
,
false
,
y_matrix
,
true
,
&
dx_matrix
);
}
if
(
dy
)
{
dy
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
...
...
@@ -95,8 +96,7 @@ class MulGradKernel : public framework::OpKernel<T> {
?
framework
::
ReshapeToMatrix
(
*
dy
,
y_num_col_dims
)
:
*
dy
;
// dy = x' * dout. dy K x N, dout : M x N, x : M x K
math
::
matmul
<
DeviceContext
,
T
>
(
dev_ctx
,
x_matrix
,
true
,
dout_mat
,
false
,
1
,
&
dy_matrix
,
0
);
blas
.
MatMul
(
x_matrix
,
true
,
dout_mat
,
false
,
&
dy_matrix
);
}
}
};
...
...
paddle/fluid/operators/sequence_conv_op.h
浏览文件 @
815d8884
...
...
@@ -58,17 +58,15 @@ class SequenceConvKernel : public framework::OpKernel<T> {
// Because if padding_trainable is false, padding data should be zeros.
math
::
SetConstant
<
DeviceContext
,
T
>
set_zero
;
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
auto
blas
=
math
::
GetBlas
<
DeviceContext
,
T
>
(
dev_ctx
);
set_zero
(
dev_ctx
,
&
col
,
static_cast
<
T
>
(
0
));
math
::
ContextProjectFunctor
<
DeviceContext
,
T
>
seq_project_functor
;
seq_project_functor
(
dev_ctx
,
*
in
,
*
padding_data
,
padding_trainable
,
context_start
,
context_length
,
context_stride
,
up_pad
,
down_pad
,
&
col
);
math
::
matmul
<
DeviceContext
,
T
>
(
dev_ctx
,
col
,
false
,
filter
,
false
,
static_cast
<
T
>
(
1.0
),
out
,
static_cast
<
T
>
(
0.0
));
blas
.
MatMul
(
col
,
filter
,
out
);
}
};
...
...
@@ -99,6 +97,7 @@ class SequenceConvGradKernel : public framework::OpKernel<T> {
math
::
SetConstant
<
DeviceContext
,
T
>
set_zero
;
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
auto
blas
=
math
::
GetBlas
<
DeviceContext
,
T
>
(
dev_ctx
);
// use col_shape in the im2col calculation
framework
::
DDim
col_shape
=
{
in
->
dims
()[
0
],
sequence_width
*
context_length
};
...
...
@@ -108,8 +107,7 @@ class SequenceConvGradKernel : public framework::OpKernel<T> {
col
.
mutable_data
<
T
>
(
col_shape
,
context
.
GetPlace
());
// Because if padding_trainable is false, padding data should be zeros.
set_zero
(
dev_ctx
,
&
col
,
static_cast
<
T
>
(
0
));
math
::
matmul
<
DeviceContext
,
T
>
(
dev_ctx
,
*
out_g
,
false
,
*
filter
,
true
,
T
(
1.0
),
&
col
,
T
(
1.0
));
blas
.
MatMul
(
*
out_g
,
false
,
*
filter
,
true
,
&
col
);
}
math
::
ContextProjectFunctor
<
DeviceContext
,
T
>
seq_project_functor
;
math
::
ContextProjectGradFunctor
<
DeviceContext
,
T
>
seq_project_grad_functor
;
...
...
@@ -150,8 +148,7 @@ class SequenceConvGradKernel : public framework::OpKernel<T> {
context_start
,
context_length
,
context_stride
,
up_pad
,
down_pad
,
&
col
);
math
::
matmul
<
DeviceContext
,
T
>
(
dev_ctx
,
col
,
true
,
out_grad
,
false
,
T
(
1.0
),
&
filter_grad
,
T
(
1.0
));
blas
.
MatMul
(
col
,
true
,
out_grad
,
false
,
&
filter_grad
);
}
}
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录