未验证 提交 6d637a93 编写于 作者: L lilu 提交者: GitHub

test=develop (#2249)

上级 0a31618b
...@@ -11,22 +11,24 @@ import math ...@@ -11,22 +11,24 @@ import math
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle import paddle
class Network(object): class Network(object):
""" """
Network Network
""" """
def __init__(self,
vocab_size, def __init__(self,
emb_size, vocab_size,
hidden_size, emb_size,
clip_value=10.0, hidden_size,
word_emb_name="shared_word_emb", clip_value=10.0,
lstm_W_name="shared_lstm_W", word_emb_name="shared_word_emb",
lstm_bias_name="shared_lstm_bias"): lstm_W_name="shared_lstm_W",
lstm_bias_name="shared_lstm_bias"):
""" """
Init function Init function
""" """
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.emb_size = emb_size self.emb_size = emb_size
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -44,8 +46,7 @@ class Network(object): ...@@ -44,8 +46,7 @@ class Network(object):
name="context_wordseq", shape=[1], dtype="int64", lod_level=1) name="context_wordseq", shape=[1], dtype="int64", lod_level=1)
response_wordseq = fluid.layers.data( response_wordseq = fluid.layers.data(
name="response_wordseq", shape=[1], dtype="int64", lod_level=1) name="response_wordseq", shape=[1], dtype="int64", lod_level=1)
label = fluid.layers.data( label = fluid.layers.data(name="label", shape=[1], dtype="float32")
name="label", shape=[1], dtype="float32")
self._feed_name = ["context_wordseq", "response_wordseq", "label"] self._feed_name = ["context_wordseq", "response_wordseq", "label"]
self._feed_infer_name = ["context_wordseq", "response_wordseq"] self._feed_infer_name = ["context_wordseq", "response_wordseq"]
...@@ -58,7 +59,7 @@ class Network(object): ...@@ -58,7 +59,7 @@ class Network(object):
param_attr=fluid.ParamAttr( param_attr=fluid.ParamAttr(
name=self.word_emb_name, name=self.word_emb_name,
initializer=fluid.initializer.Normal(scale=0.1))) initializer=fluid.initializer.Normal(scale=0.1)))
response_emb = fluid.layers.embedding( response_emb = fluid.layers.embedding(
input=response_wordseq, input=response_wordseq,
size=[self.vocab_size, self.emb_size], size=[self.vocab_size, self.emb_size],
...@@ -66,7 +67,7 @@ class Network(object): ...@@ -66,7 +67,7 @@ class Network(object):
param_attr=fluid.ParamAttr( param_attr=fluid.ParamAttr(
name=self.word_emb_name, name=self.word_emb_name,
initializer=fluid.initializer.Normal(scale=0.1))) initializer=fluid.initializer.Normal(scale=0.1)))
#fc to fit dynamic LSTM #fc to fit dynamic LSTM
context_fc = fluid.layers.fc( context_fc = fluid.layers.fc(
input=context_emb, input=context_emb,
...@@ -96,10 +97,10 @@ class Network(object): ...@@ -96,10 +97,10 @@ class Network(object):
bias_attr=fluid.ParamAttr(name=self.lstm_bias_name)) bias_attr=fluid.ParamAttr(name=self.lstm_bias_name))
response_rep = fluid.layers.sequence_last_step(input=response_rep) response_rep = fluid.layers.sequence_last_step(input=response_rep)
print('response_rep shape: %s' % str(response_rep.shape)) print('response_rep shape: %s' % str(response_rep.shape))
logits = fluid.layers.bilinear_tensor_product( logits = fluid.layers.bilinear_tensor_product(
context_rep, response_rep, size=1) context_rep, response_rep, size=1)
print('logits shape: %s' % str(logits.shape)) #[batch,1] print('logits shape: %s' % str(logits.shape)) #[batch,1]
if loss_type == 'CLS': if loss_type == 'CLS':
loss = fluid.layers.sigmoid_cross_entropy_with_logits(logits, label) loss = fluid.layers.sigmoid_cross_entropy_with_logits(logits, label)
...@@ -111,6 +112,7 @@ class Network(object): ...@@ -111,6 +112,7 @@ class Network(object):
elif loss_type == 'L2': elif loss_type == 'L2':
norm_score = 2 * fluid.layers.sigmoid(logits) norm_score = 2 * fluid.layers.sigmoid(logits)
loss = fluid.layers.square_error_cost(norm_score, label) / 4 loss = fluid.layers.square_error_cost(norm_score, label) / 4
loss = fluid.layers.reduce_mean(loss)
else: else:
raise ValueError raise ValueError
...@@ -129,10 +131,9 @@ class Network(object): ...@@ -129,10 +131,9 @@ class Network(object):
Return feed names Return feed names
""" """
return self._feed_name return self._feed_name
def get_feed_inference_names(self): def get_feed_inference_names(self):
""" """
Return inference names Return inference names
""" """
return self._feed_infer_name return self._feed_infer_name
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册