diff --git a/paddle/fluid/operators/huber_loss_op_mlu.cc b/paddle/fluid/operators/huber_loss_op_mlu.cc new file mode 100644 index 0000000000000000000000000000000000000000..48937dc38df86c44e03871c5dbc8ca128c676a84 --- /dev/null +++ b/paddle/fluid/operators/huber_loss_op_mlu.cc @@ -0,0 +1,187 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/mlu/mlu_baseop.h" + +namespace paddle { +namespace operators { + +using Tensor = phi::DenseTensor; + +template +class HuberLossMLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto& dev_ctx = GetDevCtxFromCTX(ctx); + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto* residual = ctx.Output("Residual"); + auto* out = ctx.Output("Out"); + auto delta = ctx.Attr("delta"); + + auto place = ctx.GetPlace(); + + // compute y-x + cnnlDataType_t data_type = ToCnnlDataType(); + residual->mutable_data(x->dims(), place); + MLUCnnlTensorDesc x_desc(*x); + MLUCnnlOpTensorDesc sub_op_desc( + CNNL_OP_TENSOR_SUB, data_type, CNNL_NOT_PROPAGATE_NAN); + MLUCnnl::OpTensor(ctx, + sub_op_desc.get(), + x_desc.get(), + GetBasePtr(y), + x_desc.get(), + GetBasePtr(x), + x_desc.get(), + GetBasePtr(residual), + data_type); + + // compute smoothl1loss + out->mutable_data(x->dims(), place); + cnnlSmoothL1LossAlgorithm_t smoothl1_algo = + CNNL_SMOOTHL1LOSS_REDUCTION_NONE; // defines whether to do reduction + // here + MLUCnnl::SmoothL1LossForward(ctx, + x_desc.get(), + GetBasePtr(x), + x_desc.get(), /* target has same shape as x */ + GetBasePtr(y), + static_cast(delta), + smoothl1_algo, + x_desc.get(), /* out has same shape as x */ + GetBasePtr(out)); + + // compute multiply by delta + framework::Tensor scale_tensor, bias_tensor; + scale_tensor = ctx.AllocateTmpTensor({1}, dev_ctx); + bias_tensor = ctx.AllocateTmpTensor({1}, dev_ctx); + FillMLUTensorWithHostValue(ctx, static_cast(delta), &scale_tensor); + FillMLUTensorWithHostValue(ctx, static_cast(0.f), &bias_tensor); + const int axis = std::max(out->dims().size() - 1, 0); + + MLUCnnlTensorDesc scale_desc(scale_tensor); + MLUCnnlTensorDesc bias_desc(bias_tensor); + MLUCnnlTensorDesc out_desc(*out); + MLUCnnl::Scale(ctx, + axis, + out_desc.get(), + GetBasePtr(out), + scale_desc.get(), + GetBasePtr(&scale_tensor), + bias_desc.get(), + GetBasePtr(&bias_tensor), + out_desc.get(), + GetBasePtr(out)); + } +}; + +template +class HuberLossGradMLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto& dev_ctx = GetDevCtxFromCTX(ctx); + auto* residual = ctx.Input("Residual"); + auto* dout = ctx.Input(framework::GradVarName("Out")); + auto* dx = ctx.Output(framework::GradVarName("X")); + auto* dy = ctx.Output(framework::GradVarName("Y")); + auto delta = ctx.Attr("delta"); + + auto place = ctx.GetPlace(); + + Tensor t_grad_rd; + t_grad_rd = + ctx.AllocateTmpTensor(residual->dims(), dev_ctx); + MLUCnnlTensorDesc t_grad_rd_desc(t_grad_rd); + if (dx || dy) { + Tensor t_zero; + t_zero = + ctx.AllocateTmpTensor(residual->dims(), dev_ctx); + FillMLUTensorWithHostValue(ctx, static_cast(0.f), &t_zero); + + MLUCnnlTensorDesc residual_desc(*residual); + MLUCnnlTensorDesc dout_desc(*dout); + + cnnlSmoothL1LossAlgorithm_t smoothl1_algo = + CNNL_SMOOTHL1LOSS_REDUCTION_NONE; // defines whether to do reduction + // here + MLUCnnl::SmoothL1LossBackward(ctx, + residual_desc.get(), + GetBasePtr(residual), + residual_desc.get(), + GetBasePtr(&t_zero), + dout_desc.get(), + GetBasePtr(dout), + static_cast(delta), + smoothl1_algo, + t_grad_rd_desc.get(), + GetBasePtr(&t_grad_rd)); + } + // compute multiply by delta + framework::Tensor scale_tensor, bias_tensor; + scale_tensor = ctx.AllocateTmpTensor({1}, dev_ctx); + bias_tensor = ctx.AllocateTmpTensor({1}, dev_ctx); + + FillMLUTensorWithHostValue(ctx, static_cast(0.f), &bias_tensor); + const int axis = std::max(t_grad_rd.dims().size() - 1, 0); + + MLUCnnlTensorDesc scale_desc(scale_tensor); + MLUCnnlTensorDesc bias_desc(bias_tensor); + + if (dx) { + dx->mutable_data(place); + FillMLUTensorWithHostValue(ctx, static_cast(-delta), &scale_tensor); + MLUCnnlTensorDesc out_desc(*dx); + MLUCnnl::Scale(ctx, + axis, + t_grad_rd_desc.get(), + GetBasePtr(&t_grad_rd), + scale_desc.get(), + GetBasePtr(&scale_tensor), + bias_desc.get(), + GetBasePtr(&bias_tensor), + out_desc.get(), + GetBasePtr(dx)); + } + if (dy) { + dy->mutable_data(place); + FillMLUTensorWithHostValue(ctx, static_cast(delta), &scale_tensor); + MLUCnnlTensorDesc out_desc(*dy); + MLUCnnl::Scale(ctx, + axis, + t_grad_rd_desc.get(), + GetBasePtr(&t_grad_rd), + scale_desc.get(), + GetBasePtr(&scale_tensor), + bias_desc.get(), + GetBasePtr(&bias_tensor), + out_desc.get(), + GetBasePtr(dy)); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_MLU_KERNEL(huber_loss, + ops::HuberLossMLUKernel, + ops::HuberLossMLUKernel); +REGISTER_OP_MLU_KERNEL(huber_loss_grad, + ops::HuberLossGradMLUKernel, + ops::HuberLossGradMLUKernel); diff --git a/paddle/fluid/operators/mlu/mlu_baseop.cc b/paddle/fluid/operators/mlu/mlu_baseop.cc index 5d0ccc9f72283fea72f88b6cf937d0f265a450c2..fd2fb8b953dc40a356c17fb75e14bf268d4f220d 100644 --- a/paddle/fluid/operators/mlu/mlu_baseop.cc +++ b/paddle/fluid/operators/mlu/mlu_baseop.cc @@ -4725,6 +4725,78 @@ MLURNNDesc::~MLURNNDesc() { output)); } +/* static */ void MLUCnnl::SmoothL1LossForward( + const ExecutionContext& ctx, + const cnnlTensorDescriptor_t x_desc, + const void* x, + const cnnlTensorDescriptor_t t_desc, + const void* target, + const float beta, + const cnnlSmoothL1LossAlgorithm_t algorithm, + const cnnlTensorDescriptor_t y_desc, + void* y) { + cnnlHandle_t handle = GetHandleFromCTX(ctx); + + size_t workspace_size; + PADDLE_ENFORCE_MLU_SUCCESS(cnnlGetSmoothL1LossForwardWorkspaceSize( + handle, x_desc, algorithm, &workspace_size)); + + auto& dev_ctx = GetDevCtxFromCTX(ctx); + Tensor workspace = ctx.AllocateTmpTensor( + {static_cast(workspace_size)}, dev_ctx); + void* workspace_ptr = workspace.mutable_data(ctx.GetPlace()); + + PADDLE_ENFORCE_MLU_SUCCESS(cnnlSmoothL1LossForward_v2(handle, + x_desc, + x, + t_desc, + target, + beta, + algorithm, + workspace_ptr, + workspace_size, + y_desc, + y)); +} + +/* static */ void MLUCnnl::SmoothL1LossBackward( + const ExecutionContext& ctx, + const cnnlTensorDescriptor_t x_desc, + const void* x, + const cnnlTensorDescriptor_t target_desc, + const void* target, + const cnnlTensorDescriptor_t dy_desc, + const void* dy, + const float beta, + const cnnlSmoothL1LossAlgorithm_t algorithm, + const cnnlTensorDescriptor_t dx_desc, + void* dx) { + cnnlHandle_t handle = GetHandleFromCTX(ctx); + + size_t workspace_size; + PADDLE_ENFORCE_MLU_SUCCESS(cnnlGetSmoothL1LossBackwardWorkspaceSize( + handle, x_desc, algorithm, &workspace_size)); + + auto& dev_ctx = GetDevCtxFromCTX(ctx); + Tensor workspace = ctx.AllocateTmpTensor( + {static_cast(workspace_size)}, dev_ctx); + void* workspace_ptr = workspace.mutable_data(ctx.GetPlace()); + + PADDLE_ENFORCE_MLU_SUCCESS(cnnlSmoothL1LossBackward_v2(handle, + x_desc, + x, + target_desc, + target, + dy_desc, + dy, + beta, + algorithm, + workspace_ptr, + workspace_size, + dx_desc, + dx)); +} + /* static */ void MLUCnnl::EmbeddingForward( const ExecutionContext& ctx, const int padding_idx, diff --git a/paddle/fluid/operators/mlu/mlu_baseop.h b/paddle/fluid/operators/mlu/mlu_baseop.h index 4c728df4e4ef821c7c1efdfa46bef083a064c91f..971c29b9e04d2c4de2452125fb55fb63f5f302c8 100644 --- a/paddle/fluid/operators/mlu/mlu_baseop.h +++ b/paddle/fluid/operators/mlu/mlu_baseop.h @@ -2042,6 +2042,28 @@ class MLUCnnl { const cnnlTensorDescriptor_t output_desc, void* output); + static void SmoothL1LossForward(const ExecutionContext& ctx, + const cnnlTensorDescriptor_t x_desc, + const void* x, + const cnnlTensorDescriptor_t t_desc, + const void* target, + const float beta, + const cnnlSmoothL1LossAlgorithm_t algorithm, + const cnnlTensorDescriptor_t y_desc, + void* y); + + static void SmoothL1LossBackward(const ExecutionContext& ctx, + const cnnlTensorDescriptor_t x_desc, + const void* x, + const cnnlTensorDescriptor_t target_desc, + const void* target, + const cnnlTensorDescriptor_t dy_desc, + const void* dy, + const float beta, + const cnnlSmoothL1LossAlgorithm_t algorithm, + const cnnlTensorDescriptor_t dx_desc, + void* dx); + static void EmbeddingForward(const ExecutionContext& ctx, const int padding_idx, const cnnlTensorDescriptor_t weight_desc, diff --git a/python/paddle/fluid/tests/unittests/mlu/test_huber_loss_op_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_huber_loss_op_mlu.py new file mode 100644 index 0000000000000000000000000000000000000000..6839f8ab6f01df1db013dba74457e586fcc789dc --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mlu/test_huber_loss_op_mlu.py @@ -0,0 +1,128 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import sys + +sys.path.append("..") +from op_test import OpTest +import paddle +import paddle.fluid as fluid +from paddle.fluid import compiler, Program, program_guard + +paddle.enable_static() + + +def huber_loss_forward(val, delta): + abs_val = abs(val) + if abs_val <= delta: + return 0.5 * val * val + else: + return delta * (abs_val - 0.5 * delta) + + +class TestHuberLossOp(OpTest): + + def setUp(self): + self.op_type = 'huber_loss' + self.set_mlu() + self.python_api = paddle.fluid.layers.huber_loss + self.python_out_sig = ["Out"] + self.delta = 1.0 + self.init_input() + shape = self.set_shape() + residual = self.inputs['Y'] - self.inputs['X'] + loss = np.vectorize(huber_loss_forward)(residual, + self.delta).astype('float32') + self.attrs = {'delta': self.delta} + self.outputs = {'Residual': residual, 'Out': loss.reshape(shape)} + + def init_input(self): + shape = self.set_shape() + self.inputs = { + 'X': np.random.uniform(0, 1., shape).astype('float32'), + 'Y': np.random.uniform(0, 1., shape).astype('float32'), + } + + def set_mlu(self): + self.__class__.use_mlu = True + self.place = paddle.MLUPlace(0) + + def set_shape(self): + return (100, 1) + + def test_check_output(self): + self.check_output_with_place(self.place, atol=1e-3) + + def test_check_grad_normal(self): + self.check_grad_with_place(self.place, ['X', 'Y'], 'Out') + + def test_check_grad_ingore_x(self): + self.check_grad_with_place(self.place, ['Y'], + 'Out', + max_relative_error=0.008, + no_grad_set=set("residual")) + + def test_check_grad_ingore_y(self): + self.check_grad_with_place(self.place, ['X'], + 'Out', + max_relative_error=0.008, + no_grad_set=set('residual')) + + +def TestHuberLossOp1(TestHuberLossOp): + + def set_shape(self): + return (64) + + +def TestHuberLossOp2(TestHuberLossOp): + + def set_shape(self): + return (6, 6) + + +def TestHuberLossOp3(TestHuberLossOp): + + def set_shape(self): + return (6, 6, 1) + + +class TestHuberLossOpError(unittest.TestCase): + + def test_errors(self): + with program_guard(Program(), Program()): + # the input and label must be Variable + xw = np.random.random((6, 6)).astype("float32") + xr = fluid.data(name='xr', shape=[None, 6], dtype="float32") + lw = np.random.random((6, 6)).astype("float32") + lr = fluid.data(name='lr', shape=[None, 6], dtype="float32") + delta = 1.0 + self.assertRaises(TypeError, fluid.layers.huber_loss, xr, lw, delta) + self.assertRaises(TypeError, fluid.layers.huber_loss, xw, lr, delta) + + # the dtype of input and label must be float32 or float64 + xw2 = fluid.data(name='xw2', shape=[None, 6], dtype="int32") + lw2 = fluid.data(name='lw2', shape=[None, 6], dtype="int32") + self.assertRaises(TypeError, fluid.layers.huber_loss, xw2, lr, + delta) + self.assertRaises(TypeError, fluid.layers.huber_loss, xr, lw2, + delta) + + +if __name__ == '__main__': + unittest.main()