From fa1d0e3939da8c998eb624df91880943f8349a50 Mon Sep 17 00:00:00 2001 From: Yichen Zhang <32740647+pkuzyc@users.noreply.github.com> Date: Fri, 8 Sep 2023 14:00:00 +0800 Subject: [PATCH] add reduction backward rule (#56504) --- .../spmd_rules/reduction_spmd_rule.cc | 120 +++++++++---- .../spmd_rules/reduction_spmd_rule.h | 8 +- .../spmd_rules/test_reduction_rule.py | 167 ++++++++++++++++++ 3 files changed, 263 insertions(+), 32 deletions(-) diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/reduction_spmd_rule.cc b/paddle/fluid/distributed/auto_parallel/spmd_rules/reduction_spmd_rule.cc index e2c3045d12f..62940545e88 100644 --- a/paddle/fluid/distributed/auto_parallel/spmd_rules/reduction_spmd_rule.cc +++ b/paddle/fluid/distributed/auto_parallel/spmd_rules/reduction_spmd_rule.cc @@ -22,40 +22,23 @@ namespace auto_parallel { using phi::distributed::auto_parallel::str_join; -std::pair, std::vector> -ReductionSPMDRule::InferForward(const std::vector& input_specs, - const paddle::framework::AttributeMap& attrs) { - // step0: Verify Input Args Based on Elementwise Logic - int64_t ninputs = static_cast(input_specs.size()); - PADDLE_ENFORCE_EQ( - ninputs, - 1, - phi::errors::InvalidArgument("The size of InputSpec in reduction must " - "be equal to 1, but got [%d].", - ninputs)); - VerifySpecs(input_specs, "reduction"); - - // step1: Build Einsum Notation +std::string ReductionSPMDRule::GetOutputNotation( + int64_t input_ndim, + const std::string& input_axes, + const paddle::framework::AttributeMap& attrs) { bool keep_dim = ExtractAttr("keep_dim", attrs); std::vector reduce_dims = ExtractAttr>("axis", attrs); - std::string alphabet = "abcdefghijklmnopqrstuvwxyz"; - - // get einsum notation for input - int64_t ndim = static_cast(input_specs[0].shape().size()); - std::vector input_axes_vec; - std::string input_axes = alphabet.substr(0, ndim); - input_axes_vec.emplace_back(input_axes); - // get einsum notation for output + // convert the negative dim value to normal dim value for (auto& reduce_dim : reduce_dims) { - // convert the negative dim value to normal dim value if (reduce_dim < 0) { - reduce_dim = ndim + reduce_dim; + reduce_dim = input_ndim + reduce_dim; } } + std::string output_axes = ""; - for (int64_t i = 0; i < ndim; i++) { + for (int64_t i = 0; i < input_ndim; i++) { std::vector::iterator iter = std::find(reduce_dims.begin(), reduce_dims.end(), i); if (iter != reduce_dims.end()) { @@ -71,6 +54,33 @@ ReductionSPMDRule::InferForward(const std::vector& input_specs, } } + return output_axes; +} + +std::pair, std::vector> +ReductionSPMDRule::InferForward(const std::vector& input_specs, + const paddle::framework::AttributeMap& attrs) { + // step0: Verify Input Args Based on Elementwise Logic + int64_t ninputs = input_specs.size(); + PADDLE_ENFORCE_EQ( + ninputs, + 1, + phi::errors::InvalidArgument("The size of InputSpec in reduction must " + "be equal to 1, but got [%d].", + ninputs)); + VerifySpecs(input_specs, "reduction"); + + // step1: Build Einsum Notation + // get einsum notation for input + std::string alphabet = "abcdefghijklmnopqrstuvwxyz"; + int64_t ndim = input_specs[0].shape().size(); + std::vector input_axes_vec; + std::string input_axes = alphabet.substr(0, ndim); + input_axes_vec.emplace_back(input_axes); + + // get einsum notation for output + std::string output_axes = GetOutputNotation(ndim, alphabet, attrs); + // step2: Sharding Propogation // step2.1: merge input shardings std::vector>> axes_sharding_info; @@ -88,8 +98,8 @@ ReductionSPMDRule::InferForward(const std::vector& input_specs, CopyTensorDistAttrForOutput(input_specs[0].dist_attr()); output_dist_attr.set_dims_mapping(output_dims_mapping); - // step2.4: handle partial - // Step2.4.1 Output Partial + // step3: handle partial + // Step3.1 Output Partial std::vector partial_on_dims = ResoluteOutputPartialDimension(axis_to_dim_map, output_axes); output_dist_attr.set_partial_status( @@ -98,7 +108,7 @@ ReductionSPMDRule::InferForward(const std::vector& input_specs, std::vector output_dist_attrs; output_dist_attrs.emplace_back(output_dist_attr); - // Step2.4.2 handle input tensor partial (TODO) + // Step3.2 handle input tensor partial (TODO) // If the op is a linear op, i.e. `linearity` is true, it supports // the input to be partial. Otherwise, the input cannot be partial // on reduced axes, we should reshard the input when the reduced @@ -120,12 +130,60 @@ ReductionSPMDRule::InferForward(const std::vector& input_specs, std::pair, std::vector> ReductionSPMDRule::InferBackward( + const std::vector& input_specs, const std::vector& output_specs, const paddle::framework::AttributeMap& attrs) { - PADDLE_THROW(phi::errors::Unimplemented( - "InferBackward of ReductionSPMDRule is NOT implemented yet.")); + // step0: Verify Input Args Based on Elementwise Logic + int64_t ninputs = input_specs.size(); + int64_t noutputs = output_specs.size(); + PADDLE_ENFORCE_EQ( + ninputs, + 1, + phi::errors::InvalidArgument("The size of InputSpec in reduction must " + "be equal to 1, but got [%d].", + ninputs)); + PADDLE_ENFORCE_EQ( + noutputs, + 1, + phi::errors::InvalidArgument("The size of OutputSpec in reduction must " + "be equal to 1, but got [%d].", + ninputs)); + VerifySpecs(output_specs, "reduction_backward"); + + // step1: Build Einsum Notation + // get einsum notation for input + std::string alphabet = "abcdefghijklmnopqrstuvwxyz"; + int64_t ndim = input_specs[0].shape().size(); + std::string input_axes = alphabet.substr(0, ndim); + + // get einsum notation for output + std::string output_axes = GetOutputNotation(ndim, alphabet, attrs); + + // step2: Sharding Propogation + std::unordered_map axis_to_dim_map = + ShardingMergeForTensors({{output_axes, output_specs[0].dims_mapping()}}); + + // step2.2: infer input dims mapping from output dims mapping + std::vector input_dims_mapping = + GetDimsMappingForAxes(input_axes, axis_to_dim_map, true); + + // initialize input dist_attr's process_mesh, batch_dim and dynamic dims with + // input dist_attr. + TensorDistAttr input_dist_attr(input_specs[0].dist_attr()); + input_dist_attr.set_dims_mapping(input_dims_mapping); + + // step3: handle partial (TODO) + + VLOG(4) << "ReductionSPMDRule InferBackward: "; + VLOG(4) << "Output shape:[" << str_join(output_specs[0].shape()) + << "] dims_mapping: [" << str_join(output_specs[0].dims_mapping()) + << "]"; + VLOG(4) << "Input0: " + << " shape: [" << str_join(input_specs[0].shape()) << "] " + << "dims_mapping: [" << str_join(input_dist_attr.dims_mapping()) + << "]"; - return {}; + return {{input_dist_attr}, {output_specs[0].dist_attr()}}; } } // namespace auto_parallel diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/reduction_spmd_rule.h b/paddle/fluid/distributed/auto_parallel/spmd_rules/reduction_spmd_rule.h index 7039529742d..36e412b7049 100644 --- a/paddle/fluid/distributed/auto_parallel/spmd_rules/reduction_spmd_rule.h +++ b/paddle/fluid/distributed/auto_parallel/spmd_rules/reduction_spmd_rule.h @@ -32,8 +32,14 @@ class ReductionSPMDRule : public SPMDRuleBase { const paddle::framework::AttributeMap& attrs) override; std::pair, std::vector> - InferBackward(const std::vector& output_specs, + InferBackward(const std::vector& input_specs, + const std::vector& output_specs, const paddle::framework::AttributeMap& attrs) override; + + private: + std::string GetOutputNotation(int64_t input_ndim, + const std::string& input_axes, + const paddle::framework::AttributeMap& attrs); }; } // namespace auto_parallel } // namespace distributed diff --git a/test/auto_parallel/spmd_rules/test_reduction_rule.py b/test/auto_parallel/spmd_rules/test_reduction_rule.py index e3b4012b6bd..f8069ee2265 100644 --- a/test/auto_parallel/spmd_rules/test_reduction_rule.py +++ b/test/auto_parallel/spmd_rules/test_reduction_rule.py @@ -38,6 +38,8 @@ class TestReductionSPMDRule(unittest.TestCase): x_tensor_dist_attr.process_mesh = process_mesh self.x_dist_tensor_spec = DistTensorSpec(x_shape, x_tensor_dist_attr) + self.out_dist_tensor_spec = DistTensorSpec(self.x_dist_tensor_spec) + self.attrs = { 'keep_dim': False, 'axis': [0], @@ -233,6 +235,171 @@ class TestReductionSPMDRule(unittest.TestCase): self.assertEqual(infered_output_dist_attrs[0]._is_partial(), True) self.assertEqual(infered_output_dist_attrs[0]._partial_dims(), {1}) + def test_backward_single_mesh_dim(self): + # reduce on dim 0, keep_dim = false + # [-1] --> [-1, -1], [-1] (output --> input, output) + self.attrs['keep_dim'] = False + self.attrs['axis'] = [0] + self.out_dist_tensor_spec.shape = [32] + self.out_dist_tensor_spec.set_dims_mapping([-1]) + result_dist_attrs = self.rule.infer_backward( + [self.x_dist_tensor_spec], [self.out_dist_tensor_spec], self.attrs + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(len(result_dist_attrs), 2) + self.assertEqual(len(infered_input_dist_attrs), 1) + self.assertEqual(len(infered_output_dist_attrs), 1) + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [-1, -1]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1]) + + # reduce on dim 0, keep_dim = true + # [-1, -1] --> [-1, -1], [-1, -1] (output --> input, output) + self.attrs['keep_dim'] = True + self.attrs['axis'] = [0] + self.out_dist_tensor_spec.shape = [1, 32] + self.out_dist_tensor_spec.set_dims_mapping([-1, -1]) + result_dist_attrs = self.rule.infer_backward( + [self.x_dist_tensor_spec], [self.out_dist_tensor_spec], self.attrs + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [-1, -1]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, -1]) + + # reduce on dim 1, keep_dim = false + # [0] --> [0, -1], [0] (output --> input, output) + self.attrs['keep_dim'] = False + self.attrs['axis'] = [1] + self.out_dist_tensor_spec.shape = [64] + self.out_dist_tensor_spec.set_dims_mapping([0]) + result_dist_attrs = self.rule.infer_backward( + [self.x_dist_tensor_spec], [self.out_dist_tensor_spec], self.attrs + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, -1]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0]) + + # reduce on dim 1, keep_dim = true + # [0, -1] --> [0, -1], [0, -1] (output --> input, output) + self.attrs['keep_dim'] = True + self.attrs['axis'] = [1] + self.out_dist_tensor_spec.shape = [64, 1] + self.out_dist_tensor_spec.set_dims_mapping([0, -1]) + result_dist_attrs = self.rule.infer_backward( + [self.x_dist_tensor_spec], [self.out_dist_tensor_spec], self.attrs + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, -1]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, -1]) + + # reduce on dim 0 and 1, keep_dim = false + # [] --> [-1, -1], [] (output --> input, output) + self.attrs['keep_dim'] = False + self.attrs['axis'] = [0, 1] + self.out_dist_tensor_spec.shape = [] + self.out_dist_tensor_spec.set_dims_mapping([]) + result_dist_attrs = self.rule.infer_backward( + [self.x_dist_tensor_spec], [self.out_dist_tensor_spec], self.attrs + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [-1, -1]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, []) + + # reduce on dim 0 and 1, keep_dim = true + # [-1, -1] --> [-1, -1], [-1, -1] (output --> input, output) + self.attrs['keep_dim'] = True + self.attrs['axis'] = [0, 1] + self.out_dist_tensor_spec.shape = [1, 1] + self.out_dist_tensor_spec.set_dims_mapping([-1, -1]) + result_dist_attrs = self.rule.infer_backward( + [self.x_dist_tensor_spec], [self.out_dist_tensor_spec], self.attrs + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [-1, -1]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, -1]) + + def test_backward_multi_mesh_dim(self): + process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2], [3, 4, 5]]) + self.x_dist_tensor_spec.set_process_mesh(process_mesh) + self.x_dist_tensor_spec.shape = [96, 24, 48] + self.out_dist_tensor_spec.set_process_mesh(process_mesh) + + # reduce on dim 1, 2, keep_dim = false + # [0] --> [0, -1, -1], [0] (output --> input, output) + self.attrs['keep_dim'] = False + self.attrs['axis'] = [1, 2] + self.out_dist_tensor_spec.shape = [96] + self.out_dist_tensor_spec.set_dims_mapping([0]) + result_dist_attrs = self.rule.infer_backward( + [self.x_dist_tensor_spec], [self.out_dist_tensor_spec], self.attrs + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(len(result_dist_attrs), 2) + self.assertEqual(len(infered_input_dist_attrs), 1) + self.assertEqual(len(infered_output_dist_attrs), 1) + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, -1, -1]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0]) + + # reduce on dim 1, 2, keep_dim = false + # [-1] --> [-1, -1, -1], [-1] (output --> input, output) + self.attrs['keep_dim'] = False + self.attrs['axis'] = [1, 2] + self.out_dist_tensor_spec.shape = [96] + self.out_dist_tensor_spec.set_dims_mapping([-1]) + result_dist_attrs = self.rule.infer_backward( + [self.x_dist_tensor_spec], [self.out_dist_tensor_spec], self.attrs + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [-1, -1, -1]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1]) + + # reduction on dim 1, 2, keep_dim = false + # [1] --> [1, -1, -1], [1] (output --> input, output) + self.attrs['keep_dim'] = False + self.attrs['axis'] = [1, 2] + self.out_dist_tensor_spec.shape = [96] + self.out_dist_tensor_spec.set_dims_mapping([1]) + result_dist_attrs = self.rule.infer_backward( + [self.x_dist_tensor_spec], [self.out_dist_tensor_spec], self.attrs + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [1, -1, -1]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1]) + + # reduction on dim 1, 2, keep_dim = true + # [0, -1, -1] --> [0, -1, -1], [0, -1, -1] (output --> input, output) + self.attrs['keep_dim'] = True + self.attrs['axis'] = [1, 2] + self.out_dist_tensor_spec.shape = [96, 1, 1] + self.out_dist_tensor_spec.set_dims_mapping([0, -1, -1]) + result_dist_attrs = self.rule.infer_backward( + [self.x_dist_tensor_spec], [self.out_dist_tensor_spec], self.attrs + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, -1, -1]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, -1, -1]) + if __name__ == "__main__": unittest.main() -- GitLab