未验证 提交 d6d60cbc 编写于 作者: Q QingshuChen 提交者: GitHub

fix cpu lars_momentum bug & add xpu grad_add/log_softmax/log_softmax_… (#44260)

* fix cpu lars_momentum bug & add xpu grad_add/log_softmax/log_softmax_grad
*test=kunlun

* minor
*test=kunlun
上级 87619829
......@@ -33,6 +33,7 @@ class LarsMomentumOpKernel : public framework::OpKernel<T> {
T mu = static_cast<T>(ctx.Attr<float>("mu"));
T lars_coeff = ctx.Attr<float>("lars_coeff");
T epsilon = ctx.Attr<float>("epsilon");
T rescale_grad = ctx.Attr<float>("rescale_grad");
int op_num = param.size();
for (int i = 0; i < op_num; ++i) {
......@@ -46,6 +47,7 @@ class LarsMomentumOpKernel : public framework::OpKernel<T> {
auto p = framework::EigenVector<T>::Flatten(*(param[i]));
auto v = framework::EigenVector<T>::Flatten(*(velocity[i]));
auto g = framework::EigenVector<T>::Flatten(*(grad[i]));
auto rescale_g = rescale_grad * g;
framework::Tensor p_norm_t, g_norm_t;
p_norm_t.Resize({1});
......@@ -55,14 +57,14 @@ class LarsMomentumOpKernel : public framework::OpKernel<T> {
auto ep_norm = framework::EigenScalar<T>::From(p_norm_t);
auto eg_norm = framework::EigenScalar<T>::From(g_norm_t);
ep_norm = p.square().sum().sqrt();
eg_norm = g.square().sum().sqrt();
eg_norm = rescale_g.square().sum().sqrt();
T local_lr = lr[0];
if (lars_weight_decay > 0 && ep_norm(0) > 0 && eg_norm(0) > 0) {
local_lr = lr[0] * lars_coeff * ep_norm(0) /
(eg_norm(0) + lars_weight_decay * ep_norm(0) + epsilon);
}
v_out = v * mu + local_lr * (g + lars_weight_decay * p);
v_out = v * mu + local_lr * (rescale_g + lars_weight_decay * p);
p_out = p - v_out;
}
}
......
......@@ -232,6 +232,7 @@ XPUOpMap& get_kl2_ops() {
pOpKernelType(vartype::FP16, XPUPlace())})},
{"generate_proposals_v2",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"grad_add", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"greater_equal",
XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
......@@ -274,6 +275,9 @@ XPUOpMap& get_kl2_ops() {
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"log", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"log_softmax", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"log_softmax_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"lookup_table_v2_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"lookup_table_v2",
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/api/ext/dispatch.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/elementwise_kernel_impl.h"
namespace phi {
template <typename T, typename Context>
void GradAddXPUKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
dev_ctx.template Alloc<T>(out);
auto x_shape = phi::vectorize<int>(x.dims());
auto y_shape = phi::vectorize<int>(y.dims());
int r = xpu::broadcast_add(dev_ctx.x_context(),
x.data<T>(),
y.data<T>(),
out->data<T>(),
x_shape,
y_shape);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast_add");
}
} // namespace phi
PD_REGISTER_KERNEL(grad_add, XPU, ALL_LAYOUT, phi::GradAddXPUKernel, float) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/log_softmax_grad_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/axis_utils.h"
namespace phi {
template <typename T, typename Context>
void LogSoftmaxGradKernel(const Context& dev_ctx,
const DenseTensor& out,
const DenseTensor& out_grad,
int axis,
DenseTensor* x_grad) {
const int rank = out.dims().size();
axis = funcs::CanonicalAxis(axis, rank);
if (out.numel() != 0) {
auto out_shape = phi::vectorize<int>(out.dims());
dev_ctx.template Alloc<T>(x_grad);
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
T* tmp_ptr = RAII_GUARD.alloc_l3_or_gm<T>(out_grad.numel());
T* tmp2_ptr = RAII_GUARD.alloc_l3_or_gm<T>(out_grad.numel());
PADDLE_ENFORCE_NE(
tmp_ptr, nullptr, phi::errors::External("no enough memory in xpu"));
PADDLE_ENFORCE_NE(
tmp2_ptr, nullptr, phi::errors::External("no enough memory in xpu"));
int r =
xpu::exp(dev_ctx.x_context(), out.data<T>(), tmp_ptr, out_grad.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "exp");
r = xpu::reciprocal(
dev_ctx.x_context(), tmp_ptr, tmp2_ptr, out_grad.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "reciprocal");
r = xpu::mul(dev_ctx.x_context(),
tmp2_ptr,
out_grad.data<T>(),
tmp2_ptr,
out_grad.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "mul");
r = xpu::softmax_grad(dev_ctx.x_context(),
tmp_ptr,
tmp2_ptr,
x_grad->data<T>(),
out_shape,
axis);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "softmax_grad");
}
}
} // namespace phi
PD_REGISTER_KERNEL(
log_softmax_grad, XPU, ALL_LAYOUT, phi::LogSoftmaxGradKernel, float) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/log_softmax_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/axis_utils.h"
namespace phi {
template <typename T, typename Context>
void LogSoftmaxKernel(const Context& dev_ctx,
const DenseTensor& x,
int axis,
DenseTensor* out) {
const int rank = x.dims().size();
axis = funcs::CanonicalAxis(axis, rank);
if (x.numel() != 0) {
auto x_shape = phi::vectorize<int>(x.dims());
dev_ctx.template Alloc<T>(out);
if (axis < 0) axis += rank;
int r = xpu::softmax<T>(
dev_ctx.x_context(), x.data<T>(), out->data<T>(), x_shape, axis);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "softmax");
r = xpu::log<T>(
dev_ctx.x_context(), out->data<T>(), out->data<T>(), out->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "log");
}
}
} // namespace phi
PD_REGISTER_KERNEL(log_softmax, XPU, ALL_LAYOUT, phi::LogSoftmaxKernel, float) {
}
......@@ -85,7 +85,10 @@ type_dict_str_to_numpy = {
xpu_test_op_white_list = []
xpu_test_device_type_white_list = ['xpu1_float64']
xpu_test_op_type_white_list = [
'dropout_float16', 'dropout_grad_float16', 'matmul_v2_float16'
'dropout_float16',
'dropout_grad_float16',
'matmul_v2_float16',
"grad_add_float32" # no api for grad_add, skip
]
xpu_test_device_op_white_list = []
xpu_test_device_op_type_white_list = []
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import sys
sys.path.append("..")
from op_test import OpTest
import paddle
import paddle.fluid.core as core
import paddle.nn.functional as F
from op_test_xpu import XPUOpTest
from xpu.get_test_cover_info import create_test_class, get_xpu_op_support_types, XPUOpTestWrapper
paddle.enable_static()
np.random.seed(10)
def ref_log_softmax(x):
shiftx = (x - np.max(x))
out = shiftx - np.log(np.exp(shiftx).sum())
return out
def ref_log_softmax_grad(x, axis):
if axis < 0:
axis += len(x.shape)
out = np.apply_along_axis(ref_log_softmax, axis, x)
axis_dim = x.shape[axis]
dout = np.full_like(x, fill_value=1. / x.size)
dx = dout - np.exp(out) * dout.copy().sum(axis=axis, keepdims=True).repeat(
axis_dim, axis=axis)
return dx
class XPUTestLogSoftmaxOp(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'log_softmax'
self.use_dynamic_create_class = True
def dynamic_create_class(self):
base_class = self.TestXPULogSoftmaxOp
classes = []
axis_arr = [-1, 1]
shape_arr = [[2, 3, 4, 5], [12, 10], [2, 5], [7, 7], [3, 5, 7]]
for axis in axis_arr:
for shape in shape_arr:
class_name = 'XPUTestLogSoftmax_' + \
str(axis) + "_" + str(shape)
attr_dict = {'axis': axis, 'shape': shape}
classes.append([class_name, attr_dict])
return base_class, classes
class TestXPULogSoftmaxOp(XPUOpTest):
def setUp(self):
self.op_type = 'log_softmax'
self.python_api = F.log_softmax
self.dtype = 'float32'
self.set_attrs()
self.use_xpu = True
if not hasattr(self, 'axis'):
self.shape = [2, 3, 4, 5]
self.axis = -1
x = np.random.uniform(0.1, 1., self.shape).astype(self.dtype)
out = np.apply_along_axis(ref_log_softmax, self.axis, x)
self.x_grad = ref_log_softmax_grad(x, self.axis)
self.inputs = {'X': x}
self.outputs = {'Out': out}
self.attrs = {'axis': self.axis}
def set_attrs(self):
pass
def test_check_output(self):
self.check_output(check_eager=True)
def test_check_grad(self):
self.check_grad(['X'], ['Out'],
user_defined_grads=[self.x_grad],
check_eager=True)
support_types = get_xpu_op_support_types('log_softmax')
for stype in support_types:
create_test_class(globals(), XPUTestLogSoftmaxOp, stype)
if __name__ == "__main__":
paddle.enable_static()
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册