Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
models
提交
5c5a64a3
M
models
项目概览
PaddlePaddle
/
models
1 年多 前同步成功
通知
222
Star
6828
Fork
2962
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
602
列表
看板
标记
里程碑
合并请求
255
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
602
Issue
602
列表
看板
标记
里程碑
合并请求
255
合并请求
255
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
5c5a64a3
编写于
7月 01, 2017
作者:
S
Superjom
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
utils.py refactor
上级
4f13dbfa
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
85 addition
and
34 deletion
+85
-34
dssm/utils.py
dssm/utils.py
+85
-34
未找到文件。
dssm/utils.py
浏览文件 @
5c5a64a3
...
...
@@ -6,56 +6,100 @@ logger = logging.getLogger("logger")
logger
.
setLevel
(
logging
.
INFO
)
class
TaskType
:
TRAIN_MODE
=
0
TEST_MODE
=
1
INFER_MODE
=
2
def
mode_attr_name
(
mode
):
return
mode
.
upper
()
+
'_MODE'
def
__init__
(
self
,
mode
):
self
.
mode
=
mode
def
is_train
(
self
):
return
self
.
mode
==
self
.
TRAIN_MODE
def
create_attrs
(
cls
):
for
id
,
mode
in
enumerate
(
cls
.
modes
):
setattr
(
cls
,
mode_attr_name
(
mode
),
id
)
def
make_check_method
(
cls
):
'''
create methods for classes.
'''
def
method
(
mode
):
def
_method
(
self
):
return
self
.
mode
==
getattr
(
cls
,
mode_attr_name
(
mode
))
return
_method
for
id
,
mode
in
enumerate
(
cls
.
modes
):
setattr
(
cls
,
'is_'
+
mode
,
method
(
mode
))
def
make_create_method
(
cls
):
def
method
(
mode
):
@
staticmethod
def
_method
():
key
=
getattr
(
cls
,
mode_attr_name
(
mode
))
return
cls
(
key
)
return
_method
for
id
,
mode
in
enumerate
(
cls
.
modes
):
setattr
(
cls
,
'create_'
+
mode
,
method
(
mode
))
def
is_test
(
self
):
return
self
.
mode
==
self
.
TEST_MODE
def
is_infer
(
self
):
return
self
.
mode
==
self
.
INFER_MODE
def
make_str_method
(
cls
):
def
_str_
(
self
):
for
mode
in
cls
.
modes
:
if
self
.
mode
==
getattr
(
cls
,
mode_attr_name
(
mode
)):
return
mode
@
staticmethod
def
create_train
():
return
TaskType
(
TaskType
.
TRAIN_MODE
)
def
_hash_
(
self
):
return
self
.
mode
@
staticmethod
def
create_test
():
return
TaskType
(
TaskType
.
TEST_MODE
)
setattr
(
cls
,
'__str__'
,
_str_
)
setattr
(
cls
,
'__repr__'
,
_str_
)
setattr
(
cls
,
'__hash__'
,
_hash_
)
@
staticmethod
def
create_infer
():
return
TaskType
(
TaskType
.
INFER_MODE
)
def
_init_
(
self
,
mode
,
cls
):
if
isinstance
(
mode
,
int
):
self
.
mode
=
mode
elif
isinstance
(
mode
,
cls
):
self
.
mode
=
mode
.
mode
else
:
raise
def
build_mode_class
(
cls
):
create_attrs
(
cls
)
make_str_method
(
cls
)
make_check_method
(
cls
)
make_create_method
(
cls
)
class
TaskType
(
object
):
# TRAIN_MODE = 0
# TEST_MODE = 1
# INFER_MODE = 2
modes
=
'train test infer'
.
split
()
def
__init__
(
self
,
mode
):
_init_
(
self
,
mode
,
TaskType
)
class
ModelType
:
CLASSIFICATION
=
0
RANK
=
1
modes
=
'classification rank regression'
.
split
()
def
__init__
(
self
,
mode
):
self
.
mode
=
mode
_init_
(
self
,
mode
,
ModelType
)
def
is_classification
(
self
)
:
return
self
.
mode
==
self
.
CLASSIFICATION
class
ModelArch
:
modes
=
'fc cnn rnn'
.
split
()
def
is_rank
(
self
):
return
self
.
mode
==
self
.
RANK
def
__init__
(
self
,
mode
):
_init_
(
self
,
mode
,
ModelArch
)
@
staticmethod
def
create_classification
():
return
ModelType
(
ModelType
.
CLASSIFICATION
)
@
staticmethod
def
create_rank
():
return
ModelType
(
ModelType
.
RANK
)
build_mode_class
(
TaskType
)
build_mode_class
(
ModelType
)
build_mode_class
(
ModelArch
)
def
sent2ids
(
sent
,
vocab
):
...
...
@@ -81,3 +125,10 @@ def load_dic(path):
w
=
line
.
strip
()
dic
[
w
]
=
id
return
dic
if
__name__
==
'__main__'
:
t
=
TaskType
(
1
)
t
=
TaskType
.
create_train
()
print
t
print
'is'
,
t
.
is_train
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录