From 7a857924570084851be8b6094f181f217d58fb7c Mon Sep 17 00:00:00 2001 From: hong <43953930+phlrain@users.noreply.github.com> Date: Wed, 2 Mar 2022 17:18:53 +0800 Subject: [PATCH] Move transpose to pten (#39327) * immigrate_transpose_to_pten cpu kernel only; test=develop * fix bug; test=develop * add transpose cuda api * bug fix; * fix bugs * fix bugs; test=develop * bug fix; * move transepose to pten; test=develop * fix bug; test=develop * fix bugs; test=develop * add transpose grad fp16 support; test=develop * fix bug; test=develop * fix npu bug; test=develop * fix nemul = 0 bug; test=develop * add fp16 support; test=develop * fix data type register bug; test=develop * fix transpose bug; test=develop * update transpose * fix transpose bug; test=develop * remove useless code; test=develop * remove useless code; test=develop * fix transpose alias bug; test=develop * polish code; test=develop * resolve confict; test=develop * resolve confilct; test=develop * recover prepared operator; test=develop * fix bug; test=develop * polish code; test=develop * fix bug; test=develop * fix bug; test=develop --- .../operators/mkldnn/test_mkldnn_op_nhwc.cc | 2 +- paddle/fluid/operators/transpose_op.cc | 60 ++------ paddle/fluid/operators/transpose_op.cu | 139 ------------------ paddle/fluid/operators/transpose_op.cu.h | 42 +++--- paddle/fluid/operators/transpose_op.h | 58 -------- .../fluid/operators/transpose_op_npu_test.cc | 2 +- .../phi/kernels/cpu/transpose_grad_kernel.cc | 32 ++++ paddle/phi/kernels/cpu/transpose_kernel.cc | 80 ++++++++++ paddle/phi/kernels/funcs/math_function.cu | 51 +++++++ .../phi/kernels/gpu/transpose_grad_kernel.cu | 34 +++++ paddle/phi/kernels/gpu/transpose_kernel.cu | 57 +++++++ .../kernels/impl/transpose_grad_kernel_impl.h | 38 +++++ paddle/phi/kernels/transpose_grad_kernel.h | 28 ++++ paddle/phi/kernels/transpose_kernel.h | 28 ++++ paddle/phi/ops/compat/transpose_sig.cc | 38 +++++ .../unittests/parallel_executor_test_base.py | 2 +- ..._imperative_lod_tensor_to_selected_rows.py | 1 + .../test_parallel_executor_transformer.py | 1 + ...test_partial_eager_deletion_transformer.py | 2 + .../tests/unittests/test_transpose_op.py | 1 + 20 files changed, 426 insertions(+), 270 deletions(-) delete mode 100644 paddle/fluid/operators/transpose_op.cu create mode 100644 paddle/phi/kernels/cpu/transpose_grad_kernel.cc create mode 100644 paddle/phi/kernels/cpu/transpose_kernel.cc create mode 100644 paddle/phi/kernels/gpu/transpose_grad_kernel.cu create mode 100644 paddle/phi/kernels/gpu/transpose_kernel.cu create mode 100644 paddle/phi/kernels/impl/transpose_grad_kernel_impl.h create mode 100644 paddle/phi/kernels/transpose_grad_kernel.h create mode 100644 paddle/phi/kernels/transpose_kernel.h create mode 100644 paddle/phi/ops/compat/transpose_sig.cc diff --git a/paddle/fluid/operators/mkldnn/test_mkldnn_op_nhwc.cc b/paddle/fluid/operators/mkldnn/test_mkldnn_op_nhwc.cc index 52e2caaeb6e..3791fed23a8 100644 --- a/paddle/fluid/operators/mkldnn/test_mkldnn_op_nhwc.cc +++ b/paddle/fluid/operators/mkldnn/test_mkldnn_op_nhwc.cc @@ -29,7 +29,7 @@ USE_OP(pool2d); USE_OP_DEVICE_KERNEL(pool2d, MKLDNN); USE_OP(relu); USE_OP_DEVICE_KERNEL(relu, MKLDNN); -USE_OP(transpose); +USE_OP_ITSELF(transpose); USE_OP_DEVICE_KERNEL(transpose, MKLDNN); namespace paddle { diff --git a/paddle/fluid/operators/transpose_op.cc b/paddle/fluid/operators/transpose_op.cc index 768ab21936f..1a297e7238c 100644 --- a/paddle/fluid/operators/transpose_op.cc +++ b/paddle/fluid/operators/transpose_op.cc @@ -339,6 +339,14 @@ class Transpose2OpGrad : public framework::OperatorWithKernel { } }; +class TransposeGradInferVarType : public framework::VarTypeInference { + public: + void operator()(framework::InferVarTypeContext *ctx) const override { + ctx->SyncTypeAndDataType(framework::GradVarName("Out"), + framework::GradVarName("X")); + } +}; + } // namespace operators } // namespace paddle @@ -347,59 +355,13 @@ REGISTER_OPERATOR( transpose, ops::TransposeOp, ops::TransposeOpMaker, paddle::framework::DefaultGradOpMaker, paddle::framework::DefaultGradOpMaker); -REGISTER_OPERATOR(transpose_grad, ops::TransposeOpGrad); - -REGISTER_OP_CPU_KERNEL( - transpose, ops::TransposeKernel, - ops::TransposeKernel, - ops::TransposeKernel, - ops::TransposeKernel>, - ops::TransposeKernel>, - ops::TransposeKernel); -REGISTER_OP_CPU_KERNEL( - transpose_grad, - ops::TransposeGradKernel, - ops::TransposeGradKernel, - ops::TransposeGradKernel, - ops::TransposeGradKernel>, - ops::TransposeGradKernel>, - ops::TransposeGradKernel); +REGISTER_OPERATOR(transpose_grad, ops::TransposeOpGrad, + ops::TransposeGradInferVarType); REGISTER_OPERATOR(transpose2, ops::Transpose2Op, ops::Transpose2OpMaker, ops::Transpose2GradMaker, ops::Transpose2GradMaker); REGISTER_OPERATOR(transpose2_grad, ops::Transpose2OpGrad, + ops::TransposeGradInferVarType, ops::Transpose2DoubleGradMaker, ops::Transpose2DoubleGradMaker); - -REGISTER_OP_CPU_KERNEL( - transpose2, ops::TransposeKernel, - ops::TransposeKernel, - ops::TransposeKernel, - ops::TransposeKernel, - ops::TransposeKernel, - ops::TransposeKernel>, - ops::TransposeKernel>, - ops::TransposeKernel); -REGISTER_OP_CPU_KERNEL( - transpose2_grad, - ops::TransposeGradKernel, - ops::TransposeGradKernel, - ops::TransposeGradKernel, - ops::TransposeGradKernel, - ops::TransposeGradKernel, - ops::TransposeGradKernel>, - ops::TransposeGradKernel>, - ops::TransposeGradKernel); diff --git a/paddle/fluid/operators/transpose_op.cu b/paddle/fluid/operators/transpose_op.cu deleted file mode 100644 index 02e224549a5..00000000000 --- a/paddle/fluid/operators/transpose_op.cu +++ /dev/null @@ -1,139 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/transpose_op.cu.h" -#include "paddle/fluid/operators/transpose_op.h" -#include "paddle/fluid/platform/bfloat16.h" -#include "paddle/fluid/platform/float16.h" - -namespace paddle { -namespace operators { - -template -class TransposeGPUKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* x = context.InputVar("X"); - auto* out = context.OutputVar("Out"); - - const framework::Tensor* x_tensor = - GetLoDTensorOrSelectedRowsValueFromVar(*x); - framework::Tensor* out_tensor = - GetMutableLoDTensorOrSelectedRowsValueFromVar(out); - - out_tensor->mutable_data(context.GetPlace()); - if (out_tensor->numel() == 0) { - return; - } - - std::vector axis = context.Attr>("axis"); - int ndims = axis.size(); - const auto& dev_ctx = context.template device_context(); - TransposeGPUKernelDriver(dev_ctx, ndims, *x_tensor, axis, out_tensor); - } -}; -template -class TransposeGradGPUKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* out_grad = context.InputVar(framework::GradVarName("Out")); - auto* x_grad = context.OutputVar(framework::GradVarName("X")); - if (!x_grad) { - return; - } - - const framework::Tensor* out_grad_tensor = - GetLoDTensorOrSelectedRowsValueFromVar(*out_grad); - framework::Tensor* x_grad_tensor = - GetMutableLoDTensorOrSelectedRowsValueFromVar(x_grad); - - x_grad_tensor->mutable_data(context.GetPlace()); - if (x_grad_tensor->numel() == 0) { - return; - } - std::vector axis = context.Attr>("axis"); - std::vector reversed_axis(axis); - - for (size_t i = 0; i < axis.size(); i++) { - reversed_axis[axis[i]] = i; - } - - int ndims = axis.size(); - const auto& dev_ctx = context.template device_context(); - TransposeGPUKernelDriver(dev_ctx, ndims, *out_grad_tensor, reversed_axis, - x_grad_tensor); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -namespace plat = paddle::platform; - -REGISTER_OP_CUDA_KERNEL( - transpose, - ops::TransposeGPUKernel, - ops::TransposeGPUKernel, - ops::TransposeGPUKernel, - ops::TransposeGPUKernel, - ops::TransposeGPUKernel, - ops::TransposeGPUKernel>, - ops::TransposeGPUKernel>); -REGISTER_OP_CUDA_KERNEL( - transpose_grad, - ops::TransposeGradGPUKernel, - ops::TransposeGradGPUKernel, - ops::TransposeGradGPUKernel, - ops::TransposeGradGPUKernel, - ops::TransposeGradGPUKernel, - ops::TransposeGradGPUKernel>, - ops::TransposeGradGPUKernel>); - -REGISTER_OP_CUDA_KERNEL( - transpose2, - ops::TransposeGPUKernel, - ops::TransposeGPUKernel, - ops::TransposeGPUKernel, - ops::TransposeGPUKernel, - ops::TransposeGPUKernel, - ops::TransposeGPUKernel, - ops::TransposeGPUKernel, - ops::TransposeGPUKernel>, - ops::TransposeGPUKernel>); -REGISTER_OP_CUDA_KERNEL( - transpose2_grad, - ops::TransposeGradGPUKernel, - ops::TransposeGradGPUKernel, - ops::TransposeGradGPUKernel, - ops::TransposeGradGPUKernel, - ops::TransposeGradGPUKernel, - ops::TransposeGradGPUKernel, - ops::TransposeGradGPUKernel, - ops::TransposeGradGPUKernel>, - ops::TransposeGradGPUKernel>); diff --git a/paddle/fluid/operators/transpose_op.cu.h b/paddle/fluid/operators/transpose_op.cu.h index b542fa37f88..a31ac28c991 100644 --- a/paddle/fluid/operators/transpose_op.cu.h +++ b/paddle/fluid/operators/transpose_op.cu.h @@ -16,8 +16,9 @@ limitations under the License. */ #include "paddle/fluid/framework/gpu_utils.h" #include "paddle/fluid/operators/transpose_op.h" -#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/backends/gpu/gpu_launch_config.h" namespace paddle { namespace operators { @@ -258,10 +259,10 @@ struct SystemElemType<16> { }; template -void LaunchNarrowDims2TransposeKernel(const platform::CUDADeviceContext& d, - int tile_size_i, int tile_size_j, - int total_tiles_count, const T* input, - const Dim3& input_dims, T* output) { +void LaunchNarrowDims2TransposeKernel(const phi::GPUContext& d, int tile_size_i, + int tile_size_j, int total_tiles_count, + const T* input, const Dim3& input_dims, + T* output) { constexpr int NumThreads = tile_long; if (tile_size_i <= tile_long && tile_size_j <= tile_short) { TilingSwapDim1And2< @@ -278,7 +279,7 @@ void LaunchNarrowDims2TransposeKernel(const platform::CUDADeviceContext& d, template struct NarrowDims2TransposeDispatch { - static void DoTranspose(const platform::CUDADeviceContext& d, int tile_size_i, + static void DoTranspose(const phi::GPUContext& d, int tile_size_i, int tile_size_j, int total_tiles_count, const T* input, const Dim3& input_dims, T* output) { PADDLE_ENFORCE_EQ( @@ -319,7 +320,7 @@ struct NarrowDims2TransposeDispatch< T, tile_long, tile_short, typename std::enable_if< CheckNonLongTileSize(tile_long, tile_short, sizeof(T)), void>::type> { - static void DoTranspose(const platform::CUDADeviceContext& d, int tile_size_i, + static void DoTranspose(const phi::GPUContext& d, int tile_size_i, int tile_size_j, int total_tiles_count, const T* input, const Dim3& input_dims, T* output) { PADDLE_ENFORCE_EQ( @@ -351,7 +352,7 @@ struct NarrowDims2TransposeDispatch< T, tile_long, tile_short, typename std::enable_if::type> { - static void DoTranspose(const platform::CUDADeviceContext& d, int tile_size_i, + static void DoTranspose(const phi::GPUContext& d, int tile_size_i, int tile_size_j, int total_tiles_count, const T* input, const Dim3& input_dims, T* output) { PADDLE_ENFORCE_EQ( @@ -368,7 +369,7 @@ struct NarrowDims2TransposeDispatch< }; template -void SwapDim1And2InNarrow(const platform::CUDADeviceContext& d, const T* input, +void SwapDim1And2InNarrow(const phi::GPUContext& d, const T* input, const Dim3& input_dims, T* output, const int kMinTileSize) { // First get available tile sizes for the data type requested as backups @@ -473,9 +474,8 @@ __global__ void TransposeSimpleKernel(int nthreads, const T* __restrict__ input, // Here suppose convert all tensor to dim3, so just change dim1 and 2. template -void SendSwapDim1And2InTranspose(const platform::CUDADeviceContext& d, - const T* input, const Dim3& input_dims, - T* output) { +void SendSwapDim1And2InTranspose(const phi::GPUContext& d, const T* input, + const Dim3& input_dims, T* output) { // Suppose tile size > 16 static const int kMinTileSize = 16; static const int kMinNarrowTileSize = 96; @@ -512,7 +512,7 @@ void SendSwapDim1And2InTranspose(const platform::CUDADeviceContext& d, } else { // If input shape is small, such as 8X8, just do simple copy int total_elements = input_dims[0] * input_dims[1] * input_dims[2]; - auto config = GetGpuLaunchConfig1D(d, total_elements); + auto config = phi::backends::gpu::GetGpuLaunchConfig1D(d, total_elements); TransposeSimpleKernel<<< config.block_per_grid.x, config.thread_per_block.x, 0, d.stream()>>>( total_elements, input, input_dims, output); @@ -521,7 +521,7 @@ void SendSwapDim1And2InTranspose(const platform::CUDADeviceContext& d, template struct SwapDim1And2InTranspose { - typedef platform::CUDADeviceContext Device; + typedef phi::GPUContext Device; void operator()(const Device& d, const T* in, const std::vector& combined_dims, T* out) { Dim3 input_dims = {static_cast(combined_dims[0]), @@ -533,7 +533,7 @@ struct SwapDim1And2InTranspose { template struct SwapDim0And2InTranspose { - typedef platform::CUDADeviceContext Device; + typedef phi::GPUContext Device; void operator()(const Device& d, const T* in, const std::vector& combined_dims, T* out) { Dim3 input_dims = {static_cast(combined_dims[0]), @@ -541,7 +541,7 @@ struct SwapDim0And2InTranspose { static_cast(combined_dims[2])}; size_t total_size = combined_dims[0] * combined_dims[1] * combined_dims[2]; - auto config = GetGpuLaunchConfig1D(d, total_size); + auto config = phi::backends::gpu::GetGpuLaunchConfig1D(d, total_size); TransposeSimpleKernel<<< config.block_per_grid.x, config.thread_per_block.x, 0, d.stream()>>>( @@ -607,7 +607,7 @@ inline void CombineTransposeDim3(const framework::DDim& shape, template struct TransposeSimple { - static bool run(const platform::CUDADeviceContext& ctx, const Tensor& in, + static bool run(const phi::GPUContext& ctx, const Tensor& in, const std::vector perm, Tensor* out) { // First reduce the dimensions of the input tensor if possible. std::vector new_perm; @@ -654,12 +654,12 @@ struct TransposeSimple { }; template -void TransposeGPUKernelDriver(const platform::CUDADeviceContext& dev_ctx, - const int ndims, const Tensor& in, - const std::vector perm, Tensor* out) { +void TransposeGPUKernelDriver(const phi::GPUContext& dev_ctx, const int ndims, + const Tensor& in, + const std::vector& perm, Tensor* out) { auto ret = TransposeSimple::run(dev_ctx, in, perm, out); if (!ret) { - TransCompute(ndims, dev_ctx, in, out, perm); + TransCompute(ndims, dev_ctx, in, out, perm); } } diff --git a/paddle/fluid/operators/transpose_op.h b/paddle/fluid/operators/transpose_op.h index ec05a534c0e..a9e4876cc82 100644 --- a/paddle/fluid/operators/transpose_op.h +++ b/paddle/fluid/operators/transpose_op.h @@ -59,63 +59,5 @@ inline void TransCompute(const int dim, const DeviceContext& dev_ctx, } } -template -class TransposeKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* x = context.InputVar("X"); - auto* out = context.OutputVar("Out"); - - const framework::Tensor* x_tensor = - GetLoDTensorOrSelectedRowsValueFromVar(*x); - framework::Tensor* out_tensor = - GetMutableLoDTensorOrSelectedRowsValueFromVar(out); - - out_tensor->mutable_data(context.GetPlace()); - if (out_tensor->numel() == 0) { - return; - } - - std::vector axis = context.Attr>("axis"); - int ndims = axis.size(); - auto& dev_ctx = context.template device_context(); - TransCompute(ndims, dev_ctx, *x_tensor, out_tensor, axis); - } -}; - -template -class TransposeGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* out_grad = context.InputVar(framework::GradVarName("Out")); - auto* x_grad = context.OutputVar(framework::GradVarName("X")); - - if (!x_grad) { - return; - } - const framework::Tensor* out_grad_tensor = - GetLoDTensorOrSelectedRowsValueFromVar(*out_grad); - framework::Tensor* x_grad_tensor = - GetMutableLoDTensorOrSelectedRowsValueFromVar(x_grad); - - x_grad_tensor->mutable_data(context.GetPlace()); - if (x_grad_tensor->numel() == 0) { - return; - } - - std::vector axis = context.Attr>("axis"); - std::vector reversed_axis(axis); - - for (size_t i = 0; i < axis.size(); i++) { - reversed_axis[axis[i]] = i; - } - - int ndims = axis.size(); - auto& dev_ctx = context.template device_context(); - TransCompute(ndims, dev_ctx, *out_grad_tensor, - x_grad_tensor, reversed_axis); - } -}; - } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/transpose_op_npu_test.cc b/paddle/fluid/operators/transpose_op_npu_test.cc index cce3f188c8b..5617d728a51 100644 --- a/paddle/fluid/operators/transpose_op_npu_test.cc +++ b/paddle/fluid/operators/transpose_op_npu_test.cc @@ -31,7 +31,7 @@ limitations under the License. */ namespace f = paddle::framework; namespace p = paddle::platform; -USE_OP(transpose2); +USE_OP_ITSELF(transpose2); USE_OP_DEVICE_KERNEL(transpose2, NPU); template diff --git a/paddle/phi/kernels/cpu/transpose_grad_kernel.cc b/paddle/phi/kernels/cpu/transpose_grad_kernel.cc new file mode 100644 index 00000000000..9dbcf575f33 --- /dev/null +++ b/paddle/phi/kernels/cpu/transpose_grad_kernel.cc @@ -0,0 +1,32 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/transpose_grad_kernel.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/common/bfloat16.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/transpose_grad_kernel_impl.h" + +PD_REGISTER_KERNEL(transpose_grad, + CPU, + ALL_LAYOUT, + phi::TransposeGradKernel, + bool, + float, + double, + int32_t, + int64_t, + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/cpu/transpose_kernel.cc b/paddle/phi/kernels/cpu/transpose_kernel.cc new file mode 100644 index 00000000000..a80196e7f80 --- /dev/null +++ b/paddle/phi/kernels/cpu/transpose_kernel.cc @@ -0,0 +1,80 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/transpose_kernel.h" +#include +#include "paddle/phi/api/ext/dispatch.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/common/bfloat16.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/impl/transpose_grad_kernel_impl.h" + +namespace phi { + +template +void TransposeKernel(const Context& ctx, + const DenseTensor& x, + const std::vector& axis, + DenseTensor* out) { + ctx.template Alloc(out); + if (out->numel() == 0) { + return; + } + int rank = axis.size(); + switch (rank) { + case 1: + funcs::Transpose trans1; + trans1(ctx, x, out, axis); + break; + case 2: + funcs::Transpose trans2; + trans2(ctx, x, out, axis); + break; + case 3: + funcs::Transpose trans3; + trans3(ctx, x, out, axis); + break; + case 4: + funcs::Transpose trans4; + trans4(ctx, x, out, axis); + break; + case 5: + funcs::Transpose trans5; + trans5(ctx, x, out, axis); + break; + case 6: + funcs::Transpose trans6; + trans6(ctx, x, out, axis); + break; + default: + // for rank >= 7 situation + funcs::TransposeNormal trans_normal; + trans_normal(ctx, x, out, axis); + } +} +} // namespace phi + +PD_REGISTER_KERNEL(transpose, + CPU, + ALL_LAYOUT, + phi::TransposeKernel, + bool, + float, + double, + int32_t, + int64_t, + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/funcs/math_function.cu b/paddle/phi/kernels/funcs/math_function.cu index ae368a005f0..df2af82d551 100644 --- a/paddle/phi/kernels/funcs/math_function.cu +++ b/paddle/phi/kernels/funcs/math_function.cu @@ -187,6 +187,57 @@ void TransposeNormal::operator()( in_ptr, out_ptr, elements, in_stride_ptr, out_stride_ptr, axis_ptr, rank); } +template +struct TransposeNormal { + void operator()(const phi::GPUContext& context, + const DenseTensor& in, + DenseTensor* out, + const std::vector& axis) { + const int rank = axis.size(); + auto in_stride = stride(in.dims()); + auto out_stride = stride(out->dims()); + auto* in_ptr = in.data(); + auto* out_ptr = out->data(); + + // copy in_stride, out_stride, axis to gpu device + const phi::GPUPlace& cuda_place = context.GetPlace(); + phi::CPUPlace cpu_place = paddle::platform::CPUPlace(); + size_t size = 3 * rank * sizeof(int64_t); + auto cpu_buf_holder = paddle::memory::Alloc(cpu_place, size); + auto cuda_buf_holder = paddle::memory::Alloc(cuda_place, size); + REINTERPRET(int64_t, cpu_buf, cpu_buf_holder->ptr()); + REINTERPRET(int64_t, cuda_buf, cuda_buf_holder->ptr()); + for (int i = 0; i < rank; ++i) { + cpu_buf[i] = in_stride[i]; + cpu_buf[rank + i] = out_stride[i]; + cpu_buf[2 * rank + i] = axis[i]; + } + paddle::memory::Copy( + cuda_place, cuda_buf, cpu_place, cpu_buf, size, context.stream()); + REINTERPRET(const int64_t, in_stride_ptr, cuda_buf); + REINTERPRET(const int64_t, out_stride_ptr, cuda_buf + rank); + REINTERPRET(const int64_t, axis_ptr, cuda_buf + 2 * rank); + + const int MAX_BLOCK_DIM = context.GetMaxThreadsPerBlock(); + const int MAX_GRID_DIM = + context.GetMaxPhysicalThreadCount() / MAX_BLOCK_DIM; + int64_t elements = in.numel(); + int block_size = (elements >= MAX_BLOCK_DIM) + ? MAX_BLOCK_DIM + : (1 << static_cast(std::log2(elements))); + int grid_size = elements / block_size; + grid_size = (grid_size >= MAX_GRID_DIM) ? MAX_GRID_DIM : grid_size; + TransposeNormalKernel<<>>( + in_ptr, + out_ptr, + elements, + in_stride_ptr, + out_stride_ptr, + axis_ptr, + rank); + } +}; + // define transpose normal #define DEFINE_GPU_TRANS_NORMAL(TYPE) \ template struct TransposeNormal; \ diff --git a/paddle/phi/kernels/gpu/transpose_grad_kernel.cu b/paddle/phi/kernels/gpu/transpose_grad_kernel.cu new file mode 100644 index 00000000000..0687dc0c200 --- /dev/null +++ b/paddle/phi/kernels/gpu/transpose_grad_kernel.cu @@ -0,0 +1,34 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/impl/transpose_grad_kernel_impl.h" +#include "paddle/phi/kernels/transpose_grad_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/bfloat16.h" +#include "paddle/phi/core/kernel_registry.h" + +PD_REGISTER_KERNEL(transpose_grad, + GPU, + ALL_LAYOUT, + phi::TransposeGradKernel, + bool, + float, + double, + int32_t, + int64_t, + phi::dtype::float16, + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/gpu/transpose_kernel.cu b/paddle/phi/kernels/gpu/transpose_kernel.cu new file mode 100644 index 00000000000..9ea2af292cc --- /dev/null +++ b/paddle/phi/kernels/gpu/transpose_kernel.cu @@ -0,0 +1,57 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "paddle/phi/api/ext/dispatch.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/transpose_kernel.h" + +#include "paddle/fluid/framework/gpu_utils.h" +#include "paddle/fluid/operators/transpose_op.cu.h" +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/backends/gpu/gpu_launch_config.h" +#include "paddle/phi/common/bfloat16.h" +#include "paddle/phi/kernels/impl/transpose_grad_kernel_impl.h" + +namespace phi { +template +void TransposeKernel(const Context& ctx, + const DenseTensor& x, + const std::vector& axis, + DenseTensor* out) { + int rank = axis.size(); + ctx.template Alloc(out); + if (out->numel() == 0) { + return; + } + paddle::operators::TransposeGPUKernelDriver(ctx, rank, x, axis, out); +} + +} // namespace phi + +PD_REGISTER_KERNEL(transpose, + GPU, + ALL_LAYOUT, + phi::TransposeKernel, + bool, + float, + double, + int32_t, + int64_t, + phi::dtype::float16, + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/impl/transpose_grad_kernel_impl.h b/paddle/phi/kernels/impl/transpose_grad_kernel_impl.h new file mode 100644 index 00000000000..6bb555fe28f --- /dev/null +++ b/paddle/phi/kernels/impl/transpose_grad_kernel_impl.h @@ -0,0 +1,38 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/transpose_grad_kernel.h" +#include "paddle/phi/kernels/transpose_kernel.h" + +namespace phi { + +template +void TransposeGradKernel(const Context& dev_ctx, + const DenseTensor& out_grad, + const std::vector& axis, + DenseTensor* x_grad) { + std::vector reversed_axis(axis); + + dev_ctx.template Alloc(x_grad); + for (size_t i = 0; i < axis.size(); i++) { + reversed_axis[axis[i]] = i; + } + + TransposeKernel(dev_ctx, out_grad, reversed_axis, x_grad); +} + +} // namespace phi diff --git a/paddle/phi/kernels/transpose_grad_kernel.h b/paddle/phi/kernels/transpose_grad_kernel.h new file mode 100644 index 00000000000..33d4ca7e3c6 --- /dev/null +++ b/paddle/phi/kernels/transpose_grad_kernel.h @@ -0,0 +1,28 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void TransposeGradKernel(const Context& dev_ctx, + const DenseTensor& out_grad, + const std::vector& axis, + DenseTensor* x_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/transpose_kernel.h b/paddle/phi/kernels/transpose_kernel.h new file mode 100644 index 00000000000..303b4a9a8f0 --- /dev/null +++ b/paddle/phi/kernels/transpose_kernel.h @@ -0,0 +1,28 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void TransposeKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& axis, + DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/ops/compat/transpose_sig.cc b/paddle/phi/ops/compat/transpose_sig.cc new file mode 100644 index 00000000000..90961760cfc --- /dev/null +++ b/paddle/phi/ops/compat/transpose_sig.cc @@ -0,0 +1,38 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature TransposeOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature("transpose", {"X"}, {"axis"}, {"Out"}); +} + +KernelSignature TransposeGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature( + "transpose_grad", {GradVarName("Out")}, {"axis"}, {GradVarName("X")}); +} + +} // namespace phi + +PD_REGISTER_BASE_KERNEL_NAME(transpose2, transpose); +PD_REGISTER_BASE_KERNEL_NAME(transpose2_grad, transpose_grad); + +PD_REGISTER_ARG_MAPPING_FN(transpose2, phi::TransposeOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(transpose2_grad, + phi::TransposeGradOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(transpose, phi::TransposeOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(transpose_grad, phi::TransposeGradOpArgumentMapping); diff --git a/python/paddle/fluid/tests/unittests/parallel_executor_test_base.py b/python/paddle/fluid/tests/unittests/parallel_executor_test_base.py index 2a8f72c2170..2633a599256 100644 --- a/python/paddle/fluid/tests/unittests/parallel_executor_test_base.py +++ b/python/paddle/fluid/tests/unittests/parallel_executor_test_base.py @@ -43,7 +43,7 @@ class TestParallelExecutorBase(unittest.TestCase): get_data_from_feeder=None, use_parallel_executor=True, use_reduce=False, - use_ir_memory_optimize=True, + use_ir_memory_optimize=False, enable_inplace=True, fuse_elewise_add_act_ops=False, fuse_all_optimizer_ops=False, diff --git a/python/paddle/fluid/tests/unittests/test_imperative_lod_tensor_to_selected_rows.py b/python/paddle/fluid/tests/unittests/test_imperative_lod_tensor_to_selected_rows.py index d54194164a5..110bb961bbe 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_lod_tensor_to_selected_rows.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_lod_tensor_to_selected_rows.py @@ -207,4 +207,5 @@ class TestDygraphSimpleNet(unittest.TestCase): if __name__ == '__main__': + paddle.enable_static() unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_parallel_executor_transformer.py b/python/paddle/fluid/tests/unittests/test_parallel_executor_transformer.py index 1cb39eb131b..b87e8d4e3c2 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_executor_transformer.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_executor_transformer.py @@ -206,4 +206,5 @@ class TestTransformer(TestParallelExecutorBase): if __name__ == '__main__': + paddle.enable_static() unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_partial_eager_deletion_transformer.py b/python/paddle/fluid/tests/unittests/test_partial_eager_deletion_transformer.py index 1661f753a84..15d9e0e2daa 100644 --- a/python/paddle/fluid/tests/unittests/test_partial_eager_deletion_transformer.py +++ b/python/paddle/fluid/tests/unittests/test_partial_eager_deletion_transformer.py @@ -14,10 +14,12 @@ import unittest import paddle.fluid as fluid +import paddle fluid.core._set_eager_deletion_mode(0.0, 0.55, True) from test_parallel_executor_transformer import TestTransformer if __name__ == '__main__': + paddle.enable_static() unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_transpose_op.py b/python/paddle/fluid/tests/unittests/test_transpose_op.py index 13b880b28bf..1e6b4354dd9 100644 --- a/python/paddle/fluid/tests/unittests/test_transpose_op.py +++ b/python/paddle/fluid/tests/unittests/test_transpose_op.py @@ -463,4 +463,5 @@ class TestMoveAxis(unittest.TestCase): if __name__ == '__main__': + paddle.enable_static() unittest.main() -- GitLab