未验证 提交 a99c3cd4 编写于 作者: Y ykkk2333 提交者: GitHub

add xpu adagrad and where_grad kernels (#49701)

上级 ddc8a726
......@@ -31,6 +31,7 @@ XPUOpMap& get_kl2_ops() {
{"adam", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"adam_dense_param_sparse_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"adagrad", XPUKernelSet({phi::DataType::FLOAT32})},
{"arg_max", XPUKernelSet({phi::DataType::FLOAT32})},
{"argsort_grad",
XPUKernelSet({phi::DataType::INT32,
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/adagrad_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void AdagradDenseKernel(const Context& ctx,
const DenseTensor& param,
const DenseTensor& grad,
const DenseTensor& moment,
const DenseTensor& learning_rate,
float epsilon_t,
DenseTensor* param_out_tensor,
DenseTensor* moment_out_tensor) {
ctx.template Alloc<T>(param_out_tensor);
ctx.template Alloc<T>(moment_out_tensor);
T epsilon = static_cast<T>(epsilon_t);
int r = xpu::adagrad(ctx.x_context(),
param.data<T>(),
grad.data<T>(),
moment.data<T>(),
learning_rate.data<T>(),
param_out_tensor->data<T>(),
moment_out_tensor->data<T>(),
param.numel(),
epsilon);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "adagrad");
}
} // namespace phi
PD_REGISTER_KERNEL(adagrad, XPU, ALL_LAYOUT, phi::AdagradDenseKernel, float) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/where_grad_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void WhereGradKernel(const Context& ctx,
const DenseTensor& condition,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& out_grad,
DenseTensor* x_grad,
DenseTensor* y_grad) {
using XPUType = typename XPUTypeTrait<T>::Type;
const auto* cond_data = condition.data<bool>();
auto* dout = out_grad.data<T>();
auto cond_shape = phi::vectorize(condition.dims());
auto out_shape = phi::vectorize(out_grad.dims());
T* dx = nullptr;
T* dy = nullptr;
if (x_grad != nullptr) {
dx = ctx.template Alloc<T>(x_grad);
}
if (y_grad != nullptr) {
dy = ctx.template Alloc<T>(y_grad);
}
int r = xpu::select_grad(ctx.x_context(),
cond_data,
reinterpret_cast<const XPUType*>(dout),
reinterpret_cast<XPUType*>(dx),
reinterpret_cast<XPUType*>(dy),
cond_shape,
out_shape);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "select_grad");
}
} // namespace phi
PD_REGISTER_KERNEL(where_grad,
XPU,
ALL_LAYOUT,
phi::WhereGradKernel,
float,
phi::dtype::float16,
int,
int64_t) {}
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import numpy as np
import paddle
sys.path.append("..")
import unittest
from op_test_xpu import XPUOpTest
from xpu.get_test_cover_info import (
XPUOpTestWrapper,
create_test_class,
get_xpu_op_support_types,
)
paddle.enable_static()
class XPUTestAdagradOp(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'adagrad'
self.use_dynamic_create_class = False
class TestAdagradOp1(XPUOpTest):
'''Test Adagrad operator with explicit attributes'''
def setUp(self):
self.op_type = "adagrad"
self.dtype = self.in_type
param = np.random.random((123, 321)).astype(self.in_type)
grad = np.random.random((123, 321)).astype(self.in_type)
moment = np.zeros((123, 321)).astype(self.in_type)
lr = 0.01
epsilon = 1e-8
self.inputs = {
'Param': param,
'Grad': grad,
'Moment': moment,
'LearningRate': np.array([lr]).astype(self.in_type),
}
self.attrs = {'epsilon': epsilon}
moment_out = moment + grad * grad
param_out = param - lr * grad / (np.sqrt(moment_out) + epsilon)
self.outputs = {'ParamOut': param_out, 'MomentOut': moment_out}
def test_check_output(self):
self.check_output_with_place(paddle.XPUPlace(0))
class TestAdagradOp2(XPUOpTest):
'''Test Adagrad operator with default attributes'''
def setUp(self):
self.op_type = "adagrad"
param = np.random.random((123, 321)).astype(self.in_type)
grad = np.random.random((123, 321)).astype(self.in_type)
moment = np.zeros((123, 321)).astype(self.in_type)
lr = 0.01
epsilon = 1e-6
self.inputs = {
'Param': param,
'Grad': grad,
'Moment': moment,
'LearningRate': np.array([lr]).astype(self.in_type),
}
self.attrs = {'epsilon': epsilon}
moment_out = moment + grad * grad
param_out = param - lr * grad / (np.sqrt(moment_out) + epsilon)
self.outputs = {'ParamOut': param_out, 'MomentOut': moment_out}
def test_check_output(self):
self.check_output_with_place(paddle.XPUPlace(0))
support_types = get_xpu_op_support_types('adagrad')
for stype in support_types:
create_test_class(globals(), XPUTestAdagradOp, stype)
if __name__ == "__main__":
paddle.enable_static()
unittest.main()
......@@ -58,6 +58,9 @@ class XPUTestWhereOp(XPUOpTestWrapper):
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
self.check_grad_with_place(self.place, ['X', 'Y'], 'Out')
class TestXPUWhereOp2(TestXPUWhereOp):
def init_data(self):
self.x = np.random.uniform(-5, 5, (60, 2)).astype(self.dtype)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册