From c5c7dc2e826f344ccecdbca8c02314f2660d729a Mon Sep 17 00:00:00 2001 From: Abhinav Arora Date: Wed, 18 Apr 2018 16:29:34 -0700 Subject: [PATCH] Fix CPPLint errors in multiclass_nms, nccl, nce, reduce and save_load_combine (#10032) * Fix CPPLint errors in multiclass_nms, nccl, nce, reduce and save_load_combine * Fix --- paddle/fluid/operators/multiclass_nms_op.cc | 20 +++---- paddle/fluid/operators/nccl_op.cu.cc | 5 +- paddle/fluid/operators/nce_op.h | 7 +-- paddle/fluid/operators/reduce_op.h | 52 +++++++++--------- .../operators/save_load_combine_op_test.cc | 54 +++++++++---------- 5 files changed, 70 insertions(+), 68 deletions(-) diff --git a/paddle/fluid/operators/multiclass_nms_op.cc b/paddle/fluid/operators/multiclass_nms_op.cc index 0f80f752c9..a12b975326 100644 --- a/paddle/fluid/operators/multiclass_nms_op.cc +++ b/paddle/fluid/operators/multiclass_nms_op.cc @@ -173,8 +173,8 @@ class MultiClassNMSKernel : public framework::OpKernel { void MultiClassNMS(const framework::ExecutionContext& ctx, const Tensor& scores, const Tensor& bboxes, - std::map>& indices, - int& num_nmsed_out) const { + std::map>* indices, + int* num_nmsed_out) const { int64_t background_label = ctx.Attr("background_label"); int64_t nms_top_k = ctx.Attr("nms_top_k"); int64_t keep_top_k = ctx.Attr("keep_top_k"); @@ -189,15 +189,15 @@ class MultiClassNMSKernel : public framework::OpKernel { if (c == background_label) continue; Tensor score = scores.Slice(c, c + 1); NMSFast(bboxes, score, score_threshold, nms_threshold, nms_eta, nms_top_k, - &(indices[c])); - num_det += indices[c].size(); + &((*indices)[c])); + num_det += (*indices)[c].size(); } - num_nmsed_out = num_det; + *num_nmsed_out = num_det; const T* scores_data = scores.data(); if (keep_top_k > -1 && num_det > keep_top_k) { std::vector>> score_index_pairs; - for (const auto& it : indices) { + for (const auto& it : *indices) { int label = it.first; const T* sdata = scores_data + label * predict_dim; const std::vector& label_indices = it.second; @@ -220,13 +220,13 @@ class MultiClassNMSKernel : public framework::OpKernel { int idx = score_index_pairs[j].second.second; new_indices[label].push_back(idx); } - new_indices.swap(indices); - num_nmsed_out = keep_top_k; + new_indices.swap(*indices); + *num_nmsed_out = keep_top_k; } } void MultiClassOutput(const Tensor& scores, const Tensor& bboxes, - std::map>& selected_indices, + const std::map>& selected_indices, Tensor* outs) const { int predict_dim = scores.dims()[1]; auto* scores_data = scores.data(); @@ -273,7 +273,7 @@ class MultiClassNMSKernel : public framework::OpKernel { std::map> indices; int num_nmsed_out = 0; - MultiClassNMS(ctx, ins_score, ins_boxes, indices, num_nmsed_out); + MultiClassNMS(ctx, ins_score, ins_boxes, &indices, &num_nmsed_out); all_indices.push_back(indices); batch_starts.push_back(batch_starts.back() + num_nmsed_out); } diff --git a/paddle/fluid/operators/nccl_op.cu.cc b/paddle/fluid/operators/nccl_op.cu.cc index ad623e1fe0..8de974bc2b 100644 --- a/paddle/fluid/operators/nccl_op.cu.cc +++ b/paddle/fluid/operators/nccl_op.cu.cc @@ -135,8 +135,9 @@ class NCCLBcastKernel : public framework::OpKernel { auto* x = ctx.Input("X"); VLOG(3) << "gpu : " << gpu_id << " invoke Bcast. send " << x->numel(); PADDLE_ENFORCE(platform::dynload::ncclBcast( - (void*)x->data(), x->numel(), NCCLTypeWrapper::type, root, - comm->comms().at(idx), ctx.cuda_device_context().stream())); + reinterpret_cast(const_cast(x->data())), x->numel(), + NCCLTypeWrapper::type, root, comm->comms().at(idx), + ctx.cuda_device_context().stream())); VLOG(3) << "gpu : " << gpu_id << " finished Bcast."; } else { auto* out = ctx.Output("Out"); diff --git a/paddle/fluid/operators/nce_op.h b/paddle/fluid/operators/nce_op.h index 9420763847..2c4c97f28b 100644 --- a/paddle/fluid/operators/nce_op.h +++ b/paddle/fluid/operators/nce_op.h @@ -16,6 +16,7 @@ limitations under the License. */ #include #include +#include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" #include "unsupported/Eigen/CXX11/Tensor" @@ -108,7 +109,7 @@ class NCEKernel : public framework::OpKernel { auto weight_mat = EigenMatrix::From(*(context.Input("Weight"))); for (int64_t i = 0; i < sample_labels->numel(); ++i) { Eigen::Tensor result = - (input_mat.chip((int)(i / sample_labels->dims()[1]), 0) * + (input_mat.chip(static_cast(i / sample_labels->dims()[1]), 0) * weight_mat.chip(sample_labels_data[i], 0)) .sum(); sample_out_data[i] += result(0); @@ -190,7 +191,7 @@ class NCEGradKernel : public framework::OpKernel { auto x_matrix = EigenMatrix::From(*(context.Input("Input"))); for (int64_t i = 0; i < sample_labels->numel(); ++i) { d_w_matrix.chip(sample_labels_data[i], 0) += - x_matrix.chip((int)(i / sample_labels->dims()[1]), 0) * + x_matrix.chip(static_cast(i / sample_labels->dims()[1]), 0) * sample_grad_data[i]; } } @@ -202,7 +203,7 @@ class NCEGradKernel : public framework::OpKernel { auto d_x_matrix = EigenMatrix::From(*d_x); auto w_matrix = EigenMatrix::From(*(context.Input("Weight"))); for (int64_t i = 0; i < sample_labels->numel(); ++i) { - d_x_matrix.chip((int)(i / sample_labels->dims()[1]), 0) += + d_x_matrix.chip(static_cast(i / sample_labels->dims()[1]), 0) += w_matrix.chip(sample_labels_data[i], 0) * sample_grad_data[i]; } } diff --git a/paddle/fluid/operators/reduce_op.h b/paddle/fluid/operators/reduce_op.h index b28dd7f209..e42b4bfe42 100644 --- a/paddle/fluid/operators/reduce_op.h +++ b/paddle/fluid/operators/reduce_op.h @@ -35,77 +35,77 @@ using EigenVector = framework::EigenVector; struct SumFunctor { template - void operator()(const DeviceContext& place, X& x, Y& y, const Dim& dim) { - y.device(place) = x.sum(dim); + void operator()(const DeviceContext& place, X* x, Y* y, const Dim& dim) { + y->device(place) = x->sum(dim); } }; struct SumGradFunctor { template - void operator()(const DeviceContext& place, X& x, Y& y, DX& dx, DY& dy, + void operator()(const DeviceContext& place, X* x, Y* y, DX* dx, DY* dy, const Dim& dim, int size) { - dx.device(place) = dy.broadcast(dim); + dx->device(place) = dy->broadcast(dim); } }; struct MeanFunctor { template - void operator()(const DeviceContext& place, X& x, Y& y, const Dim& dim) { - y.device(place) = x.mean(dim); + void operator()(const DeviceContext& place, X* x, Y* y, const Dim& dim) { + y->device(place) = x->mean(dim); } }; struct MeanGradFunctor { template - void operator()(const DeviceContext& place, X& x, Y& y, DX& dx, DY& dy, + void operator()(const DeviceContext& place, X* x, Y* y, DX* dx, DY* dy, const Dim& dim, int size) { - dx.device(place) = dy.broadcast(dim) / dx.constant(size); + dx->device(place) = dy->broadcast(dim) / dx->constant(size); } }; struct MaxFunctor { template - void operator()(const DeviceContext& place, X& x, Y& y, const Dim& dim) { - y.device(place) = x.maximum(dim); + void operator()(const DeviceContext& place, X* x, Y* y, const Dim& dim) { + y->device(place) = x->maximum(dim); } }; struct MinFunctor { template - void operator()(const DeviceContext& place, X& x, Y& y, const Dim& dim) { - y.device(place) = x.minimum(dim); + void operator()(const DeviceContext& place, X* x, Y* y, const Dim& dim) { + y->device(place) = x->minimum(dim); } }; struct MaxOrMinGradFunctor { template - void operator()(const DeviceContext& place, X& x, Y& y, DX& dx, DY& dy, + void operator()(const DeviceContext& place, X* x, Y* y, DX* dx, DY* dy, const Dim& dim, int size) { - auto equals = x == y.broadcast(dim); - auto ones = dx.constant(1); - auto zeros = dx.constant(0); + auto equals = (*x) == y->broadcast(dim); + auto ones = dx->constant(1); + auto zeros = dx->constant(0); // If there are multiple minimum or maximum elements, the subgradient of // each is the set [0, 1], and we pass gradient to all of them here. - dx.device(place) = dy.broadcast(dim) * equals.select(ones, zeros); + dx->device(place) = dy->broadcast(dim) * equals.select(ones, zeros); } }; struct ProdFunctor { template - void operator()(const DeviceContext& place, X& x, Y& y, const Dim& dim) { - y.device(place) = x.prod(dim); + void operator()(const DeviceContext& place, X* x, Y* y, const Dim& dim) { + y->device(place) = x->prod(dim); } }; struct ProdGradFunctor { template - void operator()(const DeviceContext& place, X& x, Y& y, DX& dx, DY& dy, + void operator()(const DeviceContext& place, X* x, Y* y, DX* dx, DY* dy, const Dim& dim, int size) { - dx.device(place) = dy.broadcast(dim) * y.broadcast(dim) * x.inverse(); + dx->device(place) = dy->broadcast(dim) * y->broadcast(dim) * x->inverse(); } }; @@ -125,7 +125,7 @@ class ReduceKernel : public framework::OpKernel { *context.template device_context().eigen_device(); auto reduce_dim = Eigen::array({{0}}); Functor functor; - functor(place, x, out, reduce_dim); + functor(place, &x, &out, reduce_dim); } else { int rank = context.Input("X")->dims().size(); switch (rank) { @@ -178,10 +178,10 @@ class ReduceKernel : public framework::OpKernel { if (D == 1) { auto out = EigenScalar::From(*output); - functor(place, x, out, reduce_dim); + functor(place, &x, &out, reduce_dim); } else { auto out = EigenTensor::From(*output, dims); - functor(place, x, out, reduce_dim); + functor(place, &x, &out, reduce_dim); } } }; @@ -206,7 +206,7 @@ class ReduceGradKernel : public framework::OpKernel { auto broadcast_dim = Eigen::array({{static_cast(input0->numel())}}); Functor functor; - functor(place, x, x_reduce, x_grad, x_reduce_grad, broadcast_dim, + functor(place, &x, &x_reduce, &x_grad, &x_reduce_grad, broadcast_dim, broadcast_dim[0]); } else { int rank = context.Input("X")->dims().size(); @@ -258,7 +258,7 @@ class ReduceGradKernel : public framework::OpKernel { auto& place = *context.template device_context().eigen_device(); Functor functor; - functor(place, x, x_reduce, x_grad, x_reduce_grad, broadcast_dim, + functor(place, &x, &x_reduce, &x_grad, &x_reduce_grad, broadcast_dim, broadcast_dim[dim]); } }; diff --git a/paddle/fluid/operators/save_load_combine_op_test.cc b/paddle/fluid/operators/save_load_combine_op_test.cc index 286f75df4c..2773c32a0a 100644 --- a/paddle/fluid/operators/save_load_combine_op_test.cc +++ b/paddle/fluid/operators/save_load_combine_op_test.cc @@ -23,17 +23,17 @@ USE_NO_KERNEL_OP(load_combine); int* CreateForSaveCombineOp(int x, int y, const std::vector& lod_info, std::string var_name, - paddle::platform::CPUPlace& place, - paddle::framework::Scope& scope, - paddle::framework::LoD& expect_lod) { - auto var = scope.Var(var_name); + const paddle::platform::CPUPlace& place, + paddle::framework::Scope* scope, + paddle::framework::LoD* expect_lod) { + auto var = scope->Var(var_name); auto tensor = var->GetMutable(); tensor->Resize({x, y}); - expect_lod.resize(1); + expect_lod->resize(1); for (size_t i = 0; i < lod_info.size(); i++) { - expect_lod[0].push_back(lod_info[i]); + (*expect_lod)[0].push_back(lod_info[i]); } - tensor->set_lod(expect_lod); + tensor->set_lod(*expect_lod); int* expect = tensor->mutable_data(place); for (int64_t i = 0; i < tensor->numel(); ++i) { expect[i] = static_cast(i); @@ -42,17 +42,17 @@ int* CreateForSaveCombineOp(int x, int y, const std::vector& lod_info, } paddle::framework::LoDTensor* GeneratePlaceholderBeforeLoad( - const std::string out_var_name, paddle::framework::Scope& scope) { - auto load_var = scope.Var(out_var_name); + const std::string out_var_name, paddle::framework::Scope* scope) { + auto load_var = scope->Var(out_var_name); auto target = load_var->GetMutable(); return target; } int* GetValuesAfterLoadCombineOp(paddle::framework::LoDTensor* target, - paddle::framework::Scope& scope, - paddle::framework::LoD& actual_lod) { + const paddle::framework::Scope& scope, + paddle::framework::LoD* actual_lod) { int* actual = target->data(); - actual_lod = target->lod(); + *actual_lod = target->lod(); return actual; } @@ -78,26 +78,26 @@ TEST(SaveLoadCombineOp, CPU) { std::vector lod1 = {0, 1, 2, 3, 10}; int numel1 = 100; paddle::framework::LoD expect_lod1; - int* expect1 = CreateForSaveCombineOp(10, 10, lod1, "test_var1", place, scope, - expect_lod1); + int* expect1 = CreateForSaveCombineOp(10, 10, lod1, "test_var1", place, + &scope, &expect_lod1); std::vector lod2 = {0, 2, 5, 10}; int numel2 = 200; paddle::framework::LoD expect_lod2; - int* expect2 = CreateForSaveCombineOp(10, 20, lod2, "test_var2", place, scope, - expect_lod2); + int* expect2 = CreateForSaveCombineOp(10, 20, lod2, "test_var2", place, + &scope, &expect_lod2); std::vector lod3 = {0, 2, 3, 20}; int numel3 = 4000; paddle::framework::LoD expect_lod3; int* expect3 = CreateForSaveCombineOp(20, 200, lod3, "test_var3", place, - scope, expect_lod3); + &scope, &expect_lod3); std::vector lod4 = {0, 1, 20}; int numel4 = 1000; paddle::framework::LoD expect_lod4; - int* expect4 = CreateForSaveCombineOp(20, 50, lod4, "test_var4", place, scope, - expect_lod4); + int* expect4 = CreateForSaveCombineOp(20, 50, lod4, "test_var4", place, + &scope, &expect_lod4); // Set attributes std::string filename = "check_tensor.ls"; @@ -111,10 +111,10 @@ TEST(SaveLoadCombineOp, CPU) { save_combine_op->Run(scope, place); // Set up output vars - auto target1 = GeneratePlaceholderBeforeLoad("out_var1", scope); - auto target2 = GeneratePlaceholderBeforeLoad("out_var2", scope); - auto target3 = GeneratePlaceholderBeforeLoad("out_var3", scope); - auto target4 = GeneratePlaceholderBeforeLoad("out_var4", scope); + auto target1 = GeneratePlaceholderBeforeLoad("out_var1", &scope); + auto target2 = GeneratePlaceholderBeforeLoad("out_var2", &scope); + auto target3 = GeneratePlaceholderBeforeLoad("out_var3", &scope); + auto target4 = GeneratePlaceholderBeforeLoad("out_var4", &scope); // Run the load_combine_op auto load_combine_op = paddle::framework::OpRegistry::CreateOp( @@ -123,10 +123,10 @@ TEST(SaveLoadCombineOp, CPU) { load_combine_op->Run(scope, place); paddle::framework::LoD actual_lod1, actual_lod2, actual_lod3, actual_lod4; - int* actual1 = GetValuesAfterLoadCombineOp(target1, scope, actual_lod1); - int* actual2 = GetValuesAfterLoadCombineOp(target2, scope, actual_lod2); - int* actual3 = GetValuesAfterLoadCombineOp(target3, scope, actual_lod3); - int* actual4 = GetValuesAfterLoadCombineOp(target4, scope, actual_lod4); + int* actual1 = GetValuesAfterLoadCombineOp(target1, scope, &actual_lod1); + int* actual2 = GetValuesAfterLoadCombineOp(target2, scope, &actual_lod2); + int* actual3 = GetValuesAfterLoadCombineOp(target3, scope, &actual_lod3); + int* actual4 = GetValuesAfterLoadCombineOp(target4, scope, &actual_lod4); CheckValues(expect1, actual1, expect_lod1, actual_lod1, numel1); CheckValues(expect2, actual2, expect_lod2, actual_lod2, numel2); -- GitLab