未验证 提交 4faac179 编写于 作者: Y Yichen Zhang 提交者: GitHub

add elementwise backward rule (#56506)

上级 fa1d0e39
......@@ -25,7 +25,7 @@ ElementwiseSPMDRule::InferForward(
const std::vector<DistTensorSpec>& input_specs,
const paddle::framework::AttributeMap& attrs) {
// step0: Verify Input Args Based on Elementwise Logic
int64_t ninputs = static_cast<int64_t>(input_specs.size());
int64_t ninputs = input_specs.size();
PADDLE_ENFORCE_GT(
ninputs,
0,
......@@ -39,7 +39,7 @@ ElementwiseSPMDRule::InferForward(
std::vector<std::string> input_axes_vec;
int64_t max_ndim = 0;
for (int64_t i = 0; i < ninputs; ++i) {
int64_t ndim = static_cast<int64_t>(input_specs[i].shape().size());
int64_t ndim = input_specs[i].shape().size();
if (ndim > max_ndim) {
max_ndim = ndim;
}
......@@ -49,7 +49,7 @@ ElementwiseSPMDRule::InferForward(
std::vector<int64_t> broadcast_axis_count(max_ndim, 0);
for (int64_t i = 0; i < ninputs; ++i) {
std::vector<int64_t> shape = input_specs[i].shape();
int64_t ndim = static_cast<int64_t>(shape.size());
int64_t ndim = shape.size();
int64_t start_dim = max_ndim - ndim;
std::string axes_notation = GetBroadcastAxes(ndim, max_ndim, alphabet);
if (ninputs > 1) {
......@@ -108,8 +108,8 @@ ElementwiseSPMDRule::InferForward(
new_input_dist_attrs.emplace_back(dist_attr);
}
// step2.4: handle partial
// Step2.3.2 handle input tensor partial (TODO)
// step3: handle partial
// handle input tensor partial (TODO)
VLOG(4) << "ElementwiseSPMDRule InferForward:";
for (int64_t i = 0; i < ninputs; i++) {
VLOG(4) << "Input" << std::to_string(i) << " shape: ["
......@@ -127,12 +127,85 @@ ElementwiseSPMDRule::InferForward(
std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
ElementwiseSPMDRule::InferBackward(
const std::vector<DistTensorSpec>& input_specs,
const std::vector<DistTensorSpec>& output_specs,
const paddle::framework::AttributeMap& attrs) {
PADDLE_THROW(phi::errors::Unimplemented(
"InferBackward of ElementwiseSPMDRule is NOT implemented yet."));
// step0: Verify Input Args Based on Elementwise Logic
int64_t ninputs = input_specs.size();
int64_t noutputs = output_specs.size();
PADDLE_ENFORCE_GT(
ninputs,
0,
phi::errors::InvalidArgument("The size of InputSpec in elementwise must "
"be greater than 0, but got [%d].",
ninputs));
PADDLE_ENFORCE_EQ(
noutputs,
1,
phi::errors::InvalidArgument("The size of OutputSpec in elementwise must "
"be equal to 1, but got [%d].",
noutputs));
VerifySpecs(output_specs, "elementwise_backward");
// step1: Build Einsum Notation
std::string alphabet = "abcdefghijklmnopqrstuvwxyz";
std::vector<std::string> input_axes_vec;
int64_t output_ndim = output_specs[0].shape().size();
std::string output_axes =
GetBroadcastAxes(output_ndim, output_ndim, alphabet);
// get einsum notation for each input, deal with broadcast
for (int64_t i = 0; i < ninputs; ++i) {
const std::vector<int64_t>& shape = input_specs[i].shape();
int64_t ndim = shape.size();
int64_t start_dim = output_ndim - ndim;
std::string axes_notation = GetBroadcastAxes(ndim, output_ndim, alphabet);
if (ninputs > 1) {
for (int64_t idim = 0; idim < output_ndim; idim++) {
// deal with the broadcast axes
if (idim >= start_dim && shape[idim - start_dim] == 1) {
// mark the broadcast axis to a special "1"
axes_notation[idim - start_dim] = '1';
}
}
}
input_axes_vec.emplace_back(axes_notation);
}
// step2: Sharding Propogation
// step2.1: get dim mapping for each output axis
std::unordered_map<std::string, int64_t> axis_to_dim_map =
ShardingMergeForTensors({{output_axes, output_specs[0].dims_mapping()}});
// step2.2: infer input dims mappings from output dims mapping
// and get the input distributed attributes to return
std::vector<TensorDistAttr> input_dist_attrs;
std::vector<TensorDistAttr> output_dist_attrs;
for (int64_t i = 0; i < ninputs; ++i) {
const DistTensorSpec& spec = input_specs[i];
TensorDistAttr dist_attr(spec.dist_attr());
std::vector<int64_t> dims_mapping =
GetDimsMappingForAxes(input_axes_vec[i], axis_to_dim_map);
dist_attr.set_dims_mapping(dims_mapping);
input_dist_attrs.emplace_back(dist_attr);
}
output_dist_attrs.emplace_back(output_specs[0].dist_attr());
// step3: handle partial (TODO)
VLOG(4) << "ElementwiseSPMDRule InferBackward:";
VLOG(4) << "Output shape: [" << str_join(output_specs[0].shape())
<< "] dims_mapping: [" << str_join(output_specs[0].dims_mapping())
<< "]";
for (int64_t i = 0; i < ninputs; i++) {
VLOG(4) << "Input" << std::to_string(i) << " shape: ["
<< str_join(input_specs[i].shape()) << "] "
<< "dims_mapping: [" << str_join(input_dist_attrs[i].dims_mapping())
<< "]";
}
return {};
return {input_dist_attrs, output_dist_attrs};
}
} // namespace auto_parallel
......
......@@ -32,7 +32,8 @@ class ElementwiseSPMDRule : public SPMDRuleBase {
const paddle::framework::AttributeMap& attrs) override;
std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
InferBackward(const std::vector<DistTensorSpec>& output_specs,
InferBackward(const std::vector<DistTensorSpec>& input_specs,
const std::vector<DistTensorSpec>& output_specs,
const paddle::framework::AttributeMap& attrs) override;
};
} // namespace auto_parallel
......
......@@ -40,6 +40,8 @@ class TestElementwiseSPMDRule(unittest.TestCase):
y_tensor_dist_attr.process_mesh = process_mesh
self.y_dist_tensor_spec = DistTensorSpec(y_shape, y_tensor_dist_attr)
self.out_dist_tensor_spec = DistTensorSpec(self.x_dist_tensor_spec)
self.attrs = {}
def test_single_mesh_dim(self):
......@@ -87,7 +89,7 @@ class TestElementwiseSPMDRule(unittest.TestCase):
self.x_dist_tensor_spec.set_dims_mapping([-1, 0])
result_dist_attrs = self.rule.infer_forward(
[self.x_dist_tensor_spec, self.y_dist_tensor_spec], self.attrs
[self.x_dist_tensor_spec], self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
......@@ -309,6 +311,253 @@ class TestElementwiseSPMDRule(unittest.TestCase):
infered_output_dist_attrs[0].dims_mapping, [-1, 0, -1, 1]
)
def test_backward_single_mesh_dim(self):
# [0, -1] --> [0, -1], [0, -1], [0, -1] (output --> inputs, output)
self.out_dist_tensor_spec.set_dims_mapping([0, -1])
result_dist_attrs = self.rule.infer_backward(
[self.x_dist_tensor_spec, self.y_dist_tensor_spec],
[self.out_dist_tensor_spec],
self.attrs,
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, -1])
self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [0, -1])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, -1])
# [-1, -1] --> [-1, -1], [-1, -1], [-1, -1] (output --> inputs, output)
self.out_dist_tensor_spec.set_dims_mapping([-1, -1])
result_dist_attrs = self.rule.infer_backward(
[self.x_dist_tensor_spec, self.y_dist_tensor_spec],
[self.out_dist_tensor_spec],
self.attrs,
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [-1, -1])
self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [-1, -1])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, -1])
# [-1, 0]--> [-1, 0], [-1, 0] (output --> inputs, output)
self.out_dist_tensor_spec.set_dims_mapping([-1, 0])
result_dist_attrs = self.rule.infer_backward(
[self.x_dist_tensor_spec], [self.out_dist_tensor_spec], self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [-1, 0])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, 0])
def test_backward_single_mesh_dim_broadcast(self):
self.x_dist_tensor_spec.shape = [64, 36, 12]
self.y_dist_tensor_spec.shape = [12]
self.out_dist_tensor_spec.shape = [64, 36, 12]
# [0, -1, -1] --> [0, -1, -1], [-1], [0, -1, -1] (output --> inputs, output)
self.out_dist_tensor_spec.set_dims_mapping([0, -1, -1])
resulted_dist_attrs = self.rule.infer_backward(
[self.x_dist_tensor_spec, self.y_dist_tensor_spec],
[self.out_dist_tensor_spec],
self.attrs,
)
infered_input_dist_attrs = resulted_dist_attrs[0]
infered_output_dist_attrs = resulted_dist_attrs[1]
self.assertEqual(len(resulted_dist_attrs), 2)
self.assertEqual(len(infered_input_dist_attrs), 2)
self.assertEqual(len(infered_output_dist_attrs), 1)
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, -1, -1])
self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [-1])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, -1, -1])
# [-1, 0, -1] --> [-1, 0, -1], [-1], [-1, 0, -1] (output --> inputs, output)
self.out_dist_tensor_spec.set_dims_mapping([-1, 0, -1])
resulted_dist_attrs = self.rule.infer_backward(
[self.x_dist_tensor_spec, self.y_dist_tensor_spec],
[self.out_dist_tensor_spec],
self.attrs,
)
infered_input_dist_attrs = resulted_dist_attrs[0]
infered_output_dist_attrs = resulted_dist_attrs[1]
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [-1, 0, -1])
self.assertEqual((infered_input_dist_attrs[1].dims_mapping), [-1])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, 0, -1])
# [-1, -1, 0] --> [-1, -1, 0], [0], [-1, -1, 0] (output --> inputs, output)
self.out_dist_tensor_spec.set_dims_mapping([-1, -1, 0])
resulted_dist_attrs = self.rule.infer_backward(
[self.x_dist_tensor_spec, self.y_dist_tensor_spec],
[self.out_dist_tensor_spec],
self.attrs,
)
infered_input_dist_attrs = resulted_dist_attrs[0]
infered_output_dist_attrs = resulted_dist_attrs[1]
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [-1, -1, 0])
self.assertEqual((infered_input_dist_attrs[1].dims_mapping), [0])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, -1, 0])
self.x_dist_tensor_spec.shape = [64, 36, 12]
self.y_dist_tensor_spec.shape = [1, 12]
self.out_dist_tensor_spec.shape = [64, 36, 12]
# [-1, 0, -1] --> [-1, 0, -1], [-1, -1], [-1, 0, -1] (output --> inputs, output)
self.out_dist_tensor_spec.set_dims_mapping([-1, 0, -1])
resulted_dist_attrs = self.rule.infer_backward(
[self.x_dist_tensor_spec, self.y_dist_tensor_spec],
[self.out_dist_tensor_spec],
self.attrs,
)
infered_input_dist_attrs = resulted_dist_attrs[0]
infered_output_dist_attrs = resulted_dist_attrs[1]
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [-1, 0, -1])
self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [-1, -1])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, 0, -1])
self.x_dist_tensor_spec.shape = [64, 1, 1, 12]
self.y_dist_tensor_spec.shape = [64, 32, 12]
self.out_dist_tensor_spec.shape = [64, 64, 32, 12]
# [0, -1, -1, -1] --> [0, -1, -1, -1], [-1, -1, -1], [0, -1, -1, -1] (output --> inputs, output)
self.out_dist_tensor_spec.set_dims_mapping([0, -1, -1, -1])
resulted_dist_attrs = self.rule.infer_backward(
[self.x_dist_tensor_spec, self.y_dist_tensor_spec],
[self.out_dist_tensor_spec],
self.attrs,
)
infered_input_dist_attrs = resulted_dist_attrs[0]
infered_output_dist_attrs = resulted_dist_attrs[1]
self.assertEqual(
infered_input_dist_attrs[0].dims_mapping, [0, -1, -1, -1]
)
self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [-1, -1, -1])
self.assertEqual(
infered_output_dist_attrs[0].dims_mapping, [0, -1, -1, -1]
)
# [-1, 0, -1, -1] --> [-1, -1, -1, -1], [0, -1, -1], [-1, 0, -1, -1] (output --> inputs, output)
self.out_dist_tensor_spec.set_dims_mapping([-1, 0, -1, -1])
resulted_dist_attrs = self.rule.infer_backward(
[self.x_dist_tensor_spec, self.y_dist_tensor_spec],
[self.out_dist_tensor_spec],
self.attrs,
)
infered_input_dist_attrs = resulted_dist_attrs[0]
infered_output_dist_attrs = resulted_dist_attrs[1]
self.assertEqual(
infered_input_dist_attrs[0].dims_mapping, [-1, -1, -1, -1]
)
self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [0, -1, -1])
self.assertEqual(
infered_output_dist_attrs[0].dims_mapping, [-1, -0, -1, -1]
)
def test_backward_multi_mesh_dim(self):
process_mesh = auto.ProcessMesh([[0, 1, 2], [3, 4, 5]])
self.x_dist_tensor_spec.set_process_mesh(process_mesh)
self.y_dist_tensor_spec.set_process_mesh(process_mesh)
self.x_dist_tensor_spec.shape = [96, 24, 48]
self.y_dist_tensor_spec.shape = [96, 24, 48]
self.out_dist_tensor_spec.shape = [96, 24, 48]
# [0, 1, -1] --> [0, 1, -1], [0, 1, -1], [0, 1, -1] (output --> inputs, output)
self.out_dist_tensor_spec.set_dims_mapping([0, 1, -1])
resulted_dist_attrs = self.rule.infer_backward(
[self.x_dist_tensor_spec, self.y_dist_tensor_spec],
[self.out_dist_tensor_spec],
self.attrs,
)
infered_input_dist_attrs = resulted_dist_attrs[0]
infered_output_dist_attrs = resulted_dist_attrs[1]
self.assertEqual(len(resulted_dist_attrs), 2)
self.assertEqual(len(infered_input_dist_attrs), 2)
self.assertEqual(len(infered_output_dist_attrs), 1)
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, 1, -1])
self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [0, 1, -1])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, 1, -1])
def test_backward_multi_mesh_dim_broadcast(self):
process_mesh = auto.ProcessMesh([[0, 1, 2], [3, 4, 5]])
self.x_dist_tensor_spec.set_process_mesh(process_mesh)
self.y_dist_tensor_spec.set_process_mesh(process_mesh)
self.x_dist_tensor_spec.shape = [96, 24, 48]
self.y_dist_tensor_spec.shape = [48]
self.out_dist_tensor_spec.shape = [96, 24, 48]
# [0, -1, 1] --> [0, -1, 1], [1], [0, -1, 1] (output --> inputs, output)
self.out_dist_tensor_spec.set_dims_mapping([0, -1, 1])
resulted_dist_attrs = self.rule.infer_backward(
[self.x_dist_tensor_spec, self.y_dist_tensor_spec],
[self.out_dist_tensor_spec],
self.attrs,
)
infered_input_dist_attrs = resulted_dist_attrs[0]
infered_output_dist_attrs = resulted_dist_attrs[1]
self.assertEqual(len(resulted_dist_attrs), 2)
self.assertEqual(len(infered_input_dist_attrs), 2)
self.assertEqual(len(infered_output_dist_attrs), 1)
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, -1, 1])
self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [1])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, -1, 1])
# [0, 1, -1] --> [0, 1, -1], [-1], [0, 1, -1] (output --> inputs, output)
self.out_dist_tensor_spec.set_dims_mapping([0, 1, -1])
resulted_dist_attrs = self.rule.infer_backward(
[self.x_dist_tensor_spec, self.y_dist_tensor_spec],
[self.out_dist_tensor_spec],
self.attrs,
)
infered_input_dist_attrs = resulted_dist_attrs[0]
infered_output_dist_attrs = resulted_dist_attrs[1]
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, 1, -1])
self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [-1])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, 1, -1])
self.x_dist_tensor_spec.shape = [96, 1, 1, 48]
self.y_dist_tensor_spec.shape = [96, 24, 48]
self.out_dist_tensor_spec.shape = [96, 96, 24, 48]
# [-1, 0, -1, 1] --> [-1, -1, -1, 1], [0, -1, 1], [-1, 0, -1, 1] (output --> inputs, output)
self.out_dist_tensor_spec.set_dims_mapping([-1, 0, -1, 1])
resulted_dist_attrs = self.rule.infer_backward(
[self.x_dist_tensor_spec, self.y_dist_tensor_spec],
[self.out_dist_tensor_spec],
self.attrs,
)
infered_input_dist_attrs = resulted_dist_attrs[0]
infered_output_dist_attrs = resulted_dist_attrs[1]
self.assertEqual(
infered_input_dist_attrs[0].dims_mapping, [-1, -1, -1, 1]
)
self.assertEqual(infered_input_dist_attrs[1].dims_mapping, [0, -1, 1])
self.assertEqual(
infered_output_dist_attrs[0].dims_mapping, [-1, 0, -1, 1]
)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册