提交 9b63b7dd 编写于 作者: A Abhinav Arora

Fix warnings in split_ids_op

上级 59234b72
...@@ -60,7 +60,9 @@ class SplitIdsOpKernel : public framework::OpKernel<T> { ...@@ -60,7 +60,9 @@ class SplitIdsOpKernel : public framework::OpKernel<T> {
} else if (ids_var->IsType<framework::SelectedRows>()) { } else if (ids_var->IsType<framework::SelectedRows>()) {
const auto *ids_selected_rows = ctx.Input<framework::SelectedRows>("Ids"); const auto *ids_selected_rows = ctx.Input<framework::SelectedRows>("Ids");
auto &ids_dims = ids_selected_rows->value().dims(); auto &ids_dims = ids_selected_rows->value().dims();
PADDLE_ENFORCE_EQ(ids_dims[0], ids_selected_rows->rows().size(), ""); PADDLE_ENFORCE_EQ(ids_dims[0],
static_cast<int64_t>(ids_selected_rows->rows().size()),
"");
const T *ids = ids_selected_rows->value().data<T>(); const T *ids = ids_selected_rows->value().data<T>();
const auto &ids_rows = ids_selected_rows->rows(); const auto &ids_rows = ids_selected_rows->rows();
auto outs = ctx.MultiOutput<framework::SelectedRows>("Out"); auto outs = ctx.MultiOutput<framework::SelectedRows>("Out");
...@@ -77,7 +79,7 @@ class SplitIdsOpKernel : public framework::OpKernel<T> { ...@@ -77,7 +79,7 @@ class SplitIdsOpKernel : public framework::OpKernel<T> {
framework::DDim ddim = framework::make_ddim( framework::DDim ddim = framework::make_ddim(
{static_cast<int64_t>(out->rows().size()), row_width}); {static_cast<int64_t>(out->rows().size()), row_width});
T *output = out->mutable_value()->mutable_data<T>(ddim, place); T *output = out->mutable_value()->mutable_data<T>(ddim, place);
for (size_t i = 0; i < ddim[0]; ++i) { for (int64_t i = 0; i < ddim[0]; ++i) {
memcpy(output + i * row_width, ids + out->rows()[i] * row_width, memcpy(output + i * row_width, ids + out->rows()[i] * row_width,
row_width * sizeof(T)); row_width * sizeof(T));
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册