提交 f26e2d64 编写于 作者: R Ruoxin Sang 提交者: TensorFlower Gardener

Add a repro failing test for "No gradients provided for any variable" issue.

PiperOrigin-RevId: 381088871
上级 a4f3d7d9
......@@ -351,6 +351,42 @@ class KerasModelsTest(tf.test.TestCase, parameterized.TestCase):
train_steps(input_iterator)
def test_nested_tf_functions_with_tf_function_passing_to_strategy_run(
self, distribution):
self.skipTest("b/190608193")
inputs = np.random.random((10, 3)).astype(np.float32)
targets = np.ones((10, 4), dtype=np.float32)
dataset = tf.data.Dataset.from_tensor_slices((inputs, targets)).repeat()
dataset = dataset.batch(10)
input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
def get_model():
x = keras.layers.Input(shape=(3,), name="input")
y = keras.layers.Dense(4, name="dense")(x)
model = keras.Model(x, y)
return model
with distribution.scope():
model = get_model()
optimizer = keras.optimizer_v2.gradient_descent.SGD(0.1, momentum=0.01)
@tf.function
def compute_loss(images, targets):
outputs = model(images)
return keras.losses.mean_squared_error(targets, outputs)
@tf.function
def step_fn(inputs):
images, targets = inputs
with tf.GradientTape() as tape:
loss = compute_loss(images, targets)
grads = tape.gradient(loss, model.variables)
optimizer.apply_gradients(zip(grads, model.variables))
inputs = next(input_iterator)
distribution.run(step_fn, args=(inputs,))
def test_customized_tf_module_run(self, distribution):
dataset = _get_dataset()
input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册