# Copyright (c) 2016 Baidu, Inc. All Rights Reserved # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import functools from paddle.trainer.config_parser import * from .activations import LinearActivation, SigmoidActivation, TanhActivation, \ ReluActivation, IdentityActivation, SoftmaxActivation from .evaluators import * from .poolings import MaxPooling, AvgPooling, BasePoolingType from .attrs import * from .default_decorators import * try: import cPickle as pickle except ImportError: import pickle import copy __all__ = ["full_matrix_projection", "AggregateLevel", "ExpandLevel", "identity_projection", "dotmul_projection", "table_projection", "mixed_layer", "data_layer", "embedding_layer", "fc_layer", "grumemory", "pooling_layer", "lstmemory", "last_seq", "first_seq", "cos_sim", "hsigmoid", "regression_cost", 'classification_cost', "LayerOutput", 'img_conv_layer', 'img_pool_layer', 'batch_norm_layer', 'img_cmrnorm_layer', 'img_rnorm_layer', 'addto_layer', 'concat_layer', 'lstm_step_layer', 'recurrent_group', 'memory', 'StaticInput', 'expand_layer', 'scaling_layer', 'power_layer', 'interpolation_layer', 'trans_layer', 'sum_to_one_norm_layer', 'get_output_layer', 'LayerType', 'context_projection', 'beam_search', 'maxid_layer', 'GeneratedInput', 'SubsequenceInput', 'gru_step_layer', 'recurrent_layer', 'BaseGeneratedInput', 'conv_operator', 'conv_shift_layer', 'tensor_layer', 'selective_fc_layer', 'sampling_id_layer', 'slope_intercept_layer', 'trans_full_matrix_projection', 'convex_comb_layer', 'ctc_layer', 'crf_layer', 'crf_decoding_layer', 'cross_entropy_with_selfnorm', 'cross_entropy', 'multi_binary_label_cross_entropy', 'rank_cost', 'lambda_cost', 'huber_cost', 'block_expand_layer', ] class LayerType(object): """ Layer type enumerations. """ DATA = "data" MIXED_LAYER = "mixed" LSTMEMORY = "lstmemory" GRUMEMORY = "gated_recurrent" SEQUENCE_LAST_INSTANCE = "seqlastins" SEQUENCE_FIRST_INSTANCE = "seqfirstins" POOLING_MAX = "max" POOLING_AVG = 'average' FC_LAYER = "fc" COST = 'cost' COSINE_SIM = 'cos_vm' HSIGMOID = 'hsigmoid' CONV_LAYER = "conv" POOL_LAYER = "pool" BATCH_NORM_LAYER = 'batch_norm' NORM_LAYER = 'norm' SUM_TO_ONE_NORM_LAYER = 'sum_to_one_norm' ADDTO_LAYER = 'addto' CONCAT_LAYER = 'concat' CONCAT_PROJ_LAYER = 'concat2' LSTM_STEP_LAYER = 'lstm_step' GRU_STEP_LAYER = 'gru_step' GET_OUTPUT_LAYER = 'get_output' EXPAND_LAYER = 'expand' INTERPOLATION_LAYER = 'interpolation' POWER_LAYER = 'power' SCALING_LAYER = 'scaling' TRANS_LAYER = 'trans' MEMORY = 'memory' MAXID_LAYER = 'maxid' EOSID_LAYER = 'eos_id' RECURRENT_LAYER = 'recurrent' CONV_SHIFT_LAYER = "conv_shift" TENSOR_LAYER = "tensor" SEL_FC_LAYER = "selective_fc" SAMPLING_ID_LAYER = "sampling_id" SLOPE_INTERCEPT_LAYER = "slope_intercept" CONVEX_COMBINATION_LAYER = "convex_comb" BLOCK_EXPAND = "blockexpand" CTC_LAYER = "ctc" CRF_LAYER = "crf" CRF_DECODING_LAYER = "crf_decoding" RANK_COST = "rank-cost" LAMBDA_COST = "lambda_cost" HUBER = "huber" CROSS_ENTROPY = "multi-class-cross-entropy" CROSS_ENTROPY_WITH_SELFNORM = "multi_class_cross_entropy_with_selfnorm" SOFT_BIN_CLASS_CROSS_ENTROPY = "soft_binary_class_cross_entropy" MULTI_BIN_LABEL_CROSS_ENTROPY = "multi_binary_label_cross_entropy" @staticmethod def is_layer_type(type_name): """ If type_name is a layer type. :param type_name: layer type name. Because layer type enumerations are strings. :type type_name: basestring :return: True if is a layer_type :rtype: bool """ for key in dir(LayerType): if key.isupper(): att = getattr(LayerType, key) if isinstance(att, basestring) and type_name == att: return True return False class AggregateLevel(object): EACH_TIMESTEP = 'non-seq' EACH_SEQUENCE = 'seq' class LayerOutput(object): """ LayerOutput is output for layer function. It is used internally by several reasons. - Check layer connection make sense. - FC(Softmax) => Cost(MSE Error) is not good for example. - Tracking layer connection. - Pass to layer methods as input. :param name: Layer output name. :type name: basestring :param layer_type: Current Layer Type. One of LayerType enumeration. :type layer_type: basestring :param activation: Layer Activation. :type activation: BaseActivation. :param parents: Layer's parents. :type parents: list|tuple """ def __init__(self, name, layer_type, parents=None, activation=None, num_filters=None, img_norm_type=None, size=None, outputs=None): assert isinstance(name, basestring) assert isinstance(layer_type, basestring) assert LayerType.is_layer_type(layer_type) self.name = name self.layer_type = layer_type self.parents = [] if parents is None else parents self.activation = activation self.num_filters = num_filters self.img_norm_type = img_norm_type self.size = size if outputs is None: outputs = ['default'] self.outputs = outputs def __repr__(self): """ Disable __repr__ for debug reason. Will be implemented when release """ assert False, "this method should not be invoked" def __str__(self): """ Disable __str__ for debug reason. Will be implemented when release """ assert False, "this method should not be invoked" ERROR_CLIPPING = 'error_clipping_threshold' DROPOUT = 'drop_rate' def layer_support(*attrs): def decorator(method): @functools.wraps(method) def wrapper(*args, **kwargs): for attr in attrs: for each in args: if isinstance(each, ExtraLayerAttribute): setattr(each, '_'.join(['can', attr]), True) for key in kwargs: val = kwargs[key] if isinstance(val, ExtraLayerAttribute): setattr(val, '_'.join(['can', attr]), True) for each in args: if isinstance(each, ExtraLayerAttribute): each.check(method.__name__) for key in kwargs: val = kwargs[key] if isinstance(val, ExtraLayerAttribute): val.check(method.__name__) return method(*args, **kwargs) return wrapper return decorator @wrap_param_attr_default() def full_matrix_projection(input, size=0, param_attr=None): """ Full Matrix Projection. It performs full matrix multiplication. .. math:: out.row[i] += in.row[i] * weight There are two styles of usage. 1. When used in mixed_layer like this, you can only set the input: .. code-block:: python with mixed_layer(size=100) as m: m += full_matrix_projection(input=layer) 2. When used as an independant object like this, you must set the size: .. code-block:: python proj = full_matrix_projection(input=layer, size=100, param_attr=ParamAttr(name='_proj')) :param input: input layer :type input: LayerOutput :param size: The parameter size. Means the width of parameter. :type size: int :param param_attr: Parameter config, None if use default. :type param_attr: ParameterAttribute :return: A FullMatrixProjection Object. :rtype: FullMatrixProjection """ proj = FullMatrixProjection(input_layer_name=input.name, size=size, **param_attr.attr) proj.origin = input proj.origin.projection = "matrix" return proj @wrap_param_attr_default() def table_projection(input, size=0, param_attr=None): """ Table Projection. It selects rows from parameter where row\_id is in input\_ids. .. math:: out.row[i] += table.row[ids[i]] where :math:`out` is output, :math:`table` is parameter, :math:`ids` is input\_ids, and :math:`i` is row\_id. There are two styles of usage. 1. When used in mixed_layer like this, you can only set the input: .. code-block:: python with mixed_layer(size=100) as m: m += table_projection(input=layer) 2. When used as an independant object like this, you must set the size: .. code-block:: python proj = table_projection(input=layer, size=100, param_attr=ParamAttr(name='_proj')) :param input: Input layer, which must contains id fields. :type input: LayerOutput :param size: The parameter size. Means the width of parameter. :type size: int :param param_attr: Parameter config, None if use default. :type param_attr: ParameterAttribute :return: A TableProjection Object. :rtype: TableProjection """ proj = TableProjection(input_layer_name=input.name, size=size, **param_attr.attr) proj.origin = input proj.origin.projection = "table" return proj def identity_projection(input, offset=None): """ 1. IdentityProjection if offset=None. It performs: .. math:: out.row[i] += in.row[i] The example usage is: .. code-block:: python proj = identity_projection(input=layer) 2. IdentityOffsetProjection if offset!=None. It likes IdentityProjection, but layer size may be smaller than input size. It select dimesions [offset, offset+layer_size) from input: .. math:: out.row[i] += in.row[i + \\textrm{offset}] The example usage is: .. code-block:: python proj = identity_projection(input=layer, offset=10) Note that both of two projections should not have any parameter. :param input: Input Layer. :type input: LayerOutput. :param offset: Offset, None if use default. :type offset: int :return: A IdentityProjection or IdentityOffsetProjection Object :rtype: IdentityProjection or IdentityOffsetProjection """ if offset is None: proj = IdentityProjection(input_layer_name=input.name) proj.origin = input proj.origin.projection = 'identity' else: proj = IdentityOffsetProjection(input_layer_name=input.name, offset=offset) proj.origin = input proj.origin.projection = 'identity_offset' return proj @wrap_param_attr_default() def dotmul_projection(input, param_attr=None, scale=1): """ 1. DotMulProjection if input is a layer. It performs element-wise multiplication with weight. .. math:: out.row[i] += in.row[i] .* weight where :math:`.*` means element-wise multiplication. The example usage is: .. code-block:: python proj = dotmul_projection(input=layer) 2. DotMulOperator if input is a list or tuple. It takes two inputs, performs element-wise multiplication: .. math:: out.row[i] += scale * (in1.row[i] .* in2.row[i]) where :math:`.*` means element-wise multiplication, and scale is a config scalar, its default value is one. The example usage is: .. code-block:: python op = dotmul_projection(input=[layer1, layer2], scale=2.0) :param input: Input layer. :type input: LayerOutput|list|tuple :param param_attr: Parameter config, None if use default. :type param_attr: ParameterAttribute :param scale: config scalar, default value is one. :type scale: float :return: A DotMulProjection or DotMulOperator Object. :rtype: DotMulProjection or DotMulOperator """ if isinstance(input, LayerOutput): proj = DotMulProjection(input_layer_name=input.name, size=input.size, **param_attr.attr) proj.origin = input proj.origin.projection = "dot_mul" return proj else: assert isinstance(input, list) or isinstance(input, tuple) assert len(input) == 2 assert param_attr is None op = DotMulOperator(input_layer_name=[x.name for x in input], scale=scale) op.origin = input op.origin.operator = "dot_mul" return op @wrap_bias_attr_default(['padding_attr']) def context_projection(input, context_len, context_start=None, padding_attr=False): """ Context Projection. It just simply reorganizes input sequence, combines "context_len" sequence to one context from context_start. "context_start" will be set to -(context_len - 1) / 2 by default. If context position out of sequence length, padding will be filled as zero if padding_attr = False, otherwise it is trainable. For example, origin sequence is [A B C D E F G], context len is 3, then after context projection and not set padding_attr, sequence will be [ 0AB ABC BCD CDE DEF EFG FG0 ]. :param input: Input Sequence. :type input: LayerOutput :param context_len: context length. :type context_len: int :param context_start: context start position. Default is -(context_len - 1)/2 :type context_start: int :param padding_attr: Padding Parameter Attribute. If false, it means padding always be zero. Otherwise Padding is learnable, and parameter attribute is set by this parameter. :type padding_attr: bool|ParameterAttribute :return: Projection :rtype: Projection """ context_start = -( context_len - 1) / 2 if context_start is None else context_start extra_dict = dict() trainable = isinstance(padding_attr, ParameterAttribute) if trainable: extra_dict = padding_attr.attr proj = ContextProjection(input_layer_name=input.name, context_length=context_len, context_start=context_start, trainable_padding=trainable, **extra_dict) proj.origin = input proj.origin.projection = 'context' return proj class MixedLayerType(LayerOutput): """ The internal object for trainer_helpers. """ class AddToSealedMixedLayerException(Exception): def __init__(self): Exception.__init__(self) def __init__(self, name, size, act, bias_attr, layer_attr, parents=None): """ Ctor. :param name: layer name. :type name: basestring :param size: layer size. :type size: int :param act: activation type. :type act: BaseActivation :param bias_attr: The Bias Attribute. If no bias, then pass False or something not type of ParameterAttribute. None will get a default Bias. :type bias_attr: ParameterAttribute or None means has bias. Any other type means no bias. :param layer_attr: Extra Layer Attribute. :type layer_attr: ExtraLayerAttribute or None """ LayerOutput.__init__(self, name, LayerType.MIXED_LAYER, parents, size=size, activation=act) self.bias_attr = bias_attr self.layer_attr = layer_attr self.inputs = [] self.finalized = False def __add__(self, other): """ + += operator :param other: Other projection. :type other: Projection :return: self. :rtype: MixedLayerType """ if not self.finalized: assert isinstance(other, Projection) self.inputs.append(other) self.parents.append(other.origin) return self else: raise MixedLayerType.AddToSealedMixedLayerException() def __enter__(self): assert len(self.inputs) == 0 return self def __exit__(self, *args, **kwargs): del args, kwargs # unused parameter to suppress warning assert len(self.inputs) != 0 MixedLayer( name=self.name, size=self.size, active_type=self.activation.name, bias=ParamAttr.to_bias(self.bias_attr), inputs=self.inputs, **ExtraLayerAttribute.to_kwargs(self.layer_attr) ) @wrap_name_default("mixed") @wrap_act_default(act=LinearActivation()) @wrap_bias_attr_default(has_bias=False) @layer_support(ERROR_CLIPPING, DROPOUT) def mixed_layer(size, input=None, name=None, act=None, bias_attr=False, layer_attr=None): """ Mixed Layer. A mixed layer will add all inputs together, then activate. Each inputs is a projection or operator. There are two styles of usages. 1. When not set inputs parameter, use mixed_layer like this: .. code-block:: python with mixed_layer(size=256) as m: m += full_matrix_projection(input=layer1) m += identity_projection(input=layer2) 2. You can also set all inputs when invoke mixed_layer as follows: .. code-block:: python m = mixed_layer(size=256, input=[full_matrix_projection(input=layer1), full_matrix_projection(input=layer2)]) :param name: mixed layer name. Can be referenced by other layer. :type name: basestring :param size: layer size. :type size: int :param input: inputs layer. It is an optional parameter. If set, then this function will just return layer's name. :param act: Activation Type. :type act: BaseActivation :param bias_attr: The Bias Attribute. If no bias, then pass False or something not type of ParameterAttribute. None will get a default Bias. :type bias_attr: ParameterAttribute or None or bool :param layer_attr: The extra layer config. Default is None. :type layer_attr: ExtraLayerAttribute :return: MixedLayerType object can add inputs or layer name. :rtype: MixedLayerType """ if input is None: return MixedLayerType(name, size, act, bias_attr, layer_attr) else: with mixed_layer(name=name, size=size, act=act, bias_attr=bias_attr, layer_attr=layer_attr) as m: if isinstance(input, list) or isinstance(input, tuple): for each in input: m += each else: m += input return m @layer_support() def data_layer(name, size, layer_attr=None): """ Define DataLayer For NeuralNetwork. The example usage is: .. code-block:: python data = data_layer(name="input", size=1000) :param name: Name of this data layer. :type name: basestring :param size: Size of this data layer. :type size: int :param layer_attr: Extra Layer Attribute. :type layer_attr: ExtraLayerAttribute. :return: Layer Output Object. :rtype: LayerOutput """ Layer(type=LayerType.DATA, name=name, size=size, **ExtraLayerAttribute.to_kwargs(layer_attr)) return LayerOutput(name, LayerType.DATA, size=size) @wrap_name_default("embedding") @wrap_param_attr_default() @layer_support(ERROR_CLIPPING) def embedding_layer(input, size, name=None, param_attr=None, layer_attr=None): """ Define a embedding Layer. :param name: Name of this embedding layer. :type name: basestring :param input: The input layer for this embedding. NOTE: must be Index Data. :type input: LayerOutput :param size: The embedding dimension. :type size: int :param param_attr: The embedding parameter attribute. See ParameterAttribute for details. :type param_attr: ParameterAttribute|None :param layer_attr: Extra layer Config. Default is None. :type layer_attr: ExtraLayerAttribute|None :return: Embedding Layer output :rtype: LayerOutput """ with mixed_layer(name=name, size=size, act=LinearActivation(), bias_attr=False, layer_attr=layer_attr) as mix: mix += table_projection(input=input, size=size, param_attr=param_attr) return mix @wrap_name_default() @wrap_param_attr_default() @wrap_bias_attr_default() @wrap_act_default() @layer_support(ERROR_CLIPPING, DROPOUT) def fc_layer(input, size, act=None, name=None, param_attr=None, bias_attr=None, layer_attr=None): """ Helper for declare fully connected layer. The example usage is: .. code-block:: python fc = fc_layer(input=layer, size=1024, act=LinearActivation(), bias_attr=False) which is equal to: .. code-block:: python with mixed_layer(size=1024) as fc: fc += full_matrix_projection(input=layer) :param name: The Layer Name. :type name: basestring :param input: The input layer. Could be a list/tuple of input layer. :type input: LayerOutput|list|tuple :param size: The layer dimension. :type size: int :param act: Activation Type. Default is tanh. :type act: BaseActivation :param param_attr: The Parameter Attribute|list. :type param_attr: ParameterAttribute :param bias_attr: The Bias Attribute. If no bias, then pass False or something not type of ParameterAttribute. None will get a default Bias. :type bias_attr: ParameterAttribute|None|Any :param layer_attr: Extra Layer config. :type layer_attr: ExtraLayerAttribute|None :return: Layer Name. :rtype: LayerOutput """ if isinstance(input, LayerOutput): input = [input] assert not isinstance(param_attr, list) param_attr = [param_attr] else: if isinstance(param_attr, list) or isinstance(param_attr, tuple): assert len(input) == len(param_attr) else: param_attr = [copy.deepcopy(param_attr) for _ in range(len(input))] assert isinstance(input, list) def __idx_to_input__(i): attr = param_attr[i] assert isinstance(attr, ParameterAttribute) return Input(input[i].name, **attr.attr) Layer( inputs=map(__idx_to_input__, range(len(input))), name=name, type=LayerType.FC_LAYER, size=size, bias=ParamAttr.to_bias(bias_attr), active_type=act.name, **ExtraLayerAttribute.to_kwargs(layer_attr) ) return LayerOutput(name, LayerType.FC_LAYER, input, activation=act, size=size) @wrap_name_default("seq_pooling") @wrap_bias_attr_default(has_bias=False) @wrap_param_default(['pooling_type'], default_factory=lambda _: MaxPooling()) @layer_support() def pooling_layer(input, pooling_type=None, name=None, bias_attr=None, agg_level=AggregateLevel.EACH_TIMESTEP, layer_attr=None): """ Pooling layer for sequence inputs, not used for Image. The example usage is: .. code-block:: python seq_pool = pooling_layer(input=layer, pooling_type=AvgPooling(), agg_level=AggregateLevel.EACH_SEQUENCE) :param agg_level: AggregateLevel.EACH_TIMESTEP or AggregateLevel.EACH_SEQUENCE :type agg_level: AggregateLevel :param name: layer name. :type name: basestring :param input: input layer name. :type input: LayerOutput :param pooling_type: Type of pooling, MaxPooling(default), AvgPooling, SumPooling, SquareRootNPooling. :type pooling_type: BasePoolingType|None :param bias_attr: Bias parameter attribute. False if no bias. :type bias_attr: ParameterAttribute|None|False :param layer_attr: The Extra Attributes for layer, such as dropout. :type layer_attr: ExtraLayerAttribute|None :return: layer name. :rtype: LayerType """ extra_dict = dict() if isinstance(pooling_type, AvgPooling): extra_dict['average_strategy'] = pooling_type.strategy extra_dict.update(ExtraLayerAttribute.to_kwargs(layer_attr)) Layer( name=name, type=pooling_type.name, inputs=[Input(input.name)], bias=ParamAttr.to_bias(bias_attr), trans_type=agg_level, **extra_dict ) return LayerOutput(name, pooling_type.name, parents=[input], size=input.size) @wrap_bias_attr_default() @wrap_param_attr_default() @wrap_act_default(param_names=['gate_act'], act=SigmoidActivation()) @wrap_act_default(param_names=["act", 'state_act'], act=TanhActivation()) @wrap_name_default("lstmemory") @layer_support(DROPOUT) def lstmemory(input, name=None, reverse=False, act=None, gate_act=None, state_act=None, bias_attr=None, param_attr=None, layer_attr=None): """ Long Short-term Memory Cell. The memory cell was implemented as follow equations. .. math:: i_t = \\sigma(W_{xi}x_{t} + W_{hi}h_{t-1} + W_{ci}c_{t-1} + b_i) f_t = \\sigma(W_{xf}x_{t} + W_{hf}h_{t-1} + W_{cf}c_{t-1} + b_f) c_t = f_tc_{t-1} + i_t tanh (W_{xc}x_t+W_{hc}h_{t-1} + b_c) o_t = \\sigma(W_{xo}x_{t} + W_{ho}h_{t-1} + W_{co}c_t + b_o) h_t = o_t tanh(c_t) NOTE: In paddle's implementation, the multiply operation :math:`W_{xi}x_{t}` , :math:`W_{xf}x_{t}`, :math:`W_{xc}x_t`, :math:`W_{xo}x_{t}` is not done by lstmemory layer, so it must use a mixed_layer do this full_matrix_projection before lstm is used. NOTE: This is a low level user interface. You may use network.simple_lstm to config a simple plain lstm layer. Please refer **Generating Sequences With Recurrent Neural Networks** if you want to know what lstm is. Link_ is here. .. _Link: http://arxiv.org/abs/1308.0850 TODO(yuyang18): Check lstm can input multiple values or not? :param name: The lstmemory layer name. :type name: basestring :param input: input layer name. :type input: LayerOutput :param reverse: is sequence process reversed or not. :type reverse: bool :param act: activation type, TanhActivation by default. :math:`h_t` :type act: BaseActivation :param gate_act: gate activation type, SigmoidActivation by default. :type gate_act: BaseActivation :param state_act: state activation type, TanhActivation by default. :type state_act: BaseActivation :param bias_attr: Bias attribute. None means default bias. False means no bias. :type bias_attr: ParameterAttribute|None|False :param param_attr: Parameter Attribute. :type param_attr: ParameterAttribute|None|False :param layer_attr: Extra Layer attribute :type layer_attr: ExtraLayerAttribute|None :return: Layer name. :rtype: LayerOutput """ assert gate_act.support_hppl assert state_act.support_hppl assert act.support_hppl Layer(name=name, type=LayerType.LSTMEMORY, active_type=act.name, active_state_type=state_act.name, active_gate_type=gate_act.name, reversed=reverse, bias=ParamAttr.to_bias(bias_attr), inputs=[Input(input.name, **param_attr.attr)], **ExtraLayerAttribute.to_kwargs(layer_attr)) return LayerOutput(name, LayerType.LSTMEMORY, [input], size=input.size / 4 if input.size is not None else None) @wrap_bias_attr_default() @wrap_param_attr_default() @wrap_act_default(param_names=['gate_act'], act=SigmoidActivation()) @wrap_act_default(param_names=["act"], act=TanhActivation()) @wrap_name_default("gru") @layer_support(DROPOUT) def grumemory(input, name=None, reverse=False, act=None, gate_act=None, bias_attr=None, param_attr=None, layer_attr=None): """ Gate Recurrent Unit Layer. The memory cell was implemented as follow equations. 1. update gate :math:`z`: defines how much of the previous memory to keep around or the unit updates its activations. The update gate is computed by: .. math:: z_t = \\sigma(W_{z}x_{t} + U_{z}h_{t-1} + b_z) 2. reset gate :math:`r`: determines how to combine the new input with the previous memory. The reset gate is computed similarly to the update gate: .. math:: r_t = \\sigma(W_{r}x_{t} + U_{r}h_{t-1} + b_r) 3. The candidate activation :math:`\\tilde{h_t}` is computed similarly to that of the traditional recurrent unit: .. math:: {\\tilde{h_t}} = tanh(W x_{t} + U (r_{t} \odot h_{t-1}) + b) 4. The hidden activation :math:`h_t` of the GRU at time t is a linear interpolation between the previous activation :math:`h_{t-1}` and the candidate activation :math:`\\tilde{h_t}`: .. math:: h_t = (1 - z_t) h_{t-1} + z_t {\\tilde{h_t}} NOTE: In paddle's implementation, the multiply operation :math:`W_{r}x_{t}`, :math:`W_{z}x_{t}` and :math:`W x_t` are not computed in gate_recurrent layer. So it must use a mixed_layer with full_matrix_projection or fc_layer to compute them before GRU. The details can refer to `Empirical Evaluation of Gated Recurrent Neural Networks on Sequence Modeling. `_ The simple usage is: .. code-block:: python gru = grumemory(input) :param name: The gru layer name. :type name: None|basestring :param input: input layer. :type input: LayerOutput. :param reverse: Wether sequence process is reversed or not. :type reverse: bool :param act: activation type, TanhActivation by default. This activation affects the :math:`{\\tilde{h_t}}`. :type act: BaseActivation :param gate_act: gate activation type, SigmoidActivation by default. This activation affects the :math:`z_t` and :math:`r_t`. It is the :math:`\\sigma` in the above formula. :type gate_act: BaseActivation :param bias_attr: Bias attribute. None means default bias. False means no bias. :type bias_attr: ParameterAttribute|None|False :param param_attr: Parameter Attribute. :type param_attr: ParameterAttribute|None|False :param layer_attr: Extra Layer attribute :type layer_attr: ExtraLayerAttribute|None :return: Layer name. :rtype: LayerOutput """ assert act.support_hppl assert gate_act.support_hppl Layer(name=name, type=LayerType.GRUMEMORY, active_type=act.name, active_gate_type=gate_act.name, reversed=reverse, bias=ParamAttr.to_bias(bias_attr), inputs=[Input(input.name, **param_attr.attr)], **ExtraLayerAttribute.to_kwargs(layer_attr) ) return LayerOutput(name, LayerType.GRUMEMORY, [input], size=input.size / 3 if input.size is not None else None) @wrap_name_default() @layer_support() def last_seq(input, name=None, agg_level=AggregateLevel.EACH_TIMESTEP, layer_attr=None): """ Get Last Timestamp Activation of a sequence. :param agg_level: Aggregated level :param name: Layer name. :type name: basestring :param input: Input layer name. :type input: LayerOutput :param layer_attr: extra layer attributes. :type layer_attr: ExtraLayerAttribute. :return: layer name. :rtype: LayerOutput """ Layer( name=name, type=LayerType.SEQUENCE_LAST_INSTANCE, inputs=[input.name], trans_type=agg_level, **ExtraLayerAttribute.to_kwargs(layer_attr) ) return LayerOutput(name, LayerType.SEQUENCE_LAST_INSTANCE, parents=[input], size=input.size) @wrap_name_default() @layer_support() def first_seq(input, name=None, agg_level=AggregateLevel.EACH_TIMESTEP, layer_attr=None): """ Get First Timestamp Activation of a sequence. :param agg_level: aggregation level :param name: Layer name. :type name: basestring :param input: Input layer name. :type input: LayerOutput :param layer_attr: extra layer attributes. :type layer_attr: ExtraLayerAttribute. :return: layer name. :rtype: LayerOutput """ Layer( name=name, type=LayerType.SEQUENCE_FIRST_INSTANCE, inputs=[input.name], trans_type=agg_level, **ExtraLayerAttribute.to_kwargs(layer_attr) ) return LayerOutput(name, LayerType.SEQUENCE_FIRST_INSTANCE, parents=[input], size=input.size) class ExpandLevel(object): FROM_TIMESTEP = AggregateLevel.EACH_TIMESTEP FROM_SEQUENCE = AggregateLevel.EACH_SEQUENCE @wrap_name_default() @layer_support() def expand_layer(input, expand_as, name=None, bias_attr=False, expand_level=ExpandLevel.FROM_TIMESTEP, layer_attr=None): """ A layer for "Expand Dense data or (sequence data where the length of each sequence is one) to sequence data." The example usage is: .. code-block:: python expand = expand_layer(input=layer1, expand_as=layer2, expand_level=ExpandLevel.FROM_TIMESTEP) :param input: Input layer :type input: LayerOutput :param expand_as: Expand as this layer's sequence info. :type expand_as: LayerOutput :param name: Layer name. :type name: basestring :param bias_attr: Bias attribute. None means default bias. False means no bias. :type bias_attr: ParameterAttribute|None|False :param expand_level: whether input layer is timestep(default) or sequence. :type expand_level: ExpandLevel :param layer_attr: extra layer attributes. :type layer_attr: ExtraLayerAttribute. :return: layer name :rtype: LayerOutput """ Layer( inputs=[input.name, expand_as.name], name=name, bias=ParamAttr.to_bias(bias_attr=bias_attr), type=LayerType.EXPAND_LAYER, trans_type=expand_level, **ExtraAttr.to_kwargs(layer_attr) ) return LayerOutput(name=name, size=input.size, layer_type=LayerType.EXPAND_LAYER, parents=[input, expand_as]) @wrap_name_default() @layer_support() def interpolation_layer(input, weight, name=None, layer_attr=None): """ This layer is for linear interpolation with two inputs, which is used in NEURAL TURING MACHINE. .. math:: y.row[i] = w[i] * x_1.row[i] + (1 - w[i]) * x_2.row[i] where :math:`x_1` and :math:`x_2` are two (batchSize x dataDim) inputs, :math:`w` is (batchSize x 1) weight vector, and :math:`y` is (batchSize x dataDim) output. The example usage is: .. code-block:: python interpolation = interpolation_layer(input=[layer1, layer2], weight=layer3) :param input: Input layer. :type input: list|tuple :param weight: Weight layer. :type weight: LayerOutput :param name: Layer name. :type name: basestring :param layer_attr: extra layer attributes. :type layer_attr: ExtraLayerAttribute. :return: layer name. :rtype: LayerOutput """ assert isinstance(input, list) or isinstance(input, tuple) assert len(input) == 2 assert input[0].size == input[1].size assert weight.size == 1 Layer( name=name, type=LayerType.INTERPOLATION_LAYER, inputs=[weight.name, input[0].name, input[1].name], **ExtraAttr.to_kwargs(layer_attr) ) return LayerOutput(name, LayerType.INTERPOLATION_LAYER, parents=[weight, input[0], input[1]], size=input[0].size) @wrap_name_default() @layer_support() def power_layer(input, weight, name=None, layer_attr=None): """ This layer applies a power function to a vector element-wise, which is used in NEURAL TURING MACHINE. .. math:: y = x^w where :math:`x` is a input vector, :math:`w` is scalar weight, and :math:`y` is a output vector. The example usage is: .. code-block:: python power = power_layer(input=layer1, weight=layer2) :param input: Input layer. :type input: LayerOutput :param weight: Weight layer. :type weight: LayerOutput :param name: Layer name. :type name: basestring :param layer_attr: extra layer attributes. :type layer_attr: ExtraLayerAttribute. :return: layer name. :rtype: LayerOutput """ assert weight.size == 1 Layer( name=name, type=LayerType.POWER_LAYER, inputs=[input.name, weight.name], **ExtraAttr.to_kwargs(layer_attr) ) return LayerOutput(name, LayerType.POWER_LAYER, parents=[input, weight], size=input.size) @wrap_name_default() @layer_support() def scaling_layer(input, weight, name=None, layer_attr=None): """ A layer for each row of a matrix, multiplying with a element of a vector. .. math:: y.row[i] = w[i] * x.row[i] where :math:`x` is (batchSize x dataDim) input, :math:`w` is (batchSize x 1) weight vector, and :math:`y` is (batchSize x dataDim) output. The example usage is: .. code-block:: python scale = scaling_layer(input=layer1, weight=layer2) :param input: Input layer. :type input: LayerOutput :param weight: Weight layer. :type weight: LayerOutput :param name: Layer name. :type name: basestring :param layer_attr: extra layer attributes. :type layer_attr: ExtraLayerAttribute. :return: layer name. :rtype: LayerOutput """ assert weight.size == 1 Layer( name=name, type=LayerType.SCALING_LAYER, inputs=[weight.name, input.name], **ExtraAttr.to_kwargs(layer_attr) ) return LayerOutput(name, LayerType.SCALING_LAYER, parents=[weight, input], size=input.size) @wrap_name_default() @layer_support() def trans_layer(input, name=None, layer_attr=None): """ A layer for transposition. .. math:: y = x^\mathrm{T} where :math:`x` is (M x N) input, and :math:`y` is (N x M) output. The example usage is: .. code-block:: python trans = trans_layer(input=layer) :param input: Input layer. :type input: LayerOutput :param name: Layer name. :type name: basestring :param layer_attr: extra layer attributes. :type layer_attr: ExtraLayerAttribute. :return: layer name. :rtype: LayerOutput """ Layer( name=name, type=LayerType.TRANS_LAYER, inputs=[input.name], **ExtraAttr.to_kwargs(layer_attr) ) return LayerOutput(name, LayerType.TRANS_LAYER, parents=[input], size=input.size) @wrap_name_default() @layer_support() def cos_sim(a, b, scale=5, size=1, name=None, layer_attr=None): """ Cosine Similarity Layer. The cosine similarity equation is here. .. math:: similarity = cos(\\theta) = {\\mathbf{A} \\cdot \\mathbf{B} \\over \\|\\mathbf{A}\\| \\|\\mathbf{B}\\|} And the input dimension is :math:`a \in R^M`, :math:`b \in R^{MN}`. The similarity will be calculated N times by step M. The output dimension is :math:`R^N`. The scale will be multiplied to similarity. :param name: layer name :type name: basestring :param a: input layer a :type a: LayerOutput :param b: input layer b :type b: LayerOutput :param scale: scale for cosine value. default is 5. :type scale: float :param size: layer size. NOTE size_a * size should equal size_b. :type size: int :param layer_attr: Extra Layer Attribute. :type layer_attr: ExtraLayerAttribute :return: layer name. :rtype: LayerOutput """ Layer( name=name, type=LayerType.COSINE_SIM, size=size, cos_scale=scale, inputs=[a.name, b.name], **ExtraLayerAttribute.to_kwargs(layer_attr) ) return LayerOutput(name, LayerType.COSINE_SIM, parents=[a, b]) @wrap_name_default() @wrap_bias_attr_default(has_bias=True) @layer_support() def hsigmoid(input, label, num_classes, name=None, bias_attr=None, layer_attr=None): """ Organize the classes into a binary tree. At each node, a sigmoid function is used to calculate the probability of belonging to the right branch. This idea is from "F. Morin, Y. Bengio (AISTATS 05): Hierarchical Probabilistic Neural Network Language Model." The example usage is: .. code-block:: python cost = hsigmoid(input=[layer1, layer2], label=data_layer, num_classes=3) :param name: layer name :type name: basestring :param input: Input layers. It could be a LayerOutput or list/tuple of LayerOutput. :type input: LayerOutput|list|tuple :param label: Label layer. :type label: LayerOutput :param num_classes: number of classes. :type num_classes: int :param bias_attr: Bias attribute. None means default bias. False means no bias. :type bias_attr: ParameterAttribute|False :param layer_attr: Extra Layer Attribute. :type layer_attr: ExtraLayerAttribute :return: layer name. :rtype: LayerOutput """ if isinstance(input, LayerOutput): input = [input] assert isinstance(input, list) or isinstance(input, tuple) assert isinstance(label, LayerOutput) assert label.layer_type == LayerType.DATA ipts_for_layer = [] parents = [] for each_input in input: assert isinstance(each_input, LayerOutput) ipts_for_layer.append(each_input.name) parents.append(each_input) ipts_for_layer.append(label.name) parents.append(label) Layer( name=name, type=LayerType.HSIGMOID, num_classes=num_classes, bias=ParamAttr.to_bias(bias_attr), inputs=ipts_for_layer, **ExtraLayerAttribute.to_kwargs(layer_attr) ) return LayerOutput(name, LayerType.HSIGMOID, parents=parents) @wrap_name_default("conv") @wrap_param_attr_default() @wrap_bias_attr_default() @wrap_act_default(act=ReluActivation()) @layer_support(DROPOUT) def img_conv_layer(input, filter_size, num_filters, name=None, num_channels=None, act=None, groups=1, stride=1, padding=0, bias_attr=None, param_attr=None, shared_biases=True, layer_attr=None, filter_size_y=None, stride_y=None, padding_y=None): """ Convolution layer for image. Paddle only support square input currently and thus input image's width equals height. The details of convolution layer, please refer UFLDL's `convolution `_ . The num_channel means input image's channel number. It may be 1 or 3 when input is raw pixels of image(mono or RGB), or it may be the previous layer's num_filters * num_group. There are several group of filter in paddle implementation. Each group will process some channel of inputs. For example, if input num_channel = 256, group = 4, num_filter=32, the paddle will create 32*4 = 128 filters to process inputs. The channels will be split into 4 pieces. First 256/4 = 64 channels will process by first 32 filters. The rest channels will be processed by rest group of filters. :param name: Layer name. :type name: basestring :param input: Layer Input. :type input: LayerOutput :param filter_size: The x dimension of a filter kernel. :type filter_size: int :param filter_size_y: The y dimension of a filter kernel. Since paddle now support rectangular filters, the filter's shape will be (filter_size, filter_size_y). :type filter_size_y: int :param num_filters: Each filter group's number of filter :param act: Activation type. Default is tanh :type act: BaseActivation :param groups: Group size of filters. :type groups: int :param stride: The x dimension of the stride. :type stride: int :param stride_y: The y dimension of the stride. :type stride_y: int :param padding: The x dimension of the padding. :type padding: int :param padding_y: The y dimension of the padding. :type padding_y: int :param bias_attr: Convolution bias attribute. None means default bias. False means no bias. :type bias_attr: ParameterAttribute|False :param num_channels: number of input channels. If None will be set automatically from previous output. :type num_channels: int :param param_attr: Convolution param attribute. None means default attribute :type param_attr: ParameterAttribute :param shared_biases: Is biases will be shared between filters or not. :type shared_biases: bool :param layer_attr: Layer Extra Attribute. :type layer_attr: ExtraLayerAttribute :return: Layer output. :rtype: LayerOutput """ if num_channels is None: assert input.num_filters is not None num_channels = input.num_filters if filter_size_y is None: filter_size_y = filter_size if stride_y is None: stride_y = stride if padding_y is None: padding_y = padding if param_attr.attr.get('initial_smart') == True: # special initial for conv layers. init_w = (2.0 / (filter_size ** 2 * num_channels)) ** 0.5 param_attr = ParameterAttribute(initial_mean=0.0, initial_std=init_w) Layer( name=name, inputs=Input(input.name, conv=Conv( filter_size=filter_size, padding=padding, stride=stride, channels=num_channels, groups=groups, filter_size_y=filter_size_y, padding_y=padding_y, stride_y=stride_y), **param_attr.attr), active_type=act.name, num_filters=num_filters, bias=ParamAttr.to_bias(bias_attr), shared_biases=shared_biases, type=LayerType.CONV_LAYER, **ExtraLayerAttribute.to_kwargs(layer_attr) ) return LayerOutput(name, LayerType.CONV_LAYER, parents=[input], activation=act, num_filters=num_filters) @wrap_name_default("pool") @layer_support() def img_pool_layer(input, pool_size, name=None, num_channels=None, pool_type=None, stride=1, start=None, padding=0, layer_attr=None): """ Image pooling Layer. The details of pooling layer, please refer ufldl's pooling_ . .. _pooling: http://ufldl.stanford.edu/tutorial/supervised/Pooling/ :param padding: pooling padding :type padding: int :param name: name of pooling layer :type name: basestring. :param input: layer's input :type input: LayerOutput :param pool_size: pooling size :type pool_size: int :param num_channels: number of input channel. :type num_channels: int :param pool_type: pooling type. MaxPooling or AveragePooling. Default is MaxPooling. :type pool_type: BasePoolingType :param stride: stride of pooling. :type stride: int :param start: start position of pooling operation. :type start: int :param layer_attr: Extra Layer attribute. :type layer_attr: ExtraLayerAttribute :return: LayerOutput """ if num_channels is None: assert input.num_filters is not None num_channels = input.num_filters if pool_type is None: pool_type = MaxPooling() elif isinstance(pool_type, AvgPooling): pool_type.name = 'avg' Layer( name=name, type=LayerType.POOL_LAYER, inputs=[Input(input.name, pool=Pool( pool_type=pool_type.name + '-projection', channels=num_channels, size_x=pool_size, start=start, stride=stride, padding=padding ))], **ExtraLayerAttribute.to_kwargs(layer_attr) ) return LayerOutput(name, LayerType.POOL_LAYER, parents=[input], num_filters=num_channels) def __img_norm_layer__(name, input, size, norm_type, scale, power, num_channels, blocked, layer_attr): if num_channels is None: assert input.num_filters is not None num_channels = input.num_filters Layer( name=name, type=LayerType.NORM_LAYER, inputs=Input( input.name, norm=Norm(norm_type=norm_type, channels=num_channels, size=size, scale=scale, pow=power, blocked=blocked) ), **ExtraLayerAttribute.to_kwargs(layer_attr) ) return LayerOutput(name, layer_type=LayerType.NORM_LAYER, parents=[input], num_filters=num_channels, img_norm_type=norm_type) @wrap_name_default("crmnorm") @layer_support() def img_cmrnorm_layer(input, size, scale, power, name=None, num_channels=None, blocked=0, layer_attr=None): """ Convolution cross-map-response-normalize layer. TODO(yuyang18): Add reference and equations, to explain why cmr is work? :param name: layer name. :type name: basestring :param input: layer's input. :type input: LayerOutput :param size: cross map response size. :type size: int :param scale: TODO(yuyang18) :type scale: float :param power: TODO(yuyang18) :type power: float :param num_channels: input layer's filers number or channels. If num_channels is None, it will be set automatically. :param blocked: TODO(yuyang18) :param layer_attr: Extra Layer Attribute. :type layer_attr: ExtraLayerAttribute :return: Layer's output :rtype: LayerOutput """ return __img_norm_layer__(name, input, size, "cmrnorm-projection", scale, power, num_channels, blocked, layer_attr) @wrap_name_default("rnorm") @layer_support() def img_rnorm_layer(input, size, scale, power, name=None, num_channels=None, layer_attr=None): """ TODO(yuyang18): add comments TODO(yuyang18): Why it is always not implemented whenever use_gpu or not? :param name: :param input: :param size: :param scale: :param power: :param num_channels: :param layer_attr: :return: """ return __img_norm_layer__(name, input, size, 'rnorm', scale, power, num_channels, 0, layer_attr) @wrap_bias_attr_default() @wrap_param_attr_default(default_factory=lambda _: ParamAttr(initial_mean=1.0, initial_std=0.)) @wrap_act_default(act=ReluActivation()) @wrap_name_default("batch_norm") @layer_support(DROPOUT) def batch_norm_layer(input, act=None, name=None, num_channels=None, bias_attr=None, param_attr=None, layer_attr=None, batch_norm_type=None, moving_average_fraction=0.9, use_global_stats=None): """ Batch Normalization Layer. The notation of this layer as follow. :math:`x` is the input features over a mini-batch. .. math:: \\mu_{\\beta} &\\gets \\frac{1}{m} \\sum_{i=1}^{m} x_i \\qquad &//\\ \ mini-batch\ mean \\\\ \\sigma_{\\beta}^{2} &\\gets \\frac{1}{m} \\sum_{i=1}^{m}(x_i - \\ \\mu_{\\beta})^2 \\qquad &//\ mini-batch\ variance \\\\ \\hat{x_i} &\\gets \\frac{x_i - \\mu_\\beta} {\\sqrt{\\ \\sigma_{\\beta}^{2} + \\epsilon}} \\qquad &//\ normalize \\\\ y_i &\\gets \\gamma \\hat{x_i} + \\beta \\qquad &//\ scale\ and\ shift The details of batch normalization please refer to this `paper `_. :param name: layer name. :type name: basestring :param input: batch normalization input. Better be linear activation. Because there is an activation inside batch_normalization. :type input: LayerOutput :param batch_norm_type: We have batch_norm and cudnn_batch_norm. batch_norm supports both CPU and GPU. cudnn_batch_norm requires cuDNN version greater or equal to v4 (>=v4). But cudnn_batch_norm is faster and needs less memory than batch_norm. By default (None), we will automaticly select cudnn_batch_norm for GPU and batch_norm for CPU. Otherwise, select batch norm type based on the specified type. If you use cudnn_batch_norm, we suggested you use latest version, such as v5.1. :type type: None|string, None or "batch_norm" or "cudnn_batch_norm" :param act: Activation Type. Better be relu. Because batch normalization will normalize input near zero. :type act: BaseActivation :param num_channels: num of image channels or previous layer's number of filters. None will automatically get from layer's input. :type num_channels: int :param bias_attr: :math:`\\beta`, better be zero when initialize. So the initial_std=0, initial_mean=1 is best practice. :type bias_attr: ParameterAttribute :param param_attr: :math:`\\gamma`, better be one when initialize. So the initial_std=0, initial_mean=1 is best practice. :type param_attr: ParameterAttribute :param layer_attr: Extra Layer Attribute. :type layer_attr: ExtraLayerAttribute :param use_global_stats: whether use moving mean/variance statistics during testing peroid. If None or True, it will use moving mean/variance statistics during testing. If False, it will use the mean and variance of current batch of test data for testing. :type use_global_stats: bool|None. :param moving_average_fraction: Factor used in the moving average computation, referred to as facotr, :math:`runningMean = newMean*(1-factor) + runningMean*factor` :type moving_average_fraction: float. :return: Layer's output :rtype: LayerOutput """ if not isinstance(act, ReluActivation): logger.log(logging.WARN, "%s is not recommend for batch normalization's activation, " "maybe the relu is better" % act.name) if not isinstance(input.activation, LinearActivation): logger.log(logging.WARN, "The activation should be inside batch normalization, the " "previous layer's activation may be Linear") if num_channels is None: if input.num_filters is not None: num_channels = input.num_filters else: num_channels = input.size assert (batch_norm_type is None) or (batch_norm_type == "batch_norm") or \ (batch_norm_type == "cudnn_batch_norm") Layer( name=name, inputs=Input(input.name, image=Image(channels=num_channels), **param_attr.attr), active_type=act.name, type=LayerType.BATCH_NORM_LAYER, batch_norm_type=batch_norm_type, bias=ParamAttr.to_bias(bias_attr), moving_average_fraction=moving_average_fraction, use_global_stats=use_global_stats, **ExtraLayerAttribute.to_kwargs(layer_attr) ) return LayerOutput(name=name, layer_type=LayerType.BATCH_NORM_LAYER, parents=[input], activation=act, num_filters=num_channels) @wrap_name_default() @layer_support() def sum_to_one_norm_layer(input, name=None, layer_attr=None): """ A layer for sum-to-one normalization, which is used in NEURAL TURING MACHINE. .. math:: out[i] = \\frac {in[i]} {\sum_{k=1}^N in[k]} where :math:`in` is a (batchSize x dataDim) input vector, and :math:`out` is a (batchSize x dataDim) output vector. The example usage is: .. code-block:: python sum_to_one_norm = sum_to_one_norm_layer(input=layer) :param input: Input layer. :type input: LayerOutput :param name: Layer name. :type name: basestring :param layer_attr: extra layer attributes. :type layer_attr: ExtraLayerAttribute. :return: layer name. :rtype: LayerOutput """ Layer( name=name, type=LayerType.SUM_TO_ONE_NORM_LAYER, inputs=[input.name], **ExtraAttr.to_kwargs(layer_attr) ) return LayerOutput(name, LayerType.SUM_TO_ONE_NORM_LAYER, parents=[input], size=input.size) @wrap_name_default("addto") @wrap_act_default(act=LinearActivation()) @wrap_bias_attr_default(has_bias=False) @layer_support(DROPOUT) def addto_layer(input, act=None, name=None, bias_attr=None, layer_attr=None): """ AddtoLayer. .. math:: y = f(\\sum_{i} x_i + b) where :math:`y` is output, :math:`x` is input, :math:`b` is bias, and :math:`f` is activation function. The example usage is: .. code-block:: python addto = addto_layer(input=[layer1, layer2], act=ReluActivation(), bias_attr=False) This layer just simply add all input layers together, then activate the sum inputs. Each input of this layer should be the same size, which is also the output size of this layer. There is no weight matrix for each input, because it just a simple add operation. If you want to a complicated operation before add, please use mixed_layer. It is a very good way to set dropout outside the layers. Since not all paddle layer support dropout, you can add an add_to layer, set dropout here. Please refer to dropout_layer for details. :param name: Layer name. :type name: basestring :param input: Input layers. It could be a LayerOutput or list/tuple of LayerOutput. :type input: LayerOutput|list|tuple :param act: Activation Type, default is tanh. :type act: BaseActivation :param bias_attr: Bias attribute. If False, means no bias. None is default bias. :type bias_attr: ParameterAttribute|bool :param layer_attr: Extra Layer attribute. :type layer_attr: ExtraLayerAttribute :return: layer's output :rtype: LayerOutput """ num_filters = None if isinstance(input, LayerOutput): input = [input] assert isinstance(input, list) or isinstance(input, tuple) ipts_for_layer = [] for each_input in input: assert isinstance(each_input, LayerOutput) ipts_for_layer.append(Input(each_input.name)) if each_input.num_filters is not None: num_filters = each_input.num_filters Layer( name=name, type=LayerType.ADDTO_LAYER, inputs=ipts_for_layer, bias=ParamAttr.to_bias(bias_attr), active_type=act.name, **ExtraLayerAttribute.to_kwargs(layer_attr) ) assert isinstance(input, list) or isinstance(input, tuple) return LayerOutput(name, LayerType.ADDTO_LAYER, parents=input, activation=act, num_filters=num_filters) @wrap_act_default(act=IdentityActivation()) @wrap_name_default("concat") @layer_support() def concat_layer(input, act=None, name=None, layer_attr=None): """ Concat all input vector into one huge vector. Inputs can be list of LayerOutput or list of projection. :param name: Layer name. :type name: basestring :param input: input layers or projections :type input: list|tuple :param act: Activation type. :type act: BaseActivation :param layer_attr: Extra Layer Attribute. :type layer_attr: ExtraLayerAttribute :return: layer's output :rtype: LayerOutput """ if isinstance(input, LayerOutput): input = [input] elif isinstance(input, Projection): input = [input] else: assert isinstance(input, list) or isinstance(input, tuple) def __is_type__(o, tp): if not isinstance(o, list) and not isinstance(o, tuple): if o == tp: return True elif len(o.__bases__) == 0: return False else: for bs in o.__bases__: if __is_type__(bs, tp): return True return False else: tmp = map(lambda _x: __is_type__(_x, tp), o) a = tmp[0] for b in tmp[1:]: assert a == b return a def __reduce_concat_type__(a, b): assert __is_type__([a, b], Projection) or __is_type__([a, b], LayerOutput) return a is_concat_layer = __is_type__(reduce(__reduce_concat_type__, map(type, input)), LayerOutput) layer_type = (LayerType.CONCAT_LAYER if is_concat_layer else LayerType.CONCAT_PROJ_LAYER) Layer( name=name, type=layer_type, inputs=[x.name for x in input] if is_concat_layer else input, active_type=act.name, **ExtraLayerAttribute.to_kwargs(layer_attr) ) sz = 0 for each_input in input: if each_input.size is not None: sz += each_input.size else: sz = None break return LayerOutput(name, layer_type=layer_type, parents=input if is_concat_layer else [ x.origin for x in input], activation=act, size=sz) def memory(name, size, is_seq=False, boot_layer=None, boot_bias=None, boot_bias_active_type=None, boot_with_const_id=None): """ The memory layers is a layer cross each time step. Reference this output as previous time step layer :code:`name` 's output. The default memory is zero in first time step, previous time step's output in the rest time steps. If boot_bias, the first time step value is this bias and with activation. If boot_with_const_id, then the first time stop is a IndexSlot, the Arguments.ids()[0] is this :code:`cost_id`. If boot_layer is not null, the memory is just the boot_layer's output. Set :code:`is_seq` is true boot layer is sequence. The same name layer in recurrent group will set memory on each time step. :param name: memory's name. :type name: basestring :param size: size of memory. :type size: int :param is_seq: is sequence for boot_layer :type is_seq: bool :param boot_layer: boot layer of memory. :type boot_layer: LayerOutput|None :param boot_bias: boot layer's bias :type boot_bias: ParameterAttribute|None :param boot_bias_active_type: boot layer's active type. :type boot_bias_active_type: BaseActivation :param boot_with_const_id: boot layer's id. :type boot_with_const_id: int :return: Memory layer's output :rtype: LayerOutput """ if boot_bias_active_type is None: boot_bias_active_type = LinearActivation() assert boot_bias is None or isinstance(boot_bias, ParameterAttribute) if isinstance(boot_bias, ParameterAttribute): boot_bias = ParamAttr.to_bias(boot_bias) assert boot_layer is None or isinstance(boot_layer, LayerOutput) agent_name = Memory(name, size, is_seq, boot_layer.name if boot_layer is not None else None, boot_bias, boot_bias_active_type.name, boot_with_const_id) lout = LayerOutput(name=agent_name, size=size, layer_type=LayerType.MEMORY, parents=[boot_layer] if boot_layer is not None else None) return lout @wrap_bias_attr_default() @wrap_act_default(param_names=['gate_act', 'state_act'], act=SigmoidActivation()) @wrap_act_default(act=TanhActivation()) @wrap_name_default('lstm_step') @layer_support() def lstm_step_layer(input, state, size, act=None, name=None, gate_act=None, state_act=None, bias_attr=None, layer_attr=None): """ LSTM Step Layer. It used in recurrent_group. The lstm equations are shown as follow. .. math:: i_t = \\sigma(W_{xi}x_{t} + W_{hi}h_{t-1} + W_{ci}c_{t-1} + b_i) f_t = \\sigma(W_{xf}x_{t} + W_{hf}h_{t-1} + W_{cf}c_{t-1} + b_f) c_t = f_tc_{t-1} + i_t tanh (W_{xc}x_t+W_{hc}h_{t-1} + b_c) o_t = \\sigma(W_{xo}x_{t} + W_{ho}h_{t-1} + W_{co}c_t + b_o) h_t = o_t tanh(c_t) The input\_ of lstm step is :math:`Wx_t + Wh_{t-1}`, and user should use :code:`mixed_layer` and :code:`full_matrix_projection` to calculate these input vector. The state of lstm step is :math:`c_{t-1}`. And lstm step layer will do .. math:: i_t = \\sigma(input + W_{ci}c_{t-1} + b_i) ... This layer contains two outputs. Default output is :math:`h_t`. The other output is :math:`o_t`, which name is 'state' and can use :code:`get_output_layer` to extract this output. :param name: Layer's name. :type name: basestring :param size: Layer's size. NOTE: lstm layer's size, should be equal as :code:`input.size/4`, and should be equal as :code:`state.size`. :type size: int :param input: input layer. :math:`Wx_t + Wh_{t-1}` :type input: LayerOutput :param state: State Layer. :math:`c_{t-1}` :type state: LayerOutput :param act: Activation type. Default is tanh :type act: BaseActivation :param gate_act: Gate Activation Type. Default is sigmoid, and should be sigmoid only. :type gate_act: BaseActivation :param state_act: State Activation Type. Default is sigmoid, and should be sigmoid only. :type state_act: BaseActivation :param bias_attr: Bias Attribute. :type bias_attr: ParameterAttribute :param layer_attr: layer's extra attribute. :type layer_attr: ExtraLayerAttribute :return: lstm step's layer output :rtype: LayerOutput """ Layer( name=name, type=LayerType.LSTM_STEP_LAYER, active_type=act.name, active_gate_type=gate_act.name, active_state_type=state_act.name, bias=ParamAttr.to_bias(bias_attr), size=size, inputs=[input.name, state.name], **ExtraLayerAttribute.to_kwargs(layer_attr) ) return LayerOutput(name=name, layer_type=LayerType.LSTM_STEP_LAYER, parents=[input, state], activation=act, size=size, outputs=['default', 'state']) @wrap_bias_attr_default() @wrap_act_default(param_names=['gate_act'], act=SigmoidActivation()) @wrap_act_default(act=TanhActivation()) @wrap_name_default('gru_step') @layer_support() def gru_step_layer(input, output_mem, size=None, act=None, name=None, gate_act=None, bias_attr=None, layer_attr=None): """ :param input: :type input: LayerOutput :param output_mem: :param size: :param act: :param name: :param gate_act: :param bias_attr: :param layer_attr: :return: :rtype: LayerOutput """ assert input.size % 3 == 0 if size is None: size = input.size / 3 Layer( name=name, type=LayerType.GRU_STEP_LAYER, inputs=[ input.name, output_mem.name ], bias=ParamAttr.to_bias(bias_attr), size=size, active_type=act.name, active_gate_type=gate_act.name, **ExtraAttr.to_kwargs(layer_attr) ) return LayerOutput( name=name, layer_type=LayerType.GRU_STEP_LAYER, parents=[input, output_mem], size=size, activation=act) @wrap_name_default() @layer_support() def get_output_layer(input, arg_name, name=None, layer_attr=None): """ Get layer's output by name. In paddle, a layer might return multiple value, but return one layer output. If user want to reference another output beside default output, use get_output_layer first to get another output from input. :param name: Layer's name. :type name: basestring :param input: get output layer's input. And this layer should contains multiple outputs. :type input: LayerOutput :param arg_name: Output name from input. :type arg_name: basestring :param layer_attr: Layer's extra attribute. :return: Layer's output :rtype: LayerOutput """ # GetOutputLayer assert arg_name in input.outputs, 'Get Output From an not existed input.' \ ' The get output name is %s, which not' \ ' in %s' % ( arg_name, ",".join(input.outputs)) Layer(name=name, type=LayerType.GET_OUTPUT_LAYER, inputs=[Input(input.name, input_layer_argument=arg_name)], size=input.size, **ExtraLayerAttribute.to_kwargs(layer_attr)) return LayerOutput(name=name, layer_type=LayerType.GET_OUTPUT_LAYER, parents=[input], size=input.size) @wrap_name_default() @wrap_act_default() @wrap_bias_attr_default() @wrap_param_attr_default() @layer_support() def recurrent_layer(input, act=None, bias_attr=None, param_attr=None, name=None, layer_attr=None): """ TODO(yuyang18): Add docs :param input: :param size: :param act: :param bias_attr: :param param_attr: :param name: :param layer_attr: :return: """ Layer(name=name, type=LayerType.RECURRENT_LAYER, inputs=Input(input.name, **param_attr.attr), active_type=act.name, size=input.size, bias=ParamAttr.to_bias(bias_attr), **ExtraAttr.to_kwargs(layer_attr)) return LayerOutput(name=name, layer_type=LayerType.RECURRENT_LAYER, parents=[input], size=input.size, activation=act) class StaticInput(object): """ StaticInput is only used in recurrent_group which defines a read-only memory that can be a sequence or non-sequence. """ def __init__(self, input, is_seq=False, size=None): assert isinstance(input, LayerOutput) self.input = input self.is_seq = is_seq assert input.size is not None or size is not None if size is not None: input.size = size class SubsequenceInput(object): """ Input sequence has sub-sequence, used in recurrent_group. The example usage is: .. code-block:: python input = SubsequenceInput(layer) """ def __init__(self, input): assert isinstance(input, LayerOutput) assert input.size is not None self.input = input @wrap_name_default("recurrent_group") def recurrent_group(step, input, reverse=False, name=None): """ Recurrent Group. It supports time steps and sequence steps mechanisms. The basic usage (time steps) is: .. code-block:: python def step(input): output = fc_layer(input=layer, size=1024, act=LinearActivation(), bias_attr=False) return output group = recurrent_group(input=layer, step=step) You can see following configs for further usages: - time steps: lstmemory_group, paddle/gserver/tests/sequence_layer_group.conf, \ demo/seqToseq/seqToseq_net.py - sequence steps: paddle/gserver/tests/sequence_nest_layer_group.conf :param step: recurrent one time step function.The input of this function is input of the group. The return of this function will be recurrent group's return value. The recurrent group scatter a sequence into time steps. And for each time step, will invoke step function, and return a time step result. Then gather each time step of output into layer group's output. :type step: callable :param name: recurrent_group's name. :type name: basestring :param input: Input links array. LayerOutput will be scattered into time steps. SubsequenceInput will be scattered into sequence steps. StaticInput will be imported to each time step, and doesn't change through time. It's a mechanism to access layer outside step function. :type input: LayerOutput|StaticInput|SubsequenceInput|list|tuple :param reverse: Reverse is true, rnn will process sequence reversely. :type reverse: bool :return: Layer output object :rtype: LayerOutput """ model_type('recurrent_nn') def is_single_input(x): return isinstance(x, LayerOutput) or isinstance(x, StaticInput) \ or isinstance(x, SubsequenceInput) if is_single_input(input): input = [input] assert isinstance(input, list) or isinstance(input, tuple) def is_in_links(x): return isinstance(x, LayerOutput) or isinstance(x, SubsequenceInput) in_links = filter(is_in_links, input) contains_sub_seq = [False] def map_in_links(x): if isinstance(x, SubsequenceInput): contains_sub_seq[0] = True return Link(name=x.input.name, has_subseq=True) else: return x.name RecurrentLayerGroupWithoutOutLinksBegin( name=name, in_links=map(map_in_links, in_links), seq_reversed=reverse) in_args = [] for each_input in input: assert is_single_input(each_input) if isinstance(each_input, LayerOutput): in_args.append(each_input) elif isinstance(each_input, SubsequenceInput): in_args.append(each_input.input) else: mem_name = "__%s_memory__" % each_input.input.name mem = memory(name=mem_name, is_seq=each_input.is_seq, size=each_input.input.size, boot_layer=each_input.input) with mixed_layer(name=mem_name, size=each_input.input.size, act=IdentityActivation()) as mix: mix += identity_projection(mem) in_args.append(mem) layer_outs = step(*in_args) if isinstance(layer_outs, LayerOutput): layer_outs = [layer_outs] for ot in layer_outs: assert isinstance(ot, LayerOutput) if contains_sub_seq[0]: RecurrentLayerGroupSetOutLink(Link(ot.name, has_subseq=True)) else: RecurrentLayerGroupSetOutLink(ot.name) RecurrentLayerGroupEnd(name=name) if len(layer_outs) == 1: return layer_outs[0] else: return layer_outs class BaseGeneratedInput(object): def __init__(self): self.bos_id = None self.eos_id = None def before_real_step(self): raise NotImplementedError() def after_real_step(self, *args): raise NotImplementedError() class GeneratedInput(BaseGeneratedInput): def after_real_step(self, input): return maxid_layer(input=input, name='__beam_search_predict__') def before_real_step(self): predict_id = memory(name='__beam_search_predict__', size=self.size, boot_with_const_id=self.bos_id) trg_emb = embedding_layer(input=predict_id, size=self.embedding_size, param_attr=ParamAttr( name=self.embedding_name)) return trg_emb def __init__(self, size, embedding_name, embedding_size): self.size = size self.embedding_name = embedding_name self.embedding_size = embedding_size @wrap_name_default() def maxid_layer(input, name=None, layer_attr=None): """ A layer for finding the id which has the maximal value for each sample. The result is stored in output.ids. The example usage is: .. code-block:: python maxid = maxid_layer(input=layer) :param input: Input layer name. :type input: LayerOutput :param name: Layer name. :type name: basestring :param layer_attr: extra layer attributes. :type layer_attr: ExtraLayerAttribute. :return: layer name. :rtype: LayerOutput """ assert isinstance(input, LayerOutput) Layer(name=name, type='maxid', inputs=[input.name], **ExtraLayerAttribute.to_kwargs(layer_attr)) return LayerOutput(name=name, layer_type=LayerType.MAXID_LAYER, parents=[input]) @wrap_name_default() def eos_layer(input, eos_id, name=None, layer_attr=None): """ A layer for checking EOS for each sample: - output_id = (input_id == conf.eos_id) The result is stored in output\_.ids. It is used by recurrent layer group. The example usage is: .. code-block:: python eos = eos_layer(input=layer, eos_id=id) :param input: Input layer name. :type input: LayerOutput :param eos_id: end id of sequence :type eos_id: int :param name: Layer name. :type name: basestring :param layer_attr: extra layer attributes. :type layer_attr: ExtraLayerAttribute. :return: layer name. :rtype: LayerOutput """ Layer(name=name, type=LayerType.EOSID_LAYER, eos_id=eos_id, inputs=[input.name], **ExtraLayerAttribute.to_kwargs(layer_attr)) return LayerOutput(name=name, layer_type=LayerType.EOSID_LAYER, parents=[input]) @wrap_name_default() def beam_search(step, input, bos_id, eos_id, beam_size, result_file, dict_file="", id_input=None, max_length=500, name=None, num_results_per_sample=None): if num_results_per_sample is None: num_results_per_sample = beam_size if num_results_per_sample > beam_size: logger.warning("num_results_per_sample should be less than beam_size") if isinstance(input, StaticInput) or isinstance(input, BaseGeneratedInput): input = [input] generated_input_index = -1 real_input = [] for i, each_input in enumerate(input): # print type(each_input) assert isinstance(each_input, StaticInput) or isinstance(each_input, BaseGeneratedInput) if isinstance(each_input, BaseGeneratedInput): assert generated_input_index == -1 generated_input_index = i else: real_input.append(each_input) assert generated_input_index != -1 gipt = input[generated_input_index] assert isinstance(gipt, BaseGeneratedInput) gipt.bos_id = bos_id gipt.eos_id = eos_id def __real_step__(*args): eos_name = "__%s_eos_layer__" % name RecurrentLayerGroupSetGenerator(Generator( eos_layer_name=eos_name, max_num_frames=max_length, beam_size=beam_size, num_results_per_sample=num_results_per_sample)) args = list(args) args.insert(generated_input_index, gipt.before_real_step()) predict = gipt.after_real_step(step(*args)) eos_layer(input=predict, eos_id=eos_id, name=eos_name) return predict tmp = recurrent_group(step=__real_step__, input=real_input, reverse=False, name=name) if id_input is None: inputs = [tmp.name] else: assert isinstance(id_input, LayerOutput) inputs = [id_input.name, tmp.name] tmp.parents.append(id_input) Evaluator(name='target_printer', type='seq_text_printer', dict_file=dict_file, result_file=result_file, inputs=inputs ) return tmp @wrap_name_default() def regression_cost(input, label, cost='square_error', name=None): """ Regression Layer. TODO(yuyang18): Complete this method. :param name: layer name. :param input: Network prediction. :param label: Data label. :param cost: Cost method. :return: layer name. """ Layer(inputs=[Input(input.name), Input(label.name)], type=cost, name=name) return LayerOutput( name, LayerType.COST, parents=[input, label] ) @wrap_name_default("cost") def classification_cost(input, label, name=None, cost="multi-class-cross-entropy", evaluator=classification_error_evaluator): """ classification cost Layer. :param name: layer name. :type name: basestring :param input: input layer name. network output. :type input: LayerOutput :param label: label layer name. data_layer often. :type label: LayerOutput :param cost: cost method. :type cost: basestring :param evaluator: Evaluator method. :return: layer name. :rtype: LayerOutput """ assert input.layer_type != LayerType.DATA assert isinstance(input.activation, SoftmaxActivation) assert label.layer_type == LayerType.DATA Layer(name=name, type=cost, inputs=[Input(input.name), Input(label.name)]) def __add_evaluator__(e): assert callable(e) assert hasattr(e, 'is_evaluator') assert isinstance(e.is_evaluator, bool) assert e.is_evaluator assert hasattr(e, "for_classification") assert isinstance(e.for_classification, bool) assert e.for_classification e(name=e.__name__, input=input, label=label) if not isinstance(evaluator, list) and not isinstance(evaluator, tuple): evaluator = [evaluator] for each_evaluator in evaluator: __add_evaluator__(each_evaluator) return LayerOutput(name, LayerType.COST, parents=[input, label]) def conv_operator(input, filter_size, num_filters, num_channel=None, stride=1, padding=0, filter_size_y=None, stride_y=None, padding_y=None): """ Different from img_conv_layer, conv_op is an Operator, which can be used in mixed_layer. And conv_op takes two inputs to perform convolution. The first input is the image and the second is filter kernel. It only support GPU mode. The example usage is: .. code-block:: python op = conv_operator(input=[layer1, layer2], filter_size=3.0, num_filters=64, num_channels=64) :param input: Input layer. :type input: LayerOutput|list|tuple :param filter_size: The x dimension of a filter kernel. :type filter_size: int :param filter_size_y: The y dimension of a filter kernel. Since paddle now support rectangular filters, the filter's shape will be (filter_size, filter_size_y). :type filter_size_y: int :param num_filter: channel of output data. :type num_filter: int :param num_channel: channel of input data. :rtype num_channel: int :param stride: The x dimension of the stride. :rtype stride: int :param stride_y: The y dimension of the stride. :rtype stride_y: int :param padding: The x dimension of padding. :type padding: int :param padding_y: The y dimension of padding. :type padding_y: int :return: A ConvOperator Object. :rtype: ConvOperator """ assert isinstance(input, list) or isinstance(input, tuple) if filter_size_y is None: filter_size_y = filter_size if stride_y is None: stride_y = stride if padding_y is None: padding_y = padding op = ConvOperator(input_layer_name=[x.name for x in input], num_filters = num_filter, conv_conf=Conv(filter_size=filter_size, padding=padding, stride=stride, channels=num_channel, filter_size_y=filter_size_y, padding_y=padding_y, stride_y=stride_y)) op.origin = input op.origin.operator = "conv_op" return op @wrap_name_default() def conv_shift_layer(input, name=None): """ This layer performs cyclic convolution for two input. For example: - a[in]: contains M elements. - b[in]: contains N elements (N should be odd). - c[out]: contains M elements. .. math:: c[i] = \sum_{j=-(N-1)/2}^{(N-1)/2}a_{i+j} * b_{j} In this formular: - a's index is computed modulo M. - b's index is computed modulo N. The example usage is: .. code-block:: python conv_shift = conv_shif_layer(input=[layer1, layer2]) :param name: layer name :type name: basestring :param input: Input layer. :type input: LayerOutput|list|tuple. :return: a object of LayerOutput. :rtype: LayerOutput """ assert isinstance(input, list) or isinstance(input, tuple) Layer( name=name, type=LayerType.CONV_SHIFT_LAYER, inputs=[x.name for x in input], ) return LayerOutput(name, LayerType.CONV_SHIFT_LAYER, parents=input) @wrap_name_default() @wrap_param_attr_default() @wrap_bias_attr_default() @layer_support(ERROR_CLIPPING, DROPOUT) def tensor_layer(input, size, act=None, name=None, param_attr=None, bias_attr=None, layer_attr=None): """ This layer performs tensor operation for two input. For example, each sample: .. math:: y_{i} = x_{1} * W_{i} * {x_{2}^\mathrm{T}}, i=0,1,...,K-1 In this formular: - :math:`x_{1}`: the first input contains M elements. - :math:`x_{2}`: the second input contains N elements. - y[out]: contains K elements. - :math:`y_{i}`: the i-th element of y. - :math:`W_{i}`: the i-th learned weight, shape if [M, N] - :math:`{x_{2}}^\mathrm{T}`: the transpose of :math:`x_{2}`. The simple usage is: .. code-block:: python tensor = tensor_layer(input=[layer1, layer2]) :param name: layer name :type name: basestring :param input: Input layer. :type input: LayerOutput|list|tuple. :param size: the layer dimension. :rtype: int. :param act: Activation Type. Default is tanh. :type act: BaseActivation :param param_attr: The Parameter Attribute. :type param_attr: ParameterAttribute|list :param bias_attr: The Bias Attribute. If no bias, then pass False or something not type of ParameterAttribute. None will get a default Bias. :type bias_attr: ParameterAttribute|None|Any :param layer_attr: Extra Layer config. :type layer_attr: ExtraLayerAttribute|None :return: a object of LayerOutput. :rtype: LayerOutput """ assert isinstance(input, list) or isinstance(input, tuple) assert len(input) == 2 Layer( name=name, size=size, type=LayerType.TENSOR_LAYER, active_type=act.name, bias=ParamAttr.to_bias(bias_attr), inputs=[Input(input[0].name, **param_attr), Input(input[1].name)], **ExtraLayerAttribute.to_kwargs(layer_attr) ) return LayerOutput(name, LayerType.TENSOR_LAYER, parents=input, activation=act, size=size) @wrap_param_attr_default() def trans_full_matrix_projection(input, size=0, param_attr=None): """ Different from full_matrix_projection, this projection performs matrix multiplication, using transpose of weight. .. math:: out.row[i] += in.row[i] * w^\mathrm{T} :math:`w^\mathrm{T}` means transpose of weight. The simply usage is: .. code-block:: python proj = trans_full_matrix_projection(input=layer, size=100, param_attr=ParamAttr( name='_proj', initial_mean=0.0, initial_std=0.01)) :param input: input layer :type input: LayerOutput :param size: The parameter size. Means the width of parameter. :type size: int :param param_attr: Parameter config, None if use default. :type param_attr: ParameterAttribute :return: A TransposedFullMatrixProjection Object. :rtype: TransposedFullMatrixProjection """ proj = TransposedFullMatrixProjection(input_layer_name=input.name, size=size, **param_attr.attr) proj.origin = input proj.origin.projection = "trans_matrix" return proj @wrap_name_default() @wrap_param_attr_default() @wrap_bias_attr_default() @wrap_act_default() def selective_fc_layer(input, size, act=None, name=None, pass_generation=False, has_selected_colums=True, mul_ratio=0.02, param_attr=None, bias_attr=None, layer_attr=None): """ Selectived fully connected layer. Different from fc_layer, the output of this layer maybe sparse. It requires an additional input to indicate several selected columns for output. If the selected columns is not specified, selective_fc_layer acts exactly like fc_layer. The simple usage is: .. code-block:: python sel_fc = selective_fc_layer(input=input, 128, act=TanhActivation()) :param name: The Layer Name. :type name: basestring :param input: The input layer. :type input: LayerOutput|list|tuple :param size: The layer dimension. :type size: int :param act: Activation Type. Default is tanh. :type act: BaseActivation :param param_attr: The Parameter Attribute. :type param_attr: ParameterAttribute :param bias_attr: The Bias Attribute. If no bias, then pass False or something not type of ParameterAttribute. None will get a default Bias. :type bias_attr: ParameterAttribute|None|Any :param layer_attr: Extra Layer config. :type layer_attr: ExtraLayerAttribute|None :return: a object of LayerOutput. :rtype: LayerOutput """ if isinstance(input, LayerOutput): input = [input] assert not isinstance(param_attr, list) param_attr = [param_attr] else: if isinstance(param_attr, list) or isinstance(param_attr, tuple): assert len(input) == len(param_attr) else: param_attr = [copy.deepcopy(param_attr) for _ in range(len(input))] assert isinstance(input, list) def __idx_to_input__(i): attr = param_attr[i] assert isinstance(attr, ParameterAttribute) return Input(input[i].name, **attr.attr) Layer( inputs=map(__idx_to_input__, range(len(input))), name=name, type=LayerType.SEL_FC_LAYER, size=size, active_type=act.name, selective_fc_pass_generation=pass_generation, has_selected_colums=has_selected_colums, selective_fc_full_mul_ratio=mul_ratio, **ExtraLayerAttribute.to_kwargs(layer_attr) ) return LayerOutput(name, LayerType.SEL_FC_LAYER, input, activation=act, size=size) @wrap_name_default() def sampling_id_layer(input, name=None): """ A layer for sampling id from multinomial distribution from the input layer. Sampling one id for one sample. The simple usage is: .. code-block:: python samping_id = sampling_id_layer(input=input) :param input: The input layer. :type input: LayerOutput :param name: The Layer Name. :type name: basestring :return: a object of LayerOutput. :rtype: LayerOutput """ Layer( name=name, type=LayerType.SAMPLING_ID_LAYER, inputs=[Input(input.name)], ) return LayerOutput(name, LayerType.SAMPLING_ID_LAYER, input) @wrap_name_default() def slope_intercept_layer(input, name=None, slope=1.0, intercept=0.0): """ This layer for applying a slope and an intercept to the input element-wise. There is no activation and weight. .. math:: y = slope * x + intercept The simple usage is: .. code-block:: python scale = slope_intercept_layer(input=input, slope=-1.0, intercept=1.0) :param input: The input layer. :type input: LayerOutput :param name: The Layer Name. :type name: basestring :param slope: the scale factor. :type slope: float. :param intercept: the offset. :type intercept: float. :return: a object of LayerOutput. :rtype: LayerOutput """ Layer( name=name, type=LayerType.SLOPE_INTERCEPT_LAYER, slope=slope, intercept=intercept, inputs=[Input(input.name)], ) return LayerOutput(name, LayerType.SLOPE_INTERCEPT_LAYER, input) @wrap_name_default() def convex_comb_layer(input, size, name=None): """ A layer for convex weighted average of vectors takes two inputs. - Input: a vector containing the convex weights (batchSize x weightdim), and a matrix in a vector form (batchSize x (weightdim*datadim)). - Output: a vector (batchSize * datadim). .. math:: y[i][j] = \sum_{j}(x_{1}(i, j) * x_{2}(i,j + i * dataDim)), i = 0,1,...,(batchSize-1); j = 0, 1,...,(dataDim-1) In this formular: - :math:`x_{1}`: the first input. - :math:`x_{2}`: the second input. - :math:`y`: the output. The simple usage is: .. code-block:: python convex_comb = convex_comb_layer(input=inputs, size=elem_dim) :param input: The input layers. :type input: LayerOutput :param size: the dimension of this layer. :type size: int :param name: The Layer Name. :type name: basestring :return: a object of LayerOutput. :rtype: LayerOutput """ assert isinstance(input, list) or isinstance(input, tuple) assert len(input) == 2 Layer( name=name, type=LayerType.CONVEX_COMBINATION_LAYER, size=size, inputs=[Input(input[0].name), Input(input[1].name)], ) return LayerOutput(name, LayerType.CONVEX_COMBINATION_LAYER, input, size=size) @wrap_name_default() def block_expand_layer(input, channel=0, block_x=0, block_y=0, stride_x=0, stride_y=0, padding_x=0, padding_y=0, name=None): """ Expand feature map to minibatch matrix. - matrix width is: block_y * block_x * channel - matirx height is: outputH * outputW .. math:: outputH = 1 + (2 * padding_y + imgSizeH - block_y + stride_y - 1) / stride_y outputW = 1 + (2 * padding_x + imgSizeW - block_x + stride_x - 1) / stride_x The expand method is the same with ExpandConvLayer, but saved the transposed value. After expanding, output.sequenceStartPositions will store timeline. The number of time steps are outputH * outputW and the dimension of each time step is block_y * block_x * channel. This layer can be used after convolution neural network, and before recurrent neural network. :param input: The input layer. :type input: LayerOutput :param channel: The channel number of input layer. :type channel: int :param block_x: The width of sub block. :type block_x: int :param block_y: The width of sub block. :type block_y: int :param stride_x: The stride size in horizontal direction. :type stride_x: int :param stride_y: The stride size in vertical direction. :type stride_y: int :param padding_x: The padding size in horizontal direction. :type padding_x: int :param padding_y: The padding size in vertical direction. :type padding_y: int :param name: The name of this layer, which can not specify. :type name: None|basestring. :return: a object of LayerOutput. :rtype: LayerOutput """ Layer(name=name, input=Input(input.name, block_expand=BlockExpand(channel=channel, block_x=block_x, block_y=block_y, stride_x=stride_x, stride_y=stride_y, padding_x=padding_x, padding_y=padding_y) ), type=LayerType.BLOCK_EXPAND, ) return LayerOutput(name, LayerType.BLOCK_EXPAND, parents=[input], size=size) @wrap_name_default() def ctc_layer(input, label, size, name=None, norm_by_times=False): """ Connectionist Temporal Classification (CTC) is designed for temporal classication task. That is, for sequence labeling problems where the alignment between the inputs and the target labels is unknown. The simple usage: .. code-block:: python ctc = ctc_layer(input=input, label=label, size=9055, norm_by_times=True) :param input: The input layers. :type input: LayerOutput :param label: The data layer of label with variable length. :type label: LayerOutput :param size: category numbers. :type size: int :param name: The name of this layer, which can not specify. :type name: string|None :param norm_by_times: Whether to normalization by times. False by default. :type norm_by_times: bool :return: a object of LayerOutput. :rtype: LayerOutput """ assert isinstance(input, LayerOutput) assert isinstance(label, LayerOutput) Layer( name = name, type = LayerType.CTC_LAYER, size = size, norm_by_times = norm_by_times, inputs = [input.name, label.name] ) return LayerOutput(name, LayerType.CTC_LAYER, [input, label], size=size) @wrap_name_default() def crf_layer(input, label, size, weight=None, param_attr=None, name=None): """ A layer for calculating the cost of sequential conditional random field model. The simple usage: .. code-block:: python crf = crf_layer(input=input, label=label, size=label_dim) :param input: The first input layer is the feature. :type input: LayerOutput :param label: The second input layer is label. :type input: LayerOutput :param size: The category number. :type size: int :param weight: The third layer is "weight" of each sample, which is an optional argument. :type weight: LayerOutput :param param_attr: Parameter attribute. None means default attribute :type param_attr: ParameterAttribute :param name: The name of this layers. It is not necessary. :type name: None|basestring :return: a object of LayerOutput. :rtype: LayerOutput """ assert isinstance(input, LayerOutput) assert isinstance(label, LayerOutput) assert weight is None or isinstance(weight, LayerOutput) ipts = [Input(input.name, **param_attr), Input(label.name)] if weight is not None: ipts.append(Input(weight.name)) Layer( name = name, type = LayerType.CRF_LAYER, size = size, inputs = ipts, ) parents = [input, label] if weight is not None: parents.append(weight) return LayerOutput(name, LayerType.CRF_LAYER, parents, size=size) @wrap_name_default() def crf_decoding_layer(input, size, label=None, param_attr=None, name=None): """ A layer for calculating the decoding sequence of sequential conditional random field model. The decoding sequence is stored in output.ids. If a second input is provided, it is treated as the ground-truth label, and this layer will also calculate error. output.value[i] is 1 for incorrect decoding or 0 for correct decoding. :param input: The first input layer. :type input: LayerOutput :param size: size of this layer. :type size: int :param label: None or ground-truth label. :type label: LayerOutput or None :param param_attr: Parameter attribute. None means default attribute :type param_attr: ParameterAttribute :param name: The name of this layers. It is not necessary. :type name: None|basestring :return: a object of LayerOutput. :rtype: LayerOutput """ assert isinstance(input, LayerOutput) assert label is None or isinstance(label, LayerOutput) ipts = [Input(input.name, **param_attr)] if label is not None: ipts.append(Input(label.name)) Layer( name = name, type = LayerType.CRF_DECODING_LAYER, size = size, inputs = ipts, ) parents = [input] if label is not None: parents.append(label) return LayerOutput(name, LayerType.CRF_DECODING_LAYER, parents, size=size) """ following are cost Layers. """ @wrap_name_default() def rank_cost(left, right, lable, weight=None, name=None, coeff=1.0): """ A cost Layer for learning to rank using gradient descent. Details can refer to `papers `_. This layer contains at least three inputs. The weight is an optional argument, which affects the cost. .. math:: C_{i,j} = -\\tilde{P_{ij}} * o_{i,j} + log(1 + e^{o_{i,j}}) o_{i,j} = o_i - o_j \\tilde{P_{i,j}} = \\{0, 0.5, 1\\} \ or \ \\{0, 1\\} In this formula: - :math:`C_{i,j}` is the cross entropy cost. - :math:`\\tilde{P_{i,j}}` is the label. 1 means positive order and 0 means reverse order. - :math:`o_i` and :math:`o_j`: the left output and right output. Their dimension is one. The simple usage: .. code-block:: python cost = rank_cost(left=out_left, right=out_right, label=label) :param left: The first input, the size of this layer is 1. :type left: LayerOutput :param right: The right input, the size of this layer is 1. :type right: LayerOutput :param label: Label is 1 or 0, means positive order and reverse order. :type label: LayerOutput :param weight: The weight affects the cost, namely the scale of cost. It is an optional argument. :type weight: LayerOutput :param name: The name of this layers. It is not necessary. :type name: None|basestring :param coeff: The coefficient affects the gradient in the backward. :type coeff: float :return: a object of LayerOutput. :rtype: LayerOutput """ assert left.size == 1 assert right.size == 1 assert label.size == 1 ipts = [left.name, right.name, label.name] parents = [left, right, label] if weight is not None: ipts.append(weight.name) parents.append(weight) Layer(name=name, type=LayerType.RANK_COST, inputs=ipts, coeff=coeff, ) return LayerOutput(name, LayerType.RANK_COST, parents=parents) @wrap_name_default() def lambda_cost(input, score, NDCG_num=5, max_sort_size=-1, coeff=1.0): """ lambdaCost for lambdaRank LTR approach. The simple usage: .. code-block:: python cost = lambda_cost(input=input, score=score, NDCG_num=8, max_sort_size=-1) :param input: The 1st input. Samples of the same query should be loaded as sequence. User should provided socres for each sample. The score should be the 2nd input of this layer. :type input: LayerOutput :param score: The 2nd input. Score of each sample. :type input: LayerOutput :param NDCG_num: The size of NDCG (Normalized Discounted Cumulative Gain), e.g., 5 for NDCG@5. It must be less than for equal to the minimum size of lists. :type NDCG_num: int :param max_sort_size: The size of partial sorting in calculating gradient. If max_sort_size = -1, then for each list, the algorithm will sort the entire list to get gradient. In other cases, max_sort_size must be greater than or equal to NDCG_num. And if max_sort_size is greater than the size of a list, the algorithm will sort the entire list of get gradient. :type max_sort_size: int :param name: The name of this layers. It is not necessary. :type name: None|basestring :param coeff: The coefficient affects the gradient in the backward. :type coeff: float :return: a object of LayerOutput. :rtype: LayerOutput """ Layer(name=name, type=LayerType.LAMBDA_COST, inputs=[input.name, score.name], NDCG_num=NDCG_num, max_sort_size=max_sort_size, coeff=coeff, ) return LayerOutput(name, LayerType.LAMBDA_COST, parents=[input, score]) @wrap_name_default() def cross_entropy(input, label, name=None, coeff=1.0): """ A loss layer for multi class entropy. .. code-block:: python cost = cross_entropy(input, label) :param input: The first input layer. :type input: LayerOutput. :param label: The input label. :type input: LayerOutput. :param type: The type of cost. :type type: basestring. :param name: The name of this layers. It is not necessary. :type name: None|basestring. :param coeff: The coefficient affects the gradient in the backward. :type coeff: float. :return: a object of LayerOutput. :rtype: LayerOutput. """ Layer(name=name, type=LayerType.CROSS_ENTROPY, inputs=[input.name, label.name], coeff=coeff, ) return LayerOutput(name, LayerType.CROSS_ENTROPY, parents=[input, label]) @wrap_name_default() def cross_entropy_with_selfnorm(input, label, name=None, coeff=1.0, softmax_selfnorm_alpha=0.1): """ A loss layer for multi class entropy with selfnorm. .. code-block:: python cost = cross_entropy_with_selfnorm(input, label) :param input: The first input layer. :type input: LayerOutput. :param label: The input label. :type input: LayerOutput. :param type: The type of cost. :type type: basestring. :param name: The name of this layers. It is not necessary. :type name: None|basestring. :param coeff: The coefficient affects the gradient in the backward. :type coeff: float. :param softmax_selfnorm_alpha: The scale factor affects the cost. :type softmax_selfnorm_alpha: float. :return: a object of LayerOutput. :rtype: LayerOutput. """ Layer(name=name, type=LayerType.CROSS_ENTROPY_WITH_SELFNORM, inputs=[input.name, label.name], coeff=coeff, softmax_selfnorm_alpha=softmax_selfnorm_alpha, ) return LayerOutput(name, LayerType.CROSS_ENTROPY_WITH_SELFNORM, parents=[input, label]) @wrap_name_default() def huber_cost(input, label, name=None, coeff=1.0): """ A loss layer for huber loss. .. code-block:: python cost = huber_cost(input, label) :param input: The first input layer. :type input: LayerOutput. :param label: The input label. :type input: LayerOutput. :param type: The type of cost. :type type: basestring. :param name: The name of this layers. It is not necessary. :type name: None|basestring. :param coeff: The coefficient affects the gradient in the backward. :type coeff: float. :return: a object of LayerOutput. :rtype: LayerOutput. """ Layer(name=name, type=LayerType.HUBER, inputs=[input.name, label.name], coeff=coeff, ) return LayerOutput(name, LayerType.HUBER, parents=[input, label]) @wrap_name_default() def multi_binary_label_cross_entropy(input, label, name=None, coeff=1.0): """ A loss layer for multi binary label cross entropy. .. code-block:: python cost = multi_binary_label_cross_entropy(input, label) :param input: The first input layer. :type input: LayerOutput :param label: The input label. :type input: LayerOutput :param type: The type of cost. :type type: basestring :param name: The name of this layers. It is not necessary. :type name: None|basestring :param coeff: The coefficient affects the gradient in the backward. :type coeff: float :return: a object of LayerOutput. :rtype: LayerOutput """ if not isinstance(input.act, SigmoidActivation): logger.log(logging.WARN, "%s is not recommend for batch normalization's activation, " "maybe the relu is better" % act.name) Layer(name=name, type=LayerType.MULTI_BIN_LABEL_CROSS_ENTROPY, inputs=[input.name, label.name], coeff=coeff, ) return LayerOutput(name, LayerType.MULTI_BIN_LABEL_CROSS_ENTROPY, parents=[input, label])