Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
4cd1a2cb
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
4cd1a2cb
编写于
5月 24, 2023
作者:
L
liangjianzhong
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
revise syntax
上级
1dcb80ea
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
92 addition
and
70 deletion
+92
-70
paddle/fluid/distributed/auto_parallel/spmd_rules/common.cc
paddle/fluid/distributed/auto_parallel/spmd_rules/common.cc
+9
-9
paddle/fluid/distributed/auto_parallel/spmd_rules/common.h
paddle/fluid/distributed/auto_parallel/spmd_rules/common.h
+33
-11
paddle/fluid/distributed/auto_parallel/spmd_rules/matmul_spmd_rule.cc
.../distributed/auto_parallel/spmd_rules/matmul_spmd_rule.cc
+49
-49
paddle/fluid/distributed/auto_parallel/spmd_rules/matmul_spmd_rule.h
...d/distributed/auto_parallel/spmd_rules/matmul_spmd_rule.h
+1
-1
未找到文件。
paddle/fluid/distributed/auto_parallel/spmd_rules/common.cc
浏览文件 @
4cd1a2cb
...
@@ -18,7 +18,7 @@ namespace paddle {
...
@@ -18,7 +18,7 @@ namespace paddle {
namespace
distributed
{
namespace
distributed
{
namespace
auto_parallel
{
namespace
auto_parallel
{
std
::
vector
<
DistTensorSpec
>
SPMDRuleBase
::
InferForward
(
std
::
vector
<
TensorDistAttr
>
SPMDRuleBase
::
InferForward
(
const
std
::
vector
<
DistTensorSpec
>&
input_specs
,
const
std
::
vector
<
DistTensorSpec
>&
input_specs
,
const
paddle
::
framework
::
AttributeMap
&
attrs
)
{
const
paddle
::
framework
::
AttributeMap
&
attrs
)
{
PADDLE_THROW
(
PADDLE_THROW
(
...
@@ -26,7 +26,7 @@ std::vector<DistTensorSpec> SPMDRuleBase::InferForward(
...
@@ -26,7 +26,7 @@ std::vector<DistTensorSpec> SPMDRuleBase::InferForward(
"derived class of SPMDRuleBase !"
));
"derived class of SPMDRuleBase !"
));
}
}
std
::
vector
<
DistTensorSpec
>
SPMDRuleBase
::
InferBackward
(
std
::
vector
<
TensorDistAttr
>
SPMDRuleBase
::
InferBackward
(
const
std
::
vector
<
DistTensorSpec
>&
output_specs
,
const
std
::
vector
<
DistTensorSpec
>&
output_specs
,
const
paddle
::
framework
::
AttributeMap
&
attrs
)
{
const
paddle
::
framework
::
AttributeMap
&
attrs
)
{
PADDLE_THROW
(
PADDLE_THROW
(
...
@@ -36,12 +36,12 @@ std::vector<DistTensorSpec> SPMDRuleBase::InferBackward(
...
@@ -36,12 +36,12 @@ std::vector<DistTensorSpec> SPMDRuleBase::InferBackward(
std
::
unordered_map
<
std
::
string
,
int64_t
>
ShardingMergeForTensors
(
std
::
unordered_map
<
std
::
string
,
int64_t
>
ShardingMergeForTensors
(
const
std
::
vector
<
std
::
pair
<
const
std
::
string
,
const
std
::
vector
<
int64_t
>>>&
const
std
::
vector
<
std
::
pair
<
const
std
::
string
,
const
std
::
vector
<
int64_t
>>>&
tensor_
notation
_to_dim_pairs
)
{
tensor_
axes
_to_dim_pairs
)
{
std
::
unordered_map
<
std
::
string
,
int64_t
>
axis_to_dim_map
;
std
::
unordered_map
<
std
::
string
,
int64_t
>
axis_to_dim_map
;
std
::
unordered_map
<
int64_t
,
std
::
string
>
dim_to_axis_map
;
std
::
unordered_map
<
int64_t
,
std
::
string
>
dim_to_axis_map
;
int64_t
merge_dim
;
int64_t
merge_dim
;
for
(
auto
&
pair
:
tensor_
notation
_to_dim_pairs
)
{
for
(
auto
&
pair
:
tensor_
axes
_to_dim_pairs
)
{
for
(
int
i
=
0
;
i
<
pair
.
second
.
size
();
i
++
)
{
for
(
int
i
=
0
;
i
<
pair
.
second
.
size
();
i
++
)
{
auto
tensor_axis
=
pair
.
first
.
substr
(
i
,
1
);
auto
tensor_axis
=
pair
.
first
.
substr
(
i
,
1
);
auto
mesh_dim
=
pair
.
second
[
i
];
auto
mesh_dim
=
pair
.
second
[
i
];
...
@@ -84,9 +84,9 @@ std::unordered_map<std::string, int64_t> ShardingMergeForTensors(
...
@@ -84,9 +84,9 @@ std::unordered_map<std::string, int64_t> ShardingMergeForTensors(
// Rule2: A tensor axis could at most be sharded by one mesh dimension.
// Rule2: A tensor axis could at most be sharded by one mesh dimension.
// (TODO trigger heuristics cost model and reshard to handle axis sharded by
// (TODO trigger heuristics cost model and reshard to handle axis sharded by
// multiple dimension case.)
// multiple dimension case.)
int64_t
ShardingMergeForAxis
(
const
std
::
string
axis
,
int64_t
ShardingMergeForAxis
(
const
std
::
string
&
axis
,
const
int64_t
mesh_dim1
,
const
int64_t
&
mesh_dim1
,
const
int64_t
mesh_dim2
)
{
const
int64_t
&
mesh_dim2
)
{
if
(
mesh_dim1
!=
mesh_dim2
)
{
if
(
mesh_dim1
!=
mesh_dim2
)
{
if
(
mesh_dim1
==
-
1
)
{
if
(
mesh_dim1
==
-
1
)
{
return
mesh_dim2
;
return
mesh_dim2
;
...
@@ -118,8 +118,8 @@ TensorDistAttr CopyTensorDistAttrForOutput(
...
@@ -118,8 +118,8 @@ TensorDistAttr CopyTensorDistAttrForOutput(
}
}
std
::
vector
<
int64_t
>
ResoluteOutputPartialDimension
(
std
::
vector
<
int64_t
>
ResoluteOutputPartialDimension
(
const
std
::
unordered_map
<
std
::
string
,
int64_t
>&
in_
axis_to_dim_map
,
const
std
::
unordered_map
<
std
::
string
,
int64_t
>&
axis_to_dim_map
,
const
std
::
string
&
out_axi
s
)
{
const
std
::
string
&
tensor_axe
s
)
{
std
::
vector
<
int64_t
>
partial_on_dims
;
std
::
vector
<
int64_t
>
partial_on_dims
;
for
(
auto
&
it
:
in_axis_to_dim_map
)
{
for
(
auto
&
it
:
in_axis_to_dim_map
)
{
...
...
paddle/fluid/distributed/auto_parallel/spmd_rules/common.h
浏览文件 @
4cd1a2cb
...
@@ -19,7 +19,9 @@ limitations under the License. */
...
@@ -19,7 +19,9 @@ limitations under the License. */
#include <string>
#include <string>
#include <vector>
#include <vector>
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/dist_tensor_spec.h"
#include "paddle/fluid/framework/type_defs.h"
#include "paddle/fluid/framework/type_defs.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
namespace
paddle
{
namespace
paddle
{
namespace
distributed
{
namespace
distributed
{
...
@@ -29,11 +31,26 @@ class SPMDRuleBase {
...
@@ -29,11 +31,26 @@ class SPMDRuleBase {
public:
public:
virtual
~
SPMDRuleBase
()
{}
virtual
~
SPMDRuleBase
()
{}
virtual
std
::
vector
<
DistTensorSpec
>
InferForward
(
// Merge the DistAttr of input tensors and infer the DistAttr of the output
// tensors from the merged input information. The input are DistAttr and Shape
// (wrapp as DistTensorSpec) of the input tensors (tensors follow the same
// order defined in Op's Phi API) and Op Attribue of the current op. The ouput
// are the Merged DistAttr of input tensors and the infered DistAttr of the
// output tensors. The Merged DistAttr might be different from the original
// Intput DistAttrs, which means that the corressponding input tensor need to
// be reshard.
virtual
std
::
vector
<
TensorDistAttr
>
InferForward
(
const
std
::
vector
<
DistTensorSpec
>&
input_specs
,
const
std
::
vector
<
DistTensorSpec
>&
input_specs
,
const
paddle
::
framework
::
AttributeMap
&
attrs
);
const
paddle
::
framework
::
AttributeMap
&
attrs
);
virtual
std
::
vector
<
DistTensorSpec
>
InferBackward
(
// Merge the DistAttr of output tensors and infer the DistAttr of the input
// tensors from the merged output information. The input are DistAttr and
// Shape (wrapp as DistTensorSpec) of the input tensors and Op Attribue of the
// current op. The ouput are the Merged DistAttr of output tensors and the
// infered DistAttr of the input tensors. This function will be use in Static
// Graph mode only, where we have the whole computation graph for sharding
// propogation.
virtual
std
::
vector
<
TensorDistAttr
>
InferBackward
(
const
std
::
vector
<
DistTensorSpec
>&
output_specs
,
const
std
::
vector
<
DistTensorSpec
>&
output_specs
,
const
paddle
::
framework
::
AttributeMap
&
attrs
);
const
paddle
::
framework
::
AttributeMap
&
attrs
);
...
@@ -44,9 +61,8 @@ class SPMDRuleBase {
...
@@ -44,9 +61,8 @@ class SPMDRuleBase {
return
PADDLE_GET_CONST
(
T
,
GetAttr
(
name
,
attrs
));
return
PADDLE_GET_CONST
(
T
,
GetAttr
(
name
,
attrs
));
}
}
virtual
const
Attribute
&
GetAttr
(
const
Attribute
&
GetAttr
(
const
std
::
string
&
name
,
const
std
::
string
&
name
,
const
paddle
::
framework
::
AttributeMap
&
attrs
)
const
{
const
paddle
::
framework
::
AttributeMap
&
attrs
)
const
{
auto
iter
=
attrs
.
find
(
name
);
auto
iter
=
attrs
.
find
(
name
);
PADDLE_ENFORCE_NE
(
PADDLE_ENFORCE_NE
(
iter
,
iter
,
...
@@ -56,23 +72,29 @@ class SPMDRuleBase {
...
@@ -56,23 +72,29 @@ class SPMDRuleBase {
}
}
};
};
// Merge sharding specification (dims mapping) of given tensors.
// The same axes of different tensors will be merged.
std
::
unordered_map
<
std
::
string
,
int64_t
>
ShardingMergeForTensors
(
std
::
unordered_map
<
std
::
string
,
int64_t
>
ShardingMergeForTensors
(
const
std
::
vector
<
std
::
pair
<
const
std
::
string
,
const
std
::
vector
<
int64_t
>>>&
const
std
::
vector
<
std
::
pair
<
const
std
::
string
,
const
std
::
vector
<
int64_t
>>>&
tensor_
notation
_to_dim_pairs
);
tensor_
axes
_to_dim_pairs
);
// Merge the sharding specification (dims mapping) for one tensor Axis.
// Rule1: A repicated dimension could be merged by any sharded dimension.
// Rule1: A repicated dimension could be merged by any sharded dimension.
// Rule2: A tensor axis could at most be sharded by one mesh dimension.
// Rule2: A tensor axis could at most be sharded by one mesh dimension.
// (TODO trigger heuristics cost model and reshard to handle axis sharded by
// (TODO trigger heuristics cost model and reshard to handle axis sharded by
// multiple dimension case.)
// multiple dimension case.)
int64_t
ShardingMergeForAxis
(
const
std
::
string
axis
,
int64_t
ShardingMergeForAxis
(
const
std
::
string
&
axis
,
const
int64_t
mesh_dim1
,
const
int64_t
&
mesh_dim1
,
const
int64_t
mesh_dim2
);
const
int64_t
&
mesh_dim2
);
TensorDistAttr
CopyTensorDistAttrForOutput
(
const
TensorDistAttr
&
src_dist_attr
);
TensorDistAttr
CopyTensorDistAttrForOutput
(
const
TensorDistAttr
&
src_dist_attr
);
// Resolute the partial mesh dimension of a output tensor, giving the
// merged sharding specifcation of input tensors and the axis names of output
// tensor. Input are
std
::
vector
<
int64_t
>
ResoluteOutputPartialDimension
(
std
::
vector
<
int64_t
>
ResoluteOutputPartialDimension
(
const
std
::
unordered_map
<
std
::
string
,
int64_t
>&
in_
axis_to_dim_map
,
const
std
::
unordered_map
<
std
::
string
,
int64_t
>&
axis_to_dim_map
,
const
std
::
string
&
out_axi
s
);
const
std
::
string
&
tensor_axe
s
);
}
// namespace auto_parallel
}
// namespace auto_parallel
}
// namespace distributed
}
// namespace distributed
...
...
paddle/fluid/distributed/auto_parallel/spmd_rules/matmul_spmd_rule.cc
浏览文件 @
4cd1a2cb
...
@@ -18,10 +18,18 @@ namespace paddle {
...
@@ -18,10 +18,18 @@ namespace paddle {
namespace
distributed
{
namespace
distributed
{
namespace
auto_parallel
{
namespace
auto_parallel
{
std
::
vector
<
DistTensorSpec
>
MatmulSPMDRule
::
InferForward
(
std
::
vector
<
TensorDistAttr
>
MatmulSPMDRule
::
InferForward
(
const
std
::
vector
<
DistTensorSpec
>&
input_specs
,
const
std
::
vector
<
DistTensorSpec
>&
input_specs
,
const
paddle
::
framework
::
AttributeMap
&
attrs
)
{
const
paddle
::
framework
::
AttributeMap
&
attrs
)
{
// step0: verify input args based on matmul logic
// step0: verify input args based on matmul logic
auto
input_specs_size
=
input_specs
.
size
();
PADDLE_ENFORCE_EQ
(
input_specs_size
,
2
,
phi
::
errors
::
InvalidArgument
(
"The size of InputSpec of matmul should be 2, but got [%d]."
,
input_specs_size
));
int
x_ndim
=
input_specs
[
0
].
shape
.
size
();
int
x_ndim
=
input_specs
[
0
].
shape
.
size
();
int
y_ndim
=
input_specs
[
1
].
shape
.
size
();
int
y_ndim
=
input_specs
[
1
].
shape
.
size
();
std
::
vector
<
int64_t
>
x_dims_mapping
=
input_specs
[
0
].
DistAttr
.
dims_mapping
;
std
::
vector
<
int64_t
>
x_dims_mapping
=
input_specs
[
0
].
DistAttr
.
dims_mapping
;
...
@@ -42,54 +50,44 @@ std::vector<DistTensorSpec> MatmulSPMDRule::InferForward(
...
@@ -42,54 +50,44 @@ std::vector<DistTensorSpec> MatmulSPMDRule::InferForward(
bool
trans_x
=
ExtractAttr
<
bool
>
(
"trans_x"
);
bool
trans_x
=
ExtractAttr
<
bool
>
(
"trans_x"
);
bool
trans_y
=
ExtractAttr
<
bool
>
(
"trans_y"
);
bool
trans_y
=
ExtractAttr
<
bool
>
(
"trans_y"
);
auto
input_specs_size
=
input_specs
.
size
()
PADDLE_ENFORCE_EQ
(
// step1: build Einsum Notation
input_specs_size
,
2
,
phi
::
errors
::
InvalidArgument
(
"The size of InputSpec of matmul should be 2, but got [%d]."
,
input_specs_size
));
// step1: Einsum Notation
int
max_ndim
=
std
::
max
(
x_ndim
,
y_ndim
);
// reserve the char k, m, n for matrix product notation: mk,kn -> mn
// reserve the char k, m, n for matrix product notation: mk,kn -> mn
int
max_ndim
=
std
::
max
(
x_ndim
,
y_ndim
);
std
::
string
alphabet
=
"abcdefghijlopqrstuvwxyz"
;
std
::
string
alphabet
=
"abcdefghijlopqrstuvwxyz"
;
std
::
string
x_
string
;
std
::
string
x_
axes
;
std
::
string
y_
string
;
std
::
string
y_
axes
;
std
::
string
out_
string
;
std
::
string
out_
axes
;
// Handle 4 different matmul cases in Paddle
// vector * vector = scala
// vector * vector = scala
if
(
x_ndim
==
1
&&
y_ndim
==
1
)
{
if
(
x_ndim
==
1
&&
y_ndim
==
1
)
{
x_
string
=
"k"
;
x_
axes
=
"k"
;
y_
string
=
"k"
;
y_
axes
=
"k"
;
out_
string
=
""
;
out_
axes
=
""
;
// vector * batched matrix
// vector * batched matrix
}
else
if
(
x_ndim
==
1
&&
y_ndim
>
1
)
{
}
else
if
(
x_ndim
==
1
&&
y_ndim
>
1
)
{
x_string
=
"k"
;
x_axes
=
"k"
;
std
::
string
y_broadcast_string
=
std
::
string
y_broadcast_axes
=
GetBroadcastAxes
(
y_ndim
,
max_ndim
,
alphabet
);
GetBroadcastNotationString
(
y_ndim
,
max_ndim
,
alphabet
);
y_axes
=
y_broadcast_axes
+
"kn"
;
y_string
=
y_broadcast_string
+
"kn"
;
out_axes
=
y_broadcast_axes
+
"n"
;
out_string
=
y_broadcast_string
+
"n"
;
// batched matrix * vector
// batched matrix * vector
}
else
if
(
x_ndim
>
1
&&
y_ndim
==
1
)
{
}
else
if
(
x_ndim
>
1
&&
y_ndim
==
1
)
{
y_string
=
"k"
;
y_axes
=
"k"
;
std
::
string
x_broadcast_string
=
std
::
string
x_broadcast_axes
=
GetBroadcastAxes
(
x_ndim
,
max_ndim
,
alphabet
);
GetBroadcastNotationString
(
x_ndim
,
max_ndim
,
alphabet
);
x_axes
=
x_broadcast_axes
+
"mk"
;
x_string
=
x_broadcast_string
+
"mk"
;
out_axes
=
x_broadcast_axes
+
"m"
;
out_string
=
x_broadcast_string
+
"m"
;
// batched matrix * batched matrix
// batched matrix * batched matrix
}
else
if
(
x_ndim
>
1
&&
y_ndim
>
1
)
{
}
else
if
(
x_ndim
>
1
&&
y_ndim
>
1
)
{
std
::
string
x_broadcast_string
=
std
::
string
x_broadcast_axes
=
GetBroadcastAxes
(
x_ndim
,
max_ndim
,
alphabet
);
GetBroadcastNotationString
(
x_ndim
,
max_ndim
,
alphabet
);
std
::
string
y_broadcast_axes
=
GetBroadcastAxes
(
y_ndim
,
max_ndim
,
alphabet
);
std
::
string
y_broadcast_string
=
x_axes
=
x_broadcast_axes
+
"mk"
;
GetBroadcastNotationString
(
y_ndim
,
max_ndim
,
alphabet
);
y_axes
=
y_broadcast_axes
+
"kn"
;
x_string
=
x_broadcast_string
+
"mk"
;
y_string
=
y_broadcast_string
+
"kn"
;
if
(
x_ndim
>
y_ndim
)
{
if
(
x_ndim
>
y_ndim
)
{
out_
string
=
x_broadcast_string
+
"mn"
;
out_
axes
=
x_broadcast_axes
+
"mn"
;
}
else
{
}
else
{
out_
string
=
y_broadcast_string
+
"mn"
;
out_
axes
=
y_broadcast_axes
+
"mn"
;
}
}
}
else
{
}
else
{
PADDLE_THROW
(
phi
::
errors
::
InvalidArgument
(
PADDLE_THROW
(
phi
::
errors
::
InvalidArgument
(
...
@@ -98,8 +96,8 @@ std::vector<DistTensorSpec> MatmulSPMDRule::InferForward(
...
@@ -98,8 +96,8 @@ std::vector<DistTensorSpec> MatmulSPMDRule::InferForward(
y_ndim
));
y_ndim
));
}
}
VLOG
(
4
)
<<
"MatmulSPMDRule build Einsum notation: ["
<<
x_
string
<<
","
VLOG
(
4
)
<<
"MatmulSPMDRule build Einsum notation: ["
<<
x_
axes
<<
","
<<
y_
string
<<
" --> "
<<
out_string
<<
"]."
;
<<
y_
axes
<<
" --> "
<<
out_axes
<<
"]."
;
// step2: Sharding Propogation
// step2: Sharding Propogation
if
(
trans_x
)
{
if
(
trans_x
)
{
...
@@ -121,34 +119,34 @@ std::vector<DistTensorSpec> MatmulSPMDRule::InferForward(
...
@@ -121,34 +119,34 @@ std::vector<DistTensorSpec> MatmulSPMDRule::InferForward(
std
::
iter_swap
(
y_dims_mapping
.
end
()
-
2
,
y_dims_mapping
.
end
()
-
1
);
std
::
iter_swap
(
y_dims_mapping
.
end
()
-
2
,
y_dims_mapping
.
end
()
-
1
);
}
}
// step2.1: Sharding Merge
// step2.1: Sharding Merge
std
::
pair
<
std
::
string
,
std
::
vector
<
int64_t
>>
x_pair
(
x_
string
,
x_dims_mapping
);
std
::
pair
<
std
::
string
,
std
::
vector
<
int64_t
>>
x_pair
(
x_
axes
,
x_dims_mapping
);
std
::
pair
<
std
::
string
,
std
::
vector
<
int64_t
>>
y_pair
(
y_
string
,
y_dims_mapping
);
std
::
pair
<
std
::
string
,
std
::
vector
<
int64_t
>>
y_pair
(
y_
axes
,
y_dims_mapping
);
std
::
vector
<
std
::
pair
<
const
std
::
string
,
const
std
::
vector
<
int64_t
>>>
std
::
vector
<
std
::
pair
<
const
std
::
string
,
const
std
::
vector
<
int64_t
>>>
input_pairs
;
input_pairs
;
input_pairs
.
push_back
(
x_pair
);
input_pairs
.
push_back
(
x_pair
);
input_pairs
.
push_back
(
y_pair
);
input_pairs
.
push_back
(
y_pair
);
auto
axis_to_dim_map
=
ShardingMergeForTensors
(
input_pairs
);
auto
axis_to_dim_map
=
ShardingMergeForTensors
(
input_pairs
);
// step2.2:
fill output's dim m
apping.
// step2.2:
Infer Output's Dims M
apping.
TensorDistAttr
output_dist_attr_dst
=
TensorDistAttr
output_dist_attr_dst
=
CopyTensorDistAttrForOutput
(
input_specs
[
0
].
DistAttr
)
std
::
vector
<
int64_t
>
CopyTensorDistAttrForOutput
(
input_specs
[
0
].
DistAttr
)
;
out_dims_mapping
;
std
::
vector
<
int64_t
>
out_dims_mapping
;
out_dims_mapping
.
reserve
(
out_
string
.
size
());
out_dims_mapping
.
reserve
(
out_
axes
.
size
());
for
(
int
i
=
0
;
i
<
out_
string
.
size
();
++
i
)
{
for
(
int
i
=
0
;
i
<
out_
axes
.
size
();
++
i
)
{
out_dims_mapping
.
push_back
(
axis_to_dim_map
[
out_
string
.
substr
(
i
,
1
)]);
out_dims_mapping
.
push_back
(
axis_to_dim_map
[
out_
axes
.
substr
(
i
,
1
)]);
}
}
output_dist_attr_dst
.
set_dims_mapping
(
out_dims_mapping
);
output_dist_attr_dst
.
set_dims_mapping
(
out_dims_mapping
);
// step2.3:
fill input's dim m
apping.
// step2.3:
Merge and get Inputs' New Dims M
apping.
TensorDistAttr
x_dist_attr_dst
=
GetInferedDistAttr
(
TensorDistAttr
x_dist_attr_dst
=
GetInferedDistAttr
(
input_specs
[
0
].
DistAttr
,
input_specs
[
0
].
shape
,
x_
string
,
axis_to_dim_map
);
input_specs
[
0
].
DistAttr
,
input_specs
[
0
].
shape
,
x_
axes
,
axis_to_dim_map
);
TensorDistAttr
y_dist_attr_dst
=
GetInferedDistAttr
(
TensorDistAttr
y_dist_attr_dst
=
GetInferedDistAttr
(
input_specs
[
1
].
DistAttr
,
input_specs
[
1
].
shape
,
y_
string
,
axis_to_dim_map
);
input_specs
[
1
].
DistAttr
,
input_specs
[
1
].
shape
,
y_
axes
,
axis_to_dim_map
);
// step2.3: Handle Partial
// step2.3: Handle Partial
// Step2.3.1 Output Partial
// Step2.3.1 Output Partial
std
::
vector
<
int64_t
>
partial_on_dims
=
std
::
vector
<
int64_t
>
partial_on_dims
=
ResoluteOutputPartialDimension
(
axis_to_dim_map
,
out_
string
);
ResoluteOutputPartialDimension
(
axis_to_dim_map
,
out_
axes
);
// Step2.3.2 handle input tensor partial (TODO)
// Step2.3.2 handle input tensor partial (TODO)
...
@@ -161,6 +159,8 @@ std::vector<DistTensorSpec> MatmulSPMDRule::InferForward(
...
@@ -161,6 +159,8 @@ std::vector<DistTensorSpec> MatmulSPMDRule::InferForward(
<<
", dst_dims_mapping: "
<<
y_dist_attr_dst
.
dims_mapping
<<
", dst_dims_mapping: "
<<
y_dist_attr_dst
.
dims_mapping
<<
"; Output dims_mapping: "
<<
out_dims_mapping
<<
"; Output dims_mapping: "
<<
out_dims_mapping
<<
", partial_on_dims: "
<<
partial_on_dims
;
<<
", partial_on_dims: "
<<
partial_on_dims
;
return
{
x_dist_attr_dst
,
y_dist_attr_dst
,
output_dist_attr_dst
}
}
}
TensorDistAttr
GetInferedDistAttr
(
TensorDistAttr
GetInferedDistAttr
(
...
@@ -184,7 +184,7 @@ TensorDistAttr GetInferedDistAttr(
...
@@ -184,7 +184,7 @@ TensorDistAttr GetInferedDistAttr(
return
dist_attr_
;
return
dist_attr_
;
}
}
std
::
vector
<
DistTensorSpec
>
MatmulSPMDRule
::
InferBackward
(
std
::
vector
<
TensorDistAttr
>
MatmulSPMDRule
::
InferBackward
(
const
std
::
vector
<
DistTensorSpec
>&
output_specs
,
const
std
::
vector
<
DistTensorSpec
>&
output_specs
,
const
paddle
::
framework
::
AttributeMap
&
attrs
)
{}
const
paddle
::
framework
::
AttributeMap
&
attrs
)
{}
...
...
paddle/fluid/distributed/auto_parallel/spmd_rules/matmul_spmd_rule.h
浏览文件 @
4cd1a2cb
...
@@ -28,7 +28,7 @@ namespace auto_parallel {
...
@@ -28,7 +28,7 @@ namespace auto_parallel {
TensorDistAttr
GetInferedDistAttr
(
TensorDistAttr
GetInferedDistAttr
(
const
TensorDistAttr
&
origin_dist_attr
,
const
TensorDistAttr
&
origin_dist_attr
,
const
std
::
vector
<
int64_t
>&
shape
,
const
std
::
vector
<
int64_t
>&
shape
,
const
std
::
string
&
tensor_ax
i
s
,
const
std
::
string
&
tensor_ax
e
s
,
const
std
::
unordered_map
<
std
::
string
,
int64_t
>&
axis_to_dim_map
);
const
std
::
unordered_map
<
std
::
string
,
int64_t
>&
axis_to_dim_map
);
class
MatmulSPMDRule
:
public
SPMDRuleBase
{
class
MatmulSPMDRule
:
public
SPMDRuleBase
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录