提交 3ab7da0c 编写于 作者: Z Zeyu Chen

fix typo

上级 7ec25147
...@@ -20,7 +20,6 @@ from nets import cnn_net ...@@ -20,7 +20,6 @@ from nets import cnn_net
from nets import lstm_net from nets import lstm_net
from nets import bilstm_net from nets import bilstm_net
from nets import gru_net from nets import gru_net
logger = logging.getLogger("paddle-fluid") logger = logging.getLogger("paddle-fluid")
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
...@@ -93,28 +92,6 @@ def parse_args(): ...@@ -93,28 +92,6 @@ def parse_args():
return args return args
def remove_feed_fetch_op(program):
""" remove feed and fetch operator and variable for fine-tuning
"""
print("remove feed fetch op")
block = program.global_block()
need_to_remove_op_index = []
for i, op in enumerate(block.ops):
if op.type == "feed" or op.type == "fetch":
need_to_remove_op_index.append(i)
for index in need_to_remove_op_index[::-1]:
block._remove_op(index)
block._remove_var("feed")
block._remove_var("fetch")
program.desc.flush()
print("********************************")
print(program)
print("********************************")
def train_net(train_reader, def train_net(train_reader,
word_dict, word_dict,
network_name, network_name,
...@@ -224,6 +201,7 @@ def retrain_net(train_reader, ...@@ -224,6 +201,7 @@ def retrain_net(train_reader,
fluid.framework.switch_main_program(module.get_inference_program()) fluid.framework.switch_main_program(module.get_inference_program())
# remove feed fetch operator and variable # remove feed fetch operator and variable
ModuleUtils.remove_feed_fetch_op(fluid.default_main_program())
remove_feed_fetch_op(fluid.default_main_program()) remove_feed_fetch_op(fluid.default_main_program())
label = fluid.layers.data(name="label", shape=[1], dtype="int64") label = fluid.layers.data(name="label", shape=[1], dtype="int64")
...@@ -231,6 +209,9 @@ def retrain_net(train_reader, ...@@ -231,6 +209,9 @@ def retrain_net(train_reader,
#TODO(ZeyuChen): how to get output paramter according to proto config #TODO(ZeyuChen): how to get output paramter according to proto config
emb = module.get_module_output() emb = module.get_module_output()
print(
"adfjkajdlfjoqi jqiorejlmsfdlkjoi jqwierjoajsdklfjoi qjerijoajdfiqwjeor adfkalsf"
)
# # # embedding layer # # # embedding layer
# emb = fluid.layers.embedding(input=data, size=[dict_dim, emb_dim]) # emb = fluid.layers.embedding(input=data, size=[dict_dim, emb_dim])
# #input=data, size=[dict_dim, emb_dim], param_attr="bow_embedding") # #input=data, size=[dict_dim, emb_dim], param_attr="bow_embedding")
...@@ -376,12 +357,9 @@ def main(args): ...@@ -376,12 +357,9 @@ def main(args):
args.word_dict_path, args.word_dict_path,
args.batch_size, args.mode) args.batch_size, args.mode)
# train_net(train_reader, word_dict, args.model_type, args.use_gpu, train_net(train_reader, word_dict, args.model_type, args.use_gpu,
# args.is_parallel, args.model_path, args.lr, args.batch_size, args.is_parallel, args.model_path, args.lr, args.batch_size,
# args.num_passes) args.num_passes)
retrain_net(train_reader, word_dict, args.model_type, args.use_gpu,
args.is_parallel, args.model_path, args.lr, args.batch_size,
args.num_passes)
# eval mode # eval mode
elif args.mode == "eval": elif args.mode == "eval":
......
...@@ -109,7 +109,7 @@ def download_and_uncompress(url, save_name=None): ...@@ -109,7 +109,7 @@ def download_and_uncompress(url, save_name=None):
for file_name in file_names: for file_name in file_names:
tar.extract(file_name, dirname) tar.extract(file_name, dirname)
return module_dir return module_name, module_dir
class TqdmProgress(tqdm): class TqdmProgress(tqdm):
......
...@@ -19,15 +19,15 @@ from __future__ import print_function ...@@ -19,15 +19,15 @@ from __future__ import print_function
import paddle.fluid as fluid import paddle.fluid as fluid
import numpy as np import numpy as np
import tempfile import tempfile
import utils
import os import os
import module_desc_pb2
from collections import defaultdict from collections import defaultdict
from downloader import download_and_uncompress from downloader import download_and_uncompress
__all__ = ["Module", "ModuleDesc"] __all__ = ["Module", "ModuleConfig", "ModuleUtils"]
DICT_NAME = "dict.txt" DICT_NAME = "dict.txt"
ASSETS_PATH = "assets" ASSETS_NAME = "assets"
def mkdir(path): def mkdir(path):
...@@ -40,12 +40,13 @@ def mkdir(path): ...@@ -40,12 +40,13 @@ def mkdir(path):
class Module(object): class Module(object):
def __init__(self, module_url): def __init__(self, module_url):
# donwload module # donwload module
if module_url.startswith("http"): # if it's remote url links if module_url.startswith("http"):
# if it's remote url link, then download and uncompress it # if it's remote url link, then download and uncompress it
module_dir = download_and_uncompress(module_url) module_name, module_dir = download_and_uncompress(module_url)
else: else:
# otherwise it's local path, no need to deal with it # otherwise it's local path, no need to deal with it
module_dir = module_url module_dir = module_url
module_name = module_url.split()[-1]
# load paddle inference model # load paddle inference model
place = fluid.CPUPlace() place = fluid.CPUPlace()
...@@ -62,9 +63,9 @@ class Module(object): ...@@ -62,9 +63,9 @@ class Module(object):
print(self.fetch_targets) print(self.fetch_targets)
# load assets # load assets
self.dict = defaultdict(int) # self.dict = defaultdict(int)
self.dict.setdefault(0) # self.dict.setdefault(0)
self._load_assets(module_dir) # self._load_assets(module_dir)
#TODO(ZeyuChen): Need add register more signature to execute different #TODO(ZeyuChen): Need add register more signature to execute different
# implmentation # implmentation
...@@ -92,6 +93,9 @@ class Module(object): ...@@ -92,6 +93,9 @@ class Module(object):
return np_result return np_result
def add_input_desc(var_name):
pass
def get_vars(self): def get_vars(self):
return self.inference_program.list_vars() return self.inference_program.list_vars()
...@@ -144,23 +148,17 @@ class Module(object): ...@@ -144,23 +148,17 @@ class Module(object):
# load assets folder # load assets folder
def _load_assets(self, module_dir): def _load_assets(self, module_dir):
assets_dir = os.path.join(module_dir, ASSETS_PATH) assets_dir = os.path.join(module_dir, ASSETS_NAME)
tokens_path = os.path.join(assets_dir, DICT_NAME) dict_path = os.path.join(assets_dir, DICT_NAME)
word_id = 0 word_id = 0
with open(tokens_path) as fi: with open(dict_path) as fi:
words = fi.readlines() words = fi.readlines()
#TODO(ZeyuChen) check whether word id is duplicated and valid #TODO(ZeyuChen) check whether word id is duplicated and valid
for line in fi: for line in fi:
w, w_id = line.split() w, w_id = line.split()
self.dict[w] = int(w_id) self.dict[w] = int(w_id)
# words = map(str.strip, words)
# for w in words:
# self.dict[w] = word_id
# word_id += 1
# print(w, word_id)
def add_module_feed_list(self, feed_list): def add_module_feed_list(self, feed_list):
self.feed_list = feed_list self.feed_list = feed_list
...@@ -168,30 +166,89 @@ class Module(object): ...@@ -168,30 +166,89 @@ class Module(object):
self.output_list = output_list self.output_list = output_list
class ModuleDesc(object): class ModuleConfig(object):
def __init__(self): def __init__(self, module_dir):
pass # generate model desc protobuf
self.module_dir = module_dir
@staticmethod self.desc = module_desc_pb3.ModuleDesc()
def save_dict(path, word_dict, dict_name): self.desc.name = module_name
""" Save dictionary for NLP module print("desc.name=", self.desc.name)
self.desc.signature = "default"
print("desc.signature=", self.desc.signature)
self.desc.contain_assets = True
print("desc.signature=", self.desc.contain_assets)
def load(module_dir):
"""load module config from module dir
""" """
mkdir(path) #TODO(ZeyuChen): check module_desc.pb exsitance
with open(os.path.join(path, dict_name), "w") as fo: with open(pb_file_path, "rb") as fi:
print("tokens.txt path", os.path.join(path, DICT_NAME)) self.desc.ParseFromString(fi.read())
if self.desc.contain_assets:
# load assets
self.dict = defaultdict(int)
self.dict.setdefault(0)
assets_dir = os.path.join(self.module_dir, assets_dir)
dict_path = os.path.join(assets_dir, DICT_NAME)
word_id = 0
with open(dict_path) as fi:
words = fi.readlines()
#TODO(ZeyuChen) check whether word id is duplicated and valid
for line in fi:
w, w_id = line.split()
self.dict[w] = int(w_id)
def dump():
# save module_desc.proto first
pb_path = os.path.join(self.module, "module_desc.pb")
with open(pb_path, "wb") as fo:
fo.write(self.desc.SerializeToString())
# save assets/dictionary
assets_dir = os.path.join(self.module_dir, assets_dir)
mkdir(assets_dir)
with open(os.path.join(assets_dir, DICT_NAME), "w") as fo:
for w in word_dict: for w in word_dict:
w_id = word_dict[w] w_id = word_dict[w]
fo.write("{}\t{}\n".format(w, w_id)) fo.write("{}\t{}\n".format(w, w_id))
@staticmethod def save_dict(word_dict, dict_name=DICT_NAME):
def save_module_dict(module_path, word_dict, dict_name=DICT_NAME):
""" Save dictionary for NLP module """ Save dictionary for NLP module
""" """
assets_path = os.path.join(module_path, ASSETS_PATH) mkdir(path)
print("save_module_dict", assets_path) with open(os.path.join(self.module_dir, DICT_NAME), "w") as fo:
ModuleDesc.save_dict(assets_path, word_dict, dict_name) for w in word_dict:
self.dict[w] = word_dict[w]
class ModuleUtils(object):
def __init__(self):
pass pass
@staticmethod
def remove_feed_fetch_op(program):
""" remove feed and fetch operator and variable for fine-tuning
"""
print("remove feed fetch op")
block = program.global_block()
need_to_remove_op_index = []
for i, op in enumerate(block.ops):
if op.type == "feed" or op.type == "fetch":
need_to_remove_op_index.append(i)
for index in need_to_remove_op_index[::-1]:
block._remove_op(index)
block._remove_var("feed")
block._remove_var("fetch")
program.desc.flush()
print("********************************")
print(program)
print("********************************")
if __name__ == "__main__": if __name__ == "__main__":
module_link = "http://paddlehub.cdn.bcebos.com/word2vec/w2v_saved_inference_module.tar.gz" module_link = "http://paddlehub.cdn.bcebos.com/word2vec/w2v_saved_inference_module.tar.gz"
......
...@@ -12,23 +12,34 @@ ...@@ -12,23 +12,34 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
// ============================================================================= // =============================================================================
syntax = "proto3"; syntax = "proto3";
option optimize_for = LITE_RUNTIME;
package paddle_hub; package paddle_hub;
message InputDesc { message InputDesc {
} string name = 1;
};
message OutputDesc { message OutputDesc {
bool return_numpy = 1; string name = 1;
} };
// A Hub Module is stored in a directory with a file 'paddlehub_module.pb'
// A Hub Module is stored in a directory with a file 'paddlehub.pb'
// containing a serialized protocol message of this type. The further contents // containing a serialized protocol message of this type. The further contents
// of the directory depend on the storage format described by the message. // of the directory depend on the storage format described by the message.
message ModuleDesc { message ModuleDesc {
string name = 1; // PaddleHub module name string name = 1; // PaddleHub module name
repeated InputDesc input_desc = 2;
repeated OutputDesc output_desc = 3;
string signature = 4;
bool return_numpy = 5;
repeated string input_signature bool contain_assets = 6;
} };
[metadata]
license_file = LICENSE
# Copyright 2018 The TensorFlow Hub Authors. All Rights Reserved. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
...@@ -11,7 +11,6 @@ ...@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""Setup for pip package.""" """Setup for pip package."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
...@@ -29,7 +28,7 @@ REQUIRED_PACKAGES = [ ...@@ -29,7 +28,7 @@ REQUIRED_PACKAGES = [
] ]
setup( setup(
name='paddle_hub', name='paddle_hub',
version=__version__.replace('-', ''), version=__version__.replace('-', ''),
description=('PaddleHub is a library to foster the publication, ' description=('PaddleHub is a library to foster the publication, '
'discovery, and consumption of reusable parts of machine ' 'discovery, and consumption of reusable parts of machine '
......
...@@ -184,7 +184,7 @@ def train(use_cuda=False): ...@@ -184,7 +184,7 @@ def train(use_cuda=False):
dictionary.append(w) dictionary.append(w)
# save word dict to assets folder # save word dict to assets folder
hub.ModuleDesc.save_module_dict( hub.ModuleConfig.save_module_dict(
module_path=saved_model_path, word_dict=dictionary) module_path=saved_model_path, word_dict=dictionary)
...@@ -214,9 +214,9 @@ def test_save_module(use_cuda=False): ...@@ -214,9 +214,9 @@ def test_save_module(use_cuda=False):
np_result = np.array(results[0]) np_result = np.array(results[0])
print(np_result) print(np_result)
saved_module_path = "./test/word2vec_inference_module" saved_module_dir = "./test/word2vec_inference_module"
fluid.io.save_inference_model( fluid.io.save_inference_model(
dirname=saved_module_path, dirname=saved_module_dir,
feeded_var_names=["words"], feeded_var_names=["words"],
target_vars=[word_emb], target_vars=[word_emb],
executor=exe) executor=exe)
...@@ -227,17 +227,19 @@ def test_save_module(use_cuda=False): ...@@ -227,17 +227,19 @@ def test_save_module(use_cuda=False):
w = w.decode("ascii") w = w.decode("ascii")
dictionary.append(w) dictionary.append(w)
# save word dict to assets folder # save word dict to assets folder
hub.ModuleDesc.save_module_dict( config = hub.ModuleConfig(saved_module_dir)
module_path=saved_module_path, word_dict=dictionary) config.save_dict(word_dict=dictionary)
config.dump()
def test_load_module(use_cuda=False): def test_load_module(use_cuda=False):
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(fluid.CPUPlace()) exe = fluid.Executor(fluid.CPUPlace())
saved_module_path = "./test/word2vec_inference_module" saved_module_dir = "./test/word2vec_inference_module"
[inference_program, feed_target_names, [inference_program, feed_target_names,
fetch_targets] = fluid.io.load_inference_model( fetch_targets] = fluid.io.load_inference_model(
saved_module_path, executor=exe) saved_module_dir, executor=exe)
# Sequence input in Paddle must be LOD Tensor, so we need to convert them inside Module # Sequence input in Paddle must be LOD Tensor, so we need to convert them inside Module
word_ids = [[1, 2, 3, 4, 5]] word_ids = [[1, 2, 3, 4, 5]]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册