From fd44de588a78eb9b9fc85907b4252b54ba5ce83d Mon Sep 17 00:00:00 2001 From: houj04 <35131887+houj04@users.noreply.github.com> Date: Wed, 26 Jan 2022 13:20:37 +0800 Subject: [PATCH] add sigmoid cross entropy with logits to kl2 (#38915) * add sigmoid cross entropy with logits to kl2. test=kunlun * add sigmoid cross entropy with logits to kl2. test=kunlun * follow comments. test=kunlun --- ...igmoid_cross_entropy_with_logits_op_xpu.cc | 82 ++++++++++---- .../fluid/platform/device/xpu/xpu2_op_list.h | 4 + ...igmoid_cross_entropy_with_logits_op_xpu.py | 105 ++++++++++++++++++ 3 files changed, 172 insertions(+), 19 deletions(-) diff --git a/paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op_xpu.cc b/paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op_xpu.cc index 7e21cba14b..6395aa1caa 100644 --- a/paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op_xpu.cc +++ b/paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op_xpu.cc @@ -18,6 +18,7 @@ #include #include "paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.h" +#include "paddle/fluid/platform/device/device_wrapper.h" #include "paddle/fluid/platform/device/xpu/xpu_header.h" namespace paddle { @@ -41,24 +42,41 @@ class SigmoidCrossEntropyWithLogitsXPUKernel : public framework::OpKernel { auto& dev_ctx = context.template device_context(); // attrs - bool normalize = context.Attr("normalize"); - PADDLE_ENFORCE_EQ( - normalize, false, - platform::errors::InvalidArgument("normalize only support true now.")); int ignore_index = context.Attr("ignore_index"); - PADDLE_ENFORCE_EQ(ignore_index, kIgnoreIndex, - platform::errors::InvalidArgument( - "ignore_index only support %d now.", kIgnoreIndex)); + bool normalize = context.Attr("normalize"); + + // allocate temp memory + xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); + int* hit = RAII_GUARD.alloc_l3_or_gm(input->numel()); + PADDLE_ENFORCE_NOT_NULL( + hit, platform::errors::External("XPU alloc_l3_or_gm returns nullptr")); int r = xpu::sigmoid_cross_entropy_with_logits( dev_ctx.x_context(), reinterpret_cast(input->data()), reinterpret_cast(label->data()), - reinterpret_cast(output->data()), 1, input->numel()); - PADDLE_ENFORCE_EQ( - r, XPU_SUCCESS, - platform::errors::External("XPU sigmoid_cross_entropy_with_logits " - "kernel return wrong value[%d %s]", - r, XPUAPIErrorMsg[r])); + reinterpret_cast(output->data()), 1, input->numel(), hit, + ignore_index); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "sigmoid_cross_entropy_with_logits"); + if (normalize) { + int* non_zero = RAII_GUARD.alloc_l3_or_gm(1); + PADDLE_ENFORCE_NOT_NULL( + non_zero, + platform::errors::External("XPU alloc_l3_or_gm returns nullptr")); + int r = xpu::nonzero_count(dev_ctx.x_context(), + reinterpret_cast(hit), + non_zero, input->numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "nonzero_count"); + int non_zero_cpu = 0; + memory::Copy(platform::CPUPlace(), static_cast(&non_zero_cpu), + context.GetPlace(), static_cast(non_zero), + sizeof(int)); + r = xpu::scale(dev_ctx.x_context(), + reinterpret_cast(output->data()), + reinterpret_cast(output->data()), + input->numel(), false, + 1.0f / static_cast(non_zero_cpu), 0.0f); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "scale"); + } } }; @@ -81,16 +99,42 @@ class SigmoidCrossEntropyWithLogitsGradXPUKernel dx->mutable_data(context.GetPlace()); auto& dev_ctx = context.template device_context(); + // attrs + int ignore_index = context.Attr("ignore_index"); + bool normalize = context.Attr("normalize"); + + // allocate temp memory + xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); + int* hit = RAII_GUARD.alloc_l3_or_gm(input->numel()); + PADDLE_ENFORCE_NOT_NULL( + hit, platform::errors::External("XPU alloc_l3_or_gm returns nullptr")); + int r = xpu::sigmoid_cross_entropy_with_logits_grad( dev_ctx.x_context(), reinterpret_cast(input->data()), reinterpret_cast(label->data()), reinterpret_cast(dy->data()), - reinterpret_cast(dx->data()), 1, input->numel()); - PADDLE_ENFORCE_EQ( - r, XPU_SUCCESS, - platform::errors::External("XPU sigmoid_cross_entropy_with_logits_grad " - "kernel return wrong value[%d %s]", - r, XPUAPIErrorMsg[r])); + reinterpret_cast(dx->data()), 1, input->numel(), hit, + ignore_index); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "sigmoid_cross_entropy_with_logits"); + if (normalize) { + int* non_zero = RAII_GUARD.alloc_l3_or_gm(1); + PADDLE_ENFORCE_NOT_NULL( + non_zero, + platform::errors::External("XPU alloc_l3_or_gm returns nullptr")); + int r = xpu::nonzero_count(dev_ctx.x_context(), + reinterpret_cast(hit), + non_zero, input->numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "nonzero_count"); + int non_zero_cpu = 0; + memory::Copy(platform::CPUPlace(), static_cast(&non_zero_cpu), + context.GetPlace(), static_cast(non_zero), + sizeof(int)); + r = xpu::scale(dev_ctx.x_context(), + reinterpret_cast(dx->data()), + reinterpret_cast(dx->data()), input->numel(), + false, 1.0f / static_cast(non_zero_cpu), 0.0f); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "scale"); + } } }; diff --git a/paddle/fluid/platform/device/xpu/xpu2_op_list.h b/paddle/fluid/platform/device/xpu/xpu2_op_list.h index f83e3f6d0d..8764458433 100644 --- a/paddle/fluid/platform/device/xpu/xpu2_op_list.h +++ b/paddle/fluid/platform/device/xpu/xpu2_op_list.h @@ -292,6 +292,10 @@ XPUOpMap& get_kl2_ops() { pOpKernelType(vartype::INT64, XPUPlace())})}, {"scatter", XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()), pOpKernelType(vartype::FP32, XPUPlace())})}, + {"sigmoid_cross_entropy_with_logits_grad", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"sigmoid_cross_entropy_with_logits", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"shape", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), pOpKernelType(vartype::INT64, XPUPlace())})}, {"sigmoid", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, diff --git a/python/paddle/fluid/tests/unittests/xpu/test_sigmoid_cross_entropy_with_logits_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_sigmoid_cross_entropy_with_logits_op_xpu.py index 4ceacd5209..9cb31d4270 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_sigmoid_cross_entropy_with_logits_op_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_sigmoid_cross_entropy_with_logits_op_xpu.py @@ -73,6 +73,39 @@ class TestSigmoidCrossEntropyWithLogitsOp1(XPUOpTest): self.dtype = np.float32 +class TestSigmoidCrossEntropyWithLogitsOp2( + TestSigmoidCrossEntropyWithLogitsOp1): + """Test sigmoid_cross_entropy_with_logit_op with probabalistic label + """ + + def setUp(self): + self.op_type = "sigmoid_cross_entropy_with_logits" + self.set_xpu() + self.init_dtype() + + batch_size = 64 + num_classes = 20 + ignore_index = -1 + self.inputs = { + 'X': logit( + np.random.uniform(0, 1, (batch_size, num_classes)) + .astype(self.dtype)), + 'Label': np.random.randint(-1, 2, (batch_size, num_classes)) + .astype(self.dtype) + } + self.attrs = {'ignore_index': ignore_index, } + + # Fw Pass is implemented as elementwise sigmoid followed by + # elementwise logistic loss + # Label * -log(sigmoid(X)) + (1 - label) * -log(1 - sigmoid(X)) + sigmoid_X = expit(self.inputs['X']) + term1 = self.inputs['Label'] * np.log(sigmoid_X) + term2 = (1 - self.inputs['Label']) * np.log(1 - sigmoid_X) + out = -term1 - term2 + out[np.where(self.inputs['Label'] == ignore_index)] = 0 + self.outputs = {'Out': out} + + class TestSigmoidCrossEntropyWithLogitsOp3( TestSigmoidCrossEntropyWithLogitsOp1): """Test sigmoid_cross_entropy_with_logit_op with probabalistic label @@ -102,6 +135,42 @@ class TestSigmoidCrossEntropyWithLogitsOp3( self.outputs = {'Out': -term1 - term2} +class TestSigmoidCrossEntropyWithLogitsOp4( + TestSigmoidCrossEntropyWithLogitsOp1): + """Test sigmoid_cross_entropy_with_logit_op with probabalistic label + """ + + def setUp(self): + self.op_type = "sigmoid_cross_entropy_with_logits" + self.set_xpu() + self.init_dtype() + + batch_size = 64 + num_classes = 20 + ignore_index = -1 + self.inputs = { + 'X': logit( + np.random.uniform(0, 1, (batch_size, num_classes)) + .astype(self.dtype)), + 'Label': np.random.randint(-1, 2, (batch_size, num_classes)) + .astype(self.dtype) + } + self.attrs = {'ignore_index': ignore_index, 'normalize': True} + + # Fw Pass is implemented as elementwise sigmoid followed by + # elementwise logistic loss + # Label * -log(sigmoid(X)) + (1 - label) * -log(1 - sigmoid(X)) + sigmoid_X = expit(self.inputs['X']) + term1 = self.inputs['Label'] * np.log(sigmoid_X) + term2 = (1 - self.inputs['Label']) * np.log(1 - sigmoid_X) + out = -term1 - term2 + out[np.where(self.inputs['Label'] == ignore_index)] = 0 + if self.attrs['normalize']: + out = out / float( + np.where(self.inputs['Label'] != ignore_index)[0].size) + self.outputs = {'Out': out} + + class TestSigmoidCrossEntropyWithLogitsOp5( TestSigmoidCrossEntropyWithLogitsOp1): """Test sigmoid_cross_entropy_with_logit_op with probabalistic label @@ -131,6 +200,42 @@ class TestSigmoidCrossEntropyWithLogitsOp5( self.outputs = {'Out': -term1 - term2} +class TestSigmoidCrossEntropyWithLogitsNorm( + TestSigmoidCrossEntropyWithLogitsOp1): + """Test sigmoid_cross_entropy_with_logit_op with probabalistic label + """ + + def setUp(self): + self.op_type = "sigmoid_cross_entropy_with_logits" + self.set_xpu() + self.init_dtype() + + batch_size = [10, 10] + num_classes = 20 + ignore_index = -1 + self.inputs = { + 'X': logit( + np.random.uniform(0, 1, tuple(batch_size + [num_classes])) + .astype(self.dtype)), + 'Label': np.random.randint(-1, 2, tuple(batch_size + [num_classes])) + .astype(self.dtype) + } + self.attrs = {'ignore_index': ignore_index, 'normalize': True} + + # Fw Pass is implemented as elementwise sigmoid followed by + # elementwise logistic loss + # Label * -log(sigmoid(X)) + (1 - label) * -log(1 - sigmoid(X)) + sigmoid_X = expit(self.inputs['X']) + term1 = self.inputs['Label'] * np.log(sigmoid_X) + term2 = (1 - self.inputs['Label']) * np.log(1 - sigmoid_X) + out = -term1 - term2 + out[np.where(self.inputs['Label'] == ignore_index)] = 0 + if self.attrs['normalize']: + out = out / float( + np.where(self.inputs['Label'] != ignore_index)[0].size) + self.outputs = {'Out': out} + + class TestSigmoidCrossEntropyWithLogitsOp6( TestSigmoidCrossEntropyWithLogitsOp1): """Test sigmoid_cross_entropy_with_logit_op with binary label -- GitLab