diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/darknet.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/darknet.py new file mode 100644 index 0000000000000000000000000000000000000000..cdf478f87107ae96e4638e0b4b168fbf0f042eb1 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/darknet.py @@ -0,0 +1,184 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +import paddle.fluid as fluid +from paddle.fluid.param_attr import ParamAttr +from paddle.fluid.regularizer import L2Decay + +from paddle.fluid.dygraph.nn import Conv2D, BatchNorm +from paddle.fluid.dygraph.base import to_variable + + +class ConvBNLayer(fluid.dygraph.Layer): + def __init__(self, + ch_in, + ch_out, + filter_size=3, + stride=1, + groups=1, + padding=0, + act="leaky", + is_test=True): + super(ConvBNLayer, self).__init__() + + self.conv = Conv2D( + num_channels=ch_in, + num_filters=ch_out, + filter_size=filter_size, + stride=stride, + padding=padding, + groups=groups, + param_attr=ParamAttr( + initializer=fluid.initializer.Normal(0., 0.02)), + bias_attr=False, + act=None) + self.batch_norm = BatchNorm( + num_channels=ch_out, + is_test=is_test, + param_attr=ParamAttr( + initializer=fluid.initializer.Normal(0., 0.02), + regularizer=L2Decay(0.)), + bias_attr=ParamAttr( + initializer=fluid.initializer.Constant(0.0), + regularizer=L2Decay(0.))) + + self.act = act + + def forward(self, inputs): + out = self.conv(inputs) + out = self.batch_norm(out) + if self.act == 'leaky': + out = fluid.layers.leaky_relu(x=out, alpha=0.1) + return out + + +class DownSample(fluid.dygraph.Layer): + def __init__(self, + ch_in, + ch_out, + filter_size=3, + stride=2, + padding=1, + is_test=True): + + super(DownSample, self).__init__() + + self.conv_bn_layer = ConvBNLayer( + ch_in=ch_in, + ch_out=ch_out, + filter_size=filter_size, + stride=stride, + padding=padding, + is_test=is_test) + self.ch_out = ch_out + + def forward(self, inputs): + out = self.conv_bn_layer(inputs) + return out + + +class BasicBlock(fluid.dygraph.Layer): + def __init__(self, ch_in, ch_out, is_test=True): + super(BasicBlock, self).__init__() + + self.conv1 = ConvBNLayer( + ch_in=ch_in, + ch_out=ch_out, + filter_size=1, + stride=1, + padding=0, + is_test=is_test) + self.conv2 = ConvBNLayer( + ch_in=ch_out, + ch_out=ch_out * 2, + filter_size=3, + stride=1, + padding=1, + is_test=is_test) + + def forward(self, inputs): + conv1 = self.conv1(inputs) + conv2 = self.conv2(conv1) + out = fluid.layers.elementwise_add(x=inputs, y=conv2, act=None) + return out + + +class LayerWarp(fluid.dygraph.Layer): + def __init__(self, ch_in, ch_out, count, is_test=True): + super(LayerWarp, self).__init__() + + self.basicblock0 = BasicBlock(ch_in, ch_out, is_test=is_test) + self.res_out_list = [] + for i in range(1, count): + res_out = self.add_sublayer( + "basic_block_%d" % (i), + BasicBlock( + ch_out * 2, ch_out, is_test=is_test)) + self.res_out_list.append(res_out) + self.ch_out = ch_out + + def forward(self, inputs): + y = self.basicblock0(inputs) + for basic_block_i in self.res_out_list: + y = basic_block_i(y) + return y + + +DarkNet_cfg = {53: ([1, 2, 8, 8, 4])} + + +class DarkNet53_conv_body(fluid.dygraph.Layer): + def __init__(self, ch_in=3, is_test=True): + super(DarkNet53_conv_body, self).__init__() + self.stages = DarkNet_cfg[53] + self.stages = self.stages[0:5] + + self.conv0 = ConvBNLayer( + ch_in=ch_in, + ch_out=32, + filter_size=3, + stride=1, + padding=1, + is_test=is_test) + + self.downsample0 = DownSample(ch_in=32, ch_out=32 * 2, is_test=is_test) + self.darknet53_conv_block_list = [] + self.downsample_list = [] + ch_in = [64, 128, 256, 512, 1024] + for i, stage in enumerate(self.stages): + conv_block = self.add_sublayer( + "stage_%d" % (i), + LayerWarp( + int(ch_in[i]), 32 * (2**i), stage, is_test=is_test)) + self.darknet53_conv_block_list.append(conv_block) + for i in range(len(self.stages) - 1): + downsample = self.add_sublayer( + "stage_%d_downsample" % i, + DownSample( + ch_in=32 * (2**(i + 1)), + ch_out=32 * (2**(i + 2)), + is_test=is_test)) + self.downsample_list.append(downsample) + + def forward(self, inputs): + + out = self.conv0(inputs) + out = self.downsample0(out) + blocks = [] + for i, conv_block_i in enumerate(self.darknet53_conv_block_list): + out = conv_block_i(out) + blocks.append(out) + if i < len(self.stages) - 1: + out = self.downsample_list[i](out) + return blocks[-1:-4:-1] diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_yolov3.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_yolov3.py new file mode 100644 index 0000000000000000000000000000000000000000..53ee32cb1d0ae5a66a631bab2147af3dc398ca57 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_yolov3.py @@ -0,0 +1,174 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import random +import time +import unittest + +import paddle.fluid as fluid +from paddle.fluid.dygraph import ProgramTranslator +from paddle.fluid.dygraph import to_variable + +from yolov3 import cfg, YOLOv3 + +random.seed(0) +np.random.seed(0) + + +class SmoothedValue(object): + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self): + self.loss_sum = 0.0 + self.iter_cnt = 0 + + def add_value(self, value): + self.loss_sum += np.mean(value) + self.iter_cnt += 1 + + def get_mean_value(self): + return self.loss_sum / self.iter_cnt + + +class FakeDataReader(object): + def __init__(self): + self.generator_out = [] + self.total_iter = cfg.max_iter + for i in range(self.total_iter): + batch_out = [] + for j in range(cfg.batch_size): + img = np.random.normal(0.485, 0.229, + [3, cfg.input_size, cfg.input_size]) + gt_boxes_node1 = np.random.randint( + low=cfg.input_size / 4, + high=cfg.input_size / 2, + size=[1, 2]) + gt_boxes_node2 = gt_boxes_node1 + cfg.input_size / 4 + gt_boxes = np.concatenate( + (gt_boxes_node1, gt_boxes_node2), axis=1) + gt_labels = np.random.randint( + low=0, high=cfg.class_num, size=[1]) + gt_scores = np.zeros([1]) + batch_out.append([img, gt_boxes, gt_labels, gt_scores]) + self.generator_out.append(batch_out) + + def reader(self): + def generator(): + for i in range(self.total_iter): + yield self.generator_out[i] + + return generator + + +fake_data_reader = FakeDataReader() + + +def train(to_static): + program_translator = ProgramTranslator() + program_translator.enable(to_static) + + random.seed(0) + np.random.seed(0) + + place = fluid.CUDAPlace(0) if cfg.use_gpu else fluid.CPUPlace() + with fluid.dygraph.guard(place): + fluid.default_startup_program().random_seed = 1000 + fluid.default_main_program().random_seed = 1000 + model = YOLOv3(3, is_train=True) + + boundaries = cfg.lr_steps + gamma = cfg.lr_gamma + step_num = len(cfg.lr_steps) + learning_rate = cfg.learning_rate + values = [learning_rate * (gamma**i) for i in range(step_num + 1)] + + lr = fluid.dygraph.PiecewiseDecay( + boundaries=boundaries, values=values, begin=0) + + lr = fluid.layers.linear_lr_warmup( + learning_rate=lr, + warmup_steps=cfg.warm_up_iter, + start_lr=0.0, + end_lr=cfg.learning_rate, ) + + optimizer = fluid.optimizer.Momentum( + learning_rate=lr, + regularization=fluid.regularizer.L2Decay(cfg.weight_decay), + momentum=cfg.momentum, + parameter_list=model.parameters()) + + start_time = time.time() + snapshot_loss = 0 + snapshot_time = 0 + total_sample = 0 + + input_size = cfg.input_size + shuffle = True + shuffle_seed = None + total_iter = cfg.max_iter + mixup_iter = total_iter - cfg.no_mixup_iter + + train_reader = FakeDataReader().reader() + + smoothed_loss = SmoothedValue() + ret = [] + for iter_id, data in enumerate(train_reader()): + prev_start_time = start_time + start_time = time.time() + img = np.array([x[0] for x in data]).astype('float32') + img = to_variable(img) + + gt_box = np.array([x[1] for x in data]).astype('float32') + gt_box = to_variable(gt_box) + + gt_label = np.array([x[2] for x in data]).astype('int32') + gt_label = to_variable(gt_label) + + gt_score = np.array([x[3] for x in data]).astype('float32') + gt_score = to_variable(gt_score) + + loss = model(img, gt_box, gt_label, gt_score, None, None) + smoothed_loss.add_value(np.mean(loss.numpy())) + snapshot_loss += loss.numpy() + snapshot_time += start_time - prev_start_time + total_sample += 1 + + print("Iter {:d}, loss {:.6f}, time {:.5f}".format( + iter_id, + smoothed_loss.get_mean_value(), start_time - prev_start_time)) + ret.append(smoothed_loss.get_mean_value()) + + loss.backward() + + optimizer.minimize(loss) + model.clear_gradients() + + return np.array(ret) + + +class TestYolov3(unittest.TestCase): + def test_dygraph_static_same_loss(self): + dygraph_loss = train(to_static=False) + static_loss = train(to_static=True) + self.assertTrue( + np.allclose(dygraph_loss, static_loss), + msg="dygraph_loss: {} \nstatic_loss: {}".format(dygraph_loss, + static_loss)) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/yolov3.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/yolov3.py new file mode 100644 index 0000000000000000000000000000000000000000..ca4cb50d6464eb67e222d299318b6ce0001c2954 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/yolov3.py @@ -0,0 +1,332 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +from __future__ import division +from __future__ import print_function + +import paddle.fluid as fluid +from paddle.fluid.dygraph import declarative +from paddle.fluid.dygraph.base import to_variable +from paddle.fluid.dygraph.nn import Conv2D, BatchNorm +from paddle.fluid.initializer import Constant +from paddle.fluid.initializer import Normal +from paddle.fluid.param_attr import ParamAttr +from paddle.fluid.regularizer import L2Decay + +from darknet import DarkNet53_conv_body +from darknet import ConvBNLayer + + +class AttrDict(dict): + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + + def __getattr__(self, name): + if name in self.__dict__: + return self.__dict__[name] + elif name in self: + return self[name] + else: + raise AttributeError(name) + + def __setattr__(self, name, value): + if name in self.__dict__: + self.__dict__[name] = value + else: + self[name] = value + + +# +# Training options +# +cfg = AttrDict() +# Snapshot period +cfg.snapshot_iter = 2000 +# min valid area for gt boxes +cfg.gt_min_area = -1 +# max target box number in an image +cfg.max_box_num = 50 +# valid score threshold to include boxes +cfg.valid_thresh = 0.005 +# threshold vale for box non-max suppression +cfg.nms_thresh = 0.45 +# the number of top k boxes to perform nms +cfg.nms_topk = 400 +# the number of output boxes after nms +cfg.nms_posk = 100 +# score threshold for draw box in debug mode +cfg.draw_thresh = 0.5 +# Use label smooth in class label +cfg.label_smooth = True +# +# Model options +# +# input size +cfg.input_size = 608 +# pixel mean values +cfg.pixel_means = [0.485, 0.456, 0.406] +# pixel std values +cfg.pixel_stds = [0.229, 0.224, 0.225] +# anchors box weight and height +cfg.anchors = [ + 10, 13, 16, 30, 33, 23, 30, 61, 62, 45, 59, 119, 116, 90, 156, 198, 373, 326 +] +# anchor mask of each yolo layer +cfg.anchor_masks = [[6, 7, 8], [3, 4, 5], [0, 1, 2]] +# IoU threshold to ignore objectness loss of pred box +cfg.ignore_thresh = .7 +# +# SOLVER options +# +# batch size +cfg.batch_size = 8 if fluid.is_compiled_with_cuda() else 4 +# derived learning rate the to get the final learning rate. +cfg.learning_rate = 0.001 +# maximum number of iterations +cfg.max_iter = 20 if fluid.is_compiled_with_cuda() else 2 +# Disable mixup in last N iter +cfg.no_mixup_iter = 10 if fluid.is_compiled_with_cuda() else 1 +# warm up to learning rate +cfg.warm_up_iter = 10 if fluid.is_compiled_with_cuda() else 1 +cfg.warm_up_factor = 0. +# lr steps_with_decay +cfg.lr_steps = [400000, 450000] +cfg.lr_gamma = 0.1 +# L2 regularization hyperparameter +cfg.weight_decay = 0.0005 +# momentum with SGD +cfg.momentum = 0.9 +# +# ENV options +# +# support both CPU and GPU +cfg.use_gpu = fluid.is_compiled_with_cuda() +# Class number +cfg.class_num = 80 + + +class YoloDetectionBlock(fluid.dygraph.Layer): + def __init__(self, ch_in, channel, is_test=True): + super(YoloDetectionBlock, self).__init__() + + assert channel % 2 == 0, \ + "channel {} cannot be divided by 2".format(channel) + + self.conv0 = ConvBNLayer( + ch_in=ch_in, + ch_out=channel, + filter_size=1, + stride=1, + padding=0, + is_test=is_test) + self.conv1 = ConvBNLayer( + ch_in=channel, + ch_out=channel * 2, + filter_size=3, + stride=1, + padding=1, + is_test=is_test) + self.conv2 = ConvBNLayer( + ch_in=channel * 2, + ch_out=channel, + filter_size=1, + stride=1, + padding=0, + is_test=is_test) + self.conv3 = ConvBNLayer( + ch_in=channel, + ch_out=channel * 2, + filter_size=3, + stride=1, + padding=1, + is_test=is_test) + self.route = ConvBNLayer( + ch_in=channel * 2, + ch_out=channel, + filter_size=1, + stride=1, + padding=0, + is_test=is_test) + self.tip = ConvBNLayer( + ch_in=channel, + ch_out=channel * 2, + filter_size=3, + stride=1, + padding=1, + is_test=is_test) + + def forward(self, inputs): + out = self.conv0(inputs) + out = self.conv1(out) + out = self.conv2(out) + out = self.conv3(out) + route = self.route(out) + tip = self.tip(route) + return route, tip + + +class Upsample(fluid.dygraph.Layer): + def __init__(self, scale=2): + super(Upsample, self).__init__() + self.scale = scale + + def forward(self, inputs): + # get dynamic upsample output shape + shape_nchw = fluid.layers.shape(inputs) + shape_hw = fluid.layers.slice( + shape_nchw, axes=[0], starts=[2], ends=[4]) + shape_hw.stop_gradient = True + in_shape = fluid.layers.cast(shape_hw, dtype='int32') + out_shape = in_shape * self.scale + out_shape.stop_gradient = True + + # reisze by actual_shape + out = fluid.layers.resize_nearest( + input=inputs, scale=self.scale, actual_shape=out_shape) + return out + + +class YOLOv3(fluid.dygraph.Layer): + def __init__(self, ch_in, is_train=True, use_random=False): + super(YOLOv3, self).__init__() + + self.is_train = is_train + self.use_random = use_random + + self.block = DarkNet53_conv_body(ch_in=ch_in, is_test=not self.is_train) + self.block_outputs = [] + self.yolo_blocks = [] + self.route_blocks_2 = [] + ch_in_list = [1024, 768, 384] + for i in range(3): + yolo_block = self.add_sublayer( + "yolo_detecton_block_%d" % (i), + YoloDetectionBlock( + ch_in_list[i], + channel=512 // (2**i), + is_test=not self.is_train)) + self.yolo_blocks.append(yolo_block) + + num_filters = len(cfg.anchor_masks[i]) * (cfg.class_num + 5) + + block_out = self.add_sublayer( + "block_out_%d" % (i), + Conv2D( + num_channels=1024 // (2**i), + num_filters=num_filters, + filter_size=1, + stride=1, + padding=0, + act=None, + param_attr=ParamAttr( + initializer=fluid.initializer.Normal(0., 0.02)), + bias_attr=ParamAttr( + initializer=fluid.initializer.Constant(0.0), + regularizer=L2Decay(0.)))) + self.block_outputs.append(block_out) + if i < 2: + route = self.add_sublayer( + "route2_%d" % i, + ConvBNLayer( + ch_in=512 // (2**i), + ch_out=256 // (2**i), + filter_size=1, + stride=1, + padding=0, + is_test=(not self.is_train))) + self.route_blocks_2.append(route) + self.upsample = Upsample() + + @declarative + def forward(self, + inputs, + gtbox=None, + gtlabel=None, + gtscore=None, + im_id=None, + im_shape=None): + self.outputs = [] + self.boxes = [] + self.scores = [] + self.losses = [] + self.downsample = 32 + blocks = self.block(inputs) + for i, block in enumerate(blocks): + if i > 0: + block = fluid.layers.concat(input=[route, block], axis=1) + route, tip = self.yolo_blocks[i](block) + block_out = self.block_outputs[i](tip) + self.outputs.append(block_out) + + if i < 2: + route = self.route_blocks_2[i](route) + route = self.upsample(route) + self.gtbox = gtbox + self.gtlabel = gtlabel + self.gtscore = gtscore + self.im_id = im_id + self.im_shape = im_shape + + # cal loss + for i, out in enumerate(self.outputs): + anchor_mask = cfg.anchor_masks[i] + if self.is_train: + loss = fluid.layers.yolov3_loss( + x=out, + gt_box=self.gtbox, + gt_label=self.gtlabel, + gt_score=self.gtscore, + anchors=cfg.anchors, + anchor_mask=anchor_mask, + class_num=cfg.class_num, + ignore_thresh=cfg.ignore_thresh, + downsample_ratio=self.downsample, + use_label_smooth=cfg.label_smooth) + self.losses.append(fluid.layers.reduce_mean(loss)) + + else: + mask_anchors = [] + for m in anchor_mask: + mask_anchors.append(cfg.anchors[2 * m]) + mask_anchors.append(cfg.anchors[2 * m + 1]) + boxes, scores = fluid.layers.yolo_box( + x=out, + img_size=self.im_shape, + anchors=mask_anchors, + class_num=cfg.class_num, + conf_thresh=cfg.valid_thresh, + downsample_ratio=self.downsample, + name="yolo_box" + str(i)) + self.boxes.append(boxes) + self.scores.append( + fluid.layers.transpose( + scores, perm=[0, 2, 1])) + self.downsample //= 2 + + if not self.is_train: + # get pred + yolo_boxes = fluid.layers.concat(self.boxes, axis=1) + yolo_scores = fluid.layers.concat(self.scores, axis=2) + + pred = fluid.layers.multiclass_nms( + bboxes=yolo_boxes, + scores=yolo_scores, + score_threshold=cfg.valid_thresh, + nms_top_k=cfg.nms_topk, + keep_top_k=cfg.nms_posk, + nms_threshold=cfg.nms_thresh, + background_label=-1) + return pred + else: + return sum(self.losses)