提交 f7b7b0ab 编写于 作者: F fangzehua

fix reciprocal grad gpu and resizeNearestNeighor

上级 3e7ba14e
......@@ -425,11 +425,22 @@ def get_bprop_rsqrt(self):
@bprop_getters.register(P.Reciprocal)
def get_bprop_reciprocal(self):
"""Grad definition for `Reciprocal` operation."""
reciprocal_grad = G.ReciprocalGrad()
if self.target == "GPU":
neg = P.Neg()
mul = P.Mul()
square = P.Square()
reciprocal = P.Reciprocal()
def bprop(x, out, dout):
g = neg(reciprocal(square(x)))
dx = mul(dout, g)
return (dx,)
else:
reciprocal_grad = G.ReciprocalGrad()
def bprop(x, out, dout):
dx = reciprocal_grad(out, dout)
return (dx,)
def bprop(x, out, dout):
dx = reciprocal_grad(out, dout)
return (dx,)
return bprop
......
......@@ -2375,6 +2375,8 @@ class ResizeNearestNeighbor(PrimitiveWithInfer):
return tuple(x)[:-2] + tuple(self.size)
def infer_dtype(self, x):
validator.check_subclass("x", x, mstype.tensor, self.name)
validator.check_tensor_type_same({"x": x}, mstype.number_type, self.name)
return x
......
......@@ -1290,6 +1290,10 @@ class Reciprocal(PrimitiveWithInfer):
@prim_attr_register
def __init__(self):
"""init Reciprocal"""
if context.get_context("device_target") == "GPU":
self.target = "GPU"
else:
self.target = "OTHER"
self.init_prim_io_names(inputs=['x'], outputs=['y'])
def infer_shape(self, x):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册