提交 faa1084b 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!2865 asymmetric param row split support for GatherV2

Merge pull request !2865 from yihuaijie/dev
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include <numeric> #include <numeric>
#include <functional> #include <functional>
#include <utility> #include <utility>
#include <algorithm>
#include "parallel/device_matrix.h" #include "parallel/device_matrix.h"
#include "parallel/graph_util/generate_graph.h" #include "parallel/graph_util/generate_graph.h"
...@@ -62,6 +63,55 @@ Status GatherV2PInfo::GetAttrs() { ...@@ -62,6 +63,55 @@ Status GatherV2PInfo::GetAttrs() {
return FAILED; return FAILED;
} }
auto manual_split_iter = attrs_.find("manual_split");
if (manual_split_iter != attrs_.end()) {
param_split_shapes_.clear();
manual_split_ = true;
auto var = manual_split_iter->second->cast<ValueTuplePtr>();
MS_LOG(DEBUG) << "Extract manual split strategy " << manual_split_iter->second->ToString();
if (var->size() > 0) {
std::vector<ValuePtr> elements = var->value();
for (auto &ele : elements) {
if (ele->isa<ValueSequeue>()) {
auto value_tuple = ele->cast<ValueTuplePtr>();
std::vector<ValuePtr> value_vector = value_tuple->value();
if (value_vector.size() != 2) {
MS_LOG(ERROR) << "Failure: Size of manual_split element must be 2.";
return FAILED;
}
param_split_shapes_.push_back(static_cast<int32_t>(GetValue<int>(value_vector[0])));
index_offsets_.push_back(static_cast<int32_t>(GetValue<int>(value_vector[1])));
} else {
MS_LOG(ERROR) << "Failure: Manual split strategy's format is wrong! Need ValueSequeue";
return FAILED;
}
}
if (param_split_shapes_.empty()) {
MS_LOG(ERROR) << "Failed to extract param split strategy.";
return FAILED;
}
}
}
return SUCCESS;
}
Status GatherV2PInfo::CheckManualSplit() {
auto param_shape = inputs_shape_.at(0);
int32_t split_shape_sum = std::accumulate(param_split_shapes_.begin(), param_split_shapes_.end(), 0,
[](int32_t s, int32_t shape) { return s + shape; });
if (split_shape_sum < param_shape.at(0)) {
MS_LOG(ERROR) << "Failure: Sum of splited shapes should not be smaller than param_shape.";
return FAILED;
}
if (std::any_of(index_offsets_.begin(), index_offsets_.end(), [](const int32_t &offset) { return offset < 0; })) {
MS_LOG(ERROR) << "Failure: Index offset must not less than 0.";
return FAILED;
}
return SUCCESS; return SUCCESS;
} }
...@@ -103,6 +153,14 @@ Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) { ...@@ -103,6 +153,14 @@ Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) {
return FAILED; return FAILED;
} }
if (manual_split_) {
if (CheckManualSplit() != SUCCESS) {
return FAILED;
}
// when using manual_split, no need to check belowings.
return SUCCESS;
}
// axis != 0, param_shape(0)%(param_strategy(0)*param_strategy(axis)) must be 0 // axis != 0, param_shape(0)%(param_strategy(0)*param_strategy(axis)) must be 0
if (axis_ != 0 && param_shape.at(0) % (param_strategy.at(0) * param_strategy.at(IntToSize(axis_))) != 0) { if (axis_ != 0 && param_shape.at(0) % (param_strategy.at(0) * param_strategy.at(IntToSize(axis_))) != 0) {
MS_LOG(DEBUG) << name_ << ": index_shape(0) can't be divided by (param_strategy(0)*param_strategy(axis))."; MS_LOG(DEBUG) << name_ << ": index_shape(0) can't be divided by (param_strategy(0)*param_strategy(axis)).";
...@@ -130,6 +188,11 @@ Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) { ...@@ -130,6 +188,11 @@ Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) {
} }
Status GatherV2PInfo::InferMirrorOps() { Status GatherV2PInfo::InferMirrorOps() {
// There is no mirror operators for manual split
if (manual_split_) {
return SUCCESS;
}
mirror_ops_.clear(); mirror_ops_.clear();
Shape input_a_tensor_map = inputs_tensor_map_.at(0); Shape input_a_tensor_map = inputs_tensor_map_.at(0);
std::vector<Group> input_a_group; std::vector<Group> input_a_group;
...@@ -160,6 +223,13 @@ Status GatherV2PInfo::InferDevMatrixShape() { ...@@ -160,6 +223,13 @@ Status GatherV2PInfo::InferDevMatrixShape() {
// infer input dev_matrix_shape // infer input dev_matrix_shape
auto param_strategy = strategy_->GetInputDim().at(0); auto param_strategy = strategy_->GetInputDim().at(0);
auto index_strategy = strategy_->GetInputDim().at(1); auto index_strategy = strategy_->GetInputDim().at(1);
if (manual_split_) {
dev_matrix_shape_ = param_strategy;
out_dev_matrix_shape_ = dev_matrix_shape_;
return SUCCESS;
}
dev_matrix_shape_ = param_strategy; dev_matrix_shape_ = param_strategy;
// param_strategy(axis)!=1, // param_strategy(axis)!=1,
...@@ -195,6 +265,12 @@ Status GatherV2PInfo::InferDevMatrixShape() { ...@@ -195,6 +265,12 @@ Status GatherV2PInfo::InferDevMatrixShape() {
} }
Status GatherV2PInfo::InferTensorMap() { Status GatherV2PInfo::InferTensorMap() {
if (manual_split_) {
inputs_tensor_map_.push_back({1, 0});
inputs_tensor_map_.push_back({-1, 1});
outputs_tensor_map_.push_back({-1, 1, 0});
return SUCCESS;
}
// infer input tensor map // infer input tensor map
// param_strategy(axis) != 1 // param_strategy(axis) != 1
size_t param_size = inputs_shape_.at(0).size(); size_t param_size = inputs_shape_.at(0).size();
...@@ -261,8 +337,13 @@ Status GatherV2PInfo::InferTensorInfo() { ...@@ -261,8 +337,13 @@ Status GatherV2PInfo::InferTensorInfo() {
Shape input_shape = inputs_shape_.at(0); Shape input_shape = inputs_shape_.at(0);
Shape input_index_shape = inputs_shape_.at(1); Shape input_index_shape = inputs_shape_.at(1);
Shape output_shape = outputs_shape_.at(0); Shape output_shape = outputs_shape_.at(0);
int32_t rank = g_device_manager->global_rank();
// infer tensor layout // infer tensor layout
TensorLayout input_tensor_layout, input_index_layout, output_tensor_layout; TensorLayout input_tensor_layout, input_index_layout, output_tensor_layout;
if (manual_split_) {
input_shape[0] = param_split_shapes_[rank / dev_matrix_shape_[1]];
input_shape[0] = input_shape[0] * dev_matrix_shape_[0];
}
if ((input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_.at(0), input_shape) != SUCCESS) || if ((input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_.at(0), input_shape) != SUCCESS) ||
(input_index_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_.at(1), input_index_shape) != SUCCESS) || (input_index_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_.at(1), input_index_shape) != SUCCESS) ||
(output_tensor_layout.InitFromVector(out_dev_matrix_shape_, outputs_tensor_map_.at(0), output_shape) != (output_tensor_layout.InitFromVector(out_dev_matrix_shape_, outputs_tensor_map_.at(0), output_shape) !=
...@@ -274,6 +355,9 @@ Status GatherV2PInfo::InferTensorInfo() { ...@@ -274,6 +355,9 @@ Status GatherV2PInfo::InferTensorInfo() {
TensorInfo input_index_info(input_index_layout); TensorInfo input_index_info(input_index_layout);
TensorInfo output_tensor_info(output_tensor_layout); TensorInfo output_tensor_info(output_tensor_layout);
Shape slice_shape = input_tensor_info.slice_shape();
MS_LOG(DEBUG) << "The fake slice shape is: " << ShapeToString(slice_shape);
inputs_tensor_info_.push_back(input_tensor_info); inputs_tensor_info_.push_back(input_tensor_info);
inputs_tensor_info_.push_back(input_index_info); inputs_tensor_info_.push_back(input_index_info);
outputs_tensor_info_.push_back(output_tensor_info); outputs_tensor_info_.push_back(output_tensor_info);
...@@ -312,6 +396,19 @@ Status GatherV2PInfo::InferBias() { ...@@ -312,6 +396,19 @@ Status GatherV2PInfo::InferBias() {
return FAILED; return FAILED;
} }
Status GatherV2PInfo::InferOffset() {
CheckGlobalDeviceManager();
size_t rank = g_device_manager->global_rank();
if (rank < index_offsets_.size()) {
index_offset_ = index_offsets_.at(rank);
MS_LOG(DEBUG) << name_ << ": Device rank " << rank << ", Index Offset: " << index_offset_;
return SUCCESS;
}
MS_LOG(ERROR) << name_ << ": Get index offset failed, index offset size is" << index_offsets_.size();
return FAILED;
}
Status GatherV2PInfo::InferGroup() { Status GatherV2PInfo::InferGroup() {
auto param_strategy = strategy_->GetInputDim().at(0); auto param_strategy = strategy_->GetInputDim().at(0);
size_t dim = IntToSize(axis_); size_t dim = IntToSize(axis_);
...@@ -410,6 +507,19 @@ Status GatherV2PInfo::ComputeReplaceGraph(const CNodePtr &cnode) { ...@@ -410,6 +507,19 @@ Status GatherV2PInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
MS_LOG(ERROR) << "GenerateGraph Init failed"; MS_LOG(ERROR) << "GenerateGraph Init failed";
return FAILED; return FAILED;
} }
if (manual_split_) {
if (InferOffset() != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Infer Bias failed.";
return FAILED;
}
auto sub = gen_g.PushBack({gen_g.NewOpInst(SUB), gen_g.virtual_input_node(), CreateInt32Tensor(index_offset_)});
auto gather_v2 =
gen_g.PushBack({gen_g.NewOpInst(replace_op_name_), gen_g.virtual_input_node(), sub, CreatInt32Imm(axis_)});
std::vector<std::pair<AnfNodePtr, int>> input_nodes = {std::make_pair(sub, 2), std::make_pair(gather_v2, 1)};
replace_graph_ = std::make_shared<std::pair<std::vector<std::pair<AnfNodePtr, int>>, AnfNodePtr>>(
std::make_pair(input_nodes, gather_v2));
return SUCCESS;
}
if (InferBias() != SUCCESS) { if (InferBias() != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Infer Bias failed."; MS_LOG(ERROR) << name_ << ": Infer Bias failed.";
return FAILED; return FAILED;
...@@ -444,6 +554,14 @@ Status GatherV2PInfo::ComputeReplaceGraph(const CNodePtr &cnode) { ...@@ -444,6 +554,14 @@ Status GatherV2PInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
} }
ReplaceGraphPtr GatherV2PInfo::replace_graph(const CNodePtr &cnode) { ReplaceGraphPtr GatherV2PInfo::replace_graph(const CNodePtr &cnode) {
if (manual_split_) {
if (ComputeReplaceGraph(cnode) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": ComputeReplaceGraph failed.";
return nullptr;
}
return replace_graph_;
}
auto param_strategy = strategy_->GetInputDim().at(0); auto param_strategy = strategy_->GetInputDim().at(0);
// target_ == CPU, no need to raplace graph // target_ == CPU, no need to raplace graph
if (target_ == CPU) { if (target_ == CPU) {
......
...@@ -36,6 +36,7 @@ class GatherV2PInfo : public OperatorInfo { ...@@ -36,6 +36,7 @@ class GatherV2PInfo : public OperatorInfo {
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<GatherV2PCost>()), : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<GatherV2PCost>()),
axis_(0), axis_(0),
bias_(0), bias_(0),
index_offset_(0),
slice_size_(0) {} slice_size_(0) {}
~GatherV2PInfo() override = default; ~GatherV2PInfo() override = default;
Status Init(const StrategyPtr &strategy) override; Status Init(const StrategyPtr &strategy) override;
...@@ -57,20 +58,26 @@ class GatherV2PInfo : public OperatorInfo { ...@@ -57,20 +58,26 @@ class GatherV2PInfo : public OperatorInfo {
private: private:
Status ComputeReplaceGraph(const CNodePtr &cnode); Status ComputeReplaceGraph(const CNodePtr &cnode);
Status CheckManualSplit();
Status ComputeReplaceOp(); Status ComputeReplaceOp();
Status InferBias(); Status InferBias();
Status InferOffset();
Status InferGroup(); Status InferGroup();
int32_t axis_; int32_t axis_;
std::string target_; std::string target_;
std::string replace_op_name_ = GATHERV2; std::string replace_op_name_ = GATHERV2;
int32_t bias_; int32_t bias_;
int32_t index_offset_;
int32_t slice_size_; int32_t slice_size_;
Shape out_dev_matrix_shape_; Shape out_dev_matrix_shape_;
Group group_; Group group_;
bool reduce_scatter_flag_ = false; bool reduce_scatter_flag_ = false;
int32_t split_num_ = 1; int32_t split_num_ = 1;
bool host_reduce_scatter_ = false; bool host_reduce_scatter_ = false;
bool manual_split_ = false;
std::vector<int32_t> param_split_shapes_;
std::vector<int32_t> index_offsets_;
}; };
class SparseGatherV2Info : public GatherV2PInfo { class SparseGatherV2Info : public GatherV2PInfo {
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import mindspore as ms
from mindspore import context, Tensor, Parameter
from mindspore.common.api import _executor
from mindspore.nn import Cell, TrainOneStepCell, Momentum
from mindspore.ops import operations as P
from mindspore.common.initializer import initializer
class Net(Cell):
def __init__(self, strategy1=None, strategy2=None, strategy3=None):
super().__init__()
self.gatherv2 = P.GatherV2().set_strategy(strategy1)
self.gatherv2.add_prim_attr("manual_split", ((1, 0), (7, 1)))
self.mul = P.Mul().set_strategy(strategy2)
self.reshape = P.Reshape()
self.matmul = P.MatMul().set_strategy(strategy3)
self.matmul.add_prim_attr("forward_reduce_scatter", True)
self.param = Parameter(initializer("ones", (8, 64), ms.float32), name="gatherv2_param")
self.mul_weight = Parameter(initializer("ones", (2, 4, 64), ms.float32), name="mul_weight")
self.matmul_weight = Parameter(initializer("ones", (256, 16), ms.float32), name="matmul_weight")
def construct(self, x, b):
out = self.gatherv2(self.param, x, 0)
out = self.mul(out, self.mul_weight)
out = self.reshape(out, (2, 256))
out = self.matmul(out, self.matmul_weight)
return out
_x = Tensor(np.ones([2, 4]), dtype=ms.int32)
_b = Tensor(np.ones([64, 8]), dtype=ms.float32)
def compile_net(net):
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
train_net = TrainOneStepCell(net, optimizer)
train_net.set_auto_parallel()
_executor.compile(train_net, _x, _b)
context.reset_auto_parallel_context()
def test_neg_data_parallel():
context.set_context(save_graphs=True)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=2, global_rank=0)
strategy1 = ((2, 1), (1, 2))
strategy2 = ((1, 2, 1), (1, 2, 1))
strategy3 = ((1, 2), (2, 1))
net = Net(strategy1, strategy2, strategy3)
compile_net(net)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册