提交 f95e05a3 编写于 作者: L Liu Yiqun

Refine the inference unittests.

上级 899ba0d0
...@@ -30,5 +30,5 @@ inference_test(label_semantic_roles) ...@@ -30,5 +30,5 @@ inference_test(label_semantic_roles)
inference_test(recognize_digits ARGS mlp) inference_test(recognize_digits ARGS mlp)
inference_test(recommender_system) inference_test(recommender_system)
#inference_test(rnn_encoder_decoder) #inference_test(rnn_encoder_decoder)
inference_test(understand_sentiment) inference_test(understand_sentiment ARGS conv lstm)
inference_test(word2vec) inference_test(word2vec)
...@@ -32,16 +32,42 @@ TEST(inference, label_semantic_roles) { ...@@ -32,16 +32,42 @@ TEST(inference, label_semantic_roles) {
paddle::framework::LoDTensor word, predicate, ctx_n2, ctx_n1, ctx_0, ctx_p1, paddle::framework::LoDTensor word, predicate, ctx_n2, ctx_n1, ctx_0, ctx_p1,
ctx_p2, mark; ctx_p2, mark;
paddle::framework::LoD lod{{0, 4, 10}}; paddle::framework::LoD lod{{0, 4, 10}};
int64_t word_dict_len = 44068;
SetupLoDTensor(word, lod, static_cast<int64_t>(0), static_cast<int64_t>(1)); int64_t predicate_dict_len = 3162;
SetupLoDTensor( int64_t mark_dict_len = 2;
predicate, lod, static_cast<int64_t>(0), static_cast<int64_t>(1));
SetupLoDTensor(ctx_n2, lod, static_cast<int64_t>(0), static_cast<int64_t>(1)); SetupLoDTensor(word,
SetupLoDTensor(ctx_n1, lod, static_cast<int64_t>(0), static_cast<int64_t>(1)); lod,
SetupLoDTensor(ctx_0, lod, static_cast<int64_t>(0), static_cast<int64_t>(1)); static_cast<int64_t>(0),
SetupLoDTensor(ctx_p1, lod, static_cast<int64_t>(0), static_cast<int64_t>(1)); static_cast<int64_t>(word_dict_len - 1));
SetupLoDTensor(ctx_p2, lod, static_cast<int64_t>(0), static_cast<int64_t>(1)); SetupLoDTensor(predicate,
SetupLoDTensor(mark, lod, static_cast<int64_t>(0), static_cast<int64_t>(1)); lod,
static_cast<int64_t>(0),
static_cast<int64_t>(predicate_dict_len - 1));
SetupLoDTensor(ctx_n2,
lod,
static_cast<int64_t>(0),
static_cast<int64_t>(word_dict_len - 1));
SetupLoDTensor(ctx_n1,
lod,
static_cast<int64_t>(0),
static_cast<int64_t>(word_dict_len - 1));
SetupLoDTensor(ctx_0,
lod,
static_cast<int64_t>(0),
static_cast<int64_t>(word_dict_len - 1));
SetupLoDTensor(ctx_p1,
lod,
static_cast<int64_t>(0),
static_cast<int64_t>(word_dict_len - 1));
SetupLoDTensor(ctx_p2,
lod,
static_cast<int64_t>(0),
static_cast<int64_t>(word_dict_len - 1));
SetupLoDTensor(mark,
lod,
static_cast<int64_t>(0),
static_cast<int64_t>(mark_dict_len - 1));
std::vector<paddle::framework::LoDTensor*> cpu_feeds; std::vector<paddle::framework::LoDTensor*> cpu_feeds;
cpu_feeds.push_back(&word); cpu_feeds.push_back(&word);
......
...@@ -31,7 +31,12 @@ TEST(inference, understand_sentiment) { ...@@ -31,7 +31,12 @@ TEST(inference, understand_sentiment) {
paddle::framework::LoDTensor words; paddle::framework::LoDTensor words;
paddle::framework::LoD lod{{0, 4, 10}}; paddle::framework::LoD lod{{0, 4, 10}};
SetupLoDTensor(words, lod, static_cast<int64_t>(0), static_cast<int64_t>(10)); int64_t word_dict_len = 5147;
SetupLoDTensor(words,
lod,
static_cast<int64_t>(0),
static_cast<int64_t>(word_dict_len - 1));
std::vector<paddle::framework::LoDTensor*> cpu_feeds; std::vector<paddle::framework::LoDTensor*> cpu_feeds;
cpu_feeds.push_back(&words); cpu_feeds.push_back(&words);
......
...@@ -296,7 +296,6 @@ def infer(use_cuda, save_dirname=None): ...@@ -296,7 +296,6 @@ def infer(use_cuda, save_dirname=None):
print(results[0].lod()) print(results[0].lod())
np_data = np.array(results[0]) np_data = np.array(results[0])
print("Inference Shape: ", np_data.shape) print("Inference Shape: ", np_data.shape)
print("Inference results: ", np_data)
def main(use_cuda): def main(use_cuda):
......
...@@ -93,7 +93,7 @@ def create_random_lodtensor(lod, place, low, high): ...@@ -93,7 +93,7 @@ def create_random_lodtensor(lod, place, low, high):
return res return res
def train(word_dict, net_method, use_cuda, save_dirname=None): def train(word_dict, nn_type, use_cuda, save_dirname=None):
BATCH_SIZE = 128 BATCH_SIZE = 128
PASS_NUM = 5 PASS_NUM = 5
dict_dim = len(word_dict) dict_dim = len(word_dict)
...@@ -102,6 +102,11 @@ def train(word_dict, net_method, use_cuda, save_dirname=None): ...@@ -102,6 +102,11 @@ def train(word_dict, net_method, use_cuda, save_dirname=None):
data = fluid.layers.data( data = fluid.layers.data(
name="words", shape=[1], dtype="int64", lod_level=1) name="words", shape=[1], dtype="int64", lod_level=1)
label = fluid.layers.data(name="label", shape=[1], dtype="int64") label = fluid.layers.data(name="label", shape=[1], dtype="int64")
if nn_type == "conv":
net_method = convolution_net
else:
net_method = stacked_lstm_net
cost, acc_out, prediction = net_method( cost, acc_out, prediction = net_method(
data, label, input_dim=dict_dim, class_dim=class_dim) data, label, input_dim=dict_dim, class_dim=class_dim)
...@@ -132,7 +137,7 @@ def train(word_dict, net_method, use_cuda, save_dirname=None): ...@@ -132,7 +137,7 @@ def train(word_dict, net_method, use_cuda, save_dirname=None):
net_method.__name__)) net_method.__name__))
def infer(use_cuda, save_dirname=None): def infer(word_dict, use_cuda, save_dirname=None):
if save_dirname is None: if save_dirname is None:
return return
...@@ -146,10 +151,11 @@ def infer(use_cuda, save_dirname=None): ...@@ -146,10 +151,11 @@ def infer(use_cuda, save_dirname=None):
[inference_program, feed_target_names, [inference_program, feed_target_names,
fetch_targets] = fluid.io.load_inference_model(save_dirname, exe) fetch_targets] = fluid.io.load_inference_model(save_dirname, exe)
word_dict_len = len(word_dict)
lod = [0, 4, 10] lod = [0, 4, 10]
word_dict = paddle.dataset.imdb.word_dict()
tensor_words = create_random_lodtensor( tensor_words = create_random_lodtensor(
lod, place, low=0, high=len(word_dict) - 1) lod, place, low=0, high=word_dict_len - 1)
# Construct feed as a dictionary of {feed_target_name: feed_target_data} # Construct feed as a dictionary of {feed_target_name: feed_target_data}
# and results will contain a list of data corresponding to fetch_targets. # and results will contain a list of data corresponding to fetch_targets.
...@@ -164,15 +170,15 @@ def infer(use_cuda, save_dirname=None): ...@@ -164,15 +170,15 @@ def infer(use_cuda, save_dirname=None):
print("Inference results: ", np_data) print("Inference results: ", np_data)
def main(word_dict, net_method, use_cuda): def main(word_dict, nn_type, use_cuda):
if use_cuda and not fluid.core.is_compiled_with_cuda(): if use_cuda and not fluid.core.is_compiled_with_cuda():
return return
# Directory for saving the trained model # Directory for saving the trained model
save_dirname = "understand_sentiment.inference.model" save_dirname = "understand_sentiment_" + nn_type + ".inference.model"
train(word_dict, net_method, use_cuda, save_dirname) train(word_dict, nn_type, use_cuda, save_dirname)
infer(use_cuda, save_dirname) infer(word_dict, use_cuda, save_dirname)
class TestUnderstandSentiment(unittest.TestCase): class TestUnderstandSentiment(unittest.TestCase):
...@@ -191,19 +197,19 @@ class TestUnderstandSentiment(unittest.TestCase): ...@@ -191,19 +197,19 @@ class TestUnderstandSentiment(unittest.TestCase):
def test_conv_cpu(self): def test_conv_cpu(self):
with self.new_program_scope(): with self.new_program_scope():
main(self.word_dict, net_method=convolution_net, use_cuda=False) main(self.word_dict, nn_type="conv", use_cuda=False)
def test_stacked_lstm_cpu(self): def test_stacked_lstm_cpu(self):
with self.new_program_scope(): with self.new_program_scope():
main(self.word_dict, net_method=stacked_lstm_net, use_cuda=False) main(self.word_dict, nn_type="lstm", use_cuda=False)
def test_conv_gpu(self): def test_conv_gpu(self):
with self.new_program_scope(): with self.new_program_scope():
main(self.word_dict, net_method=convolution_net, use_cuda=True) main(self.word_dict, nn_type="conv", use_cuda=True)
def test_stacked_lstm_gpu(self): def test_stacked_lstm_gpu(self):
with self.new_program_scope(): with self.new_program_scope():
main(self.word_dict, net_method=stacked_lstm_net, use_cuda=True) main(self.word_dict, nn_type="lstm", use_cuda=True)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册