提交 8e80f80f 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!4572 resolve output twice out of memory issue

Merge pull request !4572 from wangqiuliang/resolve-output-twice-out-of-memory
......@@ -25,6 +25,7 @@
#include "ir/func_graph_cloner.h"
#include "ir/manager.h"
#include "pipeline/jit/resource.h"
#include "pipeline/pynative/pynative_execute.h"
#include "frontend/optimizer/ad/adjoint.h"
#include "frontend/operator/ops.h"
#include "utils/symbolic.h"
......@@ -218,7 +219,8 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) {
TraceManager::DebugTrace(std::make_shared<TraceGradFpropApp>(cnode_morph->debug_info()));
auto k_app = k_graph_->NewCNode(inputs);
TraceManager::EndTrace();
ReplaceEquivdout(k_app, cnode_morph->forward());
ReplaceEquivdout(k_app, cnode_morph);
cnode_morph->set_forward(nullptr, "");
for (size_t i = 0; i < param_adjoints.size(); ++i) {
param_adjoints[i]->RegisterKUser(k_app, i);
}
......@@ -240,7 +242,9 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) {
return node_adjoint;
}
void DFunctor::ReplaceEquivdout(const CNodePtr &cnode, const ValuePtr &forward) {
void DFunctor::ReplaceEquivdout(const CNodePtr &cnode, const CNodePtr &cnode_morph) {
auto forward = cnode_morph->forward().first;
auto forward_id = cnode_morph->forward().second;
if (forward == nullptr) {
return;
}
......@@ -265,10 +269,44 @@ void DFunctor::ReplaceEquivdout(const CNodePtr &cnode, const ValuePtr &forward)
auto equivdout = cnode_input->cast<CNodePtr>();
auto func_graph = GetValueNode<FuncGraphPtr>(input_fg);
auto manager = Manage({fg, func_graph}, false);
auto ref_size = manager->node_users()[equivdout].size();
auto forward_value = forward;
if (!forward_id.empty() && ref_size > 1) {
auto inst = pynative::PynativeExecutor::GetInstance();
inst->SaveOpForwardValue(forward_id, forward_value);
}
if (ref_size < 2) {
auto tensor = forward->cast<tensor::TensorPtr>();
if (tensor != nullptr) {
auto new_tensor = std::make_shared<tensor::Tensor>(tensor->data_type(), tensor->shape());
forward_value = new_tensor;
}
}
MS_LOG(DEBUG) << "Replace: " << equivdout->ToString() << " with " << forward;
auto value_node = NewValueNode(forward);
auto value_node = NewValueNode(forward_value);
value_node->set_has_new_value(true);
manager->Replace(equivdout, value_node);
auto paras = fg->parameters();
auto inputs_value = cnode_morph->inputs_value();
if (inputs_value.size() == 0) {
return;
}
if (inputs_value.size() != paras.size()) {
MS_LOG(EXCEPTION) << "Parameter size:" << paras.size() << " is not equal to inputs size:" << inputs_value.size();
}
for (size_t i = 0; i < paras.size(); i++) {
auto para_ref_size = manager->node_users()[paras[i]].size();
auto input_value = inputs_value[i];
if (para_ref_size > 0 && input_value.first != nullptr) {
MS_LOG(DEBUG) << "Replace: " << paras[i]->ToString() << " with " << input_value.first;
auto inst = pynative::PynativeExecutor::GetInstance();
inst->SaveOpForwardValue(input_value.second, input_value.first);
auto input_value_node = NewValueNode(input_value.first);
manager->Replace(paras[i], input_value_node);
}
}
cnode_morph->clear_inputs_value();
return;
}
bool DFunctor::IsFreeMorphism(const AnfNodePtr &node) {
......
......@@ -95,7 +95,7 @@ class DFunctor : public std::enable_shared_from_this<DFunctor> {
// Update k hole with adjoint_definition, only applied in recursive case.
void UpdateAdjoint(const AdjointPtr &adjoint_definition);
void CallDoutHoleOnTape();
void ReplaceEquivdout(const CNodePtr &cnode, const ValuePtr &forward);
void ReplaceEquivdout(const CNodePtr &cnode, const CNodePtr &cnode_morph);
std::unordered_map<AnfNodePtr, AdjointPtr> anfnode_to_adjoin_;
// Cache for indirect fv backpropagation, K o K can only do backprop layer by layer.
......
......@@ -724,18 +724,14 @@ void PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, const py::ob
set_pyobj(curr_g_, obj_id);
}
void PynativeExecutor::SaveOpForwardValue(const OpExecInfoPtr &op_exec_info, const ValuePtr &value) {
auto id = GetOpId(op_exec_info);
int graph_id = resource_->results()[pipeline::kPynativeGraphId].cast<int>();
auto op = std::to_string(graph_id) + id;
op.append(std::to_string(op_id_map_[id]));
auto iter = op_forward_map_.find(op);
void PynativeExecutor::SaveOpForwardValue(const std::string &id, const ValuePtr &value) {
auto iter = op_forward_map_.find(id);
if (iter != op_forward_map_.end()) {
return;
}
op_forward_map_[op] = value;
++op_id_map_[id];
MS_LOG(DEBUG) << "Save: " << op_exec_info->op_name << "(" << op << "), " << value;
op_forward_map_[id] = value;
MS_LOG(DEBUG) << "Save op forward value: "
<< "(" << id << "), " << value;
}
void PynativeExecutor::SaveAllResult(const OpExecInfoPtr &op_exec_info, const CNodePtr &cnode, const py::tuple &out) {
......@@ -748,9 +744,25 @@ void PynativeExecutor::SaveAllResult(const OpExecInfoPtr &op_exec_info, const CN
}
auto value = PyAttrValue(out_real);
if (cnode != nullptr) {
cnode->set_forward(value);
size_t size = op_exec_info->op_inputs.size();
for (size_t i = 0; i < size; i++) {
auto obj = op_exec_info->op_inputs[i];
auto obj_id = GetId(obj);
if (obj_to_forward_id_.find(obj_id) != obj_to_forward_id_.end()) {
cnode->add_input_value(PyAttrValue(obj), obj_to_forward_id_[obj_id]);
} else {
cnode->add_input_value(nullptr, "");
}
}
std::string id = GetOpId(op_exec_info);
int graph_id = resource_->results()[pipeline::kPynativeGraphId].cast<int>();
auto op_id = std::to_string(graph_id) + id;
op_id.append(std::to_string(op_id_map_[id]));
cnode->set_forward(value, op_id);
++op_id_map_[id];
auto out_id = GetId(out_real);
obj_to_forward_id_[out_id] = op_id;
}
SaveOpForwardValue(op_exec_info, value);
}
AnfNodePtr PynativeExecutor::GetObjNode(const py::object &obj) {
......@@ -775,7 +787,7 @@ AnfNodePtr PynativeExecutor::GetObjNode(const py::object &obj) {
node_abs_map_[id] = node->abstract();
}
MS_LOG(DEBUG) << "GetObjNode output" << node->DebugString(6);
node->cast<CNodePtr>()->set_forward(PyAttrValue(obj));
node->cast<CNodePtr>()->set_forward(PyAttrValue(obj), "");
return node;
}
......@@ -1131,6 +1143,7 @@ void PynativeExecutor::EndGraphByOutId(const std::string &out_id, const py::obje
}
}
auto newfg = ad::Grad(curr_g_, resource_, curr_g_ == top_g_);
graph_info_map_.erase(curr_g_);
if (curr_g_ != top_g_) {
Popp();
for (size_t i = 0; i < args.size(); i++) {
......@@ -1300,6 +1313,7 @@ void PynativeExecutor::Clear(const std::string &flag) {
curr_g_ = nullptr;
graph_info_map_.clear();
op_id_map_.clear();
obj_to_forward_id_.clear();
std::stack<FuncGraphPtr>().swap(graph_p_);
ConfigManager::GetInstance().ResetIterNum();
}
......
......@@ -108,7 +108,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
abstract::AbstractBasePtrList *args_spec_list);
void MakeCNode(const OpExecInfoPtr &op_exec_info, const py::object &out, const AnfNodePtr &cnode);
ValuePtr GetForwardValue(const OpExecInfoPtr &op_exec_info);
void SaveOpForwardValue(const OpExecInfoPtr &op_exec_info, const ValuePtr &value);
void SaveOpForwardValue(const std::string &id, const ValuePtr &value);
void SaveForwardResult(const CNodePtr &cnode, const py::object &out);
void SaveAllResult(const OpExecInfoPtr &op_exec_info, const CNodePtr &cnode, const py::tuple &out);
......@@ -138,6 +138,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
std::unordered_map<FuncGraphPtr, GraphInfo> graph_info_map_;
std::unordered_map<std::string, ValuePtr> op_forward_map_;
std::unordered_map<std::string, size_t> op_id_map_;
std::unordered_map<std::string, std::string> obj_to_forward_id_;
std::unordered_map<std::string, abstract::AbstractBasePtr> node_abs_map_;
std::stack<FuncGraphPtr> graph_p_;
FuncGraphPtr top_g_;
......
......@@ -31,7 +31,7 @@
namespace mindspore {
// namespace to support intermediate representation definition
CNode::CNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &func_graph)
: AnfNode(func_graph), inputs_(inputs), stop_gradient_(false) {}
: AnfNode(func_graph), inputs_(inputs), stop_gradient_(false), output_value_(std::make_pair(nullptr, "")) {}
// Check if CNode is an apply with the specific Primitive.
bool CNode::IsApply(const PrimitivePtr &value) const {
......
......@@ -232,8 +232,15 @@ class CNode : public AnfNode {
void set_input(size_t i, const AnfNodePtr &input);
void set_inputs(const std::vector<AnfNodePtr> &inputs) { inputs_ = inputs; }
void set_forward(const ValuePtr &forward) { forward_ = forward; }
const ValuePtr &forward() const { return forward_; }
void add_input_value(const ValuePtr &input_value, const std::string &id) {
inputs_value_.push_back(std::make_pair(input_value, id));
}
void clear_inputs_value() { inputs_value_.clear(); }
void set_inputs_value(const std::vector<std::pair<ValuePtr, std::string>> &values) { inputs_value_ = values; }
const std::vector<std::pair<ValuePtr, std::string>> &inputs_value() const { return inputs_value_; }
void set_forward(const ValuePtr &forward, const std::string &id) { output_value_ = std::make_pair(forward, id); }
const std::pair<ValuePtr, std::string> &forward() const { return output_value_; }
bool stop_gradient() const { return stop_gradient_; }
void set_stop_gradient(bool stop_gradient) { stop_gradient_ = stop_gradient; }
......@@ -253,7 +260,10 @@ class CNode : public AnfNode {
VarPtr func_graph_as_var_;
bool stop_gradient_;
bool in_forward_flag_ = false;
ValuePtr forward_ = nullptr;
// inputs_value_ store cnode input value and id in pynative mode
// output_value_ store cnode value and id in pynative mode
std::vector<std::pair<ValuePtr, std::string>> inputs_value_;
std::pair<ValuePtr, std::string> output_value_;
};
// ANode represents the atomic node. It's derived Parameter and ValueNode.
......
......@@ -88,7 +88,8 @@ void Cloner::CloneCNode(const AnfNodePtr &node, const FuncGraphPtr &target) {
CNodePtr new_node = std::make_shared<CNode>(AnfNodePtrList{}, target);
auto old_node = node->cast<CNodePtr>();
new_node->set_abstract(old_node->abstract());
new_node->set_forward(old_node->forward());
new_node->set_forward(old_node->forward().first, old_node->forward().second);
new_node->set_inputs_value(old_node->inputs_value());
ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope();
new_node->set_scope(scope);
new_node->set_kernel_info(old_node->kernel_info_ptr());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册