Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
89cb3a24
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
89cb3a24
编写于
1月 03, 2018
作者:
T
tensor-tang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
follow comments, refine comment and function name
上级
adf79faa
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
8 addition
and
8 deletion
+8
-8
paddle/gserver/layers/MKLPackedRecurrentLayer.cpp
paddle/gserver/layers/MKLPackedRecurrentLayer.cpp
+2
-2
paddle/gserver/layers/MKLPackedRecurrentLayer.h
paddle/gserver/layers/MKLPackedRecurrentLayer.h
+3
-3
paddle/gserver/layers/MKLPackedWeight.h
paddle/gserver/layers/MKLPackedWeight.h
+3
-3
未找到文件。
paddle/gserver/layers/MKLPackedRecurrentLayer.cpp
浏览文件 @
89cb3a24
...
@@ -59,7 +59,7 @@ void MKLPackedRecurrentLayer::forwardBatch(int batchSize,
...
@@ -59,7 +59,7 @@ void MKLPackedRecurrentLayer::forwardBatch(int batchSize,
MatrixPtr
preBatchValue
=
MatrixPtr
preBatchValue
=
batchValue_
->
getBatchValue
(
n
-
1
,
batchValue
->
getHeight
());
batchValue_
->
getBatchValue
(
n
-
1
,
batchValue
->
getHeight
());
packed_weight_
->
compute
(
batchValue
,
preB
atchValue
);
packed_weight_
->
gemm_compute
(
preBatchValue
,
b
atchValue
);
}
}
Argument
arg
;
Argument
arg
;
arg
.
value
=
batchValue
;
arg
.
value
=
batchValue
;
...
@@ -96,7 +96,7 @@ void MKLPackedRecurrentLayer::backwardBatch(int batchSize,
...
@@ -96,7 +96,7 @@ void MKLPackedRecurrentLayer::backwardBatch(int batchSize,
if
(
n
!=
0
)
{
if
(
n
!=
0
)
{
batchValue
=
batchGrad_
->
getBatchValue
(
n
-
1
,
batchGrad
->
getHeight
());
batchValue
=
batchGrad_
->
getBatchValue
(
n
-
1
,
batchGrad
->
getHeight
());
packed_weightT_
->
compute
(
batchValue
,
batchGrad
);
packed_weightT_
->
gemm_compute
(
batchGrad
,
batchValue
);
}
}
if
(
backwardByBatch
&&
weight_
->
getWGrad
())
{
if
(
backwardByBatch
&&
weight_
->
getWGrad
())
{
...
...
paddle/gserver/layers/MKLPackedRecurrentLayer.h
浏览文件 @
89cb3a24
...
@@ -22,8 +22,8 @@ DECLARE_bool(rnn_use_batch);
...
@@ -22,8 +22,8 @@ DECLARE_bool(rnn_use_batch);
namespace
paddle
{
namespace
paddle
{
/**
/**
* @brief MKLPackedRecurrentLayer is
same with RecurrentLayer but is optimized
* @brief MKLPackedRecurrentLayer is
almost the same with RecurrentLayer
* with MKL cblas packed gemm.
*
but is optimized
with MKL cblas packed gemm.
* More details:
* More details:
* https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/mkl/mkl_packed.md
* https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/mkl/mkl_packed.md
*/
*/
...
@@ -48,7 +48,7 @@ protected:
...
@@ -48,7 +48,7 @@ protected:
const
int
*
starts
)
override
;
const
int
*
starts
)
override
;
protected:
protected:
/// packed_weight_
is
contains same data with
/// packed_weight_ contains same data with
/// RecurrentLayer::weight_ but is packed
/// RecurrentLayer::weight_ but is packed
std
::
unique_ptr
<
MKLPackedWeight
>
packed_weight_
;
std
::
unique_ptr
<
MKLPackedWeight
>
packed_weight_
;
/// packed_weightT_ is the transposition matrix of packed_weight_
/// packed_weightT_ is the transposition matrix of packed_weight_
...
...
paddle/gserver/layers/MKLPackedWeight.h
浏览文件 @
89cb3a24
...
@@ -22,9 +22,9 @@ namespace paddle {
...
@@ -22,9 +22,9 @@ namespace paddle {
class
MKLPackedWeight
{
class
MKLPackedWeight
{
protected:
protected:
/// The point
o
r of weight
/// The point
e
r of weight
real
*
weight_
;
real
*
weight_
;
/// The point
o
r of cblas packed gemm to weight
/// The point
e
r of cblas packed gemm to weight
real
*
packedWeight_
;
real
*
packedWeight_
;
size_t
height_
;
size_t
height_
;
size_t
width_
;
size_t
width_
;
...
@@ -43,7 +43,7 @@ public:
...
@@ -43,7 +43,7 @@ public:
void
pack
()
{
pack_
(
weight_
);
}
void
pack
()
{
pack_
(
weight_
);
}
void
compute
(
MatrixPtr
dst
,
const
MatrixPtr
src
)
{
void
gemm_compute
(
const
MatrixPtr
src
,
MatrixPtr
dst
)
{
cblas_sgemm_compute
(
CblasRowMajor
,
cblas_sgemm_compute
(
CblasRowMajor
,
CblasNoTrans
,
CblasNoTrans
,
CblasPacked
,
CblasPacked
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录