未验证 提交 e0698e33 编写于 作者: Y Yu Yang 提交者: GitHub

Make layers as a python module (#6564)

* Make cast op support bool

Also add `elemwise_sub/mul/abs/clip` layers

* Make fuild.layers as a module

* Move layers as a module

* Split layers.py into layers module

* Fix CI

* Fix CI
上级 b84da668
import ops
from ops import *
import nn
from nn import *
import io
from io import *
import tensor
from tensor import *
import control_flow
from control_flow import *
__all__ = []
__all__ += nn.__all__
__all__ += io.__all__
__all__ += tensor.__all__
__all__ += control_flow.__all__
__all__ += ops.__all__
from .. import core
from ..layer_helper import LayerHelper
__all__ = ['data']
def data(name,
shape,
append_batch_size=True,
dtype='float32',
lod_level=0,
type=core.VarDesc.VarType.LOD_TENSOR,
main_program=None,
startup_program=None,
stop_gradient=True):
"""
Data Layer.
Args:
name: The name/alias of the function
shape: Tuple declaring the shape.
append_batch_size: Whether or not to append the data as a batch.
dtype: The type of data : float32, float_16, int etc
type: The output type. By default it is LOD_TENSOR.
lod_level(int): The LoD Level. 0 means the input data is not a sequence.
main_program: Name of the main program that calls this
startup_program: Name of the startup program
stop_gradient: A boolean that mentions whether gradient should flow.
This function takes in input and based on whether data has
to be returned back as a minibatch, it creates the global variable using
the helper functions. The global variables can be accessed by all the
following operations and layers in the graph.
All the input variables of this function are passed in as local variables
to the LayerHelper constructor.
"""
helper = LayerHelper('data', **locals())
shape = list(shape)
for i in xrange(len(shape)):
if shape[i] is None:
shape[i] = -1
append_batch_size = False
elif shape[i] < 0:
append_batch_size = False
if append_batch_size:
shape = [-1] + shape # append batch size as -1
return helper.create_global_variable(
name=name,
shape=shape,
dtype=dtype,
type=type,
stop_gradient=stop_gradient,
lod_level=lod_level)
此差异已折叠。
from ..registry import register_layer
__all__ = [
'mean', 'mul', 'dropout', 'reshape', 'sigmoid', 'scale', 'transpose',
'sigmoid_cross_entropy_with_logits', 'elementwise_add', 'elementwise_div',
'elementwise_sub', 'elementwise_mul', 'clip', 'abs'
]
for _OP in set(__all__):
globals()[_OP] = register_layer(_OP)
from ..layer_helper import LayerHelper
__all__ = [
'create_tensor', 'cast', 'concat', 'sums', 'assign',
'fill_constant_batch_size_like', 'fill_constant', 'ones', 'zeros'
]
def create_tensor(dtype, name=None, main_program=None, startup_program=None):
helper = LayerHelper("create_tensor", **locals())
return helper.create_variable(name=helper.name, dtype=dtype)
def cast(x, dtype, main_program=None):
"""
This function takes in the input with input_dtype
and casts it to the output_dtype as the output.
"""
helper = LayerHelper('cast', **locals())
out = helper.create_tmp_variable(dtype=dtype)
helper.append_op(
type='cast',
inputs={'X': [x]},
outputs={'Out': [out]},
attrs={'in_dtype': x.dtype,
'out_dtype': out.dtype})
return out
def concat(input, axis, main_program=None, startup_program=None):
"""
This function concats the input along the axis mentioned
and returns that as the output.
"""
helper = LayerHelper('concat', **locals())
out = helper.create_tmp_variable(dtype=helper.input_dtype())
helper.append_op(
type='concat',
inputs={'X': input},
outputs={'Out': [out]},
attrs={'axis': axis})
return out
def sums(input, out=None, main_program=None, startup_program=None):
"""
This function takes in the input and performs the sum operation on it
and returns that as the output.
"""
helper = LayerHelper('sum', **locals())
if out is None:
out = helper.create_tmp_variable(dtype=helper.input_dtype())
helper.append_op(type='sum', inputs={'X': input}, outputs={'Out': out})
return out
def assign(input, output, main_program=None, startup_program=None):
helper = LayerHelper('assign', **locals())
helper.append_op(
type='scale',
inputs={'X': [input]},
outputs={'Out': [output]},
attrs={'scale': 1.0})
return output
def fill_constant(shape,
dtype,
value,
out=None,
main_program=None,
startup_program=None):
"""
This function creates a tensor , with shape as mentioned in the input and
specified dtype and fills this up with a constant value that
comes in the input. It also sets the stop_gradient to be True.
"""
helper = LayerHelper("fill_constant", **locals())
if out is None:
out = helper.create_tmp_variable(dtype=dtype)
helper.append_op(
type='fill_constant',
inputs={},
outputs={'Out': [out]},
attrs={'shape': shape,
'dtype': out.dtype,
'value': float(value)})
out.stop_gradient = True
return out
def fill_constant_batch_size_like(input,
shape,
dtype,
value,
input_dim_idx=0,
output_dim_idx=0,
main_program=None,
startup_program=None):
helper = LayerHelper("fill_constant_batch_size_like", **locals())
out = helper.create_tmp_variable(dtype=dtype)
helper.append_op(
type='fill_constant_batch_size_like',
inputs={'Input': input},
outputs={'Out': [out]},
attrs={
'shape': shape,
'dtype': out.dtype,
'value': float(value),
'input_dim_idx': input_dim_idx,
'output_dim_idx': output_dim_idx
})
out.stop_gradient = True
return out
def ones(shape, dtype, main_program=None):
"""
This function performs the same function as fill_constant() declared above
with the constant value being 1.0.
"""
return fill_constant(value=1.0, **locals())
def zeros(shape, dtype, main_program=None):
"""
This function performs the same function as fill_constant() declared above
with the constant value being 0.0.
"""
return fill_constant(value=0.0, **locals())
from __future__ import print_function
import numpy as np
import sys
import paddle.v2 as paddle
import paddle.v2.fluid as fluid
import sys
def resnet_cifar10(input, depth=32):
......
import numpy as np
import paddle.v2 as paddle
import paddle.v2.fluid as fluid
from paddle.v2.fluid.layer_helper import LayerHelper
def lstm(x,
c_pre_init,
hidden_dim,
forget_bias=None,
main_program=None,
startup_program=None):
"""
This function helps create an operator for the LSTM (Long Short Term
Memory) cell that can be used inside an RNN.
"""
helper = LayerHelper('lstm_unit', **locals())
rnn = fluid.layers.StaticRNN()
with rnn.step():
c_pre = rnn.memory(init=c_pre_init)
x_t = rnn.step_input(x)
before_fc = fluid.layers.concat(
input=[x_t, c_pre],
axis=1,
main_program=main_program,
startup_program=startup_program)
after_fc = fluid.layers.fc(input=before_fc,
size=hidden_dim * 4,
main_program=main_program,
startup_program=startup_program)
dtype = x.dtype
c = helper.create_tmp_variable(dtype)
h = helper.create_tmp_variable(dtype)
helper.append_op(
type='lstm_unit',
inputs={"X": after_fc,
"C_prev": c_pre},
outputs={"C": c,
"H": h},
attrs={"forget_bias": forget_bias})
rnn.update_memory(c_pre, c)
rnn.output(h)
return rnn()
def lstm_net(dict_dim, class_dim=2, emb_dim=32, seq_len=80, batch_size=50):
......@@ -23,8 +68,7 @@ def lstm_net(dict_dim, class_dim=2, emb_dim=32, seq_len=80, batch_size=50):
c_pre_init = fluid.layers.fill_constant(
dtype=emb.dtype, shape=[batch_size, emb_dim], value=0.0)
c_pre_init.stop_gradient = False
layer_1_out = fluid.layers.lstm(
emb, c_pre_init=c_pre_init, hidden_dim=emb_dim)
layer_1_out = lstm(emb, c_pre_init=c_pre_init, hidden_dim=emb_dim)
layer_1_out = fluid.layers.transpose(x=layer_1_out, axis=[1, 0, 2])
prediction = fluid.layers.fc(input=layer_1_out,
......
......@@ -68,6 +68,7 @@ packages=['paddle',
'paddle.v2.plot',
'paddle.v2.fluid',
'paddle.v2.fluid.proto',
'paddle.v2.fluid.layers',
'py_paddle']
with open('@PADDLE_SOURCE_DIR@/python/requirements.txt') as f:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册