提交 5962c6ef 编写于 作者: Z zhaozhenlong

solve broadcast two same shape bprop error

make unsupported shape error info explicit
上级 9bc2ffde
......@@ -673,6 +673,10 @@ def get_bprop_broadcast_to(self):
def bprop(x, out, dout):
x_shape = shape_op(x)
dout_shape = shape_op(dout)
if x_shape == dout_shape:
return (dout,)
_, reduction_axes = broadcast_gradient_args(broadcast_shape, x_shape)
reduced_grad = reduce_keep_dim(dout, reduction_axes)
dx = reshape(reduced_grad, x_shape)
......
......@@ -2719,6 +2719,8 @@ class BatchToSpaceND(PrimitiveWithInfer):
class BroadcastTo(PrimitiveWithInfer):
"""
Broadcasts input tensor to a given shape.
Input shape can be broadcast to target shape if for each dimension pair they are either equal or input is one.
When input shape is broadcast to target shape, it starts with the trailing dimensions.
Args:
shape (tuple): The target shape to broadcast.
......@@ -2741,11 +2743,20 @@ class BroadcastTo(PrimitiveWithInfer):
def __init__(self, shape):
"""Init BroadcastTo"""
validator.check_value_type("shape", shape, (tuple), self.name)
validator.check("shape length", len(shape), "", 0, Rel.GT, self.name)
for i in shape:
validator.check_integer("shape element", i, 0, Rel.GT, self.name)
self.shape = shape
def infer_shape(self, x_shape):
validator.check("input_x shape length", len(x_shape), "target shape", len(self.shape), Rel.LE, self.name)
reversed_x_shape = tuple(reversed(x_shape))
reversed_target = tuple(reversed(self.shape))
for i, v in enumerate(reversed_x_shape):
if v not in (reversed_target[i], 1):
raise ValueError(f"Not supported shapes for broadcast, "
f"x_shape: {tuple(x_shape)}, target shape {self.shape}.")
return self.shape
def infer_dtype(self, x_dtype):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册