提交 9cb129ac 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1720 add reducemean's special kernel fileter rule

Merge pull request !1720 from lianliguang/r0.3
...@@ -178,7 +178,7 @@ void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, co ...@@ -178,7 +178,7 @@ void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, co
if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown || is_ref) { if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown || is_ref) {
std::vector<std::string> output_format = {selected_kernel_info.GetInputFormat(input_index)}; std::vector<std::string> output_format = {selected_kernel_info.GetInputFormat(input_index)};
builder->SetOutputsFormat(output_format); builder->SetOutputsFormat(output_format);
std::vector<TypeId> output_type = {AnfAlgo::GetOutputInferDataType(real_input_node, 0)}; std::vector<TypeId> output_type = {selected_kernel_info.GetInputDeviceType(input_index)};
builder->SetOutputsDeviceType(output_type); builder->SetOutputsDeviceType(output_type);
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), real_input_node.get()); AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), real_input_node.get());
} }
......
...@@ -582,6 +582,7 @@ bool IsShapeMatchFormat(const std::vector<size_t> &shape, const std::string &for ...@@ -582,6 +582,7 @@ bool IsShapeMatchFormat(const std::vector<size_t> &shape, const std::string &for
bool IsValidKernelInfo(const std::shared_ptr<CNode> &kernel_node, const kernel::KernelBuildInfo &kernel_build_info) { bool IsValidKernelInfo(const std::shared_ptr<CNode> &kernel_node, const kernel::KernelBuildInfo &kernel_build_info) {
MS_EXCEPTION_IF_NULL(kernel_node); MS_EXCEPTION_IF_NULL(kernel_node);
auto kernel_name = AnfAlgo::GetCNodeName(kernel_node);
const size_t kCAxis = 1; const size_t kCAxis = 1;
for (size_t index = 0; index < kernel_build_info.GetOutputNum(); ++index) { for (size_t index = 0; index < kernel_build_info.GetOutputNum(); ++index) {
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, index); auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, index);
...@@ -594,6 +595,12 @@ bool IsValidKernelInfo(const std::shared_ptr<CNode> &kernel_node, const kernel:: ...@@ -594,6 +595,12 @@ bool IsValidKernelInfo(const std::shared_ptr<CNode> &kernel_node, const kernel::
if (!IsShapeMatchFormat(output_shape, kernel_build_info.GetOutputFormat(index))) { if (!IsShapeMatchFormat(output_shape, kernel_build_info.GetOutputFormat(index))) {
return false; return false;
} }
if (kernel_name == "ReduceMean") {
auto keep_dims = AnfAlgo::GetNodeAttr<bool>(kernel_node, kAttrKeepDims);
if (keep_dims == false && kernel_build_info.GetOutputFormat(index) != kOpFormat_DEFAULT) {
return false;
}
}
} }
for (size_t index = 0; index < kernel_build_info.GetInputNum(); ++index) { for (size_t index = 0; index < kernel_build_info.GetInputNum(); ++index) {
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, index); auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, index);
...@@ -606,6 +613,12 @@ bool IsValidKernelInfo(const std::shared_ptr<CNode> &kernel_node, const kernel:: ...@@ -606,6 +613,12 @@ bool IsValidKernelInfo(const std::shared_ptr<CNode> &kernel_node, const kernel::
} }
return false; return false;
} }
if (kernel_name == "ReduceMean") {
auto keep_dims = AnfAlgo::GetNodeAttr<bool>(kernel_node, kAttrKeepDims);
if (keep_dims == false && kernel_build_info.GetInputFormat(index) != kOpFormat_DEFAULT) {
return false;
}
}
} }
if (AnfAlgo::GetCNodeName(kernel_node) == prim::kPrimCast->name()) { if (AnfAlgo::GetCNodeName(kernel_node) == prim::kPrimCast->name()) {
return AnfAlgo::GetOutputInferDataType(kernel_node, 0) == kernel_build_info.GetOutputDeviceType(0) && return AnfAlgo::GetOutputInferDataType(kernel_node, 0) == kernel_build_info.GetOutputDeviceType(0) &&
......
...@@ -315,7 +315,10 @@ CNodePtr InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnod ...@@ -315,7 +315,10 @@ CNodePtr InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnod
auto cur_input = AnfAlgo::GetInputNode(cnode, input_index); auto cur_input = AnfAlgo::GetInputNode(cnode, input_index);
auto kernel_with_index = AnfAlgo::VisitKernel(cur_input, 0); auto kernel_with_index = AnfAlgo::VisitKernel(cur_input, 0);
auto is_weight_boundary = [](const AnfNodePtr &node) -> bool { auto is_weight_boundary = [](const AnfNodePtr &node) -> bool {
if (node->isa<ValueNode>() || node->isa<Parameter>()) { if (node->isa<ValueNode>()) {
return true;
}
if (node->isa<Parameter>() && AnfAlgo::IsParameterWeight(node->cast<ParameterPtr>())) {
return true; return true;
} }
return false; return false;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册