未验证 提交 56a4912b 编写于 作者: T tangwei12 提交者: GitHub

Make NCE_OP more efficient and support SelectedRows (#14469)

* Fix truncated normal.

* Fix.

* Make nce support more distribution.

* Fix API.spec.

* Fix python API.

* Fix.
test=develop

* Fix API.spec
test=develop

* Fix sampler.

* Fix order of arguments in python API.
test=develop

* NCE add selectedrows support

* NCE update weighted sampling

* fix bugs in nce_op, and assign_value_op optimized

* fix bugs in nce_op, revert assign_value_op

* nce_op optimize

* nce_op optimize

* nce_op optimize

* add selectedRows test later

test=develop

* add selectedRows supported

* add selectedRows supported

test=develop

* add selectedRows supported

* add nce selectedRows supported, test=develop

* add nce selectedRows supported

* add nce selectedRows supported, test=develop

* fix height in nce, test=develop

* add ut

* add ut, test=develop

* make AutoGrownIndex inline
test=develop

* fix tinny error, test=develop
上级 1c48d614
...@@ -97,7 +97,7 @@ paddle.fluid.layers.warpctc ArgSpec(args=['input', 'label', 'blank', 'norm_by_ti ...@@ -97,7 +97,7 @@ paddle.fluid.layers.warpctc ArgSpec(args=['input', 'label', 'blank', 'norm_by_ti
paddle.fluid.layers.sequence_reshape ArgSpec(args=['input', 'new_dim'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.sequence_reshape ArgSpec(args=['input', 'new_dim'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.transpose ArgSpec(args=['x', 'perm', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.transpose ArgSpec(args=['x', 'perm', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.im2sequence ArgSpec(args=['input', 'filter_size', 'stride', 'padding', 'input_image_size', 'out_stride', 'name'], varargs=None, keywords=None, defaults=(1, 1, 0, None, 1, None)) paddle.fluid.layers.im2sequence ArgSpec(args=['input', 'filter_size', 'stride', 'padding', 'input_image_size', 'out_stride', 'name'], varargs=None, keywords=None, defaults=(1, 1, 0, None, 1, None))
paddle.fluid.layers.nce ArgSpec(args=['input', 'label', 'num_total_classes', 'sample_weight', 'param_attr', 'bias_attr', 'num_neg_samples', 'name', 'sampler', 'custom_dist', 'seed'], varargs=None, keywords=None, defaults=(None, None, None, None, None, 'uniform', None, 0)) paddle.fluid.layers.nce ArgSpec(args=['input', 'label', 'num_total_classes', 'sample_weight', 'param_attr', 'bias_attr', 'num_neg_samples', 'name', 'sampler', 'custom_dist', 'seed', 'is_sparse'], varargs=None, keywords=None, defaults=(None, None, None, None, None, 'uniform', None, 0, False))
paddle.fluid.layers.hsigmoid ArgSpec(args=['input', 'label', 'num_classes', 'param_attr', 'bias_attr', 'name'], varargs=None, keywords=None, defaults=(None, None, None)) paddle.fluid.layers.hsigmoid ArgSpec(args=['input', 'label', 'num_classes', 'param_attr', 'bias_attr', 'name'], varargs=None, keywords=None, defaults=(None, None, None))
paddle.fluid.layers.beam_search ArgSpec(args=['pre_ids', 'pre_scores', 'ids', 'scores', 'beam_size', 'end_id', 'level', 'name'], varargs=None, keywords=None, defaults=(0, None)) paddle.fluid.layers.beam_search ArgSpec(args=['pre_ids', 'pre_scores', 'ids', 'scores', 'beam_size', 'end_id', 'level', 'name'], varargs=None, keywords=None, defaults=(0, None))
paddle.fluid.layers.row_conv ArgSpec(args=['input', 'future_context_size', 'param_attr', 'act'], varargs=None, keywords=None, defaults=(None, None)) paddle.fluid.layers.row_conv ArgSpec(args=['input', 'future_context_size', 'param_attr', 'act'], varargs=None, keywords=None, defaults=(None, None))
......
...@@ -60,75 +60,30 @@ float LogUniformSampler::Probability(int64_t value) const { ...@@ -60,75 +60,30 @@ float LogUniformSampler::Probability(int64_t value) const {
return (log((value + 2.0) / (value + 1.0))) / log_range_; return (log((value + 2.0) / (value + 1.0))) / log_range_;
} }
CustomSampler::CustomSampler(int64_t range, const float* probabilities, CustomSampler::CustomSampler(int64_t range, const float *probabilities,
const int *alias, const float *alias_probabilities,
unsigned int seed) unsigned int seed)
: Sampler(range, seed) { : Sampler(range, seed) {
random_engine_ = std::make_shared<std::mt19937_64>(seed_); random_engine_ = std::make_shared<std::mt19937>(seed_);
real_dist_ = std::make_shared<std::uniform_real_distribution<>>(0, 1); real_dist_ = std::make_shared<std::uniform_real_distribution<>>(0, 1);
int_dist_ = std::make_shared<std::uniform_int_distribution<>>(0, range); int_dist_ = std::make_shared<std::uniform_int_distribution<>>(0, range);
alias_probs_ = std::make_shared<std::vector<float>>(range + 1);
alias_ = std::make_shared<std::vector<int64_t>>(range + 1);
probs_ = std::make_shared<std::vector<float>>(range + 1);
std::queue<std::pair<int64_t, float>> bigs;
std::queue<std::pair<int64_t, float>> littles;
for (int64_t i = 0; i <= range; ++i) {
(*probs_)[i] = probabilities[i];
float normal_prob = probabilities[i] * (range + 1);
if (normal_prob - 1.0 > 1e-4) {
bigs.emplace(i, normal_prob);
} else if (1.0 - normal_prob > 1e-4) {
littles.emplace(i, normal_prob);
} else {
(*alias_probs_)[i] = normal_prob;
(*alias_)[i] = -1;
}
}
while ((!littles.empty()) && (!bigs.empty())) {
auto big = bigs.front();
auto little = littles.front();
bigs.pop();
littles.pop();
(*alias_probs_)[little.first] = little.second;
(*alias_)[little.first] = big.first;
auto big_left = big.second - (1 - little.second);
if (big_left - 1.0 > 1e-4) {
bigs.emplace(big.first, big_left);
} else if (1.0 - big_left > 1e-4) {
littles.emplace(big.first, big_left);
} else {
(*alias_probs_)[big.first] = big_left;
(*alias_)[big.first] = -1;
}
}
if (!littles.empty()) { // littles.second is close to 1.0 alias_probs_ = alias_probabilities;
auto little = littles.front(); probs_ = probabilities;
(*alias_probs_)[little.first] = 1.0; alias_ = alias;
(*alias_)[little.first] = -1;
}
if (!bigs.empty()) { // bigs.second is close to 1.0
auto big = bigs.front();
(*alias_probs_)[big.first] = 1.0;
(*alias_)[big.first] = -1;
}
} }
int64_t CustomSampler::Sample() const { int64_t CustomSampler::Sample() const {
auto index = (*int_dist_)(*random_engine_); auto index = (*int_dist_)(*random_engine_);
auto p = (*real_dist_)(*random_engine_); auto p = (*real_dist_)(*random_engine_);
if (p > (*alias_probs_)[index]) { if (p > alias_probs_[index]) {
return (*alias_)[index]; return alias_[index];
} else { } else {
return index; return index;
} }
} }
float CustomSampler::Probability(int64_t value) const { float CustomSampler::Probability(int64_t value) const { return probs_[value]; }
return (*probs_)[value];
}
} // namespace math } // namespace math
} // namespace operators } // namespace operators
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <cstdint> #include <cstdint>
#include <memory> #include <memory>
#include <random> #include <random>
...@@ -38,9 +39,12 @@ class Sampler { ...@@ -38,9 +39,12 @@ class Sampler {
seed_ = seed; seed_ = seed;
} }
} }
virtual ~Sampler(); virtual ~Sampler();
// Sample a single value // Sample a single value
virtual int64_t Sample() const = 0; virtual int64_t Sample() const = 0;
// The probability that a single call to Sample() returns the given value. // The probability that a single call to Sample() returns the given value.
virtual float Probability(int64_t value) const = 0; virtual float Probability(int64_t value) const = 0;
...@@ -99,6 +103,7 @@ class LogUniformSampler : public Sampler { ...@@ -99,6 +103,7 @@ class LogUniformSampler : public Sampler {
class CustomSampler : public Sampler { class CustomSampler : public Sampler {
public: public:
explicit CustomSampler(int64_t range, const float* probabilities, explicit CustomSampler(int64_t range, const float* probabilities,
const int* alias, const float* alias_probabilities,
unsigned int seed = 0UL); unsigned int seed = 0UL);
~CustomSampler() override {} ~CustomSampler() override {}
...@@ -108,10 +113,10 @@ class CustomSampler : public Sampler { ...@@ -108,10 +113,10 @@ class CustomSampler : public Sampler {
float Probability(int64_t value) const override; float Probability(int64_t value) const override;
private: private:
std::shared_ptr<std::vector<float>> alias_probs_; const float* alias_probs_;
std::shared_ptr<std::vector<int64_t>> alias_; const int* alias_;
std::shared_ptr<std::vector<float>> probs_; const float* probs_;
std::shared_ptr<std::mt19937_64> random_engine_; std::shared_ptr<std::mt19937> random_engine_;
std::shared_ptr<std::uniform_real_distribution<>> real_dist_; std::shared_ptr<std::uniform_real_distribution<>> real_dist_;
std::shared_ptr<std::uniform_int_distribution<>> int_dist_; std::shared_ptr<std::uniform_int_distribution<>> int_dist_;
}; };
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/nce_op.h" #include "paddle/fluid/operators/nce_op.h"
#include <string>
#include <vector> #include <vector>
namespace paddle { namespace paddle {
...@@ -25,7 +26,7 @@ class NCEOp : public framework::OperatorWithKernel { ...@@ -25,7 +26,7 @@ class NCEOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input")); PADDLE_ENFORCE(ctx->HasInput("Input"));
PADDLE_ENFORCE(ctx->HasInput("Label")); PADDLE_ENFORCE(ctx->HasInput("Label"));
PADDLE_ENFORCE(ctx->HasInput("Weight")); PADDLE_ENFORCE(ctx->HasInput("Weight"));
...@@ -67,7 +68,7 @@ class NCEOp : public framework::OperatorWithKernel { ...@@ -67,7 +68,7 @@ class NCEOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Input")->type()), framework::ToDataType(ctx.Input<Tensor>("Input")->type()),
platform::CPUPlace()); platform::CPUPlace());
...@@ -101,11 +102,24 @@ class NCEOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -101,11 +102,24 @@ class NCEOpMaker : public framework::OpProtoAndCheckerMaker {
.AsDispensable(); .AsDispensable();
AddInput( AddInput(
"CustomDistribution", "CustomDistProbs",
"(Tensor) It is used in 'CostumDist' sampler. " "(Tensor) It is used in 'CostumDist' sampler. "
"It is a tensor with shape [num_total_classes]." "It is a tensor with shape [num_total_classes]."
"The i-th element is the probsbility of the i-th class being sampled.") "The i-th element is the probsbility of the i-th class being sampled.")
.AsDispensable(); .AsDispensable();
AddInput(
"CustomDistAlias",
"(Tensor) It is used in 'CostumDist' sampler. "
"It is a tensor with shape [num_total_classes]."
"The i-th element is the probsbility of the i-th class being sampled.")
.AsDispensable();
AddInput(
"CustomDistAliasProbs",
"(Tensor) It is used in 'CostumDist' sampler. "
"It is a tensor with shape [num_total_classes]."
"The i-th element is the probsbility of the i-th class being sampled.")
.AsDispensable();
AddOutput("Cost", AddOutput("Cost",
"(Tensor) A tensor of shape [batch_size, 1]. Cost of samples."); "(Tensor) A tensor of shape [batch_size, 1]. Cost of samples.");
AddOutput("SampleLogits", AddOutput("SampleLogits",
...@@ -124,21 +138,22 @@ class NCEOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -124,21 +138,22 @@ class NCEOpMaker : public framework::OpProtoAndCheckerMaker {
"kernel to compute grads." "kernel to compute grads."
"") "")
.AsIntermediate(); .AsIntermediate();
AddAttr<int>("num_total_classes", AddAttr<int>("num_total_classes",
"Total number of classes in all samples."); "Total number of classes in all samples.");
AddAttr<int>("num_neg_samples", AddAttr<int>("num_neg_samples",
"The number of negative classes. The default value is 10.") "The number of negative classes. The default value is 10.")
.SetDefault(10); .SetDefault(10);
AddAttr<int>("sampler", AddAttr<int>("sampler",
"(int) Which sampler to be used to sample negative class." "(int) Which sampler to be used to sample negative class."
"0: Uniform; 1: LogUniform; 2: CostumDist.") "0: Uniform; 1: LogUniform; 2: CostumDist.")
.SetDefault(0); .SetDefault(0);
AddAttr<int>("seed", AddAttr<int>("seed",
"(int) The seed used in sampler. If it is 0, " "(int) The seed used in sampler. If it is 0, "
"the sampler will generate a seed randomly.") "the sampler will generate a seed randomly.")
.SetDefault(0); .SetDefault(0);
AddAttr<bool>("is_sparse", "(boolean, default false) Sparse update.")
.SetDefault(false);
AddAttr<std::vector<int>>("custom_neg_classes", AddAttr<std::vector<int>>("custom_neg_classes",
"This attribute only be used in unitest. Classes " "This attribute only be used in unitest. Classes "
...@@ -156,11 +171,19 @@ By default this operator uses a uniform distribution for sampling. ...@@ -156,11 +171,19 @@ By default this operator uses a uniform distribution for sampling.
} }
}; };
class NCEOpGradDescMaker : public framework::DefaultGradOpDescMaker<true> {
using ::paddle::framework::DefaultGradOpDescMaker<
true>::DefaultGradOpDescMaker;
protected:
virtual std::string GradOpType() const { return "nce_grad"; }
};
class NCEOpGrad : public framework::OperatorWithKernel { class NCEOpGrad : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input")); PADDLE_ENFORCE(ctx->HasInput("Input"));
PADDLE_ENFORCE(ctx->HasInput("Weight")); PADDLE_ENFORCE(ctx->HasInput("Weight"));
PADDLE_ENFORCE(ctx->HasInput("Cost")); PADDLE_ENFORCE(ctx->HasInput("Cost"));
...@@ -190,20 +213,45 @@ class NCEOpGrad : public framework::OperatorWithKernel { ...@@ -190,20 +213,45 @@ class NCEOpGrad : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Input")->type()), framework::ToDataType(ctx.Input<Tensor>("Input")->type()),
platform::CPUPlace()); platform::CPUPlace());
} }
}; };
class NCEOpGradVarTypeInference : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {
auto weight_grad = op_desc.Output(framework::GradVarName("Weight")).front();
auto bias_grad = op_desc.Output(framework::GradVarName("Bias")).front();
auto attr = op_desc.GetAttr("is_sparse");
bool is_sparse = boost::get<bool>(attr);
if (is_sparse) {
VLOG(30) << "nce_op_grad op " << weight_grad << " and " << bias_grad
<< " is set to SelectedRows";
block->Var(weight_grad)
->SetType(framework::proto::VarType::SELECTED_ROWS);
block->Var(bias_grad)->SetType(framework::proto::VarType::SELECTED_ROWS);
} else {
VLOG(30) << "nce_op_grad op " << weight_grad << " and " << bias_grad
<< " is set to LoDTensor";
block->Var(weight_grad)->SetType(framework::proto::VarType::LOD_TENSOR);
block->Var(bias_grad)->SetType(framework::proto::VarType::LOD_TENSOR);
}
block->Var(weight_grad)->SetDataType(block->Var("Input")->GetDataType());
block->Var(bias_grad)->SetDataType(block->Var("Input")->GetDataType());
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(nce, ops::NCEOp, ops::NCEOpMaker, REGISTER_OPERATOR(nce, ops::NCEOp, ops::NCEOpGradDescMaker, ops::NCEOpMaker);
paddle::framework::DefaultGradOpDescMaker<true>); REGISTER_OPERATOR(nce_grad, ops::NCEOpGrad, ops::NCEOpGradVarTypeInference);
REGISTER_OPERATOR(nce_grad, ops::NCEOpGrad);
REGISTER_OP_CPU_KERNEL(nce, ops::NCEKernel<paddle::platform::CPUPlace, float>, REGISTER_OP_CPU_KERNEL(nce, ops::NCEKernel<paddle::platform::CPUPlace, float>,
ops::NCEKernel<paddle::platform::CPUPlace, double>); ops::NCEKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL(nce_grad, REGISTER_OP_CPU_KERNEL(nce_grad,
......
...@@ -16,26 +16,32 @@ limitations under the License. */ ...@@ -16,26 +16,32 @@ limitations under the License. */
#include <math.h> #include <math.h>
#include <random> #include <random>
#include <set>
#include <vector> #include <vector>
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/math/sampler.h" #include "paddle/fluid/operators/math/sampler.h"
#include "unsupported/Eigen/CXX11/Tensor" #include "unsupported/Eigen/CXX11/Tensor"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
using SelectedRows = framework::SelectedRows;
using Sampler = math::Sampler; using Sampler = math::Sampler;
using DDim = framework::DDim;
template <typename T, int MajorType = Eigen::RowMajor, template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex> typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>; using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
void PrepareSamples(const framework::ExecutionContext& context, void PrepareSamples(const framework::ExecutionContext &context,
Sampler* sampler) { Sampler *sampler) {
auto label = context.Input<Tensor>("Label"); auto label = context.Input<Tensor>("Label");
const int64_t* label_data = label->data<int64_t>(); const int64_t *label_data = label->data<int64_t>();
auto label_dims = label->dims(); auto label_dims = label->dims();
// int num_total_classes = context.Attr<int>("num_total_classes"); // int num_total_classes = context.Attr<int>("num_total_classes");
// for unitest // for unitest
...@@ -44,7 +50,7 @@ void PrepareSamples(const framework::ExecutionContext& context, ...@@ -44,7 +50,7 @@ void PrepareSamples(const framework::ExecutionContext& context,
auto sample_labels = context.Output<Tensor>("SampleLabels"); auto sample_labels = context.Output<Tensor>("SampleLabels");
auto sample_labels_dims = sample_labels->dims(); auto sample_labels_dims = sample_labels->dims();
int64_t* sample_labels_data = int64_t *sample_labels_data =
sample_labels->mutable_data<int64_t>(context.GetPlace()); sample_labels->mutable_data<int64_t>(context.GetPlace());
int num_label = label_dims.size() == 2 ? label_dims[1] : 1; int num_label = label_dims.size() == 2 ? label_dims[1] : 1;
...@@ -70,13 +76,13 @@ void PrepareSamples(const framework::ExecutionContext& context, ...@@ -70,13 +76,13 @@ void PrepareSamples(const framework::ExecutionContext& context,
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class NCEKernel : public framework::OpKernel<T> { class NCEKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext &context) const override {
int sampler_type = context.Attr<int>("sampler"); int sampler_type = context.Attr<int>("sampler");
int seed = context.Attr<int>("seed"); int seed = context.Attr<int>("seed");
int num_total_classes = context.Attr<int>("num_total_classes"); int num_total_classes = context.Attr<int>("num_total_classes");
int num_neg_samples = context.Attr<int>("num_neg_samples"); int num_neg_samples = context.Attr<int>("num_neg_samples");
Sampler* sampler; Sampler *sampler;
switch (sampler_type) { switch (sampler_type) {
case 0: { case 0: {
sampler = new math::UniformSampler(num_total_classes - 1, seed); sampler = new math::UniformSampler(num_total_classes - 1, seed);
...@@ -87,11 +93,19 @@ class NCEKernel : public framework::OpKernel<T> { ...@@ -87,11 +93,19 @@ class NCEKernel : public framework::OpKernel<T> {
break; break;
} }
case 2: { case 2: {
auto custom_dist = context.Input<Tensor>("CustomDistribution"); auto dist_probs = context.Input<Tensor>("CustomDistProbs");
const float* custom_dist_data = custom_dist->data<float>(); auto dist_alias = context.Input<Tensor>("CustomDistAlias");
PADDLE_ENFORCE_EQ(custom_dist->numel(), num_total_classes); auto dist_alias_probs = context.Input<Tensor>("CustomDistAliasProbs");
sampler = new math::CustomSampler(num_total_classes - 1,
custom_dist_data, seed); PADDLE_ENFORCE_EQ(dist_probs->numel(), num_total_classes);
PADDLE_ENFORCE_EQ(dist_alias->numel(), num_total_classes);
PADDLE_ENFORCE_EQ(dist_alias_probs->numel(), num_total_classes);
const float *probs_data = dist_probs->data<float>();
const int *alias_data = dist_alias->data<int>();
const float *alias_probs_data = dist_alias_probs->data<float>();
sampler = new math::CustomSampler(num_total_classes - 1, probs_data,
alias_data, alias_probs_data, seed);
break; break;
} }
default: { PADDLE_THROW("Unsupported SamplerType."); } default: { PADDLE_THROW("Unsupported SamplerType."); }
...@@ -99,17 +113,17 @@ class NCEKernel : public framework::OpKernel<T> { ...@@ -99,17 +113,17 @@ class NCEKernel : public framework::OpKernel<T> {
PrepareSamples<DeviceContext, T>(context, sampler); PrepareSamples<DeviceContext, T>(context, sampler);
auto sample_labels = context.Output<Tensor>("SampleLabels"); auto sample_labels = context.Output<Tensor>("SampleLabels");
const int64_t* sample_labels_data = sample_labels->data<int64_t>(); const int64_t *sample_labels_data = sample_labels->data<int64_t>();
auto sample_out = context.Output<Tensor>("SampleLogits"); auto sample_out = context.Output<Tensor>("SampleLogits");
T* sample_out_data = sample_out->mutable_data<T>(context.GetPlace()); T *sample_out_data = sample_out->mutable_data<T>(context.GetPlace());
auto label = context.Input<Tensor>("Label"); auto label = context.Input<Tensor>("Label");
auto sample_weight = context.Input<Tensor>("SampleWeight"); auto sample_weight = context.Input<Tensor>("SampleWeight");
const T* sample_weight_data = nullptr; const T *sample_weight_data = nullptr;
if (sample_weight != nullptr) { if (sample_weight != nullptr) {
sample_weight_data = sample_weight->data<T>(); sample_weight_data = sample_weight->data<T>();
} }
auto out = context.Output<Tensor>("Cost"); auto out = context.Output<Tensor>("Cost");
T* out_data = out->mutable_data<T>(context.GetPlace()); T *out_data = out->mutable_data<T>(context.GetPlace());
int64_t num_true_class = 1; int64_t num_true_class = 1;
if (label != nullptr) { if (label != nullptr) {
num_true_class = label->dims()[1]; num_true_class = label->dims()[1];
...@@ -119,7 +133,7 @@ class NCEKernel : public framework::OpKernel<T> { ...@@ -119,7 +133,7 @@ class NCEKernel : public framework::OpKernel<T> {
// forward bias // forward bias
auto bias = context.Input<Tensor>("Bias"); auto bias = context.Input<Tensor>("Bias");
if (bias != nullptr) { if (bias != nullptr) {
const T* bias_data = bias->data<T>(); const T *bias_data = bias->data<T>();
for (int64_t i = 0; i < sample_labels->numel(); ++i) { for (int64_t i = 0; i < sample_labels->numel(); ++i) {
sample_out_data[i] = bias_data[sample_labels_data[i]]; sample_out_data[i] = bias_data[sample_labels_data[i]];
} }
...@@ -158,16 +172,16 @@ class NCEKernel : public framework::OpKernel<T> { ...@@ -158,16 +172,16 @@ class NCEKernel : public framework::OpKernel<T> {
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class NCEGradKernel : public framework::OpKernel<T> { class NCEGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext &context) const override {
auto d_out = context.Input<Tensor>(framework::GradVarName("Cost")); auto d_out = context.Input<Tensor>(framework::GradVarName("Cost"));
const T* d_out_data = d_out->data<T>(); const T *d_out_data = d_out->data<T>();
auto label = context.Input<Tensor>("Label"); auto label = context.Input<Tensor>("Label");
auto sample_out = context.Input<Tensor>("SampleLogits"); auto sample_out = context.Input<Tensor>("SampleLogits");
const T* sample_out_data = sample_out->data<T>(); const T *sample_out_data = sample_out->data<T>();
auto sample_labels = context.Input<Tensor>("SampleLabels"); auto sample_labels = context.Input<Tensor>("SampleLabels");
const int64_t* sample_labels_data = sample_labels->data<int64_t>(); const int64_t *sample_labels_data = sample_labels->data<int64_t>();
auto sample_weight = context.Input<Tensor>("SampleWeight"); auto sample_weight = context.Input<Tensor>("SampleWeight");
const T* sample_weight_data = nullptr; const T *sample_weight_data = nullptr;
if (sample_weight != nullptr) { if (sample_weight != nullptr) {
sample_weight_data = sample_weight->data<T>(); sample_weight_data = sample_weight->data<T>();
} }
...@@ -180,7 +194,7 @@ class NCEGradKernel : public framework::OpKernel<T> { ...@@ -180,7 +194,7 @@ class NCEGradKernel : public framework::OpKernel<T> {
int sampler_type = context.Attr<int>("sampler"); int sampler_type = context.Attr<int>("sampler");
int seed = context.Attr<int>("seed"); int seed = context.Attr<int>("seed");
Sampler* sampler; Sampler *sampler;
switch (sampler_type) { switch (sampler_type) {
case 0: { case 0: {
sampler = new math::UniformSampler(num_total_classes - 1, seed); sampler = new math::UniformSampler(num_total_classes - 1, seed);
...@@ -191,11 +205,19 @@ class NCEGradKernel : public framework::OpKernel<T> { ...@@ -191,11 +205,19 @@ class NCEGradKernel : public framework::OpKernel<T> {
break; break;
} }
case 2: { case 2: {
auto custom_dist = context.Input<Tensor>("CustomDistribution"); auto dist_probs = context.Input<Tensor>("CustomDistProbs");
const float* custom_dist_data = custom_dist->data<float>(); auto dist_alias = context.Input<Tensor>("CustomDistAlias");
PADDLE_ENFORCE_EQ(custom_dist->numel(), num_total_classes); auto dist_alias_probs = context.Input<Tensor>("CustomDistAliasProbs");
sampler = new math::CustomSampler(num_total_classes - 1,
custom_dist_data, seed); PADDLE_ENFORCE_EQ(dist_probs->numel(), num_total_classes);
PADDLE_ENFORCE_EQ(dist_alias->numel(), num_total_classes);
PADDLE_ENFORCE_EQ(dist_alias_probs->numel(), num_total_classes);
const float *probs_data = dist_probs->data<float>();
const int *alias_data = dist_alias->data<int>();
const float *alias_probs_data = dist_alias_probs->data<float>();
sampler = new math::CustomSampler(num_total_classes - 1, probs_data,
alias_data, alias_probs_data, seed);
break; break;
} }
default: { PADDLE_THROW("Unsupported SamplerType."); } default: { PADDLE_THROW("Unsupported SamplerType."); }
...@@ -203,7 +225,7 @@ class NCEGradKernel : public framework::OpKernel<T> { ...@@ -203,7 +225,7 @@ class NCEGradKernel : public framework::OpKernel<T> {
// T b = 1. / num_total_classes * num_neg_samples; // T b = 1. / num_total_classes * num_neg_samples;
Tensor sample_grad; // tmp tensor Tensor sample_grad; // tmp tensor
T* sample_grad_data = T *sample_grad_data =
sample_grad.mutable_data<T>(sample_labels->dims(), context.GetPlace()); sample_grad.mutable_data<T>(sample_labels->dims(), context.GetPlace());
// backward cost // backward cost
for (int64_t i = 0; i < sample_labels->numel(); ++i) { for (int64_t i = 0; i < sample_labels->numel(); ++i) {
...@@ -217,32 +239,105 @@ class NCEGradKernel : public framework::OpKernel<T> { ...@@ -217,32 +239,105 @@ class NCEGradKernel : public framework::OpKernel<T> {
: w * (o * (1 - o) / (o + b)); : w * (o * (1 - o) / (o + b));
sample_grad_data[i] *= d_out_data[sample_idx]; sample_grad_data[i] *= d_out_data[sample_idx];
} }
// get d_bias
auto d_bias = context.Output<Tensor>(framework::GradVarName("Bias")); bool is_sparse = context.Attr<bool>("is_sparse");
if (d_bias != nullptr) {
T* d_bias_data = d_bias->mutable_data<T>(context.GetPlace()); if (!is_sparse) {
std::fill(d_bias_data, d_bias_data + d_bias->numel(), 0.0); // get d_bias
auto d_bias = context.Output<Tensor>(framework::GradVarName("Bias"));
if (d_bias != nullptr) {
T *d_bias_data = d_bias->mutable_data<T>(context.GetPlace());
std::fill(d_bias_data, d_bias_data + d_bias->numel(), 0.0);
for (int64_t i = 0; i < sample_labels->numel(); ++i) {
d_bias_data[sample_labels_data[i]] += sample_grad_data[i];
}
}
// get d_w
auto d_w = context.Output<Tensor>(framework::GradVarName("Weight"));
if (d_w != nullptr) {
auto d_w_data = d_w->mutable_data<T>(context.GetPlace());
std::fill(d_w_data, d_w_data + d_w->numel(), 0.0);
auto d_w_matrix = EigenMatrix<T>::From(*d_w);
auto x_matrix = EigenMatrix<T>::From(*(context.Input<Tensor>("Input")));
for (int64_t i = 0; i < sample_labels->numel(); ++i) {
d_w_matrix.chip(sample_labels_data[i], 0) +=
x_matrix.chip(static_cast<int>(i / sample_labels->dims()[1]), 0) *
sample_grad_data[i];
}
}
} else {
std::vector<int64_t> labels;
for (int64_t i = 0; i < sample_labels->numel(); ++i) { for (int64_t i = 0; i < sample_labels->numel(); ++i) {
d_bias_data[sample_labels_data[i]] += sample_grad_data[i]; labels.push_back(sample_labels_data[i]);
} }
} std::set<T> st(labels.begin(), labels.end());
// get d_w labels.assign(st.begin(), st.end());
auto d_w = context.Output<Tensor>(framework::GradVarName("Weight"));
if (d_w != nullptr) { auto *bias_var = context.InputVar("Bias");
auto d_w_data = d_w->mutable_data<T>(context.GetPlace()); DDim bias_dim;
std::fill(d_w_data, d_w_data + d_w->numel(), 0.0); if (bias_var->IsType<LoDTensor>()) {
auto d_w_matrix = EigenMatrix<T>::From(*d_w); bias_dim = context.Input<LoDTensor>("Bias")->dims();
} else if (bias_var->IsType<SelectedRows>()) {
auto *table_t = context.Input<SelectedRows>("Bias");
bias_dim = table_t->value().dims();
} else {
PADDLE_THROW(
"The parameter Bias of a NCE_OP "
"must be either LoDTensor or SelectedRows");
}
auto d_bias =
context.Output<SelectedRows>(framework::GradVarName("Bias"));
d_bias->set_rows(labels);
d_bias->set_height(bias_dim[0]);
d_bias->mutable_value()->Resize(
{static_cast<int64_t>(labels.size()), bias_dim[1]});
T *d_bias_data =
d_bias->mutable_value()->mutable_data<T>(context.GetPlace());
std::fill(d_bias_data, d_bias_data + labels.size(), 0.0);
for (int64_t i = 0; i < sample_labels->numel(); ++i) {
d_bias_data[d_bias->Index(sample_labels_data[i])] +=
sample_grad_data[i];
}
auto *table_var = context.InputVar("Weight");
DDim table_dim;
if (table_var->IsType<LoDTensor>()) {
table_dim = context.Input<LoDTensor>("Weight")->dims();
} else if (table_var->IsType<SelectedRows>()) {
auto *table_t = context.Input<SelectedRows>("Weight");
table_dim = table_t->value().dims();
} else {
PADDLE_THROW(
"The parameter Weight of a NCE_OP "
"must be either LoDTensor or SelectedRows");
}
auto d_w = context.Output<SelectedRows>(framework::GradVarName("Weight"));
d_w->set_rows(labels);
d_w->set_height(table_dim[0]);
auto *d_table_value = d_w->mutable_value();
d_table_value->Resize(
{static_cast<int64_t>(labels.size()), table_dim[1]});
auto d_w_data = d_table_value->mutable_data<T>(context.GetPlace());
std::fill(d_w_data, d_w_data + d_table_value->numel(), 0.0);
auto d_w_matrix = EigenMatrix<T>::From(*d_table_value);
auto x_matrix = EigenMatrix<T>::From(*(context.Input<Tensor>("Input"))); auto x_matrix = EigenMatrix<T>::From(*(context.Input<Tensor>("Input")));
for (int64_t i = 0; i < sample_labels->numel(); ++i) { for (int64_t i = 0; i < sample_labels->numel(); ++i) {
d_w_matrix.chip(sample_labels_data[i], 0) += d_w_matrix.chip(d_w->Index(sample_labels_data[i]), 0) +=
x_matrix.chip(static_cast<int>(i / sample_labels->dims()[1]), 0) * x_matrix.chip(static_cast<int>(i / sample_labels->dims()[1]), 0) *
sample_grad_data[i]; sample_grad_data[i];
} }
} }
// get d_x // get d_x
auto d_x = context.Output<Tensor>(framework::GradVarName("Input")); auto d_x = context.Output<Tensor>(framework::GradVarName("Input"));
if (d_x != nullptr) { if (d_x != nullptr) {
auto* d_x_data = d_x->mutable_data<T>(context.GetPlace()); auto *d_x_data = d_x->mutable_data<T>(context.GetPlace());
std::fill(d_x_data, d_x_data + d_x->numel(), 0.0); std::fill(d_x_data, d_x_data + d_x->numel(), 0.0);
auto d_x_matrix = EigenMatrix<T>::From(*d_x); auto d_x_matrix = EigenMatrix<T>::From(*d_x);
auto w_matrix = EigenMatrix<T>::From(*(context.Input<Tensor>("Weight"))); auto w_matrix = EigenMatrix<T>::From(*(context.Input<Tensor>("Weight")));
...@@ -251,6 +346,7 @@ class NCEGradKernel : public framework::OpKernel<T> { ...@@ -251,6 +346,7 @@ class NCEGradKernel : public framework::OpKernel<T> {
w_matrix.chip(sample_labels_data[i], 0) * sample_grad_data[i]; w_matrix.chip(sample_labels_data[i], 0) * sample_grad_data[i];
} }
} }
delete sampler; delete sampler;
} }
}; };
......
...@@ -4394,7 +4394,8 @@ def nce(input, ...@@ -4394,7 +4394,8 @@ def nce(input,
name=None, name=None,
sampler="uniform", sampler="uniform",
custom_dist=None, custom_dist=None,
seed=0): seed=0,
is_sparse=False):
""" """
${comment} ${comment}
...@@ -4420,11 +4421,12 @@ def nce(input, ...@@ -4420,11 +4421,12 @@ def nce(input,
sampler (str): The sampler used to sample class from negtive classes. sampler (str): The sampler used to sample class from negtive classes.
It can be 'uniform', 'log_uniform' or 'custom_dist'. It can be 'uniform', 'log_uniform' or 'custom_dist'.
default: 'uniform'. default: 'uniform'.
custom_dist (Variable): A tensor with shape [num_total_classes]. custom_dist (float[]): A float[] with size=num_total_classes.
It is used when sampler is set to 'custom_dist'. It is used when sampler is set to 'custom_dist'.
custom_dist[i] is the probsbility of i-th class to be sampled. custom_dist[i] is the probsbility of i-th class to be sampled.
default: None. default: None.
seed (int): The seed used in sampler. default: 0. seed (int): The seed used in sampler. default: 0.
is_sparse(bool): The flag indicating whether to use sparse update, the weight@GRAD and bias@GRAD will be changed to SelectedRows.
Returns: Returns:
Variable: The output nce loss. Variable: The output nce loss.
...@@ -4476,12 +4478,7 @@ def nce(input, ...@@ -4476,12 +4478,7 @@ def nce(input,
shape=[num_total_classes, dim], shape=[num_total_classes, dim],
is_bias=False, is_bias=False,
dtype=input.dtype) dtype=input.dtype)
inputs = { inputs = {}
'Input': input,
'Label': label,
'Weight': w,
'SampleWeight': sample_weight if sample_weight is not None else []
}
if helper.bias_attr: if helper.bias_attr:
b = helper.create_parameter( b = helper.create_parameter(
attr=helper.bias_attr, attr=helper.bias_attr,
...@@ -4493,18 +4490,10 @@ def nce(input, ...@@ -4493,18 +4490,10 @@ def nce(input,
sample_logits = helper.create_variable_for_type_inference(dtype=input.dtype) sample_logits = helper.create_variable_for_type_inference(dtype=input.dtype)
sample_labels = helper.create_variable_for_type_inference(dtype=label.dtype) sample_labels = helper.create_variable_for_type_inference(dtype=label.dtype)
if num_neg_samples is None: inputs['Input'] = input
num_neg_samples = 10 inputs['Label'] = label
else: inputs['Weight'] = w
num_neg_samples = int(num_neg_samples) inputs['SampleWeight'] = sample_weight if sample_weight is not None else []
inputs = {
'Input': input,
'Label': label,
'Weight': w,
'Bias': b,
'SampleWeight': sample_weight if sample_weight is not None else []
}
if sampler == "uniform": if sampler == "uniform":
sampler = 0 sampler = 0
...@@ -4512,17 +4501,73 @@ def nce(input, ...@@ -4512,17 +4501,73 @@ def nce(input,
sampler = 1 sampler = 1
elif sampler == "custom_dist": elif sampler == "custom_dist":
assert custom_dist is not None assert custom_dist is not None
assert isinstance(custom_dist, Variable) # assert isinstance(custom_dist, Variable)
inputs['CustomDistribution'] = custom_dist
custom_dist_len = len(custom_dist)
alias_probs_ = [0] * custom_dist_len
alias_ = [0] * custom_dist_len
bigs = []
littles = []
for i in range(custom_dist_len):
normal_prob = custom_dist[i] * custom_dist_len
if normal_prob - 1.0 > 1e-4:
bigs.append((i, normal_prob))
elif 1.0 - normal_prob > 1e-4:
littles.append((i, normal_prob))
else:
alias_probs_[i] = normal_prob
alias_[i] = -1
while len(bigs) and len(littles):
big = bigs.pop(0)
little = littles.pop(0)
big_idx = big[0]
big_prob = big[1]
alias_probs_[little[0]] = little[1]
alias_[little[0]] = big_idx
big_left = big[1] + little[1] - 1
if big_left - 1.0 > 1e-4:
bigs.append((big_idx, big_left))
elif 1.0 - big_left > 1e-4:
littles.append((big_idx, big_left))
else:
alias_probs_[big_idx] = big_left
alias_[big_idx] = -1
if len(bigs):
big = bigs.pop(0)
alias_probs_[big[0]] = 1.0
alias_[big[0]] = -1
if len(littles):
little = littles.pop(0)
alias_probs_[little[0]] = 1.0
alias_[little[0]] = -1
probs = assign(input=np.array(custom_dist).astype('float32'))
custom_alias = assign(input=np.array(alias_).astype('int32'))
custom_alias_probs = assign(
input=np.array(alias_probs_).astype('float32'))
inputs['CustomDistProbs'] = probs
inputs['CustomDistAlias'] = custom_alias
inputs['CustomDistAliasProbs'] = custom_alias_probs
sampler = 2 sampler = 2
else: else:
raise Exception("Unsupported sampler type.") raise Exception("Unsupported sampler type.")
if num_neg_samples is None:
num_neg_samples = 10
else:
num_neg_samples = int(num_neg_samples)
attrs = { attrs = {
'num_total_classes': int(num_total_classes), 'num_total_classes': int(num_total_classes),
'num_neg_samples': num_neg_samples, 'num_neg_samples': num_neg_samples,
'seed': seed, 'seed': seed,
'sampler': sampler 'sampler': sampler,
'is_sparse': is_sparse
} }
helper.append_op( helper.append_op(
...@@ -6474,7 +6519,7 @@ def crop(x, shape=None, offsets=None, name=None): ...@@ -6474,7 +6519,7 @@ def crop(x, shape=None, offsets=None, name=None):
helper = LayerHelper('crop', **locals()) helper = LayerHelper('crop', **locals())
if not (isinstance(shape, list) or isinstance(shape, tuple) or \ if not (isinstance(shape, list) or isinstance(shape, tuple) or \
isinstance(shape, Variable)): isinstance(shape, Variable)):
raise ValueError("The shape should be a list, tuple or Variable.") raise ValueError("The shape should be a list, tuple or Variable.")
if offsets is None: if offsets is None:
...@@ -6596,7 +6641,7 @@ def affine_grid(theta, out_shape, name=None): ...@@ -6596,7 +6641,7 @@ def affine_grid(theta, out_shape, name=None):
helper = LayerHelper('affine_grid') helper = LayerHelper('affine_grid')
if not (isinstance(out_shape, list) or isinstance(out_shape, tuple) or \ if not (isinstance(out_shape, list) or isinstance(out_shape, tuple) or \
isinstance(out_shape, Variable)): isinstance(out_shape, Variable)):
raise ValueError("The out_shape should be a list, tuple or Variable.") raise ValueError("The out_shape should be a list, tuple or Variable.")
if not isinstance(theta, Variable): if not isinstance(theta, Variable):
......
...@@ -14,8 +14,12 @@ ...@@ -14,8 +14,12 @@
from __future__ import print_function from __future__ import print_function
import unittest
import numpy as np import numpy as np
import unittest
import paddle.fluid as fluid
import paddle.fluid.initializer as initializer
from op_test import OpTest from op_test import OpTest
...@@ -59,7 +63,7 @@ def nce(input, weight, bias, sample_weight, labels, num_classes, ...@@ -59,7 +63,7 @@ def nce(input, weight, bias, sample_weight, labels, num_classes,
class TestNCE(OpTest): class TestNCE(OpTest):
def generate_data(self, dim, batch_size, num_classes, num_true_class, def generate_data(self, dim, batch_size, num_classes, num_true_class,
num_neg_samples): num_neg_samples, is_sparse):
input = np.random.randn(batch_size, dim).astype(np.float32) input = np.random.randn(batch_size, dim).astype(np.float32)
weight = np.random.randn(num_classes, dim).astype(np.float32) weight = np.random.randn(num_classes, dim).astype(np.float32)
bias = np.random.randn(num_classes).astype(np.float32) bias = np.random.randn(num_classes).astype(np.float32)
...@@ -70,7 +74,8 @@ class TestNCE(OpTest): ...@@ -70,7 +74,8 @@ class TestNCE(OpTest):
'num_neg_samples': num_neg_samples, 'num_neg_samples': num_neg_samples,
'custom_neg_classes': list(range(num_neg_samples)), 'custom_neg_classes': list(range(num_neg_samples)),
'seed': 0, 'seed': 0,
'sampler': 0 'sampler': 0,
'is_sparse': is_sparse
} }
self.inputs = { self.inputs = {
'Input': input, 'Input': input,
...@@ -81,7 +86,7 @@ class TestNCE(OpTest): ...@@ -81,7 +86,7 @@ class TestNCE(OpTest):
} }
def set_data(self): def set_data(self):
self.generate_data(5, 5, 4, 1, 2) self.generate_data(5, 5, 4, 1, 2, False)
def compute(self): def compute(self):
out = nce(self.inputs['Input'], self.inputs['Weight'], out = nce(self.inputs['Input'], self.inputs['Weight'],
...@@ -107,9 +112,110 @@ class TestNCE(OpTest): ...@@ -107,9 +112,110 @@ class TestNCE(OpTest):
["Input", "Weight", "Bias"], "Cost", max_relative_error=0.02) ["Input", "Weight", "Bias"], "Cost", max_relative_error=0.02)
class TestNCECase1(TestNCE): class TestNCECase1Tensor(TestNCE):
def set_data(self): def set_data(self):
self.generate_data(10, 20, 10, 2, 5) self.generate_data(10, 20, 10, 2, 5, False)
class TestNCECase1SelectedRows(unittest.TestCase):
def setUp(self):
self.base_lr = 0.0001
self.batch_size = 8
@staticmethod
def get_place():
place = fluid.core.CPUPlace()
return place
@staticmethod
def get_train_data(batch_size):
batchs = []
for i in range(batch_size):
input = np.random.randn(batch_size, 10).astype(np.float32)
labels = np.random.randint(0, 20, (batch_size, 1))
batchs.append([input, labels])
return batchs
def get_optimizer(self):
# SGD optimizer
optimizer = fluid.optimizer.SGD(learning_rate=self.base_lr)
return optimizer
def train_network(self, num_total_classes, num_neg_samples, sampler,
custom_dist, is_sparse):
input = fluid.layers.data(name="input", shape=[10], dtype="float32")
label = fluid.layers.data(name="label", shape=[1], dtype="int64")
w_param = fluid.default_main_program().global_block().create_parameter(
shape=[num_total_classes, 10],
dtype='float32',
name='nce_w',
initializer=initializer.ConstantInitializer())
b_param = fluid.default_main_program().global_block().create_parameter(
shape=[num_total_classes, 1],
dtype='float32',
name='nce_b',
initializer=initializer.ConstantInitializer())
cost = fluid.layers.nce(input=input,
label=label,
num_total_classes=num_total_classes,
sampler=sampler,
custom_dist=custom_dist,
sample_weight=None,
param_attr='nce_w',
bias_attr='nce_b',
seed=1,
num_neg_samples=num_neg_samples,
is_sparse=is_sparse)
avg_cost = fluid.layers.mean(cost)
# optimizer
optimizer = self.get_optimizer()
optimizer.minimize(avg_cost)
return [avg_cost, [input, label]]
def test_input_is_selected_rows(self):
place = self.get_place()
exe = fluid.Executor(place)
data = self.get_train_data(self.batch_size)
nid_freq_arr = np.random.dirichlet(np.ones(20) * 1000).astype('float32')
rets = []
# for dense
dense_scope = fluid.core.Scope()
dense_startup_program = fluid.framework.Program()
dense_train_program = fluid.framework.Program()
with fluid.scope_guard(dense_scope):
with fluid.program_guard(dense_train_program,
dense_startup_program):
cost, feeds = self.train_network(20, 5, "custom_dist",
nid_freq_arr.tolist(), False)
feeder = fluid.DataFeeder(feed_list=feeds, place=place)
exe.run(dense_startup_program)
loss_val = exe.run(dense_train_program,
feed=feeder.feed(data),
fetch_list=[cost.name])
rets.append(np.mean(loss_val))
# for sparse
sparse_scope = fluid.core.Scope()
sparse_startup_program = fluid.framework.Program()
sparse_train_program = fluid.framework.Program()
with fluid.scope_guard(sparse_scope):
with fluid.program_guard(sparse_train_program,
sparse_startup_program):
cost, feeds = self.train_network(20, 5, "custom_dist",
nid_freq_arr.tolist(), True)
feeder = fluid.DataFeeder(feed_list=feeds, place=place)
exe.run(sparse_startup_program)
loss_val = exe.run(sparse_train_program,
feed=feeder.feed(data),
fetch_list=[cost.name])
rets.append(np.mean(loss_val))
self.assertEqual(rets[0], rets[1])
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册