未验证 提交 1aef42eb 编写于 作者: Y Yichen Zhang 提交者: GitHub

[Semi-Auto] fix bug in reshape spmd rule (#56593)

* fix small bug in reshape spmd rule

* small fix in unit test
上级 63b70740
......@@ -214,7 +214,8 @@ DimTrans* GetDimTrans(DimTrans* dim_trans,
for (int64_t i = 1, n = inputs.size(); i < n; i++) {
DimTrans* input = inputs[i];
if (input->type() == DimTrans::Type::INPUTDIM) {
(*shardable)[i].assign(nmesh, false);
InputDim* inputdim = dynamic_cast<InputDim*>(input);
(*shardable)[inputdim->input_dim()].assign(nmesh, false);
}
GetDimTrans(input,
......
......@@ -88,7 +88,7 @@ std::vector<DimTrans*> MakeReshapeDimTrans(
if (tgt_idx >= tgt_len) {
t = 1;
} else {
t = tgt_shape[tgt_idx];
t = inferred_tgt_shape[tgt_idx];
tgt_splitted_shape.emplace_back(t);
tgt_idx++;
}
......
......@@ -208,6 +208,40 @@ class TestReshapeSPMDRule(unittest.TestCase):
infered_output_dist_attrs[0].dims_mapping, [1, -1, -1, -1]
)
# shape: [1, 72, 48, 4, 6] --> [6, 12, 48, 24]
# dims_mapping: [-1, 1, -1, 0, -1] --> [-1, 1, -1, 0, -1] [1, -1, -1, 0]
self.x_dist_tensor_spec.shape = [1, 72, 48, 4, 6]
self.attrs["shape"] = [6, 12, 48, 24]
self.x_dist_tensor_spec.set_dims_mapping([-1, 1, -1, 0, -1])
result_dist_attrs = self.rule.infer_forward(
[self.x_dist_tensor_spec], self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(
infered_input_dist_attrs[0].dims_mapping, [-1, 1, -1, 0, -1]
)
self.assertEqual(
infered_output_dist_attrs[0].dims_mapping, [1, -1, -1, 0]
)
# shape: [8, 1024, 3072] --> [0, 0, -1, 192]
# dims_mapping: [0, 1, -1] --> [0, 1, -1], [0, 1, -1, -1]
self.x_dist_tensor_spec.shape = [8, 1024, 3072]
self.attrs["shape"] = [0, 0, -1, 192]
self.x_dist_tensor_spec.set_dims_mapping([0, 1, -1])
result_dist_attrs = self.rule.infer_forward(
[self.x_dist_tensor_spec], self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, 1, -1])
self.assertEqual(
infered_output_dist_attrs[0].dims_mapping, [0, 1, -1, -1]
)
# shape: [6, 12, 48, 24] --> [3, 24, 6, -1, -1]
# raise error
self.attrs["shape"] = [3, 24, 6, -1, -1]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册