未验证 提交 27ce06aa 编写于 作者: Z zhaoyingli 提交者: GitHub

[AutoParallel] quantization pass support export (#48072)

* [AutoParallel] quantization pass support export

* support subgraph

* move_presist_var_to_global_block

* update unittest

* fix ci-coverage

* fix codestyle

* fix fake_dequantize_op

* remove unused var

* fix ci error and aprroval error

* add unittest for fp16 in test_dequant_linear

* replace mutable data

* fix unittest in non-cuda-core

* fix unittest
Co-authored-by: Ncarryyu <569782149@qq.com>
Co-authored-by: Nwufeisheng <wfs1997@163.com>
上级 522c2bc0
......@@ -24,6 +24,22 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename T>
struct DequantizeFunctor<phi::CPUContext, T> {
void operator()(const phi::CPUContext &dev_ctx,
const phi::DenseTensor *in,
const phi::DenseTensor *scale,
T max_range,
phi::DenseTensor *out) {
auto in_e = framework::EigenVector<T>::Flatten(*in);
const T *scale_factor = scale->data<T>();
auto out_e = framework::EigenVector<T>::Flatten(*out);
auto &dev = *dev_ctx.eigen_device();
out_e.device(dev) = in_e * scale_factor[0] / max_range;
}
};
template <typename T>
struct ChannelDequantizeFunctorV2<phi::CPUContext, T> {
void operator()(const phi::CPUContext &dev_ctx,
......@@ -55,7 +71,7 @@ struct ChannelDequantizeFunctorV2<phi::CPUContext, T> {
int64_t step_i = in->numel() / out_iter;
int64_t step_j = in->numel() / (out_iter * channel);
auto *in_data = in->data<T>();
auto *out_data = out->mutable_data<T>(dev_ctx.GetPlace());
auto *out_data = dev_ctx.Alloc<T>(out, out->numel() * sizeof(T));
for (int64_t i = 0; i < out_iter; i++) {
for (int64_t j = 0; j < channel; j++) {
auto *cur_in = in_data + i * step_i + j * step_j;
......@@ -72,6 +88,11 @@ struct ChannelDequantizeFunctorV2<phi::CPUContext, T> {
}
};
template struct DequantizeFunctor<phi::CPUContext, phi::dtype::float16>;
template struct DequantizeFunctor<phi::CPUContext, float>;
template struct DequantizeFunctor<phi::CPUContext, double>;
template struct ChannelDequantizeFunctorV2<phi::CPUContext,
phi::dtype::float16>;
template struct ChannelDequantizeFunctorV2<phi::CPUContext, float>;
template struct ChannelDequantizeFunctorV2<phi::CPUContext, double>;
......@@ -214,6 +235,6 @@ REGISTER_OPERATOR(
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(dequantize_linear,
ops::DeQuantizeLinearKernel<CPU, float, float>,
ops::DeQuantizeLinearKernel<CPU, int8_t, float>,
ops::DeQuantizeLinearKernel<CPU, double, double>);
ops::DeQuantizeLinearKernel<CPU, float>,
ops::DeQuantizeLinearKernel<CPU, int8_t>,
ops::DeQuantizeLinearKernel<CPU, double>);
......@@ -15,14 +15,64 @@ limitations under the License. */
#include <string>
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/operators/fake_dequantize_op.cu.h"
#include "paddle/fluid/operators/fake_quantize_op.cu.h"
#include "paddle/fluid/operators/quantize_linear_op.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
using float16 = paddle::platform::float16;
namespace paddle {
namespace operators {
template <typename T>
__global__ void KeDequantize(
const T* in, const T* scale, T max_range, int64_t num, T* out) {
int64_t idx = threadIdx.x + blockIdx.x * blockDim.x;
for (int64_t i = idx; i < num; i += blockDim.x * gridDim.x) {
out[i] = in[i] * scale[0] / max_range;
}
}
template <typename T>
__global__ void DequantizeOneScaleQuantAxisN(const T* in,
const T* scale,
const T max_range,
const int64_t num,
const int n_scales,
const int quant_stride,
T* out) {
int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;
for (int64_t i = idx; i < num; i += blockDim.x * gridDim.x) {
T s = scale[(i / quant_stride) % n_scales];
out[i] = in[i] * s / max_range;
}
}
template <typename T>
struct DequantizeFunctor<phi::GPUContext, T> {
void operator()(const phi::GPUContext& dev_ctx,
const phi::DenseTensor* in,
const phi::DenseTensor* scale,
T max_range,
phi::DenseTensor* out) {
const T* in_data = in->data<T>();
const T* scale_factor = scale->data<T>();
T* out_data = dev_ctx.Alloc<T>(out, out->numel() * sizeof(T));
int64_t num = in->numel();
int64_t block_size = std::min(
num, static_cast<int64_t>(dev_ctx.GetMaxThreadsPerBlock() / 4));
int64_t max_threads =
dev_ctx.GetMaxPhysicalThreadCount(); // SM * block_per_SM
const int64_t max_blocks =
std::max(((max_threads - 1) / block_size + 1), static_cast<int64_t>(1));
const int64_t grid_size =
std::min(max_blocks, (num + block_size - 1) / block_size);
KeDequantize<T><<<grid_size, block_size, 0, dev_ctx.stream()>>>(
in_data, scale_factor, max_range, num, out_data);
}
};
template <typename T>
struct ChannelDequantizeFunctorV2<phi::GPUContext, T> {
void operator()(const phi::GPUContext& dev_ctx,
......@@ -33,7 +83,7 @@ struct ChannelDequantizeFunctorV2<phi::GPUContext, T> {
phi::DenseTensor* out) {
auto in_dims = in->dims();
const T* in_data = in->data<T>();
T* out_data = out->mutable_data<T>(dev_ctx.GetPlace());
T* out_data = dev_ctx.Alloc<T>(out, out->numel() * sizeof(T));
int64_t num = in->numel();
const T* scale_factor = scale->data<T>();
int64_t block_size = std::min(
......@@ -61,6 +111,10 @@ struct ChannelDequantizeFunctorV2<phi::GPUContext, T> {
}
};
template struct DequantizeFunctor<phi::GPUContext, phi::dtype::float16>;
template struct DequantizeFunctor<phi::GPUContext, float>;
template struct DequantizeFunctor<phi::GPUContext, double>;
template struct ChannelDequantizeFunctorV2<phi::GPUContext, float16>;
template struct ChannelDequantizeFunctorV2<phi::GPUContext, float>;
template struct ChannelDequantizeFunctorV2<phi::GPUContext, double>;
......@@ -70,9 +124,11 @@ template struct ChannelDequantizeFunctorV2<phi::GPUContext, double>;
namespace ops = paddle::operators;
using CUDA = phi::GPUContext;
REGISTER_OP_CUDA_KERNEL(dequantize_linear,
ops::DeQuantizeLinearKernel<CUDA, float, float>,
ops::DeQuantizeLinearKernel<CUDA, int8_t, float>,
ops::DeQuantizeLinearKernel<CUDA, double, double>);
ops::DeQuantizeLinearKernel<CUDA, float>,
ops::DeQuantizeLinearKernel<CUDA, float16>,
ops::DeQuantizeLinearKernel<CUDA, int8_t>,
ops::DeQuantizeLinearKernel<CUDA, double>);
REGISTER_OP_CUDA_KERNEL(quantize_linear,
ops::QuantizeLinearKernel<CUDA, float>);
ops::QuantizeLinearKernel<CUDA, float>,
ops::QuantizeLinearKernel<CUDA, float16>);
......@@ -17,7 +17,6 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/operators/fake_dequantize_op.h"
#include "paddle/fluid/operators/fake_quantize_op.h"
#include "paddle/fluid/platform/transform.h"
#include "paddle/phi/common/data_type.h"
......@@ -28,6 +27,15 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
struct DequantizeFunctor {
void operator()(const DeviceContext& dev_ctx,
const phi::DenseTensor* in,
const phi::DenseTensor* scale,
T max_range,
phi::DenseTensor* out);
};
template <typename DeviceContext, typename T>
struct ChannelDequantizeFunctorV2 {
void operator()(const DeviceContext& dev_ctx,
......@@ -105,10 +113,11 @@ class QuantizeLinearKernel : public framework::OpKernel<T> {
}
};
template <typename DeviceContext, typename T, typename D>
template <typename DeviceContext, typename T>
class DeQuantizeLinearKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
template <typename D>
void ComputeImpl(const framework::ExecutionContext& context) const {
auto& dev_ctx = context.template device_context<DeviceContext>();
auto* in = context.Input<phi::DenseTensor>("X");
......@@ -122,7 +131,7 @@ class DeQuantizeLinearKernel : public framework::OpKernel<T> {
auto* out = context.Output<phi::DenseTensor>("Y");
int bit_length = context.Attr<int>("bit_length");
auto quant_axis = context.Attr<int>("quant_axis");
out->mutable_data<D>(dev_ctx.GetPlace());
dev_ctx.template Alloc<D>(out, out->numel() * sizeof(D));
if (quant_axis < 0) {
float max_range = (std::pow(2, bit_length - 1) - 1);
......@@ -144,6 +153,27 @@ class DeQuantizeLinearKernel : public framework::OpKernel<T> {
dev_ctx, &in_tmp, scale, static_cast<D>(max_range), quant_axis, out);
}
}
void Compute(const framework::ExecutionContext& context) const override {
auto* scale = context.Input<phi::DenseTensor>("Scale");
switch (scale->dtype()) {
case experimental::DataType::FLOAT64:
ComputeImpl<double>(context);
break;
case experimental::DataType::FLOAT32:
ComputeImpl<float>(context);
break;
case experimental::DataType::FLOAT16:
ComputeImpl<paddle::platform::float16>(context);
break;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"In DeQuantizeLinearKernel, "
"data type %d for scale/output is not supported ",
scale->dtype()));
break;
}
}
};
} // namespace operators
......
......@@ -107,6 +107,7 @@ set_field_default_config(QAT, "weight_bits", 8)
set_field_default_config(QAT, "activation_bits", 8)
set_field_default_config(QAT, "not_quant_pattern", ['skip_quant'])
set_field_default_config(QAT, "algo", None)
set_field_default_config(QAT, "onnx_format", True)
# #########################################
# auto tuning configuration
......
......@@ -29,7 +29,7 @@ from paddle.distributed import fleet
from paddle.fluid import Variable, core
from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.fluid.executor import _to_name_str, global_scope
from paddle.fluid.framework import Operator
from paddle.fluid.framework import IrGraph, Operator
from paddle.fluid.framework import _current_expected_place as _get_device
from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.layers.utils import flatten
......@@ -752,7 +752,9 @@ class Engine:
# instantiate communication by process_mapping.
all_process_groups = get_all_process_groups()
cur_rank = self._cur_rank
# NOTE: After the implementation of the unified dynamic and static communication group initialization mode in the future, the initialization logic of full mode will be removed because port occupation error may occur.
# NOTE: After the implementation of the unified dynamic and static communication group
# initialization mode in the future, the initialization logic of full mode
# will be removed because port occupation error may occur.
if self._strategy.auto_mode == "full":
auto_utils.initialize_pg_in_full_mode(
all_process_groups, cur_rank
......@@ -763,9 +765,9 @@ class Engine:
continue
process_group.instantiate()
place = _get_device()
if isinstance(place, fluid.CUDAPlace):
place = fluid.CUDAPlace(ParallelEnv().dev_id)
self._place = _get_device()
if isinstance(self._place, fluid.CUDAPlace):
self._place = fluid.CUDAPlace(ParallelEnv().dev_id)
if self._strategy.seed:
paddle.seed(self._strategy.seed + self._dp_ranks[0])
......@@ -775,10 +777,12 @@ class Engine:
if self._dygraph_mode:
dist_context = self._dist_contexts[mode]
dist_main_program = self._dist_main_progs[mode][self._cur_rank]
self.program_helper.init(dist_main_program, place, dist_context)
self.program_helper.init(
dist_main_program, self._place, dist_context
)
if self._executor is None:
self._executor = paddle.static.Executor(place)
self._executor = paddle.static.Executor(self._place)
uninitialized = []
dist_startup_prog = self._dist_startup_progs[mode][self._cur_rank]
for var in dist_startup_prog.list_vars():
......@@ -1612,6 +1616,22 @@ class Engine:
feed_vars = self._feed_vars["predict"]['inputs']
fetch_vars = self._fetch_vars["predict"]['outputs']
dist_main_prog = self._dist_main_progs["predict"][self._cur_rank]
if self._strategy.qat.enable and self._strategy.qat.onnx_format:
from paddle.fluid.contrib.slim.quantization import (
QuantWeightPass,
)
self._logger.info("export quantized model.")
self._logger.info(
"convert config {}".format(self._strategy.qat.to_dict())
)
test_graph = IrGraph(
core.Graph(dist_main_prog.desc), for_test=True
)
quant_weight_pass = QuantWeightPass(global_scope(), self._place)
for sub_graph in test_graph.all_sub_graphs():
quant_weight_pass.apply(sub_graph)
dist_main_prog = test_graph.to_program()
self._saver.save_inference_model(
path,
feed_vars,
......
......@@ -131,8 +131,12 @@ class Parallelizer:
else:
# Apply pre optimization passes
time0 = time.time()
self._apply_pre_optimization(
serial_main_program, serial_startup_program, None, None, None
(
serial_main_program,
serial_startup_program,
params_grads,
) = self._apply_pre_optimization(
serial_main_program, serial_startup_program, None, None, []
)
self._logger.debug(
"within parallel apply_pre_optimization time: {}, mode {}".format(
......@@ -207,22 +211,6 @@ class Parallelizer:
if self._strategy is None:
return
# apply quantization pass
# The pass can be applied when mode must be 'train'
if self._mode == 'train' and self._strategy.qat.enable:
config = copy.deepcopy(self._strategy.qat.to_dict())
config["dist_context"] = self._dist_context
config["params_grads"] = params_grads
auto_parallel_quantization_pass = new_pass(
"auto_parallel_quantization", config
)
auto_parallel_quantization_pass.apply(
[main_program], [startup_program], self._pass_context
)
main_program = self._pass_context.get_attr("main_program")
startup_program = self._pass_context.get_attr("startup_program")
params_grads = self._pass_context.get_attr("params_grads")
# apply amp pass on train/eval/predict
if self._strategy.amp.enable:
config = copy.deepcopy(self._strategy.amp.to_dict())
......@@ -247,6 +235,25 @@ class Parallelizer:
)
loss = auto_parallel_amp_pass.get_loss()
# apply quantization pass
# The pass can be applied when mode must be 'train'
if self._strategy.qat.enable:
config = copy.deepcopy(self._strategy.qat.to_dict())
config["dist_context"] = self._dist_context
config["params_grads"] = params_grads
config["mode"] = self._mode
config["loss"] = loss
auto_parallel_quantization_pass = new_pass(
"auto_parallel_quantization", config
)
auto_parallel_quantization_pass.apply(
[main_program], [startup_program], self._pass_context
)
main_program = self._pass_context.get_attr("main_program")
startup_program = self._pass_context.get_attr("startup_program")
params_grads = self._pass_context.get_attr("params_grads")
loss = self._pass_context.get_attr("loss")
# apply recompute pass
# recompute is then train-only optimization
if self._mode == "train" and self._strategy.recompute.enable:
......
......@@ -2137,6 +2137,36 @@ class Resharder:
input_attrs.append([process_mesh, input_dims_mapping])
return input_attrs
def _get_subblock_output_attrs(self, op, var_name):
# NOTE: Multi while loop is not supported
assert op.type in _g_subblock_ops
sub_block = self.auto_parallel_main_prog.blocks[op.attr("sub_block").id]
ops = sub_block.ops
output_attrs = []
for op in ops:
dist_op = self.dist_context.get_dist_op_for_program(op)
if not dist_op:
continue
dist_attr = dist_op.dist_attr
for name in op.output_arg_names:
if name == var_name:
process_mesh = dist_attr.process_mesh
output_dims_mapping = dist_attr.get_output_dims_mapping(
var_name
)
has_exist = False
for output_attr in output_attrs:
if (
process_mesh == output_attrs[0]
and output_dims_mapping == output_attrs[1]
):
has_exist = True
break
if not has_exist:
output_attrs.append([process_mesh, output_dims_mapping])
return output_attrs
def _get_common_op_input_attrs(self, op, var_name):
process_meshes = []
dist_op = self.dist_context.get_dist_op_for_program(op)
......@@ -2166,6 +2196,11 @@ class Resharder:
if op.type in _g_subblock_ops:
op_input_attrs = self._get_subblock_input_attrs(op, var_name)
if not op_input_attrs:
# NOTE: [hack method]
# Adapt to quantization pass, which presist_vars, including inputs and outputs, all are in global_block.
# Therefore, the while_op's inputs will contain the all persist_vars, which will be inputs or output of the quantization op in subblock.
op_input_attrs = self._get_subblock_output_attrs(op, var_name)
else:
op_input_attrs = self._get_common_op_input_attrs(op, var_name)
......
......@@ -877,6 +877,7 @@ class AMPPass(PassBase):
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
first_backward_op, ref_mesh, [-1], self.dist_context
)
self._scaled_loss_grad.op = first_backward_op
# FIXME(JZ-LIANG) a trick to insert backward op
main_block._sync_with_cpp()
elementwise_mul_grad_op_desc = main_block.desc._insert_op(
......
......@@ -12,13 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import numpy as np
import paddle
from paddle.distributed.auto_parallel.dist_attribute import (
OperatorDistributedAttribute,
TensorDistributedAttribute,
)
from paddle.fluid import core, framework
from paddle.fluid.contrib.slim.quantization import (
AddQuantDequantForInferencePass,
AddQuantDequantPassV2,
OutScaleForTrainingPass,
QuantizationTransformPassV2,
......@@ -26,6 +27,11 @@ from paddle.fluid.contrib.slim.quantization import (
)
from paddle.fluid.dygraph.parallel import ParallelEnv
from ..auto_parallel.converter import Converter
from ..auto_parallel.dist_attribute import (
OperatorDistributedAttribute,
TensorDistributedAttribute,
)
from .pass_base import PassBase, register_pass
TRANSFORM_PASS_OP_TYPES = utils._weight_supported_quantizable_op_type
......@@ -42,6 +48,8 @@ class QuantizationPass(PassBase):
super().__init__()
self.set_attr("dist_context", None)
self.set_attr("params_grads", None)
self.set_attr("mode", "train")
self.set_attr("loss", None)
def _check_self(self):
if self.get_attr("dist_context") is None:
......@@ -57,15 +65,23 @@ class QuantizationPass(PassBase):
dist_context = self.get_attr("dist_context")
params_grads = self.get_attr("params_grads")
mode = self.get_attr("mode")
loss = self.get_attr("loss")
# TODO: scope and place will be removed,
# cause params should be initialized by engine module.
scope = paddle.static.global_scope()
place = paddle.fluid.CUDAPlace(ParallelEnv().dev_id)
# 0. record the relation among blocks
parent_idx_dict = dict()
for block in main_program.blocks:
parent_idx_dict[block.idx] = block.parent_idx
is_test = True if mode != "train" else False
# 1. Program convert to Graph, and this pass is only for train mode
main_graph = framework.IrGraph(
core.Graph(main_program.desc), for_test=False
core.Graph(main_program.desc), for_test=mode != "train"
)
# 2. Prepare inputs
......@@ -91,44 +107,67 @@ class QuantizationPass(PassBase):
)
# 3. Add quant op for ops which have parameters
transform_pass = QuantizationTransformPassV2(
scope=scope,
place=place,
weight_bits=self.get_attr('weight_bits'),
activation_bits=self.get_attr('activation_bits'),
skip_pattern=self.get_attr('not_quant_pattern'),
activation_quantize_type="moving_average_abs_max",
quantizable_op_type=transform_pass_ops,
weight_quantize_type=weight_quantize_type,
weight_quantize_func=None,
act_quantize_func=None,
weight_preprocess_func=None,
act_preprocess_func=None,
optimizer_func=None,
executor=None,
)
transform_pass.apply(main_graph)
if len(transform_pass_ops) > 0:
transform_pass = QuantizationTransformPassV2(
scope=scope,
place=place,
weight_bits=self.get_attr('weight_bits'),
activation_bits=self.get_attr('activation_bits'),
skip_pattern=self.get_attr('not_quant_pattern'),
activation_quantize_type="moving_average_abs_max",
quantizable_op_type=transform_pass_ops,
weight_quantize_type=weight_quantize_type,
weight_quantize_func=None,
act_quantize_func=None,
weight_preprocess_func=None,
act_preprocess_func=None,
optimizer_func=None,
executor=None,
is_test=is_test,
)
for sub_graph in main_graph.all_sub_graphs():
transform_pass.apply(sub_graph)
# 4. Add quant op for ops which don't have parameter
quant_dequant_pass = AddQuantDequantPassV2(
scope=scope,
place=place,
quant_bits=self.get_attr('activation_bits'),
skip_pattern=self.get_attr('not_quant_pattern'),
quantizable_op_type=quant_dequant_ops,
)
quant_dequant_pass.apply(main_graph)
if len(quant_dequant_ops) > 0:
quant_dequant_pass = AddQuantDequantPassV2(
scope=scope,
place=place,
quant_bits=self.get_attr('activation_bits'),
skip_pattern=self.get_attr('not_quant_pattern'),
quantizable_op_type=quant_dequant_ops,
is_test=is_test,
)
for sub_graph in main_graph.all_sub_graphs():
quant_dequant_pass.apply(sub_graph)
# 5. Gather quantitative information for the output
out_scale_training_pass = OutScaleForTrainingPass(
scope=scope, place=place
scope=scope, place=place, is_test=is_test
)
out_scale_training_pass.apply(main_graph)
for sub_graph in main_graph.all_sub_graphs():
out_scale_training_pass.apply(sub_graph)
# 6. When export quant model, traverse to find the output of each op, and insert the quant/dequant op after it.
if mode != "train" and self.get_attr('onnx_format'):
try:
out_scale_infer_pass = AddQuantDequantForInferencePass(
scope=scope,
place=place,
quant_bits=self.get_attr('activation_bits'),
)
# for sub_graph in main_graph.all_sub_graphs():
# out_scale_infer_pass.apply(sub_graph)
except:
logging.warning(
"Unable to convert quant model with onnx_format=True, please update PaddlePaddle >= 2.4.0"
)
# 6. Convert Graph back to Program
# 7. Convert Graph back to Program
quant_program = main_graph.to_program()
quant_program = self.move_presist_var_to_global_block(quant_program)
# 7. get new prams_grads from quant_program
# 8.1 get new prams_grads from quant_program
new_params_grads = []
for param, grad in params_grads:
if param.name not in quant_program.global_block().vars:
......@@ -138,7 +177,72 @@ class QuantizationPass(PassBase):
new_grad = quant_program.global_block().vars[grad.name]
new_params_grads.append((new_param, new_grad))
# 8. complete distributed attribution
# 8.2 get new loss var
new_loss = None
if loss:
new_loss = quant_program.global_block().vars[loss.name]
# 8.3 recover the relation among blocks
for block in quant_program.blocks:
block.desc._set_forward_block_idx(parent_idx_dict[block.idx])
# 9. complete distributed attribution
self.set_dist_attr_for_qat_program(
quant_program, main_program, dist_context
)
# 10. reset scale var value with dist_attr
self.reset_scope_var(quant_program, dist_context, scope, place)
context.set_attr("main_program", quant_program)
context.set_attr("startup_program", startup_program)
context.set_attr("params_grads", new_params_grads)
context.set_attr("loss", new_loss)
def move_presist_var_to_global_block(self, program):
global_block = program.global_block()
for _op in global_block.ops:
if _op.type == "while":
_block_id = _op.attr("sub_block").id
_block = program.block(_block_id)
persistables = []
for _name, _var in _block.vars.items():
if _var.persistable:
global_block._clone_variable(_var)
persistables.append(_name)
for _name in persistables:
_block._remove_var(_name)
persistables.extend(_op.input('X'))
_op.desc.set_input("X", persistables)
return program
def reset_scope_var(self, quant_program, dist_context, scope, place):
# The var_value, created by qatization_passes, should has same shape with the value after parallel.
for var in quant_program.list_vars():
scope_var = scope.find_var(var.name)
if not (scope_var and scope_var.get_tensor()._is_initialized()):
continue
tensor = scope_var.get_tensor()
if var.shape == tensor.shape:
continue
var_dist_attr = dist_context.get_tensor_dist_attr_for_program(var)
dist_attr = {
"dims_mapping": var_dist_attr.dims_mapping,
"process_shape": var_dist_attr.process_mesh.topology,
"process_group": var_dist_attr.process_mesh.processes,
}
# slice tensor_value with dist_attr
sliced_tensor = Converter.slice_with_dist_attr(
np.array(tensor), dist_attr
)
tensor._clear()
tensor.set(sliced_tensor, place)
def set_dist_attr_for_qat_program(
self, quant_program, main_program, dist_context
):
# NOTE: hack implement, upgrading soon
for ib, block in enumerate(quant_program.blocks):
# recover origin ops' dist_attr and set quant ops' dist_attr
......@@ -150,15 +254,22 @@ class QuantizationPass(PassBase):
"quantize" in quant_op.type
or quant_op.type == "moving_average_abs_max_scale"
):
# set all quantization ops' dist_attr by quantified op
input_name = quant_op.desc.input('X')[0]
if "quantize" in input_name:
input_name = input_name[
: input_name.index(".quantized")
]
if quant_op.type == "moving_average_abs_max_scale":
consume_op = main_program.blocks[ib].vars[input_name].op
if (
quant_op.type == "moving_average_abs_max_scale"
or ip - qat_offset >= len(main_program.blocks[ib].ops)
):
consume_op = (
main_program.blocks[ib]
._var_recursive(input_name)
.op
)
else:
consume_op = main_program.blocks[ib].ops[
ip - qat_offset
......@@ -185,23 +296,42 @@ class QuantizationPass(PassBase):
)
for slot_name in quant_op.desc.input_names():
in_name = quant_op.desc.input(slot_name)[0]
input_var = block._var_recursive(in_name)
ref_dims_mapping = [-1]
if slot_name == "X":
continue
for in_name in quant_op.desc.input(slot_name):
input_var = block.vars[in_name]
tensor_dist_attr = TensorDistributedAttribute()
tensor_dist_attr.process_mesh = ref_process_mesh
tensor_dist_attr.dims_mapping = [-1]
dist_context.set_tensor_dist_attr_for_program(
input_var, tensor_dist_attr
)
quant_op_dist_attr.set_input_dist_attr(
in_name, tensor_dist_attr
)
elif slot_name in ['Scale', 'ZeroPoint']:
if (
quant_op.has_attr('quant_axis')
and quant_op.attr('quant_axis') != -1
):
x_name = quant_op.desc.input('X')[0]
x_var = block._var_recursive(x_name)
x_dist_attr = (
quant_op_dist_attr.get_input_dist_attr(
x_name
)
)
quant_axis = quant_op.attr('quant_axis')
ref_dims_mapping = [
x_dist_attr.dims_mapping[quant_axis]
]
tensor_dist_attr = TensorDistributedAttribute()
tensor_dist_attr.process_mesh = ref_process_mesh
tensor_dist_attr.dims_mapping = ref_dims_mapping
dist_context.set_tensor_dist_attr_for_program(
input_var, tensor_dist_attr
)
quant_op_dist_attr.set_input_dist_attr(
in_name, tensor_dist_attr
)
for slot_name in quant_op.desc.output_names():
output_name = quant_op.desc.output(slot_name)[0]
output_var = block.vars[output_name]
output_var = block._var_recursive(output_name)
ref_dims_mapping = [-1]
if slot_name == "Y":
dist_context.set_tensor_dist_attr_for_program(
output_var, consume_input_dist_attr
......@@ -209,22 +339,39 @@ class QuantizationPass(PassBase):
quant_op_dist_attr.set_output_dist_attr(
output_name, consume_input_dist_attr
)
else:
tensor_dist_attr = TensorDistributedAttribute()
tensor_dist_attr.process_mesh = ref_process_mesh
tensor_dist_attr.dims_mapping = [-1]
dist_context.set_tensor_dist_attr_for_program(
output_var, tensor_dist_attr
)
quant_op_dist_attr.set_output_dist_attr(
output_name, tensor_dist_attr
)
continue
elif slot_name == "OutScale":
if (
quant_op.has_attr('quant_axis')
and quant_op.attr('quant_axis') != -1
):
x_name = quant_op.desc.input('X')[0]
x_var = block._var_recursive(x_name)
x_dist_attr = (
quant_op_dist_attr.get_input_dist_attr(
x_name
)
)
quant_axis = quant_op.attr('quant_axis')
ref_dims_mapping = [
x_dist_attr.dims_mapping[quant_axis]
]
tensor_dist_attr = TensorDistributedAttribute()
tensor_dist_attr.process_mesh = ref_process_mesh
tensor_dist_attr.dims_mapping = ref_dims_mapping
dist_context.set_tensor_dist_attr_for_program(
output_var, tensor_dist_attr
)
quant_op_dist_attr.set_output_dist_attr(
output_name, tensor_dist_attr
)
quant_op._set_attr("op_device", "")
qat_offset += 1
else:
# recover origin ops' dist_attr
origin_op = main_program.blocks[ib].ops[ip - qat_offset]
quant_op.desc.set_original_id(origin_op.desc.original_id())
dist_origin_op = dist_context.get_dist_op_for_program(
......@@ -240,7 +387,21 @@ class QuantizationPass(PassBase):
quant_op_dist_attr.process_mesh = (
origin_op_dist_attr.process_mesh
)
scale_offset = 0
for idx, input_name in enumerate(quant_op.input_arg_names):
if (
origin_op.type == "while"
and input_name not in origin_op.input_arg_names
):
assert (
"@scale" in input_name
or "@zero_point" in input_name
)
scale_offset += 1
continue
idx -= scale_offset
origin_input_name = origin_op.input_arg_names[idx]
origin_input_dist_attr = (
origin_op_dist_attr.inputs_dist_attrs[
......@@ -251,20 +412,6 @@ class QuantizationPass(PassBase):
input_name, origin_input_dist_attr
)
if input_name not in main_program.blocks[ib].vars:
origin_input_var = main_program.blocks[ib].vars[
origin_input_name
]
origin_in_tensor_dist_attr = (
dist_context.get_dist_tensor_for_program(
origin_input_var
).dist_attr
)
quant_input_var = block.vars[input_name]
dist_context.set_tensor_dist_attr_for_program(
quant_input_var, origin_in_tensor_dist_attr
)
for idx, output_name in enumerate(
quant_op.output_arg_names
):
......@@ -278,16 +425,18 @@ class QuantizationPass(PassBase):
output_name, origin_output_dist_attr
)
if output_name not in main_program.blocks[ib].vars:
origin_output_var = main_program.blocks[ib].vars[
origin_output_name
]
if not main_program.blocks[ib]._find_var_recursive(
output_name
):
origin_output_var = main_program.blocks[
ib
]._var_recursive(origin_output_name)
origin_out_tensor_dist_attr = (
dist_context.get_dist_tensor_for_program(
origin_output_var
).dist_attr
)
quant_output_var = block.vars[output_name]
quant_output_var = block._var_recursive(output_name)
dist_context.set_tensor_dist_attr_for_program(
quant_output_var, origin_out_tensor_dist_attr
)
......@@ -308,7 +457,3 @@ class QuantizationPass(PassBase):
dist_context.set_tensor_dist_attr_for_program(
dst_var, dist_tensor.dist_attr
)
context.set_attr("main_program", quant_program)
context.set_attr("startup_program", startup_program)
context.set_attr("params_grads", new_params_grads)
......@@ -513,11 +513,12 @@ class QuantizationTransformPass:
var_dtype=var_node.dtype(),
)
scale_name = self._quantized_scale_name(name)
data_type = (
'float64'
if var_node.dtype() == core.VarDesc.VarType.FP64
else 'float32'
)
if var_node.dtype() == core.VarDesc.VarType.FP64:
data_type = 'float64'
elif var_node.dtype() == core.VarDesc.VarType.FP32:
data_type = 'float32'
else:
data_type = "float16"
try:
scale_value = np.array(
self._scope.find_var(scale_name).get_tensor()
......@@ -560,11 +561,12 @@ class QuantizationTransformPass:
)
scale_name = self._quantized_scale_name(name)
data_type = (
'float64'
if var_node.dtype() == core.VarDesc.VarType.FP64
else 'float32'
)
if var_node.dtype() == core.VarDesc.VarType.FP64:
data_type = 'float64'
elif var_node.dtype() == core.VarDesc.VarType.FP32:
data_type = 'float32'
else:
data_type = "float16"
try:
scale_value = np.array(
self._scope.find_var(scale_name).get_tensor()
......@@ -591,11 +593,12 @@ class QuantizationTransformPass:
shape=[self._window_size],
var_dtype=var_node.dtype(),
)
data_type = (
'float64'
if var_node.dtype() == core.VarDesc.VarType.FP64
else 'float32'
)
if var_node.dtype() == core.VarDesc.VarType.FP64:
data_type = 'float64'
elif var_node.dtype() == core.VarDesc.VarType.FP32:
data_type = 'float32'
else:
data_type = "float16"
_init_var_node(
scales_node,
np.zeros([self._window_size], dtype=data_type),
......@@ -640,11 +643,12 @@ class QuantizationTransformPass:
var_dtype=var_node.dtype(),
)
scale_name = self._quantized_scale_name(name)
data_type = (
'float64'
if var_node.dtype() == core.VarDesc.VarType.FP64
else 'float32'
)
if var_node.dtype() == core.VarDesc.VarType.FP64:
data_type = 'float64'
elif var_node.dtype() == core.VarDesc.VarType.FP32:
data_type = 'float32'
else:
data_type = "float16"
try:
scale_value = np.array(
self._scope.find_var(scale_name).get_tensor()
......@@ -669,11 +673,12 @@ class QuantizationTransformPass:
var_dtype=var_node.dtype(),
shape=[1],
)
data_type = (
'float64'
if var_node.dtype() == core.VarDesc.VarType.FP64
else 'float32'
)
if var_node.dtype() == core.VarDesc.VarType.FP64:
data_type = 'float64'
elif var_node.dtype() == core.VarDesc.VarType.FP32:
data_type = 'float32'
else:
data_type = "float16"
_init_var_node(
state_in_node,
np.ones([1], dtype=data_type),
......@@ -746,11 +751,12 @@ class QuantizationTransformPass:
var_dtype=var_node.dtype(),
)
scale_name = self._quantized_scale_name(name)
data_type = (
'float64'
if var_node.dtype() == core.VarDesc.VarType.FP64
else 'float32'
)
if var_node.dtype() == core.VarDesc.VarType.FP64:
data_type = 'float64'
elif var_node.dtype() == core.VarDesc.VarType.FP32:
data_type = 'float32'
else:
data_type = "float16"
try:
scale_value = np.array(
self._scope.find_var(scale_name).get_tensor()
......@@ -1287,11 +1293,13 @@ class QuantizationFreezePass:
shape=[channel_scale.shape[0]],
var_dtype=output_var_node.dtype(),
)
data_type = (
'float64'
if output_var_node.dtype() == core.VarDesc.VarType.FP64
else 'float32'
)
if output_var_node.dtype() == core.VarDesc.VarType.FP64:
data_type = 'float64'
elif output_var_node.dtype() == core.VarDesc.VarType.FP32:
data_type = 'float32'
else:
data_type = "float16"
_init_var_node(
weight_scale_node,
channel_scale.astype(data_type),
......@@ -1439,6 +1447,7 @@ class QuantizationFreezePass:
def _is_float(self, v):
return (
isinstance(v, float)
or isinstance(v, np.float16)
or isinstance(v, np.float32)
or isinstance(v, np.float64)
)
......@@ -1636,14 +1645,17 @@ class OutScaleForTrainingPass:
if in_node.dtype() not in [
core.VarDesc.VarType.FP64,
core.VarDesc.VarType.FP32,
core.VarDesc.VarType.FP16,
]:
continue
data_type = (
'float64'
if in_node.dtype() == core.VarDesc.VarType.FP64
else 'float32'
)
if in_node.dtype() == core.VarDesc.VarType.FP64:
data_type = 'float64'
elif in_node.dtype() == core.VarDesc.VarType.FP32:
data_type = 'float32'
else:
data_type = "float16"
try:
graph._find_node_by_name(
graph.all_var_nodes(),
......@@ -1781,6 +1793,7 @@ class OutScaleForInferencePass:
not in [
core.VarDesc.VarType.FP64,
core.VarDesc.VarType.FP32,
core.VarDesc.VarType.FP16,
]
):
continue
......@@ -1997,11 +2010,12 @@ class AddQuantDequantPass:
var_dtype=var_node.dtype(),
)
scale_name = "{}.quant_dequant@scale".format(var_node.name())
data_type = (
'float64'
if var_node.dtype() == core.VarDesc.VarType.FP64
else 'float32'
)
if var_node.dtype() == core.VarDesc.VarType.FP64:
data_type = 'float64'
elif var_node.dtype() == core.VarDesc.VarType.FP32:
data_type = 'float32'
else:
data_type = "float16"
try:
if (
self._scale_dict is not None
......@@ -2036,11 +2050,12 @@ class AddQuantDequantPass:
var_dtype=var_node.dtype(),
shape=[1],
)
data_type = (
'float64'
if var_node.dtype() == core.VarDesc.VarType.FP64
else 'float32'
)
if var_node.dtype() == core.VarDesc.VarType.FP64:
data_type = 'float64'
elif var_node.dtype() == core.VarDesc.VarType.FP32:
data_type = 'float32'
else:
data_type = "float16"
_init_var_node(
state_in_node,
np.ones([1], dtype=data_type),
......@@ -2149,11 +2164,12 @@ class InsertQuantizeLinear:
var_dtype=var_node.dtype(),
)
if not scale_var_node:
data_type = (
'float64'
if var_node.dtype() == core.VarDesc.VarType.FP64
else 'float32'
)
if var_node.dtype() == core.VarDesc.VarType.FP64:
data_type = 'float64'
elif var_node.dtype() == core.VarDesc.VarType.FP32:
data_type = 'float32'
else:
data_type = "float16"
scale_name = self._quantized_scale_name(var_name)
if self.channel_wise:
scale_var_shape = var_node.shape()[self.quant_axis]
......@@ -2218,11 +2234,12 @@ class InsertQuantizeLinear:
var_dtype=var_node.dtype(),
shape=[1],
)
data_type = (
'float64'
if var_node.dtype() == core.VarDesc.VarType.FP64
else 'float32'
)
if var_node.dtype() == core.VarDesc.VarType.FP64:
data_type = 'float64'
elif var_node.dtype() == core.VarDesc.VarType.FP32:
data_type = 'float32'
else:
data_type = "float16"
_init_var_node(
state_in_node,
np.ones([1], dtype=data_type),
......@@ -3277,6 +3294,7 @@ class AddQuantDequantForInferencePass:
if out_node.dtype() not in [
core.VarDesc.VarType.FP64,
core.VarDesc.VarType.FP32,
core.VarDesc.VarType.FP16,
]:
continue
if var_name in dequantized_vars_map:
......
......@@ -74,6 +74,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
set_tests_properties(test_parallel_tuner_predict PROPERTIES TIMEOUT 120)
py_test_modules(test_selective_recompute MODULES test_selective_recompute)
set_tests_properties(test_selective_recompute PROPERTIES TIMEOUT 50)
py_test_modules(test_pass_quantization MODULES test_pass_quantization ENVS
${dist_ENVS})
set_tests_properties(test_pass_quantization PROPERTIES TIMEOUT 60)
py_test_modules(test_tuning_recompute MODULES test_tuning_recompute)
set_tests_properties(test_tuning_recompute PROPERTIES TIMEOUT 240)
......@@ -113,7 +116,6 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_process_mesh MODULES test_process_mesh)
py_test_modules(test_interface MODULES test_interface)
py_test_modules(test_strategy MODULES test_strategy)
py_test_modules(test_pass_quantization MODULES test_pass_quantization)
py_test_modules(test_dist_shape MODULES test_dist_shape)
py_test_modules(test_dist_assign MODULES test_dist_assign)
py_test_modules(test_conditional_block_reshard MODULES
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import tempfile
import unittest
import numpy as np
from get_gpt_model import FakeDataset, create_data_holder, generate_model
import paddle
from paddle.distributed.fleet import auto
paddle.enable_static()
def apply_pass():
dist_strategy = auto.Strategy()
dist_strategy.auto_mode = "semi"
amp = dist_strategy.amp
amp.enable = True
amp.custom_white_list = ["lookup_table", "lookup_table_v2"]
amp.custom_black_list = [
"reduce_sum",
"c_softmax_with_cross_entropy",
"elementwise_div",
]
amp.init_loss_scaling = 32768
amp.use_fp16_guard = False
amp.use_pure_fp16 = True
qat = dist_strategy.qat
qat.enable = True
qat.channel_wise_abs_max = True
qat.weight_bits = 8
qat.activation_bits = 8
qat.not_quant_pattern = ['skip_quant']
qat.onnx_format = True
return dist_strategy
class TestQuantizationPassTrain(unittest.TestCase):
def test_qat_pass_training(self):
batch_size = 1
batch_num = 10
strategy = apply_pass()
model, loss = generate_model("mp")
opt = paddle.optimizer.AdamW(learning_rate=0.00001)
engine = auto.Engine(model, loss, opt, strategy=strategy)
dataset = FakeDataset(batch_size * batch_num)
engine.fit(dataset, 3, batch_size=batch_size)
self.check_program(engine.main_program)
def check_program(self, program):
quantizable_op_and_inputs = {'matmul_v2': ['X', 'Y']}
quantizable_grad_op_inputs = {'matmul_v2_grad': ['X', 'Y']}
quantized_ops = set()
for block in program.blocks:
for idx, op in enumerate(block.ops):
is_quntized = False
if op.type in quantizable_op_and_inputs:
for arg_name in op.input_arg_names:
if ".quantized" in arg_name:
is_quntized = True
if not is_quntized:
continue
# check forward
if op.type in quantizable_op_and_inputs:
for arg_name in op.input_arg_names:
if "c_identity" in arg_name:
arg_name = block.ops[idx - 1].input_arg_names[0]
assert arg_name.endswith('.quantized.dequantized')
quantized_ops.add(arg_name)
for op in block.ops:
is_quntized = False
if op.type in quantizable_grad_op_inputs:
for pname in quantizable_grad_op_inputs[op.type]:
arg_name = op.input(pname)[0]
if ".quantized" in arg_name:
is_quntized = True
if not is_quntized:
continue
# check backward
if op.type in quantizable_grad_op_inputs:
for pname in quantizable_grad_op_inputs[op.type]:
arg_name = op.input(pname)[0]
assert arg_name.endswith('.quantized.dequantized')
assert arg_name in quantized_ops
class TestQuantizationPassExport(unittest.TestCase):
def setUp(self):
self.temp_dir = tempfile.TemporaryDirectory()
def tearDown(self):
self.temp_dir.cleanup()
def test_qat_pass_2(self):
batch_size = 1
batch_num = 10
strategy = apply_pass()
model, loss = generate_model("mp")
engine = auto.Engine(model, loss, strategy=strategy)
inputs_spec, labels_spec = create_data_holder(batch_size=1)
engine.prepare(inputs_spec, labels_spec, mode="predict")
path = os.path.join(self.temp_dir.name, 'inf')
engine.save(path, training=False)
self.check_export(engine._executor)
def check_export(self, exe):
sequence_len = 512
vocab_size = 1000
tokens = [np.random.randint(vocab_size, size=sequence_len)]
position_ids = [np.arange(sequence_len)]
attention_mask = [np.tril(np.ones(sequence_len))]
path_prefix = os.path.join(
self.temp_dir.name,
'inf_dist{}'.format(paddle.distributed.get_rank()),
)
[
inference_program,
feed_target_names,
fetch_targets,
] = paddle.static.load_inference_model(
path_prefix=path_prefix, executor=exe
)
out = exe.run(
inference_program,
feed={
"tokens": tokens,
"position_ids": position_ids,
"attention_mask": attention_mask,
},
fetch_list=fetch_targets,
)
if __name__ == "__main__":
unittest.main()
......@@ -12,83 +12,45 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import subprocess
import sys
import tempfile
import unittest
from get_gpt_model import FakeDataset, generate_model
import paddle
from paddle.distributed.fleet import auto
paddle.enable_static()
def apply_pass():
dist_strategy = auto.Strategy()
dist_strategy.auto_mode = "semi"
qat = dist_strategy.qat
qat.enable = True
qat.channel_wise_abs_max = True
qat.weight_bits = 8
qat.activation_bits = 8
qat.not_quant_pattern = ['skip_quant']
return dist_strategy
class TestQuantizationPass(unittest.TestCase):
def test_qat_pass(self):
batch_size = 8
batch_num = 10
strategy = apply_pass()
model, loss = generate_model("serial")
opt = paddle.optimizer.AdamW(learning_rate=0.00001)
engine = auto.Engine(model, loss, opt, strategy=strategy)
dataset = FakeDataset(batch_size * batch_num)
engine.fit(dataset, 3, batch_size=batch_size)
self.check_program(engine.main_program)
def check_program(self, program):
quantizable_op_and_inputs = {'matmul_v2': ['X', 'Y']}
quantizable_grad_op_inputs = {'matmul_v2_grad': ['X', 'Y']}
quantized_ops = set()
for block in program.blocks:
for op in block.ops:
is_quntized = False
if op.type in quantizable_op_and_inputs:
for arg_name in op.input_arg_names:
if ".quantized" in arg_name:
is_quntized = True
if not is_quntized:
continue
# check forward
if op.type in quantizable_op_and_inputs:
for arg_name in op.input_arg_names:
assert arg_name.endswith('.quantized.dequantized')
quantized_ops.add(arg_name)
for op in block.ops:
is_quntized = False
if op.type in quantizable_grad_op_inputs:
for pname in quantizable_grad_op_inputs[op.type]:
arg_name = op.input(pname)[0]
if ".quantized" in arg_name:
is_quntized = True
if not is_quntized:
continue
# check backward
if op.type in quantizable_grad_op_inputs:
for pname in quantizable_grad_op_inputs[op.type]:
arg_name = op.input(pname)[0]
assert arg_name.endswith('.quantized.dequantized')
assert arg_name in quantized_ops
def test_mp2(self):
file_dir = os.path.dirname(os.path.abspath(__file__))
launch_model_path = os.path.join(
file_dir, "quantization_pass_unittest.py"
)
if os.environ.get("WITH_COVERAGE", "OFF") == "ON":
coverage_args = ["-m", "coverage", "run", "--branch", "-p"]
else:
coverage_args = []
tmp_dir = tempfile.TemporaryDirectory()
cmd = (
[sys.executable, "-u"]
+ coverage_args
+ [
"-m",
"paddle.distributed.launch",
"--devices",
"0,1",
"--log_dir",
tmp_dir.name,
launch_model_path,
]
)
process = subprocess.Popen(cmd)
process.wait()
self.assertEqual(process.returncode, 0)
tmp_dir.cleanup()
if __name__ == "__main__":
......
......@@ -834,13 +834,14 @@ class GPTForPretraining(nn.Layer):
x_dims_mapping = ["x"] + [None for i in range(len(x.shape) - 1)]
w_dims_mapping = ["y"] + [None for i in range(len(w.shape) - 1)]
if mesh:
matmul = auto.shard_op(
paddle.matmul, mesh, [x_dims_mapping, w_dims_mapping, None]
)
logits = matmul(x, w, transpose_y=True)
else:
logits = paddle.matmul(x, w, transpose_y=True)
with paddle.fluid.name_scope('skip_quant'):
if mesh:
matmul = auto.shard_op(
paddle.matmul, mesh, [x_dims_mapping, w_dims_mapping, None]
)
logits = matmul(x, w, transpose_y=True)
else:
logits = paddle.matmul(x, w, transpose_y=True)
if use_cache:
return logits, cached_kvs
......
......@@ -26,7 +26,7 @@ def quantize_max_abs(x, max_range):
def dequantize_max_abs(x, scale, max_range):
y = (scale / max_range) * x
y = x * scale / max_range
return y
......@@ -292,6 +292,45 @@ class TestDequantizeOpDouble(TestDequantizeOp):
self.quant_axis = -1
class TestDequantizeOpHalf(TestDequantizeOp):
def set_args(self):
self.bit_length = 8
self.max_range = math.pow(2, self.bit_length - 1) - 1
self.data_type = "float16"
self.quant_axis = -1
def setUp(self):
self.set_args()
self.op_type = "dequantize_linear"
x = np.random.randn(31, 65).astype(np.float16)
yq, scale = quantize_max_abs(x, self.max_range)
scale = np.array(scale).astype('float16')
yq = np.array(yq).astype('int8')
ydq = dequantize_max_abs(yq, scale, self.max_range)
ydq = ydq.astype('float16')
zero_point = np.zeros(scale.shape, dtype="int32")
self.inputs = {'X': yq, 'Scale': scale, 'ZeroPoint': zero_point}
self.attrs = {
'bit_length': self.bit_length,
'quant_axis': self.quant_axis,
}
self.outputs = {'Y': ydq}
def _get_places(self):
import paddle
import paddle.fluid.core as core
if core.is_compiled_with_cuda():
place = paddle.fluid.core.CUDAPlace(0)
if paddle.fluid.core.is_float16_supported(place):
return [place]
else:
return []
else:
return []
class TestDequantizeOp5Bits(TestDequantizeOp):
def set_args(self):
self.bit_length = 5
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册