提交 179d8a27 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!4064 fix pass fail when ReduceMin's axis set with one int number

Merge pull request !4064 from huanghui/reduce-min-fission-pass
......@@ -32,7 +32,7 @@ CNodePtr CreateReduceMin(const FuncGraphPtr &graph, const AnfNodePtr &input, con
return reduce_min;
}
bool NeedOptmize(const TypeId &dtype, const std::vector<size_t> &shape, const std::vector<int> &axis) {
bool NeedOptimize(const TypeId &dtype, const std::vector<size_t> &shape, const std::vector<int> &axis) {
if (dtype != kNumberTypeFloat32) {
MS_LOG(INFO) << "ReduceMin's input Dtype is not float32, no need optimize!";
return false;
......@@ -84,7 +84,7 @@ std::vector<size_t> GetInferShape(const std::vector<size_t> &shape, const std::v
for (size_t item = 0; item < shape.size(); ++item) {
if (axis_first.end() != std::find(axis_first.begin(), axis_first.end(), item)) {
if (keep_dims) {
// If keep_dims is true, curretn dimesion set to 1
// If keep_dims is true, current dimension set to 1
shape_first.push_back(1);
}
} else {
......@@ -110,28 +110,31 @@ const AnfNodePtr ReduceMinFission::Process(const FuncGraphPtr &graph, const AnfN
CheckCNodeInputSize(cnode, 2);
auto shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, 0);
auto dtype = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, 0);
if (!AnfAlgo::HasNodeAttr(kAttrAxis, cnode)) {
MS_LOG(INFO) << "ReduceMin has no axis, no need optimize!";
auto prim = AnfAlgo::GetCNodePrimitive(cnode);
MS_EXCEPTION_IF_NULL(prim);
if (!prim->HasAttr(kAttrAxis) || !prim->HasAttr(kAttrKeepDims)) {
MS_LOG(INFO) << "ReduceMin has no axis or keep_dims, no need optimize!";
return nullptr;
}
auto axis = AnfAlgo::GetNodeAttr<std::vector<int>>(cnode, kAttrAxis);
if (!AnfAlgo::HasNodeAttr(kAttrKeepDims, cnode)) {
MS_LOG(INFO) << "ReduceMin has no keep_dims, no need optimize!";
auto axis_value = prim->GetAttr(kAttrAxis);
MS_EXCEPTION_IF_NULL(axis_value);
if (!axis_value->isa<ValueSequeue>()) {
return nullptr;
}
auto axis = AnfAlgo::GetNodeAttr<std::vector<int>>(cnode, kAttrAxis);
auto keep_dims = AnfAlgo::GetNodeAttr<bool>(cnode, kAttrKeepDims);
if (!NeedOptmize(dtype, shape, axis)) {
if (!NeedOptimize(dtype, shape, axis)) {
MS_LOG(INFO) << "No need optimize for this ReduceMin. " << cnode->DebugString();
return nullptr;
}
// Create reduce_min1
CNodePtr reduce_min1 = CreateReduceMin(graph, cnode->input(1), cnode);
std::vector<int> axis_fisrt = CalFirstAxis(shape, axis);
std::vector<size_t> shape_first = GetInferShape(shape, axis_fisrt, keep_dims);
std::vector<int> axis_first = CalFirstAxis(shape, axis);
std::vector<size_t> shape_first = GetInferShape(shape, axis_first, keep_dims);
AnfAlgo::SetOutputInferTypeAndShape({dtype}, {shape_first}, reduce_min1.get());
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(axis_fisrt), reduce_min1);
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(axis_first), reduce_min1);
// Create reduce_min2
CNodePtr reduce_min2 = CreateReduceMin(graph, reduce_min1, cnode);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册