提交 7ec25147 编写于 作者: Z Zeyu Chen

fix senta runing issues

上级 e687af48
...@@ -28,8 +28,8 @@ def bow_net(data, ...@@ -28,8 +28,8 @@ def bow_net(data,
fc_2 = fluid.layers.fc( fc_2 = fluid.layers.fc(
input=fc_1, size=hid_dim2, act="tanh", name="bow_fc2") input=fc_1, size=hid_dim2, act="tanh", name="bow_fc2")
# softmax layer # softmax layer
prediction = fluid.layers.fc(input=[fc_2], size=class_dim, act="softmax", prediction = fluid.layers.fc(
name="fc_softmax") input=[fc_2], size=class_dim, act="softmax", name="fc_softmax")
cost = fluid.layers.cross_entropy(input=prediction, label=label) cost = fluid.layers.cross_entropy(input=prediction, label=label)
avg_cost = fluid.layers.mean(x=cost) avg_cost = fluid.layers.mean(x=cost)
acc = fluid.layers.accuracy(input=prediction, label=label) acc = fluid.layers.accuracy(input=prediction, label=label)
......
# coding: utf-8 # coding: utf-8
import sys import sys
# NOTE: just hack for fast test
sys.path.append("../") sys.path.append("../")
sys.path.append("../paddle_hub/")
import os
import time import time
import unittest import unittest
import contextlib import contextlib
...@@ -114,7 +117,7 @@ def remove_feed_fetch_op(program): ...@@ -114,7 +117,7 @@ def remove_feed_fetch_op(program):
def train_net(train_reader, def train_net(train_reader,
word_dict, word_dict,
network, network_name,
use_gpu, use_gpu,
parallel, parallel,
save_dirname, save_dirname,
...@@ -124,39 +127,85 @@ def train_net(train_reader, ...@@ -124,39 +127,85 @@ def train_net(train_reader,
""" """
train network train network
""" """
if network == "bilstm_net": if network_name == "bilstm_net":
network = bilstm_net network = bilstm_net
elif network == "bow_net": elif network_name == "bow_net":
network = bow_net network = bow_net
elif network == "cnn_net": elif network_name == "cnn_net":
network = cnn_net network = cnn_net
elif network == "lstm_net": elif network_name == "lstm_net":
network = lstm_net network = lstm_net
elif network == "gru_net": elif network_name == "gru_net":
network = gru_net
else:
print("unknown network type")
return
label = fluid.layers.data(name="label", shape=[1], dtype="int64")
data = fluid.layers.data(
name="words", shape=[1], dtype="int64", lod_level=1)
cost, acc, pred, emb = network(data, label, len(word_dict) + 2)
# set optimizer
sgd_optimizer = fluid.optimizer.Adagrad(learning_rate=lr)
sgd_optimizer.minimize(cost)
# set place, executor, datafeeder
place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
feeder = fluid.DataFeeder(feed_list=["words", "label"], place=place)
exe.run(fluid.default_startup_program())
# start training...
for pass_id in range(pass_num):
data_size, data_count, total_acc, total_cost = 0, 0, 0.0, 0.0
for batch in train_reader():
avg_cost_np, avg_acc_np = exe.run(
fluid.default_main_program(),
feed=feeder.feed(batch),
fetch_list=[cost, acc],
return_numpy=True)
data_size = len(batch)
total_acc += data_size * avg_acc_np
total_cost += data_size * avg_cost_np
data_count += data_size
avg_cost = total_cost / data_count
avg_acc = total_acc / data_count
print("[train info]: pass_id: %d, avg_acc: %f, avg_cost: %f" %
(pass_id, avg_acc, avg_cost))
# save the model
module_path = os.path.join(save_dirname, network_name)
hub.ModuleDesc.save_module_dict(
module_path=module_path, word_dict=word_dict)
fluid.io.save_inference_model(module_path, ["words"], emb, exe)
def retrain_net(train_reader,
word_dict,
network_name,
use_gpu,
parallel,
save_dirname,
lr=0.002,
batch_size=128,
pass_num=30):
"""
train network
"""
if network_name == "bilstm_net":
network = bilstm_net
elif network_name == "bow_net":
network = bow_net
elif network_name == "cnn_net":
network = cnn_net
elif network_name == "lstm_net":
network = lstm_net
elif network_name == "gru_net":
network = gru_net network = gru_net
else: else:
print("unknown network type") print("unknown network type")
return return
# word seq data
# data = fluid.layers.data(
# name="words", shape=[1], dtype="int64", lod_level=1)
# if not parallel:
# # set network
# cost, acc, pred, emb = network(data, label, len(word_dict) + 2)
# else:
# places = fluid.layers.get_places(device_count=2)
# pd = fluid.layers.ParallelDo(places)
# with pd.do():
# # set network
# cost, acc, prediction, emb = network(
# pd.read_input(data), pd.read_input(label),
# len(word_dict) + 2)
# pd.write_output(cost)
# pd.write_output(acc)
# cost, acc = pd()
# cost = fluid.layers.mean(cost)
# acc = fluid.layers.mean(acc)
dict_dim = len(word_dict) + 2 dict_dim = len(word_dict) + 2
emb_dim = 128 emb_dim = 128
...@@ -164,7 +213,8 @@ def train_net(train_reader, ...@@ -164,7 +213,8 @@ def train_net(train_reader,
hid_dim2 = 96 hid_dim2 = 96
class_dim = 2 class_dim = 2
module_link = "https://paddlehub.cdn.bcebos.com/senta/bow_module_3.tar.gz" # module_link = "https://paddlehub.cdn.bcebos.com/senta/bow_module_3.tar.gz"
module_link = "./models/bow_net/"
module = hub.Module(module_link) module = hub.Module(module_link)
main_program = fluid.Program() main_program = fluid.Program()
...@@ -178,6 +228,7 @@ def train_net(train_reader, ...@@ -178,6 +228,7 @@ def train_net(train_reader,
label = fluid.layers.data(name="label", shape=[1], dtype="int64") label = fluid.layers.data(name="label", shape=[1], dtype="int64")
data = fluid.default_main_program().global_block().var("words") data = fluid.default_main_program().global_block().var("words")
#TODO(ZeyuChen): how to get output paramter according to proto config
emb = module.get_module_output() emb = module.get_module_output()
# # # embedding layer # # # embedding layer
...@@ -198,20 +249,6 @@ def train_net(train_reader, ...@@ -198,20 +249,6 @@ def train_net(train_reader,
fluid.layers.cross_entropy(input=pred, label=label)) fluid.layers.cross_entropy(input=pred, label=label))
acc = fluid.layers.accuracy(input=pred, label=label) acc = fluid.layers.accuracy(input=pred, label=label)
# Original Senta BoW networks
# label = fluid.layers.data(name="label", shape=[1], dtype="int64")
# data = fluid.layers.data(
# name="words", shape=[1], dtype="int64", lod_level=1)
# cost, acc, pred, emb = network(data, label, len(word_dict) + 2)
# print("new program")
# with open("program_senta.prototxt", "w") as fo:
# fo.write(str(fluid.default_main_program()))
# print("program_senta", fluid.default_main_program())
with open("senta_load_module.prototxt", "w") as fo:
fo.write(str(fluid.default_main_program()))
print("senta_load_module", fluid.default_main_program())
# set optimizer # set optimizer
sgd_optimizer = fluid.optimizer.Adagrad(learning_rate=lr) sgd_optimizer = fluid.optimizer.Adagrad(learning_rate=lr)
sgd_optimizer.minimize(cost) sgd_optimizer.minimize(cost)
...@@ -246,8 +283,8 @@ def train_net(train_reader, ...@@ -246,8 +283,8 @@ def train_net(train_reader,
# print("senta_load_module", fluid.default_main_program()) # print("senta_load_module", fluid.default_main_program())
# save the model # save the model
bow_module_path = save_dirname + "/" + "bow_module" module_path = os.path.join(save_dirname, network_name + "_retrain")
fluid.io.save_inference_model(bow_module_path, ["words"], emb, exe) fluid.io.save_inference_model(module_path, ["words"], emb, exe)
def eval_net(test_reader, use_gpu, model_path=None): def eval_net(test_reader, use_gpu, model_path=None):
...@@ -339,9 +376,12 @@ def main(args): ...@@ -339,9 +376,12 @@ def main(args):
args.word_dict_path, args.word_dict_path,
args.batch_size, args.mode) args.batch_size, args.mode)
train_net(train_reader, word_dict, args.model_type, args.use_gpu, # train_net(train_reader, word_dict, args.model_type, args.use_gpu,
args.is_parallel, args.model_path, args.lr, args.batch_size, # args.is_parallel, args.model_path, args.lr, args.batch_size,
args.num_passes) # args.num_passes)
retrain_net(train_reader, word_dict, args.model_type, args.use_gpu,
args.is_parallel, args.model_path, args.lr, args.batch_size,
args.num_passes)
# eval mode # eval mode
elif args.mode == "eval": elif args.mode == "eval":
......
...@@ -19,18 +19,33 @@ from __future__ import print_function ...@@ -19,18 +19,33 @@ from __future__ import print_function
import paddle.fluid as fluid import paddle.fluid as fluid
import numpy as np import numpy as np
import tempfile import tempfile
import utils
import os import os
from collections import defaultdict from collections import defaultdict
from downloader import download_and_uncompress from downloader import download_and_uncompress
__all__ = ["Module", "ModuleDesc"] __all__ = ["Module", "ModuleDesc"]
DICT_NAME = "dict.txt"
ASSETS_PATH = "assets"
def mkdir(path):
""" the same as the shell command mkdir -p "
"""
if not os.path.exists(path):
os.makedirs(path)
class Module(object): class Module(object):
def __init__(self, module_url): def __init__(self, module_url):
# donwload module # donwload module
module_dir = download_and_uncompress(module_url) if module_url.startswith("http"): # if it's remote url links
# if it's remote url link, then download and uncompress it
module_dir = download_and_uncompress(module_url)
else:
# otherwise it's local path, no need to deal with it
module_dir = module_url
# load paddle inference model # load paddle inference model
place = fluid.CPUPlace() place = fluid.CPUPlace()
...@@ -90,6 +105,7 @@ class Module(object): ...@@ -90,6 +105,7 @@ class Module(object):
def get_module_output(self): def get_module_output(self):
for var in self.inference_program.list_vars(): for var in self.inference_program.list_vars():
print(var) print(var)
# NOTE: just hack for load Senta's
if var.name == "embedding_0.tmp_0": if var.name == "embedding_0.tmp_0":
return var return var
...@@ -128,17 +144,22 @@ class Module(object): ...@@ -128,17 +144,22 @@ class Module(object):
# load assets folder # load assets folder
def _load_assets(self, module_dir): def _load_assets(self, module_dir):
assets_dir = os.path.join(module_dir, "assets") assets_dir = os.path.join(module_dir, ASSETS_PATH)
tokens_path = os.path.join(assets_dir, "tokens.txt") tokens_path = os.path.join(assets_dir, DICT_NAME)
word_id = 0 word_id = 0
with open(tokens_path) as fi: with open(tokens_path) as fi:
words = fi.readlines() words = fi.readlines()
words = map(str.strip, words) #TODO(ZeyuChen) check whether word id is duplicated and valid
for w in words: for line in fi:
self.dict[w] = word_id w, w_id = line.split()
word_id += 1 self.dict[w] = int(w_id)
print(w, word_id)
# words = map(str.strip, words)
# for w in words:
# self.dict[w] = word_id
# word_id += 1
# print(w, word_id)
def add_module_feed_list(self, feed_list): def add_module_feed_list(self, feed_list):
self.feed_list = feed_list self.feed_list = feed_list
...@@ -146,35 +167,27 @@ class Module(object): ...@@ -146,35 +167,27 @@ class Module(object):
def add_module_output_list(self, output_list): def add_module_output_list(self, output_list):
self.output_list = output_list self.output_list = output_list
def _mkdir(self, path):
if not os.path.exists(path):
os.makedirs(path)
class ModuleDesc(object): class ModuleDesc(object):
def __init__(self): def __init__(self):
pass pass
@staticmethod
def _mkdir(path):
if not os.path.exists(path):
os.makedirs(path)
@staticmethod @staticmethod
def save_dict(path, word_dict, dict_name): def save_dict(path, word_dict, dict_name):
""" Save dictionary for NLP module """ Save dictionary for NLP module
""" """
ModuleDesc._mkdir(path) mkdir(path)
with open(os.path.join(path, dict_name), "w") as fo: with open(os.path.join(path, dict_name), "w") as fo:
print("tokens.txt path", os.path.join(path, "tokens.txt")) print("tokens.txt path", os.path.join(path, DICT_NAME))
dict_str = "\n".join(word_dict) for w in word_dict:
fo.write(dict_str) w_id = word_dict[w]
fo.write("{}\t{}\n".format(w, w_id))
@staticmethod @staticmethod
def save_module_dict(module_path, word_dict, dict_name="dict.txt"): def save_module_dict(module_path, word_dict, dict_name=DICT_NAME):
""" Save dictionary for NLP module """ Save dictionary for NLP module
""" """
assets_path = os.path.join(module_path, "assets") assets_path = os.path.join(module_path, ASSETS_PATH)
print("save_module_dict", assets_path) print("save_module_dict", assets_path)
ModuleDesc.save_dict(assets_path, word_dict, dict_name) ModuleDesc.save_dict(assets_path, word_dict, dict_name)
pass pass
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册