Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
2ba256df
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
2ba256df
编写于
2月 10, 2019
作者:
X
xuezhong
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' of
https://github.com/PaddlePaddle/Paddle
into fix_bug_for_lstmp
上级
dff7461e
bec68fa0
变更
40
隐藏空白更改
内联
并排
Showing
40 changed file
with
329 addition
and
203 deletion
+329
-203
CMakeLists.txt
CMakeLists.txt
+6
-0
cmake/configure.cmake
cmake/configure.cmake
+6
-1
cmake/cuda.cmake
cmake/cuda.cmake
+19
-18
cmake/external/glog.cmake
cmake/external/glog.cmake
+3
-1
cmake/external/mkldnn.cmake
cmake/external/mkldnn.cmake
+2
-1
cmake/external/snappy.cmake
cmake/external/snappy.cmake
+7
-1
cmake/flags.cmake
cmake/flags.cmake
+2
-9
cmake/version.cmake
cmake/version.cmake
+17
-2
paddle/fluid/framework/details/inplace_op_pass.cc
paddle/fluid/framework/details/inplace_op_pass.cc
+11
-9
paddle/fluid/framework/details/memory_optimize_pass.cc
paddle/fluid/framework/details/memory_optimize_pass.cc
+17
-11
paddle/fluid/framework/details/memory_optimize_pass.h
paddle/fluid/framework/details/memory_optimize_pass.h
+4
-3
paddle/fluid/framework/inplace_op_inference_test.cc
paddle/fluid/framework/inplace_op_inference_test.cc
+1
-0
paddle/fluid/framework/ir/graph.h
paddle/fluid/framework/ir/graph.h
+2
-1
paddle/fluid/framework/scope.cc
paddle/fluid/framework/scope.cc
+5
-1
paddle/fluid/imperative/CMakeLists.txt
paddle/fluid/imperative/CMakeLists.txt
+2
-2
paddle/fluid/inference/CMakeLists.txt
paddle/fluid/inference/CMakeLists.txt
+2
-1
paddle/fluid/inference/analysis/ir_passes/CMakeLists.txt
paddle/fluid/inference/analysis/ir_passes/CMakeLists.txt
+3
-0
paddle/fluid/inference/analysis/passes/memory_optimize_pass.cc
...e/fluid/inference/analysis/passes/memory_optimize_pass.cc
+6
-1
paddle/fluid/memory/allocation/legacy_allocator.cc
paddle/fluid/memory/allocation/legacy_allocator.cc
+15
-15
paddle/fluid/operators/detection/box_coder_op.cc
paddle/fluid/operators/detection/box_coder_op.cc
+6
-14
paddle/fluid/operators/detection/box_coder_op.cu
paddle/fluid/operators/detection/box_coder_op.cu
+2
-8
paddle/fluid/operators/detection/box_coder_op.h
paddle/fluid/operators/detection/box_coder_op.h
+44
-33
paddle/fluid/operators/math/CMakeLists.txt
paddle/fluid/operators/math/CMakeLists.txt
+1
-1
paddle/fluid/operators/ngraph/ngraph_bridge.cc
paddle/fluid/operators/ngraph/ngraph_bridge.cc
+1
-0
paddle/fluid/operators/ngraph/ngraph_ops.h
paddle/fluid/operators/ngraph/ngraph_ops.h
+2
-1
paddle/fluid/operators/ngraph/ops/accuracy_op.h
paddle/fluid/operators/ngraph/ops/accuracy_op.h
+65
-0
paddle/fluid/operators/ngraph/ops/binary_unary_op.h
paddle/fluid/operators/ngraph/ops/binary_unary_op.h
+0
-0
paddle/fluid/operators/ngraph/ops/top_k_op.h
paddle/fluid/operators/ngraph/ops/top_k_op.h
+0
-5
paddle/fluid/operators/pool_op.cc
paddle/fluid/operators/pool_op.cc
+4
-4
paddle/fluid/operators/reader/ctr_reader.cc
paddle/fluid/operators/reader/ctr_reader.cc
+2
-2
paddle/fluid/operators/reader/ctr_reader_test.cc
paddle/fluid/operators/reader/ctr_reader_test.cc
+1
-1
paddle/fluid/operators/reduce_ops/CMakeLists.txt
paddle/fluid/operators/reduce_ops/CMakeLists.txt
+5
-1
paddle/fluid/platform/CMakeLists.txt
paddle/fluid/platform/CMakeLists.txt
+2
-2
paddle/fluid/platform/ngraph_helper.h
paddle/fluid/platform/ngraph_helper.h
+24
-13
paddle/fluid/platform/place.cc
paddle/fluid/platform/place.cc
+0
-6
paddle/fluid/pybind/CMakeLists.txt
paddle/fluid/pybind/CMakeLists.txt
+1
-1
python/CMakeLists.txt
python/CMakeLists.txt
+1
-1
python/paddle/fluid/layers/detection.py
python/paddle/fluid/layers/detection.py
+4
-4
python/paddle/fluid/tests/unittests/ngraph/test_accuracy_ngraph_op.py
...e/fluid/tests/unittests/ngraph/test_accuracy_ngraph_op.py
+30
-0
python/paddle/fluid/tests/unittests/test_box_coder_op.py
python/paddle/fluid/tests/unittests/test_box_coder_op.py
+4
-29
未找到文件。
CMakeLists.txt
浏览文件 @
2ba256df
...
...
@@ -25,12 +25,18 @@ message(STATUS "CXX compiler: ${CMAKE_CXX_COMPILER}, version: "
message
(
STATUS
"C compiler:
${
CMAKE_C_COMPILER
}
, version: "
"
${
CMAKE_C_COMPILER_ID
}
${
CMAKE_C_COMPILER_VERSION
}
"
)
if
(
WIN32
)
set
(
CMAKE_SUPPRESS_REGENERATION ON
)
set
(
CMAKE_STATIC_LIBRARY_PREFIX lib
)
add_definitions
(
"/DGOOGLE_GLOG_DLL_DECL="
)
set
(
CMAKE_C_FLAGS_DEBUG
"
${
CMAKE_C_FLAGS_DEBUG
}
/bigobj /MTd"
)
set
(
CMAKE_C_FLAGS_RELEASE
"
${
CMAKE_C_FLAGS_RELEASE
}
/bigobj /MT"
)
set
(
CMAKE_CXX_FLAGS_DEBUG
"
${
CMAKE_CXX_FLAGS_DEBUG
}
/bigobj /MTd"
)
set
(
CMAKE_CXX_FLAGS_RELEASE
"
${
CMAKE_CXX_FLAGS_RELEASE
}
/bigobj /MT"
)
add_compile_options
(
/wd4068 /wd4129 /wd4244 /wd4267 /wd4297 /wd4530 /wd4577 /wd4819 /wd4838
)
set
(
PADDLE_LINK_FLAGS
"/IGNORE:4006 /IGNORE:4098 /IGNORE:4217 /IGNORE:4221"
)
set
(
CMAKE_STATIC_LINKER_FLAGS
"
${
CMAKE_STATIC_LINKER_FLAGS
}
${
PADDLE_LINK_FLAGS
}
"
)
set
(
CMAKE_SHARED_LINKER_FLAGS
"
${
CMAKE_SHARED_LINKER_FLAGS
}
${
PADDLE_LINK_FLAGS
}
"
)
set
(
CMAKE_EXE_LINKER_FLAGS
"
${
CMAKE_EXE_LINKER_FLAGS
}
${
PADDLE_LINK_FLAGS
}
"
)
endif
(
WIN32
)
find_package
(
CUDA QUIET
)
...
...
cmake/configure.cmake
浏览文件 @
2ba256df
...
...
@@ -152,7 +152,12 @@ endif()
if
(
WITH_MKLML AND MKLML_IOMP_LIB
)
message
(
STATUS
"Enable Intel OpenMP with
${
MKLML_IOMP_LIB
}
"
)
set
(
OPENMP_FLAGS
"-fopenmp"
)
if
(
WIN32
)
# openmp not support well for now on windows
set
(
OPENMP_FLAGS
""
)
else
(
WIN32
)
set
(
OPENMP_FLAGS
"-fopenmp"
)
endif
(
WIN32
)
set
(
CMAKE_C_CREATE_SHARED_LIBRARY_FORBIDDEN_FLAGS
${
OPENMP_FLAGS
}
)
set
(
CMAKE_CXX_CREATE_SHARED_LIBRARY_FORBIDDEN_FLAGS
${
OPENMP_FLAGS
}
)
set
(
CMAKE_C_FLAGS
"
${
CMAKE_C_FLAGS
}
${
OPENMP_FLAGS
}
"
)
...
...
cmake/cuda.cmake
浏览文件 @
2ba256df
...
...
@@ -203,25 +203,26 @@ list(APPEND CUDA_NVCC_FLAGS "-w")
list
(
APPEND CUDA_NVCC_FLAGS
"--expt-relaxed-constexpr"
)
if
(
NOT WIN32
)
if
(
CMAKE_BUILD_TYPE STREQUAL
"Debug"
)
list
(
APPEND CUDA_NVCC_FLAGS
${
CMAKE_CXX_FLAGS_DEBUG
}
)
elseif
(
CMAKE_BUILD_TYPE STREQUAL
"Release"
)
list
(
APPEND CUDA_NVCC_FLAGS
${
CMAKE_CXX_FLAGS_RELEASE
}
)
elseif
(
CMAKE_BUILD_TYPE STREQUAL
"RelWithDebInfo"
)
list
(
APPEND CUDA_NVCC_FLAGS
${
CMAKE_CXX_FLAGS_RELWITHDEBINFO
}
)
elseif
(
CMAKE_BUILD_TYPE STREQUAL
"MinSizeRel"
)
# nvcc 9 does not support -Os. Use Release flags instead
list
(
APPEND CUDA_NVCC_FLAGS
${
CMAKE_CXX_FLAGS_RELEASE
}
)
endif
()
if
(
CMAKE_BUILD_TYPE STREQUAL
"Debug"
)
list
(
APPEND CUDA_NVCC_FLAGS
${
CMAKE_CXX_FLAGS_DEBUG
}
)
elseif
(
CMAKE_BUILD_TYPE STREQUAL
"Release"
)
list
(
APPEND CUDA_NVCC_FLAGS
${
CMAKE_CXX_FLAGS_RELEASE
}
)
elseif
(
CMAKE_BUILD_TYPE STREQUAL
"RelWithDebInfo"
)
list
(
APPEND CUDA_NVCC_FLAGS
${
CMAKE_CXX_FLAGS_RELWITHDEBINFO
}
)
elseif
(
CMAKE_BUILD_TYPE STREQUAL
"MinSizeRel"
)
# nvcc 9 does not support -Os. Use Release flags instead
list
(
APPEND CUDA_NVCC_FLAGS
${
CMAKE_CXX_FLAGS_RELEASE
}
)
endif
()
else
(
NOT WIN32
)
list
(
APPEND CUDA_NVCC_FLAGS
"--compiler-options;/bigobj"
)
if
(
CMAKE_BUILD_TYPE STREQUAL
"Debug"
)
list
(
APPEND CUDA_NVCC_FLAGS
"-g -G"
)
# match the cl's _ITERATOR_DEBUG_LEVEL
list
(
APPEND CUDA_NVCC_FLAGS
"-D_DEBUG"
)
elseif
(
CMAKE_BUILD_TYPE STREQUAL
"Release"
)
list
(
APPEND CUDA_NVCC_FLAGS
"-O3 -DNDEBUG"
)
else
()
list
(
APPEND CUDA_NVCC_FLAGS
"-Xcompiler
\"
/wd 4244 /wd 4267 /wd 4819
\"
"
)
list
(
APPEND CUDA_NVCC_FLAGS
"--compiler-options;/bigobj"
)
if
(
CMAKE_BUILD_TYPE STREQUAL
"Debug"
)
list
(
APPEND CUDA_NVCC_FLAGS
"-g -G"
)
# match the cl's _ITERATOR_DEBUG_LEVEL
list
(
APPEND CUDA_NVCC_FLAGS
"-D_DEBUG"
)
elseif
(
CMAKE_BUILD_TYPE STREQUAL
"Release"
)
list
(
APPEND CUDA_NVCC_FLAGS
"-O3 -DNDEBUG"
)
else
()
message
(
FATAL
"Windows only support Release or Debug build now. Please set visual studio build type to Release/Debug, x64 build."
)
endif
()
endif
(
NOT WIN32
)
...
...
cmake/external/glog.cmake
浏览文件 @
2ba256df
...
...
@@ -20,8 +20,10 @@ SET(GLOG_INCLUDE_DIR "${GLOG_INSTALL_DIR}/include" CACHE PATH "glog include dire
IF
(
WIN32
)
SET
(
GLOG_LIBRARIES
"
${
GLOG_INSTALL_DIR
}
/lib/libglog.lib"
CACHE FILEPATH
"glog library."
FORCE
)
SET
(
GLOG_CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
/wd4267 /wd4530"
)
ELSE
(
WIN32
)
SET
(
GLOG_LIBRARIES
"
${
GLOG_INSTALL_DIR
}
/lib/libglog.a"
CACHE FILEPATH
"glog library."
FORCE
)
SET
(
GLOG_CMAKE_CXX_FLAGS
${
CMAKE_CXX_FLAGS
}
)
ENDIF
(
WIN32
)
INCLUDE_DIRECTORIES
(
${
GLOG_INCLUDE_DIR
}
)
...
...
@@ -39,7 +41,7 @@ ExternalProject_Add(
UPDATE_COMMAND
""
CMAKE_ARGS -DCMAKE_CXX_COMPILER=
${
CMAKE_CXX_COMPILER
}
-DCMAKE_C_COMPILER=
${
CMAKE_C_COMPILER
}
-DCMAKE_CXX_FLAGS=
${
CMAKE_CXX_FLAGS
}
-DCMAKE_CXX_FLAGS=
${
GLOG_
CMAKE_CXX_FLAGS
}
-DCMAKE_CXX_FLAGS_RELEASE=
${
CMAKE_CXX_FLAGS_RELEASE
}
-DCMAKE_CXX_FLAGS_DEBUG=
${
CMAKE_CXX_FLAGS_DEBUG
}
-DCMAKE_C_FLAGS=
${
CMAKE_C_FLAGS
}
...
...
cmake/external/mkldnn.cmake
浏览文件 @
2ba256df
...
...
@@ -49,6 +49,8 @@ IF(NOT WIN32)
SET
(
MKLDNN_FLAG
"
${
MKLDNN_FLAG
}
-Wno-unused-result -Wno-unused-value"
)
SET
(
MKLDNN_CFLAG
"
${
CMAKE_C_FLAGS
}
${
MKLDNN_FLAG
}
"
)
SET
(
MKLDNN_CXXFLAG
"
${
CMAKE_CXX_FLAGS
}
${
MKLDNN_FLAG
}
"
)
ELSE
()
SET
(
MKLDNN_CXXFLAG
"
${
CMAKE_CXX_FLAGS
}
/EHsc"
)
ENDIF
(
NOT WIN32
)
ExternalProject_Add
(
...
...
@@ -61,7 +63,6 @@ ExternalProject_Add(
UPDATE_COMMAND
""
CMAKE_ARGS -DCMAKE_CXX_COMPILER=
${
CMAKE_CXX_COMPILER
}
CMAKE_ARGS -DCMAKE_C_COMPILER=
${
CMAKE_C_COMPILER
}
CMAKE_ARGS -DCMAKE_CXX_FLAGS=
${
CMAKE_CXX_FLAGS
}
CMAKE_ARGS -DCMAKE_CXX_FLAGS_RELEASE=
${
CMAKE_CXX_FLAGS_RELEASE
}
CMAKE_ARGS -DCMAKE_CXX_FLAGS_DEBUG=
${
CMAKE_CXX_FLAGS_DEBUG
}
CMAKE_ARGS -DCMAKE_C_FLAGS=
${
CMAKE_C_FLAGS
}
...
...
cmake/external/snappy.cmake
浏览文件 @
2ba256df
...
...
@@ -20,6 +20,12 @@ set(SNAPPY_SOURCES_DIR ${THIRD_PARTY_PATH}/snappy)
set
(
SNAPPY_INSTALL_DIR
${
THIRD_PARTY_PATH
}
/install/snappy
)
set
(
SNAPPY_INCLUDE_DIR
"
${
SNAPPY_INSTALL_DIR
}
/include"
CACHE PATH
"snappy include directory."
FORCE
)
if
(
WIN32
)
SET
(
SNAPPY_CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
/wd4244 /wd4267"
)
else
()
SET
(
SNAPPY_CMAKE_CXX_FLAGS
${
CMAKE_CXX_FLAGS
}
)
endif
()
ExternalProject_Add
(
extern_snappy
GIT_REPOSITORY
"https://github.com/google/snappy"
...
...
@@ -31,7 +37,7 @@ ExternalProject_Add(
-DCMAKE_C_FLAGS=
${
CMAKE_C_FLAGS
}
-DCMAKE_C_FLAGS_DEBUG=
${
CMAKE_C_FLAGS_DEBUG
}
-DCMAKE_C_FLAGS_RELEASE=
${
CMAKE_C_FLAGS_RELEASE
}
-DCMAKE_CXX_FLAGS=
${
CMAKE_CXX_FLAGS
}
-DCMAKE_CXX_FLAGS=
${
SNAPPY_
CMAKE_CXX_FLAGS
}
-DCMAKE_CXX_FLAGS_RELEASE=
${
CMAKE_CXX_FLAGS_RELEASE
}
-DCMAKE_CXX_FLAGS_DEBUG=
${
CMAKE_CXX_FLAGS_DEBUG
}
-DCMAKE_INSTALL_PREFIX=
${
SNAPPY_INSTALL_DIR
}
...
...
cmake/flags.cmake
浏览文件 @
2ba256df
...
...
@@ -147,12 +147,6 @@ set(GPU_COMMON_FLAGS
-Wno-error=unused-function
# Warnings in Numpy Header.
-Wno-error=array-bounds
# Warnings in Eigen::array
)
else
(
NOT WIN32
)
set
(
COMMON_FLAGS
"/w"
)
#disable all warnings.
set
(
GPU_COMMON_FLAGS
"/w"
)
#disable all warnings
endif
(
NOT WIN32
)
if
(
APPLE
)
...
...
@@ -193,8 +187,7 @@ safe_set_static_flag()
CMAKE_CXX_FLAGS_MINSIZEREL CMAKE_CXX_FLAGS_RELWITHDEBINFO
CMAKE_C_FLAGS CMAKE_C_FLAGS_DEBUG CMAKE_C_FLAGS_RELEASE
CMAKE_C_FLAGS_MINSIZEREL CMAKE_C_FLAGS_RELWITHDEBINFO
)
if
(
${
flag_var
}
MATCHES
"/W3"
)
string
(
REGEX REPLACE
"/W3"
"/w"
${
flag_var
}
"
${${
flag_var
}}
"
)
endif
(
${
flag_var
}
MATCHES
"/W3"
)
string
(
REGEX REPLACE
"(^| )/W[0-9]( |$)"
" "
${
flag_var
}
"
${${
flag_var
}}
"
)
set
(
flag_var
"
${
flag_var
}
/w"
)
endforeach
(
flag_var
)
endif
(
WIN32
)
cmake/version.cmake
浏览文件 @
2ba256df
...
...
@@ -31,8 +31,23 @@ while ("${PADDLE_VERSION}" STREQUAL "")
set
(
tmp_version
"
${
GIT_TAG_NAME
}
~1"
)
endif
()
else
()
# otherwise, we always set PADDLE_VERSION to 0.0.0 to represent latest
set
(
PADDLE_VERSION
"0.0.0"
)
execute_process
(
COMMAND
${
GIT_EXECUTABLE
}
describe --exact-match --tags
${
tmp_version
}
WORKING_DIRECTORY
${
PADDLE_SOURCE_DIR
}
OUTPUT_VARIABLE GIT_EXACT_TAG_NAME
RESULT_VARIABLE GIT_EXACT_TAG_RESULT
ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE
)
if
(
NOT
${
GIT_EXACT_TAG_NAME
}
)
# Check if current branch is tag branch
if
(
${
GIT_EXACT_TAG_NAME
}
MATCHES
"v
${
TAG_VERSION_REGEX
}
"
)
string
(
REPLACE
"v"
""
PADDLE_VERSION
${
GIT_EXACT_TAG_NAME
}
)
else
()
set
(
PADDLE_VERSION
"0.0.0"
)
endif
()
else
()
# otherwise, we always set PADDLE_VERSION to 0.0.0 to represent latest
set
(
PADDLE_VERSION
"0.0.0"
)
endif
()
endif
()
else
()
set
(
PADDLE_VERSION
"0.0.0"
)
...
...
paddle/fluid/framework/details/inplace_op_pass.cc
浏览文件 @
2ba256df
...
...
@@ -403,18 +403,20 @@ void GraphView::Build(ir::Graph* g) {
// 2. track the nodes which used by parameter server.
// these node can not be inplaced, otherwise trainer
// pserver can not find each other name.
for
(
auto
&
node
:
g
->
Nodes
())
{
if
(
!
node
->
IsOp
())
continue
;
if
(
node
->
Name
()
==
"send"
)
{
for
(
auto
&
in
:
node
->
inputs
)
{
dup_nodes_
.
emplace
(
in
->
Name
());
}
auto
update_skip_set
=
[
&
](
ir
::
Node
*
node
)
{
for
(
auto
&
in
:
node
->
inputs
)
{
if
(
in
->
IsVar
()
&&
in
->
Var
()
!=
nullptr
)
dup_nodes_
.
emplace
(
in
->
Name
());
}
if
(
node
->
Name
()
==
"recv"
)
{
for
(
auto
&
out
:
node
->
outputs
)
{
for
(
auto
&
out
:
node
->
outputs
)
{
if
(
out
->
IsVar
()
&&
out
->
Var
()
!=
nullptr
)
dup_nodes_
.
emplace
(
out
->
Name
());
}
}
};
for
(
auto
&
node
:
g
->
Nodes
())
{
if
(
!
node
->
IsOp
())
continue
;
if
(
node
->
Name
()
==
"send"
)
update_skip_set
(
node
);
if
(
node
->
Name
()
==
"recv"
)
update_skip_set
(
node
);
if
(
node
->
Name
()
==
"prefetch"
)
update_skip_set
(
node
);
}
}
...
...
paddle/fluid/framework/details/memory_optimize_pass.cc
浏览文件 @
2ba256df
...
...
@@ -51,8 +51,7 @@ static inline bool IsSameDesc(OpDesc* op1, OpDesc* op2) {
std
::
unique_ptr
<
ir
::
Graph
>
MemoryOptimizePass
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
auto
nodes
=
graph
->
Nodes
();
auto
subblock_vars
=
GetSubBlockVars
(
nodes
);
skip_set_
.
insert
(
subblock_vars
.
begin
(),
subblock_vars
.
end
());
CollectSkipVarsSet
(
nodes
);
cfg_
.
reset
(
new
details
::
ControlFlowGraph
(
*
graph
));
cfg_
->
LiveVariableAnalysis
();
...
...
@@ -224,20 +223,27 @@ void MemoryOptimizePass::SubGraphOptimize(OpDesc* op_desc) const {
}
}
std
::
unordered_set
<
std
::
string
>
MemoryOptimizePass
::
GetSubBlockVars
(
void
MemoryOptimizePass
::
CollectSkipVarsSet
(
const
std
::
unordered_set
<
ir
::
Node
*>&
nodes
)
const
{
std
::
unordered_set
<
std
::
string
>
vars
;
auto
update_skip_set
=
[
&
](
OpDesc
*
op_desc
)
{
auto
inputs
=
op_desc
->
InputArgumentNames
();
auto
outputs
=
op_desc
->
OutputArgumentNames
();
skip_set_
.
insert
(
inputs
.
begin
(),
inputs
.
end
());
skip_set_
.
insert
(
outputs
.
begin
(),
outputs
.
end
());
};
for
(
auto
&
op
:
nodes
)
{
if
(
!
op
->
IsOp
()
||
op
->
Op
()
==
nullptr
)
continue
;
auto
*
op_desc
=
op
->
Op
();
if
(
OpHasSubBlock
(
op_desc
))
{
auto
inputs
=
op_desc
->
InputArgumentNames
();
auto
outputs
=
op_desc
->
OutputArgumentNames
();
vars
.
insert
(
inputs
.
begin
(),
inputs
.
end
());
vars
.
insert
(
outputs
.
begin
(),
outputs
.
end
());
}
// NOTE(dzhwinter):
// current block can not reuse next level block vars.
if
(
OpHasSubBlock
(
op_desc
))
update_skip_set
(
op_desc
);
// NOTE(dzhwinter):
// distributed ops input/output name need to
// keep same bettwen trainer/pserver
if
(
op_desc
->
Type
()
==
"send"
)
update_skip_set
(
op_desc
);
if
(
op_desc
->
Type
()
==
"recv"
)
update_skip_set
(
op_desc
);
if
(
op_desc
->
Type
()
==
"prefetch"
)
update_skip_set
(
op_desc
);
}
return
vars
;
}
void
MemoryOptimizePass
::
RenameVarInGraphDesc
(
const
std
::
string
&
var
,
...
...
paddle/fluid/framework/details/memory_optimize_pass.h
浏览文件 @
2ba256df
...
...
@@ -55,9 +55,10 @@ class MemoryOptimizePass : public ir::Pass {
ir
::
Graph
*
graph
)
const
;
void
SubGraphOptimize
(
OpDesc
*
op_desc
)
const
;
// scan subblock and collect the output/input variables.
std
::
unordered_set
<
std
::
string
>
GetSubBlockVars
(
const
std
::
unordered_set
<
ir
::
Node
*>&
)
const
;
// 1. scan op with subblock and collect the output/input vars.
// while, while_grad, conditional_block
// 2. scan distributed ops and collect the output/input vars
void
CollectSkipVarsSet
(
const
std
::
unordered_set
<
ir
::
Node
*>&
)
const
;
private:
// Reuse Node Pool, Owned.
...
...
paddle/fluid/framework/inplace_op_inference_test.cc
浏览文件 @
2ba256df
...
...
@@ -276,6 +276,7 @@ TEST(InferInplace, MultiGradInplaceInToOut) {
auto
&
infer_inplace
=
OpInfoMap
::
Instance
().
Get
(
op
->
Type
()).
infer_inplace_
;
auto
in_to_outs
=
infer_inplace
(
*
op
,
op
->
Block
());
EXPECT_EQ
(
in_to_outs
.
size
(),
3ul
);
std
::
unordered_map
<
std
::
string
,
std
::
string
>
expects
=
{
{
"o0"
,
"a0"
},
{
"y0"
,
"b0"
},
{
"z0"
,
"c0"
},
...
...
paddle/fluid/framework/ir/graph.h
浏览文件 @
2ba256df
...
...
@@ -141,7 +141,8 @@ class Graph {
ir
::
Node
*
CreateControlDepVar
()
{
// TODO(panyx0718): control var name should be really unique.
const
std
::
string
name
=
string
::
Sprintf
(
"%s@%llu"
,
ir
::
Node
::
kControlDepVarName
,
node_set_
.
size
());
"%s@%llu"
,
static_cast
<
const
char
*>
(
ir
::
Node
::
kControlDepVarName
),
node_set_
.
size
());
auto
*
x
=
AddNode
(
new
ir
::
Node
(
name
,
ir
::
Node
::
Type
::
kVariable
));
x
->
SetId
(
num_node_created_
++
);
return
x
;
...
...
paddle/fluid/framework/scope.cc
浏览文件 @
2ba256df
...
...
@@ -22,7 +22,11 @@ limitations under the License. */
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/string/printf.h"
DECLARE_bool
(
benchmark
);
DEFINE_bool
(
benchmark
,
false
,
"Doing memory benchmark. It will make deleting scope synchronized, "
"and add some memory usage logs."
"Default cuda is asynchronous device, set to True will"
"force op run in synchronous mode."
);
DEFINE_bool
(
eager_delete_scope
,
true
,
...
...
paddle/fluid/imperative/CMakeLists.txt
浏览文件 @
2ba256df
if
(
WITH_PYTHON
)
cc_library
(
layer SRCS layer.cc DEPS proto_desc operator device_context blas
)
cc_library
(
tracer SRCS tracer.cc DEPS proto_desc device_context
)
cc_library
(
layer SRCS layer.cc DEPS proto_desc operator device_context blas
pybind
)
cc_library
(
tracer SRCS tracer.cc DEPS proto_desc device_context
pybind
)
cc_library
(
engine SRCS engine.cc
)
endif
()
paddle/fluid/inference/CMakeLists.txt
浏览文件 @
2ba256df
...
...
@@ -58,12 +58,13 @@ if(WIN32)
sep_library
(
paddle_fluid_shared SHARED SRCS
${
SHARED_INFERENCE_SRCS
}
DEPS
${
fluid_modules
}
paddle_fluid_api reset_tensor_array
analysis_config paddle_pass_builder
)
target_link_libraries
(
paddle_fluid_shared shlwapi
)
else
(
WIN32
)
cc_library
(
paddle_fluid_shared SHARED SRCS
${
SHARED_INFERENCE_SRCS
}
DEPS
${
fluid_modules
}
paddle_fluid_api reset_tensor_array
analysis_config paddle_pass_builder
)
endif
()
get_property
(
os_dependency_modules GLOBAL PROPERTY OS_DEPENDENCY_MODULES
)
target_link_libraries
(
paddle_fluid_shared
${
os_dependency_modules
}
)
set_target_properties
(
paddle_fluid_shared PROPERTIES OUTPUT_NAME paddle_fluid
)
if
(
NOT APPLE AND NOT WIN32
)
...
...
paddle/fluid/inference/analysis/ir_passes/CMakeLists.txt
浏览文件 @
2ba256df
cc_library
(
subgraph_detector SRCS subgraph_detector.cc DEPS proto_desc
)
if
(
WITH_TESTING
)
add_dependencies
(
subgraph_detector gtest
)
endif
()
if
(
WITH_GPU AND TENSORRT_FOUND
)
cc_library
(
tensorrt_subgraph_pass SRCS tensorrt_subgraph_pass.cc DEPS subgraph_detector tensorrt_op_teller
)
...
...
paddle/fluid/inference/analysis/passes/memory_optimize_pass.cc
浏览文件 @
2ba256df
...
...
@@ -18,6 +18,7 @@
#include <limits>
#include <map>
#include <string>
#include <type_traits>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/ir/graph_helper.h"
...
...
@@ -168,7 +169,11 @@ bool FindSuitableTensorToReuse(
if
(
!
cluster
->
count
(
candidate
))
continue
;
size_t
space
=
space_table
.
at
(
candidate
);
size_t
space_diff
=
std
::
abs
<
size_t
>
(
space
-
space_required
);
PADDLE_ENFORCE
(
space
<=
std
::
numeric_limits
<
std
::
make_signed
<
size_t
>::
type
>::
max
(),
"space overload"
);
size_t
space_diff
=
std
::
abs
((
std
::
make_signed
<
size_t
>::
type
)
space
-
space_required
);
if
(
space_diff
<
best_fit
.
second
)
{
best_fit
.
first
=
candidate
;
best_fit
.
second
=
space_diff
;
...
...
paddle/fluid/memory/allocation/legacy_allocator.cc
浏览文件 @
2ba256df
...
...
@@ -35,7 +35,6 @@ DEFINE_bool(init_allocated_mem, false,
"To find this error in time, we use init_allocated_mem to indicate "
"that initializing the allocated memory with a small value "
"during unit testing."
);
DECLARE_bool
(
benchmark
);
DECLARE_double
(
fraction_of_gpu_memory_to_use
);
namespace
paddle
{
...
...
@@ -188,21 +187,20 @@ void *Alloc<platform::CUDAPlace>(const platform::CUDAPlace &place,
platform
::
SetDeviceId
(
place
.
device
);
size_t
avail
,
total
;
platform
::
GpuMemoryUsage
(
&
avail
,
&
total
);
LOG
(
WARNING
)
<<
"Cannot allocate "
<<
string
::
HumanReadableSize
(
size
)
<<
" in GPU "
<<
place
.
device
<<
", available "
<<
string
::
HumanReadableSize
(
avail
);
LOG
(
WARNING
)
<<
"total "
<<
total
;
LOG
(
WARNING
)
<<
"GpuMinChunkSize "
<<
string
::
HumanReadableSize
(
buddy_allocator
->
GetMinChunkSize
());
LOG
(
WARNING
)
<<
"GpuMaxChunkSize "
<<
string
::
HumanReadableSize
(
buddy_allocator
->
GetMaxChunkSize
());
LOG
(
WARNING
)
<<
"GPU memory used: "
<<
string
::
HumanReadableSize
(
Used
<
platform
::
CUDAPlace
>
(
place
));
LOG
(
FATAL
)
<<
"Cannot allocate "
<<
string
::
HumanReadableSize
(
size
)
<<
" in GPU "
<<
place
.
device
<<
", available "
<<
string
::
HumanReadableSize
(
avail
)
<<
"total "
<<
total
<<
"GpuMinChunkSize "
<<
string
::
HumanReadableSize
(
buddy_allocator
->
GetMinChunkSize
())
<<
"GpuMaxChunkSize "
<<
string
::
HumanReadableSize
(
buddy_allocator
->
GetMaxChunkSize
())
<<
"GPU memory used: "
<<
string
::
HumanReadableSize
(
Used
<
platform
::
CUDAPlace
>
(
place
));
platform
::
SetDeviceId
(
cur_dev
);
}
else
{
if
(
FLAGS_benchmark
)
allocation
::
GPUMemMonitor
.
Add
(
place
.
device
,
size
);
if
(
VLOG_IS_ON
(
3
))
{
allocation
::
GPUMemMonitor
.
Add
(
place
.
device
,
size
);
}
if
(
FLAGS_init_allocated_mem
)
{
cudaMemset
(
ptr
,
0xEF
,
size
);
}
...
...
@@ -218,7 +216,9 @@ void Free<platform::CUDAPlace>(const platform::CUDAPlace &place, void *p,
size_t
size
)
{
#ifdef PADDLE_WITH_CUDA
GetGPUBuddyAllocator
(
place
.
device
)
->
Free
(
p
);
if
(
FLAGS_benchmark
)
allocation
::
GPUMemMonitor
.
Minus
(
place
.
device
,
size
);
if
(
VLOG_IS_ON
(
3
))
{
allocation
::
GPUMemMonitor
.
Minus
(
place
.
device
,
size
);
}
#else
PADDLE_THROW
(
"'CUDAPlace' is not supported in CPU only device."
);
#endif
...
...
paddle/fluid/operators/detection/box_coder_op.cc
浏览文件 @
2ba256df
...
...
@@ -38,20 +38,12 @@ class BoxCoderOp : public framework::OperatorWithKernel {
"The shape of PriorBox is [N, 4]"
);
if
(
ctx
->
HasInput
(
"PriorBoxVar"
))
{
auto
prior_box_var_dims
=
ctx
->
GetInputDim
(
"PriorBoxVar"
);
PADDLE_ENFORCE
(
prior_box_var_dims
.
size
()
==
1
||
prior_box_var_dims
.
size
()
==
2
,
"Input(PriorBoxVar) of BoxCoderOp should be 1 or 2."
);
if
(
prior_box_var_dims
.
size
()
==
1
)
{
PADDLE_ENFORCE_EQ
(
prior_box_var_dims
[
0
],
4
,
"The 1st dimension of Input(PriorBoxVar) should be 4"
"when the rank is 1."
);
}
else
{
PADDLE_ENFORCE_EQ
(
prior_box_dims
,
prior_box_var_dims
,
"The dimension of Input(PriorBoxVar) should be equal to"
"the dimension of Input(PriorBox when the rank is 2.)"
);
}
PADDLE_ENFORCE
(
prior_box_var_dims
.
size
()
==
2
,
"Input(PriorBoxVar) of BoxCoderOp should be 2."
);
PADDLE_ENFORCE_EQ
(
prior_box_dims
,
prior_box_var_dims
,
"The dimension of Input(PriorBoxVar) should be equal to"
"the dimension of Input(PriorBox) when the rank is 2."
);
}
}
...
...
paddle/fluid/operators/detection/box_coder_op.cu
浏览文件 @
2ba256df
...
...
@@ -56,10 +56,7 @@ __global__ void EncodeCenterSizeKernel(
output
[
idx
*
len
+
2
]
=
log
(
fabs
(
target_box_width
/
prior_box_width
));
output
[
idx
*
len
+
3
]
=
log
(
fabs
(
target_box_height
/
prior_box_height
));
if
(
prior_box_var_data
)
{
int
prior_var_offset
=
0
;
if
(
prior_box_var_size
==
2
)
{
prior_var_offset
=
col_idx
*
len
;
}
int
prior_var_offset
=
col_idx
*
len
;
output
[
idx
*
len
]
/=
prior_box_var_data
[
prior_var_offset
];
output
[
idx
*
len
+
1
]
/=
prior_box_var_data
[
prior_var_offset
+
1
];
output
[
idx
*
len
+
2
]
/=
prior_box_var_data
[
prior_var_offset
+
2
];
...
...
@@ -99,10 +96,7 @@ __global__ void DecodeCenterSizeKernel(
T
box_var_x
=
T
(
1
),
box_var_y
=
T
(
1
);
T
box_var_w
=
T
(
1
),
box_var_h
=
T
(
1
);
if
(
prior_box_var_data
)
{
int
prior_var_offset
=
0
;
if
(
prior_box_var_size
==
2
)
{
prior_var_offset
=
axis
==
0
?
col_idx
*
len
:
row_idx
*
len
;
}
int
prior_var_offset
=
axis
==
0
?
col_idx
*
len
:
row_idx
*
len
;
box_var_x
=
prior_box_var_data
[
prior_var_offset
];
box_var_y
=
prior_box_var_data
[
prior_var_offset
+
1
];
box_var_w
=
prior_box_var_data
[
prior_var_offset
+
2
];
...
...
paddle/fluid/operators/detection/box_coder_op.h
浏览文件 @
2ba256df
...
...
@@ -79,10 +79,7 @@ class BoxCoderKernel : public framework::OpKernel<T> {
output
[
offset
+
3
]
=
std
::
log
(
std
::
fabs
(
target_box_height
/
prior_box_height
));
if
(
prior_box_var
)
{
int
prior_var_offset
=
0
;
if
(
prior_box_var
->
dims
().
size
()
==
2
)
{
prior_var_offset
=
j
*
len
;
}
int
prior_var_offset
=
j
*
len
;
output
[
offset
]
/=
prior_box_var_data
[
prior_var_offset
];
output
[
offset
+
1
]
/=
prior_box_var_data
[
prior_var_offset
+
1
];
output
[
offset
+
2
]
/=
prior_box_var_data
[
prior_var_offset
+
2
];
...
...
@@ -95,11 +92,12 @@ class BoxCoderKernel : public framework::OpKernel<T> {
}
}
}
template
<
int
axis
,
int
var_size
>
void
DecodeCenterSize
(
const
framework
::
Tensor
*
target_box
,
const
framework
::
Tensor
*
prior_box
,
const
framework
::
Tensor
*
prior_box_var
,
const
bool
normalized
,
const
int
axis
,
const
std
::
vector
<
float
>
variance
,
T
*
output
)
const
{
const
bool
normalized
,
std
::
vector
<
float
>
variance
,
T
*
output
)
const
{
int64_t
row
=
target_box
->
dims
()[
0
];
int64_t
col
=
target_box
->
dims
()[
1
];
int64_t
len
=
target_box
->
dims
()[
2
];
...
...
@@ -107,19 +105,17 @@ class BoxCoderKernel : public framework::OpKernel<T> {
auto
*
target_box_data
=
target_box
->
data
<
T
>
();
auto
*
prior_box_data
=
prior_box
->
data
<
T
>
();
const
T
*
prior_box_var_data
=
nullptr
;
if
(
prior_box_var
)
prior_box_var_data
=
prior_box_var
->
data
<
T
>
();
if
(
var_size
==
2
)
prior_box_var_data
=
prior_box_var
->
data
<
T
>
();
int
prior_box_offset
=
0
;
T
var_data
[
4
]
=
{
1.
,
1.
,
1.
,
1.
};
T
*
var_ptr
=
var_data
;
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for collapse(2)
#endif
for
(
int64_t
i
=
0
;
i
<
row
;
++
i
)
{
for
(
int64_t
j
=
0
;
j
<
col
;
++
j
)
{
size_t
offset
=
i
*
col
*
len
+
j
*
len
;
if
(
axis
==
0
)
{
prior_box_offset
=
j
*
len
;
}
else
if
(
axis
==
1
)
{
prior_box_offset
=
i
*
len
;
}
prior_box_offset
=
axis
==
0
?
j
*
len
:
i
*
len
;
T
prior_box_width
=
prior_box_data
[
prior_box_offset
+
2
]
-
prior_box_data
[
prior_box_offset
]
+
(
normalized
==
false
);
...
...
@@ -133,26 +129,18 @@ class BoxCoderKernel : public framework::OpKernel<T> {
T
target_box_center_x
=
0
,
target_box_center_y
=
0
;
T
target_box_width
=
0
,
target_box_height
=
0
;
T
box_var_x
=
T
(
1
),
box_var_y
=
T
(
1
);
T
box_var_w
=
T
(
1
),
box_var_h
=
T
(
1
);
if
(
prior_box_var
)
{
int
prior_var_offset
=
0
;
if
(
prior_box_var
->
dims
().
size
()
==
2
)
{
if
(
axis
==
0
)
prior_var_offset
=
j
*
len
;
else
if
(
axis
==
1
)
prior_var_offset
=
i
*
len
;
}
box_var_x
=
prior_box_var_data
[
prior_var_offset
];
box_var_y
=
prior_box_var_data
[
prior_var_offset
+
1
];
box_var_w
=
prior_box_var_data
[
prior_var_offset
+
2
];
box_var_h
=
prior_box_var_data
[
prior_var_offset
+
3
];
}
else
if
(
!
(
variance
.
empty
()))
{
box_var_x
=
static_cast
<
T
>
(
variance
[
0
]);
box_var_y
=
static_cast
<
T
>
(
variance
[
1
]);
box_var_w
=
static_cast
<
T
>
(
variance
[
2
]);
box_var_h
=
static_cast
<
T
>
(
variance
[
3
]);
int
prior_var_offset
=
axis
==
0
?
j
*
len
:
i
*
len
;
if
(
var_size
==
2
)
{
std
::
memcpy
(
var_ptr
,
prior_box_var_data
+
prior_var_offset
,
4
*
sizeof
(
T
));
}
else
if
(
var_size
==
1
)
{
var_ptr
=
reinterpret_cast
<
T
*>
(
variance
.
data
());
}
T
box_var_x
=
*
var_ptr
;
T
box_var_y
=
*
(
var_ptr
+
1
);
T
box_var_w
=
*
(
var_ptr
+
2
);
T
box_var_h
=
*
(
var_ptr
+
3
);
target_box_center_x
=
box_var_x
*
target_box_data
[
offset
]
*
prior_box_width
+
prior_box_center_x
;
...
...
@@ -211,8 +199,31 @@ class BoxCoderKernel : public framework::OpKernel<T> {
EncodeCenterSize
(
target_box
,
prior_box
,
prior_box_var
,
normalized
,
variance
,
output
);
}
else
if
(
code_type
==
BoxCodeType
::
kDecodeCenterSize
)
{
DecodeCenterSize
(
target_box
,
prior_box
,
prior_box_var
,
normalized
,
axis
,
variance
,
output
);
if
(
prior_box_var
)
{
if
(
axis
==
0
)
{
DecodeCenterSize
<
0
,
2
>
(
target_box
,
prior_box
,
prior_box_var
,
normalized
,
variance
,
output
);
}
else
{
DecodeCenterSize
<
1
,
2
>
(
target_box
,
prior_box
,
prior_box_var
,
normalized
,
variance
,
output
);
}
}
else
if
(
!
(
variance
.
empty
()))
{
if
(
axis
==
0
)
{
DecodeCenterSize
<
0
,
1
>
(
target_box
,
prior_box
,
prior_box_var
,
normalized
,
variance
,
output
);
}
else
{
DecodeCenterSize
<
1
,
1
>
(
target_box
,
prior_box
,
prior_box_var
,
normalized
,
variance
,
output
);
}
}
else
{
if
(
axis
==
0
)
{
DecodeCenterSize
<
0
,
0
>
(
target_box
,
prior_box
,
prior_box_var
,
normalized
,
variance
,
output
);
}
else
{
DecodeCenterSize
<
1
,
0
>
(
target_box
,
prior_box
,
prior_box_var
,
normalized
,
variance
,
output
);
}
}
}
}
};
...
...
paddle/fluid/operators/math/CMakeLists.txt
浏览文件 @
2ba256df
...
...
@@ -37,7 +37,7 @@ math_library(concat_and_split)
math_library
(
context_project DEPS im2col math_function
)
math_library
(
cross_entropy
)
math_library
(
cos_sim_functor
)
math_library
(
depthwise_conv
)
math_library
(
depthwise_conv
DEPS cub
)
math_library
(
im2col
)
math_library
(
sampler
)
...
...
paddle/fluid/operators/ngraph/ngraph_bridge.cc
浏览文件 @
2ba256df
...
...
@@ -31,6 +31,7 @@ std::map<std::string,
std
::
shared_ptr
<
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
ngraph
::
Node
>>>
)
>>
NgraphBridge
::
NG_NODE_MAP
=
{
{
"accuracy"
,
NG_OPS
::
BuildAccuracyNode
},
{
"conv2d"
,
NG_OPS
::
BuildConv2dNode
},
{
"conv2d_grad"
,
NG_OPS
::
BuildConv2dGradNode
},
{
"elementwise_add"
,
NG_OPS
::
BuildElementwiseAddNode
},
...
...
paddle/fluid/operators/ngraph/ngraph_ops.h
浏览文件 @
2ba256df
...
...
@@ -21,7 +21,8 @@ limitations under the License. */
#pragma once
#include "ops/binary_unnary_op.h"
#include "ops/accuracy_op.h"
#include "ops/binary_unary_op.h"
#include "ops/conv2d_op.h"
#include "ops/elementwise_add_op.h"
#include "ops/fill_constant_op.h"
...
...
paddle/fluid/operators/ngraph/ops/accuracy_op.h
0 → 100644
浏览文件 @
2ba256df
/*Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include <vector>
#include "ngraph/ngraph.hpp"
#include "paddle/fluid/platform/ngraph_helper.h"
namespace
paddle
{
namespace
operators
{
namespace
ngraphs
{
void
BuildAccuracyNode
(
const
std
::
shared_ptr
<
framework
::
OperatorBase
>&
op
,
std
::
shared_ptr
<
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
ngraph
::
Node
>>>
ngb_node_map
)
{
auto
indices
=
platform
::
GetInputNode
(
op
,
"Indices"
,
ngb_node_map
);
auto
label
=
platform
::
GetInputNode
(
op
,
"Label"
,
ngb_node_map
);
auto
inference
=
platform
::
GetInputNode
(
op
,
"Out"
,
ngb_node_map
);
auto
inference_shape
=
inference
->
get_shape
();
size_t
num_samples
=
inference_shape
.
at
(
0
);
size_t
k
=
inference_shape
.
at
(
1
);
std
::
shared_ptr
<
ngraph
::
Node
>
label_k
=
label
;
if
(
k
>
1
)
{
auto
label_1d
=
std
::
make_shared
<
ngraph
::
op
::
Reshape
>
(
label
,
ngraph
::
AxisVector
{
0
,
1
},
ngraph
::
Shape
{
num_samples
});
label_k
=
std
::
make_shared
<
ngraph
::
op
::
Broadcast
>
(
label_1d
,
inference_shape
,
ngraph
::
AxisSet
{
1
});
}
auto
node_equal
=
std
::
make_shared
<
ngraph
::
op
::
Equal
>
(
indices
,
label_k
);
auto
node_eq_int
=
std
::
make_shared
<
ngraph
::
op
::
Convert
>
(
node_equal
,
ngraph
::
element
::
i64
);
auto
num_correct_0d
=
std
::
make_shared
<
ngraph
::
op
::
Sum
>
(
node_eq_int
,
ngraph
::
AxisSet
{
0
,
1
});
std
::
shared_ptr
<
ngraph
::
Node
>
num_correct
=
platform
::
NgReshaper
(
num_correct_0d
,
ngraph
::
Shape
{
1
});
std
::
shared_ptr
<
ngraph
::
Node
>
n_samples
=
ngraph
::
op
::
Constant
::
create
(
ngraph
::
element
::
i64
,
ngraph
::
Shape
{
1
},
{
num_samples
});
std
::
shared_ptr
<
ngraph
::
Node
>
accuracy
=
std
::
make_shared
<
ngraph
::
op
::
Divide
>
(
std
::
make_shared
<
ngraph
::
op
::
Convert
>
(
num_correct
,
ngraph
::
element
::
f32
),
std
::
make_shared
<
ngraph
::
op
::
Convert
>
(
n_samples
,
ngraph
::
element
::
f32
));
platform
::
SetOutputNode
(
op
,
"Accuracy"
,
accuracy
,
ngb_node_map
);
platform
::
SetOutputNode
(
op
,
"Correct"
,
num_correct
,
ngb_node_map
);
platform
::
SetOutputNode
(
op
,
"Total"
,
n_samples
,
ngb_node_map
);
}
}
// namespace ngraphs
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/ngraph/ops/binary_un
n
ary_op.h
→
paddle/fluid/operators/ngraph/ops/binary_unary_op.h
浏览文件 @
2ba256df
文件已移动
paddle/fluid/operators/ngraph/ops/top_k_op.h
浏览文件 @
2ba256df
...
...
@@ -36,11 +36,6 @@ void BuildTopKNode(
std
::
make_shared
<
ngraph
::
op
::
GetOutputElement
>
(
top_k
,
0
);
std
::
shared_ptr
<
ngraph
::
Node
>
out
=
std
::
make_shared
<
ngraph
::
op
::
GetOutputElement
>
(
top_k
,
1
);
auto
dummy_out
=
paddle
::
platform
::
GetOutputNode
(
op
,
"Out"
,
ngb_node_map
);
if
(
dummy_out
&&
dummy_out
->
get_element_type
()
!=
out
->
get_element_type
())
{
out
=
std
::
make_shared
<
ngraph
::
op
::
Convert
>
(
out
,
dummy_out
->
get_element_type
());
}
paddle
::
platform
::
SetOutputNode
(
op
,
"Indices"
,
indices
,
ngb_node_map
);
paddle
::
platform
::
SetOutputNode
(
op
,
"Out"
,
out
,
ngb_node_map
);
}
...
...
paddle/fluid/operators/pool_op.cc
浏览文件 @
2ba256df
...
...
@@ -259,7 +259,7 @@ Example:
W_{out} = \\frac{(W_{in} - ksize[1] + 2 * paddings[1] + strides[1] - 1)}{strides[1]} + 1
$$
For exclusive =
tru
e:
For exclusive =
fals
e:
$$
hstart = i * strides[0] - paddings[0]
hend = hstart + ksize[0]
...
...
@@ -267,7 +267,7 @@ Example:
wend = wstart + ksize[1]
Output(i ,j) = \\frac{sum(Input[hstart:hend, wstart:wend])}{ksize[0] * ksize[1]}
$$
For exclusive =
fals
e:
For exclusive =
tru
e:
$$
hstart = max(0, i * strides[0] - paddings[0])
hend = min(H, hstart + ksize[0])
...
...
@@ -403,7 +403,7 @@ Example:
H_{out} = \frac{(H_{in} - ksize[1] + 2 * paddings[1] + strides[1] -1)}{strides[1]} + 1 \\
W_{out} = \frac{(W_{in} - ksize[2] + 2 * paddings[2] + strides[2] -1)}{strides[2]} + 1
$$
For exclusive =
tru
e:
For exclusive =
fals
e:
$$
dstart = i * strides[0] - paddings[0]
dend = dstart + ksize[0]
...
...
@@ -413,7 +413,7 @@ Example:
wend = wstart + ksize[2]
Output(i ,j, k) = \\frac{sum(Input[dstart:dend, hstart:hend, wstart:wend])}{ksize[0] * ksize[1] * ksize[2]}
$$
For exclusive =
fals
e:
For exclusive =
tru
e:
$$
dstart = max(0, i * strides[0] - paddings[0])
dend = min(D, dstart + ksize[0])
...
...
paddle/fluid/operators/reader/ctr_reader.cc
浏览文件 @
2ba256df
...
...
@@ -213,7 +213,7 @@ void ReadSvmData(const DataDesc& data_desc, std::shared_ptr<Reader> reader,
framework
::
LoD
lod
{
lod_data
};
lod_tensor
.
set_lod
(
lod
);
int64_t
*
tensor_data
=
lod_tensor
.
mutable_data
<
int64_t
>
(
framework
::
make_ddim
({
1
,
static_cast
<
int64_t
>
(
batch_feasign
.
size
())
}),
framework
::
make_ddim
({
static_cast
<
int64_t
>
(
batch_feasign
.
size
()),
1
}),
platform
::
CPUPlace
());
memcpy
(
tensor_data
,
batch_feasign
.
data
(),
batch_feasign
.
size
()
*
sizeof
(
int64_t
));
...
...
@@ -223,7 +223,7 @@ void ReadSvmData(const DataDesc& data_desc, std::shared_ptr<Reader> reader,
// insert label tensor
framework
::
LoDTensor
label_tensor
;
auto
*
label_tensor_data
=
label_tensor
.
mutable_data
<
int64_t
>
(
framework
::
make_ddim
({
1
,
static_cast
<
int64_t
>
(
batch_label
.
size
())
}),
framework
::
make_ddim
({
static_cast
<
int64_t
>
(
batch_label
.
size
()),
1
}),
platform
::
CPUPlace
());
memcpy
(
label_tensor_data
,
batch_label
.
data
(),
batch_label
.
size
()
*
sizeof
(
int64_t
));
...
...
paddle/fluid/operators/reader/ctr_reader_test.cc
浏览文件 @
2ba256df
...
...
@@ -123,7 +123,7 @@ TEST(CTR_READER, read_data) {
std
::
vector
<
std
::
tuple
<
LoD
,
std
::
vector
<
int64_t
>>>
data_slot_6003
{
b1
,
b2
,
b3
,
b4
};
std
::
vector
<
DDim
>
label_dims
=
{{
1
,
3
},
{
1
,
3
},
{
1
,
3
},
{
1
,
1
}};
std
::
vector
<
DDim
>
label_dims
=
{{
3
,
1
},
{
3
,
1
},
{
3
,
1
},
{
1
,
1
}};
LoDTensorBlockingQueueHolder
queue_holder
;
int
capacity
=
64
;
...
...
paddle/fluid/operators/reduce_ops/CMakeLists.txt
浏览文件 @
2ba256df
include
(
operators
)
register_operators
()
if
(
WITH_GPU
)
register_operators
(
DEPS cub
)
else
()
register_operators
()
endif
()
if
(
WITH_GPU
)
file
(
GLOB OPS RELATIVE
"
${
CMAKE_CURRENT_SOURCE_DIR
}
"
"*.part.cu"
)
...
...
paddle/fluid/platform/CMakeLists.txt
浏览文件 @
2ba256df
proto_library
(
profiler_proto SRCS profiler.proto DEPS framework_proto
)
proto_library
(
profiler_proto SRCS profiler.proto DEPS framework_proto
simple_threadpool
)
py_proto_compile
(
profiler_py_proto SRCS profiler.proto
)
add_custom_target
(
profiler_py_proto_init ALL COMMAND
${
CMAKE_COMMAND
}
-E touch __init__.py
)
...
...
@@ -36,7 +36,7 @@ cc_test(cpu_info_test SRCS cpu_info_test.cc DEPS cpu_info)
nv_library
(
gpu_info SRCS gpu_info.cc DEPS gflags glog enforce
)
cc_library
(
place SRCS place.cc DEPS enforce boost
)
cc_library
(
place SRCS place.cc DEPS enforce boost
lib_any
)
cc_test
(
place_test SRCS place_test.cc DEPS place glog gflags
)
add_subdirectory
(
dynload
)
...
...
paddle/fluid/platform/ngraph_helper.h
浏览文件 @
2ba256df
...
...
@@ -43,13 +43,14 @@ std::shared_ptr<ngraph::Node> NgReshaper(std::shared_ptr<ngraph::Node> input,
std
::
shared_ptr
<
ngraph
::
Node
>
GetNode
(
const
std
::
shared_ptr
<
paddle
::
framework
::
OperatorBase
>&
op
,
const
std
::
string
prm
,
const
paddle
::
framework
::
VariableNameMap
&
var_map
,
const
std
::
string
name
,
const
paddle
::
framework
::
VariableNameMap
&
var_map
,
std
::
shared_ptr
<
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
ngraph
::
Node
>>>
ngb_node_map
)
{
auto
&
var_names
=
var_map
.
at
(
prm
);
auto
&
var_names
=
var_map
.
at
(
name
);
PADDLE_ENFORCE_EQ
(
var_names
.
size
(),
1
,
"op %s prm %s expects one associated var"
,
op
->
Type
(),
prm
);
"op %s name %s expects one associated var"
,
op
->
Type
(),
name
);
if
(
ngb_node_map
->
find
(
var_names
[
0
])
!=
ngb_node_map
->
end
())
{
return
(
*
ngb_node_map
)[
var_names
[
0
]];
}
else
{
...
...
@@ -59,43 +60,53 @@ std::shared_ptr<ngraph::Node> GetNode(
std
::
shared_ptr
<
ngraph
::
Node
>
GetInputNode
(
const
std
::
shared_ptr
<
paddle
::
framework
::
OperatorBase
>&
op
,
const
std
::
string
prm
,
const
std
::
string
name
,
std
::
shared_ptr
<
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
ngraph
::
Node
>>>
ngb_node_map
)
{
return
GetNode
(
op
,
prm
,
op
->
Inputs
(),
ngb_node_map
);
return
GetNode
(
op
,
name
,
op
->
Inputs
(),
ngb_node_map
);
}
std
::
shared_ptr
<
ngraph
::
Node
>
GetOutputNode
(
const
std
::
shared_ptr
<
paddle
::
framework
::
OperatorBase
>&
op
,
const
std
::
string
prm
,
const
std
::
string
name
,
std
::
shared_ptr
<
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
ngraph
::
Node
>>>
ngb_node_map
)
{
return
GetNode
(
op
,
prm
,
op
->
Outputs
(),
ngb_node_map
);
return
GetNode
(
op
,
name
,
op
->
Outputs
(),
ngb_node_map
);
}
void
SetOutputNode
(
const
std
::
shared_ptr
<
paddle
::
framework
::
OperatorBase
>&
op
,
const
std
::
string
prm
,
std
::
shared_ptr
<
ngraph
::
Node
>
node
,
const
std
::
string
name
,
std
::
shared_ptr
<
ngraph
::
Node
>
node
,
std
::
shared_ptr
<
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
ngraph
::
Node
>>>
ngb_node_map
)
{
auto
&
var_names
=
op
->
Outputs
().
at
(
prm
);
auto
&
var_names
=
op
->
Outputs
().
at
(
name
);
if
(
var_names
.
size
()
==
1
)
{
/* */
auto
dummy_out
=
GetOutputNode
(
op
,
name
,
ngb_node_map
);
if
(
dummy_out
&&
dummy_out
->
get_shape
()
!=
node
->
get_shape
())
{
node
=
NgReshaper
(
node
,
dummy_out
->
get_shape
());
}
if
(
dummy_out
&&
dummy_out
->
get_element_type
()
!=
node
->
get_element_type
())
{
node
=
std
::
make_shared
<
ngraph
::
op
::
Convert
>
(
node
,
dummy_out
->
get_element_type
());
}
(
*
ngb_node_map
)[
var_names
[
0
]]
=
node
;
}
else
if
(
var_names
.
size
()
==
0
)
{
(
*
ngb_node_map
)[
""
]
=
node
;
}
else
{
PADDLE_THROW
(
"
prm %s has more than 1 var_names."
,
prm
);
PADDLE_THROW
(
"
name %s has more than 1 var_names."
,
name
);
}
}
bool
HasOutput
(
const
std
::
shared_ptr
<
paddle
::
framework
::
OperatorBase
>&
op
,
const
std
::
string
prm
)
{
const
std
::
string
name
)
{
auto
&
outputs
=
op
->
Outputs
();
if
(
outputs
.
find
(
prm
)
==
outputs
.
end
())
return
false
;
return
outputs
.
at
(
prm
).
size
()
>
0
;
if
(
outputs
.
find
(
name
)
==
outputs
.
end
())
return
false
;
return
outputs
.
at
(
name
).
size
()
>
0
;
}
inline
void
GetMidDims
(
const
ngraph
::
Shape
&
x_shape
,
...
...
paddle/fluid/platform/place.cc
浏览文件 @
2ba256df
...
...
@@ -14,12 +14,6 @@ limitations under the License. */
#include "paddle/fluid/platform/place.h"
DEFINE_bool
(
benchmark
,
false
,
"Doing memory benchmark. It will make deleting scope synchronized, "
"and add some memory usage logs."
"Default cuda is asynchronous device, set to True will"
"force op run in synchronous mode."
);
namespace
paddle
{
namespace
platform
{
...
...
paddle/fluid/pybind/CMakeLists.txt
浏览文件 @
2ba256df
...
...
@@ -26,5 +26,5 @@ if(WITH_PYTHON)
get_property
(
os_dependency_modules GLOBAL PROPERTY OS_DEPENDENCY_MODULES
)
target_link_libraries
(
paddle_pybind
${
os_dependency_modules
}
)
cc_test
(
tensor_py_test SRCS tensor_py_test.cc DEPS python
)
cc_test
(
tensor_py_test SRCS tensor_py_test.cc DEPS python
pybind
)
endif
(
WITH_PYTHON
)
python/CMakeLists.txt
浏览文件 @
2ba256df
...
...
@@ -54,7 +54,7 @@ ELSE(WIN32)
DEPENDS copy_paddle_pybind
${
FLUID_CORE
}
framework_py_proto profiler_py_proto
${
PY_FILES
}
${
external_project_dependencies
}
${
COPY_PADDLE_MASTER
}
)
ENDIF
()
set
(
paddle_python_deps
${
PADDLE_PYTHON_BUILD_DIR
}
/.timestamp
${
MKL_DEPENDS
}
)
set
(
paddle_python_deps
${
PADDLE_PYTHON_BUILD_DIR
}
/.timestamp
${
MKL_DEPENDS
}
${
external_project_dependencies
}
)
add_custom_target
(
paddle_python ALL DEPENDS
${
paddle_python_deps
}
)
set
(
PADDLE_PYTHON_PACKAGE_DIR
${
CMAKE_CURRENT_BINARY_DIR
}
/dist/
)
...
...
python/paddle/fluid/layers/detection.py
浏览文件 @
2ba256df
...
...
@@ -397,10 +397,10 @@ def box_coder(prior_box,
input is image feature map, they are close to
the origin of the coordinate system. [xmax, ymax]
is the right bottom coordinate of the anchor box.
prior_box_var(Variable|list
): prior_box_var supports two types of input.
One is variable with shape [M, 4] holds M group.
The other one is list consist of 4 elements
shared by all boxes.
prior_box_var(Variable|list
|None): prior_box_var supports two types
of input. One is variable with shape [M, 4]
holds M group. The other one is list consist of
4 elements
shared by all boxes.
target_box(Variable): This input can be a 2-D LoDTensor with shape
[N, 4] when code_type is 'encode_center_size'.
This input also can be a 3-D Tensor with shape
...
...
python/paddle/fluid/tests/unittests/ngraph/test_accuracy_ngraph_op.py
0 → 100644
浏览文件 @
2ba256df
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
numpy
as
np
import
paddle.fluid.core
as
core
from
paddle.fluid.tests.unittests.op_test
import
OpTest
from
paddle.fluid.tests.unittests.test_accuracy_op
import
TestAccuracyOp
class
TestNGRAPHAccuracyOp
(
TestAccuracyOp
):
def
setUp
(
self
):
super
(
TestNGRAPHAccuracyOp
,
self
).
setUp
()
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_box_coder_op.py
浏览文件 @
2ba256df
...
...
@@ -34,7 +34,9 @@ def box_decoder(t_box, p_box, pb_v, output_box, norm, axis=0):
pb_y
=
pb_y
.
reshape
(
shape
)
if
pb_v
.
ndim
==
2
:
pb_v
=
pb_v
.
reshape
(
1
,
pb_v
.
shape
[
0
],
pb_v
.
shape
[
1
])
var_shape
=
(
1
,
pb_v
.
shape
[
0
],
pb_v
.
shape
[
1
])
if
axis
==
0
else
(
pb_v
.
shape
[
0
],
1
,
pb_v
.
shape
[
1
])
pb_v
=
pb_v
.
reshape
(
var_shape
)
if
pb_v
.
ndim
==
1
:
tb_x
=
pb_v
[
0
]
*
t_box
[:,
:,
0
]
*
pb_w
+
pb_x
tb_y
=
pb_v
[
1
]
*
t_box
[:,
:,
1
]
*
pb_h
+
pb_y
...
...
@@ -125,33 +127,6 @@ class TestBoxCoderOp(OpTest):
self
.
outputs
=
{
'OutputBox'
:
output_box
}
class
TestBoxCoderOpWithOneRankVar
(
OpTest
):
def
test_check_output
(
self
):
self
.
check_output
()
def
setUp
(
self
):
self
.
op_type
=
"box_coder"
lod
=
[[
1
,
1
,
1
,
1
,
1
]]
prior_box
=
np
.
random
.
random
((
81
,
4
)).
astype
(
'float32'
)
prior_box_var
=
np
.
random
.
random
((
4
)).
astype
(
'float32'
)
target_box
=
np
.
random
.
random
((
20
,
81
,
4
)).
astype
(
'float32'
)
code_type
=
"DecodeCenterSize"
box_normalized
=
False
output_box
=
batch_box_coder
(
prior_box
,
prior_box_var
,
target_box
,
lod
[
0
],
code_type
,
box_normalized
)
self
.
inputs
=
{
'PriorBox'
:
prior_box
,
'PriorBoxVar'
:
prior_box_var
,
'TargetBox'
:
target_box
,
}
self
.
attrs
=
{
'code_type'
:
'decode_center_size'
,
'box_normalized'
:
False
}
self
.
outputs
=
{
'OutputBox'
:
output_box
}
class
TestBoxCoderOpWithoutBoxVar
(
OpTest
):
def
test_check_output
(
self
):
self
.
check_output
()
...
...
@@ -210,7 +185,7 @@ class TestBoxCoderOpWithAxis(OpTest):
self
.
op_type
=
"box_coder"
lod
=
[[
1
,
1
,
1
,
1
,
1
]]
prior_box
=
np
.
random
.
random
((
30
,
4
)).
astype
(
'float32'
)
prior_box_var
=
np
.
random
.
random
((
4
)).
astype
(
'float32'
)
prior_box_var
=
np
.
random
.
random
((
30
,
4
)).
astype
(
'float32'
)
target_box
=
np
.
random
.
random
((
30
,
81
,
4
)).
astype
(
'float32'
)
code_type
=
"DecodeCenterSize"
box_normalized
=
False
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录