提交 3cd9c13b 编写于 作者: H hjchen2

Refine: pass input's lod info to output for quantize and dequantize op

上级 ccaf1b12
......@@ -30,8 +30,8 @@ bool DequantizeKernel<CPU, float>::Init(DequantizeParam<CPU> *param) {
template <>
void DequantizeKernel<CPU, float>::Compute(const DequantizeParam<CPU> &param) {
const Tensor *input = param.input_;
Tensor *output = param.output_;
const LoDTensor *input = param.input_;
LoDTensor *output = param.output_;
float activation_scale = param.activation_scale_->data<float>()[0];
float weight_scale = param.weight_scale_;
const int32_t *x = input->data<const int32_t>();
......@@ -72,6 +72,7 @@ void DequantizeKernel<CPU, float>::Compute(const DequantizeParam<CPU> &param) {
for (size_t i = 0; i < size; ++i) {
y[i] = x[i] * scale;
}
output->set_lod(input->lod());
}
} // namespace operators
......
......@@ -186,8 +186,8 @@ bool QuantizeKernel<CPU, float>::Init(QuantizeParam<CPU> *param) {
template <>
void QuantizeKernel<CPU, float>::Compute(const QuantizeParam<CPU> &param) {
const Tensor *input = param.input_;
Tensor *output = param.output_;
const LoDTensor *input = param.input_;
LoDTensor *output = param.output_;
Tensor *output_scale = param.online_scale_;
float max_abs = 0.f;
if (param.offline_) {
......@@ -212,6 +212,7 @@ void QuantizeKernel<CPU, float>::Compute(const QuantizeParam<CPU> &param) {
LOG(kLOG_ERROR) << "round type is not supported.";
break;
}
output->set_lod(input->lod());
}
} // namespace operators
......
......@@ -69,10 +69,10 @@ class LoDTensor2BatchFunctor {
auto lods = lod_tensor.lod();
PADDLE_MOBILE_ENFORCE((lods.size() == 1UL),
"Only support one level sequence now.");
"Only support 1 level sequence, but %d is given",
lods.size());
const auto& lod = lods[0];
std::vector<SeqInfo> seq_info;
for (size_t seq_id = 0; seq_id < lod.size() - 1; ++seq_id) {
int length = lod[seq_id + 1] - lod[seq_id];
......
......@@ -2589,9 +2589,9 @@ class QuantizeParam : public OpParam {
public:
// op input
RType *input_;
GType *input_;
// op output
RType *output_;
GType *output_;
RType *online_scale_;
// quantize offline scale
RType *offline_scale_;
......@@ -2625,9 +2625,9 @@ class DequantizeParam : public OpParam {
public:
// op input
RType *input_;
GType *input_;
// op output
RType *output_;
GType *output_;
RType *activation_scale_;
float weight_scale_;
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册