未验证 提交 db41e39e 编写于 作者: W Weilong Wu 提交者: GitHub

Support test_layers(group_norm,while_loop) with eager mode (#40816)

上级 fdafbc7b
...@@ -89,6 +89,7 @@ std::map<std::string, std::set<std::string>> op_ins_map = { ...@@ -89,6 +89,7 @@ std::map<std::string, std::set<std::string>> op_ins_map = {
{"Input", "Label", "Weight", "Bias", "SampleWeight", "CustomDistProbs", {"Input", "Label", "Weight", "Bias", "SampleWeight", "CustomDistProbs",
"CustomDistAlias", "CustomDistAliasProbs"}}, "CustomDistAlias", "CustomDistAliasProbs"}},
{"check_finite_and_unscale", {"X", "Scale", "FloatStatus"}}, {"check_finite_and_unscale", {"X", "Scale", "FloatStatus"}},
{"group_norm", {"X", "Scale", "Bias"}},
}; };
// NOTE(zhiqiu): Like op_ins_map. // NOTE(zhiqiu): Like op_ins_map.
......
...@@ -2986,6 +2986,12 @@ class GroupNorm(layers.Layer): ...@@ -2986,6 +2986,12 @@ class GroupNorm(layers.Layer):
is_bias=True) is_bias=True)
def forward(self, input): def forward(self, input):
if in_dygraph_mode():
attrs = ('epsilon', self._epsilon, 'groups', self._groups)
out, _, _ = _C_ops.group_norm(input, self.weight, self.bias, *attrs)
return dygraph_utils._append_activation_in_dygraph(out, self._act)
inputs = {'X': input} inputs = {'X': input}
if self.bias is not None: if self.bias is not None:
inputs['Bias'] = self.bias inputs['Bias'] = self.bias
......
...@@ -1819,7 +1819,7 @@ class TestLayer(LayerTest): ...@@ -1819,7 +1819,7 @@ class TestLayer(LayerTest):
self.assertTrue(np.allclose(static_ret, static_ret2)) self.assertTrue(np.allclose(static_ret, static_ret2))
def test_group_norm(self): def func_group_norm(self):
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
else: else:
...@@ -1873,7 +1873,6 @@ class TestLayer(LayerTest): ...@@ -1873,7 +1873,6 @@ class TestLayer(LayerTest):
with_lod=True)[0] with_lod=True)[0]
with self.dynamic_graph(): with self.dynamic_graph():
# TODO(wuweilong): Add with _test_eager_guard():
groupNorm = nn.GroupNorm( groupNorm = nn.GroupNorm(
channels=shape[1], channels=shape[1],
groups=2, groups=2,
...@@ -1886,6 +1885,11 @@ class TestLayer(LayerTest): ...@@ -1886,6 +1885,11 @@ class TestLayer(LayerTest):
self.assertTrue(np.allclose(static_ret, dy_rlt_value)) self.assertTrue(np.allclose(static_ret, dy_rlt_value))
self.assertTrue(np.allclose(static_ret, static_ret2)) self.assertTrue(np.allclose(static_ret, static_ret2))
def test_group_norm(self):
with _test_eager_guard():
self.func_group_norm()
self.func_group_norm()
def test_instance_norm(self): def test_instance_norm(self):
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
...@@ -2348,7 +2352,7 @@ class TestLayer(LayerTest): ...@@ -2348,7 +2352,7 @@ class TestLayer(LayerTest):
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
layers.eye(num_rows=3, batch_shape=[-1]) layers.eye(num_rows=3, batch_shape=[-1])
def test_while_loop(self): def func_while_loop(self):
with self.static_graph(): with self.static_graph():
i = layers.fill_constant(shape=[1], dtype='int64', value=0) i = layers.fill_constant(shape=[1], dtype='int64', value=0)
ten = layers.fill_constant(shape=[1], dtype='int64', value=10) ten = layers.fill_constant(shape=[1], dtype='int64', value=10)
...@@ -2363,7 +2367,6 @@ class TestLayer(LayerTest): ...@@ -2363,7 +2367,6 @@ class TestLayer(LayerTest):
static_ret = self.get_static_graph_result(feed={}, fetch_list=out) static_ret = self.get_static_graph_result(feed={}, fetch_list=out)
with self.dynamic_graph(): with self.dynamic_graph():
# TODO(wuweilong): Add with _test_eager_guard():
i = layers.fill_constant(shape=[1], dtype='int64', value=0) i = layers.fill_constant(shape=[1], dtype='int64', value=0)
ten = layers.fill_constant(shape=[1], dtype='int64', value=10) ten = layers.fill_constant(shape=[1], dtype='int64', value=10)
...@@ -2384,6 +2387,11 @@ class TestLayer(LayerTest): ...@@ -2384,6 +2387,11 @@ class TestLayer(LayerTest):
self.assertTrue(np.array_equal(static_ret[0], dy_ret[0].numpy())) self.assertTrue(np.array_equal(static_ret[0], dy_ret[0].numpy()))
def test_while_loop(self):
with _test_eager_guard():
self.func_while_loop()
self.func_while_loop()
def test_compare(self): def test_compare(self):
value_a = np.arange(3) value_a = np.arange(3)
value_b = np.arange(3) value_b = np.arange(3)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册