提交 dde03ce9 编写于 作者: C chujinjin

add async ops excute for pynative

上级 f6b5b273
...@@ -92,10 +92,29 @@ bool SyncDeviceToHostAndFloatToFloat64(void *dst, size_t dst_size, const void *s ...@@ -92,10 +92,29 @@ bool SyncDeviceToHostAndFloatToFloat64(void *dst, size_t dst_size, const void *s
return true; return true;
} }
void AscendDeviceAddress::SyncStream() const {
MS_LOG(INFO) << "Start!";
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
auto device_id = ms_context->device_id();
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id);
MS_EXCEPTION_IF_NULL(runtime_instance);
auto ret = runtime_instance->SyncStream();
if (!ret) {
MS_LOG(EXCEPTION) << "Sync stream error!";
}
MS_LOG(INFO) << "Finish!";
}
bool AscendDeviceAddress::SyncDeviceToHost(const std::vector<int> &shape, size_t size, mindspore::TypeId type, bool AscendDeviceAddress::SyncDeviceToHost(const std::vector<int> &shape, size_t size, mindspore::TypeId type,
void *host_ptr) const { void *host_ptr) const {
MS_LOG(INFO) << "SyncDeviceToHost, Device(format:" << format_ << ", type_id:" << TypeIdLabel(type_id_) MS_LOG(INFO) << "SyncDeviceToHost, Device(format:" << format_ << ", type_id:" << TypeIdLabel(type_id_)
<< ", size:" << size_ << "), Host(type_id:" << TypeIdLabel(type) << ", size:" << size << ")"; << ", size:" << size_ << "), Host(type_id:" << TypeIdLabel(type) << ", size:" << size << ")";
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->execution_mode() == kPynativeMode) {
SyncStream();
}
bool sync_ok = false; bool sync_ok = false;
std::vector<size_t> host_shape; std::vector<size_t> host_shape;
(void)std::transform(shape.begin(), shape.end(), std::back_inserter(host_shape), IntToSize); (void)std::transform(shape.begin(), shape.end(), std::back_inserter(host_shape), IntToSize);
......
...@@ -44,6 +44,7 @@ class AscendDeviceAddress : public DeviceAddress { ...@@ -44,6 +44,7 @@ class AscendDeviceAddress : public DeviceAddress {
bool SyncDeviceToHostAndConvertFormat(const std::vector<int> &shape, size_t size, TypeId type, void *host_ptr) const; bool SyncDeviceToHostAndConvertFormat(const std::vector<int> &shape, size_t size, TypeId type, void *host_ptr) const;
bool ConvertFormatAndSyncHostToDevice(const std::vector<int> &shape, size_t size, TypeId type, bool ConvertFormatAndSyncHostToDevice(const std::vector<int> &shape, size_t size, TypeId type,
const void *host_ptr) const; const void *host_ptr) const;
void SyncStream() const;
}; };
using AscendDeviceAddressPtr = std::shared_ptr<AscendDeviceAddress>; using AscendDeviceAddressPtr = std::shared_ptr<AscendDeviceAddress>;
} // namespace ascend } // namespace ascend
......
...@@ -41,12 +41,12 @@ class AscendKernelRuntime : public KernelRuntime { ...@@ -41,12 +41,12 @@ class AscendKernelRuntime : public KernelRuntime {
bool RunTask(const session::KernelGraph *graph) override; bool RunTask(const session::KernelGraph *graph) override;
bool LoadTask(const session::KernelGraph *graph) override; bool LoadTask(const session::KernelGraph *graph) override;
void ClearGraphRuntimeResource(uint32_t graph_id) override; void ClearGraphRuntimeResource(uint32_t graph_id) override;
bool SyncStream() override;
protected: protected:
DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
TypeId type_id) override; TypeId type_id) override;
bool NodeOutputDeviceAddressExist(const AnfNodePtr &node, size_t index) override; bool NodeOutputDeviceAddressExist(const AnfNodePtr &node, size_t index) override;
bool SyncStream() override;
private: private:
bool InitDevice(); bool InitDevice();
......
...@@ -28,6 +28,13 @@ namespace device { ...@@ -28,6 +28,13 @@ namespace device {
namespace gpu { namespace gpu {
bool GPUDeviceAddress::SyncDeviceToHost(const std::vector<int> &, size_t size, TypeId, void *host_ptr) const { bool GPUDeviceAddress::SyncDeviceToHost(const std::vector<int> &, size_t size, TypeId, void *host_ptr) const {
MS_EXCEPTION_IF_NULL(host_ptr); MS_EXCEPTION_IF_NULL(host_ptr);
auto &stream = GPUDeviceManager::GetInstance().default_stream();
MS_EXCEPTION_IF_NULL(stream);
auto ret = GPUDeviceManager::GetInstance().SyncStream(stream);
if (!ret) {
MS_LOG(ERROR) << "SyncStream failed";
return ret;
}
if (size != size_) { if (size != size_) {
MS_LOG(WARNING) << "SyncDeviceToHost ignored, host size: " << size << ", device size " << size_; MS_LOG(WARNING) << "SyncDeviceToHost ignored, host size: " << size << ", device size " << size_;
return true; return true;
......
...@@ -680,10 +680,6 @@ bool KernelRuntime::LaunchKernel(const session::KernelGraph *graph) { ...@@ -680,10 +680,6 @@ bool KernelRuntime::LaunchKernel(const session::KernelGraph *graph) {
MS_LOG(ERROR) << "LaunchKernelMod failed!"; MS_LOG(ERROR) << "LaunchKernelMod failed!";
return false; return false;
} }
if (!SyncStream()) {
MS_LOG(ERROR) << "SyncStream failed!";
return false;
}
return true; return true;
} }
......
...@@ -55,6 +55,7 @@ class KernelRuntime { ...@@ -55,6 +55,7 @@ class KernelRuntime {
virtual void AssignStaticMemoryInput(const session::KernelGraph *graph); virtual void AssignStaticMemoryInput(const session::KernelGraph *graph);
virtual void AssignStaticMemoryValueNode(session::KernelGraph *graph); virtual void AssignStaticMemoryValueNode(session::KernelGraph *graph);
virtual void ClearGraphRuntimeResource(uint32_t graph_id); virtual void ClearGraphRuntimeResource(uint32_t graph_id);
virtual bool SyncStream() = 0;
#ifdef ENABLE_DUMP_E2E #ifdef ENABLE_DUMP_E2E
DumpConfPtr GetDumpConf(); DumpConfPtr GetDumpConf();
...@@ -68,7 +69,6 @@ class KernelRuntime { ...@@ -68,7 +69,6 @@ class KernelRuntime {
virtual DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, virtual DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
TypeId type_id) = 0; TypeId type_id) = 0;
virtual bool NodeOutputDeviceAddressExist(const AnfNodePtr &node, size_t index); virtual bool NodeOutputDeviceAddressExist(const AnfNodePtr &node, size_t index);
virtual bool SyncStream() = 0;
void AssignStaticMemory(session::KernelGraph *graph); void AssignStaticMemory(session::KernelGraph *graph);
void AssignDynamicMemory(session::KernelGraph *graph); void AssignDynamicMemory(session::KernelGraph *graph);
void ReuseAssignDynamicMemory(session::KernelGraph *graph); void ReuseAssignDynamicMemory(session::KernelGraph *graph);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册