提交 59bb0141 编写于 作者: S sheng 提交者: ch-l

refine mem ctl & odd num ctl & fuse str writeback

上级 3c124889
......@@ -146,7 +146,7 @@ StrategyRec CostMatMul::GetOptimalStr(const Graph::NodeType &node,
std::vector<double> cost_op;
std::vector<std::vector<float>> mode;
if (edge_i < 2) {
if (edge_i < 2 || edge_i % 2 != 0) {
cost_op.push_back(DOUBLE_MAX);
} else {
cost_op.push_back(StrConcatDimI(edge_j, edge_k) + CostRedis(node, node_name_to_strategy,
......@@ -154,7 +154,7 @@ StrategyRec CostMatMul::GetOptimalStr(const Graph::NodeType &node,
graph));
}
if (edge_j < 2) {
if (edge_j < 2 || edge_j % 2 != 0) {
cost_op.push_back(DOUBLE_MAX);
} else {
cost_op.push_back(StrConcatDimJ(edge_i, edge_k) + CostRedis(node, node_name_to_strategy,
......@@ -162,7 +162,7 @@ StrategyRec CostMatMul::GetOptimalStr(const Graph::NodeType &node,
graph));
}
if (edge_k < 2) {
if (edge_k < 2 || edge_k % 2 != 0) {
cost_op.push_back(DOUBLE_MAX);
} else {
cost_op.push_back(StrReduceDimK(edge_i, edge_j) + CostRedis(node, node_name_to_strategy,
......@@ -226,7 +226,7 @@ StrategyRec CostMatMul::ChoseStr(const std::vector<double> &cost_op, StrategyRec
// Get optimal strategy for Conv
StrategyRec CostConvolution::GetOptimalStr(
const Graph::NodeType &node, const std::vector<std::pair<std::string, StrategyRec>> &node_name_to_strategy,
const Graph &graph) {
const Graph &graph, bool channel_partition) {
const OperatorRec &op = node.apply;
int input_tensor_h = static_cast<int>(op.arguments[0].tensor_shape.shape_h * op.arguments[0].tensor_str.str_h);
......@@ -254,7 +254,7 @@ StrategyRec CostConvolution::GetOptimalStr(
cost_op.reserve(7);
std::vector<std::vector<float>> mode;
if (input_tensor_n < 2) {
if (input_tensor_n < 2 || input_tensor_n % 2 != 0) {
cost_op.push_back(DOUBLE_MAX);
} else {
cost_op.push_back(StrDimB(tensor_filter) + CostRedis(node, node_name_to_strategy,
......@@ -264,7 +264,7 @@ StrategyRec CostConvolution::GetOptimalStr(
cost_op.push_back(DOUBLE_MAX);
cost_op.push_back(DOUBLE_MAX);
if (tensor_filter < 2) {
if (channel_partition == false || tensor_filter < 2 || tensor_filter % 2 != 0) {
cost_op.push_back(DOUBLE_MAX);
} else {
cost_op.push_back(StrDimK(tensor_in) + CostRedis(node, node_name_to_strategy,
......@@ -274,7 +274,7 @@ StrategyRec CostConvolution::GetOptimalStr(
cost_op.push_back(DOUBLE_MAX);
cost_op.push_back(DOUBLE_MAX);
if (tensor_filter_c < 2) {
if (channel_partition == false || tensor_filter_c < 2 || tensor_filter_c % 2 != 0) {
cost_op.push_back(DOUBLE_MAX);
} else {
cost_op.push_back(StrDimQ(tensor_out) + CostRedis(node, node_name_to_strategy,
......@@ -384,14 +384,14 @@ StrategyRec CostPooling::GetOptimalStr(const Graph::NodeType &node,
std::vector<double> cost_op;
std::vector<std::vector<float>> mode;
if (tensor_n < 2) {
if (tensor_n < 2 || tensor_n % 2 != 0) {
cost_op.push_back(DOUBLE_MAX);
} else {
cost_op.push_back(cost_in_ + CostRedis(node, node_name_to_strategy,
mode = {{0.5, 1, 1, 1}, {0.5, 1, 1, 1}, {0.5, 1, 1, 1}}, graph));
}
if (tensor_c < 2) {
if (tensor_c < 2 || tensor_c % 2 != 0) {
cost_op.push_back(DOUBLE_MAX);
} else {
cost_op.push_back(cost_in_ + CostRedis(node, node_name_to_strategy,
......@@ -507,7 +507,6 @@ StrategyRec CostBiasAdd::ChoseStr(const std::vector<double> &cost_op, StrategyRe
switch (min_position) {
case 0:
str.inputTensor[0].str_n /= 2.0;
str.inputTensor[1].str_n /= 2.0;
str.outputTensor.str_n /= 2.0;
str.cut_counter += 1;
str.cost = str.cost + cost_in_;
......@@ -515,7 +514,6 @@ StrategyRec CostBiasAdd::ChoseStr(const std::vector<double> &cost_op, StrategyRe
case 1:
str.inputTensor[0].str_c /= 2.0;
str.inputTensor[1].str_c /= 2.0;
str.outputTensor.str_c /= 2.0;
str.cut_counter += 1;
str.cost = str.cost + cost_in_;
......@@ -523,7 +521,6 @@ StrategyRec CostBiasAdd::ChoseStr(const std::vector<double> &cost_op, StrategyRe
case 2:
str.inputTensor[0].str_h /= 2.0;
str.inputTensor[1].str_h /= 2.0;
str.outputTensor.str_h /= 2.0;
str.cut_counter += 1;
str.cost = str.cost + cost_in_;
......@@ -547,36 +544,37 @@ StrategyRec CostBiasAdd::ChoseStr(const std::vector<double> &cost_op, StrategyRe
StrategyRec CostCommon::GetOptimalStr(const Graph::NodeType &node,
const std::vector<std::pair<std::string, StrategyRec>> &node_name_to_strategy,
const Graph &graph) {
int tensor_n = static_cast<int>(node.tensor_parm.tensor_shape.shape_n * node.tensor_parm.tensor_str.str_n);
int tensor_c = static_cast<int>(node.tensor_parm.tensor_shape.shape_c * node.tensor_parm.tensor_str.str_c);
int tensor_h = static_cast<int>(node.tensor_parm.tensor_shape.shape_h * node.tensor_parm.tensor_str.str_h);
int tensor_w = static_cast<int>(node.tensor_parm.tensor_shape.shape_w * node.tensor_parm.tensor_str.str_w);
const OperatorRec &op = node.apply;
int tensor_n = static_cast<int>(op.arguments[0].tensor_shape.shape_n * op.arguments[0].tensor_str.str_n);
int tensor_c = static_cast<int>(op.arguments[0].tensor_shape.shape_c * op.arguments[0].tensor_str.str_c);
int tensor_h = static_cast<int>(op.arguments[0].tensor_shape.shape_h * op.arguments[0].tensor_str.str_h);
int tensor_w = static_cast<int>(op.arguments[0].tensor_shape.shape_w * op.arguments[0].tensor_str.str_w);
std::vector<double> cost_op;
std::vector<std::vector<float>> mode;
if (tensor_n < 2) {
if (tensor_n < 2 || tensor_n % 2 != 0) {
cost_op.push_back(DOUBLE_MAX);
} else {
cost_op.push_back(cost_in_ + CostRedis(node, node_name_to_strategy,
mode = {{0.5, 1, 1, 1}, {0.5, 1, 1, 1}, {0.5, 1, 1, 1}}, graph));
}
if (tensor_c < 2) {
if (tensor_c < 2 || tensor_c % 2 != 0) {
cost_op.push_back(DOUBLE_MAX);
} else {
cost_op.push_back(cost_in_ + CostRedis(node, node_name_to_strategy,
mode = {{1, 0.5, 1, 1}, {1, 0.5, 1, 1}, {1, 0.5, 1, 1}}, graph));
}
if (tensor_h < 2) {
if (tensor_h < 2 || tensor_h % 2 != 0) {
cost_op.push_back(DOUBLE_MAX);
} else {
cost_op.push_back(cost_in_ + CostRedis(node, node_name_to_strategy,
mode = {{1, 1, 0.5, 1}, {1, 1, 0.5, 1}, {1, 1, 0.5, 1}}, graph));
}
if (tensor_w < 2) {
if (tensor_w < 2 || tensor_w % 2 != 0) {
cost_op.push_back(DOUBLE_MAX);
} else {
cost_op.push_back(cost_in_ + CostRedis(node, node_name_to_strategy,
......@@ -660,33 +658,27 @@ StrategyRec CostBatchNorm::GetOptimalStr(const Graph::NodeType &node,
int output_tensor_h = static_cast<int>(node.tensor_parm.tensor_shape.shape_h * node.tensor_parm.tensor_str.str_h);
int output_tensor_w = static_cast<int>(node.tensor_parm.tensor_shape.shape_w * node.tensor_parm.tensor_str.str_w);
int output_tensor_n = static_cast<int>(node.tensor_parm.tensor_shape.shape_n * node.tensor_parm.tensor_str.str_n);
int output_tensor_c = static_cast<int>(node.tensor_parm.tensor_shape.shape_c * node.tensor_parm.tensor_str.str_c);
std::vector<double> cost_op;
std::vector<std::vector<float>> mode;
if (output_tensor_n < 2) {
if (output_tensor_n < 2 || output_tensor_n % 2 != 0) {
cost_op.push_back(DOUBLE_MAX);
} else {
cost_op.push_back(StrDimB(tensor_filter) + CostRedis(node, node_name_to_strategy,
mode = {{0.5, 1, 1, 1}, {1, 1, 1, 1}, {0.5, 1, 1, 1}}, graph));
}
if (output_tensor_c < 2 || tensor_filter_c < 2) {
cost_op.push_back(DOUBLE_MAX);
} else {
cost_op.push_back(StrDimC() + CostRedis(node, node_name_to_strategy,
mode = {{1, 0.5, 1, 1}, {1, 0.5, 1, 1}, {1, 0.5, 1, 1}}, graph));
}
cost_op.push_back(DOUBLE_MAX);
if (output_tensor_h < 2) {
if (output_tensor_h < 2 || output_tensor_h % 2 != 0) {
cost_op.push_back(DOUBLE_MAX);
} else {
cost_op.push_back(StrDimH(tensor_filter) + CostRedis(node, node_name_to_strategy,
mode = {{1, 1, 0.5, 1}, {1, 1, 1, 1}, {1, 1, 0.5, 1}}, graph));
}
if (output_tensor_w < 2) {
if (output_tensor_w < 2 || output_tensor_w % 2 != 0) {
cost_op.push_back(DOUBLE_MAX);
} else {
cost_op.push_back(StrDimW(tensor_filter) + CostRedis(node, node_name_to_strategy,
......
......@@ -78,7 +78,7 @@ class CostConvolution {
public:
StrategyRec GetOptimalStr(const Graph::NodeType &node,
const std::vector<std::pair<std::string, StrategyRec>> &node_name_to_strategy,
const Graph &graph);
const Graph &graph, bool channel_partition);
double GetMinCostIn(const Graph::NodeType &node);
......
......@@ -21,18 +21,16 @@
#include <vector>
#include "ir/value.h"
#include "parallel/auto_parallel/rec_core/rec_parse_graph.h"
#include "parallel/auto_parallel/rec_core/rec_partition.h"
#include "parallel/ops_info/operator_info.h"
#include "parallel/strategy.h"
namespace mindspore {
namespace parallel {
void GenerateStrategy(std::shared_ptr<Graph> graph, bool mask_special_ops,
const std::vector<std::shared_ptr<OperatorInfo>> &ops) {
void GenerateStrategy(std::shared_ptr<Graph> graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops) {
MS_EXCEPTION_IF_NULL(graph);
if (mask_special_ops) {
MaskSpecialOps(graph);
}
for (size_t iter_ops = 0; iter_ops < ops.size(); iter_ops++) {
std::vector<std::vector<int32_t>> stra;
for (size_t iter_op_inputs = 0; iter_op_inputs < ops[iter_ops]->inputs_tensor_info().size(); iter_op_inputs++) {
......@@ -69,102 +67,61 @@ std::vector<int32_t> PrepareMatMul(const std::shared_ptr<Graph> &graph,
return s;
}
std::vector<int32_t> PrepareConv2D(const std::shared_ptr<Graph> &graph, const size_t iter_nodes,
size_t iter_op_inputs) {
std::vector<int32_t> s;
if (iter_op_inputs == 0) {
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].apply.arguments[0].tensor_str.str_n));
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].apply.arguments[0].tensor_str.str_c));
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].apply.arguments[0].tensor_str.str_h));
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].apply.arguments[0].tensor_str.str_w));
} else {
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].apply.arguments[1].tensor_str.str_n));
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].apply.arguments[1].tensor_str.str_c));
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].apply.arguments[1].tensor_str.str_h));
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].apply.arguments[1].tensor_str.str_w));
}
return s;
}
std::vector<int32_t> PrepareBiasAdd(const std::shared_ptr<Graph> &graph, const size_t iter_nodes,
const size_t iter_op_inputs) {
std::vector<int32_t> s;
if (iter_op_inputs == 0) {
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].apply.arguments[iter_op_inputs].tensor_str.str_h));
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].apply.arguments[iter_op_inputs].tensor_str.str_w));
} else {
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].apply.arguments[iter_op_inputs].tensor_str.str_w));
}
return s;
}
std::vector<int32_t> PrepareBN(const std::shared_ptr<Graph> &graph, const size_t iter_nodes,
const size_t iter_op_inputs) {
std::vector<int32_t> s;
if (iter_op_inputs == 0) {
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].apply.arguments[0].tensor_str.str_n));
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].apply.arguments[0].tensor_str.str_c));
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].apply.arguments[0].tensor_str.str_h));
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].apply.arguments[0].tensor_str.str_w));
} else {
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_nodes].apply.arguments[1].tensor_str.str_w));
}
return s;
}
std::vector<int32_t> PrepareSparse(const size_t iter_op_inputs) {
std::vector<int32_t> s;
if (iter_op_inputs == 0) {
s.push_back(g_device_manager->DeviceNum());
s.push_back(1);
} else {
s.push_back(g_device_manager->DeviceNum());
}
return s;
}
std::vector<int32_t> MakeOriginalStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
const size_t iter_op_inputs) {
std::vector<int32_t> s;
std::vector<int32_t> MakeRecSearchStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const std::shared_ptr<Graph> &graph, const size_t iter_ops,
const size_t iter_op_inputs) {
if (ops.empty()) {
MS_LOG(EXCEPTION) << "Failure: Operators is empty.";
}
if (iter_ops >= ops.size()) {
MS_LOG(EXCEPTION) << "Failure: Operators' elements out of range.";
}
if (iter_op_inputs >= ops[iter_ops]->strategy()->GetInputDim().size()) {
StrategyPtr origin_strategy = ops[iter_ops]->strategy();
if (iter_op_inputs >= origin_strategy->GetInputDim().size()) {
MS_LOG(EXCEPTION) << "Failure: Strategy's InputDim out of range.";
}
size_t input_size = ops[iter_ops]->strategy()->GetInputDim()[iter_op_inputs].size();
for (size_t dim = 0; dim < input_size; dim++) {
s.push_back(1);
// size_t output_size = ops[iter_ops]->outputs_tensor_info()[0].shape().size();
size_t output_size = origin_strategy->GetInputDim()[iter_op_inputs].size();
std::vector<int32_t> s = {};
if (output_size == 4) {
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_ops].apply.arguments[iter_op_inputs].tensor_str.str_n));
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_ops].apply.arguments[iter_op_inputs].tensor_str.str_c));
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_ops].apply.arguments[iter_op_inputs].tensor_str.str_h));
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_ops].apply.arguments[iter_op_inputs].tensor_str.str_w));
} else if (output_size == 2) {
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_ops].apply.arguments[iter_op_inputs].tensor_str.str_h));
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_ops].apply.arguments[iter_op_inputs].tensor_str.str_w));
} else if (output_size == 1) {
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_ops].apply.arguments[iter_op_inputs].tensor_str.str_w));
} else if (output_size == 0) {
return s;
} else {
MS_LOG(ERROR) << "Tensor's output size is unexcepted.";
}
return s;
}
std::vector<int32_t> MakeRecSearchStrategy(const std::shared_ptr<Graph> &graph, const size_t iter_ops,
const size_t iter_op_inputs) {
std::vector<int32_t> s;
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_ops].apply.arguments[iter_op_inputs].tensor_str.str_n));
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_ops].apply.arguments[iter_op_inputs].tensor_str.str_c));
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_ops].apply.arguments[iter_op_inputs].tensor_str.str_h));
s.push_back(static_cast<int32_t>(1.0 / graph->nodes[iter_ops].apply.arguments[iter_op_inputs].tensor_str.str_w));
return s;
}
std::vector<int32_t> MakeDataParallelStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_ops, const size_t iter_op_inputs) {
std::vector<int32_t> s;
if (ops.empty()) {
MS_LOG(EXCEPTION) << "Failure: Operators is empty.";
}
if (iter_ops >= ops.size()) {
MS_LOG(EXCEPTION) << "Failure: Operators' elements out of range.";
}
StrategyPtr origin_strategy = ops[iter_ops]->strategy();
if (iter_op_inputs >= origin_strategy->GetInputDim().size()) {
MS_LOG(EXCEPTION) << "Failure: Strategy's InputDim out of range.";
}
std::vector<int32_t> s;
size_t input_size = origin_strategy->GetInputDim()[iter_op_inputs].size();
for (size_t dim = 0; dim < input_size; dim++) {
if (dim == 0 && input_size == 4) {
......@@ -175,6 +132,7 @@ std::vector<int32_t> MakeDataParallelStrategy(const std::vector<std::shared_ptr<
s.push_back(1);
}
}
return s;
}
......@@ -187,47 +145,23 @@ std::vector<int32_t> PrepareStrategy(const std::shared_ptr<Graph> &graph,
if (iter_ops >= ops.size()) {
MS_LOG(EXCEPTION) << "Failure: Operators' elements out of range.";
}
auto type = ops[iter_ops]->type();
auto idx = DictOpType.find(type);
if (idx == DictOpType.end()) {
return MakeDataParallelStrategy(ops, iter_ops, iter_op_inputs);
}
if (type == MATMUL) {
return PrepareMatMul(graph, ops, iter_ops, iter_op_inputs);
} else if ((type == MAXPOOL) || (type == SIMPLE_MEAN) || (type == TENSOR_ADD)) {
return MakeRecSearchStrategy(graph, iter_ops, iter_op_inputs);
} else if (type == CONV2D) {
return PrepareConv2D(graph, iter_ops, iter_op_inputs);
} else if (type == BIAS_ADD) {
return PrepareBiasAdd(graph, iter_ops, iter_op_inputs);
} else if (type == RESHAPE) {
return MakeOriginalStrategy(ops, iter_ops, iter_op_inputs);
} else if (type == RELU) {
return MakeRecSearchStrategy(graph, iter_ops, iter_op_inputs);
} else if ((type == BATCH_NORM) || (type == FUSE_BATCH_NORM)) {
return PrepareBN(graph, iter_ops, iter_op_inputs);
} else if (type == SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS) {
return PrepareSparse(iter_op_inputs);
} else {
return MakeDataParallelStrategy(ops, iter_ops, iter_op_inputs);
} else if (type == DIV || type == SUB || type == MUL) {
return MakeDataParallelStrategy(ops, iter_ops, iter_op_inputs);
} else {
return MakeRecSearchStrategy(ops, graph, iter_ops, iter_op_inputs);
}
}
// use to respect strategy checks of auto parallel
void MaskSpecialOps(std::shared_ptr<Graph> graph) {
size_t iter_nodes = graph->nodes.size();
for (size_t i = 0; i < iter_nodes; i++) {
Graph::NodeType &node = graph->nodes[i];
if (node.apply.op_type == kRecConvolution) { // For convolution
// cover input tensor strategy
node.apply.arguments[0].tensor_str.str_n = 1.0 / static_cast<float>(g_device_manager->DeviceNum());
node.apply.arguments[0].tensor_str.str_c = 1;
node.apply.arguments[0].tensor_str.str_h = 1;
node.apply.arguments[0].tensor_str.str_w = 1;
// cover filter tensor strategy
node.apply.arguments[1].tensor_str.str_n = 1;
node.apply.arguments[1].tensor_str.str_c = 1;
node.apply.arguments[1].tensor_str.str_h = 1;
node.apply.arguments[1].tensor_str.str_w = 1;
}
}
}
} // namespace parallel
} // namespace mindspore
......@@ -27,28 +27,18 @@
namespace mindspore {
namespace parallel {
void GenerateStrategy(std::shared_ptr<Graph> graph, bool mask_special_ops,
const std::vector<std::shared_ptr<OperatorInfo>> &ops);
void GenerateStrategy(std::shared_ptr<Graph> graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops);
std::vector<int32_t> PrepareMatMul(const std::shared_ptr<Graph> &graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_nodes,
const size_t iter_op_inputs);
std::vector<int32_t> PrepareConv2D(const std::shared_ptr<Graph> &graph, const size_t iter_nodes,
const size_t iter_op_inputs);
std::vector<int32_t> PrepareBiasAdd(const std::shared_ptr<Graph> &graph, const size_t iter_nodes,
const size_t iter_op_inputs);
std::vector<int32_t> PrepareBN(const std::shared_ptr<Graph> &graph, const size_t iter_nodes,
const size_t iter_op_inputs);
std::vector<int32_t> PrepareSparse(const size_t iter_op_inputs);
std::vector<int32_t> MakeOriginalStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
const size_t iter_op_inputs);
std::vector<int32_t> MakeRecSearchStrategy(const std::shared_ptr<Graph> &graph, const size_t iter_ops,
std::vector<int32_t> MakeRecSearchStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const std::shared_ptr<Graph> &graph, const size_t iter_ops,
const size_t iter_op_inputs);
std::vector<int32_t> MakeDataParallelStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_ops, const size_t iter_op_inputs);
std::vector<int32_t> PrepareStrategy(const std::shared_ptr<Graph> &graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
const size_t iter_op_inputs);
void MaskSpecialOps(std::shared_ptr<Graph> graph);
} // namespace parallel
} // namespace mindspore
#endif // PARALLEL_AUTO_PARALLEL_REC_GENERATE_STRATEGY_H_
......@@ -31,7 +31,7 @@ enum OperatorType {
kRecMatMul,
kRecConvolution,
kRecPooling,
kRecTensorAdd,
kRecElmWiseOp,
kRecReLU,
kRecBatchNorm,
kRecReshape,
......
......@@ -34,20 +34,26 @@ const std::map<std::string, OperatorType> DictOpType{
{MAXPOOL, OperatorType::kRecPooling},
{MAXPOOLV2, OperatorType::kRecPooling},
{SIMPLE_MEAN, OperatorType::kRecPooling},
{TENSOR_ADD, OperatorType::kRecTensorAdd},
{TENSOR_ADD, OperatorType::kRecElmWiseOp},
{RESHAPE, OperatorType::kRecReshape},
{BIAS_ADD, OperatorType::kRecBiasAdd},
{RELU, OperatorType::kRecReLU},
{BATCH_NORM, OperatorType::kRecBatchNorm},
{FUSE_BATCH_NORM, OperatorType::kRecBatchNorm},
{SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, OperatorType::kRecSparseSoftmaxCrossEntropyWithLogits},
{SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, OperatorType::kRecSparseSoftmaxCrossEntropyWithLogits},
{ONEHOT, OperatorType::kRecOneHot},
{LOG, OperatorType::kRecLog},
{EXP, OperatorType::kRecExp},
{SUB, OperatorType::kRecSub},
{MUL, OperatorType::kRecMul},
{DIV, OperatorType::kRecDiv},
{SUB, OperatorType::kRecElmWiseOp},
{MUL, OperatorType::kRecElmWiseOp},
{DIV, OperatorType::kRecElmWiseOp},
{SQUEEZE, OperatorType::kRecSqueeze},
{CAST, OperatorType::kRecCast}};
{CAST, OperatorType::kRecCast},
{REDUCE_SUM, OperatorType::kRecCast},
{REDUCE_MAX, OperatorType::kRecCast},
{REDUCE_MIN, OperatorType::kRecCast},
{REDUCE_MEAN, OperatorType::kRecCast}};
const TensorParam MakeTensor(int n, int c, int h, int w);
......
......@@ -48,7 +48,7 @@ double GetWeights(const Graph::NodeType &node) {
auto cost_ptr = std::make_shared<CostPooling>();
return cost_ptr->GetMinCostIn();
} else if (op.op_type == OperatorType::kRecTensorAdd) {
} else if (op.op_type == OperatorType::kRecElmWiseOp) {
// For TensorAdd
auto cost_ptr = std::make_shared<CostTensorAdd>();
......@@ -124,6 +124,7 @@ std::vector<size_t> SortByWeight(const std::shared_ptr<Graph> graph) {
StrategyRec PartitionNode(const Graph::NodeType &node,
const std::vector<std::pair<std::string, StrategyRec>> &node_name_to_strategy,
std::shared_ptr<Graph> graph) {
bool enable_conv_chw_partition = false;
MS_EXCEPTION_IF_NULL(graph);
if (node.apply.op_type == OperatorType::kRecMatMul) {
......@@ -135,13 +136,13 @@ StrategyRec PartitionNode(const Graph::NodeType &node,
// For Convolution
auto cost_ptr = std::make_shared<CostConvolution>();
return cost_ptr->GetOptimalStr(node, node_name_to_strategy, *graph);
return cost_ptr->GetOptimalStr(node, node_name_to_strategy, *graph, enable_conv_chw_partition);
} else if (node.apply.op_type == OperatorType::kRecPooling) {
// For Pooling
auto cost_ptr = std::make_shared<CostPooling>();
return cost_ptr->GetOptimalStr(node, node_name_to_strategy, *graph);
} else if (node.apply.op_type == OperatorType::kRecTensorAdd) {
} else if (node.apply.op_type == OperatorType::kRecElmWiseOp) {
// For TensorAdd
auto cost_ptr = std::make_shared<CostTensorAdd>();
......@@ -260,11 +261,11 @@ Status DevicesMemoryControl(const double device_memory, std::shared_ptr<Graph> g
MS_EXCEPTION_IF_NULL(graph);
uint64_t iter_nodes = graph->nodes.size();
double used_memory = 0.0;
for (uint64_t i_node = 0; i_node < iter_nodes; i_node++) {
if (graph->nodes[i_node].info == 0) {
Graph::NodeType &Node = graph->nodes[i_node];
double used_memory = 0.0;
for (int index = 0; index < 2; index++) {
used_memory += Node.apply.arguments[index].tensor_str.str_n * Node.apply.arguments[index].tensor_shape.shape_n *
......@@ -279,12 +280,13 @@ Status DevicesMemoryControl(const double device_memory, std::shared_ptr<Graph> g
Node.tensor_parm.tensor_str.str_h * Node.tensor_parm.tensor_shape.shape_h *
Node.tensor_parm.tensor_str.str_w * Node.tensor_parm.tensor_shape.shape_w *
GetDataTypeSize(Node.tensor_parm.tensor_type);
if (device_memory < used_memory) {
MS_LOG(EXCEPTION) << "Failure: Out of memory!";
return FAILED;
}
}
}
if (device_memory < used_memory) {
MS_LOG(EXCEPTION) << "Failure: Out of memory!";
return FAILED;
}
return SUCCESS;
}
......
......@@ -1023,8 +1023,7 @@ Status ParallelStrategyRecSearch(const std::vector<AnfNodePtr> &all_nodes, const
return FAILED;
}
bool mask_special_ops = true;
GenerateStrategy(graph, mask_special_ops, ops);
GenerateStrategy(graph, ops);
if (entire_costgraph->InitSelectedStrategy() == SUCCESS) {
MS_LOG(INFO) << "Init selected strategy succeeded.";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册