未验证 提交 7ec8459c 编写于 作者: O OleNet 提交者: GitHub

[NPU] Support softmax npu kernel (#31564)

上级 7875bcb8
......@@ -187,4 +187,5 @@ endif()
if(WITH_ASCEND_CL)
cc_test(gelu_op_npu_test SRCS gelu_op_npu_test.cc DEPS op_registry gelu_op scope device_context enforce executor)
cc_test(top_k_op_npu_test SRCS top_k_op_npu_test.cc DEPS op_registry top_k_op scope device_context enforce executor)
endif()
......@@ -79,11 +79,13 @@ class SoftmaxOp : public framework::OperatorWithKernel {
#endif
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifndef PADDLE_WITH_ASCEND_CL
if (input_data_type == framework::proto::VarType::FP16) {
PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true,
platform::errors::InvalidArgument(
"float16 can only be used on GPU place"));
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_,
library_);
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <memory>
#include <string>
#include "paddle/fluid/operators/npu_op_runner.h"
#include "paddle/fluid/operators/softmax_op.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class SoftmaxNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<framework::LoDTensor>("X");
auto axis = ctx.Attr<int>("axis");
std::vector<int> axes;
axes.push_back(axis);
framework::NPUAttributeMap attr_input = {{"axes", axes}};
auto* out = ctx.Output<framework::LoDTensor>("Out");
out->mutable_data<T>(ctx.GetPlace());
auto runner = NpuOpRunner("SoftmaxV2", {*in}, {*out}, attr_input);
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
runner.Run(stream);
}
};
template <typename DeviceContext, typename T>
class SoftmaxGradNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* out = ctx.Input<framework::LoDTensor>("Out");
auto* dOut = ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"));
auto* dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto dims = dX->dims();
const int rank = dims.size();
const int axis = CanonicalAxis(ctx.Attr<int>("axis"), rank);
int64_t first_dim = 1;
int64_t sec_dim = 1;
for (int i = 0; i < axis; i++) {
first_dim *= dims[i];
}
for (int i = axis; i < rank; i++) {
sec_dim *= dims[i];
}
Tensor tmp_out;
tmp_out.ShareDataWith(*out).Resize({first_dim, sec_dim});
Tensor tmp_dOut;
tmp_dOut.ShareDataWith(*dOut).Resize({first_dim, sec_dim});
dX->Resize(framework::make_ddim({first_dim, sec_dim}));
dX->mutable_data<T>(ctx.GetPlace());
framework::NPUAttributeMap attr_input = {};
auto runner = NpuOpRunner(std::string("SoftmaxGrad"), {tmp_out, tmp_dOut},
{*dX}, attr_input);
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
runner.Run(stream);
dX->Resize(dims);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_NPU_KERNEL(
softmax, ops::SoftmaxNPUKernel<plat::NPUDeviceContext, float>,
ops::SoftmaxNPUKernel<plat::NPUDeviceContext, double>,
ops::SoftmaxNPUKernel<plat::NPUDeviceContext, plat::float16>);
REGISTER_OP_NPU_KERNEL(
softmax_grad, ops::SoftmaxGradNPUKernel<plat::NPUDeviceContext, float>,
ops::SoftmaxGradNPUKernel<plat::NPUDeviceContext, double>,
ops::SoftmaxGradNPUKernel<plat::NPUDeviceContext,
paddle::platform::float16>);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef _WIN32
#include <unistd.h>
#endif
#include <string>
#include <thread> // NOLINT
#include <vector>
#include "gtest/gtest.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/fluid/framework/tensor_util.h"
namespace f = paddle::framework;
namespace p = paddle::platform;
namespace m = paddle::operators::math;
USE_OP(softmax);
USE_OP_DEVICE_KERNEL(softmax, NPU);
template <typename T>
void Compare(f::Scope* scope, const p::DeviceContext& ctx) {
// init
auto x = scope->Var("X");
auto tensor_x = x->GetMutable<f::LoDTensor>();
std::vector<T> init;
for (int i = 3; i < 9; ++i) {
init.push_back(static_cast<T>(i));
}
TensorFromVector(init, ctx, tensor_x);
tensor_x->Resize({2, 3});
ctx.Wait();
auto place = ctx.GetPlace();
auto out = scope->Var("Out");
auto tensor_out = out->GetMutable<f::LoDTensor>();
tensor_out->Resize({2, 3});
tensor_out->mutable_data<T>(place); // allocate
// run
int axis = 1;
f::AttributeMap attrs = {
{"axis", axis},
{"use_cudnn", false},
{"use_mkldnn", false},
{"mkldnn_data_type", std::string("float32")},
{"is_test", false}, };
auto op =
f::OpRegistry::CreateOp("softmax", {{"X", {"X"}}},
{{"Out", {"Out"}}}, attrs);
op->Run(*scope, place);
ctx.Wait();
std::vector<T> out_vec;
TensorToVector(*tensor_out, ctx, &out_vec);
for (int i = 0; i < static_cast<int>(out_vec.size()); ++i) {
VLOG(3) << "out_vec[" << i << "] : "<< out_vec[i];
}
ctx.Wait();
EXPECT_EQ((uint32_t)out_vec.size(), (uint32_t)(6));
}
template <typename T>
void CompareGrad(f::Scope* scope, const p::DeviceContext& ctx) {
// init
auto out = scope->Var("Out");
auto tensor_out = out->GetMutable<f::LoDTensor>();
std::vector<T> out_init;
out_init.push_back(static_cast<T>(0.6670));
out_init.push_back(static_cast<T>(0.5888));
out_init.push_back(static_cast<T>(0.4543));
out_init.push_back(static_cast<T>(0.3330));
out_init.push_back(static_cast<T>(0.4112));
out_init.push_back(static_cast<T>(0.5457));
TensorFromVector(out_init, ctx, tensor_out);
tensor_out->Resize({2, 3});
ctx.Wait();
auto dout = scope->Var("DOut");
auto tensor_dout = dout->GetMutable<f::LoDTensor>();
std::vector<T> dout_init;
for (int i = 0; i < 6; ++i) {
dout_init.push_back(static_cast<T>(1.0));
}
TensorFromVector(dout_init, ctx, tensor_dout);
tensor_dout->Resize({2, 3});
ctx.Wait();
auto dx = scope->Var("DX");
auto tensor_dx = dx->GetMutable<f::LoDTensor>();
ctx.Wait();
// run
f::AttributeMap attrs;
attrs = {
{"name", std::string("softmax_grad")},
{"axis", static_cast<int>(0)},
{"use_cudnn", false},
{"use_mkldnn", false},
{"mkldnn_data_type", std::string("float32")},
{"is_test", false},
{"data_format", std::string("AnyLayout")}, };
auto op =
f::OpRegistry::CreateOp("softmax_grad",
{{"Out", {"Out"}},
{"Out@GRAD", {"DOut"}}},
{{"X@GRAD", {"DX"}}}, attrs);
auto place = ctx.GetPlace();
op->Run(*scope, place);
ctx.Wait();
EXPECT_EQ((uint32_t)tensor_dx->dims()[0], (uint32_t)(2));
EXPECT_EQ((uint32_t)tensor_dx->dims()[1], (uint32_t)(3));
ctx.Wait();
std::vector<float> out_vec;
TensorToVector(*tensor_dx, ctx, &out_vec);
ctx.Wait();
EXPECT_EQ((uint32_t)out_vec.size(), (uint32_t)(6));
EXPECT_NEAR((float)out_vec[0], (float)(-0.4737), 0.1);
EXPECT_NEAR((float)out_vec[1], (float)(-0.4181), 0.1);
EXPECT_NEAR((float)out_vec[2], (float)(-0.3226), 0.1);
EXPECT_NEAR((float)out_vec[3], (float)(-0.0965), 0.1);
EXPECT_NEAR((float)out_vec[4], (float)(-0.1192), 0.1);
EXPECT_NEAR((float)out_vec[5], (float)(-0.1582), 0.1);
}
TEST(softmax, NPU_fp32) {
f::Scope scope;
p::NPUDeviceContext ctx(p::NPUPlace(0));
Compare<float>(&scope, ctx);
}
TEST(softmax_grad, NPU_fp32) {
f::Scope scope;
p::NPUDeviceContext ctx(p::NPUPlace(0));
CompareGrad<float>(&scope, ctx);
}
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import unittest
import sys
sys.path.append("..")
from op_test import OpTest
import paddle
import paddle.fluid as fluid
from paddle.fluid import core
paddle.enable_static()
SEED = 2021
@unittest.skipIf(not paddle.is_compiled_with_npu(),
"core is not compiled with NPU")
class TestSoftmax(OpTest):
def setUp(self):
self.set_npu()
self.place = paddle.NPUPlace(0)
self.op_type = "softmax"
self.init_dtype()
x = np.random.random([3, 3]).astype(self.dtype)
np_out = np.exp(x) / np.sum(np.exp(x), axis=1, keepdims=True)
self.inputs = {'X': x}
self.attrs = {}
self.outputs = {'Out': np_out}
def set_npu(self):
self.__class__.use_npu = True
self.__class__.no_need_check_grad = True
def init_dtype(self):
self.dtype = np.float32
def test_check_output(self):
self.check_output_with_place(self.place, check_dygraph=False)
@unittest.skipIf(not paddle.is_compiled_with_npu(),
"core is not compiled with NPU")
class TestSoftmaxNet(unittest.TestCase):
def _test(self, run_npu=True):
main_prog = paddle.static.Program()
startup_prog = paddle.static.Program()
main_prog.random_seed = SEED
startup_prog.random_seed = SEED
np.random.seed(SEED)
a_np = np.random.random(size=(4, 32)).astype('float32')
b_np = np.random.random(size=(4, 32)).astype('float32')
label_np = np.random.randint(2, size=(4, 1)).astype('int64')
with paddle.static.program_guard(main_prog, startup_prog):
a = paddle.static.data(name="a", shape=[4, 32], dtype='float32')
b = paddle.static.data(name="b", shape=[4, 32], dtype='float32')
label = paddle.static.data(
name="label", shape=[4, 1], dtype='int64')
c = paddle.multiply(a, b)
d = paddle.sqrt(c)
# 4 x 128
fc_1 = fluid.layers.fc(input=d, size=128)
# 4 x 2
prediction = fluid.layers.fc(input=fc_1, size=2)
# 4 x 2
prob = fluid.layers.softmax(prediction, axis=1)
cost = fluid.layers.cross_entropy(input=prob, label=label)
loss = fluid.layers.mean(cost)
sgd = fluid.optimizer.SGD(learning_rate=0.01)
sgd.minimize(loss)
if run_npu:
place = paddle.NPUPlace(0)
else:
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
exe.run(startup_prog)
print("Start run on {}".format(place))
for epoch in range(100):
pred_res, loss_res = exe.run(
main_prog,
feed={"a": a_np,
"b": b_np,
"label": label_np},
fetch_list=[prediction, loss])
if epoch % 10 == 0:
print("Epoch {} | Prediction[0]: {}, Loss: {}".format(
epoch, pred_res[0], loss_res))
return pred_res, loss_res
def test_npu(self):
cpu_pred, cpu_loss = self._test(False)
npu_pred, npu_loss = self._test(True)
self.assertTrue(np.allclose(npu_pred, cpu_pred, rtol=1e-2))
self.assertTrue(np.allclose(npu_loss, cpu_loss, rtol=1e-2))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册