Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleHub
提交
34868288
P
PaddleHub
项目概览
PaddlePaddle
/
PaddleHub
大约 1 年 前同步成功
通知
282
Star
12117
Fork
2091
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
200
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleHub
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
200
Issue
200
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
34868288
编写于
3月 29, 2019
作者:
W
wuzewu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add finetune checkpoint
上级
16145775
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
211 addition
and
26 deletion
+211
-26
paddle_hub/dataset/cv_reader.py
paddle_hub/dataset/cv_reader.py
+3
-2
paddle_hub/finetune/checkpoint.proto
paddle_hub/finetune/checkpoint.proto
+25
-0
paddle_hub/finetune/checkpoint.py
paddle_hub/finetune/checkpoint.py
+35
-0
paddle_hub/finetune/checkpoint_pb2.py
paddle_hub/finetune/checkpoint_pb2.py
+107
-0
paddle_hub/finetune/finetune.py
paddle_hub/finetune/finetune.py
+40
-24
scripts/gen_proto.sh
scripts/gen_proto.sh
+1
-0
未找到文件。
paddle_hub/dataset/cv_reader.py
浏览文件 @
34868288
...
...
@@ -16,6 +16,7 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
paddle
import
numpy
as
np
from
PIL
import
Image
...
...
@@ -50,7 +51,7 @@ class ImageClassificationReader:
if
self
.
image_width
<=
0
or
self
.
image_height
<=
0
:
raise
ValueError
(
"Image width and height should not be negative."
)
def
data_generator
(
self
,
phase
,
shuffle
=
False
):
def
data_generator
(
self
,
batch_size
,
phase
=
"train"
,
shuffle
=
False
):
if
phase
==
"train"
:
data
=
self
.
dataset
.
train_data
(
shuffle
)
elif
phase
==
"test"
:
...
...
@@ -81,4 +82,4 @@ class ImageClassificationReader:
image
=
image
[
color_mode_dict
[
self
.
color_mode
],
:,
:]
yield
((
image
,
label
))
return
_data_reader
return
paddle
.
batch
(
_data_reader
,
batch_size
=
batch_size
)
paddle_hub/finetune/checkpoint.proto
0 → 100644
浏览文件 @
34868288
// Copyright 2018 The Paddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
syntax
=
"proto3"
;
option
optimize_for
=
LITE_RUNTIME
;
package
paddle_hub_finetune_checkpoint
;
message
CheckPoint
{
int64
last_epoch
=
1
;
int64
last_step
=
2
;
string
last_model_dir
=
3
;
}
paddle_hub/finetune/checkpoint.py
0 → 100644
浏览文件 @
34868288
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
paddle_hub.finetune
import
checkpoint_pb2
def
load_checkpoint
(
checkpoint_path
):
ckpt
=
checkpoint_pb2
.
CheckPoint
()
with
open
(
checkpoint_path
,
"rb"
)
as
file
:
ckpt
.
ParseFromString
(
file
.
read
())
return
ckpt
.
last_epoch
,
ckpt
.
last_step
,
ckpt
.
last_model_dir
def
save_checkpoint
(
checkpoint_path
,
last_epoch
,
last_step
,
last_model_dir
):
ckpt
=
checkpoint_pb2
.
CheckPoint
()
ckpt
.
last_epoch
=
last_epoch
ckpt
.
last_step
=
last_step
ckpt
.
last_model_dir
=
last_model_dir
with
open
(
checkpoint_path
,
"wb"
)
as
file
:
file
.
write
(
ckpt
.
SerializeToString
())
paddle_hub/finetune/checkpoint_pb2.py
0 → 100644
浏览文件 @
34868288
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: checkpoint.proto
import
sys
_b
=
sys
.
version_info
[
0
]
<
3
and
(
lambda
x
:
x
)
or
(
lambda
x
:
x
.
encode
(
'latin1'
))
from
google.protobuf
import
descriptor
as
_descriptor
from
google.protobuf
import
message
as
_message
from
google.protobuf
import
reflection
as
_reflection
from
google.protobuf
import
symbol_database
as
_symbol_database
from
google.protobuf
import
descriptor_pb2
# @@protoc_insertion_point(imports)
_sym_db
=
_symbol_database
.
Default
()
DESCRIPTOR
=
_descriptor
.
FileDescriptor
(
name
=
'checkpoint.proto'
,
package
=
'paddle_hub_finetune_checkpoint'
,
syntax
=
'proto3'
,
serialized_pb
=
_b
(
'
\n\x10\x63
heckpoint.proto
\x12\x1e
paddle_hub_finetune_checkpoint
\"
K
\n\n
CheckPoint
\x12\x12\n\n
last_epoch
\x18\x01
\x01
(
\x03\x12\x11\n\t
last_step
\x18\x02
\x01
(
\x03\x12\x16\n\x0e
last_model_dir
\x18\x03
\x01
(
\t
B
\x02
H
\x03\x62\x06
proto3'
))
_sym_db
.
RegisterFileDescriptor
(
DESCRIPTOR
)
_CHECKPOINT
=
_descriptor
.
Descriptor
(
name
=
'CheckPoint'
,
full_name
=
'paddle_hub_finetune_checkpoint.CheckPoint'
,
filename
=
None
,
file
=
DESCRIPTOR
,
containing_type
=
None
,
fields
=
[
_descriptor
.
FieldDescriptor
(
name
=
'last_epoch'
,
full_name
=
'paddle_hub_finetune_checkpoint.CheckPoint.last_epoch'
,
index
=
0
,
number
=
1
,
type
=
3
,
cpp_type
=
2
,
label
=
1
,
has_default_value
=
False
,
default_value
=
0
,
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'last_step'
,
full_name
=
'paddle_hub_finetune_checkpoint.CheckPoint.last_step'
,
index
=
1
,
number
=
2
,
type
=
3
,
cpp_type
=
2
,
label
=
1
,
has_default_value
=
False
,
default_value
=
0
,
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'last_model_dir'
,
full_name
=
'paddle_hub_finetune_checkpoint.CheckPoint.last_model_dir'
,
index
=
2
,
number
=
3
,
type
=
9
,
cpp_type
=
9
,
label
=
1
,
has_default_value
=
False
,
default_value
=
_b
(
""
).
decode
(
'utf-8'
),
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
],
extensions
=
[],
nested_types
=
[],
enum_types
=
[],
options
=
None
,
is_extendable
=
False
,
syntax
=
'proto3'
,
extension_ranges
=
[],
oneofs
=
[],
serialized_start
=
52
,
serialized_end
=
127
,
)
DESCRIPTOR
.
message_types_by_name
[
'CheckPoint'
]
=
_CHECKPOINT
CheckPoint
=
_reflection
.
GeneratedProtocolMessageType
(
'CheckPoint'
,
(
_message
.
Message
,
),
dict
(
DESCRIPTOR
=
_CHECKPOINT
,
__module__
=
'checkpoint_pb2'
# @@protoc_insertion_point(class_scope:paddle_hub_finetune_checkpoint.CheckPoint)
))
_sym_db
.
RegisterMessage
(
CheckPoint
)
DESCRIPTOR
.
has_options
=
True
DESCRIPTOR
.
_options
=
_descriptor
.
_ParseOptions
(
descriptor_pb2
.
FileOptions
(),
_b
(
'H
\003
'
))
# @@protoc_insertion_point(module_scope)
paddle_hub/finetune/finetune.py
浏览文件 @
34868288
...
...
@@ -24,6 +24,9 @@ import paddle.fluid as fluid
from
paddle_hub.tools.logger
import
logger
from
paddle_hub.finetune.optimization
import
bert_finetune
from
paddle_hub.finetune.checkpoint
import
load_checkpoint
,
save_checkpoint
CKPT_FILE
=
"ckpt.meta"
def
_finetune_model
(
task
,
...
...
@@ -40,9 +43,9 @@ def _finetune_model(task,
batch_size
=
config
.
batch_size
learning_rate
=
config
.
learning_rate
use_cuda
=
config
.
use_cuda
batch_size
=
config
.
batch_size
with_memory_optimization
=
config
.
with_memory_optimization
checkpoint_dir
=
config
.
checkpoint_dir
checkpoint_path
=
os
.
path
.
join
(
checkpoint_dir
,
CKPT_FILE
)
with
fluid
.
program_guard
(
main_program
,
startup_program
):
if
use_cuda
:
...
...
@@ -60,7 +63,7 @@ def _finetune_model(task,
scheduled_lr
=
bert_finetune
(
task
,
main_program
,
data_processor
,
config
,
dev_count
)
elif
config
.
optimizer
==
"adam"
:
optim
zi
er
=
fluid
.
optimizer
.
Adam
(
learning_rate
=
config
.
learning_rate
)
optim
iz
er
=
fluid
.
optimizer
.
Adam
(
learning_rate
=
config
.
learning_rate
)
optimizer
.
minimize
(
loss
)
#TODO: add more finetune strategy
...
...
@@ -82,18 +85,23 @@ def _finetune_model(task,
program
=
main_program
,
batch_size
=
batch_size
)
logger
.
info
(
"Theoretical memory usage in training: %.3f - %.3f %s"
%
(
lower_mem
,
upper_mem
,
unit
)),
# initilize all parameters
exe
.
run
(
fluid
.
default_startup_program
())
step
=
0
# initilize
if
os
.
path
.
exists
(
checkpoint_path
):
last_epoch
,
step
,
last_model_dir
=
load_checkpoint
(
checkpoint_path
)
fluid
.
io
.
load_persistables
(
exe
,
last_model_dir
)
else
:
exe
.
run
(
fluid
.
default_startup_program
())
step
=
0
last_epoch
=
0
logger
.
info
(
"Finetune start"
)
train_time_begin
=
time
.
time
()
for
index
in
range
(
epoch
):
for
index
in
range
(
last_epoch
,
epoch
):
train_reader
=
data_processor
.
data_generator
(
batch_size
=
batch_size
,
phase
=
'train'
)
size
=
accuracy_sum
=
loss_sum
=
0
for
batch
in
train_reader
():
loss_v
,
accuracy_v
=
exe
.
run
(
feed
=
data_feeder
.
feed
(
[
batch
]
),
feed
=
data_feeder
.
feed
(
batch
),
fetch_list
=
[
loss
.
name
,
accuracy
.
name
])
step
+=
1
size
+=
len
(
batch
)
...
...
@@ -111,27 +119,36 @@ def _finetune_model(task,
if
step
%
config
.
save_ckpt_interval
==
0
:
model_save_dir
=
os
.
path
.
join
(
checkpoint_dir
,
"step_%d"
%
step
)
"
model_in_
step_%d"
%
step
)
fluid
.
io
.
save_persistables
(
exe
,
dirname
=
model_save_dir
)
save_checkpoint
(
checkpoint_path
,
last_epoch
=
index
,
last_step
=
step
,
last_model_dir
=
model_save_dir
)
if
eval_model
and
step
%
config
.
eval_interval
==
0
:
eval
(
task
,
data_processor
,
feed_list
,
config
)
eval
(
task
,
data_processor
,
feed_list
,
phase
=
"validate"
,
config
=
config
)
# update model and checkpoint
model_save_dir
=
os
.
path
.
join
(
checkpoint_dir
,
"model_latest"
)
fluid
.
io
.
save_persistables
(
exe
,
dirname
=
model_save_dir
)
save_checkpoint
(
checkpoint_path
,
last_epoch
=
epoch
+
1
,
last_step
=
step
,
last_model_dir
=
model_save_dir
)
# eval before end
if
eval_model
:
eval
(
task
,
data_processor
,
feed_list
,
config
)
eval
(
task
,
data_processor
,
feed_list
,
phase
=
"test"
,
config
=
config
)
logger
.
info
(
"Finetune finished"
)
def
save_model_and_checkpoint
(
task
,
save_dir
):
pass
def
finetune_and_eval
(
task
,
data_processor
,
feed_list
,
config
=
None
,
):
def
finetune_and_eval
(
task
,
data_processor
,
feed_list
,
config
=
None
):
_finetune_model
(
task
,
data_processor
,
feed_list
,
config
,
eval_model
=
True
)
...
...
@@ -139,7 +156,7 @@ def finetune(task, data_processor, feed_list, config=None):
_finetune_model
(
task
,
data_processor
,
feed_list
,
config
,
eval_model
=
False
)
def
eval
(
task
,
data_processor
,
feed_list
,
config
=
None
):
def
eval
(
task
,
data_processor
,
feed_list
,
phase
=
"test"
,
config
=
None
):
inference_program
=
task
.
inference_program
()
main_program
=
task
.
main_program
()
loss
=
task
.
variable
(
"loss"
)
...
...
@@ -152,12 +169,11 @@ def eval(task, data_processor, feed_list, config=None):
exe
=
fluid
.
Executor
(
place
=
place
)
size
=
accuracy_sum
=
loss_sum
=
0
test_reader
=
data_processor
.
data_generator
(
batch_size
=
batch_size
,
phase
=
'test'
)
batch_size
=
batch_size
,
phase
=
phase
)
eval_time_begin
=
time
.
time
()
for
index
,
batch
in
enumerate
(
test_reader
()):
loss_v
,
accuracy_v
,
=
exe
.
run
(
feed
=
data_feeder
.
feed
([
batch
]),
fetch_list
=
[
loss
,
accuracy
.
name
])
feed
=
data_feeder
.
feed
(
batch
),
fetch_list
=
[
loss
,
accuracy
.
name
])
size
+=
len
(
batch
)
accuracy_sum
+=
accuracy_v
*
len
(
batch
)
loss_sum
+=
loss_v
*
len
(
batch
)
...
...
scripts/gen_proto.sh
浏览文件 @
34868288
#/bin/bash
protoc
-I
=
../paddle_hub/module
--python_out
=
../paddle_hub/module ../paddle_hub/module/module_desc.proto
protoc
-I
=
../paddle_hub/module
--python_out
=
../paddle_hub/module ../paddle_hub/module/check_info.proto
protoc
-I
=
../paddle_hub/finetune
--python_out
=
../paddle_hub/finetune ../paddle_hub/finetune/checkpoint.proto
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录