未验证 提交 7584bb50 编写于 作者: F furnace 提交者: GitHub

Layer norm fp16 (#29169)

* add fp16 for layer_norm op

* revert layernorm api

* fix forward

* fix forward

* fix backward for layernorm with fp16

* fix unit test for layernorm with fp16

* fix with_mkldnn compile error for layernorm with fp16

* 1. revert to PADDLE_ENFORCE_NOT_NULL, 2. change static_cast<float> to static_cast<U>

* fix with_mkldnn compile error for layernorm with fp16

* fix with_mkldnn compile error for layernorm with fp16
Co-authored-by: Nzhiqiu <chenqiuliang@baidu.com>
上级 597897e3
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/layer_norm_op.h" #include "paddle/fluid/operators/layer_norm_op.h"
#include <memory> #include <memory>
#include <string>
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_helper.h"
...@@ -98,7 +99,26 @@ class LayerNormOp : public framework::OperatorWithKernel { ...@@ -98,7 +99,26 @@ class LayerNormOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const { const framework::ExecutionContext &ctx) const override {
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
// By default, the type of the scale, bias, mean,
// and var tensors should both be float. (For float or float16 input tensor)
// or double (For double input tensor).
auto ln_param_type = framework::proto::VarType::FP32;
if (input_data_type == framework::proto::VarType::FP64) {
ln_param_type = framework::proto::VarType::FP64;
}
if (ctx.HasInput("Scale")) {
PADDLE_ENFORCE_EQ(ln_param_type, ctx.Input<Tensor>("Scale")->type(),
platform::errors::InvalidArgument(
"Scale input should be of float type"));
}
if (ctx.HasInput("Bias")) {
PADDLE_ENFORCE_EQ(ln_param_type, ctx.Input<Tensor>("Bias")->type(),
platform::errors::InvalidArgument(
"Bias input should be of float type"));
}
framework::LibraryType library = framework::LibraryType::kPlain; framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout; framework::DataLayout layout = framework::DataLayout::kAnyLayout;
...@@ -110,9 +130,8 @@ class LayerNormOp : public framework::OperatorWithKernel { ...@@ -110,9 +130,8 @@ class LayerNormOp : public framework::OperatorWithKernel {
} }
#endif #endif
return framework::OpKernelType( return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout,
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(), library);
layout, library);
} }
}; };
...@@ -224,7 +243,13 @@ class LayerNormGradOp : public framework::OperatorWithKernel { ...@@ -224,7 +243,13 @@ class LayerNormGradOp : public framework::OperatorWithKernel {
} }
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
t, platform::errors::NotFound("Y@GRAD of LayerNorm Op is not found.")); t, platform::errors::NotFound("Y@GRAD of LayerNorm Op is not found."));
return framework::OpKernelType(t->type(), ctx.GetPlace());
framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(),
layout, library);
} }
}; };
......
...@@ -109,9 +109,11 @@ gray_list = { ...@@ -109,9 +109,11 @@ gray_list = {
'elementwise_mod', 'elementwise_mod',
'elementwise_floordiv', 'elementwise_floordiv',
'batch_norm', 'batch_norm',
'layer_norm',
'tanh', 'tanh',
'sigmoid', 'sigmoid',
'lookup_table', 'lookup_table',
'lookup_table_v2',
'top_k', 'top_k',
'pool2d', 'pool2d',
'pool3d', 'pool3d',
...@@ -123,6 +125,7 @@ gray_list = { ...@@ -123,6 +125,7 @@ gray_list = {
'flatten2', 'flatten2',
'stack', 'stack',
'unstack', 'unstack',
'uniform_random',
'uniform_random_batch_size_like', 'uniform_random_batch_size_like',
'gaussian_random', 'gaussian_random',
'gaussian_random_batch_size_like', 'gaussian_random_batch_size_like',
...@@ -192,7 +195,6 @@ unsupported_fp16_list = { ...@@ -192,7 +195,6 @@ unsupported_fp16_list = {
'sequence_concat', 'sequence_concat',
'sequence_slice', 'sequence_slice',
'data_norm', 'data_norm',
'layer_norm',
'group_norm', 'group_norm',
'spectral_norm', 'spectral_norm',
'depthwise_conv2d_transpose', 'depthwise_conv2d_transpose',
......
...@@ -70,7 +70,7 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype): ...@@ -70,7 +70,7 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
for in_name in op.input_names: for in_name in op.input_names:
if src_dtype == core.VarDesc.VarType.FP32 and op.type in [ if src_dtype == core.VarDesc.VarType.FP32 and op.type in [
'batch_norm', 'fused_bn_add_activation' 'batch_norm', 'fused_bn_add_activation', 'layer_norm'
]: ]:
if in_name not in {'X', 'Z'}: if in_name not in {'X', 'Z'}:
continue continue
...@@ -104,8 +104,9 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype): ...@@ -104,8 +104,9 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
op._set_attr('in_dtype', dest_dtype) op._set_attr('in_dtype', dest_dtype)
if src_dtype == core.VarDesc.VarType.FP32 and dest_dtype == core.VarDesc.VarType.FP16: if src_dtype == core.VarDesc.VarType.FP32 and dest_dtype == core.VarDesc.VarType.FP16:
for out_name in op.output_names: for out_name in op.output_names:
if op.type in ['batch_norm', 'fused_bn_add_activation' if op.type in [
] and out_name != 'Y': 'batch_norm', 'fused_bn_add_activation', 'layer_norm'
] and out_name != 'Y':
continue continue
for out_var_name in op.output(out_name): for out_var_name in op.output(out_name):
out_var = block.var(out_var_name) out_var = block.var(out_var_name)
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
from __future__ import print_function from __future__ import print_function
import unittest import unittest
import numpy as np import numpy as np
import paddle
from operator import mul from operator import mul
import paddle.fluid.core as core import paddle.fluid.core as core
...@@ -210,7 +211,7 @@ class TestLayerNormOp(unittest.TestCase): ...@@ -210,7 +211,7 @@ class TestLayerNormOp(unittest.TestCase):
for name in ['x', 'scale', 'bias', 'y@GRAD'] for name in ['x', 'scale', 'bias', 'y@GRAD']
}, },
fetch_list=fetch_list) fetch_list=fetch_list)
self.__assert_close(y, out[0], "y") self.__assert_close(y, out[0], "y", 1e-3)
self.__assert_close(mean, out[1], "mean") self.__assert_close(mean, out[1], "mean")
self.__assert_close(variance, out[2], "variance", 1e-3) self.__assert_close(variance, out[2], "variance", 1e-3)
self.__assert_close(x_grad, out[3], "x_grad") self.__assert_close(x_grad, out[3], "x_grad")
...@@ -310,6 +311,8 @@ class TestLayerNormAPI(unittest.TestCase): ...@@ -310,6 +311,8 @@ class TestLayerNormAPI(unittest.TestCase):
class TestDygraphLayerNormAPIError(unittest.TestCase): class TestDygraphLayerNormAPIError(unittest.TestCase):
def test_errors(self): def test_errors(self):
with program_guard(Program(), Program()): with program_guard(Program(), Program()):
paddle.enable_static()
layer_norm = fluid.LayerNorm([32, 32]) layer_norm = fluid.LayerNorm([32, 32])
# the input of LayerNorm must be Variable. # the input of LayerNorm must be Variable.
x1 = np.random.random((3, 32, 32)).astype('float32') x1 = np.random.random((3, 32, 32)).astype('float32')
......
...@@ -293,7 +293,8 @@ def layer_norm(x, ...@@ -293,7 +293,8 @@ def layer_norm(x,
'begin_norm_axis', begin_norm_axis) 'begin_norm_axis', begin_norm_axis)
return dygraph_utils._append_activation_in_dygraph(pre_act, act=None) return dygraph_utils._append_activation_in_dygraph(pre_act, act=None)
check_variable_and_dtype(x, 'input', ['float32', 'float64'], 'LayerNorm') check_variable_and_dtype(x, 'input', ['float16', 'float32', 'float64'],
'LayerNorm')
inputs = dict() inputs = dict()
inputs['X'] = [x] inputs['X'] = [x]
...@@ -305,11 +306,13 @@ def layer_norm(x, ...@@ -305,11 +306,13 @@ def layer_norm(x,
# create output # create output
helper = LayerHelper('layer_norm', **locals()) helper = LayerHelper('layer_norm', **locals())
dtype = x.dtype
mean_out = helper.create_variable_for_type_inference( mean_out = helper.create_variable_for_type_inference(
dtype=x.dtype, stop_gradient=True) dtype=dtype, stop_gradient=True)
variance_out = helper.create_variable_for_type_inference( variance_out = helper.create_variable_for_type_inference(
dtype=x.dtype, stop_gradient=True) dtype=dtype, stop_gradient=True)
layer_norm_out = helper.create_variable_for_type_inference(x.dtype) layer_norm_out = helper.create_variable_for_type_inference(dtype)
helper.append_op( helper.append_op(
type="layer_norm", type="layer_norm",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册