From bc543e35d659086083b1acb00c60c148e88cc502 Mon Sep 17 00:00:00 2001 From: Fan Zhang Date: Thu, 12 Aug 2021 10:39:06 +0800 Subject: [PATCH] [NPU] Support npu op expand_v2 and expand_v2_grad (#34764) * [NPU] Support npu op expand_v2 and expand_v2_grad * [NPU] Support npu op expand_v2 and expand_v2_grad * [NPU] Support npu op expand_v2 and expand_v2_grad * update test_expand_v2_op_npu.py * update test_expand_v2_op_npu.py * modify expand_v2_op_npu.cc * modify expand_v2_op_npu.cc --- paddle/fluid/operators/expand_v2_op.h | 16 +- paddle/fluid/operators/expand_v2_op_npu.cc | 191 ++++++++++++ .../unittests/npu/test_expand_v2_op_npu.py | 277 ++++++++++++++++++ 3 files changed, 483 insertions(+), 1 deletion(-) create mode 100644 paddle/fluid/operators/expand_v2_op_npu.cc create mode 100755 python/paddle/fluid/tests/unittests/npu/test_expand_v2_op_npu.py diff --git a/paddle/fluid/operators/expand_v2_op.h b/paddle/fluid/operators/expand_v2_op.h index a720bd7b551..08131b71064 100644 --- a/paddle/fluid/operators/expand_v2_op.h +++ b/paddle/fluid/operators/expand_v2_op.h @@ -36,6 +36,12 @@ inline std::vector get_expand_shape( TensorCopySync(*shape_tensor, platform::CPUPlace(), &cpu_shape_tensor); shape_data = cpu_shape_tensor.data(); } +#ifdef PADDLE_WITH_ASCEND_CL + if (platform::is_npu_place(shape_tensor->place())) { + TensorCopySync(*shape_tensor, platform::CPUPlace(), &cpu_shape_tensor); + shape_data = cpu_shape_tensor.data(); + } +#endif auto vec_shape = std::vector(shape_data, shape_data + shape_tensor->numel()); return vec_shape; @@ -52,7 +58,15 @@ inline std::vector get_expand_shape( framework::Tensor temp; TensorCopySync(*tensor, platform::CPUPlace(), &temp); vec_epxand_shape.push_back(*temp.data()); - } else { + } +#ifdef PADDLE_WITH_ASCEND_CL + else if (platform::is_npu_place(tensor->place())) { // NOLINT + framework::Tensor temp; + TensorCopySync(*tensor, platform::CPUPlace(), &temp); + vec_epxand_shape.push_back(*temp.data()); + } +#endif + else { // NOLINT vec_epxand_shape.push_back(*tensor->data()); } } diff --git a/paddle/fluid/operators/expand_v2_op_npu.cc b/paddle/fluid/operators/expand_v2_op_npu.cc new file mode 100644 index 00000000000..85fe86a9e60 --- /dev/null +++ b/paddle/fluid/operators/expand_v2_op_npu.cc @@ -0,0 +1,191 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the Licnse. */ + +#include "paddle/fluid/operators/expand_v2_op.h" +#include "paddle/fluid/operators/npu_op_runner.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +template +class ExpandV2NPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* X = ctx.Input("X"); + auto* Out = ctx.Output("Out"); + + auto in_dims = X->dims(); + auto expand_shape = get_expand_shape(ctx); + auto vec_in_dims = framework::vectorize(in_dims); + auto diff = expand_shape.size() - vec_in_dims.size(); + vec_in_dims.insert(vec_in_dims.begin(), diff, 1); + std::vector final_expand_shape(vec_in_dims.size()); + for (size_t i = 0; i < vec_in_dims.size(); ++i) { + PADDLE_ENFORCE_NE(expand_shape[i], 0, + platform::errors::InvalidArgument( + "The expanded size cannot be zero.")); + if (i < diff) { // expand_shape = [3,4,-1,-1], X = [10,2] --> + // final_expand_shape = [3,4,10,2] + PADDLE_ENFORCE_GT( + expand_shape[i], 0, + platform::errors::InvalidArgument( + "The expanded size (%d) for non-existing dimensions must be " + "positive for expand_v2 op.", + expand_shape[i])); + final_expand_shape[i] = expand_shape[i]; + } else if (expand_shape[i] > 0) { // expand_shape = [3,4,10,4], X = + // [10,1] --> final_expand_shape = + // [3,4,10,4] + if (vec_in_dims[i] != 1) { + PADDLE_ENFORCE_EQ( + vec_in_dims[i], expand_shape[i], + platform::errors::InvalidArgument( + "The value (%d) of the non-singleton dimension does not match" + " the corresponding value (%d) in shape for expand_v2 op.", + vec_in_dims[i], expand_shape[i])); + final_expand_shape[i] = expand_shape[i]; + } else { + final_expand_shape[i] = expand_shape[i]; + } + } else { // expand_shape = [3,4,-1,-1], X = [10,2] --> final_expand_shape + // = [3,4,10,2] + PADDLE_ENFORCE_EQ( + expand_shape[i], -1, + platform::errors::InvalidArgument( + "When the value in shape is negative for expand_v2 op, " + "only -1 is supported, but the value received is %d.", + expand_shape[i])); + final_expand_shape[i] = vec_in_dims[i]; + } + } + + framework::NPUAttributeMap attr_input = {{"shape", final_expand_shape}}; + + auto rank = X->dims().size(); + + PADDLE_ENFORCE_GE( + rank, 1, + platform::errors::InvalidArgument( + "The rank of the input 'X' for expand_v2_npu op must be positive, " + "but the value received is %d.", + rank)); + PADDLE_ENFORCE_LE( + rank, MAX_RANK_SUPPORTED, + platform::errors::InvalidArgument( + "The rank of the input 'X' for expand_v2_npu op must be less than " + "or equal to %d, but the value received is %d.", + MAX_RANK_SUPPORTED, rank)); + auto shape_size = final_expand_shape.size(); + PADDLE_ENFORCE_GE( + shape_size, rank, + platform::errors::InvalidArgument( + "The number (%d) of elements of 'shape' for expand_v2_npu op must " + "be " + "greater than or equal to the rank (%d) of the input 'X'.", + shape_size, rank)); + PADDLE_ENFORCE_LE(shape_size, MAX_RANK_SUPPORTED, + platform::errors::InvalidArgument( + "The number (%d) of elements of 'shape' for " + "expand_v2_npu op must be " + "less than or equal to %d.", + shape_size, MAX_RANK_SUPPORTED)); + + framework::DDim out_dims = framework::make_ddim(final_expand_shape); + Out->Resize(out_dims); + Out->mutable_data(ctx.GetPlace()); + + const auto& runner = NpuOpRunner("ExpandD", {*X}, {*Out}, attr_input); + auto stream = + ctx.template device_context() + .stream(); + runner.Run(stream); + } +}; + +template +class ExpandV2NPUGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* dout = ctx.Input(framework::GradVarName("Out")); + auto* dx = ctx.Output(framework::GradVarName("X")); + dx->mutable_data(ctx.GetPlace()); + + auto stream = + ctx.template device_context() + .stream(); + + // case 1: reduce dout dims to dx dims + // For example: [2, 120] --> [120] + auto reduce_ndim = dout->dims().size() - dx->dims().size(); + std::vector axes; + for (auto i = 0; i < reduce_ndim; ++i) { + axes.push_back(i); + } + + Tensor tmp_dout(dout->type()); + Tensor reduced_dout(dx->type()); + tmp_dout.ShareDataWith(*dout); + if (axes.size() != 0) { + std::vector reduced_dout_dims; + for (auto i = reduce_ndim; i < dout->dims().size(); ++i) { + reduced_dout_dims.push_back(dout->dims()[i]); + } + tmp_dout.Resize(framework::make_ddim(reduced_dout_dims)); + reduced_dout.Resize(framework::make_ddim(reduced_dout_dims)); + reduced_dout.mutable_data(ctx.GetPlace()); + const auto& runner = NpuOpRunner("ReduceSumD", {*dout}, {reduced_dout}, + {{"axes", axes}, {"keep_dims", false}}); + runner.Run(stream); + tmp_dout = reduced_dout; + } + + // case 2: reduce axis of dout in which dim is 1 + // For example: [12, 140] --> [1, 140] + + // case 3: copy dout to dx when shape is totally same, and dim in dx != 1 + // For example: [2, 10, 5] --> [2, 10, 5] + axes.clear(); + for (auto i = 0; i < dx->dims().size(); ++i) { + if (dx->dims()[i] == 1) { + axes.push_back(i); + } + } + if (axes.size() != 0) { + const auto& runner = NpuOpRunner("ReduceSumD", {tmp_dout}, {*dx}, + {{"axes", axes}, {"keep_dims", true}}); + runner.Run(stream); + } else { + framework::TensorCopySync(tmp_dout, ctx.GetPlace(), dx); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_NPU_KERNEL( + expand_v2, + ops::ExpandV2NPUKernel, + ops::ExpandV2NPUKernel, + ops::ExpandV2NPUKernel); + +REGISTER_OP_NPU_KERNEL( + expand_v2_grad, + ops::ExpandV2NPUGradKernel, + ops::ExpandV2NPUGradKernel, + ops::ExpandV2NPUGradKernel); diff --git a/python/paddle/fluid/tests/unittests/npu/test_expand_v2_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_expand_v2_op_npu.py new file mode 100755 index 00000000000..d48d2a84301 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/npu/test_expand_v2_op_npu.py @@ -0,0 +1,277 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import unittest +import sys +import numpy as np +sys.path.append("..") +from op_test import OpTest +import paddle.fluid as fluid +from paddle.fluid import Program, program_guard +import paddle + +paddle.enable_static() +np.random.seed(10) + + +# CANN Op Support X: float16, float32, int32, int8 ,uint8 +# Situation 1: shape is a list(without tensor) +class TestExpandV2NPUOpRank1(OpTest): + def setUp(self): + self.set_npu() + self.place = paddle.NPUPlace(0) + self.op_type = "expand_v2" + self.dtype = np.float32 + self.init_data() + + self.inputs = {'X': np.random.random(self.ori_shape).astype(self.dtype)} + self.attrs = {'shape': self.shape} + output = np.tile(self.inputs['X'], self.expand_times) + self.outputs = {'Out': output} + + def set_npu(self): + self.__class__.use_npu = True + + def init_data(self): + self.ori_shape = [100] + self.shape = [100] + self.expand_times = [1] + + def test_check_output(self): + self.check_output_with_place(self.place) + + def test_check_grad(self): + self.check_grad_with_place(self.place, ['X'], 'Out') + + +class TestExpandV2OpRank2_DimExpanding(TestExpandV2NPUOpRank1): + def init_data(self): + self.ori_shape = [120] + self.shape = [2, 120] + self.expand_times = [2, 1] + + +class TestExpandV2OpRank2(TestExpandV2NPUOpRank1): + def init_data(self): + self.ori_shape = [1, 140] + self.shape = [12, 140] + self.expand_times = [12, 1] + + +class TestExpandV2OpRank3_Corner(TestExpandV2NPUOpRank1): + def init_data(self): + self.ori_shape = (2, 10, 5) + self.shape = (2, 10, 5) + self.expand_times = (1, 1, 1) + + +class TestExpandV2OpRank4(TestExpandV2NPUOpRank1): + def init_data(self): + self.ori_shape = (2, 4, 5, 7) + self.shape = (-1, -1, -1, -1) + self.expand_times = (1, 1, 1, 1) + + +class TestExpandV2OpRank5(TestExpandV2NPUOpRank1): + def init_data(self): + self.ori_shape = (2, 4, 1, 15) + self.shape = (2, -1, 4, -1) + self.expand_times = (1, 1, 4, 1) + + +class TestExpandV2OpRank6(TestExpandV2NPUOpRank1): + def init_data(self): + self.ori_shape = (4, 1, 30) + self.shape = (2, -1, 4, 30) + self.expand_times = (2, 1, 4, 1) + + +# Situation 2: shape is a list(with tensor) +class TestExpandV2OpNPURank1_tensor_attr(OpTest): + def setUp(self): + self.set_npu() + self.place = paddle.NPUPlace(0) + self.op_type = "expand_v2" + self.init_data() + self.dtype = np.float32 + expand_shapes_tensor = [] + for index, ele in enumerate(self.expand_shape): + expand_shapes_tensor.append(("x" + str(index), np.ones( + (1)).astype('int32') * ele)) + + self.inputs = { + 'X': np.random.random(self.ori_shape).astype(self.dtype), + 'expand_shapes_tensor': expand_shapes_tensor, + } + self.attrs = {"shape": self.infer_expand_shape} + output = np.tile(self.inputs['X'], self.expand_times) + self.outputs = {'Out': output} + + def set_npu(self): + self.__class__.use_npu = True + + def init_data(self): + self.ori_shape = [100] + self.expand_times = [1] + self.expand_shape = [100] + self.infer_expand_shape = [-1] + + def test_check_output(self): + self.check_output_with_place(self.place) + + def test_check_grad(self): + self.check_grad_with_place(self.place, ['X'], 'Out') + + +class TestExpandV2OpRank2_Corner_tensor_attr( + TestExpandV2OpNPURank1_tensor_attr): + def init_data(self): + self.ori_shape = [12, 14] + self.expand_times = [1, 1] + self.expand_shape = [12, 14] + self.infer_expand_shape = [12, -1] + + +# Situation 3: shape is a tensor +class TestExpandV2NPUOpRank1_tensor(OpTest): + def setUp(self): + self.set_npu() + self.place = paddle.NPUPlace(0) + self.op_type = "expand_v2" + self.init_data() + self.dtype = np.float32 + + self.inputs = { + 'X': np.random.random(self.ori_shape).astype(self.dtype), + 'Shape': np.array(self.expand_shape).astype("int32"), + } + self.attrs = {} + output = np.tile(self.inputs['X'], self.expand_times) + self.outputs = {'Out': output} + + def set_npu(self): + self.__class__.use_npu = True + + def init_data(self): + self.ori_shape = [100] + self.expand_times = [2, 1] + self.expand_shape = [2, 100] + + def test_check_output(self): + self.check_output_with_place(self.place) + + def test_check_grad(self): + self.check_grad_with_place(self.place, ['X'], 'Out') + + +# Situation 4: input x is float16 +# skip grad check for float16 +class TestExpandV2OpFloat(OpTest): + def setUp(self): + self.set_npu() + self.place = paddle.NPUPlace(0) + self.op_type = "expand_v2" + self.dtype = np.float16 + self.ori_shape = (2, 4, 20) + self.inputs = {'X': np.random.random(self.ori_shape).astype(self.dtype)} + self.attrs = {'shape': [2, 4, 20]} + output = np.tile(self.inputs['X'], (1, 1, 1)) + self.outputs = {'Out': output} + + def set_npu(self): + self.__class__.use_npu = True + self.__class__.no_need_check_grad = True + + def test_check_output(self): + self.check_output_with_place(self.place) + + +# Situation 5: input x is int32 +# skip grad check for int32 +class TestExpandV2OpInteger(OpTest): + def setUp(self): + self.set_npu() + self.place = paddle.NPUPlace(0) + self.op_type = "expand_v2" + self.inputs = { + 'X': np.random.randint( + 10, size=(2, 4, 20)).astype("int32") + } + self.attrs = {'shape': [2, 4, 20]} + output = np.tile(self.inputs['X'], (1, 1, 1)) + self.outputs = {'Out': output} + + def set_npu(self): + self.__class__.use_npu = True + self.__class__.no_need_check_grad = True + + def test_check_output(self): + self.check_output_with_place(self.place) + + +class TestExpandV2Error(unittest.TestCase): + def test_errors(self): + with program_guard(Program(), Program()): + x1 = fluid.create_lod_tensor( + np.array([[-1]]), [[1]], paddle.NPUPlace(0)) + shape = [2, 2] + self.assertRaises(TypeError, paddle.tensor.expand, x1, shape) + x2 = fluid.layers.data(name='x2', shape=[2], dtype="uint8") + self.assertRaises(TypeError, paddle.tensor.expand, x2, shape) + x3 = fluid.layers.data(name='x3', shape=[2], dtype="bool") + x3.stop_gradient = False + self.assertRaises(ValueError, paddle.tensor.expand, x3, shape) + + +# Test python API +class TestExpandV2API(unittest.TestCase): + def test_static(self): + with fluid.program_guard(fluid.Program(), fluid.Program()): + input = np.random.random([12, 14]).astype("float32") + x = fluid.layers.data( + name='x', + shape=[12, 14], + append_batch_size=False, + dtype="float32") + + positive_2 = fluid.layers.fill_constant([1], "int32", 12) + expand_shape = fluid.layers.data( + name="expand_shape", + shape=[2], + append_batch_size=False, + dtype="int32") + + out_1 = paddle.expand(x, shape=[12, 14]) + out_2 = paddle.expand(x, shape=[positive_2, 14]) + out_3 = paddle.expand(x, shape=expand_shape) + + g0 = fluid.backward.calc_gradient(out_2, x) + + exe = fluid.Executor(place=paddle.NPUPlace(0)) + res_1, res_2, res_3 = exe.run(fluid.default_main_program(), + feed={ + "x": input, + "expand_shape": + np.array([12, 14]).astype("int32") + }, + fetch_list=[out_1, out_2, out_3]) + + assert np.array_equal(res_1, np.tile(input, (1, 1))) + assert np.array_equal(res_2, np.tile(input, (1, 1))) + assert np.array_equal(res_3, np.tile(input, (1, 1))) + + +if __name__ == "__main__": + unittest.main() -- GitLab