未验证 提交 e51d6f9d 编写于 作者: Z zhouzj 提交者: GitHub

Solve the OOM problem of ReconPTQ (#1523)

* Solve the 'oom' problem of QDrop.

* Fix syntax.
上级 e94c69ce
......@@ -358,44 +358,42 @@ class ReconstructionQuanter(object):
def _run(self):
self._preprocess()
startup_program = paddle.static.Program()
tmp_program = self._student_program.clone()
for k in range(len(self._regions)):
region_ = self._regions[k]
names = self._region_weights_names[k]
tmp_program = self._student_program.clone()
tmp_program.global_block().var(region_[0]).stop_gradient = True
quant_op_out_name = region_[1]
names = self._region_weights_names[k]
_logger.info(f"Current weights: {names}")
loss_function = ReconstructionQuanterLoss(
program=tmp_program, weight_region_names=names)
update_params = [
tmp_program.global_block().var(name + '.alpha')
for name in names
]
with paddle.static.program_guard(tmp_program, startup_program):
loss_function = ReconstructionQuanterLoss(
program=tmp_program, weight_region_names=names)
student_var = tmp_program.global_block().var(quant_op_out_name)
teacher_var = tmp_program.global_block().var("teacher_" +
quant_op_out_name)
scheduler = paddle.optimizer.lr.CosineAnnealingDecay(
learning_rate=20,
eta_min=2,
T_max=2000,
verbose=True, )
total_loss, recon_loss, round_loss = loss_function.get_loss(
student_var,
teacher_var,
scheduler, )
teacher_var, )
train_fetches_loss = {
"total_loss": total_loss,
"recon_loss": recon_loss,
"round_loss": round_loss,
}
optimizer = paddle.optimizer.Adam(learning_rate=self._lr)
optimizer = paddle.optimizer.Adam(
learning_rate=self._lr, parameters=update_params)
optimizer.minimize(total_loss)
self._exe.run(startup_program)
start_time = time.time()
prev_start_time = start_time
loader = self._data_loader()
for epoch in range(self._epochs):
for i, data in (
enumerate(loader) if
(isinstance(self._data_loader, paddle.fluid.io.DataLoader)
and self._data_loader.batch_size == 1) else
enumerate(self._data_loader())):
for i, data in (enumerate(self._data_loader())):
prev_start_time = start_time
start_time = time.time()
out = self._exe.run(
......@@ -406,14 +404,14 @@ class ReconstructionQuanter(object):
],
return_numpy=True, )
_logger.info(
"Iter {:d}, lr {}, total_loss {:.5f}, recon_loss {:.5f}, round_loss {:.5f}, time {:.5f}s"
.format(epoch, self._lr,
"Epoch {:d}, Iter {:d}, lr {}, total_loss {:.5f}, recon_loss {:.5f}, round_loss {:.5f}, time {:.5f}s"
.format(epoch, i, self._lr,
np.mean(out[0]),
np.mean(out[1]),
np.mean(out[2]),
start_time - prev_start_time), )
sys.stdout.flush()
if i == self._num_iterations:
if i + 1 == self._num_iterations:
break
self._update_scale()
......@@ -831,7 +829,7 @@ class ReconstructionQuanterLoss(object):
paddle.nn.functional.sigmoid(alpha_v) * (ZETA - GAMMA) + GAMMA, 0,
1)
def get_loss(self, student_tensor, teacher_tensor, scheduler):
def get_loss(self, student_tensor, teacher_tensor, scheduler=None):
if self.rec_loss_type == 'mse':
rec_loss = paddle.nn.functional.mse_loss(
student_tensor,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册