提交 cf1628a3 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!2741 fix BatchToSpaceND

Merge pull request !2741 from jiangjinsheng/issue_fix4
......@@ -25,8 +25,8 @@ batch_to_space_nd_op_info = TBERegOp("BatchToSpaceND") \
.partial_flag(True) \
.attr("block_shape", "required", "listInt", "all") \
.attr("crops", "required", "listListInt", "all") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.input(0, "x", False, "required", "all", reshape_type="NH") \
.output(0, "y", False, "required", "all", reshape_type="NH") \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
.get_op_info()
......
......@@ -27,6 +27,8 @@ conv2d_backprop_filter_op_info = TBERegOp("Conv2DBackpropFilter") \
.attr("stride", "required", "listInt", "all") \
.attr("pad_list", "required", "listInt", "all") \
.attr("dilation", "required", "listInt", "all") \
.attr("groups", "optional", "int", "all") \
.attr("data_format", "optional", "str", "all") \
.input(0, "out_backprop", False, "required", "all") \
.input(1, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
......
......@@ -25,8 +25,8 @@ space_to_batch_nd_op_info = TBERegOp("SpaceToBatchND") \
.partial_flag(True) \
.attr("block_shape", "required", "listInt", "all") \
.attr("paddings", "required", "listListInt", "all") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.input(0, "x", False, "required", "all", reshape_type="NH") \
.output(0, "y", False, "required", "all", reshape_type="NH") \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
.get_op_info()
......
......@@ -237,6 +237,7 @@ class Conv2DBackpropFilter(PrimitiveWithInfer):
self.add_prim_attr('stride', self.stride)
self.dilation = dilation
self.group = group
self.add_prim_attr('groups', group)
self.add_prim_attr('data_format', "NCHW")
def __infer__(self, doutput, x, w_size):
......
......@@ -2636,16 +2636,20 @@ class SpaceToBatchND(PrimitiveWithInfer):
def infer_shape(self, x_shape):
x_rank = len(x_shape)
validator.check_integer('x_shape rank', x_rank, 4, Rel.EQ, self.name)
out_shape = copy.deepcopy(x_shape)
block_shape_prod = 1
for i in range(x_rank - 2):
padded = out_shape[i + 2] + self.paddings[i][0] + \
offset = 2
if x_rank < 4:
offset = 1
for i in range(len(self.block_shape)):
padded = out_shape[i + offset] + self.paddings[i][0] + \
self.paddings[i][1]
if padded % self.block_shape[i] != 0:
raise ValueError(f'For \'{self.name}\' padded[{i}] {padded} should be divisible by '
f'block_shape[{i}] {self.block_shape[i]}')
out_shape[i + 2] = padded // self.block_shape[i]
out_shape[i + offset] = padded // self.block_shape[i]
block_shape_prod = block_shape_prod * self.block_shape[i]
out_shape[0] *= block_shape_prod
return out_shape
......@@ -2716,15 +2720,19 @@ class BatchToSpaceND(PrimitiveWithInfer):
def infer_shape(self, x_shape):
x_rank = len(x_shape)
validator.check_integer('x_shape rank', x_rank, 4, Rel.EQ, self.name)
out_shape = copy.deepcopy(x_shape)
block_shape_prod = 1
for i in range(x_rank - 2):
offset = 2
if x_rank < 4:
offset = 1
for i in range(len(self.block_shape)):
block_shape_prod = block_shape_prod * self.block_shape[i]
x_block_prod = out_shape[i + 2] * self.block_shape[i]
x_block_prod = out_shape[i + offset] * self.block_shape[i]
crops_sum = self.crops[i][0] + self.crops[i][1]
validator.check("x block shape prod", x_block_prod, 'crops sum', crops_sum, Rel.GT, self.name)
out_shape[i + 2] = x_block_prod - crops_sum
out_shape[i + offset] = x_block_prod - crops_sum
if out_shape[0] % block_shape_prod != 0:
raise ValueError(f'For \'{self.name}\' input_x dimension 0 {out_shape[0]} should be divisible by '
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册