未验证 提交 1642e84b 编写于 作者: Y Yichen Zhang 提交者: GitHub

add cross_entropy backward rule (#56507)

上级 4faac179
...@@ -34,7 +34,7 @@ CrossEntropyWithSoftmaxSPMDRule::InferForward( ...@@ -34,7 +34,7 @@ CrossEntropyWithSoftmaxSPMDRule::InferForward(
input_specs_size)); input_specs_size));
auto x_shape = input_specs[0].shape(); auto x_shape = input_specs[0].shape();
int x_ndim = static_cast<int>(x_shape.size()); int x_ndim = x_shape.size();
auto x_dist_attr_src = input_specs[0].dist_attr(); auto x_dist_attr_src = input_specs[0].dist_attr();
std::vector<int64_t> x_dims_mapping_src = x_dist_attr_src.dims_mapping(); std::vector<int64_t> x_dims_mapping_src = x_dist_attr_src.dims_mapping();
...@@ -173,10 +173,116 @@ CrossEntropyWithSoftmaxSPMDRule::InferForward( ...@@ -173,10 +173,116 @@ CrossEntropyWithSoftmaxSPMDRule::InferForward(
std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>> std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
CrossEntropyWithSoftmaxSPMDRule::InferBackward( CrossEntropyWithSoftmaxSPMDRule::InferBackward(
const std::vector<DistTensorSpec>& input_specs, const std::vector<DistTensorSpec>& input_specs,
const std::vector<DistTensorSpec>& output_specs,
const paddle::framework::AttributeMap& attrs) { const paddle::framework::AttributeMap& attrs) {
PADDLE_THROW(phi::errors::Unimplemented( // step0: verify input args based on cross_entropy_with_softmax logic
"InferBackward of CrossEntropyWithSoftmaxSPMDRule is NOT implemented " int64_t ninputs = input_specs.size();
"yet.")); int64_t noutputs = output_specs.size();
PADDLE_ENFORCE_EQ(
ninputs,
2,
phi::errors::InvalidArgument("The size of InputSpec of cross entropy "
"with softmax should be 2, but got [%d].",
ninputs));
PADDLE_ENFORCE_EQ(
noutputs,
2,
phi::errors::InvalidArgument("The size of OutputSpec of cross entropy "
"with softmax should be 2, but got [%d].",
noutputs));
VerifySpecs(output_specs, "cross_entropy_with_softmax_backward");
// step1: build Einsum Notation
std::vector<int64_t> x_shape = input_specs[0].shape();
int64_t x_ndim = x_shape.size();
std::vector<int64_t> label_shape = input_specs[1].shape();
int axis = ExtractAttr<int>("axis", attrs);
int ignore_index = ExtractAttr<int>("ignore_index", attrs);
bool numeric_stable_mode = ExtractAttr<bool>("numeric_stable_mode", attrs);
bool use_softmax = ExtractAttr<bool>("use_softmax", attrs);
bool soft_label = ExtractAttr<bool>("soft_label", attrs);
// normalize axis
if (axis < 0) {
axis = x_ndim + axis;
}
std::string alphabet =
"abcdefghijlmnopqrstuvwxyz"; // k for softmax_normalize axis
std::string x_axes = GetBroadcastAxes(x_ndim, x_ndim, alphabet);
x_axes[axis] = 'k';
std::string label_axes = x_axes;
if (!soft_label) {
label_axes[axis] = '1';
}
std::string loss_axes = x_axes;
loss_axes[axis] = '1';
// optional output
std::string softmax_out_axes;
if (use_softmax) {
softmax_out_axes = x_axes;
} else {
softmax_out_axes = "";
}
// step2: Sharding Propogation
// step2.1 merge output dims mappings
std::vector<std::pair<std::string, std::vector<int64_t>>> axes_sharding_info;
axes_sharding_info =
GetAxesDimsMappingPair({softmax_out_axes, loss_axes}, output_specs);
std::unordered_map<std::string, int64_t> axis_to_dim_map =
ShardingMergeForTensors(axes_sharding_info);
// step2.2 infer inputs' dims mappings from merged dims mapping
std::vector<TensorDistAttr> input_dist_attrs;
input_dist_attrs.emplace_back(input_specs[0].dist_attr());
input_dist_attrs.emplace_back(input_specs[1].dist_attr());
// infer and set input X's dims mapping
input_dist_attrs[0].set_dims_mapping(
GetDimsMappingForAxes(x_axes, axis_to_dim_map));
// infer and set input label's dims mapping
input_dist_attrs[1].set_dims_mapping(
GetDimsMappingForAxes(label_axes, axis_to_dim_map));
// step2.3 update outputs' dims mappings with merged dims mapping
std::vector<TensorDistAttr> output_dist_attrs;
output_dist_attrs.emplace_back(output_specs[0].dist_attr()); // softmax_out
output_dist_attrs.emplace_back(output_specs[1].dist_attr()); // loss
output_dist_attrs[0].set_dims_mapping(
GetDimsMappingForAxes(softmax_out_axes, axis_to_dim_map));
output_dist_attrs[1].set_dims_mapping(
GetDimsMappingForAxes(loss_axes, axis_to_dim_map));
// step3: Handle partial state (TODO)
VLOG(4) << "CrossEntropyWithSoftmaxSPMDRule InferBackward: "
<< "axis: " << axis << ", ignore_index: " << ignore_index
<< ", numeric_stable_mode: "
<< (numeric_stable_mode ? "true" : "false")
<< ", use_softmax: " << use_softmax
<< ", soft_label: " << (soft_label ? "true" : "false");
VLOG(4) << "Einsum notation: [" << x_axes << "," << label_axes << " --> "
<< softmax_out_axes << "," << loss_axes << "]. (inputs --> outputs)";
for (int64_t i = 0; i < noutputs; i++) {
VLOG(4) << "Output" << std::to_string(i) << ": "
<< "shape: [" << str_join(output_specs[i].shape())
<< "], src_dims_mapping: ["
<< str_join(output_specs[i].dims_mapping())
<< "], dst_dims_mapping: ["
<< str_join(output_dist_attrs[i].dims_mapping()) << "]";
}
for (int64_t i = 0; i < ninputs; i++) {
VLOG(4) << "Input" << std::to_string(i) << ": "
<< "shape: [" << str_join(input_specs[i].shape())
<< "], infered_dims_mapping: ["
<< str_join(input_dist_attrs[i].dims_mapping()) << "]";
}
VLOG(4) << std::endl;
// according to the phi api implemetation, the softmax_out tensor will alway
// be genereated not matter the value of use_softmax.
return {input_dist_attrs, output_dist_attrs};
} }
} // namespace auto_parallel } // namespace auto_parallel
......
...@@ -27,7 +27,8 @@ class CrossEntropyWithSoftmaxSPMDRule : public SPMDRuleBase { ...@@ -27,7 +27,8 @@ class CrossEntropyWithSoftmaxSPMDRule : public SPMDRuleBase {
const paddle::framework::AttributeMap& attrs) override; const paddle::framework::AttributeMap& attrs) override;
std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>> std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
InferBackward(const std::vector<DistTensorSpec>& output_specs, InferBackward(const std::vector<DistTensorSpec>& input_specs,
const std::vector<DistTensorSpec>& output_specs,
const paddle::framework::AttributeMap& attrs) override; const paddle::framework::AttributeMap& attrs) override;
}; };
} // namespace auto_parallel } // namespace auto_parallel
......
...@@ -39,6 +39,9 @@ class TestCrossEntropyWithSoftmaxSPMDRule(unittest.TestCase): ...@@ -39,6 +39,9 @@ class TestCrossEntropyWithSoftmaxSPMDRule(unittest.TestCase):
label_shape, label_tensor_dist_attr label_shape, label_tensor_dist_attr
) )
self.loss_spec = DistTensorSpec(self.lable_dist_tensor_spec)
self.softmax_out_spec = DistTensorSpec(self.x_dist_tensor_spec)
self.attrs = { self.attrs = {
'ignore_index': -1, 'ignore_index': -1,
'axis': -1, 'axis': -1,
...@@ -147,6 +150,122 @@ class TestCrossEntropyWithSoftmaxSPMDRule(unittest.TestCase): ...@@ -147,6 +150,122 @@ class TestCrossEntropyWithSoftmaxSPMDRule(unittest.TestCase):
) )
self.attrs['axis'] = -1 self.attrs['axis'] = -1
def test_cross_entropy_with_softmax_infer_backward(self):
# GPT DP case
# [1, 0, -1], [1, 0, -1] (outputs) -->
# [1, 0, -1], [1, 0, -1], (inputs)
# [1, 0, -1], [1, 0, -1] (outputs)
self.attrs['axis'] = -1
self.attrs['use_softmax'] = True
self.attrs['soft_label'] = False
self.softmax_out_spec.set_dims_mapping([1, 0, -1])
self.loss_spec.set_dims_mapping([1, 0, -1])
result_dist_attrs = self.rule1.infer_backward(
[self.x_dist_tensor_spec, self.lable_dist_tensor_spec],
[self.softmax_out_spec, self.loss_spec],
self.attrs,
)
self.assertEqual(len(result_dist_attrs), 2)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(len(infered_input_dist_attrs), 2)
self.assertEqual(len(infered_output_dist_attrs), 2)
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [1, 0, -1])
self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [1, 0, -1])
self.assertEqual(
infered_output_dist_attrs[0].dims_mapping, [1, 0, -1]
) # softmax output
self.assertEqual(
infered_output_dist_attrs[1].dims_mapping, [1, 0, -1]
) # loss
# GPT MP case, shard normalized axis
# [-1, -1, 0], [-1, -1, -1] (outputs) -->
# [-1, -1, 0], [-1, -1, -1], (inputs)
# [-1, -1, 0], [-1, -1, -1] (outputs)
self.attrs['axis'] = -1
self.attrs['use_softmax'] = True
self.attrs['soft_label'] = False
self.softmax_out_spec.set_dims_mapping([-1, -1, 0])
self.loss_spec.set_dims_mapping([-1, -1, -1])
result_dist_attrs = self.rule1.infer_backward(
[self.x_dist_tensor_spec, self.lable_dist_tensor_spec],
[self.softmax_out_spec, self.loss_spec],
self.attrs,
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [-1, -1, 0])
self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [-1, -1, -1])
self.assertEqual(
infered_output_dist_attrs[0].dims_mapping, [-1, -1, 0]
) # softmax output
self.assertEqual(
infered_output_dist_attrs[1].dims_mapping, [-1, -1, -1]
) # loss
# GPT MP-DP case
# [-1, -1, 0], [1, -1, -1] (outputs) -->
# [1, -1, 0], [1, -1, -1], (inputs)
# [1, -1, 0], [1, -1, -1] (outputs)
self.attrs['axis'] = -1
self.attrs['use_softmax'] = True
self.attrs['soft_label'] = False
self.softmax_out_spec.set_dims_mapping([-1, -1, 0])
self.loss_spec.set_dims_mapping([1, -1, -1])
result_dist_attrs = self.rule1.infer_backward(
[self.x_dist_tensor_spec, self.lable_dist_tensor_spec],
[self.softmax_out_spec, self.loss_spec],
self.attrs,
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [1, -1, 0])
self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [1, -1, -1])
self.assertEqual(
infered_output_dist_attrs[0].dims_mapping, [1, -1, 0]
) # softmax output
self.assertEqual(
infered_output_dist_attrs[1].dims_mapping, [1, -1, -1]
) # loss
# Soft Label, normalized axis = 1
# [1, -1, 0], [1, -1, -1] (outputs) -->
# [1, -1, 0], [1, -1, 0], (inputs)
# [1, -1, 0], [1, -1, 0] (outputs)
self.attrs['axis'] = 1
self.attrs['use_softmax'] = True
self.attrs['soft_label'] = True
self.softmax_out_spec.set_dims_mapping([1, -1, 0])
self.loss_spec.set_dims_mapping([1, -1, -1])
result_dist_attrs = self.rule1.infer_backward(
[self.x_dist_tensor_spec, self.lable_dist_tensor_spec],
[self.softmax_out_spec, self.loss_spec],
self.attrs,
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [1, -1, 0])
self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [1, -1, 0])
self.assertEqual(
infered_output_dist_attrs[0].dims_mapping, [1, -1, 0]
) # softmax output
self.assertEqual(
infered_output_dist_attrs[1].dims_mapping, [1, -1, 0]
) # loss
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册