diff --git a/paddle/fluid/operators/distributed/parameter_recv.cc b/paddle/fluid/operators/distributed/parameter_recv.cc index c3238f28f6376f3ae61a8dcce27f88ab0c703da5..ae6516b2464326eaa6e6731ae446090c8ab36042 100644 --- a/paddle/fluid/operators/distributed/parameter_recv.cc +++ b/paddle/fluid/operators/distributed/parameter_recv.cc @@ -80,7 +80,9 @@ void ParameterRecv::operator()(const RpcContext &rpc_ctx, framework::Tensor *recv_tensor = recv_var->GetMutable(); auto dev_ctx = paddle::platform::CPUDeviceContext(); + int64_t recv_numel = 0; for (auto *in : recved_tensors) { + recv_numel += in->numel(); auto in_stride = framework::stride_numel(in->dims()); auto out_stride = framework::stride_numel(recv_tensor->dims()); StridedNumelCopyWithAxis( @@ -88,6 +90,7 @@ void ParameterRecv::operator()(const RpcContext &rpc_ctx, in->data(), in_stride, in_stride[0]); output_offset += in_stride[0]; } + PADDLE_ENFORCE_EQ(recv_numel, recv_tensor->numel()); } delete local_scope; diff --git a/paddle/fluid/operators/distributed/parameter_send.cc b/paddle/fluid/operators/distributed/parameter_send.cc index 388bc781c1387639042a70f08ead643a3048f77c..ec2884c25290aa3cfd9818ead61119cc6c6b6feb 100644 --- a/paddle/fluid/operators/distributed/parameter_send.cc +++ b/paddle/fluid/operators/distributed/parameter_send.cc @@ -99,7 +99,7 @@ void ParameterSend::operator()(const RpcContext &rpc_ctx, // split rows index into output sparse vars for (size_t i = 0; i < send_rows.size(); ++i) { - int out_idx = FindOutIdx(send_rows[i], abs_sections); + int out_idx = GetSectionIndex(send_rows[i], abs_sections); outs_rows_idx[out_idx].push_back(send_rows[i]); outs_dense_idx[out_idx].push_back(i); } diff --git a/paddle/fluid/operators/distributed_ops/send_recv_util.h b/paddle/fluid/operators/distributed_ops/send_recv_util.h index 01caee9a9250dbb6084181e9d8aad7852e7ef8f5..c05a1ff1da8803c1ef3161d0e9d8604f9f1e5f3b 100644 --- a/paddle/fluid/operators/distributed_ops/send_recv_util.h +++ b/paddle/fluid/operators/distributed_ops/send_recv_util.h @@ -48,16 +48,6 @@ inline bool NeedSend(const framework::Scope& scope, return false; } -inline int FindOutIdx(int row, const std::vector& abs_sections) { - for (size_t i = 1; i < abs_sections.size(); ++i) { - if (row < abs_sections[i]) { - return i - 1; - } - } - PADDLE_ENFORCE_LT(row, abs_sections.back(), "row should be less then max id"); - return abs_sections.size() - 1; -} - inline std::vector ToAbsoluteSection( const std::vector& height_sections) { std::vector abs_sections; diff --git a/paddle/fluid/operators/split_selected_rows_op.h b/paddle/fluid/operators/split_selected_rows_op.h index c29065649e6ee0168ce44038b36dff25c62c082f..9ec459e2a68d85af526e741d7fd9ecd858383132 100644 --- a/paddle/fluid/operators/split_selected_rows_op.h +++ b/paddle/fluid/operators/split_selected_rows_op.h @@ -32,7 +32,8 @@ class SplitSelectedRowsOpKernel : public framework::OpKernel { auto abs_sections = ToAbsoluteSection(height_sections); - auto x_rows = x->rows(); + auto& x_rows = x->rows(); + auto height = x->height(); std::vector> outs_rows_idx; std::vector> outs_dense_idx; @@ -44,8 +45,10 @@ class SplitSelectedRowsOpKernel : public framework::OpKernel { // split rows index into output sparse vars for (size_t i = 0; i < x_rows.size(); ++i) { - int out_idx = FindOutIdx(x_rows[i], abs_sections); - outs_rows_idx[out_idx].push_back(x_rows[i]); + auto& id = x_rows[i]; + PADDLE_ENFORCE_LT(id, height); + int out_idx = GetSectionIndex(id, abs_sections); + outs_rows_idx[out_idx].push_back(id); outs_dense_idx[out_idx].push_back(i); } auto place = ctx.GetPlace(); @@ -59,7 +62,9 @@ class SplitSelectedRowsOpKernel : public framework::OpKernel { outs[i]->mutable_rows()->clear(); if (rows_idx.size() > 0) { for (auto idx : rows_idx) { - outs[i]->mutable_rows()->push_back(idx - abs_sections[i]); + auto id_offset = idx - abs_sections[i]; + PADDLE_ENFORCE_LT(id_offset, height_sections[i]); + outs[i]->mutable_rows()->push_back(id_offset); } auto dst = outs[i]->mutable_value()->mutable_data(ctx.GetPlace()); for (size_t j = 0; j < rows_idx.size(); j++) {