提交 5825196d 编写于 作者: Q qiaolongfei

fix sgd for SelectedRows bug

上级 c797aded
......@@ -96,8 +96,12 @@ class SGDOpKernel : public framework::OpKernel<T> {
return;
}
size_t param_row_width = param.value().numel() / param.rows().size();
size_t grad_row_width = grad.value().numel() / grad.rows().size();
auto param_row_width = param.value().dims()[1];
auto grad_row_width = grad.value().dims()[1];
VLOG(4) << " param rows: " << param.rows().size()
<< " param memory rows: " << param.value().dims()[0]
<< " grad rows: " << grad.rows().size()
<< " grad memory rows: " << grad.value().dims()[0];
PADDLE_ENFORCE_EQ(param_row_width, grad_row_width,
"param_row should have the same size with grad_row");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册