提交 09d32b06 编写于 作者: W wanghaoshuang

Add unitest and comments.

上级 c2dd75be
...@@ -23,57 +23,87 @@ class NCEOp : public framework::OperatorWithKernel { ...@@ -23,57 +23,87 @@ class NCEOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X")); PADDLE_ENFORCE(ctx->HasInput("Input"));
PADDLE_ENFORCE(ctx->HasInput("Label")); PADDLE_ENFORCE(ctx->HasInput("Label"));
PADDLE_ENFORCE(ctx->HasInput("W")); PADDLE_ENFORCE(ctx->HasInput("Weight"));
PADDLE_ENFORCE(ctx->HasOutput("Out")); PADDLE_ENFORCE(ctx->HasOutput("Cost"));
PADDLE_ENFORCE(ctx->HasOutput("SampleLogits")); PADDLE_ENFORCE(ctx->HasOutput("SampleLogits"));
PADDLE_ENFORCE(ctx->HasOutput("SampleLabels")); PADDLE_ENFORCE(ctx->HasOutput("SampleLabels"));
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("Input");
auto label_dims = ctx->GetInputDim("Label"); auto label_dims = ctx->GetInputDim("Label");
PADDLE_ENFORCE_EQ(x_dims[0], label_dims[0]); PADDLE_ENFORCE_EQ(x_dims[0], label_dims[0]);
if (ctx->HasInput("B")) { int num_true_classes = label_dims.size() == 2 ? label_dims[1] : 1;
PADDLE_ENFORCE_EQ(ctx->GetInputDim("W")[0], ctx->GetInputDim("B")[0]); if (ctx->HasInput("Bias")) {
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Weight")[0],
ctx->GetInputDim("Bias")[0]);
} }
int num_sampled_classes = ctx->Attrs().Get<int>("num_sampled_classes"); auto num_sampled_classes = ctx->Attrs().Get<int>("num_sampled_classes");
int num_classes = ctx->Attrs().Get<int>("num_classes"); auto num_classes = ctx->Attrs().Get<int>("num_classes");
PADDLE_ENFORCE_EQ(num_classes, ctx->GetInputDim("W")[0]); std::vector<int> sampled_labels =
ctx->Attrs().Get<std::vector<int>>("sampled_labels");
PADDLE_ENFORCE_EQ(num_classes, ctx->GetInputDim("Weight")[0]);
PADDLE_ENFORCE_LT(num_sampled_classes, num_classes); PADDLE_ENFORCE_LT(num_sampled_classes, num_classes);
if (sampled_labels.size() > 0) {
PADDLE_ENFORCE_EQ(sampled_labels.size(),
static_cast<size_t>(num_sampled_classes));
}
// set dims of output(Out) // set dims of output(Out)
std::vector<int64_t> out_dims(1); std::vector<int64_t> out_dims;
out_dims.push_back(x_dims[0]); out_dims.push_back(x_dims[0]);
ctx->SetOutputDim("Out", framework::make_ddim(out_dims)); ctx->SetOutputDim("Cost", framework::make_ddim(out_dims));
// set dims of output(SampleOut) // set dims of output(SampleOut)
std::vector<int64_t> sample_out_dims(2); std::vector<int64_t> sample_out_dims;
sample_out_dims.push_back(x_dims[0]); sample_out_dims.push_back(x_dims[0]);
sample_out_dims.push_back(num_sampled_classes + 1); sample_out_dims.push_back(num_sampled_classes + num_true_classes);
ctx->SetOutputDim("SampleLogits", framework::make_ddim(sample_out_dims)); ctx->SetOutputDim("SampleLogits", framework::make_ddim(sample_out_dims));
ctx->SetOutputDim("SampleLabels", framework::make_ddim(sample_out_dims)); ctx->SetOutputDim("SampleLabels", framework::make_ddim(sample_out_dims));
} }
protected:
framework::OpKernelType GetKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Input")->type()),
ctx.device_context());
}
}; };
class NCEOpMaker : public framework::OpProtoAndCheckerMaker { class NCEOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
NCEOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) NCEOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", ""); AddInput("Input", "(Tensor) A tensor of shape [batch_size, dim].");
AddInput("Label", ""); AddInput("Label",
AddInput("W", ""); "(Tensor) A tensor of shape [batch_size, num_true_class]. "
AddInput("B", ""); "'num_true_class' is the number of target class in each sample.");
AddInput("SampleWeight", ""); AddInput("Weight",
AddOutput("Out", ""); "(Tensor) A tensor of shape [num_class, dim]. 'num_class' is the "
AddOutput("SampleLogits", ""); "total number of class.");
AddOutput("SampleLabels", ""); AddInput("Bias",
AddAttr<int>("num_classes", ""); "(Tensor) A tensor of shape [num_class]. 'num_class' is the total "
AddAttr<int>("num_sampled_classes", "").SetDefault(10); "number of class. It is a dispensable input.")
.AsDispensable();
AddInput("SampleWeight",
"(Tensor) A tensor of shape [batch_size] storing a weight for "
"each sample. And it is a dispensable input. The default value of "
"sample is 1.")
.AsDispensable();
AddOutput("Cost",
"(Tensor) A tensor of shape [batch_size]. Cost of samples.");
AddOutput("SampleLogits", "An intermediate tensor.").AsIntermediate();
AddOutput("SampleLabels", "An intermediate tensor.").AsIntermediate();
AddAttr<int>("num_classes", "Total number of classes.");
AddAttr<int>("num_sampled_classes", "The number of negative classes.")
.SetDefault(10);
AddAttr<std::vector<int>>("sampled_labels", "");
AddComment(R"DOC( AddComment(R"DOC(
Expand input(X) according to LOD of input(Y). Computes and returns the noise-contrastive estimation training loss.
See [Noise-contrastive estimation: A new estimation principle for unnormalized statistical models](http://www.jmlr.org/proceedings/papers/v9/gutmann10a/gutmann10a.pdf).
By default this uses a uniform distribution for sampling.
The number of target classes per example should be same. If you have a variable number of target classes, you can pad them out to a constant number by either repeating them or by padding with an otherwise unused class.
)DOC"); )DOC");
} }
}; };
...@@ -82,32 +112,41 @@ class NCEOpGrad : public framework::OperatorWithKernel { ...@@ -82,32 +112,41 @@ class NCEOpGrad : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X")); PADDLE_ENFORCE(ctx->HasInput("Input"));
PADDLE_ENFORCE(ctx->HasInput("W")); PADDLE_ENFORCE(ctx->HasInput("Weight"));
PADDLE_ENFORCE(ctx->HasInput("Out")); PADDLE_ENFORCE(ctx->HasInput("Cost"));
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), PADDLE_ENFORCE(ctx->HasInput("SampleLogits"));
PADDLE_ENFORCE(ctx->HasInput("SampleLabels"));
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Cost")),
"The input(Out@GRAD) should not be null"); "The input(Out@GRAD) should not be null");
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("Input");
auto x_grad_name = framework::GradVarName("X"); auto x_grad_name = framework::GradVarName("Input");
if (ctx->HasOutput(x_grad_name)) { if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, x_dims); ctx->SetOutputDim(x_grad_name, x_dims);
} }
auto w_dims = ctx->GetInputDim("W"); auto w_dims = ctx->GetInputDim("Weight");
auto w_grad_name = framework::GradVarName("W"); auto w_grad_name = framework::GradVarName("Weight");
if (ctx->HasOutput(w_grad_name)) { if (ctx->HasOutput(w_grad_name)) {
ctx->SetOutputDim(w_grad_name, w_dims); ctx->SetOutputDim(w_grad_name, w_dims);
} }
auto bias_grad_name = framework::GradVarName("B"); auto bias_grad_name = framework::GradVarName("Bias");
if (ctx->HasOutput(bias_grad_name)) { if (ctx->HasOutput(bias_grad_name)) {
auto bias_dims = ctx->GetInputDim("B"); auto bias_dims = ctx->GetInputDim("Bias");
ctx->SetOutputDim(bias_grad_name, bias_dims); ctx->SetOutputDim(bias_grad_name, bias_dims);
} }
} }
protected:
framework::OpKernelType GetKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Input")->type()),
ctx.device_context());
}
}; };
} // namespace operators } // namespace operators
......
...@@ -14,12 +14,11 @@ ...@@ -14,12 +14,11 @@
#pragma once #pragma once
#include <math.h>
#include <random> #include <random>
#include "paddle/framework/eigen.h" #include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/memory/memcpy.h"
#include "unsupported/Eigen/CXX11/Tensor" #include "unsupported/Eigen/CXX11/Tensor"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -32,9 +31,12 @@ using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>; ...@@ -32,9 +31,12 @@ using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename Place, typename T> template <typename Place, typename T>
void PrepareSamples(const framework::ExecutionContext& context) { void PrepareSamples(const framework::ExecutionContext& context) {
auto label = context.Input<Tensor>("Label"); auto label = context.Input<Tensor>("Label");
const T* label_data = label->data<T>(); const int64_t* label_data = label->data<int64_t>();
auto label_dims = label->dims(); auto label_dims = label->dims();
int num_classes = context.Attr<int>("num_classes"); int num_classes = context.Attr<int>("num_classes");
// for unitest
std::vector<int> sampled_labels =
context.Attr<std::vector<int>>("sampled_labels");
// random machine // random machine
std::random_device rd; std::random_device rd;
std::mt19937 rng(rd()); std::mt19937 rng(rd());
...@@ -42,19 +44,24 @@ void PrepareSamples(const framework::ExecutionContext& context) { ...@@ -42,19 +44,24 @@ void PrepareSamples(const framework::ExecutionContext& context) {
auto sample_labels = context.Output<Tensor>("SampleLabels"); auto sample_labels = context.Output<Tensor>("SampleLabels");
auto sample_labels_dims = sample_labels->dims(); auto sample_labels_dims = sample_labels->dims();
int* sample_labels_data = int64_t* sample_labels_data =
sample_labels->mutable_data<int>(context.GetPlace()); sample_labels->mutable_data<int64_t>(context.GetPlace());
int num_label = label_dims.size() == 2 ? label_dims[1] : 1; int num_label = label_dims.size() == 2 ? label_dims[1] : 1;
int index = 0;
for (size_t i = 0; i < label_dims[0]; ++i) { for (size_t i = 0; i < label_dims[0]; ++i) {
int j = 0; int j = 0;
for (; j < num_label; ++j) { for (; j < num_label; ++j) {
sample_labels_data[sample_labels_dims[1] * i + j] = sample_labels_data[index++] = label_data[i * num_label + j];
label_data[i * num_label + j];
} }
for (; j < sample_labels_dims[1]; ++j) { if (sampled_labels.size() > 0) {
int id = rand(rng); for (auto label : sampled_labels) {
sample_labels_data[sample_labels_dims[1] * i + j] = id; sample_labels_data[index++] = label;
}
} else {
for (; j < sample_labels_dims[1]; ++j) {
sample_labels_data[index++] = rand(rng);
}
} }
} }
} }
...@@ -65,7 +72,7 @@ class NCEKernel : public framework::OpKernel<T> { ...@@ -65,7 +72,7 @@ class NCEKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
PrepareSamples<Place, T>(context); PrepareSamples<Place, T>(context);
auto sample_labels = context.Output<Tensor>("SampleLabels"); auto sample_labels = context.Output<Tensor>("SampleLabels");
const int* sample_labels_data = sample_labels->data<int>(); const int64_t* sample_labels_data = sample_labels->data<int64_t>();
auto sample_out = context.Output<Tensor>("SampleLogits"); auto sample_out = context.Output<Tensor>("SampleLogits");
T* sample_out_data = sample_out->mutable_data<T>(context.GetPlace()); T* sample_out_data = sample_out->mutable_data<T>(context.GetPlace());
auto label = context.Input<Tensor>("Label"); auto label = context.Input<Tensor>("Label");
...@@ -74,7 +81,7 @@ class NCEKernel : public framework::OpKernel<T> { ...@@ -74,7 +81,7 @@ class NCEKernel : public framework::OpKernel<T> {
if (sample_weight != nullptr) { if (sample_weight != nullptr) {
sample_weight_data = sample_weight->data<T>(); sample_weight_data = sample_weight->data<T>();
} }
auto out = context.Output<Tensor>("Out"); auto out = context.Output<Tensor>("Cost");
T* out_data = out->mutable_data<T>(context.GetPlace()); T* out_data = out->mutable_data<T>(context.GetPlace());
int num_smalped_classes = context.Attr<int>("num_sampled_classes"); int num_smalped_classes = context.Attr<int>("num_sampled_classes");
int num_classes = context.Attr<int>("num_classes"); int num_classes = context.Attr<int>("num_classes");
...@@ -83,9 +90,8 @@ class NCEKernel : public framework::OpKernel<T> { ...@@ -83,9 +90,8 @@ class NCEKernel : public framework::OpKernel<T> {
num_true_class = label->dims()[1]; num_true_class = label->dims()[1];
} }
T b = 1. / num_classes * num_smalped_classes; T b = 1. / num_classes * num_smalped_classes;
// forward bias // forward bias
auto bias = context.Input<Tensor>("B"); auto bias = context.Input<Tensor>("Bias");
if (bias != nullptr) { if (bias != nullptr) {
const T* bias_data = bias->data<T>(); const T* bias_data = bias->data<T>();
for (size_t i = 0; i < sample_labels->numel(); ++i) { for (size_t i = 0; i < sample_labels->numel(); ++i) {
...@@ -96,27 +102,23 @@ class NCEKernel : public framework::OpKernel<T> { ...@@ -96,27 +102,23 @@ class NCEKernel : public framework::OpKernel<T> {
sample_out_data[i] = 0; sample_out_data[i] = 0;
} }
} }
// forward mul // forward mul
auto input_mat = EigenMatrix<T>::From(*(context.Input<Tensor>("X"))); auto input_mat = EigenMatrix<T>::From(*(context.Input<Tensor>("Input")));
auto weight_mat = EigenMatrix<T>::From(*(context.Input<Tensor>("W"))); auto weight_mat = EigenMatrix<T>::From(*(context.Input<Tensor>("Weight")));
for (size_t i = 0; i < sample_labels->numel(); ++i) { for (size_t i = 0; i < sample_labels->numel(); ++i) {
// sample_out_data[i] += (input_mat.chip((int)(i /
// sample_labels->dims()[1]), 0) * weight_mat.chip(sample_labels_data[i],
// 0)).sum();
Eigen::Tensor<float, 0, Eigen::RowMajor, Eigen::DenseIndex> result = Eigen::Tensor<float, 0, Eigen::RowMajor, Eigen::DenseIndex> result =
(input_mat.chip((int)(i / sample_labels->dims()[1]), 0) * (input_mat.chip((int)(i / sample_labels->dims()[1]), 0) *
weight_mat.chip(sample_labels_data[i], 0)) weight_mat.chip(sample_labels_data[i], 0))
.sum(); .sum();
sample_out_data[i] += result(0); sample_out_data[i] += result(0);
// activation_->forward // activation_->forward
sample_out_data[i] = (1 / 1 + (sample_out_data[i])); sample_out_data[i] = (1. / (1. + exp(-sample_out_data[i])));
} }
// forward cost // forward cost
for (size_t i = 0; i < sample_labels->dims()[0]; ++i) { for (size_t i = 0; i < sample_labels->dims()[0]; ++i) {
size_t j = 0; size_t j = 0;
T w = sample_weight == nullptr ? 1 : sample_weight_data[i]; out_data[i] = 0;
T w = sample_weight == nullptr ? 1. : sample_weight_data[i];
// for true classes // for true classes
for (; j < num_true_class; ++j) { for (; j < num_true_class; ++j) {
T o = sample_out_data[i * sample_out->dims()[1] + j]; T o = sample_out_data[i * sample_out->dims()[1] + j];
...@@ -137,11 +139,13 @@ template <typename Place, typename T> ...@@ -137,11 +139,13 @@ template <typename Place, typename T>
class NCEGradKernel : public framework::OpKernel<T> { class NCEGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto d_out = context.Input<Tensor>(framework::GradVarName("Cost"));
const T* d_out_data = d_out->data<T>();
auto label = context.Input<Tensor>("Label"); auto label = context.Input<Tensor>("Label");
auto sample_out = context.Input<Tensor>("SampleLogits"); auto sample_out = context.Input<Tensor>("SampleLogits");
const T* sample_out_data = sample_out->data<T>(); const T* sample_out_data = sample_out->data<T>();
auto sample_labels = context.Input<Tensor>("SampleLabels"); auto sample_labels = context.Input<Tensor>("SampleLabels");
const int* sample_labels_data = sample_labels->data<int>(); const int64_t* sample_labels_data = sample_labels->data<int64_t>();
auto sample_weight = context.Input<Tensor>("SampleWeight"); auto sample_weight = context.Input<Tensor>("SampleWeight");
const T* sample_weight_data = nullptr; const T* sample_weight_data = nullptr;
if (sample_weight != nullptr) { if (sample_weight != nullptr) {
...@@ -154,11 +158,9 @@ class NCEGradKernel : public framework::OpKernel<T> { ...@@ -154,11 +158,9 @@ class NCEGradKernel : public framework::OpKernel<T> {
num_true_class = label->dims()[1]; num_true_class = label->dims()[1];
} }
T b = 1. / num_classes * num_smalped_classes; T b = 1. / num_classes * num_smalped_classes;
Tensor sample_grad; // tmp tensor Tensor sample_grad; // tmp tensor
T* sample_grad_data = T* sample_grad_data =
sample_grad.mutable_data<T>(sample_labels->dims(), context.GetPlace()); sample_grad.mutable_data<T>(sample_labels->dims(), context.GetPlace());
// backward cost // backward cost
for (size_t i = 0; i < sample_labels->numel(); ++i) { for (size_t i = 0; i < sample_labels->numel(); ++i) {
T o = sample_out_data[i]; T o = sample_out_data[i];
...@@ -166,15 +168,12 @@ class NCEGradKernel : public framework::OpKernel<T> { ...@@ -166,15 +168,12 @@ class NCEGradKernel : public framework::OpKernel<T> {
? 1 ? 1
: sample_weight_data[i / sample_labels->dims()[1]]; : sample_weight_data[i / sample_labels->dims()[1]];
sample_grad_data[i] = (i % sample_labels->dims()[1]) < num_true_class sample_grad_data[i] = (i % sample_labels->dims()[1]) < num_true_class
? -w * b / (o * (o + b)) ? w * (b / (o + b)) * (o - 1)
: w / (o + b); : w * (o * (1 - o) / (o + b));
// sigmoid->backward sample_grad_data[i] *= d_out_data[i / sample_labels->dims()[1]];
sample_grad_data[i] =
(o > 0) ? sample_grad_data[i] : ((o < 0) ? -sample_grad_data[i] : 0);
} }
// get d_bias // get d_bias
auto d_bias = context.Output<Tensor>(framework::GradVarName("B")); auto d_bias = context.Output<Tensor>(framework::GradVarName("Bias"));
if (d_bias != nullptr) { if (d_bias != nullptr) {
T* d_bias_data = d_bias->mutable_data<T>(context.GetPlace()); T* d_bias_data = d_bias->mutable_data<T>(context.GetPlace());
for (size_t i = 0; i < sample_labels->numel(); ++i) { for (size_t i = 0; i < sample_labels->numel(); ++i) {
...@@ -182,22 +181,23 @@ class NCEGradKernel : public framework::OpKernel<T> { ...@@ -182,22 +181,23 @@ class NCEGradKernel : public framework::OpKernel<T> {
} }
} }
// get d_w // get d_w
auto d_w = context.Output<Tensor>(framework::GradVarName("W")); auto d_w = context.Output<Tensor>(framework::GradVarName("Weight"));
if (d_w != nullptr) { if (d_w != nullptr) {
d_w->mutable_data<T>(context.GetPlace());
auto d_w_matrix = EigenMatrix<T>::From(*d_w); auto d_w_matrix = EigenMatrix<T>::From(*d_w);
auto x_matrix = EigenMatrix<T>::From(*(context.Input<Tensor>("X"))); auto x_matrix = EigenMatrix<T>::From(*(context.Input<Tensor>("Input")));
for (size_t i = 0; i < sample_labels->numel(); ++i) { for (size_t i = 0; i < sample_labels->numel(); ++i) {
d_w_matrix.chip(sample_labels_data[i], 0) = d_w_matrix.chip(sample_labels_data[i], 0) +=
x_matrix.chip((int)(i / sample_labels->dims()[1]), 0) * x_matrix.chip((int)(i / sample_labels->dims()[1]), 0) *
sample_grad_data[i]; sample_grad_data[i];
} }
} }
// get d_x // get d_x
auto d_x = context.Output<Tensor>(framework::GradVarName("X")); auto d_x = context.Output<Tensor>(framework::GradVarName("Input"));
if (d_x != nullptr) { if (d_x != nullptr) {
d_x->mutable_data<T>(context.GetPlace());
auto d_x_matrix = EigenMatrix<T>::From(*d_x); auto d_x_matrix = EigenMatrix<T>::From(*d_x);
auto w_matrix = EigenMatrix<T>::From(*(context.Input<Tensor>("W"))); auto w_matrix = EigenMatrix<T>::From(*(context.Input<Tensor>("Weight")));
for (size_t i = 0; i < sample_labels->numel(); ++i) { for (size_t i = 0; i < sample_labels->numel(); ++i) {
d_x_matrix.chip((int)(i / sample_labels->dims()[1]), 0) += d_x_matrix.chip((int)(i / sample_labels->dims()[1]), 0) +=
w_matrix.chip(sample_labels_data[i], 0) * sample_grad_data[i]; w_matrix.chip(sample_labels_data[i], 0) * sample_grad_data[i];
...@@ -205,6 +205,5 @@ class NCEGradKernel : public framework::OpKernel<T> { ...@@ -205,6 +205,5 @@ class NCEGradKernel : public framework::OpKernel<T> {
} }
} }
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
import unittest
import numpy as np
from op_test import OpTest
def nce(input, weight, bias, sample_weight, labels, num_classes,
num_sample_class):
samples = []
sample_labels = []
batch_size = input.shape[0]
num_true_class = labels.shape[1]
for i in range(batch_size):
w = 1 if sample_weight is None else sample_weight[i]
for label in labels[i]:
samples.append((i, label, True, w))
sample_labels.append(label)
for num in range(num_sample_class):
samples.append((i, num, False, w))
sample_labels.append(num)
# forward bias
sampleOut = np.zeros(len(samples)).astype(np.float32)
if bias is not None:
for i in range(len(samples)):
sampleOut[i] = bias[samples[i][1]]
# forward weight
for i in range(len(samples)):
sampleOut[i] += np.dot(input[samples[i][0]], weight[samples[i][1]])
# forward activation
sampleOut = 1.0 / (1.0 + np.exp(-sampleOut))
# forward cost
out = np.zeros(batch_size).astype(np.float32)
b = 1.0 / num_classes * num_sample_class
for i in range(len(samples)):
o = sampleOut[i]
cost = -np.log(o / (o + b)) if samples[i][2] else -np.log(b / (o + b))
out[samples[i][0]] += cost * samples[i][3]
return (out, np.array(sampleOut).reshape(batch_size,
num_sample_class + num_true_class),
np.array(sample_labels).reshape(batch_size,
num_sample_class + num_true_class))
class TestNCE(OpTest):
def generate_data(self, dim, batch_size, num_classes, num_true_class,
num_sampled_classes):
input = np.random.randn(batch_size, dim).astype(np.float32)
weight = np.random.randn(num_classes, dim).astype(np.float32)
bias = np.random.randn(num_classes).astype(np.float32)
sample_weight = np.random.randn(batch_size).astype(np.float32)
labels = np.random.randint(0, num_classes, (batch_size, num_true_class))
self.attrs = {
'num_classes': num_classes,
'num_sampled_classes': num_sampled_classes,
'sampled_labels': range(num_sampled_classes)
}
self.inputs = {
'X': input,
'Label': labels,
'W': weight,
'B': bias,
'SampleWeight': sample_weight
}
def set_data(self):
self.generate_data(5, 5, 4, 1, 2)
def compute(self):
out = nce(self.inputs['X'], self.inputs['W'], self.inputs['B'],
self.inputs['SampleWeight'], self.inputs['Label'],
self.attrs['num_classes'], self.attrs['num_sampled_classes'])
self.outputs = {
'Out': out[0],
'SampleLogits': out[1],
'SampleLabels': out[2]
}
def setUp(self):
self.op_type = 'nce'
self.set_data()
self.compute()
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["X", "W", "B"], "Out", max_relative_error=0.02)
class TestNCECase1(TestNCE):
def set_data(self):
self.generate_data(10, 20, 10, 2, 5)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册