未验证 提交 f2968742 编写于 作者: Y Yichen Zhang 提交者: GitHub

add transpose backward rule (#56509)

上级 f839e821
......@@ -23,7 +23,7 @@ std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
TransposeSPMDRule::InferForward(const std::vector<DistTensorSpec>& input_specs,
const paddle::framework::AttributeMap& attrs) {
// step0: Verify Input Args Based on Transpose Logic
int64_t ninputs = static_cast<int64_t>(input_specs.size());
int64_t ninputs = input_specs.size();
PADDLE_ENFORCE_EQ(
ninputs,
1,
......@@ -33,27 +33,15 @@ TransposeSPMDRule::InferForward(const std::vector<DistTensorSpec>& input_specs,
VerifySpecs(input_specs, "transpose");
// step1: Build Einsum Notation
std::vector<int64_t> perm_dims =
ExtractAttr<std::vector<int64_t>>("perm", attrs);
std::string alphabet = "abcdefghijklmnopqrstuvwxyz";
// get einsum notation for input
int64_t ndim = static_cast<int64_t>(input_specs[0].shape().size());
int64_t ndim = input_specs[0].shape().size();
std::vector<std::string> input_axes_vec;
std::string input_axes = alphabet.substr(0, ndim);
input_axes_vec.emplace_back(input_axes);
// get einsum notation for output
for (int64_t i = 0, n = static_cast<int64_t>(perm_dims.size()); i < n; ++i) {
// convert the negative dim value to normal dim value
if (perm_dims[i] < 0) {
perm_dims[i] = ndim + perm_dims[i];
}
}
std::string output_axes = "";
for (int64_t i = 0; i < ndim; i++) {
output_axes.append(1, input_axes[perm_dims[i]]);
}
std::string output_axes = GetOutputNotation(ndim, input_axes, attrs);
// step2: Sharding Propogation
// step2.1: merge input shardings
......@@ -72,17 +60,19 @@ TransposeSPMDRule::InferForward(const std::vector<DistTensorSpec>& input_specs,
CopyTensorDistAttrForOutput(input_specs[0].dist_attr());
output_dist_attr.set_dims_mapping(output_dims_mapping);
// Step2.3 handle input tensor partial (TODO)
// step3 Handle partial (TODO)
VLOG(4) << "TransposeSPMDRule InferForward:";
for (int64_t i = 0; i < ninputs; i++) {
VLOG(4) << "Input" << std::to_string(i) << " shape: ["
<< str_join(input_specs[i].shape()) << "] "
<< "src_dims_mapping: [" << str_join(input_specs[i].dims_mapping())
<< "] "
<< "perm: [" << str_join(perm_dims) << "] "
<< "dst_dims_mapping: [" << str_join(input_specs[i].dims_mapping())
<< "]";
}
VLOG(4) << "Perm: ["
<< str_join(ExtractAttr<std::vector<int64_t>>("perm", attrs)) << "]";
VLOG(4) << "Output dims_mapping: [" + str_join(output_dims_mapping) + "]\n\n";
return {{input_specs[0].dist_attr()}, {output_dist_attr}};
......@@ -90,12 +80,92 @@ TransposeSPMDRule::InferForward(const std::vector<DistTensorSpec>& input_specs,
std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
TransposeSPMDRule::InferBackward(
const std::vector<DistTensorSpec>& input_specs,
const std::vector<DistTensorSpec>& output_specs,
const paddle::framework::AttributeMap& attrs) {
PADDLE_THROW(phi::errors::Unimplemented(
"InferBackward of TransposeSPMDRule is NOT implemented yet."));
// step0: Verify Input Args Based on Transpose Logic
int64_t ninputs = input_specs.size();
int64_t noutputs = output_specs.size();
PADDLE_ENFORCE_EQ(
ninputs,
1,
phi::errors::InvalidArgument("The size of InputSpec in transpose must "
"be equal to 1, but got [%d].",
ninputs));
PADDLE_ENFORCE_EQ(
noutputs,
1,
phi::errors::InvalidArgument("The size of OutputSpec in transpose must "
"be equal to 1, but got [%d].",
noutputs));
VerifySpecs(output_specs, "transpose_backward");
// step1: Build Einsum Notation
std::string alphabet = "abcdefghijklmnopqrstuvwxyz";
// get einsum notation for input
int64_t ndim = input_specs[0].shape().size();
std::string input_axes = alphabet.substr(0, ndim);
// get einsum notation for output
std::string output_axes = GetOutputNotation(ndim, input_axes, attrs);
std::vector<std::string> output_axes_vec;
output_axes_vec.emplace_back(output_axes);
// step2: Sharding Propogation
// step2.1: merge input shardings
std::vector<std::pair<std::string, std::vector<int64_t>>> axes_sharding_info;
axes_sharding_info = GetAxesDimsMappingPair(output_axes_vec, output_specs);
std::unordered_map<std::string, int64_t> axis_to_dim_map =
ShardingMergeForTensors(axes_sharding_info);
// step2.2: infer output dimsmapping from merged input dimsmapping
std::vector<int64_t> input_dims_mapping =
GetDimsMappingForAxes(input_axes, axis_to_dim_map);
// initialize output dist_attr's process_mesh, batch_dim and dynamic dims with
// input dist_attr.
TensorDistAttr input_dist_attr =
CopyTensorDistAttrForOutput(input_specs[0].dist_attr());
input_dist_attr.set_dims_mapping(input_dims_mapping);
// Step3 Handle partial (TODO)
VLOG(4) << "TransposeSPMDRule InferBackward:";
VLOG(4) << "Output shape: [" << str_join(output_specs[0].shape()) << "] "
<< "dims_mapping: [" << str_join(output_specs[0].dims_mapping())
<< "]";
VLOG(4) << "Perm: ["
<< str_join(ExtractAttr<std::vector<int64_t>>("perm", attrs)) << "]";
for (int64_t i = 0; i < ninputs; i++) {
VLOG(4) << "Input" << std::to_string(i) << " shape: ["
<< str_join(input_specs[i].shape()) << "] "
<< "dims_mapping: [" << str_join(input_dims_mapping) << "]";
}
VLOG(4) << std::endl;
return {{input_dist_attr}, {output_specs[0].dist_attr()}};
}
std::string TransposeSPMDRule::GetOutputNotation(
int64_t input_ndim,
const std::string& input_axes,
const paddle::framework::AttributeMap& attrs) {
std::vector<int64_t> perm_dims =
ExtractAttr<std::vector<int64_t>>("perm", attrs);
// convert the negative dim value to normal dim value
for (int64_t i = 0, n = perm_dims.size(); i < n; ++i) {
if (perm_dims[i] < 0) {
perm_dims[i] = input_ndim + perm_dims[i];
}
}
std::string output_axes = "";
for (int64_t i = 0; i < input_ndim; i++) {
output_axes.append(1, input_axes[perm_dims[i]]);
}
return {};
return output_axes;
}
} // namespace auto_parallel
......
......@@ -32,8 +32,14 @@ class TransposeSPMDRule : public SPMDRuleBase {
const paddle::framework::AttributeMap& attrs) override;
std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
InferBackward(const std::vector<DistTensorSpec>& output_specs,
InferBackward(const std::vector<DistTensorSpec>& input_specs,
const std::vector<DistTensorSpec>& output_specs,
const paddle::framework::AttributeMap& attrs) override;
private:
std::string GetOutputNotation(int64_t input_ndim,
const std::string& input_axes,
const paddle::framework::AttributeMap& attrs);
};
} // namespace auto_parallel
} // namespace distributed
......
......@@ -38,6 +38,8 @@ class TestTransposeSPMDRule(unittest.TestCase):
x_tensor_dist_attr.process_mesh = process_mesh
self.x_dist_tensor_spec = DistTensorSpec(x_shape, x_tensor_dist_attr)
self.out_dist_tensor_spec = DistTensorSpec(self.x_dist_tensor_spec)
self.attrs = {
'perm': [0, 1, 2, 3],
}
......@@ -149,6 +151,121 @@ class TestTransposeSPMDRule(unittest.TestCase):
infered_output_dist_attrs[0].dims_mapping, [1, -1, 0, -1]
)
def test_backward_single_mesh_dim(self):
# perm = [1, 0]
# [-1, 0] --> [0, -1], [-1, 0] (output --> input, output)
self.attrs['perm'] = [1, 0]
self.out_dist_tensor_spec.shape = [36, 64]
self.out_dist_tensor_spec.set_dims_mapping([-1, 0])
result_dist_attrs = self.rule.infer_backward(
[self.x_dist_tensor_spec], [self.out_dist_tensor_spec], self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(len(result_dist_attrs), 2)
self.assertEqual(len(infered_input_dist_attrs), 1)
self.assertEqual(len(infered_output_dist_attrs), 1)
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, -1])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, 0])
# perm = [0, 1]
# [0, -1] --> [0, -1], [0, -1] (output --> input, output)
self.attrs['perm'] = [0, 1]
self.out_dist_tensor_spec.shape = [64, 36]
self.out_dist_tensor_spec.set_dims_mapping([0, -1])
result_dist_attrs = self.rule.infer_backward(
[self.x_dist_tensor_spec], [self.out_dist_tensor_spec], self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, -1])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, -1])
# perm = [0, 2, 3, 1]
# [-1, 0, -1, -1] --> [-1, -1, 0, -1], [-1, 0, -1, -1] (output --> input, output)
self.x_dist_tensor_spec.shape = [64, 48, 36, 24]
self.attrs['perm'] = [0, 2, 3, 1]
self.out_dist_tensor_spec.shape = [64, 36, 24, 48]
self.out_dist_tensor_spec.set_dims_mapping([-1, 0, -1, -1])
result_dist_attrs = self.rule.infer_backward(
[self.x_dist_tensor_spec], [self.out_dist_tensor_spec], self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(
infered_input_dist_attrs[0].dims_mapping, [-1, -1, 0, -1]
)
self.assertEqual(
infered_output_dist_attrs[0].dims_mapping, [-1, 0, -1, -1]
)
def test_backward_multi_mesh_dim(self):
process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2], [3, 4, 5]])
self.x_dist_tensor_spec.set_process_mesh(process_mesh)
self.x_dist_tensor_spec.shape = [64, 48, 36, 24]
self.out_dist_tensor_spec.set_process_mesh(process_mesh)
# perm = [0, 2, 3, 1]
# [-1, 1, -1, 0] --> [-1, 0, 1, -1], [-1, 1, -1, 0] (output --> input, output)
self.attrs['perm'] = [0, 2, 3, 1]
self.out_dist_tensor_spec.shape = [64, 36, 24, 48]
self.out_dist_tensor_spec.set_dims_mapping([-1, 1, -1, 0])
result_dist_attrs = self.rule.infer_backward(
[self.x_dist_tensor_spec], [self.out_dist_tensor_spec], self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(len(result_dist_attrs), 2)
self.assertEqual(len(infered_input_dist_attrs), 1)
self.assertEqual(len(infered_output_dist_attrs), 1)
self.assertEqual(
infered_input_dist_attrs[0].dims_mapping, [-1, 0, 1, -1]
)
self.assertEqual(
infered_output_dist_attrs[0].dims_mapping, [-1, 1, -1, 0]
)
# perm = [0, 2, 3, 1]
# [-1, -1, -1, -1] --> [-1, -1, -1, -1], [-1, -1, -1, -1] (output --> input, output)
self.attrs['perm'] = [0, 2, 3, 1]
self.out_dist_tensor_spec.set_dims_mapping([-1, -1, -1, -1])
result_dist_attrs = self.rule.infer_backward(
[self.x_dist_tensor_spec], [self.out_dist_tensor_spec], self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(
infered_input_dist_attrs[0].dims_mapping, [-1, -1, -1, -1]
)
self.assertEqual(
infered_output_dist_attrs[0].dims_mapping, [-1, -1, -1, -1]
)
# perm = [-1, 0, -2, 1]
# [1, -1, 0, -1] --> [-1, -1, 0, 1], [1, -1, 0, -1] (output --> input, output)
self.x_dist_tensor_spec.shape = [64, 48, 36, 24]
self.attrs['perm'] = [-1, 0, -2, 1]
self.out_dist_tensor_spec.shape = [24, 64, 36, 48]
self.out_dist_tensor_spec.set_dims_mapping([1, -1, 0, -1])
result_dist_attrs = self.rule.infer_backward(
[self.x_dist_tensor_spec], [self.out_dist_tensor_spec], self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(
infered_input_dist_attrs[0].dims_mapping, [-1, -1, 0, 1]
)
self.assertEqual(
infered_output_dist_attrs[0].dims_mapping, [1, -1, 0, -1]
)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册