提交 86d47973 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1079 Convert AiCpu kernel when aicore not supported

Merge pull request !1079 from lianliguang/convert-to-AICPU-when-AiCore-not-supported-kernel
......@@ -85,7 +85,7 @@ const std::map<TypeId, size_t> type_map = {{kNumberTypeBool, 1}, {kNumberType
} while (0)
template <typename T>
T Ceil(T n1, T n2) {
T DivCeil(T n1, T n2) {
return (n2 != 0) ? (n1 - 1) / n2 + 1 : 0;
}
......@@ -371,15 +371,48 @@ std::vector<size_t> C1hwncoc0DeviceShape(const std::vector<size_t> &shape) {
device_shape.push_back(kCubeSize);
return device_shape;
}
std::vector<size_t> FracZc04DeviceShape(const std::vector<size_t> &shape) {
if (!CheckDims(shape)) {
MS_LOG(EXCEPTION) << "Check dims failed.";
}
std::vector<size_t> device_shape;
size_t c0 = 4;
size_t first_dim = DivCeil(c0 * shape[2] * shape[3], kCubeSize);
size_t no = DivCeil(DivCeil(shape[0], kCubeSize) * kCubeSize, kCubeSize);
device_shape.push_back(first_dim);
device_shape.push_back(no);
device_shape.push_back(kCubeSize);
device_shape.push_back(kCubeSize);
return device_shape;
}
std::vector<size_t> Nc1hwc04DeviceShape(const std::vector<size_t> &shape) {
if (!CheckDims(shape)) {
MS_LOG(EXCEPTION) << "Check dims failed.";
}
std::vector<size_t> device_shape;
size_t C1 = 1;
size_t C0 = 4;
device_shape.push_back(shape[0]);
device_shape.push_back(C1);
device_shape.push_back(shape[2]);
device_shape.push_back(shape[3]);
device_shape.push_back(C0);
return device_shape;
}
} // namespace
std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const std::string &format) {
using DeviceShapeTransfer = std::function<std::vector<size_t>(const std::vector<size_t> &)>;
const std::map<std::string, DeviceShapeTransfer> device_shape_map{
{kOpFormat_NCHW, NchwDeviceShape}, {kOpFormat_NHWC, NhwcDeviceShape},
{kOpFormat_HWCN, HwchDeviceShape}, {kOpFormat_FRAC_Z, FracZDeviceShape},
{kOpFormat_NC1HWC0, Nc1hwc0DeviceShape}, {kOpFormat_C1HWNCoC0, C1hwncoc0DeviceShape},
};
const std::map<std::string, DeviceShapeTransfer> device_shape_map{{kOpFormat_NCHW, NchwDeviceShape},
{kOpFormat_NHWC, NhwcDeviceShape},
{kOpFormat_HWCN, HwchDeviceShape},
{kOpFormat_FRAC_Z, FracZDeviceShape},
{kOpFormat_NC1HWC0, Nc1hwc0DeviceShape},
{kOpFormat_C1HWNCoC0, C1hwncoc0DeviceShape},
{kOpFormat_FRACTAL_Z_C04, FracZc04DeviceShape},
{kOpFormat_NC1HWC0_C04, Nc1hwc04DeviceShape}};
if (format == kOpFormat_ND || format == kOpFormat_DEFAULT) {
return shape;
......@@ -506,13 +539,13 @@ bool NchwToFracZ(const FormatArgs &args, void *result) {
MS_LOG(ERROR) << "Illegal dtype.";
return false;
}
size_t c1 = Ceil(c, c0);
size_t c1 = DivCeil(c, c0);
size_t hw = h * w;
size_t chw = c * hw;
size_t hwc0 = hw * c0;
size_t nchw = n * chw;
size_t hf_cnt = Ceil(n, kCubeSize);
size_t hf_cnt = DivCeil(n, kCubeSize);
size_t vf_cnt = c1 * hw;
size_t fractal_ele_cnt = c0 * kCubeSize;
size_t total_ele_cnt = hf_cnt * vf_cnt * fractal_ele_cnt;
......@@ -775,7 +808,7 @@ bool NchwToNc1hwc0(const FormatArgs &args, void *result) {
MS_LOG(ERROR) << "Illegal dtype.";
return false;
}
size_t c1 = Ceil(c, c0);
size_t c1 = DivCeil(c, c0);
size_t hw = h * w;
size_t chw = c * hw;
size_t c1hwc0 = c1 * hw * c0;
......
......@@ -34,6 +34,7 @@ namespace ascend {
namespace {
const float kWegihtBaseScore = 1;
const float kFeatureMapBaseScore = 10;
constexpr auto kPriChoosenFormat = "pri_format";
enum MatchCountPriority : int {
MATCH_COUNT_PRIORITY_BEGIN = 0,
MATCH_DTYPE_COUNT = MATCH_COUNT_PRIORITY_BEGIN,
......@@ -85,6 +86,7 @@ string GetPriorityMatchFormat(const CNodePtr &cnode) {
if (need_change_nd) {
priority_matched_format = kOpFormat_DEFAULT;
}
AnfAlgo::SetNodeAttr(kPriChoosenFormat, MakeValue(priority_matched_format), cnode);
return priority_matched_format;
}
/**
......@@ -394,9 +396,9 @@ void PrintRaiseOrReducePrecisionSelectedInfo(const CNodePtr &cnode,
std::ostringstream buffer;
buffer << cnode->DebugString();
if (precision_reduce) {
buffer << " reduce precision, node datatype: ";
buffer << " reduce precision, node datatype: \n";
} else {
buffer << " raise precision, node datatype: ";
buffer << " raise precision, node datatype: \n";
}
PrintInputAndOutputInferType(buffer, cnode);
buffer << ", select kernel:" << selected_kernel_build_info->ToString();
......@@ -464,66 +466,57 @@ std::vector<std::shared_ptr<kernel::KernelBuildInfo>> FilterRaisedOrReducePrecis
}
} // namespace
std::shared_ptr<kernel::KernelBuildInfo> CanHitKernelInfo(
int *status, const CNodePtr &kernel_node,
const std::vector<std::shared_ptr<kernel::KernelBuildInfo>> &kernel_info_list) {
KernelSelectStatus SetMatchedKernelInfo(const CNodePtr &kernel_node,
const std::vector<std::shared_ptr<kernel::KernelBuildInfo>> &kernel_info_list) {
MS_EXCEPTION_IF_NULL(kernel_node);
KernelSelectStatus select_status = kNoMatched;
bool precision_reduce = false;
std::shared_ptr<kernel::KernelBuildInfo> selected_kernel_info = nullptr;
// Matched kernel info
// Filter kernel info matched with me infered type
auto filtered_kernel_info_list = GetAllMatchedFilteredKernelInfo(kernel_node, kernel_info_list);
if (!filtered_kernel_info_list.empty()) {
selected_kernel_info = ChooseMatchedKernelInfo(kernel_node, filtered_kernel_info_list);
select_status = kStatusAllMatched;
} else {
// selected kernel info using raised precision or reduce precision
filtered_kernel_info_list =
FilterRaisedOrReducePrecisionMatchedKernelInfo(kernel_node, kernel_info_list, &precision_reduce);
selected_kernel_info = ChooseMatchedKernelInfo(kernel_node, filtered_kernel_info_list);
if (selected_kernel_info == nullptr) {
return nullptr;
return select_status;
} else {
PrintRaiseOrReducePrecisionSelectedInfo(kernel_node, selected_kernel_info, precision_reduce);
*status = precision_reduce ? kStatusReducePrecision : kStatusRaisePrecision;
select_status = precision_reduce ? kStatusReducePrecision : kStatusRaisePrecision;
}
}
return selected_kernel_info;
// Set kernel info to the anfnode
AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_info, kernel_node.get());
// Set format and data type for input tensor.
SetTensorDeviceInfo(*selected_kernel_info, kernel_node);
return select_status;
}
int SelectKernelInfo(const CNodePtr &kernel_node) {
KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node) {
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list;
int status = kStatusAllMatched;
MS_EXCEPTION_IF_NULL(kernel_node);
kernel::KernelQuery(kernel_node, &kernel_info_list);
// filter kernel info matched with me infered type
auto selected_kernel_info = CanHitKernelInfo(&status, kernel_node, kernel_info_list);
if (selected_kernel_info == nullptr) {
auto select_status = SetMatchedKernelInfo(kernel_node, kernel_info_list);
// If aicore not find valid kernel info reloading aicpu kernel info list to find it
if (select_status == kNoMatched) {
MS_LOG(WARNING) << "The node [" << kernel_node->DebugString()
<< "] cannot find valid TBE kernel info, try to get aicpu kernel info";
kernel::AicpuQuery(kernel_node, &kernel_info_list);
selected_kernel_info = CanHitKernelInfo(&status, kernel_node, kernel_info_list);
kernel::AICpuQuery(kernel_node, &kernel_info_list);
select_status = SetMatchedKernelInfo(kernel_node, kernel_info_list);
}
if (selected_kernel_info == nullptr) {
// The kernel info not finded both in the aicpu kernel list & aicore kernel list
if (select_status == kNoMatched) {
std::ostringstream buffer;
PrintInputAndOutputInferType(buffer, kernel_node);
MS_EXCEPTION(TypeError) << "The node [" << kernel_node->DebugString()
<< "] cannot find valid kernel info, not supported the type " << buffer.str();
}
AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_info, kernel_node.get());
// Set format and data type for input tensor.
SetTensorDeviceInfo(*selected_kernel_info, kernel_node);
return status;
}
bool CheckKernelAccuracySupported(const CNodePtr &kernel_node,
const kernel::KernelBuildInfoPtr &new_kernel_build_info) {
MS_EXCEPTION_IF_NULL(kernel_node);
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list;
kernel::KernelQuery(kernel_node, &kernel_info_list);
auto result = std::find_if(kernel_info_list.begin(), kernel_info_list.end(),
[&new_kernel_build_info](const kernel::KernelBuildInfoPtr item) {
MS_EXCEPTION_IF_NULL(item);
return *item == *new_kernel_build_info;
});
return result != kernel_info_list.end();
return select_status;
}
} // namespace ascend
} // namespace device
......
......@@ -21,8 +21,13 @@
namespace mindspore {
namespace device {
namespace ascend {
int SelectKernelInfo(const CNodePtr &kernel_node);
bool CheckKernelAccuracySupported(const CNodePtr &kernel_node, const kernel::KernelBuildInfoPtr &new_kernel_build_info);
enum KernelSelectStatus {
kNoMatched = -1,
kStatusAllMatched = 0,
kStatusReducePrecision = 1,
kStatusRaisePrecision = 2,
};
KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node);
} // namespace ascend
} // namespace device
} // namespace mindspore
......
......@@ -35,7 +35,7 @@ void HcclMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<K
std::vector<std::string> input_format, output_format;
std::vector<TypeId> input_type, output_type;
for (const auto &data_type : data_type_list) {
for (const auto &format : k4DSupportFormat) {
for (const auto &format : kOpFormatList) {
auto builder = std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>();
input_format.clear();
input_format.push_back(format);
......
......@@ -35,14 +35,18 @@ void FilterInvalidKernelInfo(const CNodePtr &kernel_node,
return AnfAlgo::GetOutputTensorNum(kernel_node) == kernel_build_info->GetOutputNum() &&
AnfAlgo::GetInputTensorNum(kernel_node) == kernel_build_info->GetInputNum();
});
kernel_info_list->clear();
if (!filtered_list.empty()) {
kernel_info_list->clear();
(void)std::copy(filtered_list.begin(), filtered_list.end(), std::back_inserter(*kernel_info_list));
} else {
MS_LOG(EXCEPTION) << "node" << kernel_node->DebugString() << "'s output size : ["
<< AnfAlgo::GetOutputTensorNum(kernel_node) << "]"
<< "input size : [" << AnfAlgo::GetInputTensorNum(kernel_node)
<< "] cannot match any kernelInfo !";
MS_LOG(WARNING) << "All kernel Info list does not match any kernel info ";
for (size_t index; index < kernel_info_list->size(); ++index) {
MS_EXCEPTION_IF_NULL(kernel_info_list->at(index));
MS_LOG(WARNING) << "kernel [ " << index << " ] :" << kernel_info_list->at(index)->ToString();
}
MS_LOG(WARNING) << "node" << kernel_node->DebugString() << "'s output size : ["
<< AnfAlgo::GetOutputTensorNum(kernel_node) << "]"
<< "input size : [" << AnfAlgo::GetInputTensorNum(kernel_node) << "] cannot match any kernelInfo !";
}
}
} // namespace
......@@ -50,7 +54,6 @@ void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel
MS_EXCEPTION_IF_NULL(kernel_node);
MS_EXCEPTION_IF_NULL(kernel_info_list);
TbeMetadataInfo(kernel_node, kernel_info_list);
if (kernel_info_list->empty()) {
AicpuMetadataInfo(kernel_node, kernel_info_list);
}
......@@ -68,12 +71,41 @@ void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel
FilterInvalidKernelInfo(kernel_node, kernel_info_list);
}
void AicpuQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list) {
void AICpuQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list) {
MS_EXCEPTION_IF_NULL(kernel_node);
MS_EXCEPTION_IF_NULL(kernel_info_list);
kernel_info_list->clear();
AicpuMetadataInfo(kernel_node, kernel_info_list);
FilterInvalidKernelInfo(kernel_node, kernel_info_list);
}
bool IsSupportedByAiCpu(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info) {
MS_EXCEPTION_IF_NULL(kernel_node);
MS_EXCEPTION_IF_NULL(select_kernel_build_info);
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list;
auto cnode = kernel_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
AicpuMetadataInfo(cnode, &kernel_info_list);
FilterInvalidKernelInfo(cnode, &kernel_info_list);
return std::any_of(kernel_info_list.begin(), kernel_info_list.end(),
[&select_kernel_build_info](const kernel::KernelBuildInfoPtr item) {
MS_EXCEPTION_IF_NULL(item);
return *item == *select_kernel_build_info;
});
}
bool IsSupportedByAiCore(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info) {
MS_EXCEPTION_IF_NULL(kernel_node);
MS_EXCEPTION_IF_NULL(select_kernel_build_info);
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list;
auto cnode = kernel_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
TbeMetadataInfo(cnode, &kernel_info_list);
FilterInvalidKernelInfo(cnode, &kernel_info_list);
return std::any_of(kernel_info_list.begin(), kernel_info_list.end(),
[&select_kernel_build_info](const kernel::KernelBuildInfoPtr item) {
MS_EXCEPTION_IF_NULL(item);
return *item == *select_kernel_build_info;
});
}
} // namespace kernel
} // namespace mindspore
......@@ -26,7 +26,9 @@
namespace mindspore {
namespace kernel {
void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list);
void AicpuQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list);
void AICpuQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list);
bool IsSupportedByAiCpu(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info);
bool IsSupportedByAiCore(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_KERNEL_KERNEL_QUERY_H_
......@@ -551,11 +551,6 @@ bool ParseMetadata(const CNodePtr &kernel_node, const std::shared_ptr<const OpIn
}
bool IsShapeMatchFormat(const std::vector<size_t> &shape, const std::string &format) {
const std::set<std::string> kOpFormatList = {kOpFormat_DEFAULT, kOpFormat_NC1KHKWHWC0, kOpFormat_ND,
kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_HWCN,
kOpFormat_NC1HWC0, kOpFormat_FRAC_Z, kOpFormat_C1HWNCoC0,
kOpFormat_FRAC_NZ, kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04};
// if format is default, it remarkes support all format
if (kOpFormatList.find(format) == kOpFormatList.end()) {
MS_LOG(EXCEPTION) << "Got the unknown format " << format;
......
......@@ -54,6 +54,7 @@
#include "pre_activate/pass/optimize_dependence.h"
#include "pre_activate/pass/erase_visit_attr.h"
#include "pre_activate/ascend/format_type/insert_cast.h"
#include "pre_activate/ascend/format_type/convert_unsupported_transnode_to_aicpu.h"
#include "pre_activate/pass/eliminate_redundant_op.h"
#include "pre_activate/pass/common_subexpression_elimination.h"
#include "pre_activate/ascend/format_type/merge_cast_to_op.h"
......@@ -172,6 +173,7 @@ void AscendMixPrecision(const std::shared_ptr<session::KernelGraph> &kernel_grap
mixed_precision_pm->AddPass(std::make_shared<MergeCastToOp>());
mixed_precision_pm->AddPass(std::make_shared<LayerNormBetaGammaBackpropFusion>());
mixed_precision_pm->AddPass(std::make_shared<EraseVisitAttr>());
mixed_precision_pm->AddPass(std::make_shared<ConvertUnSupportNodeToAICPU>());
optimizer->AddPassManager(mixed_precision_pm);
(void)optimizer->Optimize(kernel_graph);
kernel_graph->SetExecOrderByDefault();
......
......@@ -268,6 +268,7 @@ AnfNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr
}
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cast.get());
AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {origin_shape}, cast.get());
AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(true), cast);
return cast;
}
......
......@@ -30,10 +30,6 @@ class KernelSelect {
KernelSelect() = default;
virtual ~KernelSelect() = default;
virtual void SelectKernel(const CNodePtr &cnode) { device::ascend::SelectKernelInfo(cnode); }
virtual bool CheckKernelAccuracySupported(const CNodePtr &kernel_node,
const kernel::KernelBuildInfoPtr &new_kernel_build_info) {
return device::ascend::CheckKernelAccuracySupported(kernel_node, new_kernel_build_info);
}
};
using KernelSelectPtr = std::shared_ptr<KernelSelect>;
......@@ -41,8 +37,13 @@ class SupportedChecker {
public:
SupportedChecker() = default;
virtual ~SupportedChecker() = default;
virtual bool CheckSupported(const AnfNodePtr &anf_node, const kernel::KernelBuildInfoPtr &select_kernel_build_info) {
return kernel::CheckSupported(anf_node, select_kernel_build_info);
virtual bool CheckAiCoreSupported(const AnfNodePtr &anf_node,
const kernel::KernelBuildInfoPtr &select_kernel_build_info) {
return kernel::IsSupportedByAiCore(anf_node, select_kernel_build_info);
}
virtual bool CheckAiCpuSupported(const AnfNodePtr &anf_node,
const kernel::KernelBuildInfoPtr &select_kernel_build_info) {
return kernel::IsSupportedByAiCpu(anf_node, select_kernel_build_info);
}
};
using SupportedCheckerPtr = std::shared_ptr<SupportedChecker>;
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pre_activate/ascend/format_type/convert_unsupported_transnode_to_aicpu.h"
#include <memory>
#include "session/anf_runtime_algorithm.h"
#include "kernel/kernel_build_info.h"
#include "kernel/kernel_query.h"
namespace mindspore {
namespace opt {
const BaseRef ConvertUnSupportNodeToAICPU::DefinePattern() const {
VarPtr X = std::make_shared<Var>();
VarPtr Xs = std::make_shared<SeqVar>();
return VectorRef({X, Xs});
}
const AnfNodePtr ConvertUnSupportNodeToAICPU::Process(const mindspore::FuncGraphPtr &,
const mindspore::AnfNodePtr &node,
const mindspore::EquivPtr &) const {
if (node == nullptr || !node->isa<CNode>()) {
return nullptr;
}
auto node_name = AnfAlgo::GetCNodeName(node);
if (node_name != prim::KPrimTransData->name() || node_name != prim::kPrimCast->name()) {
return nullptr;
}
auto kernel_builder_info = AnfAlgo::GetSelectKernelBuildInfo(node);
if (supported_checker_->CheckAiCoreSupported(node, kernel_builder_info)) {
return node;
} else if (supported_checker_->CheckAiCpuSupported(node, kernel_builder_info)) {
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(kernel_builder_info);
builder->SetKernelType(AICPU_KERNEL);
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), node.get());
} else {
MS_LOG(EXCEPTION) << " kernel " << kernel_builder_info->ToString() << "is not supported in AiCPU & AiCore : node ["
<< node->DebugString() << "]";
}
return node;
}
} // namespace opt
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <memory>
#include "pre_activate/common/optimizer.h"
#include "pre_activate/ascend/ascend_helper.h"
#ifndef MINDSPORE_CONVERT_UNSUPPORTED_NODE_TO_AICPU_H
#define MINDSPORE_CONVERT_UNSUPPORTED_NODE_TO_AICPU_H
namespace mindspore {
namespace opt {
class ConvertUnSupportNodeToAICPU : public PatternProcessPass {
public:
explicit ConvertUnSupportNodeToAICPU(bool multigraph = true)
: PatternProcessPass("convert_unsupported_node_to_aicpu", multigraph),
supported_checker_(std::make_shared<SupportedChecker>()) {}
~ConvertUnSupportNodeToAICPU() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
SupportedCheckerPtr supported_checker_;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CONVERT_UNSUPPORTED_NODE_TO_AICPU_H
......@@ -13,8 +13,8 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_DEVICE_OPTIMIZER_FORMAT_TYPE_PASS_INSERT_CAST_FOR_RUNOP_H_
#define MINDSPORE_CCSRC_DEVICE_OPTIMIZER_FORMAT_TYPE_PASS_INSERT_CAST_FOR_RUNOP_H_
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_CAST_FOR_RUNOP_H_
#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_CAST_FOR_RUNOP_H_
#include <string>
#include "pre_activate/common/optimizer.h"
......@@ -32,4 +32,4 @@ class RunOpInsertCast : public PatternProcessPass {
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_DEVICE_OPTIMIZER_FORMAT_TYPE_PASS_INSERT_CAST_FOR_RUNOP_H_
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_CAST_FOR_RUNOP_H_
......@@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_DEVICE_OPTIMIZER_FORMAT_TYPE_PASS_INSERT_TRANSDATA_FOR_RUNOP_H_
#define MINDSPORE_CCSRC_DEVICE_OPTIMIZER_FORMAT_TYPE_PASS_INSERT_TRANSDATA_FOR_RUNOP_H_
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_TRANSDATA_FOR_RUNOP_H_
#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_TRANSDATA_FOR_RUNOP_H_
#include <string>
#include <utility>
......@@ -41,4 +41,4 @@ class RunOpInsertTransData : public PatternProcessPass {
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_DEVICE_OPTIMIZER_FORMAT_TYPE_PASS_INSERT_TRANSDATA_FOR_RUNOP_H_
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_TRANSDATA_FOR_RUNOP_H_
......@@ -128,7 +128,7 @@ const AnfNodePtr TopKSplit::Process(const FuncGraphPtr &func_graph, const AnfNod
auto indices_const = CreateValueNode(new_cnode);
new_cnode->add_input(indices_const);
MS_EXCEPTION_IF_NULL(supported_checker_);
if (!supported_checker_->CheckSupported(new_cnode, CreateKernelBuildInfo())) {
if (!supported_checker_->CheckAiCoreSupported(new_cnode, CreateKernelBuildInfo())) {
return nullptr;
}
......
......@@ -53,7 +53,7 @@ const AnfNodePtr TransposeTransDataFusion::Process(const FuncGraphPtr &func_grap
new_transdata_builder->SetProcessor(transdata_kernel_build_info->processor());
auto new_fusion_transdata = std::make_shared<Primitive>(kTransDataOpName);
if (kernel_select_->CheckKernelAccuracySupported(transdata_cnode, new_transdata_builder->Build())) {
if (supported_checker_->CheckAiCoreSupported(transdata_cnode, new_transdata_builder->Build())) {
std::vector<AnfNodePtr> inputs = {NewValueNode(new_fusion_transdata),
utils::cast<AnfNodePtr>((*equiv)[input_varptr_])};
auto new_node = func_graph->NewCNode(inputs);
......
......@@ -34,7 +34,7 @@ class TransposeTransDataFusion : public PatternProcessPass {
explicit TransposeTransDataFusion(bool multigraph = true)
: PatternProcessPass("transpose_transdata_fusion", multigraph) {
input_varptr_ = std::make_shared<Var>();
kernel_select_ = std::make_shared<KernelSelect>();
supported_checker_ = std::make_shared<SupportedChecker>();
}
~TransposeTransDataFusion() override = default;
const BaseRef DefinePattern() const override;
......@@ -42,7 +42,9 @@ class TransposeTransDataFusion : public PatternProcessPass {
private:
VarPtr input_varptr_;
KernelSelectPtr kernel_select_;
private:
SupportedCheckerPtr supported_checker_;
};
} // namespace opt
} // namespace mindspore
......
......@@ -329,9 +329,9 @@ void AscendSession::SelectKernel(const KernelGraph &kernel_graph) const {
size_t reduce_precision_count = 0;
for (const auto &cnode : kernel_graph.execution_order()) {
auto status = device::ascend::SelectKernelInfo(cnode);
if (status == kStatusRaisePrecision) {
if (status == device::ascend::kStatusRaisePrecision) {
raise_precision_count++;
} else if (status == kStatusReducePrecision) {
} else if (status == device::ascend::kStatusReducePrecision) {
reduce_precision_count++;
}
MS_LOG(INFO) << "Select ApplyKernel: " << cnode->DebugString();
......
......@@ -27,6 +27,8 @@
namespace mindspore {
namespace session {
namespace {
constexpr auto kIsFeatureMapOutput = "IsFeatureMapOutput";
constexpr auto kIsFeatureMapInputList = "IsFeatureMapInputList";
void PushNoVisitedNode(const AnfNodePtr &node, std::queue<AnfNodePtr> *que,
std::unordered_set<AnfNodePtr> *visited_nodes) {
MS_EXCEPTION_IF_NULL(que);
......@@ -180,11 +182,24 @@ CNodePtr KernelGraph::NewCNode(const std::vector<AnfNodePtr> &inputs) {
cnode->set_abstract(std::make_shared<abstract::AbstractNone>());
// create kernel_info from new parameter
auto kernel_info = std::make_shared<device::KernelInfo>();
std::vector<size_t> feature_map_input_indexs;
// if the node only has the primitive(such as getNext) or the node's input has a feature map input
// then the node's output is a feature map output
if (inputs.size() == 1 || std::any_of(inputs.begin() + 1, inputs.end(),
[&](const AnfNodePtr &node) { return AnfAlgo::IsFeatureMapOutput(node); })) {
for (size_t index = 1; index < inputs.size(); ++index) {
auto node = inputs[index];
if (AnfAlgo::IsFeatureMapOutput(node)) {
feature_map_input_indexs.push_back(index);
}
}
if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimCast->name()) {
AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(false), cnode);
}
if (inputs.size() == 1 || !feature_map_input_indexs.empty()) {
kernel_info->SetFeatureMapFlag(true);
AnfAlgo::SetNodeAttr(kIsFeatureMapOutput, MakeValue(true), cnode);
AnfAlgo::SetNodeAttr(kIsFeatureMapInputList, MakeValue(feature_map_input_indexs), cnode);
} else {
AnfAlgo::SetNodeAttr(kIsFeatureMapOutput, MakeValue(false), cnode);
}
cnode->set_kernel_info(kernel_info);
AnfAlgo::SetGraphId(graph_id_, cnode.get());
......
......@@ -142,6 +142,7 @@ constexpr auto kLabelGotoOpName = "LabelGoto";
// attr key name
constexpr auto kAttrInputNames = "input_names";
constexpr auto kIsBackendCast = "is_backed_cast";
constexpr auto kAttrOutputNames = "output_names";
constexpr auto kAttrVisited = "visited";
constexpr auto kAttrShape = "shape";
......@@ -201,10 +202,6 @@ constexpr auto kControlDependBehindIndex = 2;
// index define of depend
constexpr auto kRealInputIndexInDepend = 1;
constexpr auto kDependAttachNodeIndex = 2;
// status of kernel select result
const int kStatusReducePrecision = -1;
const int kStatusRaisePrecision = 1;
const int kStatusAllMatched = 0;
// format
constexpr auto kOpFormat_DEFAULT = "DefaultFormat";
constexpr auto kOpFormat_NC1KHKWHWC0 = "NC1KHKWHWC0";
......@@ -218,18 +215,11 @@ constexpr auto kOpFormat_FRAC_NZ = "FRACTAL_NZ";
constexpr auto kOpFormat_C1HWNCoC0 = "C1HWNCoC0";
constexpr auto kOpFormat_NC1HWC0_C04 = "NC1HWC0_C04";
constexpr auto kOpFormat_FRACTAL_Z_C04 = "FRACTAL_Z_C04";
const std::set<std::string> k1DSupportFormat = {kOpFormat_DEFAULT, kOpFormat_NCHW, kOpFormat_NHWC,
kOpFormat_FRAC_Z, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0,
kOpFormat_C1HWNCoC0};
const std::set<std::string> k2DSupportFormat = {kOpFormat_DEFAULT, kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_FRAC_Z,
kOpFormat_NC1KHKWHWC0};
const std::set<std::string> k3DSupportFormat = {kOpFormat_DEFAULT, kOpFormat_NC1KHKWHWC0};
const std::set<std::string> k4DSupportFormat = k1DSupportFormat;
const std::vector<std::set<std::string>> kShapeSupportFormatMap = {k1DSupportFormat, k2DSupportFormat, k3DSupportFormat,
k4DSupportFormat};
const std::set<std::string> kOpFormatList = {kOpFormat_DEFAULT, kOpFormat_NC1KHKWHWC0, kOpFormat_ND,
kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_HWCN,
kOpFormat_NC1HWC0, kOpFormat_FRAC_Z, kOpFormat_C1HWNCoC0,
kOpFormat_FRAC_NZ, kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04};
const std::set<std::string> kDefaultCompatibleFormat = {kOpFormat_ND, kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_HWCN};
const std::set<std::string> kOptOperatorSet = {
kMomentumOpName, kApplyMomentumOpName, kApplyAdadeltaOpName,
kApplyAdagradOpName, kApplyAdagradDAName, kApplyAdamOpName,
......
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "mindspore/ccsrc/device/ascend/kernel_select_ascend.h"
#include "common/common_test.h"
#include "session/kernel_graph.h"
#include "kernel/kernel.h"
#include "session/anf_runtime_algorithm.h"
#include "utils/utils.h"
#include "operator/ops.h"
#include "mindspore/ccsrc/device/kernel_info.h"
#include "mindspore/ccsrc/kernel/kernel_build_info.h"
#include <vector>
namespace mindspore {
namespace device {
namespace ascend {
namespace {
using KernelInfo = device::KernelInfo;
using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder;
using KernelBuildInfo = kernel::KernelBuildInfo;
using KernelGraph = session::KernelGraph;
using KernelBuildInfoPtr = std::shared_ptr<KernelBuildInfo>;
using KernelBuilderPtr = std::shared_ptr<KernelBuildInfoBuilder>;
using Shape = std::vector<size_t>;
using ShapeList = std::vector<Shape>;
enum MatchCountPriority {
MATCH_COUNT_PRIORITY_BEGIN = 0,
MATCH_FORMAT_COUNT = MATCH_COUNT_PRIORITY_BEGIN,
MATCH_DTYPE_COUNT,
MATCH_NZ_FORMAT_COUNT,
MATCH_5D_FORMAT_COUNT,
MATCH_OUTPUT_DTYPE_COUNT,
MATCH_COUNT_PRIORITY_END
};
const std::set<std::string> kOpFormatList = {
kOpFormat_DEFAULT, kOpFormat_NC1KHKWHWC0, kOpFormat_ND, kOpFormat_NCHW, kOpFormat_NHWC,
kOpFormat_HWCN, kOpFormat_NC1HWC0, kOpFormat_FRAC_Z, kOpFormat_C1HWNCoC0, kOpFormat_FRAC_NZ};
bool IsShapeMatchFormat(const std::vector<size_t> &shape, const std::string &format) {
// if format is default,it remarkes support all format
if (kOpFormatList.find(format) == kOpFormatList.end()) {
MS_EXCEPTION(ArgumentError) << "got the unknow format " << format;
}
if (format == kOpFormat_DEFAULT) {
return true;
}
// if shape size is 0,the shape will be a scalar
if (shape.empty()) {
return true;
}
if (shape.size() > kShapeSupportFormatMap.size()) {
return false;
}
if (format == kOpFormat_FRAC_NZ && shape.size() >= 2) {
return shape[shape.size() - 1] % 16 != 0 && shape[shape.size() - 2] % 16 != 0;
}
return !(kShapeSupportFormatMap[shape.size() - 1].find(format) == kShapeSupportFormatMap[shape.size() - 1].end());
}
bool IsValidKernelInfo(const std::shared_ptr<CNode> &kernel_node, const kernel::KernelBuildInfo &kernel_build_info) {
MS_EXCEPTION_IF_NULL(kernel_node);
auto check_function = [](const std::vector<size_t> &shape, const std::string &format) -> bool {
if (!IsShapeMatchFormat(shape, format)) {
return false;
}
for (auto shape_value : shape) {
if (shape_value == 0) {
MS_EXCEPTION(ArgumentError) << "dimension size of the tensor shape should be a positive integer, but got ["
<< shape_value << "]";
}
}
return true;
};
for (size_t index = 0; index < kernel_build_info.GetOutputNum(); ++index) {
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, index);
if (!check_function(output_shape, kernel_build_info.GetOutputFormat(index))) {
return false;
}
}
for (size_t index = 0; index < kernel_build_info.GetInputNum(); ++index) {
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, index);
if (!check_function(input_shape, kernel_build_info.GetInputFormat(index))) {
return false;
}
}
return true;
}
bool MatchInferOutputDataType(const CNodePtr &cnode, const kernel::KernelBuildInfo &kernel_build_info) {
MS_EXCEPTION_IF_NULL(cnode);
// Check input data type
for (size_t input_index = 0; input_index < kernel_build_info.GetInputNum(); ++input_index) {
AnfNodePtr cur_input = cnode->input(input_index + 1);
MS_EXCEPTION_IF_NULL(cur_input);
TypeId input_origin_type;
if (cur_input->isa<Parameter>() && AnfAlgo::IsParameterWeight(cur_input->cast<ParameterPtr>())) {
// weight
input_origin_type = AnfAlgo::GetOutputDeviceDataType(cur_input, 0);
} else {
// feature map
input_origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index);
}
if (input_origin_type == kTypeUnknown) {
continue;
}
if (kernel_build_info.GetInputDeviceType(input_index) != input_origin_type) {
return false;
}
}
// Check output data type
for (size_t output_index = 0; output_index < kernel_build_info.GetOutputNum(); ++output_index) {
if (kernel_build_info.GetOutputDeviceType(output_index) != AnfAlgo::GetOutputInferDataType(cnode, output_index)) {
return false;
}
}
return true;
}
/**
* compare too vector by priority,select a better vector,like compare too num,first compare highest num location,if
* equal then next num location
* example:[3,1,1,1] > [2,2,2,2] > [2,2,1,2] > [2,1,1,3]
*/
bool PriorityChooseItem(const std::vector<int> &cur_item, std::vector<int> *best_item) {
MS_EXCEPTION_IF_NULL(best_item);
if (cur_item.size() != best_item->size()) {
MS_LOG(ERROR) << "item size should be same!";
return false;
}
// Update the best_item by comparing the cur_item and best_item
for (size_t i = 0; i < cur_item.size(); i++) {
if (cur_item[i] > best_item->at(i)) {
*best_item = cur_item;
return true;
} else if (cur_item[i] == best_item->at(i)) {
continue;
} else {
return false;
}
}
return false;
}
void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, const std::shared_ptr<CNode> &kernel_node,
std::vector<int> *const cur_kernelinfo_match_counts) {
MS_EXCEPTION_IF_NULL(kernel_node);
MS_EXCEPTION_IF_NULL(cur_kernelinfo_match_counts);
if (cur_kernelinfo_match_counts->size() < MATCH_COUNT_PRIORITY_END) {
MS_EXCEPTION(ArgumentError) << "Out of range cur_kernelinfo_match_counts " << MATCH_COUNT_PRIORITY_END;
}
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) {
AnfNodePtr input_anf_node = kernel_node->input(input_index + 1);
MS_EXCEPTION_IF_NULL(input_anf_node);
// if a input parameter is a weight with default format, the input shouldn't participate the judge
if (input_anf_node->isa<Parameter>()) {
auto para = input_anf_node->cast<ParameterPtr>();
if (AnfAlgo::IsParameterWeight(para) && AnfAlgo::GetOutputDeviceDataType(para, 0) == kTypeUnknown) {
continue;
}
}
if (kernel_build_info.GetInputFormat(input_index) == AnfAlgo::GetPrevNodeOutputFormat(kernel_node, input_index)) {
(*cur_kernelinfo_match_counts)[MATCH_FORMAT_COUNT]++;
}
if (kernel_build_info.GetInputDeviceType(input_index) ==
AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, input_index)) {
(*cur_kernelinfo_match_counts)[MATCH_DTYPE_COUNT]++;
}
if (kernel_build_info.GetInputFormat(input_index) == kOpFormat_FRAC_NZ) {
(*cur_kernelinfo_match_counts)[MATCH_NZ_FORMAT_COUNT]++;
}
if (kernel_build_info.GetInputFormat(input_index) == kOpFormat_NC1HWC0) {
(*cur_kernelinfo_match_counts)[MATCH_5D_FORMAT_COUNT]++;
}
}
for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(kernel_node); ++output_index) {
// cal count of same output dtype between abstract and kernel info
if (kernel_build_info.GetOutputDeviceType(output_index) ==
AnfAlgo::GetOutputInferDataType(kernel_node, output_index)) {
(*cur_kernelinfo_match_counts)[MATCH_OUTPUT_DTYPE_COUNT]++;
}
}
}
void SetKernelBuildInfo(KernelBuilderPtr builder) {
builder->SetFusionType(kernel::OPAQUE);
builder->SetKernelType(AUTO_DIFF_KERNEL);
builder->SetProcessor(kernel::AICORE);
}
void test_select(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list) {
std::vector<int> most_match_counts = {-1, -1, -1, -1, -1};
int selected_index = -1;
for (size_t info_index = 0; info_index < kernel_info_list.size(); ++info_index) {
std::vector<int> cur_kernel_info_match_counts = {0, 0, 0, 0, 0};
if (!IsValidKernelInfo(kernel_node, *(kernel_info_list[info_index]))) {
continue;
}
if (!MatchInferOutputDataType(kernel_node, *(kernel_info_list[info_index]))) {
continue;
}
std::shared_ptr<kernel::KernelBuildInfo> kernel_info_ptr = kernel_info_list[info_index];
UpdateCurMatchCounts(*kernel_info_ptr, kernel_node, &cur_kernel_info_match_counts);
// Currently the selection policy is the match format count first, and then is datatype counts.
if (PriorityChooseItem(cur_kernel_info_match_counts, &most_match_counts)) {
selected_index = SizeToInt(info_index);
}
}
if (selected_index == -1) {
MS_EXCEPTION(NotExistsError) << "" << kernel_node->DebugString() << " Cannot find valid kernel Info !";
}
auto index = IntToSize(selected_index);
if (index >= kernel_info_list.size()) {
MS_EXCEPTION(ArgumentError) << "index outof range";
}
std::shared_ptr<kernel::KernelBuildInfo> selected_kernel_info_ptr = kernel_info_list[index];
MS_EXCEPTION_IF_NULL(selected_kernel_info_ptr);
AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_info_ptr, kernel_node.get());
}
void SetParentAbstract(std::vector<AnfNodePtr> parent_list, std::vector<std::vector<size_t>> shapes,
std::vector<TypeId> types) {
for (const auto &node : parent_list) {
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, node.get());
}
}
} // namespace
class AscendKernelSelctTest : public UT::Common {
public:
AscendKernelSelctTest() = default;
void SetUp() override {}
void TearDown() override {}
};
TEST_F(AscendKernelSelctTest, TestSelect) {
std::vector<KernelBuilderPtr> build_list;
std::vector<TypeId> type_list = {kNumberTypeFloat32};
for (size_t i = 0; i <= 4; ++i) {
build_list.push_back(std::make_shared<KernelBuildInfoBuilder>());
SetKernelBuildInfo(build_list[i]);
build_list[i]->SetInputsDeviceType(type_list);
build_list[i]->SetOutputsDeviceType(type_list);
}
std::vector<std::string> nd_fmt = {kOpFormat_DEFAULT};
std::vector<std::string> nz_fmt = {kOpFormat_FRAC_NZ};
auto anf_graph = std::make_shared<KernelGraph>();
// 16's multiple should not chose format NZ
Shape nd_shapes = {2, 32, 224, 224};
Shape nz_shapes = {3, 3, 5, 5};
auto add_value = NewValueNode(prim::kPrimTensorAdd);
auto a_node = anf_graph->NewCNode(std::vector<AnfNodePtr>{add_value});
auto b_node = anf_graph->NewCNode(std::vector<AnfNodePtr>{add_value});
std::vector<AnfNodePtr> parent_list = {add_value, a_node, b_node};
auto c_node = anf_graph->NewCNode(parent_list);
// a b
// \ /
// c
// a & b: kernel_info:{output_format:{nz},dtype:{kNumberTypeFloat32}}
// infer_dtype:{kNumberTypeFloat32},infer_shape:{{3, 3, 5, 5}}
// c: infer_dtype:{kNumberTypeFloat32},infer_shape:{{3, 3,224, 224}}
// set a & b's info
SetParentAbstract(parent_list, ShapeList{nz_shapes}, type_list);
// set abstract c
AnfAlgo::SetOutputInferTypeAndShape(type_list, ShapeList{nd_shapes}, c_node.get());
// set format of kernel info
build_list[0]->SetOutputsFormat(nz_fmt);
build_list[1]->SetOutputsFormat(nz_fmt);
build_list[2]->SetInputsFormat(std::vector<std::string>{nd_fmt[0], nd_fmt[0]});
build_list[3]->SetInputsFormat(std::vector<std::string>{nz_fmt[0], nz_fmt[0]});
build_list[2]->SetInputsDeviceType(std::vector<TypeId>{kNumberTypeFloat32, kNumberTypeFloat32});
build_list[3]->SetInputsDeviceType(std::vector<TypeId>{kNumberTypeFloat32, kNumberTypeFloat32});
build_list[2]->SetOutputsFormat(nd_fmt);
build_list[3]->SetOutputsFormat(nz_fmt);
std::vector<KernelBuildInfoPtr> select_info_list;
// set select info list
select_info_list.emplace_back(build_list[2]->Build());
select_info_list.emplace_back(build_list[3]->Build());
// set device info for a & b
AnfAlgo::SetSelectKernelBuildInfo(build_list[0]->Build(), a_node.get());
AnfAlgo::SetSelectKernelBuildInfo(build_list[1]->Build(), b_node.get());
test_select(c_node, select_info_list);
EXPECT_EQ(AnfAlgo::GetInputFormat(c_node, 0), kOpFormat_DEFAULT);
EXPECT_EQ(AnfAlgo::GetInputFormat(c_node, 1), kOpFormat_DEFAULT);
// set a & b's info
// a b
// \ /
// c
// a: kernel_info:{output_format:{5d},dtype:{kNumberTypeFloat32}}
// infer_dtype:{kNumberTypeFloat32},infer_shape:{{3, 3, 5, 5}}
// b: kernel_info:{output_format:{nz},dtype:{kNumberTypeFloat32}}
// infer_dtype:{kNumberTypeFloat32},infer_shape:{{3, 3, 5, 5}}
// c: infer_dtype:{kNumberTypeFloat32},infer_shape:{{3, 3, 5, 5}}
// set a & b's info
SetParentAbstract(parent_list, ShapeList{nz_shapes}, type_list);
// set abstract c
AnfAlgo::SetOutputInferTypeAndShape(type_list, ShapeList{nz_shapes}, c_node.get());
// set format of kernel info
build_list[0]->SetOutputsFormat(std::vector<std::string>{kOpFormat_NC1HWC0});
build_list[1]->SetOutputsFormat(nz_fmt);
build_list[2]->SetInputsFormat(std::vector<std::string>{kOpFormat_NC1HWC0, nd_fmt[0]});
build_list[3]->SetInputsFormat(std::vector<std::string>{nd_fmt[0], nz_fmt[0]});
build_list[2]->SetInputsDeviceType(std::vector<TypeId>{kNumberTypeFloat32, kNumberTypeFloat32});
build_list[3]->SetInputsDeviceType(std::vector<TypeId>{kNumberTypeFloat32, kNumberTypeFloat32});
build_list[2]->SetOutputsFormat(nd_fmt);
build_list[3]->SetOutputsFormat(nz_fmt);
// set select info list
select_info_list.emplace_back(build_list[2]->Build());
select_info_list.emplace_back(build_list[3]->Build());
// set device info for a & b
AnfAlgo::SetSelectKernelBuildInfo(build_list[0]->Build(), a_node.get());
AnfAlgo::SetSelectKernelBuildInfo(build_list[1]->Build(), b_node.get());
test_select(c_node, select_info_list);
EXPECT_EQ(AnfAlgo::GetInputFormat(c_node, 0), kOpFormat_DEFAULT);
EXPECT_EQ(AnfAlgo::GetInputFormat(c_node, 1), kOpFormat_FRAC_NZ);
}
} // namespace ascend
} // namespace device
} // namespace mindspore
\ No newline at end of file
......@@ -39,7 +39,7 @@ class MockSupportedChecker : public SupportedChecker {
public:
MockSupportedChecker() = default;
~MockSupportedChecker() override = default;
bool CheckSupported(const AnfNodePtr &anf_node, const kernel::KernelBuildInfoPtr &select_kernel_build_info) override {
bool CheckAiCoreSupported(const AnfNodePtr &anf_node, const kernel::KernelBuildInfoPtr &select_kernel_build_info) override {
return true;
}
}; // namespace opt
......
......@@ -37,6 +37,15 @@ class TestHWTransposeTransdataFusion : public BackendCommon {
UT::PyFuncGraphFetcher get_py_fun_;
};
class MockSupportedChecker : public SupportedChecker {
public:
MockSupportedChecker() = default;
~MockSupportedChecker() override = default;
bool CheckAiCoreSupported(const AnfNodePtr &anf_node, const kernel::KernelBuildInfoPtr &select_kernel_build_info) override {
return true;
}
};
class MockInsertTransOpKernelSelectTrans4Dto5D : public KernelSelect {
public:
MockInsertTransOpKernelSelectTrans4Dto5D() = default;
......@@ -60,37 +69,6 @@ class MockInsertTransOpKernelSelectTrans4Dto5D : public KernelSelect {
}
};
class MockTransposeTransdataFusionKernelSelect : public KernelSelect {
public:
MockTransposeTransdataFusionKernelSelect() = default;
~MockTransposeTransdataFusionKernelSelect() override = default;
bool CheckKernelAccuracySupported(const CNodePtr &kernel_node,
const kernel::KernelBuildInfoPtr &new_kernel_build_info) override {
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list;
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
builder.SetInputsFormat({kOpFormat_NCHW});
builder.SetOutputsFormat({kOpFormat_DEFAULT});
builder.SetInputsDeviceType({kNumberTypeFloat16});
builder.SetOutputsDeviceType({kNumberTypeFloat16});
builder.SetKernelType(KernelType::AUTO_DIFF_KERNEL);
builder.SetFusionType(kernel::FusionType::OPAQUE);
builder.SetProcessor(kernel::Processor::AICORE);
kernel_info_list.push_back(builder.Build());
MS_LOG(INFO) << "transpose transdata fusion success";
MS_LOG(INFO) << "new transdata build info input format:" << new_kernel_build_info->GetInputFormat(0)
<< ",outputformat:" << new_kernel_build_info->GetOutputFormat(0)
<< ",kerneltype:" << new_kernel_build_info->kernel_type()
<< ",fusiontype:" << new_kernel_build_info->fusion_type()
<< ",process:" << new_kernel_build_info->processor();
auto result = std::find_if(kernel_info_list.begin(), kernel_info_list.end(),
[&new_kernel_build_info](kernel::KernelBuildInfoPtr item) {
MS_EXCEPTION_IF_NULL(item);
return *item == *new_kernel_build_info;
});
return result != kernel_info_list.end();
}
};
TEST_F(TestHWTransposeTransdataFusion, test_transpose_transdata_fusion) {
/*
* def before(input0, input1):
......@@ -128,7 +106,7 @@ TEST_F(TestHWTransposeTransdataFusion, test_transpose_transdata_fusion) {
insert_trans_op_pass->kernel_select_ = std::make_shared<MockInsertTransOpKernelSelectTrans4Dto5D>();
pm->AddPass(insert_trans_op_pass);
auto transpose_transdata_pass = std::make_shared<opt::TransposeTransDataFusion>();
transpose_transdata_pass->kernel_select_ = std::make_shared<MockTransposeTransdataFusionKernelSelect>();
transpose_transdata_pass->supported_checker_ = std::make_shared<MockSupportedChecker>();
pm->AddPass(transpose_transdata_pass);
optimizer->AddPassManager(pm);
FuncGraphPtr new_graph = optimizer->Optimize(kg);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册