From ae01801f0a12203a239c33cfd59d4d0ecaaec7da Mon Sep 17 00:00:00 2001 From: xiaoting <31891223+tink2123@users.noreply.github.com> Date: Wed, 14 Oct 2020 00:07:50 +0800 Subject: [PATCH] Add dropout and log_loss for kunlun (#27790) * add dropout,log_loss, test=kunlun * fix dropout, test=kunlun * polish error message, test=kunlun * change boost::get to BOOST_GET_CONST, test=kunlun * fix copyright, test=kunlun --- paddle/fluid/operators/dropout_op_xpu.cc | 130 ++++++++++++++++++ paddle/fluid/operators/log_loss_op_xpu.cc | 68 +++++++++ .../unittests/xpu/test_dropout_op_xpu.py | 112 +++++++++++++++ .../unittests/xpu/test_log_loss_op_xpu.py | 65 +++++++++ 4 files changed, 375 insertions(+) create mode 100644 paddle/fluid/operators/dropout_op_xpu.cc create mode 100644 paddle/fluid/operators/log_loss_op_xpu.cc create mode 100644 python/paddle/fluid/tests/unittests/xpu/test_dropout_op_xpu.py create mode 100644 python/paddle/fluid/tests/unittests/xpu/test_log_loss_op_xpu.py diff --git a/paddle/fluid/operators/dropout_op_xpu.cc b/paddle/fluid/operators/dropout_op_xpu.cc new file mode 100644 index 00000000000..506239fd2bc --- /dev/null +++ b/paddle/fluid/operators/dropout_op_xpu.cc @@ -0,0 +1,130 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#include "paddle/fluid/operators/dropout_op.h" +#include +#include +#include "paddle/fluid/platform/xpu_header.h" +namespace paddle { +namespace operators { + +#ifdef PADDLE_WITH_XPU +static std::map mask_data_tables; +static const int max_data_size = 32 * 1024 * 1024; +static std::mutex s_mask_data_table_lock; +template +class DropoutXPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* x = context.Input("X"); + auto* y = context.Output("Out"); + const auto* x_data = x->data(); + auto* y_data = y->mutable_data(context.GetPlace()); + float dropout_prob = context.Attr("dropout_prob"); + auto dropout_implementation = + context.Attr("dropout_implementation"); + float* mask_data_table = nullptr; + PADDLE_ENFORCE_EQ(!context.HasInput("Seed"), true, + platform::errors::InvalidArgument( + ("Input(Seed) not supported on XPU"))); + if (!context.Attr("is_test")) { + int dev_id = + BOOST_GET_CONST(platform::XPUPlace, context.GetPlace()).GetDeviceId(); + int prop = static_cast(dropout_prob * 100); + int is_upscale = (dropout_implementation == "upscale_in_train"); + /* mask_data_tables key contains 3 part: + * | 31-16 | 15-8 | 7-0 | + * | dev_id | prob | is_upscale | + */ + int index = (dev_id << 16) + (prop << 8) + is_upscale; + std::lock_guard lock(s_mask_data_table_lock); + if (mask_data_tables.find(index) == mask_data_tables.end()) { + float* mask_data_host = new float[max_data_size]; + std::random_device rnd; + std::minstd_rand engine; + int seed = + context.Attr("fix_seed") ? context.Attr("seed") : rnd(); + engine.seed(seed); + std::uniform_real_distribution dist(0, 1); + for (size_t i = 0; i < max_data_size; ++i) { + if (dist(engine) < dropout_prob) { + mask_data_host[i] = 0.0f; + } else { + if (is_upscale) { + mask_data_host[i] = 1.0f / static_cast(1.0f - dropout_prob); + } else { + mask_data_host[i] = 1.0; + } + } + } + PADDLE_ENFORCE( + xpu_malloc(reinterpret_cast(&mask_data_table), + max_data_size * sizeof(float)) == xpu::Error_t::SUCCESS, + "XPU no enough memory"); + memory::Copy(BOOST_GET_CONST(platform::XPUPlace, context.GetPlace()), + mask_data_table, platform::CPUPlace(), mask_data_host, + max_data_size * sizeof(float)); + mask_data_tables[index] = mask_data_table; + free(mask_data_host); + } else { + mask_data_table = mask_data_tables[index]; + } + } + if (!context.Attr("is_test")) { // Train + auto* mask = context.Output("Mask"); + auto* mask_data = mask->mutable_data(context.GetPlace()); + size_t size = framework::product(mask->dims()); + auto& dev_ctx = context.template device_context(); + int r = xpu::dropout(dev_ctx.x_context(), mask_data_table, x_data, + mask_data, y_data, max_data_size, size); + PADDLE_ENFORCE_EQ(r == xpu::Error_t::SUCCESS, true, + platform::errors::InvalidArgument("XPU kernel error!")); + } else { // Infer + float scale = 0.0f; + if (dropout_implementation == "upscale_in_train") { + scale = 1.0f; + } else { + scale = static_cast(1.0f - dropout_prob); + } + auto& dev_ctx = context.template device_context(); + int r = xpu::scale(dev_ctx.x_context(), x->numel(), scale, 0.0f, 0, + x_data, y_data); + PADDLE_ENFORCE_EQ(r == xpu::Error_t::SUCCESS, true, + platform::errors::InvalidArgument("XPU kernel error!")); + } + } +}; +template +class DropoutGradXPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + PADDLE_ENFORCE(!context.Attr("is_test"), + "GradOp is only callable when is_test is false"); + auto* grad_x = context.Output(framework::GradVarName("X")); + auto* grad_y = context.Input(framework::GradVarName("Out")); + auto* mask = context.Input("Mask"); + grad_x->mutable_data(context.GetPlace()); + auto& dev_ctx = context.template device_context(); + int r = xpu::elementwise_mul(dev_ctx.x_context(), grad_y->data(), + mask->data(), grad_x->data(), + grad_y->numel()); + PADDLE_ENFORCE_EQ(r == xpu::Error_t::SUCCESS, true, + platform::errors::InvalidArgument("XPU kernel error!")); + } +}; +} // namespace operators +} // namespace paddle +namespace ops = paddle::operators; +REGISTER_OP_XPU_KERNEL( + dropout, ops::DropoutXPUKernel); +REGISTER_OP_XPU_KERNEL( + dropout_grad, + ops::DropoutGradXPUKernel); +#endif diff --git a/paddle/fluid/operators/log_loss_op_xpu.cc b/paddle/fluid/operators/log_loss_op_xpu.cc new file mode 100644 index 00000000000..80e5f8ec401 --- /dev/null +++ b/paddle/fluid/operators/log_loss_op_xpu.cc @@ -0,0 +1,68 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#ifdef PADDLE_WITH_XPU + +#include "paddle/fluid/operators/log_loss_op.h" +#include +namespace paddle { +namespace operators { + +template +class LogLossXPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* predict = ctx.Input("Predicted"); + auto* labels = ctx.Input("Labels"); + auto* loss = ctx.Output("Loss"); + auto epsilon = static_cast(ctx.Attr("epsilon")); + loss->mutable_data(ctx.GetPlace()); + int n = predict->numel(); + auto& dev_ctx = ctx.template device_context(); + int r = + xpu::log_loss_fwd(dev_ctx.x_context(), n, epsilon, predict->data(), + labels->data(), loss->data()); + PADDLE_ENFORCE_EQ(r == xpu::Error_t::SUCCESS, true, + platform::errors::InvalidArgument("XPU kernel error!")); + } +}; +template +class LogLossGradXPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* predict = ctx.Input("Predicted"); + auto* labels = ctx.Input("Labels"); + auto* dloss = ctx.Input(framework::GradVarName("Loss")); + auto* dpred = ctx.Output(framework::GradVarName("Predicted")); + if (!dpred) { + return; + } + auto epsilon = static_cast(ctx.Attr("epsilon")); + dpred->mutable_data(ctx.GetPlace()); + int n = predict->numel(); + auto& dev_ctx = ctx.template device_context(); + int r = xpu::log_loss_bwd(dev_ctx.x_context(), n, epsilon, + predict->data(), labels->data(), + dloss->data(), dpred->data()); + PADDLE_ENFORCE_EQ(r == xpu::Error_t::SUCCESS, true, + platform::errors::InvalidArgument("XPU kernel error!")); + } +}; + +} // namespace operators +} // namespace paddle +namespace ops = paddle::operators; +REGISTER_OP_XPU_KERNEL( + log_loss, ops::LogLossXPUKernel); +REGISTER_OP_XPU_KERNEL( + log_loss_grad, + ops::LogLossGradXPUKernel); + +#endif diff --git a/python/paddle/fluid/tests/unittests/xpu/test_dropout_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_dropout_op_xpu.py new file mode 100644 index 00000000000..6c3368c3b6b --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_dropout_op_xpu.py @@ -0,0 +1,112 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import sys +sys.path.append("..") +import unittest +import numpy as np +import paddle.fluid.core as core +from op_test import OpTest, skip_check_grad_ci +import paddle +import paddle.fluid as fluid +from paddle.fluid import Program, program_guard + + +class TestDropoutOp(OpTest): + def setUp(self): + self.op_type = "dropout" + self.inputs = {'X': np.random.random((32, 64)).astype("float32")} + self.attrs = {'dropout_prob': 0.0, 'fix_seed': True, 'is_test': False} + self.outputs = { + 'Out': self.inputs['X'], + 'Mask': np.ones((32, 64)).astype('uint8') + } + + def test_check_output(self): + if paddle.is_compiled_with_xpu(): + paddle.enable_static() + place = paddle.XPUPlace(0) + self.check_output_with_place(place) + + def test_check_grad_normal(self): + if paddle.is_compiled_with_xpu(): + paddle.enable_static() + place = paddle.XPUPlace(0) + self.check_grad_with_place(place, ['X'], 'Out') + + +class TestDropoutOpInput1d(OpTest): + def setUp(self): + self.op_type = "dropout" + self.inputs = {'X': np.random.random((2000, )).astype("float32")} + self.attrs = {'dropout_prob': 0.0, 'fix_seed': True, 'is_test': False} + self.outputs = { + 'Out': self.inputs['X'], + 'Mask': np.ones((2000)).astype('uint8') + } + + def test_check_output(self): + if paddle.is_compiled_with_xpu(): + paddle.enable_static() + place = paddle.XPUPlace(0) + self.check_output_with_place(place) + + def test_check_grad_normal(self): + if paddle.is_compiled_with_xpu(): + paddle.enable_static() + place = paddle.XPUPlace(0) + self.check_grad_with_place(place, ['X'], 'Out') + + +class TestDropoutOp2(TestDropoutOp): + def setUp(self): + self.op_type = "dropout" + self.inputs = {'X': np.random.random((32, 64)).astype("float32")} + self.attrs = {'dropout_prob': 1.0, 'fix_seed': True, 'is_test': False} + self.outputs = { + 'Out': np.zeros((32, 64)).astype('float32'), + 'Mask': np.zeros((32, 64)).astype('uint8') + } + + +class TestDropoutOp3(TestDropoutOp): + def setUp(self): + self.op_type = "dropout" + self.inputs = {'X': np.random.random((32, 64, 2)).astype("float32")} + self.attrs = {'dropout_prob': 0.0, 'fix_seed': True, 'is_test': False} + self.outputs = { + 'Out': self.inputs['X'], + 'Mask': np.ones((32, 64, 2)).astype('uint8') + } + + +class TestDropoutOp6(TestDropoutOp): + def setUp(self): + self.op_type = "dropout" + self.inputs = {'X': np.random.random((32, 64, 2)).astype("float32")} + self.attrs = { + 'dropout_prob': 0.0, + 'fix_seed': True, + 'is_test': False, + 'dropout_implementation': 'upscale_in_train' + } + self.outputs = { + 'Out': self.inputs['X'], + 'Mask': np.ones((32, 64, 2)).astype('uint8') + } + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/xpu/test_log_loss_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_log_loss_op_xpu.py new file mode 100644 index 00000000000..3ba3a8b5eef --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_log_loss_op_xpu.py @@ -0,0 +1,65 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import sys +sys.path.append("..") +import paddle.fluid.core as core +import unittest +import numpy as np +from op_test import OpTest +import paddle +import paddle.fluid as fluid +from paddle.fluid import Program, program_guard + + +def sigmoid_array(x): + return 1 / (1 + np.exp(-x)) + + +class TestXPULogLossOp(OpTest): + def setUp(self): + self.op_type = 'log_loss' + samples_num = 100 + + x = np.random.random((samples_num, 1)).astype("float32") + predicted = sigmoid_array(x) + labels = np.random.randint(0, 2, (samples_num, 1)).astype("float32") + epsilon = 1e-7 + self.inputs = { + 'Predicted': predicted, + 'Labels': labels, + } + + self.attrs = {'epsilon': epsilon} + loss = -labels * np.log(predicted + epsilon) - ( + 1 - labels) * np.log(1 - predicted + epsilon) + self.outputs = {'Loss': loss} + + def test_check_output(self): + if paddle.is_compiled_with_xpu(): + paddle.enable_static() + place = paddle.XPUPlace(0) + self.check_output_with_place(place) + + def test_check_grad(self): + if paddle.is_compiled_with_xpu(): + paddle.enable_static() + place = paddle.XPUPlace(0) + self.check_grad(['Predicted'], 'Loss') + + +if __name__ == '__main__': + unittest.main() -- GitLab