From 56a4912b763d5cb0f21d163c48a46f0ebc795913 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Tue, 27 Nov 2018 16:39:08 +0800 Subject: [PATCH] Make NCE_OP more efficient and support SelectedRows (#14469) * Fix truncated normal. * Fix. * Make nce support more distribution. * Fix API.spec. * Fix python API. * Fix. test=develop * Fix API.spec test=develop * Fix sampler. * Fix order of arguments in python API. test=develop * NCE add selectedrows support * NCE update weighted sampling * fix bugs in nce_op, and assign_value_op optimized * fix bugs in nce_op, revert assign_value_op * nce_op optimize * nce_op optimize * nce_op optimize * add selectedRows test later test=develop * add selectedRows supported * add selectedRows supported test=develop * add selectedRows supported * add nce selectedRows supported, test=develop * add nce selectedRows supported * add nce selectedRows supported, test=develop * fix height in nce, test=develop * add ut * add ut, test=develop * make AutoGrownIndex inline test=develop * fix tinny error, test=develop --- paddle/fluid/API.spec | 2 +- paddle/fluid/operators/math/sampler.cc | 63 +----- paddle/fluid/operators/math/sampler.h | 13 +- paddle/fluid/operators/nce_op.cc | 68 ++++++- paddle/fluid/operators/nce_op.h | 182 +++++++++++++----- python/paddle/fluid/layers/nn.py | 95 ++++++--- .../paddle/fluid/tests/unittests/test_nce.py | 118 +++++++++++- 7 files changed, 398 insertions(+), 143 deletions(-) diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 0a71f15343..e15fdc8257 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -97,7 +97,7 @@ paddle.fluid.layers.warpctc ArgSpec(args=['input', 'label', 'blank', 'norm_by_ti paddle.fluid.layers.sequence_reshape ArgSpec(args=['input', 'new_dim'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.transpose ArgSpec(args=['x', 'perm', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.im2sequence ArgSpec(args=['input', 'filter_size', 'stride', 'padding', 'input_image_size', 'out_stride', 'name'], varargs=None, keywords=None, defaults=(1, 1, 0, None, 1, None)) -paddle.fluid.layers.nce ArgSpec(args=['input', 'label', 'num_total_classes', 'sample_weight', 'param_attr', 'bias_attr', 'num_neg_samples', 'name', 'sampler', 'custom_dist', 'seed'], varargs=None, keywords=None, defaults=(None, None, None, None, None, 'uniform', None, 0)) +paddle.fluid.layers.nce ArgSpec(args=['input', 'label', 'num_total_classes', 'sample_weight', 'param_attr', 'bias_attr', 'num_neg_samples', 'name', 'sampler', 'custom_dist', 'seed', 'is_sparse'], varargs=None, keywords=None, defaults=(None, None, None, None, None, 'uniform', None, 0, False)) paddle.fluid.layers.hsigmoid ArgSpec(args=['input', 'label', 'num_classes', 'param_attr', 'bias_attr', 'name'], varargs=None, keywords=None, defaults=(None, None, None)) paddle.fluid.layers.beam_search ArgSpec(args=['pre_ids', 'pre_scores', 'ids', 'scores', 'beam_size', 'end_id', 'level', 'name'], varargs=None, keywords=None, defaults=(0, None)) paddle.fluid.layers.row_conv ArgSpec(args=['input', 'future_context_size', 'param_attr', 'act'], varargs=None, keywords=None, defaults=(None, None)) diff --git a/paddle/fluid/operators/math/sampler.cc b/paddle/fluid/operators/math/sampler.cc index 690d6f6baa..2708f3bcd8 100644 --- a/paddle/fluid/operators/math/sampler.cc +++ b/paddle/fluid/operators/math/sampler.cc @@ -60,75 +60,30 @@ float LogUniformSampler::Probability(int64_t value) const { return (log((value + 2.0) / (value + 1.0))) / log_range_; } -CustomSampler::CustomSampler(int64_t range, const float* probabilities, +CustomSampler::CustomSampler(int64_t range, const float *probabilities, + const int *alias, const float *alias_probabilities, unsigned int seed) : Sampler(range, seed) { - random_engine_ = std::make_shared(seed_); + random_engine_ = std::make_shared(seed_); real_dist_ = std::make_shared>(0, 1); int_dist_ = std::make_shared>(0, range); - alias_probs_ = std::make_shared>(range + 1); - alias_ = std::make_shared>(range + 1); - probs_ = std::make_shared>(range + 1); - - std::queue> bigs; - std::queue> littles; - for (int64_t i = 0; i <= range; ++i) { - (*probs_)[i] = probabilities[i]; - float normal_prob = probabilities[i] * (range + 1); - if (normal_prob - 1.0 > 1e-4) { - bigs.emplace(i, normal_prob); - } else if (1.0 - normal_prob > 1e-4) { - littles.emplace(i, normal_prob); - } else { - (*alias_probs_)[i] = normal_prob; - (*alias_)[i] = -1; - } - } - - while ((!littles.empty()) && (!bigs.empty())) { - auto big = bigs.front(); - auto little = littles.front(); - bigs.pop(); - littles.pop(); - (*alias_probs_)[little.first] = little.second; - (*alias_)[little.first] = big.first; - auto big_left = big.second - (1 - little.second); - if (big_left - 1.0 > 1e-4) { - bigs.emplace(big.first, big_left); - } else if (1.0 - big_left > 1e-4) { - littles.emplace(big.first, big_left); - } else { - (*alias_probs_)[big.first] = big_left; - (*alias_)[big.first] = -1; - } - } - if (!littles.empty()) { // littles.second is close to 1.0 - auto little = littles.front(); - (*alias_probs_)[little.first] = 1.0; - (*alias_)[little.first] = -1; - } - - if (!bigs.empty()) { // bigs.second is close to 1.0 - auto big = bigs.front(); - (*alias_probs_)[big.first] = 1.0; - (*alias_)[big.first] = -1; - } + alias_probs_ = alias_probabilities; + probs_ = probabilities; + alias_ = alias; } int64_t CustomSampler::Sample() const { auto index = (*int_dist_)(*random_engine_); auto p = (*real_dist_)(*random_engine_); - if (p > (*alias_probs_)[index]) { - return (*alias_)[index]; + if (p > alias_probs_[index]) { + return alias_[index]; } else { return index; } } -float CustomSampler::Probability(int64_t value) const { - return (*probs_)[value]; -} +float CustomSampler::Probability(int64_t value) const { return probs_[value]; } } // namespace math } // namespace operators diff --git a/paddle/fluid/operators/math/sampler.h b/paddle/fluid/operators/math/sampler.h index 836cdad51f..98e0b898a5 100644 --- a/paddle/fluid/operators/math/sampler.h +++ b/paddle/fluid/operators/math/sampler.h @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once + #include #include #include @@ -38,9 +39,12 @@ class Sampler { seed_ = seed; } } + virtual ~Sampler(); + // Sample a single value virtual int64_t Sample() const = 0; + // The probability that a single call to Sample() returns the given value. virtual float Probability(int64_t value) const = 0; @@ -99,6 +103,7 @@ class LogUniformSampler : public Sampler { class CustomSampler : public Sampler { public: explicit CustomSampler(int64_t range, const float* probabilities, + const int* alias, const float* alias_probabilities, unsigned int seed = 0UL); ~CustomSampler() override {} @@ -108,10 +113,10 @@ class CustomSampler : public Sampler { float Probability(int64_t value) const override; private: - std::shared_ptr> alias_probs_; - std::shared_ptr> alias_; - std::shared_ptr> probs_; - std::shared_ptr random_engine_; + const float* alias_probs_; + const int* alias_; + const float* probs_; + std::shared_ptr random_engine_; std::shared_ptr> real_dist_; std::shared_ptr> int_dist_; }; diff --git a/paddle/fluid/operators/nce_op.cc b/paddle/fluid/operators/nce_op.cc index 9b0d45ae5b..655e171e63 100644 --- a/paddle/fluid/operators/nce_op.cc +++ b/paddle/fluid/operators/nce_op.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include "paddle/fluid/operators/nce_op.h" +#include #include namespace paddle { @@ -25,7 +26,7 @@ class NCEOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { + void InferShape(framework::InferShapeContext *ctx) const override { PADDLE_ENFORCE(ctx->HasInput("Input")); PADDLE_ENFORCE(ctx->HasInput("Label")); PADDLE_ENFORCE(ctx->HasInput("Weight")); @@ -67,7 +68,7 @@ class NCEOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { + const framework::ExecutionContext &ctx) const override { return framework::OpKernelType( framework::ToDataType(ctx.Input("Input")->type()), platform::CPUPlace()); @@ -101,11 +102,24 @@ class NCEOpMaker : public framework::OpProtoAndCheckerMaker { .AsDispensable(); AddInput( - "CustomDistribution", + "CustomDistProbs", "(Tensor) It is used in 'CostumDist' sampler. " "It is a tensor with shape [num_total_classes]." "The i-th element is the probsbility of the i-th class being sampled.") .AsDispensable(); + AddInput( + "CustomDistAlias", + "(Tensor) It is used in 'CostumDist' sampler. " + "It is a tensor with shape [num_total_classes]." + "The i-th element is the probsbility of the i-th class being sampled.") + .AsDispensable(); + AddInput( + "CustomDistAliasProbs", + "(Tensor) It is used in 'CostumDist' sampler. " + "It is a tensor with shape [num_total_classes]." + "The i-th element is the probsbility of the i-th class being sampled.") + .AsDispensable(); + AddOutput("Cost", "(Tensor) A tensor of shape [batch_size, 1]. Cost of samples."); AddOutput("SampleLogits", @@ -124,21 +138,22 @@ class NCEOpMaker : public framework::OpProtoAndCheckerMaker { "kernel to compute grads." "") .AsIntermediate(); + AddAttr("num_total_classes", "Total number of classes in all samples."); AddAttr("num_neg_samples", "The number of negative classes. The default value is 10.") .SetDefault(10); - AddAttr("sampler", "(int) Which sampler to be used to sample negative class." "0: Uniform; 1: LogUniform; 2: CostumDist.") .SetDefault(0); - AddAttr("seed", "(int) The seed used in sampler. If it is 0, " "the sampler will generate a seed randomly.") .SetDefault(0); + AddAttr("is_sparse", "(boolean, default false) Sparse update.") + .SetDefault(false); AddAttr>("custom_neg_classes", "This attribute only be used in unitest. Classes " @@ -156,11 +171,19 @@ By default this operator uses a uniform distribution for sampling. } }; +class NCEOpGradDescMaker : public framework::DefaultGradOpDescMaker { + using ::paddle::framework::DefaultGradOpDescMaker< + true>::DefaultGradOpDescMaker; + + protected: + virtual std::string GradOpType() const { return "nce_grad"; } +}; + class NCEOpGrad : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { + void InferShape(framework::InferShapeContext *ctx) const override { PADDLE_ENFORCE(ctx->HasInput("Input")); PADDLE_ENFORCE(ctx->HasInput("Weight")); PADDLE_ENFORCE(ctx->HasInput("Cost")); @@ -190,20 +213,45 @@ class NCEOpGrad : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { + const framework::ExecutionContext &ctx) const override { return framework::OpKernelType( framework::ToDataType(ctx.Input("Input")->type()), platform::CPUPlace()); } }; +class NCEOpGradVarTypeInference : public framework::VarTypeInference { + public: + void operator()(const framework::OpDesc &op_desc, + framework::BlockDesc *block) const override { + auto weight_grad = op_desc.Output(framework::GradVarName("Weight")).front(); + auto bias_grad = op_desc.Output(framework::GradVarName("Bias")).front(); + + auto attr = op_desc.GetAttr("is_sparse"); + bool is_sparse = boost::get(attr); + if (is_sparse) { + VLOG(30) << "nce_op_grad op " << weight_grad << " and " << bias_grad + << " is set to SelectedRows"; + block->Var(weight_grad) + ->SetType(framework::proto::VarType::SELECTED_ROWS); + block->Var(bias_grad)->SetType(framework::proto::VarType::SELECTED_ROWS); + } else { + VLOG(30) << "nce_op_grad op " << weight_grad << " and " << bias_grad + << " is set to LoDTensor"; + block->Var(weight_grad)->SetType(framework::proto::VarType::LOD_TENSOR); + block->Var(bias_grad)->SetType(framework::proto::VarType::LOD_TENSOR); + } + block->Var(weight_grad)->SetDataType(block->Var("Input")->GetDataType()); + block->Var(bias_grad)->SetDataType(block->Var("Input")->GetDataType()); + } +}; + } // namespace operators } // namespace paddle namespace ops = paddle::operators; -REGISTER_OPERATOR(nce, ops::NCEOp, ops::NCEOpMaker, - paddle::framework::DefaultGradOpDescMaker); -REGISTER_OPERATOR(nce_grad, ops::NCEOpGrad); +REGISTER_OPERATOR(nce, ops::NCEOp, ops::NCEOpGradDescMaker, ops::NCEOpMaker); +REGISTER_OPERATOR(nce_grad, ops::NCEOpGrad, ops::NCEOpGradVarTypeInference); REGISTER_OP_CPU_KERNEL(nce, ops::NCEKernel, ops::NCEKernel); REGISTER_OP_CPU_KERNEL(nce_grad, diff --git a/paddle/fluid/operators/nce_op.h b/paddle/fluid/operators/nce_op.h index e9af8ad4ce..f2ca6ec247 100644 --- a/paddle/fluid/operators/nce_op.h +++ b/paddle/fluid/operators/nce_op.h @@ -16,26 +16,32 @@ limitations under the License. */ #include #include +#include #include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/operators/math/sampler.h" #include "unsupported/Eigen/CXX11/Tensor" + namespace paddle { namespace operators { using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; +using SelectedRows = framework::SelectedRows; using Sampler = math::Sampler; +using DDim = framework::DDim; template using EigenMatrix = framework::EigenMatrix; template -void PrepareSamples(const framework::ExecutionContext& context, - Sampler* sampler) { +void PrepareSamples(const framework::ExecutionContext &context, + Sampler *sampler) { auto label = context.Input("Label"); - const int64_t* label_data = label->data(); + const int64_t *label_data = label->data(); auto label_dims = label->dims(); // int num_total_classes = context.Attr("num_total_classes"); // for unitest @@ -44,7 +50,7 @@ void PrepareSamples(const framework::ExecutionContext& context, auto sample_labels = context.Output("SampleLabels"); auto sample_labels_dims = sample_labels->dims(); - int64_t* sample_labels_data = + int64_t *sample_labels_data = sample_labels->mutable_data(context.GetPlace()); int num_label = label_dims.size() == 2 ? label_dims[1] : 1; @@ -70,13 +76,13 @@ void PrepareSamples(const framework::ExecutionContext& context, template class NCEKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext& context) const override { + void Compute(const framework::ExecutionContext &context) const override { int sampler_type = context.Attr("sampler"); int seed = context.Attr("seed"); int num_total_classes = context.Attr("num_total_classes"); int num_neg_samples = context.Attr("num_neg_samples"); - Sampler* sampler; + Sampler *sampler; switch (sampler_type) { case 0: { sampler = new math::UniformSampler(num_total_classes - 1, seed); @@ -87,11 +93,19 @@ class NCEKernel : public framework::OpKernel { break; } case 2: { - auto custom_dist = context.Input("CustomDistribution"); - const float* custom_dist_data = custom_dist->data(); - PADDLE_ENFORCE_EQ(custom_dist->numel(), num_total_classes); - sampler = new math::CustomSampler(num_total_classes - 1, - custom_dist_data, seed); + auto dist_probs = context.Input("CustomDistProbs"); + auto dist_alias = context.Input("CustomDistAlias"); + auto dist_alias_probs = context.Input("CustomDistAliasProbs"); + + PADDLE_ENFORCE_EQ(dist_probs->numel(), num_total_classes); + PADDLE_ENFORCE_EQ(dist_alias->numel(), num_total_classes); + PADDLE_ENFORCE_EQ(dist_alias_probs->numel(), num_total_classes); + + const float *probs_data = dist_probs->data(); + const int *alias_data = dist_alias->data(); + const float *alias_probs_data = dist_alias_probs->data(); + sampler = new math::CustomSampler(num_total_classes - 1, probs_data, + alias_data, alias_probs_data, seed); break; } default: { PADDLE_THROW("Unsupported SamplerType."); } @@ -99,17 +113,17 @@ class NCEKernel : public framework::OpKernel { PrepareSamples(context, sampler); auto sample_labels = context.Output("SampleLabels"); - const int64_t* sample_labels_data = sample_labels->data(); + const int64_t *sample_labels_data = sample_labels->data(); auto sample_out = context.Output("SampleLogits"); - T* sample_out_data = sample_out->mutable_data(context.GetPlace()); + T *sample_out_data = sample_out->mutable_data(context.GetPlace()); auto label = context.Input("Label"); auto sample_weight = context.Input("SampleWeight"); - const T* sample_weight_data = nullptr; + const T *sample_weight_data = nullptr; if (sample_weight != nullptr) { sample_weight_data = sample_weight->data(); } auto out = context.Output("Cost"); - T* out_data = out->mutable_data(context.GetPlace()); + T *out_data = out->mutable_data(context.GetPlace()); int64_t num_true_class = 1; if (label != nullptr) { num_true_class = label->dims()[1]; @@ -119,7 +133,7 @@ class NCEKernel : public framework::OpKernel { // forward bias auto bias = context.Input("Bias"); if (bias != nullptr) { - const T* bias_data = bias->data(); + const T *bias_data = bias->data(); for (int64_t i = 0; i < sample_labels->numel(); ++i) { sample_out_data[i] = bias_data[sample_labels_data[i]]; } @@ -158,16 +172,16 @@ class NCEKernel : public framework::OpKernel { template class NCEGradKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext& context) const override { + void Compute(const framework::ExecutionContext &context) const override { auto d_out = context.Input(framework::GradVarName("Cost")); - const T* d_out_data = d_out->data(); + const T *d_out_data = d_out->data(); auto label = context.Input("Label"); auto sample_out = context.Input("SampleLogits"); - const T* sample_out_data = sample_out->data(); + const T *sample_out_data = sample_out->data(); auto sample_labels = context.Input("SampleLabels"); - const int64_t* sample_labels_data = sample_labels->data(); + const int64_t *sample_labels_data = sample_labels->data(); auto sample_weight = context.Input("SampleWeight"); - const T* sample_weight_data = nullptr; + const T *sample_weight_data = nullptr; if (sample_weight != nullptr) { sample_weight_data = sample_weight->data(); } @@ -180,7 +194,7 @@ class NCEGradKernel : public framework::OpKernel { int sampler_type = context.Attr("sampler"); int seed = context.Attr("seed"); - Sampler* sampler; + Sampler *sampler; switch (sampler_type) { case 0: { sampler = new math::UniformSampler(num_total_classes - 1, seed); @@ -191,11 +205,19 @@ class NCEGradKernel : public framework::OpKernel { break; } case 2: { - auto custom_dist = context.Input("CustomDistribution"); - const float* custom_dist_data = custom_dist->data(); - PADDLE_ENFORCE_EQ(custom_dist->numel(), num_total_classes); - sampler = new math::CustomSampler(num_total_classes - 1, - custom_dist_data, seed); + auto dist_probs = context.Input("CustomDistProbs"); + auto dist_alias = context.Input("CustomDistAlias"); + auto dist_alias_probs = context.Input("CustomDistAliasProbs"); + + PADDLE_ENFORCE_EQ(dist_probs->numel(), num_total_classes); + PADDLE_ENFORCE_EQ(dist_alias->numel(), num_total_classes); + PADDLE_ENFORCE_EQ(dist_alias_probs->numel(), num_total_classes); + + const float *probs_data = dist_probs->data(); + const int *alias_data = dist_alias->data(); + const float *alias_probs_data = dist_alias_probs->data(); + sampler = new math::CustomSampler(num_total_classes - 1, probs_data, + alias_data, alias_probs_data, seed); break; } default: { PADDLE_THROW("Unsupported SamplerType."); } @@ -203,7 +225,7 @@ class NCEGradKernel : public framework::OpKernel { // T b = 1. / num_total_classes * num_neg_samples; Tensor sample_grad; // tmp tensor - T* sample_grad_data = + T *sample_grad_data = sample_grad.mutable_data(sample_labels->dims(), context.GetPlace()); // backward cost for (int64_t i = 0; i < sample_labels->numel(); ++i) { @@ -217,32 +239,105 @@ class NCEGradKernel : public framework::OpKernel { : w * (o * (1 - o) / (o + b)); sample_grad_data[i] *= d_out_data[sample_idx]; } - // get d_bias - auto d_bias = context.Output(framework::GradVarName("Bias")); - if (d_bias != nullptr) { - T* d_bias_data = d_bias->mutable_data(context.GetPlace()); - std::fill(d_bias_data, d_bias_data + d_bias->numel(), 0.0); + + bool is_sparse = context.Attr("is_sparse"); + + if (!is_sparse) { + // get d_bias + auto d_bias = context.Output(framework::GradVarName("Bias")); + if (d_bias != nullptr) { + T *d_bias_data = d_bias->mutable_data(context.GetPlace()); + std::fill(d_bias_data, d_bias_data + d_bias->numel(), 0.0); + for (int64_t i = 0; i < sample_labels->numel(); ++i) { + d_bias_data[sample_labels_data[i]] += sample_grad_data[i]; + } + } + // get d_w + auto d_w = context.Output(framework::GradVarName("Weight")); + if (d_w != nullptr) { + auto d_w_data = d_w->mutable_data(context.GetPlace()); + std::fill(d_w_data, d_w_data + d_w->numel(), 0.0); + auto d_w_matrix = EigenMatrix::From(*d_w); + auto x_matrix = EigenMatrix::From(*(context.Input("Input"))); + for (int64_t i = 0; i < sample_labels->numel(); ++i) { + d_w_matrix.chip(sample_labels_data[i], 0) += + x_matrix.chip(static_cast(i / sample_labels->dims()[1]), 0) * + sample_grad_data[i]; + } + } + } else { + std::vector labels; for (int64_t i = 0; i < sample_labels->numel(); ++i) { - d_bias_data[sample_labels_data[i]] += sample_grad_data[i]; + labels.push_back(sample_labels_data[i]); } - } - // get d_w - auto d_w = context.Output(framework::GradVarName("Weight")); - if (d_w != nullptr) { - auto d_w_data = d_w->mutable_data(context.GetPlace()); - std::fill(d_w_data, d_w_data + d_w->numel(), 0.0); - auto d_w_matrix = EigenMatrix::From(*d_w); + std::set st(labels.begin(), labels.end()); + labels.assign(st.begin(), st.end()); + + auto *bias_var = context.InputVar("Bias"); + DDim bias_dim; + if (bias_var->IsType()) { + bias_dim = context.Input("Bias")->dims(); + } else if (bias_var->IsType()) { + auto *table_t = context.Input("Bias"); + bias_dim = table_t->value().dims(); + } else { + PADDLE_THROW( + "The parameter Bias of a NCE_OP " + "must be either LoDTensor or SelectedRows"); + } + + auto d_bias = + context.Output(framework::GradVarName("Bias")); + d_bias->set_rows(labels); + d_bias->set_height(bias_dim[0]); + + d_bias->mutable_value()->Resize( + {static_cast(labels.size()), bias_dim[1]}); + T *d_bias_data = + d_bias->mutable_value()->mutable_data(context.GetPlace()); + std::fill(d_bias_data, d_bias_data + labels.size(), 0.0); + for (int64_t i = 0; i < sample_labels->numel(); ++i) { + d_bias_data[d_bias->Index(sample_labels_data[i])] += + sample_grad_data[i]; + } + + auto *table_var = context.InputVar("Weight"); + DDim table_dim; + if (table_var->IsType()) { + table_dim = context.Input("Weight")->dims(); + } else if (table_var->IsType()) { + auto *table_t = context.Input("Weight"); + table_dim = table_t->value().dims(); + } else { + PADDLE_THROW( + "The parameter Weight of a NCE_OP " + "must be either LoDTensor or SelectedRows"); + } + + auto d_w = context.Output(framework::GradVarName("Weight")); + + d_w->set_rows(labels); + d_w->set_height(table_dim[0]); + + auto *d_table_value = d_w->mutable_value(); + d_table_value->Resize( + {static_cast(labels.size()), table_dim[1]}); + auto d_w_data = d_table_value->mutable_data(context.GetPlace()); + std::fill(d_w_data, d_w_data + d_table_value->numel(), 0.0); + + auto d_w_matrix = EigenMatrix::From(*d_table_value); auto x_matrix = EigenMatrix::From(*(context.Input("Input"))); for (int64_t i = 0; i < sample_labels->numel(); ++i) { - d_w_matrix.chip(sample_labels_data[i], 0) += + d_w_matrix.chip(d_w->Index(sample_labels_data[i]), 0) += x_matrix.chip(static_cast(i / sample_labels->dims()[1]), 0) * sample_grad_data[i]; } } + // get d_x auto d_x = context.Output(framework::GradVarName("Input")); if (d_x != nullptr) { - auto* d_x_data = d_x->mutable_data(context.GetPlace()); + auto *d_x_data = d_x->mutable_data(context.GetPlace()); std::fill(d_x_data, d_x_data + d_x->numel(), 0.0); auto d_x_matrix = EigenMatrix::From(*d_x); auto w_matrix = EigenMatrix::From(*(context.Input("Weight"))); @@ -251,6 +346,7 @@ class NCEGradKernel : public framework::OpKernel { w_matrix.chip(sample_labels_data[i], 0) * sample_grad_data[i]; } } + delete sampler; } }; diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 40a3649973..48f571a7cc 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -4394,7 +4394,8 @@ def nce(input, name=None, sampler="uniform", custom_dist=None, - seed=0): + seed=0, + is_sparse=False): """ ${comment} @@ -4420,11 +4421,12 @@ def nce(input, sampler (str): The sampler used to sample class from negtive classes. It can be 'uniform', 'log_uniform' or 'custom_dist'. default: 'uniform'. - custom_dist (Variable): A tensor with shape [num_total_classes]. + custom_dist (float[]): A float[] with size=num_total_classes. It is used when sampler is set to 'custom_dist'. custom_dist[i] is the probsbility of i-th class to be sampled. default: None. seed (int): The seed used in sampler. default: 0. + is_sparse(bool): The flag indicating whether to use sparse update, the weight@GRAD and bias@GRAD will be changed to SelectedRows. Returns: Variable: The output nce loss. @@ -4476,12 +4478,7 @@ def nce(input, shape=[num_total_classes, dim], is_bias=False, dtype=input.dtype) - inputs = { - 'Input': input, - 'Label': label, - 'Weight': w, - 'SampleWeight': sample_weight if sample_weight is not None else [] - } + inputs = {} if helper.bias_attr: b = helper.create_parameter( attr=helper.bias_attr, @@ -4493,18 +4490,10 @@ def nce(input, sample_logits = helper.create_variable_for_type_inference(dtype=input.dtype) sample_labels = helper.create_variable_for_type_inference(dtype=label.dtype) - if num_neg_samples is None: - num_neg_samples = 10 - else: - num_neg_samples = int(num_neg_samples) - - inputs = { - 'Input': input, - 'Label': label, - 'Weight': w, - 'Bias': b, - 'SampleWeight': sample_weight if sample_weight is not None else [] - } + inputs['Input'] = input + inputs['Label'] = label + inputs['Weight'] = w + inputs['SampleWeight'] = sample_weight if sample_weight is not None else [] if sampler == "uniform": sampler = 0 @@ -4512,17 +4501,73 @@ def nce(input, sampler = 1 elif sampler == "custom_dist": assert custom_dist is not None - assert isinstance(custom_dist, Variable) - inputs['CustomDistribution'] = custom_dist + # assert isinstance(custom_dist, Variable) + + custom_dist_len = len(custom_dist) + alias_probs_ = [0] * custom_dist_len + alias_ = [0] * custom_dist_len + bigs = [] + littles = [] + for i in range(custom_dist_len): + normal_prob = custom_dist[i] * custom_dist_len + if normal_prob - 1.0 > 1e-4: + bigs.append((i, normal_prob)) + elif 1.0 - normal_prob > 1e-4: + littles.append((i, normal_prob)) + else: + alias_probs_[i] = normal_prob + alias_[i] = -1 + + while len(bigs) and len(littles): + big = bigs.pop(0) + little = littles.pop(0) + + big_idx = big[0] + big_prob = big[1] + + alias_probs_[little[0]] = little[1] + alias_[little[0]] = big_idx + big_left = big[1] + little[1] - 1 + if big_left - 1.0 > 1e-4: + bigs.append((big_idx, big_left)) + elif 1.0 - big_left > 1e-4: + littles.append((big_idx, big_left)) + else: + alias_probs_[big_idx] = big_left + alias_[big_idx] = -1 + + if len(bigs): + big = bigs.pop(0) + alias_probs_[big[0]] = 1.0 + alias_[big[0]] = -1 + if len(littles): + little = littles.pop(0) + alias_probs_[little[0]] = 1.0 + alias_[little[0]] = -1 + + probs = assign(input=np.array(custom_dist).astype('float32')) + custom_alias = assign(input=np.array(alias_).astype('int32')) + custom_alias_probs = assign( + input=np.array(alias_probs_).astype('float32')) + + inputs['CustomDistProbs'] = probs + inputs['CustomDistAlias'] = custom_alias + inputs['CustomDistAliasProbs'] = custom_alias_probs sampler = 2 else: raise Exception("Unsupported sampler type.") + if num_neg_samples is None: + num_neg_samples = 10 + else: + num_neg_samples = int(num_neg_samples) + attrs = { 'num_total_classes': int(num_total_classes), 'num_neg_samples': num_neg_samples, 'seed': seed, - 'sampler': sampler + 'sampler': sampler, + 'is_sparse': is_sparse } helper.append_op( @@ -6474,7 +6519,7 @@ def crop(x, shape=None, offsets=None, name=None): helper = LayerHelper('crop', **locals()) if not (isinstance(shape, list) or isinstance(shape, tuple) or \ - isinstance(shape, Variable)): + isinstance(shape, Variable)): raise ValueError("The shape should be a list, tuple or Variable.") if offsets is None: @@ -6596,7 +6641,7 @@ def affine_grid(theta, out_shape, name=None): helper = LayerHelper('affine_grid') if not (isinstance(out_shape, list) or isinstance(out_shape, tuple) or \ - isinstance(out_shape, Variable)): + isinstance(out_shape, Variable)): raise ValueError("The out_shape should be a list, tuple or Variable.") if not isinstance(theta, Variable): diff --git a/python/paddle/fluid/tests/unittests/test_nce.py b/python/paddle/fluid/tests/unittests/test_nce.py index c01fdd5ddd..f4f9744674 100644 --- a/python/paddle/fluid/tests/unittests/test_nce.py +++ b/python/paddle/fluid/tests/unittests/test_nce.py @@ -14,8 +14,12 @@ from __future__ import print_function -import unittest import numpy as np +import unittest + +import paddle.fluid as fluid +import paddle.fluid.initializer as initializer + from op_test import OpTest @@ -59,7 +63,7 @@ def nce(input, weight, bias, sample_weight, labels, num_classes, class TestNCE(OpTest): def generate_data(self, dim, batch_size, num_classes, num_true_class, - num_neg_samples): + num_neg_samples, is_sparse): input = np.random.randn(batch_size, dim).astype(np.float32) weight = np.random.randn(num_classes, dim).astype(np.float32) bias = np.random.randn(num_classes).astype(np.float32) @@ -70,7 +74,8 @@ class TestNCE(OpTest): 'num_neg_samples': num_neg_samples, 'custom_neg_classes': list(range(num_neg_samples)), 'seed': 0, - 'sampler': 0 + 'sampler': 0, + 'is_sparse': is_sparse } self.inputs = { 'Input': input, @@ -81,7 +86,7 @@ class TestNCE(OpTest): } def set_data(self): - self.generate_data(5, 5, 4, 1, 2) + self.generate_data(5, 5, 4, 1, 2, False) def compute(self): out = nce(self.inputs['Input'], self.inputs['Weight'], @@ -107,9 +112,110 @@ class TestNCE(OpTest): ["Input", "Weight", "Bias"], "Cost", max_relative_error=0.02) -class TestNCECase1(TestNCE): +class TestNCECase1Tensor(TestNCE): def set_data(self): - self.generate_data(10, 20, 10, 2, 5) + self.generate_data(10, 20, 10, 2, 5, False) + + +class TestNCECase1SelectedRows(unittest.TestCase): + def setUp(self): + self.base_lr = 0.0001 + self.batch_size = 8 + + @staticmethod + def get_place(): + place = fluid.core.CPUPlace() + return place + + @staticmethod + def get_train_data(batch_size): + batchs = [] + for i in range(batch_size): + input = np.random.randn(batch_size, 10).astype(np.float32) + labels = np.random.randint(0, 20, (batch_size, 1)) + batchs.append([input, labels]) + return batchs + + def get_optimizer(self): + # SGD optimizer + optimizer = fluid.optimizer.SGD(learning_rate=self.base_lr) + return optimizer + + def train_network(self, num_total_classes, num_neg_samples, sampler, + custom_dist, is_sparse): + input = fluid.layers.data(name="input", shape=[10], dtype="float32") + label = fluid.layers.data(name="label", shape=[1], dtype="int64") + + w_param = fluid.default_main_program().global_block().create_parameter( + shape=[num_total_classes, 10], + dtype='float32', + name='nce_w', + initializer=initializer.ConstantInitializer()) + b_param = fluid.default_main_program().global_block().create_parameter( + shape=[num_total_classes, 1], + dtype='float32', + name='nce_b', + initializer=initializer.ConstantInitializer()) + + cost = fluid.layers.nce(input=input, + label=label, + num_total_classes=num_total_classes, + sampler=sampler, + custom_dist=custom_dist, + sample_weight=None, + param_attr='nce_w', + bias_attr='nce_b', + seed=1, + num_neg_samples=num_neg_samples, + is_sparse=is_sparse) + avg_cost = fluid.layers.mean(cost) + # optimizer + optimizer = self.get_optimizer() + optimizer.minimize(avg_cost) + + return [avg_cost, [input, label]] + + def test_input_is_selected_rows(self): + place = self.get_place() + exe = fluid.Executor(place) + + data = self.get_train_data(self.batch_size) + nid_freq_arr = np.random.dirichlet(np.ones(20) * 1000).astype('float32') + + rets = [] + # for dense + dense_scope = fluid.core.Scope() + dense_startup_program = fluid.framework.Program() + dense_train_program = fluid.framework.Program() + with fluid.scope_guard(dense_scope): + with fluid.program_guard(dense_train_program, + dense_startup_program): + cost, feeds = self.train_network(20, 5, "custom_dist", + nid_freq_arr.tolist(), False) + feeder = fluid.DataFeeder(feed_list=feeds, place=place) + exe.run(dense_startup_program) + loss_val = exe.run(dense_train_program, + feed=feeder.feed(data), + fetch_list=[cost.name]) + rets.append(np.mean(loss_val)) + + # for sparse + sparse_scope = fluid.core.Scope() + sparse_startup_program = fluid.framework.Program() + sparse_train_program = fluid.framework.Program() + with fluid.scope_guard(sparse_scope): + with fluid.program_guard(sparse_train_program, + sparse_startup_program): + cost, feeds = self.train_network(20, 5, "custom_dist", + nid_freq_arr.tolist(), True) + feeder = fluid.DataFeeder(feed_list=feeds, place=place) + exe.run(sparse_startup_program) + loss_val = exe.run(sparse_train_program, + feed=feeder.feed(data), + fetch_list=[cost.name]) + rets.append(np.mean(loss_val)) + + self.assertEqual(rets[0], rets[1]) if __name__ == '__main__': -- GitLab