diff --git a/oneflow/core/kernel/util/cuda_arithemetic_interface.cu b/oneflow/core/kernel/util/cuda_arithemetic_interface.cu index 0cd4c153991ce87b373c9301e12b190109004b65..6e4218736eedb8be9318f7f9eb3cc9f69aced0e7 100644 --- a/oneflow/core/kernel/util/cuda_arithemetic_interface.cu +++ b/oneflow/core/kernel/util/cuda_arithemetic_interface.cu @@ -33,8 +33,9 @@ template __device__ int32_t GetXIndex(const int32_t* y_shape, const int32_t* x_strides, int32_t y_idx) { int32_t x_idx = 0; for (int32_t i = NDIMS - 1; i >= 0; --i) { - x_idx += (y_idx % y_shape[i]) * x_strides[i]; - y_idx /= y_shape[i]; + const int32_t next_y_idx = y_idx / y_shape[i]; + x_idx += (y_idx - next_y_idx * y_shape[i]) * x_strides[i]; + y_idx = next_y_idx; } return x_idx; } @@ -42,16 +43,8 @@ __device__ int32_t GetXIndex(const int32_t* y_shape, const int32_t* x_strides, i template __global__ void TransposeGpu(const Int32Array y_shape, const Int32Array x_strides, const int32_t elem_cnt, const T* x, T* y) { - __shared__ int32_t x_strides_shared[NDIMS]; - __shared__ int32_t y_dims_shared[NDIMS]; - const int32_t tid = threadIdx.x; - if (tid < NDIMS) { - y_dims_shared[tid] = y_shape.val[tid]; - x_strides_shared[tid] = x_strides.val[tid]; - } - __syncthreads(); CUDA_1D_KERNEL_LOOP(y_idx, elem_cnt) { - const int32_t x_idx = GetXIndex(y_dims_shared, x_strides_shared, y_idx); + const int32_t x_idx = GetXIndex(y_shape.val, x_strides.val, y_idx); #if __CUDA_ARCH__ >= 350 y[y_idx] = __ldg(x + x_idx); #else @@ -62,7 +55,8 @@ __global__ void TransposeGpu(const Int32Array y_shape, const Int32Array void TransposeImpl(DeviceCtx* ctx, const ShapeView& x_shape, const ShapeView& y_shape, - const PbRf& permutation, const int64_t elem_cnt, const T* x, T* y) { + const std::vector& permutation, const int64_t elem_cnt, const T* x, + T* y) { CHECK_LE(y_shape.elem_cnt(), GetMaxVal()); Int32Array y_shape_struct; FOR_RANGE(int32_t, i, 0, NDIMS) { y_shape_struct.val[i] = y_shape.At(i); } @@ -95,7 +89,7 @@ struct TransposeUtil final { void ArithemeticIf::Transpose(DeviceCtx* ctx, const int32_t num_axis, const ShapeView& x_shape, const ShapeView& y_shape, - const PbRf& permutation, + const std::vector& permutation, const int64_t elem_cnt, const float* x, float* y) { TRANSPOSE_CHECK; TransposeUtil::SwitchTransposeImpl(SwitchCase(num_axis), ctx, x_shape, y_shape, @@ -104,7 +98,7 @@ void ArithemeticIf::Transpose(DeviceCtx* ctx, const int32_t nu void ArithemeticIf::Transpose(DeviceCtx* ctx, const int32_t num_axis, const ShapeView& x_shape, const ShapeView& y_shape, - const PbRf& permutation, + const std::vector& permutation, const int64_t elem_cnt, const double* x, double* y) { TRANSPOSE_CHECK; @@ -114,7 +108,7 @@ void ArithemeticIf::Transpose(DeviceCtx* ctx, const int32_t nu void ArithemeticIf::Transpose(DeviceCtx* ctx, const int32_t num_axis, const ShapeView& x_shape, const ShapeView& y_shape, - const PbRf& permutation, + const std::vector& permutation, const int64_t elem_cnt, const float16* x, float16* y) { TRANSPOSE_CHECK; @@ -125,7 +119,7 @@ void ArithemeticIf::Transpose(DeviceCtx* ctx, const int32_t nu void ArithemeticIf::Transpose(DeviceCtx* ctx, const int32_t num_axis, const ShapeView& x_shape, const ShapeView& y_shape, - const PbRf& permutation, + const std::vector& permutation, const int64_t elem_cnt, const int8_t* x, int8_t* y) { TRANSPOSE_CHECK; @@ -135,7 +129,7 @@ void ArithemeticIf::Transpose(DeviceCtx* ctx, const int32_t nu void ArithemeticIf::Transpose(DeviceCtx* ctx, const int32_t num_axis, const ShapeView& x_shape, const ShapeView& y_shape, - const PbRf& permutation, + const std::vector& permutation, const int64_t elem_cnt, const int32_t* x, int32_t* y) { TRANSPOSE_CHECK; @@ -145,7 +139,7 @@ void ArithemeticIf::Transpose(DeviceCtx* ctx, const int32_t nu void ArithemeticIf::Transpose(DeviceCtx* ctx, const int32_t num_axis, const ShapeView& x_shape, const ShapeView& y_shape, - const PbRf& permutation, + const std::vector& permutation, const int64_t elem_cnt, const int64_t* x, int64_t* y) { TRANSPOSE_CHECK; @@ -155,6 +149,65 @@ void ArithemeticIf::Transpose(DeviceCtx* ctx, const int32_t nu #undef TRANSPOSE_CHECK +void ArithemeticIf::Transpose(DeviceCtx* ctx, const int32_t num_axis, + const ShapeView& x_shape, const ShapeView& y_shape, + const PbRf& permutation, + const int64_t elem_cnt, const float* x, float* y) { + ArithemeticIf::Transpose( + ctx, num_axis, x_shape, y_shape, + std::vector({permutation.cbegin(), permutation.cend()}), elem_cnt, x, y); +} + +void ArithemeticIf::Transpose(DeviceCtx* ctx, const int32_t num_axis, + const ShapeView& x_shape, const ShapeView& y_shape, + const PbRf& permutation, + const int64_t elem_cnt, const double* x, + double* y) { + ArithemeticIf::Transpose( + ctx, num_axis, x_shape, y_shape, + std::vector({permutation.cbegin(), permutation.cend()}), elem_cnt, x, y); +} + +void ArithemeticIf::Transpose(DeviceCtx* ctx, const int32_t num_axis, + const ShapeView& x_shape, const ShapeView& y_shape, + const PbRf& permutation, + const int64_t elem_cnt, const float16* x, + float16* y) { + ArithemeticIf::Transpose( + ctx, num_axis, x_shape, y_shape, + std::vector({permutation.cbegin(), permutation.cend()}), elem_cnt, x, y); +} + +void ArithemeticIf::Transpose(DeviceCtx* ctx, const int32_t num_axis, + const ShapeView& x_shape, const ShapeView& y_shape, + const PbRf& permutation, + const int64_t elem_cnt, const int8_t* x, + int8_t* y) { + ArithemeticIf::Transpose( + ctx, num_axis, x_shape, y_shape, + std::vector({permutation.cbegin(), permutation.cend()}), elem_cnt, x, y); +} + +void ArithemeticIf::Transpose(DeviceCtx* ctx, const int32_t num_axis, + const ShapeView& x_shape, const ShapeView& y_shape, + const PbRf& permutation, + const int64_t elem_cnt, const int32_t* x, + int32_t* y) { + ArithemeticIf::Transpose( + ctx, num_axis, x_shape, y_shape, + std::vector({permutation.cbegin(), permutation.cend()}), elem_cnt, x, y); +} + +void ArithemeticIf::Transpose(DeviceCtx* ctx, const int32_t num_axis, + const ShapeView& x_shape, const ShapeView& y_shape, + const PbRf& permutation, + const int64_t elem_cnt, const int64_t* x, + int64_t* y) { + ArithemeticIf::Transpose( + ctx, num_axis, x_shape, y_shape, + std::vector({permutation.cbegin(), permutation.cend()}), elem_cnt, x, y); +} + void ArithemeticIf::InitializeWithConstConf( DeviceCtx* ctx, const ConstantInitializerConf& initializer_conf, Blob* blob) { WithHostBlobAndStreamSynchronizeEnv(ctx, blob, [&](Blob* host_blob) { diff --git a/oneflow/core/kernel/util/cuda_arithemetic_interface.h b/oneflow/core/kernel/util/cuda_arithemetic_interface.h index 6b2806128f7073f57f5ec4c086e7084dda92045c..532978e2cd40f00287d134afe11c7fe239258a48 100644 --- a/oneflow/core/kernel/util/cuda_arithemetic_interface.h +++ b/oneflow/core/kernel/util/cuda_arithemetic_interface.h @@ -29,24 +29,43 @@ class ConstantInitializerConf; template<> struct ArithemeticIf { - static void Transpose(DeviceCtx* ctx, const int32_t num_axis, const ShapeView& x_shape, + static void Transpose(DeviceCtx* ctx, int32_t num_axis, const ShapeView& x_shape, + const ShapeView& y_shape, const std::vector& permutation, + int64_t elem_cnt, const float* x, float* y); + static void Transpose(DeviceCtx* ctx, int32_t num_axis, const ShapeView& x_shape, + const ShapeView& y_shape, const std::vector& permutation, + int64_t elem_cnt, const double* x, double* y); + static void Transpose(DeviceCtx* ctx, int32_t num_axis, const ShapeView& x_shape, + const ShapeView& y_shape, const std::vector& permutation, + int64_t elem_cnt, const float16* x, float16* y); + static void Transpose(DeviceCtx* ctx, int32_t num_axis, const ShapeView& x_shape, + const ShapeView& y_shape, const std::vector& permutation, + int64_t elem_cnt, const int8_t* x, int8_t* y); + static void Transpose(DeviceCtx* ctx, int32_t num_axis, const ShapeView& x_shape, + const ShapeView& y_shape, const std::vector& permutation, + int64_t elem_cnt, const int32_t* x, int32_t* y); + static void Transpose(DeviceCtx* ctx, int32_t num_axis, const ShapeView& x_shape, + const ShapeView& y_shape, const std::vector& permutation, + int64_t elem_cnt, const int64_t* x, int64_t* y); + + static void Transpose(DeviceCtx* ctx, int32_t num_axis, const ShapeView& x_shape, const ShapeView& y_shape, const PbRf& permutation, - const int64_t elem_cnt, const float* x, float* y); - static void Transpose(DeviceCtx* ctx, const int32_t num_axis, const ShapeView& x_shape, + int64_t elem_cnt, const float* x, float* y); + static void Transpose(DeviceCtx* ctx, int32_t num_axis, const ShapeView& x_shape, const ShapeView& y_shape, const PbRf& permutation, - const int64_t elem_cnt, const double* x, double* y); - static void Transpose(DeviceCtx* ctx, const int32_t num_axis, const ShapeView& x_shape, + int64_t elem_cnt, const double* x, double* y); + static void Transpose(DeviceCtx* ctx, int32_t num_axis, const ShapeView& x_shape, const ShapeView& y_shape, const PbRf& permutation, - const int64_t elem_cnt, const float16* x, float16* y); - static void Transpose(DeviceCtx* ctx, const int32_t num_axis, const ShapeView& x_shape, + int64_t elem_cnt, const float16* x, float16* y); + static void Transpose(DeviceCtx* ctx, int32_t num_axis, const ShapeView& x_shape, const ShapeView& y_shape, const PbRf& permutation, - const int64_t elem_cnt, const int8_t* x, int8_t* y); - static void Transpose(DeviceCtx* ctx, const int32_t num_axis, const ShapeView& x_shape, + int64_t elem_cnt, const int8_t* x, int8_t* y); + static void Transpose(DeviceCtx* ctx, int32_t num_axis, const ShapeView& x_shape, const ShapeView& y_shape, const PbRf& permutation, - const int64_t elem_cnt, const int32_t* x, int32_t* y); - static void Transpose(DeviceCtx* ctx, const int32_t num_axis, const ShapeView& x_shape, + int64_t elem_cnt, const int32_t* x, int32_t* y); + static void Transpose(DeviceCtx* ctx, int32_t num_axis, const ShapeView& x_shape, const ShapeView& y_shape, const PbRf& permutation, - const int64_t elem_cnt, const int64_t* x, int64_t* y); + int64_t elem_cnt, const int64_t* x, int64_t* y); static void InitializeWithConstConf(DeviceCtx* ctx, const ConstantInitializerConf& initializer_conf, Blob* blob); diff --git a/oneflow/core/kernel/util/host_arithemetic_interface.cpp b/oneflow/core/kernel/util/host_arithemetic_interface.cpp index 25c5eb8d12e8eb7bf54b5aff54707d9d6a02be19..59e90e60fa232b4ae0f3de2a934d340b905fa7d0 100644 --- a/oneflow/core/kernel/util/host_arithemetic_interface.cpp +++ b/oneflow/core/kernel/util/host_arithemetic_interface.cpp @@ -46,7 +46,7 @@ void IncreaseIndex(const int64_t* shape, DimVector& index) { template void TransposeImpl(DeviceCtx* ctx, const int32_t num_axis, const ShapeView& x_shape, - const ShapeView& y_shape, const PbRf& permutation, + const ShapeView& y_shape, const std::vector& permutation, const int64_t elem_cnt, const T* x, T* y) { int64_t block_size = 1; int32_t shared_idxs_num = 0; @@ -87,14 +87,14 @@ void ConstantInitializer(const T& value, Blob* blob) { void ArithemeticIf::Transpose(DeviceCtx* ctx, const int32_t num_axis, const ShapeView& x_shape, const ShapeView& y_shape, - const PbRf& permutation, + const std::vector& permutation, const int64_t elem_cnt, const float* x, float* y) { TransposeImpl(ctx, num_axis, x_shape, y_shape, permutation, elem_cnt, x, y); } void ArithemeticIf::Transpose(DeviceCtx* ctx, const int32_t num_axis, const ShapeView& x_shape, const ShapeView& y_shape, - const PbRf& permutation, + const std::vector& permutation, const int64_t elem_cnt, const double* x, double* y) { TransposeImpl(ctx, num_axis, x_shape, y_shape, permutation, elem_cnt, x, y); @@ -102,7 +102,7 @@ void ArithemeticIf::Transpose(DeviceCtx* ctx, const int32_t nu void ArithemeticIf::Transpose(DeviceCtx* ctx, const int32_t num_axis, const ShapeView& x_shape, const ShapeView& y_shape, - const PbRf& permutation, + const std::vector& permutation, const int64_t elem_cnt, const int8_t* x, int8_t* y) { TransposeImpl(ctx, num_axis, x_shape, y_shape, permutation, elem_cnt, x, y); @@ -110,7 +110,7 @@ void ArithemeticIf::Transpose(DeviceCtx* ctx, const int32_t nu void ArithemeticIf::Transpose(DeviceCtx* ctx, const int32_t num_axis, const ShapeView& x_shape, const ShapeView& y_shape, - const PbRf& permutation, + const std::vector& permutation, const int64_t elem_cnt, const int32_t* x, int32_t* y) { TransposeImpl(ctx, num_axis, x_shape, y_shape, permutation, elem_cnt, x, y); @@ -118,12 +118,61 @@ void ArithemeticIf::Transpose(DeviceCtx* ctx, const int32_t nu void ArithemeticIf::Transpose(DeviceCtx* ctx, const int32_t num_axis, const ShapeView& x_shape, const ShapeView& y_shape, - const PbRf& permutation, + const std::vector& permutation, const int64_t elem_cnt, const int64_t* x, int64_t* y) { TransposeImpl(ctx, num_axis, x_shape, y_shape, permutation, elem_cnt, x, y); } +void ArithemeticIf::Transpose(DeviceCtx* ctx, const int32_t num_axis, + const ShapeView& x_shape, const ShapeView& y_shape, + const PbRf& permutation, + const int64_t elem_cnt, const float* x, float* y) { + TransposeImpl(ctx, num_axis, x_shape, y_shape, + std::vector({permutation.cbegin(), permutation.cend()}), elem_cnt, + x, y); +} + +void ArithemeticIf::Transpose(DeviceCtx* ctx, const int32_t num_axis, + const ShapeView& x_shape, const ShapeView& y_shape, + const PbRf& permutation, + const int64_t elem_cnt, const double* x, + double* y) { + TransposeImpl(ctx, num_axis, x_shape, y_shape, + std::vector({permutation.cbegin(), permutation.cend()}), elem_cnt, + x, y); +} + +void ArithemeticIf::Transpose(DeviceCtx* ctx, const int32_t num_axis, + const ShapeView& x_shape, const ShapeView& y_shape, + const PbRf& permutation, + const int64_t elem_cnt, const int8_t* x, + int8_t* y) { + TransposeImpl(ctx, num_axis, x_shape, y_shape, + std::vector({permutation.cbegin(), permutation.cend()}), elem_cnt, + x, y); +} + +void ArithemeticIf::Transpose(DeviceCtx* ctx, const int32_t num_axis, + const ShapeView& x_shape, const ShapeView& y_shape, + const PbRf& permutation, + const int64_t elem_cnt, const int32_t* x, + int32_t* y) { + TransposeImpl(ctx, num_axis, x_shape, y_shape, + std::vector({permutation.cbegin(), permutation.cend()}), elem_cnt, + x, y); +} + +void ArithemeticIf::Transpose(DeviceCtx* ctx, const int32_t num_axis, + const ShapeView& x_shape, const ShapeView& y_shape, + const PbRf& permutation, + const int64_t elem_cnt, const int64_t* x, + int64_t* y) { + TransposeImpl(ctx, num_axis, x_shape, y_shape, + std::vector({permutation.cbegin(), permutation.cend()}), elem_cnt, + x, y); +} + void ArithemeticIf::InitializeWithConstConf( DeviceCtx* ctx, const ConstantInitializerConf& initializer_conf, Blob* blob) { DataType dtype = blob->data_type(); diff --git a/oneflow/core/kernel/util/host_arithemetic_interface.h b/oneflow/core/kernel/util/host_arithemetic_interface.h index d961cc7c9706c3d9df1b07b5413b72a9d8ec050d..6c68b54895465a7f5d494c5aed8401f702114109 100644 --- a/oneflow/core/kernel/util/host_arithemetic_interface.h +++ b/oneflow/core/kernel/util/host_arithemetic_interface.h @@ -27,21 +27,37 @@ class ConstantInitializerConf; template<> struct ArithemeticIf { - static void Transpose(DeviceCtx* ctx, const int32_t num_axis, const ShapeView& x_shape, + static void Transpose(DeviceCtx* ctx, int32_t num_axis, const ShapeView& x_shape, + const ShapeView& y_shape, const std::vector& permutation, + int64_t elem_cnt, const float* x, float* y); + static void Transpose(DeviceCtx* ctx, int32_t num_axis, const ShapeView& x_shape, + const ShapeView& y_shape, const std::vector& permutation, + int64_t elem_cnt, const double* x, double* y); + static void Transpose(DeviceCtx* ctx, int32_t num_axis, const ShapeView& x_shape, + const ShapeView& y_shape, const std::vector& permutation, + int64_t elem_cnt, const int8_t* x, int8_t* y); + static void Transpose(DeviceCtx* ctx, int32_t num_axis, const ShapeView& x_shape, + const ShapeView& y_shape, const std::vector& permutation, + int64_t elem_cnt, const int32_t* x, int32_t* y); + static void Transpose(DeviceCtx* ctx, int32_t num_axis, const ShapeView& x_shape, + const ShapeView& y_shape, const std::vector& permutation, + int64_t elem_cnt, const int64_t* x, int64_t* y); + + static void Transpose(DeviceCtx* ctx, int32_t num_axis, const ShapeView& x_shape, const ShapeView& y_shape, const PbRf& permutation, - const int64_t elem_cnt, const float* x, float* y); - static void Transpose(DeviceCtx* ctx, const int32_t num_axis, const ShapeView& x_shape, + int64_t elem_cnt, const float* x, float* y); + static void Transpose(DeviceCtx* ctx, int32_t num_axis, const ShapeView& x_shape, const ShapeView& y_shape, const PbRf& permutation, - const int64_t elem_cnt, const double* x, double* y); - static void Transpose(DeviceCtx* ctx, const int32_t num_axis, const ShapeView& x_shape, + int64_t elem_cnt, const double* x, double* y); + static void Transpose(DeviceCtx* ctx, int32_t num_axis, const ShapeView& x_shape, const ShapeView& y_shape, const PbRf& permutation, - const int64_t elem_cnt, const int8_t* x, int8_t* y); - static void Transpose(DeviceCtx* ctx, const int32_t num_axis, const ShapeView& x_shape, + int64_t elem_cnt, const int8_t* x, int8_t* y); + static void Transpose(DeviceCtx* ctx, int32_t num_axis, const ShapeView& x_shape, const ShapeView& y_shape, const PbRf& permutation, - const int64_t elem_cnt, const int32_t* x, int32_t* y); - static void Transpose(DeviceCtx* ctx, const int32_t num_axis, const ShapeView& x_shape, + int64_t elem_cnt, const int32_t* x, int32_t* y); + static void Transpose(DeviceCtx* ctx, int32_t num_axis, const ShapeView& x_shape, const ShapeView& y_shape, const PbRf& permutation, - const int64_t elem_cnt, const int64_t* x, int64_t* y); + int64_t elem_cnt, const int64_t* x, int64_t* y); static void InitializeWithConstConf(DeviceCtx* ctx, const ConstantInitializerConf& initializer_conf, Blob* blob); diff --git a/oneflow/user/kernels/transpose_kernel.cpp b/oneflow/user/kernels/transpose_kernel.cpp index 5a4dbfc91d2db203cd308958c96c4ebd3d0e1c92..ee1991ff20daca78d6b469411ac033597ac6a39a 100644 --- a/oneflow/user/kernels/transpose_kernel.cpp +++ b/oneflow/user/kernels/transpose_kernel.cpp @@ -25,17 +25,37 @@ template class TransposeKernel final : public OpKernel { public: TransposeKernel() = default; - ~TransposeKernel() = default; + ~TransposeKernel() override = default; private: void Compute(KernelComputeContext* ctx) const override { const Tensor* tensor_in = ctx->Tensor4ArgNameAndIndex("input", 0); Tensor* tensor_out = ctx->Tensor4ArgNameAndIndex("output", 0); const auto& perm = ctx->Attr>("perm"); - NewKernelUtil::Transpose(ctx->device_ctx(), tensor_in->shape().NumAxes(), - tensor_in->shape(), tensor_out->shape(), - StdVec2PbRf(perm), tensor_in->shape().elem_cnt(), - tensor_in->dptr(), tensor_out->mut_dptr()); + using PackType = int64_t; + const size_t num_elem_per_pack = sizeof(PackType) / sizeof(T); + const ShapeView& in_shape = tensor_in->shape(); + const ShapeView& out_shape = tensor_out->shape(); + if (num_elem_per_pack != 1 && perm.back() == perm.size() - 1 + && in_shape.At(in_shape.NumAxes() - 1) % num_elem_per_pack == 0) { + CHECK_EQ(in_shape.At(in_shape.NumAxes() - 1), out_shape.At(out_shape.NumAxes() - 1)); + DimVector packed_in_dim_vec; + in_shape.ToDimVector(&packed_in_dim_vec); + packed_in_dim_vec.back() /= num_elem_per_pack; + const Shape packed_in_shape(packed_in_dim_vec); + DimVector packed_out_dim_vec; + out_shape.ToDimVector(&packed_out_dim_vec); + packed_out_dim_vec.back() /= num_elem_per_pack; + const Shape packed_out_shape(packed_out_dim_vec); + NewKernelUtil::Transpose( + ctx->device_ctx(), packed_in_shape.NumAxes(), packed_in_shape, packed_out_shape, perm, + packed_in_shape.elem_cnt(), reinterpret_cast(tensor_in->dptr()), + reinterpret_cast(tensor_out->mut_dptr())); + } else { + NewKernelUtil::Transpose(ctx->device_ctx(), in_shape.NumAxes(), in_shape, + tensor_out->shape(), perm, in_shape.elem_cnt(), + tensor_in->dptr(), tensor_out->mut_dptr()); + } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } };