提交 ddee474e 编写于 作者: H Hongkun Yu 提交者: A. Unique TensorFlower

Internal change

PiperOrigin-RevId: 272043067
上级 77710731
......@@ -21,7 +21,7 @@ from __future__ import print_function
import tensorflow as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
from official.modeling.activations import gelu
from official.modeling import activations
@keras_parameterized.run_all_keras_modes
......@@ -30,7 +30,7 @@ class GeluTest(keras_parameterized.TestCase):
def test_gelu(self):
expected_data = [[0.14967535, 0., -0.10032465],
[-0.15880796, -0.04540223, 2.9963627]]
gelu_data = gelu.gelu([[.25, 0, -.25], [-1, -2, 3]])
gelu_data = activations.gelu([[.25, 0, -.25], [-1, -2, 3]])
self.assertAllClose(expected_data, gelu_data)
......
......@@ -21,14 +21,14 @@ from __future__ import print_function
import tensorflow as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
from official.modeling.activations import swish
from official.modeling import activations
@keras_parameterized.run_all_keras_modes
class CustomizedSwishTest(keras_parameterized.TestCase):
def test_gelu(self):
customized_swish_data = swish.swish([[.25, 0, -.25], [-1, -2, 3]])
customized_swish_data = activations.swish([[.25, 0, -.25], [-1, -2, 3]])
swish_data = tf.nn.swish([[.25, 0, -.25], [-1, -2, 3]])
self.assertAllClose(customized_swish_data, swish_data)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册