diff --git a/dygraph/models/__init__.py b/dygraph/models/__init__.py index 3f33a7b13073ba5017ce0702ef9cbc88aa70806d..bd79647abc0a625f21a28b230fd16ec38baab1d1 100644 --- a/dygraph/models/__init__.py +++ b/dygraph/models/__init__.py @@ -1,15 +1,15 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from .unet import UNet +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .unet import UNet diff --git a/dygraph/models/unet.py b/dygraph/models/unet.py index cd286486fa9a6cf9f9077ec6f6eb32ee1dc0ef26..7e03f7c0e1f80fdc228fdc6a212df6e50bc3aa2e 100644 --- a/dygraph/models/unet.py +++ b/dygraph/models/unet.py @@ -13,101 +13,48 @@ # limitations under the License. from __future__ import absolute_import -import paddle.fluid as fluid -import os -from os import path as osp -import numpy as np +from __future__ import division +from __future__ import print_function + from collections import OrderedDict -import copy -import math -import time -import tqdm -import cv2 -import yaml -import shutil - -from paddle.fluid.dygraph.base import to_variable - -import utils -import utils.logging as logging -from utils import seconds_to_hms -from utils import ConfusionMatrix -from utils import get_environ_info -import nets -import transforms as T - - -def dict2str(dict_input): - out = '' - for k, v in dict_input.items(): - try: - v = round(float(v), 6) - except: - pass - out = out + '{}={}, '.format(k, v) - return out.strip(', ') - - -class UNet(object): - # DeepLab mobilenet - def __init__(self, - num_classes=2, - upsample_mode='bilinear', - ignore_index=255): - - self.num_classes = num_classes - self.upsample_mode = upsample_mode - self.ignore_index = ignore_index - self.labels = None - self.env_info = get_environ_info() - if self.env_info['place'] == 'cpu': - self.places = fluid.CPUPlace() - else: - self.places = fluid.CUDAPlace(0) +import paddle.fluid as fluid +from paddle.fluid.dygraph import Conv2D, BatchNorm, Pool2D +import contextlib + +regularizer = fluid.regularizer.L2DecayRegularizer(regularization_coeff=0.0) +name_scope = "" + - def build_model(self): - self.model = nets.UNet(self.num_classes, self.upsample_mode) +@contextlib.contextmanager +def scope(name): + global name_scope + bk = name_scope + name_scope = name_scope + name + '/' + yield + name_scope = bk - def arrange_transform(self, transforms, mode='train'): - arrange_transform = T.ArrangeSegmenter - if type(transforms.transforms[-1]).__name__.startswith('Arrange'): - transforms.transforms[-1] = arrange_transform(mode=mode) + +class UNet(fluid.dygraph.Layer): + def __init__(self, num_classes, upsample_mode='bilinear', ignore_index=255): + super().__init__() + self.encode = Encoder() + self.decode = Decode(upsample_mode=upsample_mode) + self.get_logit = GetLogit(64, num_classes) + self.ignore_index = ignore_index + + def forward(self, x, label, mode='train'): + encode_data, short_cuts = self.encode(x) + decode_data = self.decode(encode_data, short_cuts) + logit = self.get_logit(decode_data) + if mode == 'train': + return self._get_loss(logit, label) else: - transforms.transforms.append(arrange_transform(mode=mode)) - - def load_model(self, model_dir): - ckpt_path = osp.join(model_dir, 'model') - para_state_dict, opti_state_dict = fluid.load_dygraph(ckpt_path) - self.model.set_dict(para_state_dict) - - def save_model(self, state_dict, save_dir): - if not osp.isdir(save_dir): - if osp.exists(save_dir): - os.remove(save_dir) - os.makedirs(save_dir) - fluid.save_dygraph(state_dict, osp.join(save_dir, 'model')) - - def default_optimizer(self, - learning_rate, - num_epochs, - num_steps_each_epoch, - parameter_list=None, - lr_decay_power=0.9, - regularization_coeff=4e-5): - decay_step = num_epochs * num_steps_each_epoch - lr_decay = fluid.layers.polynomial_decay( - learning_rate, - decay_step, - end_learning_rate=0, - power=lr_decay_power) - optimizer = fluid.optimizer.Momentum( - lr_decay, - momentum=0.9, - parameter_list=parameter_list, - regularization=fluid.regularizer.L2Decay( - regularization_coeff=regularization_coeff)) - return optimizer + logit = fluid.layers.softmax(logit, axis=1) + logit = fluid.layers.transpose(logit, [0, 2, 3, 1]) + pred = fluid.layers.argmax(logit, axis=3) + pred = fluid.layers.unsqueeze(pred, axes=[3]) + return pred, logit def _get_loss(self, logit, label): mask = label != self.ignore_index @@ -126,181 +73,183 @@ class UNet(object): mask.stop_gradient = True return avg_loss - def train(self, - num_epochs, - train_dataset, - train_batch_size=2, - eval_dataset=None, - save_interval_epochs=1, - log_interval_steps=2, - save_dir='output', - pretrained_weights=None, - resume_weights=None, - optimizer=None, - learning_rate=0.01, - lr_decay_power=0.9, - regularization_coeff=4e-5, - use_vdl=False): - self.labels = train_dataset.labels - self.train_transforms = train_dataset.transforms - self.train_init = locals() - self.begin_epoch = 0 - if optimizer is None: - num_steps_each_epoch = train_dataset.num_samples // train_batch_size - optimizer = self.default_optimizer( - learning_rate=learning_rate, - num_epochs=num_epochs, - num_steps_each_epoch=num_steps_each_epoch, - parameter_list=self.model.parameters(), - lr_decay_power=lr_decay_power, - regularization_coeff=regularization_coeff) - - # to do: 预训练模型加载, resume - - if self.begin_epoch >= num_epochs: - raise ValueError( - ("begin epoch[{}] is larger than num_epochs[{}]").format( - self.begin_epoch, num_epochs)) - - if not osp.isdir(save_dir): - if osp.exists(save_dir): - os.remove(save_dir) - os.makedirs(save_dir) - - # add arrange op to transforms - self.arrange_transform( - transforms=train_dataset.transforms, mode='train') - - if eval_dataset is not None: - self.eval_transforms = eval_dataset.transforms - self.test_transforms = copy.deepcopy(eval_dataset.transforms) - - data_generator = train_dataset.generator( - batch_size=train_batch_size, drop_last=True) - total_num_steps = math.floor( - train_dataset.num_samples / train_batch_size) - - for i in range(self.begin_epoch, num_epochs): - for step, data in enumerate(data_generator()): - images = np.array([d[0] for d in data]) - labels = np.array([d[1] for d in data]).astype('int64') - images = to_variable(images) - labels = to_variable(labels) - logit = self.model(images) - loss = self._get_loss(logit, labels) - loss.backward() - optimizer.minimize(loss) - print("[TRAIN] Epoch={}/{}, Step={}/{}, loss={}".format( - i + 1, num_epochs, step + 1, total_num_steps, loss.numpy())) - - if (i + 1) % save_interval_epochs == 0 or i == num_epochs - 1: - current_save_dir = osp.join(save_dir, "epoch_{}".format(i + 1)) - if not osp.isdir(current_save_dir): - os.makedirs(current_save_dir) - self.save_model(self.model.state_dict(), current_save_dir) - if eval_dataset is not None: - self.model.eval() - self.evaluate(eval_dataset, batch_size=train_batch_size) - self.model.train() - - def evaluate(self, eval_dataset, batch_size=1, epoch_id=None): - """评估。 - - Args: - eval_dataset (paddlex.datasets): 评估数据读取器。 - batch_size (int): 评估时的batch大小。默认1。 - epoch_id (int): 当前评估模型所在的训练轮数。 - return_details (bool): 是否返回详细信息。默认False。 - - Returns: - dict: 当return_details为False时,返回dict。包含关键字:'miou'、'category_iou'、'macc'、 - 'category_acc'和'kappa',分别表示平均iou、各类别iou、平均准确率、各类别准确率和kappa系数。 - tuple (metrics, eval_details):当return_details为True时,增加返回dict (eval_details), - 包含关键字:'confusion_matrix',表示评估的混淆矩阵。 - """ - self.model.eval() - self.arrange_transform(transforms=eval_dataset.transforms, mode='train') - total_steps = math.ceil(eval_dataset.num_samples * 1.0 / batch_size) - conf_mat = ConfusionMatrix(self.num_classes, streaming=True) - data_generator = eval_dataset.generator( - batch_size=batch_size, drop_last=False) - logging.info( - "Start to evaluating(total_samples={}, total_steps={})...".format( - eval_dataset.num_samples, total_steps)) - for step, data in tqdm.tqdm( - enumerate(data_generator()), total=total_steps): - images = np.array([d[0] for d in data]) - labels = np.array([d[1] for d in data]) - images = to_variable(images) - - logit = self.model(images) - pred = fluid.layers.argmax(logit, axis=1) - pred = fluid.layers.unsqueeze(pred, axes=[3]) - pred = pred.numpy() - - mask = labels != self.ignore_index - conf_mat.calculate(pred=pred, label=labels, ignore=mask) - _, iou = conf_mat.mean_iou() - - logging.debug("[EVAL] Epoch={}, Step={}/{}, iou={}".format( - epoch_id, step + 1, total_steps, iou)) - - category_iou, miou = conf_mat.mean_iou() - category_acc, macc = conf_mat.accuracy() - - metrics = OrderedDict( - zip(['miou', 'category_iou', 'macc', 'category_acc', 'kappa'], - [miou, category_iou, macc, category_acc, - conf_mat.kappa()])) - - logging.info('[EVAL] Finished, Epoch={}, {} .'.format( - epoch_id, dict2str(metrics))) - return metrics - - def predict(self, im_file, transforms=None): - """预测。 - Args: - img_file(str|np.ndarray): 预测图像。 - transforms(paddlex.cv.transforms): 数据预处理操作。 - - Returns: - dict: 包含关键字'label_map'和'score_map', 'label_map'存储预测结果灰度图, - 像素值表示对应的类别,'score_map'存储各类别的概率,shape=(h, w, num_classes) - """ - if isinstance(im_file, str): - if not osp.exists(im_file): - raise ValueError( - 'The Image file does not exist: {}'.format(im_file)) - - if transforms is None and not hasattr(self, 'test_transforms'): - raise Exception("transforms need to be defined, now is None.") - if transforms is not None: - self.arrange_transform(transforms=transforms, mode='test') - im, im_info = transforms(im_file) + +class Encoder(fluid.dygraph.Layer): + def __init__(self): + super().__init__() + with scope('encode'): + with scope('block1'): + self.double_conv = DoubleConv(3, 64) + with scope('block1'): + self.down1 = Down(64, 128) + with scope('block2'): + self.down2 = Down(128, 256) + with scope('block3'): + self.down3 = Down(256, 512) + with scope('block4'): + self.down4 = Down(512, 512) + + def forward(self, x): + short_cuts = [] + x = self.double_conv(x) + short_cuts.append(x) + x = self.down1(x) + short_cuts.append(x) + x = self.down2(x) + short_cuts.append(x) + x = self.down3(x) + short_cuts.append(x) + x = self.down4(x) + return x, short_cuts + + +class Decode(fluid.dygraph.Layer): + def __init__(self, upsample_mode='bilinear'): + super().__init__() + with scope('decode'): + with scope('decode1'): + self.up1 = Up(512, 256, upsample_mode) + with scope('decode2'): + self.up2 = Up(256, 128, upsample_mode) + with scope('decode3'): + self.up3 = Up(128, 64, upsample_mode) + with scope('decode4'): + self.up4 = Up(64, 64, upsample_mode) + + def forward(self, x, short_cuts): + x = self.up1(x, short_cuts[3]) + x = self.up2(x, short_cuts[2]) + x = self.up3(x, short_cuts[1]) + x = self.up4(x, short_cuts[0]) + return x + + +class GetLogit(fluid.dygraph.Layer): + def __init__(self): + super().__init__() + + +class DoubleConv(fluid.dygraph.Layer): + def __init__(self, num_channels, num_filters): + super().__init__() + with scope('conv0'): + param_attr = fluid.ParamAttr( + name=name_scope + 'weights', + regularizer=regularizer, + initializer=fluid.initializer.TruncatedNormal( + loc=0.0, scale=0.33)) + self.conv0 = Conv2D( + num_channels=num_channels, + num_filters=num_filters, + filter_size=3, + stride=1, + padding=1, + param_attr=param_attr) + self.bn0 = BatchNorm( + num_channels=num_filters, + param_attr=fluid.ParamAttr( + name=name_scope + 'gamma', regularizer=regularizer), + bias_attr=fluid.ParamAttr( + name=name_scope + 'beta', regularizer=regularizer), + moving_mean_name=name_scope + 'moving_mean', + moving_variance_name=name_scope + 'moving_variance') + with scope('conv1'): + param_attr = fluid.ParamAttr( + name=name_scope + 'weights', + regularizer=regularizer, + initializer=fluid.initializer.TruncatedNormal( + loc=0.0, scale=0.33)) + self.conv1 = Conv2D( + num_channels=num_filters, + num_filters=num_filters, + filter_size=3, + stride=1, + padding=1, + param_attr=param_attr) + self.bn1 = BatchNorm( + num_channels=num_filters, + param_attr=fluid.ParamAttr( + name=name_scope + 'gamma', regularizer=regularizer), + bias_attr=fluid.ParamAttr( + name=name_scope + 'beta', regularizer=regularizer), + moving_mean_name=name_scope + 'moving_mean', + moving_variance_name=name_scope + 'moving_variance') + + def forward(self, x): + x = self.conv0(x) + x = self.bn0(x) + x = fluid.layers.relu(x) + x = self.conv1(x) + x = self.bn1(x) + x = fluid.layers.relu(x) + return x + + +class Down(fluid.dygraph.Layer): + def __init__(self, num_channels, num_filters): + super().__init__() + with scope("down"): + self.max_pool = Pool2D( + pool_size=2, pool_type='max', pool_stride=2, pool_padding=0) + self.double_conv = DoubleConv(num_channels, num_filters) + + def forward(self, x): + x = self.max_pool(x) + x = self.double_conv(x) + return x + + +class Up(fluid.dygraph.Layer): + def __init__(self, num_channels, num_filters, upsample_mode): + super().__init__() + self.upsample_mode = upsample_mode + with scope('up'): + if upsample_mode == 'bilinear': + self.double_conv = DoubleConv(2 * num_channels, num_filters) + if not upsample_mode == 'bilinear': + param_attr = fluid.ParamAttr( + name=name_scope + 'weights', + regularizer=regularizer, + initializer=fluid.initializer.XavierInitializer(), + ) + self.deconv = fluid.dygraph.Conv2DTranspose( + num_channels=num_channels, + num_filters=num_filters // 2, + filter_size=2, + stride=2, + padding=0, + param_attr=param_attr) + self.double_conv = DoubleConv(num_channels + num_filters // 2, + num_filters) + + def forward(self, x, short_cut): + if self.upsample_mode == 'bilinear': + short_cut_shape = fluid.layers.shape(short_cut) + x = fluid.layers.resize_bilinear(x, short_cut_shape[2:]) else: - self.arrange_transform(transforms=self.test_transforms, mode='test') - im, im_info = self.test_transforms(im_file) - im = np.expand_dims(im, axis=0) - im = to_variable(im) - logit = self.model(im) - logit = fluid.layers.softmax(logit) - pred = fluid.layers.argmax(logit, axis=1) - logit = logit.numpy() - pred = pred.numpy() - - logit = np.squeeze(logit) - logit = np.transpose(logit, (1, 2, 0)) - pred = np.squeeze(pred).astype('uint8') - keys = list(im_info.keys()) - print(pred.shape, logit.shape) - for k in keys[::-1]: - if k == 'shape_before_resize': - h, w = im_info[k][0], im_info[k][1] - pred = cv2.resize(pred, (w, h), cv2.INTER_NEAREST) - logit = cv2.resize(logit, (w, h), cv2.INTER_LINEAR) - elif k == 'shape_before_padding': - h, w = im_info[k][0], im_info[k][1] - pred = pred[0:h, 0:w] - logit = logit[0:h, 0:w, :] - - return {'label_map': pred, 'score_map': logit} + x = self.deconv(x) + x = fluid.layers.concat([x, short_cut], axis=1) + x = self.double_conv(x) + return x + + +class GetLogit(fluid.dygraph.Layer): + def __init__(self, num_channels, num_classes): + super().__init__() + with scope('logit'): + param_attr = fluid.ParamAttr( + name=name_scope + 'weights', + regularizer=regularizer, + initializer=fluid.initializer.TruncatedNormal( + loc=0.0, scale=0.01)) + self.conv = Conv2D( + num_channels=num_channels, + num_filters=num_classes, + filter_size=3, + stride=1, + padding=1, + param_attr=param_attr) + + def forward(self, x): + x = self.conv(x) + return x diff --git a/dygraph/nets/__init__.py b/dygraph/nets/__init__.py deleted file mode 100644 index bd79647abc0a625f21a28b230fd16ec38baab1d1..0000000000000000000000000000000000000000 --- a/dygraph/nets/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from .unet import UNet diff --git a/dygraph/nets/unet.py b/dygraph/nets/unet.py deleted file mode 100644 index 06db99091ab95268e52816c30ac5c060e4b1eee8..0000000000000000000000000000000000000000 --- a/dygraph/nets/unet.py +++ /dev/null @@ -1,234 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from collections import OrderedDict - -import paddle.fluid as fluid -from paddle.fluid.dygraph import Conv2D, BatchNorm, Pool2D -import contextlib - -regularizer = fluid.regularizer.L2DecayRegularizer(regularization_coeff=0.0) -name_scope = "" - - -@contextlib.contextmanager -def scope(name): - global name_scope - bk = name_scope - name_scope = name_scope + name + '/' - yield - name_scope = bk - - -class UNet(fluid.dygraph.Layer): - def __init__( - self, - num_classes, - upsample_mode='bilinear', - ): - super().__init__() - self.encode = Encoder() - self.decode = Decode(upsample_mode=upsample_mode) - self.get_logit = GetLogit(64, num_classes) - - def forward(self, x): - encode_data, short_cuts = self.encode(x) - decode_data = self.decode(encode_data, short_cuts) - logit = self.get_logit(decode_data) - return logit - - -class Encoder(fluid.dygraph.Layer): - def __init__(self): - super().__init__() - with scope('encode'): - with scope('block1'): - self.double_conv = DoubleConv(3, 64) - with scope('block1'): - self.down1 = Down(64, 128) - with scope('block2'): - self.down2 = Down(128, 256) - with scope('block3'): - self.down3 = Down(256, 512) - with scope('block4'): - self.down4 = Down(512, 512) - - def forward(self, x): - short_cuts = [] - x = self.double_conv(x) - short_cuts.append(x) - x = self.down1(x) - short_cuts.append(x) - x = self.down2(x) - short_cuts.append(x) - x = self.down3(x) - short_cuts.append(x) - x = self.down4(x) - return x, short_cuts - - -class Decode(fluid.dygraph.Layer): - def __init__(self, upsample_mode='bilinear'): - super().__init__() - with scope('decode'): - with scope('decode1'): - self.up1 = Up(512, 256, upsample_mode) - with scope('decode2'): - self.up2 = Up(256, 128, upsample_mode) - with scope('decode3'): - self.up3 = Up(128, 64, upsample_mode) - with scope('decode4'): - self.up4 = Up(64, 64, upsample_mode) - - def forward(self, x, short_cuts): - x = self.up1(x, short_cuts[3]) - x = self.up2(x, short_cuts[2]) - x = self.up3(x, short_cuts[1]) - x = self.up4(x, short_cuts[0]) - return x - - -class GetLogit(fluid.dygraph.Layer): - def __init__(self): - super().__init__() - - -class DoubleConv(fluid.dygraph.Layer): - def __init__(self, num_channels, num_filters): - super().__init__() - with scope('conv0'): - param_attr = fluid.ParamAttr( - name=name_scope + 'weights', - regularizer=regularizer, - initializer=fluid.initializer.TruncatedNormal( - loc=0.0, scale=0.33)) - self.conv0 = Conv2D( - num_channels=num_channels, - num_filters=num_filters, - filter_size=3, - stride=1, - padding=1, - param_attr=param_attr) - self.bn0 = BatchNorm( - num_channels=num_filters, - param_attr=fluid.ParamAttr( - name=name_scope + 'gamma', regularizer=regularizer), - bias_attr=fluid.ParamAttr( - name=name_scope + 'beta', regularizer=regularizer), - moving_mean_name=name_scope + 'moving_mean', - moving_variance_name=name_scope + 'moving_variance') - with scope('conv1'): - param_attr = fluid.ParamAttr( - name=name_scope + 'weights', - regularizer=regularizer, - initializer=fluid.initializer.TruncatedNormal( - loc=0.0, scale=0.33)) - self.conv1 = Conv2D( - num_channels=num_filters, - num_filters=num_filters, - filter_size=3, - stride=1, - padding=1, - param_attr=param_attr) - self.bn1 = BatchNorm( - num_channels=num_filters, - param_attr=fluid.ParamAttr( - name=name_scope + 'gamma', regularizer=regularizer), - bias_attr=fluid.ParamAttr( - name=name_scope + 'beta', regularizer=regularizer), - moving_mean_name=name_scope + 'moving_mean', - moving_variance_name=name_scope + 'moving_variance') - - def forward(self, x): - x = self.conv0(x) - x = self.bn0(x) - x = fluid.layers.relu(x) - x = self.conv1(x) - x = self.bn1(x) - x = fluid.layers.relu(x) - return x - - -class Down(fluid.dygraph.Layer): - def __init__(self, num_channels, num_filters): - super().__init__() - with scope("down"): - self.max_pool = Pool2D( - pool_size=2, pool_type='max', pool_stride=2, pool_padding=0) - self.double_conv = DoubleConv(num_channels, num_filters) - - def forward(self, x): - x = self.max_pool(x) - x = self.double_conv(x) - return x - - -class Up(fluid.dygraph.Layer): - def __init__(self, num_channels, num_filters, upsample_mode): - super().__init__() - self.upsample_mode = upsample_mode - with scope('up'): - if upsample_mode == 'bilinear': - self.double_conv = DoubleConv(2 * num_channels, num_filters) - if not upsample_mode == 'bilinear': - param_attr = fluid.ParamAttr( - name=name_scope + 'weights', - regularizer=regularizer, - initializer=fluid.initializer.XavierInitializer(), - ) - self.deconv = fluid.dygraph.Conv2DTranspose( - num_channels=num_channels, - num_filters=num_filters // 2, - filter_size=2, - stride=2, - padding=0, - param_attr=param_attr) - self.double_conv = DoubleConv(num_channels + num_filters // 2, - num_filters) - - def forward(self, x, short_cut): - if self.upsample_mode == 'bilinear': - short_cut_shape = fluid.layers.shape(short_cut) - x = fluid.layers.resize_bilinear(x, short_cut_shape[2:]) - else: - x = self.deconv(x) - x = fluid.layers.concat([x, short_cut], axis=1) - x = self.double_conv(x) - return x - - -class GetLogit(fluid.dygraph.Layer): - def __init__(self, num_channels, num_classes): - super().__init__() - with scope('logit'): - param_attr = fluid.ParamAttr( - name=name_scope + 'weights', - regularizer=regularizer, - initializer=fluid.initializer.TruncatedNormal( - loc=0.0, scale=0.01)) - self.conv = Conv2D( - num_channels=num_channels, - num_filters=num_classes, - filter_size=3, - stride=1, - padding=1, - param_attr=param_attr) - - def forward(self, x): - x = self.conv(x) - return x diff --git a/dygraph/train.py b/dygraph/train.py new file mode 100644 index 0000000000000000000000000000000000000000..712ef170d210fe6085f5ac488253ef64dd59fd58 --- /dev/null +++ b/dygraph/train.py @@ -0,0 +1,228 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import os.path as osp + +from paddle.fluid.dygraph.base import to_variable +import numpy as np +import paddle.fluid as fluid + +from datasets.dataset import Dataset +import transforms as T +import models +import utils.logging as logging +from utils import get_environ_info + + +def parse_args(): + parser = argparse.ArgumentParser(description='Model training') + + # params of model + parser.add_argument( + '--model_name', + dest='model_name', + help="Model type for traing, which is one of ('UNet')", + type=str, + default='UNet') + + # params of dataset + parser.add_argument( + '--data_dir', + dest='data_dir', + help='The root directory of dataset', + type=str) + parser.add_argument( + '--train_list', + dest='train_list', + help='Train list file of dataset', + type=str) + parser.add_argument( + '--val_list', + dest='val_list', + help='Val list file of dataset', + type=str, + default=None) + parser.add_argument( + '--num_classes', + dest='num_classes', + help='Number of classes', + type=int, + default=2) + + # params of training + parser.add_argument( + "--input_size", + dest="input_size", + help="The image size for net inputs.", + nargs=2, + default=[512, 512], + type=int) + parser.add_argument( + '--num_epochs', + dest='num_epochs', + help='Number epochs for training', + type=int, + default=100) + parser.add_argument( + '--batch_size', + dest='batch_size', + help='Mini batch size', + type=int, + default=2) + parser.add_argument( + '--learning_rate', + dest='learning_rate', + help='Learning rate', + type=float, + default=0.01) + parser.add_argument( + '--pretrained_model', + dest='pretrained_model', + help='The path of pretrianed weight', + type=str, + default=None) + parser.add_argument( + '--save_interval_epochs', + dest='save_interval_epochs', + help='The interval epochs for save a model snapshot', + type=int, + default=5) + parser.add_argument( + '--save_dir', + dest='save_dir', + help='The directory for saving the model snapshot', + type=str, + default='./output') + + return parser.parse_args() + + +def train(model, + train_dataset, + eval_dataset=None, + optimizer=None, + save_dir='output', + num_epochs=100, + batch_size=2, + pretrained_model=None, + save_interval_epochs=1): + if not osp.isdir(save_dir): + if osp.exists(save_dir): + os.remove(save_dir) + os.makedirs(save_dir) + + data_generator = train_dataset.generator( + batch_size=batch_size, drop_last=True) + num_steps_each_epoch = train_dataset.num_samples // args.batch_size + + for epoch in range(num_epochs): + for step, data in enumerate(data_generator()): + images = np.array([d[0] for d in data]) + labels = np.array([d[1] for d in data]).astype('int64') + images = to_variable(images) + labels = to_variable(labels) + loss = model(images, labels, mode='train') + loss.backward() + optimizer.minimize(loss) + logging.info("[TRAIN] Epoch={}/{}, Step={}/{}, loss={}".format( + epoch + 1, num_epochs, step + 1, num_steps_each_epoch, + loss.numpy())) + + if ( + epoch + 1 + ) % save_interval_epochs == 0 or num_steps_each_epoch == num_epochs - 1: + current_save_dir = osp.join(save_dir, "epoch_{}".format(epoch + 1)) + if not osp.isdir(current_save_dir): + os.makedirs(current_save_dir) + fluid.save_dygraph(model.state_dict(), + osp.join(current_save_dir, 'model')) + + # if eval_dataset is not None: + # model.eval() + # evaluate(eval_dataset, batch_size=train_batch_size) + # model.train() + + +def arrange_transform(transforms, mode='train'): + arrange_transform = T.ArrangeSegmenter + if type(transforms.transforms[-1]).__name__.startswith('Arrange'): + transforms.transforms[-1] = arrange_transform(mode=mode) + else: + transforms.transforms.append(arrange_transform(mode=mode)) + + +def main(args): + # Creat dataset reader + train_transforms = T.Compose( + [T.Resize(args.input_size), + T.RandomHorizontalFlip(), + T.Normalize()]) + arrange_transform(train_transforms, mode='train') + train_dataset = Dataset( + data_dir=args.data_dir, + file_list=args.train_list, + transforms=train_transforms, + num_workers='auto', + buffer_size=100, + parallel_method='thread', + shuffle=True) + if args.val_list is not None: + eval_transforms = T.Compose([T.Resize(args.input_size), T.Normalize()]) + arrange_transform(train_transforms, mode='eval') + eval_dataset = Dataset( + data_dir=args.data_dir, + file_list=args.val_list, + transforms=eval_transforms, + num_workers='auto', + buffer_size=100, + parallel_method='thread', + shuffle=False) + + if args.model_name == 'UNet': + model = models.UNet(num_classes=args.num_classes) + + # Creat optimizer + num_steps_each_epoch = train_dataset.num_samples // args.batch_size + decay_step = args.num_epochs * num_steps_each_epoch + lr_decay = fluid.layers.polynomial_decay( + args.learning_rate, decay_step, end_learning_rate=0, power=0.9) + optimizer = fluid.optimizer.Momentum( + lr_decay, + momentum=0.9, + parameter_list=model.parameters(), + regularization=fluid.regularizer.L2Decay(regularization_coeff=4e-5)) + + train( + model, + train_dataset, + eval_dataset, + optimizer, + save_dir=args.save_dir, + num_epochs=args.num_epochs, + batch_size=args.batch_size, + pretrained_model=args.pretrained_model, + save_interval_epochs=args.save_interval_epochs) + + +if __name__ == '__main__': + args = parse_args() + env_info = get_environ_info() + if env_info['place'] == 'cpu': + places = fluid.CPUPlace() + else: + places = fluid.CUDAPlace(0) + with fluid.dygraph.guard(places): + main(args)