fake_dequantize_op.cc 3.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/fake_dequantize_op.h"
#include <string>

namespace paddle {
namespace operators {

21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37
template <typename T>
struct DequantizeFunctor<platform::CPUDeviceContext, T> {
  void operator()(const platform::CPUDeviceContext& dev_ctx,
                  const framework::Tensor* in, const framework::Tensor* scale,
                  T max_range, framework::Tensor* out) {
    auto in_e = framework::EigenVector<T>::Flatten(*in);
    const T* scale_factor = scale->data<T>();
    auto out_e = framework::EigenVector<T>::Flatten(*out);

    auto& dev = *dev_ctx.eigen_device();
    out_e.device(dev) = (scale_factor[0] / max_range) * in_e;
  }
};

template struct DequantizeFunctor<platform::CPUDeviceContext, float>;
template struct DequantizeFunctor<platform::CPUDeviceContext, double>;

38 39
class FakeDequantizeMaxAbsOp : public framework::OperatorWithKernel {
 public:
40 41 42 43
  FakeDequantizeMaxAbsOp(const std::string& type,
                         const framework::VariableNameMap& inputs,
                         const framework::VariableNameMap& outputs,
                         const framework::AttributeMap& attrs)
44 45
      : OperatorWithKernel(type, inputs, outputs, attrs) {}

46
  void InferShape(framework::InferShapeContext* ctx) const override {
47 48 49 50
    PADDLE_ENFORCE(ctx->HasInput("X"),
                   "Input(X) of FakeDequantizeMaxAbsOp should not be null.");
    PADDLE_ENFORCE(ctx->HasOutput("Out"),
                   "Output(Out) of FakeDequantizeMaxAbsOp should not be null.");
51 52

    ctx->ShareDim("X", /*->*/ "Out");
53 54 55 56 57 58 59 60 61 62
    ctx->ShareLoD("X", /*->*/ "Out");
  }
};

class FakeDequantizeMaxAbsOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X",
             "(Tensor) The input with float-32/64 type is the "
             "low precision tensor.");
63
    AddInput("Scale", "(float) The scale in quantization stage.");
64 65 66
    AddOutput("Out",
              "(Tensor) The output is the dequantized high "
              "precision tensor.");
67
    AddAttr<float>("max_range", "(float) The max range in quantization stage.");
68 69 70 71 72
    AddComment(R"DOC(
FakeDequantizeMaxAbsOp operator.

This calculation is an opposite operation of FakeQuantizeMaxAbsOp:

73
$$Out = \frac{scale*X}{ max_range }$$
74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90

)DOC");
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
using CPU = paddle::platform::CPUDeviceContext;

REGISTER_OPERATOR(fake_dequantize_max_abs, ops::FakeDequantizeMaxAbsOp,
                  ops::FakeDequantizeMaxAbsOpMaker,
                  paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(fake_dequantize_max_abs,
                       ops::FakeDequantizeMaxAbsKernel<CPU, float>,
                       ops::FakeDequantizeMaxAbsKernel<CPU, double>);