diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index 8f22a5fbc3e7f0c4964d418c363298fe77e7ea3e..a7c89787e43df6173791bc54b3faffc034867f7d 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -58,7 +58,7 @@ op_library(rowwise_add_op SRCS rowwise_add_op.cu rowwise_add_op.cc) op_library(sigmoid_op SRCS sigmoid_op.cc sigmoid_op.cu) op_library(softmax_op SRCS softmax_op.cc softmax_op.cu) -op_library(gaussian_random_op SRCS gaussian_random_op.cc gaussian_random_op.cu DEPS math_function) +op_library(gaussian_random_op SRCS gaussian_random_op.cc gaussian_random_op.cu) op_library(cross_entropy_op SRCS cross_entropy_op.cc cross_entropy_op.cu) op_library(fill_zeros_like_op SRCS fill_zeros_like_op.cc fill_zeros_like_op.cu) @@ -67,4 +67,4 @@ op_library(sgd_op SRCS sgd_op.cc sgd_op.cu) op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc DEPS framework_proto tensor op_registry operator net_op) op_library(uniform_random_op - SRCS uniform_random_op.cc uniform_random_op.cu DEPS math_function) + SRCS uniform_random_op.cc uniform_random_op.cu) diff --git a/paddle/operators/gaussian_random_op.cc b/paddle/operators/gaussian_random_op.cc index 899f05fa4774ceef90a8f11390f446612c71d038..dcd2237459d3bbabf449d30e4279b637c569555b 100644 --- a/paddle/operators/gaussian_random_op.cc +++ b/paddle/operators/gaussian_random_op.cc @@ -1,22 +1,44 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/operators/gaussian_random_op.h" +#include +#include "paddle/framework/op_registry.h" namespace paddle { namespace operators { +template +class CPUGaussianRandomKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + float mean = context.op_.GetAttr("mean"); + float std = context.op_.GetAttr("std"); + auto* tensor = context.Output("Out"); + T* data = tensor->mutable_data(context.GetPlace()); + + unsigned int seed = + static_cast(context.op_.GetAttr("seed")); + std::minstd_rand engine; + if (seed == 0) { + seed = std::random_device()(); + } + engine.seed(seed); + std::normal_distribution dist(mean, std); + ssize_t size = framework::product(tensor->dims()); + for (ssize_t i = 0; i < size; ++i) { + data[i] = dist(engine); + } + } +}; + class GaussianRandomOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -43,8 +65,12 @@ Use to initialize tensor with gaussian random generator. )DOC"); AddAttr>("dims", "The dimension of random tensor."); - AddAttr("mean", "mean value of random.").SetDefault(.0f); - AddAttr("std", "minimum value of random value.").SetDefault(1.0f); + AddAttr("mean", "mean of random tensor.").SetDefault(.0f); + AddAttr("std", "std of random tensor.").SetDefault(1.0f); + AddAttr("seed", + "Random seed of generator." + "0 means use system wide seed") + .SetDefault(0); } }; @@ -54,6 +80,4 @@ Use to initialize tensor with gaussian random generator. namespace ops = paddle::operators; REGISTER_OP_WITHOUT_GRADIENT(gaussian_random, ops::GaussianRandomOp, ops::GaussianRandomOpMaker); -REGISTER_OP_CPU_KERNEL( - gaussian_random, - ops::GaussianRandomKernel); +REGISTER_OP_CPU_KERNEL(gaussian_random, ops::CPUGaussianRandomKernel); \ No newline at end of file diff --git a/paddle/operators/gaussian_random_op.cu b/paddle/operators/gaussian_random_op.cu index 31be16fdc81766e65ffefc58edd65c8b79c74698..1d312e7b5d41dd18d5b37977801593639f2de76f 100644 --- a/paddle/operators/gaussian_random_op.cu +++ b/paddle/operators/gaussian_random_op.cu @@ -1,20 +1,65 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/operators/gaussian_random_op.h" +#include +#include +#include +#include +#include "paddle/framework/op_registry.h" +#include "paddle/framework/operator.h" + +namespace paddle { +namespace operators { + +template +struct GaussianGenerator { + T mean_, std_; + unsigned int seed_; + + __host__ __device__ GaussianGenerator(T mean, T std, int seed) + : mean_(mean), std_(std), seed_(seed) {} + + __host__ __device__ T operator()(const unsigned int n) const { + thrust::minstd_rand rng; + rng.seed(seed_); + thrust::normal_distribution dist(min_, max_); + rng.discard(n); + return dist(rng); + } +}; + +template +class GPUGaussianRandomKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* tensor = context.Output("Out"); + T* data = tensor->mutable_data(context.GetPlace()); + unsigned int seed = + static_cast(context.op_.GetAttr("seed")); + if (seed == 0) { + std::random_device rd; + seed = rd(); + } + T mean = static_cast(context.op_.GetAttr("mean")); + T std = static_cast(context.op_.GetAttr("std")); + thrust::counting_iterator index_sequence_begin(0); + ssize_t N = framework::product(tensor->dims()); + thrust::transform(index_sequence_begin, index_sequence_begin + N, + thrust::device_ptr(data), + GaussianGenerator(mean, std, seed)); + } +}; + +} // namespace operators +} // namespace paddle -namespace ops = paddle::operators; -REGISTER_OP_GPU_KERNEL( - gaussian_random, - ops::GaussianRandomKernel); +REGISTER_OP_GPU_KERNEL(gaussian_random, + paddle::operators::GPUGaussianRandomKernel); \ No newline at end of file diff --git a/paddle/operators/gaussian_random_op.h b/paddle/operators/gaussian_random_op.h deleted file mode 100644 index c90b665fe0e077ef054ee75f2b8c1b9165eac618..0000000000000000000000000000000000000000 --- a/paddle/operators/gaussian_random_op.h +++ /dev/null @@ -1,38 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include "paddle/framework/op_registry.h" -#include "paddle/operators/math/math_function.h" - -namespace paddle { -namespace operators { -template -class GaussianRandomKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* tensor = context.Output("Out"); - T* data = tensor->mutable_data(context.GetPlace()); - T mean = static_cast(context.op_.GetAttr("mean")); - T std = static_cast(context.op_.GetAttr("std")); - auto n = framework::product(tensor->dims()); - - auto* device_context = - const_cast(context.device_context_); - math::RandGaussian(n, mean, std, data, device_context); - } -}; -} // namespace operators -} // namespace paddle diff --git a/paddle/operators/math/math_function.cc b/paddle/operators/math/math_function.cc index a098e02f95d6a641fef0831eacf3be054b93f62b..d9824e5f966026f9fd1853e813131e34ae5ab2c4 100644 --- a/paddle/operators/math/math_function.cc +++ b/paddle/operators/math/math_function.cc @@ -118,28 +118,6 @@ void Set(const int n, const float alpha, out.device(*(cpu_context->eigen_device())) = out.constant(float(alpha)); } -template <> -void RandUniform(const int n, const float min, - const float max, float* output, - platform::DeviceContext* context) { - auto* cpu_context = reinterpret_cast(context); - std::uniform_real_distribution distribution(min, max); - for (int i = 0; i < n; i++) { - output[i] = distribution(cpu_context->rand_engine()); - } -} - -template <> -void RandGaussian(const int n, const float mean, - const float std, float* output, - platform::DeviceContext* context) { - auto* cpu_context = reinterpret_cast(context); - std::normal_distribution distribution(mean, std); - for (int i = 0; i < n; i++) { - output[i] = distribution(cpu_context->rand_engine()); - } -} - } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/math_function.cu b/paddle/operators/math/math_function.cu index 908efe9e0f92170fa218c51ebac7d3d55aeac78d..9dff6f05fb705f82e8f70eaef4ddeeaa7d7acef3 100644 --- a/paddle/operators/math/math_function.cu +++ b/paddle/operators/math/math_function.cu @@ -135,54 +135,6 @@ void Set(const int n, const float alpha, out.device(*(cuda_context->eigen_device())) = out.constant(float(alpha)); } -template -__global__ void UniformShift(const int n, const T min, const T max, T* x) { - float scale = max - min; - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; - i += blockDim.x * gridDim.x) { - x[i] = x[i] * scale + min; - } -} - -template <> -void RandUniform(const int n, const float min, - const float max, float* output, - platform::DeviceContext* context) { - auto* cuda_context = reinterpret_cast(context); - PADDLE_ENFORCE(platform::dynload::curandGenerateUniform( - cuda_context->curand_generator(), output, n)); - int block = 512; - int grid = (n + block - 1) / block; - UniformShift<<stream()>>>(n, min, max, - output); -} - -template -int HandleOddLengthRandGaussian(const int n, const T mean, const T std, - T* output, - platform::CUDADeviceContext* context) { - if (n % 2 == 1) { - std::default_random_engine generator; - std::normal_distribution distribution(mean, std); - const T random_value = distribution(generator); - Set(1, random_value, output + (n - 1), context); - return n - 1; - } - return n; -} - -template <> -void RandGaussian(const int n, const float mean, - const float std, float* output, - platform::DeviceContext* context) { - auto* cuda_context = reinterpret_cast(context); - - const int even_n = - HandleOddLengthRandGaussian(n, mean, std, output, cuda_context); - PADDLE_ENFORCE(platform::dynload::curandGenerateNormal( - cuda_context->curand_generator(), output, even_n, mean, std)); -} - } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/math_function.h b/paddle/operators/math/math_function.h index 6543a1b515e8c6d79d59e0811aa92caaffa829bb..a0e9660564589c4a84e5cf529a43bc0cf3ac6a53 100644 --- a/paddle/operators/math/math_function.h +++ b/paddle/operators/math/math_function.h @@ -82,14 +82,6 @@ template void Set(const int n, const T alpha, T* output, platform::DeviceContext* context); -template -void RandUniform(const int n, const T min, const T max, T* output, - platform::DeviceContext* context); - -template -void RandGaussian(const int n, const T mean, const T std, T* output, - platform::DeviceContext* context); - } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/uniform_random_op.cc b/paddle/operators/uniform_random_op.cc index 81487a6bd82534a6babb4df8cd68b12d97166a34..876b3ef5575325daed8bebc0c964bd80064a7bb3 100644 --- a/paddle/operators/uniform_random_op.cc +++ b/paddle/operators/uniform_random_op.cc @@ -1,22 +1,48 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/operators/uniform_random_op.h" +#include +#include +#include "paddle/framework/op_registry.h" +#include "paddle/framework/operator.h" namespace paddle { namespace operators { +// It seems that Eigen::Tensor::random in GPU will SEGFAULT. +// Use std::random and thrust::random(thrust is a std library in CUDA) to +// implement uniform random. +template +class CPUUniformRandomKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* tensor = context.Output("Out"); + T* data = tensor->mutable_data(context.GetPlace()); + unsigned int seed = + static_cast(context.op_.GetAttr("seed")); + std::minstd_rand engine; + if (seed == 0) { + seed = std::random_device()(); + } + engine.seed(seed); + std::uniform_real_distribution dist( + static_cast(context.op_.GetAttr("min")), + static_cast(context.op_.GetAttr("max"))); + ssize_t size = framework::product(tensor->dims()); + for (ssize_t i = 0; i < size; ++i) { + data[i] = dist(engine); + } + } +}; + class UniformRandomOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -38,12 +64,15 @@ class UniformRandomOpMaker : public framework::OpProtoAndCheckerMaker { : framework::OpProtoAndCheckerMaker(proto, op_checker) { AddOutput("Out", "The output tensor of uniform random op"); AddComment(R"DOC(Uniform random operator. - Used to initialize tensor with uniform random generator. )DOC"); AddAttr>("dims", "the dimension of random tensor"); AddAttr("min", "Minimum value of uniform random").SetDefault(-1.0f); AddAttr("max", "Maximun value of uniform random").SetDefault(1.0f); + AddAttr("seed", + "Random seed of uniform random. " + "0 means generate a seed by system") + .SetDefault(0); } }; } // namespace operators @@ -51,6 +80,5 @@ Used to initialize tensor with uniform random generator. REGISTER_OP_WITHOUT_GRADIENT(uniform_random, paddle::operators::UniformRandomOp, paddle::operators::UniformRandomOpMaker); -REGISTER_OP_CPU_KERNEL( - uniform_random, - paddle::operators::UniformRandomKernel); +REGISTER_OP_CPU_KERNEL(uniform_random, + paddle::operators::CPUUniformRandomKernel); \ No newline at end of file diff --git a/paddle/operators/uniform_random_op.cu b/paddle/operators/uniform_random_op.cu index 1bfffc47783d54beaf839808829f962a131637d5..6716b7c7f20371f2970efaebe94cb1381983783c 100644 --- a/paddle/operators/uniform_random_op.cu +++ b/paddle/operators/uniform_random_op.cu @@ -1,19 +1,68 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/operators/uniform_random_op.h" +#include +#include +#include +#include +#include "paddle/framework/op_registry.h" +#include "paddle/framework/operator.h" + +namespace paddle { +namespace operators { + +template +struct UniformGenerator { + T min_, max_; + unsigned int seed_; + + __host__ __device__ UniformGenerator(T min, T max, int seed) + : min_(min), max_(max), seed_(seed) {} + + __host__ __device__ T operator()(const unsigned int n) const { + thrust::minstd_rand rng; + rng.seed(seed_); + thrust::uniform_real_distribution dist(min_, max_); + rng.discard(n); + return dist(rng); + } +}; + +// It seems that Eigen::Tensor::random in GPU will SEGFAULT. +// Use std::random and thrust::random(thrust is a std library in CUDA) to +// implement uniform random. +template +class GPUUniformRandomKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* tensor = context.Output("Out"); + T* data = tensor->mutable_data(context.GetPlace()); + unsigned int seed = + static_cast(context.op_.GetAttr("seed")); + if (seed == 0) { + std::random_device rd; + seed = rd(); + } + T min = static_cast(context.op_.GetAttr("min")); + T max = static_cast(context.op_.GetAttr("max")); + thrust::counting_iterator index_sequence_begin(0); + ssize_t N = framework::product(tensor->dims()); + thrust::transform(index_sequence_begin, index_sequence_begin + N, + thrust::device_ptr(data), + UniformGenerator(min, max, seed)); + } +}; + +} // namespace operators +} // namespace paddle -REGISTER_OP_GPU_KERNEL( - uniform_random, - paddle::operators::UniformRandomKernel); +REGISTER_OP_GPU_KERNEL(uniform_random, + paddle::operators::GPUUniformRandomKernel); \ No newline at end of file diff --git a/paddle/operators/uniform_random_op.h b/paddle/operators/uniform_random_op.h deleted file mode 100644 index dffa640f84e4995b12aded3a103fe42a9fe38e83..0000000000000000000000000000000000000000 --- a/paddle/operators/uniform_random_op.h +++ /dev/null @@ -1,38 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include "paddle/framework/op_registry.h" -#include "paddle/operators/math/math_function.h" - -namespace paddle { -namespace operators { -template -class UniformRandomKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* tensor = context.Output("Out"); - T* data = tensor->mutable_data(context.GetPlace()); - T min = static_cast(context.op_.GetAttr("min")); - T max = static_cast(context.op_.GetAttr("max")); - auto n = framework::product(tensor->dims()); - - auto* device_context = - const_cast(context.device_context_); - math::RandUniform(n, min, max, data, device_context); - } -}; -} // namespace operators -} // namespace paddle diff --git a/paddle/platform/device_context.cc b/paddle/platform/device_context.cc index ad9b4e42f334523255bc3f110eefe9b5575679cb..ad212c5b2c47312743362db4926c80bf056e100d 100644 --- a/paddle/platform/device_context.cc +++ b/paddle/platform/device_context.cc @@ -25,17 +25,8 @@ CPUDeviceContext::CPUDeviceContext() { eigen_device_.reset(new Eigen::DefaultDevice()); } -CPUDeviceContext::CPUDeviceContext(CPUPlace place, int seed) { +CPUDeviceContext::CPUDeviceContext(CPUPlace place) { eigen_device_.reset(new Eigen::DefaultDevice()); - rand_seed_ = seed; -} - -std::minstd_rand& CPUDeviceContext::rand_engine() { - if (!rand_engine_) { - rand_engine_.reset(new std::minstd_rand()); - rand_engine_->seed(rand_seed_); - } - return *(rand_engine_.get()); } Eigen::DefaultDevice* CPUDeviceContext::eigen_device() const { @@ -104,8 +95,7 @@ Eigen::GpuDevice* DeviceContext::get_eigen_device() const { return reinterpret_cast(this)->eigen_device(); } -CUDADeviceContext::CUDADeviceContext(GPUPlace place, uint64_t seed) - : place_(place), rand_seed_(seed) { +CUDADeviceContext::CUDADeviceContext(GPUPlace place) : place_(place) { SetDeviceId(place_.device); PADDLE_ENFORCE(cudaStreamCreate(&stream_)); eigen_stream_.reset(new EigenCudaStreamDevice()); @@ -157,19 +147,6 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() { return cudnn_handle_; } -curandGenerator_t CUDADeviceContext::curand_generator() { - if (!curand_generator_) { - SetDeviceId(place_.device); - PADDLE_ENFORCE(dynload::curandCreateGenerator(&curand_generator_, - CURAND_RNG_PSEUDO_DEFAULT)); - PADDLE_ENFORCE(dynload::curandSetPseudoRandomGeneratorSeed( - curand_generator_, rand_seed_)); - - PADDLE_ENFORCE(dynload::curandSetStream(curand_generator_, stream_)); - } - return curand_generator_; -} - cudaStream_t CUDADeviceContext::stream() { return stream_; } #endif // PADDLE_ONLY_CPU diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index e18f48fef59fcce065d8a1c8bfa72ab5c7a75439..11528e1194e4516891034fa8febdac3ba6eed204 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -17,7 +17,6 @@ limitations under the License. */ #ifndef PADDLE_ONLY_CPU #include "paddle/platform/dynload/cublas.h" #include "paddle/platform/dynload/cudnn.h" -#include "paddle/platform/dynload/curand.h" #include "paddle/platform/gpu_info.h" #define EIGEN_USE_GPU #endif @@ -40,18 +39,14 @@ class DeviceContext { class CPUDeviceContext : public DeviceContext { public: CPUDeviceContext(); - explicit CPUDeviceContext(CPUPlace place, int seed = 0); + explicit CPUDeviceContext(CPUPlace place); virtual ~CPUDeviceContext() {} Eigen::DefaultDevice* eigen_device() const; - std::minstd_rand& rand_engine(); - Place GetPlace() const override; private: - int rand_seed_; - std::unique_ptr rand_engine_; std::unique_ptr eigen_device_; }; @@ -60,7 +55,7 @@ class EigenCudaStreamDevice; class CUDADeviceContext : public DeviceContext { public: - explicit CUDADeviceContext(GPUPlace place, uint64_t seed = 0); + explicit CUDADeviceContext(GPUPlace place); virtual ~CUDADeviceContext(); /*! \brief Wait for all operations completion in the stream. */ @@ -79,9 +74,6 @@ class CUDADeviceContext : public DeviceContext { /*! \brief Return cudnn handle in the device context. */ cudnnHandle_t cudnn_handle(); - /*! \brief Return curand handle in the device context. */ - curandGenerator_t curand_generator(); - /*! \brief Return cuda stream in the device context. */ cudaStream_t stream(); // clang-format on @@ -92,13 +84,10 @@ class CUDADeviceContext : public DeviceContext { std::unique_ptr eigen_device_; std::unique_ptr eigen_stream_; - uint64_t rand_seed_; - // clang-format off cudaStream_t stream_{nullptr}; cudnnHandle_t cudnn_handle_{nullptr}; cublasHandle_t cublas_handle_{nullptr}; - curandGenerator_t curand_generator_{nullptr}; // clang-format on }; diff --git a/python/paddle/v2/framework/tests/test_gaussian_random_op.py b/python/paddle/v2/framework/tests/test_gaussian_random_op.py index 367d21b3017f093b2b13b4eda2480a6848fd7046..f95ed70b58d611b3233a21d3f2a34c864ae4d1b3 100644 --- a/python/paddle/v2/framework/tests/test_gaussian_random_op.py +++ b/python/paddle/v2/framework/tests/test_gaussian_random_op.py @@ -17,7 +17,12 @@ class GaussianRandomTest(unittest.TestCase): scope.new_var("Out").get_tensor() op = Operator( - "gaussian_random", Out="Out", dims=[1000, 784], mean=.0, std=1.) + "gaussian_random", + Out="Out", + dims=[1000, 784], + mean=.0, + std=1., + seed=10) op.infer_shape(scope) context = core.DeviceContext.create(place) diff --git a/python/paddle/v2/framework/tests/test_uniform_random_op.py b/python/paddle/v2/framework/tests/test_uniform_random_op.py index 95c36a27cf4f1115ffb264f094543c3e23becc01..c3d2bb44da3977c0899b2609a8efe15b7e1789f2 100644 --- a/python/paddle/v2/framework/tests/test_uniform_random_op.py +++ b/python/paddle/v2/framework/tests/test_uniform_random_op.py @@ -17,7 +17,12 @@ class UniformRandomTest(unittest.TestCase): scope.new_var("X").get_tensor() op = Operator( - "uniform_random", Out="X", dims=[1000, 784], min=-5.0, max=10.0) + "uniform_random", + Out="X", + dims=[1000, 784], + min=-5.0, + max=10.0, + seed=10) op.infer_shape(scope) ctx = core.DeviceContext.create(place)