未验证 提交 681a6865 编写于 作者: H hong 提交者: GitHub

Move layer norm to phi (#40193)

* update

* fix bugs; test=develop

* update; test=develop

* fix test compile error; test=develop

* fix cpu compile error; test=develop

* fix test error; test=develo

* fix layer_norm_op plugin error; test=develop

* fix error; test=develop

* fix test bug; test=develop

* update; test=develop

* polish code; test=develop

* fix bugs; test=develop

* remove unused depency; test=develop

* polish code; test=develop
上级 c335288d
......@@ -11,7 +11,7 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/layer_norm_op.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/layer_norm_op_plugin.h"
......
......@@ -17,7 +17,7 @@
#include <vector>
#include "glog/logging.h"
#include "paddle/fluid/inference/tensorrt/plugin/layer_norm_op_plugin.h"
#include "paddle/fluid/operators/layer_norm_op.h"
#include "paddle/phi/kernels/layer_norm_kernel.h"
namespace paddle {
namespace inference {
......@@ -83,7 +83,7 @@ int LayerNormPlugin::enqueue(int batch_size, const void *const *inputs,
cudaMemcpyAsync(bias_d, bias_.data(), sizeof(float) * feature_size,
cudaMemcpyHostToDevice, stream);
paddle::operators::LayerNormDirectCUDAFunctor<float> layer_norm;
phi::LayerNormDirectCUDAFunctor<float> layer_norm;
layer_norm(stream, input, input_shape, bias_d, scale_d, output, mean_d,
variance_d, begin_norm_axis, eps);
return cudaGetLastError() != cudaSuccess;
......@@ -177,7 +177,7 @@ int LayerNormPluginDynamic::enqueue(
cudaMemcpyAsync(bias_d, bias_.data(), sizeof(float) * feature_size,
cudaMemcpyHostToDevice, stream);
paddle::operators::LayerNormDirectCUDAFunctor<float> layer_norm;
phi::LayerNormDirectCUDAFunctor<float> layer_norm;
layer_norm(stream, input, input_shape, bias_d, scale_d, output, mean_d,
variance_d, begin_norm_axis, eps);
} else {
......
......@@ -25,14 +25,16 @@ limitations under the License. */
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/operators/layer_norm_kernel.cu.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/layer_norm_kernel.h"
namespace framework = paddle::framework;
namespace platform = paddle::platform;
namespace memory = paddle::memory;
USE_OP_ITSELF(dropout);
USE_OP(layer_norm);
USE_OP_ITSELF(layer_norm);
template <typename T>
using CudnnDataType = platform::CudnnDataType<T>;
......@@ -136,18 +138,23 @@ void LayerNorm(const std::vector<LayerNormParamType<T>> &scale,
const platform::CUDADeviceContext &ctx) {
framework::Scope scope;
auto place = ctx.GetPlace();
paddle::optional<const framework::LoDTensor &> scale_opt = paddle::none;
if (scale.size() > 0) {
auto var_scale = scope.Var("Scale");
auto tensor_scale = var_scale->GetMutable<framework::LoDTensor>();
framework::TensorFromVector(scale, ctx, tensor_scale);
tensor_scale->Resize({cols});
scale_opt = *tensor_scale;
}
paddle::optional<const framework::LoDTensor &> bias_opt = paddle::none;
if (bias.size() > 0) {
auto var_bias = scope.Var("Bias");
auto tensor_bias = var_bias->GetMutable<framework::LoDTensor>();
framework::TensorFromVector(bias, ctx, tensor_bias);
tensor_bias->Resize({cols});
bias_opt = *tensor_bias;
}
auto var_x = scope.Var("X");
......@@ -157,20 +164,19 @@ void LayerNorm(const std::vector<LayerNormParamType<T>> &scale,
auto var_y = scope.Var("Y");
auto tensor_y = var_y->GetMutable<framework::LoDTensor>();
tensor_y->Resize({rows, cols});
auto var_mean = scope.Var("Mean");
auto tensor_mean = var_mean->GetMutable<framework::LoDTensor>();
tensor_mean->Resize({rows});
auto var_variance = scope.Var("Variance");
auto tensor_variance = var_variance->GetMutable<framework::LoDTensor>();
framework::AttributeMap attrs;
attrs.insert({"epsilon", epsilon});
auto op = framework::OpRegistry::CreateOp(
"layer_norm", {{"X", {"X"}}, {"Scale", {"Scale"}}, {"Bias", {"Bias"}}},
{{"Y", {"Y"}}, {"Mean", {"Mean"}}, {"Variance", {"Variance"}}}, attrs);
op->Run(scope, place);
tensor_variance->Resize({rows});
ctx.Wait();
phi::LayerNormKernel<T>(static_cast<const phi::GPUContext &>(ctx), *tensor_x,
scale_opt, bias_opt, 1e-5, 1, false, tensor_y,
tensor_mean, tensor_variance);
framework::TensorToVector(*tensor_y, ctx, y);
framework::TensorToVector(*tensor_mean, ctx, means);
framework::TensorToVector(*tensor_variance, ctx, vars);
......
......@@ -198,7 +198,6 @@ struct TestFusedLayernormResidualDropoutBias {
residual_vec[i * cols + j] + out2[i * cols + j];
}
}
LayerNorm<T>(scale_vec, layernorm_bias_vec, correct_out, &correct_means,
&correct_vars, &correct_layernorm_out, epsilon, rows, cols,
*ctx);
......
......@@ -758,12 +758,14 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void ln_bwd_1024_final_kernel(
*/
template <typename T, typename U, typename ScaleT = U,
typename MaskType = uint8_t>
void ln_bwd_1024_kernel_driver(
const platform::CUDADeviceContext &dev_ctx, const int rows, const int cols,
float epsilon, const T *x_ptr, const ScaleT *scale_ptr, const U *mean_ptr,
const U *var_ptr, const T *dout_ptr, T *dx_ptr, ScaleT *dscale_ptr,
ScaleT *dbias_ptr, const MaskType *mask_ptr = nullptr,
T factor = static_cast<T>(0), T *d_dropout_src_ptr = nullptr) {
void ln_bwd_1024_kernel_driver(const phi::GPUContext &dev_ctx, const int rows,
const int cols, float epsilon, const T *x_ptr,
const ScaleT *scale_ptr, const U *mean_ptr,
const U *var_ptr, const T *dout_ptr, T *dx_ptr,
ScaleT *dscale_ptr, ScaleT *dbias_ptr,
const MaskType *mask_ptr = nullptr,
T factor = static_cast<T>(0),
T *d_dropout_src_ptr = nullptr) {
auto stream = dev_ctx.stream();
if (cols == 1024) {
// step-1: compute dx and reduced part results of dscale and dbias.
......@@ -1334,8 +1336,7 @@ static void LayerNormBackward(
const U *mean, const U *var, T *d_x,
LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *d_scale,
LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *d_bias, float epsilon,
int64_t batch_size, int64_t feature_size,
const platform::CUDADeviceContext &dev_ctx) {
int64_t batch_size, int64_t feature_size, const phi::GPUContext &dev_ctx) {
auto stream = dev_ctx.stream();
#ifdef __HIPCC__
const int kMaxBlockDim = 256;
......
......@@ -12,10 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/layer_norm_op.h"
#include <memory>
#include <string>
#include "paddle/fluid/framework/op_registry.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
......@@ -278,10 +277,3 @@ REGISTER_OPERATOR(layer_norm, ops::LayerNormOp, ops::LayerNormOpMaker,
ops::LayerNormGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(layer_norm_grad, ops::LayerNormGradOp,
ops::LayerNormGradNoNeedBufferVarInferer);
REGISTER_OP_CPU_KERNEL(
layer_norm, ops::LayerNormKernel<paddle::platform::CPUDeviceContext, float>,
ops::LayerNormKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
layer_norm_grad,
ops::LayerNormGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::LayerNormGradKernel<paddle::platform::CPUDeviceContext, double>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/operators/layer_norm_kernel.cu.h"
#include "paddle/fluid/operators/layer_norm_op.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
namespace operators {
template <typename T>
void LayerNormDirectCUDAFunctor<T>::operator()(gpuStream_t stream,
const T *input,
std::vector<int> input_shape,
const T *bias, const T *scale,
T *output, T *mean, T *variance,
int begin_norm_axis, float eps) {
const auto x_dims = phi::make_ddim(input_shape);
auto matrix_dim = phi::flatten_to_2d(x_dims, begin_norm_axis);
int64_t batch_size = static_cast<int64_t>(matrix_dim[0]);
int64_t feature_size = static_cast<int64_t>(matrix_dim[1]);
switch (GetDesiredBlockDim(feature_size)) {
FIXED_BLOCK_DIM_CASE(
LayerNormForward<T, T, kBlockDim><<<batch_size, kBlockDim, 0, stream>>>(
input, scale, bias, output, mean, variance, eps, feature_size));
default:
PADDLE_THROW(platform::errors::InvalidArgument(
"Product from begin_norm_axis to end in layer_norm must be larger "
"than 1"));
break;
}
}
template <typename T>
class LayerNormKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
using U = LayerNormParamType<T>;
const float epsilon = ctx.Attr<float>("epsilon");
auto *scale = ctx.Input<Tensor>("Scale");
auto *bias = ctx.Input<Tensor>("Bias");
auto *x = ctx.Input<Tensor>("X");
auto *y = ctx.Output<Tensor>("Y");
auto *mean = ctx.Output<Tensor>("Mean");
auto *var = ctx.Output<Tensor>("Variance");
const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis");
const auto x_dims = x->dims();
auto *x_data = x->data<T>();
auto *y_data = y->mutable_data<T>(ctx.GetPlace());
auto *mean_data = mean->mutable_data<U>(ctx.GetPlace());
auto *var_data = var->mutable_data<U>(ctx.GetPlace());
auto *void_scale_data = (scale == nullptr ? nullptr : scale->data());
auto *void_bias_data = (bias == nullptr ? nullptr : bias->data());
framework::proto::VarType::Type x_dtype =
framework::TransToProtoVarType(x->dtype());
framework::proto::VarType::Type scale_bias_dtype;
if (void_scale_data != nullptr) {
scale_bias_dtype = framework::TransToProtoVarType(scale->dtype());
if (void_bias_data != nullptr) {
PADDLE_ENFORCE_EQ(scale_bias_dtype,
framework::TransToProtoVarType(bias->dtype()),
platform::errors::InvalidArgument(
"Thie Scale and Bias of layer_norm op "
"should have the same data type."));
}
} else {
scale_bias_dtype = (void_bias_data != nullptr
? framework::TransToProtoVarType(bias->dtype())
: x_dtype);
}
bool is_scale_bias_same_dtype_with_x = x_dtype == scale_bias_dtype;
if (!is_scale_bias_same_dtype_with_x) {
PADDLE_ENFORCE_EQ(scale_bias_dtype,
framework::DataTypeTrait<U>::DataType(),
platform::errors::InvalidArgument(
"Unsupported data type of Scale and Bias: %s",
framework::DataTypeToString(scale_bias_dtype)));
}
auto matrix_dim = phi::flatten_to_2d(x_dims, begin_norm_axis);
int64_t batch_size = static_cast<int64_t>(matrix_dim[0]);
int64_t feature_size = static_cast<int64_t>(matrix_dim[1]);
auto stream = ctx.cuda_device_context().stream();
#define PADDLE_LAUNCH_LAYERNORM_FWD(ScaleBiasT, IsScaleBiasSameDTypeWithX) \
do { \
switch (GetDesiredBlockDim(feature_size)) { \
FIXED_BLOCK_DIM_CASE( \
LayerNormForward<T, U, kBlockDim, IsScaleBiasSameDTypeWithX><<< \
batch_size, kBlockDim, 0, stream>>>( \
x_data, static_cast<const ScaleBiasT *>(void_scale_data), \
static_cast<const ScaleBiasT *>(void_bias_data), y_data, \
mean_data, var_data, epsilon, feature_size)); \
default: \
PADDLE_THROW(platform::errors::InvalidArgument( \
"Product from begin_norm_axis to end must be larger than 1")); \
break; \
} \
} while (0)
#ifdef PADDLE_WITH_CUDA
bool can_call_1024_kernel = false;
if (feature_size == 1024 && scale != nullptr && bias != nullptr) {
can_call_1024_kernel = true;
}
if (can_call_1024_kernel) {
const int WARPS_M = 4;
const int WARPS_N = 1;
const int THREADS_PER_WARP = 32;
const int BYTES_PER_LDG = 16;
const int VecSize = BYTES_PER_LDG / sizeof(T);
const int THREADS_PER_CTA = WARPS_N * THREADS_PER_WARP * WARPS_M;
const int ROWS_PER_CTA = WARPS_M;
const int grid = static_cast<int>(
std::ceil(batch_size / static_cast<float>(ROWS_PER_CTA)));
if (is_scale_bias_same_dtype_with_x) {
ln_fwd_1024_kernel<T, U, T, VecSize, WARPS_M, WARPS_N,
BYTES_PER_LDG><<<grid, THREADS_PER_CTA, 0, stream>>>(
batch_size, feature_size, epsilon, x_data,
static_cast<const T *>(void_scale_data),
static_cast<const T *>(void_bias_data), mean_data, var_data,
y_data);
} else {
ln_fwd_1024_kernel<T, U, U, VecSize, WARPS_M, WARPS_N,
BYTES_PER_LDG><<<grid, THREADS_PER_CTA, 0, stream>>>(
batch_size, feature_size, epsilon, x_data,
static_cast<const U *>(void_scale_data),
static_cast<const U *>(void_bias_data), mean_data, var_data,
y_data);
}
} else {
#endif
if (is_scale_bias_same_dtype_with_x) {
PADDLE_LAUNCH_LAYERNORM_FWD(T, true);
} else {
PADDLE_LAUNCH_LAYERNORM_FWD(U, false);
}
#ifdef PADDLE_WITH_CUDA
}
#endif
#undef PADDLE_LAUNCH_LAYERNORM_FWD
}
};
template <typename T>
class LayerNormGradKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
using U = LayerNormParamType<T>;
const float epsilon = ctx.Attr<float>("epsilon");
// d_x, d_scale, d_bias may be nullptr
auto *d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *d_scale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
auto *d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias"));
auto *x = ctx.Input<Tensor>("X");
auto *mean = ctx.Input<Tensor>("Mean");
auto *var = ctx.Input<Tensor>("Variance");
auto *scale = ctx.Input<Tensor>("Scale");
auto *bias = ctx.Input<Tensor>("Bias");
auto *d_y = ctx.Input<Tensor>(framework::GradVarName("Y"));
const auto &x_dims = x->dims();
const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis");
auto matrix_dim = phi::flatten_to_2d(x_dims, begin_norm_axis);
int64_t batch_size = static_cast<int64_t>(matrix_dim[0]);
int64_t feature_size = static_cast<int64_t>(matrix_dim[1]);
auto *x_data = x->data<T>();
auto *d_y_data = d_y->data<T>();
auto *mean_data = mean->data<U>();
auto *var_data = var->data<U>();
auto *d_x_data =
(d_x == nullptr ? nullptr : d_x->mutable_data<T>(ctx.GetPlace()));
framework::proto::VarType::Type x_dtype =
framework::TransToProtoVarType(x->dtype());
framework::proto::VarType::Type scale_bias_dtype;
if (scale != nullptr) {
scale_bias_dtype = framework::TransToProtoVarType(scale->dtype());
} else {
// FIXME(zengjinle): do not find a better way to get the right
// data type of the d_scale and d_bias if scale == nullptr.
auto *bias = ctx.Input<Tensor>("Bias");
if (bias != nullptr) {
scale_bias_dtype = framework::TransToProtoVarType(bias->dtype());
} else {
scale_bias_dtype = x_dtype;
}
}
#define PADDLE_LAUNCH_LAYERNORM_BWD(ScaleBiasT, IsScaleBiasSameDTypeWithX) \
do { \
auto *scale_data = \
(scale == nullptr ? nullptr : scale->data<ScaleBiasT>()); \
auto *d_scale_data = \
(d_scale == nullptr ? nullptr : d_scale->mutable_data<ScaleBiasT>( \
ctx.GetPlace())); \
auto *d_bias_data = \
(d_bias == nullptr ? nullptr : d_bias->mutable_data<ScaleBiasT>( \
ctx.GetPlace())); \
auto *d_x_data = \
(d_x == nullptr ? nullptr : d_x->mutable_data<T>(ctx.GetPlace())); \
LayerNormBackward<T, U, IsScaleBiasSameDTypeWithX>( \
x_data, d_y_data, scale_data, mean_data, var_data, d_x_data, \
d_scale_data, d_bias_data, epsilon, batch_size, feature_size, \
ctx.cuda_device_context()); \
} while (0)
if (scale_bias_dtype == x_dtype) {
PADDLE_LAUNCH_LAYERNORM_BWD(T, true);
} else {
PADDLE_LAUNCH_LAYERNORM_BWD(U, false);
}
#undef PADDLE_LAUNCH_LAYERNORM_BWD
}
};
template class LayerNormDirectCUDAFunctor<float>;
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
#ifdef PADDLE_WITH_HIP
// MIOPEN do not support double
REGISTER_OP_CUDA_KERNEL(
layer_norm,
ops::LayerNormKernel<paddle::platform::CUDADeviceContext, float>,
ops::LayerNormKernel<paddle::platform::CUDADeviceContext, plat::float16>);
REGISTER_OP_CUDA_KERNEL(
layer_norm_grad,
ops::LayerNormGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::LayerNormGradKernel<paddle::platform::CUDADeviceContext,
plat::float16>);
#elif CUDNN_VERSION_MIN(8, 1, 0)
REGISTER_OP_CUDA_KERNEL(
layer_norm,
ops::LayerNormKernel<paddle::platform::CUDADeviceContext, float>,
ops::LayerNormKernel<paddle::platform::CUDADeviceContext, double>,
ops::LayerNormKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::LayerNormKernel<paddle::platform::CUDADeviceContext, plat::bfloat16>);
REGISTER_OP_CUDA_KERNEL(
layer_norm_grad,
ops::LayerNormGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::LayerNormGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::LayerNormGradKernel<paddle::platform::CUDADeviceContext,
plat::float16>,
ops::LayerNormGradKernel<paddle::platform::CUDADeviceContext,
plat::bfloat16>);
#else
REGISTER_OP_CUDA_KERNEL(
layer_norm,
ops::LayerNormKernel<paddle::platform::CUDADeviceContext, float>,
ops::LayerNormKernel<paddle::platform::CUDADeviceContext, double>,
ops::LayerNormKernel<paddle::platform::CUDADeviceContext, plat::float16>);
REGISTER_OP_CUDA_KERNEL(
layer_norm_grad,
ops::LayerNormGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::LayerNormGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::LayerNormGradKernel<paddle::platform::CUDADeviceContext,
plat::float16>);
#endif
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#if !defined(PADDLE_WITH_CUDA) && !defined(_WIN32) && !defined(__APPLE__) && \
!defined(__OSX__)
#include "paddle/fluid/operators/jit/kernels.h"
#endif
#include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle {
namespace platform {
class CPUDeviceContext;
class CUDADeviceContext;
} // namespace platform
} // namespace paddle
namespace paddle {
namespace operators {
// Wrap RowwiseMean and ColwiseMean.
// Reuse the cpu codes and replace the gpu codes with cublas_gemv, which is
// significantly faster. Unlike the RowwiseMean and ColwiseMean, the
// implementation only considers 2D.
template <typename DeviceContext, typename T>
struct RowwiseMean2D {
RowwiseMean2D(int left, int right, const platform::DeviceContext& dev_ctx);
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor* vec);
};
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
template <typename T>
class RowwiseMean2D<platform::CUDADeviceContext, T> {
public:
RowwiseMean2D(int left, int right, const platform::DeviceContext& dev_ctx)
: left_(left), right_(right) {
framework::DDim ones_dim({right_});
divisor_.mutable_data<T>(ones_dim, dev_ctx.GetPlace());
phi::funcs::set_constant(dev_ctx, &divisor_, 1.0 / right);
}
void operator()(const platform::CUDADeviceContext& context,
const framework::Tensor& input, framework::Tensor* out) {
phi::funcs::GetBlas<platform::CUDADeviceContext, T>(context).GEMV(
false, left_, right_, 1., input.data<T>(), divisor_.data<T>(), 0.,
out->data<T>());
}
private:
int left_;
int right_;
framework::Tensor divisor_;
};
#endif
template <typename T>
class RowwiseMean2D<platform::CPUDeviceContext, T> {
public:
RowwiseMean2D(int left, int right, const platform::DeviceContext& dev_ctx) {}
void operator()(const platform::CPUDeviceContext& context,
const framework::Tensor& input, framework::Tensor* out) {
row_mean_(context, input, out);
}
private:
phi::funcs::RowwiseMean<platform::CPUDeviceContext, T> row_mean_;
};
template <typename DeviceContext, typename T>
struct ColwiseSum2D {
ColwiseSum2D(int left, int right, const platform::DeviceContext& dev_ctx);
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor* vec);
};
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
template <typename T>
class ColwiseSum2D<platform::CUDADeviceContext, T> {
public:
ColwiseSum2D(int left, int right, const platform::DeviceContext& dev_ctx)
: left_(left), right_(right) {
framework::DDim ones_dim({left_});
divisor_.mutable_data<T>(ones_dim, dev_ctx.GetPlace());
phi::funcs::set_constant(dev_ctx, &divisor_, 1.0);
}
void operator()(const platform::CUDADeviceContext& context,
const framework::Tensor& input, framework::Tensor* out) {
phi::funcs::GetBlas<platform::CUDADeviceContext, T>(context).GEMV(
true, left_, right_, 1., input.data<T>(), divisor_.data<T>(), 0.,
out->data<T>());
}
private:
int left_;
int right_;
framework::Tensor divisor_;
};
#endif
template <typename T>
class ColwiseSum2D<platform::CPUDeviceContext, T> {
public:
ColwiseSum2D(int left, int right, const platform::DeviceContext& dev_ctx) {}
void operator()(const platform::CPUDeviceContext& context,
const framework::Tensor& input, framework::Tensor* out) {
col_wise_(context, input, out);
}
private:
phi::funcs::ColwiseSum<platform::CPUDeviceContext, T> col_wise_;
};
template <typename T>
struct SubAndSquareFunctor {
inline HOSTDEVICE T operator()(T a, T b) const { return (a - b) * (a - b); }
};
template <typename T>
struct DivAndSqrtFunctor {
explicit DivAndSqrtFunctor(T epsilon) { epsilon_ = epsilon; }
inline HOSTDEVICE T operator()(T a, T b) const {
return a / (sqrt(b + epsilon_));
}
private:
T epsilon_;
};
template <typename T>
struct MulInvVarFunctor {
inline HOSTDEVICE T operator()(T a, T b) const {
return a * std::sqrt(1.0 / b);
}
};
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
using DataLayout = framework::DataLayout;
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
template <typename T>
class LayerNormDirectCUDAFunctor {
public:
void operator()(gpuStream_t stream, const T* input,
std::vector<int> input_shape, const T* bias, const T* scale,
T* output, T* mean, T* variance, int begin_norm_axis,
float eps);
};
#endif
template <typename DeviceContext, typename T>
class LayerNormKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const float epsilon = ctx.Attr<float>("epsilon");
auto* scale = ctx.Input<Tensor>("Scale");
auto* bias = ctx.Input<Tensor>("Bias");
auto x = *ctx.Input<Tensor>("X");
auto* y = ctx.Output<Tensor>("Y");
auto* mean = ctx.Output<Tensor>("Mean");
auto* var = ctx.Output<Tensor>("Variance");
const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis");
const auto x_dims = x.dims();
y->mutable_data<T>(ctx.GetPlace());
mean->mutable_data<T>(ctx.GetPlace());
var->mutable_data<T>(ctx.GetPlace());
auto matrix_dim = phi::flatten_to_2d(x_dims, begin_norm_axis);
int left = static_cast<int>(matrix_dim[0]);
int right = static_cast<int>(matrix_dim[1]);
framework::DDim matrix_shape({left, right});
x.Resize(matrix_shape);
Tensor out;
out.ShareDataWith(*y);
out.Resize(matrix_shape);
#if defined(PADDLE_WITH_CUDA) || defined(_WIN32) || defined(__APPLE__) || \
defined(__OSX__)
auto& dev_ctx = ctx.template device_context<DeviceContext>();
RowwiseMean2D<DeviceContext, T> row_mean(left, right, ctx.device_context());
// get mean
row_mean(dev_ctx, x, mean);
// get variance
ElementwiseComputeEx<SubAndSquareFunctor<T>, DeviceContext, T>(
ctx, &x, mean, /*axis*/ 0, SubAndSquareFunctor<T>(), &out);
row_mean(dev_ctx, out, var);
// get x_norm
ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(
ctx, &x, mean, /*axis*/ 0, SubFunctor<T>(), &out);
ElementwiseComputeEx<DivAndSqrtFunctor<T>, DeviceContext, T>(
ctx, &out, var, /*axis*/ 0,
DivAndSqrtFunctor<T>(static_cast<T>(epsilon)), &out);
if (scale) {
ElementwiseComputeEx<MulFunctor<T>, DeviceContext, T>(
ctx, &out, scale, /*axis*/ 1, MulFunctor<T>(), &out);
}
if (bias) {
ElementwiseComputeEx<AddFunctor<T>, DeviceContext, T>(
ctx, &out, bias, /*axis*/ 1, AddFunctor<T>(), &out);
}
#else
PADDLE_ENFORCE_EQ(mean->numel(), left,
platform::errors::InvalidArgument(
"mean's length (%d) is not equal with expected (%d).",
mean->numel(), left));
PADDLE_ENFORCE_EQ(var->numel(), left,
platform::errors::InvalidArgument(
"var's length (%d) is not equal with expected (%d).",
var->numel(), left));
if (scale) {
PADDLE_ENFORCE_EQ(
scale->numel(), right,
platform::errors::InvalidArgument(
"scale's length (%d) is not equal with expected (%d).",
scale->numel(), right));
}
if (bias) {
PADDLE_ENFORCE_EQ(
bias->numel(), right,
platform::errors::InvalidArgument(
"bias's length (%d) is not equal with expected (%d).",
bias->numel(), right));
}
auto ker =
jit::KernelFuncs<jit::LayerNormTuple<T>, platform::CPUPlace>::Cache()
.At(right);
ker(x.data<T>(), out.data<T>(), mean->data<T>(), var->data<T>(),
scale ? scale->data<T>() : nullptr, bias ? bias->data<T>() : nullptr,
static_cast<int>(left), static_cast<const float>(epsilon), right);
#endif
}
};
template <typename DeviceContext, typename T>
class LayerNormGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const float epsilon = ctx.Attr<float>("epsilon");
auto x = *ctx.Input<Tensor>("X");
auto* mean = ctx.Input<Tensor>("Mean");
auto* var = ctx.Input<Tensor>("Variance");
auto* scale = ctx.Input<Tensor>("Scale");
auto d_y = *ctx.Input<Tensor>(framework::GradVarName("Y"));
const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis");
// init output
auto* d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* d_scale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
auto* d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias"));
const auto& x_dims = x.dims();
auto matrix_dim = phi::flatten_to_2d(x_dims, begin_norm_axis);
int left = static_cast<int>(matrix_dim[0]);
int right = static_cast<int>(matrix_dim[1]);
framework::DDim matrix_shape({left, right});
d_y.Resize(matrix_shape);
auto& dev_ctx = ctx.template device_context<DeviceContext>();
ColwiseSum2D<DeviceContext, T> colwise_sum(left, right,
ctx.device_context());
Tensor temp;
Tensor temp_norm;
if (d_scale || d_x) {
x.Resize(matrix_shape);
temp.mutable_data<T>(matrix_shape, ctx.GetPlace());
temp_norm.mutable_data<T>(matrix_shape, ctx.GetPlace());
// get x_norm
ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(
ctx, &x, mean, /*axis*/ 0, SubFunctor<T>(), &temp_norm);
ElementwiseComputeEx<DivAndSqrtFunctor<T>, DeviceContext, T>(
ctx, &temp_norm, var, /*axis*/ 0,
DivAndSqrtFunctor<T>(static_cast<T>(epsilon)), &temp_norm);
}
if (d_bias) {
d_bias->mutable_data<T>(ctx.GetPlace());
colwise_sum(dev_ctx, d_y, d_bias);
}
if (d_scale) {
d_scale->mutable_data<T>(ctx.GetPlace());
ElementwiseComputeEx<MulFunctor<T>, DeviceContext, T>(
ctx, &temp_norm, &d_y, /*axis*/ 0, MulFunctor<T>(), &temp);
colwise_sum(dev_ctx, temp, d_scale);
}
if (d_x) {
framework::DDim vec_shape({left});
d_x->mutable_data<T>(ctx.GetPlace());
auto dx_dim = d_x->dims();
Tensor temp_vec;
temp_vec.mutable_data<T>(vec_shape, ctx.GetPlace());
RowwiseMean2D<DeviceContext, T> row_mean(left, right,
ctx.device_context());
if (d_scale) {
// dy_dx
ElementwiseComputeEx<MulFunctor<T>, DeviceContext, T>(
ctx, &d_y, scale, /*axis*/ 1, MulFunctor<T>(), &temp);
framework::TensorCopy(temp, ctx.GetPlace(), ctx.device_context(), d_x);
// dy_dmean_dx
row_mean(dev_ctx, temp, &temp_vec);
ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(
ctx, d_x, &temp_vec, /*axis*/ 0, SubFunctor<T>(), d_x);
// dy_var_dx
ElementwiseComputeEx<MulFunctor<T>, DeviceContext, T>(
ctx, &temp, &temp_norm, /*axis*/ 0, MulFunctor<T>(), &temp);
} else {
// dy_dx
framework::TensorCopy(d_y, ctx.GetPlace(), ctx.device_context(), d_x);
// dy_dmean_dx
row_mean(dev_ctx, d_y, &temp_vec);
ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(
ctx, d_x, &temp_vec, /*axis*/ 0, SubFunctor<T>(), d_x);
// dy_var_dx
ElementwiseComputeEx<MulFunctor<T>, DeviceContext, T>(
ctx, &d_y, &temp_norm, /*axis*/ 0, MulFunctor<T>(), &temp);
}
// dy_var_dx
row_mean(dev_ctx, temp, &temp_vec);
ElementwiseComputeEx<MulFunctor<T>, DeviceContext, T>(
ctx, &temp_norm, &temp_vec, /*axis*/ 0, MulFunctor<T>(), &temp);
ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(
ctx, d_x, &temp, /*axis*/ 0, SubFunctor<T>(), d_x);
ElementwiseComputeEx<DivAndSqrtFunctor<T>, DeviceContext, T>(
ctx, d_x, var, /*axis*/ 0,
DivAndSqrtFunctor<T>(static_cast<T>(epsilon)), d_x);
d_x->Resize(dx_dim);
}
}
};
} // namespace operators
} // namespace paddle
......@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/layer_norm_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
namespace paddle {
......
......@@ -14,7 +14,7 @@ limitations under the License. */
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/operators/layer_norm_op.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
......
......@@ -12,8 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/layer_norm_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
#include "paddle/phi/common/data_type.h"
namespace paddle {
namespace operators {
......@@ -139,7 +140,7 @@ class LayerNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
layer_norm_p->execute(astream, args);
astream.wait();
y->set_layout(DataLayout::kMKLDNN);
y->set_layout(phi::DataLayout::kMKLDNN);
y->set_format(platform::GetMKLDNNFormat(*dst_memory));
}
};
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/layer_norm_grad_kernel.h"
#include "paddle/phi/kernels/cpu/elementwise.h"
#include "paddle/phi/kernels/funcs/layer_norm_util.h"
#if !defined(PADDLE_WITH_CUDA) && !defined(_WIN32) && !defined(__APPLE__) && \
!defined(__OSX__)
#include "paddle/fluid/operators/jit/kernels.h"
#endif
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/copy_kernel.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
template <typename T, typename Context>
void LayerNormGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& mean,
const DenseTensor& variance,
paddle::optional<const DenseTensor&> scale_opt,
paddle::optional<const DenseTensor&> bias_opt,
const DenseTensor& out_grad,
float epsilon,
int begin_norm_axis,
bool is_test,
DenseTensor* x_grad,
DenseTensor* scale_grad,
DenseTensor* bias_grad) {
auto* scale = scale_opt.get_ptr();
auto d_y = out_grad;
// init output
auto* d_x = x_grad;
auto* d_scale = scale_grad;
auto* d_bias = bias_grad;
const auto& x_dims = x.dims();
auto matrix_dim = phi::flatten_to_2d(x_dims, begin_norm_axis);
int left = static_cast<int>(matrix_dim[0]);
int right = static_cast<int>(matrix_dim[1]);
DDim matrix_shape({left, right});
d_y.Resize(matrix_shape);
funcs::ColwiseSum2D<phi::CPUContext, T> colwise_sum(left, right, dev_ctx);
DenseTensor x_tmp = x;
DenseTensor temp;
DenseTensor temp_norm;
if (d_scale || d_x) {
x_tmp.Resize(matrix_shape);
temp.Resize(matrix_shape);
dev_ctx.template Alloc<T>(&temp);
temp_norm.Resize(matrix_shape);
dev_ctx.template Alloc<T>(&temp_norm);
// get x_norm
phi::funcs::ElementwiseCompute<funcs::SubtractFunctor<T>, T, T>(
dev_ctx,
x_tmp,
mean,
/*axis*/ 0,
funcs::SubtractFunctor<T>(),
&temp_norm);
phi::funcs::ElementwiseCompute<funcs::DivAndSqrtFunctor<T>, T, T>(
dev_ctx,
temp_norm,
variance,
/*axis*/ 0,
funcs::DivAndSqrtFunctor<T>(static_cast<T>(epsilon)),
&temp_norm);
}
if (d_bias) {
dev_ctx.template Alloc<T>(d_bias);
colwise_sum(dev_ctx, d_y, d_bias);
}
if (d_scale) {
dev_ctx.template Alloc<T>(d_scale);
phi::funcs::ElementwiseCompute<funcs::MultiplyFunctor<T>, T, T>(
dev_ctx, temp_norm, d_y, 0, funcs::MultiplyFunctor<T>(), &temp);
colwise_sum(dev_ctx, temp, d_scale);
}
if (d_x) {
DDim vec_shape({left});
dev_ctx.template Alloc<T>(d_x);
auto dx_dim = d_x->dims();
DenseTensor temp_vec;
temp_vec.Resize(vec_shape);
dev_ctx.template Alloc<T>(&temp_vec);
funcs::RowwiseMean2D<phi::CPUContext, T> row_mean(left, right, dev_ctx);
if (d_scale) {
// dy_dx
phi::funcs::ElementwiseCompute<funcs::MultiplyFunctor<T>, T, T>(
dev_ctx, d_y, *scale, /*axis*/ 1, funcs::MultiplyFunctor<T>(), &temp);
phi::Copy<Context>(dev_ctx, temp, dev_ctx.GetPlace(), false, d_x);
// dy_dmean_dx
row_mean(dev_ctx, temp, &temp_vec);
phi::funcs::ElementwiseCompute<funcs::SubtractFunctor<T>, T, T>(
dev_ctx,
*d_x,
temp_vec,
/*axis*/ 0,
funcs::SubtractFunctor<T>(),
d_x);
// dy_var_dx
phi::funcs::ElementwiseCompute<funcs::MultiplyFunctor<T>, T, T>(
dev_ctx,
temp,
temp_norm,
/*axis*/ 0,
funcs::MultiplyFunctor<T>(),
&temp);
} else {
// dy_dx
phi::Copy<Context>(dev_ctx, d_y, dev_ctx.GetPlace(), false, d_x);
// dy_dmean_dx
row_mean(dev_ctx, d_y, &temp_vec);
phi::funcs::ElementwiseCompute<funcs::SubtractFunctor<T>, T, T>(
dev_ctx,
*d_x,
temp_vec,
/*axis*/ 0,
funcs::SubtractFunctor<T>(),
d_x);
// dy_var_dx
phi::funcs::ElementwiseCompute<funcs::MultiplyFunctor<T>, T, T>(
dev_ctx,
d_y,
temp_norm,
/*axis*/ 0,
funcs::MultiplyFunctor<T>(),
&temp);
}
// dy_var_dx
row_mean(dev_ctx, temp, &temp_vec);
phi::funcs::ElementwiseCompute<funcs::MultiplyFunctor<T>, T, T>(
dev_ctx,
temp_norm,
temp_vec,
/*axis*/ 0,
funcs::MultiplyFunctor<T>(),
&temp);
phi::funcs::ElementwiseCompute<funcs::SubtractFunctor<T>, T, T>(
dev_ctx, *d_x, temp, /*axis*/ 0, funcs::SubtractFunctor<T>(), d_x);
phi::funcs::ElementwiseCompute<funcs::DivAndSqrtFunctor<T>, T, T>(
dev_ctx,
*d_x,
variance,
/*axis*/ 0,
funcs::DivAndSqrtFunctor<T>(static_cast<T>(epsilon)),
d_x);
d_x->Resize(dx_dim);
}
}
} // namespace phi
PD_REGISTER_KERNEL(
layer_norm_grad, CPU, ALL_LAYOUT, phi::LayerNormGradKernel, float, double) {
}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/layer_norm_kernel.h"
#include "paddle/phi/kernels/cpu/elementwise.h"
#include "paddle/phi/kernels/funcs/layer_norm_util.h"
#if !defined(PADDLE_WITH_CUDA) && !defined(_WIN32) && !defined(__APPLE__) && \
!defined(__OSX__)
#include "paddle/fluid/operators/jit/kernels.h"
#endif
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
template <typename T, typename Context>
void LayerNormKernel(const Context& dev_ctx,
const DenseTensor& x,
paddle::optional<const DenseTensor&> scale_opt,
paddle::optional<const DenseTensor&> bias_opt,
float epsilon,
int begin_norm_axis,
bool is_test,
DenseTensor* y,
DenseTensor* mean,
DenseTensor* var) {
const auto x_dims = x.dims();
auto* scale = scale_opt.get_ptr();
auto* bias = bias_opt.get_ptr();
dev_ctx.template Alloc<T>(y);
dev_ctx.template Alloc<T>(mean);
dev_ctx.template Alloc<T>(var);
auto matrix_dim = phi::flatten_to_2d(x_dims, begin_norm_axis);
int left = static_cast<int>(matrix_dim[0]);
int right = static_cast<int>(matrix_dim[1]);
DDim matrix_shape({left, right});
auto x_tmp = x;
x_tmp.Resize(matrix_shape);
DenseTensor out;
out.ShareDataWith(*y);
out.Resize(matrix_shape);
#if defined(PADDLE_WITH_CUDA) || defined(_WIN32) || defined(__APPLE__) || \
defined(__OSX__)
funcs::RowwiseMean2D<phi::CPUContext, T> row_mean(left, right, dev_ctx);
// get mean
row_mean(dev_ctx, x_tmp, mean);
// get variance
phi::funcs::ElementwiseCompute<funcs::SubAndSquareFunctor<T>, T, T>(
dev_ctx, x_tmp, *mean, 0, funcs::SubAndSquareFunctor<T>(), &out);
row_mean(dev_ctx, out, var);
// get x_norm
phi::funcs::ElementwiseCompute<funcs::SubtractFunctor<T>, T, T>(
dev_ctx, x_tmp, *mean, 0, funcs::SubtractFunctor<T>(), &out);
phi::funcs::ElementwiseCompute<funcs::DivAndSqrtFunctor<T>, T, T>(
dev_ctx,
out,
*var,
0,
funcs::DivAndSqrtFunctor<T>(static_cast<T>(epsilon)),
&out);
if (scale) {
phi::funcs::ElementwiseCompute<funcs::MultiplyFunctor<T>, T, T>(
dev_ctx, out, *scale, 1, funcs::MultiplyFunctor<T>(), &out);
}
if (bias) {
phi::funcs::ElementwiseCompute<funcs::AddFunctor<T>, T, T>(
dev_ctx, out, *bias, 1, funcs::AddFunctor<T>(), &out);
}
#else
PADDLE_ENFORCE_EQ(mean->numel(),
left,
phi::errors::InvalidArgument(
"mean's length (%d) is not equal with expected (%d).",
mean->numel(),
left));
PADDLE_ENFORCE_EQ(var->numel(),
left,
phi::errors::InvalidArgument(
"var's length (%d) is not equal with expected (%d).",
var->numel(),
left));
if (scale) {
PADDLE_ENFORCE_EQ(
scale->numel(),
right,
phi::errors::InvalidArgument(
"scale's length (%d) is not equal with expected (%d).",
scale->numel(),
right));
}
if (bias) {
PADDLE_ENFORCE_EQ(bias->numel(),
right,
phi::errors::InvalidArgument(
"bias's length (%d) is not equal with expected (%d).",
bias->numel(),
right));
}
auto ker = paddle::operators::jit::KernelFuncs<
paddle::operators::jit::LayerNormTuple<T>,
phi::CPUPlace>::Cache()
.At(right);
ker(x_tmp.data<T>(),
out.data<T>(),
mean->data<T>(),
var->data<T>(),
scale ? scale->data<T>() : nullptr,
bias ? bias->data<T>() : nullptr,
static_cast<int>(left),
static_cast<const float>(epsilon),
right);
#endif
}
} // namespace phi
PD_REGISTER_KERNEL(
layer_norm, CPU, ALL_LAYOUT, phi::LayerNormKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/device_context.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
namespace funcs {
// Wrap RowwiseMean and ColwiseMean.
// Reuse the cpu codes and replace the gpu codes with cublas_gemv, which is
// significantly faster. Unlike the RowwiseMean and ColwiseMean, the
// implementation only considers 2D.
template <typename DeviceContext, typename T>
struct RowwiseMean2D {
RowwiseMean2D(int left, int right, const DeviceContext& dev_ctx);
void operator()(const DeviceContext& context,
const DenseTensor& input,
DenseTensor* vec);
};
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
template <typename T>
class RowwiseMean2D<phi::GPUContext, T> {
public:
RowwiseMean2D(int left, int right, const DeviceContext& dev_ctx)
: left_(left), right_(right) {
DDim ones_dim({right_});
divisor_.Resize(ones_dim);
dev_ctx.template Alloc<T>(&divisor_);
phi::funcs::set_constant(dev_ctx, &divisor_, 1.0 / right);
}
void operator()(const phi::GPUContext& context,
const DenseTensor& input,
DenseTensor* out) {
phi::funcs::GetBlas<phi::GPUContext, T>(context).GEMV(false,
left_,
right_,
1.,
input.data<T>(),
divisor_.data<T>(),
0.,
out->data<T>());
}
private:
int left_;
int right_;
DenseTensor divisor_;
};
#endif
template <typename T>
class RowwiseMean2D<phi::CPUContext, T> {
public:
RowwiseMean2D(int left, int right, const DeviceContext& dev_ctx) {}
void operator()(const phi::CPUContext& context,
const DenseTensor& input,
DenseTensor* out) {
row_mean_(context, input, out);
}
private:
phi::funcs::RowwiseMean<phi::CPUContext, T> row_mean_;
};
template <typename DeviceContext, typename T>
struct ColwiseSum2D {
ColwiseSum2D(int left, int right, const DeviceContext& dev_ctx);
void operator()(const phi::DeviceContext& context,
const DenseTensor& input,
DenseTensor* vec);
};
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
template <typename T>
class ColwiseSum2D<phi::GPUContext, T> {
public:
ColwiseSum2D(int left, int right, const phi::GPUContext& dev_ctx)
: left_(left), right_(right) {
DDim ones_dim({left_});
divisor_.Resize(ones_dim);
dev_ctx.template Alloc<T>(&divisor_);
phi::funcs::set_constant(dev_ctx, &divisor_, 1.0);
}
void operator()(const phi::GPUContext& context,
const DenseTensor& input,
DenseTensor* out) {
phi::funcs::GetBlas<phi::GPUContext, T>(context).GEMV(true,
left_,
right_,
1.,
input.data<T>(),
divisor_.data<T>(),
0.,
out->data<T>());
}
private:
int left_;
int right_;
DenseTensor divisor_;
};
#endif
template <typename T>
class ColwiseSum2D<phi::CPUContext, T> {
public:
ColwiseSum2D(int left, int right, const phi::CPUContext& dev_ctx) {}
void operator()(const phi::CPUContext& context,
const DenseTensor& input,
DenseTensor* out) {
col_wise_(context, input, out);
}
private:
phi::funcs::ColwiseSum<phi::CPUContext, T> col_wise_;
};
template <typename T>
struct SubAndSquareFunctor {
inline HOSTDEVICE T operator()(T a, T b) const { return (a - b) * (a - b); }
};
template <typename T>
struct DivAndSqrtFunctor {
explicit DivAndSqrtFunctor(T epsilon) { epsilon_ = epsilon; }
inline HOSTDEVICE T operator()(T a, T b) const {
return a / (sqrt(b + epsilon_));
}
private:
T epsilon_;
};
template <typename T>
struct MulInvVarFunctor {
inline HOSTDEVICE T operator()(T a, T b) const {
return a * std::sqrt(1.0 / b);
}
};
} // namespace funcs
} // namespace phi
......@@ -331,12 +331,20 @@ template struct ColwiseSum<paddle::platform::CPUDeviceContext, double>;
template struct ColwiseSum<paddle::platform::CPUDeviceContext, int>;
template struct ColwiseSum<paddle::platform::CPUDeviceContext, int64_t>;
template struct ColwiseSum<phi::CPUContext, float>;
template struct ColwiseSum<phi::CPUContext, double>;
template struct ColwiseSum<phi::CPUContext, int>;
template struct ColwiseSum<phi::CPUContext, int64_t>;
template struct RowwiseSum<paddle::platform::CPUDeviceContext, float>;
template struct RowwiseSum<paddle::platform::CPUDeviceContext, double>;
template struct RowwiseMean<paddle::platform::CPUDeviceContext, float>;
template struct RowwiseMean<paddle::platform::CPUDeviceContext, double>;
template struct RowwiseMean<phi::CPUContext, float>;
template struct RowwiseMean<phi::CPUContext, double>;
template <typename T>
struct ElementwiseAddTo<paddle::platform::CPUDeviceContext, T> {
void operator()(paddle::platform::CPUDeviceContext* ctx,
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/layer_norm_grad_kernel.h"
#include "paddle/fluid/operators/layer_norm_kernel.cu.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/layer_norm_util.h"
namespace phi {
template <typename T, typename Context>
void LayerNormGradKernel(const Context &dev_ctx,
const DenseTensor &x,
const DenseTensor &mean,
const DenseTensor &variance,
paddle::optional<const DenseTensor &> scale_opt,
paddle::optional<const DenseTensor &> bias_opt,
const DenseTensor &out_grad,
float epsilon,
int begin_norm_axis,
bool is_test,
DenseTensor *x_grad,
DenseTensor *scale_grad,
DenseTensor *bias_grad) {
using U = paddle::operators::LayerNormParamType<T>;
// d_x, d_scale, d_bias may be nullptr
auto *d_x = x_grad;
auto *d_scale = scale_grad;
auto *d_bias = bias_grad;
auto *scale = scale_opt.get_ptr();
auto *bias = bias_opt.get_ptr();
auto *d_y = &out_grad;
const auto &x_dims = x.dims();
auto matrix_dim = phi::flatten_to_2d(x_dims, begin_norm_axis);
int64_t batch_size = static_cast<int64_t>(matrix_dim[0]);
int64_t feature_size = static_cast<int64_t>(matrix_dim[1]);
auto *x_data = x.data<T>();
auto *d_y_data = d_y->data<T>();
auto *mean_data = mean.data<U>();
auto *var_data = variance.data<U>();
auto *d_x_data = (d_x == nullptr ? nullptr : dev_ctx.template Alloc<T>(d_x));
auto x_dtype = x.dtype();
phi::DataType scale_bias_dtype;
if (scale != nullptr) {
scale_bias_dtype = scale->dtype();
} else {
// FIXME(zengjinle): do not find a better way to get the right
// data type of the d_scale and d_bias if scale == nullptr.
if (bias != nullptr) {
scale_bias_dtype = bias->dtype();
} else {
scale_bias_dtype = x_dtype;
}
}
#define PADDLE_LAUNCH_LAYERNORM_BWD(ScaleBiasT, IsScaleBiasSameDTypeWithX) \
do { \
auto *scale_data = \
(scale == nullptr ? nullptr : scale->data<ScaleBiasT>()); \
auto *d_scale_data = \
(d_scale == nullptr ? nullptr \
: dev_ctx.template Alloc<ScaleBiasT>(d_scale)); \
auto *d_bias_data = \
(d_bias == nullptr ? nullptr \
: dev_ctx.template Alloc<ScaleBiasT>(d_bias)); \
auto *d_x_data = \
(d_x == nullptr ? nullptr : dev_ctx.template Alloc<T>(d_x)); \
paddle::operators::LayerNormBackward<T, U, IsScaleBiasSameDTypeWithX>( \
x_data, \
d_y_data, \
scale_data, \
mean_data, \
var_data, \
d_x_data, \
d_scale_data, \
d_bias_data, \
epsilon, \
batch_size, \
feature_size, \
dev_ctx); \
} while (0)
if (scale_bias_dtype == x_dtype) {
PADDLE_LAUNCH_LAYERNORM_BWD(T, true);
} else {
PADDLE_LAUNCH_LAYERNORM_BWD(U, false);
}
#undef PADDLE_LAUNCH_LAYERNORM_BWD
}
} // namespace phi
#ifdef PADDLE_WITH_HIP
// MIOPEN do not support double
PD_REGISTER_KERNEL(layer_norm_grad,
GPU,
ALL_LAYOUT,
phi::LayerNormGradKernel,
float,
phi::dtype::float16) {}
#elif CUDNN_VERSION_MIN(8, 1, 0)
PD_REGISTER_KERNEL(layer_norm_grad,
GPU,
ALL_LAYOUT,
phi::LayerNormGradKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
#else
PD_REGISTER_KERNEL(layer_norm_grad,
GPU,
ALL_LAYOUT,
phi::LayerNormGradKernel,
float,
double,
phi::dtype::float16) {}
#endif
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/layer_norm_kernel.h"
#include "paddle/fluid/operators/layer_norm_kernel.cu.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/layer_norm_util.h"
namespace phi {
template <typename T>
void LayerNormDirectCUDAFunctor<T>::operator()(gpuStream_t stream,
const T *input,
std::vector<int> input_shape,
const T *bias,
const T *scale,
T *output,
T *mean,
T *variance,
int begin_norm_axis,
float eps) {
const auto x_dims = phi::make_ddim(input_shape);
auto matrix_dim = phi::flatten_to_2d(x_dims, begin_norm_axis);
int64_t batch_size = static_cast<int64_t>(matrix_dim[0]);
int64_t feature_size = static_cast<int64_t>(matrix_dim[1]);
switch (paddle::operators::GetDesiredBlockDim(feature_size)) {
FIXED_BLOCK_DIM_CASE(paddle::operators::LayerNormForward<
T,
T,
kBlockDim><<<batch_size, kBlockDim, 0, stream>>>(
input, scale, bias, output, mean, variance, eps, feature_size));
default:
PADDLE_THROW(phi::errors::InvalidArgument(
"Product from begin_norm_axis to end in layer_norm must be larger "
"than 1"));
break;
}
}
template class LayerNormDirectCUDAFunctor<float>;
template <typename T, typename Context>
void LayerNormKernel(const Context &dev_ctx,
const DenseTensor &x,
paddle::optional<const DenseTensor &> scale_opt,
paddle::optional<const DenseTensor &> bias_opt,
float epsilon,
int begin_norm_axis,
bool is_test,
DenseTensor *y,
DenseTensor *mean,
DenseTensor *var) {
using U = paddle::operators::LayerNormParamType<T>;
auto *scale = scale_opt.get_ptr();
auto *bias = bias_opt.get_ptr();
const auto x_dims = x.dims();
auto *x_data = x.data<T>();
auto *y_data = dev_ctx.template Alloc<T>(y);
auto *mean_data = dev_ctx.template Alloc<U>(mean);
auto *var_data = dev_ctx.template Alloc<U>(var);
auto *void_scale_data = (scale == nullptr ? nullptr : scale->data());
auto *void_bias_data = (bias == nullptr ? nullptr : bias->data());
auto x_dtype = x.dtype();
phi::DataType scale_bias_dtype;
if (void_scale_data != nullptr) {
scale_bias_dtype = scale->dtype();
if (void_bias_data != nullptr) {
PADDLE_ENFORCE_EQ(
scale->dtype(),
bias->dtype(),
phi::errors::InvalidArgument("Thie Scale and Bias of layer_norm op "
"should have the same data type."));
}
} else {
scale_bias_dtype = (void_bias_data != nullptr ? bias->dtype() : x_dtype);
}
bool is_scale_bias_same_dtype_with_x = x_dtype == scale_bias_dtype;
if (!is_scale_bias_same_dtype_with_x) {
PADDLE_ENFORCE_EQ(scale_bias_dtype,
paddle::experimental::CppTypeToDataType<U>::Type(),
phi::errors::InvalidArgument(
"Unsupported data type of Scale and Bias"));
}
auto matrix_dim = phi::flatten_to_2d(x_dims, begin_norm_axis);
int64_t batch_size = static_cast<int64_t>(matrix_dim[0]);
int64_t feature_size = static_cast<int64_t>(matrix_dim[1]);
auto stream = dev_ctx.stream();
#define PADDLE_LAUNCH_LAYERNORM_FWD(ScaleBiasT, IsScaleBiasSameDTypeWithX) \
do { \
switch (paddle::operators::GetDesiredBlockDim(feature_size)) { \
FIXED_BLOCK_DIM_CASE(paddle::operators::LayerNormForward< \
T, \
U, \
kBlockDim, \
IsScaleBiasSameDTypeWithX><<<batch_size, \
kBlockDim, \
0, \
stream>>>( \
x_data, \
static_cast<const ScaleBiasT *>(void_scale_data), \
static_cast<const ScaleBiasT *>(void_bias_data), \
y_data, \
mean_data, \
var_data, \
epsilon, \
feature_size)); \
default: \
PADDLE_THROW(phi::errors::InvalidArgument( \
"Product from begin_norm_axis to end must be larger than 1")); \
break; \
} \
} while (0)
#ifdef PADDLE_WITH_CUDA
bool can_call_1024_kernel = false;
if (feature_size == 1024 && scale != nullptr && bias != nullptr) {
can_call_1024_kernel = true;
}
if (can_call_1024_kernel) {
const int WARPS_M = 4;
const int WARPS_N = 1;
const int THREADS_PER_WARP = 32;
const int BYTES_PER_LDG = 16;
const int VecSize = BYTES_PER_LDG / sizeof(T);
const int THREADS_PER_CTA = WARPS_N * THREADS_PER_WARP * WARPS_M;
const int ROWS_PER_CTA = WARPS_M;
const int grid = static_cast<int>(
std::ceil(batch_size / static_cast<float>(ROWS_PER_CTA)));
if (is_scale_bias_same_dtype_with_x) {
paddle::operators::ln_fwd_1024_kernel<
T,
U,
T,
VecSize,
WARPS_M,
WARPS_N,
BYTES_PER_LDG><<<grid, THREADS_PER_CTA, 0, stream>>>(
batch_size,
feature_size,
epsilon,
x_data,
static_cast<const T *>(void_scale_data),
static_cast<const T *>(void_bias_data),
mean_data,
var_data,
y_data);
} else {
paddle::operators::ln_fwd_1024_kernel<
T,
U,
U,
VecSize,
WARPS_M,
WARPS_N,
BYTES_PER_LDG><<<grid, THREADS_PER_CTA, 0, stream>>>(
batch_size,
feature_size,
epsilon,
x_data,
static_cast<const U *>(void_scale_data),
static_cast<const U *>(void_bias_data),
mean_data,
var_data,
y_data);
}
} else {
#endif
if (is_scale_bias_same_dtype_with_x) {
PADDLE_LAUNCH_LAYERNORM_FWD(T, true);
} else {
PADDLE_LAUNCH_LAYERNORM_FWD(U, false);
}
#ifdef PADDLE_WITH_CUDA
}
#endif
#undef PADDLE_LAUNCH_LAYERNORM_FWD
}
} // namespace phi
#ifdef PADDLE_WITH_HIP
// MIOPEN do not support double
PD_REGISTER_KERNEL(layer_norm,
GPU,
ALL_LAYOUT,
phi::LayerNormKernel,
float,
phi::dtype::float16) {}
#elif CUDNN_VERSION_MIN(8, 1, 0)
PD_REGISTER_KERNEL(layer_norm,
GPU,
ALL_LAYOUT,
phi::LayerNormKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
#else
PD_REGISTER_KERNEL(layer_norm,
GPU,
ALL_LAYOUT,
phi::LayerNormKernel,
float,
double,
phi::dtype::float16) {}
#endif
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void LayerNormGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& mean,
const DenseTensor& variance,
paddle::optional<const DenseTensor&> scale,
paddle::optional<const DenseTensor&> bias,
const DenseTensor& out_grad,
float epsilon,
int begin_norm_axis,
bool is_test,
DenseTensor* x_grad,
DenseTensor* scale_grad,
DenseTensor* bias_grad);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/backends/gpu/gpu_decls.h"
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void LayerNormKernel(const Context& ctx,
const DenseTensor& x,
paddle::optional<const DenseTensor&> scale,
paddle::optional<const DenseTensor&> bias,
float epsilon,
int begin_norm_axis,
bool is_test,
DenseTensor* out,
DenseTensor* mean,
DenseTensor* variance);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
template <typename T>
class LayerNormDirectCUDAFunctor {
public:
void operator()(gpuStream_t stream,
const T* input,
std::vector<int> input_shape,
const T* bias,
const T* scale,
T* output,
T* mean,
T* variance,
int begin_norm_axis,
float eps);
};
#endif
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature LayerNormOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("layer_norm",
{"X", "Scale", "Bias"},
{"epsilon", "begin_norm_axis", "is_test"},
{"Y", "Mean", "Variance"});
}
KernelSignature LayerNormGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"layer_norm_grad",
{"X", "Mean", "Variance", "Scale", "Bias", GradVarName("Y")},
{"epsilon", "begin_norm_axis", "is_test"},
{GradVarName("X"), GradVarName("Scale"), GradVarName("Bias")});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(layer_norm, phi::LayerNormOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(layer_norm_grad,
phi::LayerNormGradOpArgumentMapping);
......@@ -215,6 +215,8 @@ class TestLayerNormOp(unittest.TestCase):
for name in ['x', 'scale', 'bias', 'y@GRAD']
},
fetch_list=fetch_list)
# print(y)
# print(out[0])
self.__assert_close(y, out[0], "y")
self.__assert_close(mean, out[1], "mean")
self.__assert_close(variance, out[2], "variance", 1e-3)
......@@ -238,6 +240,7 @@ class TestLayerNormOp(unittest.TestCase):
def test_check_forward_backward_with_scale_and_bias(self):
self.check_forward_backward(shape=[1, 3, 4, 5], begin_norm_axis=1)
self.check_forward_backward(shape=[2, 3, 4, 5], begin_norm_axis=1)
self.check_forward_backward(
shape=[2, 3, 4, 5],
......@@ -432,4 +435,5 @@ class TestGetSetKeepLayerNormScaleBiasFP32Flag(unittest.TestCase):
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册