diff --git a/official/projects/basnet/tasks/basnet.py b/official/projects/basnet/tasks/basnet.py index 5cb71d883cb398760845fcae6020267e797ee8bd..fcb2186166dcfcc85303480070b928acb7e9c2a6 100644 --- a/official/projects/basnet/tasks/basnet.py +++ b/official/projects/basnet/tasks/basnet.py @@ -203,8 +203,7 @@ class BASNetTask(base_task.Task): # For mixed_precision policy, when LossScaleOptimizer is used, loss is # scaled for numerical stability. - if isinstance( - optimizer, tf.keras.mixed_precision.experimental.LossScaleOptimizer): + if isinstance(optimizer, tf.keras.mixed_precision.LossScaleOptimizer): scaled_loss = optimizer.get_scaled_loss(scaled_loss) tvars = model.trainable_variables @@ -212,8 +211,7 @@ class BASNetTask(base_task.Task): # Scales back gradient before apply_gradients when LossScaleOptimizer is # used. - if isinstance( - optimizer, tf.keras.mixed_precision.experimental.LossScaleOptimizer): + if isinstance(optimizer, tf.keras.mixed_precision.LossScaleOptimizer): grads = optimizer.get_unscaled_gradients(grads) # Apply gradient clipping.