提交 396bd65f 编写于 作者: J Juncheng 提交者: GitHub

Optimize transpose performance (#3487)



Former-commit-id: 809793c4
上级 ea1d417c
......@@ -33,8 +33,9 @@ template<int32_t NDIMS>
__device__ int32_t GetXIndex(const int32_t* y_shape, const int32_t* x_strides, int32_t y_idx) {
int32_t x_idx = 0;
for (int32_t i = NDIMS - 1; i >= 0; --i) {
x_idx += (y_idx % y_shape[i]) * x_strides[i];
y_idx /= y_shape[i];
const int32_t next_y_idx = y_idx / y_shape[i];
x_idx += (y_idx - next_y_idx * y_shape[i]) * x_strides[i];
y_idx = next_y_idx;
}
return x_idx;
}
......@@ -42,16 +43,8 @@ __device__ int32_t GetXIndex(const int32_t* y_shape, const int32_t* x_strides, i
template<int32_t NDIMS, typename T>
__global__ void TransposeGpu(const Int32Array<NDIMS> y_shape, const Int32Array<NDIMS> x_strides,
const int32_t elem_cnt, const T* x, T* y) {
__shared__ int32_t x_strides_shared[NDIMS];
__shared__ int32_t y_dims_shared[NDIMS];
const int32_t tid = threadIdx.x;
if (tid < NDIMS) {
y_dims_shared[tid] = y_shape.val[tid];
x_strides_shared[tid] = x_strides.val[tid];
}
__syncthreads();
CUDA_1D_KERNEL_LOOP(y_idx, elem_cnt) {
const int32_t x_idx = GetXIndex<NDIMS>(y_dims_shared, x_strides_shared, y_idx);
const int32_t x_idx = GetXIndex<NDIMS>(y_shape.val, x_strides.val, y_idx);
#if __CUDA_ARCH__ >= 350
y[y_idx] = __ldg(x + x_idx);
#else
......@@ -62,7 +55,8 @@ __global__ void TransposeGpu(const Int32Array<NDIMS> y_shape, const Int32Array<N
template<int32_t NDIMS, typename T>
void TransposeImpl(DeviceCtx* ctx, const ShapeView& x_shape, const ShapeView& y_shape,
const PbRf<int32_t>& permutation, const int64_t elem_cnt, const T* x, T* y) {
const std::vector<int32_t>& permutation, const int64_t elem_cnt, const T* x,
T* y) {
CHECK_LE(y_shape.elem_cnt(), GetMaxVal<int32_t>());
Int32Array<NDIMS> y_shape_struct;
FOR_RANGE(int32_t, i, 0, NDIMS) { y_shape_struct.val[i] = y_shape.At(i); }
......@@ -95,7 +89,7 @@ struct TransposeUtil final {
void ArithemeticIf<DeviceType::kGPU>::Transpose(DeviceCtx* ctx, const int32_t num_axis,
const ShapeView& x_shape, const ShapeView& y_shape,
const PbRf<int32_t>& permutation,
const std::vector<int32_t>& permutation,
const int64_t elem_cnt, const float* x, float* y) {
TRANSPOSE_CHECK;
TransposeUtil<float>::SwitchTransposeImpl(SwitchCase(num_axis), ctx, x_shape, y_shape,
......@@ -104,7 +98,7 @@ void ArithemeticIf<DeviceType::kGPU>::Transpose(DeviceCtx* ctx, const int32_t nu
void ArithemeticIf<DeviceType::kGPU>::Transpose(DeviceCtx* ctx, const int32_t num_axis,
const ShapeView& x_shape, const ShapeView& y_shape,
const PbRf<int32_t>& permutation,
const std::vector<int32_t>& permutation,
const int64_t elem_cnt, const double* x,
double* y) {
TRANSPOSE_CHECK;
......@@ -114,7 +108,7 @@ void ArithemeticIf<DeviceType::kGPU>::Transpose(DeviceCtx* ctx, const int32_t nu
void ArithemeticIf<DeviceType::kGPU>::Transpose(DeviceCtx* ctx, const int32_t num_axis,
const ShapeView& x_shape, const ShapeView& y_shape,
const PbRf<int32_t>& permutation,
const std::vector<int32_t>& permutation,
const int64_t elem_cnt, const float16* x,
float16* y) {
TRANSPOSE_CHECK;
......@@ -125,7 +119,7 @@ void ArithemeticIf<DeviceType::kGPU>::Transpose(DeviceCtx* ctx, const int32_t nu
void ArithemeticIf<DeviceType::kGPU>::Transpose(DeviceCtx* ctx, const int32_t num_axis,
const ShapeView& x_shape, const ShapeView& y_shape,
const PbRf<int32_t>& permutation,
const std::vector<int32_t>& permutation,
const int64_t elem_cnt, const int8_t* x,
int8_t* y) {
TRANSPOSE_CHECK;
......@@ -135,7 +129,7 @@ void ArithemeticIf<DeviceType::kGPU>::Transpose(DeviceCtx* ctx, const int32_t nu
void ArithemeticIf<DeviceType::kGPU>::Transpose(DeviceCtx* ctx, const int32_t num_axis,
const ShapeView& x_shape, const ShapeView& y_shape,
const PbRf<int32_t>& permutation,
const std::vector<int32_t>& permutation,
const int64_t elem_cnt, const int32_t* x,
int32_t* y) {
TRANSPOSE_CHECK;
......@@ -145,7 +139,7 @@ void ArithemeticIf<DeviceType::kGPU>::Transpose(DeviceCtx* ctx, const int32_t nu
void ArithemeticIf<DeviceType::kGPU>::Transpose(DeviceCtx* ctx, const int32_t num_axis,
const ShapeView& x_shape, const ShapeView& y_shape,
const PbRf<int32_t>& permutation,
const std::vector<int32_t>& permutation,
const int64_t elem_cnt, const int64_t* x,
int64_t* y) {
TRANSPOSE_CHECK;
......@@ -155,6 +149,65 @@ void ArithemeticIf<DeviceType::kGPU>::Transpose(DeviceCtx* ctx, const int32_t nu
#undef TRANSPOSE_CHECK
void ArithemeticIf<DeviceType::kGPU>::Transpose(DeviceCtx* ctx, const int32_t num_axis,
const ShapeView& x_shape, const ShapeView& y_shape,
const PbRf<int32_t>& permutation,
const int64_t elem_cnt, const float* x, float* y) {
ArithemeticIf<DeviceType::kGPU>::Transpose(
ctx, num_axis, x_shape, y_shape,
std::vector<int32_t>({permutation.cbegin(), permutation.cend()}), elem_cnt, x, y);
}
void ArithemeticIf<DeviceType::kGPU>::Transpose(DeviceCtx* ctx, const int32_t num_axis,
const ShapeView& x_shape, const ShapeView& y_shape,
const PbRf<int32_t>& permutation,
const int64_t elem_cnt, const double* x,
double* y) {
ArithemeticIf<DeviceType::kGPU>::Transpose(
ctx, num_axis, x_shape, y_shape,
std::vector<int32_t>({permutation.cbegin(), permutation.cend()}), elem_cnt, x, y);
}
void ArithemeticIf<DeviceType::kGPU>::Transpose(DeviceCtx* ctx, const int32_t num_axis,
const ShapeView& x_shape, const ShapeView& y_shape,
const PbRf<int32_t>& permutation,
const int64_t elem_cnt, const float16* x,
float16* y) {
ArithemeticIf<DeviceType::kGPU>::Transpose(
ctx, num_axis, x_shape, y_shape,
std::vector<int32_t>({permutation.cbegin(), permutation.cend()}), elem_cnt, x, y);
}
void ArithemeticIf<DeviceType::kGPU>::Transpose(DeviceCtx* ctx, const int32_t num_axis,
const ShapeView& x_shape, const ShapeView& y_shape,
const PbRf<int32_t>& permutation,
const int64_t elem_cnt, const int8_t* x,
int8_t* y) {
ArithemeticIf<DeviceType::kGPU>::Transpose(
ctx, num_axis, x_shape, y_shape,
std::vector<int32_t>({permutation.cbegin(), permutation.cend()}), elem_cnt, x, y);
}
void ArithemeticIf<DeviceType::kGPU>::Transpose(DeviceCtx* ctx, const int32_t num_axis,
const ShapeView& x_shape, const ShapeView& y_shape,
const PbRf<int32_t>& permutation,
const int64_t elem_cnt, const int32_t* x,
int32_t* y) {
ArithemeticIf<DeviceType::kGPU>::Transpose(
ctx, num_axis, x_shape, y_shape,
std::vector<int32_t>({permutation.cbegin(), permutation.cend()}), elem_cnt, x, y);
}
void ArithemeticIf<DeviceType::kGPU>::Transpose(DeviceCtx* ctx, const int32_t num_axis,
const ShapeView& x_shape, const ShapeView& y_shape,
const PbRf<int32_t>& permutation,
const int64_t elem_cnt, const int64_t* x,
int64_t* y) {
ArithemeticIf<DeviceType::kGPU>::Transpose(
ctx, num_axis, x_shape, y_shape,
std::vector<int32_t>({permutation.cbegin(), permutation.cend()}), elem_cnt, x, y);
}
void ArithemeticIf<DeviceType::kGPU>::InitializeWithConstConf(
DeviceCtx* ctx, const ConstantInitializerConf& initializer_conf, Blob* blob) {
WithHostBlobAndStreamSynchronizeEnv(ctx, blob, [&](Blob* host_blob) {
......
......@@ -29,24 +29,43 @@ class ConstantInitializerConf;
template<>
struct ArithemeticIf<DeviceType::kGPU> {
static void Transpose(DeviceCtx* ctx, const int32_t num_axis, const ShapeView& x_shape,
static void Transpose(DeviceCtx* ctx, int32_t num_axis, const ShapeView& x_shape,
const ShapeView& y_shape, const std::vector<int32_t>& permutation,
int64_t elem_cnt, const float* x, float* y);
static void Transpose(DeviceCtx* ctx, int32_t num_axis, const ShapeView& x_shape,
const ShapeView& y_shape, const std::vector<int32_t>& permutation,
int64_t elem_cnt, const double* x, double* y);
static void Transpose(DeviceCtx* ctx, int32_t num_axis, const ShapeView& x_shape,
const ShapeView& y_shape, const std::vector<int32_t>& permutation,
int64_t elem_cnt, const float16* x, float16* y);
static void Transpose(DeviceCtx* ctx, int32_t num_axis, const ShapeView& x_shape,
const ShapeView& y_shape, const std::vector<int32_t>& permutation,
int64_t elem_cnt, const int8_t* x, int8_t* y);
static void Transpose(DeviceCtx* ctx, int32_t num_axis, const ShapeView& x_shape,
const ShapeView& y_shape, const std::vector<int32_t>& permutation,
int64_t elem_cnt, const int32_t* x, int32_t* y);
static void Transpose(DeviceCtx* ctx, int32_t num_axis, const ShapeView& x_shape,
const ShapeView& y_shape, const std::vector<int32_t>& permutation,
int64_t elem_cnt, const int64_t* x, int64_t* y);
static void Transpose(DeviceCtx* ctx, int32_t num_axis, const ShapeView& x_shape,
const ShapeView& y_shape, const PbRf<int32_t>& permutation,
const int64_t elem_cnt, const float* x, float* y);
static void Transpose(DeviceCtx* ctx, const int32_t num_axis, const ShapeView& x_shape,
int64_t elem_cnt, const float* x, float* y);
static void Transpose(DeviceCtx* ctx, int32_t num_axis, const ShapeView& x_shape,
const ShapeView& y_shape, const PbRf<int32_t>& permutation,
const int64_t elem_cnt, const double* x, double* y);
static void Transpose(DeviceCtx* ctx, const int32_t num_axis, const ShapeView& x_shape,
int64_t elem_cnt, const double* x, double* y);
static void Transpose(DeviceCtx* ctx, int32_t num_axis, const ShapeView& x_shape,
const ShapeView& y_shape, const PbRf<int32_t>& permutation,
const int64_t elem_cnt, const float16* x, float16* y);
static void Transpose(DeviceCtx* ctx, const int32_t num_axis, const ShapeView& x_shape,
int64_t elem_cnt, const float16* x, float16* y);
static void Transpose(DeviceCtx* ctx, int32_t num_axis, const ShapeView& x_shape,
const ShapeView& y_shape, const PbRf<int32_t>& permutation,
const int64_t elem_cnt, const int8_t* x, int8_t* y);
static void Transpose(DeviceCtx* ctx, const int32_t num_axis, const ShapeView& x_shape,
int64_t elem_cnt, const int8_t* x, int8_t* y);
static void Transpose(DeviceCtx* ctx, int32_t num_axis, const ShapeView& x_shape,
const ShapeView& y_shape, const PbRf<int32_t>& permutation,
const int64_t elem_cnt, const int32_t* x, int32_t* y);
static void Transpose(DeviceCtx* ctx, const int32_t num_axis, const ShapeView& x_shape,
int64_t elem_cnt, const int32_t* x, int32_t* y);
static void Transpose(DeviceCtx* ctx, int32_t num_axis, const ShapeView& x_shape,
const ShapeView& y_shape, const PbRf<int32_t>& permutation,
const int64_t elem_cnt, const int64_t* x, int64_t* y);
int64_t elem_cnt, const int64_t* x, int64_t* y);
static void InitializeWithConstConf(DeviceCtx* ctx,
const ConstantInitializerConf& initializer_conf, Blob* blob);
......
......@@ -46,7 +46,7 @@ void IncreaseIndex(const int64_t* shape, DimVector& index) {
template<typename T>
void TransposeImpl(DeviceCtx* ctx, const int32_t num_axis, const ShapeView& x_shape,
const ShapeView& y_shape, const PbRf<int32_t>& permutation,
const ShapeView& y_shape, const std::vector<int32_t>& permutation,
const int64_t elem_cnt, const T* x, T* y) {
int64_t block_size = 1;
int32_t shared_idxs_num = 0;
......@@ -87,14 +87,14 @@ void ConstantInitializer(const T& value, Blob* blob) {
void ArithemeticIf<DeviceType::kCPU>::Transpose(DeviceCtx* ctx, const int32_t num_axis,
const ShapeView& x_shape, const ShapeView& y_shape,
const PbRf<int32_t>& permutation,
const std::vector<int32_t>& permutation,
const int64_t elem_cnt, const float* x, float* y) {
TransposeImpl<float>(ctx, num_axis, x_shape, y_shape, permutation, elem_cnt, x, y);
}
void ArithemeticIf<DeviceType::kCPU>::Transpose(DeviceCtx* ctx, const int32_t num_axis,
const ShapeView& x_shape, const ShapeView& y_shape,
const PbRf<int32_t>& permutation,
const std::vector<int32_t>& permutation,
const int64_t elem_cnt, const double* x,
double* y) {
TransposeImpl<double>(ctx, num_axis, x_shape, y_shape, permutation, elem_cnt, x, y);
......@@ -102,7 +102,7 @@ void ArithemeticIf<DeviceType::kCPU>::Transpose(DeviceCtx* ctx, const int32_t nu
void ArithemeticIf<DeviceType::kCPU>::Transpose(DeviceCtx* ctx, const int32_t num_axis,
const ShapeView& x_shape, const ShapeView& y_shape,
const PbRf<int32_t>& permutation,
const std::vector<int32_t>& permutation,
const int64_t elem_cnt, const int8_t* x,
int8_t* y) {
TransposeImpl<int8_t>(ctx, num_axis, x_shape, y_shape, permutation, elem_cnt, x, y);
......@@ -110,7 +110,7 @@ void ArithemeticIf<DeviceType::kCPU>::Transpose(DeviceCtx* ctx, const int32_t nu
void ArithemeticIf<DeviceType::kCPU>::Transpose(DeviceCtx* ctx, const int32_t num_axis,
const ShapeView& x_shape, const ShapeView& y_shape,
const PbRf<int32_t>& permutation,
const std::vector<int32_t>& permutation,
const int64_t elem_cnt, const int32_t* x,
int32_t* y) {
TransposeImpl<int32_t>(ctx, num_axis, x_shape, y_shape, permutation, elem_cnt, x, y);
......@@ -118,12 +118,61 @@ void ArithemeticIf<DeviceType::kCPU>::Transpose(DeviceCtx* ctx, const int32_t nu
void ArithemeticIf<DeviceType::kCPU>::Transpose(DeviceCtx* ctx, const int32_t num_axis,
const ShapeView& x_shape, const ShapeView& y_shape,
const PbRf<int32_t>& permutation,
const std::vector<int32_t>& permutation,
const int64_t elem_cnt, const int64_t* x,
int64_t* y) {
TransposeImpl<int64_t>(ctx, num_axis, x_shape, y_shape, permutation, elem_cnt, x, y);
}
void ArithemeticIf<DeviceType::kCPU>::Transpose(DeviceCtx* ctx, const int32_t num_axis,
const ShapeView& x_shape, const ShapeView& y_shape,
const PbRf<int32_t>& permutation,
const int64_t elem_cnt, const float* x, float* y) {
TransposeImpl<float>(ctx, num_axis, x_shape, y_shape,
std::vector<int32_t>({permutation.cbegin(), permutation.cend()}), elem_cnt,
x, y);
}
void ArithemeticIf<DeviceType::kCPU>::Transpose(DeviceCtx* ctx, const int32_t num_axis,
const ShapeView& x_shape, const ShapeView& y_shape,
const PbRf<int32_t>& permutation,
const int64_t elem_cnt, const double* x,
double* y) {
TransposeImpl<double>(ctx, num_axis, x_shape, y_shape,
std::vector<int32_t>({permutation.cbegin(), permutation.cend()}), elem_cnt,
x, y);
}
void ArithemeticIf<DeviceType::kCPU>::Transpose(DeviceCtx* ctx, const int32_t num_axis,
const ShapeView& x_shape, const ShapeView& y_shape,
const PbRf<int32_t>& permutation,
const int64_t elem_cnt, const int8_t* x,
int8_t* y) {
TransposeImpl<int8_t>(ctx, num_axis, x_shape, y_shape,
std::vector<int32_t>({permutation.cbegin(), permutation.cend()}), elem_cnt,
x, y);
}
void ArithemeticIf<DeviceType::kCPU>::Transpose(DeviceCtx* ctx, const int32_t num_axis,
const ShapeView& x_shape, const ShapeView& y_shape,
const PbRf<int32_t>& permutation,
const int64_t elem_cnt, const int32_t* x,
int32_t* y) {
TransposeImpl<int32_t>(ctx, num_axis, x_shape, y_shape,
std::vector<int32_t>({permutation.cbegin(), permutation.cend()}), elem_cnt,
x, y);
}
void ArithemeticIf<DeviceType::kCPU>::Transpose(DeviceCtx* ctx, const int32_t num_axis,
const ShapeView& x_shape, const ShapeView& y_shape,
const PbRf<int32_t>& permutation,
const int64_t elem_cnt, const int64_t* x,
int64_t* y) {
TransposeImpl<int64_t>(ctx, num_axis, x_shape, y_shape,
std::vector<int32_t>({permutation.cbegin(), permutation.cend()}), elem_cnt,
x, y);
}
void ArithemeticIf<DeviceType::kCPU>::InitializeWithConstConf(
DeviceCtx* ctx, const ConstantInitializerConf& initializer_conf, Blob* blob) {
DataType dtype = blob->data_type();
......
......@@ -27,21 +27,37 @@ class ConstantInitializerConf;
template<>
struct ArithemeticIf<DeviceType::kCPU> {
static void Transpose(DeviceCtx* ctx, const int32_t num_axis, const ShapeView& x_shape,
static void Transpose(DeviceCtx* ctx, int32_t num_axis, const ShapeView& x_shape,
const ShapeView& y_shape, const std::vector<int32_t>& permutation,
int64_t elem_cnt, const float* x, float* y);
static void Transpose(DeviceCtx* ctx, int32_t num_axis, const ShapeView& x_shape,
const ShapeView& y_shape, const std::vector<int32_t>& permutation,
int64_t elem_cnt, const double* x, double* y);
static void Transpose(DeviceCtx* ctx, int32_t num_axis, const ShapeView& x_shape,
const ShapeView& y_shape, const std::vector<int32_t>& permutation,
int64_t elem_cnt, const int8_t* x, int8_t* y);
static void Transpose(DeviceCtx* ctx, int32_t num_axis, const ShapeView& x_shape,
const ShapeView& y_shape, const std::vector<int32_t>& permutation,
int64_t elem_cnt, const int32_t* x, int32_t* y);
static void Transpose(DeviceCtx* ctx, int32_t num_axis, const ShapeView& x_shape,
const ShapeView& y_shape, const std::vector<int32_t>& permutation,
int64_t elem_cnt, const int64_t* x, int64_t* y);
static void Transpose(DeviceCtx* ctx, int32_t num_axis, const ShapeView& x_shape,
const ShapeView& y_shape, const PbRf<int32_t>& permutation,
const int64_t elem_cnt, const float* x, float* y);
static void Transpose(DeviceCtx* ctx, const int32_t num_axis, const ShapeView& x_shape,
int64_t elem_cnt, const float* x, float* y);
static void Transpose(DeviceCtx* ctx, int32_t num_axis, const ShapeView& x_shape,
const ShapeView& y_shape, const PbRf<int32_t>& permutation,
const int64_t elem_cnt, const double* x, double* y);
static void Transpose(DeviceCtx* ctx, const int32_t num_axis, const ShapeView& x_shape,
int64_t elem_cnt, const double* x, double* y);
static void Transpose(DeviceCtx* ctx, int32_t num_axis, const ShapeView& x_shape,
const ShapeView& y_shape, const PbRf<int32_t>& permutation,
const int64_t elem_cnt, const int8_t* x, int8_t* y);
static void Transpose(DeviceCtx* ctx, const int32_t num_axis, const ShapeView& x_shape,
int64_t elem_cnt, const int8_t* x, int8_t* y);
static void Transpose(DeviceCtx* ctx, int32_t num_axis, const ShapeView& x_shape,
const ShapeView& y_shape, const PbRf<int32_t>& permutation,
const int64_t elem_cnt, const int32_t* x, int32_t* y);
static void Transpose(DeviceCtx* ctx, const int32_t num_axis, const ShapeView& x_shape,
int64_t elem_cnt, const int32_t* x, int32_t* y);
static void Transpose(DeviceCtx* ctx, int32_t num_axis, const ShapeView& x_shape,
const ShapeView& y_shape, const PbRf<int32_t>& permutation,
const int64_t elem_cnt, const int64_t* x, int64_t* y);
int64_t elem_cnt, const int64_t* x, int64_t* y);
static void InitializeWithConstConf(DeviceCtx* ctx,
const ConstantInitializerConf& initializer_conf, Blob* blob);
......
......@@ -25,17 +25,37 @@ template<DeviceType device_type, typename T>
class TransposeKernel final : public OpKernel {
public:
TransposeKernel() = default;
~TransposeKernel() = default;
~TransposeKernel() override = default;
private:
void Compute(KernelComputeContext* ctx) const override {
const Tensor* tensor_in = ctx->Tensor4ArgNameAndIndex("input", 0);
Tensor* tensor_out = ctx->Tensor4ArgNameAndIndex("output", 0);
const auto& perm = ctx->Attr<std::vector<int32_t>>("perm");
NewKernelUtil<device_type>::Transpose(ctx->device_ctx(), tensor_in->shape().NumAxes(),
tensor_in->shape(), tensor_out->shape(),
StdVec2PbRf(perm), tensor_in->shape().elem_cnt(),
tensor_in->dptr<T>(), tensor_out->mut_dptr<T>());
using PackType = int64_t;
const size_t num_elem_per_pack = sizeof(PackType) / sizeof(T);
const ShapeView& in_shape = tensor_in->shape();
const ShapeView& out_shape = tensor_out->shape();
if (num_elem_per_pack != 1 && perm.back() == perm.size() - 1
&& in_shape.At(in_shape.NumAxes() - 1) % num_elem_per_pack == 0) {
CHECK_EQ(in_shape.At(in_shape.NumAxes() - 1), out_shape.At(out_shape.NumAxes() - 1));
DimVector packed_in_dim_vec;
in_shape.ToDimVector(&packed_in_dim_vec);
packed_in_dim_vec.back() /= num_elem_per_pack;
const Shape packed_in_shape(packed_in_dim_vec);
DimVector packed_out_dim_vec;
out_shape.ToDimVector(&packed_out_dim_vec);
packed_out_dim_vec.back() /= num_elem_per_pack;
const Shape packed_out_shape(packed_out_dim_vec);
NewKernelUtil<device_type>::Transpose(
ctx->device_ctx(), packed_in_shape.NumAxes(), packed_in_shape, packed_out_shape, perm,
packed_in_shape.elem_cnt(), reinterpret_cast<const PackType*>(tensor_in->dptr<T>()),
reinterpret_cast<PackType*>(tensor_out->mut_dptr<T>()));
} else {
NewKernelUtil<device_type>::Transpose(ctx->device_ctx(), in_shape.NumAxes(), in_shape,
tensor_out->shape(), perm, in_shape.elem_cnt(),
tensor_in->dptr<T>(), tensor_out->mut_dptr<T>());
}
}
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册