未验证 提交 6eeb16b8 编写于 作者: S sneaxiy 提交者: GitHub

add squared_l2_norm (#38968)

上级 ac933235
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <type_traits>
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace memory {
class Buffer {
public:
explicit Buffer(const platform::Place &place) : place_(place) {}
template <typename T>
T *Alloc(size_t size) {
using AllocT = typename std::conditional<std::is_same<T, void>::value,
uint8_t, T>::type;
if (UNLIKELY(size == 0)) return nullptr;
size *= sizeof(AllocT);
if (allocation_ == nullptr || allocation_->size() < size) {
allocation_ = memory::Alloc(place_, size);
}
return reinterpret_cast<T *>(allocation_->ptr());
}
template <typename T>
const T *Get() const {
return reinterpret_cast<const T *>(
allocation_ && allocation_->size() > 0 ? allocation_->ptr() : nullptr);
}
template <typename T>
T *GetMutable() {
return reinterpret_cast<T *>(
allocation_ && allocation_->size() > 0 ? allocation_->ptr() : nullptr);
}
size_t Size() const { return allocation_ ? 0 : allocation_->size(); }
private:
AllocationPtr allocation_;
platform::Place place_;
};
} // namespace memory
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/memory/buffer.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#if defined(__NVCC__) || defined(__HIPCC__)
#include "paddle/fluid/operators/kernel_primitives/functor_primitives.h"
#ifdef __NVCC__
#include "cub/cub.cuh"
#else
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
#endif
namespace paddle {
namespace operators {
namespace math {
template <typename T1, typename T2 = T1>
void SquaredL2Norm(const platform::CPUDeviceContext& ctx, const T1* x, T2* y,
size_t numel, memory::Buffer* buffer = nullptr) {
if (std::is_same<T1, T2>::value) {
using EigenT = typename framework::EigenTensor<T1, 1>::Type;
using ConstEigenT = typename framework::EigenTensor<T1, 1>::ConstType;
using EigenDim = typename framework::EigenDim<1>::Type;
ConstEigenT input(x, EigenDim(numel));
EigenT output(reinterpret_cast<T1*>(y), EigenDim(1));
output.device(*ctx.eigen_device()) = input.square().sum();
} else {
T2 ret = static_cast<T2>(0);
for (size_t i = 0; i < numel; ++i) {
auto tmp = static_cast<T2>(x[i]);
ret += tmp * tmp;
}
*y = ret;
}
}
#if defined(__NVCC__) || defined(__HIPCC__)
template <typename T1, typename T2 = T1>
void SquaredL2Norm(const platform::CUDADeviceContext& ctx, const T1* x, T2* y,
size_t numel, memory::Buffer* buffer = nullptr) {
if (UNLIKELY(buffer == nullptr)) {
memory::Buffer tmp_buffer(ctx.GetPlace());
return SquaredL2Norm(ctx, x, y, numel, &tmp_buffer);
}
using FunctorT = kernel_primitives::SquareFunctor<T1, T2>;
cub::TransformInputIterator<T2, FunctorT, const T1*> iter(x, FunctorT());
size_t temp_storage_bytes = 0;
void* d_temp_storage = nullptr;
auto stream = ctx.stream();
#pragma unroll 2
for (size_t i = 0; i < 2; ++i) {
if (temp_storage_bytes > 0) {
d_temp_storage = buffer->Alloc<void>(temp_storage_bytes);
}
PADDLE_ENFORCE_GPU_SUCCESS(
cub::DeviceReduce::Reduce(d_temp_storage, temp_storage_bytes, iter, y,
numel, cub::Sum(), static_cast<T2>(0)));
}
}
#endif
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -17,9 +17,11 @@ limitations under the License. */
#include <Eigen/Dense>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/buffer.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/math/algorithm.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/operators/math/squared_l2_norm.h"
#include "paddle/fluid/platform/eigen_ext.h"
#include "paddle/fluid/platform/for_range.h"
......@@ -383,8 +385,8 @@ struct LambParamUpateFunctor
inline HOSTDEVICE void operator()(size_t i) const {
if (skip_update_ && *skip_update_) return;
MT lr = *lr_;
MT pn = *param_norm_;
MT tn = *trust_ratio_div_norm_;
MT pn = Eigen::numext::sqrt(*param_norm_);
MT tn = Eigen::numext::sqrt(*trust_ratio_div_norm_);
MT r = (pn > static_cast<MT>(0) && tn > static_cast<MT>(0))
? pn / tn
......@@ -488,9 +490,11 @@ class LambOpKernel : public framework::OpKernel<T> {
}
auto& dev_ctx = ctx.template device_context<DeviceContext>();
platform::ForRange<DeviceContext> for_range(dev_ctx, param.numel());
auto numel = param.numel();
platform::ForRange<DeviceContext> for_range(dev_ctx, numel);
auto trust_ratio_div =
ctx.AllocateTmpTensor<MT, DeviceContext>(param.dims(), dev_ctx);
auto* trust_ratio_div_ptr = trust_ratio_div.template data<MT>();
const void* param_ptr = param.data();
const void* master_param_ptr =
......@@ -521,7 +525,7 @@ class LambOpKernel : public framework::OpKernel<T> {
grad.template data<T>(),
static_cast<const MT*>(IsMultiPrecision ? master_param_ptr
: param_ptr),
trust_ratio_div.template data<MT>(), skip_update_flag);
trust_ratio_div_ptr, skip_update_flag);
for_range(moment_update_functor);
beta1_pow_out.template mutable_data<MT>(platform::CPUPlace())[0] =
beta1 * beta1_pow.template data<MT>()[0];
......@@ -545,7 +549,7 @@ class LambOpKernel : public framework::OpKernel<T> {
grad.template data<T>(),
static_cast<const MT*>(IsMultiPrecision ? master_param_ptr
: param_ptr),
trust_ratio_div.template data<MT>(), skip_update_flag);
trust_ratio_div_ptr, skip_update_flag);
for_range(moment_update_functor);
}
} else if (grad_var->IsType<framework::SelectedRows>()) {
......@@ -638,34 +642,29 @@ class LambOpKernel : public framework::OpKernel<T> {
// Update parameter
auto p_norm_t = ctx.AllocateTmpTensor<MT, DeviceContext>({1}, dev_ctx);
auto* p_norm_ptr = p_norm_t.template data<MT>();
auto trust_ratio_div_norm_t =
ctx.AllocateTmpTensor<MT, DeviceContext>({1}, dev_ctx);
auto p_norm = framework::EigenScalar<MT>::From(p_norm_t);
auto trust_ratio_div_norm =
framework::EigenScalar<MT>::From(trust_ratio_div_norm_t);
auto t = framework::EigenVector<MT>::Flatten(trust_ratio_div);
auto* trust_ratio_div_norm_ptr = trust_ratio_div_norm_t.template data<MT>();
// TODO(zengjinle): remove the following Eigen operations when
// *skip_update == true.
auto* place = dev_ctx.eigen_device();
if (IsMultiPrecision) {
auto mp = framework::EigenVector<MT>::Flatten(*master_param);
p_norm.device(*place) = mp.square().sum().sqrt();
} else {
auto p = framework::EigenVector<MT>::Flatten(param);
p_norm.device(*place) = p.square().sum().sqrt();
}
trust_ratio_div_norm.device(*place) = t.square().sum().sqrt();
memory::Buffer buffer(dev_ctx.GetPlace());
math::SquaredL2Norm(
dev_ctx, reinterpret_cast<const MT*>(IsMultiPrecision ? master_param_ptr
: param_ptr),
p_norm_ptr, numel, &buffer);
math::SquaredL2Norm(dev_ctx, trust_ratio_div_ptr, trust_ratio_div_norm_ptr,
numel, &buffer);
#define CALL_PADDLE_UPDATE_LAMB_PARAM_FUNC(__should_update_beta_pow) \
do { \
LambParamUpateFunctor<T, MT, IsMultiPrecision, __should_update_beta_pow> \
param_update_functor( \
lr.template data<MT>(), static_cast<const T*>(param_ptr), \
static_cast<const MT*>(master_param_ptr), \
p_norm_t.template data<MT>(), trust_ratio_div.template data<MT>(), \
trust_ratio_div_norm_t.template data<MT>(), \
static_cast<const MT*>(master_param_ptr), p_norm_ptr, \
trust_ratio_div_ptr, trust_ratio_div_norm_ptr, \
static_cast<T*>(param_out_ptr), \
static_cast<MT*>(master_param_out_ptr), skip_update_flag); \
if (__should_update_beta_pow) { \
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/squared_l2_norm.h"
namespace paddle {
namespace operators {
......@@ -24,16 +25,15 @@ template <typename DeviceContext, typename T>
class SquaredL2NormKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
const framework::Tensor *X = context.Input<framework::Tensor>("X");
framework::Tensor *Out = context.Output<framework::Tensor>("Out");
Out->mutable_data<T>(context.GetPlace());
const framework::Tensor *x = context.Input<framework::Tensor>("X");
const auto *x_ptr = x->data<T>();
auto numel = x->numel();
auto x = framework::EigenVector<T>::Flatten(*X);
auto out = framework::EigenScalar<T>::From(*Out);
auto *place =
context.template device_context<DeviceContext>().eigen_device();
framework::Tensor *out = context.Output<framework::Tensor>("Out");
auto *out_ptr = out->mutable_data<T>(context.GetPlace());
out.device(*place) = x.square().sum();
math::SquaredL2Norm(context.template device_context<DeviceContext>(), x_ptr,
out_ptr, numel);
}
};
......
......@@ -18,6 +18,8 @@ import numpy as np
import unittest
from numpy import linalg as LA
from op_test import OpTest
import paddle
from paddle import _C_ops
class TestL2LossOp(OpTest):
......@@ -41,5 +43,21 @@ class TestL2LossOp(OpTest):
['X'], 'Out', max_relative_error=self.max_relative_error)
class TestL2LossDeterministic(unittest.TestCase):
def check_place(self, place):
with paddle.fluid.dygraph.guard(place):
x_np = np.random.rand(5, 11, 13).astype('float32')
x = paddle.to_tensor(x_np)
y1 = _C_ops.squared_l2_norm(x)
y2 = _C_ops.squared_l2_norm(x)
self.assertTrue(np.array_equal(y1.numpy(), y2.numpy()))
def test_main(self):
self.check_place(paddle.CPUPlace())
if paddle.is_compiled_with_cuda():
self.check_place(paddle.CUDAPlace(0))
if __name__ == "__main__":
paddle.enable_static()
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册