提交 617b98f1 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!3966 [AutoParallel]Add dropout distributed op

Merge pull request !3966 from lichen/add_dropout_distributed_op
......@@ -135,6 +135,7 @@ REGISTER(GatherV2PInfo);
REGISTER(EmbeddingLookupInfo);
REGISTER(TileInfo);
REGISTER(StridedSliceInfo);
REGISTER(DropoutInfo);
} // namespace parallel
} // namespace mindspore
......
......@@ -20,6 +20,8 @@
#include <memory>
#include <vector>
#include <utility>
#include <functional>
#include <numeric>
#include "ir/value.h"
#include "frontend/parallel/auto_parallel/costmodel.h"
......@@ -54,6 +56,29 @@ Status Activation::CheckStrategy(const StrategyPtr &strategy) {
return SUCCESS;
}
Status DropoutInfo::CheckStrategy(const StrategyPtr &strategy) {
if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) {
if (is_auto_parallel_) {
MS_LOG(DEBUG) << name_ << " : Invalid strategy.";
} else {
MS_LOG(ERROR) << name_ << " : Invalid strategy.";
}
return FAILED;
}
// dropout don't support repeated calculation
CheckGlobalDeviceManager();
auto input_strategy = strategy->GetInputDim().at(0);
size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size();
auto product_p = std::accumulate(input_strategy.begin(), input_strategy.end(), 1, std::multiplies<int>());
if (IntToSize(product_p) != dev_num) {
MS_LOG(ERROR) << name_ << ": Invalid strategy. Don't support repeated calc.";
return FAILED;
}
return SUCCESS;
}
Status ActivationInfo::GetAttrs() {
if (attrs_.size() < ACTIVATION_ATTR_SIZE) {
MS_LOG(ERROR) << name_ << " : The size of attrs small than 1.";
......@@ -120,6 +145,27 @@ Status Activation::GenerateStrategies(int32_t stage_id) {
return SUCCESS;
}
Status DropoutInfo::GenerateStrategies(int32_t stage_id) {
is_auto_parallel_ = true;
Shape input0_split(inputs_shape_[0].size(), 1);
Shapes splittable_inputs = {input0_split};
std::vector<StrategyPtr> sp_vector;
if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) {
MS_LOG(ERROR) << name_ << " : Generate strategies for independent inputs() failed.";
return FAILED;
}
size_t success = 0;
for (auto &sp : sp_vector) {
if (SetCostUnderStrategy(sp) == SUCCESS) {
success++;
MS_LOG(INFO) << name_ << " : Successfully generated " << success << " strategy";
PrintStrategy(sp);
}
}
return SUCCESS;
}
Status Softmax::CheckStrategy(const StrategyPtr &strategy) {
if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) {
if (is_auto_parallel_) {
......@@ -334,6 +380,32 @@ Status ActivationBase::InferTensorInfo() {
return SUCCESS;
}
Status DropoutInfo::InferTensorInfo() {
// infer tensor shape
Shape input_shape = inputs_shape_.at(0);
// infer slice shape
Shapes inputs_slice_shape, outputs_slice_shape;
Strategys inputs_strategy = strategy_->GetInputDim();
// dropout has two outputs
Strategys outputs_strategy = {inputs_strategy.at(0), inputs_strategy.at(0)};
if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) {
return FAILED;
}
Shape input_slice_shape = inputs_slice_shape.at(0);
TensorLayout input_tensor_layout;
if (input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], input_shape) != SUCCESS) {
return FAILED;
}
TensorInfo input_tensor_info(input_tensor_layout, input_shape, input_slice_shape);
inputs_tensor_info_.push_back(input_tensor_info);
// the two outputs of dropout all have the same tensor_info as input
outputs_tensor_info_.push_back(input_tensor_info);
outputs_tensor_info_.push_back(input_tensor_info);
return SUCCESS;
}
Status ActivationBase::Init(const StrategyPtr &strategy) {
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << " : Init failed.";
......
......@@ -219,6 +219,20 @@ class SigmoidInfo : public ActivationOther {
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
~SigmoidInfo() override = default;
};
class DropoutInfo : public ActivationOther {
public:
DropoutInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
~DropoutInfo() override = default;
Status GenerateStrategies(int32_t stage_id) override;
protected:
Status CheckStrategy(const StrategyPtr &strategy) override;
Status GetAttrs() override { return SUCCESS; }
Status InferTensorInfo() override;
};
} // namespace parallel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_ACTIVATION_INFO_H_
......@@ -238,6 +238,7 @@ constexpr char UNSORTEF_SEGMENT_PRODD[] = "UnsortedSegmentProdD";
constexpr char DEPTHWISE_CONV2D_NATIVE[] = "DepthwiseConv2dNative";
constexpr char DEPTHWISE_CONV2D[] = "DepthwiseConv2D";
constexpr char ADD[] = "Add";
constexpr char DROPOUT[] = "Dropout";
constexpr char KStridedSlice[] = "StridedSlice";
// Parallel don't care
......
......@@ -256,7 +256,7 @@ bool IsSplittableOperator(const std::string &op_name) {
REDUCE_MAX, REDUCE_MIN, ARGMAXWITHVALUE, ARGMINWITHVALUE, REDUCE_SUM, CONV2D, FUSE_BATCH_NORM, POOLING,
MAX_POOL_WITH_ARGMAX, SIMPLE_MEAN, FLATTEN, BATCH_NORM, LAYER_NORM, BIAS_ADD, ASSIGN_SUB, COS, ACOS, EXP,
LOG, REDUCE_MEAN, REAL_DIV, SIGMOID, POW, MAXIMUM, MINIMUM, EQUAL, NOT_EQUAL, LOGICALNOT, GATHERV2, SQRT,
STRIDEDSLICE, GET_NEXT, CAST, NEG, SQUARE, BATCH_MATMUL, EXPAND_DIMS, SQUEEZE, SPARSE_GATHERV2, TILE,
STRIDEDSLICE, GET_NEXT, CAST, NEG, SQUARE, BATCH_MATMUL, EXPAND_DIMS, SQUEEZE, SPARSE_GATHERV2, TILE, DROPOUT,
SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, SIGMOID_CROSS_ENTROPY_WITH_LOGITS, SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS};
// clang-format on
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import mindspore as ms
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import context
from mindspore.common.api import _executor
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from tests.ut.python.ops.test_math_ops import VirtualLoss
class NetWithLoss(nn.Cell):
def __init__(self, network):
super(NetWithLoss, self).__init__()
self.loss = VirtualLoss()
self.network = network
def construct(self, x, y):
predict = self.network(x, y)
return self.loss(predict)
class GradWrap(nn.Cell):
def __init__(self, network):
super(GradWrap, self).__init__()
self.network = network
def construct(self, x, y):
return C.grad_all(self.network)(x, y)
class Net(nn.Cell):
def __init__(self, strategy1=None, strategy2=None):
super().__init__()
self.dropout = P.Dropout(keep_prob=0.6).set_strategy(strategy1)
self.matmul = P.MatMul().set_strategy(strategy2)
def construct(self, x, y):
out = self.matmul(x, y)
out, _ = self.dropout(out)
return out
def test_dropout_semi_auto():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
net = GradWrap(NetWithLoss(Net()))
net.set_auto_parallel()
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 128]), dtype=ms.float32)
_executor.compile(net, x, y)
def test_dropout_semi_auto2():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
strategy1 = ((8, 1),)
strategy2 = ((4, 2), (2, 1))
net = GradWrap(NetWithLoss(Net(strategy1, strategy2)))
net.set_auto_parallel()
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 128]), dtype=ms.float32)
_executor.compile(net, x, y)
def test_dropout_semi_auto3():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
strategy1 = ((2, 4),)
strategy2 = ((4, 2), (2, 1))
net = GradWrap(NetWithLoss(Net(strategy1, strategy2)))
net.set_auto_parallel()
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 128]), dtype=ms.float32)
_executor.compile(net, x, y)
def test_dropout_auto():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel")
net = GradWrap(NetWithLoss(Net()))
net.set_auto_parallel()
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 128]), dtype=ms.float32)
_executor.compile(net, x, y)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册