fake_dequantize_op.cc 10.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/fake_dequantize_op.h"
#include <string>
17
#include <vector>
18
#include "paddle/fluid/framework/op_version_registry.h"
19 20 21 22

namespace paddle {
namespace operators {

23 24 25 26 27 28 29 30 31 32
template <typename T>
struct DequantizeFunctor<platform::CPUDeviceContext, T> {
  void operator()(const platform::CPUDeviceContext& dev_ctx,
                  const framework::Tensor* in, const framework::Tensor* scale,
                  T max_range, framework::Tensor* out) {
    auto in_e = framework::EigenVector<T>::Flatten(*in);
    const T* scale_factor = scale->data<T>();
    auto out_e = framework::EigenVector<T>::Flatten(*out);

    auto& dev = *dev_ctx.eigen_device();
33
    out_e.device(dev) = in_e * scale_factor[0] / max_range;
34 35 36
  }
};

37 38 39 40
template <typename T>
struct ChannelDequantizeFunctor<platform::CPUDeviceContext, T> {
  void operator()(const platform::CPUDeviceContext& dev_ctx,
                  const framework::Tensor* in, const framework::Tensor** scales,
41 42
                  const int scale_num, T max_range, const int quant_axis,
                  framework::Tensor* out) {
43
    if (scale_num == 1) {
44 45 46 47
      // Dequant op is before quantized op
      // Dequantize the weight of quantized op
      auto in_dims = in->dims();
      const int64_t channel = in_dims[quant_axis];
48
      const T* scale_factor = scales[0]->data<T>();
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
      if (quant_axis == 0) {
        for (int64_t i = 0; i < channel; i++) {
          T s = scale_factor[i];
          framework::Tensor one_channel_in = in->Slice(i, i + 1);
          framework::Tensor one_channel_out = out->Slice(i, i + 1);
          auto in_e = framework::EigenVector<T>::Flatten(one_channel_in);
          auto out_e = framework::EigenVector<T>::Flatten(one_channel_out);
          auto& dev = *dev_ctx.eigen_device();
          out_e.device(dev) = in_e * s / max_range;
        }
      } else if (quant_axis == 1) {
        int64_t out_iter = 1;
        for (int i = 0; i < quant_axis; i++) {
          out_iter *= in_dims[i];
        }
        int64_t step_i = in->numel() / out_iter;
        int64_t step_j = in->numel() / (out_iter * channel);
        auto* in_data = in->data<T>();
        auto* out_data = out->mutable_data<T>(dev_ctx.GetPlace());
        for (int64_t i = 0; i < out_iter; i++) {
          for (int64_t j = 0; j < channel; j++) {
            auto* cur_in = in_data + i * step_i + j * step_j;
            auto* cur_out = out_data + i * step_i + j * step_j;
            T s = scale_factor[j];
            for (int64_t k = 0; k < step_j; k++) {
              *cur_out = (*cur_in) * s / max_range;
              ++cur_in;
              ++cur_out;
            }
          }
        }
80 81
      }
    } else if (scale_num == 2) {
82 83
      // Dequant op is after quantized op
      // Dequantize the output tensor of quantized op
84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99
      int batch_size = in->dims()[0];
      int channel = in->dims()[1];
      const T* scale_one = scales[0]->data<T>();
      const T* scale_two = scales[1]->data<T>();
      for (int i = 0; i < batch_size; i++) {
        framework::Tensor one_batch_in = in->Slice(i, i + 1).Resize(
            framework::slice_ddim(in->dims(), 1, in->dims().size()));
        framework::Tensor one_batch_out = out->Slice(i, i + 1).Resize(
            framework::slice_ddim(out->dims(), 1, out->dims().size()));
        for (int j = 0; j < channel; j++) {
          T s = scale_one[j];
          framework::Tensor one_channel_in = one_batch_in.Slice(j, j + 1);
          framework::Tensor one_channel_out = one_batch_out.Slice(j, j + 1);
          auto in_e = framework::EigenVector<T>::Flatten(one_channel_in);
          auto out_e = framework::EigenVector<T>::Flatten(one_channel_out);
          auto& dev = *dev_ctx.eigen_device();
100
          out_e.device(dev) = in_e * s * scale_two[0] / max_range;
101 102 103 104 105 106
        }
      }
    }
  }
};

107 108
template struct DequantizeFunctor<platform::CPUDeviceContext, float>;
template struct DequantizeFunctor<platform::CPUDeviceContext, double>;
109 110
template struct ChannelDequantizeFunctor<platform::CPUDeviceContext, float>;
template struct ChannelDequantizeFunctor<platform::CPUDeviceContext, double>;
111

112 113
class FakeDequantizeMaxAbsOp : public framework::OperatorWithKernel {
 public:
114 115 116 117
  FakeDequantizeMaxAbsOp(const std::string& type,
                         const framework::VariableNameMap& inputs,
                         const framework::VariableNameMap& outputs,
                         const framework::AttributeMap& attrs)
118 119
      : OperatorWithKernel(type, inputs, outputs, attrs) {}

120
  void InferShape(framework::InferShapeContext* ctx) const override {
121 122 123
    OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FakeDequantizeMaxAbs");
    OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out",
                   "FakeDequantizeMaxAbs");
124 125

    ctx->ShareDim("X", /*->*/ "Out");
126 127 128 129 130 131 132 133 134 135
    ctx->ShareLoD("X", /*->*/ "Out");
  }
};

class FakeDequantizeMaxAbsOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X",
             "(Tensor) The input with float-32/64 type is the "
             "low precision tensor.");
136
    AddInput("Scale", "(float) The scale in quantization stage.");
137 138 139
    AddOutput("Out",
              "(Tensor) The output is the dequantized high "
              "precision tensor.");
140
    AddAttr<float>("max_range", "(float) The max range in quantization stage.");
141 142 143 144 145
    AddComment(R"DOC(
FakeDequantizeMaxAbsOp operator.

This calculation is an opposite operation of FakeQuantizeMaxAbsOp:

146
$$Out = \frac{scale*X}{ max_range }$$
147 148 149 150 151

)DOC");
  }
};

Z
Zhen Wang 已提交
152 153 154 155 156
class FakeChannelWiseDequantizeMaxAbsOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
157 158 159 160 161 162
    OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X",
                   "FakeChannelWiseDequantizeMaxAbs");
    OP_INOUT_CHECK(ctx->HasInputs("Scales"), "Input", "Scales",
                   "FakeChannelWiseDequantizeMaxAbs");
    OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out",
                   "FakeChannelWiseDequantizeMaxAbs");
Z
Zhen Wang 已提交
163 164 165 166 167 168 169 170 171 172 173 174 175

    ctx->ShareDim("X", /*->*/ "Out");
    ctx->ShareLoD("X", /*->*/ "Out");
  }
};

class FakeChannelWiseDequantizeMaxAbsOpMaker
    : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X",
             "(Tensor) The input with float-32/64 type is the "
             "low precision tensor.");
176 177 178 179 180 181
    AddInput("Scales",
             "(Tensors) The scales in quantization stage. "
             "Now, `Scales` is a vector with at most two tensors. "
             "If Scales has two elements, the second tensor should only have "
             "one value.")
        .AsDuplicable();
Z
Zhen Wang 已提交
182 183 184
    AddOutput("Out",
              "(Tensor) The output is the dequantized high "
              "precision tensor.");
185 186 187 188 189
    AddAttr<std::vector<int>>(
        "quant_bits",
        "Quantization bit numbers in quantization stage. "
        "The size of `quant_bits` should be equal to the size of `Scales`.")
        .SetDefault({8});
190 191 192 193 194 195 196 197 198 199 200 201
    AddAttr<int>("quant_axis",
                 "(int, default 0) The axis for quantization. "
                 "For conv2d, depthwise_conv2d, conv2d_transpose "
                 "and mul, the quant_axis is equal to the cout axis.")
        .SetDefault(0)
        .AddCustomChecker([](const int& quant_axis) {
          PADDLE_ENFORCE_EQ(quant_axis == 0 || quant_axis == 1, true,
                            platform::errors::InvalidArgument(
                                "'quant_axis' should be 0 or 1, but "
                                "the received is %d",
                                quant_axis));
        });
Z
Zhen Wang 已提交
202 203 204 205 206 207

    AddComment(R"DOC(
FakeChannelWiseDequantizeMaxAbsOp operator.

This calculation is an opposite operation of FakeChannelWiseQuantizeMaxAbsOp:

208
$$Out_c = \frac{X_c\prod_{i=1}^{n}Scales_{ic}}{\prod_{i=1}^{n}(2^{quant\_bits_i-1}-1)}$$
Z
Zhen Wang 已提交
209

210 211
In the above formula, the range value of $c$ can be represented as $0 \leq c \lt \ the\ channel\ number\ of\ X$.
Besides, the size of $quant\_bits$ should be equal to the size of $Scales$, and it is called $n$  in the formula.
Z
Zhen Wang 已提交
212

213
Notes: In general, the per-channel quantization is only applied to weights and the activations use per-layer quantization.
Z
Zhen Wang 已提交
214 215 216 217
)DOC");
  }
};

218 219 220 221 222 223
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
using CPU = paddle::platform::CPUDeviceContext;

H
hong 已提交
224 225 226 227 228
REGISTER_OPERATOR(
    fake_dequantize_max_abs, ops::FakeDequantizeMaxAbsOp,
    ops::FakeDequantizeMaxAbsOpMaker,
    paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
    paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
229 230 231
REGISTER_OP_CPU_KERNEL(fake_dequantize_max_abs,
                       ops::FakeDequantizeMaxAbsKernel<CPU, float>,
                       ops::FakeDequantizeMaxAbsKernel<CPU, double>);
Z
Zhen Wang 已提交
232

H
hong 已提交
233 234 235 236 237 238
REGISTER_OPERATOR(
    fake_channel_wise_dequantize_max_abs,
    ops::FakeChannelWiseDequantizeMaxAbsOp,
    ops::FakeChannelWiseDequantizeMaxAbsOpMaker,
    paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
    paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
Z
Zhen Wang 已提交
239 240 241
REGISTER_OP_CPU_KERNEL(fake_channel_wise_dequantize_max_abs,
                       ops::FakeChannelWiseDequantizeMaxAbsKernel<CPU, float>,
                       ops::FakeChannelWiseDequantizeMaxAbsKernel<CPU, double>);
242 243 244 245 246 247 248

REGISTER_OP_VERSION(fake_channel_wise_dequantize_max_abs)
    .AddCheckpoint(
        R"ROC(add new attributes [quant_axis] for applying per-channel "
        "dequantization to conv2d_tranpose and mul ops.)ROC",
        paddle::framework::compatible::OpVersionDesc().NewAttr(
            "quant_axis", "The axis for dequantization.", 0));