diff --git a/paddle/fluid/operators/layer_norm_op.cc b/paddle/fluid/operators/layer_norm_op.cc index f61042dfc831854d34ef36acbfe0e78312fd6c97..b600af950c5bdabe6ec30e11907c43c02d69fe97 100644 --- a/paddle/fluid/operators/layer_norm_op.cc +++ b/paddle/fluid/operators/layer_norm_op.cc @@ -141,7 +141,7 @@ class LayerNormGradOp : public framework::OperatorWithKernel { } if (ctx->HasOutput(framework::GradVarName("Bias"))) { ctx->SetOutputDim(framework::GradVarName("Bias"), - ctx->GetInputDim("Scale")); + ctx->GetInputDim("Bias")); } } @@ -182,6 +182,7 @@ class LayerNormGradOpMaker : public framework::SingleGradOpMaker { } if (this->HasInput("Bias")) { + op->SetInput("Bias", this->Input("Bias")); op->SetOutput(framework::GradVarName("Bias"), this->InputGrad("Bias")); } @@ -191,6 +192,9 @@ class LayerNormGradOpMaker : public framework::SingleGradOpMaker { } }; +DECLARE_NO_NEED_BUFFER_VARS_INFERER(LayerNormGradNoNeedBufferVarInference, + "Bias"); + } // namespace operators } // namespace paddle @@ -198,7 +202,8 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(layer_norm, ops::LayerNormOp, ops::LayerNormOpMaker, ops::LayerNormGradOpMaker, ops::LayerNormGradOpMaker); -REGISTER_OPERATOR(layer_norm_grad, ops::LayerNormGradOp); +REGISTER_OPERATOR(layer_norm_grad, ops::LayerNormGradOp, + ops::LayerNormGradNoNeedBufferVarInference); REGISTER_OP_CPU_KERNEL( layer_norm, ops::LayerNormKernel, ops::LayerNormKernel); diff --git a/paddle/fluid/operators/layer_norm_op.cu b/paddle/fluid/operators/layer_norm_op.cu index e42a1f4803dbb6998f06d52b169d8a89eb2341de..10a732e86ea2a612a321d53f6a16cc596e39ae75 100644 --- a/paddle/fluid/operators/layer_norm_op.cu +++ b/paddle/fluid/operators/layer_norm_op.cu @@ -45,6 +45,37 @@ inline static int GetDesiredBlockDim(int block_dim) { FIXED_BLOCK_DIM_CASE_BASE(2, ##__VA_ARGS__); \ FIXED_BLOCK_DIM_CASE_BASE(1, ##__VA_ARGS__) +#define FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE( \ + log2_block_dim, feature_size, kMaxBlockNum, ...) \ + case (1 << (log2_block_dim)): { \ + for (int i = 0; i < std::ceil(feature_size / (1.0 * kMaxBlockNum)); i++) { \ + int col_offset = i * kMaxBlockNum; \ + int block_num = std::min(feature_size - col_offset, kMaxBlockNum); \ + constexpr auto kBlockDim = (1 << (log2_block_dim)); \ + __VA_ARGS__; \ + } \ + } break + +#define FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE(feature_size, kMaxBlockNum, ...) \ + FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(9, feature_size, kMaxBlockNum, \ + ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(8, feature_size, kMaxBlockNum, \ + ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(7, feature_size, kMaxBlockNum, \ + ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(6, feature_size, kMaxBlockNum, \ + ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(5, feature_size, kMaxBlockNum, \ + ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(4, feature_size, kMaxBlockNum, \ + ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(3, feature_size, kMaxBlockNum, \ + ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(2, feature_size, kMaxBlockNum, \ + ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(1, feature_size, kMaxBlockNum, \ + ##__VA_ARGS__) + static __device__ __forceinline__ float real_sqrt(float x) { return sqrtf(x); } static __device__ __forceinline__ double real_sqrt(double x) { return sqrt(x); } @@ -131,12 +162,13 @@ __global__ void LayerNormBackwardGradientAll(const T *x, const T *d_y, T *d_scale, T *d_bias, T *d_x, const T *mean, const T *var, const T *scale, float epsilon, - int batch_size, int feature_size) { + int batch_size, int feature_size, + int col_offset) { using BlockReduce = cub::BlockReduce, BlockDim>; __shared__ typename BlockReduce::TempStorage temp_storage; - int beg_idx = threadIdx.x * feature_size + blockIdx.x; - int end_idx = batch_size * feature_size + blockIdx.x; + int beg_idx = threadIdx.x * feature_size + (blockIdx.x + col_offset); + int end_idx = batch_size * feature_size + (blockIdx.x + col_offset); int stride = BlockDim * feature_size; T d_scale_partial = 0, d_bias_partial = 0; @@ -147,7 +179,7 @@ __global__ void LayerNormBackwardGradientAll(const T *x, const T *d_y, d_scale_partial += d_y[i] * (x[i] - mean[row_idx]) / var_val; d_bias_partial += d_y[i]; if (HasDx) { - d_x[i] = d_y[i] * scale[blockIdx.x] / var_val; + d_x[i] = d_y[i] * scale[blockIdx.x + col_offset] / var_val; } } @@ -156,8 +188,8 @@ __global__ void LayerNormBackwardGradientAll(const T *x, const T *d_y, PairForLayerNormAddFunctor()); if (threadIdx.x == 0) { - d_scale[blockIdx.x] = pair.first_; - d_bias[blockIdx.x] = pair.second_; + d_scale[blockIdx.x + col_offset] = pair.first_; + d_bias[blockIdx.x + col_offset] = pair.second_; } } @@ -168,11 +200,11 @@ template __global__ void LayerNormBackwardGradientScaleOrBias( const T *x, const T *d_y, T *d_scale, T *d_bias, T *d_x, const T *mean, const T *var, const T *scale, float epsilon, int batch_size, - int feature_size) { + int feature_size, int col_offset) { using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; - int beg_idx = threadIdx.x * feature_size + blockIdx.x; - int end_idx = batch_size * feature_size + blockIdx.x; + int beg_idx = threadIdx.x * feature_size + blockIdx.x + col_offset; + int end_idx = batch_size * feature_size + blockIdx.x + col_offset; int stride = BlockDim * feature_size; T d_scale_or_d_bias_partial = 0; @@ -187,7 +219,7 @@ __global__ void LayerNormBackwardGradientScaleOrBias( if (HasDx) { if (scale != nullptr) { - d_x[i] = d_y[i] * scale[blockIdx.x] / var_val; + d_x[i] = d_y[i] * scale[blockIdx.x + col_offset] / var_val; } else { d_x[i] = d_y[i] / var_val; } @@ -199,9 +231,9 @@ __global__ void LayerNormBackwardGradientScaleOrBias( if (threadIdx.x == 0) { if (HasDScale) { - d_scale[blockIdx.x] = d_scale_or_d_bias_partial; + d_scale[blockIdx.x + col_offset] = d_scale_or_d_bias_partial; } else { - d_bias[blockIdx.x] = d_scale_or_d_bias_partial; + d_bias[blockIdx.x + col_offset] = d_scale_or_d_bias_partial; } } } @@ -322,6 +354,7 @@ static void LayerNormBackward(const T *x, const T *d_y, const T *scale, T *d_bias, float epsilon, int batch_size, int feature_size, cudaStream_t stream) { const int kMaxBlockDim = 512; + const int kMaxBlockNum = 128; int gradient_flag = ((d_x != nullptr ? 1 : 0) << 2) | ((d_scale != nullptr ? 1 : 0) << 1) | ((d_bias != nullptr ? 1 : 0)); @@ -347,29 +380,33 @@ static void LayerNormBackward(const T *x, const T *d_y, const T *scale, switch (gradient_flag) { case 1: // d_x == nulptr, d_scale == nullptr, d_bias != nullptr switch (block_dim) { - FIXED_BLOCK_DIM_CASE(LayerNormBackwardGradientScaleOrBias< - T, kBlockDim, false, - false><<>>( - x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon, batch_size, - feature_size)); + FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE( + feature_size, kMaxBlockNum, + LayerNormBackwardGradientScaleOrBias< + T, kBlockDim, false, + false><<>>( + x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon, + batch_size, feature_size, col_offset)); } break; case 2: // d_x == nullptr, d_scale != nullptr, d_bias == nullptr switch (block_dim) { - FIXED_BLOCK_DIM_CASE(LayerNormBackwardGradientScaleOrBias< - T, kBlockDim, false, - true><<>>( - x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon, batch_size, - feature_size)); + FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE( + feature_size, kMaxBlockNum, + LayerNormBackwardGradientScaleOrBias< + T, kBlockDim, false, true><<>>( + x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon, + batch_size, feature_size, col_offset)); } break; case 3: // d_x == nullptr, d_scale != nulptr, d_bias != nullptr switch (block_dim) { - FIXED_BLOCK_DIM_CASE( + FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE( + feature_size, kMaxBlockNum, LayerNormBackwardGradientAll< - T, kBlockDim, false><<>>( + T, kBlockDim, false><<>>( x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon, - batch_size, feature_size)); + batch_size, feature_size, col_offset)); } break; case 4: // d_x != nullptr, d_scale == nullptr, d_bias == nullptr @@ -382,11 +419,12 @@ static void LayerNormBackward(const T *x, const T *d_y, const T *scale, break; case 5: // d_x != nulptr, d_scale == nullptr, d_bias != nullptr switch (block_dim) { - FIXED_BLOCK_DIM_CASE(LayerNormBackwardGradientScaleOrBias< - T, kBlockDim, true, - false><<>>( - x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon, batch_size, - feature_size)); + FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE( + feature_size, kMaxBlockNum, + LayerNormBackwardGradientScaleOrBias< + T, kBlockDim, true, false><<>>( + x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon, + batch_size, feature_size, col_offset)); } switch (GetDesiredBlockDim(feature_size)) { FIXED_BLOCK_DIM_CASE( @@ -397,11 +435,12 @@ static void LayerNormBackward(const T *x, const T *d_y, const T *scale, break; case 6: // d_x != nullptr, d_scale != nullptr, d_bias == nullptr switch (block_dim) { - FIXED_BLOCK_DIM_CASE(LayerNormBackwardGradientScaleOrBias< - T, kBlockDim, true, - true><<>>( - x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon, batch_size, - feature_size)); + FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE( + feature_size, kMaxBlockNum, + LayerNormBackwardGradientScaleOrBias< + T, kBlockDim, true, true><<>>( + x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon, + batch_size, feature_size, col_offset)); } switch (GetDesiredBlockDim(feature_size)) { FIXED_BLOCK_DIM_CASE( @@ -412,11 +451,12 @@ static void LayerNormBackward(const T *x, const T *d_y, const T *scale, break; case 7: // d_x != nullptr, d_scale != nullptr, d_bias != nullptr switch (block_dim) { - FIXED_BLOCK_DIM_CASE( + FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE( + feature_size, kMaxBlockNum, LayerNormBackwardGradientAll< - T, kBlockDim, true><<>>( + T, kBlockDim, true><<>>( x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon, - batch_size, feature_size)); + batch_size, feature_size, col_offset)); } switch (GetDesiredBlockDim(feature_size)) { FIXED_BLOCK_DIM_CASE( @@ -539,6 +579,8 @@ class LayerNormGradKernel } }; template class LayerNormDirectCUDAFunctor; +#undef FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE +#undef FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE #undef FIXED_BLOCK_DIM_CASE_BASE #undef FIXED_BLOCK_DIM_CASE } // namespace operators diff --git a/paddle/fluid/operators/layer_norm_op.h b/paddle/fluid/operators/layer_norm_op.h index 89f5dccc1ac16606b87ad7d7b875a33653fb95af..6968c1a5b131211a8fc7a474df8d1692d6a5ed0f 100644 --- a/paddle/fluid/operators/layer_norm_op.h +++ b/paddle/fluid/operators/layer_norm_op.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include #include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" diff --git a/python/paddle/fluid/tests/unittests/test_layer_norm_op.py b/python/paddle/fluid/tests/unittests/test_layer_norm_op.py index bc475e19f4c6f51ad87916c77ff5e20e1343503b..8df7ea35ec116468f6bfe774c4f8932333b3d9db 100644 --- a/python/paddle/fluid/tests/unittests/test_layer_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_layer_norm_op.py @@ -36,42 +36,72 @@ def _reference_layer_norm_naive(x, scale, beta, epsilon, begin_norm_axis=1): mean = np.mean(x, axis=1) var = np.var(x, axis=1) + epsilon - output = scale.reshape([1, D]) * np.divide( - (x - mean.reshape([N, 1])), - (np.sqrt(var)).reshape([N, 1])) + beta.reshape([1, D]) + output = np.divide((x - mean.reshape([N, 1])), + (np.sqrt(var)).reshape([N, 1])) + if scale is not None: + output = scale.reshape([1, D]) * output + if beta is not None: + output = output + beta.reshape([1, D]) x.shape, output.shape = x_shape, x_shape return output, mean, var -def _reference_layer_norm_grad(x, grad_y, scale, mean, var, begin_norm_axis=1): +def _reference_layer_norm_grad(x, + grad_y, + scale, + bias, + mean, + var, + begin_norm_axis=1): x_shape = x.shape - scale_shape = scale.shape N = reduce(mul, x_shape[0:begin_norm_axis], 1) D = reduce(mul, x_shape[begin_norm_axis:len(x_shape)], 1) + + if scale is not None: + scale_shape = scale.shape + scale.shape = [1, D] x.shape, grad_y.shape = [N, D], [N, D] var.shape, mean.shape = [N, 1], [N, 1] - scale.shape = [1, D] # d_bias - d_bias = np.sum(grad_y, axis=0).reshape([1, D]) + if bias is not None: + d_bias = np.sum(grad_y, axis=0).reshape([1, D]) + else: + d_bias = None # d_scale - d_scale = np.sum(((x - mean) * np.sqrt(1 / var)) * grad_y, - axis=0).reshape([1, D]) + if scale is not None: + d_scale = np.sum(((x - mean) * np.sqrt(1 / var)) * grad_y, + axis=0).reshape([1, D]) + else: + d_scale = None # dx - dx_end = scale * np.sqrt(1.0 / var) * grad_y - d_mean_0 = np.sum(-np.sqrt(1.0 / var) * grad_y * scale, axis=1).reshape( - [N, 1]) # the second part equals to zero. - d_mean = 1.0 / D * d_mean_0 - d_std = np.sum( - -(1.0 / var) * (x - mean) * grad_y * scale, axis=1).reshape([N, 1]) * ( - 1.0 / D * np.sqrt(1.0 / var).reshape([N, 1]) * (x - mean)) + if scale is not None: + dx_end = scale * np.sqrt(1.0 / var) * grad_y + d_mean_0 = np.sum(-np.sqrt(1.0 / var) * grad_y * scale, axis=1).reshape( + [N, 1]) # the second part equals to zero. + d_mean = 1.0 / D * d_mean_0 + d_std = np.sum(-(1.0 / var) * (x - mean) * grad_y * scale, + axis=1).reshape([N, 1]) * ( + 1.0 / D * np.sqrt(1.0 / var).reshape([N, 1]) * + (x - mean)) + else: + dx_end = 1.0 * np.sqrt(1.0 / var) * grad_y + d_mean_0 = np.sum(-np.sqrt(1.0 / var) * grad_y * 1.0, axis=1).reshape( + [N, 1]) # the second part equals to zero. + d_mean = 1.0 / D * d_mean_0 + d_std = np.sum(-(1.0 / var) * (x - mean) * grad_y * 1.0, + axis=1).reshape([N, 1]) * ( + 1.0 / D * np.sqrt(1.0 / var).reshape([N, 1]) * + (x - mean)) grad_x = dx_end + d_mean + d_std grad_x.shape, x.shape, grad_y.shape = x_shape, x_shape, x_shape - scale.shape = scale_shape var.shape, mean.shape = [N, ], [N, ] + + if scale is not None: + scale.shape = scale_shape return grad_x, d_scale, d_bias @@ -82,7 +112,12 @@ class TestLayerNormOp(unittest.TestCase): def __assert_close(self, tensor, np_array, msg, atol=1e-4): self.assertTrue(np.allclose(np.array(tensor), np_array, atol=atol), msg) - def check_forward_backward(self, shape, begin_norm_axis): + def check_forward_backward(self, + shape, + begin_norm_axis, + has_scale=True, + has_bias=True, + y_grad_scale=1.0): def test_with_place(place, shape, begin_norm_axis): # attr epsilon = 0.00001 @@ -92,21 +127,26 @@ class TestLayerNormOp(unittest.TestCase): np.random.seed(123) x = np.random.random_sample(x_shape).astype(np.float32) - scale = np.random.random_sample(scale_shape).astype(np.float32) - bias = np.random.random_sample(scale_shape).astype(np.float32) - y_grad = np.random.random_sample(x_shape).astype(np.float32) + scale = np.random.random_sample(scale_shape).astype( + np.float32) if has_scale else None + bias = np.random.random_sample(scale_shape).astype( + np.float32) if has_bias else None + y_grad = (np.random.random_sample(x_shape) * + y_grad_scale).astype(np.float32) # reference forward & backward y, mean, variance = _reference_layer_norm_naive( x, scale, bias, epsilon, begin_norm_axis) x_grad, scale_grad, bias_grad = _reference_layer_norm_grad( - x, y_grad, scale, mean, variance, begin_norm_axis) + x, y_grad, scale, bias, mean, variance, begin_norm_axis) var_dict = locals() var_dict['y@GRAD'] = y_grad - var_names = [ - 'x', 'scale', 'bias', 'mean', 'variance', 'y', 'y@GRAD' - ] + var_names = ['x', 'mean', 'variance', 'y', 'y@GRAD'] + if has_scale: + var_names += ['scale'] + if has_bias: + var_names += ['bias'] ground_truth = {name: var_dict[name] for name in var_names} program = fluid.Program() @@ -117,13 +157,22 @@ class TestLayerNormOp(unittest.TestCase): name=name, dtype='float32', shape=ground_truth[name].shape) + inputs = {"X": block.var('x')} + fetch_list = [ + 'y', + 'mean', + 'variance', + 'x@GRAD', + ] + if has_scale: + inputs["Scale"] = block.var('scale') + fetch_list += ['scale@GRAD'] + if has_bias: + inputs["Bias"] = block.var('bias') + fetch_list += ['bias@GRAD'] layer_norm_op = block.append_op( type="layer_norm", - inputs={ - "X": block.var('x'), - "Scale": block.var('scale'), - "Bias": block.var('bias'), - }, + inputs=inputs, outputs={ "Y": block.var('y'), "Mean": block.var('mean'), # share the same memory @@ -134,7 +183,6 @@ class TestLayerNormOp(unittest.TestCase): "epsilon": epsilon, "begin_norm_axis": begin_norm_axis }) - # generate backward op_desc grad_op_desc_list, op_grad_to_var = core.get_grad_op_desc( layer_norm_op.desc, set(), []) @@ -150,23 +198,25 @@ class TestLayerNormOp(unittest.TestCase): grad_var.set_dtype(core.VarDesc.VarType.FP32) program._sync_with_cpp() - exe = fluid.Executor(place) out = exe.run(program, feed={ name: var_dict[name] for name in ['x', 'scale', 'bias', 'y@GRAD'] }, - fetch_list=[ - 'y', 'mean', 'variance', 'x@GRAD', - 'scale@GRAD', 'bias@GRAD' - ]) + fetch_list=fetch_list) self.__assert_close(y, out[0], "y") self.__assert_close(mean, out[1], "mean") self.__assert_close(variance, out[2], "variance", 1e-3) self.__assert_close(x_grad, out[3], "x_grad") - self.__assert_close(scale_grad, out[4], "scale_grad", 1e-3) - self.__assert_close(bias_grad, out[5], "bias_grad") + if has_scale: + self.__assert_close(scale_grad, + out[fetch_list.index('scale@GRAD')], + "scale_grad", 1e-3) + if has_bias: + self.__assert_close(bias_grad, + out[fetch_list.index('bias@GRAD')], + "bias_grad") places = [core.CPUPlace()] if core.is_compiled_with_cuda() and core.op_support_gpu( @@ -178,7 +228,45 @@ class TestLayerNormOp(unittest.TestCase): def test_check_forward_backward_with_scale_and_bias(self): self.check_forward_backward(shape=[2, 3, 4, 5], begin_norm_axis=1) + self.check_forward_backward( + shape=[2, 3, 4, 5], + begin_norm_axis=1, + has_scale=False, + has_bias=True) + self.check_forward_backward( + shape=[2, 3, 4, 5], + begin_norm_axis=1, + has_scale=True, + has_bias=False) + self.check_forward_backward( + shape=[2, 3, 4, 5], + begin_norm_axis=1, + has_scale=False, + has_bias=False) self.check_forward_backward(shape=[2, 3, 4, 5], begin_norm_axis=3) + self.check_forward_backward( + shape=[92, 513, 129], begin_norm_axis=2, y_grad_scale=0.1) + self.check_forward_backward(shape=[3, 34, 1134], begin_norm_axis=2) + self.check_forward_backward( + shape=[92, 513, 1134], begin_norm_axis=2, y_grad_scale=0.1) + self.check_forward_backward( + shape=[92, 513, 1134], + begin_norm_axis=2, + has_scale=False, + has_bias=True, + y_grad_scale=0.1) + self.check_forward_backward( + shape=[92, 513, 1134], + begin_norm_axis=2, + has_scale=True, + has_bias=False, + y_grad_scale=0.1) + self.check_forward_backward( + shape=[92, 513, 1134], + begin_norm_axis=2, + has_scale=False, + has_bias=False, + y_grad_scale=0.1) class TestLayerNormAPI(unittest.TestCase):