提交 6540701f 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #3293 from reyoung/feature/uniform_random_op

Add uniform random operator
......@@ -40,11 +40,12 @@ if(WITH_PYTHON)
cc_library(paddle_pybind SHARED
SRCS pybind.cc
DEPS pybind python backward
fc_op
sgd_op
add_op
mean_op
cross_entropy_op
fill_zeros_like_op
recurrent_op)
fc_op
sgd_op
add_op
mean_op
cross_entropy_op
recurrent_op
uniform_random_op
fill_zeros_like_op)
endif(WITH_PYTHON)
......@@ -42,6 +42,7 @@ USE_OP(softmax);
USE_OP(rowwise_add);
USE_OP(fill_zeros_like);
USE_OP_WITHOUT_KERNEL(recurrent_op);
USE_OP(uniform_random);
namespace paddle {
namespace framework {
template <typename ClassType>
......
......@@ -66,3 +66,5 @@ op_library(fc_op
op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc
DEPS op_desc tensor op_registry operator net_op)
cc_test(recurrent_op_test SRCS recurrent_op_test.cc DEPS recurrent_op gtest mul_op add_op)
op_library(uniform_random_op
SRCS uniform_random_op.cc uniform_random_op.cu)
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <random>
#include <type_traits>
#include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
namespace paddle {
namespace operators {
// It seems that Eigen::Tensor::random in GPU will SEGFAULT.
// Use std::random and thrust::random(thrust is a std library in CUDA) to
// implement uniform random.
template <typename T>
class CPUUniformRandomKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* tensor = context.Output<framework::Tensor>(0);
T* data = tensor->mutable_data<T>(context.GetPlace());
unsigned int seed =
static_cast<unsigned int>(context.op_.GetAttr<int>("seed"));
std::minstd_rand engine;
if (seed == 0) {
seed = std::random_device()();
}
engine.seed(seed);
std::uniform_real_distribution<T> dist(
static_cast<T>(context.op_.GetAttr<float>("min")),
static_cast<T>(context.op_.GetAttr<float>("max")));
for (ssize_t i = 0; i < framework::product(tensor->dims()); ++i) {
data[i] = dist(engine);
}
}
};
class UniformRandomOp : public framework::OperatorWithKernel {
protected:
void InferShape(const framework::InferShapeContext& ctx) const override {
PADDLE_ENFORCE(GetAttr<float>("min") < GetAttr<float>("max"),
"uniform_random's min must less then max");
auto* tensor = ctx.Output<framework::Tensor>(0);
auto dims = GetAttr<std::vector<int>>("dims");
tensor->Resize(framework::make_ddim(dims));
}
};
class UniformRandomOpMaker : public framework::OpProtoAndCheckerMaker {
public:
UniformRandomOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddOutput("Out", "The output tensor of uniform random op");
AddComment(R"DOC(Uniform random operator.
Used to initialize tensor with uniform random generator.
)DOC");
AddAttr<std::vector<int>>("dims", "the dimension of random tensor");
AddAttr<float>("min", "Minimum value of uniform random").SetDefault(-1.0f);
AddAttr<float>("max", "Maximun value of uniform random").SetDefault(1.0f);
AddAttr<int>("seed",
"Random seed of uniform random. "
"0 means generate a seed by system")
.SetDefault(0);
}
};
} // namespace operators
} // namespace paddle
REGISTER_OP(uniform_random, paddle::operators::UniformRandomOp,
paddle::operators::UniformRandomOpMaker);
REGISTER_OP_CPU_KERNEL(uniform_random,
paddle::operators::CPUUniformRandomKernel<float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <thrust/device_ptr.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/random.h>
#include <thrust/transform.h>
#include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
namespace paddle {
namespace operators {
template <typename T>
struct UniformGenerator {
T min_, max_;
unsigned int seed_;
__host__ __device__ UniformGenerator(T min, T max, int seed)
: min_(min), max_(max), seed_(seed) {}
__host__ __device__ T operator()(const unsigned int n) const {
thrust::minstd_rand rng;
rng.seed(seed_);
thrust::uniform_real_distribution<T> dist(min_, max_);
rng.discard(n);
return dist(rng);
}
};
// It seems that Eigen::Tensor::random in GPU will SEGFAULT.
// Use std::random and thrust::random(thrust is a std library in CUDA) to
// implement uniform random.
template <typename T>
class GPUUniformRandomKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* tensor = context.Output<framework::Tensor>(0);
T* data = tensor->mutable_data<T>(context.GetPlace());
unsigned int seed =
static_cast<unsigned int>(context.op_.GetAttr<int>("seed"));
if (seed == 0) {
seed = std::random_device()();
}
T min = static_cast<T>(context.op_.GetAttr<float>("min"));
T max = static_cast<T>(context.op_.GetAttr<float>("max"));
thrust::counting_iterator<unsigned int> index_sequence_begin(0);
ssize_t N = framework::product(tensor->dims());
thrust::transform(index_sequence_begin, index_sequence_begin + N,
thrust::device_ptr<T>(data),
UniformGenerator<T>(min, max, seed));
}
};
} // namespace operators
} // namespace paddle
REGISTER_OP_GPU_KERNEL(uniform_random,
paddle::operators::GPUUniformRandomKernel<float>);
......@@ -21,3 +21,4 @@ py_test(test_rowwise_add_op SRCS test_rowwise_add_op.py)
py_test(test_default_scope_funcs SRCS test_default_scope_funcs.py)
py_test(test_operator SRCS test_operator.py)
py_test(test_uniform_random_op SRCS test_uniform_random_op.py)
import unittest
from paddle.v2.framework.op import Operator
import paddle.v2.framework.core as core
import numpy
class UniformRandomTest(unittest.TestCase):
def test_uniform_random_cpu(self):
self.uniform_random_test(place=core.CPUPlace())
def test_uniform_random_gpu(self):
if core.is_compile_gpu():
self.uniform_random_test(place=core.GPUPlace(0))
def uniform_random_test(self, place):
scope = core.Scope()
scope.new_var("X").get_tensor()
op = Operator(
"uniform_random",
Out="X",
dims=[1000, 784],
min=-5.0,
max=10.0,
seed=10)
op.infer_shape(scope)
ctx = core.DeviceContext.create(place)
op.run(scope, ctx)
tensor = numpy.array(scope.find_var("X").get_tensor())
self.assertAlmostEqual(tensor.mean(), 2.5, delta=0.1)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册