提交 98b91a73 编写于 作者: M MRXLT

Merge remote-tracking branch 'upstream/master'

# PLSC # PaddlePaddle大规模分类库PLSC
Paddle Large Scale Classification Tools
## 简介
PaddlePaddle大规模分类库PLSC (PaddlePaddle Large Scale Classification)是基于[飞桨平台](https://www.paddlepaddle.org.cn)开发的超大规模分类库,为用户提供从训练到部署的全流程大规模分类应用解决方案。
PLSC具备以下特点:
- 基于源于产业实践的开源深度学习平台[飞桨平台](https://www.paddlepaddle.org.cn)
- 包含大量的预训练模型 (TBD)
- 提供从训练到部署的全流程解决方案 (TBD)
## 使用教程
我们提供了一系列使用教程,来帮助用户完成使用PLSC大规模分类库进行训练、评估和部署。
这一系列文档分为__快速入门__、__基础功能__、__预测部署__和__高级功能__四个部分,由浅入深地介绍PLSC大规模分类库的设计思路和使用方法。
### 快速入门
* [安装说明](docs/installation.md)
* [训练/评估/部署](docs/usage.md)
### 基础功能
* [API简介](docs/api_intro.md)
* [自定义模型](docs/custom_modes.md)
* [自定义Reader接口]
### 预测部署
* [模型导出](docs/export_for_infer.md)
* [C++预测库使用]
### 高级功能
* [混合精度训练]
* [分布式参数转换]
* [Base64格式图像预处理]
# PLSC API简介
## 默认配置参数
PLSC大规模分类库提供了默认配置参数,用于设置训练、评估和模型相关的信息,如训练数据集目录、训练轮数等。
这些参数信息位于plsc.config模块中,下面给出这些参数的含义和默认值。
### 训练相关
| 参数名称 | 参数含义 | 默认值 |
| :------- | :------- | :----- |
| train_batch_size | 训练阶段batch size的值 | 128 |
| dataset_dir | 数据集根目录 | './train_data' |
| train_image_num | 训练图像的数量 | 5822653 |
| train_epochs | 训练轮数 | 120 |
| warmup_epochs | warmup轮数 | 0 |
| lr | 初始学习率 | 0.1 |
| lr_steps | 学习率衰减的步数 | (100000,160000,220000) |
### 评估相关
| 参数名称 | 参数含义 | 默认值 |
| :------- | :------- | :----- |
| val_targets | 验证数据集名称,以逗号分隔,如'lfw,cfp_fp' | lfw |
| test_batch_size | 评估阶段batch size的值 | 120 |
| with_test | 是否在每轮训练之后开始评估模型 | True |
### 模型相关
| 参数名称 | 参数含义 | 默认值 |
| :------- | :------- | :----- |
| model_name | 使用的模型的名称 | 'RestNet50' |
| checkpoint_dir | 预训练模型目录 | "" |
| model_save_dir | 训练模型的保存目录 | "./output" |
| loss_type | loss类型,可选值为softmax、arcface、dist_softmax和dist_arcface | 'dist_arcface' |
| num_classes | 分类类别的数量 | 85742 |
| image_shape | 图像尺寸列表,格式为CHW | [3, 112, 112] |
| margin | dist_arcface和arcface的margin参数 | 0.5 |
| scale | dist_arcface和arcface的scale参数 | 64.0 |
| emb_size | 模型最后一层隐层的输出维度 | 512 |
备注:
* checkpoint_dir和model_save_dir的区别:checkpoint_dir用于在训练/评估前加载的预训练模型所在目录;model_save_dir指的是训练后模型的保存目录。
### 参数设置API
可以通过该组API修改默认参数,具体API及其描述见下表。
| API | 描述 | 参数说明 |
| :------------------- | :--------------------| :---------------------- |
| set_val_targets(targets) | 设置验证数据集 | 以逗号分隔的验证集名称,类型为字符串 |
| set_train_batch_size(size) | 设置训练batch size的值 | 类型为int |
| set_test_batch_size(size) | 设置评估batch size的值 | 类型为int |
| set_hdfs_info(fs_name, fs_ugi, directory) | 设置hdfs文件系统信息 | fs_name为hdfs地址,类型为字符串;fs_ugi为逗号分隔的用户名和密码,类型为字符串;directory为hdfs上的路径 |
| set_model_save_dir(dir) | 设置模型保存路径model_save_dir | 类型为字符串 |
| set_dataset_dir(dir) | 设置数据集根目录dataset_dir | 类型为字符串 |
| set_train_image_num(num) | 设置训练图像的总数量 | 类型为int |
| set_class_num(num) | 设置分类类别的总数量 | 类型为int |
| set_emb_size(size) | 设置最后一层隐层的输出维度 | 类型为int |
| set_model(model) | 设置用户使用的自定义模型类实例 | BaseModel的子类 |
| set_train_epochs(num) | 设置训练的轮数 | 类型为int |
| set_checkpoint_dir(dir) | 设置用于加载的预训练模型的目录 | 类型为字符串 |
| set_warmup_epochs(num) | 设置warmup的轮数 | 类型为int |
| set_loss_type(loss_type) | 设置模型的loss类型 | 类型为字符串 |
| set_image_size(size) | 设置图像尺寸,格式为CHW | 类型为元组 |
| set_optimizer(optimizer) | 设置训练阶段的optimizer | Optimizer类实例 |
| convert_for_prediction() | 将预训练模型转换为预测模型 | None |
| predict() | 离线预测接口,用于验证线上模型的正确性 | None |
| test() | 模型评估 | None |
| train() | 模型训练 | None |
备注:上述API均为PaddlePaddle大规模分类库PLSC的plsc.entry.Entry类的方法,需要通过该类的实例调用,例如:
```shell
import plsc.entry as entry
ins = entry.Entry()
ins.set_class_num(85742)
ins.train()
```
# 自定义模型
默认地,PaddlePaddle大规模分类库构建基于ResNet50模型的训练模型。
PLSC提供了模型基类plsc.models.base_model.BaseModel,用户可以基于该基类构建自己的网络模型。用户自定义的模型类需要继承自该基类,并实现build_network方法,该方法用于构建用户自定义模型。
用户在使用时需要调用类的get_output方法,该方法在用户自定义模型的尾端自动添加分布式FC层。
下面的例子给出如何使用BaseModel基类定义用户自己的网络模型, 以及如何使用。
```python
import paddle.fluid as fluid
import plsc.entry as entry
from plsc.models.base_model import BaseModel
class ResNet(BaseModel):
def __init__(self, layers=50, emb_dim=512):
super(ResNet, self).__init__()
self.layers = layers
self.emb_dim = emb_dim
def build_network(self,
input,
label,
is_train):
layers = self.layers
supported_layers = [50, 101, 152]
assert layers in supported_layers, \
"supported layers {}, but given {}".format(supported_layers, layers)
if layers == 50:
depth = [3, 4, 14, 3]
num_filters = [64, 128, 256, 512]
elif layers == 101:
depth = [3, 4, 23, 3]
num_filters = [256, 512, 1024, 2048]
elif layers == 152:
depth = [3, 8, 36, 3]
num_filters = [256, 512, 1024, 2048]
conv = self.conv_bn_layer(
input=input, num_filters=64, filter_size=3, stride=1,
pad=1, act='prelu', is_train=is_train)
for block in range(len(depth)):
for i in range(depth[block]):
conv = self.bottleneck_block(
input=conv,
num_filters=num_filters[block],
stride=2 if i == 0 else 1,
is_train=is_train)
bn = fluid.layers.batch_norm(input=conv, act=None, epsilon=2e-05,
is_test=False if is_train else True)
drop = fluid.layers.dropout(x=bn, dropout_prob=0.4,
dropout_implementation='upscale_in_train',
is_test=False if is_train else True)
fc = fluid.layers.fc(
input=drop,
size=self.emb_dim,
act=None,
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Xavier(uniform=False, fan_in=0.0)),
bias_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.ConstantInitializer()))
emb = fluid.layers.batch_norm(input=fc, act=None, epsilon=2e-05,
is_test=False if is_train else True)
return emb
... ...
if __name__ == "__main__":
ins = entry.Entry()
ins.set_model(ResNet())
ins.train()
```
用户自定义模型类需要继承自基类BaseModel,并实现build_network方法,实现用户的自定义模型。
build_network方法的输入如下:
* input: 输入图像数据
* label: 图像类别
* is_train: 表示训练阶段还是测试/预测阶段
build_network方法返回用户自定义组网的输出变量,BaseModel类的get_output方法将调用该方法获取用户自定义组网的输出,并自动在其后添加分布式FC层。
# 预测模型导出
通常,PaddlePaddle大规模分类库在训练过程中保存的模型只保存模型参数信息,
而不包括预测模型结构。为了部署PLSC预测库,需要将预训练模型导出为预测模型。
可以通过下面的代码将预训练模型导出为预测模型:
```python
import plsc.entry as entry
if __name__ == "__main__":
ins = entry.Entry()
ins.set_checkpoint_dir('./pretrain_model')
ins.set_model_save_dir('./inference_model')
ins.convert_for_prediction()
```
# 安装说明
## 1. 安装PaddlePaddle
版本要求:
* PaddlePaddle >= 1.6.2
* Python 2.7 or 3.5+
关于PaddlePaddle对操作系统、CUDA、cuDNN等软件版本的兼容信息,请查看[PaddlePaddle安装说明](https://www.paddlepaddle.org.cn/documentation/docs/zh/beginners_guide/install/index_cn.html)
### pip安装
当前,需要在GPU版本的PaddlePaddle下使用大规模分类库。
```shell
pip install paddlepaddle-gpu
```
### Conda安装
PaddlePaddle支持Conda安装,减少相关依赖模块的安装成本。conda相关使用说明可以参考[Anaconda](https://www.anaconda.com/distribution/)
```shell
conda install -c paddle paddlepaddle-gpu cudatoolkit=9.0
```
* 请安装NVIDIA NCCL >= 2.4.7,并在Linux系统下运行。
更多安装方式和信息请参考[PaddlePaddle安装说明](https://www.paddlepaddle.org.cn/documentation/docs/zh/beginners_guide/install/index_cn.html)
## 2. 安装大规模分类库
```shell
pip install plsc
```
# 训练、评估和部署
PaddlePaddle大规模分类提供了从训练、评估到预测部署的全流程解决方案。本文档介绍如何使用PaddlePaddle大规模分类库快速完成训练、评估和预测部署。
## 数据准备
我们假设用户数据集的组织结构如下:
```shell
train_data/
|-- agedb_30.bin
|-- cfp_ff.bin
|-- cfp_fp.bin
|-- images
|-- label.txt
`-- lfw.bin
```
其中,*train_data*是用户数据的根目录,*agedb_30.bin**cfp_ff.bin**cfp_fp.bin**lfw.bin*分别是不同的验证数据集,且这些验证数据集不是全部必须的。本文档教程默认使用lfw.bin作为验证数据集,因此在浏览本教程时,请确保lfw.bin验证数据集可用。*images*目录包含JPEG格式的训练图像,*label.txt*中的每一行对应一张训练图像以及该图像的类别。
*label.txt*文件的内容示例如下:
```shell
images/00000000.jpg 0
images/00000001.jpg 0
images/00000002.jpg 0
images/00000003.jpg 0
images/00000004.jpg 0
images/00000005.jpg 0
images/00000006.jpg 0
images/00000007.jpg 0
... ...
```
## 模型训练
### 训练代码
下面的例子给出使用PLSC完成大规模分类训练的脚本*train.py*
```python
import plsc.entry as entry
if __name__ == "__main__":
ins = entry.Entry()
ins.train()
```
1. 从plsc包导入entry.Entry类,其是使用PLCS大规模分类库功能的接口类。
2. 生成Entry类的实例。
3. 调用Entry类的train方法,即可开始训练。
### 开始训练
下面的例子给出如何使用上述脚本启动训练任务:
```shell
python -m paddle.distributed.launch \
--cluster_ips="127.0.0.1" \
--node_ip="127.0.0.1" \
--selected_gpus=0,1,2,3,4,5,6,7 \
train.py
```
paddle.distributed.launch模块用于启动多机/多卡分布式训练任务脚本,简化分布式训练任务启动过程,各个参数的含义如下:
* cluster_ips: 参与训练的节点的ip地址列表,以逗号分隔;
* node_ip: 当前训练节点的ip地址;
* selected_gpus: 每个训练节点所使用的gpu设备列表,以逗号分隔。
## 模型评估
本教程中,我们使用lfw.bin验证数据集评估训练模型的效果。
### 评估代码
下面的例子给出使用PLSC完成大规模分类训练的脚本*val.py*
```python
import plsc.entry as entry
if __name__ == "__main__":
ins = entry.Entry()
ins.set_checkpoint("output/0")
ins.test()
```
默认地,PLSC将训练脚本保存在'./ouput'目录下,并以pass_id作为区分不同训练轮次模型的子目录,例如'./output/0'目录下保存完成第一个轮次的训练后保存的模型。
在模型评估阶段,我们首先需要设置训练模型的目录,接着调用Entry类的test方法开始模型评估。
## 预测部署
TBD
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
文件已添加
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from easydict import EasyDict as edict
"""
Default Parameters
"""
config = edict()
config.train_batch_size = 128
config.test_batch_size = 120
config.val_targets = 'lfw'
config.dataset_dir = './train_data'
config.train_image_num = 5822653
config.model_name = 'ResNet50'
config.train_epochs = 120
config.checkpoint_dir = ""
config.with_test = True
config.model_save_dir = "output"
config.warmup_epochs = 0
config.loss_type = "dist_arcface"
config.num_classes = 85742
config.image_shape = (3,112,112)
config.margin = 0.5
config.scale = 64.0
config.lr = 0.1
config.lr_steps = (100000,160000,220000)
config.emb_dim = 512
文件已添加
import os
import sys
import time
import argparse
import functools
import numpy as np
import paddle
import paddle.fluid as fluid
import resnet
import sklearn
import reader
from verification import evaluate
from utility import add_arguments, print_arguments
from paddle.fluid.incubate.fleet.collective import fleet, DistributedStrategy
from paddle.fluid.incubate.fleet.collective import DistFCConfig
import paddle.fluid.incubate.fleet.base.role_maker as role_maker
from paddle.fluid.transpiler.details.program_utils import program_to_code
from paddle.fluid.optimizer import Optimizer
import paddle.fluid.profiler as profiler
from fp16_utils import rewrite_program, update_role_var_grad, update_loss_scaling, move_optimize_ops_back
from fp16_lists import AutoMixedPrecisionLists
from paddle.fluid.transpiler.details import program_to_code
import paddle.fluid.layers as layers
import paddle.fluid.unique_name as unique_name
parser = argparse.ArgumentParser(description="Train parallel face network.")
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('train_batch_size', int, 128, "Minibatch size for training.")
add_arg('test_batch_size', int, 120, "Minibatch size for test.")
add_arg('num_epochs', int, 120, "Number of epochs to run.")
add_arg('image_shape', str, "3,112,112", "Image size in the format of CHW.")
add_arg('emb_dim', int, 512, "Embedding dim size.")
add_arg('class_dim', int, 85742, "Number of classes.")
add_arg('model_save_dir', str, None, "Directory to save model.")
add_arg('pretrained_model', str, None, "Directory for pretrained model.")
add_arg('lr', float, 0.1, "Initial learning rate.")
add_arg('model', str, "ResNet_ARCFACE50", "The network to use.")
add_arg('loss_type', str, "softmax", "Type of network loss to use.")
add_arg('margin', float, 0.5, "Parameter of margin for arcface or dist_arcface.")
add_arg('scale', float, 64.0, "Parameter of scale for arcface or dist_arcface.")
add_arg('with_test', bool, False, "Whether to do test during training.")
add_arg('fp16', bool, True, "Whether to do test during training.")
add_arg('profile', bool, False, "Enable profiler or not." )
# yapf: enable
args = parser.parse_args()
model_list = [m for m in dir(resnet) if "__" not in m]
def optimizer_setting(params, args):
ls = params["learning_strategy"]
step = 1
bd = [step * e for e in ls["epochs"]]
base_lr = params["lr"]
lr = [base_lr * (0.1 ** i) for i in range(len(bd) + 1)]
print("bd: {}".format(bd))
print("lr_step: {}".format(lr))
step_lr = fluid.layers.piecewise_decay(boundaries=bd, values=lr)
optimizer = fluid.optimizer.Momentum(
learning_rate=step_lr,
momentum=0.9,
regularization=fluid.regularizer.L2Decay(5e-4))
num_trainers = int(os.getenv("PADDLE_TRAINERS_NUM", 1))
if args.loss_type in ["dist_softmax", "dist_arcface"]:
if args.fp16:
wrapper = DistributedClassificationOptimizer(
optimizer, args.train_batch_size * num_trainers, step_lr,
loss_type=args.loss_type, init_loss_scaling=1.0)
else:
wrapper = DistributedClassificationOptimizer(optimizer, args.train_batch_size * num_trainers, step_lr)
elif args.loss_type in ["softmax", "arcface"]:
wrapper = optimizer
return wrapper
def build_program(args,
main_program,
startup_program,
is_train=True,
use_parallel_test=False,
fleet=None,
strategy=None):
model_name = args.model
assert model_name in model_list, \
"{} is not in supported lists: {}".format(args.model, model_list)
assert not (is_train and use_parallel_test), \
"is_train and use_parallel_test cannot be set simultaneously"
trainer_id = int(os.getenv("PADDLE_TRAINER_ID", 0))
worker_num = int(os.getenv("PADDLE_TRAINERS_NUM", 1))
image_shape = [int(m) for m in args.image_shape.split(",")]
# model definition
model = resnet.__dict__[model_name]()
with fluid.program_guard(main_program, startup_program):
with fluid.unique_name.guard():
image = fluid.layers.data(name='image', shape=image_shape, dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
emb, loss = model.net(input=image,
label=label,
is_train=is_train,
emb_dim=args.emb_dim,
class_dim=args.class_dim,
loss_type=args.loss_type,
margin=args.margin,
scale=args.scale)
if args.loss_type in ["dist_softmax", "dist_arcface"]:
shard_prob = loss._get_info("shard_prob")
prob_all = fluid.layers.collective._c_allgather(shard_prob,
nranks=worker_num, use_calc_stream=True)
prob_list = fluid.layers.split(prob_all, dim=0,
num_or_sections=worker_num)
prob = fluid.layers.concat(prob_list, axis=1)
label_all = fluid.layers.collective._c_allgather(label,
nranks=worker_num, use_calc_stream=True)
acc1 = fluid.layers.accuracy(input=prob, label=label_all, k=1)
acc5 = fluid.layers.accuracy(input=prob, label=label_all, k=5)
elif args.loss_type in ["softmax", "arcface"]:
prob = loss[1]
loss = loss[0]
acc1 = fluid.layers.accuracy(input=prob, label=label, k=1)
acc5 = fluid.layers.accuracy(input=prob, label=label, k=5)
optimizer = None
if is_train:
# parameters from model and arguments
params = model.params
params["lr"] = args.lr
params["num_epochs"] = args.num_epochs
params["learning_strategy"]["batch_size"] = args.train_batch_size
# initialize optimizer
optimizer = optimizer_setting(params, args)
dist_optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
dist_optimizer.minimize(loss)
elif use_parallel_test:
emb = fluid.layers.collective._c_allgather(emb,
nranks=worker_num, use_calc_stream=True)
return emb, loss, acc1, acc5, optimizer
def train(args):
pretrained_model = args.pretrained_model
model_save_dir = args.model_save_dir
model_name = args.model
trainer_id = int(os.getenv("PADDLE_TRAINER_ID", 0))
worker_num = int(os.getenv("PADDLE_TRAINERS_NUM", 1))
role = role_maker.PaddleCloudRoleMaker(is_collective=True)
fleet.init(role)
strategy = DistributedStrategy()
strategy.mode = "collective"
strategy.collective_mode = "grad_allreduce"
startup_prog = fluid.Program()
train_prog = fluid.Program()
test_program = fluid.Program()
train_emb, train_loss, train_acc1, train_acc5, optimizer = \
build_program(args, train_prog, startup_prog, True, False,
fleet, strategy)
test_emb, test_loss, test_acc1, test_acc5, _ = \
build_program(args, test_program, startup_prog, False, True)
if args.loss_type in ["dist_softmax", "dist_arcface"]:
if not args.fp16:
global_lr = optimizer._optimizer._global_learning_rate(
program=train_prog)
else:
global_lr = optimizer._optimizer._global_learning_rate(
program=train_prog)
elif args.loss_type in ["softmax", "arcface"]:
global_lr = optimizer._global_learning_rate(program=train_prog)
origin_prog = fleet._origin_program
train_prog = fleet.main_program
if trainer_id == 0:
with open('start.program', 'w') as fout:
program_to_code(startup_prog, fout, True)
with open('main.program', 'w') as fout:
program_to_code(train_prog, fout, True)
with open('origin.program', 'w') as fout:
program_to_code(origin_prog, fout, True)
gpu_id = int(os.getenv("FLAGS_selected_gpus", 0))
place = fluid.CUDAPlace(gpu_id)
exe = fluid.Executor(place)
exe.run(startup_prog)
if pretrained_model:
pretrained_model = os.path.join(pretrained_model, str(trainer_id))
def if_exist(var):
has_var = os.path.exists(os.path.join(pretrained_model, var.name))
if has_var:
print('var: %s found' % (var.name))
return has_var
fluid.io.load_vars(exe, pretrained_model, predicate=if_exist,
main_program=train_prog)
train_reader = paddle.batch(reader.arc_train(args.class_dim),
batch_size=args.train_batch_size)
if args.with_test:
test_list, test_name_list = reader.test()
test_feeder = fluid.DataFeeder(place=place, feed_list=['image', 'label'], program=test_program)
fetch_list_test = [test_emb.name, test_acc1.name, test_acc5.name]
feeder = fluid.DataFeeder(place=place, feed_list=['image', 'label'], program=train_prog)
fetch_list_train = [train_loss.name, global_lr.name, train_acc1.name, train_acc5.name,train_emb.name,"loss_scaling_0"]
# test_program = test_program._prune(targets=loss)
num_trainers = int(os.getenv("PADDLE_TRAINERS_NUM", 1))
real_batch_size = args.train_batch_size * num_trainers
real_test_batch_size = args.test_batch_size * num_trainers
local_time = 0.0
nsamples = 0
inspect_steps = 100
step_cnt = 0
for pass_id in range(args.num_epochs):
train_info = [[], [], [], []]
local_train_info = [[], [], [], []]
for batch_id, data in enumerate(train_reader()):
nsamples += real_batch_size
t1 = time.time()
loss, lr, acc1, acc5, train_embedding, loss_scaling = exe.run(train_prog, feed=feeder.feed(data),
fetch_list=fetch_list_train, use_program_cache=True)
t2 = time.time()
if args.profile and step_cnt == 50:
print("begin profiler")
if trainer_id == 0:
profiler.start_profiler("All")
elif args.profile and batch_id == 55:
print("begin to end profiler")
if trainer_id == 0:
profiler.stop_profiler("total", "./profile_%d" % (trainer_id))
print("end profiler break!")
args.profile=False
period = t2 - t1
local_time += period
train_info[0].append(np.array(loss)[0])
train_info[1].append(np.array(lr)[0])
local_train_info[0].append(np.array(loss)[0])
local_train_info[1].append(np.array(lr)[0])
if batch_id % inspect_steps == 0:
avg_loss = np.mean(local_train_info[0])
avg_lr = np.mean(local_train_info[1])
print("Pass:%d batch:%d lr:%f loss:%f qps:%.2f acc1:%.4f acc5:%.4f" % (
pass_id, batch_id, avg_lr, avg_loss, nsamples / local_time,
acc1, acc5))
#print("train_embedding:,",np.array(train_embedding)[0])
print("train_embedding is nan:",np.isnan(np.array(train_embedding)[0]).sum())
print("loss_scaling",loss_scaling)
local_time = 0
nsamples = 0
local_train_info = [[], [], [], []]
step_cnt += 1
if args.with_test and step_cnt % inspect_steps == 0:
test_start = time.time()
for i in xrange(len(test_list)):
data_list, issame_list = test_list[i]
embeddings_list = []
for j in xrange(len(data_list)):
data = data_list[j]
embeddings = None
parallel_test_steps = data.shape[0] // real_test_batch_size
beg = 0
end = 0
for idx in range(parallel_test_steps):
start = idx * real_test_batch_size
offset = trainer_id * args.test_batch_size
begin = start + offset
end = begin + args.test_batch_size
_data = []
for k in xrange(begin, end):
_data.append((data[k], 0))
assert len(_data) == args.test_batch_size
[_embeddings, acc1, acc5] = exe.run(test_program,
fetch_list = fetch_list_test, feed=test_feeder.feed(_data),
use_program_cache=True)
if embeddings is None:
embeddings = np.zeros((data.shape[0], _embeddings.shape[1]))
embeddings[start:start+real_test_batch_size, :] = _embeddings[:, :]
beg = parallel_test_steps * real_test_batch_size
while beg < data.shape[0]:
end = min(beg + args.test_batch_size, data.shape[0])
count = end - beg
_data = []
for k in xrange(end - args.test_batch_size, end):
_data.append((data[k], 0))
[_embeddings, acc1, acc5] = exe.run(test_program,
fetch_list = fetch_list_test, feed=test_feeder.feed(_data),
use_program_cache=True)
_embeddings = _embeddings[0:args.test_batch_size,:]
embeddings[beg:end, :] = _embeddings[(args.test_batch_size-count):, :]
beg = end
embeddings_list.append(embeddings)
xnorm = 0.0
xnorm_cnt = 0
for embed in embeddings_list:
xnorm += np.sqrt((embed * embed).sum(axis=1)).sum(axis=0)
xnorm_cnt += embed.shape[0]
xnorm /= xnorm_cnt
embeddings = embeddings_list[0] + embeddings_list[1]
if np.isnan(embeddings).sum() > 1:
print("======test np.isnan(embeddings).sum()",np.isnan(embeddings).sum())
continue
embeddings = sklearn.preprocessing.normalize(embeddings)
_, _, accuracy, val, val_std, far = evaluate(embeddings, issame_list, nrof_folds=10)
acc, std = np.mean(accuracy), np.std(accuracy)
print('[%s][%d]XNorm: %f' % (test_name_list[i], step_cnt, xnorm))
print('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' % (test_name_list[i], step_cnt, acc, std))
sys.stdout.flush()
test_end = time.time()
print("test time: {}".format(test_end - test_start))
train_loss = np.array(train_info[0]).mean()
print("End pass {0}, train_loss {1}".format(pass_id, train_loss))
sys.stdout.flush()
#save model
#if trainer_id == 0:
if model_save_dir:
model_path = os.path.join(model_save_dir + '/' + model_name,
str(pass_id), str(trainer_id))
if not os.path.isdir(model_path):
os.makedirs(model_path)
fluid.io.save_persistables(exe, model_path)
class DistributedClassificationOptimizer(Optimizer):
'''
A optimizer wrapper to generate backward network for distributed
classification training of model parallelism.
'''
def __init__(self,optimizer, batch_size, lr,
loss_type='dist_arcface',
amp_lists=None,
init_loss_scaling=1.0,
incr_every_n_steps=1000,
decr_every_n_nan_or_inf=2,
incr_ratio=2.0,
decr_ratio=0.5,
use_dynamic_loss_scaling=True):
super(DistributedClassificationOptimizer, self).__init__(
learning_rate=lr)
self._optimizer = optimizer
self._batch_size = batch_size
self._amp_lists = amp_lists
if amp_lists is None:
self._amp_lists = AutoMixedPrecisionLists()
self._param_grads = None
self._scaled_loss = None
self._loss_type = loss_type
self._init_loss_scaling = init_loss_scaling
self._loss_scaling = layers.create_global_var(
name=unique_name.generate("loss_scaling"),
shape=[1],
value=init_loss_scaling,
dtype='float32',
persistable=True)
self._use_dynamic_loss_scaling = use_dynamic_loss_scaling
if self._use_dynamic_loss_scaling:
self._incr_every_n_steps = layers.fill_constant(
shape=[1], dtype='int32', value=incr_every_n_steps)
self._decr_every_n_nan_or_inf = layers.fill_constant(
shape=[1], dtype='int32', value=decr_every_n_nan_or_inf)
self._incr_ratio = incr_ratio
self._decr_ratio = decr_ratio
self._num_good_steps = layers.create_global_var(
name=unique_name.generate("num_good_steps"),
shape=[1],
value=0,
dtype='int32',
persistable=True)
self._num_bad_steps = layers.create_global_var(
name=unique_name.generate("num_bad_steps"),
shape=[1],
value=0,
dtype='int32',
persistable=True)
# Ensure the data type of learning rate vars is float32 (same as the
# master parameter dtype)
if isinstance(optimizer._learning_rate, float):
optimizer._learning_rate_map[fluid.default_main_program()] = \
layers.create_global_var(
name=unique_name.generate("learning_rate"),
shape=[1],
value=float(optimizer._learning_rate),
dtype='float32',
persistable=True)
def minimize(self,
loss,
startup_program=None,
parameter_list=None,
no_grad_set=None,
callbacks=None):
assert loss._get_info('shard_logit')
shard_logit = loss._get_info('shard_logit')
shard_prob = loss._get_info('shard_prob')
shard_label = loss._get_info('shard_label')
shard_dim = loss._get_info('shard_dim')
op_maker = fluid.core.op_proto_and_checker_maker
op_role_key = op_maker.kOpRoleAttrName()
op_role_var_key = op_maker.kOpRoleVarAttrName()
backward_role = int(op_maker.OpRole.Backward)
loss_backward_role = int(op_maker.OpRole.Loss) | int(
op_maker.OpRole.Backward)
# minimize a scalar of reduce_sum to generate the backward network
scalar = fluid.layers.reduce_sum(shard_logit)
if not args.fp16:
ret = self._optimizer.minimize(scalar)
with open("fp32_before.program", "w") as f:
program_to_code(block.program,fout=f, skip_op_callstack=False)
block = loss.block
# remove the unnecessary ops
index = 0
for i, op in enumerate(block.ops):
if op.all_attrs()[op_role_key] == loss_backward_role:
index = i
break
print("op_role_key: ",op_role_key)
print("loss_backward_role:",loss_backward_role)
# print("\nblock.ops: ",block.ops)
print("block.ops[index - 1].type: ", block.ops[index - 1].type)
print("block.ops[index].type: ", block.ops[index].type)
print("block.ops[index + 1].type: ", block.ops[index + 1].type)
assert block.ops[index - 1].type == 'reduce_sum'
assert block.ops[index].type == 'fill_constant'
assert block.ops[index + 1].type == 'reduce_sum_grad'
block._remove_op(index + 1)
block._remove_op(index)
block._remove_op(index - 1)
# insert the calculated gradient
dtype = shard_logit.dtype
shard_one_hot = fluid.layers.create_tensor(dtype, name='shard_one_hot')
block._insert_op(
index - 1,
type='one_hot',
inputs={'X': shard_label},
outputs={'Out': shard_one_hot},
attrs={
'depth': shard_dim,
'allow_out_of_range': True,
op_role_key: backward_role
})
shard_logit_grad = fluid.layers.create_tensor(
dtype, name=fluid.backward._append_grad_suffix_(shard_logit.name))
block._insert_op(
index,
type='elementwise_sub',
inputs={'X': shard_prob,
'Y': shard_one_hot},
outputs={'Out': shard_logit_grad},
attrs={op_role_key: backward_role})
block._insert_op(
index + 1,
type='scale',
inputs={'X': shard_logit_grad},
outputs={'Out': shard_logit_grad},
attrs={
'scale': 1.0 / self._batch_size,
op_role_key: loss_backward_role
})
with open("fp32_after.program", "w") as f:
program_to_code(block.program,fout=f, skip_op_callstack=False)
# use mixed_precision for training
else:
block = loss.block
rewrite_program(block.program, self._amp_lists)
self._params_grads = self._optimizer.backward(
scalar, startup_program, parameter_list, no_grad_set,
callbacks)
update_role_var_grad(block.program, self._params_grads)
move_optimize_ops_back(block.program.global_block())
scaled_params_grads = []
for p, g in self._params_grads:
with fluid.default_main_program()._optimized_guard([p, g]):
scaled_g = g / self._loss_scaling
scaled_params_grads.append([p, scaled_g])
index = 0
for i, op in enumerate(block.ops):
if op.all_attrs()[op_role_key] == loss_backward_role:
index = i
break
fp32 = fluid.core.VarDesc.VarType.FP32
dtype = shard_logit.dtype
if self._loss_type == 'dist_arcface':
assert block.ops[index - 2].type == 'fill_constant'
assert block.ops[index - 1].type == 'reduce_sum'
assert block.ops[index].type == 'fill_constant'
assert block.ops[index + 1].type == 'reduce_sum_grad'
assert block.ops[index + 2].type == 'scale'
assert block.ops[index + 3].type == 'elementwise_add_grad'
block._remove_op(index + 2)
block._remove_op(index + 1)
block._remove_op(index)
block._remove_op(index - 1)
# insert the calculated gradient
shard_one_hot = fluid.layers.create_tensor(dtype, name='shard_one_hot')
block._insert_op(
index - 1,
type='one_hot',
inputs={'X': shard_label},
outputs={'Out': shard_one_hot},
attrs={
'depth': shard_dim,
'allow_out_of_range': True,
op_role_key: backward_role
})
shard_one_hot_fp32 = fluid.layers.create_tensor(fp32, name=(shard_one_hot.name+".cast_fp32"))
block._insert_op(
index,
type="cast",
inputs={"X": shard_one_hot},
outputs={"Out": shard_one_hot_fp32},
attrs={
"in_dtype": fluid.core.VarDesc.VarType.FP16,
"out_dtype": fluid.core.VarDesc.VarType.FP32,
op_role_key: backward_role
})
name = 'tmp_3@GRAD'
shard_logit_grad_fp32 = block.var(name)
block._insert_op(
index+1,
type='elementwise_sub',
inputs={'X': shard_prob,
'Y': shard_one_hot_fp32},
outputs={'Out': shard_logit_grad_fp32},
attrs={op_role_key: backward_role})
block._insert_op(
index+2,
type='elementwise_mul',
inputs={'X': shard_logit_grad_fp32,
'Y': self._loss_scaling},
outputs={'Out': shard_logit_grad_fp32},
attrs={op_role_key: backward_role})
block._insert_op(
index+3,
type='scale',
inputs={'X': shard_logit_grad_fp32},
outputs={'Out': shard_logit_grad_fp32},
attrs={
'scale': 1.0 / self._batch_size,
op_role_key: loss_backward_role
})
elif self._loss_type == 'dist_softmax':
print("block.ops[index - 3].type: ", block.ops[index - 3].type)
print("block.ops[index - 2].type: ", block.ops[index - 2].type)
print("block.ops[index-1].type: ", block.ops[index - 1].type)
print("block.ops[index].type: ", block.ops[index].type)
print("block.ops[index + 1].type: ", block.ops[index +1].type)
print("block.ops[index + 2].type: ", block.ops[index +2].type)
print("block.ops[index + 3].type: ", block.ops[index +3].type)
with open("fp16_softmax_before.program", "w") as f:
program_to_code(block.program,fout=f, skip_op_callstack=False)
assert block.ops[index - 1].type == 'reduce_sum'
assert block.ops[index].type == 'fill_constant'
assert block.ops[index + 1].type == 'reduce_sum_grad'
assert block.ops[index + 2].type == 'cast'
assert block.ops[index + 3].type == 'elementwise_add_grad'
block._remove_op(index + 1)
block._remove_op(index)
block._remove_op(index - 1)
# insert the calculated gradient
shard_one_hot = fluid.layers.create_tensor(fp32, name='shard_one_hot')
shard_one_hot_fp32 = fluid.layers.create_tensor(fp32,
name=(shard_one_hot.name+".cast_fp32"))
shard_logit_grad_fp32 = block.var(shard_logit.name+".cast_fp32@GRAD")
block._insert_op(
index - 1,
type='one_hot',
inputs={'X': shard_label},
outputs={'Out': shard_one_hot_fp32},
attrs={
'depth': shard_dim,
'allow_out_of_range': True,
op_role_key: backward_role
})
block._insert_op(
index,
type='elementwise_sub',
inputs={'X': shard_prob,
'Y': shard_one_hot_fp32},
outputs={'Out': shard_logit_grad_fp32},
attrs={op_role_key: backward_role})
block._insert_op(
index + 1,
type='elementwise_mul',
inputs={'X': shard_logit_grad_fp32,
'Y': self._loss_scaling},
outputs={'Out': shard_logit_grad_fp32},
attrs={op_role_key: backward_role})
block._insert_op(
index + 2,
type='scale',
inputs={'X': shard_logit_grad_fp32},
outputs={'Out': shard_logit_grad_fp32},
attrs={
'scale': 1.0 / self._batch_size,
op_role_key: loss_backward_role
})
if self._use_dynamic_loss_scaling:
grads = [layers.reduce_sum(g) for [_, g] in scaled_params_grads]
all_grads = layers.concat(grads)
all_grads_sum = layers.reduce_sum(all_grads)
is_overall_finite = layers.isfinite(all_grads_sum)
update_loss_scaling(is_overall_finite, self._loss_scaling,
self._num_good_steps, self._num_bad_steps,
self._incr_every_n_steps,
self._decr_every_n_nan_or_inf, self._incr_ratio,
self._decr_ratio)
with layers.Switch() as switch:
with switch.case(is_overall_finite):
pass
with switch.default():
for _, g in scaled_params_grads:
layers.assign(layers.zeros_like(g), g)
optimize_ops = self._optimizer.apply_gradients(scaled_params_grads)
ret = optimize_ops, scaled_params_grads
with open("fp16_softmax.program", "w") as f:
program_to_code(block.program,fout=f, skip_op_callstack=False)
return ret
def main():
global args
all_loss_types = ["softmax", "arcface", "dist_softmax", "dist_arcface"]
assert args.loss_type in all_loss_types, \
"All supported loss types [{}], but give {}.".format(
all_loss_types, args.loss_type)
print_arguments(args)
train(args)
if __name__ == '__main__':
main()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from __future__ import division
import os
import sys
import time
import argparse
import numpy as np
import math
import pickle
import subprocess
import shutil
import logging
import paddle
import paddle.fluid as fluid
import sklearn
from . import config
from .models import resnet
from .models import base_model
from .models.dist_algo import DistributedClassificationOptimizer
from .utils.learning_rate import lr_warmup
from .utils.verification import evaluate
from .utils import jpeg_reader as reader
from paddle.fluid.incubate.fleet.collective import fleet, DistributedStrategy
import paddle.fluid.incubate.fleet.base.role_maker as role_maker
from paddle.fluid.transpiler.details.program_utils import program_to_code
import paddle.fluid.transpiler.distribute_transpiler as dist_transpiler
from paddle.fluid.optimizer import Optimizer
logging.basicConfig(
format='[%(asctime)s %(levelname)s line:%(lineno)d] %(message)s',
datefmt='%d %b %Y %H:%M:%S')
logger = logging.getLogger(__name__)
class Entry(object):
"""
The class to encapsulate all operations.
"""
def _check(self):
"""
Check the validation of parameters.
"""
supported_types = ["softmax", "arcface",
"dist_softmax", "dist_arcface"]
assert self.loss_type in supported_types, \
"All supported types are {}, but given {}.".format(
supported_types, self.loss_type)
if self.loss_type in ["dist_softmax", "dist_arcface"]:
assert self.num_trainers > 1, \
"At least 2 trainers are required to use distributed fc-layer."
def __init__(self):
self.config = config.config
super(Entry, self).__init__()
assert os.getenv("PADDLE_TRAINERS_NUM") is not None, \
"Please start script using paddle.distributed.launch module."
num_trainers = int(os.getenv("PADDLE_TRAINERS_NUM"))
trainer_id = int(os.getenv("PADDLE_TRAINER_ID"))
self.trainer_id = trainer_id
self.num_trainers = num_trainers
self.train_batch_size = self.config.train_batch_size
self.test_batch_size = self.config.test_batch_size
self.global_train_batch_size = self.train_batch_size * num_trainers
self.global_test_batch_size = self.test_batch_size * num_trainers
self.optimizer = None
self.model = None
self.train_reader = None
self.test_reader = None
self.train_program = fluid.Program()
self.startup_program = fluid.Program()
self.test_program = fluid.Program()
self.predict_program = fluid.Program()
self.fs_name = None
self.fs_ugi = None
self.fs_dir = None
self.val_targets = self.config.val_targets
self.dataset_dir = self.config.dataset_dir
self.num_classes = self.config.num_classes
self.image_shape = self.config.image_shape
self.loss_type = self.config.loss_type
self.margin = self.config.margin
self.scale = self.config.scale
self.lr = self.config.lr
self.lr_steps = self.config.lr_steps
self.train_image_num = self.config.train_image_num
self.model_name = self.config.model_name
self.emb_dim = self.config.emb_dim
self.train_epochs = self.config.train_epochs
self.checkpoint_dir = self.config.checkpoint_dir
self.with_test = self.config.with_test
self.model_save_dir = self.config.model_save_dir
self.warmup_epochs = self.config.warmup_epochs
logger.info('=' * 30)
logger.info("Default configuration: ")
for key in self.config:
logger.info('\t' + str(key) + ": " + str(self.config[key]))
logger.info('trainer_id: {}, num_trainers: {}'.format(
trainer_id, num_trainers))
logger.info('=' * 30)
def set_val_targets(self, targets):
self.val_targets = targets
logger.info("Set val_targets to {} by user.".format(targets))
def set_train_batch_size(self, batch_size):
self.train_batch_size = batch_size
self.global_train_batch_size = batch_size * self.num_trainers
logger.info("Set train batch size to {} by user.".format(batch_size))
def set_test_batch_size(self, batch_size):
self.test_batch_size = batch_size
self.global_test_batch_size = batch_size * self.num_trainers
logger.info("Set test batch size to {} by user.".format(batch_size))
def set_hdfs_info(self, fs_name, fs_ugi, directory):
"""
Set the info to download from or upload to hdfs filesystems.
If the information is provided, we will download pretrained
model from hdfs at the begining and upload pretrained models
to hdfs at the end automatically.
"""
self.fs_name = fs_name
self.fs_ugi = fs_ugi
self.fs_dir = directory
logger.info("HDFS Info:")
logger.info("\tfs_name: {}".format(fs_name))
logger.info("\tfs_ugi: {}".format(fs_ugi))
logger.info("\tremote directory: {}".format(directory))
def set_model_save_dir(self, directory):
"""
Set the directory to save model.
"""
self.model_save_dir = directory
logger.info("Set model_save_dir to {} by user.".format(directory))
def set_dataset_dir(self, directory):
"""
Set the root directory for datasets.
"""
self.dataset_dir = directory
logger.info("Set dataset_dir to {} by user.".format(directory))
def set_train_image_num(self, num):
"""
Set the total number of images for train.
"""
self.train_image_num = num
logger.info("Set train_image_num to {} by user.".format(num))
def set_class_num(self, num):
"""
Set the number of classes.
"""
self.num_classes = num
logger.info("Set num_classes to {} by user.".format(num))
def set_emb_size(self, size):
"""
Set the size of the last hidding layer before the distributed fc-layer.
"""
self.emb_size = size
logger.info("Set emb_size to {} by user.".format(size))
def set_model(self, model):
"""
Set user-defined model to use.
"""
self.model = model
if not isinstance(model, base_model.BaseModel):
raise ValueError("The parameter for set_model must be an "
"instance of BaseModel.")
logger.info("Set model to {} by user.".format(model))
def set_train_epochs(self, num):
"""
Set the number of epochs to train.
"""
self.train_epochs = num
logger.info("Set train_epochs to {} by user.".format(num))
def set_checkpoint_dir(self, directory):
"""
Set the directory for checkpoint loaded before training/testing.
"""
self.checkpoint_dir = directory
logger.info("Set checkpoint_dir to {} by user.".format(directory))
def set_warmup_epochs(self, num):
self.warmup_epochs = num
logger.info("Set warmup_epochs to {} by user.".format(num))
def set_loss_type(self, type):
supported_types = ["dist_softmax", "dist_arcface", "softmax", "arcface"]
if not type in supported_types:
raise ValueError("All supported loss types: {}".format(
supported_types))
self.loss_type = type
logger.info("Set loss_type to {} by user.".format(type))
def set_image_shape(self, shape):
if not isinstance(shape, (list, tuple)):
raise ValueError("shape must be of type list or tuple")
self.image_shape = shape
logger.info("Set image_shape to {} by user.".format(shape))
def set_optimizer(self, optimizer):
if not isinstance(optimizer, Optimizer):
raise ValueError("optimizer must be as type of Optimizer")
self.optimizer = optimizer
logger.info("User manually set optimizer")
def get_optimizer(self):
if self.optimizer:
return self.optimizer
bd = [step for step in self.lr_steps]
start_lr = self.lr
global_batch_size = self.global_train_batch_size
train_image_num = self.train_image_num
images_per_trainer = int(math.ceil(
train_image_num * 1.0 / self.num_trainers))
steps_per_pass = int(math.ceil(
images_per_trainer * 1.0 / self.train_batch_size))
logger.info("steps per epoch: %d" % steps_per_pass)
warmup_steps = steps_per_pass * self.warmup_epochs
batch_denom = 1024
base_lr = start_lr * global_batch_size / batch_denom
lr = [base_lr * (0.1 ** i) for i in range(len(bd) + 1)]
logger.info("lr boundaries: {}".format(bd))
logger.info("lr_step: {}".format(lr))
if self.warmup_epochs:
lr_val = lr_warmup(fluid.layers.piecewise_decay(boundaries=bd,
values=lr), warmup_steps, start_lr, base_lr)
else:
lr_val = fluid.layers.piecewise_decay(boundaries=bd, values=lr)
optimizer = fluid.optimizer.Momentum(
learning_rate=lr_val, momentum=0.9,
regularization=fluid.regularizer.L2Decay(5e-4))
self.optimizer = optimizer
if self.loss_type in ["dist_softmax", "dist_arcface"]:
self.optimizer = DistributedClassificationOptimizer(
self.optimizer, global_batch_size)
return self.optimizer
def build_program(self,
is_train=True,
use_parallel_test=False):
model_name = self.model_name
assert not (is_train and use_parallel_test), \
"is_train and use_parallel_test cannot be set simultaneously."
trainer_id = self.trainer_id
num_trainers = self.num_trainers
image_shape = [int(m) for m in self.image_shape]
# model definition
model = self.model
if model is None:
model = resnet.__dict__[model_name](emb_dim=self.emb_dim)
main_program = self.train_program if is_train else self.test_program
startup_program = self.startup_program
with fluid.program_guard(main_program, startup_program):
with fluid.unique_name.guard():
image = fluid.layers.data(name='image',
shape=image_shape, dtype='float32')
label = fluid.layers.data(name='label',
shape=[1], dtype='int64')
emb, loss, prob = model.get_output(
input=image,
label=label,
is_train=is_train,
num_classes=self.num_classes,
loss_type=self.loss_type,
margin=self.margin,
scale=self.scale)
if self.loss_type in ["dist_softmax", "dist_arcface"]:
shard_prob = loss._get_info("shard_prob")
prob_all = fluid.layers.collective._c_allgather(shard_prob,
nranks=num_trainers, use_calc_stream=True)
prob_list = fluid.layers.split(prob_all, dim=0,
num_or_sections=num_trainers)
prob = fluid.layers.concat(prob_list, axis=1)
label_all = fluid.layers.collective._c_allgather(label,
nranks=num_trainers, use_calc_stream=True)
acc1 = fluid.layers.accuracy(input=prob, label=label_all, k=1)
acc5 = fluid.layers.accuracy(input=prob, label=label_all, k=5)
else:
acc1 = fluid.layers.accuracy(input=prob, label=label, k=1)
acc5 = fluid.layers.accuracy(input=prob, label=label, k=5)
optimizer = None
if is_train:
# initialize optimizer
optimizer = self.get_optimizer()
dist_optimizer = self.fleet.distributed_optimizer(
optimizer, strategy=self.strategy)
dist_optimizer.minimize(loss)
elif use_parallel_test:
emb = fluid.layers.collective._c_allgather(emb,
nranks=num_trainers, use_calc_stream=True)
return emb, loss, acc1, acc5, optimizer
def get_files_from_hdfs(self, local_dir):
cmd = "hadoop fs -D fs.default.name="
cmd += self.fs_name + " "
cmd += "-D hadoop.job.ugi="
cmd += self.fs_ugi + " "
cmd += "-get " + self.fs_dir
cmd += " " + local_dir
logger.info("hdfs download cmd: {}".format(cmd))
cmd = cmd.split(' ')
process = subprocess.Popen(cmd,
stdout=sys.stdout,
stderr=subprocess.STDOUT)
process.wait()
def put_files_to_hdfs(self, local_dir):
cmd = "hadoop fs -D fs.default.name="
cmd += self.fs_name + " "
cmd += "-D hadoop.job.ugi="
cmd += self.fs_ugi + " "
cmd += "-put " + local_dir
cmd += " " + self.fs_dir
logger.info("hdfs upload cmd: {}".format(cmd))
cmd = cmd.split(' ')
process = subprocess.Popen(cmd,
stdout=sys.stdout,
stderr=subprocess.STDOUT)
process.wait()
def preprocess_distributed_params(self,
local_dir):
local_dir = os.path.abspath(local_dir)
output_dir = local_dir + "_@tmp"
assert not os.path.exists(output_dir), \
"The temp directory {} for distributed params exists.".format(
output_dir)
os.makedirs(output_dir)
cmd = sys.executable + ' -m plsc.utils.process_distfc_parameter '
cmd += "--nranks {} ".format(self.num_trainers)
cmd += "--num_classes {} ".format(self.num_classes)
cmd += "--pretrained_model_dir {} ".format(local_dir)
cmd += "--output_dir {}".format(output_dir)
cmd = cmd.split(' ')
logger.info("Distributed parameters processing cmd: {}".format(cmd))
process = subprocess.Popen(cmd,
stdout=sys.stdout,
stderr=subprocess.STDOUT)
process.wait()
for file in os.listdir(local_dir):
if "dist@" in file and "@rank@" in file:
file = os.path.join(local_dir, file)
os.remove(file)
for file in os.listdir(output_dir):
if "dist@" in file and "@rank@" in file:
file = os.path.join(output_dir, file)
shutil.move(file, local_dir)
shutil.rmtree(output_dir)
file_name = os.path.join(local_dir, '.lock')
with open(file_name, 'w') as f:
pass
def append_broadcast_ops(self, program):
"""
Before test, we broadcast bn-related parameters to all other trainers.
"""
bn_vars = [var for var in program.list_vars()
if 'batch_norm' in var.name and var.persistable]
block = program.current_block()
for var in bn_vars:
block._insert_op(
0,
type='c_broadcast',
inputs={'X': var},
outputs={'Out': var},
attrs={'use_calc_stream': True})
def load_checkpoint(self,
executor,
main_program,
use_per_trainer_checkpoint=False,
load_for_train=True):
if use_per_trainer_checkpoint:
checkpoint_dir = os.path.join(
self.checkpoint_dir, str(self.trainer_id))
else:
checkpoint_dir = self.checkpoint_dir
if self.fs_name is not None:
if os.path.exists(checkpoint_dir):
ans = input("Downloading pretrained model, but the local "
"checkpoint directory ({}) exists, overwrite it "
"or not? [Y/N]".format(checkpoint_dir))
if ans.lower() == n:
logger.info("Using the local checkpoint directory, instead"
" of the remote one.")
else:
logger.info("Overwriting the local checkpoint directory.")
shutil.rmtree(checkpoint_dir)
os.makedirs(checkpoint_dir)
file_name = os.path.join(checkpoint_dir, '.lock')
if self.trainer_id == 0:
self.get_files_from_hdfs(checkpoint_dir)
with open(file_name, 'w') as f:
pass
time.sleep(5)
os.remove(file_name)
else:
while True:
if not os.path.exists(file_name):
time.sleep(1)
else:
break
else:
self.get_files_from_hdfs(checkpoint_dir)
# Preporcess distributed parameters.
file_name = os.path.join(checkpoint_dir, '.lock')
distributed = self.loss_type in ["dist_softmax", "dist_arcface"]
if load_for_train and self.trainer_id == 0 and distributed:
self.preprocess_distributed_params(checkpoint_dir)
time.sleep(5)
os.remove(file_name)
elif load_for_train and distributed:
# wait trainer_id (0) to complete
while True:
if not os.path.exists(file_name):
time.sleep(1)
else:
break
def if_exist(var):
has_var = os.path.exists(os.path.join(checkpoint_dir, var.name))
if has_var:
print('var: %s found' % (var.name))
return has_var
fluid.io.load_vars(executor, checkpoint_dir, predicate=if_exist,
main_program=main_program)
def convert_for_prediction(self):
model_name = self.model_name
image_shape = [int(m) for m in self.image_shape]
# model definition
model = self.model
if model is None:
model = resnet.__dict__[model_name](emb_dim=self.emb_dim)
main_program = self.train_program
startup_program = self.startup_program
with fluid.program_guard(main_program, startup_program):
with fluid.unique_name.guard():
image = fluid.layers.data(name='image',
shape=image_shape, dtype='float32')
label = fluid.layers.data(name='label',
shape=[1], dtype='int64')
emb = model.build_network(
input=image,
label=label,
is_train=False)
gpu_id = int(os.getenv("FLAGS_selected_gpus", 0))
place = fluid.CUDAPlace(gpu_id)
exe = fluid.Executor(place)
exe.run(startup_program)
assert self.checkpoint_dir, "No checkpoint found for converting."
self.load_checkpoint(executor=exe, main_program=main_program,
load_for_train=False)
assert self.model_save_dir, \
"Does not set model_save_dir for inference."
if os.path.exists(self.model_save_dir):
ans = input("model_save_dir for inference model ({}) exists, "
"overwrite it or not? [Y/N]".format(model_save_dir))
if ans.lower() == n:
logger.error("model_save_dir for inference model exists, "
"and cannot overwrite it.")
exit()
shutil.rmtree(self.model_save_dir)
fluid.io.save_inference_model(self.model_save_dir,
feeded_var_names=[image.name],
target_vars=[emb],
executor=exe,
main_program=main_program)
if self.fs_name:
self.put_files_to_hdfs(model_save_dir)
def predict(self):
model_name = self.model_name
image_shape = [int(m) for m in self.image_shape]
# model definition
model = self.model
if model is None:
model = resnet.__dict__[model_name](emb_dim=self.emb_dim)
main_program = self.predict_program
startup_program = self.startup_program
with fluid.program_guard(main_program, startup_program):
with fluid.unique_name.guard():
image = fluid.layers.data(name='image',
shape=image_shape, dtype='float32')
label = fluid.layers.data(name='label',
shape=[1], dtype='int64')
emb = model.build_network(
input=image,
label=label,
is_train=False)
gpu_id = int(os.getenv("FLAGS_selected_gpus", 0))
place = fluid.CUDAPlace(gpu_id)
exe = fluid.Executor(place)
exe.run(startup_program)
assert self.checkpoint_dir, "No checkpoint found for predicting."
self.load_checkpoint(executor=exe, main_program=main_program,
load_for_train=False)
if self.train_reader is None:
train_reader = paddle.batch(reader.arc_train(
self.dataset_dir, self.num_classes),
batch_size=self.train_batch_size)
else:
train_reader = self.train_reader
feeder = fluid.DataFeeder(place=place,
feed_list=['image', 'label'], program=main_program)
fetch_list = [emb.name]
for data in train_reader():
emb = exe.run(main_program, feed=feeder.feed(data),
fetch_list=fetch_list, use_program_cache=True)
print("emb: ", emb)
def test(self, pass_id=0):
self._check()
trainer_id = self.trainer_id
num_trainers = self.num_trainers
worker_endpoints = os.getenv("PADDLE_TRAINER_ENDPOINTS")
current_endpoint = os.getenv("PADDLE_CURRENT_ENDPOINT")
emb, loss, acc1, acc5, _ = self.build_program(
False, self.num_trainers > 1)
config = dist_transpiler.DistributeTranspilerConfig()
config.mode = "collective"
config.collective_mode = "grad_allreduce"
t = dist_transpiler.DistributeTranspiler(config=config)
t.transpile(
trainer_id=trainer_id,
trainers=worker_endpoints,
startup_program=self.startup_program,
program=self.test_program,
current_endpoint=current_endpoint)
gpu_id = int(os.getenv("FLAGS_selected_gpus", 0))
place = fluid.CUDAPlace(gpu_id)
exe = fluid.Executor(place)
exe.run(self.startup_program)
test_list, test_name_list = reader.test(
self.dataset_dir, self.val_targets)
test_program = self.test_program
#test_program = test_program._prune(emb)
assert self.checkpoint_dir, "No checkpoint found for test."
self.load_checkpoint(executor=exe, main_program=test_program,
load_for_train=False)
feeder = fluid.DataFeeder(place=place,
feed_list=['image', 'label'], program=test_program)
fetch_list = [emb.name, acc1.name, acc5.name]
real_test_batch_size = self.global_test_batch_size
test_start = time.time()
for i in xrange(len(test_list)):
data_list, issame_list = test_list[i]
embeddings_list = []
for j in xrange(len(data_list)):
data = data_list[j]
embeddings = None
parallel_test_steps = data.shape[0] // real_test_batch_size
beg = 0
end = 0
for idx in range(parallel_test_steps):
start = idx * real_test_batch_size
offset = trainer_id * self.test_batch_size
begin = start + offset
end = begin + self.test_batch_size
_data = []
for k in xrange(begin, end):
_data.append((data[k], 0))
assert len(_data) == self.test_batch_size
[_embeddings, acc1, acc5] = exe.run(test_program,
fetch_list = fetch_list, feed=feeder.feed(_data),
use_program_cache=True)
if embeddings is None:
embeddings = np.zeros((data.shape[0], _embeddings.shape[1]))
embeddings[start:start+real_test_batch_size, :] = _embeddings[:, :]
beg = parallel_test_steps * real_test_batch_size
while beg < data.shape[0]:
end = min(beg + self.test_batch_size, data.shape[0])
count = end - beg
_data = []
for k in xrange(end - self.test_batch_size, end):
_data.append((data[k], 0))
[_embeddings, acc1, acc5] = exe.run(test_program,
fetch_list = fetch_list, feed=feeder.feed(_data),
use_program_cache=True)
_embeddings = _embeddings[0:self.test_batch_size,:]
embeddings[beg:end, :] = _embeddings[(self.test_batch_size-count):, :]
beg = end
embeddings_list.append(embeddings)
xnorm = 0.0
xnorm_cnt = 0
for embed in embeddings_list:
xnorm += np.sqrt((embed * embed).sum(axis=1)).sum(axis=0)
xnorm_cnt += embed.shape[0]
xnorm /= xnorm_cnt
embeddings = embeddings_list[0] + embeddings_list[1]
embeddings = sklearn.preprocessing.normalize(embeddings)
_, _, accuracy, val, val_std, far = evaluate(embeddings, issame_list, nrof_folds=10)
acc, std = np.mean(accuracy), np.std(accuracy)
print('[%s][%d]XNorm: %f' % (test_name_list[i], pass_id, xnorm))
print('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' % (test_name_list[i], pass_id, acc, std))
sys.stdout.flush()
test_end = time.time()
print("test time: {}".format(test_end - test_start))
def train(self):
self._check()
trainer_id = self.trainer_id
num_trainers = self.num_trainers
role = role_maker.PaddleCloudRoleMaker(is_collective=True)
fleet.init(role)
strategy = DistributedStrategy()
strategy.mode = "collective"
strategy.collective_mode = "grad_allreduce"
self.fleet = fleet
self.strategy = strategy
train_emb, train_loss, train_acc1, train_acc5, optimizer = \
self.build_program(True, False)
if self.with_test:
test_emb, test_loss, test_acc1, test_acc5, _ = \
self.build_program(False, True)
test_list, test_name_list = reader.test(
self.dataset_dir, self.val_targets)
test_program = self.test_program
self.append_broadcast_ops(test_program)
if self.loss_type in ["dist_softmax", "dist_arcface"]:
global_lr = optimizer._optimizer._global_learning_rate(
program=self.train_program)
else:
global_lr = optimizer._global_learning_rate(
program=self.train_program)
origin_prog = fleet._origin_program
train_prog = fleet.main_program
if trainer_id == 0:
with open('start.program', 'w') as fout:
program_to_code(self.startup_program, fout, True)
with open('main.program', 'w') as fout:
program_to_code(train_prog, fout, True)
with open('origin.program', 'w') as fout:
program_to_code(origin_prog, fout, True)
with open('test.program', 'w') as fout:
program_to_code(test_program, fout, True)
gpu_id = int(os.getenv("FLAGS_selected_gpus", 0))
place = fluid.CUDAPlace(gpu_id)
exe = fluid.Executor(place)
exe.run(self.startup_program)
if self.with_test:
test_feeder = fluid.DataFeeder(place=place,
feed_list=['image', 'label'], program=test_program)
fetch_list_test = [test_emb.name, test_acc1.name, test_acc5.name]
real_test_batch_size = self.global_test_batch_size
if self.checkpoint_dir == "":
load_checkpoint = False
else:
load_checkpoint = True
if load_checkpoint:
self.load_checkpoint(executor=exe, main_program=origin_prog)
if self.train_reader is None:
train_reader = paddle.batch(reader.arc_train(
self.dataset_dir, self.num_classes),
batch_size=self.train_batch_size)
else:
train_reader = self.train_reader
feeder = fluid.DataFeeder(place=place,
feed_list=['image', 'label'], program=origin_prog)
fetch_list = [train_loss.name, global_lr.name,
train_acc1.name, train_acc5.name]
local_time = 0.0
nsamples = 0
inspect_steps = 200
global_batch_size = self.global_train_batch_size
for pass_id in range(self.train_epochs):
train_info = [[], [], [], []]
local_train_info = [[], [], [], []]
for batch_id, data in enumerate(train_reader()):
nsamples += global_batch_size
t1 = time.time()
loss, lr, acc1, acc5 = exe.run(train_prog,
feed=feeder.feed(data), fetch_list=fetch_list,
use_program_cache=True)
t2 = time.time()
period = t2 - t1
local_time += period
train_info[0].append(np.array(loss)[0])
train_info[1].append(np.array(lr)[0])
local_train_info[0].append(np.array(loss)[0])
local_train_info[1].append(np.array(lr)[0])
if batch_id % inspect_steps == 0:
avg_loss = np.mean(local_train_info[0])
avg_lr = np.mean(local_train_info[1])
print("Pass:%d batch:%d lr:%f loss:%f qps:%.2f acc1:%.4f acc5:%.4f" % (
pass_id, batch_id, avg_lr, avg_loss, nsamples / local_time,
acc1, acc5))
local_time = 0
nsamples = 0
local_train_info = [[], [], [], []]
train_loss = np.array(train_info[0]).mean()
print("End pass {0}, train_loss {1}".format(pass_id, train_loss))
sys.stdout.flush()
if self.with_test:
test_start = time.time()
for i in xrange(len(test_list)):
data_list, issame_list = test_list[i]
embeddings_list = []
for j in xrange(len(data_list)):
data = data_list[j]
embeddings = None
parallel_test_steps = data.shape[0] // real_test_batch_size
beg = 0
end = 0
for idx in range(parallel_test_steps):
start = idx * real_test_batch_size
offset = trainer_id * self.test_batch_size
begin = start + offset
end = begin + self.test_batch_size
_data = []
for k in xrange(begin, end):
_data.append((data[k], 0))
assert len(_data) == self.test_batch_size
[_embeddings, acc1, acc5] = exe.run(test_program,
fetch_list = fetch_list_test, feed=test_feeder.feed(_data),
use_program_cache=True)
if embeddings is None:
embeddings = np.zeros((data.shape[0], _embeddings.shape[1]))
embeddings[start:start+real_test_batch_size, :] = _embeddings[:, :]
beg = parallel_test_steps * real_test_batch_size
while beg < data.shape[0]:
end = min(beg + self.test_batch_size, data.shape[0])
count = end - beg
_data = []
for k in xrange(end - self.test_batch_size, end):
_data.append((data[k], 0))
[_embeddings, acc1, acc5] = exe.run(test_program,
fetch_list = fetch_list_test, feed=test_feeder.feed(_data),
use_program_cache=True)
_embeddings = _embeddings[0:self.test_batch_size,:]
embeddings[beg:end, :] = _embeddings[(self.test_batch_size-count):, :]
beg = end
embeddings_list.append(embeddings)
xnorm = 0.0
xnorm_cnt = 0
for embed in embeddings_list:
xnorm += np.sqrt((embed * embed).sum(axis=1)).sum(axis=0)
xnorm_cnt += embed.shape[0]
xnorm /= xnorm_cnt
embeddings = embeddings_list[0] + embeddings_list[1]
embeddings = sklearn.preprocessing.normalize(embeddings)
_, _, accuracy, val, val_std, far = evaluate(embeddings, issame_list, nrof_folds=10)
acc, std = np.mean(accuracy), np.std(accuracy)
print('[%s][%d]XNorm: %f' % (test_name_list[i], pass_id, xnorm))
print('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' % (test_name_list[i], pass_id, acc, std))
sys.stdout.flush()
test_end = time.time()
print("test time: {}".format(test_end - test_start))
#save model
if self.model_save_dir:
model_save_dir = os.path.join(
self.model_save_dir, str(pass_id))
if not os.path.exists(model_save_dir):
os.makedirs(model_save_dir)
if trainer_id == 0:
fluid.io.save_persistables(exe,
model_save_dir,
origin_prog)
else:
def save_var(var):
to_save = "dist@" in var.name and '@rank@' in var.name
return to_save and var.persistable
fluid.io.save_vars(exe, model_save_dir,
origin_prog, predicate=save_var)
#save training info
if self.model_save_dir and trainer_id == 0:
config_file = os.path.join(
self.model_save_dir, str(pass_id), 'meta.pickle')
train_info = dict()
train_info["pretrain_nranks"] = self.num_trainers
train_info["emb_dim"] = self.emb_dim
train_info['num_classes'] = self.num_classes
with open(config_file, 'wb') as f:
pickle.dump(train_info, f)
#upload model
if self.model_save_dir and self.fs_name and trainer_id == 0:
self.put_files_to_hdfs(self.model_save_dir)
if __name__ == '__main__':
ins = Entry()
ins.train()
文件已添加
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import resnet
from .resnet import *
from . import base_model
from .base_model import *
__all__ = []
__all__ += resnet.__all__
__all__ += base_model.__all__
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import os
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid import unique_name
import dist_algo
__all__ = ["BaseModel"]
class BaseModel(object):
"""
Base class for custom models.
The sub-class must implement the build_network method,
which constructs the custom model. And we will add the
distributed fc layer for you automatically.
"""
def __init__(self):
super(BaseModel, self).__init__()
def build_network(self, input, label, is_train=True):
"""
Construct the custom model, and we will add the
distributed fc layer for you automatically.
"""
raise NotImplementedError(
"You must implement this method in your sub class.")
def get_output(self,
input,
label,
num_classes,
is_train=True,
param_attr=None,
bias_attr=None,
loss_type="dist_softmax",
margin=0.5,
scale=64.0):
"""
Add the distributed fc layer for the custom model.
"""
supported_loss_types = ["dist_softmax", "dist_arcface",
"softmax", "arcface"]
assert loss_type in supported_loss_types, \
"Supported loss types: {}, but given: {}".format(
supported_loss_types, loss_type)
nranks = int(os.getenv("PADDLE_TRAINERS_NUM", 1))
rank_id = int(os.getenv("PADDLE_TRAINER_ID", 0))
emb = self.build_network(input, label, is_train)
if loss_type == "softmax":
loss, prob = self.fc_classify(emb,
label,
num_classes,
param_attr,
bias_attr)
elif loss_type == "arcface":
loss, prob = self.fc_arcface(emb,
label,
num_classes,
param_attr,
margin,
scale)
elif loss_type == "dist_arcface":
loss = dist_algo._distributed_arcface_classify(
x=emb, label=label, class_num=num_classes,
nranks=nranks, rank_id=rank_id, margin=margin,
logit_scale=scale, param_attr=param_attr)
prob = None
elif loss_type == "dist_softmax":
loss = dist_algo._distributed_softmax_classify(
x=emb, label=label, class_num=num_classes,
nranks=nranks, rank_id=rank_id, param_attr=param_attr,
use_bias=True, bias_attr=bias_attr)
prob = None
return emb, loss, prob
def fc_classify(self, input, label, out_dim, param_attr, bias_attr):
if param_attr is None:
stdv = 1.0 / math.sqrt(input.shape[1] * 1.0)
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Uniform(-stdv, stdv))
out = fluid.layers.fc(input=input,
size=out_dim,
param_attr=param_attr,
bias_attr=bias_attr)
loss, prob = fluid.layers.softmax_with_cross_entropy(logits=out,
label=label, return_softmax=True)
avg_loss = fluid.layers.mean(x=loss)
return avg_loss, prob
def arcface(self, input, label, out_dim, param_attr, margin, scale):
input_norm = fluid.layers.sqrt(
fluid.layers.reduce_sum(fluid.layers.square(input), dim=1))
input = fluid.layers.elementwise_div(input, input_norm, axis=0)
if param_attr is None:
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Xavier(
uniform=False, fan_in=0.0))
weight = fluid.layers.create_parameter(
shape=[input.shape[1], out_dim],
dtype='float32',
name=unique_name.generate('final_fc_w'),
attr=param_attr)
weight_norm = fluid.layers.sqrt(
fluid.layers.reduce_sum(fluid.layers.square(weight), dim=0))
weight = fluid.layers.elementwise_div(weight, weight_norm, axis=1)
cos = fluid.layers.mul(input, weight)
theta = fluid.layers.acos(cos)
margin_cos = fluid.layers.cos(theta + margin)
one_hot = fluid.layers.one_hot(label, out_dim)
diff = (margin_cos - cos) * one_hot
target_cos = cos + diff
logit = fluid.layers.scale(target_cos, scale=scale)
loss, prob = fluid.layers.softmax_with_cross_entropy(
logits=logit, label=label, return_softmax=True)
avg_loss = fluid.layers.mean(x=loss)
one_hot.stop_gradient = True
return avg_loss, prob
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import math
from six.moves import reduce
import paddle.fluid as fluid
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.framework import Variable, default_startup_program
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.initializer import Normal, Constant
import paddle.fluid.layers.nn as nn
import paddle.fluid.layers.ops as ops
import paddle.fluid.layers.collective as collective
from paddle.fluid.optimizer import Optimizer
class DistributedClassificationOptimizer(Optimizer):
'''
A optimizer wrapper to generate backward network for distributed
classification training of model parallelism.
'''
def __init__(self, optimizer, batch_size, use_fp16=False):
self._optimizer = optimizer
self._batch_size = batch_size
self._use_fp16 = use_fp16
def minimize(self,
loss,
startup_program=None,
parameter_list=None,
no_grad_set=None):
assert loss._get_info('shard_logit')
shard_logit = loss._get_info('shard_logit')
shard_prob = loss._get_info('shard_prob')
shard_label = loss._get_info('shard_label')
shard_dim = loss._get_info('shard_dim')
op_maker = fluid.core.op_proto_and_checker_maker
op_role_key = op_maker.kOpRoleAttrName()
op_role_var_key = op_maker.kOpRoleVarAttrName()
backward_role = int(op_maker.OpRole.Backward)
loss_backward_role = int(op_maker.OpRole.Loss) | int(
op_maker.OpRole.Backward)
# minimize a scalar of reduce_sum to generate the backward network
scalar = fluid.layers.reduce_sum(shard_logit)
ret = self._optimizer.minimize(scalar)
block = loss.block
# remove the unnecessary ops
index = 0
for i, op in enumerate(block.ops):
if op.all_attrs()[op_role_key] == loss_backward_role:
index = i
break
assert block.ops[index - 1].type == 'reduce_sum'
assert block.ops[index].type == 'fill_constant'
assert block.ops[index + 1].type == 'reduce_sum_grad'
block._remove_op(index + 1)
block._remove_op(index)
block._remove_op(index - 1)
# insert the calculated gradient
dtype = shard_logit.dtype
shard_one_hot = fluid.layers.create_tensor(dtype, name='shard_one_hot')
block._insert_op(
index - 1,
type='one_hot',
inputs={'X': shard_label},
outputs={'Out': shard_one_hot},
attrs={
'depth': shard_dim,
'allow_out_of_range': True,
op_role_key: backward_role
})
shard_logit_grad = fluid.layers.create_tensor(
dtype, name=fluid.backward._append_grad_suffix_(shard_logit.name))
block._insert_op(
index,
type='elementwise_sub',
inputs={'X': shard_prob,
'Y': shard_one_hot},
outputs={'Out': shard_logit_grad},
attrs={op_role_key: backward_role})
block._insert_op(
index + 1,
type='scale',
inputs={'X': shard_logit_grad},
outputs={'Out': shard_logit_grad},
attrs={
'scale': 1.0 / self._batch_size,
op_role_key: loss_backward_role
})
return ret
class DistributedClassifier(object):
'''
Tookit for distributed classification, in which the parameter of the last
full-connected layer is distributed to all trainers
'''
def __init__(self, nclasses, nranks, rank_id, layer_helper):
self.nclasses = nclasses
self.nranks = nranks
self.rank_id = rank_id
self._layer_helper = layer_helper
self.shard_dim = (nclasses + nranks - 1) // nranks
self.padding_dim = 0
self.is_equal_division = True
if nclasses % nranks != 0:
self.is_equal_division = False
if rank_id == nranks - 1:
other_shard_dim = self.shard_dim
self.shard_dim = nclasses % other_shard_dim
self.padding_dim = other_shard_dim - self.shard_dim
def create_parameter(self,
dtype,
in_dim,
param_attr=None,
bias_attr=None,
transpose_weight=False,
use_bias=True):
if param_attr is None:
stdv = math.sqrt(2.0 / (in_dim + self.nclasses))
param_attr = ParamAttr(initializer=Normal(scale=stdv))
weight_shape = [self.shard_dim, in_dim
] if transpose_weight else [in_dim, self.shard_dim]
weight = self._layer_helper.create_parameter(
shape=weight_shape, dtype=dtype, attr=param_attr, is_bias=False)
# avoid distributed parameter allreduce gradients
weight.is_distributed = True
# avoid distributed parameter broadcasting in startup program
default_startup_program().global_block().vars[
weight.name].is_distributed = True
bias = None
if use_bias:
bias = self._layer_helper.create_parameter(
shape=[self.shard_dim],
attr=bias_attr,
dtype=dtype,
is_bias=True)
bias.is_distributed = True
default_startup_program().global_block().vars[
bias.name].is_distributed = True
return weight, bias
def softmax_with_cross_entropy(self, shard_logit, shard_label):
shard_max = nn.reduce_max(shard_logit, dim=1, keep_dim=True)
global_max = collective._c_allreduce(
shard_max, reduce_type='max', use_calc_stream=True)
shard_logit_new = nn.elementwise_sub(shard_logit, global_max)
shard_exp = ops.exp(shard_logit_new)
shard_demon = nn.reduce_sum(shard_exp, dim=1, keep_dim=True)
global_demon = collective._c_allreduce(
shard_demon, reduce_type='sum', use_calc_stream=True)
global_log_demon = nn.log(global_demon)
shard_log_prob = shard_logit_new - global_log_demon
shard_prob = ops.exp(shard_log_prob)
shard_one_hot = nn.one_hot(
shard_label, depth=self.shard_dim, allow_out_of_range=True)
target_log_prob = nn.reduce_min(
shard_log_prob * shard_one_hot, dim=1, keep_dim=True)
shard_loss = nn.scale(target_log_prob, scale=-1.0)
global_loss = collective._c_reducescatter(
shard_loss, nranks=self.nranks, use_calc_stream=True)
return global_loss, shard_prob
def softmax_classify(self,
x,
label,
param_attr=None,
use_bias=True,
bias_attr=None):
flatten_dim = reduce(lambda a, b: a * b, x.shape[1:], 1)
weight, bias = self.create_parameter(
dtype=x.dtype,
in_dim=flatten_dim,
param_attr=param_attr,
bias_attr=bias_attr,
use_bias=use_bias)
x_all = collective._c_allgather(
x, nranks=self.nranks, use_calc_stream=True)
label_all = collective._c_allgather(
label, nranks=self.nranks, use_calc_stream=True)
label_all.stop_gradient = True
shard_fc = nn.mul(x_all, weight, x_num_col_dims=1)
if use_bias:
shard_fc = nn.elementwise_add(shard_fc, bias)
shard_label = nn.shard_index(
label_all,
index_num=self.nclasses,
nshards=self.nranks,
shard_id=self.rank_id,
ignore_value=-1)
shard_label.stop_gradient = True
global_loss, shard_prob = self.softmax_with_cross_entropy(shard_fc,
shard_label)
avg_loss = nn.mean(global_loss)
avg_loss._set_info('shard_logit', shard_fc)
avg_loss._set_info('shard_prob', shard_prob)
avg_loss._set_info('shard_label', shard_label)
avg_loss._set_info('shard_dim', self.shard_dim)
return avg_loss
def arcface_classify(self,
x,
label,
margin=0.5,
logit_scale=64,
param_attr=None):
'''
reference: ArcFace. https://arxiv.org/abs/1801.07698
'''
flatten_dim = reduce(lambda a, b: a * b, x.shape[1:], 1)
weight, bias = self.create_parameter(
dtype=x.dtype,
in_dim=flatten_dim,
param_attr=param_attr,
use_bias=False)
# normalize x
x_l2 = ops.sqrt(nn.reduce_sum(ops.square(x), dim=1))
norm_x = nn.elementwise_div(x, x_l2, axis=0)
norm_x_all = collective._c_allgather(
norm_x, nranks=self.nranks, use_calc_stream=True)
label_all = collective._c_allgather(
label, nranks=self.nranks, use_calc_stream=True)
label_all.stop_gradient = True
shard_label = nn.shard_index(
label_all,
index_num=self.nclasses,
nshards=self.nranks,
shard_id=self.rank_id,
ignore_value=-1)
# TODO check necessary
shard_label.stop_gradient = True
# normalize weight
weight_l2 = ops.sqrt(nn.reduce_sum(ops.square(weight), dim=0))
norm_weight = nn.elementwise_div(weight, weight_l2, axis=1)
shard_cos = nn.mul(norm_x_all, norm_weight, x_num_col_dims=1)
theta = ops.acos(shard_cos)
margin_cos = ops.cos(theta + margin)
shard_one_hot = nn.one_hot(
shard_label, depth=self.shard_dim, allow_out_of_range=True)
# TODO check necessary
shard_one_hot.stop_gradient = True
diff = (margin_cos - shard_cos) * shard_one_hot
shard_target_cos = shard_cos + diff
shard_logit = nn.scale(shard_target_cos, scale=logit_scale)
global_loss, shard_prob = self.softmax_with_cross_entropy(shard_logit,
shard_label)
avg_loss = nn.mean(global_loss)
avg_loss._set_info('shard_logit', shard_logit)
avg_loss._set_info('shard_prob', shard_prob)
avg_loss._set_info('shard_label', shard_label)
avg_loss._set_info('shard_dim', self.shard_dim)
return avg_loss
def _distributed_softmax_classify(x,
label,
class_num,
nranks,
rank_id,
param_attr=None,
use_bias=True,
bias_attr=None,
name=None):
'''
Classification layer with FC, softmax and cross entropy calculation of
distibuted version in case of too large number of classes.
Args:
x (Variable): The feature representation of the input samples. This
feature will be flattened into 2-D tensor from dimension index
1. E.g. [32, 1024, 1, 1] will be flattened to [32, 1024].
label (Variable): The label corresponding to the input samples.
class_num (integer): The number of classes of the classification problem.
nranks (integer): The number of ranks of distributed trainers.
rank_id (integer): The rank index of the current trainer.
param_attr (ParamAttr, default None): The parameter attribute for
learnable distributed parameters/weights of this layer.
use_bias (float, default 64.0): The scale factor for logit value
of cosine range.
name (str, default None): The name of this layer.
Returns:
Variable: The ArcFace loss.
Examples:
.. code-block:: python
import paddle.fluid as fluid
input = fluid.layers.data(name="input",
shape=[32, 1024],
dtype='float32',
append_batch_size=False)
label = fluid.layers.data(name="label",
shape=[32, 1],
dtype='int64',
append_batch_size=False)
y = fluid.layers.collective.distributed_softmax_classify(x=input,
label=label,
class_num=1000,
nranks=8,
rank_id=0)
'''
if name is None:
name = 'dist@softmax@rank@%05d' % rank_id
helper = LayerHelper(name, **locals())
classifier = DistributedClassifier(class_num, nranks, rank_id, helper)
return classifier.softmax_classify(x, label, param_attr, use_bias,
bias_attr)
def _distributed_arcface_classify(x,
label,
class_num,
nranks,
rank_id,
margin=0.5,
logit_scale=64.0,
param_attr=None,
name=None):
'''
Classification layer with ArcFace loss of distibuted version in case of
too large number of classes. the equation is
.. math::
L=-\frac{1}{N}\sum^N_{i=1}\log\frac{e^{s(cos(\theta_{y_i}+m))}}{e^{s(cos(\theta_{y_i}+m))}+\sum^n_{j=1,j\neq y_i} e^{scos\theta_{y_i}}}
where the :math: `\theta_{y_i}` is the angle between the feature :math: `x` and
the representation of class :math: `i`. The details of ArcFace loss
could be referred to https://arxiv.org/abs/1801.07698.
Args:
x (Variable): The feature representation of the input samples. This
feature will be flattened into 2-D tensor from dimension index
1. E.g. [32, 1024, 1, 1] will be flattened to [32, 1024].
label (Variable): The label corresponding to the input samples.
class_num (integer): The number of classes of the classification problem.
nranks (integer): The number of ranks of distributed trainers.
rank_id (integer): The rank index of the current trainer.
margin (float, default 0.5): The angular margin penalty to enhance
the intra-class compactness and inter-class discrepancy.
logit_scale (float, default 64.0): The scale factor for logit value
of cosine range.
param_attr (ParamAttr, default None): The parameter attribute for
learnable distributed parameters/weights of this layer.
name (str, default None): The name of this layer.
Returns:
Variable: The ArcFace loss.
Examples:
.. code-block:: python
import paddle.fluid as fluid
input = fluid.layers.data(name="input",
shape=[32, 1024],
dtype='float32',
append_batch_size=False)
label = fluid.layers.data(name="label",
shape=[32, 1],
dtype='int64',
append_batch_size=False)
y = fluid.layers.collective.distributed_arcface_classify(x=input,
label=label,
class_num=1000,
nranks=8,
rank_id=0)
'''
if name is None:
name = 'dist@arcface@rank@%05d' % rank_id
helper = LayerHelper(name, **locals())
classifier = DistributedClassifier(class_num, nranks, rank_id, helper)
return classifier.arcface_classify(
x=x,
label=label,
margin=margin,
logit_scale=logit_scale,
param_attr=param_attr)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.fluid as fluid
import math
import os
import numpy as np
from paddle.fluid import unique_name
from .base_model import BaseModel
__all__ = ["ResNet", "ResNet50", "ResNet101", "ResNet152"]
class ResNet(BaseModel):
def __init__(self, layers=50, emb_dim=512):
super(ResNet, self).__init__()
self.layers = layers
self.emb_dim = emb_dim
def build_network(self,
input,
label,
is_train):
layers = self.layers
supported_layers = [50, 101, 152]
assert layers in supported_layers, \
"supported layers {}, but given {}".format(supported_layers, layers)
if layers == 50:
depth = [3, 4, 14, 3]
num_filters = [64, 128, 256, 512]
elif layers == 101:
depth = [3, 4, 23, 3]
num_filters = [256, 512, 1024, 2048]
elif layers == 152:
depth = [3, 8, 36, 3]
num_filters = [256, 512, 1024, 2048]
conv = self.conv_bn_layer(
input=input, num_filters=64, filter_size=3, stride=1,
pad=1, act='prelu', is_train=is_train)
for block in range(len(depth)):
for i in range(depth[block]):
conv = self.bottleneck_block(
input=conv,
num_filters=num_filters[block],
stride=2 if i == 0 else 1,
is_train=is_train)
bn = fluid.layers.batch_norm(input=conv, act=None, epsilon=2e-05,
is_test=False if is_train else True)
drop = fluid.layers.dropout(x=bn, dropout_prob=0.4,
dropout_implementation='upscale_in_train',
is_test=False if is_train else True)
fc = fluid.layers.fc(
input=drop,
size=self.emb_dim,
act=None,
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Xavier(uniform=False, fan_in=0.0)),
bias_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.ConstantInitializer()))
emb = fluid.layers.batch_norm(input=fc, act=None, epsilon=2e-05,
is_test=False if is_train else True)
return emb
def conv_bn_layer(self,
input,
num_filters,
filter_size,
stride=1,
pad=0,
groups=1,
is_train=True,
act=None):
conv = fluid.layers.conv2d(
input=input,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=pad,
groups=groups,
act=None,
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Xavier(
uniform=False, fan_in=0.0)),
bias_attr=False)
if act == 'prelu':
bn = fluid.layers.batch_norm(input=conv, act=None, epsilon=2e-05,
momentum=0.9, is_test=False if is_train else True)
return fluid.layers.prelu(bn, mode="all",
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Constant(0.25)))
else:
return fluid.layers.batch_norm(input=conv, act=act, epsilon=2e-05,
is_test=False if is_train else True)
def shortcut(self, input, ch_out, stride, is_train):
ch_in = input.shape[1]
if ch_in != ch_out or stride != 1:
return self.conv_bn_layer(input, ch_out, 1, stride,
is_train=is_train)
else:
return input
def bottleneck_block(self, input, num_filters, stride, is_train):
if self.layers < 101:
bn1 = fluid.layers.batch_norm(input=input, act=None, epsilon=2e-05,
is_test=False if is_train else True)
conv1 = self.conv_bn_layer(
input=bn1, num_filters=num_filters, filter_size=3, pad=1,
act='prelu', is_train=is_train)
conv2 = self.conv_bn_layer(
input=conv1, num_filters=num_filters, filter_size=3,
stride=stride, pad=1, act=None, is_train=is_train)
else:
bn0 = fluid.layers.batch_norm(input=input, act=None, epsilon=2e-05,
is_test=False if is_train else True)
conv0 = self.conv_bn_layer(
input=bn0, num_filters=num_filters/4, filter_size=1, pad=0,
act='prelu', is_train=is_train)
conv1 = self.conv_bn_layer(
input=conv0, num_filters=num_filters/4, filter_size=3, pad=1,
act='prelu', is_train=is_train)
conv2 = self.conv_bn_layer(
input=conv1, num_filters=num_filters, filter_size=1,
stride=stride, pad=0, act=None, is_train=is_train)
short = self.shortcut(input, num_filters, stride, is_train=is_train)
return fluid.layers.elementwise_add(x=short, y=conv2, act=None)
def ResNet50(emb_dim=512):
model = ResNet(layers=50, emb_dim=emb_dim)
return model
def ResNet101(emb_dim=512):
model = ResNet(layers=101, emb_dim=emb_dim)
return model
def ResNet152(emb_dim=512):
model = ResNet(layers=152, emb_dim=emb_dim)
return model
#!/usr/bin/env bash
export FLAGS_cudnn_exhaustive_search=true
export FLAGS_fraction_of_gpu_memory_to_use=0.96
export FLAGS_eager_delete_tensor_gb=0.0
selected_gpus="0,1,2,3,4,5,6,7"
#selected_gpus="4,5,6"
python -m paddle.distributed.launch \
--selected_gpus $selected_gpus \
--log_dir mylog \
do_train.py \
--model=ResNet_ARCFACE50 \
--loss_type=dist_softmax \
--model_save_dir=output \
--margin=0.5 \
--train_batch_size 32 \
--class_dim 85742 \
--with_test=True
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import math
import random
import pickle
import base64
import functools
import numpy as np
import paddle
from PIL import Image, ImageEnhance
try:
from StringIO import StringIO
except ImportError:
from io import StringIO
random.seed(0)
DATA_DIM = 112
THREAD = 8
BUF_SIZE = 10240
img_mean = np.array([127.5, 127.5, 127.5]).reshape((3, 1, 1))
img_std = np.array([128.0, 128.0, 128.0]).reshape((3, 1, 1))
def resize_short(img, target_size):
percent = float(target_size) / min(img.size[0], img.size[1])
resized_width = int(round(img.size[0] * percent))
resized_height = int(round(img.size[1] * percent))
img = img.resize((resized_width, resized_height), Image.BILINEAR)
return img
def Scale(img, size):
w, h = img.size
if (w <= h and w == size) or (h <= w and h == size):
return img
if w < h:
ow = size
oh = int(size * h / w)
return img.resize((ow, oh), Image.BILINEAR)
else:
oh = size
ow = int(size * w / h)
return img.resize((ow, oh), Image.BILINEAR)
def CenterCrop(img, size):
w, h = img.size
th, tw = int(size), int(size)
x1 = int(round((w - tw) / 2.))
y1 = int(round((h - th) / 2.))
return img.crop((x1, y1, x1 + tw, y1 + th))
def crop_image(img, target_size, center):
width, height = img.size
size = target_size
if center == True:
w_start = (width - size) / 2
h_start = (height - size) / 2
else:
w_start = random.randint(0, width - size)
h_start = random.randint(0, height - size)
w_end = w_start + size
h_end = h_start + size
img = img.crop((w_start, h_start, w_end, h_end))
return img
def RandomResizedCrop(img, size):
for attempt in range(10):
area = img.size[0] * img.size[1]
target_area = random.uniform(0.08, 1.0) * area
aspect_ratio = random.uniform(3. / 4, 4. / 3)
w = int(round(math.sqrt(target_area * aspect_ratio)))
h = int(round(math.sqrt(target_area / aspect_ratio)))
if random.random() < 0.5:
w, h = h, w
if w <= img.size[0] and h <= img.size[1]:
x1 = random.randint(0, img.size[0] - w)
y1 = random.randint(0, img.size[1] - h)
img = img.crop((x1, y1, x1 + w, y1 + h))
assert(img.size == (w, h))
return img.resize((size, size), Image.BILINEAR)
w = min(img.size[0], img.size[1])
i = (img.size[1] - w) // 2
j = (img.size[0] - w) // 2
img = img.crop((i, j, i+w, j+w))
img = img.resize((size, size), Image.BILINEAR)
return img
def random_crop(img, size, scale=[0.08, 1.0], ratio=[3. / 4., 4. / 3.]):
aspect_ratio = math.sqrt(random.uniform(*ratio))
w = 1. * aspect_ratio
h = 1. / aspect_ratio
bound = min((float(img.size[0]) / img.size[1]) / (w**2),
(float(img.size[1]) / img.size[0]) / (h**2))
scale_max = min(scale[1], bound)
scale_min = min(scale[0], bound)
target_area = img.size[0] * img.size[1] * random.uniform(scale_min,
scale_max)
target_size = math.sqrt(target_area)
w = int(target_size * w)
h = int(target_size * h)
i = random.randint(0, img.size[0] - w)
j = random.randint(0, img.size[1] - h)
img = img.crop((i, j, i + w, j + h))
img = img.resize((size, size), Image.BILINEAR)
return img
def rotate_image(img):
angle = random.randint(-10, 10)
img = img.rotate(angle)
return img
def distort_color(img):
def random_brightness(img, lower=0.8, upper=1.2):
e = random.uniform(lower, upper)
return ImageEnhance.Brightness(img).enhance(e)
def random_contrast(img, lower=0.8, upper=1.2):
e = random.uniform(lower, upper)
return ImageEnhance.Contrast(img).enhance(e)
def random_color(img, lower=0.8, upper=1.2):
e = random.uniform(lower, upper)
return ImageEnhance.Color(img).enhance(e)
ops = [random_brightness, random_contrast, random_color]
random.shuffle(ops)
img = ops[0](img)
img = ops[1](img)
img = ops[2](img)
return img
def process_image_imagepath(sample,
class_dim,
color_jitter,
rotate,
rand_mirror,
normalize):
img_data = base64.b64decode(sample[0])
img = Image.open(StringIO(img_data))
if rotate:
img = rotate_image(img)
img = RandomResizedCrop(img, DATA_DIM)
if color_jitter:
img = distort_color(img)
if rand_mirror:
if random.randint(0, 1) == 1:
img = img.transpose(Image.FLIP_LEFT_RIGHT)
if img.mode != 'RGB':
img = img.convert('RGB')
img = np.array(img).astype('float32').transpose((2, 0, 1))
if normalize:
img -= img_mean
img /= img_std
assert sample[1] < class_dim, \
"label of train dataset should be less than the class_dim."
return img, sample[1]
def arc_iterator(file_list,
class_dim,
color_jitter=False,
rotate=False,
rand_mirror=False,
normalize=False):
trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
trainer_count = int(os.getenv("PADDLE_TRAINERS_NUM", "1"))
def reader():
with open(file_list, 'r') as f:
flist = f.readlines()
assert len(flist) % trainer_count == 0, \
"Number of files should be divisible by trainer count, " \
"run base64 file preprocessing tool first."
num_files_per_trainer = len(flist) // trainer_count
start = num_files_per_trainer * trainer_id
end = start + num_files_per_trainer
flist = flist[start:end]
for file in flist:
with open(file, 'r') as f:
for line in f.xreadlines():
line = line.strip().split('\t')
image, label = line[0], line[1]
yield image, label
mapper = functools.partial(process_image_imagepath,
class_dim=class_dim, color_jitter=color_jitter, rotate=rotate,
rand_mirror=rand_mirror, normalize=normalize)
return paddle.reader.xmap_readers(mapper, reader, THREAD, BUF_SIZE)
def load_bin(path, image_size):
bins, issame_list = pickle.load(open(path, 'rb'))
data_list = []
for flip in [0, 1]:
data = np.empty((len(issame_list)*2, 3, image_size[0], image_size[1]))
data_list.append(data)
for i in xrange(len(issame_list)*2):
_bin = bins[i]
if not isinstance(_bin, basestring):
_bin = _bin.tostring()
img_ori = Image.open(StringIO(_bin))
for flip in [0, 1]:
img = img_ori.copy()
if flip == 1:
img = img.transpose(Image.FLIP_LEFT_RIGHT)
if img.mode != 'RGB':
img = img.convert('RGB')
img = np.array(img).astype('float32').transpose((2, 0, 1))
img -= img_mean
img /= img_std
data_list[flip][i][:] = img
if i % 1000 == 0:
print('loading bin', i)
print(data_list[0].shape)
return (data_list, issame_list)
def train(data_dir, file_list, num_classes):
file_path = os.path.join(data_dir, file_list)
return arc_iterator(file_path, class_dim=num_classes, color_jitter=False,
rotate=False, rand_mirror=True, normalize=True)
def test(data_dir, datasets):
test_list = []
test_name_list = []
for name in datasets.split(','):
path = os.path.join(data_dir, name+".bin")
if os.path.exists(path):
data_set = load_bin(path, (DATA_DIM, DATA_DIM))
test_list.append(data_set)
test_name_list.append(name)
print('test', name)
return test_list, test_name_list
import os
import math
import random
import pickle
import functools
import numpy as np
import paddle
from PIL import Image, ImageEnhance
try:
from StringIO import StringIO
except ImportError:
from io import StringIO
random.seed(0)
DATA_DIM = 112
THREAD = 8
BUF_SIZE = 10240
#TEST_LIST = 'lfw,cfp_fp,agedb_30,cfp_ff'
TEST_LIST = 'lfw'
def get_train_image_list(data_dir):
train_list_file = os.path.join(data_dir, 'label.txt')
train_list = open(train_list_file, "r").readlines()
random.shuffle(train_list)
train_image_list = []
trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
trainer_count = int(os.getenv("PADDLE_TRAINERS_NUM", "1"))
per_node_lines = (len(train_list) + trainer_count - 1) // trainer_count
train_list += train_list[0:per_node_lines
* trainer_count-len(train_list)]
lines = train_list[trainer_id * per_node_lines:(trainer_id + 1)
* per_node_lines]
print("read images from %d, length: %d, total: %d"
% (trainer_id * per_node_lines, per_node_lines, len(train_list)))
for i, item in enumerate(lines):
path, label = item.strip().split()
label = int(label)
train_image_list.append((path, label))
print("train_data size:", len(train_image_list))
return train_image_list
img_mean = np.array([127.5, 127.5, 127.5]).reshape((3, 1, 1))
img_std = np.array([128.0, 128.0, 128.0]).reshape((3, 1, 1))
def resize_short(img, target_size):
percent = float(target_size) / min(img.size[0], img.size[1])
resized_width = int(round(img.size[0] * percent))
resized_height = int(round(img.size[1] * percent))
img = img.resize((resized_width, resized_height), Image.BILINEAR)
return img
def Scale(img, size):
w, h = img.size
if (w <= h and w == size) or (h <= w and h == size):
return img
if w < h:
ow = size
oh = int(size * h / w)
return img.resize((ow, oh), Image.BILINEAR)
else:
oh = size
ow = int(size * w / h)
return img.resize((ow, oh), Image.BILINEAR)
def CenterCrop(img, size):
w, h = img.size
th, tw = int(size), int(size)
x1 = int(round((w - tw) / 2.))
y1 = int(round((h - th) / 2.))
return img.crop((x1, y1, x1 + tw, y1 + th))
def crop_image(img, target_size, center):
width, height = img.size
size = target_size
if center == True:
w_start = (width - size) / 2
h_start = (height - size) / 2
else:
w_start = random.randint(0, width - size)
h_start = random.randint(0, height - size)
w_end = w_start + size
h_end = h_start + size
img = img.crop((w_start, h_start, w_end, h_end))
return img
def RandomResizedCrop(img, size):
for attempt in range(10):
area = img.size[0] * img.size[1]
target_area = random.uniform(0.08, 1.0) * area
aspect_ratio = random.uniform(3. / 4, 4. / 3)
w = int(round(math.sqrt(target_area * aspect_ratio)))
h = int(round(math.sqrt(target_area / aspect_ratio)))
if random.random() < 0.5:
w, h = h, w
if w <= img.size[0] and h <= img.size[1]:
x1 = random.randint(0, img.size[0] - w)
y1 = random.randint(0, img.size[1] - h)
img = img.crop((x1, y1, x1 + w, y1 + h))
assert(img.size == (w, h))
return img.resize((size, size), Image.BILINEAR)
w = min(img.size[0], img.size[1])
i = (img.size[1] - w) // 2
j = (img.size[0] - w) // 2
img = img.crop((i, j, i+w, j+w))
img = img.resize((size, size), Image.BILINEAR)
return img
def random_crop(img, size, scale=[0.08, 1.0], ratio=[3. / 4., 4. / 3.]):
aspect_ratio = math.sqrt(random.uniform(*ratio))
w = 1. * aspect_ratio
h = 1. / aspect_ratio
bound = min((float(img.size[0]) / img.size[1]) / (w**2),
(float(img.size[1]) / img.size[0]) / (h**2))
scale_max = min(scale[1], bound)
scale_min = min(scale[0], bound)
target_area = img.size[0] * img.size[1] * random.uniform(scale_min,
scale_max)
target_size = math.sqrt(target_area)
w = int(target_size * w)
h = int(target_size * h)
i = random.randint(0, img.size[0] - w)
j = random.randint(0, img.size[1] - h)
img = img.crop((i, j, i + w, j + h))
img = img.resize((size, size), Image.BILINEAR)
return img
def rotate_image(img):
angle = random.randint(-10, 10)
img = img.rotate(angle)
return img
def distort_color(img):
def random_brightness(img, lower=0.8, upper=1.2):
e = random.uniform(lower, upper)
return ImageEnhance.Brightness(img).enhance(e)
def random_contrast(img, lower=0.8, upper=1.2):
e = random.uniform(lower, upper)
return ImageEnhance.Contrast(img).enhance(e)
def random_color(img, lower=0.8, upper=1.2):
e = random.uniform(lower, upper)
return ImageEnhance.Color(img).enhance(e)
ops = [random_brightness, random_contrast, random_color]
random.shuffle(ops)
img = ops[0](img)
img = ops[1](img)
img = ops[2](img)
return img
def process_image_imagepath(sample,
class_dim,
color_jitter,
rotate,
rand_mirror,
normalize):
imgpath = sample[0]
img = Image.open(imgpath)
if rotate:
img = rotate_image(img)
img = RandomResizedCrop(img, DATA_DIM)
if color_jitter:
img = distort_color(img)
if rand_mirror:
if random.randint(0, 1) == 1:
img = img.transpose(Image.FLIP_LEFT_RIGHT)
if img.mode != 'RGB':
img = img.convert('RGB')
img = np.array(img).astype('float32').transpose((2, 0, 1))
if normalize:
img -= img_mean
img /= img_std
assert sample[1] < class_dim, \
"label of train dataset should be less than the class_dim."
return img, sample[1]
def arc_iterator(data,
class_dim,
data_dir,
shuffle=False,
color_jitter=False,
rotate=False,
rand_mirror=False,
normalize=False):
def reader():
if shuffle:
random.shuffle(data)
for j in xrange(len(data)):
path, label = data[j]
path = os.path.join(data_dir, path)
yield path, label
mapper = functools.partial(process_image_imagepath, class_dim=class_dim,
color_jitter=color_jitter, rotate=rotate,
rand_mirror=rand_mirror, normalize=normalize)
return paddle.reader.xmap_readers(mapper, reader, THREAD, BUF_SIZE)
def load_bin(path, image_size):
bins, issame_list = pickle.load(open(path, 'rb'))
data_list = []
for flip in [0, 1]:
data = np.empty((len(issame_list)*2, 3, image_size[0], image_size[1]))
data_list.append(data)
for i in xrange(len(issame_list)*2):
_bin = bins[i]
if not isinstance(_bin, basestring):
_bin = _bin.tostring()
img_ori = Image.open(StringIO(_bin))
for flip in [0, 1]:
img = img_ori.copy()
if flip == 1:
img = img.transpose(Image.FLIP_LEFT_RIGHT)
if img.mode != 'RGB':
img = img.convert('RGB')
img = np.array(img).astype('float32').transpose((2, 0, 1))
img -= img_mean
img /= img_std
data_list[flip][i][:] = img
if i % 1000 == 0:
print('loading bin', i)
print(data_list[0].shape)
return (data_list, issame_list)
def arc_train(data_dir, class_dim):
train_image_list = get_train_image_list(data_dir)
return arc_iterator(train_image_list, shuffle=True, class_dim=class_dim,
data_dir=data_dir, color_jitter=False, rotate=False, rand_mirror=True,
normalize=True)
def test(data_dir, datasets):
test_list = []
test_name_list = []
for name in datasets.split(','):
path = os.path.join(data_dir, name+".bin")
if os.path.exists(path):
data_set = load_bin(path, (DATA_DIM, DATA_DIM))
test_list.append(data_set)
test_name_list.append(name)
print('test', name)
return test_list, test_name_list
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
def lr_warmup(learning_rate, warmup_steps, start_lr, end_lr):
"""
Applies linear learning rate warmup for distributed training.
The parameter learning_rate should be a float or a Variable.
lr = start_lr + (warmup_rate * step / warmup_steps), where warmup_rate
is end_lr - start_lr, and step is the current step.
"""
assert (isinstance(end_lr, float))
assert (isinstance(start_lr, float))
linear_step = end_lr - start_lr
with fluid.default_main_program()._lr_schedule_guard():
lr = fluid.layers.tensor.create_global_var(
shape=[1],
value=0.0,
dtype='float32',
persistable=True,
name="learning_rate_warmup")
global_step = fluid.layers.learning_rate_scheduler._decay_step_counter()
with fluid.layers.control_flow.Switch() as switch:
with switch.case(global_step < warmup_steps):
decayed_lr = start_lr + linear_step * (global_step /
warmup_steps)
fluid.layers.tensor.assign(decayed_lr, lr)
with switch.default():
fluid.layers.tensor.assign(learning_rate, lr)
return lr
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import warnings
import os
import six
import logging
import argparse
import shutil
import pickle
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.transpiler.details import program_to_code
logging.basicConfig(level=logging.INFO,
format='[%(levelname)s %(asctime)s line:%(lineno)d] %(message)s',
datefmt='%d %b %Y %H:%M:%S')
logger = logging.getLogger()
parser = argparse.ArgumentParser(description="""
Tool to convert pretrained distributed fc parameters for inference.
Note that the number of ranks or GPUs for inference can be different
from that for pretraining.""")
parser.add_argument("--name_feature",
type=str,
default="@rank@",
help="Feature for names of distributed fc parameters. "
"For example, by default the name for the "
"distributed fc weight parameter is like "
"dist@xxx@rank@id.w_0 where xxx is softmax or arcface "
"depending on the loss types used and rank_id is the "
"rank_id generating this parameter, and hence the "
"feature cloud be @rank@.")
parser.add_argument("--pretrain_nranks",
type=int,
default=-1,
help="Number of ranks (GPUs) for pre-training.")
parser.add_argument("--nranks",
type=int,
required=True,
help="Number of ranks (GPUs) for inference or finetuning.")
parser.add_argument("--num_classes",
type=int,
default=-1,
help="Number of classes for classification.")
parser.add_argument("--emb_dim",
type=int,
default=-1,
help="Embedding dim.")
parser.add_argument("--pretrained_model_dir",
type=str,
required=True,
default=None,
help="Directory for pretrained model.")
parser.add_argument("--output_dir",
type=str,
required=True,
default=None,
help="Directory for output.")
args = parser.parse_args()
def load_config(args):
"""
Load config file which contains the following information for pretraining:
1. pretrain_nranks (int): number of ranks for pretraining;
2. emb_dim (int): embedding dim for pretraining;
3. num_classes (int): number of classes for classification.
"""
meta_file = os.path.join(args.pretrained_model_dir, 'meta.pickle')
if not os.path.exists(meta_file):
if args.pretrain_nranks < 0 or args.emb_dim < 0 or args.num_classes < 0:
logger.error("Meta file does not exist, you have to set "
"'--pretrain_nranks', '--emb_dim' and '--num_classes "
"parameters manually.")
exit()
logger.debug("Meta file does not exist, make sure you have correctly "
"set --pretrain_nranks ({}), --emb_dim ({}) and "
"--num_classes ({}) parameters manually.".format(
args.pretrain_nranks, args.emb_dim, args.num_classes))
else:
with open(meta_file, 'rb') as handle:
config = pickle.load(handle)
if args.pretrain_nranks < 0:
args.pretrain_nranks = config['pretrain_nranks']
elif args.pretrain_nranks != config['pretrain_nranks']:
logger.error("The --pretrain_nranks ({}) parameter you set is not "
"equal to that ({}) for pretraining, please check "
"it.".format(args.pretrain_nranks,
config['pretrain_nranks']))
exit()
if args.emb_dim < 0:
args.emb_dim = config['emb_dim']
elif args.emb_dim != config['emb_dim']:
logger.error("The --emb_dim ({}) parameter you set is not equal to "
"that ({}) for pretraining, please check it.".format(
args.emb_dim, config['emb_dim']))
exit()
if args.num_classes < 0:
args.num_classes = config['num_classes']
elif args.num_classes != config['num_classes']:
logger.error("The --num_classes ({}) parameter you set is not equal"
" to that ({}) for pretraining, please check "
"it.".format(args.emb_dim, config['emb_dim']))
exit()
logger.debug("Parameters for pretraining: pretrain_nranks ({}), emb_dim "
"({}), and num_classes ({}).".format(args.pretrain_nranks,
args.emb_dim, args.num_classes))
logger.debug("Parameters for inference or finetuning: nranks ({}).".format(
args.nranks))
def find_distfc_var_names(args):
"""
Find all names of pretrained distfc-related parameters,
e.g., dist_softmax_rank_00000.w_0, dist_softmax_rank_00000.b_0 etc.
We assume that names of distfc-related parameters start with the
prefix 'dist'.
"""
var_names = []
model_dir = os.path.abspath(args.pretrained_model_dir)
if not os.path.exists(model_dir):
logger.error("The directory for pretrained model ({}) does not exist, "
"please check it.".format(model_dir))
exit()
logger.info("The directory for pretrained model: {}".format(model_dir))
args.pretrained_model_dir = model_dir
for file in os.listdir(model_dir):
if args.name_feature in file:
var_names.append(file)
assert len(var_names) > 0, \
logger.error("No distributed fc parameters found.")
logger.info("Number of distributed fc parameters: {}.".format(
len(var_names)))
logger.debug("Distributed fc parameters: {}.".format(var_names))
return var_names
def split_load_and_save(args,
name_index,
param_names,
save_rank_id,
remainder,
as_bias,
train_nshards,
train_nranks,
nshards,
dtype="float32"):
var2 = None
advance = False
emb_dim = args.emb_dim
main_program = fluid.Program()
startup_program = fluid.Program()
load_var_name = param_names[name_index]
save_var_name_list = load_var_name.split('.')
save_var_name_list[0] = save_var_name_list[0].split('@')
save_var_name_list[0][-1] = "%05d" % save_rank_id
save_var_name_list[0] = '@'.join(save_var_name_list[0])
save_var_name = '.'.join(save_var_name_list)
last_train_nshards = args.num_classes - (train_nranks - 1) * train_nshards
with fluid.program_guard(main_program, startup_program):
if name_index == train_nranks - 1:
var_dim = last_train_nshards
else:
var_dim = train_nshards
shape = [var_dim] if as_bias else [emb_dim, var_dim]
var = fluid.layers.create_parameter(shape, dtype=dtype,
name=load_var_name)
if as_bias:
var = fluid.layers.slice(var, axes=[0],
starts=[var.shape[0] - remainder], ends=[var.shape[0]])
else:
var = fluid.layers.split(var, [var.shape[1] - remainder, remainder],
dim=1)[1]
save_var_dim = nshards
if remainder < nshards:
if name_index == train_nranks - 1:
save_var_dim = remainder
else:
name_index += 1
advance = True
load_var_name = param_names[name_index]
if name_index == train_nranks - 1:
var_dim = last_train_nshards
else:
var_dim = train_nshards
shape = [var_dim] if as_bias else [emb_dim, var_dim]
var2 = fluid.layers.create_parameter(shape, dtype=dtype,
name=load_var_name)
if remainder + var_dim < nshards:
# The last train rank
save_var_dim = remainder + var_dim
else:
remainder = remainder + var_dim - nshards
elif remainder == nshards:
if name_index == train_nranks - 2:
remainder = last_train_nshards
advance = True
elif name_index < train_nranks - 2:
remainder = train_nshards
advance = True
else:
remainder = remainder - nshards
if var2 is not None:
var = fluid.layers.concat([var, var2], axis=0 if as_bias else 1)
shape = [save_var_dim] if as_bias else [emb_dim, save_var_dim]
to_save_var = fluid.layers.create_parameter(shape, dtype=dtype,
name=save_var_name + '_temp')
if save_var_dim != nshards: # get last dim
if as_bias:
temp_var = fluid.layers.slice(var, axes=[0],
starts=[var.shape[0] - save_var_dim], ends=[var.shape[0]])
else:
temp_var = fluid.layers.split(var,
[var.shape[1] - save_var_dim, save_var_dim], dim=1)[1]
fluid.layers.assign(temp_var, to_save_var)
else:
if as_bias:
temp_var = fluid.layers.slice(var, axes=[0], starts=[0],
ends=[nshards])
else:
temp_var = fluid.layers.split(var,
[nshards, var.shape[1] - nshards], dim=1)[0]
fluid.layers.assign(temp_var, to_save_var)
def expected_var(var):
has_var = os.path.exists(os.path.join(args.pretrained_model_dir,
var.name))
if has_var:
return True
return False
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(startup_program)
fluid.io.load_vars(exe, dirname=args.pretrained_model_dir,
predicate=expected_var, main_program=main_program)
exe.run(main_program)
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
fluid.io.save_vars(exe, args.output_dir, vars=[to_save_var],
main_program=main_program)
srcfile = os.path.join(args.output_dir, to_save_var.name)
dstfile = os.path.join(args.output_dir, save_var_name)
shutil.move(srcfile, dstfile)
return remainder, advance
def split_parameters(args, param_names, as_bias):
"""
Split parameters whose names are in param_names.
Params:
args: command line paramters
param_names: list of names of parameters to split
as_bias: whether parameters to split are as bias or not
"""
num_classes = args.num_classes
train_nranks = args.pretrain_nranks
nranks = args.nranks
train_nshards = (num_classes + train_nranks - 1) // train_nranks
nshards = (num_classes + nranks - 1) // nranks # for inference of finetuning
save_rank_id = 0
remainder_var_dim = train_nshards # remainder dim that is not split in a var
name_index = 0 # index of name of pretrained parameter to process
for save_rank_id in range(nranks):
assert name_index < train_nranks
remainder_var_dim, advance = split_load_and_save(args, name_index,
param_names, save_rank_id, remainder_var_dim, as_bias,
train_nshards, train_nranks, nshards)
name_index += 1 if advance else 0
processed_var_count = name_index + 1
assert processed_var_count == train_nranks, logger.error("Number of "
"pretrained parameters processed ({}) is not equal to the number of "
"ranks ({}) for pretraining.".format(processed_var_count, train_nranks))
assert save_rank_id == nranks - 1, logger.error("Number of saved parameters"
" ({}) is not equal to the number of ranks ({}) for inference or "
"finetuning.".format(save_rank_id + 1, nranks))
def split_distfc_parameters(args,
weight_param_names,
weight_velocity_param_names,
bias_param_names,
bias_velocity_param_names):
"""
Split each distributed fc-related parameter according to number of ranks
for inference or finetuning.
Params:
args: command line paramters
weight_param_names: list of names of weight parameters
bias_param_names: list of names of bias parameters
"""
split_parameters(args, weight_param_names, as_bias=False)
split_parameters(args, weight_velocity_param_names, as_bias=False)
if len(bias_param_names) != 0:
split_parameters(args, bias_param_names, as_bias=True)
split_parameters(args, bias_velocity_param_names, as_bias=True)
def concat_load_and_save(args,
name_index,
param_names,
save_rank_id,
remainder,
as_bias,
train_nshards,
train_nranks,
nshards,
dtype="float32"):
advance = 0
orig_nshards = nshards
emb_dim = args.emb_dim
main_program = fluid.Program()
startup_program = fluid.Program()
load_var_name = param_names[name_index]
save_var_name_list = load_var_name.split('.')
save_var_name_list[0] = save_var_name_list[0].split('@')
save_var_name_list[0][-1] = "%05d" % save_rank_id
save_var_name_list[0] = '@'.join(save_var_name_list[0])
save_var_name = '.'.join(save_var_name_list)
last_train_nshards = args.num_classes - (train_nranks - 1) * train_nshards
with fluid.program_guard(main_program, startup_program):
if name_index == train_nranks - 1:
var_dim = last_train_nshards
else:
var_dim = train_nshards
shape = [var_dim] if as_bias else [emb_dim, var_dim]
var = fluid.layers.create_parameter(shape, dtype=dtype,
name=load_var_name)
if as_bias:
var = fluid.layers.slice(var, axes=[0],
starts=[var.shape[0] - remainder], ends=[var.shape[0]])
else:
var = fluid.layers.split(var, [var.shape[1] - remainder, remainder],
dim=1)[1]
to_concat_var_list = [var]
while remainder < nshards and name_index < train_nranks - 1:
name_index += 1
advance += 1
load_var_name = param_names[name_index]
if name_index == train_nranks - 1:
var_dim = last_train_nshards
else:
var_dim = train_nshards
shape = [var_dim] if as_bias else [emb_dim, var_dim]
var = fluid.layers.create_parameter(shape, dtype=dtype,
name=load_var_name)
to_concat_var_list.append(var)
remainder += var_dim
if len(to_concat_var_list) > 1:
var = fluid.layers.concat(
to_concat_var_list, axis=0 if as_bias else 1)
save_var_dim = nshards
if remainder > nshards:
if as_bias:
var = fluid.layers.slice(var, axes=[0], starts=[0],
ends=[nshards])
else:
var = fluid.layers.split(var,
[nshards, var.shape[1] - nshards], dim=1)[0]
remainder = remainder - nshards
elif remainder == nshards:
if name_index == train_nranks - 2:
#advance += 1 if len(to_concat_var_list) > 1 else 0 # to avoid duplicate add
#name_index += 1 if len(to_concat_var_list) > 1 else 0 # to avoid duplicate add
advance += 1
name_index += 1
remainder = last_train_nshards
elif name_index < train_nranks - 2:
#advance += 1 if len(to_concat_var_list) > 1 else 0 # to avoid duplicate add
#name_index += 1 if len(to_concat_var_list) > 1 else 0 # to avoid duplicate add
advance += 1
name_index += 1
remainder = train_nshards
else:
save_var_dim = remainder
shape = [save_var_dim] if as_bias else [emb_dim, save_var_dim]
to_save_var = fluid.layers.create_parameter(shape, dtype=dtype,
name=save_var_name + '_temp')
fluid.layers.assign(var, to_save_var)
def expected_var(var):
has_var = os.path.exists(os.path.join(args.pretrained_model_dir,
var.name))
if has_var:
return True
return False
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(startup_program)
fluid.io.load_vars(exe, dirname=args.pretrained_model_dir,
predicate=expected_var, main_program=main_program)
exe.run(main_program)
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
fluid.io.save_vars(exe, args.output_dir, vars=[to_save_var],
main_program=main_program)
srcfile = os.path.join(args.output_dir, to_save_var.name)
dstfile = os.path.join(args.output_dir, save_var_name)
shutil.move(srcfile, dstfile)
return remainder, advance
def concat_parameters(args, param_names, as_bias):
"""
Concat parameters whose names are in param_names.
Params:
args: command line paramters
param_names: list of names of parameters to concat
as_bias: whether parameters to split are as bias or not
"""
num_classes = args.num_classes
train_nranks = args.pretrain_nranks
nranks = args.nranks
train_nshards = (num_classes + train_nranks - 1) // train_nranks
nshards = (num_classes + nranks - 1) // nranks # for inference of finetuning
save_rank_id = 0
remainder_dim = train_nshards # remainder dim that is not concatted
name_index = 0 # index of name of pretrained parameter to process
for save_rank_id in range(nranks):
assert name_index < train_nranks
remainder_dim, advance = concat_load_and_save(args,
name_index, param_names, save_rank_id, remainder_dim,
as_bias, train_nshards, train_nranks, nshards)
name_index += advance
processed_var_count = name_index + 1
assert processed_var_count == train_nranks, logger.error("Number of "
"pretrained parameters processed ({}) is not equal to the number of "
"ranks ({}) for pretraining.".format(processed_var_count, train_nranks))
assert save_rank_id == nranks - 1, logger.error("Number of saved parameters"
" ({}) is not equal to the number of ranks ({}) for inference or "
"finetuning.".format(save_rank_id + 1, nranks))
def concat_distfc_parameters(args,
weight_param_names,
weight_velocity_param_names,
bias_param_names,
bias_velocity_param_names):
"""
Concat distributed fc-related parameters according to number of ranks
for inference or finetuning.
Params:
args: command line paramters
weight_param_names: list of names of weight parameters
bias_param_names: list of names of bias parameters
"""
concat_parameters(args, weight_param_names, as_bias=False)
concat_parameters(args, weight_velocity_param_names, as_bias=False)
if len(bias_param_names) != 0:
concat_parameters(args, bias_param_names, as_bias=True)
concat_parameters(args, bias_velocity_param_names, as_bias=True)
def parameter_name_compare(x, y):
"""
Compare two parameter names depend on their rank id.
A parameter name is like dist_softmax_rank_00000.w_0,
where 00000 is the rank id.
"""
rank_id_x = int(x.split('.')[0].split('@')[-1])
rank_id_y = int(y.split('.')[0].split('@')[-1])
if rank_id_x < rank_id_y:
return -1
elif rank_id_x == rank_id_y:
return 0
else:
return 1
def main():
global args
load_config(args)
var_names = find_distfc_var_names(args)
weight_param_names = [name for name in var_names
if '.w' in name and 'velocity' not in name]
weight_velocity_param_names = [name for name in var_names
if '.w' in name and 'velocity' in name]
bias_param_names = [name for name in var_names
if '.b' in name and 'velocity' not in name]
bias_velocity_param_names = [name for name in var_names
if '.b' in name and 'velocity' in name]
weight_param_names.sort(parameter_name_compare)
weight_velocity_param_names.sort(parameter_name_compare)
bias_param_names.sort(parameter_name_compare)
bias_velocity_param_names.sort(parameter_name_compare)
assert len(weight_param_names) == args.pretrain_nranks, \
logger.error("Number of distributed fc-related weight parameters ({}) "
"should be equal to the number of ranks ({}) for "
"pretraining.".format(len(weight_param_names),
args.pretrain_nranks))
assert len(weight_velocity_param_names) == args.pretrain_nranks, \
logger.error("Number of distributed fc-related weight parameters ({}) "
"should be equal to the number of ranks ({}) for "
"pretraining.".format(len(weight_velocity_param_names),
args.pretrain_nranks))
assert len(bias_param_names) == 0 or \
len(bias_param_names) == args.pretrain_nranks, logger.error("Number of "
"distributed fc-related bias parameters ({}) should be 0 or equal "
"to the number of ranks ({}) for pretraining.".format(
len(bias_param_names), args.pretrain_nranks))
assert len(bias_velocity_param_names) == 0 or \
len(bias_velocity_param_names) == args.pretrain_nranks, logger.error("Number of "
"distributed fc-related bias parameters ({}) should be 0 or equal "
"to the number of ranks ({}) for pretraining.".format(
len(bias_velocity_param_names), args.pretrain_nranks))
pretrain_nranks = args.pretrain_nranks
nranks = args.nranks
if pretrain_nranks == nranks:
logger.info("Pre-training and inference (or finetuning) have the same "
"number of ranks, nothing to do.")
elif pretrain_nranks < nranks:
split_distfc_parameters(args, weight_param_names,
weight_velocity_param_names, bias_param_names,
bias_velocity_param_names)
else:
concat_distfc_parameters(args, weight_param_names,
weight_velocity_param_names, bias_param_names,
bias_velocity_param_names)
logger.info("Done.")
if __name__ == "__main__":
main()
"""Helper for evaluation on the Labeled Faces in the Wild dataset
"""
# MIT License
#
# Copyright (c) 2016 David Sandberg
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
import numpy as np
from sklearn.model_selection import KFold
from scipy import interpolate
import sklearn
import math
import datetime
import pickle
from sklearn.decomposition import PCA
class LFold:
def __init__(self, n_splits = 2, shuffle = False):
self.n_splits = n_splits
if self.n_splits>1:
self.k_fold = KFold(n_splits = n_splits, shuffle = shuffle)
def split(self, indices):
if self.n_splits>1:
return self.k_fold.split(indices)
else:
return [(indices, indices)]
def calculate_roc(thresholds, embeddings1, embeddings2, actual_issame, nrof_folds=10, pca = 0):
assert(embeddings1.shape[0] == embeddings2.shape[0])
assert(embeddings1.shape[1] == embeddings2.shape[1])
nrof_pairs = min(len(actual_issame), embeddings1.shape[0])
nrof_thresholds = len(thresholds)
k_fold = LFold(n_splits=nrof_folds, shuffle=False)
tprs = np.zeros((nrof_folds,nrof_thresholds))
fprs = np.zeros((nrof_folds,nrof_thresholds))
accuracy = np.zeros((nrof_folds))
indices = np.arange(nrof_pairs)
#print('pca', pca)
if pca==0:
diff = np.subtract(embeddings1, embeddings2)
dist = np.sum(np.square(diff),1)
for fold_idx, (train_set, test_set) in enumerate(k_fold.split(indices)):
#print('train_set', train_set)
#print('test_set', test_set)
if pca>0:
print('doing pca on', fold_idx)
embed1_train = embeddings1[train_set]
embed2_train = embeddings2[train_set]
_embed_train = np.concatenate( (embed1_train, embed2_train), axis=0 )
#print(_embed_train.shape)
pca_model = PCA(n_components=pca)
pca_model.fit(_embed_train)
embed1 = pca_model.transform(embeddings1)
embed2 = pca_model.transform(embeddings2)
embed1 = sklearn.preprocessing.normalize(embed1)
embed2 = sklearn.preprocessing.normalize(embed2)
#print(embed1.shape, embed2.shape)
diff = np.subtract(embed1, embed2)
dist = np.sum(np.square(diff),1)
# Find the best threshold for the fold
acc_train = np.zeros((nrof_thresholds))
for threshold_idx, threshold in enumerate(thresholds):
_, _, acc_train[threshold_idx] = calculate_accuracy(threshold, dist[train_set], actual_issame[train_set])
best_threshold_index = np.argmax(acc_train)
#print('threshold', thresholds[best_threshold_index])
for threshold_idx, threshold in enumerate(thresholds):
tprs[fold_idx,threshold_idx], fprs[fold_idx,threshold_idx], _ = calculate_accuracy(threshold, dist[test_set], actual_issame[test_set])
_, _, accuracy[fold_idx] = calculate_accuracy(thresholds[best_threshold_index], dist[test_set], actual_issame[test_set])
tpr = np.mean(tprs,0)
fpr = np.mean(fprs,0)
return tpr, fpr, accuracy
def calculate_accuracy(threshold, dist, actual_issame):
predict_issame = np.less(dist, threshold)
tp = np.sum(np.logical_and(predict_issame, actual_issame))
fp = np.sum(np.logical_and(predict_issame, np.logical_not(actual_issame)))
tn = np.sum(np.logical_and(np.logical_not(predict_issame), np.logical_not(actual_issame)))
fn = np.sum(np.logical_and(np.logical_not(predict_issame), actual_issame))
tpr = 0 if (tp+fn==0) else float(tp) / float(tp+fn)
fpr = 0 if (fp+tn==0) else float(fp) / float(fp+tn)
acc = float(tp+tn)/dist.size
return tpr, fpr, acc
def calculate_val(thresholds, embeddings1, embeddings2, actual_issame, far_target, nrof_folds=10):
assert(embeddings1.shape[0] == embeddings2.shape[0])
assert(embeddings1.shape[1] == embeddings2.shape[1])
nrof_pairs = min(len(actual_issame), embeddings1.shape[0])
nrof_thresholds = len(thresholds)
k_fold = LFold(n_splits=nrof_folds, shuffle=False)
val = np.zeros(nrof_folds)
far = np.zeros(nrof_folds)
diff = np.subtract(embeddings1, embeddings2)
dist = np.sum(np.square(diff),1)
indices = np.arange(nrof_pairs)
for fold_idx, (train_set, test_set) in enumerate(k_fold.split(indices)):
# Find the threshold that gives FAR = far_target
far_train = np.zeros(nrof_thresholds)
for threshold_idx, threshold in enumerate(thresholds):
_, far_train[threshold_idx] = calculate_val_far(threshold, dist[train_set], actual_issame[train_set])
if np.max(far_train)>=far_target:
f = interpolate.interp1d(far_train, thresholds, kind='slinear')
threshold = f(far_target)
else:
threshold = 0.0
val[fold_idx], far[fold_idx] = calculate_val_far(threshold, dist[test_set], actual_issame[test_set])
val_mean = np.mean(val)
far_mean = np.mean(far)
val_std = np.std(val)
return val_mean, val_std, far_mean
def calculate_val_far(threshold, dist, actual_issame):
predict_issame = np.less(dist, threshold)
true_accept = np.sum(np.logical_and(predict_issame, actual_issame))
false_accept = np.sum(np.logical_and(predict_issame, np.logical_not(actual_issame)))
n_same = np.sum(actual_issame)
n_diff = np.sum(np.logical_not(actual_issame))
#print(true_accept, false_accept)
#print(n_same, n_diff)
val = float(true_accept) / float(n_same)
far = float(false_accept) / float(n_diff)
return val, far
def evaluate(embeddings, actual_issame, nrof_folds=10, pca=0):
# Calculate evaluation metrics
thresholds = np.arange(0, 4, 0.01)
embeddings1 = embeddings[0::2]
embeddings2 = embeddings[1::2]
tpr, fpr, accuracy = calculate_roc(thresholds, embeddings1, embeddings2,
np.asarray(actual_issame), nrof_folds=nrof_folds, pca = pca)
thresholds = np.arange(0, 4, 0.001)
val, val_std, far = calculate_val(thresholds, embeddings1, embeddings2,
np.asarray(actual_issame), 1e-3, nrof_folds=nrof_folds)
return tpr, fpr, accuracy, val, val_std, far
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from setuptools import setup, find_packages
setup(name="plsc",
version="0.1.0",
description="Large Scale Classfication via distributed fc.",
author='lilong',
author_email="lilong.albert@gmail.com",
url="http",
license="Apache",
#packages=['paddleXML'],
packages=find_packages(),
#install_requires=['paddlepaddle>=1.6.1'],
python_requires='>=2'
)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from __future__ import division
import os
import argparse
import random
import time
import math
import logging
import sqlite3
import tempfile
logging.basicConfig(level=logging.INFO,
format='[%(levelname)s %(asctime)s line:%(lineno)d] %(message)s',
datefmt='%d %b %Y %H:%M:%S')
logger = logging.getLogger()
parser = argparse.ArgumentParser(description="""
Tool to preprocess dataset in base64 format.""")
"""
We assume that the directory of dataset contains a file-list file, and one
or more data files. Each line of the file-list file represents a data file.
Each line of a data file represents a image in base64 format.
For example:
dir
|-- file_list.txt
|-- part_one.txt
`-- part_two.txt
In the above example, the file file_list.txt has two lines:
part_one.txt
part_two.txt
Each line in part_one.txt and part_two.txt represents a image in base64
format.
"""
parser.add_argument("--data_dir",
type=str,
required=True,
default=None,
help="Directory for datasets.")
parser.add_argument("--file_list",
type=str,
required=True,
default=None,
help="The file contains a set of data files.")
parser.add_argument("--nranks",
type=int,
required=True,
default=1,
help="Number of ranks.")
args = parser.parse_args()
class Base64Preprocessor(object):
def __init__(self, data_dir, file_list, nranks):
super(Base64Preprocessor, self).__init__()
self.data_dir = data_dir
self.file_list = file_list
self.nranks = nranks
self.tempfile = tempfile.NamedTemporaryFile(delete=False, dir=data_dir)
self.sqlite3_file = self.tempfile.name
self.conn = None
self.cursor = None
def create_db(self):
start = time.time()
print(self.sqlite3_file)
self.conn = sqlite3.connect(self.sqlite3_file)
self.cursor = self.conn.cursor()
self.cursor.execute('''CREATE TABLE DATASET
(ID INT PRIMARY KEY NOT NULL,
DATA TEXT NOT NULL,
LABEL INT NOT NULL);''')
file_list_path = os.path.join(self.data_dir, self.file_list)
with open(file_list_path, 'r') as f:
cnt = 0
for line in f.xreadlines():
line = line.strip()
file_path = os.path.join(self.data_dir, line)
with open(file_path, 'r') as df:
for line in df.xreadlines():
line = line.strip().split('\t')
label = int(line[0])
data = line[1]
sql_cmd = "INSERT INTO DATASET (ID, DATA, LABEL) "
sql_cmd += "VALUES ({}, '{}', {});".format(cnt, data, label)
self.cursor.execute(sql_cmd)
cnt += 1
os.remove(file_path)
self.conn.commit()
diff = time.time() - start
print("time: ", diff)
return cnt
def shuffle_files(self):
num = self.create_db()
nranks = self.nranks
index = [i for i in range(num)]
seed = int(time.time())
random.seed(seed)
random.shuffle(index)
start_time = time.time()
lines_per_rank = int(math.ceil(num/nranks))
total_num = lines_per_rank * nranks
index = index + index[0:total_num - num]
assert len(index) == total_num
for rank in range(nranks):
start = rank * lines_per_rank
end = (rank + 1) * lines_per_rank # exclusive
f_handler = open(os.path.join(self.data_dir,
".tmp_" + str(rank)), 'w')
for i in range(start, end):
idx = index[i]
sql_cmd = "SELECT DATA, LABEL FROM DATASET WHERE ID={};".format(idx)
cursor = self.cursor.execute(sql_cmd)
for result in cursor:
data = result[0]
label = result[1]
line = data + '\t' + str(label) + '\n'
f_handler.writelines(line)
f_handler.close()
data_dir = self.data_dir
file_list = self.file_list
file_list = os.path.join(data_dir, file_list)
temp_file_list = file_list + "_temp"
with open(temp_file_list, 'w') as f_t:
for rank in range(nranks):
line = "base64_rank_{}".format(rank)
line += '\n'
f_t.writelines(line)
os.rename(os.path.join(data_dir, ".tmp_" + str(rank)),
os.path.join(data_dir, "base64_rank_{}".format(rank)))
os.remove(file_list)
os.rename(temp_file_list, file_list)
print("shuffle time: ", time.time() - start_time)
def close_db(self):
self.conn.close()
self.tempfile.close()
def main():
global args
obj = Base64Preprocessor(args.data_dir, args.file_list, args.nranks)
obj.shuffle_files()
obj.close_db()
#data_dir = args.data_dir
#file_list = args.file_list
#nranks = args.nranks
#names, file_num_map, num = get_image_info(data_dir, file_list)
#
if __name__ == "__main__":
main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册