提交 2c7d8e7b 编写于 作者: E emailweixu 提交者: GitHub

Merge pull request #2369 from emailweixu/print_layer

Correctly handle print_layer in V2 API
...@@ -111,6 +111,7 @@ __all__ = [ ...@@ -111,6 +111,7 @@ __all__ = [
'block_expand_layer', 'block_expand_layer',
'maxout_layer', 'maxout_layer',
'out_prod_layer', 'out_prod_layer',
'printer_layer',
'print_layer', 'print_layer',
'priorbox_layer', 'priorbox_layer',
'cross_channel_norm_layer', 'cross_channel_norm_layer',
...@@ -969,7 +970,7 @@ def fc_layer(input, ...@@ -969,7 +970,7 @@ def fc_layer(input,
@wrap_name_default("print") @wrap_name_default("print")
def print_layer(input, name=None): def printer_layer(input, name=None):
""" """
Print the output value of input layers. This layer is useful for debugging. Print the output value of input layers. This layer is useful for debugging.
...@@ -991,6 +992,13 @@ def print_layer(input, name=None): ...@@ -991,6 +992,13 @@ def print_layer(input, name=None):
inputs=[l.name for l in input], ) inputs=[l.name for l in input], )
# this layer don't return anything, can not be input of other layer. # this layer don't return anything, can not be input of other layer.
# Keep print_layer for compatibility with V1 API.
# 'print_layer' does not work for V2 API because it will be changed to
# 'print' for V2 API. But 'print' is a reserved key word in python.
print_layer = printer_layer
@wrap_name_default("priorbox") @wrap_name_default("priorbox")
def priorbox_layer(input, def priorbox_layer(input,
......
...@@ -149,6 +149,20 @@ def __get_used_layers__(output_layers, extra_layers=None): ...@@ -149,6 +149,20 @@ def __get_used_layers__(output_layers, extra_layers=None):
for layer in output_layers: for layer in output_layers:
dfs_travel(layer.full_name) dfs_travel(layer.full_name)
# print layer needs to be specially handled because no other
# layer depends on it. It is used to print the result of some
# layers when running the model for debug purpose. So we explicitly
# add a print layer to the topolty if its input is in the toplogy.
for layer in cp.g_config.model_config.layers:
if layer.type == 'print':
used = True
for inp in layer.inputs:
if inp.input_layer_name not in layer_names:
used = False
break
if used:
layer_names.add(layer.name)
return layer_names return layer_names
......
...@@ -164,6 +164,7 @@ class OtherLayerTest(unittest.TestCase): ...@@ -164,6 +164,7 @@ class OtherLayerTest(unittest.TestCase):
maxid = layer.max_id(input=inference) maxid = layer.max_id(input=inference)
sampling_id = layer.sampling_id(input=inference) sampling_id = layer.sampling_id(input=inference)
eos = layer.eos(input=maxid, eos_id=5) eos = layer.eos(input=maxid, eos_id=5)
layer.printer(maxid)
print layer.parse_network([maxid, sampling_id, eos]) print layer.parse_network([maxid, sampling_id, eos])
def test_slicing_joining_layer(self): def test_slicing_joining_layer(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册