Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
a1e7f2d5
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
a1e7f2d5
编写于
6月 29, 2018
作者:
C
chenweihang
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' of
https://github.com/PaddlePaddle/Paddle
into unsqueeze_op
上级
70729ad6
7b54f168
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
69 addition
and
50 deletion
+69
-50
README.md
README.md
+0
-1
cmake/generic.cmake
cmake/generic.cmake
+14
-0
cmake/inference_lib.cmake
cmake/inference_lib.cmake
+0
-13
paddle/fluid/framework/data_layout_transform.cc
paddle/fluid/framework/data_layout_transform.cc
+2
-2
paddle/fluid/framework/data_layout_transform.h
paddle/fluid/framework/data_layout_transform.h
+0
-6
paddle/fluid/framework/data_transform.cc
paddle/fluid/framework/data_transform.cc
+6
-2
paddle/fluid/framework/tensor_util.cc
paddle/fluid/framework/tensor_util.cc
+2
-15
paddle/fluid/framework/tensor_util.h
paddle/fluid/framework/tensor_util.h
+15
-0
paddle/fluid/operators/batch_norm_mkldnn_op.cc
paddle/fluid/operators/batch_norm_mkldnn_op.cc
+19
-10
paddle/fluid/platform/mkldnn_helper.h
paddle/fluid/platform/mkldnn_helper.h
+11
-1
未找到文件。
README.md
浏览文件 @
a1e7f2d5
...
...
@@ -4,7 +4,6 @@
[
![Build Status
](
https://travis-ci.org/PaddlePaddle/Paddle.svg?branch=develop
)
](https://travis-ci.org/PaddlePaddle/Paddle)
[
![Documentation Status
](
https://img.shields.io/badge/docs-latest-brightgreen.svg?style=flat
)
](http://www.paddlepaddle.org/docs/develop/documentation/en/getstarted/index_en.html)
[
![Documentation Status
](
https://img.shields.io/badge/中文文档-最新-brightgreen.svg
)
](http://www.paddlepaddle.org/docs/develop/documentation/zh/getstarted/index_cn.html)
[
![Coverage Status
](
https://coveralls.io/repos/github/PaddlePaddle/Paddle/badge.svg?branch=develop
)
](https://coveralls.io/github/PaddlePaddle/Paddle?branch=develop)
[
![Release
](
https://img.shields.io/github/release/PaddlePaddle/Paddle.svg
)
](https://github.com/PaddlePaddle/Paddle/releases)
[
![License
](
https://img.shields.io/badge/license-Apache%202-blue.svg
)
](LICENSE)
...
...
cmake/generic.cmake
浏览文件 @
a1e7f2d5
...
...
@@ -96,6 +96,20 @@ if(NOT APPLE AND NOT ANDROID)
set
(
CMAKE_CXX_LINK_EXECUTABLE
"
${
CMAKE_CXX_LINK_EXECUTABLE
}
-pthread -ldl -lrt"
)
endif
(
NOT APPLE AND NOT ANDROID
)
set_property
(
GLOBAL PROPERTY FLUID_MODULES
""
)
# find all fluid modules is used for paddle fluid static library
# for building inference libs
function
(
find_fluid_modules TARGET_NAME
)
get_filename_component
(
__target_path
${
TARGET_NAME
}
ABSOLUTE
)
string
(
REGEX REPLACE
"^
${
PADDLE_SOURCE_DIR
}
/"
""
__target_path
${
__target_path
}
)
string
(
FIND
"
${
__target_path
}
"
"fluid"
pos
)
if
(
pos GREATER 1
)
get_property
(
fluid_modules GLOBAL PROPERTY FLUID_MODULES
)
set
(
fluid_modules
${
fluid_modules
}
${
TARGET_NAME
}
)
set_property
(
GLOBAL PROPERTY FLUID_MODULES
"
${
fluid_modules
}
"
)
endif
()
endfunction
(
find_fluid_modules
)
function
(
merge_static_libs TARGET_NAME
)
set
(
libs
${
ARGN
}
)
list
(
REMOVE_DUPLICATES libs
)
...
...
cmake/inference_lib.cmake
浏览文件 @
a1e7f2d5
...
...
@@ -12,19 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
set_property
(
GLOBAL PROPERTY FLUID_MODULES
""
)
# find all fluid modules is used for paddle fluid static library
function
(
find_fluid_modules TARGET_NAME
)
get_filename_component
(
__target_path
${
TARGET_NAME
}
ABSOLUTE
)
string
(
REGEX REPLACE
"^
${
PADDLE_SOURCE_DIR
}
/"
""
__target_path
${
__target_path
}
)
string
(
FIND
"
${
__target_path
}
"
"fluid"
pos
)
if
(
pos GREATER 1
)
get_property
(
fluid_modules GLOBAL PROPERTY FLUID_MODULES
)
set
(
fluid_modules
${
fluid_modules
}
${
TARGET_NAME
}
)
set_property
(
GLOBAL PROPERTY FLUID_MODULES
"
${
fluid_modules
}
"
)
endif
()
endfunction
(
find_fluid_modules
)
# make package for paddle fluid shared and static library
function
(
copy TARGET
)
set
(
options
""
)
...
...
paddle/fluid/framework/data_layout_transform.cc
浏览文件 @
a1e7f2d5
...
...
@@ -147,9 +147,9 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
"Input tensor type is not supported: "
,
in
.
type
().
name
());
memory
::
data_type
out_type
=
in_type
;
auto
in_format
=
MKLDNNFormatForSize
(
in_tz
.
size
(),
in
.
format
());
auto
in_format
=
platform
::
MKLDNNFormatForSize
(
in_tz
.
size
(),
in
.
format
());
auto
out_format
=
MKLDNNFormatForSize
(
in_tz
.
size
(),
ToMKLDNNFormat
(
out_layout
));
platform
::
MKLDNNFormatForSize
(
in_tz
.
size
(),
ToMKLDNNFormat
(
out_layout
));
void
*
in_data
=
GetDataFromTensor
(
in
,
in_type
);
...
...
paddle/fluid/framework/data_layout_transform.h
浏览文件 @
a1e7f2d5
...
...
@@ -62,12 +62,6 @@ inline MKLDNNDataType ToMKLDNNDataType(const std::type_index type) {
return
MKLDNNDataType
::
data_undef
;
}
inline
MKLDNNFormat
MKLDNNFormatForSize
(
size_t
dims_size
,
MKLDNNFormat
default_format
)
{
return
(
dims_size
==
1
?
mkldnn
::
memory
::
format
::
x
:
dims_size
==
2
?
mkldnn
::
memory
::
format
::
nc
:
default_format
);
}
#endif
void
TransDataLayoutFromMKLDNN
(
const
OpKernelType
&
kernel_type_for_var
,
...
...
paddle/fluid/framework/data_transform.cc
浏览文件 @
a1e7f2d5
...
...
@@ -18,6 +18,10 @@ limitations under the License. */
#include "paddle/fluid/framework/data_layout_transform.h"
#include "paddle/fluid/framework/data_type_transform.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
namespace
paddle
{
namespace
framework
{
...
...
@@ -48,8 +52,8 @@ void TransformData(const OpKernelType &expected_kernel_type,
// Case1 - transform from Non-MKLDNN OPKernel to MKLDNN OPKernel
// Just set layout/format. No real transform occur
auto
out_format
=
MKLDNNFormatForSize
(
in
.
dims
().
size
(),
ToMKLDNNFormat
(
lin
));
auto
out_format
=
platform
::
MKLDNNFormatForSize
(
in
.
dims
().
size
(),
ToMKLDNNFormat
(
lin
));
out
.
ShareDataWith
(
input_tensor
);
out
.
set_layout
(
DataLayout
::
kMKLDNN
);
...
...
paddle/fluid/framework/tensor_util.cc
浏览文件 @
a1e7f2d5
...
...
@@ -73,18 +73,12 @@ void TensorCopy(const Tensor& src, const platform::Place& dst_place,
memory
::
Copy
(
dst_gpu_place
,
dst_ptr
,
src_gpu_place
,
src_ptr
,
size
,
stream
);
}
else
{
// NOTE(zcd): Because TensorCopy is an async operation, when the src_place
// and dst_place are two different GPU, to ensure that the operation can
// be carried out correctly, we should make ctx wait.
// If ctx_place and src_place are the same, we should add ctx.Wait()
// after memory::Copy; if ctx_place and dst_place are the same, we should
// add ctx.Wait() before memory::Copy.
if
(
platform
::
is_same_place
(
ctx_place
,
src_place
))
{
memory
::
Copy
(
dst_gpu_place
,
dst_ptr
,
src_gpu_place
,
src_ptr
,
size
,
stream
);
ctx
.
Wait
();
platform
::
DeviceContextPool
::
Instance
().
Get
(
src
.
place
())
->
Wait
();
}
else
if
(
platform
::
is_same_place
(
ctx_place
,
dst_place
))
{
ctx
.
Wait
();
platform
::
DeviceContextPool
::
Instance
().
Get
(
src
.
place
())
->
Wait
();
memory
::
Copy
(
dst_gpu_place
,
dst_ptr
,
src_gpu_place
,
src_ptr
,
size
,
stream
);
}
else
{
...
...
@@ -97,13 +91,6 @@ void TensorCopy(const Tensor& src, const platform::Place& dst_place,
void
TensorCopy
(
const
Tensor
&
src
,
const
platform
::
Place
&
dst_place
,
Tensor
*
dst
)
{
// NOTE(zcd): If the src.place() and dst_place are two different GPU,
// the copy operation is carried out on the dst_place's stream. This is
// very important, because TensorCopy is an async operator, and in most
// case, once this copy operator returns, dst is to be used in dst_place's
// stream, if this copy operation is carried out on the src_place's stream,
// when dst is used in dst_place's stream the copy operation may be
// not completed.
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
const
platform
::
DeviceContext
*
dev_ctx
;
if
(
platform
::
is_gpu_place
(
dst_place
))
{
...
...
paddle/fluid/framework/tensor_util.h
浏览文件 @
a1e7f2d5
...
...
@@ -23,10 +23,25 @@ limitations under the License. */
namespace
paddle
{
namespace
framework
{
// NOTE(zcd): Because TensorCopy is an async operation, when the src_place
// and dst_place are two different GPU, to ensure that the operation can
// be carried out correctly, there is a src_ctx wait operation in TensorCopy.
// If ctx_place and src_place are the same, src_ctx.Wait() is added
// after memory::Copy; if ctx_place and dst_place are the same,
// src_ctx.Wait() is added before memory::Copy.
void
TensorCopy
(
const
Tensor
&
src
,
const
platform
::
Place
&
dst_place
,
const
platform
::
DeviceContext
&
ctx
,
Tensor
*
dst
);
// NOTE(zcd): If the src.place() and dst_place are two different GPU,
// the copy operation is carried out on the dst_place's stream. This is
// very important, because TensorCopy is an async operator, and in most
// case, once this copy operator returns, dst is to be used in dst_place's
// stream, if this copy operation is carried out on the src_place's stream,
// when dst is used in dst_place's stream the copy operation may be
// not completed.
void
TensorCopy
(
const
Tensor
&
src
,
const
platform
::
Place
&
dst_place
,
Tensor
*
dst
);
void
TensorCopySync
(
const
Tensor
&
src
,
const
platform
::
Place
&
dst_place
,
Tensor
*
dst
);
...
...
paddle/fluid/operators/batch_norm_mkldnn_op.cc
浏览文件 @
a1e7f2d5
...
...
@@ -115,9 +115,12 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
if
(
fuse_with_relu
)
flags
|=
mkldnn
::
fuse_bn_relu
;
// create mkldnn memory from input x tensor
auto
src_memory
=
memory
({{{
src_tz
},
memory
::
data_type
::
f32
,
x
->
format
()},
mkldnn_engine
},
to_void_cast
(
x_data
));
mkldnn
::
memory
::
format
input_format
=
platform
::
MKLDNNFormatForSize
(
src_tz
.
size
(),
x
->
format
());
auto
src_memory
=
memory
(
{{{
src_tz
},
memory
::
data_type
::
f32
,
input_format
},
mkldnn_engine
},
to_void_cast
(
x_data
));
// create primitive descriptor for batch norm forward
using
bn_fwd_types
=
bn_type_traits
<
mkldnn
::
batch_normalization_forward
>
;
...
...
@@ -251,15 +254,21 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
using
bn_bwd_types
=
bn_type_traits
<
mkldnn
::
batch_normalization_backward
>
;
// create mkldnn memory from input diff_y tensor
auto
user_diff_dst_memory
=
memory
({{{
diff_dst_tz
},
memory
::
data_type
::
f32
,
diff_y
->
format
()},
mkldnn_engine
},
to_void_cast
(
diff_y_data
));
mkldnn
::
memory
::
format
dst_format
=
platform
::
MKLDNNFormatForSize
(
src_tz
.
size
(),
diff_y
->
format
());
auto
user_diff_dst_memory
=
memory
(
{{{
diff_dst_tz
},
memory
::
data_type
::
f32
,
dst_format
},
mkldnn_engine
},
to_void_cast
(
diff_y_data
));
// create mkldnn memory from input x tensor
auto
src_memory
=
memory
({{{
src_tz
},
memory
::
data_type
::
f32
,
x
->
format
()},
mkldnn_engine
},
to_void_cast
(
x_data
));
mkldnn
::
memory
::
format
input_format
=
platform
::
MKLDNNFormatForSize
(
src_tz
.
size
(),
x
->
format
());
auto
src_memory
=
memory
(
{{{
src_tz
},
memory
::
data_type
::
f32
,
input_format
},
mkldnn_engine
},
to_void_cast
(
x_data
));
// for diff_dst, try to use same format as dst in forward pass
auto
diff_dst_pd
=
batch_norm_fwd_pd
.
get
()
->
dst_primitive_desc
();
...
...
paddle/fluid/platform/mkldnn_helper.h
浏览文件 @
a1e7f2d5
...
...
@@ -228,7 +228,7 @@ class MKLDNNHandler {
return
dstr
;
};
return
dims2str
(
operand_dims
)
+
suffix
;
}
;
}
protected:
const
MKLDNNDeviceContext
&
dev_ctx_
;
...
...
@@ -237,5 +237,15 @@ class MKLDNNHandler {
bool
is_reusing_
;
};
inline
mkldnn
::
memory
::
format
MKLDNNFormatForSize
(
size_t
dims_size
,
mkldnn
::
memory
::
format
data_format
)
{
if
(
dims_size
==
1
)
{
return
mkldnn
::
memory
::
format
::
x
;
}
else
if
(
dims_size
==
2
)
{
return
mkldnn
::
memory
::
format
::
nc
;
}
return
data_format
;
}
}
// namespace platform
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录