未验证 提交 27ce06aa 编写于 作者: Z zhaoyingli 提交者: GitHub

[AutoParallel] quantization pass support export (#48072)

* [AutoParallel] quantization pass support export

* support subgraph

* move_presist_var_to_global_block

* update unittest

* fix ci-coverage

* fix codestyle

* fix fake_dequantize_op

* remove unused var

* fix ci error and aprroval error

* add unittest for fp16 in test_dequant_linear

* replace mutable data

* fix unittest in non-cuda-core

* fix unittest
Co-authored-by: Ncarryyu <569782149@qq.com>
Co-authored-by: Nwufeisheng <wfs1997@163.com>
上级 522c2bc0
...@@ -24,6 +24,22 @@ limitations under the License. */ ...@@ -24,6 +24,22 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T>
struct DequantizeFunctor<phi::CPUContext, T> {
void operator()(const phi::CPUContext &dev_ctx,
const phi::DenseTensor *in,
const phi::DenseTensor *scale,
T max_range,
phi::DenseTensor *out) {
auto in_e = framework::EigenVector<T>::Flatten(*in);
const T *scale_factor = scale->data<T>();
auto out_e = framework::EigenVector<T>::Flatten(*out);
auto &dev = *dev_ctx.eigen_device();
out_e.device(dev) = in_e * scale_factor[0] / max_range;
}
};
template <typename T> template <typename T>
struct ChannelDequantizeFunctorV2<phi::CPUContext, T> { struct ChannelDequantizeFunctorV2<phi::CPUContext, T> {
void operator()(const phi::CPUContext &dev_ctx, void operator()(const phi::CPUContext &dev_ctx,
...@@ -55,7 +71,7 @@ struct ChannelDequantizeFunctorV2<phi::CPUContext, T> { ...@@ -55,7 +71,7 @@ struct ChannelDequantizeFunctorV2<phi::CPUContext, T> {
int64_t step_i = in->numel() / out_iter; int64_t step_i = in->numel() / out_iter;
int64_t step_j = in->numel() / (out_iter * channel); int64_t step_j = in->numel() / (out_iter * channel);
auto *in_data = in->data<T>(); auto *in_data = in->data<T>();
auto *out_data = out->mutable_data<T>(dev_ctx.GetPlace()); auto *out_data = dev_ctx.Alloc<T>(out, out->numel() * sizeof(T));
for (int64_t i = 0; i < out_iter; i++) { for (int64_t i = 0; i < out_iter; i++) {
for (int64_t j = 0; j < channel; j++) { for (int64_t j = 0; j < channel; j++) {
auto *cur_in = in_data + i * step_i + j * step_j; auto *cur_in = in_data + i * step_i + j * step_j;
...@@ -72,6 +88,11 @@ struct ChannelDequantizeFunctorV2<phi::CPUContext, T> { ...@@ -72,6 +88,11 @@ struct ChannelDequantizeFunctorV2<phi::CPUContext, T> {
} }
}; };
template struct DequantizeFunctor<phi::CPUContext, phi::dtype::float16>;
template struct DequantizeFunctor<phi::CPUContext, float>;
template struct DequantizeFunctor<phi::CPUContext, double>;
template struct ChannelDequantizeFunctorV2<phi::CPUContext,
phi::dtype::float16>;
template struct ChannelDequantizeFunctorV2<phi::CPUContext, float>; template struct ChannelDequantizeFunctorV2<phi::CPUContext, float>;
template struct ChannelDequantizeFunctorV2<phi::CPUContext, double>; template struct ChannelDequantizeFunctorV2<phi::CPUContext, double>;
...@@ -214,6 +235,6 @@ REGISTER_OPERATOR( ...@@ -214,6 +235,6 @@ REGISTER_OPERATOR(
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(dequantize_linear, REGISTER_OP_CPU_KERNEL(dequantize_linear,
ops::DeQuantizeLinearKernel<CPU, float, float>, ops::DeQuantizeLinearKernel<CPU, float>,
ops::DeQuantizeLinearKernel<CPU, int8_t, float>, ops::DeQuantizeLinearKernel<CPU, int8_t>,
ops::DeQuantizeLinearKernel<CPU, double, double>); ops::DeQuantizeLinearKernel<CPU, double>);
...@@ -15,14 +15,64 @@ limitations under the License. */ ...@@ -15,14 +15,64 @@ limitations under the License. */
#include <string> #include <string>
#include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/operators/fake_dequantize_op.cu.h"
#include "paddle/fluid/operators/fake_quantize_op.cu.h" #include "paddle/fluid/operators/fake_quantize_op.cu.h"
#include "paddle/fluid/operators/quantize_linear_op.h" #include "paddle/fluid/operators/quantize_linear_op.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h" #include "paddle/phi/backends/gpu/gpu_primitives.h"
using float16 = paddle::platform::float16;
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T>
__global__ void KeDequantize(
const T* in, const T* scale, T max_range, int64_t num, T* out) {
int64_t idx = threadIdx.x + blockIdx.x * blockDim.x;
for (int64_t i = idx; i < num; i += blockDim.x * gridDim.x) {
out[i] = in[i] * scale[0] / max_range;
}
}
template <typename T>
__global__ void DequantizeOneScaleQuantAxisN(const T* in,
const T* scale,
const T max_range,
const int64_t num,
const int n_scales,
const int quant_stride,
T* out) {
int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;
for (int64_t i = idx; i < num; i += blockDim.x * gridDim.x) {
T s = scale[(i / quant_stride) % n_scales];
out[i] = in[i] * s / max_range;
}
}
template <typename T>
struct DequantizeFunctor<phi::GPUContext, T> {
void operator()(const phi::GPUContext& dev_ctx,
const phi::DenseTensor* in,
const phi::DenseTensor* scale,
T max_range,
phi::DenseTensor* out) {
const T* in_data = in->data<T>();
const T* scale_factor = scale->data<T>();
T* out_data = dev_ctx.Alloc<T>(out, out->numel() * sizeof(T));
int64_t num = in->numel();
int64_t block_size = std::min(
num, static_cast<int64_t>(dev_ctx.GetMaxThreadsPerBlock() / 4));
int64_t max_threads =
dev_ctx.GetMaxPhysicalThreadCount(); // SM * block_per_SM
const int64_t max_blocks =
std::max(((max_threads - 1) / block_size + 1), static_cast<int64_t>(1));
const int64_t grid_size =
std::min(max_blocks, (num + block_size - 1) / block_size);
KeDequantize<T><<<grid_size, block_size, 0, dev_ctx.stream()>>>(
in_data, scale_factor, max_range, num, out_data);
}
};
template <typename T> template <typename T>
struct ChannelDequantizeFunctorV2<phi::GPUContext, T> { struct ChannelDequantizeFunctorV2<phi::GPUContext, T> {
void operator()(const phi::GPUContext& dev_ctx, void operator()(const phi::GPUContext& dev_ctx,
...@@ -33,7 +83,7 @@ struct ChannelDequantizeFunctorV2<phi::GPUContext, T> { ...@@ -33,7 +83,7 @@ struct ChannelDequantizeFunctorV2<phi::GPUContext, T> {
phi::DenseTensor* out) { phi::DenseTensor* out) {
auto in_dims = in->dims(); auto in_dims = in->dims();
const T* in_data = in->data<T>(); const T* in_data = in->data<T>();
T* out_data = out->mutable_data<T>(dev_ctx.GetPlace()); T* out_data = dev_ctx.Alloc<T>(out, out->numel() * sizeof(T));
int64_t num = in->numel(); int64_t num = in->numel();
const T* scale_factor = scale->data<T>(); const T* scale_factor = scale->data<T>();
int64_t block_size = std::min( int64_t block_size = std::min(
...@@ -61,6 +111,10 @@ struct ChannelDequantizeFunctorV2<phi::GPUContext, T> { ...@@ -61,6 +111,10 @@ struct ChannelDequantizeFunctorV2<phi::GPUContext, T> {
} }
}; };
template struct DequantizeFunctor<phi::GPUContext, phi::dtype::float16>;
template struct DequantizeFunctor<phi::GPUContext, float>;
template struct DequantizeFunctor<phi::GPUContext, double>;
template struct ChannelDequantizeFunctorV2<phi::GPUContext, float16>;
template struct ChannelDequantizeFunctorV2<phi::GPUContext, float>; template struct ChannelDequantizeFunctorV2<phi::GPUContext, float>;
template struct ChannelDequantizeFunctorV2<phi::GPUContext, double>; template struct ChannelDequantizeFunctorV2<phi::GPUContext, double>;
...@@ -70,9 +124,11 @@ template struct ChannelDequantizeFunctorV2<phi::GPUContext, double>; ...@@ -70,9 +124,11 @@ template struct ChannelDequantizeFunctorV2<phi::GPUContext, double>;
namespace ops = paddle::operators; namespace ops = paddle::operators;
using CUDA = phi::GPUContext; using CUDA = phi::GPUContext;
REGISTER_OP_CUDA_KERNEL(dequantize_linear, REGISTER_OP_CUDA_KERNEL(dequantize_linear,
ops::DeQuantizeLinearKernel<CUDA, float, float>, ops::DeQuantizeLinearKernel<CUDA, float>,
ops::DeQuantizeLinearKernel<CUDA, int8_t, float>, ops::DeQuantizeLinearKernel<CUDA, float16>,
ops::DeQuantizeLinearKernel<CUDA, double, double>); ops::DeQuantizeLinearKernel<CUDA, int8_t>,
ops::DeQuantizeLinearKernel<CUDA, double>);
REGISTER_OP_CUDA_KERNEL(quantize_linear, REGISTER_OP_CUDA_KERNEL(quantize_linear,
ops::QuantizeLinearKernel<CUDA, float>); ops::QuantizeLinearKernel<CUDA, float>,
ops::QuantizeLinearKernel<CUDA, float16>);
...@@ -17,7 +17,6 @@ limitations under the License. */ ...@@ -17,7 +17,6 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/operators/fake_dequantize_op.h"
#include "paddle/fluid/operators/fake_quantize_op.h" #include "paddle/fluid/operators/fake_quantize_op.h"
#include "paddle/fluid/platform/transform.h" #include "paddle/fluid/platform/transform.h"
#include "paddle/phi/common/data_type.h" #include "paddle/phi/common/data_type.h"
...@@ -28,6 +27,15 @@ limitations under the License. */ ...@@ -28,6 +27,15 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename DeviceContext, typename T>
struct DequantizeFunctor {
void operator()(const DeviceContext& dev_ctx,
const phi::DenseTensor* in,
const phi::DenseTensor* scale,
T max_range,
phi::DenseTensor* out);
};
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
struct ChannelDequantizeFunctorV2 { struct ChannelDequantizeFunctorV2 {
void operator()(const DeviceContext& dev_ctx, void operator()(const DeviceContext& dev_ctx,
...@@ -105,10 +113,11 @@ class QuantizeLinearKernel : public framework::OpKernel<T> { ...@@ -105,10 +113,11 @@ class QuantizeLinearKernel : public framework::OpKernel<T> {
} }
}; };
template <typename DeviceContext, typename T, typename D> template <typename DeviceContext, typename T>
class DeQuantizeLinearKernel : public framework::OpKernel<T> { class DeQuantizeLinearKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { template <typename D>
void ComputeImpl(const framework::ExecutionContext& context) const {
auto& dev_ctx = context.template device_context<DeviceContext>(); auto& dev_ctx = context.template device_context<DeviceContext>();
auto* in = context.Input<phi::DenseTensor>("X"); auto* in = context.Input<phi::DenseTensor>("X");
...@@ -122,7 +131,7 @@ class DeQuantizeLinearKernel : public framework::OpKernel<T> { ...@@ -122,7 +131,7 @@ class DeQuantizeLinearKernel : public framework::OpKernel<T> {
auto* out = context.Output<phi::DenseTensor>("Y"); auto* out = context.Output<phi::DenseTensor>("Y");
int bit_length = context.Attr<int>("bit_length"); int bit_length = context.Attr<int>("bit_length");
auto quant_axis = context.Attr<int>("quant_axis"); auto quant_axis = context.Attr<int>("quant_axis");
out->mutable_data<D>(dev_ctx.GetPlace()); dev_ctx.template Alloc<D>(out, out->numel() * sizeof(D));
if (quant_axis < 0) { if (quant_axis < 0) {
float max_range = (std::pow(2, bit_length - 1) - 1); float max_range = (std::pow(2, bit_length - 1) - 1);
...@@ -144,6 +153,27 @@ class DeQuantizeLinearKernel : public framework::OpKernel<T> { ...@@ -144,6 +153,27 @@ class DeQuantizeLinearKernel : public framework::OpKernel<T> {
dev_ctx, &in_tmp, scale, static_cast<D>(max_range), quant_axis, out); dev_ctx, &in_tmp, scale, static_cast<D>(max_range), quant_axis, out);
} }
} }
void Compute(const framework::ExecutionContext& context) const override {
auto* scale = context.Input<phi::DenseTensor>("Scale");
switch (scale->dtype()) {
case experimental::DataType::FLOAT64:
ComputeImpl<double>(context);
break;
case experimental::DataType::FLOAT32:
ComputeImpl<float>(context);
break;
case experimental::DataType::FLOAT16:
ComputeImpl<paddle::platform::float16>(context);
break;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"In DeQuantizeLinearKernel, "
"data type %d for scale/output is not supported ",
scale->dtype()));
break;
}
}
}; };
} // namespace operators } // namespace operators
......
...@@ -107,6 +107,7 @@ set_field_default_config(QAT, "weight_bits", 8) ...@@ -107,6 +107,7 @@ set_field_default_config(QAT, "weight_bits", 8)
set_field_default_config(QAT, "activation_bits", 8) set_field_default_config(QAT, "activation_bits", 8)
set_field_default_config(QAT, "not_quant_pattern", ['skip_quant']) set_field_default_config(QAT, "not_quant_pattern", ['skip_quant'])
set_field_default_config(QAT, "algo", None) set_field_default_config(QAT, "algo", None)
set_field_default_config(QAT, "onnx_format", True)
# ######################################### # #########################################
# auto tuning configuration # auto tuning configuration
......
...@@ -29,7 +29,7 @@ from paddle.distributed import fleet ...@@ -29,7 +29,7 @@ from paddle.distributed import fleet
from paddle.fluid import Variable, core from paddle.fluid import Variable, core
from paddle.fluid.dygraph.parallel import ParallelEnv from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.fluid.executor import _to_name_str, global_scope from paddle.fluid.executor import _to_name_str, global_scope
from paddle.fluid.framework import Operator from paddle.fluid.framework import IrGraph, Operator
from paddle.fluid.framework import _current_expected_place as _get_device from paddle.fluid.framework import _current_expected_place as _get_device
from paddle.fluid.framework import in_dygraph_mode from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.layers.utils import flatten from paddle.fluid.layers.utils import flatten
...@@ -752,7 +752,9 @@ class Engine: ...@@ -752,7 +752,9 @@ class Engine:
# instantiate communication by process_mapping. # instantiate communication by process_mapping.
all_process_groups = get_all_process_groups() all_process_groups = get_all_process_groups()
cur_rank = self._cur_rank cur_rank = self._cur_rank
# NOTE: After the implementation of the unified dynamic and static communication group initialization mode in the future, the initialization logic of full mode will be removed because port occupation error may occur. # NOTE: After the implementation of the unified dynamic and static communication group
# initialization mode in the future, the initialization logic of full mode
# will be removed because port occupation error may occur.
if self._strategy.auto_mode == "full": if self._strategy.auto_mode == "full":
auto_utils.initialize_pg_in_full_mode( auto_utils.initialize_pg_in_full_mode(
all_process_groups, cur_rank all_process_groups, cur_rank
...@@ -763,9 +765,9 @@ class Engine: ...@@ -763,9 +765,9 @@ class Engine:
continue continue
process_group.instantiate() process_group.instantiate()
place = _get_device() self._place = _get_device()
if isinstance(place, fluid.CUDAPlace): if isinstance(self._place, fluid.CUDAPlace):
place = fluid.CUDAPlace(ParallelEnv().dev_id) self._place = fluid.CUDAPlace(ParallelEnv().dev_id)
if self._strategy.seed: if self._strategy.seed:
paddle.seed(self._strategy.seed + self._dp_ranks[0]) paddle.seed(self._strategy.seed + self._dp_ranks[0])
...@@ -775,10 +777,12 @@ class Engine: ...@@ -775,10 +777,12 @@ class Engine:
if self._dygraph_mode: if self._dygraph_mode:
dist_context = self._dist_contexts[mode] dist_context = self._dist_contexts[mode]
dist_main_program = self._dist_main_progs[mode][self._cur_rank] dist_main_program = self._dist_main_progs[mode][self._cur_rank]
self.program_helper.init(dist_main_program, place, dist_context) self.program_helper.init(
dist_main_program, self._place, dist_context
)
if self._executor is None: if self._executor is None:
self._executor = paddle.static.Executor(place) self._executor = paddle.static.Executor(self._place)
uninitialized = [] uninitialized = []
dist_startup_prog = self._dist_startup_progs[mode][self._cur_rank] dist_startup_prog = self._dist_startup_progs[mode][self._cur_rank]
for var in dist_startup_prog.list_vars(): for var in dist_startup_prog.list_vars():
...@@ -1612,6 +1616,22 @@ class Engine: ...@@ -1612,6 +1616,22 @@ class Engine:
feed_vars = self._feed_vars["predict"]['inputs'] feed_vars = self._feed_vars["predict"]['inputs']
fetch_vars = self._fetch_vars["predict"]['outputs'] fetch_vars = self._fetch_vars["predict"]['outputs']
dist_main_prog = self._dist_main_progs["predict"][self._cur_rank] dist_main_prog = self._dist_main_progs["predict"][self._cur_rank]
if self._strategy.qat.enable and self._strategy.qat.onnx_format:
from paddle.fluid.contrib.slim.quantization import (
QuantWeightPass,
)
self._logger.info("export quantized model.")
self._logger.info(
"convert config {}".format(self._strategy.qat.to_dict())
)
test_graph = IrGraph(
core.Graph(dist_main_prog.desc), for_test=True
)
quant_weight_pass = QuantWeightPass(global_scope(), self._place)
for sub_graph in test_graph.all_sub_graphs():
quant_weight_pass.apply(sub_graph)
dist_main_prog = test_graph.to_program()
self._saver.save_inference_model( self._saver.save_inference_model(
path, path,
feed_vars, feed_vars,
......
...@@ -131,8 +131,12 @@ class Parallelizer: ...@@ -131,8 +131,12 @@ class Parallelizer:
else: else:
# Apply pre optimization passes # Apply pre optimization passes
time0 = time.time() time0 = time.time()
self._apply_pre_optimization( (
serial_main_program, serial_startup_program, None, None, None serial_main_program,
serial_startup_program,
params_grads,
) = self._apply_pre_optimization(
serial_main_program, serial_startup_program, None, None, []
) )
self._logger.debug( self._logger.debug(
"within parallel apply_pre_optimization time: {}, mode {}".format( "within parallel apply_pre_optimization time: {}, mode {}".format(
...@@ -207,22 +211,6 @@ class Parallelizer: ...@@ -207,22 +211,6 @@ class Parallelizer:
if self._strategy is None: if self._strategy is None:
return return
# apply quantization pass
# The pass can be applied when mode must be 'train'
if self._mode == 'train' and self._strategy.qat.enable:
config = copy.deepcopy(self._strategy.qat.to_dict())
config["dist_context"] = self._dist_context
config["params_grads"] = params_grads
auto_parallel_quantization_pass = new_pass(
"auto_parallel_quantization", config
)
auto_parallel_quantization_pass.apply(
[main_program], [startup_program], self._pass_context
)
main_program = self._pass_context.get_attr("main_program")
startup_program = self._pass_context.get_attr("startup_program")
params_grads = self._pass_context.get_attr("params_grads")
# apply amp pass on train/eval/predict # apply amp pass on train/eval/predict
if self._strategy.amp.enable: if self._strategy.amp.enable:
config = copy.deepcopy(self._strategy.amp.to_dict()) config = copy.deepcopy(self._strategy.amp.to_dict())
...@@ -247,6 +235,25 @@ class Parallelizer: ...@@ -247,6 +235,25 @@ class Parallelizer:
) )
loss = auto_parallel_amp_pass.get_loss() loss = auto_parallel_amp_pass.get_loss()
# apply quantization pass
# The pass can be applied when mode must be 'train'
if self._strategy.qat.enable:
config = copy.deepcopy(self._strategy.qat.to_dict())
config["dist_context"] = self._dist_context
config["params_grads"] = params_grads
config["mode"] = self._mode
config["loss"] = loss
auto_parallel_quantization_pass = new_pass(
"auto_parallel_quantization", config
)
auto_parallel_quantization_pass.apply(
[main_program], [startup_program], self._pass_context
)
main_program = self._pass_context.get_attr("main_program")
startup_program = self._pass_context.get_attr("startup_program")
params_grads = self._pass_context.get_attr("params_grads")
loss = self._pass_context.get_attr("loss")
# apply recompute pass # apply recompute pass
# recompute is then train-only optimization # recompute is then train-only optimization
if self._mode == "train" and self._strategy.recompute.enable: if self._mode == "train" and self._strategy.recompute.enable:
......
...@@ -2137,6 +2137,36 @@ class Resharder: ...@@ -2137,6 +2137,36 @@ class Resharder:
input_attrs.append([process_mesh, input_dims_mapping]) input_attrs.append([process_mesh, input_dims_mapping])
return input_attrs return input_attrs
def _get_subblock_output_attrs(self, op, var_name):
# NOTE: Multi while loop is not supported
assert op.type in _g_subblock_ops
sub_block = self.auto_parallel_main_prog.blocks[op.attr("sub_block").id]
ops = sub_block.ops
output_attrs = []
for op in ops:
dist_op = self.dist_context.get_dist_op_for_program(op)
if not dist_op:
continue
dist_attr = dist_op.dist_attr
for name in op.output_arg_names:
if name == var_name:
process_mesh = dist_attr.process_mesh
output_dims_mapping = dist_attr.get_output_dims_mapping(
var_name
)
has_exist = False
for output_attr in output_attrs:
if (
process_mesh == output_attrs[0]
and output_dims_mapping == output_attrs[1]
):
has_exist = True
break
if not has_exist:
output_attrs.append([process_mesh, output_dims_mapping])
return output_attrs
def _get_common_op_input_attrs(self, op, var_name): def _get_common_op_input_attrs(self, op, var_name):
process_meshes = [] process_meshes = []
dist_op = self.dist_context.get_dist_op_for_program(op) dist_op = self.dist_context.get_dist_op_for_program(op)
...@@ -2166,6 +2196,11 @@ class Resharder: ...@@ -2166,6 +2196,11 @@ class Resharder:
if op.type in _g_subblock_ops: if op.type in _g_subblock_ops:
op_input_attrs = self._get_subblock_input_attrs(op, var_name) op_input_attrs = self._get_subblock_input_attrs(op, var_name)
if not op_input_attrs:
# NOTE: [hack method]
# Adapt to quantization pass, which presist_vars, including inputs and outputs, all are in global_block.
# Therefore, the while_op's inputs will contain the all persist_vars, which will be inputs or output of the quantization op in subblock.
op_input_attrs = self._get_subblock_output_attrs(op, var_name)
else: else:
op_input_attrs = self._get_common_op_input_attrs(op, var_name) op_input_attrs = self._get_common_op_input_attrs(op, var_name)
......
...@@ -877,6 +877,7 @@ class AMPPass(PassBase): ...@@ -877,6 +877,7 @@ class AMPPass(PassBase):
naive_set_dist_op_attr_for_program_by_mesh_and_mapping( naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
first_backward_op, ref_mesh, [-1], self.dist_context first_backward_op, ref_mesh, [-1], self.dist_context
) )
self._scaled_loss_grad.op = first_backward_op
# FIXME(JZ-LIANG) a trick to insert backward op # FIXME(JZ-LIANG) a trick to insert backward op
main_block._sync_with_cpp() main_block._sync_with_cpp()
elementwise_mul_grad_op_desc = main_block.desc._insert_op( elementwise_mul_grad_op_desc = main_block.desc._insert_op(
......
...@@ -12,13 +12,14 @@ ...@@ -12,13 +12,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
import numpy as np
import paddle import paddle
from paddle.distributed.auto_parallel.dist_attribute import (
OperatorDistributedAttribute,
TensorDistributedAttribute,
)
from paddle.fluid import core, framework from paddle.fluid import core, framework
from paddle.fluid.contrib.slim.quantization import ( from paddle.fluid.contrib.slim.quantization import (
AddQuantDequantForInferencePass,
AddQuantDequantPassV2, AddQuantDequantPassV2,
OutScaleForTrainingPass, OutScaleForTrainingPass,
QuantizationTransformPassV2, QuantizationTransformPassV2,
...@@ -26,6 +27,11 @@ from paddle.fluid.contrib.slim.quantization import ( ...@@ -26,6 +27,11 @@ from paddle.fluid.contrib.slim.quantization import (
) )
from paddle.fluid.dygraph.parallel import ParallelEnv from paddle.fluid.dygraph.parallel import ParallelEnv
from ..auto_parallel.converter import Converter
from ..auto_parallel.dist_attribute import (
OperatorDistributedAttribute,
TensorDistributedAttribute,
)
from .pass_base import PassBase, register_pass from .pass_base import PassBase, register_pass
TRANSFORM_PASS_OP_TYPES = utils._weight_supported_quantizable_op_type TRANSFORM_PASS_OP_TYPES = utils._weight_supported_quantizable_op_type
...@@ -42,6 +48,8 @@ class QuantizationPass(PassBase): ...@@ -42,6 +48,8 @@ class QuantizationPass(PassBase):
super().__init__() super().__init__()
self.set_attr("dist_context", None) self.set_attr("dist_context", None)
self.set_attr("params_grads", None) self.set_attr("params_grads", None)
self.set_attr("mode", "train")
self.set_attr("loss", None)
def _check_self(self): def _check_self(self):
if self.get_attr("dist_context") is None: if self.get_attr("dist_context") is None:
...@@ -57,15 +65,23 @@ class QuantizationPass(PassBase): ...@@ -57,15 +65,23 @@ class QuantizationPass(PassBase):
dist_context = self.get_attr("dist_context") dist_context = self.get_attr("dist_context")
params_grads = self.get_attr("params_grads") params_grads = self.get_attr("params_grads")
mode = self.get_attr("mode")
loss = self.get_attr("loss")
# TODO: scope and place will be removed, # TODO: scope and place will be removed,
# cause params should be initialized by engine module. # cause params should be initialized by engine module.
scope = paddle.static.global_scope() scope = paddle.static.global_scope()
place = paddle.fluid.CUDAPlace(ParallelEnv().dev_id) place = paddle.fluid.CUDAPlace(ParallelEnv().dev_id)
# 0. record the relation among blocks
parent_idx_dict = dict()
for block in main_program.blocks:
parent_idx_dict[block.idx] = block.parent_idx
is_test = True if mode != "train" else False
# 1. Program convert to Graph, and this pass is only for train mode # 1. Program convert to Graph, and this pass is only for train mode
main_graph = framework.IrGraph( main_graph = framework.IrGraph(
core.Graph(main_program.desc), for_test=False core.Graph(main_program.desc), for_test=mode != "train"
) )
# 2. Prepare inputs # 2. Prepare inputs
...@@ -91,6 +107,7 @@ class QuantizationPass(PassBase): ...@@ -91,6 +107,7 @@ class QuantizationPass(PassBase):
) )
# 3. Add quant op for ops which have parameters # 3. Add quant op for ops which have parameters
if len(transform_pass_ops) > 0:
transform_pass = QuantizationTransformPassV2( transform_pass = QuantizationTransformPassV2(
scope=scope, scope=scope,
place=place, place=place,
...@@ -106,29 +123,51 @@ class QuantizationPass(PassBase): ...@@ -106,29 +123,51 @@ class QuantizationPass(PassBase):
act_preprocess_func=None, act_preprocess_func=None,
optimizer_func=None, optimizer_func=None,
executor=None, executor=None,
is_test=is_test,
) )
transform_pass.apply(main_graph) for sub_graph in main_graph.all_sub_graphs():
transform_pass.apply(sub_graph)
# 4. Add quant op for ops which don't have parameter # 4. Add quant op for ops which don't have parameter
if len(quant_dequant_ops) > 0:
quant_dequant_pass = AddQuantDequantPassV2( quant_dequant_pass = AddQuantDequantPassV2(
scope=scope, scope=scope,
place=place, place=place,
quant_bits=self.get_attr('activation_bits'), quant_bits=self.get_attr('activation_bits'),
skip_pattern=self.get_attr('not_quant_pattern'), skip_pattern=self.get_attr('not_quant_pattern'),
quantizable_op_type=quant_dequant_ops, quantizable_op_type=quant_dequant_ops,
is_test=is_test,
) )
quant_dequant_pass.apply(main_graph) for sub_graph in main_graph.all_sub_graphs():
quant_dequant_pass.apply(sub_graph)
# 5. Gather quantitative information for the output # 5. Gather quantitative information for the output
out_scale_training_pass = OutScaleForTrainingPass( out_scale_training_pass = OutScaleForTrainingPass(
scope=scope, place=place scope=scope, place=place, is_test=is_test
) )
out_scale_training_pass.apply(main_graph) for sub_graph in main_graph.all_sub_graphs():
out_scale_training_pass.apply(sub_graph)
# 6. Convert Graph back to Program # 6. When export quant model, traverse to find the output of each op, and insert the quant/dequant op after it.
if mode != "train" and self.get_attr('onnx_format'):
try:
out_scale_infer_pass = AddQuantDequantForInferencePass(
scope=scope,
place=place,
quant_bits=self.get_attr('activation_bits'),
)
# for sub_graph in main_graph.all_sub_graphs():
# out_scale_infer_pass.apply(sub_graph)
except:
logging.warning(
"Unable to convert quant model with onnx_format=True, please update PaddlePaddle >= 2.4.0"
)
# 7. Convert Graph back to Program
quant_program = main_graph.to_program() quant_program = main_graph.to_program()
quant_program = self.move_presist_var_to_global_block(quant_program)
# 7. get new prams_grads from quant_program # 8.1 get new prams_grads from quant_program
new_params_grads = [] new_params_grads = []
for param, grad in params_grads: for param, grad in params_grads:
if param.name not in quant_program.global_block().vars: if param.name not in quant_program.global_block().vars:
...@@ -138,7 +177,72 @@ class QuantizationPass(PassBase): ...@@ -138,7 +177,72 @@ class QuantizationPass(PassBase):
new_grad = quant_program.global_block().vars[grad.name] new_grad = quant_program.global_block().vars[grad.name]
new_params_grads.append((new_param, new_grad)) new_params_grads.append((new_param, new_grad))
# 8. complete distributed attribution # 8.2 get new loss var
new_loss = None
if loss:
new_loss = quant_program.global_block().vars[loss.name]
# 8.3 recover the relation among blocks
for block in quant_program.blocks:
block.desc._set_forward_block_idx(parent_idx_dict[block.idx])
# 9. complete distributed attribution
self.set_dist_attr_for_qat_program(
quant_program, main_program, dist_context
)
# 10. reset scale var value with dist_attr
self.reset_scope_var(quant_program, dist_context, scope, place)
context.set_attr("main_program", quant_program)
context.set_attr("startup_program", startup_program)
context.set_attr("params_grads", new_params_grads)
context.set_attr("loss", new_loss)
def move_presist_var_to_global_block(self, program):
global_block = program.global_block()
for _op in global_block.ops:
if _op.type == "while":
_block_id = _op.attr("sub_block").id
_block = program.block(_block_id)
persistables = []
for _name, _var in _block.vars.items():
if _var.persistable:
global_block._clone_variable(_var)
persistables.append(_name)
for _name in persistables:
_block._remove_var(_name)
persistables.extend(_op.input('X'))
_op.desc.set_input("X", persistables)
return program
def reset_scope_var(self, quant_program, dist_context, scope, place):
# The var_value, created by qatization_passes, should has same shape with the value after parallel.
for var in quant_program.list_vars():
scope_var = scope.find_var(var.name)
if not (scope_var and scope_var.get_tensor()._is_initialized()):
continue
tensor = scope_var.get_tensor()
if var.shape == tensor.shape:
continue
var_dist_attr = dist_context.get_tensor_dist_attr_for_program(var)
dist_attr = {
"dims_mapping": var_dist_attr.dims_mapping,
"process_shape": var_dist_attr.process_mesh.topology,
"process_group": var_dist_attr.process_mesh.processes,
}
# slice tensor_value with dist_attr
sliced_tensor = Converter.slice_with_dist_attr(
np.array(tensor), dist_attr
)
tensor._clear()
tensor.set(sliced_tensor, place)
def set_dist_attr_for_qat_program(
self, quant_program, main_program, dist_context
):
# NOTE: hack implement, upgrading soon # NOTE: hack implement, upgrading soon
for ib, block in enumerate(quant_program.blocks): for ib, block in enumerate(quant_program.blocks):
# recover origin ops' dist_attr and set quant ops' dist_attr # recover origin ops' dist_attr and set quant ops' dist_attr
...@@ -150,15 +254,22 @@ class QuantizationPass(PassBase): ...@@ -150,15 +254,22 @@ class QuantizationPass(PassBase):
"quantize" in quant_op.type "quantize" in quant_op.type
or quant_op.type == "moving_average_abs_max_scale" or quant_op.type == "moving_average_abs_max_scale"
): ):
# set all quantization ops' dist_attr by quantified op
input_name = quant_op.desc.input('X')[0] input_name = quant_op.desc.input('X')[0]
if "quantize" in input_name: if "quantize" in input_name:
input_name = input_name[ input_name = input_name[
: input_name.index(".quantized") : input_name.index(".quantized")
] ]
if quant_op.type == "moving_average_abs_max_scale": if (
consume_op = main_program.blocks[ib].vars[input_name].op quant_op.type == "moving_average_abs_max_scale"
or ip - qat_offset >= len(main_program.blocks[ib].ops)
):
consume_op = (
main_program.blocks[ib]
._var_recursive(input_name)
.op
)
else: else:
consume_op = main_program.blocks[ib].ops[ consume_op = main_program.blocks[ib].ops[
ip - qat_offset ip - qat_offset
...@@ -185,13 +296,31 @@ class QuantizationPass(PassBase): ...@@ -185,13 +296,31 @@ class QuantizationPass(PassBase):
) )
for slot_name in quant_op.desc.input_names(): for slot_name in quant_op.desc.input_names():
in_name = quant_op.desc.input(slot_name)[0]
input_var = block._var_recursive(in_name)
ref_dims_mapping = [-1]
if slot_name == "X": if slot_name == "X":
continue continue
for in_name in quant_op.desc.input(slot_name): elif slot_name in ['Scale', 'ZeroPoint']:
input_var = block.vars[in_name] if (
quant_op.has_attr('quant_axis')
and quant_op.attr('quant_axis') != -1
):
x_name = quant_op.desc.input('X')[0]
x_var = block._var_recursive(x_name)
x_dist_attr = (
quant_op_dist_attr.get_input_dist_attr(
x_name
)
)
quant_axis = quant_op.attr('quant_axis')
ref_dims_mapping = [
x_dist_attr.dims_mapping[quant_axis]
]
tensor_dist_attr = TensorDistributedAttribute() tensor_dist_attr = TensorDistributedAttribute()
tensor_dist_attr.process_mesh = ref_process_mesh tensor_dist_attr.process_mesh = ref_process_mesh
tensor_dist_attr.dims_mapping = [-1] tensor_dist_attr.dims_mapping = ref_dims_mapping
dist_context.set_tensor_dist_attr_for_program( dist_context.set_tensor_dist_attr_for_program(
input_var, tensor_dist_attr input_var, tensor_dist_attr
) )
...@@ -201,7 +330,8 @@ class QuantizationPass(PassBase): ...@@ -201,7 +330,8 @@ class QuantizationPass(PassBase):
for slot_name in quant_op.desc.output_names(): for slot_name in quant_op.desc.output_names():
output_name = quant_op.desc.output(slot_name)[0] output_name = quant_op.desc.output(slot_name)[0]
output_var = block.vars[output_name] output_var = block._var_recursive(output_name)
ref_dims_mapping = [-1]
if slot_name == "Y": if slot_name == "Y":
dist_context.set_tensor_dist_attr_for_program( dist_context.set_tensor_dist_attr_for_program(
output_var, consume_input_dist_attr output_var, consume_input_dist_attr
...@@ -209,10 +339,27 @@ class QuantizationPass(PassBase): ...@@ -209,10 +339,27 @@ class QuantizationPass(PassBase):
quant_op_dist_attr.set_output_dist_attr( quant_op_dist_attr.set_output_dist_attr(
output_name, consume_input_dist_attr output_name, consume_input_dist_attr
) )
else: continue
elif slot_name == "OutScale":
if (
quant_op.has_attr('quant_axis')
and quant_op.attr('quant_axis') != -1
):
x_name = quant_op.desc.input('X')[0]
x_var = block._var_recursive(x_name)
x_dist_attr = (
quant_op_dist_attr.get_input_dist_attr(
x_name
)
)
quant_axis = quant_op.attr('quant_axis')
ref_dims_mapping = [
x_dist_attr.dims_mapping[quant_axis]
]
tensor_dist_attr = TensorDistributedAttribute() tensor_dist_attr = TensorDistributedAttribute()
tensor_dist_attr.process_mesh = ref_process_mesh tensor_dist_attr.process_mesh = ref_process_mesh
tensor_dist_attr.dims_mapping = [-1] tensor_dist_attr.dims_mapping = ref_dims_mapping
dist_context.set_tensor_dist_attr_for_program( dist_context.set_tensor_dist_attr_for_program(
output_var, tensor_dist_attr output_var, tensor_dist_attr
) )
...@@ -224,7 +371,7 @@ class QuantizationPass(PassBase): ...@@ -224,7 +371,7 @@ class QuantizationPass(PassBase):
qat_offset += 1 qat_offset += 1
else: else:
# recover origin ops' dist_attr
origin_op = main_program.blocks[ib].ops[ip - qat_offset] origin_op = main_program.blocks[ib].ops[ip - qat_offset]
quant_op.desc.set_original_id(origin_op.desc.original_id()) quant_op.desc.set_original_id(origin_op.desc.original_id())
dist_origin_op = dist_context.get_dist_op_for_program( dist_origin_op = dist_context.get_dist_op_for_program(
...@@ -240,7 +387,21 @@ class QuantizationPass(PassBase): ...@@ -240,7 +387,21 @@ class QuantizationPass(PassBase):
quant_op_dist_attr.process_mesh = ( quant_op_dist_attr.process_mesh = (
origin_op_dist_attr.process_mesh origin_op_dist_attr.process_mesh
) )
scale_offset = 0
for idx, input_name in enumerate(quant_op.input_arg_names): for idx, input_name in enumerate(quant_op.input_arg_names):
if (
origin_op.type == "while"
and input_name not in origin_op.input_arg_names
):
assert (
"@scale" in input_name
or "@zero_point" in input_name
)
scale_offset += 1
continue
idx -= scale_offset
origin_input_name = origin_op.input_arg_names[idx] origin_input_name = origin_op.input_arg_names[idx]
origin_input_dist_attr = ( origin_input_dist_attr = (
origin_op_dist_attr.inputs_dist_attrs[ origin_op_dist_attr.inputs_dist_attrs[
...@@ -251,20 +412,6 @@ class QuantizationPass(PassBase): ...@@ -251,20 +412,6 @@ class QuantizationPass(PassBase):
input_name, origin_input_dist_attr input_name, origin_input_dist_attr
) )
if input_name not in main_program.blocks[ib].vars:
origin_input_var = main_program.blocks[ib].vars[
origin_input_name
]
origin_in_tensor_dist_attr = (
dist_context.get_dist_tensor_for_program(
origin_input_var
).dist_attr
)
quant_input_var = block.vars[input_name]
dist_context.set_tensor_dist_attr_for_program(
quant_input_var, origin_in_tensor_dist_attr
)
for idx, output_name in enumerate( for idx, output_name in enumerate(
quant_op.output_arg_names quant_op.output_arg_names
): ):
...@@ -278,16 +425,18 @@ class QuantizationPass(PassBase): ...@@ -278,16 +425,18 @@ class QuantizationPass(PassBase):
output_name, origin_output_dist_attr output_name, origin_output_dist_attr
) )
if output_name not in main_program.blocks[ib].vars: if not main_program.blocks[ib]._find_var_recursive(
origin_output_var = main_program.blocks[ib].vars[ output_name
origin_output_name ):
] origin_output_var = main_program.blocks[
ib
]._var_recursive(origin_output_name)
origin_out_tensor_dist_attr = ( origin_out_tensor_dist_attr = (
dist_context.get_dist_tensor_for_program( dist_context.get_dist_tensor_for_program(
origin_output_var origin_output_var
).dist_attr ).dist_attr
) )
quant_output_var = block.vars[output_name] quant_output_var = block._var_recursive(output_name)
dist_context.set_tensor_dist_attr_for_program( dist_context.set_tensor_dist_attr_for_program(
quant_output_var, origin_out_tensor_dist_attr quant_output_var, origin_out_tensor_dist_attr
) )
...@@ -308,7 +457,3 @@ class QuantizationPass(PassBase): ...@@ -308,7 +457,3 @@ class QuantizationPass(PassBase):
dist_context.set_tensor_dist_attr_for_program( dist_context.set_tensor_dist_attr_for_program(
dst_var, dist_tensor.dist_attr dst_var, dist_tensor.dist_attr
) )
context.set_attr("main_program", quant_program)
context.set_attr("startup_program", startup_program)
context.set_attr("params_grads", new_params_grads)
...@@ -513,11 +513,12 @@ class QuantizationTransformPass: ...@@ -513,11 +513,12 @@ class QuantizationTransformPass:
var_dtype=var_node.dtype(), var_dtype=var_node.dtype(),
) )
scale_name = self._quantized_scale_name(name) scale_name = self._quantized_scale_name(name)
data_type = ( if var_node.dtype() == core.VarDesc.VarType.FP64:
'float64' data_type = 'float64'
if var_node.dtype() == core.VarDesc.VarType.FP64 elif var_node.dtype() == core.VarDesc.VarType.FP32:
else 'float32' data_type = 'float32'
) else:
data_type = "float16"
try: try:
scale_value = np.array( scale_value = np.array(
self._scope.find_var(scale_name).get_tensor() self._scope.find_var(scale_name).get_tensor()
...@@ -560,11 +561,12 @@ class QuantizationTransformPass: ...@@ -560,11 +561,12 @@ class QuantizationTransformPass:
) )
scale_name = self._quantized_scale_name(name) scale_name = self._quantized_scale_name(name)
data_type = ( if var_node.dtype() == core.VarDesc.VarType.FP64:
'float64' data_type = 'float64'
if var_node.dtype() == core.VarDesc.VarType.FP64 elif var_node.dtype() == core.VarDesc.VarType.FP32:
else 'float32' data_type = 'float32'
) else:
data_type = "float16"
try: try:
scale_value = np.array( scale_value = np.array(
self._scope.find_var(scale_name).get_tensor() self._scope.find_var(scale_name).get_tensor()
...@@ -591,11 +593,12 @@ class QuantizationTransformPass: ...@@ -591,11 +593,12 @@ class QuantizationTransformPass:
shape=[self._window_size], shape=[self._window_size],
var_dtype=var_node.dtype(), var_dtype=var_node.dtype(),
) )
data_type = ( if var_node.dtype() == core.VarDesc.VarType.FP64:
'float64' data_type = 'float64'
if var_node.dtype() == core.VarDesc.VarType.FP64 elif var_node.dtype() == core.VarDesc.VarType.FP32:
else 'float32' data_type = 'float32'
) else:
data_type = "float16"
_init_var_node( _init_var_node(
scales_node, scales_node,
np.zeros([self._window_size], dtype=data_type), np.zeros([self._window_size], dtype=data_type),
...@@ -640,11 +643,12 @@ class QuantizationTransformPass: ...@@ -640,11 +643,12 @@ class QuantizationTransformPass:
var_dtype=var_node.dtype(), var_dtype=var_node.dtype(),
) )
scale_name = self._quantized_scale_name(name) scale_name = self._quantized_scale_name(name)
data_type = ( if var_node.dtype() == core.VarDesc.VarType.FP64:
'float64' data_type = 'float64'
if var_node.dtype() == core.VarDesc.VarType.FP64 elif var_node.dtype() == core.VarDesc.VarType.FP32:
else 'float32' data_type = 'float32'
) else:
data_type = "float16"
try: try:
scale_value = np.array( scale_value = np.array(
self._scope.find_var(scale_name).get_tensor() self._scope.find_var(scale_name).get_tensor()
...@@ -669,11 +673,12 @@ class QuantizationTransformPass: ...@@ -669,11 +673,12 @@ class QuantizationTransformPass:
var_dtype=var_node.dtype(), var_dtype=var_node.dtype(),
shape=[1], shape=[1],
) )
data_type = ( if var_node.dtype() == core.VarDesc.VarType.FP64:
'float64' data_type = 'float64'
if var_node.dtype() == core.VarDesc.VarType.FP64 elif var_node.dtype() == core.VarDesc.VarType.FP32:
else 'float32' data_type = 'float32'
) else:
data_type = "float16"
_init_var_node( _init_var_node(
state_in_node, state_in_node,
np.ones([1], dtype=data_type), np.ones([1], dtype=data_type),
...@@ -746,11 +751,12 @@ class QuantizationTransformPass: ...@@ -746,11 +751,12 @@ class QuantizationTransformPass:
var_dtype=var_node.dtype(), var_dtype=var_node.dtype(),
) )
scale_name = self._quantized_scale_name(name) scale_name = self._quantized_scale_name(name)
data_type = ( if var_node.dtype() == core.VarDesc.VarType.FP64:
'float64' data_type = 'float64'
if var_node.dtype() == core.VarDesc.VarType.FP64 elif var_node.dtype() == core.VarDesc.VarType.FP32:
else 'float32' data_type = 'float32'
) else:
data_type = "float16"
try: try:
scale_value = np.array( scale_value = np.array(
self._scope.find_var(scale_name).get_tensor() self._scope.find_var(scale_name).get_tensor()
...@@ -1287,11 +1293,13 @@ class QuantizationFreezePass: ...@@ -1287,11 +1293,13 @@ class QuantizationFreezePass:
shape=[channel_scale.shape[0]], shape=[channel_scale.shape[0]],
var_dtype=output_var_node.dtype(), var_dtype=output_var_node.dtype(),
) )
data_type = (
'float64' if output_var_node.dtype() == core.VarDesc.VarType.FP64:
if output_var_node.dtype() == core.VarDesc.VarType.FP64 data_type = 'float64'
else 'float32' elif output_var_node.dtype() == core.VarDesc.VarType.FP32:
) data_type = 'float32'
else:
data_type = "float16"
_init_var_node( _init_var_node(
weight_scale_node, weight_scale_node,
channel_scale.astype(data_type), channel_scale.astype(data_type),
...@@ -1439,6 +1447,7 @@ class QuantizationFreezePass: ...@@ -1439,6 +1447,7 @@ class QuantizationFreezePass:
def _is_float(self, v): def _is_float(self, v):
return ( return (
isinstance(v, float) isinstance(v, float)
or isinstance(v, np.float16)
or isinstance(v, np.float32) or isinstance(v, np.float32)
or isinstance(v, np.float64) or isinstance(v, np.float64)
) )
...@@ -1636,14 +1645,17 @@ class OutScaleForTrainingPass: ...@@ -1636,14 +1645,17 @@ class OutScaleForTrainingPass:
if in_node.dtype() not in [ if in_node.dtype() not in [
core.VarDesc.VarType.FP64, core.VarDesc.VarType.FP64,
core.VarDesc.VarType.FP32, core.VarDesc.VarType.FP32,
core.VarDesc.VarType.FP16,
]: ]:
continue continue
data_type = ( if in_node.dtype() == core.VarDesc.VarType.FP64:
'float64' data_type = 'float64'
if in_node.dtype() == core.VarDesc.VarType.FP64 elif in_node.dtype() == core.VarDesc.VarType.FP32:
else 'float32' data_type = 'float32'
) else:
data_type = "float16"
try: try:
graph._find_node_by_name( graph._find_node_by_name(
graph.all_var_nodes(), graph.all_var_nodes(),
...@@ -1781,6 +1793,7 @@ class OutScaleForInferencePass: ...@@ -1781,6 +1793,7 @@ class OutScaleForInferencePass:
not in [ not in [
core.VarDesc.VarType.FP64, core.VarDesc.VarType.FP64,
core.VarDesc.VarType.FP32, core.VarDesc.VarType.FP32,
core.VarDesc.VarType.FP16,
] ]
): ):
continue continue
...@@ -1997,11 +2010,12 @@ class AddQuantDequantPass: ...@@ -1997,11 +2010,12 @@ class AddQuantDequantPass:
var_dtype=var_node.dtype(), var_dtype=var_node.dtype(),
) )
scale_name = "{}.quant_dequant@scale".format(var_node.name()) scale_name = "{}.quant_dequant@scale".format(var_node.name())
data_type = ( if var_node.dtype() == core.VarDesc.VarType.FP64:
'float64' data_type = 'float64'
if var_node.dtype() == core.VarDesc.VarType.FP64 elif var_node.dtype() == core.VarDesc.VarType.FP32:
else 'float32' data_type = 'float32'
) else:
data_type = "float16"
try: try:
if ( if (
self._scale_dict is not None self._scale_dict is not None
...@@ -2036,11 +2050,12 @@ class AddQuantDequantPass: ...@@ -2036,11 +2050,12 @@ class AddQuantDequantPass:
var_dtype=var_node.dtype(), var_dtype=var_node.dtype(),
shape=[1], shape=[1],
) )
data_type = ( if var_node.dtype() == core.VarDesc.VarType.FP64:
'float64' data_type = 'float64'
if var_node.dtype() == core.VarDesc.VarType.FP64 elif var_node.dtype() == core.VarDesc.VarType.FP32:
else 'float32' data_type = 'float32'
) else:
data_type = "float16"
_init_var_node( _init_var_node(
state_in_node, state_in_node,
np.ones([1], dtype=data_type), np.ones([1], dtype=data_type),
...@@ -2149,11 +2164,12 @@ class InsertQuantizeLinear: ...@@ -2149,11 +2164,12 @@ class InsertQuantizeLinear:
var_dtype=var_node.dtype(), var_dtype=var_node.dtype(),
) )
if not scale_var_node: if not scale_var_node:
data_type = ( if var_node.dtype() == core.VarDesc.VarType.FP64:
'float64' data_type = 'float64'
if var_node.dtype() == core.VarDesc.VarType.FP64 elif var_node.dtype() == core.VarDesc.VarType.FP32:
else 'float32' data_type = 'float32'
) else:
data_type = "float16"
scale_name = self._quantized_scale_name(var_name) scale_name = self._quantized_scale_name(var_name)
if self.channel_wise: if self.channel_wise:
scale_var_shape = var_node.shape()[self.quant_axis] scale_var_shape = var_node.shape()[self.quant_axis]
...@@ -2218,11 +2234,12 @@ class InsertQuantizeLinear: ...@@ -2218,11 +2234,12 @@ class InsertQuantizeLinear:
var_dtype=var_node.dtype(), var_dtype=var_node.dtype(),
shape=[1], shape=[1],
) )
data_type = ( if var_node.dtype() == core.VarDesc.VarType.FP64:
'float64' data_type = 'float64'
if var_node.dtype() == core.VarDesc.VarType.FP64 elif var_node.dtype() == core.VarDesc.VarType.FP32:
else 'float32' data_type = 'float32'
) else:
data_type = "float16"
_init_var_node( _init_var_node(
state_in_node, state_in_node,
np.ones([1], dtype=data_type), np.ones([1], dtype=data_type),
...@@ -3277,6 +3294,7 @@ class AddQuantDequantForInferencePass: ...@@ -3277,6 +3294,7 @@ class AddQuantDequantForInferencePass:
if out_node.dtype() not in [ if out_node.dtype() not in [
core.VarDesc.VarType.FP64, core.VarDesc.VarType.FP64,
core.VarDesc.VarType.FP32, core.VarDesc.VarType.FP32,
core.VarDesc.VarType.FP16,
]: ]:
continue continue
if var_name in dequantized_vars_map: if var_name in dequantized_vars_map:
......
...@@ -74,6 +74,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU) ...@@ -74,6 +74,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
set_tests_properties(test_parallel_tuner_predict PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_tuner_predict PROPERTIES TIMEOUT 120)
py_test_modules(test_selective_recompute MODULES test_selective_recompute) py_test_modules(test_selective_recompute MODULES test_selective_recompute)
set_tests_properties(test_selective_recompute PROPERTIES TIMEOUT 50) set_tests_properties(test_selective_recompute PROPERTIES TIMEOUT 50)
py_test_modules(test_pass_quantization MODULES test_pass_quantization ENVS
${dist_ENVS})
set_tests_properties(test_pass_quantization PROPERTIES TIMEOUT 60)
py_test_modules(test_tuning_recompute MODULES test_tuning_recompute) py_test_modules(test_tuning_recompute MODULES test_tuning_recompute)
set_tests_properties(test_tuning_recompute PROPERTIES TIMEOUT 240) set_tests_properties(test_tuning_recompute PROPERTIES TIMEOUT 240)
...@@ -113,7 +116,6 @@ if(WITH_DISTRIBUTE AND WITH_GPU) ...@@ -113,7 +116,6 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_process_mesh MODULES test_process_mesh) py_test_modules(test_process_mesh MODULES test_process_mesh)
py_test_modules(test_interface MODULES test_interface) py_test_modules(test_interface MODULES test_interface)
py_test_modules(test_strategy MODULES test_strategy) py_test_modules(test_strategy MODULES test_strategy)
py_test_modules(test_pass_quantization MODULES test_pass_quantization)
py_test_modules(test_dist_shape MODULES test_dist_shape) py_test_modules(test_dist_shape MODULES test_dist_shape)
py_test_modules(test_dist_assign MODULES test_dist_assign) py_test_modules(test_dist_assign MODULES test_dist_assign)
py_test_modules(test_conditional_block_reshard MODULES py_test_modules(test_conditional_block_reshard MODULES
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import tempfile
import unittest
import numpy as np
from get_gpt_model import FakeDataset, create_data_holder, generate_model
import paddle
from paddle.distributed.fleet import auto
paddle.enable_static()
def apply_pass():
dist_strategy = auto.Strategy()
dist_strategy.auto_mode = "semi"
amp = dist_strategy.amp
amp.enable = True
amp.custom_white_list = ["lookup_table", "lookup_table_v2"]
amp.custom_black_list = [
"reduce_sum",
"c_softmax_with_cross_entropy",
"elementwise_div",
]
amp.init_loss_scaling = 32768
amp.use_fp16_guard = False
amp.use_pure_fp16 = True
qat = dist_strategy.qat
qat.enable = True
qat.channel_wise_abs_max = True
qat.weight_bits = 8
qat.activation_bits = 8
qat.not_quant_pattern = ['skip_quant']
qat.onnx_format = True
return dist_strategy
class TestQuantizationPassTrain(unittest.TestCase):
def test_qat_pass_training(self):
batch_size = 1
batch_num = 10
strategy = apply_pass()
model, loss = generate_model("mp")
opt = paddle.optimizer.AdamW(learning_rate=0.00001)
engine = auto.Engine(model, loss, opt, strategy=strategy)
dataset = FakeDataset(batch_size * batch_num)
engine.fit(dataset, 3, batch_size=batch_size)
self.check_program(engine.main_program)
def check_program(self, program):
quantizable_op_and_inputs = {'matmul_v2': ['X', 'Y']}
quantizable_grad_op_inputs = {'matmul_v2_grad': ['X', 'Y']}
quantized_ops = set()
for block in program.blocks:
for idx, op in enumerate(block.ops):
is_quntized = False
if op.type in quantizable_op_and_inputs:
for arg_name in op.input_arg_names:
if ".quantized" in arg_name:
is_quntized = True
if not is_quntized:
continue
# check forward
if op.type in quantizable_op_and_inputs:
for arg_name in op.input_arg_names:
if "c_identity" in arg_name:
arg_name = block.ops[idx - 1].input_arg_names[0]
assert arg_name.endswith('.quantized.dequantized')
quantized_ops.add(arg_name)
for op in block.ops:
is_quntized = False
if op.type in quantizable_grad_op_inputs:
for pname in quantizable_grad_op_inputs[op.type]:
arg_name = op.input(pname)[0]
if ".quantized" in arg_name:
is_quntized = True
if not is_quntized:
continue
# check backward
if op.type in quantizable_grad_op_inputs:
for pname in quantizable_grad_op_inputs[op.type]:
arg_name = op.input(pname)[0]
assert arg_name.endswith('.quantized.dequantized')
assert arg_name in quantized_ops
class TestQuantizationPassExport(unittest.TestCase):
def setUp(self):
self.temp_dir = tempfile.TemporaryDirectory()
def tearDown(self):
self.temp_dir.cleanup()
def test_qat_pass_2(self):
batch_size = 1
batch_num = 10
strategy = apply_pass()
model, loss = generate_model("mp")
engine = auto.Engine(model, loss, strategy=strategy)
inputs_spec, labels_spec = create_data_holder(batch_size=1)
engine.prepare(inputs_spec, labels_spec, mode="predict")
path = os.path.join(self.temp_dir.name, 'inf')
engine.save(path, training=False)
self.check_export(engine._executor)
def check_export(self, exe):
sequence_len = 512
vocab_size = 1000
tokens = [np.random.randint(vocab_size, size=sequence_len)]
position_ids = [np.arange(sequence_len)]
attention_mask = [np.tril(np.ones(sequence_len))]
path_prefix = os.path.join(
self.temp_dir.name,
'inf_dist{}'.format(paddle.distributed.get_rank()),
)
[
inference_program,
feed_target_names,
fetch_targets,
] = paddle.static.load_inference_model(
path_prefix=path_prefix, executor=exe
)
out = exe.run(
inference_program,
feed={
"tokens": tokens,
"position_ids": position_ids,
"attention_mask": attention_mask,
},
fetch_list=fetch_targets,
)
if __name__ == "__main__":
unittest.main()
...@@ -12,83 +12,45 @@ ...@@ -12,83 +12,45 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import subprocess
import sys
import tempfile
import unittest import unittest
from get_gpt_model import FakeDataset, generate_model
import paddle
from paddle.distributed.fleet import auto
paddle.enable_static()
def apply_pass():
dist_strategy = auto.Strategy()
dist_strategy.auto_mode = "semi"
qat = dist_strategy.qat
qat.enable = True
qat.channel_wise_abs_max = True
qat.weight_bits = 8
qat.activation_bits = 8
qat.not_quant_pattern = ['skip_quant']
return dist_strategy
class TestQuantizationPass(unittest.TestCase): class TestQuantizationPass(unittest.TestCase):
def test_qat_pass(self): def test_mp2(self):
file_dir = os.path.dirname(os.path.abspath(__file__))
batch_size = 8 launch_model_path = os.path.join(
batch_num = 10 file_dir, "quantization_pass_unittest.py"
)
strategy = apply_pass()
model, loss = generate_model("serial") if os.environ.get("WITH_COVERAGE", "OFF") == "ON":
opt = paddle.optimizer.AdamW(learning_rate=0.00001) coverage_args = ["-m", "coverage", "run", "--branch", "-p"]
engine = auto.Engine(model, loss, opt, strategy=strategy) else:
dataset = FakeDataset(batch_size * batch_num) coverage_args = []
engine.fit(dataset, 3, batch_size=batch_size)
tmp_dir = tempfile.TemporaryDirectory()
self.check_program(engine.main_program) cmd = (
[sys.executable, "-u"]
def check_program(self, program): + coverage_args
+ [
quantizable_op_and_inputs = {'matmul_v2': ['X', 'Y']} "-m",
quantizable_grad_op_inputs = {'matmul_v2_grad': ['X', 'Y']} "paddle.distributed.launch",
"--devices",
quantized_ops = set() "0,1",
for block in program.blocks: "--log_dir",
for op in block.ops: tmp_dir.name,
is_quntized = False launch_model_path,
if op.type in quantizable_op_and_inputs: ]
for arg_name in op.input_arg_names: )
if ".quantized" in arg_name:
is_quntized = True process = subprocess.Popen(cmd)
process.wait()
if not is_quntized: self.assertEqual(process.returncode, 0)
continue
tmp_dir.cleanup()
# check forward
if op.type in quantizable_op_and_inputs:
for arg_name in op.input_arg_names:
assert arg_name.endswith('.quantized.dequantized')
quantized_ops.add(arg_name)
for op in block.ops:
is_quntized = False
if op.type in quantizable_grad_op_inputs:
for pname in quantizable_grad_op_inputs[op.type]:
arg_name = op.input(pname)[0]
if ".quantized" in arg_name:
is_quntized = True
if not is_quntized:
continue
# check backward
if op.type in quantizable_grad_op_inputs:
for pname in quantizable_grad_op_inputs[op.type]:
arg_name = op.input(pname)[0]
assert arg_name.endswith('.quantized.dequantized')
assert arg_name in quantized_ops
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -834,6 +834,7 @@ class GPTForPretraining(nn.Layer): ...@@ -834,6 +834,7 @@ class GPTForPretraining(nn.Layer):
x_dims_mapping = ["x"] + [None for i in range(len(x.shape) - 1)] x_dims_mapping = ["x"] + [None for i in range(len(x.shape) - 1)]
w_dims_mapping = ["y"] + [None for i in range(len(w.shape) - 1)] w_dims_mapping = ["y"] + [None for i in range(len(w.shape) - 1)]
with paddle.fluid.name_scope('skip_quant'):
if mesh: if mesh:
matmul = auto.shard_op( matmul = auto.shard_op(
paddle.matmul, mesh, [x_dims_mapping, w_dims_mapping, None] paddle.matmul, mesh, [x_dims_mapping, w_dims_mapping, None]
......
...@@ -26,7 +26,7 @@ def quantize_max_abs(x, max_range): ...@@ -26,7 +26,7 @@ def quantize_max_abs(x, max_range):
def dequantize_max_abs(x, scale, max_range): def dequantize_max_abs(x, scale, max_range):
y = (scale / max_range) * x y = x * scale / max_range
return y return y
...@@ -292,6 +292,45 @@ class TestDequantizeOpDouble(TestDequantizeOp): ...@@ -292,6 +292,45 @@ class TestDequantizeOpDouble(TestDequantizeOp):
self.quant_axis = -1 self.quant_axis = -1
class TestDequantizeOpHalf(TestDequantizeOp):
def set_args(self):
self.bit_length = 8
self.max_range = math.pow(2, self.bit_length - 1) - 1
self.data_type = "float16"
self.quant_axis = -1
def setUp(self):
self.set_args()
self.op_type = "dequantize_linear"
x = np.random.randn(31, 65).astype(np.float16)
yq, scale = quantize_max_abs(x, self.max_range)
scale = np.array(scale).astype('float16')
yq = np.array(yq).astype('int8')
ydq = dequantize_max_abs(yq, scale, self.max_range)
ydq = ydq.astype('float16')
zero_point = np.zeros(scale.shape, dtype="int32")
self.inputs = {'X': yq, 'Scale': scale, 'ZeroPoint': zero_point}
self.attrs = {
'bit_length': self.bit_length,
'quant_axis': self.quant_axis,
}
self.outputs = {'Y': ydq}
def _get_places(self):
import paddle
import paddle.fluid.core as core
if core.is_compiled_with_cuda():
place = paddle.fluid.core.CUDAPlace(0)
if paddle.fluid.core.is_float16_supported(place):
return [place]
else:
return []
else:
return []
class TestDequantizeOp5Bits(TestDequantizeOp): class TestDequantizeOp5Bits(TestDequantizeOp):
def set_args(self): def set_args(self):
self.bit_length = 5 self.bit_length = 5
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册