From 31b1f707a653d9028858f16d0259d707c89041a7 Mon Sep 17 00:00:00 2001 From: baoachun <962571062@qq.com> Date: Mon, 10 Jan 2022 19:32:36 +0800 Subject: [PATCH] refactor the forward implementation of reshape npu op (#38748) * refactor the forward implementation of reshape npu op * update reshape npu op * update reshape npu op --- paddle/fluid/operators/reshape_op_npu.cc | 99 ++++++++++++++++++++---- 1 file changed, 85 insertions(+), 14 deletions(-) diff --git a/paddle/fluid/operators/reshape_op_npu.cc b/paddle/fluid/operators/reshape_op_npu.cc index d6b1d79f2b1..8b6f9d4d48d 100644 --- a/paddle/fluid/operators/reshape_op_npu.cc +++ b/paddle/fluid/operators/reshape_op_npu.cc @@ -16,6 +16,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/utils.h" #include "paddle/fluid/platform/device/npu/npu_op_runner.h" namespace paddle { @@ -25,23 +26,93 @@ template class Reshape2NPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + auto stream = + ctx.template device_context() + .stream(); + auto place = ctx.GetPlace(); auto* x = ctx.Input("X"); auto* out = ctx.Output("Out"); - auto list_new_shape_tensor = - ctx.MultiInput("ShapeTensor"); - if (list_new_shape_tensor.size() > 0) { - PADDLE_THROW(platform::errors::Unimplemented( - "Input(ShapeTensor) is not supported on NPU.")); + + std::vector target_shape_vector; + auto shape_tensor_vector = ctx.MultiInput("ShapeTensor"); + if (shape_tensor_vector.size() > 0) { + for (auto* shape_tensor : shape_tensor_vector) { + PADDLE_ENFORCE_EQ( + shape_tensor->dims().size(), 1, + platform::errors::InvalidArgument( + "If the element type of 'shape' in Reshape Op is Tensor, " + "the element's shape must be [1]. But received the element's " + "shape is [%d]", + shape_tensor->dims().size())); + + target_shape_vector.push_back(GetDataFromTensor(shape_tensor)[0]); + } + } else { + auto* shape_tensor = ctx.HasInput("Shape") + ? ctx.Input("Shape") + : nullptr; + if (shape_tensor) { + target_shape_vector = GetDataFromTensor(shape_tensor); + } else { + target_shape_vector = ctx.Attr>("shape"); + PADDLE_ENFORCE_GT( + target_shape_vector.size(), 0, + platform::errors::InvalidArgument( + "The length of shape attribute should be larger than 0 when " + "input ShapeTensor and Shape are empty!")); + } } - PADDLE_ENFORCE_EQ(ctx.Input("Shape"), nullptr, - platform::errors::Unimplemented( - "Input(Shape) is not supported on NPU.")); - auto shape = out->dims(); - out->mutable_data(ctx.GetPlace(), x->type()); - framework::TensorCopy( - *x, ctx.GetPlace(), - ctx.template device_context(), out); - out->Resize(shape); + + int num_negative = + std::count(target_shape_vector.begin(), target_shape_vector.end(), -1); + PADDLE_ENFORCE_LE( + num_negative, 1, + platform::errors::InvalidArgument( + "The max number of -1 in shape attribute or shape tensor is 1 " + "but received %d.", + num_negative)); + auto it_zero = + std::find(target_shape_vector.begin(), target_shape_vector.end(), 0); + if (it_zero != target_shape_vector.end()) { + int x_rank = x->dims().size(); + for (size_t i = 0; i < target_shape_vector.size(); i++) { + if (target_shape_vector[i] == 0) { + PADDLE_ENFORCE_LT( + i, x_rank, + platform::errors::InvalidArgument( + "The index of 0 in shape attribute or shape tensor", + "should be less than input dim size, ", + "but the index is %d and input dim size is %d", i, x_rank)); + target_shape_vector[i] = x->dims().at(i); + } + } + } + + auto it = + std::find(target_shape_vector.begin(), target_shape_vector.end(), -1); + if (it != target_shape_vector.end()) { + auto ddim_out_vec = framework::vectorize(x->dims()); + int ddim_out_product = std::accumulate( + ddim_out_vec.begin(), ddim_out_vec.end(), 1, std::multiplies()); + int reshape_out_product = std::accumulate(target_shape_vector.begin(), + target_shape_vector.end(), -1, + std::multiplies()); + int index = std::distance(target_shape_vector.begin(), it); + target_shape_vector[index] = ddim_out_product / reshape_out_product; + } + + auto out_dims = framework::make_ddim(target_shape_vector); + out->mutable_data(out_dims, place); + + NpuOpRunner runner; + // the shape input must be on the host side + runner.SetType("Reshape") + .AddInput(*x) + .AddInput(std::vector(target_shape_vector)) + .AddOutput(*out) + .AddAttr("axis", 0) + .AddAttr("num_axes", -1); + runner.Run(stream); } }; -- GitLab