提交 c6089088 编写于 作者: Z Zeyu Chen

update module_desc.proto

上级 96102e2f
......@@ -35,7 +35,7 @@ def bow_net(data,
avg_cost = fluid.layers.mean(x=cost)
acc = fluid.layers.accuracy(input=prediction, label=label)
return avg_cost, acc, prediction, bow_tanh
return avg_cost, acc, prediction, fc_1
def cnn_net(data,
......
......@@ -204,16 +204,16 @@ def retrain_net(train_reader,
# use switch program to test fine-tuning
fluid.framework.switch_main_program(module.get_inference_program())
# remove feed fetch operator and variable
# hub.ModuleUtils.remove_feed_fetch_op(fluid.default_main_program())
label = fluid.layers.data(name="label", shape=[1], dtype="int64")
data = module.get_feed_var_by_index(0)
#TODO(ZeyuChen): how to get output paramter according to proto config
sent_emb = module.get_fetch_var_by_index(0)
fc_1 = fluid.layers.fc(
input=sent_emb, size=hid_dim, act="tanh", name="bow_fc1")
fc_2 = fluid.layers.fc(
input=sent_emb, size=hid_dim2, act="tanh", name="bow_fc2")
input=fc_1, size=hid_dim2, act="tanh", name="bow_fc2")
# softmax layer
pred = fluid.layers.fc(input=[fc_2], size=class_dim, act="softmax")
# print(fluid.default_main_program())
......@@ -221,6 +221,9 @@ def retrain_net(train_reader,
fluid.layers.cross_entropy(input=pred, label=label))
acc = fluid.layers.accuracy(input=pred, label=label)
with open("./prototxt/bow_net.forward.program_desc.prototxt", "w") as fo:
program_desc = str(fluid.default_main_program())
fo.write(program_desc)
# set optimizer
sgd_optimizer = fluid.optimizer.Adagrad(learning_rate=lr)
sgd_optimizer.minimize(cost)
......@@ -228,6 +231,7 @@ def retrain_net(train_reader,
with open("./prototxt/bow_net.finetune.program_desc.prototxt", "w") as fo:
program_desc = str(fluid.default_main_program())
fo.write(program_desc)
# set place, executor, datafeeder
place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
......
......@@ -141,21 +141,13 @@ def retrain_net(train_reader,
main_program = fluid.Program()
startup_program = fluid.Program()
# use switch program to test fine-tuning
fluid.framework.switch_main_program(module.get_inference_program())
# remove feed fetch operator and variable
# hub.ModuleUtils.remove_feed_fetch_op(fluid.default_main_program())
label = fluid.layers.data(name="label", shape=[1], dtype="int64")
#data = fluid.default_main_program().global_block().var("words")
data = module.get_feed_var("words")
#TODO(ZeyuChen): how to get output paramter according to proto config
emb = module.get_fetch_var("emb")
data = module.get_feed_var_by_index(0)
emb = module.get_fetch_var_by_index(0)
emb2 = fluid.layers.embedding(input=data, size=[dict_dim, emb_dim])
# # # embedding layer
# emb = fluid.layers.embedding(input=data, size=[dict_dim, emb_dim])
# #input=data, size=[dict_dim, emb_dim], param_attr="bow_embedding")
# # bow layer
bow = fluid.layers.sequence_pool(input=emb, pool_type='sum')
bow_tanh = fluid.layers.tanh(bow)
......
......@@ -47,7 +47,7 @@ def data2tensor(data, place):
"""
data2tensor
"""
input_seq = to_lodtensor(map(lambda x: x[0], data), place)
input_seq = to_lodtensor(list(map(lambda x: x[0], data)), place)
return {"words": input_seq}
......
......@@ -20,14 +20,12 @@ package paddle_hub;
// Feed Variable Description
message FeedDesc {
string key = 1;
string var_name = 2;
string var_name = 1;
};
// Fetch Variable Description
message FetchDesc {
string key = 1;
string var_name = 2;
string var_name = 1;
};
// Module Variable
......
......@@ -17,7 +17,7 @@ DESCRIPTOR = _descriptor.FileDescriptor(
package='paddle_hub',
syntax='proto3',
serialized_pb=_b(
'\n\x11module_desc.proto\x12\npaddle_hub\")\n\x08\x46\x65\x65\x64\x44\x65sc\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x10\n\x08var_name\x18\x02 \x01(\t\"*\n\tFetchDesc\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x10\n\x08var_name\x18\x02 \x01(\t\"_\n\tModuleVar\x12)\n\nfetch_desc\x18\x01 \x03(\x0b\x32\x15.paddle_hub.FetchDesc\x12\'\n\tfeed_desc\x18\x02 \x03(\x0b\x32\x14.paddle_hub.FeedDesc\"\xc8\x01\n\nModuleDesc\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x36\n\x08sign2var\x18\x02 \x03(\x0b\x32$.paddle_hub.ModuleDesc.Sign2varEntry\x12\x14\n\x0creturn_numpy\x18\x03 \x01(\x08\x12\x16\n\x0e\x63ontain_assets\x18\x04 \x01(\x08\x1a\x46\n\rSign2varEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12$\n\x05value\x18\x02 \x01(\x0b\x32\x15.paddle_hub.ModuleVar:\x02\x38\x01\x42\x02H\x03\x62\x06proto3'
'\n\x11module_desc.proto\x12\npaddle_hub\"\x1c\n\x08\x46\x65\x65\x64\x44\x65sc\x12\x10\n\x08var_name\x18\x01 \x01(\t\"\x1d\n\tFetchDesc\x12\x10\n\x08var_name\x18\x01 \x01(\t\"_\n\tModuleVar\x12)\n\nfetch_desc\x18\x01 \x03(\x0b\x32\x15.paddle_hub.FetchDesc\x12\'\n\tfeed_desc\x18\x02 \x03(\x0b\x32\x14.paddle_hub.FeedDesc\"\xc8\x01\n\nModuleDesc\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x36\n\x08sign2var\x18\x02 \x03(\x0b\x32$.paddle_hub.ModuleDesc.Sign2varEntry\x12\x14\n\x0creturn_numpy\x18\x03 \x01(\x08\x12\x16\n\x0e\x63ontain_assets\x18\x04 \x01(\x08\x1a\x46\n\rSign2varEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12$\n\x05value\x18\x02 \x01(\x0b\x32\x15.paddle_hub.ModuleVar:\x02\x38\x01\x42\x02H\x03\x62\x06proto3'
))
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
......@@ -28,27 +28,11 @@ _FEEDDESC = _descriptor.Descriptor(
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='key',
full_name='paddle_hub.FeedDesc.key',
index=0,
number=1,
type=9,
cpp_type=9,
label=1,
has_default_value=False,
default_value=_b("").decode('utf-8'),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='var_name',
full_name='paddle_hub.FeedDesc.var_name',
index=1,
number=2,
index=0,
number=1,
type=9,
cpp_type=9,
label=1,
......@@ -70,7 +54,7 @@ _FEEDDESC = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[],
serialized_start=33,
serialized_end=74,
serialized_end=61,
)
_FETCHDESC = _descriptor.Descriptor(
......@@ -80,27 +64,11 @@ _FETCHDESC = _descriptor.Descriptor(
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='key',
full_name='paddle_hub.FetchDesc.key',
index=0,
number=1,
type=9,
cpp_type=9,
label=1,
has_default_value=False,
default_value=_b("").decode('utf-8'),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='var_name',
full_name='paddle_hub.FetchDesc.var_name',
index=1,
number=2,
index=0,
number=1,
type=9,
cpp_type=9,
label=1,
......@@ -121,8 +89,8 @@ _FETCHDESC = _descriptor.Descriptor(
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=76,
serialized_end=118,
serialized_start=63,
serialized_end=92,
)
_MODULEVAR = _descriptor.Descriptor(
......@@ -173,8 +141,8 @@ _MODULEVAR = _descriptor.Descriptor(
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=120,
serialized_end=215,
serialized_start=94,
serialized_end=189,
)
_MODULEDESC_SIGN2VARENTRY = _descriptor.Descriptor(
......@@ -226,8 +194,8 @@ _MODULEDESC_SIGN2VARENTRY = _descriptor.Descriptor(
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=348,
serialized_end=418,
serialized_start=322,
serialized_end=392,
)
_MODULEDESC = _descriptor.Descriptor(
......@@ -312,8 +280,8 @@ _MODULEDESC = _descriptor.Descriptor(
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=218,
serialized_end=418,
serialized_start=192,
serialized_end=392,
)
_MODULEVAR.fields_by_name['fetch_desc'].message_type = _FETCHDESC
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册