提交 f64c861e 编写于 作者: Z Zhen Wang

add pure fp16 training.

上级 54b81fa3
......@@ -124,13 +124,13 @@ class CPUDenseMomentumFunctor {
auto p = framework::EigenVector<T>::Flatten(*param);
auto v = framework::EigenVector<T>::Flatten(*velocity);
auto g = framework::EigenVector<T>::Flatten(*grad);
auto* lr = learning_rate->data<T>();
const float* lr = learning_rate->data<float>();
v_out = v * mu + g;
if (use_nesterov) {
p_out = p - (g + v_out * mu) * lr[0];
p_out = p - (g + v_out * mu) * static_cast<T>(lr[0]);
} else {
p_out = p - lr[0] * v_out;
p_out = p - static_cast<T>(lr[0]) * v_out;
}
}
};
......@@ -147,7 +147,7 @@ class DenseMomentumFunctor<T, UseNesterov> {
const T* p_;
const T* g_;
const T* v_;
const T* lr_;
const float* lr_;
const T mu_;
const int64_t num_;
T* p_out_;
......@@ -155,8 +155,8 @@ class DenseMomentumFunctor<T, UseNesterov> {
public:
DenseMomentumFunctor(const T* p, const T* g, const T* v,
const T* learning_rate, const T mu, const int64_t num,
T* p_out, T* v_out)
const float* learning_rate, const T mu,
const int64_t num, T* p_out, T* v_out)
: p_(p),
g_(g),
v_(v),
......@@ -169,10 +169,10 @@ class DenseMomentumFunctor<T, UseNesterov> {
// put memory access in register
const T p = p_[i];
const T g = g_[i];
const T lr = lr_[0];
const float lr = lr_[0];
const T v = v_[i];
T v_out = v * mu_ + g;
T p_out = p - (g + v_out * mu_) * lr;
T p_out = p - (g + v_out * mu_) * static_cast<T>(lr);
// write reigster to memory
v_out_[i] = v_out;
p_out_[i] = p_out;
......@@ -185,7 +185,7 @@ class DenseMomentumFunctor<T, NoNesterov> {
const T* p_;
const T* g_;
const T* v_;
const T* lr_;
const float* lr_;
const T mu_;
const int64_t num_;
T* p_out_;
......@@ -193,8 +193,8 @@ class DenseMomentumFunctor<T, NoNesterov> {
public:
DenseMomentumFunctor(const T* p, const T* g, const T* v,
const T* learning_rate, const T mu, const int64_t num,
T* p_out, T* v_out)
const float* learning_rate, const T mu,
const int64_t num, T* p_out, T* v_out)
: p_(p),
g_(g),
v_(v),
......@@ -207,7 +207,7 @@ class DenseMomentumFunctor<T, NoNesterov> {
// put memory access in register
const T p = p_[i];
const T g = g_[i];
const T lr = lr_[0];
const T lr = static_cast<T>(lr_[0]);
const T v = v_[i];
T v_out = v * mu_ + g;
T p_out = p - lr * v_out;
......@@ -226,7 +226,7 @@ class SparseMomentumFunctor<T, UseNesterov> {
const T* p_;
const T* g_;
const T* v_;
const T* lr_;
const float* lr_;
const T mu_;
const int64_t* rows_;
const int64_t row_numel_;
......@@ -235,7 +235,7 @@ class SparseMomentumFunctor<T, UseNesterov> {
T* v_out_;
public:
SparseMomentumFunctor(const T* p, const T* g, const T* v, const T* lr,
SparseMomentumFunctor(const T* p, const T* g, const T* v, const float* lr,
const T mu, const int64_t* rows, int64_t row_numel,
int64_t row_height, T* p_out, T* v_out)
: p_(p),
......@@ -256,10 +256,10 @@ class SparseMomentumFunctor<T, UseNesterov> {
: static_cast<T>(0);
// put memory access in register
const T p = p_[i];
const T lr = lr_[0];
const float lr = lr_[0];
const T v = v_[i];
T v_out = v * mu_ + g;
T p_out = p - (g + v_out * mu_) * lr;
T p_out = p - (g + v_out * mu_) * static_cast<T>(lr);
// write reigster to memory
v_out_[i] = v_out;
p_out_[i] = p_out;
......@@ -272,7 +272,7 @@ class SparseMomentumFunctor<T, NoNesterov> {
const T* p_;
const T* g_;
const T* v_;
const T* lr_;
const float* lr_;
const T mu_;
const int64_t* rows_;
const int64_t row_numel_;
......@@ -281,7 +281,7 @@ class SparseMomentumFunctor<T, NoNesterov> {
T* v_out_;
public:
SparseMomentumFunctor(const T* p, const T* g, const T* v, const T* lr,
SparseMomentumFunctor(const T* p, const T* g, const T* v, const float* lr,
const T mu, const int64_t* rows, int64_t row_numel,
int64_t row_height, T* p_out, T* v_out)
: p_(p),
......@@ -302,7 +302,7 @@ class SparseMomentumFunctor<T, NoNesterov> {
: static_cast<T>(0);
// put memory access in register
const T p = p_[i];
const T lr = lr_[0];
const T lr = static_cast<T>(lr_[0]);
const T v = v_[i];
T v_out = v * mu_ + g;
T p_out = p - v_out * lr;
......@@ -342,7 +342,7 @@ class MomentumOpKernel : public framework::OpKernel<T> {
if (use_nesterov) {
DenseMomentumFunctor<T, UseNesterov> functor(
param->data<T>(), grad->data<T>(), velocity->data<T>(),
learning_rate->data<T>(), mu, param->numel(),
learning_rate->data<float>(), mu, param->numel(),
param_out->mutable_data<T>(ctx.GetPlace()),
velocity_out->mutable_data<T>(ctx.GetPlace()));
for_range(functor);
......@@ -350,7 +350,7 @@ class MomentumOpKernel : public framework::OpKernel<T> {
} else {
DenseMomentumFunctor<T, NoNesterov> functor(
param->data<T>(), grad->data<T>(), velocity->data<T>(),
learning_rate->data<T>(), mu, param->numel(),
learning_rate->data<float>(), mu, param->numel(),
param_out->mutable_data<T>(ctx.GetPlace()),
velocity_out->mutable_data<T>(ctx.GetPlace()));
for_range(functor);
......@@ -382,8 +382,8 @@ class MomentumOpKernel : public framework::OpKernel<T> {
if (use_nesterov) {
SparseMomentumFunctor<T, UseNesterov> functor(
param->data<T>(), merged_grad->value().data<T>(),
velocity->data<T>(), learning_rate->data<T>(), mu, rows, row_numel,
static_cast<int64_t>(merged_grad->rows().size()),
velocity->data<T>(), learning_rate->data<float>(), mu, rows,
row_numel, static_cast<int64_t>(merged_grad->rows().size()),
param_out->mutable_data<T>(ctx.GetPlace()),
velocity_out->mutable_data<T>(ctx.GetPlace()));
for_range(functor);
......@@ -391,8 +391,8 @@ class MomentumOpKernel : public framework::OpKernel<T> {
} else {
SparseMomentumFunctor<T, NoNesterov> functor(
param->data<T>(), merged_grad->value().data<T>(),
velocity->data<T>(), learning_rate->data<T>(), mu, rows, row_numel,
static_cast<int64_t>(merged_grad->rows().size()),
velocity->data<T>(), learning_rate->data<float>(), mu, rows,
row_numel, static_cast<int64_t>(merged_grad->rows().size()),
param_out->mutable_data<T>(ctx.GetPlace()),
velocity_out->mutable_data<T>(ctx.GetPlace()));
for_range(functor);
......
......@@ -16,11 +16,17 @@ from __future__ import print_function
from ... import core
from ... import layers
from ... import global_scope
from ...log_helper import get_logger
import logging
import numpy as np
_logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s')
def _rename_arg(op, old_name, new_name):
"""
If an op has old_name input and output, rename these input
If an op has old_name input and output, rename these input
args new_name.
Args:
......@@ -187,6 +193,124 @@ def _is_in_black_varnames(op, amp_lists):
return False
def cast_net_to_fp16(program):
valid_types = [
core.VarDesc.VarType.LOD_TENSOR, core.VarDesc.VarType.SELECTED_ROWS,
core.VarDesc.VarType.LOD_TENSOR_ARRAY
]
global_block = program.global_block()
for block in program.blocks:
ops = block.ops
for op in ops:
for in_name in op.input_names:
if op.type == 'batch_norm' and in_name != 'X':
continue
for in_var_name in op.input(in_name):
in_var = None
try:
in_var = block.var(in_var_name)
except ValueError as e:
_logger.debug(
"-- {}, try to get it in the global block. --".
format(e))
in_var = global_block.var(in_var_name)
if in_var is not None:
_logger.debug(
"-- var {} is got in the global block. --".
format(in_var_name))
if in_var is None or in_var.type not in valid_types:
continue
if in_var.dtype == core.VarDesc.VarType.FP32:
in_var.desc.set_dtype(core.VarDesc.VarType.FP16)
_logger.debug(
"-- op type: {}, in var name: {}, in var dtype: {} --".
format(op.type, in_var_name, in_var.dtype))
for out_name in op.output_names:
if op.type == 'batch_norm' and out_name != 'Y':
continue
for out_var_name in op.output(out_name):
out_var = None
try:
out_var = block.var(out_var_name)
except ValueError as e:
_logger.debug(
"-- {}, try to get it in the global block. --".
format(e))
out_var = global_block.var(out_var_name)
if out_var is not None:
_logger.debug(
"-- var {} is got in the global block. --".
format(out_var_name))
if out_var is None or out_var.type not in valid_types:
continue
if out_var.dtype == core.VarDesc.VarType.FP32:
out_var.desc.set_dtype(core.VarDesc.VarType.FP16)
_logger.debug(
"-- op type: {}, out var name: {}, out var dtype: {} --".
format(op.type, out_var_name, out_var.dtype))
if op.has_attr('in_dtype') and op.attr(
'in_dtype') == core.VarDesc.VarType.FP32:
op._set_attr('in_dtype', core.VarDesc.VarType.FP16)
if op.has_attr('out_dtype') and op.attr(
'out_dtype') == core.VarDesc.VarType.FP32:
op._set_attr('out_dtype', core.VarDesc.VarType.FP16)
if op.has_attr('dtype') and op.attr(
'dtype') == core.VarDesc.VarType.FP32:
op._set_attr('dtype', core.VarDesc.VarType.FP16)
def cast_parameters_to_fp16(exe, program):
global_block = program.global_block()
all_parameters = global_block.all_parameters()
for param in all_parameters:
if not (param.name.find('bn') != -1 and
(param.name.endswith('_offset') or param.name.endswith('_mean')
or param.name.endswith('_scale') or
param.name.endswith('_variance'))):
param_t = global_scope().find_var(param.name).get_tensor()
data = np.array(param_t)
param_t.set(np.float16(data), exe.place)
# def cast_parameters_to_fp16(program):
# global_block = program.global_block()
# all_parameters = global_block.all_parameters()
# is_bn_params = lambda param: (param.name.find('bn') != -1 and (param.name.endswith('_offset') or param.name.endswith('_mean') or param.name.endswith('_scale') or param.name.endswith('_variance')))
# all_param_names = {p.name for p in all_parameters if not is_bn_params(p)}
# ops = global_block.ops
# for param in all_parameters:
# if param.name in all_param_names:
# param_var = global_block.var(param.name)
# if param_var.dtype == core.VarDesc.VarType.FP32:
# param_var.desc.set_dtype(core.VarDesc.VarType.FP16)
# for op in ops:
# target_op = False
# for out_name in op.output_names:
# for out_var_name in op.output(out_name):
# if out_var_name in all_param_names:
# target_op = True
# if target_op:
# if op.has_attr('in_dtype') and op.attr(
# 'in_dtype') == core.VarDesc.VarType.FP32:
# op._set_attr('in_dtype', core.VarDesc.VarType.FP16)
# if op.has_attr('out_dtype') and op.attr(
# 'out_dtype') == core.VarDesc.VarType.FP32:
# op._set_attr('out_dtype', core.VarDesc.VarType.FP16)
# if op.has_attr('dtype') and op.attr(
# 'dtype') == core.VarDesc.VarType.FP32:
# op._set_attr('dtype', core.VarDesc.VarType.FP16)
def rewrite_program(main_prog, amp_lists):
"""
Traverse all ops in current block and insert cast op according to
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册