Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
afa26a59
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2305
Star
20932
Fork
5423
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
afa26a59
编写于
3月 13, 2023
作者:
TaoTao Li
提交者:
GitHub
3月 13, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add phi operator all_gather (#51420)
* add all_gather and fix conflicts * fix code format * fix ut * fix broadcast ut
上级
d08a1a0d
变更
15
隐藏空白更改
内联
并排
Showing
15 changed file
with
397 addition
and
13 deletion
+397
-13
paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc
...id/framework/new_executor/interpreter/interpreter_util.cc
+2
-2
paddle/phi/api/yaml/ops.yaml
paddle/phi/api/yaml/ops.yaml
+10
-0
paddle/phi/core/distributed/gloo_comm_context.cc
paddle/phi/core/distributed/gloo_comm_context.cc
+11
-0
paddle/phi/core/distributed/gloo_comm_context.h
paddle/phi/core/distributed/gloo_comm_context.h
+3
-0
paddle/phi/core/distributed/nccl_comm_context.cc
paddle/phi/core/distributed/nccl_comm_context.cc
+12
-0
paddle/phi/core/distributed/nccl_comm_context.h
paddle/phi/core/distributed/nccl_comm_context.h
+4
-0
paddle/phi/infermeta/unary.cc
paddle/phi/infermeta/unary.cc
+8
-0
paddle/phi/infermeta/unary.h
paddle/phi/infermeta/unary.h
+2
-0
paddle/phi/kernels/all_gather_kernel.h
paddle/phi/kernels/all_gather_kernel.h
+27
-0
paddle/phi/kernels/cpu/all_gather_kernel.cc
paddle/phi/kernels/cpu/all_gather_kernel.cc
+65
-0
paddle/phi/kernels/gpu/all_gather_kernel.cu
paddle/phi/kernels/gpu/all_gather_kernel.cu
+89
-0
python/paddle/fluid/tests/unittests/collective/collective_allgather_api.py
...id/tests/unittests/collective/collective_allgather_api.py
+78
-4
python/paddle/fluid/tests/unittests/collective/collective_broadcast_api.py
...id/tests/unittests/collective/collective_broadcast_api.py
+10
-3
python/paddle/fluid/tests/unittests/collective/test_collective_allgather_api.py
...sts/unittests/collective/test_collective_allgather_api.py
+47
-1
python/paddle/fluid/tests/unittests/test_collective_api_base.py
.../paddle/fluid/tests/unittests/test_collective_api_base.py
+29
-3
未找到文件。
paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc
浏览文件 @
afa26a59
...
...
@@ -1150,8 +1150,8 @@ void SetDeviceCommContext(framework::OperatorBase* operator_base,
dev_ctx
->
SetCommContext
(
comm_context
);
}
}
else
{
LOG
(
WARNING
)
<<
"op: "
<<
operator_base
->
Type
()
<<
", ring_id: "
<<
ring_id
<<
", get comm_context failed!"
;
VLOG
(
3
)
<<
"op: "
<<
operator_base
->
Type
()
<<
", ring_id: "
<<
ring_id
<<
", get comm_context failed!"
;
}
}
}
...
...
paddle/phi/api/yaml/ops.yaml
浏览文件 @
afa26a59
...
...
@@ -26,6 +26,16 @@
data_type
:
x
backward
:
addmm_grad
-
op
:
all_gather
args
:
(Tensor X, int ring_id = 0, int nranks=0)
output
:
Tensor(Out)
infer_meta
:
func
:
AllGatherInferMeta
param
:
[
X
,
nranks
]
kernel
:
func
:
all_gather
param
:
[
X
,
nranks
]
-
op
:
allclose
args
:
(Tensor x, Tensor y, Scalar rtol="1e-5", Scalar atol="1e-8", bool equal_nan=false)
output
:
Tensor(out)
...
...
paddle/phi/core/distributed/gloo_comm_context.cc
浏览文件 @
afa26a59
...
...
@@ -14,6 +14,7 @@
#include "paddle/phi/core/distributed/gloo_comm_context.h"
#include <gloo/allgather.h>
#include <gloo/broadcast.h>
#include <gloo/types.h>
...
...
@@ -56,5 +57,15 @@ void GlooCommContext::Broadcast(phi::DenseTensor* out_tensor,
gloo
::
broadcast
(
opts
);
}
void
GlooCommContext
::
AllGather
(
phi
::
DenseTensor
*
out_tensor
,
const
phi
::
DenseTensor
&
in_tensor
)
{
// gloo only uses CPU now
gloo
::
AllgatherOptions
opts
(
gloo_context_
);
const
auto
&
dtype
=
in_tensor
.
dtype
();
GENERATE_FUNC
(
dtype
,
SetInput
,
&
opts
,
in_tensor
);
GENERATE_FUNC
(
dtype
,
SetOutput
,
&
opts
,
out_tensor
);
gloo
::
allgather
(
opts
);
}
}
// namespace distributed
}
// namespace phi
paddle/phi/core/distributed/gloo_comm_context.h
浏览文件 @
afa26a59
...
...
@@ -37,6 +37,9 @@ class GlooCommContext final : public CommContext {
const
phi
::
DenseTensor
&
in_tensor
,
int
root
);
void
AllGather
(
phi
::
DenseTensor
*
out_tensor
,
const
phi
::
DenseTensor
&
in_tensor
);
private:
DISABLE_COPY_AND_ASSIGN
(
GlooCommContext
);
...
...
paddle/phi/core/distributed/nccl_comm_context.cc
浏览文件 @
afa26a59
...
...
@@ -53,5 +53,17 @@ void NCCLCommContext::Broadcast(phi::DenseTensor* out_tensor,
stream
));
}
void
NCCLCommContext
::
AllGather
(
phi
::
DenseTensor
*
out_tensor
,
const
phi
::
DenseTensor
&
in_tensor
,
gpuStream_t
stream
)
{
PADDLE_ENFORCE_GPU_SUCCESS
(
phi
::
dynload
::
ncclAllGather
(
in_tensor
.
data
(),
out_tensor
->
data
(),
in_tensor
.
numel
(),
ToNCCLDataType
(
in_tensor
.
type
()),
nccl_comm_
,
stream
));
}
}
// namespace distributed
}
// namespace phi
paddle/phi/core/distributed/nccl_comm_context.h
浏览文件 @
afa26a59
...
...
@@ -36,6 +36,10 @@ class NCCLCommContext final : public CommContext {
int
root
,
gpuStream_t
stream
);
void
AllGather
(
phi
::
DenseTensor
*
out_tensor
,
const
phi
::
DenseTensor
&
in_tensor
,
gpuStream_t
stream
);
private:
DISABLE_COPY_AND_ASSIGN
(
NCCLCommContext
);
...
...
paddle/phi/infermeta/unary.cc
浏览文件 @
afa26a59
...
...
@@ -120,6 +120,14 @@ void AffineGridInferMeta(const MetaTensor& input,
output
->
share_lod
(
input
);
}
void
AllGatherInferMeta
(
const
MetaTensor
&
x
,
int
nranks
,
MetaTensor
*
out
)
{
auto
dim
=
x
.
dims
();
dim
[
0
]
=
dim
[
0
]
*
nranks
;
if
(
dim
[
0
]
<
0
)
dim
[
0
]
=
-
1
;
out
->
set_dtype
(
x
.
dtype
());
out
->
set_dims
(
dim
);
}
void
ArgMinMaxInferMeta
(
const
MetaTensor
&
x
,
const
Scalar
&
axis
,
bool
keepdims
,
...
...
paddle/phi/infermeta/unary.h
浏览文件 @
afa26a59
...
...
@@ -39,6 +39,8 @@ void AffineGridInferMeta(const MetaTensor& input,
bool
align_corners
,
MetaTensor
*
output
);
void
AllGatherInferMeta
(
const
MetaTensor
&
x
,
int
nranks
,
MetaTensor
*
out
);
void
ArgMinMaxInferMeta
(
const
MetaTensor
&
x
,
const
Scalar
&
axis
,
bool
keepdims
,
...
...
paddle/phi/kernels/all_gather_kernel.h
0 → 100644
浏览文件 @
afa26a59
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
AllGatherKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
int
nranks
,
DenseTensor
*
out
);
}
// namespace phi
paddle/phi/kernels/cpu/all_gather_kernel.cc
0 → 100644
浏览文件 @
afa26a59
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/all_gather_kernel.h"
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/core/kernel_registry.h"
#if defined(PADDLE_WITH_GLOO)
#include "paddle/phi/core/distributed/gloo_comm_context.h"
#endif
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
AllGatherKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
int
nranks
,
DenseTensor
*
out
)
{
#if defined(PADDLE_WITH_GLOO)
dev_ctx
.
template
Alloc
<
T
>(
out
);
auto
out_dims
=
x
.
dims
();
out_dims
[
0
]
*=
nranks
;
out
->
Resize
(
out_dims
);
auto
comm_ctx
=
static_cast
<
distributed
::
GlooCommContext
*>
(
dev_ctx
.
GetCommContext
());
PADDLE_ENFORCE_EQ
(
nranks
,
comm_ctx
->
GetSize
(),
errors
::
InvalidArgument
(
"nranks: %s should equal to %s"
,
nranks
,
comm_ctx
->
GetSize
()));
comm_ctx
->
AllGather
(
out
,
x
);
#else
PADDLE_THROW
(
errors
::
Unavailable
(
"PaddlePaddle should compile with GLOO by setting WITH_GLOO=ON"
));
#endif
}
}
// namespace phi
PD_REGISTER_KERNEL
(
all_gather
,
CPU
,
ALL_LAYOUT
,
phi
::
AllGatherKernel
,
float
,
double
,
int
,
bool
,
int8_t
,
uint8_t
,
int64_t
,
phi
::
dtype
::
float16
)
{}
paddle/phi/kernels/gpu/all_gather_kernel.cu
0 → 100644
浏览文件 @
afa26a59
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/all_gather_kernel.h"
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/core/kernel_registry.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/phi/core/distributed/nccl_comm_context.h"
#endif
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
AllGatherKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
int
nranks
,
DenseTensor
*
out
)
{
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
auto
out_dims
=
x
.
dims
();
out_dims
[
0
]
*=
nranks
;
out
->
Resize
(
out_dims
);
dev_ctx
.
template
Alloc
<
T
>(
out
);
auto
comm_ctx
=
static_cast
<
distributed
::
NCCLCommContext
*>
(
dev_ctx
.
GetCommContext
());
PADDLE_ENFORCE_NE
(
comm_ctx
,
nullptr
,
errors
::
Unavailable
(
"NCCLCommContext is nullptr, collective op should "
"has ring_id attr."
));
PADDLE_ENFORCE_EQ
(
nranks
,
comm_ctx
->
GetSize
(),
errors
::
InvalidArgument
(
"nranks: %s should equal to %s"
,
nranks
,
comm_ctx
->
GetSize
()));
gpuStream_t
stream
=
dev_ctx
.
stream
();
comm_ctx
->
AllGather
(
out
,
x
,
stream
);
#else
PADDLE_THROW
(
errors
::
PreconditionNotMet
(
"PaddlePaddle should compile with GPU."
));
#endif
}
}
// namespace phi
// TODO(yuwentao01) the embedded macro definition will get an error under
// windows, need to be solved in phi
#if NCCL_VERSION_CODE >= 21000
PD_REGISTER_KERNEL
(
all_gather
,
GPU
,
ALL_LAYOUT
,
phi
::
AllGatherKernel
,
float
,
double
,
int
,
uint8_t
,
int8_t
,
int64_t
,
bool
,
phi
::
dtype
::
bfloat16
,
phi
::
dtype
::
float16
)
{}
#else
PD_REGISTER_KERNEL
(
all_gather
,
GPU
,
ALL_LAYOUT
,
phi
::
AllGatherKernel
,
float
,
double
,
int
,
uint8_t
,
int8_t
,
int64_t
,
bool
,
phi
::
dtype
::
float16
)
{}
#endif
python/paddle/fluid/tests/unittests/collective/collective_allgather_api.py
浏览文件 @
afa26a59
...
...
@@ -19,11 +19,65 @@ import sys
import
test_collective_api_base
as
test_base
import
paddle
import
paddle.distributed
as
dist
import
paddle.fluid
as
fluid
import
paddle.fluid.data_feeder
as
data_feeder
import
paddle.framework
as
framework
paddle
.
enable_static
()
def
all_gather_new
(
tensor_list
,
tensor
,
group
=
None
):
op_type
=
'all_gather'
helper
=
framework
.
LayerHelper
(
op_type
,
**
locals
())
out
=
helper
.
create_variable_for_type_inference
(
dtype
=
tensor
.
dtype
)
for
elem
in
tensor_list
:
data_feeder
.
check_variable_and_dtype
(
elem
,
'tensor_list'
,
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
,
'bool'
,
'int8'
,
'uint8'
,
],
op_type
,
)
data_feeder
.
check_variable_and_dtype
(
tensor
,
'tensor'
,
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
,
'bool'
,
'int8'
,
'uint8'
,
],
op_type
,
)
ring_id
=
0
if
group
is
None
else
group
.
id
nranks
=
dist
.
get_world_size
()
helper
.
append_op
(
type
=
op_type
,
inputs
=
{
'X'
:
[
tensor
]},
outputs
=
{
'Out'
:
[
out
]},
attrs
=
{
'ring_id'
:
ring_id
,
'nranks'
:
nranks
,
},
)
tensor_list
.
clear
()
tensor_list
.
extend
(
paddle
.
split
(
out
,
nranks
,
0
))
class
TestCollectiveAllgatherAPI
(
test_base
.
TestCollectiveAPIRunnerBase
):
def
__init__
(
self
):
self
.
global_ring_id
=
0
...
...
@@ -33,11 +87,22 @@ class TestCollectiveAllgatherAPI(test_base.TestCollectiveAPIRunnerBase):
with
fluid
.
program_guard
(
main_prog
,
startup_program
):
tensor_list
=
[]
tindata
=
paddle
.
static
.
data
(
name
=
"tindata"
,
shape
=
[
-
1
,
10
,
1000
],
dtype
=
dtype
name
=
"tindata"
,
shape
=
[
10
,
1000
],
dtype
=
dtype
)
paddle
.
distributed
.
all_gather
(
tensor_list
,
tindata
)
return
tensor_list
def
get_model_new
(
self
,
main_prog
,
startup_program
,
rank
,
dtype
=
None
,
reduce_type
=
None
):
with
fluid
.
program_guard
(
main_prog
,
startup_program
):
tensor_list
=
[]
tindata
=
paddle
.
static
.
data
(
name
=
"tindata"
,
shape
=
[
10
,
1000
],
dtype
=
dtype
)
all_gather_new
(
tensor_list
,
tindata
)
return
tensor_list
def
run_trainer
(
self
,
args
):
train_prog
=
fluid
.
Program
()
startup_prog
=
fluid
.
Program
()
...
...
@@ -45,7 +110,10 @@ class TestCollectiveAllgatherAPI(test_base.TestCollectiveAPIRunnerBase):
rank
=
args
[
"trainerid"
]
current_endpoint
=
args
[
"currentendpoint"
]
nranks
=
2
paddle
.
distributed
.
init_parallel_env
()
if
args
[
"use_comm_context"
]:
paddle
.
distributed
.
collective
.
_init_parallel_env
(
args
[
"backend"
])
else
:
paddle
.
distributed
.
init_parallel_env
()
if
args
[
'backend'
]
==
'nccl'
:
device_id
=
int
(
os
.
getenv
(
"FLAGS_selected_gpus"
,
"0"
))
place
=
fluid
.
CUDAPlace
(
...
...
@@ -62,8 +130,14 @@ class TestCollectiveAllgatherAPI(test_base.TestCollectiveAPIRunnerBase):
assert
(
args
[
'static_mode'
]
==
1
),
"collective_allgather_api only support static graph mode"
result
=
self
.
get_model
(
train_prog
,
startup_prog
,
rank
,
dtype
=
args
[
"dtype"
]
result
=
(
self
.
get_model_new
(
train_prog
,
startup_prog
,
rank
,
dtype
=
args
[
"dtype"
]
)
if
args
[
"use_comm_context"
]
else
self
.
get_model
(
train_prog
,
startup_prog
,
rank
,
dtype
=
args
[
"dtype"
]
)
)
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
startup_prog
)
...
...
python/paddle/fluid/tests/unittests/collective/collective_broadcast_api.py
浏览文件 @
afa26a59
...
...
@@ -58,16 +58,23 @@ class TestCollectiveBroadcastAPI(TestCollectiveAPIRunnerBase):
def
__init__
(
self
):
self
.
global_ring_id
=
0
def
get_model
(
self
,
main_prog
,
startup_program
,
rank
):
def
get_model
(
self
,
main_prog
,
startup_program
,
rank
,
dtype
=
'float32'
):
with
fluid
.
program_guard
(
main_prog
,
startup_program
):
tindata
=
paddle
.
static
.
data
(
name
=
"tindata"
,
shape
=
[
-
1
,
10
,
1000
],
dtype
=
'float32'
name
=
"tindata"
,
shape
=
[
-
1
,
10
,
1000
],
dtype
=
dtype
)
tindata
.
desc
.
set_need_check_feed
(
False
)
paddle
.
distributed
.
broadcast
(
tindata
,
src
=
1
)
return
[
tindata
]
def
get_model_new
(
self
,
main_prog
,
startup_program
,
rank
,
dtype
=
None
):
def
get_model_new
(
self
,
main_prog
,
startup_program
,
rank
,
dtype
=
'float32'
,
reduce_type
=
None
,
):
with
fluid
.
program_guard
(
main_prog
,
startup_program
):
tindata
=
paddle
.
static
.
data
(
name
=
"tindata"
,
shape
=
[
-
1
,
10
,
1000
],
dtype
=
dtype
...
...
python/paddle/fluid/tests/unittests/collective/test_collective_allgather_api.py
浏览文件 @
afa26a59
...
...
@@ -38,7 +38,32 @@ class TestCollectiveAllgatherAPI(TestDistBase):
]
for
dtype
in
dtypes_to_test
:
self
.
check_with_place
(
"collective_allgather_api.py"
,
"allgather"
,
"nccl"
,
dtype
=
dtype
"collective_allgather_api.py"
,
"allgather"
,
"nccl"
,
dtype
=
dtype
,
)
def
test_allgather_nccl_with_comm_context
(
self
):
dtypes_to_test
=
[
"float16"
,
"float32"
,
"float64"
,
"int32"
,
"int64"
,
"int8"
,
"uint8"
,
"bool"
,
]
if
self
.
_nccl_version
>=
2100
:
dtypes_to_test
.
append
(
"bfloat16"
)
for
dtype
in
dtypes_to_test
:
self
.
check_with_place
(
"collective_allgather_api.py"
,
"allgather"
,
"nccl"
,
dtype
=
dtype
,
need_envs
=
{
"USE_COMM_CONTEXT"
:
"1"
},
)
def
test_allgather_gloo
(
self
):
...
...
@@ -61,6 +86,27 @@ class TestCollectiveAllgatherAPI(TestDistBase):
dtype
=
dtype
,
)
def
test_allgather_gloo_with_comm_context
(
self
):
dtypes_to_test
=
[
"float16"
,
"float32"
,
"float64"
,
"int32"
,
"int64"
,
"int8"
,
"uint8"
,
"bool"
,
]
for
dtype
in
dtypes_to_test
:
self
.
check_with_place
(
"collective_allgather_api.py"
,
"allgather"
,
"gloo"
,
"3"
,
dtype
=
dtype
,
need_envs
=
{
"USE_COMM_CONTEXT"
:
"1"
},
)
def
test_allgather_nccl_dygraph
(
self
):
dtypes_to_test
=
[
"float16"
,
...
...
python/paddle/fluid/tests/unittests/test_collective_api_base.py
浏览文件 @
afa26a59
...
...
@@ -25,6 +25,7 @@ import numpy as np
from
paddle_bfloat
import
bfloat16
import
paddle
import
paddle.distributed
as
dist
import
paddle.fluid
as
fluid
from
paddle.distributed.utils.nccl_utils
import
get_nccl_version_str
from
paddle.fluid
import
core
...
...
@@ -47,7 +48,7 @@ def create_float_test_data(shape=None, dtype=None, seed=None):
def
create_int_test_data
(
shape
=
None
,
dtype
=
None
,
seed
=
None
):
if
seed
:
np
.
random
.
seed
(
seed
)
data
=
np
.
random
.
randint
(
0
,
high
=
1
00
,
size
=
shape
).
astype
(
dtype
)
data
=
np
.
random
.
randint
(
0
,
high
=
1
2
,
size
=
shape
).
astype
(
dtype
)
return
data
...
...
@@ -128,7 +129,13 @@ class TestCollectiveAPIRunnerBase:
)
if
args
[
'static_mode'
]:
result
=
(
self
.
get_model_new
(
train_prog
,
startup_prog
,
rank
)
self
.
get_model_new
(
train_prog
,
startup_prog
,
rank
,
dtype
=
args
[
"dtype"
],
reduce_type
=
args
[
'reduce_type'
],
)
if
args
[
"use_comm_context"
]
else
self
.
get_model
(
train_prog
,
startup_prog
,
rank
)
)
...
...
@@ -158,6 +165,7 @@ def runtime_main(test_class, col_type):
args
[
"path_id"
]
=
int
(
os
.
getenv
(
"PATH_ID"
))
args
[
"static_mode"
]
=
int
(
os
.
getenv
(
"STATIC_MODE"
))
args
[
"dtype"
]
=
os
.
getenv
(
"DTYPE"
)
args
[
"reduce_type"
]
=
os
.
getenv
(
"REDUCE_TYPE"
)
args
[
"use_comm_context"
]
=
bool
(
int
(
os
.
getenv
(
"USE_COMM_CONTEXT"
,
"0"
)))
model
.
run_trainer
(
args
)
...
...
@@ -298,6 +306,7 @@ class TestDistBase(unittest.TestCase):
need_envs
=
{},
eager_mode
=
True
,
dtype
=
None
,
reduce_type
=
None
,
):
if
backend
==
"nccl"
or
backend
==
"bkcl"
:
with_gloo
=
'0'
...
...
@@ -305,6 +314,7 @@ class TestDistBase(unittest.TestCase):
with_gloo
=
'1'
required_envs
=
os
.
environ
.
copy
()
dtype
=
"float32"
if
dtype
is
None
else
dtype
reduce_type
=
dist
.
ReduceOp
.
SUM
if
reduce_type
is
None
else
reduce_type
additional_envs
=
{
"NCCL_P2P_DISABLE"
:
"1"
,
"STATIC_MODE"
:
static_mode
,
...
...
@@ -313,6 +323,7 @@ class TestDistBase(unittest.TestCase):
"BACKEND"
:
backend
,
"PATH_ID"
:
path_id
,
"DTYPE"
:
dtype
,
"REDUCE_TYPE"
:
str
(
reduce_type
),
}
required_envs
.
update
(
additional_envs
)
required_envs
.
update
(
need_envs
)
...
...
@@ -354,6 +365,14 @@ class TestDistBase(unittest.TestCase):
self
.
assertEqual
(
need_result
,
tr0_out
)
self
.
assertEqual
(
need_result
,
tr1_out
)
elif
col_type
==
"reduce"
:
if
reduce_type
==
dist
.
ReduceOp
.
SUM
:
need_result
=
input1
+
input2
elif
reduce_type
==
dist
.
ReduceOp
.
MAX
:
need_result
=
np
.
amax
([
input1
,
input2
],
0
)
elif
reduce_type
==
dist
.
ReduceOp
.
MIN
:
need_result
=
np
.
amin
([
input1
,
input2
],
0
)
elif
reduce_type
==
dist
.
ReduceOp
.
PROD
:
need_result
=
np
.
prod
([
input1
,
input2
],
0
)
need_result
=
input1
+
input2
# bfloat16 precision loss comes from truncating the last 16 bits of float32,
# which sums (\sum_{i=-23}^{-8}2^{i}) to about 0.0078
...
...
@@ -385,7 +404,14 @@ class TestDistBase(unittest.TestCase):
np
.
testing
.
assert_allclose
(
tr0_out
[
0
],
need_result1
,
rtol
=
rtol
)
np
.
testing
.
assert_allclose
(
tr1_out
[
0
],
need_result2
,
rtol
=
rtol
)
elif
col_type
==
"allreduce"
:
need_result
=
input1
+
input2
if
reduce_type
==
dist
.
ReduceOp
.
SUM
:
need_result
=
input1
+
input2
elif
reduce_type
==
dist
.
ReduceOp
.
MAX
:
need_result
=
np
.
amax
([
input1
,
input2
],
0
)
elif
reduce_type
==
dist
.
ReduceOp
.
MIN
:
need_result
=
np
.
amin
([
input1
,
input2
],
0
)
elif
reduce_type
==
dist
.
ReduceOp
.
PROD
:
need_result
=
np
.
prod
([
input1
,
input2
],
0
)
if
dtype
==
"bfloat16"
:
rtol
=
8e-03
atol
=
8e-03
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录