提交 02914ba0 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1581 fix flatten grad error with reshape

Merge pull request !1581 from zhaozhenlong/fix-issue-flatten-grad
......@@ -127,10 +127,9 @@ def get_bprop_squeeze(self):
@bprop_getters.register(P.Flatten)
def get_bprop_flatten(self):
"""Generate bprop for Flatten"""
flatten_grad = G.FlattenGrad()
def bprop(x, out, dout):
dx = flatten_grad(dout, shape_op(x))
dx = reshape(dout, shape_op(x))
return (dx,)
return bprop
......
......@@ -695,7 +695,7 @@ test_case_nn_ops = [
('Flatten', {
'block': P.Flatten(),
'desc_inputs': [[128, 32, 32, 64]],
'desc_bprop': [[128 * 32 * 8 * 16]]}),
'desc_bprop': [[128, 65536]]}),
('LogSoftmax', {
'block': P.LogSoftmax(),
'desc_inputs': [[64, 2]],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册