/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #pragma once #include #include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/operators/fake_quantize_op.h" #include "paddle/fluid/platform/transform.h" #include "paddle/phi/common/data_type.h" #include "paddle/phi/core/ddim.h" #include "paddle/phi/core/hostdevice.h" #include "paddle/phi/kernels/cast_kernel.h" namespace paddle { namespace operators { template struct DequantizeFunctor { void operator()(const DeviceContext& dev_ctx, const phi::DenseTensor* in, const phi::DenseTensor* scale, T max_range, phi::DenseTensor* out); }; template struct ChannelDequantizeFunctorV2 { void operator()(const DeviceContext& dev_ctx, const phi::DenseTensor* in, const phi::DenseTensor** scales, const int scale_num, T max_range, const int quant_axis, phi::DenseTensor* out); }; template class QuantizeLinearKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* in = context.Input("X"); auto* in_scale = context.Input("Scale"); auto* out = context.Output("Y"); out->mutable_data(context.GetPlace()); int bit_length = context.Attr("bit_length"); int round_type = context.Attr("round_type"); int bin_cnt = std::pow(2, bit_length - 1) - 1; int quant_axis = context.Attr("quant_axis"); bool is_test = context.Attr("is_test"); auto& dev_ctx = context.template device_context(); if (quant_axis < 0) { if (!is_test) { // training auto* in_accum = context.Input("InAccum"); auto* in_state = context.Input("InState"); phi::DenseTensor tmp_scale; tmp_scale.Resize(phi::make_dim(1)); T* cur_scale_data = dev_ctx.template Alloc(&tmp_scale); FindAbsMaxFunctor()( dev_ctx, in->data(), in->numel(), cur_scale_data); auto* out_state = context.Output("OutState"); auto* out_accum = context.Output("OutAccum"); auto* out_scale = context.Output("OutScale"); out_state->mutable_data(context.GetPlace()); out_accum->mutable_data(context.GetPlace()); out_scale->mutable_data(context.GetPlace()); float moving_rate = context.Attr("moving_rate"); FindMovingAverageAbsMaxFunctor()(dev_ctx, *in_accum, *in_state, cur_scale_data, moving_rate, out_state, out_accum, out_scale); ClipAndFakeQuantFunctor()( dev_ctx, *in, *out_scale, bin_cnt, round_type, out); } else { ClipAndFakeQuantFunctor()( dev_ctx, *in, *in_scale, bin_cnt, round_type, out); } } else { if (!is_test) { auto* out_scale = context.Output("OutScale"); T* out_scale_data = out_scale->mutable_data(context.GetPlace()); FindChannelAbsMaxFunctor()( dev_ctx, *in, quant_axis, out_scale_data); ChannelClipAndFakeQuantFunctor()( dev_ctx, *in, *out_scale, bin_cnt, round_type, quant_axis, out); } else { ChannelClipAndFakeQuantFunctor()( dev_ctx, *in, *in_scale, bin_cnt, round_type, quant_axis, out); } } } }; template class DeQuantizeLinearKernel : public framework::OpKernel { public: template void ComputeImpl(const framework::ExecutionContext& context) const { auto& dev_ctx = context.template device_context(); auto* in = context.Input("X"); auto in_tmp = phi::Cast( static_cast::TYPE&>(dev_ctx), *in, experimental::CppTypeToDataType::Type()); auto* scale = context.Input("Scale"); auto* out = context.Output("Y"); int bit_length = context.Attr("bit_length"); auto quant_axis = context.Attr("quant_axis"); dev_ctx.template Alloc(out, out->numel() * sizeof(D)); if (quant_axis < 0) { float max_range = (std::pow(2, bit_length - 1) - 1); DequantizeFunctor()( dev_ctx, &in_tmp, scale, static_cast(max_range), out); } else { PADDLE_ENFORCE_EQ( scale->numel(), in_tmp.dims()[quant_axis], platform::errors::PreconditionNotMet( "The number of first scale values must be the same with " "quant_axis dimension value of Input(X) when the `scale` has " "only one element, but %ld != %ld here.", scale->numel(), in_tmp.dims()[quant_axis])); int max_range = (std::pow(2, bit_length - 1) - 1); ChannelDequantizeFunctorV2()( dev_ctx, &in_tmp, scale, static_cast(max_range), quant_axis, out); } } void Compute(const framework::ExecutionContext& context) const override { auto* scale = context.Input("Scale"); switch (scale->dtype()) { case experimental::DataType::FLOAT64: ComputeImpl(context); break; case experimental::DataType::FLOAT32: ComputeImpl(context); break; case experimental::DataType::FLOAT16: ComputeImpl(context); break; default: PADDLE_THROW(platform::errors::Unimplemented( "In DeQuantizeLinearKernel, " "data type %d for scale/output is not supported ", scale->dtype())); break; } } }; } // namespace operators } // namespace paddle