提交 44bf7c93 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1414 fix issue use reshape as flatten grad impl

Merge pull request !1414 from zhaozhenlong/fix-issues-reshape-replace-flattern-grad
......@@ -385,7 +385,8 @@ bool IsNopNode(const AnfNodePtr &node) {
return false;
}
static std::unordered_set<std::string> nop_nodes = {prim::kPrimReshape->name(), kExpandDimsOpName,
prim::kPrimSqueeze->name(), prim::kPrimFlatten->name()};
prim::kPrimSqueeze->name(), prim::kPrimFlatten->name(),
kFlattenGradOpName};
if (node == nullptr || !node->isa<CNode>()) {
return false;
}
......
......@@ -197,3 +197,4 @@ from .cum_sum import _cum_sum_tbe
from .apply_rms_prop import _apply_rms_prop_tbe
from .cumprod import _cumprop_tbe
from .reduce_prod import _reduce_prod_tbe
from .flatten_grad import _flatten_grad_tbe
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Reshape op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
flatten_grad_op_info = TBERegOp("FlattenGrad") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("reshape.so") \
.compute_cost(10) \
.kernel_name("reshape") \
.partial_flag(True) \
.attr("shape", "required", "listInt", "all") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.get_op_info()
@op_info_register(flatten_grad_op_info)
def _flatten_grad_tbe():
"""Reshape TBE register"""
return
......@@ -121,6 +121,16 @@ class NetForFlatten0D(nn.Cell):
return self.flatten(x)
class NetForFlattenComposed(nn.Cell):
# make flatten op together with other ops for testing flatten grad
def __init__(self):
super(NetForFlattenComposed, self).__init__()
self.flatten = P.Flatten()
def construct(self, x, y):
return self.flatten(x+x) + y
class ArgmaxNet(nn.Cell):
def __init__(self):
super(ArgmaxNet, self).__init__()
......@@ -695,7 +705,7 @@ test_case_nn_ops = [
('Flatten', {
'block': P.Flatten(),
'desc_inputs': [[128, 32, 32, 64]],
'desc_bprop': [[128 * 32 * 8 * 16]]}),
'desc_bprop': [[128, 65536]]}),
('LogSoftmax', {
'block': P.LogSoftmax(),
'desc_inputs': [[64, 2]],
......@@ -897,6 +907,11 @@ test_case_nn_ops = [
'desc_inputs': [Tensor(np.ones([8]).astype(np.int32)), Tensor(np.ones([8, 3]).astype(np.int32))],
'desc_bprop': [Tensor(np.ones([8, 3]).astype(np.int32))],
'skip': ['backward']}),
('Flatten_3', {
'block': NetForFlattenComposed(),
'desc_inputs': [Tensor(np.ones([2, 3, 4]).astype(np.int32)), Tensor(np.ones([2, 12]).astype(np.int32))],
'desc_bprop': [Tensor(np.ones([2, 12]).astype(np.int32))],
'skip': []}),
('ArgmaxNet', {
'block': ArgmaxNet(),
'desc_inputs': [Tensor(np.array([[128, 32, 32, 64], [128, 32, 32, 64]]).astype(np.float16))],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册