提交 4db40afc 编写于 作者: W Wilber 提交者: GitHub

add var_conv_2d cuda kernel and unit test test=develop (#2441)

- add var_conv_2d cuda kernel

- add var_conv_2d cuda kernel unit test

- temporarily set to two input mode, remove input(ROW) and input(COLUMN)
上级 cfa086e9
......@@ -29,6 +29,7 @@ add_kernel(sequence_concat_compute_cuda CUDA basic SRCS sequence_concat_compute.
add_kernel(sequence_arithmetic_compute_cuda CUDA basic SRCS sequence_arithmetic_compute.cu DEPS ${lite_kernel_deps})
add_kernel(lookup_table_compute_cuda CUDA extra SRCS lookup_table_compute.cu DEPS ${lite_kernel_deps})
add_kernel(match_matrix_tensor_compute_cuda CUDA basic SRCS match_matrix_tensor_compute.cu DEPS ${lite_kernel_deps} cuda_gemm)
add_kernel(var_conv_2d_compute_cuda CUDA basic SRCS var_conv_2d_compute.cu DEPS ${lite_kernel_deps} ${math_cuda})
lite_cc_test(calib_compute_cuda_test SRCS calib_compute_cuda_test.cc DEPS calib_compute_cuda)
nv_test(conv2d_cuda_test SRCS conv_compute_test.cc DEPS conv2d_cuda)
......@@ -50,6 +51,7 @@ nv_test(sequence_reverse_compute_cuda_test SRCS sequence_reverse_compute_test.cc
nv_test(sequence_concat_compute_cuda_test SRCS sequence_concat_compute_test.cc DEPS sequence_concat_compute_cuda)
nv_test(sequence_arithmetic_compute_cuda_test SRCS sequence_arithmetic_compute_test.cc DEPS sequence_arithmetic_compute_cuda)
nv_test(match_matrix_tensor_compute_cuda_test SRCS match_matrix_tensor_compute_test.cc DEPS match_matrix_tensor_compute_cuda)
nv_test(var_conv_2d_compute_cuda_test SRCS var_conv_2d_compute_test.cc DEPS var_conv_2d_compute_cuda)
if(LITE_BUILD_EXTRA)
nv_test(lookup_table_compute_cuda_test SRCS lookup_table_compute_test.cc DEPS lookup_table_compute_cuda)
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <memory>
#include <vector>
#include "lite/backends/cuda/math/gemm.h"
#include "lite/core/op_registry.h"
#include "lite/core/target_wrapper.h"
#include "lite/core/tensor.h"
#include "lite/kernels/cuda/var_conv_2d_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
const int CUDA_NUM_THREADS = 512;
template <typename Dtype>
__global__ void var_im2col_gpu_kernel(const int n,
const Dtype* data_im,
const int height,
const int width,
const int kernel_h,
const int kernel_w,
const int pad_h,
const int pad_w,
const int stride_h,
const int stride_w,
const int height_col,
const int width_col,
Dtype* data_col) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
for (int index = idx; index < n; index += blockDim.x * gridDim.x) {
const int h_index = index / width_col;
const int h_col = h_index % height_col;
const int w_col = index % width_col;
const int c_im = h_index / height_col;
const int c_col = c_im * kernel_h * kernel_w;
const int h_offset = h_col * stride_h - pad_h;
const int w_offset = w_col * stride_w - pad_w;
Dtype* data_col_ptr = data_col;
data_col_ptr += (c_col * height_col + h_col) * width_col + w_col;
const Dtype* data_im_ptr = data_im;
data_im_ptr += (c_im * height + h_offset) * width + w_offset;
for (int i = 0; i < kernel_h; ++i) {
for (int j = 0; j < kernel_w; ++j) {
int h_im = h_offset + i;
int w_im = w_offset + j;
*data_col_ptr =
(h_im >= 0 && w_im >= 0 && h_im < height && w_im < width)
? data_im_ptr[i * width + j]
: 0;
data_col_ptr += height_col * width_col;
}
}
}
}
void VarConv2DCompute::var_im2col(const cudaStream_t& stream) {
auto& param = this->Param<param_t>();
int input_channel = param.input_channel;
int kernel_h = param.kernel_h;
int kernel_w = param.kernel_w;
int stride_h = param.stride_h;
int stride_w = param.stride_w;
// auto* in_row = param.ROW;
// auto* in_col = param.COLUMN;
const auto* input = param.X;
auto* col = param.Col;
int batch = input->lod()[0].size() - 1;
const auto& bottom_offset = input->lod()[0];
// 2-D lod info.
// const auto& offset_x = in_col->lod()[0];
// const auto& offset_y = in_row->lod()[0];
const auto& offset_y = param.X->lod()[1];
const auto& offset_x = param.X->lod()[2];
// top offset is the whole size of each data sample
std::vector<uint64_t> top_offset;
int top_size = 0;
top_offset.push_back(top_size);
for (int b = 0; b < batch; ++b) {
int width = offset_x[b + 1] - offset_x[b];
int height = offset_y[b + 1] - offset_y[b];
int top_im_x = 0;
if (width == 0) {
top_im_x = 0;
} else {
top_im_x = (width - 1) / stride_w + 1;
}
int top_im_y = 0;
if (height == 0) {
top_im_y = 0;
} else {
top_im_y = (height - 1) / stride_h + 1;
}
int top_x = top_im_x * top_im_y;
int top_y = input_channel * kernel_h * kernel_w;
top_size += top_y * top_x;
top_offset.push_back(top_size);
}
LoD col_lod;
col_lod.push_back(top_offset);
col->set_lod(col_lod);
std::vector<int64_t> col_dims_vec{top_size};
col_dims_vec.push_back(1);
col->Resize(col_dims_vec);
auto* top_data = col->mutable_data<float>(TARGET(kCUDA));
const auto* bottom_data = input->data<float>();
for (int b = 0; b < batch; ++b) {
int t_offset = top_offset[b];
int b_offset = bottom_offset[b];
int width = offset_x[b + 1] - offset_x[b];
int height = offset_y[b + 1] - offset_y[b];
if (width == 0 || height == 0) {
continue;
}
int width_col = (width - 1) / stride_w + 1;
int height_col = (height - 1) / stride_h + 1;
const float* data_im = bottom_data + b_offset;
float* data_col = top_data + t_offset;
// We are going to launch channels * height_col * width_col kernels, each
// kernel responsible for copying a single-channel grid.
int num_kernels = height_col * width_col * input_channel;
const int CUDA_NUM_BLOCKS =
(num_kernels + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;
var_im2col_gpu_kernel<
float><<<CUDA_NUM_BLOCKS, CUDA_NUM_THREADS, 0, stream>>>(
num_kernels,
data_im,
height,
width,
kernel_h,
kernel_w,
((stride_h - 1) * height + kernel_h - 1) / 2,
((stride_w - 1) * width + kernel_w - 1) / 2,
stride_h,
stride_w,
height_col,
width_col,
data_col);
}
}
void VarConv2DCompute::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<CUDAContext>();
auto stream = ctx.exec_stream();
auto* bottom = param.X;
// auto* in_row = param.ROW;
// auto* in_col = param.COLUMN;
auto* w = param.W;
auto* top = param.Out;
auto* col = param.Col;
int output_channel = param.output_channel;
int input_channel = param.input_channel;
int kernel_h = param.kernel_h;
int kernel_w = param.kernel_w;
int stride_h = param.stride_h;
int stride_w = param.stride_w;
var_im2col(stream);
int batch = bottom->lod()[0].size() - 1;
const auto& col_offset = col->lod()[0];
// const auto& offset_x = in_col->lod()[0];
// const auto& offset_y = in_row->lod()[0];
const auto& offset_y = param.X->lod()[1];
const auto& offset_x = param.X->lod()[2];
std::vector<size_t> top_offset;
std::vector<int64_t> height_vector;
std::vector<int64_t> width_vector;
int top_size = 0;
top_offset.push_back(top_size);
for (int b = 0; b < batch; ++b) {
int width = offset_x[b + 1] - offset_x[b];
int height = offset_y[b + 1] - offset_y[b];
int top_im_x = 0;
if (width == 0) {
top_im_x = 0;
} else {
top_im_x = (width - 1) / stride_w + 1;
}
int top_im_y = 0;
if (height == 0) {
top_im_y = 0;
} else {
top_im_y = (height - 1) / stride_h + 1;
}
height_vector.push_back(top_im_y);
width_vector.push_back(top_im_x);
int top_im_size = top_im_y * top_im_x;
top_size += output_channel * top_im_size;
top_offset.push_back(top_size);
}
LoD top_lod;
top_lod.push_back(top_offset);
top->set_lod(top_lod);
std::vector<int64_t> top_dims_vec{top_size};
top_dims_vec.push_back(1);
top->Resize(top_dims_vec);
auto* top_data = top->mutable_data<float>(TARGET(kCUDA));
const auto* w_data = w->data<float>();
const auto* col_data = col->data<float>();
std::unique_ptr<lite::cuda::math::Gemm<float, float>> gemm_impl_;
for (int b = 0; b < batch; ++b) {
int top_im_size = (top_offset[b + 1] - top_offset[b]) / output_channel;
if (top_im_size == 0) {
continue;
}
float* out_data = top_data + top_offset[b];
const float* in_data = col_data + col->lod()[0][b];
gemm_impl_.reset(new lite::cuda::math::Gemm<float, float>);
gemm_impl_->init(false,
false,
w->dims()[0],
height_vector[b] * width_vector[b],
input_channel * kernel_h * kernel_w,
&ctx);
gemm_impl_->run(1., 0., w_data, in_data, out_data, &ctx);
}
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(var_conv_2d,
kCUDA,
kFloat,
kNCHW,
paddle::lite::kernels::cuda::VarConv2DCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindInput("W", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("Col", {LiteType::GetTensorTy(TARGET(kCUDA))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/core/kernel.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
class VarConv2DCompute : public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
public:
using param_t = operators::VarConv2DParam;
void Run() override;
virtual ~VarConv2DCompute() = default;
private:
void var_im2col(const cudaStream_t& stream);
};
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/cuda/var_conv_2d_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include <vector>
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
static void im2col_ref(const lite::Tensor& input,
const lite::Tensor* in_row,
const lite::Tensor* in_col,
const int kernel_h,
const int kernel_w,
const int stride_h,
const int stride_w,
const int input_channel,
lite::Tensor* col) {
int batch = input.lod()[0].size() - 1;
const auto& bottom_offset = input.lod()[0];
// 2-D lod info.
const auto& offset_x = in_col->lod()[0];
const auto& offset_y = in_row->lod()[0];
// top offset is the whole size of each data sample
std::vector<uint64_t> top_offset;
int top_size = 0;
top_offset.push_back(top_size);
for (int b = 0; b < batch; ++b) {
int width = offset_x[b + 1] - offset_x[b];
int height = offset_y[b + 1] - offset_y[b];
int top_im_x = 0;
if (width == 0) {
top_im_x = 0;
} else {
top_im_x = (width - 1) / stride_w + 1;
}
int top_im_y = 0;
if (height == 0) {
top_im_y = 0;
} else {
top_im_y = (height - 1) / stride_h + 1;
}
int top_x = top_im_x * top_im_y;
int top_y = input_channel * kernel_h * kernel_w;
top_size += top_y * top_x;
top_offset.push_back(top_size);
}
LoD col_lod;
col_lod.push_back(top_offset);
col->set_lod(col_lod);
std::vector<int64_t> col_dims_vec{top_size};
col_dims_vec.push_back(1);
col->Resize(col_dims_vec);
auto* top_data = col->mutable_data<float>();
const auto* bottom_data = input.data<float>();
int kernel_win_size = kernel_h * kernel_w;
int half_kernel_h = kernel_h / 2;
int half_kernel_w = kernel_w / 2;
for (int b = 0; b < batch; ++b) {
int t_offset = top_offset[b];
int b_offset = bottom_offset[b];
int width = offset_x[b + 1] - offset_x[b];
int height = offset_y[b + 1] - offset_y[b];
if (width == 0 || height == 0) {
continue;
}
int top_im_x = (width - 1) / stride_w + 1;
int top_im_y = (height - 1) / stride_h + 1;
int top_x = top_im_y * top_im_x;
for (int z = 0; z < input_channel; ++z) {
int row_offset = kernel_win_size * z;
int im_offset = z * width * height;
for (int y = 0; y < height; y += stride_h) {
for (int x = 0; x < width; x += stride_w) {
int col_offset = x / stride_w + y / stride_h * top_im_x;
for (int ky = 0; ky < kernel_h; ++ky) {
for (int kx = 0; kx < kernel_w; ++kx) {
int im_y = y + ky - half_kernel_h;
int im_x = x + kx - half_kernel_w;
if (im_x >= 0 && im_x < width && im_y >= 0 && im_y < height) {
top_data[t_offset + (row_offset + ky * kernel_w + kx) * top_x +
col_offset] =
bottom_data[b_offset + im_offset + im_y * width + im_x];
} else {
top_data[t_offset + (row_offset + ky * kernel_w + kx) * top_x +
col_offset] = 0;
}
}
}
}
}
}
}
}
static void naive_sgemm(const bool transpose_A,
const bool transpose_B,
const int M,
const int N,
const int K,
const float alpha,
const float* A, // m x k (after transpose if TransA)
const int lda, // leading dimension of a
const float* B, // k x n (after transpose if TransB)
const int ldb, // leading dimension of b
const float beta,
float* C, // m x n
const int ldc) {
for (int m = 0; m < M; ++m) {
for (int k = 0; k < K; ++k) {
for (int n = 0; n < N; ++n) {
C[m * N + n] += beta * C[m * N + n];
size_t A_idx = 0, B_idx = 0;
if (transpose_A) {
A_idx = k * M + m; // A is k x m
} else {
A_idx = m * K + k; // A is m x k
}
if (transpose_B) {
B_idx = n * K + k; // B is n x k
} else {
B_idx = k * N + n; // B is k x n
}
C[m * N + n] += alpha * A[A_idx] * B[B_idx];
}
}
}
}
static void var_conv_2d_ref(const lite::Tensor* bottom,
const lite::Tensor* w,
const lite::Tensor* in_row,
const lite::Tensor* in_col,
const int kernel_h,
const int kernel_w,
const int stride_h,
const int stride_w,
const int input_channel,
const int output_channel,
lite::Tensor* top,
lite::Tensor* col) {
im2col_ref(*bottom,
in_row,
in_col,
kernel_h,
kernel_w,
stride_h,
stride_w,
input_channel,
col);
int batch = bottom->lod()[0].size() - 1;
const auto& col_offset = col->lod()[0];
const auto& offset_x = in_col->lod()[0];
const auto& offset_y = in_row->lod()[0];
std::vector<size_t> top_offset;
int top_size = 0;
top_offset.push_back(top_size);
for (int b = 0; b < batch; ++b) {
int width = offset_x[b + 1] - offset_x[b];
int height = offset_y[b + 1] - offset_y[b];
int top_im_x = 0;
if (width == 0) {
top_im_x = 0;
} else {
top_im_x = (width - 1) / stride_w + 1;
}
int top_im_y = 0;
if (height == 0) {
top_im_y = 0;
} else {
top_im_y = (height - 1) / stride_h + 1;
}
int top_im_size = top_im_y * top_im_x;
top_size += output_channel * top_im_size;
top_offset.push_back(top_size);
}
LoD top_lod;
top_lod.push_back(top_offset);
top->set_lod(top_lod);
std::vector<int64_t> top_dims_vec{top_size};
top_dims_vec.push_back(1);
top->Resize(top_dims_vec);
auto* top_data = top->mutable_data<float>();
const auto* w_data = w->data<float>();
const auto* col_data = col->data<float>();
for (int b = 0; b < batch; ++b) {
int top_im_size = (top_offset[b + 1] - top_offset[b]) / output_channel;
if (top_im_size == 0) {
continue;
}
naive_sgemm(false,
false,
output_channel,
top_im_size,
input_channel * kernel_h * kernel_w,
1.0,
w_data,
input_channel * kernel_h * kernel_w,
col_data + col_offset[b],
top_im_size,
0.0,
top_data + top_offset[b],
top_im_size);
}
}
TEST(var_conv_2d_cuda, normal) {
VarConv2DCompute var_conv_kernel;
std::unique_ptr<KernelContext> ctx(new KernelContext);
auto& context = ctx->As<CUDAContext>();
operators::VarConv2DParam param;
lite::Tensor X, W, ROW, COLUMN;
lite::Tensor x_cpu, w_cpu;
lite::Tensor Out, Col, out_cpu, col_cpu;
int kernel_h = 5, kernel_w = 5;
int stride_h = 1, stride_w = 1;
int input_channel = 5, output_channel = 5;
std::vector<int64_t> w_dims_vec;
w_dims_vec.push_back(output_channel);
w_dims_vec.push_back(input_channel * kernel_h * kernel_w);
W.Resize(w_dims_vec);
w_cpu.Resize(w_dims_vec);
auto* w_cpu_data = w_cpu.mutable_data<float>();
for (int i = 0; i < W.numel(); ++i) {
w_cpu_data[i] = i - 1.f;
}
std::vector<uint64_t> row_lod_vec{0, 10, 20};
LoD row_lod;
row_lod.push_back(row_lod_vec);
ROW.set_lod(row_lod);
std::vector<uint64_t> column_lod_vec{0, 10, 20};
LoD column_lod;
column_lod.push_back(column_lod_vec);
COLUMN.set_lod(column_lod);
int x_size = 0;
std::vector<uint64_t> x_lod_vec;
x_lod_vec.push_back(0);
for (size_t i = 0; i < row_lod_vec.size() - 1; ++i) {
int height = row_lod_vec[i + 1] - row_lod_vec[i];
int width = column_lod_vec[i + 1] - column_lod_vec[i];
x_lod_vec.push_back(x_lod_vec.back() + height * width);
x_size += height * width;
}
for (size_t i = 0; i < x_lod_vec.size(); ++i) {
x_lod_vec[i] *= input_channel;
}
x_size *= input_channel;
std::vector<int64_t> x_dims_vec{x_size, 1};
LoD x_lod;
x_lod.push_back(x_lod_vec);
x_lod.push_back(row_lod_vec);
x_lod.push_back(column_lod_vec);
X.Resize(x_dims_vec);
x_cpu.Resize(x_dims_vec);
X.set_lod(x_lod);
x_cpu.set_lod(x_lod);
auto* x_cpu_data = x_cpu.mutable_data<float>();
for (int i = 0; i < X.numel(); ++i) {
x_cpu_data[i] = i % 20 * 1.f;
}
int sum_num = 0;
int out_sum_num = 0;
for (size_t i = 0; i < row_lod_vec.size() - 1; ++i) {
int height = row_lod_vec[i + 1] - row_lod_vec[i];
int width = column_lod_vec[i + 1] - column_lod_vec[i];
sum_num += height * width * input_channel * kernel_h * kernel_w;
out_sum_num += height * width * output_channel;
}
col_cpu.Resize({sum_num, 1});
out_cpu.Resize({out_sum_num, 1});
float* out_cpu_data = out_cpu.mutable_data<float>();
float* col_cpu_data = col_cpu.mutable_data<float>();
X.Assign<float, lite::DDim, TARGET(kCUDA)>(x_cpu_data, x_cpu.dims());
W.Assign<float, lite::DDim, TARGET(kCUDA)>(w_cpu_data, w_cpu.dims());
param.X = &X;
param.W = &W;
// param.ROW = &ROW;
// param.COLUMN = &COLUMN;
param.Out = &Out;
param.Col = &Col;
param.stride_h = stride_h;
param.stride_w = stride_w;
param.kernel_h = kernel_h;
param.kernel_w = kernel_w;
param.input_channel = input_channel;
param.output_channel = output_channel;
var_conv_kernel.SetParam(param);
cudaStream_t stream;
cudaStreamCreate(&stream);
context.SetExecStream(stream);
var_conv_kernel.SetContext(std::move(ctx));
var_conv_kernel.Run();
cudaDeviceSynchronize();
const float* out_data = Out.data<float>();
const float* col_data = Col.data<float>();
CopySync<TARGET(kCUDA)>(
out_cpu_data, out_data, sizeof(float) * Out.numel(), IoDirection::DtoH);
CopySync<TARGET(kCUDA)>(
col_cpu_data, col_data, sizeof(float) * Col.numel(), IoDirection::DtoH);
lite::Tensor top_ref, col_ref;
var_conv_2d_ref(&x_cpu,
&w_cpu,
&ROW,
&COLUMN,
kernel_h,
kernel_w,
stride_h,
stride_w,
input_channel,
output_channel,
&top_ref,
&col_ref);
for (int i = 0; i < Out.numel(); ++i) {
EXPECT_NEAR(out_cpu_data[i], top_ref.data<float>()[i], 1e-5);
}
for (int i = 0; i < Col.numel(); ++i) {
EXPECT_NEAR(col_cpu_data[i], col_ref.data<float>()[i], 1e-5);
}
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
......@@ -22,8 +22,6 @@ REGISTER_LITE_KERNEL(var_conv_2d,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("W", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("ROW", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("COLUMN", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Col", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize();
......@@ -36,14 +36,16 @@ class VarConv2DCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
int kernel_w = param.kernel_w;
int stride_h = param.stride_h;
int stride_w = param.stride_w;
auto* in_row = param.ROW;
auto* in_col = param.COLUMN;
// auto* in_row = param.ROW;
// auto* in_col = param.COLUMN;
int batch = input.lod()[0].size() - 1;
const auto& bottom_offset = input.lod()[0];
// 2-D lod info.
const auto& offset_x = in_col->lod()[0];
const auto& offset_y = in_row->lod()[0];
// const auto& offset_x = in_col->lod()[0];
// const auto& offset_y = in_row->lod()[0];
const auto& offset_y = param.X->lod()[1];
const auto& offset_x = param.X->lod()[2];
// top offset is the whole size of each data sample
std::vector<uint64_t> top_offset;
......@@ -126,8 +128,8 @@ class VarConv2DCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
auto& param = *param_.get_mutable<param_t>();
auto& context = ctx_->As<X86Context>();
auto* bottom = param.X;
auto* in_row = param.ROW;
auto* in_col = param.COLUMN;
// auto* in_row = param.ROW;
// auto* in_col = param.COLUMN;
auto* w = param.W;
auto* top = param.Out;
auto* col = param.Col;
......@@ -142,8 +144,10 @@ class VarConv2DCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
Im2Col(*bottom, col);
int batch = bottom->lod()[0].size() - 1;
const auto& col_offset = col->lod()[0];
const auto& offset_x = in_col->lod()[0];
const auto& offset_y = in_row->lod()[0];
// const auto& offset_x = in_col->lod()[0];
// const auto& offset_y = in_row->lod()[0];
const auto& offset_y = param.X->lod()[1];
const auto& offset_x = param.X->lod()[2];
std::vector<size_t> top_offset;
int top_size = 0;
top_offset.push_back(top_size);
......
......@@ -250,16 +250,18 @@ TEST(var_conv_2d_x86, run_test) {
int x_size = 0;
std::vector<uint64_t> x_lod_vec;
x_lod_vec.push_back(0);
for (size_t i = 0; i < row_lod_vec.size() - 1; ++i) {
int height = row_lod_vec[i + 1] - row_lod_vec[i];
int width = column_lod_vec[i + 1] - column_lod_vec[i];
x_lod_vec.push_back(height * width);
x_size += height * width;
x_lod_vec.push_back(height * width * input_channel);
x_size += height * width * input_channel;
}
x_size *= input_channel;
std::vector<int64_t> x_dims_vec{x_size, 1};
LoD x_lod;
x_lod.push_back(x_lod_vec);
x_lod.push_back(row_lod_vec);
x_lod.push_back(column_lod_vec);
X.Resize(x_dims_vec);
X.set_lod(x_lod);
auto* x_data = X.mutable_data<float>();
......@@ -269,8 +271,8 @@ TEST(var_conv_2d_x86, run_test) {
param.X = &X;
param.W = &W;
param.ROW = &ROW;
param.COLUMN = &COLUMN;
// param.ROW = &ROW;
// param.COLUMN = &COLUMN;
param.Out = &Out;
param.Col = &Col;
param.stride_h = stride_h;
......
......@@ -30,13 +30,15 @@ bool VarConv2dOp::CheckShape() const {
<< "W dim[1] should be equal to InputChannel * KernelH * KernelW";
LoD x_lod = param_.X->lod();
CHECK_EQ(x_lod.empty(), false) << "The Input(X) must hold lod info.";
CHECK_GE(x_lod.size(), 1) << "The Input(X)'s lod info is corrupted.";
// CHECK_GE(x_lod.size(), 1) << "The Input(X)'s lod info is corrupted.";
CHECK_GE(x_lod.size(), 3) << "The Input(X)'s lod info is corrupted.";
CHECK_EQ(x_dims[0], static_cast<int64_t>(x_lod[0].back()))
<< "The Input(X)'s lod info mismatches the actual tensor shape.";
LoD row_lod = param_.ROW->lod();
CHECK_EQ(row_lod.empty(), false) << "The Input(ROW) must hold lod info.";
LoD col_lod = param_.COLUMN->lod();
CHECK_EQ(col_lod.empty(), false) << "The Input(COLUMN) must hold lod info.";
// LoD row_lod = param_.ROW->lod();
// CHECK_EQ(row_lod.empty(), false) << "The Input(ROW) must hold lod info.";
// LoD col_lod = param_.COLUMN->lod();
// CHECK_EQ(col_lod.empty(), false) << "The Input(COLUMN) must hold lod
// info.";
return true;
}
......@@ -45,10 +47,10 @@ bool VarConv2dOp::InferShape() const { return true; }
bool VarConv2dOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
param_.X = const_cast<lite::Tensor *>(
&scope->FindVar(opdesc.Input("X").front())->Get<lite::Tensor>());
param_.ROW = const_cast<lite::Tensor *>(
&scope->FindVar(opdesc.Input("ROW").front())->Get<lite::Tensor>());
param_.COLUMN = const_cast<lite::Tensor *>(
&scope->FindVar(opdesc.Input("COLUMN").front())->Get<lite::Tensor>());
// param_.ROW = const_cast<lite::Tensor *>(
// &scope->FindVar(opdesc.Input("ROW").front())->Get<lite::Tensor>());
// param_.COLUMN = const_cast<lite::Tensor *>(
// &scope->FindVar(opdesc.Input("COLUMN").front())->Get<lite::Tensor>());
param_.W = const_cast<lite::Tensor *>(
&scope->FindVar(opdesc.Input("W").front())->Get<lite::Tensor>());
param_.Out =
......@@ -56,8 +58,8 @@ bool VarConv2dOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
param_.Col =
scope->FindVar(opdesc.Output("Col").front())->GetMutable<lite::Tensor>();
CHECK(param_.X) << "X(Input) of VarConv2dOP should not be null.";
CHECK(param_.ROW) << "Input(ROW) of VarConv2dOP should not be null.";
CHECK(param_.COLUMN) << "Input(COLUMN) of VarConv2dOP should not be null.";
// CHECK(param_.ROW) << "Input(ROW) of VarConv2dOP should not be null.";
// CHECK(param_.COLUMN) << "Input(COLUMN) of VarConv2dOP should not be null.";
CHECK(param_.W) << "W(Input) of VarConv2dOP should not be null.";
CHECK(param_.Out) << "Out(Output) of VarConv2dOP should not be null.";
CHECK(param_.Col) << "Col(Output) of VarConv2dOP should not be null.";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册