From 2d2f11d161c65a3ddae6e69f9391190393564494 Mon Sep 17 00:00:00 2001 From: joeqiao12 <45232181+joeqiao12@users.noreply.github.com> Date: Thu, 17 Feb 2022 20:38:22 +0800 Subject: [PATCH] add reshape2 op for mlu (#39562) --- paddle/fluid/operators/reshape_op_mlu.cc | 145 ++++++++++++++++++ .../unittests/mlu/test_reshape2_op_mlu.py | 73 +++++++++ 2 files changed, 218 insertions(+) create mode 100644 paddle/fluid/operators/reshape_op_mlu.cc create mode 100644 python/paddle/fluid/tests/unittests/mlu/test_reshape2_op_mlu.py diff --git a/paddle/fluid/operators/reshape_op_mlu.cc b/paddle/fluid/operators/reshape_op_mlu.cc new file mode 100644 index 00000000000..cc197e18096 --- /dev/null +++ b/paddle/fluid/operators/reshape_op_mlu.cc @@ -0,0 +1,145 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/utils.h" + +namespace paddle { +namespace operators { + +template +class Reshape2MLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* out = ctx.Output("Out"); + + std::vector target_shape_vector; + auto shape_tensor_vector = ctx.MultiInput("ShapeTensor"); + if (shape_tensor_vector.size() > 0) { + for (auto* shape_tensor : shape_tensor_vector) { + PADDLE_ENFORCE_EQ( + shape_tensor->dims().size(), 1, + platform::errors::InvalidArgument( + "If the element type of 'shape' in Reshape Op is Tensor, " + "the element's shape must be [1]. But received the element's " + "shape is [%d]", + shape_tensor->dims().size())); + + target_shape_vector.push_back(GetDataFromTensor(shape_tensor)[0]); + } + } else { + auto* shape_tensor = ctx.HasInput("Shape") + ? ctx.Input("Shape") + : nullptr; + if (shape_tensor) { + target_shape_vector = GetDataFromTensor(shape_tensor); + } else { + target_shape_vector = ctx.Attr>("shape"); + PADDLE_ENFORCE_GT( + target_shape_vector.size(), 0, + platform::errors::InvalidArgument( + "The length of shape attribute should be larger than 0 when " + "input ShapeTensor and Shape are empty!")); + } + } + + int num_negative = + std::count(target_shape_vector.begin(), target_shape_vector.end(), -1); + PADDLE_ENFORCE_LE( + num_negative, 1, + platform::errors::InvalidArgument( + "The max number of -1 in shape attribute or shape tensor is 1 " + "but received %d.", + num_negative)); + auto it_zero = + std::find(target_shape_vector.begin(), target_shape_vector.end(), 0); + if (it_zero != target_shape_vector.end()) { + int x_rank = x->dims().size(); + for (size_t i = 0; i < target_shape_vector.size(); i++) { + if (target_shape_vector[i] == 0) { + PADDLE_ENFORCE_LT( + i, x_rank, + platform::errors::InvalidArgument( + "The index of 0 in shape attribute or shape tensor", + "should be less than input dim size, ", + "but the index is %d and input dim size is %d", i, x_rank)); + target_shape_vector[i] = x->dims().at(i); + } + } + } + + auto it = + std::find(target_shape_vector.begin(), target_shape_vector.end(), -1); + if (it != target_shape_vector.end()) { + auto ddim_out_vec = framework::vectorize(x->dims()); + int ddim_out_product = std::accumulate( + ddim_out_vec.begin(), ddim_out_vec.end(), 1, std::multiplies()); + int reshape_out_product = std::accumulate(target_shape_vector.begin(), + target_shape_vector.end(), -1, + std::multiplies()); + int index = std::distance(target_shape_vector.begin(), it); + target_shape_vector[index] = ddim_out_product / reshape_out_product; + } + + auto out_dims = framework::make_ddim(target_shape_vector); + out->mutable_data(out_dims, ctx.GetPlace()); + + // output should copy to mlu + framework::TensorCopy( + *x, ctx.GetPlace(), + ctx.template device_context(), out); + out->Resize(out_dims); + } +}; + +template +class Reshape2GradMLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* d_x = ctx.Output(framework::GradVarName("X")); + auto* d_out = ctx.Input(framework::GradVarName("Out")); + auto in_dims = d_x->dims(); + + d_x->mutable_data(ctx.GetPlace(), d_out->type()); + framework::TensorCopy( + *d_out, ctx.GetPlace(), + ctx.template device_context(), d_x); + d_x->Resize(in_dims); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_MLU_KERNEL( + reshape2, ops::Reshape2MLUKernel, + ops::Reshape2MLUKernel, + ops::Reshape2MLUKernel, + ops::Reshape2MLUKernel, + ops::Reshape2MLUKernel, + ops::Reshape2MLUKernel, + ops::Reshape2MLUKernel); +REGISTER_OP_MLU_KERNEL( + reshape2_grad, + ops::Reshape2GradMLUKernel, + ops::Reshape2GradMLUKernel, + ops::Reshape2GradMLUKernel, + ops::Reshape2GradMLUKernel, + ops::Reshape2GradMLUKernel, + ops::Reshape2GradMLUKernel, + ops::Reshape2GradMLUKernel); diff --git a/python/paddle/fluid/tests/unittests/mlu/test_reshape2_op_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_reshape2_op_mlu.py new file mode 100644 index 00000000000..9cff269913f --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mlu/test_reshape2_op_mlu.py @@ -0,0 +1,73 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import unittest +import sys +sys.path.append("..") +from op_test import OpTest +import paddle +import paddle.fluid as fluid + +paddle.enable_static() +SEED = 2022 + + +class TestReshape2(OpTest): + def setUp(self): + self.set_mlu() + self.op_type = "reshape2" + self.place = paddle.MLUPlace(0) + + self.init_data() + self.inputs = {"X": np.random.random(self.ori_shape).astype("float32")} + self.attrs = {"shape": self.new_shape} + self.outputs = { + "Out": self.inputs["X"].reshape(self.infered_shape), + 'XShape': np.random.random(self.ori_shape).astype("float32") + } + + def set_mlu(self): + self.__class__.use_mlu = True + + def init_data(self): + self.ori_shape = (2, 100) + self.new_shape = (20, 10) + self.infered_shape = (20, 10) + + def test_check_output(self): + self.check_output_with_place(self.place, no_check_set=['XShape']) + + def test_check_grad_normal(self): + self.check_grad_with_place(self.place, ['X'], 'Out') + + +class TestReshape2_case2(TestReshape2): + def init_data(self): + self.ori_shape = (2, 100) + self.new_shape = (-1, 10) + self.infered_shape = (20, 10) + + +class TestReshape2_case3(TestReshape2): + def init_data(self): + self.ori_shape = (100, 5, 6) + self.new_shape = (-1, 0, 3) + self.infered_shape = (200, 5, 3) + + +if __name__ == '__main__': + unittest.main() -- GitLab