未验证 提交 e928274c 编写于 作者: Q Qi Li 提交者: GitHub

[NPU] log_softmax_grad, test=develop (#35484)

* [NPU] log_softmax_grad, test=develop

* remove debug files, test=develop

* update lookup_table_v2 for CANN 5.0.x, test=develop
上级 e9ae8dd0
......@@ -14,9 +14,13 @@
#include "paddle/fluid/operators/log_softmax_op.h"
#include "paddle/fluid/operators/npu_op_runner.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
using NPUDeviceContext = platform::NPUDeviceContext;
template <typename T>
class LogSoftmaxNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......@@ -24,22 +28,47 @@ class LogSoftmaxNPUKernel : public framework::OpKernel<T> {
auto* Out = ctx.Output<framework::Tensor>("Out");
const int rank = X->dims().size();
const int axis = CanonicalAxis(ctx.Attr<int>("axis"), rank);
std::vector<int> axes;
axes.push_back(axis);
framework::NPUAttributeMap attr_input = {{"axes", axes}};
Out->mutable_data<T>(ctx.GetPlace());
const auto& runner = NpuOpRunner("LogSoftmaxV2", {*X}, {*Out}, attr_input);
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
runner.Run(stream);
if (X->numel() != 0) {
auto stream = ctx.template device_context<NPUDeviceContext>().stream();
const auto& runner = NpuOpRunner("LogSoftmaxV2", {*X}, {*Out},
{{"axes", std::vector<int>{axis}}});
runner.Run(stream);
}
}
};
template <typename T>
class LogSoftmaxGradNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* Out = ctx.Input<framework::Tensor>("Out");
auto* dOut = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* dX = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
const int rank = dOut->dims().size();
const int axis = CanonicalAxis(ctx.Attr<int>("axis"), rank);
// allocate memory on device.
dX->mutable_data<T>(ctx.GetPlace());
if (dOut->numel() != 0) {
auto stream = ctx.template device_context<NPUDeviceContext>().stream();
const auto& runner = NpuOpRunner("LogSoftmaxGrad", {*dOut, *Out}, {*dX},
{{"axis", std::vector<int>{axis}}});
runner.Run(stream);
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_NPU_KERNEL(
log_softmax,
ops::LogSoftmaxNPUKernel<paddle::platform::NPUDeviceContext, float>);
REGISTER_OP_NPU_KERNEL(log_softmax, ops::LogSoftmaxNPUKernel<float>,
ops::LogSoftmaxNPUKernel<plat::float16>);
REGISTER_OP_NPU_KERNEL(log_softmax_grad, ops::LogSoftmaxGradNPUKernel<float>,
ops::LogSoftmaxGradNPUKernel<plat::float16>);
......@@ -29,11 +29,6 @@ class LookupTableV2NPUKernel : public framework::OpKernel<T> {
auto *output_t = ctx.Output<framework::LoDTensor>("Out"); // float tensor
auto *table_t = ctx.Input<framework::LoDTensor>("W");
// It seems cann 20.1 accepts int64, but cann 20.2+ not.
PADDLE_ENFORCE_EQ(ids_t->type(), framework::proto::VarType::INT32,
platform::errors::Unimplemented(
"The index of LookupTableV2 should be int32."));
auto *table_var = ctx.InputVar("W");
PADDLE_ENFORCE_EQ(
table_var->IsType<framework::LoDTensor>(), true,
......
......@@ -22,9 +22,10 @@ import paddle
import paddle.fluid as fluid
from paddle.fluid import core
import paddle.nn.functional as F
from test_log_softmax import ref_log_softmax, ref_log_softmax_grad
paddle.enable_static()
np.random.seed(10)
class TestLogSoftmaxNPUOp(OpTest):
......@@ -55,10 +56,16 @@ class TestLogSoftmaxNPUOp(OpTest):
pass
def test_check_output(self):
self.check_output_with_place(self.place)
if self.dtype == np.float16:
self.check_output_with_place(self.place, atol=1e-2)
else:
self.check_output_with_place(self.place)
def test_check_grad(self):
pass
if self.dtype == np.float16:
return
self.check_grad_with_place(
self.place, ['X'], ['Out'], user_defined_grads=[self.x_grad])
def test_class(op_type, typename):
......@@ -88,8 +95,73 @@ def test_class2(op_type, typename):
globals()[cls_name] = TestLogSoftmaxAxis
for _typename in {'float32'}:
for _typename in {np.float32, np.float16}:
test_class("logsoftmax", _typename)
test_class2("logsoftmax", _typename)
class TestNNLogSoftmaxAPI(unittest.TestCase):
def setUp(self):
self.x_shape = [2, 3, 4, 5]
self.x = np.random.uniform(-1., 1., self.x_shape).astype(np.float32)
self.place = paddle.NPUPlace(0) \
if paddle.fluid.core.is_compiled_with_npu() \
else paddle.CPUPlace()
def check_api(self, axis=-1):
ref_out = np.apply_along_axis(ref_log_softmax, axis, self.x)
logsoftmax = paddle.nn.LogSoftmax(axis)
# test static api
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.fluid.data(name='x', shape=self.x_shape)
y = logsoftmax(x)
exe = paddle.static.Executor(self.place)
out = exe.run(feed={'x': self.x}, fetch_list=[y])
self.assertTrue(np.allclose(out[0], ref_out))
# test dygrapg api
paddle.disable_static(self.place)
x = paddle.to_tensor(self.x)
y = logsoftmax(x)
self.assertTrue(np.allclose(y.numpy(), ref_out))
paddle.enable_static()
def test_check_api(self):
for axis in [-1, 1]:
self.check_api(axis)
class TestNNFunctionalLogSoftmaxAPI(unittest.TestCase):
def setUp(self):
self.x_shape = [2, 3, 4, 5]
self.x = np.random.uniform(-1, 1, self.x_shape).astype(np.float32)
self.place = paddle.NPUPlace(0) \
if paddle.fluid.core.is_compiled_with_npu() \
else paddle.CPUPlace()
def check_api(self, axis=-1, dtype=None):
x = self.x.copy()
if dtype is not None:
x = x.astype(dtype)
ref_out = np.apply_along_axis(ref_log_softmax, axis, x)
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.fluid.data(name='x', shape=self.x_shape)
y = F.log_softmax(x, axis, dtype)
exe = paddle.static.Executor(self.place)
out = exe.run(feed={'x': self.x}, fetch_list=[y])
self.assertTrue(np.allclose(out[0], ref_out))
paddle.disable_static(self.place)
x = paddle.to_tensor(self.x)
y = F.log_softmax(x, axis, dtype)
self.assertTrue(np.allclose(y.numpy(), ref_out), True)
paddle.enable_static()
def test_check_api(self):
for axis in [-1, 1]:
self.check_api(axis)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册