/** * Copyright 2019-2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "session/anf_runtime_algorithm.h" #include #include #include #include #include "ir/anf.h" #include "ir/func_graph.h" #include "operator/ops.h" #include "utils/utils.h" #include "device/kernel_info.h" #include "device/device_address.h" #include "pre_activate/common/helper.h" #include "kernel/kernel.h" #include "kernel/kernel_build_info.h" #include "common/utils.h" #include "common/trans.h" namespace mindspore { namespace session { using abstract::AbstractTensor; using abstract::AbstractTuple; using device::KernelInfo; using device::ascend::AscendDeviceAddress; using kernel::KernelBuildInfoPtr; using kernel::KernelMod; using kernel::KernelModPtr; namespace { std::vector TransShapeToSizet(const abstract::ShapePtr &shape) { MS_EXCEPTION_IF_NULL(shape); std::vector shape_size_t; std::transform(shape->shape().begin(), shape->shape().end(), std::back_inserter(shape_size_t), IntToSize); return shape_size_t; } } // namespace KernelWithIndex AnfRuntimeAlgorithm::VisitKernel(const AnfNodePtr &anf_node, size_t index) { MS_EXCEPTION_IF_NULL(anf_node); if (anf_node->isa()) { return std::make_pair(anf_node, 0); } else if (anf_node->isa()) { return std::make_pair(anf_node, 0); } else if (anf_node->isa()) { auto cnode = anf_node->cast(); MS_EXCEPTION_IF_NULL(cnode); auto input0 = cnode->input(0); MS_EXCEPTION_IF_NULL(input0); if (IsPrimitive(input0, prim::kPrimMakeTuple)) { auto node = cnode->input(index + IntToSize(1)); MS_EXCEPTION_IF_NULL(node); return VisitKernel(node, 0); } else if (IsPrimitive(input0, prim::kPrimTupleGetItem)) { if (cnode->inputs().size() != kTupleGetItemInputSize) { MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!"; } auto input2 = cnode->input(kInputNodeOutputIndexInTupleGetItem); MS_EXCEPTION_IF_NULL(input2); auto value_node = input2->cast(); MS_EXCEPTION_IF_NULL(value_node); int item_idx = GetValue(value_node->value()); return VisitKernel(cnode->input(kRealInputNodeIndexInTupleGetItem), IntToSize(item_idx)); } else if (IsPrimitive(input0, prim::kPrimDepend) || IsPrimitive(input0, prim::kPrimControlDepend)) { return VisitKernel(cnode->input(kRealInputIndexInDepend), 0); } else { return std::make_pair(anf_node, index); } } else { MS_LOG(EXCEPTION) << "The input is invalid"; } } KernelWithIndex AnfRuntimeAlgorithm::VisitKernelWithReturnType(const AnfNodePtr &anf_node, size_t index, bool visit_nop_node, const std::vector &return_types) { MS_EXCEPTION_IF_NULL(anf_node); for (const auto &prim_type : return_types) { if (CheckPrimitiveType(anf_node, prim_type)) { return std::make_pair(anf_node, index); } } if (anf_node->isa()) { return std::make_pair(anf_node, 0); } else if (anf_node->isa()) { return std::make_pair(anf_node, 0); } else if (anf_node->isa()) { auto cnode = anf_node->cast(); MS_EXCEPTION_IF_NULL(cnode); auto input0 = cnode->input(0); MS_EXCEPTION_IF_NULL(input0); if (IsPrimitive(input0, prim::kPrimTupleGetItem)) { if (cnode->inputs().size() != kTupleGetItemInputSize) { MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!"; } auto input2 = cnode->input(kInputNodeOutputIndexInTupleGetItem); MS_EXCEPTION_IF_NULL(input2); auto value_node = input2->cast(); MS_EXCEPTION_IF_NULL(value_node); int item_idx = GetValue(value_node->value()); return VisitKernelWithReturnType(cnode->input(kRealInputNodeIndexInTupleGetItem), IntToSize(item_idx), visit_nop_node, return_types); } else if (IsPrimitive(input0, prim::kPrimDepend) || IsPrimitive(input0, prim::kPrimControlDepend)) { return VisitKernelWithReturnType(cnode->input(kRealInputIndexInDepend), 0, visit_nop_node, return_types); } else if (opt::IsNopNode(cnode) && visit_nop_node) { if (cnode->inputs().size() == 2) { return VisitKernelWithReturnType(cnode->input(1), 0, visit_nop_node, return_types); } else { MS_LOG(EXCEPTION) << cnode->DebugString() << "Invalid nop node"; } } else { return std::make_pair(anf_node, index); } } else { MS_LOG(EXCEPTION) << "The input is invalid"; } } std::vector AnfRuntimeAlgorithm::GetAllOutput(const AnfNodePtr &node, const std::vector &return_types) { std::vector ret; auto return_prim_type = return_types; // if visited make_tuple should return back return_prim_type.push_back(prim::kPrimMakeTuple); auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, false, return_prim_type); if (AnfAlgo::CheckPrimitiveType(item_with_index.first, prim::kPrimMakeTuple)) { MS_EXCEPTION_IF_NULL(item_with_index.first); auto make_tuple = item_with_index.first->cast(); MS_EXCEPTION_IF_NULL(make_tuple); for (size_t i = 1; i < make_tuple->inputs().size(); i++) { auto input_i_vector = GetAllOutput(make_tuple->input(i), return_types); (void)std::copy(input_i_vector.begin(), input_i_vector.end(), std::back_inserter(ret)); } return ret; } ret.push_back(item_with_index.first); return ret; } AnfNodePtr AnfRuntimeAlgorithm::GetCNodePrimitiveNode(const CNodePtr &node) { MS_EXCEPTION_IF_NULL(node); return node->input(kAnfPrimitiveIndex); } PrimitivePtr AnfRuntimeAlgorithm::GetCNodePrimitive(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); auto attr_input = GetCNodePrimitiveNode(cnode); MS_EXCEPTION_IF_NULL(attr_input); auto value_node = attr_input->cast(); MS_EXCEPTION_IF_NULL(value_node); auto value = value_node->value(); MS_EXCEPTION_IF_NULL(value); auto primitive = value->cast(); return primitive; } bool AnfRuntimeAlgorithm::CheckPrimitiveType(const AnfNodePtr &node, const PrimitivePtr &primitive_type) { MS_EXCEPTION_IF_NULL(node); if (!node->isa()) { return false; } auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); return IsPrimitive(cnode->input(kAnfPrimitiveIndex), primitive_type); } std::string AnfRuntimeAlgorithm::GetCNodeName(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); if (node->isa()) { auto primitive = AnfAlgo::GetCNodePrimitive(node); MS_EXCEPTION_IF_NULL(primitive); return primitive->name(); } MS_LOG(EXCEPTION) << "Unknown anf node type " << node->DebugString(); } std::string AnfRuntimeAlgorithm::GetNodeDebugString(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); return node->DebugString(); } void AnfRuntimeAlgorithm::SetNodeAttr(const std::string &key, const ValuePtr &value, const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); if (!node->isa()) { MS_LOG(EXCEPTION) << "Only cnode has attr, but this anf is " << node->DebugString(); } auto primitive = AnfAlgo::GetCNodePrimitive(node); MS_EXCEPTION_IF_NULL(primitive); primitive->set_attr(key, value); } void AnfRuntimeAlgorithm::CopyNodeAttr(const std::string &key, const AnfNodePtr &from, const AnfNodePtr &to) { CopyNodeAttr(key, key, from, to); } void AnfRuntimeAlgorithm::CopyNodeAttr(const std::string &old_key, const std::string &new_key, const AnfNodePtr &from, const AnfNodePtr &to) { MS_EXCEPTION_IF_NULL(from); MS_EXCEPTION_IF_NULL(to); if (!from->isa() || !to->isa()) { MS_LOG(EXCEPTION) << "Only cnode has attr, but this from_anf is " << from->DebugString() << " ,to_node is " << to->DebugString(); } auto from_primitive = AnfAlgo::GetCNodePrimitive(from); MS_EXCEPTION_IF_NULL(from_primitive); auto to_primitive = AnfAlgo::GetCNodePrimitive(to); MS_EXCEPTION_IF_NULL(to_primitive); to_primitive->set_attr(new_key, from_primitive->GetAttr(old_key)); } void AnfRuntimeAlgorithm::CopyNodeAttrs(const AnfNodePtr &from, const AnfNodePtr &to) { MS_EXCEPTION_IF_NULL(from); MS_EXCEPTION_IF_NULL(to); if (!from->isa() || !to->isa()) { MS_LOG(EXCEPTION) << "Only cnode has attr, but this from_anf is " << from->DebugString() << ",to_node is " << from->DebugString(); } auto from_primitive = AnfAlgo::GetCNodePrimitive(from); MS_EXCEPTION_IF_NULL(from_primitive); auto to_primitive = AnfAlgo::GetCNodePrimitive(to); MS_EXCEPTION_IF_NULL(to_primitive); (void)to_primitive->SetAttrs(from_primitive->attrs()); } void AnfRuntimeAlgorithm::EraseNodeAttr(const std::string &key, const AnfNodePtr node) { MS_EXCEPTION_IF_NULL(node); if (!node->isa()) { MS_LOG(EXCEPTION) << "Only cnode has attr, but this anf is " << node->DebugString(); } auto primitive = AnfAlgo::GetCNodePrimitive(node); MS_EXCEPTION_IF_NULL(primitive); primitive->EraseAttr(key); } bool AnfRuntimeAlgorithm::HasNodeAttr(const std::string &key, const CNodePtr &node) { MS_EXCEPTION_IF_NULL(node); auto primitive = AnfAlgo::GetCNodePrimitive(node); MS_EXCEPTION_IF_NULL(primitive); return primitive->HasAttr(key); } size_t AnfRuntimeAlgorithm::GetInputTensorNum(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); if (!node->isa()) { MS_LOG(EXCEPTION) << "Only cnode has real input, but this anf is " << node->DebugString(); } auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); size_t input_num = cnode->inputs().size(); if (input_num == 0) { MS_LOG(EXCEPTION) << "cnode inputs size can't be zero"; } // exclude intputs[0],which is value_node storing attr,inputs left are real input return input_num - 1; } size_t AnfRuntimeAlgorithm::GetOutputTensorNum(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); TypePtr type = node->Type(); if (type == nullptr) { return 0; } if (type->isa()) { auto tuple_type = type->cast(); MS_EXCEPTION_IF_NULL(tuple_type); return tuple_type->size(); } else if (type->isa() || type->isa()) { return 1; } else if (type->isa()) { return 0; } else { return 1; } } std::string AnfRuntimeAlgorithm::GetOutputFormat(const AnfNodePtr &node, size_t output_idx) { MS_EXCEPTION_IF_NULL(node); if (output_idx > GetOutputTensorNum(node)) { MS_LOG(EXCEPTION) << "Output index:" << output_idx << " is out of the node output range :" << GetOutputTensorNum(node) << " #node [" << node->DebugString() << "]"; } if (!AnfAlgo::IsRealKernel(node)) { return AnfAlgo::GetPrevNodeOutputFormat(node, output_idx); } auto kernel_info = node->kernel_info(); MS_EXCEPTION_IF_NULL(kernel_info); auto build_info = kernel_info->select_kernel_build_info(); MS_EXCEPTION_IF_NULL(build_info); auto format = build_info->GetOutputFormat(output_idx); if (format == kernel::KernelBuildInfo::kInvalidFormat) { MS_LOG(EXCEPTION) << "Node [" << node->DebugString() << "]" << " has a invalid output format"; } return format; } std::string AnfRuntimeAlgorithm::GetInputFormat(const AnfNodePtr &node, size_t input_idx) { MS_EXCEPTION_IF_NULL(node); if (input_idx > GetInputTensorNum(node)) { MS_LOG(EXCEPTION) << "Input index :" << input_idx << " is out of the number node Input range :" << GetInputTensorNum(node) << "#node [" << node->DebugString() << "]"; } if (!IsRealKernel(node)) { GetPrevNodeOutputFormat(node, input_idx); } auto kernel_info = node->kernel_info(); MS_EXCEPTION_IF_NULL(kernel_info); auto build_info = kernel_info->select_kernel_build_info(); MS_EXCEPTION_IF_NULL(build_info); auto format = build_info->GetInputFormat(input_idx); if (format == kernel::KernelBuildInfo::kInvalidFormat) { MS_LOG(EXCEPTION) << "Node [" << node->DebugString() << "]" << " has a invalid input format"; } return format; } KernelWithIndex AnfRuntimeAlgorithm::GetPrevNodeOutput(const AnfNodePtr &anf_node, size_t input_idx) { MS_EXCEPTION_IF_NULL(anf_node); if (!anf_node->isa()) { MS_LOG(EXCEPTION) << anf_node->DebugString() << "anf_node is not CNode."; } auto cnode = anf_node->cast(); MS_EXCEPTION_IF_NULL(cnode); if (input_idx + 1 >= cnode->inputs().size()) { MS_LOG(EXCEPTION) << "Input index " << input_idx << " is larger than input number " << GetInputTensorNum(cnode); } auto node = cnode->input(input_idx + 1); MS_EXCEPTION_IF_NULL(node); return VisitKernel(node, 0); } std::string AnfRuntimeAlgorithm::GetPrevNodeOutputFormat(const AnfNodePtr &anf_node, size_t input_idx) { KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx); return AnfRuntimeAlgorithm::GetOutputFormat(kernel_with_index.first, kernel_with_index.second); } std::vector AnfRuntimeAlgorithm::GetPrevNodeOutputReshapeType(const AnfNodePtr &node, size_t input_idx) { KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(node, input_idx); return GetOutputReshapeType(kernel_with_index.first, kernel_with_index.second); } std::vector AnfRuntimeAlgorithm::GetOutputInferShape(const AnfNodePtr &node, size_t output_idx) { MS_EXCEPTION_IF_NULL(node); abstract::BaseShapePtr base_shape = node->Shape(); MS_EXCEPTION_IF_NULL(base_shape); if (base_shape->isa() && output_idx == 0) { return TransShapeToSizet(base_shape->cast()); } else if (base_shape->isa()) { auto tuple_shape = base_shape->cast(); MS_EXCEPTION_IF_NULL(tuple_shape); if (output_idx >= tuple_shape->size()) { MS_LOG(EXCEPTION) << "Output index " << output_idx << "is larger than output number " << tuple_shape->size() << "."; } auto b_shp = (*tuple_shape)[output_idx]; if (b_shp->isa()) { return TransShapeToSizet(b_shp->cast()); } else if (b_shp->isa()) { return std::vector(); } else { MS_LOG(EXCEPTION) << "The output type of ApplyKernel index:" << output_idx << " should be a NoShape , ArrayShape or a TupleShape, but it is " << base_shape->ToString(); } } else if (base_shape->isa()) { return std::vector(); } MS_LOG(EXCEPTION) << "The output type of ApplyKernel should be a NoShape , ArrayShape or a TupleShape, but it is " << base_shape->ToString(); } std::vector AnfRuntimeAlgorithm::GetPrevNodeOutputInferShape(const AnfNodePtr &node, size_t input_idx) { KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(node, input_idx); return AnfRuntimeAlgorithm::GetOutputInferShape(kernel_with_index.first, kernel_with_index.second); } std::vector AnfRuntimeAlgorithm::GetOutputDeviceShape(const AnfNodePtr &node, size_t output_idx) { auto format = GetOutputFormat(node, output_idx); auto infer_shape = GetOutputInferShape(node, output_idx); if (infer_shape.empty()) { return infer_shape; } // if format is default_format or NC1KHKWHWC0,device shape = original shape if (trans::IsNeedPadding(format, infer_shape.size())) { infer_shape = trans::PaddingShapeTo4d(infer_shape, GetOutputReshapeType(node, output_idx)); } return trans::TransShapeToDevice(infer_shape, format); } std::vector AnfRuntimeAlgorithm::GetInputDeviceShape(const AnfNodePtr &node, size_t input_idx) { auto format = GetInputFormat(node, input_idx); auto infer_shape = GetPrevNodeOutputInferShape(node, input_idx); if (infer_shape.empty()) { return infer_shape; } // if format is default_format or NC1KHKWHWC0,device shape = original shape if (trans::IsNeedPadding(format, infer_shape.size())) { infer_shape = trans::PaddingShapeTo4d(infer_shape, GetInputReshapeType(node, input_idx)); } return trans::TransShapeToDevice(infer_shape, format); } std::vector AnfRuntimeAlgorithm::GetInputReshapeType(const AnfNodePtr &node, size_t input_idx) { MS_EXCEPTION_IF_NULL(node); if (input_idx > GetInputTensorNum(node)) { MS_LOG(EXCEPTION) << "The index:" << input_idx << " is out of range of the node's input size : " << GetInputTensorNum(node) << "#node[" << node->DebugString() << "]"; } if (!IsRealKernel(node)) { return GetPrevNodeOutputReshapeType(node, input_idx); } auto kernel_info = node->kernel_info(); MS_EXCEPTION_IF_NULL(kernel_info); auto build_info = kernel_info->select_kernel_build_info(); MS_EXCEPTION_IF_NULL(build_info); if (build_info->IsInputDefaultPadding()) { return {}; } return build_info->GetInputReshapeType(input_idx); } std::vector AnfRuntimeAlgorithm::GetOutputReshapeType(const AnfNodePtr &node, size_t output_idx) { MS_EXCEPTION_IF_NULL(node); if (output_idx > GetOutputTensorNum(node)) { MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ " << GetOutputTensorNum(node) << "#node[ " << node->DebugString() << "]"; } if (!IsRealKernel(node)) { return GetPrevNodeOutputReshapeType(node, output_idx); } auto kernel_info = node->kernel_info(); MS_EXCEPTION_IF_NULL(kernel_info); auto build_info = kernel_info->select_kernel_build_info(); MS_EXCEPTION_IF_NULL(build_info); if (build_info->IsOutputDefaultPadding()) { return {}; } return build_info->GetOutputReshapeType(output_idx); } TypeId AnfRuntimeAlgorithm::GetOutputInferDataType(const AnfNodePtr &node, size_t output_idx) { MS_EXCEPTION_IF_NULL(node); TypePtr type_ptr = node->Type(); MS_EXCEPTION_IF_NULL(type_ptr); if (type_ptr->isa() && output_idx == 0) { auto tensor_ptr = type_ptr->cast(); MS_EXCEPTION_IF_NULL(tensor_ptr); TypePtr elem = tensor_ptr->element(); MS_EXCEPTION_IF_NULL(elem); return elem->type_id(); } else if (type_ptr->isa()) { auto tuple_ptr = type_ptr->cast(); MS_EXCEPTION_IF_NULL(tuple_ptr); if (output_idx >= tuple_ptr->size()) { MS_LOG(EXCEPTION) << "Output index " << output_idx << " must be less than output number " << tuple_ptr->size(); } auto tuple_i = (*tuple_ptr)[output_idx]; MS_EXCEPTION_IF_NULL(tuple_i); if (tuple_i->isa()) { auto tensor_ptr = tuple_i->cast(); MS_EXCEPTION_IF_NULL(tensor_ptr); TypePtr elem = tensor_ptr->element(); MS_EXCEPTION_IF_NULL(elem); return elem->type_id(); } else if (tuple_i->isa()) { return tuple_i->type_id(); } else { MS_LOG(WARNING) << "Not support type " << tuple_i->ToString(); return tuple_i->type_id(); } } else if (type_ptr->isa()) { return type_ptr->type_id(); } return type_ptr->type_id(); } TypeId AnfRuntimeAlgorithm::GetPrevNodeOutputInferDataType(const AnfNodePtr &node, size_t input_idx) { KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(node, input_idx); return AnfRuntimeAlgorithm::GetOutputInferDataType(kernel_with_index.first, kernel_with_index.second); } TypeId AnfRuntimeAlgorithm::GetOutputDeviceDataType(const AnfNodePtr &node, size_t output_idx) { MS_EXCEPTION_IF_NULL(node); if (output_idx > GetOutputTensorNum(node)) { MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ " << GetOutputTensorNum(node) << "#node [ " << node->DebugString() << "]"; } if (!IsRealKernel(node)) { return GetPrevNodeOutputDeviceDataType(node, output_idx); } auto kernel_info = node->kernel_info(); MS_EXCEPTION_IF_NULL(kernel_info); auto build_info = kernel_info->select_kernel_build_info(); MS_EXCEPTION_IF_NULL(build_info); auto dtype = build_info->GetOutputDeviceType(output_idx); if (dtype == TypeId::kNumberTypeEnd) { MS_LOG(EXCEPTION) << "Node [" << node->DebugString() << "]" << " has a invalid dtype"; } return dtype; } TypeId AnfRuntimeAlgorithm::GetInputDeviceDataType(const AnfNodePtr &node, size_t input_idx) { MS_EXCEPTION_IF_NULL(node); if (input_idx > GetInputTensorNum(node)) { MS_LOG(EXCEPTION) << "The index [" << input_idx << "] is out of range of the node's input size [ " << GetInputTensorNum(node) << "#node [ " << node->DebugString() << "]"; } if (!IsRealKernel(node)) { return GetPrevNodeOutputDeviceDataType(node, 0); } auto kernel_info = node->kernel_info(); MS_EXCEPTION_IF_NULL(kernel_info); auto build_info = kernel_info->select_kernel_build_info(); MS_EXCEPTION_IF_NULL(build_info); auto dtype = build_info->GetInputDeviceType(input_idx); if (dtype == TypeId::kNumberTypeEnd) { MS_LOG(EXCEPTION) << "Node [" << node->DebugString() << "]" << " has a invalid dtype"; } return dtype; } TypeId AnfRuntimeAlgorithm::GetPrevNodeOutputDeviceDataType(const AnfNodePtr &anf_node, size_t input_idx) { KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx); return AnfRuntimeAlgorithm::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second); } // get output device addr of anf_node const DeviceAddress *AnfRuntimeAlgorithm::GetOutputAddr(const AnfNodePtr &node, size_t output_idx) { MS_EXCEPTION_IF_NULL(node); if (opt::IsNopNode(node)) { auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); if (cnode->inputs().size() == 2) { return AnfRuntimeAlgorithm::GetPrevNodeOutputAddr(cnode, 0); } else { MS_LOG(EXCEPTION) << node->DebugString() << "Invalid nop node"; } } auto kernel_info = node->kernel_info(); MS_EXCEPTION_IF_NULL(kernel_info); auto addr = kernel_info->GetOutputAddr(output_idx); if (addr == nullptr) { MS_LOG(EXCEPTION) << "Output_idx " << output_idx << " of node " << node->DebugString() << " output addr is not exist"; } return addr; } DeviceAddressPtr AnfRuntimeAlgorithm::GetMutableOutputAddr(const AnfNodePtr &node, size_t output_idx) { MS_EXCEPTION_IF_NULL(node); if (opt::IsNopNode(node)) { auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); if (cnode->inputs().size() == 2) { return AnfRuntimeAlgorithm::GetPrevNodeMutableOutputAddr(cnode, 0); } else { MS_LOG(EXCEPTION) << node->DebugString() << "Invalid nop node."; } } auto kernel_info = node->kernel_info(); MS_EXCEPTION_IF_NULL(kernel_info); auto addr = kernel_info->GetMutableOutputAddr(output_idx); if (addr == nullptr) { MS_LOG(EXCEPTION) << "Output_idx" << output_idx << " of node " << node->DebugString() << " output addr is not exist"; } return addr; } // get output device addr of anf_node bool AnfRuntimeAlgorithm::OutputAddrExist(const AnfNodePtr &node, size_t output_idx) { MS_EXCEPTION_IF_NULL(node); if (output_idx > GetOutputTensorNum(node)) { MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ " << GetOutputTensorNum(node) << "#node:[ " << node->DebugString() << "]"; } auto kernel_info = node->kernel_info(); MS_EXCEPTION_IF_NULL(kernel_info); return kernel_info->OutputAddrExist(output_idx); } const DeviceAddress *AnfRuntimeAlgorithm::GetPrevNodeOutputAddr(const AnfNodePtr &anf_node, size_t input_idx) { KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx); return AnfRuntimeAlgorithm::GetOutputAddr(kernel_with_index.first, kernel_with_index.second); } DeviceAddressPtr AnfRuntimeAlgorithm::GetPrevNodeMutableOutputAddr(const AnfNodePtr &anf_node, size_t input_idx) { KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx); return AnfRuntimeAlgorithm::GetMutableOutputAddr(kernel_with_index.first, kernel_with_index.second); } // set output device addr of anf_node void AnfRuntimeAlgorithm::SetOutputAddr(const DeviceAddressPtr &addr, size_t output_idx, AnfNode *node) { MS_EXCEPTION_IF_NULL(node); auto kernel_info = node->kernel_info(); MS_EXCEPTION_IF_NULL(kernel_info); if (!kernel_info->SetOutputAddr(addr, output_idx)) { MS_LOG(EXCEPTION) << "Node " << node->DebugString() << "set adr" << output_idx << " fail"; } } // set workspace device addr of anf_node void AnfRuntimeAlgorithm::SetWorkspaceAddr(const DeviceAddressPtr &addr, size_t output_idx, AnfNode *node) { MS_EXCEPTION_IF_NULL(node); auto kernel_info = node->kernel_info(); MS_EXCEPTION_IF_NULL(kernel_info); if (!kernel_info->SetWorkspaceAddr(addr, output_idx)) { MS_LOG(EXCEPTION) << "Node " << node->DebugString() << "set adr" << output_idx << " fail"; } } // get workspace device addr of anf_node DeviceAddress *AnfRuntimeAlgorithm::GetWorkspaceAddr(const AnfNodePtr &node, size_t output_idx) { MS_EXCEPTION_IF_NULL(node); auto kernel_info = node->kernel_info(); MS_EXCEPTION_IF_NULL(kernel_info); auto addr = kernel_info->GetWorkspaceAddr(output_idx); if (addr == nullptr) { MS_LOG(EXCEPTION) << "Output_idx " << output_idx << " of node " << node->DebugString() << "] workspace addr is not exist"; } return addr; } // set infer shapes and types of anf node void AnfRuntimeAlgorithm::SetOutputInferTypeAndShape(const std::vector &types, const std::vector> &shapes, AnfNode *node) { MS_EXCEPTION_IF_NULL(node); if (types.size() != shapes.size()) { MS_LOG(EXCEPTION) << "Types size " << types.size() << "should be same with shapes size " << shapes.size(); } if (shapes.empty()) { MS_LOG(EXCEPTION) << "Illegal empty output_types_shapes"; } else if (shapes.size() == 1) { // single output handle std::vector shape_int; std::transform(shapes[0].begin(), shapes[0].end(), std::back_inserter(shape_int), SizeToInt); auto abstract = std::make_shared(TypeIdToType(types[0]), shape_int); node->set_abstract(abstract); } else { // multiple output handle std::vector abstract_list; for (size_t i = 0; i < types.size(); ++i) { std::vector shape_int; std::transform(shapes[i].begin(), shapes[i].end(), std::back_inserter(shape_int), SizeToInt); abstract_list.push_back(std::make_shared(TypeIdToType(types[i]), shape_int)); } auto abstract_tuple = std::make_shared(abstract_list); node->set_abstract(abstract_tuple); } } // copy an abstract of a node to another node void AnfRuntimeAlgorithm::CopyAbstract(const AnfNodePtr &from_node, AnfNode *to_node) { to_node->set_abstract(from_node->abstract()); } kernel::OpPattern AnfRuntimeAlgorithm::GetOpPattern(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); auto kernel_info = node->kernel_info(); MS_EXCEPTION_IF_NULL(kernel_info); // select_kernel_build_info() has checked whether return pointer is null auto build_info = kernel_info->select_kernel_build_info(); MS_EXCEPTION_IF_NULL(build_info); return build_info->op_pattern(); } // get KernelBuildType of node, such as ATT,RT,FWK and so on KernelType AnfRuntimeAlgorithm::GetKernelType(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); auto kernel_info = node->kernel_info(); MS_EXCEPTION_IF_NULL(kernel_info); // select_kernel_build_info() has checked whether return pointer is null auto build_info = kernel_info->select_kernel_build_info(); MS_EXCEPTION_IF_NULL(build_info); return build_info->kernel_type(); } kernel::Processor AnfRuntimeAlgorithm::GetProcessor(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); auto kernel_info = node->kernel_info(); MS_EXCEPTION_IF_NULL(kernel_info); auto build_info = kernel_info->select_kernel_build_info(); MS_EXCEPTION_IF_NULL(build_info); return build_info->processor(); } kernel::FusionType AnfRuntimeAlgorithm::GetFusionType(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); auto kernel_info = node->kernel_info(); MS_EXCEPTION_IF_NULL(kernel_info); auto build_info = kernel_info->select_kernel_build_info(); MS_EXCEPTION_IF_NULL(build_info); return build_info->fusion_type(); } // set select kernel_build_info void AnfRuntimeAlgorithm::SetSelectKernelBuildInfo(const KernelBuildInfoPtr &select_kernel_build_info, AnfNode *node) { MS_EXCEPTION_IF_NULL(node); auto kernel_info = node->kernel_info(); MS_EXCEPTION_IF_NULL(kernel_info); return kernel_info->set_select_kernel_build_info(select_kernel_build_info); } // get select kernel_build_info KernelBuildInfoPtr AnfRuntimeAlgorithm::GetSelectKernelBuildInfo(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); auto kernel_info = node->kernel_info(); MS_EXCEPTION_IF_NULL(kernel_info); return kernel_info->GetMutableSelectKernelBuildInfo(); } // get kernelMode KernelMod *AnfRuntimeAlgorithm::GetKernelMod(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); auto kernel_info = node->kernel_info(); MS_EXCEPTION_IF_NULL(kernel_info); return kernel_info->MutableKernelMod(); } // set kernel mod void AnfRuntimeAlgorithm::SetKernelMod(const KernelModPtr &kernel_mod, AnfNode *node) { MS_EXCEPTION_IF_NULL(node); auto kernel_info = node->kernel_info(); MS_EXCEPTION_IF_NULL(kernel_info); kernel_info->set_kernel_mod(kernel_mod); } bool AnfRuntimeAlgorithm::IsRealKernel(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); // parameter and value node is not a real kernel too if (!node->isa()) { return true; } auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); if (cnode->inputs().empty()) { MS_LOG(EXCEPTION) << "Illegal null input of cnode(%s)" << node->DebugString(); } auto input = cnode->inputs()[0]; bool is_virtual_node = IsPrimitive(input, prim::kPrimImageSummary) || IsPrimitive(input, prim::kPrimScalarSummary) || IsPrimitive(input, prim::kPrimTensorSummary) || IsPrimitive(input, prim::kPrimHistogramSummary) || IsPrimitive(input, prim::kPrimMakeTuple) || IsPrimitive(input, prim::kPrimStateSetItem) || IsPrimitive(input, prim::kPrimDepend) || IsPrimitive(input, prim::kPrimTupleGetItem) || IsPrimitive(input, prim::kPrimControlDepend) || IsPrimitive(input, prim::kPrimReturn); return !is_virtual_node; } bool AnfRuntimeAlgorithm::IsRealCNodeKernel(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); // parameter and value node is not a real cnode kernel if (!node->isa()) { return false; } // return considered as a real node if (CheckPrimitiveType(node, prim::kPrimReturn)) { return true; } return IsRealKernel(node); } bool AnfRuntimeAlgorithm::IsParameterWeight(const ParameterPtr &node) { MS_EXCEPTION_IF_NULL(node); return node->has_default(); } void AnfRuntimeAlgorithm::SetStreamId(uint32_t stream_id, AnfNode *node) { MS_EXCEPTION_IF_NULL(node); auto kernel_info = node->kernel_info(); MS_EXCEPTION_IF_NULL(kernel_info); kernel_info->set_stream_id(stream_id); } uint32_t AnfRuntimeAlgorithm::GetStreamId(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); auto kernel_info = node->kernel_info(); MS_EXCEPTION_IF_NULL(kernel_info); return kernel_info->stream_id(); } void AnfRuntimeAlgorithm::SetStreamDistinctionLabel(uint32_t stream_label, AnfNode *node) { MS_EXCEPTION_IF_NULL(node); auto kernel_info = node->kernel_info(); MS_EXCEPTION_IF_NULL(kernel_info); kernel_info->set_stream_distinction_label(stream_label); } uint32_t AnfRuntimeAlgorithm::GetStreamDistinctionLabel(const AnfNode *node) { MS_EXCEPTION_IF_NULL(node); auto kernel_info = node->kernel_info(); MS_EXCEPTION_IF_NULL(kernel_info); return kernel_info->stream_distinction_label(); } void AnfRuntimeAlgorithm::SetGraphId(uint32_t graph_id, AnfNode *node) { MS_EXCEPTION_IF_NULL(node); auto kernel_info = node->kernel_info(); MS_EXCEPTION_IF_NULL(kernel_info); kernel_info->set_graph_id(graph_id); } uint32_t AnfRuntimeAlgorithm::GetGraphId(const AnfNode *node) { MS_EXCEPTION_IF_NULL(node); auto kernel_info = node->kernel_info(); MS_EXCEPTION_IF_NULL(kernel_info); return kernel_info->graph_id(); } bool AnfRuntimeAlgorithm::IsTupleOutput(const AnfNodePtr &anf) { MS_EXCEPTION_IF_NULL(anf); TypePtr type = anf->Type(); MS_EXCEPTION_IF_NULL(type); return type->isa(); } AnfNodePtr AnfRuntimeAlgorithm::GetInputNode(const CNodePtr &node, size_t index) { MS_EXCEPTION_IF_NULL(node); auto get_input_index = index + 1; if (index + 1 > node->inputs().size()) { MS_LOG(EXCEPTION) << "Input index size " << get_input_index << "but the node input size just" << node->inputs().size(); } // input 0 is primitive node return node->input(get_input_index); } bool AnfRuntimeAlgorithm::IsFeatureMapOutput(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); if (node->isa()) { return false; } auto kernel_info = node->kernel_info(); MS_EXCEPTION_IF_NULL(kernel_info); return kernel_info->is_feature_map(); } bool AnfRuntimeAlgorithm::IsFeatureMapInput(const AnfNodePtr &node, size_t input_index) { if (!node->isa()) { MS_LOG(EXCEPTION) << "Cannot input a parameter or a valuenode to charge it's input if is a feature map"; } auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); auto input_node = cnode->input(input_index + 1); return IsFeatureMapOutput(input_node); } size_t AnfRuntimeAlgorithm::GetRealInputIndex(const mindspore::AnfNodePtr &anf_node, const size_t cur_index) { MS_EXCEPTION_IF_NULL(anf_node); static std::map> spec_node_list = { {prim::kPrimConv2DBackpropInput->name(), {{0, 1}, {1, 0}}}, {kFusionOpConv2DBackpropInputReluGradV2Name, {{0, 1}, {1, 0}, {2, 2}}}, {kFusionOpConv2DBackpropInputAddNReluGradV2Name, {{0, 1}, {1, 0}, {2, 2}, {3, 3}}}, {prim::kPrimConv2DBackpropFilter->name(), {{0, 1}, {1, 0}}}, {prim::kPrimLogSoftmaxGrad->name(), {{0, 1}, {1, 0}}}, {prim::kPrimLayerNormGrad->name(), {{0, 1}, {1, 0}, {2, 2}, {3, 3}, {4, 4}}}, {prim::kPrimLayerNormBetaGammaBackprop->name(), {{0, 1}, {1, 0}, {2, 2}, {3, 3}}}, {prim::kPrimLayerNormXBackprop->name(), {{0, 1}, {1, 0}, {2, 2}, {3, 3}, {4, 4}}}, {prim::kPrimMinimumGrad->name(), {{0, 2}, {1, 0}, {2, 1}}}, {prim::kPrimMaximumGrad->name(), {{0, 2}, {1, 0}, {2, 1}}}, {prim::kPrimApplyCenteredRMSProp->name(), {{0, 0}, {1, 1}, {2, 2}, {3, 3}, {4, 5}, {5, 6}, {6, 7}, {7, 8}, {8, 4}}}}; size_t ret = cur_index; auto node_name = AnfAlgo::GetCNodeName(anf_node); if (AnfAlgo::GetKernelType(anf_node) == TBE_KERNEL) { auto find = spec_node_list.find(node_name); if (find != spec_node_list.end()) { ret = find->second[cur_index]; MS_LOG(INFO) << "Real input index change to" << ret << ", node name:" << node_name; } } return ret; } void AnfRuntimeAlgorithm::SetNodeInput(const CNodePtr &node, const AnfNodePtr &input_node, size_t index) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(input_node); node->set_input(index + 1, input_node); } bool AnfRuntimeAlgorithm::IsCommunicationOp(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); if (!node->isa()) { return false; } auto kernel_name = AnfAlgo::GetCNodeName(node); if (kernel_name == kAllReduceOpName || kernel_name == kAllGatherOpName || kernel_name == kBroadcastOpName || kernel_name == kReduceScatterOpName) { return true; } return false; } bool AnfRuntimeAlgorithm::IsGetNext(const NotNull &node) { auto kernel_name = AnfAlgo::GetCNodeName(node); return kernel_name == kGetNextOpName; } FuncGraphPtr AnfRuntimeAlgorithm::GetValueNodeFuncGraph(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); auto value_node = node->cast(); if (value_node == nullptr) { return nullptr; } auto value = value_node->value(); if (value == nullptr) { return nullptr; } auto func_graph = value->cast(); return func_graph; } std::vector AnfRuntimeAlgorithm::GetCallNodeKernelGraph(const CNodePtr &call_node) { if (!AnfAlgo::CheckPrimitiveType(call_node, std::make_shared("call"))) { MS_LOG(EXCEPTION) << "anf node: " << call_node->DebugString() << "is not a call node."; } MS_EXCEPTION_IF_NULL(call_node); auto input1 = call_node->input(1); MS_EXCEPTION_IF_NULL(input1); if (input1->isa()) { auto value_node = input1->cast(); MS_EXCEPTION_IF_NULL(value_node); auto kernel_graph = value_node->value(); MS_EXCEPTION_IF_NULL(kernel_graph); return {kernel_graph->cast()}; } else if (input1->isa() && AnfAlgo::CheckPrimitiveType(input1, prim::kPrimSwitch)) { auto switch_node = input1->cast(); MS_EXCEPTION_IF_NULL(switch_node); auto get_switch_kernel_graph = [&](size_t input_index) -> KernelGraphPtr { auto partial = switch_node->input(input_index); MS_EXCEPTION_IF_NULL(partial); auto partial_cnode = partial->cast(); MS_EXCEPTION_IF_NULL(partial_cnode); auto graph_node = partial_cnode->input(1); MS_EXCEPTION_IF_NULL(graph_node); auto graph_value_node = graph_node->cast(); MS_EXCEPTION_IF_NULL(graph_value_node); auto graph_value = graph_value_node->value(); MS_EXCEPTION_IF_NULL(graph_value); auto child_graph = graph_value->cast(); return child_graph; }; return {get_switch_kernel_graph(2), get_switch_kernel_graph(3)}; } return {}; } bool AnfRuntimeAlgorithm::IsSwitchCall(const CNodePtr &call_node) { MS_EXCEPTION_IF_NULL(call_node); if (!CheckPrimitiveType(call_node, prim::kPrimCall)) { MS_LOG(EXCEPTION) << "call node should be a 'call', but is a " << call_node->DebugString(); } auto input1 = call_node->input(1); if (input1->isa()) { return false; } else if (input1->isa() && AnfAlgo::CheckPrimitiveType(input1, prim::kPrimSwitch)) { return true; } MS_LOG(EXCEPTION) << "Unexpected input1 of call node,input1:" << input1->DebugString(); } bool AnfRuntimeAlgorithm::IsScalarInput(const CNodePtr &cnode, size_t index) { auto shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, index); if (shape.empty()) { return true; } return shape.size() == kShape1dDims && shape[0] == 1; } bool AnfRuntimeAlgorithm::IsScalarOutput(const CNodePtr &cnode, size_t index) { auto shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, index); if (shape.empty()) { return true; } return shape.size() == kShape1dDims && shape[0] == 1; } } // namespace session } // namespace mindspore