提交 7a860694 编写于 作者: K Kexin Zhao 提交者: Yi Wang

Add float16 demo code and put float16 work in contrib/float16 folder (#10331)

* add test float16 inference accuracy example

* complete the test

* clean code

* add argument parse and refine tests

* add shell script

* add float16 benchmark code

* refine code

* prepare for contrib/float16

* put things in contrib float16 folder

* update benchmark result

* further update benchmark report

* add float16 inference report

* update report
上级 f428e82d
# float16 benchmark
## Description
We want to compare the inference benchmark of float16 vs float32 on the "image_classification" example on Nvidia Tesla V100 GPU, where we can enable the tensor core computation for float16 mode. We test Vgg16 and Resnet50 on the imagenet data set, and Vgg16 and Resnet32 on the cifar10 data set. For completeness, we also add the inference benchmark of Vgg16 and Resnet50 on imagenet data set tested on Nvidia GeForce GTX 1080 Ti GPU.
For more details about tensor core, please refer to https://devblogs.nvidia.com/programming-tensor-cores-cuda-9/
## Test environment
- GPU: single Nvidia Tesla V100 or single Nvidia GeForce GTX 1080 Ti
- CUDNN: 7.1.1
- CUDA: 9.0
- Code: https://github.com/PaddlePaddle/Paddle/pull/10331 (Tensor core is enabled in float16 mode)
## Benchmark on V100
All times are in ms (millisecond) averaged over 1000 iterations tested on a single Nvidia V100 GPU with respective to different mini-batch(mb) sizes.
### Vgg16 on imagenet (flowers data set: image.shape = [3, 224, 224]):
Total inference time for one batch:
| | mb=1 | mb=2 | mb=4 | mb=8 | mb=16 | mb=32 | mb=64 |
|-------|-----: |-----: |-----: |-----: |------: |------:|-------:|
|float32| 14.01 | 9.70 | 22.99 | 28.26 | 53.87 | 84.42 | 178.95 |
|float16| 3.32 | 4.11 | 5.88 | 9.41 | 16.54 | 30.47 | 60.23 |
|Speedup| 4.22 | 2.36  | 3.91 | 3.00 | 3.26  | 2.77 | 2.97 |
Total time spent on conv op for one batch:
| | mb=1 | mb=2 | mb=4 | mb=8 | mb=16 | mb=32 | mb=64 |
|-------|-----: |-----: |-----: |-----: |------: |------:|-------:|
|float32| 11.95 | 6.96 | 18.65 | 21.42 | 41.35 | 60.58 | 130.11 |
|float16| 1.78 | 2.10 | 2.93 | 4.55 | 7.99 | 14.63 | 28.67 |
|Speedup| 6.71 | 3.31  | 6.37 | 4.71 | 5.18  | 4.14 | 4.54 |
### Resnet50 on imagenet (flowers data set: image.shape = [3, 224, 224]):
Total inference time for one batch:
|       | mb=1 | mb=2 | mb=4 | mb=8 | mb=16 | mb=32 | mb=64 | mb=128 |
|-------|-----: |-----: |-----: |-----: |------: |------:|-------:|-------:|
|float32| 7.03 | 7.41 | 9.16 | 12.55 | 21.13 | 38.27 | 67.93 | 127.02 |
|float16| 6.13 | 6.32 | 6.24 | 7.40 | 10.90 | 18.18 | 33.20 | 64.52 |
|Speedup| 1.15 | 1.17  | 1.47  | 1.70 | 1.94  | 2.11 | 2.05 | 1.97 |
Total time spent on conv op for one batch:
|       | mb=1 | mb=2 | mb=4 | mb=8 | mb=16 | mb=32 | mb=64 | mb=128 |
|-------|-----: |-----: |-----: |-----: |------: |------:|-------:|-------:|
|float32| 5.43 | 5.46 | 6.50 | 8.36 | 13.80 | 24.45 | 41.21 | 73.44 |
|float16| 4.19 | 4.30 | 3.96 | 4.21 | 5.63 | 8.77 | 15.24 | 28.40 |
|Speedup| 1.30 | 1.27  | 1.64  | 1.99 | 2.45  | 2.79 | 2.70 | 2.59 |
### Vgg16 on cifar10 (image.shape = [3, 32, 32]):
Total inference time for one batch:
| | mb=1 | mb=2 | mb=4 | mb=8 | mb=16 | mb=32 | mb=64 | mb=128 | mb=256 | mb=512 |
|-------|-----:|-----:|-----:|-----:|------:|------:|------:|-------:|-------:|-------:|
|float32| 3.13 | 3.17 | 3.19 | 3.58 | 3.98 | 6.23 | 8.42 | 13.44 | 24.19 | 44.97 |
|float16| 2.72 | 2.77 | 2.76 | 2,88 | 2.96 | 3.24 | 4.01 | 5.78 | 9.65 | 17.37 |
|Speedup| 1.15 | 1.14 | 1.16 | 1.24 | 1.34 | 1.92  | 2.10 | 2.33  | 2.51 | 2.59 |
### Resnet32 on cifar10 (image.shape = [3, 32, 32]):
Total inference time for one batch:
| | mb=1 | mb=2 | mb=4 | mb=8 | mb=16 | mb=32 | mb=64 | mb=128 | mb=256 | mb=512 |
|-------|-----:|-----:|-----:|-----:|------:|------:|------:|-------:|-------:|-------:|
|float32| 3.11 | 3.14 | 2.99 | 3.04 | 3.10 | 3.28 | 4.47 | 6.86 | 11.63 | 21.16 |
|float16| 3.70 | 3.81 | 3.75 | 3.83 | 3.77 | 3.97 | 3.92 | 4.15 | 6.41 | 11.02 |
|Speedup|     |     |     |     |       | | 1.14  | 1.65 | 1.81 | 1.92 |
## Benchmark on 1080 Ti
All times are in ms (millisecond) averaged over 1000 iterations tested on a single Nvidia GeForce GTX 1080 Ti GPU with respective to different mini-batch(mb) sizes.
### Vgg16 on imagenet (flowers data set: image.shape = [3, 224, 224]):
Total inference time for one batch:
| | mb=1 | mb=2 | mb=4 | mb=8 | mb=16 | mb=32 |
|-------|-----: |-----: |-----: |-----: |------: |-------:|
|float32| 5.60 | 9.38 | 15.86 | 29.79 | 57.60 | 117.73 |
|float16| 4.99 | 7.79 | 13.47 | 26.02 | 52.30 | 102.34 |
|Speedup| 1.12 | 1.20  | 1.18 | 1.15 | 1.10  | 1.15 |
### Resnet50 on imagenet (flowers data set: image.shape = [3, 224, 224]):
Total inference time for one batch:
| | mb=1 | mb=2 | mb=4 | mb=8 | mb=16 | mb=32 | mb=64 |
|-------|-----: |-----: |-----: |-----: |------: |-------:|-------:|
|float32| 5.63 | 6.23 | 8.85 | 14.71 | 26.07 | 52.86 | 108.95 |
|float16| 5.89 | 6.44 | 7.94 | 12.57 | 22.03 | 45.06 | 92.68 |
|Speedup| |  | 1.12  | 1.17 | 1.18  | 1.17 | 1.18 |
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from float16_transpiler import Float16Transpiler
import argparse
import paddle
import paddle.fluid as fluid
import contextlib
import math
import sys
import numpy as np
import os
parser = argparse.ArgumentParser(
'Float16 inference accuracy test and benchmark.')
parser.add_argument(
'--train_batch_size', type=int, default=16, help="Batch size for training.")
parser.add_argument(
'--inf_batch_size', type=int, default=32, help="Batch size for inference.")
parser.add_argument(
'--repeat', type=int, default=1, help="How many times to run the test.")
parser.add_argument(
'--data_set',
type=str,
default='cifar10',
choices=['cifar10', 'imagenet'],
help="Optional dataset for benchmark.")
parser.add_argument(
'--model',
type=str,
default='vgg',
choices=['vgg', 'resnet'],
help="Optional model for benchmark.")
parser.add_argument(
'--threshold',
type=float,
default=0.005,
help='Save inference model when test accuracy reach this threshold.')
parser.add_argument('--learning_rate', type=float, default=0.001)
args = parser.parse_args()
def conv_bn_layer(input, ch_out, filter_size, stride, padding, act='relu'):
conv1 = fluid.layers.conv2d(
input=input,
filter_size=filter_size,
num_filters=ch_out,
stride=stride,
padding=padding,
act=None,
bias_attr=False)
return fluid.layers.batch_norm(input=conv1, act=act)
def shortcut(input, ch_out, stride):
ch_in = input.shape[1]
if ch_in != ch_out:
return conv_bn_layer(input, ch_out, 1, stride, 0, None)
else:
return input
def basicblock(input, ch_out, stride):
short = shortcut(input, ch_out, stride)
conv1 = conv_bn_layer(input, ch_out, 3, stride, 1)
conv2 = conv_bn_layer(conv1, ch_out, 3, 1, 1, act=None)
return fluid.layers.elementwise_add(x=short, y=conv2, act='relu')
def bottleneck(input, ch_out, stride):
short = shortcut(input, ch_out * 4, stride)
conv1 = conv_bn_layer(input, ch_out, 1, stride, 0)
conv2 = conv_bn_layer(conv1, ch_out, 3, 1, 1)
conv3 = conv_bn_layer(conv2, ch_out * 4, 1, 1, 0, act=None)
return fluid.layers.elementwise_add(x=short, y=conv3, act='relu')
def layer_warp(block_func, input, ch_out, count, stride):
res_out = block_func(input, ch_out, stride)
for i in range(1, count):
res_out = block_func(res_out, ch_out, 1)
return res_out
def resnet_imagenet(input, depth=50):
cfg = {
18: ([2, 2, 2, 1], basicblock),
34: ([3, 4, 6, 3], basicblock),
50: ([3, 4, 6, 3], bottleneck),
101: ([3, 4, 23, 3], bottleneck),
152: ([3, 8, 36, 3], bottleneck)
}
stages, block_func = cfg[depth]
conv1 = conv_bn_layer(input, ch_out=64, filter_size=7, stride=2, padding=3)
pool1 = fluid.layers.pool2d(
input=conv1, pool_type='avg', pool_size=3, pool_stride=2)
res1 = layer_warp(block_func, pool1, 64, stages[0], 1)
res2 = layer_warp(block_func, res1, 128, stages[1], 2)
res3 = layer_warp(block_func, res2, 256, stages[2], 2)
res4 = layer_warp(block_func, res3, 512, stages[3], 2)
pool2 = fluid.layers.pool2d(
input=res4,
pool_size=7,
pool_type='avg',
pool_stride=1,
global_pooling=True)
return pool2
def resnet_cifar10(input, depth=32):
assert (depth - 2) % 6 == 0
n = (depth - 2) // 6
conv1 = conv_bn_layer(
input=input, ch_out=16, filter_size=3, stride=1, padding=1)
res1 = layer_warp(basicblock, conv1, 16, n, 1)
res2 = layer_warp(basicblock, res1, 32, n, 2)
res3 = layer_warp(basicblock, res2, 64, n, 2)
pool = fluid.layers.pool2d(
input=res3, pool_size=8, pool_type='avg', pool_stride=1)
return pool
def vgg16(input):
def conv_block(input, num_filter, groups, dropouts):
return fluid.nets.img_conv_group(
input=input,
pool_size=2,
pool_stride=2,
conv_num_filter=[num_filter] * groups,
conv_filter_size=3,
conv_act='relu',
conv_with_batchnorm=True,
conv_batchnorm_drop_rate=dropouts,
pool_type='max')
conv1 = conv_block(input, 64, 2, [0.3, 0])
conv2 = conv_block(conv1, 128, 2, [0.4, 0])
conv3 = conv_block(conv2, 256, 3, [0.4, 0.4, 0])
conv4 = conv_block(conv3, 512, 3, [0.4, 0.4, 0])
conv5 = conv_block(conv4, 512, 3, [0.4, 0.4, 0])
drop = fluid.layers.dropout(x=conv5, dropout_prob=0.5)
fc1 = fluid.layers.fc(input=drop, size=4096, act=None)
bn = fluid.layers.batch_norm(input=fc1, act='relu')
drop2 = fluid.layers.dropout(x=bn, dropout_prob=0.5)
fc2 = fluid.layers.fc(input=drop2, size=4096, act=None)
return fc2
def train(place, save_dirname):
if args.data_set == "cifar10":
class_dim = 10
data_shape = [3, 32, 32]
elif args.data_set == "imagenet":
class_dim = 102
data_shape = [3, 224, 224]
else:
raise ValueError("%s dataset is not supported" % data_set)
images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
if args.model == "vgg":
print("train vgg")
net = vgg16(images)
elif args.model == "resnet":
print("train resnet")
if args.data_set == "cifar10":
net = resnet_cifar10(images)
elif args.data_set == "imagenet":
net = resnet_imagenet(images)
else:
raise ValueError("%s dataset is not supported" % args.data_set)
else:
raise ValueError("%s network is not supported" % args.model)
predict = fluid.layers.fc(input=net, size=class_dim, act='softmax')
cost = fluid.layers.cross_entropy(input=predict, label=label)
avg_cost = fluid.layers.mean(x=cost)
acc = fluid.layers.accuracy(input=predict, label=label)
#Test program
test_program = fluid.default_main_program().clone(for_test=True)
optimizer = fluid.optimizer.Adam(learning_rate=args.learning_rate)
optimizer.minimize(avg_cost)
BATCH_SIZE = args.train_batch_size
PASS_NUM = 100
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.flowers.train()
if args.data_set == 'imagenet' else paddle.dataset.cifar.train10(),
buf_size=128 * 10),
batch_size=args.train_batch_size)
test_reader = paddle.batch(
paddle.dataset.flowers.test()
if args.data_set == 'imagenet' else paddle.dataset.cifar.test10(),
batch_size=args.inf_batch_size)
exe = fluid.Executor(place)
feeder = fluid.DataFeeder(place=place, feed_list=[images, label])
exe.run(fluid.default_startup_program())
main_program = fluid.default_main_program()
for pass_id in range(PASS_NUM):
for batch_id, data in enumerate(train_reader()):
train_image = np.array(
map(lambda x: x[0].reshape(data_shape), data)).astype("float32")
train_label = np.array(map(lambda x: x[1], data)).astype("int64")
train_label = train_label.reshape([-1, 1])
exe.run(main_program,
feed={'pixel': train_image,
'label': train_label})
if (batch_id % 100) == 0:
acc_list = []
avg_loss_list = []
for tid, test_data in enumerate(test_reader()):
test_image = np.array(
map(lambda x: x[0].reshape(data_shape),
test_data)).astype("float32")
test_label = np.array(map(lambda x: x[1],
test_data)).astype("int64")
test_label = test_label.reshape([-1, 1])
loss_t, acc_t = exe.run(
program=test_program,
feed={"pixel": test_image,
"label": test_label},
fetch_list=[avg_cost, acc])
if math.isnan(float(loss_t)):
sys.exit("got NaN loss, training failed.")
acc_list.append(float(acc_t))
avg_loss_list.append(float(loss_t))
acc_value = np.array(acc_list).mean()
avg_loss_value = np.array(avg_loss_list).mean()
print(
'PassID {0:1}, BatchID {1:04}, Test Loss {2:2.2}, Accuracy {3:2.2}'.
format(pass_id, batch_id + 1,
float(avg_loss_value), float(acc_value)))
if acc_value > args.threshold:
print(
'Save inference model with test accuracy of {0} at {1}'.
format(float(acc_value), save_dirname))
fluid.io.save_inference_model(save_dirname, ["pixel"],
[predict], exe)
return
def test_accuracy(executor, inference_program, feed_target_names,
fetch_targets):
if args.data_set == "cifar10":
data_shape = [3, 32, 32]
elif args.data_set == "imagenet":
data_shape = [3, 224, 224]
else:
raise ValueError("%s dataset is not supported" % data_set)
test_reader = paddle.batch(
paddle.dataset.cifar.test10()
if args.data_set == "cifar10" else paddle.dataset.flowers.test(),
batch_size=args.inf_batch_size)
test_num = 0
correct_num = 0
for test_data in test_reader():
test_image = np.array(
map(lambda x: x[0].reshape(data_shape), test_data)).astype(
"float32")
test_label = np.array(map(lambda x: x[1], test_data)).astype("int64")
test_label = test_label.reshape([-1, 1])
results = executor.run(program=inference_program,
feed={feed_target_names[0]: test_image},
fetch_list=fetch_targets)
prediction = np.argmax(results[0], axis=1).reshape([-1, 1])
correct_num += np.sum(prediction == test_label)
test_num += test_label.size
print("{0} out of {1} predictions are correct.".format(correct_num,
test_num))
print("Test accuray is {0}.".format(float(correct_num) / float(test_num)))
def infer(place, save_dirname):
exe = fluid.Executor(place)
inference_scope = fluid.core.Scope()
with fluid.scope_guard(inference_scope):
# Use fluid.io.load_inference_model to obtain the inference program desc,
# the feed_target_names (the names of variables that will be feeded
# data using feed operators), and the fetch_targets (variables that
# we want to obtain data from using fetch operators).
print("Load inference model from {0}".format(save_dirname))
[inference_program, feed_target_names,
fetch_targets] = fluid.io.load_inference_model(save_dirname, exe)
print("The test set accuracy of inference in float mode is:")
test_accuracy(exe, inference_program, feed_target_names, fetch_targets)
float16_inference_program = inference_program.clone()
t = Float16Transpiler()
t.transpile(float16_inference_program, place)
print("The test set accuracy of inference in float16 mode is:")
test_accuracy(exe, float16_inference_program, feed_target_names,
fetch_targets)
fp16_save_dirname = "float16_" + save_dirname
fluid.io.save_inference_model(fp16_save_dirname, feed_target_names,
fetch_targets, exe,
float16_inference_program)
@contextlib.contextmanager
def scope_prog_guard():
prog = fluid.Program()
startup_prog = fluid.Program()
scope = fluid.core.Scope()
with fluid.scope_guard(scope):
with fluid.program_guard(prog, startup_prog):
yield
if __name__ == "__main__":
if not fluid.core.is_compiled_with_cuda():
raise Exception("This test requires CUDA GPUs!")
place = fluid.CUDAPlace(0)
if not fluid.core.is_float16_supported(place):
raise Exception(
"This test requires compute capability of CUDA GPU >= 5.3!")
for i in range(args.repeat):
with scope_prog_guard():
save_dirname = "image_classification_" + args.data_set + "_" + args.model + ".inference.model"
train(place, save_dirname)
infer(place, save_dirname)
## Introduction
Working with deep neural networks (DNN) is a two-stage process. First we train DNN using labeled examples of inputs and desired outputs to obtain the model parameters (weights), then we deploy DNN along with the trained weights to run inference on unknown inputs. Typically, these weights are in float data type and hence we run inference in float mode using these weights. This post focuses on the discussion of how to use low precision float16 data type to represent these trained weights and run inference in float16 mode as well as the advantages of float16 inference over its float counterpart by showing some experiment results.
## What is float16?
float16 (or FP16) is a half-precision floating-point format that uses 16 bits in memory to represent a value. The advantage over 32-bit single-precision floating-point format (commonly known as float data type) is that it requires half the storage and bandwidth at the expense of precision and range. Fortunately, DNN inference has high tolerance against the loss of precision and range when using float16 to represent the weights and the inference accuracy will only be minimally affected in most cases. This gives us the opportunity to use float16 data type to speedup the inference.
Interested readers can refer to our [design doc](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/fluid/design/data_type/float16.md) and [code](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/fluid/platform/float16.h) for more details on how we implement the float16 data type.
## Why float16?
The trend in today's deep learning community is to use bigger and deeper model. This translates to larger memory footprint, higher computation demands, and as a result higher energy consumption on computing devices. The advantages of float16 over float are correspondingly three-fold:
1. We only need half the memory size to load the same model using float16 representations. Moreover, most of the intermediate results generated during float16 inference are also of float16 data type. This makes the whole memory footprint of float16 inference roughly about half of its float counterpart. This is especially useful when deploying inference on mobile devices with limited available memory. Also given the same available memory, the maximum batch size for float16 inference is about twice that for float inference.
2. Because float16 occupies less memory than float, in theory hardware devices can achieve much higher floating point operators per second (FLOPS) for float16 data than float data. Right now, an outstanding example of hardware devices that actually deliver such advantages is Nvidia's latest Volta architecture GPUs, including Tesla V100 and Titan V. Moreover float16 takes less time to read from or write to memory and hence float16 can make inference more efficient especially in memory-bound applications where the performance is largely affected by how fast it is to read and write data.
3. From the energy efficiency perspective, the energy needed to read, write, and compute float16 data is much less that its float counterpart, which can significantly reduce the battery power consumption on mobile devices or the total cost of ownership (TCO) of data centers.
## Fluid implementation of float16 inference
### Overview
Fluid use [Program](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/fluid/design/modules/python_api.md#program) instead of computation graph to describe a neural network model and the optimization procedure. Fluid program is a python wrapper around a protobuf message called [ProgramDesc](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/fluid/design/concepts/program.md). Similar to programming languages, the basic structure of a Fluid program is some nested [blocks](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/fluid/design/modules/python_api.md#block), where each block consists of some [variable](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/fluid/design/modules/python_api.md#variable) definitions and a sequence of [operators](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/fluid/design/modules/python_api.md#operator). An [executor](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/fluid/design/concepts/executor.md) will run a given program by sequentially executing the operators in the entrance block.
### Basic requirement
When an operator is run by an executor, it uses a kernel to perform computations on tensors contained in the input variables, and then write the results to the tensors in the output variables. Each operator has multiple kernels for different combinations of data types, devices, and library types, respectively. The operator will select the appropriate kernel to run based on, among other things, the data type of the input tensors. By default, every Fluid operator has a kernel for float data type that takes float inputs and generates float outputs.
This means that if we provide float input to the first operator in a program, then each operator will use float kernel to compute float output and send it as input to the next operator to trigger its float kernel. This chain effect will makes the program run in float mode and gives us a final output of float data type.
The same principle applies if we want a program to run in float16 mode. We provide input variable of float16 data type to the first operator and every subsequent operator will invoke the float16 kernel until we get the final output in float16 data type. So the preliminary requirements for float16 inference is to add float16 kernels to operators that are needed in a specific kind of neural networks. Our current focus is on Convolutional Neural Networks (CNN) and hence we have added float16 kernels to the following operators: convolution, pooling, GEMM, elementwise addition, batch norm, dropout, various activations including relu and tanh, and softmax.
### float16 transpiler
Furthermore, we need a float16 transpiler to achieve the following usage code:
```python
# Get the float32 inference program and load the associated float32 weights
[inference_program, feed_target_names,
fetch_targets] = fluid.io.load_inference_model(save_dirname, exe)
# Prepare the float input data
batch_size = 1
tensor_img = numpy.random.rand(batch_size, 3, 32, 32).astype(numpy.float32)
# Running inference_program in float mode
float_results = exe.run(inference_program,
feed={feed_target_names[0]: tensor_img},
fetch_list=fetch_targets)
# Use float16 transpiler to speedup
float16_inference_program = float_inference_program.clone()
t = Float16Transpiler()
t.transpile(float16_inference_program, GPUPlace)
# Running float16_inference_program in float16 mode using the same input data
float16_results = exe.run(float16_inference_program,
feed={feed_target_names[0]: tensor_img},
fetch_list=fetch_targets)
# Do some tests to verify the correctness of float16 inference
...
np.testing.assert_almost_equal(float_results, float16_results, ...)
...
# Save the float16 inference program and float16 weights for future deployment
fluid.io.save_inference_model(fp16_save_dirname, feed_target_names,
fetch_targets, exe,
float16_inference_program)
```
In this scenario, we already have a float32 inference program and some associated float32 weights that can do float32 inference. We can easily use the `transpile` method of the `Float16Transpiler` class to do certain modifications to the existing program and weights so that we have a new float16 program and the associated float16 weights.
We can then run various inference experiments in float16 mode and save the float16 program and weights on disk for future deployment. To enhance the code usability, we maintain a consistent API so that user can use the same float32 input data to run inference program in either float32 and float16 mode and obtain output data both of float32 data type. This requires us to add some cast operators in the program to convert between float16 tensor and float32 tensor.
The float16 transpiler is implemented to fulfill the requirements mentioned above. The details of the float16 transpiler can be found [here](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/fluid/design/data_type/float16.md#float16-inference).
### Experiment results
We provide demo codes that can be used to reproduce the experiment results by doing:
```bash
git clone https://github.com/PaddlePaddle/Paddle.git
cd Paddle
# This line will generate a paddle development docker image with cuda 8 and cudnn 7
# If you want test on cuda 9 instead, change the line 5 in Paddle/Dockerfile
# from `FROM nvidia/cuda:8.0-cudnn7-devel-ubuntu16.04`
# to `FROM nvidia/cuda:9.0-cudnn7-devel-ubuntu16.04` and similarly for other configurations
nvidia-docker build -t paddle:float16 .
# After running this, different results will be written to different log files in Paddle/contrib/float16/
nvidia-docker run -it -v $PWD:/paddle paddle:float16 /paddle/contrib/float16/run_float16_demo.sh
```
#### Correctness
As is mentioned before, DNN inference has been found to be tolerant against the loss of precision and range incured by float16 and we want to see how good this tolerance is.
We train a resnet32 model using cifar10 data set, save it when test set accuracy is above 60%, and then test the inference accuracy on the 10000 examples of the cifar10 test set in float16 and float32 mode, respectively.
We repeat the test ten times and get the following results:
| | float16 | float32 |
|--------|--------:|--------: |
| # 1 | 62.75% | 62.72% |
| # 2 | 61.27% | 61.28% |
| # 3 | 62.24% | 62.23% |
| # 4 | 64.16% | 64.17% |
| # 5 | 60.75% | 60.77% |
| # 6 | 63.25% | 63.24% |
| # 7 | 62.15% | 62.13% |
| # 8 | 62.05% | 62.02% |
| # 9 | 65.19% | 65.20% |
| #10 | 62.53% | 62.48% |
| average| 62.63% | 62.62% |
We can see that the accuracy of float16 inference is very close to that of float32 inference in every experiment (within 0.05% difference) and is overall 0.01% better than its float32 counterpart averaged over 10 tests.
#### Performance benchmark
Currently, Fluid inference in float16 mode is only supported on Nvidia GPU device. There is no motivation to support float16 inference on non-ARM CPUs because float16 is not natively supported there and float16 calculation will only be slower than its float counterpart.
Nvidia started to support its native float16 data type (which has the same internal memory representation as Fluid float16 class) on CUDA 7.5. Moreover, float16 speedups on common computational intensive tasks including GEMM (general matrix-matrix multiplication) and convolution are supported since cublas 7.5 and cuDNN 5.0.
Recently, the introduction of [tensor core](https://devblogs.nvidia.com/programming-tensor-cores-cuda-9/) in volta architecture GPUs and the support of tensor core calculation in CUDA 9.0 and cuDNN 7 make float16 truly superior to float in certain deep learning applications.
We thus benchmark the float16 inference performance on a single Nvidia Tesla V100 GPU (volta architecture and with tensor cores) and compare it with its float32 counterpart. All the following results are in ms (millisecond) averaged over 1000 mini-batches with respective to different mini-batch(mb) sizes.
Average inference time for one mini-batch on Vgg16 model tested on imagenet data set:
| total | mb=1 | mb=2 | mb=4 | mb=8 | mb=16 | mb=32 | mb=64 |
|-------|-----: |-----: |-----: |-----: |------: |------:|-------:|
|float32| 14.01 | 9.70 | 22.99 | 28.26 | 53.87 | 84.42 | 178.95 |
|float16| 3.32 | 4.11 | 5.88 | 9.41 | 16.54 | 30.47 | 60.23 |
|Speedup| 4.22 | 2.36  | 3.91 | 3.00 | 3.26  | 2.77 | 2.97 |
We can see that float16 inference provides 2x ~ 4x speedup on different batch sizes.
Convolution operation is ususally the computational bottleneck of CNN, so we also check the average time spent on the Fluid convolution operators for one mini-batch as follows:
|conv op| mb=1 | mb=2 | mb=4 | mb=8 | mb=16 | mb=32 | mb=64 |
|-------|-----: |-----: |-----: |-----: |------: |------:|-------:|
|float32| 11.95 | 6.96 | 18.65 | 21.42 | 41.35 | 60.58 | 130.11 |
|float16| 1.78 | 2.10 | 2.93 | 4.55 | 7.99 | 14.63 | 28.67 |
|Speedup| 6.71 | 3.31  | 6.37 | 4.71 | 5.18  | 4.14 | 4.54 |
Fluid convolution operator uses cuDNN 7 to implement the kernel and we can see that with the help of tensor core, float16 convolution is significantly faster than its float32 counterpart, which makes the overall float16 inference performance much better.
Similarly, we also list the benchmark results of Resnet50 model tested on imagenet data set:
| total | mb=1 | mb=2 | mb=4 | mb=8 | mb=16 | mb=32 | mb=64 | mb=128 |
|-------|-----: |-----: |-----: |-----: |------: |------:|-------:|-------:|
|float32| 7.03 | 7.41 | 9.16 | 12.55 | 21.13 | 38.27 | 67.93 | 127.02 |
|float16| 6.13 | 6.32 | 6.24 | 7.40 | 10.90 | 18.18 | 33.20 | 64.52 |
|Speedup| 1.15 | 1.17  | 1.47  | 1.70 | 1.94  | 2.11 | 2.05 | 1.97 |
|conv op| mb=1 | mb=2 | mb=4 | mb=8 | mb=16 | mb=32 | mb=64 | mb=128 |
|-------|-----: |-----: |-----: |-----: |------: |------:|-------:|-------:|
|float32| 5.43 | 5.46 | 6.50 | 8.36 | 13.80 | 24.45 | 41.21 | 73.44 |
|float16| 4.19 | 4.30 | 3.96 | 4.21 | 5.63 | 8.77 | 15.24 | 28.40 |
|Speedup| 1.30 | 1.27  | 1.64  | 1.99 | 2.45  | 2.79 | 2.70 | 2.59 |
We find that the speedup provided by float16 inference starts relatively small at 1.15x for batch size 1 and gradually increase to about 2x for larger batch sizes. Similar trend can be found for the time spent on the convolution operator. Note that right now the tensor core will only be utilized in the convolution operation when certain dimentional requirements are met for the input data and filter. The speedup by float16 inference for Resnet50 is smaller than the Vgg16 counterpart partially because the convolution operation in Resnet is much simpler than the Vgg counterpart and this makes the tensor core less utilized in Resnet than in Vgg.
We also did the same benchmark on a Nvidia GeForce GTX 1080 Ti GPU that does not support tensor core. The results show that for Vgg16, float16 inference provides consistent small speedup (around 1.15x) for all mini-batch sizes, while for Resnet50, float16 inference is slower than its float32 counterpart in small batch sizes (mb = 1 and 2) and then deliver around 1.15x speedup for all larger batch sizes. By comparing the benchmarks on 1080 Ti and V100, we find that tensor core, which is specialized for float16 computations, is a critical component for high performance float16 inference.
Please refer to [here](https://github.com/PaddlePaddle/Paddle/blob/develop/contrib/float16/float16_benchmark.md) for comprehensive benchmark results.
### Summary
1. Fluid is now able to run inference in float16 mode via a float16 transpiler. We currently support CNN programs, including Vgg and Resnet, to run in float16 inference mode.
2. The accuracy of float16 inference is verified to be almost identical to the float32 counterpart at least on CNNs.
3. float16 inference provides significant speedup on large and computationally intensive Vgg16 network on image net data set. For the much smaller and simpler Resnet50, the speedup provided by float16 inference is less significant than on Vgg16 but still favorable especially for large batch size.
4. We cannot achieve the superior float16 inference performance without the help of the newly introduced tensor cores on the Nvidia Volta architecture GPUs.
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle.fluid.core as core
from paddle.fluid.framework import Program
from paddle.fluid.executor import global_scope
class Float16Transpiler:
def transpile(self, program, place, scope=None):
'''
Transpile the program desc and cast the weights to float16 data type to
enable float16 inference.
Since the operator in a program desc will automatically choose the
right compute kernel to run based on the data type of the input tensor.
We actually don't need to change the program desc to run in float16 mode.
However, in this way, users who are used to feeding and fetching tensors
of float32 data type when running typical inference may find it confusing
and difficult to run inference in float16 mode as they need to convert
input data to float16 dtype and then convert the results back to float32
dtype to match the rest of code.
So this function appends cast ops to the program desc where necessary so
that users are able to run inference in float16 mode while providing input
tensor (feed_holder) of float data type and obtaining output tensor
(fetch_holder) of float data type.
Moreover, it is desired that when we have the scope and program desc to run
inference in float32 mode, we can use a single API to do the necessary
modification and then user can run float16 inference on the fly. To make
this happen, this function also create new parameters in the scope to have the
converted float16 weights and change the operators in program desc to use
these new parameters.
:param program: program to transpile
:type program: Program
:param place: inference place
:type place: Place
:param scope: inference scope
:type scope: Scope
'''
if not isinstance(program, Program):
raise TypeError("program should be as Program type")
if not isinstance(place, core.CPUPlace) and not isinstance(
place, core.CUDAPlace):
raise TypeError("place should be as CPUPlace/CUDAPlace type")
if scope is None:
scope = global_scope()
if not isinstance(scope, core.Scope):
raise TypeError("scope should be as Scope type or None")
self.scope = scope
self.place = place
self.block = program.block(0)
self.input_map = {} # store the input names should be adjusted
self._modify_feed_fetch()
self._convert_param_to_float16()
self._adjust_input(skip=True)
self._remove_unused_var()
# TODO(luotao): use clone() method to flush the program.desc in force,
# since some large program.desc will not be flushed immediately.
# And a better solution will be considered later.
program = program.clone()
# ====================== private transpiler functions =====================
def _adjust_input(self, skip=False):
'''
Change the input variable name in operators.
When we are in the process of modifying a program desc, we usually
replace some variables with some other variables, where we create
a dictionary input_map to record the one-to-one correspondence
between each old variable and the new one.
After that, this function will search all the operators that use the
old variables and change the info in op to use the new variables. There
maybe some exceptions to this rule when we are using the float16 transpiler
and insert cast ops to cast float32 variable to float16 one. After we
insert the cast op to cast var_1 to var_1_fp16, we don't want to change
the input of cast op to var_1_fp16 after using this function.
'''
skip_ops = {"cast"}
for i in range(len(self.block.ops)):
current_op = self.block.ops[i]
if skip and current_op.type in skip_ops:
continue
for input_arg in current_op.input_arg_names:
if input_arg in self.input_map:
current_op.rename_input(input_arg,
self.input_map[input_arg])
def _remove_unused_var(self):
'''
remove unused varibles in program
'''
args = []
for i in range(len(self.block.ops)):
current_op = self.block.ops[i]
args += current_op.input_arg_names
args += current_op.output_arg_names
args = list(set(args)) # unique the input and output arguments
for var in self.block.vars.keys():
if var not in args:
self.block.remove_var(var)
def _modify_feed_fetch(self):
'''
Modify feed fetch op/vars for float16 inference.
For each feed op:
feed_op->feed_target_var
Change it to:
feed_op->feed_target_var->cast_op(from other dtype to float16)->tmp_var
For each fetch op:
fetch_target_var->fetch_op
Change it to:
tmp_var->cast_op(from float16 to other dtype)->fetch_target_var->fetch_op
:return: None
'''
def find_op(var):
# It is possible that var.op is not up to date after some
# modifications to program desc. Here we force to make it up to date.
var.op = None
for op in self.block.ops:
if var.name in op.output_arg_names:
var.op = op
break
if var.op is None:
raise ValueError("The target variable must have an "
"associated operator that generates it.")
i = 0
while i < len(self.block.ops):
cur_op = self.block.ops[i]
if cur_op.type == "feed":
var_name = cur_op.output("Out")[0]
tmp_var_name = var_name + ".fp16"
var = self.block.vars[var_name]
tmp_var = self.block.create_var(
name=tmp_var_name.encode('ascii'),
type=var.type,
dtype=core.VarDesc.VarType.FP16,
shape=var.shape,
persistable=var.persistable)
self.block.insert_op(
i + 1,
type="cast",
inputs={"X": var},
outputs={"Out": tmp_var},
attrs={
'in_dtype': int(var.dtype),
'out_dtype': int(tmp_var.dtype)
})
self.input_map[var_name] = tmp_var_name
i = i + 1
elif cur_op.type == "fetch":
var_name = cur_op.input("X")[0]
tmp_var_name = var_name + ".fp16"
var = self.block.vars[var_name]
tmp_var = self.block.create_var(
name=tmp_var_name.encode('ascii'),
type=var.type,
dtype=core.VarDesc.VarType.FP16,
shape=var.shape,
persistable=var.persistable)
find_op(var)
var.op.rename_output(var_name, tmp_var_name)
self.block.insert_op(
i,
type="cast",
inputs={"X": tmp_var},
outputs={"Out": var},
attrs={
'in_dtype': int(tmp_var.dtype),
'out_dtype': int(var.dtype)
})
i = i + 1
i = i + 1
def _convert_param_to_float16(self):
def _get_no_fp16_conversion_var_names():
'''
Get the set of input variable names that shouldn't be converted to float16.
When we want to run inference in float16 mode, most parameters need to be
firstly converted to float16. However, there are some parameters that
shouldn't be converted to float16 because the corresponding operator
requires float32 parameters even in float16 mode (when the input data is
of float16 data type). Currently, the only operator that has this exclusion
is the batch norm op.
:return: set of input variable names
:type var_names: set
'''
op_names = {'batch_norm'}
var_names = []
for op in self.block.ops:
if op.type in op_names:
var_names += op.input_arg_names
return set(var_names)
def _should_be_converted(var):
return var.persistable and \
var.name not in self.no_conversion_vars and \
var.type != core.VarDesc.VarType.FEED_MINIBATCH and \
var.type != core.VarDesc.VarType.FETCH_LIST
self.no_conversion_vars = _get_no_fp16_conversion_var_names()
conversion_var_list = filter(_should_be_converted,
self.block.vars.values())
for var in conversion_var_list:
fp16_var_name = var.name + ".fp16"
fp16_var = self.block.create_parameter(
name=fp16_var_name.encode('ascii'),
type=var.type,
dtype=core.VarDesc.VarType.FP16,
shape=var.shape)
# cast the data in the tensor of the original var to float16
# data type and store it in the tensor of the new float16 var
self.scope.var(fp16_var_name)
fp16_tensor = self.scope.find_var(fp16_var_name).get_tensor()
tensor = np.array(self.scope.find_var(var.name).get_tensor())
# After the old tensor data is converted to np.float16, view(np.uint16)
# is used so that the internal memory of the numpy array will be
# reinterpreted to be of np.uint16 data type, which is binded to fluid
# float16 data type via the help of pybind in tensor_py.h.
fp16_tensor.set(
tensor.astype(np.float16).view(np.uint16), self.place)
# old var will be replaced by the fp16 var in program desc
self.input_map[var.name] = fp16_var_name
self.block.remove_var(var.name)
#!/bin/bash
BUILD_PATH=/paddle/fp16_build
WHEEL_PATH=$BUILD_PATH/python/dist
INFER_PATH=$BUILD_PATH/paddle/fluid/inference/tests/book
DEMO_PATH=/paddle/contrib/float16
# Use the single most powerful CUDA GPU on your machine
export CUDA_VISIBLE_DEVICES=0
# Build the PaddlePaddle Fluid wheel package and install it.
mkdir -p $BUILD_PATH && cd $BUILD_PATH
cmake .. -DWITH_AVX=OFF \
-DWITH_MKL=OFF \
-DWITH_GPU=ON \
-DWITH_TESTING=ON \
-DWITH_TIMER=ON \
-DWITH_PROFILER=ON \
-DWITH_FLUID_ONLY=ON
make -j `nproc`
pip install -U "$WHEEL_PATH/$(ls $WHEEL_PATH)"
cd $DEMO_PATH
# Clear previous log results
rm -f *.log
# Test the float16 inference accuracy of resnet32 on cifar10 data set
stdbuf -oL python float16_inference_demo.py \
--data_set=cifar10 \
--model=resnet \
--threshold=0.6 \
--repeat=10 \
2>&1 | tee -a float16_inference_accuracy.log
# Sleep to cool down the GPU for consistent benchmarking
sleep 2m
# benchmarking parameters
REPEAT=1000
MAXIMUM_BATCH_SIZE=512
for ((batch_size = 1; batch_size <= MAXIMUM_BATCH_SIZE; batch_size *= 2));
do
# Test inference benchmark of vgg16 on imagenet
stdbuf -oL python float16_inference_demo.py \
--data_set=imagenet \
--model=vgg \
--threshold=0.001 \
--repeat=1 \
$INFER_PATH/test_inference_image_classification_vgg \
--data_set=imagenet \
--dirname=$DEMO_PATH/image_classification_imagenet_vgg.inference.model \
--fp16_dirname=$DEMO_PATH/float16_image_classification_imagenet_vgg.inference.model \
--repeat=$REPEAT \
--batch_size=$batch_size \
--skip_cpu=true \
2>&1 | tee -a imagenet_vgg16_benchmark.log
sleep 2m
# Test inference benchmark of resnet50 on imagenet
stdbuf -oL python float16_inference_demo.py \
--data_set=imagenet \
--model=resnet \
--threshold=0.001 \
--repeat=1 \
$INFER_PATH/test_inference_image_classification_resnet \
--data_set=imagenet \
--dirname=$DEMO_PATH/image_classification_imagenet_resnet.inference.model \
--fp16_dirname=$DEMO_PATH/float16_image_classification_imagenet_resnet.inference.model \
--repeat=$REPEAT \
--batch_size=$batch_size \
--skip_cpu=true \
2>&1 | tee -a imagenet_resnet50_benchmark.log
sleep 2m
# Test inference benchmark of vgg16 on cifar10
stdbuf -oL python float16_inference_demo.py \
--data_set=cifar10 \
--model=vgg \
--threshold=0.001 \
--repeat=1 \
$INFER_PATH/test_inference_image_classification_vgg \
--data_set=cifar10 \
--dirname=$DEMO_PATH/image_classification_cifar10_vgg.inference.model \
--fp16_dirname=$DEMO_PATH/float16_image_classification_cifar10_vgg.inference.model \
--repeat=$REPEAT \
--batch_size=$batch_size \
--skip_cpu=true \
2>&1 | tee -a cifar10_vgg16_benchmark.log
sleep 1m
# Test inference benchmark of resnet32 on cifar10
stdbuf -oL python float16_inference_demo.py \
--data_set=cifar10 \
--model=resnet \
--threshold=0.001 \
--repeat=1 \
$INFER_PATH/test_inference_image_classification_vgg \
--data_set=cifar10 \
--dirname=$DEMO_PATH/image_classification_cifar10_resnet.inference.model \
--fp16_dirname=$DEMO_PATH/float16_image_classification_cifar10_resnet.inference.model \
--repeat=$REPEAT \
--batch_size=$batch_size \
--skip_cpu=true \
2>&1 | tee -a cifar10_resnet32_benchmark.log
sleep 1m
done
......@@ -16,9 +16,12 @@ limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/fluid/inference/tests/test_helper.h"
DEFINE_string(data_set, "cifar10", "Data set to test");
DEFINE_string(dirname, "", "Directory of the inference model.");
DEFINE_string(fp16_dirname, "", "Directory of the float16 inference model.");
DEFINE_int32(batch_size, 1, "Batch size of input data");
DEFINE_int32(repeat, 1, "Running the inference program repeat times");
DEFINE_bool(skip_cpu, false, "Skip the cpu test");
TEST(inference, image_classification) {
if (FLAGS_dirname.empty() || FLAGS_batch_size < 1 || FLAGS_repeat < 1) {
......@@ -35,20 +38,31 @@ TEST(inference, image_classification) {
paddle::framework::LoDTensor input;
// Use normilized image pixels as input data,
// which should be in the range [0.0, 1.0].
SetupTensor<float>(&input, {FLAGS_batch_size, 3, 32, 32},
static_cast<float>(0), static_cast<float>(1));
if (FLAGS_data_set == "cifar10") {
SetupTensor<float>(&input, {FLAGS_batch_size, 3, 32, 32},
static_cast<float>(0), static_cast<float>(1));
} else if (FLAGS_data_set == "imagenet") {
SetupTensor<float>(&input, {FLAGS_batch_size, 3, 224, 224},
static_cast<float>(0), static_cast<float>(1));
} else {
LOG(FATAL) << "Only cifar10 or imagenet is supported.";
}
std::vector<paddle::framework::LoDTensor*> cpu_feeds;
cpu_feeds.push_back(&input);
paddle::framework::LoDTensor output1;
std::vector<paddle::framework::LoDTensor*> cpu_fetchs1;
cpu_fetchs1.push_back(&output1);
// Run inference on CPU
LOG(INFO) << "--- CPU Runs: ---";
TestInference<paddle::platform::CPUPlace, false, true>(
dirname, cpu_feeds, cpu_fetchs1, FLAGS_repeat);
LOG(INFO) << output1.dims();
if (!FLAGS_skip_cpu) {
std::vector<paddle::framework::LoDTensor*> cpu_fetchs1;
cpu_fetchs1.push_back(&output1);
// Run inference on CPU
LOG(INFO) << "--- CPU Runs: ---";
LOG(INFO) << "Batch size is " << FLAGS_batch_size;
TestInference<paddle::platform::CPUPlace, false, true>(
dirname, cpu_feeds, cpu_fetchs1, FLAGS_repeat);
LOG(INFO) << output1.dims();
}
#ifdef PADDLE_WITH_CUDA
paddle::framework::LoDTensor output2;
......@@ -57,24 +71,27 @@ TEST(inference, image_classification) {
// Run inference on CUDA GPU
LOG(INFO) << "--- GPU Runs: ---";
LOG(INFO) << "Batch size is " << FLAGS_batch_size;
TestInference<paddle::platform::CUDAPlace, false, true>(
dirname, cpu_feeds, cpu_fetchs2, FLAGS_repeat);
LOG(INFO) << output2.dims();
CheckError<float>(output1, output2);
if (!FLAGS_skip_cpu) {
CheckError<float>(output1, output2);
}
// float16 inference requires cuda GPUs with >= 5.3 compute capability
if (paddle::platform::GetCUDAComputeCapability(0) >= 53) {
if (!FLAGS_fp16_dirname.empty() &&
paddle::platform::GetCUDAComputeCapability(0) >= 53) {
paddle::framework::LoDTensor output3;
std::vector<paddle::framework::LoDTensor*> cpu_fetchs3;
cpu_fetchs3.push_back(&output3);
LOG(INFO) << "--- GPU Runs in float16 mode: ---";
std::string fp16_dirname = dirname;
fp16_dirname.replace(fp16_dirname.find("book/"),
std::string("book/").size(), "book/float16_");
LOG(INFO) << "Batch size is " << FLAGS_batch_size;
TestInference<paddle::platform::CUDAPlace, false, true>(
fp16_dirname, cpu_feeds, cpu_fetchs3, FLAGS_repeat);
FLAGS_fp16_dirname, cpu_feeds, cpu_fetchs3, FLAGS_repeat);
CheckError<float>(output2, output3);
}
......
......@@ -121,60 +121,7 @@ class InferenceTranspiler:
# And a better solution will be considered later.
program = program.clone()
def float16_transpile(self, program, place, scope=None):
'''
Transpile the program desc and cast the weights to float16 data type to
enable float16 inference.
Since the operator in a program desc will automatically choose the
right compute kernel to run based on the data type of the input tensor.
We actually don't need to change the program desc to run in float16 mode.
However, in this way, users who are used to feeding and fetching tensors
of float32 data type when running typical inference may find it confusing
and difficult to run inference in float16 mode as they need to convert
input data to float16 dtype and then convert the results back to float32
dtype to match the rest of code.
So this function appends cast ops to the program desc where necessary so
that users are able to run inference in float16 mode while providing input
tensor (feed_holder) of float data type and obtaining output tensor
(fetch_holder) of float data type.
Moreover, it is desired that when we have the scope and program desc to run
inference in float32 mode, we can use a single API to do the necessary
modification and then user can run float16 inference on the fly. To make
this happen, this function also create new parameters in the scope to have the
converted float16 weights and change the operators in program desc to use
these new parameters.
:param program: program to transpile
:type program: Program
:param place: inference place
:type place: Place
:param scope: inference scope
:type scope: Scope
'''
if scope is None:
scope = global_scope()
self.scope = scope
self.place = place
self.block = program.block(0)
self.input_map = {} # store the input names should be adjusted
self._modify_feed_fetch()
self._convert_param_to_float16()
self._adjust_input(skip=True)
self._remove_unused_var()
# TODO(luotao): use clone() method to flush the program.desc in force,
# since some large program.desc will not be flushed immediately.
# And a better solution will be considered later.
program = program.clone()
# ====================== private transpiler functions =====================
def _insert_bias_op(self, index, current_op, bn_op):
'''
Construct elementwise_add operator for adding bias
......@@ -269,27 +216,9 @@ class InferenceTranspiler:
# collect the renamed input
self.input_map[bn_op.output("Y")[0]] = bias_op.output("Out")[0]
def _adjust_input(self, skip=False):
'''
Change the input variable name in operators.
When we are in the process of modifying a program desc, we usually
replace some variables with some other variables, where we create
a dictionary input_map to record the one-to-one correspondence
between each old variable and the new one.
After that, this function will search all the operators that use the
old variables and change the info in op to use the new variables. There
maybe some exceptions to this rule when we are using the float16 transpiler
and insert cast ops to cast float32 variable to float16 one. After we
insert the cast op to cast var_1 to var_1_fp16, we don't want to change
the input of cast op to var_1_fp16 after using this function.
'''
skip_ops = {"cast"}
def _adjust_input(self):
for i in range(len(self.block.ops)):
current_op = self.block.ops[i]
if skip and current_op.type in skip_ops:
continue
for input_arg in current_op.input_arg_names:
if input_arg in self.input_map:
current_op.rename_input(input_arg,
......@@ -309,138 +238,3 @@ class InferenceTranspiler:
for var in self.block.vars.keys():
if var not in args:
self.block.remove_var(var)
def _modify_feed_fetch(self):
'''
Modify feed fetch op/vars for float16 inference.
For each feed op:
feed_op->feed_target_var
Change it to:
feed_op->feed_target_var->cast_op(from other dtype to float16)->tmp_var
For each fetch op:
fetch_target_var->fetch_op
Change it to:
tmp_var->cast_op(from float16 to other dtype)->fetch_target_var->fetch_op
:return: None
'''
def find_op(var):
# It is possible that var.op is not up to date after some
# modifications to program desc. Here we force to make it up to date.
var.op = None
for op in self.block.ops:
if var.name in op.output_arg_names:
var.op = op
break
if var.op is None:
raise ValueError("The target variable must have an "
"associated operator that generates it.")
i = 0
while i < len(self.block.ops):
cur_op = self.block.ops[i]
if cur_op.type == "feed":
var_name = cur_op.output("Out")[0]
tmp_var_name = var_name + ".fp16"
var = self.block.vars[var_name]
tmp_var = self.block.create_var(
name=tmp_var_name.encode('ascii'),
type=var.type,
dtype=core.VarDesc.VarType.FP16,
shape=var.shape,
persistable=var.persistable)
self.block.insert_op(
i + 1,
type="cast",
inputs={"X": var},
outputs={"Out": tmp_var},
attrs={
'in_dtype': int(var.dtype),
'out_dtype': int(tmp_var.dtype)
})
self.input_map[var_name] = tmp_var_name
i = i + 1
elif cur_op.type == "fetch":
var_name = cur_op.input("X")[0]
tmp_var_name = var_name + ".fp16"
var = self.block.vars[var_name]
tmp_var = self.block.create_var(
name=tmp_var_name.encode('ascii'),
type=var.type,
dtype=core.VarDesc.VarType.FP16,
shape=var.shape,
persistable=var.persistable)
find_op(var)
var.op.rename_output(var_name, tmp_var_name)
self.block.insert_op(
i,
type="cast",
inputs={"X": tmp_var},
outputs={"Out": var},
attrs={
'in_dtype': int(tmp_var.dtype),
'out_dtype': int(var.dtype)
})
i = i + 1
i = i + 1
def _convert_param_to_float16(self):
def _get_no_fp16_conversion_var_names():
'''
Get the set of input variable names that shouldn't be converted to float16.
When we want to run inference in float16 mode, most parameters need to be
firstly converted to float16. However, there are some parameters that
shouldn't be converted to float16 because the corresponding operator
requires float32 parameters even in float16 mode (when the input data is
of float16 data type). Currently, the only operator that has this exclusion
is the batch norm op.
:return: set of input variable names
:type var_names: set
'''
op_names = {'batch_norm'}
var_names = []
for op in self.block.ops:
if op.type in op_names:
var_names += op.input_arg_names
return set(var_names)
def _should_be_converted(var):
return var.persistable and \
var.name not in self.no_conversion_vars and \
var.type != core.VarDesc.VarType.FEED_MINIBATCH and \
var.type != core.VarDesc.VarType.FETCH_LIST
self.no_conversion_vars = _get_no_fp16_conversion_var_names()
conversion_var_list = filter(_should_be_converted,
self.block.vars.values())
for var in conversion_var_list:
fp16_var_name = var.name + ".fp16"
fp16_var = self.block.create_parameter(
name=fp16_var_name.encode('ascii'),
type=var.type,
dtype=core.VarDesc.VarType.FP16,
shape=var.shape)
# cast the data in the tensor of the original var to float16
# data type and store it in the tensor of the new float16 var
self.scope.var(fp16_var_name)
fp16_tensor = self.scope.find_var(fp16_var_name).get_tensor()
tensor = np.array(self.scope.find_var(var.name).get_tensor())
# After the old tensor data is converted to np.float16, view(np.uint16)
# is used so that the internal memory of the numpy array will be
# reinterpreted to be of np.uint16 data type, which is binded to fluid
# float16 data type via the help of pybind in tensor_py.h.
fp16_tensor.set(
tensor.astype(np.float16).view(np.uint16), self.place)
# old var will be replaced by the fp16 var in program desc
self.input_map[var.name] = fp16_var_name
self.block.remove_var(var.name)
......@@ -247,26 +247,6 @@ def infer(use_cuda, save_dirname=None):
fetch_targets, exe,
inference_transpiler_program)
if use_cuda and fluid.core.is_float16_supported(place):
# Use float16_transpiler to speedup
fp16_transpiler_program = inference_transpiler_program.clone()
t.float16_transpile(fp16_transpiler_program, place)
fp16_results = exe.run(fp16_transpiler_program,
feed={feed_target_names[0]: tensor_img},
fetch_list=fetch_targets)
assert len(results[0]) == len(fp16_results[0])
for i in range(len(results[0])):
np.testing.assert_almost_equal(
results[0][i], fp16_results[0][i], decimal=2)
print("float16 infer results: ", fp16_results[0])
fluid.io.save_inference_model("float16_" + save_dirname,
feed_target_names, fetch_targets, exe,
fp16_transpiler_program)
def main(net_type, use_cuda, is_local=True):
if use_cuda and not fluid.core.is_compiled_with_cuda():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册