Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
000ba1ac
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
000ba1ac
编写于
7月 30, 2018
作者:
M
minqiyang
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' of
https://github.com/PaddlePaddle/Paddle
into port_python3_syntax
上级
0c7d6eb8
4b8ae523
变更
83
显示空白变更内容
内联
并排
Showing
83 changed file
with
2439 addition
and
599 deletion
+2439
-599
AUTHORS.md
AUTHORS.md
+1
-0
benchmark/fluid/fluid_benchmark.py
benchmark/fluid/fluid_benchmark.py
+1
-2
cmake/external/grpc.cmake
cmake/external/grpc.cmake
+1
-1
cmake/generic.cmake
cmake/generic.cmake
+2
-2
cmake/inference_lib.cmake
cmake/inference_lib.cmake
+2
-9
doc/fluid/design/ir/draft.md
doc/fluid/design/ir/draft.md
+97
-1
paddle/fluid/framework/CMakeLists.txt
paddle/fluid/framework/CMakeLists.txt
+4
-4
paddle/fluid/framework/details/CMakeLists.txt
paddle/fluid/framework/details/CMakeLists.txt
+0
-3
paddle/fluid/framework/details/multi_devices_graph_builder.cc
...le/fluid/framework/details/multi_devices_graph_builder.cc
+82
-69
paddle/fluid/framework/details/multi_devices_graph_builder.h
paddle/fluid/framework/details/multi_devices_graph_builder.h
+12
-25
paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h
...uid/framework/details/scope_buffered_ssa_graph_executor.h
+3
-0
paddle/fluid/framework/details/ssa_graph_builder.cc
paddle/fluid/framework/details/ssa_graph_builder.cc
+7
-6
paddle/fluid/framework/details/ssa_graph_builder.h
paddle/fluid/framework/details/ssa_graph_builder.h
+6
-2
paddle/fluid/framework/details/ssa_graph_builder_factory.h
paddle/fluid/framework/details/ssa_graph_builder_factory.h
+0
-71
paddle/fluid/framework/details/ssa_graph_checker.cc
paddle/fluid/framework/details/ssa_graph_checker.cc
+10
-3
paddle/fluid/framework/details/ssa_graph_checker.h
paddle/fluid/framework/details/ssa_graph_checker.h
+4
-16
paddle/fluid/framework/details/ssa_graph_executor.h
paddle/fluid/framework/details/ssa_graph_executor.h
+3
-1
paddle/fluid/framework/details/ssa_graph_printer.cc
paddle/fluid/framework/details/ssa_graph_printer.cc
+6
-3
paddle/fluid/framework/details/ssa_graph_printer.h
paddle/fluid/framework/details/ssa_graph_printer.h
+9
-30
paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
...le/fluid/framework/details/threaded_ssa_graph_executor.cc
+4
-4
paddle/fluid/framework/details/threaded_ssa_graph_executor.h
paddle/fluid/framework/details/threaded_ssa_graph_executor.h
+1
-0
paddle/fluid/framework/ir/CMakeLists.txt
paddle/fluid/framework/ir/CMakeLists.txt
+6
-3
paddle/fluid/framework/ir/graph.cc
paddle/fluid/framework/ir/graph.cc
+120
-0
paddle/fluid/framework/ir/graph.h
paddle/fluid/framework/ir/graph.h
+8
-1
paddle/fluid/framework/ir/graph_viz_pass.cc
paddle/fluid/framework/ir/graph_viz_pass.cc
+72
-0
paddle/fluid/framework/ir/graph_viz_pass.h
paddle/fluid/framework/ir/graph_viz_pass.h
+38
-0
paddle/fluid/framework/ir/pass.cc
paddle/fluid/framework/ir/pass.cc
+28
-1
paddle/fluid/framework/ir/pass.h
paddle/fluid/framework/ir/pass.h
+168
-2
paddle/fluid/framework/ir/pass_test.cc
paddle/fluid/framework/ir/pass_test.cc
+112
-0
paddle/fluid/framework/operator.cc
paddle/fluid/framework/operator.cc
+2
-0
paddle/fluid/framework/parallel_executor.cc
paddle/fluid/framework/parallel_executor.cc
+88
-17
paddle/fluid/framework/parallel_executor.h
paddle/fluid/framework/parallel_executor.h
+0
-1
paddle/fluid/inference/CMakeLists.txt
paddle/fluid/inference/CMakeLists.txt
+21
-5
paddle/fluid/inference/api/CMakeLists.txt
paddle/fluid/inference/api/CMakeLists.txt
+1
-28
paddle/fluid/inference/api/api.map
paddle/fluid/inference/api/api.map
+0
-6
paddle/fluid/inference/api/api.sym
paddle/fluid/inference/api/api.sym
+0
-1
paddle/fluid/inference/api/demo_ci/CMakeLists.txt
paddle/fluid/inference/api/demo_ci/CMakeLists.txt
+0
-2
paddle/fluid/inference/check_symbol.sh
paddle/fluid/inference/check_symbol.sh
+2
-2
paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
+5
-2
paddle/fluid/inference/tensorrt/convert/mul_op.cc
paddle/fluid/inference/tensorrt/convert/mul_op.cc
+0
-1
paddle/fluid/inference/tensorrt/convert/pool2d_op.cc
paddle/fluid/inference/tensorrt/convert/pool2d_op.cc
+80
-0
paddle/fluid/inference/tensorrt/convert/test_activation_op.cc
...le/fluid/inference/tensorrt/convert/test_activation_op.cc
+1
-1
paddle/fluid/inference/tensorrt/convert/test_fc_op.cc
paddle/fluid/inference/tensorrt/convert/test_fc_op.cc
+2
-3
paddle/fluid/inference/tensorrt/convert/test_mul_op.cc
paddle/fluid/inference/tensorrt/convert/test_mul_op.cc
+2
-2
paddle/fluid/inference/tensorrt/convert/test_pool2d_op.cc
paddle/fluid/inference/tensorrt/convert/test_pool2d_op.cc
+60
-0
paddle/fluid/inference/tensorrt/convert/ut_helper.h
paddle/fluid/inference/tensorrt/convert/ut_helper.h
+29
-7
paddle/fluid/inference/tensorrt/test_engine.cc
paddle/fluid/inference/tensorrt/test_engine.cc
+32
-1
paddle/fluid/inference/tests/book/CMakeLists.txt
paddle/fluid/inference/tests/book/CMakeLists.txt
+2
-2
paddle/fluid/inference/tests/book/test_inference_nlp.cc
paddle/fluid/inference/tests/book/test_inference_nlp.cc
+2
-9
paddle/fluid/operators/.flatten_op.cc.swp
paddle/fluid/operators/.flatten_op.cc.swp
+0
-0
paddle/fluid/operators/CMakeLists.txt
paddle/fluid/operators/CMakeLists.txt
+3
-0
paddle/fluid/operators/distributed/CMakeLists.txt
paddle/fluid/operators/distributed/CMakeLists.txt
+3
-3
paddle/fluid/operators/distributed/grpc_client.cc
paddle/fluid/operators/distributed/grpc_client.cc
+2
-1
paddle/fluid/operators/distributed/grpc_client.h
paddle/fluid/operators/distributed/grpc_client.h
+3
-1
paddle/fluid/operators/distributed/rpc_server_test.cc
paddle/fluid/operators/distributed/rpc_server_test.cc
+12
-13
paddle/fluid/operators/extract_rows_op.cc
paddle/fluid/operators/extract_rows_op.cc
+103
-0
paddle/fluid/operators/flatten_op.cc
paddle/fluid/operators/flatten_op.cc
+169
-0
paddle/fluid/operators/lookup_table_op.cc
paddle/fluid/operators/lookup_table_op.cc
+16
-30
paddle/fluid/operators/lookup_table_op.cu
paddle/fluid/operators/lookup_table_op.cu
+28
-45
paddle/fluid/operators/lookup_table_op.h
paddle/fluid/operators/lookup_table_op.h
+5
-35
paddle/fluid/operators/tensorrt_engine_op.cc
paddle/fluid/operators/tensorrt_engine_op.cc
+0
-3
paddle/fluid/platform/CMakeLists.txt
paddle/fluid/platform/CMakeLists.txt
+4
-0
paddle/fluid/platform/cpu_helper.cc
paddle/fluid/platform/cpu_helper.cc
+2
-0
paddle/fluid/platform/cuda_device_function.h
paddle/fluid/platform/cuda_device_function.h
+21
-0
paddle/fluid/platform/cuda_helper_test.cu
paddle/fluid/platform/cuda_helper_test.cu
+118
-0
paddle/fluid/platform/cuda_primitives.h
paddle/fluid/platform/cuda_primitives.h
+69
-6
paddle/fluid/platform/float16.h
paddle/fluid/platform/float16.h
+27
-0
paddle/fluid/platform/float16_test.cc
paddle/fluid/platform/float16_test.cc
+26
-0
paddle/fluid/platform/float16_test.cu
paddle/fluid/platform/float16_test.cu
+69
-1
paddle/fluid/platform/init.cc
paddle/fluid/platform/init.cc
+4
-1
patches/grpc/completion_queue.h
patches/grpc/completion_queue.h
+386
-0
patches/grpc/fix_too_early_destory.patch
patches/grpc/fix_too_early_destory.patch
+0
-47
patches/grpc/grpc_library.h
patches/grpc/grpc_library.h
+64
-0
python/paddle/fluid/__init__.py
python/paddle/fluid/__init__.py
+1
-1
python/paddle/fluid/regularizer.py
python/paddle/fluid/regularizer.py
+14
-2
python/paddle/fluid/tests/unittests/CMakeLists.txt
python/paddle/fluid/tests/unittests/CMakeLists.txt
+1
-1
python/paddle/fluid/tests/unittests/dist_se_resnext.py
python/paddle/fluid/tests/unittests/dist_se_resnext.py
+2
-6
python/paddle/fluid/tests/unittests/test_dist_se_resnext.py
python/paddle/fluid/tests/unittests/test_dist_se_resnext.py
+22
-6
python/paddle/fluid/tests/unittests/test_extract_rows_op.py
python/paddle/fluid/tests/unittests/test_extract_rows_op.py
+58
-0
python/paddle/fluid/tests/unittests/test_flatten_op.py
python/paddle/fluid/tests/unittests/test_flatten_op.py
+68
-0
python/paddle/fluid/tests/unittests/test_lookup_table_op.py
python/paddle/fluid/tests/unittests/test_lookup_table_op.py
+0
-47
python/paddle/fluid/transpiler/distribute_transpiler.py
python/paddle/fluid/transpiler/distribute_transpiler.py
+2
-0
tools/codestyle/cpplint_pre_commit.hook
tools/codestyle/cpplint_pre_commit.hook
+1
-1
未找到文件。
AUTHORS.md
浏览文件 @
000ba1ac
...
@@ -46,6 +46,7 @@
...
@@ -46,6 +46,7 @@
| tianbingsz | Tian-Bing Xu |
| tianbingsz | Tian-Bing Xu |
| tpatejko | Tomasz Patejko |
| tpatejko | Tomasz Patejko |
| typhoonzero | Yi Wu |
| typhoonzero | Yi Wu |
| velconia | Qi-Yang Min |
| wanghaoshuang | Hao-Shuang Wang |
| wanghaoshuang | Hao-Shuang Wang |
| wangyang59 | Yang Wang |
| wangyang59 | Yang Wang |
| wangzhen-nlp | Zhen Wang |
| wangzhen-nlp | Zhen Wang |
...
...
benchmark/fluid/fluid_benchmark.py
浏览文件 @
000ba1ac
...
@@ -85,8 +85,7 @@ def dist_transpile(trainer_id, args):
...
@@ -85,8 +85,7 @@ def dist_transpile(trainer_id, args):
trainer_id
,
trainer_id
,
pservers
=
pserver_endpoints
,
pservers
=
pserver_endpoints
,
trainers
=
trainers
,
trainers
=
trainers
,
sync_mode
=
not
args
.
async_mode
,
sync_mode
=
not
args
.
async_mode
)
slice_var_up
=
not
args
.
no_split_var
)
if
training_role
==
"PSERVER"
:
if
training_role
==
"PSERVER"
:
pserver_program
=
t
.
get_pserver_program
(
current_endpoint
)
pserver_program
=
t
.
get_pserver_program
(
current_endpoint
)
pserver_startup_program
=
t
.
get_startup_program
(
current_endpoint
,
pserver_startup_program
=
t
.
get_startup_program
(
current_endpoint
,
...
...
cmake/external/grpc.cmake
浏览文件 @
000ba1ac
...
@@ -50,7 +50,7 @@ ExternalProject_Add(
...
@@ -50,7 +50,7 @@ ExternalProject_Add(
UPDATE_COMMAND
""
UPDATE_COMMAND
""
CONFIGURE_COMMAND
""
CONFIGURE_COMMAND
""
BUILD_IN_SOURCE 1
BUILD_IN_SOURCE 1
PATCH_COMMAND
git apply
${
PADDLE_SOURCE_DIR
}
/patches/grpc/fix_too_early_destory.patc
h
PATCH_COMMAND
cp
${
PADDLE_SOURCE_DIR
}
/patches/grpc/grpc_library.h
${
GRPC_SOURCES_DIR
}
/src/extern_grpc/include/grpcpp/impl/codegen/grpc_library.h && cp
${
PADDLE_SOURCE_DIR
}
/patches/grpc/completion_queue.h
${
GRPC_SOURCES_DIR
}
/src/extern_grpc/include/grpcpp/impl/codegen/completion_queue.
h
# NOTE(yuyang18):
# NOTE(yuyang18):
# Disable -Werror, otherwise the compile will fail in MacOS.
# Disable -Werror, otherwise the compile will fail in MacOS.
# It seems that we cannot configure that by make command.
# It seems that we cannot configure that by make command.
...
...
cmake/generic.cmake
浏览文件 @
000ba1ac
...
@@ -263,7 +263,7 @@ function(cc_test TARGET_NAME)
...
@@ -263,7 +263,7 @@ function(cc_test TARGET_NAME)
COMMAND
${
TARGET_NAME
}
${
cc_test_ARGS
}
COMMAND
${
TARGET_NAME
}
${
cc_test_ARGS
}
WORKING_DIRECTORY
${
CMAKE_CURRENT_BINARY_DIR
}
)
WORKING_DIRECTORY
${
CMAKE_CURRENT_BINARY_DIR
}
)
if
(
${
cc_test_SERIAL
}
)
if
(
${
cc_test_SERIAL
}
)
set_property
(
TEST
${
TARGET_NAME
}
PROPERTY SERIAL 1
)
set_property
(
TEST
${
TARGET_NAME
}
PROPERTY
RUN_
SERIAL 1
)
set_property
(
TEST
${
TARGET_NAME
}
PROPERTY ENVIRONMENT FLAGS_init_allocated_mem=true
)
set_property
(
TEST
${
TARGET_NAME
}
PROPERTY ENVIRONMENT FLAGS_init_allocated_mem=true
)
endif
()
endif
()
endif
()
endif
()
...
@@ -328,7 +328,7 @@ function(nv_test TARGET_NAME)
...
@@ -328,7 +328,7 @@ function(nv_test TARGET_NAME)
add_dependencies
(
${
TARGET_NAME
}
${
nv_test_DEPS
}
paddle_gtest_main lod_tensor memory gtest gflags glog
)
add_dependencies
(
${
TARGET_NAME
}
${
nv_test_DEPS
}
paddle_gtest_main lod_tensor memory gtest gflags glog
)
add_test
(
${
TARGET_NAME
}
${
TARGET_NAME
}
)
add_test
(
${
TARGET_NAME
}
${
TARGET_NAME
}
)
if
(
nv_test_SERIAL
)
if
(
nv_test_SERIAL
)
set_property
(
TEST
${
TARGET_NAME
}
PROPERTY SERIAL 1
)
set_property
(
TEST
${
TARGET_NAME
}
PROPERTY
RUN_
SERIAL 1
)
set_property
(
TEST
${
TARGET_NAME
}
PROPERTY ENVIRONMENT FLAGS_init_allocated_mem=true
)
set_property
(
TEST
${
TARGET_NAME
}
PROPERTY ENVIRONMENT FLAGS_init_allocated_mem=true
)
endif
()
endif
()
endif
()
endif
()
...
...
cmake/inference_lib.cmake
浏览文件 @
000ba1ac
...
@@ -148,18 +148,11 @@ if (WITH_ANAKIN AND WITH_GPU)
...
@@ -148,18 +148,11 @@ if (WITH_ANAKIN AND WITH_GPU)
list
(
APPEND inference_deps anakin_inference_lib
)
list
(
APPEND inference_deps anakin_inference_lib
)
endif
()
endif
()
copy
(
inference_api_lib DEPS paddle_inference_api paddle_inference_api_shared
SRCS
${
src_dir
}
/
${
module
}
/paddle_inference_api.h
${
src_dir
}
/
${
module
}
/demo_ci
${
PADDLE_BINARY_DIR
}
/paddle/fluid/inference/api/libpaddle_inference_api*
DSTS
${
dst_dir
}
/inference
${
dst_dir
}
/inference
${
dst_dir
}
/inference
)
list
(
APPEND inference_deps inference_api_lib
)
set
(
module
"inference"
)
set
(
module
"inference"
)
copy
(
inference_lib DEPS
${
inference_deps
}
copy
(
inference_lib DEPS
${
inference_deps
}
SRCS
${
src_dir
}
/
${
module
}
/*.h
${
PADDLE_BINARY_DIR
}
/paddle/fluid/inference/libpaddle_fluid.*
SRCS
${
src_dir
}
/
${
module
}
/*.h
${
PADDLE_BINARY_DIR
}
/paddle/fluid/inference/libpaddle_fluid.*
DSTS
${
dst_dir
}
/
${
module
}
${
dst_dir
}
/
${
module
}
${
src_dir
}
/
${
module
}
/api/paddle_inference_api.h
${
src_dir
}
/
${
module
}
/api/demo_ci
DSTS
${
dst_dir
}
/
${
module
}
${
dst_dir
}
/
${
module
}
${
dst_dir
}
/
${
module
}
${
dst_dir
}
/
${
module
}
)
)
set
(
module
"platform"
)
set
(
module
"platform"
)
...
...
doc/fluid/design/ir/draft.md
浏览文件 @
000ba1ac
...
@@ -64,6 +64,41 @@ can also contain other things that describe some properties of
...
@@ -64,6 +64,41 @@ can also contain other things that describe some properties of
the
`Graph`
or
`Graph`
nodes.
`Attribute`
can be passed
the
`Graph`
or
`Graph`
nodes.
`Attribute`
can be passed
across
`Pass`
. However, it should be used with care.
across
`Pass`
. However, it should be used with care.
```
cpp
class
Graph
{
public:
explicit
Graph
(
const
ProgramDesc
&
program
);
bool
Has
(
const
std
::
string
&
attr_name
)
const
;
template
<
typename
AttrType
>
AttrType
&
Get
(
const
std
::
string
&
attr_name
)
const
;
template
<
typename
AttrType
>
void
Set
(
const
std
::
string
&
attr_name
,
AttrType
*
attr
);
const
std
::
unordered_set
<
ir
::
Node
*>
&
Nodes
()
const
;
// Create a normal variable with non-null VarDesc.
ir
::
Node
*
CreateVarNode
(
VarDesc
*
var_desc
);
// Create a normal runnable operator with OpDesc.
ir
::
Node
*
CreateOpNode
(
OpDesc
*
op_desc
);
// Create a control dependency var that connects 2 operations. The
// var doesn't hold any data. Other than that, it's no different from
// other var, considering dependency analysis.
ir
::
Node
*
CreateControlDepVar
();
// A more free style way of creating a graph node. Mostly use for test
// or "copy" from another node. Avoid using it if possible.
ir
::
Node
*
CreateEmptyNode
(
const
std
::
string
&
name
,
ir
::
Node
::
Type
type
);
// Clear all node information of the graph and return the ownership of the
// nodes.
std
::
vector
<
std
::
unique_ptr
<
ir
::
Node
>>
ReleaseNodes
();
};
```
#### Pass
#### Pass
`Pass`
represents a transformation of
`Graph`
. Its input
`Pass`
represents a transformation of
`Graph`
. Its input
...
@@ -71,6 +106,54 @@ is a `Graph` and its output is also a `Graph`. For example,
...
@@ -71,6 +106,54 @@ is a `Graph` and its output is also a `Graph`. For example,
a
`Pass`
can simply print out the
`Graph`
. A
`Pass`
a
`Pass`
can simply print out the
`Graph`
. A
`Pass`
can also fuse some
`Graph`
's
`Node`
s.
can also fuse some
`Graph`
's
`Node`
s.
```
cpp
class
Pass
{
public:
std
::
unique_ptr
<
Graph
>
Apply
(
std
::
unique_ptr
<
Graph
>
graph
)
const
{
// Some correctness check.
auto
new_graph
=
ApplyImpl
(
std
::
move
(
graph
));
// Some correctness check.
return
new_graph
;
}
// Get a reference to the attributed previously set.
template
<
typename
AttrType
>
AttrType
&
Get
(
const
std
::
string
&
attr_name
)
const
;
// Set a pointer to the attribute. Pass takes ownership of the attribute.
template
<
typename
AttrType
>
void
Set
(
const
std
::
string
&
attr_name
,
AttrType
*
attr
)
;
// Set a pointer to the attribute. Pass doesn't take ownership. Caller
// should delete the attribute.
template
<
typename
AttrType
>
void
SetNotOwned
(
const
std
::
string
&
attr_name
,
AttrType
*
attr
);
protected:
virtual
std
::
unique_ptr
<
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
Graph
>
graph
)
const
=
0
;
};
// In my_pass.cc
class
MyPass
:
public
Pass
{
protected:
std
::
unique_ptr
<
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
Graph
>
graph
)
const
override
{
// do something.
return
graph
;
}
}
REGISTER_PASS
(
my_pass
,
MyPass
)
.
RequirePassAttr
(
"places"
)
.
RequireGraphAttr
(
"dep_vars"
);
// To use the pass.
auto
my_pass
=
ir
::
PassRegistry
::
Instance
().
Get
(
"my_pass"
);
graph
=
my_pass
->
Apply
(
std
::
move
(
graph
));
// Note: to force link my_pass.cc, in the code:
USE_PASS
(
my_pass
);
```
#### Optimize
#### Optimize
`Optimize`
contains a series of
`Pass`
with defined order.
`Optimize`
contains a series of
`Pass`
with defined order.
...
@@ -86,4 +169,17 @@ maintaining the original modeling logic.
...
@@ -86,4 +169,17 @@ maintaining the original modeling logic.
*
Graph is transformed from raw model logic to a
*
Graph is transformed from raw model logic to a
form that is efficient to execute.
form that is efficient to execute.
Program->ProgramToGraph->Graph->Pass1->Graph->Pass2->Graph->Pass3->Graph->Executor
```
// Program->ProgramToGraph->Graph->Pass1->Graph->Pass2->Graph->Pass3->Graph->Executor
auto graph = Graph(program);
graph = PassRegistry::Instance().Get("op_fuse_pass").Apply(std::move(grah));
// For more complex Pass, Optimize Process can provide Pass attributes.
auto mem_opt_pass = PassRegistry::Instance().Get("memory_optimization_pass");
mem_opt_pass.SetNotOwned<int>("optimize_level", 1);
mem_opt_pass->Apply(std::move(graph));
graph = PassRegistry::Instance().Get("multi_device_pass").Apply(std::move(grah));
graph = PassRegistry::Instance().Get("multi_device_check_pass").Apply(std::move(grah));
Executor exe;
exe.Run(graph);
```
paddle/fluid/framework/CMakeLists.txt
浏览文件 @
000ba1ac
...
@@ -8,9 +8,9 @@ cc_test(ddim_test SRCS ddim_test.cc DEPS ddim)
...
@@ -8,9 +8,9 @@ cc_test(ddim_test SRCS ddim_test.cc DEPS ddim)
nv_test
(
dim_test SRCS dim_test.cu DEPS ddim
)
nv_test
(
dim_test SRCS dim_test.cu DEPS ddim
)
cc_library
(
data_type SRCS data_type.cc DEPS framework_proto ddim device_context
)
cc_library
(
data_type SRCS data_type.cc DEPS framework_proto ddim device_context
)
if
(
WITH_GPU
)
if
(
WITH_GPU
)
nv_library
(
tensor SRCS tensor.cc tensor_util.cu DEPS place memory data_type
)
nv_library
(
tensor SRCS tensor.cc tensor_util.cu DEPS place memory data_type
device_context
)
else
()
else
()
cc_library
(
tensor SRCS tensor.cc tensor_util.cc DEPS place memory data_type
)
cc_library
(
tensor SRCS tensor.cc tensor_util.cc DEPS place memory data_type
device_context
)
endif
()
endif
()
cc_test
(
tensor_test SRCS tensor_test.cc DEPS tensor
)
cc_test
(
tensor_test SRCS tensor_test.cc DEPS tensor
)
...
@@ -99,7 +99,7 @@ else()
...
@@ -99,7 +99,7 @@ else()
endif
()
endif
()
cc_library
(
parallel_executor SRCS parallel_executor.cc DEPS
ssa_graph_builder_factory threaded_ssa_graph_executor scope_buffered_ssa_graph_executor graph
)
cc_library
(
parallel_executor SRCS parallel_executor.cc DEPS
threaded_ssa_graph_executor scope_buffered_ssa_graph_executor graph graph_viz_pass multi_devices_graph_builder ssa_graph_printer ssa_graph_checker
)
cc_library
(
prune SRCS prune.cc DEPS framework_proto
)
cc_library
(
prune SRCS prune.cc DEPS framework_proto
)
cc_test
(
prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context
)
cc_test
(
prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context
)
...
...
paddle/fluid/framework/details/CMakeLists.txt
浏览文件 @
000ba1ac
...
@@ -31,9 +31,6 @@ cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base s
...
@@ -31,9 +31,6 @@ cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base s
cc_library
(
multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle
cc_library
(
multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle
scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle
)
scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle
)
cc_library
(
ssa_graph_builder_factory SRCS ssa_graph_builder_factory.cc DEPS multi_devices_graph_builder ssa_graph_printer ssa_graph_checker
)
cc_library
(
ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto
)
cc_library
(
ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto
)
cc_library
(
threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
cc_library
(
threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
simple_threadpool device_context
)
simple_threadpool device_context
)
...
...
paddle/fluid/framework/details/multi_devices_graph_builder.cc
浏览文件 @
000ba1ac
...
@@ -34,30 +34,22 @@ namespace paddle {
...
@@ -34,30 +34,22 @@ namespace paddle {
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
details
{
static
const
char
kLossVarName
[]
=
"loss_var_name"
;
static
const
char
kPlaces
[]
=
"places"
;
static
const
char
kParams
[]
=
"params"
;
static
const
char
kLocalScopes
[]
=
"local_scopes"
;
static
const
char
kStrategy
[]
=
"strategy"
;
void
MultiDevSSAGraphBuilder
::
Init
()
const
{
loss_var_name_
=
Get
<
const
std
::
string
>
(
kLossVarName
);
places_
=
Get
<
const
std
::
vector
<
platform
::
Place
>>
(
kPlaces
);
local_scopes_
=
Get
<
const
std
::
vector
<
Scope
*>>
(
kLocalScopes
);
strategy_
=
Get
<
const
BuildStrategy
>
(
kStrategy
);
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
MultiDevSSAGraphBuilder
::
MultiDevSSAGraphBuilder
(
nccl_ctxs_
=
&
Get
<
platform
::
NCCLContextMap
>
(
"nccl_ctxs"
);
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
std
::
string
&
loss_var_name
,
const
std
::
unordered_set
<
std
::
string
>
&
params
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
platform
::
NCCLContextMap
*
nccl_ctxs
,
const
BuildStrategy
&
strategy
)
:
loss_var_name_
(
loss_var_name
),
places_
(
places
),
local_scopes_
(
local_scopes
),
nccl_ctxs_
(
nccl_ctxs
),
strategy_
(
strategy
)
{
#else
MultiDevSSAGraphBuilder
::
MultiDevSSAGraphBuilder
(
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
std
::
string
&
loss_var_name
,
const
std
::
unordered_set
<
std
::
string
>
&
params
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
BuildStrategy
&
strategy
)
:
loss_var_name_
(
loss_var_name
),
places_
(
places
),
local_scopes_
(
local_scopes
),
strategy_
(
strategy
)
{
#endif
#endif
for
(
auto
&
p
:
params
)
{
for
(
auto
&
p
:
Get
<
const
std
::
unordered_set
<
std
::
string
>>
(
kParams
))
{
grad_names_
.
insert
(
GradVarName
(
p
));
grad_names_
.
insert
(
GradVarName
(
p
));
}
}
balance_vars_
.
resize
(
places_
.
size
(),
0
);
balance_vars_
.
resize
(
places_
.
size
(),
0
);
...
@@ -72,7 +64,7 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(ir::Graph *result,
...
@@ -72,7 +64,7 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(ir::Graph *result,
ir
::
Node
*
node
,
ir
::
Node
*
node
,
size_t
place_id
)
const
{
size_t
place_id
)
const
{
auto
p
=
places_
[
place_id
];
auto
p
=
places_
[
place_id
];
auto
*
op_handle
=
result
->
Get
<
GraphOps
>
(
"ops"
).
back
().
get
();
auto
*
op_handle
=
result
->
Get
<
GraphOps
>
(
kGraphOps
).
back
().
get
();
op_handle
->
SetDeviceContext
(
p
,
op_handle
->
SetDeviceContext
(
p
,
platform
::
DeviceContextPool
::
Instance
().
Get
(
p
));
platform
::
DeviceContextPool
::
Instance
().
Get
(
p
));
...
@@ -239,8 +231,9 @@ std::vector<ir::Node *> SortOpsAndDelayOptimizeOp(const ir::Graph &graph) {
...
@@ -239,8 +231,9 @@ std::vector<ir::Node *> SortOpsAndDelayOptimizeOp(const ir::Graph &graph) {
return
sorted_ret
;
return
sorted_ret
;
}
}
std
::
unique_ptr
<
ir
::
Graph
>
MultiDevSSAGraphBuilder
::
Apply
(
std
::
unique_ptr
<
ir
::
Graph
>
MultiDevSSAGraphBuilder
::
Apply
Impl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
Init
();
// Give the topology sort order and rebuild the graph structure.
// Give the topology sort order and rebuild the graph structure.
std
::
vector
<
ir
::
Node
*>
sorted_ops
=
SortOpsAndDelayOptimizeOp
(
*
graph
);
std
::
vector
<
ir
::
Node
*>
sorted_ops
=
SortOpsAndDelayOptimizeOp
(
*
graph
);
auto
nodes
=
graph
->
ReleaseNodes
();
auto
nodes
=
graph
->
ReleaseNodes
();
...
@@ -254,9 +247,10 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply(
...
@@ -254,9 +247,10 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply(
std
::
unordered_set
<
std
::
string
>
og_has_been_broadcast
;
std
::
unordered_set
<
std
::
string
>
og_has_been_broadcast
;
// We cannot invoke resize. It is a bug of GCC 4.8
// We cannot invoke resize. It is a bug of GCC 4.8
result
.
Set
(
"vars"
,
new
GraphVars
(
places_
.
size
()));
result
.
Set
(
kGraphVars
,
new
GraphVars
(
places_
.
size
()));
result
.
Set
(
"dep_vars"
,
new
GraphDepVars
);
result
.
Set
(
kGraphDepVars
,
new
GraphDepVars
);
result
.
Set
(
"ops"
,
new
GraphOps
);
result
.
Set
(
kGraphOps
,
new
GraphOps
);
result
.
Set
(
kShardedVarDevice
,
new
ShardedVarDevice
);
// find send/recv vars so that we can place the distributed training
// find send/recv vars so that we can place the distributed training
// related op in the place 0
// related op in the place 0
...
@@ -289,11 +283,12 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply(
...
@@ -289,11 +283,12 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply(
// the block.
// the block.
is_forwarding
=
false
;
is_forwarding
=
false
;
}
else
{
}
else
{
int
op_dev_id
=
GetOpDeviceID
(
node
);
int
op_dev_id
=
GetOpDeviceID
(
result
,
node
);
if
(
op_dev_id
!=
-
1
)
{
// This op only runs on one specific device.
if
(
op_dev_id
!=
-
1
)
{
// This op only runs on one specific device.
CreateComputationalOp
(
&
result
,
node
,
op_dev_id
);
CreateComputationalOp
(
&
result
,
node
,
op_dev_id
);
for
(
ir
::
Node
*
n
:
node
->
outputs
)
{
for
(
ir
::
Node
*
n
:
node
->
outputs
)
{
var_name_on_devices_
.
emplace
(
n
->
Name
(),
op_dev_id
);
graph
->
Get
<
ShardedVarDevice
>
(
kShardedVarDevice
)
.
emplace
(
n
->
Name
(),
op_dev_id
);
}
}
}
else
{
}
else
{
// This op runs on all devices, and its output may have parameter's
// This op runs on all devices, and its output may have parameter's
...
@@ -330,7 +325,8 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply(
...
@@ -330,7 +325,8 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply(
case
BuildStrategy
::
ReduceStrategy
::
kReduce
:
case
BuildStrategy
::
ReduceStrategy
::
kReduce
:
cur_device_id
=
GetAppropriateDeviceID
({
g_name
});
cur_device_id
=
GetAppropriateDeviceID
({
g_name
});
CreateReduceOp
(
&
result
,
g_name
,
cur_device_id
);
CreateReduceOp
(
&
result
,
g_name
,
cur_device_id
);
var_name_on_devices_
.
emplace
(
g_name
,
cur_device_id
);
graph
->
Get
<
ShardedVarDevice
>
(
kShardedVarDevice
)
.
emplace
(
g_name
,
cur_device_id
);
bcast_var_name_set
[
cur_device_id
].
emplace
(
p_name
);
bcast_var_name_set
[
cur_device_id
].
emplace
(
p_name
);
break
;
break
;
case
BuildStrategy
::
ReduceStrategy
::
kAllReduce
:
case
BuildStrategy
::
ReduceStrategy
::
kAllReduce
:
...
@@ -416,16 +412,16 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(ir::Graph *result,
...
@@ -416,16 +412,16 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(ir::Graph *result,
result
->
CreateEmptyNode
(
"broadcast"
,
ir
::
Node
::
Type
::
kOperation
),
result
->
CreateEmptyNode
(
"broadcast"
,
ir
::
Node
::
Type
::
kOperation
),
local_scopes_
,
places_
);
local_scopes_
,
places_
);
#endif
#endif
result
->
Get
<
GraphOps
>
(
"ops"
).
emplace_back
(
op_handle
);
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
op_handle
);
auto
*
in
=
auto
*
in
=
result
->
Get
<
GraphVars
>
(
"vars"
).
at
(
src_dev_id
).
at
(
p_name
).
back
().
get
();
result
->
Get
<
GraphVars
>
(
kGraphVars
).
at
(
src_dev_id
).
at
(
p_name
).
back
().
get
();
op_handle
->
AddInput
(
in
);
op_handle
->
AddInput
(
in
);
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
auto
&
p
=
places_
[
i
];
auto
&
p
=
places_
[
i
];
SetCommunicationContext
(
op_handle
,
p
);
SetCommunicationContext
(
op_handle
,
p
);
auto
&
vars
=
result
->
Get
<
GraphVars
>
(
"vars"
).
at
(
i
).
at
(
p_name
);
auto
&
vars
=
result
->
Get
<
GraphVars
>
(
kGraphVars
).
at
(
i
).
at
(
p_name
);
auto
*
out_var
=
new
VarHandle
(
auto
*
out_var
=
new
VarHandle
(
result
->
CreateEmptyNode
(
p_name
,
ir
::
Node
::
Type
::
kVariable
),
vars
.
size
(),
result
->
CreateEmptyNode
(
p_name
,
ir
::
Node
::
Type
::
kVariable
),
vars
.
size
(),
i
,
p_name
,
p
);
i
,
p_name
,
p
);
...
@@ -437,7 +433,7 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(ir::Graph *result,
...
@@ -437,7 +433,7 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(ir::Graph *result,
void
MultiDevSSAGraphBuilder
::
CreateComputationalOp
(
ir
::
Graph
*
result
,
void
MultiDevSSAGraphBuilder
::
CreateComputationalOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
ir
::
Node
*
node
,
int
dev_id
)
const
{
int
dev_id
)
const
{
result
->
Get
<
GraphOps
>
(
"ops"
).
emplace_back
(
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
new
ComputationOpHandle
(
result
->
CreateOpNode
(
node
->
Op
()),
new
ComputationOpHandle
(
result
->
CreateOpNode
(
node
->
Op
()),
local_scopes_
[
dev_id
],
places_
[
dev_id
]));
local_scopes_
[
dev_id
],
places_
[
dev_id
]));
CreateOpHandleIOs
(
result
,
node
,
dev_id
);
CreateOpHandleIOs
(
result
,
node
,
dev_id
);
...
@@ -446,20 +442,20 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(ir::Graph *result,
...
@@ -446,20 +442,20 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(ir::Graph *result,
void
MultiDevSSAGraphBuilder
::
InsertAllReduceOp
(
ir
::
Graph
*
result
,
void
MultiDevSSAGraphBuilder
::
InsertAllReduceOp
(
ir
::
Graph
*
result
,
const
std
::
string
&
og
)
const
{
const
std
::
string
&
og
)
const
{
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
result
->
Get
<
GraphOps
>
(
"ops"
).
emplace_back
(
new
AllReduceOpHandle
(
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
new
AllReduceOpHandle
(
result
->
CreateEmptyNode
(
"allreduce"
,
ir
::
Node
::
Type
::
kOperation
),
result
->
CreateEmptyNode
(
"allreduce"
,
ir
::
Node
::
Type
::
kOperation
),
local_scopes_
,
places_
,
nccl_ctxs_
));
local_scopes_
,
places_
,
nccl_ctxs_
));
#else
#else
result
->
Get
<
GraphOps
>
(
"ops"
).
emplace_back
(
new
AllReduceOpHandle
(
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
new
AllReduceOpHandle
(
result
->
CreateEmptyNode
(
"allreduce"
,
ir
::
Node
::
Type
::
kOperation
),
result
->
CreateEmptyNode
(
"allreduce"
,
ir
::
Node
::
Type
::
kOperation
),
local_scopes_
,
places_
));
local_scopes_
,
places_
));
#endif
#endif
auto
*
op_handle
=
result
->
Get
<
GraphOps
>
(
"ops"
).
back
().
get
();
auto
*
op_handle
=
result
->
Get
<
GraphOps
>
(
kGraphOps
).
back
().
get
();
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
auto
&
p
=
places_
[
i
];
auto
&
p
=
places_
[
i
];
SetCommunicationContext
(
op_handle
,
p
);
SetCommunicationContext
(
op_handle
,
p
);
auto
&
vars
=
result
->
Get
<
GraphVars
>
(
"vars"
)[
i
][
og
];
auto
&
vars
=
result
->
Get
<
GraphVars
>
(
kGraphVars
)[
i
][
og
];
PADDLE_ENFORCE
(
!
vars
.
empty
());
PADDLE_ENFORCE
(
!
vars
.
empty
());
auto
&
prev_grad
=
vars
.
back
();
auto
&
prev_grad
=
vars
.
back
();
op_handle
->
AddInput
(
prev_grad
.
get
());
op_handle
->
AddInput
(
prev_grad
.
get
());
...
@@ -475,20 +471,20 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(ir::Graph *result,
...
@@ -475,20 +471,20 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(ir::Graph *result,
void
MultiDevSSAGraphBuilder
::
InsertDataBalanceOp
(
void
MultiDevSSAGraphBuilder
::
InsertDataBalanceOp
(
ir
::
Graph
*
result
,
const
std
::
vector
<
std
::
string
>
&
datas
)
const
{
ir
::
Graph
*
result
,
const
std
::
vector
<
std
::
string
>
&
datas
)
const
{
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
result
->
Get
<
GraphOps
>
(
"ops"
).
emplace_back
(
new
DataBalanceOpHandle
(
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
new
DataBalanceOpHandle
(
result
->
CreateEmptyNode
(
"data_balance"
,
ir
::
Node
::
Type
::
kOperation
),
result
->
CreateEmptyNode
(
"data_balance"
,
ir
::
Node
::
Type
::
kOperation
),
local_scopes_
,
places_
,
nccl_ctxs_
));
local_scopes_
,
places_
,
nccl_ctxs_
));
#else
#else
result
->
Get
<
GraphOps
>
(
"ops"
).
emplace_back
(
new
DataBalanceOpHandle
(
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
new
DataBalanceOpHandle
(
result
->
CreateEmptyNode
(
"data_balance"
,
ir
::
Node
::
Type
::
kOperation
),
result
->
CreateEmptyNode
(
"data_balance"
,
ir
::
Node
::
Type
::
kOperation
),
local_scopes_
,
places_
));
local_scopes_
,
places_
));
#endif
#endif
auto
*
op_handle
=
result
->
Get
<
GraphOps
>
(
"ops"
).
back
().
get
();
auto
*
op_handle
=
result
->
Get
<
GraphOps
>
(
kGraphOps
).
back
().
get
();
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
auto
&
p
=
places_
[
i
];
auto
&
p
=
places_
[
i
];
SetCommunicationContext
(
op_handle
,
p
);
SetCommunicationContext
(
op_handle
,
p
);
for
(
const
std
::
string
&
d_name
:
datas
)
{
for
(
const
std
::
string
&
d_name
:
datas
)
{
auto
&
vars
=
result
->
Get
<
GraphVars
>
(
"vars"
)[
i
][
d_name
];
auto
&
vars
=
result
->
Get
<
GraphVars
>
(
kGraphVars
)[
i
][
d_name
];
PADDLE_ENFORCE
(
!
vars
.
empty
());
PADDLE_ENFORCE
(
!
vars
.
empty
());
op_handle
->
AddInput
(
vars
.
back
().
get
());
op_handle
->
AddInput
(
vars
.
back
().
get
());
auto
var
=
new
VarHandle
(
auto
var
=
new
VarHandle
(
...
@@ -512,7 +508,8 @@ bool MultiDevSSAGraphBuilder::IsParameterGradientOnce(
...
@@ -512,7 +508,8 @@ bool MultiDevSSAGraphBuilder::IsParameterGradientOnce(
return
is_pg_once
;
return
is_pg_once
;
}
}
int
MultiDevSSAGraphBuilder
::
GetOpDeviceID
(
ir
::
Node
*
node
)
const
{
int
MultiDevSSAGraphBuilder
::
GetOpDeviceID
(
const
ir
::
Graph
&
graph
,
ir
::
Node
*
node
)
const
{
if
(
strategy_
.
reduce_
!=
BuildStrategy
::
ReduceStrategy
::
kReduce
)
{
if
(
strategy_
.
reduce_
!=
BuildStrategy
::
ReduceStrategy
::
kReduce
)
{
return
-
1
;
return
-
1
;
}
}
...
@@ -525,15 +522,17 @@ int MultiDevSSAGraphBuilder::GetOpDeviceID(ir::Node *node) const {
...
@@ -525,15 +522,17 @@ int MultiDevSSAGraphBuilder::GetOpDeviceID(ir::Node *node) const {
node
->
Op
()
->
GetAttr
(
OpProtoAndCheckerMaker
::
OpRoleVarAttrName
()));
node
->
Op
()
->
GetAttr
(
OpProtoAndCheckerMaker
::
OpRoleVarAttrName
()));
PADDLE_ENFORCE_EQ
(
param_grad
.
size
(),
2U
);
PADDLE_ENFORCE_EQ
(
param_grad
.
size
(),
2U
);
int
dev_id
=
GetVarDeviceID
(
param_grad
[
1
]);
int
dev_id
=
GetVarDeviceID
(
graph
,
param_grad
[
1
]);
PADDLE_ENFORCE_NE
(
dev_id
,
-
1
,
"dev_id should not be -1.[%s, %s, %s]"
,
PADDLE_ENFORCE_NE
(
dev_id
,
-
1
,
"dev_id should not be -1.[%s, %s, %s]"
,
node
->
Op
()
->
Type
(),
param_grad
[
0
],
param_grad
[
1
]);
node
->
Op
()
->
Type
(),
param_grad
[
0
],
param_grad
[
1
]);
return
dev_id
;
return
dev_id
;
}
}
int
MultiDevSSAGraphBuilder
::
GetVarDeviceID
(
const
std
::
string
&
varname
)
const
{
int
MultiDevSSAGraphBuilder
::
GetVarDeviceID
(
const
ir
::
Graph
&
graph
,
auto
got
=
var_name_on_devices_
.
find
(
varname
);
const
std
::
string
&
varname
)
const
{
return
got
==
var_name_on_devices_
.
end
()
?
-
1
:
got
->
second
;
auto
&
sharded_var_device
=
graph
.
Get
<
ShardedVarDevice
>
(
kShardedVarDevice
);
auto
got
=
sharded_var_device
.
find
(
varname
);
return
got
==
sharded_var_device
.
end
()
?
-
1
:
got
->
second
;
}
}
void
MultiDevSSAGraphBuilder
::
CreateScaleLossGradOp
(
ir
::
Graph
*
result
)
const
{
void
MultiDevSSAGraphBuilder
::
CreateScaleLossGradOp
(
ir
::
Graph
*
result
)
const
{
...
@@ -551,7 +550,7 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(ir::Graph *result) const {
...
@@ -551,7 +550,7 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(ir::Graph *result) const {
result
->
CreateEmptyNode
(
"scale_loss_grad"
,
ir
::
Node
::
Type
::
kOperation
),
result
->
CreateEmptyNode
(
"scale_loss_grad"
,
ir
::
Node
::
Type
::
kOperation
),
local_scopes_
.
size
(),
local_scopes_
[
i
],
places_
[
i
],
local_scopes_
.
size
(),
local_scopes_
[
i
],
places_
[
i
],
communication_dev_ctx
);
communication_dev_ctx
);
result
->
Get
<
GraphOps
>
(
"ops"
).
emplace_back
(
op_handle
);
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
op_handle
);
// FIXME: Currently ScaleLossGradOp only use device_count as scale
// FIXME: Currently ScaleLossGradOp only use device_count as scale
// factor. So it does not depend on any other operators.
// factor. So it does not depend on any other operators.
...
@@ -572,7 +571,7 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(ir::Graph *result,
...
@@ -572,7 +571,7 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(ir::Graph *result,
for
(
size_t
scope_idx
=
0
;
scope_idx
<
num_places
;
++
scope_idx
)
{
for
(
size_t
scope_idx
=
0
;
scope_idx
<
num_places
;
++
scope_idx
)
{
auto
p
=
places_
[
scope_idx
];
auto
p
=
places_
[
scope_idx
];
auto
s
=
local_scopes_
[
scope_idx
];
auto
s
=
local_scopes_
[
scope_idx
];
result
->
Get
<
GraphOps
>
(
"ops"
).
emplace_back
(
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
new
ComputationOpHandle
(
result
->
CreateOpNode
(
node
->
Op
()),
s
,
p
));
new
ComputationOpHandle
(
result
->
CreateOpNode
(
node
->
Op
()),
s
,
p
));
CreateOpHandleIOs
(
result
,
node
,
scope_idx
);
CreateOpHandleIOs
(
result
,
node
,
scope_idx
);
}
}
...
@@ -582,25 +581,25 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result,
...
@@ -582,25 +581,25 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result,
const
std
::
string
&
og
,
const
std
::
string
&
og
,
int
dst_dev_id
)
const
{
int
dst_dev_id
)
const
{
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
result
->
Get
<
GraphOps
>
(
"ops"
).
emplace_back
(
new
ReduceOpHandle
(
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
new
ReduceOpHandle
(
result
->
CreateEmptyNode
(
"reduce"
,
ir
::
Node
::
Type
::
kOperation
),
result
->
CreateEmptyNode
(
"reduce"
,
ir
::
Node
::
Type
::
kOperation
),
local_scopes_
,
places_
,
nccl_ctxs_
));
local_scopes_
,
places_
,
nccl_ctxs_
));
#else
#else
result
->
Get
<
GraphOps
>
(
"ops"
).
emplace_back
(
new
ReduceOpHandle
(
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
new
ReduceOpHandle
(
result
->
CreateEmptyNode
(
"reduce"
,
ir
::
Node
::
Type
::
kOperation
),
result
->
CreateEmptyNode
(
"reduce"
,
ir
::
Node
::
Type
::
kOperation
),
local_scopes_
,
places_
));
local_scopes_
,
places_
));
#endif
#endif
auto
*
op_handle
=
result
->
Get
<
GraphOps
>
(
"ops"
).
back
().
get
();
auto
*
op_handle
=
result
->
Get
<
GraphOps
>
(
kGraphOps
).
back
().
get
();
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
auto
&
p
=
places_
[
i
];
auto
&
p
=
places_
[
i
];
SetCommunicationContext
(
op_handle
,
p
);
SetCommunicationContext
(
op_handle
,
p
);
auto
&
vars
=
result
->
Get
<
GraphVars
>
(
"vars"
)[
i
][
og
];
auto
&
vars
=
result
->
Get
<
GraphVars
>
(
kGraphVars
)[
i
][
og
];
PADDLE_ENFORCE
(
!
vars
.
empty
());
PADDLE_ENFORCE
(
!
vars
.
empty
());
auto
&
prev_grad
=
vars
.
back
();
auto
&
prev_grad
=
vars
.
back
();
op_handle
->
AddInput
(
prev_grad
.
get
());
op_handle
->
AddInput
(
prev_grad
.
get
());
}
}
auto
&
vars
=
result
->
Get
<
GraphVars
>
(
"vars"
)[
dst_dev_id
][
og
];
auto
&
vars
=
result
->
Get
<
GraphVars
>
(
kGraphVars
)[
dst_dev_id
][
og
];
auto
var
=
auto
var
=
new
VarHandle
(
result
->
CreateEmptyNode
(
og
,
ir
::
Node
::
Type
::
kVariable
),
new
VarHandle
(
result
->
CreateEmptyNode
(
og
,
ir
::
Node
::
Type
::
kVariable
),
vars
.
size
(),
dst_dev_id
,
og
,
places_
[
dst_dev_id
]);
vars
.
size
(),
dst_dev_id
,
og
,
places_
[
dst_dev_id
]);
...
@@ -613,11 +612,11 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result,
...
@@ -613,11 +612,11 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result,
// on it.
// on it.
void
MultiDevSSAGraphBuilder
::
ConnectOp
(
ir
::
Graph
*
result
,
OpHandleBase
*
op
,
void
MultiDevSSAGraphBuilder
::
ConnectOp
(
ir
::
Graph
*
result
,
OpHandleBase
*
op
,
const
std
::
string
&
prev_op_name
)
const
{
const
std
::
string
&
prev_op_name
)
const
{
for
(
auto
&
prev_op
:
result
->
Get
<
GraphOps
>
(
"ops"
))
{
for
(
auto
&
prev_op
:
result
->
Get
<
GraphOps
>
(
kGraphOps
))
{
if
(
prev_op
->
Name
()
==
prev_op_name
)
{
if
(
prev_op
->
Name
()
==
prev_op_name
)
{
auto
*
dep_var
=
new
DummyVarHandle
(
result
->
CreateControlDepVar
());
auto
*
dep_var
=
new
DummyVarHandle
(
result
->
CreateControlDepVar
());
prev_op
->
AddOutput
(
dep_var
);
prev_op
->
AddOutput
(
dep_var
);
result
->
Get
<
GraphDepVars
>
(
"dep_vars"
).
emplace
(
dep_var
);
result
->
Get
<
GraphDepVars
>
(
kGraphDepVars
).
emplace
(
dep_var
);
op
->
AddInput
(
dep_var
);
op
->
AddInput
(
dep_var
);
}
}
}
}
...
@@ -638,20 +637,23 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
...
@@ -638,20 +637,23 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
if
(
node
->
Op
()
->
Type
()
==
"split_byref"
||
if
(
node
->
Op
()
->
Type
()
==
"split_byref"
||
node
->
Op
()
->
Type
()
==
"split_selected_rows"
)
{
node
->
Op
()
->
Type
()
==
"split_selected_rows"
)
{
// TODO(paddle-dev): getting the first var is not safe.
// TODO(paddle-dev): getting the first var is not safe.
op_dev_id
=
GetVarDeviceID
(
input_var_names
[
0
]);
op_dev_id
=
GetVarDeviceID
(
*
result
,
input_var_names
[
0
]);
if
(
strategy_
.
reduce_
==
BuildStrategy
::
ReduceStrategy
::
kAllReduce
)
{
if
(
strategy_
.
reduce_
==
BuildStrategy
::
ReduceStrategy
::
kAllReduce
)
{
op_dev_id
=
GetAppropriateDeviceID
(
input_var_names
);
op_dev_id
=
GetAppropriateDeviceID
(
input_var_names
);
for
(
auto
&
varname
:
input_var_names
)
{
for
(
auto
&
varname
:
input_var_names
)
{
var_name_on_devices_
.
emplace
(
varname
,
op_dev_id
);
result
->
Get
<
ShardedVarDevice
>
(
kShardedVarDevice
)
.
emplace
(
varname
,
op_dev_id
);
}
}
}
}
for
(
auto
&
varname
:
output_var_names
)
{
for
(
auto
&
varname
:
output_var_names
)
{
var_name_on_devices_
.
emplace
(
varname
,
op_dev_id
);
result
->
Get
<
ShardedVarDevice
>
(
kShardedVarDevice
)
.
emplace
(
varname
,
op_dev_id
);
}
}
}
else
if
(
node
->
Op
()
->
Type
()
==
"concat"
)
{
}
else
if
(
node
->
Op
()
->
Type
()
==
"concat"
)
{
op_dev_id
=
GetVarDeviceID
(
input_var_names
[
0
]);
op_dev_id
=
GetVarDeviceID
(
*
result
,
input_var_names
[
0
]);
for
(
auto
&
varname
:
output_var_names
)
{
for
(
auto
&
varname
:
output_var_names
)
{
var_name_on_devices_
.
emplace
(
varname
,
op_dev_id
);
result
->
Get
<
ShardedVarDevice
>
(
kShardedVarDevice
)
.
emplace
(
varname
,
op_dev_id
);
}
}
}
else
{
}
else
{
PADDLE_ENFORCE
(
PADDLE_ENFORCE
(
...
@@ -665,7 +667,7 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
...
@@ -665,7 +667,7 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
CreateComputationalOp
(
result
,
node
,
op_dev_id
);
CreateComputationalOp
(
result
,
node
,
op_dev_id
);
if
(
node
->
Op
()
->
Type
()
==
"concat"
)
{
if
(
node
->
Op
()
->
Type
()
==
"concat"
)
{
ConnectOp
(
result
,
result
->
Get
<
GraphOps
>
(
"ops"
).
back
().
get
(),
ConnectOp
(
result
,
result
->
Get
<
GraphOps
>
(
kGraphOps
).
back
().
get
(),
"fetch_barrier"
);
"fetch_barrier"
);
}
}
}
}
...
@@ -676,7 +678,7 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
...
@@ -676,7 +678,7 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
int
op_dev_id
=
-
1
;
int
op_dev_id
=
-
1
;
if
(
node
->
Op
()
->
Type
()
==
"send"
)
{
if
(
node
->
Op
()
->
Type
()
==
"send"
)
{
// TODO(paddle-dev): getting the first var is not safe.
// TODO(paddle-dev): getting the first var is not safe.
op_dev_id
=
GetVarDeviceID
(
node
->
inputs
[
0
]
->
Name
());
op_dev_id
=
GetVarDeviceID
(
*
result
,
node
->
inputs
[
0
]
->
Name
());
PADDLE_ENFORCE
(
!
ir
::
IsControlDepVar
(
*
node
->
inputs
[
0
]),
PADDLE_ENFORCE
(
!
ir
::
IsControlDepVar
(
*
node
->
inputs
[
0
]),
"This hack no longer holds, please fix."
);
"This hack no longer holds, please fix."
);
// the variable name which contains .block means it was splited by
// the variable name which contains .block means it was splited by
...
@@ -691,7 +693,8 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
...
@@ -691,7 +693,8 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
}
}
op_dev_id
=
GetAppropriateDeviceID
(
input_var_names
);
op_dev_id
=
GetAppropriateDeviceID
(
input_var_names
);
for
(
auto
&
varname
:
input_var_names
)
{
for
(
auto
&
varname
:
input_var_names
)
{
var_name_on_devices_
.
emplace
(
varname
,
op_dev_id
);
result
->
Get
<
ShardedVarDevice
>
(
kShardedVarDevice
)
.
emplace
(
varname
,
op_dev_id
);
}
}
}
}
}
else
if
(
node
->
Op
()
->
Type
()
==
"recv"
)
{
}
else
if
(
node
->
Op
()
->
Type
()
==
"recv"
)
{
...
@@ -701,7 +704,8 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
...
@@ -701,7 +704,8 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
}
}
op_dev_id
=
GetAppropriateDeviceID
(
output_var_names
);
op_dev_id
=
GetAppropriateDeviceID
(
output_var_names
);
for
(
auto
&
varname
:
output_var_names
)
{
for
(
auto
&
varname
:
output_var_names
)
{
var_name_on_devices_
.
emplace
(
varname
,
op_dev_id
);
result
->
Get
<
ShardedVarDevice
>
(
kShardedVarDevice
)
.
emplace
(
varname
,
op_dev_id
);
}
}
}
else
{
}
else
{
// send_barrier and fetch_barrier op can be scheduled on device 0
// send_barrier and fetch_barrier op can be scheduled on device 0
...
@@ -711,17 +715,18 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
...
@@ -711,17 +715,18 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
PADDLE_ENFORCE
(
op_dev_id
!=
-
1
,
"can not find the right place for rpc op: %s"
,
PADDLE_ENFORCE
(
op_dev_id
!=
-
1
,
"can not find the right place for rpc op: %s"
,
node
->
Op
()
->
Type
());
node
->
Op
()
->
Type
());
result
->
Get
<
GraphOps
>
(
"ops"
).
emplace_back
(
new
RPCOpHandle
(
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
new
RPCOpHandle
(
result
->
CreateOpNode
(
node
->
Op
()),
*
node
->
Op
(),
local_scopes_
[
op_dev_id
],
result
->
CreateOpNode
(
node
->
Op
()),
*
node
->
Op
(),
local_scopes_
[
op_dev_id
],
node
->
Op
()
->
Type
(),
places_
[
op_dev_id
]));
node
->
Op
()
->
Type
(),
places_
[
op_dev_id
]));
// TODO(panyx0718): This might not be needed anymore.
if
(
node
->
Op
()
->
Type
()
==
"send_barrier"
)
{
if
(
node
->
Op
()
->
Type
()
==
"send_barrier"
)
{
ConnectOp
(
result
,
result
->
Get
<
GraphOps
>
(
"ops"
).
back
().
get
(),
"send"
);
ConnectOp
(
result
,
result
->
Get
<
GraphOps
>
(
kGraphOps
).
back
().
get
(),
"send"
);
}
else
if
(
node
->
Op
()
->
Type
()
==
"recv"
)
{
}
else
if
(
node
->
Op
()
->
Type
()
==
"recv"
)
{
ConnectOp
(
result
,
result
->
Get
<
GraphOps
>
(
"ops"
).
back
().
get
(),
ConnectOp
(
result
,
result
->
Get
<
GraphOps
>
(
kGraphOps
).
back
().
get
(),
"send_barrier"
);
"send_barrier"
);
}
else
if
(
node
->
Op
()
->
Type
()
==
"fetch_barrier"
)
{
}
else
if
(
node
->
Op
()
->
Type
()
==
"fetch_barrier"
)
{
ConnectOp
(
result
,
result
->
Get
<
GraphOps
>
(
"ops"
).
back
().
get
(),
"recv"
);
ConnectOp
(
result
,
result
->
Get
<
GraphOps
>
(
kGraphOps
).
back
().
get
(),
"recv"
);
}
else
if
(
node
->
Op
()
->
Type
()
==
"send"
)
{
}
else
if
(
node
->
Op
()
->
Type
()
==
"send"
)
{
// do nothing
// do nothing
}
else
{
}
else
{
...
@@ -743,3 +748,11 @@ bool MultiDevSSAGraphBuilder::IsScaleLossOp(ir::Node *node) const {
...
@@ -743,3 +748,11 @@ bool MultiDevSSAGraphBuilder::IsScaleLossOp(ir::Node *node) const {
}
// namespace details
}
// namespace details
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
REGISTER_PASS
(
multi_device_pass
,
paddle
::
framework
::
details
::
MultiDevSSAGraphBuilder
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kLossVarName
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kPlaces
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kParams
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kLocalScopes
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kStrategy
);
paddle/fluid/framework/details/multi_devices_graph_builder.h
浏览文件 @
000ba1ac
...
@@ -31,39 +31,27 @@ class Scope;
...
@@ -31,39 +31,27 @@ class Scope;
namespace
details
{
namespace
details
{
class
MultiDevSSAGraphBuilder
:
public
SSAGraphBuilder
{
class
MultiDevSSAGraphBuilder
:
public
SSAGraphBuilder
{
public:
protected:
#ifdef PADDLE_WITH_CUDA
std
::
unique_ptr
<
ir
::
Graph
>
ApplyImpl
(
MultiDevSSAGraphBuilder
(
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
std
::
string
&
loss_var_name
,
const
std
::
unordered_set
<
std
::
string
>
&
params
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
platform
::
NCCLContextMap
*
nccl_ctxs
,
const
BuildStrategy
&
strategy
);
#else
MultiDevSSAGraphBuilder
(
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
std
::
string
&
loss_var_name
,
const
std
::
unordered_set
<
std
::
string
>
&
params
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
BuildStrategy
&
strategy
);
#endif
std
::
unique_ptr
<
ir
::
Graph
>
Apply
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
override
;
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
override
;
int
GetVarDeviceID
(
const
std
::
string
&
varname
)
const
override
;
private:
private:
void
CreateOpHandleIOs
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
void
CreateOpHandleIOs
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
size_t
device_id
)
const
;
size_t
device_id
)
const
;
void
Init
()
const
;
private:
private:
std
::
string
loss_var_name_
;
mutable
std
::
string
loss_var_name_
;
const
std
::
vector
<
platform
::
Place
>
&
places_
;
mutable
std
::
vector
<
platform
::
Place
>
places_
;
const
std
::
vector
<
Scope
*>
&
local_scopes_
;
mutable
std
::
vector
<
Scope
*>
local_scopes_
;
std
::
unordered_set
<
std
::
string
>
grad_names_
;
mutable
std
::
unordered_set
<
std
::
string
>
grad_names_
;
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
platform
::
NCCLContextMap
*
nccl_ctxs_
;
mutable
platform
::
NCCLContextMap
*
nccl_ctxs_
;
#endif
#endif
int
GetVarDeviceID
(
const
ir
::
Graph
&
graph
,
const
std
::
string
&
varname
)
const
;
bool
IsScaleLossOp
(
ir
::
Node
*
node
)
const
;
bool
IsScaleLossOp
(
ir
::
Node
*
node
)
const
;
void
CreateRPCOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
)
const
;
void
CreateRPCOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
)
const
;
...
@@ -97,7 +85,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
...
@@ -97,7 +85,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
const
std
::
string
&
og
,
const
std
::
string
&
og
,
std
::
unordered_set
<
std
::
string
>
*
og_has_been_broadcast
)
const
;
std
::
unordered_set
<
std
::
string
>
*
og_has_been_broadcast
)
const
;
int
GetOpDeviceID
(
ir
::
Node
*
node
)
const
;
int
GetOpDeviceID
(
const
ir
::
Graph
&
graph
,
ir
::
Node
*
node
)
const
;
void
InsertAllReduceOp
(
ir
::
Graph
*
result
,
const
std
::
string
&
og
)
const
;
void
InsertAllReduceOp
(
ir
::
Graph
*
result
,
const
std
::
string
&
og
)
const
;
...
@@ -113,9 +101,8 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
...
@@ -113,9 +101,8 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
const
std
::
vector
<
std
::
string
>
&
var_names
)
const
;
const
std
::
vector
<
std
::
string
>
&
var_names
)
const
;
private:
private:
BuildStrategy
strategy_
;
mutable
BuildStrategy
strategy_
;
mutable
std
::
unordered_map
<
std
::
string
,
VarDesc
*>
all_vars_
;
mutable
std
::
unordered_map
<
std
::
string
,
VarDesc
*>
all_vars_
;
mutable
std
::
unordered_map
<
std
::
string
,
int
>
var_name_on_devices_
;
mutable
std
::
vector
<
int64_t
>
balance_vars_
;
mutable
std
::
vector
<
int64_t
>
balance_vars_
;
void
SetCommunicationContext
(
OpHandleBase
*
op_handle
,
void
SetCommunicationContext
(
OpHandleBase
*
op_handle
,
...
...
paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h
浏览文件 @
000ba1ac
...
@@ -40,6 +40,9 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor {
...
@@ -40,6 +40,9 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor {
ExecutionStrategy
strategy
,
std
::
vector
<
Scope
*>
local_scopes
,
ExecutionStrategy
strategy
,
std
::
vector
<
Scope
*>
local_scopes
,
std
::
vector
<
VariableInfo
>
var_infos
,
std
::
vector
<
platform
::
Place
>
places
,
std
::
vector
<
VariableInfo
>
var_infos
,
std
::
vector
<
platform
::
Place
>
places
,
std
::
unique_ptr
<
SSAGraphExecutor
>&&
underlying_executor
);
std
::
unique_ptr
<
SSAGraphExecutor
>&&
underlying_executor
);
const
ir
::
Graph
&
Graph
()
const
{
return
underlying_executor_
->
Graph
();
}
FeedFetchList
Run
(
const
std
::
vector
<
std
::
string
>&
fetch_tensors
)
override
;
FeedFetchList
Run
(
const
std
::
vector
<
std
::
string
>&
fetch_tensors
)
override
;
private:
private:
...
...
paddle/fluid/framework/details/ssa_graph_builder.cc
浏览文件 @
000ba1ac
...
@@ -18,7 +18,7 @@ namespace paddle {
...
@@ -18,7 +18,7 @@ namespace paddle {
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
details
{
void
SSAGraphBuilder
::
PolishGraphToSupportDataHazards
(
ir
::
Graph
*
graph
)
{
void
SSAGraphBuilder
::
PolishGraphToSupportDataHazards
(
ir
::
Graph
*
graph
)
{
for
(
auto
&
var_map
:
graph
->
Get
<
GraphVars
>
(
"vars"
))
{
for
(
auto
&
var_map
:
graph
->
Get
<
GraphVars
>
(
kGraphVars
))
{
for
(
auto
&
name_pair
:
var_map
)
{
for
(
auto
&
name_pair
:
var_map
)
{
if
(
name_pair
.
second
.
size
()
<=
1
)
{
if
(
name_pair
.
second
.
size
()
<=
1
)
{
continue
;
continue
;
...
@@ -50,7 +50,7 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(ir::Graph *graph) {
...
@@ -50,7 +50,7 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(ir::Graph *graph) {
auto
*
dep_var
=
new
DummyVarHandle
(
graph
->
CreateControlDepVar
());
auto
*
dep_var
=
new
DummyVarHandle
(
graph
->
CreateControlDepVar
());
read_op
->
AddOutput
(
dep_var
);
read_op
->
AddOutput
(
dep_var
);
write_op
->
AddInput
(
dep_var
);
write_op
->
AddInput
(
dep_var
);
graph
->
Get
<
GraphDepVars
>
(
"dep_vars"
).
emplace
(
dep_var
);
graph
->
Get
<
GraphDepVars
>
(
kGraphDepVars
).
emplace
(
dep_var
);
}
}
}
}
}
}
...
@@ -60,7 +60,7 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(ir::Graph *graph) {
...
@@ -60,7 +60,7 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(ir::Graph *graph) {
VarHandle
*
SSAGraphBuilder
::
CreateOrGetLatestVarHandle
(
VarHandle
*
SSAGraphBuilder
::
CreateOrGetLatestVarHandle
(
ir
::
Graph
*
graph
,
ir
::
Node
*
node
,
const
platform
::
Place
&
place
,
ir
::
Graph
*
graph
,
ir
::
Node
*
node
,
const
platform
::
Place
&
place
,
size_t
place_offset
)
{
size_t
place_offset
)
{
auto
&
var_holders
=
graph
->
Get
<
GraphVars
>
(
"vars"
)[
place_offset
];
auto
&
var_holders
=
graph
->
Get
<
GraphVars
>
(
kGraphVars
)[
place_offset
];
auto
&
var_holder
=
var_holders
[
node
->
Name
()];
auto
&
var_holder
=
var_holders
[
node
->
Name
()];
VarHandle
*
var
=
nullptr
;
VarHandle
*
var
=
nullptr
;
if
(
var_holder
.
empty
())
{
if
(
var_holder
.
empty
())
{
...
@@ -83,7 +83,8 @@ void SSAGraphBuilder::CreateOpOutput(ir::Graph *graph, OpHandleBase *op_handle,
...
@@ -83,7 +83,8 @@ void SSAGraphBuilder::CreateOpOutput(ir::Graph *graph, OpHandleBase *op_handle,
ir
::
Node
*
new_node
,
ir
::
Node
*
new_node
,
const
platform
::
Place
&
place
,
const
platform
::
Place
&
place
,
size_t
place_offset
)
{
size_t
place_offset
)
{
auto
&
vars
=
graph
->
Get
<
GraphVars
>
(
"vars"
)[
place_offset
][
new_node
->
Name
()];
auto
&
vars
=
graph
->
Get
<
GraphVars
>
(
kGraphVars
)[
place_offset
][
new_node
->
Name
()];
size_t
version
=
vars
.
size
();
size_t
version
=
vars
.
size
();
auto
var
=
auto
var
=
new
VarHandle
(
new_node
,
version
,
place_offset
,
new_node
->
Name
(),
place
);
new
VarHandle
(
new_node
,
version
,
place_offset
,
new_node
->
Name
(),
place
);
...
@@ -92,12 +93,12 @@ void SSAGraphBuilder::CreateOpOutput(ir::Graph *graph, OpHandleBase *op_handle,
...
@@ -92,12 +93,12 @@ void SSAGraphBuilder::CreateOpOutput(ir::Graph *graph, OpHandleBase *op_handle,
}
}
void
SSAGraphBuilder
::
AddOutputToLeafOps
(
ir
::
Graph
*
graph
)
{
void
SSAGraphBuilder
::
AddOutputToLeafOps
(
ir
::
Graph
*
graph
)
{
for
(
auto
&
op
:
graph
->
Get
<
GraphOps
>
(
"ops"
))
{
for
(
auto
&
op
:
graph
->
Get
<
GraphOps
>
(
kGraphOps
))
{
if
(
!
op
->
Outputs
().
empty
())
{
if
(
!
op
->
Outputs
().
empty
())
{
continue
;
continue
;
}
}
auto
*
dummy_leaf
=
new
DummyVarHandle
(
graph
->
CreateControlDepVar
());
auto
*
dummy_leaf
=
new
DummyVarHandle
(
graph
->
CreateControlDepVar
());
graph
->
Get
<
GraphDepVars
>
(
"dep_vars"
).
emplace
(
dummy_leaf
);
graph
->
Get
<
GraphDepVars
>
(
kGraphDepVars
).
emplace
(
dummy_leaf
);
op
->
AddOutput
(
dummy_leaf
);
op
->
AddOutput
(
dummy_leaf
);
}
}
}
}
...
...
paddle/fluid/framework/details/ssa_graph_builder.h
浏览文件 @
000ba1ac
...
@@ -39,21 +39,25 @@ namespace details {
...
@@ -39,21 +39,25 @@ namespace details {
typedef
std
::
vector
<
typedef
std
::
vector
<
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
unique_ptr
<
VarHandle
>>>>
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
unique_ptr
<
VarHandle
>>>>
GraphVars
;
GraphVars
;
const
char
kGraphVars
[]
=
"vars"
;
// aux variables to represent dependency. Useful to resolve data hazard.
// aux variables to represent dependency. Useful to resolve data hazard.
typedef
std
::
unordered_set
<
std
::
unique_ptr
<
VarHandleBase
>>
GraphDepVars
;
typedef
std
::
unordered_set
<
std
::
unique_ptr
<
VarHandleBase
>>
GraphDepVars
;
const
char
kGraphDepVars
[]
=
"dep_vars"
;
// all operators. NOTE that even we use a vector here, the operators is
// all operators. NOTE that even we use a vector here, the operators is
// unordered.
// unordered.
typedef
std
::
vector
<
std
::
unique_ptr
<
OpHandleBase
>>
GraphOps
;
typedef
std
::
vector
<
std
::
unique_ptr
<
OpHandleBase
>>
GraphOps
;
const
char
kGraphOps
[]
=
"ops"
;
typedef
std
::
unordered_map
<
std
::
string
,
int
>
ShardedVarDevice
;
const
char
kShardedVarDevice
[]
=
"sharded_var_device"
;
class
SSAGraphBuilder
:
public
ir
::
Pass
{
class
SSAGraphBuilder
:
public
ir
::
Pass
{
public:
public:
SSAGraphBuilder
()
{}
SSAGraphBuilder
()
{}
virtual
~
SSAGraphBuilder
()
{}
virtual
~
SSAGraphBuilder
()
{}
virtual
int
GetVarDeviceID
(
const
std
::
string
&
var_name
)
const
=
0
;
DISABLE_COPY_AND_ASSIGN
(
SSAGraphBuilder
);
DISABLE_COPY_AND_ASSIGN
(
SSAGraphBuilder
);
protected:
protected:
...
...
paddle/fluid/framework/details/ssa_graph_builder_factory.h
已删除
100644 → 0
浏览文件 @
0c7d6eb8
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/details/ssa_graph_builder.h"
#include "paddle/fluid/platform/place.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/nccl_helper.h"
#endif
namespace
paddle
{
namespace
framework
{
class
Scope
;
namespace
details
{
class
SSAGraphBuilderFactory
{
public:
SSAGraphBuilderFactory
(
const
std
::
vector
<
platform
::
Place
>&
places
,
const
std
::
string
&
loss_var_name
,
const
std
::
unordered_set
<
std
::
string
>&
param_names
,
const
std
::
vector
<
Scope
*>&
local_scopes
,
const
BuildStrategy
&
strategy
)
:
places_
(
places
),
loss_var_name_
(
loss_var_name
),
param_names_
(
param_names
),
local_scopes_
(
local_scopes
),
strategy_
(
strategy
)
{
#ifdef PADDLE_WITH_CUDA
nccl_ctxs_
=
nullptr
;
#endif
}
#ifdef PADDLE_WITH_CUDA
void
SetNCCLContextMap
(
platform
::
NCCLContextMap
*
nccl_ctxs
)
{
nccl_ctxs_
=
nccl_ctxs
;
}
#endif
std
::
unique_ptr
<
SSAGraphBuilder
>
Create
();
private:
std
::
vector
<
platform
::
Place
>
places_
;
std
::
string
loss_var_name_
;
std
::
unordered_set
<
std
::
string
>
param_names_
;
std
::
vector
<
Scope
*>
local_scopes_
;
BuildStrategy
strategy_
;
#ifdef PADDLE_WITH_CUDA
platform
::
NCCLContextMap
*
nccl_ctxs_
;
#endif
};
}
// namespace details
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/details/ssa_graph_checker.cc
浏览文件 @
000ba1ac
...
@@ -33,7 +33,7 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const {
...
@@ -33,7 +33,7 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const {
}
}
};
};
for
(
auto
&
var_map
:
graph
->
Get
<
GraphVars
>
(
"vars"
))
{
for
(
auto
&
var_map
:
graph
->
Get
<
GraphVars
>
(
kGraphVars
))
{
for
(
auto
&
name_pair
:
var_map
)
{
for
(
auto
&
name_pair
:
var_map
)
{
for
(
auto
&
version_pair
:
name_pair
.
second
)
{
for
(
auto
&
version_pair
:
name_pair
.
second
)
{
insert_pending_var
(
version_pair
.
get
());
insert_pending_var
(
version_pair
.
get
());
...
@@ -41,11 +41,11 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const {
...
@@ -41,11 +41,11 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const {
}
}
}
}
for
(
auto
&
var
:
graph
->
Get
<
GraphDepVars
>
(
"dep_vars"
))
{
for
(
auto
&
var
:
graph
->
Get
<
GraphDepVars
>
(
kGraphDepVars
))
{
insert_pending_var
(
var
.
get
());
insert_pending_var
(
var
.
get
());
}
}
for
(
auto
&
op
:
graph
->
Get
<
GraphOps
>
(
"ops"
))
{
for
(
auto
&
op
:
graph
->
Get
<
GraphOps
>
(
kGraphOps
))
{
if
(
op
->
Inputs
().
empty
())
{
if
(
op
->
Inputs
().
empty
())
{
ready_ops
.
insert
(
op
.
get
());
ready_ops
.
insert
(
op
.
get
());
}
else
{
}
else
{
...
@@ -85,3 +85,10 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const {
...
@@ -85,3 +85,10 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const {
}
// namespace details
}
// namespace details
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
REGISTER_PASS
(
multi_device_check_pass
,
paddle
::
framework
::
details
::
SSAGraghBuilderWithChecker
)
.
RequireGraphAttr
(
paddle
::
framework
::
details
::
kGraphVars
)
.
RequireGraphAttr
(
paddle
::
framework
::
details
::
kGraphDepVars
)
.
RequireGraphAttr
(
paddle
::
framework
::
details
::
kGraphOps
)
.
RequireGraphAttr
(
paddle
::
framework
::
details
::
kShardedVarDevice
);
paddle/fluid/framework/details/ssa_graph_checker.h
浏览文件 @
000ba1ac
...
@@ -23,26 +23,14 @@ namespace framework {
...
@@ -23,26 +23,14 @@ namespace framework {
namespace
details
{
namespace
details
{
class
SSAGraghBuilderWithChecker
:
public
SSAGraphBuilder
{
class
SSAGraghBuilderWithChecker
:
public
SSAGraphBuilder
{
public:
protected:
explicit
SSAGraghBuilderWithChecker
(
std
::
unique_ptr
<
ir
::
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
SSAGraphBuilder
>&&
builder
)
:
builder_
(
std
::
move
(
builder
))
{}
std
::
unique_ptr
<
ir
::
Graph
>
Apply
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
override
{
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
override
{
auto
new_graph
=
builder_
->
Apply
(
std
::
move
(
graph
));
PADDLE_ENFORCE
(
IsValidGraph
(
graph
.
get
()));
PADDLE_ENFORCE
(
IsValidGraph
(
new_graph
.
get
()));
return
graph
;
return
new_graph
;
}
int
GetVarDeviceID
(
const
std
::
string
&
var_name
)
const
override
{
return
builder_
->
GetVarDeviceID
(
var_name
);
}
}
bool
IsValidGraph
(
const
ir
::
Graph
*
graph
)
const
;
bool
IsValidGraph
(
const
ir
::
Graph
*
graph
)
const
;
private:
std
::
unique_ptr
<
SSAGraphBuilder
>
builder_
;
};
};
}
// namespace details
}
// namespace details
...
...
paddle/fluid/framework/details/ssa_graph_executor.h
浏览文件 @
000ba1ac
...
@@ -32,7 +32,9 @@ class SSAGraphExecutor {
...
@@ -32,7 +32,9 @@ class SSAGraphExecutor {
virtual
~
SSAGraphExecutor
();
virtual
~
SSAGraphExecutor
();
virtual
FeedFetchList
Run
(
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
)
=
0
;
virtual
const
ir
::
Graph
&
Graph
()
const
=
0
;
virtual
FeedFetchList
Run
(
const
std
::
vector
<
std
::
string
>&
fetch_tensors
)
=
0
;
};
};
}
// namespace details
}
// namespace details
}
// namespace framework
}
// namespace framework
...
...
paddle/fluid/framework/details/ssa_graph_printer.cc
浏览文件 @
000ba1ac
...
@@ -22,7 +22,7 @@ namespace details {
...
@@ -22,7 +22,7 @@ namespace details {
template
<
typename
Callback
>
template
<
typename
Callback
>
static
inline
void
IterAllVar
(
const
ir
::
Graph
&
graph
,
Callback
callback
)
{
static
inline
void
IterAllVar
(
const
ir
::
Graph
&
graph
,
Callback
callback
)
{
for
(
auto
&
each
:
graph
.
Get
<
GraphVars
>
(
"vars"
))
{
for
(
auto
&
each
:
graph
.
Get
<
GraphVars
>
(
kGraphVars
))
{
for
(
auto
&
pair1
:
each
)
{
for
(
auto
&
pair1
:
each
)
{
for
(
auto
&
pair2
:
pair1
.
second
)
{
for
(
auto
&
pair2
:
pair1
.
second
)
{
callback
(
*
pair2
);
callback
(
*
pair2
);
...
@@ -30,7 +30,7 @@ static inline void IterAllVar(const ir::Graph &graph, Callback callback) {
...
@@ -30,7 +30,7 @@ static inline void IterAllVar(const ir::Graph &graph, Callback callback) {
}
}
}
}
for
(
auto
&
var
:
graph
.
Get
<
GraphDepVars
>
(
"dep_vars"
))
{
for
(
auto
&
var
:
graph
.
Get
<
GraphDepVars
>
(
kGraphDepVars
))
{
callback
(
*
var
);
callback
(
*
var
);
}
}
}
}
...
@@ -61,7 +61,7 @@ void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph,
...
@@ -61,7 +61,7 @@ void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph,
});
});
size_t
op_id
=
0
;
size_t
op_id
=
0
;
for
(
auto
&
op
:
graph
.
Get
<
GraphOps
>
(
"ops"
))
{
for
(
auto
&
op
:
graph
.
Get
<
GraphOps
>
(
kGraphOps
))
{
std
::
string
op_name
=
"op_"
+
std
::
to_string
(
op_id
++
);
std
::
string
op_name
=
"op_"
+
std
::
to_string
(
op_id
++
);
sout
<<
op_name
<<
" [label=
\"
"
<<
op
->
Name
()
<<
"
\"
, shape=rect]"
sout
<<
op_name
<<
" [label=
\"
"
<<
op
->
Name
()
<<
"
\"
, shape=rect]"
<<
std
::
endl
;
<<
std
::
endl
;
...
@@ -81,3 +81,6 @@ void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph,
...
@@ -81,3 +81,6 @@ void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph,
}
// namespace details
}
// namespace details
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
REGISTER_PASS
(
multi_device_print_pass
,
paddle
::
framework
::
details
::
SSAGraghBuilderWithPrinter
);
paddle/fluid/framework/details/ssa_graph_printer.h
浏览文件 @
000ba1ac
...
@@ -14,7 +14,9 @@
...
@@ -14,7 +14,9 @@
#pragma once
#pragma once
#include <fstream>
#include <iosfwd>
#include <iosfwd>
#include <ostream>
#include <string>
#include <string>
#include "paddle/fluid/framework/details/ssa_graph_builder.h"
#include "paddle/fluid/framework/details/ssa_graph_builder.h"
...
@@ -34,38 +36,15 @@ class GraphvizSSAGraphPrinter : public SSAGraphPrinter {
...
@@ -34,38 +36,15 @@ class GraphvizSSAGraphPrinter : public SSAGraphPrinter {
};
};
class
SSAGraghBuilderWithPrinter
:
public
SSAGraphBuilder
{
class
SSAGraghBuilderWithPrinter
:
public
SSAGraphBuilder
{
public:
protected:
SSAGraghBuilderWithPrinter
(
std
::
ostream
&
sout
,
std
::
unique_ptr
<
ir
::
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
SSAGraphPrinter
>&&
printer
,
std
::
unique_ptr
<
SSAGraphBuilder
>&&
builder
)
:
printer_
(
std
::
move
(
printer
)),
builder_
(
std
::
move
(
builder
)),
stream_ref_
(
sout
)
{}
SSAGraghBuilderWithPrinter
(
std
::
unique_ptr
<
std
::
ostream
>&&
sout
,
std
::
unique_ptr
<
SSAGraphPrinter
>&&
printer
,
std
::
unique_ptr
<
SSAGraphBuilder
>&&
builder
)
:
printer_
(
std
::
move
(
printer
)),
builder_
(
std
::
move
(
builder
)),
stream_ptr_
(
std
::
move
(
sout
)),
stream_ref_
(
*
stream_ptr_
)
{}
std
::
unique_ptr
<
ir
::
Graph
>
Apply
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
override
{
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
override
{
auto
new_graph
=
builder_
->
Apply
(
std
::
move
(
graph
));
std
::
unique_ptr
<
std
::
ostream
>
fout
(
printer_
->
Print
(
*
new_graph
,
stream_ref_
);
new
std
::
ofstream
(
Get
<
const
std
::
string
>
(
"debug_graphviz_path"
)));
return
new_graph
;
PADDLE_ENFORCE
(
fout
->
good
());
Get
<
GraphvizSSAGraphPrinter
>
(
"graph_printer"
).
Print
(
*
graph
,
*
fout
);
return
graph
;
}
}
int
GetVarDeviceID
(
const
std
::
string
&
var_name
)
const
override
{
return
builder_
->
GetVarDeviceID
(
var_name
);
}
private:
std
::
unique_ptr
<
SSAGraphPrinter
>
printer_
;
std
::
unique_ptr
<
SSAGraphBuilder
>
builder_
;
std
::
unique_ptr
<
std
::
ostream
>
stream_ptr_
;
std
::
ostream
&
stream_ref_
;
};
};
}
// namespace details
}
// namespace details
...
...
paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
浏览文件 @
000ba1ac
...
@@ -45,18 +45,18 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
...
@@ -45,18 +45,18 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
std
::
unordered_set
<
OpHandleBase
*>
delayed_ops
;
std
::
unordered_set
<
OpHandleBase
*>
delayed_ops
;
// Transform SSAGraph to pending_ops & pending_vars
// Transform SSAGraph to pending_ops & pending_vars
for
(
auto
&
var_map
:
graph_
->
Get
<
details
::
GraphVars
>
(
"vars"
))
{
for
(
auto
&
var_map
:
graph_
->
Get
<
details
::
GraphVars
>
(
details
::
kGraphVars
))
{
for
(
auto
&
name_pair
:
var_map
)
{
for
(
auto
&
name_pair
:
var_map
)
{
for
(
auto
&
version_pair
:
name_pair
.
second
)
{
for
(
auto
&
version_pair
:
name_pair
.
second
)
{
InsertPendingVar
(
&
pending_vars
,
&
ready_vars
,
version_pair
.
get
());
InsertPendingVar
(
&
pending_vars
,
&
ready_vars
,
version_pair
.
get
());
}
}
}
}
}
}
for
(
auto
&
var
:
graph_
->
Get
<
details
::
GraphDepVars
>
(
"dep_vars"
))
{
for
(
auto
&
var
:
graph_
->
Get
<
details
::
GraphDepVars
>
(
details
::
kGraphDepVars
))
{
InsertPendingVar
(
&
pending_vars
,
&
ready_vars
,
var
.
get
());
InsertPendingVar
(
&
pending_vars
,
&
ready_vars
,
var
.
get
());
}
}
for
(
auto
&
op
:
graph_
->
Get
<
details
::
GraphOps
>
(
"ops"
))
{
for
(
auto
&
op
:
graph_
->
Get
<
details
::
GraphOps
>
(
details
::
kGraphOps
))
{
if
(
op
->
Inputs
().
empty
())
{
// Special case, Op has no input.
if
(
op
->
Inputs
().
empty
())
{
// Special case, Op has no input.
ready_ops
.
insert
(
op
.
get
());
ready_ops
.
insert
(
op
.
get
());
}
else
{
}
else
{
...
@@ -162,7 +162,7 @@ void ThreadedSSAGraphExecutor::InsertFetchOps(
...
@@ -162,7 +162,7 @@ void ThreadedSSAGraphExecutor::InsertFetchOps(
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
VarHandleBase
*>>
fetched_vars
;
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
VarHandleBase
*>>
fetched_vars
;
for
(
auto
&
fetch_var_name
:
fetch_tensors
)
{
for
(
auto
&
fetch_var_name
:
fetch_tensors
)
{
for
(
auto
&
var_map
:
graph_
->
Get
<
details
::
GraphVars
>
(
"vars"
))
{
for
(
auto
&
var_map
:
graph_
->
Get
<
details
::
GraphVars
>
(
details
::
kGraphVars
))
{
auto
it
=
var_map
.
find
(
fetch_var_name
);
auto
it
=
var_map
.
find
(
fetch_var_name
);
if
(
it
!=
var_map
.
end
())
{
if
(
it
!=
var_map
.
end
())
{
fetched_vars
[
fetch_var_name
].
push_back
(
it
->
second
.
rbegin
()
->
get
());
fetched_vars
[
fetch_var_name
].
push_back
(
it
->
second
.
rbegin
()
->
get
());
...
...
paddle/fluid/framework/details/threaded_ssa_graph_executor.h
浏览文件 @
000ba1ac
...
@@ -42,6 +42,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
...
@@ -42,6 +42,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
std
::
unique_ptr
<
ir
::
Graph
>
&&
graph
);
std
::
unique_ptr
<
ir
::
Graph
>
&&
graph
);
const
ir
::
Graph
&
Graph
()
const
{
return
*
graph_
;
}
// Run a SSAGraph by a thread pool
// Run a SSAGraph by a thread pool
// Use topological sort algorithm
// Use topological sort algorithm
FeedFetchList
Run
(
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
)
override
;
FeedFetchList
Run
(
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
)
override
;
...
...
paddle/fluid/framework/ir/CMakeLists.txt
浏览文件 @
000ba1ac
cc_library
(
node SRCS node.cc DEPS proto_desc
)
cc_library
(
node SRCS node.cc DEPS proto_desc
)
cc_library
(
graph SRCS graph.cc DEPS node
)
cc_library
(
graph SRCS graph.cc DEPS node
)
cc_library
(
graph_helper SRCS graph_helper.cc DEPS graph
)
cc_library
(
graph_helper SRCS graph_helper.cc DEPS graph
)
cc_library
(
pass SRCS pass.cc DEPS graph node
)
cc_library
(
pass SRCS pass.cc DEPS graph node graph_helper
)
cc_test
(
graph_test SRCS graph_test.cc DEPS graph op_registry
)
cc_library
(
graph_viz_pass SRCS graph_viz_pass.cc DEPS graph pass graph_helper
)
cc_test
(
graph_helper_test SRCS graph_helper_test.cc DEPS graph_helper op_registry
)
cc_test
(
pass_test SRCS pass_test.cc DEPS graph pass graph_helper
)
cc_test
(
graph_test SRCS graph_test.cc DEPS graph graph_helper op_registry
)
cc_test
(
graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_registry
)
paddle/fluid/framework/ir/graph.cc
浏览文件 @
000ba1ac
...
@@ -24,6 +24,68 @@ namespace paddle {
...
@@ -24,6 +24,68 @@ namespace paddle {
namespace
framework
{
namespace
framework
{
namespace
ir
{
namespace
ir
{
std
::
vector
<
std
::
string
>
FindDistTrainSendVars
(
const
std
::
vector
<
ir
::
Node
*>
&
nodes
)
{
std
::
vector
<
std
::
string
>
send_vars
;
// since parameters are all in block 0,
// it's enough to only scan send ops in block 0
for
(
auto
&
node
:
nodes
)
{
auto
op_vars
=
node
->
Op
()
->
InputArgumentNames
();
send_vars
.
reserve
(
send_vars
.
size
()
+
std
::
distance
(
op_vars
.
begin
(),
op_vars
.
end
()));
send_vars
.
insert
(
send_vars
.
end
(),
op_vars
.
begin
(),
op_vars
.
end
());
}
return
send_vars
;
}
std
::
vector
<
std
::
string
>
FindDistTrainRecvVars
(
const
std
::
vector
<
ir
::
Node
*>
&
nodes
)
{
std
::
vector
<
std
::
string
>
recv_vars
;
for
(
auto
&
node
:
nodes
)
{
auto
op_vars
=
node
->
Op
()
->
OutputArgumentNames
();
recv_vars
.
reserve
(
recv_vars
.
size
()
+
std
::
distance
(
op_vars
.
begin
(),
op_vars
.
end
()));
recv_vars
.
insert
(
recv_vars
.
end
(),
op_vars
.
begin
(),
op_vars
.
end
());
}
return
recv_vars
;
}
bool
IsDistTrainOp
(
ir
::
Node
*
node
,
const
std
::
vector
<
std
::
string
>
&
send_vars
,
const
std
::
vector
<
std
::
string
>
&
recv_vars
)
{
if
(
send_vars
.
size
()
==
0
||
recv_vars
.
size
()
==
0
)
{
return
false
;
}
/**
* Check any of opvars contains `.block` and in sendvars
*/
auto
checker
=
[](
const
std
::
vector
<
std
::
string
>
&
opvars
,
const
std
::
vector
<
std
::
string
>
&
rpc_vars
)
->
bool
{
for
(
auto
&
var
:
opvars
)
{
// a variable name with the suffix `.block` means it's a splited
// variable by (DistributeTranspiler)
// [python/paddle/fluid/transpiler/distribute_transpiler.py]
if
(
var
.
find
(
".block"
)
!=
std
::
string
::
npos
&&
std
::
find
(
rpc_vars
.
begin
(),
rpc_vars
.
end
(),
var
)
!=
rpc_vars
.
end
())
{
return
true
;
}
}
return
false
;
};
std
::
vector
<
std
::
string
>
input_var_names
;
std
::
vector
<
std
::
string
>
output_var_names
;
for
(
ir
::
Node
*
input
:
node
->
inputs
)
{
input_var_names
.
push_back
(
input
->
Name
());
}
for
(
ir
::
Node
*
output
:
node
->
outputs
)
{
output_var_names
.
push_back
(
output
->
Name
());
}
return
checker
(
output_var_names
,
send_vars
)
||
checker
(
input_var_names
,
recv_vars
);
}
Graph
::
Graph
(
const
ProgramDesc
&
program
)
:
program_
(
program
)
{
Graph
::
Graph
(
const
ProgramDesc
&
program
)
:
program_
(
program
)
{
VLOG
(
3
)
<<
"block in program:"
<<
program_
.
Size
();
VLOG
(
3
)
<<
"block in program:"
<<
program_
.
Size
();
std
::
unordered_map
<
std
::
string
,
VarDesc
*>
all_vars
;
std
::
unordered_map
<
std
::
string
,
VarDesc
*>
all_vars
;
...
@@ -61,6 +123,64 @@ Graph::Graph(const ProgramDesc &program) : program_(program) {
...
@@ -61,6 +123,64 @@ Graph::Graph(const ProgramDesc &program) : program_(program) {
var
->
inputs
.
push_back
(
node
);
var
->
inputs
.
push_back
(
node
);
}
}
}
}
std
::
vector
<
ir
::
Node
*>
send_ops
;
ir
::
Node
*
send_bar
=
nullptr
;
std
::
vector
<
ir
::
Node
*>
recv_ops
;
ir
::
Node
*
fetch_bar
=
nullptr
;
for
(
ir
::
Node
*
node
:
Nodes
())
{
if
(
node
->
Name
()
==
"send"
)
{
send_ops
.
push_back
(
node
);
}
else
if
(
node
->
Name
()
==
"send_barrier"
)
{
PADDLE_ENFORCE
(
!
send_bar
,
"only has one send barrier"
);
send_bar
=
node
;
}
else
if
(
node
->
Name
()
==
"recv"
)
{
recv_ops
.
push_back
(
node
);
}
else
if
(
node
->
Name
()
==
"fetch_barrier"
)
{
PADDLE_ENFORCE
(
!
fetch_bar
,
"only has one fetch barrier"
);
fetch_bar
=
node
;
}
}
if
(
send_bar
)
{
for
(
ir
::
Node
*
send
:
send_ops
)
{
ir
::
Node
*
dep_var
=
CreateControlDepVar
();
send
->
outputs
.
push_back
(
dep_var
);
dep_var
->
inputs
.
push_back
(
send
);
send_bar
->
inputs
.
push_back
(
dep_var
);
dep_var
->
outputs
.
push_back
(
send_bar
);
}
for
(
ir
::
Node
*
recv
:
recv_ops
)
{
ir
::
Node
*
dep_var
=
CreateControlDepVar
();
recv
->
inputs
.
push_back
(
dep_var
);
dep_var
->
outputs
.
push_back
(
recv
);
send_bar
->
outputs
.
push_back
(
dep_var
);
dep_var
->
inputs
.
push_back
(
send_bar
);
}
}
if
(
fetch_bar
)
{
for
(
ir
::
Node
*
recv
:
recv_ops
)
{
ir
::
Node
*
dep_var
=
CreateControlDepVar
();
recv
->
outputs
.
push_back
(
dep_var
);
dep_var
->
inputs
.
push_back
(
recv
);
fetch_bar
->
inputs
.
push_back
(
dep_var
);
dep_var
->
outputs
.
push_back
(
fetch_bar
);
}
}
std
::
vector
<
std
::
string
>
send_vars
=
FindDistTrainSendVars
(
send_ops
);
std
::
vector
<
std
::
string
>
recv_vars
=
FindDistTrainRecvVars
(
recv_ops
);
for
(
ir
::
Node
*
node
:
Nodes
())
{
if
(
IsDistTrainOp
(
node
,
send_vars
,
recv_vars
))
{
if
(
fetch_bar
&&
node
->
Name
()
==
"concat"
)
{
ir
::
Node
*
dep_var
=
CreateControlDepVar
();
fetch_bar
->
outputs
.
push_back
(
dep_var
);
dep_var
->
inputs
.
push_back
(
fetch_bar
);
node
->
inputs
.
push_back
(
dep_var
);
dep_var
->
outputs
.
push_back
(
node
);
}
}
}
/**
/**
* We only handle write after read(WAR), since it should not have a write
* We only handle write after read(WAR), since it should not have a write
* after write in program. If there are write after write operators, we need
* after write in program. If there are write after write operators, we need
...
...
paddle/fluid/framework/ir/graph.h
浏览文件 @
000ba1ac
...
@@ -40,14 +40,21 @@ class Graph {
...
@@ -40,14 +40,21 @@ class Graph {
attr_dels_
.
clear
();
attr_dels_
.
clear
();
}
}
bool
Has
(
const
std
::
string
&
attr_name
)
const
{
return
attrs_
.
find
(
attr_name
)
!=
attrs_
.
end
();
}
template
<
typename
AttrType
>
template
<
typename
AttrType
>
AttrType
&
Get
(
const
std
::
string
&
attr_name
)
const
{
AttrType
&
Get
(
const
std
::
string
&
attr_name
)
const
{
PADDLE_ENFORCE
(
Has
(
attr_name
),
"%s attr not registered for graph."
,
attr_name
);
return
*
boost
::
any_cast
<
AttrType
*>
(
attrs_
.
at
(
attr_name
));
return
*
boost
::
any_cast
<
AttrType
*>
(
attrs_
.
at
(
attr_name
));
}
}
template
<
typename
AttrType
>
template
<
typename
AttrType
>
void
Set
(
const
std
::
string
&
attr_name
,
AttrType
*
attr
)
{
void
Set
(
const
std
::
string
&
attr_name
,
AttrType
*
attr
)
{
PADDLE_ENFORCE
(
attrs_
.
count
(
attr_name
)
==
0
);
PADDLE_ENFORCE
(
attrs_
.
count
(
attr_name
)
==
0
,
"%s already set in the graph"
,
attr_name
);
attrs_
[
attr_name
]
=
attr
;
attrs_
[
attr_name
]
=
attr
;
attr_dels_
[
attr_name
]
=
[
attr
,
attr_name
]()
{
attr_dels_
[
attr_name
]
=
[
attr
,
attr_name
]()
{
VLOG
(
3
)
<<
"deleting "
<<
attr_name
;
VLOG
(
3
)
<<
"deleting "
<<
attr_name
;
...
...
paddle/fluid/framework/ir/graph_viz_pass.cc
0 → 100644
浏览文件 @
000ba1ac
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include <unordered_set>
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
static
const
char
kGraphVizPath
[]
=
"graph_viz_path"
;
std
::
unique_ptr
<
ir
::
Graph
>
GraphVizPass
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
const
std
::
string
graph_viz_path
=
Get
<
std
::
string
>
(
kGraphVizPath
);
std
::
unique_ptr
<
std
::
ostream
>
fout
(
new
std
::
ofstream
(
graph_viz_path
));
PADDLE_ENFORCE
(
fout
->
good
());
std
::
ostream
&
sout
=
*
fout
;
size_t
var_id
=
0
;
std
::
unordered_map
<
const
ir
::
Node
*
,
size_t
>
vars
;
sout
<<
"digraph G {
\n
"
;
for
(
const
ir
::
Node
*
n
:
graph
->
Nodes
())
{
if
(
n
->
NodeType
()
!=
ir
::
Node
::
Type
::
kVariable
)
continue
;
size_t
cur_var_id
=
var_id
++
;
vars
[
n
]
=
cur_var_id
;
sout
<<
"var_"
<<
cur_var_id
<<
" [label=
\"
"
<<
n
->
Name
()
<<
"
\"
]"
<<
std
::
endl
;
}
size_t
op_id
=
0
;
for
(
const
ir
::
Node
*
n
:
graph
->
Nodes
())
{
if
(
n
->
NodeType
()
!=
ir
::
Node
::
Type
::
kOperation
)
continue
;
std
::
string
op_name
=
"op_"
+
std
::
to_string
(
op_id
++
);
sout
<<
op_name
<<
" [label=
\"
"
<<
n
->
Name
()
<<
"
\"
, shape=rect]"
<<
std
::
endl
;
for
(
auto
in
:
n
->
inputs
)
{
std
::
string
var_name
=
"var_"
+
std
::
to_string
(
vars
[
in
]);
sout
<<
var_name
<<
" -> "
<<
op_name
<<
std
::
endl
;
}
for
(
auto
out
:
n
->
outputs
)
{
std
::
string
var_name
=
"var_"
+
std
::
to_string
(
vars
[
out
]);
sout
<<
op_name
<<
" -> "
<<
var_name
<<
std
::
endl
;
}
}
sout
<<
"}
\n
"
;
return
graph
;
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
graph_viz_pass
,
paddle
::
framework
::
ir
::
GraphVizPass
)
.
RequirePassAttr
(
paddle
::
framework
::
ir
::
kGraphVizPath
);
paddle/fluid/framework/ir/graph_viz_pass.h
0 → 100644
浏览文件 @
000ba1ac
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <fstream>
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
class
GraphVizPass
:
public
Pass
{
protected:
std
::
unique_ptr
<
ir
::
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
override
;
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/pass.cc
浏览文件 @
000ba1ac
...
@@ -13,7 +13,34 @@ See the License for the specific language governing permissions and
...
@@ -13,7 +13,34 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{}
// namespace framework
namespace
framework
{
namespace
ir
{
std
::
unique_ptr
<
Graph
>
Pass
::
Apply
(
std
::
unique_ptr
<
Graph
>
graph
)
const
{
PADDLE_ENFORCE
(
!
applied_
,
"Pass can only Apply() once."
);
PADDLE_ENFORCE
(
graph
.
get
(),
"graph passed to Pass::Apply() cannot be empty."
);
for
(
const
std
::
string
&
attr
:
required_pass_attrs_
)
{
PADDLE_ENFORCE
(
attrs_
.
find
(
attr
)
!=
attrs_
.
end
(),
"Required pass atrribute %s not set."
,
attr
);
}
for
(
const
std
::
string
&
attr
:
required_graph_attrs_
)
{
PADDLE_ENFORCE
(
graph
->
Has
(
attr
),
"Required graph atrribute %s not set."
,
attr
);
}
auto
applied_graph
=
ApplyImpl
(
std
::
move
(
graph
));
// TODO(panyx0718): Add more verifications.
PADDLE_ENFORCE
(
!
HasCircle
(
*
applied_graph
),
"Illegal Pass. Generated graph shouldn't has cycle."
);
applied_
=
true
;
return
applied_graph
;
}
PassRegistry
&
PassRegistry
::
Instance
()
{
static
PassRegistry
g_pass_info_map
;
return
g_pass_info_map
;
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/ir/pass.h
浏览文件 @
000ba1ac
...
@@ -14,21 +14,187 @@ limitations under the License. */
...
@@ -14,21 +14,187 @@ limitations under the License. */
#pragma once
#pragma once
#include <functional>
#include <map>
#include <string>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/platform/variant.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
ir
{
namespace
ir
{
template
<
typename
PassType
>
struct
PassRegistrar
;
class
Pass
{
class
Pass
{
public:
public:
Pass
()
=
default
;
Pass
()
=
default
;
virtual
~
Pass
()
{}
virtual
~
Pass
()
{
for
(
auto
&
attr
:
attrs_
)
{
if
(
attr_dels_
.
find
(
attr
.
first
)
!=
attr_dels_
.
end
())
{
attr_dels_
[
attr
.
first
]();
}
}
attrs_
.
clear
();
attr_dels_
.
clear
();
}
std
::
unique_ptr
<
Graph
>
Apply
(
std
::
unique_ptr
<
Graph
>
graph
)
const
;
// Get a reference to the attributed previously set.
template
<
typename
AttrType
>
AttrType
&
Get
(
const
std
::
string
&
attr_name
)
const
{
PADDLE_ENFORCE
(
attrs_
.
find
(
attr_name
)
!=
attrs_
.
end
(),
"%s attr not registered for pass."
,
attr_name
);
return
*
boost
::
any_cast
<
AttrType
*>
(
attrs_
.
at
(
attr_name
));
}
// Set a pointer to the attribute. Pass takes ownership of the attribute.
template
<
typename
AttrType
>
void
Set
(
const
std
::
string
&
attr_name
,
AttrType
*
attr
)
{
PADDLE_ENFORCE
(
attrs_
.
count
(
attr_name
)
==
0
,
"%s already set in the pass"
,
attr_name
);
attrs_
[
attr_name
]
=
attr
;
attr_dels_
[
attr_name
]
=
[
attr
,
attr_name
]()
{
VLOG
(
3
)
<<
"deleting "
<<
attr_name
;
delete
attr
;
};
}
// Set a pointer to the attribute. Pass doesn't take ownership. Caller
// should delete the attribute.
template
<
typename
AttrType
>
void
SetNotOwned
(
const
std
::
string
&
attr_name
,
AttrType
*
attr
)
{
PADDLE_ENFORCE
(
attrs_
.
count
(
attr_name
)
==
0
);
attrs_
[
attr_name
]
=
attr
;
}
protected:
virtual
std
::
unique_ptr
<
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
Graph
>
graph
)
const
=
0
;
private:
template
<
typename
PassType
>
friend
struct
PassRegistrar
;
void
RegisterRequiredPassAttrs
(
const
std
::
unordered_set
<
std
::
string
>
&
attrs
)
{
required_pass_attrs_
.
insert
(
attrs
.
begin
(),
attrs
.
end
());
}
void
RegisterRequiredGraphAttrs
(
const
std
::
unordered_set
<
std
::
string
>
&
attrs
)
{
required_graph_attrs_
.
insert
(
attrs
.
begin
(),
attrs
.
end
());
}
mutable
bool
applied_
{
false
};
std
::
unordered_set
<
std
::
string
>
required_pass_attrs_
;
std
::
unordered_set
<
std
::
string
>
required_graph_attrs_
;
std
::
map
<
std
::
string
,
boost
::
any
>
attrs_
;
std
::
map
<
std
::
string
,
std
::
function
<
void
(
void
)
>>
attr_dels_
;
};
using
PassCreator
=
std
::
function
<
std
::
unique_ptr
<
Pass
>
()
>
;
class
Registrar
{
public:
// In our design, various kinds of passes,
// have their corresponding registry and registrar. The action of
// registration is in the constructor of a global registrar variable, which
// are not used in the code that calls package framework, and would
// be removed from the generated binary file by the linker. To avoid such
// removal, we add Touch to all registrar classes and make USE_PASS macros to
// call this method. So, as long as the callee code calls USE_PASS, the global
// registrar variable won't be removed by the linker.
void
Touch
()
{}
};
virtual
std
::
unique_ptr
<
Graph
>
Apply
(
std
::
unique_ptr
<
Graph
>
graph
)
const
=
0
;
class
PassRegistry
{
public:
static
PassRegistry
&
Instance
();
bool
Has
(
const
std
::
string
&
pass_type
)
const
{
return
map_
.
find
(
pass_type
)
!=
map_
.
end
();
}
void
Insert
(
const
std
::
string
&
pass_type
,
const
PassCreator
&
pass_creator
)
{
PADDLE_ENFORCE
(
!
Has
(
pass_type
),
"Pass %s has been registered"
,
pass_type
);
map_
.
insert
({
pass_type
,
pass_creator
});
}
std
::
unique_ptr
<
Pass
>
Get
(
const
std
::
string
&
pass_type
)
const
{
PADDLE_ENFORCE
(
Has
(
pass_type
),
"Pass %s has not been registered"
,
pass_type
);
return
map_
.
at
(
pass_type
)();
}
private:
PassRegistry
()
=
default
;
std
::
unordered_map
<
std
::
string
,
PassCreator
>
map_
;
DISABLE_COPY_AND_ASSIGN
(
PassRegistry
);
};
};
template
<
typename
PassType
>
struct
PassRegistrar
:
public
Registrar
{
explicit
PassRegistrar
(
const
char
*
pass_type
)
{
PADDLE_ENFORCE
(
!
PassRegistry
::
Instance
().
Has
(
pass_type
),
"'%s' is registered more than once."
,
pass_type
);
PassRegistry
::
Instance
().
Insert
(
pass_type
,
[
this
]()
->
std
::
unique_ptr
<
Pass
>
{
std
::
unique_ptr
<
Pass
>
pass
(
new
PassType
());
pass
->
RegisterRequiredPassAttrs
(
this
->
required_pass_attrs_
);
pass
->
RegisterRequiredGraphAttrs
(
this
->
required_graph_attrs_
);
return
pass
;
});
}
PassRegistrar
<
PassType
>
&
RequirePassAttr
(
const
std
::
string
&
attr
)
{
required_pass_attrs_
.
insert
(
attr
);
return
*
this
;
}
PassRegistrar
<
PassType
>
&
RequireGraphAttr
(
const
std
::
string
&
attr
)
{
required_graph_attrs_
.
insert
(
attr
);
return
*
this
;
}
private:
std
::
unordered_set
<
std
::
string
>
required_pass_attrs_
;
std
::
unordered_set
<
std
::
string
>
required_graph_attrs_
;
};
#define STATIC_ASSERT_PASS_GLOBAL_NAMESPACE(uniq_name, msg) \
struct __test_global_namespace_##uniq_name##__ {}; \
static_assert(std::is_same<::__test_global_namespace_##uniq_name##__, \
__test_global_namespace_##uniq_name##__>::value, \
msg)
// Register a new pass that can be applied on the IR.
#define REGISTER_PASS(pass_type, pass_class) \
STATIC_ASSERT_PASS_GLOBAL_NAMESPACE( \
__reg_pass__##pass_type, \
"REGISTER_PASS must be called in global namespace"); \
static ::paddle::framework::ir::PassRegistrar<pass_class> \
__pass_registrar_##pass_type##__(#pass_type); \
int TouchPassRegistrar_##pass_type() { \
__pass_registrar_##pass_type##__.Touch(); \
return 0; \
} \
static ::paddle::framework::ir::PassRegistrar<pass_class> \
&__pass_tmp_registrar_##pass_type##__ __attribute__((unused)) = \
__pass_registrar_##pass_type##__
#define USE_PASS(pass_type) \
STATIC_ASSERT_PASS_GLOBAL_NAMESPACE( \
__use_pass_itself_##pass_type, \
"USE_PASS must be called in global namespace"); \
extern int TouchPassRegistrar_##pass_type(); \
static int use_pass_itself_##pass_type##_ __attribute__((unused)) = \
TouchPassRegistrar_##pass_type()
}
// namespace ir
}
// namespace ir
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/ir/pass_test.cc
0 → 100644
浏览文件 @
000ba1ac
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/pass.h"
#include <string>
#include "gtest/gtest.h"
#include "paddle/fluid/framework/ir/graph.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
void
BuildCircleGraph
(
Graph
*
g
)
{
ir
::
Node
*
o1
=
g
->
CreateEmptyNode
(
"op1"
,
Node
::
Type
::
kOperation
);
ir
::
Node
*
o2
=
g
->
CreateEmptyNode
(
"op2"
,
Node
::
Type
::
kOperation
);
ir
::
Node
*
v1
=
g
->
CreateEmptyNode
(
"var1"
,
Node
::
Type
::
kVariable
);
ir
::
Node
*
v2
=
g
->
CreateEmptyNode
(
"var2"
,
Node
::
Type
::
kVariable
);
o1
->
outputs
.
push_back
(
v1
);
o2
->
inputs
.
push_back
(
v1
);
v1
->
inputs
.
push_back
(
o1
);
v1
->
outputs
.
push_back
(
o2
);
o2
->
outputs
.
push_back
(
v2
);
o1
->
inputs
.
push_back
(
v2
);
v2
->
inputs
.
push_back
(
o2
);
v2
->
outputs
.
push_back
(
o1
);
}
class
TestPass
:
public
Pass
{
protected:
std
::
unique_ptr
<
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
Graph
>
graph
)
const
{
graph
->
Set
<
int
>
(
"copy_test_pass_attr"
,
new
int
);
graph
->
Set
<
int
>
(
"copy_test_graph_attr"
,
new
int
);
int
test_pass_attr
=
this
->
Get
<
int
>
(
"test_pass_attr"
);
graph
->
Get
<
int
>
(
"copy_test_pass_attr"
)
=
test_pass_attr
+
1
;
int
test_graph_attr
=
graph
->
Get
<
int
>
(
"test_graph_attr"
);
graph
->
Get
<
int
>
(
"copy_test_graph_attr"
)
=
test_graph_attr
+
1
;
return
graph
;
}
};
TEST
(
PassTest
,
TestPassAttrCheck
)
{
ProgramDesc
prog
;
auto
pass
=
PassRegistry
::
Instance
().
Get
(
"test_pass"
);
std
::
unique_ptr
<
Graph
>
graph
(
new
Graph
(
prog
));
std
::
string
exception
;
try
{
graph
=
pass
->
Apply
(
std
::
move
(
graph
));
}
catch
(
paddle
::
platform
::
EnforceNotMet
e
)
{
exception
=
std
::
string
(
e
.
what
());
}
ASSERT_TRUE
(
exception
.
find
(
"test_pass_attr not set"
)
!=
exception
.
npos
);
int
val
=
1
;
graph
.
reset
(
new
Graph
(
prog
));
pass
->
SetNotOwned
<
int
>
(
"test_pass_attr"
,
&
val
);
try
{
graph
=
pass
->
Apply
(
std
::
move
(
graph
));
}
catch
(
paddle
::
platform
::
EnforceNotMet
e
)
{
exception
=
std
::
string
(
e
.
what
());
}
ASSERT_TRUE
(
exception
.
find
(
"test_graph_attr not set"
)
!=
exception
.
npos
);
graph
.
reset
(
new
Graph
(
prog
));
graph
->
Set
<
int
>
(
"test_graph_attr"
,
new
int
);
graph
->
Get
<
int
>
(
"test_graph_attr"
)
=
1
;
graph
=
pass
->
Apply
(
std
::
move
(
graph
));
ASSERT_EQ
(
graph
->
Get
<
int
>
(
"copy_test_pass_attr"
),
2
);
ASSERT_EQ
(
graph
->
Get
<
int
>
(
"copy_test_graph_attr"
),
2
);
try
{
graph
=
pass
->
Apply
(
std
::
move
(
graph
));
}
catch
(
paddle
::
platform
::
EnforceNotMet
e
)
{
exception
=
std
::
string
(
e
.
what
());
}
ASSERT_TRUE
(
exception
.
find
(
"Pass can only Apply() once"
)
!=
exception
.
npos
);
pass
=
PassRegistry
::
Instance
().
Get
(
"test_pass"
);
pass
->
SetNotOwned
<
int
>
(
"test_pass_attr"
,
&
val
);
graph
.
reset
(
new
Graph
(
prog
));
BuildCircleGraph
(
graph
.
get
());
graph
->
Set
<
int
>
(
"test_graph_attr"
,
new
int
);
graph
->
Get
<
int
>
(
"test_graph_attr"
)
=
2
;
try
{
auto
tmp
=
pass
->
Apply
(
std
::
move
(
graph
));
}
catch
(
paddle
::
platform
::
EnforceNotMet
e
)
{
exception
=
std
::
string
(
e
.
what
());
}
ASSERT_TRUE
(
exception
.
find
(
"shouldn't has cycle"
)
!=
exception
.
npos
);
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
test_pass
,
paddle
::
framework
::
ir
::
TestPass
)
.
RequirePassAttr
(
"test_pass_attr"
)
.
RequireGraphAttr
(
"test_graph_attr"
);
paddle/fluid/framework/operator.cc
浏览文件 @
000ba1ac
...
@@ -679,6 +679,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
...
@@ -679,6 +679,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
if
(
var
==
nullptr
)
continue
;
if
(
var
==
nullptr
)
continue
;
if
(
var
->
IsType
<
framework
::
LoDTensor
>
())
{
if
(
var
->
IsType
<
framework
::
LoDTensor
>
())
{
CheckTensorNANOrInf
(
vname
,
var
->
Get
<
framework
::
LoDTensor
>
());
CheckTensorNANOrInf
(
vname
,
var
->
Get
<
framework
::
LoDTensor
>
());
}
else
if
(
var
->
IsType
<
framework
::
SelectedRows
>
())
{
CheckTensorNANOrInf
(
vname
,
var
->
Get
<
framework
::
SelectedRows
>
().
value
());
}
}
}
}
}
}
...
...
paddle/fluid/framework/parallel_executor.cc
浏览文件 @
000ba1ac
...
@@ -19,19 +19,80 @@ limitations under the License. */
...
@@ -19,19 +19,80 @@ limitations under the License. */
#include <vector>
#include <vector>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/nccl_helper.h"
#include "paddle/fluid/platform/nccl_helper.h"
#endif
#endif
#include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/ssa_graph_builder_factory.h"
#include "paddle/fluid/framework/details/ssa_graph_checker.h"
#include "paddle/fluid/framework/details/ssa_graph_printer.h"
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/platform/profiler.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
std
::
unique_ptr
<
ir
::
Graph
>
ApplyParallelExecutorPass
(
const
ProgramDesc
&
main_program
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
std
::
string
&
loss_var_name
,
const
std
::
unordered_set
<
std
::
string
>
&
param_names
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
bool
use_cuda
,
#ifdef PADDLE_WITH_CUDA
const
BuildStrategy
&
strategy
,
platform
::
NCCLContextMap
*
nccl_ctxs
)
{
#else
const
BuildStrategy
&
strategy
)
{
#endif
// Convert the program to graph.
std
::
unique_ptr
<
ir
::
Graph
>
graph
(
new
ir
::
Graph
(
main_program
));
// Apply a graph viz pass to record a graph.
if
(
!
strategy
.
debug_graphviz_path_
.
empty
())
{
auto
viz_pass
=
ir
::
PassRegistry
::
Instance
().
Get
(
"graph_viz_pass"
);
const
std
::
string
graph_path
=
string
::
Sprintf
(
"%s%s"
,
strategy
.
debug_graphviz_path_
.
c_str
(),
"_original_graph"
);
viz_pass
->
Set
<
std
::
string
>
(
"graph_viz_path"
,
new
std
::
string
(
graph_path
));
graph
=
viz_pass
->
Apply
(
std
::
move
(
graph
));
}
// Convert graph to run on multi-devices.
auto
multi_device_pass
=
ir
::
PassRegistry
::
Instance
().
Get
(
"multi_device_pass"
);
multi_device_pass
->
SetNotOwned
<
const
std
::
vector
<
platform
::
Place
>>
(
"places"
,
&
places
);
multi_device_pass
->
SetNotOwned
<
const
std
::
string
>
(
"loss_var_name"
,
&
loss_var_name
);
multi_device_pass
->
SetNotOwned
<
const
std
::
unordered_set
<
std
::
string
>>
(
"params"
,
&
param_names
);
multi_device_pass
->
SetNotOwned
<
const
std
::
vector
<
Scope
*>>
(
"local_scopes"
,
&
local_scopes
);
multi_device_pass
->
SetNotOwned
<
const
BuildStrategy
>
(
"strategy"
,
&
strategy
);
#ifdef PADDLE_WITH_CUDA
platform
::
NCCLContextMap
*
nctx
=
use_cuda
?
nccl_ctxs
:
nullptr
;
multi_device_pass
->
SetNotOwned
<
platform
::
NCCLContextMap
>
(
"nccl_ctxs"
,
nctx
);
#endif
graph
=
multi_device_pass
->
Apply
(
std
::
move
(
graph
));
// Apply a graph print pass to record a graph with device info.
if
(
!
strategy
.
debug_graphviz_path_
.
empty
())
{
auto
multi_device_print_pass
=
ir
::
PassRegistry
::
Instance
().
Get
(
"multi_device_print_pass"
);
multi_device_print_pass
->
SetNotOwned
<
const
std
::
string
>
(
"debug_graphviz_path"
,
&
strategy
.
debug_graphviz_path_
);
multi_device_print_pass
->
Set
<
details
::
GraphvizSSAGraphPrinter
>
(
"graph_printer"
,
new
details
::
GraphvizSSAGraphPrinter
);
graph
=
multi_device_print_pass
->
Apply
(
std
::
move
(
graph
));
}
// Verify that the graph is correct for multi-device executor.
auto
multi_device_check_pass
=
ir
::
PassRegistry
::
Instance
().
Get
(
"multi_device_check_pass"
);
graph
=
multi_device_check_pass
->
Apply
(
std
::
move
(
graph
));
return
graph
;
}
class
ParallelExecutorPrivate
{
class
ParallelExecutorPrivate
{
public:
public:
explicit
ParallelExecutorPrivate
(
const
std
::
vector
<
platform
::
Place
>
&
places
)
explicit
ParallelExecutorPrivate
(
const
std
::
vector
<
platform
::
Place
>
&
places
)
...
@@ -119,21 +180,19 @@ ParallelExecutor::ParallelExecutor(
...
@@ -119,21 +180,19 @@ ParallelExecutor::ParallelExecutor(
var_infos
.
back
().
persistable_
=
var
->
Persistable
();
var_infos
.
back
().
persistable_
=
var
->
Persistable
();
}
}
// Step 3. Convert main_program to SSA form and dependency graph. Also, insert
// Step 3. Convert main_program to SSA form and dependency graph. Also, insert
// ncclOp
// ncclOp
details
::
SSAGraphBuilderFactory
builder_factory
(
member_
->
places_
,
loss_var_name
,
params
,
member_
->
local_scopes_
,
build_strategy
);
if
(
member_
->
use_cuda_
)
{
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
builder_factory
.
SetNCCLContextMap
(
member_
->
nccl_ctxs_
.
get
());
std
::
unique_ptr
<
ir
::
Graph
>
graph
=
ApplyParallelExecutorPass
(
main_program
,
member_
->
places_
,
loss_var_name
,
params
,
member_
->
local_scopes_
,
member_
->
use_cuda_
,
build_strategy
,
member_
->
nccl_ctxs_
.
get
());
#else
#else
PADDLE_THROW
(
"Not compiled with CUDA."
);
std
::
unique_ptr
<
ir
::
Graph
>
graph
=
ApplyParallelExecutorPass
(
main_program
,
member_
->
places_
,
loss_var_name
,
params
,
member_
->
local_scopes_
,
member_
->
use_cuda_
,
build_strategy
);
#endif
#endif
}
builder_
=
builder_factory
.
Create
();
std
::
unique_ptr
<
ir
::
Graph
>
graph
(
new
ir
::
Graph
(
main_program
));
graph
=
builder_
->
Apply
(
std
::
move
(
graph
));
member_
->
executor_
.
reset
(
new
details
::
ThreadedSSAGraphExecutor
(
member_
->
executor_
.
reset
(
new
details
::
ThreadedSSAGraphExecutor
(
exec_strategy
,
member_
->
local_scopes_
,
places
,
std
::
move
(
graph
)));
exec_strategy
,
member_
->
local_scopes_
,
places
,
std
::
move
(
graph
)));
member_
->
executor_
.
reset
(
new
details
::
ScopeBufferedSSAGraphExecutor
(
member_
->
executor_
.
reset
(
new
details
::
ScopeBufferedSSAGraphExecutor
(
...
@@ -146,11 +205,18 @@ void ParallelExecutor::BCastParamsToDevices(
...
@@ -146,11 +205,18 @@ void ParallelExecutor::BCastParamsToDevices(
// the initializing bcast, all vars would be bcast from device(0),
// the initializing bcast, all vars would be bcast from device(0),
// otherwise
// otherwise
// bcast from the specified device.
// bcast from the specified device.
bool
initializing
=
builder_
.
get
()
==
nullptr
?
true
:
false
;
bool
initializing
=
member_
->
executor_
?
false
:
true
;
for
(
auto
&
var
:
vars
)
{
for
(
auto
&
var
:
vars
)
{
int
var_dev_id
=
int
var_dev_id
=
-
1
;
builder_
.
get
()
==
nullptr
?
-
1
:
builder_
->
GetVarDeviceID
(
var
);
if
(
member_
->
executor_
)
{
auto
&
sharded_var_device
=
member_
->
executor_
->
Graph
().
Get
<
details
::
ShardedVarDevice
>
(
details
::
kShardedVarDevice
);
if
(
sharded_var_device
.
find
(
var
)
!=
sharded_var_device
.
end
())
{
var_dev_id
=
sharded_var_device
.
at
(
var
);
}
}
if
(
!
initializing
&&
var_dev_id
==
-
1
)
continue
;
if
(
!
initializing
&&
var_dev_id
==
-
1
)
continue
;
framework
::
Variable
*
main_var
=
nullptr
;
framework
::
Variable
*
main_var
=
nullptr
;
...
@@ -286,3 +352,8 @@ ParallelExecutor::~ParallelExecutor() {
...
@@ -286,3 +352,8 @@ ParallelExecutor::~ParallelExecutor() {
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
USE_PASS
(
graph_viz_pass
);
USE_PASS
(
multi_device_pass
);
USE_PASS
(
multi_device_check_pass
);
USE_PASS
(
multi_device_print_pass
);
paddle/fluid/framework/parallel_executor.h
浏览文件 @
000ba1ac
...
@@ -70,7 +70,6 @@ class ParallelExecutor {
...
@@ -70,7 +70,6 @@ class ParallelExecutor {
private:
private:
ParallelExecutorPrivate
*
member_
;
ParallelExecutorPrivate
*
member_
;
std
::
unique_ptr
<
details
::
SSAGraphBuilder
>
builder_
;
};
};
}
// namespace framework
}
// namespace framework
...
...
paddle/fluid/inference/CMakeLists.txt
浏览文件 @
000ba1ac
...
@@ -14,8 +14,15 @@ cc_library(paddle_fluid_api
...
@@ -14,8 +14,15 @@ cc_library(paddle_fluid_api
get_property
(
fluid_modules GLOBAL PROPERTY FLUID_MODULES
)
get_property
(
fluid_modules GLOBAL PROPERTY FLUID_MODULES
)
# paddle_fluid_origin exclude inference api interface
cc_library
(
paddle_fluid_origin DEPS
${
fluid_modules
}
paddle_fluid_api
)
if
(
NOT APPLE
)
add_subdirectory
(
api
)
endif
()
# Create static library
# Create static library
cc_library
(
paddle_fluid DEPS
${
fluid_modules
}
paddle_fluid_api
)
cc_library
(
paddle_fluid DEPS
${
fluid_modules
}
paddle_fluid_api
paddle_inference_api
)
if
(
NOT APPLE
)
if
(
NOT APPLE
)
# TODO(liuyiqu: Temporarily disable the link flag because it is not support on Mac.
# TODO(liuyiqu: Temporarily disable the link flag because it is not support on Mac.
set
(
LINK_FLAGS
"-Wl,--retain-symbols-file
${
CMAKE_CURRENT_SOURCE_DIR
}
/paddle_fluid.sym"
)
set
(
LINK_FLAGS
"-Wl,--retain-symbols-file
${
CMAKE_CURRENT_SOURCE_DIR
}
/paddle_fluid.sym"
)
...
@@ -24,7 +31,7 @@ endif()
...
@@ -24,7 +31,7 @@ endif()
# Create shared library
# Create shared library
cc_library
(
paddle_fluid_shared SHARED
cc_library
(
paddle_fluid_shared SHARED
SRCS io.cc
SRCS io.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/api/api.cc
${
CMAKE_CURRENT_SOURCE_DIR
}
/api/api_impl.cc
DEPS
${
fluid_modules
}
paddle_fluid_api
)
DEPS
${
fluid_modules
}
paddle_fluid_api
)
set_target_properties
(
paddle_fluid_shared PROPERTIES OUTPUT_NAME paddle_fluid
)
set_target_properties
(
paddle_fluid_shared PROPERTIES OUTPUT_NAME paddle_fluid
)
...
@@ -32,12 +39,21 @@ if(NOT APPLE)
...
@@ -32,12 +39,21 @@ if(NOT APPLE)
# TODO(liuyiqun): Temporarily disable the link flag because it is not support on Mac.
# TODO(liuyiqun): Temporarily disable the link flag because it is not support on Mac.
set
(
LINK_FLAGS
"-Wl,--version-script
${
CMAKE_CURRENT_SOURCE_DIR
}
/paddle_fluid.map"
)
set
(
LINK_FLAGS
"-Wl,--version-script
${
CMAKE_CURRENT_SOURCE_DIR
}
/paddle_fluid.map"
)
set_target_properties
(
paddle_fluid_shared PROPERTIES LINK_FLAGS
"
${
LINK_FLAGS
}
"
)
set_target_properties
(
paddle_fluid_shared PROPERTIES LINK_FLAGS
"
${
LINK_FLAGS
}
"
)
# check symbol hidden
FILE
(
WRITE
${
CMAKE_CURRENT_BINARY_DIR
}
/check_symbol.cmake
"execute_process(COMMAND bash -c
\"
${
CMAKE_CURRENT_SOURCE_DIR
}
/check_symbol.sh"
"
${
CMAKE_CURRENT_BINARY_DIR
}
/libpaddle_fluid.so
\"
RESULT_VARIABLE symbol_res)
\n
"
"if(NOT
\"\$
{symbol_res}
\"
STREQUAL
\"
0
\"
)
\n
"
" message(FATAL_ERROR
\"
Check symbol failed.
\"
)
\n
"
"endif()
\n
"
)
add_custom_command
(
OUTPUT
"
${
CMAKE_CURRENT_BINARY_DIR
}
/.check_symbol"
COMMAND
${
CMAKE_COMMAND
}
-P
"
${
CMAKE_CURRENT_BINARY_DIR
}
/check_symbol.cmake"
DEPENDS paddle_fluid_shared
)
add_custom_target
(
check_symbol ALL DEPENDS
"
${
CMAKE_CURRENT_BINARY_DIR
}
/.check_symbol"
)
endif
()
endif
()
if
(
WITH_TESTING
)
if
(
WITH_TESTING
)
# both tests/book and analysis depends the models that generated by python/paddle/fluid/tests/book
# both tests/book and analysis depends the models that generated by python/paddle/fluid/tests/book
add_subdirectory
(
tests/book
)
add_subdirectory
(
tests/book
)
endif
()
endif
()
if
(
NOT APPLE
)
add_subdirectory
(
api
)
endif
()
paddle/fluid/inference/api/CMakeLists.txt
浏览文件 @
000ba1ac
...
@@ -42,35 +42,8 @@ function(inference_api_test TARGET_NAME)
...
@@ -42,35 +42,8 @@ function(inference_api_test TARGET_NAME)
endif
(
WITH_TESTING
)
endif
(
WITH_TESTING
)
endfunction
(
inference_api_test
)
endfunction
(
inference_api_test
)
cc_library
(
paddle_inference_api
cc_library
(
paddle_inference_api SRCS api.cc api_impl.cc DEPS lod_tensor
)
SRCS api.cc api_impl.cc
DEPS
${
FLUID_CORE_MODULES
}
${
GLOB_OP_LIB
}
)
if
(
NOT APPLE
)
set
(
LINK_FLAGS
"-Wl,--retain-symbols-file
${
CMAKE_CURRENT_SOURCE_DIR
}
/api.sym"
)
set_target_properties
(
paddle_inference_api PROPERTIES LINK_FLAGS
"
${
LINK_FLAGS
}
"
)
endif
()
# Here the shared library doesn't depend on other fluid libraries, or double free will occur.
cc_library
(
paddle_inference_api_shared SHARED
SRCS api.cc api_impl.cc
)
add_dependencies
(
paddle_inference_api_shared
${
FLUID_CORE_MODULES
}
${
GLOB_OP_LIB
}
)
set_target_properties
(
paddle_inference_api_shared PROPERTIES OUTPUT_NAME paddle_inference_api
)
if
(
NOT APPLE
)
set
(
LINK_FLAGS
"-Wl,--version-script
${
CMAKE_CURRENT_SOURCE_DIR
}
/api.map"
)
set_target_properties
(
paddle_inference_api_shared PROPERTIES LINK_FLAGS
"
${
LINK_FLAGS
}
"
)
FILE
(
WRITE
${
CMAKE_CURRENT_BINARY_DIR
}
/check_symbol.cmake
"execute_process(COMMAND bash -c
\"
${
CMAKE_CURRENT_SOURCE_DIR
}
/check_symbol.sh"
"
${
CMAKE_CURRENT_BINARY_DIR
}
/libpaddle_inference_api.so
\"
RESULT_VARIABLE symbol_res)
\n
"
"if(NOT
\"\$
{symbol_res}
\"
STREQUAL
\"
0
\"
)
\n
"
" message(FATAL_ERROR
\"
Check symbol failed.
\"
)
\n
"
"endif()
\n
"
)
add_custom_command
(
OUTPUT
"
${
CMAKE_CURRENT_BINARY_DIR
}
/.check_symbol"
COMMAND
${
CMAKE_COMMAND
}
-P
"
${
CMAKE_CURRENT_BINARY_DIR
}
/check_symbol.cmake"
DEPENDS paddle_inference_api_shared
)
add_custom_target
(
check_symbol ALL DEPENDS
"
${
CMAKE_CURRENT_BINARY_DIR
}
/.check_symbol"
)
endif
()
cc_test
(
test_paddle_inference_api
cc_test
(
test_paddle_inference_api
SRCS api_tester.cc
SRCS api_tester.cc
...
...
paddle/fluid/inference/api/api.map
已删除
100644 → 0
浏览文件 @
0c7d6eb8
{
global:
*paddle*;
local:
*;
};
paddle/fluid/inference/api/api.sym
已删除
100644 → 0
浏览文件 @
0c7d6eb8
*paddle*
paddle/fluid/inference/api/demo_ci/CMakeLists.txt
浏览文件 @
000ba1ac
...
@@ -55,11 +55,9 @@ endif()
...
@@ -55,11 +55,9 @@ endif()
# Note: libpaddle_inference_api.so/a must put before libpaddle_fluid.so/a
# Note: libpaddle_inference_api.so/a must put before libpaddle_fluid.so/a
if
(
WITH_STATIC_LIB
)
if
(
WITH_STATIC_LIB
)
set
(
DEPS
set
(
DEPS
${
PADDLE_LIB
}
/paddle/fluid/inference/libpaddle_inference_api.a
${
PADDLE_LIB
}
/paddle/fluid/inference/libpaddle_fluid.a
)
${
PADDLE_LIB
}
/paddle/fluid/inference/libpaddle_fluid.a
)
else
()
else
()
set
(
DEPS
set
(
DEPS
${
PADDLE_LIB
}
/paddle/fluid/inference/libpaddle_inference_api.so
${
PADDLE_LIB
}
/paddle/fluid/inference/libpaddle_fluid.so
)
${
PADDLE_LIB
}
/paddle/fluid/inference/libpaddle_fluid.so
)
endif
()
endif
()
set
(
EXTERNAL_LIB
"-lrt -ldl -lpthread"
)
set
(
EXTERNAL_LIB
"-lrt -ldl -lpthread"
)
...
...
paddle/fluid/inference/
api/
check_symbol.sh
→
paddle/fluid/inference/check_symbol.sh
浏览文件 @
000ba1ac
...
@@ -3,8 +3,8 @@
...
@@ -3,8 +3,8 @@
lib
=
$1
lib
=
$1
if
[
$#
-ne
1
]
;
then
echo
"No input library"
;
exit
-1
;
fi
if
[
$#
-ne
1
]
;
then
echo
"No input library"
;
exit
-1
;
fi
num_paddle_syms
=
$(
nm
-D
--defined-only
${
lib
}
|
grep
paddle |
wc
-l
)
num_paddle_syms
=
$(
nm
-D
${
lib
}
|
grep
paddle |
wc
-l
)
num_google_syms
=
$(
nm
-D
--defined-only
${
lib
}
|
grep
google
|
wc
-l
)
num_google_syms
=
$(
nm
-D
${
lib
}
|
grep
google |
grep
-v
paddle |
grep
T
|
wc
-l
)
if
[
$num_paddle_syms
-le
0
]
;
then
echo
"Have no paddle symbols"
;
exit
-1
;
fi
if
[
$num_paddle_syms
-le
0
]
;
then
echo
"Have no paddle symbols"
;
exit
-1
;
fi
if
[
$num_google_syms
-ge
1
]
;
then
echo
"Have some google symbols"
;
exit
-1
;
fi
if
[
$num_google_syms
-ge
1
]
;
then
echo
"Have some google symbols"
;
exit
-1
;
fi
...
...
paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
浏览文件 @
000ba1ac
# Add TRT tests
# Add TRT tests
nv_library
(
tensorrt_converter
nv_library
(
tensorrt_converter
SRCS mul_op.cc conv2d_op.cc fc_op.cc
SRCS mul_op.cc conv2d_op.cc fc_op.cc
pool2d_op.cc
DEPS tensorrt_engine
mul_op
)
DEPS tensorrt_engine
operator scope framework_proto op_registry
)
nv_test
(
test_op_converter SRCS test_op_converter.cc DEPS
nv_test
(
test_op_converter SRCS test_op_converter.cc DEPS
${
FLUID_CORE_MODULES
}
tensorrt_engine tensorrt_converter
)
${
FLUID_CORE_MODULES
}
tensorrt_engine tensorrt_converter
)
...
@@ -13,3 +13,6 @@ nv_test(test_trt_fc_op SRCS test_fc_op.cc fc_op.cc
...
@@ -13,3 +13,6 @@ nv_test(test_trt_fc_op SRCS test_fc_op.cc fc_op.cc
DEPS
${
FLUID_CORE_MODULES
}
tensorrt_engine mul_op SERIAL
)
DEPS
${
FLUID_CORE_MODULES
}
tensorrt_engine mul_op SERIAL
)
nv_test
(
test_trt_activation_op SRCS test_activation_op.cc activation_op.cc
nv_test
(
test_trt_activation_op SRCS test_activation_op.cc activation_op.cc
DEPS
${
FLUID_CORE_MODULES
}
tensorrt_engine activation_op SERIAL
)
DEPS
${
FLUID_CORE_MODULES
}
tensorrt_engine activation_op SERIAL
)
nv_test
(
test_trt_pool2d_op SRCS test_pool2d_op.cc pool2d_op.cc
DEPS
${
FLUID_CORE_MODULES
}
tensorrt_engine pool_op SERIAL
)
paddle/fluid/inference/tensorrt/convert/mul_op.cc
浏览文件 @
000ba1ac
...
@@ -49,5 +49,4 @@ class MulOpConverter : public OpConverter {
...
@@ -49,5 +49,4 @@ class MulOpConverter : public OpConverter {
}
// namespace inference
}
// namespace inference
}
// namespace paddle
}
// namespace paddle
USE_OP
(
mul
);
REGISTER_TRT_OP_CONVERTER
(
mul
,
MulOpConverter
);
REGISTER_TRT_OP_CONVERTER
(
mul
,
MulOpConverter
);
paddle/fluid/inference/tensorrt/convert/pool2d_op.cc
0 → 100644
浏览文件 @
000ba1ac
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
/*
* Pool2dOp, IPoolingLayer in TRT. This Layer doesn't has weights.
*/
class
Pool2dOpConverter
:
public
OpConverter
{
public:
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
const
framework
::
Scope
&
scope
,
bool
test_mode
)
override
{
VLOG
(
4
)
<<
"convert a fluid pool2d op to tensorrt pool2d layer without bias"
;
framework
::
OpDesc
op_desc
(
op
,
nullptr
);
// Declare inputs
PADDLE_ENFORCE_EQ
(
op_desc
.
Input
(
"X"
).
size
(),
1
);
PADDLE_ENFORCE_EQ
(
op_desc
.
Output
(
"Out"
).
size
(),
1
);
auto
*
input1
=
engine_
->
GetITensor
(
op_desc
.
Input
(
"X"
)[
0
]);
std
::
string
pool_type
=
boost
::
get
<
std
::
string
>
(
op_desc
.
GetAttr
(
"pooling_type"
));
std
::
vector
<
int
>
ksize
=
boost
::
get
<
std
::
vector
<
int
>>
(
op_desc
.
GetAttr
(
"ksize"
));
std
::
vector
<
int
>
strides
=
boost
::
get
<
std
::
vector
<
int
>>
(
op_desc
.
GetAttr
(
"strides"
));
std
::
vector
<
int
>
paddings
=
boost
::
get
<
std
::
vector
<
int
>>
(
op_desc
.
GetAttr
(
"paddings"
));
const
nvinfer1
::
DimsHW
nv_ksize
(
ksize
[
0
],
ksize
[
1
]);
const
nvinfer1
::
DimsHW
nv_strides
(
strides
[
0
],
strides
[
1
]);
const
nvinfer1
::
DimsHW
nv_paddings
(
paddings
[
0
],
paddings
[
1
]);
PADDLE_ENFORCE_EQ
(
input1
->
getDimensions
().
nbDims
,
3UL
);
nvinfer1
::
PoolingType
nv_pool_type
=
nvinfer1
::
PoolingType
::
kMAX
;
if
(
pool_type
==
"max"
)
{
nv_pool_type
=
nvinfer1
::
PoolingType
::
kMAX
;
}
else
if
(
pool_type
==
"avg"
)
{
nv_pool_type
=
nvinfer1
::
PoolingType
::
kAVERAGE
;
}
else
{
PADDLE_THROW
(
"TensorRT unsupported pooling type!"
);
}
auto
*
layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Pooling
,
*
const_cast
<
nvinfer1
::
ITensor
*>
(
input1
),
nv_pool_type
,
nv_ksize
);
PADDLE_ENFORCE_NOT_NULL
(
layer
,
"pool layer could not be created."
);
layer
->
setStride
(
nv_strides
);
layer
->
setPadding
(
nv_paddings
);
auto
output_name
=
op_desc
.
Output
(
"Out"
)[
0
];
engine_
->
SetITensor
(
output_name
,
layer
->
getOutput
(
0
));
if
(
test_mode
)
{
engine_
->
DeclareOutput
(
output_name
);
}
}
};
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
USE_OP
(
pool2d
);
REGISTER_TRT_OP_CONVERTER
(
pool2d
,
Pool2dOpConverter
);
paddle/fluid/inference/tensorrt/convert/test_activation_op.cc
浏览文件 @
000ba1ac
...
@@ -37,7 +37,7 @@ TEST(ReluOpConverter, main) {
...
@@ -37,7 +37,7 @@ TEST(ReluOpConverter, main) {
validator
.
SetOp
(
*
desc
.
Proto
());
validator
.
SetOp
(
*
desc
.
Proto
());
LOG
(
INFO
)
<<
"execute"
;
LOG
(
INFO
)
<<
"execute"
;
validator
.
Execute
(
1
);
validator
.
Execute
(
5
);
}
}
}
// namespace tensorrt
}
// namespace tensorrt
...
...
paddle/fluid/inference/tensorrt/convert/test_fc_op.cc
浏览文件 @
000ba1ac
...
@@ -24,9 +24,8 @@ TEST(fc_op, test) {
...
@@ -24,9 +24,8 @@ TEST(fc_op, test) {
std
::
unordered_set
<
std
::
string
>
parameters
({
"mul-Y"
});
std
::
unordered_set
<
std
::
string
>
parameters
({
"mul-Y"
});
framework
::
Scope
scope
;
framework
::
Scope
scope
;
TRTConvertValidation
validator
(
10
,
parameters
,
scope
,
1000
);
TRTConvertValidation
validator
(
10
,
parameters
,
scope
,
1000
);
validator
.
DeclInputVar
(
"mul-X"
,
nvinfer1
::
Dims
4
(
1
,
10
,
1
,
1
));
validator
.
DeclInputVar
(
"mul-X"
,
nvinfer1
::
Dims
3
(
10
,
1
,
1
));
validator
.
DeclParamVar
(
"mul-Y"
,
nvinfer1
::
Dims2
(
10
,
2
));
validator
.
DeclParamVar
(
"mul-Y"
,
nvinfer1
::
Dims2
(
10
,
2
));
// validator.DeclParamVar("mul-Y", nvinfer1::Dims2(8, 2));
validator
.
DeclOutputVar
(
"mul-Out"
,
nvinfer1
::
Dims2
(
1
,
2
));
validator
.
DeclOutputVar
(
"mul-Out"
,
nvinfer1
::
Dims2
(
1
,
2
));
// Prepare Op description
// Prepare Op description
...
@@ -38,7 +37,7 @@ TEST(fc_op, test) {
...
@@ -38,7 +37,7 @@ TEST(fc_op, test) {
validator
.
SetOp
(
*
desc
.
Proto
());
validator
.
SetOp
(
*
desc
.
Proto
());
validator
.
Execute
(
1
);
validator
.
Execute
(
1
0
);
}
}
}
// namespace tensorrt
}
// namespace tensorrt
...
...
paddle/fluid/inference/tensorrt/convert/test_mul_op.cc
浏览文件 @
000ba1ac
...
@@ -23,7 +23,7 @@ namespace tensorrt {
...
@@ -23,7 +23,7 @@ namespace tensorrt {
TEST
(
MulOpConverter
,
main
)
{
TEST
(
MulOpConverter
,
main
)
{
framework
::
Scope
scope
;
framework
::
Scope
scope
;
std
::
unordered_set
<
std
::
string
>
parameters
;
std
::
unordered_set
<
std
::
string
>
parameters
;
TRTConvertValidation
validator
(
10
,
parameters
,
scope
,
1000
);
TRTConvertValidation
validator
(
10
,
parameters
,
scope
,
1000
,
false
);
validator
.
DeclInputVar
(
"mul-X"
,
nvinfer1
::
Dims2
(
10
,
6
));
validator
.
DeclInputVar
(
"mul-X"
,
nvinfer1
::
Dims2
(
10
,
6
));
validator
.
DeclInputVar
(
"mul-Y"
,
nvinfer1
::
Dims2
(
6
,
10
));
validator
.
DeclInputVar
(
"mul-Y"
,
nvinfer1
::
Dims2
(
6
,
10
));
validator
.
DeclOutputVar
(
"mul-Out"
,
nvinfer1
::
Dims2
(
10
,
10
));
validator
.
DeclOutputVar
(
"mul-Out"
,
nvinfer1
::
Dims2
(
10
,
10
));
...
@@ -39,7 +39,7 @@ TEST(MulOpConverter, main) {
...
@@ -39,7 +39,7 @@ TEST(MulOpConverter, main) {
validator
.
SetOp
(
*
desc
.
Proto
());
validator
.
SetOp
(
*
desc
.
Proto
());
LOG
(
INFO
)
<<
"execute"
;
LOG
(
INFO
)
<<
"execute"
;
validator
.
Execute
(
1
);
validator
.
Execute
(
2
);
}
}
}
// namespace tensorrt
}
// namespace tensorrt
...
...
paddle/fluid/inference/tensorrt/convert/test_pool2d_op.cc
0 → 100644
浏览文件 @
000ba1ac
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include <fstream>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/inference/tensorrt/convert/ut_helper.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
TEST
(
Pool2dOpConverter
,
main
)
{
framework
::
Scope
scope
;
std
::
unordered_set
<
std
::
string
>
parameters
;
TRTConvertValidation
validator
(
5
,
parameters
,
scope
,
1
<<
15
);
// The ITensor's Dims should not contain the batch size.
// So, the ITensor's Dims of input and output should be C * H * W.
validator
.
DeclInputVar
(
"pool2d-X"
,
nvinfer1
::
Dims3
(
3
,
4
,
4
));
validator
.
DeclOutputVar
(
"pool2d-Out"
,
nvinfer1
::
Dims3
(
3
,
2
,
2
));
// Prepare Op description
framework
::
OpDesc
desc
;
desc
.
SetType
(
"pool2d"
);
desc
.
SetInput
(
"X"
,
{
"pool2d-X"
});
desc
.
SetOutput
(
"Out"
,
{
"pool2d-Out"
});
std
::
vector
<
int
>
ksize
({
2
,
2
});
std
::
vector
<
int
>
strides
({
2
,
2
});
std
::
vector
<
int
>
paddings
({
0
,
0
});
std
::
string
pooling_t
=
"max"
;
desc
.
SetAttr
(
"pooling_type"
,
pooling_t
);
desc
.
SetAttr
(
"ksize"
,
ksize
);
desc
.
SetAttr
(
"strides"
,
strides
);
desc
.
SetAttr
(
"paddings"
,
paddings
);
LOG
(
INFO
)
<<
"set OP"
;
validator
.
SetOp
(
*
desc
.
Proto
());
LOG
(
INFO
)
<<
"execute"
;
validator
.
Execute
(
3
);
}
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
USE_OP
(
pool2d
);
paddle/fluid/inference/tensorrt/convert/ut_helper.h
浏览文件 @
000ba1ac
...
@@ -63,13 +63,16 @@ class TRTConvertValidation {
...
@@ -63,13 +63,16 @@ class TRTConvertValidation {
public:
public:
TRTConvertValidation
()
=
delete
;
TRTConvertValidation
()
=
delete
;
TRTConvertValidation
(
int
batch_size
,
TRTConvertValidation
(
int
max_
batch_size
,
const
std
::
unordered_set
<
std
::
string
>&
parameters
,
const
std
::
unordered_set
<
std
::
string
>&
parameters
,
framework
::
Scope
&
scope
,
// NOLINT
framework
::
Scope
&
scope
,
// NOLINT
int
workspace_size
=
1
<<
10
)
int
workspace_size
=
1
<<
10
,
bool
if_add_batch
=
true
)
:
parameters_
(
parameters
),
scope_
(
scope
)
{
:
parameters_
(
parameters
),
scope_
(
scope
),
if_add_batch_
(
if_add_batch
),
max_batch_size_
(
max_batch_size
)
{
// create engine.
// create engine.
engine_
.
reset
(
new
TensorRTEngine
(
batch_size
,
workspace_size
,
&
stream_
));
engine_
.
reset
(
new
TensorRTEngine
(
max_
batch_size
,
workspace_size
,
&
stream_
));
engine_
->
InitNetwork
();
engine_
->
InitNetwork
();
PADDLE_ENFORCE_EQ
(
cudaStreamCreate
(
&
stream_
),
0
);
PADDLE_ENFORCE_EQ
(
cudaStreamCreate
(
&
stream_
),
0
);
...
@@ -84,7 +87,7 @@ class TRTConvertValidation {
...
@@ -84,7 +87,7 @@ class TRTConvertValidation {
// Declare a parameter varaible in the scope.
// Declare a parameter varaible in the scope.
void
DeclParamVar
(
const
std
::
string
&
name
,
const
nvinfer1
::
Dims
&
dims
)
{
void
DeclParamVar
(
const
std
::
string
&
name
,
const
nvinfer1
::
Dims
&
dims
)
{
DeclVar
(
name
,
dims
);
DeclVar
(
name
,
dims
,
true
);
}
}
void
DeclOutputVar
(
const
std
::
string
&
name
,
const
nvinfer1
::
Dims
&
dims
)
{
void
DeclOutputVar
(
const
std
::
string
&
name
,
const
nvinfer1
::
Dims
&
dims
)
{
...
@@ -92,12 +95,18 @@ class TRTConvertValidation {
...
@@ -92,12 +95,18 @@ class TRTConvertValidation {
}
}
// Declare a variable in a fluid Scope.
// Declare a variable in a fluid Scope.
void
DeclVar
(
const
std
::
string
&
name
,
const
nvinfer1
::
Dims
&
dims
)
{
void
DeclVar
(
const
std
::
string
&
name
,
const
nvinfer1
::
Dims
&
dims
,
bool
is_param
=
false
)
{
platform
::
CPUPlace
place
;
platform
::
CPUPlace
place
;
platform
::
CPUDeviceContext
ctx
(
place
);
platform
::
CPUDeviceContext
ctx
(
place
);
// Init Fluid tensor.
// Init Fluid tensor.
std
::
vector
<
int
>
dim_vec
(
dims
.
d
,
dims
.
d
+
dims
.
nbDims
);
std
::
vector
<
int
>
dim_vec
(
dims
.
d
,
dims
.
d
+
dims
.
nbDims
);
// There is no batchsize in ITensor's shape, but We should add it to
// tensor's shape of fluid. If the variable is not parameter and the
// if_add_batch_ flag is true, add the max batchsize to dim_vec.
if
(
is_param
!=
true
&&
if_add_batch_
==
true
)
dim_vec
.
insert
(
dim_vec
.
begin
(),
max_batch_size_
);
auto
*
x
=
scope_
.
Var
(
name
);
auto
*
x
=
scope_
.
Var
(
name
);
auto
*
x_tensor
=
x
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
x_tensor
=
x
->
GetMutable
<
framework
::
LoDTensor
>
();
x_tensor
->
Resize
(
framework
::
make_ddim
(
dim_vec
));
x_tensor
->
Resize
(
framework
::
make_ddim
(
dim_vec
));
...
@@ -131,6 +140,7 @@ class TRTConvertValidation {
...
@@ -131,6 +140,7 @@ class TRTConvertValidation {
void
Execute
(
int
batch_size
)
{
void
Execute
(
int
batch_size
)
{
// Execute Fluid Op
// Execute Fluid Op
PADDLE_ENFORCE_LE
(
batch_size
,
max_batch_size_
);
platform
::
CPUPlace
place
;
platform
::
CPUPlace
place
;
platform
::
CPUDeviceContext
ctx
(
place
);
platform
::
CPUDeviceContext
ctx
(
place
);
op_
->
Run
(
scope_
,
place
);
op_
->
Run
(
scope_
,
place
);
...
@@ -149,9 +159,15 @@ class TRTConvertValidation {
...
@@ -149,9 +159,15 @@ class TRTConvertValidation {
auto
*
var
=
scope_
.
FindVar
(
output
);
auto
*
var
=
scope_
.
FindVar
(
output
);
auto
tensor
=
var
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
tensor
=
var
->
GetMutable
<
framework
::
LoDTensor
>
();
framework
::
TensorToVector
(
*
tensor
,
ctx
,
&
fluid_out
);
framework
::
TensorToVector
(
*
tensor
,
ctx
,
&
fluid_out
);
size_t
fluid_out_size
=
fluid_out
.
size
();
if
(
if_add_batch_
==
true
)
{
fluid_out_size
=
batch_size
*
(
framework
::
product
(
tensor
->
dims
())
/
max_batch_size_
);
}
// Compare two output
// Compare two output
ASSERT_FALSE
(
fluid_out
.
empty
());
ASSERT_FALSE
(
fluid_out
.
empty
());
for
(
size_t
i
=
0
;
i
<
fluid_out
.
size
()
;
i
++
)
{
for
(
size_t
i
=
0
;
i
<
fluid_out
_size
;
i
++
)
{
// Loose the threshold for CI in different machine model.
// Loose the threshold for CI in different machine model.
EXPECT_LT
(
std
::
abs
(
fluid_out
[
i
]
-
trt_out
[
i
]),
2e-5
);
EXPECT_LT
(
std
::
abs
(
fluid_out
[
i
]
-
trt_out
[
i
]),
2e-5
);
}
}
...
@@ -167,6 +183,12 @@ class TRTConvertValidation {
...
@@ -167,6 +183,12 @@ class TRTConvertValidation {
std
::
unique_ptr
<
framework
::
OpDesc
>
op_desc_
;
std
::
unique_ptr
<
framework
::
OpDesc
>
op_desc_
;
const
std
::
unordered_set
<
std
::
string
>&
parameters_
;
const
std
::
unordered_set
<
std
::
string
>&
parameters_
;
framework
::
Scope
&
scope_
;
framework
::
Scope
&
scope_
;
// The ITensor of trt does not cotain the batch size,
// bug, in most cases, we need to set batch size for
// fluid's tensor shape. This variable indicates
// whether to add batch size to tensor shape of fluid.
bool
if_add_batch_
;
int
max_batch_size_
;
};
};
}
// namespace tensorrt
}
// namespace tensorrt
...
...
paddle/fluid/inference/tensorrt/test_engine.cc
浏览文件 @
000ba1ac
...
@@ -113,7 +113,7 @@ TEST_F(TensorRTEngineTest, add_layer_multi_dim) {
...
@@ -113,7 +113,7 @@ TEST_F(TensorRTEngineTest, add_layer_multi_dim) {
ASSERT_EQ
(
y_cpu
[
1
],
14.5
);
ASSERT_EQ
(
y_cpu
[
1
],
14.5
);
}
}
TEST_F
(
TensorRTEngineTest
,
test_conv2d
_temp
)
{
TEST_F
(
TensorRTEngineTest
,
test_conv2d
)
{
// Weight in CPU memory.
// Weight in CPU memory.
float
raw_weight
[
9
]
=
{
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
};
float
raw_weight
[
9
]
=
{
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
};
float
raw_bias
[
1
]
=
{
0
};
float
raw_bias
[
1
]
=
{
0
};
...
@@ -146,6 +146,37 @@ TEST_F(TensorRTEngineTest, test_conv2d_temp) {
...
@@ -146,6 +146,37 @@ TEST_F(TensorRTEngineTest, test_conv2d_temp) {
ASSERT_EQ
(
y_cpu
[
1
],
6.0
);
ASSERT_EQ
(
y_cpu
[
1
],
6.0
);
}
}
TEST_F
(
TensorRTEngineTest
,
test_pool2d
)
{
// Weight in CPU memory.
auto
*
x
=
engine_
->
DeclareInput
(
"x"
,
nvinfer1
::
DataType
::
kFLOAT
,
nvinfer1
::
Dims3
{
1
,
2
,
2
});
nvinfer1
::
PoolingType
pool_t
=
nvinfer1
::
PoolingType
::
kAVERAGE
;
auto
*
pool_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Pooling
,
*
const_cast
<
nvinfer1
::
ITensor
*>
(
x
),
pool_t
,
nvinfer1
::
DimsHW
{
2
,
2
});
PADDLE_ENFORCE
(
pool_layer
!=
nullptr
);
pool_layer
->
setStride
(
nvinfer1
::
DimsHW
{
1
,
1
});
pool_layer
->
setPadding
(
nvinfer1
::
DimsHW
{
0
,
0
});
engine_
->
DeclareOutput
(
pool_layer
,
0
,
"y"
);
engine_
->
FreezeNetwork
();
ASSERT_EQ
(
engine_
->
engine
()
->
getNbBindings
(),
2
);
float
x_v
[
8
]
=
{
1.0
,
2.0
,
5.0
,
0.0
,
2.0
,
3.0
,
5.0
,
10.0
};
engine_
->
SetInputFromCPU
(
"x"
,
reinterpret_cast
<
void
*>
(
&
x_v
),
8
*
sizeof
(
float
));
engine_
->
Execute
(
2
);
LOG
(
INFO
)
<<
"to get output"
;
float
*
y_cpu
=
new
float
[
2
];
engine_
->
GetOutputInCPU
(
"y"
,
&
y_cpu
[
0
],
2
*
sizeof
(
float
));
ASSERT_EQ
(
y_cpu
[
0
],
2.0
);
ASSERT_EQ
(
y_cpu
[
1
],
5.0
);
}
}
// namespace tensorrt
}
// namespace tensorrt
}
// namespace inference
}
// namespace inference
}
// namespace paddle
}
// namespace paddle
paddle/fluid/inference/tests/book/CMakeLists.txt
浏览文件 @
000ba1ac
...
@@ -17,7 +17,7 @@ function(inference_test TARGET_NAME)
...
@@ -17,7 +17,7 @@ function(inference_test TARGET_NAME)
string
(
REGEX REPLACE
"^_$"
""
arg
"
${
arg
}
"
)
string
(
REGEX REPLACE
"^_$"
""
arg
"
${
arg
}
"
)
cc_test
(
test_inference_
${
TARGET_NAME
}${
arg
}
cc_test
(
test_inference_
${
TARGET_NAME
}${
arg
}
SRCS test_inference_
${
TARGET_NAME
}
.cc
SRCS test_inference_
${
TARGET_NAME
}
.cc
DEPS paddle_fluid
DEPS paddle_fluid
_origin
ARGS --dirname=
${
PYTHON_TESTS_DIR
}
/book/
${
TARGET_NAME
}${
arg
}
.inference.model
)
ARGS --dirname=
${
PYTHON_TESTS_DIR
}
/book/
${
TARGET_NAME
}${
arg
}
.inference.model
)
set_tests_properties
(
test_inference_
${
TARGET_NAME
}${
arg
}
set_tests_properties
(
test_inference_
${
TARGET_NAME
}${
arg
}
PROPERTIES DEPENDS test_
${
TARGET_NAME
}
)
PROPERTIES DEPENDS test_
${
TARGET_NAME
}
)
...
@@ -43,6 +43,6 @@ inference_test(word2vec)
...
@@ -43,6 +43,6 @@ inference_test(word2vec)
# TODO(TJ): clean me up
# TODO(TJ): clean me up
cc_test
(
test_inference_nlp
cc_test
(
test_inference_nlp
SRCS test_inference_nlp.cc
SRCS test_inference_nlp.cc
DEPS paddle_fluid
DEPS paddle_fluid
_origin
ARGS
ARGS
--model_path=
${
PADDLE_BINARY_DIR
}
/python/paddle/fluid/tests/book/recognize_digits_mlp.inference.model
)
--model_path=
${
PADDLE_BINARY_DIR
}
/python/paddle/fluid/tests/book/recognize_digits_mlp.inference.model
)
paddle/fluid/inference/tests/book/test_inference_nlp.cc
浏览文件 @
000ba1ac
...
@@ -20,9 +20,6 @@ limitations under the License. */
...
@@ -20,9 +20,6 @@ limitations under the License. */
#include "gtest/gtest.h"
#include "gtest/gtest.h"
#include "paddle/fluid/inference/tests/test_helper.h"
#include "paddle/fluid/inference/tests/test_helper.h"
#include "paddle/fluid/platform/cpu_helper.h"
#include "paddle/fluid/platform/cpu_helper.h"
#ifdef PADDLE_WITH_MKLML
#include <omp.h>
#endif
DEFINE_string
(
model_path
,
""
,
"Directory of the inference model."
);
DEFINE_string
(
model_path
,
""
,
"Directory of the inference model."
);
DEFINE_string
(
data_file
,
""
,
"File of input index data."
);
DEFINE_string
(
data_file
,
""
,
"File of input index data."
);
...
@@ -30,6 +27,7 @@ DEFINE_int32(repeat, 100, "Running the inference program repeat times");
...
@@ -30,6 +27,7 @@ DEFINE_int32(repeat, 100, "Running the inference program repeat times");
DEFINE_bool
(
prepare_vars
,
true
,
"Prepare variables before executor"
);
DEFINE_bool
(
prepare_vars
,
true
,
"Prepare variables before executor"
);
DEFINE_int32
(
num_threads
,
1
,
"Number of threads should be used"
);
DEFINE_int32
(
num_threads
,
1
,
"Number of threads should be used"
);
DECLARE_bool
(
use_mkldnn
);
DECLARE_bool
(
use_mkldnn
);
DECLARE_int32
(
paddle_num_threads
);
inline
double
GetCurrentMs
()
{
inline
double
GetCurrentMs
()
{
struct
timeval
time
;
struct
timeval
time
;
...
@@ -160,12 +158,7 @@ TEST(inference, nlp) {
...
@@ -160,12 +158,7 @@ TEST(inference, nlp) {
std
::
unique_ptr
<
paddle
::
framework
::
Scope
>
scope
(
std
::
unique_ptr
<
paddle
::
framework
::
Scope
>
scope
(
new
paddle
::
framework
::
Scope
());
new
paddle
::
framework
::
Scope
());
#ifdef PADDLE_WITH_MKLML
paddle
::
platform
::
SetNumThreads
(
FLAGS_paddle_num_threads
);
// only use 1 thread number per std::thread
omp_set_dynamic
(
0
);
omp_set_num_threads
(
1
);
paddle
::
platform
::
SetNumThreads
(
1
);
#endif
double
start_ms
=
0
,
stop_ms
=
0
;
double
start_ms
=
0
,
stop_ms
=
0
;
if
(
FLAGS_num_threads
>
1
)
{
if
(
FLAGS_num_threads
>
1
)
{
...
...
paddle/fluid/operators/.flatten_op.cc.swp
0 → 100644
浏览文件 @
000ba1ac
文件已添加
paddle/fluid/operators/CMakeLists.txt
浏览文件 @
000ba1ac
...
@@ -270,6 +270,9 @@ op_library(cos_sim_op DEPS cos_sim_functor)
...
@@ -270,6 +270,9 @@ op_library(cos_sim_op DEPS cos_sim_functor)
op_library
(
parallel_do_op DEPS executor
)
op_library
(
parallel_do_op DEPS executor
)
op_library
(
unsqueeze_op DEPS reshape_op
)
op_library
(
unsqueeze_op DEPS reshape_op
)
op_library
(
squeeze_op DEPS reshape_op
)
op_library
(
squeeze_op DEPS reshape_op
)
op_library
(
extract_rows_op DEPS memory
)
op_library
(
flatten_op DEPS reshape_op
)
if
(
WITH_GPU
)
if
(
WITH_GPU
)
op_library
(
conv_op DEPS vol2col depthwise_conv im2col
)
op_library
(
conv_op DEPS vol2col depthwise_conv im2col
)
...
...
paddle/fluid/operators/distributed/CMakeLists.txt
浏览文件 @
000ba1ac
...
@@ -19,7 +19,7 @@ if(WITH_GRPC)
...
@@ -19,7 +19,7 @@ if(WITH_GRPC)
cc_test
(
grpc_serde_test SRCS grpc_serde_test.cc
cc_test
(
grpc_serde_test SRCS grpc_serde_test.cc
DEPS grpc++_unsecure grpc_unsecure gpr cares zlib protobuf sendrecvop_grpc scope profiler math_function SERIAL
)
DEPS grpc++_unsecure grpc_unsecure gpr cares zlib protobuf sendrecvop_grpc scope profiler math_function SERIAL
)
cc_test
(
rpc_server_test SRCS rpc_server_test.cc
cc_test
(
rpc_server_test SRCS rpc_server_test.cc
DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor proto_desc lookup_table_op SERIAL
)
DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor proto_desc lookup_
sparse_
table_op SERIAL
)
return
()
return
()
endif
()
endif
()
...
...
paddle/fluid/operators/distributed/grpc_client.cc
浏览文件 @
000ba1ac
...
@@ -49,6 +49,7 @@ void GRPCClient::SendComplete() {
...
@@ -49,6 +49,7 @@ void GRPCClient::SendComplete() {
}
}
GRPCClient
::~
GRPCClient
()
{
GRPCClient
::~
GRPCClient
()
{
stopped_
=
true
;
Wait
();
Wait
();
cq_
.
Shutdown
();
cq_
.
Shutdown
();
{
{
...
@@ -275,7 +276,7 @@ void GRPCClient::Proceed() {
...
@@ -275,7 +276,7 @@ void GRPCClient::Proceed() {
void
*
tag
=
nullptr
;
void
*
tag
=
nullptr
;
bool
ok
=
false
;
bool
ok
=
false
;
while
(
cq_
.
Next
(
&
tag
,
&
ok
))
{
while
(
!
stopped_
&&
cq_
.
Next
(
&
tag
,
&
ok
))
{
BaseProcessor
*
c
=
static_cast
<
BaseProcessor
*>
(
tag
);
BaseProcessor
*
c
=
static_cast
<
BaseProcessor
*>
(
tag
);
GPR_ASSERT
(
ok
);
GPR_ASSERT
(
ok
);
PADDLE_ENFORCE
(
c
);
PADDLE_ENFORCE
(
c
);
...
...
paddle/fluid/operators/distributed/grpc_client.h
浏览文件 @
000ba1ac
...
@@ -174,7 +174,7 @@ class CheckpointNotifyProcessor : public BaseProcessor {
...
@@ -174,7 +174,7 @@ class CheckpointNotifyProcessor : public BaseProcessor {
class
GRPCClient
:
public
RPCClient
{
class
GRPCClient
:
public
RPCClient
{
public:
public:
GRPCClient
()
:
ok_
(
true
),
completed_
(
false
)
{}
GRPCClient
()
:
ok_
(
true
),
completed_
(
false
)
,
stopped_
(
false
)
{}
virtual
~
GRPCClient
();
virtual
~
GRPCClient
();
bool
AsyncSendVar
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
bool
AsyncSendVar
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
...
@@ -237,6 +237,8 @@ class GRPCClient : public RPCClient {
...
@@ -237,6 +237,8 @@ class GRPCClient : public RPCClient {
// mutex for sending complete message only once
// mutex for sending complete message only once
std
::
mutex
completed_mutex_
;
std
::
mutex
completed_mutex_
;
bool
completed_
;
bool
completed_
;
volatile
bool
stopped_
;
};
};
}
// namespace distributed
}
// namespace distributed
...
...
paddle/fluid/operators/distributed/rpc_server_test.cc
浏览文件 @
000ba1ac
...
@@ -30,7 +30,7 @@ namespace framework = paddle::framework;
...
@@ -30,7 +30,7 @@ namespace framework = paddle::framework;
namespace
platform
=
paddle
::
platform
;
namespace
platform
=
paddle
::
platform
;
namespace
distributed
=
paddle
::
operators
::
distributed
;
namespace
distributed
=
paddle
::
operators
::
distributed
;
USE_
OP
(
lookup
_table
);
USE_
NO_KERNEL_OP
(
lookup_sparse
_table
);
std
::
unique_ptr
<
distributed
::
RPCServer
>
g_rpc_service
;
std
::
unique_ptr
<
distributed
::
RPCServer
>
g_rpc_service
;
std
::
unique_ptr
<
distributed
::
RequestHandler
>
g_req_handler
;
std
::
unique_ptr
<
distributed
::
RequestHandler
>
g_req_handler
;
...
@@ -42,13 +42,13 @@ framework::BlockDesc* AppendPrefetchBlcok(framework::ProgramDesc* program) {
...
@@ -42,13 +42,13 @@ framework::BlockDesc* AppendPrefetchBlcok(framework::ProgramDesc* program) {
framework
::
VariableNameMap
input
({{
"W"
,
{
"w"
}},
{
"Ids"
,
{
"ids"
}}});
framework
::
VariableNameMap
input
({{
"W"
,
{
"w"
}},
{
"Ids"
,
{
"ids"
}}});
framework
::
VariableNameMap
output
({{
"Output"
,
{
"out"
}}});
framework
::
VariableNameMap
output
({{
"Output"
,
{
"out"
}}});
auto
op
=
block
->
AppendOp
();
auto
op
=
block
->
AppendOp
();
op
->
SetType
(
"lookup_table"
);
op
->
SetType
(
"lookup_
sparse_
table"
);
op
->
SetInput
(
"W"
,
{
"w"
});
op
->
SetInput
(
"W"
,
{
"w"
});
op
->
SetInput
(
"Ids"
,
{
"ids"
});
op
->
SetInput
(
"Ids"
,
{
"ids"
});
op
->
SetOutput
(
"Out"
,
{
"out"
});
op
->
SetOutput
(
"Out"
,
{
"out"
});
auto
&
out
=
*
root_block
->
Var
(
"out"
);
auto
&
out
=
*
root_block
->
Var
(
"out"
);
out
.
SetType
(
framework
::
proto
::
VarType
::
SELECTED_ROWS
);
out
.
SetType
(
framework
::
proto
::
VarType
::
LOD_TENSOR
);
out
.
SetShape
({
10
,
10
});
out
.
SetShape
({
10
,
10
});
return
block
;
return
block
;
...
@@ -59,20 +59,19 @@ void CreateVarsOnScope(framework::Scope* scope, platform::CPUPlace* place) {
...
@@ -59,20 +59,19 @@ void CreateVarsOnScope(framework::Scope* scope, platform::CPUPlace* place) {
w_var
->
GetMutable
<
framework
::
SelectedRows
>
();
w_var
->
GetMutable
<
framework
::
SelectedRows
>
();
auto
out_var
=
scope
->
Var
(
"out"
);
auto
out_var
=
scope
->
Var
(
"out"
);
out_var
->
GetMutable
<
framework
::
SelectedRows
>
();
out_var
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
ids_var
=
scope
->
Var
(
"ids"
);
auto
ids_var
=
scope
->
Var
(
"ids"
);
ids_var
->
GetMutable
<
framework
::
SelectedRows
>
();
ids_var
->
GetMutable
<
framework
::
LoDTensor
>
();
}
}
void
InitTensorsOnClient
(
framework
::
Scope
*
scope
,
platform
::
CPUPlace
*
place
,
void
InitTensorsOnClient
(
framework
::
Scope
*
scope
,
platform
::
CPUPlace
*
place
,
int64_t
rows_numel
)
{
int64_t
rows_numel
)
{
CreateVarsOnScope
(
scope
,
place
);
CreateVarsOnScope
(
scope
,
place
);
auto
ids_var
=
scope
->
Var
(
"ids"
)
->
GetMutable
<
framework
::
SelectedRows
>
();
auto
ids_var
=
scope
->
Var
(
"ids"
)
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
rows
=
ids_var
->
mutable_rows
();
int64_t
*
ids_ptr
=
for
(
int64_t
i
=
0
;
i
<
rows_numel
;
++
i
)
rows
->
push_back
(
i
*
2
);
ids_var
->
mutable_data
<
int64_t
>
(
framework
::
DDim
({
rows_numel
,
1
}),
*
place
);
ids_var
->
mutable_value
()
->
Resize
({
rows_numel
,
1
});
for
(
int64_t
i
=
0
;
i
<
rows_numel
;
++
i
)
ids_ptr
[
i
]
=
i
*
2
;
ids_var
->
mutable_value
()
->
mutable_data
<
float
>
(
*
place
);
}
}
void
InitTensorsOnServer
(
framework
::
Scope
*
scope
,
platform
::
CPUPlace
*
place
,
void
InitTensorsOnServer
(
framework
::
Scope
*
scope
,
platform
::
CPUPlace
*
place
,
...
@@ -148,11 +147,11 @@ TEST(PREFETCH, CPU) {
...
@@ -148,11 +147,11 @@ TEST(PREFETCH, CPU) {
client
->
AsyncPrefetchVar
(
ep
,
ctx
,
scope
,
in_var_name
,
out_var_name
);
client
->
AsyncPrefetchVar
(
ep
,
ctx
,
scope
,
in_var_name
,
out_var_name
);
client
->
Wait
();
client
->
Wait
();
auto
var
=
scope
.
Var
(
out_var_name
);
auto
var
=
scope
.
Var
(
out_var_name
);
auto
value
=
var
->
GetMutable
<
framework
::
SelectedRows
>
()
->
value
();
auto
value
=
var
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
ptr
=
value
.
mutable_data
<
float
>
(
place
);
auto
ptr
=
value
->
mutable_data
<
float
>
(
place
);
for
(
int64_t
i
=
0
;
i
<
rows_numel
;
++
i
)
{
for
(
int64_t
i
=
0
;
i
<
rows_numel
;
++
i
)
{
EXPECT_EQ
(
ptr
[
0
+
i
*
value
.
dims
()[
1
]],
static_cast
<
float
>
(
i
*
2
));
EXPECT_EQ
(
ptr
[
0
+
i
*
value
->
dims
()[
1
]],
static_cast
<
float
>
(
i
*
2
));
}
}
}
}
...
...
paddle/fluid/operators/extract_rows_op.cc
0 → 100644
浏览文件 @
000ba1ac
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
namespace
paddle
{
namespace
operators
{
class
ExtractRowsOpInferShape
:
public
framework
::
InferShapeBase
{
public:
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of ExtractRowsOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) of ExtractRowsOp should not be null."
);
PADDLE_ENFORCE_EQ
(
ctx
->
GetInputsVarType
(
"X"
)[
0
],
framework
::
proto
::
VarType
::
SELECTED_ROWS
,
"The type of input(X) must be SelectedRows."
);
auto
in_dims
=
ctx
->
GetInputDim
(
"X"
);
ctx
->
SetOutputDim
(
"Out"
,
framework
::
make_ddim
(
std
::
vector
<
int64_t
>
{
in_dims
[
0
],
1
}));
}
};
class
ExtractRowsOp
:
public
framework
::
OperatorBase
{
public:
ExtractRowsOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
framework
::
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
auto
&
in
=
scope
.
FindVar
(
Input
(
"X"
))
->
Get
<
framework
::
SelectedRows
>
();
auto
out
=
scope
.
FindVar
(
Output
(
"Out"
))
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
in_rows
=
in
.
rows
();
auto
out_dim
=
framework
::
make_ddim
(
std
::
vector
<
int64_t
>
{
static_cast
<
int64_t
>
(
in_rows
.
size
()),
1
});
auto
dst_ptr
=
out
->
mutable_data
<
int64_t
>
(
out_dim
,
in
.
place
());
if
(
paddle
::
platform
::
is_gpu_place
(
in
.
place
()))
{
#ifdef PADDLE_WITH_CUDA
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
*
dev_ctx
=
pool
.
Get
(
in
.
place
());
auto
src_ptr
=
in_rows
.
Data
(
in
.
place
());
auto
stream
=
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
*
dev_ctx
)
.
stream
();
memory
::
Copy
(
boost
::
get
<
platform
::
CUDAPlace
>
(
out
->
place
()),
dst_ptr
,
boost
::
get
<
platform
::
CUDAPlace
>
(
in
.
place
()),
src_ptr
,
in_rows
.
size
()
*
sizeof
(
int64_t
),
stream
);
#else
PADDLE_THROW
(
"Not compiled with CUDA."
);
#endif
}
else
{
memory
::
Copy
(
platform
::
CPUPlace
(),
dst_ptr
,
platform
::
CPUPlace
(),
in_rows
.
data
(),
in_rows
.
size
()
*
sizeof
(
int64_t
));
}
}
};
class
ExtractRowsOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
"X"
,
"(SelectedRows). The input tensor of extract_rows operator,"
" and its type is SelectedRows."
);
AddOutput
(
"Out"
,
"(Tensor). The the rows of input(X)."
);
AddComment
(
R"DOC(
ExtractRows Operator.
The function of extract_rows_op is extracting the rows from the input(X)
whose type is SelectedRows.
)DOC"
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
extract_rows
,
ops
::
ExtractRowsOp
,
ops
::
ExtractRowsOpMaker
,
ops
::
ExtractRowsOpInferShape
);
paddle/fluid/operators/flatten_op.cc
0 → 100644
浏览文件 @
000ba1ac
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
class
FlattenOpInferShape
:
public
framework
::
InferShapeBase
{
public:
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input (X) of Flatten op should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output (Output) of Flatten op should not be null."
);
const
auto
&
axis
=
ctx
->
Attrs
().
Get
<
int
>
(
"axis"
);
const
auto
&
in_dims
=
ctx
->
GetInputDim
(
"X"
);
PADDLE_ENFORCE
(
axis
>=
0
,
"The axis should be greater than or equal to 0."
);
PADDLE_ENFORCE
(
axis
<=
in_dims
.
size
(),
"The axis should be less than or equal to input tensor's rank."
);
const
auto
&
out_dims
=
GetOutputShape
(
axis
,
in_dims
);
ctx
->
SetOutputDim
(
"Out"
,
framework
::
make_ddim
(
out_dims
));
if
(
in_dims
[
0
]
==
out_dims
[
0
])
{
// Only pass LoD when the first dimension of output and Input(X)
// are the same.
ctx
->
ShareLoD
(
"X"
,
"Out"
);
}
}
static
std
::
vector
<
int32_t
>
GetOutputShape
(
const
int
axis
,
const
framework
::
DDim
&
in_dims
)
{
int64_t
outer
=
1
,
inner
=
1
;
for
(
int
i
=
0
;
i
<
in_dims
.
size
();
++
i
)
{
if
(
i
<
axis
)
{
outer
*=
in_dims
[
i
];
}
else
{
inner
*=
in_dims
[
i
];
}
}
std
::
vector
<
int32_t
>
out_shape
(
2
);
out_shape
[
0
]
=
outer
;
out_shape
[
1
]
=
inner
;
return
out_shape
;
}
};
class
FlattenOp
:
public
framework
::
OperatorBase
{
public:
using
OperatorBase
::
OperatorBase
;
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
auto
&
axis
=
Attr
<
int
>
(
"axis"
);
auto
in_dims
=
scope
.
FindVar
(
Input
(
"X"
))
->
Get
<
framework
::
LoDTensor
>
().
dims
();
const
auto
&
out_dims
=
FlattenOpInferShape
::
GetOutputShape
(
axis
,
in_dims
);
framework
::
AttributeMap
attrs
;
attrs
[
"shape"
]
=
out_dims
;
attrs
[
"inplace"
]
=
false
;
// Invoke Reshape Op
auto
reshape_op
=
framework
::
OpRegistry
::
CreateOp
(
"reshape"
,
{{
"X"
,
{
Input
(
"X"
)}},
{
"Shape"
,
{}}},
{{
"Out"
,
{
Output
(
"Out"
)}}},
attrs
);
reshape_op
->
Run
(
scope
,
place
);
}
};
class
FlattenOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
"X"
,
"(Tensor) A tensor of rank >= axis."
);
AddOutput
(
"Out"
,
"A 2D tensor is reshaped input tensor. The input dimensions"
"up to axis are flattened to the outer dimension of the output"
"and the remaining input dimensions are flattened into the inner"
"dimension of the output."
);
AddAttr
<
int
>
(
"axis"
,
"(int)"
"Indicate up to which input dimensions (exclusive) should be"
"flattened to the outer dimension of the output. The value"
"for axis must be in the range [0, R], where R is the rank of"
"the input tensor. When axis = 0, the shape of the output"
"tensor is (1, (d_0 X d_1 ... d_n), where the shape of the"
"input tensor is (d_0, d_1, ... d_n)."
)
.
SetDefault
(
1
);
AddComment
(
R"DOC(
Flatten Operator
Flattens the input tensor into a 2D matrix.
Examples:
Case 1:
Given
X.shape = (3, 100, 100, 4)
and
axis = 2
We get:
Out.shape = (3 * 100, 4 * 100)
Case 2:
Given
X.shape = (3, 100, 100, 4)
and
axis = 0
We get:
Out.shape = (1, 3 * 100 * 100 * 4)
)DOC"
);
}
};
class
FlattenGradInferShape
:
public
framework
::
InferShapeBase
{
public:
void
operator
()(
framework
::
InferShapeContext
*
context
)
const
override
{
context
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
context
->
GetInputDim
(
"X"
));
context
->
ShareLoD
(
"X"
,
framework
::
GradVarName
(
"X"
));
}
};
class
FlattenGradOp
:
public
framework
::
OperatorBase
{
public:
using
OperatorBase
::
OperatorBase
;
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
auto
dx_name
=
Output
(
framework
::
GradVarName
(
"X"
));
auto
dout_name
=
Input
(
framework
::
GradVarName
(
"Out"
));
auto
in_dims
=
scope
.
FindVar
(
Input
(
"X"
))
->
Get
<
framework
::
LoDTensor
>
().
dims
();
framework
::
AttributeMap
attrs
;
attrs
[
"shape"
]
=
framework
::
vectorize2int
(
in_dims
);
attrs
[
"inplace"
]
=
false
;
auto
reshape_op
=
framework
::
OpRegistry
::
CreateOp
(
"reshape"
,
{{
"X"
,
{
dout_name
}},
{
"Shape"
,
{}}},
{{
"Out"
,
{
dx_name
}}},
attrs
);
reshape_op
->
Run
(
scope
,
place
);
}
};
}
// namespace operators
}
// namespace paddle
USE_OP
(
reshape
);
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
flatten
,
ops
::
FlattenOp
,
ops
::
FlattenOpMaker
,
ops
::
FlattenOpInferShape
,
paddle
::
framework
::
DefaultGradOpDescMaker
<
true
>
);
REGISTER_OPERATOR
(
flatten_grad
,
ops
::
FlattenGradOp
,
ops
::
FlattenGradInferShape
);
paddle/fluid/operators/lookup_table_op.cc
浏览文件 @
000ba1ac
...
@@ -33,20 +33,16 @@ class LookupTableOp : public framework::OperatorWithKernel {
...
@@ -33,20 +33,16 @@ class LookupTableOp : public framework::OperatorWithKernel {
auto
table_dims
=
ctx
->
GetInputDim
(
"W"
);
auto
table_dims
=
ctx
->
GetInputDim
(
"W"
);
auto
ids_dims
=
ctx
->
GetInputDim
(
"Ids"
);
auto
ids_dims
=
ctx
->
GetInputDim
(
"Ids"
);
auto
ids_var_type
=
ctx
->
GetInputsVarType
(
"Ids"
).
front
();
// The type of Ids(Input) is SelectedRows or LoDTensor, when Ids's type
// is LoDTensor, this tensor contains the ids to be looked up in W
// and it must be a column vector with rank = 2 while the 2nd dimension
// size must be 1, when Ids's type is SelectedRows, the rows of Ids
// contains the ids to be looked up in W;
if
(
ids_var_type
==
framework
::
proto
::
VarType
::
LOD_TENSOR
)
{
PADDLE_ENFORCE_EQ
(
ids_dims
.
size
(),
2
);
PADDLE_ENFORCE_EQ
(
ids_dims
.
size
(),
2
);
PADDLE_ENFORCE_EQ
(
ids_dims
[
1
],
1
);
PADDLE_ENFORCE_EQ
(
ids_dims
[
1
],
1
);
}
ctx
->
SetOutputDim
(
"Out"
,
{
ids_dims
[
0
],
table_dims
[
1
]});
ctx
->
SetOutputDim
(
"Out"
,
{
ids_dims
[
0
],
table_dims
[
1
]});
if
(
ctx
->
GetOutputsVarType
(
"Out"
)[
0
]
==
framework
::
proto
::
VarType
::
LOD_TENSOR
)
{
ctx
->
ShareLoD
(
"Ids"
,
/*->*/
"Out"
);
ctx
->
ShareLoD
(
"Ids"
,
/*->*/
"Out"
);
}
}
}
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
...
@@ -62,17 +58,12 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -62,17 +58,12 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput
(
"W"
,
AddInput
(
"W"
,
"(Tensor) The input represents embedding tensors, "
"(Tensor) The input represents embedding tensors, "
"which is a learnable parameter."
);
"which is a learnable parameter."
);
AddInput
(
AddInput
(
"Ids"
,
"Ids"
,
"An input with type int32 or int64 "
"(Tensor or SelectedRows) Ids's type can be Tensor or "
"contains the ids to be looked up in W. "
"SelectedRows, when Ids's type is Tensor, this tensor contains "
"Ids must be a column vector with rank = 2. "
"the ids to be looked up in W and it must be a column vector with "
"The 2nd dimension size must be 1."
);
"rank = 2 while the 2nd dimension size must be 1; when Ids's type is "
AddOutput
(
"Out"
,
"The lookup results, which have the same type as W."
);
"SelectedRows, the rows of Ids contains the ids to be looked up "
"in W."
);
AddOutput
(
"Out"
,
"(Tensor or SelectedRows) The lookup results, which have the "
"same type as W."
);
AddAttr
<
bool
>
(
"is_sparse"
,
AddAttr
<
bool
>
(
"is_sparse"
,
"(boolean, default false) "
"(boolean, default false) "
"Sparse update."
)
"Sparse update."
)
...
@@ -90,15 +81,10 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -90,15 +81,10 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
Lookup Table Operator.
Lookup Table Operator.
This operator is used to perform lookups on the parameter W,
This operator is used to perform lookups on the parameter W,
then concatenated into a dense or sparse tensor.
then concatenated into a dense tensor.
The type of Ids(Input) is SelectedRows, Tensor or LoDTensor, when Ids's
type is SelectedRows, the rows of Ids contains the ids to be looked up in W;
when Ids's type is Tensor, this tensor contains the ids to be looked up in W
and it must be a column vector with rank = 2 while the 2nd dimension size must be 1,
at this time, Ids can carry the LoD (Level of Details) information, or not, and
the output only shares the LoD information with input Ids.
The input Ids can carry the LoD (Level of Details) information,
or not. And the output only shares the LoD information with input Ids.
)DOC"
);
)DOC"
);
}
}
...
...
paddle/fluid/operators/lookup_table_op.cu
浏览文件 @
000ba1ac
...
@@ -23,7 +23,7 @@ namespace operators {
...
@@ -23,7 +23,7 @@ namespace operators {
template
<
typename
T
,
int
BlockDimX
,
int
BlockDimY
,
int
GridDimX
,
template
<
typename
T
,
int
BlockDimX
,
int
BlockDimY
,
int
GridDimX
,
bool
PaddingFlag
>
bool
PaddingFlag
>
__global__
void
LookupTable
(
T
*
output
,
const
T
*
table
,
const
int64_t
*
ids
,
__global__
void
LookupTable
(
T
*
output
,
const
T
*
table
,
const
int64_t
*
ids
,
const
int64_t
N
,
const
int64_t
K
,
const
int64_t
D
,
const
int64_t
N
,
const
int64_t
K
,
const
int64_t
D
,
const
int64_t
padding_idx
)
{
const
int64_t
padding_idx
)
{
int
idx
=
threadIdx
.
x
;
int
idx
=
threadIdx
.
x
;
...
@@ -33,8 +33,8 @@ __global__ void LookupTable(T* output, const T* table, const int64_t* ids,
...
@@ -33,8 +33,8 @@ __global__ void LookupTable(T* output, const T* table, const int64_t* ids,
int64_t
id
=
ids
[
idy
];
int64_t
id
=
ids
[
idy
];
PADDLE_ASSERT
(
id
>=
0
);
PADDLE_ASSERT
(
id
>=
0
);
PADDLE_ASSERT
(
id
<
N
);
PADDLE_ASSERT
(
id
<
N
);
T
*
out
=
output
+
idy
*
D
;
T
*
out
=
output
+
idy
*
D
;
const
T
*
tab
=
table
+
id
*
D
;
const
T
*
tab
=
table
+
id
*
D
;
for
(
int
i
=
idx
;
i
<
D
;
i
+=
BlockDimX
)
{
for
(
int
i
=
idx
;
i
<
D
;
i
+=
BlockDimX
)
{
if
(
PaddingFlag
)
{
if
(
PaddingFlag
)
{
if
(
id
==
padding_idx
)
if
(
id
==
padding_idx
)
...
@@ -50,7 +50,7 @@ __global__ void LookupTable(T* output, const T* table, const int64_t* ids,
...
@@ -50,7 +50,7 @@ __global__ void LookupTable(T* output, const T* table, const int64_t* ids,
}
}
template
<
typename
T
,
int
BlockDimX
,
int
BlockDimY
,
int
GridDimX
>
template
<
typename
T
,
int
BlockDimX
,
int
BlockDimY
,
int
GridDimX
>
__global__
void
LookupTableGrad
(
T
*
table
,
const
T
*
output
,
const
int64_t
*
ids
,
__global__
void
LookupTableGrad
(
T
*
table
,
const
T
*
output
,
const
int64_t
*
ids
,
const
int64_t
N
,
const
int64_t
K
,
const
int64_t
N
,
const
int64_t
K
,
const
int64_t
D
)
{
const
int64_t
D
)
{
int
idx
=
threadIdx
.
x
;
int
idx
=
threadIdx
.
x
;
...
@@ -60,8 +60,8 @@ __global__ void LookupTableGrad(T* table, const T* output, const int64_t* ids,
...
@@ -60,8 +60,8 @@ __global__ void LookupTableGrad(T* table, const T* output, const int64_t* ids,
int
id
=
ids
[
idy
];
int
id
=
ids
[
idy
];
PADDLE_ASSERT
(
id
>=
0
);
PADDLE_ASSERT
(
id
>=
0
);
PADDLE_ASSERT
(
id
<
N
);
PADDLE_ASSERT
(
id
<
N
);
const
T
*
out
=
output
+
idy
*
D
;
const
T
*
out
=
output
+
idy
*
D
;
T
*
tab
=
table
+
id
*
D
;
T
*
tab
=
table
+
id
*
D
;
for
(
int
i
=
idx
;
i
<
D
;
i
+=
BlockDimX
)
{
for
(
int
i
=
idx
;
i
<
D
;
i
+=
BlockDimX
)
{
paddle
::
platform
::
CudaAtomicAdd
(
&
tab
[
i
],
out
[
i
]);
paddle
::
platform
::
CudaAtomicAdd
(
&
tab
[
i
],
out
[
i
]);
}
}
...
@@ -72,36 +72,19 @@ __global__ void LookupTableGrad(T* table, const T* output, const int64_t* ids,
...
@@ -72,36 +72,19 @@ __global__ void LookupTableGrad(T* table, const T* output, const int64_t* ids,
template
<
typename
T
>
template
<
typename
T
>
class
LookupTableCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
class
LookupTableCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
table_t
=
context
.
Input
<
LoDTensor
>
(
"W"
);
auto
*
table_t
=
context
.
Input
<
LoDTensor
>
(
"W"
);
auto
*
ids_t
=
context
.
Input
<
LoDTensor
>
(
"Ids"
);
auto
*
output_t
=
context
.
Output
<
LoDTensor
>
(
"Out"
);
int64_t
padding_idx
=
context
.
Attr
<
int64_t
>
(
"padding_idx"
);
int64_t
padding_idx
=
context
.
Attr
<
int64_t
>
(
"padding_idx"
);
auto
*
ids_var
=
context
.
InputVar
(
"Ids"
);
Tensor
*
output_t
=
context
.
Output
<
Tensor
>
(
"Out"
);
int64_t
*
ids
;
int64_t
K
;
// The type of Ids(Input) is SelectedRows or LoDTensor, when Ids's type
// is LoDTensor, this tensor contains the ids to be looked up in W;
// when Ids's type is SelectedRows, the rows of Ids contains the
// ids to be looked up in W.
if
(
ids_var
->
IsType
<
framework
::
LoDTensor
>
())
{
auto
*
ids_t
=
context
.
Input
<
LoDTensor
>
(
"Ids"
);
ids
=
const_cast
<
int64_t
*>
(
ids_t
->
data
<
int64_t
>
());
K
=
ids_t
->
numel
();
}
else
if
(
ids_var
->
IsType
<
framework
::
SelectedRows
>
())
{
auto
*
ids_t
=
context
.
Input
<
framework
::
SelectedRows
>
(
"Ids"
);
ids
=
const_cast
<
int64_t
*>
(
ids_t
->
rows
().
CUDAData
(
context
.
GetPlace
()));
K
=
ids_t
->
rows
().
size
();
output_t
->
Resize
({
K
,
table_t
->
dims
()[
1
]});
}
else
{
PADDLE_THROW
(
"Unsupported Variable Type of Ids"
);
}
size_t
N
=
table_t
->
dims
()[
0
];
size_t
N
=
table_t
->
dims
()[
0
];
size_t
D
=
table_t
->
dims
()[
1
];
size_t
D
=
table_t
->
dims
()[
1
];
auto
*
table
=
table_t
->
data
<
T
>
();
size_t
K
=
ids_t
->
numel
();
auto
*
output
=
output_t
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
*
ids
=
ids_t
->
data
<
int64_t
>
();
auto
*
table
=
table_t
->
data
<
T
>
();
auto
*
output
=
output_t
->
mutable_data
<
T
>
(
context
.
GetPlace
());
dim3
threads
(
128
,
8
);
dim3
threads
(
128
,
8
);
dim3
grids
(
8
,
1
);
dim3
grids
(
8
,
1
);
...
@@ -122,19 +105,19 @@ class LookupTableCUDAKernel : public framework::OpKernel<T> {
...
@@ -122,19 +105,19 @@ class LookupTableCUDAKernel : public framework::OpKernel<T> {
template
<
typename
T
>
template
<
typename
T
>
class
LookupTableGradCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
class
LookupTableGradCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
&
dev_ctx
=
auto
&
dev_ctx
=
context
.
template
device_context
<
platform
::
CUDADeviceContext
>();
context
.
template
device_context
<
platform
::
CUDADeviceContext
>();
bool
is_sparse
=
context
.
Attr
<
bool
>
(
"is_sparse"
);
bool
is_sparse
=
context
.
Attr
<
bool
>
(
"is_sparse"
);
// Since paddings are not trainable and fixed in forward, the gradient of
// Since paddings are not trainable and fixed in forward, the gradient of
// paddings makes no sense and we don't deal with it in backward.
// paddings makes no sense and we don't deal with it in backward.
if
(
is_sparse
)
{
if
(
is_sparse
)
{
auto
*
ids
=
context
.
Input
<
LoDTensor
>
(
"Ids"
);
auto
*
ids
=
context
.
Input
<
LoDTensor
>
(
"Ids"
);
auto
*
table
=
context
.
Input
<
LoDTensor
>
(
"W"
);
auto
*
table
=
context
.
Input
<
LoDTensor
>
(
"W"
);
auto
*
d_output
=
context
.
Input
<
LoDTensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
d_output
=
context
.
Input
<
LoDTensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
d_table
=
context
.
Output
<
SelectedRows
>
(
framework
::
GradVarName
(
"W"
));
auto
*
d_table
=
context
.
Output
<
SelectedRows
>
(
framework
::
GradVarName
(
"W"
));
auto
*
ids_data
=
ids
->
data
<
int64_t
>
();
auto
*
ids_data
=
ids
->
data
<
int64_t
>
();
auto
ids_dim
=
ids
->
dims
();
auto
ids_dim
=
ids
->
dims
();
auto
stream
=
dev_ctx
.
stream
();
auto
stream
=
dev_ctx
.
stream
();
...
@@ -150,12 +133,12 @@ class LookupTableGradCUDAKernel : public framework::OpKernel<T> {
...
@@ -150,12 +133,12 @@ class LookupTableGradCUDAKernel : public framework::OpKernel<T> {
d_table
->
set_rows
(
new_rows
);
d_table
->
set_rows
(
new_rows
);
auto
*
d_table_value
=
d_table
->
mutable_value
();
auto
*
d_table_value
=
d_table
->
mutable_value
();
d_table_value
->
Resize
({
ids_dim
[
0
],
table
->
dims
()[
1
]});
d_table_value
->
Resize
({
ids_dim
[
0
],
table
->
dims
()[
1
]});
d_table_value
->
mutable_data
<
T
>
(
context
.
GetPlace
());
d_table_value
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
*
d_table_data
=
d_table_value
->
data
<
T
>
();
auto
*
d_table_data
=
d_table_value
->
data
<
T
>
();
auto
*
d_output_data
=
d_output
->
data
<
T
>
();
auto
*
d_output_data
=
d_output
->
data
<
T
>
();
PADDLE_ENFORCE_EQ
(
d_table_value
->
dims
(),
d_output
->
dims
());
PADDLE_ENFORCE_EQ
(
d_table_value
->
dims
(),
d_output
->
dims
());
memory
::
Copy
(
gpu_place
,
d_table_data
,
gpu_place
,
d_output_data
,
memory
::
Copy
(
gpu_place
,
d_table_data
,
gpu_place
,
d_output_data
,
d_output
->
numel
()
*
sizeof
(
T
),
stream
);
d_output
->
numel
()
*
sizeof
(
T
),
stream
);
...
@@ -168,9 +151,9 @@ class LookupTableGradCUDAKernel : public framework::OpKernel<T> {
...
@@ -168,9 +151,9 @@ class LookupTableGradCUDAKernel : public framework::OpKernel<T> {
int
N
=
d_table_t
->
dims
()[
0
];
int
N
=
d_table_t
->
dims
()[
0
];
int
D
=
d_table_t
->
dims
()[
1
];
int
D
=
d_table_t
->
dims
()[
1
];
int
K
=
ids_t
->
numel
();
int
K
=
ids_t
->
numel
();
const
int64_t
*
ids
=
ids_t
->
data
<
int64_t
>
();
const
int64_t
*
ids
=
ids_t
->
data
<
int64_t
>
();
const
T
*
d_output
=
d_output_t
->
data
<
T
>
();
const
T
*
d_output
=
d_output_t
->
data
<
T
>
();
T
*
d_table
=
d_table_t
->
mutable_data
<
T
>
(
context
.
GetPlace
());
T
*
d_table
=
d_table_t
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
t
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
d_table_t
);
auto
t
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
d_table_t
);
t
.
device
(
*
dev_ctx
.
eigen_device
())
=
t
.
constant
(
static_cast
<
T
>
(
0
));
t
.
device
(
*
dev_ctx
.
eigen_device
())
=
t
.
constant
(
static_cast
<
T
>
(
0
));
...
...
paddle/fluid/operators/lookup_table_op.h
浏览文件 @
000ba1ac
...
@@ -36,43 +36,13 @@ template <typename T>
...
@@ -36,43 +36,13 @@ template <typename T>
class
LookupTableKernel
:
public
framework
::
OpKernel
<
T
>
{
class
LookupTableKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
ids_t
=
context
.
Input
<
LoDTensor
>
(
"Ids"
);
// int tensor
auto
*
output_t
=
context
.
Output
<
LoDTensor
>
(
"Out"
);
// float tensor
auto
*
table_var
=
context
.
InputVar
(
"W"
);
auto
*
table_var
=
context
.
InputVar
(
"W"
);
auto
*
ids_var
=
context
.
InputVar
(
"Ids"
);
Tensor
*
output_t
=
context
.
Output
<
Tensor
>
(
"Out"
);
int64_t
padding_idx
=
context
.
Attr
<
int64_t
>
(
"padding_idx"
);
DDim
table_dim
;
if
(
table_var
->
IsType
<
LoDTensor
>
())
{
int64_t
padding_idx
=
context
.
Attr
<
int64_t
>
(
"padding_idx"
);
table_dim
=
context
.
Input
<
LoDTensor
>
(
"W"
)
->
dims
();
int64_t
*
ids
=
const_cast
<
int64_t
*>
(
ids_t
->
data
<
int64_t
>
());
}
else
if
(
table_var
->
IsType
<
SelectedRows
>
())
{
int64_t
ids_numel
=
ids_t
->
numel
();
auto
*
table_t
=
context
.
Input
<
SelectedRows
>
(
"W"
);
table_dim
=
table_t
->
value
().
dims
();
}
else
{
PADDLE_THROW
(
"The parameter W of a LookupTable "
"must be either LoDTensor or SelectedRows"
);
}
int64_t
*
ids
;
int64_t
ids_numel
;
// The type of Ids(Input) is SelectedRows or LoDTensor, when Ids's type
// is LoDTensor, this tensor contains the ids to be looked up in W;
// when Ids's type is SelectedRows, the rows of Ids contains the
// ids to be looked up in W.
if
(
ids_var
->
IsType
<
LoDTensor
>
())
{
auto
*
ids_t
=
context
.
Input
<
LoDTensor
>
(
"Ids"
);
ids
=
const_cast
<
int64_t
*>
(
ids_t
->
data
<
int64_t
>
());
ids_numel
=
ids_t
->
numel
();
}
else
if
(
ids_var
->
IsType
<
SelectedRows
>
())
{
auto
*
ids_t
=
context
.
Input
<
SelectedRows
>
(
"Ids"
);
ids
=
const_cast
<
int64_t
*>
(
ids_t
->
rows
().
data
());
ids_numel
=
ids_t
->
rows
().
size
();
output_t
->
Resize
({
ids_numel
,
table_dim
[
1
]});
}
else
{
PADDLE_THROW
(
"Unsupported Variable Type of Ids"
);
}
if
(
table_var
->
IsType
<
LoDTensor
>
())
{
if
(
table_var
->
IsType
<
LoDTensor
>
())
{
auto
*
table_t
=
context
.
Input
<
LoDTensor
>
(
"W"
);
auto
*
table_t
=
context
.
Input
<
LoDTensor
>
(
"W"
);
...
...
paddle/fluid/operators/tensorrt_engine_op.cc
浏览文件 @
000ba1ac
...
@@ -163,7 +163,4 @@ REGISTER_OP_CPU_KERNEL(
...
@@ -163,7 +163,4 @@ REGISTER_OP_CPU_KERNEL(
ops
::
TensorRTEngineKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int
>
,
ops
::
TensorRTEngineKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int
>
,
ops
::
TensorRTEngineKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int64_t
>
);
ops
::
TensorRTEngineKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int64_t
>
);
// A trick to compile with the needed TensorRT op converter.
USE_TRT_CONVERTER
(
mul
)
#endif // PADDLE_WITH_CUDA
#endif // PADDLE_WITH_CUDA
paddle/fluid/platform/CMakeLists.txt
浏览文件 @
000ba1ac
...
@@ -60,3 +60,7 @@ cc_test(profiler_test SRCS profiler_test.cc DEPS profiler)
...
@@ -60,3 +60,7 @@ cc_test(profiler_test SRCS profiler_test.cc DEPS profiler)
nv_test
(
float16_gpu_test SRCS float16_test.cu DEPS lod_tensor
)
nv_test
(
float16_gpu_test SRCS float16_test.cu DEPS lod_tensor
)
cc_test
(
float16_test SRCS float16_test.cc DEPS lod_tensor
)
cc_test
(
float16_test SRCS float16_test.cc DEPS lod_tensor
)
IF
(
WITH_GPU
)
nv_test
(
cuda_helper_test SRCS cuda_helper_test.cu
)
ENDIF
()
paddle/fluid/platform/cpu_helper.cc
浏览文件 @
000ba1ac
...
@@ -16,6 +16,7 @@ limitations under the License. */
...
@@ -16,6 +16,7 @@ limitations under the License. */
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/enforce.h"
#ifdef PADDLE_WITH_MKLML
#ifdef PADDLE_WITH_MKLML
#include <omp.h>
#include "paddle/fluid/platform/dynload/mklml.h"
#include "paddle/fluid/platform/dynload/mklml.h"
#endif
#endif
...
@@ -33,6 +34,7 @@ void SetNumThreads(int num_threads) {
...
@@ -33,6 +34,7 @@ void SetNumThreads(int num_threads) {
#elif defined(PADDLE_WITH_MKLML)
#elif defined(PADDLE_WITH_MKLML)
int
real_num_threads
=
num_threads
>
1
?
num_threads
:
1
;
int
real_num_threads
=
num_threads
>
1
?
num_threads
:
1
;
platform
::
dynload
::
MKL_Set_Num_Threads
(
real_num_threads
);
platform
::
dynload
::
MKL_Set_Num_Threads
(
real_num_threads
);
omp_set_num_threads
(
num_threads
);
#else
#else
PADDLE_ENFORCE
(
false
,
"To be implemented."
);
PADDLE_ENFORCE
(
false
,
"To be implemented."
);
#endif
#endif
...
...
paddle/fluid/platform/cuda_device_function.h
浏览文件 @
000ba1ac
...
@@ -14,6 +14,10 @@ limitations under the License. */
...
@@ -14,6 +14,10 @@ limitations under the License. */
#pragma once
#pragma once
#include <cuda.h>
#include <cuda.h>
// NOTE(): support float16 to half in header file.
#define PADDLE_CUDA_FP16
#include <cuda_fp16.h>
#include "paddle/fluid/platform/float16.h"
namespace
paddle
{
namespace
paddle
{
namespace
platform
{
namespace
platform
{
...
@@ -36,6 +40,18 @@ __forceinline__ __device__ T CudaShuffleDownSync(unsigned mask, T val,
...
@@ -36,6 +40,18 @@ __forceinline__ __device__ T CudaShuffleDownSync(unsigned mask, T val,
#endif
#endif
}
}
// CUDA 9.0 have native compatible float16 shfl_down
#if CUDA_VERSION < 9000
template
<
>
__forceinline__
__device__
float16
CudaShuffleDownSync
(
unsigned
mask
,
float16
val
,
int
delta
,
int
width
)
{
half
tmp
=
static_cast
<
half
>
(
val
);
__shfl_down
(
tmp
,
static_cast
<
unsigned
>
(
delta
),
width
);
return
float16
(
tmp
);
}
#endif
template
<
typename
T
>
template
<
typename
T
>
__forceinline__
__device__
T
CudaShuffleSync
(
unsigned
mask
,
T
val
,
int
src_line
,
__forceinline__
__device__
T
CudaShuffleSync
(
unsigned
mask
,
T
val
,
int
src_line
,
int
width
=
32
)
{
int
width
=
32
)
{
...
@@ -46,6 +62,11 @@ __forceinline__ __device__ T CudaShuffleSync(unsigned mask, T val, int src_line,
...
@@ -46,6 +62,11 @@ __forceinline__ __device__ T CudaShuffleSync(unsigned mask, T val, int src_line,
#endif
#endif
}
}
template
<
typename
T
>
HOSTDEVICE
T
Infinity
()
{
return
INFINITY
;
}
template
<
typename
T
>
template
<
typename
T
>
__device__
T
reduceSum
(
T
val
,
int
tid
,
int
len
)
{
__device__
T
reduceSum
(
T
val
,
int
tid
,
int
len
)
{
// NOTE(zcd): The warp size should be taken from the
// NOTE(zcd): The warp size should be taken from the
...
...
paddle/fluid/platform/cuda_helper_test.cu
0 → 100644
浏览文件 @
000ba1ac
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include <bitset>
#include <iostream>
#include <random>
#define PADDLE_CUDA_FP16
#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/float16.h"
using
paddle
::
platform
::
PADDLE_CUDA_NUM_THREADS
;
using
paddle
::
platform
::
float16
;
#define CUDA_ATOMIC_KERNEL(op, T) \
__global__ void op##Kernel(const T* data_a, T* data_b, size_t num) { \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num; \
i += blockDim.x * gridDim.x) { \
paddle::platform::CudaAtomic##op(&data_b[i], data_a[i]); \
} \
}
template
<
typename
T
>
struct
AddFunctor
{
T
operator
()(
const
T
&
a
,
const
T
&
b
)
{
return
a
+
b
;
}
};
template
<
typename
T
>
struct
SubFunctor
{
T
operator
()(
const
T
&
a
,
const
T
&
b
)
{
return
a
-
b
;
}
};
// NOTE(dzhwinter): the float16 add has small underflow/overflow
// so we use EXPECT_NEAR to check the result.
#define ARITHMETIC_KERNEL_LAUNCH(op, T) \
void Test##T##op(size_t num) { \
T *in1, *in2, *out; \
T *d_in1, *d_in2; \
size_t size = sizeof(T) * num; \
cudaMalloc(reinterpret_cast<void**>(&d_in1), size); \
cudaMalloc(reinterpret_cast<void**>(&d_in2), size); \
in1 = reinterpret_cast<T*>(malloc(size)); \
in2 = reinterpret_cast<T*>(malloc(size)); \
out = reinterpret_cast<T*>(malloc(size)); \
std::minstd_rand engine; \
std::uniform_real_distribution<double> dist(0.0, 1.0); \
for (size_t i = 0; i < num; ++i) { \
in1[i] = static_cast<T>(dist(engine)); \
in2[i] = static_cast<T>(dist(engine)); \
} \
cudaMemcpy(d_in1, in1, size, cudaMemcpyHostToDevice); \
cudaMemcpy(d_in2, in2, size, cudaMemcpyHostToDevice); \
op##Kernel<<<1, PADDLE_CUDA_NUM_THREADS>>>(d_in1, d_in2, num); \
cudaDeviceSynchronize(); \
cudaMemcpy(out, d_in2, size, cudaMemcpyDeviceToHost); \
cudaDeviceSynchronize(); \
for (size_t i = 0; i < num; ++i) { \
EXPECT_NEAR(static_cast<float>(out[i]), \
static_cast<float>(op##Functor<T>()(in1[i], in2[i])), \
0.001); \
} \
free(in1); \
free(in2); \
free(out); \
cudaFree(d_in1); \
cudaFree(d_in2); \
}
CUDA_ATOMIC_KERNEL
(
Add
,
float
);
CUDA_ATOMIC_KERNEL
(
Add
,
double
);
CUDA_ATOMIC_KERNEL
(
Add
,
float16
);
ARITHMETIC_KERNEL_LAUNCH
(
Add
,
float
);
ARITHMETIC_KERNEL_LAUNCH
(
Add
,
double
);
ARITHMETIC_KERNEL_LAUNCH
(
Add
,
float16
);
namespace
paddle
{
namespace
platform
{
USE_CUDA_ATOMIC
(
Sub
,
int
);
};
};
CUDA_ATOMIC_KERNEL
(
Sub
,
int
);
ARITHMETIC_KERNEL_LAUNCH
(
Sub
,
int
);
// cuda primitives
TEST
(
CudaAtomic
,
Add
)
{
TestfloatAdd
(
static_cast
<
size_t
>
(
10
));
TestfloatAdd
(
static_cast
<
size_t
>
(
1024
*
1024
));
TestdoubleAdd
(
static_cast
<
size_t
>
(
10
));
TestdoubleAdd
(
static_cast
<
size_t
>
(
1024
*
1024
));
}
TEST
(
CudaAtomic
,
Sub
)
{
TestintSub
(
static_cast
<
size_t
>
(
10
));
TestintSub
(
static_cast
<
size_t
>
(
1024
*
1024
));
}
TEST
(
CudaAtomic
,
float16
)
{
using
paddle
::
platform
::
float16
;
Testfloat16Add
(
static_cast
<
size_t
>
(
1
));
Testfloat16Add
(
static_cast
<
size_t
>
(
2
));
Testfloat16Add
(
static_cast
<
size_t
>
(
3
));
Testfloat16Add
(
static_cast
<
size_t
>
(
10
));
Testfloat16Add
(
static_cast
<
size_t
>
(
1024
*
1024
));
}
paddle/fluid/platform/cuda_primitives.h
浏览文件 @
000ba1ac
...
@@ -14,12 +14,14 @@ limitations under the License. */
...
@@ -14,12 +14,14 @@ limitations under the License. */
#pragma once
#pragma once
#include <cuda.h>
#include <cuda.h>
#include <stdio.h>
#include "paddle/fluid/platform/float16.h"
namespace
paddle
{
namespace
paddle
{
namespace
platform
{
namespace
platform
{
#define CUDA_ATOMIC_WRAPPER(op, T) \
#define CUDA_ATOMIC_WRAPPER(op, T) \
__device__ __forceinline__ T CudaAtomic##op(T
*
address, const T val)
__device__ __forceinline__ T CudaAtomic##op(T
*
address, const T val)
#define USE_CUDA_ATOMIC(op, T) \
#define USE_CUDA_ATOMIC(op, T) \
CUDA_ATOMIC_WRAPPER(op, T) { return atomic##op(address, val); }
CUDA_ATOMIC_WRAPPER(op, T) { return atomic##op(address, val); }
...
@@ -42,7 +44,7 @@ CUDA_ATOMIC_WRAPPER(Add, int64_t) {
...
@@ -42,7 +44,7 @@ CUDA_ATOMIC_WRAPPER(Add, int64_t) {
static_assert
(
sizeof
(
int64_t
)
==
sizeof
(
long
long
int
),
// NOLINT
static_assert
(
sizeof
(
int64_t
)
==
sizeof
(
long
long
int
),
// NOLINT
"long long should be int64"
);
"long long should be int64"
);
return
CudaAtomicAdd
(
return
CudaAtomicAdd
(
reinterpret_cast
<
unsigned
long
long
int
*>
(
address
),
// NOLINT
reinterpret_cast
<
unsigned
long
long
int
*>
(
address
),
// NOLINT
static_cast
<
unsigned
long
long
int
>
(
val
));
// NOLINT
static_cast
<
unsigned
long
long
int
>
(
val
));
// NOLINT
}
}
...
@@ -50,8 +52,8 @@ CUDA_ATOMIC_WRAPPER(Add, int64_t) {
...
@@ -50,8 +52,8 @@ CUDA_ATOMIC_WRAPPER(Add, int64_t) {
USE_CUDA_ATOMIC
(
Add
,
double
);
USE_CUDA_ATOMIC
(
Add
,
double
);
#else
#else
CUDA_ATOMIC_WRAPPER
(
Add
,
double
)
{
CUDA_ATOMIC_WRAPPER
(
Add
,
double
)
{
unsigned
long
long
int
*
address_as_ull
=
// NOLINT
unsigned
long
long
int
*
address_as_ull
=
// NOLINT
reinterpret_cast
<
unsigned
long
long
int
*>
(
address
);
// NOLINT
reinterpret_cast
<
unsigned
long
long
int
*>
(
address
);
// NOLINT
unsigned
long
long
int
old
=
*
address_as_ull
,
assumed
;
// NOLINT
unsigned
long
long
int
old
=
*
address_as_ull
,
assumed
;
// NOLINT
do
{
do
{
...
@@ -64,6 +66,67 @@ CUDA_ATOMIC_WRAPPER(Add, double) {
...
@@ -64,6 +66,67 @@ CUDA_ATOMIC_WRAPPER(Add, double) {
return
__longlong_as_double
(
old
);
return
__longlong_as_double
(
old
);
}
}
#endif
#ifdef PADDLE_CUDA_FP16
// NOTE(dzhwinter): cuda do not have atomicCAS for half.
// Just use the half address as a unsigned value address and
// do the atomicCAS. According to the value store at high 16 bits
// or low 16 bits, then do a different sum and CAS.
// Given most warp-threads will failed on the atomicCAS, so this
// implemented should be avoided in high concurrency. It's will be
// slower than the way convert value into 32bits and do a full atomicCAS.
// convert the value into float and do the add arithmetic.
// then store the result into a uint32.
inline
__device__
uint32_t
add_to_low_half
(
uint32_t
val
,
float
x
)
{
float16
low_half
;
// the float16 in lower 16bits
low_half
.
x
=
static_cast
<
uint16_t
>
(
val
&
0xffffu
);
low_half
=
static_cast
<
float16
>
(
static_cast
<
float
>
(
low_half
)
+
x
);
return
(
val
&
0xffff0000u
)
|
low_half
.
x
;
}
inline
__device__
uint32_t
add_to_high_half
(
uint32_t
val
,
float
x
)
{
float16
high_half
;
// the float16 in higher 16bits
high_half
.
x
=
static_cast
<
uint16_t
>
(
val
>>
16
);
high_half
=
static_cast
<
float16
>
(
static_cast
<
float
>
(
high_half
)
+
x
);
return
(
val
&
0xffffu
)
|
(
static_cast
<
uint32_t
>
(
high_half
.
x
)
<<
16
);
}
CUDA_ATOMIC_WRAPPER
(
Add
,
float16
)
{
// concrete packed float16 value may exsits in lower or higher 16bits
// of the 32bits address.
uint32_t
*
address_as_ui
=
reinterpret_cast
<
uint32_t
*>
(
reinterpret_cast
<
char
*>
(
address
)
-
(
reinterpret_cast
<
size_t
>
(
address
)
&
2
));
float
val_f
=
static_cast
<
float
>
(
val
);
uint32_t
old
=
*
address_as_ui
;
uint32_t
sum
;
uint32_t
newval
;
uint32_t
assumed
;
if
(((
size_t
)
address
&
2
)
==
0
)
{
// the float16 value stay at lower 16 bits of the address.
do
{
assumed
=
old
;
old
=
atomicCAS
(
address_as_ui
,
assumed
,
add_to_low_half
(
assumed
,
val_f
));
}
while
(
old
!=
assumed
);
float16
ret
;
ret
.
x
=
old
&
0xffffu
;
return
ret
;
}
else
{
// the float16 value stay at higher 16 bits of the address.
do
{
assumed
=
old
;
old
=
atomicCAS
(
address_as_ui
,
assumed
,
add_to_high_half
(
assumed
,
val_f
));
}
while
(
old
!=
assumed
);
float16
ret
;
ret
.
x
=
old
>>
16
;
return
ret
;
}
}
#endif
#endif
}
// namespace platform
}
// namespace platform
}
// namespace paddle
}
// namespace paddle
paddle/fluid/platform/float16.h
浏览文件 @
000ba1ac
...
@@ -67,8 +67,11 @@ struct float16;
...
@@ -67,8 +67,11 @@ struct float16;
}
// namespace platform
}
// namespace platform
}
// namespace paddle
}
// namespace paddle
// NOTE():
// Do not move the eigen.h header, otherwise the eigen_vector<bool> will failed.
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/platform/hostdevice.h"
#include "paddle/fluid/platform/hostdevice.h"
#include "unsupported/Eigen/CXX11/Tensor"
namespace
paddle
{
namespace
paddle
{
namespace
platform
{
namespace
platform
{
...
@@ -898,6 +901,30 @@ struct is_pod<paddle::platform::float16> {
...
@@ -898,6 +901,30 @@ struct is_pod<paddle::platform::float16> {
is_standard_layout
<
paddle
::
platform
::
float16
>::
value
;
is_standard_layout
<
paddle
::
platform
::
float16
>::
value
;
};
};
template
<
>
struct
is_floating_point
<
paddle
::
platform
::
float16
>
:
std
::
integral_constant
<
bool
,
std
::
is_same
<
paddle
::
platform
::
float16
,
typename
std
::
remove_cv
<
paddle
::
platform
::
float16
>::
type
>::
value
>
{};
template
<
>
struct
is_signed
<
paddle
::
platform
::
float16
>
{
static
const
bool
value
=
true
;
};
template
<
>
struct
is_unsigned
<
paddle
::
platform
::
float16
>
{
static
const
bool
value
=
false
;
};
inline
bool
isnan
(
const
paddle
::
platform
::
float16
&
a
)
{
return
paddle
::
platform
::
isnan
(
a
);
}
inline
bool
isinf
(
const
paddle
::
platform
::
float16
&
a
)
{
return
paddle
::
platform
::
isinf
(
a
);
}
template
<
>
template
<
>
struct
numeric_limits
<
paddle
::
platform
::
float16
>
{
struct
numeric_limits
<
paddle
::
platform
::
float16
>
{
static
const
bool
is_specialized
=
true
;
static
const
bool
is_specialized
=
true
;
...
...
paddle/fluid/platform/float16_test.cc
浏览文件 @
000ba1ac
...
@@ -141,10 +141,36 @@ TEST(float16, lod_tensor_cpu) {
...
@@ -141,10 +141,36 @@ TEST(float16, lod_tensor_cpu) {
}
}
}
}
TEST
(
float16
,
floating
)
{
// compile time assert.
PADDLE_ASSERT
(
std
::
is_floating_point
<
float16
>::
value
);
}
TEST
(
float16
,
print
)
{
TEST
(
float16
,
print
)
{
float16
a
=
float16
(
1.0
f
);
float16
a
=
float16
(
1.0
f
);
std
::
cout
<<
a
<<
std
::
endl
;
std
::
cout
<<
a
<<
std
::
endl
;
}
}
// CPU test
TEST
(
float16
,
isinf
)
{
float16
a
;
a
.
x
=
0x7c00
;
float16
b
=
float16
(
INFINITY
);
float16
c
=
static_cast
<
float16
>
(
INFINITY
);
EXPECT_EQ
(
std
::
isinf
(
a
),
true
);
EXPECT_EQ
(
std
::
isinf
(
b
),
true
);
EXPECT_EQ
(
std
::
isinf
(
c
),
true
);
}
TEST
(
float16
,
isnan
)
{
float16
a
;
a
.
x
=
0x7fff
;
float16
b
=
float16
(
NAN
);
float16
c
=
static_cast
<
float16
>
(
NAN
);
EXPECT_EQ
(
std
::
isnan
(
a
),
true
);
EXPECT_EQ
(
std
::
isnan
(
b
),
true
);
EXPECT_EQ
(
std
::
isnan
(
c
),
true
);
}
}
// namespace platform
}
// namespace platform
}
// namespace paddle
}
// namespace paddle
paddle/fluid/platform/float16_test.cu
浏览文件 @
000ba1ac
...
@@ -11,11 +11,13 @@ limitations under the License. */
...
@@ -11,11 +11,13 @@ limitations under the License. */
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/float16.h"
#include <glog/logging.h>
#include <gtest/gtest.h>
#include <gtest/gtest.h>
#include <bitset>
#include <iostream>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/legacy/utils/Logging.h"
#define ARITHMETIC_KERNEL(op_type, sign) \
#define ARITHMETIC_KERNEL(op_type, sign) \
__global__ void op_type(const half* in1, const half* in2, half* out) { \
__global__ void op_type(const half* in1, const half* in2, half* out) { \
...
@@ -241,6 +243,72 @@ TEST(float16, lod_tensor_on_gpu) {
...
@@ -241,6 +243,72 @@ TEST(float16, lod_tensor_on_gpu) {
}
}
}
}
template
<
typename
T
>
struct
Functor
{
bool
operator
()(
const
T
&
val
)
{
return
std
::
type_index
(
typeid
(
T
))
==
std
::
type_index
(
typeid
(
platform
::
float16
));
}
};
TEST
(
float16
,
typeid
)
{
// the framework heavily used typeid hash
Functor
<
float16
>
functor
;
float16
a
=
float16
(
.0
f
);
Functor
<
int
>
functor2
;
int
b
(
0
);
// compile time assert
PADDLE_ASSERT
(
functor
(
a
)
==
true
);
PADDLE_ASSERT
(
functor2
(
b
)
==
false
);
}
// GPU test
TEST
(
float16
,
isinf
)
{
float16
a
;
a
.
x
=
0x7c00
;
float16
b
=
float16
(
INFINITY
);
// underflow to 0
float16
native_a
(
5e-40
f
);
// overflow to inf
float16
native_b
(
5e40
f
);
EXPECT_EQ
(
std
::
isinf
(
a
),
true
);
EXPECT_EQ
(
std
::
isinf
(
b
),
true
);
EXPECT_EQ
(
std
::
isinf
(
native_b
),
true
);
EXPECT_EQ
(
native_a
,
float16
(
0
));
}
TEST
(
float16
,
isnan
)
{
float16
a
;
a
.
x
=
0x7fff
;
float16
b
=
float16
(
NAN
);
float16
c
=
float16
(
5e40
);
// inf * +-0 will get a nan
float16
d
=
c
*
float16
(
0
);
EXPECT_EQ
(
std
::
isnan
(
a
),
true
);
EXPECT_EQ
(
std
::
isnan
(
b
),
true
);
EXPECT_EQ
(
std
::
isnan
(
d
),
true
);
}
TEST
(
float16
,
cast
)
{
float16
a
;
a
.
x
=
0x0070
;
auto
b
=
a
;
{
// change semantic, keep the same value
float16
c
=
reinterpret_cast
<
float16
&>
(
reinterpret_cast
<
unsigned
&>
(
b
));
EXPECT_EQ
(
b
,
c
);
}
{
// use uint32 low 16 bit store float16
uint32_t
c
=
reinterpret_cast
<
uint32_t
&>
(
b
);
float16
d
;
d
.
x
=
c
;
EXPECT_EQ
(
b
,
d
);
}
}
}
// namespace platform
}
// namespace platform
}
// namespace paddle
}
// namespace paddle
#endif // PADDLE_CUDA_FP16
#endif // PADDLE_CUDA_FP16
paddle/fluid/platform/init.cc
浏览文件 @
000ba1ac
...
@@ -23,6 +23,9 @@ limitations under the License. */
...
@@ -23,6 +23,9 @@ limitations under the License. */
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/string/piece.h"
#include "paddle/fluid/string/piece.h"
DEFINE_int32
(
paddle_num_threads
,
1
,
"Number of threads for each paddle instance."
);
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
...
@@ -115,7 +118,7 @@ void InitDevices(bool init_p2p, const std::vector<int> devices) {
...
@@ -115,7 +118,7 @@ void InitDevices(bool init_p2p, const std::vector<int> devices) {
places
.
emplace_back
(
platform
::
CPUPlace
());
places
.
emplace_back
(
platform
::
CPUPlace
());
platform
::
DeviceContextPool
::
Init
(
places
);
platform
::
DeviceContextPool
::
Init
(
places
);
#ifndef PADDLE_WITH_MKLDNN
#ifndef PADDLE_WITH_MKLDNN
platform
::
SetNumThreads
(
1
);
platform
::
SetNumThreads
(
FLAGS_paddle_num_threads
);
#endif
#endif
}
}
...
...
patches/grpc/completion_queue.h
0 → 100644
浏览文件 @
000ba1ac
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/// A completion queue implements a concurrent producer-consumer queue, with
/// two main API-exposed methods: \a Next and \a AsyncNext. These
/// methods are the essential component of the gRPC C++ asynchronous API.
/// There is also a \a Shutdown method to indicate that a given completion queue
/// will no longer have regular events. This must be called before the
/// completion queue is destroyed.
/// All completion queue APIs are thread-safe and may be used concurrently with
/// any other completion queue API invocation; it is acceptable to have
/// multiple threads calling \a Next or \a AsyncNext on the same or different
/// completion queues, or to call these methods concurrently with a \a Shutdown
/// elsewhere.
/// \remark{All other API calls on completion queue should be completed before
/// a completion queue destructor is called.}
#ifndef GRPCPP_IMPL_CODEGEN_COMPLETION_QUEUE_H
#define GRPCPP_IMPL_CODEGEN_COMPLETION_QUEUE_H
#include <typeinfo>
#include <grpc/impl/codegen/atm.h>
#include <grpcpp/impl/codegen/completion_queue_tag.h>
#include <grpcpp/impl/codegen/core_codegen_interface.h>
#include <grpcpp/impl/codegen/grpc_library.h>
#include <grpcpp/impl/codegen/status.h>
#include <grpcpp/impl/codegen/time.h>
struct
grpc_completion_queue
;
namespace
grpc
{
template
<
class
R
>
class
ClientReader
;
template
<
class
W
>
class
ClientWriter
;
template
<
class
W
,
class
R
>
class
ClientReaderWriter
;
template
<
class
R
>
class
ServerReader
;
template
<
class
W
>
class
ServerWriter
;
namespace
internal
{
template
<
class
W
,
class
R
>
class
ServerReaderWriterBody
;
}
// namespace internal
class
Channel
;
class
ChannelInterface
;
class
ClientContext
;
class
CompletionQueue
;
class
Server
;
class
ServerBuilder
;
class
ServerContext
;
class
ServerInterface
;
namespace
internal
{
class
CompletionQueueTag
;
class
RpcMethod
;
template
<
class
ServiceType
,
class
RequestType
,
class
ResponseType
>
class
RpcMethodHandler
;
template
<
class
ServiceType
,
class
RequestType
,
class
ResponseType
>
class
ClientStreamingHandler
;
template
<
class
ServiceType
,
class
RequestType
,
class
ResponseType
>
class
ServerStreamingHandler
;
template
<
class
ServiceType
,
class
RequestType
,
class
ResponseType
>
class
BidiStreamingHandler
;
class
UnknownMethodHandler
;
template
<
class
Streamer
,
bool
WriteNeeded
>
class
TemplatedBidiStreamingHandler
;
template
<
class
InputMessage
,
class
OutputMessage
>
class
BlockingUnaryCallImpl
;
}
// namespace internal
extern
CoreCodegenInterface
*
g_core_codegen_interface
;
/// A thin wrapper around \ref grpc_completion_queue (see \ref
/// src/core/lib/surface/completion_queue.h).
/// See \ref doc/cpp/perf_notes.md for notes on best practices for high
/// performance servers.
class
CompletionQueue
:
private
GrpcLibraryCodegen
{
public:
/// Default constructor. Implicitly creates a \a grpc_completion_queue
/// instance.
CompletionQueue
()
:
CompletionQueue
(
grpc_completion_queue_attributes
{
GRPC_CQ_CURRENT_VERSION
,
GRPC_CQ_NEXT
,
GRPC_CQ_DEFAULT_POLLING
})
{}
/// Wrap \a take, taking ownership of the instance.
///
/// \param take The completion queue instance to wrap. Ownership is taken.
explicit
CompletionQueue
(
grpc_completion_queue
*
take
);
/// Destructor. Destroys the owned wrapped completion queue / instance.
~
CompletionQueue
()
{
if
(
typeid
(
*
g_core_codegen_interface
).
hash_code
()
!=
typeid
(
CoreCodegenInterface
).
hash_code
())
{
g_core_codegen_interface
->
grpc_completion_queue_destroy
(
cq_
);
}
}
/// Tri-state return for AsyncNext: SHUTDOWN, GOT_EVENT, TIMEOUT.
enum
NextStatus
{
SHUTDOWN
,
///< The completion queue has been shutdown and fully-drained
GOT_EVENT
,
///< Got a new event; \a tag will be filled in with its
///< associated value; \a ok indicating its success.
TIMEOUT
///< deadline was reached.
};
/// Read from the queue, blocking until an event is available or the queue is
/// shutting down.
///
/// \param tag[out] Updated to point to the read event's tag.
/// \param ok[out] true if read a successful event, false otherwise.
///
/// Note that each tag sent to the completion queue (through RPC operations
/// or alarms) will be delivered out of the completion queue by a call to
/// Next (or a related method), regardless of whether the operation succeeded
/// or not. Success here means that this operation completed in the normal
/// valid manner.
///
/// Server-side RPC request: \a ok indicates that the RPC has indeed
/// been started. If it is false, the server has been Shutdown
/// before this particular call got matched to an incoming RPC.
///
/// Client-side StartCall/RPC invocation: \a ok indicates that the RPC is
/// going to go to the wire. If it is false, it not going to the wire. This
/// would happen if the channel is either permanently broken or
/// transiently broken but with the fail-fast option. (Note that async unary
/// RPCs don't post a CQ tag at this point, nor do client-streaming
/// or bidi-streaming RPCs that have the initial metadata corked option set.)
///
/// Client-side Write, Client-side WritesDone, Server-side Write,
/// Server-side Finish, Server-side SendInitialMetadata (which is
/// typically included in Write or Finish when not done explicitly):
/// \a ok means that the data/metadata/status/etc is going to go to the
/// wire. If it is false, it not going to the wire because the call
/// is already dead (i.e., canceled, deadline expired, other side
/// dropped the channel, etc).
///
/// Client-side Read, Server-side Read, Client-side
/// RecvInitialMetadata (which is typically included in Read if not
/// done explicitly): \a ok indicates whether there is a valid message
/// that got read. If not, you know that there are certainly no more
/// messages that can ever be read from this stream. For the client-side
/// operations, this only happens because the call is dead. For the
/// server-sider operation, though, this could happen because the client
/// has done a WritesDone already.
///
/// Client-side Finish: \a ok should always be true
///
/// Server-side AsyncNotifyWhenDone: \a ok should always be true
///
/// Alarm: \a ok is true if it expired, false if it was canceled
///
/// \return true if got an event, false if the queue is fully drained and
/// shut down.
bool
Next
(
void
**
tag
,
bool
*
ok
)
{
return
(
AsyncNextInternal
(
tag
,
ok
,
g_core_codegen_interface
->
gpr_inf_future
(
GPR_CLOCK_REALTIME
))
!=
SHUTDOWN
);
}
/// Read from the queue, blocking up to \a deadline (or the queue's shutdown).
/// Both \a tag and \a ok are updated upon success (if an event is available
/// within the \a deadline). A \a tag points to an arbitrary location usually
/// employed to uniquely identify an event.
///
/// \param tag[out] Upon sucess, updated to point to the event's tag.
/// \param ok[out] Upon sucess, true if a successful event, false otherwise
/// See documentation for CompletionQueue::Next for explanation of ok
/// \param deadline[in] How long to block in wait for an event.
///
/// \return The type of event read.
template
<
typename
T
>
NextStatus
AsyncNext
(
void
**
tag
,
bool
*
ok
,
const
T
&
deadline
)
{
TimePoint
<
T
>
deadline_tp
(
deadline
);
return
AsyncNextInternal
(
tag
,
ok
,
deadline_tp
.
raw_time
());
}
/// EXPERIMENTAL
/// First executes \a F, then reads from the queue, blocking up to
/// \a deadline (or the queue's shutdown).
/// Both \a tag and \a ok are updated upon success (if an event is available
/// within the \a deadline). A \a tag points to an arbitrary location usually
/// employed to uniquely identify an event.
///
/// \param F[in] Function to execute before calling AsyncNext on this queue.
/// \param tag[out] Upon sucess, updated to point to the event's tag.
/// \param ok[out] Upon sucess, true if read a regular event, false otherwise.
/// \param deadline[in] How long to block in wait for an event.
///
/// \return The type of event read.
template
<
typename
T
,
typename
F
>
NextStatus
DoThenAsyncNext
(
F
&&
f
,
void
**
tag
,
bool
*
ok
,
const
T
&
deadline
)
{
CompletionQueueTLSCache
cache
=
CompletionQueueTLSCache
(
this
);
f
();
if
(
cache
.
Flush
(
tag
,
ok
))
{
return
GOT_EVENT
;
}
else
{
return
AsyncNext
(
tag
,
ok
,
deadline
);
}
}
/// Request the shutdown of the queue.
///
/// \warning This method must be called at some point if this completion queue
/// is accessed with Next or AsyncNext. \a Next will not return false
/// until this method has been called and all pending tags have been drained.
/// (Likewise for \a AsyncNext returning \a NextStatus::SHUTDOWN .)
/// Only once either one of these methods does that (that is, once the queue
/// has been \em drained) can an instance of this class be destroyed.
/// Also note that applications must ensure that no work is enqueued on this
/// completion queue after this method is called.
void
Shutdown
();
/// Returns a \em raw pointer to the underlying \a grpc_completion_queue
/// instance.
///
/// \warning Remember that the returned instance is owned. No transfer of
/// owership is performed.
grpc_completion_queue
*
cq
()
{
return
cq_
;
}
protected:
/// Private constructor of CompletionQueue only visible to friend classes
CompletionQueue
(
const
grpc_completion_queue_attributes
&
attributes
)
{
cq_
=
g_core_codegen_interface
->
grpc_completion_queue_create
(
g_core_codegen_interface
->
grpc_completion_queue_factory_lookup
(
&
attributes
),
&
attributes
,
NULL
);
InitialAvalanching
();
// reserve this for the future shutdown
}
private:
// Friend synchronous wrappers so that they can access Pluck(), which is
// a semi-private API geared towards the synchronous implementation.
template
<
class
R
>
friend
class
::
grpc
::
ClientReader
;
template
<
class
W
>
friend
class
::
grpc
::
ClientWriter
;
template
<
class
W
,
class
R
>
friend
class
::
grpc
::
ClientReaderWriter
;
template
<
class
R
>
friend
class
::
grpc
::
ServerReader
;
template
<
class
W
>
friend
class
::
grpc
::
ServerWriter
;
template
<
class
W
,
class
R
>
friend
class
::
grpc
::
internal
::
ServerReaderWriterBody
;
template
<
class
ServiceType
,
class
RequestType
,
class
ResponseType
>
friend
class
::
grpc
::
internal
::
RpcMethodHandler
;
template
<
class
ServiceType
,
class
RequestType
,
class
ResponseType
>
friend
class
::
grpc
::
internal
::
ClientStreamingHandler
;
template
<
class
ServiceType
,
class
RequestType
,
class
ResponseType
>
friend
class
::
grpc
::
internal
::
ServerStreamingHandler
;
template
<
class
Streamer
,
bool
WriteNeeded
>
friend
class
::
grpc
::
internal
::
TemplatedBidiStreamingHandler
;
friend
class
::
grpc
::
internal
::
UnknownMethodHandler
;
friend
class
::
grpc
::
Server
;
friend
class
::
grpc
::
ServerContext
;
friend
class
::
grpc
::
ServerInterface
;
template
<
class
InputMessage
,
class
OutputMessage
>
friend
class
::
grpc
::
internal
::
BlockingUnaryCallImpl
;
/// EXPERIMENTAL
/// Creates a Thread Local cache to store the first event
/// On this completion queue queued from this thread. Once
/// initialized, it must be flushed on the same thread.
class
CompletionQueueTLSCache
{
public:
CompletionQueueTLSCache
(
CompletionQueue
*
cq
);
~
CompletionQueueTLSCache
();
bool
Flush
(
void
**
tag
,
bool
*
ok
);
private:
CompletionQueue
*
cq_
;
bool
flushed_
;
};
NextStatus
AsyncNextInternal
(
void
**
tag
,
bool
*
ok
,
gpr_timespec
deadline
);
/// Wraps \a grpc_completion_queue_pluck.
/// \warning Must not be mixed with calls to \a Next.
bool
Pluck
(
internal
::
CompletionQueueTag
*
tag
)
{
auto
deadline
=
g_core_codegen_interface
->
gpr_inf_future
(
GPR_CLOCK_REALTIME
);
auto
ev
=
g_core_codegen_interface
->
grpc_completion_queue_pluck
(
cq_
,
tag
,
deadline
,
nullptr
);
bool
ok
=
ev
.
success
!=
0
;
void
*
ignored
=
tag
;
GPR_CODEGEN_ASSERT
(
tag
->
FinalizeResult
(
&
ignored
,
&
ok
));
GPR_CODEGEN_ASSERT
(
ignored
==
tag
);
// Ignore mutations by FinalizeResult: Pluck returns the C API status
return
ev
.
success
!=
0
;
}
/// Performs a single polling pluck on \a tag.
/// \warning Must not be mixed with calls to \a Next.
///
/// TODO: sreek - This calls tag->FinalizeResult() even if the cq_ is already
/// shutdown. This is most likely a bug and if it is a bug, then change this
/// implementation to simple call the other TryPluck function with a zero
/// timeout. i.e:
/// TryPluck(tag, gpr_time_0(GPR_CLOCK_REALTIME))
void
TryPluck
(
internal
::
CompletionQueueTag
*
tag
)
{
auto
deadline
=
g_core_codegen_interface
->
gpr_time_0
(
GPR_CLOCK_REALTIME
);
auto
ev
=
g_core_codegen_interface
->
grpc_completion_queue_pluck
(
cq_
,
tag
,
deadline
,
nullptr
);
if
(
ev
.
type
==
GRPC_QUEUE_TIMEOUT
)
return
;
bool
ok
=
ev
.
success
!=
0
;
void
*
ignored
=
tag
;
// the tag must be swallowed if using TryPluck
GPR_CODEGEN_ASSERT
(
!
tag
->
FinalizeResult
(
&
ignored
,
&
ok
));
}
/// Performs a single polling pluck on \a tag. Calls tag->FinalizeResult if
/// the pluck() was successful and returned the tag.
///
/// This exects tag->FinalizeResult (if called) to return 'false' i.e expects
/// that the tag is internal not something that is returned to the user.
void
TryPluck
(
internal
::
CompletionQueueTag
*
tag
,
gpr_timespec
deadline
)
{
auto
ev
=
g_core_codegen_interface
->
grpc_completion_queue_pluck
(
cq_
,
tag
,
deadline
,
nullptr
);
if
(
ev
.
type
==
GRPC_QUEUE_TIMEOUT
||
ev
.
type
==
GRPC_QUEUE_SHUTDOWN
)
{
return
;
}
bool
ok
=
ev
.
success
!=
0
;
void
*
ignored
=
tag
;
GPR_CODEGEN_ASSERT
(
!
tag
->
FinalizeResult
(
&
ignored
,
&
ok
));
}
/// Manage state of avalanching operations : completion queue tags that
/// trigger other completion queue operations. The underlying core completion
/// queue should not really shutdown until all avalanching operations have
/// been finalized. Note that we maintain the requirement that an avalanche
/// registration must take place before CQ shutdown (which must be maintained
/// elsehwere)
void
InitialAvalanching
()
{
gpr_atm_rel_store
(
&
avalanches_in_flight_
,
static_cast
<
gpr_atm
>
(
1
));
}
void
RegisterAvalanching
()
{
gpr_atm_no_barrier_fetch_add
(
&
avalanches_in_flight_
,
static_cast
<
gpr_atm
>
(
1
));
}
void
CompleteAvalanching
();
grpc_completion_queue
*
cq_
;
// owned
gpr_atm
avalanches_in_flight_
;
};
/// A specific type of completion queue used by the processing of notifications
/// by servers. Instantiated by \a ServerBuilder.
class
ServerCompletionQueue
:
public
CompletionQueue
{
public:
bool
IsFrequentlyPolled
()
{
return
polling_type_
!=
GRPC_CQ_NON_LISTENING
;
}
private:
grpc_cq_polling_type
polling_type_
;
friend
class
ServerBuilder
;
/// \param is_frequently_polled Informs the GRPC library about whether the
/// server completion queue would be actively polled (by calling Next() or
/// AsyncNext()). By default all server completion queues are assumed to be
/// frequently polled.
ServerCompletionQueue
(
grpc_cq_polling_type
polling_type
)
:
CompletionQueue
(
grpc_completion_queue_attributes
{
GRPC_CQ_CURRENT_VERSION
,
GRPC_CQ_NEXT
,
polling_type
}),
polling_type_
(
polling_type
)
{}
};
}
// namespace grpc
#endif // GRPCPP_IMPL_CODEGEN_COMPLETION_QUEUE_H
patches/grpc/fix_too_early_destory.patch
已删除
100644 → 0
浏览文件 @
0c7d6eb8
diff --git a/include/grpcpp/impl/codegen/completion_queue.h b/include/grpcpp/impl/codegen/completion_queue.h
index 80c7c41982..3f7d8a7714 100644
--- a/include/grpcpp/impl/codegen/completion_queue.h
+++ b/include/grpcpp/impl/codegen/completion_queue.h
@@ -32,6 +32,8 @@
#ifndef GRPCPP_IMPL_CODEGEN_COMPLETION_QUEUE_H
#define GRPCPP_IMPL_CODEGEN_COMPLETION_QUEUE_H
+#include <typeinfo>
+
#include <grpc/impl/codegen/atm.h>
#include <grpcpp/impl/codegen/completion_queue_tag.h>
#include <grpcpp/impl/codegen/core_codegen_interface.h>
@@ -106,7 +108,9 @@
class CompletionQueue : private GrpcLibraryCodegen {
/// Destructor. Destroys the owned wrapped completion queue / instance.
~CompletionQueue() {
- g_core_codegen_interface->grpc_completion_queue_destroy(cq_);
+ if (typeid(*g_core_codegen_interface).hash_code() != typeid(CoreCodegenInterface).hash_code()) {
+ g_core_codegen_interface->grpc_completion_queue_destroy(cq_);
+ }
}
/// Tri-state return for AsyncNext: SHUTDOWN, GOT_EVENT, TIMEOUT.
diff --git a/include/grpcpp/impl/codegen/grpc_library.h b/include/grpcpp/impl/codegen/grpc_library.h
index 17c904d71a..a092b2204d 100644
--- a/include/grpcpp/impl/codegen/grpc_library.h
+++ b/include/grpcpp/impl/codegen/grpc_library.h
@@ -19,6 +19,8 @@
#ifndef GRPCPP_IMPL_CODEGEN_GRPC_LIBRARY_H
#define GRPCPP_IMPL_CODEGEN_GRPC_LIBRARY_H
+#include <typeinfo>
+
#include <grpcpp/impl/codegen/core_codegen_interface.h>
namespace grpc {
@@ -47,7 +49,8 @@
class GrpcLibraryCodegen {
}
}
virtual ~GrpcLibraryCodegen() {
- if (grpc_init_called_) {
+ if (grpc_init_called_ &&
+ typeid(*g_glip).hash_code() != typeid(GrpcLibraryInterface).hash_code()) {
GPR_CODEGEN_ASSERT(g_glip &&
"gRPC library not initialized. See "
"grpc::internal::GrpcLibraryInitializer.");
pa
ddle/fluid/framework/details/ssa_graph_builder_factory.cc
→
pa
tches/grpc/grpc_library.h
浏览文件 @
000ba1ac
...
@@ -12,39 +12,53 @@
...
@@ -12,39 +12,53 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/framework/details/ssa_graph_builder_factory.h"
#ifndef GRPCPP_IMPL_CODEGEN_GRPC_LIBRARY_H
#include <fstream>
#define GRPCPP_IMPL_CODEGEN_GRPC_LIBRARY_H
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
#include "paddle/fluid/framework/details/ssa_graph_checker.h"
#include <typeinfo>
#include "paddle/fluid/framework/details/ssa_graph_printer.h"
#include <grpcpp/impl/codegen/core_codegen_interface.h>
namespace
paddle
{
namespace
framework
{
namespace
grpc
{
namespace
details
{
std
::
unique_ptr
<
SSAGraphBuilder
>
SSAGraphBuilderFactory
::
Create
()
{
class
GrpcLibraryInterface
{
std
::
unique_ptr
<
SSAGraphBuilder
>
res
(
public:
#ifdef PADDLE_WITH_CUDA
virtual
~
GrpcLibraryInterface
()
=
default
;
new
MultiDevSSAGraphBuilder
(
places_
,
loss_var_name_
,
param_names_
,
virtual
void
init
()
=
0
;
local_scopes_
,
nccl_ctxs_
,
strategy_
)
virtual
void
shutdown
()
=
0
;
#else
};
new
MultiDevSSAGraphBuilder
(
places_
,
loss_var_name_
,
param_names_
,
local_scopes_
,
strategy_
)
/// Initialized by \a grpc::GrpcLibraryInitializer from
#endif
/// <grpcpp/impl/grpc_library.h>
);
// NOLINT
extern
GrpcLibraryInterface
*
g_glip
;
if
(
!
strategy_
.
debug_graphviz_path_
.
empty
())
{
/// Classes that require gRPC to be initialized should inherit from this class.
std
::
unique_ptr
<
std
::
ostream
>
fout
(
class
GrpcLibraryCodegen
{
new
std
::
ofstream
(
strategy_
.
debug_graphviz_path_
));
public:
PADDLE_ENFORCE
(
fout
->
good
());
GrpcLibraryCodegen
(
bool
call_grpc_init
=
true
)
:
grpc_init_called_
(
false
)
{
std
::
unique_ptr
<
GraphvizSSAGraphPrinter
>
graphviz_printer
(
if
(
call_grpc_init
)
{
new
GraphvizSSAGraphPrinter
());
GPR_CODEGEN_ASSERT
(
g_glip
&&
res
.
reset
(
new
SSAGraghBuilderWithPrinter
(
"gRPC library not initialized. See "
std
::
move
(
fout
),
std
::
move
(
graphviz_printer
),
std
::
move
(
res
)));
"grpc::internal::GrpcLibraryInitializer."
);
g_glip
->
init
();
grpc_init_called_
=
true
;
}
}
res
.
reset
(
new
SSAGraghBuilderWithChecker
(
std
::
move
(
res
)));
}
virtual
~
GrpcLibraryCodegen
()
{
if
(
grpc_init_called_
&&
typeid
(
*
g_glip
).
hash_code
()
!=
typeid
(
GrpcLibraryInterface
).
hash_code
())
{
GPR_CODEGEN_ASSERT
(
g_glip
&&
"gRPC library not initialized. See "
"grpc::internal::GrpcLibraryInitializer."
);
g_glip
->
shutdown
();
}
}
private:
bool
grpc_init_called_
;
};
}
// namespace grpc
return
res
;
#endif // GRPCPP_IMPL_CODEGEN_GRPC_LIBRARY_H
}
}
// namespace details
}
// namespace framework
}
// namespace paddle
python/paddle/fluid/__init__.py
浏览文件 @
000ba1ac
...
@@ -123,7 +123,7 @@ def __bootstrap__():
...
@@ -123,7 +123,7 @@ def __bootstrap__():
read_env_flags
=
[
read_env_flags
=
[
'use_pinned_memory'
,
'check_nan_inf'
,
'benchmark'
,
'warpctc_dir'
,
'use_pinned_memory'
,
'check_nan_inf'
,
'benchmark'
,
'warpctc_dir'
,
'eager_delete_scope'
,
'use_mkldnn'
,
'initial_cpu_memory_in_mb'
,
'eager_delete_scope'
,
'use_mkldnn'
,
'initial_cpu_memory_in_mb'
,
'init_allocated_mem'
,
'free_idle_memory'
'init_allocated_mem'
,
'free_idle_memory'
,
'paddle_num_threads'
]
]
if
core
.
is_compiled_with_dist
():
if
core
.
is_compiled_with_dist
():
read_env_flags
.
append
(
'rpc_deadline'
)
read_env_flags
.
append
(
'rpc_deadline'
)
...
...
python/paddle/fluid/regularizer.py
浏览文件 @
000ba1ac
...
@@ -142,14 +142,20 @@ class L2DecayRegularizer(WeightDecayRegularizer):
...
@@ -142,14 +142,20 @@ class L2DecayRegularizer(WeightDecayRegularizer):
dtype
=
"float32"
,
shape
=
param
.
shape
,
lod_level
=
param
.
lod_level
)
dtype
=
"float32"
,
shape
=
param
.
shape
,
lod_level
=
param
.
lod_level
)
if
grad
.
type
==
core
.
VarDesc
.
VarType
.
SELECTED_ROWS
:
if
grad
.
type
==
core
.
VarDesc
.
VarType
.
SELECTED_ROWS
:
idx
=
block
.
create_var
(
dtype
=
"int64"
,
shape
=
param
.
shape
,
type
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
)
decay
=
block
.
create_var
(
decay
=
block
.
create_var
(
dtype
=
"float32"
,
dtype
=
"float32"
,
shape
=
param
.
shape
,
shape
=
param
.
shape
,
type
=
core
.
VarDesc
.
VarType
.
SELECTED_ROWS
)
type
=
core
.
VarDesc
.
VarType
.
SELECTED_ROWS
)
block
.
append_op
(
type
=
'extract_rows'
,
inputs
=
{
'X'
:
grad
},
outputs
=
{
'Out'
:
idx
})
block
.
append_op
(
block
.
append_op
(
type
=
'lookup_table'
,
type
=
'lookup_table'
,
inputs
=
{
'W'
:
param
,
inputs
=
{
'W'
:
param
,
'Ids'
:
grad
},
'Ids'
:
idx
},
outputs
=
{
'Out'
:
decay
},
outputs
=
{
'Out'
:
decay
},
attrs
=
{
'is_sparse'
:
True
})
attrs
=
{
'is_sparse'
:
True
})
param
=
decay
param
=
decay
...
@@ -216,14 +222,20 @@ class L1DecayRegularizer(WeightDecayRegularizer):
...
@@ -216,14 +222,20 @@ class L1DecayRegularizer(WeightDecayRegularizer):
dtype
=
"float32"
,
shape
=
param
.
shape
,
lod_level
=
param
.
lod_level
)
dtype
=
"float32"
,
shape
=
param
.
shape
,
lod_level
=
param
.
lod_level
)
if
grad
.
type
==
core
.
VarDesc
.
VarType
.
SELECTED_ROWS
:
if
grad
.
type
==
core
.
VarDesc
.
VarType
.
SELECTED_ROWS
:
idx
=
block
.
create_var
(
dtype
=
"int64"
,
shape
=
param
.
shape
,
type
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
)
decay
=
block
.
create_var
(
decay
=
block
.
create_var
(
dtype
=
"float32"
,
dtype
=
"float32"
,
shape
=
param
.
shape
,
shape
=
param
.
shape
,
type
=
core
.
VarDesc
.
VarType
.
SELECTED_ROWS
)
type
=
core
.
VarDesc
.
VarType
.
SELECTED_ROWS
)
block
.
append_op
(
type
=
'extract_rows'
,
inputs
=
{
'X'
:
grad
},
outputs
=
{
'Out'
:
idx
})
block
.
append_op
(
block
.
append_op
(
type
=
'lookup_table'
,
type
=
'lookup_table'
,
inputs
=
{
'W'
:
param
,
inputs
=
{
'W'
:
param
,
'Ids'
:
grad
},
'Ids'
:
idx
},
outputs
=
{
'Out'
:
decay
},
outputs
=
{
'Out'
:
decay
},
attrs
=
{
'is_sparse'
:
True
})
attrs
=
{
'is_sparse'
:
True
})
...
...
python/paddle/fluid/tests/unittests/CMakeLists.txt
浏览文件 @
000ba1ac
...
@@ -40,7 +40,7 @@ function(py_test_modules TARGET_NAME)
...
@@ -40,7 +40,7 @@ function(py_test_modules TARGET_NAME)
${
PYTHON_EXECUTABLE
}
${
PADDLE_SOURCE_DIR
}
/tools/test_runner.py
${
py_test_modules_MODULES
}
${
PYTHON_EXECUTABLE
}
${
PADDLE_SOURCE_DIR
}
/tools/test_runner.py
${
py_test_modules_MODULES
}
WORKING_DIRECTORY
${
CMAKE_CURRENT_BINARY_DIR
}
)
WORKING_DIRECTORY
${
CMAKE_CURRENT_BINARY_DIR
}
)
if
(
py_test_modules_SERIAL
)
if
(
py_test_modules_SERIAL
)
set_property
(
TEST
${
TARGET_NAME
}
PROPERTY SERIAL 1
)
set_property
(
TEST
${
TARGET_NAME
}
PROPERTY
RUN_
SERIAL 1
)
endif
()
endif
()
endif
()
endif
()
endfunction
()
endfunction
()
...
...
python/paddle/fluid/tests/unittests/dist_se_resnext.py
浏览文件 @
000ba1ac
...
@@ -278,7 +278,7 @@ class DistSeResneXt2x2:
...
@@ -278,7 +278,7 @@ class DistSeResneXt2x2:
def
run_trainer
(
self
,
place
,
endpoints
,
trainer_id
,
trainers
,
is_dist
=
True
):
def
run_trainer
(
self
,
place
,
endpoints
,
trainer_id
,
trainers
,
is_dist
=
True
):
test_program
,
avg_cost
,
train_reader
,
test_reader
,
batch_acc
,
predict
=
get_model
(
test_program
,
avg_cost
,
train_reader
,
test_reader
,
batch_acc
,
predict
=
get_model
(
batch_size
=
2
0
)
batch_size
=
2
)
if
is_dist
:
if
is_dist
:
t
=
get_transpiler
(
trainer_id
,
t
=
get_transpiler
(
trainer_id
,
fluid
.
default_main_program
(),
endpoints
,
fluid
.
default_main_program
(),
endpoints
,
...
@@ -294,11 +294,7 @@ class DistSeResneXt2x2:
...
@@ -294,11 +294,7 @@ class DistSeResneXt2x2:
strategy
.
num_threads
=
1
strategy
.
num_threads
=
1
strategy
.
allow_op_delay
=
False
strategy
.
allow_op_delay
=
False
exe
=
fluid
.
ParallelExecutor
(
exe
=
fluid
.
ParallelExecutor
(
True
,
True
,
loss_name
=
avg_cost
.
name
,
exec_strategy
=
strategy
)
loss_name
=
avg_cost
.
name
,
exec_strategy
=
strategy
,
num_trainers
=
trainers
,
trainer_id
=
trainer_id
)
feed_var_list
=
[
feed_var_list
=
[
var
for
var
in
trainer_prog
.
global_block
().
vars
.
values
()
var
for
var
in
trainer_prog
.
global_block
().
vars
.
values
()
...
...
python/paddle/fluid/tests/unittests/test_dist_se_resnext.py
浏览文件 @
000ba1ac
...
@@ -19,6 +19,7 @@ import math
...
@@ -19,6 +19,7 @@ import math
import
unittest
import
unittest
import
os
import
os
import
sys
import
signal
import
signal
import
subprocess
import
subprocess
...
@@ -56,7 +57,7 @@ class TestDistSeResneXt2x2(unittest.TestCase):
...
@@ -56,7 +57,7 @@ class TestDistSeResneXt2x2(unittest.TestCase):
except
os
.
error
:
except
os
.
error
:
retry_times
-=
1
retry_times
-=
1
def
non_
test_with_place
(
self
):
def
test_with_place
(
self
):
# *ATTENTION* THIS TEST NEEDS AT LEAST 2GPUS TO RUN
# *ATTENTION* THIS TEST NEEDS AT LEAST 2GPUS TO RUN
required_envs
=
{
required_envs
=
{
"PATH"
:
os
.
getenv
(
"PATH"
),
"PATH"
:
os
.
getenv
(
"PATH"
),
...
@@ -70,9 +71,15 @@ class TestDistSeResneXt2x2(unittest.TestCase):
...
@@ -70,9 +71,15 @@ class TestDistSeResneXt2x2(unittest.TestCase):
local_cmd
=
"%s dist_se_resnext.py trainer %s 0 %s %d FLASE"
%
\
local_cmd
=
"%s dist_se_resnext.py trainer %s 0 %s %d FLASE"
%
\
(
self
.
_python_interp
,
"127.0.0.1:1234"
,
"127.0.0.1:1234"
,
1
)
(
self
.
_python_interp
,
"127.0.0.1:1234"
,
"127.0.0.1:1234"
,
1
)
local_proc
=
subprocess
.
Popen
(
local_proc
=
subprocess
.
Popen
(
local_cmd
.
split
(
" "
),
stdout
=
subprocess
.
PIPE
,
env
=
env_local
)
local_cmd
.
split
(
" "
),
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
,
env
=
env_local
)
local_proc
.
wait
()
local_proc
.
wait
()
local_ret
=
local_proc
.
stdout
.
read
()
out
,
err
=
local_proc
.
communicate
()
local_ret
=
out
sys
.
stderr
.
write
(
'local_loss: %s
\n
'
%
local_ret
)
sys
.
stderr
.
write
(
'local_stderr: %s
\n
'
%
err
)
# Run dist train to compare with local results
# Run dist train to compare with local results
ps0
,
ps1
=
self
.
start_pserver
()
ps0
,
ps1
=
self
.
start_pserver
()
...
@@ -92,13 +99,22 @@ class TestDistSeResneXt2x2(unittest.TestCase):
...
@@ -92,13 +99,22 @@ class TestDistSeResneXt2x2(unittest.TestCase):
FNULL
=
open
(
os
.
devnull
,
'w'
)
FNULL
=
open
(
os
.
devnull
,
'w'
)
tr0_proc
=
subprocess
.
Popen
(
tr0_proc
=
subprocess
.
Popen
(
tr0_cmd
.
split
(
" "
),
stdout
=
subprocess
.
PIPE
,
stderr
=
FNULL
,
env
=
env0
)
tr0_cmd
.
split
(
" "
),
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
,
env
=
env0
)
tr1_proc
=
subprocess
.
Popen
(
tr1_proc
=
subprocess
.
Popen
(
tr1_cmd
.
split
(
" "
),
stdout
=
subprocess
.
PIPE
,
stderr
=
FNULL
,
env
=
env1
)
tr1_cmd
.
split
(
" "
),
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
,
env
=
env1
)
tr0_proc
.
wait
()
tr0_proc
.
wait
()
tr1_proc
.
wait
()
tr1_proc
.
wait
()
loss_data0
=
tr0_proc
.
stdout
.
read
()
out
,
err
=
tr0_proc
.
communicate
()
sys
.
stderr
.
write
(
'dist_stderr: %s
\n
'
%
err
)
loss_data0
=
out
sys
.
stderr
.
write
(
'dist_loss: %s
\n
'
%
loss_data0
)
lines
=
loss_data0
.
split
(
"
\n
"
)
lines
=
loss_data0
.
split
(
"
\n
"
)
dist_first_loss
=
eval
(
lines
[
0
].
replace
(
" "
,
","
))[
0
]
dist_first_loss
=
eval
(
lines
[
0
].
replace
(
" "
,
","
))[
0
]
dist_last_loss
=
eval
(
lines
[
1
].
replace
(
" "
,
","
))[
0
]
dist_last_loss
=
eval
(
lines
[
1
].
replace
(
" "
,
","
))[
0
]
...
...
python/paddle/fluid/tests/unittests/test_extract_rows_op.py
0 → 100644
浏览文件 @
000ba1ac
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
numpy
as
np
import
paddle.fluid.core
as
core
from
paddle.fluid.op
import
Operator
from
op_test
import
OpTest
class
TestExtractRows
(
OpTest
):
def
check_with_place
(
self
,
place
):
scope
=
core
.
Scope
()
# create and initialize Variable
feature_len
=
12
rows
=
[
0
,
4
,
4
,
7
]
np_array
=
np
.
ones
((
len
(
rows
),
feature_len
)).
astype
(
"float32"
)
in_x
=
scope
.
var
(
'X'
).
get_selected_rows
()
in_x
.
set_height
(
len
(
rows
))
in_x
.
set_rows
(
rows
)
in_x_tensor
=
in_x
.
get_tensor
()
in_x_tensor
.
set
(
np_array
,
place
)
# create Out Variable
out_tensor
=
scope
.
var
(
'Out'
).
get_tensor
()
# create and run lookup_table operator
extract_rows_op
=
Operator
(
"extract_rows"
,
X
=
'X'
,
Out
=
'Out'
)
extract_rows_op
.
run
(
scope
,
place
)
# get result from Out
result_array
=
np
.
array
(
out_tensor
)
result_array
=
[
ele
[
0
]
for
ele
in
result_array
]
assert
result_array
==
rows
def
test_concat_rows
(
self
):
places
=
[
core
.
CPUPlace
()]
if
core
.
is_compiled_with_cuda
():
places
.
append
(
core
.
CUDAPlace
(
0
))
for
place
in
places
:
self
.
check_with_place
(
place
)
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_flatten_op.py
0 → 100644
浏览文件 @
000ba1ac
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
numpy
as
np
from
op_test
import
OpTest
class
TestFlattenOp
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"flatten"
self
.
init_test_case
()
self
.
inputs
=
{
"X"
:
np
.
random
.
random
(
self
.
in_shape
).
astype
(
"float32"
)}
self
.
init_attrs
()
self
.
outputs
=
{
"Out"
:
self
.
inputs
[
"X"
].
reshape
(
self
.
new_shape
)}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
"X"
],
"Out"
)
def
init_test_case
(
self
):
self
.
in_shape
=
(
3
,
2
,
2
,
5
)
self
.
axis
=
1
self
.
new_shape
=
(
3
,
20
)
def
init_attrs
(
self
):
self
.
attrs
=
{
"axis"
:
self
.
axis
}
class
TestFlattenOp
(
TestFlattenOp
):
def
init_test_case
(
self
):
self
.
in_shape
=
(
3
,
2
,
2
,
3
)
self
.
axis
=
0
self
.
new_shape
=
(
1
,
36
)
class
TestFlattenOpWithDefaultAxis
(
TestFlattenOp
):
def
init_test_case
(
self
):
self
.
in_shape
=
(
3
,
2
,
2
,
3
)
self
.
new_shape
=
(
3
,
12
)
def
init_attrs
(
self
):
self
.
attrs
=
{}
class
TestFlattenOpSixDims
(
TestFlattenOp
):
def
init_test_case
(
self
):
self
.
in_shape
=
(
3
,
2
,
3
,
2
,
4
,
4
)
self
.
axis
=
4
self
.
new_shape
=
(
36
,
16
)
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_lookup_table_op.py
浏览文件 @
000ba1ac
...
@@ -49,53 +49,6 @@ class TestLookupTableOpWithPadding(TestLookupTableOp):
...
@@ -49,53 +49,6 @@ class TestLookupTableOpWithPadding(TestLookupTableOp):
pass
pass
class
TestLookupTableIdsIsSelectedRows
(
OpTest
):
def
check_with_place
(
self
,
place
):
scope
=
core
.
Scope
()
# create and initialize Variable
height
=
10
rows
=
[
0
,
4
,
4
,
7
]
row_numel
=
12
# create and initialize W Variable
W
=
scope
.
var
(
'W'
).
get_tensor
()
W_array
=
np
.
full
((
height
,
row_numel
),
1.0
).
astype
(
"float32"
)
for
i
in
range
(
height
):
W_array
[
i
]
*=
i
W
.
set
(
W_array
,
place
)
# create and initialize Ids Variable
ids_selected_rows
=
scope
.
var
(
'Ids'
).
get_selected_rows
()
ids_selected_rows
.
set_height
(
len
(
rows
))
ids_selected_rows
.
set_rows
(
rows
)
np_array
=
np
.
ones
((
len
(
rows
),
row_numel
)).
astype
(
"float32"
)
ids_tensor
=
ids_selected_rows
.
get_tensor
()
ids_tensor
.
set
(
np_array
,
place
)
# create Out Variable
Out
=
scope
.
var
(
'Out'
).
get_selected_rows
()
# create and run lookup_table operator
concat_rows_op
=
Operator
(
"lookup_table"
,
W
=
'W'
,
Ids
=
'Ids'
,
Out
=
'Out'
)
concat_rows_op
.
run
(
scope
,
place
)
# get result from Out
Out_tensor
=
Out
.
get_tensor
()
result_array
=
np
.
array
(
Out_tensor
)
# all(): return True if all elements of the iterable are true (or if the iterable is empty)
for
idx
,
row
in
enumerate
(
rows
):
assert
(
row
==
result_array
[
idx
]).
all
()
def
test_concat_rows
(
self
):
places
=
[
core
.
CPUPlace
()]
if
core
.
is_compiled_with_cuda
():
places
.
append
(
core
.
CUDAPlace
(
0
))
for
place
in
places
:
self
.
check_with_place
(
place
)
class
TestLookupTableWIsSelectedRows
(
OpTest
):
class
TestLookupTableWIsSelectedRows
(
OpTest
):
def
check_with_place
(
self
,
place
):
def
check_with_place
(
self
,
place
):
scope
=
core
.
Scope
()
scope
=
core
.
Scope
()
...
...
python/paddle/fluid/transpiler/distribute_transpiler.py
浏览文件 @
000ba1ac
...
@@ -346,6 +346,7 @@ class DistributeTranspiler(object):
...
@@ -346,6 +346,7 @@ class DistributeTranspiler(object):
# step1
# step1
pserver_program
=
Program
()
pserver_program
=
Program
()
pserver_program
.
random_seed
=
self
.
origin_program
.
random_seed
# step2: Create vars to receive vars at parameter servers.
# step2: Create vars to receive vars at parameter servers.
recv_inputs
=
[]
recv_inputs
=
[]
for
v
in
self
.
param_grad_ep_mapping
[
endpoint
][
"params"
]:
for
v
in
self
.
param_grad_ep_mapping
[
endpoint
][
"params"
]:
...
@@ -543,6 +544,7 @@ class DistributeTranspiler(object):
...
@@ -543,6 +544,7 @@ class DistributeTranspiler(object):
"""
"""
s_prog
=
Program
()
s_prog
=
Program
()
orig_s_prog
=
default_startup_program
()
orig_s_prog
=
default_startup_program
()
s_prog
.
random_seed
=
orig_s_prog
.
random_seed
params
=
self
.
param_grad_ep_mapping
[
endpoint
][
"params"
]
params
=
self
.
param_grad_ep_mapping
[
endpoint
][
"params"
]
def
_get_splited_name_and_shape
(
varname
):
def
_get_splited_name_and_shape
(
varname
):
...
...
tools/codestyle/cpplint_pre_commit.hook
浏览文件 @
000ba1ac
...
@@ -4,7 +4,7 @@ TOTAL_ERRORS=0
...
@@ -4,7 +4,7 @@ TOTAL_ERRORS=0
# The trick to remove deleted files: https://stackoverflow.com/a/2413151
# The trick to remove deleted files: https://stackoverflow.com/a/2413151
for
file
in
$(
git diff
--cached
--name-status
|
awk
'$1 != "D" {print $2}'
)
;
do
for
file
in
$(
git diff
--cached
--name-status
|
awk
'$1 != "D" {print $2}'
)
;
do
if
[[
$file
=
~ ^
(
paddle/legacy/api/.
*
|paddle/legacy/capi/.
*
|paddle/contrib/.
*
|paddle/legacy/cuda/.
*
|paddle/legacy/function/.
*
|paddle/legacy/gserver/.
*
|paddle/legacy/math/.
*
|paddle/legacy/optimizer/.
*
|paddle/legacy/parameter/.
*
|paddle/legacy/pserver/.
*
|paddle/legacy/trainer/.
*
|paddle/legacy/utils/.
*
|paddle/testing/TestUtil.
*
)
]]
;
then
if
[[
$file
=
~ ^
(
paddle/legacy/api/.
*
|paddle/legacy/capi/.
*
|paddle/contrib/.
*
|paddle/legacy/cuda/.
*
|paddle/legacy/function/.
*
|paddle/legacy/gserver/.
*
|paddle/legacy/math/.
*
|paddle/legacy/optimizer/.
*
|paddle/legacy/parameter/.
*
|paddle/legacy/pserver/.
*
|paddle/legacy/trainer/.
*
|paddle/legacy/utils/.
*
|paddle/testing/TestUtil.
*
|patches/grpc/.
*
)
]]
;
then
continue
;
continue
;
else
else
cpplint
--filter
=
-readability
/fn_size
$file
;
cpplint
--filter
=
-readability
/fn_size
$file
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录