diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 6fe18f2479478a49819da2608dc7c3a0bf5d3017..a82adb4254b3124275aa39712fc2899147faefa6 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -187,4 +187,5 @@ endif() if(WITH_ASCEND_CL) cc_test(gelu_op_npu_test SRCS gelu_op_npu_test.cc DEPS op_registry gelu_op scope device_context enforce executor) +cc_test(top_k_op_npu_test SRCS top_k_op_npu_test.cc DEPS op_registry top_k_op scope device_context enforce executor) endif() diff --git a/paddle/fluid/operators/softmax_op.cc b/paddle/fluid/operators/softmax_op.cc index ff750ab47a963c2f1d24e0f74b616534acaa2c41..7d4194227b4cd4b989d2c4a274cd730a2b78f90a 100644 --- a/paddle/fluid/operators/softmax_op.cc +++ b/paddle/fluid/operators/softmax_op.cc @@ -79,11 +79,13 @@ class SoftmaxOp : public framework::OperatorWithKernel { #endif auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); +#ifndef PADDLE_WITH_ASCEND_CL if (input_data_type == framework::proto::VarType::FP16) { PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true, platform::errors::InvalidArgument( "float16 can only be used on GPU place")); } +#endif return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_, library_); diff --git a/paddle/fluid/operators/softmax_op_npu.cc b/paddle/fluid/operators/softmax_op_npu.cc new file mode 100644 index 0000000000000000000000000000000000000000..3c7e08b7a2ee0b268e5e8440c601d1976c537282 --- /dev/null +++ b/paddle/fluid/operators/softmax_op_npu.cc @@ -0,0 +1,102 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include + +#include "paddle/fluid/operators/npu_op_runner.h" +#include "paddle/fluid/operators/softmax_op.h" + +namespace paddle { +namespace operators { + +template +class SoftmaxNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* in = ctx.Input("X"); + auto axis = ctx.Attr("axis"); + std::vector axes; + axes.push_back(axis); + framework::NPUAttributeMap attr_input = {{"axes", axes}}; + + auto* out = ctx.Output("Out"); + out->mutable_data(ctx.GetPlace()); + + auto runner = NpuOpRunner("SoftmaxV2", {*in}, {*out}, attr_input); + + auto stream = + ctx.template device_context() + .stream(); + runner.Run(stream); + } +}; + +template +class SoftmaxGradNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* out = ctx.Input("Out"); + auto* dOut = ctx.Input(framework::GradVarName("Out")); + + auto* dX = ctx.Output(framework::GradVarName("X")); + + auto dims = dX->dims(); + const int rank = dims.size(); + const int axis = CanonicalAxis(ctx.Attr("axis"), rank); + int64_t first_dim = 1; + int64_t sec_dim = 1; + for (int i = 0; i < axis; i++) { + first_dim *= dims[i]; + } + for (int i = axis; i < rank; i++) { + sec_dim *= dims[i]; + } + + Tensor tmp_out; + tmp_out.ShareDataWith(*out).Resize({first_dim, sec_dim}); + + Tensor tmp_dOut; + tmp_dOut.ShareDataWith(*dOut).Resize({first_dim, sec_dim}); + + + dX->Resize(framework::make_ddim({first_dim, sec_dim})); + dX->mutable_data(ctx.GetPlace()); + + framework::NPUAttributeMap attr_input = {}; + auto runner = NpuOpRunner(std::string("SoftmaxGrad"), {tmp_out, tmp_dOut}, + {*dX}, attr_input); + + auto stream = + ctx.template device_context() + .stream(); + runner.Run(stream); + + dX->Resize(dims); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_NPU_KERNEL( + softmax, ops::SoftmaxNPUKernel, + ops::SoftmaxNPUKernel, + ops::SoftmaxNPUKernel); + +REGISTER_OP_NPU_KERNEL( + softmax_grad, ops::SoftmaxGradNPUKernel, + ops::SoftmaxGradNPUKernel, + ops::SoftmaxGradNPUKernel); diff --git a/paddle/fluid/operators/softmax_op_npu_test.cc b/paddle/fluid/operators/softmax_op_npu_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..89357705ce0e6c8711567ed031b4fb1640bf7441 --- /dev/null +++ b/paddle/fluid/operators/softmax_op_npu_test.cc @@ -0,0 +1,175 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifndef _WIN32 +#include +#endif + +#include +#include // NOLINT +#include + +#include "gtest/gtest.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/operators/dropout_op.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/string/printf.h" +#include "paddle/fluid/framework/tensor_util.h" + + +namespace f = paddle::framework; +namespace p = paddle::platform; +namespace m = paddle::operators::math; + +USE_OP(softmax); +USE_OP_DEVICE_KERNEL(softmax, NPU); + +template +void Compare(f::Scope* scope, const p::DeviceContext& ctx) { + // init + auto x = scope->Var("X"); + auto tensor_x = x->GetMutable(); + + std::vector init; + for (int i = 3; i < 9; ++i) { + init.push_back(static_cast(i)); + } + + TensorFromVector(init, ctx, tensor_x); + tensor_x->Resize({2, 3}); + + ctx.Wait(); + + auto place = ctx.GetPlace(); + auto out = scope->Var("Out"); + auto tensor_out = out->GetMutable(); + tensor_out->Resize({2, 3}); + tensor_out->mutable_data(place); // allocate + + // run + int axis = 1; + f::AttributeMap attrs = { + {"axis", axis}, + {"use_cudnn", false}, + {"use_mkldnn", false}, + {"mkldnn_data_type", std::string("float32")}, + {"is_test", false}, }; + + auto op = + f::OpRegistry::CreateOp("softmax", {{"X", {"X"}}}, + {{"Out", {"Out"}}}, attrs); + + op->Run(*scope, place); + ctx.Wait(); + + std::vector out_vec; + TensorToVector(*tensor_out, ctx, &out_vec); + + for (int i = 0; i < static_cast(out_vec.size()); ++i) { + VLOG(3) << "out_vec[" << i << "] : "<< out_vec[i]; + } + + ctx.Wait(); + + EXPECT_EQ((uint32_t)out_vec.size(), (uint32_t)(6)); +} + + +template +void CompareGrad(f::Scope* scope, const p::DeviceContext& ctx) { + // init + auto out = scope->Var("Out"); + auto tensor_out = out->GetMutable(); + + std::vector out_init; + + out_init.push_back(static_cast(0.6670)); + out_init.push_back(static_cast(0.5888)); + out_init.push_back(static_cast(0.4543)); + out_init.push_back(static_cast(0.3330)); + out_init.push_back(static_cast(0.4112)); + out_init.push_back(static_cast(0.5457)); + + TensorFromVector(out_init, ctx, tensor_out); + tensor_out->Resize({2, 3}); + + ctx.Wait(); + + auto dout = scope->Var("DOut"); + auto tensor_dout = dout->GetMutable(); + + std::vector dout_init; + for (int i = 0; i < 6; ++i) { + dout_init.push_back(static_cast(1.0)); + } + + TensorFromVector(dout_init, ctx, tensor_dout); + tensor_dout->Resize({2, 3}); + + ctx.Wait(); + + auto dx = scope->Var("DX"); + auto tensor_dx = dx->GetMutable(); + + ctx.Wait(); + + // run + f::AttributeMap attrs; + attrs = { + {"name", std::string("softmax_grad")}, + {"axis", static_cast(0)}, + {"use_cudnn", false}, + {"use_mkldnn", false}, + {"mkldnn_data_type", std::string("float32")}, + {"is_test", false}, + {"data_format", std::string("AnyLayout")}, }; + auto op = + f::OpRegistry::CreateOp("softmax_grad", + {{"Out", {"Out"}}, + {"Out@GRAD", {"DOut"}}}, + {{"X@GRAD", {"DX"}}}, attrs); + + auto place = ctx.GetPlace(); + op->Run(*scope, place); + ctx.Wait(); + + EXPECT_EQ((uint32_t)tensor_dx->dims()[0], (uint32_t)(2)); + EXPECT_EQ((uint32_t)tensor_dx->dims()[1], (uint32_t)(3)); + + ctx.Wait(); + + std::vector out_vec; + TensorToVector(*tensor_dx, ctx, &out_vec); + + ctx.Wait(); + + EXPECT_EQ((uint32_t)out_vec.size(), (uint32_t)(6)); + EXPECT_NEAR((float)out_vec[0], (float)(-0.4737), 0.1); + EXPECT_NEAR((float)out_vec[1], (float)(-0.4181), 0.1); + EXPECT_NEAR((float)out_vec[2], (float)(-0.3226), 0.1); + EXPECT_NEAR((float)out_vec[3], (float)(-0.0965), 0.1); + EXPECT_NEAR((float)out_vec[4], (float)(-0.1192), 0.1); + EXPECT_NEAR((float)out_vec[5], (float)(-0.1582), 0.1); +} + +TEST(softmax, NPU_fp32) { + f::Scope scope; + p::NPUDeviceContext ctx(p::NPUPlace(0)); + Compare(&scope, ctx); +} + +TEST(softmax_grad, NPU_fp32) { + f::Scope scope; + p::NPUDeviceContext ctx(p::NPUPlace(0)); + CompareGrad(&scope, ctx); +} diff --git a/python/paddle/fluid/tests/unittests/npu/test_softmax_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_softmax_op_npu.py new file mode 100644 index 0000000000000000000000000000000000000000..c1ba41943a359ba2103bfd34c722c697d6b01b2f --- /dev/null +++ b/python/paddle/fluid/tests/unittests/npu/test_softmax_op_npu.py @@ -0,0 +1,125 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import unittest +import sys +sys.path.append("..") +from op_test import OpTest +import paddle +import paddle.fluid as fluid +from paddle.fluid import core + +paddle.enable_static() +SEED = 2021 + + +@unittest.skipIf(not paddle.is_compiled_with_npu(), + "core is not compiled with NPU") +class TestSoftmax(OpTest): + def setUp(self): + self.set_npu() + self.place = paddle.NPUPlace(0) + self.op_type = "softmax" + self.init_dtype() + + x = np.random.random([3, 3]).astype(self.dtype) + np_out = np.exp(x) / np.sum(np.exp(x), axis=1, keepdims=True) + self.inputs = {'X': x} + + self.attrs = {} + self.outputs = {'Out': np_out} + + def set_npu(self): + self.__class__.use_npu = True + self.__class__.no_need_check_grad = True + + def init_dtype(self): + self.dtype = np.float32 + + def test_check_output(self): + self.check_output_with_place(self.place, check_dygraph=False) + + +@unittest.skipIf(not paddle.is_compiled_with_npu(), + "core is not compiled with NPU") +class TestSoftmaxNet(unittest.TestCase): + def _test(self, run_npu=True): + main_prog = paddle.static.Program() + startup_prog = paddle.static.Program() + main_prog.random_seed = SEED + startup_prog.random_seed = SEED + np.random.seed(SEED) + + a_np = np.random.random(size=(4, 32)).astype('float32') + b_np = np.random.random(size=(4, 32)).astype('float32') + label_np = np.random.randint(2, size=(4, 1)).astype('int64') + + with paddle.static.program_guard(main_prog, startup_prog): + a = paddle.static.data(name="a", shape=[4, 32], dtype='float32') + b = paddle.static.data(name="b", shape=[4, 32], dtype='float32') + label = paddle.static.data( + name="label", shape=[4, 1], dtype='int64') + + c = paddle.multiply(a, b) + d = paddle.sqrt(c) + + # 4 x 128 + fc_1 = fluid.layers.fc(input=d, size=128) + # 4 x 2 + prediction = fluid.layers.fc(input=fc_1, size=2) + + # 4 x 2 + prob = fluid.layers.softmax(prediction, axis=1) + + cost = fluid.layers.cross_entropy(input=prob, label=label) + loss = fluid.layers.mean(cost) + sgd = fluid.optimizer.SGD(learning_rate=0.01) + sgd.minimize(loss) + + if run_npu: + place = paddle.NPUPlace(0) + else: + place = paddle.CPUPlace() + + exe = paddle.static.Executor(place) + exe.run(startup_prog) + + print("Start run on {}".format(place)) + for epoch in range(100): + + pred_res, loss_res = exe.run( + main_prog, + feed={"a": a_np, + "b": b_np, + "label": label_np}, + fetch_list=[prediction, loss]) + if epoch % 10 == 0: + print("Epoch {} | Prediction[0]: {}, Loss: {}".format( + epoch, pred_res[0], loss_res)) + + return pred_res, loss_res + + def test_npu(self): + cpu_pred, cpu_loss = self._test(False) + npu_pred, npu_loss = self._test(True) + + self.assertTrue(np.allclose(npu_pred, cpu_pred, rtol=1e-2)) + self.assertTrue(np.allclose(npu_loss, cpu_loss, rtol=1e-2)) + + +if __name__ == '__main__': + unittest.main()