未验证 提交 97e29411 编写于 作者: H huangxu96 提交者: GitHub

fix a bug in multi_precision_fp16 unittest. (#29756)

上级 2e5b4a21
...@@ -155,9 +155,10 @@ def train(use_pure_fp16=True, use_nesterov=False): ...@@ -155,9 +155,10 @@ def train(use_pure_fp16=True, use_nesterov=False):
loss, = exe.run(compiled_program, loss, = exe.run(compiled_program,
feed=feeder.feed(data), feed=feeder.feed(data),
fetch_list=[sum_cost]) fetch_list=[sum_cost])
loss_v = loss[0] if isinstance(loss, np.ndarray) else loss
print('PassID {0:1}, Train Batch ID {1:04}, train loss {2:2.4}'. print('PassID {0:1}, Train Batch ID {1:04}, train loss {2:2.4}'.
format(pass_id, batch_id + 1, float(loss))) format(pass_id, batch_id + 1, float(loss_v)))
train_loss_list.append(float(loss)) train_loss_list.append(float(loss_v))
if batch_id >= 4: # For speeding up CI if batch_id >= 4: # For speeding up CI
test_loss_list = [] test_loss_list = []
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册