Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
0b07eef1
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
0b07eef1
编写于
5年前
作者:
Y
Yan Xu
提交者:
GitHub
5年前
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
ParallelDyGraph with GPU collective mode (#16827)
implement dygraph.parallel.DataParallel to hook reduce op.
上级
1a4a51db
develop
2.0.1-rocm-post
Ligoml-patch-1
OliverLPH-patch-1
OliverLPH-patch-2
PaddlePM-patch-1
PaddlePM-patch-2
ZHUI-patch-1
add_default_att
add_model_benchmark_ci
add_some_yaml_config
addfile
all_new_design_exec
ascendrc
ascendrelease
cherry_undefined_var
compile_windows
delete_2.0.1-rocm-post
delete_add_default_att
delete_all_new_design_exec
delete_ascendrc
delete_compile_windows
delete_delete_addfile
delete_disable_iterable_dataset_unittest
delete_fix_dataloader_memory_leak
delete_fix_imperative_dygraph_error
delete_fix_retry_ci
delete_fix_undefined_var
delete_improve_sccache
delete_incubate/lite
delete_paddle_tiny_install
delete_paralleltest
delete_prv-disable-more-cache
delete_revert-31068-fix_conv3d_windows
delete_revert-31562-mean
delete_revert-33630-bug-fix
delete_revert-34159-add_npu_bce_logical_dev
delete_revert-34910-spinlocks_for_allocator
delete_revert-35069-revert-34910-spinlocks_for_allocator
delete_revert-36057-dev/read_flags_in_ut
dingjiaweiww-patch-1
disable_iterable_dataset_unittest
dy2static
enable_eager_model_test
final_state_gen_python_c
final_state_intermediate
fix-numpy-issue
fix_concat_slice
fix_dataloader_memory_leak
fix_imperative_dygraph_error
fix_npu_ci
fix_op_flops
fix_retry_ci
fix_rnn_docs
fix_tensor_type
fix_undefined_var
fixiscan
fixiscan1
fixiscan2
fixiscan3
github/fork/123malin/netifaces
github/fork/123malin/tdm_abacus
github/fork/AshburnLee/dev_unique
github/fork/ForFishes/fix_memory_matmul
github/fork/ForFishes/rm_fluid
github/fork/LielinJiang/move-2.0-api
github/fork/LielinJiang/visual-dl-cb
github/fork/LiuChiachi/add-transformer-generate-square-subsequent-mask-api
github/fork/LiuChiachi/fix-example-code-for-hapi-Model
github/fork/LiuChiachi/remove-input-requirment-in-dygraph-Model
github/fork/MrChengmo/fix_ps_profiler
github/fork/MrChengmo/update_ps_heter
github/fork/PWhiddy/patch-1
github/fork/Shixiaowei02/dev/save_load_upgrade
github/fork/TCChenlong/fix_hapi
github/fork/TCChenlong/fix_inden
github/fork/Thunderbrook/xpu_slice
github/fork/XieYunshen/disable_ut_test_parallel_executor_fetch_isolated_var
github/fork/XieYunshen/disable_ut_test_parallel_executor_fetch_isolated_var_2
github/fork/XieYunshen/disable_ut_test_parallel_executor_fetch_isolated_var_3
github/fork/XieYunshen/timeout_20S_ut
github/fork/ZeyuChen/remove-nltk
github/fork/arlesniak/arlesniak/selective__mkldnn_flags
github/fork/baiyfbupt/code_doc_mig
github/fork/chalsliu/set_timeout
github/fork/chen-zhiyu/develop
github/fork/chenwhql/ci/try_to_find_test_buffer_shared_memory_reuse_pass_error
github/fork/chenwhql/dygraph/remove_scale_loss_and_apply_collective_grads
github/fork/chenwhql/saveload/add_get_inference_program
github/fork/chenwhql/saveload/remove_save_load_config
github/fork/cryoco/pass-compatibility-trt
github/fork/danleifeng/isempty_api2.0
github/fork/frankwhzhang/api_transfer
github/fork/hbwx24/error_msg/cuda_kernel_error_msg
github/fork/heavengate/cherry_yolo_box
github/fork/heavengate/update_yolo_box
github/fork/iclementine/rnn_fix
github/fork/iducn/testestse
github/fork/jczaja/prv-25537-fix
github/fork/jeff41404/release/1.8
github/fork/jiweibo/api_2.0
github/fork/jiweibo/fix_lite_resnet50_test
github/fork/juncaipeng/fix_doc_1
github/fork/lfchener/sample_code
github/fork/littletomatodonkey/fix_reg_doc
github/fork/liym27/dy2stat_update_assign_to_rc20
github/fork/luotao1/profiler_ut
github/fork/mapingshuo/add_wait
github/fork/mapingshuo/doc_2.0
github/fork/mapingshuo/zero-0.5
github/fork/miraiwk/dev
github/fork/pangyoki/add-Categorical-class-branch
github/fork/pangyoki/add-multinomial-op-branch
github/fork/pangyoki/fix-test_distritbution-CI
github/fork/qjing666/doublegrad
github/fork/qjing666/fix_hdfs_download
github/fork/sandyhouse/add_gather_etc
github/fork/sandyhouse/add_send_recv_alltoall_etc
github/fork/sandyhouse/pipeline_exe_run
github/fork/seiriosPlus/feature/large_scale_kv_save_delta
github/fork/seiriosPlus/fix/paddle_errors_fix
github/fork/seiriosPlus/fix/paddle_op_errors
github/fork/shangzhizhou/fix_test_activation_op_random_bug
github/fork/smallv0221/yxp0924
github/fork/smallv0221/yxp0925
github/fork/swtkiwi/del-matplotlib
github/fork/tianshuo78520a/kunlun_test
github/fork/tianshuo78520a/update_dockerfile
github/fork/wanghaoshuang/bert_fuse
github/fork/wanghaoshuang/label_smooth
github/fork/wanghuancoder/develop_CUDASynchronize
github/fork/wanghuancoder/develop_Layer_doc
github/fork/wanghuancoder/develop_ParameterList_doc
github/fork/wanghuancoder/develop_Sequential_doc
github/fork/wanghuancoder/develop_bilinear_tensor_product
github/fork/wanghuancoder/develop_coverage_build_sh
github/fork/wanghuancoder/develop_in_dynamic_mode_doc
github/fork/wanghuancoder/develop_unique_name_doc
github/fork/wangxicoding/fleet_meta_combine
github/fork/wawltor/error_message_fix_5
github/fork/willthefrog/remove_l2_norm
github/fork/windstamp/momentum_op
github/fork/windstamp/mv_op_5
github/fork/windstamp/normal_api
github/fork/wojtuss/wojtuss/fusion_gru_quantization
github/fork/wojtuss/wojtuss/quantization-with-shift
github/fork/wzzju/fix_err_info
github/fork/wzzju/pure_fp16
github/fork/xiemoyuan/op_error_message
github/fork/xiemoyuan/optimize_error_message
github/fork/yaoxuefeng6/fix_doc
github/fork/yaoxuefeng6/mod_dataset_v2
github/fork/yongqiangma/lod
github/fork/ysh329/fix-clip-by-norm-error
github/fork/ysh329/fix-error-clip-by-value
github/fork/yukavio/error_info
github/fork/zhangting2020/conv_filter_grad
github/fork/zhangting2020/is_compile_with_cuda
github/fork/zhangting2020/place_doc
github/fork/zhangting2020/program
github/fork/zhhsplendid/fix_any
github/fork/zhhsplendid/refine_api2
github/fork/zhhsplendid/refine_api2_test
github/fork/zhhsplendid/refine_api_test_ptb_lm
github/fork/zhhsplendid/refine_api_test_resnet
github/fork/zhhsplendid/refine_api_test_simnet
github/fork/zhiqiu/dev/refine_initializer
github/fork/zhiqiu/dev/remove_inplace_argument
github/fork/zlsh80826/nvinfer_plugin_var_len_cuda11
improve_sccache
incubate/infrt
incubate/lite
inplace_addto
make_flag_adding_easier
move_embedding_to_phi
move_histogram_to_pten
move_sgd_to_phi
move_slice_to_pten
move_temporal_shift_to_phi
move_yolo_box_to_phi
npu_fix_alloc
numel
paddle_tiny_install
paralleltest
preln_ernie
prv-disable-more-cache
prv-md-even-more
prv-onednn-2.5
pten_tensor_refactor
release/1.5
release/1.6
release/1.7
release/1.8
release/2.0
release/2.0-alpha
release/2.0-beta
release/2.0-rc
release/2.0-rc1
release/2.1
release/2.2
release/2.3
release/2.3-fc-ernie-fix
release/2.4
release/lite-0.1
revert-24981-add_device_attr_for_regulization
revert-26856-strategy_example2
revert-27520-disable_pr
revert-31068-fix_conv3d_windows
revert-31562-mean
revert-32290-develop-hardlabel
revert-33037-forci
revert-33475-fix_cifar_label_dimension
revert-33630-bug-fix
revert-34159-add_npu_bce_logical_dev
revert-34406-add_copy_from_tensor
revert-34910-spinlocks_for_allocator
revert-35069-revert-34910-spinlocks_for_allocator
revert-36057-dev/read_flags_in_ut
revert-36201-refine_fast_threaded_ssa_graph_executor
revert-36985-add_license
revert-37318-refactor_dygraph_to_eager
revert-37926-eager_coreops_500
revert-37956-revert-37727-pylayer_support_tuple
revert-38100-mingdong
revert-38301-allocation_rearrange_pr
revert-38703-numpy_bf16_package_reupload
revert-38732-remove_useless_header_in_elementwise_mul_grad
revert-38959-Reduce_Grad
revert-39143-adjust_empty
revert-39227-move_trace_op_to_pten
revert-39268-dev/remove_concat_fluid_kernel
revert-40170-support_partial_grad
revert-41056-revert-40727-move_some_activaion_to_phi
revert-41065-revert-40993-mv_ele_floordiv_pow
revert-41068-revert-40790-phi_new
revert-41944-smaller_inference_api_test
revert-42149-do-not-reset-default-stream-for-stream-safe-cuda-allocator
revert-43155-fix_ut_tempfile
revert-43882-revert-41944-smaller_inference_api_test
revert-45808-phi/simplify_size_op
revert-46827-deform_comment
rocm_dev_0217
support_weight_transpose
test_benchmark_ci
test_feature_precision_test_c
test_model_benchmark
test_model_benchmark_ci
zhiqiu-patch-1
v2.4.0-rc0
v2.3.2
v2.3.1
v2.3.0
v2.3.0-rc0
v2.2.2
v2.2.1
v2.2.0
v2.2.0-rc0
v2.2.0-bak0
v2.1.3
v2.1.2
v2.1.1
v2.1.0
v2.1.0-rc0
v2.0.2
v2.0.1
v2.0.0
v2.0.0-rc1
v2.0.0-rc0
v2.0.0-beta0
v2.0.0-alpha0
v1.8.5
v1.8.4
v1.8.3
v1.8.2
v1.8.1
v1.8.0
v1.7.2
v1.7.1
v1.7.0
v1.6.3
v1.6.2
v1.6.1
v1.6.0
v1.6.0-rc0
v1.5.2
v1.5.1
v1.5.0
lite-v0.1
无相关合并请求
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
436 addition
and
97 deletion
+436
-97
paddle/fluid/imperative/layer.cc
paddle/fluid/imperative/layer.cc
+6
-2
paddle/fluid/imperative/layer.h
paddle/fluid/imperative/layer.h
+1
-1
paddle/fluid/operators/distributed_ops/allreduce_op.cc
paddle/fluid/operators/distributed_ops/allreduce_op.cc
+24
-87
paddle/fluid/operators/distributed_ops/allreduce_op.cu.cc
paddle/fluid/operators/distributed_ops/allreduce_op.cu.cc
+25
-0
paddle/fluid/operators/distributed_ops/allreduce_op.h
paddle/fluid/operators/distributed_ops/allreduce_op.h
+87
-0
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+5
-3
python/paddle/fluid/dygraph/parallel.py
python/paddle/fluid/dygraph/parallel.py
+46
-1
python/paddle/fluid/layers/collective.py
python/paddle/fluid/layers/collective.py
+3
-2
python/paddle/fluid/tests/unittests/CMakeLists.txt
python/paddle/fluid/tests/unittests/CMakeLists.txt
+1
-0
python/paddle/fluid/tests/unittests/parallel_dygraph_mnist.py
...on/paddle/fluid/tests/unittests/parallel_dygraph_mnist.py
+136
-0
python/paddle/fluid/tests/unittests/test_dist_base.py
python/paddle/fluid/tests/unittests/test_dist_base.py
+70
-1
python/paddle/fluid/tests/unittests/test_parallel_dygraph_mnist.py
...ddle/fluid/tests/unittests/test_parallel_dygraph_mnist.py
+32
-0
未找到文件。
paddle/fluid/imperative/layer.cc
浏览文件 @
0b07eef1
...
...
@@ -336,11 +336,15 @@ void OpBase::InvokeBackwardHooks() {
}
}
void
OpBase
::
RegisterBackwardHooks
(
const
py
::
object
&
callable
)
{
void
OpBase
::
RegisterBackwardHooks
(
const
py
::
object
&
callable
,
bool
front
)
{
VLOG
(
3
)
<<
"Register backward hooks "
<<
trace_id_
;
// TODO(minqiyang): check the callable format
backward_hooks_
.
push_back
(
callable
);
if
(
front
)
{
backward_hooks_
.
insert
(
backward_hooks_
.
begin
(),
callable
);
}
else
{
backward_hooks_
.
push_back
(
callable
);
}
}
void
VarBase
::
RunBackward
()
{
...
...
This diff is collapsed.
Click to expand it.
paddle/fluid/imperative/layer.h
浏览文件 @
0b07eef1
...
...
@@ -310,7 +310,7 @@ class PYBIND11_HIDDEN OpBase {
return
grad_op_descs_
[
index
]
->
Type
();
}
void
RegisterBackwardHooks
(
const
py
::
object
&
callable
);
void
RegisterBackwardHooks
(
const
py
::
object
&
callable
,
bool
front
=
false
);
void
InvokeBackwardHooks
();
...
...
This diff is collapsed.
Click to expand it.
paddle/fluid/operators/distributed_ops/allreduce_op.cc
浏览文件 @
0b07eef1
...
...
@@ -15,91 +15,22 @@ limitations under the License. */
#include <future> // NOLINT
#include <ostream>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/nccl_helper.h"
#endif
#include "paddle/fluid/operators/distributed_ops/allreduce_op.h"
namespace
paddle
{
namespace
operators
{
struct
MutableDataFunctor
{
MutableDataFunctor
(
void
**
data
,
framework
::
LoDTensor
*
tensor
,
const
platform
::
Place
&
place
)
:
data_
(
data
),
tensor_
(
tensor
),
place_
(
place
)
{}
template
<
typename
T
>
void
apply
()
{
*
data_
=
tensor_
->
mutable_data
<
T
>
(
place_
);
}
class
AllReduceOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
**
data_
;
framework
::
LoDTensor
*
tensor_
;
platform
::
Place
place_
;
};
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{}
class
AllReduceOp
:
public
framework
::
OperatorBase
{
using
OperatorBase
::
OperatorBase
;
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
PADDLE_ENFORCE
(
is_gpu_place
(
place
),
"AllReduce op can run on gpu place only for now."
);
#ifdef PADDLE_WITH_CUDA
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
*
ctx
=
pool
.
Get
(
place
);
auto
in_names
=
Inputs
(
"X"
);
auto
out_names
=
Outputs
(
"Out"
);
PADDLE_ENFORCE_EQ
(
in_names
.
size
(),
1
,
"Only support one input"
);
PADDLE_ENFORCE_EQ
(
out_names
.
size
(),
1
,
"Only support one output"
);
auto
*
in
=
scope
.
FindVar
(
in_names
[
0
]);
auto
*
out
=
scope
.
FindVar
(
out_names
[
0
]);
PADDLE_ENFORCE
(
in
->
IsType
<
framework
::
LoDTensor
>
()
||
out
->
IsType
<
framework
::
LoDTensor
>
(),
"Only support allreduce LoDTensors"
);
int
dtype
=
-
1
;
auto
in_tensor
=
in
->
Get
<
framework
::
LoDTensor
>
();
dtype
=
platform
::
ToNCCLDataType
(
in_tensor
.
type
());
int64_t
numel
=
in_tensor
.
numel
();
auto
*
sendbuff
=
in_tensor
.
data
<
void
>
();
auto
*
out_tensor
=
out
->
GetMutable
<
framework
::
LoDTensor
>
();
out_tensor
->
Resize
(
in_tensor
.
dims
());
void
*
recvbuff
=
nullptr
;
framework
::
VisitDataType
(
in_tensor
.
type
(),
MutableDataFunctor
(
&
recvbuff
,
out_tensor
,
place
));
auto
cuda_ctx
=
static_cast
<
platform
::
CUDADeviceContext
*>
(
ctx
);
auto
*
comm
=
cuda_ctx
->
nccl_comm
();
// FIXME(typhoonzero): should use nccl stream here.
auto
stream
=
cuda_ctx
->
stream
();
int
reduce_type
=
Attr
<
int
>
(
"reduce_type"
);
ncclRedOp_t
red_type
=
ncclSum
;
switch
(
reduce_type
)
{
case
0
:
red_type
=
ncclSum
;
break
;
case
1
:
red_type
=
ncclProd
;
break
;
case
2
:
red_type
=
ncclMax
;
break
;
case
3
:
red_type
=
ncclMin
;
break
;
}
PADDLE_ENFORCE
(
platform
::
dynload
::
ncclAllReduce
(
sendbuff
,
recvbuff
,
numel
,
static_cast
<
ncclDataType_t
>
(
dtype
),
red_type
,
comm
,
stream
));
#endif
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
)
->
type
(),
ctx
.
GetPlace
());
}
};
...
...
@@ -110,6 +41,10 @@ class AllReduceOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput
(
"Out"
,
"(Tensor) the result of allreduced."
);
AddAttr
<
int
>
(
"reduce_type"
,
"(int) determin the reduce type."
)
.
SetDefault
(
0
);
AddAttr
<
bool
>
(
"sync_mode"
,
"(bool) whether to synchronize the CUDA stream after nccl call."
)
.
SetDefault
(
false
);
AddComment
(
R"DOC(
***AllReduce Operator***
...
...
@@ -128,16 +63,18 @@ If input and output are the same variable, in-place allreduce will be used.
}
};
class
AllReduceOpShapeInference
:
public
framework
::
InferShapeBase
{
public:
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
{}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
REGISTER_OP_WITHOUT_GRADIENT
(
allreduce
,
ops
::
AllReduceOp
,
ops
::
AllReduceOpMaker
);
REGISTER_OPERATOR
(
allreduce
,
ops
::
AllReduceOp
,
paddle
::
framework
::
EmptyGradOpMaker
,
ops
::
AllReduceOpMaker
,
ops
::
AllReduceOpShapeInference
);
REGISTER_OP_CPU_KERNEL
(
allreduce
,
ops
::
AllReduceOpKernel
<
plat
::
CPUDeviceContext
,
float
>
,
ops
::
AllReduceOpKernel
<
plat
::
CPUDeviceContext
,
double
>
,
ops
::
AllReduceOpKernel
<
plat
::
CPUDeviceContext
,
int
>
,
ops
::
AllReduceOpKernel
<
plat
::
CPUDeviceContext
,
int64_t
>
,
ops
::
AllReduceOpKernel
<
plat
::
CPUDeviceContext
,
plat
::
float16
>
);
This diff is collapsed.
Click to expand it.
paddle/fluid/operators/distributed_ops/allreduce_op.cu.cc
0 → 100644
浏览文件 @
0b07eef1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/distributed_ops/allreduce_op.h"
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
REGISTER_OP_CUDA_KERNEL
(
allreduce
,
ops
::
AllReduceOpKernel
<
plat
::
CUDADeviceContext
,
float
>
,
ops
::
AllReduceOpKernel
<
plat
::
CUDADeviceContext
,
double
>
,
ops
::
AllReduceOpKernel
<
plat
::
CUDADeviceContext
,
int
>
,
ops
::
AllReduceOpKernel
<
plat
::
CUDADeviceContext
,
int64_t
>
,
ops
::
AllReduceOpKernel
<
plat
::
CUDADeviceContext
,
plat
::
float16
>
);
This diff is collapsed.
Click to expand it.
paddle/fluid/operators/distributed_ops/allreduce_op.h
0 → 100644
浏览文件 @
0b07eef1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#include "paddle/fluid/platform/nccl_helper.h"
#endif
namespace
paddle
{
namespace
operators
{
template
<
typename
DeviceContext
,
typename
T
>
class
AllReduceOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
place
=
ctx
.
GetPlace
();
PADDLE_ENFORCE
(
is_gpu_place
(
place
),
"AllReduce op can run on gpu place only for now."
);
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
auto
in
=
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"Out"
);
int
dtype
=
platform
::
ToNCCLDataType
(
in
->
type
());
int64_t
numel
=
in
->
numel
();
auto
*
sendbuff
=
in
->
data
<
void
>
();
out
->
Resize
(
in
->
dims
());
void
*
recvbuff
=
out
->
mutable_data
<
T
>
(
place
);
auto
*
comm
=
dev_ctx
.
nccl_comm
();
// FIXME(typhoonzero): should use nccl stream here.
auto
stream
=
dev_ctx
.
stream
();
PADDLE_ENFORCE_NOT_NULL
(
stream
,
"Should initialize NCCL firstly."
);
int
reduce_type
=
ctx
.
Attr
<
int
>
(
"reduce_type"
);
ncclRedOp_t
red_type
=
ncclSum
;
switch
(
reduce_type
)
{
case
0
:
red_type
=
ncclSum
;
break
;
case
1
:
red_type
=
ncclProd
;
break
;
case
2
:
red_type
=
ncclMax
;
break
;
case
3
:
red_type
=
ncclMin
;
break
;
}
VLOG
(
0
)
<<
"call allreduce with type: "
<<
reduce_type
;
PADDLE_ENFORCE
(
platform
::
dynload
::
ncclAllReduce
(
sendbuff
,
recvbuff
,
numel
,
static_cast
<
ncclDataType_t
>
(
dtype
),
red_type
,
comm
,
stream
));
if
(
ctx
.
Attr
<
bool
>
(
"sync_mode"
))
{
VLOG
(
0
)
<<
"sync allreduce..."
;
cudaError_t
e_sync
=
cudaStreamSynchronize
(
stream
);
if
(
e_sync
!=
0
)
{
LOG
(
FATAL
)
<<
"cudaStreamSynchronize "
<<
cudaGetErrorString
(
e_sync
);
}
}
#else
PADDLE_THROW
(
"PaddlePaddle should compile with GPU."
);
#endif
}
};
}
// namespace operators
}
// namespace paddle
This diff is collapsed.
Click to expand it.
paddle/fluid/pybind/pybind.cc
浏览文件 @
0b07eef1
...
...
@@ -236,9 +236,11 @@ PYBIND11_MODULE(core, m) {
py
::
class_
<
imperative
::
OpBase
,
PyOpBase
>
(
m
,
"OpBase"
,
R"DOC()DOC"
)
.
def
(
py
::
init
<
const
std
::
string
&>
())
.
def
(
"register_backward_hooks"
,
[](
imperative
::
OpBase
&
self
,
const
py
::
object
&
callable
)
{
self
.
RegisterBackwardHooks
(
callable
);
})
[](
imperative
::
OpBase
&
self
,
const
py
::
object
&
callable
,
bool
front
=
false
)
{
self
.
RegisterBackwardHooks
(
callable
,
front
);
},
py
::
arg
(
"callable"
),
py
::
arg
(
"front"
)
=
false
)
.
def_property
(
"_trace_id"
,
[](
const
imperative
::
OpBase
&
self
)
{
pybind11
::
gil_scoped_release
release
;
...
...
This diff is collapsed.
Click to expand it.
python/paddle/fluid/dygraph/parallel.py
浏览文件 @
0b07eef1
...
...
@@ -12,7 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
six
from
..
import
core
from
.
import
layers
from
..
import
framework
from
..layers
import
collective
__all__
=
[
"prepare_context"
]
...
...
@@ -21,9 +27,13 @@ ParallelStrategy = core.ParallelStrategy
__parallel_ctx__clz__
=
None
def
prepare_context
(
parallel_strategy
,
place
):
def
prepare_context
(
parallel_strategy
):
global
__parallel_ctx__clz__
assert
__parallel_ctx__clz__
is
None
,
"ParallelContext can only be initialized once."
assert
framework
.
in_dygraph_mode
(
)
is
True
,
"dygraph.parallel.prepare_context should be used with dygrahp mode."
place
=
framework
.
_current_expected_place
()
assert
place
is
not
None
,
"dygraph.parallel.prepare_context should be used in fluid.dygraph.guard(place) guard."
if
isinstance
(
place
,
core
.
CUDAPlace
):
__parallel_ctx__clz__
=
core
.
NCCLParallelContext
(
parallel_strategy
,
...
...
@@ -58,3 +68,38 @@ class Env(object):
@
property
def
current_endpoint
(
self
):
return
self
.
_current_endpoint
@
property
def
trainer_endpoints
(
self
):
return
self
.
_trainer_endpoints
class
DataParallel
(
layers
.
Layer
):
def
__init__
(
self
,
layers
):
super
(
DataParallel
,
self
).
__init__
(
layers
.
full_name
()
+
"_data_parallel"
)
self
.
_layers
=
layers
def
build_once
(
self
,
*
inputs
,
**
kwargs
):
#TODO(Yancey1989): broadcast all the paramters
pass
def
forward
(
self
,
*
inputs
,
**
kwargs
):
def
_collective_hook
(
iop
):
op
=
framework
.
_dygraph_tracer
().
_ops
[
iop
.
_trace_id
]
for
k
,
v
in
six
.
iteritems
(
op
.
inputs
):
for
ivar
in
v
:
g
=
ivar
.
_grad_ivar
()
if
g
:
g_var
=
framework
.
Variable
(
block
=
self
.
_helper
.
main_program
.
current_block
(),
name
=
ivar
.
_grad_name
(),
stop_gradient
=
True
,
ivar
=
g
)
collective
.
_allreduce
(
g_var
,
g_var
,
sync_mode
=
True
)
outs
=
self
.
_layers
(
*
inputs
,
**
kwargs
)
for
_
,
op
in
six
.
iteritems
(
framework
.
_dygraph_tracer
().
_ops
):
# hook collective ops
op
.
iop
.
register_backward_hooks
(
_collective_hook
,
front
=
True
)
return
outs
This diff is collapsed.
Click to expand it.
python/paddle/fluid/layers/collective.py
浏览文件 @
0b07eef1
...
...
@@ -16,7 +16,7 @@ from __future__ import print_function
from
..layer_helper
import
LayerHelper
,
unique_name
def
_allreduce
(
x
,
out
=
None
,
reduce_type
=
"sum"
):
def
_allreduce
(
x
,
out
=
None
,
reduce_type
=
"sum"
,
sync_mode
=
False
):
helper
=
LayerHelper
(
"allreduce"
,
**
locals
())
# Convert string reduce type to op int type
red_typ_int
=
0
...
...
@@ -43,5 +43,6 @@ def _allreduce(x, out=None, reduce_type="sum"):
type
=
'allreduce'
,
inputs
=
{
'X'
:
[
x
]},
outputs
=
{
'Out'
:
[
out
]},
attrs
=
{
"reduce_type"
:
red_typ_int
})
attrs
=
{
"reduce_type"
:
red_typ_int
,
"sync_mode"
:
sync_mode
})
return
out
This diff is collapsed.
Click to expand it.
python/paddle/fluid/tests/unittests/CMakeLists.txt
浏览文件 @
0b07eef1
...
...
@@ -19,6 +19,7 @@ endif(NOT WITH_DISTRIBUTE)
if
(
NOT
${
WITH_GPU
}
)
LIST
(
REMOVE_ITEM TEST_OPS test_conv2d_fusion_op
)
LIST
(
REMOVE_ITEM TEST_OPS test_parallel_dygraph_mnist
)
# TODO(Yancey1989): parallel dygraph support CPU device in future
elseif
(
${
CUDNN_VERSION
}
VERSION_LESS 7100
)
LIST
(
REMOVE_ITEM TEST_OPS test_conv2d_fusion_op
)
endif
()
...
...
This diff is collapsed.
Click to expand it.
python/paddle/fluid/tests/unittests/parallel_dygraph_mnist.py
0 → 100644
浏览文件 @
0b07eef1
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
os
import
contextlib
import
unittest
import
numpy
as
np
import
six
import
pickle
import
paddle
import
paddle.fluid
as
fluid
import
paddle.fluid.dygraph
as
dygraph
from
paddle.fluid
import
core
from
paddle.fluid.optimizer
import
SGDOptimizer
from
paddle.fluid.dygraph.nn
import
Conv2D
,
Pool2D
,
FC
from
paddle.fluid.dygraph.base
import
to_variable
from
test_dist_base
import
runtime_main
,
TestParallelDyGraphRunnerBase
class
SimpleImgConvPool
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
,
name_scope
,
num_channels
,
num_filters
,
filter_size
,
pool_size
,
pool_stride
,
pool_padding
=
0
,
pool_type
=
'max'
,
global_pooling
=
False
,
conv_stride
=
1
,
conv_padding
=
0
,
conv_dilation
=
1
,
conv_groups
=
1
,
act
=
None
,
use_cudnn
=
False
,
param_attr
=
None
,
bias_attr
=
None
):
super
(
SimpleImgConvPool
,
self
).
__init__
(
name_scope
)
self
.
_conv2d
=
Conv2D
(
self
.
full_name
(),
num_channels
=
num_channels
,
num_filters
=
num_filters
,
filter_size
=
filter_size
,
stride
=
conv_stride
,
padding
=
conv_padding
,
dilation
=
conv_dilation
,
groups
=
conv_groups
,
param_attr
=
None
,
bias_attr
=
None
,
use_cudnn
=
use_cudnn
)
self
.
_pool2d
=
Pool2D
(
self
.
full_name
(),
pool_size
=
pool_size
,
pool_type
=
pool_type
,
pool_stride
=
pool_stride
,
pool_padding
=
pool_padding
,
global_pooling
=
global_pooling
,
use_cudnn
=
use_cudnn
)
def
forward
(
self
,
inputs
):
x
=
self
.
_conv2d
(
inputs
)
x
=
self
.
_pool2d
(
x
)
return
x
class
MNIST
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
,
name_scope
):
super
(
MNIST
,
self
).
__init__
(
name_scope
)
self
.
_simple_img_conv_pool_1
=
SimpleImgConvPool
(
self
.
full_name
(),
1
,
20
,
5
,
2
,
2
,
act
=
"relu"
)
self
.
_simple_img_conv_pool_2
=
SimpleImgConvPool
(
self
.
full_name
(),
20
,
50
,
5
,
2
,
2
,
act
=
"relu"
)
pool_2_shape
=
50
*
4
*
4
SIZE
=
10
scale
=
(
2.0
/
(
pool_2_shape
**
2
*
SIZE
))
**
0.5
self
.
_fc
=
FC
(
self
.
full_name
(),
10
,
param_attr
=
fluid
.
param_attr
.
ParamAttr
(
initializer
=
fluid
.
initializer
.
NormalInitializer
(
loc
=
0.0
,
scale
=
scale
)),
act
=
"softmax"
)
def
forward
(
self
,
inputs
):
x
=
self
.
_simple_img_conv_pool_1
(
inputs
)
x
=
self
.
_simple_img_conv_pool_2
(
x
)
x
=
self
.
_fc
(
x
)
return
x
class
TestMnist
(
TestParallelDyGraphRunnerBase
):
def
get_model
(
self
):
model
=
MNIST
(
"mnist"
)
train_reader
=
paddle
.
batch
(
paddle
.
dataset
.
mnist
.
train
(),
batch_size
=
2
,
drop_last
=
True
)
opt
=
SGDOptimizer
(
learning_rate
=
1e-3
)
return
model
,
train_reader
,
opt
def
run_one_loop
(
self
,
model
,
opt
,
data
):
batch_size
=
len
(
data
)
dy_x_data
=
np
.
array
([
x
[
0
].
reshape
(
1
,
28
,
28
)
for
x
in
data
]).
astype
(
'float32'
)
y_data
=
np
.
array
(
[
x
[
1
]
for
x
in
data
]).
astype
(
'int64'
).
reshape
(
batch_size
,
1
)
img
=
to_variable
(
dy_x_data
)
label
=
to_variable
(
y_data
)
label
.
stop_gradient
=
True
cost
=
model
(
img
)
loss
=
fluid
.
layers
.
cross_entropy
(
cost
,
label
)
avg_loss
=
fluid
.
layers
.
mean
(
loss
)
return
avg_loss
if
__name__
==
"__main__"
:
runtime_main
(
TestMnist
)
This diff is collapsed.
Click to expand it.
python/paddle/fluid/tests/unittests/test_dist_base.py
浏览文件 @
0b07eef1
...
...
@@ -27,6 +27,9 @@ import numpy as np
import
paddle.fluid
as
fluid
from
paddle.fluid
import
compiler
import
paddle.fluid.dygraph
as
dygraph
from
paddle.fluid.dygraph.base
import
to_variable
from
paddle.fluid.dygraph.parallel
import
DataParallel
RUN_STEP
=
10
DEFAULT_BATCH_SIZE
=
2
...
...
@@ -187,6 +190,68 @@ class TestDistRunnerBase(object):
sys
.
stdout
.
buffer
.
write
(
pickle
.
dumps
(
out_losses
))
class
TestParallelDyGraphRunnerBase
(
object
):
def
get_model
(
self
):
raise
NotImplementedError
(
"get_model should be implemented by child classes."
)
def
run_one_loop
(
self
,
model
,
opt
,
data
):
raise
NotImplementedError
(
"train_one_loop should be implemented by the child classes."
)
def
run_trainer
(
self
,
args
):
seed
=
90
device_id
=
int
(
os
.
getenv
(
"FLAGS_selected_gpus"
,
"0"
))
place
=
fluid
.
CUDAPlace
(
device_id
)
def
_get_data
(
batch
):
if
args
.
update_method
!=
"local"
:
new_batch
=
[]
for
offset
,
item
in
enumerate
(
batch
):
if
offset
%
2
==
args
.
trainer_id
:
new_batch
.
append
(
item
)
return
new_batch
else
:
return
batch
with
fluid
.
dygraph
.
guard
(
place
):
fluid
.
default_startup_program
().
random_seed
=
seed
fluid
.
default_main_program
().
random_seed
=
seed
model
,
train_reader
,
opt
=
self
.
get_model
()
nranks
=
len
(
args
.
endpoints
.
split
(
","
))
if
args
.
endpoints
else
1
if
args
.
update_method
==
"nccl2"
:
sys
.
stderr
.
write
(
""
)
model
=
dygraph
.
parallel
.
DataParallel
(
model
)
strategy
=
dygraph
.
parallel
.
ParallelStrategy
()
strategy
.
nranks
=
nranks
strategy
.
local_rank
=
args
.
trainer_id
strategy
.
trainer_endpoints
=
args
.
endpoints
.
split
(
","
)
strategy
.
current_endpoint
=
args
.
current_endpoint
dygraph
.
parallel
.
prepare_context
(
strategy
)
out_losses
=
[]
for
step_id
,
data
in
enumerate
(
train_reader
()):
data
=
_get_data
(
data
)
if
step_id
==
RUN_STEP
:
break
loss
=
self
.
run_one_loop
(
model
,
opt
,
data
)
# FIXME(Yancey1989): scale the loss inplace
loss
.
stop_gradient
=
True
loss_scale
=
to_variable
(
np
.
array
([
nranks
]).
astype
(
"float32"
))
loss
=
loss
/
loss_scale
out_losses
.
append
(
loss
.
numpy
())
loss
.
backward
()
opt
.
minimize
(
loss
)
model
.
clear_gradients
()
if
six
.
PY2
:
print
(
pickle
.
dumps
(
out_losses
))
else
:
sys
.
stdout
.
buffer
.
write
(
pickle
.
dumps
(
out_losses
))
def
runtime_main
(
test_class
):
parser
=
argparse
.
ArgumentParser
(
description
=
'Run dist test.'
)
parser
.
add_argument
(
...
...
@@ -275,6 +340,7 @@ class TestDistBase(unittest.TestCase):
self
.
_nccl2_reduce_layer
=
False
self
.
_lr
=
0.001
self
.
_use_dgc
=
False
self
.
_dygraph
=
False
self
.
_setup_config
()
self
.
_after_setup_config
()
...
...
@@ -597,6 +663,9 @@ class TestDistBase(unittest.TestCase):
local_loss
=
local_losses
[
step_id
]
tr0_loss
=
tr0_losses
[
step_id
]
tr1_loss
=
tr1_losses
[
step_id
]
dist_loss
=
(
np
.
array
([
tr0_loss
])
+
np
.
array
([
tr1_loss
]))
/
2
dist_loss
=
(
np
.
array
([
tr0_loss
])
+
np
.
array
([
tr1_loss
]))
if
not
self
.
_dygraph
:
# Parallel DyGraph already scaled the loss in training
dist_loss
=
dist_loss
/
2
print
(
"======="
,
local_loss
,
":"
,
dist_loss
[
0
],
"======="
)
self
.
assertAlmostEqual
(
local_loss
,
dist_loss
[
0
],
delta
=
delta
)
This diff is collapsed.
Click to expand it.
python/paddle/fluid/tests/unittests/test_parallel_dygraph_mnist.py
0 → 100644
浏览文件 @
0b07eef1
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
from
test_dist_base
import
TestDistBase
class
TestParallelDygraphMnist
(
TestDistBase
):
def
_setup_config
(
self
):
self
.
_sync_mode
=
False
self
.
_nccl2_mode
=
True
self
.
_dygraph
=
True
def
test_mnist
(
self
):
self
.
check_with_place
(
"parallel_dygraph_mnist.py"
,
delta
=
1e-5
,
check_error_log
=
True
)
if
__name__
==
"__main__"
:
unittest
.
main
()
This diff is collapsed.
Click to expand it.
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录