未验证 提交 cc231cb3 编写于 作者: J JZ-LIANG 提交者: GitHub

[Semi-Auto] Replicated Parallel Rule (#54810)

* base rule

* add sharidng merge

* add sharidng axis merge

* define unified data class for inferencing dist_attr

---------
Co-authored-by: NYichen Zhang <zhangyichen03@baidu.com>
上级 d7e04ed6
cc_library(
spmd_rule
SRCS common.cc dist_tensor_spec.cc matmul_spmd_rule.cc
SRCS common.cc dist_tensor_spec.cc matmul_spmd_rule.cc replicated_spmd_rule.cc
DEPS phi)
......@@ -160,6 +160,14 @@ std::string GetBroadcastAxes(const int64_t& tenosr_ndim,
return alphabet.substr(broadcast_ndim - tenosr_ndim, tenosr_ndim);
}
TensorDistAttr ReplicatedOnMesh(const TensorDistAttr& src_dist_attr) {
TensorDistAttr replicated_dist_attr = src_dist_attr;
replicated_dist_attr.clear_annotated();
size_t tensor_ndim = replicated_dist_attr.dims_mapping().size();
replicated_dist_attr.set_dims_mapping(std::vector<int64_t>(tensor_ndim, -1));
return replicated_dist_attr;
}
// SPMDRuleMap
SPMDRuleMap& SPMDRuleMap::Instance() {
static SPMDRuleMap g_spmd_rule_map;
......
......@@ -17,6 +17,7 @@ limitations under the License. */
#include <iterator>
#include <map>
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/dist_tensor_spec.h"
......@@ -106,6 +107,10 @@ int64_t ShardingMergeForAxis(const std::string& axis,
const int64_t& mesh_dim1,
const int64_t& mesh_dim2);
// Intend to use for generating the TensorDistAttr of output based on the input
// activation TensorDistAttr. The process_mesh, batch_dim, dynamic_dim are
// copied with annotated is forced to False, and dims_mapping is leave to be
// null.
TensorDistAttr CopyTensorDistAttrForOutput(const TensorDistAttr& src_dist_attr);
// Resolute the partial mesh dimension of a output tensor, giving the
......@@ -124,6 +129,10 @@ std::string GetBroadcastAxes(const int64_t& tenosr_ndim,
const int64_t& broadcast_ndim,
const std::string& alphabet);
// Return a NEW TensorDistAttr whose dims mapping is consist of "-1"
// (unsharded).
TensorDistAttr ReplicatedOnMesh(const TensorDistAttr& src_dist_attr);
// The static map that stores and initializes all the registered SPMD rules.
class SPMDRuleMap {
public:
......
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/replicated_spmd_rule.h"
namespace paddle {
namespace distributed {
namespace auto_parallel {
using phi::distributed::auto_parallel::str_join;
std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
ReplicatedSPMDRule::InferForward(const std::vector<DistTensorSpec>& input_specs,
const paddle::framework::AttributeMap& attrs) {
std::vector<TensorDistAttr> intput_dist_attrs;
std::vector<TensorDistAttr> output_dist_attrs;
intput_dist_attrs.reserve(input_specs.size());
for (auto& input_spec : input_specs) {
intput_dist_attrs.push_back(ReplicatedOnMesh(input_spec.dist_attr()));
}
// TODO(ljz): we need to know num of output and size of each output before
// generate the excat replicasted dist tensor attr for the current op.
// here we just assume that only one output tensor and has the same size as
// the first input tensor.
return {intput_dist_attrs, {ReplicatedOnMesh(input_specs[0].dist_attr())}};
}
std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
ReplicatedSPMDRule::InferBackward(
const std::vector<DistTensorSpec>& input_specs,
const paddle::framework::AttributeMap& attrs) {
PADDLE_THROW(phi::errors::Unimplemented(
"InferBackward of ReplicatedSPMDRule is NOT implemented yet."));
}
} // namespace auto_parallel
} // namespace distributed
} // namespace paddle
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/common.h"
namespace paddle {
namespace distributed {
namespace auto_parallel {
// A Bottom Line Rule that enforces input(s) and output(s) of the Op to be
// replicated among the given mesh.
class ReplicatedSPMDRule : public SPMDRuleBase {
public:
// The dims_mapping of ALL TensorDistAttrs would be repeat of "-1"
// (unsharded).
std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
InferForward(const std::vector<DistTensorSpec>& input_specs,
const paddle::framework::AttributeMap& attrs) override;
// The dims_mapping of ALL TensorDistAttrs would be repeat of "-1"
// (unsharded).
std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
InferBackward(const std::vector<DistTensorSpec>& output_specs,
const paddle::framework::AttributeMap& attrs) override;
};
} // namespace auto_parallel
} // namespace distributed
} // namespace paddle
......@@ -16,6 +16,7 @@
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/common.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/matmul_spmd_rule.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/replicated_spmd_rule.h"
// TODO(ljz) Automatic this process in cmake file.
namespace paddle {
......@@ -25,6 +26,9 @@ namespace auto_parallel {
// matmul rule
REGISTER_SPMD_RULE(matmul, MatmulSPMDRule);
// replicated rule
REGISTER_SPMD_RULE(replicated, ReplicatedSPMDRule);
} // namespace auto_parallel
} // namespace distributed
} // namespace paddle
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from paddle.distributed.auto_parallel.static.completion import get_spmd_rule
from paddle.distributed.auto_parallel.static.dist_attribute import (
DistTensorSpec,
TensorDistAttr,
)
from paddle.distributed.fleet import auto
class TestMatmulSPMDRule(unittest.TestCase):
def setUp(self):
self.rule = get_spmd_rule("replicated")
x_shape = [64, 32, 10, 10]
y_shape = [32, 48]
process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2], [3, 4, 5]])
x_tensor_dist_attr = TensorDistAttr()
x_tensor_dist_attr.dims_mapping = [-1, 1, 0, -1]
x_tensor_dist_attr.process_mesh = process_mesh
self.x_dist_tensor_spec = DistTensorSpec(x_shape, x_tensor_dist_attr)
y_tensor_dist_attr = TensorDistAttr()
y_tensor_dist_attr.dims_mapping = [0, -1]
y_tensor_dist_attr.process_mesh = process_mesh
self.y_dist_tensor_spec = DistTensorSpec(y_shape, y_tensor_dist_attr)
def test_replicated_infer_forward(self):
# return all -1
result_tensor_specs = self.rule.infer_forward(
[self.x_dist_tensor_spec, self.y_dist_tensor_spec], {}
)
self.assertEqual(len(result_tensor_specs), 2)
self.assertEqual(len(result_tensor_specs[0]), 2)
self.assertEqual(len(result_tensor_specs[1]), 1)
self.assertEqual(
result_tensor_specs[0][0].dims_mapping, [-1, -1, -1, -1]
)
self.assertEqual(result_tensor_specs[0][1].dims_mapping, [-1, -1])
self.assertEqual(
result_tensor_specs[1][0].dims_mapping, [-1, -1, -1, -1]
)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册