Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
models
提交
b6126ac9
M
models
项目概览
PaddlePaddle
/
models
大约 1 年 前同步成功
通知
222
Star
6828
Fork
2962
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
602
列表
看板
标记
里程碑
合并请求
255
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
602
Issue
602
列表
看板
标记
里程碑
合并请求
255
合并请求
255
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b6126ac9
编写于
7月 19, 2017
作者:
C
Cao Ying
提交者:
GitHub
7月 19, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #172 from Superjom/dssm
DSSM enhancement
上级
3a784aa0
47818f8b
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
306 addition
and
85 deletion
+306
-85
dssm/README.md
dssm/README.md
+72
-5
dssm/infer.py
dssm/infer.py
+140
-0
dssm/network_conf.py
dssm/network_conf.py
+17
-46
dssm/reader.py
dssm/reader.py
+20
-8
dssm/train.py
dssm/train.py
+50
-26
dssm/utils.py
dssm/utils.py
+7
-0
未找到文件。
dssm/README.md
浏览文件 @
b6126ac9
...
...
@@ -384,11 +384,13 @@ def _build_rank_model(self):
```
usage
:
train
.
py
[-
h
]
[-
i
TRAIN_DATA_PATH
]
[-
t
TEST_DATA_PATH
]
[-
s
SOURCE_DIC_PATH
]
[--
target_dic_path
TARGET_DIC_PATH
]
[-
b
BATCH_SIZE
]
[-
p
NUM_PASSES
]
-
y
MODEL_TYPE
--
model_arch
MODEL_ARCH
[-
b
BATCH_SIZE
]
[-
p
NUM_PASSES
]
-
y
MODEL_TYPE
-
a
MODEL_ARCH
[--
share_network_between_source_target
SHARE_NETWORK_BETWEEN_SOURCE_TARGET
]
[--
share_embed
SHARE_EMBED
]
[--
dnn_dims
DNN_DIMS
]
[--
num_workers
NUM_WORKERS
]
[--
use_gpu
USE_GPU
]
[-
c
CLASS_NUM
]
[--
model_output_prefix
MODEL_OUTPUT_PREFIX
]
[-
g
NUM_BATCHES_TO_LOG
]
[-
e
NUM_BATCHES_TO_TEST
]
[-
z
NUM_BATCHES_TO_SAVE_MODEL
]
PaddlePaddle
DSSM
example
...
...
@@ -408,9 +410,9 @@ optional arguments:
-
p
NUM_PASSES
,
--
num_passes
NUM_PASSES
number
of
passes
to
run
(
default
:
10
)
-
y
MODEL_TYPE
,
--
model_type
MODEL_TYPE
model
type
,
0
for
classification
,
1
for
pairwise
rank
(
default
:
classification
)
--
model_arch
MODEL_ARCH
model
type
,
0
for
classification
,
1
for
pairwise
rank
,
2
for
regression
(
default
:
classification
)
-
a
MODEL_ARCH
,
-
-
model_arch
MODEL_ARCH
model
architecture
,
1
for
CNN
,
0
for
FC
,
2
for
RNN
--
share_network_between_source_target
SHARE_NETWORK_BETWEEN_SOURCE_TARGET
whether
to
share
network
parameters
between
source
and
...
...
@@ -426,8 +428,73 @@ optional arguments:
--
use_gpu
USE_GPU
whether
to
use
GPU
devices
(
default
:
False
)
-
c
CLASS_NUM
,
--
class_num
CLASS_NUM
number
of
categories
for
classification
task
.
--
model_output_prefix
MODEL_OUTPUT_PREFIX
prefix
of
the
path
for
model
to
store
,
(
default
:
./)
-
g
NUM_BATCHES_TO_LOG
,
--
num_batches_to_log
NUM_BATCHES_TO_LOG
number
of
batches
to
output
train
log
,
(
default
:
100
)
-
e
NUM_BATCHES_TO_TEST
,
--
num_batches_to_test
NUM_BATCHES_TO_TEST
number
of
batches
to
test
,
(
default
:
200
)
-
z
NUM_BATCHES_TO_SAVE_MODEL
,
--
num_batches_to_save_model
NUM_BATCHES_TO_SAVE_MODEL
number
of
batches
to
output
model
,
(
default
:
400
)
```
重要的参数描述如下
-
`train_data_path`
训练数据路径
-
`test_data_path`
测试数据路局,可以不设置
-
`source_dic_path`
源字典字典路径
-
`target_dic_path`
目标字典路径
-
`model_type`
模型的损失函数的类型,分类0,排序1,回归2
-
`model_arch`
模型结构,FC 0, CNN 1, RNN 2
-
`dnn_dims`
模型各层的维度设置,默认为
`256,128,64,32`
,即模型有4层,各层维度如上设置
## 用训练好的模型预测
```
usage
:
infer
.
py
[-
h
]
--
model_path
MODEL_PATH
-
i
DATA_PATH
-
o
PREDICTION_OUTPUT_PATH
-
y
MODEL_TYPE
[-
s
SOURCE_DIC_PATH
]
[--
target_dic_path
TARGET_DIC_PATH
]
-
a
MODEL_ARCH
[--
share_network_between_source_target
SHARE_NETWORK_BETWEEN_SOURCE_TARGET
]
[--
share_embed
SHARE_EMBED
]
[--
dnn_dims
DNN_DIMS
]
[-
c
CLASS_NUM
]
PaddlePaddle
DSSM
infer
optional
arguments
:
-
h
,
--
help
show
this
help
message
and
exit
--
model_path
MODEL_PATH
path
of
model
parameters
file
-
i
DATA_PATH
,
--
data_path
DATA_PATH
path
of
the
dataset
to
infer
-
o
PREDICTION_OUTPUT_PATH
,
--
prediction_output_path
PREDICTION_OUTPUT_PATH
path
to
output
the
prediction
-
y
MODEL_TYPE
,
--
model_type
MODEL_TYPE
model
type
,
0
for
classification
,
1
for
pairwise
rank
,
2
for
regression
(
default
:
classification
)
-
s
SOURCE_DIC_PATH
,
--
source_dic_path
SOURCE_DIC_PATH
path
of
the
source
's word dic
--target_dic_path TARGET_DIC_PATH
path of the target'
s
word
dic
,
if
not
set
,
the
`
source_dic_path
`
will
be
used
-
a
MODEL_ARCH
,
--
model_arch
MODEL_ARCH
model
architecture
,
1
for
CNN
,
0
for
FC
,
2
for
RNN
--
share_network_between_source_target
SHARE_NETWORK_BETWEEN_SOURCE_TARGET
whether
to
share
network
parameters
between
source
and
target
--
share_embed
SHARE_EMBED
whether
to
share
word
embedding
between
source
and
target
--
dnn_dims
DNN_DIMS
dimentions
of
dnn
layers
,
default
is
'256,128,64,32'
,
which
means
create
a
4
-
layer
dnn
,
demention
of
each
layer
is
256
,
128
,
64
and
32
-
c
CLASS_NUM
,
--
class_num
CLASS_NUM
number
of
categories
for
classification
task
.
```
部分参数可以参考
`train.py`
,重要参数解释如下
-
`data_path`
需要预测的数据路径
-
`prediction_output_path`
预测的输出路径
## 参考文献
1.
Huang P S, He X, Gao J, et al. Learning deep structured semantic models for web search using clickthrough data[C]//Proceedings of the 22nd ACM international conference on Conference on information & knowledge management. ACM, 2013: 2333-2338.
...
...
dssm/infer.py
浏览文件 @
b6126ac9
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import
argparse
import
itertools
import
reader
import
paddle.v2
as
paddle
from
network_conf
import
DSSM
from
utils
import
logger
,
ModelType
,
ModelArch
,
load_dic
parser
=
argparse
.
ArgumentParser
(
description
=
"PaddlePaddle DSSM infer"
)
parser
.
add_argument
(
'--model_path'
,
type
=
str
,
required
=
True
,
help
=
"path of model parameters file"
)
parser
.
add_argument
(
'-i'
,
'--data_path'
,
type
=
str
,
required
=
True
,
help
=
"path of the dataset to infer"
)
parser
.
add_argument
(
'-o'
,
'--prediction_output_path'
,
type
=
str
,
required
=
True
,
help
=
"path to output the prediction"
)
parser
.
add_argument
(
'-y'
,
'--model_type'
,
type
=
int
,
required
=
True
,
default
=
ModelType
.
CLASSIFICATION_MODE
,
help
=
"model type, %d for classification, %d for pairwise rank, %d for regression (default: classification)"
%
(
ModelType
.
CLASSIFICATION_MODE
,
ModelType
.
RANK_MODE
,
ModelType
.
REGRESSION_MODE
))
parser
.
add_argument
(
'-s'
,
'--source_dic_path'
,
type
=
str
,
required
=
False
,
help
=
"path of the source's word dic"
)
parser
.
add_argument
(
'--target_dic_path'
,
type
=
str
,
required
=
False
,
help
=
"path of the target's word dic, if not set, the `source_dic_path` will be used"
)
parser
.
add_argument
(
'-a'
,
'--model_arch'
,
type
=
int
,
required
=
True
,
default
=
ModelArch
.
CNN_MODE
,
help
=
"model architecture, %d for CNN, %d for FC, %d for RNN"
%
(
ModelArch
.
CNN_MODE
,
ModelArch
.
FC_MODE
,
ModelArch
.
RNN_MODE
))
parser
.
add_argument
(
'--share_network_between_source_target'
,
type
=
bool
,
default
=
False
,
help
=
"whether to share network parameters between source and target"
)
parser
.
add_argument
(
'--share_embed'
,
type
=
bool
,
default
=
False
,
help
=
"whether to share word embedding between source and target"
)
parser
.
add_argument
(
'--dnn_dims'
,
type
=
str
,
default
=
'256,128,64,32'
,
help
=
"dimentions of dnn layers, default is '256,128,64,32', which means create a 4-layer dnn, demention of each layer is 256, 128, 64 and 32"
)
parser
.
add_argument
(
'-c'
,
'--class_num'
,
type
=
int
,
default
=
0
,
help
=
"number of categories for classification task."
)
args
=
parser
.
parse_args
()
args
.
model_type
=
ModelType
(
args
.
model_type
)
args
.
model_arch
=
ModelArch
(
args
.
model_arch
)
if
args
.
model_type
.
is_classification
():
assert
args
.
class_num
>
1
,
"--class_num should be set in classification task."
layer_dims
=
map
(
int
,
args
.
dnn_dims
.
split
(
','
))
args
.
target_dic_path
=
args
.
source_dic_path
if
not
args
.
target_dic_path
else
args
.
target_dic_path
paddle
.
init
(
use_gpu
=
False
,
trainer_count
=
1
)
class
Inferer
(
object
):
def
__init__
(
self
,
param_path
):
logger
.
info
(
"create DSSM model"
)
cost
,
prediction
,
label
=
DSSM
(
dnn_dims
=
layer_dims
,
vocab_sizes
=
[
len
(
load_dic
(
path
))
for
path
in
[
args
.
source_dic_path
,
args
.
target_dic_path
]
],
model_type
=
args
.
model_type
,
model_arch
=
args
.
model_arch
,
share_semantic_generator
=
args
.
share_network_between_source_target
,
class_num
=
args
.
class_num
,
share_embed
=
args
.
share_embed
)()
# load parameter
logger
.
info
(
"load model parameters from %s"
%
param_path
)
self
.
parameters
=
paddle
.
parameters
.
Parameters
.
from_tar
(
open
(
param_path
,
'r'
))
self
.
inferer
=
paddle
.
inference
.
Inference
(
output_layer
=
prediction
,
parameters
=
self
.
parameters
)
def
infer
(
self
,
data_path
):
logger
.
info
(
"infer data..."
)
dataset
=
reader
.
Dataset
(
train_path
=
data_path
,
test_path
=
None
,
source_dic_path
=
args
.
source_dic_path
,
target_dic_path
=
args
.
target_dic_path
,
model_type
=
args
.
model_type
,
)
infer_reader
=
paddle
.
batch
(
dataset
.
infer
,
batch_size
=
1000
)
logger
.
warning
(
'write predictions to %s'
%
args
.
prediction_output_path
)
output_f
=
open
(
args
.
prediction_output_path
,
'w'
)
for
id
,
batch
in
enumerate
(
infer_reader
()):
res
=
self
.
inferer
.
infer
(
input
=
batch
)
predictions
=
[
' '
.
join
(
map
(
str
,
x
))
for
x
in
res
]
assert
len
(
batch
)
==
len
(
predictions
),
"predict error, %d inputs, but %d predictions"
%
(
len
(
batch
),
len
(
predictions
))
output_f
.
write
(
'
\n
'
.
join
(
map
(
str
,
predictions
))
+
'
\n
'
)
if
__name__
==
'__main__'
:
inferer
=
Inferer
(
args
.
model_path
)
inferer
.
infer
(
args
.
data_path
)
dssm/network_conf.py
浏览文件 @
b6126ac9
...
...
@@ -11,7 +11,8 @@ class DSSM(object):
model_arch
=
ModelArch
.
create_cnn
(),
share_semantic_generator
=
False
,
class_num
=
None
,
share_embed
=
False
):
share_embed
=
False
,
is_infer
=
False
):
'''
@dnn_dims: list of int
dimentions of each layer in semantic vector generator.
...
...
@@ -40,6 +41,7 @@ class DSSM(object):
self
.
model_type
=
ModelType
(
model_type
)
self
.
model_arch
=
ModelArch
(
model_arch
)
self
.
class_num
=
class_num
self
.
is_infer
=
is_infer
logger
.
warning
(
"build DSSM model with config of %s, %s"
%
(
self
.
model_type
,
self
.
model_arch
))
logger
.
info
(
"vocabulary sizes: %s"
%
str
(
self
.
vocab_sizes
))
...
...
@@ -68,9 +70,6 @@ class DSSM(object):
self
.
model_type_creater
=
_model_type
[
str
(
self
.
model_type
)]
def
__call__
(
self
):
# if self.model_type.is_classification():
# return self._build_classification_model()
# return self._build_rank_model()
return
self
.
model_type_creater
()
def
create_embedding
(
self
,
input
,
prefix
=
''
):
...
...
@@ -189,8 +188,9 @@ class DSSM(object):
right_target
=
paddle
.
layer
.
data
(
name
=
'right_target_input'
,
type
=
paddle
.
data_type
.
integer_value_sequence
(
self
.
vocab_sizes
[
1
]))
label
=
paddle
.
layer
.
data
(
name
=
'label_input'
,
type
=
paddle
.
data_type
.
integer_value
(
1
))
if
not
self
.
is_infer
:
label
=
paddle
.
layer
.
data
(
name
=
'label_input'
,
type
=
paddle
.
data_type
.
integer_value
(
1
))
prefixs
=
'_ _ _'
.
split
(
)
if
self
.
share_semantic_generator
else
'source left right'
.
split
()
...
...
@@ -212,12 +212,14 @@ class DSSM(object):
# cossim score of source and right target
right_score
=
paddle
.
layer
.
cos_sim
(
semantics
[
0
],
semantics
[
2
])
# rank cost
cost
=
paddle
.
layer
.
rank_cost
(
left_score
,
right_score
,
label
=
label
)
# prediction = left_score - right_score
# but this operator is not supported currently.
# so AUC will not used.
return
cost
,
None
,
None
if
not
self
.
is_infer
:
# rank cost
cost
=
paddle
.
layer
.
rank_cost
(
left_score
,
right_score
,
label
=
label
)
# prediction = left_score - right_score
# but this operator is not supported currently.
# so AUC will not used.
return
cost
,
None
,
label
return
None
,
[
left_score
,
right_score
],
label
def
_build_classification_or_regression_model
(
self
,
is_classification
):
'''
...
...
@@ -270,38 +272,7 @@ class DSSM(object):
else
:
prediction
=
paddle
.
layer
.
cos_sim
(
*
semantics
)
cost
=
paddle
.
layer
.
mse_cost
(
prediction
,
label
)
return
cost
,
prediction
,
label
class
RankMetrics
(
object
):
'''
A custom metrics to calculate AUC.
Paddle's rank model do not support auc evaluator directly,
to make it, infer all the outputs and use python to calculate
the metrics.
'''
def
__init__
(
self
,
model_parameters
,
left_score_layer
,
right_score_layer
,
label
):
'''
@model_parameters: dict
model's parameters
@left_score_layer: paddle.layer
left part's score
@right_score_laeyr: paddle.layer
right part's score
@label: paddle.data_layer
label input
'''
self
.
inferer
=
paddle
.
inference
.
Inference
(
output_layer
=
[
left_score_layer
,
right_score_layer
],
parameters
=
model_parameters
)
def
test
(
self
,
input
):
scores
=
[]
for
id
,
rcd
in
enumerate
(
input
()):
# output [left_score, right_score, label]
res
=
self
.
inferer
(
input
=
input
)
scores
.
append
(
res
)
print
scores
if
not
self
.
is_infer
:
return
cost
,
prediction
,
label
return
None
,
prediction
,
label
dssm/reader.py
浏览文件 @
b6126ac9
...
...
@@ -23,6 +23,7 @@ class Dataset(object):
assert
isinstance
(
model_type
,
ModelType
)
self
.
record_reader
=
_record_reader
[
model_type
.
mode
]
self
.
is_infer
=
False
def
train
(
self
):
'''
...
...
@@ -37,11 +38,17 @@ class Dataset(object):
'''
Load testset.
'''
logger
.
info
(
"[reader] load testset from %s"
%
self
.
test_path
)
#
logger.info("[reader] load testset from %s" % self.test_path)
with
open
(
self
.
test_path
)
as
f
:
for
line_id
,
line
in
enumerate
(
f
):
yield
self
.
record_reader
(
line
)
def
infer
(
self
):
self
.
is_infer
=
True
with
open
(
self
.
train_path
)
as
f
:
for
line
in
f
:
yield
self
.
record_reader
(
line
)
def
_read_classification_record
(
self
,
line
):
'''
data format:
...
...
@@ -56,8 +63,10 @@ class Dataset(object):
"<source words> [TAB] <target words> [TAB] <label>'"
source
=
sent2ids
(
fs
[
0
],
self
.
source_dic
)
target
=
sent2ids
(
fs
[
1
],
self
.
target_dic
)
label
=
int
(
fs
[
2
])
return
(
source
,
target
,
label
,
)
if
not
self
.
is_infer
:
label
=
int
(
fs
[
2
])
return
(
source
,
target
,
label
,
)
return
source
,
target
def
_read_regression_record
(
self
,
line
):
'''
...
...
@@ -73,8 +82,10 @@ class Dataset(object):
"<source words> [TAB] <target words> [TAB] <label>'"
source
=
sent2ids
(
fs
[
0
],
self
.
source_dic
)
target
=
sent2ids
(
fs
[
1
],
self
.
target_dic
)
label
=
float
(
fs
[
2
])
return
(
source
,
target
,
[
label
],
)
if
not
self
.
is_infer
:
label
=
float
(
fs
[
2
])
return
(
source
,
target
,
[
label
],
)
return
source
,
target
def
_read_rank_record
(
self
,
line
):
'''
...
...
@@ -89,9 +100,10 @@ class Dataset(object):
source
=
sent2ids
(
fs
[
0
],
self
.
source_dic
)
left_target
=
sent2ids
(
fs
[
1
],
self
.
target_dic
)
right_target
=
sent2ids
(
fs
[
2
],
self
.
target_dic
)
label
=
int
(
fs
[
3
])
return
(
source
,
left_target
,
right_target
,
label
)
if
not
self
.
is_infer
:
label
=
int
(
fs
[
3
])
return
(
source
,
left_target
,
right_target
,
label
)
return
source
,
left_target
,
right_target
if
__name__
==
'__main__'
:
...
...
dssm/train.py
浏览文件 @
b6126ac9
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import
argparse
import
gzip
import
paddle.v2
as
paddle
from
network_conf
import
DSSM
import
reader
from
utils
import
TaskType
,
load_dic
,
logger
,
ModelType
,
ModelArch
from
utils
import
TaskType
,
load_dic
,
logger
,
ModelType
,
ModelArch
,
display_args
parser
=
argparse
.
ArgumentParser
(
description
=
"PaddlePaddle DSSM example"
)
...
...
@@ -56,6 +55,7 @@ parser.add_argument(
%
(
ModelType
.
CLASSIFICATION_MODE
,
ModelType
.
RANK_MODE
,
ModelType
.
REGRESSION_MODE
))
parser
.
add_argument
(
'-a'
,
'--model_arch'
,
type
=
int
,
required
=
True
,
...
...
@@ -91,6 +91,29 @@ parser.add_argument(
type
=
int
,
default
=
0
,
help
=
"number of categories for classification task."
)
parser
.
add_argument
(
'--model_output_prefix'
,
type
=
str
,
default
=
"./"
,
help
=
"prefix of the path for model to store, (default: ./)"
)
parser
.
add_argument
(
'-g'
,
'--num_batches_to_log'
,
type
=
int
,
default
=
100
,
help
=
"number of batches to output train log, (default: 100)"
)
parser
.
add_argument
(
'-e'
,
'--num_batches_to_test'
,
type
=
int
,
default
=
200
,
help
=
"number of batches to test, (default: 200)"
)
parser
.
add_argument
(
'-z'
,
'--num_batches_to_save_model'
,
type
=
int
,
default
=
400
,
help
=
"number of batches to output model, (default: 400)"
)
# arguments check.
args
=
parser
.
parse_args
()
...
...
@@ -100,10 +123,7 @@ if args.model_type.is_classification():
assert
args
.
class_num
>
1
,
"--class_num should be set in classification task."
layer_dims
=
[
int
(
i
)
for
i
in
args
.
dnn_dims
.
split
(
','
)]
target_dic_path
=
args
.
source_dic_path
if
not
args
.
target_dic_path
else
args
.
target_dic_path
model_save_name_prefix
=
"dssm_pass_%s_%s"
%
(
args
.
model_type
,
args
.
model_arch
,
)
args
.
target_dic_path
=
args
.
source_dic_path
if
not
args
.
target_dic_path
else
args
.
target_dic_path
def
train
(
train_data_path
=
None
,
...
...
@@ -174,15 +194,10 @@ def train(train_data_path=None,
trainer
=
paddle
.
trainer
.
SGD
(
cost
=
cost
,
extra_layers
=
None
,
extra_layers
=
paddle
.
evaluator
.
auc
(
input
=
prediction
,
label
=
label
)
if
not
model_type
.
is_rank
()
else
None
,
parameters
=
parameters
,
update_equation
=
adam_optimizer
)
# trainer = paddle.trainer.SGD(
# cost=cost,
# extra_layers=paddle.evaluator.auc(input=prediction, label=label)
# if prediction and model_type.is_classification() else None,
# parameters=parameters,
# update_equation=adam_optimizer)
feeding
=
{}
if
model_type
.
is_classification
()
or
model_type
.
is_regression
():
...
...
@@ -200,21 +215,29 @@ def train(train_data_path=None,
Define batch handler
'''
if
isinstance
(
event
,
paddle
.
event
.
EndIteration
):
if
event
.
batch_id
%
100
==
0
:
logger
.
info
(
"Pass %d, Batch %d, Cost %f, %s
\n
"
%
(
# output train log
if
event
.
batch_id
%
args
.
num_batches_to_log
==
0
:
logger
.
info
(
"Pass %d, Batch %d, Cost %f, %s"
%
(
event
.
pass_id
,
event
.
batch_id
,
event
.
cost
,
event
.
metrics
))
if
isinstance
(
event
,
paddle
.
event
.
EndPass
):
if
test_reader
is
not
None
:
if
model_type
.
is_classification
():
result
=
trainer
.
test
(
reader
=
test_reader
,
feeding
=
feeding
)
logger
.
info
(
"Test at Pass %d, %s
\n
"
%
(
event
.
pass_id
,
result
.
metrics
))
else
:
result
=
None
with
gzip
.
open
(
"dssm_%s_pass_%05d.tar.gz"
%
(
model_save_name_prefix
,
event
.
pass_id
),
"w"
)
as
f
:
parameters
.
to_tar
(
f
)
# test model
if
event
.
batch_id
>
0
and
event
.
batch_id
%
args
.
num_batches_to_test
==
0
:
if
test_reader
is
not
None
:
if
model_type
.
is_classification
():
result
=
trainer
.
test
(
reader
=
test_reader
,
feeding
=
feeding
)
logger
.
info
(
"Test at Pass %d, %s"
%
(
event
.
pass_id
,
result
.
metrics
))
else
:
result
=
None
# save model
if
event
.
batch_id
>
0
and
event
.
batch_id
%
args
.
num_batches_to_save_model
==
0
:
model_desc
=
"{type}_{arch}"
.
format
(
type
=
str
(
args
.
model_type
),
arch
=
str
(
args
.
model_arch
))
with
open
(
"%sdssm_%s_pass_%05d.tar"
%
(
args
.
model_output_prefix
,
model_desc
,
event
.
pass_id
),
"w"
)
as
f
:
parameters
.
to_tar
(
f
)
trainer
.
train
(
reader
=
train_reader
,
...
...
@@ -226,6 +249,7 @@ def train(train_data_path=None,
if
__name__
==
'__main__'
:
display_args
(
args
)
train
(
train_data_path
=
args
.
train_data_path
,
test_data_path
=
args
.
test_data_path
,
...
...
dssm/utils.py
浏览文件 @
b6126ac9
import
logging
import
paddle
UNK
=
0
...
...
@@ -126,6 +127,12 @@ def load_dic(path):
return
dic
def
display_args
(
args
):
logger
.
info
(
"arguments passed by command line:"
)
for
k
,
v
in
sorted
(
v
for
v
in
vars
(
args
).
items
()):
logger
.
info
(
"{}:
\t
{}"
.
format
(
k
,
v
))
if
__name__
==
'__main__'
:
t
=
TaskType
(
1
)
t
=
TaskType
.
create_train
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录