提交 b9d7e4e6 编写于 作者: Z Ziyan

add uniform split in the bprop of concat

上级 4bdd8e16
...@@ -220,19 +220,37 @@ def get_bprop_transpose(self): ...@@ -220,19 +220,37 @@ def get_bprop_transpose(self):
return bprop return bprop
@constexpr
def _concat_grad_uniform(input_shapes, input_nums):
"""Helper function for bprop of Concat"""
is_uniform = True
for i in range(1, input_nums):
if input_shapes[i-1] != input_shapes[i]:
is_uniform = False
break
return is_uniform
@bprop_getters.register(P.Concat) @bprop_getters.register(P.Concat)
def get_bprop_concat(self): def get_bprop_concat(self):
"""Generate bprop for Concat""" """Generate bprop for Concat"""
axis = self.axis axis = self.axis
is_ascend = context.get_context('device_target') == "Ascend"
def bprop(x, out, dout): def bprop(x, out, dout):
dx = () dx = ()
out_offset = G.ConcatOffset(F.tuple_len(x), axis)(x) out_offset = G.ConcatOffset(F.tuple_len(x), axis)(x)
for i in range(F.tuple_len(x)): input_nums = F.tuple_len(x)
slice_out = P.Slice()(dout, out_offset[i], shape_op(x[i])) input_shapes = ()
dx = dx + (slice_out,) for i in range(input_nums):
input_shapes = input_shapes + (shape_op(x[i]),)
is_uniform = _concat_grad_uniform(input_shapes, input_nums)
if is_uniform and is_ascend:
dx = P.Split(axis, input_nums)(dout)
else:
for i in range(input_nums):
slice_out = P.Slice()(dout, out_offset[i], input_shapes[i])
dx = dx + (slice_out,)
return (dx,) return (dx,)
return bprop return bprop
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册