提交 29262ab2 编写于 作者: W wanghaoshuang

Fix unitest.

上级 76a65a83
...@@ -41,11 +41,11 @@ class NCEOp : public framework::OperatorWithKernel { ...@@ -41,11 +41,11 @@ class NCEOp : public framework::OperatorWithKernel {
} }
auto num_neg_samples = ctx->Attrs().Get<int>("num_neg_samples"); auto num_neg_samples = ctx->Attrs().Get<int>("num_neg_samples");
auto num_total_classes = ctx->Attrs().Get<int>("num_total_classes"); auto num_total_classes = ctx->Attrs().Get<int>("num_total_classes");
std::vector<int> sampled_labels = std::vector<int> custom_neg_classes =
ctx->Attrs().Get<std::vector<int>>("sampled_labels"); ctx->Attrs().Get<std::vector<int>>("custom_neg_classes");
PADDLE_ENFORCE_EQ(num_total_classes, ctx->GetInputDim("Weight")[0]); PADDLE_ENFORCE_EQ(num_total_classes, ctx->GetInputDim("Weight")[0]);
if (sampled_labels.size() > 0) { if (custom_neg_classes.size() > 0) {
PADDLE_ENFORCE_EQ(sampled_labels.size(), PADDLE_ENFORCE_EQ(custom_neg_classes.size(),
static_cast<size_t>(num_neg_samples)); static_cast<size_t>(num_neg_samples));
} }
// set dims of output(Out) // set dims of output(Out)
......
...@@ -33,14 +33,14 @@ void PrepareSamples(const framework::ExecutionContext& context) { ...@@ -33,14 +33,14 @@ void PrepareSamples(const framework::ExecutionContext& context) {
auto label = context.Input<Tensor>("Label"); auto label = context.Input<Tensor>("Label");
const int64_t* label_data = label->data<int64_t>(); const int64_t* label_data = label->data<int64_t>();
auto label_dims = label->dims(); auto label_dims = label->dims();
int num_classes = context.Attr<int>("num_classes"); int num_total_classes = context.Attr<int>("num_total_classes");
// for unitest // for unitest
std::vector<int> custom_neg_classes = std::vector<int> custom_neg_classes =
context.Attr<std::vector<int>>("custom_neg_classes"); context.Attr<std::vector<int>>("custom_neg_classes");
// random machine // random machine
std::random_device rd; std::random_device rd;
std::mt19937 rng(rd()); std::mt19937 rng(rd());
std::uniform_int_distribution<int> rand(0, num_classes - 1); std::uniform_int_distribution<int> rand(0, num_total_classes - 1);
auto sample_labels = context.Output<Tensor>("SampleLabels"); auto sample_labels = context.Output<Tensor>("SampleLabels");
auto sample_labels_dims = sample_labels->dims(); auto sample_labels_dims = sample_labels->dims();
...@@ -84,13 +84,13 @@ class NCEKernel : public framework::OpKernel<T> { ...@@ -84,13 +84,13 @@ class NCEKernel : public framework::OpKernel<T> {
} }
auto out = context.Output<Tensor>("Cost"); auto out = context.Output<Tensor>("Cost");
T* out_data = out->mutable_data<T>(context.GetPlace()); T* out_data = out->mutable_data<T>(context.GetPlace());
int num_smalped_classes = context.Attr<int>("num_sampled_classes"); int num_neg_samples = context.Attr<int>("num_neg_samples");
int num_classes = context.Attr<int>("num_classes"); int num_total_classes = context.Attr<int>("num_total_classes");
int num_true_class = 1; int num_true_class = 1;
if (label != nullptr) { if (label != nullptr) {
num_true_class = label->dims()[1]; num_true_class = label->dims()[1];
} }
T b = 1. / num_classes * num_smalped_classes; T b = 1. / num_total_classes * num_neg_samples;
// forward bias // forward bias
auto bias = context.Input<Tensor>("Bias"); auto bias = context.Input<Tensor>("Bias");
if (bias != nullptr) { if (bias != nullptr) {
...@@ -151,13 +151,13 @@ class NCEGradKernel : public framework::OpKernel<T> { ...@@ -151,13 +151,13 @@ class NCEGradKernel : public framework::OpKernel<T> {
if (sample_weight != nullptr) { if (sample_weight != nullptr) {
sample_weight_data = sample_weight->data<T>(); sample_weight_data = sample_weight->data<T>();
} }
int num_smalped_classes = context.Attr<int>("num_sampled_classes"); int num_neg_samples = context.Attr<int>("num_neg_samples");
int num_classes = context.Attr<int>("num_classes"); int num_total_classes = context.Attr<int>("num_total_classes");
int num_true_class = 1; int num_true_class = 1;
if (label != nullptr) { if (label != nullptr) {
num_true_class = label->dims()[1]; num_true_class = label->dims()[1];
} }
T b = 1. / num_classes * num_smalped_classes; T b = 1. / num_total_classes * num_neg_samples;
Tensor sample_grad; // tmp tensor Tensor sample_grad; // tmp tensor
T* sample_grad_data = T* sample_grad_data =
sample_grad.mutable_data<T>(sample_labels->dims(), context.GetPlace()); sample_grad.mutable_data<T>(sample_labels->dims(), context.GetPlace());
......
...@@ -35,7 +35,7 @@ def nce(input, weight, bias, sample_weight, labels, num_classes, ...@@ -35,7 +35,7 @@ def nce(input, weight, bias, sample_weight, labels, num_classes,
o = sample_out[i] o = sample_out[i]
cost = -np.log(o / (o + b)) if samples[i][2] else -np.log(b / (o + b)) cost = -np.log(o / (o + b)) if samples[i][2] else -np.log(b / (o + b))
out[samples[i][0]] += cost * samples[i][3] out[samples[i][0]] += cost * samples[i][3]
return (out, np.array(sample_out).reshape( return (out[:, np.newaxis], np.array(sample_out).reshape(
batch_size, num_sample_class + num_true_class), batch_size, num_sample_class + num_true_class),
np.array(sample_labels).reshape(batch_size, np.array(sample_labels).reshape(batch_size,
num_sample_class + num_true_class)) num_sample_class + num_true_class))
...@@ -43,16 +43,16 @@ def nce(input, weight, bias, sample_weight, labels, num_classes, ...@@ -43,16 +43,16 @@ def nce(input, weight, bias, sample_weight, labels, num_classes,
class TestNCE(OpTest): class TestNCE(OpTest):
def generate_data(self, dim, batch_size, num_classes, num_true_class, def generate_data(self, dim, batch_size, num_classes, num_true_class,
num_sampled_classes): num_neg_samples):
input = np.random.randn(batch_size, dim).astype(np.float32) input = np.random.randn(batch_size, dim).astype(np.float32)
weight = np.random.randn(num_classes, dim).astype(np.float32) weight = np.random.randn(num_classes, dim).astype(np.float32)
bias = np.random.randn(num_classes).astype(np.float32) bias = np.random.randn(num_classes).astype(np.float32)
sample_weight = np.random.randn(batch_size).astype(np.float32) sample_weight = np.random.randn(batch_size).astype(np.float32)
labels = np.random.randint(0, num_classes, (batch_size, num_true_class)) labels = np.random.randint(0, num_classes, (batch_size, num_true_class))
self.attrs = { self.attrs = {
'num_classes': num_classes, 'num_total_classes': num_classes,
'num_sampled_classes': num_sampled_classes, 'num_neg_samples': num_neg_samples,
'sampled_labels': range(num_sampled_classes) 'custom_neg_classes': range(num_neg_samples)
} }
self.inputs = { self.inputs = {
'Input': input, 'Input': input,
...@@ -68,8 +68,8 @@ class TestNCE(OpTest): ...@@ -68,8 +68,8 @@ class TestNCE(OpTest):
def compute(self): def compute(self):
out = nce(self.inputs['Input'], self.inputs['Weight'], out = nce(self.inputs['Input'], self.inputs['Weight'],
self.inputs['Bias'], self.inputs['SampleWeight'], self.inputs['Bias'], self.inputs['SampleWeight'],
self.inputs['Label'], self.attrs['num_classes'], self.inputs['Label'], self.attrs['num_total_classes'],
self.attrs['num_sampled_classes']) self.attrs['num_neg_samples'])
self.outputs = { self.outputs = {
'Cost': out[0], 'Cost': out[0],
'SampleLogits': out[1], 'SampleLogits': out[1],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册