提交 9080613d 编写于 作者: F François Chollet

Fix deprecation warnings related to TF v1

上级 cf9595ae
......@@ -109,6 +109,12 @@ def eager(func):
return eager_fn_wrapper
def _has_compat_v1():
if hasattr(tf, 'compat') and hasattr(tf.compat, 'v1'):
return True
return False
def get_uid(prefix=''):
"""Provides a unique UID given a string prefix.
......@@ -2270,7 +2276,11 @@ def _fused_normalize_batch_in_training(x, gamma, beta, reduction_axes,
if beta.dtype != tf.float32:
beta = tf.cast(beta, tf.float32)
return tf.nn.fused_batch_norm(
if _has_compat_v1:
fused_batch_norm = tf.compat.v1.nn.fused_batch_norm
else:
fused_batch_norm = tf.nn.fused_batch_norm
return fused_batch_norm(
x,
gamma,
beta,
......@@ -2373,7 +2383,12 @@ def batch_normalization(x, mean, var, beta, gamma, axis=-1, epsilon=1e-3):
if var.dtype != tf.float32:
var = tf.cast(var, tf.float32)
y, _, _ = tf.nn.fused_batch_norm(
if _has_compat_v1:
fused_batch_norm = tf.compat.v1.nn.fused_batch_norm
else:
fused_batch_norm = tf.nn.fused_batch_norm
y, _, _ = fused_batch_norm(
x,
gamma,
beta,
......
......@@ -853,7 +853,11 @@ def get(identifier):
if K.backend() == 'tensorflow':
# Wrap TF optimizer instances
if tf.__version__.startswith('1.'):
if isinstance(identifier, tf.train.Optimizer):
try:
TFOpt = tf.compat.v1.train.Optimizer
except AttributeError:
TFOpt = tf.train.Optimizer
if isinstance(identifier, TFOpt):
return TFOptimizer(identifier)
elif isinstance(identifier, tf.keras.optimizers.Optimizer):
return TFOptimizer(identifier)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册