未验证 提交 fa1d0e39 编写于 作者: Y Yichen Zhang 提交者: GitHub

add reduction backward rule (#56504)

上级 cf652101
...@@ -22,40 +22,23 @@ namespace auto_parallel { ...@@ -22,40 +22,23 @@ namespace auto_parallel {
using phi::distributed::auto_parallel::str_join; using phi::distributed::auto_parallel::str_join;
std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>> std::string ReductionSPMDRule::GetOutputNotation(
ReductionSPMDRule::InferForward(const std::vector<DistTensorSpec>& input_specs, int64_t input_ndim,
const std::string& input_axes,
const paddle::framework::AttributeMap& attrs) { const paddle::framework::AttributeMap& attrs) {
// step0: Verify Input Args Based on Elementwise Logic
int64_t ninputs = static_cast<int64_t>(input_specs.size());
PADDLE_ENFORCE_EQ(
ninputs,
1,
phi::errors::InvalidArgument("The size of InputSpec in reduction must "
"be equal to 1, but got [%d].",
ninputs));
VerifySpecs(input_specs, "reduction");
// step1: Build Einsum Notation
bool keep_dim = ExtractAttr<bool>("keep_dim", attrs); bool keep_dim = ExtractAttr<bool>("keep_dim", attrs);
std::vector<int64_t> reduce_dims = std::vector<int64_t> reduce_dims =
ExtractAttr<std::vector<int64_t>>("axis", attrs); ExtractAttr<std::vector<int64_t>>("axis", attrs);
std::string alphabet = "abcdefghijklmnopqrstuvwxyz";
// get einsum notation for input
int64_t ndim = static_cast<int64_t>(input_specs[0].shape().size());
std::vector<std::string> input_axes_vec;
std::string input_axes = alphabet.substr(0, ndim);
input_axes_vec.emplace_back(input_axes);
// get einsum notation for output
for (auto& reduce_dim : reduce_dims) {
// convert the negative dim value to normal dim value // convert the negative dim value to normal dim value
for (auto& reduce_dim : reduce_dims) {
if (reduce_dim < 0) { if (reduce_dim < 0) {
reduce_dim = ndim + reduce_dim; reduce_dim = input_ndim + reduce_dim;
} }
} }
std::string output_axes = ""; std::string output_axes = "";
for (int64_t i = 0; i < ndim; i++) { for (int64_t i = 0; i < input_ndim; i++) {
std::vector<int64_t>::iterator iter = std::vector<int64_t>::iterator iter =
std::find(reduce_dims.begin(), reduce_dims.end(), i); std::find(reduce_dims.begin(), reduce_dims.end(), i);
if (iter != reduce_dims.end()) { if (iter != reduce_dims.end()) {
...@@ -71,6 +54,33 @@ ReductionSPMDRule::InferForward(const std::vector<DistTensorSpec>& input_specs, ...@@ -71,6 +54,33 @@ ReductionSPMDRule::InferForward(const std::vector<DistTensorSpec>& input_specs,
} }
} }
return output_axes;
}
std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
ReductionSPMDRule::InferForward(const std::vector<DistTensorSpec>& input_specs,
const paddle::framework::AttributeMap& attrs) {
// step0: Verify Input Args Based on Elementwise Logic
int64_t ninputs = input_specs.size();
PADDLE_ENFORCE_EQ(
ninputs,
1,
phi::errors::InvalidArgument("The size of InputSpec in reduction must "
"be equal to 1, but got [%d].",
ninputs));
VerifySpecs(input_specs, "reduction");
// step1: Build Einsum Notation
// get einsum notation for input
std::string alphabet = "abcdefghijklmnopqrstuvwxyz";
int64_t ndim = input_specs[0].shape().size();
std::vector<std::string> input_axes_vec;
std::string input_axes = alphabet.substr(0, ndim);
input_axes_vec.emplace_back(input_axes);
// get einsum notation for output
std::string output_axes = GetOutputNotation(ndim, alphabet, attrs);
// step2: Sharding Propogation // step2: Sharding Propogation
// step2.1: merge input shardings // step2.1: merge input shardings
std::vector<std::pair<std::string, std::vector<int64_t>>> axes_sharding_info; std::vector<std::pair<std::string, std::vector<int64_t>>> axes_sharding_info;
...@@ -88,8 +98,8 @@ ReductionSPMDRule::InferForward(const std::vector<DistTensorSpec>& input_specs, ...@@ -88,8 +98,8 @@ ReductionSPMDRule::InferForward(const std::vector<DistTensorSpec>& input_specs,
CopyTensorDistAttrForOutput(input_specs[0].dist_attr()); CopyTensorDistAttrForOutput(input_specs[0].dist_attr());
output_dist_attr.set_dims_mapping(output_dims_mapping); output_dist_attr.set_dims_mapping(output_dims_mapping);
// step2.4: handle partial // step3: handle partial
// Step2.4.1 Output Partial // Step3.1 Output Partial
std::vector<int64_t> partial_on_dims = std::vector<int64_t> partial_on_dims =
ResoluteOutputPartialDimension(axis_to_dim_map, output_axes); ResoluteOutputPartialDimension(axis_to_dim_map, output_axes);
output_dist_attr.set_partial_status( output_dist_attr.set_partial_status(
...@@ -98,7 +108,7 @@ ReductionSPMDRule::InferForward(const std::vector<DistTensorSpec>& input_specs, ...@@ -98,7 +108,7 @@ ReductionSPMDRule::InferForward(const std::vector<DistTensorSpec>& input_specs,
std::vector<TensorDistAttr> output_dist_attrs; std::vector<TensorDistAttr> output_dist_attrs;
output_dist_attrs.emplace_back(output_dist_attr); output_dist_attrs.emplace_back(output_dist_attr);
// Step2.4.2 handle input tensor partial (TODO) // Step3.2 handle input tensor partial (TODO)
// If the op is a linear op, i.e. `linearity` is true, it supports // If the op is a linear op, i.e. `linearity` is true, it supports
// the input to be partial. Otherwise, the input cannot be partial // the input to be partial. Otherwise, the input cannot be partial
// on reduced axes, we should reshard the input when the reduced // on reduced axes, we should reshard the input when the reduced
...@@ -120,12 +130,60 @@ ReductionSPMDRule::InferForward(const std::vector<DistTensorSpec>& input_specs, ...@@ -120,12 +130,60 @@ ReductionSPMDRule::InferForward(const std::vector<DistTensorSpec>& input_specs,
std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>> std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
ReductionSPMDRule::InferBackward( ReductionSPMDRule::InferBackward(
const std::vector<DistTensorSpec>& input_specs,
const std::vector<DistTensorSpec>& output_specs, const std::vector<DistTensorSpec>& output_specs,
const paddle::framework::AttributeMap& attrs) { const paddle::framework::AttributeMap& attrs) {
PADDLE_THROW(phi::errors::Unimplemented( // step0: Verify Input Args Based on Elementwise Logic
"InferBackward of ReductionSPMDRule is NOT implemented yet.")); int64_t ninputs = input_specs.size();
int64_t noutputs = output_specs.size();
PADDLE_ENFORCE_EQ(
ninputs,
1,
phi::errors::InvalidArgument("The size of InputSpec in reduction must "
"be equal to 1, but got [%d].",
ninputs));
PADDLE_ENFORCE_EQ(
noutputs,
1,
phi::errors::InvalidArgument("The size of OutputSpec in reduction must "
"be equal to 1, but got [%d].",
ninputs));
VerifySpecs(output_specs, "reduction_backward");
// step1: Build Einsum Notation
// get einsum notation for input
std::string alphabet = "abcdefghijklmnopqrstuvwxyz";
int64_t ndim = input_specs[0].shape().size();
std::string input_axes = alphabet.substr(0, ndim);
// get einsum notation for output
std::string output_axes = GetOutputNotation(ndim, alphabet, attrs);
// step2: Sharding Propogation
std::unordered_map<std::string, int64_t> axis_to_dim_map =
ShardingMergeForTensors({{output_axes, output_specs[0].dims_mapping()}});
// step2.2: infer input dims mapping from output dims mapping
std::vector<int64_t> input_dims_mapping =
GetDimsMappingForAxes(input_axes, axis_to_dim_map, true);
// initialize input dist_attr's process_mesh, batch_dim and dynamic dims with
// input dist_attr.
TensorDistAttr input_dist_attr(input_specs[0].dist_attr());
input_dist_attr.set_dims_mapping(input_dims_mapping);
// step3: handle partial (TODO)
VLOG(4) << "ReductionSPMDRule InferBackward: ";
VLOG(4) << "Output shape:[" << str_join(output_specs[0].shape())
<< "] dims_mapping: [" << str_join(output_specs[0].dims_mapping())
<< "]";
VLOG(4) << "Input0: "
<< " shape: [" << str_join(input_specs[0].shape()) << "] "
<< "dims_mapping: [" << str_join(input_dist_attr.dims_mapping())
<< "]";
return {}; return {{input_dist_attr}, {output_specs[0].dist_attr()}};
} }
} // namespace auto_parallel } // namespace auto_parallel
......
...@@ -32,8 +32,14 @@ class ReductionSPMDRule : public SPMDRuleBase { ...@@ -32,8 +32,14 @@ class ReductionSPMDRule : public SPMDRuleBase {
const paddle::framework::AttributeMap& attrs) override; const paddle::framework::AttributeMap& attrs) override;
std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>> std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
InferBackward(const std::vector<DistTensorSpec>& output_specs, InferBackward(const std::vector<DistTensorSpec>& input_specs,
const std::vector<DistTensorSpec>& output_specs,
const paddle::framework::AttributeMap& attrs) override; const paddle::framework::AttributeMap& attrs) override;
private:
std::string GetOutputNotation(int64_t input_ndim,
const std::string& input_axes,
const paddle::framework::AttributeMap& attrs);
}; };
} // namespace auto_parallel } // namespace auto_parallel
} // namespace distributed } // namespace distributed
......
...@@ -38,6 +38,8 @@ class TestReductionSPMDRule(unittest.TestCase): ...@@ -38,6 +38,8 @@ class TestReductionSPMDRule(unittest.TestCase):
x_tensor_dist_attr.process_mesh = process_mesh x_tensor_dist_attr.process_mesh = process_mesh
self.x_dist_tensor_spec = DistTensorSpec(x_shape, x_tensor_dist_attr) self.x_dist_tensor_spec = DistTensorSpec(x_shape, x_tensor_dist_attr)
self.out_dist_tensor_spec = DistTensorSpec(self.x_dist_tensor_spec)
self.attrs = { self.attrs = {
'keep_dim': False, 'keep_dim': False,
'axis': [0], 'axis': [0],
...@@ -233,6 +235,171 @@ class TestReductionSPMDRule(unittest.TestCase): ...@@ -233,6 +235,171 @@ class TestReductionSPMDRule(unittest.TestCase):
self.assertEqual(infered_output_dist_attrs[0]._is_partial(), True) self.assertEqual(infered_output_dist_attrs[0]._is_partial(), True)
self.assertEqual(infered_output_dist_attrs[0]._partial_dims(), {1}) self.assertEqual(infered_output_dist_attrs[0]._partial_dims(), {1})
def test_backward_single_mesh_dim(self):
# reduce on dim 0, keep_dim = false
# [-1] --> [-1, -1], [-1] (output --> input, output)
self.attrs['keep_dim'] = False
self.attrs['axis'] = [0]
self.out_dist_tensor_spec.shape = [32]
self.out_dist_tensor_spec.set_dims_mapping([-1])
result_dist_attrs = self.rule.infer_backward(
[self.x_dist_tensor_spec], [self.out_dist_tensor_spec], self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(len(result_dist_attrs), 2)
self.assertEqual(len(infered_input_dist_attrs), 1)
self.assertEqual(len(infered_output_dist_attrs), 1)
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [-1, -1])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1])
# reduce on dim 0, keep_dim = true
# [-1, -1] --> [-1, -1], [-1, -1] (output --> input, output)
self.attrs['keep_dim'] = True
self.attrs['axis'] = [0]
self.out_dist_tensor_spec.shape = [1, 32]
self.out_dist_tensor_spec.set_dims_mapping([-1, -1])
result_dist_attrs = self.rule.infer_backward(
[self.x_dist_tensor_spec], [self.out_dist_tensor_spec], self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [-1, -1])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, -1])
# reduce on dim 1, keep_dim = false
# [0] --> [0, -1], [0] (output --> input, output)
self.attrs['keep_dim'] = False
self.attrs['axis'] = [1]
self.out_dist_tensor_spec.shape = [64]
self.out_dist_tensor_spec.set_dims_mapping([0])
result_dist_attrs = self.rule.infer_backward(
[self.x_dist_tensor_spec], [self.out_dist_tensor_spec], self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, -1])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0])
# reduce on dim 1, keep_dim = true
# [0, -1] --> [0, -1], [0, -1] (output --> input, output)
self.attrs['keep_dim'] = True
self.attrs['axis'] = [1]
self.out_dist_tensor_spec.shape = [64, 1]
self.out_dist_tensor_spec.set_dims_mapping([0, -1])
result_dist_attrs = self.rule.infer_backward(
[self.x_dist_tensor_spec], [self.out_dist_tensor_spec], self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, -1])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, -1])
# reduce on dim 0 and 1, keep_dim = false
# [] --> [-1, -1], [] (output --> input, output)
self.attrs['keep_dim'] = False
self.attrs['axis'] = [0, 1]
self.out_dist_tensor_spec.shape = []
self.out_dist_tensor_spec.set_dims_mapping([])
result_dist_attrs = self.rule.infer_backward(
[self.x_dist_tensor_spec], [self.out_dist_tensor_spec], self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [-1, -1])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [])
# reduce on dim 0 and 1, keep_dim = true
# [-1, -1] --> [-1, -1], [-1, -1] (output --> input, output)
self.attrs['keep_dim'] = True
self.attrs['axis'] = [0, 1]
self.out_dist_tensor_spec.shape = [1, 1]
self.out_dist_tensor_spec.set_dims_mapping([-1, -1])
result_dist_attrs = self.rule.infer_backward(
[self.x_dist_tensor_spec], [self.out_dist_tensor_spec], self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [-1, -1])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, -1])
def test_backward_multi_mesh_dim(self):
process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2], [3, 4, 5]])
self.x_dist_tensor_spec.set_process_mesh(process_mesh)
self.x_dist_tensor_spec.shape = [96, 24, 48]
self.out_dist_tensor_spec.set_process_mesh(process_mesh)
# reduce on dim 1, 2, keep_dim = false
# [0] --> [0, -1, -1], [0] (output --> input, output)
self.attrs['keep_dim'] = False
self.attrs['axis'] = [1, 2]
self.out_dist_tensor_spec.shape = [96]
self.out_dist_tensor_spec.set_dims_mapping([0])
result_dist_attrs = self.rule.infer_backward(
[self.x_dist_tensor_spec], [self.out_dist_tensor_spec], self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(len(result_dist_attrs), 2)
self.assertEqual(len(infered_input_dist_attrs), 1)
self.assertEqual(len(infered_output_dist_attrs), 1)
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, -1, -1])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0])
# reduce on dim 1, 2, keep_dim = false
# [-1] --> [-1, -1, -1], [-1] (output --> input, output)
self.attrs['keep_dim'] = False
self.attrs['axis'] = [1, 2]
self.out_dist_tensor_spec.shape = [96]
self.out_dist_tensor_spec.set_dims_mapping([-1])
result_dist_attrs = self.rule.infer_backward(
[self.x_dist_tensor_spec], [self.out_dist_tensor_spec], self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [-1, -1, -1])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1])
# reduction on dim 1, 2, keep_dim = false
# [1] --> [1, -1, -1], [1] (output --> input, output)
self.attrs['keep_dim'] = False
self.attrs['axis'] = [1, 2]
self.out_dist_tensor_spec.shape = [96]
self.out_dist_tensor_spec.set_dims_mapping([1])
result_dist_attrs = self.rule.infer_backward(
[self.x_dist_tensor_spec], [self.out_dist_tensor_spec], self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [1, -1, -1])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1])
# reduction on dim 1, 2, keep_dim = true
# [0, -1, -1] --> [0, -1, -1], [0, -1, -1] (output --> input, output)
self.attrs['keep_dim'] = True
self.attrs['axis'] = [1, 2]
self.out_dist_tensor_spec.shape = [96, 1, 1]
self.out_dist_tensor_spec.set_dims_mapping([0, -1, -1])
result_dist_attrs = self.rule.infer_backward(
[self.x_dist_tensor_spec], [self.out_dist_tensor_spec], self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, -1, -1])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, -1, -1])
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册