提交 13a2d6d4 编写于 作者: G gong chen

Init GraphKernel.

- It provides a unified style to express graph and kernel for user.
- It provides a unified IR to represent graph and kernel for developer.
- It breaks the boundary between graph and kernel.
- It provides more opportunities to do compile optimization.
上级 dc9a51aa
...@@ -13,3 +13,6 @@ ...@@ -13,3 +13,6 @@
[submodule "graphengine"] [submodule "graphengine"]
path = graphengine path = graphengine
url = https://gitee.com/mindspore/graphengine.git url = https://gitee.com/mindspore/graphengine.git
[submodule "akg"]
path = akg
url = https://gitee.com/mindspore/akg.git
...@@ -83,10 +83,27 @@ if (ENABLE_GE OR ENABLE_D OR ENABLE_TESTCASES) ...@@ -83,10 +83,27 @@ if (ENABLE_GE OR ENABLE_D OR ENABLE_TESTCASES)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/graphengine/third_party/fwkacllib/inc/toolchain) include_directories(${CMAKE_CURRENT_SOURCE_DIR}/graphengine/third_party/fwkacllib/inc/toolchain)
endif() endif()
if (ENABLE_AKG AND ENABLE_D)
set(AKG_PATH "${CMAKE_SOURCE_DIR}/akg/mindspore/ccsrc/akg")
add_subdirectory(${AKG_PATH})
set(TVM_PATH "${CMAKE_CURRENT_BINARY_DIR}/akg/mindspore/ccsrc/akg/incubator-tvm")
include_directories("${TVM_PATH}/include")
include_directories("${TVM_PATH}")
include_directories("${TVM_PATH}/src")
include_directories("${TVM_PATH}/src/schedule")
include_directories("${TVM_PATH}/3rdparty/dmlc-core/include")
include_directories("${TVM_PATH}/3rdparty/dlpack/include")
include_directories("${TVM_PATH}/3rdparty/compiler-rt")
include_directories("${TVM_PATH}/3rdparty/rang/include")
include_directories("${TVM_PATH}/3rdparty/picojson")
include_directories("${AKG_PATH}")
include_directories("${AKG_PATH}/include")
endif()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fvisibility=hidden") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fvisibility=hidden")
add_subdirectory(mindspore/ccsrc) add_subdirectory(mindspore/ccsrc)
if (ENABLE_TESTCASES) if (ENABLE_TESTCASES)
add_subdirectory(tests) add_subdirectory(tests)
endif() endif()
include(cmake/package.cmake) include(cmake/package.cmake)
\ No newline at end of file
akg @ 02133cc7
Subproject commit 02133cc76374b8c2d6c73f24a58d3c28e1e4d14e
...@@ -245,6 +245,9 @@ checkopts "$@" ...@@ -245,6 +245,9 @@ checkopts "$@"
echo "---------------- mindspore: build start ----------------" echo "---------------- mindspore: build start ----------------"
mkdir -pv "${BUILD_PATH}/package/mindspore/lib" mkdir -pv "${BUILD_PATH}/package/mindspore/lib"
git submodule update --init graphengine git submodule update --init graphengine
if [[ "X$ENABLE_AKG" = "Xon" ]] && [[ "X$ENABLE_D" = "Xon" ]]; then
git submodule update --init --recursive akg
fi
build_exit() build_exit()
{ {
...@@ -307,7 +310,7 @@ build_mindspore() ...@@ -307,7 +310,7 @@ build_mindspore()
if [[ "X$USE_GLOG" = "Xon" ]]; then if [[ "X$USE_GLOG" = "Xon" ]]; then
CMAKE_ARGS="${CMAKE_ARGS} -DUSE_GLOG=ON" CMAKE_ARGS="${CMAKE_ARGS} -DUSE_GLOG=ON"
fi fi
if [[ "X$ENABLE_AKG" = "Xon" ]]; then if [[ "X$ENABLE_AKG" = "Xon" ]] && [[ "X$ENABLE_D" = "Xon" ]]; then
CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_AKG=ON" CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_AKG=ON"
fi fi
echo "${CMAKE_ARGS}" echo "${CMAKE_ARGS}"
...@@ -451,4 +454,11 @@ fi ...@@ -451,4 +454,11 @@ fi
cp -rf ${BUILD_PATH}/package/mindspore/lib ${BUILD_PATH}/../mindspore cp -rf ${BUILD_PATH}/package/mindspore/lib ${BUILD_PATH}/../mindspore
cp -rf ${BUILD_PATH}/package/mindspore/*.so ${BUILD_PATH}/../mindspore cp -rf ${BUILD_PATH}/package/mindspore/*.so ${BUILD_PATH}/../mindspore
if [[ "X$ENABLE_AKG" = "Xon" ]] && [[ "X$ENABLE_D" = "Xon" ]]; then
so_lib_dir=${BUILD_PATH}/package/mindspore/lib
akg_build_dir=${BUILD_PATH}/mindspore/akg/mindspore/ccsrc/akg
mkdir -p ${so_lib_dir}
cp ${akg_build_dir}/*.so ${so_lib_dir}
fi
echo "---------------- mindspore: build end ----------------" echo "---------------- mindspore: build end ----------------"
...@@ -222,6 +222,24 @@ if (ENABLE_GPU) ...@@ -222,6 +222,24 @@ if (ENABLE_GPU)
endif () endif ()
endif () endif ()
if (ENABLE_D AND ENABLE_AKG)
set (AKG_PATH ${CMAKE_SOURCE_DIR}/akg)
set (TVM_PATH ${CMAKE_SOURCE_DIR}/build/mindspore/akg/mindspore/ccsrc/akg/incubator-tvm)
install(
DIRECTORY
${AKG_PATH}/mindspore/akg
DESTINATION ${INSTALL_PY_DIR}/..
COMPONENT mindspore
)
install(
DIRECTORY
${TVM_PATH}/topi/python/topi
${TVM_PATH}/python/tvm
DESTINATION ${INSTALL_PY_DIR}/../akg
COMPONENT mindspore
)
endif()
if (EXISTS ${CMAKE_SOURCE_DIR}/mindspore/dataset) if (EXISTS ${CMAKE_SOURCE_DIR}/mindspore/dataset)
install( install(
DIRECTORY ${CMAKE_SOURCE_DIR}/mindspore/dataset DIRECTORY ${CMAKE_SOURCE_DIR}/mindspore/dataset
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Providing akg compile with json"""
import sys
def run_compiler(op_json):
"""
Run AKG compiler to compile op with subprocess, if this process of
compilation failed, an exception will be raised
Args:
op_json (str): json string of the op
Returns:
None
"""
p = __import__("akg", globals(), locals(), ['ms'], 0)
func = getattr(p.ms, "compilewithjson")
res = func(op_json)
if not res:
raise ValueError("Compile error")
if __name__ == "__main__":
run_compiler(sys.argv[1])
...@@ -13,95 +13,59 @@ ...@@ -13,95 +13,59 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""Providing multi process compile with json""" """Providing multi process compile with json"""
import json
import math
import os import os
import subprocess import subprocess
import sys import sys
from multiprocessing import Pool from multiprocessing import Pool, cpu_count
def _compiletask(platform, *jsons): def _compile_akg_task(*json_strs):
""" """
compile func called in single process compile func called in single process
Parameters: Parameters:
platform: str. AKG platform or TBE platform json_strs: list. List contains multiple kernel infos, suitable for json compile api.
*jsons: str. json str contain kernel info, suitable for json compile """
api akg_compiler = os.path.join(os.path.split(
os.path.realpath(__file__))[0], "compiler.py")
""" for json_str in json_strs:
if platform == "AKG": res = subprocess.run(
p = __import__("_akg", globals(), locals(), ['ms'], 0) [sys.executable, akg_compiler, json_str], text=True)
func = getattr(p.ms, "compilewithjson") if res.returncode != 0:
for json_item in jsons: raise ValueError("Failed, args: {}!".format(json_str))
res = func(json_item)
if not res:
raise ValueError("Compile error")
if platform == "TBE":
tbe_compiler = os.path.join(os.path.split(os.path.realpath(__file__))[0], "tbe_compiler", "compiler.py")
for json_item in jsons:
res = subprocess.run([sys.executable, tbe_compiler], input=json_item, text=True)
if res.returncode != 0:
raise ValueError("Tbe compile error")
def compilekernelparallel(jsons, process, waitime): def compile_akg_kernel_parallel(json_infos, process, waitime):
""" """
compile kernel use multi processes compile kernel use multi processes
Parameters: Parameters:
jsons: list. json str list contain kernel info json_infos: list. list contain kernel info(task id and json str)
process: int. processes num process: int. processes num
waittime: int. max time the function blocked waittime: int. max time the function blocked
Returns:
True for all compile success, False for some failed.
""" """
if not isinstance(jsons, list): if not isinstance(json_infos, list):
raise ValueError("jsons must be a list") raise ValueError("json_infos must be a list")
if not isinstance(process, int): if not isinstance(process, int):
raise ValueError("process must be a num") raise ValueError("process must be a num")
if not isinstance(waitime, int): if not isinstance(waitime, int):
raise ValueError("waittime must be a num") raise ValueError("waittime must be a num")
jsons_akg = [] if process == 0 and json_infos:
jsons_tbe = [] process = 1
for json_ in jsons:
j = json.loads(json_) cpu_proc_num = cpu_count()
if j["platform"] == "TBE": max_proc_num = 16
jsons_tbe.append(json_) process = min([cpu_proc_num, max_proc_num, process])
continue
if j["platform"] == "AKG":
jsons_akg.append(json_)
continue
raise RuntimeError(
"not support this platform {0}".format(j["platform"]))
if jsons_akg:
process_akg = math.floor(len(jsons)/len(jsons_akg)*process)
else:
process_akg = 0
if process_akg == 0 and jsons_akg: args = [[] for _ in range(process)]
process_akg = 1 for p, info in enumerate(json_infos):
process_tbe = process-process_akg args[p % process].append(info)
if process_tbe == 0 and jsons_tbe:
process_tbe = 1
raise RuntimeWarning("we add a process for compile more operator")
args = [[] for _ in range(process_akg+process_tbe)]
args_lens = len(args)
for p in range(args_lens):
if p < process_tbe:
args[p].append("TBE")
else:
args[p].append("AKG")
jsons_tbe_lens = len(jsons_tbe)
for p in range(jsons_tbe_lens):
args[p % process_tbe].append(jsons_tbe[p])
jsons_akg_lens = len(jsons_akg)
for p in range(jsons_akg_lens):
args[process-p % process_akg-1].append(jsons_akg[p])
for p in range(args_lens):
args[p] = tuple(args[p])
with Pool(processes=process) as pool: with Pool(processes=process) as pool:
res = pool.starmap_async(_compiletask, args) res = pool.starmap_async(_compile_akg_task, args)
res.get(timeout=waitime) res.get(timeout=waitime)
return True return True
...@@ -426,6 +426,10 @@ std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const s ...@@ -426,6 +426,10 @@ std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const s
auto temp_shape = shape; auto temp_shape = shape;
std::vector<size_t> device_shape; std::vector<size_t> device_shape;
if (format == kOpFormat_FRAC_NZ) { if (format == kOpFormat_FRAC_NZ) {
if (shape.size() == 1 && (shape[0] == 1 || shape[0] % kCubeSize == 0)) {
// For [1] and [1024] shape we can trait it as NZ shape
return shape;
}
if (shape.size() < 2) { if (shape.size() < 2) {
MS_LOG(EXCEPTION) << "Format" << format << " is not support shape " << shape.size(); MS_LOG(EXCEPTION) << "Format" << format << " is not support shape " << shape.size();
} else { } else {
......
...@@ -111,9 +111,15 @@ void DumpGlobalInfoEntry(const FuncGraphPtr &graph, std::ostringstream &buffer) ...@@ -111,9 +111,15 @@ void DumpGlobalInfoEntry(const FuncGraphPtr &graph, std::ostringstream &buffer)
} }
buffer << "#IR entry : @" << graph->ToString() << "." << graph->debug_info()->get_id() << std::endl; buffer << "#IR entry : @" << graph->ToString() << "." << graph->debug_info()->get_id() << std::endl;
buffer << "#flags :" << std::endl; buffer << "#attrs :" << std::endl;
for (const auto &flag : graph->flags()) { for (const auto &attr : graph->attrs()) {
buffer << flag.first << " : " << flag.second << std::endl; buffer << attr.first << " : ";
if (attr.second->isa<BoolImm>()) {
buffer << GetValue<bool>(attr.second);
} else if (attr.second->isa<StringImm>()) {
buffer << GetValue<std::string>(attr.second);
}
buffer << std::endl;
} }
} }
...@@ -417,10 +423,16 @@ void DumpSubgraph(const OrderedMap<FuncGraphPtr, std::shared_ptr<SubGraphIRInfo> ...@@ -417,10 +423,16 @@ void DumpSubgraph(const OrderedMap<FuncGraphPtr, std::shared_ptr<SubGraphIRInfo>
fout << std::endl; fout << std::endl;
for (const auto &sg : *sub_graphs) { for (const auto &sg : *sub_graphs) {
fout << "subgraph flag:" << std::endl; fout << "subgraph attr:" << std::endl;
MS_EXCEPTION_IF_NULL(sg.first); MS_EXCEPTION_IF_NULL(sg.first);
for (const auto &flag : sg.first->flags()) { for (const auto &attr : sg.first->attrs()) {
fout << flag.first << " : " << flag.second << std::endl; fout << attr.first << " : ";
if (attr.second->isa<BoolImm>()) {
fout << GetValue<bool>(attr.second);
} else if (attr.second->isa<StringImm>()) {
fout << GetValue<std::string>(attr.second);
}
fout << std::endl;
} }
fout << "subgraph @" << sg.first->ToString() << "."; fout << "subgraph @" << sg.first->ToString() << ".";
fout << sg.first->debug_info()->get_id() << "("; fout << sg.first->debug_info()->get_id() << "(";
......
...@@ -512,9 +512,15 @@ void AscendStreamAssign::GetNeedActiveStreams(const shared_ptr<session::KernelGr ...@@ -512,9 +512,15 @@ void AscendStreamAssign::GetNeedActiveStreams(const shared_ptr<session::KernelGr
for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { for (size_t i = 0; i < cnode_ptr_list.size(); ++i) {
cur_cnode_ptr = cnode_ptr_list[i]; cur_cnode_ptr = cnode_ptr_list[i];
MS_EXCEPTION_IF_NULL(cur_cnode_ptr); MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
ValuePtr value_ptr = nullptr;
auto primitive = AnfAlgo::GetCNodePrimitive(cur_cnode_ptr); auto primitive = AnfAlgo::GetCNodePrimitive(cur_cnode_ptr);
MS_EXCEPTION_IF_NULL(primitive); if (primitive != nullptr) {
auto value_ptr = primitive->GetAttr(kStreamNeedActivedFirst); value_ptr = primitive->GetAttr(kStreamNeedActivedFirst);
} else {
auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(cur_cnode_ptr);
MS_EXCEPTION_IF_NULL(func_graph);
value_ptr = func_graph->get_attr(kStreamNeedActivedFirst);
}
if (value_ptr == nullptr) { if (value_ptr == nullptr) {
continue; continue;
} }
...@@ -774,6 +780,7 @@ void AscendStreamAssign::ReorderIndependentOrders(const shared_ptr<mindspore::se ...@@ -774,6 +780,7 @@ void AscendStreamAssign::ReorderIndependentOrders(const shared_ptr<mindspore::se
} }
std::set<CNode *> processed; std::set<CNode *> processed;
for (size_t i = 0; i < others.size(); i++) { for (size_t i = 0; i < others.size(); i++) {
auto begin = others.begin() + i; auto begin = others.begin() + i;
auto end = begin + 1; auto end = begin + 1;
...@@ -781,6 +788,7 @@ void AscendStreamAssign::ReorderIndependentOrders(const shared_ptr<mindspore::se ...@@ -781,6 +788,7 @@ void AscendStreamAssign::ReorderIndependentOrders(const shared_ptr<mindspore::se
for (size_t j = 0; j < independents.size(); j++) { for (size_t j = 0; j < independents.size(); j++) {
auto cur_independent = independents[j]; auto cur_independent = independents[j];
auto it = std::find(processed.begin(), processed.end(), cur_independent.get()); auto it = std::find(processed.begin(), processed.end(), cur_independent.get());
if (it != processed.end()) { if (it != processed.end()) {
continue; continue;
} }
......
...@@ -26,10 +26,12 @@ ...@@ -26,10 +26,12 @@
#include "kernel/kernel.h" #include "kernel/kernel.h"
#include "kernel/tbe/tbe_kernel_build.h" #include "kernel/tbe/tbe_kernel_build.h"
#include "kernel/tbe/tbe_kernel_parallel_build.h" #include "kernel/tbe/tbe_kernel_parallel_build.h"
#include "kernel/akg/ascend/akg_ascend_kernel_build.h"
#include "kernel/aicpu/aicpu_kernel_build.h" #include "kernel/aicpu/aicpu_kernel_build.h"
#include "kernel/hccl/hccl_kernel_build.h" #include "kernel/hccl/hccl_kernel_build.h"
#include "kernel/rts/rt_kernel_build.h" #include "kernel/rts/rt_kernel_build.h"
#include "kernel/tbe/tbe_utils.h" #include "kernel/tbe/tbe_utils.h"
#include "kernel/common_utils.h"
#include "operator/ops.h" #include "operator/ops.h"
#include "session/anf_runtime_algorithm.h" #include "session/anf_runtime_algorithm.h"
#include "./common.h" #include "./common.h"
...@@ -65,6 +67,7 @@ static kernel::KernelModPtr SerialCompileImpl(const AnfNodePtr &anf_node) { ...@@ -65,6 +67,7 @@ static kernel::KernelModPtr SerialCompileImpl(const AnfNodePtr &anf_node) {
static bool KernelBuildParallelCompile(const mindspore::session::KernelGraph *kernel_graph_ptr) { static bool KernelBuildParallelCompile(const mindspore::session::KernelGraph *kernel_graph_ptr) {
MS_EXCEPTION_IF_NULL(kernel_graph_ptr); MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
std::vector<AnfNodePtr> tbe_nodes; std::vector<AnfNodePtr> tbe_nodes;
std::vector<AnfNodePtr> akg_nodes;
std::vector<AnfNodePtr> other_nodes; std::vector<AnfNodePtr> other_nodes;
for (const auto &anf_node : kernel_graph_ptr->execution_order()) { for (const auto &anf_node : kernel_graph_ptr->execution_order()) {
MS_EXCEPTION_IF_NULL(anf_node); MS_EXCEPTION_IF_NULL(anf_node);
...@@ -79,19 +82,26 @@ static bool KernelBuildParallelCompile(const mindspore::session::KernelGraph *ke ...@@ -79,19 +82,26 @@ static bool KernelBuildParallelCompile(const mindspore::session::KernelGraph *ke
} }
break; break;
} }
case KernelType::AKG_KERNEL: {
akg_nodes.push_back(anf_node);
break;
}
default: { default: {
other_nodes.push_back(anf_node); other_nodes.push_back(anf_node);
break; break;
} }
} }
} }
bool ret = kernel::TbeOpParallelBuild(tbe_nodes); bool tbe_ret = kernel::TbeOpParallelBuild(tbe_nodes);
bool akg_ret = kernel::AkgAscendKernelParallelBuild(akg_nodes);
auto bin_map = kernel::tbe::KernelMeta::GetInstance();
(void)bin_map->ReadIndex(kernel::kCceKernelMeta);
for (const auto &anf_node : other_nodes) { for (const auto &anf_node : other_nodes) {
kernel::KernelModPtr kernel_mod_ptr = SerialCompileImpl(anf_node); kernel::KernelModPtr kernel_mod_ptr = SerialCompileImpl(anf_node);
MS_EXCEPTION_IF_NULL(kernel_mod_ptr); MS_EXCEPTION_IF_NULL(kernel_mod_ptr);
AnfAlgo::SetKernelMod(kernel_mod_ptr, anf_node.get()); AnfAlgo::SetKernelMod(kernel_mod_ptr, anf_node.get());
} }
return ret; return tbe_ret && akg_ret;
} }
static std::vector<int> CalCleanZerosSize(const CNodePtr &pre_node) { static std::vector<int> CalCleanZerosSize(const CNodePtr &pre_node) {
...@@ -202,7 +212,7 @@ void KernelBuildPreprocess(mindspore::session::KernelGraph *kernel_graph) { ...@@ -202,7 +212,7 @@ void KernelBuildPreprocess(mindspore::session::KernelGraph *kernel_graph) {
for (const auto &anf_node : kernel_graph->execution_order()) { for (const auto &anf_node : kernel_graph->execution_order()) {
std::string apply_function_name = AnfAlgo::GetCNodeName(anf_node); std::string apply_function_name = AnfAlgo::GetCNodeName(anf_node);
if (apply_function_name == prim::kPrimMaxPoolGrad->name() && if (apply_function_name == prim::kPrimMaxPoolGrad->name() &&
AnfAlgo::GetKernelType(anf_node) == KernelType::AUTO_DIFF_KERNEL) { AnfAlgo::GetKernelType(anf_node) == KernelType::AKG_KERNEL) {
auto clear_zero_prim = std::make_shared<Primitive>(kClearZeroOpName); auto clear_zero_prim = std::make_shared<Primitive>(kClearZeroOpName);
MS_EXCEPTION_IF_NULL(clear_zero_prim); MS_EXCEPTION_IF_NULL(clear_zero_prim);
auto new_value_node = NewValueNode(clear_zero_prim); auto new_value_node = NewValueNode(clear_zero_prim);
......
...@@ -15,18 +15,27 @@ ...@@ -15,18 +15,27 @@
*/ */
#include "device/ascend/kernel_select_ascend.h" #include "device/ascend/kernel_select_ascend.h"
#include <string> #include <string>
#include <vector> #include <vector>
#include <memory> #include <memory>
#include <utility> #include <utility>
#include <algorithm>
#include <map> #include <map>
#include "kernel/oplib/oplib.h" #include <unordered_map>
#include "kernel/kernel_query.h" #include <unordered_set>
#include "common/utils.h"
#include "debug/anf_ir_dump.h"
#include "operator/ops.h"
#include "ir/func_graph.h"
#include "utils/context/ms_context.h"
#include "session/anf_runtime_algorithm.h" #include "session/anf_runtime_algorithm.h"
#include "device/kernel_info.h"
#include "kernel/common_utils.h"
#include "kernel/kernel_query.h"
#include "kernel/oplib/oplib.h"
#include "kernel/kernel_build_info.h" #include "kernel/kernel_build_info.h"
#include "utils/context/ms_context.h"
#include "operator/ops.h"
#include "debug/anf_ir_dump.h"
namespace mindspore { namespace mindspore {
namespace device { namespace device {
...@@ -124,12 +133,23 @@ void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, cons ...@@ -124,12 +133,23 @@ void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, cons
} }
auto pri_match_format = GetPriorityMatchFormat(kernel_node); auto pri_match_format = GetPriorityMatchFormat(kernel_node);
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) {
auto input_anf_node = kernel_node->input(input_index + 1);
// we do not take ValueNode into consideration in composite op.
if (kernel_build_info.kernel_type() == KernelType::AKG_KERNEL) {
if (input_anf_node->isa<ValueNode>() && AnfAlgo::GetOutputDeviceDataType(input_anf_node, 0) == kTypeUnknown) {
continue;
}
}
auto base_score = AnfAlgo::IsFeatureMapInput(kernel_node, input_index) ? kFeatureMapBaseScore : kWegihtBaseScore; auto base_score = AnfAlgo::IsFeatureMapInput(kernel_node, input_index) ? kFeatureMapBaseScore : kWegihtBaseScore;
if (kernel_build_info.GetInputFormat(input_index) == AnfAlgo::GetPrevNodeOutputFormat(kernel_node, input_index)) { if (kernel_build_info.GetInputFormat(input_index) == AnfAlgo::GetPrevNodeOutputFormat(kernel_node, input_index)) {
(*cur_kernelinfo_match_counts)[MATCH_FORMAT_COUNT] += base_score; (*cur_kernelinfo_match_counts)[MATCH_FORMAT_COUNT] += base_score;
} }
if (kernel_build_info.GetInputDeviceType(input_index) == // we match output fix precision first.
AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, input_index)) { auto prev_device_type = AnfAlgo::GetPrevNodeOutputPrecision(kernel_node, input_index);
if (prev_device_type == kTypeUnknown) {
prev_device_type = AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, input_index);
}
if (kernel_build_info.GetInputDeviceType(input_index) == prev_device_type) {
(*cur_kernelinfo_match_counts)[MATCH_DTYPE_COUNT] += base_score; (*cur_kernelinfo_match_counts)[MATCH_DTYPE_COUNT] += base_score;
} }
if (kernel_build_info.GetInputFormat(input_index) == pri_match_format) { if (kernel_build_info.GetInputFormat(input_index) == pri_match_format) {
...@@ -149,42 +169,6 @@ void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, cons ...@@ -149,42 +169,6 @@ void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, cons
} }
} }
void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) {
auto input_kernel_node = AnfAlgo::GetInputNode(kernel_node, input_index);
MS_EXCEPTION_IF_NULL(input_kernel_node);
auto input_with_index = AnfAlgo::VisitKernel(input_kernel_node, 0);
MS_EXCEPTION_IF_NULL(input_with_index.first);
auto real_input_node = input_with_index.first;
if (real_input_node->isa<CNode>()) {
continue;
}
if (real_input_node->isa<Parameter>() && !AnfAlgo::IsParameterWeight(real_input_node->cast<ParameterPtr>())) {
continue;
}
std::shared_ptr<kernel::KernelBuildInfo::KernelBuildInfoBuilder> builder =
std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
bool is_ref = false;
auto op_info = mindspore::kernel::OpLib::FindOp(AnfAlgo::GetCNodeName(kernel_node), kernel::kTBE);
if (op_info != nullptr) {
is_ref = op_info->is_ref();
}
MS_EXCEPTION_IF_NULL(MsContext::GetInstance());
if (MsContext::GetInstance()->execution_mode() == kPynativeMode &&
AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) != kTypeUnknown) {
continue;
}
if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown || is_ref) {
std::vector<std::string> output_format = {selected_kernel_info.GetInputFormat(input_index)};
builder->SetOutputsFormat(output_format);
std::vector<TypeId> output_type = {selected_kernel_info.GetInputDeviceType(input_index)};
builder->SetOutputsDeviceType(output_type);
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), real_input_node.get());
}
}
}
void AddSupportMixedPrecisionDataTypeIndex(TypeId data_type, std::vector<int> *support_index) { void AddSupportMixedPrecisionDataTypeIndex(TypeId data_type, std::vector<int> *support_index) {
MS_EXCEPTION_IF_NULL(support_index); MS_EXCEPTION_IF_NULL(support_index);
int index = kUnSupportMixedDataTypeIndex; int index = kUnSupportMixedDataTypeIndex;
...@@ -469,6 +453,51 @@ std::vector<std::shared_ptr<kernel::KernelBuildInfo>> FilterRaisedOrReducePrecis ...@@ -469,6 +453,51 @@ std::vector<std::shared_ptr<kernel::KernelBuildInfo>> FilterRaisedOrReducePrecis
} }
} // namespace } // namespace
void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) {
auto input_kernel_node = AnfAlgo::GetInputNode(kernel_node, input_index);
MS_EXCEPTION_IF_NULL(input_kernel_node);
auto input_with_index = AnfAlgo::VisitKernel(input_kernel_node, 0);
MS_EXCEPTION_IF_NULL(input_with_index.first);
auto real_input_node = input_with_index.first;
if (real_input_node->isa<CNode>()) {
continue;
}
if (real_input_node->isa<Parameter>() && !AnfAlgo::IsParameterWeight(real_input_node->cast<ParameterPtr>())) {
continue;
}
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
if (IsValueNode<tensor::Tensor>(input_kernel_node) &&
AnfAlgo::GetOutputDeviceDataType(input_kernel_node, 0) == kTypeUnknown) {
std::vector<std::string> output_format = {selected_kernel_info.GetInputFormat(input_index)};
builder->SetOutputsFormat(output_format);
std::vector<TypeId> output_type = {selected_kernel_info.GetInputDeviceType(input_index)};
builder->SetOutputsDeviceType(output_type);
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), input_kernel_node.get());
continue;
}
// we set special device info of a input tensor.
bool is_ref = false;
auto op_info = kernel::OpLib::FindOp(AnfAlgo::GetCNodeName(kernel_node), kernel::kTBE);
if (op_info != nullptr) {
is_ref = op_info->is_ref();
}
MS_EXCEPTION_IF_NULL(MsContext::GetInstance());
if (MsContext::GetInstance()->execution_mode() == kPynativeMode &&
AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) != kTypeUnknown) {
continue;
}
if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown || is_ref) {
std::vector<std::string> output_format = {selected_kernel_info.GetInputFormat(input_index)};
builder->SetOutputsFormat(output_format);
std::vector<TypeId> output_type = {selected_kernel_info.GetInputDeviceType(input_index)};
builder->SetOutputsDeviceType(output_type);
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), real_input_node.get());
}
}
}
KernelSelectStatus SetMatchedKernelInfo(const CNodePtr &kernel_node, KernelSelectStatus SetMatchedKernelInfo(const CNodePtr &kernel_node,
const std::vector<std::shared_ptr<kernel::KernelBuildInfo>> &kernel_info_list) { const std::vector<std::shared_ptr<kernel::KernelBuildInfo>> &kernel_info_list) {
MS_EXCEPTION_IF_NULL(kernel_node); MS_EXCEPTION_IF_NULL(kernel_node);
...@@ -500,11 +529,17 @@ KernelSelectStatus SetMatchedKernelInfo(const CNodePtr &kernel_node, ...@@ -500,11 +529,17 @@ KernelSelectStatus SetMatchedKernelInfo(const CNodePtr &kernel_node,
return select_status; return select_status;
} }
KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node) { KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node, KernelType kernel_type) {
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list; std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list;
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> aicpu_kernel_info_list; std::vector<std::shared_ptr<kernel::KernelBuildInfo>> aicpu_kernel_info_list;
MS_EXCEPTION_IF_NULL(kernel_node); MS_EXCEPTION_IF_NULL(kernel_node);
kernel::KernelQuery(kernel_node, &kernel_info_list); if (AnfAlgo::IsCompositeKernel(kernel_node)) {
auto func_graph = GetValueNode<FuncGraphPtr>(kernel_node->input(kAnfPrimitiveIndex));
MS_EXCEPTION_IF_NULL(func_graph);
SelectCompositeKernelInfo(kernel_node, func_graph);
return kStatusAllMatched;
}
kernel::KernelQuery(kernel_node, &kernel_info_list, kernel_type);
auto select_status = SetMatchedKernelInfo(kernel_node, kernel_info_list); auto select_status = SetMatchedKernelInfo(kernel_node, kernel_info_list);
// If aicore not find valid kernel info reloading aicpu kernel info list to find it // If aicore not find valid kernel info reloading aicpu kernel info list to find it
if (select_status == kNoMatched) { if (select_status == kNoMatched) {
......
...@@ -27,7 +27,10 @@ enum KernelSelectStatus { ...@@ -27,7 +27,10 @@ enum KernelSelectStatus {
kStatusReducePrecision = 1, kStatusReducePrecision = 1,
kStatusRaisePrecision = 2, kStatusRaisePrecision = 2,
}; };
KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node); KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node,
KernelType kernel_type = KernelType::UNKNOWN_KERNEL_TYPE);
void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, const CNodePtr &kernel_node);
void SelectCompositeKernelInfo(const CNodePtr &kernel_node, const FuncGraphPtr &func_graph);
} // namespace ascend } // namespace ascend
} // namespace device } // namespace device
} // namespace mindspore } // namespace mindspore
......
...@@ -31,7 +31,7 @@ void TaskDescReporter::ReportData() { ...@@ -31,7 +31,7 @@ void TaskDescReporter::ReportData() {
size_t task_index = 0; size_t task_index = 0;
for (const auto &node : cnode_list_) { for (const auto &node : cnode_list_) {
if (AnfAlgo::GetKernelType(node) != TBE_KERNEL) { if (AnfAlgo::GetKernelType(node) != TBE_KERNEL && AnfAlgo::GetKernelType(node) != AKG_KERNEL) {
MS_LOG(WARNING) << "Skip non tbe kernel"; MS_LOG(WARNING) << "Skip non tbe kernel";
++task_index; ++task_index;
continue; continue;
......
...@@ -43,7 +43,37 @@ bool TaskGenerator::GenTasks(const std::vector<CNodePtr> &anf_node_list, std::ve ...@@ -43,7 +43,37 @@ bool TaskGenerator::GenTasks(const std::vector<CNodePtr> &anf_node_list, std::ve
void TaskGenerator::LaunchAddrCleanKernel(const CNodePtr &anf_node_ptr, AddressPtrList *kernel_inputs) { void TaskGenerator::LaunchAddrCleanKernel(const CNodePtr &anf_node_ptr, AddressPtrList *kernel_inputs) {
MS_EXCEPTION_IF_NULL(anf_node_ptr); MS_EXCEPTION_IF_NULL(anf_node_ptr);
if (anf_node_ptr->inputs().size() != 2) { if (anf_node_ptr->inputs().size() != 2) {
MS_LOG(EXCEPTION) << "atomic Addr clean Node Input nodes not equal 2."; // akg process
// set atomic clean addr
if (AnfAlgo::HasNodeAttr(kAttrAutomicOutputIndexs, anf_node_ptr)) {
auto clean_output_indexs = AnfAlgo::GetNodeAttr<std::vector<size_t>>(anf_node_ptr, kAttrAutomicOutputIndexs);
auto graph = anf_node_ptr->func_graph();
MS_EXCEPTION_IF_NULL(graph);
auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager);
auto node_users = manager->node_users();
if (node_users[anf_node_ptr].empty()) {
MS_LOG(EXCEPTION) << "Node users of " << anf_node_ptr->ToString() << " is empty.";
}
auto depend_node = node_users[anf_node_ptr].pop().first;
if (!IsPrimitiveCNode(depend_node, prim::kPrimDepend)) {
MS_LOG(EXCEPTION) << "Checking Depend node failed";
}
if (node_users[depend_node].empty()) {
MS_LOG(EXCEPTION) << "Node users of " << depend_node->ToString() << " is empty.";
}
auto post_node = node_users[depend_node].pop().first;
for (auto index : clean_output_indexs) {
auto device_address = AnfAlgo::GetOutputAddr(post_node, index);
kernel::AddressPtr input = std::make_shared<kernel::Address>();
input->addr = device_address->ptr_;
MS_EXCEPTION_IF_NULL(input->addr);
input->size = device_address->size_;
kernel_inputs->push_back(input);
}
MS_LOG(DEBUG) << "AtomicAddClean clean output size: " << clean_output_indexs.size();
}
return;
} }
MS_EXCEPTION_IF_NULL(anf_node_ptr->inputs()[1]); MS_EXCEPTION_IF_NULL(anf_node_ptr->inputs()[1]);
auto pre_node = (anf_node_ptr->inputs()[1])->cast<CNodePtr>(); auto pre_node = (anf_node_ptr->inputs()[1])->cast<CNodePtr>();
...@@ -59,7 +89,7 @@ void TaskGenerator::LaunchAddrCleanKernel(const CNodePtr &anf_node_ptr, AddressP ...@@ -59,7 +89,7 @@ void TaskGenerator::LaunchAddrCleanKernel(const CNodePtr &anf_node_ptr, AddressP
input->size = device_address->size_; input->size = device_address->size_;
kernel_inputs->push_back(input); kernel_inputs->push_back(input);
} }
MS_LOG(INFO) << "AtomicAddClean clean output size:" << clean_output_indexs.size(); MS_LOG(DEBUG) << "AtomicAddClean clean output size:" << clean_output_indexs.size();
} }
// set clean workspace address // set clean workspace address
if (AnfAlgo::HasNodeAttr(kAttrAutomicWorkspaceSize, pre_node)) { if (AnfAlgo::HasNodeAttr(kAttrAutomicWorkspaceSize, pre_node)) {
......
...@@ -38,7 +38,7 @@ void GpuBuild(const KernelGraphPtr &kernel_graph) { ...@@ -38,7 +38,7 @@ void GpuBuild(const KernelGraphPtr &kernel_graph) {
continue; continue;
} }
if (session::AnfRuntimeAlgorithm::GetKernelType(kernel) == KernelType::AUTO_DIFF_KERNEL) { if (session::AnfRuntimeAlgorithm::GetKernelType(kernel) == KernelType::AKG_KERNEL) {
auto gpu_kernel_ptr = kernel::AkgGpuKernelBuild(kernel); auto gpu_kernel_ptr = kernel::AkgGpuKernelBuild(kernel);
if (!gpu_kernel_ptr) { if (!gpu_kernel_ptr) {
MS_LOG(EXCEPTION) << "Build akg kernel op[" << kernel_name << "] failed"; MS_LOG(EXCEPTION) << "Build akg kernel op[" << kernel_name << "] failed";
......
...@@ -179,7 +179,7 @@ void SetKernelInfo(const CNodePtr &kernel_node) { ...@@ -179,7 +179,7 @@ void SetKernelInfo(const CNodePtr &kernel_node) {
if (!result) { if (!result) {
result = SelectAkgKernel(kernel_node, builder->Build()); result = SelectAkgKernel(kernel_node, builder->Build());
kernel_type = AUTO_DIFF_KERNEL; kernel_type = AKG_KERNEL;
} }
if (!result) { if (!result) {
......
...@@ -43,7 +43,8 @@ void KernelAdjust::Reorder(const std::shared_ptr<session::KernelGraph> &kernel_g ...@@ -43,7 +43,8 @@ void KernelAdjust::Reorder(const std::shared_ptr<session::KernelGraph> &kernel_g
std::vector<CNodePtr> momentum_list; std::vector<CNodePtr> momentum_list;
std::vector<CNodePtr> other_list; std::vector<CNodePtr> other_list;
for (const auto &cnode : origin_cnode_list) { for (const auto &cnode : origin_cnode_list) {
if (kOptOperatorSet.find(AnfAlgo::GetCNodeName(cnode)) != kOptOperatorSet.end()) { if (!AnfAlgo::IsCompositeKernel(cnode) &&
kOptOperatorSet.find(AnfAlgo::GetCNodeName(cnode)) != kOptOperatorSet.end()) {
momentum_list.emplace_back(cnode); momentum_list.emplace_back(cnode);
} else { } else {
other_list.emplace_back(cnode); other_list.emplace_back(cnode);
......
...@@ -26,6 +26,8 @@ ...@@ -26,6 +26,8 @@
#include "ir/func_graph.h" #include "ir/func_graph.h"
#include "ir/primitive.h" #include "ir/primitive.h"
#include "operator/ops.h"
namespace mindspore { namespace mindspore {
// namespace to support intermediate representation definition // namespace to support intermediate representation definition
CNode::CNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &func_graph) CNode::CNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &func_graph)
...@@ -106,10 +108,14 @@ std::string ValueNode::fullname_with_scope() { ...@@ -106,10 +108,14 @@ std::string ValueNode::fullname_with_scope() {
bool IsPrimitiveCNode(const AnfNodePtr &node, const PrimitivePtr &value) { bool IsPrimitiveCNode(const AnfNodePtr &node, const PrimitivePtr &value) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>(); auto cnode = node->cast<CNodePtr>();
if (cnode != nullptr) { if (cnode == nullptr) {
return false;
}
if (value != nullptr) {
return cnode->IsApply(value); return cnode->IsApply(value);
} }
return false; const auto &prim = GetValueNode<PrimitivePtr>(cnode->input(0));
return prim != nullptr;
} }
PrimitivePtr GetCNodePrimitive(const AnfNodePtr &node) { PrimitivePtr GetCNodePrimitive(const AnfNodePtr &node) {
......
...@@ -124,6 +124,7 @@ class AnfNode : public Base { ...@@ -124,6 +124,7 @@ class AnfNode : public Base {
const KernelInfoDevice *kernel_info() const { return kernel_info_.get(); } const KernelInfoDevice *kernel_info() const { return kernel_info_.get(); }
KernelInfoDevice *kernel_info() { return kernel_info_.get(); } KernelInfoDevice *kernel_info() { return kernel_info_.get(); }
const KernelInfoDevicePtr &kernel_info_ptr() { return kernel_info_; }
void set_kernel_info(const KernelInfoDevicePtr &kernel_info) { kernel_info_ = kernel_info; } void set_kernel_info(const KernelInfoDevicePtr &kernel_info) { kernel_info_ = kernel_info; }
AbstractBasePtr abstract() const { return abstract_; } AbstractBasePtr abstract() const { return abstract_; }
...@@ -395,9 +396,9 @@ static S GetValue(const ValuePtr &value) { ...@@ -395,9 +396,9 @@ static S GetValue(const ValuePtr &value) {
std::string GetCNodeFuncName(CNodePtr cnode); std::string GetCNodeFuncName(CNodePtr cnode);
// used to check whether an AnfNode is a cnode with a kind of Primitive as first input // used to check whether an AnfNode is a cnode with a kind of Primitive as first input
bool IsPrimitiveCNode(const AnfNodePtr &node, const PrimitivePtr &value); bool IsPrimitiveCNode(const AnfNodePtr &node, const PrimitivePtr &value = nullptr);
// used to check whether an AnfNode is a cnode with a Primitive as first input // used to get PrimitivePtr from a cnode first input
PrimitivePtr GetCNodePrimitive(const AnfNodePtr &node); PrimitivePtr GetCNodePrimitive(const AnfNodePtr &node);
// used to check whether an AnfNode is a valuenode having some Primitive value // used to check whether an AnfNode is a valuenode having some Primitive value
......
...@@ -69,7 +69,7 @@ std::string CNode::fullname_with_scope() { ...@@ -69,7 +69,7 @@ std::string CNode::fullname_with_scope() {
} }
fullname_with_scope_ = name; fullname_with_scope_ = name;
} else { } else {
// cnode input 0 should be primitive ptr // cnode input 0 should be primitive ptr or funcgraph ptr
auto value_ptr = input(0)->cast<ValueNodePtr>(); auto value_ptr = input(0)->cast<ValueNodePtr>();
if (value_ptr == nullptr) { if (value_ptr == nullptr) {
MS_LOG(WARNING) << "Input 0 of cnode is not a value node, its type is " << input(0)->type_name() << "."; MS_LOG(WARNING) << "Input 0 of cnode is not a value node, its type is " << input(0)->type_name() << ".";
...@@ -83,11 +83,23 @@ std::string CNode::fullname_with_scope() { ...@@ -83,11 +83,23 @@ std::string CNode::fullname_with_scope() {
return fullname_with_scope_; return fullname_with_scope_;
} }
PrimitivePtr prim = GetValue<PrimitivePtr>(input_value); auto prim = input_value->cast<PrimitivePtr>();
MS_EXCEPTION_IF_NULL(scope()); MS_EXCEPTION_IF_NULL(scope());
MS_EXCEPTION_IF_NULL(prim); fullname_with_scope_ = scope()->name() + "/";
fullname_with_scope_ = if (prim != nullptr) {
scope()->name() + "/" + prim->name() + "-op" + id_generator::get_id(shared_from_base<CNode>()); fullname_with_scope_ += prim->name();
} else {
auto func_graph = input_value->cast<FuncGraphPtr>();
MS_EXCEPTION_IF_NULL(func_graph);
auto fg_flag = func_graph->get_attr(FUNC_GRAPH_FLAG_COMPOSITE);
if (fg_flag != nullptr) {
auto fg_name = GetValue<std::string>(fg_flag);
fullname_with_scope_ += "composite_" + fg_name;
} else {
fullname_with_scope_ += func_graph->ToString();
}
}
fullname_with_scope_ += "-op" + id_generator::get_id(shared_from_base<CNode>());
} }
return fullname_with_scope_; return fullname_with_scope_;
......
...@@ -77,9 +77,9 @@ class Bool : public Number { ...@@ -77,9 +77,9 @@ class Bool : public Number {
TypeId generic_type_id() const override { return kNumberTypeBool; } TypeId generic_type_id() const override { return kNumberTypeBool; }
TypePtr DeepCopy() const override { return std::make_shared<Bool>(); } TypePtr DeepCopy() const override { return std::make_shared<Bool>(); }
std::string ToString() const override { return "Bool_"; } std::string ToString() const override { return "Bool"; }
std::string ToReprString() const override { return "bool_"; } std::string ToReprString() const override { return "bool"; }
std::string DumpText() const override { return "Bool_"; } std::string DumpText() const override { return "Bool"; }
}; };
// Int // Int
......
...@@ -44,7 +44,7 @@ using mindspore::abstract::VirtualAbstractClosure; ...@@ -44,7 +44,7 @@ using mindspore::abstract::VirtualAbstractClosure;
* Methods of Graph * Methods of Graph
*/ */
FuncGraph::FuncGraph() FuncGraph::FuncGraph()
: flags_(), : attrs_(),
transforms_(), transforms_(),
parameter_default_value_(), parameter_default_value_(),
seen_(0), seen_(0),
...@@ -155,13 +155,27 @@ ParameterPtr FuncGraph::AddWeightParameter(const std::string &name) { ...@@ -155,13 +155,27 @@ ParameterPtr FuncGraph::AddWeightParameter(const std::string &name) {
return p; return p;
} }
bool FuncGraph::has_flag(const std::string &flag) { bool FuncGraph::has_flag(const std::string &key) {
if (flags_.count(flag)) { auto iter = attrs_.find(key);
return flags_[flag]; if (iter != attrs_.cend()) {
if (iter->second->isa<BoolImm>()) {
return GetValue<bool>(iter->second);
}
MS_LOG(WARNING) << "key " << key << " is not a flag, please use has_attr function.";
} }
return false; return false;
} }
bool FuncGraph::has_attr(const std::string &key) {
auto iter = attrs_.find(key);
return !(iter == attrs_.cend());
}
ValuePtr FuncGraph::get_attr(const std::string &key) {
auto iter = attrs_.find(key);
return iter == attrs_.cend() ? nullptr : iter->second;
}
CNodePtr FuncGraph::NewCNode(const std::vector<AnfNodePtr> &inputs) { CNodePtr FuncGraph::NewCNode(const std::vector<AnfNodePtr> &inputs) {
CNodePtr cnode = std::make_shared<CNode>(inputs, shared_from_base<FuncGraph>()); CNodePtr cnode = std::make_shared<CNode>(inputs, shared_from_base<FuncGraph>());
if (has_flag(GRAPH_FLAG_HAS_EFFECT)) { if (has_flag(GRAPH_FLAG_HAS_EFFECT)) {
...@@ -979,8 +993,8 @@ void FuncGraph::ReleaseFullOrderToEffectOrder() { ...@@ -979,8 +993,8 @@ void FuncGraph::ReleaseFullOrderToEffectOrder() {
depend_inputs.push_back(*iter); depend_inputs.push_back(*iter);
} }
} }
set_flags(GRAPH_FLAG_HAS_EFFECT, false); set_flag(GRAPH_FLAG_HAS_EFFECT, false);
set_flags(GRAPH_FLAG_EFFECT_PATIAL_ORDER, true); set_flag(GRAPH_FLAG_EFFECT_PATIAL_ORDER, true);
if (!depend_inputs.empty()) { if (!depend_inputs.empty()) {
SetEffectDepends(depend_inputs); SetEffectDepends(depend_inputs);
} }
......
...@@ -48,6 +48,7 @@ using FuncGraphMap = OrderedMap<FuncGraphPtr, int>; ...@@ -48,6 +48,7 @@ using FuncGraphMap = OrderedMap<FuncGraphPtr, int>;
const char FUNC_GRAPH_FLAG_IGNORE_VALUES[] = "ignore_values"; const char FUNC_GRAPH_FLAG_IGNORE_VALUES[] = "ignore_values";
const char FUNC_GRAPH_FLAG_DEFER_INLINE[] = "defer_inline"; const char FUNC_GRAPH_FLAG_DEFER_INLINE[] = "defer_inline";
const char FUNC_GRAPH_FLAG_CORE[] = "core"; const char FUNC_GRAPH_FLAG_CORE[] = "core";
const char FUNC_GRAPH_FLAG_COMPOSITE[] = "composite";
const char FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER[] = "spec_param"; const char FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER[] = "spec_param";
// ANF transform class // ANF transform class
...@@ -162,10 +163,19 @@ class FuncGraph : public FuncGraphBase { ...@@ -162,10 +163,19 @@ class FuncGraph : public FuncGraphBase {
void set_is_generate(bool generated) { is_generated_ = generated; } void set_is_generate(bool generated) { is_generated_ = generated; }
bool is_generated() const { return is_generated_; } bool is_generated() const { return is_generated_; }
bool has_flag(const std::string &flag); std::unordered_map<std::string, ValuePtr> &attrs() { return attrs_; }
std::unordered_map<std::string, bool> &flags() { return flags_; } void set_attrs(const std::unordered_map<std::string, ValuePtr> &attrs) {
void set_flags(const std::unordered_map<std::string, bool> &flags) { flags_ = flags; } for (auto &attr : attrs) {
void set_flags(const std::string &key, const bool value) { flags_[key] = value; } attrs_[attr.first] = attr.second;
}
}
bool has_flag(const std::string &key);
void set_flag(const std::string &key, bool flag) { attrs_[key] = MakeValue(flag); }
void erase_flag(const std::string &key) { (void)attrs_.erase(key); }
bool has_attr(const std::string &key);
ValuePtr get_attr(const std::string &key);
void set_attr(const std::string &key, const ValuePtr &value) { attrs_[key] = value; }
std::unordered_map<std::string, FuncGraphTransform> &transforms() { return transforms_; } std::unordered_map<std::string, FuncGraphTransform> &transforms() { return transforms_; }
void set_transforms(const std::unordered_map<std::string, FuncGraphTransform> &transforms) { void set_transforms(const std::unordered_map<std::string, FuncGraphTransform> &transforms) {
...@@ -284,7 +294,7 @@ class FuncGraph : public FuncGraphBase { ...@@ -284,7 +294,7 @@ class FuncGraph : public FuncGraphBase {
std::unordered_map<AnfNodePtr, AnfNodePtr> &make_ref_params() { return make_ref_params_; } std::unordered_map<AnfNodePtr, AnfNodePtr> &make_ref_params() { return make_ref_params_; }
std::unordered_map<std::string, bool> flags_; std::unordered_map<std::string, ValuePtr> attrs_;
std::unordered_map<std::string, FuncGraphTransform> transforms_; std::unordered_map<std::string, FuncGraphTransform> transforms_;
// parameter default value // parameter default value
std::map<std::string, AnfNodePtr> parameter_default_value_; std::map<std::string, AnfNodePtr> parameter_default_value_;
......
...@@ -89,6 +89,7 @@ void Cloner::CloneCNode(const AnfNodePtr &node, const FuncGraphPtr &target) { ...@@ -89,6 +89,7 @@ void Cloner::CloneCNode(const AnfNodePtr &node, const FuncGraphPtr &target) {
new_node->set_abstract(old_node->abstract()); new_node->set_abstract(old_node->abstract());
ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope(); ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope();
new_node->set_scope(scope); new_node->set_scope(scope);
new_node->set_kernel_info(old_node->kernel_info_ptr());
repl_node_[old_node] = new_node; repl_node_[old_node] = new_node;
nodes_.emplace_back(old_node, new_node); nodes_.emplace_back(old_node, new_node);
TraceManager::EndTrace(); TraceManager::EndTrace();
...@@ -210,7 +211,7 @@ void Cloner::SetFuncGraphInfo(const FuncGraphPtr &func_graph, FuncGraphPtr *cons ...@@ -210,7 +211,7 @@ void Cloner::SetFuncGraphInfo(const FuncGraphPtr &func_graph, FuncGraphPtr *cons
MS_EXCEPTION_IF_NULL(target_func_graph); MS_EXCEPTION_IF_NULL(target_func_graph);
TraceManager::DebugTrace(func_graph->debug_info(), target_relation_); TraceManager::DebugTrace(func_graph->debug_info(), target_relation_);
*target_func_graph = std::make_shared<FuncGraph>(); *target_func_graph = std::make_shared<FuncGraph>();
(*target_func_graph)->set_flags(func_graph->flags()); (*target_func_graph)->set_attrs(func_graph->attrs());
(*target_func_graph)->set_transforms(func_graph->transforms()); (*target_func_graph)->set_transforms(func_graph->transforms());
(*target_func_graph)->set_has_vararg(func_graph->has_vararg()); (*target_func_graph)->set_has_vararg(func_graph->has_vararg());
(*target_func_graph)->set_has_kwarg(func_graph->has_kwarg()); (*target_func_graph)->set_has_kwarg(func_graph->has_kwarg());
...@@ -635,9 +636,14 @@ FuncGraphPtr TransformableClone(const FuncGraphPtr &func_graph, const TraceInfoP ...@@ -635,9 +636,14 @@ FuncGraphPtr TransformableClone(const FuncGraphPtr &func_graph, const TraceInfoP
if (MsContext::GetInstance()->is_multi_graph_sink()) { if (MsContext::GetInstance()->is_multi_graph_sink()) {
if (func_graph->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES)) { if (func_graph->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES)) {
new_func_graph->set_flags(FUNC_GRAPH_FLAG_IGNORE_VALUES, true); new_func_graph->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true);
} }
} }
if (func_graph->has_attr(FUNC_GRAPH_FLAG_COMPOSITE)) {
new_func_graph->set_attr(FUNC_GRAPH_FLAG_COMPOSITE, func_graph->get_attr(FUNC_GRAPH_FLAG_COMPOSITE));
}
return new_func_graph; return new_func_graph;
} }
} // namespace mindspore } // namespace mindspore
...@@ -9,6 +9,10 @@ if (ENABLE_D) ...@@ -9,6 +9,10 @@ if (ENABLE_D)
file(GLOB_RECURSE D_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} file(GLOB_RECURSE D_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
"kernel_query.cc" "kernel_query.cc"
"kernel_fusion.cc" "kernel_fusion.cc"
"akg/ascend/*.cc"
"akg/akgkernelbuild.cc"
"akg/akg_kernel_attrs_process.cc"
"akg/akg_kernel_metadata.cc"
"tbe/*.cc" "tbe/*.cc"
"aicpu/*.cc" "aicpu/*.cc"
"rts/*.cc" "rts/*.cc"
......
...@@ -79,6 +79,10 @@ void SetAkgAttrsForCast(const AnfNodePtr &anf_node) { ...@@ -79,6 +79,10 @@ void SetAkgAttrsForCast(const AnfNodePtr &anf_node) {
dst_type = "float32"; dst_type = "float32";
} else if (output_type == kFloat16->type_id()) { } else if (output_type == kFloat16->type_id()) {
dst_type = "float16"; dst_type = "float16";
} else if (output_type == kInt32->type_id()) {
dst_type = "int32";
} else {
MS_LOG(WARNING) << "Unknown cast_to type: " << TypeIdToType(output_type)->ToString();
} }
AnfAlgo::SetNodeAttr("dst_type", MakeValue(dst_type), anf_node); AnfAlgo::SetNodeAttr("dst_type", MakeValue(dst_type), anf_node);
} }
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "kernel/akg/akg_kernel_metadata.h"
#include <memory>
#include "session/anf_runtime_algorithm.h"
#include "kernel/oplib/oplib.h"
#include "kernel/common_utils.h"
namespace mindspore {
namespace kernel {
void AkgMetadataInfo(const CNodePtr &kernel_node,
std::vector<std::shared_ptr<KernelBuildInfo>> *const kernel_info_list) {
MS_EXCEPTION_IF_NULL(kernel_node);
MS_EXCEPTION_IF_NULL(kernel_info_list);
std::string op_name = AnfAlgo::GetCNodeName(kernel_node);
for (size_t i = 0; i < support_devices.size(); i++) {
auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, OpImplyType::kAKG);
if (op_info_ptr == nullptr) {
continue;
}
if (!ParseMetadata(kernel_node, op_info_ptr, Processor(i), kernel_info_list)) {
MS_LOG(WARNING) << "Akg parsed metadata of op[" << op_name << "], device[" << support_devices[i] << "] failed.";
} else {
MS_LOG(DEBUG) << "Akg parsed metadata of op[" << op_name << "], device[" << support_devices[i] << "].";
break;
}
}
if (kernel_info_list->empty()) {
MS_LOG(WARNING) << "Akg dose not has metadata of op[" << op_name << "].";
}
}
} // namespace kernel
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_AKG_AKG_KERNEL_METADATA_H_
#define MINDSPORE_CCSRC_KERNEL_AKG_AKG_KERNEL_METADATA_H_
#include <string>
#include <vector>
#include <unordered_map>
#include <memory>
#include "kernel/kernel_build_info.h"
namespace mindspore {
namespace kernel {
void AkgMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_KERNEL_AKG_AKG_KERNEL_METADATA_H_
...@@ -43,7 +43,9 @@ namespace kernel { ...@@ -43,7 +43,9 @@ namespace kernel {
constexpr int ME_MAX_KERNEL_NAME_LENGTH = 200; constexpr int ME_MAX_KERNEL_NAME_LENGTH = 200;
constexpr int32_t ARGS_SIZE = 1; constexpr int32_t ARGS_SIZE = 1;
constexpr auto kCompileWithJsonFunc = "compilewithjson"; constexpr auto kCompileWithJsonFunc = "compilewithjson";
// json key // json key
constexpr auto kOpDesc = "op_desc";
constexpr auto kInputDesc = "input_desc"; constexpr auto kInputDesc = "input_desc";
constexpr auto kShape = "shape"; constexpr auto kShape = "shape";
constexpr auto kDataType = "data_type"; constexpr auto kDataType = "data_type";
...@@ -51,13 +53,24 @@ constexpr auto kOutputDesc = "output_desc"; ...@@ -51,13 +53,24 @@ constexpr auto kOutputDesc = "output_desc";
constexpr auto kName = "name"; constexpr auto kName = "name";
constexpr auto kTensorName = "tensor_name"; constexpr auto kTensorName = "tensor_name";
constexpr auto kValue = "value"; constexpr auto kValue = "value";
constexpr auto KInpputNames = "input_names"; constexpr auto KDynInputSizes = "dyn_input_sizes";
constexpr auto KInputNames = "input_names";
constexpr auto KInput = "input"; constexpr auto KInput = "input";
constexpr auto KDtype = "dtype"; constexpr auto KDtype = "dtype";
int AkgKernelBuild::op_cnt_ = 0; namespace {
std::mutex AkgKernelBuild::op_cnt_mtx_; template <typename T>
std::string Vector2Str(const std::vector<T> &inputs) {
if (!inputs.empty()) {
std::ostringstream oss;
(void)std::copy(inputs.begin(), inputs.end() - 1, std::ostream_iterator<T>(oss, ", "));
oss << inputs.back();
return oss.str();
}
return "";
}
} // namespace
std::string PyObjectToStr(PyObject *const PyObj) { std::string AkgKernelBuild::PyObjectToStr(PyObject *const PyObj) {
char *pChar = nullptr; char *pChar = nullptr;
std::string str_res; std::string str_res;
if (PyObj == nullptr) { if (PyObj == nullptr) {
...@@ -76,6 +89,72 @@ std::string PyObjectToStr(PyObject *const PyObj) { ...@@ -76,6 +89,72 @@ std::string PyObjectToStr(PyObject *const PyObj) {
return str_res; return str_res;
} }
std::string GetTensorName(const nlohmann::json &node_json, const std::string &tag,
const std::pair<size_t, size_t> &position) {
if (node_json.count(tag) == 0) {
MS_LOG(ERROR) << "Node [" << node_json.dump() << "] has no key [" << tag << "].";
return "";
}
auto const &tag_desc = node_json[tag];
nlohmann::json first_index;
if (tag == kOutputDesc) {
first_index = tag_desc;
} else if (!tag_desc.is_array() || tag_desc.size() <= position.first) {
MS_LOG(ERROR) << "Node [" << tag_desc.dump() << "] has no enough value [" << position.first << "].";
return "";
} else {
first_index = tag_desc[position.first];
}
if (!first_index.is_array() || first_index.size() <= position.second) {
MS_LOG(ERROR) << "Node [" << first_index.dump() << "] has no enough value [" << position.second << "].";
return "";
}
auto const &second_index = first_index[position.second];
if (second_index.count(kTensorName) == 0) {
MS_LOG(ERROR) << "Node [" << second_index.dump() << "] has no key [" << kTensorName << "].";
return "";
}
return second_index[kTensorName];
}
void SetTensorName(const std::string &tag, const std::string &new_name, const std::pair<size_t, size_t> &position,
nlohmann::json *const node_json) {
MS_EXCEPTION_IF_NULL(node_json);
if (node_json->count(tag) == 0) {
MS_LOG(ERROR) << "Node [" << node_json->dump() << "] has no key [" << tag << "].";
return;
}
nlohmann::json *tag_desc = &((*node_json)[tag]);
nlohmann::json *first_index;
if (tag == kOutputDesc) {
first_index = tag_desc;
} else if (!tag_desc->is_array() || tag_desc->size() <= position.first) {
MS_LOG(ERROR) << "Node [" << tag_desc->dump() << "] has no enough value [" << position.first << "].";
return;
} else {
first_index = &((*tag_desc)[position.first]);
}
if (!first_index->is_array() || first_index->size() <= position.second) {
MS_LOG(ERROR) << "Node [" << first_index->dump() << "] has no enough value [" << position.second << "].";
return;
}
nlohmann::json *second_index = &((*first_index)[position.second]);
if (second_index->count(kTensorName) == 0) {
MS_LOG(ERROR) << "Node [" << second_index->dump() << "] has no key [" << kTensorName << "].";
return;
}
(*second_index)[kTensorName] = new_name;
return;
}
int AkgKernelBuild::op_cnt_ = 0;
std::mutex AkgKernelBuild::op_cnt_mtx_;
std::string AkgKernelBuild::GetProcessor(const AnfNodePtr &anf_node) { std::string AkgKernelBuild::GetProcessor(const AnfNodePtr &anf_node) {
MS_EXCEPTION_IF_NULL(anf_node); MS_EXCEPTION_IF_NULL(anf_node);
std::string device; std::string device;
...@@ -187,10 +266,7 @@ bool AkgKernelBuild::CreateInputDescJson(const AnfNodePtr &anf_node, nlohmann::j ...@@ -187,10 +266,7 @@ bool AkgKernelBuild::CreateInputDescJson(const AnfNodePtr &anf_node, nlohmann::j
for (size_t input_i = 0; input_i < input_tensor_num; input_i++) { for (size_t input_i = 0; input_i < input_tensor_num; input_i++) {
// dtype : float16 // dtype : float16
auto type_id = AnfAlgo::GetInputDeviceDataType(anf_node, real_input_index); auto type_id = AnfAlgo::GetInputDeviceDataType(anf_node, real_input_index);
TypePtr type_ptr = TypeIdToType(type_id); std::string dtype = TypeId2String(type_id);
MS_EXCEPTION_IF_NULL(type_ptr);
std::string dtype = type_ptr->ToString();
dtype = Dtype2String(dtype);
if (dtype.empty()) { if (dtype.empty()) {
MS_LOG(ERROR) << "Op [" << op_name << "] input [" << input_i << "] data type is null. "; MS_LOG(ERROR) << "Op [" << op_name << "] input [" << input_i << "] data type is null. ";
return false; return false;
...@@ -198,13 +274,23 @@ bool AkgKernelBuild::CreateInputDescJson(const AnfNodePtr &anf_node, nlohmann::j ...@@ -198,13 +274,23 @@ bool AkgKernelBuild::CreateInputDescJson(const AnfNodePtr &anf_node, nlohmann::j
nlohmann::json input_desc_json; nlohmann::json input_desc_json;
input_desc_json[kDataType] = dtype; input_desc_json[kDataType] = dtype;
input_desc_json[kName] = op_input_name; input_desc_json[kName] = op_input_name;
input_desc_json[kTensorName] = input_desc_json[kTensorName] = "input_" + std::to_string(GetInputTensorIdxInc(anf_node, real_input_index));
op_input_name + "_" + std::to_string(real_input_index) + "_" + std::to_string(input_i); auto input_shape = AnfAlgo::GetInputDeviceShape(anf_node, real_input_index);
input_desc_json[kShape] = AnfAlgo::GetInputDeviceShape(anf_node, real_input_index); if (GetInputTensorValue(anf_node, real_input_index, &input_desc_json)) {
MS_LOG(WARNING) << "we take input[" << real_input_index << "] of [" << anf_node->DebugString(2)
<< "] as const tensor, shape: [" << Vector2Str(input_shape)
<< "], value: " << input_desc_json[kValue];
input_shape.clear();
}
if (input_shape.empty()) {
input_shape.push_back(1);
}
input_desc_json[kShape] = input_shape;
input_list.emplace_back(input_desc_json); input_list.emplace_back(input_desc_json);
real_input_index++;
} }
inputs_json->emplace_back(input_list); inputs_json->emplace_back(input_list);
real_input_index++;
} }
return true; return true;
} }
...@@ -220,10 +306,7 @@ bool AkgKernelBuild::CreateOutputDescJson(const AnfNodePtr &anf_node, nlohmann:: ...@@ -220,10 +306,7 @@ bool AkgKernelBuild::CreateOutputDescJson(const AnfNodePtr &anf_node, nlohmann::
for (size_t i = 0; i < output_tensor_num; i++) { for (size_t i = 0; i < output_tensor_num; i++) {
nlohmann::json output_json; nlohmann::json output_json;
auto type_id = AnfAlgo::GetOutputDeviceDataType(anf_node, i); auto type_id = AnfAlgo::GetOutputDeviceDataType(anf_node, i);
TypePtr type_ptr = TypeIdToType(type_id); std::string dtype = TypeId2String(type_id);
MS_EXCEPTION_IF_NULL(type_ptr);
std::string dtype = type_ptr->ToString();
dtype = Dtype2String(dtype);
if (dtype.empty()) { if (dtype.empty()) {
MS_LOG(ERROR) << "Op [" << op_name << "] output [" << i << "] data type is null. "; MS_LOG(ERROR) << "Op [" << op_name << "] output [" << i << "] data type is null. ";
return false; return false;
...@@ -232,7 +315,7 @@ bool AkgKernelBuild::CreateOutputDescJson(const AnfNodePtr &anf_node, nlohmann:: ...@@ -232,7 +315,7 @@ bool AkgKernelBuild::CreateOutputDescJson(const AnfNodePtr &anf_node, nlohmann::
std::string output_name = outputs[i]->name(); std::string output_name = outputs[i]->name();
output_json[kDataType] = dtype; output_json[kDataType] = dtype;
output_json[kName] = output_name; output_json[kName] = output_name;
output_json[kTensorName] = output_name + "_" + std::to_string(i); output_json[kTensorName] = "output_" + std::to_string(i) + "_" + std::to_string(GetOutputTensorIdxInc());
output_json[kShape] = AnfAlgo::GetOutputDeviceShape(anf_node, i); output_json[kShape] = AnfAlgo::GetOutputDeviceShape(anf_node, i);
outputs_json->push_back(output_json); outputs_json->push_back(output_json);
} }
...@@ -358,15 +441,14 @@ bool AkgKernelBuild::GenerateSingleKernelJson(const AnfNodePtr &anf_node, const ...@@ -358,15 +441,14 @@ bool AkgKernelBuild::GenerateSingleKernelJson(const AnfNodePtr &anf_node, const
MS_EXCEPTION_IF_NULL(op_info_ptr); MS_EXCEPTION_IF_NULL(op_info_ptr);
// get basic params from currentNodeOpDesc // get basic params from currentNodeOpDesc
(*node_json)["platform"] = "AKG";
(*node_json)[kName] = op_name; (*node_json)[kName] = op_name;
(*node_json)["fusion_type"] = AnfAlgo::GetFusionType(anf_node);
(*node_json)["impl_path"] = op_info_ptr->impl_path(); (*node_json)["impl_path"] = op_info_ptr->impl_path();
(*node_json)["process"] = AkgKernelBuild::GetProcessor(anf_node); (*node_json)["process"] = AkgKernelBuild::GetProcessor(anf_node);
(*node_json)["composite"] = false;
auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
ValuePtr input_names_v = primitive->GetAttr(KInpputNames); ValuePtr input_names_v = primitive->GetAttr(KInputNames);
if (input_names_v == nullptr) { if (input_names_v == nullptr) {
MS_LOG(ERROR) << "ApplyKernel has no input_names, op[" << op_name << "]."; MS_LOG(ERROR) << "ApplyKernel has no input_names, op[" << op_name << "].";
return false; return false;
...@@ -465,12 +547,12 @@ KernelPackPtr AkgKernelBuild::OpBuild(const std::string &node_json, const AnfNod ...@@ -465,12 +547,12 @@ KernelPackPtr AkgKernelBuild::OpBuild(const std::string &node_json, const AnfNod
(void)alarm(0); (void)alarm(0);
if (pRes == nullptr) { if (pRes == nullptr) {
MS_LOG(ERROR) << "No ret got, failed to call function [" << kCompileWithJsonFunc << "], args:\n(" MS_LOG(ERROR) << "No ret got, failed to call function [" << kCompileWithJsonFunc << "], args:\n("
<< PyObjectToStr(pArg) << ")."; << AkgKernelBuild::PyObjectToStr(pArg) << ").";
return nullptr; return nullptr;
} }
if (PyObject_IsTrue(pRes) != 1) { if (PyObject_IsTrue(pRes) != 1) {
MS_LOG(ERROR) << "Illegal ret, failed to call function [" << kCompileWithJsonFunc << "], args:\n(" MS_LOG(ERROR) << "Illegal ret, failed to call function [" << kCompileWithJsonFunc << "], args:\n("
<< PyObjectToStr(pArg) << ")."; << AkgKernelBuild::PyObjectToStr(pArg) << ").";
return nullptr; return nullptr;
} }
...@@ -513,5 +595,29 @@ KernelPackPtr AkgKernelBuild::BuildByJson(const AnfNodePtr &anf_node, std::vecto ...@@ -513,5 +595,29 @@ KernelPackPtr AkgKernelBuild::BuildByJson(const AnfNodePtr &anf_node, std::vecto
<< "]"; << "]";
return kernel_pack; return kernel_pack;
} }
size_t AkgKernelBuild::GetInputTensorIdxInc(const AnfNodePtr &anf_node, size_t input_idx) {
MS_EXCEPTION_IF_NULL(anf_node);
auto cnode = anf_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (input_idx + 1 >= cnode->inputs().size()) {
MS_EXCEPTION(ArgumentError) << "input_idx [" << input_idx << "] is out of index of inputs of ["
<< cnode->inputs().size() - 1 << "][" << cnode->DebugString() << "]";
}
auto input_node = cnode->input(input_idx + 1);
if (input_tensor_idx_.find(input_node) == input_tensor_idx_.end()) {
size_t index = input_tensor_idx_.size();
input_tensor_idx_[input_node] = index;
}
return input_tensor_idx_[input_node];
}
size_t AkgKernelBuild::GetOutputTensorIdxInc() {
size_t idx = output_tensor_idx_++;
return idx;
}
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore
...@@ -32,29 +32,45 @@ namespace mindspore { ...@@ -32,29 +32,45 @@ namespace mindspore {
namespace kernel { namespace kernel {
class AkgKernelBuild { class AkgKernelBuild {
public: public:
AkgKernelBuild() = default; AkgKernelBuild() {
input_tensor_idx_ = {};
output_tensor_idx_ = 0;
}
~AkgKernelBuild() = default; ~AkgKernelBuild() = default;
KernelPackPtr BuildByJson(const AnfNodePtr &anf_node, std::vector<size_t> *const input_size, KernelPackPtr BuildByJson(const AnfNodePtr &anf_node, std::vector<size_t> *const input_size,
std::vector<size_t> *const output_size); std::vector<size_t> *const output_size);
static std::string GetProcessor(const AnfNodePtr &anf_node);
static std::string PyObjectToStr(PyObject *const PyObj);
private: protected:
bool CreateInputDescJson(const AnfNodePtr &anf_node, nlohmann::json *const inputs_json); bool CreateInputDescJson(const AnfNodePtr &anf_node, nlohmann::json *const inputs_json);
bool CreateOutputDescJson(const AnfNodePtr &anf_node, nlohmann::json *const outputs_json); bool CreateOutputDescJson(const AnfNodePtr &anf_node, nlohmann::json *const outputs_json);
bool CreateAttrDescJson(const AnfNodePtr &anf_node, const std::string &op_name, bool CreateAttrDescJson(const AnfNodePtr &anf_node, const std::string &op_name,
const std::shared_ptr<OpInfo> &op_info, nlohmann::json *const attrs_json); const std::shared_ptr<OpInfo> &op_info, nlohmann::json *const attrs_json);
KernelPackPtr OpBuild(const std::string &node_json, const AnfNodePtr &anf_node);
int GetOpCntInc();
size_t GetInputTensorIdxInc(const AnfNodePtr &anf_node, size_t input_idx);
size_t GetOutputTensorIdxInc();
bool GenerateSingleKernelJson(const AnfNodePtr &anf_node, const std::string &op_name, bool GenerateSingleKernelJson(const AnfNodePtr &anf_node, const std::string &op_name,
nlohmann::json *const node_json); nlohmann::json *const node_json);
KernelPackPtr OpBuild(const std::string &node_json, const AnfNodePtr &anf_node);
int GetOpCntInc();
std::string GetProcessor(const AnfNodePtr &anf_node);
static int op_cnt_; static int op_cnt_;
// lock for variable fusionOpCnt in singleton mode // lock for variable fusionOpCnt in singleton mode
static std::mutex op_cnt_mtx_; static std::mutex op_cnt_mtx_;
std::string json_name_; std::string json_name_;
std::string json_info_; std::string json_info_;
std::unordered_map<AnfNodePtr, size_t> input_tensor_idx_;
size_t output_tensor_idx_;
}; };
bool GetIOSize(const nlohmann::json &node_json, std::vector<size_t> *const input_size,
std::vector<size_t> *const output_size);
void SetTensorName(const std::string &tag, const std::string &new_name, const std::pair<size_t, size_t> &position,
nlohmann::json *const node_json);
std::string GetTensorName(const nlohmann::json &node_json, const std::string &tag,
const std::pair<size_t, size_t> &position);
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "kernel/akg/ascend/akg_ascend_kernel_build.h"
#include <algorithm>
#include <map>
#include <memory>
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>
#include <Python.h>
#include "ir/dtype.h"
#include "ir/func_graph.h"
#include "kernel/kernel.h"
#include "kernel/common_utils.h"
#include "kernel/tbe/tbe_utils.h"
#include "kernel/akg/ascend/akg_ascend_kernel_mod.h"
#include "kernel/akg/akg_kernel_attrs_process.h"
#include "session/anf_runtime_algorithm.h"
namespace mindspore {
namespace kernel {
constexpr int32_t PARALLEL_ARGS_SIZE = 3;
constexpr int32_t PROCESS_NUM = 16;
constexpr int32_t TIME_OUT = 300;
constexpr auto kOpDesc = "op_desc";
constexpr auto kShape = "shape";
constexpr auto kDataType = "data_type";
constexpr auto kInputDesc = "input_desc";
constexpr auto kOutputDesc = "output_desc";
constexpr auto kTensorName = "tensor_name";
constexpr auto kCompileAkgKernelParallelFunc = "compile_akg_kernel_parallel";
constexpr auto kMultiProcModule = "mindspore._extends.parallel_compile.akg_compiler.multi_process_compiler";
bool AkgAscendKernelBuilder::CollectJson(const AnfNodePtr &anf_node) {
MS_EXCEPTION_IF_NULL(anf_node);
std::string op_name = AnfAlgo::GetCNodeName(anf_node);
MS_LOG(INFO) << "AKG start compile, op[" << op_name << "], device[" << AkgKernelBuild::GetProcessor(anf_node) << "]";
auto it = kAkgKernelAttrsProcessMap.find(op_name);
if (it != kAkgKernelAttrsProcessMap.end()) {
it->second(anf_node);
}
MS_LOG(INFO) << "Akg start compile, op[" << op_name << "], device[" << AkgKernelBuild::GetProcessor(anf_node) << "]";
nlohmann::json node_json;
if (!GenerateSingleKernelJson(anf_node, op_name, &node_json)) {
MS_LOG(ERROR) << "Op[" << op_name << "] create single kernel json failed.";
}
kernel_json_ = node_json.dump();
if (!GetIOSize(node_json, &input_size_list_, &output_size_list_)) {
MS_LOG(ERROR) << "Cal mem size failed.";
return false;
}
return true;
}
bool AkgAscendKernelBuilder::CollectFusedJson(const std::vector<AnfNodePtr> &anf_nodes,
const std::vector<AnfNodePtr> &input_list,
const std::vector<AnfNodePtr> &output_list) {
if (anf_nodes.empty() || input_list.empty()) {
MS_LOG(ERROR) << "Invalid input size, anf_nodes [" << anf_nodes.size() << "], input_list [" << input_list.size()
<< "].";
return false;
}
MS_LOG(INFO) << "anf_nodes [" << output_list.size() << "], input_list [" << anf_nodes.size() << "], output_list ["
<< input_list.size() << "].";
std::map<AnfNodePtr, nlohmann::json> node_json_map;
for (auto const &anf_node : anf_nodes) {
MS_EXCEPTION_IF_NULL(anf_node);
std::string op_name = AnfAlgo::GetCNodeName(anf_node);
if (!AnfAlgo::IsRealKernel(anf_node)) {
MS_LOG(ERROR) << "Invalid anf node to build [" << anf_node->fullname_with_scope() << "].";
return false;
}
auto it = kAkgKernelAttrsProcessMap.find(op_name);
if (it != kAkgKernelAttrsProcessMap.end()) {
it->second(anf_node);
}
nlohmann::json node_json;
if (!GenerateSingleKernelJson(anf_node, op_name, &node_json)) {
MS_LOG(ERROR) << "Op [" << op_name << "] create single kernel json failed.";
return false;
}
// No need for composite op.
node_json.erase("id");
node_json.erase("op");
node_json.erase("composite");
auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);
MS_EXCEPTION_IF_NULL(primitive);
if (primitive->GetAttr("fusion") != nullptr) {
node_json["fusion"] = primitive->GetAttr("fusion")->ToString();
}
node_json_map[anf_node] = node_json;
}
for (auto const &anf_node : anf_nodes) {
std::vector<int> dyn_input_sizes;
auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);
MS_EXCEPTION_IF_NULL(primitive);
if (primitive->GetAttr(kAttrDynInputSizes) != nullptr) {
dyn_input_sizes = GetValue<const std::vector<int>>(primitive->GetAttr(kAttrDynInputSizes));
}
bool is_dynamic_input = !dyn_input_sizes.empty();
size_t input_num = is_dynamic_input ? dyn_input_sizes.size() : AnfAlgo::GetInputTensorNum(anf_node);
size_t real_input_index = 0;
for (size_t i = 0; i < input_num; ++i) {
size_t input_tensor_num = is_dynamic_input ? IntToSize(dyn_input_sizes[i]) : 1;
for (size_t j = 0; j < input_tensor_num; ++j) {
auto tmp_input = GetKernelInput(anf_node, real_input_index);
std::string tensor_name = GetTensorName(node_json_map[anf_node], kInputDesc, std::make_pair(i, j));
if (node_json_map.find(tmp_input.first) != node_json_map.end()) {
std::string new_tensor_name =
GetTensorName(node_json_map[tmp_input.first], kOutputDesc, std::make_pair(0, tmp_input.second));
SetTensorName(kInputDesc, new_tensor_name, std::make_pair(i, j), &(node_json_map[anf_node]));
MS_LOG(DEBUG) << "Update [" << real_input_index << "] input [" << tensor_name << "] of ["
<< anf_node->fullname_with_scope() << "] to [" << tmp_input.second << "] output ["
<< new_tensor_name << "] of [" << tmp_input.first->fullname_with_scope() << "].";
} else {
MS_LOG(DEBUG) << "[" << real_input_index << "] input " << tensor_name << "] of ["
<< anf_node->fullname_with_scope() << "] is out input.";
}
real_input_index++;
}
}
}
nlohmann::json fused_node_json;
std::vector<nlohmann::json> node_json_desc;
std::transform(anf_nodes.begin(), anf_nodes.end(), std::back_inserter(node_json_desc),
[&node_json_map](const AnfNodePtr &anf_node) { return node_json_map[anf_node]; });
fused_node_json[kOpDesc] = node_json_desc;
nlohmann::json inputs_json;
auto input_index = GetInputIndex(anf_nodes, input_list);
for (size_t i = 0; i < input_index.size(); ++i) {
auto tmp_input = input_index[i];
auto type_id = AnfAlgo::GetInputDeviceDataType(tmp_input.first, tmp_input.second.first);
std::string dtype = TypeId2String(type_id);
nlohmann::json input_desc_json;
input_desc_json[kTensorName] = GetTensorName(node_json_map[tmp_input.first], kInputDesc, tmp_input.second);
input_desc_json[kDataType] = dtype;
input_desc_json[kShape] = AnfAlgo::GetInputDeviceShape(tmp_input.first, tmp_input.second.first);
inputs_json.emplace_back(std::vector<nlohmann::json>{input_desc_json});
}
fused_node_json[kInputDesc] = inputs_json;
nlohmann::json outputs_json;
auto output_index = GetOutputIndex(anf_nodes, input_list, output_list);
for (size_t i = 0; i < output_index.size(); ++i) {
auto tmp_output = output_index[i];
bool found = false;
nlohmann::json output_desc_json;
for (size_t input_i = 0; input_i < input_list.size(); ++input_i) {
if (tmp_output.first == input_list[input_i]) {
output_desc_json = inputs_json[input_i][0];
found = true;
break;
}
}
if (!found) {
auto type_id = AnfAlgo::GetOutputDeviceDataType(tmp_output.first, tmp_output.second);
std::string dtype = TypeId2String(type_id);
output_desc_json[kTensorName] =
GetTensorName(node_json_map[tmp_output.first], kOutputDesc, std::make_pair(0, tmp_output.second));
output_desc_json[kDataType] = dtype;
auto output_shape = AnfAlgo::GetOutputDeviceShape(tmp_output.first, tmp_output.second);
if (output_shape.empty()) {
output_shape.push_back(1);
}
output_desc_json[kShape] = output_shape;
}
outputs_json.emplace_back(output_desc_json);
}
fused_node_json[kOutputDesc] = outputs_json;
size_t hash_id = std::hash<std::string>()(fused_node_json.dump());
json_name_ = "Fused_";
auto fg = anf_nodes[0]->func_graph();
MS_EXCEPTION_IF_NULL(fg);
auto attr_val = fg->get_attr(FUNC_GRAPH_FLAG_COMPOSITE);
if (attr_val != nullptr) {
auto fg_attr = GetValue<std::string>(attr_val);
(void)json_name_.append(fg_attr).append("_");
}
(void)json_name_.append(std::to_string(hash_id));
fused_node_json["composite_graph"] = fg->ToString();
fused_node_json["op"] = json_name_;
fused_node_json["platform"] = "AKG";
fused_node_json["process"] = "aicore";
fused_node_json["composite"] = true;
kernel_json_ = fused_node_json.dump();
if (!GetIOSize(fused_node_json, &input_size_list_, &output_size_list_)) {
MS_LOG(ERROR) << "Cal mem size failed.";
return false;
}
return true;
}
void GenParallelCompileFuncArgs(const std::vector<std::string> &kernel_jsons, PyObject **p_args) {
MS_EXCEPTION_IF_NULL(p_args);
*p_args = PyTuple_New(PARALLEL_ARGS_SIZE);
PyObject *arg1 = PyList_New(kernel_jsons.size());
for (int i = 0; i < PyList_Size(arg1); ++i) {
PyList_SetItem(arg1, i, Py_BuildValue("s", kernel_jsons[i].c_str()));
}
PyObject *arg2 = Py_BuildValue("i", PROCESS_NUM);
PyObject *arg3 = Py_BuildValue("i", TIME_OUT);
(void)PyTuple_SetItem(*p_args, 0, arg1);
(void)PyTuple_SetItem(*p_args, 1, arg2);
(void)PyTuple_SetItem(*p_args, 2, arg3);
}
bool AkgOpParallelBuild(const std::vector<std::pair<AkgAscendKernelBuilder, AnfNodePtr>> &build_args) {
// Remove cached nodes, gether unique nodes, and collect repeated nodes which need postprecess.
std::vector<std::string> jsons;
std::unordered_set<std::string> json_name_set;
std::vector<std::pair<AkgAscendKernelBuilder, AnfNodePtr>> repeat_nodes;
for (const auto &[builder, anf_node] : build_args) {
MS_EXCEPTION_IF_NULL(anf_node);
auto json_name = builder.json_name();
MS_LOG(DEBUG) << "Akg start compile op: " << json_name;
auto cached_kernel_pack = tbe::TbeUtils::SearchCache(json_name, AkgKernelBuild::GetProcessor(anf_node));
if (cached_kernel_pack != nullptr) {
MS_LOG(DEBUG) << "Use cached kernel, json_name_[" << json_name << "], fullname_with_scope["
<< anf_node->fullname_with_scope() << "].";
auto kernel_mod_ptr = std::make_shared<AkgKernelMod>(cached_kernel_pack);
kernel_mod_ptr->SetInputSizeList(builder.input_size_list());
kernel_mod_ptr->SetOutputSizeList(builder.output_size_list());
AnfAlgo::SetKernelMod(kernel_mod_ptr, anf_node.get());
continue;
}
if (json_name_set.count(json_name) != 0) {
repeat_nodes.push_back({builder, anf_node});
continue;
}
json_name_set.insert(json_name);
auto node_json = builder.kernel_json();
kernel::SaveJsonInfo(json_name, node_json);
jsons.push_back(node_json);
}
// No nodes need to be compiled!
if (jsons.empty()) {
return true;
}
// Try to call python method to compile nodes parallely.
PyObject *p_module = nullptr;
PyObject *p_func = nullptr;
PyObject *p_arg = nullptr;
PyObject *p_res = nullptr;
p_module = PyImport_ImportModule(kMultiProcModule);
if (p_module == nullptr) {
MS_LOG(ERROR) << "Failed to import [" << kMultiProcModule << "].";
return false;
}
p_func = PyObject_GetAttrString(p_module, kCompileAkgKernelParallelFunc);
GenParallelCompileFuncArgs(jsons, &p_arg);
MS_LOG(DEBUG) << "Call function [" << kCompileAkgKernelParallelFunc << "], try to compile " << jsons.size()
<< " Akg kernels parallelly.";
p_res = PyEval_CallObject(p_func, p_arg);
if (p_res == nullptr) {
PyErr_Print();
MS_LOG(ERROR) << "No ret got, failed to call function [" << kCompileAkgKernelParallelFunc << "], args:\n("
<< AkgKernelBuild::PyObjectToStr(p_arg) << ").";
return false;
}
if (PyObject_IsTrue(p_res) != 1) {
PyErr_Print();
MS_LOG(ERROR) << "Illegal ret, failed to call function [" << kCompileAkgKernelParallelFunc << "], args:\n("
<< AkgKernelBuild::PyObjectToStr(p_arg) << ").";
return false;
}
// All unique done here, cache them and set kernel.
for (const auto &[builder, anf_node] : build_args) {
auto json_name = builder.json_name();
auto new_kernel_pack = tbe::TbeUtils::InsertCache(json_name, AkgKernelBuild::GetProcessor(anf_node));
if (new_kernel_pack == nullptr) {
MS_LOG(ERROR) << "Insert to cache failed, json_name_[" << json_name << "], fullname_with_scope["
<< anf_node->fullname_with_scope() << "].";
return false;
}
auto kernel_mod_ptr = std::make_shared<AkgKernelMod>(new_kernel_pack);
kernel_mod_ptr->SetInputSizeList(builder.input_size_list());
kernel_mod_ptr->SetOutputSizeList(builder.output_size_list());
AnfAlgo::SetKernelMod(kernel_mod_ptr, anf_node.get());
MS_LOG(DEBUG) << "Akg compile " << json_name << " kernel and insert cache successfully!";
}
// Handle repeated nodes.
for (const auto &[builder, anf_node] : repeat_nodes) {
auto node_json = builder.kernel_json();
auto json_name = builder.json_name();
auto cached_kernel_pack = tbe::TbeUtils::SearchCache(json_name, AkgKernelBuild::GetProcessor(anf_node));
if (cached_kernel_pack == nullptr) return false;
MS_LOG(INFO) << "Use just compiled kernel, json_name_[" << json_name << "], fullname_with_scope["
<< anf_node->fullname_with_scope() << "].";
auto kernel_mod_ptr = std::make_shared<AkgKernelMod>(cached_kernel_pack);
kernel_mod_ptr->SetInputSizeList(builder.input_size_list());
kernel_mod_ptr->SetOutputSizeList(builder.output_size_list());
AnfAlgo::SetKernelMod(kernel_mod_ptr, anf_node.get());
}
return true;
}
bool AkgAscendKernelParallelBuild(const std::vector<AnfNodePtr> &anf_nodes) {
std::vector<std::pair<AkgAscendKernelBuilder, AnfNodePtr>> json_and_node;
for (const auto &anf_node : anf_nodes) {
MS_EXCEPTION_IF_NULL(anf_node);
AkgAscendKernelBuilder akg_cce_kernel_builder;
KernelPackPtr kernel_pack = nullptr;
auto cnode = anf_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (AnfAlgo::IsCompositeKernel(cnode)) {
auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(cnode);
auto mng = func_graph->manager();
if (mng == nullptr) {
mng = Manage(func_graph, true);
func_graph->set_manager(mng);
}
MS_EXCEPTION_IF_NULL(func_graph);
std::vector<AnfNodePtr> node_list;
std::vector<AnfNodePtr> input_list;
std::vector<AnfNodePtr> output_list;
std::string op_name = AnfAlgo::GetCNodeName(anf_node);
MS_LOG(INFO) << "Akg start compile composite op[" << op_name << "]";
GetValidKernelNodes(func_graph, &node_list, &input_list, &output_list);
if (!akg_cce_kernel_builder.CollectFusedJson(node_list, input_list, output_list)) {
MS_EXCEPTION(UnknownError) << "Akg build failed composite op[" << op_name << "].";
}
} else {
if (!akg_cce_kernel_builder.CollectJson(anf_node)) {
MS_EXCEPTION(UnknownError) << "Akg build failed op[" << AnfAlgo::GetCNodeName(anf_node) << "].";
}
}
json_and_node.push_back({akg_cce_kernel_builder, anf_node});
}
if (json_and_node.empty()) {
MS_LOG(DEBUG) << "There is no kernel needed to be compiled.";
return true;
}
return AkgOpParallelBuild(json_and_node);
}
} // namespace kernel
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_AKG_ASCEND_AKG_ASCEND_KERNEL_BUILD_H_
#define MINDSPORE_CCSRC_KERNEL_AKG_ASCEND_AKG_ASCEND_KERNEL_BUILD_H_
#include <string>
#include <memory>
#include <vector>
#include "ir/anf.h"
#include "kernel/kernel.h"
#include "kernel/akg/akgkernelbuild.h"
namespace mindspore {
namespace kernel {
class AkgAscendKernelBuilder : public AkgKernelBuild {
public:
AkgAscendKernelBuilder() = default;
~AkgAscendKernelBuilder() = default;
bool CollectJson(const AnfNodePtr &anf_node);
bool CollectFusedJson(const std::vector<AnfNodePtr> &anf_nodes, const std::vector<AnfNodePtr> &input_list,
const std::vector<AnfNodePtr> &output_list);
std::string json_name() const { return json_name_; }
std::string kernel_json() const { return kernel_json_; }
const std::vector<size_t> &input_size_list() const { return input_size_list_; }
const std::vector<size_t> &output_size_list() const { return output_size_list_; }
private:
std::string kernel_json_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
};
bool AkgAscendKernelParallelBuild(const std::vector<AnfNodePtr> &anf_nodes);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_KERNEL_AKG_ASCEND_AKG_ASCEND_KERNEL_BUILD_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "kernel/akg/ascend/akg_ascend_kernel_mod.h"
#include <algorithm>
#include <fstream>
#include <map>
#include <memory>
#include <mutex>
#include <unordered_map>
#include <vector>
#include "nlohmann/json.hpp"
#include "runtime/rt.h"
#include "utils/log_adapter.h"
#include "utils/convert_utils.h"
namespace mindspore {
namespace kernel {
using std::fstream;
using std::map;
using std::mutex;
using std::string;
using TbeTaskInfoPtr = std::shared_ptr<ge::model_runner::TbeTaskInfo>;
using tbe::KernelManager;
constexpr uint32_t DEFAULT_BLOCK_DIM = 1;
/**
* @brief infotable contain func_stub\blockdim\kernel file buffer
*/
AkgKernelMod::AkgKernelMod(const KernelPackPtr &kernel_pack) : kernel_pack_(kernel_pack) {}
void AkgKernelMod::SetInputSizeList(const std::vector<size_t> &size_list) { input_size_list_ = size_list; }
void AkgKernelMod::SetOutputSizeList(const std::vector<size_t> &size_list) { output_size_list_ = size_list; }
void AkgKernelMod::SetWorkspaceSizeList(const std::vector<size_t> &size_list) { workspace_size_list_ = size_list; }
const std::vector<size_t> &AkgKernelMod::GetInputSizeList() const { return input_size_list_; }
const std::vector<size_t> &AkgKernelMod::GetOutputSizeList() const { return output_size_list_; }
const std::vector<size_t> &AkgKernelMod::GetWorkspaceSizeList() const { return workspace_size_list_; }
void DumpData(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs) {
const char *dump_data = getenv("MS_KERNEL_DUMP_DATA");
if (dump_data) {
int idx = 0;
for (const auto &x : inputs) {
std::vector<char> buf(x->size);
if (RT_ERROR_NONE != rtMemcpy(buf.data(), buf.size(), reinterpret_cast<const void *>(x->addr), x->size,
RT_MEMCPY_DEVICE_TO_HOST)) {
MS_LOG(WARNING) << "Call runtime rtMemcpy error.";
return;
}
std::string file_name("input_");
file_name += std::to_string(idx);
std::ofstream file(file_name, std::ios::binary);
if (file.is_open()) {
(void)file.write(buf.data(), SizeToLong(buf.size()));
file.close();
idx++;
} else {
MS_LOG(ERROR) << "Open file failed.";
return;
}
}
idx = 0;
for (const auto &x : outputs) {
std::vector<char> buf(x->size);
if (RT_ERROR_NONE != rtMemcpy(buf.data(), buf.size(), reinterpret_cast<const void *>(x->addr), x->size,
RT_MEMCPY_DEVICE_TO_HOST)) {
MS_LOG(WARNING) << "Call runtime rtMemcpy error.";
return;
}
std::string file_name("output_");
file_name += std::to_string(idx);
std::ofstream file(file_name, std::ios::binary);
if (file.is_open()) {
(void)file.write(buf.data(), SizeToLong(buf.size()));
file.close();
idx++;
} else {
MS_LOG(ERROR) << "Open file failed.";
return;
}
}
}
}
bool AkgKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
if (stream_ptr == 0) {
MS_LOG(ERROR) << "stream_ptr should not be nullptr.";
return false;
}
if (kernel_pack_ == nullptr) {
MS_LOG(ERROR) << "kernel pack should not be nullptr.";
return false;
}
uint32_t block_dim = DEFAULT_BLOCK_DIM; // default blockdim equal to 1.
auto func_stub = KernelManager::GenFuncStub(*kernel_pack_, false, &block_dim);
if (func_stub == 0) {
MS_LOG(ERROR) << "GenFuncStub failed.";
return false;
}
// pack all addresses into a vector.
std::vector<void *> runtime_args;
(void)std::transform(std::begin(inputs), std::end(inputs), std::back_inserter(runtime_args),
[](const AddressPtr &input) -> void * { return input->addr; });
(void)std::transform(std::begin(outputs), std::end(outputs), std::back_inserter(runtime_args),
[](const AddressPtr &output) -> void * { return output->addr; });
rtL2Ctrl_t *l2ctrl = nullptr;
auto stream = reinterpret_cast<rtStream_t *>(stream_ptr);
if (RT_ERROR_NONE != rtKernelLaunch(reinterpret_cast<void *>(func_stub), block_dim, runtime_args.data(),
SizeToUint(sizeof(void *) * runtime_args.size()), l2ctrl, stream)) {
MS_LOG(ERROR) << "Call runtime rtKernelLaunch error.";
return false;
}
DumpData(inputs, outputs);
return true;
}
std::vector<TaskInfoPtr> AkgKernelMod::GenTask(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs, uint32_t stream_id) {
if (kernel_pack_ == nullptr) {
MS_LOG(EXCEPTION) << "kernel pack should not be nullptr.";
}
std::vector<uint8_t> args;
uint32_t args_size = 0;
std::vector<uint8_t> sm_desc;
void *binary = nullptr;
uint32_t binary_size = 0;
std::vector<uint8_t> meta_data;
std::vector<void *> input_data_addrs;
std::vector<void *> output_data_addrs;
std::vector<void *> workspace_addrs;
// pack all addresses into a vector.
(void)std::transform(std::begin(inputs), std::end(inputs), std::back_inserter(input_data_addrs),
[](const AddressPtr &input) -> void * { return input->addr; });
(void)std::transform(std::begin(outputs), std::end(outputs), std::back_inserter(output_data_addrs),
[](const AddressPtr &output) -> void * { return output->addr; });
uint32_t block_dim = DEFAULT_BLOCK_DIM; // default blockdim equal to 1.
auto func_stub = KernelManager::GenFuncStub(*kernel_pack_, false, &block_dim);
if (func_stub == 0) {
MS_LOG(EXCEPTION) << "GenFuncStub failed.";
}
std::string stub_func = KernelManager::GetStubFuncName(kernel_pack_);
MS_LOG(DEBUG) << "The block_dim is:" << block_dim;
TbeTaskInfoPtr task_info_ptr = make_shared<ge::model_runner::TbeTaskInfo>(
stream_id, stub_func, block_dim, args, args_size, sm_desc, binary, binary_size, meta_data, input_data_addrs,
output_data_addrs, workspace_addrs);
return {task_info_ptr};
}
} // namespace kernel
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_AKG_ASCEND_AKG_ASCEND_KERNEL_MOD_H_
#define MINDSPORE_CCSRC_KERNEL_AKG_ASCEND_AKG_ASCEND_KERNEL_MOD_H_
#include <string>
#include <vector>
#include <memory>
#include "kernel/ascend_kernel_mod.h"
#include "kernel/tbe/tbe_utils.h"
namespace mindspore {
namespace kernel {
class AkgKernelMod : public AscendKernelMod {
public:
explicit AkgKernelMod(const KernelPackPtr &kernel_pack);
~AkgKernelMod() final {}
void SetInputSizeList(const std::vector<size_t> &size_list);
void SetOutputSizeList(const std::vector<size_t> &size_list);
void SetWorkspaceSizeList(const std::vector<size_t> &size_list);
const std::vector<size_t> &GetInputSizeList() const override;
const std::vector<size_t> &GetOutputSizeList() const override;
const std::vector<size_t> &GetWorkspaceSizeList() const override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override;
std::vector<TaskInfoPtr> GenTask(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, uint32_t stream_id) override;
private:
KernelPackPtr kernel_pack_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
};
using AkgKernelModPtr = std::shared_ptr<AkgKernelMod>;
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_KERNEL_AKG_ASCEND_AKG_ASCEND_KERNEL_MOD_H_
...@@ -22,6 +22,11 @@ ...@@ -22,6 +22,11 @@
#include "nlohmann/json.hpp" #include "nlohmann/json.hpp"
#include "session/anf_runtime_algorithm.h" #include "session/anf_runtime_algorithm.h"
#include "common/utils.h" #include "common/utils.h"
#include "ir/manager.h"
#include "ir/meta_tensor.h"
#include "ir/func_graph.h"
#include "operator/ops.h"
#include "utils/graph_utils.h"
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
...@@ -47,12 +52,6 @@ const std::map<TypeId, std::string> type_id_str_map = { ...@@ -47,12 +52,6 @@ const std::map<TypeId, std::string> type_id_str_map = {
{TypeId::kNumberTypeBool, "bool"}, {TypeId::kNumberTypeBool, "bool"},
}; };
const std::map<std::string, std::string> DATATYPE_STRING_MAP{
{"Float32", "float32"}, {"Float16", "float16"}, {"Int8", "int8"}, {"Int16", "int16"},
{"UInt16", "uint16"}, {"UInt8", "uint8"}, {"Int32", "int32"}, {"UInt32", "uint32"},
{"Int64", "int64"}, {"UInt64", "uint64"}, {"Bool_", "bool"}, {"Float64", "double"},
};
const std::unordered_map<std::string, std::string> dtype_shortdtype_map_ = { const std::unordered_map<std::string, std::string> dtype_shortdtype_map_ = {
{"float16", "f16"}, {"float32", "f32"}, {"float64", "f64"}, {"int8", "i8"}, {"int16", "i16"}, {"int32", "i32"}, {"float16", "f16"}, {"float32", "f32"}, {"float64", "f64"}, {"int8", "i8"}, {"int16", "i16"}, {"int32", "i32"},
{"int64", "i64"}, {"uint8", "u8"}, {"uint16", "u16"}, {"uint32", "u32"}, {"uint64", "u64"}, {"bool", "bool"}, {"int64", "i64"}, {"uint8", "u8"}, {"uint16", "u16"}, {"uint32", "u32"}, {"uint64", "u64"}, {"bool", "bool"},
...@@ -242,14 +241,6 @@ TypeId DtypeToTypeId(const std::string &dtypes) { ...@@ -242,14 +241,6 @@ TypeId DtypeToTypeId(const std::string &dtypes) {
} }
} }
std::string Dtype2String(const std::string &dtypes) {
auto iter = DATATYPE_STRING_MAP.find(dtypes);
if (iter == DATATYPE_STRING_MAP.end()) {
MS_EXCEPTION(ArgumentError) << "Illegal input dtype:" << dtypes;
}
return iter->second;
}
std::string TypeId2String(TypeId type_id) { std::string TypeId2String(TypeId type_id) {
auto iter = type_id_str_map.find(type_id); auto iter = type_id_str_map.find(type_id);
if (iter == type_id_str_map.end()) { if (iter == type_id_str_map.end()) {
...@@ -360,7 +351,7 @@ bool SetOutputKernelBuilderInfo(const std::vector<std::shared_ptr<OpIOInfo>> &ou ...@@ -360,7 +351,7 @@ bool SetOutputKernelBuilderInfo(const std::vector<std::shared_ptr<OpIOInfo>> &ou
output_num = 1; output_num = 1;
} else { } else {
if (output_idx < real_output_num) { if (output_idx < real_output_num) {
MS_LOG(INFO) << "Set output kernel builder info, output type is optional, output index is :" << output_idx; MS_LOG(DEBUG) << "Set output kernel builder info, output type is optional, output index is :" << output_idx;
output_num = 1; output_num = 1;
} }
} }
...@@ -402,7 +393,7 @@ void SetKernelBuildInfo(const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBu ...@@ -402,7 +393,7 @@ void SetKernelBuildInfo(const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBu
} }
if (imply_type == kAKG) { if (imply_type == kAKG) {
builder->SetKernelType(AUTO_DIFF_KERNEL); builder->SetKernelType(AKG_KERNEL);
} else if (imply_type == kAICPU) { } else if (imply_type == kAICPU) {
builder->SetKernelType(AICPU_KERNEL); builder->SetKernelType(AICPU_KERNEL);
} else { } else {
...@@ -537,5 +528,256 @@ bool IsSameShape(const std::vector<size_t> &shape_a, const std::vector<size_t> & ...@@ -537,5 +528,256 @@ bool IsSameShape(const std::vector<size_t> &shape_a, const std::vector<size_t> &
} }
return true; return true;
} }
std::pair<AnfNodePtr, size_t> GetKernelInput(const AnfNodePtr &anf_node, size_t index) {
MS_EXCEPTION_IF_NULL(anf_node);
if (index >= AnfAlgo::GetInputTensorNum(anf_node)) {
MS_EXCEPTION(ArgumentError) << "Index is out of the size of anf_node inputs.";
}
auto cnode = anf_node->cast<CNodePtr>();
if (cnode == nullptr) {
return AnfAlgo::VisitKernel(anf_node, 0);
} else {
return AnfAlgo::VisitKernel(anf_node->cast<CNodePtr>()->input(index + 1), 0);
}
}
std::vector<std::pair<AnfNodePtr, std::pair<size_t, size_t>>> GetInputIndex(const std::vector<AnfNodePtr> &node_list,
const std::vector<AnfNodePtr> &input_list) {
std::vector<std::pair<AnfNodePtr, std::pair<size_t, size_t>>> input_index;
for (size_t i = 0; i < input_list.size(); ++i) {
auto const &input = input_list[i];
MS_EXCEPTION_IF_NULL(input);
bool found = false;
// using NodeUsersMap = std::unordered_map<AnfNodePtr, std::set<std::pair<AnfNodePtr, int>>>;
auto mng = input->func_graph()->manager();
MS_EXCEPTION_IF_NULL(mng);
const NodeUsersMap &users = mng->node_users();
auto input_users = users.find(input);
if (input_users == users.end() || input_users->second.empty()) {
MS_EXCEPTION(ArgumentError) << "Input [" << i << "][" << input->DebugString(2) << "] of ["
<< input->func_graph()->ToString() << "] has no users.";
}
for (auto const &input_user : input_users->second) {
for (auto const &anf_node : node_list) {
if (anf_node != input_user.first) {
continue;
}
std::vector<int> dyn_input_sizes;
auto prim = AnfAlgo::GetCNodePrimitive(anf_node);
MS_EXCEPTION_IF_NULL(prim);
if (prim->GetAttr(kAttrDynInputSizes) != nullptr) {
dyn_input_sizes = GetValue<const std::vector<int>>(prim->GetAttr(kAttrDynInputSizes));
}
if (dyn_input_sizes.empty()) {
input_index.push_back(std::make_pair(anf_node, std::make_pair(IntToSize(input_user.second - 1), 0)));
found = true;
break;
} else {
int used_as_idx = input_user.second - 1;
int accum_idx = 0;
size_t dyn_i = 0;
for (; dyn_i < dyn_input_sizes.size(); ++dyn_i) {
accum_idx += dyn_input_sizes[dyn_i];
if (used_as_idx < accum_idx) {
input_index.push_back(std::make_pair(
anf_node, std::make_pair(dyn_i, IntToSize(used_as_idx - (accum_idx - dyn_input_sizes[dyn_i])))));
break;
}
}
if (dyn_i != dyn_input_sizes.size()) {
found = true;
break;
}
}
}
if (found) {
break;
}
}
if (!found) {
MS_EXCEPTION(ArgumentError) << "Input [" << i << "][" << input->DebugString(2) << "] of ["
<< input->func_graph()->ToString() << "] found no related kernel info.";
}
}
return input_index;
}
std::vector<std::pair<AnfNodePtr, size_t>> GetOutputIndex(const std::vector<AnfNodePtr> &node_list,
const std::vector<AnfNodePtr> &input_list,
const std::vector<AnfNodePtr> &output_list) {
std::vector<std::pair<AnfNodePtr, size_t>> output_index;
for (size_t i = 0; i < output_list.size(); ++i) {
auto const &output = output_list[i];
MS_EXCEPTION_IF_NULL(output);
bool found = false;
auto pree_node = AnfAlgo::VisitKernel(output, 0);
auto pos = std::find(std::begin(node_list), std::end(node_list), pree_node.first);
if (pos != std::end(node_list)) {
output_index.push_back(pree_node);
continue;
}
auto ret = std::find(std::begin(input_list), std::end(input_list), pree_node.first);
if (ret != std::end(input_list)) {
output_index.push_back(std::make_pair(pree_node.first, 0));
found = true;
}
if (!found) {
MS_EXCEPTION(ArgumentError) << "Output [" << i << "][" << output->DebugString(2) << "] of ["
<< output->func_graph()->ToString() << "] found no related kernel info.";
}
}
return output_index;
}
void GetValidKernelNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *node_list) {
MS_EXCEPTION_IF_NULL(node_list);
MS_EXCEPTION_IF_NULL(func_graph);
std::vector<AnfNodePtr> node_lists = TopoSort(func_graph->get_return());
for (auto const &node : node_lists) {
if (!AnfAlgo::IsRealKernel(node) || !node->isa<CNode>()) {
continue;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (IsValueNode<Primitive>(cnode->input(kAnfPrimitiveIndex))) {
node_list->push_back(node);
}
}
}
void GetValidKernelNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *node_list,
std::vector<AnfNodePtr> *input_list, std::vector<AnfNodePtr> *output_list) {
MS_EXCEPTION_IF_NULL(node_list);
MS_EXCEPTION_IF_NULL(input_list);
MS_EXCEPTION_IF_NULL(output_list);
MS_EXCEPTION_IF_NULL(func_graph);
GetValidKernelNodes(func_graph, node_list);
auto parameters = func_graph->parameters();
input_list->insert(input_list->begin(), parameters.begin(), parameters.end());
auto func_output = func_graph->output();
MS_EXCEPTION_IF_NULL(func_output);
if (func_output->isa<CNode>()) {
// multi output.
auto cnode = func_output->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto input0 = cnode->input(kAnfPrimitiveIndex);
MS_EXCEPTION_IF_NULL(input0);
if (IsPrimitive(input0, prim::kPrimMakeTuple)) {
for (size_t input_idx = 1; input_idx < cnode->inputs().size(); ++input_idx) {
auto input_node = cnode->input(input_idx);
MS_EXCEPTION_IF_NULL(input_node);
output_list->push_back(AnfAlgo::VisitKernel(input_node, 0).first);
}
} else {
// single output.
output_list->push_back(AnfAlgo::VisitKernel(func_output, 0).first);
}
} else {
// single output.
output_list->push_back(AnfAlgo::VisitKernel(func_output, 0).first);
}
}
bool GetInputTensorValue(const AnfNodePtr &anf_node, size_t input_idx, nlohmann::json *const node_json) {
MS_EXCEPTION_IF_NULL(anf_node);
MS_EXCEPTION_IF_NULL(node_json);
auto cnode = anf_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (input_idx + 1 >= cnode->size()) {
MS_EXCEPTION(ArgumentError) << "input_idx [" << input_idx << "] is out of index of inputs of ["
<< cnode->inputs().size() << "][" << cnode->DebugString() << "]";
}
auto input_node = cnode->input(input_idx + 1);
if (!IsValueNode<tensor::Tensor>(input_node)) {
return false;
}
auto tensor = GetValueNode<tensor::TensorPtr>(input_node);
if (tensor == nullptr) {
return false;
}
auto type_id = tensor->data_type();
auto *data = tensor->data_c();
MS_EXCEPTION_IF_NULL(data);
if (tensor->DataDim() > 1 || tensor->DataSize() != 1) {
// not const tensor.
MS_LOG(WARNING) << "We take first value of tensor whose datasize != 1, [" << input_node->DebugString(2) << "]";
}
if (type_id == kFloat32->type_id()) {
float *val = static_cast<float *>(data);
MS_EXCEPTION_IF_NULL(val);
(*node_json)["value"] = val[0];
MS_LOG(DEBUG) << "Value of tensor[" << cnode->DebugString() << "] is [float32][" << *val << "].";
return true;
} else if (type_id == kFloat16->type_id()) {
float16 *val = static_cast<float16 *>(data);
MS_EXCEPTION_IF_NULL(val);
(*node_json)["value"] = static_cast<float>(val[0]);
MS_LOG(INFO) << "Value of tensor[" << cnode->DebugString() << "] is [float16][" << *val << "].";
return true;
} else if (type_id == kInt32->type_id()) {
int *val = static_cast<int *>(data);
MS_EXCEPTION_IF_NULL(val);
(*node_json)["value"] = val[0];
MS_LOG(INFO) << "Value of tensor[" << cnode->DebugString() << "] is [int32][" << *val << "].";
return true;
}
MS_LOG(ERROR) << "Unknown value type of tensor[" << cnode->DebugString() << "]";
return false;
}
void GetGraphRealOutput(const FuncGraphPtr &func_graph, std::vector<std::pair<AnfNodePtr, size_t>> *node_list) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node_list);
auto output = func_graph->output();
MS_EXCEPTION_IF_NULL(output);
if (AnfAlgo::IsRealKernel(output)) {
// single output.
node_list->push_back(std::make_pair(output, 0));
return;
} else if (IsPrimitiveCNode(output, prim::kPrimMakeTuple)) {
auto output_cnode = output->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(output_cnode);
// multi output.
auto &inputs = output_cnode->inputs();
for (size_t i = 1; i < inputs.size(); ++i) {
auto in_with_idx = AnfAlgo::VisitKernel(inputs[i], 0);
node_list->push_back(in_with_idx);
}
return;
}
MS_EXCEPTION(ArgumentError) << "Unknown output type: " << output->DebugString(2)
<< " of graph: " << func_graph->ToString();
}
bool IsWeightBoundary(const AnfNodePtr &node) {
if (node->isa<ValueNode>()) {
return true;
}
if (node->isa<Parameter>() && AnfAlgo::IsParameterWeight(node->cast<ParameterPtr>())) {
return true;
}
return false;
}
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore
...@@ -20,9 +20,12 @@ ...@@ -20,9 +20,12 @@
#include <dirent.h> #include <dirent.h>
#include <memory> #include <memory>
#include <unordered_map> #include <unordered_map>
#include <unordered_set>
#include <map> #include <map>
#include <string> #include <string>
#include <vector> #include <vector>
#include <utility>
#include <nlohmann/json.hpp>
#include "kernel/kernel.h" #include "kernel/kernel.h"
#include "kernel/oplib/opinfo.h" #include "kernel/oplib/opinfo.h"
#include "kernel/kernel_build_info.h" #include "kernel/kernel_build_info.h"
...@@ -73,16 +76,26 @@ bool CheckCache(const std::string &kernel_name); ...@@ -73,16 +76,26 @@ bool CheckCache(const std::string &kernel_name);
KernelPackPtr SearchCache(const std::string &kernel_name, const std::string &processor); KernelPackPtr SearchCache(const std::string &kernel_name, const std::string &processor);
KernelPackPtr InsertCache(const std::string &kernel_name, const std::string &processor); KernelPackPtr InsertCache(const std::string &kernel_name, const std::string &processor);
TypeId DtypeToTypeId(const std::string &dtypes); TypeId DtypeToTypeId(const std::string &dtypes);
std::string Dtype2String(const std::string &dtypes);
std::string Dtype2ShortType(const std::string &dtypes); std::string Dtype2ShortType(const std::string &dtypes);
std::string TypeId2String(TypeId type_id); std::string TypeId2String(TypeId type_id);
size_t GetDtypeNbyte(const std::string &dtypes); size_t GetDtypeNbyte(const std::string &dtypes);
bool ParseMetadata(const CNodePtr &kernel_node, const std::shared_ptr<const OpInfo> &op_info_ptr, Processor processor, bool ParseMetadata(const CNodePtr &kernel_node, const std::shared_ptr<const OpInfo> &op_info_ptr, Processor processor,
std::vector<std::shared_ptr<KernelBuildInfo>> *const kernel_info_list); std::vector<std::shared_ptr<KernelBuildInfo>> *const kernel_info_list);
bool IsAtomicNode(const CNodePtr &kernel_node);
void SaveJsonInfo(const std::string &json_name, const std::string &info); void SaveJsonInfo(const std::string &json_name, const std::string &info);
std::string GetProcessor(const AnfNodePtr &anf_node); std::string GetProcessor(const AnfNodePtr &anf_node);
bool IsSameShape(const std::vector<size_t> &shape_a, const std::vector<size_t> &shape_b); bool IsSameShape(const std::vector<size_t> &shape_a, const std::vector<size_t> &shape_b);
std::pair<AnfNodePtr, size_t> GetKernelInput(const AnfNodePtr &anf_node, size_t index);
std::vector<std::pair<AnfNodePtr, std::pair<size_t, size_t>>> GetInputIndex(const std::vector<AnfNodePtr> &node_list,
const std::vector<AnfNodePtr> &input_list);
std::vector<std::pair<AnfNodePtr, size_t>> GetOutputIndex(const std::vector<AnfNodePtr> &node_list,
const std::vector<AnfNodePtr> &input_list,
const std::vector<AnfNodePtr> &output_list);
void GetValidKernelNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *node_list,
std::vector<AnfNodePtr> *input_list, std::vector<AnfNodePtr> *output_list);
void GetValidKernelNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *node_list);
bool GetInputTensorValue(const AnfNodePtr &anf_node, size_t input_idx, nlohmann::json *const node_json);
void GetGraphRealOutput(const FuncGraphPtr &func_graph, std::vector<std::pair<AnfNodePtr, size_t>> *node_list);
bool IsWeightBoundary(const AnfNodePtr &node);
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore
......
...@@ -27,7 +27,7 @@ ...@@ -27,7 +27,7 @@
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
namespace mindspore { namespace mindspore {
enum KernelType : int { UNKNOWN_KERNEL_TYPE = 0, AUTO_DIFF_KERNEL, AICPU_KERNEL, RT_KERNEL, HCCL_KERNEL, TBE_KERNEL }; enum KernelType : int { UNKNOWN_KERNEL_TYPE = 0, AKG_KERNEL, AICPU_KERNEL, RT_KERNEL, HCCL_KERNEL, TBE_KERNEL };
namespace kernel { namespace kernel {
......
...@@ -31,7 +31,7 @@ class KernelBuildInfo { ...@@ -31,7 +31,7 @@ class KernelBuildInfo {
class KernelBuildInfoBuilder; class KernelBuildInfoBuilder;
KernelBuildInfo() { KernelBuildInfo() {
kernel_type_ = AUTO_DIFF_KERNEL; kernel_type_ = AKG_KERNEL;
fusion_type_ = OPAQUE; fusion_type_ = OPAQUE;
processor_ = AICORE; processor_ = AICORE;
op_pattern_ = kCommonPattern; op_pattern_ = kCommonPattern;
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include "kernel/rts/rt_kernel_info.h" #include "kernel/rts/rt_kernel_info.h"
#include "kernel/hccl/hccl_kernel_metadata.h" #include "kernel/hccl/hccl_kernel_metadata.h"
#include "kernel/tbe/tbe_kernel_select.h" #include "kernel/tbe/tbe_kernel_select.h"
#include "kernel/akg/akg_kernel_metadata.h"
#include "session/anf_runtime_algorithm.h" #include "session/anf_runtime_algorithm.h"
namespace mindspore { namespace mindspore {
...@@ -50,10 +51,14 @@ void FilterInvalidKernelInfo(const CNodePtr &kernel_node, ...@@ -50,10 +51,14 @@ void FilterInvalidKernelInfo(const CNodePtr &kernel_node,
} }
} }
} // namespace } // namespace
void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list) {
void KernelQueryAll(const CNodePtr &kernel_node,
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list) {
MS_EXCEPTION_IF_NULL(kernel_node); MS_EXCEPTION_IF_NULL(kernel_node);
MS_EXCEPTION_IF_NULL(kernel_info_list); MS_EXCEPTION_IF_NULL(kernel_info_list);
TbeMetadataInfo(kernel_node, kernel_info_list); TbeMetadataInfo(kernel_node, kernel_info_list);
if (kernel_info_list->empty()) { if (kernel_info_list->empty()) {
AicpuMetadataInfo(kernel_node, kernel_info_list); AicpuMetadataInfo(kernel_node, kernel_info_list);
if (!kernel_info_list->empty()) { if (!kernel_info_list->empty()) {
...@@ -73,6 +78,28 @@ void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel ...@@ -73,6 +78,28 @@ void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel
if (kernel_info_list->empty()) { if (kernel_info_list->empty()) {
MS_LOG(EXCEPTION) << "Op " << kernel_node->DebugString() << "kernel query fail!"; MS_LOG(EXCEPTION) << "Op " << kernel_node->DebugString() << "kernel query fail!";
} }
}
void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list,
KernelType kernel_type) {
MS_EXCEPTION_IF_NULL(kernel_node);
MS_EXCEPTION_IF_NULL(kernel_info_list);
std::string op_name = AnfAlgo::GetCNodeName(kernel_node);
switch (kernel_type) {
case KernelType::AKG_KERNEL:
AkgMetadataInfo(kernel_node, kernel_info_list);
break;
default:
KernelQueryAll(kernel_node, kernel_info_list);
break;
}
if (kernel_info_list->empty()) {
MS_EXCEPTION(NotExistsError) << "Op[" << kernel_node->DebugString() << "] kernel query fail!";
}
// check output
FilterInvalidKernelInfo(kernel_node, kernel_info_list); FilterInvalidKernelInfo(kernel_node, kernel_info_list);
} }
......
...@@ -25,7 +25,8 @@ ...@@ -25,7 +25,8 @@
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list); void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list,
KernelType kernel_type = KernelType::UNKNOWN_KERNEL_TYPE);
void AICPUQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list); void AICPUQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list);
bool IsSupportedByAICPU(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info); bool IsSupportedByAICPU(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info);
bool IsSupportedByAICore(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info); bool IsSupportedByAICore(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info);
......
...@@ -264,8 +264,7 @@ std::shared_ptr<OpInfo> OpLib::FindOp(const std::string &op_name, OpImplyType im ...@@ -264,8 +264,7 @@ std::shared_ptr<OpInfo> OpLib::FindOp(const std::string &op_name, OpImplyType im
auto context = MsContext::GetInstance(); auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context); MS_EXCEPTION_IF_NULL(context);
bool is_gpu = (context->device_target() == kGPUDevice); bool is_gpu = (context->device_target() == kGPUDevice);
if ((is_gpu && (imply_type == kTBE || imply_type == kAICPU)) || if (is_gpu && (imply_type == kTBE || imply_type == kAICPU)) {
(!is_gpu && (imply_type != kTBE && imply_type != kAICPU))) {
MS_LOG(ERROR) << "FindOp failed: opname: " << op_name << ", imply_type: " << ImplTypeToStr(imply_type) MS_LOG(ERROR) << "FindOp failed: opname: " << op_name << ", imply_type: " << ImplTypeToStr(imply_type)
<< ", current op num: " << op_info_.size(); << ", current op num: " << op_info_.size();
return nullptr; return nullptr;
......
...@@ -307,7 +307,7 @@ static int TypeStrToDstType(const std::string &type_str) { ...@@ -307,7 +307,7 @@ static int TypeStrToDstType(const std::string &type_str) {
ret = 4; ret = 4;
} else if (type_str == "UInt64") { } else if (type_str == "UInt64") {
ret = 10; ret = 10;
} else if (type_str == "Bool_") { } else if (type_str == "Bool") {
ret = 12; ret = 12;
} else { } else {
MS_EXCEPTION(ArgumentError) << "type str is invailed: " << type_str; MS_EXCEPTION(ArgumentError) << "type str is invailed: " << type_str;
......
...@@ -51,7 +51,7 @@ const std::map<TypeId, std::string> type_id_str_maps = { ...@@ -51,7 +51,7 @@ const std::map<TypeId, std::string> type_id_str_maps = {
const std::map<std::string, std::string> type_str_maps = { const std::map<std::string, std::string> type_str_maps = {
{"Float32", "float32"}, {"Float16", "float16"}, {"Int8", "int8"}, {"Int16", "int16"}, {"Float32", "float32"}, {"Float16", "float16"}, {"Int8", "int8"}, {"Int16", "int16"},
{"UInt16", "uint16"}, {"UInt8", "uint8"}, {"Int32", "int32"}, {"UInt32", "uint32"}, {"UInt16", "uint16"}, {"UInt8", "uint8"}, {"Int32", "int32"}, {"UInt32", "uint32"},
{"Int64", "int64"}, {"UInt64", "uint64"}, {"Bool_", "int8"}, {"Float64", "float64"}, {"Int64", "int64"}, {"UInt64", "uint64"}, {"Bool", "int8"}, {"Float64", "float64"},
}; };
const std::unordered_map<std::string, size_t> type_nbyte_maps = { const std::unordered_map<std::string, size_t> type_nbyte_maps = {
......
...@@ -334,8 +334,8 @@ ArgsPairList HyperMap::Harmonize(const FuncGraphPtr &func_graph, const ArgsPairL ...@@ -334,8 +334,8 @@ ArgsPairList HyperMap::Harmonize(const FuncGraphPtr &func_graph, const ArgsPairL
FuncGraphPtr HyperMap::GenerateFromTypes(const TypePtrList &args_spec_list) { FuncGraphPtr HyperMap::GenerateFromTypes(const TypePtrList &args_spec_list) {
FuncGraphPtr ptrGraph = std::make_shared<FuncGraph>(); FuncGraphPtr ptrGraph = std::make_shared<FuncGraph>();
ptrGraph->set_flags(FUNC_GRAPH_FLAG_CORE, true); ptrGraph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
ptrGraph->set_flags(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true); ptrGraph->set_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true);
ptrGraph->debug_info()->set_name("hyper_map"); ptrGraph->debug_info()->set_name("hyper_map");
AnfNodePtr ptrFnArg = nullptr; AnfNodePtr ptrFnArg = nullptr;
...@@ -389,7 +389,7 @@ FuncGraphPtr Tail::GenerateTupleFuncGraph(const abstract::AbstractTuplePtr &a_tu ...@@ -389,7 +389,7 @@ FuncGraphPtr Tail::GenerateTupleFuncGraph(const abstract::AbstractTuplePtr &a_tu
MS_EXCEPTION_IF_NULL(a_tuple); MS_EXCEPTION_IF_NULL(a_tuple);
FuncGraphPtr ret = std::make_shared<FuncGraph>(); FuncGraphPtr ret = std::make_shared<FuncGraph>();
ret->set_flags(FUNC_GRAPH_FLAG_CORE, true); ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
ret->debug_info()->set_name("tail"); ret->debug_info()->set_name("tail");
AnfNodePtr ptrTup = ret->add_parameter(); AnfNodePtr ptrTup = ret->add_parameter();
...@@ -409,7 +409,7 @@ FuncGraphPtr Tail::GenerateListFuncGraph(const abstract::AbstractListPtr &a_list ...@@ -409,7 +409,7 @@ FuncGraphPtr Tail::GenerateListFuncGraph(const abstract::AbstractListPtr &a_list
MS_EXCEPTION_IF_NULL(a_list); MS_EXCEPTION_IF_NULL(a_list);
FuncGraphPtr ret = std::make_shared<FuncGraph>(); FuncGraphPtr ret = std::make_shared<FuncGraph>();
ret->set_flags(FUNC_GRAPH_FLAG_CORE, true); ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
ret->debug_info()->set_name("tail"); ret->debug_info()->set_name("tail");
AnfNodePtr ptrList = ret->add_parameter(); AnfNodePtr ptrList = ret->add_parameter();
...@@ -481,10 +481,10 @@ FuncGraphPtr MakeTupleGradient::GenerateFuncGraph(const AbstractBasePtrList &arg ...@@ -481,10 +481,10 @@ FuncGraphPtr MakeTupleGradient::GenerateFuncGraph(const AbstractBasePtrList &arg
grads.push_back(b->NewCNode({NewValueNode(prim::kPrimTupleGetItem), dout, NewValueNode(i)})); grads.push_back(b->NewCNode({NewValueNode(prim::kPrimTupleGetItem), dout, NewValueNode(i)}));
} }
b->set_flags(FUNC_GRAPH_FLAG_CORE, true); b->set_flag(FUNC_GRAPH_FLAG_CORE, true);
b->set_output(b->NewCNode(grads)); b->set_output(b->NewCNode(grads));
fg->set_flags(FUNC_GRAPH_FLAG_CORE, true); fg->set_flag(FUNC_GRAPH_FLAG_CORE, true);
fg->set_output(fg->NewCNode({NewValueNode(prim::kPrimMakeTuple), out, NewValueNode(b)})); fg->set_output(fg->NewCNode({NewValueNode(prim::kPrimMakeTuple), out, NewValueNode(b)}));
(void)fg->transforms().emplace("primal", FuncGraphTransform(prim::kPrimMakeTuple)); (void)fg->transforms().emplace("primal", FuncGraphTransform(prim::kPrimMakeTuple));
return fg; return fg;
...@@ -503,7 +503,7 @@ GradOperation::GradOperation(const std::string &name, bool get_all, bool get_by_ ...@@ -503,7 +503,7 @@ GradOperation::GradOperation(const std::string &name, bool get_all, bool get_by_
FuncGraphPtr GradOperation::GetGrad(AnfNodePtr node, const AnfNodePtr &weights, FuncGraphPtr GradOperation::GetGrad(AnfNodePtr node, const AnfNodePtr &weights,
const std::vector<AnfNodePtr> &params_list, bool applyJ) { const std::vector<AnfNodePtr> &params_list, bool applyJ) {
FuncGraphPtr ret = std::make_shared<FuncGraph>(); FuncGraphPtr ret = std::make_shared<FuncGraph>();
ret->set_flags(FUNC_GRAPH_FLAG_CORE, true); ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
ValueNodePtr opsJ = NewValueNode(prim::kPrimJ); ValueNodePtr opsJ = NewValueNode(prim::kPrimJ);
ValueNodePtr opsTupleItem = NewValueNode(prim::kPrimTupleGetItem); ValueNodePtr opsTupleItem = NewValueNode(prim::kPrimTupleGetItem);
...@@ -619,7 +619,7 @@ FuncGraphPtr GradOperation::GenerateFuncGraph(const AbstractBasePtrList &args_sp ...@@ -619,7 +619,7 @@ FuncGraphPtr GradOperation::GenerateFuncGraph(const AbstractBasePtrList &args_sp
std::ostringstream ss; std::ostringstream ss;
ss << "grad{" << nparam << "}"; ss << "grad{" << nparam << "}";
dfBuilder->set_flags(FUNC_GRAPH_FLAG_CORE, true); dfBuilder->set_flag(FUNC_GRAPH_FLAG_CORE, true);
dfBuilder->debug_info()->set_name(ss.str()); dfBuilder->debug_info()->set_name(ss.str());
ParameterPtr param_graph = dfBuilder->add_parameter(); ParameterPtr param_graph = dfBuilder->add_parameter();
...@@ -774,7 +774,7 @@ FuncGraphPtr ListMap::GenerateFuncGraph(const AbstractBasePtrList &args_spec_lis ...@@ -774,7 +774,7 @@ FuncGraphPtr ListMap::GenerateFuncGraph(const AbstractBasePtrList &args_spec_lis
} }
FuncGraphPtr fg_ptr = std::make_shared<FuncGraph>(); FuncGraphPtr fg_ptr = std::make_shared<FuncGraph>();
fg_ptr->set_flags(FUNC_GRAPH_FLAG_CORE, true); fg_ptr->set_flag(FUNC_GRAPH_FLAG_CORE, true);
fg_ptr->debug_info()->set_name("list_map"); fg_ptr->debug_info()->set_name("list_map");
AnfNodePtr fn = fg_ptr->add_parameter(); AnfNodePtr fn = fg_ptr->add_parameter();
...@@ -844,7 +844,7 @@ void ListMap::MakeCond(const std::vector<AnfNodePtr> &lists, const FuncGraphPtr ...@@ -844,7 +844,7 @@ void ListMap::MakeCond(const std::vector<AnfNodePtr> &lists, const FuncGraphPtr
// cond = reduce(lambda a, b: g.apply(P.bool_and, a, b), hasnexts) // cond = reduce(lambda a, b: g.apply(P.bool_and, a, b), hasnexts)
FuncGraphPtr fgtrue_ptr = std::make_shared<FuncGraph>(); FuncGraphPtr fgtrue_ptr = std::make_shared<FuncGraph>();
fgtrue_ptr->debug_info()->set_name("ftrue"); fgtrue_ptr->debug_info()->set_name("ftrue");
fgtrue_ptr->set_flags(FUNC_GRAPH_FLAG_CORE, true); fgtrue_ptr->set_flag(FUNC_GRAPH_FLAG_CORE, true);
CNodePtr fgtrue_output_cnode = fgtrue_ptr->NewCNode({NewValueNode(fgnext_ptr), fn, resl}); CNodePtr fgtrue_output_cnode = fgtrue_ptr->NewCNode({NewValueNode(fgnext_ptr), fn, resl});
auto inputs = fgtrue_output_cnode->inputs(); auto inputs = fgtrue_output_cnode->inputs();
...@@ -854,7 +854,7 @@ void ListMap::MakeCond(const std::vector<AnfNodePtr> &lists, const FuncGraphPtr ...@@ -854,7 +854,7 @@ void ListMap::MakeCond(const std::vector<AnfNodePtr> &lists, const FuncGraphPtr
FuncGraphPtr fgfalse_ptr = std::make_shared<FuncGraph>(); FuncGraphPtr fgfalse_ptr = std::make_shared<FuncGraph>();
fgfalse_ptr->debug_info()->set_name("ffalse"); fgfalse_ptr->debug_info()->set_name("ffalse");
fgfalse_ptr->set_flags(FUNC_GRAPH_FLAG_CORE, true); fgfalse_ptr->set_flag(FUNC_GRAPH_FLAG_CORE, true);
fgfalse_ptr->set_output(resl); fgfalse_ptr->set_output(resl);
AnfNodePtr output_cnode = fg_ptr->NewCNode({NewValueNode(prim::kPrimSwitch), NewValueNode(std::string("cond")), AnfNodePtr output_cnode = fg_ptr->NewCNode({NewValueNode(prim::kPrimSwitch), NewValueNode(std::string("cond")),
...@@ -911,7 +911,7 @@ FuncGraphPtr TupleAdd::GenerateFuncGraph(const AbstractBasePtrList &args_spec_li ...@@ -911,7 +911,7 @@ FuncGraphPtr TupleAdd::GenerateFuncGraph(const AbstractBasePtrList &args_spec_li
} }
FuncGraphPtr ret = std::make_shared<FuncGraph>(); FuncGraphPtr ret = std::make_shared<FuncGraph>();
ret->set_flags(FUNC_GRAPH_FLAG_CORE, true); ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
AnfNodePtr p_tup_a = ret->add_parameter(); AnfNodePtr p_tup_a = ret->add_parameter();
AnfNodePtr p_tup_b = ret->add_parameter(); AnfNodePtr p_tup_b = ret->add_parameter();
...@@ -1015,7 +1015,7 @@ FuncGraphPtr TupleSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec_ ...@@ -1015,7 +1015,7 @@ FuncGraphPtr TupleSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec_
GenerateTupleSliceParameter(tuple, slice, &start_index, &stop_index, &step_value); GenerateTupleSliceParameter(tuple, slice, &start_index, &stop_index, &step_value);
FuncGraphPtr ret = std::make_shared<FuncGraph>(); FuncGraphPtr ret = std::make_shared<FuncGraph>();
ret->set_flags(FUNC_GRAPH_FLAG_CORE, true); ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
AnfNodePtr p_tuple = ret->add_parameter(); AnfNodePtr p_tuple = ret->add_parameter();
(void)ret->add_parameter(); (void)ret->add_parameter();
...@@ -1186,7 +1186,7 @@ FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec ...@@ -1186,7 +1186,7 @@ FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec
AbstractTensorPtr tensorPtr = abstract::CheckArg<AbstractTensor>(op_name, args_spec_list, 0); AbstractTensorPtr tensorPtr = abstract::CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
FuncGraphPtr ret_graph = std::make_shared<FuncGraph>(); FuncGraphPtr ret_graph = std::make_shared<FuncGraph>();
ret_graph->set_flags(FUNC_GRAPH_FLAG_CORE, true); ret_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
AnfNodePtr tensor_node = ret_graph->add_parameter(); AnfNodePtr tensor_node = ret_graph->add_parameter();
(void)ret_graph->add_parameter(); (void)ret_graph->add_parameter();
...@@ -1244,7 +1244,7 @@ FuncGraphPtr TupleGetItemTensor::GenerateFuncGraph(const AbstractBasePtrList &ar ...@@ -1244,7 +1244,7 @@ FuncGraphPtr TupleGetItemTensor::GenerateFuncGraph(const AbstractBasePtrList &ar
AbstractBasePtrList branches = branches_abs->elements(); AbstractBasePtrList branches = branches_abs->elements();
if (branches.size() > 0 && branches[0] != nullptr && branches[0]->isa<AbstractFunction>()) { if (branches.size() > 0 && branches[0] != nullptr && branches[0]->isa<AbstractFunction>()) {
FuncGraphPtr ret_graph = std::make_shared<FuncGraph>(); FuncGraphPtr ret_graph = std::make_shared<FuncGraph>();
ret_graph->set_flags(FUNC_GRAPH_FLAG_CORE, true); ret_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
AnfNodePtr functions = ret_graph->add_parameter(); AnfNodePtr functions = ret_graph->add_parameter();
auto index = ret_graph->add_parameter(); auto index = ret_graph->add_parameter();
......
...@@ -225,7 +225,7 @@ FuncGraphPtr DoSignatureMetaFuncGraph::GenerateFuncGraph(const AbstractBasePtrLi ...@@ -225,7 +225,7 @@ FuncGraphPtr DoSignatureMetaFuncGraph::GenerateFuncGraph(const AbstractBasePtrLi
} }
auto new_cnode = BuildNewCNode(func_graph, name_, function_, args_spec_list, func_graph->parameters()); auto new_cnode = BuildNewCNode(func_graph, name_, function_, args_spec_list, func_graph->parameters());
func_graph->set_output(new_cnode); func_graph->set_output(new_cnode);
func_graph->set_flags(FUNC_GRAPH_FLAG_CORE, true); func_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
return func_graph; return func_graph;
} }
} // namespace prim } // namespace prim
......
...@@ -35,7 +35,7 @@ FuncGraphPtr ListAppend::GenerateFuncGraph(const abstract::AbstractBasePtrList & ...@@ -35,7 +35,7 @@ FuncGraphPtr ListAppend::GenerateFuncGraph(const abstract::AbstractBasePtrList &
MS_EXCEPTION_IF_NULL(arg0_list); MS_EXCEPTION_IF_NULL(arg0_list);
FuncGraphPtr ret = std::make_shared<FuncGraph>(); FuncGraphPtr ret = std::make_shared<FuncGraph>();
ret->set_flags(FUNC_GRAPH_FLAG_CORE, true); ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
ret->debug_info()->set_name("append"); ret->debug_info()->set_name("append");
AnfNodePtr arg0_node = ret->add_parameter(); AnfNodePtr arg0_node = ret->add_parameter();
......
...@@ -51,7 +51,7 @@ FuncGraphPtr UnpackCall::GenerateFuncGraph(const AbstractBasePtrList &args_spec_ ...@@ -51,7 +51,7 @@ FuncGraphPtr UnpackCall::GenerateFuncGraph(const AbstractBasePtrList &args_spec_
(void)abstract::CheckArg<AbstractFunction>(op_name, args_spec_list, 0); (void)abstract::CheckArg<AbstractFunction>(op_name, args_spec_list, 0);
auto ret_graph = std::make_shared<FuncGraph>(); auto ret_graph = std::make_shared<FuncGraph>();
ret_graph->set_flags(FUNC_GRAPH_FLAG_CORE, true); ret_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
AnfNodePtr fnNode = ret_graph->add_parameter(); AnfNodePtr fnNode = ret_graph->add_parameter();
std::vector<AnfNodePtr> elems; std::vector<AnfNodePtr> elems;
......
...@@ -57,7 +57,7 @@ FuncGraphPtr ZipOperation::GenerateFuncGraph(const AbstractBasePtrList &args_spe ...@@ -57,7 +57,7 @@ FuncGraphPtr ZipOperation::GenerateFuncGraph(const AbstractBasePtrList &args_spe
return (x->cast<AbstractTuplePtr>()->size() < y->cast<AbstractTuplePtr>()->size()); return (x->cast<AbstractTuplePtr>()->size() < y->cast<AbstractTuplePtr>()->size());
}); });
FuncGraphPtr ret_graph = std::make_shared<FuncGraph>(); FuncGraphPtr ret_graph = std::make_shared<FuncGraph>();
ret_graph->set_flags(FUNC_GRAPH_FLAG_CORE, true); ret_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
for (size_t idx = 0; idx < args_spec_list.size(); idx++) { for (size_t idx = 0; idx < args_spec_list.size(); idx++) {
(void)ret_graph->add_parameter(); (void)ret_graph->add_parameter();
} }
......
...@@ -52,6 +52,12 @@ const PrimitivePtr kPrimBoolNot = std::make_shared<Primitive>("bool_not"); ...@@ -52,6 +52,12 @@ const PrimitivePtr kPrimBoolNot = std::make_shared<Primitive>("bool_not");
const PrimitivePtr kPrimBoolAnd = std::make_shared<Primitive>("bool_and"); const PrimitivePtr kPrimBoolAnd = std::make_shared<Primitive>("bool_and");
const PrimitivePtr kPrimBoolOr = std::make_shared<Primitive>("bool_or"); const PrimitivePtr kPrimBoolOr = std::make_shared<Primitive>("bool_or");
const PrimitivePtr kPrimBoolEq = std::make_shared<Primitive>("bool_eq"); const PrimitivePtr kPrimBoolEq = std::make_shared<Primitive>("bool_eq");
const PrimitivePtr kPrimGreater = std::make_shared<Primitive>("Greater");
const PrimitivePtr kPrimGreaterEqual = std::make_shared<Primitive>("GreaterEqual");
const PrimitivePtr kPrimLess = std::make_shared<Primitive>("Less");
const PrimitivePtr kPrimLessEqual = std::make_shared<Primitive>("LessEqual");
const PrimitivePtr kPrimEqual = std::make_shared<Primitive>("Equal");
const PrimitivePtr kPrimNotEqual = std::make_shared<Primitive>("NotEqual");
// Type introspection // Type introspection
const PrimitivePtr kPrimTypeOf = std::make_shared<Primitive>("typeof"); const PrimitivePtr kPrimTypeOf = std::make_shared<Primitive>("typeof");
...@@ -165,14 +171,17 @@ const PrimitivePtr kPrimMul = std::make_shared<Primitive>("Mul"); ...@@ -165,14 +171,17 @@ const PrimitivePtr kPrimMul = std::make_shared<Primitive>("Mul");
const PrimitivePtr kPrimMinimum = std::make_shared<Primitive>("Minimum"); const PrimitivePtr kPrimMinimum = std::make_shared<Primitive>("Minimum");
const PrimitivePtr kPrimMaximum = std::make_shared<Primitive>("Maximum"); const PrimitivePtr kPrimMaximum = std::make_shared<Primitive>("Maximum");
const PrimitivePtr kPrimSquare = std::make_shared<Primitive>("Square"); const PrimitivePtr kPrimSquare = std::make_shared<Primitive>("Square");
const PrimitivePtr kPrimEqual = std::make_shared<Primitive>("Equal");
const PrimitivePtr kPrimLess = std::make_shared<Primitive>("Less");
const PrimitivePtr kPrimLessEqual = std::make_shared<Primitive>("LessEqual");
const PrimitivePtr kPrimCumSum = std::make_shared<Primitive>("CumSum"); const PrimitivePtr kPrimCumSum = std::make_shared<Primitive>("CumSum");
const PrimitivePtr kPrimCumProd = std::make_shared<Primitive>("CumProd"); const PrimitivePtr kPrimCumProd = std::make_shared<Primitive>("CumProd");
const PrimitivePtr kPrimPow = std::make_shared<Primitive>("Pow");
const PrimitivePtr kPrimRealDiv = std::make_shared<Primitive>("RealDiv");
const PrimitivePtr kPrimSqrt = std::make_shared<Primitive>("Sqrt");
const PrimitivePtr kPrimReciprocal = std::make_shared<Primitive>("Reciprocal");
const PrimitivePtr kPrimExpandDims = std::make_shared<Primitive>("ExpandDims");
// NN // NN
const PrimitivePtr kPrimFlatten = std::make_shared<Primitive>("Flatten"); const PrimitivePtr kPrimFlatten = std::make_shared<Primitive>("Flatten");
const PrimitivePtr kPrimSoftmax = std::make_shared<Primitive>("Softmax");
const PrimitivePtr kPrimLogSoftmax = std::make_shared<Primitive>("LogSoftmax"); const PrimitivePtr kPrimLogSoftmax = std::make_shared<Primitive>("LogSoftmax");
const PrimitivePtr kPrimLogSoftmaxGrad = std::make_shared<Primitive>("LogSoftmaxGrad"); const PrimitivePtr kPrimLogSoftmaxGrad = std::make_shared<Primitive>("LogSoftmaxGrad");
const PrimitivePtr kPrimTanh = std::make_shared<Primitive>("Tanh"); const PrimitivePtr kPrimTanh = std::make_shared<Primitive>("Tanh");
...@@ -211,6 +220,7 @@ const PrimitivePtr kPrimOneHot = std::make_shared<Primitive>("OneHot"); ...@@ -211,6 +220,7 @@ const PrimitivePtr kPrimOneHot = std::make_shared<Primitive>("OneHot");
const PrimitivePtr kPrimGelu = std::make_shared<Primitive>("Gelu"); const PrimitivePtr kPrimGelu = std::make_shared<Primitive>("Gelu");
const PrimitivePtr kPrimGeluGrad = std::make_shared<Primitive>("GeluGrad"); const PrimitivePtr kPrimGeluGrad = std::make_shared<Primitive>("GeluGrad");
const PrimitivePtr kPrimRelu = std::make_shared<Primitive>("ReLU"); const PrimitivePtr kPrimRelu = std::make_shared<Primitive>("ReLU");
const PrimitivePtr kPrimZerosLike = std::make_shared<Primitive>("ZerosLike");
const PrimitivePtr kPrimReluV2 = std::make_shared<Primitive>("ReLUV2"); const PrimitivePtr kPrimReluV2 = std::make_shared<Primitive>("ReLUV2");
const PrimitivePtr kPrimZerosLikeTensor = std::make_shared<Primitive>("zeros_like_tensor"); const PrimitivePtr kPrimZerosLikeTensor = std::make_shared<Primitive>("zeros_like_tensor");
const PrimitivePtr kPrimFakeBprop = std::make_shared<Primitive>("fake_bprop"); const PrimitivePtr kPrimFakeBprop = std::make_shared<Primitive>("fake_bprop");
...@@ -244,6 +254,7 @@ const PrimitivePtr kPrimIs_ = std::make_shared<Primitive>("is_"); ...@@ -244,6 +254,7 @@ const PrimitivePtr kPrimIs_ = std::make_shared<Primitive>("is_");
const PrimitivePtr kPrimIsNot = std::make_shared<Primitive>("is_not"); const PrimitivePtr kPrimIsNot = std::make_shared<Primitive>("is_not");
const PrimitivePtr kPrimInDict = std::make_shared<Primitive>("in_dict"); const PrimitivePtr kPrimInDict = std::make_shared<Primitive>("in_dict");
const PrimitivePtr kPrimNotInDict = std::make_shared<Primitive>("not_in_dict"); const PrimitivePtr kPrimNotInDict = std::make_shared<Primitive>("not_in_dict");
const PrimitivePtr kPrimEquivFormat = std::make_shared<Primitive>("EquivFormat");
// Comm ops // Comm ops
const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator"); const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator");
......
...@@ -58,6 +58,12 @@ extern const PrimitivePtr kPrimBoolNot; ...@@ -58,6 +58,12 @@ extern const PrimitivePtr kPrimBoolNot;
extern const PrimitivePtr kPrimBoolAnd; extern const PrimitivePtr kPrimBoolAnd;
extern const PrimitivePtr kPrimBoolOr; extern const PrimitivePtr kPrimBoolOr;
extern const PrimitivePtr kPrimBoolEq; extern const PrimitivePtr kPrimBoolEq;
extern const PrimitivePtr kPrimGreater;
extern const PrimitivePtr kPrimGreaterEqual;
extern const PrimitivePtr kPrimLess;
extern const PrimitivePtr kPrimLessEqual;
extern const PrimitivePtr kPrimEqual;
extern const PrimitivePtr kPrimNotEqual;
// Type introspection // Type introspection
extern const PrimitivePtr kPrimTypeOf; extern const PrimitivePtr kPrimTypeOf;
...@@ -153,6 +159,10 @@ extern const PrimitivePtr kPrimAddN; ...@@ -153,6 +159,10 @@ extern const PrimitivePtr kPrimAddN;
extern const PrimitivePtr KPrimTransData; extern const PrimitivePtr KPrimTransData;
extern const PrimitivePtr kPrimNMSWithMask; extern const PrimitivePtr kPrimNMSWithMask;
extern const PrimitivePtr kPrimPad; extern const PrimitivePtr kPrimPad;
extern const PrimitivePtr kPrimRealDiv;
extern const PrimitivePtr kPrimSqrt;
extern const PrimitivePtr kPrimReciprocal;
extern const PrimitivePtr kPrimExpandDims;
// Maths // Maths
extern const PrimitivePtr kPrimTensorAdd; extern const PrimitivePtr kPrimTensorAdd;
...@@ -176,9 +186,11 @@ extern const PrimitivePtr kPrimLess; ...@@ -176,9 +186,11 @@ extern const PrimitivePtr kPrimLess;
extern const PrimitivePtr kPrimLessEqual; extern const PrimitivePtr kPrimLessEqual;
extern const PrimitivePtr kPrimCumSum; extern const PrimitivePtr kPrimCumSum;
extern const PrimitivePtr kPrimCumProd; extern const PrimitivePtr kPrimCumProd;
extern const PrimitivePtr kPrimPow;
// NN // NN
extern const PrimitivePtr kPrimFlatten; extern const PrimitivePtr kPrimFlatten;
extern const PrimitivePtr kPrimSoftmax;
extern const PrimitivePtr kPrimLogSoftmax; extern const PrimitivePtr kPrimLogSoftmax;
extern const PrimitivePtr kPrimLogSoftmaxGrad; extern const PrimitivePtr kPrimLogSoftmaxGrad;
extern const PrimitivePtr kPrimTanh; extern const PrimitivePtr kPrimTanh;
...@@ -217,6 +229,7 @@ extern const PrimitivePtr kPrimGeluGrad; ...@@ -217,6 +229,7 @@ extern const PrimitivePtr kPrimGeluGrad;
extern const PrimitivePtr kPrimRelu; extern const PrimitivePtr kPrimRelu;
extern const PrimitivePtr kPrimReluV2; extern const PrimitivePtr kPrimReluV2;
extern const PrimitivePtr kPrimActivation; extern const PrimitivePtr kPrimActivation;
extern const PrimitivePtr kPrimZerosLike;
extern const PrimitivePtr kPrimZerosLikeTensor; extern const PrimitivePtr kPrimZerosLikeTensor;
extern const PrimitivePtr kPrimFakeBprop; extern const PrimitivePtr kPrimFakeBprop;
extern const PrimitivePtr kPrimBpropCut; extern const PrimitivePtr kPrimBpropCut;
...@@ -251,6 +264,7 @@ extern const PrimitivePtr kPrimIs_; ...@@ -251,6 +264,7 @@ extern const PrimitivePtr kPrimIs_;
extern const PrimitivePtr kPrimIsNot; extern const PrimitivePtr kPrimIsNot;
extern const PrimitivePtr kPrimInDict; extern const PrimitivePtr kPrimInDict;
extern const PrimitivePtr kPrimNotInDict; extern const PrimitivePtr kPrimNotInDict;
extern const PrimitivePtr kPrimEquivFormat;
// Comm ops // Comm ops
extern const PrimitivePtr kPrimAllReduce; extern const PrimitivePtr kPrimAllReduce;
......
...@@ -110,7 +110,7 @@ AbstractBasePtr InferImplSwitch(const AnalysisEnginePtr &, const PrimitivePtr &, ...@@ -110,7 +110,7 @@ AbstractBasePtr InferImplSwitch(const AnalysisEnginePtr &, const PrimitivePtr &,
ValuePtr v = cond->GetValueTrack(); ValuePtr v = cond->GetValueTrack();
MS_EXCEPTION_IF_NULL(v); MS_EXCEPTION_IF_NULL(v);
if (v->isa<AnyValue>()) { if (v->isa<AnyValue>() || cond->isa<AbstractTensor>()) {
MS_EXCEPTION_IF_NULL(tb); MS_EXCEPTION_IF_NULL(tb);
return tb->Join(fb); return tb->Join(fb);
} }
......
...@@ -45,10 +45,19 @@ DFunctor::DFunctor(const FuncGraphPtr &primal_graph, const pipeline::ResourceBas ...@@ -45,10 +45,19 @@ DFunctor::DFunctor(const FuncGraphPtr &primal_graph, const pipeline::ResourceBas
: primal_graph_(primal_graph), resources_(resources), need_cut_(false), is_top_(false) { : primal_graph_(primal_graph), resources_(resources), need_cut_(false), is_top_(false) {
TraceManager::DebugTrace(std::make_shared<TraceGradFprop>(primal_graph->debug_info())); TraceManager::DebugTrace(std::make_shared<TraceGradFprop>(primal_graph->debug_info()));
k_graph_ = std::make_shared<FuncGraph>(); k_graph_ = std::make_shared<FuncGraph>();
if (primal_graph->has_attr(FUNC_GRAPH_FLAG_COMPOSITE)) {
std::string grad_op_name = GetValue<std::string>(primal_graph->get_attr(FUNC_GRAPH_FLAG_COMPOSITE));
k_graph_->set_attr(FUNC_GRAPH_FLAG_COMPOSITE, MakeValue(grad_op_name));
}
TraceManager::EndTrace(); TraceManager::EndTrace();
TraceManager::DebugTrace(std::make_shared<TraceGradBprop>(primal_graph->debug_info())); TraceManager::DebugTrace(std::make_shared<TraceGradBprop>(primal_graph->debug_info()));
tape_ = std::make_shared<FuncGraph>(); tape_ = std::make_shared<FuncGraph>();
// Add "_Grad" postfix
if (primal_graph->has_attr(FUNC_GRAPH_FLAG_COMPOSITE)) {
std::string grad_op_name = GetValue<std::string>(primal_graph->get_attr(FUNC_GRAPH_FLAG_COMPOSITE)) + "_Grad";
tape_->set_attr(FUNC_GRAPH_FLAG_COMPOSITE, MakeValue(grad_op_name));
}
TraceManager::EndTrace(); TraceManager::EndTrace();
dout_ = tape_->add_parameter(); dout_ = tape_->add_parameter();
...@@ -368,7 +377,7 @@ FuncGraphPtr DFunctor::KUserDefined(const FuncGraphPtr &primal) { ...@@ -368,7 +377,7 @@ FuncGraphPtr DFunctor::KUserDefined(const FuncGraphPtr &primal) {
(void)primal->transforms().insert(std::make_pair("grad", FuncGraphTransform(fg))); (void)primal->transforms().insert(std::make_pair("grad", FuncGraphTransform(fg)));
(void)fg->transforms().insert(std::make_pair("primal", FuncGraphTransform(primal))); (void)fg->transforms().insert(std::make_pair("primal", FuncGraphTransform(primal)));
// Reset defer_inline to enable successive inlining // Reset defer_inline to enable successive inlining
primal->set_flags(FUNC_GRAPH_FLAG_DEFER_INLINE, false); primal->set_flag(FUNC_GRAPH_FLAG_DEFER_INLINE, false);
auto functor = std::make_shared<DFunctor>(primal, resources_); auto functor = std::make_shared<DFunctor>(primal, resources_);
functor->Init(functor); functor->Init(functor);
......
...@@ -37,7 +37,7 @@ FuncGraphPtr Grad(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePt ...@@ -37,7 +37,7 @@ FuncGraphPtr Grad(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePt
auto multi_graph_sink = [&func_graph](const FuncGraphPtr &f) { auto multi_graph_sink = [&func_graph](const FuncGraphPtr &f) {
if (MsContext::GetInstance()->is_multi_graph_sink()) { if (MsContext::GetInstance()->is_multi_graph_sink()) {
if (func_graph->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES)) { if (func_graph->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES)) {
f->set_flags(FUNC_GRAPH_FLAG_IGNORE_VALUES, true); f->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true);
} }
} }
}; };
......
...@@ -77,7 +77,10 @@ AnfNodePtr ConvertGetAttrToTupleGetItem(const CNodePtr &node) { ...@@ -77,7 +77,10 @@ AnfNodePtr ConvertGetAttrToTupleGetItem(const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(cons); MS_EXCEPTION_IF_NULL(cons);
auto dt = data->abstract(); auto dt = data->abstract();
MS_EXCEPTION_IF_NULL(dt); if (dt == nullptr) {
return nullptr;
}
if (!dt->isa<AbstractClass>()) { if (!dt->isa<AbstractClass>()) {
MS_LOG(EXCEPTION) << "First parameter of getattr is not AbstractClass, but " << dt->type_name() << "."; MS_LOG(EXCEPTION) << "First parameter of getattr is not AbstractClass, but " << dt->type_name() << ".";
} }
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "optimizer/composite_op_reuse.h"
#include <vector>
#include <algorithm>
#include <string>
#include "./common.h"
#include "utils/graph_utils.h"
namespace mindspore {
/* namespace to support opt */
namespace opt {
bool CompositeReuse::CompareNode(const AnfNodePtr a, const AnfNodePtr b) {
if (a->abstract() && b->abstract()) {
auto a_type = a->abstract()->GetTypeTrack();
auto b_type = b->abstract()->GetTypeTrack();
if (a_type != b_type) {
return false;
}
auto a_shape = a->abstract()->GetShapeTrack();
auto b_shape = b->abstract()->GetShapeTrack();
if (a_shape != nullptr && a_shape == b_shape) {
return true;
}
if (a_shape != nullptr && b_shape != nullptr && a_shape->isa<abstract::Shape>() &&
b_shape->isa<abstract::Shape>()) {
return a_shape->cast<abstract::ShapePtr>()->shape() == b_shape->cast<abstract::ShapePtr>()->shape();
}
}
return false;
}
bool CompositeReuse::DoReplace(const FuncGraphManagerPtr manager) {
bool changed = false;
auto fgs = manager->func_graphs();
for (FuncGraphPtr &fg : fgs) {
if (!fg->has_attr(FUNC_GRAPH_FLAG_COMPOSITE)) {
continue;
}
std::string key = GetValue<std::string>(fg->get_attr(FUNC_GRAPH_FLAG_COMPOSITE));
if (composite_ops.find(key) != composite_ops.end()) {
if (find(composite_ops[key].begin(), composite_ops[key].end(), fg) == composite_ops[key].end()) {
FuncGraphPtr new_fg = nullptr;
for (auto &cfg : composite_ops[key]) {
// If two graphs have different size then continue
auto fg_topos = TopoSort(fg->get_return());
auto cfg_topos = TopoSort(cfg->get_return());
if (fg_topos.size() != cfg_topos.size()) {
continue;
}
// Compare const tensor
bool has_same = true;
for (size_t i = 0; i < fg_topos.size(); ++i) {
if (IsValueNode<tensor::Tensor>(fg_topos[i])) {
if (!IsValueNode<tensor::Tensor>(cfg_topos[i])) {
has_same = false;
break;
}
auto tensor1 = GetValueNode<tensor::TensorPtr>(fg_topos[i]);
auto tensor2 = GetValueNode<tensor::TensorPtr>(cfg_topos[i]);
if (!tensor1->ValueEqual(*tensor2)) {
has_same = false;
break;
}
}
}
if (!has_same) {
continue;
}
auto fg_input = fg->parameters();
auto cfg_input = cfg->parameters();
if (fg_input.size() != cfg_input.size()) {
continue;
}
// Compare input
for (size_t i = 0; i < fg_input.size(); ++i) {
if (!CompareNode(fg_input[i], cfg_input[i])) {
has_same = false;
break;
}
}
if (!has_same) {
continue;
}
// Compare output
if (!CompareNode(fg->output(), cfg->output())) {
continue;
}
// Find reusable fg
new_fg = cfg;
break;
}
if (new_fg != nullptr) {
// Replace current fg with existing fg
auto users = fg->func_graph_cnodes_index();
for (auto &iter : users) {
auto cnode = iter.first->first->cast<CNodePtr>();
auto new_input = cnode->inputs();
auto main_graph = cnode->func_graph();
MS_EXCEPTION_IF_NULL(main_graph);
if (IsPrimitiveCNode(cnode, prim::kPrimPartial)) {
new_input[1] = NewValueNode(new_fg);
} else {
new_input[0] = NewValueNode(new_fg);
}
auto new_cnode = main_graph->NewCNode(new_input);
manager->Replace(iter.first->first, new_cnode);
changed = true;
}
} else {
// Add current fg to map
composite_ops[key].push_back(fg);
}
}
} else {
composite_ops[key] = {fg};
}
}
return changed;
}
bool CompositeReuse::ReuseComposite(const FuncGraphPtr root, const FuncGraphManagerPtr manager) {
MS_EXCEPTION_IF_NULL(manager);
manager->AddFuncGraph(root);
return DoReplace(manager);
}
} // namespace opt
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_OPTIMIZER_COMPOSITE_OP_REUSE_H
#define MINDSPORE_CCSRC_OPTIMIZER_COMPOSITE_OP_REUSE_H
#include <mindspore/ccsrc/session/anf_runtime_algorithm.h>
#include <unordered_map>
#include <string>
#include <vector>
#include "optimizer/optimizer.h"
namespace mindspore {
namespace opt {
// Common subexpression elimination.
class CompositeReuse {
public:
CompositeReuse() : count(0) {}
virtual ~CompositeReuse() = default;
bool operator()(const FuncGraphPtr &root, const OptimizerPtr &optimizer) {
bool chg = ReuseComposite(root, optimizer->resource()->manager());
return chg;
}
bool CompareNode(const AnfNodePtr a, const AnfNodePtr other);
bool DoReplace(const FuncGraphManagerPtr manager);
bool ReuseComposite(const FuncGraphPtr root, const FuncGraphManagerPtr manager);
private:
std::unordered_map<std::string, std::vector<FuncGraphPtr>> composite_ops;
int count;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_OPTIMIZER_COMPOSITE_OP_REUSE_H
...@@ -41,6 +41,8 @@ ...@@ -41,6 +41,8 @@
#include "optimizer/irpass/incorporate_call.h" #include "optimizer/irpass/incorporate_call.h"
#include "optimizer/irpass/grad_var_prepare.h" #include "optimizer/irpass/grad_var_prepare.h"
#include "optimizer/irpass/param_replace.h" #include "optimizer/irpass/param_replace.h"
#include "optimizer/irpass/mark_interface_fusion.h"
#include "optimizer/opt.h"
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
...@@ -48,7 +50,7 @@ namespace irpass { ...@@ -48,7 +50,7 @@ namespace irpass {
OptimizeIRPassLib::OptimizeIRPassLib() { OptimizeIRPassLib::OptimizeIRPassLib() {
arithmetic_simplify_ = MakeSubstitution(ArithmeticSimplify(), "arithmetic_simplify", arithmetic_simplify_ = MakeSubstitution(ArithmeticSimplify(), "arithmetic_simplify",
{prim::kPrimScalarAdd, prim::kPrimScalarMul, prim::kPrimTensorAdd, {prim::kPrimScalarAdd, prim::kPrimScalarMul, prim::kPrimTensorAdd,
prim::kPrimIdentity, prim::kPrimMomentum, prim::kPrimMul}); prim::kPrimIdentity, prim::kPrimMomentum, prim::kPrimMul, prim::kPrimPow});
special_op_eliminate_ = special_op_eliminate_ =
MakeSubstitution(SpecialOpEliminater(), "special_op_eliminate", MakeSubstitution(SpecialOpEliminater(), "special_op_eliminate",
{prim::kPrimInsertGradientOf, prim::kPrimHookBackward, prim::kPrimPrintShapeType, {prim::kPrimInsertGradientOf, prim::kPrimHookBackward, prim::kPrimPrintShapeType,
...@@ -88,7 +90,8 @@ OptimizeIRPassLib::OptimizeIRPassLib() { ...@@ -88,7 +90,8 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
replace_refkey_by_param_ = replace_refkey_by_param_ =
MakeSubstitution(ReplaceRefkeyByParam(), "replace_refkey_by_param", IsValueNode<RefKey>, opt::FORCE_RENORM); MakeSubstitution(ReplaceRefkeyByParam(), "replace_refkey_by_param", IsValueNode<RefKey>, opt::FORCE_RENORM);
replace_old_param_ = MakeSubstitution(ReplaceOldParam(), "replace_old_param", IsParam); replace_old_param_ = MakeSubstitution(ReplaceOldParam(), "replace_old_param", IsParam);
get_ref_value_eliminate_ =
MakeSubstitution(GetRefValueEliminater(), "get_ref_value_eliminate", prim::kPrimGetRefValue);
// Gradient transforms // Gradient transforms
expand_jprim_ = MakeSubstitution(ExpandJPrim(), "expand_jprim", prim::kPrimJ); expand_jprim_ = MakeSubstitution(ExpandJPrim(), "expand_jprim", prim::kPrimJ);
stop_gradient_eliminate_ = stop_gradient_eliminate_ =
...@@ -114,6 +117,8 @@ OptimizeIRPassLib::OptimizeIRPassLib() { ...@@ -114,6 +117,8 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
// Incorporation // Incorporation
incorporate_getitem_ = MakeSubstitution(IncorporateGetitem(), "incorporate_getitem", prim::kPrimTupleGetItem); incorporate_getitem_ = MakeSubstitution(IncorporateGetitem(), "incorporate_getitem", prim::kPrimTupleGetItem);
incorporate_getitem_from_param_ =
MakeSubstitution(IncorporateGetitemFromParam(), "incorporate_getitem_from_param", IsCNodeComposite);
incorporate_getitem_switch_ = incorporate_getitem_switch_ =
MakeSubstitution(IncorporateGetitemSwitch(), "incorporate_getitem_switch", prim::kPrimTupleGetItem); MakeSubstitution(IncorporateGetitemSwitch(), "incorporate_getitem_switch", prim::kPrimTupleGetItem);
incorporate_call_ = MakeSubstitution(IncorporateCall(), "incorporate_call", IsCNodeDup); incorporate_call_ = MakeSubstitution(IncorporateCall(), "incorporate_call", IsCNodeDup);
...@@ -125,6 +130,17 @@ OptimizeIRPassLib::OptimizeIRPassLib() { ...@@ -125,6 +130,17 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
// Convert // Convert
print_tuple_wrapper_ = MakeSubstitution(PrintTupleWrapper(), "print_tuple_wrapper", prim::kPrimPrint); print_tuple_wrapper_ = MakeSubstitution(PrintTupleWrapper(), "print_tuple_wrapper", prim::kPrimPrint);
// Unused parameter eliminate
unused_parameter_eliminate_ =
MakeSubstitution(UnusedParasEliminater(), "unused_parameter_eliminate", IsCNodeComposite);
unused_output_eliminate_ = MakeSubstitution(UnusedOutputEliminater(), "unused_output_eliminate", IsCNodeComposite);
// AddN eliminate
addn_eliminate_ = MakeSubstitution(AddNEliminater(), "addn_eliminate", IsCNodeComposite);
// Mark interface fusion
mark_interface_fusion_ = MakeSubstitution(MarkInterfaceFusion(), "mark_interface_fusion", prim::kPrimSelect);
} }
ResolveIRPassLib::ResolveIRPassLib() { ResolveIRPassLib::ResolveIRPassLib() {
......
...@@ -61,6 +61,7 @@ class OptimizeIRPassLib { ...@@ -61,6 +61,7 @@ class OptimizeIRPassLib {
SubstitutionPtr get_make_ref_eliminate_; SubstitutionPtr get_make_ref_eliminate_;
SubstitutionPtr replace_refkey_by_param_; SubstitutionPtr replace_refkey_by_param_;
SubstitutionPtr replace_old_param_; SubstitutionPtr replace_old_param_;
SubstitutionPtr get_ref_value_eliminate_;
// Branch culling // Branch culling
SubstitutionPtr switch_simplify_; SubstitutionPtr switch_simplify_;
...@@ -84,6 +85,7 @@ class OptimizeIRPassLib { ...@@ -84,6 +85,7 @@ class OptimizeIRPassLib {
// Incorporation // Incorporation
SubstitutionPtr incorporate_getitem_; SubstitutionPtr incorporate_getitem_;
SubstitutionPtr incorporate_getitem_from_param_;
SubstitutionPtr incorporate_getitem_switch_; SubstitutionPtr incorporate_getitem_switch_;
SubstitutionPtr incorporate_call_; SubstitutionPtr incorporate_call_;
SubstitutionPtr incorporate_call_switch_; SubstitutionPtr incorporate_call_switch_;
...@@ -93,6 +95,16 @@ class OptimizeIRPassLib { ...@@ -93,6 +95,16 @@ class OptimizeIRPassLib {
// Convert // Convert
SubstitutionPtr print_tuple_wrapper_; SubstitutionPtr print_tuple_wrapper_;
// Unused parameter eliminate
SubstitutionPtr unused_parameter_eliminate_;
SubstitutionPtr unused_output_eliminate_;
// AddN eliminate
SubstitutionPtr addn_eliminate_;
// Fusion
SubstitutionPtr mark_interface_fusion_;
}; };
// the collection of irpass for resolve action // the collection of irpass for resolve action
...@@ -149,6 +161,23 @@ inline bool IsCNodeGraph(const AnfNodePtr &node) { ...@@ -149,6 +161,23 @@ inline bool IsCNodeGraph(const AnfNodePtr &node) {
return false; return false;
} }
// Check if CNode Input 0 is Func Graph of composite op.
inline bool IsCNodeComposite(const AnfNodePtr &node) {
if (node == nullptr || !node->isa<CNode>()) {
return false;
}
auto inp0 = node->cast<CNodePtr>()->input(0);
if (IsValueNode<FuncGraph>(inp0)) {
auto fg = GetValueNode<FuncGraphPtr>(inp0);
if (fg == nullptr) {
return false;
}
return fg->has_attr(FUNC_GRAPH_FLAG_COMPOSITE);
}
return false;
}
// Check if CNode Input 0 is CNode // Check if CNode Input 0 is CNode
inline bool IsCNodeDup(const AnfNodePtr &node) { inline bool IsCNodeDup(const AnfNodePtr &node) {
if (node == nullptr || !node->isa<CNode>()) { if (node == nullptr || !node->isa<CNode>()) {
......
...@@ -83,6 +83,216 @@ class MultiplyByZeroOrOne : public AnfVisitor { ...@@ -83,6 +83,216 @@ class MultiplyByZeroOrOne : public AnfVisitor {
AnfNodePtr x_{nullptr}; AnfNodePtr x_{nullptr};
}; };
// Support class used for checking if all values of a Tensor are equal `check_value_`
// Supported data types: double, float/float32, int/int32
class CheckTensorConstant {
public:
explicit CheckTensorConstant(int _check_value = 0) : check_value_(_check_value) {}
~CheckTensorConstant() = default;
bool IsTensorConstant(const ValuePtr &value) {
if (!value->isa<tensor::Tensor>()) {
return false;
}
auto tensor_ptr = dyn_cast<tensor::Tensor>(value);
TypeId tensor_type = tensor_ptr->Dtype()->type_id();
if ((tensor_type == TypeId::kNumberTypeFloat32) || (tensor_type == TypeId::kNumberTypeFloat)) {
float *data2 = reinterpret_cast<float *>(tensor_ptr->data_c());
for (int i = 0; i < tensor_ptr->DataSize(); i++) {
if (fabs(data2[i] - check_value_) > FLT_EPSILON) {
return false;
}
return true;
}
} else if (tensor_type == TypeId::kNumberTypeFloat64) {
double *data2 = reinterpret_cast<double *>(tensor_ptr->data_c());
for (int i = 0; i < tensor_ptr->DataSize(); i++) {
if (fabs(data2[i] - check_value_) > DBL_EPSILON) {
return false;
}
return true;
}
} else if ((tensor_type == TypeId::kNumberTypeInt32) || (tensor_type == TypeId::kNumberTypeInt)) {
int *data2 = reinterpret_cast<int *>(tensor_ptr->data_c());
for (int i = 0; i < tensor_ptr->DataSize(); i++) {
if (data2[i] != check_value_) {
return false;
}
return true;
}
}
// Un-support Data Types
return false;
}
bool IsTensorScalarConstant(const ValuePtr &value) {
if (!value->isa<tensor::Tensor>()) {
return false;
}
auto tensor_ptr = dyn_cast<tensor::Tensor>(value);
if ((tensor_ptr->DataSize() > 1) || (tensor_ptr->DataDim() > 0)) {
return false;
}
return IsTensorConstant(value);
}
private:
int check_value_;
};
// {prim::kPrimMul, 0, X}, {prim::kPrimMul, X, 0}
// {prim::kPrimMul, 1, X}, {prim::kPrimMul, X, 1}
class TensorMultiplyByZeroOrOne : public AnfVisitor {
public:
TensorMultiplyByZeroOrOne() : zero_(MakeValue(0)) {}
~TensorMultiplyByZeroOrOne() override = default;
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
Reset();
AnfVisitor::Match(prim::kPrimMul)(node);
if (is_zero_) {
if (x_->func_graph() != node->func_graph()) {
return nullptr;
}
return NewTensorFilledWithData(node);
}
if (is_one_) {
return NewTensorFilledWithData(node, x_);
}
return nullptr;
}
void Visit(const AnfNodePtr &node) override {
if (is_zero_ || is_one_) {
x_ = node;
return;
}
if (IsParam(node)) {
x_ = node;
return;
}
if (IsCNode(node)) {
CNodePtr cnode = node->cast<CNodePtr>();
if (IsPrimitive(cnode->input(0), prim::kPrimZerosLikeTensor)) {
is_zero_ = true;
return;
}
x_ = node;
return;
}
auto value = node->cast<ValueNodePtr>()->value();
if (CheckTensorConstant(0).IsTensorConstant(value)) {
is_zero_ = true;
return;
} else if (CheckTensorConstant(1).IsTensorConstant(value)) {
is_one_ = true;
return;
}
x_ = node;
}
void Visit(const ValueNodePtr &vnode) override {
auto value = vnode->value();
if (CheckTensorConstant(0).IsTensorConstant(value)) {
is_zero_ = true;
return;
} else if (CheckTensorConstant(1).IsTensorConstant(value)) {
is_one_ = true;
return;
}
x_ = vnode;
}
void Reset() {
x_ = nullptr;
is_one_ = false;
is_zero_ = false;
}
void *GetPointerToTensorData(const AnfNodePtr &node, bool writable = false) {
if (!node->isa<ValueNode>()) {
return nullptr;
}
auto value = node->cast<ValueNodePtr>()->value();
if (!value->isa<tensor::Tensor>()) {
return nullptr;
}
tensor::TensorPtr tensor_ptr = dyn_cast<tensor::Tensor>(value);
return tensor_ptr->data_c(writable);
}
// Make a new tensor (when possible) with the same shape as of `node`
// If x is nullptr then fill new tensor will "0"
// If x is a tensor with empty shape then fill new tensor with the single value of x
// If x is a tensor with same shape as `node` then return x as result
AnfNodePtr NewTensorFilledWithData(const AnfNodePtr &node, const AnfNodePtr &x = nullptr) {
if ((node->abstract() == nullptr) || !node->abstract()->isa<abstract::AbstractTensor>()) {
return nullptr;
}
auto tensor_abstract = node->abstract()->cast<abstract::AbstractTensorPtr>();
TypePtr tensor_type_ptr = tensor_abstract->element()->BuildType();
std::vector<int> tensor_shape = tensor_abstract->shape()->shape();
auto new_tensor_ptr = std::make_shared<tensor::Tensor>(tensor_type_ptr->type_id(), tensor_shape);
size_t mem_size = GetTypeByte(tensor_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum());
char *data = reinterpret_cast<char *>(new_tensor_ptr->data_c(true));
if (x == nullptr) {
std::memset(data, 0, mem_size);
auto new_vnode = NewValueNode(new_tensor_ptr);
new_vnode->set_abstract(new_tensor_ptr->ToAbstract());
return new_vnode;
}
// x is not nullptr
if (x->isa<CNode>()) {
if ((x->abstract() == nullptr) || !x->abstract()->isa<abstract::AbstractTensor>()) {
return nullptr;
}
auto x_abstract = x->abstract()->cast<abstract::AbstractTensorPtr>();
std::vector<int> x_shape = x_abstract->shape()->shape();
if (x_shape != tensor_shape) {
return nullptr;
}
return x;
}
if (!x->isa<ValueNode>()) {
return nullptr;
}
auto x_value = x->cast<ValueNodePtr>()->value();
if (!x_value->isa<tensor::Tensor>()) {
return nullptr;
}
auto x_tensor_ptr = dyn_cast<tensor::Tensor>(x_value);
if ((x_tensor_ptr->DataSize() > 1) && (x_tensor_ptr->DataSize() != new_tensor_ptr->DataSize())) {
return nullptr;
}
char *source_data = reinterpret_cast<char *>(GetPointerToTensorData(x));
if (x_tensor_ptr->DataSize() == 1) {
for (int i = 0; i < new_tensor_ptr->ElementsNum(); i++) {
memcpy(source_data, data + i * GetTypeByte(tensor_type_ptr), GetTypeByte(tensor_type_ptr));
}
} else {
memcpy(source_data, data, mem_size);
}
auto new_vnode = NewValueNode(new_tensor_ptr);
new_vnode->set_abstract(new_tensor_ptr->ToAbstract());
return new_vnode;
}
private:
bool is_zero_{false}, is_one_{false};
ValuePtr zero_;
AnfNodePtr x_{nullptr};
};
// {prim::kPrimScalarAdd, X, 0} // {prim::kPrimScalarAdd, X, 0}
// {prim::kPrimScalarAdd, 0, X} // {prim::kPrimScalarAdd, 0, X}
class AddByZero : public AnfVisitor { class AddByZero : public AnfVisitor {
...@@ -101,7 +311,8 @@ class AddByZero : public AnfVisitor { ...@@ -101,7 +311,8 @@ class AddByZero : public AnfVisitor {
} }
void Visit(const AnfNodePtr &node) override { void Visit(const AnfNodePtr &node) override {
if (node->isa<ValueNode>() && *GetValueNode(node) == *zero_) { if (node->isa<ValueNode>() &&
((*GetValueNode(node) == *zero_) || CheckTensorConstant(0).IsTensorScalarConstant(GetValueNode(node)))) {
is_zero_ = true; is_zero_ = true;
return; return;
} }
...@@ -139,10 +350,22 @@ class TensorAddByZero : public AnfVisitor { ...@@ -139,10 +350,22 @@ class TensorAddByZero : public AnfVisitor {
is_zero_ = true; is_zero_ = true;
return; return;
} }
if (node->isa<ValueNode>() && CheckTensorConstant(0).IsTensorScalarConstant(GetValueNode(node))) {
is_zero_ = true;
return;
}
x_ = node; x_ = node;
} }
void Visit(const ValueNodePtr &vnode) override {
auto value = vnode->value();
if (CheckTensorConstant(0).IsTensorConstant(value)) {
is_zero_ = true;
return;
}
}
void Reset() { void Reset() {
x_ = nullptr; x_ = nullptr;
is_zero_ = false; is_zero_ = false;
...@@ -183,29 +406,143 @@ class OptUpdateZeroTensor : public AnfVisitor { ...@@ -183,29 +406,143 @@ class OptUpdateZeroTensor : public AnfVisitor {
// {prim::kPrimMul, {...}, {prim::kPrimMul, Tensor1, Tensor2}} // {prim::kPrimMul, {...}, {prim::kPrimMul, Tensor1, Tensor2}}
class ConstantDuplicateMul : public AnfVisitor { class ConstantDuplicateMul : public AnfVisitor {
public: public:
// Support function to multiply two constant tensors: partially support broadcasting shapes
template <typename T>
void Multiply(void *in_data_1, int in_data_1_size, void *in_data_2, int in_data_2_size, void **out_data,
int out_data_size) {
T *data_1 = reinterpret_cast<T *>(in_data_1);
T *data_2 = reinterpret_cast<T *>(in_data_2);
T *data_out = new T[out_data_size];
if (in_data_1_size == 1) {
for (int i = 0; i < out_data_size; i++) {
data_out[i] = data_1[0];
}
} else {
for (int i = 0; i < out_data_size; i++) {
data_out[i] = data_1[i];
}
}
if (in_data_2_size == 1) {
for (int i = 0; i < out_data_size; i++) {
data_out[i] *= data_2[0];
}
} else {
for (int i = 0; i < out_data_size; i++) {
data_out[i] *= data_2[i];
}
}
*out_data = reinterpret_cast<void *>(data_out);
return;
}
AnfNodePtr MulConstantTensors(const AnfNodePtr &vnode_1, const AnfNodePtr &vnode_2, const AnfNodePtr &node_3) {
if (!vnode_1->isa<ValueNode>() || !vnode_2->isa<ValueNode>() || (vnode_1->abstract() == nullptr) ||
(vnode_2->abstract() == nullptr) || (node_3->abstract() == nullptr)) {
return nullptr;
}
auto value_1 = GetValueNode(vnode_1);
auto value_2 = GetValueNode(vnode_2);
if (!value_1->isa<tensor::Tensor>() || !value_2->isa<tensor::Tensor>()) {
return nullptr;
}
auto tensor_ptr_1 = dyn_cast<tensor::Tensor>(value_1);
auto tensor_ptr_2 = dyn_cast<tensor::Tensor>(value_2);
auto tensor_1_abstract = vnode_1->abstract()->cast<abstract::AbstractTensorPtr>();
auto tensor_2_abstract = vnode_1->abstract()->cast<abstract::AbstractTensorPtr>();
auto tensor_3_abstract = node_3->abstract()->cast<abstract::AbstractTensorPtr>();
TypePtr tensor_1_type_ptr = tensor_1_abstract->element()->BuildType();
TypePtr tensor_2_type_ptr = tensor_2_abstract->element()->BuildType();
TypePtr tensor_3_type_ptr = tensor_3_abstract->element()->BuildType();
if ((tensor_1_type_ptr->type_id() != tensor_3_type_ptr->type_id()) ||
(tensor_2_type_ptr->type_id() != tensor_3_type_ptr->type_id())) {
return nullptr;
}
std::vector<int> tensor_out_shape = tensor_3_abstract->shape()->shape();
int data_out_size = 1;
for (auto it : tensor_out_shape) {
data_out_size *= it;
}
if ((tensor_ptr_1->DataSize() > 1) && (tensor_ptr_1->DataSize() != data_out_size)) {
return nullptr;
}
if ((tensor_ptr_2->DataSize() > 1) && (tensor_ptr_2->DataSize() != data_out_size)) {
return nullptr;
}
void *data_out;
if ((tensor_3_type_ptr->type_id() == TypeId::kNumberTypeFloat32) ||
(tensor_3_type_ptr->type_id() == TypeId::kNumberTypeFloat)) {
Multiply<float>(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(),
tensor_ptr_2->DataSize(), &data_out, data_out_size);
} else {
if (tensor_3_type_ptr->type_id() == TypeId::kNumberTypeFloat64) {
Multiply<double>(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(),
tensor_ptr_2->DataSize(), &data_out, data_out_size);
} else {
if ((tensor_3_type_ptr->type_id() == TypeId::kNumberTypeInt32) ||
(tensor_3_type_ptr->type_id() == TypeId::kNumberTypeInt)) {
Multiply<int>(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(),
tensor_ptr_2->DataSize(), &data_out, data_out_size);
} else {
// Un-support data types
return nullptr;
}
}
}
auto new_tensor_ptr = std::make_shared<tensor::Tensor>(tensor_3_type_ptr->type_id(), tensor_out_shape);
size_t mem_size = GetTypeByte(tensor_3_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum());
char *data = reinterpret_cast<char *>(new_tensor_ptr->data_c(true));
memcpy(data, data_out, mem_size);
auto new_vnode = NewValueNode(new_tensor_ptr);
new_vnode->set_abstract(new_tensor_ptr->ToAbstract());
return new_vnode;
}
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
Reset(); Reset();
// {prim::kPrimMul, Tensor1, {...}} // {prim::kPrimMul, Tensor1, {...}}
AnfVisitor::Match(prim::kPrimMul, {IsNode, IsNode})(node); AnfVisitor::Match(prim::kPrimMul, {IsNode, IsNode})(node);
if (vnode_ == nullptr || cnode_ == nullptr) { if (vnode_ == nullptr || c_p_node_ == nullptr) {
return nullptr;
}
if (!IsCNode(c_p_node_)) {
return nullptr; return nullptr;
} }
auto tensor1 = vnode_; auto tensor1 = vnode_;
auto mul = cnode_; auto mul = c_p_node_->cast<CNodePtr>();
Reset(); Reset();
// {prim::kPrimMul, Tensor2, {...}} // {prim::kPrimMul, Tensor2, {...}}
AnfVisitor::Match(prim::kPrimMul, {IsNode, IsNode})(mul); AnfVisitor::Match(prim::kPrimMul, {IsNode, IsNode})(mul);
if (vnode_ == nullptr || cnode_ == nullptr) { if (vnode_ == nullptr || c_p_node_ == nullptr) {
return nullptr; return nullptr;
} }
auto tensor2 = vnode_; auto tensor2 = vnode_;
auto cnode = cnode_; auto c_p_node = c_p_node_;
auto PrimMul = GetValueNode<PrimitivePtr>(mul->input(0)); auto PrimMul = GetValueNode<PrimitivePtr>(mul->input(0));
auto fg = node->func_graph(); auto fg = node->func_graph();
auto ttmul = NewCNode({NewValueNode(PrimMul), tensor1, tensor2}, fg);
return NewCNode({NewValueNode(PrimMul), cnode, ttmul}, fg); auto new_mul_tensor = MulConstantTensors(tensor1, tensor2, c_p_node);
if (new_mul_tensor == nullptr) {
auto ttmul = NewCNode({NewValueNode(PrimMul), tensor1, tensor2}, fg);
return NewCNode({NewValueNode(PrimMul), c_p_node, ttmul}, fg);
}
return NewCNode({NewValueNode(PrimMul), c_p_node, new_mul_tensor}, fg);
} }
void Visit(const AnfNodePtr &node) override { void Visit(const AnfNodePtr &node) override {
...@@ -213,19 +550,40 @@ class ConstantDuplicateMul : public AnfVisitor { ...@@ -213,19 +550,40 @@ class ConstantDuplicateMul : public AnfVisitor {
vnode_ = node; vnode_ = node;
} }
if (IsCNode(node)) { if (IsCNode(node) || IsParam(node)) {
cnode_ = node->cast<CNodePtr>(); c_p_node_ = node;
} }
} }
void Reset() { void Reset() {
vnode_ = nullptr; vnode_ = nullptr;
cnode_ = nullptr; c_p_node_ = nullptr;
} }
private: private:
AnfNodePtr vnode_; AnfNodePtr vnode_;
CNodePtr cnode_; AnfNodePtr c_p_node_;
};
class PowerOneEliminate : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
if (!IsPrimitiveCNode(node, prim::kPrimPow) || node->func_graph() == nullptr) {
return nullptr;
}
auto &inputs = node->cast<CNodePtr>()->inputs();
if (!IsValueNode<Scalar>(inputs[2])) {
return nullptr;
}
auto scalar = GetValueNode<ScalarPtr>(inputs[2]);
if (scalar->isa<FloatImm>() && GetValue<float>(scalar) == 1.0) {
return inputs[1];
} else if (scalar->isa<IntergerImm>() && GetValue<int>(scalar) == 1) {
return inputs[1];
}
return nullptr;
}
}; };
// grad = AllReduce(grad) / worker_number // grad = AllReduce(grad) / worker_number
...@@ -341,17 +699,21 @@ class ArithmeticSimplify { ...@@ -341,17 +699,21 @@ class ArithmeticSimplify {
public: public:
ArithmeticSimplify() ArithmeticSimplify()
: multiply_by_zero_or_one_(), : multiply_by_zero_or_one_(),
tensor_multiply_by_zero_or_one_(),
add_by_zero_(), add_by_zero_(),
tensor_add_by_zero_(), tensor_add_by_zero_(),
identity_(prim::kPrimIdentity), identity_(prim::kPrimIdentity),
opt_update_zero_tensor_(), opt_update_zero_tensor_(),
constant_duplicate_mul_() { constant_duplicate_mul_(),
power_one_() {
eliminaters_.emplace_back(multiply_by_zero_or_one_); eliminaters_.emplace_back(multiply_by_zero_or_one_);
eliminaters_.emplace_back(tensor_multiply_by_zero_or_one_);
eliminaters_.emplace_back(add_by_zero_); eliminaters_.emplace_back(add_by_zero_);
eliminaters_.emplace_back(tensor_add_by_zero_); eliminaters_.emplace_back(tensor_add_by_zero_);
eliminaters_.emplace_back(identity_); eliminaters_.emplace_back(identity_);
eliminaters_.emplace_back(opt_update_zero_tensor_); eliminaters_.emplace_back(opt_update_zero_tensor_);
eliminaters_.emplace_back(constant_duplicate_mul_); eliminaters_.emplace_back(constant_duplicate_mul_);
eliminaters_.emplace_back(power_one_);
} }
~ArithmeticSimplify() = default; ~ArithmeticSimplify() = default;
...@@ -368,11 +730,13 @@ class ArithmeticSimplify { ...@@ -368,11 +730,13 @@ class ArithmeticSimplify {
private: private:
MultiplyByZeroOrOne multiply_by_zero_or_one_; MultiplyByZeroOrOne multiply_by_zero_or_one_;
TensorMultiplyByZeroOrOne tensor_multiply_by_zero_or_one_;
AddByZero add_by_zero_; AddByZero add_by_zero_;
TensorAddByZero tensor_add_by_zero_; TensorAddByZero tensor_add_by_zero_;
PrimEliminater identity_; PrimEliminater identity_;
OptUpdateZeroTensor opt_update_zero_tensor_; OptUpdateZeroTensor opt_update_zero_tensor_;
ConstantDuplicateMul constant_duplicate_mul_; ConstantDuplicateMul constant_duplicate_mul_;
PowerOneEliminate power_one_;
std::vector<TransformFuncType> eliminaters_{}; std::vector<TransformFuncType> eliminaters_{};
}; };
} // namespace irpass } // namespace irpass
......
...@@ -80,7 +80,8 @@ bool InConvertWhiteList(const AnfNodePtr &node, size_t index) { ...@@ -80,7 +80,8 @@ bool InConvertWhiteList(const AnfNodePtr &node, size_t index) {
using NodeInputReplMap = std::unordered_map<std::pair<AnfNodePtr, size_t>, AnfNodePtr, PairHasher>; using NodeInputReplMap = std::unordered_map<std::pair<AnfNodePtr, size_t>, AnfNodePtr, PairHasher>;
// replace the nodes which should be changed // replace the nodes which should be changed
void RunSwitchNodeReplace(const FuncGraphManagerPtr &manager, std::vector<std::pair<CNodePtr, CNodePtr>> nodes_changed, void RunSwitchNodeReplace(const FuncGraphManagerPtr &manager, std::vector<std::pair<CNodePtr, CNodePtr>> nodes_changed,
std::unordered_map<AnfNodePtr, AnfNodePtr> repl_node, NodeInputReplMap repl_node_inputs) { std::unordered_map<AnfNodePtr, AnfNodePtr> repl_node, NodeInputReplMap repl_node_inputs,
const FuncGraphPtr &func_graph) {
for (auto &node_pair : nodes_changed) { for (auto &node_pair : nodes_changed) {
CNodePtr old_node = node_pair.first; CNodePtr old_node = node_pair.first;
CNodePtr new_node = node_pair.second; CNodePtr new_node = node_pair.second;
...@@ -99,9 +100,11 @@ void RunSwitchNodeReplace(const FuncGraphManagerPtr &manager, std::vector<std::p ...@@ -99,9 +100,11 @@ void RunSwitchNodeReplace(const FuncGraphManagerPtr &manager, std::vector<std::p
} }
for (auto &item : repl_node) { for (auto &item : repl_node) {
if (!manager->Replace(item.first, item.second)) { if (IsPrimitiveCNode(item.second, prim::kPrimReturn)) {
MS_LOG(EXCEPTION) << "TransformGraphDependNode replace node failed original:" << item.first->DebugString() func_graph->set_output(item.second->cast<CNodePtr>()->input(1));
<< " to new: " << item.second->DebugString(); } else if (!manager->Replace(item.first, item.second)) {
MS_LOG(EXCEPTION) << "TransformGraphDependNode replace node failed original:" << item.first->DebugString(2)
<< " to new: " << item.second->DebugString(2);
} }
} }
} }
...@@ -154,7 +157,7 @@ FuncGraphPtr TransformGraphCondBranchNodes( ...@@ -154,7 +157,7 @@ FuncGraphPtr TransformGraphCondBranchNodes(
nodes_changed.emplace_back(node->cast<CNodePtr>(), new_node); nodes_changed.emplace_back(node->cast<CNodePtr>(), new_node);
} }
} }
RunSwitchNodeReplace(manager, nodes_changed, repl_node, repl_node_inputs); RunSwitchNodeReplace(manager, nodes_changed, repl_node, repl_node_inputs, graph);
return graph; return graph;
} }
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include <algorithm> #include <algorithm>
#include <unordered_map> #include <unordered_map>
#include <memory> #include <memory>
#include <unordered_set>
#include "optimizer/irpass.h" #include "optimizer/irpass.h"
#include "optimizer/optimizer.h" #include "optimizer/optimizer.h"
...@@ -28,7 +29,6 @@ ...@@ -28,7 +29,6 @@
#include "ir/func_graph.h" #include "ir/func_graph.h"
#include "ir/func_graph_cloner.h" #include "ir/func_graph_cloner.h"
#include "operator/ops.h" #include "operator/ops.h"
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
namespace irpass { namespace irpass {
...@@ -81,13 +81,32 @@ class IncorporateGetitem : public AnfVisitor { ...@@ -81,13 +81,32 @@ class IncorporateGetitem : public AnfVisitor {
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
Reset(); Reset();
AnfVisitor::Match(prim::kPrimTupleGetItem, {IsCNode, IsValueNode<Int32Imm>})(node); AnfVisitor::Match(prim::kPrimTupleGetItem, {IsCNode, IsValueNode<Int32Imm>})(node);
if (node->func_graph() == nullptr || idx_ == -1 || fg_ == nullptr) {
return nullptr;
}
if (node->func_graph() != nullptr && idx_ >= 0 && fg_ != nullptr) { if (fg_->has_attr(FUNC_GRAPH_FLAG_COMPOSITE)) {
auto new_fg = getitem_transform_(fg_, idx_); // If composite has muti output, do not split.
(void)args_.insert(args_.begin(), NewValueNode(new_fg)); // some composite output has EnvInstance node or DeadCode node should split.
return node->func_graph()->NewCNode(args_); auto output = fg_->output();
if (IsPrimitiveCNode(output, prim::kPrimMakeTuple)) {
auto output_cnode = output->cast<CNodePtr>();
auto outputs = output_cnode->inputs();
int real_output_cnt = 0;
for (size_t i = 1; i < outputs.size(); ++i) {
if (IsCNode(outputs[i]) || IsValueNode<tensor::Tensor>(outputs[i]) || IsParam(outputs[i])) {
real_output_cnt++;
if (real_output_cnt > 1) {
return nullptr;
}
}
}
}
} }
return nullptr;
auto new_fg = getitem_transform_(fg_, idx_);
(void)args_.insert(args_.begin(), NewValueNode(new_fg));
return node->func_graph()->NewCNode(args_);
} }
void Visit(const CNodePtr &cnode) override { void Visit(const CNodePtr &cnode) override {
...@@ -115,6 +134,172 @@ class IncorporateGetitem : public AnfVisitor { ...@@ -115,6 +134,172 @@ class IncorporateGetitem : public AnfVisitor {
internal::GetitemTransform getitem_transform_; internal::GetitemTransform getitem_transform_;
}; };
class IncorporateGetitemFromParam : public AnfVisitor {
public:
void Process(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const AnfNodePtr &param, size_t input_idx) {
auto mng = func_graph->manager();
MS_EXCEPTION_IF_NULL(mng);
auto &node_users = mng->node_users();
if (node_users.find(param) == node_users.end() || node_users[param].empty()) {
args_.push_back(cnode->input(input_idx + 1));
return;
}
for (auto &user : node_users[param]) {
if (!IsPrimitiveCNode(user.first, prim::kPrimTupleGetItem)) {
// we do not process this case.
args_.push_back(cnode->input(input_idx + 1));
return;
}
}
// update new args.
if (IsPrimitiveCNode(cnode->input(input_idx + 1), prim::kPrimMakeTuple)) {
// case 1
replace_parameters_[input_idx] = true;
need_update_ = true;
auto make_tuple_cnode = cnode->input(input_idx + 1)->cast<CNodePtr>();
auto &make_tuple_cnode_inputs = make_tuple_cnode->inputs();
inputs_num_[input_idx] = make_tuple_cnode_inputs.size() - 1;
args_.insert(args_.end(), make_tuple_cnode_inputs.begin() + 1, make_tuple_cnode_inputs.end());
} else {
// case 2
auto prev_cnode = cnode->input(input_idx + 1)->cast<CNodePtr>();
auto prev_fg = GetValueNode<FuncGraphPtr>(prev_cnode->input(0));
auto fg_output = prev_fg->output();
if (!IsPrimitiveCNode(fg_output, prim::kPrimMakeTuple)) {
MS_LOG(ERROR) << "The return of: " << prev_fg->ToString()
<< " should be a make tuple, but got: " << fg_output->DebugString();
return;
}
replace_parameters_[input_idx] = true;
need_update_ = true;
auto make_tuple_cnode = fg_output->cast<CNodePtr>();
inputs_num_[input_idx] = make_tuple_cnode->inputs().size() - 1;
for (size_t output_i = 0; output_i < inputs_num_[input_idx]; ++output_i) {
auto new_getitem =
func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), prev_cnode, NewValueNode(SizeToInt(output_i))});
auto aptr = std::make_shared<abstract::AbstractScalar>(std::make_shared<Int32Imm>(SizeToInt(output_i)));
new_getitem->input(2)->set_abstract(aptr);
new_getitem->set_abstract(make_tuple_cnode->input(output_i + 1)->abstract());
args_.push_back(new_getitem);
}
}
}
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
if (node->func_graph() == nullptr) {
return nullptr;
}
Reset();
auto cnode = node->cast<CNodePtr>();
if (cnode == nullptr) {
return nullptr;
}
auto &inputs = cnode->inputs();
auto fg = GetValueNode<FuncGraphPtr>(inputs[0]);
if (fg == nullptr) {
return nullptr;
}
auto mng = fg->manager();
MS_EXCEPTION_IF_NULL(mng);
auto parameters = fg->parameters();
if (parameters.size() != inputs.size() - 1) {
return nullptr;
}
replace_parameters_ = std::vector<bool>(parameters.size(), false);
inputs_num_ = std::vector<size_t>(parameters.size(), 1);
auto node_fg = node->func_graph();
for (size_t i = 1; i < inputs.size(); ++i) {
if (IsPrimitiveCNode(inputs[i], prim::kPrimMakeTuple) || IsCNodeComposite(inputs[i])) {
Process(node_fg, cnode, parameters[i - 1], i - 1);
} else {
args_.push_back(inputs[i]);
}
}
if (!need_update_) {
return nullptr;
}
FuncGraphPtr new_fg = TransformableClone(fg, std::make_shared<TraceTransform>("sp"));
mng->AddFuncGraph(new_fg);
auto node_users = mng->node_users();
std::vector<AnfNodePtr> new_fg_parameters = new_fg->parameters();
std::vector<AnfNodePtr> new_parameters;
size_t curr_input_idx{0};
for (size_t param_i = 0; param_i < new_fg_parameters.size(); ++param_i) {
if (!replace_parameters_[param_i]) {
if (parameters[param_i]->abstract() != nullptr) {
new_fg_parameters[param_i]->set_abstract(parameters[param_i]->abstract());
}
new_parameters.push_back(new_fg_parameters[param_i]);
curr_input_idx++;
continue;
}
// make a new parameter.
for (size_t input_i = 0; input_i < inputs_num_[param_i]; ++input_i) {
auto new_param = std::make_shared<Parameter>(new_fg);
new_param->set_abstract(args_.at(curr_input_idx)->abstract());
// update users of new parameter.
for (auto &user : node_users[new_fg_parameters[param_i]]) {
idx_ = -1;
AnfVisitor::Match(prim::kPrimTupleGetItem, {IsParam, IsValueNode<Int32Imm>})(user.first);
if (idx_ == -1) {
MS_LOG(ERROR) << "User of: " << new_fg_parameters[param_i]->DebugString()
<< " must be tuple getitem here, but got: " << user.first->DebugString();
return nullptr;
}
if (input_i == IntToSize(idx_)) {
for (auto &sub_user : node_users[user.first]) {
auto sub_user_cnode = sub_user.first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(sub_user_cnode);
sub_user_cnode->set_input(sub_user.second, new_param);
(void)mng->Replace(sub_user.first, sub_user_cnode);
}
}
}
// (void)mng->Replace(new_fg_parameters[param_i], new_param);
new_parameters.push_back(new_param);
curr_input_idx++;
}
}
mng->SetParameters(new_fg, new_parameters);
(void)args_.insert(args_.begin(), NewValueNode(new_fg));
auto new_call = node_fg->NewCNode(args_);
new_call->set_abstract(node->abstract());
return new_call;
}
void Visit(const ValueNodePtr &vnode) override { idx_ = GetValue<int>(vnode->value()); }
void Visit(const CNodePtr &cnode) override {}
void Reset() {
replace_parameters_.clear();
args_.clear();
inputs_num_.clear();
need_update_ = false;
idx_ = -1;
}
private:
std::vector<bool> replace_parameters_{};
std::vector<AnfNodePtr> args_{};
std::vector<size_t> inputs_num_{};
bool need_update_{false};
int idx_{-1};
};
// {prim::kPrimTupleGetItem, {{prim::kPrimSwitch, X, G1, G2}, Xs}, C} // {prim::kPrimTupleGetItem, {{prim::kPrimSwitch, X, G1, G2}, Xs}, C}
class IncorporateGetitemSwitch : public AnfVisitor { class IncorporateGetitemSwitch : public AnfVisitor {
public: public:
......
...@@ -90,20 +90,10 @@ bool IsUniqueUse(const FuncGraphPtr &fg, AnfNodePtr) { ...@@ -90,20 +90,10 @@ bool IsUniqueUse(const FuncGraphPtr &fg, AnfNodePtr) {
bool IsInside(FuncGraphPtr, const AnfNodePtr &node) { bool IsInside(FuncGraphPtr, const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node->func_graph()); MS_EXCEPTION_IF_NULL(node->func_graph());
auto &flags = node->func_graph()->flags(); return node->func_graph()->has_flag("inline_inside");
if (flags.find("inline_inside") != flags.end()) {
return flags["inline_inside"];
}
return false;
} }
bool IsCore(const FuncGraphPtr &fg, AnfNodePtr) { bool IsCore(const FuncGraphPtr &fg, AnfNodePtr) { return fg->has_flag("core"); }
auto &flags = fg->flags();
if (flags.find("core") != flags.end()) {
return flags["core"];
}
return false;
}
bool NoCriterion(FuncGraphPtr, AnfNodePtr) { return true; } bool NoCriterion(FuncGraphPtr, AnfNodePtr) { return true; }
...@@ -127,6 +117,13 @@ class InlinerBase : public AnfVisitor { ...@@ -127,6 +117,13 @@ class InlinerBase : public AnfVisitor {
if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE)) { if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE)) {
return nullptr; return nullptr;
} }
// Do not inline composite op to Cell.
if (fg->has_attr(FUNC_GRAPH_FLAG_COMPOSITE) && !node->func_graph()->has_attr(FUNC_GRAPH_FLAG_COMPOSITE)) {
// If the composite op only contains a return node, we make it inlined.
if (fg->nodes().size() - fg->parameters().size() > 1) {
return nullptr;
}
}
Reset(); Reset();
bool is_match = false; bool is_match = false;
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_MARK_INTERFACE_FUSION_H
#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_MARK_INTERFACE_FUSION_H
#include <string>
#include <sstream>
#include <unordered_map>
#include "session/anf_runtime_algorithm.h"
#include "optimizer/optimizer.h"
#include "optimizer/irpass.h"
#include "ir/visitor.h"
#include "operator/ops.h"
#include "utils/graph_utils.h"
#include "operator/composite/composite.h"
namespace mindspore {
namespace opt {
namespace irpass {
static int count = 0;
std::string GetFusionNumber() {
std::stringstream ss;
ss << std::setw(4) << std::setfill('0') << count;
std::string num = ss.str();
++count;
return "_" + num;
}
// Mark CNodes which can be merged in kernel build
class MarkInterfaceFusion : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
if (node->func_graph()->has_attr(FUNC_GRAPH_FLAG_COMPOSITE) && IsPrimitiveCNode(node, prim::kPrimSelect)) {
auto cnode = node->cast<CNodePtr>();
auto condition = cnode->input(1);
std::string cmp;
std::unordered_map<std::string, std::string> cmp_list = {{"GreaterEqual", "GE"}, {"Greater", "GT"},
{"LessEqual", "LE"}, {"Less", "LT"},
{"Equal", "EQ"}, {"NotEqual", "NE"}};
if (IsPrimitiveCNode(condition)) {
auto prim_name = GetCNodeFuncName(condition->cast<CNodePtr>());
if (cmp_list.count(prim_name) != 0) {
// Mark Select and compare node
cmp = cmp_list[prim_name];
auto cnt = GetFusionNumber();
AnfAlgo::SetNodeAttr("fusion", MakeValue("Select" + cmp + cnt), condition);
AnfAlgo::SetNodeAttr("fusion", MakeValue("Select" + cmp + cnt + "_end"), node);
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
if (IsPrimitiveCNode(cnode->input(i), prim::kPrimZerosLike)) {
AnfAlgo::SetNodeAttr("fusion", MakeValue("Select" + cmp + cnt), cnode->input(i));
}
}
}
}
}
return nullptr;
}
void Visit(const AnfNodePtr &) override {}
private:
AnfNodePtr y_{nullptr};
};
} // namespace irpass
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_MARK_INTERFACE_FUSION_H
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include <vector> #include <vector>
#include <algorithm> #include <algorithm>
#include <memory>
#include "optimizer/irpass.h" #include "optimizer/irpass.h"
#include "optimizer/optimizer.h" #include "optimizer/optimizer.h"
...@@ -196,6 +197,131 @@ class AddNZeroFilter : public AnfVisitor { ...@@ -196,6 +197,131 @@ class AddNZeroFilter : public AnfVisitor {
std::vector<AnfNodePtr> filtered_Xs_{}, Xs_{}; std::vector<AnfNodePtr> filtered_Xs_{}, Xs_{};
bool has_zero_like_{false}; bool has_zero_like_{false};
}; };
// {PrimAddN, {kPrimMakeTuple, Xs}}
// Akg don't support AddN(ValueNode, Tensor, ...), converted to TensorAdd.
// case0: AddN(inputs)(inputs size < 2) -> error
// case1: AddN(inputs)(all inputs is ValueNode) -> error
// case2: AddN(inputs)(inputs size = 2) -> TensorAdd(Tensor, Tensor)
// case3: AddN(ValueNode, Tensor, Tensor, ...)(has one ValueNode input)
// -> TensorAdd(ValueNode, AddN(Tensor, Tensor, ...))
class AddNEliminater : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
if (!node->isa<CNode>() || node->func_graph() == nullptr) {
return nullptr;
}
auto &inputs = node->cast<CNodePtr>()->inputs();
auto fg = GetValueNode<FuncGraphPtr>(inputs[0]);
MS_EXCEPTION_IF_NULL(fg);
auto mng = fg->manager();
MS_EXCEPTION_IF_NULL(mng);
if (fg->recursive()) {
return nullptr;
}
auto new_fg = TransformableClone(fg, std::make_shared<TraceTransform>("fg"));
mng->AddFuncGraph(new_fg);
need_update_ = false;
bool changed = false;
do {
changed = false;
changed |= Process(new_fg);
} while (changed);
if (!need_update_) {
return nullptr;
} else {
auto new_sx = inputs;
new_sx[0] = NewValueNode(new_fg);
return node->func_graph()->NewCNode(new_sx);
}
}
bool Process(const FuncGraphPtr &func_graph) {
auto mng = func_graph->manager();
MS_EXCEPTION_IF_NULL(mng);
auto nodes = TopoSort(func_graph->output());
bool changed = false;
for (size_t i = 0; i < nodes.size(); ++i) {
auto node = nodes[i];
if (!IsPrimitiveCNode(node, prim::kPrimAddN)) {
continue;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto &tuple_input = cnode->input(1);
MS_EXCEPTION_IF_NULL(tuple_input);
auto tuple_input_cnode = tuple_input->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(tuple_input_cnode);
auto &tuple_inputs = tuple_input_cnode->inputs();
if (tuple_inputs.size() < 3) {
// case0: inputs size < 2, error
MS_EXCEPTION(ArgumentError) << "Inputs size of AddN less than 2. " << cnode->DebugString(2);
}
int valuenode_num =
std::accumulate(tuple_inputs.begin() + 1, tuple_inputs.end(), 0, [](int accumulator, const AnfNodePtr &node) {
if (IsValueNode<tensor::Tensor>(node)) {
return accumulator + 1;
} else {
return accumulator;
}
});
if (IntToSize(valuenode_num) == tuple_inputs.size()) {
// case1: all inputs is ValueNode, error
MS_EXCEPTION(ArgumentError) << "All inputs of AddN is ValueNode. " << cnode->DebugString(2);
}
if (tuple_inputs.size() == 3) {
// case2: inputs size = 2, -> TensorAdd(Tensor, Tensor)
MS_LOG(DEBUG) << "Replace AddN with two inputs with TensorAdd. " << cnode->DebugString(2);
ValuePtr prim_tensoradd = prim::GetPythonOps("TensorAdd", "mindspore.ops.operations");
std::vector<AnfNodePtr> new_xs{func_graph->NewCNode({NewValueNode(prim_tensoradd)}), tuple_inputs[1],
tuple_inputs[2]};
mng->Replace(node, func_graph->NewCNode(new_xs));
changed = true;
continue;
}
auto first_valuenode = std::find_if(tuple_inputs.begin() + 1, tuple_inputs.end(),
[](const AnfNodePtr &node) { return IsValueNode<tensor::Tensor>(node); });
if (first_valuenode == tuple_inputs.end()) {
// no ValueNode input found.
continue;
} else {
// case3: has one ValueNode input -> TensorAdd(ValueNode, AddN(Tensor, Tensor, ...))
std::vector<AnfNodePtr> make_tuple_new_xs{
NewValueNode(prim::kPrimMakeTuple),
};
std::for_each(tuple_inputs.begin() + 1, tuple_inputs.end(),
[&make_tuple_new_xs, &first_valuenode](const AnfNodePtr &node) {
if (node != *first_valuenode) {
make_tuple_new_xs.push_back(node);
}
});
ValuePtr prim_addn = prim::GetPythonOps("AddN", "mindspore.ops.operations");
auto new_addn = func_graph->NewCNode(
{func_graph->NewCNode({NewValueNode(prim_addn)}), func_graph->NewCNode(make_tuple_new_xs)});
ValuePtr prim_tensoradd = prim::GetPythonOps("TensorAdd", "mindspore.ops.operations");
auto new_add =
func_graph->NewCNode({func_graph->NewCNode({NewValueNode(prim_tensoradd)}), *first_valuenode, new_addn});
(void)mng->Replace(node, new_add);
changed = true;
continue;
}
}
need_update_ |= changed;
return changed;
}
private:
bool need_update_{false};
};
} // namespace irpass } // namespace irpass
} // namespace opt } // namespace opt
} // namespace mindspore } // namespace mindspore
......
...@@ -79,7 +79,7 @@ class ReduceOneEliminater : public AnfVisitor { ...@@ -79,7 +79,7 @@ class ReduceOneEliminater : public AnfVisitor {
} }
void Visit(const AnfNodePtr &node) override { void Visit(const AnfNodePtr &node) override {
if (x_ == nullptr) { if (!IsVNode(node) && x_ == nullptr) {
if (IsValueNode<tensor::Tensor>(node)) { if (IsValueNode<tensor::Tensor>(node)) {
is_tensor_ = true; is_tensor_ = true;
} }
......
...@@ -23,6 +23,8 @@ ...@@ -23,6 +23,8 @@
#include "optimizer/irpass.h" #include "optimizer/irpass.h"
#include "ir/visitor.h" #include "ir/visitor.h"
#include "operator/ops.h" #include "operator/ops.h"
#include "utils/graph_utils.h"
#include "operator/composite/composite.h"
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
...@@ -36,6 +38,7 @@ class MakeRefEliminater : public AnfVisitor { ...@@ -36,6 +38,7 @@ class MakeRefEliminater : public AnfVisitor {
this->y_ = node; this->y_ = node;
return true; return true;
}; };
AnfVisitor::Match(prim::kPrimMakeRef, {IsNode, gety, IsNode})(node); AnfVisitor::Match(prim::kPrimMakeRef, {IsNode, gety, IsNode})(node);
return y_; return y_;
} }
...@@ -75,6 +78,32 @@ class GetMakeRefEliminater : public AnfVisitor { ...@@ -75,6 +78,32 @@ class GetMakeRefEliminater : public AnfVisitor {
} }
}; };
// {prim::kPrimGetRefValue, {X}} -> X
class GetRefValueEliminater : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
if (!node->isa<CNode>() || node->func_graph() == nullptr) {
return nullptr;
}
auto gety = [this](const AnfNodePtr &node) -> bool {
if (node->isa<Parameter>()) {
this->y_ = node;
return true;
}
return false;
};
y_ = nullptr;
AnfVisitor::Match(prim::kPrimGetRefValue, {gety})(node);
return y_;
}
void Visit(const AnfNodePtr &) override {}
private:
AnfNodePtr y_{nullptr};
};
// IsValueNode<RefKey> // IsValueNode<RefKey>
class ReplaceRefkeyByParam : public AnfVisitor { class ReplaceRefkeyByParam : public AnfVisitor {
public: public:
......
...@@ -137,7 +137,7 @@ class ResetDeferInline : public AnfVisitor { ...@@ -137,7 +137,7 @@ class ResetDeferInline : public AnfVisitor {
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
if (IsValueNode<FuncGraph>(node)) { if (IsValueNode<FuncGraph>(node)) {
auto fg = GetValueNode<FuncGraphPtr>(node); auto fg = GetValueNode<FuncGraphPtr>(node);
fg->set_flags(FUNC_GRAPH_FLAG_DEFER_INLINE, false); fg->set_flag(FUNC_GRAPH_FLAG_DEFER_INLINE, false);
} }
return nullptr; return nullptr;
} }
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include <memory> #include <memory>
#include <utility> #include <utility>
#include <unordered_map> #include <unordered_map>
#include <unordered_set>
#include "optimizer/irpass.h" #include "optimizer/irpass.h"
#include "optimizer/optimizer.h" #include "optimizer/optimizer.h"
...@@ -41,7 +42,7 @@ class SpecializeTransform { ...@@ -41,7 +42,7 @@ class SpecializeTransform {
~SpecializeTransform() = default; ~SpecializeTransform() = default;
FuncGraphPtr operator()(const FuncGraphPtr &func_graph, std::vector<FuncGraphPtr> graph_args, FuncGraphPtr operator()(const FuncGraphPtr &func_graph, std::vector<FuncGraphPtr> graph_args,
std::vector<PrimitivePtr> prim_args) { std::vector<PrimitivePtr> prim_args, std::vector<tensor::TensorPtr> value_args) {
if (cache_.count(func_graph) == 0) { if (cache_.count(func_graph) == 0) {
cache_[func_graph] = {}; cache_[func_graph] = {};
} }
...@@ -69,6 +70,13 @@ class SpecializeTransform { ...@@ -69,6 +70,13 @@ class SpecializeTransform {
(void)mng->Replace(params[i], arg); (void)mng->Replace(params[i], arg);
continue; continue;
} }
if (value_args[i] != nullptr) {
auto const_tensor = *value_args[i];
auto const_tensor_ptr = std::make_shared<tensor::Tensor>(const_tensor);
AnfNodePtr arg = NewValueNode(const_tensor_ptr);
(void)mng->Replace(params[i], arg);
continue;
}
new_params.push_back(params[i]); new_params.push_back(params[i]);
} }
...@@ -108,6 +116,7 @@ class SpecializeOnGraphArguments : public AnfVisitor { ...@@ -108,6 +116,7 @@ class SpecializeOnGraphArguments : public AnfVisitor {
std::vector<FuncGraphPtr> graph_args; std::vector<FuncGraphPtr> graph_args;
std::vector<PrimitivePtr> prim_args; std::vector<PrimitivePtr> prim_args;
std::vector<tensor::TensorPtr> value_node_args;
std::vector<AnfNodePtr> new_xs; std::vector<AnfNodePtr> new_xs;
bool hasVNode = false; bool hasVNode = false;
for (size_t i = 1; i < inputs.size(); i++) { for (size_t i = 1; i < inputs.size(); i++) {
...@@ -115,15 +124,24 @@ class SpecializeOnGraphArguments : public AnfVisitor { ...@@ -115,15 +124,24 @@ class SpecializeOnGraphArguments : public AnfVisitor {
auto fg_vnode = GetValueNode<FuncGraphPtr>(inputs[i]); auto fg_vnode = GetValueNode<FuncGraphPtr>(inputs[i]);
graph_args.push_back(fg_vnode); graph_args.push_back(fg_vnode);
prim_args.emplace_back(nullptr); prim_args.emplace_back(nullptr);
value_node_args.emplace_back(nullptr);
hasVNode = true; hasVNode = true;
} else if (IsValueNode<Primitive>(inputs[i])) { } else if (IsValueNode<Primitive>(inputs[i])) {
auto p_vnode = GetValueNode<PrimitivePtr>(inputs[i]); auto p_vnode = GetValueNode<PrimitivePtr>(inputs[i]);
graph_args.emplace_back(nullptr); graph_args.emplace_back(nullptr);
prim_args.push_back(p_vnode); prim_args.push_back(p_vnode);
value_node_args.emplace_back(nullptr);
hasVNode = true;
} else if (IsValueNode<tensor::Tensor>(inputs[i])) {
tensor::TensorPtr t_vnode = GetValueNode<tensor::TensorPtr>(inputs[i]);
graph_args.emplace_back(nullptr);
prim_args.emplace_back(nullptr);
value_node_args.emplace_back(t_vnode);
hasVNode = true; hasVNode = true;
} else { } else {
graph_args.emplace_back(nullptr); graph_args.emplace_back(nullptr);
prim_args.emplace_back(nullptr); prim_args.emplace_back(nullptr);
value_node_args.emplace_back(nullptr);
new_xs.push_back(inputs[i]); new_xs.push_back(inputs[i]);
} }
} }
...@@ -132,7 +150,7 @@ class SpecializeOnGraphArguments : public AnfVisitor { ...@@ -132,7 +150,7 @@ class SpecializeOnGraphArguments : public AnfVisitor {
return nullptr; return nullptr;
} }
auto new_fg = specialize_transform_(inp0_fg, graph_args, prim_args); auto new_fg = specialize_transform_(inp0_fg, graph_args, prim_args, value_node_args);
(void)new_xs.insert(new_xs.begin(), NewValueNode(new_fg)); (void)new_xs.insert(new_xs.begin(), NewValueNode(new_fg));
return node->func_graph()->NewCNode(new_xs); return node->func_graph()->NewCNode(new_xs);
...@@ -141,6 +159,146 @@ class SpecializeOnGraphArguments : public AnfVisitor { ...@@ -141,6 +159,146 @@ class SpecializeOnGraphArguments : public AnfVisitor {
private: private:
internal::SpecializeTransform specialize_transform_; internal::SpecializeTransform specialize_transform_;
}; };
// Eliminate unused parameters.
// {G, Xs}
class UnusedParasEliminater : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
if (!node->isa<CNode>() || node->func_graph() == nullptr) {
return nullptr;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto &inputs = cnode->inputs();
auto fg = GetValueNode<FuncGraphPtr>(inputs[0]);
MS_EXCEPTION_IF_NULL(fg);
std::vector<AnfNodePtr> parameters = fg->parameters();
size_t size = parameters.size();
if (size != inputs.size() - 1) {
return nullptr;
}
std::vector<AnfNodePtr> new_xs;
std::vector<bool> keep_parameters;
auto mng = fg->manager();
MS_EXCEPTION_IF_NULL(mng);
auto &node_users = mng->node_users();
bool has_unused_para = false;
for (size_t i = 0; i < size; ++i) {
auto iter = node_users.find(parameters[i]);
if (iter != node_users.end() && !iter->second.empty()) {
keep_parameters.push_back(true);
new_xs.push_back(inputs[i + 1]);
continue;
}
keep_parameters.push_back(false);
has_unused_para = true;
}
if (!has_unused_para) {
return nullptr;
}
FuncGraphPtr new_fg = TransformableClone(fg, std::make_shared<TraceTransform>("sp"));
mng->AddFuncGraph(new_fg);
std::vector<AnfNodePtr> new_fg_parameters = new_fg->parameters();
std::vector<AnfNodePtr> new_parameters;
for (size_t i = 0; i < size; i++) {
if (keep_parameters[i]) {
if (parameters[i]->abstract() != nullptr) {
new_fg_parameters[i]->set_abstract(parameters[i]->abstract());
}
new_parameters.push_back(new_fg_parameters[i]);
}
}
mng->SetParameters(new_fg, new_parameters);
(void)new_xs.insert(new_xs.begin(), NewValueNode(new_fg));
return node->func_graph()->NewCNode(new_xs);
}
};
// Eliminate unused outputs.
// {G, Xs}
class UnusedOutputEliminater : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
if (!node->isa<CNode>() || node->func_graph() == nullptr) {
return nullptr;
}
auto &inputs = node->cast<CNodePtr>()->inputs();
auto fg = GetValueNode<FuncGraphPtr>(inputs[0]);
MS_EXCEPTION_IF_NULL(fg);
auto mng = fg->manager();
MS_EXCEPTION_IF_NULL(mng);
if (fg->recursive()) {
return nullptr;
}
auto new_fg = TransformableClone(fg, std::make_shared<TraceTransform>("fg"));
mng->AddFuncGraph(new_fg);
auto new_fg_output = new_fg->output();
if (!IsPrimitiveCNode(new_fg_output, prim::kPrimMakeTuple)) {
return nullptr;
}
auto output_cnode = new_fg_output->cast<CNodePtr>();
auto &node_users = mng->node_users();
if (node_users.count(node) == 0 || node_users[node].empty()) {
return nullptr;
}
std::unordered_set<int> used_output_idx;
std::vector<std::pair<AnfNodePtr, int>> all_users;
for (auto &node_user : node_users[node]) {
if (!IsPrimitiveCNode(node_user.first, prim::kPrimTupleGetItem)) {
return nullptr;
}
auto user_cnode = node_user.first->cast<CNodePtr>();
size_t used_idx = GetValue<int>(user_cnode->input(2)->cast<ValueNodePtr>()->value());
used_output_idx.insert(used_idx);
all_users.push_back(std::make_pair(node_user.first, used_idx));
}
if (used_output_idx.size() >= output_cnode->inputs().size() - 1) {
// all output has users.
return nullptr;
}
if (used_output_idx.empty()) {
// we do not process this case.
return nullptr;
} else if (used_output_idx.size() == 1) {
// after eliminate, only one output left.
new_fg->set_output(output_cnode->input(*used_output_idx.begin() + 1));
// update users.
for (auto &ret_user : all_users) {
(void)mng->Replace(ret_user.first, node);
}
} else {
// after eliminate, create new multi output.
std::vector<AnfNodePtr> new_output_inputs{output_cnode->input(0)};
std::unordered_map<int, int> new_idx_map;
for (auto idx : used_output_idx) {
new_idx_map[idx] = SizeToInt(new_output_inputs.size() - 1);
new_output_inputs.push_back(output_cnode->input(idx + 1));
}
new_fg->set_output(new_fg->NewCNode(new_output_inputs));
// update users.
for (auto &ret_user : all_users) {
auto ret_user_cnode = ret_user.first->cast<CNodePtr>();
ret_user_cnode->set_input(2, NewValueNode(new_idx_map[ret_user.second]));
}
}
auto new_sx = inputs;
new_sx[0] = NewValueNode(new_fg);
return node->func_graph()->NewCNode(new_sx);
}
};
} // namespace irpass } // namespace irpass
} // namespace opt } // namespace opt
} // namespace mindspore } // namespace mindspore
......
...@@ -88,7 +88,7 @@ using OptPassGroupMap = std::vector<std::pair<std::string, OptPassConfig>>; ...@@ -88,7 +88,7 @@ using OptPassGroupMap = std::vector<std::pair<std::string, OptPassConfig>>;
class Optimizer : public std::enable_shared_from_this<Optimizer> { class Optimizer : public std::enable_shared_from_this<Optimizer> {
public: public:
Optimizer(const std::string &name, const pipeline::ResourceBasePtr &resource_ptr) Optimizer(const std::string &name, const pipeline::ResourceBasePtr &resource_ptr)
: name_(name), resource_(resource_ptr), run_only_once_(false), is_watch_renormalize_(false) {} : name_(name), resource_(resource_ptr), run_only_once_(false), is_watch_renormalize_(false), is_enable_(true) {}
virtual ~Optimizer() = default; virtual ~Optimizer() = default;
void Init(const OptPassGroupMap &passes, bool run_only_once) { void Init(const OptPassGroupMap &passes, bool run_only_once) {
...@@ -131,6 +131,9 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> { ...@@ -131,6 +131,9 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> {
} }
FuncGraphPtr step(FuncGraphPtr func_graph, bool use_profile = true) { FuncGraphPtr step(FuncGraphPtr func_graph, bool use_profile = true) {
if (!is_enable_) {
return func_graph;
}
// Optimizer step counter; // Optimizer step counter;
int counter = -1; int counter = -1;
bool changes = true; bool changes = true;
...@@ -170,7 +173,7 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> { ...@@ -170,7 +173,7 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> {
}; };
use_profile ? (WITH(MsProfile::GetProfile()->Step(pass_names_[i])) opt_func) : opt_func(); use_profile ? (WITH(MsProfile::GetProfile()->Step(pass_names_[i])) opt_func) : opt_func();
if (IS_OUTPUT_ON(mindspore::DEBUG) && MsContext::GetInstance()->save_graphs_flag()) { if (IS_OUTPUT_ON(mindspore::DEBUG) && MsContext::GetInstance()->save_graphs_flag()) {
MS_LOG(DEBUG) << name_ << " round " << counter << " OptPass " << pass_names_[i] << " end."; MS_LOG(DEBUG) << "The opt " << name_ << " round " << counter << " OptPass " << pass_names_[i] << " end.";
auto fg_name = auto fg_name =
"opt_substep_" + name_ + "_r" + std::to_string(counter) + "_" + std::to_string(i) + "_" + pass_names_[i]; "opt_substep_" + name_ + "_r" + std::to_string(counter) + "_" + std::to_string(i) + "_" + pass_names_[i];
func_graph->DumpFuncGraph(fg_name); func_graph->DumpFuncGraph(fg_name);
...@@ -209,6 +212,7 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> { ...@@ -209,6 +212,7 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> {
void enable_watch_renormalize() { is_watch_renormalize_ = true; } void enable_watch_renormalize() { is_watch_renormalize_ = true; }
void disable_watch_renormalize() { is_watch_renormalize_ = false; } void disable_watch_renormalize() { is_watch_renormalize_ = false; }
bool is_watch_renormalize() { return is_watch_renormalize_; } bool is_watch_renormalize() { return is_watch_renormalize_; }
void set_enable(bool enable) { is_enable_ = enable; }
private: private:
const std::string name_; const std::string name_;
...@@ -218,6 +222,7 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> { ...@@ -218,6 +222,7 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> {
bool run_only_once_; bool run_only_once_;
std::vector<AnfNodePtr> untyped_nodes_; std::vector<AnfNodePtr> untyped_nodes_;
bool is_watch_renormalize_; bool is_watch_renormalize_;
bool is_enable_;
}; };
} // namespace opt } // namespace opt
} // namespace mindspore } // namespace mindspore
......
...@@ -64,7 +64,7 @@ bool StepAllreduceFusion(const FuncGraphPtr &root, const opt::OptimizerPtr &opti ...@@ -64,7 +64,7 @@ bool StepAllreduceFusion(const FuncGraphPtr &root, const opt::OptimizerPtr &opti
DumpGraph(root, std::string(ALLREDUCE_FUSION_END)); DumpGraph(root, std::string(ALLREDUCE_FUSION_END));
// allreduce fusion only run once // allreduce fusion only run once
root->flags()[ALLREDUCE_FUSION_RUN_ONCE_ONLY] = true; root->set_flag(ALLREDUCE_FUSION_RUN_ONCE_ONLY, true);
res->results()[pipeline::kStepParallelGraph] = root; res->results()[pipeline::kStepParallelGraph] = root;
#if defined(_WIN32) || defined(_WIN64) #if defined(_WIN32) || defined(_WIN64)
auto end_time = std::chrono::steady_clock::now(); auto end_time = std::chrono::steady_clock::now();
......
...@@ -155,8 +155,8 @@ void ParallelParameterContextRestoreInNoTraining(const FuncGraphPtr &func_graph, ...@@ -155,8 +155,8 @@ void ParallelParameterContextRestoreInNoTraining(const FuncGraphPtr &func_graph,
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(param_node); MS_EXCEPTION_IF_NULL(param_node);
MS_EXCEPTION_IF_NULL(ptr); MS_EXCEPTION_IF_NULL(ptr);
if (!func_graph->has_flag(AUTO_PARALLEL) || (func_graph->flags().count(TRAINING) == 0) || if (!func_graph->has_flag(AUTO_PARALLEL) || (func_graph->attrs().count(TRAINING) == 0) ||
func_graph->flags()[TRAINING]) { func_graph->has_flag(TRAINING)) {
return; return;
} }
......
...@@ -175,7 +175,7 @@ bool StepAutoParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &) { ...@@ -175,7 +175,7 @@ bool StepAutoParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &) {
time += static_cast<uint64_t>(end_time.tv_usec - start_time.tv_usec); time += static_cast<uint64_t>(end_time.tv_usec - start_time.tv_usec);
MS_LOG(INFO) << "Now leaving step auto parallel, used time: " << time << " us"; MS_LOG(INFO) << "Now leaving step auto parallel, used time: " << time << " us";
root->flags()[AUTO_PARALLEL_RUN_ONCE_ONLY] = true; root->set_flag(AUTO_PARALLEL_RUN_ONCE_ONLY, true);
return changes; return changes;
} }
......
...@@ -2258,10 +2258,10 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) ...@@ -2258,10 +2258,10 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer)
(root->has_flag(SEMI_AUTO_PARALLEL_RUN_ONCE_ONLY))) { (root->has_flag(SEMI_AUTO_PARALLEL_RUN_ONCE_ONLY))) {
if (!root->has_flag(CHECK_SET_STRATEGY_VALID_ONCE_ONLY)) { if (!root->has_flag(CHECK_SET_STRATEGY_VALID_ONCE_ONLY)) {
if (HasStrategy(root)) { if (HasStrategy(root)) {
MS_LOG(INFO) << "strategies ignored in " << parallel_mode MS_LOG(INFO) << "Strategies ignored in " << parallel_mode
<< ", set_strategy() only valid in [semi_]auto_parallel."; << ", set_strategy() only valid in [semi_]auto_parallel.";
} }
root->flags()[CHECK_SET_STRATEGY_VALID_ONCE_ONLY] = true; root->set_flag(CHECK_SET_STRATEGY_VALID_ONCE_ONLY, true);
} }
return changes; return changes;
...@@ -2318,11 +2318,11 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) ...@@ -2318,11 +2318,11 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer)
DumpGraph(root, std::string(STEP_PARALLEL_END)); DumpGraph(root, std::string(STEP_PARALLEL_END));
// step parallel only run once // step parallel only run once
root->flags()[SEMI_AUTO_PARALLEL_RUN_ONCE_ONLY] = true; root->set_flag(SEMI_AUTO_PARALLEL_RUN_ONCE_ONLY, true);
res->results()[pipeline::kStepParallelGraph] = root; res->results()[pipeline::kStepParallelGraph] = root;
// in auto parallel mode, no need to check if stategies set // in auto parallel mode, no need to check if stategies set
root->flags()[CHECK_SET_STRATEGY_VALID_ONCE_ONLY] = true; root->set_flag(CHECK_SET_STRATEGY_VALID_ONCE_ONLY, true);
(void)gettimeofday(&end_time, nullptr); (void)gettimeofday(&end_time, nullptr);
uint64_t time = kUSecondInSecond * static_cast<uint64_t>(end_time.tv_sec - start_time.tv_sec); uint64_t time = kUSecondInSecond * static_cast<uint64_t>(end_time.tv_sec - start_time.tv_sec);
......
...@@ -141,7 +141,10 @@ PYBIND11_MODULE(_c_expression, m) { ...@@ -141,7 +141,10 @@ PYBIND11_MODULE(_c_expression, m) {
.def("get_enable_profiling", &mindspore::MsContext::enable_profiling, "Get whether to open profiling.") .def("get_enable_profiling", &mindspore::MsContext::enable_profiling, "Get whether to open profiling.")
.def("set_enable_profiling", &mindspore::MsContext::set_enable_profiling, "Set whether to open profiling.") .def("set_enable_profiling", &mindspore::MsContext::set_enable_profiling, "Set whether to open profiling.")
.def("get_profiling_options", &mindspore::MsContext::profiling_options, "Get options to profiling.") .def("get_profiling_options", &mindspore::MsContext::profiling_options, "Get options to profiling.")
.def("set_profiling_options", &mindspore::MsContext::set_profiling_options, "Set options to profiling."); .def("set_profiling_options", &mindspore::MsContext::set_profiling_options, "Set options to profiling.")
.def("set_enable_graph_kernel", &mindspore::MsContext::set_enable_graph_kernel,
"Set the GraphKernel switch to on or off.")
.def("get_enable_graph_kernel", &mindspore::MsContext::enable_graph_kernel, "Get the value of GraphKernel switch.");
(void)py::class_<ParallelContext, std::shared_ptr<ParallelContext>>(m, "AutoParallelContext") (void)py::class_<ParallelContext, std::shared_ptr<ParallelContext>>(m, "AutoParallelContext")
.def_static("get_instance", &ParallelContext::GetInstance, "Get auto parallel context instance.") .def_static("get_instance", &ParallelContext::GetInstance, "Get auto parallel context instance.")
......
...@@ -278,7 +278,7 @@ bool ConvertOtherObj(py::object obj, ValuePtr *const data) { ...@@ -278,7 +278,7 @@ bool ConvertOtherObj(py::object obj, ValuePtr *const data) {
if (bprop_graph != nullptr) { if (bprop_graph != nullptr) {
(void)func_graph->transforms().insert(std::make_pair("bprop", FuncGraphTransform(bprop_graph))); (void)func_graph->transforms().insert(std::make_pair("bprop", FuncGraphTransform(bprop_graph)));
(void)bprop_graph->transforms().insert(std::make_pair("primal", FuncGraphTransform(func_graph))); (void)bprop_graph->transforms().insert(std::make_pair("primal", FuncGraphTransform(func_graph)));
func_graph->set_flags(FUNC_GRAPH_FLAG_DEFER_INLINE, true); func_graph->set_flag(FUNC_GRAPH_FLAG_DEFER_INLINE, true);
} }
} }
*data = func_graph; *data = func_graph;
......
...@@ -1437,15 +1437,23 @@ bool ParseAst::UpdateFuncGraphFlags(const FuncGraphPtr &func_graph) { ...@@ -1437,15 +1437,23 @@ bool ParseAst::UpdateFuncGraphFlags(const FuncGraphPtr &func_graph) {
} }
py::dict flags = python_adapter::GetPyObjAttr(obj_, PYTHON_EXTERN_MINDSPORE_FLAG); py::dict flags = python_adapter::GetPyObjAttr(obj_, PYTHON_EXTERN_MINDSPORE_FLAG);
for (auto &item : flags) { for (auto &item : flags) {
if (!py::isinstance<py::str>(item.first) || !py::isinstance<py::bool_>(item.second)) { if (!py::isinstance<py::str>(item.first)) {
MS_LOG(ERROR) << "Type error in flags dict convert"; MS_LOG(ERROR) << "Type error in flags dict convert";
return false; return false;
} }
auto name = py::cast<std::string>(item.first); auto name = py::cast<std::string>(item.first);
auto value = py::cast<bool>(item.second); if (py::isinstance<py::bool_>(item.second)) {
MS_LOG(DEBUG) << "Flag name: " << name << ". Value: " << value; auto value = py::cast<bool>(item.second);
MS_LOG(DEBUG) << "Flag name: " << name << ". Value: " << value;
func_graph->set_flags(name, value); func_graph->set_flag(name, value);
} else if (py::isinstance<py::str>(item.second)) {
auto value = py::cast<std::string>(item.second);
MS_LOG(DEBUG) << "Flag name: " << name << ". Value: " << value;
func_graph->set_attr(name, MakeValue(value));
} else {
MS_LOG(ERROR) << "Type error in flags/attrs dict convert";
return false;
}
} }
return true; return true;
......
...@@ -223,8 +223,8 @@ class Parser { ...@@ -223,8 +223,8 @@ class Parser {
FunctionBlockPtr block = std::make_shared<FunctionBlock>(parse); FunctionBlockPtr block = std::make_shared<FunctionBlock>(parse);
// In order to keep effect order in the sub-graphs which generated by control flow. // In order to keep effect order in the sub-graphs which generated by control flow.
// We copy the flags from the top graph to the sub-graphs. // We copy the flags from the top graph to the sub-graphs.
if (func_graph_ && !func_graph_->flags().empty()) { if (func_graph_ && !func_graph_->attrs().empty()) {
block->func_graph()->set_flags(func_graph_->flags()); block->func_graph()->set_attrs(func_graph_->attrs());
} }
func_block_list_.push_back(block); func_block_list_.push_back(block);
return block; return block;
......
...@@ -25,12 +25,14 @@ ...@@ -25,12 +25,14 @@
#include <functional> #include <functional>
#include "ir/func_graph_cloner.h" #include "ir/func_graph_cloner.h"
#include "debug/anf_ir_utils.h"
#include "pipeline/parse/parse_base.h" #include "pipeline/parse/parse_base.h"
#include "pipeline/parse/data_converter.h" #include "pipeline/parse/data_converter.h"
#include "pipeline/resource.h" #include "pipeline/resource.h"
#include "pipeline/validator.h" #include "pipeline/validator.h"
#include "optimizer/optimizer.h" #include "optimizer/optimizer.h"
#include "optimizer/cse.h" #include "optimizer/cse.h"
#include "optimizer/composite_op_reuse.h"
#include "optimizer/clean.h" #include "optimizer/clean.h"
#include "optimizer/irpass.h" #include "optimizer/irpass.h"
#include "optimizer/control_depend.h" #include "optimizer/control_depend.h"
...@@ -38,6 +40,7 @@ ...@@ -38,6 +40,7 @@
#include "parallel/step_auto_parallel.h" #include "parallel/step_auto_parallel.h"
#include "parallel/allreduce_fusion/step_allreduce_fusion.h" #include "parallel/allreduce_fusion/step_allreduce_fusion.h"
#include "utils/any.h" #include "utils/any.h"
#include "utils/log_adapter.h"
namespace mindspore { namespace mindspore {
namespace pipeline { namespace pipeline {
...@@ -151,6 +154,7 @@ OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) { ...@@ -151,6 +154,7 @@ OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) {
opt::OptPassConfig b_2 = opt::OptPassConfig({ opt::OptPassConfig b_2 = opt::OptPassConfig({
irpass.replace_refkey_by_param_, irpass.replace_refkey_by_param_,
irpass.make_ref_eliminate_, irpass.make_ref_eliminate_,
irpass.get_ref_value_eliminate_,
}); });
OptPassGroupMap map({ OptPassGroupMap map({
{"b_1", b_1}, {"b_1", b_1},
...@@ -161,6 +165,40 @@ OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) { ...@@ -161,6 +165,40 @@ OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) {
return map; return map;
} }
OptPassGroupMap GetOptPassesCompositeA(const opt::irpass::OptimizeIRPassLib &irpass) {
opt::OptPassConfig interface_fusion = opt::OptPassConfig({
irpass.mark_interface_fusion_,
});
OptPassGroupMap map({
{"composite_reuse", opt::OptPassConfig(opt::CompositeReuse())},
{"interface_fusion", interface_fusion},
{"renormalize", opt::OptPassConfig::Renormalize()},
{"cse", opt::OptPassConfig(opt::CSE(false))},
});
return map;
}
OptPassGroupMap GetOptPassesCompositeB(const opt::irpass::OptimizeIRPassLib &irpass) {
opt::OptPassConfig elim_1 = opt::OptPassConfig({
irpass.addn_eliminate_,
irpass.incorporate_getitem_from_param_,
});
opt::OptPassConfig elim_2 = opt::OptPassConfig({
irpass.unused_parameter_eliminate_,
irpass.unused_output_eliminate_,
});
OptPassGroupMap map({
{"elim_1", elim_1},
{"renormalize", opt::OptPassConfig::Renormalize()},
{"elim_2", elim_2},
});
return map;
}
OptPassGroupMap GetOptPassesC(const opt::irpass::OptimizeIRPassLib &irpass) {
return OptPassGroupMap({{"renormalize", opt::OptPassConfig::Renormalize()}});
}
OptPassGroupMap GetControlPhases(const opt::irpass::OptimizeIRPassLib &irpass) { OptPassGroupMap GetControlPhases(const opt::irpass::OptimizeIRPassLib &irpass) {
opt::OptPassConfig control_group = opt::OptPassConfig({irpass.convert_switch_replacement_}); opt::OptPassConfig control_group = opt::OptPassConfig({irpass.convert_switch_replacement_});
OptPassGroupMap map({ OptPassGroupMap map({
...@@ -190,8 +228,19 @@ void InitOpt(const ResourcePtr &res) { ...@@ -190,8 +228,19 @@ void InitOpt(const ResourcePtr &res) {
opt::irpass::OptimizeIRPassLib irpass; opt::irpass::OptimizeIRPassLib irpass;
g_pass_opts["opt_a"] = Optimizer::MakeOptimizer("opt_a", res, GetOptPassesA(irpass)); g_pass_opts["opt_a"] = Optimizer::MakeOptimizer("opt_a", res, GetOptPassesA(irpass));
g_pass_opts["opt_b"] = Optimizer::MakeOptimizer("opt_b", res, GetOptPassesB(irpass), false, true); g_pass_opts["opt_b"] = Optimizer::MakeOptimizer("opt_b", res, GetOptPassesB(irpass), false, true);
g_pass_opts["opt_composite_a"] =
Optimizer::MakeOptimizer("opt_composite_a", res, GetOptPassesCompositeA(irpass), true);
g_pass_opts["opt_composite_b"] =
Optimizer::MakeOptimizer("opt_composite_b", res, GetOptPassesCompositeB(irpass), false);
g_pass_opts["renormal"] = Optimizer::MakeOptimizer("renormal", res, GetOptPassesC(irpass));
g_pass_opts["opt_control"] = Optimizer::MakeOptimizer("opt_control", res, GetControlPhases(irpass), false, true); g_pass_opts["opt_control"] = Optimizer::MakeOptimizer("opt_control", res, GetControlPhases(irpass), false, true);
g_pass_opts["opt_prepare"] = Optimizer::MakeOptimizer("opt_prepare", res, GetPreparePhases(irpass)); g_pass_opts["opt_prepare"] = Optimizer::MakeOptimizer("opt_prepare", res, GetPreparePhases(irpass));
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (!(context_ptr->enable_graph_kernel())) {
g_pass_opts["opt_composite_a"]->set_enable(false);
g_pass_opts["opt_composite_b"]->set_enable(false);
}
} }
} }
} // namespace } // namespace
...@@ -223,9 +272,13 @@ bool OptPassGroup(const ResourcePtr &res, const std::string &name) { ...@@ -223,9 +272,13 @@ bool OptPassGroup(const ResourcePtr &res, const std::string &name) {
bool OptPassAGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_a"); } bool OptPassAGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_a"); }
bool OptPassBGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_b"); } bool OptPassBGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_b"); }
bool OptPassCompositeGroupA(const ResourcePtr &res) { return OptPassGroup(res, "opt_composite_a"); }
bool OptPassCompositeGroupB(const ResourcePtr &res) { return OptPassGroup(res, "opt_composite_b"); }
bool ControlGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_control"); } bool ControlGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_control"); }
bool PrepareGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_prepare"); } bool PrepareGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_prepare"); }
bool OptPassRNGroup(const ResourcePtr &res) { return OptPassGroup(res, "renormal"); }
bool AddControlDependPass(const ResourcePtr &res) { bool AddControlDependPass(const ResourcePtr &res) {
FuncGraphPtr func_graph = res->func_graph(); FuncGraphPtr func_graph = res->func_graph();
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);
...@@ -269,8 +322,10 @@ bool InferenceOptPreparePass(const ResourcePtr &res) { ...@@ -269,8 +322,10 @@ bool InferenceOptPreparePass(const ResourcePtr &res) {
std::vector<PassItem> kVmPasses = {{"simplify_data_structures", SimplifyDataStructuresPass}, std::vector<PassItem> kVmPasses = {{"simplify_data_structures", SimplifyDataStructuresPass},
{"opt_a", OptPassAGroup}, {"opt_a", OptPassAGroup},
{"opt_b", OptPassBGroup}, {"opt_b", OptPassBGroup},
{"add_control_depend", AddControlDependPass}, {"cconv", CconvPass},
{"cconv", CconvPass}}; {"opt_composite_a", OptPassCompositeGroupA},
{"opt_composite_b", OptPassCompositeGroupB},
{"add_control_depend", AddControlDependPass}};
std::vector<PassItem> kGePasses = {{"simplify_data_structures", SimplifyDataStructuresPass}, std::vector<PassItem> kGePasses = {{"simplify_data_structures", SimplifyDataStructuresPass},
{"opt_a", OptPassAGroup}, {"opt_a", OptPassAGroup},
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册