未验证 提交 7d15f930 编写于 作者: C cifar10 提交者: GitHub

add mlu label_smooth kernel (#43743)

上级 d77c4955
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/mlu/mlu_baseop.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
template <typename T>
class LabelSmoothMLUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in_t = ctx.Input<LoDTensor>("X");
auto* dist_t = ctx.Input<Tensor>("PriorDist");
auto* out_t = ctx.Output<LoDTensor>("Out");
auto epsilon = ctx.Attr<float>("epsilon");
auto epsilon_gt = 1.0f - epsilon;
if (in_t->numel() == 0) return;
out_t->mutable_data<T>(ctx.GetPlace());
auto label_dim = in_t->dims()[in_t->dims().size() - 1];
MLUCnnlTensorDesc x_desc(*in_t);
MLUCnnlTensorDesc out_desc(*out_t);
auto data_type = ToCnnlDataType<T>();
MLUCnnlOpTensorDesc op_tensor_desc(
CNNL_OP_TENSOR_ADD, data_type, CNNL_NOT_PROPAGATE_NAN);
if (ctx.HasInput("PriorDist")) {
MLUCnnlTensorDesc dist_desc(*dist_t);
MLUCnnl::OpTensor(ctx,
op_tensor_desc.get(),
x_desc.get(),
GetBasePtr(in_t),
dist_desc.get(),
GetBasePtr(dist_t),
out_desc.get(),
GetBasePtr(out_t),
data_type,
epsilon_gt,
epsilon);
} else {
auto& dev_ctx = ctx.template device_context<MLUDeviceContext>();
framework::Tensor dist_tensor =
ctx.AllocateTmpTensor<T, MLUDeviceContext>({1, label_dim}, dev_ctx);
MLUCnnlTensorDesc dist_desc(dist_tensor);
auto value = static_cast<T>(1.0f / label_dim);
MLUCnnl::Fill(ctx,
CNNL_POINTER_MODE_HOST,
&value,
dist_desc.get(),
GetBasePtr(&dist_tensor));
MLUCnnl::OpTensor(ctx,
op_tensor_desc.get(),
x_desc.get(),
GetBasePtr(in_t),
dist_desc.get(),
GetBasePtr(&dist_tensor),
out_desc.get(),
GetBasePtr(out_t),
data_type,
epsilon_gt,
epsilon);
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_MLU_KERNEL(label_smooth,
ops::LabelSmoothMLUKernel<float>,
ops::LabelSmoothMLUKernel<plat::float16>);
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import unittest
import sys
sys.path.append("..")
from op_test import OpTest
import paddle
import paddle.fluid as fluid
SEED = 2022
paddle.enable_static()
class TestLabelSmoothOp(OpTest):
def init_dtype(self):
self.dtype = np.float32
def config(self):
self.op_type = "label_smooth"
self.place = paddle.MLUPlace(0)
self.__class__.use_mlu = True
self.__class__.no_need_check_grad = True
self.epsilon = 0.1
batch_size, self.label_dim = 10, 12
np.random.seed(SEED)
self.label = np.zeros((batch_size, self.label_dim)).astype(self.dtype)
nonzero_index = np.random.randint(self.label_dim, size=(batch_size))
self.label[np.arange(batch_size), nonzero_index] = 1
def setUp(self):
self.init_dtype()
self.config()
smoothed_label = (
1 - self.epsilon) * self.label + self.epsilon / self.label_dim
smoothed_label = smoothed_label.astype(self.dtype)
self.inputs = {'X': self.label}
self.attrs = {'epsilon': self.epsilon}
self.outputs = {'Out': smoothed_label}
def test_check_output(self):
self.check_output_with_place(self.place)
class TestLabelSmoothOpWithPriorDist(TestLabelSmoothOp):
def setUp(self):
self.init_dtype()
self.config()
dist = np.random.random((1, self.label_dim)).astype(self.dtype)
smoothed_label = (1 - self.epsilon) * self.label + self.epsilon * dist
smoothed_label = smoothed_label.astype(self.dtype)
self.inputs = {'X': self.label, 'PriorDist': dist}
self.attrs = {'epsilon': self.epsilon}
self.outputs = {'Out': smoothed_label}
class TestLabelSmoothOp3D(TestLabelSmoothOp):
def setUp(self):
super(TestLabelSmoothOp3D, self).setUp()
self.inputs['X'] = self.inputs['X'].reshape(
[2, -1, self.inputs['X'].shape[-1]])
self.outputs['Out'] = self.outputs['Out'].reshape(
self.inputs['X'].shape)
class TestLabelSmoothOpWithPriorDist3D(TestLabelSmoothOpWithPriorDist):
def setUp(self):
super(TestLabelSmoothOpWithPriorDist3D, self).setUp()
self.inputs['X'] = self.inputs['X'].reshape(
[2, -1, self.inputs['X'].shape[-1]])
self.outputs['Out'] = self.outputs['Out'].reshape(
self.inputs['X'].shape)
class TestLabelSmoothOpFP16(TestLabelSmoothOp):
def init_dtype(self):
self.dtype = np.float16
class TestLabelSmoothOpWithPriorDistFP16(TestLabelSmoothOpWithPriorDist):
def init_dtype(self):
self.dtype = np.float16
class TestLabelSmoothOp3DFP16(TestLabelSmoothOp3D):
def init_dtype(self):
self.dtype = np.float16
class TestLabelSmoothOpWithPriorDist3DFP16(TestLabelSmoothOpWithPriorDist3D):
def init_dtype(self):
self.dtype = np.float16
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册