提交 4661f558 编写于 作者: T tangwei12

random optimize

上级 478f73c1
...@@ -36,9 +36,19 @@ class SamplingIdKernel : public framework::OpKernel<T> { ...@@ -36,9 +36,19 @@ class SamplingIdKernel : public framework::OpKernel<T> {
std::vector<T> ins_vector; std::vector<T> ins_vector;
framework::TensorToVector(*input, context.device_context(), &ins_vector); framework::TensorToVector(*input, context.device_context(), &ins_vector);
unsigned int seed = static_cast<unsigned int>(ctx.Attr<int>("seed"));
std::minstd_rand engine;
if (seed == 0) {
seed = std::random_device()();
}
engine.seed(seed);
std::uniform_real_distribution<T> dist(
static_cast<T>(ctx.Attr<float>("min")),
static_cast<T>(ctx.Attr<float>("max")));
std::vector<T> ids(batch_size); std::vector<T> ids(batch_size);
for (size_t i = 0; i < batch_size; ++i) { for (size_t i = 0; i < batch_size; ++i) {
double r = getRandReal(); double r = dist(engine);
int idx = width - 1; int idx = width - 1;
for (int j = 0; j < width; ++j) { for (int j = 0; j < width; ++j) {
if ((r -= ins_vector[i * width + j]) < 0) { if ((r -= ins_vector[i * width + j]) < 0) {
...@@ -57,16 +67,6 @@ class SamplingIdKernel : public framework::OpKernel<T> { ...@@ -57,16 +67,6 @@ class SamplingIdKernel : public framework::OpKernel<T> {
output->mutable_data<T>(context.GetPlace()); output->mutable_data<T>(context.GetPlace());
framework::TensorFromVector(ids, context.device_context(), output); framework::TensorFromVector(ids, context.device_context(), output);
} }
private:
double getRandReal() const {
std::random_device
rd; // Will be used to obtain a seed for the random number engine
std::mt19937 gen(rd()); // Standard mersenne_twister_engine seeded with
// rd()
std::uniform_real_distribution<> dis(1.0, 2.0);
return dis(gen);
}
}; };
class SamplingIdOp : public framework::OperatorWithKernel { class SamplingIdOp : public framework::OperatorWithKernel {
...@@ -78,6 +78,9 @@ class SamplingIdOp : public framework::OperatorWithKernel { ...@@ -78,6 +78,9 @@ class SamplingIdOp : public framework::OperatorWithKernel {
"Input(X) of SamplingIdOp should not be null."); "Input(X) of SamplingIdOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of SamplingIdOp should not be null."); "Output(Out) of SamplingIdOp should not be null.");
PADDLE_ENFORCE(
ctx->Attrs().Get<float>("min") < ctx->Attrs().Get<float>("max"),
"min must less then max");
auto input_dims = ctx->GetInputDim("X"); auto input_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE(input_dims.size() == 2, PADDLE_ENFORCE(input_dims.size() == 2,
...@@ -99,7 +102,17 @@ class SamplingIdOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -99,7 +102,17 @@ class SamplingIdOpMaker : public framework::OpProtoAndCheckerMaker {
AddComment(R"DOC( AddComment(R"DOC(
SamplingId Operator. SamplingId Operator.
A layer for sampling id from multinomial distribution from the A layer for sampling id from multinomial distribution from the
input layer. Sampling one id for one sample.)DOC"); input. Sampling one id for one sample.)DOC");
AddAttr<float>("min", "Minimum value of random. [default 0.0].")
.SetDefault(0.0f);
AddAttr<float>("max", "Maximun value of random. [default 1.0].")
.SetDefault(1.0f);
AddAttr<int>("seed",
"Random seed used for the random number engine. "
"0 means use a seed generated by the system."
"Note that if seed is not 0, this operator will always "
"generate the same random numbers every time. [default 0].")
.SetDefault(0);
} }
}; };
} // namespace operators } // namespace operators
...@@ -109,8 +122,5 @@ namespace ops = paddle::operators; ...@@ -109,8 +122,5 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(sampling_id, ops::SamplingIdOp, ops::SamplingIdOpMaker, REGISTER_OPERATOR(sampling_id, ops::SamplingIdOp, ops::SamplingIdOpMaker,
paddle::framework::EmptyGradOpMaker); paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(sampling_id, paddle::operators::SamplingIdKernel<float>,
sampling_id, ops::SamplingIdKernel<paddle::platform::CPUDeviceContext, int>, paddle::operators::SamplingIdKernel<double>);
ops::SamplingIdKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::SamplingIdKernel<paddle::platform::CPUDeviceContext, float>,
ops::SamplingIdKernel<paddle::platform::CPUDeviceContext, double>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册