提交 c6089088 编写于 作者: Z Zeyu Chen

update module_desc.proto

上级 96102e2f
...@@ -35,7 +35,7 @@ def bow_net(data, ...@@ -35,7 +35,7 @@ def bow_net(data,
avg_cost = fluid.layers.mean(x=cost) avg_cost = fluid.layers.mean(x=cost)
acc = fluid.layers.accuracy(input=prediction, label=label) acc = fluid.layers.accuracy(input=prediction, label=label)
return avg_cost, acc, prediction, bow_tanh return avg_cost, acc, prediction, fc_1
def cnn_net(data, def cnn_net(data,
......
...@@ -204,16 +204,16 @@ def retrain_net(train_reader, ...@@ -204,16 +204,16 @@ def retrain_net(train_reader,
# use switch program to test fine-tuning # use switch program to test fine-tuning
fluid.framework.switch_main_program(module.get_inference_program()) fluid.framework.switch_main_program(module.get_inference_program())
# remove feed fetch operator and variable
# hub.ModuleUtils.remove_feed_fetch_op(fluid.default_main_program())
label = fluid.layers.data(name="label", shape=[1], dtype="int64") label = fluid.layers.data(name="label", shape=[1], dtype="int64")
data = module.get_feed_var_by_index(0) data = module.get_feed_var_by_index(0)
#TODO(ZeyuChen): how to get output paramter according to proto config #TODO(ZeyuChen): how to get output paramter according to proto config
sent_emb = module.get_fetch_var_by_index(0) sent_emb = module.get_fetch_var_by_index(0)
fc_1 = fluid.layers.fc(
input=sent_emb, size=hid_dim, act="tanh", name="bow_fc1")
fc_2 = fluid.layers.fc( fc_2 = fluid.layers.fc(
input=sent_emb, size=hid_dim2, act="tanh", name="bow_fc2") input=fc_1, size=hid_dim2, act="tanh", name="bow_fc2")
# softmax layer # softmax layer
pred = fluid.layers.fc(input=[fc_2], size=class_dim, act="softmax") pred = fluid.layers.fc(input=[fc_2], size=class_dim, act="softmax")
# print(fluid.default_main_program()) # print(fluid.default_main_program())
...@@ -221,6 +221,9 @@ def retrain_net(train_reader, ...@@ -221,6 +221,9 @@ def retrain_net(train_reader,
fluid.layers.cross_entropy(input=pred, label=label)) fluid.layers.cross_entropy(input=pred, label=label))
acc = fluid.layers.accuracy(input=pred, label=label) acc = fluid.layers.accuracy(input=pred, label=label)
with open("./prototxt/bow_net.forward.program_desc.prototxt", "w") as fo:
program_desc = str(fluid.default_main_program())
fo.write(program_desc)
# set optimizer # set optimizer
sgd_optimizer = fluid.optimizer.Adagrad(learning_rate=lr) sgd_optimizer = fluid.optimizer.Adagrad(learning_rate=lr)
sgd_optimizer.minimize(cost) sgd_optimizer.minimize(cost)
...@@ -228,6 +231,7 @@ def retrain_net(train_reader, ...@@ -228,6 +231,7 @@ def retrain_net(train_reader,
with open("./prototxt/bow_net.finetune.program_desc.prototxt", "w") as fo: with open("./prototxt/bow_net.finetune.program_desc.prototxt", "w") as fo:
program_desc = str(fluid.default_main_program()) program_desc = str(fluid.default_main_program())
fo.write(program_desc) fo.write(program_desc)
# set place, executor, datafeeder # set place, executor, datafeeder
place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace() place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
......
...@@ -141,21 +141,13 @@ def retrain_net(train_reader, ...@@ -141,21 +141,13 @@ def retrain_net(train_reader,
main_program = fluid.Program() main_program = fluid.Program()
startup_program = fluid.Program() startup_program = fluid.Program()
# use switch program to test fine-tuning
fluid.framework.switch_main_program(module.get_inference_program())
# remove feed fetch operator and variable
# hub.ModuleUtils.remove_feed_fetch_op(fluid.default_main_program())
label = fluid.layers.data(name="label", shape=[1], dtype="int64") label = fluid.layers.data(name="label", shape=[1], dtype="int64")
#data = fluid.default_main_program().global_block().var("words") data = module.get_feed_var_by_index(0)
data = module.get_feed_var("words") emb = module.get_fetch_var_by_index(0)
#TODO(ZeyuChen): how to get output paramter according to proto config
emb = module.get_fetch_var("emb")
emb2 = fluid.layers.embedding(input=data, size=[dict_dim, emb_dim])
# # # embedding layer # # # embedding layer
# emb = fluid.layers.embedding(input=data, size=[dict_dim, emb_dim]) # emb = fluid.layers.embedding(input=data, size=[dict_dim, emb_dim])
# #input=data, size=[dict_dim, emb_dim], param_attr="bow_embedding")
# # bow layer # # bow layer
bow = fluid.layers.sequence_pool(input=emb, pool_type='sum') bow = fluid.layers.sequence_pool(input=emb, pool_type='sum')
bow_tanh = fluid.layers.tanh(bow) bow_tanh = fluid.layers.tanh(bow)
......
...@@ -47,7 +47,7 @@ def data2tensor(data, place): ...@@ -47,7 +47,7 @@ def data2tensor(data, place):
""" """
data2tensor data2tensor
""" """
input_seq = to_lodtensor(map(lambda x: x[0], data), place) input_seq = to_lodtensor(list(map(lambda x: x[0], data)), place)
return {"words": input_seq} return {"words": input_seq}
......
...@@ -20,14 +20,12 @@ package paddle_hub; ...@@ -20,14 +20,12 @@ package paddle_hub;
// Feed Variable Description // Feed Variable Description
message FeedDesc { message FeedDesc {
string key = 1; string var_name = 1;
string var_name = 2;
}; };
// Fetch Variable Description // Fetch Variable Description
message FetchDesc { message FetchDesc {
string key = 1; string var_name = 1;
string var_name = 2;
}; };
// Module Variable // Module Variable
......
...@@ -17,7 +17,7 @@ DESCRIPTOR = _descriptor.FileDescriptor( ...@@ -17,7 +17,7 @@ DESCRIPTOR = _descriptor.FileDescriptor(
package='paddle_hub', package='paddle_hub',
syntax='proto3', syntax='proto3',
serialized_pb=_b( serialized_pb=_b(
'\n\x11module_desc.proto\x12\npaddle_hub\")\n\x08\x46\x65\x65\x64\x44\x65sc\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x10\n\x08var_name\x18\x02 \x01(\t\"*\n\tFetchDesc\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x10\n\x08var_name\x18\x02 \x01(\t\"_\n\tModuleVar\x12)\n\nfetch_desc\x18\x01 \x03(\x0b\x32\x15.paddle_hub.FetchDesc\x12\'\n\tfeed_desc\x18\x02 \x03(\x0b\x32\x14.paddle_hub.FeedDesc\"\xc8\x01\n\nModuleDesc\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x36\n\x08sign2var\x18\x02 \x03(\x0b\x32$.paddle_hub.ModuleDesc.Sign2varEntry\x12\x14\n\x0creturn_numpy\x18\x03 \x01(\x08\x12\x16\n\x0e\x63ontain_assets\x18\x04 \x01(\x08\x1a\x46\n\rSign2varEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12$\n\x05value\x18\x02 \x01(\x0b\x32\x15.paddle_hub.ModuleVar:\x02\x38\x01\x42\x02H\x03\x62\x06proto3' '\n\x11module_desc.proto\x12\npaddle_hub\"\x1c\n\x08\x46\x65\x65\x64\x44\x65sc\x12\x10\n\x08var_name\x18\x01 \x01(\t\"\x1d\n\tFetchDesc\x12\x10\n\x08var_name\x18\x01 \x01(\t\"_\n\tModuleVar\x12)\n\nfetch_desc\x18\x01 \x03(\x0b\x32\x15.paddle_hub.FetchDesc\x12\'\n\tfeed_desc\x18\x02 \x03(\x0b\x32\x14.paddle_hub.FeedDesc\"\xc8\x01\n\nModuleDesc\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x36\n\x08sign2var\x18\x02 \x03(\x0b\x32$.paddle_hub.ModuleDesc.Sign2varEntry\x12\x14\n\x0creturn_numpy\x18\x03 \x01(\x08\x12\x16\n\x0e\x63ontain_assets\x18\x04 \x01(\x08\x1a\x46\n\rSign2varEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12$\n\x05value\x18\x02 \x01(\x0b\x32\x15.paddle_hub.ModuleVar:\x02\x38\x01\x42\x02H\x03\x62\x06proto3'
)) ))
_sym_db.RegisterFileDescriptor(DESCRIPTOR) _sym_db.RegisterFileDescriptor(DESCRIPTOR)
...@@ -28,27 +28,11 @@ _FEEDDESC = _descriptor.Descriptor( ...@@ -28,27 +28,11 @@ _FEEDDESC = _descriptor.Descriptor(
file=DESCRIPTOR, file=DESCRIPTOR,
containing_type=None, containing_type=None,
fields=[ fields=[
_descriptor.FieldDescriptor(
name='key',
full_name='paddle_hub.FeedDesc.key',
index=0,
number=1,
type=9,
cpp_type=9,
label=1,
has_default_value=False,
default_value=_b("").decode('utf-8'),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='var_name', name='var_name',
full_name='paddle_hub.FeedDesc.var_name', full_name='paddle_hub.FeedDesc.var_name',
index=1, index=0,
number=2, number=1,
type=9, type=9,
cpp_type=9, cpp_type=9,
label=1, label=1,
...@@ -70,7 +54,7 @@ _FEEDDESC = _descriptor.Descriptor( ...@@ -70,7 +54,7 @@ _FEEDDESC = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=33, serialized_start=33,
serialized_end=74, serialized_end=61,
) )
_FETCHDESC = _descriptor.Descriptor( _FETCHDESC = _descriptor.Descriptor(
...@@ -80,27 +64,11 @@ _FETCHDESC = _descriptor.Descriptor( ...@@ -80,27 +64,11 @@ _FETCHDESC = _descriptor.Descriptor(
file=DESCRIPTOR, file=DESCRIPTOR,
containing_type=None, containing_type=None,
fields=[ fields=[
_descriptor.FieldDescriptor(
name='key',
full_name='paddle_hub.FetchDesc.key',
index=0,
number=1,
type=9,
cpp_type=9,
label=1,
has_default_value=False,
default_value=_b("").decode('utf-8'),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='var_name', name='var_name',
full_name='paddle_hub.FetchDesc.var_name', full_name='paddle_hub.FetchDesc.var_name',
index=1, index=0,
number=2, number=1,
type=9, type=9,
cpp_type=9, cpp_type=9,
label=1, label=1,
...@@ -121,8 +89,8 @@ _FETCHDESC = _descriptor.Descriptor( ...@@ -121,8 +89,8 @@ _FETCHDESC = _descriptor.Descriptor(
syntax='proto3', syntax='proto3',
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=76, serialized_start=63,
serialized_end=118, serialized_end=92,
) )
_MODULEVAR = _descriptor.Descriptor( _MODULEVAR = _descriptor.Descriptor(
...@@ -173,8 +141,8 @@ _MODULEVAR = _descriptor.Descriptor( ...@@ -173,8 +141,8 @@ _MODULEVAR = _descriptor.Descriptor(
syntax='proto3', syntax='proto3',
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=120, serialized_start=94,
serialized_end=215, serialized_end=189,
) )
_MODULEDESC_SIGN2VARENTRY = _descriptor.Descriptor( _MODULEDESC_SIGN2VARENTRY = _descriptor.Descriptor(
...@@ -226,8 +194,8 @@ _MODULEDESC_SIGN2VARENTRY = _descriptor.Descriptor( ...@@ -226,8 +194,8 @@ _MODULEDESC_SIGN2VARENTRY = _descriptor.Descriptor(
syntax='proto3', syntax='proto3',
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=348, serialized_start=322,
serialized_end=418, serialized_end=392,
) )
_MODULEDESC = _descriptor.Descriptor( _MODULEDESC = _descriptor.Descriptor(
...@@ -312,8 +280,8 @@ _MODULEDESC = _descriptor.Descriptor( ...@@ -312,8 +280,8 @@ _MODULEDESC = _descriptor.Descriptor(
syntax='proto3', syntax='proto3',
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=218, serialized_start=192,
serialized_end=418, serialized_end=392,
) )
_MODULEVAR.fields_by_name['fetch_desc'].message_type = _FETCHDESC _MODULEVAR.fields_by_name['fetch_desc'].message_type = _FETCHDESC
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册