未验证 提交 12aa5e80 编写于 作者: A andyjpaddle 提交者: GitHub

Merge pull request #6152 from andyjpaddle/dygraph

fix sar training on windows
...@@ -99,8 +99,8 @@ class SAREncoder(nn.Layer): ...@@ -99,8 +99,8 @@ class SAREncoder(nn.Layer):
if valid_ratios is not None: if valid_ratios is not None:
valid_hf = [] valid_hf = []
T = holistic_feat.shape[1] T = holistic_feat.shape[1]
for i, valid_ratio in enumerate(valid_ratios): for i in range(len(valid_ratios)):
valid_step = min(T, math.ceil(T * valid_ratio)) - 1 valid_step = min(T, math.ceil(T * valid_ratios[i])) - 1
valid_hf.append(holistic_feat[i, valid_step, :]) valid_hf.append(holistic_feat[i, valid_step, :])
valid_hf = paddle.stack(valid_hf, axis=0) valid_hf = paddle.stack(valid_hf, axis=0)
else: else:
...@@ -252,8 +252,8 @@ class ParallelSARDecoder(BaseDecoder): ...@@ -252,8 +252,8 @@ class ParallelSARDecoder(BaseDecoder):
if valid_ratios is not None: if valid_ratios is not None:
# cal mask of attention weight # cal mask of attention weight
for i, valid_ratio in enumerate(valid_ratios): for i in range(len(valid_ratios)):
valid_width = min(w, math.ceil(w * valid_ratio)) valid_width = min(w, math.ceil(w * valid_ratios[i]))
if valid_width < w: if valid_width < w:
attn_weight[i, :, :, valid_width:, :] = float('-inf') attn_weight[i, :, :, valid_width:, :] = float('-inf')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册