提交 98b91a73 编写于 作者: M MRXLT

Merge remote-tracking branch 'upstream/master'

# PLSC
Paddle Large Scale Classification Tools
# PaddlePaddle大规模分类库PLSC
## 简介
PaddlePaddle大规模分类库PLSC (PaddlePaddle Large Scale Classification)是基于[飞桨平台](https://www.paddlepaddle.org.cn)开发的超大规模分类库,为用户提供从训练到部署的全流程大规模分类应用解决方案。
PLSC具备以下特点:
- 基于源于产业实践的开源深度学习平台[飞桨平台](https://www.paddlepaddle.org.cn)
- 包含大量的预训练模型 (TBD)
- 提供从训练到部署的全流程解决方案 (TBD)
## 使用教程
我们提供了一系列使用教程,来帮助用户完成使用PLSC大规模分类库进行训练、评估和部署。
这一系列文档分为__快速入门__、__基础功能__、__预测部署__和__高级功能__四个部分,由浅入深地介绍PLSC大规模分类库的设计思路和使用方法。
### 快速入门
* [安装说明](docs/installation.md)
* [训练/评估/部署](docs/usage.md)
### 基础功能
* [API简介](docs/api_intro.md)
* [自定义模型](docs/custom_modes.md)
* [自定义Reader接口]
### 预测部署
* [模型导出](docs/export_for_infer.md)
* [C++预测库使用]
### 高级功能
* [混合精度训练]
* [分布式参数转换]
* [Base64格式图像预处理]
# PLSC API简介
## 默认配置参数
PLSC大规模分类库提供了默认配置参数,用于设置训练、评估和模型相关的信息,如训练数据集目录、训练轮数等。
这些参数信息位于plsc.config模块中,下面给出这些参数的含义和默认值。
### 训练相关
| 参数名称 | 参数含义 | 默认值 |
| :------- | :------- | :----- |
| train_batch_size | 训练阶段batch size的值 | 128 |
| dataset_dir | 数据集根目录 | './train_data' |
| train_image_num | 训练图像的数量 | 5822653 |
| train_epochs | 训练轮数 | 120 |
| warmup_epochs | warmup轮数 | 0 |
| lr | 初始学习率 | 0.1 |
| lr_steps | 学习率衰减的步数 | (100000,160000,220000) |
### 评估相关
| 参数名称 | 参数含义 | 默认值 |
| :------- | :------- | :----- |
| val_targets | 验证数据集名称,以逗号分隔,如'lfw,cfp_fp' | lfw |
| test_batch_size | 评估阶段batch size的值 | 120 |
| with_test | 是否在每轮训练之后开始评估模型 | True |
### 模型相关
| 参数名称 | 参数含义 | 默认值 |
| :------- | :------- | :----- |
| model_name | 使用的模型的名称 | 'RestNet50' |
| checkpoint_dir | 预训练模型目录 | "" |
| model_save_dir | 训练模型的保存目录 | "./output" |
| loss_type | loss类型,可选值为softmax、arcface、dist_softmax和dist_arcface | 'dist_arcface' |
| num_classes | 分类类别的数量 | 85742 |
| image_shape | 图像尺寸列表,格式为CHW | [3, 112, 112] |
| margin | dist_arcface和arcface的margin参数 | 0.5 |
| scale | dist_arcface和arcface的scale参数 | 64.0 |
| emb_size | 模型最后一层隐层的输出维度 | 512 |
备注:
* checkpoint_dir和model_save_dir的区别:checkpoint_dir用于在训练/评估前加载的预训练模型所在目录;model_save_dir指的是训练后模型的保存目录。
### 参数设置API
可以通过该组API修改默认参数,具体API及其描述见下表。
| API | 描述 | 参数说明 |
| :------------------- | :--------------------| :---------------------- |
| set_val_targets(targets) | 设置验证数据集 | 以逗号分隔的验证集名称,类型为字符串 |
| set_train_batch_size(size) | 设置训练batch size的值 | 类型为int |
| set_test_batch_size(size) | 设置评估batch size的值 | 类型为int |
| set_hdfs_info(fs_name, fs_ugi, directory) | 设置hdfs文件系统信息 | fs_name为hdfs地址,类型为字符串;fs_ugi为逗号分隔的用户名和密码,类型为字符串;directory为hdfs上的路径 |
| set_model_save_dir(dir) | 设置模型保存路径model_save_dir | 类型为字符串 |
| set_dataset_dir(dir) | 设置数据集根目录dataset_dir | 类型为字符串 |
| set_train_image_num(num) | 设置训练图像的总数量 | 类型为int |
| set_class_num(num) | 设置分类类别的总数量 | 类型为int |
| set_emb_size(size) | 设置最后一层隐层的输出维度 | 类型为int |
| set_model(model) | 设置用户使用的自定义模型类实例 | BaseModel的子类 |
| set_train_epochs(num) | 设置训练的轮数 | 类型为int |
| set_checkpoint_dir(dir) | 设置用于加载的预训练模型的目录 | 类型为字符串 |
| set_warmup_epochs(num) | 设置warmup的轮数 | 类型为int |
| set_loss_type(loss_type) | 设置模型的loss类型 | 类型为字符串 |
| set_image_size(size) | 设置图像尺寸,格式为CHW | 类型为元组 |
| set_optimizer(optimizer) | 设置训练阶段的optimizer | Optimizer类实例 |
| convert_for_prediction() | 将预训练模型转换为预测模型 | None |
| predict() | 离线预测接口,用于验证线上模型的正确性 | None |
| test() | 模型评估 | None |
| train() | 模型训练 | None |
备注:上述API均为PaddlePaddle大规模分类库PLSC的plsc.entry.Entry类的方法,需要通过该类的实例调用,例如:
```shell
import plsc.entry as entry
ins = entry.Entry()
ins.set_class_num(85742)
ins.train()
```
# 自定义模型
默认地,PaddlePaddle大规模分类库构建基于ResNet50模型的训练模型。
PLSC提供了模型基类plsc.models.base_model.BaseModel,用户可以基于该基类构建自己的网络模型。用户自定义的模型类需要继承自该基类,并实现build_network方法,该方法用于构建用户自定义模型。
用户在使用时需要调用类的get_output方法,该方法在用户自定义模型的尾端自动添加分布式FC层。
下面的例子给出如何使用BaseModel基类定义用户自己的网络模型, 以及如何使用。
```python
import paddle.fluid as fluid
import plsc.entry as entry
from plsc.models.base_model import BaseModel
class ResNet(BaseModel):
def __init__(self, layers=50, emb_dim=512):
super(ResNet, self).__init__()
self.layers = layers
self.emb_dim = emb_dim
def build_network(self,
input,
label,
is_train):
layers = self.layers
supported_layers = [50, 101, 152]
assert layers in supported_layers, \
"supported layers {}, but given {}".format(supported_layers, layers)
if layers == 50:
depth = [3, 4, 14, 3]
num_filters = [64, 128, 256, 512]
elif layers == 101:
depth = [3, 4, 23, 3]
num_filters = [256, 512, 1024, 2048]
elif layers == 152:
depth = [3, 8, 36, 3]
num_filters = [256, 512, 1024, 2048]
conv = self.conv_bn_layer(
input=input, num_filters=64, filter_size=3, stride=1,
pad=1, act='prelu', is_train=is_train)
for block in range(len(depth)):
for i in range(depth[block]):
conv = self.bottleneck_block(
input=conv,
num_filters=num_filters[block],
stride=2 if i == 0 else 1,
is_train=is_train)
bn = fluid.layers.batch_norm(input=conv, act=None, epsilon=2e-05,
is_test=False if is_train else True)
drop = fluid.layers.dropout(x=bn, dropout_prob=0.4,
dropout_implementation='upscale_in_train',
is_test=False if is_train else True)
fc = fluid.layers.fc(
input=drop,
size=self.emb_dim,
act=None,
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Xavier(uniform=False, fan_in=0.0)),
bias_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.ConstantInitializer()))
emb = fluid.layers.batch_norm(input=fc, act=None, epsilon=2e-05,
is_test=False if is_train else True)
return emb
... ...
if __name__ == "__main__":
ins = entry.Entry()
ins.set_model(ResNet())
ins.train()
```
用户自定义模型类需要继承自基类BaseModel,并实现build_network方法,实现用户的自定义模型。
build_network方法的输入如下:
* input: 输入图像数据
* label: 图像类别
* is_train: 表示训练阶段还是测试/预测阶段
build_network方法返回用户自定义组网的输出变量,BaseModel类的get_output方法将调用该方法获取用户自定义组网的输出,并自动在其后添加分布式FC层。
# 预测模型导出
通常,PaddlePaddle大规模分类库在训练过程中保存的模型只保存模型参数信息,
而不包括预测模型结构。为了部署PLSC预测库,需要将预训练模型导出为预测模型。
可以通过下面的代码将预训练模型导出为预测模型:
```python
import plsc.entry as entry
if __name__ == "__main__":
ins = entry.Entry()
ins.set_checkpoint_dir('./pretrain_model')
ins.set_model_save_dir('./inference_model')
ins.convert_for_prediction()
```
# 安装说明
## 1. 安装PaddlePaddle
版本要求:
* PaddlePaddle >= 1.6.2
* Python 2.7 or 3.5+
关于PaddlePaddle对操作系统、CUDA、cuDNN等软件版本的兼容信息,请查看[PaddlePaddle安装说明](https://www.paddlepaddle.org.cn/documentation/docs/zh/beginners_guide/install/index_cn.html)
### pip安装
当前,需要在GPU版本的PaddlePaddle下使用大规模分类库。
```shell
pip install paddlepaddle-gpu
```
### Conda安装
PaddlePaddle支持Conda安装,减少相关依赖模块的安装成本。conda相关使用说明可以参考[Anaconda](https://www.anaconda.com/distribution/)
```shell
conda install -c paddle paddlepaddle-gpu cudatoolkit=9.0
```
* 请安装NVIDIA NCCL >= 2.4.7,并在Linux系统下运行。
更多安装方式和信息请参考[PaddlePaddle安装说明](https://www.paddlepaddle.org.cn/documentation/docs/zh/beginners_guide/install/index_cn.html)
## 2. 安装大规模分类库
```shell
pip install plsc
```
# 训练、评估和部署
PaddlePaddle大规模分类提供了从训练、评估到预测部署的全流程解决方案。本文档介绍如何使用PaddlePaddle大规模分类库快速完成训练、评估和预测部署。
## 数据准备
我们假设用户数据集的组织结构如下:
```shell
train_data/
|-- agedb_30.bin
|-- cfp_ff.bin
|-- cfp_fp.bin
|-- images
|-- label.txt
`-- lfw.bin
```
其中,*train_data*是用户数据的根目录,*agedb_30.bin**cfp_ff.bin**cfp_fp.bin**lfw.bin*分别是不同的验证数据集,且这些验证数据集不是全部必须的。本文档教程默认使用lfw.bin作为验证数据集,因此在浏览本教程时,请确保lfw.bin验证数据集可用。*images*目录包含JPEG格式的训练图像,*label.txt*中的每一行对应一张训练图像以及该图像的类别。
*label.txt*文件的内容示例如下:
```shell
images/00000000.jpg 0
images/00000001.jpg 0
images/00000002.jpg 0
images/00000003.jpg 0
images/00000004.jpg 0
images/00000005.jpg 0
images/00000006.jpg 0
images/00000007.jpg 0
... ...
```
## 模型训练
### 训练代码
下面的例子给出使用PLSC完成大规模分类训练的脚本*train.py*
```python
import plsc.entry as entry
if __name__ == "__main__":
ins = entry.Entry()
ins.train()
```
1. 从plsc包导入entry.Entry类,其是使用PLCS大规模分类库功能的接口类。
2. 生成Entry类的实例。
3. 调用Entry类的train方法,即可开始训练。
### 开始训练
下面的例子给出如何使用上述脚本启动训练任务:
```shell
python -m paddle.distributed.launch \
--cluster_ips="127.0.0.1" \
--node_ip="127.0.0.1" \
--selected_gpus=0,1,2,3,4,5,6,7 \
train.py
```
paddle.distributed.launch模块用于启动多机/多卡分布式训练任务脚本,简化分布式训练任务启动过程,各个参数的含义如下:
* cluster_ips: 参与训练的节点的ip地址列表,以逗号分隔;
* node_ip: 当前训练节点的ip地址;
* selected_gpus: 每个训练节点所使用的gpu设备列表,以逗号分隔。
## 模型评估
本教程中,我们使用lfw.bin验证数据集评估训练模型的效果。
### 评估代码
下面的例子给出使用PLSC完成大规模分类训练的脚本*val.py*
```python
import plsc.entry as entry
if __name__ == "__main__":
ins = entry.Entry()
ins.set_checkpoint("output/0")
ins.test()
```
默认地,PLSC将训练脚本保存在'./ouput'目录下,并以pass_id作为区分不同训练轮次模型的子目录,例如'./output/0'目录下保存完成第一个轮次的训练后保存的模型。
在模型评估阶段,我们首先需要设置训练模型的目录,接着调用Entry类的test方法开始模型评估。
## 预测部署
TBD
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
文件已添加
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from easydict import EasyDict as edict
"""
Default Parameters
"""
config = edict()
config.train_batch_size = 128
config.test_batch_size = 120
config.val_targets = 'lfw'
config.dataset_dir = './train_data'
config.train_image_num = 5822653
config.model_name = 'ResNet50'
config.train_epochs = 120
config.checkpoint_dir = ""
config.with_test = True
config.model_save_dir = "output"
config.warmup_epochs = 0
config.loss_type = "dist_arcface"
config.num_classes = 85742
config.image_shape = (3,112,112)
config.margin = 0.5
config.scale = 64.0
config.lr = 0.1
config.lr_steps = (100000,160000,220000)
config.emb_dim = 512
文件已添加
此差异已折叠。
此差异已折叠。
文件已添加
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import resnet
from .resnet import *
from . import base_model
from .base_model import *
__all__ = []
__all__ += resnet.__all__
__all__ += base_model.__all__
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import os
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid import unique_name
import dist_algo
__all__ = ["BaseModel"]
class BaseModel(object):
"""
Base class for custom models.
The sub-class must implement the build_network method,
which constructs the custom model. And we will add the
distributed fc layer for you automatically.
"""
def __init__(self):
super(BaseModel, self).__init__()
def build_network(self, input, label, is_train=True):
"""
Construct the custom model, and we will add the
distributed fc layer for you automatically.
"""
raise NotImplementedError(
"You must implement this method in your sub class.")
def get_output(self,
input,
label,
num_classes,
is_train=True,
param_attr=None,
bias_attr=None,
loss_type="dist_softmax",
margin=0.5,
scale=64.0):
"""
Add the distributed fc layer for the custom model.
"""
supported_loss_types = ["dist_softmax", "dist_arcface",
"softmax", "arcface"]
assert loss_type in supported_loss_types, \
"Supported loss types: {}, but given: {}".format(
supported_loss_types, loss_type)
nranks = int(os.getenv("PADDLE_TRAINERS_NUM", 1))
rank_id = int(os.getenv("PADDLE_TRAINER_ID", 0))
emb = self.build_network(input, label, is_train)
if loss_type == "softmax":
loss, prob = self.fc_classify(emb,
label,
num_classes,
param_attr,
bias_attr)
elif loss_type == "arcface":
loss, prob = self.fc_arcface(emb,
label,
num_classes,
param_attr,
margin,
scale)
elif loss_type == "dist_arcface":
loss = dist_algo._distributed_arcface_classify(
x=emb, label=label, class_num=num_classes,
nranks=nranks, rank_id=rank_id, margin=margin,
logit_scale=scale, param_attr=param_attr)
prob = None
elif loss_type == "dist_softmax":
loss = dist_algo._distributed_softmax_classify(
x=emb, label=label, class_num=num_classes,
nranks=nranks, rank_id=rank_id, param_attr=param_attr,
use_bias=True, bias_attr=bias_attr)
prob = None
return emb, loss, prob
def fc_classify(self, input, label, out_dim, param_attr, bias_attr):
if param_attr is None:
stdv = 1.0 / math.sqrt(input.shape[1] * 1.0)
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Uniform(-stdv, stdv))
out = fluid.layers.fc(input=input,
size=out_dim,
param_attr=param_attr,
bias_attr=bias_attr)
loss, prob = fluid.layers.softmax_with_cross_entropy(logits=out,
label=label, return_softmax=True)
avg_loss = fluid.layers.mean(x=loss)
return avg_loss, prob
def arcface(self, input, label, out_dim, param_attr, margin, scale):
input_norm = fluid.layers.sqrt(
fluid.layers.reduce_sum(fluid.layers.square(input), dim=1))
input = fluid.layers.elementwise_div(input, input_norm, axis=0)
if param_attr is None:
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Xavier(
uniform=False, fan_in=0.0))
weight = fluid.layers.create_parameter(
shape=[input.shape[1], out_dim],
dtype='float32',
name=unique_name.generate('final_fc_w'),
attr=param_attr)
weight_norm = fluid.layers.sqrt(
fluid.layers.reduce_sum(fluid.layers.square(weight), dim=0))
weight = fluid.layers.elementwise_div(weight, weight_norm, axis=1)
cos = fluid.layers.mul(input, weight)
theta = fluid.layers.acos(cos)
margin_cos = fluid.layers.cos(theta + margin)
one_hot = fluid.layers.one_hot(label, out_dim)
diff = (margin_cos - cos) * one_hot
target_cos = cos + diff
logit = fluid.layers.scale(target_cos, scale=scale)
loss, prob = fluid.layers.softmax_with_cross_entropy(
logits=logit, label=label, return_softmax=True)
avg_loss = fluid.layers.mean(x=loss)
one_hot.stop_gradient = True
return avg_loss, prob
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import math
from six.moves import reduce
import paddle.fluid as fluid
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.framework import Variable, default_startup_program
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.initializer import Normal, Constant
import paddle.fluid.layers.nn as nn
import paddle.fluid.layers.ops as ops
import paddle.fluid.layers.collective as collective
from paddle.fluid.optimizer import Optimizer
class DistributedClassificationOptimizer(Optimizer):
'''
A optimizer wrapper to generate backward network for distributed
classification training of model parallelism.
'''
def __init__(self, optimizer, batch_size, use_fp16=False):
self._optimizer = optimizer
self._batch_size = batch_size
self._use_fp16 = use_fp16
def minimize(self,
loss,
startup_program=None,
parameter_list=None,
no_grad_set=None):
assert loss._get_info('shard_logit')
shard_logit = loss._get_info('shard_logit')
shard_prob = loss._get_info('shard_prob')
shard_label = loss._get_info('shard_label')
shard_dim = loss._get_info('shard_dim')
op_maker = fluid.core.op_proto_and_checker_maker
op_role_key = op_maker.kOpRoleAttrName()
op_role_var_key = op_maker.kOpRoleVarAttrName()
backward_role = int(op_maker.OpRole.Backward)
loss_backward_role = int(op_maker.OpRole.Loss) | int(
op_maker.OpRole.Backward)
# minimize a scalar of reduce_sum to generate the backward network
scalar = fluid.layers.reduce_sum(shard_logit)
ret = self._optimizer.minimize(scalar)
block = loss.block
# remove the unnecessary ops
index = 0
for i, op in enumerate(block.ops):
if op.all_attrs()[op_role_key] == loss_backward_role:
index = i
break
assert block.ops[index - 1].type == 'reduce_sum'
assert block.ops[index].type == 'fill_constant'
assert block.ops[index + 1].type == 'reduce_sum_grad'
block._remove_op(index + 1)
block._remove_op(index)
block._remove_op(index - 1)
# insert the calculated gradient
dtype = shard_logit.dtype
shard_one_hot = fluid.layers.create_tensor(dtype, name='shard_one_hot')
block._insert_op(
index - 1,
type='one_hot',
inputs={'X': shard_label},
outputs={'Out': shard_one_hot},
attrs={
'depth': shard_dim,
'allow_out_of_range': True,
op_role_key: backward_role
})
shard_logit_grad = fluid.layers.create_tensor(
dtype, name=fluid.backward._append_grad_suffix_(shard_logit.name))
block._insert_op(
index,
type='elementwise_sub',
inputs={'X': shard_prob,
'Y': shard_one_hot},
outputs={'Out': shard_logit_grad},
attrs={op_role_key: backward_role})
block._insert_op(
index + 1,
type='scale',
inputs={'X': shard_logit_grad},
outputs={'Out': shard_logit_grad},
attrs={
'scale': 1.0 / self._batch_size,
op_role_key: loss_backward_role
})
return ret
class DistributedClassifier(object):
'''
Tookit for distributed classification, in which the parameter of the last
full-connected layer is distributed to all trainers
'''
def __init__(self, nclasses, nranks, rank_id, layer_helper):
self.nclasses = nclasses
self.nranks = nranks
self.rank_id = rank_id
self._layer_helper = layer_helper
self.shard_dim = (nclasses + nranks - 1) // nranks
self.padding_dim = 0
self.is_equal_division = True
if nclasses % nranks != 0:
self.is_equal_division = False
if rank_id == nranks - 1:
other_shard_dim = self.shard_dim
self.shard_dim = nclasses % other_shard_dim
self.padding_dim = other_shard_dim - self.shard_dim
def create_parameter(self,
dtype,
in_dim,
param_attr=None,
bias_attr=None,
transpose_weight=False,
use_bias=True):
if param_attr is None:
stdv = math.sqrt(2.0 / (in_dim + self.nclasses))
param_attr = ParamAttr(initializer=Normal(scale=stdv))
weight_shape = [self.shard_dim, in_dim
] if transpose_weight else [in_dim, self.shard_dim]
weight = self._layer_helper.create_parameter(
shape=weight_shape, dtype=dtype, attr=param_attr, is_bias=False)
# avoid distributed parameter allreduce gradients
weight.is_distributed = True
# avoid distributed parameter broadcasting in startup program
default_startup_program().global_block().vars[
weight.name].is_distributed = True
bias = None
if use_bias:
bias = self._layer_helper.create_parameter(
shape=[self.shard_dim],
attr=bias_attr,
dtype=dtype,
is_bias=True)
bias.is_distributed = True
default_startup_program().global_block().vars[
bias.name].is_distributed = True
return weight, bias
def softmax_with_cross_entropy(self, shard_logit, shard_label):
shard_max = nn.reduce_max(shard_logit, dim=1, keep_dim=True)
global_max = collective._c_allreduce(
shard_max, reduce_type='max', use_calc_stream=True)
shard_logit_new = nn.elementwise_sub(shard_logit, global_max)
shard_exp = ops.exp(shard_logit_new)
shard_demon = nn.reduce_sum(shard_exp, dim=1, keep_dim=True)
global_demon = collective._c_allreduce(
shard_demon, reduce_type='sum', use_calc_stream=True)
global_log_demon = nn.log(global_demon)
shard_log_prob = shard_logit_new - global_log_demon
shard_prob = ops.exp(shard_log_prob)
shard_one_hot = nn.one_hot(
shard_label, depth=self.shard_dim, allow_out_of_range=True)
target_log_prob = nn.reduce_min(
shard_log_prob * shard_one_hot, dim=1, keep_dim=True)
shard_loss = nn.scale(target_log_prob, scale=-1.0)
global_loss = collective._c_reducescatter(
shard_loss, nranks=self.nranks, use_calc_stream=True)
return global_loss, shard_prob
def softmax_classify(self,
x,
label,
param_attr=None,
use_bias=True,
bias_attr=None):
flatten_dim = reduce(lambda a, b: a * b, x.shape[1:], 1)
weight, bias = self.create_parameter(
dtype=x.dtype,
in_dim=flatten_dim,
param_attr=param_attr,
bias_attr=bias_attr,
use_bias=use_bias)
x_all = collective._c_allgather(
x, nranks=self.nranks, use_calc_stream=True)
label_all = collective._c_allgather(
label, nranks=self.nranks, use_calc_stream=True)
label_all.stop_gradient = True
shard_fc = nn.mul(x_all, weight, x_num_col_dims=1)
if use_bias:
shard_fc = nn.elementwise_add(shard_fc, bias)
shard_label = nn.shard_index(
label_all,
index_num=self.nclasses,
nshards=self.nranks,
shard_id=self.rank_id,
ignore_value=-1)
shard_label.stop_gradient = True
global_loss, shard_prob = self.softmax_with_cross_entropy(shard_fc,
shard_label)
avg_loss = nn.mean(global_loss)
avg_loss._set_info('shard_logit', shard_fc)
avg_loss._set_info('shard_prob', shard_prob)
avg_loss._set_info('shard_label', shard_label)
avg_loss._set_info('shard_dim', self.shard_dim)
return avg_loss
def arcface_classify(self,
x,
label,
margin=0.5,
logit_scale=64,
param_attr=None):
'''
reference: ArcFace. https://arxiv.org/abs/1801.07698
'''
flatten_dim = reduce(lambda a, b: a * b, x.shape[1:], 1)
weight, bias = self.create_parameter(
dtype=x.dtype,
in_dim=flatten_dim,
param_attr=param_attr,
use_bias=False)
# normalize x
x_l2 = ops.sqrt(nn.reduce_sum(ops.square(x), dim=1))
norm_x = nn.elementwise_div(x, x_l2, axis=0)
norm_x_all = collective._c_allgather(
norm_x, nranks=self.nranks, use_calc_stream=True)
label_all = collective._c_allgather(
label, nranks=self.nranks, use_calc_stream=True)
label_all.stop_gradient = True
shard_label = nn.shard_index(
label_all,
index_num=self.nclasses,
nshards=self.nranks,
shard_id=self.rank_id,
ignore_value=-1)
# TODO check necessary
shard_label.stop_gradient = True
# normalize weight
weight_l2 = ops.sqrt(nn.reduce_sum(ops.square(weight), dim=0))
norm_weight = nn.elementwise_div(weight, weight_l2, axis=1)
shard_cos = nn.mul(norm_x_all, norm_weight, x_num_col_dims=1)
theta = ops.acos(shard_cos)
margin_cos = ops.cos(theta + margin)
shard_one_hot = nn.one_hot(
shard_label, depth=self.shard_dim, allow_out_of_range=True)
# TODO check necessary
shard_one_hot.stop_gradient = True
diff = (margin_cos - shard_cos) * shard_one_hot
shard_target_cos = shard_cos + diff
shard_logit = nn.scale(shard_target_cos, scale=logit_scale)
global_loss, shard_prob = self.softmax_with_cross_entropy(shard_logit,
shard_label)
avg_loss = nn.mean(global_loss)
avg_loss._set_info('shard_logit', shard_logit)
avg_loss._set_info('shard_prob', shard_prob)
avg_loss._set_info('shard_label', shard_label)
avg_loss._set_info('shard_dim', self.shard_dim)
return avg_loss
def _distributed_softmax_classify(x,
label,
class_num,
nranks,
rank_id,
param_attr=None,
use_bias=True,
bias_attr=None,
name=None):
'''
Classification layer with FC, softmax and cross entropy calculation of
distibuted version in case of too large number of classes.
Args:
x (Variable): The feature representation of the input samples. This
feature will be flattened into 2-D tensor from dimension index
1. E.g. [32, 1024, 1, 1] will be flattened to [32, 1024].
label (Variable): The label corresponding to the input samples.
class_num (integer): The number of classes of the classification problem.
nranks (integer): The number of ranks of distributed trainers.
rank_id (integer): The rank index of the current trainer.
param_attr (ParamAttr, default None): The parameter attribute for
learnable distributed parameters/weights of this layer.
use_bias (float, default 64.0): The scale factor for logit value
of cosine range.
name (str, default None): The name of this layer.
Returns:
Variable: The ArcFace loss.
Examples:
.. code-block:: python
import paddle.fluid as fluid
input = fluid.layers.data(name="input",
shape=[32, 1024],
dtype='float32',
append_batch_size=False)
label = fluid.layers.data(name="label",
shape=[32, 1],
dtype='int64',
append_batch_size=False)
y = fluid.layers.collective.distributed_softmax_classify(x=input,
label=label,
class_num=1000,
nranks=8,
rank_id=0)
'''
if name is None:
name = 'dist@softmax@rank@%05d' % rank_id
helper = LayerHelper(name, **locals())
classifier = DistributedClassifier(class_num, nranks, rank_id, helper)
return classifier.softmax_classify(x, label, param_attr, use_bias,
bias_attr)
def _distributed_arcface_classify(x,
label,
class_num,
nranks,
rank_id,
margin=0.5,
logit_scale=64.0,
param_attr=None,
name=None):
'''
Classification layer with ArcFace loss of distibuted version in case of
too large number of classes. the equation is
.. math::
L=-\frac{1}{N}\sum^N_{i=1}\log\frac{e^{s(cos(\theta_{y_i}+m))}}{e^{s(cos(\theta_{y_i}+m))}+\sum^n_{j=1,j\neq y_i} e^{scos\theta_{y_i}}}
where the :math: `\theta_{y_i}` is the angle between the feature :math: `x` and
the representation of class :math: `i`. The details of ArcFace loss
could be referred to https://arxiv.org/abs/1801.07698.
Args:
x (Variable): The feature representation of the input samples. This
feature will be flattened into 2-D tensor from dimension index
1. E.g. [32, 1024, 1, 1] will be flattened to [32, 1024].
label (Variable): The label corresponding to the input samples.
class_num (integer): The number of classes of the classification problem.
nranks (integer): The number of ranks of distributed trainers.
rank_id (integer): The rank index of the current trainer.
margin (float, default 0.5): The angular margin penalty to enhance
the intra-class compactness and inter-class discrepancy.
logit_scale (float, default 64.0): The scale factor for logit value
of cosine range.
param_attr (ParamAttr, default None): The parameter attribute for
learnable distributed parameters/weights of this layer.
name (str, default None): The name of this layer.
Returns:
Variable: The ArcFace loss.
Examples:
.. code-block:: python
import paddle.fluid as fluid
input = fluid.layers.data(name="input",
shape=[32, 1024],
dtype='float32',
append_batch_size=False)
label = fluid.layers.data(name="label",
shape=[32, 1],
dtype='int64',
append_batch_size=False)
y = fluid.layers.collective.distributed_arcface_classify(x=input,
label=label,
class_num=1000,
nranks=8,
rank_id=0)
'''
if name is None:
name = 'dist@arcface@rank@%05d' % rank_id
helper = LayerHelper(name, **locals())
classifier = DistributedClassifier(class_num, nranks, rank_id, helper)
return classifier.arcface_classify(
x=x,
label=label,
margin=margin,
logit_scale=logit_scale,
param_attr=param_attr)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.fluid as fluid
import math
import os
import numpy as np
from paddle.fluid import unique_name
from .base_model import BaseModel
__all__ = ["ResNet", "ResNet50", "ResNet101", "ResNet152"]
class ResNet(BaseModel):
def __init__(self, layers=50, emb_dim=512):
super(ResNet, self).__init__()
self.layers = layers
self.emb_dim = emb_dim
def build_network(self,
input,
label,
is_train):
layers = self.layers
supported_layers = [50, 101, 152]
assert layers in supported_layers, \
"supported layers {}, but given {}".format(supported_layers, layers)
if layers == 50:
depth = [3, 4, 14, 3]
num_filters = [64, 128, 256, 512]
elif layers == 101:
depth = [3, 4, 23, 3]
num_filters = [256, 512, 1024, 2048]
elif layers == 152:
depth = [3, 8, 36, 3]
num_filters = [256, 512, 1024, 2048]
conv = self.conv_bn_layer(
input=input, num_filters=64, filter_size=3, stride=1,
pad=1, act='prelu', is_train=is_train)
for block in range(len(depth)):
for i in range(depth[block]):
conv = self.bottleneck_block(
input=conv,
num_filters=num_filters[block],
stride=2 if i == 0 else 1,
is_train=is_train)
bn = fluid.layers.batch_norm(input=conv, act=None, epsilon=2e-05,
is_test=False if is_train else True)
drop = fluid.layers.dropout(x=bn, dropout_prob=0.4,
dropout_implementation='upscale_in_train',
is_test=False if is_train else True)
fc = fluid.layers.fc(
input=drop,
size=self.emb_dim,
act=None,
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Xavier(uniform=False, fan_in=0.0)),
bias_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.ConstantInitializer()))
emb = fluid.layers.batch_norm(input=fc, act=None, epsilon=2e-05,
is_test=False if is_train else True)
return emb
def conv_bn_layer(self,
input,
num_filters,
filter_size,
stride=1,
pad=0,
groups=1,
is_train=True,
act=None):
conv = fluid.layers.conv2d(
input=input,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=pad,
groups=groups,
act=None,
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Xavier(
uniform=False, fan_in=0.0)),
bias_attr=False)
if act == 'prelu':
bn = fluid.layers.batch_norm(input=conv, act=None, epsilon=2e-05,
momentum=0.9, is_test=False if is_train else True)
return fluid.layers.prelu(bn, mode="all",
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Constant(0.25)))
else:
return fluid.layers.batch_norm(input=conv, act=act, epsilon=2e-05,
is_test=False if is_train else True)
def shortcut(self, input, ch_out, stride, is_train):
ch_in = input.shape[1]
if ch_in != ch_out or stride != 1:
return self.conv_bn_layer(input, ch_out, 1, stride,
is_train=is_train)
else:
return input
def bottleneck_block(self, input, num_filters, stride, is_train):
if self.layers < 101:
bn1 = fluid.layers.batch_norm(input=input, act=None, epsilon=2e-05,
is_test=False if is_train else True)
conv1 = self.conv_bn_layer(
input=bn1, num_filters=num_filters, filter_size=3, pad=1,
act='prelu', is_train=is_train)
conv2 = self.conv_bn_layer(
input=conv1, num_filters=num_filters, filter_size=3,
stride=stride, pad=1, act=None, is_train=is_train)
else:
bn0 = fluid.layers.batch_norm(input=input, act=None, epsilon=2e-05,
is_test=False if is_train else True)
conv0 = self.conv_bn_layer(
input=bn0, num_filters=num_filters/4, filter_size=1, pad=0,
act='prelu', is_train=is_train)
conv1 = self.conv_bn_layer(
input=conv0, num_filters=num_filters/4, filter_size=3, pad=1,
act='prelu', is_train=is_train)
conv2 = self.conv_bn_layer(
input=conv1, num_filters=num_filters, filter_size=1,
stride=stride, pad=0, act=None, is_train=is_train)
short = self.shortcut(input, num_filters, stride, is_train=is_train)
return fluid.layers.elementwise_add(x=short, y=conv2, act=None)
def ResNet50(emb_dim=512):
model = ResNet(layers=50, emb_dim=emb_dim)
return model
def ResNet101(emb_dim=512):
model = ResNet(layers=101, emb_dim=emb_dim)
return model
def ResNet152(emb_dim=512):
model = ResNet(layers=152, emb_dim=emb_dim)
return model
#!/usr/bin/env bash
export FLAGS_cudnn_exhaustive_search=true
export FLAGS_fraction_of_gpu_memory_to_use=0.96
export FLAGS_eager_delete_tensor_gb=0.0
selected_gpus="0,1,2,3,4,5,6,7"
#selected_gpus="4,5,6"
python -m paddle.distributed.launch \
--selected_gpus $selected_gpus \
--log_dir mylog \
do_train.py \
--model=ResNet_ARCFACE50 \
--loss_type=dist_softmax \
--model_save_dir=output \
--margin=0.5 \
--train_batch_size 32 \
--class_dim 85742 \
--with_test=True
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import math
import random
import pickle
import base64
import functools
import numpy as np
import paddle
from PIL import Image, ImageEnhance
try:
from StringIO import StringIO
except ImportError:
from io import StringIO
random.seed(0)
DATA_DIM = 112
THREAD = 8
BUF_SIZE = 10240
img_mean = np.array([127.5, 127.5, 127.5]).reshape((3, 1, 1))
img_std = np.array([128.0, 128.0, 128.0]).reshape((3, 1, 1))
def resize_short(img, target_size):
percent = float(target_size) / min(img.size[0], img.size[1])
resized_width = int(round(img.size[0] * percent))
resized_height = int(round(img.size[1] * percent))
img = img.resize((resized_width, resized_height), Image.BILINEAR)
return img
def Scale(img, size):
w, h = img.size
if (w <= h and w == size) or (h <= w and h == size):
return img
if w < h:
ow = size
oh = int(size * h / w)
return img.resize((ow, oh), Image.BILINEAR)
else:
oh = size
ow = int(size * w / h)
return img.resize((ow, oh), Image.BILINEAR)
def CenterCrop(img, size):
w, h = img.size
th, tw = int(size), int(size)
x1 = int(round((w - tw) / 2.))
y1 = int(round((h - th) / 2.))
return img.crop((x1, y1, x1 + tw, y1 + th))
def crop_image(img, target_size, center):
width, height = img.size
size = target_size
if center == True:
w_start = (width - size) / 2
h_start = (height - size) / 2
else:
w_start = random.randint(0, width - size)
h_start = random.randint(0, height - size)
w_end = w_start + size
h_end = h_start + size
img = img.crop((w_start, h_start, w_end, h_end))
return img
def RandomResizedCrop(img, size):
for attempt in range(10):
area = img.size[0] * img.size[1]
target_area = random.uniform(0.08, 1.0) * area
aspect_ratio = random.uniform(3. / 4, 4. / 3)
w = int(round(math.sqrt(target_area * aspect_ratio)))
h = int(round(math.sqrt(target_area / aspect_ratio)))
if random.random() < 0.5:
w, h = h, w
if w <= img.size[0] and h <= img.size[1]:
x1 = random.randint(0, img.size[0] - w)
y1 = random.randint(0, img.size[1] - h)
img = img.crop((x1, y1, x1 + w, y1 + h))
assert(img.size == (w, h))
return img.resize((size, size), Image.BILINEAR)
w = min(img.size[0], img.size[1])
i = (img.size[1] - w) // 2
j = (img.size[0] - w) // 2
img = img.crop((i, j, i+w, j+w))
img = img.resize((size, size), Image.BILINEAR)
return img
def random_crop(img, size, scale=[0.08, 1.0], ratio=[3. / 4., 4. / 3.]):
aspect_ratio = math.sqrt(random.uniform(*ratio))
w = 1. * aspect_ratio
h = 1. / aspect_ratio
bound = min((float(img.size[0]) / img.size[1]) / (w**2),
(float(img.size[1]) / img.size[0]) / (h**2))
scale_max = min(scale[1], bound)
scale_min = min(scale[0], bound)
target_area = img.size[0] * img.size[1] * random.uniform(scale_min,
scale_max)
target_size = math.sqrt(target_area)
w = int(target_size * w)
h = int(target_size * h)
i = random.randint(0, img.size[0] - w)
j = random.randint(0, img.size[1] - h)
img = img.crop((i, j, i + w, j + h))
img = img.resize((size, size), Image.BILINEAR)
return img
def rotate_image(img):
angle = random.randint(-10, 10)
img = img.rotate(angle)
return img
def distort_color(img):
def random_brightness(img, lower=0.8, upper=1.2):
e = random.uniform(lower, upper)
return ImageEnhance.Brightness(img).enhance(e)
def random_contrast(img, lower=0.8, upper=1.2):
e = random.uniform(lower, upper)
return ImageEnhance.Contrast(img).enhance(e)
def random_color(img, lower=0.8, upper=1.2):
e = random.uniform(lower, upper)
return ImageEnhance.Color(img).enhance(e)
ops = [random_brightness, random_contrast, random_color]
random.shuffle(ops)
img = ops[0](img)
img = ops[1](img)
img = ops[2](img)
return img
def process_image_imagepath(sample,
class_dim,
color_jitter,
rotate,
rand_mirror,
normalize):
img_data = base64.b64decode(sample[0])
img = Image.open(StringIO(img_data))
if rotate:
img = rotate_image(img)
img = RandomResizedCrop(img, DATA_DIM)
if color_jitter:
img = distort_color(img)
if rand_mirror:
if random.randint(0, 1) == 1:
img = img.transpose(Image.FLIP_LEFT_RIGHT)
if img.mode != 'RGB':
img = img.convert('RGB')
img = np.array(img).astype('float32').transpose((2, 0, 1))
if normalize:
img -= img_mean
img /= img_std
assert sample[1] < class_dim, \
"label of train dataset should be less than the class_dim."
return img, sample[1]
def arc_iterator(file_list,
class_dim,
color_jitter=False,
rotate=False,
rand_mirror=False,
normalize=False):
trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
trainer_count = int(os.getenv("PADDLE_TRAINERS_NUM", "1"))
def reader():
with open(file_list, 'r') as f:
flist = f.readlines()
assert len(flist) % trainer_count == 0, \
"Number of files should be divisible by trainer count, " \
"run base64 file preprocessing tool first."
num_files_per_trainer = len(flist) // trainer_count
start = num_files_per_trainer * trainer_id
end = start + num_files_per_trainer
flist = flist[start:end]
for file in flist:
with open(file, 'r') as f:
for line in f.xreadlines():
line = line.strip().split('\t')
image, label = line[0], line[1]
yield image, label
mapper = functools.partial(process_image_imagepath,
class_dim=class_dim, color_jitter=color_jitter, rotate=rotate,
rand_mirror=rand_mirror, normalize=normalize)
return paddle.reader.xmap_readers(mapper, reader, THREAD, BUF_SIZE)
def load_bin(path, image_size):
bins, issame_list = pickle.load(open(path, 'rb'))
data_list = []
for flip in [0, 1]:
data = np.empty((len(issame_list)*2, 3, image_size[0], image_size[1]))
data_list.append(data)
for i in xrange(len(issame_list)*2):
_bin = bins[i]
if not isinstance(_bin, basestring):
_bin = _bin.tostring()
img_ori = Image.open(StringIO(_bin))
for flip in [0, 1]:
img = img_ori.copy()
if flip == 1:
img = img.transpose(Image.FLIP_LEFT_RIGHT)
if img.mode != 'RGB':
img = img.convert('RGB')
img = np.array(img).astype('float32').transpose((2, 0, 1))
img -= img_mean
img /= img_std
data_list[flip][i][:] = img
if i % 1000 == 0:
print('loading bin', i)
print(data_list[0].shape)
return (data_list, issame_list)
def train(data_dir, file_list, num_classes):
file_path = os.path.join(data_dir, file_list)
return arc_iterator(file_path, class_dim=num_classes, color_jitter=False,
rotate=False, rand_mirror=True, normalize=True)
def test(data_dir, datasets):
test_list = []
test_name_list = []
for name in datasets.split(','):
path = os.path.join(data_dir, name+".bin")
if os.path.exists(path):
data_set = load_bin(path, (DATA_DIM, DATA_DIM))
test_list.append(data_set)
test_name_list.append(name)
print('test', name)
return test_list, test_name_list
import os
import math
import random
import pickle
import functools
import numpy as np
import paddle
from PIL import Image, ImageEnhance
try:
from StringIO import StringIO
except ImportError:
from io import StringIO
random.seed(0)
DATA_DIM = 112
THREAD = 8
BUF_SIZE = 10240
#TEST_LIST = 'lfw,cfp_fp,agedb_30,cfp_ff'
TEST_LIST = 'lfw'
def get_train_image_list(data_dir):
train_list_file = os.path.join(data_dir, 'label.txt')
train_list = open(train_list_file, "r").readlines()
random.shuffle(train_list)
train_image_list = []
trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
trainer_count = int(os.getenv("PADDLE_TRAINERS_NUM", "1"))
per_node_lines = (len(train_list) + trainer_count - 1) // trainer_count
train_list += train_list[0:per_node_lines
* trainer_count-len(train_list)]
lines = train_list[trainer_id * per_node_lines:(trainer_id + 1)
* per_node_lines]
print("read images from %d, length: %d, total: %d"
% (trainer_id * per_node_lines, per_node_lines, len(train_list)))
for i, item in enumerate(lines):
path, label = item.strip().split()
label = int(label)
train_image_list.append((path, label))
print("train_data size:", len(train_image_list))
return train_image_list
img_mean = np.array([127.5, 127.5, 127.5]).reshape((3, 1, 1))
img_std = np.array([128.0, 128.0, 128.0]).reshape((3, 1, 1))
def resize_short(img, target_size):
percent = float(target_size) / min(img.size[0], img.size[1])
resized_width = int(round(img.size[0] * percent))
resized_height = int(round(img.size[1] * percent))
img = img.resize((resized_width, resized_height), Image.BILINEAR)
return img
def Scale(img, size):
w, h = img.size
if (w <= h and w == size) or (h <= w and h == size):
return img
if w < h:
ow = size
oh = int(size * h / w)
return img.resize((ow, oh), Image.BILINEAR)
else:
oh = size
ow = int(size * w / h)
return img.resize((ow, oh), Image.BILINEAR)
def CenterCrop(img, size):
w, h = img.size
th, tw = int(size), int(size)
x1 = int(round((w - tw) / 2.))
y1 = int(round((h - th) / 2.))
return img.crop((x1, y1, x1 + tw, y1 + th))
def crop_image(img, target_size, center):
width, height = img.size
size = target_size
if center == True:
w_start = (width - size) / 2
h_start = (height - size) / 2
else:
w_start = random.randint(0, width - size)
h_start = random.randint(0, height - size)
w_end = w_start + size
h_end = h_start + size
img = img.crop((w_start, h_start, w_end, h_end))
return img
def RandomResizedCrop(img, size):
for attempt in range(10):
area = img.size[0] * img.size[1]
target_area = random.uniform(0.08, 1.0) * area
aspect_ratio = random.uniform(3. / 4, 4. / 3)
w = int(round(math.sqrt(target_area * aspect_ratio)))
h = int(round(math.sqrt(target_area / aspect_ratio)))
if random.random() < 0.5:
w, h = h, w
if w <= img.size[0] and h <= img.size[1]:
x1 = random.randint(0, img.size[0] - w)
y1 = random.randint(0, img.size[1] - h)
img = img.crop((x1, y1, x1 + w, y1 + h))
assert(img.size == (w, h))
return img.resize((size, size), Image.BILINEAR)
w = min(img.size[0], img.size[1])
i = (img.size[1] - w) // 2
j = (img.size[0] - w) // 2
img = img.crop((i, j, i+w, j+w))
img = img.resize((size, size), Image.BILINEAR)
return img
def random_crop(img, size, scale=[0.08, 1.0], ratio=[3. / 4., 4. / 3.]):
aspect_ratio = math.sqrt(random.uniform(*ratio))
w = 1. * aspect_ratio
h = 1. / aspect_ratio
bound = min((float(img.size[0]) / img.size[1]) / (w**2),
(float(img.size[1]) / img.size[0]) / (h**2))
scale_max = min(scale[1], bound)
scale_min = min(scale[0], bound)
target_area = img.size[0] * img.size[1] * random.uniform(scale_min,
scale_max)
target_size = math.sqrt(target_area)
w = int(target_size * w)
h = int(target_size * h)
i = random.randint(0, img.size[0] - w)
j = random.randint(0, img.size[1] - h)
img = img.crop((i, j, i + w, j + h))
img = img.resize((size, size), Image.BILINEAR)
return img
def rotate_image(img):
angle = random.randint(-10, 10)
img = img.rotate(angle)
return img
def distort_color(img):
def random_brightness(img, lower=0.8, upper=1.2):
e = random.uniform(lower, upper)
return ImageEnhance.Brightness(img).enhance(e)
def random_contrast(img, lower=0.8, upper=1.2):
e = random.uniform(lower, upper)
return ImageEnhance.Contrast(img).enhance(e)
def random_color(img, lower=0.8, upper=1.2):
e = random.uniform(lower, upper)
return ImageEnhance.Color(img).enhance(e)
ops = [random_brightness, random_contrast, random_color]
random.shuffle(ops)
img = ops[0](img)
img = ops[1](img)
img = ops[2](img)
return img
def process_image_imagepath(sample,
class_dim,
color_jitter,
rotate,
rand_mirror,
normalize):
imgpath = sample[0]
img = Image.open(imgpath)
if rotate:
img = rotate_image(img)
img = RandomResizedCrop(img, DATA_DIM)
if color_jitter:
img = distort_color(img)
if rand_mirror:
if random.randint(0, 1) == 1:
img = img.transpose(Image.FLIP_LEFT_RIGHT)
if img.mode != 'RGB':
img = img.convert('RGB')
img = np.array(img).astype('float32').transpose((2, 0, 1))
if normalize:
img -= img_mean
img /= img_std
assert sample[1] < class_dim, \
"label of train dataset should be less than the class_dim."
return img, sample[1]
def arc_iterator(data,
class_dim,
data_dir,
shuffle=False,
color_jitter=False,
rotate=False,
rand_mirror=False,
normalize=False):
def reader():
if shuffle:
random.shuffle(data)
for j in xrange(len(data)):
path, label = data[j]
path = os.path.join(data_dir, path)
yield path, label
mapper = functools.partial(process_image_imagepath, class_dim=class_dim,
color_jitter=color_jitter, rotate=rotate,
rand_mirror=rand_mirror, normalize=normalize)
return paddle.reader.xmap_readers(mapper, reader, THREAD, BUF_SIZE)
def load_bin(path, image_size):
bins, issame_list = pickle.load(open(path, 'rb'))
data_list = []
for flip in [0, 1]:
data = np.empty((len(issame_list)*2, 3, image_size[0], image_size[1]))
data_list.append(data)
for i in xrange(len(issame_list)*2):
_bin = bins[i]
if not isinstance(_bin, basestring):
_bin = _bin.tostring()
img_ori = Image.open(StringIO(_bin))
for flip in [0, 1]:
img = img_ori.copy()
if flip == 1:
img = img.transpose(Image.FLIP_LEFT_RIGHT)
if img.mode != 'RGB':
img = img.convert('RGB')
img = np.array(img).astype('float32').transpose((2, 0, 1))
img -= img_mean
img /= img_std
data_list[flip][i][:] = img
if i % 1000 == 0:
print('loading bin', i)
print(data_list[0].shape)
return (data_list, issame_list)
def arc_train(data_dir, class_dim):
train_image_list = get_train_image_list(data_dir)
return arc_iterator(train_image_list, shuffle=True, class_dim=class_dim,
data_dir=data_dir, color_jitter=False, rotate=False, rand_mirror=True,
normalize=True)
def test(data_dir, datasets):
test_list = []
test_name_list = []
for name in datasets.split(','):
path = os.path.join(data_dir, name+".bin")
if os.path.exists(path):
data_set = load_bin(path, (DATA_DIM, DATA_DIM))
test_list.append(data_set)
test_name_list.append(name)
print('test', name)
return test_list, test_name_list
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
def lr_warmup(learning_rate, warmup_steps, start_lr, end_lr):
"""
Applies linear learning rate warmup for distributed training.
The parameter learning_rate should be a float or a Variable.
lr = start_lr + (warmup_rate * step / warmup_steps), where warmup_rate
is end_lr - start_lr, and step is the current step.
"""
assert (isinstance(end_lr, float))
assert (isinstance(start_lr, float))
linear_step = end_lr - start_lr
with fluid.default_main_program()._lr_schedule_guard():
lr = fluid.layers.tensor.create_global_var(
shape=[1],
value=0.0,
dtype='float32',
persistable=True,
name="learning_rate_warmup")
global_step = fluid.layers.learning_rate_scheduler._decay_step_counter()
with fluid.layers.control_flow.Switch() as switch:
with switch.case(global_step < warmup_steps):
decayed_lr = start_lr + linear_step * (global_step /
warmup_steps)
fluid.layers.tensor.assign(decayed_lr, lr)
with switch.default():
fluid.layers.tensor.assign(learning_rate, lr)
return lr
此差异已折叠。
"""Helper for evaluation on the Labeled Faces in the Wild dataset
"""
# MIT License
#
# Copyright (c) 2016 David Sandberg
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
import numpy as np
from sklearn.model_selection import KFold
from scipy import interpolate
import sklearn
import math
import datetime
import pickle
from sklearn.decomposition import PCA
class LFold:
def __init__(self, n_splits = 2, shuffle = False):
self.n_splits = n_splits
if self.n_splits>1:
self.k_fold = KFold(n_splits = n_splits, shuffle = shuffle)
def split(self, indices):
if self.n_splits>1:
return self.k_fold.split(indices)
else:
return [(indices, indices)]
def calculate_roc(thresholds, embeddings1, embeddings2, actual_issame, nrof_folds=10, pca = 0):
assert(embeddings1.shape[0] == embeddings2.shape[0])
assert(embeddings1.shape[1] == embeddings2.shape[1])
nrof_pairs = min(len(actual_issame), embeddings1.shape[0])
nrof_thresholds = len(thresholds)
k_fold = LFold(n_splits=nrof_folds, shuffle=False)
tprs = np.zeros((nrof_folds,nrof_thresholds))
fprs = np.zeros((nrof_folds,nrof_thresholds))
accuracy = np.zeros((nrof_folds))
indices = np.arange(nrof_pairs)
#print('pca', pca)
if pca==0:
diff = np.subtract(embeddings1, embeddings2)
dist = np.sum(np.square(diff),1)
for fold_idx, (train_set, test_set) in enumerate(k_fold.split(indices)):
#print('train_set', train_set)
#print('test_set', test_set)
if pca>0:
print('doing pca on', fold_idx)
embed1_train = embeddings1[train_set]
embed2_train = embeddings2[train_set]
_embed_train = np.concatenate( (embed1_train, embed2_train), axis=0 )
#print(_embed_train.shape)
pca_model = PCA(n_components=pca)
pca_model.fit(_embed_train)
embed1 = pca_model.transform(embeddings1)
embed2 = pca_model.transform(embeddings2)
embed1 = sklearn.preprocessing.normalize(embed1)
embed2 = sklearn.preprocessing.normalize(embed2)
#print(embed1.shape, embed2.shape)
diff = np.subtract(embed1, embed2)
dist = np.sum(np.square(diff),1)
# Find the best threshold for the fold
acc_train = np.zeros((nrof_thresholds))
for threshold_idx, threshold in enumerate(thresholds):
_, _, acc_train[threshold_idx] = calculate_accuracy(threshold, dist[train_set], actual_issame[train_set])
best_threshold_index = np.argmax(acc_train)
#print('threshold', thresholds[best_threshold_index])
for threshold_idx, threshold in enumerate(thresholds):
tprs[fold_idx,threshold_idx], fprs[fold_idx,threshold_idx], _ = calculate_accuracy(threshold, dist[test_set], actual_issame[test_set])
_, _, accuracy[fold_idx] = calculate_accuracy(thresholds[best_threshold_index], dist[test_set], actual_issame[test_set])
tpr = np.mean(tprs,0)
fpr = np.mean(fprs,0)
return tpr, fpr, accuracy
def calculate_accuracy(threshold, dist, actual_issame):
predict_issame = np.less(dist, threshold)
tp = np.sum(np.logical_and(predict_issame, actual_issame))
fp = np.sum(np.logical_and(predict_issame, np.logical_not(actual_issame)))
tn = np.sum(np.logical_and(np.logical_not(predict_issame), np.logical_not(actual_issame)))
fn = np.sum(np.logical_and(np.logical_not(predict_issame), actual_issame))
tpr = 0 if (tp+fn==0) else float(tp) / float(tp+fn)
fpr = 0 if (fp+tn==0) else float(fp) / float(fp+tn)
acc = float(tp+tn)/dist.size
return tpr, fpr, acc
def calculate_val(thresholds, embeddings1, embeddings2, actual_issame, far_target, nrof_folds=10):
assert(embeddings1.shape[0] == embeddings2.shape[0])
assert(embeddings1.shape[1] == embeddings2.shape[1])
nrof_pairs = min(len(actual_issame), embeddings1.shape[0])
nrof_thresholds = len(thresholds)
k_fold = LFold(n_splits=nrof_folds, shuffle=False)
val = np.zeros(nrof_folds)
far = np.zeros(nrof_folds)
diff = np.subtract(embeddings1, embeddings2)
dist = np.sum(np.square(diff),1)
indices = np.arange(nrof_pairs)
for fold_idx, (train_set, test_set) in enumerate(k_fold.split(indices)):
# Find the threshold that gives FAR = far_target
far_train = np.zeros(nrof_thresholds)
for threshold_idx, threshold in enumerate(thresholds):
_, far_train[threshold_idx] = calculate_val_far(threshold, dist[train_set], actual_issame[train_set])
if np.max(far_train)>=far_target:
f = interpolate.interp1d(far_train, thresholds, kind='slinear')
threshold = f(far_target)
else:
threshold = 0.0
val[fold_idx], far[fold_idx] = calculate_val_far(threshold, dist[test_set], actual_issame[test_set])
val_mean = np.mean(val)
far_mean = np.mean(far)
val_std = np.std(val)
return val_mean, val_std, far_mean
def calculate_val_far(threshold, dist, actual_issame):
predict_issame = np.less(dist, threshold)
true_accept = np.sum(np.logical_and(predict_issame, actual_issame))
false_accept = np.sum(np.logical_and(predict_issame, np.logical_not(actual_issame)))
n_same = np.sum(actual_issame)
n_diff = np.sum(np.logical_not(actual_issame))
#print(true_accept, false_accept)
#print(n_same, n_diff)
val = float(true_accept) / float(n_same)
far = float(false_accept) / float(n_diff)
return val, far
def evaluate(embeddings, actual_issame, nrof_folds=10, pca=0):
# Calculate evaluation metrics
thresholds = np.arange(0, 4, 0.01)
embeddings1 = embeddings[0::2]
embeddings2 = embeddings[1::2]
tpr, fpr, accuracy = calculate_roc(thresholds, embeddings1, embeddings2,
np.asarray(actual_issame), nrof_folds=nrof_folds, pca = pca)
thresholds = np.arange(0, 4, 0.001)
val, val_std, far = calculate_val(thresholds, embeddings1, embeddings2,
np.asarray(actual_issame), 1e-3, nrof_folds=nrof_folds)
return tpr, fpr, accuracy, val, val_std, far
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from setuptools import setup, find_packages
setup(name="plsc",
version="0.1.0",
description="Large Scale Classfication via distributed fc.",
author='lilong',
author_email="lilong.albert@gmail.com",
url="http",
license="Apache",
#packages=['paddleXML'],
packages=find_packages(),
#install_requires=['paddlepaddle>=1.6.1'],
python_requires='>=2'
)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from __future__ import division
import os
import argparse
import random
import time
import math
import logging
import sqlite3
import tempfile
logging.basicConfig(level=logging.INFO,
format='[%(levelname)s %(asctime)s line:%(lineno)d] %(message)s',
datefmt='%d %b %Y %H:%M:%S')
logger = logging.getLogger()
parser = argparse.ArgumentParser(description="""
Tool to preprocess dataset in base64 format.""")
"""
We assume that the directory of dataset contains a file-list file, and one
or more data files. Each line of the file-list file represents a data file.
Each line of a data file represents a image in base64 format.
For example:
dir
|-- file_list.txt
|-- part_one.txt
`-- part_two.txt
In the above example, the file file_list.txt has two lines:
part_one.txt
part_two.txt
Each line in part_one.txt and part_two.txt represents a image in base64
format.
"""
parser.add_argument("--data_dir",
type=str,
required=True,
default=None,
help="Directory for datasets.")
parser.add_argument("--file_list",
type=str,
required=True,
default=None,
help="The file contains a set of data files.")
parser.add_argument("--nranks",
type=int,
required=True,
default=1,
help="Number of ranks.")
args = parser.parse_args()
class Base64Preprocessor(object):
def __init__(self, data_dir, file_list, nranks):
super(Base64Preprocessor, self).__init__()
self.data_dir = data_dir
self.file_list = file_list
self.nranks = nranks
self.tempfile = tempfile.NamedTemporaryFile(delete=False, dir=data_dir)
self.sqlite3_file = self.tempfile.name
self.conn = None
self.cursor = None
def create_db(self):
start = time.time()
print(self.sqlite3_file)
self.conn = sqlite3.connect(self.sqlite3_file)
self.cursor = self.conn.cursor()
self.cursor.execute('''CREATE TABLE DATASET
(ID INT PRIMARY KEY NOT NULL,
DATA TEXT NOT NULL,
LABEL INT NOT NULL);''')
file_list_path = os.path.join(self.data_dir, self.file_list)
with open(file_list_path, 'r') as f:
cnt = 0
for line in f.xreadlines():
line = line.strip()
file_path = os.path.join(self.data_dir, line)
with open(file_path, 'r') as df:
for line in df.xreadlines():
line = line.strip().split('\t')
label = int(line[0])
data = line[1]
sql_cmd = "INSERT INTO DATASET (ID, DATA, LABEL) "
sql_cmd += "VALUES ({}, '{}', {});".format(cnt, data, label)
self.cursor.execute(sql_cmd)
cnt += 1
os.remove(file_path)
self.conn.commit()
diff = time.time() - start
print("time: ", diff)
return cnt
def shuffle_files(self):
num = self.create_db()
nranks = self.nranks
index = [i for i in range(num)]
seed = int(time.time())
random.seed(seed)
random.shuffle(index)
start_time = time.time()
lines_per_rank = int(math.ceil(num/nranks))
total_num = lines_per_rank * nranks
index = index + index[0:total_num - num]
assert len(index) == total_num
for rank in range(nranks):
start = rank * lines_per_rank
end = (rank + 1) * lines_per_rank # exclusive
f_handler = open(os.path.join(self.data_dir,
".tmp_" + str(rank)), 'w')
for i in range(start, end):
idx = index[i]
sql_cmd = "SELECT DATA, LABEL FROM DATASET WHERE ID={};".format(idx)
cursor = self.cursor.execute(sql_cmd)
for result in cursor:
data = result[0]
label = result[1]
line = data + '\t' + str(label) + '\n'
f_handler.writelines(line)
f_handler.close()
data_dir = self.data_dir
file_list = self.file_list
file_list = os.path.join(data_dir, file_list)
temp_file_list = file_list + "_temp"
with open(temp_file_list, 'w') as f_t:
for rank in range(nranks):
line = "base64_rank_{}".format(rank)
line += '\n'
f_t.writelines(line)
os.rename(os.path.join(data_dir, ".tmp_" + str(rank)),
os.path.join(data_dir, "base64_rank_{}".format(rank)))
os.remove(file_list)
os.rename(temp_file_list, file_list)
print("shuffle time: ", time.time() - start_time)
def close_db(self):
self.conn.close()
self.tempfile.close()
def main():
global args
obj = Base64Preprocessor(args.data_dir, args.file_list, args.nranks)
obj.shuffle_files()
obj.close_db()
#data_dir = args.data_dir
#file_list = args.file_list
#nranks = args.nranks
#names, file_num_map, num = get_image_info(data_dir, file_list)
#
if __name__ == "__main__":
main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册