提交 ee1510da 编写于 作者: H He Wei

Eliminate circular dependency between 'ir' and 'device/kernel'

上级 c99cc0df
......@@ -38,7 +38,7 @@ void CPUKernelFactory::Register(const std::string &kernel_name, const KernelAttr
}
std::shared_ptr<CPUKernel> CPUKernelFactory::Create(const std::string &kernel_name, const CNodePtr &apply_kernel) {
auto kernel_info = apply_kernel->kernel_info();
auto kernel_info = dynamic_cast<device::KernelInfo *>(apply_kernel->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
const KernelBuildInfo *kernel_build_Info = kernel_info->select_kernel_build_info();
MS_EXCEPTION_IF_NULL(kernel_build_Info);
......
......@@ -137,7 +137,7 @@ std::pair<bool, size_t> GpuKernelFactory::GpuKernelAttrCheck(const std::string &
}
GpuKernel *GpuKernelFactory::Create(const std::string &kernel_name, const CNodePtr &apply_kernel) {
auto kernel_info = apply_kernel->kernel_info();
auto kernel_info = dynamic_cast<device::KernelInfo *>(apply_kernel->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
const KernelBuildInfo *kernel_build_Info = kernel_info->select_kernel_build_info();
MS_EXCEPTION_IF_NULL(kernel_build_Info);
......
......@@ -63,7 +63,7 @@ const AnfNodePtr ParamTransRoad(const FuncGraphPtr &func_graph, const AnfNodePtr
kernel::KernelBuildInfoPtr GetKernelBuildInfo(const CNodePtr &cast, const string &format, TypeId input_type,
TypeId output_type) {
MS_EXCEPTION_IF_NULL(cast);
auto kernel_info = cast->kernel_info();
auto kernel_info = dynamic_cast<device::KernelInfo *>(cast->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
auto cast_build_info = kernel_info->select_kernel_build_info();
MS_EXCEPTION_IF_NULL(cast_build_info);
......
......@@ -23,8 +23,8 @@ namespace {
bool CheckEqualKernelBuildInfo(const AnfNodePtr &main, const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(main);
MS_EXCEPTION_IF_NULL(node);
auto main_kernel_info = main->kernel_info();
auto node_kernel_info = node->kernel_info();
auto main_kernel_info = dynamic_cast<device::KernelInfo *>(main->kernel_info());
auto node_kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
if (main_kernel_info == nullptr && node_kernel_info == nullptr) {
return true;
}
......
......@@ -338,7 +338,7 @@ std::string AnfRuntimeAlgorithm::GetOutputFormat(const AnfNodePtr &node, size_t
if (!AnfAlgo::IsRealKernel(node)) {
return AnfAlgo::GetPrevNodeOutputFormat(node, output_idx);
}
auto kernel_info = node->kernel_info();
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
auto build_info = kernel_info->select_kernel_build_info();
MS_EXCEPTION_IF_NULL(build_info);
......@@ -360,7 +360,7 @@ std::string AnfRuntimeAlgorithm::GetInputFormat(const AnfNodePtr &node, size_t i
if (!IsRealKernel(node)) {
GetPrevNodeOutputFormat(node, input_idx);
}
auto kernel_info = node->kernel_info();
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
auto build_info = kernel_info->select_kernel_build_info();
MS_EXCEPTION_IF_NULL(build_info);
......@@ -467,7 +467,7 @@ std::vector<kernel::Axis> AnfRuntimeAlgorithm::GetInputReshapeType(const AnfNode
if (!IsRealKernel(node)) {
return GetPrevNodeOutputReshapeType(node, input_idx);
}
auto kernel_info = node->kernel_info();
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
auto build_info = kernel_info->select_kernel_build_info();
MS_EXCEPTION_IF_NULL(build_info);
......@@ -486,7 +486,7 @@ std::vector<kernel::Axis> AnfRuntimeAlgorithm::GetOutputReshapeType(const AnfNod
if (!IsRealKernel(node)) {
return GetPrevNodeOutputReshapeType(node, output_idx);
}
auto kernel_info = node->kernel_info();
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
auto build_info = kernel_info->select_kernel_build_info();
MS_EXCEPTION_IF_NULL(build_info);
......@@ -546,7 +546,7 @@ TypeId AnfRuntimeAlgorithm::GetOutputDeviceDataType(const AnfNodePtr &node, size
if (!IsRealKernel(node)) {
return GetPrevNodeOutputDeviceDataType(node, output_idx);
}
auto kernel_info = node->kernel_info();
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
auto build_info = kernel_info->select_kernel_build_info();
MS_EXCEPTION_IF_NULL(build_info);
......@@ -567,7 +567,7 @@ TypeId AnfRuntimeAlgorithm::GetInputDeviceDataType(const AnfNodePtr &node, size_
if (!IsRealKernel(node)) {
return GetPrevNodeOutputDeviceDataType(node, 0);
}
auto kernel_info = node->kernel_info();
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
auto build_info = kernel_info->select_kernel_build_info();
MS_EXCEPTION_IF_NULL(build_info);
......@@ -597,7 +597,7 @@ const DeviceAddress *AnfRuntimeAlgorithm::GetOutputAddr(const AnfNodePtr &node,
MS_LOG(EXCEPTION) << node->DebugString() << "Invalid nop node";
}
}
auto kernel_info = node->kernel_info();
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
auto addr = kernel_info->GetOutputAddr(output_idx);
if (addr == nullptr) {
......@@ -619,7 +619,7 @@ DeviceAddressPtr AnfRuntimeAlgorithm::GetMutableOutputAddr(const AnfNodePtr &nod
MS_LOG(EXCEPTION) << node->DebugString() << "Invalid nop node.";
}
}
auto kernel_info = node->kernel_info();
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
auto addr = kernel_info->GetMutableOutputAddr(output_idx);
if (addr == nullptr) {
......@@ -636,7 +636,7 @@ bool AnfRuntimeAlgorithm::OutputAddrExist(const AnfNodePtr &node, size_t output_
MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ "
<< GetOutputTensorNum(node) << "#node:[ " << node->DebugString() << "]";
}
auto kernel_info = node->kernel_info();
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
return kernel_info->OutputAddrExist(output_idx);
}
......@@ -656,7 +656,7 @@ DeviceAddressPtr AnfRuntimeAlgorithm::GetPrevNodeMutableOutputAddr(const AnfNode
// set output device addr of anf_node
void AnfRuntimeAlgorithm::SetOutputAddr(const DeviceAddressPtr &addr, size_t output_idx, AnfNode *node) {
MS_EXCEPTION_IF_NULL(node);
auto kernel_info = node->kernel_info();
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
if (!kernel_info->SetOutputAddr(addr, output_idx)) {
MS_LOG(EXCEPTION) << "Node " << node->DebugString() << "set adr" << output_idx << " fail";
......@@ -666,7 +666,7 @@ void AnfRuntimeAlgorithm::SetOutputAddr(const DeviceAddressPtr &addr, size_t out
// set workspace device addr of anf_node
void AnfRuntimeAlgorithm::SetWorkspaceAddr(const DeviceAddressPtr &addr, size_t output_idx, AnfNode *node) {
MS_EXCEPTION_IF_NULL(node);
auto kernel_info = node->kernel_info();
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
if (!kernel_info->SetWorkspaceAddr(addr, output_idx)) {
MS_LOG(EXCEPTION) << "Node " << node->DebugString() << "set adr" << output_idx << " fail";
......@@ -676,7 +676,7 @@ void AnfRuntimeAlgorithm::SetWorkspaceAddr(const DeviceAddressPtr &addr, size_t
// get workspace device addr of anf_node
DeviceAddress *AnfRuntimeAlgorithm::GetWorkspaceAddr(const AnfNodePtr &node, size_t output_idx) {
MS_EXCEPTION_IF_NULL(node);
auto kernel_info = node->kernel_info();
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
auto addr = kernel_info->GetWorkspaceAddr(output_idx);
if (addr == nullptr) {
......@@ -720,7 +720,7 @@ void AnfRuntimeAlgorithm::CopyAbstract(const AnfNodePtr &from_node, AnfNode *to_
kernel::OpPattern AnfRuntimeAlgorithm::GetOpPattern(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto kernel_info = node->kernel_info();
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
// select_kernel_build_info() has checked whether return pointer is null
auto build_info = kernel_info->select_kernel_build_info();
......@@ -731,7 +731,7 @@ kernel::OpPattern AnfRuntimeAlgorithm::GetOpPattern(const AnfNodePtr &node) {
// get KernelBuildType of node, such as ATT,RT,FWK and so on
KernelType AnfRuntimeAlgorithm::GetKernelType(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto kernel_info = node->kernel_info();
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
// select_kernel_build_info() has checked whether return pointer is null
auto build_info = kernel_info->select_kernel_build_info();
......@@ -741,7 +741,7 @@ KernelType AnfRuntimeAlgorithm::GetKernelType(const AnfNodePtr &node) {
kernel::Processor AnfRuntimeAlgorithm::GetProcessor(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto kernel_info = node->kernel_info();
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
auto build_info = kernel_info->select_kernel_build_info();
MS_EXCEPTION_IF_NULL(build_info);
......@@ -750,7 +750,7 @@ kernel::Processor AnfRuntimeAlgorithm::GetProcessor(const AnfNodePtr &node) {
kernel::FusionType AnfRuntimeAlgorithm::GetFusionType(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto kernel_info = node->kernel_info();
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
auto build_info = kernel_info->select_kernel_build_info();
MS_EXCEPTION_IF_NULL(build_info);
......@@ -760,7 +760,7 @@ kernel::FusionType AnfRuntimeAlgorithm::GetFusionType(const AnfNodePtr &node) {
// set select kernel_build_info
void AnfRuntimeAlgorithm::SetSelectKernelBuildInfo(const KernelBuildInfoPtr &select_kernel_build_info, AnfNode *node) {
MS_EXCEPTION_IF_NULL(node);
auto kernel_info = node->kernel_info();
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
return kernel_info->set_select_kernel_build_info(select_kernel_build_info);
}
......@@ -768,7 +768,7 @@ void AnfRuntimeAlgorithm::SetSelectKernelBuildInfo(const KernelBuildInfoPtr &sel
// get select kernel_build_info
KernelBuildInfoPtr AnfRuntimeAlgorithm::GetSelectKernelBuildInfo(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto kernel_info = node->kernel_info();
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
return kernel_info->GetMutableSelectKernelBuildInfo();
}
......@@ -776,7 +776,7 @@ KernelBuildInfoPtr AnfRuntimeAlgorithm::GetSelectKernelBuildInfo(const AnfNodePt
// get kernelMode
KernelMod *AnfRuntimeAlgorithm::GetKernelMod(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto kernel_info = node->kernel_info();
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
return kernel_info->MutableKernelMod();
}
......@@ -784,7 +784,7 @@ KernelMod *AnfRuntimeAlgorithm::GetKernelMod(const AnfNodePtr &node) {
// set kernel mod
void AnfRuntimeAlgorithm::SetKernelMod(const KernelModPtr &kernel_mod, AnfNode *node) {
MS_EXCEPTION_IF_NULL(node);
auto kernel_info = node->kernel_info();
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
kernel_info->set_kernel_mod(kernel_mod);
}
......@@ -850,42 +850,42 @@ bool AnfRuntimeAlgorithm::IsParameterWeight(const ParameterPtr &node) {
void AnfRuntimeAlgorithm::SetStreamId(uint32_t stream_id, AnfNode *node) {
MS_EXCEPTION_IF_NULL(node);
auto kernel_info = node->kernel_info();
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
kernel_info->set_stream_id(stream_id);
}
uint32_t AnfRuntimeAlgorithm::GetStreamId(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto kernel_info = node->kernel_info();
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
return kernel_info->stream_id();
}
void AnfRuntimeAlgorithm::SetStreamDistinctionLabel(uint32_t stream_label, AnfNode *node) {
MS_EXCEPTION_IF_NULL(node);
auto kernel_info = node->kernel_info();
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
kernel_info->set_stream_distinction_label(stream_label);
}
uint32_t AnfRuntimeAlgorithm::GetStreamDistinctionLabel(const AnfNode *node) {
MS_EXCEPTION_IF_NULL(node);
auto kernel_info = node->kernel_info();
auto kernel_info = dynamic_cast<const device::KernelInfo *>(node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
return kernel_info->stream_distinction_label();
}
void AnfRuntimeAlgorithm::SetGraphId(uint32_t graph_id, AnfNode *node) {
MS_EXCEPTION_IF_NULL(node);
auto kernel_info = node->kernel_info();
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
kernel_info->set_graph_id(graph_id);
}
uint32_t AnfRuntimeAlgorithm::GetGraphId(const AnfNode *node) {
MS_EXCEPTION_IF_NULL(node);
auto kernel_info = node->kernel_info();
auto kernel_info = dynamic_cast<const device::KernelInfo *>(node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
return kernel_info->graph_id();
}
......@@ -913,7 +913,7 @@ bool AnfRuntimeAlgorithm::IsFeatureMapOutput(const AnfNodePtr &node) {
if (node->isa<ValueNode>()) {
return false;
}
auto kernel_info = node->kernel_info();
auto kernel_info = dynamic_cast<const device::KernelInfo *>(node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
return kernel_info->is_feature_map();
}
......
......@@ -38,6 +38,8 @@ namespace mindspore {
namespace session {
using AnfVisitFuncion = std::function<Any(const AnfNodePtr &node, int index)>;
using KernelWithIndex = std::pair<AnfNodePtr, size_t>;
using DeviceAddress = device::DeviceAddress;
using DeviceAddressPtr = device::DeviceAddressPtr;
class AnfRuntimeAlgorithm {
public:
// get input_anf_node's real kernel by recurse
......
......@@ -121,7 +121,7 @@ void GPUSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
if (input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0)) {
auto pk_node = input_node->cast<ParameterPtr>();
auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0);
auto tensor_address = tensor->device_address();
auto tensor_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
bool need_sync = false;
if (ms_context->enable_pynative_infer()) {
if (tensor_address == nullptr || tensor_address != device_address) {
......
......@@ -230,13 +230,14 @@ ParameterPtr ConstructRunOpParameter(const std::shared_ptr<KernelGraph> &graph,
// set the kernel info of parameter
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
MS_EXCEPTION_IF_NULL(input_tensor);
if (input_tensor->device_address().get() == nullptr) {
auto device_address = std::dynamic_pointer_cast<device::DeviceAddress>(input_tensor->device_address());
if (device_address == nullptr) {
kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT});
TypeId param_init_data_type = AnfAlgo::IsParameterWeight(param) ? kTypeUnknown : input_tensor->data_type();
kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>{param_init_data_type});
} else {
kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{input_tensor->device_address()->format()});
kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>{input_tensor->device_address()->type_id()});
kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{device_address->format()});
kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>{device_address->type_id()});
}
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), param.get());
// construct abstract of parameter
......@@ -319,7 +320,7 @@ void SessionBasic::InitInternalOutputParameter(const AnfNodePtr &out_node, const
if (ref_real_node->isa<CNode>() && node_graph->IsInternalOutput(ref_real_node) &&
node_graph->IsFinalOutputKernel(ref_real_node)) {
auto kernel_info = ref_real_node->kernel_info();
if (kernel_info == nullptr || kernel_info->select_kernel_build_info() == nullptr) {
if (kernel_info == nullptr || !kernel_info->has_build_info()) {
MS_LOG(INFO) << "No kernel info";
return;
}
......@@ -330,9 +331,9 @@ void SessionBasic::InitInternalOutputParameter(const AnfNodePtr &out_node, const
}
auto format = AnfAlgo::GetOutputFormat(ref_real_node, ref_real_node_index);
auto type = AnfAlgo::GetOutputDeviceDataType(ref_real_node, ref_real_node_index);
parameter->set_kernel_info(std::make_shared<device::KernelInfo>());
auto d_kernel_info = parameter->kernel_info();
auto d_kernel_info = std::make_shared<device::KernelInfo>();
MS_EXCEPTION_IF_NULL(d_kernel_info);
parameter->set_kernel_info(d_kernel_info);
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
builder.SetOutputsDeviceType({type});
builder.SetOutputsFormat({format});
......
......@@ -128,7 +128,7 @@ void DumpKernelInfo(const CNodePtr &node, const std::shared_ptr<SubGraphIRInfo>
return;
}
auto kernel_info = node->kernel_info();
if (kernel_info == nullptr || kernel_info->select_kernel_build_info() == nullptr) {
if (kernel_info == nullptr || !kernel_info->has_build_info()) {
return;
}
......@@ -179,7 +179,7 @@ void DumpParams(const FuncGraphPtr &graph, std::ostringstream &buffer, OrderedMa
// print parameters' type and shape
PrintNodeOutputType(buffer, p);
auto kernel_info = p->kernel_info();
if (kernel_info != nullptr && kernel_info->select_kernel_build_info() != nullptr) {
if (kernel_info != nullptr && kernel_info->has_build_info()) {
buffer << " : ";
auto type = AnfAlgo::GetOutputDeviceDataType(p, 0);
auto format = AnfAlgo::GetOutputFormat(p, 0);
......
......@@ -362,8 +362,7 @@ void CheckFormatsAndDtypes(const CNodePtr &kernel_node, const std::vector<AnfNod
continue;
}
for (auto &node_user : iter->second) {
if (node_user.first->kernel_info() == nullptr ||
node_user.first->kernel_info()->select_kernel_build_info() == nullptr) {
if (node_user.first->kernel_info() == nullptr || !node_user.first->kernel_info()->has_build_info()) {
// maybe not a real kernel.
continue;
}
......
......@@ -21,8 +21,7 @@
#include <vector>
#include <memory>
#include "ir/dtype.h"
using std::string;
#include "ir/device_sync.h"
namespace mindspore {
namespace device {
......@@ -51,15 +50,12 @@ namespace device {
enum class DeviceAddressStatus { kInDevice, kInHost, kInDeviceToHost, kInHostToDevice };
enum class DeviceAddressType { kUnknown, kAscend, kCPU, kGPU };
class DeviceAddress {
class DeviceAddress : public mindspore::DeviceSync {
public:
explicit DeviceAddress(void *ptr, size_t size) : ptr_(ptr), size_(size) {}
explicit DeviceAddress(void *ptr, size_t size, const string &format, TypeId type_id)
: ptr_(ptr), size_(size), format_(format), type_id_(type_id) {}
virtual ~DeviceAddress() { ptr_ = nullptr; }
virtual bool SyncDeviceToHost(const std::vector<int> &shape, size_t size, TypeId type, void *host_ptr) const = 0;
virtual bool SyncHostToDevice(const std::vector<int> &shape, size_t size, TypeId type,
const void *host_ptr) const = 0;
const void *GetPtr() const { return ptr_; }
size_t GetSize() const { return size_; }
std::string format() const { return format_; }
......
......@@ -19,6 +19,7 @@
#include <vector>
#include <memory>
#include "ir/kernel_info_dev.h"
#include "backend/kernel_compiler/kernel_build_info.h"
#include "runtime/device/ascend/ascend_device_address.h"
#include "backend/kernel_compiler/kernel.h"
......@@ -27,7 +28,7 @@ namespace mindspore {
const uint32_t kInvalidGraphId = UINT32_MAX;
const uint32_t kInvalidDistincLabel = UINT32_MAX;
namespace device {
class KernelInfo {
class KernelInfo : public KernelInfoDevice {
public:
KernelInfo() {
kernel_mod_ = nullptr;
......@@ -41,6 +42,7 @@ class KernelInfo {
}
virtual ~KernelInfo() = default;
bool has_build_info() const override { return select_kernel_build_info() != nullptr; }
const kernel::KernelBuildInfo *select_kernel_build_info() const;
kernel::KernelBuildInfoPtr GetMutableSelectKernelBuildInfo() const;
void set_select_kernel_build_info(const kernel::KernelBuildInfoPtr &select_kernel_build_info) {
......
......@@ -214,8 +214,10 @@ void KernelRuntime::RunOpAssignInputMemory(const std::vector<tensor::TensorPtr>
auto output_size = AnfAlgo::GetOutputTensorNum(item);
for (size_t index = 0; index < output_size; index++) {
MS_EXCEPTION_IF_NULL(input_tensors[input_index]);
if (input_tensors[input_index]->device_address().get() != nullptr) {
AnfAlgo::SetOutputAddr(input_tensors[input_index]->device_address(), index, item.get());
auto output_address =
std::dynamic_pointer_cast<device::DeviceAddress>(input_tensors[input_index]->device_address());
if (output_address != nullptr) {
AnfAlgo::SetOutputAddr(output_address, index, item.get());
continue;
}
TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(item, index);
......
......@@ -27,8 +27,9 @@
#include <utility>
#include "base/base.h"
#include "debug/info.h"
#include "ir/kernel_info_dev.h"
#include "ir/scope.h"
#include "debug/info.h"
// A MindSpore ANF IR defined here.
// with BNF followed:
......@@ -71,12 +72,6 @@ class BaseRef;
class Var;
using VarPtr = std::shared_ptr<Var>;
namespace device {
class KernelInfo;
} // namespace device
using KernelInfoDevice = device::KernelInfo;
using KernelInfoDevicePtr = std::shared_ptr<KernelInfoDevice>;
class AnfVisitor;
class ParamValue;
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_IR_DEVICE_SYNC_H_
#define MINDSPORE_CCSRC_IR_DEVICE_SYNC_H_
#include <vector>
#include <memory>
#include <string>
#include "ir/dtype/type.h"
using std::string;
namespace mindspore {
// Interface for data synchornize between device and host.
class DeviceSync {
public:
virtual bool SyncDeviceToHost(const std::vector<int> &shape, size_t size, TypeId type, void *host_ptr) const = 0;
virtual bool SyncHostToDevice(const std::vector<int> &shape, size_t size, TypeId type,
const void *host_ptr) const = 0;
};
using DeviceSyncPtr = std::shared_ptr<DeviceSync>;
} // namespace mindspore
#endif // MINDSPORE_CCSRC_IR_DEVICE_SYNC_H_
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_IR_KERNEL_INFO_DEV_H_
#define MINDSPORE_CCSRC_IR_KERNEL_INFO_DEV_H_
#include <memory>
namespace mindspore {
// Interface for device kernel program information.
class KernelInfoDevice {
public:
// If kernel program was built and build info is set.
virtual bool has_build_info() const = 0;
};
using KernelInfoDevicePtr = std::shared_ptr<KernelInfoDevice>;
} // namespace mindspore
#endif // MINDSPORE_CCSRC_IR_KERNEL_INFO_DEV_H_
......@@ -326,7 +326,7 @@ Tensor::Tensor(const Tensor &tensor)
data_(tensor.data_),
dirty_(tensor.dirty_),
id_(tensor.id_),
device_address_(tensor.device_address_) {}
device_sync_(tensor.device_sync_) {}
Tensor::Tensor(const Tensor &tensor, TypeId data_type)
: MetaTensor(data_type, tensor.shape_),
......@@ -334,7 +334,7 @@ Tensor::Tensor(const Tensor &tensor, TypeId data_type)
data_(MakeTensorData(data_type, tensor.shape_, tensor.data_->data(), tensor.data_type_)),
dirty_(tensor.dirty_),
id_(tensor.id_),
device_address_(tensor.device_address_) {}
device_sync_(tensor.device_sync_) {}
Tensor::Tensor(TypeId data_type, const std::vector<int> &shape, TensorDataPtr data)
: MetaTensor(data_type, shape), data_(std::move(data)), id_(MakeId()) {}
......@@ -379,10 +379,10 @@ bool Tensor::ValueEqual(const Tensor &tensor) const {
Tensor &Tensor::AssignValue(const Tensor &tensor) {
if (this != &tensor) {
MetaTensor::operator=(tensor);
dirty_ = tensor.is_dirty();
device_address_ = tensor.device_address();
dirty_ = tensor.dirty_;
device_sync_ = tensor.device_sync_;
data_ = tensor.data_;
id_ = tensor.id();
id_ = tensor.id_;
}
return *this;
}
......@@ -425,8 +425,8 @@ std::string Tensor::ToStringRepr() const {
}
void Tensor::data_sync() const {
if (device_address_ != nullptr) {
if (!device_address_->SyncDeviceToHost(shape(), static_cast<size_t>(data().nbytes()), data_type(), data_c())) {
if (device_sync_ != nullptr) {
if (!device_sync_->SyncDeviceToHost(shape(), static_cast<size_t>(data().nbytes()), data_type(), data_c())) {
MS_LOG(EXCEPTION) << "SyncDeviceToHost when asnumpy.";
}
}
......
......@@ -23,15 +23,13 @@
#include <numeric>
#include "Eigen/Core"
#include "runtime/device/device_address.h"
#include "ir/device_sync.h"
#include "ir/meta_tensor.h"
#include "include/ms_tensor.h"
#include "utils/log_adapter.h"
using float16 = Eigen::half;
using mindspore::device::DeviceAddress;
using DeviceAddressPtr = std::shared_ptr<mindspore::device::DeviceAddress>;
// brief mindspore namespace.
//
// mindspore namespace is the top level namespace of MindSpore project.
......@@ -222,8 +220,8 @@ class Tensor : public MetaTensor {
bool is_dirty() const { return dirty_; }
void set_dirty(const bool dirty) { dirty_ = dirty; }
DeviceAddressPtr device_address() const { return device_address_; }
void set_device_address(const DeviceAddressPtr &device_address) { device_address_ = device_address; }
DeviceSyncPtr device_address() const { return device_sync_; }
void set_device_address(const DeviceSyncPtr &device_sync) { device_sync_ = device_sync; }
std::string id() const { return id_; }
......@@ -234,7 +232,7 @@ class Tensor : public MetaTensor {
TensorDataPtr data_{nullptr};
bool dirty_{true};
std::string id_{""};
DeviceAddressPtr device_address_{nullptr};
DeviceSyncPtr device_sync_{nullptr};
};
using TensorPtr = std::shared_ptr<Tensor>;
using TensorPtrList = std::vector<std::shared_ptr<Tensor>>;
......
......@@ -22,7 +22,6 @@
#include <sstream>
#include <string>
#include "runtime/device/device_address.h"
#include "pybind_api/api_register.h"
#include "pybind_api/export_flags.h"
#include "abstract/abstract_value.h"
......
......@@ -81,8 +81,6 @@ struct type_caster<float16> : public npy_scalar_caster<float16> {
} // namespace detail
} // namespace pybind11
using mindspore::device::DeviceAddress;
using DeviceAddressPtr = std::shared_ptr<mindspore::device::DeviceAddress>;
// brief mindspore namespace.
//
// mindspore namespace is the top level namespace of Mindsporeession project.
......
......@@ -255,7 +255,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetOutputFormat) {
AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32, kNumberTypeFloat32}, {shape, shape}, add.get());
MS_EXCEPTION_IF_NULL(add);
add->set_kernel_info(std::make_shared<KernelInfo>());
auto d_kernel_info = add->kernel_info();
auto d_kernel_info = dynamic_cast<KernelInfo *>(add->kernel_info());
MS_EXCEPTION_IF_NULL(d_kernel_info);
KernelBuildInfoBuilder builder;
builder.SetOutputsDeviceType({kFloat32->type_id(), kFloat16->type_id()});
......@@ -274,7 +274,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetInputFormat) {
auto add = kernel_graph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(add);
add->set_kernel_info(std::make_shared<KernelInfo>());
auto d_kernel_info = add->kernel_info();
auto d_kernel_info = dynamic_cast<KernelInfo *>(add->kernel_info());
MS_EXCEPTION_IF_NULL(d_kernel_info);
KernelBuildInfoBuilder builder;
builder.SetInputsDeviceType({kFloat32->type_id(), kFloat16->type_id()});
......@@ -293,7 +293,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetPrevNodeOutputFormat) {
auto pre_add = kernel_graph->NewCNode(pre_node_inputs);
MS_EXCEPTION_IF_NULL(pre_add);
pre_add->set_kernel_info(std::make_shared<KernelInfo>());
auto d_kernel_info = pre_add->kernel_info();
auto d_kernel_info = dynamic_cast<KernelInfo *>(pre_add->kernel_info());
MS_EXCEPTION_IF_NULL(d_kernel_info);
KernelBuildInfoBuilder builder;
builder.SetOutputsDeviceType({kFloat32->type_id()});
......@@ -373,7 +373,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetOutputDeviceShape) {
MS_EXCEPTION_IF_NULL(add);
add->set_abstract(tuple_abstract);
add->set_kernel_info(std::make_shared<KernelInfo>());
auto d_kernel_info = add->kernel_info();
auto d_kernel_info = dynamic_cast<KernelInfo *>(add->kernel_info());
MS_EXCEPTION_IF_NULL(d_kernel_info);
KernelBuildInfoBuilder builder;
builder.SetOutputsFormat({kOpFormat_NCHW, kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_FRAC_NZ});
......@@ -404,7 +404,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetInputDeviceShape) {
auto add = kernel_graph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(add);
add->set_kernel_info(std::make_shared<KernelInfo>());
auto d_kernel_info = add->kernel_info();
auto d_kernel_info = dynamic_cast<KernelInfo *>(add->kernel_info());
MS_EXCEPTION_IF_NULL(d_kernel_info);
KernelBuildInfoBuilder builder;
builder.SetInputsFormat({kOpFormat_NCHW, kOpFormat_NCHW, kOpFormat_NHWC});
......@@ -457,7 +457,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetOutputDeviceDataTypeTest) {
auto add = kernel_graph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(add);
add->set_kernel_info(std::make_shared<KernelInfo>());
auto d_kernel_info = add->kernel_info();
auto d_kernel_info = dynamic_cast<KernelInfo *>(add->kernel_info());
MS_EXCEPTION_IF_NULL(d_kernel_info);
KernelBuildInfoBuilder builder;
builder.SetOutputsDeviceType({kFloat32->type_id()});
......@@ -474,7 +474,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetInputDeviceDataTypeTest) {
auto add = kernel_graph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(add);
add->set_kernel_info(std::make_shared<KernelInfo>());
auto d_kernel_info = add->kernel_info();
auto d_kernel_info = dynamic_cast<KernelInfo *>(add->kernel_info());
MS_EXCEPTION_IF_NULL(d_kernel_info);
KernelBuildInfoBuilder builder;
builder.SetInputsDeviceType({kFloat32->type_id(), kFloat16->type_id()});
......@@ -492,7 +492,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetPrevNodeOutputDeviceDataType) {
auto pre_add = kernel_graph->NewCNode(pre_add_inputs);
MS_EXCEPTION_IF_NULL(pre_add);
pre_add->set_kernel_info(std::make_shared<KernelInfo>());
auto d_kernel_info = pre_add->kernel_info();
auto d_kernel_info = dynamic_cast<KernelInfo *>(pre_add->kernel_info());
MS_EXCEPTION_IF_NULL(d_kernel_info);
KernelBuildInfoBuilder builder;
builder.SetOutputsDeviceType({kFloat32->type_id()});
......@@ -513,7 +513,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetOutputAddr) {
auto add = kernel_graph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(add);
add->set_kernel_info(std::make_shared<KernelInfo>());
auto d_kernel_info = add->kernel_info();
auto d_kernel_info = dynamic_cast<KernelInfo *>(add->kernel_info());
MS_EXCEPTION_IF_NULL(d_kernel_info);
int *addr = nullptr;
auto device_address = std::make_shared<AscendDeviceAddress>(addr, 1);
......@@ -528,7 +528,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetPrevNodeOutputAddr) {
auto pre_add = kernel_graph->NewCNode(pre_add_inputs);
MS_EXCEPTION_IF_NULL(pre_add);
pre_add->set_kernel_info(std::make_shared<KernelInfo>());
auto d_kernel_info = pre_add->kernel_info();
auto d_kernel_info = dynamic_cast<KernelInfo *>(pre_add->kernel_info());
MS_EXCEPTION_IF_NULL(d_kernel_info);
int *addr = nullptr;
auto device_address = std::make_shared<AscendDeviceAddress>(addr, 1);
......@@ -561,7 +561,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetWorkspaceAddr) {
auto add = kernel_graph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(add);
add->set_kernel_info(std::make_shared<KernelInfo>());
auto d_kernel_info = add->kernel_info();
auto d_kernel_info = dynamic_cast<KernelInfo *>(add->kernel_info());
MS_EXCEPTION_IF_NULL(d_kernel_info);
int *addr = nullptr;
auto device_address = std::make_shared<AscendDeviceAddress>(addr, 1);
......@@ -643,7 +643,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetKernelType) {
auto add = kernel_graph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(add);
add->set_kernel_info(std::make_shared<KernelInfo>());
auto d_kernel_info = add->kernel_info();
auto d_kernel_info = dynamic_cast<KernelInfo *>(add->kernel_info());
MS_EXCEPTION_IF_NULL(d_kernel_info);
KernelBuildInfoBuilder builder;
builder.SetKernelType(AKG_KERNEL);
......@@ -659,7 +659,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetProcessor) {
auto add = kernel_graph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(add);
add->set_kernel_info(std::make_shared<KernelInfo>());
auto d_kernel_info = add->kernel_info();
auto d_kernel_info = dynamic_cast<KernelInfo *>(add->kernel_info());
MS_EXCEPTION_IF_NULL(d_kernel_info);
KernelBuildInfoBuilder builder;
builder.SetProcessor(kernel::AICORE);
......@@ -675,7 +675,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetFusionType) {
auto add = kernel_graph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(add);
add->set_kernel_info(std::make_shared<KernelInfo>());
auto d_kernel_info = add->kernel_info();
auto d_kernel_info = dynamic_cast<KernelInfo *>(add->kernel_info());
MS_EXCEPTION_IF_NULL(d_kernel_info);
KernelBuildInfoBuilder builder;
builder.SetFusionType(kernel::CONVLUTION);
......@@ -703,7 +703,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetKernelMod) {
auto add = kernel_graph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(add);
add->set_kernel_info(std::make_shared<KernelInfo>());
auto d_kernel_info = add->kernel_info();
auto d_kernel_info = dynamic_cast<KernelInfo *>(add->kernel_info());
MS_EXCEPTION_IF_NULL(d_kernel_info);
d_kernel_info->set_kernel_mod(nullptr);
EXPECT_EQ(AnfAlgo::GetKernelMod(add), nullptr);
......@@ -779,7 +779,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetStreamId) {
auto add = kernel_graph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(add);
add->set_kernel_info(std::make_shared<KernelInfo>());
auto d_kernel_info = add->kernel_info();
auto d_kernel_info = dynamic_cast<KernelInfo *>(add->kernel_info());
MS_EXCEPTION_IF_NULL(d_kernel_info);
d_kernel_info->set_stream_id(0);
EXPECT_EQ(AnfAlgo::GetStreamId(add), 0);
......
......@@ -42,7 +42,7 @@ TEST_F(KernelGraphTest, NewValueNode) {
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shape);
add_value->set_abstract(x_abstract);
add_value->set_kernel_info(std::make_shared<KernelInfo>());
auto mutable_kernel_info = add_value->kernel_info();
auto mutable_kernel_info = dynamic_cast<device::KernelInfo *>(add_value->kernel_info());
MS_EXCEPTION_IF_NULL(mutable_kernel_info);
std::shared_ptr<KernelBuildInfoBuilder> builder = std::make_shared<KernelBuildInfoBuilder>();
builder->SetOutputsFormat({kOpFormat_FRAC_Z});
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册