提交 64abbeaa 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!705 add pynative cache

Merge pull request !705 from chujinjin/add_pynative_cache
...@@ -22,7 +22,7 @@ namespace mindspore { ...@@ -22,7 +22,7 @@ namespace mindspore {
namespace device { namespace device {
namespace ascend { namespace ascend {
const uint64_t kAscendDeviceMemGB = 20; const uint64_t kAscendDeviceMemGB = 20;
const uint64_t kAscendMemPoolGB = 5; const uint64_t kAscendMemPoolGB = 10;
const uint64_t kAscendDeviceMemSize = (kAscendDeviceMemGB << 30); const uint64_t kAscendDeviceMemSize = (kAscendDeviceMemGB << 30);
const uint64_t kAscendMemPoolSize = (kAscendMemPoolGB << 30); const uint64_t kAscendMemPoolSize = (kAscendMemPoolGB << 30);
......
...@@ -38,6 +38,7 @@ ...@@ -38,6 +38,7 @@
#include "parallel/graph_util/get_parallel_info.h" #include "parallel/graph_util/get_parallel_info.h"
#include "device/kernel_runtime_manager.h" #include "device/kernel_runtime_manager.h"
#include "debug/trace.h" #include "debug/trace.h"
#include "pynative/pynative_execute.h"
#if (ENABLE_GE || ENABLE_D) #if (ENABLE_GE || ENABLE_D)
#include "pipeline/pipeline_ge.h" #include "pipeline/pipeline_ge.h"
...@@ -829,6 +830,7 @@ void FinalizeBackend() { ...@@ -829,6 +830,7 @@ void FinalizeBackend() {
void ClearResAtexit() { void ClearResAtexit() {
MS_LOG(DEBUG) << "Pipeline clear all resource"; MS_LOG(DEBUG) << "Pipeline clear all resource";
pynative::ClearPyNativeSession();
device::KernelRuntimeManager::Instance().ClearRuntimeResource(); device::KernelRuntimeManager::Instance().ClearRuntimeResource();
ad::g_k_prims.clear(); ad::g_k_prims.clear();
......
...@@ -44,6 +44,7 @@ const std::set<std::string> vm_operators = {"partial", "depend", "make_ref", "ze ...@@ -44,6 +44,7 @@ const std::set<std::string> vm_operators = {"partial", "depend", "make_ref", "ze
namespace mindspore { namespace mindspore {
namespace pynative { namespace pynative {
static std::shared_ptr<session::SessionBasic> session = nullptr;
inline ValuePtr PyAttrValue(const py::object &obj) { inline ValuePtr PyAttrValue(const py::object &obj) {
ValuePtr converted_ret = nullptr; ValuePtr converted_ret = nullptr;
bool converted = parse::ConvertData(obj, &converted_ret); bool converted = parse::ConvertData(obj, &converted_ret);
...@@ -310,7 +311,11 @@ py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat ...@@ -310,7 +311,11 @@ py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat
if (device_target != kAscendDevice && device_target != kGPUDevice) { if (device_target != kAscendDevice && device_target != kGPUDevice) {
MS_EXCEPTION(ArgumentError) << "Device target [" << device_target << "] is not supported in Pynative mode"; MS_EXCEPTION(ArgumentError) << "Device target [" << device_target << "] is not supported in Pynative mode";
} }
std::shared_ptr<session::SessionBasic> session = session::SessionFactory::Get().Create(device_target);
if (session == nullptr) {
session = session::SessionFactory::Get().Create(device_target);
}
MS_EXCEPTION_IF_NULL(session); MS_EXCEPTION_IF_NULL(session);
session->Init(ms_context->device_id()); session->Init(ms_context->device_id());
...@@ -407,5 +412,7 @@ py::tuple RunOp(const py::args &args) { ...@@ -407,5 +412,7 @@ py::tuple RunOp(const py::args &args) {
MS_LOG(INFO) << "RunOp end"; MS_LOG(INFO) << "RunOp end";
return result; return result;
} }
void ClearPyNativeSession() { session = nullptr; }
} // namespace pynative } // namespace pynative
} // namespace mindspore } // namespace mindspore
...@@ -36,6 +36,9 @@ namespace py = pybind11; ...@@ -36,6 +36,9 @@ namespace py = pybind11;
py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status); py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status);
py::tuple RunOp(const py::args &args); py::tuple RunOp(const py::args &args);
void ClearPyNativeSession();
} // namespace pynative } // namespace pynative
} // namespace mindspore } // namespace mindspore
......
...@@ -249,10 +249,23 @@ void AscendSession::RunOpExecTask(const std::shared_ptr<KernelGraph> &kernel_gra ...@@ -249,10 +249,23 @@ void AscendSession::RunOpExecTask(const std::shared_ptr<KernelGraph> &kernel_gra
MS_LOG(INFO) << "Finish!"; MS_LOG(INFO) << "Finish!";
} }
bool AscendSession::GraphCacheExist(const GraphInfo &graph_info) const {
if (run_op_graphs_.find(graph_info) != run_op_graphs_.end()) {
return true;
}
return false;
}
void AscendSession::BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, void AscendSession::BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
const std::vector<tensor::TensorPtr> &input_tensors, const std::vector<tensor::TensorPtr> &input_tensors,
const std::vector<bool> &tensors_mask) { const std::vector<bool> &tensors_mask) {
MS_LOG(INFO) << "Build op " << op_run_info.op_name << " start !"; MS_LOG(INFO) << "Build op " << op_run_info.op_name << " start !";
if (GraphCacheExist(graph_info)) {
MS_LOG(INFO) << "Build op " << op_run_info.op_name << " finish !";
return;
}
// construct graph include one op // construct graph include one op
auto graph = ConstructSingleOpGraph(op_run_info, input_tensors, tensors_mask); auto graph = ConstructSingleOpGraph(op_run_info, input_tensors, tensors_mask);
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
...@@ -267,6 +280,7 @@ void AscendSession::BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph ...@@ -267,6 +280,7 @@ void AscendSession::BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph
RunOpAdjustKernel(graph); RunOpAdjustKernel(graph);
BuildKernel(graph); BuildKernel(graph);
run_op_graphs_[graph_info] = graph; run_op_graphs_[graph_info] = graph;
MS_LOG(INFO) << "Build op " << op_run_info.op_name << " finish !";
} }
py::tuple AscendSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, py::tuple AscendSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
...@@ -291,7 +305,6 @@ py::tuple AscendSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &gr ...@@ -291,7 +305,6 @@ py::tuple AscendSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &gr
} }
py::object tuple_obj = utils::cast<PyObjectRef>(output_tensors).object_; py::object tuple_obj = utils::cast<PyObjectRef>(output_tensors).object_;
py::tuple tuple_tensors = py::cast<py::tuple>(tuple_obj); py::tuple tuple_tensors = py::cast<py::tuple>(tuple_obj);
run_op_graphs_.clear();
MS_LOG(INFO) << "Run op " << op_run_info.op_name << " finish!"; MS_LOG(INFO) << "Run op " << op_run_info.op_name << " finish!";
return tuple_tensors; return tuple_tensors;
} }
......
...@@ -111,6 +111,8 @@ class AscendSession : public SessionBasic { ...@@ -111,6 +111,8 @@ class AscendSession : public SessionBasic {
std::vector<GraphType> &GetGraphOrderType(GraphId final_graph_id); std::vector<GraphType> &GetGraphOrderType(GraphId final_graph_id);
// copy output of if and else // copy output of if and else
void CopyOutputOfIf(GraphId false_graph_id); void CopyOutputOfIf(GraphId false_graph_id);
// check if graph cache exist
bool GraphCacheExist(const GraphInfo &graph_info) const;
// member variables // member variables
// key is final_graph_id,value is child graph execute order of final graph // key is final_graph_id,value is child graph execute order of final graph
......
...@@ -125,7 +125,7 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne ...@@ -125,7 +125,7 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne
// if in paynative mode,data only copyed to host when user want to print data // if in paynative mode,data only copyed to host when user want to print data
auto ms_context = MsContext::GetInstance(); auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context); MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->enable_pynative_infer()) { if (ms_context->execution_mode() == kPynativeMode) {
tensor->set_device_address(AnfAlgo::GetMutableOutputAddr(node, output_index)); tensor->set_device_address(AnfAlgo::GetMutableOutputAddr(node, output_index));
} else if (!address->SyncDeviceToHost(trans::GetRuntimePaddingShape(node, output_index), } else if (!address->SyncDeviceToHost(trans::GetRuntimePaddingShape(node, output_index),
LongToSize(tensor->data().nbytes()), tensor->data_type(), LongToSize(tensor->data().nbytes()), tensor->data_type(),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册