提交 ea66324a 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!5197 add attr for transdata node

Merge pull request !5197 from lianliguang/master
......@@ -108,10 +108,7 @@ std::string KernelBuildInfo::ToString() const {
return output_buffer.str();
}
bool KernelBuildInfo::operator==(const KernelBuildInfo &other) const {
if (kernel_type_ != other.kernel_type_ || fusion_type_ != other.fusion_type_ || processor_ != other.processor_) {
return false;
}
bool KernelBuildInfo::IsSimilarityKernelBuildInfo(const KernelBuildInfo &other) const {
if (inputs_format_ != other.inputs_format_ || outputs_format_ != other.outputs_format_) {
if (op_pattern_ != kFormatAgnosticPattern) {
return false;
......@@ -123,6 +120,13 @@ bool KernelBuildInfo::operator==(const KernelBuildInfo &other) const {
return !(inputs_device_type_ != other.inputs_device_type_ || outputs_device_type_ != other.outputs_device_type_);
}
bool KernelBuildInfo::operator==(const KernelBuildInfo &other) const {
if (kernel_type_ != other.kernel_type_ || fusion_type_ != other.fusion_type_ || processor_ != other.processor_) {
return false;
}
return IsSimilarityKernelBuildInfo(other);
}
bool KernelBuildInfo::IsInputDefaultPadding() const { return input_reshape_type_.empty(); }
bool KernelBuildInfo::IsOutputDefaultPadding() const { return output_reshape_type_.empty(); }
......
......@@ -91,6 +91,8 @@ class KernelBuildInfo {
std::string ToString() const;
bool IsSimilarityKernelBuildInfo(const KernelBuildInfo &other) const;
bool operator==(const KernelBuildInfo &other) const;
bool operator!=(const KernelBuildInfo &other) const;
......
......@@ -130,6 +130,7 @@ void AICPUQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel:
AicpuMetadataInfo(kernel_node, kernel_info_list);
FilterInvalidKernelInfo(kernel_node, kernel_info_list);
}
bool IsSupportedByAICPU(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info) {
MS_EXCEPTION_IF_NULL(kernel_node);
MS_EXCEPTION_IF_NULL(select_kernel_build_info);
......@@ -140,7 +141,7 @@ bool IsSupportedByAICPU(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr
return std::any_of(kernel_info_list.begin(), kernel_info_list.end(),
[&select_kernel_build_info](const kernel::KernelBuildInfoPtr item) {
MS_EXCEPTION_IF_NULL(item);
return *item == *select_kernel_build_info;
return item->IsSimilarityKernelBuildInfo(*select_kernel_build_info);
});
}
......
......@@ -178,22 +178,6 @@ void TbeAdapter::NormalizeFuncName(std::string *func_name) {
}
}
void TbeAdapter::SetTbeAttrsForTransDataOp(const mindspore::AnfNodePtr &anf_node) {
MS_EXCEPTION_IF_NULL(anf_node);
if (AnfAlgo::GetCNodeName(anf_node) == kTransDataOpName) {
std::string input_format = AnfAlgo::GetInputFormat(anf_node, 0);
std::string output_format = AnfAlgo::GetOutputFormat(anf_node, 0);
if (input_format == kOpFormat_DEFAULT) {
input_format = kOpFormat_NCHW;
}
if (output_format == kOpFormat_DEFAULT) {
output_format = kOpFormat_NCHW;
}
AnfAlgo::SetNodeAttr("src_format", MakeValue(input_format), anf_node);
AnfAlgo::SetNodeAttr("dst_format", MakeValue(output_format), anf_node);
}
}
std::unordered_set<std::string> input_order_adjusted_ops = {
"Conv2DBackpropInput", "Conv2DBackpropFilter", "LogSoftmaxGrad", "LayerNormGrad", "LayerNormXBackprop",
"LayerNormBetaGammaBackprop", "MinimumGrad", "MaximumGrad", "ApplyCenteredRMSProp"};
......
......@@ -36,7 +36,7 @@ class TbeAdapter {
TbeAdapter() = default;
~TbeAdapter() = default;
static void NormalizeFuncName(std::string *func_name);
static void SetTbeAttrsForTransDataOp(const AnfNodePtr &anf_node);
static void InputOrderPass(const std::string &op_name, std::vector<std::vector<nlohmann::json>> const &inputs_list,
nlohmann::json *inputs_json);
static bool RunAttrPass(const AnfNodePtr &anf_node, const std::vector<std::shared_ptr<OpAttr>> &op_info_attrs,
......
......@@ -75,7 +75,6 @@ bool TbeOpParallelBuild(const std::vector<AnfNodePtr> &anf_nodes) {
set<std::string> processed_kernel;
for (const auto &anf_node : anf_nodes) {
// gen kernel json
tbe::TbeAdapter::SetTbeAttrsForTransDataOp(anf_node);
if (AnfAlgo::GetKernelMod(anf_node) != nullptr) {
continue;
}
......
......@@ -48,6 +48,22 @@ AnfNodePtr CreateReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &i
return reshape;
}
void SetTransNodeAttr(const CNodePtr &trans_node) {
MS_EXCEPTION_IF_NULL(trans_node);
if (AnfAlgo::GetCNodeName(trans_node) == kTransDataOpName) {
std::string input_format = AnfAlgo::GetInputFormat(trans_node, 0);
std::string output_format = AnfAlgo::GetOutputFormat(trans_node, 0);
if (input_format == kOpFormat_DEFAULT) {
input_format = kOpFormat_NCHW;
}
if (output_format == kOpFormat_DEFAULT) {
output_format = kOpFormat_NCHW;
}
AnfAlgo::SetNodeAttr(kAttrSrcFormat, MakeValue(input_format), trans_node);
AnfAlgo::SetNodeAttr(kAttrDstFormat, MakeValue(output_format), trans_node);
}
}
AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const KernelSelectPtr &kernel_select, size_t insert_index, bool is_insert_input) {
AnfNodePtr trans_node = nullptr;
......@@ -173,6 +189,7 @@ void RefreshKernelBuildInfo(const std::string &input_format, const std::string &
builder->SetInputsDeviceType({type_id});
}
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), trans_data.get());
SetTransNodeAttr(trans_data->cast<CNodePtr>());
}
CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const KernelSelectPtr &kernel_select,
......
......@@ -224,6 +224,7 @@ constexpr auto kAttrEventId = "event_id";
constexpr auto kAttrDynInput = "dynamic";
constexpr auto kAttrDynInputSizes = "dyn_input_sizes";
constexpr auto kAttrSrcFormat = "src_format";
constexpr auto kAttrDstFormat = "dst_format";
constexpr auto kAttrMultiples = "multiples";
constexpr auto kAttrFixPrecision = "fix_precision";
constexpr auto kAttrOutputPrecision = "output_precision";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册