Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
80cc4f0d
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
80cc4f0d
编写于
8月 08, 2022
作者:
Y
Yulong Ao
提交者:
GitHub
8月 08, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Auto Parallel] Add the C++ ProcessMesh and DistributedMapper (#44963)
上级
a1da4f2f
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
620 addition
and
28 deletion
+620
-28
paddle/fluid/distributed/auto_parallel/CMakeLists.txt
paddle/fluid/distributed/auto_parallel/CMakeLists.txt
+16
-28
paddle/fluid/distributed/auto_parallel/auto_parallel.proto
paddle/fluid/distributed/auto_parallel/auto_parallel.proto
+32
-0
paddle/fluid/distributed/auto_parallel/dist_mapper.cc
paddle/fluid/distributed/auto_parallel/dist_mapper.cc
+146
-0
paddle/fluid/distributed/auto_parallel/dist_mapper.h
paddle/fluid/distributed/auto_parallel/dist_mapper.h
+73
-0
paddle/fluid/distributed/auto_parallel/dist_mapper_test.cc
paddle/fluid/distributed/auto_parallel/dist_mapper_test.cc
+72
-0
paddle/fluid/distributed/auto_parallel/process_mesh.cc
paddle/fluid/distributed/auto_parallel/process_mesh.cc
+134
-0
paddle/fluid/distributed/auto_parallel/process_mesh.h
paddle/fluid/distributed/auto_parallel/process_mesh.h
+94
-0
paddle/fluid/distributed/auto_parallel/process_mesh_test.cc
paddle/fluid/distributed/auto_parallel/process_mesh_test.cc
+53
-0
未找到文件。
paddle/fluid/distributed/auto_parallel/CMakeLists.txt
浏览文件 @
80cc4f0d
...
...
@@ -7,34 +7,22 @@ cc_test(
SRCS device_mesh_test.cc
DEPS device_mesh
)
# cc_library(
# process_mesh
# SRCS process_mesh.cc
# DEPS auto_parallel_proto)
# cc_test(
# process_mesh_test
# SRCS process_mesh_test.cc
# DEPS process_mesh)
# cc_library(
# dist_attr
# SRCS dist_attr.cc
# DEPS process_mesh auto_parallel_proto proto_desc)
# cc_test(
# dist_attr_test
# SRCS dist_attr_test.cc
# DEPS dist_attr)
cc_library
(
process_mesh
SRCS process_mesh.cc
DEPS auto_parallel_proto
)
cc_test
(
process_mesh_test
SRCS process_mesh_test.cc
DEPS process_mesh
)
#
cc_library(
#
dist_mapper
#
SRCS dist_mapper.cc
#
DEPS device_mesh auto_parallel_proto)
#
cc_test(
#
dist_mapper_test
#
SRCS dist_mapper_test.cc
#
DEPS dist_mapper)
cc_library
(
dist_mapper
SRCS dist_mapper.cc
DEPS device_mesh auto_parallel_proto
)
cc_test
(
dist_mapper_test
SRCS dist_mapper_test.cc
DEPS dist_mapper
)
proto_library
(
auto_parallel_proto SRCS auto_parallel.proto
)
# cc_library(auto_parallel DEPS process_mesh device_mesh dist_attr dist_mapper
# auto_parallel_proto)
paddle/fluid/distributed/auto_parallel/auto_parallel.proto
浏览文件 @
80cc4f0d
...
...
@@ -16,6 +16,20 @@ syntax = "proto2";
package
paddle
.
distributed.auto_parallel
;
// ProcessMesh is used to organize processes and like n-dimension array.
message
ProcessMeshProto
{
// The size of each dimension.
repeated
int64
shape
=
1
;
// These process ids are stored by a row-major way.
// There are no duplicate process ids within one process mesh.
repeated
int64
process_ids
=
2
;
// The name of each dimension.
repeated
string
dim_names
=
3
;
}
// This proto describes the capability of one device such as the computation and memory.
message
DeviceCapabilityProto
{
optional
double
single_precision_flops
=
1
;
...
...
@@ -86,3 +100,21 @@ message DeviceMeshProto {
// The links are between devices.
repeated
LinkProto
links
=
6
;
}
// Record the mapping between the logical processes and the physical devices.
message
DistributedMapperProto
{
// The device meshes used by this distributed computation,
// which may be shared by different multiple device meshes.
repeated
DeviceMeshProto
device_meshes
=
1
;
message
MapperEntryProto
{
optional
int64
process_id
=
1
;
optional
string
device_mesh_name
=
2
;
repeated
int64
device_ids
=
3
;
}
// The mapping from process ids to device ids.
// It is also possible for one process to use multiple devices.
// It is possible for one device shared by multiple processes.
repeated
MapperEntryProto
process_id_to_device_ids
=
2
;
}
paddle/fluid/distributed/auto_parallel/dist_mapper.cc
0 → 100644
浏览文件 @
80cc4f0d
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include "paddle/fluid/distributed/auto_parallel/dist_mapper.h"
#include "paddle/fluid/distributed/auto_parallel/utils.h"
namespace
paddle
{
namespace
distributed
{
namespace
auto_parallel
{
void
DistributedMapper
::
set_process_id_to_device_ids
(
const
std
::
map
<
int64_t
,
std
::
pair
<
std
::
string
,
std
::
vector
<
int64_t
>>>&
process_id_to_device_ids
)
{
std
::
vector
<
std
::
string
>
device_mesh_names
;
for
(
const
auto
&
item
:
device_meshes_
)
{
device_mesh_names
.
push_back
(
item
.
first
);
}
for
(
const
auto
&
item
:
process_id_to_device_ids
)
{
PADDLE_ENFORCE_GE
(
item
.
first
,
0
,
platform
::
errors
::
InvalidArgument
(
"The process id %d must be greater than or equal to 0."
,
item
.
first
));
std
::
string
device_mesh_name
=
item
.
second
.
first
;
const
std
::
vector
<
int64_t
>&
device_ids
=
item
.
second
.
second
;
PADDLE_ENFORCE_EQ
(
device_meshes_
.
count
(
device_mesh_name
),
1
,
platform
::
errors
::
InvalidArgument
(
"Cannot find the device mesh %d in device_mesh ids [%s]."
,
device_mesh_name
,
str_join
(
device_mesh_names
)));
PADDLE_ENFORCE_EQ
(
has_duplicates
(
device_ids
),
false
,
platform
::
errors
::
InvalidArgument
(
"The mapped device ids [%s] of process_mesh %d must be unique."
,
str_join
(
device_ids
),
item
.
first
));
const
DeviceMesh
&
device_mesh
=
device_meshes_
[
device_mesh_name
];
const
std
::
vector
<
int64_t
>
cur_device_ids
=
device_mesh
.
device_ids
();
for
(
int64_t
device_id
:
device_ids
)
{
bool
found
=
std
::
find
(
cur_device_ids
.
begin
(),
cur_device_ids
.
end
(),
device_id
)
!=
cur_device_ids
.
end
();
PADDLE_ENFORCE_EQ
(
found
,
true
,
platform
::
errors
::
InvalidArgument
(
"The device id %d cannot be find in the device mesh [%s]."
,
device_id
,
str_join
(
cur_device_ids
)));
}
}
process_id_to_device_ids_
=
process_id_to_device_ids
;
}
DistributedMapper
DistributedMapper
::
from_proto
(
const
DistributedMapperProto
&
proto
)
{
DistributedMapper
dist_mapper
;
for
(
int64_t
i
=
0
;
i
<
proto
.
device_meshes_size
();
++
i
)
{
dist_mapper
.
device_meshes_
[
proto
.
device_meshes
(
i
).
name
()]
=
DeviceMesh
::
from_proto
(
proto
.
device_meshes
(
i
));
}
for
(
int64_t
i
=
0
;
i
<
proto
.
process_id_to_device_ids_size
();
++
i
)
{
int64_t
process_id
=
proto
.
process_id_to_device_ids
(
i
).
process_id
();
std
::
string
device_mesh_name
=
proto
.
process_id_to_device_ids
(
i
).
device_mesh_name
();
std
::
vector
<
int64_t
>
device_ids
;
int64_t
num_devices
=
proto
.
process_id_to_device_ids
(
i
).
device_ids_size
();
for
(
int64_t
j
=
0
;
j
<
num_devices
;
++
j
)
{
device_ids
.
push_back
(
proto
.
process_id_to_device_ids
(
i
).
device_ids
(
j
));
}
dist_mapper
.
process_id_to_device_ids_
[
process_id
].
first
=
device_mesh_name
;
dist_mapper
.
process_id_to_device_ids_
[
process_id
].
second
=
device_ids
;
}
return
dist_mapper
;
}
DistributedMapperProto
DistributedMapper
::
to_proto
()
const
{
DistributedMapperProto
proto
;
for
(
const
auto
&
item
:
device_meshes_
)
{
proto
.
mutable_device_meshes
()
->
Add
()
->
CopyFrom
(
item
.
second
.
to_proto
());
}
for
(
const
auto
&
outer
:
process_id_to_device_ids_
)
{
auto
proto_item
=
proto
.
mutable_process_id_to_device_ids
()
->
Add
();
proto_item
->
set_process_id
(
outer
.
first
);
proto_item
->
set_device_mesh_name
(
outer
.
second
.
first
);
for
(
const
auto
&
inner
:
outer
.
second
.
second
)
{
proto_item
->
add_device_ids
(
inner
);
}
}
return
proto
;
}
std
::
string
DistributedMapper
::
to_string
()
const
{
std
::
string
mapper_str
=
"{device_meshes: ["
;
for
(
const
auto
&
item
:
device_meshes_
)
{
mapper_str
+=
item
.
second
.
to_string
()
+
", "
;
}
mapper_str
.
replace
(
mapper_str
.
size
()
-
2
,
2
,
"]"
);
mapper_str
+=
"
\n
process_id_to_device_ids: ["
;
for
(
const
auto
&
item
:
process_id_to_device_ids_
)
{
mapper_str
+=
"{"
;
mapper_str
+=
"process_id: "
+
std
::
to_string
(
item
.
first
)
+
", device_ids: ["
;
for
(
const
auto
&
device_id
:
item
.
second
.
second
)
{
mapper_str
+=
"{"
+
item
.
second
.
first
+
", "
+
std
::
to_string
(
device_id
)
+
"}, "
;
}
mapper_str
.
replace
(
mapper_str
.
size
()
-
2
,
2
,
"]"
);
mapper_str
+=
"}, "
;
}
mapper_str
.
replace
(
mapper_str
.
size
()
-
2
,
2
,
"]"
);
mapper_str
+=
"}"
;
return
mapper_str
;
}
bool
operator
==
(
const
DistributedMapper
&
lhs
,
const
DistributedMapper
&
rhs
)
{
if
(
lhs
.
device_meshes
()
!=
rhs
.
device_meshes
())
{
return
false
;
}
if
(
lhs
.
process_id_to_device_ids
()
!=
rhs
.
process_id_to_device_ids
())
{
return
false
;
}
return
true
;
}
}
// namespace auto_parallel
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/auto_parallel/dist_mapper.h
0 → 100644
浏览文件 @
80cc4f0d
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <utility>
#include "paddle/fluid/distributed/auto_parallel/auto_parallel.pb.h"
#include "paddle/fluid/distributed/auto_parallel/device_mesh.h"
#include "paddle/fluid/distributed/auto_parallel/process_mesh.h"
namespace
paddle
{
namespace
distributed
{
namespace
auto_parallel
{
class
DistributedMapper
{
public:
DistributedMapper
()
=
default
;
const
std
::
map
<
std
::
string
,
DeviceMesh
>&
device_meshes
()
const
{
return
device_meshes_
;
}
const
DeviceMesh
&
device_mesh
(
const
std
::
string
&
name
)
const
{
return
device_meshes_
.
at
(
name
);
}
void
add_device_mesh
(
const
DeviceMesh
&
device_mesh
)
{
device_meshes_
[
device_mesh
.
name
()]
=
device_mesh
;
}
const
std
::
map
<
int64_t
,
std
::
pair
<
std
::
string
,
std
::
vector
<
int64_t
>>>&
process_id_to_device_ids
()
const
{
return
process_id_to_device_ids_
;
}
void
set_process_id_to_device_ids
(
const
std
::
map
<
int64_t
,
std
::
pair
<
std
::
string
,
std
::
vector
<
int64_t
>>>&
process_id_to_device_ids
);
// DistributedMapper from_string(const std::string& mapper_str);
std
::
string
to_string
()
const
;
static
DistributedMapper
from_proto
(
const
DistributedMapperProto
&
proto
);
DistributedMapperProto
to_proto
()
const
;
private:
std
::
map
<
std
::
string
,
DeviceMesh
>
device_meshes_
;
std
::
map
<
int64_t
,
std
::
pair
<
std
::
string
,
std
::
vector
<
int64_t
>>>
process_id_to_device_ids_
;
};
bool
operator
==
(
const
DistributedMapper
&
lhs
,
const
DistributedMapper
&
rhs
);
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
DistributedMapper
&
obj
)
{
os
<<
obj
.
to_string
();
return
os
;
}
}
// namespace auto_parallel
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/auto_parallel/dist_mapper_test.cc
0 → 100644
浏览文件 @
80cc4f0d
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/distributed/auto_parallel/dist_mapper.h"
#include <map>
#include <sstream>
#include "gtest/gtest.h"
namespace
paddle
{
namespace
distributed
{
namespace
auto_parallel
{
TEST
(
DistributedMapper
,
Ctor
)
{
std
::
vector
<
int64_t
>
shape
=
{
2
,
3
};
std
::
vector
<
int64_t
>
device_ids
=
{
0
,
1
,
2
,
3
,
4
,
5
};
std
::
vector
<
std
::
string
>
dim_names
=
{
"x"
,
"y"
};
std
::
string
device_type
=
"GPU"
;
int64_t
size
=
shape
[
0
]
*
shape
[
1
];
DeviceMesh
device_mesh
(
"device_mesh"
,
shape
,
device_ids
,
dim_names
);
for
(
int64_t
i
=
0
;
i
<
shape
[
0
];
++
i
)
{
for
(
int64_t
j
=
0
;
j
<
shape
[
1
];
++
j
)
{
int64_t
global_id
=
i
*
shape
[
1
]
+
j
;
int64_t
local_id
=
j
;
int64_t
machine_id
=
i
;
device_mesh
.
add_device
(
Device
(
global_id
,
local_id
,
machine_id
,
device_type
));
}
}
for
(
int64_t
i
=
0
;
i
<
size
;
++
i
)
{
for
(
int64_t
j
=
0
;
j
<
size
;
++
j
)
{
device_mesh
.
add_link
(
Link
(
i
,
j
,
"NVL"
));
}
}
DistributedMapper
dist_mapper
;
dist_mapper
.
add_device_mesh
(
device_mesh
);
std
::
map
<
int64_t
,
std
::
pair
<
std
::
string
,
std
::
vector
<
int64_t
>>>
process_id_to_device_ids
;
process_id_to_device_ids
[
0
]
=
{
"device_mesh"
,
{
5
}};
process_id_to_device_ids
[
1
]
=
{
"device_mesh"
,
{
4
}};
process_id_to_device_ids
[
2
]
=
{
"device_mesh"
,
{
3
}};
process_id_to_device_ids
[
3
]
=
{
"device_mesh"
,
{
2
}};
process_id_to_device_ids
[
4
]
=
{
"device_mesh"
,
{
1
}};
process_id_to_device_ids
[
5
]
=
{
"device_mesh"
,
{
0
}};
dist_mapper
.
set_process_id_to_device_ids
(
process_id_to_device_ids
);
EXPECT_EQ
(
dist_mapper
.
device_meshes
().
at
(
"device_mesh"
),
device_mesh
);
EXPECT_EQ
(
dist_mapper
.
device_mesh
(
"device_mesh"
),
device_mesh
);
EXPECT_EQ
(
dist_mapper
.
process_id_to_device_ids
(),
process_id_to_device_ids
);
std
::
stringstream
sstream
;
sstream
<<
dist_mapper
;
EXPECT_EQ
(
sstream
.
str
(),
dist_mapper
.
to_string
());
auto
proto
=
dist_mapper
.
to_proto
();
DistributedMapper
new_dist_mapper
=
DistributedMapper
::
from_proto
(
proto
);
EXPECT_EQ
(
dist_mapper
,
new_dist_mapper
);
}
}
// namespace auto_parallel
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/auto_parallel/process_mesh.cc
0 → 100644
浏览文件 @
80cc4f0d
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include <iterator>
#include "paddle/fluid/distributed/auto_parallel/process_mesh.h"
#include "paddle/fluid/distributed/auto_parallel/utils.h"
namespace
paddle
{
namespace
distributed
{
namespace
auto_parallel
{
ProcessMesh
::
ProcessMesh
(
const
std
::
vector
<
int64_t
>
&
shape
,
const
std
::
vector
<
int64_t
>
&
process_ids
,
const
std
::
vector
<
std
::
string
>
&
dim_names
)
{
shape_
=
shape
;
int64_t
size
=
this
->
size
();
PADDLE_ENFORCE_EQ
(
size
,
process_ids
.
size
(),
platform
::
errors
::
InvalidArgument
(
"The size of this process mesh must be "
"equal to the size of its process ids."
,
size
,
process_ids
.
size
()));
PADDLE_ENFORCE_EQ
(
has_duplicates
(
process_ids
),
false
,
platform
::
errors
::
InvalidArgument
(
"The process ids [%s] must be unique."
,
str_join
(
process_ids_
)));
process_ids_
=
process_ids
;
PADDLE_ENFORCE_EQ
(
shape_
.
size
(),
dim_names
.
size
(),
platform
::
errors
::
InvalidArgument
(
"The size of mesh shape must be equal to the size "
"of the dimension names."
,
shape_
.
size
(),
dim_names_
.
size
()));
PADDLE_ENFORCE_EQ
(
has_duplicates
(
dim_names
),
false
,
platform
::
errors
::
InvalidArgument
(
"The names [%s] of each dimension must be unique."
,
str_join
(
dim_names
)));
dim_names_
=
dim_names
;
}
int64_t
ProcessMesh
::
size
()
const
{
if
(
shape_
.
empty
())
return
0
;
int64_t
size
=
1
;
for
(
const
int64_t
dim_size
:
shape_
)
size
*=
dim_size
;
return
size
;
}
bool
ProcessMesh
::
contains
(
int64_t
process_id
)
const
{
auto
result
=
std
::
find
(
std
::
begin
(
process_ids_
),
std
::
end
(
process_ids_
),
process_id
);
if
(
result
!=
std
::
end
(
process_ids_
))
{
return
true
;
}
else
{
return
false
;
}
}
std
::
string
ProcessMesh
::
to_string
()
const
{
std
::
string
mesh_str
=
"{shape: ["
+
str_join
(
shape_
)
+
"], "
;
mesh_str
+=
"process_ids: ["
+
str_join
(
process_ids_
)
+
"], "
;
mesh_str
+=
"dim_names: ["
+
str_join
(
dim_names_
)
+
"]}"
;
return
mesh_str
;
}
ProcessMesh
ProcessMesh
::
from_proto
(
const
ProcessMeshProto
&
proto
)
{
ProcessMesh
mesh
;
mesh
.
shape_
.
resize
(
proto
.
shape_size
());
for
(
int64_t
i
=
0
;
i
<
proto
.
shape_size
();
++
i
)
{
mesh
.
shape_
[
i
]
=
proto
.
shape
(
i
);
}
mesh
.
process_ids_
.
resize
(
proto
.
process_ids_size
());
for
(
int64_t
i
=
0
;
i
<
proto
.
process_ids_size
();
++
i
)
{
mesh
.
process_ids_
[
i
]
=
proto
.
process_ids
(
i
);
}
mesh
.
dim_names_
.
resize
(
proto
.
dim_names_size
());
for
(
int64_t
i
=
0
;
i
<
proto
.
dim_names_size
();
++
i
)
{
mesh
.
dim_names_
[
i
]
=
proto
.
dim_names
(
i
);
}
return
mesh
;
}
ProcessMeshProto
ProcessMesh
::
to_proto
()
const
{
ProcessMeshProto
proto
;
for
(
const
auto
&
i
:
shape_
)
{
proto
.
add_shape
(
i
);
}
for
(
const
auto
&
i
:
process_ids_
)
{
proto
.
add_process_ids
(
i
);
}
for
(
const
auto
&
i
:
dim_names_
)
{
proto
.
add_dim_names
(
i
);
}
return
proto
;
}
bool
operator
==
(
const
ProcessMesh
&
lhs
,
const
ProcessMesh
&
rhs
)
{
if
(
lhs
.
shape
()
!=
rhs
.
shape
())
{
return
false
;
}
if
(
lhs
.
process_ids
()
!=
rhs
.
process_ids
())
{
return
false
;
}
return
true
;
}
}
// namespace auto_parallel
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/auto_parallel/process_mesh.h
0 → 100644
浏览文件 @
80cc4f0d
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <pybind11/pybind11.h>
#include <atomic>
#include <cstddef>
#include <cstdint>
#include <string>
#include <vector>
#include "paddle/fluid/distributed/auto_parallel/auto_parallel.pb.h"
#include "paddle/fluid/distributed/auto_parallel/device_mesh.h"
#include "paddle/fluid/distributed/auto_parallel/utils.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
distributed
{
namespace
auto_parallel
{
class
ProcessMesh
{
public:
ProcessMesh
()
=
default
;
ProcessMesh
(
const
std
::
vector
<
int64_t
>&
shape
,
const
std
::
vector
<
int64_t
>&
process_ids
,
const
std
::
vector
<
std
::
string
>&
dim_names
);
const
std
::
vector
<
int64_t
>&
shape
()
const
{
return
shape_
;
}
const
std
::
vector
<
int64_t
>&
process_ids
()
const
{
return
process_ids_
;
}
const
std
::
vector
<
std
::
string
>&
dim_names
()
const
{
return
dim_names_
;
}
int64_t
size
()
const
;
int64_t
ndim
()
const
{
return
shape_
.
size
();
}
int64_t
dim_size
(
int64_t
dim
)
const
{
int64_t
cdim
=
canonical_dim
(
dim
,
shape_
.
size
());
return
shape_
[
cdim
];
}
int64_t
dim_size
(
const
std
::
string
&
dim_name
)
const
{
for
(
std
::
size_t
i
=
0
;
i
<
dim_names_
.
size
();
++
i
)
{
if
(
dim_names_
[
i
]
==
dim_name
)
{
return
shape_
[
i
];
}
}
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Cannot find the dimension of %s in this process mesh."
,
dim_name
));
}
bool
empty
()
const
{
return
(
shape_
.
empty
()
||
process_ids_
.
empty
());
}
bool
contains
(
int64_t
process_id
)
const
;
// ProcessMesh from_string(const std::string& mesh_str);
std
::
string
to_string
()
const
;
static
ProcessMesh
from_proto
(
const
ProcessMeshProto
&
proto
);
ProcessMeshProto
to_proto
()
const
;
private:
std
::
vector
<
int64_t
>
shape_
;
std
::
vector
<
int64_t
>
process_ids_
;
std
::
vector
<
std
::
string
>
dim_names_
;
};
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
ProcessMesh
&
obj
)
{
os
<<
obj
.
to_string
();
return
os
;
}
bool
operator
==
(
const
ProcessMesh
&
lhs
,
const
ProcessMesh
&
rhs
);
inline
bool
operator
!=
(
const
ProcessMesh
&
lhs
,
const
ProcessMesh
&
rhs
)
{
return
!
operator
==
(
lhs
,
rhs
);
}
}
// namespace auto_parallel
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/auto_parallel/process_mesh_test.cc
0 → 100644
浏览文件 @
80cc4f0d
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/distributed/auto_parallel/process_mesh.h"
#include <iostream>
#include <sstream>
#include "gtest/gtest.h"
namespace
paddle
{
namespace
distributed
{
namespace
auto_parallel
{
TEST
(
ProcessMesh
,
Ctor
)
{
std
::
vector
<
int64_t
>
shape
=
{
2
,
3
};
std
::
vector
<
int64_t
>
process_ids
=
{
0
,
1
,
2
,
3
,
4
,
5
};
std
::
vector
<
std
::
string
>
dim_names
=
{
"x"
,
"y"
};
int64_t
size
=
shape
[
0
]
*
shape
[
1
];
ProcessMesh
process_mesh
(
shape
,
process_ids
,
dim_names
);
EXPECT_EQ
(
process_mesh
.
shape
(),
shape
);
EXPECT_EQ
(
process_mesh
.
process_ids
(),
process_ids
);
EXPECT_EQ
(
process_mesh
.
dim_names
()[
0
],
"x"
);
EXPECT_EQ
(
process_mesh
.
dim_names
()[
1
],
"y"
);
EXPECT_EQ
(
process_mesh
.
size
(),
size
);
EXPECT_EQ
(
process_mesh
.
ndim
(),
static_cast
<
int64_t
>
(
shape
.
size
()));
EXPECT_EQ
(
process_mesh
.
dim_size
(
0
),
shape
[
0
]);
EXPECT_EQ
(
process_mesh
.
dim_size
(
-
1
),
shape
[
1
]);
EXPECT_EQ
(
process_mesh
.
dim_size
(
"x"
),
shape
[
0
]);
EXPECT_EQ
(
process_mesh
.
dim_size
(
"y"
),
shape
[
1
]);
EXPECT_EQ
(
process_mesh
.
empty
(),
false
);
EXPECT_EQ
(
process_mesh
.
contains
(
0
),
true
);
EXPECT_EQ
(
process_mesh
.
contains
(
6
),
false
);
std
::
stringstream
sstream
;
sstream
<<
process_mesh
;
EXPECT_EQ
(
sstream
.
str
(),
process_mesh
.
to_string
());
auto
proto
=
process_mesh
.
to_proto
();
ProcessMesh
new_process_mesh
=
ProcessMesh
::
from_proto
(
proto
);
EXPECT_EQ
(
process_mesh
,
new_process_mesh
);
}
}
// namespace auto_parallel
}
// namespace distributed
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录