diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/reduction_spmd_rule.cc b/paddle/fluid/distributed/auto_parallel/spmd_rules/reduction_spmd_rule.cc new file mode 100644 index 0000000000000000000000000000000000000000..851313bedfa018135f4c7130b3bf493d1af589de --- /dev/null +++ b/paddle/fluid/distributed/auto_parallel/spmd_rules/reduction_spmd_rule.cc @@ -0,0 +1,131 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/distributed/auto_parallel/spmd_rules/reduction_spmd_rule.h" +#include +#include "paddle/phi/core/distributed/auto_parallel/utils.h" + +namespace paddle { +namespace distributed { +namespace auto_parallel { + +using phi::distributed::auto_parallel::str_join; + +std::pair, std::vector> +ReductionSPMDRule::InferForward(const std::vector& input_specs, + const paddle::framework::AttributeMap& attrs) { + // step0: Verify Input Args Based on Elementwise Logic + int64_t ninputs = input_specs.size(); + PADDLE_ENFORCE_EQ( + ninputs, + 1, + phi::errors::InvalidArgument("The size of InputSpec in reduction must " + "be equal to 1, but got [%d].", + ninputs)); + VerifySpecs(input_specs, "reduction"); + + // step1: Build Einsum Notation + bool keep_dim = ExtractAttr("keep_dim", attrs); + std::vector reduce_dims = + ExtractAttr>("axis", attrs); + std::string alphabet = "abcdefghijklmnopqrstuvwxyz"; + + // get einsum notation for input + int64_t ndim = input_specs[0].shape().size(); + std::vector input_axes_vec; + std::string input_axes = alphabet.substr(0, ndim); + input_axes_vec.emplace_back(input_axes); + + // get einsum notation for output + for (int64_t i = 0, n = reduce_dims.size(); i < n; ++i) { + // convert the negative dim value to normal dim value + if (reduce_dims[i] < 0) { + reduce_dims[i] = ndim + reduce_dims[i]; + } + } + std::string output_axes = ""; + for (int64_t i = 0; i < ndim; i++) { + std::vector::iterator iter = + std::find(reduce_dims.begin(), reduce_dims.end(), i); + if (iter != reduce_dims.end()) { + // if i is reduce dim, the corresponding input axis + // will not be appended at the end of output_axes + if (keep_dim) { + output_axes.append(1, '1'); + } + } else { + // otherwise, the corresponding input axis + // will be appended at the end of output_axes + output_axes.append(1, input_axes[i]); + } + } + + // step2: Sharding Propogation + // step2.1: merge input shardings + std::vector>> axes_sharding_info; + axes_sharding_info = GetAxesDimsMappingPair(input_axes_vec, input_specs); + std::unordered_map axis_to_dim_map = + ShardingMergeForTensors(axes_sharding_info); + + // step2.2: infer output dimsmapping from merged input dimsmapping + std::vector output_dims_mapping = + GetDimsMappingForAxes(output_axes, axis_to_dim_map); + + // initialize output dist_attr's process_mesh, batch_dim and dynamic dims with + // input dist_attr. + TensorDistAttr output_dist_attr = + CopyTensorDistAttrForOutput(input_specs[0].dist_attr()); + output_dist_attr.set_dims_mapping(output_dims_mapping); + + std::vector output_dist_attrs; + output_dist_attrs.emplace_back(output_dist_attr); + + // step2.4: handle partial + // Step2.4.1 Output Partial + std::vector partial_on_dims = + ResoluteOutputPartialDimension(axis_to_dim_map, output_axes); + + // Step2.4.2 handle input tensor partial (TODO) + // If the op is a linear op, i.e. `linearity` is true, it supports + // the input to be partial. Otherwise, the input cannot be partial + // on reduced axes, we should reshard the input when the reduced + // axes are parital. + VLOG(4) << "ReductionSPMDRule InferForward: "; + for (int64_t i = 0; i < ninputs; i++) { + VLOG(4) << "Input" << std::to_string(i) << " shape: [" + << str_join(input_specs[i].shape()) << "] " + << "src_dims_mapping: [" << str_join(input_specs[i].dims_mapping()) + << "] " + << "dst_dims_mapping: [" << str_join(input_specs[i].dims_mapping()) + << "]"; + } + VLOG(4) << "Output dims_mapping: [" + str_join(output_dims_mapping) + "] " + << "partial_on_dims: [" + str_join(partial_on_dims) + "]\n\n"; + + return {{input_specs[0].dist_attr()}, output_dist_attrs}; +} + +std::pair, std::vector> +ReductionSPMDRule::InferBackward( + const std::vector& output_specs, + const paddle::framework::AttributeMap& attrs) { + PADDLE_THROW(phi::errors::Unimplemented( + "InferBackward of ReductionSPMDRule is NOT implemented yet.")); + + return {}; +} + +} // namespace auto_parallel +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/reduction_spmd_rule.h b/paddle/fluid/distributed/auto_parallel/spmd_rules/reduction_spmd_rule.h new file mode 100644 index 0000000000000000000000000000000000000000..7039529742d4022cf2ff0bf8ca2333e9d3d11ce1 --- /dev/null +++ b/paddle/fluid/distributed/auto_parallel/spmd_rules/reduction_spmd_rule.h @@ -0,0 +1,40 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include + +#include "paddle/fluid/distributed/auto_parallel/spmd_rules/common.h" + +namespace paddle { +namespace distributed { +namespace auto_parallel { + +class ReductionSPMDRule : public SPMDRuleBase { + public: + std::pair, std::vector> + InferForward(const std::vector& input_specs, + const paddle::framework::AttributeMap& attrs) override; + + std::pair, std::vector> + InferBackward(const std::vector& output_specs, + const paddle::framework::AttributeMap& attrs) override; +}; +} // namespace auto_parallel +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/rules.h b/paddle/fluid/distributed/auto_parallel/spmd_rules/rules.h index f75379cc18950c4019628bee29d76bef38a16396..1ce1dabb0b87506ba12e8f7b368257c93ed26d94 100644 --- a/paddle/fluid/distributed/auto_parallel/spmd_rules/rules.h +++ b/paddle/fluid/distributed/auto_parallel/spmd_rules/rules.h @@ -18,6 +18,7 @@ #include "paddle/fluid/distributed/auto_parallel/spmd_rules/elementwise_spmd_rule.h" #include "paddle/fluid/distributed/auto_parallel/spmd_rules/layer_norm_spmd_rule.h" #include "paddle/fluid/distributed/auto_parallel/spmd_rules/matmul_spmd_rule.h" +#include "paddle/fluid/distributed/auto_parallel/spmd_rules/reduction_spmd_rule.h" #include "paddle/fluid/distributed/auto_parallel/spmd_rules/replicated_spmd_rule.h" // TODO(ljz) Automatic this process in cmake file. @@ -28,6 +29,18 @@ namespace auto_parallel { // matmul rule REGISTER_SPMD_RULE(matmul, MatmulSPMDRule); +// reduction rules +REGISTER_SPMD_RULE(all, ReductionSPMDRule); +REGISTER_SPMD_RULE(amax, ReductionSPMDRule); +REGISTER_SPMD_RULE(amin, ReductionSPMDRule); +REGISTER_SPMD_RULE(any, ReductionSPMDRule); +REGISTER_SPMD_RULE(frobenius_norm, ReductionSPMDRule); +REGISTER_SPMD_RULE(max, ReductionSPMDRule); +REGISTER_SPMD_RULE(mean, ReductionSPMDRule); +REGISTER_SPMD_RULE(min, ReductionSPMDRule); +REGISTER_SPMD_RULE(prod, ReductionSPMDRule); +REGISTER_SPMD_RULE(sum, ReductionSPMDRule); + // elementwise rule REGISTER_SPMD_RULE(add, ElementwiseSPMDRule); REGISTER_SPMD_RULE(assign, ElementwiseSPMDRule); diff --git a/test/auto_parallel/spmd_rules/test_reduction_rule.py b/test/auto_parallel/spmd_rules/test_reduction_rule.py new file mode 100644 index 0000000000000000000000000000000000000000..a21528e781e66b8e4a50f38a6d22c8b456c84d7e --- /dev/null +++ b/test/auto_parallel/spmd_rules/test_reduction_rule.py @@ -0,0 +1,217 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from paddle.distributed.auto_parallel.static.completion import get_spmd_rule +from paddle.distributed.auto_parallel.static.dist_attribute import ( + DistTensorSpec, + TensorDistAttr, +) +from paddle.distributed.fleet import auto + + +class TestReductionSPMDRule(unittest.TestCase): + """ + Unit tests for reduction spmd rule. + """ + + def setUp(self): + self.rule = get_spmd_rule("max") + + x_shape = [64, 32] + process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3]) + + x_tensor_dist_attr = TensorDistAttr() + x_tensor_dist_attr.dims_mapping = [1, 0] + x_tensor_dist_attr.process_mesh = process_mesh + self.x_dist_tensor_spec = DistTensorSpec(x_shape, x_tensor_dist_attr) + + self.attrs = { + 'keep_dim': False, + 'axis': [0], + 'linearity': False, + } + + def test_single_mesh_dim(self): + # reduce on dim 0, keep_dim = false + # [0, -1] --> [0, -1], [-1], partial_on_dim:[0] + self.attrs['keep_dim'] = False + self.attrs['axis'] = [0] + self.x_dist_tensor_spec.set_dims_mapping([0, -1]) + result_dist_attrs = self.rule.infer_forward( + [self.x_dist_tensor_spec], self.attrs + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(len(result_dist_attrs), 2) + self.assertEqual(len(infered_input_dist_attrs), 1) + self.assertEqual(len(infered_output_dist_attrs), 1) + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, -1]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1]) + + # reduce on dim 0, keep_dim = true + # [0, -1] --> [0, -1], [-1, -1], partial_on_dim:[0] + self.attrs['keep_dim'] = True + self.attrs['axis'] = [0] + self.x_dist_tensor_spec.set_dims_mapping([0, -1]) + result_dist_attrs = self.rule.infer_forward( + [self.x_dist_tensor_spec], self.attrs + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, -1]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, -1]) + + # reduce on dim 1, keep_dim = false + # [0, -1] --> [0, -1], [0], partial_on_dim:[] + self.attrs['keep_dim'] = False + self.attrs['axis'] = [1] + self.x_dist_tensor_spec.set_dims_mapping([0, -1]) + result_dist_attrs = self.rule.infer_forward( + [self.x_dist_tensor_spec], self.attrs + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, -1]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0]) + + # reduce on dim 1, keep_dim = true + # [0, -1] --> [0, -1], [0, -1], partial_on_dim:[] + self.attrs['keep_dim'] = True + self.attrs['axis'] = [1] + self.x_dist_tensor_spec.set_dims_mapping([0, -1]) + result_dist_attrs = self.rule.infer_forward( + [self.x_dist_tensor_spec], self.attrs + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, -1]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, -1]) + + # reduce on dim 0 and 1, keep_dim = false + # [0, -1] --> [0, -1], [], partial_on_dim:[0] + self.attrs['keep_dim'] = False + self.attrs['axis'] = [0, 1] + self.x_dist_tensor_spec.set_dims_mapping([0, -1]) + result_dist_attrs = self.rule.infer_forward( + [self.x_dist_tensor_spec], self.attrs + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, -1]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, []) + + # reduce on dim 0 and 1, keep_dim = true + # [0, -1] --> [0, -1], [-1, -1], partial_on_dim:[0] + self.attrs['keep_dim'] = True + self.attrs['axis'] = [0, 1] + self.x_dist_tensor_spec.set_dims_mapping([0, -1]) + result_dist_attrs = self.rule.infer_forward( + [self.x_dist_tensor_spec], self.attrs + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, -1]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, -1]) + + def test_multi_mesh_dim(self): + process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2], [3, 4, 5]]) + self.x_dist_tensor_spec.set_process_mesh(process_mesh) + self.x_dist_tensor_spec.shape = [96, 24, 48] + + # reduce on dim 1, 2, keep_dim = false + # [0, -1, -1] --> [0, -1, -1], [0], partial_on_dim:[] + self.attrs['keep_dim'] = False + self.attrs['axis'] = [1, 2] + self.x_dist_tensor_spec.set_dims_mapping([0, -1, -1]) + result_dist_attrs = self.rule.infer_forward( + [self.x_dist_tensor_spec], self.attrs + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(len(result_dist_attrs), 2) + self.assertEqual(len(infered_input_dist_attrs), 1) + self.assertEqual(len(infered_output_dist_attrs), 1) + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, -1, -1]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0]) + + # reduce on dim 1, 2, keep_dim = false + # [-1, 0, 1] --> [-1, 0, 1], [-1], partial_on_dim:[0, 1] + self.attrs['keep_dim'] = False + self.attrs['axis'] = [1, 2] + self.x_dist_tensor_spec.set_dims_mapping([-1, 0, 1]) + result_dist_attrs = self.rule.infer_forward( + [self.x_dist_tensor_spec], self.attrs + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [-1, 0, 1]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1]) + + # reduction on dim 1, 2, keep_dim = false + # [1, -1, -1] --> [1, -1, -1], [1], partial_on_dim:[] + self.attrs['keep_dim'] = False + self.attrs['axis'] = [1, 2] + self.x_dist_tensor_spec.set_dims_mapping([1, -1, -1]) + result_dist_attrs = self.rule.infer_forward( + [self.x_dist_tensor_spec], self.attrs + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [1, -1, -1]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1]) + + # reduction on dim 1, 2, keep_dim = false + # [0, 1, -1] --> [0, 1, -1], [0], partial_on_dim:[1] + self.attrs['keep_dim'] = False + self.attrs['axis'] = [1, 2] + self.x_dist_tensor_spec.set_dims_mapping([0, 1, -1]) + result_dist_attrs = self.rule.infer_forward( + [self.x_dist_tensor_spec], self.attrs + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, 1, -1]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0]) + + # reduction on dim 1, 2, keep_dim = true + # [0, 1, -1] --> [0, 1, -1], [0, -1, -1], partial_on_dim:[1] + self.attrs['keep_dim'] = True + self.attrs['axis'] = [1, 2] + self.x_dist_tensor_spec.set_dims_mapping([0, 1, -1]) + result_dist_attrs = self.rule.infer_forward( + [self.x_dist_tensor_spec], self.attrs + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, 1, -1]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, -1, -1]) + + +if __name__ == "__main__": + unittest.main()