From ea0df4e8a2cf291a0e6626771c58d1d75635b3c1 Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Sat, 16 Mar 2019 15:11:45 +0800 Subject: [PATCH] add some check --- .../fluid/operators/distributed/parameter_recv.cc | 3 +++ .../fluid/operators/distributed/parameter_send.cc | 2 +- .../operators/distributed_ops/send_recv_util.h | 10 ---------- paddle/fluid/operators/split_selected_rows_op.h | 13 +++++++++---- 4 files changed, 13 insertions(+), 15 deletions(-) diff --git a/paddle/fluid/operators/distributed/parameter_recv.cc b/paddle/fluid/operators/distributed/parameter_recv.cc index c3238f28f63..ae6516b2464 100644 --- a/paddle/fluid/operators/distributed/parameter_recv.cc +++ b/paddle/fluid/operators/distributed/parameter_recv.cc @@ -80,7 +80,9 @@ void ParameterRecv::operator()(const RpcContext &rpc_ctx, framework::Tensor *recv_tensor = recv_var->GetMutable(); auto dev_ctx = paddle::platform::CPUDeviceContext(); + int64_t recv_numel = 0; for (auto *in : recved_tensors) { + recv_numel += in->numel(); auto in_stride = framework::stride_numel(in->dims()); auto out_stride = framework::stride_numel(recv_tensor->dims()); StridedNumelCopyWithAxis( @@ -88,6 +90,7 @@ void ParameterRecv::operator()(const RpcContext &rpc_ctx, in->data(), in_stride, in_stride[0]); output_offset += in_stride[0]; } + PADDLE_ENFORCE_EQ(recv_numel, recv_tensor->numel()); } delete local_scope; diff --git a/paddle/fluid/operators/distributed/parameter_send.cc b/paddle/fluid/operators/distributed/parameter_send.cc index 388bc781c13..ec2884c2529 100644 --- a/paddle/fluid/operators/distributed/parameter_send.cc +++ b/paddle/fluid/operators/distributed/parameter_send.cc @@ -99,7 +99,7 @@ void ParameterSend::operator()(const RpcContext &rpc_ctx, // split rows index into output sparse vars for (size_t i = 0; i < send_rows.size(); ++i) { - int out_idx = FindOutIdx(send_rows[i], abs_sections); + int out_idx = GetSectionIndex(send_rows[i], abs_sections); outs_rows_idx[out_idx].push_back(send_rows[i]); outs_dense_idx[out_idx].push_back(i); } diff --git a/paddle/fluid/operators/distributed_ops/send_recv_util.h b/paddle/fluid/operators/distributed_ops/send_recv_util.h index 01caee9a925..c05a1ff1da8 100644 --- a/paddle/fluid/operators/distributed_ops/send_recv_util.h +++ b/paddle/fluid/operators/distributed_ops/send_recv_util.h @@ -48,16 +48,6 @@ inline bool NeedSend(const framework::Scope& scope, return false; } -inline int FindOutIdx(int row, const std::vector& abs_sections) { - for (size_t i = 1; i < abs_sections.size(); ++i) { - if (row < abs_sections[i]) { - return i - 1; - } - } - PADDLE_ENFORCE_LT(row, abs_sections.back(), "row should be less then max id"); - return abs_sections.size() - 1; -} - inline std::vector ToAbsoluteSection( const std::vector& height_sections) { std::vector abs_sections; diff --git a/paddle/fluid/operators/split_selected_rows_op.h b/paddle/fluid/operators/split_selected_rows_op.h index c29065649e6..9ec459e2a68 100644 --- a/paddle/fluid/operators/split_selected_rows_op.h +++ b/paddle/fluid/operators/split_selected_rows_op.h @@ -32,7 +32,8 @@ class SplitSelectedRowsOpKernel : public framework::OpKernel { auto abs_sections = ToAbsoluteSection(height_sections); - auto x_rows = x->rows(); + auto& x_rows = x->rows(); + auto height = x->height(); std::vector> outs_rows_idx; std::vector> outs_dense_idx; @@ -44,8 +45,10 @@ class SplitSelectedRowsOpKernel : public framework::OpKernel { // split rows index into output sparse vars for (size_t i = 0; i < x_rows.size(); ++i) { - int out_idx = FindOutIdx(x_rows[i], abs_sections); - outs_rows_idx[out_idx].push_back(x_rows[i]); + auto& id = x_rows[i]; + PADDLE_ENFORCE_LT(id, height); + int out_idx = GetSectionIndex(id, abs_sections); + outs_rows_idx[out_idx].push_back(id); outs_dense_idx[out_idx].push_back(i); } auto place = ctx.GetPlace(); @@ -59,7 +62,9 @@ class SplitSelectedRowsOpKernel : public framework::OpKernel { outs[i]->mutable_rows()->clear(); if (rows_idx.size() > 0) { for (auto idx : rows_idx) { - outs[i]->mutable_rows()->push_back(idx - abs_sections[i]); + auto id_offset = idx - abs_sections[i]; + PADDLE_ENFORCE_LT(id_offset, height_sections[i]); + outs[i]->mutable_rows()->push_back(id_offset); } auto dst = outs[i]->mutable_value()->mutable_data(ctx.GetPlace()); for (size_t j = 0; j < rows_idx.size(); j++) { -- GitLab