diff --git a/paddle/fluid/operators/optimizers/lars_momentum_op.h b/paddle/fluid/operators/optimizers/lars_momentum_op.h index df4d7b9a0438bc103f262bb4a8971a3ee31d6ebb..459900b14f61d9b64d92e3814510a6e7b7d4a35a 100644 --- a/paddle/fluid/operators/optimizers/lars_momentum_op.h +++ b/paddle/fluid/operators/optimizers/lars_momentum_op.h @@ -33,6 +33,7 @@ class LarsMomentumOpKernel : public framework::OpKernel { T mu = static_cast(ctx.Attr("mu")); T lars_coeff = ctx.Attr("lars_coeff"); T epsilon = ctx.Attr("epsilon"); + T rescale_grad = ctx.Attr("rescale_grad"); int op_num = param.size(); for (int i = 0; i < op_num; ++i) { @@ -46,6 +47,7 @@ class LarsMomentumOpKernel : public framework::OpKernel { auto p = framework::EigenVector::Flatten(*(param[i])); auto v = framework::EigenVector::Flatten(*(velocity[i])); auto g = framework::EigenVector::Flatten(*(grad[i])); + auto rescale_g = rescale_grad * g; framework::Tensor p_norm_t, g_norm_t; p_norm_t.Resize({1}); @@ -55,14 +57,14 @@ class LarsMomentumOpKernel : public framework::OpKernel { auto ep_norm = framework::EigenScalar::From(p_norm_t); auto eg_norm = framework::EigenScalar::From(g_norm_t); ep_norm = p.square().sum().sqrt(); - eg_norm = g.square().sum().sqrt(); + eg_norm = rescale_g.square().sum().sqrt(); T local_lr = lr[0]; if (lars_weight_decay > 0 && ep_norm(0) > 0 && eg_norm(0) > 0) { local_lr = lr[0] * lars_coeff * ep_norm(0) / (eg_norm(0) + lars_weight_decay * ep_norm(0) + epsilon); } - v_out = v * mu + local_lr * (g + lars_weight_decay * p); + v_out = v * mu + local_lr * (rescale_g + lars_weight_decay * p); p_out = p - v_out; } } diff --git a/paddle/fluid/platform/device/xpu/xpu2_op_list.h b/paddle/fluid/platform/device/xpu/xpu2_op_list.h index 7e9c61289b67f42e9f0b6d6dc3e537fd421b4cec..9f07f05ff7fa6dac5c8b90ab069a820e4b9cdb99 100644 --- a/paddle/fluid/platform/device/xpu/xpu2_op_list.h +++ b/paddle/fluid/platform/device/xpu/xpu2_op_list.h @@ -232,6 +232,7 @@ XPUOpMap& get_kl2_ops() { pOpKernelType(vartype::FP16, XPUPlace())})}, {"generate_proposals_v2", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"grad_add", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"greater_equal", XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()), pOpKernelType(vartype::INT32, XPUPlace()), @@ -274,6 +275,9 @@ XPUOpMap& get_kl2_ops() { pOpKernelType(vartype::INT32, XPUPlace()), pOpKernelType(vartype::FP32, XPUPlace())})}, {"log", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"log_softmax", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"log_softmax_grad", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"lookup_table_v2_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"lookup_table_v2", diff --git a/paddle/phi/kernels/xpu/elementwise_add_kernel.cc b/paddle/phi/kernels/xpu/elementwise_add_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..34d39b0a83da2fd80160982442633006413e9a8a --- /dev/null +++ b/paddle/phi/kernels/xpu/elementwise_add_kernel.cc @@ -0,0 +1,41 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/api/ext/dispatch.h" +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/elementwise_kernel_impl.h" + +namespace phi { + +template +void GradAddXPUKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* out) { + dev_ctx.template Alloc(out); + auto x_shape = phi::vectorize(x.dims()); + auto y_shape = phi::vectorize(y.dims()); + int r = xpu::broadcast_add(dev_ctx.x_context(), + x.data(), + y.data(), + out->data(), + x_shape, + y_shape); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast_add"); +} + +} // namespace phi + +PD_REGISTER_KERNEL(grad_add, XPU, ALL_LAYOUT, phi::GradAddXPUKernel, float) {} diff --git a/paddle/phi/kernels/xpu/log_softmax_grad_kernel.cc b/paddle/phi/kernels/xpu/log_softmax_grad_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..c9165f3ef7d7e042b6ff17ea21b6c1601e8a7729 --- /dev/null +++ b/paddle/phi/kernels/xpu/log_softmax_grad_kernel.cc @@ -0,0 +1,68 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/log_softmax_grad_kernel.h" + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/axis_utils.h" + +namespace phi { + +template +void LogSoftmaxGradKernel(const Context& dev_ctx, + const DenseTensor& out, + const DenseTensor& out_grad, + int axis, + DenseTensor* x_grad) { + const int rank = out.dims().size(); + axis = funcs::CanonicalAxis(axis, rank); + + if (out.numel() != 0) { + auto out_shape = phi::vectorize(out.dims()); + dev_ctx.template Alloc(x_grad); + xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); + T* tmp_ptr = RAII_GUARD.alloc_l3_or_gm(out_grad.numel()); + T* tmp2_ptr = RAII_GUARD.alloc_l3_or_gm(out_grad.numel()); + PADDLE_ENFORCE_NE( + tmp_ptr, nullptr, phi::errors::External("no enough memory in xpu")); + PADDLE_ENFORCE_NE( + tmp2_ptr, nullptr, phi::errors::External("no enough memory in xpu")); + + int r = + xpu::exp(dev_ctx.x_context(), out.data(), tmp_ptr, out_grad.numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "exp"); + r = xpu::reciprocal( + dev_ctx.x_context(), tmp_ptr, tmp2_ptr, out_grad.numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "reciprocal"); + r = xpu::mul(dev_ctx.x_context(), + tmp2_ptr, + out_grad.data(), + tmp2_ptr, + out_grad.numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "mul"); + r = xpu::softmax_grad(dev_ctx.x_context(), + tmp_ptr, + tmp2_ptr, + x_grad->data(), + out_shape, + axis); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "softmax_grad"); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL( + log_softmax_grad, XPU, ALL_LAYOUT, phi::LogSoftmaxGradKernel, float) {} diff --git a/paddle/phi/kernels/xpu/log_softmax_kernel.cc b/paddle/phi/kernels/xpu/log_softmax_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..1f084d0e6cbf7ddfb525eaefcc887013c5407898 --- /dev/null +++ b/paddle/phi/kernels/xpu/log_softmax_kernel.cc @@ -0,0 +1,47 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/log_softmax_kernel.h" + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/axis_utils.h" + +namespace phi { + +template +void LogSoftmaxKernel(const Context& dev_ctx, + const DenseTensor& x, + int axis, + DenseTensor* out) { + const int rank = x.dims().size(); + axis = funcs::CanonicalAxis(axis, rank); + + if (x.numel() != 0) { + auto x_shape = phi::vectorize(x.dims()); + dev_ctx.template Alloc(out); + if (axis < 0) axis += rank; + int r = xpu::softmax( + dev_ctx.x_context(), x.data(), out->data(), x_shape, axis); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "softmax"); + r = xpu::log( + dev_ctx.x_context(), out->data(), out->data(), out->numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "log"); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(log_softmax, XPU, ALL_LAYOUT, phi::LogSoftmaxKernel, float) { +} diff --git a/python/paddle/fluid/tests/unittests/xpu/get_test_cover_info.py b/python/paddle/fluid/tests/unittests/xpu/get_test_cover_info.py index b4032f2dcb67e0eb0b7143ee1cf8a71abc6c62c9..3da9e32b015eda8efc6208fb50adc8cb99a73a55 100644 --- a/python/paddle/fluid/tests/unittests/xpu/get_test_cover_info.py +++ b/python/paddle/fluid/tests/unittests/xpu/get_test_cover_info.py @@ -85,7 +85,10 @@ type_dict_str_to_numpy = { xpu_test_op_white_list = [] xpu_test_device_type_white_list = ['xpu1_float64'] xpu_test_op_type_white_list = [ - 'dropout_float16', 'dropout_grad_float16', 'matmul_v2_float16' + 'dropout_float16', + 'dropout_grad_float16', + 'matmul_v2_float16', + "grad_add_float32" # no api for grad_add, skip ] xpu_test_device_op_white_list = [] xpu_test_device_op_type_white_list = [] diff --git a/python/paddle/fluid/tests/unittests/xpu/test_log_softmax_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_log_softmax_op_xpu.py new file mode 100644 index 0000000000000000000000000000000000000000..e7e730d9b2e25da99cf6263dd14c9f381dd51992 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_log_softmax_op_xpu.py @@ -0,0 +1,107 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +import sys + +sys.path.append("..") +from op_test import OpTest + +import paddle +import paddle.fluid.core as core +import paddle.nn.functional as F + +from op_test_xpu import XPUOpTest +from xpu.get_test_cover_info import create_test_class, get_xpu_op_support_types, XPUOpTestWrapper + +paddle.enable_static() +np.random.seed(10) + + +def ref_log_softmax(x): + shiftx = (x - np.max(x)) + out = shiftx - np.log(np.exp(shiftx).sum()) + return out + + +def ref_log_softmax_grad(x, axis): + if axis < 0: + axis += len(x.shape) + out = np.apply_along_axis(ref_log_softmax, axis, x) + axis_dim = x.shape[axis] + dout = np.full_like(x, fill_value=1. / x.size) + dx = dout - np.exp(out) * dout.copy().sum(axis=axis, keepdims=True).repeat( + axis_dim, axis=axis) + return dx + + +class XPUTestLogSoftmaxOp(XPUOpTestWrapper): + + def __init__(self): + self.op_name = 'log_softmax' + self.use_dynamic_create_class = True + + def dynamic_create_class(self): + base_class = self.TestXPULogSoftmaxOp + classes = [] + axis_arr = [-1, 1] + shape_arr = [[2, 3, 4, 5], [12, 10], [2, 5], [7, 7], [3, 5, 7]] + for axis in axis_arr: + for shape in shape_arr: + class_name = 'XPUTestLogSoftmax_' + \ + str(axis) + "_" + str(shape) + attr_dict = {'axis': axis, 'shape': shape} + classes.append([class_name, attr_dict]) + return base_class, classes + + class TestXPULogSoftmaxOp(XPUOpTest): + + def setUp(self): + self.op_type = 'log_softmax' + self.python_api = F.log_softmax + self.dtype = 'float32' + self.set_attrs() + self.use_xpu = True + if not hasattr(self, 'axis'): + self.shape = [2, 3, 4, 5] + self.axis = -1 + + x = np.random.uniform(0.1, 1., self.shape).astype(self.dtype) + out = np.apply_along_axis(ref_log_softmax, self.axis, x) + self.x_grad = ref_log_softmax_grad(x, self.axis) + + self.inputs = {'X': x} + self.outputs = {'Out': out} + self.attrs = {'axis': self.axis} + + def set_attrs(self): + pass + + def test_check_output(self): + self.check_output(check_eager=True) + + def test_check_grad(self): + self.check_grad(['X'], ['Out'], + user_defined_grads=[self.x_grad], + check_eager=True) + + +support_types = get_xpu_op_support_types('log_softmax') +for stype in support_types: + create_test_class(globals(), XPUTestLogSoftmaxOp, stype) + +if __name__ == "__main__": + paddle.enable_static() + unittest.main()