layers.py 29.3 KB
Newer Older
Y
Yu Yang 已提交
1
import paddle.v2.framework.core as core
Y
Yu Yang 已提交
2 3 4 5
from paddle.v2.framework.framework import OpProtoHolder, Variable, Program, \
    Operator
from paddle.v2.framework.initializer import ConstantInitializer, \
    NormalInitializer
6
from paddle.v2.framework.layer_helper import LayerHelper, unique_name
Y
Yu Yang 已提交
7 8
import re

Q
QI JUN 已提交
9
__all__ = [
Y
Yu Yang 已提交
10
    'fc', 'data', 'cross_entropy', 'conv2d', 'pool2d', 'embedding', 'concat',
D
dzhwinter 已提交
11 12
    'StaticRNN', 'cast', 'sequence_conv', 'sequence_pool', 'sums', 'cos_sim',
    'batch_norm', 'accuracy'
Q
QI JUN 已提交
13
]
Y
Yu Yang 已提交
14 15


F
fengjiayi 已提交
16 17 18 19 20 21 22
def fc(input,
       size,
       param_attr=None,
       bias_attr=True,
       name=None,
       act=None,
       num_flatten_dims=1,
23 24
       main_program=None,
       startup_program=None):
Y
Yu Yang 已提交
25 26 27 28 29 30 31 32 33
    # create helper
    helper = LayerHelper('fc', **locals())

    dtype = helper.input_dtype()

    # mul
    mul_results = []
    for input_var, param_attr in helper.iter_inputs_and_params():
        input_shape = input_var.shape
Y
Yu Yang 已提交
34 35 36
        param_shape = [
            reduce(lambda a, b: a * b, input_shape[num_flatten_dims:], 1)
        ] + [size]
Y
Yu Yang 已提交
37 38 39 40 41 42 43 44 45 46
        w = helper.create_parameter(
            attr=param_attr, shape=param_shape, dtype=dtype)
        tmp = helper.create_tmp_variable(dtype)
        helper.append_op(
            type="mul",
            inputs={
                "X": input_var,
                "Y": w,
            },
            outputs={"Out": tmp},
Y
Yu Yang 已提交
47 48
            attrs={'x_num_col_dims': num_flatten_dims,
                   'y_num_col_dims': 1})
Y
Yu Yang 已提交
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
        mul_results.append(tmp)

    # sum
    if len(mul_results) == 1:
        pre_bias = mul_results[0]
    else:
        pre_bias = helper.create_tmp_variable(dtype)
        helper.append_op(
            type="sum", inputs={"X": mul_results}, outputs={"Out": pre_bias})
    # add bias
    pre_activation = helper.append_bias_op(pre_bias)
    # add activation
    return helper.append_activation(pre_activation)


Q
QI JUN 已提交
64 65 66
def embedding(input,
              size,
              data_type='float32',
67
              is_sparse=False,
Q
QI JUN 已提交
68
              param_attr=None,
69 70
              main_program=None,
              startup_program=None):
Q
QI JUN 已提交
71 72 73 74 75 76 77 78
    helper = LayerHelper('embedding', **locals())
    w = helper.create_parameter(
        attr=helper.param_attr, shape=size, dtype=data_type)
    tmp = helper.create_tmp_variable(data_type)
    helper.append_op(
        type='lookup_table',
        inputs={'Ids': input,
                'W': w},
79 80
        outputs={'Out': tmp},
        attrs={'is_sparse': is_sparse})
Q
QI JUN 已提交
81 82 83
    return tmp


F
fengjiayi 已提交
84 85 86 87
def data(name,
         shape,
         data_type='float32',
         type=core.VarDesc.VarType.LOD_TENSOR,
Y
Yu Yang 已提交
88
         append_batch_size=True,
89
         main_program=None,
90 91
         startup_program=None,
         stop_gradient=True):
Y
Yu Yang 已提交
92
    helper = LayerHelper('data', **locals())
Y
Yu Yang 已提交
93 94 95 96 97 98 99 100
    shape = list(shape)
    for i in xrange(len(shape)):
        if shape[i] is None:
            shape[i] = -1
            append_batch_size = False
        elif shape[i] < 0:
            append_batch_size = False

Y
Yu Yang 已提交
101 102
    if append_batch_size:
        shape = [-1] + shape  # append batch size as -1
Y
Yu Yang 已提交
103

Y
Yu Yang 已提交
104
    return helper.create_global_variable(
105 106 107 108 109
        name=name,
        shape=shape,
        dtype=data_type,
        type=type,
        stop_gradient=stop_gradient)
Y
Yu Yang 已提交
110 111 112 113 114 115 116 117 118


def _convert_(name):
    s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name)
    return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower()


def _create_op_func_(op_type):
    op_proto = OpProtoHolder.instance().get_op_proto(op_type)
119 120 121 122 123 124
    not_intermediate_outputs = \
        filter(lambda output: not output.intermediate, op_proto.outputs)
    intermediate_outputs = \
        filter(lambda output: output.intermediate, op_proto.outputs)

    if len(not_intermediate_outputs) != 1:
Y
Yu Yang 已提交
125
        raise ValueError(
126 127
            "Only one not intermediate output operator can be automatically generated"
        )
Y
Yu Yang 已提交
128

129
    if not_intermediate_outputs[0].duplicable:
Y
Yu Yang 已提交
130 131 132
        raise ValueError(
            "Only not duplicable op can be automatically generated")

133 134 135 136 137 138 139 140
    for output in intermediate_outputs:
        if output.duplicable:
            raise ValueError(
                "Only when all intermediate ops are not duplicable, "
                "this op can be automatically generated")

    o_name = not_intermediate_outputs[0].name
    intermediate_output_names = [output.name for output in intermediate_outputs]
Y
Yu Yang 已提交
141

Y
Yang Yang(Tony) 已提交
142
    def infer_and_check_data_type(op_proto, **kwargs):
Y
Yu Yang 已提交
143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158
        dtype = None
        for ipt in op_proto.inputs:
            name = _convert_(ipt.name)
            val = kwargs.pop(name, [])
            if not isinstance(val, list) and not isinstance(val, tuple):
                val = [val]
            for each in val:
                if not isinstance(each, Variable):
                    raise ValueError("input of {0} must be variable".format(
                        op_type))

                if dtype is None:
                    dtype = each.data_type
                elif dtype != each.data_type:
                    raise ValueError(
                        "operator {0} must input same dtype".format(op_type))
Y
Yang Yang(Tony) 已提交
159 160 161 162 163 164 165 166 167 168 169 170 171 172

        return dtype

    def func(**kwargs):
        helper = LayerHelper(op_type, **kwargs)

        dtype = infer_and_check_data_type(op_proto, **kwargs)

        inputs = dict()
        for ipt in op_proto.inputs:
            name = _convert_(ipt.name)
            val = kwargs.pop(name, [])
            if not isinstance(val, list) and not isinstance(val, tuple):
                val = [val]
Y
Yu Yang 已提交
173 174
            inputs[ipt.name] = val

175
        outputs = dict()
Y
Yu Yang 已提交
176
        out = helper.create_tmp_variable(dtype=dtype)
177 178 179
        outputs[o_name] = [out]
        for name in intermediate_output_names:
            outputs[name] = [helper.create_tmp_variable(dtype=dtype)]
Y
Yu Yang 已提交
180
        helper.append_op(
181
            type=op_type, inputs=inputs, outputs=outputs, attrs=kwargs)
Q
Qiao Longfei 已提交
182
        return helper.append_activation(out)
Y
Yu Yang 已提交
183 184 185 186 187 188 189 190

    func.__name__ = op_type
    globals()[op_type] = func
    global __all__
    __all__.append(op_type)


_create_op_func_('mean')
Y
Yu Yang 已提交
191
_create_op_func_('mul')
Q
Qiao Longfei 已提交
192
_create_op_func_('elementwise_add')
193
_create_op_func_('dropout')
Q
Qiao Longfei 已提交
194
_create_op_func_('reshape')
Y
Yu Yang 已提交
195 196 197
_create_op_func_('elementwise_add')
_create_op_func_('sigmoid')
_create_op_func_('scale')
Y
Yang Yang(Tony) 已提交
198 199 200 201 202 203 204 205 206 207 208 209 210 211
_create_op_func_('reshape')
_create_op_func_('transpose')


def fill_constant(data_type, shape, value=None, program=None):
    helper = LayerHelper('fill_constant', **locals())
    out = helper.create_tmp_variable(dtype=data_type)
    helper.append_op(
        type='fill_constant',
        outputs={'Out': [out]},
        attrs={'data_type': data_type,
               'shape': shape,
               'value': value})
    return out
Y
Yu Yang 已提交
212 213


214
def cast(x, data_type, main_program=None):
Y
Yu Yang 已提交
215 216 217 218 219 220 221 222 223 224 225
    helper = LayerHelper('cast', **locals())
    out = helper.create_tmp_variable(dtype=data_type)
    helper.append_op(
        type='cast',
        inputs={'X': [x]},
        outputs={'Out': [out]},
        attrs={'in_data_type': x.data_type,
               'out_data_type': out.data_type})
    return out


226
def concat(input, axis, main_program=None, startup_program=None):
Q
QI JUN 已提交
227
    helper = LayerHelper('concat', **locals())
D
dzhwinter 已提交
228
    out = helper.create_tmp_variable(dtype=helper.input_dtype())
Q
QI JUN 已提交
229 230 231 232 233 234 235 236
    helper.append_op(
        type='concat',
        inputs={'X': input},
        outputs={'Out': [out]},
        attrs={'axis': axis})
    return out


237
def sums(input, main_program=None, startup_program=None):
D
dzhwinter 已提交
238 239
    helper = LayerHelper('sum', **locals())
    out = helper.create_tmp_variable(dtype=helper.input_dtype())
Y
Yu Yang 已提交
240
    helper.append_op(type='sum', inputs={'X': input}, outputs={'Out': out})
D
dzhwinter 已提交
241 242 243
    return out


244 245 246 247 248
def cos_sim(X, Y, **kwargs):
    helper = LayerHelper('cos_sim', **kwargs)
    out = helper.create_tmp_variable(dtype=X.data_type)
    xnorm = helper.create_tmp_variable(dtype=X.data_type)
    ynorm = helper.create_tmp_variable(dtype=X.data_type)
D
dzhwinter 已提交
249 250 251 252 253 254 255
    helper.append_op(
        type='cos_sim',
        inputs={'X': [X],
                'Y': [Y]},
        outputs={'Out': [out],
                 'XNorm': [xnorm],
                 'YNorm': [ynorm]})
256
    return out
D
dzhwinter 已提交
257 258


Y
Yu Yang 已提交
259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281
def cross_entropy(input, label, **kwargs):
    helper = LayerHelper('cross_entropy', **kwargs)
    out = helper.create_tmp_variable(dtype=input.data_type)
    helper.append_op(
        type='cross_entropy',
        inputs={'X': [input],
                'Label': [label]},
        outputs={'Y': [out]},
        attrs=kwargs)
    return out


def square_error_cost(input, label, **kwargs):
    helper = LayerHelper('square_error_cost', **kwargs)
    minus_out = helper.create_tmp_variable(dtype=input.data_type)
    helper.append_op(
        type='elementwise_sub',
        inputs={'X': [input],
                'Y': [label]},
        outputs={'Out': [minus_out]})

    square_out = helper.create_tmp_variable(dtype=input.data_type)
    helper.append_op(
Q
QI JUN 已提交
282
        type='square', inputs={'X': [minus_out]}, outputs={'Y': [square_out]})
Y
Yu Yang 已提交
283
    return square_out
284 285


F
fengjiayi 已提交
286 287 288 289 290 291 292 293 294 295 296 297 298 299
def accuracy(input, label, k=1, **kwargs):
    helper = LayerHelper("accuracy", **kwargs)
    topk_out = helper.create_tmp_variable(dtype=input.data_type)
    topk_indices = helper.create_tmp_variable(dtype="int64")
    helper.append_op(
        type="top_k",
        inputs={"X": [input]},
        outputs={"Out": [topk_out],
                 "Indices": [topk_indices]},
        attrs={"k": k})
    acc_out_dtype = kwargs.get("out_dtype", "float32")
    acc_out = helper.create_tmp_variable(dtype=acc_out_dtype)
    helper.append_op(
        type="accuracy",
武毅 已提交
300 301 302 303 304
        inputs={
            "Out": [topk_out],
            "Indices": [topk_indices],
            "Label": [label]
        },
F
fengjiayi 已提交
305 306 307 308
        outputs={"Accuracy": [acc_out]})
    return acc_out


D
dzhwinter 已提交
309 310 311
def sequence_conv(input,
                  num_filters,
                  filter_size=3,
312
                  filter_stride=1,
313
                  act=None,
D
dzhwinter 已提交
314 315 316
                  padding=None,
                  bias_attr=None,
                  param_attr=None,
317 318
                  main_program=None,
                  startup_program=None):
D
dzhwinter 已提交
319 320 321 322 323 324 325
    # FIXME(dzh) : want to unify the argument of python layer
    # function. So we ignore some unecessary attributes.
    # such as, padding_trainable, context_start.

    helper = LayerHelper('sequence_conv', **locals())
    dtype = helper.input_dtype()

D
dzhwinter 已提交
326
    filter_shape = [filter_size * input.shape[1], num_filters]
D
dzhwinter 已提交
327 328 329 330 331 332 333 334
    filter = helper.create_parameter(
        attr=helper.param_attr, shape=filter_shape, dtype=dtype)
    pre_bias = helper.create_tmp_variable(dtype)

    helper.append_op(
        type='sequence_conv',
        inputs={
            'X': [input],
D
dzhwinter 已提交
335
            'Filter': [filter],
D
dzhwinter 已提交
336 337 338
        },
        outputs={"Out": pre_bias},
        attrs={
339
            'contextStride': filter_stride,
340
            'contextStart': -int(filter_size / 2),
341
            'contextLength': filter_size
D
dzhwinter 已提交
342 343 344 345 346
        })
    pre_act = helper.append_bias_op(pre_bias)
    return helper.append_activation(pre_act)


F
fengjiayi 已提交
347 348 349 350 351 352 353 354 355 356
def conv2d(input,
           num_filters,
           name=None,
           filter_size=[1, 1],
           act=None,
           groups=None,
           stride=[1, 1],
           padding=None,
           bias_attr=None,
           param_attr=None,
357 358
           main_program=None,
           startup_program=None):
359 360 361 362 363 364 365 366 367 368 369
    helper = LayerHelper('conv2d', **locals())
    dtype = helper.input_dtype()

    num_channels = input.shape[1]
    if groups is None:
        num_filter_channels = num_channels
    else:
        if num_channels % groups is not 0:
            raise ValueError("num_channels must be divisible by groups.")
        num_filter_channels = num_channels / groups

F
fengjiayi 已提交
370 371 372 373 374 375 376
    if isinstance(filter_size, int):
        filter_size = [filter_size, filter_size]
    if isinstance(stride, int):
        stride = [stride, stride]
    if isinstance(padding, int):
        padding = [padding, padding]

377 378
    input_shape = input.shape
    filter_shape = [num_filters, num_filter_channels] + filter_size
379 380

    std = (2.0 / (filter_size[0]**2 * num_channels))**0.5
381
    filter = helper.create_parameter(
382 383 384 385
        attr=helper.param_attr,
        shape=filter_shape,
        dtype=dtype,
        initializer=NormalInitializer(0.0, std, 0))
386 387 388 389 390 391 392 393 394 395 396 397 398
    pre_bias = helper.create_tmp_variable(dtype)

    helper.append_op(
        type='conv2d',
        inputs={
            'Input': input,
            'Filter': filter,
        },
        outputs={"Output": pre_bias},
        attrs={'strides': stride,
               'paddings': padding,
               'groups': groups})

Y
Yu Yang 已提交
399
    pre_act = helper.append_bias_op(pre_bias, 1)
400 401

    return helper.append_activation(pre_act)
F
fengjiayi 已提交
402 403


D
dzhwinter 已提交
404
def sequence_pool(input, pool_type, **kwargs):
405
    helper = LayerHelper('sequence_pool', input=input, **kwargs)
D
dzhwinter 已提交
406 407
    dtype = helper.input_dtype()
    pool_out = helper.create_tmp_variable(dtype)
D
dangqingqing 已提交
408
    max_index = helper.create_tmp_variable(dtype)
D
dzhwinter 已提交
409 410 411

    helper.append_op(
        type="sequence_pool",
D
dangqingqing 已提交
412 413 414
        inputs={"X": input},
        outputs={"Out": pool_out,
                 "MaxIndex": max_index},
D
dzhwinter 已提交
415
        attrs={"pooltype": pool_type.upper()})
D
dzhwinter 已提交
416 417 418 419

    return pool_out


F
fengjiayi 已提交
420 421 422 423 424 425
def pool2d(input,
           pool_size,
           pool_type,
           pool_stride=[1, 1],
           pool_padding=[0, 0],
           global_pooling=False,
426 427
           main_program=None,
           startup_program=None):
F
fengjiayi 已提交
428 429 430 431 432 433 434 435 436 437 438
    if pool_type not in ["max", "avg"]:
        raise ValueError(
            "Unknown pool_type: '%s'. It can only be 'max' or 'avg'.",
            str(pool_type))
    if isinstance(pool_size, int):
        pool_size = [pool_size, pool_size]
    if isinstance(pool_stride, int):
        pool_stride = [pool_stride, pool_stride]
    if isinstance(pool_padding, int):
        pool_padding = [pool_padding, pool_padding]

D
dzhwinter 已提交
439
    helper = LayerHelper('pool2d', **locals())
F
fengjiayi 已提交
440 441 442 443 444 445 446 447
    dtype = helper.input_dtype()
    pool_out = helper.create_tmp_variable(dtype)

    helper.append_op(
        type="pool2d",
        inputs={"X": input},
        outputs={"Out": pool_out},
        attrs={
C
chengduoZH 已提交
448
            "pooling_type": pool_type,
F
fengjiayi 已提交
449
            "ksize": pool_size,
C
chengduoZH 已提交
450
            "global_pooling": global_pooling,
F
fengjiayi 已提交
451 452 453 454 455
            "strides": pool_stride,
            "paddings": pool_padding
        })

    return pool_out
Y
Yu Yang 已提交
456 457


Q
Qiao Longfei 已提交
458 459 460 461
def batch_norm(input,
               act=None,
               is_test=False,
               momentum=0.9,
462
               epsilon=1e-05,
Q
Qiao Longfei 已提交
463 464 465
               param_attr=None,
               bias_attr=None,
               data_layout='NCHW',
466 467
               main_program=None,
               startup_program=None):
Q
Qiao Longfei 已提交
468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483
    helper = LayerHelper('batch_norm', **locals())
    dtype = helper.input_dtype()

    input_shape = input.shape
    if data_layout == 'NCHW':
        channel_num = input_shape[1]
    else:
        if data_layout == 'NHWC':
            channel_num = input_shape[-1]
        else:
            raise ValueError("unsupported data layout:" + data_layout)

    param_shape = [channel_num]

    # create parameter
    scale = helper.create_parameter(
484 485 486 487
        attr=helper.param_attr,
        shape=param_shape,
        dtype=dtype,
        initializer=ConstantInitializer(1.0))
Q
Qiao Longfei 已提交
488
    bias = helper.create_parameter(
489 490 491 492 493 494 495 496 497 498 499 500 501 502
        attr=helper.param_attr,
        shape=param_shape,
        dtype=dtype,
        initializer=ConstantInitializer(0.0))

    mean = helper.create_global_variable(
        dtype=input.data_type, shape=param_shape, persistable=True)
    helper.set_variable_initializer(
        var=mean, initializer=ConstantInitializer(0.0))

    variance = helper.create_global_variable(
        dtype=input.data_type, shape=param_shape, persistable=True)
    helper.set_variable_initializer(
        var=variance, initializer=ConstantInitializer(1.0))
Q
Qiao Longfei 已提交
503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536

    # create output
    # mean and mean_out share the same memory
    mean_out = mean
    # variance and variance out share the same memory
    variance_out = variance
    saved_mean = helper.create_tmp_variable(dtype)
    saved_variance = helper.create_tmp_variable(dtype)

    batch_norm_out = helper.create_tmp_variable(dtype)

    helper.append_op(
        type="batch_norm",
        inputs={
            "X": input,
            "Scale": scale,
            "Bias": bias,
            "Mean": mean,
            "Variance": variance
        },
        outputs={
            "Y": batch_norm_out,
            "MeanOut": mean_out,
            "VarianceOut": variance_out,
            "SavedMean": saved_mean,
            "SavedVariance": saved_variance
        },
        attrs={"momentum": momentum,
               "epsilon": epsilon,
               "is_test": is_test})

    return helper.append_activation(batch_norm_out)


Y
Yu Yang 已提交
537 538 539 540 541 542
class BlockGuard(object):
    """
    BlockGuard used to create sub-block in program by using Python `with` 
    keyword.
    """

543 544
    def __init__(self, main_program):
        if not isinstance(main_program, Program):
Y
Yu Yang 已提交
545
            raise TypeError("BlockGuard takes a program")
546
        self.main_program = main_program
Y
Yu Yang 已提交
547 548

    def __enter__(self):
549
        self.main_program.create_block()
Y
Yu Yang 已提交
550 551

    def __exit__(self, exc_type, exc_val, exc_tb):
552
        self.main_program.rollback()
Y
Yu Yang 已提交
553 554 555 556 557 558 559 560 561
        if exc_type is not None:
            return False  # re-raise exception
        return True


class StaticRNNGuard(BlockGuard):
    def __init__(self, rnn):
        if not isinstance(rnn, StaticRNN):
            raise TypeError("StaticRNNGuard takes an StaticRNN")
562
        super(StaticRNNGuard, self).__init__(rnn.helper.main_program)
Y
Yu Yang 已提交
563 564 565 566 567 568 569
        self.rnn = rnn

    def __enter__(self):
        self.rnn.status = StaticRNN.IN_RNN_BLOCK
        return super(StaticRNNGuard, self).__enter__()

    def __exit__(self, exc_type, exc_val, exc_tb):
Y
Yu Yang 已提交
570 571
        if exc_type is not None:
            return False
Y
Yu Yang 已提交
572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597
        self.rnn.status = StaticRNN.AFTER_RNN_BLOCK
        self.rnn.complete_rnn_op()
        return super(StaticRNNGuard, self).__exit__(exc_type, exc_val, exc_tb)


class StaticRNNMemoryLink(object):
    """
    :param init: the initial variable for Memory
    :type init: Variable
    :param pre_mem: the memory variable in previous time step
    :type pre_mem: Variable
    :param mem: the memory variable in current time step
    :type mem: Variable
    """

    def __init__(self, init, pre_mem, mem=None):
        self.init = init
        self.pre_mem = pre_mem
        self.mem = mem


class StaticRNN(object):
    BEFORE_RNN_BLOCK = 0
    IN_RNN_BLOCK = 1
    AFTER_RNN_BLOCK = 2

598 599 600
    def __init__(self, name=None, main_program=None):
        self.helper = LayerHelper(
            "static_rnn", name=name, main_program=main_program)
Y
Yu Yang 已提交
601 602 603 604 605 606 607 608 609 610 611 612 613 614
        self.memories = {}  # memory map, from pre_mem.name --> MemoryLink
        self.inputs = []  # input variable list in current block
        self.outputs = []  # output variable list in parent block
        self.status = StaticRNN.BEFORE_RNN_BLOCK  # status flag.
        # sequence length, since it is a static RNN, sequence length are fixed.
        self.seq_len = None

    def step(self):
        return StaticRNNGuard(self)

    def _assert_in_rnn_block_(self, method):
        if self.status != StaticRNN.IN_RNN_BLOCK:
            raise ValueError("You must invoke {0} in rnn block".format(method))

615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630
    def memory(self,
               init=None,
               shape=None,
               batch_ref=None,
               init_value=0.0,
               init_batch_dim_idx=0,
               ref_batch_dim_idx=1):
        '''
        :param init: boot memory, if not set, a shape, batch_ref must be provided
        :param shape: shape of the boot memory
        :param batch_ref: batch size reference variable
        :param init_value: the init value of boot memory
        :param init_batch_dim_idx: the index of batch size in init's dimension
        :param ref_batch_dim_idx: the index of batch size in batch_ref's dimension
        :return: boot memory
        '''
Y
Yu Yang 已提交
631 632
        self._assert_in_rnn_block_('memory')
        if init is None:
633
            if shape is None or batch_ref is None:
Y
Yu Yang 已提交
634
                raise ValueError(
635
                    "if init is None, memory at least need shape and batch_ref")
Y
Yu Yang 已提交
636 637 638
            parent_block = self.parent_block()
            var_name = unique_name("@".join([self.helper.name, "memory_boot"]))
            boot_var = parent_block.create_var(
639 640 641 642
                name=var_name,
                shape=shape,
                dtype=batch_ref.data_type,
                persistable=False)
Y
Yu Yang 已提交
643 644

            parent_block.append_op(
645 646
                type="fill_constant_batch_size_like",
                inputs={'Input': [batch_ref]},
Y
Yu Yang 已提交
647 648 649
                outputs={'Out': [boot_var]},
                attrs={
                    'value': init_value,
650 651 652 653
                    'shape': boot_var.shape,
                    'data_type': boot_var.data_type,
                    'input_dim_idx': ref_batch_dim_idx,
                    'output_dim_idx': init_batch_dim_idx
Y
Yu Yang 已提交
654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670
                })

            return self.memory(init=boot_var)
        else:
            pre_mem = self.helper.create_variable(
                name=unique_name("@".join([self.helper.name, "mem"])),
                dtype=init.data_type,
                shape=init.shape)
            self.memories[pre_mem.name] = StaticRNNMemoryLink(
                init=init, pre_mem=pre_mem)
            return pre_mem

    def step_input(self, x):
        self._assert_in_rnn_block_('step_input')
        if not isinstance(x, Variable):
            raise TypeError("step input takes a Variable")
        if self.seq_len is None:
Y
Yu Yang 已提交
671 672
            self.seq_len = x.shape[0]
        elif self.seq_len != x.shape[0]:
Y
Yu Yang 已提交
673 674 675 676 677
            raise ValueError("Static RNN only take fix seq_len input")

        ipt = self.helper.create_variable(
            name=x.name,
            dtype=x.data_type,
Y
Yu Yang 已提交
678
            shape=list(x.shape[1:]),
Y
Yu Yang 已提交
679 680 681 682 683 684 685 686 687
            type=x.type)
        self.inputs.append(ipt)
        return ipt

    def step_output(self, o):
        self._assert_in_rnn_block_('step_output')
        if not isinstance(o, Variable):
            raise TypeError("step output takes a Variable")

Y
Yu Yang 已提交
688 689 690 691 692 693 694
        tmp_o = self.helper.create_tmp_variable(dtype=o.data_type)
        self.helper.append_op(
            type='rnn_memory_helper',
            inputs={'X': [o]},
            outputs={'Out': tmp_o},
            attrs={'data_type': o.data_type})

Y
Yu Yang 已提交
695
        out_var = self.parent_block().create_var(
Y
Yu Yang 已提交
696 697 698
            name=tmp_o.name,
            shape=[self.seq_len] + list(tmp_o.shape),
            dtype=tmp_o.data_type)
Y
Yu Yang 已提交
699 700 701 702 703 704 705 706 707 708 709 710 711

        self.outputs.append(out_var)

    def output(self, *outputs):
        for each in outputs:
            self.step_output(each)

    def update_memory(self, mem, var):
        if not isinstance(mem, Variable) or not isinstance(var, Variable):
            raise TypeError("update memory should take variables")
        self.memories[mem.name].mem = var

    def parent_block(self):
712
        prog = self.helper.main_program
Y
Yu Yang 已提交
713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728
        parent_idx = prog.current_block().parent_idx
        assert parent_idx >= 0
        parent_block = prog.block(parent_idx)
        return parent_block

    def __call__(self, *args, **kwargs):
        if self.status != StaticRNN.AFTER_RNN_BLOCK:
            raise ValueError("RNN output can only be retrieved after rnn block")
        if len(self.outputs) == 0:
            raise ValueError("RNN has no output")
        elif len(self.outputs) == 1:
            return self.outputs[0]
        else:
            return self.outputs

    def complete_rnn_op(self):
729 730
        main_program = self.helper.main_program
        rnn_block = main_program.current_block()
Y
Yu Yang 已提交
731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793
        parent_block = self.parent_block()

        local_inputs = set()

        for op in rnn_block.ops:
            assert isinstance(op, Operator)
            for oname in op.output_names:
                for out_var_name in op.output(oname):
                    local_inputs.add(out_var_name)

        for var in self.inputs:
            local_inputs.add(var.name)
        for m in self.memories:
            local_inputs.add(m)

        params = list()
        for op in rnn_block.ops:
            assert isinstance(op, Operator)
            for iname in op.input_names:
                for in_var_name in op.input(iname):
                    if in_var_name not in local_inputs:
                        params.append(in_var_name)

        parameters = [parent_block.var(name) for name in params]

        step_scope = parent_block.create_var(
            type=core.VarDesc.VarType.STEP_SCOPES)

        inlinks = [parent_block.var(i.name) for i in self.inputs]
        outlinks = self.outputs

        boot_memories = []
        pre_memories = []
        memories = []
        for _, mem in self.memories.iteritems():
            boot_memories.append(mem.init)
            pre_memories.append(mem.pre_mem.name)
            mem_var = rnn_block.var(mem.mem.name)
            assert isinstance(mem_var, Variable)
            new_mem = self.helper.create_tmp_variable(dtype=mem_var.data_type)

            rnn_block.append_op(
                type='rnn_memory_helper',
                inputs={'X': [mem_var]},
                outputs={'Out': [new_mem]},
                attrs={'data_type': mem_var.data_type})

            memories.append(new_mem.name)

        parent_block.append_op(
            type='recurrent',
            inputs={
                'inputs': inlinks,
                'initial_states': boot_memories,
                'parameters': parameters
            },
            outputs={'outputs': outlinks,
                     'step_scopes': [step_scope]},
            attrs={
                'ex_states': pre_memories,
                'states': memories,
                'step_block': rnn_block
            })
Y
Yu Yang 已提交
794 795


Y
Yang Yang(Tony) 已提交
796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835
def lstm(x,
         c_pre_init,
         hidden_dim,
         forget_bias=None,
         main_program=None,
         startup_program=None):
    helper = LayerHelper('lstm_unit', **locals())
    rnn = StaticRNN()
    with rnn.step():
        c_pre = rnn.memory(init=c_pre_init)
        x_t = rnn.step_input(x)

        before_fc = concat(
            input=[x_t, c_pre],
            axis=1,
            main_program=main_program,
            startup_program=startup_program)
        after_fc = fc(input=before_fc,
                      size=hidden_dim * 4,
                      main_program=main_program,
                      startup_program=startup_program)

        data_type = x.data_type
        c = helper.create_tmp_variable(data_type)
        h = helper.create_tmp_variable(data_type)

        helper.append_op(
            type='lstm_unit',
            inputs={"X": after_fc,
                    "C_prev": c_pre},
            outputs={"C": c,
                     "H": h},
            attrs={"forget_bias": forget_bias})

        rnn.update_memory(c_pre, c)
        rnn.output(h)

    return rnn()


836
def lod_rank_table(x, level=0, main_program=None):
Y
Yu Yang 已提交
837 838 839 840 841 842 843 844 845 846
    helper = LayerHelper("lod_rank_table", **locals())
    table = helper.create_variable(
        type=core.VarDesc.VarType.LOD_RANK_TABLE,
        name=unique_name("lod_rank_table"))
    helper.append_op(
        type='lod_rank_table',
        inputs={'X': x},
        outputs={'Out': table},
        attrs={'level': level})
    return table
Y
Yu Yang 已提交
847 848


849 850 851 852
def lod_tensor_to_array(x, table, main_program=None):
    helper = LayerHelper("lod_tensor_to_array", **locals())
    array = helper.create_variable(
        name=unique_name("lod_tensor_to_array"),
853 854
        type=core.VarDesc.VarType.LOD_TENSOR_ARRAY,
        dtype=x.data_type)
855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873
    helper.append_op(
        type='lod_tensor_to_array',
        inputs={'X': x,
                'RankTable': table},
        outputs={'Out': array})
    return array


def array_to_lod_tensor(x, table, main_program=None):
    helper = LayerHelper("array_to_lod_tensor", **locals())
    tmp = helper.create_tmp_variable(dtype=x.data_type)
    helper.append_op(
        type="array_to_lod_tensor",
        inputs={'X': x,
                'RankTable': table},
        outputs={'Out': tmp})
    return tmp


Y
Yu Yang 已提交
874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899
def fill_constant(shape, dtype, value, main_program=None):
    helper = LayerHelper("ones", **locals())
    out = helper.create_tmp_variable(dtype=dtype)
    helper.append_op(
        type='fill_constant',
        inputs={},
        outputs={'Out': [out]},
        attrs={
            'shape': shape,
            'data_type': out.data_type,
            'value': float(value)
        })
    out.stop_gradient = True
    return out


def ones(shape, dtype, main_program=None):
    return fill_constant(value=1.0, **locals())


def zeros(shape, dtype, main_program=None):
    return fill_constant(value=0.0, **locals())


def increment(x, value=1.0, main_program=None):
    helper = LayerHelper("increment", **locals())
Y
Yang Yu 已提交
900
    out = helper.create_tmp_variable(dtype=x.data_type)
Y
Yu Yang 已提交
901 902 903
    helper.append_op(
        type='increment',
        inputs={'X': [x]},
Y
Yang Yu 已提交
904
        outputs={'Out': [out]},
Y
Yu Yang 已提交
905
        attrs={'step': value})
Y
Yang Yu 已提交
906
    return out
Y
Yu Yang 已提交
907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936


def array_write(x, i, array=None, main_program=None):
    helper = LayerHelper('array_write', **locals())
    if array is None:
        array = helper.create_variable(
            name="{0}.out".format(helper.name),
            type=core.VarDesc.VarType.LOD_TENSOR_ARRAY,
            dtype=x.data_type)
    helper.append_op(
        type='write_to_array',
        inputs={'X': [x],
                'I': [i]},
        outputs={'Out': [array]})
    return array


def array_read(array, i, main_program=None):
    helper = LayerHelper('array_read', **locals())
    if not isinstance(
            array,
            Variable) or array.type != core.VarDesc.VarType.LOD_TENSOR_ARRAY:
        raise TypeError("array should be tensor array vairable")
    out = helper.create_tmp_variable(dtype=array.data_type)
    helper.append_op(
        type='read_from_array',
        inputs={'X': [array],
                'I': [i]},
        outputs={'Out': [out]})
    return out
Y
Yang Yu 已提交
937 938 939 940 941 942


def shrink_memory(x, i, table, main_program=None):
    helper = LayerHelper('shrink_memory', **locals())
    out = helper.create_tmp_variable(dtype=x.data_type)
    helper.append_op(
Y
Yang Yu 已提交
943
        type='shrink_rnn_memory',
Y
Yang Yu 已提交
944 945 946 947 948 949
        inputs={'X': [x],
                'I': [i],
                'RankTable': [table]},
        outputs={'Out': [out]},
        attrs={})
    return out