提交 4e3828a6 编写于 作者: P parap1uie-s 提交者: wangguanzhong

Add UNet method of AGE challenge baseline (#3065)

Add UNet method of AGE challenge baseline
上级 76711d60
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# [Angle closure Glaucoma Evaluation Challenge](https://age.grand-challenge.org/Details/)\n",
"## Scleral spur localization Baseline (ResNet50+UNet)\n",
"\n",
"- To keep model training stable, images with coordinate == -1, were removed.\n",
"\n",
"- For real inference, you MIGHT keep all images in val_file_path file."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Training\n",
"\n",
"- Assume `Training100.zip` and `Validation_ASOCT_Image.zip` are stored @ `./AGE_challenge Baseline/datasets/`\n",
"- Assume `weights` are stored @ `./AGE_challenge Baseline/weights/`"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Download ImageNet weight"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"--2019-08-08 16:00:14-- https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_pretrained.tar\n",
"Resolving paddle-imagenet-models-name.bj.bcebos.com (paddle-imagenet-models-name.bj.bcebos.com)... 202.106.5.21, 111.206.47.194\n",
"Connecting to paddle-imagenet-models-name.bj.bcebos.com (paddle-imagenet-models-name.bj.bcebos.com)|202.106.5.21|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 102717440 (98M) [application/x-tar]\n",
"Saving to: ‘../weights/ResNet50_pretrained.tar’\n",
"\n",
"ResNet50_pretrained 100%[===================>] 97.96M 2.93MB/s in 34s \n",
"\n",
"2019-08-08 16:00:48 (2.90 MB/s) - ‘../weights/ResNet50_pretrained.tar’ saved [102717440/102717440]\n",
"\n"
]
}
],
"source": [
"# https://github.com/PaddlePaddle/models/tree/develop/PaddleCV/image_classification\n",
"!rm ../weights/ResNet50_pretrained.tar \n",
"!rm -rf ../weights/ResNet50_pretrained\n",
"\n",
"!wget -P ../weights/ https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_pretrained.tar \n",
"!tar xvf ../weights/ResNet50_pretrained.tar -C ../weights/ > /dev/null # silent\n",
"!rm ../weights/ResNet50_pretrained/fc*"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Main Code"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import os, random, functools, math\n",
"import cv2\n",
"import numpy as np\n",
"import time"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Running Verify Fluid Program ... \n",
"Your Paddle Fluid works well on SINGLE GPU or CPU.\n",
"Your Paddle Fluid works well on MUTIPLE GPU or CPU.\n",
"Your Paddle Fluid is installed successfully! Let's start deep Learning with Paddle Fluid now\n"
]
}
],
"source": [
"import paddle\n",
"import paddle.fluid as fluid\n",
"import paddle.fluid.layers as FL\n",
"import paddle.fluid.optimizer as FO\n",
"fluid.install_check.run_check()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# preprocess: extract left/right label col in Training100_Location.xlsx\n",
"# save to train_csv file\n",
"data_root_path = \"../datasets/Training100/\"\n",
"image_path = os.path.join(data_root_path, \"ASOCT_Image_loc\")\n",
"\n",
"train_file_path = os.path.join(data_root_path, \"loc_train_split.csv\")\n",
"val_file_path = os.path.join(data_root_path, \"loc_val_split.csv\")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"BATCH_SIZE = 8\n",
"THREAD = 8\n",
"BUF_SIZE = 32"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"# Remove last global pooling and fullyconnect layer to enable UNet arch.\n",
"# Standard ResNet Implement: \n",
"# https://github.com/PaddlePaddle/models/blob/develop/PaddleCV/image_classification/models/resnet.py\n",
"from resnet import *\n",
"from res_unet_paddle import *"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Define Data Loader"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"def vflip_image(image):\n",
" return cv2.flip(image, flipCode=1)\n",
"\n",
"def gaussian_k(x0,y0, sigma, width, height):\n",
" \"\"\" Make a square gaussian kernel centered at (x0, y0) with sigma as SD.\n",
" \"\"\"\n",
" x = np.arange(0, width, 1, float) ## (width,)\n",
" y = np.arange(0, height, 1, float)[:, np.newaxis] ## (height,1)\n",
" return np.exp(-((x-x0)**2 + (y-y0)**2) / (2*sigma**2))\n",
"\n",
"def generate_hm(height, width, point, s=10):\n",
" \"\"\" Generate a full Heap Map for every landmarks in an array\n",
" Args:\n",
" height : The height of Heat Map (the height of target output)\n",
" width : The width of Heat Map (the width of target output)\n",
" point : (x,y)\n",
" \"\"\"\n",
" hm = gaussian_k(point[0], point[1], s, height, width)\n",
" return hm"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"def reader(img_path, file_list, batch_size=32, shuffle=True, shuffle_seed=42):\n",
" def read_file_list():\n",
" batch_data = []\n",
" np.random.shuffle(file_list)\n",
" for line in file_list:\n",
" file_name, p_x, p_y = line.split(\",\")\n",
" batch_data.append([file_name, float(p_x), float(p_y)])\n",
" if len(batch_data) == batch_size:\n",
" yield batch_data\n",
" batch_data = []\n",
" if len(batch_data) != 0:\n",
" yield batch_data\n",
" return read_file_list\n",
"\n",
"def process_batch_data(input_data, mode, rotate=True, flip=True):\n",
" batch_data = []\n",
" for sample in input_data:\n",
" file, p_x, p_y = sample\n",
" \n",
" img = cv2.imread( file )\n",
" img = img[:, :, ::-1].astype('float32') / 255.0\n",
" \n",
" ratio = 256.0 / img.shape[0]\n",
" p_x, p_y = p_x * ratio, p_y * ratio\n",
" img = cv2.resize(img, (256, 256))\n",
"\n",
" if mode == 'train':\n",
" img = img + np.random.randn(*img.shape) * 0.3 / 255 \n",
" if flip and np.random.randint(0,2):\n",
" img = vflip_image(img)\n",
" p_x = 256 - p_x\n",
" else:\n",
" pass\n",
" \n",
" hm = generate_hm(256, 256, (p_x, p_y))\n",
" img = img.transpose((2, 0, 1))\n",
" batch_data.append((img, hm))\n",
"\n",
" return batch_data"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"def data_loader(img_list, img_path, batch_size, order=False, mode='train'):\n",
" data_reader = reader(img_path, img_list, batch_size)\n",
" mapper = functools.partial(process_batch_data, mode=mode)\n",
" \n",
" data_reader = paddle.reader.shuffle(data_reader, 32)\n",
" \n",
" return paddle.reader.xmap_readers(\n",
" mapper, data_reader, THREAD, BUF_SIZE, order=order)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"with open(train_file_path) as flist:\n",
" train_file_list = [os.path.join(image_path,line.strip()) for line in flist]\n",
"\n",
"with open(val_file_path) as flist:\n",
" val_file_list = [os.path.join(image_path,line.strip()) for line in flist] "
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2586\n",
"607\n",
"../datasets/Training100/ASOCT_Image_loc/T0056-10_left.jpg,228.83365553922314,466.95960107867666\n"
]
}
],
"source": [
"print(len(train_file_list))\n",
"print(len(val_file_list))\n",
"print(train_file_list[0])"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"np.random.shuffle(train_file_list)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"train_dataloader = data_loader(train_file_list, image_path, BATCH_SIZE, False, mode='train')\n",
"val_dataloader = data_loader(val_file_list, image_path, BATCH_SIZE, True, mode='val')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Define model (compute graph)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"def network():\n",
" data_shape = [3, 256, 256]\n",
" \n",
" model = ResUNet(\n",
" ResNet50().net, 1\n",
" )\n",
" \n",
" input_feature = FL.data(name='pixel', shape=data_shape, dtype='float32')\n",
" hm = FL.data(name='label', shape=data_shape[1:], dtype='float32')\n",
" \n",
" logit = model.net(input_feature)\n",
" pred_hm = FL.squeeze(logit, axes=[1]) # Bs, 256,256\n",
"\n",
" reader = fluid.io.PyReader(feed_list=[input_feature, hm], \n",
" capacity=64, iterable=True, use_double_buffer=True)\n",
"\n",
" cost = FL.square_error_cost(pred_hm, hm)\n",
" loss = FL.mean(cost)\n",
" \n",
" return [loss, pred_hm, reader]"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"def calc_dist(pred_hm, hm):\n",
" hm = np.array(hm)\n",
" \n",
" mean_dis = 0.\n",
" for single_hm, single_pred_hm in zip(hm, pred_hm):\n",
" # Find argmax_x, argmax_y from 2D tensor\n",
" label_x, label_y = np.unravel_index(single_hm.argmax(), single_hm.shape)\n",
" pred_x, pred_y = np.unravel_index(single_pred_hm.argmax(), single_pred_hm.shape)\n",
" mean_dis += np.sqrt((pred_x - label_x) ** 2 + (pred_y - label_y) ** 2)\n",
" \n",
" return mean_dis / hm.shape[0]"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"def train(use_cuda, params_dirname_prefix, pretrained_model=False, EPOCH_NUM=10):\n",
" place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()\n",
" \n",
" startup_prog = fluid.Program()\n",
" train_prog = fluid.Program()\n",
" val_prog = fluid.Program()\n",
"\n",
" with fluid.program_guard(train_prog, startup_prog):\n",
" # fluid.unique_name.guard() to share parameters with test network\n",
" with fluid.unique_name.guard():\n",
" train_loss, train_output, train_reader = network()\n",
" \n",
" optimizer = fluid.optimizer.Adam(learning_rate=1e-4)\n",
" optimizer.minimize(train_loss)\n",
" \n",
" # 定义预测网络\n",
" with fluid.program_guard(val_prog, startup_prog):\n",
" # Use fluid.unique_name.guard() to share parameters with train network\n",
" with fluid.unique_name.guard():\n",
" val_loss, val_output, val_reader = network()\n",
"\n",
" val_prog = val_prog.clone(for_test=True)\n",
"\n",
" train_loss.persistable = True\n",
" val_loss.persistable = True\n",
" val_output.persistable = True\n",
" \n",
" exe = fluid.Executor(place)\n",
" exe.run(startup_prog)\n",
"\n",
" if pretrained_model:\n",
" def if_exist(var):\n",
" return os.path.exists(os.path.join(pretrained_model, var.name))\n",
"\n",
" fluid.io.load_vars(\n",
" exe, pretrained_model, main_program=train_prog, predicate=if_exist)\n",
"\n",
" train_reader.decorate_sample_list_generator( train_dataloader, places=place )\n",
" val_reader.decorate_sample_list_generator( val_dataloader, places=place )\n",
"\n",
" # For training test cost\n",
" def train_test(val_prog, val_reader):\n",
" count = 0\n",
" accumulated = [0,0]\n",
" \n",
" prediction = []\n",
" label_values = []\n",
" \n",
" for tid, val_data in enumerate(val_reader()):\n",
" avg_cost_np = exe.run(\n",
" program=val_prog,\n",
" feed=val_data,\n",
" fetch_list=[val_loss, val_output],\n",
" use_program_cache=True)\n",
" accumulated = [\n",
" x[0] + x[1][0] for x in zip(accumulated, avg_cost_np)\n",
" ]\n",
" prediction.append(avg_cost_np[1])\n",
" label_values.append( np.array(val_data[0]['label']) )\n",
" count += 1\n",
"\n",
" prediction = np.concatenate(prediction, 0)\n",
" label_values = np.concatenate(label_values, 0)\n",
"\n",
" mean_dis = calc_dist(prediction, label_values)\n",
" \n",
" return [x / count for x in accumulated], mean_dis\n",
"\n",
" # main train loop.\n",
" def train_loop():\n",
" step = 0\n",
" best_dist = 65536.\n",
"\n",
" for pass_id in range(EPOCH_NUM):\n",
" data_load_time = time.time()\n",
" for step_id, data_train in enumerate(train_reader()):\n",
" data_load_costtime = time.time() - data_load_time\n",
" start_time = time.time()\n",
" avg_loss_value = exe.run(\n",
" train_prog,\n",
" feed=data_train,\n",
" fetch_list=[train_loss, train_output], \n",
" use_program_cache=True)\n",
" cost_time = time.time() - start_time\n",
" if step_id % 50 == 0:\n",
" mean_dis = calc_dist(avg_loss_value[1], data_train[0]['label'])\n",
" print(\"Pass %d, Epoch %d, Cost %f, EuDis %f, Time %f, LoadTime %f\" % (\n",
" step_id, pass_id, avg_loss_value[0], mean_dis, cost_time, data_load_costtime))\n",
" else:\n",
" pass\n",
" step += 1\n",
" data_load_time = time.time()\n",
"\n",
" avg_cost_test, avg_dist_test = train_test(val_prog, val_reader)\n",
"\n",
" print('Test with Epoch {0}, Loss {1:2.4}, EuDis {2:2.4}'.format(\n",
" pass_id, avg_cost_test[0], avg_dist_test))\n",
"\n",
" if avg_dist_test < best_dist:\n",
" best_dist = avg_dist_test\n",
" print(\"\\nBest Dis, Checkpoint Saved!\\n\")\n",
" if not os.path.isdir(params_dirname_prefix+\"_best/\"):\n",
" os.makedirs(params_dirname_prefix+\"_best/\")\n",
" fluid.io.save_persistables(exe, params_dirname_prefix+\"_best/\", main_program=train_prog)\n",
"\n",
" if not os.path.isdir(params_dirname_prefix+\"_checkpoint/\"):\n",
" os.makedirs(params_dirname_prefix+\"_checkpoint/\")\n",
" fluid.io.save_persistables(exe, params_dirname_prefix+\"_checkpoint/\", main_program=train_prog)\n",
" train_loop()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# download imagenet pretrain weight from:\n",
"# https://github.com/PaddlePaddle/models/tree/develop/PaddleCV/image_classification\n",
"train(use_cuda=True, params_dirname_prefix=\"../weights/loc_unet\", \n",
" pretrained_model=\"../weights/ResNet50_pretrained\", EPOCH_NUM=40)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
# Angle closure Glaucoma Evaluation Challenge
The goal of the challenge is to evaluate and compare automated algorithms for angle closure classification and localization of scleral spur (SS) points on a common dataset of AS-OCT images. We invite the medical image analysis community to participate by developing and testing existing and novel automated classification and segmentation methods.
More detail [AGE challenge](https://age.grand-challenge.org/Details/).
## Scleral spur localization task (ResNet50+UNet model)
1. Method
* Inspired by UNet method, a keypoint is equivalent to 2D gaussian heatmap.
<img src="assets/1.png">
<img src="assets/2.png">
* Then, a localization task could be transformed to a heatmap regression task.
2. Prepare data
* We assume that you have downloaded data(two zip files), and store @ `../datasets/`.
* (Updated on August 5) Replace update files.
* We provide a demo about `zip file extract`, `data structure explore`, and `Train/Val split`.
3. Train
* We assume that you have download data, extract compressed files, and store @ `../datasets/`.
* Based on PaddlePaddle and [ResNet50](https://github.com/PaddlePaddle/models/blob/develop/PaddleCV/image_classification/models/resnet.py), we modify the model structure to enable UNet model, which global pooling layer and final fc layer were removed.
4. Inference
* We assume that you have download data, extract compressed files, and store @ `../datasets/`.
* We assume that you stored checkpoint files @ `../weights/loc_unet`
* We provide a baseline about `inference` and `visualization`.
<img src="assets/3.png">
<img src="assets/4.png">
import paddle
import paddle.fluid as fluid
import paddle.fluid.layers as FL
from paddle.fluid.param_attr import ParamAttr
from resnet import *
def conv_bn_layer(input,
num_filters,
filter_size,
stride=1,
groups=1,
act=None,
name=None):
conv = FL.conv2d(
input=input,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
act=None,
param_attr=ParamAttr(name=name + "_weights"),
bias_attr=False,
name=name + '.conv2d.output.1')
bn_name = name + "_bn"
return FL.batch_norm(input=conv,
act=act,
name=bn_name+'.output.1',
param_attr=ParamAttr(name=bn_name + '_scale'),
bias_attr=ParamAttr(bn_name + '_offset'),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance',)
def DoubleConv_up(x, out_channels, name=None):
x = conv_bn_layer(x, out_channels, 3, 1, act='relu', name=name+"1")
x = conv_bn_layer(x, out_channels, 3, 1, act='relu', name=name+"2")
return x
def ConvUp(x1, x2, out_channels, name=None):
x1 = FL.conv2d_transpose(x1, num_filters=x1.shape[1] // 2, filter_size=2, stride=2)
x = FL.concat([x1,x2], axis=1)
x = DoubleConv_up(x, out_channels, name=name+"_doubleconv")
return x
class ResUNet():
def __init__(self, backbone, out_channels):
self.backbone = backbone
self.out_channels = out_channels
def net(self, input):
c1, c2, c3, c4, c5 = self.backbone(input)
channels = [64, 128, 256, 512]
x = ConvUp(c5, c4, channels[2], name='up5')
x = ConvUp(x, c3, channels[1], name='up6')
x = ConvUp(x, c2, channels[0], name='up7')
x = ConvUp(x, c1, channels[0], name='up8')
x = FL.conv2d_transpose(x, num_filters=self.out_channels, filter_size=2, stride=2)
return x
\ No newline at end of file
#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import paddle
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
__all__ = ["ResNet", "ResNet18", "ResNet34", "ResNet50", "ResNet101", "ResNet152"]
train_parameters = {
"input_size": [3, 224, 224],
"input_mean": [0.485, 0.456, 0.406],
"input_std": [0.229, 0.224, 0.225],
"learning_strategy": {
"name": "piecewise_decay",
"batch_size": 256,
"epochs": [30, 60, 90],
"steps": [0.1, 0.01, 0.001, 0.0001]
}
}
class ResNet():
def __init__(self, layers=50):
self.params = train_parameters
self.layers = layers
def net(self, input):
layers = self.layers
supported_layers = [18, 34, 50, 101, 152]
assert layers in supported_layers, \
"supported layers are {} but input layer is {}".format(supported_layers, layers)
if layers == 18:
depth = [2, 2, 2, 2]
elif layers == 34 or layers == 50:
depth = [3, 4, 6, 3]
elif layers == 101:
depth = [3, 4, 23, 3]
elif layers == 152:
depth = [3, 8, 36, 3]
num_filters = [64, 128, 256, 512]
unet_collector = []
conv = self.conv_bn_layer(
input=input, num_filters=64, filter_size=7, stride=2, act='relu',name="conv1")
unet_collector.append(conv)
conv = fluid.layers.pool2d(
input=conv,
pool_size=3,
pool_stride=2,
pool_padding=1,
pool_type='max')
if layers >= 50:
for block in range(len(depth)):
for i in range(depth[block]):
if layers in [101, 152] and block == 2:
if i == 0:
conv_name="res"+str(block+2)+"a"
else:
conv_name="res"+str(block+2)+"b"+str(i)
else:
conv_name="res"+str(block+2)+chr(97+i)
conv = self.bottleneck_block(
input=conv,
num_filters=num_filters[block],
stride=2 if i == 0 and block != 0 else 1, name=conv_name)
unet_collector.append(conv)
else:
for block in range(len(depth)):
for i in range(depth[block]):
conv_name="res"+str(block+2)+chr(97+i)
conv = self.basic_block(
input=conv,
num_filters=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
is_first=block==i==0,
name=conv_name)
unet_collector.append(conv)
return unet_collector
def conv_bn_layer(self,
input,
num_filters,
filter_size,
stride=1,
groups=1,
act=None,
name=None):
conv = fluid.layers.conv2d(
input=input,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
act=None,
param_attr=ParamAttr(name=name + "_weights"),
bias_attr=False,
name=name + '.conv2d.output.1')
if name == "conv1":
bn_name = "bn_" + name
else:
bn_name = "bn" + name[3:]
return fluid.layers.batch_norm(input=conv,
act=act,
name=bn_name+'.output.1',
param_attr=ParamAttr(name=bn_name + '_scale'),
bias_attr=ParamAttr(bn_name + '_offset'),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance',)
def shortcut(self, input, ch_out, stride, is_first, name):
ch_in = input.shape[1]
if ch_in != ch_out or stride != 1 or is_first == True:
return self.conv_bn_layer(input, ch_out, 1, stride, name=name)
else:
return input
def bottleneck_block(self, input, num_filters, stride, name):
conv0 = self.conv_bn_layer(
input=input, num_filters=num_filters, filter_size=1, act='relu',name=name+"_branch2a")
conv1 = self.conv_bn_layer(
input=conv0,
num_filters=num_filters,
filter_size=3,
stride=stride,
act='relu',
name=name+"_branch2b")
conv2 = self.conv_bn_layer(
input=conv1, num_filters=num_filters * 4, filter_size=1, act=None, name=name+"_branch2c")
short = self.shortcut(input, num_filters * 4, stride, is_first=False, name=name + "_branch1")
return fluid.layers.elementwise_add(x=short, y=conv2, act='relu',name=name+".add.output.5")
def basic_block(self, input, num_filters, stride, is_first, name):
conv0 = self.conv_bn_layer(input=input, num_filters=num_filters, filter_size=3, act='relu', stride=stride,
name=name+"_branch2a")
conv1 = self.conv_bn_layer(input=conv0, num_filters=num_filters, filter_size=3, act=None,
name=name+"_branch2b")
short = self.shortcut(input, num_filters, stride, is_first, name=name + "_branch1")
return fluid.layers.elementwise_add(x=short, y=conv1, act='relu')
def ResNet18():
model = ResNet(layers=18)
return model
def ResNet34():
model = ResNet(layers=34)
return model
def ResNet50():
model = ResNet(layers=50)
return model
def ResNet101():
model = ResNet(layers=101)
return model
def ResNet152():
model = ResNet(layers=152)
return model
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册