From 6a179e4868058e7470253692f510fa04e7a24a8a Mon Sep 17 00:00:00 2001 From: fuyou765 <64373205+fuyou765@users.noreply.github.com> Date: Fri, 17 Jun 2022 17:53:17 +0800 Subject: [PATCH] [MLU]add mlu kernel for expand_v2 op (#43353) --- paddle/fluid/operators/expand_v2_op.h | 14 + paddle/fluid/operators/expand_v2_op_mlu.cc | 111 +++++++ .../unittests/mlu/test_expand_v2_op_mlu.py | 308 ++++++++++++++++++ 3 files changed, 433 insertions(+) create mode 100644 paddle/fluid/operators/expand_v2_op_mlu.cc create mode 100644 python/paddle/fluid/tests/unittests/mlu/test_expand_v2_op_mlu.py diff --git a/paddle/fluid/operators/expand_v2_op.h b/paddle/fluid/operators/expand_v2_op.h index 158a9d1bc52..d78ae442064 100644 --- a/paddle/fluid/operators/expand_v2_op.h +++ b/paddle/fluid/operators/expand_v2_op.h @@ -50,6 +50,13 @@ inline std::vector get_expand_shape( &cpu_shape_tensor); shape_data = cpu_shape_tensor.data(); } +#endif +#ifdef PADDLE_WITH_MLU + if (platform::is_mlu_place(shape_tensor->place())) { + paddle::framework::TensorCopySync(*shape_tensor, platform::CPUPlace(), + &cpu_shape_tensor); + shape_data = cpu_shape_tensor.data(); + } #endif auto vec_shape = std::vector(shape_data, shape_data + shape_tensor->numel()); @@ -81,6 +88,13 @@ inline std::vector get_expand_shape( paddle::framework::TensorCopySync(*tensor, platform::CPUPlace(), &temp); vec_epxand_shape.push_back(*temp.data()); } +#endif +#ifdef PADDLE_WITH_MLU + else if (platform::is_mlu_place(tensor->place())) { // NOLINT + framework::Tensor temp; + paddle::framework::TensorCopySync(*tensor, platform::CPUPlace(), &temp); + vec_epxand_shape.push_back(*temp.data()); + } #endif else { // NOLINT vec_epxand_shape.push_back(*tensor->data()); diff --git a/paddle/fluid/operators/expand_v2_op_mlu.cc b/paddle/fluid/operators/expand_v2_op_mlu.cc new file mode 100644 index 00000000000..8f8104c48b5 --- /dev/null +++ b/paddle/fluid/operators/expand_v2_op_mlu.cc @@ -0,0 +1,111 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef PADDLE_WITH_MLU + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/expand_v2_op.h" +#include "paddle/fluid/operators/mlu/mlu_baseop.h" + +namespace paddle { +namespace operators { + +template +class ExpandV2MLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* X = ctx.Input("X"); + auto* Out = ctx.Output("Out"); + auto in_dims = X->dims(); + auto expand_shape = get_expand_shape(ctx); + auto vec_in_dims = phi::vectorize(in_dims); + auto diff = expand_shape.size() - vec_in_dims.size(); + vec_in_dims.insert(vec_in_dims.begin(), diff, 1); + std::vector final_expand_shape(vec_in_dims.size()); + for (size_t i = 0; i < vec_in_dims.size(); ++i) { + PADDLE_ENFORCE_NE(expand_shape[i], 0, + platform::errors::InvalidArgument( + "The expanded size cannot be zero.")); + if (i < diff) { // expand_shape = [3,4,-1,-1], X = [10,2] --> + // final_expand_shape = [3,4,10,2] + PADDLE_ENFORCE_GT( + expand_shape[i], 0, + platform::errors::InvalidArgument( + "The expanded size (%d) for non-existing dimensions must be " + "positive for expand_v2 op.", + expand_shape[i])); + final_expand_shape[i] = expand_shape[i]; + } else if (expand_shape[i] > 0) { // expand_shape = [3,4,10,4], X = + // [10,1] --> final_expand_shape = + // [3,4,10,4] + if (vec_in_dims[i] != 1) { + PADDLE_ENFORCE_EQ( + vec_in_dims[i], expand_shape[i], + platform::errors::InvalidArgument( + "The value (%d) of the non-singleton dimension does not match" + " the corresponding value (%d) in shape for expand_v2 op.", + vec_in_dims[i], expand_shape[i])); + final_expand_shape[i] = expand_shape[i]; + } else { + final_expand_shape[i] = expand_shape[i]; + } + } else { // expand_shape = [3,4,-1,-1], X = [10,2] --> final_expand_shape + // = [3,4,10,2] + PADDLE_ENFORCE_EQ( + expand_shape[i], -1, + platform::errors::InvalidArgument( + "When the value in shape is negative for expand_v2 op, " + "only -1 is supported, but the value received is %d.", + expand_shape[i])); + final_expand_shape[i] = vec_in_dims[i]; + } + } + + auto rank = X->dims().size(); + PADDLE_ENFORCE_GE( + rank, 1, + platform::errors::InvalidArgument( + "The rank of the input 'X' for expand_v2_mlu op must be positive, " + "but the value received is %d.", + rank)); + auto shape_size = final_expand_shape.size(); + PADDLE_ENFORCE_GE( + shape_size, rank, + platform::errors::InvalidArgument( + "The number (%d) of elements of 'shape' for expand_v2_mlu op must " + "be " + "greater than or equal to the rank (%d) of the input 'X'.", + shape_size, rank)); + + framework::DDim out_dims = phi::make_ddim(final_expand_shape); + Out->Resize(out_dims); + auto place = ctx.GetPlace(); + Out->mutable_data(place); + MLUCnnlTensorDesc x_desc(*X); + MLUCnnlTensorDesc out_desc(*Out); + MLUCnnl::BroadcastTo(ctx, x_desc.get(), GetBasePtr(X), out_desc.get(), + GetBasePtr(Out)); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_MLU_KERNEL(expand_v2, ops::ExpandV2MLUKernel, + ops::ExpandV2MLUKernel, + ops::ExpandV2MLUKernel, + ops::ExpandV2MLUKernel, + ops::ExpandV2MLUKernel); + +#endif diff --git a/python/paddle/fluid/tests/unittests/mlu/test_expand_v2_op_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_expand_v2_op_mlu.py new file mode 100644 index 00000000000..d7b1768d509 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mlu/test_expand_v2_op_mlu.py @@ -0,0 +1,308 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import sys + +sys.path.append("..") +import unittest +import numpy as np +from op_test import OpTest +import paddle.fluid as fluid +from paddle.fluid import compiler, Program, program_guard +import paddle +from paddle.fluid.framework import _test_eager_guard + + +# Situation 1: shape is a list(without tensor) +class TestExpandV2OpRank1(OpTest): + + def setUp(self): + self.op_type = "expand_v2" + self.place = paddle.device.MLUPlace(0) + self.__class__.use_mlu = True + self.init_data() + self.python_api = paddle.expand + + self.inputs = {'X': np.random.random(self.ori_shape).astype("float32")} + self.attrs = {'shape': self.shape} + output = np.tile(self.inputs['X'], self.expand_times) + self.outputs = {'Out': output} + + def init_data(self): + self.ori_shape = [100] + self.shape = [100] + self.expand_times = [1] + + def test_check_output(self): + self.check_output_with_place(self.place, check_eager=False) + + def test_check_grad(self): + self.check_grad(['X'], 'Out', check_eager=True) + + +class TestExpandV2OpRank2_DimExpanding(TestExpandV2OpRank1): + + def init_data(self): + self.ori_shape = [120] + self.shape = [2, 120] + self.expand_times = [2, 1] + + +class TestExpandV2OpRank2(TestExpandV2OpRank1): + + def init_data(self): + self.ori_shape = [1, 140] + self.shape = [12, 140] + self.expand_times = [12, 1] + + +class TestExpandV2OpRank3_Corner(TestExpandV2OpRank1): + + def init_data(self): + self.ori_shape = (2, 10, 5) + self.shape = (2, 10, 5) + self.expand_times = (1, 1, 1) + + +class TestExpandV2OpRank4(TestExpandV2OpRank1): + + def init_data(self): + self.ori_shape = (2, 4, 5, 7) + self.shape = (-1, -1, -1, -1) + self.expand_times = (1, 1, 1, 1) + + +class TestExpandV2OpRank5(TestExpandV2OpRank1): + + def init_data(self): + self.ori_shape = (2, 4, 1, 15) + self.shape = (2, -1, 4, -1) + self.expand_times = (1, 1, 4, 1) + + +class TestExpandV2OpRank6(TestExpandV2OpRank1): + + def init_data(self): + self.ori_shape = (4, 1, 30) + self.shape = (2, -1, 4, 30) + self.expand_times = (2, 1, 4, 1) + + +# Situation 2: shape is a list(with tensor) +class TestExpandV2OpRank1_tensor_attr(OpTest): + + def setUp(self): + self.op_type = "expand_v2" + self.place = paddle.device.MLUPlace(0) + self.__class__.use_mlu = True + self.init_data() + expand_shapes_tensor = [] + for index, ele in enumerate(self.expand_shape): + expand_shapes_tensor.append(("x" + str(index), np.ones( + (1)).astype('int32') * ele)) + + self.inputs = { + 'X': np.random.random(self.ori_shape).astype("float32"), + 'expand_shapes_tensor': expand_shapes_tensor, + } + self.attrs = {"shape": self.infer_expand_shape} + output = np.tile(self.inputs['X'], self.expand_times) + self.outputs = {'Out': output} + + def init_data(self): + self.ori_shape = [100] + self.expand_times = [1] + self.expand_shape = [100] + self.infer_expand_shape = [-1] + + def test_check_output(self): + self.check_output_with_place(self.place, check_eager=False) + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + +class TestExpandV2OpRank2_Corner_tensor_attr(TestExpandV2OpRank1_tensor_attr): + + def init_data(self): + self.ori_shape = [12, 14] + self.expand_times = [1, 1] + self.expand_shape = [12, 14] + self.infer_expand_shape = [12, -1] + + +# Situation 3: shape is a tensor +class TestExpandV2OpRank1_tensor(OpTest): + + def setUp(self): + self.op_type = "expand_v2" + self.place = paddle.device.MLUPlace(0) + self.__class__.use_mlu = True + self.init_data() + + self.inputs = { + 'X': np.random.random(self.ori_shape).astype("float32"), + 'Shape': np.array(self.expand_shape).astype("int32"), + } + self.attrs = {} + output = np.tile(self.inputs['X'], self.expand_times) + self.outputs = {'Out': output} + + def init_data(self): + self.ori_shape = [100] + self.expand_times = [2, 1] + self.expand_shape = [2, 100] + + def test_check_output(self): + self.check_output_with_place(self.place, check_eager=False) + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + +# Situation 4: input x is Integer +class TestExpandV2OpInteger(OpTest): + + def setUp(self): + self.op_type = "expand_v2" + self.place = paddle.device.MLUPlace(0) + self.__class__.use_mlu = True + self.inputs = { + 'X': np.random.randint(10, size=(2, 4, 5)).astype("int32") + } + self.attrs = {'shape': [2, 4, 5]} + output = np.tile(self.inputs['X'], (1, 1, 1)) + self.outputs = {'Out': output} + + def test_check_output(self): + self.check_output_with_place(self.place, check_eager=False) + + +# Situation 5: input x is Bool +class TestExpandV2OpBoolean(OpTest): + + def setUp(self): + self.op_type = "expand_v2" + self.place = paddle.device.MLUPlace(0) + self.__class__.use_mlu = True + self.inputs = {'X': np.random.randint(2, size=(2, 4, 5)).astype("bool")} + self.attrs = {'shape': [2, 4, 5]} + output = np.tile(self.inputs['X'], (1, 1, 1)) + self.outputs = {'Out': output} + + def test_check_output(self): + self.check_output_with_place(self.place, check_eager=False) + + +# Situation 56: input x is Integer +class TestExpandV2OpInt64_t(OpTest): + + def setUp(self): + self.op_type = "expand_v2" + self.place = paddle.device.MLUPlace(0) + self.__class__.use_mlu = True + self.inputs = { + 'X': np.random.randint(10, size=(2, 4, 5)).astype("int64") + } + self.attrs = {'shape': [2, 4, 5]} + output = np.tile(self.inputs['X'], (1, 1, 1)) + self.outputs = {'Out': output} + + def test_check_output(self): + self.check_output_with_place(self.place, check_eager=False) + + +class TestExpandV2Error(unittest.TestCase): + + def test_errors(self): + with program_guard(Program(), Program()): + x1 = fluid.create_lod_tensor(np.array([[-1]]), [[1]], + paddle.device.MLUPlace(0)) + shape = [2, 2] + self.assertRaises(TypeError, paddle.tensor.expand, x1, shape) + x2 = fluid.layers.data(name='x2', shape=[4], dtype="uint8") + self.assertRaises(TypeError, paddle.tensor.expand, x2, shape) + x3 = fluid.layers.data(name='x3', shape=[4], dtype="bool") + x3.stop_gradient = False + self.assertRaises(ValueError, paddle.tensor.expand, x3, shape) + + +# Test python API +class TestExpandV2API(unittest.TestCase): + + def test_api(self): + input = np.random.random([12, 14]).astype("float32") + x = fluid.layers.data(name='x', + shape=[12, 14], + append_batch_size=False, + dtype="float32") + + positive_2 = fluid.layers.fill_constant([1], "int32", 12) + expand_shape = fluid.layers.data(name="expand_shape", + shape=[2], + append_batch_size=False, + dtype="int32") + + out_1 = paddle.expand(x, shape=[12, 14]) + out_2 = paddle.expand(x, shape=[positive_2, 14]) + out_3 = paddle.expand(x, shape=expand_shape) + + g0 = fluid.backward.calc_gradient(out_2, x) + + exe = fluid.Executor(place=paddle.device.MLUPlace(0)) + res_1, res_2, res_3 = exe.run(fluid.default_main_program(), + feed={ + "x": + input, + "expand_shape": + np.array([12, 14]).astype("int32") + }, + fetch_list=[out_1, out_2, out_3]) + assert np.array_equal(res_1, np.tile(input, (1, 1))) + assert np.array_equal(res_2, np.tile(input, (1, 1))) + assert np.array_equal(res_3, np.tile(input, (1, 1))) + + +class TestExpandInferShape(unittest.TestCase): + + def test_shape_with_var(self): + with program_guard(Program(), Program()): + x = paddle.static.data(shape=[-1, 1, 3], name='x') + fake_var = paddle.randn([2, 3]) + target_shape = [ + -1, paddle.shape(fake_var)[0], + paddle.shape(fake_var)[1] + ] + out = paddle.expand(x, shape=target_shape) + self.assertListEqual(list(out.shape), [-1, -1, -1]) + + +# Test python Dygraph API +class TestExpandV2DygraphAPI(unittest.TestCase): + + def test_expand_times_is_tensor(self): + with paddle.fluid.dygraph.guard(): + paddle.seed(1) + a = paddle.rand([2, 5]) + expand_1 = paddle.expand(a, shape=[2, 5]) + np_array = np.array([2, 5]) + expand_2 = paddle.expand(a, shape=np_array) + self.assertTrue(np.array_equal(expand_1.numpy(), expand_2.numpy())) + + +if __name__ == "__main__": + paddle.enable_static() + unittest.main() -- GitLab