fake_dequantize_op.h 4.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

17
#include <vector>
18
#include "paddle/fluid/framework/ddim.h"
19 20 21 22 23
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"

namespace paddle {
namespace operators {
24 25 26 27 28 29 30 31

template <typename DeviceContext, typename T>
struct DequantizeFunctor {
  void operator()(const DeviceContext& dev_ctx, const framework::Tensor* in,
                  const framework::Tensor* scale, T max_range,
                  framework::Tensor* out);
};

32 33 34 35 36
template <typename DeviceContext, typename T>
class FakeDequantizeMaxAbsKernel : public framework::OpKernel<T> {
 public:
  virtual void Compute(const framework::ExecutionContext& ctx) const {
    auto* in = ctx.Input<framework::Tensor>("X");
37
    auto* scale = ctx.Input<framework::Tensor>("Scale");
38 39
    auto* out = ctx.Output<framework::Tensor>("Out");

40 41 42 43
    float max_range = ctx.Attr<float>("max_range");

    auto& dev_ctx = ctx.template device_context<DeviceContext>();
    out->mutable_data<T>(dev_ctx.GetPlace());
44

45 46
    DequantizeFunctor<DeviceContext, T>()(dev_ctx, in, scale,
                                          static_cast<T>(max_range), out);
47 48 49
  }
};

Z
Zhen Wang 已提交
50 51 52 53 54
template <typename DeviceContext, typename T>
class FakeChannelWiseDequantizeMaxAbsKernel : public framework::OpKernel<T> {
 public:
  virtual void Compute(const framework::ExecutionContext& ctx) const {
    auto* in = ctx.Input<framework::Tensor>("X");
55
    auto scales = ctx.MultiInput<framework::Tensor>("Scales");
Z
Zhen Wang 已提交
56 57
    auto* out = ctx.Output<framework::Tensor>("Out");

58 59
    auto quant_bits = ctx.Attr<std::vector<int>>("quant_bits");
    int max_range = std::pow(2, quant_bits[0] - 1) - 1;
Z
Zhen Wang 已提交
60 61 62 63 64

    auto& dev_ctx = ctx.template device_context<DeviceContext>();
    out->mutable_data<T>(dev_ctx.GetPlace());

    auto dequant = DequantizeFunctor<DeviceContext, T>();
65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
    if (scales.size() == 1) {
      PADDLE_ENFORCE_EQ(
          scales[0]->numel(), in->dims()[0],
          "The number of first scale values must be the same with "
          "first dimension value of Input(X) when the `Scales` has only one "
          "element.");
      for (int64_t i = 0; i < in->dims()[0]; i++) {
        framework::Tensor one_channel_in = in->Slice(i, i + 1);
        framework::Tensor one_channel_out = out->Slice(i, i + 1);
        framework::Tensor one_channel_scale = scales[0]->Slice(i, i + 1);
        dequant(dev_ctx, &one_channel_in, &one_channel_scale,
                static_cast<T>(max_range), &one_channel_out);
      }
    } else if (scales.size() == 2) {
      PADDLE_ENFORCE_EQ(
          scales[0]->numel(), in->dims()[1],
          "The number of first scale values must be the same with "
          "second dimension value of Input(X) when the `Scales` has two "
          "elements.");
      for (int64_t i = 0; i < in->dims()[0]; i++) {
        framework::Tensor one_batch_in = in->Slice(i, i + 1).Resize(
            framework::slice_ddim(in->dims(), 1, in->dims().size()));
        framework::Tensor one_batch_out = out->Slice(i, i + 1).Resize(
            framework::slice_ddim(out->dims(), 1, out->dims().size()));
        for (int64_t j = 0; j < in->dims()[1]; j++) {
          framework::Tensor one_channel_in = one_batch_in.Slice(j, j + 1);
          framework::Tensor one_channel_out = one_batch_out.Slice(j, j + 1);
          framework::Tensor one_channel_scale = scales[0]->Slice(j, j + 1);
          dequant(dev_ctx, &one_channel_in, &one_channel_scale,
                  static_cast<T>(max_range), &one_channel_out);
        }
      }
97 98 99
      PADDLE_ENFORCE_EQ(
          scales[1]->numel(), 1,
          "The second scale tensor should only have one value at now.");
100 101
      max_range = std::pow(2, quant_bits[1] - 1) - 1;
      dequant(dev_ctx, out, scales[1], static_cast<T>(max_range), out);
Z
Zhen Wang 已提交
102 103 104 105
    }
  }
};

106 107
}  // namespace operators
}  // namespace paddle