diff --git a/paddle_hub/dataset/cv_reader.py b/paddle_hub/dataset/cv_reader.py index a801ec43e8458d24c3c66ef08f57cda6a210e120..9d4facc2508e0cc638b4951a0929b20a68cb5065 100644 --- a/paddle_hub/dataset/cv_reader.py +++ b/paddle_hub/dataset/cv_reader.py @@ -16,6 +16,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import paddle import numpy as np from PIL import Image @@ -50,7 +51,7 @@ class ImageClassificationReader: if self.image_width <= 0 or self.image_height <= 0: raise ValueError("Image width and height should not be negative.") - def data_generator(self, phase, shuffle=False): + def data_generator(self, batch_size, phase="train", shuffle=False): if phase == "train": data = self.dataset.train_data(shuffle) elif phase == "test": @@ -81,4 +82,4 @@ class ImageClassificationReader: image = image[color_mode_dict[self.color_mode], :, :] yield ((image, label)) - return _data_reader + return paddle.batch(_data_reader, batch_size=batch_size) diff --git a/paddle_hub/finetune/checkpoint.proto b/paddle_hub/finetune/checkpoint.proto new file mode 100644 index 0000000000000000000000000000000000000000..3cee62e31d34772bbd8a5fd4f4a42895cac23c61 --- /dev/null +++ b/paddle_hub/finetune/checkpoint.proto @@ -0,0 +1,25 @@ +// Copyright 2018 The Paddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +syntax = "proto3"; +option optimize_for = LITE_RUNTIME; + +package paddle_hub_finetune_checkpoint; + +message CheckPoint { + int64 last_epoch = 1; + int64 last_step = 2; + string last_model_dir = 3; +} diff --git a/paddle_hub/finetune/checkpoint.py b/paddle_hub/finetune/checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..ca6e166cf54869cc3ccd4a93ad21ce2ab80a37b7 --- /dev/null +++ b/paddle_hub/finetune/checkpoint.py @@ -0,0 +1,35 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from paddle_hub.finetune import checkpoint_pb2 + + +def load_checkpoint(checkpoint_path): + ckpt = checkpoint_pb2.CheckPoint() + with open(checkpoint_path, "rb") as file: + ckpt.ParseFromString(file.read()) + return ckpt.last_epoch, ckpt.last_step, ckpt.last_model_dir + + +def save_checkpoint(checkpoint_path, last_epoch, last_step, last_model_dir): + ckpt = checkpoint_pb2.CheckPoint() + ckpt.last_epoch = last_epoch + ckpt.last_step = last_step + ckpt.last_model_dir = last_model_dir + with open(checkpoint_path, "wb") as file: + file.write(ckpt.SerializeToString()) diff --git a/paddle_hub/finetune/checkpoint_pb2.py b/paddle_hub/finetune/checkpoint_pb2.py new file mode 100644 index 0000000000000000000000000000000000000000..2661808cd6e326d96767e7d7a2723f48b9128a91 --- /dev/null +++ b/paddle_hub/finetune/checkpoint_pb2.py @@ -0,0 +1,107 @@ +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: checkpoint.proto + +import sys +_b = sys.version_info[0] < 3 and (lambda x: x) or (lambda x: x.encode('latin1')) +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database +from google.protobuf import descriptor_pb2 +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + +DESCRIPTOR = _descriptor.FileDescriptor( + name='checkpoint.proto', + package='paddle_hub_finetune_checkpoint', + syntax='proto3', + serialized_pb=_b( + '\n\x10\x63heckpoint.proto\x12\x1epaddle_hub_finetune_checkpoint\"K\n\nCheckPoint\x12\x12\n\nlast_epoch\x18\x01 \x01(\x03\x12\x11\n\tlast_step\x18\x02 \x01(\x03\x12\x16\n\x0elast_model_dir\x18\x03 \x01(\tB\x02H\x03\x62\x06proto3' + )) +_sym_db.RegisterFileDescriptor(DESCRIPTOR) + +_CHECKPOINT = _descriptor.Descriptor( + name='CheckPoint', + full_name='paddle_hub_finetune_checkpoint.CheckPoint', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='last_epoch', + full_name='paddle_hub_finetune_checkpoint.CheckPoint.last_epoch', + index=0, + number=1, + type=3, + cpp_type=2, + label=1, + has_default_value=False, + default_value=0, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='last_step', + full_name='paddle_hub_finetune_checkpoint.CheckPoint.last_step', + index=1, + number=2, + type=3, + cpp_type=2, + label=1, + has_default_value=False, + default_value=0, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='last_model_dir', + full_name='paddle_hub_finetune_checkpoint.CheckPoint.last_model_dir', + index=2, + number=3, + type=9, + cpp_type=9, + label=1, + has_default_value=False, + default_value=_b("").decode('utf-8'), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None), + ], + extensions=[], + nested_types=[], + enum_types=[], + options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[], + serialized_start=52, + serialized_end=127, +) + +DESCRIPTOR.message_types_by_name['CheckPoint'] = _CHECKPOINT + +CheckPoint = _reflection.GeneratedProtocolMessageType( + 'CheckPoint', + (_message.Message, ), + dict( + DESCRIPTOR=_CHECKPOINT, + __module__='checkpoint_pb2' + # @@protoc_insertion_point(class_scope:paddle_hub_finetune_checkpoint.CheckPoint) + )) +_sym_db.RegisterMessage(CheckPoint) + +DESCRIPTOR.has_options = True +DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), + _b('H\003')) +# @@protoc_insertion_point(module_scope) diff --git a/paddle_hub/finetune/finetune.py b/paddle_hub/finetune/finetune.py index 44f68f4410c2d3dce868c6ade4ed8d2083b643d8..24603b655723836473ee2812e1acc25cd746387e 100644 --- a/paddle_hub/finetune/finetune.py +++ b/paddle_hub/finetune/finetune.py @@ -24,6 +24,9 @@ import paddle.fluid as fluid from paddle_hub.tools.logger import logger from paddle_hub.finetune.optimization import bert_finetune +from paddle_hub.finetune.checkpoint import load_checkpoint, save_checkpoint + +CKPT_FILE = "ckpt.meta" def _finetune_model(task, @@ -40,9 +43,9 @@ def _finetune_model(task, batch_size = config.batch_size learning_rate = config.learning_rate use_cuda = config.use_cuda - batch_size = config.batch_size with_memory_optimization = config.with_memory_optimization checkpoint_dir = config.checkpoint_dir + checkpoint_path = os.path.join(checkpoint_dir, CKPT_FILE) with fluid.program_guard(main_program, startup_program): if use_cuda: @@ -60,7 +63,7 @@ def _finetune_model(task, scheduled_lr = bert_finetune(task, main_program, data_processor, config, dev_count) elif config.optimizer == "adam": - optimzier = fluid.optimizer.Adam(learning_rate=config.learning_rate) + optimizer = fluid.optimizer.Adam(learning_rate=config.learning_rate) optimizer.minimize(loss) #TODO: add more finetune strategy @@ -82,18 +85,23 @@ def _finetune_model(task, program=main_program, batch_size=batch_size) logger.info("Theoretical memory usage in training: %.3f - %.3f %s" % (lower_mem, upper_mem, unit)), - # initilize all parameters - exe.run(fluid.default_startup_program()) - step = 0 + # initilize + if os.path.exists(checkpoint_path): + last_epoch, step, last_model_dir = load_checkpoint(checkpoint_path) + fluid.io.load_persistables(exe, last_model_dir) + else: + exe.run(fluid.default_startup_program()) + step = 0 + last_epoch = 0 logger.info("Finetune start") train_time_begin = time.time() - for index in range(epoch): + for index in range(last_epoch, epoch): train_reader = data_processor.data_generator( batch_size=batch_size, phase='train') size = accuracy_sum = loss_sum = 0 for batch in train_reader(): loss_v, accuracy_v = exe.run( - feed=data_feeder.feed([batch]), + feed=data_feeder.feed(batch), fetch_list=[loss.name, accuracy.name]) step += 1 size += len(batch) @@ -111,27 +119,36 @@ def _finetune_model(task, if step % config.save_ckpt_interval == 0: model_save_dir = os.path.join(checkpoint_dir, - "step_%d" % step) + "model_in_step_%d" % step) fluid.io.save_persistables(exe, dirname=model_save_dir) + save_checkpoint( + checkpoint_path, + last_epoch=index, + last_step=step, + last_model_dir=model_save_dir) if eval_model and step % config.eval_interval == 0: - eval(task, data_processor, feed_list, config) + eval( + task, + data_processor, + feed_list, + phase="validate", + config=config) + # update model and checkpoint + model_save_dir = os.path.join(checkpoint_dir, "model_latest") + fluid.io.save_persistables(exe, dirname=model_save_dir) + save_checkpoint( + checkpoint_path, + last_epoch=epoch + 1, + last_step=step, + last_model_dir=model_save_dir) # eval before end if eval_model: - eval(task, data_processor, feed_list, config) + eval(task, data_processor, feed_list, phase="test", config=config) logger.info("Finetune finished") -def save_model_and_checkpoint(task, save_dir): - pass - - -def finetune_and_eval( - task, - data_processor, - feed_list, - config=None, -): +def finetune_and_eval(task, data_processor, feed_list, config=None): _finetune_model(task, data_processor, feed_list, config, eval_model=True) @@ -139,7 +156,7 @@ def finetune(task, data_processor, feed_list, config=None): _finetune_model(task, data_processor, feed_list, config, eval_model=False) -def eval(task, data_processor, feed_list, config=None): +def eval(task, data_processor, feed_list, phase="test", config=None): inference_program = task.inference_program() main_program = task.main_program() loss = task.variable("loss") @@ -152,12 +169,11 @@ def eval(task, data_processor, feed_list, config=None): exe = fluid.Executor(place=place) size = accuracy_sum = loss_sum = 0 test_reader = data_processor.data_generator( - batch_size=batch_size, phase='test') + batch_size=batch_size, phase=phase) eval_time_begin = time.time() for index, batch in enumerate(test_reader()): loss_v, accuracy_v, = exe.run( - feed=data_feeder.feed([batch]), - fetch_list=[loss, accuracy.name]) + feed=data_feeder.feed(batch), fetch_list=[loss, accuracy.name]) size += len(batch) accuracy_sum += accuracy_v * len(batch) loss_sum += loss_v * len(batch) diff --git a/scripts/gen_proto.sh b/scripts/gen_proto.sh index 6c4481b455faa19481a93551abca64f50836f842..f5b614d70099f95aa5d69915671f3b0d1fa32075 100755 --- a/scripts/gen_proto.sh +++ b/scripts/gen_proto.sh @@ -1,3 +1,4 @@ #/bin/bash protoc -I=../paddle_hub/module --python_out=../paddle_hub/module ../paddle_hub/module/module_desc.proto protoc -I=../paddle_hub/module --python_out=../paddle_hub/module ../paddle_hub/module/check_info.proto +protoc -I=../paddle_hub/finetune --python_out=../paddle_hub/finetune ../paddle_hub/finetune/checkpoint.proto