未验证 提交 b7b5de7f 编写于 作者: F fengjiayi 提交者: GitHub

Merge pull request #7665 from JiayiFeng/dev_update_auto-registry

make auto-registry layers supporting specified output
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
# #
#Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
#You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
#Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
#limitations under the License. # limitations under the License.
import re import re
import cStringIO import cStringIO
import warnings import warnings
...@@ -167,13 +167,18 @@ def register_layer(op_type): ...@@ -167,13 +167,18 @@ def register_layer(op_type):
inputs[ipt.name] = val inputs[ipt.name] = val
outputs = dict() outputs = dict()
out = helper.create_tmp_variable(dtype=dtype) out = kwargs.pop(_convert_(o_name), [])
outputs[o_name] = [out] if out:
out_var = out[0] if (isinstance(out, list) or
isinstance(out, tuple)) else out
else:
out_var = helper.create_tmp_variable(dtype=dtype)
outputs[o_name] = [out_var]
for name in intermediate_output_names: for name in intermediate_output_names:
outputs[name] = [helper.create_tmp_variable(dtype=dtype)] outputs[name] = [helper.create_tmp_variable(dtype=dtype)]
helper.append_op( helper.append_op(
type=op_type, inputs=inputs, outputs=outputs, attrs=kwargs) type=op_type, inputs=inputs, outputs=outputs, attrs=kwargs)
return helper.append_activation(out) return helper.append_activation(out_var)
func.__name__ = op_type func.__name__ = op_type
func.__doc__ = _generate_doc_string_(op_proto) func.__doc__ = _generate_doc_string_(op_proto)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册