未验证 提交 17226782 编写于 作者: W whs 提交者: GitHub

Make nce support more distribution. (#13549)

* Fix truncated normal.

* Fix.

* Make nce support more distribution.

* Fix API.spec.

* Fix python API.

* Fix.
test=develop

* Fix API.spec
test=develop

* Fix sampler.

* Fix order of arguments in python API.
test=develop
上级 2f27c048
...@@ -97,7 +97,7 @@ paddle.fluid.layers.warpctc ArgSpec(args=['input', 'label', 'blank', 'norm_by_ti ...@@ -97,7 +97,7 @@ paddle.fluid.layers.warpctc ArgSpec(args=['input', 'label', 'blank', 'norm_by_ti
paddle.fluid.layers.sequence_reshape ArgSpec(args=['input', 'new_dim'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.sequence_reshape ArgSpec(args=['input', 'new_dim'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.transpose ArgSpec(args=['x', 'perm', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.transpose ArgSpec(args=['x', 'perm', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.im2sequence ArgSpec(args=['input', 'filter_size', 'stride', 'padding', 'input_image_size', 'out_stride', 'name'], varargs=None, keywords=None, defaults=(1, 1, 0, None, 1, None)) paddle.fluid.layers.im2sequence ArgSpec(args=['input', 'filter_size', 'stride', 'padding', 'input_image_size', 'out_stride', 'name'], varargs=None, keywords=None, defaults=(1, 1, 0, None, 1, None))
paddle.fluid.layers.nce ArgSpec(args=['input', 'label', 'num_total_classes', 'sample_weight', 'param_attr', 'bias_attr', 'num_neg_samples', 'name'], varargs=None, keywords=None, defaults=(None, None, None, None, None)) paddle.fluid.layers.nce ArgSpec(args=['input', 'label', 'num_total_classes', 'sample_weight', 'param_attr', 'bias_attr', 'num_neg_samples', 'name', 'sampler', 'custom_dist', 'seed'], varargs=None, keywords=None, defaults=(None, None, None, None, None, 'uniform', None, 0))
paddle.fluid.layers.hsigmoid ArgSpec(args=['input', 'label', 'num_classes', 'param_attr', 'bias_attr', 'name'], varargs=None, keywords=None, defaults=(None, None, None)) paddle.fluid.layers.hsigmoid ArgSpec(args=['input', 'label', 'num_classes', 'param_attr', 'bias_attr', 'name'], varargs=None, keywords=None, defaults=(None, None, None))
paddle.fluid.layers.beam_search ArgSpec(args=['pre_ids', 'pre_scores', 'ids', 'scores', 'beam_size', 'end_id', 'level', 'name'], varargs=None, keywords=None, defaults=(0, None)) paddle.fluid.layers.beam_search ArgSpec(args=['pre_ids', 'pre_scores', 'ids', 'scores', 'beam_size', 'end_id', 'level', 'name'], varargs=None, keywords=None, defaults=(0, None))
paddle.fluid.layers.row_conv ArgSpec(args=['input', 'future_context_size', 'param_attr', 'act'], varargs=None, keywords=None, defaults=(None, None)) paddle.fluid.layers.row_conv ArgSpec(args=['input', 'future_context_size', 'param_attr', 'act'], varargs=None, keywords=None, defaults=(None, None))
......
...@@ -308,6 +308,7 @@ op_library(flatten_op DEPS reshape_op) ...@@ -308,6 +308,7 @@ op_library(flatten_op DEPS reshape_op)
op_library(sequence_pad_op DEPS sequence_padding) op_library(sequence_pad_op DEPS sequence_padding)
op_library(unstack_op DEPS stack_op) op_library(unstack_op DEPS stack_op)
op_library(fake_quantize_op DEPS memory) op_library(fake_quantize_op DEPS memory)
op_library(nce_op DEPS sampler)
if (NOT WIN32) if (NOT WIN32)
op_library(crf_decoding_op DEPS jit_kernel) op_library(crf_decoding_op DEPS jit_kernel)
op_library(fusion_lstm_op DEPS jit_kernel) op_library(fusion_lstm_op DEPS jit_kernel)
......
...@@ -41,6 +41,7 @@ math_library(cross_entropy) ...@@ -41,6 +41,7 @@ math_library(cross_entropy)
math_library(cos_sim_functor) math_library(cos_sim_functor)
math_library(depthwise_conv) math_library(depthwise_conv)
math_library(im2col) math_library(im2col)
math_library(sampler)
if (NOT WIN32) # windows do not support avx functions yet. if (NOT WIN32) # windows do not support avx functions yet.
math_library(gru_compute DEPS activation_functions math_function) math_library(gru_compute DEPS activation_functions math_function)
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
...@@ -13,52 +13,46 @@ See the License for the specific language governing permissions and ...@@ -13,52 +13,46 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/math/sampler.h" #include "paddle/fluid/operators/math/sampler.h"
#include <iostream>
#include <queue>
#include <utility>
#include <vector>
namespace paddle { namespace paddle {
namespace random { namespace operators {
namespace math {
Sampler::~Sampler() {} Sampler::~Sampler() {}
UniformSampler::UniformSampler(int64 range) UniformSampler::UniformSampler(int64_t range, unsigned int seed)
: Sampler(range), inv_range_(1.0 / range) { : Sampler(range, seed), inv_range_(1.0 / (range + 1)) {
random_engine_ = std::make_shared<std::mt19937>(seed_); random_engine_ = std::make_shared<std::mt19937_64>(seed_);
dist_ = std::make_shared<std::uniform_int_distribution<>>(0, range); dist_ = std::make_shared<std::uniform_int_distribution<>>(0, range);
} }
UniformSampler::UniformSampler(int64 range, unsigned int seed) int64_t UniformSampler::Sample() const { return (*dist_)(*random_engine_); }
: Sampler(range, seed), inv_range_(1.0 / range) {
random_engine_ = std::make_shared<std::mt19937>(seed_);
dist_ = std::make_shared<std::uniform_int_distribution<>>(0, range);
}
int64 UniformSampler::Sample() const { return (*dist_)(*random_engine_); }
float UniformSampler::Probability(int64 value) const { return inv_range_; } float UniformSampler::Probability(int64_t value) const { return inv_range_; }
LogUniformSampler::LogUniformSampler(int64 range) LogUniformSampler::LogUniformSampler(int64_t range, unsigned int seed)
: Sampler(range), log_range_(log(range + 1)) {
random_engine_ = std::make_shared<std::mt19937>(seed_);
dist_ = std::make_shared<std::uniform_real_distribution<>>(0, 1);
}
LogUniformSampler::LogUniformSampler(int64 range, unsigned int seed)
: Sampler(range, seed), log_range_(log(range + 1)) { : Sampler(range, seed), log_range_(log(range + 1)) {
random_engine_ = std::make_shared<std::mt19937>(seed_); random_engine_ = std::make_shared<std::mt19937_64>(seed_);
dist_ = std::make_shared<std::uniform_real_distribution<>>(0, 1); dist_ = std::make_shared<std::uniform_real_distribution<>>(0, 1);
} }
int64 LogUniformSampler::Sample() const {
int64_t LogUniformSampler::Sample() const {
// Got Log Uniform distribution from uniform distribution by // Got Log Uniform distribution from uniform distribution by
// inverse_transform_sampling method // inverse_transform_sampling method
// More details: // More details:
// https://wanghaoshuang.github.io/2017/11/Log-uniform-distribution-sampler/ // https://wanghaoshuang.github.io/2017/11/Log-uniform-distribution-sampler/
const int64 value = const int64_t value =
static_cast<int64>(exp((*dist_)(*random_engine_) * log_range_)) - 1; static_cast<int64_t>(exp((*dist_)(*random_engine_) * log_range_)) - 1;
// Mathematically, value should be <= range_, but might not be due to some // Mathematically, value should be <= range_, but might not be due to some
// floating point roundoff, so we mod by range_. // floating point roundoff, so we mod by range_.
return value % range_; return value % range_;
} }
float LogUniformSampler::Probability(int64 value) const { float LogUniformSampler::Probability(int64_t value) const {
// Given f(x) = 1/[(x+1) * log_range_] // Given f(x) = 1/[(x+1) * log_range_]
// The value's probability is integral of f(x) from value to (value + 1) // The value's probability is integral of f(x) from value to (value + 1)
// More details: // More details:
...@@ -66,5 +60,76 @@ float LogUniformSampler::Probability(int64 value) const { ...@@ -66,5 +60,76 @@ float LogUniformSampler::Probability(int64 value) const {
return (log((value + 2.0) / (value + 1.0))) / log_range_; return (log((value + 2.0) / (value + 1.0))) / log_range_;
} }
} // namespace random CustomSampler::CustomSampler(int64_t range, const float* probabilities,
unsigned int seed)
: Sampler(range, seed) {
random_engine_ = std::make_shared<std::mt19937_64>(seed_);
real_dist_ = std::make_shared<std::uniform_real_distribution<>>(0, 1);
int_dist_ = std::make_shared<std::uniform_int_distribution<>>(0, range);
alias_probs_ = std::make_shared<std::vector<float>>(range + 1);
alias_ = std::make_shared<std::vector<int64_t>>(range + 1);
probs_ = std::make_shared<std::vector<float>>(range + 1);
std::queue<std::pair<int64_t, float>> bigs;
std::queue<std::pair<int64_t, float>> littles;
for (int64_t i = 0; i <= range; ++i) {
(*probs_)[i] = probabilities[i];
float normal_prob = probabilities[i] * (range + 1);
if (normal_prob - 1.0 > 1e-4) {
bigs.emplace(i, normal_prob);
} else if (1.0 - normal_prob > 1e-4) {
littles.emplace(i, normal_prob);
} else {
(*alias_probs_)[i] = normal_prob;
(*alias_)[i] = -1;
}
}
while ((!littles.empty()) && (!bigs.empty())) {
auto big = bigs.front();
auto little = littles.front();
bigs.pop();
littles.pop();
(*alias_probs_)[little.first] = little.second;
(*alias_)[little.first] = big.first;
auto big_left = big.second - (1 - little.second);
if (big_left - 1.0 > 1e-4) {
bigs.emplace(big.first, big_left);
} else if (1.0 - big_left > 1e-4) {
littles.emplace(big.first, big_left);
} else {
(*alias_probs_)[big.first] = big_left;
(*alias_)[big.first] = -1;
}
}
if (!littles.empty()) { // littles.second is close to 1.0
auto little = littles.front();
(*alias_probs_)[little.first] = 1.0;
(*alias_)[little.first] = -1;
}
if (!bigs.empty()) { // bigs.second is close to 1.0
auto big = bigs.front();
(*alias_probs_)[big.first] = 1.0;
(*alias_)[big.first] = -1;
}
}
int64_t CustomSampler::Sample() const {
auto index = (*int_dist_)(*random_engine_);
auto p = (*real_dist_)(*random_engine_);
if (p > (*alias_probs_)[index]) {
return (*alias_)[index];
} else {
return index;
}
}
float CustomSampler::Probability(int64_t value) const {
return (*probs_)[value];
}
} // namespace math
} // namespace operators
} // namespace paddle } // namespace paddle
...@@ -16,6 +16,8 @@ limitations under the License. */ ...@@ -16,6 +16,8 @@ limitations under the License. */
#include <cstdint> #include <cstdint>
#include <memory> #include <memory>
#include <random> #include <random>
#include <vector>
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
...@@ -27,14 +29,14 @@ namespace math { ...@@ -27,14 +29,14 @@ namespace math {
*/ */
class Sampler { class Sampler {
public: public:
explicit Sampler(int64_t range) : range_(range) { explicit Sampler(int64_t range, unsigned int seed = 0UL) : range_(range) {
PADDLE_ENFORCE_GT(range, 0); // PADDLE_ENFORCE_GT(range, 0, "Range should be greater than 0.");
if (seed == 0) {
std::random_device r; std::random_device r;
seed_ = r(); seed_ = r();
} else {
seed_ = seed;
} }
explicit Sampler(int64_t range, unsigned int seed)
: range_(range), seed_(seed) {
PADDLE_ENFORCE_GT(range, 0);
} }
virtual ~Sampler(); virtual ~Sampler();
// Sample a single value // Sample a single value
...@@ -42,7 +44,7 @@ class Sampler { ...@@ -42,7 +44,7 @@ class Sampler {
// The probability that a single call to Sample() returns the given value. // The probability that a single call to Sample() returns the given value.
virtual float Probability(int64_t value) const = 0; virtual float Probability(int64_t value) const = 0;
int64 range() { return range_; } int64_t range() { return range_; }
protected: protected:
const int64_t range_; const int64_t range_;
...@@ -56,13 +58,11 @@ class Sampler { ...@@ -56,13 +58,11 @@ class Sampler {
*/ */
class UniformSampler : public Sampler { class UniformSampler : public Sampler {
public: public:
explicit UniformSampler(int64_t range); explicit UniformSampler(int64_t range, unsigned int seed = 0UL);
explicit UniformSampler(int64_t range, unsigned int seed);
~UniformSampler() override {} ~UniformSampler() override {}
int64 Sample() const override; int64_t Sample() const override;
float Probability(int64_t value) const override; float Probability(int64_t value) const override;
...@@ -79,13 +79,11 @@ class UniformSampler : public Sampler { ...@@ -79,13 +79,11 @@ class UniformSampler : public Sampler {
*/ */
class LogUniformSampler : public Sampler { class LogUniformSampler : public Sampler {
public: public:
explicit LogUniformSampler(int64_t range); explicit LogUniformSampler(int64_t range, unsigned int seed = 0UL);
explicit LogUniformSampler(int64_t range, unsigned int seed);
~LogUniformSampler() override {} ~LogUniformSampler() override {}
int64 Sample() const override; int64_t Sample() const override;
float Probability(int64_t value) const override; float Probability(int64_t value) const override;
...@@ -95,6 +93,29 @@ class LogUniformSampler : public Sampler { ...@@ -95,6 +93,29 @@ class LogUniformSampler : public Sampler {
std::shared_ptr<std::uniform_real_distribution<>> dist_; std::shared_ptr<std::uniform_real_distribution<>> dist_;
}; };
/**
* Sample integers from [0, range) from custom distribution.
*/
class CustomSampler : public Sampler {
public:
explicit CustomSampler(int64_t range, const float* probabilities,
unsigned int seed = 0UL);
~CustomSampler() override {}
int64_t Sample() const override;
float Probability(int64_t value) const override;
private:
std::shared_ptr<std::vector<float>> alias_probs_;
std::shared_ptr<std::vector<int64_t>> alias_;
std::shared_ptr<std::vector<float>> probs_;
std::shared_ptr<std::mt19937_64> random_engine_;
std::shared_ptr<std::uniform_real_distribution<>> real_dist_;
std::shared_ptr<std::uniform_int_distribution<>> int_dist_;
};
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -35,6 +35,7 @@ class NCEOp : public framework::OperatorWithKernel { ...@@ -35,6 +35,7 @@ class NCEOp : public framework::OperatorWithKernel {
auto x_dims = ctx->GetInputDim("Input"); auto x_dims = ctx->GetInputDim("Input");
auto label_dims = ctx->GetInputDim("Label"); auto label_dims = ctx->GetInputDim("Label");
auto w_dims = ctx->GetInputDim("Weight");
PADDLE_ENFORCE_EQ(x_dims[0], label_dims[0]); PADDLE_ENFORCE_EQ(x_dims[0], label_dims[0]);
int num_true_classes = label_dims.size() == 2 ? label_dims[1] : 1; int num_true_classes = label_dims.size() == 2 ? label_dims[1] : 1;
if (ctx->HasInput("Bias")) { if (ctx->HasInput("Bias")) {
...@@ -98,6 +99,13 @@ class NCEOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -98,6 +99,13 @@ class NCEOpMaker : public framework::OpProtoAndCheckerMaker {
"each sample. And it is a dispensable input. The default value of " "each sample. And it is a dispensable input. The default value of "
"sample is 1.") "sample is 1.")
.AsDispensable(); .AsDispensable();
AddInput(
"CustomDistribution",
"(Tensor) It is used in 'CostumDist' sampler. "
"It is a tensor with shape [num_total_classes]."
"The i-th element is the probsbility of the i-th class being sampled.")
.AsDispensable();
AddOutput("Cost", AddOutput("Cost",
"(Tensor) A tensor of shape [batch_size, 1]. Cost of samples."); "(Tensor) A tensor of shape [batch_size, 1]. Cost of samples.");
AddOutput("SampleLogits", AddOutput("SampleLogits",
...@@ -121,6 +129,17 @@ class NCEOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -121,6 +129,17 @@ class NCEOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<int>("num_neg_samples", AddAttr<int>("num_neg_samples",
"The number of negative classes. The default value is 10.") "The number of negative classes. The default value is 10.")
.SetDefault(10); .SetDefault(10);
AddAttr<int>("sampler",
"(int) Which sampler to be used to sample negative class."
"0: Uniform; 1: LogUniform; 2: CostumDist.")
.SetDefault(0);
AddAttr<int>("seed",
"(int) The seed used in sampler. If it is 0, "
"the sampler will generate a seed randomly.")
.SetDefault(0);
AddAttr<std::vector<int>>("custom_neg_classes", AddAttr<std::vector<int>>("custom_neg_classes",
"This attribute only be used in unitest. Classes " "This attribute only be used in unitest. Classes "
"in this list wiil be used as negative classes " "in this list wiil be used as negative classes "
......
...@@ -19,29 +19,28 @@ limitations under the License. */ ...@@ -19,29 +19,28 @@ limitations under the License. */
#include <vector> #include <vector>
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/sampler.h"
#include "unsupported/Eigen/CXX11/Tensor" #include "unsupported/Eigen/CXX11/Tensor"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
using Sampler = math::Sampler;
template <typename T, int MajorType = Eigen::RowMajor, template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex> typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>; using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
void PrepareSamples(const framework::ExecutionContext& context) { void PrepareSamples(const framework::ExecutionContext& context,
Sampler* sampler) {
auto label = context.Input<Tensor>("Label"); auto label = context.Input<Tensor>("Label");
const int64_t* label_data = label->data<int64_t>(); const int64_t* label_data = label->data<int64_t>();
auto label_dims = label->dims(); auto label_dims = label->dims();
int num_total_classes = context.Attr<int>("num_total_classes"); // int num_total_classes = context.Attr<int>("num_total_classes");
// for unitest // for unitest
std::vector<int> custom_neg_classes = std::vector<int> custom_neg_classes =
context.Attr<std::vector<int>>("custom_neg_classes"); context.Attr<std::vector<int>>("custom_neg_classes");
// random machine
std::random_device rd;
std::mt19937 rng(rd());
std::uniform_int_distribution<int> rand(0, num_total_classes - 1);
auto sample_labels = context.Output<Tensor>("SampleLabels"); auto sample_labels = context.Output<Tensor>("SampleLabels");
auto sample_labels_dims = sample_labels->dims(); auto sample_labels_dims = sample_labels->dims();
...@@ -62,7 +61,7 @@ void PrepareSamples(const framework::ExecutionContext& context) { ...@@ -62,7 +61,7 @@ void PrepareSamples(const framework::ExecutionContext& context) {
} else { } else {
for (; j < sample_labels_dims[1]; ++j) { for (; j < sample_labels_dims[1]; ++j) {
// TODO(wanghaoshuang): support more distribution sampling // TODO(wanghaoshuang): support more distribution sampling
sample_labels_data[index++] = rand(rng); sample_labels_data[index++] = sampler->Sample();
} }
} }
} }
...@@ -72,7 +71,33 @@ template <typename DeviceContext, typename T> ...@@ -72,7 +71,33 @@ template <typename DeviceContext, typename T>
class NCEKernel : public framework::OpKernel<T> { class NCEKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
PrepareSamples<DeviceContext, T>(context); int sampler_type = context.Attr<int>("sampler");
int seed = context.Attr<int>("seed");
int num_total_classes = context.Attr<int>("num_total_classes");
int num_neg_samples = context.Attr<int>("num_neg_samples");
Sampler* sampler;
switch (sampler_type) {
case 0: {
sampler = new math::UniformSampler(num_total_classes - 1, seed);
break;
}
case 1: {
sampler = new math::LogUniformSampler(num_total_classes - 1, seed);
break;
}
case 2: {
auto custom_dist = context.Input<Tensor>("CustomDistribution");
const float* custom_dist_data = custom_dist->data<float>();
PADDLE_ENFORCE_EQ(custom_dist->numel(), num_total_classes);
sampler = new math::CustomSampler(num_total_classes - 1,
custom_dist_data, seed);
break;
}
default: { PADDLE_THROW("Unsupported SamplerType."); }
}
PrepareSamples<DeviceContext, T>(context, sampler);
auto sample_labels = context.Output<Tensor>("SampleLabels"); auto sample_labels = context.Output<Tensor>("SampleLabels");
const int64_t* sample_labels_data = sample_labels->data<int64_t>(); const int64_t* sample_labels_data = sample_labels->data<int64_t>();
auto sample_out = context.Output<Tensor>("SampleLogits"); auto sample_out = context.Output<Tensor>("SampleLogits");
...@@ -85,13 +110,12 @@ class NCEKernel : public framework::OpKernel<T> { ...@@ -85,13 +110,12 @@ class NCEKernel : public framework::OpKernel<T> {
} }
auto out = context.Output<Tensor>("Cost"); auto out = context.Output<Tensor>("Cost");
T* out_data = out->mutable_data<T>(context.GetPlace()); T* out_data = out->mutable_data<T>(context.GetPlace());
int num_neg_samples = context.Attr<int>("num_neg_samples");
int num_total_classes = context.Attr<int>("num_total_classes");
int64_t num_true_class = 1; int64_t num_true_class = 1;
if (label != nullptr) { if (label != nullptr) {
num_true_class = label->dims()[1]; num_true_class = label->dims()[1];
} }
T b = 1. / num_total_classes * num_neg_samples; int64_t sampled_labels_num = sample_labels->dims()[1];
// T b = 1. / num_total_classes * num_neg_samples;
// forward bias // forward bias
auto bias = context.Input<Tensor>("Bias"); auto bias = context.Input<Tensor>("Bias");
if (bias != nullptr) { if (bias != nullptr) {
...@@ -117,22 +141,17 @@ class NCEKernel : public framework::OpKernel<T> { ...@@ -117,22 +141,17 @@ class NCEKernel : public framework::OpKernel<T> {
} }
// forward cost // forward cost
for (int64_t i = 0; i < sample_labels->dims()[0]; ++i) { for (int64_t i = 0; i < sample_labels->dims()[0]; ++i) {
int64_t j = 0;
out_data[i] = 0; out_data[i] = 0;
T w = sample_weight == nullptr ? 1. : sample_weight_data[i]; T w = sample_weight == nullptr ? 1. : sample_weight_data[i];
// for true classes for (int64_t j = 0; j < sampled_labels_num; ++j) {
for (; j < num_true_class; ++j) { int64_t target = sample_labels_data[i * sampled_labels_num + j];
T o = sample_out_data[i * sample_out->dims()[1] + j]; T o = sample_out_data[i * sampled_labels_num + j];
T cost = -log(o / (o + b)); float b = sampler->Probability(target) * num_neg_samples;
out_data[i] += w * cost; T cost = (j < num_true_class) ? -log(o / (o + b)) : -log(b / (o + b));
}
// for sampled neg classes
for (; j < sample_labels->dims()[1]; ++j) {
T o = sample_out_data[i * sample_out->dims()[1] + j];
T cost = -log(b / (o + b));
out_data[i] += w * cost; out_data[i] += w * cost;
} }
} }
delete sampler;
} }
}; };
...@@ -158,20 +177,45 @@ class NCEGradKernel : public framework::OpKernel<T> { ...@@ -158,20 +177,45 @@ class NCEGradKernel : public framework::OpKernel<T> {
if (label != nullptr) { if (label != nullptr) {
num_true_class = label->dims()[1]; num_true_class = label->dims()[1];
} }
T b = 1. / num_total_classes * num_neg_samples;
int sampler_type = context.Attr<int>("sampler");
int seed = context.Attr<int>("seed");
Sampler* sampler;
switch (sampler_type) {
case 0: {
sampler = new math::UniformSampler(num_total_classes - 1, seed);
break;
}
case 1: {
sampler = new math::LogUniformSampler(num_total_classes - 1, seed);
break;
}
case 2: {
auto custom_dist = context.Input<Tensor>("CustomDistribution");
const float* custom_dist_data = custom_dist->data<float>();
PADDLE_ENFORCE_EQ(custom_dist->numel(), num_total_classes);
sampler = new math::CustomSampler(num_total_classes - 1,
custom_dist_data, seed);
break;
}
default: { PADDLE_THROW("Unsupported SamplerType."); }
}
// T b = 1. / num_total_classes * num_neg_samples;
Tensor sample_grad; // tmp tensor Tensor sample_grad; // tmp tensor
T* sample_grad_data = T* sample_grad_data =
sample_grad.mutable_data<T>(sample_labels->dims(), context.GetPlace()); sample_grad.mutable_data<T>(sample_labels->dims(), context.GetPlace());
// backward cost // backward cost
for (int64_t i = 0; i < sample_labels->numel(); ++i) { for (int64_t i = 0; i < sample_labels->numel(); ++i) {
int64_t label_idx = i % sample_labels->dims()[1];
int64_t sample_idx = i / sample_labels->dims()[1];
float b = sampler->Probability(sample_labels_data[i]) * num_neg_samples;
T o = sample_out_data[i]; T o = sample_out_data[i];
T w = sample_weight == nullptr T w = sample_weight == nullptr ? 1 : sample_weight_data[sample_idx];
? 1 sample_grad_data[i] = label_idx < num_true_class
: sample_weight_data[i / sample_labels->dims()[1]];
sample_grad_data[i] = (i % sample_labels->dims()[1]) < num_true_class
? w * (b / (o + b)) * (o - 1) ? w * (b / (o + b)) * (o - 1)
: w * (o * (1 - o) / (o + b)); : w * (o * (1 - o) / (o + b));
sample_grad_data[i] *= d_out_data[i / sample_labels->dims()[1]]; sample_grad_data[i] *= d_out_data[sample_idx];
} }
// get d_bias // get d_bias
auto d_bias = context.Output<Tensor>(framework::GradVarName("Bias")); auto d_bias = context.Output<Tensor>(framework::GradVarName("Bias"));
...@@ -207,6 +251,7 @@ class NCEGradKernel : public framework::OpKernel<T> { ...@@ -207,6 +251,7 @@ class NCEGradKernel : public framework::OpKernel<T> {
w_matrix.chip(sample_labels_data[i], 0) * sample_grad_data[i]; w_matrix.chip(sample_labels_data[i], 0) * sample_grad_data[i];
} }
} }
delete sampler;
} }
}; };
} // namespace operators } // namespace operators
......
...@@ -4313,7 +4313,10 @@ def nce(input, ...@@ -4313,7 +4313,10 @@ def nce(input,
param_attr=None, param_attr=None,
bias_attr=None, bias_attr=None,
num_neg_samples=None, num_neg_samples=None,
name=None): name=None,
sampler="uniform",
custom_dist=None,
seed=0):
""" """
${comment} ${comment}
...@@ -4336,6 +4339,14 @@ def nce(input, ...@@ -4336,6 +4339,14 @@ def nce(input,
num_neg_samples (int): ${num_neg_samples_comment} num_neg_samples (int): ${num_neg_samples_comment}
name (str|None): A name for this layer(optional). If set None, the layer name (str|None): A name for this layer(optional). If set None, the layer
will be named automatically. Default: None. will be named automatically. Default: None.
sampler (str): The sampler used to sample class from negtive classes.
It can be 'uniform', 'log_uniform' or 'custom_dist'.
default: 'uniform'.
custom_dist (Variable): A tensor with shape [num_total_classes].
It is used when sampler is set to 'custom_dist'.
custom_dist[i] is the probsbility of i-th class to be sampled.
default: None.
seed (int): The seed used in sampler. default: 0.
Returns: Returns:
Variable: The output nce loss. Variable: The output nce loss.
...@@ -4365,6 +4376,16 @@ def nce(input, ...@@ -4365,6 +4376,16 @@ def nce(input,
loss = layers.nce(input=embs, label=words[label_word], loss = layers.nce(input=embs, label=words[label_word],
num_total_classes=dict_size, param_attr='nce.w', num_total_classes=dict_size, param_attr='nce.w',
bias_attr='nce.b') bias_attr='nce.b')
#or use custom distribution
dist = fluid.layers.assign(input=np.array([0.05,0.5,0.1,0.3,0.05]).astype("float32"))
loss = layers.nce(input=embs, label=words[label_word],
num_total_classes=5, param_attr='nce.w',
bias_attr='nce.b',
num_neg_samples=3,
sampler="custom_dist",
custom_dist=dist)
""" """
helper = LayerHelper('nce', **locals()) helper = LayerHelper('nce', **locals())
assert isinstance(input, Variable) assert isinstance(input, Variable)
...@@ -4399,9 +4420,31 @@ def nce(input, ...@@ -4399,9 +4420,31 @@ def nce(input,
else: else:
num_neg_samples = int(num_neg_samples) num_neg_samples = int(num_neg_samples)
inputs = {
'Input': input,
'Label': label,
'Weight': w,
'Bias': b,
'SampleWeight': sample_weight if sample_weight is not None else []
}
if sampler == "uniform":
sampler = 0
elif sampler == "log_uniform":
sampler = 1
elif sampler == "custom_dist":
assert custom_dist is not None
assert isinstance(custom_dist, Variable)
inputs['CustomDistribution'] = custom_dist
sampler = 2
else:
raise Exception("Unsupported sampler type.")
attrs = { attrs = {
'num_total_classes': int(num_total_classes), 'num_total_classes': int(num_total_classes),
'num_neg_samples': num_neg_samples 'num_neg_samples': num_neg_samples,
'seed': seed,
'sampler': sampler
} }
helper.append_op( helper.append_op(
......
...@@ -68,7 +68,9 @@ class TestNCE(OpTest): ...@@ -68,7 +68,9 @@ class TestNCE(OpTest):
self.attrs = { self.attrs = {
'num_total_classes': num_classes, 'num_total_classes': num_classes,
'num_neg_samples': num_neg_samples, 'num_neg_samples': num_neg_samples,
'custom_neg_classes': list(range(num_neg_samples)) 'custom_neg_classes': list(range(num_neg_samples)),
'seed': 0,
'sampler': 0
} }
self.inputs = { self.inputs = {
'Input': input, 'Input': input,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册