From 6eeb16b8944955e572b5cdc0af450adfc5cd37a1 Mon Sep 17 00:00:00 2001 From: sneaxiy <32832641+sneaxiy@users.noreply.github.com> Date: Mon, 17 Jan 2022 10:53:19 +0800 Subject: [PATCH] add squared_l2_norm (#38968) --- paddle/fluid/memory/buffer.h | 60 +++++++++++++ paddle/fluid/operators/math/squared_l2_norm.h | 84 +++++++++++++++++++ paddle/fluid/operators/optimizers/lamb_op.h | 43 +++++----- paddle/fluid/operators/squared_l2_norm_op.h | 16 ++-- .../unittests/test_squared_l2_norm_op.py | 18 ++++ 5 files changed, 191 insertions(+), 30 deletions(-) create mode 100644 paddle/fluid/memory/buffer.h create mode 100644 paddle/fluid/operators/math/squared_l2_norm.h diff --git a/paddle/fluid/memory/buffer.h b/paddle/fluid/memory/buffer.h new file mode 100644 index 00000000000..127d6357e4a --- /dev/null +++ b/paddle/fluid/memory/buffer.h @@ -0,0 +1,60 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/fluid/memory/malloc.h" +#include "paddle/fluid/platform/place.h" + +namespace paddle { +namespace memory { + +class Buffer { + public: + explicit Buffer(const platform::Place &place) : place_(place) {} + + template + T *Alloc(size_t size) { + using AllocT = typename std::conditional::value, + uint8_t, T>::type; + if (UNLIKELY(size == 0)) return nullptr; + size *= sizeof(AllocT); + if (allocation_ == nullptr || allocation_->size() < size) { + allocation_ = memory::Alloc(place_, size); + } + return reinterpret_cast(allocation_->ptr()); + } + + template + const T *Get() const { + return reinterpret_cast( + allocation_ && allocation_->size() > 0 ? allocation_->ptr() : nullptr); + } + + template + T *GetMutable() { + return reinterpret_cast( + allocation_ && allocation_->size() > 0 ? allocation_->ptr() : nullptr); + } + + size_t Size() const { return allocation_ ? 0 : allocation_->size(); } + + private: + AllocationPtr allocation_; + platform::Place place_; +}; + +} // namespace memory +} // namespace paddle diff --git a/paddle/fluid/operators/math/squared_l2_norm.h b/paddle/fluid/operators/math/squared_l2_norm.h new file mode 100644 index 00000000000..540f6961538 --- /dev/null +++ b/paddle/fluid/operators/math/squared_l2_norm.h @@ -0,0 +1,84 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/memory/buffer.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/enforce.h" + +#if defined(__NVCC__) || defined(__HIPCC__) +#include "paddle/fluid/operators/kernel_primitives/functor_primitives.h" +#ifdef __NVCC__ +#include "cub/cub.cuh" +#else +#include +namespace cub = hipcub; +#endif +#endif + +namespace paddle { +namespace operators { +namespace math { + +template +void SquaredL2Norm(const platform::CPUDeviceContext& ctx, const T1* x, T2* y, + size_t numel, memory::Buffer* buffer = nullptr) { + if (std::is_same::value) { + using EigenT = typename framework::EigenTensor::Type; + using ConstEigenT = typename framework::EigenTensor::ConstType; + using EigenDim = typename framework::EigenDim<1>::Type; + ConstEigenT input(x, EigenDim(numel)); + EigenT output(reinterpret_cast(y), EigenDim(1)); + output.device(*ctx.eigen_device()) = input.square().sum(); + } else { + T2 ret = static_cast(0); + for (size_t i = 0; i < numel; ++i) { + auto tmp = static_cast(x[i]); + ret += tmp * tmp; + } + *y = ret; + } +} + +#if defined(__NVCC__) || defined(__HIPCC__) +template +void SquaredL2Norm(const platform::CUDADeviceContext& ctx, const T1* x, T2* y, + size_t numel, memory::Buffer* buffer = nullptr) { + if (UNLIKELY(buffer == nullptr)) { + memory::Buffer tmp_buffer(ctx.GetPlace()); + return SquaredL2Norm(ctx, x, y, numel, &tmp_buffer); + } + + using FunctorT = kernel_primitives::SquareFunctor; + cub::TransformInputIterator iter(x, FunctorT()); + size_t temp_storage_bytes = 0; + void* d_temp_storage = nullptr; + auto stream = ctx.stream(); +#pragma unroll 2 + for (size_t i = 0; i < 2; ++i) { + if (temp_storage_bytes > 0) { + d_temp_storage = buffer->Alloc(temp_storage_bytes); + } + PADDLE_ENFORCE_GPU_SUCCESS( + cub::DeviceReduce::Reduce(d_temp_storage, temp_storage_bytes, iter, y, + numel, cub::Sum(), static_cast(0))); + } +} +#endif + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/optimizers/lamb_op.h b/paddle/fluid/operators/optimizers/lamb_op.h index e3798b49dcb..6d98522d752 100644 --- a/paddle/fluid/operators/optimizers/lamb_op.h +++ b/paddle/fluid/operators/optimizers/lamb_op.h @@ -17,9 +17,11 @@ limitations under the License. */ #include #include #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/memory/buffer.h" #include "paddle/fluid/operators/amp/fp16_type_traits.h" #include "paddle/fluid/operators/math/algorithm.h" #include "paddle/fluid/operators/math/selected_rows_functor.h" +#include "paddle/fluid/operators/math/squared_l2_norm.h" #include "paddle/fluid/platform/eigen_ext.h" #include "paddle/fluid/platform/for_range.h" @@ -383,8 +385,8 @@ struct LambParamUpateFunctor inline HOSTDEVICE void operator()(size_t i) const { if (skip_update_ && *skip_update_) return; MT lr = *lr_; - MT pn = *param_norm_; - MT tn = *trust_ratio_div_norm_; + MT pn = Eigen::numext::sqrt(*param_norm_); + MT tn = Eigen::numext::sqrt(*trust_ratio_div_norm_); MT r = (pn > static_cast(0) && tn > static_cast(0)) ? pn / tn @@ -488,9 +490,11 @@ class LambOpKernel : public framework::OpKernel { } auto& dev_ctx = ctx.template device_context(); - platform::ForRange for_range(dev_ctx, param.numel()); + auto numel = param.numel(); + platform::ForRange for_range(dev_ctx, numel); auto trust_ratio_div = ctx.AllocateTmpTensor(param.dims(), dev_ctx); + auto* trust_ratio_div_ptr = trust_ratio_div.template data(); const void* param_ptr = param.data(); const void* master_param_ptr = @@ -521,7 +525,7 @@ class LambOpKernel : public framework::OpKernel { grad.template data(), static_cast(IsMultiPrecision ? master_param_ptr : param_ptr), - trust_ratio_div.template data(), skip_update_flag); + trust_ratio_div_ptr, skip_update_flag); for_range(moment_update_functor); beta1_pow_out.template mutable_data(platform::CPUPlace())[0] = beta1 * beta1_pow.template data()[0]; @@ -545,7 +549,7 @@ class LambOpKernel : public framework::OpKernel { grad.template data(), static_cast(IsMultiPrecision ? master_param_ptr : param_ptr), - trust_ratio_div.template data(), skip_update_flag); + trust_ratio_div_ptr, skip_update_flag); for_range(moment_update_functor); } } else if (grad_var->IsType()) { @@ -638,34 +642,29 @@ class LambOpKernel : public framework::OpKernel { // Update parameter auto p_norm_t = ctx.AllocateTmpTensor({1}, dev_ctx); + auto* p_norm_ptr = p_norm_t.template data(); + auto trust_ratio_div_norm_t = ctx.AllocateTmpTensor({1}, dev_ctx); - - auto p_norm = framework::EigenScalar::From(p_norm_t); - auto trust_ratio_div_norm = - framework::EigenScalar::From(trust_ratio_div_norm_t); - auto t = framework::EigenVector::Flatten(trust_ratio_div); + auto* trust_ratio_div_norm_ptr = trust_ratio_div_norm_t.template data(); // TODO(zengjinle): remove the following Eigen operations when // *skip_update == true. - auto* place = dev_ctx.eigen_device(); - if (IsMultiPrecision) { - auto mp = framework::EigenVector::Flatten(*master_param); - p_norm.device(*place) = mp.square().sum().sqrt(); - } else { - auto p = framework::EigenVector::Flatten(param); - p_norm.device(*place) = p.square().sum().sqrt(); - } - trust_ratio_div_norm.device(*place) = t.square().sum().sqrt(); + memory::Buffer buffer(dev_ctx.GetPlace()); + math::SquaredL2Norm( + dev_ctx, reinterpret_cast(IsMultiPrecision ? master_param_ptr + : param_ptr), + p_norm_ptr, numel, &buffer); + math::SquaredL2Norm(dev_ctx, trust_ratio_div_ptr, trust_ratio_div_norm_ptr, + numel, &buffer); #define CALL_PADDLE_UPDATE_LAMB_PARAM_FUNC(__should_update_beta_pow) \ do { \ LambParamUpateFunctor \ param_update_functor( \ lr.template data(), static_cast(param_ptr), \ - static_cast(master_param_ptr), \ - p_norm_t.template data(), trust_ratio_div.template data(), \ - trust_ratio_div_norm_t.template data(), \ + static_cast(master_param_ptr), p_norm_ptr, \ + trust_ratio_div_ptr, trust_ratio_div_norm_ptr, \ static_cast(param_out_ptr), \ static_cast(master_param_out_ptr), skip_update_flag); \ if (__should_update_beta_pow) { \ diff --git a/paddle/fluid/operators/squared_l2_norm_op.h b/paddle/fluid/operators/squared_l2_norm_op.h index d44eec1a80b..e35b432a645 100644 --- a/paddle/fluid/operators/squared_l2_norm_op.h +++ b/paddle/fluid/operators/squared_l2_norm_op.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/squared_l2_norm.h" namespace paddle { namespace operators { @@ -24,16 +25,15 @@ template class SquaredL2NormKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &context) const override { - const framework::Tensor *X = context.Input("X"); - framework::Tensor *Out = context.Output("Out"); - Out->mutable_data(context.GetPlace()); + const framework::Tensor *x = context.Input("X"); + const auto *x_ptr = x->data(); + auto numel = x->numel(); - auto x = framework::EigenVector::Flatten(*X); - auto out = framework::EigenScalar::From(*Out); - auto *place = - context.template device_context().eigen_device(); + framework::Tensor *out = context.Output("Out"); + auto *out_ptr = out->mutable_data(context.GetPlace()); - out.device(*place) = x.square().sum(); + math::SquaredL2Norm(context.template device_context(), x_ptr, + out_ptr, numel); } }; diff --git a/python/paddle/fluid/tests/unittests/test_squared_l2_norm_op.py b/python/paddle/fluid/tests/unittests/test_squared_l2_norm_op.py index 439bae9510e..430632ebb87 100644 --- a/python/paddle/fluid/tests/unittests/test_squared_l2_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_squared_l2_norm_op.py @@ -18,6 +18,8 @@ import numpy as np import unittest from numpy import linalg as LA from op_test import OpTest +import paddle +from paddle import _C_ops class TestL2LossOp(OpTest): @@ -41,5 +43,21 @@ class TestL2LossOp(OpTest): ['X'], 'Out', max_relative_error=self.max_relative_error) +class TestL2LossDeterministic(unittest.TestCase): + def check_place(self, place): + with paddle.fluid.dygraph.guard(place): + x_np = np.random.rand(5, 11, 13).astype('float32') + x = paddle.to_tensor(x_np) + y1 = _C_ops.squared_l2_norm(x) + y2 = _C_ops.squared_l2_norm(x) + self.assertTrue(np.array_equal(y1.numpy(), y2.numpy())) + + def test_main(self): + self.check_place(paddle.CPUPlace()) + if paddle.is_compiled_with_cuda(): + self.check_place(paddle.CUDAPlace(0)) + + if __name__ == "__main__": + paddle.enable_static() unittest.main() -- GitLab