提交 e4ace4ca 编写于 作者: K Ken Franko 提交者: TensorFlower Gardener

Add test for custom TF Module based Keras model.

PiperOrigin-RevId: 295000925
Change-Id: I254678450994c63a422c0953e6c4a6a484afe4e0
上级 22f80136
......@@ -30,10 +30,27 @@ from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.eager import backprop
from tensorflow.python.eager import def_function
from tensorflow.python.eager import test
from tensorflow.python.module import module
from tensorflow.python.ops import math_ops
from tensorflow.python.util import nest
class CustomModel(module.Module):
def __init__(self, name=None):
super(CustomModel, self).__init__(name=name)
with self.name_scope:
self._layers = [
keras.layers.Dense(4, name="dense"),
]
@module.Module.with_name_scope
def __call__(self, x):
for layer in self._layers:
x = layer(x)
return x
class KerasModelsTest(test.TestCase, parameterized.TestCase):
@combinations.generate(
......@@ -325,6 +342,35 @@ class KerasModelsTest(test.TestCase, parameterized.TestCase):
for model_v, model2_v in zip(model.variables, model2.variables):
self.assertAllClose(model_v.numpy(), model2_v.numpy())
@combinations.generate(
combinations.combine(
distribution=strategy_combinations.all_strategies,
mode=["eager"]
))
def test_customized_tf_module_experimental_run(self, distribution):
dataset = self._get_dataset()
input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
with distribution.scope():
model = CustomModel()
@def_function.function
def train_step(iterator):
def step_fn(inputs):
images, targets = inputs
with backprop.GradientTape() as tape:
outputs = model(images)
loss = math_ops.reduce_sum(outputs - targets)
grads = tape.gradient(loss, model.variables)
return grads
outputs = distribution.experimental_run_v2(
step_fn, args=(next(iterator),))
return nest.map_structure(distribution.experimental_local_results,
outputs)
train_step(input_iterator)
def _get_dataset(self):
inputs = np.zeros((10, 3), dtype=np.float32)
targets = np.zeros((10, 4), dtype=np.float32)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册