quantize_linear_op.h 7.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <string>
#include <vector>
16

17 18 19 20 21 22 23 24 25 26 27 28 29
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/operators/fake_quantize_op.h"
#include "paddle/fluid/platform/transform.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/kernels/cast_kernel.h"

namespace paddle {
namespace operators {

30 31 32 33 34 35 36 37 38
template <typename DeviceContext, typename T>
struct DequantizeFunctor {
  void operator()(const DeviceContext& dev_ctx,
                  const phi::DenseTensor* in,
                  const phi::DenseTensor* scale,
                  T max_range,
                  phi::DenseTensor* out);
};

39 40
template <typename DeviceContext, typename T>
struct ChannelDequantizeFunctorV2 {
41
  void operator()(const DeviceContext& dev_ctx,
42 43
                  const phi::DenseTensor* in,
                  const phi::DenseTensor** scales,
44 45 46
                  const int scale_num,
                  T max_range,
                  const int quant_axis,
47
                  phi::DenseTensor* out);
48 49 50 51 52 53
};

template <typename DeviceContext, typename T>
class QuantizeLinearKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& context) const override {
54 55
    auto* in = context.Input<phi::DenseTensor>("X");
    auto* in_scale = context.Input<phi::DenseTensor>("Scale");
56

57
    auto* out = context.Output<phi::DenseTensor>("Y");
58 59
    out->mutable_data<T>(context.GetPlace());
    int bit_length = context.Attr<int>("bit_length");
60
    int round_type = context.Attr<int>("round_type");
61 62 63 64 65 66 67
    int bin_cnt = std::pow(2, bit_length - 1) - 1;
    int quant_axis = context.Attr<int>("quant_axis");
    bool is_test = context.Attr<bool>("is_test");
    auto& dev_ctx = context.template device_context<DeviceContext>();

    if (quant_axis < 0) {
      if (!is_test) {
68
        // training
69 70
        auto* in_accum = context.Input<phi::DenseTensor>("InAccum");
        auto* in_state = context.Input<phi::DenseTensor>("InState");
71 72 73
        phi::DenseTensor tmp_scale;
        tmp_scale.Resize(phi::make_dim(1));
        T* cur_scale_data = dev_ctx.template Alloc<T>(&tmp_scale);
74

75
        FindAbsMaxFunctor<DeviceContext, T>()(
76 77
            dev_ctx, in->data<T>(), in->numel(), cur_scale_data);

78 79 80
        auto* out_state = context.Output<phi::DenseTensor>("OutState");
        auto* out_accum = context.Output<phi::DenseTensor>("OutAccum");
        auto* out_scale = context.Output<phi::DenseTensor>("OutScale");
81 82 83 84 85 86 87 88 89 90 91 92 93
        out_state->mutable_data<T>(context.GetPlace());
        out_accum->mutable_data<T>(context.GetPlace());
        out_scale->mutable_data<T>(context.GetPlace());
        float moving_rate = context.Attr<float>("moving_rate");

        FindMovingAverageAbsMaxFunctor<DeviceContext, T>()(dev_ctx,
                                                           *in_accum,
                                                           *in_state,
                                                           cur_scale_data,
                                                           moving_rate,
                                                           out_state,
                                                           out_accum,
                                                           out_scale);
94 95
        ClipAndFakeQuantFunctor<DeviceContext, T>()(
            dev_ctx, *in, *out_scale, bin_cnt, round_type, out);
96
      } else {
97 98
        ClipAndFakeQuantFunctor<DeviceContext, T>()(
            dev_ctx, *in, *in_scale, bin_cnt, round_type, out);
99 100 101
      }
    } else {
      if (!is_test) {
102
        auto* out_scale = context.Output<phi::DenseTensor>("OutScale");
103
        T* out_scale_data = out_scale->mutable_data<T>(context.GetPlace());
104 105
        FindChannelAbsMaxFunctor<DeviceContext, T>()(
            dev_ctx, *in, quant_axis, out_scale_data);
106
        ChannelClipAndFakeQuantFunctor<DeviceContext, T>()(
107
            dev_ctx, *in, *out_scale, bin_cnt, round_type, quant_axis, out);
108 109
      } else {
        ChannelClipAndFakeQuantFunctor<DeviceContext, T>()(
110
            dev_ctx, *in, *in_scale, bin_cnt, round_type, quant_axis, out);
111 112 113 114 115
      }
    }
  }
};

116
template <typename DeviceContext, typename T>
117 118
class DeQuantizeLinearKernel : public framework::OpKernel<T> {
 public:
119 120
  template <typename D>
  void ComputeImpl(const framework::ExecutionContext& context) const {
121
    auto& dev_ctx = context.template device_context<DeviceContext>();
122
    auto* in = context.Input<phi::DenseTensor>("X");
123 124 125 126

    auto in_tmp = phi::Cast<T>(
        static_cast<const typename paddle::framework::ConvertToPhiContext<
            DeviceContext>::TYPE&>(dev_ctx),
127 128
        *in,
        experimental::CppTypeToDataType<D>::Type());
129

130 131
    auto* scale = context.Input<phi::DenseTensor>("Scale");
    auto* out = context.Output<phi::DenseTensor>("Y");
132 133
    int bit_length = context.Attr<int>("bit_length");
    auto quant_axis = context.Attr<int>("quant_axis");
134
    dev_ctx.template Alloc<D>(out, out->numel() * sizeof(D));
135 136 137

    if (quant_axis < 0) {
      float max_range = (std::pow(2, bit_length - 1) - 1);
138 139
      DequantizeFunctor<DeviceContext, D>()(
          dev_ctx, &in_tmp, scale, static_cast<D>(max_range), out);
140 141
    } else {
      PADDLE_ENFORCE_EQ(
142 143
          scale->numel(),
          in_tmp.dims()[quant_axis],
144 145 146 147
          platform::errors::PreconditionNotMet(
              "The number of first scale values must be the same with "
              "quant_axis dimension value of Input(X) when the `scale` has "
              "only one element, but %ld != %ld here.",
148 149
              scale->numel(),
              in_tmp.dims()[quant_axis]));
150 151 152 153 154 155
      int max_range = (std::pow(2, bit_length - 1) - 1);

      ChannelDequantizeFunctorV2<DeviceContext, D>()(
          dev_ctx, &in_tmp, scale, static_cast<D>(max_range), quant_axis, out);
    }
  }
156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176

  void Compute(const framework::ExecutionContext& context) const override {
    auto* scale = context.Input<phi::DenseTensor>("Scale");
    switch (scale->dtype()) {
      case experimental::DataType::FLOAT64:
        ComputeImpl<double>(context);
        break;
      case experimental::DataType::FLOAT32:
        ComputeImpl<float>(context);
        break;
      case experimental::DataType::FLOAT16:
        ComputeImpl<paddle::platform::float16>(context);
        break;
      default:
        PADDLE_THROW(platform::errors::Unimplemented(
            "In DeQuantizeLinearKernel, "
            "data type %d for scale/output is not supported ",
            scale->dtype()));
        break;
    }
  }
177 178 179 180
};

}  // namespace operators
}  // namespace paddle