提交 4145098e 编写于 作者: S Sheng

fix broadcast

上级 e4c8365d
......@@ -614,7 +614,6 @@ std::vector<int32_t> CopyIncomingOperatorInputStrategy(const std::vector<std::sh
std::vector<std::vector<int32_t>> GenerateStrategiesFromStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_ops,
std::vector<int32_t> basic_stra) {
std::vector<int32_t> s_empty = {};
std::vector<std::vector<int32_t>> stra;
MS_EXCEPTION_IF_NULL(ops[iter_ops]);
......@@ -636,9 +635,99 @@ std::vector<std::vector<int32_t>> GenerateStrategiesFromStrategy(const std::vect
if (ops[iter_ops]->type() == L2_NORMALIZE) {
return PrepareL2Normalize(ops, iter_ops, basic_stra);
}
if (ops[iter_ops]->type() == TENSOR_ADD || ops[iter_ops]->type() == SUB || ops[iter_ops]->type() == MUL ||
ops[iter_ops]->type() == DIV) {
return CheckBroadcast(ops, iter_ops, basic_stra);
}
return CheckDivisible(ops, iter_ops, basic_stra);
}
// Function to deal with ops with broadcasting, like TensorAdd/Sub/Mul/Div etc.
std::vector<std::vector<int32_t>> CheckBroadcast(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_ops, std::vector<int32_t> s) {
std::vector<std::vector<int32_t>> stra;
size_t first_tensor_dim = ops[iter_ops]->inputs_tensor_info()[0].shape().size();
size_t second_tensor_dim = ops[iter_ops]->inputs_tensor_info()[1].shape().size();
// Do Broadcasting in the second tensor.
if (second_tensor_dim < first_tensor_dim) {
bool braoadcast_first_tensor = false;
// Push back the first tensor's strategy.
stra.push_back(s);
// Push back the second tensor's strategy after applying broadcast.
stra.push_back(ApplyBroadcast(ops, iter_ops, s, second_tensor_dim, first_tensor_dim, braoadcast_first_tensor));
} else if (second_tensor_dim > first_tensor_dim) { // Do Broadcasting in the first tensor.
bool braoadcast_first_tensor = true;
// Push back the first tensor's strategy after applying broadcast.
stra.push_back(ApplyBroadcast(ops, iter_ops, s, first_tensor_dim, second_tensor_dim, braoadcast_first_tensor));
// Push back the second tensor's strategy.
stra.push_back(s);
} else { // Broadcasting can be ignored or No broadcasting needs to be applied.
stra = CheckDivisible(ops, iter_ops, s);
}
return stra;
}
std::vector<int32_t> ApplyBroadcast(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
std::vector<int32_t> s, size_t target_tensor_dim, size_t refer_tensor_dim,
bool braoadcast_first_tensor) {
std::vector<int32_t> s_empty = {};
std::vector<int32_t> s_broadcast;
int target_tensor_index = 0;
int refer_tensor_index = 0;
// Indexing target and refer tensor.
if (braoadcast_first_tensor) {
target_tensor_index = 0;
refer_tensor_index = 1;
} else {
target_tensor_index = 1;
refer_tensor_index = 0;
}
// When target tensor with an empty dim.
if (target_tensor_dim == 0) {
return s_empty;
} else if (target_tensor_dim == 1) { // When target tensor with a single dim.
bool broadcast_dim_found = false;
for (size_t iter = 0; iter < refer_tensor_dim; iter++) {
// Find and copy that dim's strategy from the refer tensor.
if ((ops[iter_ops]->inputs_tensor_info()[refer_tensor_index].shape()[iter] ==
ops[iter_ops]->inputs_tensor_info()[target_tensor_index].shape()[0]) &&
(ops[iter_ops]->inputs_tensor_info()[refer_tensor_index].shape()[iter] > 1) &&
(refer_tensor_dim == s.size())) {
s_broadcast.push_back(s.at(iter));
broadcast_dim_found = true;
break;
}
}
// Cannot decide which dim it is, push back one.
if (broadcast_dim_found == false) {
s_broadcast.push_back(1);
}
} else {
// Cannot decide which dim needs to do broadcast, push back one(s).
for (size_t iter = 0; iter < target_tensor_dim; iter++) {
s_broadcast.push_back(1);
}
}
return s_broadcast;
}
// Check whether the operator can be divided by the current strategy.
std::vector<std::vector<int32_t>> CheckDivisible(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_ops, std::vector<int32_t> basic_stra) {
std::vector<int32_t> s_empty = {};
std::vector<std::vector<int32_t>> stra;
// For all the input tensors.
for (size_t iter_op_inputs = 0; iter_op_inputs < (size_t)ops[iter_ops]->inputs_tensor_info().size();
iter_op_inputs++) {
// If input tensor is empty, return strategy as void.
if (ops[iter_ops]->inputs_tensor_info()[iter_op_inputs].shape().size() == 0) {
stra.push_back(s_empty);
continue;
......@@ -646,6 +735,8 @@ std::vector<std::vector<int32_t>> GenerateStrategiesFromStrategy(const std::vect
std::vector<int32_t> tmp_stra = basic_stra;
bool modified = false;
// Make sure each tensor's dim shape is greater than 1. If not, push back strategy as 1 instead.
for (size_t j = 0; j < (size_t)ops[iter_ops]->inputs_tensor_info()[iter_op_inputs].shape().size(); j++) {
if (ops[iter_ops]->inputs_tensor_info()[iter_op_inputs].shape()[j] == 1) {
tmp_stra[j] = 1;
......@@ -658,6 +749,7 @@ std::vector<std::vector<int32_t>> GenerateStrategiesFromStrategy(const std::vect
stra.push_back(basic_stra);
}
}
return stra;
}
......
......@@ -42,6 +42,13 @@ std::vector<std::vector<int32_t>> PrepareGatherV2(const std::vector<std::shared_
const size_t iter_ops, std::vector<int32_t> s);
std::vector<std::vector<int32_t>> PrepareL2Normalize(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_ops, std::vector<int32_t> s);
std::vector<std::vector<int32_t>> CheckBroadcast(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_ops, std::vector<int32_t> s);
std::vector<int32_t> ApplyBroadcast(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
std::vector<int32_t> s, size_t target_tensor_dim, size_t refer_tensor_dim,
bool braoadcast_first_tensor);
std::vector<std::vector<int32_t>> CheckDivisible(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_ops, std::vector<int32_t> s);
std::vector<std::vector<int32_t>> MakeRecSearchStrategy(const std::shared_ptr<Graph> &graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_graph, const size_t iter_ops);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册