提交 ea7359c6 编写于 作者: W wanghaoshuang

Refine code and comments

1. Remove checking for num_neg_samples.
2. Fix dims of Output(Cost) and Input(Bias).
3. Renamed num_sampled_classes to num_neg_samples.
4. Add TODO for add more distribution sampler.
5. Init grad_data of bias by zero.
6. Refine comments.
7. Register a kernel for type double.
上级 e60eb1ea
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/nce_op.h"
......@@ -39,25 +39,25 @@ class NCEOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Weight")[0],
ctx->GetInputDim("Bias")[0]);
}
auto num_sampled_classes = ctx->Attrs().Get<int>("num_sampled_classes");
auto num_classes = ctx->Attrs().Get<int>("num_classes");
auto num_neg_samples = ctx->Attrs().Get<int>("num_neg_samples");
auto num_total_classes = ctx->Attrs().Get<int>("num_total_classes");
std::vector<int> sampled_labels =
ctx->Attrs().Get<std::vector<int>>("sampled_labels");
PADDLE_ENFORCE_EQ(num_classes, ctx->GetInputDim("Weight")[0]);
PADDLE_ENFORCE_LT(num_sampled_classes, num_classes);
PADDLE_ENFORCE_EQ(num_total_classes, ctx->GetInputDim("Weight")[0]);
if (sampled_labels.size() > 0) {
PADDLE_ENFORCE_EQ(sampled_labels.size(),
static_cast<size_t>(num_sampled_classes));
static_cast<size_t>(num_neg_samples));
}
// set dims of output(Out)
std::vector<int64_t> out_dims;
out_dims.push_back(x_dims[0]);
out_dims.push_back(1);
ctx->SetOutputDim("Cost", framework::make_ddim(out_dims));
// set dims of output(SampleOut)
std::vector<int64_t> sample_out_dims;
sample_out_dims.push_back(x_dims[0]);
sample_out_dims.push_back(num_sampled_classes + num_true_classes);
sample_out_dims.push_back(num_neg_samples + num_true_classes);
ctx->SetOutputDim("SampleLogits", framework::make_ddim(sample_out_dims));
ctx->SetOutputDim("SampleLabels", framework::make_ddim(sample_out_dims));
}
......@@ -76,34 +76,59 @@ class NCEOpMaker : public framework::OpProtoAndCheckerMaker {
NCEOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Input", "(Tensor) A tensor of shape [batch_size, dim].");
AddInput("Label",
"(Tensor) A tensor of shape [batch_size, num_true_class]. "
"'num_true_class' is the number of target class in each sample.");
AddInput(
"Label",
"(Tensor) A tensor of shape [batch_size, num_true_class]. "
"'num_true_class' is the number of target classes in each sample."
"The number of target classes per sample should be same. "
"If you have a variable number of target classes, "
"you can pad them out to a constant number by either repeating them"
" or by padding with an otherwise unused class.)");
AddInput("Weight",
"(Tensor) A tensor of shape [num_class, dim]. 'num_class' is the "
"total number of class.");
AddInput("Bias",
"(Tensor) A tensor of shape [num_class]. 'num_class' is the total "
"number of class. It is a dispensable input.")
AddInput(
"Bias",
"(Tensor) A tensor of shape [num_class, 1]. 'num_class' is the total "
"number of class. It is a dispensable input.")
.AsDispensable();
AddInput("SampleWeight",
"(Tensor) A tensor of shape [batch_size] storing a weight for "
"(Tensor) A tensor of shape [batch_size, 1] storing a weight for "
"each sample. And it is a dispensable input. The default value of "
"sample is 1.")
.AsDispensable();
AddOutput("Cost",
"(Tensor) A tensor of shape [batch_size]. Cost of samples.");
AddOutput("SampleLogits", "An intermediate tensor.").AsIntermediate();
AddOutput("SampleLabels", "An intermediate tensor.").AsIntermediate();
AddAttr<int>("num_classes", "Total number of classes.");
AddAttr<int>("num_sampled_classes", "The number of negative classes.")
"(Tensor) A tensor of shape [batch_size, 1]. Cost of samples.");
AddOutput("SampleLogits",
"An intermediate tensor of shape[batch_size, num_neg_samples + "
"num_pos_samples]."
"This tensor is output of forward kernel and used in backward "
"kernel to compute grads."
"Given X is the dot product of input tensor and sampled labels' "
"weights."
"Then 'SampleLogits' is sigmoid(X).")
.AsIntermediate();
AddOutput("SampleLabels",
"An intermediate tensor of shape[batch_size, num_neg_samples + "
"num_pos_samples]."
"This tensor is output of forward kernel and used in backward "
"kernel to compute grads."
"")
.AsIntermediate();
AddAttr<int>("num_total_classes",
"Total number of classes in all samples.");
AddAttr<int>("num_neg_samples",
"The number of negative classes. The default value is 10.")
.SetDefault(10);
AddAttr<std::vector<int>>("sampled_labels", "");
AddAttr<std::vector<int>>("custom_neg_classes",
"This attribute only be used in unitest. Classes "
"in this list wiil be used as negative classes "
"for every samples. Under normal conditions, "
"user should avoid setting this attribute.");
AddComment(R"DOC(
Computes and returns the noise-contrastive estimation training loss.
Compute and return the noise-contrastive estimation training loss.
See [Noise-contrastive estimation: A new estimation principle for unnormalized statistical models](http://www.jmlr.org/proceedings/papers/v9/gutmann10a/gutmann10a.pdf).
By default this uses a uniform distribution for sampling.
The number of target classes per example should be same. If you have a variable number of target classes, you can pad them out to a constant number by either repeating them or by padding with an otherwise unused class.
By default this operator uses a uniform distribution for sampling.
)DOC");
}
};
......@@ -119,7 +144,7 @@ class NCEOpGrad : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasInput("SampleLogits"));
PADDLE_ENFORCE(ctx->HasInput("SampleLabels"));
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Cost")),
"The input(Out@GRAD) should not be null");
"The input(Out@GRAD) should not be null.");
auto x_dims = ctx->GetInputDim("Input");
auto x_grad_name = framework::GradVarName("Input");
......@@ -154,6 +179,8 @@ class NCEOpGrad : public framework::OperatorWithKernel {
namespace ops = paddle::operators;
REGISTER_OP(nce, ops::NCEOp, ops::NCEOpMaker, nce_grad, ops::NCEOpGrad);
REGISTER_OP_CPU_KERNEL(nce, ops::NCEKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(nce, ops::NCEKernel<paddle::platform::CPUPlace, float>,
ops::NCEKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL(nce_grad,
ops::NCEGradKernel<paddle::platform::CPUPlace, float>);
ops::NCEGradKernel<paddle::platform::CPUPlace, float>,
ops::NCEGradKernel<paddle::platform::CPUPlace, double>);
......@@ -22,7 +22,7 @@
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
......@@ -35,8 +35,8 @@ void PrepareSamples(const framework::ExecutionContext& context) {
auto label_dims = label->dims();
int num_classes = context.Attr<int>("num_classes");
// for unitest
std::vector<int> sampled_labels =
context.Attr<std::vector<int>>("sampled_labels");
std::vector<int> custom_neg_classes =
context.Attr<std::vector<int>>("custom_neg_classes");
// random machine
std::random_device rd;
std::mt19937 rng(rd());
......@@ -54,12 +54,13 @@ void PrepareSamples(const framework::ExecutionContext& context) {
for (; j < num_label; ++j) {
sample_labels_data[index++] = label_data[i * num_label + j];
}
if (sampled_labels.size() > 0) {
for (auto label : sampled_labels) {
if (custom_neg_classes.size() > 0) {
for (auto label : custom_neg_classes) {
sample_labels_data[index++] = label;
}
} else {
for (; j < sample_labels_dims[1]; ++j) {
// TODO: support more distribution sampling
sample_labels_data[index++] = rand(rng);
}
}
......@@ -176,6 +177,7 @@ class NCEGradKernel : public framework::OpKernel<T> {
auto d_bias = context.Output<Tensor>(framework::GradVarName("Bias"));
if (d_bias != nullptr) {
T* d_bias_data = d_bias->mutable_data<T>(context.GetPlace());
std::fill(d_bias_data, d_bias_data + d_bias->numel(), 0.0);
for (size_t i = 0; i < sample_labels->numel(); ++i) {
d_bias_data[sample_labels_data[i]] += sample_grad_data[i];
}
......@@ -183,7 +185,8 @@ class NCEGradKernel : public framework::OpKernel<T> {
// get d_w
auto d_w = context.Output<Tensor>(framework::GradVarName("Weight"));
if (d_w != nullptr) {
d_w->mutable_data<T>(context.GetPlace());
auto d_w_data = d_w->mutable_data<T>(context.GetPlace());
std::fill(d_w_data, d_w_data + d_w->numel(), 0.0);
auto d_w_matrix = EigenMatrix<T>::From(*d_w);
auto x_matrix = EigenMatrix<T>::From(*(context.Input<Tensor>("Input")));
for (size_t i = 0; i < sample_labels->numel(); ++i) {
......
......@@ -18,25 +18,25 @@ def nce(input, weight, bias, sample_weight, labels, num_classes,
samples.append((i, num, False, w))
sample_labels.append(num)
# forward bias
sampleOut = np.zeros(len(samples)).astype(np.float32)
sample_out = np.zeros(len(samples)).astype(np.float32)
if bias is not None:
for i in range(len(samples)):
sampleOut[i] = bias[samples[i][1]]
sample_out[i] = bias[samples[i][1]]
# forward weight
for i in range(len(samples)):
sampleOut[i] += np.dot(input[samples[i][0]], weight[samples[i][1]])
sample_out[i] += np.dot(input[samples[i][0]], weight[samples[i][1]])
# forward activation
sampleOut = 1.0 / (1.0 + np.exp(-sampleOut))
sample_out = 1.0 / (1.0 + np.exp(-sample_out))
# forward cost
out = np.zeros(batch_size).astype(np.float32)
b = 1.0 / num_classes * num_sample_class
for i in range(len(samples)):
o = sampleOut[i]
o = sample_out[i]
cost = -np.log(o / (o + b)) if samples[i][2] else -np.log(b / (o + b))
out[samples[i][0]] += cost * samples[i][3]
return (out, np.array(sampleOut).reshape(batch_size,
num_sample_class + num_true_class),
return (out, np.array(sample_out).reshape(
batch_size, num_sample_class + num_true_class),
np.array(sample_labels).reshape(batch_size,
num_sample_class + num_true_class))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册