未验证 提交 cf76e7ae 编写于 作者: Y Yichen Zhang 提交者: GitHub

[Semi-Auto] add split spmd rule (#55397)

* add split spmd rule

* add pytest in cmake file

* small fix
上级 1f3e6ec4
......@@ -182,8 +182,8 @@ TensorDistAttr ReplicatedOnMesh(const TensorDistAttr& src_dist_attr) {
void VerifySpecs(const std::vector<DistTensorSpec>& specs,
const std::string& op_name) {
for (size_t i = 0, n = specs.size(); i < n; ++i) {
std::vector<int64_t> shape = specs[i].shape();
std::vector<int64_t> dims_mapping = specs[i].dims_mapping();
const std::vector<int64_t>& shape = specs[i].shape();
const std::vector<int64_t>& dims_mapping = specs[i].dims_mapping();
PADDLE_ENFORCE_EQ(shape.size(),
dims_mapping.size(),
phi::errors::InvalidArgument(
......
......@@ -23,6 +23,7 @@
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/reduction_spmd_rule.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/replicated_spmd_rule.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/softmax_spmd_rule.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/split_spmd_rule.h"
// TODO(ljz) Automatic this process in cmake file.
namespace paddle {
......@@ -150,6 +151,10 @@ REGISTER_SPMD_RULE(log_softmax, SoftmaxSPMDRule);
REGISTER_SPMD_RULE(cross_entropy_with_softmax, CrossEntropyWithSoftmaxSPMDRule);
REGISTER_SPMD_RULE(softmax_with_cross_entropy, CrossEntropyWithSoftmaxSPMDRule);
// split rule
REGISTER_SPMD_RULE(split, SplitSPMDRule);
REGISTER_SPMD_RULE(split_with_num, SplitSPMDRule);
} // namespace auto_parallel
} // namespace distributed
} // namespace paddle
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/split_spmd_rule.h"
#include <algorithm>
#include <typeinfo>
#include "paddle/phi/core/distributed/auto_parallel/utils.h"
namespace paddle {
namespace distributed {
namespace auto_parallel {
using phi::distributed::auto_parallel::str_join;
std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
SplitSPMDRule::InferForward(const std::vector<DistTensorSpec>& input_specs,
const paddle::framework::AttributeMap& attrs) {
// step0: Verify Input Args Based on Elementwise Logic
int64_t ninputs = input_specs.size();
PADDLE_ENFORCE_EQ(
ninputs,
1,
phi::errors::InvalidArgument("The size of InputSpec in split must "
"be equal to 1, but got [%d].",
ninputs));
VerifySpecs(input_specs, "split");
// step1: Build Einsum Notation
int64_t ndim = input_specs[0].shape().size();
int64_t noutput = 0;
// split api uses num or sections as attribute
if (attrs.find("num") != attrs.end()) {
noutput = ExtractAttr<int64_t>("num", attrs);
} else if (attrs.find("sections") != attrs.end()) {
std::vector<int64_t> sections =
ExtractAttr<std::vector<int64_t>>("sections", attrs);
noutput = sections.size();
}
int64_t axis = ExtractAttr<int>("axis", attrs);
if (axis < 0) {
axis += ndim;
}
std::string alphabet = "abcdefghijlmnopqrstuvwxyz";
// get einsum notation for input, use a special
// notation 'k' to mark the splitted axis in input
std::vector<std::string> input_axes_vec;
std::string input_axes = alphabet.substr(0, ndim);
input_axes[axis] = 'k';
input_axes_vec.emplace_back(input_axes);
// get einsum notation for output
std::string output_axes(input_axes);
// the splitted axis cannot be sharded, set its notation
// with the special '1' to set its dim mapping to -1.
output_axes[axis] = '1';
// step2: Sharding Propogation
// step2.1: merge input shardings
std::vector<std::pair<std::string, std::vector<int64_t>>> axes_sharding_info;
axes_sharding_info = GetAxesDimsMappingPair(input_axes_vec, input_specs);
std::unordered_map<std::string, int64_t> axis_to_dim_map =
ShardingMergeForTensors(axes_sharding_info);
// step2.2: infer output dimsmapping from merged input dimsmapping
std::vector<int64_t> output_dims_mapping =
GetDimsMappingForAxes(output_axes, axis_to_dim_map);
// get the dist attributes for all outputs, the
// dist attributes are same for all outputs.
std::vector<TensorDistAttr> output_dist_attrs;
for (int64_t i = 0; i < noutput; i++) {
output_dist_attrs.emplace_back(
CopyTensorDistAttrForOutput(input_specs[0].dist_attr()));
output_dist_attrs[i].set_dims_mapping(output_dims_mapping);
}
// step2.3 get new dist attribute for input. the splitted
// cannot be sharded, if it is sharded, set it to replicated.
std::vector<TensorDistAttr> new_input_dist_attrs;
new_input_dist_attrs.emplace_back(input_specs[0].dist_attr());
std::vector<int64_t> new_input_dims_mapping(input_specs[0].dims_mapping());
new_input_dims_mapping[axis] = -1;
new_input_dist_attrs[0].set_dims_mapping(new_input_dims_mapping);
// Step2.4 handle input tensor partial (TODO)
VLOG(4) << "SplitSPMDRule InferForward: ";
for (int64_t i = 0; i < ninputs; i++) {
VLOG(4) << "Input" << std::to_string(i) << " shape: ["
<< str_join(input_specs[i].shape()) << "] "
<< "einsum_notation: " << input_axes << " src_dims_mapping: ["
<< str_join(input_specs[i].dims_mapping()) << "] "
<< "dst_dims_mapping: ["
<< str_join(new_input_dist_attrs[i].dims_mapping()) << "]";
}
for (int64_t i = 0; i < noutput; i++) {
VLOG(4) << "Output" << std::to_string(i) << " dims_mapping: ["
<< str_join(output_dims_mapping) << "]";
}
return {new_input_dist_attrs, output_dist_attrs};
}
std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
SplitSPMDRule::InferBackward(const std::vector<DistTensorSpec>& output_specs,
const paddle::framework::AttributeMap& attrs) {
PADDLE_THROW(phi::errors::Unimplemented(
"InferBackward of SplitPMDRule is NOT implemented yet."));
return {};
}
} // namespace auto_parallel
} // namespace distributed
} // namespace paddle
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <iterator>
#include <map>
#include <string>
#include <vector>
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/common.h"
namespace paddle {
namespace distributed {
namespace auto_parallel {
class SplitSPMDRule : public SPMDRuleBase {
public:
std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
InferForward(const std::vector<DistTensorSpec>& input_specs,
const paddle::framework::AttributeMap& attrs) override;
std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
InferBackward(const std::vector<DistTensorSpec>& output_specs,
const paddle::framework::AttributeMap& attrs) override;
};
} // namespace auto_parallel
} // namespace distributed
} // namespace paddle
......@@ -8,6 +8,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_matmul_rule MODULES test_embedding_rule)
py_test_modules(test_matmul_rule MODULES test_replicated_rule)
py_test_modules(test_matmul_rule MODULES test_softmax_rule)
py_test_modules(test_split_rule MODULES test_split_rule)
# End of unittests WITH single card WITHOUT timeout
endif()
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from paddle.distributed.auto_parallel.static.completion import get_spmd_rule
from paddle.distributed.auto_parallel.static.dist_attribute import (
DistTensorSpec,
TensorDistAttr,
)
from paddle.distributed.fleet import auto
class TestReductionSPMDRule(unittest.TestCase):
"""
Unit tests for split spmd rule.
"""
def setUp(self):
self.rule = get_spmd_rule("split")
x_shape = [64, 32, 48]
process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3])
x_tensor_dist_attr = TensorDistAttr()
x_tensor_dist_attr.dims_mapping = [1, 0]
x_tensor_dist_attr.process_mesh = process_mesh
self.x_dist_tensor_spec = DistTensorSpec(x_shape, x_tensor_dist_attr)
self.attrs = {
'num_or_sections': 2,
'axis': 1,
}
def test_single_mesh_dim(self):
# num_or_sections = 2, axis = 1
# [0, -1, -1] --> [0, -1, -1], [0, -1, -1], [0, -1, -1]
self.rule = get_spmd_rule("split_with_num")
self.attrs = {}
self.attrs['num'] = 2
self.attrs['axis'] = 1
self.x_dist_tensor_spec.set_dims_mapping([0, -1, -1])
result_dist_attrs = self.rule.infer_forward(
[self.x_dist_tensor_spec], self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(len(result_dist_attrs), 2)
self.assertEqual(len(infered_input_dist_attrs), 1)
self.assertEqual(len(infered_output_dist_attrs), 2)
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, -1, -1])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, -1, -1])
self.assertEqual(infered_output_dist_attrs[1].dims_mapping, [0, -1, -1])
# num_or_sections = [15, 16, 17], axis = 2
# [0, -1, -1] --> [0, -1, -1], [0, -1, -1], [0, -1, -1], [0, -1, -1]
self.rule = get_spmd_rule("split")
self.attrs = {}
self.attrs['sections'] = [15, 16, 17]
self.attrs['axis'] = 2
self.x_dist_tensor_spec.set_dims_mapping([0, -1, -1])
result_dist_attrs = self.rule.infer_forward(
[self.x_dist_tensor_spec], self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(len(result_dist_attrs), 2)
self.assertEqual(len(infered_input_dist_attrs), 1)
self.assertEqual(len(infered_output_dist_attrs), 3)
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, -1, -1])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, -1, -1])
self.assertEqual(infered_output_dist_attrs[1].dims_mapping, [0, -1, -1])
self.assertEqual(infered_output_dist_attrs[2].dims_mapping, [0, -1, -1])
# num_or_sections = [15, 16, 17], axis = 2
# [-1, -1, 0] --> [-1, -1, -1], [-1, -1, -1], [-1 -1, -1], [-1, -1, -1]
self.attrs = {}
self.attrs['sections'] = [15, 16, 17]
self.attrs['axis'] = 2
self.x_dist_tensor_spec.set_dims_mapping([-1, -1, 0])
result_dist_attrs = self.rule.infer_forward(
[self.x_dist_tensor_spec], self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(len(result_dist_attrs), 2)
self.assertEqual(len(infered_input_dist_attrs), 1)
self.assertEqual(len(infered_output_dist_attrs), 3)
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [-1, -1, -1])
self.assertEqual(
infered_output_dist_attrs[0].dims_mapping, [-1, -1, -1]
)
self.assertEqual(
infered_output_dist_attrs[1].dims_mapping, [-1, -1, -1]
)
self.assertEqual(
infered_output_dist_attrs[2].dims_mapping, [-1, -1, -1]
)
# num_or_sections = 2, axis = -2
# [0, -1, -1] --> [0, -1, -1], [0, -1, -1], [0, -1, -1]
self.rule = get_spmd_rule("split_with_num")
self.attrs = {}
self.attrs['num'] = 2
self.attrs['axis'] = -2
self.x_dist_tensor_spec.set_dims_mapping([0, -1, -1])
result_dist_attrs = self.rule.infer_forward(
[self.x_dist_tensor_spec], self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(len(result_dist_attrs), 2)
self.assertEqual(len(infered_input_dist_attrs), 1)
self.assertEqual(len(infered_output_dist_attrs), 2)
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, -1, -1])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, -1, -1])
self.assertEqual(infered_output_dist_attrs[1].dims_mapping, [0, -1, -1])
def test_multi_mesh_dim(self):
process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2], [3, 4, 5]])
self.x_dist_tensor_spec.set_process_mesh(process_mesh)
self.x_dist_tensor_spec.shape = [96, 32, 48, 24]
# num_or_sections = 3, axis = -1
# [0, 1, -1, -1] --> [0, 1, -1, -1], [0, 1, -1, -1], [0, 1, -1, -1], [0, 1, -1, -1]
self.rule = get_spmd_rule("split_with_num")
self.attrs = {}
self.attrs['num'] = 3
self.attrs['axis'] = -1
self.x_dist_tensor_spec.set_dims_mapping([0, 1, -1, -1])
result_dist_attrs = self.rule.infer_forward(
[self.x_dist_tensor_spec], self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(len(result_dist_attrs), 2)
self.assertEqual(len(infered_input_dist_attrs), 1)
self.assertEqual(len(infered_output_dist_attrs), 3)
self.assertEqual(
infered_input_dist_attrs[0].dims_mapping, [0, 1, -1, -1]
)
self.assertEqual(
infered_output_dist_attrs[0].dims_mapping, [0, 1, -1, -1]
)
self.assertEqual(
infered_output_dist_attrs[1].dims_mapping, [0, 1, -1, -1]
)
self.assertEqual(
infered_output_dist_attrs[2].dims_mapping, [0, 1, -1, -1]
)
# num_or_sections = [32, 32, 32], axis = 0
# [0, 1, -1, -1] --> [-1, 1, -1, -1], [-1, 1, -1, -1], [-1, 1, -1, -1], [-1, 1, -1, -1]
self.rule = get_spmd_rule("split")
self.attrs = {}
self.attrs['sections'] = [32, 32, 32]
self.attrs['axis'] = 0
self.x_dist_tensor_spec.set_dims_mapping([0, 1, -1, -1])
result_dist_attrs = self.rule.infer_forward(
[self.x_dist_tensor_spec], self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(len(result_dist_attrs), 2)
self.assertEqual(len(infered_input_dist_attrs), 1)
self.assertEqual(len(infered_output_dist_attrs), 3)
self.assertEqual(
infered_input_dist_attrs[0].dims_mapping, [-1, 1, -1, -1]
)
self.assertEqual(
infered_output_dist_attrs[0].dims_mapping, [-1, 1, -1, -1]
)
self.assertEqual(
infered_output_dist_attrs[1].dims_mapping, [-1, 1, -1, -1]
)
self.assertEqual(
infered_output_dist_attrs[2].dims_mapping, [-1, 1, -1, -1]
)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册