提交 4c681dfb 编写于 作者: Y Yibing Liu

Fix test & feed type and disable grad clip temporarily

上级 8c5f9d7e
......@@ -25,7 +25,7 @@ class Net(object):
# turns ids
shapes = [[-1, self._max_turn_len, 1]
for i in six.moves.xrange(self._max_turn_num)]
dtypes = ["int32" for i in six.moves.xrange(self._max_turn_num)]
dtypes = ["int64" for i in six.moves.xrange(self._max_turn_num)]
# turns mask
shapes += [[-1, self._max_turn_len, 1]
for i in six.moves.xrange(self._max_turn_num)]
......@@ -34,7 +34,7 @@ class Net(object):
# response ids, response mask, label
shapes += [[-1, self._max_turn_len, 1], [-1, self._max_turn_len, 1],
[-1, 1]]
dtypes += ["int32", "float32", "float32"]
dtypes += ["int64", "float32", "float32"]
py_reader = fluid.layers.py_reader(
capacity=capacity,
......@@ -60,7 +60,7 @@ class Net(object):
for i in six.moves.xrange(self._max_turn_num):
name = "turn_%d" % i
turn = fluid.layers.data(
name=name, shape=[self._max_turn_len, 1], dtype="int32")
name=name, shape=[self._max_turn_len, 1], dtype="int64")
self.turns_data.append(turn)
self._feed_names.append(name)
......@@ -73,7 +73,7 @@ class Net(object):
self._feed_names.append(name)
self.response = fluid.layers.data(
name="response", shape=[self._max_turn_len, 1], dtype="int32")
name="response", shape=[self._max_turn_len, 1], dtype="int64")
self.response_mask = fluid.layers.data(
name="response_mask",
shape=[self._max_turn_len, 1],
......
......@@ -126,6 +126,7 @@ def test(args):
dam = Net(args.max_turn_num, args.max_turn_len, args.vocab_size,
args.emb_size, args.stack_num, args.channel1_num,
args.channel2_num)
dam.create_data_layers()
loss, logits = dam.create_network()
loss.persistable = True
......@@ -191,7 +192,8 @@ def test(args):
feed_list = []
for dev in six.moves.xrange(dev_count):
index = it * dev_count + dev
feed_dict = reader.make_one_batch_input(test_batches, index)
batch_data = reader.make_one_batch_input(test_batches, index)
feed_dict = dict(zip(dam.get_feed_names(), batch_data))
feed_list.append(feed_dict)
predicts = test_exe.run(feed=feed_list, fetch_list=[logits.name])
......
......@@ -203,8 +203,8 @@ def train(args):
loss.persistable = True
logits.persistable = True
# gradient clipping
fluid.clip.set_gradient_clip(clip=fluid.clip.GradientClipByValue(
max=1.0, min=-1.0))
#fluid.clip.set_gradient_clip(clip=fluid.clip.GradientClipByValue(
# max=1.0, min=-1.0))
optimizer = fluid.optimizer.Adam(
learning_rate=fluid.layers.exponential_decay(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册