提交 48e54dcf 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1056 [Auto parallel] Memory calculation in the inference phase

Merge pull request !1056 from Xiaoda/memory-estimation-in-inference-phase
......@@ -53,7 +53,7 @@ struct Cost {
communication_redis_backward_ = 0.0;
communication_forward_ = 0.0;
}
// 'memory_with_reuse_' calculates the peak memory usage in a training phase
// 'memory_with_reuse_' calculates the peak memory usage in a training (or inference) phase
double memory_with_reuse_;
// 'computation_cost_' models the training time of an iteration in a training phase. Currently, this is calculated
// by ONLY forward phase
......
......@@ -300,5 +300,20 @@ Status Edge::CalculateMemoryCost() {
return SUCCESS;
}
Status Edge::CalculateMemoryCostForInference() {
// Currently, memory cost is NOT calculated for redistribution
if ((is_output_critical_ != 0) && (is_output_critical_ != 1)) {
MS_LOG(ERROR) << "Failure: unexpected output critical flag value: " << is_output_critical_;
return FAILED;
}
for (auto &cost_kv : cost_map_) {
auto &cost_v = cost_kv.second;
if (!cost_v.empty()) {
cost_v[0]->memory_with_reuse_ = 0;
}
}
return SUCCESS;
}
} // namespace parallel
} // namespace mindspore
......@@ -131,9 +131,13 @@ class Edge {
void set_selected_cost(const CostPtr &cost) { selected_cost_ = cost; }
const CostPtr &selected_cost() const { return selected_cost_; }
void set_parameter_involve(int para_invol) { is_output_parameter_involve_ = para_invol; }
// When the input of a operator contains WEIGHT or a output from other operators involving WEIGHT, then these input
// should stay in memory until it is used in the backward phase, which is kept in memory at the end of forward phase.
// In the training phase, when the input of a operator contains WEIGHT or a output from other operators involving
// WEIGHT, then these input should stay in memory until it is used in the backward phase, which is kept in memory
// at the end of forward phase.
Status CalculateMemoryCost();
// In the inference phase,
Status CalculateMemoryCostForInference();
void mark_output_critical() { is_output_critical_ = 1; }
private:
std::string edge_name_;
......@@ -156,7 +160,11 @@ class Edge {
// If it is true, then we should guarantee that the strategy for output tensor consistent with the input tensor.
bool is_identity_edge;
CostPtr selected_cost_;
// In the training phase, 'is_output_parameter_involve_' is used to mark whether the output of the previous operator
// is parameter-involved
int is_output_parameter_involve_ = -1; // -1: unset; 0: not parameter_involved; 1: parameter_involved
// In the inference phase, this is used to mark whether the output of the previous operator is critical.
int is_output_critical_ = 0;
};
} // namespace parallel
} // namespace mindspore
......
......@@ -369,7 +369,7 @@ CostPtr CostGraph::SelectCostWithMinInferenceTime(const CostPtrList &cost_list,
<< ", communication_with_partial_para_: " << ret->communication_with_partial_para_
<< ", communication_cost_: " << ret->communication_cost_
<< ", communication_without_parameter_: " << ret->communication_without_parameter_ << ".";
MS_LOG(INFO) << "Cost 0: totoal_cost: " << minimum;
MS_LOG(INFO) << "Cost 0: total_cost: " << minimum;
for (size_t i = 1; i < after_mem_filter.size(); ++i) {
MS_EXCEPTION_IF_NULL(after_mem_filter[i]);
MS_LOG(INFO) << "Cost " << i << ": memory_cost: " << after_mem_filter[i]->memory_with_reuse_
......@@ -422,7 +422,7 @@ CostPtr CostGraph::SelectCostWithMinTrainingTime(const CostPtrList &cost_list, d
<< ", communication_with_partial_para_: " << ret->communication_with_partial_para_
<< ", communication_cost_: " << ret->communication_cost_
<< ", communication_without_parameter_: " << ret->communication_without_parameter_ << ".";
MS_LOG(INFO) << "Cost 0: totoal_cost: " << minimum;
MS_LOG(INFO) << "Cost 0: total_cost: " << minimum;
for (size_t i = 1; i < after_mem_filter.size(); ++i) {
MS_EXCEPTION_IF_NULL(after_mem_filter[i]);
MS_LOG(INFO) << "Cost " << i << ": memory_cost: " << after_mem_filter[i]->memory_with_reuse_
......@@ -1351,6 +1351,14 @@ std::vector<std::shared_ptr<Edge>> CostGraph::EliminationStar(const OperatorInfo
return succ_edges;
}
size_t CostGraph::GetNumEdges() const {
size_t sum = 0;
for (const auto &kv : edges_) {
auto &edges = kv.second;
sum += edges.size();
}
return sum;
}
Status CostGraph::InitSelectedStrategy() {
for (auto &op : ops_) {
MS_EXCEPTION_IF_NULL(op);
......@@ -1416,6 +1424,122 @@ Status CostGraph::ComputeOpsAndEdgesParameterInvolved() {
return SUCCESS;
}
void CostGraph::DFSForTopoOrder(const OperatorInfoPtr &current_op, std::map<OperatorInfoPtr, bool> *visited,
std::vector<OperatorInfoPtr> *topo_order) {
MS_EXCEPTION_IF_NULL(current_op);
MS_EXCEPTION_IF_NULL(visited);
MS_EXCEPTION_IF_NULL(topo_order);
visited->at(current_op) = true;
for (const auto &s_edge : current_op->succ_edges()) {
if (!visited->at(s_edge->next_operator())) {
DFSForTopoOrder(s_edge->next_operator(), visited, topo_order);
}
}
topo_order->push_back(current_op);
}
// Compute a topological order of the costgraph
void CostGraph::TopologyOrder(std::vector<OperatorInfoPtr> *topo_order) {
std::map<OperatorInfoPtr, bool> visited;
for (auto &op : ops_) {
visited[op] = false;
}
for (auto &op : ops_) {
if (!visited[op]) {
DFSForTopoOrder(op, &visited, topo_order);
}
}
}
void CostGraph::MarkCriticalOpsAndEdges(const std::map<OperatorInfoPtr, int> &candidate_ops) {
for (auto &op : ops_) {
auto search = candidate_ops.find(op);
if (search != candidate_ops.end()) {
// Mark the critical operators
op->mark_output_critical();
// Mark the successive edges
for (auto &s_edge : op->succ_edges()) {
s_edge->mark_output_critical();
}
} else {
op->mark_output_not_critical();
}
}
}
Status CostGraph::DetermineCriticalOps(const std::vector<OperatorInfoPtr> &topo_order) {
if (topo_order.size() == 0) {
MS_LOG(ERROR) << "0 operator in costgraph.";
return FAILED;
}
auto &first_op = topo_order[0];
if (first_op->prev_edges().size() > 0) {
MS_LOG(ERROR) << "The first operator in the first of topological order of "
"costgraph should have 0 incoming edge, but has "
<< first_op->prev_edges() << "edges.";
return FAILED;
}
// The 'curr_memory_state' records <OperatorInfo, remaining_output_cnt>, where remaining_output_cnt is the number
// of the output of OperatorInfo that currently has not been used
std::map<OperatorInfoPtr, int> curr_memory_state;
(void)curr_memory_state.emplace(std::make_pair(first_op, SizeToInt(first_op->succ_edges().size())));
std::map<OperatorInfoPtr, int> max_memory_state = curr_memory_state;
// The 'curr_memory_size' records the current total memory size, which is the sum of outputs of operators that has
// not been used
double curr_memory_size = first_op->GetOutputsTotalSize();
double max_memory_size = curr_memory_size;
for (size_t finished = 1; finished < topo_order.size(); ++finished) {
// Produce
(void)curr_memory_state.emplace(
std::make_pair(topo_order[finished], SizeToInt(topo_order[finished]->succ_edges().size())));
curr_memory_size += topo_order[finished]->GetOutputsTotalSize();
// Consume
for (const auto &prev_edge : topo_order[finished]->prev_edges()) {
const auto &prev_op = prev_edge->prev_operator();
curr_memory_state[prev_op]--;
}
for (const auto &prev_edge : topo_order[finished]->prev_edges()) {
const auto &prev_op = prev_edge->prev_operator();
if (curr_memory_state[prev_op] < 0) {
MS_LOG(ERROR) << "Failure: " << prev_op->name() << "'s current output count: " << curr_memory_state[prev_op];
return FAILED;
} else if (curr_memory_state[prev_op] == 0) {
curr_memory_state.erase(prev_op);
curr_memory_size -= prev_op->GetOutputsTotalSize();
}
}
if (curr_memory_size < 0) {
MS_LOG(ERROR) << "Memory size calculation failed: " << curr_memory_size;
}
// Modify the max
if (curr_memory_size > max_memory_size) {
max_memory_size = curr_memory_size;
max_memory_state = curr_memory_state;
}
}
// Mark those critical operators
MarkCriticalOpsAndEdges(max_memory_state);
return SUCCESS;
}
Status CostGraph::ComputeOpsAndEdgesOutputCritical() {
// Two steps to do:
// 1. Compute a topological order of the costgraph
// 2. Determine and mark the operators (and necessary edges) that are critical
std::vector<OperatorInfoPtr> topo_order;
TopologyOrder(&topo_order);
std::reverse(std::begin(topo_order), std::end(topo_order));
if (DetermineCriticalOps(topo_order) != SUCCESS) {
MS_LOG(ERROR) << "Determining critical operators failed.";
return FAILED;
}
return SUCCESS;
}
Status CostGraph::CalculateOpsMemoryCost() {
for (auto &op : ops_) {
MS_EXCEPTION_IF_NULL(op);
......@@ -1427,6 +1551,17 @@ Status CostGraph::CalculateOpsMemoryCost() {
return SUCCESS;
}
Status CostGraph::CalculateOpsMemoryCostForInference() {
for (auto &op : ops_) {
MS_EXCEPTION_IF_NULL(op);
if (op->CalculateMemoryCostForInference() != SUCCESS) {
MS_LOG(ERROR) << "Calculate Operator: " << op->name() << " cost for memory usage failed.";
return FAILED;
}
}
return SUCCESS;
}
Status CostGraph::CalculateEdgesMemoryCost() {
for (auto &edge_pair : edges_) {
const auto &edges = edge_pair.second;
......@@ -1440,6 +1575,19 @@ Status CostGraph::CalculateEdgesMemoryCost() {
return SUCCESS;
}
Status CostGraph::CalculateEdgesMemoryCostForInference() {
for (auto &edge_pair : edges_) {
const auto &edges = edge_pair.second;
for (auto &one_edge : edges) {
if (one_edge->CalculateMemoryCostForInference() != SUCCESS) {
MS_LOG(ERROR) << "Calculate Edge: " << one_edge->edge_name() << " cost for memory usage failed.";
return FAILED;
}
}
}
return SUCCESS;
}
OperatorInfoPtr CostGraph::FindTmpIdentityByParameterName(std::string &p_name) const {
for (auto one_op : ops_) {
if (one_op->name().find(IDENTITY_INFO) != std::string::npos) {
......@@ -1480,5 +1628,49 @@ Status CostGraph::CorrectOpsMemoryCost() {
}
return SUCCESS;
}
Status CostGraph::CalculateMemoryCost() {
if (RUN_PHASE == TRAINING_PHASE) {
// training phase
if (ComputeOpsAndEdgesParameterInvolved() == SUCCESS) {
// Calculate operators' memory usage
if (CalculateOpsMemoryCost() != SUCCESS) {
MS_LOG(ERROR) << "Calculating operators' cost for memory cost failed.";
return FAILED;
}
// Calculate edges' memory usage
if (CalculateEdgesMemoryCost() != SUCCESS) {
MS_LOG(ERROR) << "Calculating edges' cost for memory cost failed.";
return FAILED;
}
// Correct memory usage caused by TmpIdentity
if (CorrectOpsMemoryCost() != SUCCESS) {
MS_LOG(ERROR) << "Correcting operators' cost for memory cost failed.";
return FAILED;
}
} else {
MS_LOG(ERROR) << "Computing operators' parameter_involved failed.";
return FAILED;
}
} else {
// inference phase
if (ComputeOpsAndEdgesOutputCritical() == SUCCESS) {
// Calculate operators' memory usage
if (CalculateOpsMemoryCostForInference() != SUCCESS) {
MS_LOG(ERROR) << "Calculating operators' memory cost for inference failed.";
return FAILED;
}
// Calculate edges's memory usage
if (CalculateEdgesMemoryCostForInference() != SUCCESS) {
MS_LOG(ERROR) << "Calculating operators' memory cost for inference failed.";
return FAILED;
}
} else {
MS_LOG(ERROR) << "Computing operators' critical flag failed.";
return FAILED;
}
}
return SUCCESS;
}
} // namespace parallel
} // namespace mindspore
......@@ -179,16 +179,24 @@ class CostGraph {
void CreateStarEliminationSubCostList(const StrategyPtr &, const CostPtrList &, const CostPtrList &,
const StrategyPtr &, const CostPtrList &, std::vector<StrategyPtr>,
CostPtrList &, CostPtrList &, CostPtrList *);
// Calculate memory cost for training phase or inference phase.
Status CalculateMemoryCost();
// When the input of a operator is neither a WEIGHT, nor a output of a subsequent operator involving WEIGHT, then
// the memory cost can be resused.
// the memory cost can be resused. This is used to calculate memory in the training phase.
Status CalculateOpsMemoryCost();
// When the input of the edge is neither a WEIGHT, nor a output of a subsequent operator involving WEIGHT, then
// the memory cost can be resused.
// the memory cost can be reused. This is used to calculate memory in the training phase.
Status CalculateEdgesMemoryCost();
// Calculate memory cost of operators in the inference phase.
Status CalculateOpsMemoryCostForInference();
// Calculate memory cost of edges in the inference phase.
Status CalculateEdgesMemoryCostForInference();
Status ComputeOpsAndEdgesParameterInvolved();
// Compute for each operator whether the output is critical.
Status ComputeOpsAndEdgesOutputCritical();
std::vector<OperatorInfoPtr> GetOperators() const { return ops_; }
size_t GetNumPairs() const { return edges_.size(); }
size_t GetNumEdges() const;
Status InitSelectedStrategy();
OperatorInfoPtr FindTmpIdentityByParameterName(std::string &) const;
// When TmpIdentity is used by mulitple operators, the corresponding parameter's memory cost should be calculated only
......@@ -208,6 +216,10 @@ class CostGraph {
const std::map<std::string, std::string> get_tuple_getitem_list() const { return tuple_getitem_list_; }
private:
void TopologyOrder(std::vector<OperatorInfoPtr> *);
void DFSForTopoOrder(const OperatorInfoPtr &, std::map<OperatorInfoPtr, bool> *, std::vector<OperatorInfoPtr> *);
Status DetermineCriticalOps(const std::vector<OperatorInfoPtr> &);
void MarkCriticalOpsAndEdges(const std::map<OperatorInfoPtr, int> &);
// Needed by rec_parser
std::vector<std::vector<std::string>> inputs_tensor_name_list_;
std::map<std::string, std::string> tuple_getitem_list_;
......
......@@ -37,6 +37,8 @@ void OperatorCost::SetInputAndOutputTypeLength(const std::vector<size_t> &input_
outputs_type_lengths_ = output_lengths;
}
void OperatorCost::set_output_critical(int critical) { is_outputs_critical_ = critical; }
double OperatorCost::GetMemoryCost(const std::vector<TensorInfo> &inputs,
const std::vector<TensorInfo> &outputs) const {
double result = 0.0;
......@@ -63,6 +65,20 @@ double OperatorCost::GetMemoryCost(const std::vector<TensorInfo> &inputs,
return result;
}
double OperatorCost::GetMemoryCostForInference(const std::vector<TensorInfo> &,
const std::vector<TensorInfo> &outputs) const {
double result = 0.0;
if (is_outputs_critical_ == -1) {
MS_LOG(EXCEPTION) << "The critical flag is not set.";
}
if (is_outputs_critical_ == 1) {
for (size_t i = 0; i < outputs.size(); ++i) {
result += ListProduct(outputs[i].slice_shape()) * static_cast<double>(outputs_type_lengths_[i]);
}
}
return result;
}
// return the per device communication cost in the forward phase.
double MatMulCost::GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
int32_t) const {
......
......@@ -70,6 +70,7 @@ class OperatorCost {
void set_is_parameter(const std::vector<bool> &is_parameter);
void set_is_parameter_involve(const std::vector<bool> &);
void set_output_parameter_involve(int);
void set_output_critical(int);
void SetInputAndOutputTypeLength(const std::vector<size_t> &input_lengths, const std::vector<size_t> &output_lengths);
std::vector<size_t> inputs_type_lengths() const { return inputs_type_lengths_; }
std::vector<size_t> outputs_type_lengths() const { return outputs_type_lengths_; }
......@@ -92,6 +93,8 @@ class OperatorCost {
// Typically, the PEAK memory cost contributed by an operator is its output (if the output is parameter-invovled),
// plus necessary inputs.
virtual double GetMemoryCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs) const;
// per device memory cost in a inference phase
double GetMemoryCostForInference(const std::vector<TensorInfo> &, const std::vector<TensorInfo> &) const;
protected:
// For each input in 'inputs_', a bool variable is true if the corresponding one is a parameter or a output of
......@@ -106,6 +109,9 @@ class OperatorCost {
// for each input and output, the followings record the number of bytes of each element
std::vector<size_t> inputs_type_lengths_;
std::vector<size_t> outputs_type_lengths_;
// Whether the output is critical, which means that this output is included in calculating peak memory cost
// in the inference phase.
int is_outputs_critical_ = -1;
};
using OperatorCostPtr = std::shared_ptr<OperatorCost>;
......
......@@ -1119,6 +1119,21 @@ Status OperatorInfo::CalculateMemoryCost() {
return SUCCESS;
}
Status OperatorInfo::CalculateMemoryCostForInference() {
// First, set the 'is_outputs_critical_' flag into OperatorCost.
if (is_output_critical_ == -1) {
MS_LOG(EXCEPTION) << "The critical flag is not set.";
return FAILED;
}
operator_cost()->set_output_critical(is_output_critical_);
// Set the memory cost in the 'strategy_cost_'
for (auto &swc : strategy_cost_) {
auto mem_cost = operator_cost()->GetMemoryCostForInference(swc->inputs_ptr, swc->outputs_ptr);
swc->cost_list[0]->memory_with_reuse_ = mem_cost;
}
return SUCCESS;
}
Status OperatorInfo::CorrectMemoryCost(size_t input_index) {
for (auto &swc : strategy_cost_) {
double parameter_mem_cost = ListProduct(swc->inputs_ptr[input_index].slice_shape()) *
......@@ -1230,6 +1245,25 @@ Status OperatorInfo::SetInputAndOutputTypeLength(const std::vector<size_t> &inpu
return SUCCESS;
}
double OperatorInfo::GetOutputsTotalSize() {
if (is_calculated_outputs_size_) {
return outputs_total_size_;
}
if (outputs_type_lengths_.size() != outputs_shape_.size()) {
MS_LOG(EXCEPTION) << "Output_lengths: " << outputs_type_lengths_.size()
<< " do not have the same number of outputs shape: " << outputs_shape_.size();
}
double sum = 0.0;
for (size_t i = 0; i < outputs_type_lengths_.size(); ++i) {
auto size = std::accumulate(outputs_shape_[i].begin(), outputs_shape_[i].end(), static_cast<double>(1.0),
std::multiplies<double>());
sum += size * static_cast<double>(outputs_type_lengths_[i]);
}
is_calculated_outputs_size_ = true;
outputs_total_size_ = sum;
return outputs_total_size_;
}
Status OperatorInfo::set_outputs_type(const std::vector<TypePtr> &outputs_type) {
if (outputs_type.size() != outputs_shape_.size()) {
MS_LOG(ERROR) << "Outputs type: " << outputs_type.size()
......
......@@ -72,6 +72,7 @@ class OperatorInfo {
Status set_is_parameter(const std::vector<bool> &is_parameter);
Status SetInputAndOutputTypeLength(const std::vector<size_t> &input_lengths,
const std::vector<size_t> &output_lengths);
double GetOutputsTotalSize();
// Set outputs dtype.
// If only one output, outputs_type.size() is 1.
// If output is tuple, outputs_type.size() is greater than 1.
......@@ -96,9 +97,13 @@ class OperatorInfo {
// is checked
Status SetCostUnderStrategyBase(const StrategyPtr &strategy);
std::vector<std::shared_ptr<StrategyWithCost>> GetStrategyCost() { return strategy_cost_; }
// When the input of a operator contains WEIGHT or a output from other operators involving WEIGHT, then these input
// should stay in memory until it is used in the backward phase, which is kept in memory at the end of forward phase.
// In the training phase, when the input of a operator contains WEIGHT or a output from other operators involving
// WEIGHT, then these input should stay in memory until it is used in the backward phase, which is kept in memory
// at the end of forward phase.
Status CalculateMemoryCost();
// In the inference phase, the memory cost is incurred only when the operator is critical. The size is calculated
// by the output
Status CalculateMemoryCostForInference();
int ComputeOpAndPrevEdgeParameterInvolved();
ForwardOp forward_op() const { return forward_op_; }
......@@ -147,6 +152,9 @@ class OperatorInfo {
// multiple times. This method is to correct this, and makes the cost is calulated only once.
Status CorrectMemoryCost(size_t input_index);
int is_output_parameter_involve() const { return is_output_parameter_involve_; }
int is_output_critical() const { return is_output_critical_; }
void mark_output_critical() { is_output_critical_ = 1; }
void mark_output_not_critical() { is_output_critical_ = 0; }
int used_devices() const { return used_devices_; }
// needed by rec_parser
void set_type(const std::string &type) { type_ = type; }
......@@ -220,7 +228,16 @@ class OperatorInfo {
// For each input in 'inputs_', a bool variable is true if the corresponding one is a parameter or a output of
// pre-operator that has parameters as input.
std::vector<bool> is_parameter_involve_;
int is_output_parameter_involve_ = -1; // -1: unset; 0: not parameter_involved; 1: parameter_involved
// If any input is parameter-involved, the output is parameter-involved. This variable is used in calculating
// peak memory cost in the training phase.
// -1: unset; 0: not parameter_involved; 1: parameter_involved
int is_output_parameter_involve_ = -1;
// Whether this output is critical, which means that this output is included in calculating peak memory cost
// in the inference phase.
// -1 : unset; 0: not critical; 1: critical
int is_output_critical_ = -1;
double outputs_total_size_ = 0.0;
bool is_calculated_outputs_size_ = false;
// for each input and output, the followings record the number of bytes of each element
std::vector<size_t> inputs_type_lengths_;
std::vector<size_t> outputs_type_lengths_;
......
......@@ -1055,6 +1055,9 @@ Status ParallelStrategySearch(const std::vector<AnfNodePtr> &all_nodes, const Fu
// Step 1: Traverse the ANF graph, and create NODEs for costgraph:
// create the OperatorInfo object for each primitive, and enumerate the parallelization strategies
// for each OperatorInfo;
// Step 1.1: Deal with 'Reshape':
// For 'Reshape', it takes its previous operator's layout as its input layout, and takes its next operator's
// layout as its output layout.
// Step 2: Traverse the ANF graph, and create EDGES for costgraph:
// create the Edge object for each pair of OperatorInfo, and enumerate the parallelization strategies
// for each edge, based on the strategies of two OperatorInfos;
......@@ -1062,7 +1065,8 @@ Status ParallelStrategySearch(const std::vector<AnfNodePtr> &all_nodes, const Fu
// taking care for the case of a single Parameter being used by multiple operators. Create a TmpIdentity
// operator for this Parameter, and add an edge for the use of this Parameter by each
// subsequent operator;
// Step 3.1: Calculate memory usage
// Step 3.1: Calculate memory usage:
// note the memory usage calculation is different in training phase and inference phase.
// Step 4: Run the Dynamic Programming algorithm:
// in this process, cost is calculated based on not only the operators, but also the edges. Here, the edge
// cost is caused by the redistribution of a operator's output tensor layout to the next operator's input
......@@ -1087,35 +1091,21 @@ Status ParallelStrategySearch(const std::vector<AnfNodePtr> &all_nodes, const Fu
MS_LOG(EXCEPTION) << "Constructing nodes for cost graph failed.";
}
}
// reshape operator needs the next node's input_layout as its output_layout.
// and needs the previous node's output_layout as its input_layout.
// Step 1.1
ReshapeCostCompute(all_nodes);
// Step 2
ConstructCostGraphEdges(all_nodes);
MS_LOG(INFO) << "Constructing edges for cost graph succeeded. There are " << entire_costgraph->GetOperators().size()
<< " operators, and " << entire_costgraph->GetNumPairs() << " edges.",
<< " operators, and " << entire_costgraph->GetNumEdges() << " edges.";
// Step 3: Augment the costgraph.
AugmentCostGraph(all_nodes);
// Step 3: Augment the costgraph.
AugmentCostGraph(all_nodes);
MS_LOG(INFO) << "After the augmenting procedure, there are " << entire_costgraph->GetOperators().size()
<< " operators, and " << entire_costgraph->GetNumPairs() << " edges.";
<< " operators, and " << entire_costgraph->GetNumEdges() << " edges.";
// Step 3.1: Calculate the memory usage
if (entire_costgraph->ComputeOpsAndEdgesParameterInvolved() == SUCCESS) {
// Calculate operators' memory usage
if (entire_costgraph->CalculateOpsMemoryCost() != SUCCESS) {
MS_LOG(EXCEPTION) << "Calculating operators' cost for memory cost failed.";
}
// Calculate edges' memory usage
if (entire_costgraph->CalculateEdgesMemoryCost() != SUCCESS) {
MS_LOG(EXCEPTION) << "Calculating edges' cost for memory cost failed.";
}
// Correct memory usage caused by TmpIdentity
if (entire_costgraph->CorrectOpsMemoryCost() != SUCCESS) {
MS_LOG(EXCEPTION) << "Correcting operators' cost for memory cost failed.";
}
} else {
MS_LOG(EXCEPTION) << "Computing operators' parameter_involved failed.";
if (entire_costgraph->CalculateMemoryCost() != SUCCESS) {
MS_LOG(EXCEPTION) << "Calculating memory cost failed.";
}
// Step 4: run DP algorithm on the costgraph.
......
......@@ -32,5 +32,6 @@ def test_inference_phase():
net_with_loss = WithLossCell(net, loss)
train_network = TrainOneStepCell(net_with_loss, optimizer)
train_network.set_train()
train_network.set_auto_parallel()
output = train_network(predict, label)
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册