提交 72d2fc74 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!4050 fix unsortedgrad clipbynorm boundingboxdecode

Merge pull request !4050 from fangzehua/unsortedgrad
......@@ -250,6 +250,10 @@ def _is_equal_one(x):
return False
return bool(x.asnumpy().mean() == 1.0)
@constexpr
def _dtype_check(x_dtype):
if x_dtype not in [mstype.float32, mstype.float16]:
raise TypeError("The input type must be float32 or float16.")
class ClipByNorm(Cell):
r"""
......@@ -264,12 +268,11 @@ class ClipByNorm(Cell):
where :math:`L_2(X)` is the :math:`L_2`-norm of :math:`X`.
Inputs:
- **input** (Tensor) - Tensor of shape N-D.
- **clip_norm** (Tensor) - A scalar Tensor of shape :math:`()` or :math:`(1)` and of
the same type as the input Tensor.
- **input** (Tensor) - Tensor of shape N-D. The type should be float32 or float16.
- **clip_norm** (Tensor) - A scalar Tensor of shape :math:`()` or :math:`(1)`.
Outputs:
Tensor, clipped tensor with the same shape as the input.
Tensor, clipped tensor with the same shape as the input, whose type is float32.
Examples:
>>> net = nn.ClipByNorm()
......@@ -300,10 +303,10 @@ class ClipByNorm(Cell):
l2sum = self.cast(self.reduce_sum(mul_x), mstype.float32)
cond = self.greater_(l2sum, 0)
ones_ = self.fill(self.dtype(cond), self.shape(cond), 1.0)
l2sum_safe = self.select_(cond, l2sum, self.cast(ones_, self.dtype(l2sum)))
l2norm = self.select_(cond, self.sqrt(l2sum_safe), l2sum)
_dtype_check(self.dtype(x))
if _is_equal_one(clip_norm):
intermediate = x
else:
......
......@@ -827,13 +827,3 @@ def get_bprop_unique(self):
dx = op(dout, out)
return (dx,)
return bprop
@bprop_getters.register(P.UnsortedSegmentSum)
def get_bprop_unsorted_segment_sum(self):
"""Generate bprop for UnsortedSegmentSum"""
op = G.UnsortedSegmentSumGrad()
def bprop(x, segment_ids, num_segments, out, dout):
dx = op(dout, segment_ids)
return (dx, zeros_like(segment_ids), zeros_like(num_segments))
return bprop
......@@ -502,20 +502,6 @@ class UniqueGrad(Primitive):
raise NotImplementedError
class UnsortedSegmentSumGrad(PrimitiveWithInfer):
"""Gradients of UnsortedSegmentSum operation."""
@prim_attr_register
def __init__(self):
self.init_prim_io_names(inputs=['grads', 'ids'], outputs=['y'])
def infer_shape(self, grads, ids):
return ids + grads[len(ids):]
def infer_dtype(self, grads, ids):
return grads
class BNTrainingReduceGrad(PrimitiveWithInfer):
"""Gradients of FusedBatchNorm operation."""
......
......@@ -93,8 +93,12 @@ class BoundingBoxEncode(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, means=(0.0, 0.0, 0.0, 0.0), stds=(1.0, 1.0, 1.0, 1.0)):
validator.check_value_type('means', means, [tuple], self.name)
validator.check_value_type('stds', stds, [tuple], self.name)
validator.check_value_type('means', means, [tuple, list], self.name)
validator.check_value_type('stds', stds, [tuple, list], self.name)
for i, value in enumerate(means):
validator.check_value_type("means[%d]" % i, value, [float], self.name)
for i, value in enumerate(stds):
validator.check_value_type("stds[%d]" % i, value, [float], self.name)
validator.check_integer("means len", len(means), 4, Rel.EQ, self.name)
validator.check_integer("stds len", len(stds), 4, Rel.EQ, self.name)
......@@ -143,8 +147,12 @@ class BoundingBoxDecode(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, max_shape, means=(0.0, 0.0, 0.0, 0.0), stds=(1.0, 1.0, 1.0, 1.0), wh_ratio_clip=0.016):
validator.check_value_type('means', means, [tuple], self.name)
validator.check_value_type('stds', stds, [tuple], self.name)
validator.check_value_type('means', means, [tuple, list], self.name)
validator.check_value_type('stds', stds, [tuple, list], self.name)
for i, value in enumerate(means):
validator.check_value_type("means[%d]" % i, value, [float], self.name)
for i, value in enumerate(stds):
validator.check_value_type("stds[%d]" % i, value, [float], self.name)
validator.check_value_type('wh_ratio_clip', wh_ratio_clip, [float], self.name)
validator.check_integer("means len", len(means), 4, Rel.EQ, self.name)
validator.check_integer("stds len", len(stds), 4, Rel.EQ, self.name)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册