diff --git a/oneflow/python/ops/nn_ops.py b/oneflow/python/ops/nn_ops.py index eb70f35cf7dfaba2b7acbaca753db2b4e6c64a7b..f420bcc4bfcb96b457adc60afa23e51760d36604 100644 --- a/oneflow/python/ops/nn_ops.py +++ b/oneflow/python/ops/nn_ops.py @@ -926,6 +926,120 @@ def batch_normalization( raise NotImplementedError +@oneflow_export("nn.layer_norm") +def layer_norm( + inputs: remote_blob_util.BlobDef, + gamma: Optional[remote_blob_util.BlobDef] = None, + beta: Optional[remote_blob_util.BlobDef] = None, + begin_norm_axis: int = 1, + begin_params_axis: int = -1, + epsilon: float = 1e-5, + name: Optional[str] = None, +) -> remote_blob_util.BlobDef: + r"""Layer Normalization. + + Args: + inputs (remote_blob_util.BlobDef): Input `Blob`. + gamma (Optional[remote_blob_util.BlobDef]). + beta (Optional[remote_blob_util.BlobDef]). + begin_norm_axis (int, optional): An integer specifies which axis to normalize at first. Defaults to 1. + begin_params_axis (int, optional): An integer specifies which axis params at . Defaults to -1. + epsilon (float, optional): A small float is added to avoid division by zero. Defaults to 1e-5. + name (Optional[str], optional): This operator's name. Defaults to None. + + Returns: + remote_blob_util.BlobDef: A normalized `Blob` with same shape of input. + + For example: + + .. code-block:: python + + import oneflow as flow + import numpy as np + import oneflow.typing as tp + + + @flow.global_function() + def layer_norm_Job(x: tp.Numpy.Placeholder((1, 64, 128, 128)) + ) -> tp.Numpy: + layer_norm = flow.nn.layer_norm( + x, + name="LayerNorm1" + ) + return layer_norm + + + x = np.random.randn(1, 64, 128, 128).astype(np.float32) + out = layer_norm_Job(x) + + # out.shape (1, 64, 128, 128) + + """ + param_shape = inputs.shape[begin_params_axis:] + + if name is None: + name = id_util.UniqueStr("LayerNorm_") + + if flow.current_scope().device_parallel_desc_symbol.device_tag == "cpu": + if begin_norm_axis < 0: + begin_norm_axis = begin_norm_axis + len(inputs.shape) + + reduce_axis = [] + for dim in range(len(inputs.shape)): + if dim >= begin_norm_axis: + reduce_axis.append(dim) + mean, variance = flow.nn.moments(inputs, reduce_axis, keepdims=True) + + axis = begin_norm_axis + normalized = flow.nn.batch_normalization( + x=inputs, + mean=mean, + variance=variance, + variance_epsilon=epsilon, + axis=axis, + name=name, + ) + nd_params_shape = [1] * (len(inputs.shape) - len(param_shape)) + list( + param_shape + ) + affined = normalized + if gamma: + gamma = flow.reshape(gamma, nd_params_shape) + affined *= gamma + if beta: + beta = flow.reshape(beta, nd_params_shape) + affined += beta + return affined + elif flow.current_scope().device_parallel_desc_symbol.device_tag == "gpu": + op_builder = ( + flow.user_op_builder(name) + .Op("layer_norm") + .Input("x", [inputs]) + .Output("y") + .Output("mean") + .Output("inv_variance") + ) + scale = False + center = False + if beta is not None: + center = True + op_builder.Input("beta", [beta]) + if gamma is not None: + scale = True + op_builder.Input("gamma", [gamma]) + op_builder.Output("normalized") + op_builder.Attr("center", center) + op_builder.Attr("scale", scale) + op_builder.Attr("begin_norm_axis", begin_norm_axis) + op_builder.Attr("begin_params_axis", begin_params_axis) + op_builder.Attr("epsilon", epsilon) + + y = op_builder.Build().InferAndTryRun().RemoteBlobList()[0] + return y + else: + raise NotImplementedError + + @oneflow_export("nn.compat_conv2d") def tf_conv2d( input: remote_blob_util.BlobDef, diff --git a/oneflow/python/test/ops/test_layer_norm.py b/oneflow/python/test/ops/test_layer_norm.py index ac32f6c3f07474c8d2664c360f7e09d2eac7e487..c8beadf003d63d711ff27dc7872df097546ab4b7 100644 --- a/oneflow/python/test/ops/test_layer_norm.py +++ b/oneflow/python/test/ops/test_layer_norm.py @@ -20,6 +20,8 @@ from collections import OrderedDict import numpy as np import oneflow as flow import tensorflow as tf +import test_global_storage + from test_util import GenArgList, type_name_to_flow_type, type_name_to_np_type import oneflow.typing as oft @@ -30,12 +32,12 @@ for gpu in gpus: def test_layer_norm(_): confs = [ - {"x_shape": (4, 5, 2, 6), "begin_norm_axis": -1, "begin_params_axis": -1}, + {"x_shape": (40, 50), "begin_norm_axis": -1, "begin_params_axis": -1}, ] arg_dict = OrderedDict() arg_dict["device_type"] = ["cpu", "gpu"] arg_dict["confs"] = confs - arg_dict["data_type"] = ["float32"] + arg_dict["data_type"] = ["float32", "float16"] arg_dict["trainable"] = [True, False] arg_dict["center"] = [True, False] arg_dict["scale"] = [True, False] @@ -43,6 +45,8 @@ def test_layer_norm(_): for case in GenArgList(arg_dict): (device_type, confs, data_type, trainable, center, scale, epsilon) = case + if device_type == "cpu" and data_type == "float16": + continue x_shape = confs["x_shape"] begin_norm_axis = confs["begin_norm_axis"] begin_params_axis = confs["begin_params_axis"] @@ -51,13 +55,26 @@ def test_layer_norm(_): begin_norm_axis == begin_params_axis ), "tf doesn't support a dedicated begin_params_axis" # Random inputs - x = np.random.randn(*x_shape).astype(type_name_to_np_type[data_type]) + if data_type == "float16": + x = ( + np.random.uniform(low=-1, high=1, size=x_shape) + .astype(np.float16) + .astype(np.float32) + ) + else: + x = np.random.uniform(low=-1, high=1, size=x_shape).astype( + type_name_to_np_type[data_type] + ) + dim = len(x.shape) - 2 # TF results with tf.GradientTape(persistent=True) as tape: x_tf = tf.Variable(x) - y_tf = tf.keras.layers.LayerNormalization( + if data_type == "float16": + x_tf = tf.cast(x_tf, dtype=tf.float16) + tf.keras.backend.set_floatx("float16") + layer = tf.keras.layers.LayerNormalization( axis=begin_norm_axis, epsilon=epsilon, center=center, @@ -69,20 +86,65 @@ def test_layer_norm(_): beta_constraint=None, gamma_constraint=None, trainable=trainable, - )(x_tf) - - dx_tf = tape.gradient(y_tf, x_tf, tf.constant(1.0, shape=y_tf.shape)) + ) + y_tf = layer(x_tf) + if data_type == "float16": + dx_tf = tape.gradient( + y_tf, x_tf, tf.constant(1.0, shape=y_tf.shape, dtype=tf.float16) + ) + else: + dx_tf = tape.gradient(y_tf, x_tf, tf.constant(1.0, shape=y_tf.shape)) + grad = tape.gradient(y_tf, layer.trainable_variables) + if trainable: + if scale and center: + tf_gamma_diff = grad[0] + tf_beta_diff = grad[1] + elif scale and not center: + tf_gamma_diff = grad[0] + elif not scale and center: + tf_beta_diff = grad[0] + else: + pass + else: + pass def assert_grad(b): diff = dx_tf.numpy() - b.numpy() max_diff = np.max(np.abs(diff)) - assert np.allclose(dx_tf.numpy(), b.numpy(), rtol=1e-5, atol=1e-5), ( + if data_type == "float16": + tolerance = 2e-3 + else: + tolerance = 1e-5 + assert np.allclose( + dx_tf.numpy(), b.numpy(), rtol=tolerance, atol=tolerance + ), ( + case, + max_diff, + ) + + def assert_grad_gamma(b): + diff = tf_gamma_diff.numpy() - b.numpy() + max_diff = np.max(np.abs(diff)) + assert np.allclose( + tf_gamma_diff.numpy(), b.numpy(), rtol=1e-4, atol=1e-4 + ), ( + case, + max_diff, + ) + + def assert_grad_beta(b): + diff = tf_beta_diff.numpy() - b.numpy() + max_diff = np.max(np.abs(diff)) + assert np.allclose(tf_beta_diff.numpy(), b.numpy(), rtol=1e-5, atol=1e-5), ( case, max_diff, ) # 1F results - dtype = type_name_to_flow_type[data_type] + if data_type == "float16": + dtype = flow.float + else: + dtype = type_name_to_flow_type[data_type] func_config = flow.FunctionConfig() func_config.default_data_type(flow.float) @@ -98,14 +160,60 @@ def test_layer_norm(_): ) flow.watch_diff(v, assert_grad) x += v + if data_type == "float16": + x = flow.cast(x, dtype=flow.float16) with flow.scope.placement(device_type, "0:0"): - y = flow.layers.layer_norm( + param_shape = x.shape[begin_params_axis:] + gamma = None + beta = None + if center: + with flow.scope.namespace("LayerNorm"): + beta = flow.get_variable( + name="beta", + shape=param_shape, + dtype=flow.float, + initializer=flow.constant_initializer(0.0), + trainable=trainable, + model_name="beta", + reuse=False, + ) + if trainable: + flow.watch_diff(beta, assert_grad_beta) + if data_type == "float16": + beta = flow.cast(beta, dtype=flow.float16) + + if scale: + with flow.scope.namespace("LayerNorm"): + gamma = flow.get_variable( + name="gamma", + shape=param_shape, + dtype=flow.float, + initializer=flow.constant_initializer(1.0), + trainable=trainable, + model_name="gamma", + reuse=False, + ) + if trainable: + if data_type == "float16": + flow.watch_diff( + gamma, test_global_storage.Setter("gamma_diff") + ) + else: + flow.watch_diff(gamma, assert_grad_gamma) + if data_type == "float16": + gamma = flow.cast(gamma, dtype=flow.float16) + + y = flow.nn.layer_norm( x, + gamma=gamma, + beta=beta, begin_norm_axis=begin_norm_axis, begin_params_axis=begin_params_axis, - center=center, - scale=scale, + epsilon=epsilon, ) + if data_type == "float16": + y = flow.cast(y, dtype=flow.float) + flow.optimizer.SGD( flow.optimizer.PiecewiseConstantScheduler([], [1e-4]), momentum=0 ).minimize(y) @@ -114,6 +222,7 @@ def test_layer_norm(_): check_point = flow.train.CheckPoint() check_point.init() y = test_job(x).get() + assert y.numpy().shape == y_tf.numpy().shape, ( y.numpy().shape, y_tf.numpy().shape, @@ -124,3 +233,23 @@ def test_layer_norm(_): case, max_diff, ) + if data_type == "float16" and trainable and scale: + np_dy = np.ones(x.shape).astype(np.float32) + np_gamma_diff = np.sum(np_dy * y.numpy().astype(np.float32), axis=0).astype( + np.float16 + ) + max_diff = np.max( + np.abs( + np_gamma_diff + - test_global_storage.Get("gamma_diff").astype(np.float16) + ) + ) + assert np.allclose( + np_gamma_diff, + test_global_storage.Get("gamma_diff").astype(np.float16), + rtol=1e-2, + atol=1e-2, + ), ( + case, + max_diff, + ) diff --git a/oneflow/user/kernels/layer_norm_gpu_kernel.cu b/oneflow/user/kernels/layer_norm_gpu_kernel.cu index 32c71c501b2993efb845ff2acbd13ef5324e89bb..893bc0f3b37a59186b6af3331e9ce08cbf764c74 100644 --- a/oneflow/user/kernels/layer_norm_gpu_kernel.cu +++ b/oneflow/user/kernels/layer_norm_gpu_kernel.cu @@ -17,6 +17,7 @@ limitations under the License. #include "oneflow/core/device/cudnn_util.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/ndarray/ndarray_util.h" +#include "oneflow/core/kernel/kernel_util.cuh" namespace oneflow { @@ -142,6 +143,86 @@ void InstanceScaleCenter(DeviceCtx* ctx, const int64_t batch_size, } } +constexpr int64_t kLayerNormGpuBlockSize = 512; + +int64_t GetLayerNormBlockSize() { return kLayerNormGpuBlockSize; } + +int64_t GetLayerNormNumBlocks(const int64_t elem_cnt) { + return std::min( + static_cast((elem_cnt + kLayerNormGpuBlockSize - 1) / kLayerNormGpuBlockSize), 256); +} + +template +int64_t GetDynamicSharedMemorySize(const int64_t instance_size) { + return 2 * instance_size * sizeof(T); +} + +template<> +int64_t GetDynamicSharedMemorySize(const int64_t instance_size) { + return 2 * instance_size * sizeof(float); +} + +template +__global__ void LayerNormParamGradImpl(const I n, const I instance_size, const T* dy, + const T* normalized, const T* gamma, T* gamma_diff, + T* beta_diff, T* normalized_diff) { + extern __shared__ __align__(sizeof(T)) unsigned char bw_shared_buf[]; + auto* gamma_diff_sum_buf = reinterpret_cast(bw_shared_buf); + auto* beta_diff_sum_buf = gamma_diff_sum_buf + instance_size; + const I tid = threadIdx.x; + for (I elem_id = tid; elem_id < instance_size; elem_id += blockDim.x) { + gamma_diff_sum_buf[elem_id] = 0; + beta_diff_sum_buf[elem_id] = 0; + } + __syncthreads(); + CUDA_1D_KERNEL_LOOP_T(I, i, n) { + const I elem_id = i % instance_size; + T dy_val = dy[i]; + T normalized_val = normalized[i]; + gpu_atomic_add(&gamma_diff_sum_buf[elem_id], dy_val * normalized_val); + gpu_atomic_add(&beta_diff_sum_buf[elem_id], dy_val); + T gamma_val = gamma[elem_id]; + normalized_diff[i] = gamma_val * dy_val; + } + __syncthreads(); + for (I elem_id = tid; elem_id < instance_size; elem_id += blockDim.x) { + gpu_atomic_add(gamma_diff + elem_id, gamma_diff_sum_buf[elem_id]); + gpu_atomic_add(beta_diff + elem_id, beta_diff_sum_buf[elem_id]); + } +} + +template +__global__ void LayerNormParamGradHalfImpl(const I n, const I instance_size, const half* dy, + const half* normalized, const half* gamma, + half* tmp_gamma_diff, half* tmp_beta_diff, + half* normalized_diff) { + extern __shared__ __align__(sizeof(float)) unsigned char bw_shared_buf[]; + auto* gamma_diff_sum_buf = reinterpret_cast(bw_shared_buf); + auto* beta_diff_sum_buf = gamma_diff_sum_buf + instance_size; + const I tid = threadIdx.x; + for (I elem_id = tid; elem_id < instance_size; elem_id += blockDim.x) { + gamma_diff_sum_buf[elem_id] = 0; + beta_diff_sum_buf[elem_id] = 0; + } + __syncthreads(); + CUDA_1D_KERNEL_LOOP_T(I, i, n) { + const I elem_id = i % instance_size; + half dy_val = dy[i]; + half normalized_val = normalized[i]; + gpu_atomic_add(&gamma_diff_sum_buf[elem_id], + __half2float(dy_val) * __half2float(normalized_val)); + gpu_atomic_add(&beta_diff_sum_buf[elem_id], __half2float(dy_val)); + half gamma_val = gamma[elem_id]; + normalized_diff[i] = __hmul(gamma_val, dy_val); + } + __syncthreads(); + for (I elem_id = tid; elem_id < instance_size; elem_id += blockDim.x) { + const I offset = blockIdx.x * instance_size + elem_id; + tmp_gamma_diff[offset] = __float2half(gamma_diff_sum_buf[elem_id]); + tmp_beta_diff[offset] = __float2half(beta_diff_sum_buf[elem_id]); + } +} + } // namespace template @@ -298,36 +379,67 @@ class LayerNormParamGradGpuKernel final : public user_op::OpKernel { const bool has_gamma_diff = gamma_diff != nullptr; const bool has_normalized_diff = normalized_diff != nullptr; const bool has_gamma = gamma != nullptr; - if (has_beta_diff) { - user_op::Tensor* reduce_buf = ctx->Tensor4ArgNameAndIndex("reduce_buf", 0); - const int64_t m = beta_diff->shape().elem_cnt(); - CHECK_EQ(dy->shape().elem_cnt() % m, 0); - const int64_t n = dy->shape().elem_cnt() / m; - NdUtil::ReduceSum(ctx->device_ctx(), Var({1, m}, beta_diff->mut_dptr()), - Val({n, m}, dy->dptr()), Var({n, m}, reduce_buf->mut_dptr())); - } - if (has_gamma_diff) { + const int64_t begin_params_axis = ctx->Attr("begin_params_axis"); + const int64_t elem_cnt = dy->shape().elem_cnt(); + const int64_t m = dy->shape().Count(begin_params_axis); + int max_active_blocks; + OF_CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, LayerNormParamGradImpl, GetLayerNormBlockSize(), + GetDynamicSharedMemorySize(m))); + if (has_gamma_diff && has_beta_diff && has_normalized_diff && max_active_blocks > 0) { const user_op::Tensor* normalized = ctx->Tensor4ArgNameAndIndex("normalized", 0); - user_op::Tensor* reduce_buf = ctx->Tensor4ArgNameAndIndex("reduce_buf", 0); - const int64_t m = gamma_diff->shape().elem_cnt(); - CHECK_EQ(dy->shape().elem_cnt() % m, 0); - const int64_t n = dy->shape().elem_cnt() / m; - NdUtil::BroadcastMul(ctx->device_ctx(), Var({n, m}, reduce_buf->mut_dptr()), - Val({n, m}, normalized->dptr()), Val({n, m}, dy->dptr())); - NdUtil::ReduceSum(ctx->device_ctx(), Var({1, m}, gamma_diff->mut_dptr()), - Val({n, m}, reduce_buf->dptr()), Var({n, m}, reduce_buf->mut_dptr())); - } - if (has_normalized_diff) { - if (has_gamma) { - const int64_t m = gamma->shape().elem_cnt(); + Memset(ctx->device_ctx(), gamma_diff->mut_dptr(), 0, + gamma_diff->shape().elem_cnt() * sizeof(T)); + Memset(ctx->device_ctx(), beta_diff->mut_dptr(), 0, + beta_diff->shape().elem_cnt() * sizeof(T)); + if (elem_cnt > static_cast(GetMaxVal() / 2)) { + LayerNormParamGradImpl + <<(m), ctx->device_ctx()->cuda_stream()>>>( + elem_cnt, m, dy->dptr(), normalized->dptr(), gamma->dptr(), + gamma_diff->mut_dptr(), beta_diff->mut_dptr(), + normalized_diff->mut_dptr()); + } else { + LayerNormParamGradImpl + <<(m), ctx->device_ctx()->cuda_stream()>>>( + static_cast(elem_cnt), static_cast(m), dy->dptr(), + normalized->dptr(), gamma->dptr(), gamma_diff->mut_dptr(), + beta_diff->mut_dptr(), normalized_diff->mut_dptr()); + } + } else { + if (has_beta_diff) { + user_op::Tensor* reduce_buf = ctx->Tensor4ArgNameAndIndex("reduce_buf", 0); + CHECK_EQ(m, beta_diff->shape().elem_cnt()); CHECK_EQ(dy->shape().elem_cnt() % m, 0); const int64_t n = dy->shape().elem_cnt() / m; - NdUtil::BroadcastMul(ctx->device_ctx(), Var({n, m}, normalized_diff->mut_dptr()), - Val({n, m}, dy->dptr()), Val({1, m}, gamma->dptr())); - } else { - Memcpy(ctx->device_ctx(), normalized_diff->mut_dptr(), - dy->dptr(), - dy->shape().elem_cnt() * GetSizeOfDataType(dy->data_type())); + NdUtil::ReduceSum(ctx->device_ctx(), Var({1, m}, beta_diff->mut_dptr()), + Val({n, m}, dy->dptr()), Var({n, m}, reduce_buf->mut_dptr())); + } + if (has_gamma_diff) { + const user_op::Tensor* normalized = ctx->Tensor4ArgNameAndIndex("normalized", 0); + user_op::Tensor* reduce_buf = ctx->Tensor4ArgNameAndIndex("reduce_buf", 0); + CHECK_EQ(m, gamma_diff->shape().elem_cnt()); + CHECK_EQ(dy->shape().elem_cnt() % m, 0); + const int64_t n = dy->shape().elem_cnt() / m; + NdUtil::BroadcastMul(ctx->device_ctx(), Var({n, m}, reduce_buf->mut_dptr()), + Val({n, m}, normalized->dptr()), Val({n, m}, dy->dptr())); + NdUtil::ReduceSum(ctx->device_ctx(), Var({1, m}, gamma_diff->mut_dptr()), + Val({n, m}, reduce_buf->dptr()), + Var({n, m}, reduce_buf->mut_dptr())); + } + if (has_normalized_diff) { + if (has_gamma) { + CHECK_EQ(m, gamma->shape().elem_cnt()); + CHECK_EQ(dy->shape().elem_cnt() % m, 0); + const int64_t n = dy->shape().elem_cnt() / m; + NdUtil::BroadcastMul(ctx->device_ctx(), Var({n, m}, normalized_diff->mut_dptr()), + Val({n, m}, dy->dptr()), Val({1, m}, gamma->dptr())); + } else { + Memcpy(ctx->device_ctx(), normalized_diff->mut_dptr(), + dy->dptr(), + dy->shape().elem_cnt() * GetSizeOfDataType(dy->data_type())); + } } } }; @@ -341,6 +453,127 @@ class LayerNormParamGradGpuKernel final : public user_op::OpKernel { REGISTER_LAYER_NORM_PARAM_GRAD_GPU_KERNEL(float) REGISTER_LAYER_NORM_PARAM_GRAD_GPU_KERNEL(double) -REGISTER_LAYER_NORM_PARAM_GRAD_GPU_KERNEL(float16) +class LayerNormParamGradGpuHalfKernel final : public user_op::OpKernel { + public: + LayerNormParamGradGpuHalfKernel() = default; + ~LayerNormParamGradGpuHalfKernel() = default; + + private: + bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } + void Compute(user_op::KernelComputeContext* ctx) const override { + using NdUtil = NdarrayUtil; + auto Val = NdUtil::GetValNdarrayBuilder(); + auto Var = NdUtil::GetVarNdarrayBuilder(); + const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); + user_op::Tensor* beta_diff = ctx->Tensor4ArgNameAndIndex("beta_diff", 0); + user_op::Tensor* gamma_diff = ctx->Tensor4ArgNameAndIndex("gamma_diff", 0); + user_op::Tensor* normalized_diff = ctx->Tensor4ArgNameAndIndex("normalized_diff", 0); + user_op::Tensor* gamma = ctx->Tensor4ArgNameAndIndex("gamma", 0); + const bool has_beta_diff = beta_diff != nullptr; + const bool has_gamma_diff = gamma_diff != nullptr; + const bool has_normalized_diff = normalized_diff != nullptr; + const bool has_gamma = gamma != nullptr; + const int64_t begin_params_axis = ctx->Attr("begin_params_axis"); + const int64_t elem_cnt = dy->shape().elem_cnt(); + const int64_t m = dy->shape().Count(begin_params_axis); + int max_active_blocks; + OF_CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, LayerNormParamGradHalfImpl, GetLayerNormBlockSize(), + GetDynamicSharedMemorySize(m))); + if (has_gamma_diff && has_beta_diff && has_normalized_diff && max_active_blocks > 0) { + const user_op::Tensor* normalized = ctx->Tensor4ArgNameAndIndex("normalized", 0); + user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); + const int64_t num_blocks = GetLayerNormNumBlocks(dy->shape().elem_cnt()); + const size_t tmp_diff_size = GetCudaAlignedSize(num_blocks * m * sizeof(float16)); + float16* tmp_gamma_diff = tmp_buffer->mut_dptr(); + float16* tmp_beta_diff = + reinterpret_cast(tmp_buffer->mut_dptr() + tmp_diff_size); + float16* tmp_reduce_buf = + reinterpret_cast(tmp_buffer->mut_dptr() + 2 * tmp_diff_size); + CHECK_GE(tmp_buffer->shape().elem_cnt(), 3 * tmp_diff_size); + if (elem_cnt > static_cast(GetMaxVal() / 2)) { + LayerNormParamGradHalfImpl + <<(m), ctx->device_ctx()->cuda_stream()>>>( + elem_cnt, m, dy->dptr(), normalized->dptr(), gamma->dptr(), + reinterpret_cast(tmp_gamma_diff), reinterpret_cast(tmp_beta_diff), + normalized_diff->mut_dptr()); + } else { + LayerNormParamGradHalfImpl + <<(m), ctx->device_ctx()->cuda_stream()>>>( + static_cast(elem_cnt), static_cast(m), dy->dptr(), + normalized->dptr(), gamma->dptr(), + reinterpret_cast(tmp_gamma_diff), reinterpret_cast(tmp_beta_diff), + normalized_diff->mut_dptr()); + } + NdUtil::ReduceSum(ctx->device_ctx(), Var({1, m}, gamma_diff->mut_dptr()), + Val({num_blocks, m}, tmp_gamma_diff), Var({num_blocks, m}, tmp_reduce_buf)); + NdUtil::ReduceSum(ctx->device_ctx(), Var({1, m}, beta_diff->mut_dptr()), + Val({num_blocks, m}, tmp_beta_diff), Var({num_blocks, m}, tmp_reduce_buf)); + } else { + if (has_beta_diff) { + user_op::Tensor* reduce_buf = ctx->Tensor4ArgNameAndIndex("reduce_buf", 0); + CHECK_EQ(m, beta_diff->shape().elem_cnt()); + CHECK_EQ(dy->shape().elem_cnt() % m, 0); + const int64_t n = dy->shape().elem_cnt() / m; + NdUtil::ReduceSum(ctx->device_ctx(), Var({1, m}, beta_diff->mut_dptr()), + Val({n, m}, dy->dptr()), + Var({n, m}, reduce_buf->mut_dptr())); + } + if (has_gamma_diff) { + const user_op::Tensor* normalized = ctx->Tensor4ArgNameAndIndex("normalized", 0); + user_op::Tensor* reduce_buf = ctx->Tensor4ArgNameAndIndex("reduce_buf", 0); + CHECK_EQ(m, gamma_diff->shape().elem_cnt()); + CHECK_EQ(dy->shape().elem_cnt() % m, 0); + const int64_t n = dy->shape().elem_cnt() / m; + NdUtil::BroadcastMul(ctx->device_ctx(), Var({n, m}, reduce_buf->mut_dptr()), + Val({n, m}, normalized->dptr()), + Val({n, m}, dy->dptr())); + NdUtil::ReduceSum(ctx->device_ctx(), Var({1, m}, gamma_diff->mut_dptr()), + Val({n, m}, reduce_buf->dptr()), + Var({n, m}, reduce_buf->mut_dptr())); + } + if (has_normalized_diff) { + if (has_gamma) { + CHECK_EQ(m, gamma->shape().elem_cnt()); + CHECK_EQ(dy->shape().elem_cnt() % m, 0); + const int64_t n = dy->shape().elem_cnt() / m; + NdUtil::BroadcastMul(ctx->device_ctx(), Var({n, m}, normalized_diff->mut_dptr()), + Val({n, m}, dy->dptr()), + Val({1, m}, gamma->dptr())); + } else { + Memcpy(ctx->device_ctx(), normalized_diff->mut_dptr(), + dy->dptr(), + dy->shape().elem_cnt() * GetSizeOfDataType(dy->data_type())); + } + } + } + } +}; + +REGISTER_USER_KERNEL("layer_norm_param_grad") + .SetCreateFn() + .SetIsMatchedHob((user_op::HobDeviceTag() == "gpu") + & (user_op::HobDataType("dy", 0) == DataType::kFloat16)) + .SetInferTmpSizeFn([](user_op::InferContext* ctx) { + const int64_t begin_params_axis = ctx->Attr("begin_params_axis"); + const bool has_gamma_diff = ctx->user_op_conf().has_output("gamma_diff", 0); + const bool has_beta_diff = ctx->user_op_conf().has_output("beta_diff", 0); + const bool has_normalized_diff = ctx->user_op_conf().has_output("normalized_diff", 0); + const auto* dy = ctx->TensorDesc4ArgNameAndIndex("dy", 0); + const int64_t instance_size = dy->shape().Count(begin_params_axis); + size_t tmp_buffer_size = 0; + if (has_gamma_diff && has_beta_diff && has_normalized_diff) { + const size_t tmp_gamma_diff = GetCudaAlignedSize( + GetLayerNormNumBlocks(dy->shape().elem_cnt()) * instance_size * sizeof(float16)); + const size_t tmp_beta_diff = tmp_gamma_diff; + const size_t tmp_reduce_buf = tmp_gamma_diff; + tmp_buffer_size = tmp_gamma_diff + tmp_beta_diff + tmp_reduce_buf; + } else { + tmp_buffer_size = 0; + } + return tmp_buffer_size; + }); } // namespace oneflow