Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
b1026f64
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
b1026f64
编写于
2月 03, 2021
作者:
W
WangXi
提交者:
GitHub
2月 03, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
【kunlun】dygraph supports multi xpu card training (#30671)
上级
3a3ff75c
变更
33
显示空白变更内容
内联
并排
Showing
33 changed file
with
1225 addition
and
120 deletion
+1225
-120
paddle/fluid/imperative/CMakeLists.txt
paddle/fluid/imperative/CMakeLists.txt
+4
-0
paddle/fluid/imperative/bkcl_context.cc
paddle/fluid/imperative/bkcl_context.cc
+172
-0
paddle/fluid/imperative/bkcl_context.h
paddle/fluid/imperative/bkcl_context.h
+53
-0
paddle/fluid/imperative/reducer.cc
paddle/fluid/imperative/reducer.cc
+84
-3
paddle/fluid/imperative/reducer.h
paddle/fluid/imperative/reducer.h
+1
-1
paddle/fluid/imperative/tests/CMakeLists.txt
paddle/fluid/imperative/tests/CMakeLists.txt
+4
-1
paddle/fluid/imperative/tests/bkcl_context_test.cc
paddle/fluid/imperative/tests/bkcl_context_test.cc
+66
-0
paddle/fluid/imperative/tests/test_group.cc
paddle/fluid/imperative/tests/test_group.cc
+18
-3
paddle/fluid/operators/collective/CMakeLists.txt
paddle/fluid/operators/collective/CMakeLists.txt
+4
-0
paddle/fluid/operators/collective/broadcast_op_xpu.cc
paddle/fluid/operators/collective/broadcast_op_xpu.cc
+96
-0
paddle/fluid/operators/math/concat_and_split.cc
paddle/fluid/operators/math/concat_and_split.cc
+98
-0
paddle/fluid/operators/math/math_function.h
paddle/fluid/operators/math/math_function.h
+8
-13
paddle/fluid/operators/math/math_function_impl.h
paddle/fluid/operators/math/math_function_impl.h
+3
-2
paddle/fluid/platform/collective_helper.cc
paddle/fluid/platform/collective_helper.cc
+128
-4
paddle/fluid/platform/collective_helper.h
paddle/fluid/platform/collective_helper.h
+97
-3
paddle/fluid/platform/device_context.cc
paddle/fluid/platform/device_context.cc
+3
-0
paddle/fluid/platform/gen_comm_id_helper.cc
paddle/fluid/platform/gen_comm_id_helper.cc
+2
-2
paddle/fluid/platform/gen_comm_id_helper.h
paddle/fluid/platform/gen_comm_id_helper.h
+1
-1
paddle/fluid/platform/xpu_info.h
paddle/fluid/platform/xpu_info.h
+23
-0
paddle/fluid/pybind/CMakeLists.txt
paddle/fluid/pybind/CMakeLists.txt
+10
-0
paddle/fluid/pybind/imperative.cc
paddle/fluid/pybind/imperative.cc
+20
-7
paddle/fluid/pybind/tensor_py.h
paddle/fluid/pybind/tensor_py.h
+37
-1
python/paddle/distributed/fleet/launch.py
python/paddle/distributed/fleet/launch.py
+31
-18
python/paddle/distributed/fleet/launch_utils.py
python/paddle/distributed/fleet/launch_utils.py
+63
-10
python/paddle/distributed/parallel.py
python/paddle/distributed/parallel.py
+20
-8
python/paddle/fluid/dygraph/parallel.py
python/paddle/fluid/dygraph/parallel.py
+11
-4
python/paddle/fluid/tests/unittests/detected_xpu.py
python/paddle/fluid/tests/unittests/detected_xpu.py
+25
-0
python/paddle/fluid/tests/unittests/nproc_process.py
python/paddle/fluid/tests/unittests/nproc_process.py
+7
-3
python/paddle/fluid/tests/unittests/test_dist_base.py
python/paddle/fluid/tests/unittests/test_dist_base.py
+58
-19
python/paddle/fluid/tests/unittests/test_dist_mnist_fleet_save.py
...addle/fluid/tests/unittests/test_dist_mnist_fleet_save.py
+2
-2
python/paddle/fluid/tests/unittests/test_dist_sharding_save.py
...n/paddle/fluid/tests/unittests/test_dist_sharding_save.py
+9
-6
python/paddle/fluid/tests/unittests/test_fleet_launch_nproc.sh
...n/paddle/fluid/tests/unittests/test_fleet_launch_nproc.sh
+48
-9
python/paddle/fluid/tests/unittests/test_parallel_dygraph_mnist.py
...ddle/fluid/tests/unittests/test_parallel_dygraph_mnist.py
+19
-0
未找到文件。
paddle/fluid/imperative/CMakeLists.txt
浏览文件 @
b1026f64
...
...
@@ -14,6 +14,10 @@ if(NOT WIN32)
cc_library
(
nccl_context SRCS nccl_context.cc DEPS collective_helper device_context imperative_all_reduce var_type_traits
)
cc_library
(
reducer SRCS reducer.cc DEPS layer imperative_all_reduce
)
endif
()
if
(
WITH_XPU_BKCL
)
cc_library
(
bkcl_context SRCS bkcl_context.cc DEPS collective_helper device_context tensor var_type_traits
)
cc_library
(
reducer SRCS reducer.cc DEPS layer
)
endif
()
cc_library
(
data_loader SRCS data_loader.cc DEPS enforce
)
endif
(
NOT WIN32
)
...
...
paddle/fluid/imperative/bkcl_context.cc
0 → 100644
浏览文件 @
b1026f64
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#if defined(PADDLE_WITH_XPU_BKCL)
#include "paddle/fluid/imperative/bkcl_context.h"
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/platform/bkcl_helper.h"
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/gen_comm_id_helper.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/string/split.h"
#include "paddle/fluid/string/string_helper.h"
namespace
paddle
{
namespace
imperative
{
static
void
AllReduce
(
const
framework
::
Tensor
&
src
,
framework
::
Tensor
*
dst
,
const
XPUStream
stream
,
const
platform
::
BKCLComm
*
comm
)
{
const
auto
&
place
=
src
.
place
();
PADDLE_ENFORCE_EQ
(
platform
::
is_xpu_place
(
place
),
true
,
platform
::
errors
::
Unimplemented
(
"Dynamic graph mode does not support multi-CPU training yet."
));
const
void
*
src_ptr
=
src
.
data
<
void
>
();
dst
->
Resize
(
src
.
dims
());
auto
*
dst_ptr
=
dst
->
mutable_data
(
src
.
place
(),
src
.
type
());
auto
bkcl_dtype
=
platform
::
ToBKCLDataType
(
src
.
type
());
PADDLE_ENFORCE_EQ
(
bkcl_all_reduce
(
comm
->
comm
(),
src_ptr
,
dst_ptr
,
src
.
numel
(),
bkcl_dtype
,
BKCL_ADD
,
stream
),
BKCL_SUCCESS
,
platform
::
errors
::
PreconditionNotMet
(
"BKCL all reduce failed"
));
}
/*
Baidu Kunlun Communication Library(BKCL) is designed for multi Baidu Kunlun
cards communication
as NVIDIA Collective Communications Library(NCCL) in multi Nvidia GPU cards.
Please refer to bkcl.h in xpu.tar.gz linked in cmake/external/xpu.cmake.
*/
void
BKCLParallelContext
::
BcastBKCLId
(
std
::
vector
<
BKCLUniqueId
>
&
bkcl_ids
,
// NOLINT
int
root
)
{
if
(
strategy_
.
local_rank_
==
root
)
{
std
::
vector
<
std
::
string
>
other_trainers
;
for
(
auto
&
ep
:
strategy_
.
trainer_endpoints_
)
{
if
(
ep
!=
strategy_
.
current_endpoint_
)
{
other_trainers
.
push_back
(
ep
);
}
}
platform
::
SendBroadCastCommID
(
other_trainers
,
&
bkcl_ids
);
}
else
{
platform
::
RecvBroadCastCommID
(
strategy_
.
current_endpoint_
,
&
bkcl_ids
);
}
}
void
BKCLParallelContext
::
Init
()
{
std
::
vector
<
BKCLUniqueId
>
bkcl_ids
;
bkcl_ids
.
resize
(
strategy_
.
nrings_
);
if
(
strategy_
.
local_rank_
==
0
)
{
// generate the unique ncclid on the root worker
for
(
size_t
i
=
0
;
i
<
bkcl_ids
.
size
();
++
i
)
{
auto
ret
=
bkcl_get_unique_id
(
&
bkcl_ids
[
i
]);
PADDLE_ENFORCE_EQ
(
BKCL_SUCCESS
,
ret
,
platform
::
errors
::
PreconditionNotMet
(
"BKCL get unique id failed [%d]"
,
ret
));
}
}
BcastBKCLId
(
bkcl_ids
,
0
);
int
xpu_id
=
BOOST_GET_CONST
(
platform
::
XPUPlace
,
place_
).
device
;
for
(
int
ring_id
=
0
;
ring_id
<
strategy_
.
nrings_
;
ring_id
++
)
{
VLOG
(
0
)
<<
"init BKCL context nranks: "
<<
strategy_
.
nranks_
<<
" local rank: "
<<
strategy_
.
local_rank_
<<
" xpu id: "
<<
xpu_id
<<
" ring id: "
<<
ring_id
;
// it will assign bkcl_comm in XPUDeviceContext within ring_id
platform
::
BKCLCommContext
::
Instance
().
CreateBKCLComm
(
&
bkcl_ids
[
ring_id
],
strategy_
.
nranks_
,
strategy_
.
local_rank_
,
xpu_id
,
ring_id
);
}
}
void
BKCLParallelContext
::
AllReduceByStream
(
const
framework
::
Variable
&
src
,
framework
::
Variable
*
dst
,
int
ring_id
,
bool
use_calc_stream
)
{
PADDLE_ENFORCE_EQ
(
platform
::
is_xpu_place
(
place_
),
true
,
platform
::
errors
::
Unimplemented
(
"Dynamic graph mode does not support multi-CPU training yet."
));
auto
place
=
place_
;
auto
*
dev_ctx
=
static_cast
<
platform
::
XPUDeviceContext
*>
(
platform
::
DeviceContextPool
::
Instance
().
Get
(
place
));
platform
::
BKCLComm
*
comm
=
platform
::
BKCLCommContext
::
Instance
().
Get
(
ring_id
,
place
);
XPUStream
stream
=
use_calc_stream
?
dev_ctx
->
x_context
()
->
xpu_stream
:
comm
->
stream
();
if
(
src
.
IsType
<
framework
::
LoDTensor
>
())
{
if
(
!
dst
->
IsType
<
framework
::
LoDTensor
>
())
{
dst
->
Clear
();
}
AllReduce
(
src
.
Get
<
framework
::
LoDTensor
>
(),
dst
->
GetMutable
<
framework
::
LoDTensor
>
(),
stream
,
comm
);
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"XPU unsupported variable type %s for imperative allreduce, only "
"LoDTensor are supported."
,
platform
::
demangle
(
framework
::
ToTypeName
(
src
.
Type
()))));
}
}
paddle
::
platform
::
DeviceContext
*
BKCLParallelContext
::
GetDeviceContext
(
int
ring_id
)
{
return
static_cast
<
platform
::
DeviceContext
*>
(
platform
::
BKCLCommContext
::
Instance
()
.
Get
(
ring_id
,
place_
)
->
dev_context
());
}
void
BKCLParallelContext
::
WaitCompute
(
int
ring_id
)
{
PADDLE_ENFORCE_GE
(
ring_id
,
0
,
platform
::
errors
::
OutOfRange
(
"Ring id expected >= 0, but got %d"
,
ring_id
));
PADDLE_ENFORCE_LT
(
ring_id
,
strategy_
.
nrings_
,
platform
::
errors
::
OutOfRange
(
"Ring id expected < nrings,"
"but got ring id = %d, nrings = %d"
,
ring_id
,
strategy_
.
nrings_
));
// TODO(wangxi16): [Performance optimize] Maybe need to put Wait and
// bkcl_allreduce to comm thread, for bkcl_allreduce is blocking now.
auto
compute_dev_ctx
=
static_cast
<
platform
::
XPUDeviceContext
*>
(
platform
::
DeviceContextPool
::
Instance
().
Get
(
place_
));
compute_dev_ctx
->
Wait
();
}
void
BKCLParallelContext
::
WaitComm
(
int
ring_id
)
{
PADDLE_ENFORCE_GE
(
ring_id
,
0
,
platform
::
errors
::
OutOfRange
(
"Ring id expected >= 0, but got %d"
,
ring_id
));
PADDLE_ENFORCE_LT
(
ring_id
,
strategy_
.
nrings_
,
platform
::
errors
::
OutOfRange
(
"Ring id expected < nrings,"
"but got ring id = %d, nrings = %d"
,
ring_id
,
strategy_
.
nrings_
));
auto
comm_dev_ctx
=
platform
::
BKCLCommContext
::
Instance
().
Get
(
ring_id
,
place_
)
->
dev_context
();
comm_dev_ctx
->
Wait
();
}
}
// namespace imperative
}
// namespace paddle
#endif
paddle/fluid/imperative/bkcl_context.h
0 → 100644
浏览文件 @
b1026f64
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#if defined(PADDLE_WITH_XPU_BKCL)
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/imperative/parallel_context.h"
#include "xpu/bkcl.h"
namespace
paddle
{
namespace
imperative
{
class
BKCLParallelContext
:
public
ParallelContext
{
public:
explicit
BKCLParallelContext
(
const
ParallelStrategy
&
strategy
,
const
platform
::
Place
&
place
)
:
ParallelContext
(
strategy
,
place
)
{}
~
BKCLParallelContext
()
override
=
default
;
void
BcastBKCLId
(
std
::
vector
<
BKCLUniqueId
>&
bkcl_ids
,
int
root
);
// NOLINT
void
Init
()
override
;
void
AllReduceByStream
(
const
framework
::
Variable
&
src
,
framework
::
Variable
*
dst
,
int
ring_id
,
bool
use_calc_stream
)
override
;
paddle
::
platform
::
DeviceContext
*
GetDeviceContext
(
int
ring_id
)
override
;
void
WaitCompute
(
int
ring_id
)
override
;
void
WaitComm
(
int
ring_id
)
override
;
};
}
// namespace imperative
}
// namespace paddle
#endif
paddle/fluid/imperative/reducer.cc
浏览文件 @
b1026f64
...
...
@@ -30,17 +30,15 @@
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/string/string_helper.h"
#if defined(PADDLE_WITH_NCCL)
#include "paddle/fluid/operators/math/concat_and_split.h"
#include "paddle/fluid/operators/strided_memcpy.h"
#endif
#include "paddle/fluid/imperative/parallel_context.h"
namespace
paddle
{
namespace
imperative
{
#if
defined(PADDLE_WITH_NC
CL)
#if
(defined PADDLE_WITH_NCCL) || (defined PADDLE_WITH_XPU_BK
CL)
template
<
typename
DeviceContext
,
typename
T
>
static
void
ConcatTensorsForAllReduce
(
const
DeviceContext
&
context
,
...
...
@@ -130,6 +128,69 @@ static void SplitTensorsWithType(
}
}
#ifdef PADDLE_WITH_XPU_BKCL
template
<
>
void
SplitTensorsForAllReduce
<
platform
::
XPUDeviceContext
,
float
>
(
const
platform
::
XPUDeviceContext
&
context
,
framework
::
Variable
*
p_dense_contents
,
std
::
vector
<
framework
::
Tensor
>
*
p_dense_tensors
)
{
auto
*
in
=
p_dense_contents
->
GetMutable
<
framework
::
LoDTensor
>
();
std
::
vector
<
framework
::
Tensor
*>
outs
;
std
::
vector
<
const
framework
::
Tensor
*>
shape_refer
;
outs
.
reserve
(
p_dense_tensors
->
size
());
shape_refer
.
reserve
(
p_dense_tensors
->
size
());
for
(
auto
&
tensor
:
*
p_dense_tensors
)
{
outs
.
emplace_back
(
&
tensor
);
shape_refer
.
emplace_back
(
&
tensor
);
}
operators
::
math
::
SplitFunctor
<
platform
::
XPUDeviceContext
,
float
>
split_functor_
;
split_functor_
(
context
,
*
in
,
shape_refer
,
0
,
&
outs
);
}
// context is used to select the stream for concat
template
<
>
void
ConcatTensorsWithType
<
platform
::
XPUDeviceContext
>
(
const
platform
::
XPUDeviceContext
&
context
,
const
std
::
vector
<
framework
::
Tensor
>
&
dense_tensors_
,
framework
::
Variable
*
p_dense_contents
,
framework
::
proto
::
VarType
::
Type
type
)
{
switch
(
type
)
{
case
framework
::
proto
::
VarType
::
FP32
:
ConcatTensorsForAllReduce
<
platform
::
XPUDeviceContext
,
float
>
(
context
,
dense_tensors_
,
p_dense_contents
);
break
;
default:
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Data type (%s) is not supported when it concats tensors for "
"allreduce."
,
framework
::
DataTypeToString
(
type
)));
}
}
// context is used to select the stream for split
template
<
>
void
SplitTensorsWithType
<
platform
::
XPUDeviceContext
>
(
const
platform
::
XPUDeviceContext
&
context
,
framework
::
Variable
*
p_dense_contents
,
std
::
vector
<
framework
::
Tensor
>
*
p_dense_tensors
,
framework
::
proto
::
VarType
::
Type
type
)
{
switch
(
type
)
{
case
framework
::
proto
::
VarType
::
FP32
:
SplitTensorsForAllReduce
<
platform
::
XPUDeviceContext
,
float
>
(
context
,
p_dense_contents
,
p_dense_tensors
);
break
;
default:
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Data type (%s) is not supported when it splits tensors for "
"allreduce."
,
framework
::
DataTypeToString
(
type
)));
}
}
#endif
void
Group
::
ConcatTensors
(
const
platform
::
DeviceContext
&
context
)
{
VLOG
(
3
)
<<
"Before concat, set output tensor size is "
<<
all_length_
;
auto
tensor
=
dense_contents_
.
GetMutable
<
framework
::
LoDTensor
>
();
...
...
@@ -146,6 +207,16 @@ void Group::ConcatTensors(const platform::DeviceContext &context) {
PADDLE_THROW
(
platform
::
errors
::
PermissionDenied
(
"Paddle can't concat grad tensors since it's not compiled with NCCL,"
"Please recompile or reinstall Paddle with NCCL support."
));
#endif
}
else
if
(
platform
::
is_xpu_place
(
place
))
{
#ifdef PADDLE_WITH_XPU_BKCL
ConcatTensorsWithType
(
static_cast
<
const
platform
::
XPUDeviceContext
&>
(
context
),
dense_tensors_
,
&
dense_contents_
,
dtype_
);
#else
PADDLE_THROW
(
platform
::
errors
::
PermissionDenied
(
"Paddle can't concat xpu grads since it's not compiled with BKCL,"
"Please recompile or reinstall Paddle with BKCL support."
));
#endif
}
else
if
(
platform
::
is_cpu_place
(
place
))
{
ConcatTensorsWithType
(
...
...
@@ -168,6 +239,16 @@ void Group::SplitTensors(const platform::DeviceContext &context) {
PADDLE_THROW
(
platform
::
errors
::
PermissionDenied
(
"Paddle can't split grad tensor since it's not compiled with NCCL,"
"Please recompile or reinstall Paddle with NCCL support."
));
#endif
}
else
if
(
platform
::
is_xpu_place
(
place
))
{
#ifdef PADDLE_WITH_XPU_BKCL
SplitTensorsWithType
(
static_cast
<
const
platform
::
XPUDeviceContext
&>
(
context
),
&
dense_contents_
,
&
dense_tensors_
,
dtype_
);
#else
PADDLE_THROW
(
platform
::
errors
::
PermissionDenied
(
"Paddle can't split xpu grad since it's not compiled with BKCL,"
"Please recompile or reinstall Paddle with BKCL support."
));
#endif
}
else
if
(
platform
::
is_cpu_place
(
place
))
{
SplitTensorsWithType
(
...
...
paddle/fluid/imperative/reducer.h
浏览文件 @
b1026f64
...
...
@@ -44,7 +44,7 @@ class VariableWrapper;
namespace
paddle
{
namespace
imperative
{
#if
defined(PADDLE_WITH_NC
CL)
#if
(defined PADDLE_WITH_NCCL) || (defined PADDLE_WITH_XPU_BK
CL)
class
Group
{
public:
// Here, we use dense_contents_ & sparse_contents_ to
...
...
paddle/fluid/imperative/tests/CMakeLists.txt
浏览文件 @
b1026f64
...
...
@@ -4,6 +4,9 @@ else()
if
(
WITH_NCCL
)
cc_test
(
nccl_context_test SRCS nccl_context_test.cc DEPS nccl_context
)
endif
()
if
(
WITH_XPU_BKCL
)
cc_test
(
bkcl_context_test SRCS bkcl_context_test.cc DEPS bkcl_context
)
endif
()
endif
(
WIN32
)
...
...
@@ -13,6 +16,6 @@ cc_test(test_prepare_op SRCS test_prepare_op.cc DEPS prepared_operator op_info s
cc_test
(
test_tracer SRCS test_tracer.cc DEPS tracer layer proto_desc operator op_registry variable_helper mul_op reduce_sum_op elementwise_add_op memcpy
)
cc_test
(
test_hooks SRCS test_hooks.cc DEPS tracer basic_engine layer proto_desc operator op_registry variable_helper mul_op elementwise_add_op memcpy
)
if
(
WITH_NCCL
)
if
(
WITH_NCCL
OR WITH_XPU_BKCL
)
cc_test
(
test_group SRCS test_group.cc DEPS reducer concat_and_split memcpy
)
endif
()
paddle/fluid/imperative/tests/bkcl_context_test.cc
0 → 100644
浏览文件 @
b1026f64
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <thread> // NOLINT
#include "paddle/fluid/imperative/bkcl_context.h"
#include "gtest/gtest.h"
namespace
imperative
=
paddle
::
imperative
;
namespace
platform
=
paddle
::
platform
;
int
nrings
=
2
;
imperative
::
ParallelStrategy
GetStrategy
(
int
local_rank
)
{
std
::
vector
<
std
::
string
>
eps
=
{
"127.0.0.1:9866"
,
"localhost:9867"
};
imperative
::
ParallelStrategy
strategy
;
strategy
.
trainer_endpoints_
=
eps
;
strategy
.
current_endpoint_
=
eps
[
local_rank
];
strategy
.
nranks_
=
2
;
strategy
.
local_rank_
=
local_rank
;
strategy
.
nrings_
=
nrings
;
return
strategy
;
}
#if defined(PADDLE_WITH_XPU_BKCL)
void
BcastBKCLId
(
int
local_rank
,
std
::
vector
<
BKCLUniqueId
>*
bkcl_ids
)
{
auto
strategy
=
GetStrategy
(
local_rank
);
platform
::
XPUPlace
xpu
(
local_rank
);
imperative
::
BKCLParallelContext
ctx
(
strategy
,
xpu
);
ctx
.
BcastBKCLId
(
*
bkcl_ids
,
0
);
}
TEST
(
BcastBKCLId
,
Run
)
{
std
::
vector
<
BKCLUniqueId
>
bkcl_ids
;
bkcl_ids
.
resize
(
nrings
);
for
(
int
i
=
0
;
i
<
nrings
;
++
i
)
{
bkcl_get_unique_id
(
&
bkcl_ids
[
i
]);
}
std
::
thread
t
(
BcastBKCLId
,
0
,
&
bkcl_ids
);
std
::
vector
<
BKCLUniqueId
>
recv_bkcl_ids
;
recv_bkcl_ids
.
resize
(
nrings
);
for
(
int
i
=
0
;
i
<
nrings
;
++
i
)
{
bkcl_get_unique_id
(
&
recv_bkcl_ids
[
i
]);
}
BcastBKCLId
(
1
,
&
recv_bkcl_ids
);
t
.
join
();
for
(
int
i
=
0
;
i
<
nrings
;
++
i
)
{
EXPECT_EQ
(
0
,
std
::
memcmp
(
&
bkcl_ids
[
i
],
&
recv_bkcl_ids
[
i
],
BKCL_UNIQUE_ID_BYTES
));
}
}
#endif
paddle/fluid/imperative/tests/test_group.cc
浏览文件 @
b1026f64
...
...
@@ -20,14 +20,11 @@
#include "glog/logging.h"
#include "gtest/gtest.h"
#if defined(PADDLE_WITH_NCCL)
#include "paddle/fluid/imperative/reducer.h"
#endif
namespace
paddle
{
namespace
imperative
{
#if defined(PADDLE_WITH_NCCL)
TEST
(
TestGroup
,
TestPrintGroupMessage
)
{
Group
group
;
std
::
stringstream
stream1
,
stream2
;
...
...
@@ -80,8 +77,10 @@ void GroupConcatSplit(Place place, size_t size) {
}
if
(
std
::
is_same
<
Place
,
platform
::
CUDAPlace
>::
value
)
{
#if defined(PADDLE_WITH_NCCL)
paddle
::
memory
::
Copy
(
place
,
data
,
cpu_place
,
value
.
data
(),
sizeof
(
T
)
*
value
.
size
(),
0
);
#endif
}
else
{
paddle
::
memory
::
Copy
(
place
,
data
,
cpu_place
,
value
.
data
(),
sizeof
(
T
)
*
value
.
size
());
...
...
@@ -134,6 +133,7 @@ void GroupConcatSplit(Place place, size_t size) {
}
}
#if defined(PADDLE_WITH_NCCL)
TEST
(
TestGroup
,
TestConcatSplit
)
{
platform
::
CUDAPlace
cuda_place
(
0
);
platform
::
CPUPlace
cpu_place
;
...
...
@@ -165,5 +165,20 @@ TEST(TestGroup, TestConcatSplitException) {
}
#endif
#if defined(PADDLE_WITH_XPU_BKCL)
TEST
(
TestGroup
,
TestXPUConcatSplit
)
{
platform
::
XPUPlace
xpu_place
(
0
);
platform
::
CPUPlace
cpu_place
;
int
size
=
3
;
GroupConcatSplit
<
float
>
(
cpu_place
,
size
);
GroupConcatSplit
<
float
>
(
xpu_place
,
size
);
size
=
15
;
GroupConcatSplit
<
float
>
(
cpu_place
,
size
);
GroupConcatSplit
<
float
>
(
xpu_place
,
size
);
}
#endif
}
// namespace imperative
}
// namespace paddle
paddle/fluid/operators/collective/CMakeLists.txt
浏览文件 @
b1026f64
...
...
@@ -19,6 +19,10 @@ if(WITH_NCCL)
op_library
(
gen_nccl_id_op DEPS
${
COLLECTIVE_DEPS
}
)
endif
()
if
(
WITH_BKCL
)
set
(
COLLECTIVE_DEPS
${
COLLECTIVE_DEPS
}
collective_helper
)
endif
()
if
(
WITH_GLOO
)
set
(
COLLECTIVE_DEPS
${
COLLECTIVE_DEPS
}
gloo_wrapper
)
endif
()
...
...
paddle/fluid/operators/collective/broadcast_op_xpu.cc
0 → 100644
浏览文件 @
b1026f64
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#if defined(PADDLE_WITH_XPU_BKCL)
#include "paddle/fluid/platform/bkcl_helper.h"
#include "paddle/fluid/platform/collective_helper.h"
#endif
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
class
BKCLBroadcastOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_ENFORCE_EQ
(
platform
::
is_xpu_place
(
ctx
.
GetPlace
()),
true
,
platform
::
errors
::
PreconditionNotMet
(
"The place of ExecutionContext should be XPUPlace."
));
#if defined(PADDLE_WITH_XPU_BKCL)
int
dev_id
=
BOOST_GET_CONST
(
platform
::
XPUPlace
,
ctx
.
GetPlace
()).
device
;
int
root_dev_id
=
ctx
.
Attr
<
int
>
(
"root"
);
auto
in
=
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"Out"
);
PADDLE_ENFORCE_EQ
(
out
->
IsInitialized
(),
true
,
platform
::
errors
::
PreconditionNotMet
(
"Currently, the output of broadcast op must be initialized,"
"because this op can only be an In-Place operation."
));
void
*
send_recv_buffer
=
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
PADDLE_ENFORCE_EQ
(
send_recv_buffer
,
in
->
data
<
void
>
(),
platform
::
errors
::
PreconditionNotMet
(
"Currently, the broadcast op can "
"only be an In-Place operation."
));
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
XPUDeviceContext
>();
auto
comm
=
dev_ctx
.
bkcl_context
();
auto
stream
=
dev_ctx
.
x_context
()
->
xpu_stream
;
// TODO(wangxi16): bkcl_broadcast only support float type,
// need to converted other type to float before broadcasting.
// Broadcast is equivalent to no type of operation, does not affect
// correctness.
// Once bkcl_broadcast support other type, need chang to:
// BKCLDataType data_type = platform::ToBKCLDataType(in->type());
BKCLDataType
data_type
=
BKCL_FLOAT
;
size_t
scale
=
sizeof
(
T
)
/
sizeof
(
float
);
auto
ret
=
bkcl_broadcast
(
comm
,
send_recv_buffer
,
send_recv_buffer
,
static_cast
<
size_t
>
(
in
->
numel
())
*
scale
,
data_type
,
root_dev_id
,
stream
);
PADDLE_ENFORCE_EQ
(
ret
,
BKCL_SUCCESS
,
platform
::
errors
::
Unavailable
(
"bkcl_broadcast failed"
));
VLOG
(
3
)
<<
"Bcast "
<<
ctx
.
InputNames
(
"X"
)[
0
]
<<
", ("
<<
in
->
numel
()
<<
")"
<<
" From "
<<
root_dev_id
<<
" to "
<<
dev_id
;
if
(
ctx
.
Attr
<
bool
>
(
"sync_mode"
))
{
dev_ctx
.
Wait
();
}
#else
PADDLE_THROW
(
platform
::
errors
::
PreconditionNotMet
(
"PaddlePaddle should compile with XPU."
));
#endif
}
};
}
// namespace operators
}
// namespace paddle
REGISTER_OP_XPU_KERNEL
(
broadcast
,
ops
::
BKCLBroadcastOpKernel
<
float
>
,
ops
::
BKCLBroadcastOpKernel
<
double
>
,
ops
::
BKCLBroadcastOpKernel
<
int
>
,
ops
::
BKCLBroadcastOpKernel
<
int64_t
>
);
paddle/fluid/operators/math/concat_and_split.cc
浏览文件 @
b1026f64
...
...
@@ -119,12 +119,110 @@ class SplitFunctor<platform::CPUDeviceContext, T> {
}
}
};
#ifdef PADDLE_WITH_XPU
/*
* All tensors' dimension should be the same and the values of
* each dimension must be the same, except the axis dimension.
*/
template
<
typename
T
>
class
ConcatFunctor
<
platform
::
XPUDeviceContext
,
T
>
{
public:
void
operator
()(
const
platform
::
XPUDeviceContext
&
context
,
const
std
::
vector
<
framework
::
Tensor
>&
input
,
int
axis
,
framework
::
Tensor
*
output
)
{
int
dev_id
=
BOOST_GET_CONST
(
platform
::
XPUPlace
,
context
.
GetPlace
()).
GetDeviceId
();
platform
::
XPUDeviceGuard
guard
(
dev_id
);
int
num
=
input
.
size
();
auto
input_dims
=
input
[
0
].
dims
();
std
::
vector
<
std
::
vector
<
int
>>
xdims_list
(
num
);
for
(
int
i
=
0
;
i
<
num
;
++
i
)
{
std
::
vector
<
int
>
tmp_dims
(
input_dims
.
size
());
for
(
int
j
=
0
;
j
<
input_dims
.
size
();
++
j
)
{
tmp_dims
[
j
]
=
input
[
i
].
dims
()[
j
];
}
xdims_list
[
i
]
=
tmp_dims
;
}
std
::
vector
<
const
T
*>
ptrs
;
for
(
int
i
=
0
;
i
<
num
;
++
i
)
{
ptrs
.
push_back
(
input
[
i
].
data
<
T
>
());
}
auto
r
=
xpu
::
concat
<
T
>
(
context
.
x_context
(),
ptrs
,
output
->
data
<
T
>
(),
xdims_list
,
axis
);
PADDLE_ENFORCE_EQ
(
r
,
XPU_SUCCESS
,
platform
::
errors
::
External
(
"XPU API return wrong value[%d %s], please check whether "
"Baidu Kunlun Card is properly installed."
,
r
,
XPUAPIErrorMsg
[
r
]));
}
};
template
<
typename
T
>
class
SplitFunctor
<
platform
::
XPUDeviceContext
,
T
>
{
public:
void
operator
()(
const
platform
::
XPUDeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
const
std
::
vector
<
const
framework
::
Tensor
*>&
ref_inputs
,
const
int
axis
,
std
::
vector
<
framework
::
Tensor
*>*
outputs
)
{
int
dev_id
=
BOOST_GET_CONST
(
platform
::
XPUPlace
,
context
.
GetPlace
()).
GetDeviceId
();
platform
::
XPUDeviceGuard
guard
(
dev_id
);
auto
&
ins
=
ref_inputs
;
int
num
=
ins
.
size
();
auto
input_dims
=
ins
[
0
]
->
dims
();
std
::
vector
<
int
>
split_list
(
num
);
std
::
vector
<
int
>
xdims_list
(
input_dims
.
size
());
int
total_length
=
0
;
for
(
int
i
=
0
;
i
<
num
;
++
i
)
{
split_list
[
i
]
=
ins
[
i
]
->
dims
()[
axis
];
total_length
+=
ins
[
i
]
->
dims
()[
axis
];
}
for
(
int
i
=
0
;
i
<
input_dims
.
size
();
++
i
)
{
if
(
i
==
axis
)
continue
;
xdims_list
[
i
]
=
input_dims
[
i
];
}
xdims_list
[
axis
]
=
total_length
;
std
::
vector
<
T
*>
ptrs
(
num
);
for
(
int
i
=
0
;
i
<
num
;
++
i
)
{
ptrs
[
i
]
=
outputs
->
at
(
i
)
->
data
<
T
>
();
}
auto
r
=
xpu
::
split
<
T
>
(
context
.
x_context
(),
input
.
data
<
T
>
(),
ptrs
,
xdims_list
,
split_list
,
axis
);
PADDLE_ENFORCE_EQ
(
r
,
XPU_SUCCESS
,
platform
::
errors
::
External
(
"XPU API return wrong value[%d %s], please check whether "
"Baidu Kunlun Card is properly installed."
,
r
,
XPUAPIErrorMsg
[
r
]));
}
};
#endif
#define DEFINE_FUNCTOR(type) \
template class ConcatFunctor<platform::CPUDeviceContext, type>; \
template class SplitFunctor<platform::CPUDeviceContext, type>;
FOR_ALL_TYPES
(
DEFINE_FUNCTOR
);
#ifdef PADDLE_WITH_XPU
#define DEFINE_XPU_FUNCTOR(type) \
template class ConcatFunctor<platform::XPUDeviceContext, type>; \
template class SplitFunctor<platform::XPUDeviceContext, type>;
DEFINE_XPU_FUNCTOR
(
float
)
#endif
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/math/math_function.h
浏览文件 @
b1026f64
...
...
@@ -88,27 +88,22 @@ struct RowwiseMean {
#ifdef PADDLE_WITH_XPU
template
<
typename
U
>
struct
TensorSetConstantXPU
{
TensorSetConstantXPU
(
framework
::
Tensor
*
tensor
,
U
value
)
:
tensor_
(
tensor
),
value_
(
value
)
{}
TensorSetConstantXPU
(
framework
::
Tensor
*
tensor
,
U
value
,
platform
::
Place
place
)
:
tensor_
(
tensor
),
value_
(
value
),
place_
(
place
)
{}
template
<
typename
T
>
void
apply
()
const
{
int
dev_id
=
-
1
;
xpu_current_device
(
&
dev_id
);
if
(
dev_id
>=
64
)
{
// if dev_id >= 64, the device is a simulator device, -64 to get real
// dev_id
dev_id
-=
64
;
}
auto
xpu
=
platform
::
XPUPlace
(
dev_id
);
auto
*
begin
=
tensor_
->
mutable_data
<
T
>
(
xpu
);
auto
*
begin
=
tensor_
->
mutable_data
<
T
>
(
place_
);
int
numel
=
tensor_
->
numel
();
std
::
unique_ptr
<
T
[]
>
data_cpu
(
new
T
[
numel
]);
std
::
fill
(
data_cpu
.
get
(),
data_cpu
.
get
()
+
numel
,
static_cast
<
T
>
(
value_
));
memory
::
Copy
(
xpu
,
begin
,
platform
::
CPUPlace
(),
static_cast
<
void
*>
(
data_cpu
.
get
()),
numel
*
sizeof
(
T
));
memory
::
Copy
(
BOOST_GET_CONST
(
platform
::
XPUPlace
,
place_
),
begin
,
platform
::
CPUPlace
(),
static_cast
<
void
*>
(
data_cpu
.
get
()),
numel
*
sizeof
(
T
));
}
framework
::
Tensor
*
tensor_
;
U
value_
;
platform
::
Place
place_
;
};
#endif
...
...
paddle/fluid/operators/math/math_function_impl.h
浏览文件 @
b1026f64
...
...
@@ -32,8 +32,9 @@ void SetConstant<DeviceContext, T>::operator()(const DeviceContext& context,
#ifdef PADDLE_WITH_XPU
if
(
platform
::
is_xpu_place
(
context
.
GetPlace
()))
{
xpu_place
=
true
;
framework
::
VisitDataType
(
tensor
->
type
(),
TensorSetConstantXPU
<
T
>
(
tensor
,
num
));
framework
::
VisitDataType
(
tensor
->
type
(),
TensorSetConstantXPU
<
T
>
(
tensor
,
num
,
context
.
GetPlace
()));
}
#endif
if
(
!
xpu_place
)
{
...
...
paddle/fluid/platform/collective_helper.cc
浏览文件 @
b1026f64
...
...
@@ -12,13 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#if defined(PADDLE_WITH_NCCL)
#include "paddle/fluid/platform/collective_helper.h"
#include <utility>
namespace
paddle
{
namespace
platform
{
#if defined(PADDLE_WITH_NCCL)
class
NCCLCommImpl
:
public
NCCLComm
{
public:
void
set_ring_id
(
int
ring_id
)
{
ring_id_
=
ring_id
;
}
...
...
@@ -159,7 +158,132 @@ void NCCLCommContext::ReleaseNCCLComms() {
}
}
}
// namespace platform
}
// namespace paddle
#endif
#if defined(PADDLE_WITH_XPU_BKCL)
class
BKCLCommImpl
:
public
BKCLComm
{
public:
void
set_ring_id
(
int
ring_id
)
{
ring_id_
=
ring_id
;
}
int
ring_id
()
const
override
{
return
ring_id_
;
}
void
set_nranks
(
int
nranks
)
{
nranks_
=
nranks
;
}
int
nranks
()
const
override
{
return
nranks_
;
}
void
set_rank
(
int
rank
)
{
rank_
=
rank
;
}
int
rank
()
const
override
{
return
rank_
;
}
int
device_id
()
const
override
{
return
BOOST_GET_CONST
(
XPUPlace
,
dev_ctx_
->
GetPlace
()).
device
;
}
void
set_comm
(
BKCLContext_t
comm
)
{
comm_
=
comm
;
}
BKCLContext_t
comm
()
const
override
{
return
comm_
;
}
XPUStream
stream
()
const
override
{
return
dev_ctx_
->
x_context
()
->
xpu_stream
;
}
void
set_dev_ctx
(
std
::
unique_ptr
<
XPUDeviceContext
>&&
dev_ctx
)
{
dev_ctx_
=
std
::
move
(
dev_ctx
);
}
XPUDeviceContext
*
dev_context
()
const
override
{
return
dev_ctx_
.
get
();
}
private:
int
ring_id_
;
int
nranks_
;
int
rank_
;
BKCLContext_t
comm_
;
std
::
unique_ptr
<
XPUDeviceContext
>
dev_ctx_
;
};
BKCLComm
*
BKCLCommContext
::
CreateBKCLComm
(
BKCLUniqueId
*
bkcl_id
,
int
nranks
,
int
rank
,
int
dev_id
,
int
ring_id
)
{
PADDLE_ENFORCE_NOT_NULL
(
bkcl_id
,
platform
::
errors
::
InvalidArgument
(
"The bkcl unique id should not be null."
));
PADDLE_ENFORCE_GT
(
nranks
,
1
,
platform
::
errors
::
InvalidArgument
(
"Expected nranks > 1. But received nranks is %d."
,
nranks
));
PADDLE_ENFORCE_GE
(
rank
,
0
,
platform
::
errors
::
InvalidArgument
(
"Expected rank >= 0. But received rank is %d."
,
rank
));
PADDLE_ENFORCE_LT
(
rank
,
nranks
,
platform
::
errors
::
InvalidArgument
(
"Expected rank < nranks. But received rank is %d, nranks is %d."
,
rank
,
nranks
));
PADDLE_ENFORCE_GE
(
dev_id
,
0
,
platform
::
errors
::
InvalidArgument
(
"Expected dev_id >= 0. But received dev_id is %d."
,
dev_id
));
BKCLContext_t
comm
=
nullptr
;
auto
ret
=
xpu_set_device
(
dev_id
);
PADDLE_ENFORCE_EQ
(
ret
,
XPU_SUCCESS
,
platform
::
errors
::
PreconditionNotMet
(
"XPU API return wrong value[%d %s], please check whether "
"Baidu Kunlun Card is properly installed."
,
ret
,
XPUAPIErrorMsg
[
ret
]));
ret
=
bkcl_init_rank
(
&
comm
,
rank
,
nranks
,
bkcl_id
);
PADDLE_ENFORCE_EQ
(
ret
,
BKCL_SUCCESS
,
platform
::
errors
::
PreconditionNotMet
(
"bkcl_init_rank failed, got wrong value [%d]."
,
ret
));
auto
*
comm_wrapper
=
AssignBKCLComm
(
comm
,
nranks
,
rank
,
dev_id
,
ring_id
);
VLOG
(
1
)
<<
"bkcl communicator of rank "
<<
rank
<<
" in ring "
<<
ring_id
<<
" has been created on device "
<<
dev_id
;
std
::
call_once
(
once_flag_
,
[]()
{
std
::
atexit
([]()
{
BKCLCommContext
::
Instance
().
ReleaseBKCLComms
();
});
});
return
comm_wrapper
;
}
BKCLComm
*
BKCLCommContext
::
AssignBKCLComm
(
BKCLContext_t
comm
,
int
nranks
,
int
rank
,
int
dev_id
,
int
ring_id
)
{
std
::
unique_ptr
<
XPUDeviceContext
>
dev_ctx
(
new
XPUDeviceContext
(
XPUPlace
(
dev_id
)));
BKCLCommImpl
*
c
=
new
BKCLCommImpl
;
c
->
set_ring_id
(
ring_id
);
c
->
set_nranks
(
nranks
);
c
->
set_rank
(
rank
);
c
->
set_comm
(
comm
);
c
->
set_dev_ctx
(
std
::
move
(
dev_ctx
));
comm_map_mutex_
.
lock
();
if
(
comm_map_
.
count
(
ring_id
)
==
0
)
{
comm_map_
.
emplace
(
ring_id
,
std
::
map
<
int
,
std
::
unique_ptr
<
BKCLComm
>>
());
}
auto
&
dev2comm
=
comm_map_
[
ring_id
];
dev2comm
.
emplace
(
dev_id
,
std
::
unique_ptr
<
BKCLComm
>
(
c
));
comm_map_mutex_
.
unlock
();
if
(
ring_id
==
0
)
{
auto
*
dev_ctx
=
static_cast
<
platform
::
XPUDeviceContext
*>
(
platform
::
DeviceContextPool
::
Instance
().
Get
(
platform
::
XPUPlace
(
dev_id
)));
dev_ctx
->
set_bkcl_context
(
comm
);
}
return
comm_map_
[
ring_id
][
dev_id
].
get
();
}
void
BKCLCommContext
::
ReleaseBKCLComms
()
{
for
(
auto
&
p
:
comm_map_
)
{
for
(
auto
&
q
:
p
.
second
)
{
q
.
second
.
reset
();
}
}
}
#endif
}
// namespace platform
}
// namespace paddle
paddle/fluid/platform/collective_helper.h
浏览文件 @
b1026f64
...
...
@@ -14,7 +14,6 @@
#pragma once
#if defined(PADDLE_WITH_NCCL)
#include <map>
#include <memory>
#include <string>
...
...
@@ -28,6 +27,7 @@
namespace
paddle
{
namespace
platform
{
#if defined(PADDLE_WITH_NCCL)
// In order to apply hierarchical communication with NCCL, we need
// a communication ring contains NCCL communicators associated to a global
// ncclUniqueId. E.g. for a hierarchical case,
...
...
@@ -120,8 +120,102 @@ class NCCLCommContext {
NCCLCommContext
()
=
default
;
DISABLE_COPY_AND_ASSIGN
(
NCCLCommContext
);
};
#endif
}
// namespace platform
}
// namespace paddle
#if defined(PADDLE_WITH_XPU_BKCL)
// In order to apply hierarchical communication with BKCL, we need
// a communication ring contains BKCL communicators associated to a global
// BKCLUniqueId. E.g. for a hierarchical case,
//
// 11 - 12 21 - 22
// | | | |
// 13 - 14 - 23 - 24
// | |
// 31 - 32 - 41 - 42
// | | | |
// 33 - 34 43 - 44
//
// we group (14,23,32,41) as the top, and (11,12,13,14), (21,22,23,24),
// (31,32,33,34), (41,42,43,44) as bottoms respectively.
//
// We could also use a single communication ring for the flatten case
//
// The BKCLComm instance is created and reversed in the BKCLCommContext
// singleton with a global user specified group id.
class
BKCLComm
{
public:
virtual
int
ring_id
()
const
=
0
;
virtual
int
nranks
()
const
=
0
;
virtual
int
rank
()
const
=
0
;
virtual
int
device_id
()
const
=
0
;
virtual
BKCLContext_t
comm
()
const
=
0
;
virtual
XPUStream
stream
()
const
=
0
;
virtual
XPUDeviceContext
*
dev_context
()
const
=
0
;
virtual
~
BKCLComm
()
=
default
;
};
// A singleton BKCL communicator context reserves communication ring ids
class
BKCLCommContext
{
public:
static
BKCLCommContext
&
Instance
()
{
static
BKCLCommContext
comm_ctx
;
return
comm_ctx
;
}
BKCLComm
*
CreateBKCLComm
(
BKCLUniqueId
*
bkcl_id
,
int
nranks
,
int
rank
,
int
dev_id
,
int
ring_id
=
0
);
void
CreateAllBKCLComms
(
const
std
::
vector
<
int
>&
dev_ids
,
int
ring_id
=
0
);
// a latter comm with the same dev_id and the same ring_id
// will override the former
BKCLComm
*
AssignBKCLComm
(
BKCLContext_t
comm
,
int
nranks
,
int
rank
,
int
dev_id
,
int
ring_id
=
0
);
// retrieve a communicator by the ring id in multiprocessing mode
BKCLComm
*
Get
(
int
ring_id
)
const
{
PADDLE_ENFORCE_GT
(
comm_map_
.
count
(
ring_id
),
0
,
platform
::
errors
::
InvalidArgument
(
"Communicator in ring id %d has not been initialized."
,
ring_id
));
PADDLE_ENFORCE_EQ
(
comm_map_
.
at
(
ring_id
).
size
(),
1
,
platform
::
errors
::
InvalidArgument
(
"One device id should be specified to retrieve from "
"multiple communicators."
));
return
comm_map_
.
at
(
ring_id
).
begin
()
->
second
.
get
();
}
// retrieve a communicator by the ring id and the device id
BKCLComm
*
Get
(
int
ring_id
,
int
dev_id
)
const
{
PADDLE_ENFORCE_GT
(
comm_map_
.
count
(
ring_id
),
0
,
platform
::
errors
::
InvalidArgument
(
"Communicator of ring id %d has not been initialized."
,
ring_id
));
PADDLE_ENFORCE_GT
(
comm_map_
.
at
(
ring_id
).
count
(
dev_id
),
0
,
platform
::
errors
::
InvalidArgument
(
"Communicator at device id %d has not been initialized in ring %d."
,
dev_id
,
ring_id
));
return
comm_map_
.
at
(
ring_id
).
at
(
dev_id
).
get
();
}
// retrieve a communicator by the ring id and place
BKCLComm
*
Get
(
int
ring_id
,
Place
place
)
const
{
return
Get
(
ring_id
,
BOOST_GET_CONST
(
XPUPlace
,
place
).
device
);
}
private:
std
::
once_flag
once_flag_
;
std
::
mutex
comm_map_mutex_
;
// ring id to dev-BKCLComm
std
::
map
<
int
,
std
::
map
<
int
,
std
::
unique_ptr
<
BKCLComm
>>>
comm_map_
;
void
ReleaseBKCLComms
();
BKCLCommContext
()
=
default
;
DISABLE_COPY_AND_ASSIGN
(
BKCLCommContext
);
};
#endif
}
// namespace platform
}
// namespace paddle
paddle/fluid/platform/device_context.cc
浏览文件 @
b1026f64
...
...
@@ -188,6 +188,9 @@ XPUDeviceContext::XPUDeviceContext(XPUPlace place) : place_(place) {
"XPU API return wrong value[%d], please check whether "
"Baidu Kunlun Card is properly installed."
,
ret
));
LOG_FIRST_N
(
WARNING
,
1
)
<<
"Please NOTE: xpu device: "
<<
place_
.
device
;
context_
=
xpu
::
create_context
();
const
int
MAX_XPU_NUM
=
16
;
const
int
l3_size
=
13.5
*
1024
*
1024
;
...
...
paddle/fluid/platform/gen_comm_id_helper.cc
浏览文件 @
b1026f64
...
...
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#if
def PADDLE_WITH_NCCL
#if
(defined PADDLE_WITH_NCCL) || (defined PADDLE_WITH_XPU_BKCL)
#include "paddle/fluid/platform/gen_comm_id_helper.h"
#include <arpa/inet.h>
...
...
@@ -339,7 +339,7 @@ void RecvBroadCastCommID(int server_fd, std::string endpoint,
INSTANT_TEMPLATE
(
ncclUniqueId
)
#endif
#ifdef PADDLE_WITH_XPU_BKCL
INSTANT_TEMPLATE
(
bkcl
UniqueId
)
INSTANT_TEMPLATE
(
BKCL
UniqueId
)
#endif
}
// namespace platform
}
// namespace paddle
...
...
paddle/fluid/platform/gen_comm_id_helper.h
浏览文件 @
b1026f64
...
...
@@ -14,7 +14,7 @@ limitations under the License. */
#pragma once
#if
def PADDLE_WITH_NCCL
#if
(defined PADDLE_WITH_NCCL) || (defined PADDLE_WITH_XPU_BKCL)
#include <functional>
#include <string>
#include <vector>
...
...
paddle/fluid/platform/xpu_info.h
浏览文件 @
b1026f64
...
...
@@ -28,6 +28,29 @@ std::vector<int> GetXPUSelectedDevices();
//! Set the XPU device id for next execution.
void
SetXPUDeviceId
(
int
device_id
);
class
XPUDeviceGuard
{
public:
explicit
inline
XPUDeviceGuard
(
int
dev_id
)
{
int
prev_id
=
platform
::
GetXPUCurrentDeviceId
();
if
(
prev_id
!=
dev_id
)
{
prev_id_
=
prev_id
;
platform
::
SetXPUDeviceId
(
dev_id
);
}
}
inline
~
XPUDeviceGuard
()
{
if
(
prev_id_
!=
-
1
)
{
platform
::
SetXPUDeviceId
(
prev_id_
);
}
}
XPUDeviceGuard
(
const
XPUDeviceGuard
&
o
)
=
delete
;
XPUDeviceGuard
&
operator
=
(
const
XPUDeviceGuard
&
o
)
=
delete
;
private:
int
prev_id_
{
-
1
};
};
}
// namespace platform
}
// namespace paddle
#endif
paddle/fluid/pybind/CMakeLists.txt
浏览文件 @
b1026f64
...
...
@@ -5,6 +5,7 @@ set(PYBIND_DEPS pybind python proto_desc memory executor fleet_wrapper box_wrapp
if
(
WITH_GPU
)
set
(
PYBIND_DEPS
${
PYBIND_DEPS
}
dynload_cuda
)
set
(
PYBIND_DEPS
${
PYBIND_DEPS
}
cuda_device_guard
)
endif
()
if
(
WITH_NCCL
)
...
...
@@ -12,6 +13,11 @@ if (WITH_NCCL)
set
(
PYBIND_DEPS
${
PYBIND_DEPS
}
reducer
)
endif
()
if
(
WITH_XPU_BKCL
)
set
(
PYBIND_DEPS
${
PYBIND_DEPS
}
reducer
)
set
(
PYBIND_DEPS
${
PYBIND_DEPS
}
bkcl_context
)
endif
()
if
(
NOT WIN32
)
set
(
PYBIND_DEPS
${
PYBIND_DEPS
}
data_loader
)
set
(
PYBIND_DEPS
${
PYBIND_DEPS
}
mmap_allocator
)
...
...
@@ -79,6 +85,10 @@ if(WITH_PYTHON)
list
(
APPEND OP_FUNCTION_GENERETOR_DEPS nccl_context
)
endif
(
WITH_NCCL
)
if
(
WITH_XPU_BKCL
)
list
(
APPEND OP_FUNCTION_GENERETOR_DEPS bkcl_context
)
endif
(
WITH_XPU_BKCL
)
add_executable
(
op_function_generator op_function_generator.cc
)
target_link_libraries
(
op_function_generator
${
OP_FUNCTION_GENERETOR_DEPS
}
)
get_property
(
os_dependency_modules GLOBAL PROPERTY OS_DEPENDENCY_MODULES
)
...
...
paddle/fluid/pybind/imperative.cc
浏览文件 @
b1026f64
...
...
@@ -32,6 +32,7 @@ limitations under the License. */
#include "paddle/fluid/imperative/all_reduce.h"
#include "paddle/fluid/imperative/amp_auto_cast.h"
#include "paddle/fluid/imperative/basic_engine.h"
#include "paddle/fluid/imperative/bkcl_context.h"
#include "paddle/fluid/imperative/data_loader.h"
#include "paddle/fluid/imperative/layer.h"
#include "paddle/fluid/imperative/nccl_context.h"
...
...
@@ -1377,16 +1378,10 @@ void BindImperative(py::module *m_ptr) {
},
py
::
call_guard
<
py
::
gil_scoped_release
>
());
#if
defined(PADDLE_WITH_NC
CL)
#if
(defined PADDLE_WITH_NCCL) || (defined PADDLE_WITH_XPU_BK
CL)
py
::
class_
<
imperative
::
ParallelContext
,
std
::
shared_ptr
<
imperative
::
ParallelContext
>>
(
m
,
"ParallelContext"
);
py
::
class_
<
imperative
::
NCCLParallelContext
,
imperative
::
ParallelContext
,
std
::
shared_ptr
<
imperative
::
NCCLParallelContext
>>
(
m
,
"NCCLParallelContext"
)
.
def
(
py
::
init
<
const
imperative
::
ParallelStrategy
&
,
const
platform
::
CUDAPlace
&>
())
.
def
(
"init"
,
[](
imperative
::
NCCLParallelContext
&
self
)
{
self
.
Init
();
});
py
::
class_
<
imperative
::
Reducer
,
std
::
shared_ptr
<
imperative
::
Reducer
>>
(
m
,
"Reducer"
,
R"DOC()DOC"
)
...
...
@@ -1404,6 +1399,24 @@ void BindImperative(py::module *m_ptr) {
py
::
arg
(
"tensor_indices"
)
=
std
::
vector
<
int64_t
>
{},
py
::
call_guard
<
py
::
gil_scoped_release
>
());
#endif
#if defined(PADDLE_WITH_NCCL)
py
::
class_
<
imperative
::
NCCLParallelContext
,
imperative
::
ParallelContext
,
std
::
shared_ptr
<
imperative
::
NCCLParallelContext
>>
(
m
,
"NCCLParallelContext"
)
.
def
(
py
::
init
<
const
imperative
::
ParallelStrategy
&
,
const
platform
::
CUDAPlace
&>
())
.
def
(
"init"
,
[](
imperative
::
NCCLParallelContext
&
self
)
{
self
.
Init
();
});
#endif
#if defined(PADDLE_WITH_XPU_BKCL)
py
::
class_
<
imperative
::
BKCLParallelContext
,
imperative
::
ParallelContext
,
std
::
shared_ptr
<
imperative
::
BKCLParallelContext
>>
(
m
,
"BKCLParallelContext"
)
.
def
(
py
::
init
<
const
imperative
::
ParallelStrategy
&
,
const
platform
::
XPUPlace
&>
())
.
def
(
"init"
,
[](
imperative
::
BKCLParallelContext
&
self
)
{
self
.
Init
();
});
#endif
}
}
// namespace pybind
...
...
paddle/fluid/pybind/tensor_py.h
浏览文件 @
b1026f64
...
...
@@ -27,6 +27,9 @@ limitations under the License. */
#include "paddle/fluid/operators/math/concat_and_split.h"
#include "paddle/fluid/operators/strided_memcpy.h"
#include "paddle/fluid/platform/bfloat16.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cuda_device_guard.h"
#endif
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/profiler.h"
...
...
@@ -256,6 +259,38 @@ void TensorSetElement(framework::Tensor *self, size_t offset, T elem) {
}
}
// NOTE(wangxi): When copying data to the accelerator card,
// we need set_device(dev_id) first.
template
<
typename
P
>
static
int
GetDeviceId
(
const
P
&
place
)
{
// for CPUPlace and CUDAPinnedPlace.
PADDLE_THROW
(
platform
::
errors
::
PermissionDenied
(
"Paddle can't Get CPUPlace or CUDAPinnedPlace Device Id."
));
}
template
<
>
int
GetDeviceId
<
platform
::
CUDAPlace
>
(
const
platform
::
CUDAPlace
&
place
)
{
return
place
.
GetDeviceId
();
}
template
<
>
int
GetDeviceId
<
platform
::
XPUPlace
>
(
const
platform
::
XPUPlace
&
place
)
{
return
place
.
GetDeviceId
();
}
// NOTE(wangxi16): Used by VarBase __setitem__
template
<
>
int
GetDeviceId
<
platform
::
Place
>
(
const
platform
::
Place
&
place
)
{
if
(
paddle
::
platform
::
is_gpu_place
(
place
))
{
return
GetDeviceId
(
BOOST_GET_CONST
(
platform
::
CUDAPlace
,
place
));
}
else
if
(
paddle
::
platform
::
is_xpu_place
(
place
))
{
return
GetDeviceId
(
BOOST_GET_CONST
(
platform
::
XPUPlace
,
place
));
}
// for CPUPlace and CUDAPinnedPlace.
PADDLE_THROW
(
platform
::
errors
::
PermissionDenied
(
"Paddle can't Get CPUPlace or CUDAPinnedPlace Device Id."
));
}
template
<
typename
T
,
typename
P
>
void
SetTensorFromPyArrayT
(
framework
::
Tensor
*
self
,
...
...
@@ -279,6 +314,7 @@ void SetTensorFromPyArrayT(
}
}
else
if
(
paddle
::
platform
::
is_xpu_place
(
place
))
{
#ifdef PADDLE_WITH_XPU
platform
::
XPUDeviceGuard
guard
(
GetDeviceId
(
place
));
auto
dst
=
self
->
mutable_data
<
T
>
(
place
);
xpu_memcpy
(
dst
,
array
.
data
(),
array
.
nbytes
(),
XPUMemcpyKind
::
XPU_HOST_TO_DEVICE
);
...
...
@@ -290,7 +326,7 @@ void SetTensorFromPyArrayT(
}
else
{
#ifdef PADDLE_WITH_CUDA
if
(
paddle
::
platform
::
is_gpu_place
(
place
))
{
// TODO(zhiqiu): set SetDeviceId before calling cuda APIs.
platform
::
CUDADeviceGuard
guard
(
GetDeviceId
(
place
));
auto
dst
=
self
->
mutable_data
<
T
>
(
place
);
paddle
::
platform
::
GpuMemcpySync
(
dst
,
array
.
data
(),
array
.
nbytes
(),
cudaMemcpyHostToDevice
);
...
...
python/paddle/distributed/fleet/launch.py
浏览文件 @
b1026f64
...
...
@@ -108,6 +108,7 @@ see: http://www.paddlepaddle.org/documentation/docs/zh/1.6/user_guides/howto/tra
"In gpu training, it should be less or equal to the gpus number of you system(or you set by --gpus). And so each process can"
" bound to one or average number of gpus."
)
if
fluid
.
core
.
is_compiled_with_cuda
():
base_group
.
add_argument
(
"--gpus"
,
type
=
str
,
...
...
@@ -116,9 +117,18 @@ see: http://www.paddlepaddle.org/documentation/docs/zh/1.6/user_guides/howto/tra
"For example:"
"--gpus=
\"
0,1,2,3
\"
will launch four training processes each bound to one gpu."
)
base_group
.
add_argument
(
"--selected_gpus"
,
dest
=
"gpus"
)
if
fluid
.
core
.
is_compiled_with_xpu
():
base_group
.
add_argument
(
"--xpus"
,
type
=
str
,
default
=
None
,
help
=
"It's for xpu training. For example: "
"--xpus=
\"
0,1,2,3
\"
will launch four training processes each bound to one xpu."
)
base_group
.
add_argument
(
"--selected_xpus"
,
dest
=
"xpus"
)
base_group
.
add_argument
(
"training_script"
,
type
=
str
,
...
...
@@ -288,14 +298,16 @@ def which_distributed_mode(args):
)
if
fluid
.
core
.
is_compiled_with_cuda
():
cuda_device_num
=
fluid
.
core
.
get_cuda_device_count
()
device_count
=
fluid
.
core
.
get_cuda_device_count
()
elif
fluid
.
core
.
is_compiled_with_xpu
():
device_count
=
fluid
.
core
.
get_xpu_device_count
()
else
:
cuda_device_num
=
0
device_count
=
0
if
len
(
has_ps_args
)
>
0
:
logger
.
info
(
"Run parameter-sever mode. pserver arguments:{}, cuda count:{}"
.
format
(
has_ps_args
,
cuda_device_num
))
"Run parameter-sever mode. pserver arguments:{}, cuda
or xpu
count:{}"
.
format
(
has_ps_args
,
device_count
))
has_ps_heter_args
=
list
(
set
(
has_ps_args
)
&
set
(
ps_heter_args
))
if
len
(
has_ps_heter_args
)
>
0
:
return
DistributeMode
.
PS_HETER
...
...
@@ -303,17 +315,18 @@ def which_distributed_mode(args):
return
DistributeMode
.
PS
elif
len
(
has_collective_args
)
>
0
:
logger
.
info
(
"Run collective gpu mode. gpu arguments:{}, cuda count:{}"
.
format
(
has_collective_args
,
cuda_device_num
))
format
(
has_collective_args
,
device_count
))
return
DistributeMode
.
COLLECTIVE
else
:
if
not
fluid
.
core
.
is_compiled_with_cuda
():
if
not
fluid
.
core
.
is_compiled_with_cuda
(
)
and
not
fluid
.
core
.
is_compiled_with_xpu
():
logger
.
warning
(
"Not found distinct arguments and not compiled with cuda. Default use ps mode"
"Not found distinct arguments and not compiled with cuda
or xpu
. Default use ps mode"
)
return
DistributeMode
.
PS
else
:
logger
.
warning
(
"Not found distinct arguments and compiled with cuda. Default use collective mode"
"Not found distinct arguments and compiled with cuda
or xpu
. Default use collective mode"
)
return
DistributeMode
.
COLLECTIVE
...
...
python/paddle/distributed/fleet/launch_utils.py
浏览文件 @
b1026f64
...
...
@@ -47,10 +47,11 @@ class DeviceMode():
"""
Training devices type
"""
UNKNOWN
=
-
1
CPU
=
0
GPU
=
1
KUNLUN
=
2
UNKNOWN
=
3
XPU
=
2
class
Cluster
(
object
):
...
...
@@ -275,6 +276,11 @@ def get_cluster(node_ips, node_ip, trainer_endpoints, device_mode,
trainer
.
gpus
.
extend
(
devices_per_proc
[
i
])
else
:
trainer
.
gpus
.
append
(
devices_per_proc
[
i
])
elif
device_mode
==
DeviceMode
.
XPU
:
if
isinstance
(
devices_per_proc
[
i
],
(
list
,
tuple
)):
trainer
.
gpus
.
extend
(
devices_per_proc
[
i
])
else
:
trainer
.
gpus
.
extend
(
devices_per_proc
[
i
])
trainer
.
endpoint
=
"%s"
%
(
cur_node_endpoints
[
i
])
trainer
.
rank
=
trainer_rank
trainer_rank
+=
1
...
...
@@ -454,9 +460,12 @@ def start_local_trainers(cluster,
"PADDLE_TRAINER_ENDPOINTS"
:
","
.
join
(
cluster
.
trainers_endpoints
())
}
if
len
(
t
.
gpus
)
>
0
:
if
fluid
.
core
.
is_compiled_with_cuda
()
and
len
(
t
.
gpus
)
>
0
:
proc_env
[
"FLAGS_selected_gpus"
]
=
"%s"
%
","
.
join
(
[
str
(
g
)
for
g
in
t
.
gpus
])
elif
fluid
.
core
.
is_compiled_with_xpu
()
and
len
(
t
.
gpus
)
>
0
:
proc_env
[
"FLAGS_selected_xpus"
]
=
"%s"
%
","
.
join
(
[
str
(
g
)
for
g
in
t
.
gpus
])
current_env
.
update
(
proc_env
)
...
...
@@ -584,15 +593,47 @@ def get_gpus(gpus):
return
res_gpus
def
get_device_mode
():
#TODO(gongwb):Add XPU supported
if
not
fluid
.
core
.
is_compiled_with_cuda
(
)
or
fluid
.
core
.
get_cuda_device_count
()
<=
0
:
print
(
"launch train in CPU mode"
)
return
DeviceMode
.
CPU
def
get_xpus
(
xpus
):
if
xpus
is
None
:
xpus_num
=
fluid
.
core
.
get_xpu_device_count
()
res_xpus
=
[
str
(
x
)
for
x
in
range
(
0
,
xpus_num
)]
else
:
xpu_visible_devices
=
os
.
getenv
(
"XPU_VISIBLE_DEVICES"
)
if
xpu_visible_devices
is
None
or
xpu_visible_devices
==
""
:
res_xpus
=
[
x
.
strip
()
for
x
in
xpus
.
split
(
','
)]
else
:
# change xpus into relative values
# e.g. XPU_VISIBLE_DEVICES=4,5,6,7; args.xpus=4,5,6,7;
# therefore xpus=0,1,2,3
xpu_visible_devices_list
=
xpu_visible_devices
.
split
(
','
)
for
x
in
xpus
.
split
(
','
):
assert
x
in
xpu_visible_devices_list
,
"Can't find "
\
"your xpus %s in XPU_VISIBLE_DEVICES[%s]."
\
%
(
x
,
xpu_visible_devices
)
res_xpus
=
[
xpu_visible_devices_list
.
index
(
x
.
strip
())
for
x
in
xpus
.
split
(
','
)
]
logger
.
info
(
"Change selected_xpus into reletive values. --ips:{} "
"will change into relative_ips:{} according to your "
"XPU_VISIBLE_DEVICES:{}"
.
format
(
xpus
,
res_xpus
,
xpu_visible_devices_list
))
return
res_xpus
def
get_device_mode
():
if
fluid
.
core
.
is_compiled_with_cuda
()
and
fluid
.
core
.
get_cuda_device_count
(
)
>
0
:
print
(
"launch train in GPU mode"
)
return
DeviceMode
.
GPU
elif
fluid
.
core
.
is_compiled_with_xpu
()
and
fluid
.
core
.
get_xpu_device_count
(
)
>
0
:
print
(
"launch train in XPU mode"
)
return
DeviceMode
.
XPU
print
(
"launch train in CPU mode"
)
return
DeviceMode
.
CPU
def
get_device_proc_info
(
args
):
...
...
@@ -613,13 +654,25 @@ def get_device_proc_info(args):
]
else
:
devices_per_proc
=
gpus
elif
device_mode
==
DeviceMode
.
XPU
:
xpus
=
get_xpus
(
args
.
xpus
)
if
args
.
nproc_per_node
is
not
None
:
assert
(
len
(
xpus
)
%
int
(
args
.
nproc_per_node
))
==
0
,
\
"xpus' number:{} mod args.nproc_per_node:{} must == 0"
.
format
(
len
(
xpus
),
arg
.
nproc_per_node
)
n
=
int
(
len
(
xpus
)
/
int
(
args
.
nproc_per_node
))
devices_per_proc
=
[
xpus
[
i
:
i
+
n
]
for
i
in
six
.
moves
.
range
(
0
,
len
(
xpus
),
n
)
]
else
:
devices_per_proc
=
xpus
elif
device_mode
==
DeviceMode
.
CPU
:
if
args
.
nproc_per_node
is
None
:
devices_per_proc
=
[
0
]
else
:
devices_per_proc
=
[
x
for
x
in
range
(
0
,
args
.
nproc_per_node
)]
else
:
assert
False
,
"Can't support device_mode:{}, support only cpu
and g
pu now."
.
format
(
assert
False
,
"Can't support device_mode:{}, support only cpu
|gpu|x
pu now."
.
format
(
device_mode
)
return
(
device_mode
,
devices_per_proc
)
...
...
python/paddle/distributed/parallel.py
浏览文件 @
b1026f64
...
...
@@ -120,12 +120,12 @@ def init_parallel_env():
)
return
# 1. gpu
check
if
not
core
.
is_compiled_with_cuda
():
# 1. gpu
xpu check, must be gpu or xpu
if
not
core
.
is_compiled_with_cuda
()
and
not
core
.
is_compiled_with_xpu
()
:
raise
NotImplementedError
(
"Cannot initialize parallel environment in CPU-only version, now only "
"supports initializing the GPU parallel environment. Please recompile "
"or reinstall paddle with GPU support."
)
"supports initializing the GPU
and XPU
parallel environment. Please recompile "
"or reinstall paddle with GPU
or XPU
support."
)
# 2. check env
def
_check_var_exists
(
var_name
):
...
...
@@ -135,7 +135,11 @@ def init_parallel_env():
"environment variable %s is needed, but not set."
%
var_name
)
if
core
.
is_compiled_with_cuda
():
_check_var_exists
(
"FLAGS_selected_gpus"
)
elif
core
.
is_compiled_with_xpu
():
_check_var_exists
(
'FLAGS_selected_xpus'
)
_check_var_exists
(
"PADDLE_TRAINER_ID"
)
_check_var_exists
(
"PADDLE_CURRENT_ENDPOINT"
)
_check_var_exists
(
"PADDLE_TRAINERS_NUM"
)
...
...
@@ -176,11 +180,19 @@ def init_parallel_env():
# directly, if they want to switch default place,
# they need to call a function to change default place,
# here just set correctly place to users
if
core
.
is_compiled_with_cuda
():
place
=
core
.
CUDAPlace
(
parallel_env
.
device_id
)
elif
core
.
is_compiled_with_xpu
():
place
=
core
.
XPUPlace
(
parallel_env
.
device_id
)
_set_expected_place
(
place
)
# init nccl context
parallel_helper
.
_set_parallel_ctx
(
core
.
NCCLParallelContext
(
strategy
,
place
))
# init nccl or bkcl context
if
core
.
is_compiled_with_cuda
():
parallel_helper
.
_set_parallel_ctx
(
core
.
NCCLParallelContext
(
strategy
,
place
))
elif
core
.
is_compiled_with_xpu
():
parallel_helper
.
_set_parallel_ctx
(
core
.
BKCLParallelContext
(
strategy
,
place
))
parallel_helper
.
_init_parallel_ctx
()
# 5: init gloo context (step 2: gloo init)
...
...
python/paddle/fluid/dygraph/parallel.py
浏览文件 @
b1026f64
...
...
@@ -55,9 +55,12 @@ def prepare_context(strategy=None):
if
isinstance
(
place
,
core
.
CUDAPlace
):
parallel_helper
.
_set_parallel_ctx
(
core
.
NCCLParallelContext
(
strategy
,
place
))
elif
isinstance
(
place
,
core
.
XPUPlace
):
parallel_helper
.
_set_parallel_ctx
(
core
.
BKCLParallelContext
(
strategy
,
place
))
else
:
# TODO(Yancey1989): add Gloo Parallel Context to support CPU parallel computation
assert
(
"Only support CUDAPlace for now."
)
assert
(
"Only support CUDAPlace
or XPUPlace
for now."
)
parallel_helper
.
_init_parallel_ctx
()
return
strategy
...
...
@@ -108,9 +111,13 @@ class ParallelEnv(object):
self
.
_rank
=
int
(
os
.
getenv
(
"PADDLE_TRAINER_ID"
,
"0"
))
self
.
_world_size
=
int
(
os
.
getenv
(
"PADDLE_TRAINERS_NUM"
,
"1"
))
# imperative only support one gpu
# imperative only support one gpu or xpu
if
core
.
is_compiled_with_cuda
():
selected_gpus
=
os
.
getenv
(
"FLAGS_selected_gpus"
,
"0"
).
split
(
","
)
self
.
_device_id
=
int
(
selected_gpus
[
0
])
elif
core
.
is_compiled_with_xpu
():
selected_xpus
=
os
.
getenv
(
"FLAGS_selected_xpus"
,
"0"
).
split
(
","
)
self
.
_device_id
=
int
(
selected_xpus
[
0
])
self
.
_trainer_endpoints
=
os
.
getenv
(
"PADDLE_TRAINER_ENDPOINTS"
,
""
).
split
(
","
)
...
...
python/paddle/fluid/tests/unittests/detected_xpu.py
0 → 100644
浏览文件 @
b1026f64
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle
import
sys
import
paddle.fluid
as
fluid
print
(
"compile with xpu:"
,
fluid
.
core
.
is_compiled_with_xpu
())
print
(
"get_xpu_device_count:"
,
fluid
.
core
.
get_xpu_device_count
())
if
fluid
.
core
.
is_compiled_with_xpu
()
and
fluid
.
core
.
get_xpu_device_count
()
>
0
:
sys
.
exit
(
0
)
else
:
sys
.
exit
(
1
)
python/paddle/fluid/tests/unittests/nproc_process.py
浏览文件 @
b1026f64
...
...
@@ -15,18 +15,22 @@
import
os
import
sys
import
time
import
paddle.fluid
as
fluid
def
train
(
prefix
):
selected_gpus
=
os
.
getenv
(
"FLAGS_selected_gpus"
)
if
fluid
.
core
.
is_compiled_with_xpu
():
selected_devices
=
os
.
getenv
(
"FLAGS_selected_xpus"
)
else
:
selected_devices
=
os
.
getenv
(
"FLAGS_selected_gpus"
)
trainer_id
=
int
(
os
.
getenv
(
"PADDLE_TRAINER_ID"
))
worker_endpoints_env
=
os
.
getenv
(
"PADDLE_TRAINER_ENDPOINTS"
)
current_endpoint
=
os
.
getenv
(
"PADDLE_CURRENT_ENDPOINT"
)
worker_endpoints
=
worker_endpoints_env
trainers_num
=
len
(
worker_endpoints
.
split
(
','
))
name
=
"selected_
gpu
s:{} worker_endpoints:{} trainers_num:{} current_endpoint:{} trainer_id:{}"
\
.
format
(
selected_
gpu
s
,
worker_endpoints
,
trainers_num
,
current_endpoint
,
trainer_id
)
name
=
"selected_
device
s:{} worker_endpoints:{} trainers_num:{} current_endpoint:{} trainer_id:{}"
\
.
format
(
selected_
device
s
,
worker_endpoints
,
trainers_num
,
current_endpoint
,
trainer_id
)
print
(
name
)
with
open
(
"{}.check_{}.log"
.
format
(
prefix
,
trainer_id
),
"w"
)
as
f
:
...
...
python/paddle/fluid/tests/unittests/test_dist_base.py
浏览文件 @
b1026f64
...
...
@@ -464,8 +464,14 @@ class TestParallelDyGraphRunnerBase(object):
def
run_trainer
(
self
,
args
):
seed
=
90
if
fluid
.
core
.
is_compiled_with_cuda
():
device_id
=
int
(
os
.
getenv
(
"FLAGS_selected_gpus"
,
"0"
))
place
=
fluid
.
CUDAPlace
(
device_id
)
elif
fluid
.
core
.
is_compiled_with_xpu
():
device_id
=
int
(
os
.
getenv
(
"FLAGS_selected_xpus"
,
"0"
))
place
=
fluid
.
XPUPlace
(
device_id
)
else
:
assert
(
"Only support CUDAPlace or XPUPlace for now."
)
with
fluid
.
dygraph
.
guard
(
place
):
fluid
.
default_startup_program
().
random_seed
=
seed
...
...
@@ -476,7 +482,8 @@ class TestParallelDyGraphRunnerBase(object):
model
,
train_reader
,
opt
=
self
.
get_model
()
nranks
=
len
(
args
.
endpoints
.
split
(
","
))
if
args
.
endpoints
else
1
if
args
.
update_method
==
"nccl2"
:
#if args.update_method == "nccl2":
if
args
.
update_method
==
"nccl2"
or
args
.
update_method
==
"bkcl"
:
strategy
=
dygraph
.
parallel
.
ParallelStrategy
()
strategy
.
nranks
=
nranks
strategy
.
local_rank
=
args
.
trainer_id
...
...
@@ -592,7 +599,7 @@ def runtime_main(test_class):
'--update_method'
,
type
=
str
,
default
=
"local"
,
choices
=
[
"pserver"
,
"nccl2"
,
"local"
,
"nccl2_reduce_layer"
])
choices
=
[
"pserver"
,
"nccl2"
,
"
bkcl"
,
"
local"
,
"nccl2_reduce_layer"
])
parser
.
add_argument
(
'--trainer_id'
,
type
=
int
,
required
=
False
,
default
=
0
)
parser
.
add_argument
(
'--trainers'
,
type
=
int
,
required
=
False
,
default
=
1
)
parser
.
add_argument
(
'--nccl_comm_num'
,
type
=
int
,
required
=
False
,
default
=
1
)
...
...
@@ -608,6 +615,7 @@ def runtime_main(test_class):
'--current_endpoint'
,
type
=
str
,
required
=
False
,
default
=
""
)
parser
.
add_argument
(
'--sync_mode'
,
action
=
'store_true'
)
parser
.
add_argument
(
'--use_cuda'
,
action
=
'store_true'
)
parser
.
add_argument
(
'--use_xpu'
,
action
=
'store_true'
)
parser
.
add_argument
(
'--use_dgc'
,
action
=
'store_true'
)
parser
.
add_argument
(
'--use_reduce'
,
action
=
'store_true'
)
parser
.
add_argument
(
'--dc_asgd'
,
action
=
'store_true'
)
...
...
@@ -656,9 +664,15 @@ class TestDistBase(unittest.TestCase):
def
_after_setup_config
(
self
):
if
self
.
_enforce_place
==
"CPU"
:
self
.
__use_cuda
=
False
self
.
__use_xpu
=
False
self
.
_use_dgc
=
False
elif
self
.
_enforce_place
==
"GPU"
:
self
.
__use_cuda
=
True
self
.
__use_xpu
=
False
elif
self
.
_enforce_place
==
"XPU"
:
self
.
__use_cuda
=
False
self
.
__use_xpu
=
True
self
.
_use_dgc
=
False
else
:
if
fluid
.
core
.
is_compiled_with_cuda
():
self
.
__use_cuda
=
True
...
...
@@ -681,6 +695,7 @@ class TestDistBase(unittest.TestCase):
self
.
_dc_asgd
=
False
# must use with async mode
self
.
_use_reader_alloc
=
True
self
.
_nccl2_mode
=
False
self
.
_bkcl_mode
=
False
self
.
_pipeline_mode
=
False
self
.
_mp_mode
=
False
# FIXME(typhoonzero): I added this stupid argument to enable
...
...
@@ -783,7 +798,7 @@ class TestDistBase(unittest.TestCase):
batch_size
=
DEFAULT_BATCH_SIZE
,
batch_merge_repeat
=
1
,
log_name
=
""
,
gpu
s
=
"0"
):
device
s
=
"0"
):
cmd
=
self
.
_python_interp
...
...
@@ -804,7 +819,14 @@ class TestDistBase(unittest.TestCase):
if
self
.
__use_cuda
:
cmd
+=
" --use_cuda"
env_local
=
{
"CUDA_VISIBLE_DEVICES"
:
gpus
,
"CUDA_VISIBLE_DEVICES"
:
devices
,
"PADDLE_TRAINERS_NUM"
:
"1"
,
"PADDLE_TRAINER_ID"
:
"0"
}
elif
self
.
__use_xpu
:
cmd
+=
" --use_xpu"
env_local
=
{
"FLAGS_selected_xpus"
:
devices
,
"PADDLE_TRAINERS_NUM"
:
"1"
,
"PADDLE_TRAINER_ID"
:
"0"
}
...
...
@@ -812,7 +834,7 @@ class TestDistBase(unittest.TestCase):
env_local
=
{
'CPU_NUM'
:
'1'
}
# not use dgc in single card
if
len
(
gpu
s
)
>
1
and
self
.
_use_dgc
:
if
len
(
device
s
)
>
1
and
self
.
_use_dgc
:
cmd
+=
" --use_dgc"
env_local
.
update
(
envs
)
...
...
@@ -962,6 +984,19 @@ class TestDistBase(unittest.TestCase):
"PADDLE_TRAINER_ENDPOINTS"
:
self
.
_ps_endpoints
,
"PADDLE_CURRENT_ENDPOINT"
:
ep
,
})
# TODO(liuyuhui):XPU_VISIBLE_DEVICES is not working right now,
# will update it after Badiu Kunlun partners' support.
elif
self
.
__use_xpu
:
tr_cmd
+=
" --use_xpu"
env
.
update
({
"FLAGS_selected_xpus"
:
"{}"
.
format
(
trainer_id
),
#"XPU_VISIBLE_DEVICES": "{}".format(trainer_id + 1),
"PADDLE_TRAINERS_NUM"
:
"{}"
.
format
(
trainer_num
),
"PADDLE_TRAINER_ID"
:
"{}"
.
format
(
trainer_id
),
"PADDLE_TRAINER_ENDPOINTS"
:
self
.
_ps_endpoints
,
"PADDLE_CURRENT_ENDPOINT"
:
ep
,
"GLOG_v"
:
"2"
,
})
else
:
env
.
update
({
'CPU_NUM'
:
'1'
})
...
...
@@ -999,8 +1034,8 @@ class TestDistBase(unittest.TestCase):
return
tr_cmd
,
env
def
_run_cluster_nccl2
(
self
,
model
,
envs
,
nccl2_reduce_layer
,
check_error_log
,
log_name
):
def
_run_cluster_nccl2
(
self
,
model
,
envs
,
update_method
,
check_error_log
,
log_name
):
if
self
.
_use_hallreduce
:
self
.
_ps_endpoints
=
""
...
...
@@ -1018,10 +1053,6 @@ class TestDistBase(unittest.TestCase):
# NOTE: we reuse ps_endpoints as nccl2 worker endpoints
worker_endpoints
=
self
.
_ps_endpoints
.
split
(
","
)
if
nccl2_reduce_layer
:
update_method
=
"nccl2_reduce_layer"
else
:
update_method
=
"nccl2"
trainer_num
=
len
(
worker_endpoints
)
...
...
@@ -1150,16 +1181,24 @@ class TestDistBase(unittest.TestCase):
tr0_losses
,
tr1_losses
=
self
.
_run_cluster_nccl2
(
model_file
,
required_envs
,
True
,
check_error_log
,
update_method
=
"nccl2_reduce_layer"
,
check_error_log
=
check_error_log
,
log_name
=
log_name
)
else
:
tr0_losses
,
tr1_losses
=
self
.
_run_cluster_nccl2
(
model_file
,
required_envs
,
False
,
check_error_log
,
update_method
=
'nccl2'
,
check_error_log
=
check_error_log
,
log_name
=
log_name
)
elif
self
.
_bkcl_mode
:
tr0_losses
,
tr1_losses
=
self
.
_run_cluster_nccl2
(
model_file
,
required_envs
,
update_method
=
'bkcl'
,
check_error_log
=
check_error_log
,
log_name
=
log_name
)
elif
self
.
_pipeline_mode
:
tr0_losses
,
tr1_losses
=
self
.
_run_pipeline
(
model_file
,
required_envs
,
check_error_log
,
log_name
=
log_name
)
...
...
@@ -1196,7 +1235,7 @@ class TestDistBase(unittest.TestCase):
required_envs
,
check_error_log
,
log_name
=
log_name
+
"_dgc_2cards"
,
gpu
s
=
"0,1"
)
device
s
=
"0,1"
)
self
.
_use_dgc
=
False
base_losses
=
self
.
_run_local
(
...
...
@@ -1204,7 +1243,7 @@ class TestDistBase(unittest.TestCase):
required_envs
,
check_error_log
,
log_name
=
log_name
+
"_base_2cards"
,
gpu
s
=
"0,1"
)
device
s
=
"0,1"
)
self
.
_use_dgc
=
True
...
...
python/paddle/fluid/tests/unittests/test_dist_mnist_fleet_save.py
浏览文件 @
b1026f64
...
...
@@ -89,8 +89,8 @@ class TestDistMnistFleetSave(TestDistBase):
tr0_losses
,
tr1_losses
=
self
.
_run_cluster_nccl2
(
model_file
,
required_envs
,
False
,
check_error_log
,
update_method
=
'nccl2'
,
check_error_log
=
check_error_log
,
log_name
=
log_name
)
dirname
=
'/tmp'
...
...
python/paddle/fluid/tests/unittests/test_dist_sharding_save.py
浏览文件 @
b1026f64
...
...
@@ -32,7 +32,6 @@ class TestDistMnistFleetSave(TestDistBase):
self
.
_sharding_save
=
True
self
.
_enforce_place
=
"GPU"
def
_rm_temp_files
(
self
,
dirname
):
shutil
.
rmtree
(
dirname
)
...
...
@@ -40,9 +39,13 @@ class TestDistMnistFleetSave(TestDistBase):
sharding_save_files
=
sorted
(
os
.
listdir
(
dirname
))
check_files
=
[
'fc_0.b_0'
,
'fc_0.b_0_velocity_0'
,
'fc_0.w_0'
,
'fc_0.w_0_velocity_0'
,
'fc_1.b_0'
,
'fc_1.b_0_velocity_0'
,
'fc_1.w_0'
,
'fc_1.w_0_velocity_0'
,
'fc_2.b_0'
,
'fc_2.b_0_velocity_0'
,
'fc_2.w_0'
,
'fc_2.w_0_velocity_0'
,
'learning_rate_0'
]
check_files
=
[
'fc_0.b_0'
,
'fc_0.b_0_velocity_0'
,
'fc_0.w_0'
,
'fc_0.w_0_velocity_0'
,
'fc_1.b_0'
,
'fc_1.b_0_velocity_0'
,
'fc_1.w_0'
,
'fc_1.w_0_velocity_0'
,
'fc_2.b_0'
,
'fc_2.b_0_velocity_0'
,
'fc_2.w_0'
,
'fc_2.w_0_velocity_0'
,
'learning_rate_0'
]
if
sharding_save_files
!=
check_files
:
self
.
_rm_temp_files
(
dirname
)
...
...
@@ -62,8 +65,8 @@ class TestDistMnistFleetSave(TestDistBase):
tr0_losses
,
tr1_losses
=
self
.
_run_cluster_nccl2
(
model_file
,
required_envs
,
False
,
check_error_log
,
update_method
=
'nccl2'
,
check_error_log
=
check_error_log
,
log_name
=
log_name
)
dirname
=
'./ut_sharding_save_model'
...
...
python/paddle/fluid/tests/unittests/test_fleet_launch_nproc.sh
浏览文件 @
b1026f64
...
...
@@ -27,7 +27,7 @@ function test_nproc_0(){
# nproc_per_node=1, each with 2 gpus
python
-m
paddle.distributed.launch
${
distributed_args
}
nproc_process.py fleet_nproc_0
str0
=
"selected_
gpu
s:
${
gpus
}
worker_endpoints:127.0.0.1:35789 trainers_num:1 current_endpoint:127.0.0.1:35789 trainer_id:0"
str0
=
"selected_
device
s:
${
gpus
}
worker_endpoints:127.0.0.1:35789 trainers_num:1 current_endpoint:127.0.0.1:35789 trainer_id:0"
if
grep
-q
"
$str0
"
"
$file_0
"
;
then
echo
"find trainer 0"
else
...
...
@@ -50,6 +50,12 @@ if ! python detected_gpu.py ; then
test_nproc_0
""
fi
# unittest3:xpu
if
python detected_xpu.py
;
then
echo
"begin ut 3:"
export
XPU_VISIBLE_DEVICES
=
0,1
test_nproc_0
"0,1"
fi
function
test_nproc_1_gpu
(){
file_0
=
"fleet_nproc_1.check_0.log"
...
...
@@ -59,7 +65,7 @@ function test_nproc_1_gpu(){
distributed_args
=
"--log_dir=testlog --nproc_per_node=2"
python
-m
paddle.distributed.launch
${
distributed_args
}
nproc_process.py fleet_nproc_1
str0
=
"selected_
gpu
s:0 worker_endpoints:127.0.0.1:35789,127.0.0.1:35790 trainers_num:2 current_endpoint:127.0.0.1:35789 trainer_id:0"
str0
=
"selected_
device
s:0 worker_endpoints:127.0.0.1:35789,127.0.0.1:35790 trainers_num:2 current_endpoint:127.0.0.1:35789 trainer_id:0"
if
grep
-q
"
$str0
"
"
$file_0
"
;
then
echo
"find trainer 0"
else
...
...
@@ -67,7 +73,7 @@ function test_nproc_1_gpu(){
exit
-1
fi
str1
=
"selected_
gpu
s:1 worker_endpoints:127.0.0.1:35789,127.0.0.1:35790 trainers_num:2 current_endpoint:127.0.0.1:35790 trainer_id:1"
str1
=
"selected_
device
s:1 worker_endpoints:127.0.0.1:35789,127.0.0.1:35790 trainers_num:2 current_endpoint:127.0.0.1:35790 trainer_id:1"
if
grep
-q
"
$str1
"
"
$file_1
"
;
then
echo
"find trainer 1"
else
...
...
@@ -76,9 +82,9 @@ function test_nproc_1_gpu(){
fi
}
# unittest
3
: nproc_per_node=2, each with 1 gpus
# unittest
4
: nproc_per_node=2, each with 1 gpus
if
python detected_gpu.py
;
then
echo
"begin ut
3
:"
echo
"begin ut
4
:"
export
CUDA_VISIBLE_DEVICES
=
0,1
test_nproc_1_gpu
fi
...
...
@@ -91,7 +97,7 @@ function test_nproc_1_cpu(){
distributed_args
=
"--log_dir=testlog --nproc_per_node=2"
python
-m
paddle.distributed.launch
${
distributed_args
}
nproc_process.py fleet_nproc_1
str0
=
"selected_
gpu
s: worker_endpoints:127.0.0.1:35789,127.0.0.1:35790 trainers_num:2 current_endpoint:127.0.0.1:35789 trainer_id:0"
str0
=
"selected_
device
s: worker_endpoints:127.0.0.1:35789,127.0.0.1:35790 trainers_num:2 current_endpoint:127.0.0.1:35789 trainer_id:0"
if
grep
-q
"
$str0
"
"
$file_0
"
;
then
echo
"find trainer 0"
else
...
...
@@ -99,7 +105,7 @@ function test_nproc_1_cpu(){
exit
-1
fi
str1
=
"selected_
gpu
s: worker_endpoints:127.0.0.1:35789,127.0.0.1:35790 trainers_num:2 current_endpoint:127.0.0.1:35790 trainer_id:1"
str1
=
"selected_
device
s: worker_endpoints:127.0.0.1:35789,127.0.0.1:35790 trainers_num:2 current_endpoint:127.0.0.1:35790 trainer_id:1"
if
grep
-q
"
$str1
"
"
$file_1
"
;
then
echo
"find trainer 1"
else
...
...
@@ -108,9 +114,42 @@ function test_nproc_1_cpu(){
fi
}
# unittest
4
: nproc_per_node=2, cpu
# unittest
5
: nproc_per_node=2, cpu
if
!
python detected_gpu.py
;
then
echo
"begin ut
4
:"
echo
"begin ut
5
:"
export
CUDA_VISIBLE_DEVICES
=
""
test_nproc_1_cpu
fi
function
test_nproc_1_xpu
(){
file_0
=
"fleet_nproc_1.check_0.log"
file_1
=
"fleet_nproc_1.check_1.log"
rm
-f
${
file_0
}
${
file_1
}
distributed_args
=
"--log_dir=testlog --nproc_per_node=2"
python
-m
paddle.distributed.launch
${
distributed_args
}
nproc_process.py fleet_nproc_1
str0
=
"selected_devices:0 worker_endpoints:127.0.0.1:35789,127.0.0.1:35790 trainers_num:2 current_endpoint:127.0.0.1:35789 trainer_id:0"
if
grep
-q
"
$str0
"
"
$file_0
"
;
then
echo
"find trainer 0"
else
echo
"not find trainer 0"
exit
-1
fi
str1
=
"selected_devices:1 worker_endpoints:127.0.0.1:35789,127.0.0.1:35790 trainers_num:2 current_endpoint:127.0.0.1:35790 trainer_id:1"
if
grep
-q
"
$str1
"
"
$file_1
"
;
then
echo
"find trainer 1"
else
echo
"not find trainer 1"
exit
-1
fi
}
# unittest6: nproc_per_node=2, each with 1 gpus
if
python detected_xpu.py
;
then
echo
"begin ut 6:"
export
XPU_VISIBLE_DEVICES
=
0,1
test_nproc_1_xpu
fi
python/paddle/fluid/tests/unittests/test_parallel_dygraph_mnist.py
浏览文件 @
b1026f64
...
...
@@ -41,6 +41,25 @@ class TestParallelDygraphMnist(TestDistBase):
log_name
=
flag_name
)
#TODO(liuyuhui): Multi-Card Baidu Kunlun XPU training exist accuracy problems
#it is difficult to find out immediately where the problem is,
#and we will work with frameworkers' help to fix it.
class
TestParallelDygraphMnistXPU
(
TestDistBase
):
def
_setup_config
(
self
):
self
.
_sync_mode
=
False
self
.
_bkcl_mode
=
True
self
.
_dygraph
=
True
self
.
_enforce_place
=
"XPU"
def
test_mnist_xpu
(
self
):
if
fluid
.
core
.
is_compiled_with_xpu
():
self
.
check_with_place
(
"parallel_dygraph_mnist.py"
,
delta
=
1e-1
,
check_error_log
=
True
,
log_name
=
flag_name
)
class
TestParallelDygraphMnistSpawn
(
TestDistSpawnRunner
):
def
test_mnist_with_spawn
(
self
):
if
fluid
.
core
.
is_compiled_with_cuda
()
and
sys
.
version_info
>=
(
3
,
4
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录