提交 2031710d 编写于 作者: H hongxing

fix bug and optimize code

上级 26d05be8
......@@ -39,17 +39,18 @@ void GenerateStrategy(std::shared_ptr<Graph> graph, const std::vector<std::share
std::shared_ptr<std::vector<size_t>> no_stra_op_list(new std::vector<size_t>);
GenerateEliminatedOperatorStrategyForward(graph, ops, eli_list, input_tensor_names, index_list, no_stra_op_list);
GenerateEliminatedOperatorStrategyBackward(ops, input_tensor_names, no_stra_op_list);
GenerateRemainingOperatorStrategy(graph, ops, input_tensor_names, index_list, no_stra_op_list);
}
std::vector<std::vector<int32_t>> PrepareMatMul(const std::shared_ptr<Graph> &graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_graph, const size_t iter_ops) {
std::vector<std::vector<int32_t>> strategies;
auto attrs = ops[iter_ops]->attrs();
bool transpose_a = attrs[TRANSPOSE_A]->cast<BoolImmPtr>()->value();
bool transpose_b = attrs[TRANSPOSE_B]->cast<BoolImmPtr>()->value();
for (size_t iter_op_inputs = 0; iter_op_inputs < ops[iter_ops]->inputs_tensor_info().size(); iter_op_inputs++) {
std::vector<int32_t> s;
auto attrs = ops[iter_ops]->attrs();
bool transpose_a = attrs[TRANSPOSE_A]->cast<BoolImmPtr>()->value();
bool transpose_b = attrs[TRANSPOSE_B]->cast<BoolImmPtr>()->value();
if (transpose_a && (iter_op_inputs == 0)) {
s.push_back(
static_cast<int32_t>(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_w));
......@@ -71,43 +72,20 @@ std::vector<std::vector<int32_t>> PrepareMatMul(const std::shared_ptr<Graph> &gr
return strategies;
}
std::vector<std::vector<int32_t>> PrepareVirtualDataset(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_ops) {
std::vector<std::vector<int32_t>> strategies = MakeDataParallelStrategy(ops, iter_ops);
strategies[1][0] = strategies[0][0];
std::vector<std::vector<int32_t>> PreparePReLU(const std::shared_ptr<Graph> &graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_graph, const size_t iter_ops) {
std::vector<std::vector<int32_t>> strategies = MakeDataParallelStrategy(graph, ops, iter_graph, iter_ops);
strategies[1][0] = 1;
return strategies;
}
std::vector<std::vector<int32_t>> PrepareScalarInputOperator(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_ops, std::vector<int32_t> s) {
std::vector<std::vector<int32_t>> PrepareBiasAdd(std::vector<int32_t> s) {
std::vector<std::vector<int32_t>> strategies;
auto dev_num = g_device_manager->DeviceNum();
size_t cut_num = 1;
for (size_t iter_s = 0; iter_s < s.size(); iter_s++) {
cut_num *= s[iter_s];
}
if (cut_num != dev_num) {
std::vector<int32_t> s_max = s;
for (size_t dim = 0; dim < (size_t)ops[iter_ops]->inputs_tensor_info()[0].shape().size(); dim++) {
size_t shape = ops[iter_ops]->inputs_tensor_info()[0].shape()[dim] / s[dim];
while (cut_num < dev_num && shape % 2 == 0) {
shape = shape / 2;
s_max[dim] = s_max[dim] * 2;
cut_num = cut_num * 2;
}
if (cut_num == dev_num) {
break;
}
}
s = s_max;
}
strategies.push_back(s);
std::vector<int32_t> s_biasadd;
s_biasadd.push_back(s[1]);
strategies.push_back(s_biasadd);
return strategies;
}
......@@ -131,16 +109,13 @@ std::vector<std::vector<int32_t>> MakeRecSearchStrategy(const std::shared_ptr<Gr
}
StrategyPtr origin_strategy = ops[iter_ops]->strategy();
std::vector<std::vector<int32_t>> strategies;
for (size_t iter_op_inputs = 0; iter_op_inputs < ops[iter_ops]->inputs_tensor_info().size(); iter_op_inputs++) {
if (iter_op_inputs >= origin_strategy->GetInputDim().size()) {
MS_LOG(EXCEPTION) << "Failure: Strategy's InputDim out of range.";
}
// size_t output_size = ops[iter_ops]->outputs_tensor_info()[0].shape().size();
size_t output_size = origin_strategy->GetInputDim()[iter_op_inputs].size();
std::vector<int32_t> s;
if (output_size == 4) {
s.push_back(
......@@ -164,14 +139,14 @@ std::vector<std::vector<int32_t>> MakeRecSearchStrategy(const std::shared_ptr<Gr
} else {
MS_LOG(ERROR) << "Tensor's output size is unexcepted.";
}
strategies.push_back(s);
}
return strategies;
}
std::vector<std::vector<int32_t>> MakeDataParallelStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_ops) {
std::vector<std::vector<int32_t>> MakeDataParallelStrategy(const std::shared_ptr<Graph> &graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_graph, const size_t iter_ops) {
if (ops.empty()) {
MS_LOG(EXCEPTION) << "Failure: Operators is empty.";
}
......@@ -180,8 +155,9 @@ std::vector<std::vector<int32_t>> MakeDataParallelStrategy(const std::vector<std
}
StrategyPtr origin_strategy = ops[iter_ops]->strategy();
std::vector<std::vector<int32_t>> strategies;
size_t max_device_num = g_device_manager->DeviceNum();
size_t target_tensor_batch = ops[iter_ops]->outputs_tensor_info()[0].shape()[0];
for (size_t iter_op_inputs = 0; iter_op_inputs < ops[iter_ops]->inputs_tensor_info().size(); iter_op_inputs++) {
if (iter_op_inputs >= origin_strategy->GetInputDim().size()) {
MS_LOG(EXCEPTION) << "Failure: Strategy's InputDim out of range.";
......@@ -192,8 +168,6 @@ std::vector<std::vector<int32_t>> MakeDataParallelStrategy(const std::vector<std
for (size_t dim = 0; dim < input_size; dim++) {
if (input_size == 1 || input_size == 2 || input_size == 4) {
if (dim == 0) {
size_t max_device_num = g_device_manager->DeviceNum();
size_t target_tensor_batch = ops[iter_ops]->outputs_tensor_info()[0].shape()[0];
s.push_back(std::min(max_device_num, target_tensor_batch));
} else {
s.push_back(1);
......@@ -202,9 +176,21 @@ std::vector<std::vector<int32_t>> MakeDataParallelStrategy(const std::vector<std
MS_LOG(ERROR) << "Tensor's shape is unknown.";
}
}
strategies.push_back(s);
}
graph->nodes[iter_graph].tensor_parm.tensor_str.str_n = 1.0;
graph->nodes[iter_graph].tensor_parm.tensor_str.str_c = 1.0;
graph->nodes[iter_graph].tensor_parm.tensor_str.str_h = 1.0;
graph->nodes[iter_graph].tensor_parm.tensor_str.str_w = 1.0;
if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 1) {
graph->nodes[iter_graph].tensor_parm.tensor_str.str_w = 1.0 / std::min(max_device_num, target_tensor_batch);
} else if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 2) {
graph->nodes[iter_graph].tensor_parm.tensor_str.str_h = 1.0 / std::min(max_device_num, target_tensor_batch);
} else if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 4) {
graph->nodes[iter_graph].tensor_parm.tensor_str.str_n = 1.0 / std::min(max_device_num, target_tensor_batch);
}
return strategies;
}
......@@ -217,20 +203,18 @@ std::vector<std::vector<int32_t>> PrepareStrategy(const std::shared_ptr<Graph> &
if (iter_ops >= ops.size()) {
MS_LOG(EXCEPTION) << "Failure: Operators' elements out of range.";
}
MS_EXCEPTION_IF_NULL(ops[iter_ops]);
auto type = ops[iter_ops]->type();
if (type == VIRTUAL_DATA_SET) {
return PrepareVirtualDataset(ops, iter_ops);
}
auto idx = DictOpType.find(type);
if (idx == DictOpType.end()) {
return MakeDataParallelStrategy(ops, iter_ops);
return MakeDataParallelStrategy(graph, ops, iter_graph, iter_ops);
}
if (type == MATMUL) {
return PrepareMatMul(graph, ops, iter_graph, iter_ops);
} else if (type == RESHAPE) {
return MakeDataParallelStrategy(ops, iter_ops);
} else if (type == PRELU) {
return PreparePReLU(graph, ops, iter_graph, iter_ops);
} else {
return MakeRecSearchStrategy(graph, ops, iter_graph, iter_ops);
}
......@@ -242,28 +226,25 @@ void GeneratePartitionedOperatorStrategy(const std::shared_ptr<Graph> graph,
for (size_t iter_ops = 0; iter_ops < (size_t)index_list->size(); iter_ops++) {
std::vector<std::vector<int32_t>> strategies;
size_t iter_graph = index_list->at(iter_ops);
if (iter_graph == SIZE_MAX) {
StrategyPtr sp = std::make_shared<Strategy>(0, strategies);
ops[iter_ops]->SetSelectedStrategyAndCost(sp, ops[iter_ops]->selected_cost());
continue;
if (iter_graph != SIZE_MAX) {
strategies = PrepareStrategy(graph, ops, iter_graph, iter_ops);
}
strategies = PrepareStrategy(graph, ops, iter_graph, iter_ops);
StrategyPtr sp = std::make_shared<Strategy>(0, strategies);
ops[iter_ops]->SetSelectedStrategyAndCost(sp, ops[iter_ops]->selected_cost());
}
}
int FindIndexOfOperatorIncoming(const std::vector<std::vector<std::string>> &input_tensor_names,
const size_t iter_ops) {
int incoming_op_index = -1;
for (size_t i = 1; i < (size_t)input_tensor_names[iter_ops].size(); i++) {
for (size_t j = 0; j < (size_t)input_tensor_names.size(); j++) {
size_t FindIndexOfOperatorIncoming(const std::vector<std::vector<std::string>> &input_tensor_names,
const size_t iter_ops) {
size_t incoming_op_index = SIZE_MAX;
for (size_t i = 1; i < input_tensor_names[iter_ops].size(); i++) {
for (size_t j = 0; j < input_tensor_names.size(); j++) {
if (input_tensor_names[iter_ops][i] == input_tensor_names[j][0]) {
incoming_op_index = j;
break;
}
}
if (incoming_op_index != -1) {
if (incoming_op_index != SIZE_MAX) {
break;
}
}
......@@ -298,12 +279,16 @@ std::vector<int32_t> CopyIncomingOperatorOutputStrategy(const std::shared_ptr<Gr
}
std::vector<int32_t> PrepareIncomingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const int incoming_op_index) {
const size_t incoming_op_index) {
std::vector<int32_t> s;
if (ops[incoming_op_index]->type() == RESHAPE) {
return s;
}
auto strategy = ops[incoming_op_index]->selected_strategy();
if (strategy->GetInputNumber() == 0) {
return s;
}
for (size_t i = 0; i < (size_t)ops[incoming_op_index]->inputs_tensor_info().size(); i++) {
if (ops[incoming_op_index]->inputs_tensor_info()[i].shape().size() == 0) {
continue;
......@@ -327,6 +312,7 @@ std::vector<int32_t> GetAxisList(const std::vector<std::shared_ptr<OperatorInfo>
} else {
MS_LOG(EXCEPTION) << "Failure: Axis type is invalid, neither tuple nor list." << std::endl;
}
for (auto &element : elements) {
if (!element->isa<Int32Imm>()) {
MS_LOG(EXCEPTION) << "Failure: Dimension indexes is not Int32." << std::endl;
......@@ -338,12 +324,13 @@ std::vector<int32_t> GetAxisList(const std::vector<std::shared_ptr<OperatorInfo>
}
std::vector<int32_t> ModifyStrategyIfSqueezeIncoming(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const int incoming_op_index, std::vector<int32_t> s) {
const size_t incoming_op_index, std::vector<int32_t> s) {
std::vector<int32_t> s_Squeeze;
std::vector<int32_t> stra_dim_list;
for (size_t i = 0; i < s.size(); i++) {
stra_dim_list.push_back(i);
}
auto axis_list = GetAxisList(ops, incoming_op_index);
for (auto axis : axis_list) {
auto it = find(stra_dim_list.begin(), stra_dim_list.end(), axis);
......@@ -355,6 +342,7 @@ std::vector<int32_t> ModifyStrategyIfSqueezeIncoming(const std::vector<std::shar
}
stra_dim_list.erase(it);
}
for (size_t i = 0; i < (size_t)stra_dim_list.size(); i++) {
s_Squeeze.push_back(s[stra_dim_list[i]]);
}
......@@ -391,12 +379,13 @@ std::vector<int32_t> GetDimList(const std::vector<std::shared_ptr<OperatorInfo>>
}
std::vector<int32_t> ModifyStrategyIfReduceIncoming(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const int incoming_op_index, std::vector<int32_t> s) {
const size_t incoming_op_index, std::vector<int32_t> s) {
std::vector<int32_t> s_Reduce;
std::vector<int32_t> axis_list;
for (size_t i = 0; i < s.size(); i++) {
axis_list.push_back(i);
}
auto dim_list = GetDimList(ops, incoming_op_index);
for (auto axis : dim_list) {
auto it = find(axis_list.begin(), axis_list.end(), axis);
......@@ -405,6 +394,7 @@ std::vector<int32_t> ModifyStrategyIfReduceIncoming(const std::vector<std::share
}
axis_list.erase(it);
}
for (size_t i = 0; i < (size_t)axis_list.size(); i++) {
s_Reduce.push_back(s[axis_list[i]]);
}
......@@ -412,10 +402,10 @@ std::vector<int32_t> ModifyStrategyIfReduceIncoming(const std::vector<std::share
}
std::vector<int32_t> CopyIncomingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const int incoming_op_index, const size_t iter_ops,
const std::shared_ptr<std::vector<size_t>> no_stra_op_list) {
const size_t iter_ops, const size_t incoming_op_index) {
std::vector<int32_t> s;
s = PrepareIncomingOperatorInputStrategy(ops, incoming_op_index);
if (s.size() != 0) {
if (ops[incoming_op_index]->type() == SQUEEZE) {
s = ModifyStrategyIfSqueezeIncoming(ops, incoming_op_index, s);
......@@ -429,27 +419,27 @@ std::vector<int32_t> CopyIncomingOperatorInputStrategy(const std::vector<std::sh
}
std::vector<std::vector<int32_t>> GenerateStrategiesFromStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_ops, std::vector<int32_t> s) {
const size_t iter_ops,
std::vector<int32_t> basic_stra) {
std::vector<int32_t> s_empty = {};
std::vector<std::vector<int32_t>> stra;
MS_EXCEPTION_IF_NULL(ops[iter_ops]);
if (s.size() == 0) {
if (basic_stra.size() == 0) {
for (size_t iter_op_inputs = 0; iter_op_inputs < (size_t)ops[iter_ops]->inputs_tensor_info().size();
iter_op_inputs++) {
stra.push_back(s);
stra.push_back(basic_stra);
}
return stra;
}
MS_EXCEPTION_IF_NULL(ops[iter_ops]);
if (ops[iter_ops]->type() == BIAS_ADD || ops[iter_ops]->type() == PRELU) {
return PrepareScalarInputOperator(ops, iter_ops, s);
if (ops[iter_ops]->type() == BIAS_ADD) {
return PrepareBiasAdd(basic_stra);
}
if (ops[iter_ops]->type() == ONEHOT) {
return PrepareOneHot(s);
return PrepareOneHot(basic_stra);
}
auto dev_num = g_device_manager->DeviceNum();
for (size_t iter_op_inputs = 0; iter_op_inputs < (size_t)ops[iter_ops]->inputs_tensor_info().size();
iter_op_inputs++) {
if (ops[iter_ops]->inputs_tensor_info()[iter_op_inputs].shape().size() == 0) {
......@@ -457,41 +447,19 @@ std::vector<std::vector<int32_t>> GenerateStrategiesFromStrategy(const std::vect
continue;
}
size_t cut_num = 1;
for (size_t iter_s = 0; iter_s < s.size(); iter_s++) {
cut_num *= s[iter_s];
}
if (cut_num == dev_num) {
std::vector<int32_t> s_1 = s;
bool modified = false;
for (size_t j = 0; j < (size_t)ops[iter_ops]->inputs_tensor_info()[iter_op_inputs].shape().size(); j++) {
if (ops[iter_ops]->inputs_tensor_info()[iter_op_inputs].shape()[j] == 1) {
s_1[j] = 1;
modified = true;
}
std::vector<int32_t> tmp_stra = basic_stra;
bool modified = false;
for (size_t j = 0; j < (size_t)ops[iter_ops]->inputs_tensor_info()[iter_op_inputs].shape().size(); j++) {
if (ops[iter_ops]->inputs_tensor_info()[iter_op_inputs].shape()[j] == 1) {
tmp_stra[j] = 1;
modified = true;
}
if (modified) {
stra.push_back(s_1);
} else {
stra.push_back(s);
}
continue;
}
std::vector<int32_t> s_max = s;
for (size_t dim = 0; dim < (size_t)ops[iter_ops]->inputs_tensor_info()[iter_op_inputs].shape().size(); dim++) {
size_t shape = ops[iter_ops]->inputs_tensor_info()[iter_op_inputs].shape()[dim] / s[dim];
while (cut_num < dev_num && shape % 2 == 0) {
shape = shape / 2;
s_max[dim] = s_max[dim] * 2;
cut_num = cut_num * 2;
}
if (cut_num == dev_num) {
break;
}
if (modified) {
stra.push_back(tmp_stra);
} else {
stra.push_back(basic_stra);
}
stra.push_back(s_max);
}
return stra;
}
......@@ -502,17 +470,17 @@ void GenerateEliminatedOperatorStrategyForward(const std::shared_ptr<Graph> grap
const std::vector<std::vector<std::string>> &input_tensor_names,
const std::shared_ptr<std::vector<size_t>> index_list,
const std::shared_ptr<std::vector<size_t>> no_stra_op_list) {
for (int eli_index = eli_list->size() - 1; eli_index >= 0; eli_index--) {
size_t iter_ops = eli_list->at(eli_index)[0];
for (size_t eli_index = eli_list->size(); eli_index > 0; eli_index--) {
size_t iter_ops = eli_list->at(eli_index - 1)[0];
std::vector<std::vector<int32_t>> stra;
std::vector<int32_t> s;
int incoming_op_index = FindIndexOfOperatorIncoming(input_tensor_names, iter_ops);
if (incoming_op_index != -1) {
size_t incoming_op_index = FindIndexOfOperatorIncoming(input_tensor_names, iter_ops);
if (incoming_op_index != SIZE_MAX && ops[iter_ops]->type() != ONEHOT) {
auto iter_graph = index_list->at(incoming_op_index);
if (iter_graph != SIZE_MAX) {
s = CopyIncomingOperatorOutputStrategy(graph, ops, iter_ops, iter_graph);
} else {
s = CopyIncomingOperatorInputStrategy(ops, incoming_op_index, iter_ops, no_stra_op_list);
s = CopyIncomingOperatorInputStrategy(ops, iter_ops, incoming_op_index);
}
}
......@@ -534,7 +502,7 @@ std::vector<int32_t> ModifyStrategyIfSqueezeOutgoing(const std::vector<std::shar
size_t s_index = 0;
size_t axis_list_index = 0;
for (size_t i = 0; i < (size_t)(s.size() + axis_list.size()); i++) {
if ((i) == (size_t)axis_list[axis_list_index]) {
if (i == (size_t)axis_list[axis_list_index]) {
s_Squeeze.push_back(1);
axis_list_index++;
} else {
......@@ -542,46 +510,49 @@ std::vector<int32_t> ModifyStrategyIfSqueezeOutgoing(const std::vector<std::shar
s_index++;
}
}
return s_Squeeze;
}
std::vector<int32_t> ModifyStrategyIfReduceOutgoing(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_ops, std::vector<int32_t> s) {
std::vector<int32_t> dim_list = GetDimList(ops, iter_ops);
if (dim_list.size() == 0) {
return s;
size_t cut = 1;
for (size_t i = 0; i < s_Squeeze.size(); i++) {
cut *= s_Squeeze[i];
}
std::vector<int32_t> s_Reduce;
size_t s_index = 0;
size_t dim_list_index = 0;
for (size_t i = 0; i < (size_t)(s.size() + dim_list.size()); i++) {
if (i == (size_t)dim_list[dim_list_index]) {
s_Reduce.push_back(1);
dim_list_index++;
} else {
s_Reduce.push_back(s[s_index]);
s_index++;
}
if (cut != g_device_manager->DeviceNum()) {
s_Squeeze.clear();
}
return s_Reduce;
return s_Squeeze;
}
std::vector<int32_t> CopyOutgoingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const std::vector<std::vector<std::string>> &input_tensor_names,
const size_t iter_ops) {
std::vector<int32_t> s;
if (ops[iter_ops]->type() == REDUCE_MAX || ops[iter_ops]->type() == REDUCE_MIN ||
ops[iter_ops]->type() == REDUCE_SUM || ops[iter_ops]->type() == REDUCE_MEAN || ops[iter_ops]->type() == RESHAPE) {
return s;
}
bool found = false;
for (size_t i = 0; i < (size_t)input_tensor_names.size(); i++) {
for (size_t j = 1; j < (size_t)input_tensor_names[i].size(); j++) {
if (input_tensor_names[i][j] == input_tensor_names[iter_ops][0]) {
for (size_t k = 0; k < ops[i]->selected_strategy()->GetInputDim()[j - 1].size(); ++k) {
s.push_back(ops[i]->selected_strategy()->GetInputDim()[j - 1][k]);
}
size_t outgoing_op_index = SIZE_MAX;
size_t iter_op_inputs = SIZE_MAX;
for (size_t i = 0; i < input_tensor_names.size(); i++) {
for (size_t j = 1; j < input_tensor_names[i].size(); j++) {
if (input_tensor_names[i][j] == input_tensor_names[iter_ops][0] &&
ops[i]->selected_strategy()->GetInputNumber() != 0) {
outgoing_op_index = i;
iter_op_inputs = j - 1;
found = true;
break;
}
}
if (found) break;
if (found) {
break;
}
}
if (outgoing_op_index != SIZE_MAX && iter_op_inputs != SIZE_MAX) {
for (size_t k = 0; k < ops[outgoing_op_index]->selected_strategy()->GetInputDim()[iter_op_inputs].size(); ++k) {
s.push_back(ops[outgoing_op_index]->selected_strategy()->GetInputDim()[iter_op_inputs][k]);
}
}
return s;
}
......@@ -589,23 +560,70 @@ std::vector<int32_t> CopyOutgoingOperatorInputStrategy(const std::vector<std::sh
void GenerateEliminatedOperatorStrategyBackward(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const std::vector<std::vector<std::string>> &input_tensor_names,
const std::shared_ptr<std::vector<size_t>> no_stra_op_list) {
MS_EXCEPTION_IF_NULL(no_stra_op_list);
for (int iter_list = no_stra_op_list->size() - 1; iter_list >= 0; iter_list--) {
auto iter_ops = no_stra_op_list->at(iter_list);
if (no_stra_op_list->size() == 0) {
return;
}
std::vector<size_t> no_stra_op_list_bis;
for (size_t iter_list = no_stra_op_list->size(); iter_list > 0; iter_list--) {
auto iter_ops = no_stra_op_list->at(iter_list - 1);
std::vector<std::vector<int32_t>> stra;
std::vector<int32_t> s = CopyOutgoingOperatorInputStrategy(ops, input_tensor_names, iter_ops);
if (s.size() != 0 && ops[iter_ops]->type() == SQUEEZE) {
s = ModifyStrategyIfSqueezeOutgoing(ops, iter_ops, s);
}
if (s.size() != 0) {
stra = GenerateStrategiesFromStrategy(ops, iter_ops, s);
} else {
no_stra_op_list_bis.push_back(iter_ops);
}
StrategyPtr sp = std::make_shared<Strategy>(0, stra);
ops[iter_ops]->SetSelectedStrategyAndCost(sp, ops[iter_ops]->selected_cost());
}
no_stra_op_list->clear();
for (size_t i = 0; i < no_stra_op_list_bis.size(); i++) {
no_stra_op_list->push_back(no_stra_op_list_bis[i]);
}
}
void GenerateRemainingOperatorStrategy(const std::shared_ptr<Graph> graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const std::vector<std::vector<std::string>> &input_tensor_names,
const std::shared_ptr<std::vector<size_t>> index_list,
const std::shared_ptr<std::vector<size_t>> no_stra_op_list) {
if (no_stra_op_list->size() == 0) {
return;
}
for (size_t iter_list = no_stra_op_list->size(); iter_list > 0; iter_list--) {
auto iter_ops = no_stra_op_list->at(iter_list - 1);
std::vector<std::vector<int32_t>> stra;
std::vector<int32_t> s;
size_t incoming_op_index = FindIndexOfOperatorIncoming(input_tensor_names, iter_ops);
if (incoming_op_index != SIZE_MAX) {
auto iter_graph = index_list->at(incoming_op_index);
if (iter_graph != SIZE_MAX) {
s = CopyIncomingOperatorOutputStrategy(graph, ops, iter_ops, iter_graph);
} else {
s = CopyIncomingOperatorInputStrategy(ops, iter_ops, incoming_op_index);
}
}
if (s.size() == 0) {
for (size_t i = 0; i < ops[iter_ops]->inputs_tensor_info()[0].shape().size(); i++) {
size_t max_dim_num = 0;
for (size_t iter_op_inputs = 0; iter_op_inputs < ops[iter_ops]->inputs_tensor_info().size(); iter_op_inputs++) {
if (ops[iter_ops]->inputs_tensor_info()[iter_op_inputs].shape().size() > max_dim_num) {
max_dim_num = ops[iter_ops]->inputs_tensor_info()[iter_op_inputs].shape().size();
}
}
for (size_t i = 0; i < max_dim_num; i++) {
s.push_back(1);
}
}
if (ops[iter_ops]->type() == SQUEEZE) {
s = ModifyStrategyIfSqueezeOutgoing(ops, iter_ops, s);
}
if (ops[iter_ops]->type() == REDUCE_SUM || ops[iter_ops]->type() == REDUCE_MAX ||
ops[iter_ops]->type() == REDUCE_MIN || ops[iter_ops]->type() == REDUCE_MEAN) {
s = ModifyStrategyIfReduceOutgoing(ops, iter_ops, s);
}
stra = GenerateStrategiesFromStrategy(ops, iter_ops, s);
StrategyPtr sp = std::make_shared<Strategy>(0, stra);
ops[iter_ops]->SetSelectedStrategyAndCost(sp, ops[iter_ops]->selected_cost());
......
......@@ -34,37 +34,38 @@ void GenerateStrategy(std::shared_ptr<Graph> graph, const std::vector<std::share
std::vector<std::vector<int32_t>> PrepareMatMul(const std::shared_ptr<Graph> &graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_graph, const size_t iter_ops);
std::vector<std::vector<int32_t>> PrepareVirtualDataset(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_ops);
std::vector<std::vector<int32_t>> PrepareScalarInputOperator(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_ops, std::vector<int32_t> s);
std::vector<std::vector<int32_t>> PreparePReLU(const std::shared_ptr<Graph> &graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_graph, const size_t iter_ops);
std::vector<std::vector<int32_t>> PrepareBiasAdd(std::vector<int32_t> s);
std::vector<std::vector<int32_t>> PrepareOneHot(std::vector<int32_t> s);
std::vector<std::vector<int32_t>> MakeRecSearchStrategy(const std::shared_ptr<Graph> &graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_graph, const size_t iter_ops);
std::vector<std::vector<int32_t>> MakeDataParallelStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_ops);
std::vector<std::vector<int32_t>> MakeDataParallelStrategy(const std::shared_ptr<Graph> &graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_graph, const size_t iter_ops);
std::vector<std::vector<int32_t>> PrepareStrategy(const std::shared_ptr<Graph> &graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_graph, const size_t iter_ops);
void GeneratePartitionedOperatorStrategy(const std::shared_ptr<Graph> graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const std::shared_ptr<std::vector<size_t>> index_list);
int FindIndexOfOperatorIncoming(const std::vector<std::vector<std::string>> &input_tensor_names, const size_t iter_ops);
size_t FindIndexOfOperatorIncoming(const std::vector<std::vector<std::string>> &input_tensor_names,
const size_t iter_ops);
std::vector<int32_t> CopyIncomingOperatorOutputStrategy(const std::shared_ptr<Graph> graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_ops, const size_t iter_graph);
std::vector<int32_t> PrepareIncomingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const int incoming_op_index);
const size_t incoming_op_index);
std::vector<int32_t> GetAxisList(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const int iter_ops);
std::vector<int32_t> ModifyStrategyIfSqueezeIncoming(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const int incoming_op_index, std::vector<int32_t> s);
const size_t incoming_op_index, std::vector<int32_t> s);
std::vector<int32_t> GetDimList(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops);
std::vector<int32_t> ModifyStrategyIfReduceIncoming(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const int incoming_op_index, std::vector<int32_t> s);
const size_t incoming_op_index, std::vector<int32_t> s);
std::vector<int32_t> CopyIncomingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const int incoming_op_index, const size_t iter_ops,
const std::shared_ptr<std::vector<size_t>> no_stra_op_list);
const size_t iter_ops, const size_t incoming_op_index);
std::vector<std::vector<int32_t>> GenerateStrategiesFromStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_ops, std::vector<int32_t> s);
void GenerateEliminatedOperatorStrategyForward(std::shared_ptr<Graph> graph,
......@@ -75,14 +76,17 @@ void GenerateEliminatedOperatorStrategyForward(std::shared_ptr<Graph> graph,
const std::shared_ptr<std::vector<size_t>> no_stra_op_list);
std::vector<int32_t> ModifyStrategyIfSqueezeOutgoing(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_ops, std::vector<int32_t> s);
std::vector<int32_t> ModifyStrategyIfReduceOutgoing(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_ops, std::vector<int32_t> s);
std::vector<int32_t> CopyOutgoingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const std::vector<std::vector<std::string>> &input_tensor_names,
const size_t iter_ops);
void GenerateEliminatedOperatorStrategyBackward(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const std::vector<std::vector<std::string>> &input_tensor_names,
const std::shared_ptr<std::vector<size_t>> no_stra_op_list);
void GenerateRemainingOperatorStrategy(const std::shared_ptr<Graph> graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const std::vector<std::vector<std::string>> &input_tensor_names,
const std::shared_ptr<std::vector<size_t>> index_list,
const std::shared_ptr<std::vector<size_t>> no_stra_op_list);
} // namespace parallel
} // namespace mindspore
#endif // PARALLEL_AUTO_PARALLEL_REC_GENERATE_STRATEGY_H_
......@@ -47,7 +47,8 @@ enum OperatorType {
kRecDiv,
kRecSqueeze,
kRecCast,
kRecReduce
kRecReduce,
kRecPReLU
};
enum InfoType { kApplication, kConstant };
......
......@@ -199,7 +199,7 @@ std::shared_ptr<Graph> EliminateGraph(const std::shared_ptr<Graph> graph,
OperatorType::kRecOneHot, OperatorType::kRecReLU, OperatorType::kRecLog, OperatorType::kRecExp,
OperatorType::kRecAdd, OperatorType::kRecElmWiseOp, OperatorType::kRecBiasAdd, OperatorType::kRecSub,
OperatorType::kRecMul, OperatorType::kRecDiv, OperatorType::kRecSqueeze, OperatorType::kRecReduce,
OperatorType::kRecCast};
OperatorType::kRecCast, OperatorType::kRecReshape};
for (size_t node_index = 0; node_index < (size_t)graph->nodes.size(); node_index++) {
auto type = graph->nodes[node_index].apply.op_type;
if (type_list.find(type) != type_list.end()) {
......
......@@ -55,7 +55,8 @@ const std::map<std::string, OperatorType> DictOpType{
{"HSigmoid", OperatorType::kRecReLU},
{GELU, OperatorType::kRecReLU},
{TANH, OperatorType::kRecReLU},
{PRELU, OperatorType::kRecReLU},
{PRELU, OperatorType::kRecPReLU},
{TENSOR_ADD, OperatorType::kRecElmWiseOp},
{SUB, OperatorType::kRecElmWiseOp},
......
......@@ -83,7 +83,7 @@ double GetWeights(const Graph::NodeType &node) {
auto cost_ptr = std::make_shared<CostCommon>();
return cost_ptr->GetMinCostIn();
} else if (op.op_type == OperatorType::kRecUnkownType) {
} else if (op.op_type == OperatorType::kRecUnkownType || op.op_type == OperatorType::kRecPReLU) {
// For unknown type
return 0.0;
} else {
......@@ -177,7 +177,7 @@ StrategyRec PartitionNode(const Graph::NodeType &node,
auto cost_ptr = std::make_shared<CostCommon>();
return cost_ptr->GetOptimalStr(node, node_name_to_strategy, *graph);
} else if (node.apply.op_type == OperatorType::kRecUnkownType) {
} else if (node.apply.op_type == OperatorType::kRecUnkownType || node.apply.op_type == OperatorType::kRecPReLU) {
// For unknown type
StrategyRec default_strategy;
return default_strategy;
......
......@@ -464,6 +464,11 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_node
}
ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
if (!IsAutoParallelCareNode(cnode)) {
// Needed by rec_parser
PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
if (prim->name() == TUPLE_GETITEM) {
entire_costgraph->add_tuple_getitem(std::make_pair(cnode->UniqueId(), cnode->input(1)->UniqueId()));
}
continue;
}
PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
......@@ -522,6 +527,11 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_no
}
ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
if (!IsAutoParallelCareNode(cnode)) {
// Needed by rec_parser
PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
if (prim->name() == TUPLE_GETITEM) {
entire_costgraph->add_tuple_getitem(std::make_pair(cnode->UniqueId(), cnode->input(1)->UniqueId()));
}
continue;
}
PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
......@@ -1153,6 +1163,7 @@ Status ParallelStrategyRecSearch(const std::vector<AnfNodePtr> &all_nodes, const
MS_LOG(ERROR) << "Constructing nodes for cost graph failed.";
return FAILED;
}
auto ops = entire_costgraph->GetOperators();
std::vector<std::vector<std::string>> input_tensor_names = entire_costgraph->get_inputs_tensor_name_list();
auto tuple_getitem_list = entire_costgraph->get_tuple_getitem_list();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册