未验证 提交 1dfa2d49 编写于 作者: F fwenguang 提交者: GitHub

[MLU] add bce kernel for mlu (#43467)

上级 669d8689
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/mlu/mlu_baseop.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
class BCELossMLUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<Tensor>("X");
auto* labels = ctx.Input<Tensor>("Label");
auto* out = ctx.Output<Tensor>("Out");
out->mutable_data<T>(ctx.GetPlace());
MLUCnnlTensorDesc x_desc(*x);
MLUCnnlTensorDesc label_desc(*labels);
MLUCnnlTensorDesc out_desc(*out);
MLUCnnl::BceLoss(ctx, CNNL_BCE_LOSS_NONE, x_desc.get(), GetBasePtr(x),
label_desc.get(), GetBasePtr(labels), nullptr, nullptr,
out_desc.get(), GetBasePtr(out));
}
};
template <typename T>
class BCELossGradMLUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<Tensor>("X");
auto* labels = ctx.Input<Tensor>("Label");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
dx->mutable_data<T>(ctx.GetPlace());
MLUCnnlTensorDesc x_desc(*x);
MLUCnnlTensorDesc label_desc(*labels);
MLUCnnlTensorDesc dout_desc(*dout);
MLUCnnl::BceLossBackward(ctx, CNNL_BCE_LOSS_NONE, dout_desc.get(),
GetBasePtr(dout), x_desc.get(), GetBasePtr(x),
label_desc.get(), GetBasePtr(labels), nullptr,
nullptr, x_desc.get(), GetBasePtr(dx));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_MLU_KERNEL(bce_loss, ops::BCELossMLUKernel<float>,
ops::BCELossMLUKernel<plat::float16>);
REGISTER_OP_MLU_KERNEL(bce_loss_grad, ops::BCELossGradMLUKernel<float>,
ops::BCELossGradMLUKernel<plat::float16>);
......@@ -2799,6 +2799,52 @@ MLUCnnlTrigonDesc::~MLUCnnlTrigonDesc() {
cnnlReciprocal(handle, input_desc, input, output_desc, output));
}
/* static */ void MLUCnnl::BceLoss(
const ExecutionContext& ctx, const cnnlBceLossReduction_t reduction,
const cnnlTensorDescriptor_t input_desc, const void* input,
const cnnlTensorDescriptor_t target_desc, const void* target,
const cnnlTensorDescriptor_t weight_desc, const void* weight,
const cnnlTensorDescriptor_t output_desc, void* output) {
cnnlHandle_t handle = GetHandleFromCTX(ctx);
size_t workspace_size;
PADDLE_ENFORCE_MLU_SUCCESS(cnnlGetBceLossWorkspaceSize(
handle, input_desc, weight_desc, &workspace_size));
auto& dev_ctx = GetDevCtxFromCTX(ctx);
Tensor workspace = ctx.AllocateTmpTensor<int8_t, MLUDeviceContext>(
{static_cast<int64_t>(workspace_size)}, dev_ctx);
void* workspace_ptr = workspace.mutable_data(ctx.GetPlace());
PADDLE_ENFORCE_MLU_SUCCESS(cnnlBceLoss(
handle, input_desc, input, target_desc, target, weight_desc, weight,
reduction, workspace_ptr, workspace_size, output_desc, output));
}
/* static */ void MLUCnnl::BceLossBackward(
const ExecutionContext& ctx, const cnnlBceLossReduction_t reduction,
const cnnlTensorDescriptor_t grad_desc, const void* grad,
const cnnlTensorDescriptor_t input_desc, const void* input,
const cnnlTensorDescriptor_t target_desc, const void* target,
const cnnlTensorDescriptor_t weight_desc, const void* weight,
const cnnlTensorDescriptor_t output_desc, void* output) {
cnnlHandle_t handle = GetHandleFromCTX(ctx);
size_t workspace_size;
PADDLE_ENFORCE_MLU_SUCCESS(cnnlGetBceLossBackwardWorkspaceSize(
handle, target_desc, weight_desc, &workspace_size));
auto& dev_ctx = GetDevCtxFromCTX(ctx);
Tensor workspace = ctx.AllocateTmpTensor<int8_t, MLUDeviceContext>(
{static_cast<int64_t>(workspace_size)}, dev_ctx);
void* workspace_ptr = workspace.mutable_data(ctx.GetPlace());
PADDLE_ENFORCE_MLU_SUCCESS(
cnnlBceLossBackward(handle, grad_desc, grad, input_desc, input,
target_desc, target, weight_desc, weight, reduction,
workspace_ptr, workspace_size, output_desc, output));
}
/* static */ void MLUCnnl::EmbeddingForward(
const ExecutionContext& ctx, const int padding_idx,
const cnnlTensorDescriptor_t weight_desc, const void* weight,
......
......@@ -1268,6 +1268,21 @@ class MLUCnnl {
const cnnlTensorDescriptor_t output_desc,
void* output);
static void BceLoss(
const ExecutionContext& ctx, const cnnlBceLossReduction_t reduction,
const cnnlTensorDescriptor_t input_desc, const void* input,
const cnnlTensorDescriptor_t target_desc, const void* target,
const cnnlTensorDescriptor_t weight_desc, const void* weight,
const cnnlTensorDescriptor_t output_desc, void* output);
static void BceLossBackward(
const ExecutionContext& ctx, const cnnlBceLossReduction_t reduction,
const cnnlTensorDescriptor_t grad_desc, const void* grad,
const cnnlTensorDescriptor_t input_desc, const void* input,
const cnnlTensorDescriptor_t target_desc, const void* target,
const cnnlTensorDescriptor_t weight_desc, const void* weight,
const cnnlTensorDescriptor_t output_desc, void* output);
static void EmbeddingForward(
const ExecutionContext& ctx, const int padding_idx,
const cnnlTensorDescriptor_t weight_desc, const void* weight,
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.fluid as fluid
import numpy as np
import unittest
import sys
sys.path.append('..')
from op_test import OpTest
paddle.enable_static()
def test_static_layer(place,
input_np,
label_np,
reduction='mean',
weight_np=None):
prog = paddle.static.Program()
startup_prog = paddle.static.Program()
with paddle.static.program_guard(prog, startup_prog):
input = paddle.fluid.data(name='input',
shape=input_np.shape,
dtype='float32')
label = paddle.fluid.data(name='label',
shape=label_np.shape,
dtype='float32')
if weight_np is not None:
weight = paddle.fluid.data(name='weight',
shape=weight_np.shape,
dtype='float32')
bce_loss = paddle.nn.loss.BCELoss(weight=weight,
reduction=reduction)
else:
bce_loss = paddle.nn.loss.BCELoss(reduction=reduction)
res = bce_loss(input, label)
exe = paddle.static.Executor(place)
static_result = exe.run(prog,
feed={
"input": input_np,
"label": label_np
} if weight_np is None else {
"input": input_np,
"label": label_np,
"weight": weight_np
},
fetch_list=[res])
return static_result
def test_static_functional(place,
input_np,
label_np,
reduction='mean',
weight_np=None):
prog = paddle.static.Program()
startup_prog = paddle.static.Program()
with paddle.static.program_guard(prog, startup_prog):
input = paddle.fluid.data(name='input',
shape=input_np.shape,
dtype='float32')
label = paddle.fluid.data(name='label',
shape=label_np.shape,
dtype='float32')
if weight_np is not None:
weight = paddle.fluid.data(name='weight',
shape=weight_np.shape,
dtype='float32')
res = paddle.nn.functional.binary_cross_entropy(input,
label,
weight=weight,
reduction=reduction)
else:
res = paddle.nn.functional.binary_cross_entropy(input,
label,
reduction=reduction)
exe = paddle.static.Executor(place)
static_result = exe.run(prog,
feed={
"input": input_np,
"label": label_np
} if weight_np is None else {
"input": input_np,
"label": label_np,
"weight": weight_np
},
fetch_list=[res])
return static_result
def test_dygraph_layer(place,
input_np,
label_np,
reduction='mean',
weight_np=None):
paddle.disable_static()
if weight_np is not None:
weight = paddle.to_tensor(weight_np)
bce_loss = paddle.nn.loss.BCELoss(weight=weight, reduction=reduction)
else:
bce_loss = paddle.nn.loss.BCELoss(reduction=reduction)
dy_res = bce_loss(paddle.to_tensor(input_np), paddle.to_tensor(label_np))
dy_result = dy_res.numpy()
paddle.enable_static()
return dy_result
def test_dygraph_functional(place,
input_np,
label_np,
reduction='mean',
weight_np=None):
paddle.disable_static()
input = paddle.to_tensor(input_np)
label = paddle.to_tensor(label_np)
if weight_np is not None:
weight = paddle.to_tensor(weight_np)
dy_res = paddle.nn.functional.binary_cross_entropy(input,
label,
weight=weight,
reduction=reduction)
else:
dy_res = paddle.nn.functional.binary_cross_entropy(input,
label,
reduction=reduction)
dy_result = dy_res.numpy()
paddle.enable_static()
return dy_result
def calc_bceloss(input_np, label_np, reduction='mean', weight_np=None):
if weight_np is None:
expected = -1 * (label_np * np.log(input_np) +
(1. - label_np) * np.log(1. - input_np))
else:
expected = -1 * weight_np * (label_np * np.log(input_np) +
(1. - label_np) * np.log(1. - input_np))
if reduction == 'mean':
expected = np.mean(expected)
elif reduction == 'sum':
expected = np.sum(expected)
else:
expected = expected
return expected
class TestBCELoss(unittest.TestCase):
def test_BCELoss(self):
input_np = np.random.uniform(0.1, 0.8, size=(20, 30)).astype(np.float32)
label_np = np.random.randint(0, 2, size=(20, 30)).astype(np.float32)
places = [fluid.MLUPlace(0)]
reductions = ['sum', 'mean', 'none']
for place in places:
for reduction in reductions:
static_result = test_static_layer(place, input_np, label_np,
reduction)
dy_result = test_dygraph_layer(place, input_np, label_np,
reduction)
expected = calc_bceloss(input_np, label_np, reduction)
self.assertTrue(np.allclose(static_result, expected))
self.assertTrue(np.allclose(static_result, dy_result))
self.assertTrue(np.allclose(dy_result, expected))
static_functional = test_static_functional(
place, input_np, label_np, reduction)
dy_functional = test_dygraph_functional(place, input_np,
label_np, reduction)
self.assertTrue(np.allclose(static_functional, expected))
self.assertTrue(np.allclose(static_functional, dy_functional))
self.assertTrue(np.allclose(dy_functional, expected))
def test_BCELoss_weight(self):
input_np = np.random.uniform(0.1, 0.8,
size=(2, 3, 4, 10)).astype(np.float32)
label_np = np.random.randint(0, 2,
size=(2, 3, 4, 10)).astype(np.float32)
weight_np = np.random.random(size=(3, 4, 10)).astype(np.float32)
place = fluid.MLUPlace(0)
for reduction in ['sum', 'mean', 'none']:
static_result = test_static_layer(place,
input_np,
label_np,
reduction,
weight_np=weight_np)
dy_result = test_dygraph_layer(place,
input_np,
label_np,
reduction,
weight_np=weight_np)
expected = calc_bceloss(input_np,
label_np,
reduction,
weight_np=weight_np)
self.assertTrue(np.allclose(static_result, expected))
self.assertTrue(np.allclose(static_result, dy_result))
self.assertTrue(np.allclose(dy_result, expected))
static_functional = test_static_functional(place,
input_np,
label_np,
reduction,
weight_np=weight_np)
dy_functional = test_dygraph_functional(place,
input_np,
label_np,
reduction,
weight_np=weight_np)
self.assertTrue(np.allclose(static_functional, expected))
self.assertTrue(np.allclose(static_functional, dy_functional))
self.assertTrue(np.allclose(dy_functional, expected))
def test_BCELoss_error(self):
paddle.disable_static()
self.assertRaises(ValueError,
paddle.nn.loss.BCELoss,
reduction="unsupport reduction")
input = paddle.to_tensor([[0.1, 0.3]], dtype='float32')
label = paddle.to_tensor([[0.0, 1.0]], dtype='float32')
self.assertRaises(ValueError,
paddle.nn.functional.binary_cross_entropy,
input=input,
label=label,
reduction="unsupport reduction")
paddle.enable_static()
def bce_loss(input, label):
return -1 * (label * np.log(input) + (1. - label) * np.log(1. - input))
class TestBceLossOp(OpTest):
def setUp(self):
self.init_test_case()
self.op_type = "bce_loss"
self.place = paddle.device.MLUPlace(0)
self.__class__.use_mlu = True
input_np = np.random.uniform(0.1, 0.8, self.shape).astype("float32")
label_np = np.random.randint(0, 2, self.shape).astype("float32")
output_np = bce_loss(input_np, label_np)
self.inputs = {'X': input_np, 'Label': label_np}
self.outputs = {'Out': output_np}
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
self.check_grad_with_place(self.place, ['X'], 'Out')
def init_test_case(self):
self.shape = [10, 10]
class TestBceLossOpCase1(TestBceLossOp):
def init_test_case(self):
self.shape = [2, 3, 4, 5]
class TestBceLossOpCase2(TestBceLossOp):
def init_test_case(self):
self.shape = [2, 3, 20]
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册