diff --git a/dygraph/train.py b/dygraph/train.py index abc55179be14e844bb4ba87e65181f26e80f9a74..c8dce1b0f8dbbb829910d9afef10137ca3b6e7a1 100644 --- a/dygraph/train.py +++ b/dygraph/train.py @@ -171,7 +171,7 @@ def train(model, loss = model(images, labels, mode='train') loss.backward() optimizer.minimize(loss) - model_parallel.clear_gradients() + model.clear_gradients() logging.info("[TRAIN] Epoch={}/{}, Step={}/{}, loss={}".format( epoch + 1, num_epochs, step + 1, num_steps_each_epoch, loss.numpy()))