未验证 提交 cba5736f 编写于 作者: W Wilber 提交者: GitHub

add transpose kernel for cuda test=develop (#1997)

add transpose kernel for cuda
上级 83d4b0e8
......@@ -5,6 +5,7 @@ endif()
nv_library(cuda_activation SRCS activation.cu)
nv_library(cuda_scale SRCS scale.cu)
nv_library(cuda_type_trans SRCS type_trans.cu)
nv_library(cuda_transpose SRCS transpose.cu)
nv_library(cudnn_conv SRCS cudnn_conv.cc DEPS cuda_activation cuda_scale
cuda_type_trans)
......@@ -14,6 +15,7 @@ set (
cuda_activation
cuda_scale
cuda_type_trans
cuda_transpose
)
set(math_cuda "${math_cuda}" CACHE GLOBAL "math cuda")
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/backends/cuda/math/transpose.h"
#include "lite/backends/cuda/math/utils.h"
namespace paddle {
namespace lite {
namespace cuda {
namespace math {
constexpr int kTileDim = 32;
constexpr int kBlockRows = 8;
constexpr int CUDA_NUM_THREADS = 128;
// Splits the original matrix into submatrices with size 32 * 32.
// Reference https://devblogs.nvidia.com/efficient-matrix-transpose-cuda-cc/
template <typename T>
__global__ void BatchTranspose2DCUDAKernel(const int N,
const int H,
const int W,
const int dh,
const int dw,
const T* input,
T* out) {
__shared__ T tile[kTileDim][kTileDim + 1]; // plus 1 to prevent bank confict.
const int n = blockIdx.x / (dh * dw);
const int k = blockIdx.x % (dh * dw);
const int r = k / dw;
const int c = k % dw;
const int offset = n * H * W;
int x = c * kTileDim + threadIdx.x;
int y = r * kTileDim + threadIdx.y;
if (x < W) {
for (int i = 0; threadIdx.y + i < kTileDim && y + i < H; i += kBlockRows) {
#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
tile[threadIdx.y + i][threadIdx.x] =
__ldg(input + offset + (y + i) * W + x);
#else
tile[threadIdx.y + i][threadIdx.x] = input[offset + (y + i) * W + x];
#endif
}
}
__syncthreads();
x = r * kTileDim + threadIdx.x;
y = c * kTileDim + threadIdx.y;
if (x < H) {
for (int i = 0; threadIdx.y + i < kTileDim && y + i < W; i += kBlockRows) {
out[offset + (y + i) * H + x] = tile[threadIdx.x][threadIdx.y + i];
}
}
}
template <typename T>
void BatchTranspose2DCUDAImpl(const int N,
const int H,
const int W,
const T* input,
T* out,
CUDAContext* ctx) {
const int dh = (H + kTileDim - 1) / kTileDim;
const int dw = (W + kTileDim - 1) / kTileDim;
BatchTranspose2DCUDAKernel<
T><<<N * dh * dw, dim3(kTileDim, kBlockRows), 0, ctx->exec_stream()>>>(
N, H, W, dh, dw, input, out);
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
}
#define TYPE_SPECIALIZED_CUDA_NCHW2NHWC(T) \
template <> \
void NCHW2NHWC<T>(const int N, \
const int C, \
const int HxW, \
const T* X, \
T* Y, \
CUDAContext* ctx) { \
BatchTranspose2DCUDAImpl<T>(N, C, HxW, X, Y, ctx); \
}
TYPE_SPECIALIZED_CUDA_NCHW2NHWC(float)
#undef TYPE_SPECIALIZED_CUDA_NCHW2NHWC
#define TYPE_SPECIALIZED_CUDA_NHWC2NCHW(T) \
template <> \
void NHWC2NCHW<T>(const int N, \
const int C, \
const int HxW, \
const T* X, \
T* Y, \
CUDAContext* ctx) { \
BatchTranspose2DCUDAImpl<T>(N, HxW, C, X, Y, ctx); \
}
TYPE_SPECIALIZED_CUDA_NHWC2NCHW(float)
#undef TYPE_SPECIALIZED_CUDA_NHWC2NCHW
template <typename T>
__global__ void TransposeCUDAKernel(const int size,
const int ndim,
const int* X_strides,
const int* Y_dims,
const T* X,
T* Y) {
const int Y_index = blockIdx.x * CUDA_NUM_THREADS + threadIdx.x;
if (Y_index < size) {
int X_index = 0;
int v = Y_index;
#pragma unroll
for (int i = ndim - 1; i >= 0; --i) {
X_index += v % Y_dims[i] * X_strides[i];
v /= Y_dims[i];
}
#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
Y[Y_index] = __ldg(X + X_index);
#else
Y[Y_index] = X[X_index];
#endif
}
}
template <typename T>
void TransposeCUDAImpl(const std::vector<int64_t>& X_dims,
const std::vector<int>& axes,
const T* X,
T* Y,
CUDAContext* ctx) {
CHECK_EQ(X_dims.size(), axes.size()) << "dimension size should be equal";
int ndim = X_dims.size();
std::vector<int> strides(ndim, 0);
std::vector<int> Y_dims(ndim, 0);
std::vector<int> buf(ndim, 0);
int cur_stride = 1;
for (int i = ndim - 1; i >= 0; --i) {
buf[i] = cur_stride;
cur_stride *= X_dims[i];
}
for (int i = 0; i < ndim; ++i) {
strides[i] = buf[axes[i]];
}
int size = 1;
for (int i = 0; i < ndim; ++i) {
Y_dims[i] = static_cast<int>(X_dims[axes[i]]);
size *= X_dims[i];
}
lite::Tensor Y_dims_, strides_;
Y_dims_.Resize(std::vector<int64_t>({ndim}));
int* d_y_dims = Y_dims_.mutable_data<int>(TARGET(kCUDA));
CopySync<TARGET(kCUDA)>(
d_y_dims, Y_dims.data(), sizeof(int) * Y_dims.size(), IoDirection::HtoD);
strides_.Resize(std::vector<int64_t>({ndim}));
int* d_strides = strides_.mutable_data<int>(TARGET(kCUDA));
CopySync<TARGET(kCUDA)>(d_strides,
strides.data(),
sizeof(int) * strides.size(),
IoDirection::HtoD);
const int M = (size + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;
TransposeCUDAKernel<<<M, CUDA_NUM_THREADS, 0, ctx->exec_stream()>>>(
size, ndim, d_strides, d_y_dims, X, Y);
// cudaError_t error = cudaGetLastError();
// if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
}
#define TYPE_SPECIALIZED_CUDA_TRANSPOSE(T) \
template <> \
void Transpose<T>(const std::vector<int64_t>& X_dims, \
const std::vector<int>& axes, \
const T* X, \
T* Y, \
CUDAContext* ctx) { \
TransposeCUDAImpl<T>(X_dims, axes, X, Y, ctx); \
}
TYPE_SPECIALIZED_CUDA_TRANSPOSE(float)
#undef TYPE_SPECIALIZED_CUDA_TRANSPOSEF
} // namespace math
} // namespace cuda
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cuda.h>
#include <cuda_runtime.h>
#include <string>
#include <vector>
#include "lite/core/context.h"
#include "lite/core/tensor.h"
namespace paddle {
namespace lite {
namespace cuda {
namespace math {
template <typename T>
void NCHW2NHWC(int N, int C, int HxW, const T* X, T* Y, CUDAContext* context);
template <typename T>
void NHWC2NCHW(int N, int C, int HxW, const T* X, T* Y, CUDAContext* context);
template <typename T>
void Transpose(const std::vector<int64_t>& X_dims,
const std::vector<int>& axes,
const T* X,
T* Y,
CUDAContext* ctx);
} // namespace math
} // namespace cuda
} // namespace lite
} // namespace paddle
......@@ -10,6 +10,7 @@ lite_cc_library(io_copy_compute_cuda SRCS io_copy_compute.cc DEPS ${lite_kernel_
nv_library(leaky_relu_compute_cuda SRCS leaky_relu_compute.cu DEPS ${lite_kernel_deps})
nv_library(yolo_box_compute_cuda SRCS yolo_box_compute.cu DEPS ${lite_kernel_deps})
nv_library(transpose_compute_cuda SRCS transpose_compute.cu DEPS ${lite_kernel_deps} ${math_cuda})
nv_library(nearest_interp_compute_cuda SRCS nearest_interp_compute.cu DEPS ${lite_kernel_deps})
nv_library(conv2d_cuda SRCS conv_compute.cc DEPS ${lite_kernel_deps} ${math_cuda})
nv_library(concat_compute_cuda SRCS concat_compute.cu DEPS ${lite_kernel_deps})
......@@ -19,6 +20,7 @@ nv_test(conv2d_cuda_test SRCS conv_compute_test.cc DEPS conv2d_cuda)
nv_test(nearest_interp_compute_cuda_test SRCS nearest_interp_compute_test.cc DEPS nearest_interp_compute_cuda)
nv_test(leaky_relu_compute_cuda_test SRCS leaky_relu_compute_test.cc DEPS leaky_relu_compute_cuda)
nv_test(yolo_box_compute_cuda_test SRCS yolo_box_compute_test.cc DEPS yolo_box_compute_cuda)
nv_test(transpose_compute_cuda_test SRCS transpose_compute_test.cc DEPS transpose_compute_cuda)
nv_test(concat_compute_cuda_test SRCS concat_compute_test.cc DEPS concat_compute_cuda)
nv_test(elementwise_add_compute_cuda_test SRCS elementwise_add_compute_test.cc DEPS elementwise_add_compute_cuda)
......@@ -34,6 +36,7 @@ nearest_interp_compute_cuda
concat_compute_cuda
elementwise_add_compute_cuda
yolo_box_compute_cuda
transpose_compute_cuda
)
set(cuda_kernels "${cuda_kernels}" CACHE GLOBAL "cuda kernels")
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/cuda/transpose_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
void TransposeCompute::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<CUDAContext>();
const lite::Tensor* X = param.x;
lite::Tensor* Out = param.output;
std::vector<int> axes = param.axis;
const float* in = X->data<float>();
float* out = Out->mutable_data<float>(TARGET(kCUDA));
int ndim = X->dims().size();
std::vector<int64_t> dims = X->dims().data();
// NCHW -> NHWC
if (axes.size() == 4 && axes[0] == 0 && axes[1] == 2 && axes[2] == 3 &&
axes[3] == 1) {
lite::cuda::math::NCHW2NHWC(
dims[0], dims[1], dims[2] * dims[3], in, out, &ctx);
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
return;
}
// NHWC -> NCHW
if (axes.size() == 4 && axes[0] == 0 && axes[1] == 3 && axes[2] == 1 &&
axes[3] == 2) {
lite::cuda::math::NHWC2NCHW(
dims[0], dims[3], dims[1] * dims[2], in, out, &ctx);
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
return;
}
lite::cuda::math::Transpose(dims, axes, in, out, &ctx);
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(transpose,
kCUDA,
kFloat,
kNCHW,
paddle::lite::kernels::cuda::TransposeCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))})
.Finalize();
// REGISTER_LITE_KERNEL(transpose2,
// kCUDA,
// kFloat,
// kNCHW,
// paddle::lite::kernels::cuda::TransposeCompute,
// def)
// .BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))})
// .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))})
// .BindOutput("XShape", {LiteType::GetTensorTy(TARGET(kCUDA))})
// .Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/backends/cuda/math/transpose.h"
#include "lite/core/kernel.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
class TransposeCompute : public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
public:
using param_t = operators::TransposeParam;
void Run() override;
virtual ~TransposeCompute() = default;
private:
lite::Tensor axes_, dims_;
};
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/cuda/transpose_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include <vector>
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
namespace {
#define IN(n, c, h, w) \
input_data[w + h * input_w + c * input_h * input_w + \
n * input_c * input_h * input_w]
#define OUT(n, c, h, w) \
output_data[w + h * output_w + c * output_h * output_w + \
n * output_c * output_h * output_w]
void nchw2nhwc_ref(lite::Tensor* input,
lite::Tensor* output,
const std::vector<int> axies) {
auto* input_data = input->data<float>();
auto* output_data = output->mutable_data<float>();
int input_n = input->dims()[0];
int input_c = input->dims()[1];
int input_h = input->dims()[2];
int input_w = input->dims()[3];
int output_n = output->dims()[0];
int output_c = output->dims()[1];
int output_h = output->dims()[2];
int output_w = output->dims()[3];
for (int n = 0; n < input_n; ++n) {
for (int c = 0; c < input_c; ++c) {
for (int h = 0; h < input_h; ++h) {
for (int w = 0; w < input_w; ++w) {
OUT(n, h, w, c) = IN(n, c, h, w);
}
}
}
}
}
#undef IN
#undef OUT
#define IN(n, h, w, c) \
input_data[c + w * input_c + h * input_w * input_c + \
n * input_h * input_w * input_c]
#define OUT(n, h, w, c) \
output_data[c + w * output_c + h * output_w * output_c + \
n * output_h * output_w * output_c]
void nhwc2nchw_ref(lite::Tensor* input,
lite::Tensor* output,
const std::vector<int> axies) {
auto* input_data = input->data<float>();
auto* output_data = output->mutable_data<float>();
int input_n = input->dims()[0];
int input_h = input->dims()[1];
int input_w = input->dims()[2];
int input_c = input->dims()[3];
int output_n = output->dims()[0];
int output_h = output->dims()[1];
int output_w = output->dims()[2];
int output_c = output->dims()[3];
for (int n = 0; n < input_n; ++n) {
for (int c = 0; c < input_c; ++c) {
for (int h = 0; h < input_h; ++h) {
for (int w = 0; w < input_w; ++w) {
OUT(n, c, h, w) = IN(n, h, w, c);
}
}
}
}
}
void transpose_ref(lite::Tensor* input,
lite::Tensor* output,
const std::vector<int> axes) {
auto* input_data = input->data<float>();
auto* output_data = output->mutable_data<float>();
int ndim = input->dims().size();
auto dims = input->dims();
std::vector<int> strides(ndim, 0);
std::vector<int> buf(ndim, 0);
int cur_stride = 1;
for (int i = ndim - 1; i >= 0; --i) {
buf[i] = cur_stride;
cur_stride *= dims[i];
}
for (int i = 0; i < ndim; ++i) {
strides[i] = buf[axes[i]];
}
auto y_dims = output->dims();
int size = input->dims().production();
for (int i = 0; i < size; ++i) {
int idx = 0;
int v = i;
for (int j = ndim - 1; j >= 0; --j) {
idx += v % y_dims[j] * strides[j];
v /= y_dims[j];
}
output_data[i] = input_data[idx];
}
}
} // namespace
TEST(transpose_nchw, normal) {
TransposeCompute transpose_kernel;
std::unique_ptr<KernelContext> ctx(new KernelContext);
auto& context = ctx->As<CUDAContext>();
operators::TransposeParam param;
lite::Tensor x, x_cpu, x_ref;
lite::Tensor out, out_cpu, out_ref;
int N = 5, C = 6, H = 7, W = 8;
std::vector<int> axes({0, 2, 3, 1});
x.Resize({N, C, H, W});
out.Resize({N, H, W, C});
x_cpu.Resize({N, C, H, W});
out_cpu.Resize({N, H, W, C});
x_ref.Resize({N, C, H, W});
out_ref.Resize({N, H, W, C});
auto* x_data = x.mutable_data<float>(TARGET(kCUDA));
auto* x_cpu_data = x_cpu.mutable_data<float>();
auto* out_cpu_data = out_cpu.mutable_data<float>();
auto* x_ref_data = x_ref.mutable_data<float>();
for (int i = 0; i < x_cpu.numel(); ++i) {
x_cpu_data[i] = i + 1;
x_ref_data[i] = i + 1;
}
x.Assign<float, lite::DDim, TARGET(kCUDA)>(x_cpu_data, x_cpu.dims());
param.x = &x;
param.output = &out;
param.axis = axes;
transpose_kernel.SetParam(param);
cudaStream_t stream;
cudaStreamCreate(&stream);
context.SetExecStream(stream);
transpose_kernel.SetContext(std::move(ctx));
transpose_kernel.Launch();
cudaDeviceSynchronize();
auto* out_data = out.mutable_data<float>(TARGET(kCUDA));
CopySync<TARGET(kCUDA)>(
out_cpu_data, out_data, sizeof(float) * out.numel(), IoDirection::DtoH);
nchw2nhwc_ref(&x_ref, &out_ref, axes);
auto* out_ref_data = out_ref.mutable_data<float>();
// transpose_ref(&x_ref, &out_ref, axes);
for (int i = 0; i < out.numel(); i++) {
EXPECT_NEAR(out_cpu_data[i], out_ref_data[i], 1e-5);
}
}
TEST(transpose_nhwc, normal) {
TransposeCompute transpose_kernel;
std::unique_ptr<KernelContext> ctx(new KernelContext);
auto& context = ctx->As<CUDAContext>();
operators::TransposeParam param;
lite::Tensor x, x_cpu, x_ref;
lite::Tensor out, out_cpu, out_ref;
int N = 5, C = 6, H = 7, W = 8;
std::vector<int> axes({0, 3, 1, 2});
x.Resize({N, H, W, C});
out.Resize({N, C, H, W});
x_cpu.Resize({N, H, W, C});
out_cpu.Resize({N, C, H, W});
x_ref.Resize({N, H, W, C});
out_ref.Resize({N, C, H, W});
auto* x_data = x.mutable_data<float>(TARGET(kCUDA));
auto* x_cpu_data = x_cpu.mutable_data<float>();
auto* out_cpu_data = out_cpu.mutable_data<float>();
auto* x_ref_data = x_ref.mutable_data<float>();
for (int i = 0; i < x_cpu.numel(); ++i) {
x_cpu_data[i] = i + 1;
x_ref_data[i] = i + 1;
}
x.Assign<float, lite::DDim, TARGET(kCUDA)>(x_cpu_data, x_cpu.dims());
param.x = &x;
param.output = &out;
param.axis = axes;
transpose_kernel.SetParam(param);
cudaStream_t stream;
cudaStreamCreate(&stream);
context.SetExecStream(stream);
transpose_kernel.SetContext(std::move(ctx));
transpose_kernel.Launch();
cudaDeviceSynchronize();
auto* out_data = out.mutable_data<float>(TARGET(kCUDA));
CopySync<TARGET(kCUDA)>(
out_cpu_data, out_data, sizeof(float) * out.numel(), IoDirection::DtoH);
nhwc2nchw_ref(&x_ref, &out_ref, axes);
// transpose_ref(&x_ref, &out_ref, axes);
auto* out_ref_data = out_ref.mutable_data<float>();
for (int i = 0; i < out.numel(); i++) {
EXPECT_NEAR(out_cpu_data[i], out_ref_data[i], 1e-5);
}
}
TEST(transpose, normal) {
TransposeCompute transpose_kernel;
std::unique_ptr<KernelContext> ctx(new KernelContext);
auto& context = ctx->As<CUDAContext>();
operators::TransposeParam param;
lite::Tensor x, x_cpu, x_ref;
lite::Tensor out, out_cpu, out_ref;
int C = 6, H = 7, W = 8;
std::vector<int> axes({2, 0, 1});
x.Resize({C, H, W});
out.Resize({W, C, H});
x_cpu.Resize({C, H, W});
out_cpu.Resize({W, C, H});
x_ref.Resize({C, H, W});
out_ref.Resize({W, C, H});
auto* x_data = x.mutable_data<float>(TARGET(kCUDA));
auto* x_cpu_data = x_cpu.mutable_data<float>();
auto* out_cpu_data = out_cpu.mutable_data<float>();
auto* x_ref_data = x_ref.mutable_data<float>();
for (int i = 0; i < x_cpu.numel(); ++i) {
x_cpu_data[i] = i + 1;
x_ref_data[i] = i + 1;
}
x.Assign<float, lite::DDim, TARGET(kCUDA)>(x_cpu_data, x_cpu.dims());
param.x = &x;
param.output = &out;
param.axis = axes;
transpose_kernel.SetParam(param);
cudaStream_t stream;
cudaStreamCreate(&stream);
context.SetExecStream(stream);
transpose_kernel.SetContext(std::move(ctx));
transpose_kernel.Launch();
cudaDeviceSynchronize();
auto* out_data = out.mutable_data<float>(TARGET(kCUDA));
CopySync<TARGET(kCUDA)>(
out_cpu_data, out_data, sizeof(float) * out.numel(), IoDirection::DtoH);
transpose_ref(&x_ref, &out_ref, axes);
auto* out_ref_data = out_ref.mutable_data<float>();
for (int i = 0; i < out.numel(); i++) {
EXPECT_NEAR(out_cpu_data[i], out_ref_data[i], 1e-5);
}
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册