From 62044067d16a7a8f454ead85ecf36720e2a9f63b Mon Sep 17 00:00:00 2001 From: Wilber Date: Thu, 19 Dec 2019 19:04:27 +0800 Subject: [PATCH] optimize cuda kernel test=develop (#2628) * optimize content-dnn cuda kernel --- lite/backends/cuda/math/cudnn_conv.cc | 8 +- lite/backends/cuda/math/gemm.h | 2 + .../cuda/match_matrix_tensor_compute.cu | 180 +++++++----- .../cuda/match_matrix_tensor_compute.h | 4 + lite/kernels/cuda/search_fc_compute.cu | 111 +------ lite/kernels/cuda/search_fc_compute.h | 7 +- lite/kernels/cuda/sequence_concat_compute.cu | 191 ++++++------ lite/kernels/cuda/sequence_concat_compute.h | 5 - .../cuda/sequence_topk_avg_pooling_compute.cu | 82 +++--- lite/kernels/cuda/softmax_compute.cu | 21 +- lite/kernels/cuda/softmax_compute.h | 6 + lite/kernels/cuda/var_conv_2d_compute.cu | 271 ++++-------------- lite/kernels/cuda/var_conv_2d_compute.h | 6 +- lite/kernels/x86/sequence_concat_compute.h | 24 +- lite/operators/op_params.h | 2 + lite/operators/sequence_concat_op.cc | 39 +-- lite/operators/var_conv_2d_op.cc | 23 +- 17 files changed, 390 insertions(+), 592 deletions(-) diff --git a/lite/backends/cuda/math/cudnn_conv.cc b/lite/backends/cuda/math/cudnn_conv.cc index a4f33f467f..5dd53084f4 100644 --- a/lite/backends/cuda/math/cudnn_conv.cc +++ b/lite/backends/cuda/math/cudnn_conv.cc @@ -89,9 +89,15 @@ bool CudnnConv2D::create(const operators::ConvParam& param, this->act_desc_, CUDNN_ACTIVATION_RELU, CUDNN_NOT_PROPAGATE_NAN, 0.0)); } +#if CUDNN_VERSION_MIN(7, 0, 0) + cudnnMathType_t math_type = + use_tensor_core_ ? CUDNN_TENSOR_OP_MATH : CUDNN_DEFAULT_MATH; + CUDNN_CHECK(cudnnSetConvolutionMathType(this->conv_desc_, math_type)); +#endif + if (ic == param.groups && ic == oc && ic != 1) { this->fwd_algo_ = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM; - } else if (1) { + } else if (!param.var_length) { const auto* i_data = param.x->data(); const auto* w_data = param.filter->data(); auto* o_data = param.output->mutable_data(TARGET(kCUDA)); diff --git a/lite/backends/cuda/math/gemm.h b/lite/backends/cuda/math/gemm.h index 12194d54b0..85576e6501 100644 --- a/lite/backends/cuda/math/gemm.h +++ b/lite/backends/cuda/math/gemm.h @@ -55,6 +55,8 @@ class Gemm { PtypeOut* c, Context* ctx); + cublasHandle_t get_handle() const { return cu_handle_; } + private: cudaStream_t exe_stream_; cublasHandle_t cu_handle_; diff --git a/lite/kernels/cuda/match_matrix_tensor_compute.cu b/lite/kernels/cuda/match_matrix_tensor_compute.cu index f89b9c9578..0458bb4e8e 100644 --- a/lite/kernels/cuda/match_matrix_tensor_compute.cu +++ b/lite/kernels/cuda/match_matrix_tensor_compute.cu @@ -10,6 +10,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include #include #include "lite/core/op_registry.h" #include "lite/kernels/cuda/match_matrix_tensor_compute.h" @@ -20,6 +21,54 @@ namespace kernels { namespace cuda { using Tensor = lite::Tensor; +template +void gpu_transpose( + cublasHandle_t handle, const dtype* src, int M, int N, dtype* dst); + +template <> +void gpu_transpose( + cublasHandle_t handle, const float* src, int M, int N, float* dst) { + float alpha = 1.0; + float beta = 0.0; + CUBLAS_CHECK(cublasSgeam(handle, + CUBLAS_OP_T, + CUBLAS_OP_N, + M, + N, + &alpha, + src, + N, + &beta, + dst, + M, + dst, + M)); +} + +template +__global__ void padding_out(const dtype* src, + const int* offset, + const int seq_num_r, + const int max_len_r, + const int tl, + const int count, + dtype* dst) { + int tid = threadIdx.x + blockIdx.x * blockDim.x; + int thread_num = blockDim.x * gridDim.x; + for (tid = threadIdx.x + blockIdx.x * blockDim.x; tid < count; + tid += thread_num) { + int seq_id = tid / (tl * max_len_r); + int tl_id = (tid / (max_len_r)) % tl; + int r_id = tid % max_len_r; + int cur_len = offset[seq_id + 1] - offset[seq_id]; + if (r_id < cur_len) { + dst[tid] = src[(offset[seq_id] + r_id) * tl + tl_id]; + } else { + dst[tid] = 0.f; + } + } +} + void MatchMatrixTensorCompute::PrepareForRun() { gemm_impl_.reset(new lite::cuda::math::Gemm); } @@ -28,6 +77,7 @@ void MatchMatrixTensorCompute::Run() { CHECK(ctx_) << "running context should be set first"; auto& param = this->Param(); auto& context = this->ctx_->template As(); + auto stream = context.exec_stream(); auto* x = param.x; auto* w = param.w; @@ -39,76 +89,74 @@ void MatchMatrixTensorCompute::Run() { const auto& offset_l = x->lod()[0]; const auto& offset_r = y->lod()[0]; - - std::vector top_offset; - int top_size = 0; - top_offset.push_back(top_size); - for (size_t b = 0; b < x->lod()[0].size() - 1; b++) { - int len_l = offset_l[b + 1] - offset_l[b]; - int len_r = offset_r[b + 1] - offset_r[b]; - top_size += dim_t * len_l * len_r; - top_offset.push_back(top_size); + std::vector offset_r_int(offset_r.size()); + std::transform(offset_r.begin(), + offset_r.end(), + offset_r_int.begin(), + [](int64_t x) -> int { return static_cast(x); }); + + int batch = offset_r.size() - 1; + int len_l = offset_l[1] - offset_l[0]; + for (int i = 1; i < offset_l.size() - 1; i++) { + int cur_len = offset_l[i + 1] - offset_l[i]; + CHECK_EQ(cur_len, len_l) + << "each sequence of left matrix is the same length"; } - - auto* bottom_l_data = x->data(); - auto* bottom_r_data = y->data(); - auto* t_data = w->data(); - auto* out_data = out->mutable_data(TARGET(kCUDA)); - auto* bottom_l_trans_data = tmp->mutable_data(TARGET(kCUDA)); - - gemm_impl_->init( - false, false, x->dims()[0], dim_t * dim_in, dim_in, &context); - gemm_impl_->run( - 1.0f, 0.0f, bottom_l_data, t_data, bottom_l_trans_data, &context); - - for (size_t b = 0; b < x->lod()[0].size() - 1; b++) { - for (int t = 0; t < dim_t; t++) { - int len_l = offset_l[b + 1] - offset_l[b]; - int len_r = offset_r[b + 1] - offset_r[b]; - auto* top_data = out_data + top_offset[b] + t * len_l * len_r; - const auto* l_t_data = - bottom_l_trans_data + offset_l[b] * dim_t * dim_in + t * dim_in; - const auto* r_data = bottom_r_data + offset_r[b] * dim_in; - - gemm_impl_->init(false, - true, - len_l, - len_r, - dim_in, - dim_t * dim_in, - dim_in, - len_r, - &context); - gemm_impl_->run(1.0f, 0.0f, l_t_data, r_data, top_data, &context); - } + int max_len_r = 0; + for (int i = 0; i < offset_r.size() - 1; ++i) { + int cur_len = offset_r[i + 1] - offset_r[i]; + max_len_r = cur_len > max_len_r ? cur_len : max_len_r; } - int batch_size = x->lod()[0].size() - 1; - int lod_lv1_size = batch_size * dim_t; - int lod_lv2_size = x->lod()[0].back() * dim_t; - std::vector out_lod0(batch_size + 1, 0); - std::vector out_lod1(lod_lv1_size + 1, 0); - std::vector out_lod2(lod_lv2_size + 1, 0); - for (int i = 0; i < batch_size; i++) { - out_lod0[i + 1] = out_lod0[i] + dim_t; - int len_l = offset_l[i + 1] - offset_l[i]; - - for (int j = 0; j < dim_t; j++) { - out_lod1[i * dim_t + j + 1] = out_lod1[i * dim_t + j] + len_l; - int len_r = offset_r[i + 1] - offset_r[i]; - - for (int k = 0; k < len_l; k++) { - out_lod2[offset_l[i] * dim_t + j * len_l + k + 1] = - out_lod2[offset_l[i] * dim_t + j * len_l + k] + len_r; - } - } + _input_l_transform.Resize({batch, dim_t, dim_in, len_l}); + _input_l_transform_reorganize.Resize({batch, dim_t, len_l, dim_in}); + _output_tmp.Resize({batch, max_len_r, dim_t, len_l}); + out->Resize({batch, dim_t, len_l, max_len_r}); + + _offset_r.Resize({static_cast(offset_r.size())}); + TargetWrapperCuda::MemcpyAsync(_offset_r.mutable_data(TARGET(kCUDA)), + &offset_r_int[0], + sizeof(int) * offset_r.size(), + IoDirection::HtoD, + stream); + + int len_r = offset_r[offset_r.size() - 1]; + const float* input_l = x->data(); + const float* input_r = y->data(); + const float* weight_data = w->data(); + float* input_l_transform = + _input_l_transform.mutable_data(TARGET(kCUDA)); + float* input_l_transform_reorganize = + _input_l_transform_reorganize.mutable_data(TARGET(kCUDA)); + float* output_tmp = _output_tmp.mutable_data(TARGET(kCUDA)); + float* out_data = out->mutable_data(TARGET(kCUDA)); + + gemm_impl_->init(true, true, dim_t * dim_in, len_l, dim_in, &context); + gemm_impl_->run( + 1.0f, 0.0f, weight_data, input_l, input_l_transform, &context); + for (int i = 0; i < dim_t; ++i) { + int offset = i * dim_in * len_l; + gpu_transpose(gemm_impl_->get_handle(), + input_l_transform + offset, + dim_in, + len_l, + input_l_transform_reorganize + offset); } - - LoD out_lod; - out_lod.push_back(top_offset); - out_lod.push_back(offset_l); - out_lod.push_back(offset_r); - out->set_lod(out_lod); + gemm_impl_->init(false, true, len_r, dim_t * len_l, dim_in, &context); + gemm_impl_->run( + 1.0f, 0.0f, input_r, input_l_transform_reorganize, output_tmp, &context); + int seq_num = offset_r.size() - 1; + int count = seq_num * max_len_r * dim_t * len_l; + const int blocks = 512; + const int grids = (count + blocks - 1) / blocks; + padding_out<<>>(_output_tmp.data(), + _offset_r.data(), + seq_num, + max_len_r, + dim_t * len_l, + count, + out_data); + out->set_lod(y->lod()); } } // namespace cuda diff --git a/lite/kernels/cuda/match_matrix_tensor_compute.h b/lite/kernels/cuda/match_matrix_tensor_compute.h index 09db326ff3..d5fe8885f2 100644 --- a/lite/kernels/cuda/match_matrix_tensor_compute.h +++ b/lite/kernels/cuda/match_matrix_tensor_compute.h @@ -34,6 +34,10 @@ class MatchMatrixTensorCompute private: std::unique_ptr> gemm_impl_; + lite::Tensor _input_l_transform; + lite::Tensor _input_l_transform_reorganize; + lite::Tensor _output_tmp; + lite::Tensor _offset_r; }; } // namespace cuda diff --git a/lite/kernels/cuda/search_fc_compute.cu b/lite/kernels/cuda/search_fc_compute.cu index 591e2474a4..8bd5bad540 100644 --- a/lite/kernels/cuda/search_fc_compute.cu +++ b/lite/kernels/cuda/search_fc_compute.cu @@ -16,92 +16,6 @@ namespace paddle { namespace lite { namespace kernels { namespace cuda { -template -static void anakin_NV_gemv(cublasHandle_t handle, - const bool TransA, - const int M, - const int N, - const T alpha, - const T* A, - const T* x, - const T beta, - T* y); -template <> -void anakin_NV_gemv(cublasHandle_t handle, - const bool TransA, - const int M, - const int N, - const float alpha, - const float* A, - const float* x, - const float beta, - float* y) { - cublasOperation_t cuTransA = (TransA == false) ? CUBLAS_OP_T : CUBLAS_OP_N; - CUBLAS_CHECK( - cublasSgemv(handle, cuTransA, N, M, &alpha, A, N, x, 1, &beta, y, 1)); -} -template -static void anakin_NV_gemm(cublasHandle_t handle, - const bool TransA, - const bool TransB, - const int M, - const int N, - const int K, - const T alpha, - const T* A, - const T* B, - const T beta, - T* C); - -template <> -void anakin_NV_gemm(cublasHandle_t handle, - const bool TransA, - const bool TransB, - const int M, - const int N, - const int K, - const float alpha, - const float* A, - const float* B, - const float beta, - float* C) { - // Note that cublas follows fortran order. - int lda = (!TransA /* == CblasNoTrans*/) ? K : M; - int ldb = (!TransB /* == CblasNoTrans*/) ? N : K; - cublasOperation_t cuTransA = - (!TransA /* == CblasNoTrans*/) ? CUBLAS_OP_N : CUBLAS_OP_T; - cublasOperation_t cuTransB = - (!TransB /* == CblasNoTrans*/) ? CUBLAS_OP_N : CUBLAS_OP_T; - CUBLAS_CHECK(cublasSgemm(handle, - cuTransB, - cuTransA, - N, - M, - K, - &alpha, - B, - ldb, - A, - lda, - &beta, - C, - N)); -} - -template <> -void anakin_NV_gemm(cublasHandle_t handle, - const bool TransA, - const bool TransB, - const int M, - const int N, - const int K, - const char alpha, - const char* A, - const char* B, - const char beta, - char* C) { - LOG(FATAL) << "int8 gemm is not implemented"; -} template static __global__ void add_bias(int n, @@ -115,6 +29,11 @@ static __global__ void add_bias(int n, } } +template +void SearchFcCompute::PrepareForRun() { + gemm_impl_.reset(new lite::cuda::math::Gemm); +} + template void SearchFcCompute::Run() { auto& param = this->Param(); @@ -132,22 +51,10 @@ void SearchFcCompute::Run() { const T* weight = w_tensor->data(); const Tensor* b_tensor = param.b; const T* bias = b_tensor->data(); - cublasCreate(&_handle); - if (_M == 1 && _K > 50000) { - anakin_NV_gemv(_handle, false, _N, _K, (T)1, weight, din, (T)0, dout); - } else { - anakin_NV_gemm(_handle, - false, - !_flag_trans_weights, - _M, - _N, - _K, - (T)1, - din, - weight, - (T)0, - dout); - } + + CHECK(gemm_impl_->init(false, true, _M, _N, _K, &ctx)); + gemm_impl_->run(1.0f, 0.0f, din, weight, dout, &ctx); + int total_size = _M * _N; add_bias<<>>( total_size, _N, bias, dout); diff --git a/lite/kernels/cuda/search_fc_compute.h b/lite/kernels/cuda/search_fc_compute.h index db09362734..a551486cba 100644 --- a/lite/kernels/cuda/search_fc_compute.h +++ b/lite/kernels/cuda/search_fc_compute.h @@ -14,7 +14,9 @@ #pragma once #include +#include #include "lite/backends/cuda/cuda_utils.h" +#include "lite/backends/cuda/math/gemm.h" #include "lite/core/kernel.h" namespace paddle { @@ -34,16 +36,15 @@ template class SearchFcCompute : public KernelLite { public: using param_t = operators::SearchFcParam; + void PrepareForRun() override; void Run() override; virtual ~SearchFcCompute() = default; private: - bool _flag_trans_weights{false}; + std::unique_ptr> gemm_impl_{nullptr}; int _M; int _K; int _N; - cublasHandle_t _handle; - bool _is_continue_buf{true}; }; } // namespace cuda diff --git a/lite/kernels/cuda/sequence_concat_compute.cu b/lite/kernels/cuda/sequence_concat_compute.cu index d4390046b0..57d6684ffd 100644 --- a/lite/kernels/cuda/sequence_concat_compute.cu +++ b/lite/kernels/cuda/sequence_concat_compute.cu @@ -22,43 +22,44 @@ namespace lite { namespace kernels { namespace cuda { -const int CUDA_NUM_THREADS = 512; - -template -inline LoD ConcatLoD(const std::vector& xs) { - std::vector result; - result.resize(xs[0]->lod()[0].size()); - - for (size_t i = 1; i < result.size(); ++i) { - size_t sum = 0; - for (size_t j = 0; j < xs.size(); ++j) { - auto& x_lod = xs[j]->lod()[0]; - sum += x_lod[i]; - } - result[i] = sum; +template +__global__ void concat_impl_cuda(const int nthreads, + const dtype* in_data, + const int num_concats, + const int concat_size, + const int top_concat_axis, + const int bottom_concat_axis, + const int offset_concat_axis, + dtype* out_data) { + for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; + index += blockDim.x * gridDim.x) { + const int total_concat_size = concat_size * bottom_concat_axis; + const int concat_num = index / total_concat_size; + const int concat_index = index % total_concat_size; + const int top_index = + concat_index + + (concat_num * top_concat_axis + offset_concat_axis) * concat_size; + out_data[top_index] = in_data[index]; } - LoD lod; - lod.emplace_back(result); - return lod; } -template -__global__ void ker_sequence_concat(Dtype* out_data, - const uint64_t* in_locate_data, - const int* o2i_map, - const int* o2i_w_map, - const int seq_num, - const int emb_size, - const int count) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - for (int tid = idx; tid < count; tid += blockDim.x * gridDim.x) { - int emb_id = tid % emb_size; - int word_id = tid / emb_size; - int input_id = o2i_map[word_id]; - int cur_work_id = o2i_w_map[word_id]; - const Dtype* in_data = reinterpret_cast( - reinterpret_cast(in_locate_data[input_id])); - out_data[tid] = in_data[cur_work_id * emb_size + emb_id]; +template +__global__ void concat_impl_2d_impl(const int inner_size, + const int num_concats, + const dtype* in_data, + const int concat_size, + const int out_concat_axis, + const int offset_concat_axis, + dtype* out_data) { + int idx_inner = threadIdx.x + blockIdx.x * blockDim.x; + int idx_outer = threadIdx.y + blockIdx.y * blockDim.y; + + if (idx_inner < inner_size && idx_outer < num_concats) { + int idx_input = idx_outer * inner_size + idx_inner; + int idx_output = + (idx_outer * out_concat_axis + offset_concat_axis) * concat_size + + idx_inner; + out_data[idx_output] = in_data[idx_input]; } } @@ -66,73 +67,75 @@ void SequenceConcatCompute::Run() { auto& param = this->Param(); auto& ctx = this->ctx_->template As(); auto stream = ctx.exec_stream(); - float* out_data = param.Out->mutable_data(TARGET(kCUDA)); - int seq_num = param.X[0]->lod()[0].size() - 1; - const int emb_size = param.X[0]->numel() / param.X[0]->dims()[0]; - std::vector in_locate_vec; - for (size_t i = 0; i < param.X.size(); ++i) { - in_locate_vec.push_back( - reinterpret_cast(param.X[i]->data())); - } - in_locate_tensor.Resize({static_cast(in_locate_vec.size())}); + const int BLOCK_SIZE = 32; + const int axis = 1; + int num_concats = param.X[0]->dims().count(0, axis); + int concat_input_size = + param.X[0]->dims().count(axis + 1, param.X[0]->dims().size()); - std::vector out2in_map; - std::vector out2in_word_map; - for (int i = 0; i < seq_num; ++i) { - for (int j = 0; j < param.X.size(); ++j) { - auto offset = param.X[j]->lod()[0]; - int cur_len = offset[i + 1] - offset[i]; - for (int k = 0; k < cur_len; ++k) { - out2in_map.push_back(j); - out2in_word_map.push_back(offset[i] + k); + int input_size = param.X.size(); + std::vector> shapes_in(input_size); + for (int i = 0; i < input_size; ++i) { + shapes_in[i] = param.X[i]->dims().Vectorize(); + } + std::vector shape_out = shapes_in[0]; + + // compute output shape + for (int i = 1; i < input_size; ++i) { + for (int j = 0; j < shapes_in[i].size(); ++j) { + if (j == axis) { + continue; + } else if (shapes_in[i][j] != -1) { + CHECK_EQ(shape_out[j], shapes_in[i][j]) + << "All inputs must have the same shape, except at concat_axis."; } } + shape_out[axis] += shapes_in[i][axis]; } - int word_num = out2in_map.size(); - out2in_map_tensor.Resize({word_num}); - out2in_word_map_tensor.Resize({word_num}); - int* gpu_o2i_map_data = out2in_map_tensor.mutable_data(TARGET(kCUDA)); - int* gpu_o2i_w_map_data = - out2in_word_map_tensor.mutable_data(TARGET(kCUDA)); - uint64_t* gpu_in_locate_data = - in_locate_tensor.mutable_data(TARGET(kCUDA)); - TargetWrapperCuda::MemcpyAsync(gpu_o2i_map_data, - out2in_map.data(), - sizeof(int) * out2in_map.size(), - IoDirection::HtoD, - stream); - TargetWrapperCuda::MemcpyAsync(gpu_o2i_w_map_data, - out2in_word_map.data(), - sizeof(int) * out2in_word_map.size(), - IoDirection::HtoD, - stream); - TargetWrapperCuda::MemcpyAsync(gpu_in_locate_data, - in_locate_vec.data(), - sizeof(uint64_t) * in_locate_vec.size(), - IoDirection::HtoD, - stream); - - param.Out->set_lod(ConcatLoD(param.X)); - - int count = param.X[0]->numel(); - for (int i = 1; i < param.X.size(); ++i) { - count += param.X[i]->numel(); + param.Out->Resize(shape_out); + float* out_data = param.Out->mutable_data(TARGET(kCUDA)); + int offset_concat_axis = 0; + const int out_concat_axis = shape_out[axis]; + + for (int i = 0; i < input_size; ++i) { + std::vector in_shape = param.X[i]->dims().Vectorize(); + const auto* in_data = param.X[i]->data(); + const int in_concat_axis = in_shape[axis]; + const int in_concat_size = in_concat_axis * concat_input_size; + const int nthreads = in_concat_size * num_concats; + float ratio = static_cast(in_concat_size) / num_concats; + bool is_balance = (ratio > 0.1 && ratio < 10); + if (is_balance) { + int block_x = BLOCK_SIZE; + int block_y = BLOCK_SIZE; + int grid_x = (in_concat_size + block_x - 1) / block_x; + int grid_y = (num_concats + block_y - 1) / block_y; + dim3 block(block_x, block_y); + dim3 grid(grid_x, grid_y); + concat_impl_2d_impl<<>>(in_concat_size, + num_concats, + in_data, + concat_input_size, + out_concat_axis, + offset_concat_axis, + out_data); + } else { + int grid = (nthreads + BLOCK_SIZE - 1) / BLOCK_SIZE; + concat_impl_cuda<<>>( + nthreads, + in_data, + num_concats, + concat_input_size, + out_concat_axis, + in_concat_axis, + offset_concat_axis, + out_data); + } + offset_concat_axis += in_concat_axis; } - - int blocks = (count + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; - ker_sequence_concat<<>>( - out_data, - gpu_in_locate_data, - gpu_o2i_map_data, - gpu_o2i_w_map_data, - seq_num, - emb_size, - count); - - cudaError_t error = cudaGetLastError(); - if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error); + param.Out->set_lod(param.X[0]->lod()); } } // namespace cuda diff --git a/lite/kernels/cuda/sequence_concat_compute.h b/lite/kernels/cuda/sequence_concat_compute.h index 1737c18dd3..3f2204cd1b 100644 --- a/lite/kernels/cuda/sequence_concat_compute.h +++ b/lite/kernels/cuda/sequence_concat_compute.h @@ -27,11 +27,6 @@ class SequenceConcatCompute void Run() override; virtual ~SequenceConcatCompute() = default; - - private: - lite::Tensor out2in_map_tensor; - lite::Tensor out2in_word_map_tensor; - lite::Tensor in_locate_tensor; }; } // namespace cuda diff --git a/lite/kernels/cuda/sequence_topk_avg_pooling_compute.cu b/lite/kernels/cuda/sequence_topk_avg_pooling_compute.cu index 8ea3edb30d..4794644c6d 100644 --- a/lite/kernels/cuda/sequence_topk_avg_pooling_compute.cu +++ b/lite/kernels/cuda/sequence_topk_avg_pooling_compute.cu @@ -26,6 +26,8 @@ __global__ void topk_avg_pooling_kernel_by_row_improve( const Dtype *input, const int *gpu_input_offset_l, const int *gpu_input_offset_r, + const int row_max, + const int col_max, const int topk_size, const int *topks, const int feat_map_num) { @@ -33,20 +35,17 @@ __global__ void topk_avg_pooling_kernel_by_row_improve( gpu_input_offset_l[blockIdx.x + 1] - gpu_input_offset_l[blockIdx.x]; // 8 int col = gpu_input_offset_r[blockIdx.x + 1] - gpu_input_offset_r[blockIdx.x]; // 30 + int max_k = topks[topk_size - 1]; max_k = max_k < col ? max_k : col; extern __shared__ Dtype smem[]; // H*W - const Dtype *fm_row_in_data = input; - for (int i = 0; i < blockIdx.x; ++i) { - int tmp_row = gpu_input_offset_l[i + 1] - gpu_input_offset_l[i]; - int tmp_col = gpu_input_offset_r[i + 1] - gpu_input_offset_r[i]; - fm_row_in_data += tmp_row * feat_map_num * tmp_col; - } - fm_row_in_data += blockIdx.y * row * col; + const Dtype *fm_row_in_data = input + + blockIdx.x * row_max * feat_map_num * col_max + + blockIdx.y * row_max * col_max; - for (int i = threadIdx.x; i < row * col; i += blockDim.x) { + for (int i = threadIdx.x; i < row * col_max; i += blockDim.x) { smem[i] = fm_row_in_data[i]; } __syncthreads(); @@ -57,13 +56,13 @@ __global__ void topk_avg_pooling_kernel_by_row_improve( (gpu_input_offset_l[blockIdx.x] + idx) * feat_map_num * topk_size + blockIdx.y * topk_size; - Dtype *smem_start_col = smem + idx * col; + Dtype *smem_start_col = smem + idx * col_max; int counter = max_k; // topk_size; Dtype last_max_val = -20000.0; while (counter) { Dtype max_val = -10000.0; - int max_pos = 0; + int max_pos = 0; // -1; int m = 0; for (; m < col; m++) { Dtype cur_data = smem_start_col[m]; @@ -77,6 +76,7 @@ __global__ void topk_avg_pooling_kernel_by_row_improve( max_val = last_max_val; } smem_start_col[max_pos] = -10000000.0; + int i = max_k - counter; for (int c = 0; c < topk_size; c++) { if (i <= topks[c] - 1) { @@ -98,22 +98,18 @@ void SequenceTopkAvgPoolingCompute::Run() { auto ¶m = this->Param(); auto &ctx = this->ctx_->template As(); auto cuda_stream = ctx.exec_stream(); - int topk_num = param.topks.size(); - lite::DDim top_ks_shape(std::vector{topk_num, 1, 1, 1}); - _top_ks.Resize(top_ks_shape); - cudaMemcpyAsync(_top_ks.mutable_data(TARGET(kCUDA)), - ¶m.topks[0], - sizeof(int) * topk_num, - cudaMemcpyHostToDevice, - cuda_stream); - int width_offset_len = param.COLUMN->lod()[0].size(); - lite::DDim width_offset_shape( - std::vector{width_offset_len, 1, 1, 1}); + CHECK(param.X->lod().size() > 0 && param.X->lod()[0].size() > 0) + << "X sequence offset is not valid"; + CHECK(param.ROW->lod().size() > 0 && param.ROW->lod()[0].size() > 0) + << "ROW sequence offset is not valid"; + + int width_offset_len = param.X->lod()[0].size(); + lite::DDim width_offset_shape(std::vector{width_offset_len}); _width_offset.Resize(width_offset_shape); std::vector width_lod_0(width_offset_len, 0); - for (size_t i = 0; i < param.COLUMN->lod()[0].size(); ++i) { - width_lod_0[i] = static_cast(param.COLUMN->lod()[0][i]); + for (size_t i = 0; i < param.X->lod()[0].size(); ++i) { + width_lod_0[i] = static_cast(param.X->lod()[0][i]); } cudaMemcpyAsync(_width_offset.mutable_data(TARGET(kCUDA)), &width_lod_0[0], @@ -122,8 +118,7 @@ void SequenceTopkAvgPoolingCompute::Run() { cuda_stream); int height_offset_len = param.ROW->lod()[0].size(); - lite::DDim height_offset_shape( - std::vector{height_offset_len, 1, 1, 1}); + lite::DDim height_offset_shape(std::vector{height_offset_len}); _height_offset.Resize(height_offset_shape); std::vector height_lod_0(height_offset_len, 0); for (size_t i = 0; i < param.ROW->lod()[0].size(); ++i) { @@ -139,39 +134,42 @@ void SequenceTopkAvgPoolingCompute::Run() { Tensor *out_tensor = param.Out; const T *in_data = x_tensor->data(); T *out_data = out_tensor->mutable_data(TARGET(kCUDA)); - TargetWrapperCuda::MemsetAsync(out_tensor->mutable_data(TARGET(kCUDA)), - 0, - sizeof(T) * out_tensor->numel(), - cuda_stream); + TargetWrapperCuda::MemsetAsync( + out_data, 0, sizeof(T) * param.Out->numel(), cuda_stream); + + int topk_num = param.topks.size(); + lite::DDim top_ks_shape(std::vector{topk_num, 1, 1, 1}); + _top_ks.Resize(top_ks_shape); + cudaMemcpyAsync(_top_ks.mutable_data(TARGET(kCUDA)), + ¶m.topks[0], + sizeof(int) * topk_num, + cudaMemcpyHostToDevice, + cuda_stream); - int num = param.ROW->lod()[0].size() - 1; - int channel = param.channel_num; + int num = param.X->dims()[0]; + int channel = param.X->dims()[1]; + int height = param.X->dims()[2]; + int width = param.X->dims()[3]; const int *height_offset = _height_offset.data(); const int *width_offset = _width_offset.data(); - int feat_map_size = 0; - for (size_t i = 0; i < height_lod_0.size() - 1; ++i) { - int height = height_lod_0[i + 1] - height_lod_0[i]; - int width = width_lod_0[i + 1] - width_lod_0[i]; - if (height * width > feat_map_size) { - feat_map_size = height * width; - } - } + int feat_map_size = height * width; + dim3 blocks(num, channel); dim3 threads(32, 1); + topk_avg_pooling_kernel_by_row_improve< T><<>>( out_data, in_data, height_offset, width_offset, + height, + width, param.topks.size(), _top_ks.data(), param.channel_num); - - cudaError_t error = cudaGetLastError(); - if (error != cudaSuccess) LOG(ERROR) << cudaGetErrorString(error); } } // namespace cuda diff --git a/lite/kernels/cuda/softmax_compute.cu b/lite/kernels/cuda/softmax_compute.cu index 6293f7295e..157c6ae889 100644 --- a/lite/kernels/cuda/softmax_compute.cu +++ b/lite/kernels/cuda/softmax_compute.cu @@ -21,6 +21,8 @@ namespace kernels { namespace cuda { using Tensor = lite::Tensor; +const int CUDA_NUM_THREADS = 512; + extern __shared__ char tile[]; template __global__ void sharemem_softmax_kernel(int total_size, @@ -149,6 +151,15 @@ __global__ void softmax_divid_output_kernel(int total_size, } } +void SoftmaxCompute::PrepareForRun() { + int device_id; + cudaGetDevice(&device_id); + cudaDeviceProp deviceProp; + cudaGetDeviceProperties(&deviceProp, device_id); + sharedmem_size = deviceProp.sharedMemPerBlock; + max_dimsize = sharedmem_size / sizeof(float) / CUDA_NUM_THREADS; +} + void SoftmaxCompute::Run() { auto& param = this->Param(); auto& ctx = this->ctx_->template As(); @@ -165,18 +176,10 @@ void SoftmaxCompute::Run() { int total_threads = inner_num * outer_num; int axis_size = x_dims[axis]; - int device_id; - const int threads = 512; + const int threads = CUDA_NUM_THREADS; const int blocks = (total_threads + threads - 1) / threads; - cudaGetDevice(&device_id); - cudaDeviceProp deviceProp; - cudaGetDeviceProperties(&deviceProp, device_id); - size_t sharedmem_size = deviceProp.sharedMemPerBlock; - int max_dimsize = sharedmem_size / sizeof(float) / threads; auto input_data = param.x->data(); auto output_data = param.output->mutable_data(TARGET(kCUDA)); - TargetWrapperCuda::MemsetSync( - output_data, 0, param.output->numel() * sizeof(float)); if (axis_size <= max_dimsize) { int use_sharemem_size = axis_size * threads * sizeof(float); sharemem_softmax_kernel<<>>( diff --git a/lite/kernels/cuda/softmax_compute.h b/lite/kernels/cuda/softmax_compute.h index 4acde4ab07..72d43a8eff 100644 --- a/lite/kernels/cuda/softmax_compute.h +++ b/lite/kernels/cuda/softmax_compute.h @@ -25,8 +25,14 @@ class SoftmaxCompute public: using param_t = operators::SoftmaxParam; + void PrepareForRun() override; void Run() override; virtual ~SoftmaxCompute() = default; + + private: + size_t sharedmem_size; + int num_threads; + int max_dimsize; }; } // namespace cuda diff --git a/lite/kernels/cuda/var_conv_2d_compute.cu b/lite/kernels/cuda/var_conv_2d_compute.cu index f2588a8f53..92a59876a6 100644 --- a/lite/kernels/cuda/var_conv_2d_compute.cu +++ b/lite/kernels/cuda/var_conv_2d_compute.cu @@ -25,224 +25,79 @@ namespace lite { namespace kernels { namespace cuda { -const int CUDA_NUM_THREADS = 512; - -template -__global__ void var_im2col_gpu_kernel(const int n, - const Dtype* data_im, - const int height, - const int width, - const int kernel_h, - const int kernel_w, - const int pad_h, - const int pad_w, - const int stride_h, - const int stride_w, - const int height_col, - const int width_col, - Dtype* data_col) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - for (int index = idx; index < n; index += blockDim.x * gridDim.x) { - const int h_index = index / width_col; - const int h_col = h_index % height_col; - const int w_col = index % width_col; - const int c_im = h_index / height_col; - const int c_col = c_im * kernel_h * kernel_w; - const int h_offset = h_col * stride_h - pad_h; - const int w_offset = w_col * stride_w - pad_w; - - Dtype* data_col_ptr = data_col; - data_col_ptr += (c_col * height_col + h_col) * width_col + w_col; - const Dtype* data_im_ptr = data_im; - data_im_ptr += (c_im * height + h_offset) * width + w_offset; - - for (int i = 0; i < kernel_h; ++i) { - for (int j = 0; j < kernel_w; ++j) { - int h_im = h_offset + i; - int w_im = w_offset + j; - *data_col_ptr = - (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) - ? data_im_ptr[i * width + j] - : 0; - data_col_ptr += height_col * width_col; - } - } - } +inline int ConvOutputSize(int input_size, + int filter_size, + int dilation, + int pad_left, + int pad_right, + int stride) { + const int dkernel = dilation * (filter_size - 1) + 1; + int output_size = + (input_size + (pad_left + pad_right) - dkernel) / stride + 1; + + return output_size; } -void VarConv2DCompute::var_im2col(const cudaStream_t& stream) { +void VarConv2DCompute::PrepareForRun() { + auto& context = this->ctx_->template As(); + auto stream = context.exec_stream(); auto& param = this->Param(); - int input_channel = param.input_channel; - int kernel_h = param.kernel_h; - int kernel_w = param.kernel_w; - int stride_h = param.stride_h; - int stride_w = param.stride_w; - // auto* in_row = param.ROW; - // auto* in_col = param.COLUMN; - const auto* input = param.X; - auto* col = param.Col; - - int batch = input->lod()[0].size() - 1; - const auto& bottom_offset = input->lod()[0]; - // 2-D lod info. - // const auto& offset_x = in_col->lod()[0]; - // const auto& offset_y = in_row->lod()[0]; - const auto& offset_y = param.X->lod()[1]; - const auto& offset_x = param.X->lod()[2]; - // top offset is the whole size of each data sample - std::vector top_offset; - int top_size = 0; - top_offset.push_back(top_size); - for (int b = 0; b < batch; ++b) { - int width = offset_x[b + 1] - offset_x[b]; - int height = offset_y[b + 1] - offset_y[b]; - int top_im_x = 0; - if (width == 0) { - top_im_x = 0; - } else { - top_im_x = (width - 1) / stride_w + 1; - } - int top_im_y = 0; - if (height == 0) { - top_im_y = 0; - } else { - top_im_y = (height - 1) / stride_h + 1; - } - int top_x = top_im_x * top_im_y; - int top_y = input_channel * kernel_h * kernel_w; - top_size += top_y * top_x; - top_offset.push_back(top_size); - } - - LoD col_lod; - col_lod.push_back(top_offset); - col->set_lod(col_lod); - std::vector col_dims_vec{top_size}; - col_dims_vec.push_back(1); - col->Resize(col_dims_vec); - auto* top_data = col->mutable_data(TARGET(kCUDA)); - const auto* bottom_data = input->data(); - - for (int b = 0; b < batch; ++b) { - int t_offset = top_offset[b]; - int b_offset = bottom_offset[b]; - int width = offset_x[b + 1] - offset_x[b]; - int height = offset_y[b + 1] - offset_y[b]; - if (width == 0 || height == 0) { - continue; - } - int width_col = (width - 1) / stride_w + 1; - int height_col = (height - 1) / stride_h + 1; - const float* data_im = bottom_data + b_offset; - float* data_col = top_data + t_offset; - - // We are going to launch channels * height_col * width_col kernels, each - // kernel responsible for copying a single-channel grid. - int num_kernels = height_col * width_col * input_channel; - const int CUDA_NUM_BLOCKS = - (num_kernels + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; - var_im2col_gpu_kernel< - float><<>>( - num_kernels, - data_im, - height, - width, - kernel_h, - kernel_w, - ((stride_h - 1) * height + kernel_h - 1) / 2, - ((stride_w - 1) * width + kernel_w - 1) / 2, - stride_h, - stride_w, - height_col, - width_col, - data_col); + conv_param_.x = const_cast(param.X); + conv_param_.var_length = true; + + conv_param_.paddings.reset(new std::vector); + conv_param_.paddings->push_back(static_cast(param.kernel_h / 2)); + conv_param_.paddings->push_back(static_cast(param.kernel_h / 2)); + conv_param_.paddings->push_back(static_cast(param.kernel_w / 2)); + conv_param_.paddings->push_back(static_cast(param.kernel_w / 2)); + conv_param_.dilations.reset(new std::vector); + conv_param_.dilations->push_back(1); + conv_param_.dilations->push_back(1); + conv_param_.strides[0] = param.stride_h; + conv_param_.strides[1] = param.stride_w; + conv_param_.filter = const_cast(param.W); + conv_param_.filter->Resize({param.output_channel, + param.input_channel, + param.kernel_h, + param.kernel_w}); + + conv_param_.output = param.Out; + std::vector output_shape( + {conv_param_.x->dims()[0], param.output_channel}); + for (size_t i = 0; i < conv_param_.strides.size(); ++i) { + output_shape.push_back( + ConvOutputSize(conv_param_.x->dims()[i + 2], + conv_param_.filter->dims()[i + 2], + (*conv_param_.dilations.get())[i], + (*conv_param_.paddings.get())[i * 2], + (*conv_param_.paddings.get())[i * 2 + 1], + conv_param_.strides[i])); } + conv_param_.output->Resize({output_shape}); + conv_impl_.reset(new lite::cuda::math::CudnnConv2D); + conv_impl_->init(conv_param_, &context); } void VarConv2DCompute::Run() { + auto& context = this->ctx_->template As(); + auto stream = context.exec_stream(); auto& param = this->Param(); - auto& ctx = this->ctx_->template As(); - auto stream = ctx.exec_stream(); - - auto* bottom = param.X; - // auto* in_row = param.ROW; - // auto* in_col = param.COLUMN; - auto* w = param.W; - auto* top = param.Out; - auto* col = param.Col; - int output_channel = param.output_channel; - int input_channel = param.input_channel; - int kernel_h = param.kernel_h; - int kernel_w = param.kernel_w; - int stride_h = param.stride_h; - int stride_w = param.stride_w; - - var_im2col(stream); - - int batch = bottom->lod()[0].size() - 1; - const auto& col_offset = col->lod()[0]; - // const auto& offset_x = in_col->lod()[0]; - // const auto& offset_y = in_row->lod()[0]; - const auto& offset_y = param.X->lod()[1]; - const auto& offset_x = param.X->lod()[2]; - std::vector top_offset; - std::vector height_vector; - std::vector width_vector; - int top_size = 0; - top_offset.push_back(top_size); - for (int b = 0; b < batch; ++b) { - int width = offset_x[b + 1] - offset_x[b]; - int height = offset_y[b + 1] - offset_y[b]; - int top_im_x = 0; - if (width == 0) { - top_im_x = 0; - } else { - top_im_x = (width - 1) / stride_w + 1; - } - int top_im_y = 0; - if (height == 0) { - top_im_y = 0; - } else { - top_im_y = (height - 1) / stride_h + 1; - } - height_vector.push_back(top_im_y); - width_vector.push_back(top_im_x); - int top_im_size = top_im_y * top_im_x; - top_size += output_channel * top_im_size; - top_offset.push_back(top_size); - } - - LoD top_lod; - top_lod.push_back(top_offset); - top->set_lod(top_lod); - std::vector top_dims_vec{top_size}; - top_dims_vec.push_back(1); - top->Resize(top_dims_vec); - auto* top_data = top->mutable_data(TARGET(kCUDA)); - const auto* w_data = w->data(); - const auto* col_data = col->data(); - - std::unique_ptr> gemm_impl_; - for (int b = 0; b < batch; ++b) { - int top_im_size = (top_offset[b + 1] - top_offset[b]) / output_channel; - if (top_im_size == 0) { - continue; - } - float* out_data = top_data + top_offset[b]; - const float* in_data = col_data + col->lod()[0][b]; - gemm_impl_.reset(new lite::cuda::math::Gemm); - gemm_impl_->init(false, - false, - w->dims()[0], - height_vector[b] * width_vector[b], - input_channel * kernel_h * kernel_w, - &ctx); - gemm_impl_->run(1., 0., w_data, in_data, out_data, &ctx); + param.Out->set_lod(param.X->lod()); + std::vector output_shape( + {conv_param_.x->dims()[0], param.output_channel}); + for (size_t i = 0; i < conv_param_.strides.size(); ++i) { + output_shape.push_back( + ConvOutputSize(conv_param_.x->dims()[i + 2], + conv_param_.filter->dims()[i + 2], + (*conv_param_.dilations.get())[i], + (*conv_param_.paddings.get())[i * 2], + (*conv_param_.paddings.get())[i * 2 + 1], + conv_param_.strides[i])); } - - cudaError_t error = cudaGetLastError(); - if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error); + conv_param_.output->Resize({output_shape}); + conv_impl_->create(conv_param_, &context); + conv_impl_->run(conv_param_); } } // namespace cuda diff --git a/lite/kernels/cuda/var_conv_2d_compute.h b/lite/kernels/cuda/var_conv_2d_compute.h index e0b8e30c50..4bb61132db 100644 --- a/lite/kernels/cuda/var_conv_2d_compute.h +++ b/lite/kernels/cuda/var_conv_2d_compute.h @@ -13,6 +13,8 @@ // limitations under the License. #pragma once +#include +#include "lite/backends/cuda/math/cudnn_conv.h" #include "lite/core/kernel.h" namespace paddle { @@ -25,10 +27,12 @@ class VarConv2DCompute : public KernelLite { using param_t = operators::VarConv2DParam; void Run() override; + void PrepareForRun() override; virtual ~VarConv2DCompute() = default; private: - void var_im2col(const cudaStream_t& stream); + mutable operators::ConvParam conv_param_; + std::unique_ptr> conv_impl_; }; } // namespace cuda diff --git a/lite/kernels/x86/sequence_concat_compute.h b/lite/kernels/x86/sequence_concat_compute.h index 553e2e8b06..8dd7077f7d 100644 --- a/lite/kernels/x86/sequence_concat_compute.h +++ b/lite/kernels/x86/sequence_concat_compute.h @@ -52,7 +52,29 @@ class SequenceConcatCompute void Run() override { auto& param = *param_.get_mutable(); - // auto& param = Param(); + + int64_t batch_size = 0; + int64_t feature_size = 0; + std::vector out_dims; + for (const auto& tensor : param.X) { + const auto x_dims = tensor->dims(); + if (out_dims.empty()) { + out_dims = x_dims.Vectorize(); + } + batch_size += x_dims[0]; + if (feature_size == 0) { + feature_size = x_dims.production() / x_dims[0]; + } else { + CHECK_EQ(feature_size, x_dims.production() / x_dims[0]) + << "Inputs of sequence concat must have same feature size"; + } + } + if (batch_size < 0) { + batch_size = -1; // Normalize batch size for compile time. + } + out_dims[0] = batch_size; + param.Out->Resize(out_dims); + T* dout = param.Out->mutable_data(); std::vector x_in_order; diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index aade54c0e5..dc010b96c0 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -286,6 +286,8 @@ struct ConvParam { std::string data_format{"Anylayout"}; // for activation ActivationParam activation_param; + // support var_length or not + bool var_length{false}; // for int8 WITH_INT8_CONFIG }; diff --git a/lite/operators/sequence_concat_op.cc b/lite/operators/sequence_concat_op.cc index 2a54df890c..88afe5e00f 100644 --- a/lite/operators/sequence_concat_op.cc +++ b/lite/operators/sequence_concat_op.cc @@ -23,47 +23,10 @@ bool SequenceConcatOp::CheckShape() const { CHECK_GT(param_.X.size(), 1) << "The number of input sequences is at least two."; CHECK_OR_FALSE(param_.Out); - size_t lod_size = 0; - for (const auto &t : param_.X) { - CHECK_EQ(t->lod().empty(), false) - << "Input Tensor of X does not contain LoD information."; - // CHECK_EQ(t->lod().size(), 1) << "Only support one level sequence now."; - if (lod_size == 0) { - lod_size = t->lod()[0].size(); - } else { - CHECK_EQ(t->lod()[0].size(), lod_size) - << "The number of sequence must be same between each input"; - } - } - CHECK_NE(lod_size, 0) << "Each input must have sequence information"; return true; } -bool SequenceConcatOp::InferShape() const { - int64_t batch_size = 0; - int64_t feature_size = 0; - std::vector out_dims; - for (const auto &tensor : param_.X) { - const auto x_dims = tensor->dims(); - if (out_dims.empty()) { - out_dims = x_dims.Vectorize(); - } - batch_size += x_dims[0]; - if (feature_size == 0) { - feature_size = x_dims.production() / x_dims[0]; - } else { - CHECK_EQ(feature_size, x_dims.production() / x_dims[0]) - << "Inputs of sequence concat must have same feature size"; - } - } - if (batch_size < 0) { - batch_size = -1; // Normalize batch size for compile time. - } - out_dims[0] = batch_size; - param_.Out->Resize(out_dims); - // LoD info will be computed in Kernel. - return true; -} +bool SequenceConcatOp::InferShape() const { return true; } bool SequenceConcatOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { diff --git a/lite/operators/var_conv_2d_op.cc b/lite/operators/var_conv_2d_op.cc index 5c7fe374fc..d87c871a32 100644 --- a/lite/operators/var_conv_2d_op.cc +++ b/lite/operators/var_conv_2d_op.cc @@ -19,28 +19,7 @@ namespace paddle { namespace lite { namespace operators { -bool VarConv2dOp::CheckShape() const { - auto x_dims = param_.X->dims(); - CHECK_EQ(x_dims.size(), 2) << "The rank of X(Input) can't be less than 2."; - auto w_dims = param_.W->dims(); - CHECK_EQ(w_dims.size(), 2) << "W should be 2-D tensor"; - CHECK_EQ(w_dims[0], param_.output_channel) - << "W dim[0] should be equal to OutputChannel"; - CHECK_EQ(w_dims[1], param_.input_channel * param_.kernel_h * param_.kernel_w) - << "W dim[1] should be equal to InputChannel * KernelH * KernelW"; - LoD x_lod = param_.X->lod(); - CHECK_EQ(x_lod.empty(), false) << "The Input(X) must hold lod info."; - // CHECK_GE(x_lod.size(), 1) << "The Input(X)'s lod info is corrupted."; - CHECK_GE(x_lod.size(), 3) << "The Input(X)'s lod info is corrupted."; - CHECK_EQ(x_dims[0], static_cast(x_lod[0].back())) - << "The Input(X)'s lod info mismatches the actual tensor shape."; - // LoD row_lod = param_.ROW->lod(); - // CHECK_EQ(row_lod.empty(), false) << "The Input(ROW) must hold lod info."; - // LoD col_lod = param_.COLUMN->lod(); - // CHECK_EQ(col_lod.empty(), false) << "The Input(COLUMN) must hold lod - // info."; - return true; -} +bool VarConv2dOp::CheckShape() const { return true; } bool VarConv2dOp::InferShape() const { return true; } -- GitLab