From 2d16d69b5f06cbd00fafe42f71d47328c9b8a7f4 Mon Sep 17 00:00:00 2001 From: Linjie Chen <40840292+linjieccc@users.noreply.github.com> Date: Tue, 15 Feb 2022 10:26:16 +0800 Subject: [PATCH] [Pten]Move expand_v2 to pten (#39471) * move expand to pten * move expand_v2 to pten * move expand_v2 to pten * fix grad register * fix grad register * fix tensorcpry * fix tensorcopy * fix tensorcopy * fix tensorcopy * fix tensorcopy * fix ci * fix tensorcopy --- paddle/fluid/operators/expand_v2_op.cc | 33 +-- paddle/fluid/operators/expand_v2_op.h | 254 ------------------ paddle/fluid/operators/expand_v2_op_npu.cc | 1 + paddle/fluid/operators/expand_v2_op_xpu.cc | 1 + paddle/pten/core/compat/op_utils.h | 2 + paddle/pten/kernels/cpu/expand_grad_kernel.cc | 29 ++ paddle/pten/kernels/cpu/expand_kernel.cc | 30 +++ paddle/pten/kernels/expand_grad_kernel.h | 30 +++ paddle/pten/kernels/expand_kernel.h | 29 ++ paddle/pten/kernels/gpu/expand_grad_kernel.cu | 29 ++ paddle/pten/kernels/gpu/expand_kernel.cu | 31 +++ .../kernels/impl/expand_grad_kernel_impl.h | 142 ++++++++++ paddle/pten/kernels/impl/expand_kernel_impl.h | 169 ++++++++++++ paddle/pten/ops/compat/expand_sig.cc | 54 ++++ 14 files changed, 550 insertions(+), 284 deletions(-) mode change 100755 => 100644 paddle/fluid/operators/expand_v2_op.cc create mode 100644 paddle/pten/kernels/cpu/expand_grad_kernel.cc create mode 100644 paddle/pten/kernels/cpu/expand_kernel.cc create mode 100644 paddle/pten/kernels/expand_grad_kernel.h create mode 100644 paddle/pten/kernels/expand_kernel.h create mode 100644 paddle/pten/kernels/gpu/expand_grad_kernel.cu create mode 100644 paddle/pten/kernels/gpu/expand_kernel.cu create mode 100644 paddle/pten/kernels/impl/expand_grad_kernel_impl.h create mode 100644 paddle/pten/kernels/impl/expand_kernel_impl.h create mode 100644 paddle/pten/ops/compat/expand_sig.cc diff --git a/paddle/fluid/operators/expand_v2_op.cc b/paddle/fluid/operators/expand_v2_op.cc old mode 100755 new mode 100644 index 6d803c500d..901e073ccd --- a/paddle/fluid/operators/expand_v2_op.cc +++ b/paddle/fluid/operators/expand_v2_op.cc @@ -16,6 +16,9 @@ limitations under the License. */ #include #include #include +#include "paddle/fluid/framework/op_registry.h" + +#define MAX_RANK_SUPPORTED 6 namespace paddle { namespace operators { @@ -296,33 +299,3 @@ REGISTER_OPERATOR(expand_v2_grad, ops::ExpandV2GradOp, ops::ExpandV2DoubleGradOpMaker, ops::ExpandV2DoubleGradOpMaker, ops::ExpandV2GradNoNeedBufVarsInferer); -REGISTER_OP_CPU_KERNEL( - expand_v2, ops::ExpandV2Kernel, - ops::ExpandV2Kernel, - ops::ExpandV2Kernel, - ops::ExpandV2Kernel, - ops::ExpandV2Kernel); -REGISTER_OP_CPU_KERNEL( - expand_v2_grad, - ops::ExpandV2GradKernel, - ops::ExpandV2GradKernel, - ops::ExpandV2GradKernel, - ops::ExpandV2GradKernel); -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -REGISTER_OP_CUDA_KERNEL( - expand_v2, ops::ExpandV2Kernel, - ops::ExpandV2Kernel, - ops::ExpandV2Kernel, - ops::ExpandV2Kernel, - ops::ExpandV2Kernel, - ops::ExpandV2Kernel); -REGISTER_OP_CUDA_KERNEL( - expand_v2_grad, - ops::ExpandV2GradKernel, - ops::ExpandV2GradKernel, - ops::ExpandV2GradKernel, - ops::ExpandV2GradKernel, - ops::ExpandV2GradKernel); -#endif diff --git a/paddle/fluid/operators/expand_v2_op.h b/paddle/fluid/operators/expand_v2_op.h index dd16250134..158a9d1bc5 100644 --- a/paddle/fluid/operators/expand_v2_op.h +++ b/paddle/fluid/operators/expand_v2_op.h @@ -91,259 +91,5 @@ inline std::vector get_expand_shape( return ctx.Attr>("shape"); } } - -using Tensor = framework::Tensor; -template -using EigenVector = framework::EigenVector; -template -using EigenTensor = framework::EigenTensor; -using framework::To32BitIndex; - -template -class ExpandV2Kernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto rank = context.Input("X")->dims().size(); - PADDLE_ENFORCE_GE( - rank, 1, - platform::errors::InvalidArgument( - "The rank of the input 'X' for expand_v2 op must be positive, " - "but the value received is %d.", - rank)); - PADDLE_ENFORCE_LE( - rank, MAX_RANK_SUPPORTED, - platform::errors::InvalidArgument( - "The rank of the input 'X' for expand_v2 op must be less than " - "or equal to %d, but the value received is %d.", - MAX_RANK_SUPPORTED, rank)); - auto expand_shape = get_expand_shape(context); - auto shape_size = expand_shape.size(); - PADDLE_ENFORCE_GE( - shape_size, rank, - platform::errors::InvalidArgument( - "The number (%d) of elements of 'shape' for expand_v2 op must be " - "greater than or equal to the rank (%d) of the input 'X'.", - shape_size, rank)); - PADDLE_ENFORCE_LE( - shape_size, MAX_RANK_SUPPORTED, - platform::errors::InvalidArgument( - "The number (%d) of elements of 'shape' for expand_v2 op must be " - "less than or equal to %d.", - shape_size, MAX_RANK_SUPPORTED)); - rank = std::max(rank, static_cast(shape_size)); - switch (rank) { - case 1: - Expand<1>(context); - break; - case 2: - Expand<2>(context); - break; - case 3: - Expand<3>(context); - break; - case 4: - Expand<4>(context); - break; - case 5: - Expand<5>(context); - break; - case 6: - Expand<6>(context); - break; - } - } - - protected: - template - void Expand(const framework::ExecutionContext& context) const { - auto* in0 = context.Input("X"); - - auto in_dims = in0->dims(); - auto expand_shape = get_expand_shape(context); - auto vec_in_dims = framework::vectorize(in_dims); - auto diff = expand_shape.size() - vec_in_dims.size(); - vec_in_dims.insert(vec_in_dims.begin(), diff, 1); - std::vector repeat_times(vec_in_dims.size()); - for (size_t i = 0; i < vec_in_dims.size(); ++i) { - PADDLE_ENFORCE_NE(expand_shape[i], 0, - platform::errors::InvalidArgument( - "The expanded size cannot be zero.")); - if (i < diff) { - PADDLE_ENFORCE_GT( - expand_shape[i], 0, - platform::errors::InvalidArgument( - "The expanded size (%d) for non-existing dimensions must be " - "positive for expand_v2 op.", - expand_shape[i])); - repeat_times[i] = expand_shape[i]; - } else if (expand_shape[i] > 0) { - if (vec_in_dims[i] != 1) { - PADDLE_ENFORCE_EQ( - vec_in_dims[i], expand_shape[i], - platform::errors::InvalidArgument( - "The value (%d) of the non-singleton dimension does not match" - " the corresponding value (%d) in shape for expand_v2 op.", - vec_in_dims[i], expand_shape[i])); - repeat_times[i] = 1; - } else { - repeat_times[i] = expand_shape[i]; - } - } else { - PADDLE_ENFORCE_EQ( - expand_shape[i], -1, - platform::errors::InvalidArgument( - "When the value in shape is negative for expand_v2 op, " - "only -1 is supported, but the value received is %d.", - expand_shape[i])); - repeat_times[i] = 1; - } - } - - auto* out0 = context.Output("Out"); - Eigen::DSizes bcast_dims; - for (size_t i = 0; i < repeat_times.size(); ++i) { - bcast_dims[i] = repeat_times[i]; - } - - framework::DDim new_in_dims = framework::make_ddim(vec_in_dims); - framework::DDim out_dims(new_in_dims); - for (size_t i = 0; i < repeat_times.size(); ++i) { - out_dims[i] *= repeat_times[i]; - } - - out0->Resize(out_dims); - auto x = EigenTensor::From(*in0, new_in_dims); - out0->mutable_data(context.GetPlace()); - auto y = EigenTensor::From(*out0, out_dims); - auto& place = - *context.template device_context().eigen_device(); - // use 32-bit index to speed up - bool use_32bit_index = y.size() < Eigen::NumTraits::highest(); - if (use_32bit_index) { - EigenBroadcast, T, Rank>::Eval( - place, To32BitIndex(y), To32BitIndex(x), bcast_dims); - } else { - EigenBroadcast, T, Rank>::Eval(place, y, x, - bcast_dims); - } - } -}; - -template -class ExpandV2GradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* in0 = context.Input("X"); - auto expand_shape = get_expand_shape(context); - auto x_dims = in0->dims(); - auto vec_in_dims = framework::vectorize(x_dims); - auto diff = expand_shape.size() - vec_in_dims.size(); - vec_in_dims.insert(vec_in_dims.begin(), diff, 1); - // 1. reshape_dims_vec is the broadcast parameter. - // 2. reduce_dims_vec is the dimension parameter to compute gradients. For - // each dimension expanded, the gradients should be summed to original - // size. - std::vector repeat_times(vec_in_dims.size()); - for (size_t i = 0; i < vec_in_dims.size(); ++i) { - if (expand_shape[i] < 0) { - repeat_times[i] = 1; - } else { - repeat_times[i] = expand_shape[i] / vec_in_dims[i]; - } - } - std::vector reshape_dims_vec; - std::vector reduce_dims_vec; - for (size_t i = 0; i < repeat_times.size(); ++i) { - reduce_dims_vec.push_back(reshape_dims_vec.size()); - reshape_dims_vec.push_back(repeat_times[i]); - reshape_dims_vec.push_back(vec_in_dims[i]); - } - - int dims = reduce_dims_vec.size(); - - bool just_copy = true; - for (size_t i = 0; i < repeat_times.size(); i++) { - if (repeat_times[i] != 1) { - just_copy = false; - break; - } - } - // no need reduce, just copy - if (just_copy) { - auto* in0 = context.Input(framework::GradVarName("Out")); - auto* out0 = context.Output(framework::GradVarName("X")); - out0->mutable_data(context.GetPlace()); - framework::TensorCopy(*in0, context.GetPlace(), context.device_context(), - out0); - } else { - PADDLE_ENFORCE_GE(dims, 1, - platform::errors::InvalidArgument( - "The rank of the input 'Out@GRAD' for " - "expand_v2_grad op must be greater than or " - "equal to 1, but the value received is %d.", - dims)); - PADDLE_ENFORCE_LE(dims, MAX_RANK_SUPPORTED, - platform::errors::InvalidArgument( - "The rank of the input 'Out@GRAD' for " - "expand_v2_grad op must be less than or equal " - "to %d, but the value received is %d.", - MAX_RANK_SUPPORTED, dims)); - switch (dims) { - case 1: - ExpandBackward<1>(context, reshape_dims_vec, reduce_dims_vec); - break; - case 2: - ExpandBackward<2>(context, reshape_dims_vec, reduce_dims_vec); - break; - case 3: - ExpandBackward<3>(context, reshape_dims_vec, reduce_dims_vec); - break; - case 4: - ExpandBackward<4>(context, reshape_dims_vec, reduce_dims_vec); - break; - case 5: - ExpandBackward<5>(context, reshape_dims_vec, reduce_dims_vec); - break; - case 6: - ExpandBackward<6>(context, reshape_dims_vec, reduce_dims_vec); - break; - default: - PADDLE_THROW(platform::errors::InvalidArgument( - "Only support tensor with rank being between 1 and 6. But " - "received tensor's rank = %d.", - dims)); - } - } - } - - protected: - template - void ExpandBackward(const framework::ExecutionContext& context, - const std::vector& reshape_dims_vec, - const std::vector& reduce_dims_vec) const { - size_t reshape_size = reshape_dims_vec.size(); - size_t reduce_size = reduce_dims_vec.size(); - auto* in0 = context.Input(framework::GradVarName("Out")); - auto* out0 = context.Output(framework::GradVarName("X")); - out0->mutable_data(context.GetPlace()); - auto x_grad = EigenVector::Flatten(*out0); - Eigen::DSizes reshape_dims; - for (size_t i = 0; i < reshape_size; ++i) { - reshape_dims[i] = reshape_dims_vec[i]; - } - Eigen::DSizes reduce_dims; - for (size_t i = 0; i < reduce_size; ++i) { - reduce_dims[i] = reduce_dims_vec[i]; - } - auto out_grad = EigenVector::Flatten(*in0); - auto& place = - *context.template device_context().eigen_device(); - EigenBroadcastGrad, T, Dims>::Eval( - place, x_grad, out_grad, reduce_dims, reshape_dims); - } -}; - } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/expand_v2_op_npu.cc b/paddle/fluid/operators/expand_v2_op_npu.cc index d2f8daf732..f7de5a336d 100644 --- a/paddle/fluid/operators/expand_v2_op_npu.cc +++ b/paddle/fluid/operators/expand_v2_op_npu.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/expand_v2_op.h" +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/device/npu/npu_op_runner.h" namespace paddle { diff --git a/paddle/fluid/operators/expand_v2_op_xpu.cc b/paddle/fluid/operators/expand_v2_op_xpu.cc index 791f8e8236..e2e0cff100 100644 --- a/paddle/fluid/operators/expand_v2_op_xpu.cc +++ b/paddle/fluid/operators/expand_v2_op_xpu.cc @@ -14,6 +14,7 @@ limitations under the License. */ #ifdef PADDLE_WITH_XPU #include "paddle/fluid/operators/expand_v2_op.h" +#include "paddle/fluid/framework/op_registry.h" namespace paddle { namespace operators { diff --git a/paddle/pten/core/compat/op_utils.h b/paddle/pten/core/compat/op_utils.h index 950b6cd039..14a1872c17 100644 --- a/paddle/pten/core/compat/op_utils.h +++ b/paddle/pten/core/compat/op_utils.h @@ -45,6 +45,8 @@ const std::unordered_set deprecated_op_names({"flatten", "mean", "reshape", "reshape_grad", + "expand", + "expand_grad", "sum"}); class DefaultKernelSignatureMap { diff --git a/paddle/pten/kernels/cpu/expand_grad_kernel.cc b/paddle/pten/kernels/cpu/expand_grad_kernel.cc new file mode 100644 index 0000000000..518d81d89e --- /dev/null +++ b/paddle/pten/kernels/cpu/expand_grad_kernel.cc @@ -0,0 +1,29 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/pten/kernels/expand_grad_kernel.h" +#include "paddle/pten/backends/cpu/cpu_context.h" +#include "paddle/pten/common/scalar.h" +#include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/core/kernel_registry.h" +#include "paddle/pten/kernels/impl/expand_grad_kernel_impl.h" + +PT_REGISTER_KERNEL(expand_grad, + CPU, + ALL_LAYOUT, + pten::ExpandGradKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/pten/kernels/cpu/expand_kernel.cc b/paddle/pten/kernels/cpu/expand_kernel.cc new file mode 100644 index 0000000000..c5c019bd72 --- /dev/null +++ b/paddle/pten/kernels/cpu/expand_kernel.cc @@ -0,0 +1,30 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/pten/kernels/expand_kernel.h" +#include "paddle/pten/backends/cpu/cpu_context.h" +#include "paddle/pten/common/scalar.h" +#include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/core/kernel_registry.h" +#include "paddle/pten/kernels/impl/expand_kernel_impl.h" + +PT_REGISTER_KERNEL(expand, + CPU, + ALL_LAYOUT, + pten::ExpandKernel, + float, + double, + int, + int64_t, + bool) {} diff --git a/paddle/pten/kernels/expand_grad_kernel.h b/paddle/pten/kernels/expand_grad_kernel.h new file mode 100644 index 0000000000..8bcb599cc9 --- /dev/null +++ b/paddle/pten/kernels/expand_grad_kernel.h @@ -0,0 +1,30 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/pten/common/scalar_array.h" +#include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/core/device_context.h" + +namespace pten { + +template +void ExpandGradKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& out_grad, + const ScalarArray& shape, + DenseTensor* in_grad); + +} // namespace pten diff --git a/paddle/pten/kernels/expand_kernel.h b/paddle/pten/kernels/expand_kernel.h new file mode 100644 index 0000000000..91bea8c07e --- /dev/null +++ b/paddle/pten/kernels/expand_kernel.h @@ -0,0 +1,29 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/pten/common/scalar_array.h" +#include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/core/device_context.h" + +namespace pten { + +template +void ExpandKernel(const Context& ctx, + const DenseTensor& x, + const ScalarArray& shape, + DenseTensor* out); + +} // namepsace pten diff --git a/paddle/pten/kernels/gpu/expand_grad_kernel.cu b/paddle/pten/kernels/gpu/expand_grad_kernel.cu new file mode 100644 index 0000000000..49f8718c48 --- /dev/null +++ b/paddle/pten/kernels/gpu/expand_grad_kernel.cu @@ -0,0 +1,29 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/pten/backends/gpu/gpu_context.h" +#include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/core/kernel_registry.h" +#include "paddle/pten/kernels/expand_grad_kernel.h" +#include "paddle/pten/kernels/impl/expand_grad_kernel_impl.h" + +PT_REGISTER_KERNEL(expand_grad, + GPU, + ALL_LAYOUT, + pten::ExpandGradKernel, + float, + double, + paddle::platform::float16, + int, + int64_t) {} diff --git a/paddle/pten/kernels/gpu/expand_kernel.cu b/paddle/pten/kernels/gpu/expand_kernel.cu new file mode 100644 index 0000000000..e0d8536d6a --- /dev/null +++ b/paddle/pten/kernels/gpu/expand_kernel.cu @@ -0,0 +1,31 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/pten/backends/gpu/gpu_context.h" +#include "paddle/pten/common/scalar.h" +#include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/core/kernel_registry.h" +#include "paddle/pten/kernels/expand_kernel.h" +#include "paddle/pten/kernels/impl/expand_kernel_impl.h" + +PT_REGISTER_KERNEL(expand, + GPU, + ALL_LAYOUT, + pten::ExpandKernel, + float, + double, + paddle::platform::float16, + int, + int64_t, + bool) {} diff --git a/paddle/pten/kernels/impl/expand_grad_kernel_impl.h b/paddle/pten/kernels/impl/expand_grad_kernel_impl.h new file mode 100644 index 0000000000..05ccf2e00d --- /dev/null +++ b/paddle/pten/kernels/impl/expand_grad_kernel_impl.h @@ -0,0 +1,142 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/pten/kernels/copy_kernel.h" +#include "paddle/pten/kernels/funcs/eigen/common.h" +#include "paddle/pten/kernels/funcs/eigen/eigen_function.h" +#include "paddle/pten/kernels/impl/expand_kernel_impl.h" + +namespace pten { +template +void ExpandBackward(const Context& ctx, + const DenseTensor& out_grad, + const std::vector& reshape_dims_vec, + const std::vector& reduce_dims_vec, + DenseTensor* in_grad) { + size_t reshape_size = reshape_dims_vec.size(); + size_t reduce_size = reduce_dims_vec.size(); + ctx.template Alloc(in_grad); + in_grad->data(); + + auto x_grad = EigenVector::Flatten(*in_grad); + Eigen::DSizes reshape_dims; + for (size_t i = 0; i < reshape_size; ++i) { + reshape_dims[i] = reshape_dims_vec[i]; + } + Eigen::DSizes reduce_dims; + for (size_t i = 0; i < reduce_size; ++i) { + reduce_dims[i] = reduce_dims_vec[i]; + } + auto out_grad0 = EigenVector::Flatten(out_grad); + auto& place = *ctx.eigen_device(); + pten::funcs::EigenBroadcastGrad, T, Dims>::Eval( + place, x_grad, out_grad0, reduce_dims, reshape_dims); +} + +template +void ExpandGradKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& out_grad, + const ScalarArray& shape, + DenseTensor* in_grad) { + auto expand_shape = shape.GetData(); + auto x_dims = x.dims(); + auto vec_in_dims = framework::vectorize(x_dims); + auto diff = expand_shape.size() - vec_in_dims.size(); + vec_in_dims.insert(vec_in_dims.begin(), diff, 1); + // 1. reshape_dims_vec is the broadcast parameter. + // 2. reduce_dims_vec is the dimension parameter to compute gradients. For + // each dimension expanded, the gradients should be summed to original + // size. + std::vector repeat_times(vec_in_dims.size()); + for (size_t i = 0; i < vec_in_dims.size(); ++i) { + if (expand_shape[i] < 0) { + repeat_times[i] = 1; + } else { + repeat_times[i] = expand_shape[i] / vec_in_dims[i]; + } + } + std::vector reshape_dims_vec; + std::vector reduce_dims_vec; + for (size_t i = 0; i < repeat_times.size(); ++i) { + reduce_dims_vec.push_back(reshape_dims_vec.size()); + reshape_dims_vec.push_back(repeat_times[i]); + reshape_dims_vec.push_back(vec_in_dims[i]); + } + + int dims = reduce_dims_vec.size(); + + bool just_copy = true; + for (size_t i = 0; i < repeat_times.size(); i++) { + if (repeat_times[i] != 1) { + just_copy = false; + break; + } + } + // no need reduce, just copy + if (just_copy) { + pten::Copy(ctx, out_grad, false, in_grad); + } else { + PADDLE_ENFORCE_GE(dims, + 1, + pten::errors::InvalidArgument( + "The rank of the input 'Out@GRAD' for " + "expand_v2_grad op must be greater than or " + "equal to 1, but the value received is %d.", + dims)); + PADDLE_ENFORCE_LE(dims, + MAX_RANK_SUPPORTED, + pten::errors::InvalidArgument( + "The rank of the input 'Out@GRAD' for " + "expand_v2_grad op must be less than or equal " + "to %d, but the value received is %d.", + MAX_RANK_SUPPORTED, + dims)); + switch (dims) { + case 1: + ExpandBackward( + ctx, out_grad, reshape_dims_vec, reduce_dims_vec, in_grad); + break; + case 2: + ExpandBackward( + ctx, out_grad, reshape_dims_vec, reduce_dims_vec, in_grad); + break; + case 3: + ExpandBackward( + ctx, out_grad, reshape_dims_vec, reduce_dims_vec, in_grad); + break; + case 4: + ExpandBackward( + ctx, out_grad, reshape_dims_vec, reduce_dims_vec, in_grad); + break; + case 5: + ExpandBackward( + ctx, out_grad, reshape_dims_vec, reduce_dims_vec, in_grad); + break; + case 6: + ExpandBackward( + ctx, out_grad, reshape_dims_vec, reduce_dims_vec, in_grad); + break; + default: + PADDLE_THROW(pten::errors::InvalidArgument( + "Only support tensor with rank being between 1 and 6. But " + "received tensor's rank = %d.", + dims)); + } + } +} + +} // namespace pten diff --git a/paddle/pten/kernels/impl/expand_kernel_impl.h b/paddle/pten/kernels/impl/expand_kernel_impl.h new file mode 100644 index 0000000000..14f2b8ad24 --- /dev/null +++ b/paddle/pten/kernels/impl/expand_kernel_impl.h @@ -0,0 +1,169 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "paddle/pten/kernels/funcs/eigen/common.h" +#include "paddle/pten/kernels/funcs/eigen/eigen_function.h" +#define MAX_RANK_SUPPORTED 6 + +namespace pten { +using Tensor = DenseTensor; + +template +void Expand(const Context& ctx, + const DenseTensor& x, + const ScalarArray& shape, + DenseTensor* out) { + auto in_dims = x.dims(); + auto expand_shape = shape.GetData(); + auto vec_in_dims = framework::vectorize(in_dims); + auto diff = expand_shape.size() - vec_in_dims.size(); + vec_in_dims.insert(vec_in_dims.begin(), diff, 1); + std::vector repeat_times(vec_in_dims.size()); + for (size_t i = 0; i < vec_in_dims.size(); ++i) { + PADDLE_ENFORCE_NE( + expand_shape[i], + 0, + pten::errors::InvalidArgument("The expanded size cannot be zero.")); + if (i < diff) { + PADDLE_ENFORCE_GT( + expand_shape[i], + 0, + pten::errors::InvalidArgument( + "The expanded size (%d) for non-existing dimensions must be " + "positive for expand_v2 op.", + expand_shape[i])); + repeat_times[i] = expand_shape[i]; + } else if (expand_shape[i] > 0) { + if (vec_in_dims[i] != 1) { + PADDLE_ENFORCE_EQ( + vec_in_dims[i], + expand_shape[i], + pten::errors::InvalidArgument( + "The value (%d) of the non-singleton dimension does not match" + " the corresponding value (%d) in shape for expand_v2 op.", + vec_in_dims[i], + expand_shape[i])); + repeat_times[i] = 1; + } else { + repeat_times[i] = expand_shape[i]; + } + } else { + PADDLE_ENFORCE_EQ( + expand_shape[i], + -1, + pten::errors::InvalidArgument( + "When the value in shape is negative for expand_v2 op, " + "only -1 is supported, but the value received is %d.", + expand_shape[i])); + repeat_times[i] = 1; + } + } + + Eigen::DSizes bcast_dims; + for (size_t i = 0; i < repeat_times.size(); ++i) { + bcast_dims[i] = repeat_times[i]; + } + + framework::DDim new_in_dims = framework::make_ddim(vec_in_dims); + framework::DDim out_dims(new_in_dims); + for (size_t i = 0; i < repeat_times.size(); ++i) { + out_dims[i] *= repeat_times[i]; + } + + out->Resize(out_dims); + auto x0 = EigenTensor::From(x, new_in_dims); + ctx.template Alloc(out); + out->data(); + + auto y = EigenTensor::From(*out, out_dims); + auto& place = *ctx.eigen_device(); + // use 32-bit index to speed up + bool use_32bit_index = y.size() < Eigen::NumTraits::highest(); + if (use_32bit_index) { + pten::funcs::EigenBroadcast, T, Rank>::Eval( + place, To32BitIndex(y), To32BitIndex(x0), bcast_dims); + } else { + pten::funcs::EigenBroadcast, T, Rank>::Eval( + place, y, x0, bcast_dims); + } +} + +template +void ExpandKernel(const Context& ctx, + const DenseTensor& x, + const ScalarArray& shape, + DenseTensor* out) { + auto rank = x.dims().size(); + PADDLE_ENFORCE_GE( + rank, + 1, + pten::errors::InvalidArgument( + "The rank of the input 'X' for expand_v2 op must be positive, " + "but the value received is %d.", + rank)); + PADDLE_ENFORCE_LE( + rank, + MAX_RANK_SUPPORTED, + pten::errors::InvalidArgument( + "The rank of the input 'X' for expand_v2 op must be less than " + "or equal to %d, but the value received is %d.", + MAX_RANK_SUPPORTED, + rank)); + auto expand_shape = shape.GetData(); + auto shape_size = expand_shape.size(); + PADDLE_ENFORCE_GE( + shape_size, + rank, + pten::errors::InvalidArgument( + "The number (%d) of elements of 'shape' for expand_v2 op must be " + "greater than or equal to the rank (%d) of the input 'X'.", + shape_size, + rank)); + PADDLE_ENFORCE_LE( + shape_size, + MAX_RANK_SUPPORTED, + pten::errors::InvalidArgument( + "The number (%d) of elements of 'shape' for expand_v2 op must be " + "less than or equal to %d.", + shape_size, + MAX_RANK_SUPPORTED)); + rank = std::max(rank, static_cast(shape_size)); + switch (rank) { + case 1: + Expand(ctx, x, shape, out); + break; + case 2: + Expand(ctx, x, shape, out); + break; + case 3: + Expand(ctx, x, shape, out); + break; + case 4: + Expand(ctx, x, shape, out); + break; + case 5: + Expand(ctx, x, shape, out); + break; + case 6: + Expand(ctx, x, shape, out); + break; + } +} + +} // namespace pten diff --git a/paddle/pten/ops/compat/expand_sig.cc b/paddle/pten/ops/compat/expand_sig.cc new file mode 100644 index 0000000000..a04052cdac --- /dev/null +++ b/paddle/pten/ops/compat/expand_sig.cc @@ -0,0 +1,54 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/pten/core/compat/op_utils.h" + +namespace pten { + +KernelSignature ExpandOpArgumentMapping(const ArgumentMappingContext& ctx) { + if (ctx.HasInput("Shape")) { + return KernelSignature("expand", {"X"}, {"Shape"}, {"Out"}); + } else if (ctx.InputSize("expand_shapes_tensor") > 0) { + return KernelSignature("expand", {"X"}, {"expand_shapes_tensor"}, {"Out"}); + } else { + return KernelSignature("expand", {"X"}, {"shape"}, {"Out"}); + } +} + +KernelSignature ExpandGradOpArgumentMapping(const ArgumentMappingContext& ctx) { + if (ctx.HasInput("Shape")) { + return KernelSignature("expand_grad", + {"X", GradVarName("Out")}, + {"Shape"}, + {GradVarName("X")}); + } else if (ctx.InputSize("expand_shapes_tensor") > 0) { + return KernelSignature("expand_grad", + {"X", GradVarName("Out")}, + {"expand_shapes_tensor"}, + {GradVarName("X")}); + } else { + return KernelSignature("expand_grad", + {"X", GradVarName("Out")}, + {"shape"}, + {GradVarName("X")}); + } +} + +} // namespace pten + +PT_REGISTER_BASE_KERNEL_NAME(expand_v2, expand); +PT_REGISTER_BASE_KERNEL_NAME(expand_v2_grad, expand_grad); + +PT_REGISTER_ARG_MAPPING_FN(expand_v2, pten::ExpandOpArgumentMapping); +PT_REGISTER_ARG_MAPPING_FN(expand_v2_grad, pten::ExpandGradOpArgumentMapping); -- GitLab