提交 2d973c95 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!3161 Set output value for dynamic graph.

Merge pull request !3161 from flywind/output
...@@ -608,14 +608,20 @@ py::tuple AscendSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &gr ...@@ -608,14 +608,20 @@ py::tuple AscendSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &gr
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
MS_LOG(INFO) << "Run op " << op_run_info.op_name << " start!"; MS_LOG(INFO) << "Run op " << op_run_info.op_name << " start!";
// malloc mem // malloc mem
RunOpMemoryAlloc(input_tensors, graph.get()); RunOpMemoryAlloc(op_run_info.value, input_tensors, graph.get());
// load input data to device // load input data to device
LoadInputData(graph, input_tensors); LoadInputData(graph, input_tensors);
// run op // run op
RunOpExecTask(graph); RunOpExecTask(graph);
// get output // get output
VectorRef outputs; VectorRef outputs;
if (op_run_info.value != nullptr) {
std::vector<tensor::TensorPtr> pre_output_tensors;
TensorValueToTensor(op_run_info.value, &pre_output_tensors);
std::copy(pre_output_tensors.begin(), pre_output_tensors.end(), std::back_inserter(outputs));
} else {
UpdateOutputs(graph, &outputs, input_tensors); UpdateOutputs(graph, &outputs, input_tensors);
}
// trans output to tuple // trans output to tuple
auto output_tensors = TransformBaseRefListToTuple(outputs); auto output_tensors = TransformBaseRefListToTuple(outputs);
if (!utils::isa<PyObjectRef>(output_tensors) || if (!utils::isa<PyObjectRef>(output_tensors) ||
...@@ -744,14 +750,15 @@ void AscendSession::MemoryAlloc(KernelGraph *kernel_graph) const { ...@@ -744,14 +750,15 @@ void AscendSession::MemoryAlloc(KernelGraph *kernel_graph) const {
MS_LOG(INFO) << "Finish!"; MS_LOG(INFO) << "Finish!";
} }
void AscendSession::RunOpMemoryAlloc(const std::vector<tensor::TensorPtr> &input_tensors, void AscendSession::RunOpMemoryAlloc(const ValuePtr &pre_output_value,
const std::vector<tensor::TensorPtr> &input_tensors,
KernelGraph *kernel_graph) const { KernelGraph *kernel_graph) const {
MS_LOG(INFO) << "Start memory alloc!"; MS_LOG(INFO) << "Start memory alloc!";
MS_EXCEPTION_IF_NULL(kernel_graph); MS_EXCEPTION_IF_NULL(kernel_graph);
opt::RemoveNopNode(kernel_graph); opt::RemoveNopNode(kernel_graph);
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
MS_EXCEPTION_IF_NULL(runtime_instance); MS_EXCEPTION_IF_NULL(runtime_instance);
runtime_instance->RunOpAssignMemory(input_tensors, kernel_graph); runtime_instance->RunOpAssignMemory(pre_output_value, input_tensors, kernel_graph);
MS_LOG(INFO) << "Finish!"; MS_LOG(INFO) << "Finish!";
} }
......
...@@ -79,7 +79,8 @@ class AscendSession : public SessionBasic { ...@@ -79,7 +79,8 @@ class AscendSession : public SessionBasic {
void AssignStream(NotNull<KernelGraphPtr> kernel_graph) const; void AssignStream(NotNull<KernelGraphPtr> kernel_graph) const;
void BuildKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const; void BuildKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const;
void MemoryAlloc(KernelGraph *kernel_graph) const; void MemoryAlloc(KernelGraph *kernel_graph) const;
void RunOpMemoryAlloc(const std::vector<tensor::TensorPtr> &input_tensors, KernelGraph *kernel_graph) const; void RunOpMemoryAlloc(const ValuePtr &pre_output_value, const std::vector<tensor::TensorPtr> &input_tensors,
KernelGraph *kernel_graph) const;
void RunOpMemoryClear(const KernelGraph *kernel_graph) const; void RunOpMemoryClear(const KernelGraph *kernel_graph) const;
void GenerateTaskInfo(const std::shared_ptr<KernelGraph> &kernel_graph) const; void GenerateTaskInfo(const std::shared_ptr<KernelGraph> &kernel_graph) const;
void LoadTask(const std::shared_ptr<KernelGraph> &kernel_graph) const; void LoadTask(const std::shared_ptr<KernelGraph> &kernel_graph) const;
......
...@@ -102,12 +102,13 @@ void GPUSession::AllocateMemory(KernelGraph *kernel_graph) const { ...@@ -102,12 +102,13 @@ void GPUSession::AllocateMemory(KernelGraph *kernel_graph) const {
runtime_instance->AssignMemory(kernel_graph); runtime_instance->AssignMemory(kernel_graph);
} }
void GPUSession::RunOpAllocateMemory(const std::vector<tensor::TensorPtr> &input_tensors, void GPUSession::RunOpAllocateMemory(const ValuePtr &pre_output_value,
const std::vector<tensor::TensorPtr> &input_tensors,
KernelGraph *kernel_graph) const { KernelGraph *kernel_graph) const {
MS_EXCEPTION_IF_NULL(kernel_graph); MS_EXCEPTION_IF_NULL(kernel_graph);
auto runtime_instance = device::KernelRuntimeManager::Instance().GetSingleKernelRuntime(kGPUDevice, device_id_); auto runtime_instance = device::KernelRuntimeManager::Instance().GetSingleKernelRuntime(kGPUDevice, device_id_);
MS_EXCEPTION_IF_NULL(runtime_instance); MS_EXCEPTION_IF_NULL(runtime_instance);
runtime_instance->RunOpAssignMemory(input_tensors, kernel_graph); runtime_instance->RunOpAssignMemory(pre_output_value, input_tensors, kernel_graph);
} }
void GPUSession::RunOpClearMemory(KernelGraph *kernel_graph) const { void GPUSession::RunOpClearMemory(KernelGraph *kernel_graph) const {
...@@ -292,7 +293,7 @@ py::tuple GPUSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph ...@@ -292,7 +293,7 @@ py::tuple GPUSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph
MS_EXCEPTION_IF_NULL(kernel_graph); MS_EXCEPTION_IF_NULL(kernel_graph);
// Remove NoOp from execution graph // Remove NoOp from execution graph
opt::RemoveNopNode(kernel_graph.get()); opt::RemoveNopNode(kernel_graph.get());
RunOpAllocateMemory(input_tensors, kernel_graph.get()); RunOpAllocateMemory(op_run_info.value, input_tensors, kernel_graph.get());
// Execute the computation // Execute the computation
LoadInputData(kernel_graph, input_tensors); LoadInputData(kernel_graph, input_tensors);
Execute(kernel_graph); Execute(kernel_graph);
......
...@@ -59,7 +59,8 @@ class GPUSession : public SessionBasic { ...@@ -59,7 +59,8 @@ class GPUSession : public SessionBasic {
void AllocateMemory(KernelGraph *kernel_graph) const; void AllocateMemory(KernelGraph *kernel_graph) const;
void RunOpAllocateMemory(const std::vector<tensor::TensorPtr> &input_tensors, KernelGraph *kernel_graph) const; void RunOpAllocateMemory(const ValuePtr &pre_output_value, const std::vector<tensor::TensorPtr> &input_tensors,
KernelGraph *kernel_graph) const;
void RunOpClearMemory(KernelGraph *kernel_graph) const; void RunOpClearMemory(KernelGraph *kernel_graph) const;
......
...@@ -95,6 +95,38 @@ bool IsSameLabel(const CNodePtr &left, const CNodePtr &right) { ...@@ -95,6 +95,38 @@ bool IsSameLabel(const CNodePtr &left, const CNodePtr &right) {
} }
return false; return false;
} }
void SyncDeviceInfoToValueNode(const ValueNodePtr &value_node, std::vector<std::string> *device_formats,
std::vector<TypeId> *device_types) {
MS_EXCEPTION_IF_NULL(value_node);
MS_EXCEPTION_IF_NULL(device_formats);
MS_EXCEPTION_IF_NULL(device_types);
ValuePtr value = value_node->value();
std::vector<tensor::TensorPtr> tensors;
TensorValueToTensor(value, &tensors);
if (!tensors.empty()) {
if (tensors.size() != AnfAlgo::GetOutputTensorNum(value_node)) {
MS_LOG(EXCEPTION) << "The size of tensors converted from value [" << tensors.size()
<< "] is not equal to output size of value node [" << AnfAlgo::GetOutputTensorNum(value_node)
<< "]";
}
device_formats->clear();
device_types->clear();
for (const auto &tensor : tensors) {
MS_EXCEPTION_IF_NULL(tensor);
auto device_sync = tensor->device_address();
if (device_sync != nullptr) {
auto device_address = std::dynamic_pointer_cast<device::DeviceAddress>(device_sync);
MS_EXCEPTION_IF_NULL(device_address);
device_formats->emplace_back(device_address->format());
device_types->emplace_back(device_address->type_id());
continue;
}
device_formats->emplace_back(kOpFormat_DEFAULT);
device_types->emplace_back(kTypeUnknown);
}
}
}
} // namespace } // namespace
AnfNodePtr KernelGraph::MakeValueNode(const AnfNodePtr &node) { AnfNodePtr KernelGraph::MakeValueNode(const AnfNodePtr &node) {
auto value_node = node->cast<ValueNodePtr>(); auto value_node = node->cast<ValueNodePtr>();
...@@ -347,10 +379,12 @@ void KernelGraph::SetKernelInfoForNode(const AnfNodePtr &node) const { ...@@ -347,10 +379,12 @@ void KernelGraph::SetKernelInfoForNode(const AnfNodePtr &node) const {
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(); auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
// set the format of value_node to DEFAULT_FORMAT // set the format of value_node to DEFAULT_FORMAT
std::vector<TypeId> types; std::vector<TypeId> types;
kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT}); std::vector<std::string> formats = {kOpFormat_DEFAULT};
if (node->isa<ValueNode>()) { if (node->isa<ValueNode>()) {
kernel_info->SetFeatureMapFlag(false); kernel_info->SetFeatureMapFlag(false);
types.emplace_back(kTypeUnknown); types.emplace_back(kTypeUnknown);
auto value_node = node->cast<ValueNodePtr>();
SyncDeviceInfoToValueNode(value_node, &formats, &types);
} }
if (node->isa<Parameter>()) { if (node->isa<Parameter>()) {
auto parameter = node->cast<ParameterPtr>(); auto parameter = node->cast<ParameterPtr>();
...@@ -360,6 +394,7 @@ void KernelGraph::SetKernelInfoForNode(const AnfNodePtr &node) const { ...@@ -360,6 +394,7 @@ void KernelGraph::SetKernelInfoForNode(const AnfNodePtr &node) const {
types.push_back(is_weight ? kTypeUnknown : AnfAlgo::GetOutputInferDataType(parameter, 0)); types.push_back(is_weight ? kTypeUnknown : AnfAlgo::GetOutputInferDataType(parameter, 0));
} }
// set parameter initaial device data type // set parameter initaial device data type
kernel_build_info_builder->SetOutputsFormat(formats);
kernel_build_info_builder->SetOutputsDeviceType(types); kernel_build_info_builder->SetOutputsDeviceType(types);
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), node.get()); AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), node.get());
} }
......
...@@ -216,6 +216,7 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) { ...@@ -216,6 +216,7 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) {
TraceManager::DebugTrace(std::make_shared<TraceGradFpropApp>(cnode_morph->debug_info())); TraceManager::DebugTrace(std::make_shared<TraceGradFpropApp>(cnode_morph->debug_info()));
auto k_app = k_graph_->NewCNode(inputs); auto k_app = k_graph_->NewCNode(inputs);
TraceManager::EndTrace(); TraceManager::EndTrace();
ReplaceEquivdout(k_app, cnode_morph->forward());
for (size_t i = 0; i < param_adjoints.size(); ++i) { for (size_t i = 0; i < param_adjoints.size(); ++i) {
param_adjoints[i]->RegisterKUser(k_app, i); param_adjoints[i]->RegisterKUser(k_app, i);
} }
...@@ -237,6 +238,37 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) { ...@@ -237,6 +238,37 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) {
return node_adjoint; return node_adjoint;
} }
void DFunctor::ReplaceEquivdout(const CNodePtr &cnode, const ValuePtr &forward) {
if (forward == nullptr) {
return;
}
auto &input = cnode->input(0);
if (!IsValueNode<FuncGraph>(input)) {
return;
}
auto fg = GetValueNode<FuncGraphPtr>(input);
auto output = fg->output();
if (!output->isa<CNode>()) {
return;
}
auto cnode_output = output->cast<CNodePtr>();
auto &cnode_input = cnode_output->input(1);
if (!cnode_input->isa<CNode>()) {
return;
}
auto &input_fg = cnode_output->input(2);
if (!IsValueNode<FuncGraph>(input_fg)) {
return;
}
auto equivdout = cnode_input->cast<CNodePtr>();
auto func_graph = GetValueNode<FuncGraphPtr>(input_fg);
auto manager = Manage({fg, func_graph}, false);
MS_LOG(DEBUG) << "Replace: " << equivdout->ToString() << " with " << forward;
auto value_node = NewValueNode(forward);
value_node->set_has_new_value(true);
manager->Replace(equivdout, value_node);
}
bool DFunctor::IsFreeMorphism(const AnfNodePtr &node) { bool DFunctor::IsFreeMorphism(const AnfNodePtr &node) {
// Do not care about non-CNode // Do not care about non-CNode
if (!node->isa<CNode>()) { if (!node->isa<CNode>()) {
......
...@@ -95,6 +95,7 @@ class DFunctor : public std::enable_shared_from_this<DFunctor> { ...@@ -95,6 +95,7 @@ class DFunctor : public std::enable_shared_from_this<DFunctor> {
// Update k hole with adjoint_definition, only applied in recursive case. // Update k hole with adjoint_definition, only applied in recursive case.
void UpdateAdjoint(const AdjointPtr &adjoint_definition); void UpdateAdjoint(const AdjointPtr &adjoint_definition);
void CallDoutHoleOnTape(); void CallDoutHoleOnTape();
void ReplaceEquivdout(const CNodePtr &cnode, const ValuePtr &forward);
std::unordered_map<AnfNodePtr, AdjointPtr> anfnode_to_adjoin_; std::unordered_map<AnfNodePtr, AdjointPtr> anfnode_to_adjoin_;
// Cache for indirect fv backpropagation, K o K can only do backprop layer by layer. // Cache for indirect fv backpropagation, K o K can only do backprop layer by layer.
......
...@@ -88,7 +88,9 @@ class GetitemConstEliminater : public AnfVisitor { ...@@ -88,7 +88,9 @@ class GetitemConstEliminater : public AnfVisitor {
AnfVisitor::Match(prim::kPrimListGetItem, {IsVNode, IsVNode})(node); AnfVisitor::Match(prim::kPrimListGetItem, {IsVNode, IsVNode})(node);
if (is_match_) { if (is_match_) {
return NewValueNode((*tuple_)[id_]); auto out = NewValueNode((*tuple_)[id_]);
out->set_has_new_value(has_new_value_);
return out;
} }
return nullptr; return nullptr;
} }
...@@ -96,6 +98,7 @@ class GetitemConstEliminater : public AnfVisitor { ...@@ -96,6 +98,7 @@ class GetitemConstEliminater : public AnfVisitor {
void Visit(const ValueNodePtr &vnode) override { void Visit(const ValueNodePtr &vnode) override {
if (IsValueNode<ValueTuple>(vnode)) { if (IsValueNode<ValueTuple>(vnode)) {
tuple_ = GetValueNode<ValueTuplePtr>(vnode); tuple_ = GetValueNode<ValueTuplePtr>(vnode);
has_new_value_ = vnode->has_new_value();
} }
if (tuple_ != nullptr && IsValueNode<Int32Imm>(vnode)) { if (tuple_ != nullptr && IsValueNode<Int32Imm>(vnode)) {
id_ = IntToSize(GetValue<int>(vnode->value())); id_ = IntToSize(GetValue<int>(vnode->value()));
...@@ -115,6 +118,7 @@ class GetitemConstEliminater : public AnfVisitor { ...@@ -115,6 +118,7 @@ class GetitemConstEliminater : public AnfVisitor {
bool is_match_{false}; bool is_match_{false};
size_t id_{0}; size_t id_{0};
ValueTuplePtr tuple_{nullptr}; ValueTuplePtr tuple_{nullptr};
bool has_new_value_{false};
}; };
// setitem((a, b, c, ...), 0, z) => (z, b, c, ...) // setitem((a, b, c, ...), 0, z) => (z, b, c, ...)
......
...@@ -205,7 +205,11 @@ EvalResultPtr AnalysisEngine::Eval(const AnfNodeConfigPtr &conf) { ...@@ -205,7 +205,11 @@ EvalResultPtr AnalysisEngine::Eval(const AnfNodeConfigPtr &conf) {
AbstractBasePtr AnalysisEngine::EvalValueNode(const ValueNodePtr &value_node, const AnfNodeConfigPtr &conf) { AbstractBasePtr AnalysisEngine::EvalValueNode(const ValueNodePtr &value_node, const AnfNodeConfigPtr &conf) {
MS_EXCEPTION_IF_NULL(conf); MS_EXCEPTION_IF_NULL(conf);
MS_EXCEPTION_IF_NULL(value_node); MS_EXCEPTION_IF_NULL(value_node);
return ToAbstract(value_node->value(), conf->context(), conf); auto out = ToAbstract(value_node->value(), conf->context(), conf);
if (value_node->has_new_value()) {
out = out->Broaden();
}
return out;
} }
EvalResultPtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf) { EvalResultPtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf) {
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include <unordered_set> #include <unordered_set>
#include "pybind11/pybind11.h" #include "pybind11/pybind11.h"
#include "ir/anf.h"
#include "ir/primitive_py.h" #include "ir/primitive_py.h"
#include "abstract/abstract_value.h" #include "abstract/abstract_value.h"
...@@ -51,6 +52,7 @@ struct OpExecInfo { ...@@ -51,6 +52,7 @@ struct OpExecInfo {
PrimitivePyPtr py_primitive; PrimitivePyPtr py_primitive;
std::string op_name; std::string op_name;
AbstractBasePtr abstract; AbstractBasePtr abstract;
ValuePtr value = nullptr;
py::tuple op_inputs; py::tuple op_inputs;
py::tuple inputs_mask; py::tuple inputs_mask;
......
...@@ -111,7 +111,7 @@ inline ValuePtr PyAttrValue(const py::object &obj) { ...@@ -111,7 +111,7 @@ inline ValuePtr PyAttrValue(const py::object &obj) {
return converted_ret; return converted_ret;
} }
std::string GetId(const py::object &obj) { static std::string GetId(const py::object &obj) {
py::object to_process = obj; py::object to_process = obj;
std::string prefix = ""; std::string prefix = "";
if (py::isinstance<py::tuple>(to_process)) { if (py::isinstance<py::tuple>(to_process)) {
...@@ -141,6 +141,11 @@ std::string GetId(const py::object &obj) { ...@@ -141,6 +141,11 @@ std::string GetId(const py::object &obj) {
return py::cast<std::string>(ret); return py::cast<std::string>(ret);
} }
static std::string GetOpId(const OpExecInfoPtr &op_exec_info) {
auto id = GetId(op_exec_info->py_primitive->GetPyObj());
return id;
}
py::object GetTupleObj(const py::object &obj) { py::object GetTupleObj(const py::object &obj) {
py::module mod = parse::python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE); py::module mod = parse::python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE);
py::object obj_tuple = parse::python_adapter::CallPyModFn(mod, parse::PYTHON_MOD_GET_DEFAULT_INPUT, obj); py::object obj_tuple = parse::python_adapter::CallPyModFn(mod, parse::PYTHON_MOD_GET_DEFAULT_INPUT, obj);
...@@ -317,6 +322,7 @@ OpExecInfoPtr GenerateOpExecInfo(const py::args &args, py::list *const out_args) ...@@ -317,6 +322,7 @@ OpExecInfoPtr GenerateOpExecInfo(const py::args &args, py::list *const out_args)
} }
op_exec_info->py_primitive = prim; op_exec_info->py_primitive = prim;
op_exec_info->op_attrs = py::getattr(args[PY_PRIM], "attrs"); op_exec_info->op_attrs = py::getattr(args[PY_PRIM], "attrs");
op_exec_info->value = PynativeExecutor::GetInstance()->GetForwardValue(op_exec_info);
if (op_exec_info->op_inputs.size() != op_exec_info->inputs_mask.size()) { if (op_exec_info->op_inputs.size() != op_exec_info->inputs_mask.size()) {
MS_LOG(ERROR) << "Op:" << op_exec_info->op_name << " inputs size not equal op_mask"; MS_LOG(ERROR) << "Op:" << op_exec_info->op_name << " inputs size not equal op_mask";
return nullptr; return nullptr;
...@@ -606,7 +612,20 @@ py::object RunOpWithBackendPolicy(MsBackendPolicy backend_policy, const OpExecIn ...@@ -606,7 +612,20 @@ py::object RunOpWithBackendPolicy(MsBackendPolicy backend_policy, const OpExecIn
return result; return result;
} }
AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, const py::args &args, const py::tuple &out) { ValuePtr PynativeExecutor::GetForwardValue(const OpExecInfoPtr &op_exec_info) {
auto id = GetOpId(op_exec_info);
auto op = id;
op.append(std::to_string(op_id_map_[id]));
auto iter = op_forward_map_.find(op);
if (iter != op_forward_map_.end()) {
++op_id_map_[id];
MS_LOG(DEBUG) << "Get: " << op_exec_info->op_name << "(" << op << "), " << iter->second;
return iter->second;
}
return nullptr;
}
CNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, const py::args &args, const py::tuple &out) {
if (!grad_flag_ || graph_info_map_.empty()) { if (!grad_flag_ || graph_info_map_.empty()) {
return nullptr; return nullptr;
} }
...@@ -645,6 +664,34 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, const ...@@ -645,6 +664,34 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, const
return cnode; return cnode;
} }
void PynativeExecutor::SaveOpForwardValue(const OpExecInfoPtr &op_exec_info, const ValuePtr &value) {
auto id = GetOpId(op_exec_info);
auto op = id;
op.append(std::to_string(op_id_map_[id]));
auto iter = op_forward_map_.find(op);
if (iter != op_forward_map_.end()) {
return;
}
op_forward_map_[op] = value;
++op_id_map_[id];
MS_LOG(DEBUG) << "Save: " << op_exec_info->op_name << "(" << op << "), " << value;
}
void PynativeExecutor::SaveAllResult(const OpExecInfoPtr &op_exec_info, const CNodePtr &cnode, const py::tuple &out) {
if (!grad_flag_ || op_exec_info->value != nullptr) {
return;
}
py::object out_real = out;
if (out.size() == 1) {
out_real = out[0];
}
auto value = PyAttrValue(out_real);
if (cnode != nullptr) {
cnode->set_forward(value);
}
SaveOpForwardValue(op_exec_info, value);
}
AnfNodePtr PynativeExecutor::GetObjNode(const py::object &obj) { AnfNodePtr PynativeExecutor::GetObjNode(const py::object &obj) {
auto &out = graph_info_map_[curr_g_].obj_node_map[GetId(obj)]; auto &out = graph_info_map_[curr_g_].obj_node_map[GetId(obj)];
if (out.second.size() == 1 && out.second[0] == -1) { if (out.second.size() == 1 && out.second[0] == -1) {
...@@ -657,6 +704,7 @@ AnfNodePtr PynativeExecutor::GetObjNode(const py::object &obj) { ...@@ -657,6 +704,7 @@ AnfNodePtr PynativeExecutor::GetObjNode(const py::object &obj) {
node = curr_g_->NewCNode(tuple_get_item_inputs); node = curr_g_->NewCNode(tuple_get_item_inputs);
} }
MS_LOG(DEBUG) << "GetObjNode output" << node->DebugString(6); MS_LOG(DEBUG) << "GetObjNode output" << node->DebugString(6);
node->cast<CNodePtr>()->set_forward(PyAttrValue(obj));
return node; return node;
} }
...@@ -690,11 +738,12 @@ py::tuple RunOpInner(const OpExecInfoPtr &op_exec_info, const py::args &args) { ...@@ -690,11 +738,12 @@ py::tuple RunOpInner(const OpExecInfoPtr &op_exec_info, const py::args &args) {
return err_ret; return err_ret;
} }
auto node = PynativeExecutor::GetInstance()->MakeCNode(op_exec_info, args, result); auto cnode = PynativeExecutor::GetInstance()->MakeCNode(op_exec_info, args, result);
if (node != nullptr) { if (cnode != nullptr) {
node->set_abstract(op_exec_info->abstract); cnode->set_abstract(op_exec_info->abstract);
MS_LOG(DEBUG) << "RunOp MakeCnode,new node is: " << node->DebugString(); MS_LOG(DEBUG) << "RunOp MakeCnode,new node is: " << cnode->DebugString();
} }
PynativeExecutor::GetInstance()->SaveAllResult(op_exec_info, cnode, result);
MS_LOG(DEBUG) << "RunOp end"; MS_LOG(DEBUG) << "RunOp end";
return result; return result;
} }
...@@ -1072,7 +1121,7 @@ void PynativeExecutor::GradNetInner(const GradOperationPtr &grad, const py::obje ...@@ -1072,7 +1121,7 @@ void PynativeExecutor::GradNetInner(const GradOperationPtr &grad, const py::obje
void PynativeExecutor::Clear(const std::string &flag) { void PynativeExecutor::Clear(const std::string &flag) {
if (!flag.empty()) { if (!flag.empty()) {
MS_LOG(INFO) << "Clear res"; MS_LOG(DEBUG) << "Clear res";
(void)graph_map_.erase(flag); (void)graph_map_.erase(flag);
(void)cell_graph_map_.erase(flag); (void)cell_graph_map_.erase(flag);
Clean(); Clean();
...@@ -1084,17 +1133,19 @@ void PynativeExecutor::Clear(const std::string &flag) { ...@@ -1084,17 +1133,19 @@ void PynativeExecutor::Clear(const std::string &flag) {
return; return;
} }
MS_LOG(INFO) << "Clear"; MS_LOG(DEBUG) << "Clear";
top_g_ = nullptr; top_g_ = nullptr;
curr_g_ = nullptr; curr_g_ = nullptr;
graph_info_map_.clear(); graph_info_map_.clear();
op_id_map_.clear();
std::stack<FuncGraphPtr>().swap(graph_p_); std::stack<FuncGraphPtr>().swap(graph_p_);
} }
void PynativeExecutor::Clean() { void PynativeExecutor::Clean() {
MS_LOG(INFO) << "Clean all res"; MS_LOG(DEBUG) << "Clean all res";
Clear(); Clear();
grad_flag_ = false; grad_flag_ = false;
op_forward_map_.clear();
df_builder_ = nullptr; df_builder_ = nullptr;
ad::CleanRes(); ad::CleanRes();
pipeline::ReclaimOptimizer(); pipeline::ReclaimOptimizer();
......
...@@ -95,7 +95,11 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> { ...@@ -95,7 +95,11 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
void set_obj_node_map(FuncGraphPtr g, const std::string obj, AnfNodePtr node, std::vector<int> index) { void set_obj_node_map(FuncGraphPtr g, const std::string obj, AnfNodePtr node, std::vector<int> index) {
graph_info_map_[g].obj_node_map[obj] = std::make_pair(node, index); graph_info_map_[g].obj_node_map[obj] = std::make_pair(node, index);
} }
AnfNodePtr MakeCNode(const OpExecInfoPtr &op_exec_info, const py::args &args, const py::tuple &out); CNodePtr MakeCNode(const OpExecInfoPtr &op_exec_info, const py::args &args, const py::tuple &out);
ValuePtr GetForwardValue(const OpExecInfoPtr &op_exec_info);
void SaveOpForwardValue(const OpExecInfoPtr &op_exec_info, const ValuePtr &value);
void SaveForwardResult(const CNodePtr &cnode, const py::object &out);
void SaveAllResult(const OpExecInfoPtr &op_exec_info, const CNodePtr &cnode, const py::tuple &out);
py::object Run(const py::tuple &args, const py::object &phase); py::object Run(const py::tuple &args, const py::object &phase);
void Pushp(); void Pushp();
...@@ -116,6 +120,8 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> { ...@@ -116,6 +120,8 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
std::unordered_map<std::string, FuncGraphPtr> graph_map_; std::unordered_map<std::string, FuncGraphPtr> graph_map_;
std::unordered_map<std::string, FuncGraphPtr> cell_graph_map_; std::unordered_map<std::string, FuncGraphPtr> cell_graph_map_;
std::unordered_map<FuncGraphPtr, GraphInfo> graph_info_map_; std::unordered_map<FuncGraphPtr, GraphInfo> graph_info_map_;
std::unordered_map<std::string, ValuePtr> op_forward_map_;
std::unordered_map<std::string, size_t> op_id_map_;
std::stack<FuncGraphPtr> graph_p_; std::stack<FuncGraphPtr> graph_p_;
FuncGraphPtr top_g_; FuncGraphPtr top_g_;
FuncGraphPtr df_builder_; FuncGraphPtr df_builder_;
......
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include "backend/session/anf_runtime_algorithm.h" #include "backend/session/anf_runtime_algorithm.h"
#include "backend/kernel_compiler/common_utils.h" #include "backend/kernel_compiler/common_utils.h"
#include "backend/kernel_compiler/oplib/oplib.h" #include "backend/kernel_compiler/oplib/oplib.h"
#include "backend/optimizer/common/helper.h"
#include "ir/value.h" #include "ir/value.h"
using mindspore::kernel::Address; using mindspore::kernel::Address;
using mindspore::kernel::AddressPtr; using mindspore::kernel::AddressPtr;
...@@ -150,11 +151,13 @@ void KernelRuntime::AssignMemory(session::KernelGraph *graph) { ...@@ -150,11 +151,13 @@ void KernelRuntime::AssignMemory(session::KernelGraph *graph) {
UpdateRefNodeOutputMem(graph); UpdateRefNodeOutputMem(graph);
} }
void KernelRuntime::RunOpAssignMemory(const std::vector<tensor::TensorPtr> &input_tensors, void KernelRuntime::RunOpAssignMemory(const ValuePtr &pre_output_value,
const std::vector<tensor::TensorPtr> &input_tensors,
session::KernelGraph *graph) { session::KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
RunOpAssignInputMemory(input_tensors, graph); RunOpAssignInputMemory(input_tensors, graph);
AssignStaticMemoryValueNode(graph); AssignStaticMemoryValueNode(graph);
RunOpAssignOutputNodeMemory(pre_output_value, graph);
for (const auto &cnode : graph->execution_order()) { for (const auto &cnode : graph->execution_order()) {
RunOpAssignOutputMemory(cnode); RunOpAssignOutputMemory(cnode);
RunOpAssignWorkSpaceMemory(cnode); RunOpAssignWorkSpaceMemory(cnode);
...@@ -322,6 +325,45 @@ void KernelRuntime::RunOpAssignWorkSpaceMemory(const AnfNodePtr &kernel) { ...@@ -322,6 +325,45 @@ void KernelRuntime::RunOpAssignWorkSpaceMemory(const AnfNodePtr &kernel) {
} }
} }
void KernelRuntime::RunOpAssignOutputNodeMemory(const ValuePtr &pre_output_value, session::KernelGraph *graph) {
if (pre_output_value == nullptr) {
return;
}
std::vector<tensor::TensorPtr> pre_output_tensors;
TensorValueToTensor(pre_output_value, &pre_output_tensors);
MS_EXCEPTION_IF_NULL(graph);
auto output_nodes = graph->outputs();
if (pre_output_tensors.size() != output_nodes.size()) {
MS_LOG(EXCEPTION) << "The size of pre output tensors [" << pre_output_tensors.size()
<< "] is not equal to the size of output nodes of graph [" << output_nodes.size() << "]";
}
// share output address with pre output tensors
for (size_t i = 0; i < output_nodes.size(); ++i) {
auto output_node_with_index = AnfAlgo::VisitKernel(output_nodes[i], 0);
if (!output_node_with_index.first->isa<CNode>()) {
MS_LOG(EXCEPTION) << "The output node should be a cnode , but it is "
<< output_node_with_index.first->DebugString();
}
auto real_output_cnode = output_node_with_index.first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(real_output_cnode);
MS_EXCEPTION_IF_NULL(pre_output_tensors[i]);
if (pre_output_tensors[i]->device_address() == nullptr) {
MS_LOG(EXCEPTION) << "The address of pre output tensor [" << i << "] is a nullptr!";
}
if (opt::IsNopNode(real_output_cnode)) {
if (real_output_cnode->inputs().size() < 2) {
MS_LOG(EXCEPTION) << "The input size of output node: " << real_output_cnode->DebugString()
<< " should large than one!";
}
AnfAlgo::SetOutputAddr(std::dynamic_pointer_cast<device::DeviceAddress>(pre_output_tensors[i]->device_address()),
output_node_with_index.second, real_output_cnode->input(1).get());
} else {
AnfAlgo::SetOutputAddr(std::dynamic_pointer_cast<device::DeviceAddress>(pre_output_tensors[i]->device_address()),
output_node_with_index.second, output_node_with_index.first.get());
}
}
}
void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) { void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(mem_manager_); MS_EXCEPTION_IF_NULL(mem_manager_);
...@@ -573,11 +615,18 @@ void KernelRuntime::AssignValueNodeTensor(const ValueNodePtr &value_node, const ...@@ -573,11 +615,18 @@ void KernelRuntime::AssignValueNodeTensor(const ValueNodePtr &value_node, const
MS_EXCEPTION_IF_NULL(mem_manager_); MS_EXCEPTION_IF_NULL(mem_manager_);
auto ms_context = MsContext::GetInstance(); auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context); MS_EXCEPTION_IF_NULL(ms_context);
auto tensor = node_value->cast<TensorPtr>(); std::vector<tensor::TensorPtr> tensors;
TensorValueToTensor(node_value, &tensors);
for (const auto &tensor : tensors) {
if (tensor == nullptr) { if (tensor == nullptr) {
MS_LOG(WARNING) << "Tensor is null"; MS_LOG(WARNING) << "Tensor is null";
return; return;
} }
if (tensor->device_address() != nullptr) {
AnfAlgo::SetOutputAddr(std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address()), output_idx++,
value_node.get());
continue;
}
size_t tensor_size = tensor->data().nbytes(); size_t tensor_size = tensor->data().nbytes();
auto node_size = CountNodeDeviceMemorySize(value_node, output_idx); auto node_size = CountNodeDeviceMemorySize(value_node, output_idx);
TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(value_node, output_idx); TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(value_node, output_idx);
...@@ -596,9 +645,10 @@ void KernelRuntime::AssignValueNodeTensor(const ValueNodePtr &value_node, const ...@@ -596,9 +645,10 @@ void KernelRuntime::AssignValueNodeTensor(const ValueNodePtr &value_node, const
AnfAlgo::SetOutputAddr(address, output_idx, value_node.get()); AnfAlgo::SetOutputAddr(address, output_idx, value_node.get());
if (!address->SyncHostToDevice(trans::GetRuntimePaddingShape(value_node, 0), tensor_size, tensor->data_type(), if (!address->SyncHostToDevice(trans::GetRuntimePaddingShape(value_node, 0), tensor_size, tensor->data_type(),
tensor->data_c())) { tensor->data_c())) {
MS_EXCEPTION(NotExistsError) << "ValueNode SyncHostToDevice fail!" << value_node->DebugString() << "node format is" MS_EXCEPTION(NotExistsError) << "ValueNode SyncHostToDevice fail!" << value_node->DebugString()
<< AnfAlgo::GetOutputFormat(value_node, output_idx) << "node dtype is " << "node format is" << AnfAlgo::GetOutputFormat(value_node, output_idx)
<< AnfAlgo::GetOutputInferDataType(value_node, output_idx); << "node dtype is " << AnfAlgo::GetOutputInferDataType(value_node, output_idx);
}
} }
} }
...@@ -615,7 +665,7 @@ void KernelRuntime::AssignStaticMemoryValueNode(session::KernelGraph *graph) { ...@@ -615,7 +665,7 @@ void KernelRuntime::AssignStaticMemoryValueNode(session::KernelGraph *graph) {
} }
auto &node_value = value_node->value(); auto &node_value = value_node->value();
MS_EXCEPTION_IF_NULL(node_value); MS_EXCEPTION_IF_NULL(node_value);
if (node_value->isa<Tensor>()) { if (node_value->isa<Tensor>() || node_value->isa<ValueTuple>()) {
AssignValueNodeTensor(value_node, node_value, 0); AssignValueNodeTensor(value_node, node_value, 0);
} else if (node_value->isa<StringImm>()) { } else if (node_value->isa<StringImm>()) {
auto value = GetValue<std::string>(node_value); auto value = GetValue<std::string>(node_value);
......
...@@ -53,7 +53,8 @@ class KernelRuntime { ...@@ -53,7 +53,8 @@ class KernelRuntime {
virtual ~KernelRuntime(); virtual ~KernelRuntime();
virtual bool Init() = 0; virtual bool Init() = 0;
virtual void AssignMemory(session::KernelGraph *graph); virtual void AssignMemory(session::KernelGraph *graph);
void RunOpAssignMemory(const std::vector<tensor::TensorPtr> &input_tensors, session::KernelGraph *graph); void RunOpAssignMemory(const ValuePtr &pre_output_value, const std::vector<tensor::TensorPtr> &input_tensors,
session::KernelGraph *graph);
void RunOpClearMemory(const session::KernelGraph *graph); void RunOpClearMemory(const session::KernelGraph *graph);
bool DumpDataEnabled(); bool DumpDataEnabled();
bool DumpDataEnabledIteration(); bool DumpDataEnabledIteration();
...@@ -108,6 +109,7 @@ class KernelRuntime { ...@@ -108,6 +109,7 @@ class KernelRuntime {
void RunOpAssignInputMemory(const std::vector<tensor::TensorPtr> &input_tensors, const session::KernelGraph *graph); void RunOpAssignInputMemory(const std::vector<tensor::TensorPtr> &input_tensors, const session::KernelGraph *graph);
void RunOpAssignOutputMemory(const AnfNodePtr &kernel); void RunOpAssignOutputMemory(const AnfNodePtr &kernel);
void RunOpAssignWorkSpaceMemory(const AnfNodePtr &kernel); void RunOpAssignWorkSpaceMemory(const AnfNodePtr &kernel);
void RunOpAssignOutputNodeMemory(const ValuePtr &pre_output_value, session::KernelGraph *graph);
void AssignValueNodeTensor(const ValueNodePtr &value_node, const ValuePtr &node_value, size_t output_idx); void AssignValueNodeTensor(const ValueNodePtr &value_node, const ValuePtr &node_value, size_t output_idx);
DeviceAddressPtr PreAssignCNodeMemory(const AnfNodePtr &anf_node, size_t index); DeviceAddressPtr PreAssignCNodeMemory(const AnfNodePtr &anf_node, size_t index);
......
...@@ -607,4 +607,25 @@ tensor::TensorPtr ScalarToTensor(const ScalarPtr &scalar) { ...@@ -607,4 +607,25 @@ tensor::TensorPtr ScalarToTensor(const ScalarPtr &scalar) {
MS_EXCEPTION_IF_NULL(tensor); MS_EXCEPTION_IF_NULL(tensor);
return tensor; return tensor;
} }
void TensorValueToTensor(const ValuePtr &value, std::vector<tensor::TensorPtr> *tensors) {
MS_EXCEPTION_IF_NULL(value);
MS_EXCEPTION_IF_NULL(tensors);
if (value->isa<ValueTuple>()) {
auto value_tuple = value->cast<ValueTuplePtr>();
MS_EXCEPTION_IF_NULL(value_tuple);
for (size_t i = 0; i < value_tuple->size(); ++i) {
ValuePtr element = value_tuple->value()[i];
if (element->isa<tensor::Tensor>()) {
auto tensor = element->cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(tensor);
tensors->push_back(tensor);
}
}
} else if (value->isa<tensor::Tensor>()) {
tensor::TensorPtr tensor = value->cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(tensor);
tensors->push_back(tensor);
}
}
} // namespace mindspore } // namespace mindspore
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include <memory> #include <memory>
#include <utility> #include <utility>
#include <stack> #include <stack>
#include <vector>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
...@@ -69,6 +70,8 @@ using NodeMapEquiv = std::unordered_map<AnfNodePtr, AnfNodePtr>; ...@@ -69,6 +70,8 @@ using NodeMapEquiv = std::unordered_map<AnfNodePtr, AnfNodePtr>;
bool Isomorphic(FuncGraphPtr g1, FuncGraphPtr g2, FuncGraphPairMapEquiv *equiv_func_graph, NodeMapEquiv *equiv_node); bool Isomorphic(FuncGraphPtr g1, FuncGraphPtr g2, FuncGraphPairMapEquiv *equiv_func_graph, NodeMapEquiv *equiv_node);
tensor::TensorPtr ScalarToTensor(const ScalarPtr &scalar); tensor::TensorPtr ScalarToTensor(const ScalarPtr &scalar);
void TensorValueToTensor(const ValuePtr &value, std::vector<tensor::TensorPtr> *tensors);
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_UTILS_CONVERT_UTILS_H_ #endif // MINDSPORE_CCSRC_UTILS_CONVERT_UTILS_H_
...@@ -50,8 +50,13 @@ using BaseShapePtr = std::shared_ptr<abstract::BaseShape>; ...@@ -50,8 +50,13 @@ using BaseShapePtr = std::shared_ptr<abstract::BaseShape>;
using AbstractBasePtr = std::shared_ptr<abstract::AbstractBase>; using AbstractBasePtr = std::shared_ptr<abstract::AbstractBase>;
using AbstractBasePtrList = std::vector<AbstractBasePtr>; using AbstractBasePtrList = std::vector<AbstractBasePtr>;
class Value;
using ValuePtr = std::shared_ptr<Value>;
using ValuePtrList = std::vector<ValuePtr>;
class ValueNode; class ValueNode;
using ValueNodePtr = std::shared_ptr<ValueNode>; using ValueNodePtr = std::shared_ptr<ValueNode>;
class CNode; class CNode;
using CNodePtr = std::shared_ptr<CNode>; using CNodePtr = std::shared_ptr<CNode>;
...@@ -225,6 +230,9 @@ class CNode : public AnfNode { ...@@ -225,6 +230,9 @@ class CNode : public AnfNode {
void set_input(size_t i, const AnfNodePtr &input); void set_input(size_t i, const AnfNodePtr &input);
void set_inputs(const std::vector<AnfNodePtr> &inputs) { inputs_ = inputs; } void set_inputs(const std::vector<AnfNodePtr> &inputs) { inputs_ = inputs; }
void set_forward(const ValuePtr &forward) { forward_ = forward; }
const ValuePtr &forward() const { return forward_; }
bool stop_gradient() const { return stop_gradient_; } bool stop_gradient() const { return stop_gradient_; }
void set_stop_gradient(bool stop_gradient) { stop_gradient_ = stop_gradient; } void set_stop_gradient(bool stop_gradient) { stop_gradient_ = stop_gradient; }
...@@ -243,6 +251,7 @@ class CNode : public AnfNode { ...@@ -243,6 +251,7 @@ class CNode : public AnfNode {
VarPtr func_graph_as_var_; VarPtr func_graph_as_var_;
bool stop_gradient_; bool stop_gradient_;
bool in_forward_flag_ = false; bool in_forward_flag_ = false;
ValuePtr forward_ = nullptr;
}; };
// ANode represents the atomic node. It's derived Parameter and ValueNode. // ANode represents the atomic node. It's derived Parameter and ValueNode.
...@@ -321,8 +330,6 @@ class Value : public Base { ...@@ -321,8 +330,6 @@ class Value : public Base {
protected: protected:
TypePtr type_{nullptr}; TypePtr type_{nullptr};
}; };
using ValuePtr = std::shared_ptr<Value>;
using ValuePtrList = std::vector<ValuePtr>;
// ValueNode is used to hold value. Unlike CNode and Parameter, ValueNode // ValueNode is used to hold value. Unlike CNode and Parameter, ValueNode
// does not belong to any particular function graph. // does not belong to any particular function graph.
...@@ -333,9 +340,13 @@ class ValueNode : public ANode { ...@@ -333,9 +340,13 @@ class ValueNode : public ANode {
MS_DECLARE_PARENT(ValueNode, ANode); MS_DECLARE_PARENT(ValueNode, ANode);
void accept(AnfIrVisitor *v) override; void accept(AnfIrVisitor *v) override;
void set_value(const ValuePtr &value) { value_ = value; }
const ValuePtr &value() const { return value_; } const ValuePtr &value() const { return value_; }
std::string fullname_with_scope() override; std::string fullname_with_scope() override;
void set_has_new_value(bool flag) { has_new_value_ = flag; }
bool has_new_value() const { return has_new_value_; }
std::string ToString() const override; std::string ToString() const override;
std::string DebugString(int recursive_level = 1) const override; std::string DebugString(int recursive_level = 1) const override;
std::string DebugString(bool recursive) const override { return DebugString(recursive ? 1 : 0); } std::string DebugString(bool recursive) const override { return DebugString(recursive ? 1 : 0); }
...@@ -355,6 +366,7 @@ class ValueNode : public ANode { ...@@ -355,6 +366,7 @@ class ValueNode : public ANode {
private: private:
ValuePtr value_; ValuePtr value_;
bool has_new_value_ = false;
}; };
template <typename T> template <typename T>
......
...@@ -88,6 +88,7 @@ void Cloner::CloneCNode(const AnfNodePtr &node, const FuncGraphPtr &target) { ...@@ -88,6 +88,7 @@ void Cloner::CloneCNode(const AnfNodePtr &node, const FuncGraphPtr &target) {
CNodePtr new_node = std::make_shared<CNode>(AnfNodePtrList{}, target); CNodePtr new_node = std::make_shared<CNode>(AnfNodePtrList{}, target);
auto old_node = node->cast<CNodePtr>(); auto old_node = node->cast<CNodePtr>();
new_node->set_abstract(old_node->abstract()); new_node->set_abstract(old_node->abstract());
new_node->set_forward(old_node->forward());
ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope(); ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope();
new_node->set_scope(scope); new_node->set_scope(scope);
new_node->set_kernel_info(old_node->kernel_info_ptr()); new_node->set_kernel_info(old_node->kernel_info_ptr());
...@@ -103,6 +104,7 @@ void Cloner::CloneValueNode(const AnfNodePtr &node) { ...@@ -103,6 +104,7 @@ void Cloner::CloneValueNode(const AnfNodePtr &node) {
ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope(); ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope();
new_const->set_scope(scope); new_const->set_scope(scope);
new_const->set_abstract(node->abstract()); new_const->set_abstract(node->abstract());
new_const->set_has_new_value(node->cast<ValueNodePtr>()->has_new_value());
repl_node_[node] = new_const; repl_node_[node] = new_const;
TraceManager::EndTrace(); TraceManager::EndTrace();
} }
...@@ -115,6 +117,7 @@ void Cloner::CloneValueNode(const AnfNodePtr &node, const FuncGraphPtr &target) ...@@ -115,6 +117,7 @@ void Cloner::CloneValueNode(const AnfNodePtr &node, const FuncGraphPtr &target)
ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope(); ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope();
new_const->set_scope(scope); new_const->set_scope(scope);
new_const->set_abstract(node->abstract()); new_const->set_abstract(node->abstract());
new_const->set_has_new_value(node->cast<ValueNodePtr>()->has_new_value());
repl_node_[node] = new_const; repl_node_[node] = new_const;
TraceManager::EndTrace(); TraceManager::EndTrace();
} }
......
...@@ -19,7 +19,7 @@ from mindspore.ops.composite import grad, grad_all, grad_all_with_sens ...@@ -19,7 +19,7 @@ from mindspore.ops.composite import grad, grad_all, grad_all_with_sens
def setup_module(module): def setup_module(module):
context.set_context(mode=context.PYNATIVE_MODE) context.set_context(mode=context.PYNATIVE_MODE, check_bprop=False)
def single(x): def single(x):
......
...@@ -554,9 +554,7 @@ def softmax_cross_entropy_with_logits(logits, labels): ...@@ -554,9 +554,7 @@ def softmax_cross_entropy_with_logits(logits, labels):
sample_num = labels.shape[0] sample_num = labels.shape[0]
prob = softmax(logits) prob = softmax(logits)
log_likelihood = -np.log(prob[range(sample_num)]) * labels log_likelihood = -np.log(prob[range(sample_num)]) * labels
# loss = np.sum(log_likelihood) loss = np.sum(log_likelihood)
loss = log_likelihood
dx = prob.copy() dx = prob.copy()
dx[range(sample_num)] -= labels dx[range(sample_num)] -= labels
return loss, dx return loss, dx
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册