From 5825196db95855da40083fc16322ae0691fe8289 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Tue, 29 May 2018 10:34:08 +0800 Subject: [PATCH] fix sgd for SelectedRows bug --- paddle/fluid/operators/sgd_op.h | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/sgd_op.h b/paddle/fluid/operators/sgd_op.h index f3e88b0a0b..f9e0596191 100644 --- a/paddle/fluid/operators/sgd_op.h +++ b/paddle/fluid/operators/sgd_op.h @@ -96,8 +96,12 @@ class SGDOpKernel : public framework::OpKernel { return; } - size_t param_row_width = param.value().numel() / param.rows().size(); - size_t grad_row_width = grad.value().numel() / grad.rows().size(); + auto param_row_width = param.value().dims()[1]; + auto grad_row_width = grad.value().dims()[1]; + VLOG(4) << " param rows: " << param.rows().size() + << " param memory rows: " << param.value().dims()[0] + << " grad rows: " << grad.rows().size() + << " grad memory rows: " << grad.value().dims()[0]; PADDLE_ENFORCE_EQ(param_row_width, grad_row_width, "param_row should have the same size with grad_row"); -- GitLab