diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index a7c89787e43df6173791bc54b3faffc034867f7d..8f22a5fbc3e7f0c4964d418c363298fe77e7ea3e 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -58,7 +58,7 @@ op_library(rowwise_add_op SRCS rowwise_add_op.cu rowwise_add_op.cc) op_library(sigmoid_op SRCS sigmoid_op.cc sigmoid_op.cu) op_library(softmax_op SRCS softmax_op.cc softmax_op.cu) -op_library(gaussian_random_op SRCS gaussian_random_op.cc gaussian_random_op.cu) +op_library(gaussian_random_op SRCS gaussian_random_op.cc gaussian_random_op.cu DEPS math_function) op_library(cross_entropy_op SRCS cross_entropy_op.cc cross_entropy_op.cu) op_library(fill_zeros_like_op SRCS fill_zeros_like_op.cc fill_zeros_like_op.cu) @@ -67,4 +67,4 @@ op_library(sgd_op SRCS sgd_op.cc sgd_op.cu) op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc DEPS framework_proto tensor op_registry operator net_op) op_library(uniform_random_op - SRCS uniform_random_op.cc uniform_random_op.cu) + SRCS uniform_random_op.cc uniform_random_op.cu DEPS math_function) diff --git a/paddle/operators/gaussian_random_op.cc b/paddle/operators/gaussian_random_op.cc index f30bbce9586d61063b4b61d98695bb568ef73c8d..aba8c6e5cd9fd2bbcc6d7b69126b191959256fef 100644 --- a/paddle/operators/gaussian_random_op.cc +++ b/paddle/operators/gaussian_random_op.cc @@ -12,36 +12,11 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include -#include "paddle/framework/op_registry.h" +#include "paddle/operators/gaussian_random_op.h" namespace paddle { namespace operators { -template -class GaussianRandomKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - float mean = context.op_.GetAttr("mean"); - float std = context.op_.GetAttr("std"); - auto* tensor = context.Output(0); - T* data = tensor->mutable_data(context.GetPlace()); - - // TODO(dzh): attribute does not support unsigned int. - // And we need a global random seed configuration. - int seed = context.op_.GetAttr("seed"); - if (seed == 0) { - seed = std::random_device()(); - } - std::mt19937 g(seed); - std::normal_distribution distribution(mean, std); - ssize_t size = framework::product(tensor->dims()); - for (int i = 0; i < size; ++i) { - data[i] = distribution(g); - } - } -}; - class GaussianRandomOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -70,10 +45,6 @@ Use to initialize tensor with gaussian random generator. AddAttr>("dims", "The dimension of random tensor."); AddAttr("mean", "mean value of random.").SetDefault(.0f); AddAttr("std", "minimum value of random value.").SetDefault(1.0f); - AddAttr("seed", - "Random seed of generator." - "0 means use system wide seed") - .SetDefault(0); } }; @@ -83,4 +54,6 @@ Use to initialize tensor with gaussian random generator. namespace ops = paddle::operators; REGISTER_OP_WITHOUT_GRADIENT(gaussian_random, ops::GaussianRandomOp, ops::GaussianRandomOpMaker); -REGISTER_OP_CPU_KERNEL(gaussian_random, ops::GaussianRandomKernel); +REGISTER_OP_CPU_KERNEL( + gaussian_random, + ops::GaussianRandomKernel); diff --git a/paddle/operators/gaussian_random_op.cu b/paddle/operators/gaussian_random_op.cu index 1340b1e1e9f19fd96ced9e57fab75fe9d33bc84e..31be16fdc81766e65ffefc58edd65c8b79c74698 100644 --- a/paddle/operators/gaussian_random_op.cu +++ b/paddle/operators/gaussian_random_op.cu @@ -12,42 +12,9 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include -#include -#include "paddle/platform/dynload/curand.h" -#include "paddle/platform/gpu_info.h" - -#include "paddle/framework/op_registry.h" - -namespace paddle { -namespace operators { - -template -class GaussianRandomKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - float mean = context.op_.GetAttr("mean"); - float std = context.op_.GetAttr("std"); - auto* tensor = context.Output(0); - T* data = tensor->mutable_data(context.GetPlace()); - - int seed = context.op_.GetAttr("seed"); - if (seed == 0) { - std::random_device rd; - seed = rd(); - } - curandGenerator_t g; - PADDLE_ENFORCE(platform::dynload::curandCreateGenerator( - &g, CURAND_RNG_PSEUDO_DEFAULT)); - PADDLE_ENFORCE( - platform::dynload::curandSetPseudoRandomGeneratorSeed(g, seed)); - platform::dynload::curandGenerateNormal( - g, data, framework::product(tensor->dims()), mean, std); - } -}; - -} // namespace operators -} // namespace paddle +#include "paddle/operators/gaussian_random_op.h" namespace ops = paddle::operators; -REGISTER_OP_GPU_KERNEL(gaussian_random, ops::GaussianRandomKernel); +REGISTER_OP_GPU_KERNEL( + gaussian_random, + ops::GaussianRandomKernel); diff --git a/paddle/operators/gaussian_random_op.h b/paddle/operators/gaussian_random_op.h new file mode 100644 index 0000000000000000000000000000000000000000..041390e954fe19f3b55fefd65551370301a6aa6e --- /dev/null +++ b/paddle/operators/gaussian_random_op.h @@ -0,0 +1,38 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/framework/op_registry.h" +#include "paddle/operators/math/math_function.h" + +namespace paddle { +namespace operators { +template +class GaussianRandomKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* tensor = context.Output("Out"); + T* data = tensor->mutable_data(context.GetPlace()); + T mean = static_cast(context.op_.GetAttr("mean")); + T std = static_cast(context.op_.GetAttr("std")); + auto n = framework::product(tensor->dims()); + + auto* device_context = + const_cast(context.device_context_); + math::RandGaussian(n, mean, std, data, device_context); + } +}; +} +} diff --git a/paddle/operators/math/math_function.cc b/paddle/operators/math/math_function.cc index 1e86fc3d166077265e0f433a6712b0665ea5a152..da59044899762baeb0497255a08455de2d46c062 100644 --- a/paddle/operators/math/math_function.cc +++ b/paddle/operators/math/math_function.cc @@ -109,6 +109,28 @@ void matmul(const framework::Tensor& matrix_a, matrix_b.data(), beta, matrix_out->data(), context); } +template <> +void RandUniform(const int n, const float min, + const float max, float* output, + platform::DeviceContext* context) { + auto* cpu_context = reinterpret_cast(context); + std::uniform_real_distribution distribution(min, max); + for (int i = 0; i < n; i++) { + output[i] = distribution(cpu_context->rand_engine()); + } +} + +template <> +void RandGaussian(const int n, const float mean, + const float std, float* output, + platform::DeviceContext* context) { + auto* cpu_context = reinterpret_cast(context); + std::normal_distribution distribution(mean, std); + for (int i = 0; i < n; i++) { + output[i] = distribution(cpu_context->rand_engine()); + } +} + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/math_function.cu b/paddle/operators/math/math_function.cu index da40b27c948918e4997f4a046d2145552296158b..5a400d44459efc71e182ff305695cc609073a448 100644 --- a/paddle/operators/math/math_function.cu +++ b/paddle/operators/math/math_function.cu @@ -12,6 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include +#include +#include +#include #include "paddle/operators/math/math_function.h" namespace paddle { @@ -122,6 +126,38 @@ void matmul(const framework::Tensor& matrix_a, matrix_b.data(), beta, matrix_out->data(), context); } +template <> +void RandUniform(const int n, const float min, + const float max, float* output, + platform::DeviceContext* context) { + auto* cuda_context = reinterpret_cast(context); + thrust::uniform_real_distribution distribution(min, max); + thrust::minstd_rand engine = cuda_context->rand_enigne(); + engine->discard(n); + + thrust::counting_iterator index_sequence_begin(0); + + thrust::transform(thrust::cuda::par.on(cuda_context->stream()), + index_sequence_begin, index_sequence_begin + n, + thrust::device_ptr(output), distribution(engine)); +} + +template <> +void RandGaussian(const int n, const float mean, + const float std, float* output, + platform::DeviceContext* context) { + auto* cuda_context = reinterpret_cast(context); + thrust::normal_distribution distribution(mean, std); + thrust::minstd_rand engine = cuda_context->rand_enigne(); + engine->discard(n); + + thrust::counting_iterator index_sequence_begin(0); + + thrust::transform(thrust::cuda::par.on(cuda_context->stream()), + index_sequence_begin, index_sequence_begin + n, + thrust::device_ptr(output), distribution(engine)); +} + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/math_function.h b/paddle/operators/math/math_function.h index 155589fadb3ed9f59160a750d546dd8093a56cbe..ea15e8fd2bd813649fe1504f9afcc66f252dc870 100644 --- a/paddle/operators/math/math_function.h +++ b/paddle/operators/math/math_function.h @@ -77,6 +77,14 @@ void matmul(const framework::Tensor& matrix_a, bool trans_a, framework::Tensor* matrix_out, T beta, platform::DeviceContext* context); +template +void RandUniform(const int n, const T min, const T max, T* output, + platform::DeviceContext* context); + +template +void RandGaussian(const int n, const T mean, const T std, T* output, + platform::DeviceContext* context); + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/mul_op.cc b/paddle/operators/mul_op.cc index 460e458ca4f7f40746f0dbf7e258a165faa88e1a..173cc3850ca9d97200e272ec59d1bd3fe09b5053 100644 --- a/paddle/operators/mul_op.cc +++ b/paddle/operators/mul_op.cc @@ -13,7 +13,6 @@ limitations under the License. */ #include "paddle/operators/mul_op.h" -#include "paddle/operators/math/math_function.h" namespace paddle { namespace operators { diff --git a/paddle/operators/uniform_random_op.cc b/paddle/operators/uniform_random_op.cc index a0a0d4d914b37fca4250e5218a953f573611a086..81487a6bd82534a6babb4df8cd68b12d97166a34 100644 --- a/paddle/operators/uniform_random_op.cc +++ b/paddle/operators/uniform_random_op.cc @@ -12,39 +12,11 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include -#include -#include "paddle/framework/op_registry.h" -#include "paddle/framework/operator.h" +#include "paddle/operators/uniform_random_op.h" namespace paddle { namespace operators { -// It seems that Eigen::Tensor::random in GPU will SEGFAULT. -// Use std::random and thrust::random(thrust is a std library in CUDA) to -// implement uniform random. -template -class CPUUniformRandomKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* tensor = context.Output("Out"); - T* data = tensor->mutable_data(context.GetPlace()); - unsigned int seed = - static_cast(context.op_.GetAttr("seed")); - std::minstd_rand engine; - if (seed == 0) { - seed = std::random_device()(); - } - engine.seed(seed); - std::uniform_real_distribution dist( - static_cast(context.op_.GetAttr("min")), - static_cast(context.op_.GetAttr("max"))); - for (ssize_t i = 0; i < framework::product(tensor->dims()); ++i) { - data[i] = dist(engine); - } - } -}; - class UniformRandomOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -72,10 +44,6 @@ Used to initialize tensor with uniform random generator. AddAttr>("dims", "the dimension of random tensor"); AddAttr("min", "Minimum value of uniform random").SetDefault(-1.0f); AddAttr("max", "Maximun value of uniform random").SetDefault(1.0f); - AddAttr("seed", - "Random seed of uniform random. " - "0 means generate a seed by system") - .SetDefault(0); } }; } // namespace operators @@ -83,5 +51,6 @@ Used to initialize tensor with uniform random generator. REGISTER_OP_WITHOUT_GRADIENT(uniform_random, paddle::operators::UniformRandomOp, paddle::operators::UniformRandomOpMaker); -REGISTER_OP_CPU_KERNEL(uniform_random, - paddle::operators::CPUUniformRandomKernel); +REGISTER_OP_CPU_KERNEL( + uniform_random, + paddle::operators::UniformRandomKernel); diff --git a/paddle/operators/uniform_random_op.cu b/paddle/operators/uniform_random_op.cu index 7a243555b6385af690e9632dfa81bf96d70f925d..91368fa73e9769b03d0dbf82973a7261ca23e30f 100644 --- a/paddle/operators/uniform_random_op.cu +++ b/paddle/operators/uniform_random_op.cu @@ -12,60 +12,11 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include -#include -#include -#include -#include "paddle/framework/op_registry.h" -#include "paddle/framework/operator.h" +#include "paddle/operators/uniform_random_op.h" namespace paddle { namespace operators { -template -struct UniformGenerator { - T min_, max_; - unsigned int seed_; - - __host__ __device__ UniformGenerator(T min, T max, int seed) - : min_(min), max_(max), seed_(seed) {} - - __host__ __device__ T operator()(const unsigned int n) const { - thrust::minstd_rand rng; - rng.seed(seed_); - thrust::uniform_real_distribution dist(min_, max_); - rng.discard(n); - return dist(rng); - } -}; - -// It seems that Eigen::Tensor::random in GPU will SEGFAULT. -// Use std::random and thrust::random(thrust is a std library in CUDA) to -// implement uniform random. -template -class GPUUniformRandomKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* tensor = context.Output("Out"); - T* data = tensor->mutable_data(context.GetPlace()); - unsigned int seed = - static_cast(context.op_.GetAttr("seed")); - if (seed == 0) { - std::random_device rd; - seed = rd(); - } - T min = static_cast(context.op_.GetAttr("min")); - T max = static_cast(context.op_.GetAttr("max")); - thrust::counting_iterator index_sequence_begin(0); - ssize_t N = framework::product(tensor->dims()); - thrust::transform(index_sequence_begin, index_sequence_begin + N, - thrust::device_ptr(data), - UniformGenerator(min, max, seed)); - } -}; - -} // namespace operators -} // namespace paddle - REGISTER_OP_GPU_KERNEL(uniform_random, - paddle::operators::GPUUniformRandomKernel); + paddle::operators::GPUUniformRandomKernel< + paddle::platform::GPUPlace, float>); diff --git a/paddle/operators/uniform_random_op.h b/paddle/operators/uniform_random_op.h new file mode 100644 index 0000000000000000000000000000000000000000..ec009b025e7c32fbcb4325f3b1b46ce8f0d9eb85 --- /dev/null +++ b/paddle/operators/uniform_random_op.h @@ -0,0 +1,38 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/framework/op_registry.h" +#include "paddle/operators/math/math_function.h" + +namespace paddle { +namespace operators { +template +class UniformRandomKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* tensor = context.Output("Out"); + T* data = tensor->mutable_data(context.GetPlace()); + T min = static_cast(context.op_.GetAttr("min")); + T max = static_cast(context.op_.GetAttr("max")); + auto n = framework::product(tensor->dims()); + + auto* device_context = + const_cast(context.device_context_); + math::RandUniform(n, min, max, data, device_context); + } +}; +} +} diff --git a/paddle/platform/device_context.cc b/paddle/platform/device_context.cc index f92c15ae450e94de44d27e77763e791e6bae4426..fabbb55443d1e117ca09e24881150026196c073a 100644 --- a/paddle/platform/device_context.cc +++ b/paddle/platform/device_context.cc @@ -25,8 +25,17 @@ CPUDeviceContext::CPUDeviceContext() { eigen_device_.reset(new Eigen::DefaultDevice()); } -CPUDeviceContext::CPUDeviceContext(CPUPlace place) { +CPUDeviceContext::CPUDeviceContext(CPUPlace place, int rand_seed) { eigen_device_.reset(new Eigen::DefaultDevice()); + rand_seed_ = rand_seed; +} + +std::minstd_rand& CPUDeviceContext::rand_engine() { + if (!rand_engine_) { + rand_engine_.reset(new std::minstd_rand()); + rand_engine_->seed(rand_seed_); + } + return *(rand_engine_.get()); } Eigen::DefaultDevice* CPUDeviceContext::eigen_device() const { @@ -95,7 +104,8 @@ Eigen::GpuDevice* DeviceContext::get_eigen_device() const { return reinterpret_cast(this)->eigen_device(); } -CUDADeviceContext::CUDADeviceContext(GPUPlace place) : place_(place) { +CUDADeviceContext::CUDADeviceContext(GPUPlace place, uint64_t seed) + : place_(place), seed_(seed) { SetDeviceId(place_.device); PADDLE_ENFORCE(cudaStreamCreate(&stream_)); eigen_stream_.reset(new EigenCudaStreamDevice()); @@ -114,9 +124,6 @@ CUDADeviceContext::~CUDADeviceContext() { PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_)); } - if (curand_generator_) { - PADDLE_ENFORCE(dynload::curandDestroyGenerator(curand_generator_)); - } eigen_stream_.reset(); eigen_device_.reset(); PADDLE_ENFORCE(cudaStreamDestroy(stream_)); @@ -150,21 +157,16 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() { return cudnn_handle_; } -cudaStream_t CUDADeviceContext::stream() { return stream_; } - -curandGenerator_t CUDADeviceContext::curand_generator() { - if (!curand_generator_) { - SetDeviceId(place_.device); - PADDLE_ENFORCE(dynload::curandCreateGenerator(&curand_generator_, - CURAND_RNG_PSEUDO_DEFAULT)); - PADDLE_ENFORCE( - dynload::curandSetPseudoRandomGeneratorSeed(curand_generator_, seed_)); - - PADDLE_ENFORCE(dynload::curandSetStream(curand_generator_, stream_)); +thrust::minstd_rand& CPUDeviceContext::rand_engine() { + if (!rand_engine_) { + rand_engine_.reset(new thrust::minstd_rand()); + rand_engine_->seed(rand_seed_); } - return curand_generator_; + return *(rand_engine_.get()); } +cudaStream_t CUDADeviceContext::stream() { return stream_; } + #endif // PADDLE_ONLY_CPU } // namespace platform diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index c5042ae33e47e04521e59e0d91ddd8d4efffe50a..e4de3807cd4925cade697a090d59500c8a3e0b39 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -15,9 +15,10 @@ limitations under the License. */ #include "paddle/platform/place.h" #ifndef PADDLE_ONLY_CPU +#include +#include #include "paddle/platform/dynload/cublas.h" #include "paddle/platform/dynload/cudnn.h" -#include "paddle/platform/dynload/curand.h" #include "paddle/platform/gpu_info.h" #define EIGEN_USE_GPU #endif @@ -40,14 +41,18 @@ class DeviceContext { class CPUDeviceContext : public DeviceContext { public: CPUDeviceContext(); - explicit CPUDeviceContext(CPUPlace); + explicit CPUDeviceContext(CPUPlace place, int rand_seed = 0); virtual ~CPUDeviceContext() {} Eigen::DefaultDevice* eigen_device() const; + std::minstd_rand& rand_engine(); + Place GetPlace() const override; private: + int rand_seed_; + std::unique_ptr rand_engine_; std::unique_ptr eigen_device_; }; @@ -56,7 +61,7 @@ class EigenCudaStreamDevice; class CUDADeviceContext : public DeviceContext { public: - explicit CUDADeviceContext(GPUPlace); + explicit CUDADeviceContext(GPUPlace place, uint64_t rand_seed = 0); virtual ~CUDADeviceContext(); /*! \brief Wait for all operations completion in the stream. */ @@ -75,8 +80,7 @@ class CUDADeviceContext : public DeviceContext { /*! \brief Return cudnn handle in the device context. */ cudnnHandle_t cudnn_handle(); - /*! \brief Return curand handle in the device context. */ - curandGenerator_t curand_generator(); + thrust::minstd_rand& CPUDeviceContext::rand_engine(); /*! \brief Return cuda stream in the device context. */ cudaStream_t stream(); @@ -85,18 +89,16 @@ class CUDADeviceContext : public DeviceContext { private: GPUPlace place_; - private: std::unique_ptr eigen_device_; std::unique_ptr eigen_stream_; - private: - uint64_t seed_; + uint64_t rand_seed_; + std::unique_ptr rand_engine_; // clang-format off cudaStream_t stream_{nullptr}; cudnnHandle_t cudnn_handle_{nullptr}; cublasHandle_t cublas_handle_{nullptr}; - curandGenerator_t curand_generator_{nullptr}; // clang-format on }; diff --git a/python/paddle/v2/framework/tests/CMakeLists.txt b/python/paddle/v2/framework/tests/CMakeLists.txt index ce57a0713092723b6a99b2416e06ff1a436f043b..b07a65f4d1fed12d82c638ee59f9de72379cfcbe 100644 --- a/python/paddle/v2/framework/tests/CMakeLists.txt +++ b/python/paddle/v2/framework/tests/CMakeLists.txt @@ -22,7 +22,7 @@ py_test(test_rowwise_add_op SRCS test_rowwise_add_op.py) py_test(test_default_scope_funcs SRCS test_default_scope_funcs.py) py_test(test_operator SRCS test_operator.py) -# py_test(test_gaussian_random_op SRCS test_gaussian_random_op.py) +py_test(test_gaussian_random_op SRCS test_gaussian_random_op.py) py_test(test_uniform_random_op SRCS test_uniform_random_op.py) py_test(test_recurrent_op SRCS test_recurrent_op.py) py_test(test_sgd_op SRCS test_sgd_op.py) diff --git a/python/paddle/v2/framework/tests/test_gaussian_random_op.py b/python/paddle/v2/framework/tests/test_gaussian_random_op.py index f95ed70b58d611b3233a21d3f2a34c864ae4d1b3..367d21b3017f093b2b13b4eda2480a6848fd7046 100644 --- a/python/paddle/v2/framework/tests/test_gaussian_random_op.py +++ b/python/paddle/v2/framework/tests/test_gaussian_random_op.py @@ -17,12 +17,7 @@ class GaussianRandomTest(unittest.TestCase): scope.new_var("Out").get_tensor() op = Operator( - "gaussian_random", - Out="Out", - dims=[1000, 784], - mean=.0, - std=1., - seed=10) + "gaussian_random", Out="Out", dims=[1000, 784], mean=.0, std=1.) op.infer_shape(scope) context = core.DeviceContext.create(place) diff --git a/python/paddle/v2/framework/tests/test_uniform_random_op.py b/python/paddle/v2/framework/tests/test_uniform_random_op.py index c3d2bb44da3977c0899b2609a8efe15b7e1789f2..95c36a27cf4f1115ffb264f094543c3e23becc01 100644 --- a/python/paddle/v2/framework/tests/test_uniform_random_op.py +++ b/python/paddle/v2/framework/tests/test_uniform_random_op.py @@ -17,12 +17,7 @@ class UniformRandomTest(unittest.TestCase): scope.new_var("X").get_tensor() op = Operator( - "uniform_random", - Out="X", - dims=[1000, 784], - min=-5.0, - max=10.0, - seed=10) + "uniform_random", Out="X", dims=[1000, 784], min=-5.0, max=10.0) op.infer_shape(scope) ctx = core.DeviceContext.create(place)