未验证 提交 35b72e87 编写于 作者: Y Yichen Zhang 提交者: GitHub

[Semi-Auto] Add reduction spmd rule (#54991)

* add reduction spmd rule for auto parallel

* fix the logic of handling partial

* fix code style

* fix the partial handling
上级 a02e6dbd
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/reduction_spmd_rule.h"
#include <algorithm>
#include "paddle/phi/core/distributed/auto_parallel/utils.h"
namespace paddle {
namespace distributed {
namespace auto_parallel {
using phi::distributed::auto_parallel::str_join;
std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
ReductionSPMDRule::InferForward(const std::vector<DistTensorSpec>& input_specs,
const paddle::framework::AttributeMap& attrs) {
// step0: Verify Input Args Based on Elementwise Logic
int64_t ninputs = input_specs.size();
PADDLE_ENFORCE_EQ(
ninputs,
1,
phi::errors::InvalidArgument("The size of InputSpec in reduction must "
"be equal to 1, but got [%d].",
ninputs));
VerifySpecs(input_specs, "reduction");
// step1: Build Einsum Notation
bool keep_dim = ExtractAttr<bool>("keep_dim", attrs);
std::vector<int64_t> reduce_dims =
ExtractAttr<std::vector<int64_t>>("axis", attrs);
std::string alphabet = "abcdefghijklmnopqrstuvwxyz";
// get einsum notation for input
int64_t ndim = input_specs[0].shape().size();
std::vector<std::string> input_axes_vec;
std::string input_axes = alphabet.substr(0, ndim);
input_axes_vec.emplace_back(input_axes);
// get einsum notation for output
for (int64_t i = 0, n = reduce_dims.size(); i < n; ++i) {
// convert the negative dim value to normal dim value
if (reduce_dims[i] < 0) {
reduce_dims[i] = ndim + reduce_dims[i];
}
}
std::string output_axes = "";
for (int64_t i = 0; i < ndim; i++) {
std::vector<int64_t>::iterator iter =
std::find(reduce_dims.begin(), reduce_dims.end(), i);
if (iter != reduce_dims.end()) {
// if i is reduce dim, the corresponding input axis
// will not be appended at the end of output_axes
if (keep_dim) {
output_axes.append(1, '1');
}
} else {
// otherwise, the corresponding input axis
// will be appended at the end of output_axes
output_axes.append(1, input_axes[i]);
}
}
// step2: Sharding Propogation
// step2.1: merge input shardings
std::vector<std::pair<std::string, std::vector<int64_t>>> axes_sharding_info;
axes_sharding_info = GetAxesDimsMappingPair(input_axes_vec, input_specs);
std::unordered_map<std::string, int64_t> axis_to_dim_map =
ShardingMergeForTensors(axes_sharding_info);
// step2.2: infer output dimsmapping from merged input dimsmapping
std::vector<int64_t> output_dims_mapping =
GetDimsMappingForAxes(output_axes, axis_to_dim_map);
// initialize output dist_attr's process_mesh, batch_dim and dynamic dims with
// input dist_attr.
TensorDistAttr output_dist_attr =
CopyTensorDistAttrForOutput(input_specs[0].dist_attr());
output_dist_attr.set_dims_mapping(output_dims_mapping);
std::vector<TensorDistAttr> output_dist_attrs;
output_dist_attrs.emplace_back(output_dist_attr);
// step2.4: handle partial
// Step2.4.1 Output Partial
std::vector<int64_t> partial_on_dims =
ResoluteOutputPartialDimension(axis_to_dim_map, output_axes);
// Step2.4.2 handle input tensor partial (TODO)
// If the op is a linear op, i.e. `linearity` is true, it supports
// the input to be partial. Otherwise, the input cannot be partial
// on reduced axes, we should reshard the input when the reduced
// axes are parital.
VLOG(4) << "ReductionSPMDRule InferForward: ";
for (int64_t i = 0; i < ninputs; i++) {
VLOG(4) << "Input" << std::to_string(i) << " shape: ["
<< str_join(input_specs[i].shape()) << "] "
<< "src_dims_mapping: [" << str_join(input_specs[i].dims_mapping())
<< "] "
<< "dst_dims_mapping: [" << str_join(input_specs[i].dims_mapping())
<< "]";
}
VLOG(4) << "Output dims_mapping: [" + str_join(output_dims_mapping) + "] "
<< "partial_on_dims: [" + str_join(partial_on_dims) + "]\n\n";
return {{input_specs[0].dist_attr()}, output_dist_attrs};
}
std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
ReductionSPMDRule::InferBackward(
const std::vector<DistTensorSpec>& output_specs,
const paddle::framework::AttributeMap& attrs) {
PADDLE_THROW(phi::errors::Unimplemented(
"InferBackward of ReductionSPMDRule is NOT implemented yet."));
return {};
}
} // namespace auto_parallel
} // namespace distributed
} // namespace paddle
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <iterator>
#include <map>
#include <string>
#include <vector>
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/common.h"
namespace paddle {
namespace distributed {
namespace auto_parallel {
class ReductionSPMDRule : public SPMDRuleBase {
public:
std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
InferForward(const std::vector<DistTensorSpec>& input_specs,
const paddle::framework::AttributeMap& attrs) override;
std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
InferBackward(const std::vector<DistTensorSpec>& output_specs,
const paddle::framework::AttributeMap& attrs) override;
};
} // namespace auto_parallel
} // namespace distributed
} // namespace paddle
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/elementwise_spmd_rule.h" #include "paddle/fluid/distributed/auto_parallel/spmd_rules/elementwise_spmd_rule.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/layer_norm_spmd_rule.h" #include "paddle/fluid/distributed/auto_parallel/spmd_rules/layer_norm_spmd_rule.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/matmul_spmd_rule.h" #include "paddle/fluid/distributed/auto_parallel/spmd_rules/matmul_spmd_rule.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/reduction_spmd_rule.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/replicated_spmd_rule.h" #include "paddle/fluid/distributed/auto_parallel/spmd_rules/replicated_spmd_rule.h"
// TODO(ljz) Automatic this process in cmake file. // TODO(ljz) Automatic this process in cmake file.
...@@ -28,6 +29,18 @@ namespace auto_parallel { ...@@ -28,6 +29,18 @@ namespace auto_parallel {
// matmul rule // matmul rule
REGISTER_SPMD_RULE(matmul, MatmulSPMDRule); REGISTER_SPMD_RULE(matmul, MatmulSPMDRule);
// reduction rules
REGISTER_SPMD_RULE(all, ReductionSPMDRule);
REGISTER_SPMD_RULE(amax, ReductionSPMDRule);
REGISTER_SPMD_RULE(amin, ReductionSPMDRule);
REGISTER_SPMD_RULE(any, ReductionSPMDRule);
REGISTER_SPMD_RULE(frobenius_norm, ReductionSPMDRule);
REGISTER_SPMD_RULE(max, ReductionSPMDRule);
REGISTER_SPMD_RULE(mean, ReductionSPMDRule);
REGISTER_SPMD_RULE(min, ReductionSPMDRule);
REGISTER_SPMD_RULE(prod, ReductionSPMDRule);
REGISTER_SPMD_RULE(sum, ReductionSPMDRule);
// elementwise rule // elementwise rule
REGISTER_SPMD_RULE(add, ElementwiseSPMDRule); REGISTER_SPMD_RULE(add, ElementwiseSPMDRule);
REGISTER_SPMD_RULE(assign, ElementwiseSPMDRule); REGISTER_SPMD_RULE(assign, ElementwiseSPMDRule);
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from paddle.distributed.auto_parallel.static.completion import get_spmd_rule
from paddle.distributed.auto_parallel.static.dist_attribute import (
DistTensorSpec,
TensorDistAttr,
)
from paddle.distributed.fleet import auto
class TestReductionSPMDRule(unittest.TestCase):
"""
Unit tests for reduction spmd rule.
"""
def setUp(self):
self.rule = get_spmd_rule("max")
x_shape = [64, 32]
process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3])
x_tensor_dist_attr = TensorDistAttr()
x_tensor_dist_attr.dims_mapping = [1, 0]
x_tensor_dist_attr.process_mesh = process_mesh
self.x_dist_tensor_spec = DistTensorSpec(x_shape, x_tensor_dist_attr)
self.attrs = {
'keep_dim': False,
'axis': [0],
'linearity': False,
}
def test_single_mesh_dim(self):
# reduce on dim 0, keep_dim = false
# [0, -1] --> [0, -1], [-1], partial_on_dim:[0]
self.attrs['keep_dim'] = False
self.attrs['axis'] = [0]
self.x_dist_tensor_spec.set_dims_mapping([0, -1])
result_dist_attrs = self.rule.infer_forward(
[self.x_dist_tensor_spec], self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(len(result_dist_attrs), 2)
self.assertEqual(len(infered_input_dist_attrs), 1)
self.assertEqual(len(infered_output_dist_attrs), 1)
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, -1])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1])
# reduce on dim 0, keep_dim = true
# [0, -1] --> [0, -1], [-1, -1], partial_on_dim:[0]
self.attrs['keep_dim'] = True
self.attrs['axis'] = [0]
self.x_dist_tensor_spec.set_dims_mapping([0, -1])
result_dist_attrs = self.rule.infer_forward(
[self.x_dist_tensor_spec], self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, -1])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, -1])
# reduce on dim 1, keep_dim = false
# [0, -1] --> [0, -1], [0], partial_on_dim:[]
self.attrs['keep_dim'] = False
self.attrs['axis'] = [1]
self.x_dist_tensor_spec.set_dims_mapping([0, -1])
result_dist_attrs = self.rule.infer_forward(
[self.x_dist_tensor_spec], self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, -1])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0])
# reduce on dim 1, keep_dim = true
# [0, -1] --> [0, -1], [0, -1], partial_on_dim:[]
self.attrs['keep_dim'] = True
self.attrs['axis'] = [1]
self.x_dist_tensor_spec.set_dims_mapping([0, -1])
result_dist_attrs = self.rule.infer_forward(
[self.x_dist_tensor_spec], self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, -1])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, -1])
# reduce on dim 0 and 1, keep_dim = false
# [0, -1] --> [0, -1], [], partial_on_dim:[0]
self.attrs['keep_dim'] = False
self.attrs['axis'] = [0, 1]
self.x_dist_tensor_spec.set_dims_mapping([0, -1])
result_dist_attrs = self.rule.infer_forward(
[self.x_dist_tensor_spec], self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, -1])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [])
# reduce on dim 0 and 1, keep_dim = true
# [0, -1] --> [0, -1], [-1, -1], partial_on_dim:[0]
self.attrs['keep_dim'] = True
self.attrs['axis'] = [0, 1]
self.x_dist_tensor_spec.set_dims_mapping([0, -1])
result_dist_attrs = self.rule.infer_forward(
[self.x_dist_tensor_spec], self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, -1])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, -1])
def test_multi_mesh_dim(self):
process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2], [3, 4, 5]])
self.x_dist_tensor_spec.set_process_mesh(process_mesh)
self.x_dist_tensor_spec.shape = [96, 24, 48]
# reduce on dim 1, 2, keep_dim = false
# [0, -1, -1] --> [0, -1, -1], [0], partial_on_dim:[]
self.attrs['keep_dim'] = False
self.attrs['axis'] = [1, 2]
self.x_dist_tensor_spec.set_dims_mapping([0, -1, -1])
result_dist_attrs = self.rule.infer_forward(
[self.x_dist_tensor_spec], self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(len(result_dist_attrs), 2)
self.assertEqual(len(infered_input_dist_attrs), 1)
self.assertEqual(len(infered_output_dist_attrs), 1)
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, -1, -1])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0])
# reduce on dim 1, 2, keep_dim = false
# [-1, 0, 1] --> [-1, 0, 1], [-1], partial_on_dim:[0, 1]
self.attrs['keep_dim'] = False
self.attrs['axis'] = [1, 2]
self.x_dist_tensor_spec.set_dims_mapping([-1, 0, 1])
result_dist_attrs = self.rule.infer_forward(
[self.x_dist_tensor_spec], self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [-1, 0, 1])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1])
# reduction on dim 1, 2, keep_dim = false
# [1, -1, -1] --> [1, -1, -1], [1], partial_on_dim:[]
self.attrs['keep_dim'] = False
self.attrs['axis'] = [1, 2]
self.x_dist_tensor_spec.set_dims_mapping([1, -1, -1])
result_dist_attrs = self.rule.infer_forward(
[self.x_dist_tensor_spec], self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [1, -1, -1])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1])
# reduction on dim 1, 2, keep_dim = false
# [0, 1, -1] --> [0, 1, -1], [0], partial_on_dim:[1]
self.attrs['keep_dim'] = False
self.attrs['axis'] = [1, 2]
self.x_dist_tensor_spec.set_dims_mapping([0, 1, -1])
result_dist_attrs = self.rule.infer_forward(
[self.x_dist_tensor_spec], self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, 1, -1])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0])
# reduction on dim 1, 2, keep_dim = true
# [0, 1, -1] --> [0, 1, -1], [0, -1, -1], partial_on_dim:[1]
self.attrs['keep_dim'] = True
self.attrs['axis'] = [1, 2]
self.x_dist_tensor_spec.set_dims_mapping([0, 1, -1])
result_dist_attrs = self.rule.infer_forward(
[self.x_dist_tensor_spec], self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, 1, -1])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, -1, -1])
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册