diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index f6ad5b2e4258553fc1a4eeb869b9d4d02cae9e26..33e6baf818a728d7bf50ba110274d60000dcc22e 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -40,11 +40,12 @@ if(WITH_PYTHON) cc_library(paddle_pybind SHARED SRCS pybind.cc DEPS pybind python backward - fc_op - sgd_op - add_op - mean_op - cross_entropy_op - fill_zeros_like_op - recurrent_op) + fc_op + sgd_op + add_op + mean_op + cross_entropy_op + recurrent_op + uniform_random_op + fill_zeros_like_op) endif(WITH_PYTHON) diff --git a/paddle/framework/pybind.cc b/paddle/framework/pybind.cc index 9ee2c6af86476ea50def237ed011fcddaa41daad..011391bc2df4206dde15030cd78d6ea329530de4 100644 --- a/paddle/framework/pybind.cc +++ b/paddle/framework/pybind.cc @@ -42,6 +42,7 @@ USE_OP(softmax); USE_OP(rowwise_add); USE_OP(fill_zeros_like); USE_OP_WITHOUT_KERNEL(recurrent_op); +USE_OP(uniform_random); namespace paddle { namespace framework { template diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index 531c3c8affac3109a9b795b0d699fb0652c1660e..b5311cab959c8e8c941cdcff467ac9720aea0fe7 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -66,3 +66,5 @@ op_library(fc_op op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc DEPS op_desc tensor op_registry operator net_op) cc_test(recurrent_op_test SRCS recurrent_op_test.cc DEPS recurrent_op gtest mul_op add_op) +op_library(uniform_random_op + SRCS uniform_random_op.cc uniform_random_op.cu) diff --git a/paddle/operators/uniform_random_op.cc b/paddle/operators/uniform_random_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..405b84b76d2e24db25d2ff16e99495f2f132ef09 --- /dev/null +++ b/paddle/operators/uniform_random_op.cc @@ -0,0 +1,84 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include +#include +#include "paddle/framework/op_registry.h" +#include "paddle/framework/operator.h" + +namespace paddle { +namespace operators { + +// It seems that Eigen::Tensor::random in GPU will SEGFAULT. +// Use std::random and thrust::random(thrust is a std library in CUDA) to +// implement uniform random. +template +class CPUUniformRandomKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* tensor = context.Output(0); + T* data = tensor->mutable_data(context.GetPlace()); + unsigned int seed = + static_cast(context.op_.GetAttr("seed")); + std::minstd_rand engine; + if (seed == 0) { + seed = std::random_device()(); + } + engine.seed(seed); + std::uniform_real_distribution dist( + static_cast(context.op_.GetAttr("min")), + static_cast(context.op_.GetAttr("max"))); + for (ssize_t i = 0; i < framework::product(tensor->dims()); ++i) { + data[i] = dist(engine); + } + } +}; + +class UniformRandomOp : public framework::OperatorWithKernel { + protected: + void InferShape(const framework::InferShapeContext& ctx) const override { + PADDLE_ENFORCE(GetAttr("min") < GetAttr("max"), + "uniform_random's min must less then max"); + auto* tensor = ctx.Output(0); + auto dims = GetAttr>("dims"); + tensor->Resize(framework::make_ddim(dims)); + } +}; + +class UniformRandomOpMaker : public framework::OpProtoAndCheckerMaker { + public: + UniformRandomOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) + : framework::OpProtoAndCheckerMaker(proto, op_checker) { + AddOutput("Out", "The output tensor of uniform random op"); + AddComment(R"DOC(Uniform random operator. + +Used to initialize tensor with uniform random generator. +)DOC"); + AddAttr>("dims", "the dimension of random tensor"); + AddAttr("min", "Minimum value of uniform random").SetDefault(-1.0f); + AddAttr("max", "Maximun value of uniform random").SetDefault(1.0f); + AddAttr("seed", + "Random seed of uniform random. " + "0 means generate a seed by system") + .SetDefault(0); + } +}; +} // namespace operators +} // namespace paddle + +REGISTER_OP(uniform_random, paddle::operators::UniformRandomOp, + paddle::operators::UniformRandomOpMaker); +REGISTER_OP_CPU_KERNEL(uniform_random, + paddle::operators::CPUUniformRandomKernel); diff --git a/paddle/operators/uniform_random_op.cu b/paddle/operators/uniform_random_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..f1a63e52ec0d3d46a505a89d7d7916bf93a58221 --- /dev/null +++ b/paddle/operators/uniform_random_op.cu @@ -0,0 +1,70 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include +#include +#include +#include +#include "paddle/framework/op_registry.h" +#include "paddle/framework/operator.h" + +namespace paddle { +namespace operators { + +template +struct UniformGenerator { + T min_, max_; + unsigned int seed_; + + __host__ __device__ UniformGenerator(T min, T max, int seed) + : min_(min), max_(max), seed_(seed) {} + + __host__ __device__ T operator()(const unsigned int n) const { + thrust::minstd_rand rng; + rng.seed(seed_); + thrust::uniform_real_distribution dist(min_, max_); + rng.discard(n); + return dist(rng); + } +}; + +// It seems that Eigen::Tensor::random in GPU will SEGFAULT. +// Use std::random and thrust::random(thrust is a std library in CUDA) to +// implement uniform random. +template +class GPUUniformRandomKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* tensor = context.Output(0); + T* data = tensor->mutable_data(context.GetPlace()); + unsigned int seed = + static_cast(context.op_.GetAttr("seed")); + if (seed == 0) { + seed = std::random_device()(); + } + T min = static_cast(context.op_.GetAttr("min")); + T max = static_cast(context.op_.GetAttr("max")); + thrust::counting_iterator index_sequence_begin(0); + ssize_t N = framework::product(tensor->dims()); + thrust::transform(index_sequence_begin, index_sequence_begin + N, + thrust::device_ptr(data), + UniformGenerator(min, max, seed)); + } +}; + +} // namespace operators +} // namespace paddle + +REGISTER_OP_GPU_KERNEL(uniform_random, + paddle::operators::GPUUniformRandomKernel); diff --git a/python/paddle/v2/framework/tests/CMakeLists.txt b/python/paddle/v2/framework/tests/CMakeLists.txt index 541639ac21661529b0b1f2cc8d8fa25605052c8c..0328bea7f7f18562fcb6b5a19cd2a2c70f10a532 100644 --- a/python/paddle/v2/framework/tests/CMakeLists.txt +++ b/python/paddle/v2/framework/tests/CMakeLists.txt @@ -21,3 +21,4 @@ py_test(test_rowwise_add_op SRCS test_rowwise_add_op.py) py_test(test_default_scope_funcs SRCS test_default_scope_funcs.py) py_test(test_operator SRCS test_operator.py) +py_test(test_uniform_random_op SRCS test_uniform_random_op.py) diff --git a/python/paddle/v2/framework/tests/test_uniform_random_op.py b/python/paddle/v2/framework/tests/test_uniform_random_op.py new file mode 100644 index 0000000000000000000000000000000000000000..c3d2bb44da3977c0899b2609a8efe15b7e1789f2 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_uniform_random_op.py @@ -0,0 +1,35 @@ +import unittest +from paddle.v2.framework.op import Operator +import paddle.v2.framework.core as core +import numpy + + +class UniformRandomTest(unittest.TestCase): + def test_uniform_random_cpu(self): + self.uniform_random_test(place=core.CPUPlace()) + + def test_uniform_random_gpu(self): + if core.is_compile_gpu(): + self.uniform_random_test(place=core.GPUPlace(0)) + + def uniform_random_test(self, place): + scope = core.Scope() + scope.new_var("X").get_tensor() + + op = Operator( + "uniform_random", + Out="X", + dims=[1000, 784], + min=-5.0, + max=10.0, + seed=10) + + op.infer_shape(scope) + ctx = core.DeviceContext.create(place) + op.run(scope, ctx) + tensor = numpy.array(scope.find_var("X").get_tensor()) + self.assertAlmostEqual(tensor.mean(), 2.5, delta=0.1) + + +if __name__ == '__main__': + unittest.main()