diff --git a/paddle/fluid/platform/device/xpu/xpu2_op_list.h b/paddle/fluid/platform/device/xpu/xpu2_op_list.h index 79ee3924055cc9d2f1713f3db8ac8f667680e2a6..eef7eba7bc7d242ab82c114d2b6406ef5ba3d469 100644 --- a/paddle/fluid/platform/device/xpu/xpu2_op_list.h +++ b/paddle/fluid/platform/device/xpu/xpu2_op_list.h @@ -33,6 +33,7 @@ XPUOpMap& get_kl2_ops() { {"abs_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), pOpKernelType(vartype::FP16, XPUPlace())})}, + {"adadelta", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"adamw", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"adam", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), @@ -109,6 +110,8 @@ XPUOpMap& get_kl2_ops() { XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), pOpKernelType(vartype::FP16, XPUPlace())})}, {"clip", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"clip_by_norm", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"coalesce_tensor", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"concat_grad", diff --git a/paddle/phi/kernels/xpu/adadelta_kernel.cc b/paddle/phi/kernels/xpu/adadelta_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..153f27f54c97e584fc8c9ce1b4d6c40de12ab0ef --- /dev/null +++ b/paddle/phi/kernels/xpu/adadelta_kernel.cc @@ -0,0 +1,53 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/adadelta_kernel.h" + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void AdadeltaKernel(const Context& dev_ctx, + const DenseTensor& param, + const DenseTensor& grad, + const DenseTensor& avg_squared_grad, + const DenseTensor& avg_squared_update, + float rho, + float epsilon, + DenseTensor* param_out, + DenseTensor* avg_squared_grad_out, + DenseTensor* avg_squared_update_out) { + dev_ctx.template Alloc(param_out); + dev_ctx.template Alloc(avg_squared_grad_out); + dev_ctx.template Alloc(avg_squared_update_out); + + int r = xpu::adadelta(dev_ctx.x_context(), + param.data(), + grad.data(), + avg_squared_grad.data(), + avg_squared_update.data(), + param_out->data(), + avg_squared_grad_out->data(), + avg_squared_update_out->data(), + param.numel(), + rho, + epsilon); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "adadelta"); +} + +} // namespace phi + +PD_REGISTER_KERNEL(adadelta, XPU, ALL_LAYOUT, phi::AdadeltaKernel, float) {} diff --git a/python/paddle/fluid/tests/unittests/xpu/test_adadelta_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_adadelta_op_xpu.py new file mode 100644 index 0000000000000000000000000000000000000000..d65e20522a20e72ec9d70ceeb05cf47b3b217ca7 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_adadelta_op_xpu.py @@ -0,0 +1,239 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +import sys + +sys.path.append("..") + +from op_test import OpTest +import paddle +import paddle.fluid as fluid +from op_test_xpu import XPUOpTest +from xpu.get_test_cover_info import ( + create_test_class, + get_xpu_op_support_types, + XPUOpTestWrapper, +) + +paddle.enable_static() + + +class XPUTestAdadelta(XPUOpTestWrapper): + def __init__(self): + self.op_name = 'adadelta' + + class TestAdadeltaOp1(XPUOpTest): + def setUp(self): + self.op_type = "adadelta" + self.dtype = self.in_type + self.place = paddle.XPUPlace(0) + + param = np.random.uniform(-1, 1, (102, 105)).astype(self.dtype) + grad = np.random.uniform(-1, 1, (102, 105)).astype(self.dtype) + # The squared gradient is positive + avg_squared_grad = np.random.random((102, 105)).astype(self.dtype) + # The squared update is positive + avg_squared_update = np.random.random((102, 105)).astype(self.dtype) + + rho = 0.95 + epsilon = 1e-6 + + self.inputs = { + 'Param': param, + 'Grad': grad, + 'AvgSquaredGrad': avg_squared_grad, + 'AvgSquaredUpdate': avg_squared_update, + } + + self.attrs = {'rho': rho, 'epsilon': epsilon} + + avg_squared_grad_out = rho * avg_squared_grad + ( + 1 - rho + ) * np.square(grad) + update = -np.multiply( + np.sqrt( + np.divide( + avg_squared_update + epsilon, + avg_squared_grad_out + epsilon, + ) + ), + grad, + ) + + avg_squared_update_out = rho * avg_squared_update + ( + 1 - rho + ) * np.square(update) + + param_out = param + update + + self.outputs = { + 'ParamOut': param_out, + 'AvgSquaredGradOut': avg_squared_grad_out, + 'AvgSquaredUpdateOut': avg_squared_update_out, + } + + def test_check_output(self): + self.check_output() + + class TestAdadeltaOp2(OpTest): + '''Test Adadelta op with default attribute values''' + + def setUp(self): + self.op_type = "adadelta" + self.dtype = self.in_type + self.place = paddle.XPUPlace(0) + + param = np.random.uniform(-1, 1, (102, 105)).astype(self.dtype) + grad = np.random.uniform(-1, 1, (102, 105)).astype(self.dtype) + # The squared gradient is positive + avg_squared_grad = np.random.random((102, 105)).astype(self.dtype) + # The squared update is positive + avg_squared_update = np.random.random((102, 105)).astype(self.dtype) + + rho = 0.95 + epsilon = 1e-6 + + self.inputs = { + 'Param': param, + 'Grad': grad, + 'AvgSquaredGrad': avg_squared_grad, + 'AvgSquaredUpdate': avg_squared_update, + } + + avg_squared_grad_out = rho * avg_squared_grad + ( + 1 - rho + ) * np.square(grad) + update = -np.multiply( + np.sqrt( + np.divide( + avg_squared_update + epsilon, + avg_squared_grad_out + epsilon, + ) + ), + grad, + ) + + avg_squared_update_out = rho * avg_squared_update + ( + 1 - rho + ) * np.square(update) + + param_out = param + update + + self.outputs = { + 'ParamOut': param_out, + 'AvgSquaredGradOut': avg_squared_grad_out, + 'AvgSquaredUpdateOut': avg_squared_update_out, + } + + def test_check_output(self): + self.check_output() + + class TestAdadeltaV2(unittest.TestCase): + def test_adadelta_dygraph(self): + self.dtype = self.in_type + self.place = paddle.XPUPlace(0) + + paddle.disable_static(self.place) + value = np.arange(26).reshape(2, 13).astype(self.dtype) + a = paddle.to_tensor(value) + linear = paddle.nn.Linear(13, 5) + # This can be any optimizer supported by dygraph. + adam = paddle.optimizer.Adadelta( + learning_rate=0.01, + parameters=linear.parameters(), + weight_decay=0.01, + ) + out = linear(a) + out.backward() + adam.step() + adam.clear_gradients() + + def test_adadelta(self): + self.dtype = self.in_type + paddle.enable_static() + place = fluid.XPUPlace(0) + main = fluid.Program() + with fluid.program_guard(main): + x = fluid.layers.data(name='x', shape=[13], dtype=self.dtype) + y = fluid.layers.data(name='y', shape=[1], dtype=self.dtype) + y_predict = fluid.layers.fc(input=x, size=1, act=None) + cost = fluid.layers.square_error_cost(input=y_predict, label=y) + avg_cost = paddle.mean(cost) + + rms_optimizer = paddle.optimizer.Adadelta(learning_rate=0.1) + rms_optimizer.minimize(avg_cost) + + fetch_list = [avg_cost] + train_reader = paddle.batch( + paddle.dataset.uci_housing.train(), batch_size=1 + ) + feeder = fluid.DataFeeder(place=place, feed_list=[x, y]) + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + for data in train_reader(): + exe.run(main, feed=feeder.feed(data), fetch_list=fetch_list) + + def test_raise_error(self): + self.assertRaises(ValueError, paddle.optimizer.Adadelta, None) + self.assertRaises( + ValueError, + paddle.optimizer.Adadelta, + learning_rate=0.1, + rho=None, + ) + self.assertRaises( + ValueError, + paddle.optimizer.Adadelta, + learning_rate=0.1, + epsilon=None, + ) + + class TestAdadeltaV2Group(TestAdadeltaV2): + def test_adadelta_dygraph(self): + self.dtype = self.in_type + self.place = paddle.XPUPlace(0) + + paddle.disable_static(self.place) + value = np.arange(26).reshape(2, 13).astype(self.dtype) + a = paddle.to_tensor(value) + linear_1 = paddle.nn.Linear(13, 5) + linear_2 = paddle.nn.Linear(5, 5) + # This can be any optimizer supported by dygraph. + adam = paddle.optimizer.Adadelta( + learning_rate=0.01, + parameters=[ + {'params': linear_1.parameters()}, + { + 'params': linear_2.parameters(), + 'weight_decay': 0.001, + }, + ], + weight_decay=0.1, + ) + out = linear_1(a) + out = linear_2(out) + out.backward() + adam.step() + adam.clear_gradients() + + +support_types = get_xpu_op_support_types('adadelta') +for stype in support_types: + create_test_class(globals(), XPUTestAdadelta, stype) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/xpu/test_clip_by_norm_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_clip_by_norm_op_xpu.py index 1bfa96b5975deba6f12d42adea2533a5f27b8d25..e439b9fc29dc93904734dfd8908f75b4c352820f 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_clip_by_norm_op_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_clip_by_norm_op_xpu.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,56 +19,68 @@ import unittest import numpy as np from op_test_xpu import XPUOpTest import paddle +from xpu.get_test_cover_info import ( + create_test_class, + get_xpu_op_support_types, + XPUOpTestWrapper, +) -class TestXPUClipByNormOp(XPUOpTest): - def setUp(self): - self.op_type = "clip_by_norm" - self.dtype = np.float32 - self.use_xpu = True - self.max_relative_error = 0.006 - self.initTestCase() - input = np.random.random(self.shape).astype("float32") - input[np.abs(input) < self.max_relative_error] = 0.5 - self.inputs = { - 'X': input, - } - self.attrs = {} - self.attrs['max_norm'] = self.max_norm - norm = np.sqrt(np.sum(np.square(input))) - if norm > self.max_norm: - output = self.max_norm * input / norm - else: - output = input - self.outputs = {'Out': output} +class XPUTestClipByNormOp(XPUOpTestWrapper): + def __init__(self): + self.op_name = 'clip_by_norm' + self.use_dynamic_create_class = False - def test_check_output(self): - if paddle.is_compiled_with_xpu(): - paddle.enable_static() - place = paddle.XPUPlace(0) - self.check_output_with_place(place) + class TestClipByNormOp(XPUOpTest): + def setUp(self): + self.op_type = "clip_by_norm" + self.dtype = self.in_type + self.place = paddle.XPUPlace(0) + self.use_xpu = True + self.max_relative_error = 0.006 + self.initTestCase() + input = np.random.random(self.shape).astype(self.dtype) + input[np.abs(input) < self.max_relative_error] = 0.5 + self.inputs = { + 'X': input, + } + self.attrs = {} + self.attrs['max_norm'] = self.max_norm + norm = np.sqrt(np.sum(np.square(input))) + if norm > self.max_norm: + output = self.max_norm * input / norm + else: + output = input + self.outputs = {'Out': output} - def initTestCase(self): - self.shape = (100,) - self.max_norm = 1.0 + def test_check_output(self): + if paddle.is_compiled_with_xpu(): + paddle.enable_static() + self.check_output_with_place(self.place) + def initTestCase(self): + self.shape = (100,) + self.max_norm = 1.0 -class TestCase1(TestXPUClipByNormOp): - def initTestCase(self): - self.shape = (100,) - self.max_norm = 1e20 + class TestCase1(TestClipByNormOp): + def initTestCase(self): + self.shape = (100,) + self.max_norm = 1e20 + class TestCase2(TestClipByNormOp): + def initTestCase(self): + self.shape = (16, 16) + self.max_norm = 0.1 -class TestCase2(TestXPUClipByNormOp): - def initTestCase(self): - self.shape = (16, 16) - self.max_norm = 0.1 + class TestCase3(TestClipByNormOp): + def initTestCase(self): + self.shape = (4, 8, 16) + self.max_norm = 1.0 -class TestCase3(TestXPUClipByNormOp): - def initTestCase(self): - self.shape = (4, 8, 16) - self.max_norm = 1.0 +support_types = get_xpu_op_support_types('clip_by_norm') +for stype in support_types: + create_test_class(globals(), XPUTestClipByNormOp, stype) if __name__ == "__main__":