From 97e29411eb5a9c22ad9d2eee6d88c989eac206dc Mon Sep 17 00:00:00 2001 From: huangxu96 <46740794+huangxu96@users.noreply.github.com> Date: Mon, 21 Dec 2020 12:05:19 +0800 Subject: [PATCH] fix a bug in multi_precision_fp16 unittest. (#29756) --- .../fluid/contrib/tests/test_multi_precision_fp16_train.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/contrib/tests/test_multi_precision_fp16_train.py b/python/paddle/fluid/contrib/tests/test_multi_precision_fp16_train.py index 83b920642b..812b817b92 100644 --- a/python/paddle/fluid/contrib/tests/test_multi_precision_fp16_train.py +++ b/python/paddle/fluid/contrib/tests/test_multi_precision_fp16_train.py @@ -155,9 +155,10 @@ def train(use_pure_fp16=True, use_nesterov=False): loss, = exe.run(compiled_program, feed=feeder.feed(data), fetch_list=[sum_cost]) + loss_v = loss[0] if isinstance(loss, np.ndarray) else loss print('PassID {0:1}, Train Batch ID {1:04}, train loss {2:2.4}'. - format(pass_id, batch_id + 1, float(loss))) - train_loss_list.append(float(loss)) + format(pass_id, batch_id + 1, float(loss_v))) + train_loss_list.append(float(loss_v)) if batch_id >= 4: # For speeding up CI test_loss_list = [] -- GitLab