diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/reshape_spmd_rule.cc b/paddle/fluid/distributed/auto_parallel/spmd_rules/reshape_spmd_rule.cc index 01cb5cffe65ab93281e4e4543132e80afd4363b2..ec6b3c0e3e0f601b57e4eeaf4d90acc343b2d65c 100644 --- a/paddle/fluid/distributed/auto_parallel/spmd_rules/reshape_spmd_rule.cc +++ b/paddle/fluid/distributed/auto_parallel/spmd_rules/reshape_spmd_rule.cc @@ -135,6 +135,7 @@ std::vector MakeReshapeDimTrans( return ret; } +// std::pair, std::vector> paddle::distributed::auto_parallel::ReshapeSPMDRule::InferForward( const std::vector& input_specs, @@ -195,12 +196,64 @@ paddle::distributed::auto_parallel::ReshapeSPMDRule::InferForward( std::pair, std::vector> paddle::distributed::auto_parallel::ReshapeSPMDRule::InferBackward( + const std::vector& input_specs, const std::vector& output_specs, const paddle::framework::AttributeMap& attrs) { - PADDLE_THROW(phi::errors::Unimplemented( - "InferBackward of ReductionSPMDRule is NOT implemented yet.")); + // step0: Verify Input Args Based on Reshape Logic + int64_t ninputs = input_specs.size(); + int64_t noutputs = output_specs.size(); + PADDLE_ENFORCE_EQ( + ninputs, + 1, + phi::errors::InvalidArgument("The size of InputSpec in reshape must " + "be equal to 1, but got [%d].", + ninputs)); + PADDLE_ENFORCE_EQ( + noutputs, + 1, + phi::errors::InvalidArgument("The size of OutputSpec in reshape must " + "be equal to 1, but got [%d].", + noutputs)); + VerifySpecs(output_specs, "reshape"); + + // step1: build the transformation from the output shape + // to original shape. Inferbackward infers the dims mapping + // from output to input, we first get the transformation + // from output to input so that we can infer the dims mapping + // with the map from output axes to input axes. + // Shapes in Inferbackward don't contain -1 or 0, so they will + // not be modified and we can use ref here. + const std::vector& output_shape = output_specs[0].shape(); + const std::vector& input_shape = input_specs[0].shape(); + + std::vector trans = MakeReshapeDimTrans(output_shape, input_shape); + + // step2: infer the dims mapping of input with + // output's dims_mapping and the transformation. + std::vector> dims_mapping_vec = + InferFromDimTrans(output_specs[0], trans); + + // step3: update the dist attributes of input + // and output with the inferred dims mapping + TensorDistAttr new_output_dist_attr(output_specs[0].dist_attr()); + new_output_dist_attr.set_dims_mapping(dims_mapping_vec[0]); + TensorDistAttr input_dist_attr(input_specs[0].dist_attr()); + input_dist_attr.set_dims_mapping(dims_mapping_vec[1]); + + VLOG(4) << "Reshape Inferbackward: output_shape: [" << str_join(output_shape) + << "] input_shape: [" << str_join(input_shape) << "]"; + VLOG(4) << "Transformation from output to input:"; + for (int64_t i = 0, n = trans.size(); i < n; i++) { + DimTrans* t = trans[i]; + VLOG(4) << "\tInput axis " << i << ": " << t->to_string(); + } + VLOG(4) << "input_dims_mapping: [" << str_join(dims_mapping_vec[1]) + << "] output_dims_mapping: [" << str_join(dims_mapping_vec[0]) + << "]\n\n"; + + CleanUp(); - return {}; + return {{input_dist_attr}, {new_output_dist_attr}}; } } // namespace auto_parallel diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/reshape_spmd_rule.h b/paddle/fluid/distributed/auto_parallel/spmd_rules/reshape_spmd_rule.h index 63b9a5a6f038a2d9fb4aaded8761916c28a4b54f..737455e0be6c8bdd95e763d65918c83ae2e64a02 100644 --- a/paddle/fluid/distributed/auto_parallel/spmd_rules/reshape_spmd_rule.h +++ b/paddle/fluid/distributed/auto_parallel/spmd_rules/reshape_spmd_rule.h @@ -32,7 +32,8 @@ class ReshapeSPMDRule : public SPMDRuleBase { const paddle::framework::AttributeMap& attrs) override; std::pair, std::vector> - InferBackward(const std::vector& output_specs, + InferBackward(const std::vector& input_specs, + const std::vector& output_specs, const paddle::framework::AttributeMap& attrs) override; }; } // namespace auto_parallel diff --git a/test/auto_parallel/spmd_rules/test_reshape_rule.py b/test/auto_parallel/spmd_rules/test_reshape_rule.py index a35fc29389fa09cc206823250cf28f9bf8b05cfe..dd7c248ca42fbdc021ce16f0fdb767d120649b08 100644 --- a/test/auto_parallel/spmd_rules/test_reshape_rule.py +++ b/test/auto_parallel/spmd_rules/test_reshape_rule.py @@ -30,7 +30,7 @@ class TestReshapeSPMDRule(unittest.TestCase): process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2], [3, 4, 5]]) x_tensor_dist_attr = TensorDistAttr() - x_tensor_dist_attr.dims_mapping = [-1, -1] + x_tensor_dist_attr.dims_mapping = [-1, -1, -1, -1] x_tensor_dist_attr.process_mesh = process_mesh self.x_dist_tensor_spec = DistTensorSpec(x_shape, x_tensor_dist_attr) @@ -248,6 +248,171 @@ class TestReshapeSPMDRule(unittest.TestCase): with self.assertRaises(BaseException): self.rule.infer_forward([self.x_dist_tensor_spec], self.attrs) + def test_reshape_infer_backward(self): + process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2], [3, 4, 5]]) + + output_tensor_dist_attr = TensorDistAttr() + output_tensor_dist_attr.dims_mapping = [-1, -1, -1, -1] + output_tensor_dist_attr.process_mesh = process_mesh + + # shape: [6, 12, 48, 24] --> [1, 72, 48, 4, 6] (input --> output) + # dims_mapping: [-1, 0, 1, -1, -1] --> [0, -1, 1, -1], [-1, 0, 1, -1, -1] (output --> input, output) + self.output_dist_tensor_spec = DistTensorSpec( + [1, 72, 48, 4, 6], output_tensor_dist_attr + ) + self.output_dist_tensor_spec.set_dims_mapping([-1, 0, 1, -1, -1]) + result_dist_attrs = self.rule.infer_backward( + [self.x_dist_tensor_spec], + [self.output_dist_tensor_spec], + self.attrs, + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(len(infered_input_dist_attrs), 1) + self.assertEqual(len(infered_output_dist_attrs), 1) + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [0, -1, 1, -1] + ) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [-1, 0, 1, -1, -1] + ) + + # shape: [6, 12, 48, 24] --> [1, 72, 48, 4, 6] (input --> output) + # dims_mapping: [-1, -1, -1, -1, -1] --> [-1, -1, -1, -1], [-1, -1, -1, -1, -1] (output --> input, output) + self.output_dist_tensor_spec.shape = [1, 72, 48, 4, 6] + self.output_dist_tensor_spec.set_dims_mapping([-1, -1, -1, -1, -1]) + result_dist_attrs = self.rule.infer_backward( + [self.x_dist_tensor_spec], + [self.output_dist_tensor_spec], + self.attrs, + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [-1, -1, -1, -1] + ) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [-1, -1, -1, -1, -1] + ) + + # shape: [6, 12, 48, 24] --> [1, 72, 48, 4, 6] (input --> output) + # dims_mapping: [-1, 1, -1, 0, -1] --> [1, -1, -1, 0] [-1, 1, -1, 0, -1] (output --> input, output) + self.output_dist_tensor_spec.shape = [1, 72, 48, 4, 6] + self.output_dist_tensor_spec.set_dims_mapping([-1, 1, -1, 0, -1]) + result_dist_attrs = self.rule.infer_backward( + [self.x_dist_tensor_spec], + [self.output_dist_tensor_spec], + self.attrs, + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [1, -1, -1, 0] + ) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [-1, 1, -1, 0, -1] + ) + + # shape: [6, 12, 48, 24] --> [3, 24, 6, 8, 24] (input --> output) + # dims_mapping: [1, -1, -1, -1, 0] --> [1, -1, -1, 0], [1, -1, -1, -1, 0] (output --> input, output) + self.output_dist_tensor_spec.shape = [3, 24, 6, 8, 24] + self.output_dist_tensor_spec.set_dims_mapping([1, -1, -1, -1, 0]) + + result_dist_attrs = self.rule.infer_backward( + [self.x_dist_tensor_spec], + [self.output_dist_tensor_spec], + self.attrs, + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [1, -1, -1, 0] + ) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [1, -1, -1, -1, 0] + ) + + # shape: [6, 12, 48, 24] --> [3, 24, 6, 8, 24] (input --> output) + # dims_mapping: [-1, -1, 0, -1, 1] --> [-1, -1, 0, 1], [-1, -1, 0, -1, 1] (output --> input, output) + self.output_dist_tensor_spec.shape = [3, 24, 6, 8, 24] + self.output_dist_tensor_spec.set_dims_mapping([-1, -1, 0, -1, 1]) + + result_dist_attrs = self.rule.infer_backward( + [self.x_dist_tensor_spec], + [self.output_dist_tensor_spec], + self.attrs, + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [-1, -1, 0, 1] + ) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [-1, -1, 0, -1, 1] + ) + + # shape: [6, 12, 48, 24] --> [6, 12, 48, 24] (intput --> output) + # dims_mapping: [-1, -1, 0, 1] --> [-1, -1, 0, 1], [-1, -1, 0, 1] (output --> input, output) + self.output_dist_tensor_spec.shape = [6, 12, 48, 24] + self.output_dist_tensor_spec.set_dims_mapping([-1, -1, 0, 1]) + result_dist_attrs = self.rule.infer_backward( + [self.x_dist_tensor_spec], + [self.output_dist_tensor_spec], + self.attrs, + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [-1, -1, 0, 1] + ) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [-1, -1, 0, 1] + ) + + # shape: [6, 12, 48, 24] --> [72, 3, 16, 24] (intput --> output) + # dims_mapping: [0, 1, -1, -1] --> [0, -1, 1, -1], [0, 1, -1, -1] (output --> input, output) + self.output_dist_tensor_spec.shape = [72, 3, 16, 24] + self.output_dist_tensor_spec.set_dims_mapping([0, 1, -1, -1]) + result_dist_attrs = self.rule.infer_backward( + [self.x_dist_tensor_spec], + [self.output_dist_tensor_spec], + self.attrs, + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [0, -1, 1, -1] + ) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [0, 1, -1, -1] + ) + + # shape: [6, 12, 48, 24] --> [72, 3, 16, 24] (intput --> output) + # dims_mapping: [1, -1, -1, -1] --> [1, -1, -1, -1], [1, -1, -1, -1] (output --> input, output) + self.output_dist_tensor_spec.shape = [72, 3, 16, 24] + self.output_dist_tensor_spec.set_dims_mapping([1, -1, -1, -1]) + result_dist_attrs = self.rule.infer_backward( + [self.x_dist_tensor_spec], + [self.output_dist_tensor_spec], + self.attrs, + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [1, -1, -1, -1] + ) + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [1, -1, -1, -1] + ) + if __name__ == "__main__": unittest.main()