未验证 提交 e51d6f9d 编写于 作者: Z zhouzj 提交者: GitHub

Solve the OOM problem of ReconPTQ (#1523)

* Solve the 'oom' problem of QDrop.

* Fix syntax.
上级 e94c69ce
...@@ -358,44 +358,42 @@ class ReconstructionQuanter(object): ...@@ -358,44 +358,42 @@ class ReconstructionQuanter(object):
def _run(self): def _run(self):
self._preprocess() self._preprocess()
startup_program = paddle.static.Program() startup_program = paddle.static.Program()
tmp_program = self._student_program.clone()
for k in range(len(self._regions)): for k in range(len(self._regions)):
region_ = self._regions[k] region_ = self._regions[k]
names = self._region_weights_names[k] tmp_program.global_block().var(region_[0]).stop_gradient = True
tmp_program = self._student_program.clone()
quant_op_out_name = region_[1] quant_op_out_name = region_[1]
with paddle.static.program_guard(tmp_program, startup_program):
names = self._region_weights_names[k]
_logger.info(f"Current weights: {names}")
loss_function = ReconstructionQuanterLoss( loss_function = ReconstructionQuanterLoss(
program=tmp_program, weight_region_names=names) program=tmp_program, weight_region_names=names)
update_params = [
tmp_program.global_block().var(name + '.alpha')
for name in names
]
with paddle.static.program_guard(tmp_program, startup_program):
student_var = tmp_program.global_block().var(quant_op_out_name) student_var = tmp_program.global_block().var(quant_op_out_name)
teacher_var = tmp_program.global_block().var("teacher_" + teacher_var = tmp_program.global_block().var("teacher_" +
quant_op_out_name) quant_op_out_name)
scheduler = paddle.optimizer.lr.CosineAnnealingDecay(
learning_rate=20,
eta_min=2,
T_max=2000,
verbose=True, )
total_loss, recon_loss, round_loss = loss_function.get_loss( total_loss, recon_loss, round_loss = loss_function.get_loss(
student_var, student_var,
teacher_var, teacher_var, )
scheduler, )
train_fetches_loss = { train_fetches_loss = {
"total_loss": total_loss, "total_loss": total_loss,
"recon_loss": recon_loss, "recon_loss": recon_loss,
"round_loss": round_loss, "round_loss": round_loss,
} }
optimizer = paddle.optimizer.Adam(learning_rate=self._lr) optimizer = paddle.optimizer.Adam(
learning_rate=self._lr, parameters=update_params)
optimizer.minimize(total_loss) optimizer.minimize(total_loss)
self._exe.run(startup_program) self._exe.run(startup_program)
start_time = time.time() start_time = time.time()
prev_start_time = start_time prev_start_time = start_time
loader = self._data_loader()
for epoch in range(self._epochs): for epoch in range(self._epochs):
for i, data in ( for i, data in (enumerate(self._data_loader())):
enumerate(loader) if
(isinstance(self._data_loader, paddle.fluid.io.DataLoader)
and self._data_loader.batch_size == 1) else
enumerate(self._data_loader())):
prev_start_time = start_time prev_start_time = start_time
start_time = time.time() start_time = time.time()
out = self._exe.run( out = self._exe.run(
...@@ -406,14 +404,14 @@ class ReconstructionQuanter(object): ...@@ -406,14 +404,14 @@ class ReconstructionQuanter(object):
], ],
return_numpy=True, ) return_numpy=True, )
_logger.info( _logger.info(
"Iter {:d}, lr {}, total_loss {:.5f}, recon_loss {:.5f}, round_loss {:.5f}, time {:.5f}s" "Epoch {:d}, Iter {:d}, lr {}, total_loss {:.5f}, recon_loss {:.5f}, round_loss {:.5f}, time {:.5f}s"
.format(epoch, self._lr, .format(epoch, i, self._lr,
np.mean(out[0]), np.mean(out[0]),
np.mean(out[1]), np.mean(out[1]),
np.mean(out[2]), np.mean(out[2]),
start_time - prev_start_time), ) start_time - prev_start_time), )
sys.stdout.flush() sys.stdout.flush()
if i == self._num_iterations: if i + 1 == self._num_iterations:
break break
self._update_scale() self._update_scale()
...@@ -831,7 +829,7 @@ class ReconstructionQuanterLoss(object): ...@@ -831,7 +829,7 @@ class ReconstructionQuanterLoss(object):
paddle.nn.functional.sigmoid(alpha_v) * (ZETA - GAMMA) + GAMMA, 0, paddle.nn.functional.sigmoid(alpha_v) * (ZETA - GAMMA) + GAMMA, 0,
1) 1)
def get_loss(self, student_tensor, teacher_tensor, scheduler): def get_loss(self, student_tensor, teacher_tensor, scheduler=None):
if self.rec_loss_type == 'mse': if self.rec_loss_type == 'mse':
rec_loss = paddle.nn.functional.mse_loss( rec_loss = paddle.nn.functional.mse_loss(
student_tensor, student_tensor,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册