From fd91b8281480233cb814443a49095ea4046db0a5 Mon Sep 17 00:00:00 2001 From: hjchen2 Date: Thu, 20 Dec 2018 16:06:07 +0800 Subject: [PATCH] Refine: pass input's lod info to output for quantize and dequantize op --- src/operators/kernel/arm/dequantize_kernel.cpp | 5 +++-- src/operators/kernel/arm/quantize_kernel.cpp | 5 +++-- src/operators/math/sequence2batch.h | 4 ++-- src/operators/op_param.h | 8 ++++---- 4 files changed, 12 insertions(+), 10 deletions(-) diff --git a/src/operators/kernel/arm/dequantize_kernel.cpp b/src/operators/kernel/arm/dequantize_kernel.cpp index 2c13cac1a6..c9dfd6f936 100644 --- a/src/operators/kernel/arm/dequantize_kernel.cpp +++ b/src/operators/kernel/arm/dequantize_kernel.cpp @@ -30,8 +30,8 @@ bool DequantizeKernel::Init(DequantizeParam *param) { template <> void DequantizeKernel::Compute(const DequantizeParam ¶m) { - const Tensor *input = param.input_; - Tensor *output = param.output_; + const LoDTensor *input = param.input_; + LoDTensor *output = param.output_; float activation_scale = param.activation_scale_->data()[0]; float weight_scale = param.weight_scale_; const int32_t *x = input->data(); @@ -72,6 +72,7 @@ void DequantizeKernel::Compute(const DequantizeParam ¶m) { for (size_t i = 0; i < size; ++i) { y[i] = x[i] * scale; } + output->set_lod(input->lod()); } } // namespace operators diff --git a/src/operators/kernel/arm/quantize_kernel.cpp b/src/operators/kernel/arm/quantize_kernel.cpp index 03f4ac81fb..97ffa05c86 100644 --- a/src/operators/kernel/arm/quantize_kernel.cpp +++ b/src/operators/kernel/arm/quantize_kernel.cpp @@ -186,8 +186,8 @@ bool QuantizeKernel::Init(QuantizeParam *param) { template <> void QuantizeKernel::Compute(const QuantizeParam ¶m) { - const Tensor *input = param.input_; - Tensor *output = param.output_; + const LoDTensor *input = param.input_; + LoDTensor *output = param.output_; Tensor *output_scale = param.online_scale_; float max_abs = 0.f; if (param.offline_) { @@ -212,6 +212,7 @@ void QuantizeKernel::Compute(const QuantizeParam ¶m) { LOG(kLOG_ERROR) << "round type is not supported."; break; } + output->set_lod(input->lod()); } } // namespace operators diff --git a/src/operators/math/sequence2batch.h b/src/operators/math/sequence2batch.h index 42b369f7dc..537f2326d0 100644 --- a/src/operators/math/sequence2batch.h +++ b/src/operators/math/sequence2batch.h @@ -69,10 +69,10 @@ class LoDTensor2BatchFunctor { auto lods = lod_tensor.lod(); PADDLE_MOBILE_ENFORCE((lods.size() == 1UL), - "Only support one level sequence now."); + "Only support 1 level sequence, but %d is given", + lods.size()); const auto& lod = lods[0]; - std::vector seq_info; for (size_t seq_id = 0; seq_id < lod.size() - 1; ++seq_id) { int length = lod[seq_id + 1] - lod[seq_id]; diff --git a/src/operators/op_param.h b/src/operators/op_param.h index adc725e25f..643bd65ee0 100644 --- a/src/operators/op_param.h +++ b/src/operators/op_param.h @@ -2589,9 +2589,9 @@ class QuantizeParam : public OpParam { public: // op input - RType *input_; + GType *input_; // op output - RType *output_; + GType *output_; RType *online_scale_; // quantize offline scale RType *offline_scale_; @@ -2625,9 +2625,9 @@ class DequantizeParam : public OpParam { public: // op input - RType *input_; + GType *input_; // op output - RType *output_; + GType *output_; RType *activation_scale_; float weight_scale_; }; -- GitLab