提交 5825196d 编写于 作者: Q qiaolongfei

fix sgd for SelectedRows bug

上级 c797aded
...@@ -96,8 +96,12 @@ class SGDOpKernel : public framework::OpKernel<T> { ...@@ -96,8 +96,12 @@ class SGDOpKernel : public framework::OpKernel<T> {
return; return;
} }
size_t param_row_width = param.value().numel() / param.rows().size(); auto param_row_width = param.value().dims()[1];
size_t grad_row_width = grad.value().numel() / grad.rows().size(); auto grad_row_width = grad.value().dims()[1];
VLOG(4) << " param rows: " << param.rows().size()
<< " param memory rows: " << param.value().dims()[0]
<< " grad rows: " << grad.rows().size()
<< " grad memory rows: " << grad.value().dims()[0];
PADDLE_ENFORCE_EQ(param_row_width, grad_row_width, PADDLE_ENFORCE_EQ(param_row_width, grad_row_width,
"param_row should have the same size with grad_row"); "param_row should have the same size with grad_row");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册