From c4e0498631989db423b89963de6b0bb0dcac3657 Mon Sep 17 00:00:00 2001 From: pangyoki Date: Mon, 12 Jul 2021 16:14:37 +0800 Subject: [PATCH] [NPU] add dropout npu op (#34081) * add dropout npu op * fix bugs * add unittest * fix bugs * support 1-D input --- paddle/fluid/operators/dropout_op_npu.cc | 199 ++++++++++++ paddle/fluid/operators/npu_op_runner.cc | 14 +- .../unittests/npu/test_dropout_op_npu.py | 297 ++++++++++++++++++ 3 files changed, 507 insertions(+), 3 deletions(-) create mode 100644 paddle/fluid/operators/dropout_op_npu.cc create mode 100644 python/paddle/fluid/tests/unittests/npu/test_dropout_op_npu.py diff --git a/paddle/fluid/operators/dropout_op_npu.cc b/paddle/fluid/operators/dropout_op_npu.cc new file mode 100644 index 00000000000..b5c8bfff0dc --- /dev/null +++ b/paddle/fluid/operators/dropout_op_npu.cc @@ -0,0 +1,199 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the Licnse. */ + +#include +#include + +#include "paddle/fluid/framework/ddim.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/operators/dropout_op.h" +#include "paddle/fluid/operators/npu_op_runner.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class DropoutNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* seed_tensor = + ctx.HasInput("Seed") ? ctx.Input("Seed") : nullptr; + auto* out = ctx.Output("Out"); + auto* mask = ctx.Output("Mask"); + + auto dropout_prob = ctx.Attr("dropout_prob"); + auto is_test = ctx.Attr("is_test"); + + out->mutable_data(ctx.GetPlace()); + + auto stream = + ctx.template device_context() + .stream(); + + if (dropout_prob == 1.) { + const auto& runner_zeros_out = NpuOpRunner("ZerosLike", {*out}, {*out}); + runner_zeros_out.Run(stream); + mask->mutable_data(ctx.GetPlace()); + const auto& runner_zeros_mask = + NpuOpRunner("ZerosLike", {*mask}, {*mask}); + runner_zeros_mask.Run(stream); + return; + } + + // only achive the default `upscale_in_train` method + if (!is_test) { + Tensor tmp_x(x->type()); + Tensor tmp_out(out->type()); + tmp_x.ShareDataWith(*x); + tmp_out.ShareDataWith(*out); + if (x->dims().size() == 1) { + // DropOutDoMask will get error result when input + // is 1-D. Make it become 2-D. + std::vector vec_dim = framework::vectorize(x->dims()); + tmp_x.Resize(framework::make_ddim({vec_dim[0], 1})); + tmp_out.Resize(framework::make_ddim({vec_dim[0], 1})); + } + + int seed = 0; + int seed2 = 0; + float keep_prob = 1. - dropout_prob; + if (seed_tensor) { + std::vector seed_data; + TensorToVector(*seed_tensor, ctx.device_context(), &seed_data); + seed = seed_data[0]; + } else { + seed = ctx.Attr("fix_seed") ? ctx.Attr("seed") : 0; + } + + Tensor keep_prob_tensor(x->type()); + keep_prob_tensor.mutable_data({1}, ctx.GetPlace()); + FillNpuTensorWithConstant(&keep_prob_tensor, + static_cast(keep_prob)); + + mask->mutable_data(ctx.GetPlace()); + + // mask used in `DropOutGenMask` NPU OP is different from + // the output `Mask`. + Tensor npu_mask(framework::proto::VarType::UINT8); + uint32_t length = (x->numel() + 128 - 1) / 128 * 128; + npu_mask.Resize(framework::make_ddim({length / 8})); + npu_mask.mutable_data(ctx.GetPlace()); + + // TODO(pangyoki): `keep_prob` used in `DropOutGenMask` NPU + // OP must be a scalar with shape[0]. At present, the shape + // of the `prob` Tensor of this OP is forced to be set to 0 + // in `npu_op_runner.cc`, which needs to be optimized later. + NpuOpRunner runner_gen_mask; + runner_gen_mask.SetType("DropOutGenMask") + .AddInput(framework::vectorize(tmp_out.dims())) + .AddInput(keep_prob_tensor) + .AddOutput(npu_mask) + .AddAttr("seed", seed) + .AddAttr("seed2", seed2); + runner_gen_mask.Run(stream); + + NpuOpRunner runner_dropout; + runner_dropout.SetType("DropOutDoMask") + .AddInput(tmp_x) + .AddInput(npu_mask) + .AddInput(keep_prob_tensor) + .AddOutput(tmp_out); + runner_dropout.Run(stream); + + // cast `out` from float/float16 to bool + Tensor cast_mask(framework::proto::VarType::BOOL); + cast_mask.Resize(mask->dims()); + cast_mask.mutable_data(ctx.GetPlace()); + auto dst_dtype_bool = ConvertToNpuDtype(cast_mask.type()); + const auto& runner_cast_mask_bool = + NpuOpRunner("Cast", {*out}, {cast_mask}, + {{"dst_type", static_cast(dst_dtype_bool)}}); + runner_cast_mask_bool.Run(stream); + + // cast cast_mask from bool to uint8 + auto dst_dtype_uint8 = ConvertToNpuDtype(mask->type()); + const auto& runner_cast_mask_uint8 = + NpuOpRunner("Cast", {cast_mask}, {*mask}, + {{"dst_type", static_cast(dst_dtype_uint8)}}); + runner_cast_mask_uint8.Run(stream); + } else { + framework::TensorCopy( + *x, ctx.GetPlace(), + ctx.template device_context(), out); + } + } +}; + +template +class DropoutGradNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* dx = ctx.Output(framework::GradVarName("X")); + auto* dout = ctx.Input(framework::GradVarName("Out")); + auto* mask = ctx.Input("Mask"); + + auto dropout_prob = ctx.Attr("dropout_prob"); + auto is_test = ctx.Attr("is_test"); + + PADDLE_ENFORCE_EQ(is_test, false, + platform::errors::PreconditionNotMet( + "GradOp is only callable when is_test is false")); + + dx->mutable_data(ctx.GetPlace()); + + auto stream = + ctx.template device_context() + .stream(); + + if (dropout_prob == 1.) { + const auto& runner_zeros = NpuOpRunner("ZerosLike", {*dx}, {*dx}); + runner_zeros.Run(stream); + return; + } + + // cast mask from uint8 to float32/float16 + Tensor cast_mask(dx->type()); + cast_mask.Resize(mask->dims()); + cast_mask.mutable_data(ctx.GetPlace()); + auto dst_dtype = ConvertToNpuDtype(dx->type()); + const auto& runner_cast_mask = + NpuOpRunner("Cast", {*mask}, {cast_mask}, + {{"dst_type", static_cast(dst_dtype)}}); + runner_cast_mask.Run(stream); + + const auto& runner = + NpuOpRunner("MaskedScale", {*dout, cast_mask}, {*dx}, + {{"value", static_cast(1. / (1 - dropout_prob))}}); + runner.Run(stream); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_NPU_KERNEL( + dropout, ops::DropoutNPUKernel, + ops::DropoutNPUKernel); + +REGISTER_OP_NPU_KERNEL( + dropout_grad, + ops::DropoutGradNPUKernel, + ops::DropoutGradNPUKernel); diff --git a/paddle/fluid/operators/npu_op_runner.cc b/paddle/fluid/operators/npu_op_runner.cc index 25ef24d04d2..4461941e85c 100644 --- a/paddle/fluid/operators/npu_op_runner.cc +++ b/paddle/fluid/operators/npu_op_runner.cc @@ -32,6 +32,7 @@ namespace operators { static std::map DTYPE_2_ACL_DTYPE = { {framework::proto::VarType::BOOL, ACL_BOOL}, + {framework::proto::VarType::UINT8, ACL_UINT8}, {framework::proto::VarType::INT16, ACL_INT16}, {framework::proto::VarType::INT32, ACL_INT32}, {framework::proto::VarType::INT64, ACL_INT64}, @@ -325,17 +326,24 @@ aclTensorDesc *NpuOpRunner::CreateTensorDesc(Tensor tensor, auto dtype = ConvertToNpuDtype(tensor.type()); auto format = ConvertToNpuFormat(tensor.layout()); auto dims = framework::vectorize(tensor.dims()); + int size = dims.size(); + // TODO(pangyoki): `keep_prob` used in `DropOutGenMask` NPU + // OP must be a scalar with shape[0]. At present, the shape + // of the `prob` Tensor of this OP is forced to be set to 0 + // in `npu_op_runner.cc`, which needs to be optimized later. + if (op_type_ == "DropOutGenMask" && size == 1 && *(dims.data()) == 1) { + size = 0; + } VLOG(4) << "NPU dtype:" << dtype << " " << "rank:" << dims.size() << " dims:" << tensor.dims() << " format:" << format; - auto *desc = aclCreateTensorDesc(dtype, dims.size(), dims.data(), format); + auto *desc = aclCreateTensorDesc(dtype, size, dims.data(), format); PADDLE_ENFORCE_NOT_NULL( desc, platform::errors::External("Call aclCreateTensorDesc failed.")); PADDLE_ENFORCE_NPU_SUCCESS(aclSetTensorStorageFormat(desc, format)); - PADDLE_ENFORCE_NPU_SUCCESS( - aclSetTensorStorageShape(desc, dims.size(), dims.data())); + PADDLE_ENFORCE_NPU_SUCCESS(aclSetTensorStorageShape(desc, size, dims.data())); if (mem_type == ACL_MEMTYPE_HOST) { PADDLE_ENFORCE_NPU_SUCCESS(aclSetTensorPlaceMent(desc, mem_type)); } diff --git a/python/paddle/fluid/tests/unittests/npu/test_dropout_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_dropout_op_npu.py new file mode 100644 index 00000000000..6b936514452 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/npu/test_dropout_op_npu.py @@ -0,0 +1,297 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import unittest +import sys +sys.path.append("..") +from op_test import OpTest, skip_check_grad_ci +import paddle +import paddle.fluid as fluid + +paddle.enable_static() + +SEED = 2021 +EPOCH = 100 + + +@unittest.skipIf(not paddle.is_compiled_with_npu(), + "core is not compiled with NPU") +class TestDropoutOp(OpTest): + def setUp(self): + self.op_type = "dropout" + self.set_npu() + self.init_dtype() + self.inputs = {'X': np.random.random((32, 64)).astype(self.dtype)} + self.attrs = { + 'dropout_prob': 0.0, + 'fix_seed': True, + 'is_test': False, + 'dropout_implementation': 'upscale_in_train' + } + self.outputs = { + 'Out': self.inputs['X'], + 'Mask': np.ones((32, 64)).astype('uint8') + } + + def init_dtype(self): + self.dtype = np.float32 + + def set_npu(self): + self.__class__.use_npu = True + self.place = paddle.NPUPlace(0) + + def test_check_output(self): + self.check_output_with_place(self.place, check_dygraph=False) + + def test_check_grad_normal(self): + if self.dtype == np.float16: + return + self.check_grad_with_place( + self.place, ['X'], 'Out', check_dygraph=False) + + +@unittest.skipIf(not paddle.is_compiled_with_npu(), + "core is not compiled with NPU") +class TestDropoutOpInput1d(TestDropoutOp): + # change input shape + def setUp(self): + self.op_type = "dropout" + self.set_npu() + self.init_dtype() + self.inputs = {'X': np.random.random((3, 62)).astype(self.dtype)} + self.attrs = { + 'dropout_prob': 0.0, + 'fix_seed': True, + 'is_test': False, + 'dropout_implementation': 'upscale_in_train' + } + self.outputs = { + 'Out': self.inputs['X'], + 'Mask': np.ones((3, 62)).astype('uint8') + } + + +@unittest.skipIf(not paddle.is_compiled_with_npu(), + "core is not compiled with NPU") +class TestDropoutOpInput1d(TestDropoutOp): + # the input is 1-D + def setUp(self): + self.op_type = "dropout" + self.set_npu() + self.init_dtype() + self.inputs = {'X': np.random.random((2000, )).astype(self.dtype)} + self.attrs = { + 'dropout_prob': 0.0, + 'fix_seed': True, + 'is_test': False, + 'dropout_implementation': 'upscale_in_train' + } + self.outputs = { + 'Out': self.inputs['X'], + 'Mask': np.ones((2000)).astype('uint8') + } + + +@unittest.skipIf(not paddle.is_compiled_with_npu(), + "core is not compiled with NPU") +class TestDropoutOp2(TestDropoutOp): + # the dropout_prob is 1.0 + def setUp(self): + self.op_type = "dropout" + self.set_npu() + self.init_dtype() + self.inputs = {'X': np.random.random((32, 64)).astype(self.dtype)} + self.attrs = { + 'dropout_prob': 1.0, + 'fix_seed': True, + 'is_test': False, + 'dropout_implementation': 'upscale_in_train' + } + self.outputs = { + 'Out': np.zeros((32, 64)).astype('float32'), + 'Mask': np.zeros((32, 64)).astype('uint8') + } + + +@unittest.skipIf(not paddle.is_compiled_with_npu(), + "core is not compiled with NPU") +class TestDropoutOp3(TestDropoutOp): + # the input dim is 3 + def setUp(self): + self.op_type = "dropout" + self.set_npu() + self.init_dtype() + self.inputs = {'X': np.random.random((32, 64, 2)).astype(self.dtype)} + self.attrs = { + 'dropout_prob': 0.0, + 'fix_seed': True, + 'is_test': False, + 'dropout_implementation': 'upscale_in_train' + } + self.outputs = { + 'Out': self.inputs['X'], + 'Mask': np.ones((32, 64, 2)).astype('uint8') + } + + +@unittest.skipIf(not paddle.is_compiled_with_npu(), + "core is not compiled with NPU") +@skip_check_grad_ci(reason="For inference, check_grad is not required.") +class TestDropoutOpInference(OpTest): + # is_test = True + def setUp(self): + self.op_type = "dropout" + self.set_npu() + self.init_dtype() + self.inputs = {'X': np.random.random((32, 64)).astype(self.dtype)} + self.attrs = { + 'dropout_prob': 0.35, + 'fix_seed': True, + 'is_test': True, + 'dropout_implementation': 'upscale_in_train' + } + self.outputs = {'Out': self.inputs['X']} + + def init_dtype(self): + self.dtype = np.float32 + + def set_npu(self): + self.__class__.use_npu = True + self.place = paddle.NPUPlace(0) + + def test_check_output(self): + self.check_output_with_place(self.place, check_dygraph=False) + + +@unittest.skipIf(not paddle.is_compiled_with_npu(), + "core is not compiled with NPU") +@skip_check_grad_ci(reason="For inference, check_grad is not required.") +class TestDropoutOpInference2(TestDropoutOpInference): + def setUp(self): + self.op_type = "dropout" + self.set_npu() + self.init_dtype() + self.inputs = {'X': np.random.random((32, 64, 3)).astype(self.dtype)} + self.attrs = { + 'dropout_prob': 0.75, + 'is_test': True, + 'dropout_implementation': 'upscale_in_train' + } + self.outputs = {'Out': self.inputs['X']} + + +@unittest.skipIf(not paddle.is_compiled_with_npu(), + "core is not compiled with NPU") +class TestDropoutOpWithSeed(TestDropoutOp): + # the seed is a Tensor + def setUp(self): + self.op_type = "dropout" + self.set_npu() + self.init_dtype() + self.inputs = { + "X": np.random.random((32, 64)).astype(self.dtype), + "Seed": np.asarray( + [125], dtype="int32") + } + self.attrs = { + 'dropout_prob': 0.0, + 'is_test': False, + 'dropout_implementation': 'upscale_in_train' + } + self.outputs = { + 'Out': self.inputs['X'], + 'Mask': np.ones((32, 64)).astype('uint8') + } + + +@unittest.skipIf(not paddle.is_compiled_with_npu(), + "core is not compiled with NPU") +class TestDropoutOpFp16(TestDropoutOp): + # float16 + def init_dtype(self): + self.dtype = np.float16 + + def set_npu(self): + self.__class__.use_npu = True + self.__class__.no_need_check_grad = True + self.place = paddle.NPUPlace(0) + + +@unittest.skipIf(not paddle.is_compiled_with_npu(), + "core is not compiled with NPU") +class TestDropoutAPI(unittest.TestCase): + def setUp(self): + np.random.seed(123) + self.places = [fluid.CPUPlace(), paddle.NPUPlace(0)] + + def check_static_result(self, place): + with fluid.program_guard(fluid.Program(), fluid.Program()): + input = fluid.data(name="input", shape=[40, 40], dtype="float32") + res1 = paddle.nn.functional.dropout( + x=input, p=0., training=False, mode='upscale_in_train') + res2 = paddle.nn.functional.dropout( + x=input, p=0., axis=0, training=True, mode='upscale_in_train') + res3 = paddle.nn.functional.dropout( + x=input, p=0., axis=0, training=False, mode='upscale_in_train') + res4 = paddle.nn.functional.dropout( + x=input, + p=0., + axis=[0, 1], + training=True, + mode='upscale_in_train') + res5 = paddle.nn.functional.dropout( + x=input, + p=0., + axis=[0, 1], + training=False, + mode='upscale_in_train') + res6 = paddle.nn.functional.dropout( + x=input, p=1., training=True, mode='upscale_in_train') + res7 = paddle.fluid.layers.dropout( + x=input, + dropout_prob=0., + dropout_implementation='upscale_in_train') + res8 = paddle.nn.functional.dropout( + x=input, + p=0., + axis=(0, 1), + training=False, + mode='upscale_in_train') + + in_np = np.random.random([40, 40]).astype("float32") + res_np = in_np + res_np2 = np.zeros_like(in_np) + + exe = fluid.Executor(place) + res_list = [res1, res2, res3, res4, res5, res7, res8] + for res in res_list: + fetches = exe.run(fluid.default_main_program(), + feed={"input": in_np}, + fetch_list=[res]) + self.assertTrue(np.allclose(fetches[0], res_np)) + fetches2 = exe.run(fluid.default_main_program(), + feed={"input": in_np}, + fetch_list=[res6]) + self.assertTrue(np.allclose(fetches2[0], res_np2)) + + def test_static(self): + for place in self.places: + self.check_static_result(place=place) + + +if __name__ == '__main__': + unittest.main() -- GitLab