提交 7d7ab3af 编写于 作者: F FrostML

align the behavior of simnet with static graph, test=develop

上级 1bf72647
...@@ -21,6 +21,8 @@ from paddle.fluid.dygraph.nn import Linear ...@@ -21,6 +21,8 @@ from paddle.fluid.dygraph.nn import Linear
from paddle.fluid.dygraph import Layer from paddle.fluid.dygraph import Layer
from paddle import fluid from paddle import fluid
import numpy as np import numpy as np
from utils import seq_length
class GRU(Layer): class GRU(Layer):
""" """
...@@ -37,13 +39,16 @@ class GRU(Layer): ...@@ -37,13 +39,16 @@ class GRU(Layer):
self.emb_dim = conf_dict["net"]["emb_dim"] self.emb_dim = conf_dict["net"]["emb_dim"]
self.gru_dim = conf_dict["net"]["gru_dim"] self.gru_dim = conf_dict["net"]["gru_dim"]
self.hidden_dim = conf_dict["net"]["hidden_dim"] self.hidden_dim = conf_dict["net"]["hidden_dim"]
self.emb_layer = layers.EmbeddingLayer(self.dict_size, self.emb_dim, "emb").ops() self.emb_layer = layers.EmbeddingLayer(self.dict_size, self.emb_dim,
"emb").ops()
self.gru_layer = layers.DynamicGRULayer(self.gru_dim, "gru").ops() self.gru_layer = layers.DynamicGRULayer(self.gru_dim, "gru").ops()
self.fc_layer = layers.FCLayer(self.hidden_dim, None, "fc").ops() self.fc_layer = layers.FCLayer(self.hidden_dim, None, "fc").ops()
self.proj_layer = Linear(input_dim = self.hidden_dim, output_dim=self.gru_dim*3) self.proj_layer = Linear(
input_dim=self.hidden_dim, output_dim=self.gru_dim * 3)
self.softmax_layer = layers.FCLayer(2, "softmax", "cos_sim").ops() self.softmax_layer = layers.FCLayer(2, "softmax", "cos_sim").ops()
self.seq_len=conf_dict["seq_len"] self.last_layer = layers.ExtractLastLayer()
self.seq_len = conf_dict["seq_len"]
def forward(self, left, right): def forward(self, left, right):
""" """
...@@ -60,17 +65,14 @@ class GRU(Layer): ...@@ -60,17 +65,14 @@ class GRU(Layer):
h_0 = to_variable(h_0) h_0 = to_variable(h_0)
left_gru = self.gru_layer(left_emb, h_0=h_0) left_gru = self.gru_layer(left_emb, h_0=h_0)
right_gru = self.gru_layer(right_emb, h_0=h_0) right_gru = self.gru_layer(right_emb, h_0=h_0)
left_emb = fluid.layers.reduce_max(left_gru, dim=1) # Get sequence length before padding
right_emb = fluid.layers.reduce_max(right_gru, dim=1) left_len = seq_length(left)
left_emb = fluid.layers.reshape( left_len.stop_gradient = True
left_emb, shape=[-1, self.seq_len, self.hidden_dim]) right_len = seq_length(right)
right_emb = fluid.layers.reshape( right_len.stop_gradient = True
right_emb, shape=[-1, self.seq_len, self.hidden_dim]) # Extract last step
left_emb = fluid.layers.reduce_sum(left_emb, dim=1) left_last = self.last_layer.ops(left_gru, left_len)
right_emb = fluid.layers.reduce_sum(right_emb, dim=1) right_last = self.last_layer.ops(right_gru, right_len)
left_last = fluid.layers.tanh(left_emb)
right_last = fluid.layers.tanh(right_emb)
if self.task_mode == "pairwise": if self.task_mode == "pairwise":
left_fc = self.fc_layer(left_last) left_fc = self.fc_layer(left_last)
......
...@@ -17,6 +17,8 @@ lstm class ...@@ -17,6 +17,8 @@ lstm class
import paddle_layers as layers import paddle_layers as layers
from paddle.fluid.dygraph import Layer, Linear from paddle.fluid.dygraph import Layer, Linear
from paddle import fluid from paddle import fluid
from utils import seq_length
class LSTM(Layer): class LSTM(Layer):
""" """
...@@ -27,20 +29,22 @@ class LSTM(Layer): ...@@ -27,20 +29,22 @@ class LSTM(Layer):
""" """
initialize initialize
""" """
super(LSTM,self).__init__() super(LSTM, self).__init__()
self.dict_size = conf_dict["dict_size"] self.dict_size = conf_dict["dict_size"]
self.task_mode = conf_dict["task_mode"] self.task_mode = conf_dict["task_mode"]
self.emb_dim = conf_dict["net"]["emb_dim"] self.emb_dim = conf_dict["net"]["emb_dim"]
self.lstm_dim = conf_dict["net"]["lstm_dim"] self.lstm_dim = conf_dict["net"]["lstm_dim"]
self.hidden_dim = conf_dict["net"]["hidden_dim"] self.hidden_dim = conf_dict["net"]["hidden_dim"]
self.emb_layer = layers.EmbeddingLayer(self.dict_size, self.emb_dim, "emb").ops() self.emb_layer = layers.EmbeddingLayer(self.dict_size, self.emb_dim,
"emb").ops()
self.lstm_layer = layers.DynamicLSTMLayer(self.lstm_dim, "lstm").ops() self.lstm_layer = layers.DynamicLSTMLayer(self.lstm_dim, "lstm").ops()
self.fc_layer = layers.FCLayer(self.hidden_dim, None, "fc").ops() self.fc_layer = layers.FCLayer(self.hidden_dim, None, "fc").ops()
self.softmax_layer = layers.FCLayer(2, "softmax", "cos_sim").ops() self.softmax_layer = layers.FCLayer(2, "softmax", "cos_sim").ops()
self.proj_layer = Linear(input_dim = self.hidden_dim, output_dim=self.lstm_dim*4) self.proj_layer = Linear(
input_dim=self.hidden_dim, output_dim=self.lstm_dim * 4)
self.last_layer = layers.ExtractLastLayer()
self.seq_len = conf_dict["seq_len"] self.seq_len = conf_dict["seq_len"]
def forward(self, left, right): def forward(self, left, right):
""" """
Forward network Forward network
...@@ -53,19 +57,14 @@ class LSTM(Layer): ...@@ -53,19 +57,14 @@ class LSTM(Layer):
right_proj = self.proj_layer(right_emb) right_proj = self.proj_layer(right_emb)
left_lstm, _ = self.lstm_layer(left_proj) left_lstm, _ = self.lstm_layer(left_proj)
right_lstm, _ = self.lstm_layer(right_proj) right_lstm, _ = self.lstm_layer(right_proj)
# Get sequence length before padding
left_emb = fluid.layers.reduce_max(left_lstm, dim=1) left_len = seq_length(left)
right_emb = fluid.layers.reduce_max(right_lstm, dim=1) left_len.stop_gradient = True
left_emb = fluid.layers.reshape( right_len = seq_length(right)
left_emb, shape=[-1, self.seq_len, self.hidden_dim]) right_len.stop_gradient = True
right_emb = fluid.layers.reshape( # Extract last step
right_emb, shape=[-1, self.seq_len, self.hidden_dim]) left_last = self.last_layer.ops(left_lstm, left_len)
left_emb = fluid.layers.reduce_sum(left_emb, dim=1) right_last = self.last_layer.ops(right_lstm, right_len)
right_emb = fluid.layers.reduce_sum(right_emb, dim=1)
left_last = fluid.layers.tanh(left_emb)
right_last = fluid.layers.tanh(right_emb)
# matching layer # matching layer
if self.task_mode == "pairwise": if self.task_mode == "pairwise":
......
...@@ -1051,3 +1051,33 @@ class BasicGRUUnit(Layer): ...@@ -1051,3 +1051,33 @@ class BasicGRUUnit(Layer):
new_hidden = u * pre_hidden + (1 - u) * c new_hidden = u * pre_hidden + (1 - u) * c
return new_hidden return new_hidden
class ExtractLastLayer(object):
"""
a layer class: get the last step layer
"""
def __init__(self):
"""
init function
"""
pass
def ops(self, input_hidden, seq_length=None):
"""
operation
"""
if seq_length is not None:
output = input_hidden
output_shape = output.shape
batch_size = output_shape[0]
max_length = output_shape[1]
emb_size = output_shape[2]
index = fluid.layers.range(0, batch_size, 1,
'int32') * max_length + (seq_length - 1)
flat = fluid.layers.reshape(output, [-1, emb_size])
return fluid.layers.gather(flat, index)
else:
output = fluid.layers.transpose(input_hidden, [1, 0, 2])
return fluid.layers.gather(output, output.shape[0] - 1)
...@@ -32,12 +32,11 @@ class SimNetProcessor(object): ...@@ -32,12 +32,11 @@ class SimNetProcessor(object):
def padding_text(self, x): def padding_text(self, x):
if len(x) < self.seq_len: if len(x) < self.seq_len:
x += [0]*(self.seq_len-len(x)) x += [0] * (self.seq_len - len(x))
if len(x) > self.seq_len: if len(x) > self.seq_len:
x = x[0:self.seq_len] x = x[0:self.seq_len]
return x return x
def get_reader(self, mode, epoch=0): def get_reader(self, mode, epoch=0):
""" """
Get Reader Get Reader
...@@ -48,8 +47,8 @@ class SimNetProcessor(object): ...@@ -48,8 +47,8 @@ class SimNetProcessor(object):
Reader with Pairwise Reader with Pairwise
""" """
if mode == "valid": if mode == "valid":
with io.open(self.args.valid_data_dir, "r", with io.open(
encoding="utf8") as file: self.args.valid_data_dir, "r", encoding="utf8") as file:
for line in file: for line in file:
query, title, label = line.strip().split("\t") query, title, label = line.strip().split("\t")
if len(query) == 0 or len(title) == 0 or len( if len(query) == 0 or len(title) == 0 or len(
...@@ -76,7 +75,8 @@ class SimNetProcessor(object): ...@@ -76,7 +75,8 @@ class SimNetProcessor(object):
yield [query, title] yield [query, title]
elif mode == "test": elif mode == "test":
with io.open(self.args.test_data_dir, "r", encoding="utf8") as file: with io.open(
self.args.test_data_dir, "r", encoding="utf8") as file:
for line in file: for line in file:
query, title, label = line.strip().split("\t") query, title, label = line.strip().split("\t")
if len(query) == 0 or len(title) == 0 or len( if len(query) == 0 or len(title) == 0 or len(
...@@ -104,25 +104,29 @@ class SimNetProcessor(object): ...@@ -104,25 +104,29 @@ class SimNetProcessor(object):
yield [query, title] yield [query, title]
else: else:
for idx in range(epoch): for idx in range(epoch):
with io.open(self.args.train_data_dir, "r", with io.open(
self.args.train_data_dir, "r",
encoding="utf8") as file: encoding="utf8") as file:
for line in file: for line in file:
query, pos_title, neg_title = line.strip().split("\t") query, pos_title, neg_title = line.strip().split(
"\t")
if len(query) == 0 or len(pos_title) == 0 or len( if len(query) == 0 or len(pos_title) == 0 or len(
neg_title) == 0: neg_title) == 0:
logging.warning( logging.warning(
"line not match format in test file") "line not match format in train file")
continue continue
query = [ query = [
self.vocab[word] for word in query.split(" ") self.vocab[word] for word in query.split(" ")
if word in self.vocab if word in self.vocab
] ]
pos_title = [ pos_title = [
self.vocab[word] for word in pos_title.split(" ") self.vocab[word]
for word in pos_title.split(" ")
if word in self.vocab if word in self.vocab
] ]
neg_title = [ neg_title = [
self.vocab[word] for word in neg_title.split(" ") self.vocab[word]
for word in neg_title.split(" ")
if word in self.vocab if word in self.vocab
] ]
if len(query) == 0: if len(query) == 0:
...@@ -143,8 +147,8 @@ class SimNetProcessor(object): ...@@ -143,8 +147,8 @@ class SimNetProcessor(object):
Reader with Pointwise Reader with Pointwise
""" """
if mode == "valid": if mode == "valid":
with io.open(self.args.valid_data_dir, "r", with io.open(
encoding="utf8") as file: self.args.valid_data_dir, "r", encoding="utf8") as file:
for line in file: for line in file:
query, title, label = line.strip().split("\t") query, title, label = line.strip().split("\t")
if len(query) == 0 or len(title) == 0 or len( if len(query) == 0 or len(title) == 0 or len(
...@@ -171,7 +175,8 @@ class SimNetProcessor(object): ...@@ -171,7 +175,8 @@ class SimNetProcessor(object):
yield [query, title] yield [query, title]
elif mode == "test": elif mode == "test":
with io.open(self.args.test_data_dir, "r", encoding="utf8") as file: with io.open(
self.args.test_data_dir, "r", encoding="utf8") as file:
for line in file: for line in file:
query, title, label = line.strip().split("\t") query, title, label = line.strip().split("\t")
if len(query) == 0 or len(title) == 0 or len( if len(query) == 0 or len(title) == 0 or len(
...@@ -199,7 +204,8 @@ class SimNetProcessor(object): ...@@ -199,7 +204,8 @@ class SimNetProcessor(object):
yield [query, title] yield [query, title]
else: else:
for idx in range(epoch): for idx in range(epoch):
with io.open(self.args.train_data_dir, "r", with io.open(
self.args.train_data_dir, "r",
encoding="utf8") as file: encoding="utf8") as file:
for line in file: for line in file:
query, title, label = line.strip().split("\t") query, title, label = line.strip().split("\t")
......
...@@ -48,7 +48,7 @@ train() { ...@@ -48,7 +48,7 @@ train() {
evaluate() { evaluate() {
python run_classifier.py \ python run_classifier.py \
--task_name ${TASK_NAME} \ --task_name ${TASK_NAME} \
--use_cuda false \ --use_cuda False \
--do_test True \ --do_test True \
--verbose_result True \ --verbose_result True \
--batch_size 128 \ --batch_size 128 \
...@@ -65,7 +65,7 @@ evaluate() { ...@@ -65,7 +65,7 @@ evaluate() {
infer() { infer() {
python run_classifier.py \ python run_classifier.py \
--task_name ${TASK_NAME} \ --task_name ${TASK_NAME} \
--use_cuda false \ --use_cuda False \
--do_infer True \ --do_infer True \
--batch_size 128 \ --batch_size 128 \
--infer_data_dir ${INFER_DATA_PATH} \ --infer_data_dir ${INFER_DATA_PATH} \
......
...@@ -161,10 +161,6 @@ def train(conf_dict, args): ...@@ -161,10 +161,6 @@ def train(conf_dict, args):
if args.task_mode == "pairwise": if args.task_mode == "pairwise":
for left, pos_right, neg_right in train_loader(): for left, pos_right, neg_right in train_loader():
left = fluid.layers.reshape(left, shape=[-1, 1])
pos_right = fluid.layers.reshape(pos_right, shape=[-1, 1])
neg_right = fluid.layers.reshape(neg_right, shape=[-1, 1])
net.train() net.train()
global_step += 1 global_step += 1
left_feat, pos_score = net(left, pos_right) left_feat, pos_score = net(left, pos_right)
...@@ -178,9 +174,6 @@ def train(conf_dict, args): ...@@ -178,9 +174,6 @@ def train(conf_dict, args):
if args.do_valid and global_step % args.validation_steps == 0: if args.do_valid and global_step % args.validation_steps == 0:
for left, pos_right in valid_loader(): for left, pos_right in valid_loader():
left = fluid.layers.reshape(left, shape=[-1, 1])
pos_right = fluid.layers.reshape(
pos_right, shape=[-1, 1])
net.eval() net.eval()
left_feat, pos_score = net(left, pos_right) left_feat, pos_score = net(left, pos_right)
pred = pos_score pred = pos_score
...@@ -212,9 +205,6 @@ def train(conf_dict, args): ...@@ -212,9 +205,6 @@ def train(conf_dict, args):
logging.info("saving infer model in %s" % model_path) logging.info("saving infer model in %s" % model_path)
else: else:
for left, right, label in train_loader(): for left, right, label in train_loader():
left = fluid.layers.reshape(left, shape=[-1, 1])
right = fluid.layers.reshape(right, shape=[-1, 1])
label = fluid.layers.reshape(label, shape=[-1, 1])
net.train() net.train()
global_step += 1 global_step += 1
left_feat, pred = net(left, right) left_feat, pred = net(left, right)
...@@ -226,8 +216,6 @@ def train(conf_dict, args): ...@@ -226,8 +216,6 @@ def train(conf_dict, args):
if args.do_valid and global_step % args.validation_steps == 0: if args.do_valid and global_step % args.validation_steps == 0:
for left, right in valid_loader(): for left, right in valid_loader():
left = fluid.layers.reshape(left, shape=[-1, 1])
right = fluid.layers.reshape(right, shape=[-1, 1])
net.eval() net.eval()
left_feat, pred = net(left, right) left_feat, pred = net(left, right)
pred_list += list(pred.numpy()) pred_list += list(pred.numpy())
...@@ -296,11 +284,7 @@ def train(conf_dict, args): ...@@ -296,11 +284,7 @@ def train(conf_dict, args):
place) place)
pred_list = [] pred_list = []
for left, pos_right in test_loader(): for left, pos_right in test_loader():
left = fluid.layers.reshape(left, shape=[-1, 1])
pos_right = fluid.layers.reshape(pos_right, shape=[-1, 1])
net.eval() net.eval()
left = fluid.layers.reshape(left, shape=[-1, 1])
pos_right = fluid.layers.reshape(pos_right, shape=[-1, 1])
left_feat, pos_score = net(left, pos_right) left_feat, pos_score = net(left, pos_right)
pred = pos_score pred = pos_score
pred_list += list(pred.numpy()) pred_list += list(pred.numpy())
...@@ -351,9 +335,6 @@ def test(conf_dict, args): ...@@ -351,9 +335,6 @@ def test(conf_dict, args):
"predictions.txt", "w", encoding="utf8") as predictions_file: "predictions.txt", "w", encoding="utf8") as predictions_file:
if args.task_mode == "pairwise": if args.task_mode == "pairwise":
for left, pos_right in test_loader(): for left, pos_right in test_loader():
left = fluid.layers.reshape(left, shape=[-1, 1])
pos_right = fluid.layers.reshape(pos_right, shape=[-1, 1])
left_feat, pos_score = net(left, pos_right) left_feat, pos_score = net(left, pos_right)
pred = pos_score pred = pos_score
...@@ -365,8 +346,6 @@ def test(conf_dict, args): ...@@ -365,8 +346,6 @@ def test(conf_dict, args):
else: else:
for left, right in test_loader(): for left, right in test_loader():
left = fluid.layers.reshape(left, shape=[-1, 1])
right = fluid.layers.reshape(right, shape=[-1, 1])
left_feat, pred = net(left, right) left_feat, pred = net(left, right)
pred_list += list( pred_list += list(
...@@ -433,8 +412,6 @@ def infer(conf_dict, args): ...@@ -433,8 +412,6 @@ def infer(conf_dict, args):
pred_list = [] pred_list = []
if args.task_mode == "pairwise": if args.task_mode == "pairwise":
for left, pos_right in infer_loader(): for left, pos_right in infer_loader():
left = fluid.layers.reshape(left, shape=[-1, 1])
pos_right = fluid.layers.reshape(pos_right, shape=[-1, 1])
left_feat, pos_score = net(left, pos_right) left_feat, pos_score = net(left, pos_right)
pred = pos_score pred = pos_score
...@@ -443,8 +420,6 @@ def infer(conf_dict, args): ...@@ -443,8 +420,6 @@ def infer(conf_dict, args):
else: else:
for left, right in infer_loader(): for left, right in infer_loader():
left = fluid.layers.reshape(left, shape=[-1, 1])
pos_right = fluid.layers.reshape(right, shape=[-1, 1])
left_feat, pred = net(left, right) left_feat, pred = net(left, right)
pred_list += map(lambda item: str(np.argmax(item)), pred_list += map(lambda item: str(np.argmax(item)),
pred.numpy()) pred.numpy())
......
...@@ -33,6 +33,7 @@ from functools import partial ...@@ -33,6 +33,7 @@ from functools import partial
******functions for file processing****** ******functions for file processing******
""" """
def load_vocab(file_path): def load_vocab(file_path):
""" """
load the given vocabulary load the given vocabulary
...@@ -59,8 +60,11 @@ def get_result_file(args): ...@@ -59,8 +60,11 @@ def get_result_file(args):
""" """
with io.open(args.test_data_dir, "r", encoding="utf8") as test_file: with io.open(args.test_data_dir, "r", encoding="utf8") as test_file:
with io.open("predictions.txt", "r", encoding="utf8") as predictions_file: with io.open(
with io.open(args.test_result_path, "w", encoding="utf8") as test_result_file: "predictions.txt", "r", encoding="utf8") as predictions_file:
with io.open(
args.test_result_path, "w",
encoding="utf8") as test_result_file:
test_datas = [line.strip("\n") for line in test_file] test_datas = [line.strip("\n") for line in test_file]
predictions = [line.strip("\n") for line in predictions_file] predictions = [line.strip("\n") for line in predictions_file]
for test_data, prediction in zip(test_datas, predictions): for test_data, prediction in zip(test_datas, predictions):
...@@ -168,52 +172,82 @@ class ArgumentGroup(object): ...@@ -168,52 +172,82 @@ class ArgumentGroup(object):
help=help + ' Default: %(default)s.', help=help + ' Default: %(default)s.',
**kwargs) **kwargs)
class ArgConfig(object): class ArgConfig(object):
def __init__(self): def __init__(self):
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
model_g = ArgumentGroup(parser, "model", "model configuration and paths.") model_g = ArgumentGroup(parser, "model",
model_g.add_arg("config_path", str, None, "Path to the json file for EmoTect model config.") "model configuration and paths.")
model_g.add_arg("init_checkpoint", str, None, "Init checkpoint to resume training from.") model_g.add_arg("config_path", str, None,
model_g.add_arg("output_dir", str, None, "Directory path to save checkpoints") "Path to the json file for EmoTect model config.")
model_g.add_arg("task_mode", str, None, "task mode: pairwise or pointwise") model_g.add_arg("init_checkpoint", str, None,
"Init checkpoint to resume training from.")
model_g.add_arg("output_dir", str, None,
"Directory path to save checkpoints")
model_g.add_arg("task_mode", str, None,
"task mode: pairwise or pointwise")
train_g = ArgumentGroup(parser, "training", "training options.") train_g = ArgumentGroup(parser, "training", "training options.")
train_g.add_arg("epoch", int, 10, "Number of epoches for training.") train_g.add_arg("epoch", int, 10, "Number of epoches for training.")
train_g.add_arg("save_steps", int, 200, "The steps interval to save checkpoints.") train_g.add_arg("save_steps", int, 200,
train_g.add_arg("validation_steps", int, 100, "The steps interval to evaluate model performance.") "The steps interval to save checkpoints.")
train_g.add_arg("validation_steps", int, 100,
"The steps interval to evaluate model performance.")
log_g = ArgumentGroup(parser, "logging", "logging related") log_g = ArgumentGroup(parser, "logging", "logging related")
log_g.add_arg("skip_steps", int, 10, "The steps interval to print loss.") log_g.add_arg("skip_steps", int, 10,
log_g.add_arg("verbose_result", bool, True, "Whether to output verbose result.") "The steps interval to print loss.")
log_g.add_arg("test_result_path", str, "test_result", "Directory path to test result.") log_g.add_arg("verbose_result", bool, True,
log_g.add_arg("infer_result_path", str, "infer_result", "Directory path to infer result.") "Whether to output verbose result.")
log_g.add_arg("test_result_path", str, "test_result",
data_g = ArgumentGroup(parser, "data", "Data paths, vocab paths and data processing options") "Directory path to test result.")
data_g.add_arg("train_data_dir", str, None, "Directory path to training data.") log_g.add_arg("infer_result_path", str, "infer_result",
data_g.add_arg("valid_data_dir", str, None, "Directory path to valid data.") "Directory path to infer result.")
data_g.add_arg("test_data_dir", str, None, "Directory path to testing data.")
data_g.add_arg("infer_data_dir", str, None, "Directory path to infer data.") data_g = ArgumentGroup(
parser, "data",
"Data paths, vocab paths and data processing options")
data_g.add_arg("train_data_dir", str, None,
"Directory path to training data.")
data_g.add_arg("valid_data_dir", str, None,
"Directory path to valid data.")
data_g.add_arg("test_data_dir", str, None,
"Directory path to testing data.")
data_g.add_arg("infer_data_dir", str, None,
"Directory path to infer data.")
data_g.add_arg("vocab_path", str, None, "Vocabulary path.") data_g.add_arg("vocab_path", str, None, "Vocabulary path.")
data_g.add_arg("batch_size", int, 32, "Total examples' number in batch for training.") data_g.add_arg("batch_size", int, 32,
"Total examples' number in batch for training.")
data_g.add_arg("seq_len", int, 32, "The length of each sentence.") data_g.add_arg("seq_len", int, 32, "The length of each sentence.")
run_type_g = ArgumentGroup(parser, "run_type", "running type options.") run_type_g = ArgumentGroup(parser, "run_type", "running type options.")
run_type_g.add_arg("use_cuda", bool, False, "If set, use GPU for training.") run_type_g.add_arg("use_cuda", bool, False,
run_type_g.add_arg("task_name", str, None, "The name of task to perform sentiment classification.") "If set, use GPU for training.")
run_type_g.add_arg("do_train", bool, False, "Whether to perform training.") run_type_g.add_arg(
"task_name", str, None,
"The name of task to perform sentiment classification.")
run_type_g.add_arg("do_train", bool, False,
"Whether to perform training.")
run_type_g.add_arg("do_valid", bool, False, "Whether to perform dev.") run_type_g.add_arg("do_valid", bool, False, "Whether to perform dev.")
run_type_g.add_arg("do_test", bool, False, "Whether to perform testing.") run_type_g.add_arg("do_test", bool, False,
run_type_g.add_arg("do_infer", bool, False, "Whether to perform inference.") "Whether to perform testing.")
run_type_g.add_arg("compute_accuracy", bool, False, "Whether to compute accuracy.") run_type_g.add_arg("do_infer", bool, False,
run_type_g.add_arg("lamda", float, 0.91, "When task_mode is pairwise, lamda is the threshold for calculating the accuracy.") "Whether to perform inference.")
run_type_g.add_arg("compute_accuracy", bool, False,
"Whether to compute accuracy.")
run_type_g.add_arg(
"lamda", float, 0.91,
"When task_mode is pairwise, lamda is the threshold for calculating the accuracy."
)
custom_g = ArgumentGroup(parser, "customize", "customized options.") custom_g = ArgumentGroup(parser, "customize", "customized options.")
self.custom_g = custom_g self.custom_g = custom_g
parser.add_argument('--enable_ce',action='store_true',help='If set, run the task with continuous evaluation logs.') parser.add_argument(
'--enable_ce',
action='store_true',
help='If set, run the task with continuous evaluation logs.')
self.parser = parser self.parser = parser
...@@ -384,6 +418,26 @@ def load_dygraph(model_path, keep_name_table=False): ...@@ -384,6 +418,26 @@ def load_dygraph(model_path, keep_name_table=False):
if six.PY3: if six.PY3:
load_bak = pickle.load load_bak = pickle.load
pickle.load = partial(load_bak, encoding="latin1") pickle.load = partial(load_bak, encoding="latin1")
para_dict, opti_dict = fluid.load_dygraph(model_path, keep_name_table) para_dict, opti_dict = fluid.load_dygraph(model_path,
keep_name_table)
pickle.load = load_bak pickle.load = load_bak
return para_dict, opti_dict return para_dict, opti_dict
def seq_length(sequence):
"""
get sequence length
for id-sequence, (N, S)
or vector-sequence (N, S, D)
"""
if len(sequence.shape) == 2:
used = fluid.layers.sign(
fluid.layers.cast(fluid.layers.abs(sequence), np.float32))
else:
used = fluid.layers.sign(
fluid.layers.cast(
fluid.layers.reduce_max(fluid.layers.abs(sequence), 2),
np.float32))
length = fluid.layers.reduce_sum(used, 1)
length = fluid.layers.cast(length, np.int32)
return length
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册