提交 02914ba0 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1581 fix flatten grad error with reshape

Merge pull request !1581 from zhaozhenlong/fix-issue-flatten-grad
...@@ -127,10 +127,9 @@ def get_bprop_squeeze(self): ...@@ -127,10 +127,9 @@ def get_bprop_squeeze(self):
@bprop_getters.register(P.Flatten) @bprop_getters.register(P.Flatten)
def get_bprop_flatten(self): def get_bprop_flatten(self):
"""Generate bprop for Flatten""" """Generate bprop for Flatten"""
flatten_grad = G.FlattenGrad()
def bprop(x, out, dout): def bprop(x, out, dout):
dx = flatten_grad(dout, shape_op(x)) dx = reshape(dout, shape_op(x))
return (dx,) return (dx,)
return bprop return bprop
......
...@@ -695,7 +695,7 @@ test_case_nn_ops = [ ...@@ -695,7 +695,7 @@ test_case_nn_ops = [
('Flatten', { ('Flatten', {
'block': P.Flatten(), 'block': P.Flatten(),
'desc_inputs': [[128, 32, 32, 64]], 'desc_inputs': [[128, 32, 32, 64]],
'desc_bprop': [[128 * 32 * 8 * 16]]}), 'desc_bprop': [[128, 65536]]}),
('LogSoftmax', { ('LogSoftmax', {
'block': P.LogSoftmax(), 'block': P.LogSoftmax(),
'desc_inputs': [[64, 2]], 'desc_inputs': [[64, 2]],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册