提交 665eb015 编写于 作者: P peterzhang2029

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into bi_tensor_prod_op

...@@ -27,20 +27,22 @@ namespace platform { ...@@ -27,20 +27,22 @@ namespace platform {
This wrap is a hack to avoid this bug. This wrap is a hack to avoid this bug.
*/ */
template <class Callable, class... Args> template <typename Callable, typename... Args>
inline void call_once(std::once_flag& flag, Callable&& f, Args&&... args) { inline void call_once(std::once_flag& flag, Callable&& f, Args&&... args) {
bool good = false; bool good = false;
std::exception ex; std::exception ex;
std::call_once(flag, [&]() { std::call_once(flag,
try { [&](Args&&... args) {
f(args...); try {
good = true; f(args...);
} catch (const std::exception& e) { good = true;
ex = e; } catch (const std::exception& e) {
} catch (...) { ex = e;
ex = std::runtime_error("excption caught in call_once"); } catch (...) {
} ex = std::runtime_error("excption caught in call_once");
}); }
},
args...);
if (!good) { if (!good) {
throw std::exception(ex); throw std::exception(ex);
} }
......
...@@ -4,7 +4,7 @@ import itertools ...@@ -4,7 +4,7 @@ import itertools
from paddle.v2.framework.framework import Variable, g_main_program, \ from paddle.v2.framework.framework import Variable, g_main_program, \
g_startup_program, unique_name, Program g_startup_program, unique_name, Program
from paddle.v2.framework.initializer import ConstantInitializer, \ from paddle.v2.framework.initializer import ConstantInitializer, \
UniformInitializer UniformInitializer, XavierInitializer
class LayerHelper(object): class LayerHelper(object):
...@@ -61,7 +61,7 @@ class LayerHelper(object): ...@@ -61,7 +61,7 @@ class LayerHelper(object):
@property @property
def param_attr(self): def param_attr(self):
default = {'name': None, 'initializer': UniformInitializer()} default = {'name': None, 'initializer': XavierInitializer()}
actual = self.kwargs.get('param_attr', None) actual = self.kwargs.get('param_attr', None)
if actual is None: if actual is None:
actual = default actual = default
...@@ -70,10 +70,11 @@ class LayerHelper(object): ...@@ -70,10 +70,11 @@ class LayerHelper(object):
actual[default_field] = default[default_field] actual[default_field] = default[default_field]
return actual return actual
@property
def bias_attr(self): def bias_attr(self):
default = {'name': None, 'initializer': ConstantInitializer()} default = {'name': None, 'initializer': XavierInitializer()}
bias_attr = self.kwargs.get('bias_attr', None) bias_attr = self.kwargs.get('bias_attr', None)
if bias_attr is True: if bias_attr is None:
bias_attr = default bias_attr = default
if isinstance(bias_attr, dict): if isinstance(bias_attr, dict):
...@@ -166,7 +167,7 @@ class LayerHelper(object): ...@@ -166,7 +167,7 @@ class LayerHelper(object):
num_flatten_dims = 1 num_flatten_dims = 1
size = list(input_var.shape[num_flatten_dims:]) size = list(input_var.shape[num_flatten_dims:])
bias_attr = self.bias_attr() bias_attr = self.bias_attr
if not bias_attr: if not bias_attr:
return input_var return input_var
......
...@@ -16,7 +16,7 @@ __all__ = [ ...@@ -16,7 +16,7 @@ __all__ = [
def fc(input, def fc(input,
size, size,
param_attr=None, param_attr=None,
bias_attr=True, bias_attr=None,
name=None, name=None,
act=None, act=None,
num_flatten_dims=1, num_flatten_dims=1,
...@@ -125,6 +125,55 @@ def embedding(input, ...@@ -125,6 +125,55 @@ def embedding(input,
return tmp return tmp
# TODO(qijun): expose H0 and C0
def dynamic_lstm(input,
size,
data_type='float32',
param_attr=None,
bias_attr=None,
use_peepholes=True,
is_reverse=False,
gate_activation='sigmoid',
cell_activation='tanh',
candidate_activation='tanh',
main_program=None,
startup_program=None):
helper = LayerHelper('lstm', **locals())
size = size / 4
weight = helper.create_parameter(
attr=helper.param_attr, shape=[size, 4 * size], dtype=data_type)
bias_size = [1, 7 * size]
if not use_peepholes:
bias_size[1] = 4 * size
bias = helper.create_parameter(
attr=helper.bias_attr, shape=bias_size, dtype=data_type, suffix='b')
hidden = helper.create_tmp_variable(data_type)
cell = helper.create_tmp_variable(data_type)
batch_gate = helper.create_tmp_variable(data_type)
batch_cell_pre_act = helper.create_tmp_variable(data_type)
helper.append_op(
type='lstm',
inputs={'Input': input,
'Weight': weight,
'Bias': bias},
outputs={
'Hidden': hidden,
'Cell': cell,
'BatchGate': batch_gate,
'BatchCellPreAct': batch_cell_pre_act
},
attrs={
'use_peepholes': use_peepholes,
'is_reverse': is_reverse,
'gate_activation': gate_activation,
'cell_activation': cell_activation,
'candidate_activation': candidate_activation
})
return hidden, cell
def data(name, def data(name,
shape, shape,
data_type='float32', data_type='float32',
......
import paddle.v2 as paddle
import paddle.v2.framework.layers as layers
import paddle.v2.framework.nets as nets
import paddle.v2.framework.core as core
import paddle.v2.framework.optimizer as optimizer
from paddle.v2.framework.framework import Program, g_main_program, g_startup_program
from paddle.v2.framework.executor import Executor
import numpy as np
def stacked_lstm_net(input_dim,
class_dim=2,
emb_dim=128,
hid_dim=512,
stacked_num=3):
assert stacked_num % 2 == 1
data = layers.data(name="words", shape=[1], data_type="int64")
label = layers.data(name="label", shape=[1], data_type="int64")
emb = layers.embedding(input=data, size=[input_dim, emb_dim])
# add bias attr
# TODO(qijun) linear act
fc1 = layers.fc(input=emb, size=hid_dim)
lstm1, cell1 = layers.dynamic_lstm(input=fc1, size=hid_dim)
inputs = [fc1, lstm1]
for i in range(2, stacked_num + 1):
fc = layers.fc(input=inputs, size=hid_dim)
lstm, cell = layers.dynamic_lstm(
input=fc, size=hid_dim, is_reverse=(i % 2) == 0)
inputs = [fc, lstm]
fc_last = layers.sequence_pool(input=inputs[0], pool_type='max')
lstm_last = layers.sequence_pool(input=inputs[1], pool_type='max')
prediction = layers.fc(input=[fc_last, lstm_last],
size=class_dim,
act='softmax')
cost = layers.cross_entropy(input=prediction, label=label)
avg_cost = layers.mean(x=cost)
adam_optimizer = optimizer.AdamOptimizer(learning_rate=0.002)
opts = adam_optimizer.minimize(avg_cost)
acc = layers.accuracy(input=prediction, label=label)
return avg_cost, acc
def to_lodtensor(data, place):
seq_lens = [len(seq) for seq in data]
cur_len = 0
lod = [cur_len]
for l in seq_lens:
cur_len += l
lod.append(cur_len)
flattened_data = np.concatenate(data, axis=0).astype("int64")
flattened_data = flattened_data.reshape([len(flattened_data), 1])
res = core.LoDTensor()
res.set(flattened_data, place)
res.set_lod([lod])
return res
def main():
BATCH_SIZE = 100
PASS_NUM = 5
word_dict = paddle.dataset.imdb.word_dict()
print "load word dict successfully"
dict_dim = len(word_dict)
class_dim = 2
cost, acc = stacked_lstm_net(input_dim=dict_dim, class_dim=class_dim)
train_data = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.imdb.train(word_dict), buf_size=1000),
batch_size=BATCH_SIZE)
place = core.CPUPlace()
exe = Executor(place)
exe.run(g_startup_program)
for pass_id in xrange(PASS_NUM):
for data in train_data():
tensor_words = to_lodtensor(map(lambda x: x[0], data), place)
label = np.array(map(lambda x: x[1], data)).astype("int64")
label = label.reshape([BATCH_SIZE, 1])
tensor_label = core.LoDTensor()
tensor_label.set(label, place)
outs = exe.run(g_main_program,
feed={"words": tensor_words,
"label": tensor_label},
fetch_list=[cost, acc])
cost_val = np.array(outs[0])
acc_val = np.array(outs[1])
print("cost=" + str(cost_val) + " acc=" + str(acc_val))
if cost_val < 1.0 and acc_val > 0.7:
exit(0)
exit(1)
if __name__ == '__main__':
main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册