diff --git a/paddle/operators/compare_op.h b/paddle/operators/compare_op.h index b275fd75b3512343825170fc38565dd27f7f1c75..79b8c6f59c7ad3d77aa969f6b4f36f8050cfe823 100644 --- a/paddle/operators/compare_op.h +++ b/paddle/operators/compare_op.h @@ -62,7 +62,7 @@ class CompareOpKernel z->mutable_data(context.GetPlace()); int axis = context.Attr("axis"); ElementwiseComputeEx(context, x, y, axis, - z); + Functor(), z); } }; diff --git a/paddle/operators/elementwise_add_op.h b/paddle/operators/elementwise_add_op.h index c32288d6984f126f2374a13973541f4f663b25a4..c24f97a85092ff14e8211ca8bc4bb9b155510a2c 100644 --- a/paddle/operators/elementwise_add_op.h +++ b/paddle/operators/elementwise_add_op.h @@ -35,7 +35,8 @@ class ElementwiseAddKernel : public framework::OpKernel { auto* z = ctx.Output("Out"); z->mutable_data(ctx.GetPlace()); int axis = ctx.Attr("axis"); - ElementwiseComputeEx, DeviceContext, T>(ctx, x, y, axis, z); + ElementwiseComputeEx, DeviceContext, T>(ctx, x, y, axis, + AddFunctor(), z); } }; diff --git a/paddle/operators/elementwise_div_op.h b/paddle/operators/elementwise_div_op.h index 07ebade31ff5b3d5c89156e28ff5fa0670a9a842..dc863cc598ec6015067f166b1544a5d20223662a 100644 --- a/paddle/operators/elementwise_div_op.h +++ b/paddle/operators/elementwise_div_op.h @@ -35,7 +35,8 @@ class ElementwiseDivKernel : public framework::OpKernel { auto* z = ctx.Output("Out"); z->mutable_data(ctx.GetPlace()); int axis = ctx.Attr("axis"); - ElementwiseComputeEx, DeviceContext, T>(ctx, x, y, axis, z); + ElementwiseComputeEx, DeviceContext, T>(ctx, x, y, axis, + DivFunctor(), z); } }; diff --git a/paddle/operators/elementwise_max_op.h b/paddle/operators/elementwise_max_op.h index 717e45ab31db9b9a6629fb33e17654dbf986d8c5..67efe4e1511e054d54f91b5aa22ce28f222ed20a 100644 --- a/paddle/operators/elementwise_max_op.h +++ b/paddle/operators/elementwise_max_op.h @@ -35,7 +35,8 @@ class ElementwiseMaxKernel : public framework::OpKernel { auto* z = ctx.Output("Out"); z->mutable_data(ctx.GetPlace()); int axis = ctx.Attr("axis"); - ElementwiseComputeEx, DeviceContext, T>(ctx, x, y, axis, z); + ElementwiseComputeEx, DeviceContext, T>(ctx, x, y, axis, + MaxFunctor(), z); } }; diff --git a/paddle/operators/elementwise_min_op.h b/paddle/operators/elementwise_min_op.h index 0de9a91c52b0ab82cd62604de318ce68e56b767d..cf11759404d3342b8a1c0080fa09f6cd57e735db 100644 --- a/paddle/operators/elementwise_min_op.h +++ b/paddle/operators/elementwise_min_op.h @@ -35,7 +35,8 @@ class ElementwiseMinKernel : public framework::OpKernel { auto* z = ctx.Output("Out"); z->mutable_data(ctx.GetPlace()); int axis = ctx.Attr("axis"); - ElementwiseComputeEx, DeviceContext, T>(ctx, x, y, axis, z); + ElementwiseComputeEx, DeviceContext, T>(ctx, x, y, axis, + MinFunctor(), z); } }; diff --git a/paddle/operators/elementwise_mul_op.h b/paddle/operators/elementwise_mul_op.h index ae7a71e0244dfb8ad3e55683ac081f92bc36bea5..773125f5ca54e7b529df47a2823d56a5ad71e50d 100644 --- a/paddle/operators/elementwise_mul_op.h +++ b/paddle/operators/elementwise_mul_op.h @@ -34,7 +34,8 @@ class ElementwiseMulKernel : public framework::OpKernel { auto* z = ctx.Output("Out"); z->mutable_data(ctx.GetPlace()); int axis = ctx.Attr("axis"); - ElementwiseComputeEx, DeviceContext, T>(ctx, x, y, axis, z); + ElementwiseComputeEx, DeviceContext, T>(ctx, x, y, axis, + MulFunctor(), z); } }; diff --git a/paddle/operators/elementwise_op_function.h b/paddle/operators/elementwise_op_function.h index 213fe1f5a818873e8b666464cb112637261c598c..74abf7c4a58788eb0e53025886f10f5a43021a9e 100644 --- a/paddle/operators/elementwise_op_function.h +++ b/paddle/operators/elementwise_op_function.h @@ -365,10 +365,10 @@ template void ElementwiseComputeEx(const framework::ExecutionContext& ctx, const framework::Tensor* x, - const framework::Tensor* y, int axis, + const framework::Tensor* y, int axis, Functor func, framework::Tensor* z) { TransformFunctor functor( - x, y, z, ctx.template device_context(), Functor()); + x, y, z, ctx.template device_context(), func); auto x_dims = x->dims(); auto y_dims = y->dims(); diff --git a/paddle/operators/elementwise_pow_op.h b/paddle/operators/elementwise_pow_op.h index 874fd3f09f2afaccfbfca75799cc3448f7393b03..0c5dd031ec46ebecaabb701839c0f69c02678eb0 100644 --- a/paddle/operators/elementwise_pow_op.h +++ b/paddle/operators/elementwise_pow_op.h @@ -36,7 +36,8 @@ class ElementwisePowKernel : public framework::OpKernel { auto* z = ctx.Output("Out"); z->mutable_data(ctx.GetPlace()); int axis = ctx.Attr("axis"); - ElementwiseComputeEx, DeviceContext, T>(ctx, x, y, axis, z); + ElementwiseComputeEx, DeviceContext, T>(ctx, x, y, axis, + PowFunctor(), z); } }; diff --git a/paddle/operators/elementwise_sub_op.h b/paddle/operators/elementwise_sub_op.h index c2749a8e6ba689233dab4f3c72de10bf01f39fab..6a88c5f6b4c869f8ab5b4fa3b112ffc264be7145 100644 --- a/paddle/operators/elementwise_sub_op.h +++ b/paddle/operators/elementwise_sub_op.h @@ -34,7 +34,8 @@ class ElementwiseSubKernel : public framework::OpKernel { auto* z = ctx.Output("Out"); z->mutable_data(ctx.GetPlace()); int axis = ctx.Attr("axis"); - ElementwiseComputeEx, DeviceContext, T>(ctx, x, y, axis, z); + ElementwiseComputeEx, DeviceContext, T>(ctx, x, y, axis, + SubFunctor(), z); } }; diff --git a/paddle/operators/layer_norm_op.cc b/paddle/operators/layer_norm_op.cc index 8fcac00e08fc27a92a20e829013e0f66af73f613..6dd18277c9c074ab907ceabc98d953030a853bc7 100644 --- a/paddle/operators/layer_norm_op.cc +++ b/paddle/operators/layer_norm_op.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/operators/layer_norm_op.h" +#include "paddle/operators/elementwise_op_function.h" +#include "paddle/operators/math/math_function.h" namespace paddle { namespace operators { @@ -353,8 +355,9 @@ namespace ops = paddle::operators; REGISTER_OP(layer_norm, ops::LayerNormOp, ops::LayerNormOpMaker, layer_norm_grad, ops::LayerNormGradOp); REGISTER_OP_CPU_KERNEL( - layer_norm, - ops::LayerNormKernel); + layer_norm, ops::LayerNormKernel, + ops::LayerNormKernel); REGISTER_OP_CPU_KERNEL( layer_norm_grad, - ops::LayerNormGradKernel); + ops::LayerNormGradKernel, + ops::LayerNormGradKernel); diff --git a/paddle/operators/layer_norm_op.cu b/paddle/operators/layer_norm_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..a84f5a41eae4343d622ffcba1306b90739b6f0bd --- /dev/null +++ b/paddle/operators/layer_norm_op.cu @@ -0,0 +1,245 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/elementwise_op_function.h" +#include "paddle/operators/layer_norm_op.h" +#include "paddle/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; +using DataLayout = framework::DataLayout; + +namespace { +template +struct SubAndSquareFunctor { + inline HOSTDEVICE T operator()(T a, T b) const { return (a - b) * (a - b); } +}; + +template +struct DivAndSqrtFunctor { + explicit DivAndSqrtFunctor(T epsilon) { epsilon_ = epsilon; } + inline HOSTDEVICE T operator()(T a, T b) const { + return a / (sqrt(b) + epsilon_); + } + + private: + T epsilon_; +}; + +template +struct MulFunctor { + inline HOSTDEVICE T operator()(T a, T b) const { return a * b; } +}; + +template +struct AddFunctor { + inline HOSTDEVICE T operator()(T a, T b) const { return a + b; } +}; + +template +struct SubFunctor { + inline HOSTDEVICE T operator()(T a, T b) const { return a - b; } +}; + +template +struct MulInvVarFunctor { + inline HOSTDEVICE T operator()(T a, T b) const { + return a * std::sqrt(1.0 / b); + } +}; +} // namespace + +template +class LayerNormCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + const float epsilon = ctx.Attr("epsilon"); + auto *scale = ctx.Input("Scale"); + auto *bias = ctx.Input("Bias"); + auto x = *ctx.Input("X"); + + auto *y = ctx.Output("Y"); + auto *mean = ctx.Output("Mean"); + auto *var = ctx.Output("Variance"); + const auto begin_norm_axis = ctx.Attr("begin_norm_axis"); + + const auto &x_dims = x.dims(); + + y->mutable_data(ctx.GetPlace()); + mean->mutable_data(ctx.GetPlace()); + var->mutable_data(ctx.GetPlace()); + + auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis); + int left = static_cast(matrix_dim[0]); + int right = static_cast(matrix_dim[1]); + + framework::DDim matrix_shape({left, right}); + + x.Resize(matrix_shape); + y->Resize(matrix_shape); + + auto &dev_ctx = ctx.template device_context(); + math::RowwiseMean row_mean; + + // functor-> get mean + row_mean(dev_ctx, x, mean); + + // functor-> get variance + ElementwiseComputeEx, DeviceContext, T>( + ctx, &x, mean, /*axis*/ 0, SubAndSquareFunctor(), y); + row_mean(dev_ctx, *y, var); + + // functor-> get norm_out + ElementwiseComputeEx, DeviceContext, T>( + ctx, &x, mean, /*axis*/ 0, SubFunctor(), y); + ElementwiseComputeEx, DeviceContext, T>( + ctx, y, var, /*axis*/ 0, DivAndSqrtFunctor(static_cast(epsilon)), + y); + + framework::DDim scale_shape({right}); + if (scale) { + Tensor scale_matrix = *scale; + scale_matrix.Resize(scale_shape); + ElementwiseComputeEx, DeviceContext, T>( + ctx, y, &scale_matrix, /*axis*/ 1, MulFunctor(), y); + } + if (bias) { + Tensor bias_matrix = *bias; + bias_matrix.Resize(scale_shape); + ElementwiseComputeEx, DeviceContext, T>( + ctx, y, &bias_matrix, /*axis*/ 1, AddFunctor(), y); + } + y->Resize(x_dims); + } +}; + +template +class LayerNormCUDAGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + const float epsilon = ctx.Attr("epsilon"); + auto x = *ctx.Input("X"); + auto mean = *ctx.Input("Mean"); + auto var = *ctx.Input("Variance"); + auto scale = *ctx.Input("Scale"); + auto d_y = *ctx.Input(framework::GradVarName("Y")); + const auto begin_norm_axis = ctx.Attr("begin_norm_axis"); + + // init output + auto *d_x = ctx.Output(framework::GradVarName("X")); + auto *d_scale = ctx.Output(framework::GradVarName("Scale")); + auto *d_bias = ctx.Output(framework::GradVarName("Bias")); + + const auto &x_dims = x.dims(); + auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis); + int left = static_cast(matrix_dim[0]); + int right = static_cast(matrix_dim[1]); + framework::DDim matrix_shape({left, right}); + + d_y.Resize(matrix_shape); + auto &dev_ctx = ctx.template device_context(); + math::ColwiseSum colwise_sum; + + Tensor temp; + Tensor temp_norm; + if (d_scale || d_x) { + x.Resize(matrix_shape); + temp.mutable_data(matrix_shape, ctx.GetPlace()); + temp_norm.mutable_data(matrix_shape, ctx.GetPlace()); + + // get x_norm + ElementwiseComputeEx, DeviceContext, T>( + ctx, &x, &mean, /*axis*/ 0, SubFunctor(), &temp_norm); + ElementwiseComputeEx, DeviceContext, T>( + ctx, &temp_norm, &var, /*axis*/ 0, + DivAndSqrtFunctor(static_cast(epsilon)), &temp_norm); + } + + if (d_bias) { + d_bias->mutable_data(ctx.GetPlace()); + colwise_sum(dev_ctx, d_y, d_bias); + } + if (d_scale) { + d_scale->mutable_data(ctx.GetPlace()); + ElementwiseComputeEx, DeviceContext, T>( + ctx, &temp_norm, &d_y, /*axis*/ 0, MulFunctor(), &temp); + colwise_sum(dev_ctx, temp, d_scale); + } + + if (d_x) { + framework::DDim vec_shape({left}); + d_x->mutable_data(ctx.GetPlace()); + Tensor temp_vec; + temp_vec.mutable_data(vec_shape, ctx.GetPlace()); + + auto &dev_ctx = ctx.template device_context(); + math::RowwiseMean row_mean; + + if (d_scale) { + // dy_dx + ElementwiseComputeEx, DeviceContext, T>( + ctx, &d_y, &scale, /*axis*/ 1, MulFunctor(), &temp); + framework::Copy(temp, ctx.GetPlace(), ctx.device_context(), d_x); + + // dy_dmean_dx + row_mean(dev_ctx, temp, &temp_vec); + ElementwiseComputeEx, DeviceContext, T>( + ctx, d_x, &temp_vec, /*axis*/ 0, SubFunctor(), d_x); + + // dy_var_dx + ElementwiseComputeEx, DeviceContext, T>( + ctx, &temp, &temp_norm, /*axis*/ 0, MulFunctor(), &temp); + + } else { + // dy_dx + framework::Copy(d_y, ctx.GetPlace(), ctx.device_context(), d_x); + + // dy_dmean_dx + row_mean(dev_ctx, d_y, &temp_vec); + ElementwiseComputeEx, DeviceContext, T>( + ctx, d_x, &temp_vec, /*axis*/ 0, SubFunctor(), d_x); + + // dy_var_dx + ElementwiseComputeEx, DeviceContext, T>( + ctx, &d_y, &temp_norm, /*axis*/ 0, MulFunctor(), &temp); + } + // dy_var_dx + row_mean(dev_ctx, temp, &temp_vec); + ElementwiseComputeEx, DeviceContext, T>( + ctx, &temp_norm, &temp_vec, /*axis*/ 0, MulFunctor(), &temp_norm); + ElementwiseComputeEx, DeviceContext, T>( + ctx, d_x, &temp_norm, /*axis*/ 0, SubFunctor(), d_x); + + ElementwiseComputeEx, DeviceContext, T>( + ctx, d_x, &var, /*axis*/ 0, + DivAndSqrtFunctor(static_cast(epsilon)), d_x); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + layer_norm, + ops::LayerNormCUDAKernel, + ops::LayerNormCUDAKernel); +REGISTER_OP_CUDA_KERNEL( + layer_norm_grad, + ops::LayerNormCUDAGradKernel, + ops::LayerNormCUDAGradKernel); diff --git a/paddle/operators/math/math_function.cc b/paddle/operators/math/math_function.cc index dcf4b85e1aadf88e4b1ca70ac7e8b5416fc58cd8..ce0a5f6cff873166e3308a625978ecefaed2aa29 100644 --- a/paddle/operators/math/math_function.cc +++ b/paddle/operators/math/math_function.cc @@ -331,6 +331,12 @@ template struct RowwiseAdd; template struct ColwiseSum; template struct ColwiseSum; +template struct RowwiseSum; +template struct RowwiseSum; + +template struct RowwiseMean; +template struct RowwiseMean; + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/math_function.cu b/paddle/operators/math/math_function.cu index d47a7f818ded61baf31e46ea3b8ae3101324111f..c0a107470a4629506fc06dabc78a4a4716be6649 100644 --- a/paddle/operators/math/math_function.cu +++ b/paddle/operators/math/math_function.cu @@ -325,6 +325,31 @@ void ColwiseSum::operator()( vector->data()); } +template struct RowwiseSum; +// template struct RowwiseSum; +// TODO(zcd): Following ColwiseSum format, need to confirm. +// The RowwiseSum failed in debug mode, +// and only failed for this case. So reimplemented it. +template <> +void RowwiseSum::operator()( + const platform::CUDADeviceContext& context, const framework::Tensor& input, + framework::Tensor* vector) { + auto in_dims = input.dims(); + auto size = input.numel() / in_dims[0]; + PADDLE_ENFORCE_EQ(vector->numel(), in_dims[0]); + framework::Tensor one; + one.mutable_data({size}, context.GetPlace()); + SetConstant set; + set(context, &one, static_cast(1.0)); + gemv( + context, true, static_cast(in_dims[1]), static_cast(in_dims[0]), + 1.0, one.data(), input.data(), 0.0, + vector->data()); +} + +template struct RowwiseMean; +template struct RowwiseMean; + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/math_function.h b/paddle/operators/math/math_function.h index 8cc03c2ba0facae691a0d2b8a4f2ea768cfa5491..cb14d1e57468564710640773fdabd41896c178e0 100644 --- a/paddle/operators/math/math_function.h +++ b/paddle/operators/math/math_function.h @@ -128,6 +128,18 @@ struct ColwiseSum { framework::Tensor* vec); }; +template +struct RowwiseSum { + void operator()(const DeviceContext& context, const framework::Tensor& input, + framework::Tensor* vec); +}; + +template +struct RowwiseMean { + void operator()(const DeviceContext& context, const framework::Tensor& input, + framework::Tensor* vec); +}; + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/math_function_impl.h b/paddle/operators/math/math_function_impl.h index de591626df28e2bc3391b609f909612411398247..af4127788af0aaeb99199f7d6e2138a449b9fe51 100644 --- a/paddle/operators/math/math_function_impl.h +++ b/paddle/operators/math/math_function_impl.h @@ -87,6 +87,88 @@ class ColwiseSum { } }; +template +void RowwiseMean::operator()(const DeviceContext& context, + const framework::Tensor& input, + framework::Tensor* out) { + auto in_dims = input.dims(); + PADDLE_ENFORCE_EQ(in_dims.size(), 2U); + PADDLE_ENFORCE_EQ(out->numel(), in_dims[0]); + + auto in = framework::EigenMatrix::From(input); + auto vec = framework::EigenVector::Flatten(*out); + + vec.device(*context.eigen_device()) = in.mean(Eigen::array({{1}})); +} +// TODO(zcd): Following ColwiseSum format, need to confirm. +// Specialize for CPU, since Eigen implement a general reduce. However, +// rowwise-sum can be easily implemented. General reduce has a huge overhead in +// CPU +template +class RowwiseMean { + public: + void operator()(const platform::CPUDeviceContext& context, + const framework::Tensor& input, framework::Tensor* out) { + auto& in_dims = input.dims(); + PADDLE_ENFORCE_EQ(in_dims.size(), 2U); + auto height = in_dims[0]; + auto size = in_dims[1]; + PADDLE_ENFORCE_EQ(out->numel(), height); + auto inv_size = 1.0 / size; + T* out_buf = out->mutable_data(out->place()); + const T* in_buf = input.data(); + + for (size_t i = 0; i < static_cast(height); ++i) { + T sum = 0; + for (size_t j = 0; j < static_cast(size); ++j) { + sum += in_buf[i * size + j]; + } + out_buf[i] = sum * inv_size; + } + } +}; + +template +void RowwiseSum::operator()(const DeviceContext& context, + const framework::Tensor& input, + framework::Tensor* out) { + auto in_dims = input.dims(); + PADDLE_ENFORCE_EQ(in_dims.size(), 2U); + PADDLE_ENFORCE_EQ(out->numel(), in_dims[0]); + + auto in = framework::EigenMatrix::From(input); + auto vec = framework::EigenVector::Flatten(*out); + + vec.device(*context.eigen_device()) = in.sum(Eigen::array({{1}})); +} +// TODO(zcd): Following ColwiseSum format, need to confirm. +// Specialize for CPU, since Eigen implement a general reduce. However, +// rowwise-sum can be easily implemented. General reduce has a huge overhead in +// CPU +template +class RowwiseSum { + public: + void operator()(const platform::CPUDeviceContext& context, + const framework::Tensor& input, framework::Tensor* out) { + auto& in_dims = input.dims(); + PADDLE_ENFORCE_EQ(in_dims.size(), 2U); + auto height = in_dims[0]; + auto size = in_dims[1]; + PADDLE_ENFORCE_EQ(out->numel(), size); + + T* out_buf = out->mutable_data(out->place()); + const T* in_buf = input.data(); + + for (size_t i = 0; i < static_cast(height); ++i) { + T sum = 0; + for (size_t j = 0; j < static_cast(size); ++j) { + sum += in_buf[i * size + j]; + } + out_buf[i] = sum; + } + } +}; + } // namespace math } // namespace operators } // namespace paddle