提交 57ce3e5d 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!3711 fix topK multi dimention grad func

Merge pull request !3711 from fangzehua/topkgrad
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# ============================================================================ # ============================================================================
"""Define the grad rules of neural network related operations.""" """Define the grad rules of neural network related operations."""
import math
import numpy as np import numpy as np
from mindspore.ops import _selected_grad_ops as SG from mindspore.ops import _selected_grad_ops as SG
from mindspore.ops.primitive import constexpr from mindspore.ops.primitive import constexpr
...@@ -628,19 +629,62 @@ def get_bprop_onehot(self): ...@@ -628,19 +629,62 @@ def get_bprop_onehot(self):
return bprop return bprop
@constexpr
def _range_op(start, limit, delta, dtype):
"""helper function for Grad TopK"""
range_op = inner.Range(float(start), float(limit), float(delta))
length_input = math.ceil((limit - start) / delta)
input_tensor = Tensor(list(range(length_input)), dtype)
range_out = range_op(input_tensor)
return range_out
@constexpr
def _get_1d_shape(in_shape):
"""helper function for Grad TopK"""
out_shape = 1
for i in in_shape:
out_shape *= i
return (out_shape,)
@bprop_getters.register(P.TopK) @bprop_getters.register(P.TopK)
def get_bprop_top_kv2(self): def get_bprop_top_kv2(self):
"""Grad definition for `TopK` operation.""" """Grad definition for `TopK` operation."""
scatter = P.ScatterNd() scatter = P.ScatterNd()
expand_dims = P.ExpandDims() expand_dims = P.ExpandDims()
shape_op = P.Shape() shape_op = P.Shape()
reshape_op = P.Reshape()
dtype = P.DType()
def bprop(input_x, k, out, dout): def bprop(input_x, k, out, dout):
# (n1, n2, ...., n_p), in_lastdim = n_p
in_shape = shape_op(input_x)
in_lastdim = in_shape[-1]
# (n_1, ... n_(p-1), k), ind_lastdim = k
indices = out[1] indices = out[1]
indices = expand_dims(indices, -1) ind_shape = shape_op(indices)
updates = dout[0] ind_lastdim = ind_shape[-1]
shapes = shape_op(input_x)
return scatter(indices, updates, shapes), zeros_like(k) # (n_1*n_2..*n_(p-1), k), outerdim = n_1*n_2..*n_(p-1)
ind_2d = reshape_op(indices, (-1, ind_lastdim))
outerdim = shape_op(ind_2d)[0]
# [0, outterdim, 2*outerdim, ..., (k-1)*outerdim]
indices_dtype = dtype(indices)
range_flatten_index = _range_op(0, outerdim * in_lastdim, in_lastdim, indices_dtype)
# expand_dims to (k, 1), then broadcast
ind = reshape_op(ind_2d + expand_dims(range_flatten_index, -1), (-1,))
in_shape_1d = _get_1d_shape(in_shape)
out_grad = reshape_op(
scatter(
expand_dims(ind, -1),
reshape_op(dout[0], (-1,)),
in_shape_1d),
in_shape)
return out_grad, zeros_like(k)
return bprop return bprop
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册