提交 64abbeaa 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!705 add pynative cache

Merge pull request !705 from chujinjin/add_pynative_cache
......@@ -22,7 +22,7 @@ namespace mindspore {
namespace device {
namespace ascend {
const uint64_t kAscendDeviceMemGB = 20;
const uint64_t kAscendMemPoolGB = 5;
const uint64_t kAscendMemPoolGB = 10;
const uint64_t kAscendDeviceMemSize = (kAscendDeviceMemGB << 30);
const uint64_t kAscendMemPoolSize = (kAscendMemPoolGB << 30);
......
......@@ -38,6 +38,7 @@
#include "parallel/graph_util/get_parallel_info.h"
#include "device/kernel_runtime_manager.h"
#include "debug/trace.h"
#include "pynative/pynative_execute.h"
#if (ENABLE_GE || ENABLE_D)
#include "pipeline/pipeline_ge.h"
......@@ -829,6 +830,7 @@ void FinalizeBackend() {
void ClearResAtexit() {
MS_LOG(DEBUG) << "Pipeline clear all resource";
pynative::ClearPyNativeSession();
device::KernelRuntimeManager::Instance().ClearRuntimeResource();
ad::g_k_prims.clear();
......
......@@ -44,6 +44,7 @@ const std::set<std::string> vm_operators = {"partial", "depend", "make_ref", "ze
namespace mindspore {
namespace pynative {
static std::shared_ptr<session::SessionBasic> session = nullptr;
inline ValuePtr PyAttrValue(const py::object &obj) {
ValuePtr converted_ret = nullptr;
bool converted = parse::ConvertData(obj, &converted_ret);
......@@ -310,7 +311,11 @@ py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat
if (device_target != kAscendDevice && device_target != kGPUDevice) {
MS_EXCEPTION(ArgumentError) << "Device target [" << device_target << "] is not supported in Pynative mode";
}
std::shared_ptr<session::SessionBasic> session = session::SessionFactory::Get().Create(device_target);
if (session == nullptr) {
session = session::SessionFactory::Get().Create(device_target);
}
MS_EXCEPTION_IF_NULL(session);
session->Init(ms_context->device_id());
......@@ -407,5 +412,7 @@ py::tuple RunOp(const py::args &args) {
MS_LOG(INFO) << "RunOp end";
return result;
}
void ClearPyNativeSession() { session = nullptr; }
} // namespace pynative
} // namespace mindspore
......@@ -36,6 +36,9 @@ namespace py = pybind11;
py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status);
py::tuple RunOp(const py::args &args);
void ClearPyNativeSession();
} // namespace pynative
} // namespace mindspore
......
......@@ -249,10 +249,23 @@ void AscendSession::RunOpExecTask(const std::shared_ptr<KernelGraph> &kernel_gra
MS_LOG(INFO) << "Finish!";
}
bool AscendSession::GraphCacheExist(const GraphInfo &graph_info) const {
if (run_op_graphs_.find(graph_info) != run_op_graphs_.end()) {
return true;
}
return false;
}
void AscendSession::BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
const std::vector<tensor::TensorPtr> &input_tensors,
const std::vector<bool> &tensors_mask) {
MS_LOG(INFO) << "Build op " << op_run_info.op_name << " start !";
if (GraphCacheExist(graph_info)) {
MS_LOG(INFO) << "Build op " << op_run_info.op_name << " finish !";
return;
}
// construct graph include one op
auto graph = ConstructSingleOpGraph(op_run_info, input_tensors, tensors_mask);
MS_EXCEPTION_IF_NULL(graph);
......@@ -267,6 +280,7 @@ void AscendSession::BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph
RunOpAdjustKernel(graph);
BuildKernel(graph);
run_op_graphs_[graph_info] = graph;
MS_LOG(INFO) << "Build op " << op_run_info.op_name << " finish !";
}
py::tuple AscendSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
......@@ -291,7 +305,6 @@ py::tuple AscendSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &gr
}
py::object tuple_obj = utils::cast<PyObjectRef>(output_tensors).object_;
py::tuple tuple_tensors = py::cast<py::tuple>(tuple_obj);
run_op_graphs_.clear();
MS_LOG(INFO) << "Run op " << op_run_info.op_name << " finish!";
return tuple_tensors;
}
......
......@@ -111,6 +111,8 @@ class AscendSession : public SessionBasic {
std::vector<GraphType> &GetGraphOrderType(GraphId final_graph_id);
// copy output of if and else
void CopyOutputOfIf(GraphId false_graph_id);
// check if graph cache exist
bool GraphCacheExist(const GraphInfo &graph_info) const;
// member variables
// key is final_graph_id,value is child graph execute order of final graph
......
......@@ -125,7 +125,7 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne
// if in paynative mode,data only copyed to host when user want to print data
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->enable_pynative_infer()) {
if (ms_context->execution_mode() == kPynativeMode) {
tensor->set_device_address(AnfAlgo::GetMutableOutputAddr(node, output_index));
} else if (!address->SyncDeviceToHost(trans::GetRuntimePaddingShape(node, output_index),
LongToSize(tensor->data().nbytes()), tensor->data_type(),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册