From ea7359c60bdf6062b1296f471f50cbeaf8da243e Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Tue, 28 Nov 2017 12:47:17 +0800 Subject: [PATCH] Refine code and comments 1. Remove checking for num_neg_samples. 2. Fix dims of Output(Cost) and Input(Bias). 3. Renamed num_sampled_classes to num_neg_samples. 4. Add TODO for add more distribution sampler. 5. Init grad_data of bias by zero. 6. Refine comments. 7. Register a kernel for type double. --- paddle/operators/nce_op.cc | 95 +++++++++++++++--------- paddle/operators/nce_op.h | 15 ++-- python/paddle/v2/fluid/tests/test_nce.py | 14 ++-- 3 files changed, 77 insertions(+), 47 deletions(-) diff --git a/paddle/operators/nce_op.cc b/paddle/operators/nce_op.cc index c365d5d92..bb9346b13 100644 --- a/paddle/operators/nce_op.cc +++ b/paddle/operators/nce_op.cc @@ -1,16 +1,16 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ #include "paddle/operators/nce_op.h" @@ -39,25 +39,25 @@ class NCEOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_EQ(ctx->GetInputDim("Weight")[0], ctx->GetInputDim("Bias")[0]); } - auto num_sampled_classes = ctx->Attrs().Get("num_sampled_classes"); - auto num_classes = ctx->Attrs().Get("num_classes"); + auto num_neg_samples = ctx->Attrs().Get("num_neg_samples"); + auto num_total_classes = ctx->Attrs().Get("num_total_classes"); std::vector sampled_labels = ctx->Attrs().Get>("sampled_labels"); - PADDLE_ENFORCE_EQ(num_classes, ctx->GetInputDim("Weight")[0]); - PADDLE_ENFORCE_LT(num_sampled_classes, num_classes); + PADDLE_ENFORCE_EQ(num_total_classes, ctx->GetInputDim("Weight")[0]); if (sampled_labels.size() > 0) { PADDLE_ENFORCE_EQ(sampled_labels.size(), - static_cast(num_sampled_classes)); + static_cast(num_neg_samples)); } // set dims of output(Out) std::vector out_dims; out_dims.push_back(x_dims[0]); + out_dims.push_back(1); ctx->SetOutputDim("Cost", framework::make_ddim(out_dims)); // set dims of output(SampleOut) std::vector sample_out_dims; sample_out_dims.push_back(x_dims[0]); - sample_out_dims.push_back(num_sampled_classes + num_true_classes); + sample_out_dims.push_back(num_neg_samples + num_true_classes); ctx->SetOutputDim("SampleLogits", framework::make_ddim(sample_out_dims)); ctx->SetOutputDim("SampleLabels", framework::make_ddim(sample_out_dims)); } @@ -76,34 +76,59 @@ class NCEOpMaker : public framework::OpProtoAndCheckerMaker { NCEOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("Input", "(Tensor) A tensor of shape [batch_size, dim]."); - AddInput("Label", - "(Tensor) A tensor of shape [batch_size, num_true_class]. " - "'num_true_class' is the number of target class in each sample."); + AddInput( + "Label", + "(Tensor) A tensor of shape [batch_size, num_true_class]. " + "'num_true_class' is the number of target classes in each sample." + "The number of target classes per sample should be same. " + "If you have a variable number of target classes, " + "you can pad them out to a constant number by either repeating them" + " or by padding with an otherwise unused class.)"); AddInput("Weight", "(Tensor) A tensor of shape [num_class, dim]. 'num_class' is the " "total number of class."); - AddInput("Bias", - "(Tensor) A tensor of shape [num_class]. 'num_class' is the total " - "number of class. It is a dispensable input.") + AddInput( + "Bias", + "(Tensor) A tensor of shape [num_class, 1]. 'num_class' is the total " + "number of class. It is a dispensable input.") .AsDispensable(); AddInput("SampleWeight", - "(Tensor) A tensor of shape [batch_size] storing a weight for " + "(Tensor) A tensor of shape [batch_size, 1] storing a weight for " "each sample. And it is a dispensable input. The default value of " "sample is 1.") .AsDispensable(); AddOutput("Cost", - "(Tensor) A tensor of shape [batch_size]. Cost of samples."); - AddOutput("SampleLogits", "An intermediate tensor.").AsIntermediate(); - AddOutput("SampleLabels", "An intermediate tensor.").AsIntermediate(); - AddAttr("num_classes", "Total number of classes."); - AddAttr("num_sampled_classes", "The number of negative classes.") + "(Tensor) A tensor of shape [batch_size, 1]. Cost of samples."); + AddOutput("SampleLogits", + "An intermediate tensor of shape[batch_size, num_neg_samples + " + "num_pos_samples]." + "This tensor is output of forward kernel and used in backward " + "kernel to compute grads." + "Given X is the dot product of input tensor and sampled labels' " + "weights." + "Then 'SampleLogits' is sigmoid(X).") + .AsIntermediate(); + AddOutput("SampleLabels", + "An intermediate tensor of shape[batch_size, num_neg_samples + " + "num_pos_samples]." + "This tensor is output of forward kernel and used in backward " + "kernel to compute grads." + "") + .AsIntermediate(); + AddAttr("num_total_classes", + "Total number of classes in all samples."); + AddAttr("num_neg_samples", + "The number of negative classes. The default value is 10.") .SetDefault(10); - AddAttr>("sampled_labels", ""); + AddAttr>("custom_neg_classes", + "This attribute only be used in unitest. Classes " + "in this list wiil be used as negative classes " + "for every samples. Under normal conditions, " + "user should avoid setting this attribute."); AddComment(R"DOC( -Computes and returns the noise-contrastive estimation training loss. +Compute and return the noise-contrastive estimation training loss. See [Noise-contrastive estimation: A new estimation principle for unnormalized statistical models](http://www.jmlr.org/proceedings/papers/v9/gutmann10a/gutmann10a.pdf). -By default this uses a uniform distribution for sampling. -The number of target classes per example should be same. If you have a variable number of target classes, you can pad them out to a constant number by either repeating them or by padding with an otherwise unused class. +By default this operator uses a uniform distribution for sampling. )DOC"); } }; @@ -119,7 +144,7 @@ class NCEOpGrad : public framework::OperatorWithKernel { PADDLE_ENFORCE(ctx->HasInput("SampleLogits")); PADDLE_ENFORCE(ctx->HasInput("SampleLabels")); PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Cost")), - "The input(Out@GRAD) should not be null"); + "The input(Out@GRAD) should not be null."); auto x_dims = ctx->GetInputDim("Input"); auto x_grad_name = framework::GradVarName("Input"); @@ -154,6 +179,8 @@ class NCEOpGrad : public framework::OperatorWithKernel { namespace ops = paddle::operators; REGISTER_OP(nce, ops::NCEOp, ops::NCEOpMaker, nce_grad, ops::NCEOpGrad); -REGISTER_OP_CPU_KERNEL(nce, ops::NCEKernel); +REGISTER_OP_CPU_KERNEL(nce, ops::NCEKernel, + ops::NCEKernel); REGISTER_OP_CPU_KERNEL(nce_grad, - ops::NCEGradKernel); + ops::NCEGradKernel, + ops::NCEGradKernel); diff --git a/paddle/operators/nce_op.h b/paddle/operators/nce_op.h index 3017bccdc..c41393d26 100644 --- a/paddle/operators/nce_op.h +++ b/paddle/operators/nce_op.h @@ -22,7 +22,7 @@ namespace paddle { namespace operators { -using Tensor = framework::Tensor; +using framework::Tensor; template @@ -35,8 +35,8 @@ void PrepareSamples(const framework::ExecutionContext& context) { auto label_dims = label->dims(); int num_classes = context.Attr("num_classes"); // for unitest - std::vector sampled_labels = - context.Attr>("sampled_labels"); + std::vector custom_neg_classes = + context.Attr>("custom_neg_classes"); // random machine std::random_device rd; std::mt19937 rng(rd()); @@ -54,12 +54,13 @@ void PrepareSamples(const framework::ExecutionContext& context) { for (; j < num_label; ++j) { sample_labels_data[index++] = label_data[i * num_label + j]; } - if (sampled_labels.size() > 0) { - for (auto label : sampled_labels) { + if (custom_neg_classes.size() > 0) { + for (auto label : custom_neg_classes) { sample_labels_data[index++] = label; } } else { for (; j < sample_labels_dims[1]; ++j) { + // TODO: support more distribution sampling sample_labels_data[index++] = rand(rng); } } @@ -176,6 +177,7 @@ class NCEGradKernel : public framework::OpKernel { auto d_bias = context.Output(framework::GradVarName("Bias")); if (d_bias != nullptr) { T* d_bias_data = d_bias->mutable_data(context.GetPlace()); + std::fill(d_bias_data, d_bias_data + d_bias->numel(), 0.0); for (size_t i = 0; i < sample_labels->numel(); ++i) { d_bias_data[sample_labels_data[i]] += sample_grad_data[i]; } @@ -183,7 +185,8 @@ class NCEGradKernel : public framework::OpKernel { // get d_w auto d_w = context.Output(framework::GradVarName("Weight")); if (d_w != nullptr) { - d_w->mutable_data(context.GetPlace()); + auto d_w_data = d_w->mutable_data(context.GetPlace()); + std::fill(d_w_data, d_w_data + d_w->numel(), 0.0); auto d_w_matrix = EigenMatrix::From(*d_w); auto x_matrix = EigenMatrix::From(*(context.Input("Input"))); for (size_t i = 0; i < sample_labels->numel(); ++i) { diff --git a/python/paddle/v2/fluid/tests/test_nce.py b/python/paddle/v2/fluid/tests/test_nce.py index 82978f2d2..6cbf468e0 100644 --- a/python/paddle/v2/fluid/tests/test_nce.py +++ b/python/paddle/v2/fluid/tests/test_nce.py @@ -18,25 +18,25 @@ def nce(input, weight, bias, sample_weight, labels, num_classes, samples.append((i, num, False, w)) sample_labels.append(num) # forward bias - sampleOut = np.zeros(len(samples)).astype(np.float32) + sample_out = np.zeros(len(samples)).astype(np.float32) if bias is not None: for i in range(len(samples)): - sampleOut[i] = bias[samples[i][1]] + sample_out[i] = bias[samples[i][1]] # forward weight for i in range(len(samples)): - sampleOut[i] += np.dot(input[samples[i][0]], weight[samples[i][1]]) + sample_out[i] += np.dot(input[samples[i][0]], weight[samples[i][1]]) # forward activation - sampleOut = 1.0 / (1.0 + np.exp(-sampleOut)) + sample_out = 1.0 / (1.0 + np.exp(-sample_out)) # forward cost out = np.zeros(batch_size).astype(np.float32) b = 1.0 / num_classes * num_sample_class for i in range(len(samples)): - o = sampleOut[i] + o = sample_out[i] cost = -np.log(o / (o + b)) if samples[i][2] else -np.log(b / (o + b)) out[samples[i][0]] += cost * samples[i][3] - return (out, np.array(sampleOut).reshape(batch_size, - num_sample_class + num_true_class), + return (out, np.array(sample_out).reshape( + batch_size, num_sample_class + num_true_class), np.array(sample_labels).reshape(batch_size, num_sample_class + num_true_class)) -- GitLab