diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/cross_entropy_with_softmax_spmd_rule.cc b/paddle/fluid/distributed/auto_parallel/spmd_rules/cross_entropy_with_softmax_spmd_rule.cc index 922b525abd57a0d4f7a778ca1c00e77f3ded202b..f38a4b2f533b31884321ce798d175cbb01ff1484 100644 --- a/paddle/fluid/distributed/auto_parallel/spmd_rules/cross_entropy_with_softmax_spmd_rule.cc +++ b/paddle/fluid/distributed/auto_parallel/spmd_rules/cross_entropy_with_softmax_spmd_rule.cc @@ -34,7 +34,7 @@ CrossEntropyWithSoftmaxSPMDRule::InferForward( input_specs_size)); auto x_shape = input_specs[0].shape(); - int x_ndim = static_cast(x_shape.size()); + int x_ndim = x_shape.size(); auto x_dist_attr_src = input_specs[0].dist_attr(); std::vector x_dims_mapping_src = x_dist_attr_src.dims_mapping(); @@ -173,10 +173,116 @@ CrossEntropyWithSoftmaxSPMDRule::InferForward( std::pair, std::vector> CrossEntropyWithSoftmaxSPMDRule::InferBackward( const std::vector& input_specs, + const std::vector& output_specs, const paddle::framework::AttributeMap& attrs) { - PADDLE_THROW(phi::errors::Unimplemented( - "InferBackward of CrossEntropyWithSoftmaxSPMDRule is NOT implemented " - "yet.")); + // step0: verify input args based on cross_entropy_with_softmax logic + int64_t ninputs = input_specs.size(); + int64_t noutputs = output_specs.size(); + PADDLE_ENFORCE_EQ( + ninputs, + 2, + phi::errors::InvalidArgument("The size of InputSpec of cross entropy " + "with softmax should be 2, but got [%d].", + ninputs)); + PADDLE_ENFORCE_EQ( + noutputs, + 2, + phi::errors::InvalidArgument("The size of OutputSpec of cross entropy " + "with softmax should be 2, but got [%d].", + noutputs)); + VerifySpecs(output_specs, "cross_entropy_with_softmax_backward"); + + // step1: build Einsum Notation + std::vector x_shape = input_specs[0].shape(); + int64_t x_ndim = x_shape.size(); + std::vector label_shape = input_specs[1].shape(); + + int axis = ExtractAttr("axis", attrs); + int ignore_index = ExtractAttr("ignore_index", attrs); + bool numeric_stable_mode = ExtractAttr("numeric_stable_mode", attrs); + bool use_softmax = ExtractAttr("use_softmax", attrs); + bool soft_label = ExtractAttr("soft_label", attrs); + + // normalize axis + if (axis < 0) { + axis = x_ndim + axis; + } + + std::string alphabet = + "abcdefghijlmnopqrstuvwxyz"; // k for softmax_normalize axis + std::string x_axes = GetBroadcastAxes(x_ndim, x_ndim, alphabet); + x_axes[axis] = 'k'; + std::string label_axes = x_axes; + if (!soft_label) { + label_axes[axis] = '1'; + } + std::string loss_axes = x_axes; + loss_axes[axis] = '1'; + // optional output + std::string softmax_out_axes; + if (use_softmax) { + softmax_out_axes = x_axes; + } else { + softmax_out_axes = ""; + } + + // step2: Sharding Propogation + // step2.1 merge output dims mappings + std::vector>> axes_sharding_info; + axes_sharding_info = + GetAxesDimsMappingPair({softmax_out_axes, loss_axes}, output_specs); + std::unordered_map axis_to_dim_map = + ShardingMergeForTensors(axes_sharding_info); + + // step2.2 infer inputs' dims mappings from merged dims mapping + std::vector input_dist_attrs; + input_dist_attrs.emplace_back(input_specs[0].dist_attr()); + input_dist_attrs.emplace_back(input_specs[1].dist_attr()); + // infer and set input X's dims mapping + input_dist_attrs[0].set_dims_mapping( + GetDimsMappingForAxes(x_axes, axis_to_dim_map)); + // infer and set input label's dims mapping + input_dist_attrs[1].set_dims_mapping( + GetDimsMappingForAxes(label_axes, axis_to_dim_map)); + + // step2.3 update outputs' dims mappings with merged dims mapping + std::vector output_dist_attrs; + output_dist_attrs.emplace_back(output_specs[0].dist_attr()); // softmax_out + output_dist_attrs.emplace_back(output_specs[1].dist_attr()); // loss + output_dist_attrs[0].set_dims_mapping( + GetDimsMappingForAxes(softmax_out_axes, axis_to_dim_map)); + output_dist_attrs[1].set_dims_mapping( + GetDimsMappingForAxes(loss_axes, axis_to_dim_map)); + + // step3: Handle partial state (TODO) + + VLOG(4) << "CrossEntropyWithSoftmaxSPMDRule InferBackward: " + << "axis: " << axis << ", ignore_index: " << ignore_index + << ", numeric_stable_mode: " + << (numeric_stable_mode ? "true" : "false") + << ", use_softmax: " << use_softmax + << ", soft_label: " << (soft_label ? "true" : "false"); + VLOG(4) << "Einsum notation: [" << x_axes << "," << label_axes << " --> " + << softmax_out_axes << "," << loss_axes << "]. (inputs --> outputs)"; + for (int64_t i = 0; i < noutputs; i++) { + VLOG(4) << "Output" << std::to_string(i) << ": " + << "shape: [" << str_join(output_specs[i].shape()) + << "], src_dims_mapping: [" + << str_join(output_specs[i].dims_mapping()) + << "], dst_dims_mapping: [" + << str_join(output_dist_attrs[i].dims_mapping()) << "]"; + } + for (int64_t i = 0; i < ninputs; i++) { + VLOG(4) << "Input" << std::to_string(i) << ": " + << "shape: [" << str_join(input_specs[i].shape()) + << "], infered_dims_mapping: [" + << str_join(input_dist_attrs[i].dims_mapping()) << "]"; + } + VLOG(4) << std::endl; + + // according to the phi api implemetation, the softmax_out tensor will alway + // be genereated not matter the value of use_softmax. + return {input_dist_attrs, output_dist_attrs}; } } // namespace auto_parallel diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/cross_entropy_with_softmax_spmd_rule.h b/paddle/fluid/distributed/auto_parallel/spmd_rules/cross_entropy_with_softmax_spmd_rule.h index 212006183767444829529a8c5afed428650a517c..9ff7015f4a898cb1d3fe7b925b34bbbe00b8dee2 100644 --- a/paddle/fluid/distributed/auto_parallel/spmd_rules/cross_entropy_with_softmax_spmd_rule.h +++ b/paddle/fluid/distributed/auto_parallel/spmd_rules/cross_entropy_with_softmax_spmd_rule.h @@ -27,7 +27,8 @@ class CrossEntropyWithSoftmaxSPMDRule : public SPMDRuleBase { const paddle::framework::AttributeMap& attrs) override; std::pair, std::vector> - InferBackward(const std::vector& output_specs, + InferBackward(const std::vector& input_specs, + const std::vector& output_specs, const paddle::framework::AttributeMap& attrs) override; }; } // namespace auto_parallel diff --git a/test/auto_parallel/spmd_rules/test_cross_entropy_with_softmax_rule.py b/test/auto_parallel/spmd_rules/test_cross_entropy_with_softmax_rule.py index d5a92213f87c3d78552a5d1c2a723462b546b88e..95c6a33a5f48cef63abbb7d2d76d185aa9c00269 100644 --- a/test/auto_parallel/spmd_rules/test_cross_entropy_with_softmax_rule.py +++ b/test/auto_parallel/spmd_rules/test_cross_entropy_with_softmax_rule.py @@ -39,6 +39,9 @@ class TestCrossEntropyWithSoftmaxSPMDRule(unittest.TestCase): label_shape, label_tensor_dist_attr ) + self.loss_spec = DistTensorSpec(self.lable_dist_tensor_spec) + self.softmax_out_spec = DistTensorSpec(self.x_dist_tensor_spec) + self.attrs = { 'ignore_index': -1, 'axis': -1, @@ -147,6 +150,122 @@ class TestCrossEntropyWithSoftmaxSPMDRule(unittest.TestCase): ) self.attrs['axis'] = -1 + def test_cross_entropy_with_softmax_infer_backward(self): + # GPT DP case + # [1, 0, -1], [1, 0, -1] (outputs) --> + # [1, 0, -1], [1, 0, -1], (inputs) + # [1, 0, -1], [1, 0, -1] (outputs) + self.attrs['axis'] = -1 + self.attrs['use_softmax'] = True + self.attrs['soft_label'] = False + self.softmax_out_spec.set_dims_mapping([1, 0, -1]) + self.loss_spec.set_dims_mapping([1, 0, -1]) + + result_dist_attrs = self.rule1.infer_backward( + [self.x_dist_tensor_spec, self.lable_dist_tensor_spec], + [self.softmax_out_spec, self.loss_spec], + self.attrs, + ) + self.assertEqual(len(result_dist_attrs), 2) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(len(infered_input_dist_attrs), 2) + self.assertEqual(len(infered_output_dist_attrs), 2) + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [1, 0, -1]) + self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [1, 0, -1]) + + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [1, 0, -1] + ) # softmax output + self.assertEqual( + infered_output_dist_attrs[1].dims_mapping, [1, 0, -1] + ) # loss + + # GPT MP case, shard normalized axis + # [-1, -1, 0], [-1, -1, -1] (outputs) --> + # [-1, -1, 0], [-1, -1, -1], (inputs) + # [-1, -1, 0], [-1, -1, -1] (outputs) + self.attrs['axis'] = -1 + self.attrs['use_softmax'] = True + self.attrs['soft_label'] = False + self.softmax_out_spec.set_dims_mapping([-1, -1, 0]) + self.loss_spec.set_dims_mapping([-1, -1, -1]) + + result_dist_attrs = self.rule1.infer_backward( + [self.x_dist_tensor_spec, self.lable_dist_tensor_spec], + [self.softmax_out_spec, self.loss_spec], + self.attrs, + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [-1, -1, 0]) + self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [-1, -1, -1]) + + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [-1, -1, 0] + ) # softmax output + self.assertEqual( + infered_output_dist_attrs[1].dims_mapping, [-1, -1, -1] + ) # loss + + # GPT MP-DP case + # [-1, -1, 0], [1, -1, -1] (outputs) --> + # [1, -1, 0], [1, -1, -1], (inputs) + # [1, -1, 0], [1, -1, -1] (outputs) + self.attrs['axis'] = -1 + self.attrs['use_softmax'] = True + self.attrs['soft_label'] = False + self.softmax_out_spec.set_dims_mapping([-1, -1, 0]) + self.loss_spec.set_dims_mapping([1, -1, -1]) + + result_dist_attrs = self.rule1.infer_backward( + [self.x_dist_tensor_spec, self.lable_dist_tensor_spec], + [self.softmax_out_spec, self.loss_spec], + self.attrs, + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [1, -1, 0]) + self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [1, -1, -1]) + + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [1, -1, 0] + ) # softmax output + self.assertEqual( + infered_output_dist_attrs[1].dims_mapping, [1, -1, -1] + ) # loss + + # Soft Label, normalized axis = 1 + # [1, -1, 0], [1, -1, -1] (outputs) --> + # [1, -1, 0], [1, -1, 0], (inputs) + # [1, -1, 0], [1, -1, 0] (outputs) + self.attrs['axis'] = 1 + self.attrs['use_softmax'] = True + self.attrs['soft_label'] = True + self.softmax_out_spec.set_dims_mapping([1, -1, 0]) + self.loss_spec.set_dims_mapping([1, -1, -1]) + result_dist_attrs = self.rule1.infer_backward( + [self.x_dist_tensor_spec, self.lable_dist_tensor_spec], + [self.softmax_out_spec, self.loss_spec], + self.attrs, + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [1, -1, 0]) + self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [1, -1, 0]) + + self.assertEqual( + infered_output_dist_attrs[0].dims_mapping, [1, -1, 0] + ) # softmax output + self.assertEqual( + infered_output_dist_attrs[1].dims_mapping, [1, -1, 0] + ) # loss + if __name__ == "__main__": unittest.main()