From 669d86896db27415418e960ccc30bd5a3292bcd3 Mon Sep 17 00:00:00 2001 From: fwenguang <95677191+fwenguang@users.noreply.github.com> Date: Wed, 15 Jun 2022 19:08:01 +0800 Subject: [PATCH] [MLU] add bce kernel (#43435) --- paddle/fluid/operators/mlu/mlu_baseop.cc | 50 +++++ paddle/fluid/operators/mlu/mlu_baseop.h | 17 ++ .../operators/reduce_ops/reduce_sum_op_mlu.cc | 1 + ...igmoid_cross_entropy_with_logits_op_mlu.cc | 101 +++++++++ .../mlu/test_bce_with_logits_loss_mlu.py | 200 ++++++++++++++++++ ...igmoid_cross_entropy_with_logits_op_mlu.py | 172 +++++++++++++++ 6 files changed, 541 insertions(+) create mode 100644 paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op_mlu.cc create mode 100644 python/paddle/fluid/tests/unittests/mlu/test_bce_with_logits_loss_mlu.py create mode 100644 python/paddle/fluid/tests/unittests/mlu/test_sigmoid_cross_entropy_with_logits_op_mlu.py diff --git a/paddle/fluid/operators/mlu/mlu_baseop.cc b/paddle/fluid/operators/mlu/mlu_baseop.cc index d5b843d47af..2f5c54086b0 100644 --- a/paddle/fluid/operators/mlu/mlu_baseop.cc +++ b/paddle/fluid/operators/mlu/mlu_baseop.cc @@ -2832,5 +2832,55 @@ MLUCnnlTrigonDesc::~MLUCnnlTrigonDesc() { diff, workspace_ptr, workspace_size, output_desc, output)); } +/* static */ void MLUCnnl::BceWithLogits( + const ExecutionContext& ctx, cnnlBceWithLogitsReduction_t reduction, + const cnnlTensorDescriptor_t input_desc, const void* input, + const cnnlTensorDescriptor_t target_desc, const void* target, + const cnnlTensorDescriptor_t weight_desc, const void* weight, + const cnnlTensorDescriptor_t pos_weight_desc, const void* pos_weight, + const cnnlTensorDescriptor_t output_desc, void* output) { + cnnlHandle_t handle = GetHandleFromCTX(ctx); + + size_t workspace_size; + PADDLE_ENFORCE_MLU_SUCCESS(cnnlGetBceWithLogitsWorkspaceSize( + handle, input_desc, weight_desc, pos_weight_desc, &workspace_size)); + + auto& dev_ctx = GetDevCtxFromCTX(ctx); + Tensor workspace = ctx.AllocateTmpTensor( + {static_cast(workspace_size)}, dev_ctx); + void* workspace_ptr = workspace.mutable_data(ctx.GetPlace()); + + const cnnlComputationPreference_t prefer = CNNL_COMPUTATION_HIGH_PRECISION; + PADDLE_ENFORCE_MLU_SUCCESS(cnnlBceWithLogits_v2( + handle, prefer, input_desc, input, target_desc, target, weight_desc, + weight, pos_weight_desc, pos_weight, reduction, workspace_ptr, + workspace_size, output_desc, output)); +} + +/* static */ void MLUCnnl::BceWithLogitsBackward( + const ExecutionContext& ctx, cnnlBceWithLogitsReduction_t reduction, + const cnnlTensorDescriptor_t grad_desc, const void* grad, + const cnnlTensorDescriptor_t input_desc, const void* input, + const cnnlTensorDescriptor_t target_desc, const void* target, + const cnnlTensorDescriptor_t weight_desc, const void* weight, + const cnnlTensorDescriptor_t pos_weight_desc, const void* pos_weight, + const cnnlTensorDescriptor_t diff_input_desc, void* diff_input) { + cnnlHandle_t handle = GetHandleFromCTX(ctx); + + size_t workspace_size; + PADDLE_ENFORCE_MLU_SUCCESS(cnnlGetBceWithLogitsBackwardWorkspaceSize( + handle, target_desc, weight_desc, pos_weight_desc, &workspace_size)); + + auto& dev_ctx = GetDevCtxFromCTX(ctx); + Tensor workspace = ctx.AllocateTmpTensor( + {static_cast(workspace_size)}, dev_ctx); + void* workspace_ptr = workspace.mutable_data(ctx.GetPlace()); + + PADDLE_ENFORCE_MLU_SUCCESS(cnnlBceWithLogitsBackward( + handle, grad_desc, grad, input_desc, input, target_desc, target, + weight_desc, weight, pos_weight_desc, pos_weight, reduction, + workspace_ptr, workspace_size, diff_input_desc, diff_input)); +} + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/mlu/mlu_baseop.h b/paddle/fluid/operators/mlu/mlu_baseop.h index 71648c5c5fb..21214c16268 100644 --- a/paddle/fluid/operators/mlu/mlu_baseop.h +++ b/paddle/fluid/operators/mlu/mlu_baseop.h @@ -1279,6 +1279,23 @@ class MLUCnnl { const cnnlTensorDescriptor_t indices_desc, const void* indices, const cnnlTensorDescriptor_t diff_desc, const void* diff, const cnnlTensorDescriptor_t output_desc, void* output); + + static void BceWithLogits( + const ExecutionContext& ctx, cnnlBceWithLogitsReduction_t reduction, + const cnnlTensorDescriptor_t input_desc, const void* input, + const cnnlTensorDescriptor_t target_desc, const void* target, + const cnnlTensorDescriptor_t weight_desc, const void* weight, + const cnnlTensorDescriptor_t pos_weight_desc, const void* pos_weight, + const cnnlTensorDescriptor_t output_desc, void* output); + + static void BceWithLogitsBackward( + const ExecutionContext& ctx, cnnlBceWithLogitsReduction_t reduction, + const cnnlTensorDescriptor_t grad_desc, const void* grad, + const cnnlTensorDescriptor_t input_desc, const void* input, + const cnnlTensorDescriptor_t target_desc, const void* target, + const cnnlTensorDescriptor_t weight_desc, const void* weight, + const cnnlTensorDescriptor_t pos_weight_desc, const void* pos_weight, + const cnnlTensorDescriptor_t diff_input_desc, void* diff_input); }; template diff --git a/paddle/fluid/operators/reduce_ops/reduce_sum_op_mlu.cc b/paddle/fluid/operators/reduce_ops/reduce_sum_op_mlu.cc index fab8bb23b16..e3b116ab7e2 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_sum_op_mlu.cc +++ b/paddle/fluid/operators/reduce_ops/reduce_sum_op_mlu.cc @@ -73,6 +73,7 @@ namespace ops = paddle::operators; namespace plat = paddle::platform; REGISTER_OP_MLU_KERNEL(reduce_sum, ops::ReduceSumMLUKernel, + ops::ReduceSumMLUKernel, ops::ReduceSumMLUKernel); REGISTER_OP_MLU_KERNEL(reduce_sum_grad, ops::ReduceSumGradMLUKernel, ops::ReduceSumGradMLUKernel); diff --git a/paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op_mlu.cc b/paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op_mlu.cc new file mode 100644 index 00000000000..c6440cd1a29 --- /dev/null +++ b/paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op_mlu.cc @@ -0,0 +1,101 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/mlu/mlu_baseop.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +const int kIgnoreIndex = -100; + +void CheckAttrs(const framework::ExecutionContext& ctx) { + // cnnl not support normalize and ignore_index + bool normalize = ctx.Attr("normalize"); + int ignore_index = ctx.Attr("ignore_index"); + PADDLE_ENFORCE_EQ(normalize, false, + platform::errors::InvalidArgument( + "attr normalize must be false, but got true")); + PADDLE_ENFORCE_EQ(ignore_index, kIgnoreIndex, + platform::errors::InvalidArgument( + "attr ignore_index must be default %d, but got %d", + kIgnoreIndex, ignore_index)); +} + +template +class SigmoidCrossEntropyWithLogitsMLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + CheckAttrs(ctx); + + auto* x = ctx.Input("X"); + auto* label = ctx.Input("Label"); + + auto* out = ctx.Output("Out"); + + auto place = ctx.GetPlace(); + + out->mutable_data(place); + + MLUCnnlTensorDesc x_desc(*x); + MLUCnnlTensorDesc label_desc(*label); + MLUCnnlTensorDesc out_desc(*out); + MLUCnnl::BceWithLogits(ctx, CNNL_BCE_WITH_LOGITS_NONE, x_desc.get(), + GetBasePtr(x), label_desc.get(), GetBasePtr(label), + nullptr, nullptr, nullptr, nullptr, out_desc.get(), + GetBasePtr(out)); + } +}; + +template +class SigmoidCrossEntropyWithLogitsMLUGradKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + CheckAttrs(ctx); + + auto* x = ctx.Input("X"); + auto* label = ctx.Input("Label"); + auto* dout = ctx.Input(framework::GradVarName("Out")); + + auto* dx = ctx.Output(framework::GradVarName("X")); + + auto place = ctx.GetPlace(); + + dx->mutable_data(place); + + MLUCnnlTensorDesc x_desc(*x); + MLUCnnlTensorDesc label_desc(*label); + MLUCnnlTensorDesc dout_desc(*dout); + MLUCnnl::BceWithLogitsBackward( + ctx, CNNL_BCE_WITH_LOGITS_NONE, dout_desc.get(), GetBasePtr(dout), + x_desc.get(), GetBasePtr(x), label_desc.get(), GetBasePtr(label), + nullptr, nullptr, nullptr, nullptr, x_desc.get(), GetBasePtr(dx)); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; +REGISTER_OP_MLU_KERNEL( + sigmoid_cross_entropy_with_logits, + ops::SigmoidCrossEntropyWithLogitsMLUKernel, + ops::SigmoidCrossEntropyWithLogitsMLUKernel); +REGISTER_OP_MLU_KERNEL( + sigmoid_cross_entropy_with_logits_grad, + ops::SigmoidCrossEntropyWithLogitsMLUGradKernel, + ops::SigmoidCrossEntropyWithLogitsMLUGradKernel); diff --git a/python/paddle/fluid/tests/unittests/mlu/test_bce_with_logits_loss_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_bce_with_logits_loss_mlu.py new file mode 100644 index 00000000000..42989a5c44b --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mlu/test_bce_with_logits_loss_mlu.py @@ -0,0 +1,200 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.fluid as fluid +import numpy as np +import unittest +import sys + +sys.path.append('..') +from op_test import OpTest +from test_bce_with_logits_loss import call_bce_layer, call_bce_functional, test_dygraph, calc_bce_with_logits_loss + + +def test_static(place, + logit_np, + label_np, + weight_np=None, + reduction='mean', + pos_weight_np=None, + functional=False): + paddle.enable_static() + prog = paddle.static.Program() + startup_prog = paddle.static.Program() + with paddle.static.program_guard(prog, startup_prog): + logit = paddle.fluid.data(name='logit', + shape=logit_np.shape, + dtype='float32') + label = paddle.fluid.data(name='label', + shape=label_np.shape, + dtype='float32') + feed_dict = {"logit": logit_np, "label": label_np} + + pos_weight = None + weight = None + if pos_weight_np is not None: + pos_weight = paddle.fluid.data(name='pos_weight', + shape=pos_weight_np.shape, + dtype='float32') + feed_dict["pos_weight"] = pos_weight_np + if weight_np is not None: + weight = paddle.fluid.data(name='weight', + shape=weight_np.shape, + dtype='float32') + feed_dict["weight"] = weight_np + if functional: + res = call_bce_functional(logit, label, weight, reduction, + pos_weight) + else: + res = call_bce_layer(logit, label, weight, reduction, pos_weight) + exe = paddle.static.Executor(place) + static_result = exe.run(prog, feed=feed_dict, fetch_list=[res]) + return static_result + + +paddle.enable_static() + + +class TestBCEWithLogitsLoss(unittest.TestCase): + + def test_BCEWithLogitsLoss(self): + logit_np = np.random.uniform(0.1, 0.8, size=(20, 30)).astype(np.float32) + label_np = np.random.randint(0, 2, size=(20, 30)).astype(np.float32) + places = [fluid.MLUPlace(0)] + reductions = ['sum', 'mean', 'none'] + for place in places: + for reduction in reductions: + static_result = test_static(place, + logit_np, + label_np, + reduction=reduction) + dy_result = test_dygraph(place, + logit_np, + label_np, + reduction=reduction) + expected = calc_bce_with_logits_loss(logit_np, label_np, + reduction) + self.assertTrue(np.allclose(static_result, expected)) + self.assertTrue(np.allclose(static_result, dy_result)) + self.assertTrue(np.allclose(dy_result, expected)) + static_functional = test_static(place, + logit_np, + label_np, + reduction=reduction, + functional=True) + dy_functional = test_dygraph(place, + logit_np, + label_np, + reduction=reduction, + functional=True) + + self.assertTrue(np.allclose(static_functional, expected)) + self.assertTrue(np.allclose(static_functional, dy_functional)) + self.assertTrue(np.allclose(dy_functional, expected)) + + def test_BCEWithLogitsLoss_weight(self): + logit_np = np.random.uniform(0.1, 0.8, + size=(2, 3, 4, 10)).astype(np.float32) + label_np = np.random.randint(0, 2, + size=(2, 3, 4, 10)).astype(np.float32) + weight_np = np.random.random(size=(2, 3, 4, 10)).astype(np.float32) + place = fluid.MLUPlace(0) + for reduction in ['sum', 'mean', 'none']: + static_result = test_static(place, + logit_np, + label_np, + weight_np=weight_np, + reduction=reduction) + dy_result = test_dygraph(place, + logit_np, + label_np, + weight_np=weight_np, + reduction=reduction) + expected = calc_bce_with_logits_loss(logit_np, + label_np, + reduction, + weight_np=weight_np) + self.assertTrue(np.allclose(static_result, expected)) + self.assertTrue(np.allclose(static_result, dy_result)) + self.assertTrue(np.allclose(dy_result, expected)) + static_functional = test_static(place, + logit_np, + label_np, + weight_np=weight_np, + reduction=reduction, + functional=True) + dy_functional = test_dygraph(place, + logit_np, + label_np, + weight_np=weight_np, + reduction=reduction, + functional=True) + self.assertTrue(np.allclose(static_functional, expected)) + self.assertTrue(np.allclose(static_functional, dy_functional)) + self.assertTrue(np.allclose(dy_functional, expected)) + + def test_BCEWithLogitsLoss_pos_weight(self): + logit_np = np.random.uniform(0.1, 0.8, + size=(2, 3, 4, 10)).astype(np.float32) + label_np = np.random.randint(0, 2, + size=(2, 3, 4, 10)).astype(np.float32) + pos_weight_np = np.random.random(size=(3, 4, 10)).astype(np.float32) + weight_np = np.random.random(size=(2, 3, 4, 10)).astype(np.float32) + place = fluid.MLUPlace(0) + reduction = "mean" + static_result = test_static(place, logit_np, label_np, weight_np, + reduction, pos_weight_np) + dy_result = test_dygraph(place, logit_np, label_np, weight_np, + reduction, pos_weight_np) + expected = calc_bce_with_logits_loss(logit_np, label_np, reduction, + weight_np, pos_weight_np) + self.assertTrue(np.allclose(static_result, expected)) + self.assertTrue(np.allclose(static_result, dy_result)) + self.assertTrue(np.allclose(dy_result, expected)) + static_functional = test_static(place, + logit_np, + label_np, + weight_np, + reduction, + pos_weight_np, + functional=True) + dy_functional = test_dygraph(place, + logit_np, + label_np, + weight_np, + reduction, + pos_weight_np, + functional=True) + self.assertTrue(np.allclose(static_functional, expected)) + self.assertTrue(np.allclose(static_functional, dy_functional)) + self.assertTrue(np.allclose(dy_functional, expected)) + + def test_BCEWithLogitsLoss_error(self): + paddle.disable_static() + self.assertRaises(ValueError, + paddle.nn.BCEWithLogitsLoss, + reduction="unsupport reduction") + logit = paddle.to_tensor([[0.1, 0.3]], dtype='float32') + label = paddle.to_tensor([[0.0, 1.0]], dtype='float32') + self.assertRaises(ValueError, + paddle.nn.functional.binary_cross_entropy_with_logits, + logit=logit, + label=label, + reduction="unsupport reduction") + paddle.enable_static() + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/mlu/test_sigmoid_cross_entropy_with_logits_op_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_sigmoid_cross_entropy_with_logits_op_mlu.py new file mode 100644 index 00000000000..738a810b89e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mlu/test_sigmoid_cross_entropy_with_logits_op_mlu.py @@ -0,0 +1,172 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import sys + +sys.path.append('..') +from op_test import OpTest +from scipy.special import logit +from scipy.special import expit +import paddle.fluid.core as core +import unittest +from paddle.fluid import compiler, Program, program_guard +import paddle.fluid as fluid +import paddle + +paddle.enable_static() + + +class TestSigmoidCrossEntropyWithLogitsOp1(OpTest): + """Test sigmoid_cross_entropy_with_logit_op with binary label + """ + + def setUp(self): + self.op_type = "sigmoid_cross_entropy_with_logits" + self.set_mlu() + self.init_dtype() + + batch_size = 64 + num_classes = 20 + self.inputs = { + 'X': + logit( + np.random.uniform(0, 1, (batch_size, num_classes)).astype( + self.dtype)), + 'Label': + np.random.randint(0, 2, + (batch_size, num_classes)).astype(self.dtype) + } + + # Fw Pass is implemented as elementwise sigmoid followed by + # elementwise logistic loss + # Label * -log(sigmoid(X)) + (1 - label) * -log(1 - sigmoid(X)) + sigmoid_X = expit(self.inputs['X']) + term1 = self.inputs['Label'] * np.log(sigmoid_X) + term2 = (1 - self.inputs['Label']) * np.log(1 - sigmoid_X) + self.outputs = {'Out': -term1 - term2} + + def test_check_output(self): + self.check_output_with_place(self.place, atol=1e-5) + + def test_check_grad(self): + self.check_grad_with_place(self.place, ['X'], 'Out') + + def set_mlu(self): + self.__class__.use_mlu = True + self.place = paddle.MLUPlace(0) + + def init_dtype(self): + self.dtype = np.float32 + + +class TestSigmoidCrossEntropyWithLogitsOp3(TestSigmoidCrossEntropyWithLogitsOp1 + ): + """Test sigmoid_cross_entropy_with_logit_op with probabalistic label + """ + + def setUp(self): + self.op_type = "sigmoid_cross_entropy_with_logits" + self.set_mlu() + self.init_dtype() + + batch_size = 64 + num_classes = 20 + self.inputs = { + 'X': + logit( + np.random.uniform(0, 1, (batch_size, num_classes)).astype( + self.dtype)), + 'Label': + np.random.uniform(0, 1, + (batch_size, num_classes)).astype(self.dtype) + } + + # Fw Pass is implemented as elementwise sigmoid followed by + # elementwise logistic loss + # Label * -log(sigmoid(X)) + (1 - label) * -log(1 - sigmoid(X)) + sigmoid_X = expit(self.inputs['X']) + term1 = self.inputs['Label'] * np.log(sigmoid_X) + term2 = (1 - self.inputs['Label']) * np.log(1 - sigmoid_X) + self.outputs = {'Out': -term1 - term2} + + +class TestSigmoidCrossEntropyWithLogitsOp5(TestSigmoidCrossEntropyWithLogitsOp1 + ): + """Test sigmoid_cross_entropy_with_logit_op with probabalistic label + """ + + def setUp(self): + self.op_type = "sigmoid_cross_entropy_with_logits" + self.set_mlu() + self.init_dtype() + + batch_size = [10, 10] + num_classes = 20 + self.inputs = { + 'X': + logit( + np.random.uniform(0, 1, + tuple(batch_size + [num_classes])).astype( + self.dtype)), + 'Label': + np.random.uniform(0, 1, tuple(batch_size + [num_classes])).astype( + self.dtype) + } + + # Fw Pass is implemented as elementwise sigmoid followed by + # elementwise logistic loss + # Label * -log(sigmoid(X)) + (1 - label) * -log(1 - sigmoid(X)) + sigmoid_X = expit(self.inputs['X']) + term1 = self.inputs['Label'] * np.log(sigmoid_X) + term2 = (1 - self.inputs['Label']) * np.log(1 - sigmoid_X) + self.outputs = {'Out': -term1 - term2} + + +class TestSigmoidCrossEntropyWithLogitsOp6(TestSigmoidCrossEntropyWithLogitsOp1 + ): + """Test sigmoid_cross_entropy_with_logit_op with binary label + """ + + def setUp(self): + self.op_type = "sigmoid_cross_entropy_with_logits" + self.set_mlu() + self.init_dtype() + + batch_size = [10, 10] + num_classes = 20 + self.inputs = { + 'X': + logit( + np.random.uniform(0, 1, + tuple(batch_size + [num_classes])).astype( + self.dtype)), + 'Label': + np.random.randint(0, 2, tuple(batch_size + [num_classes])).astype( + self.dtype) + } + + # Fw Pass is implemented as elementwise sigmoid followed by + # elementwise logistic loss + # Label * -log(sigmoid(X)) + (1 - label) * -log(1 - sigmoid(X)) + sigmoid_X = expit(self.inputs['X']) + term1 = self.inputs['Label'] * np.log(sigmoid_X) + term2 = (1 - self.inputs['Label']) * np.log(1 - sigmoid_X) + self.outputs = {'Out': -term1 - term2} + + +if __name__ == '__main__': + unittest.main() -- GitLab