未验证 提交 047971f0 编写于 作者: Z zhangyikun02 提交者: GitHub

add adadelta op for xpu, test=kunlun (#47661)

上级 6a6a3ff1
...@@ -33,6 +33,7 @@ XPUOpMap& get_kl2_ops() { ...@@ -33,6 +33,7 @@ XPUOpMap& get_kl2_ops() {
{"abs_grad", {"abs_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})}, pOpKernelType(vartype::FP16, XPUPlace())})},
{"adadelta", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"adamw", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"adamw", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"adam", {"adam",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
...@@ -109,6 +110,8 @@ XPUOpMap& get_kl2_ops() { ...@@ -109,6 +110,8 @@ XPUOpMap& get_kl2_ops() {
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})}, pOpKernelType(vartype::FP16, XPUPlace())})},
{"clip", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"clip", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"clip_by_norm",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"coalesce_tensor", {"coalesce_tensor",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"concat_grad", {"concat_grad",
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/adadelta_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void AdadeltaKernel(const Context& dev_ctx,
const DenseTensor& param,
const DenseTensor& grad,
const DenseTensor& avg_squared_grad,
const DenseTensor& avg_squared_update,
float rho,
float epsilon,
DenseTensor* param_out,
DenseTensor* avg_squared_grad_out,
DenseTensor* avg_squared_update_out) {
dev_ctx.template Alloc<T>(param_out);
dev_ctx.template Alloc<T>(avg_squared_grad_out);
dev_ctx.template Alloc<T>(avg_squared_update_out);
int r = xpu::adadelta<T, T>(dev_ctx.x_context(),
param.data<T>(),
grad.data<T>(),
avg_squared_grad.data<T>(),
avg_squared_update.data<T>(),
param_out->data<T>(),
avg_squared_grad_out->data<T>(),
avg_squared_update_out->data<T>(),
param.numel(),
rho,
epsilon);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "adadelta");
}
} // namespace phi
PD_REGISTER_KERNEL(adadelta, XPU, ALL_LAYOUT, phi::AdadeltaKernel, float) {}
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import sys
sys.path.append("..")
from op_test import OpTest
import paddle
import paddle.fluid as fluid
from op_test_xpu import XPUOpTest
from xpu.get_test_cover_info import (
create_test_class,
get_xpu_op_support_types,
XPUOpTestWrapper,
)
paddle.enable_static()
class XPUTestAdadelta(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'adadelta'
class TestAdadeltaOp1(XPUOpTest):
def setUp(self):
self.op_type = "adadelta"
self.dtype = self.in_type
self.place = paddle.XPUPlace(0)
param = np.random.uniform(-1, 1, (102, 105)).astype(self.dtype)
grad = np.random.uniform(-1, 1, (102, 105)).astype(self.dtype)
# The squared gradient is positive
avg_squared_grad = np.random.random((102, 105)).astype(self.dtype)
# The squared update is positive
avg_squared_update = np.random.random((102, 105)).astype(self.dtype)
rho = 0.95
epsilon = 1e-6
self.inputs = {
'Param': param,
'Grad': grad,
'AvgSquaredGrad': avg_squared_grad,
'AvgSquaredUpdate': avg_squared_update,
}
self.attrs = {'rho': rho, 'epsilon': epsilon}
avg_squared_grad_out = rho * avg_squared_grad + (
1 - rho
) * np.square(grad)
update = -np.multiply(
np.sqrt(
np.divide(
avg_squared_update + epsilon,
avg_squared_grad_out + epsilon,
)
),
grad,
)
avg_squared_update_out = rho * avg_squared_update + (
1 - rho
) * np.square(update)
param_out = param + update
self.outputs = {
'ParamOut': param_out,
'AvgSquaredGradOut': avg_squared_grad_out,
'AvgSquaredUpdateOut': avg_squared_update_out,
}
def test_check_output(self):
self.check_output()
class TestAdadeltaOp2(OpTest):
'''Test Adadelta op with default attribute values'''
def setUp(self):
self.op_type = "adadelta"
self.dtype = self.in_type
self.place = paddle.XPUPlace(0)
param = np.random.uniform(-1, 1, (102, 105)).astype(self.dtype)
grad = np.random.uniform(-1, 1, (102, 105)).astype(self.dtype)
# The squared gradient is positive
avg_squared_grad = np.random.random((102, 105)).astype(self.dtype)
# The squared update is positive
avg_squared_update = np.random.random((102, 105)).astype(self.dtype)
rho = 0.95
epsilon = 1e-6
self.inputs = {
'Param': param,
'Grad': grad,
'AvgSquaredGrad': avg_squared_grad,
'AvgSquaredUpdate': avg_squared_update,
}
avg_squared_grad_out = rho * avg_squared_grad + (
1 - rho
) * np.square(grad)
update = -np.multiply(
np.sqrt(
np.divide(
avg_squared_update + epsilon,
avg_squared_grad_out + epsilon,
)
),
grad,
)
avg_squared_update_out = rho * avg_squared_update + (
1 - rho
) * np.square(update)
param_out = param + update
self.outputs = {
'ParamOut': param_out,
'AvgSquaredGradOut': avg_squared_grad_out,
'AvgSquaredUpdateOut': avg_squared_update_out,
}
def test_check_output(self):
self.check_output()
class TestAdadeltaV2(unittest.TestCase):
def test_adadelta_dygraph(self):
self.dtype = self.in_type
self.place = paddle.XPUPlace(0)
paddle.disable_static(self.place)
value = np.arange(26).reshape(2, 13).astype(self.dtype)
a = paddle.to_tensor(value)
linear = paddle.nn.Linear(13, 5)
# This can be any optimizer supported by dygraph.
adam = paddle.optimizer.Adadelta(
learning_rate=0.01,
parameters=linear.parameters(),
weight_decay=0.01,
)
out = linear(a)
out.backward()
adam.step()
adam.clear_gradients()
def test_adadelta(self):
self.dtype = self.in_type
paddle.enable_static()
place = fluid.XPUPlace(0)
main = fluid.Program()
with fluid.program_guard(main):
x = fluid.layers.data(name='x', shape=[13], dtype=self.dtype)
y = fluid.layers.data(name='y', shape=[1], dtype=self.dtype)
y_predict = fluid.layers.fc(input=x, size=1, act=None)
cost = fluid.layers.square_error_cost(input=y_predict, label=y)
avg_cost = paddle.mean(cost)
rms_optimizer = paddle.optimizer.Adadelta(learning_rate=0.1)
rms_optimizer.minimize(avg_cost)
fetch_list = [avg_cost]
train_reader = paddle.batch(
paddle.dataset.uci_housing.train(), batch_size=1
)
feeder = fluid.DataFeeder(place=place, feed_list=[x, y])
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
for data in train_reader():
exe.run(main, feed=feeder.feed(data), fetch_list=fetch_list)
def test_raise_error(self):
self.assertRaises(ValueError, paddle.optimizer.Adadelta, None)
self.assertRaises(
ValueError,
paddle.optimizer.Adadelta,
learning_rate=0.1,
rho=None,
)
self.assertRaises(
ValueError,
paddle.optimizer.Adadelta,
learning_rate=0.1,
epsilon=None,
)
class TestAdadeltaV2Group(TestAdadeltaV2):
def test_adadelta_dygraph(self):
self.dtype = self.in_type
self.place = paddle.XPUPlace(0)
paddle.disable_static(self.place)
value = np.arange(26).reshape(2, 13).astype(self.dtype)
a = paddle.to_tensor(value)
linear_1 = paddle.nn.Linear(13, 5)
linear_2 = paddle.nn.Linear(5, 5)
# This can be any optimizer supported by dygraph.
adam = paddle.optimizer.Adadelta(
learning_rate=0.01,
parameters=[
{'params': linear_1.parameters()},
{
'params': linear_2.parameters(),
'weight_decay': 0.001,
},
],
weight_decay=0.1,
)
out = linear_1(a)
out = linear_2(out)
out.backward()
adam.step()
adam.clear_gradients()
support_types = get_xpu_op_support_types('adadelta')
for stype in support_types:
create_test_class(globals(), XPUTestAdadelta, stype)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -19,56 +19,68 @@ import unittest ...@@ -19,56 +19,68 @@ import unittest
import numpy as np import numpy as np
from op_test_xpu import XPUOpTest from op_test_xpu import XPUOpTest
import paddle import paddle
from xpu.get_test_cover_info import (
create_test_class,
get_xpu_op_support_types,
XPUOpTestWrapper,
)
class TestXPUClipByNormOp(XPUOpTest): class XPUTestClipByNormOp(XPUOpTestWrapper):
def setUp(self): def __init__(self):
self.op_type = "clip_by_norm" self.op_name = 'clip_by_norm'
self.dtype = np.float32 self.use_dynamic_create_class = False
self.use_xpu = True
self.max_relative_error = 0.006
self.initTestCase()
input = np.random.random(self.shape).astype("float32")
input[np.abs(input) < self.max_relative_error] = 0.5
self.inputs = {
'X': input,
}
self.attrs = {}
self.attrs['max_norm'] = self.max_norm
norm = np.sqrt(np.sum(np.square(input)))
if norm > self.max_norm:
output = self.max_norm * input / norm
else:
output = input
self.outputs = {'Out': output}
def test_check_output(self): class TestClipByNormOp(XPUOpTest):
if paddle.is_compiled_with_xpu(): def setUp(self):
paddle.enable_static() self.op_type = "clip_by_norm"
place = paddle.XPUPlace(0) self.dtype = self.in_type
self.check_output_with_place(place) self.place = paddle.XPUPlace(0)
self.use_xpu = True
self.max_relative_error = 0.006
self.initTestCase()
input = np.random.random(self.shape).astype(self.dtype)
input[np.abs(input) < self.max_relative_error] = 0.5
self.inputs = {
'X': input,
}
self.attrs = {}
self.attrs['max_norm'] = self.max_norm
norm = np.sqrt(np.sum(np.square(input)))
if norm > self.max_norm:
output = self.max_norm * input / norm
else:
output = input
self.outputs = {'Out': output}
def initTestCase(self): def test_check_output(self):
self.shape = (100,) if paddle.is_compiled_with_xpu():
self.max_norm = 1.0 paddle.enable_static()
self.check_output_with_place(self.place)
def initTestCase(self):
self.shape = (100,)
self.max_norm = 1.0
class TestCase1(TestXPUClipByNormOp): class TestCase1(TestClipByNormOp):
def initTestCase(self): def initTestCase(self):
self.shape = (100,) self.shape = (100,)
self.max_norm = 1e20 self.max_norm = 1e20
class TestCase2(TestClipByNormOp):
def initTestCase(self):
self.shape = (16, 16)
self.max_norm = 0.1
class TestCase2(TestXPUClipByNormOp): class TestCase3(TestClipByNormOp):
def initTestCase(self): def initTestCase(self):
self.shape = (16, 16) self.shape = (4, 8, 16)
self.max_norm = 0.1 self.max_norm = 1.0
class TestCase3(TestXPUClipByNormOp): support_types = get_xpu_op_support_types('clip_by_norm')
def initTestCase(self): for stype in support_types:
self.shape = (4, 8, 16) create_test_class(globals(), XPUTestClipByNormOp, stype)
self.max_norm = 1.0
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册