From 83537c7ada62153d9bd323de6144d67902cdcd39 Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Wed, 6 Dec 2017 13:10:04 +0800 Subject: [PATCH] Fix warning about comparison of integers of different signs --- paddle/operators/nce_op.h | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/paddle/operators/nce_op.h b/paddle/operators/nce_op.h index ea92a797fe1..0a8a95de5f4 100644 --- a/paddle/operators/nce_op.h +++ b/paddle/operators/nce_op.h @@ -49,7 +49,7 @@ void PrepareSamples(const framework::ExecutionContext& context) { int num_label = label_dims.size() == 2 ? label_dims[1] : 1; int index = 0; - for (size_t i = 0; i < label_dims[0]; ++i) { + for (int64_t i = 0; i < label_dims[0]; ++i) { int j = 0; for (; j < num_label; ++j) { sample_labels_data[index++] = label_data[i * num_label + j]; @@ -86,7 +86,7 @@ class NCEKernel : public framework::OpKernel { T* out_data = out->mutable_data(context.GetPlace()); int num_neg_samples = context.Attr("num_neg_samples"); int num_total_classes = context.Attr("num_total_classes"); - int num_true_class = 1; + int64_t num_true_class = 1; if (label != nullptr) { num_true_class = label->dims()[1]; } @@ -95,18 +95,18 @@ class NCEKernel : public framework::OpKernel { auto bias = context.Input("Bias"); if (bias != nullptr) { const T* bias_data = bias->data(); - for (size_t i = 0; i < sample_labels->numel(); ++i) { + for (int64_t i = 0; i < sample_labels->numel(); ++i) { sample_out_data[i] = bias_data[sample_labels_data[i]]; } } else { - for (size_t i = 0; i < sample_labels->numel(); ++i) { + for (int64_t i = 0; i < sample_labels->numel(); ++i) { sample_out_data[i] = 0; } } // forward mul auto input_mat = EigenMatrix::From(*(context.Input("Input"))); auto weight_mat = EigenMatrix::From(*(context.Input("Weight"))); - for (size_t i = 0; i < sample_labels->numel(); ++i) { + for (int64_t i = 0; i < sample_labels->numel(); ++i) { Eigen::Tensor result = (input_mat.chip((int)(i / sample_labels->dims()[1]), 0) * weight_mat.chip(sample_labels_data[i], 0)) @@ -115,8 +115,8 @@ class NCEKernel : public framework::OpKernel { sample_out_data[i] = (1. / (1. + exp(-sample_out_data[i]))); } // forward cost - for (size_t i = 0; i < sample_labels->dims()[0]; ++i) { - size_t j = 0; + for (int64_t i = 0; i < sample_labels->dims()[0]; ++i) { + int64_t j = 0; out_data[i] = 0; T w = sample_weight == nullptr ? 1. : sample_weight_data[i]; // for true classes @@ -162,7 +162,7 @@ class NCEGradKernel : public framework::OpKernel { T* sample_grad_data = sample_grad.mutable_data(sample_labels->dims(), context.GetPlace()); // backward cost - for (size_t i = 0; i < sample_labels->numel(); ++i) { + for (int64_t i = 0; i < sample_labels->numel(); ++i) { T o = sample_out_data[i]; T w = sample_weight == nullptr ? 1 @@ -177,7 +177,7 @@ class NCEGradKernel : public framework::OpKernel { if (d_bias != nullptr) { T* d_bias_data = d_bias->mutable_data(context.GetPlace()); std::fill(d_bias_data, d_bias_data + d_bias->numel(), 0.0); - for (size_t i = 0; i < sample_labels->numel(); ++i) { + for (int64_t i = 0; i < sample_labels->numel(); ++i) { d_bias_data[sample_labels_data[i]] += sample_grad_data[i]; } } @@ -188,7 +188,7 @@ class NCEGradKernel : public framework::OpKernel { std::fill(d_w_data, d_w_data + d_w->numel(), 0.0); auto d_w_matrix = EigenMatrix::From(*d_w); auto x_matrix = EigenMatrix::From(*(context.Input("Input"))); - for (size_t i = 0; i < sample_labels->numel(); ++i) { + for (int64_t i = 0; i < sample_labels->numel(); ++i) { d_w_matrix.chip(sample_labels_data[i], 0) += x_matrix.chip((int)(i / sample_labels->dims()[1]), 0) * sample_grad_data[i]; @@ -200,7 +200,7 @@ class NCEGradKernel : public framework::OpKernel { d_x->mutable_data(context.GetPlace()); auto d_x_matrix = EigenMatrix::From(*d_x); auto w_matrix = EigenMatrix::From(*(context.Input("Weight"))); - for (size_t i = 0; i < sample_labels->numel(); ++i) { + for (int64_t i = 0; i < sample_labels->numel(); ++i) { d_x_matrix.chip((int)(i / sample_labels->dims()[1]), 0) += w_matrix.chip(sample_labels_data[i], 0) * sample_grad_data[i]; } -- GitLab