未验证 提交 669d8689 编写于 作者: F fwenguang 提交者: GitHub

[MLU] add bce kernel (#43435)

上级 13cf4ced
...@@ -2832,5 +2832,55 @@ MLUCnnlTrigonDesc::~MLUCnnlTrigonDesc() { ...@@ -2832,5 +2832,55 @@ MLUCnnlTrigonDesc::~MLUCnnlTrigonDesc() {
diff, workspace_ptr, workspace_size, output_desc, output)); diff, workspace_ptr, workspace_size, output_desc, output));
} }
/* static */ void MLUCnnl::BceWithLogits(
const ExecutionContext& ctx, cnnlBceWithLogitsReduction_t reduction,
const cnnlTensorDescriptor_t input_desc, const void* input,
const cnnlTensorDescriptor_t target_desc, const void* target,
const cnnlTensorDescriptor_t weight_desc, const void* weight,
const cnnlTensorDescriptor_t pos_weight_desc, const void* pos_weight,
const cnnlTensorDescriptor_t output_desc, void* output) {
cnnlHandle_t handle = GetHandleFromCTX(ctx);
size_t workspace_size;
PADDLE_ENFORCE_MLU_SUCCESS(cnnlGetBceWithLogitsWorkspaceSize(
handle, input_desc, weight_desc, pos_weight_desc, &workspace_size));
auto& dev_ctx = GetDevCtxFromCTX(ctx);
Tensor workspace = ctx.AllocateTmpTensor<int8_t, MLUDeviceContext>(
{static_cast<int64_t>(workspace_size)}, dev_ctx);
void* workspace_ptr = workspace.mutable_data(ctx.GetPlace());
const cnnlComputationPreference_t prefer = CNNL_COMPUTATION_HIGH_PRECISION;
PADDLE_ENFORCE_MLU_SUCCESS(cnnlBceWithLogits_v2(
handle, prefer, input_desc, input, target_desc, target, weight_desc,
weight, pos_weight_desc, pos_weight, reduction, workspace_ptr,
workspace_size, output_desc, output));
}
/* static */ void MLUCnnl::BceWithLogitsBackward(
const ExecutionContext& ctx, cnnlBceWithLogitsReduction_t reduction,
const cnnlTensorDescriptor_t grad_desc, const void* grad,
const cnnlTensorDescriptor_t input_desc, const void* input,
const cnnlTensorDescriptor_t target_desc, const void* target,
const cnnlTensorDescriptor_t weight_desc, const void* weight,
const cnnlTensorDescriptor_t pos_weight_desc, const void* pos_weight,
const cnnlTensorDescriptor_t diff_input_desc, void* diff_input) {
cnnlHandle_t handle = GetHandleFromCTX(ctx);
size_t workspace_size;
PADDLE_ENFORCE_MLU_SUCCESS(cnnlGetBceWithLogitsBackwardWorkspaceSize(
handle, target_desc, weight_desc, pos_weight_desc, &workspace_size));
auto& dev_ctx = GetDevCtxFromCTX(ctx);
Tensor workspace = ctx.AllocateTmpTensor<int8_t, MLUDeviceContext>(
{static_cast<int64_t>(workspace_size)}, dev_ctx);
void* workspace_ptr = workspace.mutable_data(ctx.GetPlace());
PADDLE_ENFORCE_MLU_SUCCESS(cnnlBceWithLogitsBackward(
handle, grad_desc, grad, input_desc, input, target_desc, target,
weight_desc, weight, pos_weight_desc, pos_weight, reduction,
workspace_ptr, workspace_size, diff_input_desc, diff_input));
}
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -1279,6 +1279,23 @@ class MLUCnnl { ...@@ -1279,6 +1279,23 @@ class MLUCnnl {
const cnnlTensorDescriptor_t indices_desc, const void* indices, const cnnlTensorDescriptor_t indices_desc, const void* indices,
const cnnlTensorDescriptor_t diff_desc, const void* diff, const cnnlTensorDescriptor_t diff_desc, const void* diff,
const cnnlTensorDescriptor_t output_desc, void* output); const cnnlTensorDescriptor_t output_desc, void* output);
static void BceWithLogits(
const ExecutionContext& ctx, cnnlBceWithLogitsReduction_t reduction,
const cnnlTensorDescriptor_t input_desc, const void* input,
const cnnlTensorDescriptor_t target_desc, const void* target,
const cnnlTensorDescriptor_t weight_desc, const void* weight,
const cnnlTensorDescriptor_t pos_weight_desc, const void* pos_weight,
const cnnlTensorDescriptor_t output_desc, void* output);
static void BceWithLogitsBackward(
const ExecutionContext& ctx, cnnlBceWithLogitsReduction_t reduction,
const cnnlTensorDescriptor_t grad_desc, const void* grad,
const cnnlTensorDescriptor_t input_desc, const void* input,
const cnnlTensorDescriptor_t target_desc, const void* target,
const cnnlTensorDescriptor_t weight_desc, const void* weight,
const cnnlTensorDescriptor_t pos_weight_desc, const void* pos_weight,
const cnnlTensorDescriptor_t diff_input_desc, void* diff_input);
}; };
template <typename T> template <typename T>
......
...@@ -73,6 +73,7 @@ namespace ops = paddle::operators; ...@@ -73,6 +73,7 @@ namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
REGISTER_OP_MLU_KERNEL(reduce_sum, ops::ReduceSumMLUKernel<float>, REGISTER_OP_MLU_KERNEL(reduce_sum, ops::ReduceSumMLUKernel<float>,
ops::ReduceSumMLUKernel<int>,
ops::ReduceSumMLUKernel<plat::float16>); ops::ReduceSumMLUKernel<plat::float16>);
REGISTER_OP_MLU_KERNEL(reduce_sum_grad, ops::ReduceSumGradMLUKernel<float>, REGISTER_OP_MLU_KERNEL(reduce_sum_grad, ops::ReduceSumGradMLUKernel<float>,
ops::ReduceSumGradMLUKernel<plat::float16>); ops::ReduceSumGradMLUKernel<plat::float16>);
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/mlu/mlu_baseop.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
const int kIgnoreIndex = -100;
void CheckAttrs(const framework::ExecutionContext& ctx) {
// cnnl not support normalize and ignore_index
bool normalize = ctx.Attr<bool>("normalize");
int ignore_index = ctx.Attr<int>("ignore_index");
PADDLE_ENFORCE_EQ(normalize, false,
platform::errors::InvalidArgument(
"attr normalize must be false, but got true"));
PADDLE_ENFORCE_EQ(ignore_index, kIgnoreIndex,
platform::errors::InvalidArgument(
"attr ignore_index must be default %d, but got %d",
kIgnoreIndex, ignore_index));
}
template <typename T>
class SigmoidCrossEntropyWithLogitsMLUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
CheckAttrs(ctx);
auto* x = ctx.Input<Tensor>("X");
auto* label = ctx.Input<Tensor>("Label");
auto* out = ctx.Output<Tensor>("Out");
auto place = ctx.GetPlace();
out->mutable_data<T>(place);
MLUCnnlTensorDesc x_desc(*x);
MLUCnnlTensorDesc label_desc(*label);
MLUCnnlTensorDesc out_desc(*out);
MLUCnnl::BceWithLogits(ctx, CNNL_BCE_WITH_LOGITS_NONE, x_desc.get(),
GetBasePtr(x), label_desc.get(), GetBasePtr(label),
nullptr, nullptr, nullptr, nullptr, out_desc.get(),
GetBasePtr(out));
}
};
template <typename T>
class SigmoidCrossEntropyWithLogitsMLUGradKernel
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
CheckAttrs(ctx);
auto* x = ctx.Input<Tensor>("X");
auto* label = ctx.Input<Tensor>("Label");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto place = ctx.GetPlace();
dx->mutable_data<T>(place);
MLUCnnlTensorDesc x_desc(*x);
MLUCnnlTensorDesc label_desc(*label);
MLUCnnlTensorDesc dout_desc(*dout);
MLUCnnl::BceWithLogitsBackward(
ctx, CNNL_BCE_WITH_LOGITS_NONE, dout_desc.get(), GetBasePtr(dout),
x_desc.get(), GetBasePtr(x), label_desc.get(), GetBasePtr(label),
nullptr, nullptr, nullptr, nullptr, x_desc.get(), GetBasePtr(dx));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_MLU_KERNEL(
sigmoid_cross_entropy_with_logits,
ops::SigmoidCrossEntropyWithLogitsMLUKernel<float>,
ops::SigmoidCrossEntropyWithLogitsMLUKernel<plat::float16>);
REGISTER_OP_MLU_KERNEL(
sigmoid_cross_entropy_with_logits_grad,
ops::SigmoidCrossEntropyWithLogitsMLUGradKernel<float>,
ops::SigmoidCrossEntropyWithLogitsMLUGradKernel<plat::float16>);
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.fluid as fluid
import numpy as np
import unittest
import sys
sys.path.append('..')
from op_test import OpTest
from test_bce_with_logits_loss import call_bce_layer, call_bce_functional, test_dygraph, calc_bce_with_logits_loss
def test_static(place,
logit_np,
label_np,
weight_np=None,
reduction='mean',
pos_weight_np=None,
functional=False):
paddle.enable_static()
prog = paddle.static.Program()
startup_prog = paddle.static.Program()
with paddle.static.program_guard(prog, startup_prog):
logit = paddle.fluid.data(name='logit',
shape=logit_np.shape,
dtype='float32')
label = paddle.fluid.data(name='label',
shape=label_np.shape,
dtype='float32')
feed_dict = {"logit": logit_np, "label": label_np}
pos_weight = None
weight = None
if pos_weight_np is not None:
pos_weight = paddle.fluid.data(name='pos_weight',
shape=pos_weight_np.shape,
dtype='float32')
feed_dict["pos_weight"] = pos_weight_np
if weight_np is not None:
weight = paddle.fluid.data(name='weight',
shape=weight_np.shape,
dtype='float32')
feed_dict["weight"] = weight_np
if functional:
res = call_bce_functional(logit, label, weight, reduction,
pos_weight)
else:
res = call_bce_layer(logit, label, weight, reduction, pos_weight)
exe = paddle.static.Executor(place)
static_result = exe.run(prog, feed=feed_dict, fetch_list=[res])
return static_result
paddle.enable_static()
class TestBCEWithLogitsLoss(unittest.TestCase):
def test_BCEWithLogitsLoss(self):
logit_np = np.random.uniform(0.1, 0.8, size=(20, 30)).astype(np.float32)
label_np = np.random.randint(0, 2, size=(20, 30)).astype(np.float32)
places = [fluid.MLUPlace(0)]
reductions = ['sum', 'mean', 'none']
for place in places:
for reduction in reductions:
static_result = test_static(place,
logit_np,
label_np,
reduction=reduction)
dy_result = test_dygraph(place,
logit_np,
label_np,
reduction=reduction)
expected = calc_bce_with_logits_loss(logit_np, label_np,
reduction)
self.assertTrue(np.allclose(static_result, expected))
self.assertTrue(np.allclose(static_result, dy_result))
self.assertTrue(np.allclose(dy_result, expected))
static_functional = test_static(place,
logit_np,
label_np,
reduction=reduction,
functional=True)
dy_functional = test_dygraph(place,
logit_np,
label_np,
reduction=reduction,
functional=True)
self.assertTrue(np.allclose(static_functional, expected))
self.assertTrue(np.allclose(static_functional, dy_functional))
self.assertTrue(np.allclose(dy_functional, expected))
def test_BCEWithLogitsLoss_weight(self):
logit_np = np.random.uniform(0.1, 0.8,
size=(2, 3, 4, 10)).astype(np.float32)
label_np = np.random.randint(0, 2,
size=(2, 3, 4, 10)).astype(np.float32)
weight_np = np.random.random(size=(2, 3, 4, 10)).astype(np.float32)
place = fluid.MLUPlace(0)
for reduction in ['sum', 'mean', 'none']:
static_result = test_static(place,
logit_np,
label_np,
weight_np=weight_np,
reduction=reduction)
dy_result = test_dygraph(place,
logit_np,
label_np,
weight_np=weight_np,
reduction=reduction)
expected = calc_bce_with_logits_loss(logit_np,
label_np,
reduction,
weight_np=weight_np)
self.assertTrue(np.allclose(static_result, expected))
self.assertTrue(np.allclose(static_result, dy_result))
self.assertTrue(np.allclose(dy_result, expected))
static_functional = test_static(place,
logit_np,
label_np,
weight_np=weight_np,
reduction=reduction,
functional=True)
dy_functional = test_dygraph(place,
logit_np,
label_np,
weight_np=weight_np,
reduction=reduction,
functional=True)
self.assertTrue(np.allclose(static_functional, expected))
self.assertTrue(np.allclose(static_functional, dy_functional))
self.assertTrue(np.allclose(dy_functional, expected))
def test_BCEWithLogitsLoss_pos_weight(self):
logit_np = np.random.uniform(0.1, 0.8,
size=(2, 3, 4, 10)).astype(np.float32)
label_np = np.random.randint(0, 2,
size=(2, 3, 4, 10)).astype(np.float32)
pos_weight_np = np.random.random(size=(3, 4, 10)).astype(np.float32)
weight_np = np.random.random(size=(2, 3, 4, 10)).astype(np.float32)
place = fluid.MLUPlace(0)
reduction = "mean"
static_result = test_static(place, logit_np, label_np, weight_np,
reduction, pos_weight_np)
dy_result = test_dygraph(place, logit_np, label_np, weight_np,
reduction, pos_weight_np)
expected = calc_bce_with_logits_loss(logit_np, label_np, reduction,
weight_np, pos_weight_np)
self.assertTrue(np.allclose(static_result, expected))
self.assertTrue(np.allclose(static_result, dy_result))
self.assertTrue(np.allclose(dy_result, expected))
static_functional = test_static(place,
logit_np,
label_np,
weight_np,
reduction,
pos_weight_np,
functional=True)
dy_functional = test_dygraph(place,
logit_np,
label_np,
weight_np,
reduction,
pos_weight_np,
functional=True)
self.assertTrue(np.allclose(static_functional, expected))
self.assertTrue(np.allclose(static_functional, dy_functional))
self.assertTrue(np.allclose(dy_functional, expected))
def test_BCEWithLogitsLoss_error(self):
paddle.disable_static()
self.assertRaises(ValueError,
paddle.nn.BCEWithLogitsLoss,
reduction="unsupport reduction")
logit = paddle.to_tensor([[0.1, 0.3]], dtype='float32')
label = paddle.to_tensor([[0.0, 1.0]], dtype='float32')
self.assertRaises(ValueError,
paddle.nn.functional.binary_cross_entropy_with_logits,
logit=logit,
label=label,
reduction="unsupport reduction")
paddle.enable_static()
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import sys
sys.path.append('..')
from op_test import OpTest
from scipy.special import logit
from scipy.special import expit
import paddle.fluid.core as core
import unittest
from paddle.fluid import compiler, Program, program_guard
import paddle.fluid as fluid
import paddle
paddle.enable_static()
class TestSigmoidCrossEntropyWithLogitsOp1(OpTest):
"""Test sigmoid_cross_entropy_with_logit_op with binary label
"""
def setUp(self):
self.op_type = "sigmoid_cross_entropy_with_logits"
self.set_mlu()
self.init_dtype()
batch_size = 64
num_classes = 20
self.inputs = {
'X':
logit(
np.random.uniform(0, 1, (batch_size, num_classes)).astype(
self.dtype)),
'Label':
np.random.randint(0, 2,
(batch_size, num_classes)).astype(self.dtype)
}
# Fw Pass is implemented as elementwise sigmoid followed by
# elementwise logistic loss
# Label * -log(sigmoid(X)) + (1 - label) * -log(1 - sigmoid(X))
sigmoid_X = expit(self.inputs['X'])
term1 = self.inputs['Label'] * np.log(sigmoid_X)
term2 = (1 - self.inputs['Label']) * np.log(1 - sigmoid_X)
self.outputs = {'Out': -term1 - term2}
def test_check_output(self):
self.check_output_with_place(self.place, atol=1e-5)
def test_check_grad(self):
self.check_grad_with_place(self.place, ['X'], 'Out')
def set_mlu(self):
self.__class__.use_mlu = True
self.place = paddle.MLUPlace(0)
def init_dtype(self):
self.dtype = np.float32
class TestSigmoidCrossEntropyWithLogitsOp3(TestSigmoidCrossEntropyWithLogitsOp1
):
"""Test sigmoid_cross_entropy_with_logit_op with probabalistic label
"""
def setUp(self):
self.op_type = "sigmoid_cross_entropy_with_logits"
self.set_mlu()
self.init_dtype()
batch_size = 64
num_classes = 20
self.inputs = {
'X':
logit(
np.random.uniform(0, 1, (batch_size, num_classes)).astype(
self.dtype)),
'Label':
np.random.uniform(0, 1,
(batch_size, num_classes)).astype(self.dtype)
}
# Fw Pass is implemented as elementwise sigmoid followed by
# elementwise logistic loss
# Label * -log(sigmoid(X)) + (1 - label) * -log(1 - sigmoid(X))
sigmoid_X = expit(self.inputs['X'])
term1 = self.inputs['Label'] * np.log(sigmoid_X)
term2 = (1 - self.inputs['Label']) * np.log(1 - sigmoid_X)
self.outputs = {'Out': -term1 - term2}
class TestSigmoidCrossEntropyWithLogitsOp5(TestSigmoidCrossEntropyWithLogitsOp1
):
"""Test sigmoid_cross_entropy_with_logit_op with probabalistic label
"""
def setUp(self):
self.op_type = "sigmoid_cross_entropy_with_logits"
self.set_mlu()
self.init_dtype()
batch_size = [10, 10]
num_classes = 20
self.inputs = {
'X':
logit(
np.random.uniform(0, 1,
tuple(batch_size + [num_classes])).astype(
self.dtype)),
'Label':
np.random.uniform(0, 1, tuple(batch_size + [num_classes])).astype(
self.dtype)
}
# Fw Pass is implemented as elementwise sigmoid followed by
# elementwise logistic loss
# Label * -log(sigmoid(X)) + (1 - label) * -log(1 - sigmoid(X))
sigmoid_X = expit(self.inputs['X'])
term1 = self.inputs['Label'] * np.log(sigmoid_X)
term2 = (1 - self.inputs['Label']) * np.log(1 - sigmoid_X)
self.outputs = {'Out': -term1 - term2}
class TestSigmoidCrossEntropyWithLogitsOp6(TestSigmoidCrossEntropyWithLogitsOp1
):
"""Test sigmoid_cross_entropy_with_logit_op with binary label
"""
def setUp(self):
self.op_type = "sigmoid_cross_entropy_with_logits"
self.set_mlu()
self.init_dtype()
batch_size = [10, 10]
num_classes = 20
self.inputs = {
'X':
logit(
np.random.uniform(0, 1,
tuple(batch_size + [num_classes])).astype(
self.dtype)),
'Label':
np.random.randint(0, 2, tuple(batch_size + [num_classes])).astype(
self.dtype)
}
# Fw Pass is implemented as elementwise sigmoid followed by
# elementwise logistic loss
# Label * -log(sigmoid(X)) + (1 - label) * -log(1 - sigmoid(X))
sigmoid_X = expit(self.inputs['X'])
term1 = self.inputs['Label'] * np.log(sigmoid_X)
term2 = (1 - self.inputs['Label']) * np.log(1 - sigmoid_X)
self.outputs = {'Out': -term1 - term2}
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册