提交 03d181cf 编写于 作者: C Cao Ying 提交者: GitHub

Merge pull request #2522 from emailweixu/print_layer

Allow printer layer to print user-provided message
...@@ -22,10 +22,33 @@ public: ...@@ -22,10 +22,33 @@ public:
void forward(PassType passType) override { void forward(PassType passType) override {
Layer::forward(passType); Layer::forward(passType);
std::vector<std::string> vals;
for (size_t i = 0; i != inputLayers_.size(); ++i) { for (size_t i = 0; i != inputLayers_.size(); ++i) {
getInput(i).printValueString(LOG(INFO), std::ostringstream s;
"layer=" + inputLayers_[i]->getName() + " "); getInput(i).printValueString(s, "");
vals.push_back(s.str());
} }
size_t pos = 0;
int i = 0;
std::ostringstream s;
const std::string& format = config_.user_arg();
while (true) {
size_t pos1 = format.find("%s", pos);
if (pos1 == std::string::npos) break;
if (i >= vals.size()) {
break;
}
s << format.substr(pos, pos1 - pos) << vals[i];
pos = pos1 + 2;
++i;
}
if (i != inputLayers_.size()) {
LOG(ERROR) << "Number of value in the format (" << format
<< ") is not same as the number of inputs ("
<< inputLayers_.size() << ") at " << getName();
}
s << format.substr(pos);
LOG(INFO) << s.str();
} }
void backward(const UpdateCallback& callback) override {} void backward(const UpdateCallback& callback) override {}
......
...@@ -1628,8 +1628,14 @@ class SelectiveFCLayer(LayerBase): ...@@ -1628,8 +1628,14 @@ class SelectiveFCLayer(LayerBase):
@config_layer('print') @config_layer('print')
class PrintLayer(LayerBase): class PrintLayer(LayerBase):
def __init__(self, name, inputs): def __init__(self, name, inputs, format=None):
super(PrintLayer, self).__init__(name, 'print', 0, inputs) super(PrintLayer, self).__init__(name, 'print', 0, inputs)
if format is None:
format = "\n".join([
"layer=" + input.input_layer_name + " %s"
for input in self.inputs
])
self.config.user_arg = format
@config_layer('priorbox') @config_layer('priorbox')
......
...@@ -964,7 +964,7 @@ def fc_layer(input, ...@@ -964,7 +964,7 @@ def fc_layer(input,
@wrap_name_default("print") @wrap_name_default("print")
def printer_layer(input, name=None): def printer_layer(input, format=None, name=None):
""" """
Print the output value of input layers. This layer is useful for debugging. Print the output value of input layers. This layer is useful for debugging.
...@@ -982,6 +982,7 @@ def printer_layer(input, name=None): ...@@ -982,6 +982,7 @@ def printer_layer(input, name=None):
Layer( Layer(
name=name, name=name,
format=format,
type=LayerType.PRINT_LAYER, type=LayerType.PRINT_LAYER,
inputs=[l.name for l in input], ) inputs=[l.name for l in input], )
# this layer don't return anything, can not be input of other layer. # this layer don't return anything, can not be input of other layer.
......
...@@ -12,6 +12,7 @@ layers { ...@@ -12,6 +12,7 @@ layers {
inputs { inputs {
input_layer_name: "input" input_layer_name: "input"
} }
user_arg: "layer=input %s"
} }
input_layer_names: "input" input_layer_names: "input"
output_layer_names: "input" output_layer_names: "input"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册