未验证 提交 697facc9 编写于 作者: D dzhwinter 提交者: GitHub

"add registry interface" (#6449)

* "add registry interface"

* "move function to registry"

* "rename with meaningful name"

* "add exposed layers"

* "fixed based on comments"

* "remove unsed comments"
上级 8ad36cdb
import core import contextlib
import proto.framework_pb2 as framework_pb2 import proto.framework_pb2 as framework_pb2
import core
from framework import OpProtoHolder, Variable, Program, Operator from framework import OpProtoHolder, Variable, Program, Operator
from initializer import Constant, Normal, Xavier, Initializer from initializer import Constant, Normal, Xavier, Initializer
from paddle.v2.fluid.layer_helper import LayerHelper, unique_name from paddle.v2.fluid.layer_helper import LayerHelper, unique_name
import re from registry import register_layer
import cStringIO
from param_attr import ParamAttr from param_attr import ParamAttr
import contextlib
__all__ = [ __all__ = [
'fc', 'data', 'cross_entropy', 'conv2d', 'pool2d', 'embedding', 'concat', 'fc', 'data', 'cross_entropy', 'conv2d', 'pool2d', 'embedding', 'concat',
...@@ -14,6 +14,15 @@ __all__ = [ ...@@ -14,6 +14,15 @@ __all__ = [
'batch_norm', 'accuracy', 'split_lod_tensor', 'While' 'batch_norm', 'accuracy', 'split_lod_tensor', 'While'
] ]
_REGISTER_LAYER_FROM_OPS = [
'mean', 'mul', 'elementwise_add', 'elementwise_div', 'dropout', 'reshape',
'sigmoid', 'scale', 'transpose', 'sigmoid_cross_entropy_with_logits'
]
for _OP in set(_REGISTER_LAYER_FROM_OPS):
globals()[_OP] = register_layer(_OP)
__all__.append(_OP)
def fc(input, def fc(input,
size, size,
...@@ -309,174 +318,6 @@ def create_tensor(dtype, name=None, main_program=None, startup_program=None): ...@@ -309,174 +318,6 @@ def create_tensor(dtype, name=None, main_program=None, startup_program=None):
return helper.create_variable(name=helper.name, dtype=dtype) return helper.create_variable(name=helper.name, dtype=dtype)
def _convert_(name):
"""
Formatting.
Args:
name: The name/alias
This function takes in a name and converts it to a standard format of
group1_group2. Where as per the regular expression, group1 can have
alphabets and numbers and group2 has capital alphabets.
"""
s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name)
return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower()
def _generate_doc_string_(op_proto):
"""
Generate docstring by OpProto
Args:
op_proto (framework_pb2.OpProto): a protobuf message typed OpProto
Returns:
str: the document string
"""
def _type_to_str_(tp):
return framework_pb2.AttrType.Name(tp)
if not isinstance(op_proto, framework_pb2.OpProto):
raise TypeError("OpProto should be `framework_pb2.OpProto`")
buf = cStringIO.StringIO()
buf.write(op_proto.comment)
buf.write('\nArgs:\n')
for each_input in op_proto.inputs:
line_begin = ' {0}: '.format(_convert_(each_input.name))
buf.write(line_begin)
buf.write(each_input.comment)
buf.write('\n')
buf.write(' ' * len(line_begin))
buf.write('Duplicable: ')
buf.write(str(each_input.duplicable))
buf.write(' Optional: ')
buf.write(str(each_input.dispensable))
buf.write('\n')
for each_attr in op_proto.attrs:
buf.write(' ')
buf.write(each_attr.name)
buf.write(' (')
buf.write(_type_to_str_(each_attr.type))
buf.write('): ')
buf.write(each_attr.comment)
buf.write('\n')
if len(op_proto.outputs) != 0:
buf.write('\nReturns:\n')
buf.write(' ')
for each_opt in op_proto.outputs:
if not each_opt.intermediate:
break
buf.write(each_opt.comment)
return buf.getvalue()
def _create_op_func_(op_type):
"""
Create an Operator for a Function.
Args:
op_type: The name of the operator to be created
This function takes in the operator type (sigmoid, mean , average etc) and
creates the operator functionality.
"""
op_proto = OpProtoHolder.instance().get_op_proto(op_type)
not_intermediate_outputs = \
filter(lambda output: not output.intermediate, op_proto.outputs)
intermediate_outputs = \
filter(lambda output: output.intermediate, op_proto.outputs)
if len(not_intermediate_outputs) != 1:
raise ValueError("Only one non intermediate output operator can be",
"automatically generated")
if not_intermediate_outputs[0].duplicable:
raise ValueError(
"Only non duplicable op can be automatically generated")
for output in intermediate_outputs:
if output.duplicable:
raise ValueError("The op can be automatically generated only when ",
"all intermediate ops are not duplicable")
o_name = not_intermediate_outputs[0].name
intermediate_output_names = [output.name for output in intermediate_outputs]
def infer_and_check_dtype(op_proto, **kwargs):
"""
This function performs the sanity check for dtype and
instance type.
"""
dtype = None
for ipt in op_proto.inputs:
name = _convert_(ipt.name)
val = kwargs.pop(name, [])
if not isinstance(val, list) and not isinstance(val, tuple):
val = [val]
for each in val:
if not isinstance(each, Variable):
raise ValueError("input of {0} must be variable".format(
op_type))
if dtype is None:
dtype = each.dtype
elif dtype != each.dtype:
raise ValueError(
"operator {0} must input same dtype. {1} vs {2}".format(
op_type, dtype, each.dtype))
return dtype
def func(**kwargs):
helper = LayerHelper(op_type, **kwargs)
dtype = infer_and_check_dtype(op_proto, **kwargs)
inputs = dict()
for ipt in op_proto.inputs:
name = _convert_(ipt.name)
val = kwargs.pop(name, [])
if not isinstance(val, list) and not isinstance(val, tuple):
val = [val]
inputs[ipt.name] = val
outputs = dict()
out = helper.create_tmp_variable(dtype=dtype)
outputs[o_name] = [out]
for name in intermediate_output_names:
outputs[name] = [helper.create_tmp_variable(dtype=dtype)]
helper.append_op(
type=op_type, inputs=inputs, outputs=outputs, attrs=kwargs)
return helper.append_activation(out)
func.__name__ = op_type
globals()[op_type] = func
func.__doc__ = _generate_doc_string_(op_proto)
global __all__
__all__.append(op_type)
_create_op_func_('mean')
_create_op_func_('mul')
_create_op_func_('elementwise_add')
_create_op_func_('elementwise_div')
_create_op_func_('dropout')
_create_op_func_('reshape')
_create_op_func_('sigmoid')
_create_op_func_('scale')
_create_op_func_('reshape')
_create_op_func_('transpose')
_create_op_func_('sigmoid_cross_entropy_with_logits')
def cast(x, dtype, main_program=None): def cast(x, dtype, main_program=None):
""" """
This function takes in the input with input_dtype This function takes in the input with input_dtype
......
import re
import cStringIO
import warnings
import functools
import inspect
import proto.framework_pb2 as framework_pb2
from framework import OpProtoHolder, Variable, Program, Operator
from paddle.v2.fluid.layer_helper import LayerHelper, unique_name
__all__ = ['deprecated', 'register_layer']
def _convert_(name):
"""
Formatting.
Args:
name: The name/alias
This function takes in a name and converts it to a standard format of
group1_group2. Where as per the regular expression, group1 can have
alphabets and numbers and group2 has capital alphabets.
"""
s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name)
return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower()
def _generate_doc_string_(op_proto):
"""
Generate docstring by OpProto
Args:
op_proto (framework_pb2.OpProto): a protobuf message typed OpProto
Returns:
str: the document string
"""
def _type_to_str_(tp):
return framework_pb2.AttrType.Name(tp)
if not isinstance(op_proto, framework_pb2.OpProto):
raise TypeError("OpProto should be `framework_pb2.OpProto`")
buf = cStringIO.StringIO()
buf.write(op_proto.comment)
buf.write('\nArgs:\n')
for each_input in op_proto.inputs:
line_begin = ' {0}: '.format(_convert_(each_input.name))
buf.write(line_begin)
buf.write(each_input.comment)
buf.write('\n')
buf.write(' ' * len(line_begin))
buf.write('Duplicable: ')
buf.write(str(each_input.duplicable))
buf.write(' Optional: ')
buf.write(str(each_input.dispensable))
buf.write('\n')
for each_attr in op_proto.attrs:
buf.write(' ')
buf.write(each_attr.name)
buf.write(' (')
buf.write(_type_to_str_(each_attr.type))
buf.write('): ')
buf.write(each_attr.comment)
buf.write('\n')
if len(op_proto.outputs) != 0:
buf.write('\nReturns:\n')
buf.write(' ')
for each_opt in op_proto.outputs:
if not each_opt.intermediate:
break
buf.write(each_opt.comment)
return buf.getvalue()
def register_layer(op_type):
"""
Register an Python layer for an Operator
Args:
op_type: The name of the operator to be created
This function takes in the operator type (sigmoid, mean , average etc) and
creates the operator functionality.
"""
op_proto = OpProtoHolder.instance().get_op_proto(op_type)
not_intermediate_outputs = \
filter(lambda output: not output.intermediate, op_proto.outputs)
intermediate_outputs = \
filter(lambda output: output.intermediate, op_proto.outputs)
if len(not_intermediate_outputs) != 1:
raise ValueError("Only one non intermediate output operator can be",
"automatically generated")
if not_intermediate_outputs[0].duplicable:
raise ValueError(
"Only non duplicable op can be automatically generated")
for output in intermediate_outputs:
if output.duplicable:
raise ValueError("The op can be automatically generated only when ",
"all intermediate ops are not duplicable")
o_name = not_intermediate_outputs[0].name
intermediate_output_names = [output.name for output in intermediate_outputs]
def infer_and_check_dtype(op_proto, **kwargs):
"""
This function performs the sanity check for dtype and
instance type.
"""
dtype = None
for ipt in op_proto.inputs:
name = _convert_(ipt.name)
val = kwargs.pop(name, [])
if not isinstance(val, list) and not isinstance(val, tuple):
val = [val]
for each in val:
if not isinstance(each, Variable):
raise ValueError("input of {0} must be variable".format(
op_type))
if dtype is None:
dtype = each.dtype
elif dtype != each.dtype:
raise ValueError(
"operator {0} must input same dtype. {1} vs {2}".format(
op_type, dtype, each.dtype))
return dtype
def func(**kwargs):
helper = LayerHelper(op_type, **kwargs)
dtype = infer_and_check_dtype(op_proto, **kwargs)
inputs = dict()
for ipt in op_proto.inputs:
name = _convert_(ipt.name)
val = kwargs.pop(name, [])
if not isinstance(val, list) and not isinstance(val, tuple):
val = [val]
inputs[ipt.name] = val
outputs = dict()
out = helper.create_tmp_variable(dtype=dtype)
outputs[o_name] = [out]
for name in intermediate_output_names:
outputs[name] = [helper.create_tmp_variable(dtype=dtype)]
helper.append_op(
type=op_type, inputs=inputs, outputs=outputs, attrs=kwargs)
return helper.append_activation(out)
func.__name__ = op_type
func.__doc__ = _generate_doc_string_(op_proto)
return func
def deprecated(func_or_class):
"""
Deprecated warning decorator. It will result a warning message.
Should be used before class or function, member function
"""
@functools.wraps(func)
def func_wrapper(*args, **kwargs):
"""
Wrap func with deprecated warning
"""
warnings.simplefilter('always', DeprecationWarning) #turn off filter
warnings.warn(
"Call to deprecated function {}.".format(func.__name__),
category=DeprecationWarning,
stacklevel=2)
warnings.simplefilter('default', DeprecationWarning) #reset filter
return func(*args, **kwargs)
return func_wrapper
import unittest
import warnings
import paddle.v2.fluid as fluid
import paddle.v2.fluid.framework as framework
import paddle.v2.fluid.layers as layers
import paddle.v2.fluid.registry as registry
class TestRegistry(unittest.TestCase):
def test_registry_layer(self):
self.layer_type = "mean"
program = framework.Program()
x = fluid.layers.data(name='X', shape=[10, 10], dtype='float32')
output = layers.mean(x)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
X = np.random.random((10, 10)).astype("float32")
mean_out = exe.run(program, feed={"X": X}, fetch_list=[output])
self.assertAlmostEqual(np.mean(X), mean_out)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册