diff --git a/paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op_xpu.cc b/paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op_xpu.cc new file mode 100644 index 0000000000000000000000000000000000000000..7e21cba14b7dcaad215aa040958a656e9b3058ec --- /dev/null +++ b/paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op_xpu.cc @@ -0,0 +1,110 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifdef PADDLE_WITH_XPU + +#include +#include + +#include "paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.h" +#include "paddle/fluid/platform/device/xpu/xpu_header.h" + +namespace paddle { +namespace operators { + +template +class SigmoidCrossEntropyWithLogitsXPUKernel : public framework::OpKernel { + using XPUType = typename XPUTypeTrait::Type; + + public: + void Compute(const framework::ExecutionContext& context) const override { + PADDLE_ENFORCE_EQ( + platform::is_xpu_place(context.GetPlace()), true, + platform::errors::Unavailable("This kernel only runs on XPU.")); + + // input and output data + auto* input = context.Input("X"); + auto* label = context.Input("Label"); + auto* output = context.Output("Out"); + output->mutable_data(context.GetPlace()); + auto& dev_ctx = context.template device_context(); + + // attrs + bool normalize = context.Attr("normalize"); + PADDLE_ENFORCE_EQ( + normalize, false, + platform::errors::InvalidArgument("normalize only support true now.")); + int ignore_index = context.Attr("ignore_index"); + PADDLE_ENFORCE_EQ(ignore_index, kIgnoreIndex, + platform::errors::InvalidArgument( + "ignore_index only support %d now.", kIgnoreIndex)); + + int r = xpu::sigmoid_cross_entropy_with_logits( + dev_ctx.x_context(), reinterpret_cast(input->data()), + reinterpret_cast(label->data()), + reinterpret_cast(output->data()), 1, input->numel()); + PADDLE_ENFORCE_EQ( + r, XPU_SUCCESS, + platform::errors::External("XPU sigmoid_cross_entropy_with_logits " + "kernel return wrong value[%d %s]", + r, XPUAPIErrorMsg[r])); + } +}; + +template +class SigmoidCrossEntropyWithLogitsGradXPUKernel + : public framework::OpKernel { + using XPUType = typename XPUTypeTrait::Type; + + public: + void Compute(const framework::ExecutionContext& context) const override { + PADDLE_ENFORCE_EQ( + platform::is_xpu_place(context.GetPlace()), true, + platform::errors::Unavailable("This kernel only runs on XPU.")); + + // input and output data + auto* input = context.Input("X"); + auto* label = context.Input("Label"); + auto* dy = context.Input(framework::GradVarName("Out")); + auto* dx = context.Output(framework::GradVarName("X")); + dx->mutable_data(context.GetPlace()); + auto& dev_ctx = context.template device_context(); + + int r = xpu::sigmoid_cross_entropy_with_logits_grad( + dev_ctx.x_context(), reinterpret_cast(input->data()), + reinterpret_cast(label->data()), + reinterpret_cast(dy->data()), + reinterpret_cast(dx->data()), 1, input->numel()); + PADDLE_ENFORCE_EQ( + r, XPU_SUCCESS, + platform::errors::External("XPU sigmoid_cross_entropy_with_logits_grad " + "kernel return wrong value[%d %s]", + r, XPUAPIErrorMsg[r])); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_XPU_KERNEL(sigmoid_cross_entropy_with_logits, + ops::SigmoidCrossEntropyWithLogitsXPUKernel< + paddle::platform::XPUDeviceContext, float>); + +REGISTER_OP_XPU_KERNEL(sigmoid_cross_entropy_with_logits_grad, + ops::SigmoidCrossEntropyWithLogitsGradXPUKernel< + paddle::platform::XPUDeviceContext, float>); + +#endif diff --git a/paddle/fluid/platform/device/xpu/xpu1_op_list.h b/paddle/fluid/platform/device/xpu/xpu1_op_list.h index b2114afee63c6665cae40e1b63f50da5fd4bd7ae..a08e6a70c9863775a2c4989be7b28a36f67e9100 100644 --- a/paddle/fluid/platform/device/xpu/xpu1_op_list.h +++ b/paddle/fluid/platform/device/xpu/xpu1_op_list.h @@ -249,6 +249,10 @@ XPUOpMap& get_kl1_ops() { pOpKernelType(vartype::INT32, XPUPlace()), pOpKernelType(vartype::BOOL, XPUPlace()), pOpKernelType(vartype::FP32, XPUPlace())})}, + {"sigmoid_cross_entropy_with_logits_grad", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"sigmoid_cross_entropy_with_logits", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"sigmoid_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"sigmoid", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, diff --git a/python/paddle/fluid/tests/unittests/xpu/test_sigmoid_cross_entropy_with_logits_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_sigmoid_cross_entropy_with_logits_op_xpu.py new file mode 100644 index 0000000000000000000000000000000000000000..4ceacd52092341347ce5633c5b439ad49e7ca8de --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_sigmoid_cross_entropy_with_logits_op_xpu.py @@ -0,0 +1,164 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import sys +sys.path.append("..") +from op_test_xpu import OpTest, XPUOpTest +from op_test import skip_check_grad_ci +import paddle +import paddle.fluid.core as core +import paddle.fluid as fluid +from paddle.fluid import compiler, Program, program_guard +from paddle.fluid.framework import convert_np_dtype_to_dtype_ + +from scipy.special import logit +from scipy.special import expit + +paddle.enable_static() + + +class TestSigmoidCrossEntropyWithLogitsOp1(XPUOpTest): + """Test sigmoid_cross_entropy_with_logit_op with binary label + """ + + def setUp(self): + self.op_type = "sigmoid_cross_entropy_with_logits" + self.set_xpu() + self.init_dtype() + + batch_size = 64 + num_classes = 20 + self.inputs = { + 'X': logit( + np.random.uniform(0, 1, (batch_size, num_classes)) + .astype(self.dtype)), + 'Label': np.random.randint(0, 2, (batch_size, num_classes)) + .astype(self.dtype) + } + + # Fw Pass is implemented as elementwise sigmoid followed by + # elementwise logistic loss + # Label * -log(sigmoid(X)) + (1 - label) * -log(1 - sigmoid(X)) + sigmoid_X = expit(self.inputs['X']) + term1 = self.inputs['Label'] * np.log(sigmoid_X) + term2 = (1 - self.inputs['Label']) * np.log(1 - sigmoid_X) + self.outputs = {'Out': -term1 - term2} + + def test_check_output(self): + self.check_output_with_place(self.place) + + def test_check_grad(self): + self.check_grad_with_place(self.place, ['X'], 'Out') + + def set_xpu(self): + self.__class__.use_xpu = True + self.place = paddle.XPUPlace(0) + + def init_dtype(self): + self.dtype = np.float32 + + +class TestSigmoidCrossEntropyWithLogitsOp3( + TestSigmoidCrossEntropyWithLogitsOp1): + """Test sigmoid_cross_entropy_with_logit_op with probabalistic label + """ + + def setUp(self): + self.op_type = "sigmoid_cross_entropy_with_logits" + self.set_xpu() + self.init_dtype() + + batch_size = 64 + num_classes = 20 + self.inputs = { + 'X': logit( + np.random.uniform(0, 1, (batch_size, num_classes)) + .astype(self.dtype)), + 'Label': np.random.uniform(0, 1, (batch_size, num_classes)) + .astype(self.dtype) + } + + # Fw Pass is implemented as elementwise sigmoid followed by + # elementwise logistic loss + # Label * -log(sigmoid(X)) + (1 - label) * -log(1 - sigmoid(X)) + sigmoid_X = expit(self.inputs['X']) + term1 = self.inputs['Label'] * np.log(sigmoid_X) + term2 = (1 - self.inputs['Label']) * np.log(1 - sigmoid_X) + self.outputs = {'Out': -term1 - term2} + + +class TestSigmoidCrossEntropyWithLogitsOp5( + TestSigmoidCrossEntropyWithLogitsOp1): + """Test sigmoid_cross_entropy_with_logit_op with probabalistic label + """ + + def setUp(self): + self.op_type = "sigmoid_cross_entropy_with_logits" + self.set_xpu() + self.init_dtype() + + batch_size = [10, 10] + num_classes = 20 + self.inputs = { + 'X': logit( + np.random.uniform(0, 1, tuple(batch_size + [num_classes])) + .astype(self.dtype)), + 'Label': np.random.uniform(0, 1, tuple(batch_size + [num_classes])) + .astype(self.dtype) + } + + # Fw Pass is implemented as elementwise sigmoid followed by + # elementwise logistic loss + # Label * -log(sigmoid(X)) + (1 - label) * -log(1 - sigmoid(X)) + sigmoid_X = expit(self.inputs['X']) + term1 = self.inputs['Label'] * np.log(sigmoid_X) + term2 = (1 - self.inputs['Label']) * np.log(1 - sigmoid_X) + self.outputs = {'Out': -term1 - term2} + + +class TestSigmoidCrossEntropyWithLogitsOp6( + TestSigmoidCrossEntropyWithLogitsOp1): + """Test sigmoid_cross_entropy_with_logit_op with binary label + """ + + def setUp(self): + self.op_type = "sigmoid_cross_entropy_with_logits" + self.set_xpu() + self.init_dtype() + + batch_size = [10, 10] + num_classes = 20 + self.inputs = { + 'X': logit( + np.random.uniform(0, 1, tuple(batch_size + [num_classes])) + .astype(self.dtype)), + 'Label': np.random.randint(0, 2, tuple(batch_size + [num_classes])) + .astype(self.dtype) + } + + # Fw Pass is implemented as elementwise sigmoid followed by + # elementwise logistic loss + # Label * -log(sigmoid(X)) + (1 - label) * -log(1 - sigmoid(X)) + sigmoid_X = expit(self.inputs['X']) + term1 = self.inputs['Label'] * np.log(sigmoid_X) + term2 = (1 - self.inputs['Label']) * np.log(1 - sigmoid_X) + self.outputs = {'Out': -term1 - term2} + + +if __name__ == '__main__': + unittest.main()