Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
models
提交
ee1262d5
M
models
项目概览
PaddlePaddle
/
models
大约 2 年 前同步成功
通知
232
Star
6828
Fork
2962
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
602
列表
看板
标记
里程碑
合并请求
255
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
602
Issue
602
列表
看板
标记
里程碑
合并请求
255
合并请求
255
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ee1262d5
编写于
8月 24, 2017
作者:
C
caoying03
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update model config.
上级
d30a28c7
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
114 addition
and
85 deletion
+114
-85
globally_normalized_reader/basic_modules.py
globally_normalized_reader/basic_modules.py
+9
-6
globally_normalized_reader/config.py
globally_normalized_reader/config.py
+11
-3
globally_normalized_reader/model.py
globally_normalized_reader/model.py
+30
-29
globally_normalized_reader/reader.py
globally_normalized_reader/reader.py
+17
-16
globally_normalized_reader/train.py
globally_normalized_reader/train.py
+47
-31
未找到文件。
globally_normalized_reader/basic_modules.py
浏览文件 @
ee1262d5
...
@@ -38,9 +38,11 @@ def stacked_bidirectional_lstm(inputs, size, depth, drop_rate=0., prefix=""):
...
@@ -38,9 +38,11 @@ def stacked_bidirectional_lstm(inputs, size, depth, drop_rate=0., prefix=""):
paddle
.
layer
.
last_seq
(
input
=
lstm_last
[
0
]),
paddle
.
layer
.
last_seq
(
input
=
lstm_last
[
0
]),
paddle
.
layer
.
first_seq
(
input
=
lstm_last
[
1
]),
paddle
.
layer
.
first_seq
(
input
=
lstm_last
[
1
]),
])
])
return
final_states
,
paddle
.
layer
.
concat
(
lstm_outs
=
paddle
.
layer
.
concat
(
input
=
lstm_last
,
input
=
lstm_last
,
layer_attr
=
paddle
.
attr
.
ExtraLayerAttribute
(
drop_rate
=
drop_rate
),
)
layer_attr
=
paddle
.
attr
.
ExtraLayerAttribute
(
drop_rate
=
drop_rate
))
return
final_states
,
lstm_outs
def
lstm_by_nested_sequence
(
input_layer
,
hidden_dim
,
name
=
""
,
reverse
=
False
):
def
lstm_by_nested_sequence
(
input_layer
,
hidden_dim
,
name
=
""
,
reverse
=
False
):
...
@@ -70,8 +72,9 @@ def lstm_by_nested_sequence(input_layer, hidden_dim, name="", reverse=False):
...
@@ -70,8 +72,9 @@ def lstm_by_nested_sequence(input_layer, hidden_dim, name="", reverse=False):
name
=
"__inner_state_%s__"
%
name
,
name
=
"__inner_state_%s__"
%
name
,
size
=
hidden_dim
,
size
=
hidden_dim
,
boot_layer
=
outer_memory
)
boot_layer
=
outer_memory
)
input_proj
=
paddle
.
layer
.
fc
(
input_proj
=
paddle
.
layer
.
fc
(
size
=
hidden_dim
*
4
,
size
=
hidden_dim
*
4
,
bias_attr
=
False
,
input
=
input_layer
)
bias_attr
=
False
,
input
=
input_layer
)
return
paddle
.
networks
.
lstmemory_unit
(
return
paddle
.
networks
.
lstmemory_unit
(
input
=
input_proj
,
input
=
input_proj
,
name
=
"__inner_state_%s__"
%
name
,
name
=
"__inner_state_%s__"
%
name
,
...
@@ -91,12 +94,12 @@ def lstm_by_nested_sequence(input_layer, hidden_dim, name="", reverse=False):
...
@@ -91,12 +94,12 @@ def lstm_by_nested_sequence(input_layer, hidden_dim, name="", reverse=False):
inner_last_output
=
paddle
.
layer
.
first_seq
(
inner_last_output
=
paddle
.
layer
.
first_seq
(
input
=
inner_out
,
input
=
inner_out
,
name
=
"__inner_%s_last__"
%
name
,
name
=
"__inner_%s_last__"
%
name
,
agg_level
=
paddle
.
layer
.
AggregateLevel
.
TO_SEQUENCE
)
agg_level
=
paddle
.
layer
.
AggregateLevel
.
TO_
NO_
SEQUENCE
)
else
:
else
:
inner_last_output
=
paddle
.
layer
.
last_seq
(
inner_last_output
=
paddle
.
layer
.
last_seq
(
input
=
inner_out
,
input
=
inner_out
,
name
=
"__inner_%s_last__"
%
name
,
name
=
"__inner_%s_last__"
%
name
,
agg_level
=
paddle
.
layer
.
AggregateLevel
.
TO_SEQUENCE
)
agg_level
=
paddle
.
layer
.
AggregateLevel
.
TO_
NO_
SEQUENCE
)
return
inner_out
return
inner_out
return
paddle
.
layer
.
recurrent_group
(
return
paddle
.
layer
.
recurrent_group
(
...
...
globally_normalized_reader/config.py
浏览文件 @
ee1262d5
...
@@ -6,8 +6,8 @@ __all__ = ["ModelConfig"]
...
@@ -6,8 +6,8 @@ __all__ = ["ModelConfig"]
class
ModelConfig
(
object
):
class
ModelConfig
(
object
):
beam_size
=
3
beam_size
=
3
vocab_size
=
10
2400
vocab_size
=
10
4808
embedding_dim
=
256
embedding_dim
=
300
embedding_droprate
=
0.3
embedding_droprate
=
0.3
lstm_depth
=
3
lstm_depth
=
3
...
@@ -17,9 +17,17 @@ class ModelConfig(object):
...
@@ -17,9 +17,17 @@ class ModelConfig(object):
passage_indep_embedding_dim
=
300
passage_indep_embedding_dim
=
300
passage_aligned_embedding_dim
=
128
passage_aligned_embedding_dim
=
128
beam_size
=
5
beam_size
=
32
dict_path
=
"data/featurized/vocab.txt"
pretrained_emb_path
=
"data/featurized/embeddings.npy"
class
TrainerConfig
(
object
):
class
TrainerConfig
(
object
):
learning_rate
=
1e-3
learning_rate
=
1e-3
data_dir
=
"data/featurized"
data_dir
=
"data/featurized"
save_dir
=
"models"
batch_size
=
12
*
4
epochs
=
100
globally_normalized_reader/model.py
浏览文件 @
ee1262d5
...
@@ -10,15 +10,10 @@ from config import ModelConfig
...
@@ -10,15 +10,10 @@ from config import ModelConfig
__all__
=
[
"GNR"
]
__all__
=
[
"GNR"
]
def
build_pretrained_embedding
(
name
,
def
build_pretrained_embedding
(
name
,
data_type
,
emb_dim
,
emb_drop
=
0.
):
data_type
,
vocab_size
,
emb_dim
,
emb_drop
=
0.
):
one_hot_input
=
paddle
.
layer
.
data
(
name
=
name
,
type
=
paddle
.
data_type
.
integer_value_sequence
(
vocab_size
))
return
paddle
.
layer
.
embedding
(
return
paddle
.
layer
.
embedding
(
input
=
one_hot_input
,
input
=
paddle
.
layer
.
data
(
name
=
name
,
type
=
data_type
),
size
=
emb_dim
,
size
=
emb_dim
,
param_attr
=
paddle
.
attr
.
Param
(
param_attr
=
paddle
.
attr
.
Param
(
name
=
"GloveVectors"
,
is_static
=
True
),
name
=
"GloveVectors"
,
is_static
=
True
),
...
@@ -112,25 +107,24 @@ def encode_documents(input_embedding, same_as_question, question_vector,
...
@@ -112,25 +107,24 @@ def encode_documents(input_embedding, same_as_question, question_vector,
])
])
def
search_answer
(
doc_lstm_outs
,
sentence_idx
,
start_idx
,
end_idx
,
config
):
def
search_answer
(
doc_lstm_outs
,
sentence_idx
,
start_idx
,
end_idx
,
config
,
is_infer
):
last_state_of_sentence
=
paddle
.
layer
.
last_seq
(
last_state_of_sentence
=
paddle
.
layer
.
last_seq
(
input
=
doc_lstm_outs
,
agg_level
=
paddle
.
layer
.
AggregateLevel
.
TO_SEQUENCE
)
input
=
doc_lstm_outs
,
agg_level
=
paddle
.
layer
.
AggregateLevel
.
TO_SEQUENCE
)
# HERE do not use sequence softmax activition.
sentence_scores
=
paddle
.
layer
.
fc
(
input
=
last_state_of_sentence
,
sentence_scores
=
paddle
.
layer
.
fc
(
input
=
last_state_of_sentence
,
size
=
1
,
size
=
1
,
act
=
paddle
.
activation
.
Exp
())
act
=
paddle
.
activation
.
Linear
())
topk_sentence_ids
=
paddle
.
layer
.
kmax_sequence_score
(
topk_sentence_ids
=
paddle
.
layer
.
kmax_sequence_score
(
input
=
sentence_scores
,
beam_size
=
config
.
beam_size
)
input
=
sentence_scores
,
beam_size
=
config
.
beam_size
)
topk_sen
=
paddle
.
layer
.
sub_nested_seq
(
topk_sen
=
paddle
.
layer
.
sub_nested_seq
(
input
=
last_state_of_sentence
,
selected_indices
=
topk_sentence_ids
)
input
=
doc_lstm_outs
,
selected_indices
=
topk_sentence_ids
)
# expand beam to search start positions on selected sentences
# expand beam to search start positions on selected sentences
start_pos_scores
=
paddle
.
layer
.
fc
(
input
=
topk_sen
,
start_pos_scores
=
paddle
.
layer
.
fc
(
input
=
topk_sen
,
size
=
1
,
size
=
1
,
act
=
paddle
.
activation
.
Exp
())
act
=
paddle
.
activation
.
Linear
())
topk_start_pos_ids
=
paddle
.
layer
.
kmax_sequence_score
(
topk_start_pos_ids
=
paddle
.
layer
.
kmax_sequence_score
(
input
=
s
entence
_scores
,
beam_size
=
config
.
beam_size
)
input
=
s
tart_pos
_scores
,
beam_size
=
config
.
beam_size
)
topk_start_spans
=
paddle
.
layer
.
seq_slice
(
topk_start_spans
=
paddle
.
layer
.
seq_slice
(
input
=
topk_sen
,
starts
=
topk_start_pos_ids
,
ends
=
None
)
input
=
topk_sen
,
starts
=
topk_start_pos_ids
,
ends
=
None
)
...
@@ -143,33 +137,40 @@ def search_answer(doc_lstm_outs, sentence_idx, start_idx, end_idx, config):
...
@@ -143,33 +137,40 @@ def search_answer(doc_lstm_outs, sentence_idx, start_idx, end_idx, config):
prefix
=
"__end_span_embeddings__"
)
prefix
=
"__end_span_embeddings__"
)
end_pos_scores
=
paddle
.
layer
.
fc
(
input
=
end_span_embedding
,
end_pos_scores
=
paddle
.
layer
.
fc
(
input
=
end_span_embedding
,
size
=
1
,
size
=
1
,
act
=
paddle
.
activation
.
Exp
())
act
=
paddle
.
activation
.
Linear
())
topk_end_pos_ids
=
paddle
.
layer
.
kmax_sequence_score
(
topk_end_pos_ids
=
paddle
.
layer
.
kmax_sequence_score
(
input
=
end_pos_scores
,
beam_size
=
config
.
beam_size
)
input
=
end_pos_scores
,
beam_size
=
config
.
beam_size
)
cost
=
paddle
.
layer
.
cross_entropy_over_beam
(
input
=
[
sentence_scores
,
topk_sentence_ids
,
start_pos_scores
,
topk_start_pos_ids
,
end_pos_scores
,
topk_end_pos_ids
],
label
=
[
sentence_idx
,
start_idx
,
end_idx
])
return
cost
if
is_infer
:
return
[
topk_sentence_ids
,
topk_start_pos_ids
,
topk_end_pos_ids
]
else
:
return
paddle
.
layer
.
cross_entropy_over_beam
(
input
=
[
paddle
.
layer
.
BeamInput
(
sentence_scores
,
topk_sentence_ids
,
sentence_idx
),
paddle
.
layer
.
BeamInput
(
start_pos_scores
,
topk_start_pos_ids
,
start_idx
),
paddle
.
layer
.
BeamInput
(
end_pos_scores
,
topk_end_pos_ids
,
end_idx
)
])
def
GNR
(
config
):
def
GNR
(
config
,
is_infer
=
False
):
# encoding question words
# encoding question words
question_embeddings
=
build_pretrained_embedding
(
question_embeddings
=
build_pretrained_embedding
(
"question"
,
paddle
.
data_type
.
integer_value_sequence
,
config
.
vocab_size
,
"question"
,
paddle
.
data_type
.
integer_value_sequence
(
config
.
vocab_size
),
config
.
embedding_dim
,
config
.
embedding_droprate
)
config
.
embedding_dim
,
config
.
embedding_droprate
)
question_vector
,
question_lstm_outs
=
encode_question
(
question_vector
,
question_lstm_outs
=
encode_question
(
input_embedding
=
question_embeddings
,
config
=
config
,
prefix
=
"__ques"
)
input_embedding
=
question_embeddings
,
config
=
config
,
prefix
=
"__ques"
)
# encoding document words
# encoding document words
document_embeddings
=
build_pretrained_embedding
(
document_embeddings
=
build_pretrained_embedding
(
"documents"
,
paddle
.
data_type
.
integer_value_sub_sequence
,
"documents"
,
config
.
vocab_size
,
config
.
embedding_dim
,
config
.
embedding_droprate
)
paddle
.
data_type
.
integer_value_sub_sequence
(
config
.
vocab_size
),
config
.
embedding_dim
,
config
.
embedding_droprate
)
same_as_question
=
paddle
.
layer
.
data
(
same_as_question
=
paddle
.
layer
.
data
(
name
=
"same_as_question"
,
name
=
"same_as_question"
,
type
=
paddle
.
data_type
.
integer_value_sub_sequence
(
2
))
type
=
paddle
.
data_type
.
dense_vector_sub_sequence
(
1
))
document_words_ecoding
=
encode_documents
(
document_words_ecoding
=
encode_documents
(
input_embedding
=
document_embeddings
,
input_embedding
=
document_embeddings
,
question_vector
=
question_vector
,
question_vector
=
question_vector
,
...
@@ -192,7 +193,7 @@ def GNR(config):
...
@@ -192,7 +193,7 @@ def GNR(config):
end_idx
=
paddle
.
layer
.
data
(
end_idx
=
paddle
.
layer
.
data
(
name
=
"end_idx"
,
type
=
paddle
.
data_type
.
integer_value
(
1
))
name
=
"end_idx"
,
type
=
paddle
.
data_type
.
integer_value
(
1
))
return
search_answer
(
doc_lstm_outs
,
sentence_idx
,
start_idx
,
end_idx
,
return
search_answer
(
doc_lstm_outs
,
sentence_idx
,
start_idx
,
end_idx
,
config
)
config
,
is_infer
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
globally_normalized_reader/reader.py
浏览文件 @
ee1262d5
#!/usr/bin/env python
#!/usr/bin/env python
#coding=utf-8
#coding=utf-8
import
pdb
import
os
import
os
import
random
import
random
import
json
import
json
import
logging
logger
=
logging
.
getLogger
(
"paddle"
)
logger
.
setLevel
(
logging
.
INFO
)
def
train_reader
(
data_list
,
is_train
=
True
):
def
train_reader
(
data_list
,
is_train
=
True
):
...
@@ -14,22 +17,20 @@ def train_reader(data_list, is_train=True):
...
@@ -14,22 +17,20 @@ def train_reader(data_list, is_train=True):
for
train_sample
in
data_list
:
for
train_sample
in
data_list
:
data
=
json
.
load
(
open
(
train_sample
,
"r"
))
data
=
json
.
load
(
open
(
train_sample
,
"r"
))
sent_len
=
data
[
'sent_lengths'
]
doc_len
=
len
(
data
[
'context'
])
same_as_question_word
=
[[[
x
]]
for
x
in
data
[
'same_as_question_word'
]]
ans_sentence
=
[
0
]
*
doc_len
ans_sentence
[
data
[
'ans_sentence'
]]
=
1
ans_start
=
[
0
]
*
doc_len
ans_start
[
data
[
'ans_start'
]]
=
1
ans_end
=
[
0
]
*
doc_len
start_pos
=
0
ans_end
[
data
[
'ans_end'
]]
=
1
doc
=
[]
yield
(
data
[
'question'
],
data
[
'context'
],
same_as_question_word
,
same_as_question_word
=
[]
ans_sentence
,
ans_start
,
ans_end
)
for
l
in
data
[
'sent_lengths'
]:
doc
.
append
(
data
[
'context'
][
start_pos
:
start_pos
+
l
])
same_as_question_word
.
append
([
[[
x
]]
for
x
in
data
[
'same_as_question_word'
]
][
start_pos
:
start_pos
+
l
])
start_pos
+=
l
yield
(
data
[
'question'
],
doc
,
same_as_question_word
,
data
[
'ans_sentence'
],
data
[
'ans_start'
],
data
[
'ans_end'
]
-
data
[
'ans_start'
])
return
reader
return
reader
...
...
globally_normalized_reader/train.py
浏览文件 @
ee1262d5
...
@@ -9,6 +9,7 @@ import logging
...
@@ -9,6 +9,7 @@ import logging
import
random
import
random
import
glob
import
glob
import
gzip
import
gzip
import
numpy
as
np
import
reader
import
reader
import
paddle.v2
as
paddle
import
paddle.v2
as
paddle
...
@@ -21,7 +22,7 @@ logger.setLevel(logging.INFO)
...
@@ -21,7 +22,7 @@ logger.setLevel(logging.INFO)
def
load_pretrained_parameters
(
path
,
height
,
width
):
def
load_pretrained_parameters
(
path
,
height
,
width
):
return
return
np
.
load
(
path
)
def
save_model
(
save_path
,
parameters
):
def
save_model
(
save_path
,
parameters
):
...
@@ -51,27 +52,30 @@ def choose_samples(path):
...
@@ -51,27 +52,30 @@ def choose_samples(path):
train_samples
.
sort
()
train_samples
.
sort
()
valid_samples
.
sort
()
valid_samples
.
sort
()
random
.
shuffle
(
train_samples
)
#
random.shuffle(train_samples)
return
train_samples
,
valid_samples
return
train_samples
,
valid_samples
def
build_reader
(
data_dir
):
def
build_reader
(
data_dir
,
batch_size
):
"""
"""
Build the data reader for this model.
Build the data reader for this model.
"""
"""
train_samples
,
valid_samples
=
choose_samples
(
data_dir
)
train_samples
,
valid_samples
=
choose_samples
(
data_dir
)
pdb
.
set_trace
()
train_reader
=
paddle
.
batch
(
train_reader
=
paddle
.
batch
(
paddle
.
reader
.
shuffle
(
paddle
.
reader
.
shuffle
(
reader
.
train_reader
(
train_samples
),
buf_size
=
102400
),
reader
.
train_reader
(
train_samples
),
buf_size
=
102400
),
batch_size
=
config
.
batch_size
)
batch_size
=
batch_size
)
# train_reader = paddle.batch(
# reader.train_reader(train_samples), batch_size=batch_size)
# testing data is not shuffled
# testing data is not shuffled
test_reader
=
paddle
.
batch
(
test_reader
=
paddle
.
batch
(
reader
.
train_reader
(
valid_samples
,
is_train
=
False
),
reader
.
train_reader
(
batch_size
=
config
.
batch_size
)
valid_samples
,
is_train
=
False
),
batch_size
=
batch_size
)
return
train_reader
,
test_reader
return
train_reader
,
test_reader
...
@@ -85,53 +89,65 @@ def build_event_handler(config, parameters, trainer, test_reader):
...
@@ -85,53 +89,65 @@ def build_event_handler(config, parameters, trainer, test_reader):
"""The event handler."""
"""The event handler."""
if
isinstance
(
event
,
paddle
.
event
.
EndIteration
):
if
isinstance
(
event
,
paddle
.
event
.
EndIteration
):
if
(
not
event
.
batch_id
%
100
)
and
event
.
batch_id
:
if
(
not
event
.
batch_id
%
100
)
and
event
.
batch_id
:
save_model
(
"checkpoint_param.latest.tar.gz"
,
parameters
)
save_path
=
os
.
path
.
join
(
config
.
save_dir
,
"checkpoint_param.latest.tar.gz"
)
save_model
(
save_path
,
parameters
)
if
not
event
.
batch_id
%
5
:
if
not
event
.
batch_id
%
1
:
logger
.
info
(
logger
.
info
(
"Pass %d, Batch %d, Cost %f, %s"
%
"Pass %d, Batch %d, Cost %f, %s"
%
(
event
.
pass_id
,
event
.
batch_id
,
event
.
cost
,
event
.
metrics
))
(
event
.
pass_id
,
event
.
batch_id
,
event
.
cost
,
event
.
metrics
))
if
isinstance
(
event
,
paddle
.
event
.
EndPass
):
if
isinstance
(
event
,
paddle
.
event
.
EndPass
):
save_model
(
config
.
param_save_filename_format
%
event
.
pass_id
,
save_path
=
os
.
path
.
join
(
config
.
save_dir
,
parameters
)
"pass_%05d.tar.gz"
%
event
.
pass_id
)
with
gzip
.
open
(
param_path
,
'w'
)
as
handle
:
save_model
(
save_path
,
parameters
)
parameters
.
to_tar
(
handle
)
result
=
trainer
.
test
(
reader
=
test_reader
)
#
result = trainer.test(reader=test_reader)
logger
.
info
(
"Test with Pass %d, %s"
%
#
logger.info("Test with Pass %d, %s" %
(
event
.
pass_id
,
result
.
metrics
))
#
(event.pass_id, result.metrics))
return
event_handler
return
event_handler
def
train
(
model_config
,
trainer_config
):
def
train
(
model_config
,
trainer_config
):
paddle
.
init
(
use_gpu
=
True
,
trainer_count
=
1
)
if
not
os
.
path
.
exists
(
trainer_config
.
save_dir
):
os
.
mkdir
(
trainer_config
.
save_dir
)
paddle
.
init
(
use_gpu
=
True
,
trainer_count
=
4
)
# define the optimizer
# define the optimizer
optimizer
=
paddle
.
optimizer
.
Adam
(
optimizer
=
paddle
.
optimizer
.
Adam
(
learning_rate
=
trainer_config
.
learning_rate
,
learning_rate
=
trainer_config
.
learning_rate
,
regularization
=
paddle
.
optimizer
.
L2Regularization
(
rate
=
1e-3
),
regularization
=
paddle
.
optimizer
.
L2Regularization
(
rate
=
1e-3
),
model_average
=
paddle
.
optimizer
.
ModelAverage
(
average_window
=
0.5
))
# model_average=paddle.optimizer.ModelAverage(average_window=0.5))
)
# define network topology
# define network topology
losses
=
GNR
(
model_config
)
loss
=
GNR
(
model_config
)
parameters
=
paddle
.
parameters
.
create
(
losses
)
# print(parse_network(losses))
# print(parse_network(loss))
trainer
=
paddle
.
trainer
.
SGD
(
cost
=
losses
,
parameters
=
parameters
,
update_equation
=
optimizer
)
parameters
=
paddle
.
parameters
.
create
(
loss
)
"""
parameters
.
set
(
"GloveVectors"
,
parameters.set('GloveVectors',
load_pretrained_parameters
(
load_pretrained_parameters(parameter_path, height, width))
ModelConfig
.
pretrained_emb_path
,
"""
height
=
ModelConfig
.
vocab_size
,
width
=
ModelConfig
.
embedding_dim
))
trainer
=
paddle
.
trainer
.
SGD
(
cost
=
loss
,
parameters
=
parameters
,
update_equation
=
optimizer
)
# define data reader
# define data reader
train_reader
,
test_reader
=
build_reader
(
trainer_config
.
data_dir
)
train_reader
,
test_reader
=
build_reader
(
trainer_config
.
data_dir
,
trainer_config
.
batch_size
)
event_handler
=
build_event_handler
(
conf
,
parameters
,
trainer
,
test_reader
)
event_handler
=
build_event_handler
(
trainer_config
,
parameters
,
trainer
,
test_reader
)
trainer
.
train
(
trainer
.
train
(
reader
=
train
_reader
,
reader
=
data
_reader
,
num_passes
=
conf
.
epochs
,
num_passes
=
trainer_config
.
epochs
,
event_handler
=
event_handler
)
event_handler
=
event_handler
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录