未验证 提交 5a4d9328 编写于 作者: T Tao Luo 提交者: GitHub

Merge pull request #9888 from abhinavarora/fix_warnings_

Fix warnings in sgd_op.h
...@@ -65,7 +65,8 @@ class SGDOpKernel : public framework::OpKernel<T> { ...@@ -65,7 +65,8 @@ class SGDOpKernel : public framework::OpKernel<T> {
auto &grad_rows = grad->rows(); auto &grad_rows = grad->rows();
size_t grad_row_numel = grad_value.numel() / grad_rows.size(); size_t grad_row_numel = grad_value.numel() / grad_rows.size();
PADDLE_ENFORCE_EQ(grad_row_numel, param_out->numel() / grad_height); PADDLE_ENFORCE_EQ(static_cast<int64_t>(grad_row_numel),
param_out->numel() / grad_height);
auto *grad_data = grad_value.data<T>(); auto *grad_data = grad_value.data<T>();
auto *out_data = param_out->data<T>(); auto *out_data = param_out->data<T>();
...@@ -73,7 +74,7 @@ class SGDOpKernel : public framework::OpKernel<T> { ...@@ -73,7 +74,7 @@ class SGDOpKernel : public framework::OpKernel<T> {
for (size_t i = 0; i < grad_rows.size(); i++) { for (size_t i = 0; i < grad_rows.size(); i++) {
PADDLE_ENFORCE(grad_rows[i] < grad_height, PADDLE_ENFORCE(grad_rows[i] < grad_height,
"Input rows index should less than height"); "Input rows index should less than height");
for (int64_t j = 0; j < grad_row_numel; j++) { for (size_t j = 0; j < grad_row_numel; j++) {
out_data[grad_rows[i] * grad_row_numel + j] -= out_data[grad_rows[i] * grad_row_numel + j] -=
lr[0] * grad_data[i * grad_row_numel + j]; lr[0] * grad_data[i * grad_row_numel + j];
} }
...@@ -107,7 +108,7 @@ class SGDOpKernel : public framework::OpKernel<T> { ...@@ -107,7 +108,7 @@ class SGDOpKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE(grad.rows()[i] < grad.height(), PADDLE_ENFORCE(grad.rows()[i] < grad.height(),
"Input rows index should less than height"); "Input rows index should less than height");
int64_t id_index = param.index(grad.rows()[i]); int64_t id_index = param.index(grad.rows()[i]);
for (int64_t j = 0; j < grad_row_width; j++) { for (size_t j = 0; j < grad_row_width; j++) {
out_data[id_index * grad_row_width + j] -= out_data[id_index * grad_row_width + j] -=
lr[0] * grad_data[i * grad_row_width + j]; lr[0] * grad_data[i * grad_row_width + j];
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册