From aecf9967cf901a3ea3bae3e1ed3bd2996fe909c5 Mon Sep 17 00:00:00 2001 From: fwenguang <95677191+fwenguang@users.noreply.github.com> Date: Sun, 30 Jan 2022 10:47:31 +0800 Subject: [PATCH] [MLU] add softmax_with_cross_entropy mlu kernel (#39260) --- .../softmax_with_cross_entropy_op.cc | 17 +- .../softmax_with_cross_entropy_op_mlu.cc | 151 ++++++++++++++++ python/paddle/fluid/layers/loss.py | 2 +- .../test_softmax_with_cross_entropy_op_mlu.py | 161 ++++++++++++++++++ 4 files changed, 321 insertions(+), 10 deletions(-) create mode 100644 paddle/fluid/operators/softmax_with_cross_entropy_op_mlu.cc create mode 100644 python/paddle/fluid/tests/unittests/mlu/test_softmax_with_cross_entropy_op_mlu.py diff --git a/paddle/fluid/operators/softmax_with_cross_entropy_op.cc b/paddle/fluid/operators/softmax_with_cross_entropy_op.cc index 78e813edda..cba779d0a7 100644 --- a/paddle/fluid/operators/softmax_with_cross_entropy_op.cc +++ b/paddle/fluid/operators/softmax_with_cross_entropy_op.cc @@ -40,7 +40,7 @@ class SoftmaxWithCrossEntropyOpMaker "The outputs value of softmax activation by given the input batch, " "which will be used in backward calculation.") .AsIntermediate(); -#ifdef PADDLE_WITH_ASCEND_CL +#if defined(PADDLE_WITH_ASCEND_CL) || defined(PADDLE_WITH_MLU) AddOutput( "Backprop", "(Tensor, default: Tensor), A tensor in same shape with " @@ -49,7 +49,7 @@ class SoftmaxWithCrossEntropyOpMaker "is :" "exp(logits -max_logits) / sum(exp(logits - max_logits)) - labels, " "where labels is ont-hot." - "Currently, the tensor is generated and used in npu kernel only. ") + "Currently, the tensor is generated and used in npu/mlu kernel. ") .AsIntermediate(); #endif AddOutput("Loss", @@ -131,7 +131,7 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_EQ(ctx->HasOutput("Softmax"), true, platform::errors::InvalidArgument( "Output(Softmax) should be not null.")); -#ifdef PADDLE_WITH_ASCEND_CL +#if defined(PADDLE_WITH_ASCEND_CL) || defined(PADDLE_WITH_MLU) PADDLE_ENFORCE_EQ(ctx->HasOutput("Backprop"), true, platform::errors::InvalidArgument( "Output(Backprop) should be not null.")); @@ -194,7 +194,7 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel { } ctx->SetOutputDim("Softmax", logits_dims); -#ifdef PADDLE_WITH_ASCEND_CL +#if defined(PADDLE_WITH_ASCEND_CL) || defined(PADDLE_WITH_MLU) ctx->SetOutputDim("Backprop", logits_dims); ctx->ShareLoD("Logits", /*->*/ "Backprop"); #endif @@ -225,7 +225,7 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel { PADDLE_ENFORCE_EQ(ctx->HasInput("Softmax"), true, platform::errors::InvalidArgument( "Input(Softmax) should be not null.")); -#ifdef PADDLE_WITH_ASCEND_CL +#if defined(PADDLE_WITH_ASCEND_CL) || defined(PADDLE_WITH_MLU) PADDLE_ENFORCE_EQ(ctx->HasInput("Backprop"), true, platform::errors::InvalidArgument( "Input(Backprop) should be not null.")); @@ -306,7 +306,7 @@ class SoftmaxGradMaker : public framework::SingleGradOpMaker { grad_op->SetType("softmax_with_cross_entropy_grad"); grad_op->SetInput("Label", this->Input("Label")); grad_op->SetInput("Softmax", this->Output("Softmax")); -#ifdef PADDLE_WITH_ASCEND_CL +#if defined(PADDLE_WITH_ASCEND_CL) || defined(PADDLE_WITH_MLU) grad_op->SetInput("Backprop", this->Output("Backprop")); #endif grad_op->SetInput(framework::GradVarName("Loss"), this->OutputGrad("Loss")); @@ -343,7 +343,7 @@ REGISTER_OP_CPU_KERNEL(softmax_with_cross_entropy_grad, ops::SoftmaxWithCrossEntropyGradKernel); REGISTER_OP_VERSION(softmax_with_cross_entropy) -#ifdef PADDLE_WITH_ASCEND_CL +#if defined(PADDLE_WITH_ASCEND_CL) || defined(PADDLE_WITH_MLU) .AddCheckpoint( R"ROC( Add a new attribute [use_softmax] )ROC", @@ -358,8 +358,7 @@ REGISTER_OP_VERSION(softmax_with_cross_entropy) "calculation is :" "exp(logits -max_logits) / sum(exp(logits - max_logits)) - labels, " "where labels is ont-hot." - "Currently, the tensor is generated and used in npu kernel " - "only. ")); + "Currently, the tensor is generated and used in npu/mlu kernel. ")); #else .AddCheckpoint( R"ROC( diff --git a/paddle/fluid/operators/softmax_with_cross_entropy_op_mlu.cc b/paddle/fluid/operators/softmax_with_cross_entropy_op_mlu.cc new file mode 100644 index 0000000000..0f14e6dabd --- /dev/null +++ b/paddle/fluid/operators/softmax_with_cross_entropy_op_mlu.cc @@ -0,0 +1,151 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/softmax_with_cross_entropy_op.h" +#include "paddle/fluid/operators/mlu/mlu_baseop.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class SoftmaxWithCrossEntropyMLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* logits = ctx.Input("Logits"); + auto* labels = ctx.Input("Label"); + auto* softmax = ctx.Output("Softmax"); + auto* loss = ctx.Output("Loss"); + auto* backprop = ctx.Output("Backprop"); + auto soft_label = ctx.Attr("soft_label"); + + PADDLE_ENFORCE_EQ(ctx.Attr("use_softmax"), true, + platform::errors::InvalidArgument( + "use_softmax=False is not supported in " + "the mlu kernel of softmax_with_cross_entropy.")); + + const int rank = logits->dims().size(); + const int axis = CanonicalAxis(ctx.Attr("axis"), rank); + + loss->mutable_data(ctx.GetPlace()); + backprop->mutable_data(ctx.GetPlace()); + softmax->mutable_data(ctx.GetPlace()); + + // cnnl softmax only support 3-dims, regard all shape as [d1, d2, d3] + const int cnnl_softmax_dims = 3; + const int d1 = SizeToAxis(axis, logits->dims()); + const int d2_logits = logits->dims()[axis]; + const int d2_labels = labels->dims()[axis]; + const int d3 = SizeOutAxis(axis, logits->dims()); + + // CNNL_SOFTMAX_MODE_LOW_DIMENSION has better perfermence, use it as much as + // possible. + cnnlSoftmaxMode_t mode = CNNL_SOFTMAX_MODE_LOW_DIMENSION; + std::vector regard_logits_shape{d1, 1, d2_logits}; + std::vector regard_labels_shape{d1, 1, d2_labels}; + std::vector regard_loss_shape{d1, 1, 1}; + if (d3 != 1) { + mode = CNNL_SOFTMAX_MODE_MEDIUM_DIMENSION; + regard_logits_shape = {d1, d2_logits, d3}; + regard_labels_shape = {d1, d2_labels, d3}; + regard_loss_shape = {d1, 1, d3}; + } + + MLUCnnlTensorDesc logits_desc(cnnl_softmax_dims, regard_logits_shape.data(), + ToCnnlDataType()); + MLUCnnlTensorDesc labels_desc(cnnl_softmax_dims, regard_labels_shape.data(), + ToCnnlDataType()); + MLUCnnlTensorDesc loss_desc(cnnl_softmax_dims, regard_loss_shape.data(), + ToCnnlDataType()); + + const cnnlSoftmaxAlgorithm_t algo = CNNL_SOFTMAX_ACCURATE; + MLUCnnl::SoftmaxForward(ctx, algo, mode, NULL, logits_desc.get(), + GetBasePtr(logits), NULL, logits_desc.get(), + GetBasePtr(softmax)); + + if (soft_label) { + const cnnlComputationPreference_t prefer = + CNNL_COMPUTATION_HIGH_PRECISION; + MLUCnnl::SoftmaxCrossEntropyWithLogits( + ctx, mode, prefer, logits_desc.get(), GetBasePtr(logits), + labels_desc.get(), GetBasePtr(labels), loss_desc.get(), + GetBasePtr(loss), logits_desc.get(), GetBasePtr(backprop)); + } else { + PADDLE_ENFORCE_EQ(d3, 1, + platform::errors::InvalidArgument( + "If soft_label=False, axis must be -1 or" + " can be regard as last dimention in mlu kernel.")); + framework::Tensor labels_int32(VT::INT32); + labels_int32.Resize(labels->dims()); + labels_int32.mutable_data(ctx.GetPlace()); + + MLUCnnlTensorDesc labels_int64_desc(*labels); + MLUCnnlTensorDesc labels_int32_desc(labels_int32); + cnnlCastDataType_t cast_type = GetCastDataType(VT::INT64, VT::INT32); + MLUCnnl::Cast(ctx, cast_type, labels_int64_desc.get(), GetBasePtr(labels), + labels_int32_desc.get(), GetBasePtr(&labels_int32)); + + const int regard_sparse_shape[cnnl_softmax_dims - 1] = {d1, 1}; + MLUCnnlTensorDesc sparse_labels_desc(cnnl_softmax_dims - 1, + regard_sparse_shape, + ToCnnlDataType()); + MLUCnnlTensorDesc sparse_loss_desc( + cnnl_softmax_dims - 1, regard_sparse_shape, ToCnnlDataType()); + + MLUCnnl::SparseSoftmaxXentWithLogits( + ctx, mode, logits_desc.get(), GetBasePtr(logits), + sparse_labels_desc.get(), GetBasePtr(&labels_int32), + sparse_loss_desc.get(), GetBasePtr(loss), logits_desc.get(), + GetBasePtr(backprop)); + } + } +}; + +template +class SoftmaxWithCrossEntropyGradMLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* backprop = ctx.Input("Backprop"); + auto* loss_grad = ctx.Input(framework::GradVarName("Loss")); + auto* logits_grad = ctx.Output(framework::GradVarName("Logits")); + PADDLE_ENFORCE_NOT_NULL(backprop, + platform::errors::PreconditionNotMet( + "backprop should not be null in MLU kernel of " + "softmax_with_cross_entropy_grad.")); + logits_grad->mutable_data(ctx.GetPlace()); + + MLUCnnlOpTensorDesc mul_op_desc(CNNL_OP_TENSOR_MUL, ToCnnlDataType(), + CNNL_NOT_PROPAGATE_NAN); + MLUCnnlTensorDesc backprop_desc(*backprop); + MLUCnnlTensorDesc loss_grad_desc(*loss_grad); + MLUCnnlTensorDesc logits_grad_desc(*logits_grad); + MLUCnnl::OpTensor(ctx, mul_op_desc.get(), backprop_desc.get(), + GetBasePtr(backprop), loss_grad_desc.get(), + GetBasePtr(loss_grad), logits_grad_desc.get(), + GetBasePtr(logits_grad), ToCnnlDataType()); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_MLU_KERNEL( + softmax_with_cross_entropy, ops::SoftmaxWithCrossEntropyMLUKernel, + ops::SoftmaxWithCrossEntropyMLUKernel); +REGISTER_OP_MLU_KERNEL( + softmax_with_cross_entropy_grad, + ops::SoftmaxWithCrossEntropyGradMLUKernel, + ops::SoftmaxWithCrossEntropyGradMLUKernel); diff --git a/python/paddle/fluid/layers/loss.py b/python/paddle/fluid/layers/loss.py index 3db4a894d1..07ed02181e 100644 --- a/python/paddle/fluid/layers/loss.py +++ b/python/paddle/fluid/layers/loss.py @@ -1287,7 +1287,7 @@ def softmax_with_cross_entropy(logits, loss = helper.create_variable_for_type_inference(dtype=logits.dtype) outputs = {'Softmax': softmax, 'Loss': loss} - if core.is_compiled_with_npu(): + if core.is_compiled_with_npu() or core.is_compiled_with_mlu(): backprop = helper.create_variable_for_type_inference(dtype=logits.dtype) outputs['Backprop'] = backprop helper.append_op( diff --git a/python/paddle/fluid/tests/unittests/mlu/test_softmax_with_cross_entropy_op_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_softmax_with_cross_entropy_op_mlu.py new file mode 100644 index 0000000000..e626b6a093 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mlu/test_softmax_with_cross_entropy_op_mlu.py @@ -0,0 +1,161 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import unittest +import sys +sys.path.append("..") +from op_test import OpTest +import paddle +import paddle.fluid as fluid +from test_softmax_op import stable_softmax +from test_softmax_with_cross_entropy_op import cross_entropy + +paddle.enable_static() +SEED = 2021 + + +class TestSoftmaxWithCrossEntropyOp(OpTest): + def set_mlu(self): + self.__class__.use_mlu = True + + def init_dtype(self): + self.dtype = np.float32 + + def initParams(self): + self.set_mlu() + self.op_type = "softmax_with_cross_entropy" + self.numeric_stable_mode = False + self.place = paddle.device.MLUPlace(0) + self.soft_label = False + self.init_dtype() + self.axis = -1 + self.ignore_index = -1 + self.shape = [41, 37] + np.random.seed(SEED) + + def setUp(self): + self.initParams() + + logits = getattr( + self, "logits", + np.random.uniform(0.1, 1.0, self.shape).astype(self.dtype)) + softmax = np.apply_along_axis(stable_softmax, self.axis, logits) + + if self.soft_label: + labels = np.random.uniform(0.1, 1.0, self.shape).astype(self.dtype) + labels /= np.sum(labels, axis=self.axis, keepdims=True) + else: + axis_dim = self.shape[self.axis] + self.shape[self.axis] = 1 + labels = np.random.randint(0, axis_dim, self.shape, dtype="int64") + + loss = cross_entropy(softmax, labels, self.soft_label, self.axis, + self.ignore_index) + + one_hot_label = np.eye(axis_dim)[labels.reshape(-1)] + + self.inputs = {"Logits": logits, "Label": labels} + self.outputs = { + "Backprop": (softmax - one_hot_label).astype(self.dtype), + "Softmax": softmax.astype(self.dtype), + "Loss": loss.astype(self.dtype) + } + self.attrs = { + "numeric_stable_mode": self.numeric_stable_mode, + "soft_label": self.soft_label, + "ignore_index": self.ignore_index, + } + + if self.axis != -1: + self.attrs['axis'] = self.axis + + def test_check_output(self): + self.check_output_with_place(self.place) + + def test_check_grad(self): + if self.dtype == np.float16: + return + # fp32 has low precision, cpu and mlu both need to relax the max_relative_error if using fp32 + self.check_grad_with_place( + self.place, ['Logits'], + 'Loss', + numeric_grad_delta=0.001, + max_relative_error=0.5) + + +class TestPowNet(unittest.TestCase): + def _test(self, run_mlu=True): + main_prog = paddle.static.Program() + startup_prog = paddle.static.Program() + main_prog.random_seed = SEED + startup_prog.random_seed = SEED + np.random.seed(SEED) + + a_np = np.random.random(size=(32, 32)).astype('float32') + b_np = np.random.random(size=(32, 32)).astype('float32') + label_np = np.random.randint(2, size=(32, 1)).astype('int64') + + with paddle.static.program_guard(main_prog, startup_prog): + a = paddle.static.data(name="a", shape=[32, 32], dtype='float32') + b = paddle.static.data(name="b", shape=[32, 32], dtype='float32') + label = paddle.static.data( + name="label", shape=[32, 1], dtype='int64') + + sum = paddle.add(a, b) + z = paddle.pow(sum, 2.0) + + fc_1 = fluid.layers.fc(input=z, size=128) + prediction = fluid.layers.fc(input=fc_1, size=2) + + cost = fluid.layers.softmax_with_cross_entropy(prediction, label) + loss = fluid.layers.reduce_mean(cost) + sgd = fluid.optimizer.SGD(learning_rate=0.01) + sgd.minimize(loss) + + if run_mlu: + place = paddle.device.MLUPlace(0) + else: + place = paddle.CPUPlace() + + exe = paddle.static.Executor(place) + exe.run(startup_prog) + + print("Start run on {}".format(place)) + for epoch in range(100): + + pred_res, loss_res = exe.run( + main_prog, + feed={"a": a_np, + "b": b_np, + "label": label_np}, + fetch_list=[prediction, loss]) + if epoch % 10 == 0: + print("Epoch {} | Prediction[0]: {}, Loss: {}".format( + epoch, pred_res[0], loss_res)) + + return pred_res, loss_res + + def test_mlu(self): + cpu_pred, cpu_loss = self._test(False) + mlu_pred, mlu_loss = self._test(True) + + self.assertTrue(np.allclose(mlu_pred, cpu_pred)) + self.assertTrue(np.allclose(mlu_loss, cpu_loss)) + + +if __name__ == '__main__': + unittest.main() -- GitLab