diff --git a/paddle/fluid/operators/label_smooth_op_npu.cc b/paddle/fluid/operators/label_smooth_op_npu.cc new file mode 100644 index 0000000000000000000000000000000000000000..a20b7f06d794e7f40f7bc1f395e67a144cfa8eca --- /dev/null +++ b/paddle/fluid/operators/label_smooth_op_npu.cc @@ -0,0 +1,108 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/label_smooth_op.h" +#include "paddle/fluid/operators/npu_op_runner.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; + +template +void LabelSmoothMuls(const platform::Place& place, const aclrtStream& stream, + const Tensor* in, float val, Tensor* out) { + out->mutable_data(in->dims(), place); + const auto& runner = NpuOpRunner("Muls", {*in}, {*out}, {{"value", val}}); + runner.Run(stream); +} + +template +void LabelSmoothAdds(const platform::Place& place, const aclrtStream& stream, + const Tensor* in, float val, Tensor* out) { + out->mutable_data(in->dims(), place); + const auto& runner = NpuOpRunner("Adds", {*in}, {*out}, {{"value", val}}); + runner.Run(stream); +} + +template +void LabelSmoothAddBroadCast(const platform::Place& place, + const aclrtStream& stream, const Tensor* in1, + const Tensor* in2, Tensor* out) { + out->mutable_data(place); + const auto& runner = NpuOpRunner("AddV2", {*in1, *in2}, {*out}, {}); + runner.Run(stream); +} + +template +class LabelSmoothNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* out_t = ctx.Output("Out"); + auto* in_t = ctx.Input("X"); + auto* dist_t = ctx.Input("PriorDist"); + auto epsilon = ctx.Attr("epsilon"); + + auto label_dim = in_t->dims()[in_t->dims().size() - 1]; + auto place = ctx.GetPlace(); + + auto stream = + ctx.template device_context() + .stream(); + + if (dist_t) { + Tensor tmp; + Tensor dist; + Tensor tmp2; + LabelSmoothMuls(place, stream, in_t, (1 - epsilon), &tmp); + LabelSmoothMuls(place, stream, dist_t, epsilon, &tmp2); + tmp2.Resize({1, label_dim}); + LabelSmoothAddBroadCast(place, stream, &tmp, &tmp2, out_t); + } else { + Tensor tmp; + LabelSmoothMuls(place, stream, in_t, (1 - epsilon), &tmp); + LabelSmoothAdds(place, stream, &tmp, (epsilon / label_dim), out_t); + } + } +}; + +template +class LabelSmoothGradNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* d_out_t = ctx.Input(framework::GradVarName("Out")); + auto* d_in_t = ctx.Output(framework::GradVarName("X")); + auto epsilon = ctx.Attr("epsilon"); + + auto place = ctx.GetPlace(); + + auto stream = + ctx.template device_context() + .stream(); + + LabelSmoothMuls(place, stream, d_out_t, 1 - epsilon, d_in_t); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_NPU_KERNEL(label_smooth, ops::LabelSmoothNPUKernel, + ops::LabelSmoothNPUKernel); +REGISTER_OP_NPU_KERNEL(label_smooth_grad, ops::LabelSmoothGradNPUKernel, + ops::LabelSmoothGradNPUKernel); diff --git a/python/paddle/fluid/tests/unittests/npu/test_label_smooth_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_label_smooth_op_npu.py new file mode 100644 index 0000000000000000000000000000000000000000..6e5b4c012053f7e5e8cee28c7d54be3152ecb4cd --- /dev/null +++ b/python/paddle/fluid/tests/unittests/npu/test_label_smooth_op_npu.py @@ -0,0 +1,126 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import unittest +import sys +sys.path.append("..") +from op_test import OpTest +import paddle +import paddle.fluid as fluid + +paddle.enable_static() +SEED = 2021 + + +@unittest.skipIf(not paddle.is_compiled_with_npu(), + "core is not compiled with NPU") +class TestLabelSmoothOp(OpTest): + def setUp(self): + self.set_npu() + self.op_type = "label_smooth" + self.place = paddle.NPUPlace(0) + + self.init_dtype() + np.random.seed(SEED) + + self.set_inputs() + self.set_attrs() + self.set_outputs() + + def calc_out(self, label, epsilon, dist=None): + label_dim = label.shape[-1] + y = (1 - epsilon) * label + if dist is not None: + y += epsilon * dist + else: + y += epsilon / label_dim + return y.astype(self.dtype) + + def set_inputs(self): + batch_size, label_dim = 10, 12 + x = np.zeros((batch_size, label_dim)).astype(self.dtype) + nonzero_index = np.random.randint(label_dim, size=(batch_size)) + x[np.arange(batch_size), nonzero_index] = 1 + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + + def set_attrs(self): + epsilon = 0.1 + self.attrs = {"epsilon": epsilon} + + def set_outputs(self): + dist = None if 'PriorDist' not in self.inputs else self.inputs[ + 'PriorDist'] + out = self.calc_out(self.inputs['X'], self.attrs['epsilon'], dist) + self.outputs = {'Out': out} + + def set_npu(self): + self.__class__.use_npu = True + + def init_dtype(self): + self.dtype = np.float32 + + def test_check_output(self): + self.check_output_with_place(self.place) + + def test_check_grad(self): + if self.dtype == np.float16: + return + self.check_grad_with_place(self.place, ['X'], 'Out') + + +class TestLabelSmoothOpWithPriorDist(TestLabelSmoothOp): + def set_inputs(self): + super(TestLabelSmoothOpWithPriorDist, self).set_inputs() + label_dim = self.inputs['X'].shape[-1] + dist = np.random.random((1, label_dim)).astype(self.dtype) + self.inputs['PriorDist'] = dist + + +class TestLabelSmoothOp3D(TestLabelSmoothOp): + def set_inputs(self): + super(TestLabelSmoothOp3D, self).set_inputs() + self.inputs['X'].reshape([2, -1, self.inputs['X'].shape[-1]]) + + +class TestLabelSmoothOpWithPriorDist3D(TestLabelSmoothOpWithPriorDist): + def set_inputs(self): + super(TestLabelSmoothOpWithPriorDist3D, self).set_inputs() + self.inputs['X'].reshape([2, -1, self.inputs['X'].shape[-1]]) + + +class TestLabelSmoothOpFP16(TestLabelSmoothOp): + def init_dtype(self): + self.dtype = np.float16 + + +class TestLabelSmoothOpWithPriorDistFP16(TestLabelSmoothOpWithPriorDist): + def init_dtype(self): + self.dtype = np.float16 + + +class TestLabelSmoothOp3DFP16(TestLabelSmoothOp3D): + def init_dtype(self): + self.dtype = np.float16 + + +class TestLabelSmoothOpWithPriorDist3DFP16(TestLabelSmoothOpWithPriorDist3D): + def init_dtype(self): + self.dtype = np.float16 + + +if __name__ == '__main__': + unittest.main()