diff --git a/mindspore/ccsrc/parallel/dynamic_creator.h b/mindspore/ccsrc/parallel/dynamic_creator.h index e6e1b41d76848a80f83679f7950fef942e8e6b09..1270116f50c3f9dbda20182c14c299d1af0457e7 100644 --- a/mindspore/ccsrc/parallel/dynamic_creator.h +++ b/mindspore/ccsrc/parallel/dynamic_creator.h @@ -125,6 +125,7 @@ REGISTER(SqrtInfo); REGISTER(GetNextInfo); REGISTER(NegInfo); REGISTER(BatchMatMulInfo); +REGISTER(ExpandDimsInfo); } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/parallel/ops_info/activation_info.cc b/mindspore/ccsrc/parallel/ops_info/activation_info.cc index 13155ee4f1b8a4ffb48e0c5e4bd4ccd5f88a72c8..9ba3624b01ca6647186459e6a910aacddd1bd342 100644 --- a/mindspore/ccsrc/parallel/ops_info/activation_info.cc +++ b/mindspore/ccsrc/parallel/ops_info/activation_info.cc @@ -381,5 +381,168 @@ Status CastInfo::InferMirrorOps() { return SUCCESS; } + +Status ExpandDimsInfo::GetAttrs() { + if (input_value_.size() != EXPANDDIMS_INPUT_SIZE) { + MS_LOG(ERROR) << name_ << ": Invalid inputs size " << input_value_.size(); + return FAILED; + } + + if (!input_value_.back()->isa()) { + MS_LOG(ERROR) << name_ << ": The type of axis is not int"; + return FAILED; + } + + int32_t axis = GetValue(input_value_.back()); + + if (inputs_shape_.empty()) { + MS_LOG(ERROR) << name_ << ": The inputs shape is empty"; + return FAILED; + } + + int32_t dim = SizeToInt(inputs_shape_[0].size()); + if ((axis > dim) || (axis < -dim - 1)) { + MS_LOG(ERROR) << name_ << ": The axis(" << axis << ") is out of range[" << -dim - 1 << ", " << dim << "]"; + return FAILED; + } + + if (axis < 0) { + positive_axis_ = dim + axis + 1; + } else { + positive_axis_ = axis; + } + MS_LOG(INFO) << name_ << ": The axis is " << axis << ", and the positive axis is " << positive_axis_; + return SUCCESS; +} + +Status ExpandDimsInfo::InferTensorMap() { + if (inputs_shape_.empty()) { + MS_LOG(ERROR) << name_ << ": The inputs shape is empty"; + return FAILED; + } + + // for example: if the dimension of input is 3, and the axis is 2, + // then the input_tensor_map is [2, 1, 0], the output_tensor_map is [2, 1, -1, 0] + std::vector input_tensor_map, output_tensor_map; + size_t size = inputs_shape_[0].size(); + for (size_t i = 0; i < size; ++i) { + input_tensor_map.push_back(SizeToInt(size - i - 1)); + } + + inputs_tensor_map_.push_back(input_tensor_map); + + output_tensor_map = input_tensor_map; + if ((positive_axis_ < 0) || (positive_axis_ > SizeToInt(size))) { + MS_LOG(ERROR) << name_ << ": Invalid positive axis " << positive_axis_; + return FAILED; + } + (void)output_tensor_map.insert(output_tensor_map.begin() + positive_axis_, NO_SPLIT_MAP); + outputs_tensor_map_.push_back(output_tensor_map); + + MS_LOG(INFO) << name_ << ": The tensor map of input is " << ShapeToString(input_tensor_map) + << ", and the tensor map of output is " << ShapeToString(output_tensor_map); + return SUCCESS; +} + +Status ExpandDimsInfo::InferTensorStrategy() { + if (strategy_ == nullptr) { + MS_LOG(ERROR) << name_ << ": The strategy is null"; + return FAILED; + } + + inputs_strategy_ = strategy_->GetInputDim(); + if (inputs_strategy_.empty()) { + MS_LOG(ERROR) << name_ << ": The strategy is empty"; + return FAILED; + } + + Shape output_strategy = inputs_strategy_[0]; + if ((positive_axis_ < 0) || (positive_axis_ > SizeToInt(output_strategy.size()))) { + MS_LOG(ERROR) << name_ << ": Invalid positive axis " << positive_axis_; + return FAILED; + } + (void)output_strategy.insert(output_strategy.begin() + positive_axis_, NO_SPLIT_STRATEGY); + outputs_strategy_ = {output_strategy}; + return SUCCESS; +} + +Status ExpandDimsInfo::InferTensorInfo() { + if (inputs_shape_.empty() || outputs_shape_.empty()) { + MS_LOG(ERROR) << name_ << ": The shape of inputs or outputs is empty"; + return FAILED; + } + + if (inputs_tensor_map_.empty() || outputs_tensor_map_.empty()) { + MS_LOG(ERROR) << name_ << ": The tensor map of inputs or outputs is empty"; + return FAILED; + } + + Shape input_shape = inputs_shape_[0]; + Shape output_shape = outputs_shape_[0]; + + // infer slice shape + if (InferTensorStrategy() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Infer tensor strategy failed"; + return FAILED; + } + Shapes inputs_slice_shape, outputs_slice_shape; + if (InferSliceShape(inputs_strategy_, outputs_strategy_, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Infer slice shape failed"; + return FAILED; + } + + if (inputs_slice_shape.empty() || outputs_slice_shape.empty()) { + MS_LOG(ERROR) << name_ << ": The slice shape of inputs or outputs is empty"; + return FAILED; + } + + Shape input_slice_shape = inputs_slice_shape[0]; + Shape output_slice_shape = outputs_slice_shape[0]; + + TensorLayout input_tensor_layout, output_tensor_layout; + if (input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], input_shape) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Init tensor layout for input failed"; + return FAILED; + } + + if (output_tensor_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[0], output_shape) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Init tensor layout for output failed"; + return FAILED; + } + + TensorInfo input_tensor_info(input_tensor_layout, input_shape, input_slice_shape); + TensorInfo output_tensor_info(output_tensor_layout, output_shape, output_slice_shape); + + inputs_tensor_info_.push_back(input_tensor_info); + outputs_tensor_info_.push_back(output_tensor_info); + return SUCCESS; +} + +Status ExpandDimsInfo::InferMirrorOps() { + mirror_ops_.clear(); + + if (inputs_tensor_map_.empty()) { + MS_LOG(ERROR) << name_ << ": The tensor map of inputs is empty"; + return FAILED; + } + + std::vector group; + if (CreateGroupByTensorMap(inputs_tensor_map_[0], &group) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Create group failed"; + return FAILED; + } + + if (group.empty()) { + MS_LOG(INFO) << name_ << ": No need to create mirror ops"; + return SUCCESS; + } + + OperatorVector mirror_op, placeholder_op; + mirror_op = CreateMirrorOps(group[0].name(), group[0].GetDevNum()); + mirror_ops_.push_back(mirror_op); + mirror_ops_.push_back(placeholder_op); + MS_LOG(INFO) << name_ << ": Create mirror ops success, the group name is " << group[0].name(); + return SUCCESS; +} } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/parallel/ops_info/activation_info.h b/mindspore/ccsrc/parallel/ops_info/activation_info.h index 3cadad6b80f62aa44c5f337674f3d4d0c122abc1..183b593e2389e62972b53fe387865d4bf0de5fe0 100644 --- a/mindspore/ccsrc/parallel/ops_info/activation_info.h +++ b/mindspore/ccsrc/parallel/ops_info/activation_info.h @@ -174,6 +174,26 @@ class NegInfo : public ActivationOther { : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} ~NegInfo() override = default; }; + +class ExpandDimsInfo : public ActivationOther { + public: + ExpandDimsInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, + const PrimitiveAttrs& attrs) + : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} + ~ExpandDimsInfo() override = default; + + protected: + Status GetAttrs() override; + Status InferTensorMap() override; + Status InferTensorInfo() override; + Status InferMirrorOps() override; + Status InferTensorStrategy(); + + private: + int32_t positive_axis_ = -1; + Strategys inputs_strategy_; + Strategys outputs_strategy_; +}; } // namespace parallel } // namespace mindspore #endif // MINDSPORE_CCSRC_OPTIMIZER_OPS_INFO_PARALLEL_ACTIVATION_INFO_H_ diff --git a/mindspore/ccsrc/parallel/ops_info/ops_utils.h b/mindspore/ccsrc/parallel/ops_info/ops_utils.h index 4062847d73f8a5c134f9a0b509be4f8fe23e7ed1..fe2a5d2c868033713a13a0f9a2d57813f7c4d477 100644 --- a/mindspore/ccsrc/parallel/ops_info/ops_utils.h +++ b/mindspore/ccsrc/parallel/ops_info/ops_utils.h @@ -24,6 +24,8 @@ constexpr size_t PRELU_OUTPUTS_SIZE = 1; constexpr size_t PRELU_SECOND_INPUT_SIZE = 1; constexpr int32_t PRELU_CHANNEL_INDEX = 1; constexpr int32_t PRELU_CHANNEL_STRATEGY = 1; +constexpr int32_t NO_SPLIT_MAP = -1; +constexpr int32_t NO_SPLIT_STRATEGY = 1; constexpr size_t MATMUL_ATTRS_SIZE = 2; constexpr size_t MATMUL_INPUTS_SIZE = 2; constexpr size_t MATMUL_OUTPUTS_SIZE = 1; @@ -31,6 +33,7 @@ constexpr size_t ACTIVATION_ATTR_SIZE = 1; constexpr size_t SOFTMAX_ATTR_SIZE = 1; constexpr size_t ACTIVATION_INPUTS_SIZE = 1; constexpr size_t ACTIVATION_OUTPUTS_SIZE = 1; +constexpr size_t EXPANDDIMS_INPUT_SIZE = 2; constexpr size_t SoftmaxCrossEntropyWithLogitsAttrSize = 1; constexpr size_t SoftmaxCrossEntropyWithLogitsInputsSize = 2; constexpr size_t SoftmaxCrossEntropyWithLogitsOutputsSize = 2; @@ -191,6 +194,7 @@ constexpr char GET_NEXT[] = "GetNext"; constexpr char SQUEEZE[] = "Squeeze"; constexpr char Neg[] = "Neg"; constexpr char BATCH_MATMUL[] = "BatchMatMul"; +constexpr char EXPAND_DIMS[] = "ExpandDims"; // Parallel don't care constexpr char TUPLE_GETITEM[] = "tuple_getitem"; diff --git a/mindspore/ccsrc/parallel/step_auto_parallel.cc b/mindspore/ccsrc/parallel/step_auto_parallel.cc index a359905494f429e9d094b228d1653c3c39c33f00..50e6a1e84e56b30abdf32a6d6946d32a8c4e3e2e 100644 --- a/mindspore/ccsrc/parallel/step_auto_parallel.cc +++ b/mindspore/ccsrc/parallel/step_auto_parallel.cc @@ -104,6 +104,7 @@ std::vector splittable_op_ = {MATMUL, CAST, Neg, BATCH_MATMUL, + EXPAND_DIMS, SQUEEZE}; std::vector elementwise_op_ = {ACTIVATION, GELU, TANH, SOFTMAX, LOG_SOFTMAX, RELU, SQRT, diff --git a/tests/ut/python/parallel/test_expand_dims.py b/tests/ut/python/parallel/test_expand_dims.py new file mode 100644 index 0000000000000000000000000000000000000000..676e9ed523bfdd7caac7cce2e11a225270c464d8 --- /dev/null +++ b/tests/ut/python/parallel/test_expand_dims.py @@ -0,0 +1,110 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import mindspore as ms +from mindspore import context, Tensor, Parameter +from mindspore.nn import Cell, TrainOneStepCell, Momentum +from mindspore.ops import operations as P +from mindspore.common.api import _executor + + +class Net(Cell): + def __init__(self, mul_weight, strategy1=None, strategy2=None, strategy3=None): + super().__init__() + self.mul = P.Mul().set_strategy(strategy1) + self.expand_dims = P.ExpandDims().set_strategy(strategy2) + self.mul2 = P.Mul().set_strategy(strategy3) + self.mul_weight = Parameter(mul_weight, "w1") + + def construct(self, x, b): + out = self.mul(x, self.mul_weight) + out = self.expand_dims(out, -1) + out = self.mul2(out, b) + return out + + +class Net2(Cell): + def __init__(self, mul_weight, strategy1=None, strategy2=None): + super().__init__() + self.expand_dims = P.ExpandDims().set_strategy(strategy1) + self.mul = P.Mul().set_strategy(strategy2) + self.mul_weight = Parameter(mul_weight, "w1") + + def construct(self, x, b): + out = self.expand_dims(self.mul_weight, -1) + out = self.mul(out, b) + return out + + +_x = Tensor(np.ones([128, 64, 32]), dtype=ms.float32) +_w1 = Tensor(np.ones([128, 64, 32]), dtype=ms.float32) +_b = Tensor(np.ones([128, 64, 32, 1]), dtype=ms.float32) + + +def compile(net): + optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) + train_net = TrainOneStepCell(net, optimizer) + _executor.compile(train_net, _x, _b) + context.reset_auto_parallel_context() + + +def test_expand_dims_data_parallel(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) + strategy1 = ((16, 1, 1), (16, 1, 1)) + strategy2 = ((16, 1, 1), ) + strategy3 = ((16, 1, 1, 1), (16, 1, 1, 1)) + net = Net(_w1, strategy1, strategy2, strategy3) + compile(net) + + +def test_expand_dims_model_parallel(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) + strategy1 = ((1, 1, 16), (1, 1, 16)) + strategy2 = ((1, 1, 16), ) + strategy3 = ((1, 1, 16, 1), (1, 1, 16, 1)) + net = Net(_w1, strategy1, strategy2, strategy3) + compile(net) + + +def test_expand_dims_hybrid_parallel(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) + strategy1 = ((2, 2, 4), (2, 2, 4)) + strategy2 = ((2, 2, 4), ) + strategy3 = ((2, 2, 4, 1), (2, 2, 4, 1)) + net = Net(_w1, strategy1, strategy2, strategy3) + compile(net) + + +def test_expand_dims_auto_parallel(): + context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=16, global_rank=0) + net = Net(_w1) + compile(net) + + +def test_expand_dims_repeat_calc(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) + strategy1 = ((2, 2, 4), (2, 2, 4)) + strategy2 = ((1, 2, 2), ) + strategy3 = ((2, 2, 4, 1), (2, 2, 4, 1)) + net = Net(_w1, strategy1, strategy2, strategy3) + compile(net) + + +def test_expand_dims_parameter(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) + strategy1 = ((1, 2, 2), ) + strategy2 = ((2, 2, 4, 1), (2, 2, 4, 1)) + net = Net2(_w1, strategy1, strategy2) + compile(net)