提交 657b5471 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!3103 change type of Shape from int32 to int64

Merge pull request !3103 from yihuaijie/dev
...@@ -402,31 +402,36 @@ AbstractBasePtr InferImplDropoutGenMask(const AnalysisEnginePtr &, const Primiti ...@@ -402,31 +402,36 @@ AbstractBasePtr InferImplDropoutGenMask(const AnalysisEnginePtr &, const Primiti
for (std::size_t i = 0; i < x_shape->size(); ++i) { for (std::size_t i = 0; i < x_shape->size(); ++i) {
auto value_track = x_shape_data[i]->GetValueTrack(); auto value_track = x_shape_data[i]->GetValueTrack();
MS_EXCEPTION_IF_NULL(value_track); MS_EXCEPTION_IF_NULL(value_track);
if (!value_track->isa<Int32Imm>()) { int64_t e_value = 0;
MS_LOG(EXCEPTION) << "DropOutGenMask input x_shape elements is not int32, but " << value_track->ToString() << "."; if (value_track->isa<Int64Imm>()) {
e_value = GetValue<int64_t>(value_track);
} else if (value_track->isa<Int32Imm>()) {
e_value = static_cast<int64_t>(GetValue<int>(value_track));
} else {
MS_LOG(EXCEPTION) << "DropOutGenMask input x_shape elements is not int64 or int32, but "
<< value_track->ToString() << ".";
} }
int e_value = GetValue<int>(value_track);
if (e_value <= 0) { if (e_value <= 0) {
MS_LOG(EXCEPTION) << "DropOutGenMask product of x_shape should be > 0"; MS_LOG(EXCEPTION) << "DropOutGenMask product of x_shape should be > 0";
} }
if (std::numeric_limits<int>::max() / count / e_value < 1) { if (std::numeric_limits<int64_t>::max() / count / e_value < 1) {
MS_LOG(EXCEPTION) << "integer multiply integer overflow"; MS_LOG(EXCEPTION) << "integer multiply integer overflow";
} }
count = count * e_value; count = count * e_value;
} }
// convert to bytes(8 bits) mask, using round up // convert to bytes(8 bits) mask, using round up
int n128s = count / 128; int64_t n128s = count / 128;
if ((count % 128) != 0) { if ((count % 128) != 0) {
n128s++; n128s++;
} }
int bytes_count = n128s * 16; int64_t bytes_count = n128s * 16;
std::vector<int> shape_y{bytes_count}; std::vector<int64_t> shape_y{bytes_count};
primitive->set_attr("T", kInt32); primitive->set_attr("T", kInt32);
return std::make_shared<AbstractTensor>(std::make_shared<AbstractScalar>(kAnyValue, kUInt8), return std::make_shared<AbstractTensor>(std::make_shared<AbstractScalar>(kAnyValue, kUInt8),
std::make_shared<Shape>(std::vector<int>{shape_y})); std::make_shared<Shape>(std::vector<int64_t>{shape_y}));
} }
} // namespace abstract } // namespace abstract
} // namespace mindspore } // namespace mindspore
...@@ -1580,7 +1580,7 @@ Status CostGraph::InitSelectedStrategy() { ...@@ -1580,7 +1580,7 @@ Status CostGraph::InitSelectedStrategy() {
if (stra.empty()) { if (stra.empty()) {
MS_LOG(EXCEPTION) << "Infer strategy by tensor_info failed"; MS_LOG(EXCEPTION) << "Infer strategy by tensor_info failed";
} }
std::vector<Dimensions> stra_inputs = {stra}; Strategys stra_inputs = {stra};
StrategyPtr reshape_stra = StrategyPtr reshape_stra =
std::make_shared<Strategy>((*pre_iter)->prev_operator()->strategy()->GetInputStage(), stra_inputs); std::make_shared<Strategy>((*pre_iter)->prev_operator()->strategy()->GetInputStage(), stra_inputs);
reshape_info->set_strategy(reshape_stra); reshape_info->set_strategy(reshape_stra);
......
...@@ -31,68 +31,60 @@ void GenerateStrategy(const std::shared_ptr<Graph> &graph, const std::vector<std ...@@ -31,68 +31,60 @@ void GenerateStrategy(const std::shared_ptr<Graph> &graph, const std::vector<std
const std::shared_ptr<std::vector<std::vector<size_t>>> &eli_list, const std::shared_ptr<std::vector<std::vector<size_t>>> &eli_list,
const std::vector<std::vector<std::string>> &input_tensor_names, const std::vector<std::vector<std::string>> &input_tensor_names,
const std::shared_ptr<std::vector<size_t>> &index_list); const std::shared_ptr<std::vector<size_t>> &index_list);
std::vector<std::vector<int32_t>> PrepareMatMul(const std::shared_ptr<Graph> &graph, Strategys PrepareMatMul(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_graph, const size_t iter_ops);
const size_t iter_graph, const size_t iter_ops); Strategys PrepareBiasAdd(const std::shared_ptr<Dimensions> &s);
std::vector<std::vector<int32_t>> PrepareBiasAdd(const std::shared_ptr<std::vector<int32_t>> &s); Strategys PrepareOneHot(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
std::vector<std::vector<int32_t>> PrepareOneHot(const std::shared_ptr<Graph> &graph, const size_t iter_graph, const size_t iter_ops);
const std::vector<std::shared_ptr<OperatorInfo>> &ops, Strategys PrepareGatherV2(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, Dimensions s);
const size_t iter_graph, const size_t iter_ops); Strategys PrepareL2Normalize(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
std::vector<std::vector<int32_t>> PrepareGatherV2(const std::vector<std::shared_ptr<OperatorInfo>> &ops, Dimensions s);
const size_t iter_ops, std::vector<int32_t> s); Strategys MakeRecSearchStrategy(const std::shared_ptr<Graph> &graph,
std::vector<std::vector<int32_t>> PrepareL2Normalize(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_graph,
const size_t iter_ops, std::vector<int32_t> s); const size_t iter_ops);
std::vector<std::vector<int32_t>> CheckBroadcast(const std::vector<std::shared_ptr<OperatorInfo>> &ops, Strategys CheckBroadcast(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, Dimensions s);
const size_t iter_ops, std::vector<int32_t> s); Dimensions ApplyBroadcast(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, Dimensions s,
std::vector<int32_t> ApplyBroadcast(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, size_t target_tensor_dim, size_t refer_tensor_dim, bool braoadcast_first_tensor);
std::vector<int32_t> s, size_t target_tensor_dim, size_t refer_tensor_dim, Strategys CheckDivisible(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, Dimensions s);
bool braoadcast_first_tensor); Strategys MakeDataParallelStrategy(const std::shared_ptr<Graph> &graph,
std::vector<std::vector<int32_t>> CheckDivisible(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_graph,
const size_t iter_ops, std::vector<int32_t> s); const size_t iter_ops);
std::vector<std::vector<int32_t>> MakeRecSearchStrategy(const std::shared_ptr<Graph> &graph, Strategys PrepareStrategy(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_graph, const size_t iter_ops);
const size_t iter_graph, const size_t iter_ops);
std::vector<std::vector<int32_t>> MakeDataParallelStrategy(const std::shared_ptr<Graph> &graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_graph, const size_t iter_ops);
std::vector<std::vector<int32_t>> PrepareStrategy(const std::shared_ptr<Graph> &graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_graph, const size_t iter_ops);
void GeneratePartitionedOperatorStrategy(const std::shared_ptr<Graph> &graph, void GeneratePartitionedOperatorStrategy(const std::shared_ptr<Graph> &graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const std::shared_ptr<std::vector<size_t>> &index_list); const std::shared_ptr<std::vector<size_t>> &index_list);
size_t FindIndexOfOperatorIncoming(const std::vector<std::vector<std::string>> &input_tensor_names, size_t FindIndexOfOperatorIncoming(const std::vector<std::vector<std::string>> &input_tensor_names,
const size_t iter_ops); const size_t iter_ops);
std::vector<int32_t> CopyIncomingOperatorOutputStrategy(const std::shared_ptr<Graph> &graph, Dimensions CopyIncomingOperatorOutputStrategy(const std::shared_ptr<Graph> &graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_ops, const size_t iter_graph); const size_t iter_ops, const size_t iter_graph);
std::vector<int32_t> PrepareIncomingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops, Dimensions PrepareIncomingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t incoming_op_index); const size_t incoming_op_index);
std::vector<int32_t> GetAxisList(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const int iter_ops); Dimensions GetAxisList(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const int iter_ops);
std::vector<int32_t> ModifyStrategyIfSqueezeIncoming(const std::vector<std::shared_ptr<OperatorInfo>> &ops, Dimensions ModifyStrategyIfSqueezeIncoming(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t incoming_op_index, std::vector<int32_t> s); const size_t incoming_op_index, Dimensions s);
bool GetKeepDims(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops); bool GetKeepDims(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops);
std::vector<int32_t> GetDimList(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops); Dimensions GetDimList(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops);
std::vector<int32_t> ModifyStrategyIfReduceIncoming(const std::vector<std::shared_ptr<OperatorInfo>> &ops, Dimensions ModifyStrategyIfReduceIncoming(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t incoming_op_index, std::vector<int32_t> s); const size_t incoming_op_index, Dimensions s);
std::vector<int32_t> GetDimListFromAttrs(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops); Dimensions GetDimListFromAttrs(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops);
std::vector<int32_t> ModifyStrategyIfArgIncoming(const std::vector<std::shared_ptr<OperatorInfo>> &ops, Dimensions ModifyStrategyIfArgIncoming(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t incoming_op_index, std::vector<int32_t> s); const size_t incoming_op_index, Dimensions s);
std::vector<int32_t> CopyIncomingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops, Dimensions CopyIncomingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_ops, const size_t incoming_op_index); const size_t iter_ops, const size_t incoming_op_index);
std::vector<std::vector<int32_t>> GenerateStrategiesFromStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops, Strategys GenerateStrategiesFromStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
const size_t iter_ops, Dimensions basic_stra);
std::vector<int32_t> basic_stra);
void GenerateEliminatedOperatorStrategyForward(const std::shared_ptr<Graph> &graph, void GenerateEliminatedOperatorStrategyForward(const std::shared_ptr<Graph> &graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const std::vector<std::vector<std::string>> &input_tensor_names, const std::vector<std::vector<std::string>> &input_tensor_names,
const std::shared_ptr<std::vector<size_t>> &index_list, const std::shared_ptr<std::vector<size_t>> &index_list,
const std::shared_ptr<std::vector<size_t>> &no_stra_op_list); const std::shared_ptr<std::vector<size_t>> &no_stra_op_list);
std::vector<int32_t> ModifyStrategyIfSqueezeOutgoing(const std::vector<std::shared_ptr<OperatorInfo>> &ops, Dimensions ModifyStrategyIfSqueezeOutgoing(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
const size_t iter_ops, std::vector<int32_t> s); Dimensions s);
std::vector<int32_t> CopyOutgoingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops, Dimensions CopyOutgoingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const std::vector<std::vector<std::string>> &input_tensor_names, const std::vector<std::vector<std::string>> &input_tensor_names,
const size_t iter_ops); const size_t iter_ops);
void GenerateEliminatedOperatorStrategyBackward(const std::vector<std::shared_ptr<OperatorInfo>> &ops, void GenerateEliminatedOperatorStrategyBackward(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const std::vector<std::vector<std::string>> &input_tensor_names, const std::vector<std::vector<std::string>> &input_tensor_names,
const std::shared_ptr<std::vector<size_t>> &no_stra_op_list); const std::shared_ptr<std::vector<size_t>> &no_stra_op_list);
......
...@@ -29,7 +29,7 @@ ...@@ -29,7 +29,7 @@
namespace mindspore { namespace mindspore {
namespace parallel { namespace parallel {
static std::map<std::string, std::vector<int>> param_shapes; static std::map<std::string, Shape> param_shapes;
std::vector<std::string> PARALLEL_MODE_LIST = {STAND_ALONE, DATA_PARALLEL, HYBRID_PARALLEL, SEMI_AUTO_PARALLEL, std::vector<std::string> PARALLEL_MODE_LIST = {STAND_ALONE, DATA_PARALLEL, HYBRID_PARALLEL, SEMI_AUTO_PARALLEL,
AUTO_PARALLEL}; AUTO_PARALLEL};
...@@ -173,7 +173,7 @@ void ParallelParameterContextRestoreInNoTraining(const FuncGraphPtr &func_graph, ...@@ -173,7 +173,7 @@ void ParallelParameterContextRestoreInNoTraining(const FuncGraphPtr &func_graph,
MS_LOG(WARNING) << "Can not found the shape for parameter " << param_node->name(); MS_LOG(WARNING) << "Can not found the shape for parameter " << param_node->name();
return; return;
} }
std::vector<int> shape = iter->second; Shape shape = iter->second;
std::shared_ptr<abstract::BaseShape> base_shape = std::make_shared<abstract::Shape>(shape); std::shared_ptr<abstract::BaseShape> base_shape = std::make_shared<abstract::Shape>(shape);
ptr->set_shape(base_shape); ptr->set_shape(base_shape);
MS_LOG(DEBUG) << "The parameter name is " << param_node->name() << ", the shape is " << shape; MS_LOG(DEBUG) << "The parameter name is " << param_node->name() << ", the shape is " << shape;
...@@ -189,7 +189,10 @@ void ParallelParameterContextCkptInTraining(const FuncGraphPtr &func_graph, cons ...@@ -189,7 +189,10 @@ void ParallelParameterContextCkptInTraining(const FuncGraphPtr &func_graph, cons
return; return;
} }
std::vector<int> shape = dyn_cast<abstract::Shape>(ptr->GetShapeTrack())->shape(); std::vector<int> shape_int = dyn_cast<abstract::Shape>(ptr->GetShapeTrack())->shape();
Shape shape;
(void)std::transform(shape_int.begin(), shape_int.end(), std::back_inserter(shape),
[](const int &value) { return static_cast<int64_t>(value); });
auto ret = param_shapes.try_emplace(param_node->name(), shape); auto ret = param_shapes.try_emplace(param_node->name(), shape);
if (!ret.second) { if (!ret.second) {
MS_LOG(EXCEPTION) << "The shape for parameter name " << param_node->name() << " is existed"; MS_LOG(EXCEPTION) << "The shape for parameter name " << param_node->name() << " is existed";
......
...@@ -159,7 +159,7 @@ std::string ShapeToString(const Shape &shape) { ...@@ -159,7 +159,7 @@ std::string ShapeToString(const Shape &shape) {
return str + "]"; return str + "]";
} }
std::string ListToString(const std::vector<int32_t> &list) { std::string ListToString(const RankList &list) {
std::string str = "["; std::string str = "[";
for (auto &element : list) { for (auto &element : list) {
str += std::to_string(element) + ", "; str += std::to_string(element) + ", ";
......
...@@ -27,7 +27,7 @@ ...@@ -27,7 +27,7 @@
namespace mindspore { namespace mindspore {
namespace parallel { namespace parallel {
using RankList = std::vector<int32_t>; using RankList = std::vector<int32_t>;
using Shape = std::vector<int32_t>; using Shape = std::vector<int64_t>;
class DeviceMatrix { class DeviceMatrix {
public: public:
...@@ -48,7 +48,7 @@ class DeviceMatrix { ...@@ -48,7 +48,7 @@ class DeviceMatrix {
}; };
std::string ShapeToString(const Shape &shape); std::string ShapeToString(const Shape &shape);
std::string ListToString(const std::vector<int32_t> &list); std::string ListToString(const RankList &list);
} // namespace parallel } // namespace parallel
} // namespace mindspore } // namespace mindspore
......
...@@ -45,13 +45,13 @@ py::dict GetParameterLayout(const FuncGraphPtr &graph) { ...@@ -45,13 +45,13 @@ py::dict GetParameterLayout(const FuncGraphPtr &graph) {
auto tensor_map = tensor_layout->tensor_map().array(); auto tensor_map = tensor_layout->tensor_map().array();
auto slice_shape = tensor_layout->slice_shape().array(); auto slice_shape = tensor_layout->slice_shape().array();
int32_t _field_size = tensor_layout->get_field_size(); int32_t _field_size = tensor_layout->get_field_size();
std::vector<int32_t> field_size; Shape field_size;
if (_field_size != 0) { if (_field_size != 0) {
field_size.push_back(_field_size); field_size.push_back(_field_size);
} else { } else {
field_size = {0}; field_size = {0};
} }
std::vector<std::vector<int32_t>> layout = {device_arrangement, tensor_map, slice_shape, field_size}; std::vector<Shape> layout = {device_arrangement, tensor_map, slice_shape, field_size};
dict[py::str(name)] = layout; dict[py::str(name)] = layout;
MS_LOG(INFO) << "GetParameterLayout name = " << name << ", layout " << tensor_layout->ToString(); MS_LOG(INFO) << "GetParameterLayout name = " << name << ", layout " << tensor_layout->ToString();
} }
......
...@@ -130,7 +130,7 @@ Status Softmax::CheckStrategy(const StrategyPtr &strategy) { ...@@ -130,7 +130,7 @@ Status Softmax::CheckStrategy(const StrategyPtr &strategy) {
return FAILED; return FAILED;
} }
std::vector<Dimensions> stra = strategy->GetInputDim(); Strategys stra = strategy->GetInputDim();
Dimensions input_strategy = stra.at(0); Dimensions input_strategy = stra.at(0);
for (auto &element : axis_) { for (auto &element : axis_) {
...@@ -181,7 +181,7 @@ Status Softmax::GetAttrs() { ...@@ -181,7 +181,7 @@ Status Softmax::GetAttrs() {
MS_LOG(ERROR) << name_ << " : The axis tuple is empty."; MS_LOG(ERROR) << name_ << " : The axis tuple is empty.";
return FAILED; return FAILED;
} }
MS_LOG(INFO) << name_ << " : The axis is tuple, value is " << ShapeToString(axis_); MS_LOG(INFO) << name_ << " : The axis is tuple, value is " << ListToString(axis_);
} else { } else {
MS_LOG(ERROR) << name_ << " : The value of axis is not int or tuple int."; MS_LOG(ERROR) << name_ << " : The value of axis is not int or tuple int.";
return FAILED; return FAILED;
...@@ -258,7 +258,7 @@ Status Softmax::GenerateStrategies(int32_t stage_id) { ...@@ -258,7 +258,7 @@ Status Softmax::GenerateStrategies(int32_t stage_id) {
} }
Status ActivationBase::InferDevMatrixShape() { Status ActivationBase::InferDevMatrixShape() {
std::vector<Dimensions> stra = strategy_->GetInputDim(); Strategys stra = strategy_->GetInputDim();
Dimensions input_strategy = stra.at(0); Dimensions input_strategy = stra.at(0);
dev_matrix_shape_ = input_strategy; dev_matrix_shape_ = input_strategy;
...@@ -296,11 +296,11 @@ Status ActivationBase::InferForwardCommunication() { ...@@ -296,11 +296,11 @@ Status ActivationBase::InferForwardCommunication() {
} }
Status ActivationBase::InferTensorMap() { Status ActivationBase::InferTensorMap() {
std::vector<int32_t> tensor_map_index; Shape tensor_map_index;
size_t size = inputs_shape_.at(0).size(); size_t size = inputs_shape_.at(0).size();
// such as 4: tensor_map_index [3,2,1,0] // such as 4: tensor_map_index [3,2,1,0]
for (size_t i = 0; i < size; ++i) { for (size_t i = 0; i < size; ++i) {
tensor_map_index.push_back((int32_t)(size - i - 1)); tensor_map_index.push_back((int64_t)(size - i - 1));
} }
inputs_tensor_map_.push_back(tensor_map_index); inputs_tensor_map_.push_back(tensor_map_index);
...@@ -425,7 +425,7 @@ Status ExpandDimsInfo::InferTensorMap() { ...@@ -425,7 +425,7 @@ Status ExpandDimsInfo::InferTensorMap() {
// for example: if the dimension of input is 3, and the axis is 2, // for example: if the dimension of input is 3, and the axis is 2,
// then the input_tensor_map is [2, 1, 0], the output_tensor_map is [2, 1, -1, 0] // then the input_tensor_map is [2, 1, 0], the output_tensor_map is [2, 1, -1, 0]
std::vector<int32_t> input_tensor_map, output_tensor_map; Shape input_tensor_map, output_tensor_map;
size_t size = inputs_shape_[0].size(); size_t size = inputs_shape_[0].size();
for (size_t i = 0; i < size; ++i) { for (size_t i = 0; i < size; ++i) {
input_tensor_map.push_back(SizeToInt(size - i - 1)); input_tensor_map.push_back(SizeToInt(size - i - 1));
...@@ -607,7 +607,7 @@ Status SqueezeInfo::InferReplaceOps(const StrategyPtr &strategy) { ...@@ -607,7 +607,7 @@ Status SqueezeInfo::InferReplaceOps(const StrategyPtr &strategy) {
Status SqueezeInfo::InferTensorMap() { Status SqueezeInfo::InferTensorMap() {
// for example: if the shape of input is [32, 32, 1], and the axis is (2, ), // for example: if the shape of input is [32, 32, 1], and the axis is (2, ),
// then the input_tensor_map is [2, 1, 0], the output_tensor_map is [2, 1] // then the input_tensor_map is [2, 1, 0], the output_tensor_map is [2, 1]
std::vector<int32_t> input_tensor_map, output_tensor_map; Shape input_tensor_map, output_tensor_map;
if (inputs_shape_.empty()) { if (inputs_shape_.empty()) {
MS_LOG(ERROR) << name_ << ": The inputs shape is empty"; MS_LOG(ERROR) << name_ << ": The inputs shape is empty";
return FAILED; return FAILED;
......
...@@ -54,9 +54,9 @@ Shapes ArithmeticBase::InferExpendShape() { ...@@ -54,9 +54,9 @@ Shapes ArithmeticBase::InferExpendShape() {
return input_shapes; return input_shapes;
} }
std::vector<Dimensions> ExpendStrategy(const StrategyPtr &strategy) { Strategys ExpendStrategy(const StrategyPtr &strategy) {
std::vector<Dimensions> expend_strategy; Strategys expend_strategy;
std::vector<Dimensions> stra = strategy->GetInputDim(); Strategys stra = strategy->GetInputDim();
Dimensions sub_a_strategy = stra.at(0); Dimensions sub_a_strategy = stra.at(0);
Dimensions sub_b_strategy = stra.at(1); Dimensions sub_b_strategy = stra.at(1);
size_t input_a_size = sub_a_strategy.size(); size_t input_a_size = sub_a_strategy.size();
...@@ -83,7 +83,7 @@ Status ArithmeticBase::CheckStrategy(const StrategyPtr &strategy) { ...@@ -83,7 +83,7 @@ Status ArithmeticBase::CheckStrategy(const StrategyPtr &strategy) {
return FAILED; return FAILED;
} }
Shapes input_shapes = InferExpendShape(); Shapes input_shapes = InferExpendShape();
std::vector<Dimensions> expend_strategy = ExpendStrategy(strategy); Strategys expend_strategy = ExpendStrategy(strategy);
Dimensions sub_a_strategy = expend_strategy.at(0); Dimensions sub_a_strategy = expend_strategy.at(0);
Dimensions sub_b_strategy = expend_strategy.at(1); Dimensions sub_b_strategy = expend_strategy.at(1);
Shape input_a_shape = input_shapes.at(0); Shape input_a_shape = input_shapes.at(0);
...@@ -103,7 +103,7 @@ Status ArithmeticBase::CheckStrategy(const StrategyPtr &strategy) { ...@@ -103,7 +103,7 @@ Status ArithmeticBase::CheckStrategy(const StrategyPtr &strategy) {
} }
Status ArithmeticBase::InferDevMatrixShape() { Status ArithmeticBase::InferDevMatrixShape() {
std::vector<Dimensions> expend_strategy = ExpendStrategy(strategy_); Strategys expend_strategy = ExpendStrategy(strategy_);
Dimensions sub_a_strategy = expend_strategy.at(0); Dimensions sub_a_strategy = expend_strategy.at(0);
Dimensions sub_b_strategy = expend_strategy.at(1); Dimensions sub_b_strategy = expend_strategy.at(1);
Shape dev_shape; Shape dev_shape;
...@@ -123,7 +123,7 @@ TensorMap SetExpendTensorMap(const Shape &strategy, const Shape &dev_matrix_shap ...@@ -123,7 +123,7 @@ TensorMap SetExpendTensorMap(const Shape &strategy, const Shape &dev_matrix_shap
TensorMap tensor_map_index; TensorMap tensor_map_index;
for (size_t i = 0; i < strategy.size(); ++i) { for (size_t i = 0; i < strategy.size(); ++i) {
if (strategy[i] == dev_matrix_shape[i]) { if (strategy[i] == dev_matrix_shape[i]) {
tensor_map_index.push_back((int32_t)(LAST_INDEX(SizeToUint(strategy.size())) - i)); tensor_map_index.push_back((int64_t)(LAST_INDEX(strategy.size()) - i));
} else { } else {
tensor_map_index.push_back(-1); tensor_map_index.push_back(-1);
} }
...@@ -159,15 +159,15 @@ void ArithmeticBase::ReComputeBatchSplitFlagList() { ...@@ -159,15 +159,15 @@ void ArithmeticBase::ReComputeBatchSplitFlagList() {
} }
Status ArithmeticBase::InferTensorMap() { Status ArithmeticBase::InferTensorMap() {
std::vector<int32_t> tensor_map_index; Shape tensor_map_index;
std::vector<Dimensions> expend_strategy = ExpendStrategy(strategy_); Strategys expend_strategy = ExpendStrategy(strategy_);
Dimensions sub_a_expend_strategy = expend_strategy.at(0); Dimensions sub_a_expend_strategy = expend_strategy.at(0);
Dimensions sub_b_expend_strategy = expend_strategy.at(1); Dimensions sub_b_expend_strategy = expend_strategy.at(1);
Strategys stra = strategy_->GetInputDim(); Strategys stra = strategy_->GetInputDim();
Dimensions sub_a_strategy = stra.at(0); Dimensions sub_a_strategy = stra.at(0);
Dimensions sub_b_strategy = stra.at(1); Dimensions sub_b_strategy = stra.at(1);
for (size_t i = 0; i < sub_a_expend_strategy.size(); ++i) { for (size_t i = 0; i < sub_a_expend_strategy.size(); ++i) {
tensor_map_index.push_back((int32_t)(LAST_INDEX(SizeToUint(sub_a_expend_strategy.size())) - i)); tensor_map_index.push_back((int64_t)(LAST_INDEX(sub_a_expend_strategy.size()) - i));
} }
Shape dev_shape; Shape dev_shape;
...@@ -261,7 +261,7 @@ Status ArithmeticBase::InferTensorInfo() { ...@@ -261,7 +261,7 @@ Status ArithmeticBase::InferTensorInfo() {
// infer slice shape // infer slice shape
Shapes inputs_slice_shape, outputs_slice_shape; Shapes inputs_slice_shape, outputs_slice_shape;
std::vector<Dimensions> expend_strategy = ExpendStrategy(strategy_); Strategys expend_strategy = ExpendStrategy(strategy_);
Dimensions sub_a_expend_strategy = expend_strategy.at(0); Dimensions sub_a_expend_strategy = expend_strategy.at(0);
Dimensions sub_b_expend_strategy = expend_strategy.at(1); Dimensions sub_b_expend_strategy = expend_strategy.at(1);
Strategys inputs_strategy = strategy_->GetInputDim(); Strategys inputs_strategy = strategy_->GetInputDim();
......
...@@ -43,13 +43,13 @@ Status BatchParallelInfo::CheckStrategy(const StrategyPtr &strategy) { ...@@ -43,13 +43,13 @@ Status BatchParallelInfo::CheckStrategy(const StrategyPtr &strategy) {
dev_num_ = dev_num; dev_num_ = dev_num;
size_t strategy_size = strategy->GetInputNumber(); size_t strategy_size = strategy->GetInputNumber();
std::vector<Dimensions> stra = strategy->GetInputDim(); Strategys stra = strategy->GetInputDim();
for (size_t i = 0; i < strategy_size; ++i) { for (size_t i = 0; i < strategy_size; ++i) {
Shape sub_strategy = stra.at(i); Shape sub_strategy = stra.at(i);
size_t strategy_len = sub_strategy.size(); size_t strategy_len = sub_strategy.size();
bool flag = false; bool flag = false;
for (size_t j = 0; j < strategy_len; ++j) { for (size_t j = 0; j < strategy_len; ++j) {
int32_t strategy_value = sub_strategy.at(j); int64_t strategy_value = sub_strategy.at(j);
if (strategy_value > 1) { if (strategy_value > 1) {
if (flag || strategy_value != dev_num_) { if (flag || strategy_value != dev_num_) {
if (is_auto_parallel_) { if (is_auto_parallel_) {
...@@ -95,7 +95,7 @@ Status BatchParallelInfo::InferTensorMap() { ...@@ -95,7 +95,7 @@ Status BatchParallelInfo::InferTensorMap() {
return FAILED; return FAILED;
} }
for (size_t i = 0; i < inputs_shape_.size(); i++) { for (size_t i = 0; i < inputs_shape_.size(); i++) {
std::vector<int32_t> tensor_map_index; Shape tensor_map_index;
for (size_t j = 0; j < inputs_shape_[i].size(); ++j) { for (size_t j = 0; j < inputs_shape_[i].size(); ++j) {
if (strategy_->GetInputDim()[i][j] == dev_num_ && j == 0) { if (strategy_->GetInputDim()[i][j] == dev_num_ && j == 0) {
tensor_map_index.push_back(0); tensor_map_index.push_back(0);
...@@ -106,7 +106,7 @@ Status BatchParallelInfo::InferTensorMap() { ...@@ -106,7 +106,7 @@ Status BatchParallelInfo::InferTensorMap() {
inputs_tensor_map_.push_back(tensor_map_index); inputs_tensor_map_.push_back(tensor_map_index);
} }
for (size_t i = 0; i < outputs_shape_.size(); i++) { for (size_t i = 0; i < outputs_shape_.size(); i++) {
std::vector<int32_t> tensor_map_index; Shape tensor_map_index;
for (size_t j = 0; j < outputs_shape_[i].size(); ++j) { for (size_t j = 0; j < outputs_shape_[i].size(); ++j) {
if (i == 0 && j == 0) { if (i == 0 && j == 0) {
tensor_map_index.push_back(0); tensor_map_index.push_back(0);
...@@ -123,7 +123,7 @@ Strategys BatchParallelInfo::GetOutputsStrategy() { ...@@ -123,7 +123,7 @@ Strategys BatchParallelInfo::GetOutputsStrategy() {
Strategys outputs_strategy; Strategys outputs_strategy;
for (size_t i = 0; i < outputs_shape_.size(); ++i) { for (size_t i = 0; i < outputs_shape_.size(); ++i) {
std::vector<int32_t> strategy; Dimensions strategy;
for (size_t j = 0; j < outputs_shape_[i].size(); ++j) { for (size_t j = 0; j < outputs_shape_[i].size(); ++j) {
if (i == 0 && j == 0) { if (i == 0 && j == 0) {
strategy.push_back(dev_num_); strategy.push_back(dev_num_);
...@@ -201,7 +201,7 @@ Status BatchParallelInfo::GenerateStrategies(int32_t stage_id) { ...@@ -201,7 +201,7 @@ Status BatchParallelInfo::GenerateStrategies(int32_t stage_id) {
is_auto_parallel_ = true; is_auto_parallel_ = true;
size_t total_dev_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); size_t total_dev_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
StrategyPtr sp; StrategyPtr sp;
std::vector<Dimensions> strategy; Strategys strategy;
for (size_t i = 0; i < inputs_shape_.size(); i++) { for (size_t i = 0; i < inputs_shape_.size(); i++) {
Shape temp(inputs_shape_[i].size(), 1); Shape temp(inputs_shape_[i].size(), 1);
if (split_flag_list_[i]) { if (split_flag_list_[i]) {
......
...@@ -36,11 +36,11 @@ Status BiasAddInfo::CheckStrategy(const StrategyPtr &strategy) { ...@@ -36,11 +36,11 @@ Status BiasAddInfo::CheckStrategy(const StrategyPtr &strategy) {
} }
return FAILED; return FAILED;
} }
std::vector<Dimensions> stra = strategy->GetInputDim(); Strategys stra = strategy->GetInputDim();
Dimensions sub_a_strategy = stra.at(0); Dimensions sub_a_strategy = stra.at(0);
Dimensions sub_b_strategy = stra.at(1); Dimensions sub_b_strategy = stra.at(1);
int32_t channel_a_strategy = sub_a_strategy.at(1); int64_t channel_a_strategy = sub_a_strategy.at(1);
int32_t channel_b_strategy = sub_b_strategy.at(0); int64_t channel_b_strategy = sub_b_strategy.at(0);
if (channel_a_strategy != channel_b_strategy) { if (channel_a_strategy != channel_b_strategy) {
if (is_auto_parallel_) { if (is_auto_parallel_) {
MS_LOG(DEBUG) << name_ << " : Invalid strategy."; MS_LOG(DEBUG) << name_ << " : Invalid strategy.";
...@@ -53,7 +53,7 @@ Status BiasAddInfo::CheckStrategy(const StrategyPtr &strategy) { ...@@ -53,7 +53,7 @@ Status BiasAddInfo::CheckStrategy(const StrategyPtr &strategy) {
} }
Status BiasAddInfo::InferDevMatrixShape() { Status BiasAddInfo::InferDevMatrixShape() {
std::vector<Dimensions> stra = strategy_->GetInputDim(); Strategys stra = strategy_->GetInputDim();
Dimensions sub_a_strategy = stra.at(0); Dimensions sub_a_strategy = stra.at(0);
dev_matrix_shape_ = sub_a_strategy; dev_matrix_shape_ = sub_a_strategy;
return SUCCESS; return SUCCESS;
...@@ -67,13 +67,13 @@ void BiasAddInfo::ReComputeBatchSplitFlagList() { ...@@ -67,13 +67,13 @@ void BiasAddInfo::ReComputeBatchSplitFlagList() {
Status BiasAddInfo::InferTensorMap() { Status BiasAddInfo::InferTensorMap() {
TensorMap sub_a_tensor_map; TensorMap sub_a_tensor_map;
TensorMap sub_b_tensor_map; TensorMap sub_b_tensor_map;
std::vector<Dimensions> stra = strategy_->GetInputDim(); Strategys stra = strategy_->GetInputDim();
Dimensions sub_a_strategy = stra.at(0); Dimensions sub_a_strategy = stra.at(0);
size_t sub_a_strategy_size = sub_a_strategy.size(); size_t sub_a_strategy_size = sub_a_strategy.size();
for (size_t i = 0; i < sub_a_strategy_size; ++i) { for (size_t i = 0; i < sub_a_strategy_size; ++i) {
sub_a_tensor_map.push_back((int32_t)(LAST_INDEX(SizeToUint(sub_a_strategy_size)) - i)); sub_a_tensor_map.push_back((int32_t)(LAST_INDEX(sub_a_strategy_size) - i));
} }
sub_b_tensor_map.push_back((int32_t)(LAST_INDEX(SizeToUint(sub_a_strategy_size)) - 1)); sub_b_tensor_map.push_back((int32_t)(LAST_INDEX(sub_a_strategy_size) - 1));
inputs_tensor_map_.push_back(sub_a_tensor_map); inputs_tensor_map_.push_back(sub_a_tensor_map);
inputs_tensor_map_.push_back(sub_b_tensor_map); inputs_tensor_map_.push_back(sub_b_tensor_map);
...@@ -213,7 +213,7 @@ Status BiasAddInfo::GenerateStrategies(int32_t stage_id) { ...@@ -213,7 +213,7 @@ Status BiasAddInfo::GenerateStrategies(int32_t stage_id) {
MS_LOG(INFO) << name_ << " : Generate strategies with broadcast success."; MS_LOG(INFO) << name_ << " : Generate strategies with broadcast success.";
for (auto &sp : sp_vector) { for (auto &sp : sp_vector) {
std::vector<Dimensions> tmp_strategy; Strategys tmp_strategy;
Dimensions input0_strategy = sp->GetInputDim()[0]; Dimensions input0_strategy = sp->GetInputDim()[0];
tmp_strategy.push_back(input0_strategy); // input0 tmp_strategy.push_back(input0_strategy); // input0
......
...@@ -38,7 +38,7 @@ Status DropoutDoMaskInfo::CheckStrategy(const StrategyPtr &strategy) { ...@@ -38,7 +38,7 @@ Status DropoutDoMaskInfo::CheckStrategy(const StrategyPtr &strategy) {
return FAILED; return FAILED;
} }
std::vector<Dimensions> stra = strategy->GetInputDim(); Strategys stra = strategy->GetInputDim();
if (stra.size() != 1) { if (stra.size() != 1) {
MS_LOG(ERROR) << name_ << ": Invalid strategy size " << stra.size() << ", it must be 1"; MS_LOG(ERROR) << name_ << ": Invalid strategy size " << stra.size() << ", it must be 1";
return FAILED; return FAILED;
...@@ -68,7 +68,7 @@ Status DropoutDoMaskInfo::InferDevMatrixShape() { ...@@ -68,7 +68,7 @@ Status DropoutDoMaskInfo::InferDevMatrixShape() {
return FAILED; return FAILED;
} }
std::vector<Dimensions> strategy = strategy_->GetInputDim(); Strategys strategy = strategy_->GetInputDim();
if (strategy.empty()) { if (strategy.empty()) {
MS_LOG(ERROR) << name_ << ": The strategy is empty"; MS_LOG(ERROR) << name_ << ": The strategy is empty";
return FAILED; return FAILED;
...@@ -84,7 +84,7 @@ Status DropoutDoMaskInfo::InferTensorMap() { ...@@ -84,7 +84,7 @@ Status DropoutDoMaskInfo::InferTensorMap() {
return FAILED; return FAILED;
} }
std::vector<int32_t> tensor_map_index; Shape tensor_map_index;
size_t size = inputs_shape_[0].size(); size_t size = inputs_shape_[0].size();
// if the dimension of input is 4, and tensor_map_index is [3, 2, 1, 0] // if the dimension of input is 4, and tensor_map_index is [3, 2, 1, 0]
for (size_t i = 0; i < size; ++i) { for (size_t i = 0; i < size; ++i) {
...@@ -169,13 +169,13 @@ Status DropoutDoMaskInfo::GenerateStrategies(int32_t stage_id) { ...@@ -169,13 +169,13 @@ Status DropoutDoMaskInfo::GenerateStrategies(int32_t stage_id) {
return SUCCESS; return SUCCESS;
} }
std::shared_ptr<std::vector<std::vector<int32_t>>> DropoutDoMaskInfo::GenerateBatchStrategies() { std::shared_ptr<Strategys> DropoutDoMaskInfo::GenerateBatchStrategies() {
CheckGlobalDeviceManager(); CheckGlobalDeviceManager();
size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size();
Dimensions strategy(inputs_shape_[0].size() - 1, 1); Dimensions strategy(inputs_shape_[0].size() - 1, 1);
(void)strategy.insert(strategy.begin(), SizeToInt(dev_num)); (void)strategy.insert(strategy.begin(), SizeToInt(dev_num));
std::vector<Dimensions> strategy_v = {strategy}; Strategys strategy_v = {strategy};
return std::make_shared<std::vector<std::vector<int32_t>>>(strategy_v); return std::make_shared<Strategys>(strategy_v);
} }
Status DropoutDoMaskInfo::Init(const StrategyPtr &strategy) { Status DropoutDoMaskInfo::Init(const StrategyPtr &strategy) {
......
...@@ -40,7 +40,7 @@ class DropoutDoMaskInfo : public OperatorInfo { ...@@ -40,7 +40,7 @@ class DropoutDoMaskInfo : public OperatorInfo {
Status GenerateStrategies(int32_t stage_id) override; Status GenerateStrategies(int32_t stage_id) override;
Status SetCostUnderStrategy(const StrategyPtr &strategy) override; Status SetCostUnderStrategy(const StrategyPtr &strategy) override;
Status InitForCostModel(const StrategyPtr &strategy) override; Status InitForCostModel(const StrategyPtr &strategy) override;
std::shared_ptr<std::vector<std::vector<int32_t>>> GenerateBatchStrategies() override; std::shared_ptr<Strategys> GenerateBatchStrategies() override;
std::vector<Operator> GetDropoutGenMaskReplaceOp(const CNodePtr &cnode); std::vector<Operator> GetDropoutGenMaskReplaceOp(const CNodePtr &cnode);
protected: protected:
......
...@@ -109,7 +109,7 @@ Status GatherV2Info::CheckStrategy(const StrategyPtr &strategy) { ...@@ -109,7 +109,7 @@ Status GatherV2Info::CheckStrategy(const StrategyPtr &strategy) {
} }
Status GatherV2Info::InferDevMatrixShape() { Status GatherV2Info::InferDevMatrixShape() {
std::vector<Dimensions> stra = strategy_->GetInputDim(); Strategys stra = strategy_->GetInputDim();
dev_matrix_shape_ = stra.at(0); dev_matrix_shape_ = stra.at(0);
return SUCCESS; return SUCCESS;
} }
...@@ -129,8 +129,8 @@ Status GatherV2Info::InferTensorMap() { ...@@ -129,8 +129,8 @@ Status GatherV2Info::InferTensorMap() {
<< outputs_shape_.size(); << outputs_shape_.size();
return FAILED; return FAILED;
} }
std::vector<int32_t> tensor_map_in; Shape tensor_map_in;
std::vector<int32_t> tensor_map_out; Shape tensor_map_out;
size_t size = inputs_shape_.at(0).size(); size_t size = inputs_shape_.at(0).size();
// such as 4: tensor_map_index [3,2,1,0] // such as 4: tensor_map_index [3,2,1,0]
for (size_t i = 0; i < size; ++i) { for (size_t i = 0; i < size; ++i) {
...@@ -149,7 +149,7 @@ Status GatherV2Info::InferTensorMap() { ...@@ -149,7 +149,7 @@ Status GatherV2Info::InferTensorMap() {
return FAILED; return FAILED;
} }
std::vector<int32_t> tensor_map_in_index; Shape tensor_map_in_index;
if (index_size_ >= 1) { if (index_size_ >= 1) {
tensor_map_in_index.push_back(SizeToInt(size - axis_ - 1)); tensor_map_in_index.push_back(SizeToInt(size - axis_ - 1));
} }
...@@ -323,7 +323,7 @@ Status GatherV2Info::SetCostUnderStrategy(const StrategyPtr &strategy) { ...@@ -323,7 +323,7 @@ Status GatherV2Info::SetCostUnderStrategy(const StrategyPtr &strategy) {
return SUCCESS; return SUCCESS;
} }
std::shared_ptr<std::vector<std::vector<int32_t>>> GatherV2Info::GenerateBatchStrategies() { std::shared_ptr<Strategys> GatherV2Info::GenerateBatchStrategies() {
if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) { if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) {
MS_LOG(EXCEPTION) << name_ << ": inputs shape size must be " << GATHER_V2_INPUTS_SIZE << ", but is " MS_LOG(EXCEPTION) << name_ << ": inputs shape size must be " << GATHER_V2_INPUTS_SIZE << ", but is "
<< inputs_shape_.size(); << inputs_shape_.size();
...@@ -343,8 +343,8 @@ std::shared_ptr<std::vector<std::vector<int32_t>>> GatherV2Info::GenerateBatchSt ...@@ -343,8 +343,8 @@ std::shared_ptr<std::vector<std::vector<int32_t>>> GatherV2Info::GenerateBatchSt
for (size_t i = 1; i < inputs_shape_[0].size(); i++) { for (size_t i = 1; i < inputs_shape_[0].size(); i++) {
strategy.push_back(1); strategy.push_back(1);
} }
std::vector<Dimensions> strategy_v = {strategy}; Strategys strategy_v = {strategy};
return std::make_shared<std::vector<std::vector<int32_t>>>(strategy_v); return std::make_shared<Strategys>(strategy_v);
} }
} // namespace parallel } // namespace parallel
} // namespace mindspore } // namespace mindspore
...@@ -50,7 +50,7 @@ class GatherV2Info : public OperatorInfo { ...@@ -50,7 +50,7 @@ class GatherV2Info : public OperatorInfo {
Status GenerateStrategies(int32_t stage_id) override; Status GenerateStrategies(int32_t stage_id) override;
Status SetCostUnderStrategy(const StrategyPtr &strategy) override; Status SetCostUnderStrategy(const StrategyPtr &strategy) override;
std::shared_ptr<std::vector<std::vector<int32_t>>> GenerateBatchStrategies() override; std::shared_ptr<Strategys> GenerateBatchStrategies() override;
protected: protected:
Status CheckStrategy(const StrategyPtr &strategy) override; Status CheckStrategy(const StrategyPtr &strategy) override;
......
...@@ -73,8 +73,8 @@ Status GatherV2PInfo::GetAttrs() { ...@@ -73,8 +73,8 @@ Status GatherV2PInfo::GetAttrs() {
MS_LOG(ERROR) << "Failure: Size of manual_split element must be 2."; MS_LOG(ERROR) << "Failure: Size of manual_split element must be 2.";
return FAILED; return FAILED;
} }
param_split_shapes_.push_back(static_cast<int32_t>(GetValue<int>(value_vector[0]))); param_split_shapes_.push_back(static_cast<int64_t>(GetValue<int>(value_vector[0])));
index_offsets_.push_back(static_cast<int32_t>(GetValue<int>(value_vector[1]))); index_offsets_.push_back(static_cast<int64_t>(GetValue<int>(value_vector[1])));
} else { } else {
MS_LOG(ERROR) << "Failure: Manual split strategy's format is wrong! Need ValueSequeue"; MS_LOG(ERROR) << "Failure: Manual split strategy's format is wrong! Need ValueSequeue";
return FAILED; return FAILED;
...@@ -93,14 +93,14 @@ Status GatherV2PInfo::GetAttrs() { ...@@ -93,14 +93,14 @@ Status GatherV2PInfo::GetAttrs() {
Status GatherV2PInfo::CheckManualSplit() { Status GatherV2PInfo::CheckManualSplit() {
auto param_shape = inputs_shape_.at(0); auto param_shape = inputs_shape_.at(0);
int32_t split_shape_sum = std::accumulate(param_split_shapes_.begin(), param_split_shapes_.end(), 0, int64_t split_shape_sum = std::accumulate(param_split_shapes_.begin(), param_split_shapes_.end(), 0,
[](int32_t s, int32_t shape) { return s + shape; }); [](int64_t s, int64_t shape) { return s + shape; });
if (split_shape_sum < param_shape.at(0)) { if (split_shape_sum < param_shape.at(0)) {
MS_LOG(ERROR) << "Failure: Sum of splited shapes should not be smaller than param_shape."; MS_LOG(ERROR) << "Failure: Sum of splited shapes should not be smaller than param_shape.";
return FAILED; return FAILED;
} }
if (std::any_of(index_offsets_.begin(), index_offsets_.end(), [](const int32_t &offset) { return offset < 0; })) { if (std::any_of(index_offsets_.begin(), index_offsets_.end(), [](const int64_t &offset) { return offset < 0; })) {
MS_LOG(ERROR) << "Failure: Index offset must not less than 0."; MS_LOG(ERROR) << "Failure: Index offset must not less than 0.";
return FAILED; return FAILED;
} }
...@@ -269,8 +269,8 @@ Status GatherV2PInfo::InferTensorMap() { ...@@ -269,8 +269,8 @@ Status GatherV2PInfo::InferTensorMap() {
size_t param_size = inputs_shape_.at(0).size(); size_t param_size = inputs_shape_.at(0).size();
size_t index_size = inputs_shape_.at(1).size(); size_t index_size = inputs_shape_.at(1).size();
size_t total_size = param_size + index_size; size_t total_size = param_size + index_size;
std::vector<int32_t> tensor_map_index; Shape tensor_map_index;
std::vector<int32_t> tensor_map_params; Shape tensor_map_params;
auto param_strategy = strategy_->GetInputDim().at(0); auto param_strategy = strategy_->GetInputDim().at(0);
if (param_strategy.at(IntToSize(axis_)) != 1) { if (param_strategy.at(IntToSize(axis_)) != 1) {
tensor_map_index.insert(tensor_map_index.begin(), index_size, -1); tensor_map_index.insert(tensor_map_index.begin(), index_size, -1);
...@@ -288,7 +288,7 @@ Status GatherV2PInfo::InferTensorMap() { ...@@ -288,7 +288,7 @@ Status GatherV2PInfo::InferTensorMap() {
} }
// infer output tensor map // infer output tensor map
std::vector<int32_t> tensor_map_out; Shape tensor_map_out;
if (param_strategy.at(IntToSize(axis_)) == 1) { if (param_strategy.at(IntToSize(axis_)) == 1) {
// param_strategy(axis) == 1 // param_strategy(axis) == 1
for (size_t i = 0; i < param_size; ++i) { for (size_t i = 0; i < param_size; ++i) {
...@@ -427,8 +427,8 @@ Status GatherV2PInfo::InferGroup() { ...@@ -427,8 +427,8 @@ Status GatherV2PInfo::InferGroup() {
return SUCCESS; return SUCCESS;
} }
std::vector<int32_t> GetRankFromGroup(const Group &group) { RankList GetRankFromGroup(const Group &group) {
std::vector<int32_t> rank_list; RankList rank_list;
auto device_list = group.GetDevicesList(); auto device_list = group.GetDevicesList();
for (auto &device : device_list) { for (auto &device : device_list) {
rank_list.insert(rank_list.end(), device.rank() % 8); rank_list.insert(rank_list.end(), device.rank() % 8);
...@@ -634,7 +634,7 @@ Status GatherV2PInfo::GenerateStrategies(int32_t stage_id) { ...@@ -634,7 +634,7 @@ Status GatherV2PInfo::GenerateStrategies(int32_t stage_id) {
return SUCCESS; return SUCCESS;
} }
std::shared_ptr<std::vector<std::vector<int32_t>>> GatherV2PInfo::GenerateBatchStrategies() { std::shared_ptr<Strategys> GatherV2PInfo::GenerateBatchStrategies() {
CheckGlobalDeviceManager(); CheckGlobalDeviceManager();
size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size();
Dimensions param_strategy(inputs_shape_[0].size(), 1); Dimensions param_strategy(inputs_shape_[0].size(), 1);
...@@ -643,8 +643,8 @@ std::shared_ptr<std::vector<std::vector<int32_t>>> GatherV2PInfo::GenerateBatchS ...@@ -643,8 +643,8 @@ std::shared_ptr<std::vector<std::vector<int32_t>>> GatherV2PInfo::GenerateBatchS
for (size_t i = 1; i < inputs_shape_[1].size(); i++) { for (size_t i = 1; i < inputs_shape_[1].size(); i++) {
index_strategy.push_back(1); index_strategy.push_back(1);
} }
std::vector<Dimensions> strategy_v = {param_strategy, index_strategy}; Strategys strategy_v = {param_strategy, index_strategy};
return std::make_shared<std::vector<std::vector<int32_t>>>(strategy_v); return std::make_shared<Strategys>(strategy_v);
} }
} // namespace parallel } // namespace parallel
} // namespace mindspore } // namespace mindspore
...@@ -45,7 +45,7 @@ class GatherV2PInfo : public OperatorInfo { ...@@ -45,7 +45,7 @@ class GatherV2PInfo : public OperatorInfo {
Status GenerateStrategies(int32_t stage_id) override; Status GenerateStrategies(int32_t stage_id) override;
Status SetCostUnderStrategy(const StrategyPtr &strategy) override; Status SetCostUnderStrategy(const StrategyPtr &strategy) override;
ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override; ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override;
std::shared_ptr<std::vector<std::vector<int32_t>>> GenerateBatchStrategies() override; std::shared_ptr<Strategys> GenerateBatchStrategies() override;
protected: protected:
Status CheckStrategy(const StrategyPtr &strategy) override; Status CheckStrategy(const StrategyPtr &strategy) override;
...@@ -67,13 +67,13 @@ class GatherV2PInfo : public OperatorInfo { ...@@ -67,13 +67,13 @@ class GatherV2PInfo : public OperatorInfo {
std::string target_ = DEVICE; std::string target_ = DEVICE;
std::string replace_op_name_ = GATHERV2; std::string replace_op_name_ = GATHERV2;
int32_t bias_; int32_t bias_;
int32_t index_offset_; int64_t index_offset_;
int32_t slice_size_; int32_t slice_size_;
Shape out_dev_matrix_shape_; Shape out_dev_matrix_shape_;
Group group_; Group group_;
bool manual_split_ = false; bool manual_split_ = false;
std::vector<int32_t> param_split_shapes_; std::vector<int64_t> param_split_shapes_;
std::vector<int32_t> index_offsets_; std::vector<int64_t> index_offsets_;
}; };
class SparseGatherV2Info : public GatherV2PInfo { class SparseGatherV2Info : public GatherV2PInfo {
......
...@@ -118,7 +118,7 @@ Status GetNextInfo::Init(const StrategyPtr &strategy) { ...@@ -118,7 +118,7 @@ Status GetNextInfo::Init(const StrategyPtr &strategy) {
} }
Status GetNextInfo::CheckStrategy(const StrategyPtr &strategy) { Status GetNextInfo::CheckStrategy(const StrategyPtr &strategy) {
std::vector<Dimensions> stras = strategy->GetInputDim(); Strategys stras = strategy->GetInputDim();
for (Dimensions stra : stras) { for (Dimensions stra : stras) {
if (stra.size() != 0) { if (stra.size() != 0) {
if (is_auto_parallel_) { if (is_auto_parallel_) {
...@@ -254,7 +254,7 @@ Status GetNextInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { ...@@ -254,7 +254,7 @@ Status GetNextInfo::SetCostUnderStrategy(const StrategyPtr &strategy) {
Status GetNextInfo::GenerateStrategies(int32_t stage_id) { Status GetNextInfo::GenerateStrategies(int32_t stage_id) {
is_auto_parallel_ = true; is_auto_parallel_ = true;
std::vector<Dimensions> stra; Strategys stra;
StrategyPtr sp = std::make_shared<Strategy>(stage_id, stra); StrategyPtr sp = std::make_shared<Strategy>(stage_id, stra);
if (SetCostUnderStrategy(sp) == SUCCESS) { if (SetCostUnderStrategy(sp) == SUCCESS) {
MS_LOG(INFO) << name_ << " : Successfully generated strategy."; MS_LOG(INFO) << name_ << " : Successfully generated strategy.";
......
...@@ -37,7 +37,7 @@ Status L2NormalizeInfo::CheckStrategy(const StrategyPtr &strategy) { ...@@ -37,7 +37,7 @@ Status L2NormalizeInfo::CheckStrategy(const StrategyPtr &strategy) {
return FAILED; return FAILED;
} }
std::vector<Dimensions> stra = strategy->GetInputDim(); Strategys stra = strategy->GetInputDim();
Dimensions input_strategy = stra.at(0); Dimensions input_strategy = stra.at(0);
int32_t axis_index = axis_; int32_t axis_index = axis_;
if (axis_ < 0) { if (axis_ < 0) {
......
...@@ -49,7 +49,7 @@ Status LayerNormInfo::GetAttrs() { ...@@ -49,7 +49,7 @@ Status LayerNormInfo::GetAttrs() {
Status LayerNormInfo::CheckStrategy(const StrategyPtr &strategy) { Status LayerNormInfo::CheckStrategy(const StrategyPtr &strategy) {
MS_EXCEPTION_IF_NULL(strategy); MS_EXCEPTION_IF_NULL(strategy);
std::vector<Dimensions> stra = strategy->GetInputDim(); Strategys stra = strategy->GetInputDim();
if (stra.size() != LAYER_NORM_INPUT_SIZE) { if (stra.size() != LAYER_NORM_INPUT_SIZE) {
MS_LOG(ERROR) << name_ << ": Invalid strategy size " << stra.size(); MS_LOG(ERROR) << name_ << ": Invalid strategy size " << stra.size();
return FAILED; return FAILED;
...@@ -104,7 +104,7 @@ Status LayerNormInfo::InferDevMatrixShape() { ...@@ -104,7 +104,7 @@ Status LayerNormInfo::InferDevMatrixShape() {
MS_LOG(ERROR) << name_ << ": The strategy is null"; MS_LOG(ERROR) << name_ << ": The strategy is null";
return FAILED; return FAILED;
} }
std::vector<Dimensions> stra = strategy_->GetInputDim(); Strategys stra = strategy_->GetInputDim();
if (stra.empty()) { if (stra.empty()) {
MS_LOG(ERROR) << name_ << ": The strategy is empty"; MS_LOG(ERROR) << name_ << ": The strategy is empty";
return FAILED; return FAILED;
...@@ -228,7 +228,7 @@ Status LayerNormInfo::GenerateGammaAndBetaStrategies(const std::vector<StrategyP ...@@ -228,7 +228,7 @@ Status LayerNormInfo::GenerateGammaAndBetaStrategies(const std::vector<StrategyP
MS_LOG(ERROR) << name_ << ": Invalid strategy"; MS_LOG(ERROR) << name_ << ": Invalid strategy";
return FAILED; return FAILED;
} }
std::vector<Dimensions> tmp_strategy; Strategys tmp_strategy;
Dimensions input_strategy = sp->GetInputDim()[0]; Dimensions input_strategy = sp->GetInputDim()[0];
Dimensions gamma_strategy = input_strategy; Dimensions gamma_strategy = input_strategy;
(void)gamma_strategy.erase(gamma_strategy.begin(), (void)gamma_strategy.erase(gamma_strategy.begin(),
......
...@@ -38,7 +38,7 @@ Status SoftmaxCrossEntropyWithLogitsInfo::CheckStrategy(const mindspore::paralle ...@@ -38,7 +38,7 @@ Status SoftmaxCrossEntropyWithLogitsInfo::CheckStrategy(const mindspore::paralle
return FAILED; return FAILED;
} }
std::vector<Dimensions> stra = strategy->GetInputDim(); Strategys stra = strategy->GetInputDim();
Dimensions input_strategy = stra.at(0); Dimensions input_strategy = stra.at(0);
Dimensions label_strategy = stra.at(1); Dimensions label_strategy = stra.at(1);
if (input_strategy != label_strategy) { if (input_strategy != label_strategy) {
...@@ -52,8 +52,8 @@ Status SoftmaxCrossEntropyWithLogitsInfo::CheckStrategy(const mindspore::paralle ...@@ -52,8 +52,8 @@ Status SoftmaxCrossEntropyWithLogitsInfo::CheckStrategy(const mindspore::paralle
axis_index = static_cast<int32_t>(input_dim) + axis_; axis_index = static_cast<int32_t>(input_dim) + axis_;
} }
int32_t input_axis_strategy = input_strategy.at(IntToSize(axis_index)); int64_t input_axis_strategy = input_strategy.at(IntToSize(axis_index));
int32_t label_axis_strategy = label_strategy.at(IntToSize(axis_index)); int64_t label_axis_strategy = label_strategy.at(IntToSize(axis_index));
// Dimension corresponding to axis is un-splittable // Dimension corresponding to axis is un-splittable
if ((input_axis_strategy != MIN_SLICE_NUM) && (label_axis_strategy != MIN_SLICE_NUM)) { if ((input_axis_strategy != MIN_SLICE_NUM) && (label_axis_strategy != MIN_SLICE_NUM)) {
if (is_auto_parallel_) { if (is_auto_parallel_) {
...@@ -82,21 +82,21 @@ Status SoftmaxCrossEntropyWithLogitsInfo::GetAttrs() { ...@@ -82,21 +82,21 @@ Status SoftmaxCrossEntropyWithLogitsInfo::GetAttrs() {
} }
Status SoftmaxCrossEntropyWithLogitsInfo::InferDevMatrixShape() { Status SoftmaxCrossEntropyWithLogitsInfo::InferDevMatrixShape() {
std::vector<Dimensions> stra = strategy_->GetInputDim(); Strategys stra = strategy_->GetInputDim();
Dimensions input_strategy = stra.at(0); Dimensions input_strategy = stra.at(0);
dev_matrix_shape_ = input_strategy; dev_matrix_shape_ = input_strategy;
return SUCCESS; return SUCCESS;
} }
Status SoftmaxCrossEntropyWithLogitsInfo::InferTensorMap() { Status SoftmaxCrossEntropyWithLogitsInfo::InferTensorMap() {
std::vector<int32_t> tensor_map_index; Shape tensor_map_index;
size_t size = inputs_shape_[0].size(); size_t size = inputs_shape_[0].size();
// such as 4: tensor_map_index [3,2,1,0] // such as 4: tensor_map_index [3,2,1,0]
for (size_t i = 0; i < size; ++i) { for (size_t i = 0; i < size; ++i) {
tensor_map_index.push_back((int32_t)(size - i - 1)); tensor_map_index.push_back((int64_t)(size - i - 1));
} }
std::vector<int32_t> first_output_tensor_map = {tensor_map_index[0]}; Shape first_output_tensor_map = {tensor_map_index[0]};
inputs_tensor_map_.push_back(tensor_map_index); // input inputs_tensor_map_.push_back(tensor_map_index); // input
inputs_tensor_map_.push_back(tensor_map_index); // label inputs_tensor_map_.push_back(tensor_map_index); // label
outputs_tensor_map_.push_back(first_output_tensor_map); // output-0 outputs_tensor_map_.push_back(first_output_tensor_map); // output-0
......
...@@ -158,7 +158,7 @@ Status MatMul::CheckStrategy(const StrategyPtr &strategy) { ...@@ -158,7 +158,7 @@ Status MatMul::CheckStrategy(const StrategyPtr &strategy) {
return FAILED; return FAILED;
} }
std::vector<Dimensions> stra = strategy->GetInputDim(); Strategys stra = strategy->GetInputDim();
Dimensions mat_a_strategy = stra.at(0); Dimensions mat_a_strategy = stra.at(0);
Dimensions mat_b_strategy = stra.at(1); Dimensions mat_b_strategy = stra.at(1);
...@@ -207,7 +207,7 @@ Status MatMul::CheckStrategy(const StrategyPtr &strategy) { ...@@ -207,7 +207,7 @@ Status MatMul::CheckStrategy(const StrategyPtr &strategy) {
} }
Status MatMulBase::InferDevMatrixShape() { Status MatMulBase::InferDevMatrixShape() {
std::vector<Dimensions> stra = strategy_->GetInputDim(); Strategys stra = strategy_->GetInputDim();
Dimensions mat_a_strategy = stra.at(0); Dimensions mat_a_strategy = stra.at(0);
Dimensions mat_b_strategy = stra.at(1); Dimensions mat_b_strategy = stra.at(1);
...@@ -279,10 +279,10 @@ Status MatMulBase::InferTensorMap() { ...@@ -279,10 +279,10 @@ Status MatMulBase::InferTensorMap() {
size = dev_matrix_shape_.size() - 1; size = dev_matrix_shape_.size() - 1;
} }
std::vector<int32_t> tensor_map_index; Shape tensor_map_index;
// such as 5: tensor_map_index [4,3,2,1,0] // such as 5: tensor_map_index [4,3,2,1,0]
for (size_t i = 0; i < size; ++i) { for (size_t i = 0; i < size; ++i) {
tensor_map_index.push_back((int32_t)(LAST_INDEX(size) - i)); tensor_map_index.push_back((int64_t)(LAST_INDEX(size) - i));
} }
// infer output tensor map: [4,3,2,0], delete the second-from-end element // infer output tensor map: [4,3,2,0], delete the second-from-end element
...@@ -309,7 +309,7 @@ Status MatMulBase::InferTensorMap() { ...@@ -309,7 +309,7 @@ Status MatMulBase::InferTensorMap() {
mat_b_tensor_map.begin() + static_cast<different_type>(LAST_INDEX(size) - mat_b_dimension_)); mat_b_tensor_map.begin() + static_cast<different_type>(LAST_INDEX(size) - mat_b_dimension_));
if (transpose_b_) { if (transpose_b_) {
// swap the last two elements // swap the last two elements
int32_t last_value = mat_b_tensor_map.back(); int64_t last_value = mat_b_tensor_map.back();
mat_b_tensor_map.pop_back(); mat_b_tensor_map.pop_back();
(void)mat_b_tensor_map.insert( (void)mat_b_tensor_map.insert(
mat_b_tensor_map.begin() + static_cast<different_type>(LAST_INDEX(mat_b_tensor_map.size())), last_value); mat_b_tensor_map.begin() + static_cast<different_type>(LAST_INDEX(mat_b_tensor_map.size())), last_value);
...@@ -436,7 +436,7 @@ Status MatMulBase::GenerateStrategies(int32_t stage_id) { ...@@ -436,7 +436,7 @@ Status MatMulBase::GenerateStrategies(int32_t stage_id) {
return FAILED; return FAILED;
} }
CheckGlobalDeviceManager(); CheckGlobalDeviceManager();
std::vector<int32_t> dev_list = g_device_manager->GetDeviceListByStageId(stage_id); RankList dev_list = g_device_manager->GetDeviceListByStageId(stage_id);
size_t dev_num = dev_list.size(); size_t dev_num = dev_list.size();
Shape input0_shape = inputs_shape_[0], input1_shape = inputs_shape_[1]; Shape input0_shape = inputs_shape_[0], input1_shape = inputs_shape_[1];
if (transpose_a_) { if (transpose_a_) {
...@@ -503,13 +503,14 @@ Status MatMulBase::GenerateStrategies(int32_t stage_id) { ...@@ -503,13 +503,14 @@ Status MatMulBase::GenerateStrategies(int32_t stage_id) {
Status MatMulBase::PrepareStrategy(int32_t stage_id, size_t dev_num, Status MatMulBase::PrepareStrategy(int32_t stage_id, size_t dev_num,
mindspore::parallel::Dimensions combined_partitions, size_t input0_shape_size, mindspore::parallel::Dimensions combined_partitions, size_t input0_shape_size,
size_t input1_shape_size, mindspore::parallel::StrategyPtr *const sp) { size_t input1_shape_size, mindspore::parallel::StrategyPtr *const sp) {
int32_t product = std::accumulate(combined_partitions.begin(), combined_partitions.end(), 1, std::multiplies<int>()); int64_t product =
std::accumulate(combined_partitions.begin(), combined_partitions.end(), 1, std::multiplies<int64_t>());
if (!FULLY_USE_DEVICES) { if (!FULLY_USE_DEVICES) {
if (IntToSize(product) > dev_num) { if (LongToSize(product) > dev_num) {
return FAILED; return FAILED;
} }
} else { } else {
if (IntToSize(product) != dev_num) { if (LongToSize(product) != dev_num) {
return FAILED; return FAILED;
} }
} }
...@@ -550,7 +551,7 @@ Status MatMulBase::PrepareStrategy(int32_t stage_id, size_t dev_num, ...@@ -550,7 +551,7 @@ Status MatMulBase::PrepareStrategy(int32_t stage_id, size_t dev_num,
MS_LOG(ERROR) << name_ << " : Swap last two elements failed."; MS_LOG(ERROR) << name_ << " : Swap last two elements failed.";
} }
} }
std::vector<Dimensions> stras; Strategys stras;
stras.push_back(input0_partitions); stras.push_back(input0_partitions);
stras.push_back(input1_partitions); stras.push_back(input1_partitions);
(*sp) = std::make_shared<Strategy>(stage_id, stras); (*sp) = std::make_shared<Strategy>(stage_id, stras);
......
...@@ -77,7 +77,7 @@ Status OneHotInfo::CheckStrategy(const StrategyPtr &strategy) { ...@@ -77,7 +77,7 @@ Status OneHotInfo::CheckStrategy(const StrategyPtr &strategy) {
} }
Status OneHotInfo::InferDevMatrixShape() { Status OneHotInfo::InferDevMatrixShape() {
std::vector<Dimensions> stra = strategy_->GetInputDim(); Strategys stra = strategy_->GetInputDim();
Dimensions input_strategy = stra.at(0); Dimensions input_strategy = stra.at(0);
// Now input only support 1-D tensor, so the output is a 2-D tensor // Now input only support 1-D tensor, so the output is a 2-D tensor
...@@ -96,16 +96,16 @@ Status OneHotInfo::InferDevMatrixShape() { ...@@ -96,16 +96,16 @@ Status OneHotInfo::InferDevMatrixShape() {
} }
Status OneHotInfo::InferTensorMap() { Status OneHotInfo::InferTensorMap() {
std::vector<int32_t> input_tensor_map_index, output_tensor_map_index; Shape input_tensor_map_index, output_tensor_map_index;
size_t size = outputs_shape_[0].size(); size_t size = outputs_shape_[0].size();
// such as 2: tensor_map_index [1,0] // such as 2: tensor_map_index [1,0]
if (axis_ == 0) { if (axis_ == 0) {
for (size_t i = 0; i < size; ++i) { for (size_t i = 0; i < size; ++i) {
output_tensor_map_index.push_back((int32_t)(i)); output_tensor_map_index.push_back((int64_t)(i));
} }
} else { } else {
for (size_t i = 0; i < size; ++i) { for (size_t i = 0; i < size; ++i) {
output_tensor_map_index.push_back((int32_t)(LAST_INDEX(size) - i)); output_tensor_map_index.push_back((int64_t)(LAST_INDEX(size) - i));
} }
} }
outputs_tensor_map_.push_back(output_tensor_map_index); outputs_tensor_map_.push_back(output_tensor_map_index);
...@@ -299,13 +299,13 @@ Status OneHotInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { ...@@ -299,13 +299,13 @@ Status OneHotInfo::SetCostUnderStrategy(const StrategyPtr &strategy) {
return SUCCESS; return SUCCESS;
} }
std::shared_ptr<std::vector<std::vector<int32_t>>> OneHotInfo::GenerateBatchStrategies() { std::shared_ptr<Strategys> OneHotInfo::GenerateBatchStrategies() {
CheckGlobalDeviceManager(); CheckGlobalDeviceManager();
size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size();
Dimensions strategy = {SizeToInt(dev_num), 1}; Dimensions strategy = {SizeToInt(dev_num), 1};
Dimensions empty_strategy; Dimensions empty_strategy;
std::vector<Dimensions> strategy_v = {strategy, empty_strategy, empty_strategy}; Strategys strategy_v = {strategy, empty_strategy, empty_strategy};
return std::make_shared<std::vector<std::vector<int32_t>>>(strategy_v); return std::make_shared<Strategys>(strategy_v);
} }
} // namespace parallel } // namespace parallel
} // namespace mindspore } // namespace mindspore
...@@ -41,7 +41,7 @@ class OneHotInfo : public OperatorInfo { ...@@ -41,7 +41,7 @@ class OneHotInfo : public OperatorInfo {
Status GenerateStrategies(int32_t stage_id) override; Status GenerateStrategies(int32_t stage_id) override;
Status SetCostUnderStrategy(const StrategyPtr &strategy) override; Status SetCostUnderStrategy(const StrategyPtr &strategy) override;
ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override; ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override;
std::shared_ptr<std::vector<std::vector<int32_t>>> GenerateBatchStrategies() override; std::shared_ptr<Strategys> GenerateBatchStrategies() override;
protected: protected:
Status CheckStrategy(const StrategyPtr &strategy) override; Status CheckStrategy(const StrategyPtr &strategy) override;
......
...@@ -52,7 +52,7 @@ Status CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shap ...@@ -52,7 +52,7 @@ Status CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shap
return FAILED; return FAILED;
} }
std::vector<Dimensions> stra = strategy->GetInputDim(); Strategys stra = strategy->GetInputDim();
for (size_t i = 0; i < strategy_size; ++i) { for (size_t i = 0; i < strategy_size; ++i) {
Shape sub_strategy = stra.at(i); Shape sub_strategy = stra.at(i);
Shape sub_input_shape = inputs_shape.at(i); Shape sub_input_shape = inputs_shape.at(i);
...@@ -70,7 +70,7 @@ Status CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shap ...@@ -70,7 +70,7 @@ Status CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shap
} }
for (size_t j = 0; j < strategy_len; ++j) { for (size_t j = 0; j < strategy_len; ++j) {
int32_t strategy_value = sub_strategy.at(j); int64_t strategy_value = sub_strategy.at(j);
if (strategy_value < MIN_SLICE_NUM) { if (strategy_value < MIN_SLICE_NUM) {
if (is_auto_parallel) { if (is_auto_parallel) {
MS_LOG(DEBUG) << "Invalid strategy value: " << strategy_value; MS_LOG(DEBUG) << "Invalid strategy value: " << strategy_value;
...@@ -89,7 +89,7 @@ Status CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shap ...@@ -89,7 +89,7 @@ Status CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shap
return FAILED; return FAILED;
} }
int32_t shape_value = sub_input_shape.at(j); int64_t shape_value = sub_input_shape.at(j);
if ((shape_value % strategy_value) != 0) { if ((shape_value % strategy_value) != 0) {
if (is_auto_parallel) { if (is_auto_parallel) {
MS_LOG(DEBUG) << "Shape " << shape_value << " cannot be divisible by strategy " << strategy_value; MS_LOG(DEBUG) << "Shape " << shape_value << " cannot be divisible by strategy " << strategy_value;
...@@ -138,9 +138,9 @@ void OperatorInfo::SetDeviceListByStrategy() { ...@@ -138,9 +138,9 @@ void OperatorInfo::SetDeviceListByStrategy() {
} }
Status OperatorInfo::InferRepeatedCalcInfo() { Status OperatorInfo::InferRepeatedCalcInfo() {
int32_t g_dev_list_size = SizeToInt(global_device_list_.size()); int64_t g_dev_list_size = SizeToLong(global_device_list_.size());
int32_t dev_matrix_size = int64_t dev_matrix_size =
std::accumulate(dev_matrix_shape_.begin(), dev_matrix_shape_.end(), 1, std::multiplies<int>()); std::accumulate(dev_matrix_shape_.begin(), dev_matrix_shape_.end(), 1, std::multiplies<int64_t>());
if (dev_matrix_size == 0) { if (dev_matrix_size == 0) {
MS_LOG(ERROR) << name_ << ": The dev matrix size is 0"; MS_LOG(ERROR) << name_ << ": The dev matrix size is 0";
return FAILED; return FAILED;
...@@ -149,7 +149,7 @@ Status OperatorInfo::InferRepeatedCalcInfo() { ...@@ -149,7 +149,7 @@ Status OperatorInfo::InferRepeatedCalcInfo() {
if (g_dev_list_size == dev_matrix_size) { if (g_dev_list_size == dev_matrix_size) {
repeated_calc_num_ = 1; repeated_calc_num_ = 1;
} else if (g_dev_list_size % dev_matrix_size == 0) { } else if (g_dev_list_size % dev_matrix_size == 0) {
repeated_calc_num_ = g_dev_list_size / dev_matrix_size; repeated_calc_num_ = ((int32_t)(g_dev_list_size / dev_matrix_size));
} else { } else {
MS_LOG(ERROR) << name_ << ": Dev list size " << g_dev_list_size << " can not be divisible by dev matrix size " MS_LOG(ERROR) << name_ << ": Dev list size " << g_dev_list_size << " can not be divisible by dev matrix size "
<< dev_matrix_size; << dev_matrix_size;
...@@ -326,7 +326,7 @@ Status OperatorInfo::CreateGroupByDim(size_t axis, std::vector<Group> *group) { ...@@ -326,7 +326,7 @@ Status OperatorInfo::CreateGroupByDim(size_t axis, std::vector<Group> *group) {
Shape GetSliceShape(const Shape &tensor_shape, const Dimensions &strategy) { Shape GetSliceShape(const Shape &tensor_shape, const Dimensions &strategy) {
Shape slice_shape; Shape slice_shape;
if (std::any_of(strategy.begin(), strategy.end(), [](int32_t value) { return value <= 0; })) { if (std::any_of(strategy.begin(), strategy.end(), [](int64_t value) { return value <= 0; })) {
MS_LOG(ERROR) << "Invalid strategy: " << ShapeToString(strategy) << ", the element is less than or equal to 0"; MS_LOG(ERROR) << "Invalid strategy: " << ShapeToString(strategy) << ", the element is less than or equal to 0";
return slice_shape; return slice_shape;
} }
...@@ -430,7 +430,8 @@ Status OperatorInfo::InitForCostModelWithAutoRepeatCalc(const StrategyPtr &strat ...@@ -430,7 +430,8 @@ Status OperatorInfo::InitForCostModelWithAutoRepeatCalc(const StrategyPtr &strat
return FAILED; return FAILED;
} }
used_devices_ = std::accumulate(dev_matrix_shape_.begin(), dev_matrix_shape_.end(), 1, std::multiplies<int32_t>()); used_devices_ =
((int32_t)(std::accumulate(dev_matrix_shape_.begin(), dev_matrix_shape_.end(), 1, std::multiplies<int64_t>())));
// must be after InferDevMatrixShape // must be after InferDevMatrixShape
if (InferRepeatedCalcInfo() != SUCCESS) { if (InferRepeatedCalcInfo() != SUCCESS) {
...@@ -646,8 +647,8 @@ void OperatorInfo::ReplaceSuccEdges(const std::shared_ptr<OperatorInfo> &op, con ...@@ -646,8 +647,8 @@ void OperatorInfo::ReplaceSuccEdges(const std::shared_ptr<OperatorInfo> &op, con
succ_edges_ = new_succ_edges; succ_edges_ = new_succ_edges;
} }
std::shared_ptr<std::vector<std::vector<int32_t>>> GenerateBatchStrategiesBySplitFlag( std::shared_ptr<Strategys> GenerateBatchStrategiesBySplitFlag(const Shapes &shapes,
const Shapes &shapes, const std::vector<bool> &split_flag_list) { const std::vector<bool> &split_flag_list) {
if (shapes.size() != split_flag_list.size()) { if (shapes.size() != split_flag_list.size()) {
MS_LOG(ERROR) << "Split_flag_list do not have the same size as inputs shape, " << split_flag_list.size() << " : " MS_LOG(ERROR) << "Split_flag_list do not have the same size as inputs shape, " << split_flag_list.size() << " : "
<< shapes.size(); << shapes.size();
...@@ -655,21 +656,21 @@ std::shared_ptr<std::vector<std::vector<int32_t>>> GenerateBatchStrategiesBySpli ...@@ -655,21 +656,21 @@ std::shared_ptr<std::vector<std::vector<int32_t>>> GenerateBatchStrategiesBySpli
} }
CheckGlobalDeviceManager(); CheckGlobalDeviceManager();
int32_t dev_num = SizeToInt(g_device_manager->GetDeviceListByStageId(0).size()); int32_t dev_num = SizeToInt(g_device_manager->GetDeviceListByStageId(0).size());
std::vector<std::vector<int32_t>> strategy_v; Strategys strategy_v;
for (size_t i = 0; i != shapes.size(); i++) { for (size_t i = 0; i != shapes.size(); i++) {
if (shapes[i].empty()) { if (shapes[i].empty()) {
MS_LOG(INFO) << "Elements of shapes is empty."; MS_LOG(INFO) << "Elements of shapes is empty.";
std::vector<int32_t> empty_element; Dimensions empty_element;
strategy_v.push_back(empty_element); strategy_v.push_back(empty_element);
} else { } else {
std::vector<int32_t> element(shapes[i].size(), 1); Dimensions element(shapes[i].size(), 1);
if (split_flag_list[i]) { if (split_flag_list[i]) {
element[0] = dev_num; element[0] = dev_num;
} }
strategy_v.push_back(element); strategy_v.push_back(element);
} }
} }
return std::make_shared<std::vector<std::vector<int32_t>>>(strategy_v); return std::make_shared<Strategys>(strategy_v);
} }
void OperatorInfo::ReComputeBatchSplitFlagList() { void OperatorInfo::ReComputeBatchSplitFlagList() {
...@@ -692,26 +693,26 @@ Status PrepareStrategyBase(int32_t stage_id, size_t dev_num, const Shapes &input ...@@ -692,26 +693,26 @@ Status PrepareStrategyBase(int32_t stage_id, size_t dev_num, const Shapes &input
MS_LOG(ERROR) << "The strategy is null."; MS_LOG(ERROR) << "The strategy is null.";
return FAILED; return FAILED;
} }
int32_t product = 1; int64_t product = 1;
for (auto &input_partition : inputs_partitions) { for (auto &input_partition : inputs_partitions) {
product *= std::accumulate(input_partition.begin(), input_partition.end(), 1, std::multiplies<int>()); product *= std::accumulate(input_partition.begin(), input_partition.end(), 1, std::multiplies<int64_t>());
} }
if (!FULLY_USE_DEVICES) { if (!FULLY_USE_DEVICES) {
if (IntToSize(product) > dev_num) { if (LongToSize(product) > dev_num) {
return FAILED; return FAILED;
} }
} else { } else {
if ((product != 1) && (IntToSize(product) != dev_num)) { if ((product != 1) && (LongToSize(product) != dev_num)) {
return FAILED; return FAILED;
} }
} }
std::vector<Dimensions> stras(inputs_partitions); Strategys stras(inputs_partitions);
(*sp) = std::make_shared<Strategy>(stage_id, stras); (*sp) = std::make_shared<Strategy>(stage_id, stras);
return SUCCESS; return SUCCESS;
} }
std::shared_ptr<std::vector<std::vector<int32_t>>> OperatorInfo::GenerateBatchStrategies() { std::shared_ptr<Strategys> OperatorInfo::GenerateBatchStrategies() {
ComputeBatchSplitFlagList(); ComputeBatchSplitFlagList();
return GenerateBatchStrategiesBySplitFlag(inputs_shape_, split_flag_list_); return GenerateBatchStrategiesBySplitFlag(inputs_shape_, split_flag_list_);
} }
...@@ -793,7 +794,7 @@ Status GenerateStrategiesForBroadcastLeft(int32_t stage_id, const Shapes &inputs ...@@ -793,7 +794,7 @@ Status GenerateStrategiesForBroadcastLeft(int32_t stage_id, const Shapes &inputs
// second, get the correct strategy for input0 // second, get the correct strategy for input0
for (auto &sp : *sp_vector) { for (auto &sp : *sp_vector) {
std::vector<Dimensions> tmp_strategy; Strategys tmp_strategy;
Dimensions input0_strategy = sp->GetInputDim()[0]; Dimensions input0_strategy = sp->GetInputDim()[0];
size_t size_diff = inputs_shape[1].size() - inputs_shape[0].size(); size_t size_diff = inputs_shape[1].size() - inputs_shape[0].size();
...@@ -842,7 +843,7 @@ Status GenerateStrategiesForBroadcastRight(int32_t stage_id, const Shapes &input ...@@ -842,7 +843,7 @@ Status GenerateStrategiesForBroadcastRight(int32_t stage_id, const Shapes &input
// second, get the correct strategy for input1 // second, get the correct strategy for input1
for (auto &sp : *sp_vector) { for (auto &sp : *sp_vector) {
std::vector<Dimensions> tmp_strategy; Strategys tmp_strategy;
tmp_strategy.push_back(sp->GetInputDim()[0]); // input0 tmp_strategy.push_back(sp->GetInputDim()[0]); // input0
Dimensions input1_strategy = sp->GetInputDim()[1]; Dimensions input1_strategy = sp->GetInputDim()[1];
...@@ -1175,7 +1176,7 @@ int32_t ComputeRepeatDeviceNumByTensorMap(const Shape &dev_matrix_shape, const S ...@@ -1175,7 +1176,7 @@ int32_t ComputeRepeatDeviceNumByTensorMap(const Shape &dev_matrix_shape, const S
// The number of repetitions is equal to the number of all devices divided by the number of devices use for // The number of repetitions is equal to the number of all devices divided by the number of devices use for
// tensor map. // tensor map.
int32_t device_num = std::accumulate(dev_matrix_shape.begin(), dev_matrix_shape.end(), 1, std::multiplies<int>()); int64_t device_num = std::accumulate(dev_matrix_shape.begin(), dev_matrix_shape.end(), 1, std::multiplies<int64_t>());
for (auto &element : tensor_map) { for (auto &element : tensor_map) {
// -1 means the corresponding dimension is not split. // -1 means the corresponding dimension is not split.
if (element == MAP_NONE) { if (element == MAP_NONE) {
...@@ -1194,7 +1195,7 @@ int32_t ComputeRepeatDeviceNumByTensorMap(const Shape &dev_matrix_shape, const S ...@@ -1194,7 +1195,7 @@ int32_t ComputeRepeatDeviceNumByTensorMap(const Shape &dev_matrix_shape, const S
} }
} }
return device_num; return (int32_t)device_num;
} }
Status OperatorInfo::InferAsLossDivisor() { Status OperatorInfo::InferAsLossDivisor() {
......
...@@ -43,11 +43,10 @@ using ForwardOp = OperatorVector; ...@@ -43,11 +43,10 @@ using ForwardOp = OperatorVector;
using MirrorOps = std::vector<OperatorVector>; using MirrorOps = std::vector<OperatorVector>;
using Ops = std::vector<OperatorVector>; using Ops = std::vector<OperatorVector>;
using VirtualDivOp = OperatorVector; using VirtualDivOp = OperatorVector;
using TensorMaps = std::vector<std::vector<int32_t>>; using TensorMaps = std::vector<Shape>;
using TensorLayouts = std::vector<TensorLayout>; using TensorLayouts = std::vector<TensorLayout>;
using different_type = std::vector<int32_t>::difference_type; using different_type = std::vector<int32_t>::difference_type;
using PrimitiveAttrs = std::unordered_map<std::string, ValuePtr>; using PrimitiveAttrs = std::unordered_map<std::string, ValuePtr>;
using Strategys = std::vector<Dimensions>;
using ReplaceGraphPtr = std::shared_ptr<std::pair<std::vector<std::pair<AnfNodePtr, int>>, AnfNodePtr>>; using ReplaceGraphPtr = std::shared_ptr<std::pair<std::vector<std::pair<AnfNodePtr, int>>, AnfNodePtr>>;
class Edge; class Edge;
...@@ -88,7 +87,7 @@ class OperatorInfo { ...@@ -88,7 +87,7 @@ class OperatorInfo {
void set_cost(const OperatorCostPtr &cost) { operator_cost_ = cost; } void set_cost(const OperatorCostPtr &cost) { operator_cost_ = cost; }
virtual Status SetCostUnderStrategy(const StrategyPtr &strategy) = 0; virtual Status SetCostUnderStrategy(const StrategyPtr &strategy) = 0;
virtual std::shared_ptr<std::vector<std::vector<int32_t>>> GenerateBatchStrategies(); virtual std::shared_ptr<Strategys> GenerateBatchStrategies();
virtual void ReComputeBatchSplitFlagList(); virtual void ReComputeBatchSplitFlagList();
void ComputeBatchSplitFlagList(); void ComputeBatchSplitFlagList();
...@@ -271,8 +270,8 @@ Operator CreateReduceScatterOp(const std::string &reduce_op, const std::string & ...@@ -271,8 +270,8 @@ Operator CreateReduceScatterOp(const std::string &reduce_op, const std::string &
Operator CreateGetTensorSliceOp(const TensorLayout &tensor_layout); Operator CreateGetTensorSliceOp(const TensorLayout &tensor_layout);
OperatorVector CreateMirrorOps(const std::string &group_name, size_t dev_num); OperatorVector CreateMirrorOps(const std::string &group_name, size_t dev_num);
int32_t ComputeRepeatDeviceNumByTensorMap(const Shape &dev_matrix_shape, const Shape &tensor_map); int32_t ComputeRepeatDeviceNumByTensorMap(const Shape &dev_matrix_shape, const Shape &tensor_map);
std::shared_ptr<std::vector<std::vector<int32_t>>> GenerateBatchStrategiesBySplitFlag( std::shared_ptr<Strategys> GenerateBatchStrategiesBySplitFlag(const Shapes &shapes,
const Shapes &shapes, const std::vector<bool> &split_flag_list); const std::vector<bool> &split_flag_list);
void PrintStrategy(const StrategyPtr &strategy); void PrintStrategy(const StrategyPtr &strategy);
// generate strategies for that all inputs' dimensions are independent, such as: ([a, b, c, d]) // generate strategies for that all inputs' dimensions are independent, such as: ([a, b, c, d])
......
...@@ -43,7 +43,7 @@ Status PReLUInfo::CheckStrategy(const StrategyPtr &strategy) { ...@@ -43,7 +43,7 @@ Status PReLUInfo::CheckStrategy(const StrategyPtr &strategy) {
} }
return FAILED; return FAILED;
} }
std::vector<Dimensions> stra = strategy->GetInputDim(); Strategys stra = strategy->GetInputDim();
if (stra[1].size() != PRELU_SECOND_INPUT_SIZE) { if (stra[1].size() != PRELU_SECOND_INPUT_SIZE) {
if (is_auto_parallel_) { if (is_auto_parallel_) {
MS_LOG(DEBUG) << name_ << ": Invalid strategy size."; MS_LOG(DEBUG) << name_ << ": Invalid strategy size.";
...@@ -67,7 +67,7 @@ Status PReLUInfo::CheckStrategy(const StrategyPtr &strategy) { ...@@ -67,7 +67,7 @@ Status PReLUInfo::CheckStrategy(const StrategyPtr &strategy) {
* device matrix is same with the strategy matrix * device matrix is same with the strategy matrix
*/ */
Status PReLUInfo::InferDevMatrixShape() { Status PReLUInfo::InferDevMatrixShape() {
std::vector<Dimensions> stra = strategy_->GetInputDim(); Strategys stra = strategy_->GetInputDim();
Dimensions input_strategy = stra.at(0); Dimensions input_strategy = stra.at(0);
input_strategy_ = input_strategy; input_strategy_ = input_strategy;
dev_matrix_shape_ = input_strategy; dev_matrix_shape_ = input_strategy;
...@@ -103,7 +103,7 @@ Status PReLUInfo::InferTensorMap() { ...@@ -103,7 +103,7 @@ Status PReLUInfo::InferTensorMap() {
TensorMap input_tensor_map; TensorMap input_tensor_map;
// such as 4: input_tensor_map [3,2,1,0] // such as 4: input_tensor_map [3,2,1,0]
for (size_t i = 0; i < inputs_shape_[0].size(); ++i) { for (size_t i = 0; i < inputs_shape_[0].size(); ++i) {
input_tensor_map.push_back((int32_t)(inputs_shape_[0].size() - i - 1)); input_tensor_map.push_back((int64_t)(inputs_shape_[0].size() - i - 1));
} }
TensorMap param_tensor_map; TensorMap param_tensor_map;
......
...@@ -43,7 +43,7 @@ Status ReduceMethod::CheckStrategy(const StrategyPtr &strategy) { ...@@ -43,7 +43,7 @@ Status ReduceMethod::CheckStrategy(const StrategyPtr &strategy) {
} }
Status ReduceMethod::InferDevMatrixShape() { Status ReduceMethod::InferDevMatrixShape() {
std::vector<Dimensions> stra = strategy_->GetInputDim(); Strategys stra = strategy_->GetInputDim();
Dimensions input_strategy = stra.at(0); Dimensions input_strategy = stra.at(0);
dev_matrix_shape_ = input_strategy; dev_matrix_shape_ = input_strategy;
...@@ -119,11 +119,12 @@ Status ReduceMethod::GetAttrs() { ...@@ -119,11 +119,12 @@ Status ReduceMethod::GetAttrs() {
} }
Status ReduceMethod::InferTensorMap() { Status ReduceMethod::InferTensorMap() {
std::vector<int32_t> tensor_map_index, dim_list, output_tensor_map; Shape tensor_map_index, output_tensor_map;
std::vector<int32_t> dim_list;
size_t size = inputs_shape_.at(0).size(); size_t size = inputs_shape_.at(0).size();
// such as 4: tensor_map_index [3,2,1,0] // such as 4: tensor_map_index [3,2,1,0]
for (size_t i = 0; i < size; ++i) { for (size_t i = 0; i < size; ++i) {
tensor_map_index.push_back((int32_t)(size - 1 - i)); tensor_map_index.push_back((int64_t)(size - 1 - i));
} }
dim_list = reduce_dim(); dim_list = reduce_dim();
for (size_t i = 0; i < size; ++i) { for (size_t i = 0; i < size; ++i) {
...@@ -462,7 +463,7 @@ Status ArgMaxWithValueInfo::CheckStrategy(const StrategyPtr &strategy) { ...@@ -462,7 +463,7 @@ Status ArgMaxWithValueInfo::CheckStrategy(const StrategyPtr &strategy) {
std::vector<int32_t> dim_list = reduce_dim(); std::vector<int32_t> dim_list = reduce_dim();
MS_ASSERT(dim_list.size() == 1); MS_ASSERT(dim_list.size() == 1);
std::vector<Dimensions> stra = strategy->GetInputDim(); Strategys stra = strategy->GetInputDim();
MS_ASSERT(stra.size() == 1); MS_ASSERT(stra.size() == 1);
Shape input_strategy = stra.at(0); Shape input_strategy = stra.at(0);
MS_ASSERT(dim_list.at(0) < input_strategy.size()); MS_ASSERT(dim_list.at(0) < input_strategy.size());
......
...@@ -57,7 +57,7 @@ Status ReshapeInfo::CheckStrategy(const StrategyPtr &strategy) { ...@@ -57,7 +57,7 @@ Status ReshapeInfo::CheckStrategy(const StrategyPtr &strategy) {
* only support batch parallel reshape operator in ReID (batch parallel degree can be smaller than device number) * only support batch parallel reshape operator in ReID (batch parallel degree can be smaller than device number)
*/ */
Status ReshapeInfo::InferDevMatrixShape() { Status ReshapeInfo::InferDevMatrixShape() {
std::vector<Dimensions> stra = strategy_->GetInputDim(); Strategys stra = strategy_->GetInputDim();
input_strategy_ = stra.at(0); input_strategy_ = stra.at(0);
dev_matrix_shape_.push_back(input_strategy_[0]); dev_matrix_shape_.push_back(input_strategy_[0]);
return SUCCESS; return SUCCESS;
...@@ -181,7 +181,7 @@ Status ReshapeInfo::InferTensorMap() { ...@@ -181,7 +181,7 @@ Status ReshapeInfo::InferTensorMap() {
return FAILED; return FAILED;
} }
std::vector<int32_t> tensor_map_index_input; Shape tensor_map_index_input;
tensor_map_index_input.push_back(0); tensor_map_index_input.push_back(0);
for (size_t j = 1; j < inputs_shape_[0].size(); ++j) { for (size_t j = 1; j < inputs_shape_[0].size(); ++j) {
...@@ -189,7 +189,7 @@ Status ReshapeInfo::InferTensorMap() { ...@@ -189,7 +189,7 @@ Status ReshapeInfo::InferTensorMap() {
} }
inputs_tensor_map_.push_back(tensor_map_index_input); inputs_tensor_map_.push_back(tensor_map_index_input);
std::vector<int32_t> tensor_map_index_output; Shape tensor_map_index_output;
tensor_map_index_output.push_back(0); tensor_map_index_output.push_back(0);
for (size_t j = 1; j < outputs_shape_[0].size(); ++j) { for (size_t j = 1; j < outputs_shape_[0].size(); ++j) {
...@@ -205,7 +205,7 @@ Status ReshapeInfo::InferTensorMap() { ...@@ -205,7 +205,7 @@ Status ReshapeInfo::InferTensorMap() {
*/ */
Strategys ReshapeInfo::GetOutputsStrategy() { Strategys ReshapeInfo::GetOutputsStrategy() {
Strategys outputs_strategy; Strategys outputs_strategy;
std::vector<int32_t> strategy; Dimensions strategy;
strategy.push_back(input_strategy_[0]); strategy.push_back(input_strategy_[0]);
for (size_t j = 1; j < outputs_shape_[0].size(); ++j) { for (size_t j = 1; j < outputs_shape_[0].size(); ++j) {
strategy.push_back(1); strategy.push_back(1);
...@@ -325,7 +325,7 @@ void ReshapeInfo::device_number(const StrategyPtr &strategy) { ...@@ -325,7 +325,7 @@ void ReshapeInfo::device_number(const StrategyPtr &strategy) {
} }
Status ReshapeInfo::InferDefaultLayout(const Shape &shape, TensorLayout *const layout) { Status ReshapeInfo::InferDefaultLayout(const Shape &shape, TensorLayout *const layout) {
std::vector<int32_t> tensor_map_index; Shape tensor_map_index;
for (size_t i = 0; i < shape.size(); i++) { for (size_t i = 0; i < shape.size(); i++) {
tensor_map_index.push_back(MAP_NONE); tensor_map_index.push_back(MAP_NONE);
} }
...@@ -504,7 +504,7 @@ Status ReshapeInfo::GenetateStrategyCosts(const std::vector<std::shared_ptr<Stra ...@@ -504,7 +504,7 @@ Status ReshapeInfo::GenetateStrategyCosts(const std::vector<std::shared_ptr<Stra
MS_LOG(ERROR) << "Infer strategy by tensor_info failed"; MS_LOG(ERROR) << "Infer strategy by tensor_info failed";
return FAILED; return FAILED;
} }
std::vector<Dimensions> stra_inputs = {stra}; Strategys stra_inputs = {stra};
StrategyPtr reshape_stra = std::make_shared<Strategy>(pre_stra_cost->strategy_ptr->GetInputStage(), stra_inputs); StrategyPtr reshape_stra = std::make_shared<Strategy>(pre_stra_cost->strategy_ptr->GetInputStage(), stra_inputs);
if (next_stra_costs.empty()) { if (next_stra_costs.empty()) {
if (Init(nullptr) == FAILED) { if (Init(nullptr) == FAILED) {
......
...@@ -227,7 +227,7 @@ Status StridedSliceInfo::InferTensorInfo() { ...@@ -227,7 +227,7 @@ Status StridedSliceInfo::InferTensorInfo() {
} }
// Note: if the batch dimension is not fully fetched, the batch strategy may not work. // Note: if the batch dimension is not fully fetched, the batch strategy may not work.
std::shared_ptr<std::vector<std::vector<int32_t>>> StridedSliceInfo::GenerateBatchStrategies() { std::shared_ptr<Strategys> StridedSliceInfo::GenerateBatchStrategies() {
split_flag_list_ = {true}; split_flag_list_ = {true};
return GenerateBatchStrategiesBySplitFlag(inputs_shape_, split_flag_list_); return GenerateBatchStrategiesBySplitFlag(inputs_shape_, split_flag_list_);
} }
......
...@@ -41,7 +41,7 @@ class StridedSliceInfo : public OperatorInfo { ...@@ -41,7 +41,7 @@ class StridedSliceInfo : public OperatorInfo {
Status InitForCostModel(const StrategyPtr &strategy) override; Status InitForCostModel(const StrategyPtr &strategy) override;
Status GenerateStrategies(int32_t) override; Status GenerateStrategies(int32_t) override;
Status SetCostUnderStrategy(const StrategyPtr &) override; Status SetCostUnderStrategy(const StrategyPtr &) override;
std::shared_ptr<std::vector<std::vector<int32_t>>> GenerateBatchStrategies() override; std::shared_ptr<Strategys> GenerateBatchStrategies() override;
protected: protected:
Status GetAttrs() override; Status GetAttrs() override;
......
...@@ -54,7 +54,7 @@ Status TileInfo::GetAttrs() { ...@@ -54,7 +54,7 @@ Status TileInfo::GetAttrs() {
for (auto &element : elements) { for (auto &element : elements) {
MS_EXCEPTION_IF_NULL(element); MS_EXCEPTION_IF_NULL(element);
if (element->isa<Int32Imm>()) { if (element->isa<Int32Imm>()) {
int32_t axis = element->cast<Int32ImmPtr>()->value(); int64_t axis = static_cast<int64_t>(element->cast<Int32ImmPtr>()->value());
full_multiples_.push_back(axis); full_multiples_.push_back(axis);
} else { } else {
MS_LOG(ERROR) << name_ << ": The value of axis must be int32."; MS_LOG(ERROR) << name_ << ": The value of axis must be int32.";
...@@ -180,12 +180,15 @@ void TileInfo::UpdateMultiples(const CNodePtr &cnode) { ...@@ -180,12 +180,15 @@ void TileInfo::UpdateMultiples(const CNodePtr &cnode) {
auto manager = func_graph->manager(); auto manager = func_graph->manager();
MS_EXCEPTION_IF_NULL(manager); MS_EXCEPTION_IF_NULL(manager);
ValuePtr new_multiples = MakeValue(slice_multiples_); std::vector<int32_t> slice_multiples_int;
(void)std::transform(slice_multiples_.begin(), slice_multiples_.end(), std::back_inserter(slice_multiples_int),
[](const int64_t &value) { return static_cast<int32_t>(value); });
ValuePtr new_multiples = MakeValue(slice_multiples_int);
AnfNodePtr val = NewValueNode(new_multiples); AnfNodePtr val = NewValueNode(new_multiples);
(void)manager->Replace(cnode->input(2), val); (void)manager->Replace(cnode->input(2), val);
} }
std::shared_ptr<std::vector<std::vector<int32_t>>> TileInfo::GenerateBatchStrategies() { std::shared_ptr<Strategys> TileInfo::GenerateBatchStrategies() {
if (InferAttrs() != SUCCESS) { if (InferAttrs() != SUCCESS) {
MS_LOG(EXCEPTION) << name_ << ": Infer attrs failed"; MS_LOG(EXCEPTION) << name_ << ": Infer attrs failed";
} }
......
...@@ -41,7 +41,7 @@ class TileInfo : public OperatorInfo { ...@@ -41,7 +41,7 @@ class TileInfo : public OperatorInfo {
Status InitForCostModel(const StrategyPtr &strategy) override; Status InitForCostModel(const StrategyPtr &strategy) override;
Status GenerateStrategies(int32_t) override; Status GenerateStrategies(int32_t) override;
Status SetCostUnderStrategy(const StrategyPtr &) override; Status SetCostUnderStrategy(const StrategyPtr &) override;
std::shared_ptr<std::vector<std::vector<int32_t>>> GenerateBatchStrategies() override; std::shared_ptr<Strategys> GenerateBatchStrategies() override;
void UpdateMultiples(const CNodePtr &cnode); void UpdateMultiples(const CNodePtr &cnode);
protected: protected:
...@@ -54,8 +54,8 @@ class TileInfo : public OperatorInfo { ...@@ -54,8 +54,8 @@ class TileInfo : public OperatorInfo {
Status InferTensorMap() override; Status InferTensorMap() override;
private: private:
std::vector<int32_t> full_multiples_; std::vector<int64_t> full_multiples_;
std::vector<int32_t> slice_multiples_; std::vector<int64_t> slice_multiples_;
}; };
using TileInfoPtr = std::shared_ptr<TileInfo>; using TileInfoPtr = std::shared_ptr<TileInfo>;
......
...@@ -37,18 +37,18 @@ Status TmpIdentityInfo::CheckStrategy(const mindspore::parallel::StrategyPtr &st ...@@ -37,18 +37,18 @@ Status TmpIdentityInfo::CheckStrategy(const mindspore::parallel::StrategyPtr &st
} }
Status TmpIdentityInfo::InferDevMatrixShape() { Status TmpIdentityInfo::InferDevMatrixShape() {
std::vector<Dimensions> stra = strategy_->GetInputDim(); Strategys stra = strategy_->GetInputDim();
Dimensions input_strategy = stra.at(0); Dimensions input_strategy = stra.at(0);
dev_matrix_shape_ = input_strategy; dev_matrix_shape_ = input_strategy;
return SUCCESS; return SUCCESS;
} }
Status TmpIdentityInfo::InferTensorMap() { Status TmpIdentityInfo::InferTensorMap() {
std::vector<int32_t> tensor_map_index; Shape tensor_map_index;
size_t size = inputs_shape_[0].size(); size_t size = inputs_shape_[0].size();
// such as 4: tensor_map_index [3,2,1,0] // such as 4: tensor_map_index [3,2,1,0]
for (size_t i = 0; i < size; ++i) { for (size_t i = 0; i < size; ++i) {
tensor_map_index.push_back((int32_t)(size - 1 - i)); tensor_map_index.push_back((int64_t)(size - 1 - i));
} }
inputs_tensor_map_.push_back(tensor_map_index); inputs_tensor_map_.push_back(tensor_map_index);
......
...@@ -41,7 +41,7 @@ Status TransposeInfo::CheckStrategy(const StrategyPtr &strategy) { ...@@ -41,7 +41,7 @@ Status TransposeInfo::CheckStrategy(const StrategyPtr &strategy) {
} }
Status TransposeInfo::InferDevMatrixShape() { Status TransposeInfo::InferDevMatrixShape() {
std::vector<Dimensions> stra = strategy_->GetInputDim(); Strategys stra = strategy_->GetInputDim();
input_strategy_ = stra.at(0); input_strategy_ = stra.at(0);
for (auto &iter : input_strategy_) { for (auto &iter : input_strategy_) {
dev_matrix_shape_.push_back(iter); dev_matrix_shape_.push_back(iter);
...@@ -105,13 +105,13 @@ Status TransposeInfo::InferTensorMap() { ...@@ -105,13 +105,13 @@ Status TransposeInfo::InferTensorMap() {
return FAILED; return FAILED;
} }
std::vector<int32_t> tensor_map_index_input; Shape tensor_map_index_input;
for (size_t j = 0; j < inputs_shape_[0].size(); ++j) { for (size_t j = 0; j < inputs_shape_[0].size(); ++j) {
tensor_map_index_input.push_back(SizeToInt(inputs_shape_[0].size() - j - 1)); tensor_map_index_input.push_back(SizeToInt(inputs_shape_[0].size() - j - 1));
} }
inputs_tensor_map_.push_back(tensor_map_index_input); inputs_tensor_map_.push_back(tensor_map_index_input);
std::vector<int32_t> tensor_map_index_output = tensor_map_index_input; Shape tensor_map_index_output = tensor_map_index_input;
for (uint32_t i = 0; i < tensor_map_index_output.size(); i++) { for (uint32_t i = 0; i < tensor_map_index_output.size(); i++) {
tensor_map_index_output[i] = tensor_map_index_input[IntToUint(axis_v_[i])]; tensor_map_index_output[i] = tensor_map_index_input[IntToUint(axis_v_[i])];
} }
...@@ -122,7 +122,7 @@ Status TransposeInfo::InferTensorMap() { ...@@ -122,7 +122,7 @@ Status TransposeInfo::InferTensorMap() {
// the output tensor strategy is the permutation of input tensor strategy, the permutation is axis_v // the output tensor strategy is the permutation of input tensor strategy, the permutation is axis_v
Strategys TransposeInfo::GetOutputsStrategy() { Strategys TransposeInfo::GetOutputsStrategy() {
Strategys outputs_strategy; Strategys outputs_strategy;
std::vector<int32_t> strategy = input_strategy_; Dimensions strategy = input_strategy_;
for (uint32_t i = 0; i < strategy.size(); i++) { for (uint32_t i = 0; i < strategy.size(); i++) {
strategy[i] = input_strategy_[IntToUint(axis_v_[i])]; strategy[i] = input_strategy_[IntToUint(axis_v_[i])];
} }
......
...@@ -38,7 +38,7 @@ Status VirtualDatasetInfo::CheckStrategy(const StrategyPtr &strategy) { ...@@ -38,7 +38,7 @@ Status VirtualDatasetInfo::CheckStrategy(const StrategyPtr &strategy) {
return FAILED; return FAILED;
} }
std::vector<Dimensions> stra = strategy->GetInputDim(); Strategys stra = strategy->GetInputDim();
if (stra.size() < 1) { if (stra.size() < 1) {
if (is_auto_parallel_) { if (is_auto_parallel_) {
MS_LOG(DEBUG) << name_ << ": Strategy size must be larger than 1."; MS_LOG(DEBUG) << name_ << ": Strategy size must be larger than 1.";
...@@ -80,12 +80,12 @@ Status VirtualDatasetInfo::CheckStrategy(const StrategyPtr &strategy) { ...@@ -80,12 +80,12 @@ Status VirtualDatasetInfo::CheckStrategy(const StrategyPtr &strategy) {
} }
Status VirtualDatasetInfo::InferDevMatrixShape() { Status VirtualDatasetInfo::InferDevMatrixShape() {
std::vector<Dimensions> stra = strategy_->GetInputDim(); Strategys stra = strategy_->GetInputDim();
Dimensions strategy_first = stra.at(0); Dimensions strategy_first = stra.at(0);
int32_t stage = strategy_->GetInputStage(); int32_t stage = strategy_->GetInputStage();
CheckGlobalDeviceManager(); CheckGlobalDeviceManager();
int32_t dev_num = SizeToInt(g_device_manager->GetDeviceListByStageId(stage).size()); int32_t dev_num = SizeToInt(g_device_manager->GetDeviceListByStageId(stage).size());
int32_t batch_split_num = strategy_first.at(0); int32_t batch_split_num = ((int32_t)(strategy_first.at(0)));
dev_matrix_shape_.push_back(batch_split_num); dev_matrix_shape_.push_back(batch_split_num);
if (dev_num > batch_split_num) { if (dev_num > batch_split_num) {
dev_matrix_shape_.push_back(dev_num / batch_split_num); dev_matrix_shape_.push_back(dev_num / batch_split_num);
...@@ -103,11 +103,11 @@ Status VirtualDatasetInfo::InferTensorMap() { ...@@ -103,11 +103,11 @@ Status VirtualDatasetInfo::InferTensorMap() {
bool full_batch = ParallelContext::GetInstance()->full_batch(); bool full_batch = ParallelContext::GetInstance()->full_batch();
for (size_t i = 0; i < strategy_->GetInputNumber(); i++) { for (size_t i = 0; i < strategy_->GetInputNumber(); i++) {
std::vector<int32_t> tensor_map_index; Shape tensor_map_index;
if (full_batch) { if (full_batch) {
tensor_map_index.push_back(MAP_NONE); tensor_map_index.push_back(MAP_NONE);
} else { } else {
tensor_map_index.push_back((int32_t)(LAST_INDEX(SizeToUint(dev_matrix_shape_.size())))); tensor_map_index.push_back((int64_t)(LAST_INDEX(dev_matrix_shape_.size())));
} }
for (size_t j = 1; j < strategy_->GetInputDim()[i].size(); ++j) { for (size_t j = 1; j < strategy_->GetInputDim()[i].size(); ++j) {
tensor_map_index.push_back(MAP_NONE); tensor_map_index.push_back(MAP_NONE);
...@@ -193,7 +193,7 @@ Status VirtualDatasetInfo::GenerateStrategies(int32_t stage_id) { ...@@ -193,7 +193,7 @@ Status VirtualDatasetInfo::GenerateStrategies(int32_t stage_id) {
total_dev_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); total_dev_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
} }
StrategyPtr sp; StrategyPtr sp;
std::vector<Dimensions> strategy; Strategys strategy;
for (auto &shape : inputs_shape_) { for (auto &shape : inputs_shape_) {
Shape temp; Shape temp;
temp.emplace_back(SizeToInt(total_dev_num)); temp.emplace_back(SizeToInt(total_dev_num));
......
...@@ -1019,14 +1019,16 @@ StrategyPtr ExtractStrategy(std::unordered_map<std::string, ValuePtr> attrs) { ...@@ -1019,14 +1019,16 @@ StrategyPtr ExtractStrategy(std::unordered_map<std::string, ValuePtr> attrs) {
} }
if (var->size() > 0) { if (var->size() > 0) {
std::vector<ValuePtr> elements = var->value(); std::vector<ValuePtr> elements = var->value();
std::vector<Dimensions> strategy; Strategys strategy;
for (uint32_t index = 0; index < elements.size(); ++index) { for (uint32_t index = 0; index < elements.size(); ++index) {
Dimensions dim; Dimensions dim;
if (elements[index]->isa<ValueSequeue>()) { if (elements[index]->isa<ValueSequeue>()) {
ValueTuplePtr value_tuple = elements[index]->cast<ValueTuplePtr>(); ValueTuplePtr value_tuple = elements[index]->cast<ValueTuplePtr>();
std::vector<ValuePtr> value_vector = value_tuple->value(); std::vector<ValuePtr> value_vector = value_tuple->value();
(void)std::transform(value_vector.begin(), value_vector.end(), std::back_inserter(dim), (void)std::transform(
[](const ValuePtr &value) { return static_cast<int32_t>(GetValue<int>(value)); }); value_vector.begin(), value_vector.end(), std::back_inserter(dim), [](const ValuePtr &value) {
return value->isa<Int64Imm>() ? GetValue<int64_t>(value) : static_cast<int64_t>(GetValue<int>(value));
});
strategy.push_back(dim); strategy.push_back(dim);
} else { } else {
MS_LOG(EXCEPTION) << "Failure:Strategy's format is wrong! Need ValueSequeue"; MS_LOG(EXCEPTION) << "Failure:Strategy's format is wrong! Need ValueSequeue";
...@@ -1075,12 +1077,20 @@ Shapes GetNodeShape(const AnfNodePtr &node) { ...@@ -1075,12 +1077,20 @@ Shapes GetNodeShape(const AnfNodePtr &node) {
for (auto &shape : tuple_shape) { for (auto &shape : tuple_shape) {
auto each_shape = dyn_cast<abstract::Shape>(shape); auto each_shape = dyn_cast<abstract::Shape>(shape);
MS_EXCEPTION_IF_NULL(each_shape); MS_EXCEPTION_IF_NULL(each_shape);
shapes.push_back(each_shape->shape()); std::vector<int> shape_int = each_shape->shape();
Shape new_shape;
(void)std::transform(shape_int.begin(), shape_int.end(), std::back_inserter(new_shape),
[](const int &value) { return static_cast<int64_t>(value); });
shapes.push_back(new_shape);
} }
} else { } else {
auto shape_ptr = dyn_cast<abstract::Shape>(base_shape_ptr); auto shape_ptr = dyn_cast<abstract::Shape>(base_shape_ptr);
MS_EXCEPTION_IF_NULL(shape_ptr); MS_EXCEPTION_IF_NULL(shape_ptr);
shapes.push_back(shape_ptr->shape()); std::vector<int> shape_int = shape_ptr->shape();
Shape new_shape;
(void)std::transform(shape_int.begin(), shape_int.end(), std::back_inserter(new_shape),
[](const int &value) { return static_cast<int64_t>(value); });
shapes.push_back(new_shape);
} }
return shapes; return shapes;
} }
...@@ -1412,7 +1422,7 @@ void SetVirtualDatasetStrategy(const CNodePtr &node) { ...@@ -1412,7 +1422,7 @@ void SetVirtualDatasetStrategy(const CNodePtr &node) {
if (shape_list[0][i].empty()) { if (shape_list[0][i].empty()) {
MS_LOG(EXCEPTION) << "shape_list[ " << i << " ].size() is zero"; MS_LOG(EXCEPTION) << "shape_list[ " << i << " ].size() is zero";
} }
std::vector<int32_t> input_strategy = {dev_num}; Dimensions input_strategy = {dev_num};
for (size_t j = 1; j < shape_list[0][i].size(); j++) { for (size_t j = 1; j < shape_list[0][i].size(); j++) {
input_strategy.push_back(1); input_strategy.push_back(1);
} }
...@@ -1476,7 +1486,7 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) { ...@@ -1476,7 +1486,7 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) {
if (!StrategyFound(attrs) && !load_strategy_from_ckpt) { if (!StrategyFound(attrs) && !load_strategy_from_ckpt) {
MS_LOG(INFO) << "ExtractInformation: the strategy of node " << node->ToString() << " prim " << prim->name() MS_LOG(INFO) << "ExtractInformation: the strategy of node " << node->ToString() << " prim " << prim->name()
<< " is empty, using batch parallel"; << " is empty, using batch parallel";
std::shared_ptr<std::vector<Dimensions>> strategy_v_ptr = operator_->GenerateBatchStrategies(); std::shared_ptr<Strategys> strategy_v_ptr = operator_->GenerateBatchStrategies();
if (strategy_v_ptr == nullptr) { if (strategy_v_ptr == nullptr) {
MS_LOG(EXCEPTION) << "Failure:Generate batch parallel strategy failed"; MS_LOG(EXCEPTION) << "Failure:Generate batch parallel strategy failed";
} }
......
...@@ -24,19 +24,20 @@ ...@@ -24,19 +24,20 @@
#include <vector> #include <vector>
#include "frontend/parallel/status.h" #include "frontend/parallel/status.h"
#include "frontend/parallel/device_matrix.h"
namespace mindspore { namespace mindspore {
namespace parallel { namespace parallel {
#define MIN_SLICE_NUM 1 #define MIN_SLICE_NUM 1
using Dimensions = std::vector<int32_t>; using Dimensions = Shape;
using Strategys = std::vector<Dimensions>;
class Strategy; class Strategy;
using StrategyPtr = std::shared_ptr<Strategy>; using StrategyPtr = std::shared_ptr<Strategy>;
class Strategy { class Strategy {
public: public:
Strategy(int32_t stage, std::vector<Dimensions> inputs) Strategy(int32_t stage, Strategys inputs)
: stage_(stage), inputs_(std::move(inputs)), internal_size_(0), internal_stragies_() {} : stage_(stage), inputs_(std::move(inputs)), internal_size_(0), internal_stragies_() {}
Strategy(const Strategy &another_stra) : stage_(another_stra.GetInputStage()) { Strategy(const Strategy &another_stra) : stage_(another_stra.GetInputStage()) {
...@@ -51,14 +52,14 @@ class Strategy { ...@@ -51,14 +52,14 @@ class Strategy {
~Strategy() = default; ~Strategy() = default;
size_t GetInputNumber() const { return inputs_.size(); } size_t GetInputNumber() const { return inputs_.size(); }
std::vector<Dimensions> GetInputDim() const { return inputs_; } Strategys GetInputDim() const { return inputs_; }
int32_t GetInputStage() const { return stage_; } int32_t GetInputStage() const { return stage_; }
void ExpandInputDimFromOneToTwo() { void ExpandInputDimFromOneToTwo() {
if (inputs_.size() == 1) { if (inputs_.size() == 1) {
inputs_.push_back(inputs_[0]); inputs_.push_back(inputs_[0]);
} }
} }
void ResetInputs(const std::vector<Dimensions> &input) { inputs_ = input; } void ResetInputs(const Strategys &input) { inputs_ = input; }
std::vector<StrategyPtr> GetInternalStrategies() const { return internal_stragies_; } std::vector<StrategyPtr> GetInternalStrategies() const { return internal_stragies_; }
size_t GetInternalSize() const { return internal_size_; } size_t GetInternalSize() const { return internal_size_; }
...@@ -83,12 +84,12 @@ class Strategy { ...@@ -83,12 +84,12 @@ class Strategy {
const int32_t stage_; const int32_t stage_;
// The size of Dimensions must equal to inputs_ tensor dimension. // The size of Dimensions must equal to inputs_ tensor dimension.
std::vector<Dimensions> inputs_; Strategys inputs_;
size_t internal_size_ = 0; size_t internal_size_ = 0;
std::vector<StrategyPtr> internal_stragies_; std::vector<StrategyPtr> internal_stragies_;
}; };
inline StrategyPtr NewStrategy(const int32_t stage, const std::vector<Dimensions> &inputs) { inline StrategyPtr NewStrategy(const int32_t stage, const Strategys &inputs) {
return std::make_shared<Strategy>(stage, inputs); return std::make_shared<Strategy>(stage, inputs);
} }
} // namespace parallel } // namespace parallel
......
...@@ -66,10 +66,10 @@ Status StrategyCheckpoint::Load(StrategyMap *strategy_map) { ...@@ -66,10 +66,10 @@ Status StrategyCheckpoint::Load(StrategyMap *strategy_map) {
straspb::ParallelStrategys parallel_strategys = parallel_strategy_item.parallel_strategys(); straspb::ParallelStrategys parallel_strategys = parallel_strategy_item.parallel_strategys();
auto stage = (int32_t)parallel_strategys.stage(); auto stage = (int32_t)parallel_strategys.stage();
size_t strategys_num = IntToSize(parallel_strategys.parallel_strategy_size()); size_t strategys_num = IntToSize(parallel_strategys.parallel_strategy_size());
std::vector<std::vector<int32_t>> strategy_inputs; Strategys strategy_inputs;
for (size_t j = 0; j < strategys_num; j++) { for (size_t j = 0; j < strategys_num; j++) {
straspb::ParallelStrategy parallel_strategy = parallel_strategys.parallel_strategy(SizeToInt(j)); straspb::ParallelStrategy parallel_strategy = parallel_strategys.parallel_strategy(SizeToInt(j));
std::vector<int32_t> dimension; Dimensions dimension;
size_t dim_num = IntToSize(parallel_strategy.dim_size()); size_t dim_num = IntToSize(parallel_strategy.dim_size());
for (size_t k = 0; k < dim_num; k++) { for (size_t k = 0; k < dim_num; k++) {
dimension.push_back(parallel_strategy.dim(SizeToInt(k))); dimension.push_back(parallel_strategy.dim(SizeToInt(k)));
......
...@@ -26,7 +26,7 @@ ...@@ -26,7 +26,7 @@
namespace mindspore { namespace mindspore {
namespace parallel { namespace parallel {
Status Arrangement::Init(const std::vector<int32_t> &array) { Status Arrangement::Init(const Shape &array) {
Status status = Array::Init(array); Status status = Array::Init(array);
if (status != Status::SUCCESS) { if (status != Status::SUCCESS) {
return Status::FAILED; return Status::FAILED;
...@@ -40,7 +40,7 @@ Status Arrangement::Init(const std::vector<int32_t> &array) { ...@@ -40,7 +40,7 @@ Status Arrangement::Init(const std::vector<int32_t> &array) {
} }
bool Arrangement::IsValidArrangement() { bool Arrangement::IsValidArrangement() {
return !std::any_of(array_.begin(), array_.end(), [](int32_t value) { return value <= 0; }); return !std::any_of(array_.begin(), array_.end(), [](int64_t value) { return value <= 0; });
} }
void Arrangement::ComputeSize() { void Arrangement::ComputeSize() {
...@@ -57,14 +57,14 @@ void Arrangement::ComputeSize() { ...@@ -57,14 +57,14 @@ void Arrangement::ComputeSize() {
* where size_[i-1] = shape[0] * shape[1] * ... * shape[i-1], * where size_[i-1] = shape[0] * shape[1] * ... * shape[i-1],
* if value > size_, return [] * if value > size_, return []
*/ */
std::vector<int32_t> Arrangement::GetFrontElementByValue(int32_t value) const { Shape Arrangement::GetFrontElementByValue(int64_t value) const {
std::vector<int32_t> out; Shape out;
if (GetDimSize() == 0) { if (GetDimSize() == 0) {
return out; return out;
} }
if (value <= size_) { if (value <= size_) {
int32_t size = 1; int64_t size = 1;
uint32_t shape_list_idx = 0; size_t shape_list_idx = 0;
while (size < value) { while (size < value) {
size *= array_[shape_list_idx]; size *= array_[shape_list_idx];
if (size <= value) { if (size <= value) {
...@@ -88,9 +88,9 @@ std::shared_ptr<Arrangement> Arrangement::GetExpandedShapeByExpandListRemoveLeft ...@@ -88,9 +88,9 @@ std::shared_ptr<Arrangement> Arrangement::GetExpandedShapeByExpandListRemoveLeft
if (expand_list.size() != GetDimSize()) { if (expand_list.size() != GetDimSize()) {
return nullptr; return nullptr;
} }
std::vector<int32_t> new_shape; Shape new_shape;
for (uint32_t i = 0; i < expand_list.size(); i++) { for (size_t i = 0; i < expand_list.size(); i++) {
std::vector<int32_t> expand_shape = expand_list[i].GetFrontElementByValue(GetDimByIdx(i)); Shape expand_shape = expand_list[i].GetFrontElementByValue(GetDimByIdx(i));
if (expand_shape.empty()) { if (expand_shape.empty()) {
new_shape.push_back(GetDimByIdx(i)); new_shape.push_back(GetDimByIdx(i));
} else { } else {
...@@ -109,11 +109,11 @@ std::shared_ptr<Arrangement> Arrangement::GetExpandedShapeByExpandListRemoveLeft ...@@ -109,11 +109,11 @@ std::shared_ptr<Arrangement> Arrangement::GetExpandedShapeByExpandListRemoveLeft
* arrangement_list = [[4, 2], [2, 2]] * arrangement_list = [[4, 2], [2, 2]]
*/ */
std::shared_ptr<std::vector<Arrangement>> Arrangement::GetExpandShapeList(const Arrangement &expand_shape) const { std::shared_ptr<std::vector<Arrangement>> Arrangement::GetExpandShapeList(const Arrangement &expand_shape) const {
int32_t size = 1; int64_t size = 1;
uint32_t ind = 0; size_t ind = 0;
std::vector<Arrangement> arrangement_list; std::vector<Arrangement> arrangement_list;
std::vector<int32_t> shape; Shape shape;
for (uint32_t i = 0; i < expand_shape.GetDimSize(); i++) { for (size_t i = 0; i < expand_shape.GetDimSize(); i++) {
size *= expand_shape.GetDimByIdx(i); size *= expand_shape.GetDimByIdx(i);
if (size > GetDimByIdx(ind)) { if (size > GetDimByIdx(ind)) {
MS_LOG(ERROR) << "invalid expand_shape"; MS_LOG(ERROR) << "invalid expand_shape";
...@@ -145,7 +145,7 @@ std::shared_ptr<std::pair<std::vector<Arrangement>, Arrangement>> Arrangement::G ...@@ -145,7 +145,7 @@ std::shared_ptr<std::pair<std::vector<Arrangement>, Arrangement>> Arrangement::G
if (expand_shape_list_ptr == nullptr) { if (expand_shape_list_ptr == nullptr) {
return nullptr; return nullptr;
} }
std::vector<int32_t> expand_num_list_shape; Shape expand_num_list_shape;
(void)std::transform(expand_shape_list_ptr->begin(), expand_shape_list_ptr->end(), (void)std::transform(expand_shape_list_ptr->begin(), expand_shape_list_ptr->end(),
std::back_inserter(expand_num_list_shape), std::back_inserter(expand_num_list_shape),
[](const Arrangement &arr) { return SizeToInt(arr.GetDimSize()); }); [](const Arrangement &arr) { return SizeToInt(arr.GetDimSize()); });
...@@ -158,9 +158,9 @@ std::shared_ptr<std::pair<std::vector<Arrangement>, Arrangement>> Arrangement::G ...@@ -158,9 +158,9 @@ std::shared_ptr<std::pair<std::vector<Arrangement>, Arrangement>> Arrangement::G
return std::make_shared<std::pair<std::vector<Arrangement>, Arrangement>>(out_value); return std::make_shared<std::pair<std::vector<Arrangement>, Arrangement>>(out_value);
} }
std::vector<int32_t> Arrangement::ComputeReverseAccumulateSumInReverseOrder() const { Shape Arrangement::ComputeReverseAccumulateSumInReverseOrder() const {
std::vector<int32_t> shape_accum; Shape shape_accum;
int32_t size = 0; int64_t size = 0;
for (auto iter = array_.end() - 1; iter >= array_.begin(); --iter) { for (auto iter = array_.end() - 1; iter >= array_.begin(); --iter) {
shape_accum.push_back(size); shape_accum.push_back(size);
size += *iter; size += *iter;
...@@ -173,11 +173,11 @@ std::shared_ptr<Arrangement> Arrangement::GetExpandedShapeByExpandListReserveLef ...@@ -173,11 +173,11 @@ std::shared_ptr<Arrangement> Arrangement::GetExpandedShapeByExpandListReserveLef
if (expand_list.size() != GetDimSize()) { if (expand_list.size() != GetDimSize()) {
return nullptr; return nullptr;
} }
std::vector<int32_t> new_shape; Shape new_shape;
for (uint32_t i = 0; i < expand_list.size(); i++) { for (size_t i = 0; i < expand_list.size(); i++) {
if (expand_list[i].GetDimSize() >= 1) { if (expand_list[i].GetDimSize() >= 1) {
int32_t size = 1; int64_t size = 1;
for (uint32_t k = 0; k < expand_list[i].GetDimSize() - 1; k++) { for (size_t k = 0; k < expand_list[i].GetDimSize() - 1; k++) {
new_shape.push_back(expand_list[i].GetDimByIdx(k)); new_shape.push_back(expand_list[i].GetDimByIdx(k));
size *= expand_list[i].GetDimByIdx(k); size *= expand_list[i].GetDimByIdx(k);
} }
...@@ -207,7 +207,7 @@ std::shared_ptr<Arrangement> Arrangement::GetUnifiedShape(const Arrangement &in2 ...@@ -207,7 +207,7 @@ std::shared_ptr<Arrangement> Arrangement::GetUnifiedShape(const Arrangement &in2
if (status != Status::SUCCESS) { if (status != Status::SUCCESS) {
return nullptr; return nullptr;
} }
std::vector<int32_t> out_shape; Shape out_shape;
status = AccumulateProductToShape(out_accum, &out_shape); status = AccumulateProductToShape(out_accum, &out_shape);
if (status != Status::SUCCESS) { if (status != Status::SUCCESS) {
return nullptr; return nullptr;
...@@ -231,8 +231,8 @@ std::vector<size_t> Arrangement::GetSqueezeIdx() const { ...@@ -231,8 +231,8 @@ std::vector<size_t> Arrangement::GetSqueezeIdx() const {
} }
Arrangement Arrangement::GetSqueezeArrangement() const { Arrangement Arrangement::GetSqueezeArrangement() const {
std::vector<int32_t> out_shape(array_.size()); Shape out_shape(array_.size());
auto it = std::copy_if(array_.begin(), array_.end(), out_shape.begin(), [](int32_t value) { return value != 1; }); auto it = std::copy_if(array_.begin(), array_.end(), out_shape.begin(), [](int64_t value) { return value != 1; });
out_shape.resize(LongToSize(std::distance(out_shape.begin(), it))); out_shape.resize(LongToSize(std::distance(out_shape.begin(), it)));
// if all elements are 1, out_shape = {1} // if all elements are 1, out_shape = {1}
......
...@@ -32,11 +32,11 @@ class Arrangement : public Array { ...@@ -32,11 +32,11 @@ class Arrangement : public Array {
public: public:
Arrangement() : size_(1) {} Arrangement() : size_(1) {}
~Arrangement() override = default; ~Arrangement() override = default;
Status Init(const std::vector<int32_t> &array) override; Status Init(const Shape &array) override;
int32_t size() const { return size_; } int64_t size() const { return size_; }
std::vector<int32_t> GetFrontElementByValue(int32_t value) const; Shape GetFrontElementByValue(int64_t value) const;
std::shared_ptr<std::vector<Arrangement>> GetExpandShapeList(const Arrangement &expand_shape) const; std::shared_ptr<std::vector<Arrangement>> GetExpandShapeList(const Arrangement &expand_shape) const;
std::vector<int32_t> ComputeReverseAccumulateSumInReverseOrder() const; Shape ComputeReverseAccumulateSumInReverseOrder() const;
std::shared_ptr<Arrangement> GetExpandedShapeByExpandListReserveLeft( std::shared_ptr<Arrangement> GetExpandedShapeByExpandListReserveLeft(
const std::vector<Arrangement> &expand_list) const; const std::vector<Arrangement> &expand_list) const;
std::shared_ptr<Arrangement> GetExpandedShapeByExpandListRemoveLeft( std::shared_ptr<Arrangement> GetExpandedShapeByExpandListRemoveLeft(
...@@ -50,7 +50,7 @@ class Arrangement : public Array { ...@@ -50,7 +50,7 @@ class Arrangement : public Array {
private: private:
bool IsValidArrangement(); bool IsValidArrangement();
void ComputeSize(); void ComputeSize();
int32_t size_; int64_t size_;
}; };
} // namespace parallel } // namespace parallel
} // namespace mindspore } // namespace mindspore
......
...@@ -31,14 +31,14 @@ std::string Array::ToString() const { ...@@ -31,14 +31,14 @@ std::string Array::ToString() const {
return buffer.str(); return buffer.str();
} }
Status Array::Init(const std::vector<int32_t> &array) { Status Array::Init(const Shape &array) {
array_ = array; array_ = array;
return IsvalidArray() ? Status::SUCCESS : Status::FAILED; return IsvalidArray() ? Status::SUCCESS : Status::FAILED;
} }
bool Array::IsvalidArray() const { return true; } bool Array::IsvalidArray() const { return true; }
int32_t Array::GetDimByIdx(uint32_t idx) const { int64_t Array::GetDimByIdx(size_t idx) const {
size_t mod_idx = idx; size_t mod_idx = idx;
if (idx >= GetDimSize()) { if (idx >= GetDimSize()) {
MS_LOG(EXCEPTION) << "idx is " << idx << ", but array size is " << GetDimSize(); MS_LOG(EXCEPTION) << "idx is " << idx << ", but array size is " << GetDimSize();
...@@ -46,7 +46,7 @@ int32_t Array::GetDimByIdx(uint32_t idx) const { ...@@ -46,7 +46,7 @@ int32_t Array::GetDimByIdx(uint32_t idx) const {
return array_[mod_idx]; return array_[mod_idx];
} }
int32_t Array::GetDimByReverseIdx(uint32_t idx) const { int64_t Array::GetDimByReverseIdx(size_t idx) const {
size_t mod_idx = idx; size_t mod_idx = idx;
if (idx >= GetDimSize()) { if (idx >= GetDimSize()) {
MS_LOG(EXCEPTION) << "idx is " << idx << " but array size is " << GetDimSize(); MS_LOG(EXCEPTION) << "idx is " << idx << " but array size is " << GetDimSize();
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "frontend/parallel/status.h" #include "frontend/parallel/status.h"
#include "frontend/parallel/device_matrix.h"
namespace mindspore { namespace mindspore {
namespace parallel { namespace parallel {
...@@ -31,16 +32,16 @@ class Array { ...@@ -31,16 +32,16 @@ class Array {
Array() = default; Array() = default;
virtual ~Array() = default; virtual ~Array() = default;
std::string ToString() const; std::string ToString() const;
virtual Status Init(const std::vector<int32_t> &array); virtual Status Init(const Shape &array);
bool IsvalidArray() const; bool IsvalidArray() const;
std::vector<int32_t> array() const { return array_; } Shape array() const { return array_; }
size_t GetDimSize() const { return array_.size(); } size_t GetDimSize() const { return array_.size(); }
int32_t GetDimByIdx(uint32_t idx) const; int64_t GetDimByIdx(size_t idx) const;
int32_t GetDimByReverseIdx(uint32_t idx) const; int64_t GetDimByReverseIdx(size_t idx) const;
bool operator==(const Array &a1) const; bool operator==(const Array &a1) const;
protected: protected:
std::vector<int32_t> array_; Shape array_;
}; };
} // namespace parallel } // namespace parallel
} // namespace mindspore } // namespace mindspore
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include <functional> #include <functional>
#include <numeric> #include <numeric>
#include <algorithm>
namespace mindspore { namespace mindspore {
namespace parallel { namespace parallel {
...@@ -42,8 +43,8 @@ OperatorVector ConstructOperator::SkipRedisReshapeOP(Shape shape) { ...@@ -42,8 +43,8 @@ OperatorVector ConstructOperator::SkipRedisReshapeOP(Shape shape) {
} }
Status ConstructOperator::ReshapeOP(Shape shape) { Status ConstructOperator::ReshapeOP(Shape shape) {
int32_t prod = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>()); int64_t prod = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int64_t>());
int32_t prod_expect = std::accumulate(tensor_shape_.begin(), tensor_shape_.end(), 1, std::multiplies<int>()); int64_t prod_expect = std::accumulate(tensor_shape_.begin(), tensor_shape_.end(), 1, std::multiplies<int64_t>());
if (prod != prod_expect) { if (prod != prod_expect) {
ValuePtr ptr = MakeValue(shape); ValuePtr ptr = MakeValue(shape);
MS_EXCEPTION_IF_NULL(ptr); MS_EXCEPTION_IF_NULL(ptr);
...@@ -68,12 +69,21 @@ Operator CreateStridedSliceOp(int32_t value, const Shape &begin, const Shape &en ...@@ -68,12 +69,21 @@ Operator CreateStridedSliceOp(int32_t value, const Shape &begin, const Shape &en
Attr attr_shrink_axis_mask = std::make_pair(SHRINK_AXIS_MASK, attr_value); Attr attr_shrink_axis_mask = std::make_pair(SHRINK_AXIS_MASK, attr_value);
OperatorAttrs attrs = {attr_begin_mask, attr_end_mask, attr_ellipsis_mask, attr_new_axis_mask, attr_shrink_axis_mask}; OperatorAttrs attrs = {attr_begin_mask, attr_end_mask, attr_ellipsis_mask, attr_new_axis_mask, attr_shrink_axis_mask};
ValuePtr param_begin_value = MakeValue(begin); std::vector<int32_t> begin_int;
(void)std::transform(begin.begin(), begin.end(), std::back_inserter(begin_int),
[](const int64_t &value) { return static_cast<int32_t>(value); });
ValuePtr param_begin_value = MakeValue(begin_int);
Param param_begin = std::make_pair(std::make_pair(BEGIN, param_begin_value), 2); Param param_begin = std::make_pair(std::make_pair(BEGIN, param_begin_value), 2);
ValuePtr param_end_value = MakeValue(end); std::vector<int32_t> end_int;
(void)std::transform(end.begin(), end.end(), std::back_inserter(end_int),
[](const int64_t &value) { return static_cast<int32_t>(value); });
ValuePtr param_end_value = MakeValue(end_int);
Param param_end = std::make_pair(std::make_pair(END, param_end_value), 3); Param param_end = std::make_pair(std::make_pair(END, param_end_value), 3);
ValuePtr param_strides_value = MakeValue(strides); std::vector<int32_t> strides_int;
(void)std::transform(strides.begin(), strides.end(), std::back_inserter(strides_int),
[](const int64_t &value) { return static_cast<int32_t>(value); });
ValuePtr param_strides_value = MakeValue(strides_int);
Param param_strides = std::make_pair(std::make_pair(STRIDES, param_strides_value), 4); Param param_strides = std::make_pair(std::make_pair(STRIDES, param_strides_value), 4);
OperatorParams params = {param_begin, param_end, param_strides}; OperatorParams params = {param_begin, param_end, param_strides};
OperatorArgs op_args = std::make_pair(attrs, params); OperatorArgs op_args = std::make_pair(attrs, params);
...@@ -86,16 +96,16 @@ Status ConstructOperator::StridedSliceOP(Args args) { ...@@ -86,16 +96,16 @@ Status ConstructOperator::StridedSliceOP(Args args) {
MS_LOG(ERROR) << "args size should not be less than 3!"; MS_LOG(ERROR) << "args size should not be less than 3!";
return Status::FAILED; return Status::FAILED;
} }
int32_t split_count = args[0]; int64_t split_count = args[0];
if (split_count <= 0) { if (split_count <= 0) {
MS_LOG(ERROR) << "split_count should not be less than 0!"; MS_LOG(ERROR) << "split_count should not be less than 0!";
return Status::FAILED; return Status::FAILED;
} }
int32_t split_dim = args[1]; int64_t split_dim = args[1];
int32_t dev_dim = args[2]; int64_t dev_dim = args[2];
std::vector<Group> group_list; std::vector<Group> group_list;
if (CreateGroupByDim(dev_size_ - IntToSize(dev_dim) - 1, &group_list) != SUCCESS) { if (CreateGroupByDim(dev_size_ - LongToSize(dev_dim) - 1, &group_list) != SUCCESS) {
MS_LOG(ERROR) << "stride slice op: create group failed"; MS_LOG(ERROR) << "stride slice op: create group failed";
return FAILED; return FAILED;
} else if (group_list.empty()) { // this group only has one device, don't need do StridedSlice } else if (group_list.empty()) { // this group only has one device, don't need do StridedSlice
...@@ -114,7 +124,7 @@ Status ConstructOperator::StridedSliceOP(Args args) { ...@@ -114,7 +124,7 @@ Status ConstructOperator::StridedSliceOP(Args args) {
Shape strides(size, 1); Shape strides(size, 1);
size_t index = 0; size_t index = 0;
for (auto num : tensor_shape_) { for (auto num : tensor_shape_) {
if (index != IntToSize(split_dim)) { if (index != LongToSize(split_dim)) {
begin[index] = 0; begin[index] = 0;
end[index] = num; end[index] = num;
} else { } else {
...@@ -123,9 +133,9 @@ Status ConstructOperator::StridedSliceOP(Args args) { ...@@ -123,9 +133,9 @@ Status ConstructOperator::StridedSliceOP(Args args) {
<< "! when construct StridedSlice operator"; << "! when construct StridedSlice operator";
return Status::INVALID_ARGUMENT; return Status::INVALID_ARGUMENT;
} }
int32_t count = num / split_count; int64_t count = num / split_count;
begin[index] = SizeToInt(rank) * count; begin[index] = SizeToLong(rank) * count;
end[index] = (SizeToInt(rank) + 1) * count; end[index] = (SizeToLong(rank) + 1) * count;
} }
index++; index++;
} }
...@@ -135,7 +145,7 @@ Status ConstructOperator::StridedSliceOP(Args args) { ...@@ -135,7 +145,7 @@ Status ConstructOperator::StridedSliceOP(Args args) {
return Status::SUCCESS; return Status::SUCCESS;
} }
Status ConstructOperator::AllGatherOP(int32_t dev_dim) { Status ConstructOperator::AllGatherOP(int64_t dev_dim) {
if ((IntToSize(dev_dim) >= dev_size_) || (dev_dim < 0)) { if ((IntToSize(dev_dim) >= dev_size_) || (dev_dim < 0)) {
MS_LOG(ERROR) << "Invalid device dimension " << dev_dim << " when construct AllGather operator!"; MS_LOG(ERROR) << "Invalid device dimension " << dev_dim << " when construct AllGather operator!";
return Status::INVALID_ARGUMENT; return Status::INVALID_ARGUMENT;
...@@ -160,7 +170,7 @@ Status ConstructOperator::AllGatherOP(int32_t dev_dim) { ...@@ -160,7 +170,7 @@ Status ConstructOperator::AllGatherOP(int32_t dev_dim) {
return Status::SUCCESS; return Status::SUCCESS;
} }
Status ConstructOperator::ConcatOP(int32_t concat_dim) { Status ConstructOperator::ConcatOP(int64_t concat_dim) {
if (IntToSize(concat_dim) >= tensor_shape_.size()) { if (IntToSize(concat_dim) >= tensor_shape_.size()) {
MS_LOG(ERROR) << "Invalid tensor dimension " << concat_dim << " when construct Concat operator!"; MS_LOG(ERROR) << "Invalid tensor dimension " << concat_dim << " when construct Concat operator!";
return Status::INVALID_ARGUMENT; return Status::INVALID_ARGUMENT;
...@@ -174,7 +184,7 @@ Status ConstructOperator::ConcatOP(int32_t concat_dim) { ...@@ -174,7 +184,7 @@ Status ConstructOperator::ConcatOP(int32_t concat_dim) {
return Status::SUCCESS; return Status::SUCCESS;
} }
Status ConstructOperator::SplitOP(int32_t split_count) { Status ConstructOperator::SplitOP(int64_t split_count) {
if (split_count <= 0) { if (split_count <= 0) {
MS_LOG(ERROR) << "Invalid split count when construct Split operator!"; MS_LOG(ERROR) << "Invalid split count when construct Split operator!";
return Status::FAILED; return Status::FAILED;
...@@ -196,30 +206,30 @@ Status ConstructOperator::AlltoAllOP(Args args) { ...@@ -196,30 +206,30 @@ Status ConstructOperator::AlltoAllOP(Args args) {
MS_LOG(ERROR) << "args size should not be less than 4!"; MS_LOG(ERROR) << "args size should not be less than 4!";
return Status::FAILED; return Status::FAILED;
} }
int32_t split_count = args[0]; int64_t split_count = args[0];
int32_t split_dim = args[1]; int64_t split_dim = args[1];
int32_t concat_dim = args[2]; int64_t concat_dim = args[2];
int32_t dev_dim = args[3]; int64_t dev_dim = args[3];
if (split_count <= 0) { if (split_count <= 0) {
MS_LOG(ERROR) << "Invalid split count when construct AlltoAll operator!"; MS_LOG(ERROR) << "Invalid split count when construct AlltoAll operator!";
return Status::FAILED; return Status::FAILED;
} }
if (tensor_shape_[IntToSize(split_dim)] % split_count != 0) { if (tensor_shape_[LongToSize(split_dim)] % split_count != 0) {
MS_LOG(ERROR) << "Tensor can not be split into " << split_count << " slices in the dimension " << split_dim MS_LOG(ERROR) << "Tensor can not be split into " << split_count << " slices in the dimension " << split_dim
<< "when construct AlltoAll operator!"; << "when construct AlltoAll operator!";
return Status::INVALID_ARGUMENT; return Status::INVALID_ARGUMENT;
} }
if (IntToSize(concat_dim) >= tensor_shape_.size()) { if (LongToSize(concat_dim) >= tensor_shape_.size()) {
MS_LOG(ERROR) << "Invalid split count " << split_count << " when construct AlltoAll operator!"; MS_LOG(ERROR) << "Invalid split count " << split_count << " when construct AlltoAll operator!";
return Status::INVALID_ARGUMENT; return Status::INVALID_ARGUMENT;
} }
if ((IntToSize(dev_dim) >= dev_size_) || (dev_dim < 0)) { if ((LongToSize(dev_dim) >= dev_size_) || (dev_dim < 0)) {
MS_LOG(ERROR) << "Invalid device dimension " << dev_dim << " when construct AlltoAll operator!"; MS_LOG(ERROR) << "Invalid device dimension " << dev_dim << " when construct AlltoAll operator!";
return Status::INVALID_ARGUMENT; return Status::INVALID_ARGUMENT;
} }
std::vector<Group> group_list; std::vector<Group> group_list;
if (CreateGroupByDim(dev_size_ - IntToSize(dev_dim) - 1, &group_list) != SUCCESS) { if (CreateGroupByDim(dev_size_ - LongToSize(dev_dim) - 1, &group_list) != SUCCESS) {
MS_LOG(ERROR) << "AlltoAll op: create group failed"; MS_LOG(ERROR) << "AlltoAll op: create group failed";
return FAILED; return FAILED;
} else if (group_list.empty()) { // this group only has one device, don't need do alltoall } else if (group_list.empty()) { // this group only has one device, don't need do alltoall
......
...@@ -27,7 +27,7 @@ ...@@ -27,7 +27,7 @@
namespace mindspore { namespace mindspore {
namespace parallel { namespace parallel {
using Args = std::vector<std::int32_t>; using Args = std::vector<std::int64_t>;
class ConstructOperator { class ConstructOperator {
public: public:
...@@ -38,9 +38,9 @@ class ConstructOperator { ...@@ -38,9 +38,9 @@ class ConstructOperator {
OperatorVector SkipRedisReshapeOP(Shape shape); OperatorVector SkipRedisReshapeOP(Shape shape);
Status ReshapeOP(Shape shape); Status ReshapeOP(Shape shape);
Status StridedSliceOP(Args args); Status StridedSliceOP(Args args);
Status AllGatherOP(int32_t dev_dim); Status AllGatherOP(int64_t dev_dim);
Status SplitOP(int32_t split_count); Status SplitOP(int64_t split_count);
Status ConcatOP(int32_t concat_dim); Status ConcatOP(int64_t concat_dim);
Status AlltoAllOP(Args args); Status AlltoAllOP(Args args);
Operator GetOperator() const { return op_; } Operator GetOperator() const { return op_; }
void UpdateTensorShape(const Shape &tensor_shape) { tensor_shape_ = tensor_shape; } void UpdateTensorShape(const Shape &tensor_shape) { tensor_shape_ = tensor_shape; }
......
...@@ -26,7 +26,7 @@ ...@@ -26,7 +26,7 @@
namespace mindspore { namespace mindspore {
namespace parallel { namespace parallel {
Status Map::Init(const std::vector<int32_t> &array) { Status Map::Init(const Shape &array) {
Status status = Array::Init(array); Status status = Array::Init(array);
if (status != Status::SUCCESS) { if (status != Status::SUCCESS) {
return Status::FAILED; return Status::FAILED;
...@@ -39,11 +39,11 @@ Status Map::Init(const std::vector<int32_t> &array) { ...@@ -39,11 +39,11 @@ Status Map::Init(const std::vector<int32_t> &array) {
} }
bool Map::IsValidMap() { bool Map::IsValidMap() {
if (std::any_of(array_.begin(), array_.end(), [](int32_t value) { return ((value < 0) && (value != MAP_NONE)); })) { if (std::any_of(array_.begin(), array_.end(), [](int64_t value) { return ((value < 0) && (value != MAP_NONE)); })) {
return false; return false;
} }
// check that all none -1 value in array_ is different // check that all none -1 value in array_ is different
std::vector<int32_t> sorted_array = array_; Shape sorted_array = array_;
std::sort(sorted_array.begin(), sorted_array.end()); std::sort(sorted_array.begin(), sorted_array.end());
int32_t value = MAP_NONE; int32_t value = MAP_NONE;
for (auto &element : sorted_array) { for (auto &element : sorted_array) {
...@@ -58,7 +58,7 @@ bool Map::IsValidMap() { ...@@ -58,7 +58,7 @@ bool Map::IsValidMap() {
return true; return true;
} }
int32_t Map::GetMaxItem() const { int64_t Map::GetMaxItem() const {
if (!array_.empty()) { if (!array_.empty()) {
return *std::max_element(array_.begin(), array_.end()); return *std::max_element(array_.begin(), array_.end());
} else { } else {
...@@ -66,7 +66,7 @@ int32_t Map::GetMaxItem() const { ...@@ -66,7 +66,7 @@ int32_t Map::GetMaxItem() const {
} }
} }
int32_t Map::GetIndexByValue(int32_t value) const { int32_t Map::GetIndexByValue(int64_t value) const {
auto iter = find(array_.begin(), array_.end(), value); auto iter = find(array_.begin(), array_.end(), value);
if (iter != array_.end()) { if (iter != array_.end()) {
return static_cast<int32_t>(std::distance(array_.begin(), iter)); return static_cast<int32_t>(std::distance(array_.begin(), iter));
...@@ -82,15 +82,15 @@ std::shared_ptr<Map> Map::ExpandMapByNone(const Arrangement &expand_num_list) co ...@@ -82,15 +82,15 @@ std::shared_ptr<Map> Map::ExpandMapByNone(const Arrangement &expand_num_list) co
if (expand_num_list.GetDimSize() != GetDimSize()) { if (expand_num_list.GetDimSize() != GetDimSize()) {
return nullptr; return nullptr;
} }
std::vector<int32_t> new_shape; Shape new_shape;
for (uint32_t i = 0; i != GetDimSize(); i++) { for (size_t i = 0; i != GetDimSize(); i++) {
if (GetDimByIdx(i) == MAP_NONE) { if (GetDimByIdx(i) == MAP_NONE) {
for (int32_t j = 0; j < expand_num_list.GetDimByIdx(i); j++) { for (int64_t j = 0; j < expand_num_list.GetDimByIdx(i); j++) {
new_shape.push_back(MAP_NONE); new_shape.push_back(MAP_NONE);
} }
} else { } else {
new_shape.push_back(GetDimByIdx(i)); new_shape.push_back(GetDimByIdx(i));
int32_t j = 1; int64_t j = 1;
while (j < expand_num_list.GetDimByIdx(i)) { while (j < expand_num_list.GetDimByIdx(i)) {
new_shape.push_back(MAP_NONE); new_shape.push_back(MAP_NONE);
j++; j++;
...@@ -106,17 +106,17 @@ std::shared_ptr<Map> Map::ExpandMapByNone(const Arrangement &expand_num_list) co ...@@ -106,17 +106,17 @@ std::shared_ptr<Map> Map::ExpandMapByNone(const Arrangement &expand_num_list) co
* expand.size() should be equal to array_.size() * expand.size() should be equal to array_.size()
*/ */
std::shared_ptr<Map> Map::ExpandMapByDecreaseNumber(const Arrangement &expand_num_list) const { std::shared_ptr<Map> Map::ExpandMapByDecreaseNumber(const Arrangement &expand_num_list) const {
if (GetMaxItem() >= static_cast<int32_t>(expand_num_list.GetDimSize())) { if (GetMaxItem() >= static_cast<int64_t>(expand_num_list.GetDimSize())) {
return nullptr; return nullptr;
} }
std::vector<int32_t> new_shape; Shape new_shape;
for (uint32_t i = 0; i < GetDimSize(); i++) { for (size_t i = 0; i < GetDimSize(); i++) {
if (GetDimByIdx(i) == MAP_NONE) { if (GetDimByIdx(i) == MAP_NONE) {
new_shape.push_back(MAP_NONE); new_shape.push_back(MAP_NONE);
} else { } else {
int32_t start_map = int64_t start_map =
expand_num_list.ComputeReverseAccumulateSumInReverseOrder()[static_cast<uint32_t>(GetDimByIdx(i))]; expand_num_list.ComputeReverseAccumulateSumInReverseOrder()[static_cast<size_t>(GetDimByIdx(i))];
for (int32_t k = expand_num_list.GetDimByReverseIdx(static_cast<uint32_t>(GetDimByIdx(i))) - 1; k >= 0; k--) { for (int32_t k = expand_num_list.GetDimByReverseIdx(static_cast<size_t>(GetDimByIdx(i))) - 1; k >= 0; k--) {
new_shape.push_back(k + start_map); new_shape.push_back(k + start_map);
} }
} }
...@@ -127,16 +127,16 @@ std::shared_ptr<Map> Map::ExpandMapByDecreaseNumber(const Arrangement &expand_nu ...@@ -127,16 +127,16 @@ std::shared_ptr<Map> Map::ExpandMapByDecreaseNumber(const Arrangement &expand_nu
} }
std::shared_ptr<std::vector<Arrangement>> Map::ReMapVector(const std::vector<Arrangement> &input_vector) const { std::shared_ptr<std::vector<Arrangement>> Map::ReMapVector(const std::vector<Arrangement> &input_vector) const {
if (GetMaxItem() >= static_cast<int32_t>(input_vector.size())) { if (GetMaxItem() >= static_cast<int64_t>(input_vector.size())) {
return nullptr; return nullptr;
} }
std::vector<Arrangement> out; std::vector<Arrangement> out;
Arrangement empty_arrangement; Arrangement empty_arrangement;
for (uint32_t i = 0; i < GetDimSize(); i++) { for (size_t i = 0; i < GetDimSize(); i++) {
if (GetDimByIdx(i) == MAP_NONE) { if (GetDimByIdx(i) == MAP_NONE) {
out.push_back(empty_arrangement); out.push_back(empty_arrangement);
} else { } else {
out.push_back(input_vector[IntToUint(SizeToInt(input_vector.size()) - 1 - GetDimByIdx(i))]); out.push_back(input_vector[input_vector.size() - 1 - LongToSize(GetDimByIdx(i))]);
} }
} }
return std::make_shared<std::vector<Arrangement>>(out); return std::make_shared<std::vector<Arrangement>>(out);
...@@ -144,7 +144,7 @@ std::shared_ptr<std::vector<Arrangement>> Map::ReMapVector(const std::vector<Arr ...@@ -144,7 +144,7 @@ std::shared_ptr<std::vector<Arrangement>> Map::ReMapVector(const std::vector<Arr
bool Map::CheckNoneByIdxList(std::vector<size_t> idx_list) const { bool Map::CheckNoneByIdxList(std::vector<size_t> idx_list) const {
for (auto &value : idx_list) { for (auto &value : idx_list) {
if (GetDimByIdx(SizeToUint(value)) != MAP_NONE) { if (GetDimByIdx(value) != MAP_NONE) {
return false; return false;
} }
} }
...@@ -152,11 +152,11 @@ bool Map::CheckNoneByIdxList(std::vector<size_t> idx_list) const { ...@@ -152,11 +152,11 @@ bool Map::CheckNoneByIdxList(std::vector<size_t> idx_list) const {
} }
Map Map::SqueezeMapByIdxList(std::vector<size_t> idx_list) const { Map Map::SqueezeMapByIdxList(std::vector<size_t> idx_list) const {
std::vector<int32_t> out_shape; Shape out_shape;
for (size_t i = 0; i < GetDimSize(); i++) { for (size_t i = 0; i < GetDimSize(); i++) {
auto it = std::find(idx_list.begin(), idx_list.end(), i); auto it = std::find(idx_list.begin(), idx_list.end(), i);
if (it == idx_list.end()) { if (it == idx_list.end()) {
out_shape.push_back(GetDimByIdx(SizeToUint(i))); out_shape.push_back(GetDimByIdx(i));
} }
} }
if (out_shape.empty()) { if (out_shape.empty()) {
......
...@@ -34,9 +34,9 @@ class Map : public Array { ...@@ -34,9 +34,9 @@ class Map : public Array {
public: public:
Map() = default; Map() = default;
~Map() override = default; ~Map() override = default;
Status Init(const std::vector<int32_t> &array) override; Status Init(const Shape &array) override;
int32_t GetMaxItem() const; int64_t GetMaxItem() const;
int32_t GetIndexByValue(int32_t value) const; int32_t GetIndexByValue(int64_t value) const;
std::shared_ptr<Map> ExpandMapByNone(const Arrangement &expand_num_list) const; std::shared_ptr<Map> ExpandMapByNone(const Arrangement &expand_num_list) const;
std::shared_ptr<Map> ExpandMapByDecreaseNumber(const Arrangement &expand_num_list) const; std::shared_ptr<Map> ExpandMapByDecreaseNumber(const Arrangement &expand_num_list) const;
std::shared_ptr<std::vector<Arrangement>> ReMapVector(const std::vector<Arrangement> &input_vector) const; std::shared_ptr<std::vector<Arrangement>> ReMapVector(const std::vector<Arrangement> &input_vector) const;
......
...@@ -47,8 +47,8 @@ Status RedistributionOperatorInfer::Init(const TensorLayout &tensor_layout, cons ...@@ -47,8 +47,8 @@ Status RedistributionOperatorInfer::Init(const TensorLayout &tensor_layout, cons
constructor_.UpdateTensorShape(cur_tensor_layout_.slice_shape().array()); constructor_.UpdateTensorShape(cur_tensor_layout_.slice_shape().array());
size_t key = 0; size_t key = 0;
std::vector<int32_t> map = in_tensor_map_.array(); Shape map = in_tensor_map_.array();
for (int32_t item : map) { for (int64_t item : map) {
map_[key++] = item; map_[key++] = item;
} }
...@@ -83,9 +83,9 @@ Status RedistributionOperatorInfer::InferRedistributionOperator() { ...@@ -83,9 +83,9 @@ Status RedistributionOperatorInfer::InferRedistributionOperator() {
// break loop structure with concat_by_axis // break loop structure with concat_by_axis
if (len_global == operator_list_.size() && !map_.empty()) { if (len_global == operator_list_.size() && !map_.empty()) {
size_t index = map_.begin()->first; size_t index = map_.begin()->first;
int32_t in_dim = map_[index]; int64_t in_dim = map_[index];
map_[index] = NONE; map_[index] = NONE;
Args args = {SizeToInt(index), in_dim, dev_mat_.GetDimByReverseIdx(IntToUint(in_dim))}; Args args = {SizeToInt(index), in_dim, dev_mat_.GetDimByReverseIdx(LongToSize(in_dim))};
if (InsertOperator(CONCAT_BY_AXIS, args) == Status::FAILED) { if (InsertOperator(CONCAT_BY_AXIS, args) == Status::FAILED) {
return Status::FAILED; return Status::FAILED;
} }
...@@ -97,8 +97,8 @@ Status RedistributionOperatorInfer::InferRedistributionOperator() { ...@@ -97,8 +97,8 @@ Status RedistributionOperatorInfer::InferRedistributionOperator() {
Status RedistributionOperatorInfer::InferSplitByAxis() { Status RedistributionOperatorInfer::InferSplitByAxis() {
for (auto iter = map_.begin(); iter != map_.end();) { for (auto iter = map_.begin(); iter != map_.end();) {
uint32_t index = iter->first; uint32_t index = iter->first;
int32_t in_dim = iter->second; int64_t in_dim = iter->second;
int32_t out_dim = out_tensor_map_.GetDimByIdx(index); int64_t out_dim = out_tensor_map_.GetDimByIdx(index);
if (in_dim == out_dim) { if (in_dim == out_dim) {
(void)map_.erase(iter++); (void)map_.erase(iter++);
continue; continue;
...@@ -122,8 +122,8 @@ Status RedistributionOperatorInfer::InferSplitByAxis() { ...@@ -122,8 +122,8 @@ Status RedistributionOperatorInfer::InferSplitByAxis() {
Status RedistributionOperatorInfer::InferPermuteByAxis() { Status RedistributionOperatorInfer::InferPermuteByAxis() {
for (auto iter = map_.begin(); iter != map_.end();) { for (auto iter = map_.begin(); iter != map_.end();) {
uint32_t index = iter->first; uint32_t index = iter->first;
int32_t in_dim = map_[index]; int64_t in_dim = map_[index];
int32_t out_dim = out_tensor_map_.GetDimByIdx(index); int64_t out_dim = out_tensor_map_.GetDimByIdx(index);
if (in_dim == out_dim) { if (in_dim == out_dim) {
(void)map_.erase(iter++); (void)map_.erase(iter++);
continue; continue;
...@@ -132,9 +132,9 @@ Status RedistributionOperatorInfer::InferPermuteByAxis() { ...@@ -132,9 +132,9 @@ Status RedistributionOperatorInfer::InferPermuteByAxis() {
std::any_of(map_.begin(), map_.end(), std::any_of(map_.begin(), map_.end(),
[out_dim](const RedistributionOperatorMap::value_type &a) { return a.second == out_dim; })) { [out_dim](const RedistributionOperatorMap::value_type &a) { return a.second == out_dim; })) {
int32_t cat_dim = in_tensor_map_.GetIndexByValue(out_dim); int32_t cat_dim = in_tensor_map_.GetIndexByValue(out_dim);
int32_t dev_num = dev_mat_.GetDimByReverseIdx(IntToUint(out_dim)); int64_t dev_num = dev_mat_.GetDimByReverseIdx(LongToSize(out_dim));
if (is_cost_model_) { if (is_cost_model_) {
int32_t dev_dim = in_tensor_map_.GetDimByIdx(IntToUint(cat_dim)); int64_t dev_dim = in_tensor_map_.GetDimByIdx(IntToUint(cat_dim));
Args args_alltoall = {dev_mat_.GetDimByReverseIdx(IntToUint(dev_dim)), UintToInt(index), cat_dim, dev_dim, Args args_alltoall = {dev_mat_.GetDimByReverseIdx(IntToUint(dev_dim)), UintToInt(index), cat_dim, dev_dim,
dev_num}; dev_num};
if (InsertOperator(PERMUTE_BY_AXIS, args_alltoall) == Status::FAILED) { if (InsertOperator(PERMUTE_BY_AXIS, args_alltoall) == Status::FAILED) {
...@@ -165,10 +165,10 @@ Status RedistributionOperatorInfer::InferPermuteByAxis() { ...@@ -165,10 +165,10 @@ Status RedistributionOperatorInfer::InferPermuteByAxis() {
Status RedistributionOperatorInfer::InferConcatByAxis() { Status RedistributionOperatorInfer::InferConcatByAxis() {
for (auto iter = map_.begin(); iter != map_.end();) { for (auto iter = map_.begin(); iter != map_.end();) {
uint32_t index = iter->first; uint32_t index = iter->first;
int32_t in_dim = map_[index]; int64_t in_dim = map_[index];
int32_t out_dim = out_tensor_map_.GetDimByIdx(index); int64_t out_dim = out_tensor_map_.GetDimByIdx(index);
if (in_dim != NONE && out_tensor_map_.GetIndexByValue(in_dim) == NONE) { if (in_dim != NONE && out_tensor_map_.GetIndexByValue(in_dim) == NONE) {
Args args = {SizeToInt(index), in_dim, dev_mat_.GetDimByReverseIdx(IntToUint(in_dim))}; Args args = {SizeToInt(index), in_dim, dev_mat_.GetDimByReverseIdx(LongToSize(in_dim))};
if (InsertOperator(CONCAT_BY_AXIS, args) == Status::FAILED) { if (InsertOperator(CONCAT_BY_AXIS, args) == Status::FAILED) {
MS_LOG(ERROR) << "Insert ConcatByAxis Error!"; MS_LOG(ERROR) << "Insert ConcatByAxis Error!";
return Status::FAILED; return Status::FAILED;
...@@ -215,7 +215,7 @@ Status RedistributionOperatorInfer::TransferSplitByAxis(Args args) { ...@@ -215,7 +215,7 @@ Status RedistributionOperatorInfer::TransferSplitByAxis(Args args) {
MS_LOG(ERROR) << "args size should not be less than 3!"; MS_LOG(ERROR) << "args size should not be less than 3!";
return Status::FAILED; return Status::FAILED;
} }
uint32_t index = IntToUint(args[1]); size_t index = LongToSize(args[1]);
if (constructor_.StridedSliceOP(args) != Status::SUCCESS) { if (constructor_.StridedSliceOP(args) != Status::SUCCESS) {
return Status::FAILED; return Status::FAILED;
} else { } else {
...@@ -239,11 +239,11 @@ Status RedistributionOperatorInfer::TransferPermuteByAxis(Args args) { ...@@ -239,11 +239,11 @@ Status RedistributionOperatorInfer::TransferPermuteByAxis(Args args) {
operator_vector_.push_back(constructor_.GetOperator()); operator_vector_.push_back(constructor_.GetOperator());
output_info_vector_.push_back(std::make_pair(false, 0)); output_info_vector_.push_back(std::make_pair(false, 0));
} }
uint32_t index = IntToUint(args[1]); size_t index = LongToSize(args[1]);
int32_t val = args[2]; int64_t val = args[2];
int32_t out_dim = out_tensor_map_.GetDimByIdx(index); int64_t out_dim = out_tensor_map_.GetDimByIdx(index);
if (cur_tensor_layout_.UpdateTensorMap(IntToUint(val), NONE) == Status::FAILED) { if (cur_tensor_layout_.UpdateTensorMap(LongToSize(val), NONE) == Status::FAILED) {
return Status::FAILED; return Status::FAILED;
} }
if (cur_tensor_layout_.UpdateTensorMap(index, out_dim) == Status::FAILED) { if (cur_tensor_layout_.UpdateTensorMap(index, out_dim) == Status::FAILED) {
...@@ -257,9 +257,9 @@ Status RedistributionOperatorInfer::TransferConcatByAxis(Args args) { ...@@ -257,9 +257,9 @@ Status RedistributionOperatorInfer::TransferConcatByAxis(Args args) {
MS_LOG(ERROR) << "args size should not be less than 3!"; MS_LOG(ERROR) << "args size should not be less than 3!";
return Status::FAILED; return Status::FAILED;
} }
int32_t tensor_dim = args[0]; int64_t tensor_dim = args[0];
int32_t dev_dim = args[1]; int64_t dev_dim = args[1];
int32_t split_count = args[2]; int64_t split_count = args[2];
if (constructor_.AllGatherOP(dev_dim) != Status::SUCCESS) { if (constructor_.AllGatherOP(dev_dim) != Status::SUCCESS) {
return Status::FAILED; return Status::FAILED;
} else { } else {
...@@ -280,7 +280,7 @@ Status RedistributionOperatorInfer::TransferConcatByAxis(Args args) { ...@@ -280,7 +280,7 @@ Status RedistributionOperatorInfer::TransferConcatByAxis(Args args) {
output_info_vector_.push_back(std::make_pair(false, 0)); output_info_vector_.push_back(std::make_pair(false, 0));
} }
} }
if (cur_tensor_layout_.UpdateTensorMap(IntToUint(tensor_dim), NONE) == Status::FAILED) { if (cur_tensor_layout_.UpdateTensorMap(LongToSize(tensor_dim), NONE) == Status::FAILED) {
return Status::FAILED; return Status::FAILED;
} }
return Status::SUCCESS; return Status::SUCCESS;
......
...@@ -28,10 +28,10 @@ ...@@ -28,10 +28,10 @@
#include "utils/convert_utils.h" #include "utils/convert_utils.h"
namespace mindspore { namespace mindspore {
namespace parallel { namespace parallel {
using DeviceArrangement = std::vector<int32_t>; using DeviceArrangement = Shape;
using TensorMap = std::vector<int32_t>; using TensorMap = Shape;
using TensorShape = std::vector<int32_t>; using TensorShape = Shape;
using RedistributionOperatorMap = std::unordered_map<uint32_t, int32_t>; using RedistributionOperatorMap = std::unordered_map<uint32_t, int64_t>;
using OperatorR = std::pair<OperatorName, Args>; using OperatorR = std::pair<OperatorName, Args>;
using OperatorC = std::pair<OperatorR, Shape>; using OperatorC = std::pair<OperatorR, Shape>;
using OperatorList = std::vector<OperatorC>; using OperatorList = std::vector<OperatorC>;
......
...@@ -26,7 +26,7 @@ namespace parallel { ...@@ -26,7 +26,7 @@ namespace parallel {
* shape = [2, 8, 32] * shape = [2, 8, 32]
* shape_accum = [2, 2 * 8, 2 * 8 * 32] * shape_accum = [2, 2 * 8, 2 * 8 * 32]
*/ */
Status ShapeToAccumulateProduct(const std::vector<int32_t> &shape, std::vector<int64_t> *shape_accum) { Status ShapeToAccumulateProduct(const Shape &shape, Shape *shape_accum) {
MS_EXCEPTION_IF_NULL(shape_accum); MS_EXCEPTION_IF_NULL(shape_accum);
shape_accum->clear(); shape_accum->clear();
int64_t size = 1; int64_t size = 1;
...@@ -47,7 +47,7 @@ Status ShapeToAccumulateProduct(const std::vector<int32_t> &shape, std::vector<i ...@@ -47,7 +47,7 @@ Status ShapeToAccumulateProduct(const std::vector<int32_t> &shape, std::vector<i
* shape_accum = [2 * 8 * 32, 8 * 32, 32] * shape_accum = [2 * 8 * 32, 8 * 32, 32]
* *
*/ */
Status ShapeToAccumulateProductReverse(const std::vector<int32_t> &shape, std::vector<int64_t> *shape_accum) { Status ShapeToAccumulateProductReverse(const Shape &shape, Shape *shape_accum) {
MS_EXCEPTION_IF_NULL(shape_accum); MS_EXCEPTION_IF_NULL(shape_accum);
shape_accum->clear(); shape_accum->clear();
int64_t size = 1; int64_t size = 1;
...@@ -68,7 +68,7 @@ Status ShapeToAccumulateProductReverse(const std::vector<int32_t> &shape, std::v ...@@ -68,7 +68,7 @@ Status ShapeToAccumulateProductReverse(const std::vector<int32_t> &shape, std::v
* shape = [2, 8, 32] * shape = [2, 8, 32]
* *
*/ */
Status AccumulateProductToShape(const std::vector<int64_t> &shape_accum, std::vector<int32_t> *shape) { Status AccumulateProductToShape(const Shape &shape_accum, Shape *shape) {
MS_EXCEPTION_IF_NULL(shape); MS_EXCEPTION_IF_NULL(shape);
shape->clear(); shape->clear();
int64_t value = 1; int64_t value = 1;
...@@ -81,7 +81,7 @@ Status AccumulateProductToShape(const std::vector<int64_t> &shape_accum, std::ve ...@@ -81,7 +81,7 @@ Status AccumulateProductToShape(const std::vector<int64_t> &shape_accum, std::ve
MS_LOG(ERROR) << "shape_accum is not a accumulate product in ascending order"; MS_LOG(ERROR) << "shape_accum is not a accumulate product in ascending order";
return Status::FAILED; return Status::FAILED;
} }
shape->push_back(static_cast<int32_t>((*iter) / value)); shape->push_back(static_cast<int64_t>((*iter) / value));
value = (*iter); value = (*iter);
} }
return Status::SUCCESS; return Status::SUCCESS;
...@@ -92,7 +92,7 @@ Status AccumulateProductToShape(const std::vector<int64_t> &shape_accum, std::ve ...@@ -92,7 +92,7 @@ Status AccumulateProductToShape(const std::vector<int64_t> &shape_accum, std::ve
* shape_accum_reverse = [2 * 8 * 32, 8 * 32, 32] * shape_accum_reverse = [2 * 8 * 32, 8 * 32, 32]
* shape = [2, 8, 32] * shape = [2, 8, 32]
*/ */
Status AccumulateProductReverseToShape(const std::vector<int64_t> &shape_accum_reverse, std::vector<int32_t> *shape) { Status AccumulateProductReverseToShape(const Shape &shape_accum_reverse, Shape *shape) {
MS_EXCEPTION_IF_NULL(shape); MS_EXCEPTION_IF_NULL(shape);
shape->clear(); shape->clear();
int64_t value = 1; int64_t value = 1;
...@@ -105,7 +105,7 @@ Status AccumulateProductReverseToShape(const std::vector<int64_t> &shape_accum_r ...@@ -105,7 +105,7 @@ Status AccumulateProductReverseToShape(const std::vector<int64_t> &shape_accum_r
MS_LOG(ERROR) << "shape_accum is not a accumulate product in ascending order"; MS_LOG(ERROR) << "shape_accum is not a accumulate product in ascending order";
return Status::FAILED; return Status::FAILED;
} }
(void)shape->insert(shape->begin(), static_cast<int32_t>((*iter) / value)); (void)shape->insert(shape->begin(), static_cast<int64_t>((*iter) / value));
value = *iter; value = *iter;
} }
return Status::SUCCESS; return Status::SUCCESS;
...@@ -122,8 +122,7 @@ Status AccumulateProductReverseToShape(const std::vector<int64_t> &shape_accum_r ...@@ -122,8 +122,7 @@ Status AccumulateProductReverseToShape(const std::vector<int64_t> &shape_accum_r
* in2 = [8, 16] * in2 = [8, 16]
* *out = [2, 4, 8, 16] * *out = [2, 4, 8, 16]
*/ */
Status UnifyAccumulateProduct(const std::vector<int64_t> &in1_accum, const std::vector<int64_t> &in2_accum, Status UnifyAccumulateProduct(const Shape &in1_accum, const Shape &in2_accum, Shape *out_accum) {
std::vector<int64_t> *out_accum) {
MS_EXCEPTION_IF_NULL(out_accum); MS_EXCEPTION_IF_NULL(out_accum);
out_accum->clear(); out_accum->clear();
auto in1_iter = in1_accum.begin(); auto in1_iter = in1_accum.begin();
...@@ -159,19 +158,19 @@ Status UnifyAccumulateProduct(const std::vector<int64_t> &in1_accum, const std:: ...@@ -159,19 +158,19 @@ Status UnifyAccumulateProduct(const std::vector<int64_t> &in1_accum, const std::
* in2 = [2, 16] * in2 = [2, 16]
* out = [2, 4, 4] * out = [2, 4, 4]
*/ */
Status UnifyShape(const std::vector<int32_t> &in1, const std::vector<int32_t> &in2, std::vector<int32_t> *out) { Status UnifyShape(const Shape &in1, const Shape &in2, Shape *out) {
MS_EXCEPTION_IF_NULL(out); MS_EXCEPTION_IF_NULL(out);
std::vector<int64_t> in1_accum; Shape in1_accum;
Status status = ShapeToAccumulateProduct(in1, &in1_accum); Status status = ShapeToAccumulateProduct(in1, &in1_accum);
if (status != Status::SUCCESS) { if (status != Status::SUCCESS) {
return status; return status;
} }
std::vector<int64_t> in2_accum; Shape in2_accum;
status = ShapeToAccumulateProduct(in2, &in2_accum); status = ShapeToAccumulateProduct(in2, &in2_accum);
if (status != Status::SUCCESS) { if (status != Status::SUCCESS) {
return status; return status;
} }
std::vector<int64_t> out_accum; Shape out_accum;
status = UnifyAccumulateProduct(in1_accum, in2_accum, &out_accum); status = UnifyAccumulateProduct(in1_accum, in2_accum, &out_accum);
if (status != Status::SUCCESS) { if (status != Status::SUCCESS) {
return status; return status;
...@@ -194,9 +193,8 @@ Status UnifyShape(const std::vector<int32_t> &in1, const std::vector<int32_t> &i ...@@ -194,9 +193,8 @@ Status UnifyShape(const std::vector<int32_t> &in1, const std::vector<int32_t> &i
* expand_accum_reverse = [2 * 4 * 8, 4 * 8, 8] * expand_accum_reverse = [2 * 4 * 8, 4 * 8, 8]
* out_accum_reverse = [2 * 4 * 2 * 4 * 8, 4 * 2 * 4 * 8, 2 * 4 * 8, 4 * 8, 8] * out_accum_reverse = [2 * 4 * 2 * 4 * 8, 4 * 2 * 4 * 8, 2 * 4 * 8, 4 * 8, 8]
*/ */
Status ExpandAccumulateProduct(const std::vector<int64_t> &in_accum_reverse, Status ExpandAccumulateProduct(const Shape &in_accum_reverse, const Shape &expand_accum_reverse,
const std::vector<int64_t> &expand_accum_reverse, Shape *out_accum_reverse) {
std::vector<int64_t> *out_accum_reverse) {
MS_EXCEPTION_IF_NULL(out_accum_reverse); MS_EXCEPTION_IF_NULL(out_accum_reverse);
out_accum_reverse->clear(); out_accum_reverse->clear();
auto in_riter = in_accum_reverse.rbegin(); auto in_riter = in_accum_reverse.rbegin();
...@@ -236,19 +234,19 @@ Status ExpandAccumulateProduct(const std::vector<int64_t> &in_accum_reverse, ...@@ -236,19 +234,19 @@ Status ExpandAccumulateProduct(const std::vector<int64_t> &in_accum_reverse,
* expand = [2, 4, 8] * expand = [2, 4, 8]
* out = [2, 4, 2, 4, 8] * out = [2, 4, 2, 4, 8]
*/ */
Status ExpandShape(const std::vector<int32_t> &in, const std::vector<int32_t> &expand, std::vector<int32_t> *out) { Status ExpandShape(const Shape &in, const Shape &expand, Shape *out) {
MS_EXCEPTION_IF_NULL(out); MS_EXCEPTION_IF_NULL(out);
std::vector<int64_t> in_accum_reverse; Shape in_accum_reverse;
Status status = ShapeToAccumulateProductReverse(in, &in_accum_reverse); Status status = ShapeToAccumulateProductReverse(in, &in_accum_reverse);
if (status != Status::SUCCESS) { if (status != Status::SUCCESS) {
return status; return status;
} }
std::vector<int64_t> expand_accum_reverse; Shape expand_accum_reverse;
status = ShapeToAccumulateProductReverse(expand, &expand_accum_reverse); status = ShapeToAccumulateProductReverse(expand, &expand_accum_reverse);
if (status != Status::SUCCESS) { if (status != Status::SUCCESS) {
return status; return status;
} }
std::vector<int64_t> out_accum_reverse; Shape out_accum_reverse;
status = ExpandAccumulateProduct(in_accum_reverse, expand_accum_reverse, &out_accum_reverse); status = ExpandAccumulateProduct(in_accum_reverse, expand_accum_reverse, &out_accum_reverse);
if (status != Status::SUCCESS) { if (status != Status::SUCCESS) {
return status; return status;
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include <vector> #include <vector>
#include "frontend/parallel/status.h" #include "frontend/parallel/status.h"
#include "frontend/parallel/device_matrix.h"
namespace mindspore { namespace mindspore {
namespace parallel { namespace parallel {
...@@ -39,7 +40,7 @@ namespace parallel { ...@@ -39,7 +40,7 @@ namespace parallel {
* shape_accum = [2, 2 * 8, 2 * 8 * 32] * shape_accum = [2, 2 * 8, 2 * 8 * 32]
* *
*/ */
Status ShapeToAccumulateProduct(const std::vector<int32_t> &shape, std::vector<int64_t> *shape_accum); Status ShapeToAccumulateProduct(const Shape &shape, Shape *shape_accum);
/* /*
* compute the accumulating product of all the values in shape from right to left, * compute the accumulating product of all the values in shape from right to left,
...@@ -53,7 +54,7 @@ Status ShapeToAccumulateProduct(const std::vector<int32_t> &shape, std::vector<i ...@@ -53,7 +54,7 @@ Status ShapeToAccumulateProduct(const std::vector<int32_t> &shape, std::vector<i
* shape_accum = [2 * 8 * 32, 8 * 32, 32] * shape_accum = [2 * 8 * 32, 8 * 32, 32]
* *
*/ */
Status ShapeToAccumulateProductReverse(const std::vector<int32_t> &shape, std::vector<int64_t> *shape_accum); Status ShapeToAccumulateProductReverse(const Shape &shape, Shape *shape_accum);
/* /*
* compute the original shape from the accumulating product shape_accum, * compute the original shape from the accumulating product shape_accum,
...@@ -68,7 +69,7 @@ Status ShapeToAccumulateProductReverse(const std::vector<int32_t> &shape, std::v ...@@ -68,7 +69,7 @@ Status ShapeToAccumulateProductReverse(const std::vector<int32_t> &shape, std::v
* shape = [2, 8, 32] * shape = [2, 8, 32]
* *
*/ */
Status AccumulateProductToShape(const std::vector<int64_t> &shape_accum, std::vector<int32_t> *shape); Status AccumulateProductToShape(const Shape &shape_accum, Shape *shape);
/* /*
* compute the original shape from the accumulating product shape_accum, * compute the original shape from the accumulating product shape_accum,
...@@ -83,7 +84,7 @@ Status AccumulateProductToShape(const std::vector<int64_t> &shape_accum, std::ve ...@@ -83,7 +84,7 @@ Status AccumulateProductToShape(const std::vector<int64_t> &shape_accum, std::ve
* shape = [2, 8, 32] * shape = [2, 8, 32]
* *
*/ */
Status AccumulateProductReverseToShape(const std::vector<int64_t> &shape_accum_reverse, std::vector<int32_t> *shape); Status AccumulateProductReverseToShape(const Shape &shape_accum_reverse, Shape *shape);
/* /*
* given two accumulate product in1_accum and in2_accum, compute the union of in1_accum and in2_accum, * given two accumulate product in1_accum and in2_accum, compute the union of in1_accum and in2_accum,
...@@ -101,8 +102,7 @@ Status AccumulateProductReverseToShape(const std::vector<int64_t> &shape_accum_r ...@@ -101,8 +102,7 @@ Status AccumulateProductReverseToShape(const std::vector<int64_t> &shape_accum_r
* in2_accum = [8, 16] * in2_accum = [8, 16]
* out_accum = [2, 4, 8, 16] * out_accum = [2, 4, 8, 16]
*/ */
Status UnifyAccumulateProduct(const std::vector<int64_t> &in1_accum, const std::vector<int64_t> &in2_accum, Status UnifyAccumulateProduct(const Shape &in1_accum, const Shape &in2_accum, Shape *out_accum);
std::vector<int64_t> *out_accum);
/* /*
* given two shape in1 = [din1_n-1, din1_n-2, ..., din1_0] and in2 = [din2_m-1, din2_m-2, ..., din2_m] * given two shape in1 = [din1_n-1, din1_n-2, ..., din1_0] and in2 = [din2_m-1, din2_m-2, ..., din2_m]
...@@ -117,7 +117,7 @@ Status UnifyAccumulateProduct(const std::vector<int64_t> &in1_accum, const std:: ...@@ -117,7 +117,7 @@ Status UnifyAccumulateProduct(const std::vector<int64_t> &in1_accum, const std::
* in2 = [2, 16] * in2 = [2, 16]
* out = [2, 4, 4] * out = [2, 4, 4]
*/ */
Status UnifyShape(const std::vector<int32_t> &in1, const std::vector<int32_t> &in2, std::vector<int32_t> *out); Status UnifyShape(const Shape &in1, const Shape &in2, Shape *out);
/* /*
* given two accumulate product in reverse order of in and expand, * given two accumulate product in reverse order of in and expand,
...@@ -141,9 +141,8 @@ Status UnifyShape(const std::vector<int32_t> &in1, const std::vector<int32_t> &i ...@@ -141,9 +141,8 @@ Status UnifyShape(const std::vector<int32_t> &in1, const std::vector<int32_t> &i
* expand_accum_reverse = [2 * 4 * 8, 4 * 8, 8] * expand_accum_reverse = [2 * 4 * 8, 4 * 8, 8]
* out_accum_reverse = [2 * 4 * 2 * 4 * 8, 4 * 2 * 4 * 8, 2 * 4 * 8, 4 * 8, 8] * out_accum_reverse = [2 * 4 * 2 * 4 * 8, 4 * 2 * 4 * 8, 2 * 4 * 8, 4 * 8, 8]
*/ */
Status ExpandAccumulateProduct(const std::vector<int64_t> &in_accum_reverse, Status ExpandAccumulateProduct(const Shape &in_accum_reverse, const Shape &expand_accum_reverse,
const std::vector<int64_t> &expand_accum_reverse, Shape *out_accum_reverse);
std::vector<int64_t> *out_accum_reverse);
/* /*
* given a shape in = [din_n-1, din_n-2, ..., d_0], and the expand shape expand= [dexp_m-1, dexp_m-2, ..., dexp_0], * given a shape in = [din_n-1, din_n-2, ..., d_0], and the expand shape expand= [dexp_m-1, dexp_m-2, ..., dexp_0],
...@@ -165,7 +164,7 @@ Status ExpandAccumulateProduct(const std::vector<int64_t> &in_accum_reverse, ...@@ -165,7 +164,7 @@ Status ExpandAccumulateProduct(const std::vector<int64_t> &in_accum_reverse,
* expand = [2, 4, 8] * expand = [2, 4, 8]
* out = [2, 4, 2, 4, 8] * out = [2, 4, 2, 4, 8]
*/ */
Status ExpandShape(const std::vector<int32_t> &in, const std::vector<int32_t> &expand, std::vector<int32_t> *out); Status ExpandShape(const Shape &in, const Shape &expand, Shape *out);
} // namespace parallel } // namespace parallel
} // namespace mindspore } // namespace mindspore
......
...@@ -64,8 +64,8 @@ Status TensorLayout::Init(const Arrangement &device_arrangement, const Map &tens ...@@ -64,8 +64,8 @@ Status TensorLayout::Init(const Arrangement &device_arrangement, const Map &tens
} }
} }
Status TensorLayout::InitFromVector(const std::vector<int32_t> &device_arrangement, Status TensorLayout::InitFromVector(const Shape &device_arrangement, const Shape &tensor_map,
const std::vector<int32_t> &tensor_map, const std::vector<int32_t> &tensor_shape) { const Shape &tensor_shape) {
if (device_arrangement_origin_.Init(device_arrangement) != SUCCESS) { if (device_arrangement_origin_.Init(device_arrangement) != SUCCESS) {
return FAILED; return FAILED;
} }
...@@ -82,7 +82,7 @@ Status TensorLayout::InitFromVector(const std::vector<int32_t> &device_arrangeme ...@@ -82,7 +82,7 @@ Status TensorLayout::InitFromVector(const std::vector<int32_t> &device_arrangeme
} }
bool TensorLayout::IsValidTensorLayout() const { bool TensorLayout::IsValidTensorLayout() const {
if (tensor_map_origin_.GetMaxItem() >= static_cast<int32_t>(device_arrangement_origin_.GetDimSize())) { if (tensor_map_origin_.GetMaxItem() >= static_cast<int64_t>(device_arrangement_origin_.GetDimSize())) {
MS_LOG(ERROR) << "the max element in tensor_map_origin_ must be smaller than device_arrangement_origin_ size!"; MS_LOG(ERROR) << "the max element in tensor_map_origin_ must be smaller than device_arrangement_origin_ size!";
return false; return false;
} }
...@@ -114,18 +114,18 @@ bool TensorLayout::TensorShapeDimensionIsDividedBySplitDeviceDimension() const { ...@@ -114,18 +114,18 @@ bool TensorLayout::TensorShapeDimensionIsDividedBySplitDeviceDimension() const {
} }
void TensorLayout::RemoveElementEqualToOneInDeviceArrangement() { void TensorLayout::RemoveElementEqualToOneInDeviceArrangement() {
std::vector<int32_t> device_arrangement_shape; Shape device_arrangement_shape;
std::vector<int32_t> tensor_map_shape = tensor_map_origin_.array(); Shape tensor_map_shape = tensor_map_origin_.array();
uint32_t dev_num = SizeToUint(device_arrangement_origin_.GetDimSize()); size_t dev_num = device_arrangement_origin_.GetDimSize();
int32_t dev_num_left = SizeToInt(device_arrangement_origin_.GetDimSize()); size_t dev_num_left = device_arrangement_origin_.GetDimSize();
for (uint32_t i = 0; i < dev_num; i++) { for (size_t i = 0; i < dev_num; i++) {
if (device_arrangement_origin_.GetDimByIdx(i) == 1) { if (device_arrangement_origin_.GetDimByIdx(i) == 1) {
int32_t idx = GetTensorDimensionIndexByDeviceDimensionIndex(static_cast<int32_t>(dev_num - 1 - i)); int32_t idx = GetTensorDimensionIndexByDeviceDimensionIndex(static_cast<int64_t>(dev_num - 1 - i));
if (idx != -1) { if (idx != -1) {
tensor_map_shape[static_cast<uint32_t>(idx)] = -1; tensor_map_shape[static_cast<uint32_t>(idx)] = -1;
} }
for (auto &value : tensor_map_shape) { for (auto &value : tensor_map_shape) {
if (value >= dev_num_left - 1 - static_cast<int32_t>(i)) { if (value >= SizeToLong(dev_num_left) - 1 - static_cast<int64_t>(i)) {
value--; value--;
} }
} }
...@@ -139,7 +139,7 @@ void TensorLayout::RemoveElementEqualToOneInDeviceArrangement() { ...@@ -139,7 +139,7 @@ void TensorLayout::RemoveElementEqualToOneInDeviceArrangement() {
} }
// if idx is not in tensor_map, return -1 // if idx is not in tensor_map, return -1
int32_t TensorLayout::GetTensorDimensionIndexByDeviceDimensionIndex(int32_t idx) const { int32_t TensorLayout::GetTensorDimensionIndexByDeviceDimensionIndex(int64_t idx) const {
return tensor_map_.GetIndexByValue(idx); return tensor_map_.GetIndexByValue(idx);
} }
...@@ -288,7 +288,7 @@ std::shared_ptr<TensorLayout> TensorLayout::ExpandDeviceArrangement(const Arrang ...@@ -288,7 +288,7 @@ std::shared_ptr<TensorLayout> TensorLayout::ExpandDeviceArrangement(const Arrang
} }
bool TensorLayout::TensorShapeCanBeExpanded(const Arrangement &expand_shape) const { bool TensorLayout::TensorShapeCanBeExpanded(const Arrangement &expand_shape) const {
std::vector<int32_t> in_expand_shape_shape; Shape in_expand_shape_shape;
Status status = ExpandShape(tensor_shape_.array(), expand_shape.array(), &in_expand_shape_shape); Status status = ExpandShape(tensor_shape_.array(), expand_shape.array(), &in_expand_shape_shape);
if (status != Status::SUCCESS) { if (status != Status::SUCCESS) {
return false; return false;
...@@ -297,7 +297,7 @@ bool TensorLayout::TensorShapeCanBeExpanded(const Arrangement &expand_shape) con ...@@ -297,7 +297,7 @@ bool TensorLayout::TensorShapeCanBeExpanded(const Arrangement &expand_shape) con
} }
std::shared_ptr<Arrangement> TensorLayout::ComputeExpandedTensorShape(const Arrangement &expand_shape) const { std::shared_ptr<Arrangement> TensorLayout::ComputeExpandedTensorShape(const Arrangement &expand_shape) const {
std::vector<int32_t> in_expand_shape_shape; Shape in_expand_shape_shape;
Status status = ExpandShape(tensor_shape_.array(), expand_shape.array(), &in_expand_shape_shape); Status status = ExpandShape(tensor_shape_.array(), expand_shape.array(), &in_expand_shape_shape);
if (status != Status::SUCCESS) { if (status != Status::SUCCESS) {
return nullptr; return nullptr;
...@@ -311,14 +311,14 @@ std::shared_ptr<Arrangement> TensorLayout::ComputeExpandedTensorShape(const Arra ...@@ -311,14 +311,14 @@ std::shared_ptr<Arrangement> TensorLayout::ComputeExpandedTensorShape(const Arra
} }
Arrangement TensorLayout::slice_shape() const { Arrangement TensorLayout::slice_shape() const {
std::vector<int32_t> shape; Shape shape;
for (uint32_t index = 0; index < tensor_map_.GetDimSize(); index++) { for (size_t index = 0; index < tensor_map_.GetDimSize(); index++) {
int32_t dim = tensor_map_.GetDimByIdx(index); int64_t dim = tensor_map_.GetDimByIdx(index);
int32_t num = tensor_shape_.GetDimByIdx(index); int64_t num = tensor_shape_.GetDimByIdx(index);
if (dim == -1) { if (dim == -1) {
shape.push_back(num); shape.push_back(num);
} else { } else {
int32_t divisor = device_arrangement_.GetDimByReverseIdx(IntToUint(dim)); int64_t divisor = device_arrangement_.GetDimByReverseIdx(IntToUint(dim));
shape.push_back(num / divisor); shape.push_back(num / divisor);
} }
} }
...@@ -331,7 +331,7 @@ Arrangement TensorLayout::slice_shape() const { ...@@ -331,7 +331,7 @@ Arrangement TensorLayout::slice_shape() const {
} }
} }
Status TensorLayout::UpdateTensorMap(uint32_t index, int32_t value) { Status TensorLayout::UpdateTensorMap(size_t index, int64_t value) {
if (index >= tensor_map_.GetDimSize()) { if (index >= tensor_map_.GetDimSize()) {
MS_LOG(ERROR) << "Index is out of the size of the tensor map!"; MS_LOG(ERROR) << "Index is out of the size of the tensor map!";
return Status::FAILED; return Status::FAILED;
......
...@@ -38,8 +38,7 @@ class TensorLayout { ...@@ -38,8 +38,7 @@ class TensorLayout {
std::string StandardToString() const; std::string StandardToString() const;
std::string OriginToString() const; std::string OriginToString() const;
Status Init(const Arrangement &device_arrangement, const Map &tensor_map, const Arrangement &tensor_shape); Status Init(const Arrangement &device_arrangement, const Map &tensor_map, const Arrangement &tensor_shape);
Status InitFromVector(const std::vector<int32_t> &device_arrangement, const std::vector<int32_t> &tensor_map, Status InitFromVector(const Shape &device_arrangement, const Shape &tensor_map, const Shape &tensor_shape);
const std::vector<int32_t> &tensor_shape);
bool skip_redistribution() const { return skip_redistribution_; } bool skip_redistribution() const { return skip_redistribution_; }
...@@ -79,7 +78,7 @@ class TensorLayout { ...@@ -79,7 +78,7 @@ class TensorLayout {
Arrangement slice_shape() const; Arrangement slice_shape() const;
Status UpdateTensorMap(uint32_t index, int32_t value); Status UpdateTensorMap(size_t index, int64_t value);
TensorLayout SqueezeShape() const; TensorLayout SqueezeShape() const;
...@@ -95,7 +94,7 @@ class TensorLayout { ...@@ -95,7 +94,7 @@ class TensorLayout {
int32_t GetSliceDeviceDimensionByTensorDimensionIndex(uint32_t idx) const; int32_t GetSliceDeviceDimensionByTensorDimensionIndex(uint32_t idx) const;
int32_t GetSliceNumByTensorDimensionIndex(uint32_t idx) const; int32_t GetSliceNumByTensorDimensionIndex(uint32_t idx) const;
bool TensorShapeDimensionIsDividedBySplitDeviceDimension() const; bool TensorShapeDimensionIsDividedBySplitDeviceDimension() const;
int32_t GetTensorDimensionIndexByDeviceDimensionIndex(int32_t idx) const; int32_t GetTensorDimensionIndexByDeviceDimensionIndex(int64_t idx) const;
Arrangement device_arrangement_origin_; Arrangement device_arrangement_origin_;
Map tensor_map_origin_; Map tensor_map_origin_;
......
...@@ -48,6 +48,10 @@ py::object ValuePtrToPyData(const ValuePtr &value) { ...@@ -48,6 +48,10 @@ py::object ValuePtrToPyData(const ValuePtr &value) {
MS_LOG(DEBUG) << "int"; MS_LOG(DEBUG) << "int";
py::int_ v = value->cast<Int32ImmPtr>()->value(); py::int_ v = value->cast<Int32ImmPtr>()->value();
ret = v; ret = v;
} else if (value->isa<Int64Imm>()) {
MS_LOG(DEBUG) << "int64";
py::int_ v = value->cast<Int64ImmPtr>()->value();
ret = v;
} else if (value->isa<UInt64Imm>()) { } else if (value->isa<UInt64Imm>()) {
MS_LOG(DEBUG) << "uint64"; MS_LOG(DEBUG) << "uint64";
py::int_ v = value->cast<UInt64ImmPtr>()->value(); py::int_ v = value->cast<UInt64ImmPtr>()->value();
......
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#include <unordered_map> #include <unordered_map>
#include <typeindex> #include <typeindex>
#include <memory> #include <memory>
#include <algorithm>
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
#include "base/base.h" #include "base/base.h"
...@@ -63,7 +64,16 @@ class Shape : public BaseShape { ...@@ -63,7 +64,16 @@ class Shape : public BaseShape {
static const int SHP_ANY = -1; static const int SHP_ANY = -1;
Shape() : shape_() {} Shape() : shape_() {}
Shape(const std::initializer_list<int> &list) : shape_(list) {} Shape(const std::initializer_list<int> &list) : shape_(list) {}
Shape(const std::initializer_list<int64_t> &list) {
std::vector<int64_t> list_in(list);
(void)std::transform(list_in.begin(), list_in.end(), std::back_inserter(shape_),
[](const int64_t &value) { return static_cast<int>(value); });
}
explicit Shape(const std::vector<int> &list) : shape_(list) {} explicit Shape(const std::vector<int> &list) : shape_(list) {}
explicit Shape(const std::vector<int64_t> &list) {
(void)std::transform(list.begin(), list.end(), std::back_inserter(shape_),
[](const int64_t &value) { return static_cast<int>(value); });
}
~Shape() override = default; ~Shape() override = default;
MS_DECLARE_PARENT(Shape, BaseShape) MS_DECLARE_PARENT(Shape, BaseShape)
std::string ToString() const override; std::string ToString() const override;
......
...@@ -154,13 +154,13 @@ class TestDPAlgo : public UT::Common { ...@@ -154,13 +154,13 @@ class TestDPAlgo : public UT::Common {
void TestDPAlgo::SetUp() { void TestDPAlgo::SetUp() {
cost_graph = std::make_shared<CostGraph>(); cost_graph = std::make_shared<CostGraph>();
cost_graph->SetDeviceMemoryAndCostParameter(); cost_graph->SetDeviceMemoryAndCostParameter();
std::vector<int32_t> dev_list; RankList dev_list;
for (int32_t i = 0; i < 10; i++) { for (int32_t i = 0; i < 10; i++) {
dev_list.push_back(i); dev_list.push_back(i);
} }
std::vector<int32_t> stage_map; RankList stage_map;
stage_map.push_back(8); stage_map.push_back(8);
stage_map.push_back(2); stage_map.push_back(2);
...@@ -1327,8 +1327,8 @@ TEST_F(TestDPAlgo, test_GetStrategy_for_DoubleStarGraph) { ...@@ -1327,8 +1327,8 @@ TEST_F(TestDPAlgo, test_GetStrategy_for_DoubleStarGraph) {
for (auto &op : cost_graph->GetOperators()) { for (auto &op : cost_graph->GetOperators()) {
StrategyPtr s_strategy = op->selected_strategy(); StrategyPtr s_strategy = op->selected_strategy();
std::vector<int32_t> strategy_0 = s_strategy->GetInputDim()[0]; Dimensions strategy_0 = s_strategy->GetInputDim()[0];
std::vector<int32_t> strategy_1 = s_strategy->GetInputDim()[1]; Dimensions strategy_1 = s_strategy->GetInputDim()[1];
std::string string_strategy_0 = "["; std::string string_strategy_0 = "[";
for (size_t i = 0; i < strategy_0.size(); ++i) { for (size_t i = 0; i < strategy_0.size(); ++i) {
......
...@@ -43,13 +43,13 @@ class TestEdgeCostModel : public UT::Common { ...@@ -43,13 +43,13 @@ class TestEdgeCostModel : public UT::Common {
}; };
void TestEdgeCostModel::SetUp() { void TestEdgeCostModel::SetUp() {
std::vector<int32_t> dev_list; RankList dev_list;
for (int32_t i = 0; i < 10; i++) { for (int32_t i = 0; i < 10; i++) {
dev_list.push_back(i); dev_list.push_back(i);
} }
std::vector<int32_t> stage_map; RankList stage_map;
stage_map.push_back(8); stage_map.push_back(8);
stage_map.push_back(2); stage_map.push_back(2);
......
...@@ -53,13 +53,13 @@ class TestCostGraph : public UT::Common { ...@@ -53,13 +53,13 @@ class TestCostGraph : public UT::Common {
void TestCostGraph::SetUp() { void TestCostGraph::SetUp() {
cost_graph.SetDeviceMemoryAndCostParameter(); cost_graph.SetDeviceMemoryAndCostParameter();
std::vector<int32_t> dev_list; RankList dev_list;
for (int32_t i = 0; i < 10; i++) { for (int32_t i = 0; i < 10; i++) {
dev_list.push_back(i); dev_list.push_back(i);
} }
std::vector<int32_t> stage_map; RankList stage_map;
stage_map.push_back(8); stage_map.push_back(8);
stage_map.push_back(2); stage_map.push_back(2);
......
...@@ -33,13 +33,13 @@ class TestMatMulCost : public UT::Common { ...@@ -33,13 +33,13 @@ class TestMatMulCost : public UT::Common {
void TestMatMulCost::SetUp() { void TestMatMulCost::SetUp() {
mmcost_ = MatMulCost(); mmcost_ = MatMulCost();
std::vector<int32_t> dev_list; RankList dev_list;
for (int32_t i = 0; i < 1050; i++) { for (int32_t i = 0; i < 1050; i++) {
dev_list.push_back(i); dev_list.push_back(i);
} }
std::vector<int32_t> stage_map; RankList stage_map;
stage_map.push_back(1024); stage_map.push_back(1024);
stage_map.push_back(26); stage_map.push_back(26);
...@@ -90,13 +90,13 @@ class TestActivationCost : public UT::Common { ...@@ -90,13 +90,13 @@ class TestActivationCost : public UT::Common {
void TestActivationCost::SetUp() { void TestActivationCost::SetUp() {
ac_cost_ = ActivationCost(); ac_cost_ = ActivationCost();
std::vector<int32_t> dev_list; RankList dev_list;
for (int32_t i = 0; i < 1050; i++) { for (int32_t i = 0; i < 1050; i++) {
dev_list.push_back(i); dev_list.push_back(i);
} }
std::vector<int32_t> stage_map; RankList stage_map;
stage_map.push_back(1024); stage_map.push_back(1024);
stage_map.push_back(26); stage_map.push_back(26);
...@@ -142,13 +142,13 @@ class TestPReLUCost : public UT::Common { ...@@ -142,13 +142,13 @@ class TestPReLUCost : public UT::Common {
void TestPReLUCost::SetUp() { void TestPReLUCost::SetUp() {
prelu_cost_ = PReLUCost(); prelu_cost_ = PReLUCost();
std::vector<int32_t> dev_list; RankList dev_list;
for (int32_t i = 0; i < 1050; i++) { for (int32_t i = 0; i < 1050; i++) {
dev_list.push_back(i); dev_list.push_back(i);
} }
std::vector<int32_t> stage_map; RankList stage_map;
stage_map.push_back(1024); stage_map.push_back(1024);
stage_map.push_back(26); stage_map.push_back(26);
......
...@@ -69,8 +69,8 @@ void TestDeviceManager::TearDown() { ...@@ -69,8 +69,8 @@ void TestDeviceManager::TearDown() {
} }
TEST_F(TestDeviceManager, test_dm_init_AND_get_device_list) { TEST_F(TestDeviceManager, test_dm_init_AND_get_device_list) {
std::vector<int32_t> dev_list; RankList dev_list;
std::vector<int32_t> stage_map; RankList stage_map;
int32_t local_dev = 0; int32_t local_dev = 0;
dev_list.push_back(5); dev_list.push_back(5);
...@@ -85,12 +85,12 @@ TEST_F(TestDeviceManager, test_dm_init_AND_get_device_list) { ...@@ -85,12 +85,12 @@ TEST_F(TestDeviceManager, test_dm_init_AND_get_device_list) {
ASSERT_EQ(dm_.DeviceNum(), 4); ASSERT_EQ(dm_.DeviceNum(), 4);
ASSERT_EQ(dm_.GetStageNum(), (int32_t)(2)); ASSERT_EQ(dm_.GetStageNum(), (int32_t)(2));
std::vector<int32_t> dev_list_0 = dm_.GetDeviceListByStageId(0); RankList dev_list_0 = dm_.GetDeviceListByStageId(0);
std::vector<int32_t> dev_list_1 = dm_.GetDeviceListByStageId(1); RankList dev_list_1 = dm_.GetDeviceListByStageId(1);
ASSERT_EQ(dev_list_0.size(), 2); ASSERT_EQ(dev_list_0.size(), 2);
ASSERT_EQ(dev_list_1.size(), 2); ASSERT_EQ(dev_list_1.size(), 2);
std::vector<int32_t>::iterator it = dev_list_0.begin(); RankList::iterator it = dev_list_0.begin();
ASSERT_EQ((*it), int32_t(5)); ASSERT_EQ((*it), int32_t(5));
it++; it++;
ASSERT_EQ((*it), int32_t(3)); ASSERT_EQ((*it), int32_t(3));
...@@ -112,7 +112,7 @@ TEST_F(TestDeviceManager, test_CreateNewDeviceByRank) { ...@@ -112,7 +112,7 @@ TEST_F(TestDeviceManager, test_CreateNewDeviceByRank) {
TEST_F(TestDeviceManager, test_CreateDeviceListByRankList) { TEST_F(TestDeviceManager, test_CreateDeviceListByRankList) {
std::vector<Device> dev_list; std::vector<Device> dev_list;
std::vector<int32_t> rlist; RankList rlist;
rlist.push_back(int32_t(2)); rlist.push_back(int32_t(2));
rlist.push_back(int32_t(1)); rlist.push_back(int32_t(1));
dev_list = dm_.CreateDeviceListByRankList(rlist); dev_list = dm_.CreateDeviceListByRankList(rlist);
......
...@@ -38,13 +38,13 @@ class TestActivationInfo : public UT::Common { ...@@ -38,13 +38,13 @@ class TestActivationInfo : public UT::Common {
}; };
void TestActivationInfo::SetUp() { void TestActivationInfo::SetUp() {
std::vector<int32_t> dev_list; RankList dev_list;
for (int32_t i = 0; i < 1050; i++) { for (int32_t i = 0; i < 1050; i++) {
dev_list.push_back(i); dev_list.push_back(i);
} }
std::vector<int32_t> stage_map; RankList stage_map;
stage_map.push_back(1024); stage_map.push_back(1024);
stage_map.push_back(26); stage_map.push_back(26);
...@@ -64,18 +64,18 @@ void TestActivationInfo::SetUp() { ...@@ -64,18 +64,18 @@ void TestActivationInfo::SetUp() {
} }
TEST_F(TestActivationInfo, InferDevMatrixShape1) { TEST_F(TestActivationInfo, InferDevMatrixShape1) {
std::vector<Dimensions> inputs = {{2, 4, 8, 16}}; Strategys inputs = {{2, 4, 8, 16}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
activation->Init(strategy); activation->Init(strategy);
std::vector<int32_t> dev_matrix_shape = activation->dev_matrix_shape(); Shape dev_matrix_shape = activation->dev_matrix_shape();
std::vector<int32_t> expect = {2, 4, 8, 16}; Shape expect = {2, 4, 8, 16};
ASSERT_EQ(dev_matrix_shape, expect); ASSERT_EQ(dev_matrix_shape, expect);
} }
TEST_F(TestActivationInfo, InferSliceShape1) { TEST_F(TestActivationInfo, InferSliceShape1) {
std::vector<Dimensions> str = {{2, 4, 8, 16}}; Strategys str = {{2, 4, 8, 16}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
activation->Init(strategy); activation->Init(strategy);
...@@ -96,7 +96,7 @@ TEST_F(TestActivationInfo, InferSliceShape1) { ...@@ -96,7 +96,7 @@ TEST_F(TestActivationInfo, InferSliceShape1) {
} }
TEST_F(TestActivationInfo, GetTensorLayout1) { TEST_F(TestActivationInfo, GetTensorLayout1) {
std::vector<Dimensions> str = {{2, 4, 8, 16}}; Strategys str = {{2, 4, 8, 16}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
activation->Init(strategy); activation->Init(strategy);
...@@ -117,7 +117,7 @@ TEST_F(TestActivationInfo, GetTensorLayout1) { ...@@ -117,7 +117,7 @@ TEST_F(TestActivationInfo, GetTensorLayout1) {
} }
TEST_F(TestActivationInfo, GetForwardOp1) { TEST_F(TestActivationInfo, GetForwardOp1) {
std::vector<Dimensions> inputs = {{2, 4, 8, 16}}; Strategys inputs = {{2, 4, 8, 16}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
activation->Init(strategy); activation->Init(strategy);
...@@ -128,7 +128,7 @@ TEST_F(TestActivationInfo, GetForwardOp1) { ...@@ -128,7 +128,7 @@ TEST_F(TestActivationInfo, GetForwardOp1) {
} }
TEST_F(TestActivationInfo, GetMirrorOPs1) { TEST_F(TestActivationInfo, GetMirrorOPs1) {
std::vector<Dimensions> inputs = {{1, 4, 8, 16}}; Strategys inputs = {{1, 4, 8, 16}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
activation->Init(strategy); activation->Init(strategy);
...@@ -148,7 +148,7 @@ TEST_F(TestActivationInfo, GetMirrorOPs1) { ...@@ -148,7 +148,7 @@ TEST_F(TestActivationInfo, GetMirrorOPs1) {
} }
TEST_F(TestActivationInfo, GetMirrorOPs2) { TEST_F(TestActivationInfo, GetMirrorOPs2) {
std::vector<Dimensions> inputs = {{2, 4, 8, 16}}; Strategys inputs = {{2, 4, 8, 16}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
activation->Init(strategy); activation->Init(strategy);
...@@ -161,7 +161,7 @@ TEST_F(TestActivationInfo, GetMirrorOPs2) { ...@@ -161,7 +161,7 @@ TEST_F(TestActivationInfo, GetMirrorOPs2) {
TEST_F(TestActivationInfo, CheckStrategy1) { TEST_F(TestActivationInfo, CheckStrategy1) {
// Success: {{2,4,8,16}} // Success: {{2,4,8,16}}
std::vector<Dimensions> inputs = {{2, 2, 8, 16}, {2, 4, 16, 1}}; Strategys inputs = {{2, 2, 8, 16}, {2, 4, 16, 1}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = activation->Init(strategy); Status ret = activation->Init(strategy);
...@@ -170,7 +170,7 @@ TEST_F(TestActivationInfo, CheckStrategy1) { ...@@ -170,7 +170,7 @@ TEST_F(TestActivationInfo, CheckStrategy1) {
TEST_F(TestActivationInfo, CheckStrategy2) { TEST_F(TestActivationInfo, CheckStrategy2) {
// Success: {{2,4,8,16}} // Success: {{2,4,8,16}}
std::vector<Dimensions> inputs = {{2, 4, 8}}; Strategys inputs = {{2, 4, 8}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = activation->Init(strategy); Status ret = activation->Init(strategy);
......
...@@ -40,13 +40,13 @@ class TestActivation : public UT::Common { ...@@ -40,13 +40,13 @@ class TestActivation : public UT::Common {
}; };
void TestActivation::SetUp() { void TestActivation::SetUp() {
std::vector<int32_t> dev_list; RankList dev_list;
for (int32_t i = 0; i < 1050; i++) { for (int32_t i = 0; i < 1050; i++) {
dev_list.push_back(i); dev_list.push_back(i);
} }
std::vector<int32_t> stage_map; RankList stage_map;
stage_map.push_back(1024); stage_map.push_back(1024);
stage_map.push_back(26); stage_map.push_back(26);
...@@ -101,7 +101,7 @@ TEST_F(TestActivation, test_softmax_strategies) { ...@@ -101,7 +101,7 @@ TEST_F(TestActivation, test_softmax_strategies) {
ASSERT_NE(sp, nullptr); ASSERT_NE(sp, nullptr);
Cost cost = *(swc->cost_list[0]); Cost cost = *(swc->cost_list[0]);
std::vector<Dimensions> stra = sp->GetInputDim(); Strategys stra = sp->GetInputDim();
ASSERT_GT(stra.size(), 0); ASSERT_GT(stra.size(), 0);
Dimensions input0_stra = stra[0]; Dimensions input0_stra = stra[0];
ASSERT_GT(input0_stra.size(), 2); ASSERT_GT(input0_stra.size(), 2);
......
...@@ -38,13 +38,13 @@ class TestGeluInfo : public UT::Common { ...@@ -38,13 +38,13 @@ class TestGeluInfo : public UT::Common {
}; };
void TestGeluInfo::SetUp() { void TestGeluInfo::SetUp() {
std::vector<int32_t> dev_list; RankList dev_list;
for (int32_t i = 0; i < 130; i++) { for (int32_t i = 0; i < 130; i++) {
dev_list.push_back(i); dev_list.push_back(i);
} }
std::vector<int32_t> stage_map; RankList stage_map;
stage_map.push_back(128); stage_map.push_back(128);
stage_map.push_back(2); stage_map.push_back(2);
...@@ -63,18 +63,18 @@ void TestGeluInfo::SetUp() { ...@@ -63,18 +63,18 @@ void TestGeluInfo::SetUp() {
} }
TEST_F(TestGeluInfo, InferDevMatrixShape1) { TEST_F(TestGeluInfo, InferDevMatrixShape1) {
std::vector<Dimensions> inputs = {{2, 4, 1, 16}}; Strategys inputs = {{2, 4, 1, 16}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
gelu->Init(strategy); gelu->Init(strategy);
std::vector<int32_t> dev_matrix_shape = gelu->dev_matrix_shape(); Shape dev_matrix_shape = gelu->dev_matrix_shape();
std::vector<int32_t> expect = {2, 4, 1, 16}; Shape expect = {2, 4, 1, 16};
ASSERT_EQ(dev_matrix_shape, expect); ASSERT_EQ(dev_matrix_shape, expect);
} }
TEST_F(TestGeluInfo, InferSliceShape1) { TEST_F(TestGeluInfo, InferSliceShape1) {
std::vector<Dimensions> str = {{2, 4, 1, 16}}; Strategys str = {{2, 4, 1, 16}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
gelu->Init(strategy); gelu->Init(strategy);
...@@ -95,7 +95,7 @@ TEST_F(TestGeluInfo, InferSliceShape1) { ...@@ -95,7 +95,7 @@ TEST_F(TestGeluInfo, InferSliceShape1) {
} }
TEST_F(TestGeluInfo, GetTensorLayout1) { TEST_F(TestGeluInfo, GetTensorLayout1) {
std::vector<Dimensions> str = {{2, 4, 1, 16}}; Strategys str = {{2, 4, 1, 16}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
gelu->Init(strategy); gelu->Init(strategy);
...@@ -116,7 +116,7 @@ TEST_F(TestGeluInfo, GetTensorLayout1) { ...@@ -116,7 +116,7 @@ TEST_F(TestGeluInfo, GetTensorLayout1) {
} }
TEST_F(TestGeluInfo, GetForwardOp1) { TEST_F(TestGeluInfo, GetForwardOp1) {
std::vector<Dimensions> inputs = {{2, 4, 1, 16}}; Strategys inputs = {{2, 4, 1, 16}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
gelu->Init(strategy); gelu->Init(strategy);
...@@ -127,7 +127,7 @@ TEST_F(TestGeluInfo, GetForwardOp1) { ...@@ -127,7 +127,7 @@ TEST_F(TestGeluInfo, GetForwardOp1) {
} }
TEST_F(TestGeluInfo, GetMirrorOPs1) { TEST_F(TestGeluInfo, GetMirrorOPs1) {
std::vector<Dimensions> inputs = {{2, 4, 1, 16}}; Strategys inputs = {{2, 4, 1, 16}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
gelu->Init(strategy); gelu->Init(strategy);
...@@ -140,7 +140,7 @@ TEST_F(TestGeluInfo, GetMirrorOPs1) { ...@@ -140,7 +140,7 @@ TEST_F(TestGeluInfo, GetMirrorOPs1) {
TEST_F(TestGeluInfo, CheckStrategy1) { TEST_F(TestGeluInfo, CheckStrategy1) {
// Success: {{2,4,1,16}} // Success: {{2,4,1,16}}
std::vector<Dimensions> inputs = {{2, 2, 8, 16}, {2, 4, 16, 1}}; Strategys inputs = {{2, 2, 8, 16}, {2, 4, 16, 1}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = gelu->Init(strategy); Status ret = gelu->Init(strategy);
...@@ -149,7 +149,7 @@ TEST_F(TestGeluInfo, CheckStrategy1) { ...@@ -149,7 +149,7 @@ TEST_F(TestGeluInfo, CheckStrategy1) {
TEST_F(TestGeluInfo, CheckStrategy2) { TEST_F(TestGeluInfo, CheckStrategy2) {
// Success: {{2,4,1,16}} // Success: {{2,4,1,16}}
std::vector<Dimensions> inputs = {{2, 4, 8}}; Strategys inputs = {{2, 4, 8}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = gelu->Init(strategy); Status ret = gelu->Init(strategy);
...@@ -158,7 +158,7 @@ TEST_F(TestGeluInfo, CheckStrategy2) { ...@@ -158,7 +158,7 @@ TEST_F(TestGeluInfo, CheckStrategy2) {
TEST_F(TestGeluInfo, CheckStrategy3) { TEST_F(TestGeluInfo, CheckStrategy3) {
// Success: {{2,4,1,16}} // Success: {{2,4,1,16}}
std::vector<Dimensions> inputs = {{2, 4, 1, 16}}; Strategys inputs = {{2, 4, 1, 16}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = gelu->Init(strategy); Status ret = gelu->Init(strategy);
......
...@@ -34,13 +34,13 @@ class TestGenerateStrategy : public UT::Common { ...@@ -34,13 +34,13 @@ class TestGenerateStrategy : public UT::Common {
}; };
void TestGenerateStrategy::SetUp() { void TestGenerateStrategy::SetUp() {
std::vector<int32_t> dev_list; RankList dev_list;
for (int32_t i = 0; i < 10; i++) { for (int32_t i = 0; i < 10; i++) {
dev_list.push_back(i); dev_list.push_back(i);
} }
std::vector<int32_t> stage_map; RankList stage_map;
stage_map.push_back(8); stage_map.push_back(8);
stage_map.push_back(2); stage_map.push_back(2);
......
...@@ -38,13 +38,13 @@ class TestGetNextInfo : public UT::Common { ...@@ -38,13 +38,13 @@ class TestGetNextInfo : public UT::Common {
}; };
void TestGetNextInfo::SetUp() { void TestGetNextInfo::SetUp() {
std::vector<int32_t> dev_list; RankList dev_list;
for (int32_t i = 0; i < 8; i++) { for (int32_t i = 0; i < 8; i++) {
dev_list.push_back(i); dev_list.push_back(i);
} }
std::vector<int32_t> stage_map; RankList stage_map;
stage_map.push_back(8); stage_map.push_back(8);
int32_t local_dev = 0; int32_t local_dev = 0;
// create a new g_device_manager // create a new g_device_manager
...@@ -65,16 +65,16 @@ void TestGetNextInfo::SetUp() { ...@@ -65,16 +65,16 @@ void TestGetNextInfo::SetUp() {
} }
TEST_F(TestGetNextInfo, InferDevMatrixShape1) { TEST_F(TestGetNextInfo, InferDevMatrixShape1) {
std::vector<Dimensions> inputs = {{}, {}}; Strategys inputs = {{}, {}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
get_next->Init(strategy); get_next->Init(strategy);
std::vector<int32_t> dev_matrix_shape = get_next->dev_matrix_shape(); Shape dev_matrix_shape = get_next->dev_matrix_shape();
std::vector<int32_t> expect = {8, 1}; Shape expect = {8, 1};
ASSERT_EQ(dev_matrix_shape, expect); ASSERT_EQ(dev_matrix_shape, expect);
} }
TEST_F(TestGetNextInfo, InferSliceShape1) { TEST_F(TestGetNextInfo, InferSliceShape1) {
std::vector<Dimensions> str = {{}, {}}; Strategys str = {{}, {}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
get_next->Init(strategy); get_next->Init(strategy);
...@@ -90,7 +90,7 @@ TEST_F(TestGetNextInfo, InferSliceShape1) { ...@@ -90,7 +90,7 @@ TEST_F(TestGetNextInfo, InferSliceShape1) {
} }
TEST_F(TestGetNextInfo, GetTensorLayout1) { TEST_F(TestGetNextInfo, GetTensorLayout1) {
std::vector<Dimensions> str = {{}, {}}; Strategys str = {{}, {}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
get_next->Init(strategy); get_next->Init(strategy);
std::vector<TensorInfo> outputs = get_next->outputs_tensor_info(); std::vector<TensorInfo> outputs = get_next->outputs_tensor_info();
...@@ -106,14 +106,14 @@ TEST_F(TestGetNextInfo, GetTensorLayout1) { ...@@ -106,14 +106,14 @@ TEST_F(TestGetNextInfo, GetTensorLayout1) {
} }
TEST_F(TestGetNextInfo, CheckStrategy1) { TEST_F(TestGetNextInfo, CheckStrategy1) {
std::vector<Dimensions> inputs = {}; Strategys inputs = {};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = get_next->Init(strategy); Status ret = get_next->Init(strategy);
ASSERT_EQ(ret, SUCCESS); ASSERT_EQ(ret, SUCCESS);
} }
TEST_F(TestGetNextInfo, CheckStrategy2) { TEST_F(TestGetNextInfo, CheckStrategy2) {
std::vector<Dimensions> inputs = {{8, 1}, {8}}; Strategys inputs = {{8, 1}, {8}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = get_next->Init(strategy); Status ret = get_next->Init(strategy);
ASSERT_EQ(ret, FAILED); ASSERT_EQ(ret, FAILED);
......
...@@ -38,13 +38,13 @@ class TestL2NormalizeInfo : public UT::Common { ...@@ -38,13 +38,13 @@ class TestL2NormalizeInfo : public UT::Common {
}; };
void TestL2NormalizeInfo::SetUp() { void TestL2NormalizeInfo::SetUp() {
std::vector<int32_t> dev_list; RankList dev_list;
for (int32_t i = 0; i < 34; i++) { for (int32_t i = 0; i < 34; i++) {
dev_list.push_back(i); dev_list.push_back(i);
} }
std::vector<int32_t> stage_map; RankList stage_map;
stage_map.push_back(32); stage_map.push_back(32);
stage_map.push_back(2); stage_map.push_back(2);
...@@ -64,18 +64,18 @@ void TestL2NormalizeInfo::SetUp() { ...@@ -64,18 +64,18 @@ void TestL2NormalizeInfo::SetUp() {
} }
TEST_F(TestL2NormalizeInfo, InferDevMatrixShape1) { TEST_F(TestL2NormalizeInfo, InferDevMatrixShape1) {
std::vector<Dimensions> inputs = {{4, 1, 8}}; Strategys inputs = {{4, 1, 8}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
norm->Init(strategy); norm->Init(strategy);
std::vector<int32_t> dev_matrix_shape = norm->dev_matrix_shape(); Shape dev_matrix_shape = norm->dev_matrix_shape();
std::vector<int32_t> expect = {4, 1, 8}; Shape expect = {4, 1, 8};
ASSERT_EQ(dev_matrix_shape, expect); ASSERT_EQ(dev_matrix_shape, expect);
} }
TEST_F(TestL2NormalizeInfo, InferSliceShape1) { TEST_F(TestL2NormalizeInfo, InferSliceShape1) {
std::vector<Dimensions> str = {{4, 1, 8}}; Strategys str = {{4, 1, 8}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
norm->Init(strategy); norm->Init(strategy);
...@@ -96,7 +96,7 @@ TEST_F(TestL2NormalizeInfo, InferSliceShape1) { ...@@ -96,7 +96,7 @@ TEST_F(TestL2NormalizeInfo, InferSliceShape1) {
} }
TEST_F(TestL2NormalizeInfo, GetTensorLayout1) { TEST_F(TestL2NormalizeInfo, GetTensorLayout1) {
std::vector<Dimensions> str = {{4, 1, 8}}; Strategys str = {{4, 1, 8}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
norm->Init(strategy); norm->Init(strategy);
...@@ -117,7 +117,7 @@ TEST_F(TestL2NormalizeInfo, GetTensorLayout1) { ...@@ -117,7 +117,7 @@ TEST_F(TestL2NormalizeInfo, GetTensorLayout1) {
} }
TEST_F(TestL2NormalizeInfo, GetForwardOp1) { TEST_F(TestL2NormalizeInfo, GetForwardOp1) {
std::vector<Dimensions> inputs = {{4, 1, 8}}; Strategys inputs = {{4, 1, 8}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
norm->Init(strategy); norm->Init(strategy);
...@@ -128,7 +128,7 @@ TEST_F(TestL2NormalizeInfo, GetForwardOp1) { ...@@ -128,7 +128,7 @@ TEST_F(TestL2NormalizeInfo, GetForwardOp1) {
} }
TEST_F(TestL2NormalizeInfo, GetMirrorOPs1) { TEST_F(TestL2NormalizeInfo, GetMirrorOPs1) {
std::vector<Dimensions> inputs = {{4, 1, 8}}; Strategys inputs = {{4, 1, 8}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
norm->Init(strategy); norm->Init(strategy);
...@@ -140,7 +140,7 @@ TEST_F(TestL2NormalizeInfo, GetMirrorOPs1) { ...@@ -140,7 +140,7 @@ TEST_F(TestL2NormalizeInfo, GetMirrorOPs1) {
} }
TEST_F(TestL2NormalizeInfo, CheckStrategy1) { TEST_F(TestL2NormalizeInfo, CheckStrategy1) {
std::vector<Dimensions> inputs = {{4, 1, 8}, {4, 1, 8}}; Strategys inputs = {{4, 1, 8}, {4, 1, 8}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = norm->Init(strategy); Status ret = norm->Init(strategy);
...@@ -148,7 +148,7 @@ TEST_F(TestL2NormalizeInfo, CheckStrategy1) { ...@@ -148,7 +148,7 @@ TEST_F(TestL2NormalizeInfo, CheckStrategy1) {
} }
TEST_F(TestL2NormalizeInfo, CheckStrategy2) { TEST_F(TestL2NormalizeInfo, CheckStrategy2) {
std::vector<Dimensions> inputs = {{4, 2, 3}}; Strategys inputs = {{4, 2, 3}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = norm->Init(strategy); Status ret = norm->Init(strategy);
...@@ -156,7 +156,7 @@ TEST_F(TestL2NormalizeInfo, CheckStrategy2) { ...@@ -156,7 +156,7 @@ TEST_F(TestL2NormalizeInfo, CheckStrategy2) {
} }
TEST_F(TestL2NormalizeInfo, CheckStrategy3) { TEST_F(TestL2NormalizeInfo, CheckStrategy3) {
std::vector<Dimensions> inputs = {{4, 2, 3, 4}}; Strategys inputs = {{4, 2, 3, 4}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = norm->Init(strategy); Status ret = norm->Init(strategy);
...@@ -164,7 +164,7 @@ TEST_F(TestL2NormalizeInfo, CheckStrategy3) { ...@@ -164,7 +164,7 @@ TEST_F(TestL2NormalizeInfo, CheckStrategy3) {
} }
TEST_F(TestL2NormalizeInfo, CheckStrategy4) { TEST_F(TestL2NormalizeInfo, CheckStrategy4) {
std::vector<Dimensions> inputs = {{4, 1, 8}}; Strategys inputs = {{4, 1, 8}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = norm->Init(strategy); Status ret = norm->Init(strategy);
...@@ -172,7 +172,7 @@ TEST_F(TestL2NormalizeInfo, CheckStrategy4) { ...@@ -172,7 +172,7 @@ TEST_F(TestL2NormalizeInfo, CheckStrategy4) {
} }
TEST_F(TestL2NormalizeInfo, mirror_ops) { TEST_F(TestL2NormalizeInfo, mirror_ops) {
std::vector<Dimensions> inputs = {{2, 1, 8}}; Strategys inputs = {{2, 1, 8}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
norm->Init(strategy); norm->Init(strategy);
......
...@@ -38,13 +38,13 @@ class TestLogSoftmaxInfo : public UT::Common { ...@@ -38,13 +38,13 @@ class TestLogSoftmaxInfo : public UT::Common {
}; };
void TestLogSoftmaxInfo::SetUp() { void TestLogSoftmaxInfo::SetUp() {
std::vector<int32_t> dev_list; RankList dev_list;
for (int32_t i = 0; i < 130; i++) { for (int32_t i = 0; i < 130; i++) {
dev_list.push_back(i); dev_list.push_back(i);
} }
std::vector<int32_t> stage_map; RankList stage_map;
stage_map.push_back(128); stage_map.push_back(128);
stage_map.push_back(2); stage_map.push_back(2);
...@@ -64,18 +64,18 @@ void TestLogSoftmaxInfo::SetUp() { ...@@ -64,18 +64,18 @@ void TestLogSoftmaxInfo::SetUp() {
} }
TEST_F(TestLogSoftmaxInfo, InferDevMatrixShape1) { TEST_F(TestLogSoftmaxInfo, InferDevMatrixShape1) {
std::vector<Dimensions> inputs = {{2, 4, 1, 16}}; Strategys inputs = {{2, 4, 1, 16}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
log_softmax->Init(strategy); log_softmax->Init(strategy);
std::vector<int32_t> dev_matrix_shape = log_softmax->dev_matrix_shape(); Shape dev_matrix_shape = log_softmax->dev_matrix_shape();
std::vector<int32_t> expect = {2, 4, 1, 16}; Shape expect = {2, 4, 1, 16};
ASSERT_EQ(dev_matrix_shape, expect); ASSERT_EQ(dev_matrix_shape, expect);
} }
TEST_F(TestLogSoftmaxInfo, InferSliceShape1) { TEST_F(TestLogSoftmaxInfo, InferSliceShape1) {
std::vector<Dimensions> str = {{2, 4, 1, 16}}; Strategys str = {{2, 4, 1, 16}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
log_softmax->Init(strategy); log_softmax->Init(strategy);
...@@ -96,7 +96,7 @@ TEST_F(TestLogSoftmaxInfo, InferSliceShape1) { ...@@ -96,7 +96,7 @@ TEST_F(TestLogSoftmaxInfo, InferSliceShape1) {
} }
TEST_F(TestLogSoftmaxInfo, GetTensorLayout1) { TEST_F(TestLogSoftmaxInfo, GetTensorLayout1) {
std::vector<Dimensions> str = {{2, 4, 1, 16}}; Strategys str = {{2, 4, 1, 16}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
log_softmax->Init(strategy); log_softmax->Init(strategy);
...@@ -117,7 +117,7 @@ TEST_F(TestLogSoftmaxInfo, GetTensorLayout1) { ...@@ -117,7 +117,7 @@ TEST_F(TestLogSoftmaxInfo, GetTensorLayout1) {
} }
TEST_F(TestLogSoftmaxInfo, GetForwardOp1) { TEST_F(TestLogSoftmaxInfo, GetForwardOp1) {
std::vector<Dimensions> inputs = {{2, 4, 1, 16}}; Strategys inputs = {{2, 4, 1, 16}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
log_softmax->Init(strategy); log_softmax->Init(strategy);
...@@ -128,7 +128,7 @@ TEST_F(TestLogSoftmaxInfo, GetForwardOp1) { ...@@ -128,7 +128,7 @@ TEST_F(TestLogSoftmaxInfo, GetForwardOp1) {
} }
TEST_F(TestLogSoftmaxInfo, GetMirrorOPs1) { TEST_F(TestLogSoftmaxInfo, GetMirrorOPs1) {
std::vector<Dimensions> inputs = {{2, 4, 1, 16}}; Strategys inputs = {{2, 4, 1, 16}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
log_softmax->Init(strategy); log_softmax->Init(strategy);
...@@ -141,7 +141,7 @@ TEST_F(TestLogSoftmaxInfo, GetMirrorOPs1) { ...@@ -141,7 +141,7 @@ TEST_F(TestLogSoftmaxInfo, GetMirrorOPs1) {
TEST_F(TestLogSoftmaxInfo, CheckStrategy1) { TEST_F(TestLogSoftmaxInfo, CheckStrategy1) {
// Success: {{2,4,1,16}} // Success: {{2,4,1,16}}
std::vector<Dimensions> inputs = {{2, 2, 8, 16}, {2, 4, 16, 1}}; Strategys inputs = {{2, 2, 8, 16}, {2, 4, 16, 1}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = log_softmax->Init(strategy); Status ret = log_softmax->Init(strategy);
...@@ -150,7 +150,7 @@ TEST_F(TestLogSoftmaxInfo, CheckStrategy1) { ...@@ -150,7 +150,7 @@ TEST_F(TestLogSoftmaxInfo, CheckStrategy1) {
TEST_F(TestLogSoftmaxInfo, CheckStrategy2) { TEST_F(TestLogSoftmaxInfo, CheckStrategy2) {
// Success: {{2,4,1,16}} // Success: {{2,4,1,16}}
std::vector<Dimensions> inputs = {{2, 4, 8}}; Strategys inputs = {{2, 4, 8}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = log_softmax->Init(strategy); Status ret = log_softmax->Init(strategy);
...@@ -159,7 +159,7 @@ TEST_F(TestLogSoftmaxInfo, CheckStrategy2) { ...@@ -159,7 +159,7 @@ TEST_F(TestLogSoftmaxInfo, CheckStrategy2) {
TEST_F(TestLogSoftmaxInfo, CheckStrategy3) { TEST_F(TestLogSoftmaxInfo, CheckStrategy3) {
// Success: {{2,4,1,16}} // Success: {{2,4,1,16}}
std::vector<Dimensions> inputs = {{2, 4, 8, 16}}; Strategys inputs = {{2, 4, 8, 16}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = log_softmax->Init(strategy); Status ret = log_softmax->Init(strategy);
...@@ -167,7 +167,7 @@ TEST_F(TestLogSoftmaxInfo, CheckStrategy3) { ...@@ -167,7 +167,7 @@ TEST_F(TestLogSoftmaxInfo, CheckStrategy3) {
} }
TEST_F(TestLogSoftmaxInfo, GetDeviceList1) { TEST_F(TestLogSoftmaxInfo, GetDeviceList1) {
std::vector<Dimensions> inputs = {{2, 4, 1, 16}}; Strategys inputs = {{2, 4, 1, 16}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
log_softmax->Init(strategy); log_softmax->Init(strategy);
......
...@@ -42,13 +42,13 @@ class TestMatmulInfo : public UT::Common { ...@@ -42,13 +42,13 @@ class TestMatmulInfo : public UT::Common {
}; };
void TestMatmulInfo::SetUp() { void TestMatmulInfo::SetUp() {
std::vector<int32_t> dev_list; RankList dev_list;
for (int32_t i = 0; i < 1050; i++) { for (int32_t i = 0; i < 1050; i++) {
dev_list.push_back(i); dev_list.push_back(i);
} }
std::vector<int32_t> stage_map; RankList stage_map;
stage_map.push_back(1024); stage_map.push_back(1024);
stage_map.push_back(26); stage_map.push_back(26);
...@@ -94,77 +94,77 @@ void TestMatmulInfo::SetUp() { ...@@ -94,77 +94,77 @@ void TestMatmulInfo::SetUp() {
} }
TEST_F(TestMatmulInfo, InferDevMatrixShape1) { TEST_F(TestMatmulInfo, InferDevMatrixShape1) {
std::vector<Dimensions> inputs = {{2, 4, 8, 16}, {2, 4, 16, 1}}; Strategys inputs = {{2, 4, 8, 16}, {2, 4, 16, 1}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
matmul1->Init(strategy); matmul1->Init(strategy);
std::vector<int32_t> dev_matrix_shape = matmul1->dev_matrix_shape(); Shape dev_matrix_shape = matmul1->dev_matrix_shape();
std::vector<int32_t> expect = {2, 4, 8, 16, 1}; Shape expect = {2, 4, 8, 16, 1};
ASSERT_EQ(dev_matrix_shape, expect); ASSERT_EQ(dev_matrix_shape, expect);
} }
TEST_F(TestMatmulInfo, InferDevMatrixShape2) { TEST_F(TestMatmulInfo, InferDevMatrixShape2) {
std::vector<Dimensions> inputs = {{2, 4, 8, 8}, {2, 4, 8, 2}}; Strategys inputs = {{2, 4, 8, 8}, {2, 4, 8, 2}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
matmul1->Init(strategy); matmul1->Init(strategy);
std::vector<int32_t> dev_matrix_shape = matmul1->dev_matrix_shape(); Shape dev_matrix_shape = matmul1->dev_matrix_shape();
std::vector<int32_t> expect = {2, 4, 8, 8, 2}; Shape expect = {2, 4, 8, 8, 2};
ASSERT_EQ(dev_matrix_shape, expect); ASSERT_EQ(dev_matrix_shape, expect);
} }
// matmul2 // matmul2
TEST_F(TestMatmulInfo, InferDevMatrixShape3) { TEST_F(TestMatmulInfo, InferDevMatrixShape3) {
std::vector<Dimensions> inputs = {{2, 4, 8, 16}, {1, 16}}; Strategys inputs = {{2, 4, 8, 16}, {1, 16}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
matmul2->Init(strategy); matmul2->Init(strategy);
std::vector<int32_t> dev_matrix_shape = matmul2->dev_matrix_shape(); Shape dev_matrix_shape = matmul2->dev_matrix_shape();
std::vector<int32_t> expect = {2, 4, 8, 16, 1}; Shape expect = {2, 4, 8, 16, 1};
ASSERT_EQ(dev_matrix_shape, expect); ASSERT_EQ(dev_matrix_shape, expect);
} }
// matmul2 // matmul2
TEST_F(TestMatmulInfo, InferDevMatrixShape4) { TEST_F(TestMatmulInfo, InferDevMatrixShape4) {
std::vector<Dimensions> inputs = {{2, 4, 8, 8}, {2, 8}}; Strategys inputs = {{2, 4, 8, 8}, {2, 8}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
matmul2->Init(strategy); matmul2->Init(strategy);
std::vector<int32_t> dev_matrix_shape = matmul2->dev_matrix_shape(); Shape dev_matrix_shape = matmul2->dev_matrix_shape();
std::vector<int32_t> expect = {2, 4, 8, 8, 2}; Shape expect = {2, 4, 8, 8, 2};
ASSERT_EQ(dev_matrix_shape, expect); ASSERT_EQ(dev_matrix_shape, expect);
} }
// matmul3 // matmul3
TEST_F(TestMatmulInfo, InferDevMatrixShape5) { TEST_F(TestMatmulInfo, InferDevMatrixShape5) {
std::vector<Dimensions> inputs = {{8, 16}, {2, 4, 1, 16}}; Strategys inputs = {{8, 16}, {2, 4, 1, 16}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
matmul3->Init(strategy); matmul3->Init(strategy);
std::vector<int32_t> dev_matrix_shape = matmul3->dev_matrix_shape(); Shape dev_matrix_shape = matmul3->dev_matrix_shape();
std::vector<int32_t> expect = {2, 4, 8, 16, 1}; Shape expect = {2, 4, 8, 16, 1};
ASSERT_EQ(dev_matrix_shape, expect); ASSERT_EQ(dev_matrix_shape, expect);
} }
// matmul3 // matmul3
TEST_F(TestMatmulInfo, InferDevMatrixShape6) { TEST_F(TestMatmulInfo, InferDevMatrixShape6) {
std::vector<Dimensions> inputs = {{8, 8}, {2, 4, 2, 8}}; Strategys inputs = {{8, 8}, {2, 4, 2, 8}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
matmul3->Init(strategy); matmul3->Init(strategy);
std::vector<int32_t> dev_matrix_shape = matmul3->dev_matrix_shape(); Shape dev_matrix_shape = matmul3->dev_matrix_shape();
std::vector<int32_t> expect = {2, 4, 8, 8, 2}; Shape expect = {2, 4, 8, 8, 2};
ASSERT_EQ(dev_matrix_shape, expect); ASSERT_EQ(dev_matrix_shape, expect);
} }
TEST_F(TestMatmulInfo, InferTensorMap1) { TEST_F(TestMatmulInfo, InferTensorMap1) {
std::vector<Dimensions> str = {{2, 4, 8, 16}, {2, 4, 16, 1}}; Strategys str = {{2, 4, 8, 16}, {2, 4, 16, 1}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
matmul1->Init(strategy); matmul1->Init(strategy);
...@@ -190,7 +190,7 @@ TEST_F(TestMatmulInfo, InferTensorMap1) { ...@@ -190,7 +190,7 @@ TEST_F(TestMatmulInfo, InferTensorMap1) {
// matmul2 // matmul2
TEST_F(TestMatmulInfo, InferTensorMap2) { TEST_F(TestMatmulInfo, InferTensorMap2) {
std::vector<Dimensions> str = {{2, 4, 8, 16}, {1, 16}}; Strategys str = {{2, 4, 8, 16}, {1, 16}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
matmul2->Init(strategy); matmul2->Init(strategy);
...@@ -216,7 +216,7 @@ TEST_F(TestMatmulInfo, InferTensorMap2) { ...@@ -216,7 +216,7 @@ TEST_F(TestMatmulInfo, InferTensorMap2) {
// matmul3 // matmul3
TEST_F(TestMatmulInfo, InferTensorMap3) { TEST_F(TestMatmulInfo, InferTensorMap3) {
std::vector<Dimensions> str = {{8, 16}, {2, 4, 1, 16}}; Strategys str = {{8, 16}, {2, 4, 1, 16}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
matmul3->Init(strategy); matmul3->Init(strategy);
...@@ -241,7 +241,7 @@ TEST_F(TestMatmulInfo, InferTensorMap3) { ...@@ -241,7 +241,7 @@ TEST_F(TestMatmulInfo, InferTensorMap3) {
} }
TEST_F(TestMatmulInfo, InferSliceShape1) { TEST_F(TestMatmulInfo, InferSliceShape1) {
std::vector<Dimensions> str = {{2, 4, 8, 16}, {2, 4, 16, 1}}; Strategys str = {{2, 4, 8, 16}, {2, 4, 16, 1}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
matmul1->Init(strategy); matmul1->Init(strategy);
...@@ -267,7 +267,7 @@ TEST_F(TestMatmulInfo, InferSliceShape1) { ...@@ -267,7 +267,7 @@ TEST_F(TestMatmulInfo, InferSliceShape1) {
// matmul2 // matmul2
TEST_F(TestMatmulInfo, InferSliceShape2) { TEST_F(TestMatmulInfo, InferSliceShape2) {
std::vector<Dimensions> str = {{2, 4, 8, 16}, {1, 16}}; Strategys str = {{2, 4, 8, 16}, {1, 16}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
matmul2->Init(strategy); matmul2->Init(strategy);
...@@ -293,7 +293,7 @@ TEST_F(TestMatmulInfo, InferSliceShape2) { ...@@ -293,7 +293,7 @@ TEST_F(TestMatmulInfo, InferSliceShape2) {
// matmul3 // matmul3
TEST_F(TestMatmulInfo, InferSliceShape3) { TEST_F(TestMatmulInfo, InferSliceShape3) {
std::vector<Dimensions> str = {{8, 16}, {2, 4, 1, 16}}; Strategys str = {{8, 16}, {2, 4, 1, 16}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
matmul3->Init(strategy); matmul3->Init(strategy);
...@@ -319,7 +319,7 @@ TEST_F(TestMatmulInfo, InferSliceShape3) { ...@@ -319,7 +319,7 @@ TEST_F(TestMatmulInfo, InferSliceShape3) {
// matmul3 // matmul3
TEST_F(TestMatmulInfo, GetTensorLayout3) { TEST_F(TestMatmulInfo, GetTensorLayout3) {
std::vector<Dimensions> str = {{8, 16}, {2, 4, 1, 16}}; Strategys str = {{8, 16}, {2, 4, 1, 16}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
matmul3->Init(strategy); matmul3->Init(strategy);
...@@ -344,7 +344,7 @@ TEST_F(TestMatmulInfo, GetTensorLayout3) { ...@@ -344,7 +344,7 @@ TEST_F(TestMatmulInfo, GetTensorLayout3) {
} }
TEST_F(TestMatmulInfo, GetForwardOp1) { TEST_F(TestMatmulInfo, GetForwardOp1) {
std::vector<Dimensions> inputs = {{2, 4, 8, 16}, {2, 4, 16, 1}}; Strategys inputs = {{2, 4, 8, 16}, {2, 4, 16, 1}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
matmul1->Init(strategy); matmul1->Init(strategy);
...@@ -370,7 +370,7 @@ TEST_F(TestMatmulInfo, GetForwardOp1) { ...@@ -370,7 +370,7 @@ TEST_F(TestMatmulInfo, GetForwardOp1) {
} }
TEST_F(TestMatmulInfo, GetForwardOp2) { TEST_F(TestMatmulInfo, GetForwardOp2) {
std::vector<Dimensions> inputs = {{2, 4, 8, 1}, {2, 4, 1, 16}}; Strategys inputs = {{2, 4, 8, 1}, {2, 4, 1, 16}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
matmul1->Init(strategy); matmul1->Init(strategy);
...@@ -380,7 +380,7 @@ TEST_F(TestMatmulInfo, GetForwardOp2) { ...@@ -380,7 +380,7 @@ TEST_F(TestMatmulInfo, GetForwardOp2) {
} }
TEST_F(TestMatmulInfo, GetVirtualDivOp1) { TEST_F(TestMatmulInfo, GetVirtualDivOp1) {
std::vector<Dimensions> inputs = {{2, 4, 8, 16}, {2, 4, 16, 1}}; Strategys inputs = {{2, 4, 8, 16}, {2, 4, 16, 1}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
matmul1->Init(strategy); matmul1->Init(strategy);
...@@ -399,7 +399,7 @@ TEST_F(TestMatmulInfo, GetVirtualDivOp1) { ...@@ -399,7 +399,7 @@ TEST_F(TestMatmulInfo, GetVirtualDivOp1) {
} }
TEST_F(TestMatmulInfo, GetMirrorOPs1) { TEST_F(TestMatmulInfo, GetMirrorOPs1) {
std::vector<Dimensions> inputs = {{2, 4, 8, 16}, {2, 4, 16, 1}}; Strategys inputs = {{2, 4, 8, 16}, {2, 4, 16, 1}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
matmul1->Init(strategy); matmul1->Init(strategy);
...@@ -419,7 +419,7 @@ TEST_F(TestMatmulInfo, GetMirrorOPs1) { ...@@ -419,7 +419,7 @@ TEST_F(TestMatmulInfo, GetMirrorOPs1) {
// matmul2 // matmul2
TEST_F(TestMatmulInfo, GetMirrorOPs2) { TEST_F(TestMatmulInfo, GetMirrorOPs2) {
std::vector<Dimensions> inputs = {{2, 4, 1, 16}, {8, 16}}; Strategys inputs = {{2, 4, 1, 16}, {8, 16}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
matmul2->Init(strategy); matmul2->Init(strategy);
...@@ -439,7 +439,7 @@ TEST_F(TestMatmulInfo, GetMirrorOPs2) { ...@@ -439,7 +439,7 @@ TEST_F(TestMatmulInfo, GetMirrorOPs2) {
// matmul3 // matmul3
TEST_F(TestMatmulInfo, GetMirrorOPs3) { TEST_F(TestMatmulInfo, GetMirrorOPs3) {
std::vector<Dimensions> inputs = {{8, 16}, {2, 4, 1, 16}}; Strategys inputs = {{8, 16}, {2, 4, 1, 16}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
matmul3->Init(strategy); matmul3->Init(strategy);
...@@ -457,7 +457,7 @@ TEST_F(TestMatmulInfo, GetMirrorOPs3) { ...@@ -457,7 +457,7 @@ TEST_F(TestMatmulInfo, GetMirrorOPs3) {
} }
TEST_F(TestMatmulInfo, GetMirrorOPs4) { TEST_F(TestMatmulInfo, GetMirrorOPs4) {
std::vector<Dimensions> inputs = {{2, 4, 1, 16}, {2, 4, 16, 8}}; Strategys inputs = {{2, 4, 1, 16}, {2, 4, 16, 8}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
matmul1->Init(strategy); matmul1->Init(strategy);
...@@ -467,7 +467,7 @@ TEST_F(TestMatmulInfo, GetMirrorOPs4) { ...@@ -467,7 +467,7 @@ TEST_F(TestMatmulInfo, GetMirrorOPs4) {
} }
TEST_F(TestMatmulInfo, InitTwice) { TEST_F(TestMatmulInfo, InitTwice) {
std::vector<Dimensions> inputs = {{2, 4, 8, 16}, {2, 4, 16, 1}}; Strategys inputs = {{2, 4, 8, 16}, {2, 4, 16, 1}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
// init twice // init twice
...@@ -489,7 +489,7 @@ TEST_F(TestMatmulInfo, InitTwice) { ...@@ -489,7 +489,7 @@ TEST_F(TestMatmulInfo, InitTwice) {
TEST_F(TestMatmulInfo, CheckStrategy1) { TEST_F(TestMatmulInfo, CheckStrategy1) {
// Success: {{2,4,8,16}, {2,4,16,1}} // Success: {{2,4,8,16}, {2,4,16,1}}
std::vector<Dimensions> inputs = {{2, 2, 8, 16}, {2, 4, 16, 1}}; Strategys inputs = {{2, 2, 8, 16}, {2, 4, 16, 1}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = matmul1->Init(strategy); Status ret = matmul1->Init(strategy);
...@@ -498,7 +498,7 @@ TEST_F(TestMatmulInfo, CheckStrategy1) { ...@@ -498,7 +498,7 @@ TEST_F(TestMatmulInfo, CheckStrategy1) {
TEST_F(TestMatmulInfo, CheckStrategy2) { TEST_F(TestMatmulInfo, CheckStrategy2) {
// Success: {{2,4,8,16}, {2,4,16,1}} // Success: {{2,4,8,16}, {2,4,16,1}}
std::vector<Dimensions> inputs = {{2, 4, 8, 16}, {4, 16, 1}}; Strategys inputs = {{2, 4, 8, 16}, {4, 16, 1}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = matmul1->Init(strategy); Status ret = matmul1->Init(strategy);
...@@ -507,7 +507,7 @@ TEST_F(TestMatmulInfo, CheckStrategy2) { ...@@ -507,7 +507,7 @@ TEST_F(TestMatmulInfo, CheckStrategy2) {
TEST_F(TestMatmulInfo, CheckStrategy3) { TEST_F(TestMatmulInfo, CheckStrategy3) {
// Success: {{2,4,8,16}, {2,4,16,1}} // Success: {{2,4,8,16}, {2,4,16,1}}
std::vector<Dimensions> inputs = {{2, 4, 8, 16}, {2, 4, 8, 1}}; Strategys inputs = {{2, 4, 8, 16}, {2, 4, 8, 1}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = matmul1->Init(strategy); Status ret = matmul1->Init(strategy);
...@@ -516,7 +516,7 @@ TEST_F(TestMatmulInfo, CheckStrategy3) { ...@@ -516,7 +516,7 @@ TEST_F(TestMatmulInfo, CheckStrategy3) {
TEST_F(TestMatmulInfo, CheckStrategy4) { TEST_F(TestMatmulInfo, CheckStrategy4) {
// Success: {{2,4,8,16}, {2,4,16,1}} // Success: {{2,4,8,16}, {2,4,16,1}}
std::vector<Dimensions> inputs = {{2, 4, 8, 16}, {2, 3, 16, 1}}; Strategys inputs = {{2, 4, 8, 16}, {2, 3, 16, 1}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = matmul1->Init(strategy); Status ret = matmul1->Init(strategy);
...@@ -525,7 +525,7 @@ TEST_F(TestMatmulInfo, CheckStrategy4) { ...@@ -525,7 +525,7 @@ TEST_F(TestMatmulInfo, CheckStrategy4) {
TEST_F(TestMatmulInfo, CheckStrategy5) { TEST_F(TestMatmulInfo, CheckStrategy5) {
// Success: {{2,4,8,16}, {2,4,16,1}} // Success: {{2,4,8,16}, {2,4,16,1}}
std::vector<Dimensions> inputs = {{0, 4, 8, 16}, {2, 4, 16, 1}}; Strategys inputs = {{0, 4, 8, 16}, {2, 4, 16, 1}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = matmul1->Init(strategy); Status ret = matmul1->Init(strategy);
...@@ -534,7 +534,7 @@ TEST_F(TestMatmulInfo, CheckStrategy5) { ...@@ -534,7 +534,7 @@ TEST_F(TestMatmulInfo, CheckStrategy5) {
TEST_F(TestMatmulInfo, CheckStrategy6) { TEST_F(TestMatmulInfo, CheckStrategy6) {
// Success: {{2,4,8,16}, {2,4,16,1}} // Success: {{2,4,8,16}, {2,4,16,1}}
std::vector<Dimensions> inputs = {{-1, 4, 8, 16}, {2, 4, 16, 1}}; Strategys inputs = {{-1, 4, 8, 16}, {2, 4, 16, 1}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = matmul1->Init(strategy); Status ret = matmul1->Init(strategy);
...@@ -543,7 +543,7 @@ TEST_F(TestMatmulInfo, CheckStrategy6) { ...@@ -543,7 +543,7 @@ TEST_F(TestMatmulInfo, CheckStrategy6) {
TEST_F(TestMatmulInfo, CheckStrategy7) { TEST_F(TestMatmulInfo, CheckStrategy7) {
// Success: {{2,4,8,16}, {2,4,16,1}} // Success: {{2,4,8,16}, {2,4,16,1}}
std::vector<Dimensions> inputs = {{4, 4, 8, 16}, {2, 4, 16, 1}}; Strategys inputs = {{4, 4, 8, 16}, {2, 4, 16, 1}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = matmul1->Init(strategy); Status ret = matmul1->Init(strategy);
...@@ -552,7 +552,7 @@ TEST_F(TestMatmulInfo, CheckStrategy7) { ...@@ -552,7 +552,7 @@ TEST_F(TestMatmulInfo, CheckStrategy7) {
TEST_F(TestMatmulInfo, InitFailed) { TEST_F(TestMatmulInfo, InitFailed) {
// matmul4 attr is wrong // matmul4 attr is wrong
std::vector<Dimensions> inputs = {{4, 4, 8, 16}, {2, 4, 16, 1}}; Strategys inputs = {{4, 4, 8, 16}, {2, 4, 16, 1}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = matmul4->Init(strategy); Status ret = matmul4->Init(strategy);
......
...@@ -38,13 +38,13 @@ class TestOneHotInfo : public UT::Common { ...@@ -38,13 +38,13 @@ class TestOneHotInfo : public UT::Common {
}; };
void TestOneHotInfo::SetUp() { void TestOneHotInfo::SetUp() {
std::vector<int32_t> dev_list; RankList dev_list;
for (int32_t i = 0; i < 10; i++) { for (int32_t i = 0; i < 10; i++) {
dev_list.push_back(i); dev_list.push_back(i);
} }
std::vector<int32_t> stage_map; RankList stage_map;
stage_map.push_back(8); stage_map.push_back(8);
stage_map.push_back(2); stage_map.push_back(2);
...@@ -64,43 +64,43 @@ void TestOneHotInfo::SetUp() { ...@@ -64,43 +64,43 @@ void TestOneHotInfo::SetUp() {
} }
TEST_F(TestOneHotInfo, InferDevMatrixShape1) { TEST_F(TestOneHotInfo, InferDevMatrixShape1) {
std::vector<Dimensions> inputs = {{8, 1}, {}, {}}; Strategys inputs = {{8, 1}, {}, {}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status status = onehot_info->Init(strategy); Status status = onehot_info->Init(strategy);
ASSERT_EQ(status, SUCCESS); ASSERT_EQ(status, SUCCESS);
std::vector<int32_t> dev_matrix_shape = onehot_info->dev_matrix_shape(); Shape dev_matrix_shape = onehot_info->dev_matrix_shape();
std::vector<int32_t> expect = {8, 1}; Shape expect = {8, 1};
ASSERT_EQ(dev_matrix_shape, expect); ASSERT_EQ(dev_matrix_shape, expect);
} }
TEST_F(TestOneHotInfo, InferDevMatrixShape2) { TEST_F(TestOneHotInfo, InferDevMatrixShape2) {
std::vector<Dimensions> inputs = {{4, 1}, {}, {}}; Strategys inputs = {{4, 1}, {}, {}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status status = onehot_info->Init(strategy); Status status = onehot_info->Init(strategy);
ASSERT_EQ(status, SUCCESS); ASSERT_EQ(status, SUCCESS);
std::vector<int32_t> dev_matrix_shape = onehot_info->dev_matrix_shape(); Shape dev_matrix_shape = onehot_info->dev_matrix_shape();
std::vector<int32_t> expect = {2, 4, 1}; Shape expect = {2, 4, 1};
ASSERT_EQ(dev_matrix_shape, expect); ASSERT_EQ(dev_matrix_shape, expect);
} }
TEST_F(TestOneHotInfo, InferDevMatrixShape3) { TEST_F(TestOneHotInfo, InferDevMatrixShape3) {
std::vector<Dimensions> inputs = {{4, 2}, {}, {}}; Strategys inputs = {{4, 2}, {}, {}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status status = onehot_info->Init(strategy); Status status = onehot_info->Init(strategy);
ASSERT_EQ(status, FAILED); ASSERT_EQ(status, FAILED);
std::vector<int32_t> dev_matrix_shape = onehot_info->dev_matrix_shape(); Shape dev_matrix_shape = onehot_info->dev_matrix_shape();
std::vector<int32_t> expect = {4, 2}; Shape expect = {4, 2};
ASSERT_EQ(dev_matrix_shape, expect); ASSERT_EQ(dev_matrix_shape, expect);
} }
TEST_F(TestOneHotInfo, InferTensorMap2) { TEST_F(TestOneHotInfo, InferTensorMap2) {
std::vector<Dimensions> str = {{8, 1}, {}, {}}; Strategys str = {{8, 1}, {}, {}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
Status status = onehot_info->Init(strategy); Status status = onehot_info->Init(strategy);
...@@ -122,7 +122,7 @@ TEST_F(TestOneHotInfo, InferTensorMap2) { ...@@ -122,7 +122,7 @@ TEST_F(TestOneHotInfo, InferTensorMap2) {
} }
TEST_F(TestOneHotInfo, InferSliceShape1) { TEST_F(TestOneHotInfo, InferSliceShape1) {
std::vector<Dimensions> str = {{8, 1}, {}, {}}; Strategys str = {{8, 1}, {}, {}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
Status status = onehot_info->Init(strategy); Status status = onehot_info->Init(strategy);
...@@ -144,7 +144,7 @@ TEST_F(TestOneHotInfo, InferSliceShape1) { ...@@ -144,7 +144,7 @@ TEST_F(TestOneHotInfo, InferSliceShape1) {
} }
TEST_F(TestOneHotInfo, InferSliceShape2) { TEST_F(TestOneHotInfo, InferSliceShape2) {
std::vector<Dimensions> str = {{4, 2}, {}, {}}; Strategys str = {{4, 2}, {}, {}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
Status status = onehot_info->Init(strategy); Status status = onehot_info->Init(strategy);
...@@ -166,7 +166,7 @@ TEST_F(TestOneHotInfo, InferSliceShape2) { ...@@ -166,7 +166,7 @@ TEST_F(TestOneHotInfo, InferSliceShape2) {
} }
TEST_F(TestOneHotInfo, InferSliceShape3) { TEST_F(TestOneHotInfo, InferSliceShape3) {
std::vector<Dimensions> str = {{2, 2}, {}, {}}; Strategys str = {{2, 2}, {}, {}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
Status status = onehot_info->Init(strategy); Status status = onehot_info->Init(strategy);
...@@ -188,7 +188,7 @@ TEST_F(TestOneHotInfo, InferSliceShape3) { ...@@ -188,7 +188,7 @@ TEST_F(TestOneHotInfo, InferSliceShape3) {
} }
TEST_F(TestOneHotInfo, GetMirrorOPs1) { TEST_F(TestOneHotInfo, GetMirrorOPs1) {
std::vector<Dimensions> inputs = {{8, 1}, {}, {}}; Strategys inputs = {{8, 1}, {}, {}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status status = onehot_info->Init(strategy); Status status = onehot_info->Init(strategy);
...@@ -199,7 +199,7 @@ TEST_F(TestOneHotInfo, GetMirrorOPs1) { ...@@ -199,7 +199,7 @@ TEST_F(TestOneHotInfo, GetMirrorOPs1) {
} }
TEST_F(TestOneHotInfo, CheckStrategy1) { TEST_F(TestOneHotInfo, CheckStrategy1) {
std::vector<Dimensions> inputs = {{16}, {}, {}}; Strategys inputs = {{16}, {}, {}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = onehot_info->Init(strategy); Status ret = onehot_info->Init(strategy);
......
...@@ -38,13 +38,13 @@ class TestOneHotInfo2 : public UT::Common { ...@@ -38,13 +38,13 @@ class TestOneHotInfo2 : public UT::Common {
}; };
void TestOneHotInfo2::SetUp() { void TestOneHotInfo2::SetUp() {
std::vector<int32_t> dev_list; RankList dev_list;
for (int32_t i = 0; i < 10; i++) { for (int32_t i = 0; i < 10; i++) {
dev_list.push_back(i); dev_list.push_back(i);
} }
std::vector<int32_t> stage_map; RankList stage_map;
stage_map.push_back(8); stage_map.push_back(8);
stage_map.push_back(2); stage_map.push_back(2);
...@@ -64,43 +64,43 @@ void TestOneHotInfo2::SetUp() { ...@@ -64,43 +64,43 @@ void TestOneHotInfo2::SetUp() {
} }
TEST_F(TestOneHotInfo2, InferDevMatrixShape1) { TEST_F(TestOneHotInfo2, InferDevMatrixShape1) {
std::vector<Dimensions> inputs = {{1, 8}, {}, {}}; Strategys inputs = {{1, 8}, {}, {}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status status = onehot_info2->Init(strategy); Status status = onehot_info2->Init(strategy);
ASSERT_EQ(status, SUCCESS); ASSERT_EQ(status, SUCCESS);
std::vector<int32_t> dev_matrix_shape = onehot_info2->dev_matrix_shape(); Shape dev_matrix_shape = onehot_info2->dev_matrix_shape();
std::vector<int32_t> expect = {8, 1}; Shape expect = {8, 1};
ASSERT_EQ(dev_matrix_shape, expect); ASSERT_EQ(dev_matrix_shape, expect);
} }
TEST_F(TestOneHotInfo2, InferDevMatrixShape2) { TEST_F(TestOneHotInfo2, InferDevMatrixShape2) {
std::vector<Dimensions> inputs = {{1, 4}, {}, {}}; Strategys inputs = {{1, 4}, {}, {}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status status = onehot_info2->Init(strategy); Status status = onehot_info2->Init(strategy);
ASSERT_EQ(status, SUCCESS); ASSERT_EQ(status, SUCCESS);
std::vector<int32_t> dev_matrix_shape = onehot_info2->dev_matrix_shape(); Shape dev_matrix_shape = onehot_info2->dev_matrix_shape();
std::vector<int32_t> expect = {2, 4, 1}; Shape expect = {2, 4, 1};
ASSERT_EQ(dev_matrix_shape, expect); ASSERT_EQ(dev_matrix_shape, expect);
} }
TEST_F(TestOneHotInfo2, InferDevMatrixShape3) { TEST_F(TestOneHotInfo2, InferDevMatrixShape3) {
std::vector<Dimensions> inputs = {{2, 4}, {}, {}}; Strategys inputs = {{2, 4}, {}, {}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status status = onehot_info2->Init(strategy); Status status = onehot_info2->Init(strategy);
ASSERT_EQ(status, FAILED); ASSERT_EQ(status, FAILED);
std::vector<int32_t> dev_matrix_shape = onehot_info2->dev_matrix_shape(); Shape dev_matrix_shape = onehot_info2->dev_matrix_shape();
std::vector<int32_t> expect = {4, 2}; Shape expect = {4, 2};
ASSERT_EQ(dev_matrix_shape, expect); ASSERT_EQ(dev_matrix_shape, expect);
} }
TEST_F(TestOneHotInfo2, InferTensorMap2) { TEST_F(TestOneHotInfo2, InferTensorMap2) {
std::vector<Dimensions> str = {{1, 8}, {}, {}}; Strategys str = {{1, 8}, {}, {}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
Status status = onehot_info2->Init(strategy); Status status = onehot_info2->Init(strategy);
...@@ -122,7 +122,7 @@ TEST_F(TestOneHotInfo2, InferTensorMap2) { ...@@ -122,7 +122,7 @@ TEST_F(TestOneHotInfo2, InferTensorMap2) {
} }
TEST_F(TestOneHotInfo2, InferSliceShape1) { TEST_F(TestOneHotInfo2, InferSliceShape1) {
std::vector<Dimensions> str = {{1, 8}, {}, {}}; Strategys str = {{1, 8}, {}, {}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
Status status = onehot_info2->Init(strategy); Status status = onehot_info2->Init(strategy);
...@@ -144,7 +144,7 @@ TEST_F(TestOneHotInfo2, InferSliceShape1) { ...@@ -144,7 +144,7 @@ TEST_F(TestOneHotInfo2, InferSliceShape1) {
} }
TEST_F(TestOneHotInfo2, InferSliceShape2) { TEST_F(TestOneHotInfo2, InferSliceShape2) {
std::vector<Dimensions> str = {{2, 4}, {}, {}}; Strategys str = {{2, 4}, {}, {}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
Status status = onehot_info2->Init(strategy); Status status = onehot_info2->Init(strategy);
...@@ -166,7 +166,7 @@ TEST_F(TestOneHotInfo2, InferSliceShape2) { ...@@ -166,7 +166,7 @@ TEST_F(TestOneHotInfo2, InferSliceShape2) {
} }
TEST_F(TestOneHotInfo2, InferSliceShape3) { TEST_F(TestOneHotInfo2, InferSliceShape3) {
std::vector<Dimensions> str = {{2, 2}, {}, {}}; Strategys str = {{2, 2}, {}, {}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
Status status = onehot_info2->Init(strategy); Status status = onehot_info2->Init(strategy);
......
...@@ -38,13 +38,13 @@ class TestPowInfo : public UT::Common { ...@@ -38,13 +38,13 @@ class TestPowInfo : public UT::Common {
}; };
void TestPowInfo::SetUp() { void TestPowInfo::SetUp() {
std::vector<int32_t> dev_list; RankList dev_list;
for (int32_t i = 0; i < 66; i++) { for (int32_t i = 0; i < 66; i++) {
dev_list.push_back(i); dev_list.push_back(i);
} }
std::vector<int32_t> stage_map; RankList stage_map;
stage_map.push_back(64); stage_map.push_back(64);
stage_map.push_back(2); stage_map.push_back(2);
...@@ -63,18 +63,18 @@ void TestPowInfo::SetUp() { ...@@ -63,18 +63,18 @@ void TestPowInfo::SetUp() {
} }
TEST_F(TestPowInfo, InferDevMatrixShape1) { TEST_F(TestPowInfo, InferDevMatrixShape1) {
std::vector<Dimensions> inputs = {{2, 4, 8}, {2, 4, 8}}; Strategys inputs = {{2, 4, 8}, {2, 4, 8}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
pow->Init(strategy); pow->Init(strategy);
std::vector<int32_t> dev_matrix_shape = pow->dev_matrix_shape(); Shape dev_matrix_shape = pow->dev_matrix_shape();
std::vector<int32_t> expect = {2, 4, 8}; Shape expect = {2, 4, 8};
ASSERT_EQ(dev_matrix_shape, expect); ASSERT_EQ(dev_matrix_shape, expect);
} }
TEST_F(TestPowInfo, InferSliceShape1) { TEST_F(TestPowInfo, InferSliceShape1) {
std::vector<Dimensions> str = {{2, 4, 8}, {2, 4, 8}}; Strategys str = {{2, 4, 8}, {2, 4, 8}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
pow->Init(strategy); pow->Init(strategy);
...@@ -95,7 +95,7 @@ TEST_F(TestPowInfo, InferSliceShape1) { ...@@ -95,7 +95,7 @@ TEST_F(TestPowInfo, InferSliceShape1) {
} }
TEST_F(TestPowInfo, GetTensorLayout1) { TEST_F(TestPowInfo, GetTensorLayout1) {
std::vector<Dimensions> str = {{2, 4, 8}, {2, 4, 8}}; Strategys str = {{2, 4, 8}, {2, 4, 8}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
pow->Init(strategy); pow->Init(strategy);
...@@ -116,7 +116,7 @@ TEST_F(TestPowInfo, GetTensorLayout1) { ...@@ -116,7 +116,7 @@ TEST_F(TestPowInfo, GetTensorLayout1) {
} }
TEST_F(TestPowInfo, GetForwardOp1) { TEST_F(TestPowInfo, GetForwardOp1) {
std::vector<Dimensions> inputs = {{2, 4, 8}, {2, 4, 8}}; Strategys inputs = {{2, 4, 8}, {2, 4, 8}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
pow->Init(strategy); pow->Init(strategy);
...@@ -127,7 +127,7 @@ TEST_F(TestPowInfo, GetForwardOp1) { ...@@ -127,7 +127,7 @@ TEST_F(TestPowInfo, GetForwardOp1) {
} }
TEST_F(TestPowInfo, GetMirrorOPs1) { TEST_F(TestPowInfo, GetMirrorOPs1) {
std::vector<Dimensions> inputs = {{2, 4, 8}, {2, 4, 8}}; Strategys inputs = {{2, 4, 8}, {2, 4, 8}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
pow->Init(strategy); pow->Init(strategy);
...@@ -139,7 +139,7 @@ TEST_F(TestPowInfo, GetMirrorOPs1) { ...@@ -139,7 +139,7 @@ TEST_F(TestPowInfo, GetMirrorOPs1) {
} }
TEST_F(TestPowInfo, CheckStrategy1) { TEST_F(TestPowInfo, CheckStrategy1) {
std::vector<Dimensions> inputs = {{2, 2, 8}, {2, 4, 8}}; Strategys inputs = {{2, 2, 8}, {2, 4, 8}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = pow->Init(strategy); Status ret = pow->Init(strategy);
...@@ -147,7 +147,7 @@ TEST_F(TestPowInfo, CheckStrategy1) { ...@@ -147,7 +147,7 @@ TEST_F(TestPowInfo, CheckStrategy1) {
} }
TEST_F(TestPowInfo, CheckStrategy2) { TEST_F(TestPowInfo, CheckStrategy2) {
std::vector<Dimensions> inputs = {{2, 4, 8, 16}, {2, 4, 8, 16}}; Strategys inputs = {{2, 4, 8, 16}, {2, 4, 8, 16}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = pow->Init(strategy); Status ret = pow->Init(strategy);
...@@ -155,7 +155,7 @@ TEST_F(TestPowInfo, CheckStrategy2) { ...@@ -155,7 +155,7 @@ TEST_F(TestPowInfo, CheckStrategy2) {
} }
TEST_F(TestPowInfo, CheckStrategy3) { TEST_F(TestPowInfo, CheckStrategy3) {
std::vector<Dimensions> inputs = {{2, 4, 8}, {2, 4, 8}}; Strategys inputs = {{2, 4, 8}, {2, 4, 8}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = pow->Init(strategy); Status ret = pow->Init(strategy);
......
...@@ -39,13 +39,13 @@ class TestPReLUInfo : public UT::Common { ...@@ -39,13 +39,13 @@ class TestPReLUInfo : public UT::Common {
}; };
void TestPReLUInfo::SetUp() { void TestPReLUInfo::SetUp() {
std::vector<int32_t> dev_list; RankList dev_list;
for (int32_t i = 0; i < 1050; i++) { for (int32_t i = 0; i < 1050; i++) {
dev_list.push_back(i); dev_list.push_back(i);
} }
std::vector<int32_t> stage_map; RankList stage_map;
stage_map.push_back(1024); stage_map.push_back(1024);
stage_map.push_back(26); stage_map.push_back(26);
int32_t local_dev = 0; int32_t local_dev = 0;
...@@ -64,18 +64,18 @@ void TestPReLUInfo::SetUp() { ...@@ -64,18 +64,18 @@ void TestPReLUInfo::SetUp() {
} }
TEST_F(TestPReLUInfo, InferDevMatrixShape1) { TEST_F(TestPReLUInfo, InferDevMatrixShape1) {
std::vector<Dimensions> inputs = {{2, 1, 8, 16}, {1}}; Strategys inputs = {{2, 1, 8, 16}, {1}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
prelu->Init(strategy); prelu->Init(strategy);
std::vector<int32_t> dev_matrix_shape = prelu->dev_matrix_shape(); Shape dev_matrix_shape = prelu->dev_matrix_shape();
std::vector<int32_t> expect = {4, 2, 1, 8, 16}; Shape expect = {4, 2, 1, 8, 16};
ASSERT_EQ(dev_matrix_shape, expect); ASSERT_EQ(dev_matrix_shape, expect);
} }
TEST_F(TestPReLUInfo, InferSliceShape1) { TEST_F(TestPReLUInfo, InferSliceShape1) {
std::vector<Dimensions> str = {{2, 1, 8, 16}, {1}}; Strategys str = {{2, 1, 8, 16}, {1}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
prelu->Init(strategy); prelu->Init(strategy);
...@@ -98,7 +98,7 @@ TEST_F(TestPReLUInfo, InferSliceShape1) { ...@@ -98,7 +98,7 @@ TEST_F(TestPReLUInfo, InferSliceShape1) {
} }
TEST_F(TestPReLUInfo, GetTensorLayout1) { TEST_F(TestPReLUInfo, GetTensorLayout1) {
std::vector<Dimensions> str = {{2, 1, 8, 16}, {1}}; Strategys str = {{2, 1, 8, 16}, {1}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
prelu->Init(strategy); prelu->Init(strategy);
...@@ -122,7 +122,7 @@ TEST_F(TestPReLUInfo, GetTensorLayout1) { ...@@ -122,7 +122,7 @@ TEST_F(TestPReLUInfo, GetTensorLayout1) {
} }
TEST_F(TestPReLUInfo, GetMirrorOPs1) { TEST_F(TestPReLUInfo, GetMirrorOPs1) {
std::vector<Dimensions> str = {{2, 1, 2, 2}, {1}}; Strategys str = {{2, 1, 2, 2}, {1}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
prelu->Init(strategy); prelu->Init(strategy);
MirrorOps mirror_ops = prelu->mirror_ops(); MirrorOps mirror_ops = prelu->mirror_ops();
...@@ -139,14 +139,14 @@ TEST_F(TestPReLUInfo, GetMirrorOPs1) { ...@@ -139,14 +139,14 @@ TEST_F(TestPReLUInfo, GetMirrorOPs1) {
TEST_F(TestPReLUInfo, CheckStrategy1) { TEST_F(TestPReLUInfo, CheckStrategy1) {
// Success: {{2,1,8,16},{1}} // Success: {{2,1,8,16},{1}}
std::vector<Dimensions> inputs = {{2, 1, 8, 16}}; Strategys inputs = {{2, 1, 8, 16}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = prelu->Init(strategy); Status ret = prelu->Init(strategy);
ASSERT_EQ(ret, FAILED); ASSERT_EQ(ret, FAILED);
} }
TEST_F(TestPReLUInfo, CheckStrategy2) { TEST_F(TestPReLUInfo, CheckStrategy2) {
std::vector<Dimensions> inputs = {{2, 4, 8, 16}, {4}}; Strategys inputs = {{2, 4, 8, 16}, {4}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = prelu->Init(strategy); Status ret = prelu->Init(strategy);
ASSERT_EQ(ret, SUCCESS); ASSERT_EQ(ret, SUCCESS);
...@@ -169,18 +169,18 @@ TEST_F(TestPReLUInfo, AutoStrategy1) { ...@@ -169,18 +169,18 @@ TEST_F(TestPReLUInfo, AutoStrategy1) {
} }
TEST_F(TestPReLUInfo, InferDevMatrixShape_2d1) { TEST_F(TestPReLUInfo, InferDevMatrixShape_2d1) {
std::vector<Dimensions> inputs = {{128, 1}, {1}}; Strategys inputs = {{128, 1}, {1}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
prelu_2d->Init(strategy); prelu_2d->Init(strategy);
std::vector<int32_t> dev_matrix_shape = prelu_2d->dev_matrix_shape(); Shape dev_matrix_shape = prelu_2d->dev_matrix_shape();
std::vector<int32_t> expect = {8, 128, 1}; Shape expect = {8, 128, 1};
ASSERT_EQ(dev_matrix_shape, expect); ASSERT_EQ(dev_matrix_shape, expect);
} }
TEST_F(TestPReLUInfo, InferSliceShape_2d1) { TEST_F(TestPReLUInfo, InferSliceShape_2d1) {
std::vector<Dimensions> str = {{128, 1}, {1}}; Strategys str = {{128, 1}, {1}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
prelu_2d->Init(strategy); prelu_2d->Init(strategy);
...@@ -203,7 +203,7 @@ TEST_F(TestPReLUInfo, InferSliceShape_2d1) { ...@@ -203,7 +203,7 @@ TEST_F(TestPReLUInfo, InferSliceShape_2d1) {
} }
TEST_F(TestPReLUInfo, GetTensorLayout_2d1) { TEST_F(TestPReLUInfo, GetTensorLayout_2d1) {
std::vector<Dimensions> str = {{128, 1}, {1}}; Strategys str = {{128, 1}, {1}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
prelu_2d->Init(strategy); prelu_2d->Init(strategy);
...@@ -227,7 +227,7 @@ TEST_F(TestPReLUInfo, GetTensorLayout_2d1) { ...@@ -227,7 +227,7 @@ TEST_F(TestPReLUInfo, GetTensorLayout_2d1) {
} }
TEST_F(TestPReLUInfo, GetMirrorOPs_2d1) { TEST_F(TestPReLUInfo, GetMirrorOPs_2d1) {
std::vector<Dimensions> str = {{128, 1}, {1}}; Strategys str = {{128, 1}, {1}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
prelu_2d->Init(strategy); prelu_2d->Init(strategy);
MirrorOps mirror_ops = prelu_2d->mirror_ops(); MirrorOps mirror_ops = prelu_2d->mirror_ops();
...@@ -244,14 +244,14 @@ TEST_F(TestPReLUInfo, GetMirrorOPs_2d1) { ...@@ -244,14 +244,14 @@ TEST_F(TestPReLUInfo, GetMirrorOPs_2d1) {
TEST_F(TestPReLUInfo, CheckStrategy_2d1) { TEST_F(TestPReLUInfo, CheckStrategy_2d1) {
// Success: {{2,1,8,16},{1}} // Success: {{2,1,8,16},{1}}
std::vector<Dimensions> inputs = {{128, 1}}; Strategys inputs = {{128, 1}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = prelu_2d->Init(strategy); Status ret = prelu_2d->Init(strategy);
ASSERT_EQ(ret, FAILED); ASSERT_EQ(ret, FAILED);
} }
TEST_F(TestPReLUInfo, CheckStrategy_2d2) { TEST_F(TestPReLUInfo, CheckStrategy_2d2) {
std::vector<Dimensions> inputs = {{128, 4}, {4}}; Strategys inputs = {{128, 4}, {4}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = prelu_2d->Init(strategy); Status ret = prelu_2d->Init(strategy);
ASSERT_EQ(ret, SUCCESS); ASSERT_EQ(ret, SUCCESS);
......
...@@ -39,13 +39,13 @@ class TestReduceSumInfo : public UT::Common { ...@@ -39,13 +39,13 @@ class TestReduceSumInfo : public UT::Common {
void TestReduceSumInfo::SetUp() { void TestReduceSumInfo::SetUp() {
UT::InitPythonPath(); UT::InitPythonPath();
std::vector<int32_t> dev_list; RankList dev_list;
for (int32_t i = 0; i < 34; i++) { for (int32_t i = 0; i < 34; i++) {
dev_list.push_back(i); dev_list.push_back(i);
} }
std::vector<int32_t> stage_map; RankList stage_map;
stage_map.push_back(32); stage_map.push_back(32);
stage_map.push_back(2); stage_map.push_back(2);
...@@ -68,18 +68,18 @@ void TestReduceSumInfo::SetUp() { ...@@ -68,18 +68,18 @@ void TestReduceSumInfo::SetUp() {
} }
TEST_F(TestReduceSumInfo, InferDevMatrixShape1) { TEST_F(TestReduceSumInfo, InferDevMatrixShape1) {
std::vector<Dimensions> inputs = {{4, 8, 1}}; Strategys inputs = {{4, 8, 1}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
reduce_sum->Init(strategy); reduce_sum->Init(strategy);
std::vector<int32_t> dev_matrix_shape = reduce_sum->dev_matrix_shape(); Shape dev_matrix_shape = reduce_sum->dev_matrix_shape();
std::vector<int32_t> expect = {4, 8, 1}; Shape expect = {4, 8, 1};
ASSERT_EQ(dev_matrix_shape, expect); ASSERT_EQ(dev_matrix_shape, expect);
} }
TEST_F(TestReduceSumInfo, InferSliceShape1) { TEST_F(TestReduceSumInfo, InferSliceShape1) {
std::vector<Dimensions> str = {{4, 8, 1}}; Strategys str = {{4, 8, 1}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
reduce_sum->Init(strategy); reduce_sum->Init(strategy);
...@@ -100,7 +100,7 @@ TEST_F(TestReduceSumInfo, InferSliceShape1) { ...@@ -100,7 +100,7 @@ TEST_F(TestReduceSumInfo, InferSliceShape1) {
} }
TEST_F(TestReduceSumInfo, GetTensorLayout1) { TEST_F(TestReduceSumInfo, GetTensorLayout1) {
std::vector<Dimensions> str = {{4, 8, 1}}; Strategys str = {{4, 8, 1}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
reduce_sum->Init(strategy); reduce_sum->Init(strategy);
...@@ -121,7 +121,7 @@ TEST_F(TestReduceSumInfo, GetTensorLayout1) { ...@@ -121,7 +121,7 @@ TEST_F(TestReduceSumInfo, GetTensorLayout1) {
} }
TEST_F(TestReduceSumInfo, GetForwardOp1) { TEST_F(TestReduceSumInfo, GetForwardOp1) {
std::vector<Dimensions> inputs = {{4, 8, 1}}; Strategys inputs = {{4, 8, 1}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
reduce_sum->Init(strategy); reduce_sum->Init(strategy);
...@@ -132,7 +132,7 @@ TEST_F(TestReduceSumInfo, GetForwardOp1) { ...@@ -132,7 +132,7 @@ TEST_F(TestReduceSumInfo, GetForwardOp1) {
} }
TEST_F(TestReduceSumInfo, GetForwardOp2) { TEST_F(TestReduceSumInfo, GetForwardOp2) {
std::vector<Dimensions> inputs = {{4, 4, 2}}; Strategys inputs = {{4, 4, 2}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
reduce_sum->Init(strategy); reduce_sum->Init(strategy);
...@@ -156,7 +156,7 @@ TEST_F(TestReduceSumInfo, GetForwardOp2) { ...@@ -156,7 +156,7 @@ TEST_F(TestReduceSumInfo, GetForwardOp2) {
} }
TEST_F(TestReduceSumInfo, GetMirrorOPs1) { TEST_F(TestReduceSumInfo, GetMirrorOPs1) {
std::vector<Dimensions> inputs = {{4, 8, 1}}; Strategys inputs = {{4, 8, 1}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
reduce_sum->Init(strategy); reduce_sum->Init(strategy);
...@@ -168,7 +168,7 @@ TEST_F(TestReduceSumInfo, GetMirrorOPs1) { ...@@ -168,7 +168,7 @@ TEST_F(TestReduceSumInfo, GetMirrorOPs1) {
} }
TEST_F(TestReduceSumInfo, GetMirrorOPs2) { TEST_F(TestReduceSumInfo, GetMirrorOPs2) {
std::vector<Dimensions> inputs = {{4, 4, 1}}; Strategys inputs = {{4, 4, 1}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
reduce_sum->Init(strategy); reduce_sum->Init(strategy);
...@@ -187,7 +187,7 @@ TEST_F(TestReduceSumInfo, GetMirrorOPs2) { ...@@ -187,7 +187,7 @@ TEST_F(TestReduceSumInfo, GetMirrorOPs2) {
} }
TEST_F(TestReduceSumInfo, CheckStrategy1) { TEST_F(TestReduceSumInfo, CheckStrategy1) {
std::vector<Dimensions> inputs = {{2, 2, 8, 16}}; Strategys inputs = {{2, 2, 8, 16}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = reduce_sum->Init(strategy); Status ret = reduce_sum->Init(strategy);
...@@ -195,7 +195,7 @@ TEST_F(TestReduceSumInfo, CheckStrategy1) { ...@@ -195,7 +195,7 @@ TEST_F(TestReduceSumInfo, CheckStrategy1) {
} }
TEST_F(TestReduceSumInfo, CheckStrategy2) { TEST_F(TestReduceSumInfo, CheckStrategy2) {
std::vector<Dimensions> inputs = {{2, 4, 8}, {2, 4, 8}}; Strategys inputs = {{2, 4, 8}, {2, 4, 8}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = reduce_sum->Init(strategy); Status ret = reduce_sum->Init(strategy);
...@@ -203,7 +203,7 @@ TEST_F(TestReduceSumInfo, CheckStrategy2) { ...@@ -203,7 +203,7 @@ TEST_F(TestReduceSumInfo, CheckStrategy2) {
} }
TEST_F(TestReduceSumInfo, CheckStrategy3) { TEST_F(TestReduceSumInfo, CheckStrategy3) {
std::vector<Dimensions> inputs = {{4, 4, 2}}; Strategys inputs = {{4, 4, 2}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = reduce_sum->Init(strategy); Status ret = reduce_sum->Init(strategy);
...@@ -211,7 +211,7 @@ TEST_F(TestReduceSumInfo, CheckStrategy3) { ...@@ -211,7 +211,7 @@ TEST_F(TestReduceSumInfo, CheckStrategy3) {
} }
TEST_F(TestReduceSumInfo, CheckStrategy4) { TEST_F(TestReduceSumInfo, CheckStrategy4) {
std::vector<Dimensions> inputs = {{4, 8, 1}}; Strategys inputs = {{4, 8, 1}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = reduce_sum->Init(strategy); Status ret = reduce_sum->Init(strategy);
......
...@@ -38,13 +38,13 @@ class TestReshapeInfo : public UT::Common { ...@@ -38,13 +38,13 @@ class TestReshapeInfo : public UT::Common {
}; };
void TestReshapeInfo::SetUp() { void TestReshapeInfo::SetUp() {
std::vector<int32_t> dev_list; RankList dev_list;
for (int32_t i = 0; i < 34; i++) { for (int32_t i = 0; i < 34; i++) {
dev_list.push_back(i); dev_list.push_back(i);
} }
std::vector<int32_t> stage_map; RankList stage_map;
stage_map.push_back(32); stage_map.push_back(32);
stage_map.push_back(2); stage_map.push_back(2);
...@@ -68,29 +68,29 @@ void TestReshapeInfo::SetUp() { ...@@ -68,29 +68,29 @@ void TestReshapeInfo::SetUp() {
} }
TEST_F(TestReshapeInfo, InferDevMatrixShape1) { TEST_F(TestReshapeInfo, InferDevMatrixShape1) {
std::vector<Dimensions> inputs = {{4, 1, 1, 1}}; Strategys inputs = {{4, 1, 1, 1}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
reshape->Init(strategy); reshape->Init(strategy);
std::vector<int32_t> dev_matrix_shape = reshape->dev_matrix_shape(); Shape dev_matrix_shape = reshape->dev_matrix_shape();
std::vector<int32_t> expect = {8, 4}; Shape expect = {8, 4};
ASSERT_EQ(dev_matrix_shape, expect); ASSERT_EQ(dev_matrix_shape, expect);
} }
TEST_F(TestReshapeInfo, InferDevMatrixShape2) { TEST_F(TestReshapeInfo, InferDevMatrixShape2) {
std::vector<Dimensions> inputs = {{32, 1, 1, 1}}; Strategys inputs = {{32, 1, 1, 1}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
reshape->Init(strategy); reshape->Init(strategy);
std::vector<int32_t> dev_matrix_shape = reshape->dev_matrix_shape(); Shape dev_matrix_shape = reshape->dev_matrix_shape();
std::vector<int32_t> expect = {32}; Shape expect = {32};
ASSERT_EQ(dev_matrix_shape, expect); ASSERT_EQ(dev_matrix_shape, expect);
} }
TEST_F(TestReshapeInfo, InferSliceShape1) { TEST_F(TestReshapeInfo, InferSliceShape1) {
std::vector<Dimensions> str = {{4, 1, 1, 1}}; Strategys str = {{4, 1, 1, 1}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
reshape->Init(strategy); reshape->Init(strategy);
...@@ -111,7 +111,7 @@ TEST_F(TestReshapeInfo, InferSliceShape1) { ...@@ -111,7 +111,7 @@ TEST_F(TestReshapeInfo, InferSliceShape1) {
} }
TEST_F(TestReshapeInfo, InferSliceShape2) { TEST_F(TestReshapeInfo, InferSliceShape2) {
std::vector<Dimensions> str = {{32, 1, 1, 1}}; Strategys str = {{32, 1, 1, 1}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
reshape->Init(strategy); reshape->Init(strategy);
...@@ -132,7 +132,7 @@ TEST_F(TestReshapeInfo, InferSliceShape2) { ...@@ -132,7 +132,7 @@ TEST_F(TestReshapeInfo, InferSliceShape2) {
} }
TEST_F(TestReshapeInfo, GetTensorLayout1) { TEST_F(TestReshapeInfo, GetTensorLayout1) {
std::vector<Dimensions> str = {{4, 1, 1, 1}}; Strategys str = {{4, 1, 1, 1}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
reshape->Init(strategy); reshape->Init(strategy);
...@@ -153,7 +153,7 @@ TEST_F(TestReshapeInfo, GetTensorLayout1) { ...@@ -153,7 +153,7 @@ TEST_F(TestReshapeInfo, GetTensorLayout1) {
} }
TEST_F(TestReshapeInfo, GetTensorLayout2) { TEST_F(TestReshapeInfo, GetTensorLayout2) {
std::vector<Dimensions> str = {{32, 1, 1, 1}}; Strategys str = {{32, 1, 1, 1}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
reshape->Init(strategy); reshape->Init(strategy);
...@@ -174,7 +174,7 @@ TEST_F(TestReshapeInfo, GetTensorLayout2) { ...@@ -174,7 +174,7 @@ TEST_F(TestReshapeInfo, GetTensorLayout2) {
} }
TEST_F(TestReshapeInfo, GetForwardOp1) { TEST_F(TestReshapeInfo, GetForwardOp1) {
std::vector<Dimensions> inputs = {{4, 1, 1, 1}}; Strategys inputs = {{4, 1, 1, 1}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
reshape->Init(strategy); reshape->Init(strategy);
...@@ -185,7 +185,7 @@ TEST_F(TestReshapeInfo, GetForwardOp1) { ...@@ -185,7 +185,7 @@ TEST_F(TestReshapeInfo, GetForwardOp1) {
} }
TEST_F(TestReshapeInfo, GetMirrorOPs1) { TEST_F(TestReshapeInfo, GetMirrorOPs1) {
std::vector<Dimensions> inputs = {{4, 1, 1, 1}}; Strategys inputs = {{4, 1, 1, 1}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
reshape->Init(strategy); reshape->Init(strategy);
...@@ -197,7 +197,7 @@ TEST_F(TestReshapeInfo, GetMirrorOPs1) { ...@@ -197,7 +197,7 @@ TEST_F(TestReshapeInfo, GetMirrorOPs1) {
} }
TEST_F(TestReshapeInfo, CheckStrategy1) { TEST_F(TestReshapeInfo, CheckStrategy1) {
std::vector<Dimensions> inputs = {{1, 4, 8}}; Strategys inputs = {{1, 4, 8}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = reshape->Init(strategy); Status ret = reshape->Init(strategy);
...@@ -205,7 +205,7 @@ TEST_F(TestReshapeInfo, CheckStrategy1) { ...@@ -205,7 +205,7 @@ TEST_F(TestReshapeInfo, CheckStrategy1) {
} }
TEST_F(TestReshapeInfo, CheckStrategy2) { TEST_F(TestReshapeInfo, CheckStrategy2) {
std::vector<Dimensions> inputs = {{2, 4, 8}, {2, 4, 8}}; Strategys inputs = {{2, 4, 8}, {2, 4, 8}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = reshape->Init(strategy); Status ret = reshape->Init(strategy);
...@@ -213,7 +213,7 @@ TEST_F(TestReshapeInfo, CheckStrategy2) { ...@@ -213,7 +213,7 @@ TEST_F(TestReshapeInfo, CheckStrategy2) {
} }
TEST_F(TestReshapeInfo, CheckStrategy3) { TEST_F(TestReshapeInfo, CheckStrategy3) {
std::vector<Dimensions> inputs = {{4, 1, 1, 1}}; Strategys inputs = {{4, 1, 1, 1}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = reshape->Init(strategy); Status ret = reshape->Init(strategy);
......
...@@ -38,13 +38,13 @@ class TestSoftmaxLoss : public UT::Common { ...@@ -38,13 +38,13 @@ class TestSoftmaxLoss : public UT::Common {
}; };
void TestSoftmaxLoss::SetUp() { void TestSoftmaxLoss::SetUp() {
std::vector<int32_t> dev_list; RankList dev_list;
for (int32_t i = 0; i < 65; i++) { for (int32_t i = 0; i < 65; i++) {
dev_list.push_back(i); dev_list.push_back(i);
} }
std::vector<int32_t> stage_map; RankList stage_map;
stage_map.push_back(64); stage_map.push_back(64);
stage_map.push_back(1); stage_map.push_back(1);
...@@ -64,18 +64,18 @@ void TestSoftmaxLoss::SetUp() { ...@@ -64,18 +64,18 @@ void TestSoftmaxLoss::SetUp() {
} }
TEST_F(TestSoftmaxLoss, InferDevMatrixShape1) { TEST_F(TestSoftmaxLoss, InferDevMatrixShape1) {
std::vector<Dimensions> inputs = {{2, 4, 8, 1}, {2, 4, 8, 1}}; Strategys inputs = {{2, 4, 8, 1}, {2, 4, 8, 1}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
loss->Init(strategy); loss->Init(strategy);
std::vector<int32_t> dev_matrix_shape = loss->dev_matrix_shape(); Shape dev_matrix_shape = loss->dev_matrix_shape();
std::vector<int32_t> expect = {2, 4, 8, 1}; Shape expect = {2, 4, 8, 1};
ASSERT_EQ(dev_matrix_shape, expect); ASSERT_EQ(dev_matrix_shape, expect);
} }
TEST_F(TestSoftmaxLoss, InferSliceShape1) { TEST_F(TestSoftmaxLoss, InferSliceShape1) {
std::vector<Dimensions> str = {{2, 4, 8, 1}, {2, 4, 8, 1}}; Strategys str = {{2, 4, 8, 1}, {2, 4, 8, 1}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
loss->Init(strategy); loss->Init(strategy);
...@@ -104,7 +104,7 @@ TEST_F(TestSoftmaxLoss, InferSliceShape1) { ...@@ -104,7 +104,7 @@ TEST_F(TestSoftmaxLoss, InferSliceShape1) {
} }
TEST_F(TestSoftmaxLoss, GetTensorLayout1) { TEST_F(TestSoftmaxLoss, GetTensorLayout1) {
std::vector<Dimensions> str = {{2, 4, 8, 1}, {2, 4, 8, 1}}; Strategys str = {{2, 4, 8, 1}, {2, 4, 8, 1}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
loss->Init(strategy); loss->Init(strategy);
...@@ -133,7 +133,7 @@ TEST_F(TestSoftmaxLoss, GetTensorLayout1) { ...@@ -133,7 +133,7 @@ TEST_F(TestSoftmaxLoss, GetTensorLayout1) {
} }
TEST_F(TestSoftmaxLoss, GetForwardOp1) { TEST_F(TestSoftmaxLoss, GetForwardOp1) {
std::vector<Dimensions> inputs = {{2, 4, 8, 1}, {2, 4, 8, 1}}; Strategys inputs = {{2, 4, 8, 1}, {2, 4, 8, 1}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
loss->Init(strategy); loss->Init(strategy);
...@@ -144,7 +144,7 @@ TEST_F(TestSoftmaxLoss, GetForwardOp1) { ...@@ -144,7 +144,7 @@ TEST_F(TestSoftmaxLoss, GetForwardOp1) {
} }
TEST_F(TestSoftmaxLoss, GetMirrorOPs1) { TEST_F(TestSoftmaxLoss, GetMirrorOPs1) {
std::vector<Dimensions> inputs = {{2, 4, 8, 1}, {2, 4, 8, 1}}; Strategys inputs = {{2, 4, 8, 1}, {2, 4, 8, 1}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
loss->Init(strategy); loss->Init(strategy);
...@@ -156,7 +156,7 @@ TEST_F(TestSoftmaxLoss, GetMirrorOPs1) { ...@@ -156,7 +156,7 @@ TEST_F(TestSoftmaxLoss, GetMirrorOPs1) {
} }
TEST_F(TestSoftmaxLoss, GetVirtualDivOPs1) { TEST_F(TestSoftmaxLoss, GetVirtualDivOPs1) {
std::vector<Dimensions> inputs = {{1, 4, 8, 1}, {1, 4, 8, 1}}; Strategys inputs = {{1, 4, 8, 1}, {1, 4, 8, 1}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
loss->Init(strategy); loss->Init(strategy);
...@@ -176,7 +176,7 @@ TEST_F(TestSoftmaxLoss, GetVirtualDivOPs1) { ...@@ -176,7 +176,7 @@ TEST_F(TestSoftmaxLoss, GetVirtualDivOPs1) {
TEST_F(TestSoftmaxLoss, CheckStrategy1) { TEST_F(TestSoftmaxLoss, CheckStrategy1) {
// Success: {{2,4,8,16}} // Success: {{2,4,8,16}}
std::vector<Dimensions> inputs = {{2, 2, 8, 16}, {2, 4, 16, 1}}; Strategys inputs = {{2, 2, 8, 16}, {2, 4, 16, 1}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = loss->Init(strategy); Status ret = loss->Init(strategy);
...@@ -185,7 +185,7 @@ TEST_F(TestSoftmaxLoss, CheckStrategy1) { ...@@ -185,7 +185,7 @@ TEST_F(TestSoftmaxLoss, CheckStrategy1) {
TEST_F(TestSoftmaxLoss, CheckStrategy2) { TEST_F(TestSoftmaxLoss, CheckStrategy2) {
// Success: {{2,4,8,16}} // Success: {{2,4,8,16}}
std::vector<Dimensions> inputs = {{2, 4, 8}}; Strategys inputs = {{2, 4, 8}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = loss->Init(strategy); Status ret = loss->Init(strategy);
......
...@@ -39,13 +39,13 @@ class TestSoftmaxInfo : public UT::Common { ...@@ -39,13 +39,13 @@ class TestSoftmaxInfo : public UT::Common {
}; };
void TestSoftmaxInfo::SetUp() { void TestSoftmaxInfo::SetUp() {
std::vector<int32_t> dev_list; RankList dev_list;
for (int32_t i = 0; i < 130; i++) { for (int32_t i = 0; i < 130; i++) {
dev_list.push_back(i); dev_list.push_back(i);
} }
std::vector<int32_t> stage_map; RankList stage_map;
stage_map.push_back(128); stage_map.push_back(128);
stage_map.push_back(2); stage_map.push_back(2);
...@@ -68,18 +68,18 @@ void TestSoftmaxInfo::SetUp() { ...@@ -68,18 +68,18 @@ void TestSoftmaxInfo::SetUp() {
} }
TEST_F(TestSoftmaxInfo, InferDevMatrixShape1) { TEST_F(TestSoftmaxInfo, InferDevMatrixShape1) {
std::vector<Dimensions> inputs = {{2, 4, 1, 16}}; Strategys inputs = {{2, 4, 1, 16}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
softmax->Init(strategy); softmax->Init(strategy);
std::vector<int32_t> dev_matrix_shape = softmax->dev_matrix_shape(); Shape dev_matrix_shape = softmax->dev_matrix_shape();
std::vector<int32_t> expect = {2, 4, 1, 16}; Shape expect = {2, 4, 1, 16};
ASSERT_EQ(dev_matrix_shape, expect); ASSERT_EQ(dev_matrix_shape, expect);
} }
TEST_F(TestSoftmaxInfo, InferSliceShape1) { TEST_F(TestSoftmaxInfo, InferSliceShape1) {
std::vector<Dimensions> str = {{2, 4, 1, 16}}; Strategys str = {{2, 4, 1, 16}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
softmax->Init(strategy); softmax->Init(strategy);
...@@ -100,7 +100,7 @@ TEST_F(TestSoftmaxInfo, InferSliceShape1) { ...@@ -100,7 +100,7 @@ TEST_F(TestSoftmaxInfo, InferSliceShape1) {
} }
TEST_F(TestSoftmaxInfo, GetTensorLayout1) { TEST_F(TestSoftmaxInfo, GetTensorLayout1) {
std::vector<Dimensions> str = {{2, 4, 1, 16}}; Strategys str = {{2, 4, 1, 16}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
softmax->Init(strategy); softmax->Init(strategy);
...@@ -121,7 +121,7 @@ TEST_F(TestSoftmaxInfo, GetTensorLayout1) { ...@@ -121,7 +121,7 @@ TEST_F(TestSoftmaxInfo, GetTensorLayout1) {
} }
TEST_F(TestSoftmaxInfo, GetForwardOp1) { TEST_F(TestSoftmaxInfo, GetForwardOp1) {
std::vector<Dimensions> inputs = {{2, 4, 1, 16}}; Strategys inputs = {{2, 4, 1, 16}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
softmax->Init(strategy); softmax->Init(strategy);
...@@ -132,7 +132,7 @@ TEST_F(TestSoftmaxInfo, GetForwardOp1) { ...@@ -132,7 +132,7 @@ TEST_F(TestSoftmaxInfo, GetForwardOp1) {
} }
TEST_F(TestSoftmaxInfo, GetMirrorOPs1) { TEST_F(TestSoftmaxInfo, GetMirrorOPs1) {
std::vector<Dimensions> inputs = {{2, 4, 1, 16}}; Strategys inputs = {{2, 4, 1, 16}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
softmax->Init(strategy); softmax->Init(strategy);
...@@ -145,7 +145,7 @@ TEST_F(TestSoftmaxInfo, GetMirrorOPs1) { ...@@ -145,7 +145,7 @@ TEST_F(TestSoftmaxInfo, GetMirrorOPs1) {
TEST_F(TestSoftmaxInfo, CheckStrategy1) { TEST_F(TestSoftmaxInfo, CheckStrategy1) {
// Success: {{2,4,1,16}} // Success: {{2,4,1,16}}
std::vector<Dimensions> inputs = {{2, 2, 8, 16}, {2, 4, 16, 1}}; Strategys inputs = {{2, 2, 8, 16}, {2, 4, 16, 1}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = softmax->Init(strategy); Status ret = softmax->Init(strategy);
...@@ -154,7 +154,7 @@ TEST_F(TestSoftmaxInfo, CheckStrategy1) { ...@@ -154,7 +154,7 @@ TEST_F(TestSoftmaxInfo, CheckStrategy1) {
TEST_F(TestSoftmaxInfo, CheckStrategy2) { TEST_F(TestSoftmaxInfo, CheckStrategy2) {
// Success: {{2,4,1,16}} // Success: {{2,4,1,16}}
std::vector<Dimensions> inputs = {{2, 4, 8}}; Strategys inputs = {{2, 4, 8}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = softmax->Init(strategy); Status ret = softmax->Init(strategy);
...@@ -163,7 +163,7 @@ TEST_F(TestSoftmaxInfo, CheckStrategy2) { ...@@ -163,7 +163,7 @@ TEST_F(TestSoftmaxInfo, CheckStrategy2) {
TEST_F(TestSoftmaxInfo, CheckStrategy3) { TEST_F(TestSoftmaxInfo, CheckStrategy3) {
// Success: {{2,4,1,16}} // Success: {{2,4,1,16}}
std::vector<Dimensions> inputs = {{2, 4, 8, 16}}; Strategys inputs = {{2, 4, 8, 16}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = softmax->Init(strategy); Status ret = softmax->Init(strategy);
...@@ -172,7 +172,7 @@ TEST_F(TestSoftmaxInfo, CheckStrategy3) { ...@@ -172,7 +172,7 @@ TEST_F(TestSoftmaxInfo, CheckStrategy3) {
TEST_F(TestSoftmaxInfo, InitFailed1) { TEST_F(TestSoftmaxInfo, InitFailed1) {
// softmax2's axis is wrong // softmax2's axis is wrong
std::vector<Dimensions> inputs = {{2, 4, 1, 16}}; Strategys inputs = {{2, 4, 1, 16}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = softmax2->Init(strategy); Status ret = softmax2->Init(strategy);
...@@ -181,7 +181,7 @@ TEST_F(TestSoftmaxInfo, InitFailed1) { ...@@ -181,7 +181,7 @@ TEST_F(TestSoftmaxInfo, InitFailed1) {
TEST_F(TestSoftmaxInfo, InitFailed2) { TEST_F(TestSoftmaxInfo, InitFailed2) {
// dev num is wrong // dev num is wrong
std::vector<Dimensions> inputs = {{2, 4, 1, 100}}; Strategys inputs = {{2, 4, 1, 100}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = softmax2->Init(strategy); Status ret = softmax2->Init(strategy);
......
...@@ -38,13 +38,13 @@ class TestTanhInfo : public UT::Common { ...@@ -38,13 +38,13 @@ class TestTanhInfo : public UT::Common {
}; };
void TestTanhInfo::SetUp() { void TestTanhInfo::SetUp() {
std::vector<int32_t> dev_list; RankList dev_list;
for (int32_t i = 0; i < 130; i++) { for (int32_t i = 0; i < 130; i++) {
dev_list.push_back(i); dev_list.push_back(i);
} }
std::vector<int32_t> stage_map; RankList stage_map;
stage_map.push_back(128); stage_map.push_back(128);
stage_map.push_back(2); stage_map.push_back(2);
...@@ -63,18 +63,18 @@ void TestTanhInfo::SetUp() { ...@@ -63,18 +63,18 @@ void TestTanhInfo::SetUp() {
} }
TEST_F(TestTanhInfo, InferDevMatrixShape1) { TEST_F(TestTanhInfo, InferDevMatrixShape1) {
std::vector<Dimensions> inputs = {{2, 4, 1, 16}}; Strategys inputs = {{2, 4, 1, 16}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
tanh->Init(strategy); tanh->Init(strategy);
std::vector<int32_t> dev_matrix_shape = tanh->dev_matrix_shape(); Shape dev_matrix_shape = tanh->dev_matrix_shape();
std::vector<int32_t> expect = {2, 4, 1, 16}; Shape expect = {2, 4, 1, 16};
ASSERT_EQ(dev_matrix_shape, expect); ASSERT_EQ(dev_matrix_shape, expect);
} }
TEST_F(TestTanhInfo, InferSliceShape1) { TEST_F(TestTanhInfo, InferSliceShape1) {
std::vector<Dimensions> str = {{2, 4, 1, 16}}; Strategys str = {{2, 4, 1, 16}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
tanh->Init(strategy); tanh->Init(strategy);
...@@ -95,7 +95,7 @@ TEST_F(TestTanhInfo, InferSliceShape1) { ...@@ -95,7 +95,7 @@ TEST_F(TestTanhInfo, InferSliceShape1) {
} }
TEST_F(TestTanhInfo, GetTensorLayout1) { TEST_F(TestTanhInfo, GetTensorLayout1) {
std::vector<Dimensions> str = {{2, 4, 1, 16}}; Strategys str = {{2, 4, 1, 16}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
tanh->Init(strategy); tanh->Init(strategy);
...@@ -116,7 +116,7 @@ TEST_F(TestTanhInfo, GetTensorLayout1) { ...@@ -116,7 +116,7 @@ TEST_F(TestTanhInfo, GetTensorLayout1) {
} }
TEST_F(TestTanhInfo, GetForwardOp1) { TEST_F(TestTanhInfo, GetForwardOp1) {
std::vector<Dimensions> inputs = {{2, 4, 1, 16}}; Strategys inputs = {{2, 4, 1, 16}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
tanh->Init(strategy); tanh->Init(strategy);
...@@ -127,7 +127,7 @@ TEST_F(TestTanhInfo, GetForwardOp1) { ...@@ -127,7 +127,7 @@ TEST_F(TestTanhInfo, GetForwardOp1) {
} }
TEST_F(TestTanhInfo, GetMirrorOPs1) { TEST_F(TestTanhInfo, GetMirrorOPs1) {
std::vector<Dimensions> inputs = {{2, 4, 1, 16}}; Strategys inputs = {{2, 4, 1, 16}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
tanh->Init(strategy); tanh->Init(strategy);
...@@ -140,7 +140,7 @@ TEST_F(TestTanhInfo, GetMirrorOPs1) { ...@@ -140,7 +140,7 @@ TEST_F(TestTanhInfo, GetMirrorOPs1) {
TEST_F(TestTanhInfo, CheckStrategy1) { TEST_F(TestTanhInfo, CheckStrategy1) {
// Success: {{2,4,1,16}} // Success: {{2,4,1,16}}
std::vector<Dimensions> inputs = {{2, 2, 8, 16}, {2, 4, 16, 1}}; Strategys inputs = {{2, 2, 8, 16}, {2, 4, 16, 1}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = tanh->Init(strategy); Status ret = tanh->Init(strategy);
...@@ -149,7 +149,7 @@ TEST_F(TestTanhInfo, CheckStrategy1) { ...@@ -149,7 +149,7 @@ TEST_F(TestTanhInfo, CheckStrategy1) {
TEST_F(TestTanhInfo, CheckStrategy2) { TEST_F(TestTanhInfo, CheckStrategy2) {
// Success: {{2,4,1,16}} // Success: {{2,4,1,16}}
std::vector<Dimensions> inputs = {{2, 4, 8}}; Strategys inputs = {{2, 4, 8}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = tanh->Init(strategy); Status ret = tanh->Init(strategy);
...@@ -158,7 +158,7 @@ TEST_F(TestTanhInfo, CheckStrategy2) { ...@@ -158,7 +158,7 @@ TEST_F(TestTanhInfo, CheckStrategy2) {
TEST_F(TestTanhInfo, CheckStrategy3) { TEST_F(TestTanhInfo, CheckStrategy3) {
// Success: {{2,4,1,16}} // Success: {{2,4,1,16}}
std::vector<Dimensions> inputs = {{2, 4, 1, 16}}; Strategys inputs = {{2, 4, 1, 16}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = tanh->Init(strategy); Status ret = tanh->Init(strategy);
......
...@@ -38,13 +38,13 @@ class TestTensorAddInfo : public UT::Common { ...@@ -38,13 +38,13 @@ class TestTensorAddInfo : public UT::Common {
}; };
void TestTensorAddInfo::SetUp() { void TestTensorAddInfo::SetUp() {
std::vector<int32_t> dev_list; RankList dev_list;
for (int32_t i = 0; i < 34; i++) { for (int32_t i = 0; i < 34; i++) {
dev_list.push_back(i); dev_list.push_back(i);
} }
std::vector<int32_t> stage_map; RankList stage_map;
stage_map.push_back(32); stage_map.push_back(32);
stage_map.push_back(2); stage_map.push_back(2);
...@@ -66,18 +66,18 @@ void TestTensorAddInfo::SetUp() { ...@@ -66,18 +66,18 @@ void TestTensorAddInfo::SetUp() {
} }
TEST_F(TestTensorAddInfo, InferDevMatrixShape1) { TEST_F(TestTensorAddInfo, InferDevMatrixShape1) {
std::vector<Dimensions> inputs = {{2, 4, 4}, {2, 4, 4}}; Strategys inputs = {{2, 4, 4}, {2, 4, 4}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
tensor_add->Init(strategy); tensor_add->Init(strategy);
std::vector<int32_t> dev_matrix_shape = tensor_add->dev_matrix_shape(); Shape dev_matrix_shape = tensor_add->dev_matrix_shape();
std::vector<int32_t> expect = {2, 4, 4}; Shape expect = {2, 4, 4};
ASSERT_EQ(dev_matrix_shape, expect); ASSERT_EQ(dev_matrix_shape, expect);
} }
TEST_F(TestTensorAddInfo, InferSliceShape1) { TEST_F(TestTensorAddInfo, InferSliceShape1) {
std::vector<Dimensions> str = {{2, 4, 4}, {2, 4, 4}}; Strategys str = {{2, 4, 4}, {2, 4, 4}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
tensor_add->Init(strategy); tensor_add->Init(strategy);
...@@ -101,7 +101,7 @@ TEST_F(TestTensorAddInfo, InferSliceShape1) { ...@@ -101,7 +101,7 @@ TEST_F(TestTensorAddInfo, InferSliceShape1) {
} }
TEST_F(TestTensorAddInfo, GetTensorLayout1) { TEST_F(TestTensorAddInfo, GetTensorLayout1) {
std::vector<Dimensions> str = {{2, 4, 4}, {2, 4, 4}}; Strategys str = {{2, 4, 4}, {2, 4, 4}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
tensor_add->Init(strategy); tensor_add->Init(strategy);
...@@ -125,7 +125,7 @@ TEST_F(TestTensorAddInfo, GetTensorLayout1) { ...@@ -125,7 +125,7 @@ TEST_F(TestTensorAddInfo, GetTensorLayout1) {
} }
TEST_F(TestTensorAddInfo, GetForwardOp1) { TEST_F(TestTensorAddInfo, GetForwardOp1) {
std::vector<Dimensions> inputs = {{2, 4, 4}, {2, 4, 4}}; Strategys inputs = {{2, 4, 4}, {2, 4, 4}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
tensor_add->Init(strategy); tensor_add->Init(strategy);
...@@ -136,7 +136,7 @@ TEST_F(TestTensorAddInfo, GetForwardOp1) { ...@@ -136,7 +136,7 @@ TEST_F(TestTensorAddInfo, GetForwardOp1) {
} }
TEST_F(TestTensorAddInfo, GetMirrorOPs1) { TEST_F(TestTensorAddInfo, GetMirrorOPs1) {
std::vector<Dimensions> inputs = {{2, 4, 4}, {2, 4, 4}}; Strategys inputs = {{2, 4, 4}, {2, 4, 4}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
tensor_add->Init(strategy); tensor_add->Init(strategy);
...@@ -148,7 +148,7 @@ TEST_F(TestTensorAddInfo, GetMirrorOPs1) { ...@@ -148,7 +148,7 @@ TEST_F(TestTensorAddInfo, GetMirrorOPs1) {
} }
TEST_F(TestTensorAddInfo, CheckStrategy1) { TEST_F(TestTensorAddInfo, CheckStrategy1) {
std::vector<Dimensions> inputs = {{2, 4, 4}, {2, 6, 4}}; Strategys inputs = {{2, 4, 4}, {2, 6, 4}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = tensor_add->Init(strategy); Status ret = tensor_add->Init(strategy);
...@@ -156,7 +156,7 @@ TEST_F(TestTensorAddInfo, CheckStrategy1) { ...@@ -156,7 +156,7 @@ TEST_F(TestTensorAddInfo, CheckStrategy1) {
} }
TEST_F(TestTensorAddInfo, CheckStrategy2) { TEST_F(TestTensorAddInfo, CheckStrategy2) {
std::vector<Dimensions> inputs = {{2, 4, 8}, {2, 4, 8}}; Strategys inputs = {{2, 4, 8}, {2, 4, 8}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = tensor_add->Init(strategy); Status ret = tensor_add->Init(strategy);
...@@ -164,7 +164,7 @@ TEST_F(TestTensorAddInfo, CheckStrategy2) { ...@@ -164,7 +164,7 @@ TEST_F(TestTensorAddInfo, CheckStrategy2) {
} }
TEST_F(TestTensorAddInfo, CheckStrategy3) { TEST_F(TestTensorAddInfo, CheckStrategy3) {
std::vector<Dimensions> inputs = {{2, 4, 6}}; Strategys inputs = {{2, 4, 6}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = tensor_add->Init(strategy); Status ret = tensor_add->Init(strategy);
...@@ -172,7 +172,7 @@ TEST_F(TestTensorAddInfo, CheckStrategy3) { ...@@ -172,7 +172,7 @@ TEST_F(TestTensorAddInfo, CheckStrategy3) {
} }
TEST_F(TestTensorAddInfo, CheckStrategy4) { TEST_F(TestTensorAddInfo, CheckStrategy4) {
std::vector<Dimensions> inputs = {{2, 4, 4}, {2, 4, 4}}; Strategys inputs = {{2, 4, 4}, {2, 4, 4}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = tensor_add->Init(strategy); Status ret = tensor_add->Init(strategy);
...@@ -224,7 +224,7 @@ TEST_F(TestTensorAddInfo, GenerateStrategies1) { ...@@ -224,7 +224,7 @@ TEST_F(TestTensorAddInfo, GenerateStrategies1) {
} }
TEST_F(TestTensorAddInfo, mirror_ops) { TEST_F(TestTensorAddInfo, mirror_ops) {
std::vector<Dimensions> inputs = {{1, 8}, {4, 1}}; Strategys inputs = {{1, 8}, {4, 1}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
tensor_add1->Init(strategy); tensor_add1->Init(strategy);
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include "frontend/parallel/device_manager.h" #include "frontend/parallel/device_manager.h"
#include "frontend/parallel/ops_info/operator_info.h" #include "frontend/parallel/ops_info/operator_info.h"
#include "frontend/parallel/ops_info/tmp_identity_info.h" #include "frontend/parallel/ops_info/tmp_identity_info.h"
#include "frontend/parallel/step_parallel.h"
namespace mindspore { namespace mindspore {
namespace parallel { namespace parallel {
...@@ -26,7 +27,6 @@ namespace parallel { ...@@ -26,7 +27,6 @@ namespace parallel {
class TmpIdentityInfo; class TmpIdentityInfo;
using TmpIdentityInfoPtr = std::shared_ptr<TmpIdentityInfo>; using TmpIdentityInfoPtr = std::shared_ptr<TmpIdentityInfo>;
TmpIdentityInfoPtr identity_ptr; TmpIdentityInfoPtr identity_ptr;
using TensorMap = std::vector<int32_t>;
class TestTmpIdentityInfo : public UT::Common { class TestTmpIdentityInfo : public UT::Common {
public: public:
...@@ -38,13 +38,13 @@ class TestTmpIdentityInfo : public UT::Common { ...@@ -38,13 +38,13 @@ class TestTmpIdentityInfo : public UT::Common {
}; };
void TestTmpIdentityInfo::SetUp() { void TestTmpIdentityInfo::SetUp() {
std::vector<int32_t> dev_list; RankList dev_list;
for (int32_t i = 0; i < 1050; i++) { for (int32_t i = 0; i < 1050; i++) {
dev_list.push_back(i); dev_list.push_back(i);
} }
std::vector<int32_t> stage_map; RankList stage_map;
stage_map.push_back(1024); stage_map.push_back(1024);
stage_map.push_back(26); stage_map.push_back(26);
...@@ -65,18 +65,18 @@ void TestTmpIdentityInfo::SetUp() { ...@@ -65,18 +65,18 @@ void TestTmpIdentityInfo::SetUp() {
} }
TEST_F(TestTmpIdentityInfo, InferDevMatrixShape1) { TEST_F(TestTmpIdentityInfo, InferDevMatrixShape1) {
std::vector<Dimensions> inputs = {{2, 4, 8, 16}}; Strategys inputs = {{2, 4, 8, 16}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
identity_ptr->Init(strategy); identity_ptr->Init(strategy);
std::vector<int32_t> dev_matrix_shape = identity_ptr->dev_matrix_shape(); Shape dev_matrix_shape = identity_ptr->dev_matrix_shape();
std::vector<int32_t> expect = {2, 4, 8, 16}; Shape expect = {2, 4, 8, 16};
ASSERT_EQ(dev_matrix_shape, expect); ASSERT_EQ(dev_matrix_shape, expect);
} }
TEST_F(TestTmpIdentityInfo, InferSliceShape1) { TEST_F(TestTmpIdentityInfo, InferSliceShape1) {
std::vector<Dimensions> str = {{2, 4, 8, 16}}; Strategys str = {{2, 4, 8, 16}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
identity_ptr->Init(strategy); identity_ptr->Init(strategy);
...@@ -97,7 +97,7 @@ TEST_F(TestTmpIdentityInfo, InferSliceShape1) { ...@@ -97,7 +97,7 @@ TEST_F(TestTmpIdentityInfo, InferSliceShape1) {
} }
TEST_F(TestTmpIdentityInfo, GetTensorLayout1) { TEST_F(TestTmpIdentityInfo, GetTensorLayout1) {
std::vector<Dimensions> str = {{2, 4, 8, 16}}; Strategys str = {{2, 4, 8, 16}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
identity_ptr->Init(strategy); identity_ptr->Init(strategy);
...@@ -119,7 +119,7 @@ TEST_F(TestTmpIdentityInfo, GetTensorLayout1) { ...@@ -119,7 +119,7 @@ TEST_F(TestTmpIdentityInfo, GetTensorLayout1) {
TEST_F(TestTmpIdentityInfo, CheckStrategy1) { TEST_F(TestTmpIdentityInfo, CheckStrategy1) {
// Success: {{2,4,8,16}} // Success: {{2,4,8,16}}
std::vector<Dimensions> inputs = {{2, 2, 8, 16}, {2, 4, 16, 1}}; Strategys inputs = {{2, 2, 8, 16}, {2, 4, 16, 1}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = identity_ptr->Init(strategy); Status ret = identity_ptr->Init(strategy);
...@@ -128,7 +128,7 @@ TEST_F(TestTmpIdentityInfo, CheckStrategy1) { ...@@ -128,7 +128,7 @@ TEST_F(TestTmpIdentityInfo, CheckStrategy1) {
TEST_F(TestTmpIdentityInfo, CheckStrategy2) { TEST_F(TestTmpIdentityInfo, CheckStrategy2) {
// Success: {{2,4,8,16}} // Success: {{2,4,8,16}}
std::vector<Dimensions> inputs = {{2, 4, 8}}; Strategys inputs = {{2, 4, 8}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = identity_ptr->Init(strategy); Status ret = identity_ptr->Init(strategy);
......
...@@ -38,13 +38,13 @@ class TestTransposeInfo : public UT::Common { ...@@ -38,13 +38,13 @@ class TestTransposeInfo : public UT::Common {
}; };
void TestTransposeInfo::SetUp() { void TestTransposeInfo::SetUp() {
std::vector<int32_t> dev_list; RankList dev_list;
for (int32_t i = 0; i < 34; i++) { for (int32_t i = 0; i < 34; i++) {
dev_list.push_back(i); dev_list.push_back(i);
} }
std::vector<int32_t> stage_map; RankList stage_map;
stage_map.push_back(32); stage_map.push_back(32);
stage_map.push_back(2); stage_map.push_back(2);
...@@ -68,29 +68,29 @@ void TestTransposeInfo::SetUp() { ...@@ -68,29 +68,29 @@ void TestTransposeInfo::SetUp() {
} }
TEST_F(TestTransposeInfo, InferDevMatrixShape1) { TEST_F(TestTransposeInfo, InferDevMatrixShape1) {
std::vector<Dimensions> inputs = {{4, 8}}; Strategys inputs = {{4, 8}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
transpose->Init(strategy); transpose->Init(strategy);
std::vector<int32_t> dev_matrix_shape = transpose->dev_matrix_shape(); Shape dev_matrix_shape = transpose->dev_matrix_shape();
std::vector<int32_t> expect = {4, 8}; Shape expect = {4, 8};
ASSERT_EQ(dev_matrix_shape, expect); ASSERT_EQ(dev_matrix_shape, expect);
} }
TEST_F(TestTransposeInfo, InferDevMatrixShape2) { TEST_F(TestTransposeInfo, InferDevMatrixShape2) {
std::vector<Dimensions> inputs = {{4, 1}}; Strategys inputs = {{4, 1}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
transpose->Init(strategy); transpose->Init(strategy);
std::vector<int32_t> dev_matrix_shape = transpose->dev_matrix_shape(); Shape dev_matrix_shape = transpose->dev_matrix_shape();
std::vector<int32_t> expect = {8, 4, 1}; Shape expect = {8, 4, 1};
ASSERT_EQ(dev_matrix_shape, expect); ASSERT_EQ(dev_matrix_shape, expect);
} }
TEST_F(TestTransposeInfo, InferSliceShape1) { TEST_F(TestTransposeInfo, InferSliceShape1) {
std::vector<Dimensions> str = {{4, 8}}; Strategys str = {{4, 8}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
transpose->Init(strategy); transpose->Init(strategy);
...@@ -111,7 +111,7 @@ TEST_F(TestTransposeInfo, InferSliceShape1) { ...@@ -111,7 +111,7 @@ TEST_F(TestTransposeInfo, InferSliceShape1) {
} }
TEST_F(TestTransposeInfo, GetTensorLayout1) { TEST_F(TestTransposeInfo, GetTensorLayout1) {
std::vector<Dimensions> str = {{4, 8}}; Strategys str = {{4, 8}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
transpose->Init(strategy); transpose->Init(strategy);
...@@ -132,7 +132,7 @@ TEST_F(TestTransposeInfo, GetTensorLayout1) { ...@@ -132,7 +132,7 @@ TEST_F(TestTransposeInfo, GetTensorLayout1) {
} }
TEST_F(TestTransposeInfo, GetForwardOp1) { TEST_F(TestTransposeInfo, GetForwardOp1) {
std::vector<Dimensions> inputs = {{4, 8}}; Strategys inputs = {{4, 8}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
transpose->Init(strategy); transpose->Init(strategy);
...@@ -143,7 +143,7 @@ TEST_F(TestTransposeInfo, GetForwardOp1) { ...@@ -143,7 +143,7 @@ TEST_F(TestTransposeInfo, GetForwardOp1) {
} }
TEST_F(TestTransposeInfo, GetMirrorOPs1) { TEST_F(TestTransposeInfo, GetMirrorOPs1) {
std::vector<Dimensions> inputs = {{4, 8}}; Strategys inputs = {{4, 8}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
transpose->Init(strategy); transpose->Init(strategy);
...@@ -155,7 +155,7 @@ TEST_F(TestTransposeInfo, GetMirrorOPs1) { ...@@ -155,7 +155,7 @@ TEST_F(TestTransposeInfo, GetMirrorOPs1) {
} }
TEST_F(TestTransposeInfo, CheckStrategy1) { TEST_F(TestTransposeInfo, CheckStrategy1) {
std::vector<Dimensions> inputs = {{1, 4, 8}}; Strategys inputs = {{1, 4, 8}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = transpose->Init(strategy); Status ret = transpose->Init(strategy);
...@@ -163,7 +163,7 @@ TEST_F(TestTransposeInfo, CheckStrategy1) { ...@@ -163,7 +163,7 @@ TEST_F(TestTransposeInfo, CheckStrategy1) {
} }
TEST_F(TestTransposeInfo, CheckStrategy2) { TEST_F(TestTransposeInfo, CheckStrategy2) {
std::vector<Dimensions> inputs = {{2, 4, 8}, {2, 4, 8}}; Strategys inputs = {{2, 4, 8}, {2, 4, 8}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = transpose->Init(strategy); Status ret = transpose->Init(strategy);
...@@ -171,7 +171,7 @@ TEST_F(TestTransposeInfo, CheckStrategy2) { ...@@ -171,7 +171,7 @@ TEST_F(TestTransposeInfo, CheckStrategy2) {
} }
TEST_F(TestTransposeInfo, CheckStrategy3) { TEST_F(TestTransposeInfo, CheckStrategy3) {
std::vector<Dimensions> inputs = {{4, 8}}; Strategys inputs = {{4, 8}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = transpose->Init(strategy); Status ret = transpose->Init(strategy);
......
...@@ -32,13 +32,13 @@ class TestStepAutoParallel : public UT::Common { ...@@ -32,13 +32,13 @@ class TestStepAutoParallel : public UT::Common {
}; };
void TestStepAutoParallel::SetUp() { void TestStepAutoParallel::SetUp() {
std::vector<int32_t> dev_list; RankList dev_list;
for (int32_t i = 0; i < 20; i++) { for (int32_t i = 0; i < 20; i++) {
dev_list.push_back(i); dev_list.push_back(i);
} }
std::vector<int32_t> stage_map; RankList stage_map;
stage_map.push_back(16); stage_map.push_back(16);
stage_map.push_back(4); stage_map.push_back(4);
......
...@@ -33,9 +33,9 @@ class TestStrategy : public UT::Common { ...@@ -33,9 +33,9 @@ class TestStrategy : public UT::Common {
TEST_F(TestStrategy, GetInputNumber) { TEST_F(TestStrategy, GetInputNumber) {
int32_t number = 2; int32_t number = 2;
int32_t stage = 1; int32_t stage = 1;
std::vector<int32_t> dimension1 = {2, 4}; Dimensions dimension1 = {2, 4};
std::vector<int32_t> dimension2 = {2, 2}; Dimensions dimension2 = {2, 2};
std::vector<std::vector<int32_t>> inputs = {dimension1, dimension2}; Strategys inputs = {dimension1, dimension2};
Strategy strategy(stage, inputs); Strategy strategy(stage, inputs);
int32_t number_test = strategy.GetInputNumber(); int32_t number_test = strategy.GetInputNumber();
...@@ -44,9 +44,9 @@ TEST_F(TestStrategy, GetInputNumber) { ...@@ -44,9 +44,9 @@ TEST_F(TestStrategy, GetInputNumber) {
TEST_F(TestStrategy, GetInputStage) { TEST_F(TestStrategy, GetInputStage) {
int32_t stage = 1; int32_t stage = 1;
std::vector<int32_t> dimension1 = {2, 4}; Dimensions dimension1 = {2, 4};
std::vector<int32_t> dimension2 = {2, 2}; Dimensions dimension2 = {2, 2};
std::vector<std::vector<int32_t>> inputs = {dimension1, dimension2}; Strategys inputs = {dimension1, dimension2};
Strategy strategy(stage, inputs); Strategy strategy(stage, inputs);
int32_t stage_test = strategy.GetInputStage(); int32_t stage_test = strategy.GetInputStage();
...@@ -55,23 +55,23 @@ TEST_F(TestStrategy, GetInputStage) { ...@@ -55,23 +55,23 @@ TEST_F(TestStrategy, GetInputStage) {
TEST_F(TestStrategy, GetInputDim) { TEST_F(TestStrategy, GetInputDim) {
int32_t stage = 1; int32_t stage = 1;
std::vector<int32_t> dimension1 = {2, 4}; Dimensions dimension1 = {2, 4};
std::vector<int32_t> dimension2 = {2, 2}; Dimensions dimension2 = {2, 2};
std::vector<std::vector<int32_t>> inputs = {dimension1, dimension2}; Strategys inputs = {dimension1, dimension2};
Strategy strategy(stage, inputs); Strategy strategy(stage, inputs);
std::vector<std::vector<int32_t>> inputs_test = strategy.GetInputDim(); Strategys inputs_test = strategy.GetInputDim();
ASSERT_EQ(inputs, inputs_test); ASSERT_EQ(inputs, inputs_test);
} }
TEST_F(TestStrategy, IsEqual) { TEST_F(TestStrategy, IsEqual) {
int32_t stage1 = 0, stage2 = 0, stage3 = 1, stage4 = 0; int32_t stage1 = 0, stage2 = 0, stage3 = 1, stage4 = 0;
std::vector<int32_t> dimension1 = {8, 1}; Dimensions dimension1 = {8, 1};
std::vector<int32_t> dimension2 = {1, 8}; Dimensions dimension2 = {1, 8};
std::vector<std::vector<int32_t>> inputs1 = {dimension1}; Strategys inputs1 = {dimension1};
std::vector<std::vector<int32_t>> inputs2 = {dimension1}; Strategys inputs2 = {dimension1};
std::vector<std::vector<int32_t>> inputs3 = {dimension2}; Strategys inputs3 = {dimension2};
std::vector<std::vector<int32_t>> inputs4 = {dimension1, dimension2}; Strategys inputs4 = {dimension1, dimension2};
StrategyPtr stra1 = std::make_shared<Strategy>(stage1, inputs1); StrategyPtr stra1 = std::make_shared<Strategy>(stage1, inputs1);
StrategyPtr stra2 = std::make_shared<Strategy>(stage2, inputs2); StrategyPtr stra2 = std::make_shared<Strategy>(stage2, inputs2);
......
...@@ -39,12 +39,12 @@ class TestConstructOperator : public UT::Common { ...@@ -39,12 +39,12 @@ class TestConstructOperator : public UT::Common {
}; };
void TestConstructOperator::SetUp() { void TestConstructOperator::SetUp() {
std::vector<int32_t> dev_list; RankList dev_list;
for (int32_t i = 0; i < 1050; i++) { for (int32_t i = 0; i < 1050; i++) {
dev_list.push_back(i); dev_list.push_back(i);
} }
std::vector<int32_t> stage_map; RankList stage_map;
stage_map.push_back(1024); stage_map.push_back(1024);
stage_map.push_back(26); stage_map.push_back(26);
...@@ -62,7 +62,7 @@ void TestConstructOperator::SetUp() { ...@@ -62,7 +62,7 @@ void TestConstructOperator::SetUp() {
MatMulInfoPtr matmul = std::make_shared<MatMulInfo>("matmul_info", inputs_shape_1, outputs_shape_1, attr_1); MatMulInfoPtr matmul = std::make_shared<MatMulInfo>("matmul_info", inputs_shape_1, outputs_shape_1, attr_1);
std::vector<Dimensions> str = {{2, 4, 8, 16}, {2, 4, 16, 1}}; Strategys str = {{2, 4, 8, 16}, {2, 4, 16, 1}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
matmul->Init(strategy); matmul->Init(strategy);
Shape tensor_shape = {512, 1024}; Shape tensor_shape = {512, 1024};
...@@ -79,8 +79,8 @@ TEST_F(TestConstructOperator, TestReshapeOP) { ...@@ -79,8 +79,8 @@ TEST_F(TestConstructOperator, TestReshapeOP) {
TEST_F(TestConstructOperator, TestStridedSliceOP) { TEST_F(TestConstructOperator, TestStridedSliceOP) {
Args args = {1, 2, 3}; Args args = {1, 2, 3};
int32_t split_count = args[0]; int64_t split_count = args[0];
int32_t split_dim = args[1]; int64_t split_dim = args[1];
Shape device_arrangement = {8, 4}; Shape device_arrangement = {8, 4};
Arrangement dev_mat; Arrangement dev_mat;
dev_mat.Init(device_arrangement); dev_mat.Init(device_arrangement);
...@@ -98,12 +98,18 @@ TEST_F(TestConstructOperator, TestStridedSliceOP) { ...@@ -98,12 +98,18 @@ TEST_F(TestConstructOperator, TestStridedSliceOP) {
OperatorParams params = op.second.second; OperatorParams params = op.second.second;
ValuePtr begin_ptr = params[0].first.second; ValuePtr begin_ptr = params[0].first.second;
ValuePtr end_ptr = params[1].first.second; ValuePtr end_ptr = params[1].first.second;
Shape begin = GetValue<const std::vector<int>>(begin_ptr); std::vector<int32_t> begin_int = GetValue<const std::vector<int32_t>>(begin_ptr);
Shape end = GetValue<const std::vector<int>>(end_ptr); std::vector<int32_t> end_int = GetValue<const std::vector<int32_t>>(end_ptr);
Shape begin;
Shape end;
(void)std::transform(begin_int.begin(), begin_int.end(), std::back_inserter(begin),
[](const int32_t &value) { return static_cast<int64_t>(value); });
(void)std::transform(end_int.begin(), end_int.end(), std::back_inserter(end),
[](const int32_t &value) { return static_cast<int64_t>(value); });
for (size_t i = 0; i < begin.size(); i++) { for (size_t i = 0; i < begin.size(); i++) {
int32_t diff = end[i] - begin[i]; int64_t diff = end[i] - begin[i];
int32_t num = shape[i]; int64_t num = shape[i];
if (SizeToInt(i) != split_dim) { if (SizeToLong(i) != split_dim) {
ASSERT_EQ(diff, shape[i]); ASSERT_EQ(diff, shape[i]);
} else { } else {
ASSERT_EQ(diff, num / split_count); ASSERT_EQ(diff, num / split_count);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册