From a105bbdf796b3a4b1ce29283c3bebbe838f89063 Mon Sep 17 00:00:00 2001 From: Wilber Date: Fri, 24 Jul 2020 12:38:32 +0800 Subject: [PATCH] [CUDA] Support model run correctly. (#3975) --- lite/backends/cuda/math/gru_forward.h | 10 +++++- lite/backends/cuda/math/scale.cu | 15 +++------ lite/backends/cuda/math/scale.h | 3 +- lite/backends/cuda/math/sequence2batch.cu | 4 +-- lite/backends/cuda/math/sequence2batch.h | 20 ++++++------ lite/backends/cuda/math/sequence_padding.cu | 6 ++-- lite/kernels/cuda/assign_value_compute.cu | 2 +- lite/kernels/cuda/dropout_compute.cc | 5 ++- lite/kernels/cuda/gru_compute.cu | 11 +++---- lite/kernels/cuda/scale_compute.cc | 7 +++-- lite/kernels/cuda/sequence_mask_compute.cu | 13 +++++--- lite/kernels/cuda/sequence_pad_compute.cu | 14 +++++++-- lite/kernels/cuda/sequence_unpad_compute.cu | 34 ++++++++++++++++++++- lite/kernels/cuda/sequence_unpad_compute.h | 1 + lite/kernels/cuda/var_conv_2d_compute.cu | 5 +++ lite/operators/gru_op.cc | 5 ++- lite/operators/sequence_pad_op.cc | 13 ++++---- lite/operators/sequence_unpad_op.cc | 27 +--------------- lite/operators/var_conv_2d_op.cc | 16 ++++++---- 19 files changed, 121 insertions(+), 90 deletions(-) diff --git a/lite/backends/cuda/math/gru_forward.h b/lite/backends/cuda/math/gru_forward.h index 22d4ae3a9d..3a1648c437 100644 --- a/lite/backends/cuda/math/gru_forward.h +++ b/lite/backends/cuda/math/gru_forward.h @@ -30,9 +30,16 @@ namespace lite { namespace cuda { namespace math { +#define SIGMOID_THRESHOLD_MIN -40.0 +#define SIGMOID_THRESHOLD_MAX 13.0 +#define EXP_MAX_INPUT 40.0 + template inline __device__ Dtype Sigmoid(const Dtype a) { - return static_cast(1.0) / (static_cast(1.0) + expf(-a)); + const Dtype min = SIGMOID_THRESHOLD_MIN; + const Dtype max = SIGMOID_THRESHOLD_MAX; + Dtype tmp = (a < min) ? min : ((a > max) ? max : a); + return static_cast(1.0) / (static_cast(1.0) + expf(-tmp)); } template <> @@ -63,6 +70,7 @@ inline __device__ half ReLU(const half a) { template inline __device__ Dtype Tanh(const Dtype a) { Dtype tmp = static_cast(-2.0) * a; + tmp = (tmp > EXP_MAX_INPUT) ? EXP_MAX_INPUT : tmp; return (static_cast(2.0) / (static_cast(1.0) + expf(tmp))) - static_cast(1.0); } diff --git a/lite/backends/cuda/math/scale.cu b/lite/backends/cuda/math/scale.cu index 806a3697a2..f9d5209c3e 100644 --- a/lite/backends/cuda/math/scale.cu +++ b/lite/backends/cuda/math/scale.cu @@ -22,10 +22,6 @@ namespace lite { namespace cuda { namespace math { -#define CUDA_KERNEL_LOOP(i, n) \ - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ - i += blockDim.x * gridDim.x) - template __global__ void scale_kernel(int count, const T* in_data, @@ -48,7 +44,6 @@ __global__ void scale_kernel(int count, template __global__ void scale_kernel( int count, const T* in_data, T* out_data, const T scale, const T bias) { - int tid = blockIdx.x * blockDim.x + threadIdx.x; CUDA_KERNEL_LOOP(tid, count) { out_data[tid] = scale * in_data[tid] + bias; } } @@ -133,12 +128,11 @@ void fp32_scale_nhwc(int num, } template -void scale(int num, const T* in, T* out, T scale, cudaStream_t stream, T bias) { +void scale(int num, const T* in, T* out, T scale, T bias, cudaStream_t stream) { int thread = 256; int block = (num + thread - 1) / thread; scale_kernel<<>>(num, in, out, scale, bias); - cudaError_t error = cudaGetLastError(); - if (error != cudaSuccess) std::cout << cudaGetErrorString(error); + CUDA_POST_KERNEL_CHECK; } template @@ -146,11 +140,10 @@ void scale(int num, const T* in, T* out, T scale, T bias) { int thread = 256; int block = (num + thread - 1) / thread; scale_kernel<<>>(num, in, out, scale, bias); - cudaError_t error = cudaGetLastError(); - if (error != cudaSuccess) std::cout << cudaGetErrorString(error); + CUDA_POST_KERNEL_CHECK; } -template void scale(int num, const float*, float*, float, cudaStream_t, float); +template void scale(int num, const float*, float*, float, float, cudaStream_t); template void scale(int num, const float*, float*, float, float); } // namespace math diff --git a/lite/backends/cuda/math/scale.h b/lite/backends/cuda/math/scale.h index 52ed1d38ae..b9961b12c3 100644 --- a/lite/backends/cuda/math/scale.h +++ b/lite/backends/cuda/math/scale.h @@ -32,8 +32,7 @@ void fp32_scale_nhwc(int num, cudaStream_t stream); template -void scale( - int num, const T* in, T* out, T scale, cudaStream_t stream, T bias = 0); +void scale(int num, const T* in, T* out, T scale, T bias, cudaStream_t stream); template void scale(int num, const T* in, T* out, T scale, T bias = 0); diff --git a/lite/backends/cuda/math/sequence2batch.cu b/lite/backends/cuda/math/sequence2batch.cu index edc2a1ec08..9a93362b3b 100644 --- a/lite/backends/cuda/math/sequence2batch.cu +++ b/lite/backends/cuda/math/sequence2batch.cu @@ -32,7 +32,7 @@ __global__ void CopyMatrixRowsKernel(const T* src, bool is_src_index) { int idx = threadIdx.x; int idy = threadIdx.y; - int row_id = blockDim.y * gridDim.x + idy; + int row_id = blockDim.y * blockIdx.x + idy; if (row_id < height) { int src_idx = is_src_index ? index[row_id] : row_id; int dst_idx = is_src_index ? row_id : index[row_id]; @@ -72,7 +72,7 @@ void CopyMatrixRowsFunctor::operator()( dim3 threads(128, 8); dim3 grids((height + threads.y - 1) / threads.y); CopyMatrixRowsKernel<<>>( - src_data, dst_data, index_tensor_data, height, width, true); + src_data, dst_data, index_tensor_data, height, width, is_src_index); CUDA_POST_KERNEL_CHECK; } diff --git a/lite/backends/cuda/math/sequence2batch.h b/lite/backends/cuda/math/sequence2batch.h index 0df9600307..e5a12ed0b4 100644 --- a/lite/backends/cuda/math/sequence2batch.h +++ b/lite/backends/cuda/math/sequence2batch.h @@ -53,11 +53,11 @@ class LoDTensor2BatchFunctor { // s0: 0 0 0 0, s1: 1 1 1 1 1, s2: 2 2 2 // seq_info[3] = {(4, 5, 1), (0, 4, 0), (9, 3, 2)} struct SeqInfo { - SeqInfo(size_t start, size_t length, size_t seq_idx) - : start_(start), length_(length), seq_idx_(seq_idx) {} - size_t start_; - size_t length_; - size_t seq_idx_; + SeqInfo(size_t start_val, size_t len_val, size_t seq_val) + : start(start_val), length(len_val), seq_idx(seq_val) {} + size_t start; + size_t length; + size_t seq_idx; }; public: @@ -76,7 +76,7 @@ class LoDTensor2BatchFunctor { } std::sort(seq_info.begin(), seq_info.end(), [](SeqInfo a, SeqInfo b) { - return a.length_ > b.length_; + return a.length > b.length; }); // Calculate the start position of each batch. @@ -106,7 +106,7 @@ class LoDTensor2BatchFunctor { batch_lods.emplace_back(std::vector{0}); // batch_lods[0] is the start positions for batch LoDTensor - size_t max_seqlen = seq_info[0].length_; + size_t max_seqlen = seq_info[0].length; batch_lods[0].resize(max_seqlen + 1); // batch_lods[1] is the raw index in the input LoDTensor batch_lods[1].resize(static_cast(lod_tensor.dims()[0])); @@ -119,8 +119,8 @@ class LoDTensor2BatchFunctor { for (size_t n = 0; n < max_seqlen; ++n) { size_t batch_id = batch_starts[n]; for (size_t i = 0; i < seq_info.size(); ++i) { - size_t seq_len = seq_info[i].length_; - size_t start = seq_info[i].start_; + size_t seq_len = seq_info[i].length; + size_t start = seq_info[i].start; if (n < seq_len) { seq2batch_idx[batch_id] = is_reverse ? start + seq_len - 1 - n : start + n; @@ -133,7 +133,7 @@ class LoDTensor2BatchFunctor { } auto* seq_order = batch_lods[2].data(); for (size_t i = 0; i < seq_info.size(); ++i) { - seq_order[i] = seq_info[i].seq_idx_; + seq_order[i] = seq_info[i].seq_idx; } batch_tensor->set_lod(batch_lods); diff --git a/lite/backends/cuda/math/sequence_padding.cu b/lite/backends/cuda/math/sequence_padding.cu index 3a32be2a34..e4f194b9c2 100644 --- a/lite/backends/cuda/math/sequence_padding.cu +++ b/lite/backends/cuda/math/sequence_padding.cu @@ -86,8 +86,7 @@ void SequencePadding(T* pad_data, seq_num, pad_seq_len, step_width); - cudaError_t error = cudaGetLastError(); - if (error != cudaSuccess) LOG(ERROR) << cudaGetErrorString(error); + CUDA_POST_KERNEL_CHECK; } template @@ -120,8 +119,7 @@ void SequenceUnpadding(T* seq_data, seq_num, pad_seq_len, step_width); - cudaError_t error = cudaGetLastError(); - if (error != cudaSuccess) LOG(ERROR) << cudaGetErrorString(error); + CUDA_POST_KERNEL_CHECK; } template void SequencePadding(float* pad_data, diff --git a/lite/kernels/cuda/assign_value_compute.cu b/lite/kernels/cuda/assign_value_compute.cu index 89f2937f10..6a2740101c 100644 --- a/lite/kernels/cuda/assign_value_compute.cu +++ b/lite/kernels/cuda/assign_value_compute.cu @@ -68,7 +68,7 @@ void AssignValueCompute::Run() { REGISTER_LITE_KERNEL(assign_value, kCUDA, - kAny, + kFloat, kNCHW, paddle::lite::kernels::cuda::AssignValueCompute, def) diff --git a/lite/kernels/cuda/dropout_compute.cc b/lite/kernels/cuda/dropout_compute.cc index 7e3a3a6243..f9303a39ce 100644 --- a/lite/kernels/cuda/dropout_compute.cc +++ b/lite/kernels/cuda/dropout_compute.cc @@ -23,6 +23,9 @@ namespace cuda { void DropoutCompute::Run() { auto& param = Param(); + auto& ctx = this->ctx_->template As(); + auto stream = ctx.exec_stream(); + const float* x_data = param.x->data(); float* out_data = param.output->mutable_data(TARGET(kCUDA)); int num = param.x->dims().production(); @@ -31,7 +34,7 @@ void DropoutCompute::Run() { if (param.dropout_implementation == "downgrade_in_infer") { scale = 1.0f - prob_data; } - lite::cuda::math::scale(num, x_data, out_data, scale, 0); + lite::cuda::math::scale(num, x_data, out_data, scale, 0.f, stream); } } // namespace cuda diff --git a/lite/kernels/cuda/gru_compute.cu b/lite/kernels/cuda/gru_compute.cu index 630d12f3e6..ddca95048b 100644 --- a/lite/kernels/cuda/gru_compute.cu +++ b/lite/kernels/cuda/gru_compute.cu @@ -11,6 +11,8 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. +#include "lite/kernels/cuda/gru_compute.h" + #include #include "lite/backends/cuda/cuda_utils.h" @@ -19,7 +21,6 @@ #include "lite/backends/cuda/math/sequence2batch.h" #include "lite/backends/cuda/target_wrapper.h" #include "lite/core/op_registry.h" -#include "lite/kernels/cuda/gru_compute.h" namespace paddle { namespace lite { @@ -133,7 +134,6 @@ struct GRUUnitFunctor { value.gate_value, context); } - CUDA_POST_KERNEL_CHECK; lite::cuda::math::GruForwardResetOutput< T><<exec_stream()>>>( @@ -143,7 +143,7 @@ struct GRUUnitFunctor { frame_size, batch_size, active_gate, - batch_size == 1); + batch_size != 1); CUDA_POST_KERNEL_CHECK; if (value.prev_out_value) { @@ -163,7 +163,6 @@ struct GRUUnitFunctor { value.gate_value + frame_size * 2, context); } - CUDA_POST_KERNEL_CHECK; lite::cuda::math::GruForwardFinalOutput< T><<exec_stream()>>>(value.gate_value, @@ -173,7 +172,7 @@ struct GRUUnitFunctor { batch_size, active_node, origin_mode, - batch_size == 1); + batch_size != 1); CUDA_POST_KERNEL_CHECK; } }; @@ -218,7 +217,6 @@ struct GRUUnitFunctor { value.gate_value, context); } - CUDA_POST_KERNEL_CHECK; lite::cuda::math::GruForwardResetOutput< half><<exec_stream()>>>( @@ -248,7 +246,6 @@ struct GRUUnitFunctor { value.gate_value + frame_size * 2, context); } - CUDA_POST_KERNEL_CHECK; lite::cuda::math::GruForwardFinalOutput< half><<exec_stream()>>>( diff --git a/lite/kernels/cuda/scale_compute.cc b/lite/kernels/cuda/scale_compute.cc index 6bf7414d8c..9ce5905a7d 100644 --- a/lite/kernels/cuda/scale_compute.cc +++ b/lite/kernels/cuda/scale_compute.cc @@ -23,8 +23,11 @@ namespace cuda { void ScaleCompute::Run() { auto& param = Param(); + auto& ctx = this->ctx_->template As(); + auto stream = ctx.exec_stream(); + const float* x_data = param.x->data(); - float* output_data = param.output->mutable_data(); + float* output_data = param.output->mutable_data(TARGET(kCUDA)); DDim x_dims = param.x->dims(); bool bias_after_scale = param.bias_after_scale; float scale = param.scale; @@ -33,7 +36,7 @@ void ScaleCompute::Run() { bias *= scale; } lite::cuda::math::scale( - x_dims.production(), x_data, output_data, scale, bias); + x_dims.production(), x_data, output_data, scale, bias, stream); } } // namespace cuda diff --git a/lite/kernels/cuda/sequence_mask_compute.cu b/lite/kernels/cuda/sequence_mask_compute.cu index 8a8f292c10..8e227a6a27 100644 --- a/lite/kernels/cuda/sequence_mask_compute.cu +++ b/lite/kernels/cuda/sequence_mask_compute.cu @@ -12,13 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "lite/kernels/cuda/sequence_mask_compute.h" - #include +#include #include #include "lite/backends/cuda/cuda_utils.h" #include "lite/core/op_registry.h" +#include "lite/kernels/cuda/sequence_mask_compute.h" namespace paddle { namespace lite { @@ -44,7 +44,7 @@ void SequenceMaskCompute::Run() { auto stream = ctx.exec_stream(); const auto* x = param.X; - auto* x_data = x->template data(); + const int64_t* x_data = x->template data(); auto* y = param.Y; int maxlen = param.maxlen; @@ -57,8 +57,11 @@ void SequenceMaskCompute::Run() { } if (maxlen < 0) { - maxlen = thrust::reduce( - x_data, x_data + x->numel(), 0, thrust::maximum()); + maxlen = static_cast( + thrust::reduce(thrust::device_pointer_cast(x_data), + thrust::device_pointer_cast(x_data) + x->numel(), + static_cast(0), + thrust::maximum())); } auto y_dim = x->dims().Vectorize(); diff --git a/lite/kernels/cuda/sequence_pad_compute.cu b/lite/kernels/cuda/sequence_pad_compute.cu index 1e304f0063..8368eb3007 100644 --- a/lite/kernels/cuda/sequence_pad_compute.cu +++ b/lite/kernels/cuda/sequence_pad_compute.cu @@ -32,9 +32,19 @@ void SequencePadCompute::Run() { const auto* pad_value = param.PadValue; auto* out = param.Out; auto* len_t = param.Length; - int padded_length = param.padded_length; - int seq_num = x->lod()[0].size() - 1; + int padded_length; + if (param.padded_length == -1) { + int max_seq_len = 0; + for (int i = 0; i < seq_num; ++i) { + max_seq_len = std::max( + max_seq_len, static_cast(x->lod()[0][i + 1] - x->lod()[0][i])); + } + padded_length = max_seq_len; + } else { + padded_length = param.padded_length; + } + int max_seq_len = 0; int step_width = x->numel() / x->dims()[0]; diff --git a/lite/kernels/cuda/sequence_unpad_compute.cu b/lite/kernels/cuda/sequence_unpad_compute.cu index bdedd74588..b4274e19a8 100644 --- a/lite/kernels/cuda/sequence_unpad_compute.cu +++ b/lite/kernels/cuda/sequence_unpad_compute.cu @@ -13,6 +13,7 @@ // limitations under the License. #include + #include "lite/backends/cuda/math/sequence_padding.h" #include "lite/core/op_registry.h" #include "lite/core/target_wrapper.h" @@ -29,8 +30,39 @@ void SequenceUnpadCompute::Run() { auto& ctx = this->ctx_->template As(); auto stream = ctx.exec_stream(); + auto x_dims = param.X->dims(); + auto len_dims = param.Length->dims(); + + auto* seq_len_ptr = param.Length->template data(); + seq_len_cpu_.Resize(param.Length->dims()); + TargetWrapperCuda::MemcpyAsync(seq_len_cpu_.mutable_data(), + seq_len_ptr, + sizeof(int64_t) * param.Length->numel(), + IoDirection::DtoH, + stream); + TargetWrapperCuda::StreamSync(stream); + + int64_t batch_size = len_dims[0]; + std::vector out_lod0(batch_size + 1, 0); + for (int64_t i = 0; i < batch_size; ++i) { + out_lod0[i + 1] = out_lod0[i] + seq_len_cpu_.data()[i]; + } + paddle::lite::LoD out_lod; + out_lod.push_back(out_lod0); + + int64_t out_dim0 = out_lod0.back(); + std::vector out_dims{out_dim0}; + if (x_dims.size() == 2) { + out_dims.push_back(1); + } else { + for (size_t i = 2; i < x_dims.size(); ++i) { + out_dims.push_back(x_dims[i]); + } + } + param.Out->Resize(out_dims); + param.Out->set_lod(out_lod); + const auto* pad_tensor = param.X; - const auto* len_t = param.Length; auto* seq_tensor = param.Out; int padded_length = pad_tensor->dims()[1]; diff --git a/lite/kernels/cuda/sequence_unpad_compute.h b/lite/kernels/cuda/sequence_unpad_compute.h index f36520ea15..6b077a4dcb 100644 --- a/lite/kernels/cuda/sequence_unpad_compute.h +++ b/lite/kernels/cuda/sequence_unpad_compute.h @@ -31,6 +31,7 @@ class SequenceUnpadCompute : public KernelLite { private: lite::Tensor seq_offsets_; + lite::Tensor seq_len_cpu_; std::vector seq_offsets_vec_; }; diff --git a/lite/kernels/cuda/var_conv_2d_compute.cu b/lite/kernels/cuda/var_conv_2d_compute.cu index b847069879..b14073e5e1 100644 --- a/lite/kernels/cuda/var_conv_2d_compute.cu +++ b/lite/kernels/cuda/var_conv_2d_compute.cu @@ -184,6 +184,8 @@ using VarConvFp16 = REGISTER_LITE_KERNEL(var_conv_2d, kCUDA, kFloat, kNCHW, VarConvFp32, def) .BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))}) .BindInput("W", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindInput("COLUMN", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindInput("ROW", {LiteType::GetTensorTy(TARGET(kCUDA))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))}) .BindOutput("Col", {LiteType::GetTensorTy(TARGET(kCUDA))}) .Finalize(); @@ -191,6 +193,9 @@ REGISTER_LITE_KERNEL(var_conv_2d, kCUDA, kFloat, kNCHW, VarConvFp32, def) REGISTER_LITE_KERNEL(var_conv_2d, kCUDA, kFP16, kNCHW, VarConvFp16, def) .BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFP16))}) .BindInput("W", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFP16))}) + .BindInput("COLUMN", + {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFP16))}) + .BindInput("ROW", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFP16))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFP16))}) .BindOutput("Col", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFP16))}) .Finalize(); diff --git a/lite/operators/gru_op.cc b/lite/operators/gru_op.cc index 862a1ff98f..0a9128dcd2 100644 --- a/lite/operators/gru_op.cc +++ b/lite/operators/gru_op.cc @@ -75,9 +75,8 @@ bool GRUOpLite::AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) { auto batch_reset_hidden_prev = op_desc.Output("BatchResetHiddenPrev").front(); auto batch_hidden = op_desc.Output("BatchHidden").front(); auto hidden = op_desc.Output("Hidden").front(); - param_.input = scope->FindVar(input)->GetMutable(); - if (op_desc.Input("H0").size()) { + if (!op_desc.Input("H0").empty()) { auto h0 = op_desc.Input("H0").front(); param_.h0 = scope->FindVar(h0)->GetMutable(); } @@ -90,7 +89,7 @@ bool GRUOpLite::AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) { scope->FindVar(batch_hidden)->GetMutable(); param_.hidden = scope->FindVar(hidden)->GetMutable(); - if (op_desc.HasInput("Bias")) { + if (!op_desc.Input("Bias").empty()) { auto bias = op_desc.Input("Bias").front(); param_.bias = scope->FindVar(bias)->GetMutable(); } diff --git a/lite/operators/sequence_pad_op.cc b/lite/operators/sequence_pad_op.cc index 687c4a1989..858c0ffcbb 100644 --- a/lite/operators/sequence_pad_op.cc +++ b/lite/operators/sequence_pad_op.cc @@ -61,18 +61,19 @@ bool SequencePadOp::InferShapeImpl() const { max_seq_len = std::max(max_seq_len, static_cast(x_lod_0[i + 1] - x_lod_0[i])); } - if (param_.padded_length == -1) { - param_.padded_length = max_seq_len; + int real_padded_length = param_.padded_length; + if (real_padded_length == -1) { + real_padded_length = max_seq_len; } - CHECK_GE(param_.padded_length, max_seq_len) + CHECK_GE(real_padded_length, max_seq_len) << "The SequencePadOp Attr(padded_length) should be greater than or " "equal to the length of the longest original sequence. But the " "padded_length we received is " - << param_.padded_length + << real_padded_length << ", the length of the longest original sequence is " << max_seq_len; int out_dim_0 = seq_num; - std::vector out_dims_vec{out_dim_0, param_.padded_length}; + std::vector out_dims_vec{out_dim_0, real_padded_length}; std::vector len_dims_vec{out_dim_0}; auto time_step_dims_vec = time_step_dims.Vectorize(); out_dims_vec.insert( @@ -87,7 +88,7 @@ bool SequencePadOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { &scope->FindVar(opdesc.Input("X").front())->Get()); param_.PadValue = const_cast( &scope->FindVar(opdesc.Input("PadValue").front())->Get()); - param_.Length = scope->FindVar(opdesc.Input("Length").front()) + param_.Length = scope->FindVar(opdesc.Output("Length").front()) ->GetMutable(); param_.Out = scope->FindVar(opdesc.Output("Out").front())->GetMutable(); diff --git a/lite/operators/sequence_unpad_op.cc b/lite/operators/sequence_unpad_op.cc index b91d43c741..4f4497f0b8 100644 --- a/lite/operators/sequence_unpad_op.cc +++ b/lite/operators/sequence_unpad_op.cc @@ -32,32 +32,7 @@ bool SequenceUnpadOp::CheckShape() const { return true; } -bool SequenceUnpadOp::InferShapeImpl() const { - auto x_dims = param_.X->dims(); - auto len_dims = param_.Length->dims(); - - auto *seq_len_ptr = param_.Length->data(); - int64_t batch_size = len_dims[0]; - std::vector out_lod0(batch_size + 1, 0); - for (int64_t i = 0; i < batch_size; ++i) { - out_lod0[i + 1] = out_lod0[i] + seq_len_ptr[i]; - } - paddle::lite::LoD out_lod; - out_lod.push_back(out_lod0); - - int64_t out_dim0 = out_lod0.back(); - std::vector out_dims{out_dim0}; - if (x_dims.size() == 2) { - out_dims.push_back(1); - } else { - for (size_t i = 2; i < x_dims.size(); ++i) { - out_dims.push_back(x_dims[i]); - } - } - param_.Out->Resize(out_dims); - param_.Out->set_lod(out_lod); - return true; -} +bool SequenceUnpadOp::InferShapeImpl() const { return true; } bool SequenceUnpadOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { diff --git a/lite/operators/var_conv_2d_op.cc b/lite/operators/var_conv_2d_op.cc index 83b6cc6a24..612632acb4 100644 --- a/lite/operators/var_conv_2d_op.cc +++ b/lite/operators/var_conv_2d_op.cc @@ -26,10 +26,16 @@ bool VarConv2dOp::InferShapeImpl() const { return true; } bool VarConv2dOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { param_.X = const_cast( &scope->FindVar(opdesc.Input("X").front())->Get()); - // param_.ROW = const_cast( - // &scope->FindVar(opdesc.Input("ROW").front())->Get()); - // param_.COLUMN = const_cast( - // &scope->FindVar(opdesc.Input("COLUMN").front())->Get()); + if (opdesc.HasInput("ROW") && !opdesc.Input("ROW").empty()) { + param_.ROW = const_cast( + &scope->FindVar(opdesc.Input("ROW").front())->Get()); + CHECK(param_.ROW) << "Input(ROW) of VarConv2dOP should not be null."; + } + if (opdesc.HasInput("COLUMN") && !opdesc.Input("COLUMN").empty()) { + param_.COLUMN = const_cast( + &scope->FindVar(opdesc.Input("COLUMN").front())->Get()); + CHECK(param_.COLUMN) << "Input(COLUMN) of VarConv2dOP should not be null."; + } param_.W = const_cast( &scope->FindVar(opdesc.Input("W").front())->Get()); param_.Out = @@ -37,8 +43,6 @@ bool VarConv2dOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { param_.Col = scope->FindVar(opdesc.Output("Col").front())->GetMutable(); CHECK(param_.X) << "X(Input) of VarConv2dOP should not be null."; - // CHECK(param_.ROW) << "Input(ROW) of VarConv2dOP should not be null."; - // CHECK(param_.COLUMN) << "Input(COLUMN) of VarConv2dOP should not be null."; CHECK(param_.W) << "W(Input) of VarConv2dOP should not be null."; CHECK(param_.Out) << "Out(Output) of VarConv2dOP should not be null."; CHECK(param_.Col) << "Col(Output) of VarConv2dOP should not be null."; -- GitLab