dequantize_abs_max_op.cc 3.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/dequantize_abs_max_op.h"
W
wanghuancoder 已提交
16

17
#include <string>
W
wanghuancoder 已提交
18 19 20 21 22 23 24 25 26 27 28

namespace paddle {
namespace framework {
class InferShapeContext;
class OpDesc;
template <typename T>
class EmptyGradOpMaker;
}  // namespace framework
namespace imperative {
class OpBase;
}  // namespace imperative
29
namespace platform {}  // namespace platform
W
wanghuancoder 已提交
30
}  // namespace paddle
31 32 33 34 35

namespace paddle {
namespace operators {

template <typename T>
L
Leo Chen 已提交
36 37
struct DequantizeFunctor<phi::CPUContext, T> {
  void operator()(const phi::CPUContext& dev_ctx,
38 39 40 41
                  const framework::Tensor* in,
                  const framework::Tensor* scale,
                  float max_range,
                  framework::Tensor* out) {
42 43 44 45 46 47 48 49 50 51
    const float* scale_factor = scale->data<float>();
    const T* input_data = in->data<T>();
    float* output_data = out->mutable_data<float>(dev_ctx.GetPlace());
    int ind = in->numel();
    for (size_t i = 0; i < (unsigned)ind; i++) {
      output_data[i] = scale_factor[0] * input_data[i] / max_range;
    }
  }
};

L
Leo Chen 已提交
52 53
template struct DequantizeFunctor<phi::CPUContext, int8_t>;
template struct DequantizeFunctor<phi::CPUContext, int16_t>;
54 55 56 57 58 59 60 61 62 63

class DequantizeMaxAbsOp : public framework::OperatorWithKernel {
 public:
  DequantizeMaxAbsOp(const std::string& type,
                     const framework::VariableNameMap& inputs,
                     const framework::VariableNameMap& outputs,
                     const framework::AttributeMap& attrs)
      : OperatorWithKernel(type, inputs, outputs, attrs) {}

  void InferShape(framework::InferShapeContext* ctx) const override {
64 65
    OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "DequantizeMaxAbs");
    OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "DequantizeMaxAbs");
66 67 68 69 70 71

    ctx->ShareDim("X", /*->*/ "Out");
    ctx->ShareLoD("X", /*->*/ "Out");
  }

  framework::OpKernelType GetExpectedKernelType(
72
      const framework::ExecutionContext& ctx) const override {
73 74 75 76 77 78 79 80 81 82
    auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
    auto type = framework::OpKernelType(data_type, ctx.device_context());
    return type;
  }
};

class DequantizeMaxAbsOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X",
83
             "(Int Tensor) The input with int8/16 type is the "
84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104
             "low precision tensor.");
    AddInput("Scale", "(float) The scale in quantization stage.");
    AddOutput("Out",
              "(float32 Tensor) The output is the dequantized high "
              "precision tensor.");
    AddAttr<float>("max_range", "(float) The max range in quantization stage.");
    AddComment(R"DOC(
DequantizeMaxAbsOp operator.

This calculation is an opposite operation of QuantizeMaxAbsOp:

$$Out = \frac{scale*X}{ max\_range }$$

)DOC");
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
L
Leo Chen 已提交
105
using CPU = phi::CPUContext;
106 107

REGISTER_OPERATOR(
108 109 110
    dequantize_abs_max,
    ops::DequantizeMaxAbsOp,
    ops::DequantizeMaxAbsOpMaker,
111 112 113
    paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
    paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(dequantize_abs_max,
114 115
                       ops::DequantizeMaxAbsKernel<CPU, int8_t>,
                       ops::DequantizeMaxAbsKernel<CPU, int16_t>);