未验证 提交 aecf9967 编写于 作者: F fwenguang 提交者: GitHub

[MLU] add softmax_with_cross_entropy mlu kernel (#39260)

上级 d28f6f7b
......@@ -40,7 +40,7 @@ class SoftmaxWithCrossEntropyOpMaker
"The outputs value of softmax activation by given the input batch, "
"which will be used in backward calculation.")
.AsIntermediate();
#ifdef PADDLE_WITH_ASCEND_CL
#if defined(PADDLE_WITH_ASCEND_CL) || defined(PADDLE_WITH_MLU)
AddOutput(
"Backprop",
"(Tensor, default: Tensor<float>), A tensor in same shape with "
......@@ -49,7 +49,7 @@ class SoftmaxWithCrossEntropyOpMaker
"is :"
"exp(logits -max_logits) / sum(exp(logits - max_logits)) - labels, "
"where labels is ont-hot."
"Currently, the tensor is generated and used in npu kernel only. ")
"Currently, the tensor is generated and used in npu/mlu kernel. ")
.AsIntermediate();
#endif
AddOutput("Loss",
......@@ -131,7 +131,7 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(ctx->HasOutput("Softmax"), true,
platform::errors::InvalidArgument(
"Output(Softmax) should be not null."));
#ifdef PADDLE_WITH_ASCEND_CL
#if defined(PADDLE_WITH_ASCEND_CL) || defined(PADDLE_WITH_MLU)
PADDLE_ENFORCE_EQ(ctx->HasOutput("Backprop"), true,
platform::errors::InvalidArgument(
"Output(Backprop) should be not null."));
......@@ -194,7 +194,7 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel {
}
ctx->SetOutputDim("Softmax", logits_dims);
#ifdef PADDLE_WITH_ASCEND_CL
#if defined(PADDLE_WITH_ASCEND_CL) || defined(PADDLE_WITH_MLU)
ctx->SetOutputDim("Backprop", logits_dims);
ctx->ShareLoD("Logits", /*->*/ "Backprop");
#endif
......@@ -225,7 +225,7 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(ctx->HasInput("Softmax"), true,
platform::errors::InvalidArgument(
"Input(Softmax) should be not null."));
#ifdef PADDLE_WITH_ASCEND_CL
#if defined(PADDLE_WITH_ASCEND_CL) || defined(PADDLE_WITH_MLU)
PADDLE_ENFORCE_EQ(ctx->HasInput("Backprop"), true,
platform::errors::InvalidArgument(
"Input(Backprop) should be not null."));
......@@ -306,7 +306,7 @@ class SoftmaxGradMaker : public framework::SingleGradOpMaker<T> {
grad_op->SetType("softmax_with_cross_entropy_grad");
grad_op->SetInput("Label", this->Input("Label"));
grad_op->SetInput("Softmax", this->Output("Softmax"));
#ifdef PADDLE_WITH_ASCEND_CL
#if defined(PADDLE_WITH_ASCEND_CL) || defined(PADDLE_WITH_MLU)
grad_op->SetInput("Backprop", this->Output("Backprop"));
#endif
grad_op->SetInput(framework::GradVarName("Loss"), this->OutputGrad("Loss"));
......@@ -343,7 +343,7 @@ REGISTER_OP_CPU_KERNEL(softmax_with_cross_entropy_grad,
ops::SoftmaxWithCrossEntropyGradKernel<double>);
REGISTER_OP_VERSION(softmax_with_cross_entropy)
#ifdef PADDLE_WITH_ASCEND_CL
#if defined(PADDLE_WITH_ASCEND_CL) || defined(PADDLE_WITH_MLU)
.AddCheckpoint(
R"ROC(
Add a new attribute [use_softmax] )ROC",
......@@ -358,8 +358,7 @@ REGISTER_OP_VERSION(softmax_with_cross_entropy)
"calculation is :"
"exp(logits -max_logits) / sum(exp(logits - max_logits)) - labels, "
"where labels is ont-hot."
"Currently, the tensor is generated and used in npu kernel "
"only. "));
"Currently, the tensor is generated and used in npu/mlu kernel. "));
#else
.AddCheckpoint(
R"ROC(
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/softmax_with_cross_entropy_op.h"
#include "paddle/fluid/operators/mlu/mlu_baseop.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
class SoftmaxWithCrossEntropyMLUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* logits = ctx.Input<Tensor>("Logits");
auto* labels = ctx.Input<Tensor>("Label");
auto* softmax = ctx.Output<Tensor>("Softmax");
auto* loss = ctx.Output<Tensor>("Loss");
auto* backprop = ctx.Output<Tensor>("Backprop");
auto soft_label = ctx.Attr<bool>("soft_label");
PADDLE_ENFORCE_EQ(ctx.Attr<bool>("use_softmax"), true,
platform::errors::InvalidArgument(
"use_softmax=False is not supported in "
"the mlu kernel of softmax_with_cross_entropy."));
const int rank = logits->dims().size();
const int axis = CanonicalAxis(ctx.Attr<int>("axis"), rank);
loss->mutable_data<T>(ctx.GetPlace());
backprop->mutable_data<T>(ctx.GetPlace());
softmax->mutable_data<T>(ctx.GetPlace());
// cnnl softmax only support 3-dims, regard all shape as [d1, d2, d3]
const int cnnl_softmax_dims = 3;
const int d1 = SizeToAxis(axis, logits->dims());
const int d2_logits = logits->dims()[axis];
const int d2_labels = labels->dims()[axis];
const int d3 = SizeOutAxis(axis, logits->dims());
// CNNL_SOFTMAX_MODE_LOW_DIMENSION has better perfermence, use it as much as
// possible.
cnnlSoftmaxMode_t mode = CNNL_SOFTMAX_MODE_LOW_DIMENSION;
std::vector<int> regard_logits_shape{d1, 1, d2_logits};
std::vector<int> regard_labels_shape{d1, 1, d2_labels};
std::vector<int> regard_loss_shape{d1, 1, 1};
if (d3 != 1) {
mode = CNNL_SOFTMAX_MODE_MEDIUM_DIMENSION;
regard_logits_shape = {d1, d2_logits, d3};
regard_labels_shape = {d1, d2_labels, d3};
regard_loss_shape = {d1, 1, d3};
}
MLUCnnlTensorDesc logits_desc(cnnl_softmax_dims, regard_logits_shape.data(),
ToCnnlDataType<T>());
MLUCnnlTensorDesc labels_desc(cnnl_softmax_dims, regard_labels_shape.data(),
ToCnnlDataType<T>());
MLUCnnlTensorDesc loss_desc(cnnl_softmax_dims, regard_loss_shape.data(),
ToCnnlDataType<T>());
const cnnlSoftmaxAlgorithm_t algo = CNNL_SOFTMAX_ACCURATE;
MLUCnnl::SoftmaxForward(ctx, algo, mode, NULL, logits_desc.get(),
GetBasePtr(logits), NULL, logits_desc.get(),
GetBasePtr(softmax));
if (soft_label) {
const cnnlComputationPreference_t prefer =
CNNL_COMPUTATION_HIGH_PRECISION;
MLUCnnl::SoftmaxCrossEntropyWithLogits(
ctx, mode, prefer, logits_desc.get(), GetBasePtr(logits),
labels_desc.get(), GetBasePtr(labels), loss_desc.get(),
GetBasePtr(loss), logits_desc.get(), GetBasePtr(backprop));
} else {
PADDLE_ENFORCE_EQ(d3, 1,
platform::errors::InvalidArgument(
"If soft_label=False, axis must be -1 or"
" can be regard as last dimention in mlu kernel."));
framework::Tensor labels_int32(VT::INT32);
labels_int32.Resize(labels->dims());
labels_int32.mutable_data<int32_t>(ctx.GetPlace());
MLUCnnlTensorDesc labels_int64_desc(*labels);
MLUCnnlTensorDesc labels_int32_desc(labels_int32);
cnnlCastDataType_t cast_type = GetCastDataType(VT::INT64, VT::INT32);
MLUCnnl::Cast(ctx, cast_type, labels_int64_desc.get(), GetBasePtr(labels),
labels_int32_desc.get(), GetBasePtr(&labels_int32));
const int regard_sparse_shape[cnnl_softmax_dims - 1] = {d1, 1};
MLUCnnlTensorDesc sparse_labels_desc(cnnl_softmax_dims - 1,
regard_sparse_shape,
ToCnnlDataType<int32_t>());
MLUCnnlTensorDesc sparse_loss_desc(
cnnl_softmax_dims - 1, regard_sparse_shape, ToCnnlDataType<T>());
MLUCnnl::SparseSoftmaxXentWithLogits(
ctx, mode, logits_desc.get(), GetBasePtr(logits),
sparse_labels_desc.get(), GetBasePtr(&labels_int32),
sparse_loss_desc.get(), GetBasePtr(loss), logits_desc.get(),
GetBasePtr(backprop));
}
}
};
template <typename T>
class SoftmaxWithCrossEntropyGradMLUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* backprop = ctx.Input<Tensor>("Backprop");
auto* loss_grad = ctx.Input<Tensor>(framework::GradVarName("Loss"));
auto* logits_grad = ctx.Output<Tensor>(framework::GradVarName("Logits"));
PADDLE_ENFORCE_NOT_NULL(backprop,
platform::errors::PreconditionNotMet(
"backprop should not be null in MLU kernel of "
"softmax_with_cross_entropy_grad."));
logits_grad->mutable_data<T>(ctx.GetPlace());
MLUCnnlOpTensorDesc mul_op_desc(CNNL_OP_TENSOR_MUL, ToCnnlDataType<T>(),
CNNL_NOT_PROPAGATE_NAN);
MLUCnnlTensorDesc backprop_desc(*backprop);
MLUCnnlTensorDesc loss_grad_desc(*loss_grad);
MLUCnnlTensorDesc logits_grad_desc(*logits_grad);
MLUCnnl::OpTensor(ctx, mul_op_desc.get(), backprop_desc.get(),
GetBasePtr(backprop), loss_grad_desc.get(),
GetBasePtr(loss_grad), logits_grad_desc.get(),
GetBasePtr(logits_grad), ToCnnlDataType<T>());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_MLU_KERNEL(
softmax_with_cross_entropy, ops::SoftmaxWithCrossEntropyMLUKernel<float>,
ops::SoftmaxWithCrossEntropyMLUKernel<paddle::platform::float16>);
REGISTER_OP_MLU_KERNEL(
softmax_with_cross_entropy_grad,
ops::SoftmaxWithCrossEntropyGradMLUKernel<float>,
ops::SoftmaxWithCrossEntropyGradMLUKernel<paddle::platform::float16>);
......@@ -1287,7 +1287,7 @@ def softmax_with_cross_entropy(logits,
loss = helper.create_variable_for_type_inference(dtype=logits.dtype)
outputs = {'Softmax': softmax, 'Loss': loss}
if core.is_compiled_with_npu():
if core.is_compiled_with_npu() or core.is_compiled_with_mlu():
backprop = helper.create_variable_for_type_inference(dtype=logits.dtype)
outputs['Backprop'] = backprop
helper.append_op(
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import unittest
import sys
sys.path.append("..")
from op_test import OpTest
import paddle
import paddle.fluid as fluid
from test_softmax_op import stable_softmax
from test_softmax_with_cross_entropy_op import cross_entropy
paddle.enable_static()
SEED = 2021
class TestSoftmaxWithCrossEntropyOp(OpTest):
def set_mlu(self):
self.__class__.use_mlu = True
def init_dtype(self):
self.dtype = np.float32
def initParams(self):
self.set_mlu()
self.op_type = "softmax_with_cross_entropy"
self.numeric_stable_mode = False
self.place = paddle.device.MLUPlace(0)
self.soft_label = False
self.init_dtype()
self.axis = -1
self.ignore_index = -1
self.shape = [41, 37]
np.random.seed(SEED)
def setUp(self):
self.initParams()
logits = getattr(
self, "logits",
np.random.uniform(0.1, 1.0, self.shape).astype(self.dtype))
softmax = np.apply_along_axis(stable_softmax, self.axis, logits)
if self.soft_label:
labels = np.random.uniform(0.1, 1.0, self.shape).astype(self.dtype)
labels /= np.sum(labels, axis=self.axis, keepdims=True)
else:
axis_dim = self.shape[self.axis]
self.shape[self.axis] = 1
labels = np.random.randint(0, axis_dim, self.shape, dtype="int64")
loss = cross_entropy(softmax, labels, self.soft_label, self.axis,
self.ignore_index)
one_hot_label = np.eye(axis_dim)[labels.reshape(-1)]
self.inputs = {"Logits": logits, "Label": labels}
self.outputs = {
"Backprop": (softmax - one_hot_label).astype(self.dtype),
"Softmax": softmax.astype(self.dtype),
"Loss": loss.astype(self.dtype)
}
self.attrs = {
"numeric_stable_mode": self.numeric_stable_mode,
"soft_label": self.soft_label,
"ignore_index": self.ignore_index,
}
if self.axis != -1:
self.attrs['axis'] = self.axis
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
if self.dtype == np.float16:
return
# fp32 has low precision, cpu and mlu both need to relax the max_relative_error if using fp32
self.check_grad_with_place(
self.place, ['Logits'],
'Loss',
numeric_grad_delta=0.001,
max_relative_error=0.5)
class TestPowNet(unittest.TestCase):
def _test(self, run_mlu=True):
main_prog = paddle.static.Program()
startup_prog = paddle.static.Program()
main_prog.random_seed = SEED
startup_prog.random_seed = SEED
np.random.seed(SEED)
a_np = np.random.random(size=(32, 32)).astype('float32')
b_np = np.random.random(size=(32, 32)).astype('float32')
label_np = np.random.randint(2, size=(32, 1)).astype('int64')
with paddle.static.program_guard(main_prog, startup_prog):
a = paddle.static.data(name="a", shape=[32, 32], dtype='float32')
b = paddle.static.data(name="b", shape=[32, 32], dtype='float32')
label = paddle.static.data(
name="label", shape=[32, 1], dtype='int64')
sum = paddle.add(a, b)
z = paddle.pow(sum, 2.0)
fc_1 = fluid.layers.fc(input=z, size=128)
prediction = fluid.layers.fc(input=fc_1, size=2)
cost = fluid.layers.softmax_with_cross_entropy(prediction, label)
loss = fluid.layers.reduce_mean(cost)
sgd = fluid.optimizer.SGD(learning_rate=0.01)
sgd.minimize(loss)
if run_mlu:
place = paddle.device.MLUPlace(0)
else:
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
exe.run(startup_prog)
print("Start run on {}".format(place))
for epoch in range(100):
pred_res, loss_res = exe.run(
main_prog,
feed={"a": a_np,
"b": b_np,
"label": label_np},
fetch_list=[prediction, loss])
if epoch % 10 == 0:
print("Epoch {} | Prediction[0]: {}, Loss: {}".format(
epoch, pred_res[0], loss_res))
return pred_res, loss_res
def test_mlu(self):
cpu_pred, cpu_loss = self._test(False)
mlu_pred, mlu_loss = self._test(True)
self.assertTrue(np.allclose(mlu_pred, cpu_pred))
self.assertTrue(np.allclose(mlu_loss, cpu_loss))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册