Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
8e9de875
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
8e9de875
编写于
9月 08, 2023
作者:
Y
Yichen Zhang
提交者:
GitHub
9月 08, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add reshape backward rule (#56443)
上级
f2968742
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
224 addition
and
5 deletion
+224
-5
paddle/fluid/distributed/auto_parallel/spmd_rules/reshape_spmd_rule.cc
...distributed/auto_parallel/spmd_rules/reshape_spmd_rule.cc
+56
-3
paddle/fluid/distributed/auto_parallel/spmd_rules/reshape_spmd_rule.h
.../distributed/auto_parallel/spmd_rules/reshape_spmd_rule.h
+2
-1
test/auto_parallel/spmd_rules/test_reshape_rule.py
test/auto_parallel/spmd_rules/test_reshape_rule.py
+166
-1
未找到文件。
paddle/fluid/distributed/auto_parallel/spmd_rules/reshape_spmd_rule.cc
浏览文件 @
8e9de875
...
@@ -135,6 +135,7 @@ std::vector<DimTrans*> MakeReshapeDimTrans(
...
@@ -135,6 +135,7 @@ std::vector<DimTrans*> MakeReshapeDimTrans(
return
ret
;
return
ret
;
}
}
//
std
::
pair
<
std
::
vector
<
TensorDistAttr
>
,
std
::
vector
<
TensorDistAttr
>>
std
::
pair
<
std
::
vector
<
TensorDistAttr
>
,
std
::
vector
<
TensorDistAttr
>>
paddle
::
distributed
::
auto_parallel
::
ReshapeSPMDRule
::
InferForward
(
paddle
::
distributed
::
auto_parallel
::
ReshapeSPMDRule
::
InferForward
(
const
std
::
vector
<
DistTensorSpec
>&
input_specs
,
const
std
::
vector
<
DistTensorSpec
>&
input_specs
,
...
@@ -195,12 +196,64 @@ paddle::distributed::auto_parallel::ReshapeSPMDRule::InferForward(
...
@@ -195,12 +196,64 @@ paddle::distributed::auto_parallel::ReshapeSPMDRule::InferForward(
std
::
pair
<
std
::
vector
<
TensorDistAttr
>
,
std
::
vector
<
TensorDistAttr
>>
std
::
pair
<
std
::
vector
<
TensorDistAttr
>
,
std
::
vector
<
TensorDistAttr
>>
paddle
::
distributed
::
auto_parallel
::
ReshapeSPMDRule
::
InferBackward
(
paddle
::
distributed
::
auto_parallel
::
ReshapeSPMDRule
::
InferBackward
(
const
std
::
vector
<
DistTensorSpec
>&
input_specs
,
const
std
::
vector
<
DistTensorSpec
>&
output_specs
,
const
std
::
vector
<
DistTensorSpec
>&
output_specs
,
const
paddle
::
framework
::
AttributeMap
&
attrs
)
{
const
paddle
::
framework
::
AttributeMap
&
attrs
)
{
PADDLE_THROW
(
phi
::
errors
::
Unimplemented
(
// step0: Verify Input Args Based on Reshape Logic
"InferBackward of ReductionSPMDRule is NOT implemented yet."
));
int64_t
ninputs
=
input_specs
.
size
();
int64_t
noutputs
=
output_specs
.
size
();
PADDLE_ENFORCE_EQ
(
ninputs
,
1
,
phi
::
errors
::
InvalidArgument
(
"The size of InputSpec in reshape must "
"be equal to 1, but got [%d]."
,
ninputs
));
PADDLE_ENFORCE_EQ
(
noutputs
,
1
,
phi
::
errors
::
InvalidArgument
(
"The size of OutputSpec in reshape must "
"be equal to 1, but got [%d]."
,
noutputs
));
VerifySpecs
(
output_specs
,
"reshape"
);
// step1: build the transformation from the output shape
// to original shape. Inferbackward infers the dims mapping
// from output to input, we first get the transformation
// from output to input so that we can infer the dims mapping
// with the map from output axes to input axes.
// Shapes in Inferbackward don't contain -1 or 0, so they will
// not be modified and we can use ref here.
const
std
::
vector
<
int64_t
>&
output_shape
=
output_specs
[
0
].
shape
();
const
std
::
vector
<
int64_t
>&
input_shape
=
input_specs
[
0
].
shape
();
std
::
vector
<
DimTrans
*>
trans
=
MakeReshapeDimTrans
(
output_shape
,
input_shape
);
// step2: infer the dims mapping of input with
// output's dims_mapping and the transformation.
std
::
vector
<
std
::
vector
<
int64_t
>>
dims_mapping_vec
=
InferFromDimTrans
(
output_specs
[
0
],
trans
);
// step3: update the dist attributes of input
// and output with the inferred dims mapping
TensorDistAttr
new_output_dist_attr
(
output_specs
[
0
].
dist_attr
());
new_output_dist_attr
.
set_dims_mapping
(
dims_mapping_vec
[
0
]);
TensorDistAttr
input_dist_attr
(
input_specs
[
0
].
dist_attr
());
input_dist_attr
.
set_dims_mapping
(
dims_mapping_vec
[
1
]);
VLOG
(
4
)
<<
"Reshape Inferbackward: output_shape: ["
<<
str_join
(
output_shape
)
<<
"] input_shape: ["
<<
str_join
(
input_shape
)
<<
"]"
;
VLOG
(
4
)
<<
"Transformation from output to input:"
;
for
(
int64_t
i
=
0
,
n
=
trans
.
size
();
i
<
n
;
i
++
)
{
DimTrans
*
t
=
trans
[
i
];
VLOG
(
4
)
<<
"
\t
Input axis "
<<
i
<<
": "
<<
t
->
to_string
();
}
VLOG
(
4
)
<<
"input_dims_mapping: ["
<<
str_join
(
dims_mapping_vec
[
1
])
<<
"] output_dims_mapping: ["
<<
str_join
(
dims_mapping_vec
[
0
])
<<
"]
\n\n
"
;
CleanUp
();
return
{};
return
{
{
input_dist_attr
},
{
new_output_dist_attr
}
};
}
}
}
// namespace auto_parallel
}
// namespace auto_parallel
...
...
paddle/fluid/distributed/auto_parallel/spmd_rules/reshape_spmd_rule.h
浏览文件 @
8e9de875
...
@@ -32,7 +32,8 @@ class ReshapeSPMDRule : public SPMDRuleBase {
...
@@ -32,7 +32,8 @@ class ReshapeSPMDRule : public SPMDRuleBase {
const
paddle
::
framework
::
AttributeMap
&
attrs
)
override
;
const
paddle
::
framework
::
AttributeMap
&
attrs
)
override
;
std
::
pair
<
std
::
vector
<
TensorDistAttr
>
,
std
::
vector
<
TensorDistAttr
>>
std
::
pair
<
std
::
vector
<
TensorDistAttr
>
,
std
::
vector
<
TensorDistAttr
>>
InferBackward
(
const
std
::
vector
<
DistTensorSpec
>&
output_specs
,
InferBackward
(
const
std
::
vector
<
DistTensorSpec
>&
input_specs
,
const
std
::
vector
<
DistTensorSpec
>&
output_specs
,
const
paddle
::
framework
::
AttributeMap
&
attrs
)
override
;
const
paddle
::
framework
::
AttributeMap
&
attrs
)
override
;
};
};
}
// namespace auto_parallel
}
// namespace auto_parallel
...
...
test/auto_parallel/spmd_rules/test_reshape_rule.py
浏览文件 @
8e9de875
...
@@ -30,7 +30,7 @@ class TestReshapeSPMDRule(unittest.TestCase):
...
@@ -30,7 +30,7 @@ class TestReshapeSPMDRule(unittest.TestCase):
process_mesh
=
auto
.
ProcessMesh
(
mesh
=
[[
0
,
1
,
2
],
[
3
,
4
,
5
]])
process_mesh
=
auto
.
ProcessMesh
(
mesh
=
[[
0
,
1
,
2
],
[
3
,
4
,
5
]])
x_tensor_dist_attr
=
TensorDistAttr
()
x_tensor_dist_attr
=
TensorDistAttr
()
x_tensor_dist_attr
.
dims_mapping
=
[
-
1
,
-
1
]
x_tensor_dist_attr
.
dims_mapping
=
[
-
1
,
-
1
,
-
1
,
-
1
]
x_tensor_dist_attr
.
process_mesh
=
process_mesh
x_tensor_dist_attr
.
process_mesh
=
process_mesh
self
.
x_dist_tensor_spec
=
DistTensorSpec
(
x_shape
,
x_tensor_dist_attr
)
self
.
x_dist_tensor_spec
=
DistTensorSpec
(
x_shape
,
x_tensor_dist_attr
)
...
@@ -248,6 +248,171 @@ class TestReshapeSPMDRule(unittest.TestCase):
...
@@ -248,6 +248,171 @@ class TestReshapeSPMDRule(unittest.TestCase):
with
self
.
assertRaises
(
BaseException
):
with
self
.
assertRaises
(
BaseException
):
self
.
rule
.
infer_forward
([
self
.
x_dist_tensor_spec
],
self
.
attrs
)
self
.
rule
.
infer_forward
([
self
.
x_dist_tensor_spec
],
self
.
attrs
)
def
test_reshape_infer_backward
(
self
):
process_mesh
=
auto
.
ProcessMesh
(
mesh
=
[[
0
,
1
,
2
],
[
3
,
4
,
5
]])
output_tensor_dist_attr
=
TensorDistAttr
()
output_tensor_dist_attr
.
dims_mapping
=
[
-
1
,
-
1
,
-
1
,
-
1
]
output_tensor_dist_attr
.
process_mesh
=
process_mesh
# shape: [6, 12, 48, 24] --> [1, 72, 48, 4, 6] (input --> output)
# dims_mapping: [-1, 0, 1, -1, -1] --> [0, -1, 1, -1], [-1, 0, 1, -1, -1] (output --> input, output)
self
.
output_dist_tensor_spec
=
DistTensorSpec
(
[
1
,
72
,
48
,
4
,
6
],
output_tensor_dist_attr
)
self
.
output_dist_tensor_spec
.
set_dims_mapping
([
-
1
,
0
,
1
,
-
1
,
-
1
])
result_dist_attrs
=
self
.
rule
.
infer_backward
(
[
self
.
x_dist_tensor_spec
],
[
self
.
output_dist_tensor_spec
],
self
.
attrs
,
)
infered_input_dist_attrs
=
result_dist_attrs
[
0
]
infered_output_dist_attrs
=
result_dist_attrs
[
1
]
self
.
assertEqual
(
len
(
infered_input_dist_attrs
),
1
)
self
.
assertEqual
(
len
(
infered_output_dist_attrs
),
1
)
self
.
assertEqual
(
infered_input_dist_attrs
[
0
].
dims_mapping
,
[
0
,
-
1
,
1
,
-
1
]
)
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
dims_mapping
,
[
-
1
,
0
,
1
,
-
1
,
-
1
]
)
# shape: [6, 12, 48, 24] --> [1, 72, 48, 4, 6] (input --> output)
# dims_mapping: [-1, -1, -1, -1, -1] --> [-1, -1, -1, -1], [-1, -1, -1, -1, -1] (output --> input, output)
self
.
output_dist_tensor_spec
.
shape
=
[
1
,
72
,
48
,
4
,
6
]
self
.
output_dist_tensor_spec
.
set_dims_mapping
([
-
1
,
-
1
,
-
1
,
-
1
,
-
1
])
result_dist_attrs
=
self
.
rule
.
infer_backward
(
[
self
.
x_dist_tensor_spec
],
[
self
.
output_dist_tensor_spec
],
self
.
attrs
,
)
infered_input_dist_attrs
=
result_dist_attrs
[
0
]
infered_output_dist_attrs
=
result_dist_attrs
[
1
]
self
.
assertEqual
(
infered_input_dist_attrs
[
0
].
dims_mapping
,
[
-
1
,
-
1
,
-
1
,
-
1
]
)
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
dims_mapping
,
[
-
1
,
-
1
,
-
1
,
-
1
,
-
1
]
)
# shape: [6, 12, 48, 24] --> [1, 72, 48, 4, 6] (input --> output)
# dims_mapping: [-1, 1, -1, 0, -1] --> [1, -1, -1, 0] [-1, 1, -1, 0, -1] (output --> input, output)
self
.
output_dist_tensor_spec
.
shape
=
[
1
,
72
,
48
,
4
,
6
]
self
.
output_dist_tensor_spec
.
set_dims_mapping
([
-
1
,
1
,
-
1
,
0
,
-
1
])
result_dist_attrs
=
self
.
rule
.
infer_backward
(
[
self
.
x_dist_tensor_spec
],
[
self
.
output_dist_tensor_spec
],
self
.
attrs
,
)
infered_input_dist_attrs
=
result_dist_attrs
[
0
]
infered_output_dist_attrs
=
result_dist_attrs
[
1
]
self
.
assertEqual
(
infered_input_dist_attrs
[
0
].
dims_mapping
,
[
1
,
-
1
,
-
1
,
0
]
)
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
dims_mapping
,
[
-
1
,
1
,
-
1
,
0
,
-
1
]
)
# shape: [6, 12, 48, 24] --> [3, 24, 6, 8, 24] (input --> output)
# dims_mapping: [1, -1, -1, -1, 0] --> [1, -1, -1, 0], [1, -1, -1, -1, 0] (output --> input, output)
self
.
output_dist_tensor_spec
.
shape
=
[
3
,
24
,
6
,
8
,
24
]
self
.
output_dist_tensor_spec
.
set_dims_mapping
([
1
,
-
1
,
-
1
,
-
1
,
0
])
result_dist_attrs
=
self
.
rule
.
infer_backward
(
[
self
.
x_dist_tensor_spec
],
[
self
.
output_dist_tensor_spec
],
self
.
attrs
,
)
infered_input_dist_attrs
=
result_dist_attrs
[
0
]
infered_output_dist_attrs
=
result_dist_attrs
[
1
]
self
.
assertEqual
(
infered_input_dist_attrs
[
0
].
dims_mapping
,
[
1
,
-
1
,
-
1
,
0
]
)
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
dims_mapping
,
[
1
,
-
1
,
-
1
,
-
1
,
0
]
)
# shape: [6, 12, 48, 24] --> [3, 24, 6, 8, 24] (input --> output)
# dims_mapping: [-1, -1, 0, -1, 1] --> [-1, -1, 0, 1], [-1, -1, 0, -1, 1] (output --> input, output)
self
.
output_dist_tensor_spec
.
shape
=
[
3
,
24
,
6
,
8
,
24
]
self
.
output_dist_tensor_spec
.
set_dims_mapping
([
-
1
,
-
1
,
0
,
-
1
,
1
])
result_dist_attrs
=
self
.
rule
.
infer_backward
(
[
self
.
x_dist_tensor_spec
],
[
self
.
output_dist_tensor_spec
],
self
.
attrs
,
)
infered_input_dist_attrs
=
result_dist_attrs
[
0
]
infered_output_dist_attrs
=
result_dist_attrs
[
1
]
self
.
assertEqual
(
infered_input_dist_attrs
[
0
].
dims_mapping
,
[
-
1
,
-
1
,
0
,
1
]
)
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
dims_mapping
,
[
-
1
,
-
1
,
0
,
-
1
,
1
]
)
# shape: [6, 12, 48, 24] --> [6, 12, 48, 24] (intput --> output)
# dims_mapping: [-1, -1, 0, 1] --> [-1, -1, 0, 1], [-1, -1, 0, 1] (output --> input, output)
self
.
output_dist_tensor_spec
.
shape
=
[
6
,
12
,
48
,
24
]
self
.
output_dist_tensor_spec
.
set_dims_mapping
([
-
1
,
-
1
,
0
,
1
])
result_dist_attrs
=
self
.
rule
.
infer_backward
(
[
self
.
x_dist_tensor_spec
],
[
self
.
output_dist_tensor_spec
],
self
.
attrs
,
)
infered_input_dist_attrs
=
result_dist_attrs
[
0
]
infered_output_dist_attrs
=
result_dist_attrs
[
1
]
self
.
assertEqual
(
infered_input_dist_attrs
[
0
].
dims_mapping
,
[
-
1
,
-
1
,
0
,
1
]
)
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
dims_mapping
,
[
-
1
,
-
1
,
0
,
1
]
)
# shape: [6, 12, 48, 24] --> [72, 3, 16, 24] (intput --> output)
# dims_mapping: [0, 1, -1, -1] --> [0, -1, 1, -1], [0, 1, -1, -1] (output --> input, output)
self
.
output_dist_tensor_spec
.
shape
=
[
72
,
3
,
16
,
24
]
self
.
output_dist_tensor_spec
.
set_dims_mapping
([
0
,
1
,
-
1
,
-
1
])
result_dist_attrs
=
self
.
rule
.
infer_backward
(
[
self
.
x_dist_tensor_spec
],
[
self
.
output_dist_tensor_spec
],
self
.
attrs
,
)
infered_input_dist_attrs
=
result_dist_attrs
[
0
]
infered_output_dist_attrs
=
result_dist_attrs
[
1
]
self
.
assertEqual
(
infered_input_dist_attrs
[
0
].
dims_mapping
,
[
0
,
-
1
,
1
,
-
1
]
)
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
dims_mapping
,
[
0
,
1
,
-
1
,
-
1
]
)
# shape: [6, 12, 48, 24] --> [72, 3, 16, 24] (intput --> output)
# dims_mapping: [1, -1, -1, -1] --> [1, -1, -1, -1], [1, -1, -1, -1] (output --> input, output)
self
.
output_dist_tensor_spec
.
shape
=
[
72
,
3
,
16
,
24
]
self
.
output_dist_tensor_spec
.
set_dims_mapping
([
1
,
-
1
,
-
1
,
-
1
])
result_dist_attrs
=
self
.
rule
.
infer_backward
(
[
self
.
x_dist_tensor_spec
],
[
self
.
output_dist_tensor_spec
],
self
.
attrs
,
)
infered_input_dist_attrs
=
result_dist_attrs
[
0
]
infered_output_dist_attrs
=
result_dist_attrs
[
1
]
self
.
assertEqual
(
infered_input_dist_attrs
[
0
].
dims_mapping
,
[
1
,
-
1
,
-
1
,
-
1
]
)
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
dims_mapping
,
[
1
,
-
1
,
-
1
,
-
1
]
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录