提交 cc92eb41 编写于 作者: Z Zeyu Chen

add new test file

上级 db5fd353
python test_create_hub.py --train_data_path ./data/train_data/corpus.train --word_dict_path ./data/train.vocab --mode train --model_path ./models
python test_finetune.py --train_data_path ./data/train_data/corpus.train --word_dict_path ./data/train.vocab --mode train --model_path ./models
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# coding: utf-8
import sys
import os
import time
import unittest
import contextlib
import logging
import argparse
import ast
import utils
import paddle.fluid as fluid
import paddle_hub as hub
from nets import bow_net
from nets import cnn_net
from nets import lstm_net
from nets import bilstm_net
from nets import gru_net
logger = logging.getLogger("paddle-fluid")
logger.setLevel(logging.INFO)
def parse_args():
parser = argparse.ArgumentParser("Sentiment Classification.")
# training data path
parser.add_argument(
"--train_data_path",
type=str,
required=False,
help="The path of trainning data. Should be given in train mode!")
# test data path
parser.add_argument(
"--test_data_path",
type=str,
required=False,
help="The path of test data. Should be given in eval or infer mode!")
# word_dict path
parser.add_argument(
"--word_dict_path",
type=str,
required=True,
help="The path of word dictionary.")
# current mode
parser.add_argument(
"--mode",
type=str,
required=True,
choices=['train', 'eval', 'infer'],
help="train/eval/infer mode")
# model type
parser.add_argument(
"--model_type", type=str, default="bow_net", help="type of model")
# model save path parser.add_argument(
parser.add_argument(
"--model_path",
type=str,
default="models",
required=True,
help="The path to saved the trained models.")
# Number of passes for the training task.
parser.add_argument(
"--num_passes",
type=int,
default=3,
help="Number of passes for the training task.")
# Batch size
parser.add_argument(
"--batch_size",
type=int,
default=256,
help="The number of training examples in one forward/backward pass.")
# lr value for training
parser.add_argument(
"--lr", type=float, default=0.002, help="The lr value for training.")
# Whether to use gpu
parser.add_argument(
"--use_gpu",
type=ast.literal_eval,
default=False,
help="Whether to use gpu to train the model.")
# parallel train
parser.add_argument(
"--is_parallel",
type=ast.literal_eval,
default=False,
help="Whether to train the model in parallel.")
args = parser.parse_args()
return args
def bow_net_module(data,
label,
dict_dim,
emb_dim=128,
hid_dim=128,
hid_dim2=96,
class_dim=2):
"""
Bow net
"""
module_dir = "./model/test_create_module"
# embedding layer
emb = fluid.layers.embedding(
input=data, size=[dict_dim, emb_dim], param_attr="bow_embedding")
# bow layer
bow = fluid.layers.sequence_pool(input=emb, pool_type='sum')
bow_tanh = fluid.layers.tanh(bow)
# full connect layer
fc_1 = fluid.layers.fc(
input=bow_tanh, size=hid_dim, act="tanh", name="bow_fc1")
fc_2 = fluid.layers.fc(
input=fc_1, size=hid_dim2, act="tanh", name="bow_fc2")
# softmax layer
prediction = fluid.layers.fc(
input=[fc_2], size=class_dim, act="softmax", name="fc_softmax")
cost = fluid.layers.cross_entropy(input=prediction, label=label)
avg_cost = fluid.layers.mean(x=cost)
acc = fluid.layers.accuracy(input=prediction, label=label)
return avg_cost, acc, prediction, emb
def train_net(train_reader,
word_dict,
network_name,
use_gpu,
parallel,
save_dirname,
lr=0.002,
batch_size=128,
pass_num=10):
"""
train network
"""
if network_name == "bilstm_net":
network = bilstm_net
elif network_name == "bow_net":
network = bow_net
elif network_name == "cnn_net":
network = cnn_net
elif network_name == "lstm_net":
network = lstm_net
elif network_name == "gru_net":
network = gru_net
else:
print("unknown network type")
return
label = fluid.layers.data(name="label", shape=[1], dtype="int64")
data = fluid.layers.data(
name="words", shape=[1], dtype="int64", lod_level=1)
cost, acc, pred, emb = network(data, label, len(word_dict) + 2)
# set optimizer
sgd_optimizer = fluid.optimizer.Adagrad(learning_rate=lr)
sgd_optimizer.minimize(cost)
# set place, executor, datafeeder
place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
feeder = fluid.DataFeeder(feed_list=["words", "label"], place=place)
exe.run(fluid.default_startup_program())
# start training...
for pass_id in range(pass_num):
data_size, data_count, total_acc, total_cost = 0, 0, 0.0, 0.0
for batch in train_reader():
avg_cost_np, avg_acc_np = exe.run(
fluid.default_main_program(),
feed=feeder.feed(batch),
fetch_list=[cost, acc],
return_numpy=True)
data_size = len(batch)
total_acc += data_size * avg_acc_np
total_cost += data_size * avg_cost_np
data_count += data_size
avg_cost = total_cost / data_count
avg_acc = total_acc / data_count
print("[train info]: pass_id: %d, avg_acc: %f, avg_cost: %f" %
(pass_id, avg_acc, avg_cost))
# save the model
module_dir = os.path.join(save_dirname, network_name)
config = hub.ModuleConfig(module_dir)
config.save_dict(word_dict=word_dict)
# saving config
input_desc = {"words": data.name}
output_desc = {"emb": emb.name}
config.register_feed_signature(input_desc)
config.register_fetch_signature(output_desc)
config.dump()
feed_var_name = config.feed_var_name("words")
fluid.io.save_inference_model(module_dir, [feed_var_name], emb, exe)
def retrain_net(train_reader,
word_dict,
network_name,
use_gpu,
parallel,
save_dirname,
lr=0.002,
batch_size=128,
pass_num=30):
"""
train network
"""
if network_name == "bilstm_net":
network = bilstm_net
elif network_name == "bow_net":
network = bow_net
elif network_name == "cnn_net":
network = cnn_net
elif network_name == "lstm_net":
network = lstm_net
elif network_name == "gru_net":
network = gru_net
else:
print("unknown network type")
return
dict_dim = len(word_dict) + 2
emb_dim = 128
hid_dim = 128
hid_dim2 = 96
class_dim = 2
module_path = "./models/bow_net"
module = hub.Module(module_dir=module_path)
main_program = fluid.Program()
startup_program = fluid.Program()
# use switch program to test fine-tuning
fluid.framework.switch_main_program(module.get_inference_program())
# remove feed fetch operator and variable
hub.ModuleUtils.remove_feed_fetch_op(fluid.default_main_program())
label = fluid.layers.data(name="label", shape=[1], dtype="int64")
data = module.get_feed_var("words")
emb = module.get_fetch_var("emb")
# # data layer
# data = fluid.layers.data(
# name="words", shape=[1], dtype="int64", lod_level=1)
# # embedding layer
# emb = fluid.layers.embedding(input=data, size=[dict_dim, emb_dim])
# bow layer
bow = fluid.layers.sequence_pool(input=emb, pool_type='sum')
bow_tanh = fluid.layers.tanh(bow)
# full connect layer
fc_1 = fluid.layers.fc(
input=bow_tanh, size=hid_dim, act="tanh", name="bow_fc1")
fc_2 = fluid.layers.fc(
input=fc_1, size=hid_dim2, act="tanh", name="bow_fc2")
# softmax layer
pred = fluid.layers.fc(input=[fc_2], size=class_dim, act="softmax")
cost = fluid.layers.mean(
fluid.layers.cross_entropy(input=pred, label=label))
acc = fluid.layers.accuracy(input=pred, label=label)
# set optimizer
sgd_optimizer = fluid.optimizer.Adagrad(learning_rate=lr)
sgd_optimizer.minimize(cost)
# set place, executor, datafeeder
place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
feeder = fluid.DataFeeder(feed_list=["words", "label"], place=place)
exe.run(fluid.default_startup_program())
# start training...
for pass_id in range(pass_num):
data_size, data_count, total_acc, total_cost = 0, 0, 0.0, 0.0
for batch in train_reader():
avg_cost_np, avg_acc_np = exe.run(
fluid.default_main_program(),
feed=feeder.feed(batch),
fetch_list=[cost, acc],
return_numpy=True)
data_size = len(batch)
total_acc += data_size * avg_acc_np
total_cost += data_size * avg_cost_np
data_count += data_size
avg_cost = total_cost / data_count
avg_acc = total_acc / data_count
print("[train info]: pass_id: %d, avg_acc: %f, avg_cost: %f" %
(pass_id, avg_acc, avg_cost))
# save the model
module_dir = os.path.join(save_dirname, network_name + "_retrain")
fluid.io.save_inference_model(module_dir, ["words"], emb, exe)
config = hub.ModuleConfig(module_dir)
config.save_dict(word_dict=word_dict)
config.dump()
def main(args):
# prepare_data to get word_dict, train_reader
word_dict, train_reader = utils.prepare_data(
args.train_data_path, args.word_dict_path, args.batch_size, args.mode)
train_net(train_reader, word_dict, args.model_type, args.use_gpu,
args.is_parallel, args.model_path, args.lr, args.batch_size,
args.num_passes)
# NOTE(ZeyuChen): can't run train_net and retrain_net together
# retrain_net(train_reader, word_dict, args.model_type, args.use_gpu,
# args.is_parallel, args.model_path, args.lr, args.batch_size,
# args.num_passes)
if __name__ == "__main__":
args = parse_args()
main(args)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# coding: utf-8
import sys
import os
import time
import unittest
import contextlib
import logging
import argparse
import ast
import utils
import paddle.fluid as fluid
import paddle_hub as hub
from nets import bow_net
from nets import cnn_net
from nets import lstm_net
from nets import bilstm_net
from nets import gru_net
logger = logging.getLogger("paddle-fluid")
logger.setLevel(logging.INFO)
def parse_args():
parser = argparse.ArgumentParser("Sentiment Classification.")
# training data path
parser.add_argument(
"--train_data_path",
type=str,
required=False,
help="The path of trainning data. Should be given in train mode!")
# test data path
parser.add_argument(
"--test_data_path",
type=str,
required=False,
help="The path of test data. Should be given in eval or infer mode!")
# word_dict path
parser.add_argument(
"--word_dict_path",
type=str,
required=True,
help="The path of word dictionary.")
# current mode
parser.add_argument(
"--mode",
type=str,
required=True,
choices=['train', 'eval', 'infer'],
help="train/eval/infer mode")
# model type
parser.add_argument(
"--model_type", type=str, default="bow_net", help="type of model")
# model save path
parser.add_argument(
"--model_path",
type=str,
default="models",
required=True,
help="The path to saved the trained models.")
# Number of passes for the training task.
parser.add_argument(
"--num_passes",
type=int,
default=10,
help="Number of passes for the training task.")
# Batch size
parser.add_argument(
"--batch_size",
type=int,
default=256,
help="The number of training examples in one forward/backward pass.")
# lr value for training
parser.add_argument(
"--lr", type=float, default=0.002, help="The lr value for training.")
# Whether to use gpu
parser.add_argument(
"--use_gpu",
type=ast.literal_eval,
default=False,
help="Whether to use gpu to train the model.")
# parallel train
parser.add_argument(
"--is_parallel",
type=ast.literal_eval,
default=False,
help="Whether to train the model in parallel.")
args = parser.parse_args()
return args
def retrain_net(train_reader,
word_dict,
network_name,
use_gpu,
parallel,
save_dirname,
lr=0.002,
batch_size=128,
pass_num=30):
"""
train network
"""
if network_name == "bilstm_net":
network = bilstm_net
elif network_name == "bow_net":
network = bow_net
elif network_name == "cnn_net":
network = cnn_net
elif network_name == "lstm_net":
network = lstm_net
elif network_name == "gru_net":
network = gru_net
else:
print("unknown network type")
return
dict_dim = len(word_dict) + 2
emb_dim = 128
hid_dim = 128
hid_dim2 = 96
class_dim = 2
module_path = "./models/bow_net"
module = hub.Module(module_dir=module_path)
main_program = fluid.Program()
startup_program = fluid.Program()
# use switch program to test fine-tuning
fluid.framework.switch_main_program(module.get_inference_program())
# remove feed fetch operator and variable
hub.ModuleUtils.remove_feed_fetch_op(fluid.default_main_program())
# remove_feed_fetch_op(fluid.default_main_program())
label = fluid.layers.data(name="label", shape=[1], dtype="int64")
#data = fluid.default_main_program().global_block().var("words")
data = module.get_feed_var("words")
#TODO(ZeyuChen): how to get output paramter according to proto config
emb = module.get_fetch_var("emb")
# # # embedding layer
# emb = fluid.layers.embedding(input=data, size=[dict_dim, emb_dim])
# #input=data, size=[dict_dim, emb_dim], param_attr="bow_embedding")
# # bow layer
bow = fluid.layers.sequence_pool(input=emb, pool_type='sum')
bow_tanh = fluid.layers.tanh(bow)
# full connect layer
fc_1 = fluid.layers.fc(
input=bow_tanh, size=hid_dim, act="tanh", name="bow_fc1")
fc_2 = fluid.layers.fc(
input=fc_1, size=hid_dim2, act="tanh", name="bow_fc2")
# softmax layer
pred = fluid.layers.fc(input=[fc_2], size=class_dim, act="softmax")
# print(fluid.default_main_program())
cost = fluid.layers.mean(
fluid.layers.cross_entropy(input=pred, label=label))
acc = fluid.layers.accuracy(input=pred, label=label)
# set optimizer
sgd_optimizer = fluid.optimizer.Adagrad(learning_rate=lr)
sgd_optimizer.minimize(cost)
# set place, executor, datafeeder
place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
feeder = fluid.DataFeeder(feed_list=["words", "label"], place=place)
exe.run(fluid.default_startup_program())
# start training...
for pass_id in range(pass_num):
data_size, data_count, total_acc, total_cost = 0, 0, 0.0, 0.0
for batch in train_reader():
avg_cost_np, avg_acc_np = exe.run(
fluid.default_main_program(),
feed=feeder.feed(batch),
fetch_list=[cost, acc],
return_numpy=True)
data_size = len(batch)
total_acc += data_size * avg_acc_np
total_cost += data_size * avg_cost_np
data_count += data_size
avg_cost = total_cost / data_count
avg_acc = total_acc / data_count
print("[train info]: pass_id: %d, avg_acc: %f, avg_cost: %f" %
(pass_id, avg_acc, avg_cost))
# save the model
module_dir = os.path.join(save_dirname, network_name + "_retrain")
fluid.io.save_inference_model(module_dir, ["words"], emb, exe)
input_desc = {"words": data.name}
output_desc = {"emb": emb.name}
config = hub.ModuleConfig(module_dir)
config.save_dict(word_dict=word_dict)
config.dump()
def main(args):
# prepare_data to get word_dict, train_reader
word_dict, train_reader = utils.prepare_data(
args.train_data_path, args.word_dict_path, args.batch_size, args.mode)
retrain_net(train_reader, word_dict, args.model_type, args.use_gpu,
args.is_parallel, args.model_path, args.lr, args.batch_size,
args.num_passes)
if __name__ == "__main__":
args = parse_args()
main(args)
......@@ -20,10 +20,10 @@ import paddle.fluid as fluid
import numpy as np
import tempfile
import os
import paddle_hub.module_desc_pb2
from collections import defaultdict
from paddle_hub.downloader import download_and_uncompress
from paddle_hub import module_desc_pb2
__all__ = ["Module", "ModuleConfig", "ModuleUtils"]
DICT_NAME = "dict.txt"
......@@ -80,7 +80,6 @@ class Module(object):
#TODO(ZeyuChen): Need add register more signature to execute different
# implmentation
def __call__(self, inputs=None, signature=None):
""" Call default signature and return results
"""
......@@ -105,26 +104,23 @@ class Module(object):
return np_result
def add_input_desc(var_name):
pass
def get_vars(self):
return self.inference_program.list_vars()
def get_input_vars(self):
def get_feed_var(self, key, signature="default"):
for var in self.inference_program.list_vars():
print(var)
if var.name == "words":
if var.name == self.config.feed_var_name(key, signature):
return var
# return self.fetch_targets
def get_module_output(self):
raise Exception("Can't find input var {}".format(key))
def get_fetch_var(self, key, signature="default"):
for var in self.inference_program.list_vars():
print(var)
# NOTE: just hack for load Senta's
if var.name == "embedding_0.tmp_0":
if var.name == self.config.fetch_var_name(key, signature):
return var
raise Exception("Can't find output var {}".format(key))
def get_inference_program(self):
return self.inference_program
......@@ -159,12 +155,6 @@ class Module(object):
word_dict = self.config.get_dict()
return list(map(lambda x: word_dict[x], inputs))
def add_module_feed_list(self, feed_list):
self.feed_list = feed_list
def add_module_output_list(self, output_list):
self.output_list = output_list
class ModuleConfig(object):
def __init__(self, module_dir, module_name=None):
......@@ -175,8 +165,6 @@ class ModuleConfig(object):
module_name = module_dir.split("/")[-1]
self.desc.name = module_name
print("desc.name=", self.desc.name)
self.desc.signature = "default"
print("desc.signature=", self.desc.signature)
self.desc.contain_assets = True
print("desc.signature=", self.desc.contain_assets)
......@@ -184,12 +172,6 @@ class ModuleConfig(object):
self.dict = defaultdict(int)
self.dict.setdefault(0)
# feed_list
self.feed_list = []
# fetch_list
self.fetch_list = []
def load(self):
"""load module config from module dir
"""
......@@ -227,25 +209,45 @@ class ModuleConfig(object):
w_id = self.dict[w]
fo.write("{}\t{}\n".format(w, w_id))
def register_input_var(self, var, signature="default"):
var_name = var.name()
self.desc.sign2input[signature].append(var_name)
def register_output_var(self, var, signature="default"):
var_name = var.name()
self.desc.sign2output[signature].append(var_name)
def save_dict(self, word_dict, dict_name=DICT_NAME):
""" Save dictionary for NLP module
"""
mkdir(self.module_dir)
with open(os.path.join(self.module_dir, DICT_NAME), "w") as fo:
for w in word_dict:
self.dict[w] = word_dict[w]
for w in word_dict:
self.dict[w] = word_dict[w]
# mkdir(self.module_dir)
# with open(os.path.join(self.module_dir, DICT_NAME), "w") as fo:
# for w in word_dict:
# self.dict[w] = word_dict[w]
def get_dict(self):
return self.dict
def register_feed_signature(self, feed_desc, sign_name="default"):
#TODO(ZeyuChen) check fetch_desc key is valid and no duplicated
for k in feed_desc:
feed = self.desc.sign2var[sign_name].feed_desc.add()
feed.key = k
feed.var_name = feed_desc[k]
def register_fetch_signature(self, fetch_desc, sign_name="default"):
#TODO(ZeyuChen) check fetch_desc key is valid and no duplicated
for k in fetch_desc:
fetch = self.desc.sign2var[sign_name].fetch_desc.add()
fetch.key = k
fetch.var_name = fetch_desc[k]
def feed_var_name(self, key, sign_name="default"):
for desc in self.desc.sign2var[sign_name].feed_desc:
if desc.key == key:
return desc.var_name
raise Exception("feed variable {} not found".format(key))
def fetch_var_name(self, key, sign_name="default"):
for desc in self.desc.sign2var[sign_name].fetch_desc:
if desc.key == key:
return desc.var_name
raise Exception("fetch variable {} not found".format(key))
class ModuleUtils(object):
def __init__(self):
......@@ -269,9 +271,9 @@ class ModuleUtils(object):
block._remove_var("fetch")
program.desc.flush()
print("********************************")
print(program)
print("********************************")
# print("********************************")
# print(program)
# print("********************************")
if __name__ == "__main__":
......
......@@ -18,14 +18,21 @@ option optimize_for = LITE_RUNTIME;
package paddle_hub;
message InputDesc {
repeated string name = 1;
message FeedDesc {
string key = 1;
string var_name = 2;
};
message OutputDesc {
repeated string name = 1;
message FetchDesc {
string key = 1;
string var_name = 2;
};
message ModuleVar {
repeated FetchDesc fetch_desc = 1;
repeated FeedDesc feed_desc = 2;
}
message Parameter {
string name = 1;
double learning_rate = 2;
......@@ -39,10 +46,7 @@ message ModuleDesc {
string name = 1; // PaddleHub module name
// signature to input description
map<string, InputDesc> sign2input = 2;
// signature to output description
map<string, OutputDesc> sign2output = 3;
map<string, ModuleVar> sign2var = 2;
bool return_numpy = 4;
......
......@@ -17,27 +17,43 @@ DESCRIPTOR = _descriptor.FileDescriptor(
package='paddle_hub',
syntax='proto3',
serialized_pb=_b(
'\n\x11module_desc.proto\x12\npaddle_hub\"\x19\n\tInputDesc\x12\x0c\n\x04name\x18\x01 \x03(\t\"\x1a\n\nOutputDesc\x12\x0c\n\x04name\x18\x01 \x03(\t\"C\n\tParameter\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x15\n\rlearning_rate\x18\x02 \x01(\x01\x12\x11\n\ttrainable\x18\x03 \x01(\x08\"\xd8\x02\n\nModuleDesc\x12\x0c\n\x04name\x18\x01 \x01(\t\x12:\n\nsign2input\x18\x02 \x03(\x0b\x32&.paddle_hub.ModuleDesc.Sign2inputEntry\x12<\n\x0bsign2output\x18\x03 \x03(\x0b\x32\'.paddle_hub.ModuleDesc.Sign2outputEntry\x12\x14\n\x0creturn_numpy\x18\x04 \x01(\x08\x12\x16\n\x0e\x63ontain_assets\x18\x05 \x01(\x08\x1aH\n\x0fSign2inputEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12$\n\x05value\x18\x02 \x01(\x0b\x32\x15.paddle_hub.InputDesc:\x02\x38\x01\x1aJ\n\x10Sign2outputEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12%\n\x05value\x18\x02 \x01(\x0b\x32\x16.paddle_hub.OutputDesc:\x02\x38\x01\x42\x02H\x03\x62\x06proto3'
'\n\x11module_desc.proto\x12\npaddle_hub\")\n\x08\x46\x65\x65\x64\x44\x65sc\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x10\n\x08var_name\x18\x02 \x01(\t\"*\n\tFetchDesc\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x10\n\x08var_name\x18\x02 \x01(\t\"_\n\tModuleVar\x12)\n\nfetch_desc\x18\x01 \x03(\x0b\x32\x15.paddle_hub.FetchDesc\x12\'\n\tfeed_desc\x18\x02 \x03(\x0b\x32\x14.paddle_hub.FeedDesc\"C\n\tParameter\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x15\n\rlearning_rate\x18\x02 \x01(\x01\x12\x11\n\ttrainable\x18\x03 \x01(\x08\"\xc8\x01\n\nModuleDesc\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x36\n\x08sign2var\x18\x02 \x03(\x0b\x32$.paddle_hub.ModuleDesc.Sign2varEntry\x12\x14\n\x0creturn_numpy\x18\x04 \x01(\x08\x12\x16\n\x0e\x63ontain_assets\x18\x05 \x01(\x08\x1a\x46\n\rSign2varEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12$\n\x05value\x18\x02 \x01(\x0b\x32\x15.paddle_hub.ModuleVar:\x02\x38\x01\x42\x02H\x03\x62\x06proto3'
))
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
_INPUTDESC = _descriptor.Descriptor(
name='InputDesc',
full_name='paddle_hub.InputDesc',
_FEEDDESC = _descriptor.Descriptor(
name='FeedDesc',
full_name='paddle_hub.FeedDesc',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='name',
full_name='paddle_hub.InputDesc.name',
name='key',
full_name='paddle_hub.FeedDesc.key',
index=0,
number=1,
type=9,
cpp_type=9,
label=3,
label=1,
has_default_value=False,
default_value=[],
default_value=_b("").decode('utf-8'),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='var_name',
full_name='paddle_hub.FeedDesc.var_name',
index=1,
number=2,
type=9,
cpp_type=9,
label=1,
has_default_value=False,
default_value=_b("").decode('utf-8'),
message_type=None,
enum_type=None,
containing_type=None,
......@@ -54,26 +70,42 @@ _INPUTDESC = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[],
serialized_start=33,
serialized_end=58,
serialized_end=74,
)
_OUTPUTDESC = _descriptor.Descriptor(
name='OutputDesc',
full_name='paddle_hub.OutputDesc',
_FETCHDESC = _descriptor.Descriptor(
name='FetchDesc',
full_name='paddle_hub.FetchDesc',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='name',
full_name='paddle_hub.OutputDesc.name',
name='key',
full_name='paddle_hub.FetchDesc.key',
index=0,
number=1,
type=9,
cpp_type=9,
label=3,
label=1,
has_default_value=False,
default_value=[],
default_value=_b("").decode('utf-8'),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='var_name',
full_name='paddle_hub.FetchDesc.var_name',
index=1,
number=2,
type=9,
cpp_type=9,
label=1,
has_default_value=False,
default_value=_b("").decode('utf-8'),
message_type=None,
enum_type=None,
containing_type=None,
......@@ -89,27 +121,27 @@ _OUTPUTDESC = _descriptor.Descriptor(
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=60,
serialized_end=86,
serialized_start=76,
serialized_end=118,
)
_PARAMETER = _descriptor.Descriptor(
name='Parameter',
full_name='paddle_hub.Parameter',
_MODULEVAR = _descriptor.Descriptor(
name='ModuleVar',
full_name='paddle_hub.ModuleVar',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='name',
full_name='paddle_hub.Parameter.name',
name='fetch_desc',
full_name='paddle_hub.ModuleVar.fetch_desc',
index=0,
number=1,
type=9,
cpp_type=9,
label=1,
type=11,
cpp_type=10,
label=3,
has_default_value=False,
default_value=_b("").decode('utf-8'),
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
......@@ -117,31 +149,15 @@ _PARAMETER = _descriptor.Descriptor(
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='learning_rate',
full_name='paddle_hub.Parameter.learning_rate',
name='feed_desc',
full_name='paddle_hub.ModuleVar.feed_desc',
index=1,
number=2,
type=1,
cpp_type=5,
label=1,
has_default_value=False,
default_value=float(0),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='trainable',
full_name='paddle_hub.Parameter.trainable',
index=2,
number=3,
type=8,
cpp_type=7,
label=1,
type=11,
cpp_type=10,
label=3,
has_default_value=False,
default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
......@@ -157,20 +173,20 @@ _PARAMETER = _descriptor.Descriptor(
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=88,
serialized_end=155,
serialized_start=120,
serialized_end=215,
)
_MODULEDESC_SIGN2INPUTENTRY = _descriptor.Descriptor(
name='Sign2inputEntry',
full_name='paddle_hub.ModuleDesc.Sign2inputEntry',
_PARAMETER = _descriptor.Descriptor(
name='Parameter',
full_name='paddle_hub.Parameter',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='key',
full_name='paddle_hub.ModuleDesc.Sign2inputEntry.key',
name='name',
full_name='paddle_hub.Parameter.name',
index=0,
number=1,
type=9,
......@@ -185,15 +201,31 @@ _MODULEDESC_SIGN2INPUTENTRY = _descriptor.Descriptor(
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='value',
full_name='paddle_hub.ModuleDesc.Sign2inputEntry.value',
name='learning_rate',
full_name='paddle_hub.Parameter.learning_rate',
index=1,
number=2,
type=11,
cpp_type=10,
type=1,
cpp_type=5,
label=1,
has_default_value=False,
default_value=None,
default_value=float(0),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='trainable',
full_name='paddle_hub.Parameter.trainable',
index=2,
number=3,
type=8,
cpp_type=7,
label=1,
has_default_value=False,
default_value=False,
message_type=None,
enum_type=None,
containing_type=None,
......@@ -204,26 +236,25 @@ _MODULEDESC_SIGN2INPUTENTRY = _descriptor.Descriptor(
extensions=[],
nested_types=[],
enum_types=[],
options=_descriptor._ParseOptions(descriptor_pb2.MessageOptions(),
_b('8\001')),
options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=354,
serialized_end=426,
serialized_start=217,
serialized_end=284,
)
_MODULEDESC_SIGN2OUTPUTENTRY = _descriptor.Descriptor(
name='Sign2outputEntry',
full_name='paddle_hub.ModuleDesc.Sign2outputEntry',
_MODULEDESC_SIGN2VARENTRY = _descriptor.Descriptor(
name='Sign2varEntry',
full_name='paddle_hub.ModuleDesc.Sign2varEntry',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='key',
full_name='paddle_hub.ModuleDesc.Sign2outputEntry.key',
full_name='paddle_hub.ModuleDesc.Sign2varEntry.key',
index=0,
number=1,
type=9,
......@@ -239,7 +270,7 @@ _MODULEDESC_SIGN2OUTPUTENTRY = _descriptor.Descriptor(
options=None),
_descriptor.FieldDescriptor(
name='value',
full_name='paddle_hub.ModuleDesc.Sign2outputEntry.value',
full_name='paddle_hub.ModuleDesc.Sign2varEntry.value',
index=1,
number=2,
type=11,
......@@ -263,8 +294,8 @@ _MODULEDESC_SIGN2OUTPUTENTRY = _descriptor.Descriptor(
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=428,
serialized_end=502,
serialized_start=417,
serialized_end=487,
)
_MODULEDESC = _descriptor.Descriptor(
......@@ -291,8 +322,8 @@ _MODULEDESC = _descriptor.Descriptor(
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='sign2input',
full_name='paddle_hub.ModuleDesc.sign2input',
name='sign2var',
full_name='paddle_hub.ModuleDesc.sign2var',
index=1,
number=2,
type=11,
......@@ -306,26 +337,10 @@ _MODULEDESC = _descriptor.Descriptor(
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='sign2output',
full_name='paddle_hub.ModuleDesc.sign2output',
index=2,
number=3,
type=11,
cpp_type=10,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='return_numpy',
full_name='paddle_hub.ModuleDesc.return_numpy',
index=3,
index=2,
number=4,
type=8,
cpp_type=7,
......@@ -341,7 +356,7 @@ _MODULEDESC = _descriptor.Descriptor(
_descriptor.FieldDescriptor(
name='contain_assets',
full_name='paddle_hub.ModuleDesc.contain_assets',
index=4,
index=3,
number=5,
type=8,
cpp_type=7,
......@@ -357,8 +372,7 @@ _MODULEDESC = _descriptor.Descriptor(
],
extensions=[],
nested_types=[
_MODULEDESC_SIGN2INPUTENTRY,
_MODULEDESC_SIGN2OUTPUTENTRY,
_MODULEDESC_SIGN2VARENTRY,
],
enum_types=[],
options=None,
......@@ -366,42 +380,50 @@ _MODULEDESC = _descriptor.Descriptor(
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=158,
serialized_end=502,
serialized_start=287,
serialized_end=487,
)
_MODULEDESC_SIGN2INPUTENTRY.fields_by_name['value'].message_type = _INPUTDESC
_MODULEDESC_SIGN2INPUTENTRY.containing_type = _MODULEDESC
_MODULEDESC_SIGN2OUTPUTENTRY.fields_by_name['value'].message_type = _OUTPUTDESC
_MODULEDESC_SIGN2OUTPUTENTRY.containing_type = _MODULEDESC
_MODULEDESC.fields_by_name[
'sign2input'].message_type = _MODULEDESC_SIGN2INPUTENTRY
_MODULEDESC.fields_by_name[
'sign2output'].message_type = _MODULEDESC_SIGN2OUTPUTENTRY
DESCRIPTOR.message_types_by_name['InputDesc'] = _INPUTDESC
DESCRIPTOR.message_types_by_name['OutputDesc'] = _OUTPUTDESC
_MODULEVAR.fields_by_name['fetch_desc'].message_type = _FETCHDESC
_MODULEVAR.fields_by_name['feed_desc'].message_type = _FEEDDESC
_MODULEDESC_SIGN2VARENTRY.fields_by_name['value'].message_type = _MODULEVAR
_MODULEDESC_SIGN2VARENTRY.containing_type = _MODULEDESC
_MODULEDESC.fields_by_name['sign2var'].message_type = _MODULEDESC_SIGN2VARENTRY
DESCRIPTOR.message_types_by_name['FeedDesc'] = _FEEDDESC
DESCRIPTOR.message_types_by_name['FetchDesc'] = _FETCHDESC
DESCRIPTOR.message_types_by_name['ModuleVar'] = _MODULEVAR
DESCRIPTOR.message_types_by_name['Parameter'] = _PARAMETER
DESCRIPTOR.message_types_by_name['ModuleDesc'] = _MODULEDESC
InputDesc = _reflection.GeneratedProtocolMessageType(
'InputDesc',
FeedDesc = _reflection.GeneratedProtocolMessageType(
'FeedDesc',
(_message.Message, ),
dict(
DESCRIPTOR=_INPUTDESC,
DESCRIPTOR=_FEEDDESC,
__module__='module_desc_pb2'
# @@protoc_insertion_point(class_scope:paddle_hub.InputDesc)
# @@protoc_insertion_point(class_scope:paddle_hub.FeedDesc)
))
_sym_db.RegisterMessage(InputDesc)
_sym_db.RegisterMessage(FeedDesc)
OutputDesc = _reflection.GeneratedProtocolMessageType(
'OutputDesc',
FetchDesc = _reflection.GeneratedProtocolMessageType(
'FetchDesc',
(_message.Message, ),
dict(
DESCRIPTOR=_OUTPUTDESC,
DESCRIPTOR=_FETCHDESC,
__module__='module_desc_pb2'
# @@protoc_insertion_point(class_scope:paddle_hub.OutputDesc)
# @@protoc_insertion_point(class_scope:paddle_hub.FetchDesc)
))
_sym_db.RegisterMessage(OutputDesc)
_sym_db.RegisterMessage(FetchDesc)
ModuleVar = _reflection.GeneratedProtocolMessageType(
'ModuleVar',
(_message.Message, ),
dict(
DESCRIPTOR=_MODULEVAR,
__module__='module_desc_pb2'
# @@protoc_insertion_point(class_scope:paddle_hub.ModuleVar)
))
_sym_db.RegisterMessage(ModuleVar)
Parameter = _reflection.GeneratedProtocolMessageType(
'Parameter',
......@@ -417,37 +439,25 @@ ModuleDesc = _reflection.GeneratedProtocolMessageType(
'ModuleDesc',
(_message.Message, ),
dict(
Sign2inputEntry=_reflection.GeneratedProtocolMessageType(
'Sign2inputEntry',
(_message.Message, ),
dict(
DESCRIPTOR=_MODULEDESC_SIGN2INPUTENTRY,
__module__='module_desc_pb2'
# @@protoc_insertion_point(class_scope:paddle_hub.ModuleDesc.Sign2inputEntry)
)),
Sign2outputEntry=_reflection.GeneratedProtocolMessageType(
'Sign2outputEntry',
Sign2varEntry=_reflection.GeneratedProtocolMessageType(
'Sign2varEntry',
(_message.Message, ),
dict(
DESCRIPTOR=_MODULEDESC_SIGN2OUTPUTENTRY,
DESCRIPTOR=_MODULEDESC_SIGN2VARENTRY,
__module__='module_desc_pb2'
# @@protoc_insertion_point(class_scope:paddle_hub.ModuleDesc.Sign2outputEntry)
# @@protoc_insertion_point(class_scope:paddle_hub.ModuleDesc.Sign2varEntry)
)),
DESCRIPTOR=_MODULEDESC,
__module__='module_desc_pb2'
# @@protoc_insertion_point(class_scope:paddle_hub.ModuleDesc)
))
_sym_db.RegisterMessage(ModuleDesc)
_sym_db.RegisterMessage(ModuleDesc.Sign2inputEntry)
_sym_db.RegisterMessage(ModuleDesc.Sign2outputEntry)
_sym_db.RegisterMessage(ModuleDesc.Sign2varEntry)
DESCRIPTOR.has_options = True
DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(),
_b('H\003'))
_MODULEDESC_SIGN2INPUTENTRY.has_options = True
_MODULEDESC_SIGN2INPUTENTRY._options = _descriptor._ParseOptions(
descriptor_pb2.MessageOptions(), _b('8\001'))
_MODULEDESC_SIGN2OUTPUTENTRY.has_options = True
_MODULEDESC_SIGN2OUTPUTENTRY._options = _descriptor._ParseOptions(
_MODULEDESC_SIGN2VARENTRY.has_options = True
_MODULEDESC_SIGN2VARENTRY._options = _descriptor._ParseOptions(
descriptor_pb2.MessageOptions(), _b('8\001'))
# @@protoc_insertion_point(module_scope)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle_hub as hub
class TestModule(unittest.TestCase):
def test_word2vec_module_usage(self):
url = "http://paddlehub.cdn.bcebos.com/word2vec/word2vec-dim16-simple-example-2.tar.gz"
module = Module(module_url=url)
inputs = [["it", "is", "new"], ["hello", "world"]]
tensor = module._process_input(inputs)
print(tensor)
result = module(inputs)
print(result)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册