未验证 提交 5118968d 编写于 作者: L Leo Chen 提交者: GitHub

[NPU] add npu kernel for softmax_with_cross_entropy (#31656)

* init

* fix bugs
上级 925432d8
......@@ -82,6 +82,8 @@ class NpuOpRunner {
aclopAttr *attr_{nullptr};
};
aclDataType ConvertToNpuDtype(framework::proto::VarType::Type dtype);
} // namespace operators
} // namespace paddle
#endif
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/softmax.h"
#include <memory>
#include <string>
#include "paddle/fluid/operators/math/cross_entropy.h"
#include "paddle/fluid/operators/npu_op_runner.h"
#include "paddle/fluid/operators/softmax_op.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
class SoftmaxWithCrossEntropyNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* logits = ctx.Input<Tensor>("Logits");
auto* labels = ctx.Input<Tensor>("Label");
auto* softmax = ctx.Output<Tensor>("Softmax");
auto* loss = ctx.Output<Tensor>("Loss");
int cls_num = logits->dims()[1];
const int rank = logits->dims().size();
const int axis = CanonicalAxis(ctx.Attr<int>("axis"), rank);
std::vector<int> axes;
for (auto i = axis; i < logits->dims().size(); ++i) {
axes.push_back(i);
}
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
// softmax
softmax->mutable_data<T>(ctx.GetPlace());
auto runner_softmax =
NpuOpRunner("SoftmaxV2", {*logits}, {*softmax}, {{"axes", axes}});
runner_softmax.Run(stream);
// cast label from int64/int32 to int32
Tensor tmp_labels(framework::proto::VarType::INT32);
if (labels->type() != framework::proto::VarType::INT32) {
tmp_labels.Resize(labels->dims());
tmp_labels.mutable_data(ctx.GetPlace(), framework::proto::VarType::INT32);
auto dst_dtype = ConvertToNpuDtype(framework::proto::VarType::INT32);
auto runner_cast_label =
NpuOpRunner("Cast", {*labels}, {tmp_labels},
{{"dst_type", static_cast<int>(dst_dtype)}});
runner_cast_label.Run(stream);
labels = &tmp_labels;
}
// on and off
Tensor on_tensor(framework::proto::VarType::INT32);
on_tensor.mutable_data<int>({1}, ctx.GetPlace());
TensorFromVector(std::vector<int>{static_cast<int>(1)},
ctx.device_context(), &on_tensor);
Tensor off_tensor(framework::proto::VarType::INT32);
off_tensor.mutable_data<int>({1}, ctx.GetPlace());
TensorFromVector(std::vector<int>{static_cast<int>(0)},
ctx.device_context(), &off_tensor);
// one_hot
Tensor tmp_onehot(on_tensor.type());
tmp_onehot.Resize(logits->dims());
tmp_onehot.mutable_data<int>(ctx.GetPlace());
auto runner_onehot =
NpuOpRunner("OneHotD", {*labels, on_tensor, off_tensor}, {tmp_onehot},
{{"axis", -1}, {"depth", cls_num}});
runner_onehot.Run(stream);
// cast one_hot from int32 to T
Tensor cast_onehot(logits->type());
cast_onehot.Resize(tmp_onehot.dims());
cast_onehot.mutable_data<T>(ctx.GetPlace());
auto dst_dtype = ConvertToNpuDtype(logits->type());
auto runner_cast_onehot =
NpuOpRunner("Cast", {tmp_onehot}, {cast_onehot},
{{"dst_type", static_cast<int>(dst_dtype)}});
runner_cast_onehot.Run(stream);
// SoftmaxCrossEntropyWithLogits
Tensor backprop(logits->type());
backprop.Resize(logits->dims());
backprop.mutable_data<T>(ctx.GetPlace());
loss->mutable_data<T>(ctx.GetPlace());
// SoftmaxCrossEntropyWithLogits requires loss to be of shape [batch_size]
auto loss_dims = loss->dims();
loss->Resize({loss_dims[0]});
auto runner_s = NpuOpRunner("SoftmaxCrossEntropyWithLogits",
{*logits, cast_onehot}, {*loss, backprop}, {});
runner_s.Run(stream);
loss->Resize(loss_dims);
}
};
template <typename DeviceContext, typename T>
class SoftmaxWithCrossEntropyGradNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* labels = ctx.Input<Tensor>("Label");
auto* softmax = ctx.Input<Tensor>("Softmax");
auto* loss_grad = ctx.Input<Tensor>(framework::GradVarName("Loss"));
auto* logits_grad = ctx.Output<Tensor>(framework::GradVarName("Logits"));
int cls_num = softmax->dims()[1];
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
// cast label from int64/int32 to int32
Tensor tmp_labels(framework::proto::VarType::INT32);
if (labels->type() != framework::proto::VarType::INT32) {
tmp_labels.Resize(labels->dims());
tmp_labels.mutable_data(ctx.GetPlace(), framework::proto::VarType::INT32);
auto dst_dtype = ConvertToNpuDtype(framework::proto::VarType::INT32);
auto runner_cast_label =
NpuOpRunner("Cast", {*labels}, {tmp_labels},
{{"dst_type", static_cast<int>(dst_dtype)}});
runner_cast_label.Run(stream);
labels = &tmp_labels;
}
// on and off
Tensor on_tensor(framework::proto::VarType::INT32);
on_tensor.mutable_data<int>({1}, ctx.GetPlace());
TensorFromVector(std::vector<int>{static_cast<int>(1)},
ctx.device_context(), &on_tensor);
Tensor off_tensor(framework::proto::VarType::INT32);
off_tensor.mutable_data<int>({1}, ctx.GetPlace());
TensorFromVector(std::vector<int>{static_cast<int>(0)},
ctx.device_context(), &off_tensor);
// one_hot
Tensor tmp_onehot(on_tensor.type());
tmp_onehot.Resize(softmax->dims());
tmp_onehot.mutable_data<int>(ctx.GetPlace());
auto runner_onehot =
NpuOpRunner("OneHotD", {*labels, on_tensor, off_tensor}, {tmp_onehot},
{{"axis", -1}, {"depth", cls_num}});
runner_onehot.Run(stream);
// cast one_hot from int32 to T
Tensor cast_onehot(softmax->type());
cast_onehot.Resize(tmp_onehot.dims());
cast_onehot.mutable_data<T>(ctx.GetPlace());
auto dst_dtype = ConvertToNpuDtype(softmax->type());
auto runner_cast_onehot =
NpuOpRunner("Cast", {tmp_onehot}, {cast_onehot},
{{"dst_type", static_cast<int>(dst_dtype)}});
runner_cast_onehot.Run(stream);
// sub
Tensor tmp_sub(softmax->type());
tmp_sub.Resize(softmax->dims());
tmp_sub.mutable_data<T>(ctx.GetPlace());
auto runner_sub =
NpuOpRunner("Sub", {*softmax, cast_onehot}, {tmp_sub}, {});
runner_sub.Run(stream);
// mul
logits_grad->mutable_data<T>(ctx.GetPlace());
auto runner_mul =
NpuOpRunner("Mul", {*loss_grad, tmp_sub}, {*logits_grad}, {});
runner_mul.Run(stream);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_NPU_KERNEL(
softmax_with_cross_entropy,
ops::SoftmaxWithCrossEntropyNPUKernel<paddle::platform::NPUDeviceContext,
float>,
ops::SoftmaxWithCrossEntropyNPUKernel<paddle::platform::NPUDeviceContext,
paddle::platform::float16>);
REGISTER_OP_NPU_KERNEL(
softmax_with_cross_entropy_grad,
ops::SoftmaxWithCrossEntropyGradNPUKernel<
paddle::platform::NPUDeviceContext, float>,
ops::SoftmaxWithCrossEntropyGradNPUKernel<
paddle::platform::NPUDeviceContext, paddle::platform::float16>);
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import unittest
import sys
sys.path.append("..")
from op_test import OpTest
import paddle
import paddle.fluid as fluid
from test_softmax_op import stable_softmax
from test_softmax_with_cross_entropy_op import cross_entropy
paddle.enable_static()
SEED = 2021
@unittest.skipIf(not paddle.is_compiled_with_npu(),
"core is not compiled with NPU")
class TestSoftmaxWithCrossEntropyOp(OpTest):
def set_npu(self):
self.__class__.use_npu = True
def init_dtype(self):
self.dtype = np.float32
def initParams(self):
self.set_npu()
self.op_type = "softmax_with_cross_entropy"
self.numeric_stable_mode = False
self.place = paddle.NPUPlace(0)
self.soft_label = False
self.init_dtype()
self.axis = -1
self.ignore_index = -1
self.shape = [41, 37]
np.random.seed(SEED)
def setUp(self):
self.initParams()
logits = getattr(
self, "logits",
np.random.uniform(0.1, 1.0, self.shape).astype(self.dtype))
softmax = np.apply_along_axis(stable_softmax, self.axis, logits)
if self.soft_label:
labels = np.random.uniform(0.1, 1.0, self.shape).astype(self.dtype)
labels /= np.sum(labels, axis=self.axis, keepdims=True)
else:
axis_dim = self.shape[self.axis]
self.shape[self.axis] = 1
labels = np.random.randint(0, axis_dim, self.shape, dtype="int64")
loss = cross_entropy(softmax, labels, self.soft_label, self.axis,
self.ignore_index)
self.inputs = {"Logits": logits, "Label": labels}
self.outputs = {
"Softmax": softmax.astype(self.dtype),
"Loss": loss.astype(self.dtype)
}
self.attrs = {
"numeric_stable_mode": self.numeric_stable_mode,
"soft_label": self.soft_label,
"ignore_index": self.ignore_index,
}
if self.axis != -1:
self.attrs['axis'] = self.axis
def test_check_output(self):
self.check_output_with_place(self.place, check_dygraph=False)
# TODO(ascendrc): Add grad test
# def test_check_grad(self):
# if self.dtype == np.float16:
# return
# self.check_grad(['X'], 'Out')
#
@unittest.skipIf(not paddle.is_compiled_with_npu(),
"core is not compiled with NPU")
class TestPowNet(unittest.TestCase):
def _test(self, run_npu=True):
main_prog = paddle.static.Program()
startup_prog = paddle.static.Program()
main_prog.random_seed = SEED
startup_prog.random_seed = SEED
np.random.seed(SEED)
a_np = np.random.random(size=(32, 32)).astype('float32')
b_np = np.random.random(size=(32, 32)).astype('float32')
label_np = np.random.randint(2, size=(32, 1)).astype('int64')
with paddle.static.program_guard(main_prog, startup_prog):
a = paddle.static.data(name="a", shape=[32, 32], dtype='float32')
b = paddle.static.data(name="b", shape=[32, 32], dtype='float32')
label = paddle.static.data(
name="label", shape=[32, 1], dtype='int64')
sum = paddle.add(a, b)
z = paddle.pow(sum, 2.0)
fc_1 = fluid.layers.fc(input=z, size=128)
prediction = fluid.layers.fc(input=fc_1, size=2)
cost = fluid.layers.softmax_with_cross_entropy(prediction, label)
loss = fluid.layers.reduce_mean(cost)
sgd = fluid.optimizer.SGD(learning_rate=0.01)
sgd.minimize(loss)
if run_npu:
place = paddle.NPUPlace(0)
else:
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
exe.run(startup_prog)
print("Start run on {}".format(place))
for epoch in range(100):
pred_res, loss_res = exe.run(
main_prog,
feed={"a": a_np,
"b": b_np,
"label": label_np},
fetch_list=[prediction, loss])
if epoch % 10 == 0:
print("Epoch {} | Prediction[0]: {}, Loss: {}".format(
epoch, pred_res[0], loss_res))
return pred_res, loss_res
def test_npu(self):
cpu_pred, cpu_loss = self._test(False)
npu_pred, npu_loss = self._test(True)
self.assertTrue(np.allclose(npu_pred, cpu_pred))
self.assertTrue(np.allclose(npu_loss, cpu_loss))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册