Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
e3b6e02f
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
e3b6e02f
编写于
8月 04, 2023
作者:
J
JZ-LIANG
提交者:
GitHub
8月 04, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Semi AutoParall] Support Partial Semantic I (#55508)
上级
dd1379e4
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
221 addition
and
6 deletion
+221
-6
paddle/fluid/distributed/auto_parallel/spmd_rules/matmul_spmd_rule.cc
.../distributed/auto_parallel/spmd_rules/matmul_spmd_rule.cc
+1
-0
paddle/fluid/distributed/auto_parallel/spmd_rules/reduction_spmd_rule.cc
...stributed/auto_parallel/spmd_rules/reduction_spmd_rule.cc
+5
-3
paddle/fluid/pybind/auto_parallel_py.cc
paddle/fluid/pybind/auto_parallel_py.cc
+5
-1
paddle/phi/core/distributed/auto_parallel/dist_attr.cc
paddle/phi/core/distributed/auto_parallel/dist_attr.cc
+74
-1
paddle/phi/core/distributed/auto_parallel/dist_attr.h
paddle/phi/core/distributed/auto_parallel/dist_attr.h
+47
-1
test/auto_parallel/spmd_rules/test_matmul_rule.py
test/auto_parallel/spmd_rules/test_matmul_rule.py
+21
-0
test/auto_parallel/spmd_rules/test_reduction_rule.py
test/auto_parallel/spmd_rules/test_reduction_rule.py
+21
-0
test/cpp/auto_parallel/spmd_rule_test.cc
test/cpp/auto_parallel/spmd_rule_test.cc
+47
-0
未找到文件。
paddle/fluid/distributed/auto_parallel/spmd_rules/matmul_spmd_rule.cc
浏览文件 @
e3b6e02f
...
@@ -160,6 +160,7 @@ MatmulSPMDRule::InferForward(const std::vector<DistTensorSpec>& input_specs,
...
@@ -160,6 +160,7 @@ MatmulSPMDRule::InferForward(const std::vector<DistTensorSpec>& input_specs,
// Step2.3.1 Output Partial
// Step2.3.1 Output Partial
std
::
vector
<
int64_t
>
partial_on_dims
=
std
::
vector
<
int64_t
>
partial_on_dims
=
ResoluteOutputPartialDimension
(
axis_to_dim_map
,
out_axes
);
ResoluteOutputPartialDimension
(
axis_to_dim_map
,
out_axes
);
output_dist_attr_dst
.
set_partial_status
(
partial_on_dims
);
// Step2.3.2 handle input tensor partial (TODO)
// Step2.3.2 handle input tensor partial (TODO)
VLOG
(
4
)
<<
"MatmulSPMDRule InferForward: "
VLOG
(
4
)
<<
"MatmulSPMDRule InferForward: "
...
...
paddle/fluid/distributed/auto_parallel/spmd_rules/reduction_spmd_rule.cc
浏览文件 @
e3b6e02f
...
@@ -88,13 +88,15 @@ ReductionSPMDRule::InferForward(const std::vector<DistTensorSpec>& input_specs,
...
@@ -88,13 +88,15 @@ ReductionSPMDRule::InferForward(const std::vector<DistTensorSpec>& input_specs,
CopyTensorDistAttrForOutput
(
input_specs
[
0
].
dist_attr
());
CopyTensorDistAttrForOutput
(
input_specs
[
0
].
dist_attr
());
output_dist_attr
.
set_dims_mapping
(
output_dims_mapping
);
output_dist_attr
.
set_dims_mapping
(
output_dims_mapping
);
std
::
vector
<
TensorDistAttr
>
output_dist_attrs
;
output_dist_attrs
.
emplace_back
(
output_dist_attr
);
// step2.4: handle partial
// step2.4: handle partial
// Step2.4.1 Output Partial
// Step2.4.1 Output Partial
std
::
vector
<
int64_t
>
partial_on_dims
=
std
::
vector
<
int64_t
>
partial_on_dims
=
ResoluteOutputPartialDimension
(
axis_to_dim_map
,
output_axes
);
ResoluteOutputPartialDimension
(
axis_to_dim_map
,
output_axes
);
output_dist_attr
.
set_partial_status
(
partial_on_dims
/*, handle reduce_type in future */
);
std
::
vector
<
TensorDistAttr
>
output_dist_attrs
;
output_dist_attrs
.
emplace_back
(
output_dist_attr
);
// Step2.4.2 handle input tensor partial (TODO)
// Step2.4.2 handle input tensor partial (TODO)
// If the op is a linear op, i.e. `linearity` is true, it supports
// If the op is a linear op, i.e. `linearity` is true, it supports
...
...
paddle/fluid/pybind/auto_parallel_py.cc
浏览文件 @
e3b6e02f
...
@@ -293,7 +293,11 @@ void BindAutoParallel(py::module *m) {
...
@@ -293,7 +293,11 @@ void BindAutoParallel(py::module *m) {
return
TensorDistAttr
(
self
);
return
TensorDistAttr
(
self
);
},
},
py
::
arg
(
"memo"
))
py
::
arg
(
"memo"
))
.
def
(
"__str__"
,
&
TensorDistAttr
::
to_string
);
.
def
(
"__str__"
,
&
TensorDistAttr
::
to_string
)
.
def
(
"_is_partial"
,
&
TensorDistAttr
::
is_partial
)
.
def
(
"_partial_dims"
,
&
TensorDistAttr
::
partial_dims
)
.
def
(
"_clean_partial_dims"
,
&
TensorDistAttr
::
clean_partial_dims
)
.
def
(
"_clean_partial_status"
,
&
TensorDistAttr
::
clean_partial_status
);
py
::
class_
<
SPMDRuleBase
>
(
*
m
,
"SPMDRuleBase"
)
py
::
class_
<
SPMDRuleBase
>
(
*
m
,
"SPMDRuleBase"
)
.
def
(
"infer_forward"
,
&
SPMDRuleBase
::
InferForward
)
.
def
(
"infer_forward"
,
&
SPMDRuleBase
::
InferForward
)
...
...
paddle/phi/core/distributed/auto_parallel/dist_attr.cc
浏览文件 @
e3b6e02f
...
@@ -24,6 +24,7 @@ namespace phi {
...
@@ -24,6 +24,7 @@ namespace phi {
namespace
distributed
{
namespace
distributed
{
namespace
auto_parallel
{
namespace
auto_parallel
{
// partial is not allow annotated by user by now.
std
::
vector
<
std
::
string
>
TensorDistAttr
::
fields_
{
std
::
vector
<
std
::
string
>
TensorDistAttr
::
fields_
{
"process_mesh"
,
"dims_mapping"
,
"batch_dim"
,
"dynamic_dims"
};
"process_mesh"
,
"dims_mapping"
,
"batch_dim"
,
"dynamic_dims"
};
...
@@ -44,6 +45,7 @@ TensorDistAttr& TensorDistAttr::operator=(const TensorDistAttr& dist_attr) {
...
@@ -44,6 +45,7 @@ TensorDistAttr& TensorDistAttr::operator=(const TensorDistAttr& dist_attr) {
std
::
swap
(
this
->
batch_dim_
,
tmp
.
batch_dim_
);
std
::
swap
(
this
->
batch_dim_
,
tmp
.
batch_dim_
);
std
::
swap
(
this
->
dynamic_dims_
,
tmp
.
dynamic_dims_
);
std
::
swap
(
this
->
dynamic_dims_
,
tmp
.
dynamic_dims_
);
std
::
swap
(
this
->
annotated_
,
tmp
.
annotated_
);
std
::
swap
(
this
->
annotated_
,
tmp
.
annotated_
);
std
::
swap
(
this
->
partial_status_
,
tmp
.
partial_status_
);
return
*
this
;
return
*
this
;
}
}
...
@@ -53,6 +55,7 @@ void TensorDistAttr::copy_from(const TensorDistAttr& dist_attr) {
...
@@ -53,6 +55,7 @@ void TensorDistAttr::copy_from(const TensorDistAttr& dist_attr) {
set_batch_dim
(
dist_attr
.
batch_dim
());
set_batch_dim
(
dist_attr
.
batch_dim
());
set_dynamic_dims
(
dist_attr
.
dynamic_dims
());
set_dynamic_dims
(
dist_attr
.
dynamic_dims
());
set_annotated
(
dist_attr
.
annotated
());
set_annotated
(
dist_attr
.
annotated
());
set_partial_status
(
dist_attr
.
partial_status
());
}
}
void
TensorDistAttr
::
set_process_mesh
(
const
ProcessMesh
&
process_mesh
)
{
void
TensorDistAttr
::
set_process_mesh
(
const
ProcessMesh
&
process_mesh
)
{
...
@@ -77,6 +80,44 @@ void TensorDistAttr::set_annotated(
...
@@ -77,6 +80,44 @@ void TensorDistAttr::set_annotated(
annotated_
=
annotated
;
annotated_
=
annotated
;
}
}
const
std
::
set
<
int64_t
>
TensorDistAttr
::
partial_dims
()
const
{
std
::
set
<
int64_t
>
keys
;
for
(
auto
&
kv
:
partial_status_
)
{
keys
.
emplace
(
kv
.
first
);
}
return
keys
;
}
void
TensorDistAttr
::
set_partial_status
(
const
paddle
::
flat_hash_map
<
int64_t
,
ReduceType
>&
partial_status
)
{
partial_status_
=
partial_status
;
}
void
TensorDistAttr
::
set_partial_status
(
const
std
::
vector
<
int64_t
>&
dims
,
const
ReduceType
&
type
)
{
for
(
const
auto
&
dim
:
dims
)
{
if
(
partial_status_
.
count
(
dim
)
!=
0
)
{
PADDLE_THROW
(
phi
::
errors
::
InvalidArgument
(
"Trying to Set dim %d as Partial which is already a Partial dim."
,
dim
));
}
partial_status_
.
emplace
(
dim
,
type
);
}
}
void
TensorDistAttr
::
clean_partial_status
()
{
partial_status_
.
clear
();
}
void
TensorDistAttr
::
clean_partial_dims
(
const
std
::
vector
<
int64_t
>&
dims
)
{
for
(
const
auto
&
dim
:
dims
)
{
if
(
partial_status_
.
count
(
dim
)
==
0
)
{
PADDLE_THROW
(
phi
::
errors
::
InvalidArgument
(
"Trying to clean Partial on dim %d but it is not Partial."
,
dim
));
}
else
{
partial_status_
.
erase
(
dim
);
}
}
}
void
TensorDistAttr
::
set_default_dims_mapping
(
void
TensorDistAttr
::
set_default_dims_mapping
(
const
std
::
vector
<
int64_t
>&
tensor_shape
)
{
const
std
::
vector
<
int64_t
>&
tensor_shape
)
{
if
(
!
tensor_shape
.
empty
())
{
if
(
!
tensor_shape
.
empty
())
{
...
@@ -178,6 +219,20 @@ bool TensorDistAttr::verify_annotated(
...
@@ -178,6 +219,20 @@ bool TensorDistAttr::verify_annotated(
return
true
;
return
true
;
}
}
bool
TensorDistAttr
::
verify_partial_status
()
const
{
VLOG
(
4
)
<<
"[TensorDistAttr verify_partial_status] "
<<
partial_status_string
();
for
(
auto
&
itr
:
partial_status_
)
{
if
(
itr
.
first
<
0
||
itr
.
first
>=
process_mesh_
.
ndim
())
{
return
false
;
}
if
(
itr
.
second
<
ReduceType
::
SUM
||
itr
.
second
<=
ReduceType
::
ALL
)
{
return
false
;
}
}
return
true
;
}
bool
TensorDistAttr
::
verify
(
const
std
::
vector
<
int64_t
>&
tensor_shape
)
const
{
bool
TensorDistAttr
::
verify
(
const
std
::
vector
<
int64_t
>&
tensor_shape
)
const
{
if
(
!
verify_process_mesh
(
process_mesh_
))
{
if
(
!
verify_process_mesh
(
process_mesh_
))
{
return
false
;
return
false
;
...
@@ -194,6 +249,9 @@ bool TensorDistAttr::verify(const std::vector<int64_t>& tensor_shape) const {
...
@@ -194,6 +249,9 @@ bool TensorDistAttr::verify(const std::vector<int64_t>& tensor_shape) const {
if
(
!
verify_annotated
(
annotated_
))
{
if
(
!
verify_annotated
(
annotated_
))
{
return
false
;
return
false
;
}
}
if
(
!
verify_partial_status
())
{
return
false
;
}
return
true
;
return
true
;
}
}
...
@@ -203,7 +261,8 @@ std::string TensorDistAttr::to_string() const {
...
@@ -203,7 +261,8 @@ std::string TensorDistAttr::to_string() const {
dist_str
+=
"dims_mappings: ["
+
str_join
(
dims_mapping_
)
+
"], "
;
dist_str
+=
"dims_mappings: ["
+
str_join
(
dims_mapping_
)
+
"], "
;
dist_str
+=
"batch_dim: "
+
std
::
to_string
(
batch_dim_
)
+
", "
;
dist_str
+=
"batch_dim: "
+
std
::
to_string
(
batch_dim_
)
+
", "
;
dist_str
+=
"dynamic_dims: ["
+
str_join
(
dynamic_dims_
)
+
"], "
;
dist_str
+=
"dynamic_dims: ["
+
str_join
(
dynamic_dims_
)
+
"], "
;
dist_str
+=
"annotated: ["
+
str_join
(
annotated_
)
+
"]}"
;
dist_str
+=
"annotated: ["
+
str_join
(
annotated_
)
+
"], "
;
dist_str
+=
"partial: "
+
partial_status_string
()
+
".}"
;
return
dist_str
;
return
dist_str
;
}
}
...
@@ -267,9 +326,23 @@ bool operator==(const TensorDistAttr& lhs, const TensorDistAttr& rhs) {
...
@@ -267,9 +326,23 @@ bool operator==(const TensorDistAttr& lhs, const TensorDistAttr& rhs) {
if
(
lhs
.
dynamic_dims
()
!=
rhs
.
dynamic_dims
())
{
if
(
lhs
.
dynamic_dims
()
!=
rhs
.
dynamic_dims
())
{
return
false
;
return
false
;
}
}
if
(
lhs
.
partial_status
()
!=
rhs
.
partial_status
())
{
return
false
;
}
return
true
;
return
true
;
}
}
std
::
string
TensorDistAttr
::
partial_status_string
()
const
{
std
::
string
partial_status_str
=
"["
;
for
(
auto
&
itr
:
partial_status_
)
{
partial_status_str
+=
"Partial(dims:"
+
std
::
to_string
(
itr
.
first
)
+
", "
+
ReduceTypeStrings
[
static_cast
<
int
>
(
itr
.
second
)]
+
"), "
;
}
partial_status_str
+=
"]"
;
return
partial_status_str
;
}
}
// namespace auto_parallel
}
// namespace auto_parallel
}
// namespace distributed
}
// namespace distributed
}
// namespace phi
}
// namespace phi
paddle/phi/core/distributed/auto_parallel/dist_attr.h
浏览文件 @
e3b6e02f
...
@@ -25,6 +25,7 @@ limitations under the License. */
...
@@ -25,6 +25,7 @@ limitations under the License. */
#include "paddle/phi/core/distributed/auto_parallel/process_mesh.h"
#include "paddle/phi/core/distributed/auto_parallel/process_mesh.h"
#include "paddle/phi/core/distributed/auto_parallel/utils.h"
#include "paddle/phi/core/distributed/auto_parallel/utils.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/utils/flat_hash_map.h"
namespace
phi
{
namespace
phi
{
namespace
distributed
{
namespace
distributed
{
...
@@ -32,13 +33,25 @@ namespace auto_parallel {
...
@@ -32,13 +33,25 @@ namespace auto_parallel {
constexpr
const
char
*
kDefault
=
"default"
;
constexpr
const
char
*
kDefault
=
"default"
;
enum
class
ReduceType
:
std
::
uint8_t
{
SUM
=
0
,
AVG
,
MAX
,
MIN
,
PRODUCT
,
ANY
,
ALL
};
constexpr
const
char
*
ReduceTypeStrings
[]
=
{
"SUM"
,
"AVG"
,
"MAX"
,
"MIN"
,
"PRODUCT"
,
"ANY"
,
"ALL"
};
class
TensorDistAttr
{
class
TensorDistAttr
{
public:
public:
TensorDistAttr
()
=
default
;
TensorDistAttr
()
=
default
;
explicit
TensorDistAttr
(
const
std
::
vector
<
int64_t
>&
tensor_shape
);
explicit
TensorDistAttr
(
const
std
::
vector
<
int64_t
>&
tensor_shape
);
TensorDistAttr
(
const
TensorDistAttr
&
tenso
r
);
TensorDistAttr
(
const
TensorDistAttr
&
dist_att
r
);
TensorDistAttr
&
operator
=
(
const
TensorDistAttr
&
dist_attr
);
TensorDistAttr
&
operator
=
(
const
TensorDistAttr
&
dist_attr
);
...
@@ -52,6 +65,29 @@ class TensorDistAttr {
...
@@ -52,6 +65,29 @@ class TensorDistAttr {
void
set_dims_mapping
(
const
std
::
vector
<
int64_t
>&
dims_mapping
);
void
set_dims_mapping
(
const
std
::
vector
<
int64_t
>&
dims_mapping
);
// true if tensor is partial on any mesh dim.
bool
is_partial
()
const
{
return
!
partial_status_
.
empty
();
}
// return vector of mesh dims on which the this tensor is partial on
const
std
::
set
<
int64_t
>
partial_dims
()
const
;
const
paddle
::
flat_hash_map
<
int64_t
,
ReduceType
>&
partial_status
()
const
{
return
partial_status_
;
}
// by map
void
set_partial_status
(
const
paddle
::
flat_hash_map
<
int64_t
,
ReduceType
>&
partial_status
);
// by each dim
void
set_partial_status
(
const
std
::
vector
<
int64_t
>&
dims
,
const
ReduceType
&
type
=
ReduceType
::
SUM
);
// all
void
clean_partial_status
();
// clean by dims
void
clean_partial_dims
(
const
std
::
vector
<
int64_t
>&
dims
);
void
set_default_dims_mapping
(
const
std
::
vector
<
int64_t
>&
tensor_shape
);
void
set_default_dims_mapping
(
const
std
::
vector
<
int64_t
>&
tensor_shape
);
int64_t
batch_dim
()
const
{
return
batch_dim_
;
}
int64_t
batch_dim
()
const
{
return
batch_dim_
;
}
...
@@ -89,11 +125,17 @@ class TensorDistAttr {
...
@@ -89,11 +125,17 @@ class TensorDistAttr {
bool
verify_annotated
(
const
std
::
map
<
std
::
string
,
bool
>&
annotated
)
const
;
bool
verify_annotated
(
const
std
::
map
<
std
::
string
,
bool
>&
annotated
)
const
;
bool
verify_partial_status
()
const
;
bool
verify
(
const
std
::
vector
<
int64_t
>&
tensor_shape
)
const
;
bool
verify
(
const
std
::
vector
<
int64_t
>&
tensor_shape
)
const
;
// TensorDistAttr from_string(const std::string& dist_str);
// TensorDistAttr from_string(const std::string& dist_str);
std
::
string
to_string
()
const
;
std
::
string
to_string
()
const
;
std
::
string
partial_status_string
()
const
;
// in partial-support-stage-I partial will always be a runtime attribute,
// there is not need to serialize it. support the partial serialization in
// future partial-support-stage-II.
void
from_proto
(
const
TensorDistAttrProto
&
proto
);
void
from_proto
(
const
TensorDistAttrProto
&
proto
);
TensorDistAttrProto
to_proto
()
const
;
TensorDistAttrProto
to_proto
()
const
;
...
@@ -109,6 +151,10 @@ class TensorDistAttr {
...
@@ -109,6 +151,10 @@ class TensorDistAttr {
int64_t
batch_dim_
{
0
};
int64_t
batch_dim_
{
0
};
std
::
vector
<
bool
>
dynamic_dims_
;
std
::
vector
<
bool
>
dynamic_dims_
;
std
::
map
<
std
::
string
,
bool
>
annotated_
;
std
::
map
<
std
::
string
,
bool
>
annotated_
;
// partial map would be small (less than mesh.size)
// iterate operation (copy and comparision) would more frequency than random
// element access. <key: dim on mesh, value: reduce type>
paddle
::
flat_hash_map
<
int64_t
,
ReduceType
>
partial_status_
;
};
};
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
TensorDistAttr
&
obj
)
{
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
TensorDistAttr
&
obj
)
{
...
...
test/auto_parallel/spmd_rules/test_matmul_rule.py
浏览文件 @
e3b6e02f
...
@@ -60,6 +60,8 @@ class TestMatmulSPMDRule(unittest.TestCase):
...
@@ -60,6 +60,8 @@ class TestMatmulSPMDRule(unittest.TestCase):
self
.
assertEqual
(
infered_input_dist_attrs
[
0
].
dims_mapping
,
[
1
,
0
])
self
.
assertEqual
(
infered_input_dist_attrs
[
0
].
dims_mapping
,
[
1
,
0
])
self
.
assertEqual
(
infered_input_dist_attrs
[
1
].
dims_mapping
,
[
0
,
-
1
])
self
.
assertEqual
(
infered_input_dist_attrs
[
1
].
dims_mapping
,
[
0
,
-
1
])
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
dims_mapping
,
[
1
,
-
1
])
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
dims_mapping
,
[
1
,
-
1
])
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
_is_partial
(),
True
)
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
_partial_dims
(),
[
0
])
# test row parallel: mk[1, -1],kn[-1, -1] --> mk[1, -1],kn[-1, -1] = nm[1, -1] partial[]
# test row parallel: mk[1, -1],kn[-1, -1] --> mk[1, -1],kn[-1, -1] = nm[1, -1] partial[]
self
.
x_dist_tensor_spec
.
set_dims_mapping
([
1
,
-
1
])
self
.
x_dist_tensor_spec
.
set_dims_mapping
([
1
,
-
1
])
...
@@ -73,6 +75,7 @@ class TestMatmulSPMDRule(unittest.TestCase):
...
@@ -73,6 +75,7 @@ class TestMatmulSPMDRule(unittest.TestCase):
self
.
assertEqual
(
infered_input_dist_attrs
[
0
].
dims_mapping
,
[
1
,
-
1
])
self
.
assertEqual
(
infered_input_dist_attrs
[
0
].
dims_mapping
,
[
1
,
-
1
])
self
.
assertEqual
(
infered_input_dist_attrs
[
1
].
dims_mapping
,
[
-
1
,
-
1
])
self
.
assertEqual
(
infered_input_dist_attrs
[
1
].
dims_mapping
,
[
-
1
,
-
1
])
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
dims_mapping
,
[
1
,
-
1
])
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
dims_mapping
,
[
1
,
-
1
])
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
_is_partial
(),
False
)
# test row parallel: mk[1, -1],kn[-1, -1] --> mk[1, -1],kn[-1, -1] = nm[1, -1] partial[]
# test row parallel: mk[1, -1],kn[-1, -1] --> mk[1, -1],kn[-1, -1] = nm[1, -1] partial[]
self
.
x_dist_tensor_spec
.
set_dims_mapping
([
1
,
-
1
])
self
.
x_dist_tensor_spec
.
set_dims_mapping
([
1
,
-
1
])
...
@@ -85,6 +88,7 @@ class TestMatmulSPMDRule(unittest.TestCase):
...
@@ -85,6 +88,7 @@ class TestMatmulSPMDRule(unittest.TestCase):
self
.
assertEqual
(
infered_input_dist_attrs
[
0
].
dims_mapping
,
[
1
,
-
1
])
self
.
assertEqual
(
infered_input_dist_attrs
[
0
].
dims_mapping
,
[
1
,
-
1
])
self
.
assertEqual
(
infered_input_dist_attrs
[
1
].
dims_mapping
,
[
-
1
,
-
1
])
self
.
assertEqual
(
infered_input_dist_attrs
[
1
].
dims_mapping
,
[
-
1
,
-
1
])
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
dims_mapping
,
[
1
,
-
1
])
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
dims_mapping
,
[
1
,
-
1
])
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
_is_partial
(),
False
)
# test n parallel: mk[-1, -1],kn[-1, 0] --> mk[-1, -1],kn[-1, 0] = nm[-1, 0] partial[]
# test n parallel: mk[-1, -1],kn[-1, 0] --> mk[-1, -1],kn[-1, 0] = nm[-1, 0] partial[]
self
.
x_dist_tensor_spec
.
set_dims_mapping
([
-
1
,
-
1
])
self
.
x_dist_tensor_spec
.
set_dims_mapping
([
-
1
,
-
1
])
...
@@ -97,6 +101,7 @@ class TestMatmulSPMDRule(unittest.TestCase):
...
@@ -97,6 +101,7 @@ class TestMatmulSPMDRule(unittest.TestCase):
self
.
assertEqual
(
infered_input_dist_attrs
[
0
].
dims_mapping
,
[
-
1
,
-
1
])
self
.
assertEqual
(
infered_input_dist_attrs
[
0
].
dims_mapping
,
[
-
1
,
-
1
])
self
.
assertEqual
(
infered_input_dist_attrs
[
1
].
dims_mapping
,
[
-
1
,
0
])
self
.
assertEqual
(
infered_input_dist_attrs
[
1
].
dims_mapping
,
[
-
1
,
0
])
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
dims_mapping
,
[
-
1
,
0
])
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
dims_mapping
,
[
-
1
,
0
])
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
_is_partial
(),
False
)
# test partial with propogation: mk[1, 0],kn[-1,-1] --> mk[1, 0],kn[0, -1] = nm[1, -1] partial[0]
# test partial with propogation: mk[1, 0],kn[-1,-1] --> mk[1, 0],kn[0, -1] = nm[1, -1] partial[0]
self
.
x_dist_tensor_spec
.
set_dims_mapping
([
1
,
0
])
self
.
x_dist_tensor_spec
.
set_dims_mapping
([
1
,
0
])
...
@@ -109,6 +114,8 @@ class TestMatmulSPMDRule(unittest.TestCase):
...
@@ -109,6 +114,8 @@ class TestMatmulSPMDRule(unittest.TestCase):
self
.
assertEqual
(
infered_input_dist_attrs
[
0
].
dims_mapping
,
[
1
,
0
])
self
.
assertEqual
(
infered_input_dist_attrs
[
0
].
dims_mapping
,
[
1
,
0
])
self
.
assertEqual
(
infered_input_dist_attrs
[
1
].
dims_mapping
,
[
0
,
-
1
])
self
.
assertEqual
(
infered_input_dist_attrs
[
1
].
dims_mapping
,
[
0
,
-
1
])
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
dims_mapping
,
[
1
,
-
1
])
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
dims_mapping
,
[
1
,
-
1
])
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
_is_partial
(),
True
)
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
_partial_dims
(),
[
0
])
# mk[-1,-1],kn[1,0] --> mk[-1, 1],kn[1, 0] = nm[-1, 0] partial[1]:
# mk[-1,-1],kn[1,0] --> mk[-1, 1],kn[1, 0] = nm[-1, 0] partial[1]:
self
.
x_dist_tensor_spec
.
set_dims_mapping
([
-
1
,
-
1
])
self
.
x_dist_tensor_spec
.
set_dims_mapping
([
-
1
,
-
1
])
...
@@ -121,6 +128,8 @@ class TestMatmulSPMDRule(unittest.TestCase):
...
@@ -121,6 +128,8 @@ class TestMatmulSPMDRule(unittest.TestCase):
self
.
assertEqual
(
infered_input_dist_attrs
[
0
].
dims_mapping
,
[
-
1
,
1
])
self
.
assertEqual
(
infered_input_dist_attrs
[
0
].
dims_mapping
,
[
-
1
,
1
])
self
.
assertEqual
(
infered_input_dist_attrs
[
1
].
dims_mapping
,
[
1
,
0
])
self
.
assertEqual
(
infered_input_dist_attrs
[
1
].
dims_mapping
,
[
1
,
0
])
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
dims_mapping
,
[
-
1
,
0
])
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
dims_mapping
,
[
-
1
,
0
])
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
_is_partial
(),
True
)
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
_partial_dims
(),
[
1
])
# abcmk[1, 0, -1, -1],kn[-1, -1] --> abcmk[1, 0, -1, -1],kn[-1, -1] = abcmn[1, 0, -1, -1] partial[]: done
# abcmk[1, 0, -1, -1],kn[-1, -1] --> abcmk[1, 0, -1, -1],kn[-1, -1] = abcmn[1, 0, -1, -1] partial[]: done
self
.
x_dist_tensor_spec
.
shape
=
[
512
,
48
,
64
,
32
]
self
.
x_dist_tensor_spec
.
shape
=
[
512
,
48
,
64
,
32
]
...
@@ -138,6 +147,7 @@ class TestMatmulSPMDRule(unittest.TestCase):
...
@@ -138,6 +147,7 @@ class TestMatmulSPMDRule(unittest.TestCase):
self
.
assertEqual
(
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
dims_mapping
,
[
0
,
1
,
-
1
,
-
1
]
infered_output_dist_attrs
[
0
].
dims_mapping
,
[
0
,
1
,
-
1
,
-
1
]
)
)
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
_is_partial
(),
False
)
# abcmk[1, -1, -1, 0],kn[-1, -1] --> abcmk[1, -1, -1, 0],kn[0, -1] = abcmn[1,-1, -1, -1] partial[0]
# abcmk[1, -1, -1, 0],kn[-1, -1] --> abcmk[1, -1, -1, 0],kn[0, -1] = abcmn[1,-1, -1, -1] partial[0]
self
.
x_dist_tensor_spec
.
set_dims_mapping
([
1
,
-
1
,
-
1
,
0
])
self
.
x_dist_tensor_spec
.
set_dims_mapping
([
1
,
-
1
,
-
1
,
0
])
...
@@ -154,6 +164,8 @@ class TestMatmulSPMDRule(unittest.TestCase):
...
@@ -154,6 +164,8 @@ class TestMatmulSPMDRule(unittest.TestCase):
self
.
assertEqual
(
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
dims_mapping
,
[
1
,
-
1
,
-
1
,
-
1
]
infered_output_dist_attrs
[
0
].
dims_mapping
,
[
1
,
-
1
,
-
1
,
-
1
]
)
)
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
_is_partial
(),
True
)
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
_partial_dims
(),
[
0
])
# trans_x = True, abcmk[1, -1, -1, 0], kn[-1, -1] --> abcmk[1, -1, -1, 0],kn[-1, -1] = abcmn[1, -1, 0, -1] partial[]
# trans_x = True, abcmk[1, -1, -1, 0], kn[-1, -1] --> abcmk[1, -1, -1, 0],kn[-1, -1] = abcmn[1, -1, 0, -1] partial[]
self
.
x_dist_tensor_spec
.
set_dims_mapping
([
1
,
-
1
,
-
1
,
0
])
self
.
x_dist_tensor_spec
.
set_dims_mapping
([
1
,
-
1
,
-
1
,
0
])
...
@@ -171,6 +183,7 @@ class TestMatmulSPMDRule(unittest.TestCase):
...
@@ -171,6 +183,7 @@ class TestMatmulSPMDRule(unittest.TestCase):
self
.
assertEqual
(
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
dims_mapping
,
[
1
,
-
1
,
0
,
-
1
]
infered_output_dist_attrs
[
0
].
dims_mapping
,
[
1
,
-
1
,
0
,
-
1
]
)
)
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
_is_partial
(),
False
)
# trans_y = True, abcmk[-1, -1, -1, -1], kn[1, 0] --> abcmk[-1, -1, -1, 0],kn[1, 0] = abcmn[-1, -1, -1, 1] partial[0]: done
# trans_y = True, abcmk[-1, -1, -1, -1], kn[1, 0] --> abcmk[-1, -1, -1, 0],kn[1, 0] = abcmn[-1, -1, -1, 1] partial[0]: done
self
.
x_dist_tensor_spec
.
set_dims_mapping
([
-
1
,
-
1
,
-
1
,
-
1
])
self
.
x_dist_tensor_spec
.
set_dims_mapping
([
-
1
,
-
1
,
-
1
,
-
1
])
...
@@ -189,6 +202,10 @@ class TestMatmulSPMDRule(unittest.TestCase):
...
@@ -189,6 +202,10 @@ class TestMatmulSPMDRule(unittest.TestCase):
self
.
assertEqual
(
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
dims_mapping
,
[
-
1
,
-
1
,
-
1
,
1
]
infered_output_dist_attrs
[
0
].
dims_mapping
,
[
-
1
,
-
1
,
-
1
,
1
]
)
)
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
_is_partial
(),
True
)
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
_partial_dims
(),
[
0
])
infered_output_dist_attrs
[
0
].
_clean_partial_dims
([
0
])
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
_is_partial
(),
False
)
# trans_y = True, trans_x = True, abcmk[-1, -1, 0, 1], kn[1, 0] --> abcmk[-1, -1, 0, 1]],kn[-1, 0] = abcmn[-1, -1, 1, -1] partial[0]
# trans_y = True, trans_x = True, abcmk[-1, -1, 0, 1], kn[1, 0] --> abcmk[-1, -1, 0, 1]],kn[-1, 0] = abcmn[-1, -1, 1, -1] partial[0]
# multiple mesh dim shard same tensor axis
# multiple mesh dim shard same tensor axis
...
@@ -208,6 +225,10 @@ class TestMatmulSPMDRule(unittest.TestCase):
...
@@ -208,6 +225,10 @@ class TestMatmulSPMDRule(unittest.TestCase):
self
.
assertEqual
(
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
dims_mapping
,
[
-
1
,
-
1
,
1
,
-
1
]
infered_output_dist_attrs
[
0
].
dims_mapping
,
[
-
1
,
-
1
,
1
,
-
1
]
)
)
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
_is_partial
(),
True
)
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
_partial_dims
(),
[
0
])
infered_output_dist_attrs
[
0
].
_clean_partial_status
()
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
_is_partial
(),
False
)
# trans_y = True, trans_x = True, abcmk[-1, -1, 1, 0], kn[1, 0] --> error:
# trans_y = True, trans_x = True, abcmk[-1, -1, 1, 0], kn[1, 0] --> error:
# one mesh dim shard multiple tensor axes
# one mesh dim shard multiple tensor axes
...
...
test/auto_parallel/spmd_rules/test_reduction_rule.py
浏览文件 @
e3b6e02f
...
@@ -62,6 +62,8 @@ class TestReductionSPMDRule(unittest.TestCase):
...
@@ -62,6 +62,8 @@ class TestReductionSPMDRule(unittest.TestCase):
self
.
assertEqual
(
infered_input_dist_attrs
[
0
].
dims_mapping
,
[
0
,
-
1
])
self
.
assertEqual
(
infered_input_dist_attrs
[
0
].
dims_mapping
,
[
0
,
-
1
])
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
dims_mapping
,
[
-
1
])
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
dims_mapping
,
[
-
1
])
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
_is_partial
(),
True
)
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
_partial_dims
(),
[
0
])
# reduce on dim 0, keep_dim = true
# reduce on dim 0, keep_dim = true
# [0, -1] --> [0, -1], [-1, -1], partial_on_dim:[0]
# [0, -1] --> [0, -1], [-1, -1], partial_on_dim:[0]
...
@@ -76,6 +78,8 @@ class TestReductionSPMDRule(unittest.TestCase):
...
@@ -76,6 +78,8 @@ class TestReductionSPMDRule(unittest.TestCase):
self
.
assertEqual
(
infered_input_dist_attrs
[
0
].
dims_mapping
,
[
0
,
-
1
])
self
.
assertEqual
(
infered_input_dist_attrs
[
0
].
dims_mapping
,
[
0
,
-
1
])
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
dims_mapping
,
[
-
1
,
-
1
])
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
dims_mapping
,
[
-
1
,
-
1
])
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
_is_partial
(),
True
)
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
_partial_dims
(),
[
0
])
# reduce on dim 1, keep_dim = false
# reduce on dim 1, keep_dim = false
# [0, -1] --> [0, -1], [0], partial_on_dim:[]
# [0, -1] --> [0, -1], [0], partial_on_dim:[]
...
@@ -90,6 +94,7 @@ class TestReductionSPMDRule(unittest.TestCase):
...
@@ -90,6 +94,7 @@ class TestReductionSPMDRule(unittest.TestCase):
self
.
assertEqual
(
infered_input_dist_attrs
[
0
].
dims_mapping
,
[
0
,
-
1
])
self
.
assertEqual
(
infered_input_dist_attrs
[
0
].
dims_mapping
,
[
0
,
-
1
])
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
dims_mapping
,
[
0
])
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
dims_mapping
,
[
0
])
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
_is_partial
(),
False
)
# reduce on dim 1, keep_dim = true
# reduce on dim 1, keep_dim = true
# [0, -1] --> [0, -1], [0, -1], partial_on_dim:[]
# [0, -1] --> [0, -1], [0, -1], partial_on_dim:[]
...
@@ -104,6 +109,7 @@ class TestReductionSPMDRule(unittest.TestCase):
...
@@ -104,6 +109,7 @@ class TestReductionSPMDRule(unittest.TestCase):
self
.
assertEqual
(
infered_input_dist_attrs
[
0
].
dims_mapping
,
[
0
,
-
1
])
self
.
assertEqual
(
infered_input_dist_attrs
[
0
].
dims_mapping
,
[
0
,
-
1
])
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
dims_mapping
,
[
0
,
-
1
])
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
dims_mapping
,
[
0
,
-
1
])
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
_is_partial
(),
False
)
# reduce on dim 0 and 1, keep_dim = false
# reduce on dim 0 and 1, keep_dim = false
# [0, -1] --> [0, -1], [], partial_on_dim:[0]
# [0, -1] --> [0, -1], [], partial_on_dim:[0]
...
@@ -118,6 +124,8 @@ class TestReductionSPMDRule(unittest.TestCase):
...
@@ -118,6 +124,8 @@ class TestReductionSPMDRule(unittest.TestCase):
self
.
assertEqual
(
infered_input_dist_attrs
[
0
].
dims_mapping
,
[
0
,
-
1
])
self
.
assertEqual
(
infered_input_dist_attrs
[
0
].
dims_mapping
,
[
0
,
-
1
])
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
dims_mapping
,
[])
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
dims_mapping
,
[])
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
_is_partial
(),
True
)
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
_partial_dims
(),
[
0
])
# reduce on dim 0 and 1, keep_dim = true
# reduce on dim 0 and 1, keep_dim = true
# [0, -1] --> [0, -1], [-1, -1], partial_on_dim:[0]
# [0, -1] --> [0, -1], [-1, -1], partial_on_dim:[0]
...
@@ -132,6 +140,8 @@ class TestReductionSPMDRule(unittest.TestCase):
...
@@ -132,6 +140,8 @@ class TestReductionSPMDRule(unittest.TestCase):
self
.
assertEqual
(
infered_input_dist_attrs
[
0
].
dims_mapping
,
[
0
,
-
1
])
self
.
assertEqual
(
infered_input_dist_attrs
[
0
].
dims_mapping
,
[
0
,
-
1
])
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
dims_mapping
,
[
-
1
,
-
1
])
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
dims_mapping
,
[
-
1
,
-
1
])
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
_is_partial
(),
True
)
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
_partial_dims
(),
[
0
])
def
test_multi_mesh_dim
(
self
):
def
test_multi_mesh_dim
(
self
):
process_mesh
=
auto
.
ProcessMesh
(
mesh
=
[[
0
,
1
,
2
],
[
3
,
4
,
5
]])
process_mesh
=
auto
.
ProcessMesh
(
mesh
=
[[
0
,
1
,
2
],
[
3
,
4
,
5
]])
...
@@ -170,6 +180,10 @@ class TestReductionSPMDRule(unittest.TestCase):
...
@@ -170,6 +180,10 @@ class TestReductionSPMDRule(unittest.TestCase):
self
.
assertEqual
(
infered_input_dist_attrs
[
0
].
dims_mapping
,
[
-
1
,
0
,
1
])
self
.
assertEqual
(
infered_input_dist_attrs
[
0
].
dims_mapping
,
[
-
1
,
0
,
1
])
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
dims_mapping
,
[
-
1
])
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
dims_mapping
,
[
-
1
])
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
_is_partial
(),
True
)
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
_partial_dims
(),
[
0
,
1
])
infered_output_dist_attrs
[
0
].
_clean_partial_status
()
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
_is_partial
(),
False
)
# reduction on dim 1, 2, keep_dim = false
# reduction on dim 1, 2, keep_dim = false
# [1, -1, -1] --> [1, -1, -1], [1], partial_on_dim:[]
# [1, -1, -1] --> [1, -1, -1], [1], partial_on_dim:[]
self
.
attrs
[
'keep_dim'
]
=
False
self
.
attrs
[
'keep_dim'
]
=
False
...
@@ -183,6 +197,7 @@ class TestReductionSPMDRule(unittest.TestCase):
...
@@ -183,6 +197,7 @@ class TestReductionSPMDRule(unittest.TestCase):
self
.
assertEqual
(
infered_input_dist_attrs
[
0
].
dims_mapping
,
[
1
,
-
1
,
-
1
])
self
.
assertEqual
(
infered_input_dist_attrs
[
0
].
dims_mapping
,
[
1
,
-
1
,
-
1
])
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
dims_mapping
,
[
1
])
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
dims_mapping
,
[
1
])
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
_is_partial
(),
False
)
# reduction on dim 1, 2, keep_dim = false
# reduction on dim 1, 2, keep_dim = false
# [0, 1, -1] --> [0, 1, -1], [0], partial_on_dim:[1]
# [0, 1, -1] --> [0, 1, -1], [0], partial_on_dim:[1]
...
@@ -197,6 +212,10 @@ class TestReductionSPMDRule(unittest.TestCase):
...
@@ -197,6 +212,10 @@ class TestReductionSPMDRule(unittest.TestCase):
self
.
assertEqual
(
infered_input_dist_attrs
[
0
].
dims_mapping
,
[
0
,
1
,
-
1
])
self
.
assertEqual
(
infered_input_dist_attrs
[
0
].
dims_mapping
,
[
0
,
1
,
-
1
])
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
dims_mapping
,
[
0
])
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
dims_mapping
,
[
0
])
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
_is_partial
(),
True
)
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
_partial_dims
(),
[
1
])
infered_output_dist_attrs
[
0
].
_clean_partial_status
()
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
_is_partial
(),
False
)
# reduction on dim 1, 2, keep_dim = true
# reduction on dim 1, 2, keep_dim = true
# [0, 1, -1] --> [0, 1, -1], [0, -1, -1], partial_on_dim:[1]
# [0, 1, -1] --> [0, 1, -1], [0, -1, -1], partial_on_dim:[1]
...
@@ -211,6 +230,8 @@ class TestReductionSPMDRule(unittest.TestCase):
...
@@ -211,6 +230,8 @@ class TestReductionSPMDRule(unittest.TestCase):
self
.
assertEqual
(
infered_input_dist_attrs
[
0
].
dims_mapping
,
[
0
,
1
,
-
1
])
self
.
assertEqual
(
infered_input_dist_attrs
[
0
].
dims_mapping
,
[
0
,
1
,
-
1
])
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
dims_mapping
,
[
0
,
-
1
,
-
1
])
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
dims_mapping
,
[
0
,
-
1
,
-
1
])
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
_is_partial
(),
True
)
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
_partial_dims
(),
[
1
])
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
test/cpp/auto_parallel/spmd_rule_test.cc
浏览文件 @
e3b6e02f
...
@@ -70,6 +70,7 @@ TEST(MatmulSPMDRule, Ctor) {
...
@@ -70,6 +70,7 @@ TEST(MatmulSPMDRule, Ctor) {
std
::
vector
<
int64_t
>
({
-
1
,
-
1
}));
std
::
vector
<
int64_t
>
({
-
1
,
-
1
}));
EXPECT_EQ
(
infered_dist_attrs
.
second
[
0
].
dims_mapping
(),
EXPECT_EQ
(
infered_dist_attrs
.
second
[
0
].
dims_mapping
(),
std
::
vector
<
int64_t
>
({
1
,
-
1
}));
std
::
vector
<
int64_t
>
({
1
,
-
1
}));
EXPECT_EQ
(
infered_dist_attrs
.
second
[
0
].
is_partial
(),
false
);
VLOG
(
4
)
<<
"test1 done."
<<
std
::
endl
<<
std
::
endl
<<
std
::
endl
;
VLOG
(
4
)
<<
"test1 done."
<<
std
::
endl
<<
std
::
endl
<<
std
::
endl
;
// mk[-1,-1],kn[-1,0] --> mk[-1,-1],kn[-1,0] = nm[-1,0] partial[]
// mk[-1,-1],kn[-1,0] --> mk[-1,-1],kn[-1,0] = nm[-1,0] partial[]
...
@@ -83,6 +84,7 @@ TEST(MatmulSPMDRule, Ctor) {
...
@@ -83,6 +84,7 @@ TEST(MatmulSPMDRule, Ctor) {
std
::
vector
<
int64_t
>
({
-
1
,
0
}));
std
::
vector
<
int64_t
>
({
-
1
,
0
}));
EXPECT_EQ
(
infered_dist_attrs
.
second
[
0
].
dims_mapping
(),
EXPECT_EQ
(
infered_dist_attrs
.
second
[
0
].
dims_mapping
(),
std
::
vector
<
int64_t
>
({
-
1
,
0
}));
std
::
vector
<
int64_t
>
({
-
1
,
0
}));
EXPECT_EQ
(
infered_dist_attrs
.
second
[
0
].
is_partial
(),
false
);
VLOG
(
4
)
<<
"test2 done."
<<
std
::
endl
<<
std
::
endl
<<
std
::
endl
;
VLOG
(
4
)
<<
"test2 done."
<<
std
::
endl
<<
std
::
endl
<<
std
::
endl
;
// mk[1, 0],kn[-1,-1] --> mk[1, 0],kn[0, -1] = nm[1, -1] partial[0]: done
// mk[1, 0],kn[-1,-1] --> mk[1, 0],kn[0, -1] = nm[1, -1] partial[0]: done
...
@@ -96,6 +98,9 @@ TEST(MatmulSPMDRule, Ctor) {
...
@@ -96,6 +98,9 @@ TEST(MatmulSPMDRule, Ctor) {
std
::
vector
<
int64_t
>
({
0
,
-
1
}));
std
::
vector
<
int64_t
>
({
0
,
-
1
}));
EXPECT_EQ
(
infered_dist_attrs
.
second
[
0
].
dims_mapping
(),
EXPECT_EQ
(
infered_dist_attrs
.
second
[
0
].
dims_mapping
(),
std
::
vector
<
int64_t
>
({
1
,
-
1
}));
std
::
vector
<
int64_t
>
({
1
,
-
1
}));
EXPECT_EQ
(
infered_dist_attrs
.
second
[
0
].
is_partial
(),
true
);
EXPECT_EQ
(
infered_dist_attrs
.
second
[
0
].
partial_dims
(),
std
::
set
<
int64_t
>
({
0
}));
VLOG
(
4
)
<<
"test3 done."
<<
std
::
endl
<<
std
::
endl
<<
std
::
endl
;
VLOG
(
4
)
<<
"test3 done."
<<
std
::
endl
<<
std
::
endl
<<
std
::
endl
;
// mk[-1,-1],kn[1,0] --> mk[-1, 1],kn[1, 0] = nm[-1, 0] partial[1]: done
// mk[-1,-1],kn[1,0] --> mk[-1, 1],kn[1, 0] = nm[-1, 0] partial[1]: done
...
@@ -109,6 +114,9 @@ TEST(MatmulSPMDRule, Ctor) {
...
@@ -109,6 +114,9 @@ TEST(MatmulSPMDRule, Ctor) {
std
::
vector
<
int64_t
>
({
1
,
0
}));
std
::
vector
<
int64_t
>
({
1
,
0
}));
EXPECT_EQ
(
infered_dist_attrs
.
second
[
0
].
dims_mapping
(),
EXPECT_EQ
(
infered_dist_attrs
.
second
[
0
].
dims_mapping
(),
std
::
vector
<
int64_t
>
({
-
1
,
0
}));
std
::
vector
<
int64_t
>
({
-
1
,
0
}));
EXPECT_EQ
(
infered_dist_attrs
.
second
[
0
].
is_partial
(),
true
);
EXPECT_EQ
(
infered_dist_attrs
.
second
[
0
].
partial_dims
(),
std
::
set
<
int64_t
>
({
1
}));
VLOG
(
4
)
<<
"test4 done."
<<
std
::
endl
<<
std
::
endl
<<
std
::
endl
;
VLOG
(
4
)
<<
"test4 done."
<<
std
::
endl
<<
std
::
endl
<<
std
::
endl
;
// abcmk[1, 0, -1, -1],kn[-1, -1] --> abcmk[1, 0, -1, -1],kn[-1, -1] =
// abcmk[1, 0, -1, -1],kn[-1, -1] --> abcmk[1, 0, -1, -1],kn[-1, -1] =
...
@@ -124,6 +132,7 @@ TEST(MatmulSPMDRule, Ctor) {
...
@@ -124,6 +132,7 @@ TEST(MatmulSPMDRule, Ctor) {
std
::
vector
<
int64_t
>
({
-
1
,
-
1
}));
std
::
vector
<
int64_t
>
({
-
1
,
-
1
}));
EXPECT_EQ
(
infered_dist_attrs
.
second
[
0
].
dims_mapping
(),
EXPECT_EQ
(
infered_dist_attrs
.
second
[
0
].
dims_mapping
(),
std
::
vector
<
int64_t
>
({
0
,
1
,
-
1
,
-
1
}));
std
::
vector
<
int64_t
>
({
0
,
1
,
-
1
,
-
1
}));
EXPECT_EQ
(
infered_dist_attrs
.
second
[
0
].
is_partial
(),
false
);
VLOG
(
4
)
<<
"test5 done."
<<
std
::
endl
<<
std
::
endl
<<
std
::
endl
;
VLOG
(
4
)
<<
"test5 done."
<<
std
::
endl
<<
std
::
endl
<<
std
::
endl
;
// abcmk[1, -1, -1, 0],kn[-1, -1] --> abcmk[1, -1, -1, 0],kn[0, -1] = abcmn[1,
// abcmk[1, -1, -1, 0],kn[-1, -1] --> abcmk[1, -1, -1, 0],kn[0, -1] = abcmn[1,
...
@@ -138,6 +147,9 @@ TEST(MatmulSPMDRule, Ctor) {
...
@@ -138,6 +147,9 @@ TEST(MatmulSPMDRule, Ctor) {
std
::
vector
<
int64_t
>
({
0
,
-
1
}));
std
::
vector
<
int64_t
>
({
0
,
-
1
}));
EXPECT_EQ
(
infered_dist_attrs
.
second
[
0
].
dims_mapping
(),
EXPECT_EQ
(
infered_dist_attrs
.
second
[
0
].
dims_mapping
(),
std
::
vector
<
int64_t
>
({
1
,
-
1
,
-
1
,
-
1
}));
std
::
vector
<
int64_t
>
({
1
,
-
1
,
-
1
,
-
1
}));
EXPECT_EQ
(
infered_dist_attrs
.
second
[
0
].
is_partial
(),
true
);
EXPECT_EQ
(
infered_dist_attrs
.
second
[
0
].
partial_dims
(),
std
::
set
<
int64_t
>
({
0
}));
VLOG
(
4
)
<<
"test6 done."
<<
std
::
endl
<<
std
::
endl
<<
std
::
endl
;
VLOG
(
4
)
<<
"test6 done."
<<
std
::
endl
<<
std
::
endl
<<
std
::
endl
;
// abcmk[1, -1, -1, 0], kn[-1, -1] --> abcmk[1, -1, -1, 0],kn[-1, -1] =
// abcmk[1, -1, -1, 0], kn[-1, -1] --> abcmk[1, -1, -1, 0],kn[-1, -1] =
...
@@ -153,6 +165,7 @@ TEST(MatmulSPMDRule, Ctor) {
...
@@ -153,6 +165,7 @@ TEST(MatmulSPMDRule, Ctor) {
std
::
vector
<
int64_t
>
({
-
1
,
-
1
}));
std
::
vector
<
int64_t
>
({
-
1
,
-
1
}));
EXPECT_EQ
(
infered_dist_attrs
.
second
[
0
].
dims_mapping
(),
EXPECT_EQ
(
infered_dist_attrs
.
second
[
0
].
dims_mapping
(),
std
::
vector
<
int64_t
>
({
1
,
-
1
,
0
,
-
1
}));
std
::
vector
<
int64_t
>
({
1
,
-
1
,
0
,
-
1
}));
EXPECT_EQ
(
infered_dist_attrs
.
second
[
0
].
is_partial
(),
false
);
VLOG
(
4
)
<<
"test7 done."
<<
std
::
endl
<<
std
::
endl
<<
std
::
endl
;
VLOG
(
4
)
<<
"test7 done."
<<
std
::
endl
<<
std
::
endl
<<
std
::
endl
;
// abcmk[-1, -1, -1, -1], kn[1, 0] --> abcmk[-1, -1, -1, 0],kn[1, 0] =
// abcmk[-1, -1, -1, -1], kn[1, 0] --> abcmk[-1, -1, -1, 0],kn[1, 0] =
...
@@ -169,6 +182,11 @@ TEST(MatmulSPMDRule, Ctor) {
...
@@ -169,6 +182,11 @@ TEST(MatmulSPMDRule, Ctor) {
std
::
vector
<
int64_t
>
({
1
,
0
}));
std
::
vector
<
int64_t
>
({
1
,
0
}));
EXPECT_EQ
(
infered_dist_attrs
.
second
[
0
].
dims_mapping
(),
EXPECT_EQ
(
infered_dist_attrs
.
second
[
0
].
dims_mapping
(),
std
::
vector
<
int64_t
>
({
-
1
,
-
1
,
-
1
,
1
}));
std
::
vector
<
int64_t
>
({
-
1
,
-
1
,
-
1
,
1
}));
EXPECT_EQ
(
infered_dist_attrs
.
second
[
0
].
is_partial
(),
true
);
EXPECT_EQ
(
infered_dist_attrs
.
second
[
0
].
partial_dims
(),
std
::
set
<
int64_t
>
({
0
}));
infered_dist_attrs
.
second
[
0
].
clean_partial_dims
(
std
::
vector
<
int64_t
>
({
0
}));
EXPECT_EQ
(
infered_dist_attrs
.
second
[
0
].
is_partial
(),
false
);
VLOG
(
4
)
<<
"test8 done."
<<
std
::
endl
<<
std
::
endl
<<
std
::
endl
;
VLOG
(
4
)
<<
"test8 done."
<<
std
::
endl
<<
std
::
endl
<<
std
::
endl
;
// abcmk[-1, -1, -1, -1], kn[1, 0] --> abcmk[-1, -1, -1, 0],kn[1, 0] =
// abcmk[-1, -1, -1, -1], kn[1, 0] --> abcmk[-1, -1, -1, 0],kn[1, 0] =
...
@@ -185,6 +203,13 @@ TEST(MatmulSPMDRule, Ctor) {
...
@@ -185,6 +203,13 @@ TEST(MatmulSPMDRule, Ctor) {
std
::
vector
<
int64_t
>
({
-
1
,
0
}));
std
::
vector
<
int64_t
>
({
-
1
,
0
}));
EXPECT_EQ
(
infered_dist_attrs
.
second
[
0
].
dims_mapping
(),
EXPECT_EQ
(
infered_dist_attrs
.
second
[
0
].
dims_mapping
(),
std
::
vector
<
int64_t
>
({
-
1
,
-
1
,
1
,
-
1
}));
std
::
vector
<
int64_t
>
({
-
1
,
-
1
,
1
,
-
1
}));
EXPECT_EQ
(
infered_dist_attrs
.
second
[
0
].
partial_dims
(),
std
::
set
<
int64_t
>
({
0
}));
VLOG
(
4
)
<<
infered_dist_attrs
.
second
[
0
].
to_string
();
infered_dist_attrs
.
second
[
0
].
clean_partial_status
();
EXPECT_EQ
(
infered_dist_attrs
.
second
[
0
].
is_partial
(),
false
);
infered_dist_attrs
.
second
[
0
].
set_partial_status
(
std
::
vector
<
int64_t
>
({
1
}));
EXPECT_EQ
(
infered_dist_attrs
.
second
[
0
].
verify_partial_status
(),
false
);
VLOG
(
4
)
<<
"test9 done."
<<
std
::
endl
<<
std
::
endl
<<
std
::
endl
;
VLOG
(
4
)
<<
"test9 done."
<<
std
::
endl
<<
std
::
endl
<<
std
::
endl
;
// abcmk[-1, -1, 1, 0], kn[1, 0] --> abcmk[-1, -1, -1, 0],kn[1, 0] =
// abcmk[-1, -1, 1, 0], kn[1, 0] --> abcmk[-1, -1, -1, 0],kn[1, 0] =
...
@@ -197,6 +222,28 @@ TEST(MatmulSPMDRule, Ctor) {
...
@@ -197,6 +222,28 @@ TEST(MatmulSPMDRule, Ctor) {
{
x_dist_tensor_spec
,
y_dist_tensor_spec
},
attrs
));
{
x_dist_tensor_spec
,
y_dist_tensor_spec
},
attrs
));
// Error
// Error
VLOG
(
4
)
<<
"test10 done."
<<
std
::
endl
<<
std
::
endl
<<
std
::
endl
;
VLOG
(
4
)
<<
"test10 done."
<<
std
::
endl
<<
std
::
endl
<<
std
::
endl
;
// abcmk[-1, -1, -1, -1], kn[1, 0] --> abcmk[-1, -1, -1, 0],kn[1, 0] =
// abcmn[-1, -1, -1, 1] partial[0]:
x_dist_tensor_spec
.
set_dims_mapping
({
-
1
,
-
1
,
0
,
1
});
y_dist_tensor_spec
.
set_dims_mapping
({
1
,
0
});
attrs
[
"trans_y"
]
=
true
;
attrs
[
"trans_x"
]
=
true
;
infered_dist_attrs
=
matmul_rule
->
InferForward
(
{
x_dist_tensor_spec
,
y_dist_tensor_spec
},
attrs
);
EXPECT_ANY_THROW
(
infered_dist_attrs
.
second
[
0
].
clean_partial_dims
(
std
::
vector
<
int64_t
>
({
1
})));
infered_dist_attrs
.
second
[
0
].
set_partial_status
(
std
::
vector
<
int64_t
>
({
1
}));
EXPECT_EQ
(
infered_dist_attrs
.
second
[
0
].
is_partial
(),
true
);
EXPECT_EQ
(
infered_dist_attrs
.
second
[
0
].
partial_dims
(),
std
::
set
<
int64_t
>
({
0
,
1
}));
infered_dist_attrs
.
second
[
0
].
clean_partial_dims
(
std
::
vector
<
int64_t
>
({
1
}));
EXPECT_EQ
(
infered_dist_attrs
.
second
[
0
].
partial_dims
(),
std
::
set
<
int64_t
>
({
0
}));
infered_dist_attrs
.
second
[
0
].
clean_partial_dims
(
std
::
vector
<
int64_t
>
({
0
}));
EXPECT_EQ
(
infered_dist_attrs
.
second
[
0
].
is_partial
(),
false
);
VLOG
(
4
)
<<
"test11 done."
<<
std
::
endl
<<
std
::
endl
<<
std
::
endl
;
}
}
TEST
(
LayerNormSPMDRule
,
Ctor
)
{
TEST
(
LayerNormSPMDRule
,
Ctor
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录