提交 caac6bce 编写于 作者: C ch-l

adjustements w.r.t. distributed execution

上级 69ab46e6
......@@ -296,10 +296,10 @@ double CostConvolution::GetMinCostIn(const Graph::NodeType &node) {
static_cast<int>(op.arguments[1].tensor_shape.shape_n * op.arguments[1].tensor_str.str_n) *
static_cast<int>(op.arguments[1].tensor_shape.shape_w * op.arguments[1].tensor_str.str_w) *
static_cast<int>(op.arguments[1].tensor_shape.shape_c * op.arguments[1].tensor_str.str_c);
int tensor_out = static_cast<int>(node.tensor_parm.tensor_shape.shape_h * node.tensor_parm.tensor_shape.shape_w) *
static_cast<int>(node.tensor_parm.tensor_shape.shape_n * node.tensor_parm.tensor_shape.shape_c) *
static_cast<int>(node.tensor_parm.tensor_str.str_h * node.tensor_parm.tensor_str.str_w) *
static_cast<int>(node.tensor_parm.tensor_str.str_n * node.tensor_parm.tensor_str.str_c);
int tensor_out = static_cast<int>(node.tensor_parm.tensor_shape.shape_h * node.tensor_parm.tensor_str.str_h) *
static_cast<int>(node.tensor_parm.tensor_shape.shape_n * node.tensor_parm.tensor_str.str_n) *
static_cast<int>(node.tensor_parm.tensor_shape.shape_w * node.tensor_parm.tensor_str.str_w) *
static_cast<int>(node.tensor_parm.tensor_shape.shape_c * node.tensor_parm.tensor_str.str_c);
std::vector<double> cost_in;
cost_in.push_back(StrDimB(tensor_filter));
......@@ -628,6 +628,22 @@ StrategyRec CostCommon::ChoseStr(const std::vector<double> &cost_op, StrategyRec
return str;
}
// Get weight for BN
double CostBatchNorm::GetMinCostIn(const OperatorRec &op) {
int tensor = static_cast<int>(op.arguments[0].tensor_shape.shape_h * op.arguments[0].tensor_str.str_h) *
static_cast<int>(op.arguments[0].tensor_shape.shape_n * op.arguments[0].tensor_str.str_n) *
static_cast<int>(op.arguments[0].tensor_shape.shape_w * op.arguments[0].tensor_str.str_w) *
static_cast<int>(op.arguments[0].tensor_shape.shape_c * op.arguments[0].tensor_str.str_c);
std::vector<double> cost_in;
cost_in.push_back(StrDimB(tensor) * 1.2);
cost_in.push_back(DOUBLE_MAX);
cost_in.push_back(StrDimH(tensor) * 1.2);
cost_in.push_back(StrDimW(tensor) * 1.2);
return *min_element(cost_in.begin(), cost_in.end());
}
// Get optimal strategy for BN
StrategyRec CostBatchNorm::GetOptimalStr(const Graph::NodeType &node,
const std::vector<std::pair<std::string, StrategyRec>> &node_name_to_strategy,
......
......@@ -213,7 +213,7 @@ class CostBatchNorm {
const std::vector<std::pair<std::string, StrategyRec>> &node_name_to_strategy,
const Graph &graph);
double GetMinCostIn() const { return 0.0; }
double GetMinCostIn(const OperatorRec &op);
private:
double StrDimB(int32_t Tensor) {
......
......@@ -132,8 +132,9 @@ std::vector<int32_t> MakeOriginalStrategy(const std::vector<std::shared_ptr<Oper
if (iter_ops >= ops.size()) {
MS_LOG(EXCEPTION) << "Failure: Operators' elements out of range.";
}
if (iter_op_inputs >= ops[iter_ops]->strategy()->GetInputDim().size())
if (iter_op_inputs >= ops[iter_ops]->strategy()->GetInputDim().size()) {
MS_LOG(EXCEPTION) << "Failure: Strategy's InputDim out of range.";
}
size_t input_size = ops[iter_ops]->strategy()->GetInputDim()[iter_op_inputs].size();
for (size_t dim = 0; dim < input_size; dim++) {
s.push_back(1);
......@@ -161,8 +162,9 @@ std::vector<int32_t> MakeDataParallelStrategy(const std::vector<std::shared_ptr<
MS_LOG(EXCEPTION) << "Failure: Operators' elements out of range.";
}
StrategyPtr origin_strategy = ops[iter_ops]->strategy();
if (iter_op_inputs >= origin_strategy->GetInputDim().size())
if (iter_op_inputs >= origin_strategy->GetInputDim().size()) {
MS_LOG(EXCEPTION) << "Failure: Strategy's InputDim out of range.";
}
size_t input_size = origin_strategy->GetInputDim()[iter_op_inputs].size();
for (size_t dim = 0; dim < input_size; dim++) {
if (dim == 0 && input_size == 4) {
......@@ -198,9 +200,9 @@ std::vector<int32_t> PrepareStrategy(const std::shared_ptr<Graph> &graph,
return MakeOriginalStrategy(ops, iter_ops, iter_op_inputs);
} else if (type == RELU) {
return MakeRecSearchStrategy(graph, iter_ops, iter_op_inputs);
} else if (type == BATCH_NORM || (type == FUSE_BATCH_NORM)) {
} else if ((type == BATCH_NORM) || (type == FUSE_BATCH_NORM)) {
return PrepareBN(graph, iter_ops, iter_op_inputs);
} else if (type == SOFTMAX_CROSS_ENTROPY_WITH_LOGITS) {
} else if (type == SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS) {
return PrepareSparse(iter_op_inputs);
} else {
return MakeDataParallelStrategy(ops, iter_ops, iter_op_inputs);
......@@ -224,12 +226,6 @@ void MaskSpecialOps(std::shared_ptr<Graph> graph) {
node.apply.arguments[1].tensor_str.str_c = 1;
node.apply.arguments[1].tensor_str.str_h = 1;
node.apply.arguments[1].tensor_str.str_w = 1;
} else if (node.apply.op_type == kRecBiasAdd || node.apply.op_type == kRecMatMul) {
// For MatMul and BiasAdd
node.apply.arguments[0].tensor_str.str_h = 1;
node.apply.arguments[0].tensor_str.str_w = 1;
node.apply.arguments[1].tensor_str.str_h = 1;
node.apply.arguments[1].tensor_str.str_w = 1;
}
}
}
......
......@@ -58,7 +58,8 @@ Graph::NodeType MakeNewOperator(std::vector<std::shared_ptr<OperatorInfo>> ops,
ops[iter_ops]->outputs_tensor_info()[0].shape()[0], ops[iter_ops]->outputs_tensor_info()[0].shape()[1],
ops[iter_ops]->outputs_tensor_info()[0].shape()[2], ops[iter_ops]->outputs_tensor_info()[0].shape()[3]);
} else if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 2) {
NewOp.tensor_parm = Fill2DTensor(ops, iter_ops, NewOp);
NewOp.tensor_parm = MakeTensor(1, 1, ops[iter_ops]->outputs_tensor_info()[0].shape()[0],
ops[iter_ops]->outputs_tensor_info()[0].shape()[1]);
} else if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 1) {
NewOp.tensor_parm = MakeTensor(1, 1, 1, ops[iter_ops]->outputs_tensor_info()[0].shape()[0]);
} else if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 0) {
......@@ -71,29 +72,6 @@ Graph::NodeType MakeNewOperator(std::vector<std::shared_ptr<OperatorInfo>> ops,
return NewOp;
}
TensorParam Fill2DTensor(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
Graph::NodeType NewTensor) {
if (NewTensor.apply.op_type == OperatorType::kRecMatMul) {
auto attrs = ops[iter_ops]->attrs();
bool transpose_a = attrs[TRANSPOSE_A]->cast<BoolImmPtr>()->value();
bool transpose_b = attrs[TRANSPOSE_B]->cast<BoolImmPtr>()->value();
if (transpose_a) {
NewTensor.tensor_parm = MakeTensor(1, 1, ops[iter_ops]->outputs_tensor_info()[0].shape()[1],
ops[iter_ops]->outputs_tensor_info()[0].shape()[0]);
} else if (transpose_b) {
NewTensor.tensor_parm = MakeTensor(1, 1, ops[iter_ops]->outputs_tensor_info()[0].shape()[1],
ops[iter_ops]->outputs_tensor_info()[0].shape()[0]);
} else {
NewTensor.tensor_parm = MakeTensor(1, 1, ops[iter_ops]->outputs_tensor_info()[0].shape()[0],
ops[iter_ops]->outputs_tensor_info()[0].shape()[1]);
}
} else {
NewTensor.tensor_parm = MakeTensor(1, 1, ops[iter_ops]->outputs_tensor_info()[0].shape()[0],
ops[iter_ops]->outputs_tensor_info()[0].shape()[1]);
}
return NewTensor.tensor_parm;
}
OperatorRec CompleteOperatorInputs(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
Graph::NodeType NewTensor) {
for (size_t iter_input_tensors = 0; iter_input_tensors < ops[iter_ops]->inputs_tensor_info().size();
......
......@@ -53,9 +53,6 @@ const TensorParam MakeTensor(int n, int c, int h, int w);
Graph::NodeType MakeNewOperator(std::vector<std::shared_ptr<OperatorInfo>> ops, size_t iter_ops);
TensorParam Fill2DTensor(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
Graph::NodeType NewTensor);
OperatorRec CompleteOperatorInputs(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
Graph::NodeType NewTensor);
......
......@@ -73,7 +73,7 @@ double GetWeights(const Graph::NodeType &node) {
// For BatchNorm
auto cost_ptr = std::make_shared<CostBatchNorm>();
return cost_ptr->GetMinCostIn();
return cost_ptr->GetMinCostIn(op);
} else if (op.op_type == OperatorType::kRecOneHot || op.op_type == OperatorType::kRecLog ||
op.op_type == OperatorType::kRecExp || op.op_type == OperatorType::kRecAdd ||
op.op_type == OperatorType::kRecSub || op.op_type == OperatorType::kRecMul ||
......@@ -108,8 +108,8 @@ std::vector<size_t> SortByWeight(const std::shared_ptr<Graph> graph) {
}
}
// Do sorting.
sort(weight_to_node_index.begin(), weight_to_node_index.end());
// Ordering ops aka nodes of the graph
std::sort(weight_to_node_index.begin(), weight_to_node_index.end());
// Store the result in node_index_by_weights.
uint64_t size = weight_to_node_index.size();
......@@ -231,7 +231,6 @@ Status PartitionForAllDevices(const size_t num_device, const double device_memor
}
}
InferUndecideStrategy(graph);
if (DevicesMemoryControl(device_memory, graph) != SUCCESS) {
return FAILED;
} else {
......@@ -257,80 +256,6 @@ Graph::NodeType ApplyStrToTensor(Graph::NodeType Node) {
return Node;
}
// Check Strategy for the same tensor between op.
void InferUndecideStrategy(std::shared_ptr<Graph> graph) {
MS_EXCEPTION_IF_NULL(graph);
uint64_t iter_nodes = graph->nodes.size();
// For all the nodes in the graph
for (uint64_t i_node = 0; i_node < iter_nodes; i_node++) {
// If this target node is an operator, find it's adjecent op's strategy;
if (graph->nodes[i_node].info == 0) {
// Try to apply last op's strategy.
ApplyLastStrategy(i_node, graph);
// Try to apply next op's strategy.
ApplyNextStrategy(i_node, graph);
}
}
}
void ApplyLastStrategy(const uint64_t node_index, std::shared_ptr<Graph> graph) {
Graph::NodeType &target_node = graph->nodes[node_index];
// Number of node-in
size_t num_node_in = target_node.node_in.size();
// Find forward op and copy strategy if meets the limits.
for (size_t index = 0; index < num_node_in; index++) {
if (graph->nodes[target_node.node_in[index]].tensor_parm.tensor_str.str_n <=
target_node.apply.arguments[0].tensor_str.str_n &&
graph->nodes[target_node.node_in[index]].tensor_parm.tensor_str.str_c <=
target_node.apply.arguments[0].tensor_str.str_c &&
graph->nodes[target_node.node_in[index]].tensor_parm.tensor_str.str_h <=
target_node.apply.arguments[0].tensor_str.str_h &&
graph->nodes[target_node.node_in[index]].tensor_parm.tensor_str.str_w <=
target_node.apply.arguments[0].tensor_str.str_w) {
target_node.apply.arguments[0].tensor_str.str_n =
graph->nodes[target_node.node_in[index]].tensor_parm.tensor_str.str_n;
target_node.apply.arguments[0].tensor_str.str_c =
graph->nodes[target_node.node_in[index]].tensor_parm.tensor_str.str_c;
target_node.apply.arguments[0].tensor_str.str_h =
graph->nodes[target_node.node_in[index]].tensor_parm.tensor_str.str_h;
target_node.apply.arguments[0].tensor_str.str_w =
graph->nodes[target_node.node_in[index]].tensor_parm.tensor_str.str_w;
}
}
}
void ApplyNextStrategy(const uint64_t node_index, std::shared_ptr<Graph> graph) {
Graph::NodeType &target_node = graph->nodes[node_index];
// Number of node-out
size_t num_node_out = target_node.node_out.size();
// Find backward op and copy strategy if meets the limits.
for (size_t index = 0; index < num_node_out; index++) {
if (graph->nodes[target_node.node_out[index]].apply.arguments[0].tensor_str.str_n <=
target_node.tensor_parm.tensor_str.str_n &&
graph->nodes[target_node.node_out[index]].apply.arguments[0].tensor_str.str_c <=
target_node.tensor_parm.tensor_str.str_c &&
graph->nodes[target_node.node_out[index]].apply.arguments[0].tensor_str.str_h <=
target_node.tensor_parm.tensor_str.str_h &&
graph->nodes[target_node.node_out[index]].apply.arguments[0].tensor_str.str_w <=
target_node.tensor_parm.tensor_str.str_w) {
target_node.tensor_parm.tensor_str.str_n =
graph->nodes[target_node.node_out[index]].apply.arguments[0].tensor_str.str_n;
target_node.tensor_parm.tensor_str.str_c =
graph->nodes[target_node.node_out[index]].apply.arguments[0].tensor_str.str_c;
target_node.tensor_parm.tensor_str.str_h =
graph->nodes[target_node.node_out[index]].apply.arguments[0].tensor_str.str_h;
target_node.tensor_parm.tensor_str.str_w =
graph->nodes[target_node.node_out[index]].apply.arguments[0].tensor_str.str_w;
}
}
}
Status DevicesMemoryControl(const double device_memory, std::shared_ptr<Graph> graph) {
MS_EXCEPTION_IF_NULL(graph);
......
......@@ -44,12 +44,6 @@ Status PartitionForAllDevices(const size_t num_device, const double device_memor
Graph::NodeType ApplyStrToTensor(Graph::NodeType Node);
void InferUndecideStrategy(std::shared_ptr<Graph> graph);
void ApplyLastStrategy(const uint64_t node_index, std::shared_ptr<Graph> graph);
void ApplyNextStrategy(const uint64_t node_index, std::shared_ptr<Graph> graph);
Status DevicesMemoryControl(const double device_memory, std::shared_ptr<Graph> graph);
size_t GetDataTypeSize(const TensorType &type);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册