未验证 提交 fd44de58 编写于 作者: H houj04 提交者: GitHub

add sigmoid cross entropy with logits to kl2 (#38915)

* add sigmoid cross entropy with logits to kl2. test=kunlun

* add sigmoid cross entropy with logits to kl2. test=kunlun

* follow comments. test=kunlun
上级 106b5514
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include <vector> #include <vector>
#include "paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.h" #include "paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.h"
#include "paddle/fluid/platform/device/device_wrapper.h"
#include "paddle/fluid/platform/device/xpu/xpu_header.h" #include "paddle/fluid/platform/device/xpu/xpu_header.h"
namespace paddle { namespace paddle {
...@@ -41,24 +42,41 @@ class SigmoidCrossEntropyWithLogitsXPUKernel : public framework::OpKernel<T> { ...@@ -41,24 +42,41 @@ class SigmoidCrossEntropyWithLogitsXPUKernel : public framework::OpKernel<T> {
auto& dev_ctx = context.template device_context<DeviceContext>(); auto& dev_ctx = context.template device_context<DeviceContext>();
// attrs // attrs
bool normalize = context.Attr<bool>("normalize");
PADDLE_ENFORCE_EQ(
normalize, false,
platform::errors::InvalidArgument("normalize only support true now."));
int ignore_index = context.Attr<int>("ignore_index"); int ignore_index = context.Attr<int>("ignore_index");
PADDLE_ENFORCE_EQ(ignore_index, kIgnoreIndex, bool normalize = context.Attr<bool>("normalize");
platform::errors::InvalidArgument(
"ignore_index only support %d now.", kIgnoreIndex)); // allocate temp memory
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
int* hit = RAII_GUARD.alloc_l3_or_gm<int>(input->numel());
PADDLE_ENFORCE_NOT_NULL(
hit, platform::errors::External("XPU alloc_l3_or_gm returns nullptr"));
int r = xpu::sigmoid_cross_entropy_with_logits( int r = xpu::sigmoid_cross_entropy_with_logits(
dev_ctx.x_context(), reinterpret_cast<const XPUType*>(input->data<T>()), dev_ctx.x_context(), reinterpret_cast<const XPUType*>(input->data<T>()),
reinterpret_cast<const XPUType*>(label->data<T>()), reinterpret_cast<const XPUType*>(label->data<T>()),
reinterpret_cast<XPUType*>(output->data<T>()), 1, input->numel()); reinterpret_cast<XPUType*>(output->data<T>()), 1, input->numel(), hit,
PADDLE_ENFORCE_EQ( ignore_index);
r, XPU_SUCCESS, PADDLE_ENFORCE_XDNN_SUCCESS(r, "sigmoid_cross_entropy_with_logits");
platform::errors::External("XPU sigmoid_cross_entropy_with_logits " if (normalize) {
"kernel return wrong value[%d %s]", int* non_zero = RAII_GUARD.alloc_l3_or_gm<int>(1);
r, XPUAPIErrorMsg[r])); PADDLE_ENFORCE_NOT_NULL(
non_zero,
platform::errors::External("XPU alloc_l3_or_gm returns nullptr"));
int r = xpu::nonzero_count(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(hit),
non_zero, input->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "nonzero_count");
int non_zero_cpu = 0;
memory::Copy(platform::CPUPlace(), static_cast<void*>(&non_zero_cpu),
context.GetPlace(), static_cast<void*>(non_zero),
sizeof(int));
r = xpu::scale(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(output->data<T>()),
reinterpret_cast<XPUType*>(output->data<T>()),
input->numel(), false,
1.0f / static_cast<float>(non_zero_cpu), 0.0f);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "scale");
}
} }
}; };
...@@ -81,16 +99,42 @@ class SigmoidCrossEntropyWithLogitsGradXPUKernel ...@@ -81,16 +99,42 @@ class SigmoidCrossEntropyWithLogitsGradXPUKernel
dx->mutable_data<T>(context.GetPlace()); dx->mutable_data<T>(context.GetPlace());
auto& dev_ctx = context.template device_context<DeviceContext>(); auto& dev_ctx = context.template device_context<DeviceContext>();
// attrs
int ignore_index = context.Attr<int>("ignore_index");
bool normalize = context.Attr<bool>("normalize");
// allocate temp memory
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
int* hit = RAII_GUARD.alloc_l3_or_gm<int>(input->numel());
PADDLE_ENFORCE_NOT_NULL(
hit, platform::errors::External("XPU alloc_l3_or_gm returns nullptr"));
int r = xpu::sigmoid_cross_entropy_with_logits_grad( int r = xpu::sigmoid_cross_entropy_with_logits_grad(
dev_ctx.x_context(), reinterpret_cast<const XPUType*>(input->data<T>()), dev_ctx.x_context(), reinterpret_cast<const XPUType*>(input->data<T>()),
reinterpret_cast<const XPUType*>(label->data<T>()), reinterpret_cast<const XPUType*>(label->data<T>()),
reinterpret_cast<const XPUType*>(dy->data<T>()), reinterpret_cast<const XPUType*>(dy->data<T>()),
reinterpret_cast<XPUType*>(dx->data<T>()), 1, input->numel()); reinterpret_cast<XPUType*>(dx->data<T>()), 1, input->numel(), hit,
PADDLE_ENFORCE_EQ( ignore_index);
r, XPU_SUCCESS, PADDLE_ENFORCE_XDNN_SUCCESS(r, "sigmoid_cross_entropy_with_logits");
platform::errors::External("XPU sigmoid_cross_entropy_with_logits_grad " if (normalize) {
"kernel return wrong value[%d %s]", int* non_zero = RAII_GUARD.alloc_l3_or_gm<int>(1);
r, XPUAPIErrorMsg[r])); PADDLE_ENFORCE_NOT_NULL(
non_zero,
platform::errors::External("XPU alloc_l3_or_gm returns nullptr"));
int r = xpu::nonzero_count(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(hit),
non_zero, input->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "nonzero_count");
int non_zero_cpu = 0;
memory::Copy(platform::CPUPlace(), static_cast<void*>(&non_zero_cpu),
context.GetPlace(), static_cast<void*>(non_zero),
sizeof(int));
r = xpu::scale(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(dx->data<T>()),
reinterpret_cast<XPUType*>(dx->data<T>()), input->numel(),
false, 1.0f / static_cast<float>(non_zero_cpu), 0.0f);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "scale");
}
} }
}; };
......
...@@ -292,6 +292,10 @@ XPUOpMap& get_kl2_ops() { ...@@ -292,6 +292,10 @@ XPUOpMap& get_kl2_ops() {
pOpKernelType(vartype::INT64, XPUPlace())})}, pOpKernelType(vartype::INT64, XPUPlace())})},
{"scatter", XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()), {"scatter", XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})}, pOpKernelType(vartype::FP32, XPUPlace())})},
{"sigmoid_cross_entropy_with_logits_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"sigmoid_cross_entropy_with_logits",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"shape", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), {"shape", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace())})}, pOpKernelType(vartype::INT64, XPUPlace())})},
{"sigmoid", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"sigmoid", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
......
...@@ -73,6 +73,39 @@ class TestSigmoidCrossEntropyWithLogitsOp1(XPUOpTest): ...@@ -73,6 +73,39 @@ class TestSigmoidCrossEntropyWithLogitsOp1(XPUOpTest):
self.dtype = np.float32 self.dtype = np.float32
class TestSigmoidCrossEntropyWithLogitsOp2(
TestSigmoidCrossEntropyWithLogitsOp1):
"""Test sigmoid_cross_entropy_with_logit_op with probabalistic label
"""
def setUp(self):
self.op_type = "sigmoid_cross_entropy_with_logits"
self.set_xpu()
self.init_dtype()
batch_size = 64
num_classes = 20
ignore_index = -1
self.inputs = {
'X': logit(
np.random.uniform(0, 1, (batch_size, num_classes))
.astype(self.dtype)),
'Label': np.random.randint(-1, 2, (batch_size, num_classes))
.astype(self.dtype)
}
self.attrs = {'ignore_index': ignore_index, }
# Fw Pass is implemented as elementwise sigmoid followed by
# elementwise logistic loss
# Label * -log(sigmoid(X)) + (1 - label) * -log(1 - sigmoid(X))
sigmoid_X = expit(self.inputs['X'])
term1 = self.inputs['Label'] * np.log(sigmoid_X)
term2 = (1 - self.inputs['Label']) * np.log(1 - sigmoid_X)
out = -term1 - term2
out[np.where(self.inputs['Label'] == ignore_index)] = 0
self.outputs = {'Out': out}
class TestSigmoidCrossEntropyWithLogitsOp3( class TestSigmoidCrossEntropyWithLogitsOp3(
TestSigmoidCrossEntropyWithLogitsOp1): TestSigmoidCrossEntropyWithLogitsOp1):
"""Test sigmoid_cross_entropy_with_logit_op with probabalistic label """Test sigmoid_cross_entropy_with_logit_op with probabalistic label
...@@ -102,6 +135,42 @@ class TestSigmoidCrossEntropyWithLogitsOp3( ...@@ -102,6 +135,42 @@ class TestSigmoidCrossEntropyWithLogitsOp3(
self.outputs = {'Out': -term1 - term2} self.outputs = {'Out': -term1 - term2}
class TestSigmoidCrossEntropyWithLogitsOp4(
TestSigmoidCrossEntropyWithLogitsOp1):
"""Test sigmoid_cross_entropy_with_logit_op with probabalistic label
"""
def setUp(self):
self.op_type = "sigmoid_cross_entropy_with_logits"
self.set_xpu()
self.init_dtype()
batch_size = 64
num_classes = 20
ignore_index = -1
self.inputs = {
'X': logit(
np.random.uniform(0, 1, (batch_size, num_classes))
.astype(self.dtype)),
'Label': np.random.randint(-1, 2, (batch_size, num_classes))
.astype(self.dtype)
}
self.attrs = {'ignore_index': ignore_index, 'normalize': True}
# Fw Pass is implemented as elementwise sigmoid followed by
# elementwise logistic loss
# Label * -log(sigmoid(X)) + (1 - label) * -log(1 - sigmoid(X))
sigmoid_X = expit(self.inputs['X'])
term1 = self.inputs['Label'] * np.log(sigmoid_X)
term2 = (1 - self.inputs['Label']) * np.log(1 - sigmoid_X)
out = -term1 - term2
out[np.where(self.inputs['Label'] == ignore_index)] = 0
if self.attrs['normalize']:
out = out / float(
np.where(self.inputs['Label'] != ignore_index)[0].size)
self.outputs = {'Out': out}
class TestSigmoidCrossEntropyWithLogitsOp5( class TestSigmoidCrossEntropyWithLogitsOp5(
TestSigmoidCrossEntropyWithLogitsOp1): TestSigmoidCrossEntropyWithLogitsOp1):
"""Test sigmoid_cross_entropy_with_logit_op with probabalistic label """Test sigmoid_cross_entropy_with_logit_op with probabalistic label
...@@ -131,6 +200,42 @@ class TestSigmoidCrossEntropyWithLogitsOp5( ...@@ -131,6 +200,42 @@ class TestSigmoidCrossEntropyWithLogitsOp5(
self.outputs = {'Out': -term1 - term2} self.outputs = {'Out': -term1 - term2}
class TestSigmoidCrossEntropyWithLogitsNorm(
TestSigmoidCrossEntropyWithLogitsOp1):
"""Test sigmoid_cross_entropy_with_logit_op with probabalistic label
"""
def setUp(self):
self.op_type = "sigmoid_cross_entropy_with_logits"
self.set_xpu()
self.init_dtype()
batch_size = [10, 10]
num_classes = 20
ignore_index = -1
self.inputs = {
'X': logit(
np.random.uniform(0, 1, tuple(batch_size + [num_classes]))
.astype(self.dtype)),
'Label': np.random.randint(-1, 2, tuple(batch_size + [num_classes]))
.astype(self.dtype)
}
self.attrs = {'ignore_index': ignore_index, 'normalize': True}
# Fw Pass is implemented as elementwise sigmoid followed by
# elementwise logistic loss
# Label * -log(sigmoid(X)) + (1 - label) * -log(1 - sigmoid(X))
sigmoid_X = expit(self.inputs['X'])
term1 = self.inputs['Label'] * np.log(sigmoid_X)
term2 = (1 - self.inputs['Label']) * np.log(1 - sigmoid_X)
out = -term1 - term2
out[np.where(self.inputs['Label'] == ignore_index)] = 0
if self.attrs['normalize']:
out = out / float(
np.where(self.inputs['Label'] != ignore_index)[0].size)
self.outputs = {'Out': out}
class TestSigmoidCrossEntropyWithLogitsOp6( class TestSigmoidCrossEntropyWithLogitsOp6(
TestSigmoidCrossEntropyWithLogitsOp1): TestSigmoidCrossEntropyWithLogitsOp1):
"""Test sigmoid_cross_entropy_with_logit_op with binary label """Test sigmoid_cross_entropy_with_logit_op with binary label
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册