Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
4cc57836
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
4cc57836
编写于
8月 25, 2017
作者:
T
tensor-tang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
enable reorder
上级
780c8d96
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
97 addition
and
32 deletion
+97
-32
paddle/gserver/layers/MKLDNNFcLayer.cpp
paddle/gserver/layers/MKLDNNFcLayer.cpp
+10
-29
paddle/math/MKLDNNMatrix.cpp
paddle/math/MKLDNNMatrix.cpp
+57
-0
paddle/math/MKLDNNMatrix.h
paddle/math/MKLDNNMatrix.h
+30
-3
未找到文件。
paddle/gserver/layers/MKLDNNFcLayer.cpp
浏览文件 @
4cc57836
...
...
@@ -61,39 +61,20 @@ void MKLDNNFcLayer::convertWeightsFromPaddle() {
return
;
}
// TODO(TJ): dst format should get from wgtVal_
int
dstFmt
=
PARAM_FORMAT_MKLDNN_OI
;
int
srcFmt
=
weight_
->
getParameterPtr
()
->
getHeaderFormat
();
if
(
srcFmt
==
dstFmt
)
{
return
;
}
// The weight_ is transposed from initial paddle weight
MatrixPtr
paddleWgt
=
Matrix
::
create
(
weight_
->
getW
()
->
getData
(),
iLayerSize_
,
oc_
,
false
,
false
);
// TODO(TJ): remove this print when do not need differ weights
std
::
ostringstream
ostr
;
paddleWgt
->
print
(
ostr
);
VLOG
(
MKLDNN_ALL
)
<<
"Initial Weight from paddle: "
<<
std
::
endl
<<
ostr
.
str
();
// The mkldnn weight is transposed from initial paddle matrix
MatrixPtr
paddleWgtT
;
paddleWgt
->
transpose
(
paddleWgtT
,
true
);
weight_
->
getW
()
->
copyFrom
(
*
paddleWgtT
);
weight_
->
getParameterPtr
()
->
setHeaderFormat
(
dstFmt
);
CHECK
(
wgtVal_
)
<<
"should have been initialized"
;
bool
hasNoSpatial_
=
ih_
==
1
&&
iw_
==
1
;
auto
targetDim
=
wgtVal_
->
getDims
();
auto
srcFmt
=
hasNoSpatial_
?
memory
::
format
::
io
:
memory
::
format
::
ihwo
;
wgtVal_
->
reorderDataFrom
(
wgtVal_
,
srcFmt
,
targetDim
);
hasInitedWgt_
=
true
;
}
void
MKLDNNFcLayer
::
convertWeightsToPaddle
()
{
MatrixPtr
dnnWgt
=
weight_
->
getW
();
MatrixPtr
paddleWgt
;
dnnWgt
->
transpose
(
paddleWgt
,
true
);
// copy paddle weight and override on weight_
MatrixPtr
dnnWgtT
=
Matrix
::
create
(
dnnWgt
->
getData
(),
dnnWgt
->
getWidth
(),
dnnWgt
->
getHeight
(),
false
,
false
);
dnnWgtT
->
copyFrom
(
*
paddleWgt
);
CHECK
(
wgtVal_
)
<<
"should have been initialized"
;
bool
hasNoSpatial_
=
ih_
==
1
&&
iw_
==
1
;
auto
targetDim
=
wgtVal_
->
getDims
();
auto
dstFmt
=
hasNoSpatial_
?
memory
::
format
::
io
:
memory
::
format
::
ihwo
;
wgtVal_
->
reorderDataTo
(
wgtVal_
,
dstFmt
,
targetDim
);
}
void
MKLDNNFcLayer
::
reshape
()
{
...
...
paddle/math/MKLDNNMatrix.cpp
浏览文件 @
4cc57836
...
...
@@ -56,6 +56,63 @@ MKLDNNMatrixPtr MKLDNNMatrix::create(MatrixPtr m,
return
create
(
m
,
pd
);
}
void
MKLDNNMatrix
::
reorderDataFrom
(
const
MKLDNNMatrixPtr
&
m
,
memory
::
format
srcFmt
,
memory
::
dims
targetDim
)
{
memory
::
format
dstFmt
=
getFormat
();
if
(
srcFmt
==
dstFmt
)
{
return
;
}
CHECK_EQ
(
getElementCnt
(),
m
->
getElementCnt
())
<<
"size should equal"
;
real
*
srcData
=
getData
();
real
*
dstData
=
m
->
getData
();
reorderOnce
(
srcData
,
dstData
,
srcFmt
,
dstFmt
,
targetDim
);
}
void
MKLDNNMatrix
::
reorderDataTo
(
const
MKLDNNMatrixPtr
&
m
,
memory
::
format
dstFmt
,
memory
::
dims
targetDim
)
{
memory
::
format
srcFmt
=
getFormat
();
if
(
srcFmt
==
dstFmt
)
{
return
;
}
CHECK_EQ
(
getElementCnt
(),
m
->
getElementCnt
())
<<
"size should equal"
;
real
*
srcData
=
getData
();
real
*
dstData
=
m
->
getData
();
reorderOnce
(
srcData
,
dstData
,
srcFmt
,
dstFmt
,
targetDim
);
}
void
MKLDNNMatrix
::
reorderOnce
(
void
*
srcData
,
void
*
dstData
,
memory
::
format
srcFmt
,
memory
::
format
dstFmt
,
memory
::
dims
dm
)
{
CHECK
(
srcData
);
CHECK
(
dstData
);
MatrixPtr
tmpSrc
;
if
(
dstData
==
srcData
)
{
// inplace data
size_t
sz
=
1
;
for
(
size_t
i
=
0
;
i
<
dm
.
size
();
++
i
)
{
sz
*=
dm
[
i
];
}
tmpSrc
=
Matrix
::
create
(
sz
,
1
,
false
,
false
);
tmpSrc
->
copyFrom
((
real
*
)
srcData
,
sz
);
srcData
=
tmpSrc
->
getData
();
}
auto
dtype
=
this
->
getDtype
();
auto
srcMD
=
memory
::
desc
(
dm
,
dtype
,
srcFmt
);
auto
dstMD
=
memory
::
desc
(
dm
,
dtype
,
dstFmt
);
auto
eg
=
this
->
getEngine
();
auto
src
=
memory
(
memory
::
primitive_desc
(
srcMD
,
eg
),
srcData
);
auto
dst
=
memory
(
memory
::
primitive_desc
(
dstMD
,
eg
),
dstData
);
auto
r
=
reorder
(
src
,
dst
);
stream
(
stream
::
kind
::
eager
).
submit
({
r
}).
wait
();
}
void
MKLDNNMatrix
::
downSpatial
()
{
int
fmt
=
getFormat
();
if
(
!
(
fmt
==
memory
::
format
::
nchw
||
fmt
==
memory
::
format
::
oihw
))
{
...
...
paddle/math/MKLDNNMatrix.h
浏览文件 @
4cc57836
...
...
@@ -21,9 +21,6 @@ limitations under the License. */
namespace
paddle
{
static
const
std
::
map
<
mkldnn
::
memory
::
format
,
PARAM_FORMAT
>
PARAM_FOARMAT_MAP
=
{{
mkldnn
::
memory
::
format
::
oi
,
PARAM_FORMAT_MKLDNN_OI
}};
class
MKLDNNMatrix
;
typedef
std
::
shared_ptr
<
MKLDNNMatrix
>
MKLDNNMatrixPtr
;
...
...
@@ -57,6 +54,26 @@ public:
mkldnn
::
memory
::
data_type
dtype
=
mkldnn
::
memory
::
data_type
::
f32
);
public:
/**
* Reorder this MKLDNNMatrix from other format.
* Support inplace reorder
* Pay attention: this function would only reorder the data layout.
* will NOT change this original dim or format info
*/
void
reorderDataFrom
(
const
MKLDNNMatrixPtr
&
m
,
memory
::
format
srcFmt
,
memory
::
dims
targetDim
);
/**
* Reorder this MKLDNNMatrix to other format.
* Support inplace reorder
* Pay attention: this function would only reorder the data layout.
* will NOT change the dst dim or format info
*/
void
reorderDataTo
(
const
MKLDNNMatrixPtr
&
m
,
memory
::
format
dstFmt
,
memory
::
dims
targetDim
);
/**
* Dimensionality reduction.
* Change format "nchw --> nc" or "oihw --> oi" if the h and w are both 1
...
...
@@ -113,6 +130,16 @@ public:
* Get engine.
*/
mkldnn
::
engine
getEngine
()
{
return
getPD
().
get_engine
();
}
protected:
/**
* Do once reorder supported inplace.
*/
void
reorderOnce
(
void
*
srcData
,
void
*
dstData
,
memory
::
format
srcFmt
,
memory
::
format
dstFmt
,
memory
::
dims
dm
);
};
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录