提交 7bc2cee3 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!167 add_squeeze_distributed_op

Merge pull request !167 from lichen/add_squeeze_distributed_op
...@@ -125,6 +125,7 @@ REGISTER(GetNextInfo); ...@@ -125,6 +125,7 @@ REGISTER(GetNextInfo);
REGISTER(NegInfo); REGISTER(NegInfo);
REGISTER(BatchMatMulInfo); REGISTER(BatchMatMulInfo);
REGISTER(ExpandDimsInfo); REGISTER(ExpandDimsInfo);
REGISTER(SqueezeInfo);
} // namespace parallel } // namespace parallel
} // namespace mindspore } // namespace mindspore
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include <algorithm> #include <algorithm>
#include <memory> #include <memory>
#include <vector> #include <vector>
#include <utility>
#include "ir/value.h" #include "ir/value.h"
#include "parallel/auto_parallel/costmodel.h" #include "parallel/auto_parallel/costmodel.h"
...@@ -544,5 +545,160 @@ Status ExpandDimsInfo::InferMirrorOps() { ...@@ -544,5 +545,160 @@ Status ExpandDimsInfo::InferMirrorOps() {
MS_LOG(INFO) << name_ << ": Create mirror ops success, the group name is " << group[0].name(); MS_LOG(INFO) << name_ << ": Create mirror ops success, the group name is " << group[0].name();
return SUCCESS; return SUCCESS;
} }
Status SqueezeInfo::InferAxis(const ValueTuplePtr& value_tuple) {
std::vector<int32_t> axis;
auto axis_list = value_tuple->value();
if (inputs_shape_.empty()) {
MS_LOG(ERROR) << name_ << ": The inputs shape is empty";
return FAILED;
}
Shape input_shape = inputs_shape_.at(0);
size_t input_size = input_shape.size();
// if axis tuple is empty, we should exclude the axis that the corresponding slice shape is 1.
if (axis_list.empty()) {
for (size_t i = 0; i < input_size; ++i) {
if (input_shape[i] == 1) {
axis.push_back(i);
}
}
axis_ = MakeValue(axis)->cast<ValueTuplePtr>();
return SUCCESS;
}
// convert negative axis to positive.
for (auto& dim : axis_list) {
if (!dim->isa<Int32Imm>()) {
MS_LOG(ERROR) << name_ << ": The type of axis is not int";
return FAILED;
}
int32_t dim_value = GetValue<int32_t>(dim);
int32_t positive_value = (dim_value < 0) ? (dim_value + SizeToInt(input_size)) : dim_value;
axis.push_back(positive_value);
}
axis_ = MakeValue(axis)->cast<ValueTuplePtr>();
return SUCCESS;
}
Status SqueezeInfo::GetAttrs() {
auto iter = attrs_.find(AXIS);
if (iter == attrs_.end()) {
MS_LOG(ERROR) << name_ << ": Can't find axis attribute.";
return FAILED;
}
MS_EXCEPTION_IF_NULL(iter->second);
auto value_tuple = iter->second->cast<ValueTuplePtr>();
MS_EXCEPTION_IF_NULL(value_tuple);
InferAxis(value_tuple);
attrs_[AXIS] = axis_;
return SUCCESS;
}
Status SqueezeInfo::InferReplaceOps(const StrategyPtr& strategy) {
Attr attr = std::make_pair(AXIS, axis_);
OperatorAttrs attrs = {attr};
OperatorParams params;
OperatorArgs args = std::make_pair(attrs, params);
replace_op_ = {std::make_pair(SQUEEZE, args)};
return SUCCESS;
}
Status SqueezeInfo::InferTensorMap() {
// for example: if the shape of input is [32, 32, 1], and the axis is (2, ),
// then the input_tensor_map is [2, 1, 0], the output_tensor_map is [2, 1]
std::vector<int32_t> input_tensor_map, output_tensor_map;
if (inputs_shape_.empty()) {
MS_LOG(ERROR) << name_ << ": The inputs shape is empty";
return FAILED;
}
size_t size = inputs_shape_[0].size();
std::vector<int32_t> axis = GetValue<const std::vector<int>>(axis_);
for (size_t i = 0; i < size; ++i) {
size_t index = size - i - 1;
auto iter = std::find(axis.begin(), axis.end(), SizeToInt(i));
if (iter == axis.end()) {
output_tensor_map.push_back(SizeToInt(index));
}
input_tensor_map.push_back(SizeToInt(index));
}
inputs_tensor_map_.push_back(input_tensor_map);
outputs_tensor_map_.push_back(output_tensor_map);
MS_LOG(INFO) << name_ << ": The tensor map of input is " << ShapeToString(input_tensor_map)
<< ", and the tensor map of output is " << ShapeToString(output_tensor_map);
return SUCCESS;
}
Status SqueezeInfo::InferTensorInfo() {
if (inputs_shape_.empty() || outputs_shape_.empty()) {
MS_LOG(ERROR) << name_ << ": The shape of inputs or outputs is empty";
return FAILED;
}
if (inputs_tensor_map_.empty() || outputs_tensor_map_.empty()) {
MS_LOG(ERROR) << name_ << ": The tensor map of inputs or outputs is empty";
return FAILED;
}
Shape input_shape = inputs_shape_[0];
Shape output_shape = outputs_shape_[0];
// infer slice shape
Shapes inputs_slice_shape, outputs_slice_shape;
Strategys inputs_strategy = strategy_->GetInputDim();
Dimensions output_strategy;
std::vector<int32_t> axis = GetValue<const std::vector<int>>(axis_);
for (size_t i = 0; i < inputs_shape_[0].size(); ++i) {
auto iter = std::find(axis.begin(), axis.end(), SizeToInt(i));
if (iter == axis.end()) {
output_strategy.push_back(inputs_strategy[0].at(i));
}
}
Strategys outputs_strategy = {output_strategy};
if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Infer slice shape failed";
return FAILED;
}
if (inputs_slice_shape.empty() || outputs_slice_shape.empty()) {
MS_LOG(ERROR) << name_ << ": The slice shape of inputs or outputs is empty";
return FAILED;
}
Shape input_slice_shape = inputs_slice_shape[0];
Shape output_slice_shape = outputs_slice_shape[0];
// infer tensor layout
TensorLayout input_tensor_layout, output_tensor_layout;
if (input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], input_shape) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init tensor layout for input failed";
return FAILED;
}
if (output_tensor_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[0], output_shape) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init tensor layout for output failed";
return FAILED;
}
TensorInfo input_tensor_info(input_tensor_layout, input_shape, input_slice_shape);
TensorInfo output_tensor_info(output_tensor_layout, output_shape, output_slice_shape);
inputs_tensor_info_.push_back(input_tensor_info);
outputs_tensor_info_.push_back(output_tensor_info);
return SUCCESS;
}
Status SqueezeInfo::Init(const StrategyPtr& strategy) {
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << " : Init failed.";
}
if (InferReplaceOps(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << " : Infer replace ops failed";
}
MS_LOG(INFO) << name_ << " : Init success.";
return SUCCESS;
}
} // namespace parallel } // namespace parallel
} // namespace mindspore } // namespace mindspore
...@@ -184,6 +184,25 @@ class ExpandDimsInfo : public ActivationOther { ...@@ -184,6 +184,25 @@ class ExpandDimsInfo : public ActivationOther {
Strategys inputs_strategy_; Strategys inputs_strategy_;
Strategys outputs_strategy_; Strategys outputs_strategy_;
}; };
class SqueezeInfo : public ActivationOther {
public:
SqueezeInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
const PrimitiveAttrs& attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
~SqueezeInfo() override = default;
protected:
Status InferAxis(const ValueTuplePtr& value_tuple);
Status GetAttrs() override;
Status InferReplaceOps(const StrategyPtr& strategy);
Status InferTensorMap() override;
Status InferTensorInfo() override;
Status Init(const StrategyPtr& strategy) override;
private:
ValueTuplePtr axis_;
};
} // namespace parallel } // namespace parallel
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_OPTIMIZER_OPS_INFO_PARALLEL_ACTIVATION_INFO_H_ #endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_ACTIVATION_INFO_H_
...@@ -123,4 +123,4 @@ class AssignSubInfo : public ArithmeticBase { ...@@ -123,4 +123,4 @@ class AssignSubInfo : public ArithmeticBase {
} // namespace parallel } // namespace parallel
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_OPTIMIZER_OPS_INFO_PARALLEL_ARITHMETIC_INFO_H_ #endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_ARITHMETIC_INFO_H_
...@@ -53,4 +53,4 @@ class MaximumInfo : public ArithmeticBase { ...@@ -53,4 +53,4 @@ class MaximumInfo : public ArithmeticBase {
} // namespace parallel } // namespace parallel
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_OPTIMIZER_OPS_INFO_PARALLEL_COMPARISON_FUNCTION_INFO_H_ #endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_COMPARISON_FUNCTION_INFO_H_
...@@ -65,4 +65,4 @@ class OneHotInfo : public OperatorInfo { ...@@ -65,4 +65,4 @@ class OneHotInfo : public OperatorInfo {
}; };
} // namespace parallel } // namespace parallel
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_OPTIMIZER_OPS_INFO_PARALLEL_ONEHOT_INFO_H_ #endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_ONEHOT_INFO_H_
...@@ -47,8 +47,8 @@ using mindspore::tensor::Tensor; ...@@ -47,8 +47,8 @@ using mindspore::tensor::Tensor;
namespace mindspore { namespace mindspore {
namespace parallel { namespace parallel {
const std::set<std::string> COMMUNICATION_OPS = {ALL_REDUCE, ALL_GATHER, ALL_TO_ALL, REDUCE_SCATTER}; static const std::set<std::string> COMMUNICATION_OPS = {ALL_REDUCE, ALL_GATHER, ALL_TO_ALL, REDUCE_SCATTER};
const std::set<std::string> INVALID_LOSS_OPS = {GET_NEXT, VIRTUALLOSS}; static const std::set<std::string> INVALID_LOSS_OPS = {GET_NEXT, VIRTUALLOSS};
// g_RefMap, for CNode B input i is a RefKey[Parameter C], // g_RefMap, for CNode B input i is a RefKey[Parameter C],
// it will be one item in map with key: C, and value: (B, i) // it will be one item in map with key: C, and value: (B, i)
static std::map<AnfNodePtr, std::pair<AnfNodePtr, int>> g_RefMap; static std::map<AnfNodePtr, std::pair<AnfNodePtr, int>> g_RefMap;
...@@ -1840,7 +1840,6 @@ void ParallelCommunication(const FuncGraphPtr& root, const std::vector<AnfNodePt ...@@ -1840,7 +1840,6 @@ void ParallelCommunication(const FuncGraphPtr& root, const std::vector<AnfNodePt
if (cnode == loss_cnode) { if (cnode == loss_cnode) {
is_loss_cnode = true; is_loss_cnode = true;
} }
// insert forward ops // insert forward ops
InsertForwardOps(distribute_operator, cnode); InsertForwardOps(distribute_operator, cnode);
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import mindspore as ms
from mindspore import context, Tensor, Parameter
from mindspore.nn import Cell, TrainOneStepCell, Momentum
from mindspore.ops import operations as P
from mindspore.common.api import _executor
class Net(Cell):
def __init__(self, strategy1=None, strategy2=None, axis=()):
super().__init__()
self.squeeze = P.Squeeze(axis=axis).set_strategy(strategy1)
self.mul = P.Mul().set_strategy(strategy2)
def construct(self, x, b):
out = self.squeeze(x)
out = self.mul(out, b)
return out
_x = Tensor(np.ones([64, 1, 32, 1]), dtype=ms.float32)
_b = Tensor(np.ones([64, 32]), dtype=ms.float32)
def compile(net):
_executor.compile(net, _x, _b)
context.reset_auto_parallel_context()
def test_squeeze_data_parallel():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
strategy1 = ((16, 1, 1, 1), )
strategy2 = ((16, 1), (16, 1))
net = Net(strategy1, strategy2)
compile(net)
def test_squeeze_model_parallel():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
strategy1 = ((1, 1, 16, 1), )
strategy2 = ((1, 16), (1, 16))
net = Net(strategy1, strategy2)
compile(net)
def test_squeeze_specified_axis():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
strategy1 = ((4, 1, 4, 1), )
strategy2 = ((8, 2), (8, 2))
net = Net(strategy1, strategy2, (1, 3))
compile(net)
def test_squeeze_auto_parallel():
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=16, global_rank=0)
net = Net()
compile(net)
def test_squeeze_repeat_calc():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
strategy1 = ((1, 1, 8, 1), )
strategy2 = ((2, 8), (2, 8))
net = Net(strategy1, strategy2)
compile(net)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册