未验证 提交 7584bb50 编写于 作者: F furnace 提交者: GitHub

Layer norm fp16 (#29169)

* add fp16 for layer_norm op

* revert layernorm api

* fix forward

* fix forward

* fix backward for layernorm with fp16

* fix unit test for layernorm with fp16

* fix with_mkldnn compile error for layernorm with fp16

* 1. revert to PADDLE_ENFORCE_NOT_NULL, 2. change static_cast<float> to static_cast<U>

* fix with_mkldnn compile error for layernorm with fp16

* fix with_mkldnn compile error for layernorm with fp16
Co-authored-by: Nzhiqiu <chenqiuliang@baidu.com>
上级 597897e3
......@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/layer_norm_op.h"
#include <memory>
#include <string>
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
......@@ -98,7 +99,26 @@ class LayerNormOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const {
const framework::ExecutionContext &ctx) const override {
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
// By default, the type of the scale, bias, mean,
// and var tensors should both be float. (For float or float16 input tensor)
// or double (For double input tensor).
auto ln_param_type = framework::proto::VarType::FP32;
if (input_data_type == framework::proto::VarType::FP64) {
ln_param_type = framework::proto::VarType::FP64;
}
if (ctx.HasInput("Scale")) {
PADDLE_ENFORCE_EQ(ln_param_type, ctx.Input<Tensor>("Scale")->type(),
platform::errors::InvalidArgument(
"Scale input should be of float type"));
}
if (ctx.HasInput("Bias")) {
PADDLE_ENFORCE_EQ(ln_param_type, ctx.Input<Tensor>("Bias")->type(),
platform::errors::InvalidArgument(
"Bias input should be of float type"));
}
framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
......@@ -110,9 +130,8 @@ class LayerNormOp : public framework::OperatorWithKernel {
}
#endif
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(),
layout, library);
return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout,
library);
}
};
......@@ -224,7 +243,13 @@ class LayerNormGradOp : public framework::OperatorWithKernel {
}
PADDLE_ENFORCE_NOT_NULL(
t, platform::errors::NotFound("Y@GRAD of LayerNorm Op is not found."));
return framework::OpKernelType(t->type(), ctx.GetPlace());
framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(),
layout, library);
}
};
......
......@@ -109,9 +109,11 @@ gray_list = {
'elementwise_mod',
'elementwise_floordiv',
'batch_norm',
'layer_norm',
'tanh',
'sigmoid',
'lookup_table',
'lookup_table_v2',
'top_k',
'pool2d',
'pool3d',
......@@ -123,6 +125,7 @@ gray_list = {
'flatten2',
'stack',
'unstack',
'uniform_random',
'uniform_random_batch_size_like',
'gaussian_random',
'gaussian_random_batch_size_like',
......@@ -192,7 +195,6 @@ unsupported_fp16_list = {
'sequence_concat',
'sequence_slice',
'data_norm',
'layer_norm',
'group_norm',
'spectral_norm',
'depthwise_conv2d_transpose',
......
......@@ -70,7 +70,7 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
for in_name in op.input_names:
if src_dtype == core.VarDesc.VarType.FP32 and op.type in [
'batch_norm', 'fused_bn_add_activation'
'batch_norm', 'fused_bn_add_activation', 'layer_norm'
]:
if in_name not in {'X', 'Z'}:
continue
......@@ -104,8 +104,9 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
op._set_attr('in_dtype', dest_dtype)
if src_dtype == core.VarDesc.VarType.FP32 and dest_dtype == core.VarDesc.VarType.FP16:
for out_name in op.output_names:
if op.type in ['batch_norm', 'fused_bn_add_activation'
] and out_name != 'Y':
if op.type in [
'batch_norm', 'fused_bn_add_activation', 'layer_norm'
] and out_name != 'Y':
continue
for out_var_name in op.output(out_name):
out_var = block.var(out_var_name)
......
......@@ -15,6 +15,7 @@
from __future__ import print_function
import unittest
import numpy as np
import paddle
from operator import mul
import paddle.fluid.core as core
......@@ -210,7 +211,7 @@ class TestLayerNormOp(unittest.TestCase):
for name in ['x', 'scale', 'bias', 'y@GRAD']
},
fetch_list=fetch_list)
self.__assert_close(y, out[0], "y")
self.__assert_close(y, out[0], "y", 1e-3)
self.__assert_close(mean, out[1], "mean")
self.__assert_close(variance, out[2], "variance", 1e-3)
self.__assert_close(x_grad, out[3], "x_grad")
......@@ -310,6 +311,8 @@ class TestLayerNormAPI(unittest.TestCase):
class TestDygraphLayerNormAPIError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
paddle.enable_static()
layer_norm = fluid.LayerNorm([32, 32])
# the input of LayerNorm must be Variable.
x1 = np.random.random((3, 32, 32)).astype('float32')
......
......@@ -293,7 +293,8 @@ def layer_norm(x,
'begin_norm_axis', begin_norm_axis)
return dygraph_utils._append_activation_in_dygraph(pre_act, act=None)
check_variable_and_dtype(x, 'input', ['float32', 'float64'], 'LayerNorm')
check_variable_and_dtype(x, 'input', ['float16', 'float32', 'float64'],
'LayerNorm')
inputs = dict()
inputs['X'] = [x]
......@@ -305,11 +306,13 @@ def layer_norm(x,
# create output
helper = LayerHelper('layer_norm', **locals())
dtype = x.dtype
mean_out = helper.create_variable_for_type_inference(
dtype=x.dtype, stop_gradient=True)
dtype=dtype, stop_gradient=True)
variance_out = helper.create_variable_for_type_inference(
dtype=x.dtype, stop_gradient=True)
layer_norm_out = helper.create_variable_for_type_inference(x.dtype)
dtype=dtype, stop_gradient=True)
layer_norm_out = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type="layer_norm",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册