From 1dfa2d49cbae1ecf4028e6e3f6710daa9d209afb Mon Sep 17 00:00:00 2001 From: fwenguang <95677191+fwenguang@users.noreply.github.com> Date: Wed, 15 Jun 2022 19:08:22 +0800 Subject: [PATCH] [MLU] add bce kernel for mlu (#43467) --- paddle/fluid/operators/bce_loss_op_mlu.cc | 73 +++++ paddle/fluid/operators/mlu/mlu_baseop.cc | 46 +++ paddle/fluid/operators/mlu/mlu_baseop.h | 15 + .../tests/unittests/mlu/test_bce_loss_mlu.py | 283 ++++++++++++++++++ 4 files changed, 417 insertions(+) create mode 100644 paddle/fluid/operators/bce_loss_op_mlu.cc create mode 100644 python/paddle/fluid/tests/unittests/mlu/test_bce_loss_mlu.py diff --git a/paddle/fluid/operators/bce_loss_op_mlu.cc b/paddle/fluid/operators/bce_loss_op_mlu.cc new file mode 100644 index 00000000000..f32ad69ba96 --- /dev/null +++ b/paddle/fluid/operators/bce_loss_op_mlu.cc @@ -0,0 +1,73 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/mlu/mlu_baseop.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class BCELossMLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* labels = ctx.Input("Label"); + auto* out = ctx.Output("Out"); + + out->mutable_data(ctx.GetPlace()); + + MLUCnnlTensorDesc x_desc(*x); + MLUCnnlTensorDesc label_desc(*labels); + MLUCnnlTensorDesc out_desc(*out); + MLUCnnl::BceLoss(ctx, CNNL_BCE_LOSS_NONE, x_desc.get(), GetBasePtr(x), + label_desc.get(), GetBasePtr(labels), nullptr, nullptr, + out_desc.get(), GetBasePtr(out)); + } +}; + +template +class BCELossGradMLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* labels = ctx.Input("Label"); + auto* dout = ctx.Input(framework::GradVarName("Out")); + auto* dx = ctx.Output(framework::GradVarName("X")); + + dx->mutable_data(ctx.GetPlace()); + + MLUCnnlTensorDesc x_desc(*x); + MLUCnnlTensorDesc label_desc(*labels); + MLUCnnlTensorDesc dout_desc(*dout); + MLUCnnl::BceLossBackward(ctx, CNNL_BCE_LOSS_NONE, dout_desc.get(), + GetBasePtr(dout), x_desc.get(), GetBasePtr(x), + label_desc.get(), GetBasePtr(labels), nullptr, + nullptr, x_desc.get(), GetBasePtr(dx)); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_MLU_KERNEL(bce_loss, ops::BCELossMLUKernel, + ops::BCELossMLUKernel); + +REGISTER_OP_MLU_KERNEL(bce_loss_grad, ops::BCELossGradMLUKernel, + ops::BCELossGradMLUKernel); diff --git a/paddle/fluid/operators/mlu/mlu_baseop.cc b/paddle/fluid/operators/mlu/mlu_baseop.cc index 2f5c54086b0..8414a7921de 100644 --- a/paddle/fluid/operators/mlu/mlu_baseop.cc +++ b/paddle/fluid/operators/mlu/mlu_baseop.cc @@ -2799,6 +2799,52 @@ MLUCnnlTrigonDesc::~MLUCnnlTrigonDesc() { cnnlReciprocal(handle, input_desc, input, output_desc, output)); } +/* static */ void MLUCnnl::BceLoss( + const ExecutionContext& ctx, const cnnlBceLossReduction_t reduction, + const cnnlTensorDescriptor_t input_desc, const void* input, + const cnnlTensorDescriptor_t target_desc, const void* target, + const cnnlTensorDescriptor_t weight_desc, const void* weight, + const cnnlTensorDescriptor_t output_desc, void* output) { + cnnlHandle_t handle = GetHandleFromCTX(ctx); + + size_t workspace_size; + PADDLE_ENFORCE_MLU_SUCCESS(cnnlGetBceLossWorkspaceSize( + handle, input_desc, weight_desc, &workspace_size)); + + auto& dev_ctx = GetDevCtxFromCTX(ctx); + Tensor workspace = ctx.AllocateTmpTensor( + {static_cast(workspace_size)}, dev_ctx); + void* workspace_ptr = workspace.mutable_data(ctx.GetPlace()); + + PADDLE_ENFORCE_MLU_SUCCESS(cnnlBceLoss( + handle, input_desc, input, target_desc, target, weight_desc, weight, + reduction, workspace_ptr, workspace_size, output_desc, output)); +} + +/* static */ void MLUCnnl::BceLossBackward( + const ExecutionContext& ctx, const cnnlBceLossReduction_t reduction, + const cnnlTensorDescriptor_t grad_desc, const void* grad, + const cnnlTensorDescriptor_t input_desc, const void* input, + const cnnlTensorDescriptor_t target_desc, const void* target, + const cnnlTensorDescriptor_t weight_desc, const void* weight, + const cnnlTensorDescriptor_t output_desc, void* output) { + cnnlHandle_t handle = GetHandleFromCTX(ctx); + + size_t workspace_size; + PADDLE_ENFORCE_MLU_SUCCESS(cnnlGetBceLossBackwardWorkspaceSize( + handle, target_desc, weight_desc, &workspace_size)); + + auto& dev_ctx = GetDevCtxFromCTX(ctx); + Tensor workspace = ctx.AllocateTmpTensor( + {static_cast(workspace_size)}, dev_ctx); + void* workspace_ptr = workspace.mutable_data(ctx.GetPlace()); + + PADDLE_ENFORCE_MLU_SUCCESS( + cnnlBceLossBackward(handle, grad_desc, grad, input_desc, input, + target_desc, target, weight_desc, weight, reduction, + workspace_ptr, workspace_size, output_desc, output)); +} + /* static */ void MLUCnnl::EmbeddingForward( const ExecutionContext& ctx, const int padding_idx, const cnnlTensorDescriptor_t weight_desc, const void* weight, diff --git a/paddle/fluid/operators/mlu/mlu_baseop.h b/paddle/fluid/operators/mlu/mlu_baseop.h index 21214c16268..6c5f716625c 100644 --- a/paddle/fluid/operators/mlu/mlu_baseop.h +++ b/paddle/fluid/operators/mlu/mlu_baseop.h @@ -1268,6 +1268,21 @@ class MLUCnnl { const cnnlTensorDescriptor_t output_desc, void* output); + static void BceLoss( + const ExecutionContext& ctx, const cnnlBceLossReduction_t reduction, + const cnnlTensorDescriptor_t input_desc, const void* input, + const cnnlTensorDescriptor_t target_desc, const void* target, + const cnnlTensorDescriptor_t weight_desc, const void* weight, + const cnnlTensorDescriptor_t output_desc, void* output); + + static void BceLossBackward( + const ExecutionContext& ctx, const cnnlBceLossReduction_t reduction, + const cnnlTensorDescriptor_t grad_desc, const void* grad, + const cnnlTensorDescriptor_t input_desc, const void* input, + const cnnlTensorDescriptor_t target_desc, const void* target, + const cnnlTensorDescriptor_t weight_desc, const void* weight, + const cnnlTensorDescriptor_t output_desc, void* output); + static void EmbeddingForward( const ExecutionContext& ctx, const int padding_idx, const cnnlTensorDescriptor_t weight_desc, const void* weight, diff --git a/python/paddle/fluid/tests/unittests/mlu/test_bce_loss_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_bce_loss_mlu.py new file mode 100644 index 00000000000..78dd988aa7e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mlu/test_bce_loss_mlu.py @@ -0,0 +1,283 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.fluid as fluid +import numpy as np +import unittest +import sys + +sys.path.append('..') +from op_test import OpTest + +paddle.enable_static() + + +def test_static_layer(place, + input_np, + label_np, + reduction='mean', + weight_np=None): + prog = paddle.static.Program() + startup_prog = paddle.static.Program() + with paddle.static.program_guard(prog, startup_prog): + input = paddle.fluid.data(name='input', + shape=input_np.shape, + dtype='float32') + label = paddle.fluid.data(name='label', + shape=label_np.shape, + dtype='float32') + if weight_np is not None: + weight = paddle.fluid.data(name='weight', + shape=weight_np.shape, + dtype='float32') + bce_loss = paddle.nn.loss.BCELoss(weight=weight, + reduction=reduction) + else: + bce_loss = paddle.nn.loss.BCELoss(reduction=reduction) + res = bce_loss(input, label) + exe = paddle.static.Executor(place) + static_result = exe.run(prog, + feed={ + "input": input_np, + "label": label_np + } if weight_np is None else { + "input": input_np, + "label": label_np, + "weight": weight_np + }, + fetch_list=[res]) + return static_result + + +def test_static_functional(place, + input_np, + label_np, + reduction='mean', + weight_np=None): + prog = paddle.static.Program() + startup_prog = paddle.static.Program() + with paddle.static.program_guard(prog, startup_prog): + input = paddle.fluid.data(name='input', + shape=input_np.shape, + dtype='float32') + label = paddle.fluid.data(name='label', + shape=label_np.shape, + dtype='float32') + if weight_np is not None: + weight = paddle.fluid.data(name='weight', + shape=weight_np.shape, + dtype='float32') + res = paddle.nn.functional.binary_cross_entropy(input, + label, + weight=weight, + reduction=reduction) + else: + res = paddle.nn.functional.binary_cross_entropy(input, + label, + reduction=reduction) + exe = paddle.static.Executor(place) + static_result = exe.run(prog, + feed={ + "input": input_np, + "label": label_np + } if weight_np is None else { + "input": input_np, + "label": label_np, + "weight": weight_np + }, + fetch_list=[res]) + return static_result + + +def test_dygraph_layer(place, + input_np, + label_np, + reduction='mean', + weight_np=None): + paddle.disable_static() + if weight_np is not None: + weight = paddle.to_tensor(weight_np) + bce_loss = paddle.nn.loss.BCELoss(weight=weight, reduction=reduction) + else: + bce_loss = paddle.nn.loss.BCELoss(reduction=reduction) + dy_res = bce_loss(paddle.to_tensor(input_np), paddle.to_tensor(label_np)) + dy_result = dy_res.numpy() + paddle.enable_static() + return dy_result + + +def test_dygraph_functional(place, + input_np, + label_np, + reduction='mean', + weight_np=None): + paddle.disable_static() + input = paddle.to_tensor(input_np) + label = paddle.to_tensor(label_np) + + if weight_np is not None: + weight = paddle.to_tensor(weight_np) + dy_res = paddle.nn.functional.binary_cross_entropy(input, + label, + weight=weight, + reduction=reduction) + else: + dy_res = paddle.nn.functional.binary_cross_entropy(input, + label, + reduction=reduction) + dy_result = dy_res.numpy() + paddle.enable_static() + return dy_result + + +def calc_bceloss(input_np, label_np, reduction='mean', weight_np=None): + if weight_np is None: + expected = -1 * (label_np * np.log(input_np) + + (1. - label_np) * np.log(1. - input_np)) + else: + expected = -1 * weight_np * (label_np * np.log(input_np) + + (1. - label_np) * np.log(1. - input_np)) + + if reduction == 'mean': + expected = np.mean(expected) + elif reduction == 'sum': + expected = np.sum(expected) + else: + expected = expected + + return expected + + +class TestBCELoss(unittest.TestCase): + + def test_BCELoss(self): + input_np = np.random.uniform(0.1, 0.8, size=(20, 30)).astype(np.float32) + label_np = np.random.randint(0, 2, size=(20, 30)).astype(np.float32) + places = [fluid.MLUPlace(0)] + reductions = ['sum', 'mean', 'none'] + for place in places: + for reduction in reductions: + static_result = test_static_layer(place, input_np, label_np, + reduction) + dy_result = test_dygraph_layer(place, input_np, label_np, + reduction) + expected = calc_bceloss(input_np, label_np, reduction) + self.assertTrue(np.allclose(static_result, expected)) + self.assertTrue(np.allclose(static_result, dy_result)) + self.assertTrue(np.allclose(dy_result, expected)) + static_functional = test_static_functional( + place, input_np, label_np, reduction) + dy_functional = test_dygraph_functional(place, input_np, + label_np, reduction) + self.assertTrue(np.allclose(static_functional, expected)) + self.assertTrue(np.allclose(static_functional, dy_functional)) + self.assertTrue(np.allclose(dy_functional, expected)) + + def test_BCELoss_weight(self): + input_np = np.random.uniform(0.1, 0.8, + size=(2, 3, 4, 10)).astype(np.float32) + label_np = np.random.randint(0, 2, + size=(2, 3, 4, 10)).astype(np.float32) + weight_np = np.random.random(size=(3, 4, 10)).astype(np.float32) + place = fluid.MLUPlace(0) + for reduction in ['sum', 'mean', 'none']: + static_result = test_static_layer(place, + input_np, + label_np, + reduction, + weight_np=weight_np) + dy_result = test_dygraph_layer(place, + input_np, + label_np, + reduction, + weight_np=weight_np) + expected = calc_bceloss(input_np, + label_np, + reduction, + weight_np=weight_np) + self.assertTrue(np.allclose(static_result, expected)) + self.assertTrue(np.allclose(static_result, dy_result)) + self.assertTrue(np.allclose(dy_result, expected)) + static_functional = test_static_functional(place, + input_np, + label_np, + reduction, + weight_np=weight_np) + dy_functional = test_dygraph_functional(place, + input_np, + label_np, + reduction, + weight_np=weight_np) + self.assertTrue(np.allclose(static_functional, expected)) + self.assertTrue(np.allclose(static_functional, dy_functional)) + self.assertTrue(np.allclose(dy_functional, expected)) + + def test_BCELoss_error(self): + paddle.disable_static() + self.assertRaises(ValueError, + paddle.nn.loss.BCELoss, + reduction="unsupport reduction") + input = paddle.to_tensor([[0.1, 0.3]], dtype='float32') + label = paddle.to_tensor([[0.0, 1.0]], dtype='float32') + self.assertRaises(ValueError, + paddle.nn.functional.binary_cross_entropy, + input=input, + label=label, + reduction="unsupport reduction") + paddle.enable_static() + + +def bce_loss(input, label): + return -1 * (label * np.log(input) + (1. - label) * np.log(1. - input)) + + +class TestBceLossOp(OpTest): + + def setUp(self): + self.init_test_case() + self.op_type = "bce_loss" + self.place = paddle.device.MLUPlace(0) + self.__class__.use_mlu = True + input_np = np.random.uniform(0.1, 0.8, self.shape).astype("float32") + label_np = np.random.randint(0, 2, self.shape).astype("float32") + output_np = bce_loss(input_np, label_np) + + self.inputs = {'X': input_np, 'Label': label_np} + self.outputs = {'Out': output_np} + + def test_check_output(self): + self.check_output_with_place(self.place) + + def test_check_grad(self): + self.check_grad_with_place(self.place, ['X'], 'Out') + + def init_test_case(self): + self.shape = [10, 10] + + +class TestBceLossOpCase1(TestBceLossOp): + + def init_test_case(self): + self.shape = [2, 3, 4, 5] + + +class TestBceLossOpCase2(TestBceLossOp): + + def init_test_case(self): + self.shape = [2, 3, 20] + + +if __name__ == "__main__": + unittest.main() -- GitLab