提交 34868288 编写于 作者: W wuzewu

add finetune checkpoint

上级 16145775
......@@ -16,6 +16,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import numpy as np
from PIL import Image
......@@ -50,7 +51,7 @@ class ImageClassificationReader:
if self.image_width <= 0 or self.image_height <= 0:
raise ValueError("Image width and height should not be negative.")
def data_generator(self, phase, shuffle=False):
def data_generator(self, batch_size, phase="train", shuffle=False):
if phase == "train":
data = self.dataset.train_data(shuffle)
elif phase == "test":
......@@ -81,4 +82,4 @@ class ImageClassificationReader:
image = image[color_mode_dict[self.color_mode], :, :]
yield ((image, label))
return _data_reader
return paddle.batch(_data_reader, batch_size=batch_size)
// Copyright 2018 The Paddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
syntax = "proto3";
option optimize_for = LITE_RUNTIME;
package paddle_hub_finetune_checkpoint;
message CheckPoint {
int64 last_epoch = 1;
int64 last_step = 2;
string last_model_dir = 3;
}
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from paddle_hub.finetune import checkpoint_pb2
def load_checkpoint(checkpoint_path):
ckpt = checkpoint_pb2.CheckPoint()
with open(checkpoint_path, "rb") as file:
ckpt.ParseFromString(file.read())
return ckpt.last_epoch, ckpt.last_step, ckpt.last_model_dir
def save_checkpoint(checkpoint_path, last_epoch, last_step, last_model_dir):
ckpt = checkpoint_pb2.CheckPoint()
ckpt.last_epoch = last_epoch
ckpt.last_step = last_step
ckpt.last_model_dir = last_model_dir
with open(checkpoint_path, "wb") as file:
file.write(ckpt.SerializeToString())
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: checkpoint.proto
import sys
_b = sys.version_info[0] < 3 and (lambda x: x) or (lambda x: x.encode('latin1'))
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
from google.protobuf import descriptor_pb2
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor.FileDescriptor(
name='checkpoint.proto',
package='paddle_hub_finetune_checkpoint',
syntax='proto3',
serialized_pb=_b(
'\n\x10\x63heckpoint.proto\x12\x1epaddle_hub_finetune_checkpoint\"K\n\nCheckPoint\x12\x12\n\nlast_epoch\x18\x01 \x01(\x03\x12\x11\n\tlast_step\x18\x02 \x01(\x03\x12\x16\n\x0elast_model_dir\x18\x03 \x01(\tB\x02H\x03\x62\x06proto3'
))
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
_CHECKPOINT = _descriptor.Descriptor(
name='CheckPoint',
full_name='paddle_hub_finetune_checkpoint.CheckPoint',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='last_epoch',
full_name='paddle_hub_finetune_checkpoint.CheckPoint.last_epoch',
index=0,
number=1,
type=3,
cpp_type=2,
label=1,
has_default_value=False,
default_value=0,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='last_step',
full_name='paddle_hub_finetune_checkpoint.CheckPoint.last_step',
index=1,
number=2,
type=3,
cpp_type=2,
label=1,
has_default_value=False,
default_value=0,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='last_model_dir',
full_name='paddle_hub_finetune_checkpoint.CheckPoint.last_model_dir',
index=2,
number=3,
type=9,
cpp_type=9,
label=1,
has_default_value=False,
default_value=_b("").decode('utf-8'),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
],
extensions=[],
nested_types=[],
enum_types=[],
options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=52,
serialized_end=127,
)
DESCRIPTOR.message_types_by_name['CheckPoint'] = _CHECKPOINT
CheckPoint = _reflection.GeneratedProtocolMessageType(
'CheckPoint',
(_message.Message, ),
dict(
DESCRIPTOR=_CHECKPOINT,
__module__='checkpoint_pb2'
# @@protoc_insertion_point(class_scope:paddle_hub_finetune_checkpoint.CheckPoint)
))
_sym_db.RegisterMessage(CheckPoint)
DESCRIPTOR.has_options = True
DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(),
_b('H\003'))
# @@protoc_insertion_point(module_scope)
......@@ -24,6 +24,9 @@ import paddle.fluid as fluid
from paddle_hub.tools.logger import logger
from paddle_hub.finetune.optimization import bert_finetune
from paddle_hub.finetune.checkpoint import load_checkpoint, save_checkpoint
CKPT_FILE = "ckpt.meta"
def _finetune_model(task,
......@@ -40,9 +43,9 @@ def _finetune_model(task,
batch_size = config.batch_size
learning_rate = config.learning_rate
use_cuda = config.use_cuda
batch_size = config.batch_size
with_memory_optimization = config.with_memory_optimization
checkpoint_dir = config.checkpoint_dir
checkpoint_path = os.path.join(checkpoint_dir, CKPT_FILE)
with fluid.program_guard(main_program, startup_program):
if use_cuda:
......@@ -60,7 +63,7 @@ def _finetune_model(task,
scheduled_lr = bert_finetune(task, main_program, data_processor,
config, dev_count)
elif config.optimizer == "adam":
optimzier = fluid.optimizer.Adam(learning_rate=config.learning_rate)
optimizer = fluid.optimizer.Adam(learning_rate=config.learning_rate)
optimizer.minimize(loss)
#TODO: add more finetune strategy
......@@ -82,18 +85,23 @@ def _finetune_model(task,
program=main_program, batch_size=batch_size)
logger.info("Theoretical memory usage in training: %.3f - %.3f %s" %
(lower_mem, upper_mem, unit)),
# initilize all parameters
exe.run(fluid.default_startup_program())
step = 0
# initilize
if os.path.exists(checkpoint_path):
last_epoch, step, last_model_dir = load_checkpoint(checkpoint_path)
fluid.io.load_persistables(exe, last_model_dir)
else:
exe.run(fluid.default_startup_program())
step = 0
last_epoch = 0
logger.info("Finetune start")
train_time_begin = time.time()
for index in range(epoch):
for index in range(last_epoch, epoch):
train_reader = data_processor.data_generator(
batch_size=batch_size, phase='train')
size = accuracy_sum = loss_sum = 0
for batch in train_reader():
loss_v, accuracy_v = exe.run(
feed=data_feeder.feed([batch]),
feed=data_feeder.feed(batch),
fetch_list=[loss.name, accuracy.name])
step += 1
size += len(batch)
......@@ -111,27 +119,36 @@ def _finetune_model(task,
if step % config.save_ckpt_interval == 0:
model_save_dir = os.path.join(checkpoint_dir,
"step_%d" % step)
"model_in_step_%d" % step)
fluid.io.save_persistables(exe, dirname=model_save_dir)
save_checkpoint(
checkpoint_path,
last_epoch=index,
last_step=step,
last_model_dir=model_save_dir)
if eval_model and step % config.eval_interval == 0:
eval(task, data_processor, feed_list, config)
eval(
task,
data_processor,
feed_list,
phase="validate",
config=config)
# update model and checkpoint
model_save_dir = os.path.join(checkpoint_dir, "model_latest")
fluid.io.save_persistables(exe, dirname=model_save_dir)
save_checkpoint(
checkpoint_path,
last_epoch=epoch + 1,
last_step=step,
last_model_dir=model_save_dir)
# eval before end
if eval_model:
eval(task, data_processor, feed_list, config)
eval(task, data_processor, feed_list, phase="test", config=config)
logger.info("Finetune finished")
def save_model_and_checkpoint(task, save_dir):
pass
def finetune_and_eval(
task,
data_processor,
feed_list,
config=None,
):
def finetune_and_eval(task, data_processor, feed_list, config=None):
_finetune_model(task, data_processor, feed_list, config, eval_model=True)
......@@ -139,7 +156,7 @@ def finetune(task, data_processor, feed_list, config=None):
_finetune_model(task, data_processor, feed_list, config, eval_model=False)
def eval(task, data_processor, feed_list, config=None):
def eval(task, data_processor, feed_list, phase="test", config=None):
inference_program = task.inference_program()
main_program = task.main_program()
loss = task.variable("loss")
......@@ -152,12 +169,11 @@ def eval(task, data_processor, feed_list, config=None):
exe = fluid.Executor(place=place)
size = accuracy_sum = loss_sum = 0
test_reader = data_processor.data_generator(
batch_size=batch_size, phase='test')
batch_size=batch_size, phase=phase)
eval_time_begin = time.time()
for index, batch in enumerate(test_reader()):
loss_v, accuracy_v, = exe.run(
feed=data_feeder.feed([batch]),
fetch_list=[loss, accuracy.name])
feed=data_feeder.feed(batch), fetch_list=[loss, accuracy.name])
size += len(batch)
accuracy_sum += accuracy_v * len(batch)
loss_sum += loss_v * len(batch)
......
#/bin/bash
protoc -I=../paddle_hub/module --python_out=../paddle_hub/module ../paddle_hub/module/module_desc.proto
protoc -I=../paddle_hub/module --python_out=../paddle_hub/module ../paddle_hub/module/check_info.proto
protoc -I=../paddle_hub/finetune --python_out=../paddle_hub/finetune ../paddle_hub/finetune/checkpoint.proto
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册