未验证 提交 7a857924 编写于 作者: H hong 提交者: GitHub

Move transpose to pten (#39327)

* immigrate_transpose_to_pten cpu kernel only; test=develop

* fix bug; test=develop

* add transpose cuda api

* bug fix;

* fix bugs

* fix bugs; test=develop

* bug fix;

* move transepose to pten; test=develop

* fix bug; test=develop

* fix bugs; test=develop

* add transpose grad fp16 support; test=develop

* fix bug; test=develop

* fix npu bug; test=develop

* fix nemul = 0 bug; test=develop

* add fp16 support; test=develop

* fix data type register bug; test=develop

* fix transpose bug; test=develop

* update transpose

* fix transpose bug; test=develop

* remove useless code; test=develop

* remove useless code; test=develop

* fix transpose alias bug; test=develop

* polish code; test=develop

* resolve confict; test=develop

* resolve confilct; test=develop

* recover prepared operator; test=develop

* fix bug; test=develop

* polish code; test=develop

* fix bug; test=develop

* fix bug; test=develop
上级 2a5590a1
......@@ -29,7 +29,7 @@ USE_OP(pool2d);
USE_OP_DEVICE_KERNEL(pool2d, MKLDNN);
USE_OP(relu);
USE_OP_DEVICE_KERNEL(relu, MKLDNN);
USE_OP(transpose);
USE_OP_ITSELF(transpose);
USE_OP_DEVICE_KERNEL(transpose, MKLDNN);
namespace paddle {
......
......@@ -339,6 +339,14 @@ class Transpose2OpGrad : public framework::OperatorWithKernel {
}
};
class TransposeGradInferVarType : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
ctx->SyncTypeAndDataType(framework::GradVarName("Out"),
framework::GradVarName("X"));
}
};
} // namespace operators
} // namespace paddle
......@@ -347,59 +355,13 @@ REGISTER_OPERATOR(
transpose, ops::TransposeOp, ops::TransposeOpMaker,
paddle::framework::DefaultGradOpMaker<paddle::framework::OpDesc, true>,
paddle::framework::DefaultGradOpMaker<paddle::imperative::OpBase, true>);
REGISTER_OPERATOR(transpose_grad, ops::TransposeOpGrad);
REGISTER_OP_CPU_KERNEL(
transpose, ops::TransposeKernel<paddle::platform::CPUDeviceContext, bool>,
ops::TransposeKernel<paddle::platform::CPUDeviceContext, float>,
ops::TransposeKernel<paddle::platform::CPUDeviceContext, double>,
ops::TransposeKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::TransposeKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>,
ops::TransposeKernel<paddle::platform::CPUDeviceContext,
paddle::platform::bfloat16>);
REGISTER_OP_CPU_KERNEL(
transpose_grad,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, bool>,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::bfloat16>);
REGISTER_OPERATOR(transpose_grad, ops::TransposeOpGrad,
ops::TransposeGradInferVarType);
REGISTER_OPERATOR(transpose2, ops::Transpose2Op, ops::Transpose2OpMaker,
ops::Transpose2GradMaker<paddle::framework::OpDesc>,
ops::Transpose2GradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(transpose2_grad, ops::Transpose2OpGrad,
ops::TransposeGradInferVarType,
ops::Transpose2DoubleGradMaker<paddle::framework::OpDesc>,
ops::Transpose2DoubleGradMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(
transpose2, ops::TransposeKernel<paddle::platform::CPUDeviceContext, bool>,
ops::TransposeKernel<paddle::platform::CPUDeviceContext, float>,
ops::TransposeKernel<paddle::platform::CPUDeviceContext, int32_t>,
ops::TransposeKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::TransposeKernel<paddle::platform::CPUDeviceContext, double>,
ops::TransposeKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::TransposeKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>,
ops::TransposeKernel<paddle::platform::CPUDeviceContext,
paddle::platform::bfloat16>);
REGISTER_OP_CPU_KERNEL(
transpose2_grad,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, bool>,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, int32_t>,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::bfloat16>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/transpose_op.cu.h"
#include "paddle/fluid/operators/transpose_op.h"
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class TransposeGPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.InputVar("X");
auto* out = context.OutputVar("Out");
const framework::Tensor* x_tensor =
GetLoDTensorOrSelectedRowsValueFromVar(*x);
framework::Tensor* out_tensor =
GetMutableLoDTensorOrSelectedRowsValueFromVar(out);
out_tensor->mutable_data<T>(context.GetPlace());
if (out_tensor->numel() == 0) {
return;
}
std::vector<int> axis = context.Attr<std::vector<int>>("axis");
int ndims = axis.size();
const auto& dev_ctx = context.template device_context<DeviceContext>();
TransposeGPUKernelDriver<T>(dev_ctx, ndims, *x_tensor, axis, out_tensor);
}
};
template <typename DeviceContext, typename T>
class TransposeGradGPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* out_grad = context.InputVar(framework::GradVarName("Out"));
auto* x_grad = context.OutputVar(framework::GradVarName("X"));
if (!x_grad) {
return;
}
const framework::Tensor* out_grad_tensor =
GetLoDTensorOrSelectedRowsValueFromVar(*out_grad);
framework::Tensor* x_grad_tensor =
GetMutableLoDTensorOrSelectedRowsValueFromVar(x_grad);
x_grad_tensor->mutable_data<T>(context.GetPlace());
if (x_grad_tensor->numel() == 0) {
return;
}
std::vector<int> axis = context.Attr<std::vector<int>>("axis");
std::vector<int> reversed_axis(axis);
for (size_t i = 0; i < axis.size(); i++) {
reversed_axis[axis[i]] = i;
}
int ndims = axis.size();
const auto& dev_ctx = context.template device_context<DeviceContext>();
TransposeGPUKernelDriver<T>(dev_ctx, ndims, *out_grad_tensor, reversed_axis,
x_grad_tensor);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
transpose,
ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext, bool>,
ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext, float>,
ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext, double>,
ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext,
plat::bfloat16>,
ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
transpose_grad,
ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext, bool>,
ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext, float>,
ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext, double>,
ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext,
plat::float16>,
ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext,
plat::bfloat16>,
ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
transpose2,
ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext, bool>,
ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext, int32_t>,
ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext, float>,
ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext, double>,
ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext,
plat::bfloat16>,
ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
transpose2_grad,
ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext, bool>,
ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext, int32_t>,
ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext, float>,
ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext, double>,
ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext,
plat::float16>,
ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext,
plat::bfloat16>,
ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
......@@ -16,8 +16,9 @@ limitations under the License. */
#include "paddle/fluid/framework/gpu_utils.h"
#include "paddle/fluid/operators/transpose_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
namespace paddle {
namespace operators {
......@@ -258,10 +259,10 @@ struct SystemElemType<16> {
};
template <typename T, int tile_long, int tile_short>
void LaunchNarrowDims2TransposeKernel(const platform::CUDADeviceContext& d,
int tile_size_i, int tile_size_j,
int total_tiles_count, const T* input,
const Dim3& input_dims, T* output) {
void LaunchNarrowDims2TransposeKernel(const phi::GPUContext& d, int tile_size_i,
int tile_size_j, int total_tiles_count,
const T* input, const Dim3& input_dims,
T* output) {
constexpr int NumThreads = tile_long;
if (tile_size_i <= tile_long && tile_size_j <= tile_short) {
TilingSwapDim1And2<
......@@ -278,7 +279,7 @@ void LaunchNarrowDims2TransposeKernel(const platform::CUDADeviceContext& d,
template <typename T, int tile_long, int tile_short, typename dummy = void>
struct NarrowDims2TransposeDispatch {
static void DoTranspose(const platform::CUDADeviceContext& d, int tile_size_i,
static void DoTranspose(const phi::GPUContext& d, int tile_size_i,
int tile_size_j, int total_tiles_count,
const T* input, const Dim3& input_dims, T* output) {
PADDLE_ENFORCE_EQ(
......@@ -319,7 +320,7 @@ struct NarrowDims2TransposeDispatch<
T, tile_long, tile_short,
typename std::enable_if<
CheckNonLongTileSize(tile_long, tile_short, sizeof(T)), void>::type> {
static void DoTranspose(const platform::CUDADeviceContext& d, int tile_size_i,
static void DoTranspose(const phi::GPUContext& d, int tile_size_i,
int tile_size_j, int total_tiles_count,
const T* input, const Dim3& input_dims, T* output) {
PADDLE_ENFORCE_EQ(
......@@ -351,7 +352,7 @@ struct NarrowDims2TransposeDispatch<
T, tile_long, tile_short,
typename std::enable_if<CheckLongTileSize(tile_long, tile_short, sizeof(T)),
void>::type> {
static void DoTranspose(const platform::CUDADeviceContext& d, int tile_size_i,
static void DoTranspose(const phi::GPUContext& d, int tile_size_i,
int tile_size_j, int total_tiles_count,
const T* input, const Dim3& input_dims, T* output) {
PADDLE_ENFORCE_EQ(
......@@ -368,7 +369,7 @@ struct NarrowDims2TransposeDispatch<
};
template <typename T, bool conjugate = false>
void SwapDim1And2InNarrow(const platform::CUDADeviceContext& d, const T* input,
void SwapDim1And2InNarrow(const phi::GPUContext& d, const T* input,
const Dim3& input_dims, T* output,
const int kMinTileSize) {
// First get available tile sizes for the data type requested as backups
......@@ -473,9 +474,8 @@ __global__ void TransposeSimpleKernel(int nthreads, const T* __restrict__ input,
// Here suppose convert all tensor to dim3, so just change dim1 and 2.
template <typename T>
void SendSwapDim1And2InTranspose(const platform::CUDADeviceContext& d,
const T* input, const Dim3& input_dims,
T* output) {
void SendSwapDim1And2InTranspose(const phi::GPUContext& d, const T* input,
const Dim3& input_dims, T* output) {
// Suppose tile size > 16
static const int kMinTileSize = 16;
static const int kMinNarrowTileSize = 96;
......@@ -512,7 +512,7 @@ void SendSwapDim1And2InTranspose(const platform::CUDADeviceContext& d,
} else {
// If input shape is small, such as 8X8, just do simple copy
int total_elements = input_dims[0] * input_dims[1] * input_dims[2];
auto config = GetGpuLaunchConfig1D(d, total_elements);
auto config = phi::backends::gpu::GetGpuLaunchConfig1D(d, total_elements);
TransposeSimpleKernel<T, 0, 2, 1><<<
config.block_per_grid.x, config.thread_per_block.x, 0, d.stream()>>>(
total_elements, input, input_dims, output);
......@@ -521,7 +521,7 @@ void SendSwapDim1And2InTranspose(const platform::CUDADeviceContext& d,
template <typename T>
struct SwapDim1And2InTranspose {
typedef platform::CUDADeviceContext Device;
typedef phi::GPUContext Device;
void operator()(const Device& d, const T* in,
const std::vector<int>& combined_dims, T* out) {
Dim3 input_dims = {static_cast<int>(combined_dims[0]),
......@@ -533,7 +533,7 @@ struct SwapDim1And2InTranspose {
template <typename T>
struct SwapDim0And2InTranspose {
typedef platform::CUDADeviceContext Device;
typedef phi::GPUContext Device;
void operator()(const Device& d, const T* in,
const std::vector<int>& combined_dims, T* out) {
Dim3 input_dims = {static_cast<int>(combined_dims[0]),
......@@ -541,7 +541,7 @@ struct SwapDim0And2InTranspose {
static_cast<int>(combined_dims[2])};
size_t total_size = combined_dims[0] * combined_dims[1] * combined_dims[2];
auto config = GetGpuLaunchConfig1D(d, total_size);
auto config = phi::backends::gpu::GetGpuLaunchConfig1D(d, total_size);
TransposeSimpleKernel<T, 2, 1, 0><<<
config.block_per_grid.x, config.thread_per_block.x, 0, d.stream()>>>(
......@@ -607,7 +607,7 @@ inline void CombineTransposeDim3(const framework::DDim& shape,
template <typename T>
struct TransposeSimple {
static bool run(const platform::CUDADeviceContext& ctx, const Tensor& in,
static bool run(const phi::GPUContext& ctx, const Tensor& in,
const std::vector<int32_t> perm, Tensor* out) {
// First reduce the dimensions of the input tensor if possible.
std::vector<int> new_perm;
......@@ -654,12 +654,12 @@ struct TransposeSimple {
};
template <typename T>
void TransposeGPUKernelDriver(const platform::CUDADeviceContext& dev_ctx,
const int ndims, const Tensor& in,
const std::vector<int32_t> perm, Tensor* out) {
void TransposeGPUKernelDriver(const phi::GPUContext& dev_ctx, const int ndims,
const Tensor& in,
const std::vector<int32_t>& perm, Tensor* out) {
auto ret = TransposeSimple<T>::run(dev_ctx, in, perm, out);
if (!ret) {
TransCompute<platform::CUDADeviceContext, T>(ndims, dev_ctx, in, out, perm);
TransCompute<phi::GPUContext, T>(ndims, dev_ctx, in, out, perm);
}
}
......
......@@ -59,63 +59,5 @@ inline void TransCompute(const int dim, const DeviceContext& dev_ctx,
}
}
template <typename DeviceContext, typename T>
class TransposeKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.InputVar("X");
auto* out = context.OutputVar("Out");
const framework::Tensor* x_tensor =
GetLoDTensorOrSelectedRowsValueFromVar(*x);
framework::Tensor* out_tensor =
GetMutableLoDTensorOrSelectedRowsValueFromVar(out);
out_tensor->mutable_data<T>(context.GetPlace());
if (out_tensor->numel() == 0) {
return;
}
std::vector<int> axis = context.Attr<std::vector<int>>("axis");
int ndims = axis.size();
auto& dev_ctx = context.template device_context<DeviceContext>();
TransCompute<DeviceContext, T>(ndims, dev_ctx, *x_tensor, out_tensor, axis);
}
};
template <typename DeviceContext, typename T>
class TransposeGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* out_grad = context.InputVar(framework::GradVarName("Out"));
auto* x_grad = context.OutputVar(framework::GradVarName("X"));
if (!x_grad) {
return;
}
const framework::Tensor* out_grad_tensor =
GetLoDTensorOrSelectedRowsValueFromVar(*out_grad);
framework::Tensor* x_grad_tensor =
GetMutableLoDTensorOrSelectedRowsValueFromVar(x_grad);
x_grad_tensor->mutable_data<T>(context.GetPlace());
if (x_grad_tensor->numel() == 0) {
return;
}
std::vector<int> axis = context.Attr<std::vector<int>>("axis");
std::vector<int> reversed_axis(axis);
for (size_t i = 0; i < axis.size(); i++) {
reversed_axis[axis[i]] = i;
}
int ndims = axis.size();
auto& dev_ctx = context.template device_context<DeviceContext>();
TransCompute<DeviceContext, T>(ndims, dev_ctx, *out_grad_tensor,
x_grad_tensor, reversed_axis);
}
};
} // namespace operators
} // namespace paddle
......@@ -31,7 +31,7 @@ limitations under the License. */
namespace f = paddle::framework;
namespace p = paddle::platform;
USE_OP(transpose2);
USE_OP_ITSELF(transpose2);
USE_OP_DEVICE_KERNEL(transpose2, NPU);
template <typename T>
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/transpose_grad_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/transpose_grad_kernel_impl.h"
PD_REGISTER_KERNEL(transpose_grad,
CPU,
ALL_LAYOUT,
phi::TransposeGradKernel,
bool,
float,
double,
int32_t,
int64_t,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/transpose_kernel.h"
#include <vector>
#include "paddle/phi/api/ext/dispatch.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/impl/transpose_grad_kernel_impl.h"
namespace phi {
template <typename T, typename Context>
void TransposeKernel(const Context& ctx,
const DenseTensor& x,
const std::vector<int>& axis,
DenseTensor* out) {
ctx.template Alloc<T>(out);
if (out->numel() == 0) {
return;
}
int rank = axis.size();
switch (rank) {
case 1:
funcs::Transpose<Context, T, 1> trans1;
trans1(ctx, x, out, axis);
break;
case 2:
funcs::Transpose<Context, T, 2> trans2;
trans2(ctx, x, out, axis);
break;
case 3:
funcs::Transpose<Context, T, 3> trans3;
trans3(ctx, x, out, axis);
break;
case 4:
funcs::Transpose<Context, T, 4> trans4;
trans4(ctx, x, out, axis);
break;
case 5:
funcs::Transpose<Context, T, 5> trans5;
trans5(ctx, x, out, axis);
break;
case 6:
funcs::Transpose<Context, T, 6> trans6;
trans6(ctx, x, out, axis);
break;
default:
// for rank >= 7 situation
funcs::TransposeNormal<Context, T> trans_normal;
trans_normal(ctx, x, out, axis);
}
}
} // namespace phi
PD_REGISTER_KERNEL(transpose,
CPU,
ALL_LAYOUT,
phi::TransposeKernel,
bool,
float,
double,
int32_t,
int64_t,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
......@@ -187,6 +187,57 @@ void TransposeNormal<DeviceContext, T>::operator()(
in_ptr, out_ptr, elements, in_stride_ptr, out_stride_ptr, axis_ptr, rank);
}
template <typename T>
struct TransposeNormal<phi::GPUContext, T> {
void operator()(const phi::GPUContext& context,
const DenseTensor& in,
DenseTensor* out,
const std::vector<int>& axis) {
const int rank = axis.size();
auto in_stride = stride(in.dims());
auto out_stride = stride(out->dims());
auto* in_ptr = in.data<T>();
auto* out_ptr = out->data<T>();
// copy in_stride, out_stride, axis to gpu device
const phi::GPUPlace& cuda_place = context.GetPlace();
phi::CPUPlace cpu_place = paddle::platform::CPUPlace();
size_t size = 3 * rank * sizeof(int64_t);
auto cpu_buf_holder = paddle::memory::Alloc(cpu_place, size);
auto cuda_buf_holder = paddle::memory::Alloc(cuda_place, size);
REINTERPRET(int64_t, cpu_buf, cpu_buf_holder->ptr());
REINTERPRET(int64_t, cuda_buf, cuda_buf_holder->ptr());
for (int i = 0; i < rank; ++i) {
cpu_buf[i] = in_stride[i];
cpu_buf[rank + i] = out_stride[i];
cpu_buf[2 * rank + i] = axis[i];
}
paddle::memory::Copy(
cuda_place, cuda_buf, cpu_place, cpu_buf, size, context.stream());
REINTERPRET(const int64_t, in_stride_ptr, cuda_buf);
REINTERPRET(const int64_t, out_stride_ptr, cuda_buf + rank);
REINTERPRET(const int64_t, axis_ptr, cuda_buf + 2 * rank);
const int MAX_BLOCK_DIM = context.GetMaxThreadsPerBlock();
const int MAX_GRID_DIM =
context.GetMaxPhysicalThreadCount() / MAX_BLOCK_DIM;
int64_t elements = in.numel();
int block_size = (elements >= MAX_BLOCK_DIM)
? MAX_BLOCK_DIM
: (1 << static_cast<int>(std::log2(elements)));
int grid_size = elements / block_size;
grid_size = (grid_size >= MAX_GRID_DIM) ? MAX_GRID_DIM : grid_size;
TransposeNormalKernel<T><<<grid_size, block_size, 0, context.stream()>>>(
in_ptr,
out_ptr,
elements,
in_stride_ptr,
out_stride_ptr,
axis_ptr,
rank);
}
};
// define transpose normal
#define DEFINE_GPU_TRANS_NORMAL(TYPE) \
template struct TransposeNormal<paddle::platform::CUDADeviceContext, TYPE>; \
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/impl/transpose_grad_kernel_impl.h"
#include "paddle/phi/kernels/transpose_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/core/kernel_registry.h"
PD_REGISTER_KERNEL(transpose_grad,
GPU,
ALL_LAYOUT,
phi::TransposeGradKernel,
bool,
float,
double,
int32_t,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <vector>
#include "paddle/phi/api/ext/dispatch.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/transpose_kernel.h"
#include "paddle/fluid/framework/gpu_utils.h"
#include "paddle/fluid/operators/transpose_op.cu.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/kernels/impl/transpose_grad_kernel_impl.h"
namespace phi {
template <typename T, typename Context>
void TransposeKernel(const Context& ctx,
const DenseTensor& x,
const std::vector<int>& axis,
DenseTensor* out) {
int rank = axis.size();
ctx.template Alloc<T>(out);
if (out->numel() == 0) {
return;
}
paddle::operators::TransposeGPUKernelDriver<T>(ctx, rank, x, axis, out);
}
} // namespace phi
PD_REGISTER_KERNEL(transpose,
GPU,
ALL_LAYOUT,
phi::TransposeKernel,
bool,
float,
double,
int32_t,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/transpose_grad_kernel.h"
#include "paddle/phi/kernels/transpose_kernel.h"
namespace phi {
template <typename T, typename Context>
void TransposeGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
const std::vector<int>& axis,
DenseTensor* x_grad) {
std::vector<int> reversed_axis(axis);
dev_ctx.template Alloc<T>(x_grad);
for (size_t i = 0; i < axis.size(); i++) {
reversed_axis[axis[i]] = i;
}
TransposeKernel<T, Context>(dev_ctx, out_grad, reversed_axis, x_grad);
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void TransposeGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
const std::vector<int>& axis,
DenseTensor* x_grad);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void TransposeKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int>& axis,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature TransposeOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("transpose", {"X"}, {"axis"}, {"Out"});
}
KernelSignature TransposeGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"transpose_grad", {GradVarName("Out")}, {"axis"}, {GradVarName("X")});
}
} // namespace phi
PD_REGISTER_BASE_KERNEL_NAME(transpose2, transpose);
PD_REGISTER_BASE_KERNEL_NAME(transpose2_grad, transpose_grad);
PD_REGISTER_ARG_MAPPING_FN(transpose2, phi::TransposeOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(transpose2_grad,
phi::TransposeGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(transpose, phi::TransposeOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(transpose_grad, phi::TransposeGradOpArgumentMapping);
......@@ -43,7 +43,7 @@ class TestParallelExecutorBase(unittest.TestCase):
get_data_from_feeder=None,
use_parallel_executor=True,
use_reduce=False,
use_ir_memory_optimize=True,
use_ir_memory_optimize=False,
enable_inplace=True,
fuse_elewise_add_act_ops=False,
fuse_all_optimizer_ops=False,
......
......@@ -207,4 +207,5 @@ class TestDygraphSimpleNet(unittest.TestCase):
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
......@@ -206,4 +206,5 @@ class TestTransformer(TestParallelExecutorBase):
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
......@@ -14,10 +14,12 @@
import unittest
import paddle.fluid as fluid
import paddle
fluid.core._set_eager_deletion_mode(0.0, 0.55, True)
from test_parallel_executor_transformer import TestTransformer
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
......@@ -463,4 +463,5 @@ class TestMoveAxis(unittest.TestCase):
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册