提交 d525abed 编写于 作者: Q qijun

refine random related ops

上级 0d9846f3
...@@ -58,7 +58,7 @@ op_library(rowwise_add_op SRCS rowwise_add_op.cu rowwise_add_op.cc) ...@@ -58,7 +58,7 @@ op_library(rowwise_add_op SRCS rowwise_add_op.cu rowwise_add_op.cc)
op_library(sigmoid_op SRCS sigmoid_op.cc sigmoid_op.cu) op_library(sigmoid_op SRCS sigmoid_op.cc sigmoid_op.cu)
op_library(softmax_op SRCS softmax_op.cc softmax_op.cu) op_library(softmax_op SRCS softmax_op.cc softmax_op.cu)
op_library(gaussian_random_op SRCS gaussian_random_op.cc gaussian_random_op.cu) op_library(gaussian_random_op SRCS gaussian_random_op.cc gaussian_random_op.cu DEPS math_function)
op_library(cross_entropy_op SRCS cross_entropy_op.cc cross_entropy_op.cu) op_library(cross_entropy_op SRCS cross_entropy_op.cc cross_entropy_op.cu)
op_library(fill_zeros_like_op SRCS fill_zeros_like_op.cc fill_zeros_like_op.cu) op_library(fill_zeros_like_op SRCS fill_zeros_like_op.cc fill_zeros_like_op.cu)
...@@ -67,4 +67,4 @@ op_library(sgd_op SRCS sgd_op.cc sgd_op.cu) ...@@ -67,4 +67,4 @@ op_library(sgd_op SRCS sgd_op.cc sgd_op.cu)
op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc
DEPS framework_proto tensor op_registry operator net_op) DEPS framework_proto tensor op_registry operator net_op)
op_library(uniform_random_op op_library(uniform_random_op
SRCS uniform_random_op.cc uniform_random_op.cu) SRCS uniform_random_op.cc uniform_random_op.cu DEPS math_function)
...@@ -12,36 +12,11 @@ ...@@ -12,36 +12,11 @@
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <random> #include "paddle/operators/gaussian_random_op.h"
#include "paddle/framework/op_registry.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T>
class GaussianRandomKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
float mean = context.op_.GetAttr<float>("mean");
float std = context.op_.GetAttr<float>("std");
auto* tensor = context.Output<framework::Tensor>(0);
T* data = tensor->mutable_data<T>(context.GetPlace());
// TODO(dzh): attribute does not support unsigned int.
// And we need a global random seed configuration.
int seed = context.op_.GetAttr<int>("seed");
if (seed == 0) {
seed = std::random_device()();
}
std::mt19937 g(seed);
std::normal_distribution<T> distribution(mean, std);
ssize_t size = framework::product(tensor->dims());
for (int i = 0; i < size; ++i) {
data[i] = distribution(g);
}
}
};
class GaussianRandomOp : public framework::OperatorWithKernel { class GaussianRandomOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
...@@ -70,10 +45,6 @@ Use to initialize tensor with gaussian random generator. ...@@ -70,10 +45,6 @@ Use to initialize tensor with gaussian random generator.
AddAttr<std::vector<int>>("dims", "The dimension of random tensor."); AddAttr<std::vector<int>>("dims", "The dimension of random tensor.");
AddAttr<float>("mean", "mean value of random.").SetDefault(.0f); AddAttr<float>("mean", "mean value of random.").SetDefault(.0f);
AddAttr<float>("std", "minimum value of random value.").SetDefault(1.0f); AddAttr<float>("std", "minimum value of random value.").SetDefault(1.0f);
AddAttr<int>("seed",
"Random seed of generator."
"0 means use system wide seed")
.SetDefault(0);
} }
}; };
...@@ -83,4 +54,6 @@ Use to initialize tensor with gaussian random generator. ...@@ -83,4 +54,6 @@ Use to initialize tensor with gaussian random generator.
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(gaussian_random, ops::GaussianRandomOp, REGISTER_OP_WITHOUT_GRADIENT(gaussian_random, ops::GaussianRandomOp,
ops::GaussianRandomOpMaker); ops::GaussianRandomOpMaker);
REGISTER_OP_CPU_KERNEL(gaussian_random, ops::GaussianRandomKernel<float>); REGISTER_OP_CPU_KERNEL(
gaussian_random,
ops::GaussianRandomKernel<paddle::platform::CPUPlace, float>);
...@@ -12,42 +12,9 @@ ...@@ -12,42 +12,9 @@
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <memory> #include "paddle/operators/gaussian_random_op.h"
#include <random>
#include "paddle/platform/dynload/curand.h"
#include "paddle/platform/gpu_info.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
template <typename T>
class GaussianRandomKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
float mean = context.op_.GetAttr<float>("mean");
float std = context.op_.GetAttr<float>("std");
auto* tensor = context.Output<framework::Tensor>(0);
T* data = tensor->mutable_data<T>(context.GetPlace());
int seed = context.op_.GetAttr<int>("seed");
if (seed == 0) {
std::random_device rd;
seed = rd();
}
curandGenerator_t g;
PADDLE_ENFORCE(platform::dynload::curandCreateGenerator(
&g, CURAND_RNG_PSEUDO_DEFAULT));
PADDLE_ENFORCE(
platform::dynload::curandSetPseudoRandomGeneratorSeed(g, seed));
platform::dynload::curandGenerateNormal(
g, data, framework::product(tensor->dims()), mean, std);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(gaussian_random, ops::GaussianRandomKernel<float>); REGISTER_OP_GPU_KERNEL(
gaussian_random,
ops::GaussianRandomKernel<paddle::platform::GPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/math_function.h"
namespace paddle {
namespace operators {
template <typename Place, typename T>
class GaussianRandomKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* tensor = context.Output<framework::Tensor>("Out");
T* data = tensor->mutable_data<T>(context.GetPlace());
T mean = static_cast<T>(context.op_.GetAttr<float>("mean"));
T std = static_cast<T>(context.op_.GetAttr<float>("std"));
auto n = framework::product(tensor->dims());
auto* device_context =
const_cast<platform::DeviceContext*>(context.device_context_);
math::RandGaussian<Place, T>(n, mean, std, data, device_context);
}
};
}
}
...@@ -109,6 +109,28 @@ void matmul<platform::CPUPlace, double>(const framework::Tensor& matrix_a, ...@@ -109,6 +109,28 @@ void matmul<platform::CPUPlace, double>(const framework::Tensor& matrix_a,
matrix_b.data<double>(), beta, matrix_out->data<double>(), context); matrix_b.data<double>(), beta, matrix_out->data<double>(), context);
} }
template <>
void RandUniform<platform::CPUPlace, float>(const int n, const float min,
const float max, float* output,
platform::DeviceContext* context) {
auto* cpu_context = reinterpret_cast<platform::CPUDeviceContext*>(context);
std::uniform_real_distribution<float> distribution(min, max);
for (int i = 0; i < n; i++) {
output[i] = distribution(cpu_context->rand_engine());
}
}
template <>
void RandGaussian<platform::CPUPlace, float>(const int n, const float mean,
const float std, float* output,
platform::DeviceContext* context) {
auto* cpu_context = reinterpret_cast<platform::CPUDeviceContext*>(context);
std::normal_distribution<float> distribution(mean, std);
for (int i = 0; i < n; i++) {
output[i] = distribution(cpu_context->rand_engine());
}
}
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -12,6 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,6 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <thrust/device_ptr.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/random.h>
#include <thrust/transform.h>
#include "paddle/operators/math/math_function.h" #include "paddle/operators/math/math_function.h"
namespace paddle { namespace paddle {
...@@ -122,6 +126,38 @@ void matmul<platform::GPUPlace, double>(const framework::Tensor& matrix_a, ...@@ -122,6 +126,38 @@ void matmul<platform::GPUPlace, double>(const framework::Tensor& matrix_a,
matrix_b.data<double>(), beta, matrix_out->data<double>(), context); matrix_b.data<double>(), beta, matrix_out->data<double>(), context);
} }
template <>
void RandUniform<platform::GPUPlace, float>(const int n, const float min,
const float max, float* output,
platform::DeviceContext* context) {
auto* cuda_context = reinterpret_cast<platform::CUDADeviceContext*>(context);
thrust::uniform_real_distribution<float> distribution(min, max);
thrust::minstd_rand engine = cuda_context->rand_enigne();
engine->discard(n);
thrust::counting_iterator<unsigned int> index_sequence_begin(0);
thrust::transform(thrust::cuda::par.on(cuda_context->stream()),
index_sequence_begin, index_sequence_begin + n,
thrust::device_ptr<float>(output), distribution(engine));
}
template <>
void RandGaussian<platform::GPUPlace, float>(const int n, const float mean,
const float std, float* output,
platform::DeviceContext* context) {
auto* cuda_context = reinterpret_cast<platform::CUDADeviceContext*>(context);
thrust::normal_distribution<float> distribution(mean, std);
thrust::minstd_rand engine = cuda_context->rand_enigne();
engine->discard(n);
thrust::counting_iterator<unsigned int> index_sequence_begin(0);
thrust::transform(thrust::cuda::par.on(cuda_context->stream()),
index_sequence_begin, index_sequence_begin + n,
thrust::device_ptr<float>(output), distribution(engine));
}
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -77,6 +77,14 @@ void matmul(const framework::Tensor& matrix_a, bool trans_a, ...@@ -77,6 +77,14 @@ void matmul(const framework::Tensor& matrix_a, bool trans_a,
framework::Tensor* matrix_out, T beta, framework::Tensor* matrix_out, T beta,
platform::DeviceContext* context); platform::DeviceContext* context);
template <typename Place, typename T>
void RandUniform(const int n, const T min, const T max, T* output,
platform::DeviceContext* context);
template <typename Place, typename T>
void RandGaussian(const int n, const T mean, const T std, T* output,
platform::DeviceContext* context);
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
limitations under the License. */ limitations under the License. */
#include "paddle/operators/mul_op.h" #include "paddle/operators/mul_op.h"
#include "paddle/operators/math/math_function.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -12,39 +12,11 @@ ...@@ -12,39 +12,11 @@
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <random> #include "paddle/operators/uniform_random_op.h"
#include <type_traits>
#include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
// It seems that Eigen::Tensor::random in GPU will SEGFAULT.
// Use std::random and thrust::random(thrust is a std library in CUDA) to
// implement uniform random.
template <typename T>
class CPUUniformRandomKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* tensor = context.Output<framework::Tensor>("Out");
T* data = tensor->mutable_data<T>(context.GetPlace());
unsigned int seed =
static_cast<unsigned int>(context.op_.GetAttr<int>("seed"));
std::minstd_rand engine;
if (seed == 0) {
seed = std::random_device()();
}
engine.seed(seed);
std::uniform_real_distribution<T> dist(
static_cast<T>(context.op_.GetAttr<float>("min")),
static_cast<T>(context.op_.GetAttr<float>("max")));
for (ssize_t i = 0; i < framework::product(tensor->dims()); ++i) {
data[i] = dist(engine);
}
}
};
class UniformRandomOp : public framework::OperatorWithKernel { class UniformRandomOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
...@@ -72,10 +44,6 @@ Used to initialize tensor with uniform random generator. ...@@ -72,10 +44,6 @@ Used to initialize tensor with uniform random generator.
AddAttr<std::vector<int>>("dims", "the dimension of random tensor"); AddAttr<std::vector<int>>("dims", "the dimension of random tensor");
AddAttr<float>("min", "Minimum value of uniform random").SetDefault(-1.0f); AddAttr<float>("min", "Minimum value of uniform random").SetDefault(-1.0f);
AddAttr<float>("max", "Maximun value of uniform random").SetDefault(1.0f); AddAttr<float>("max", "Maximun value of uniform random").SetDefault(1.0f);
AddAttr<int>("seed",
"Random seed of uniform random. "
"0 means generate a seed by system")
.SetDefault(0);
} }
}; };
} // namespace operators } // namespace operators
...@@ -83,5 +51,6 @@ Used to initialize tensor with uniform random generator. ...@@ -83,5 +51,6 @@ Used to initialize tensor with uniform random generator.
REGISTER_OP_WITHOUT_GRADIENT(uniform_random, paddle::operators::UniformRandomOp, REGISTER_OP_WITHOUT_GRADIENT(uniform_random, paddle::operators::UniformRandomOp,
paddle::operators::UniformRandomOpMaker); paddle::operators::UniformRandomOpMaker);
REGISTER_OP_CPU_KERNEL(uniform_random, REGISTER_OP_CPU_KERNEL(
paddle::operators::CPUUniformRandomKernel<float>); uniform_random,
paddle::operators::UniformRandomKernel<paddle::platform::CPUPlace, float>);
...@@ -12,60 +12,11 @@ ...@@ -12,60 +12,11 @@
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <thrust/device_ptr.h> #include "paddle/operators/uniform_random_op.h"
#include <thrust/iterator/counting_iterator.h>
#include <thrust/random.h>
#include <thrust/transform.h>
#include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T>
struct UniformGenerator {
T min_, max_;
unsigned int seed_;
__host__ __device__ UniformGenerator(T min, T max, int seed)
: min_(min), max_(max), seed_(seed) {}
__host__ __device__ T operator()(const unsigned int n) const {
thrust::minstd_rand rng;
rng.seed(seed_);
thrust::uniform_real_distribution<T> dist(min_, max_);
rng.discard(n);
return dist(rng);
}
};
// It seems that Eigen::Tensor::random in GPU will SEGFAULT.
// Use std::random and thrust::random(thrust is a std library in CUDA) to
// implement uniform random.
template <typename T>
class GPUUniformRandomKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* tensor = context.Output<framework::Tensor>("Out");
T* data = tensor->mutable_data<T>(context.GetPlace());
unsigned int seed =
static_cast<unsigned int>(context.op_.GetAttr<int>("seed"));
if (seed == 0) {
std::random_device rd;
seed = rd();
}
T min = static_cast<T>(context.op_.GetAttr<float>("min"));
T max = static_cast<T>(context.op_.GetAttr<float>("max"));
thrust::counting_iterator<unsigned int> index_sequence_begin(0);
ssize_t N = framework::product(tensor->dims());
thrust::transform(index_sequence_begin, index_sequence_begin + N,
thrust::device_ptr<T>(data),
UniformGenerator<T>(min, max, seed));
}
};
} // namespace operators
} // namespace paddle
REGISTER_OP_GPU_KERNEL(uniform_random, REGISTER_OP_GPU_KERNEL(uniform_random,
paddle::operators::GPUUniformRandomKernel<float>); paddle::operators::GPUUniformRandomKernel<
paddle::platform::GPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/math_function.h"
namespace paddle {
namespace operators {
template <typename Place, typename T>
class UniformRandomKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* tensor = context.Output<framework::Tensor>("Out");
T* data = tensor->mutable_data<T>(context.GetPlace());
T min = static_cast<T>(context.op_.GetAttr<float>("min"));
T max = static_cast<T>(context.op_.GetAttr<float>("max"));
auto n = framework::product(tensor->dims());
auto* device_context =
const_cast<platform::DeviceContext*>(context.device_context_);
math::RandUniform<Place, T>(n, min, max, data, device_context);
}
};
}
}
...@@ -25,8 +25,17 @@ CPUDeviceContext::CPUDeviceContext() { ...@@ -25,8 +25,17 @@ CPUDeviceContext::CPUDeviceContext() {
eigen_device_.reset(new Eigen::DefaultDevice()); eigen_device_.reset(new Eigen::DefaultDevice());
} }
CPUDeviceContext::CPUDeviceContext(CPUPlace place) { CPUDeviceContext::CPUDeviceContext(CPUPlace place, int rand_seed) {
eigen_device_.reset(new Eigen::DefaultDevice()); eigen_device_.reset(new Eigen::DefaultDevice());
rand_seed_ = rand_seed;
}
std::minstd_rand& CPUDeviceContext::rand_engine() {
if (!rand_engine_) {
rand_engine_.reset(new std::minstd_rand());
rand_engine_->seed(rand_seed_);
}
return *(rand_engine_.get());
} }
Eigen::DefaultDevice* CPUDeviceContext::eigen_device() const { Eigen::DefaultDevice* CPUDeviceContext::eigen_device() const {
...@@ -95,7 +104,8 @@ Eigen::GpuDevice* DeviceContext::get_eigen_device<Eigen::GpuDevice>() const { ...@@ -95,7 +104,8 @@ Eigen::GpuDevice* DeviceContext::get_eigen_device<Eigen::GpuDevice>() const {
return reinterpret_cast<const CUDADeviceContext*>(this)->eigen_device(); return reinterpret_cast<const CUDADeviceContext*>(this)->eigen_device();
} }
CUDADeviceContext::CUDADeviceContext(GPUPlace place) : place_(place) { CUDADeviceContext::CUDADeviceContext(GPUPlace place, uint64_t seed)
: place_(place), seed_(seed) {
SetDeviceId(place_.device); SetDeviceId(place_.device);
PADDLE_ENFORCE(cudaStreamCreate(&stream_)); PADDLE_ENFORCE(cudaStreamCreate(&stream_));
eigen_stream_.reset(new EigenCudaStreamDevice()); eigen_stream_.reset(new EigenCudaStreamDevice());
...@@ -114,9 +124,6 @@ CUDADeviceContext::~CUDADeviceContext() { ...@@ -114,9 +124,6 @@ CUDADeviceContext::~CUDADeviceContext() {
PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_)); PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_));
} }
if (curand_generator_) {
PADDLE_ENFORCE(dynload::curandDestroyGenerator(curand_generator_));
}
eigen_stream_.reset(); eigen_stream_.reset();
eigen_device_.reset(); eigen_device_.reset();
PADDLE_ENFORCE(cudaStreamDestroy(stream_)); PADDLE_ENFORCE(cudaStreamDestroy(stream_));
...@@ -150,21 +157,16 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() { ...@@ -150,21 +157,16 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() {
return cudnn_handle_; return cudnn_handle_;
} }
cudaStream_t CUDADeviceContext::stream() { return stream_; } thrust::minstd_rand& CPUDeviceContext::rand_engine() {
if (!rand_engine_) {
curandGenerator_t CUDADeviceContext::curand_generator() { rand_engine_.reset(new thrust::minstd_rand());
if (!curand_generator_) { rand_engine_->seed(rand_seed_);
SetDeviceId(place_.device);
PADDLE_ENFORCE(dynload::curandCreateGenerator(&curand_generator_,
CURAND_RNG_PSEUDO_DEFAULT));
PADDLE_ENFORCE(
dynload::curandSetPseudoRandomGeneratorSeed(curand_generator_, seed_));
PADDLE_ENFORCE(dynload::curandSetStream(curand_generator_, stream_));
} }
return curand_generator_; return *(rand_engine_.get());
} }
cudaStream_t CUDADeviceContext::stream() { return stream_; }
#endif // PADDLE_ONLY_CPU #endif // PADDLE_ONLY_CPU
} // namespace platform } // namespace platform
......
...@@ -15,9 +15,10 @@ limitations under the License. */ ...@@ -15,9 +15,10 @@ limitations under the License. */
#include "paddle/platform/place.h" #include "paddle/platform/place.h"
#ifndef PADDLE_ONLY_CPU #ifndef PADDLE_ONLY_CPU
#include <thrust/device_ptr.h>
#include <thrust/random.h>
#include "paddle/platform/dynload/cublas.h" #include "paddle/platform/dynload/cublas.h"
#include "paddle/platform/dynload/cudnn.h" #include "paddle/platform/dynload/cudnn.h"
#include "paddle/platform/dynload/curand.h"
#include "paddle/platform/gpu_info.h" #include "paddle/platform/gpu_info.h"
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#endif #endif
...@@ -40,14 +41,18 @@ class DeviceContext { ...@@ -40,14 +41,18 @@ class DeviceContext {
class CPUDeviceContext : public DeviceContext { class CPUDeviceContext : public DeviceContext {
public: public:
CPUDeviceContext(); CPUDeviceContext();
explicit CPUDeviceContext(CPUPlace); explicit CPUDeviceContext(CPUPlace place, int rand_seed = 0);
virtual ~CPUDeviceContext() {} virtual ~CPUDeviceContext() {}
Eigen::DefaultDevice* eigen_device() const; Eigen::DefaultDevice* eigen_device() const;
std::minstd_rand& rand_engine();
Place GetPlace() const override; Place GetPlace() const override;
private: private:
int rand_seed_;
std::unique_ptr<std::minstd_rand> rand_engine_;
std::unique_ptr<Eigen::DefaultDevice> eigen_device_; std::unique_ptr<Eigen::DefaultDevice> eigen_device_;
}; };
...@@ -56,7 +61,7 @@ class EigenCudaStreamDevice; ...@@ -56,7 +61,7 @@ class EigenCudaStreamDevice;
class CUDADeviceContext : public DeviceContext { class CUDADeviceContext : public DeviceContext {
public: public:
explicit CUDADeviceContext(GPUPlace); explicit CUDADeviceContext(GPUPlace place, uint64_t rand_seed = 0);
virtual ~CUDADeviceContext(); virtual ~CUDADeviceContext();
/*! \brief Wait for all operations completion in the stream. */ /*! \brief Wait for all operations completion in the stream. */
...@@ -75,8 +80,7 @@ class CUDADeviceContext : public DeviceContext { ...@@ -75,8 +80,7 @@ class CUDADeviceContext : public DeviceContext {
/*! \brief Return cudnn handle in the device context. */ /*! \brief Return cudnn handle in the device context. */
cudnnHandle_t cudnn_handle(); cudnnHandle_t cudnn_handle();
/*! \brief Return curand handle in the device context. */ thrust::minstd_rand& CPUDeviceContext::rand_engine();
curandGenerator_t curand_generator();
/*! \brief Return cuda stream in the device context. */ /*! \brief Return cuda stream in the device context. */
cudaStream_t stream(); cudaStream_t stream();
...@@ -85,18 +89,16 @@ class CUDADeviceContext : public DeviceContext { ...@@ -85,18 +89,16 @@ class CUDADeviceContext : public DeviceContext {
private: private:
GPUPlace place_; GPUPlace place_;
private:
std::unique_ptr<Eigen::GpuDevice> eigen_device_; std::unique_ptr<Eigen::GpuDevice> eigen_device_;
std::unique_ptr<EigenCudaStreamDevice> eigen_stream_; std::unique_ptr<EigenCudaStreamDevice> eigen_stream_;
private: uint64_t rand_seed_;
uint64_t seed_; std::unique_ptr<thrust::minstd_rand> rand_engine_;
// clang-format off // clang-format off
cudaStream_t stream_{nullptr}; cudaStream_t stream_{nullptr};
cudnnHandle_t cudnn_handle_{nullptr}; cudnnHandle_t cudnn_handle_{nullptr};
cublasHandle_t cublas_handle_{nullptr}; cublasHandle_t cublas_handle_{nullptr};
curandGenerator_t curand_generator_{nullptr};
// clang-format on // clang-format on
}; };
......
...@@ -22,7 +22,7 @@ py_test(test_rowwise_add_op SRCS test_rowwise_add_op.py) ...@@ -22,7 +22,7 @@ py_test(test_rowwise_add_op SRCS test_rowwise_add_op.py)
py_test(test_default_scope_funcs SRCS test_default_scope_funcs.py) py_test(test_default_scope_funcs SRCS test_default_scope_funcs.py)
py_test(test_operator SRCS test_operator.py) py_test(test_operator SRCS test_operator.py)
# py_test(test_gaussian_random_op SRCS test_gaussian_random_op.py) py_test(test_gaussian_random_op SRCS test_gaussian_random_op.py)
py_test(test_uniform_random_op SRCS test_uniform_random_op.py) py_test(test_uniform_random_op SRCS test_uniform_random_op.py)
py_test(test_recurrent_op SRCS test_recurrent_op.py) py_test(test_recurrent_op SRCS test_recurrent_op.py)
py_test(test_sgd_op SRCS test_sgd_op.py) py_test(test_sgd_op SRCS test_sgd_op.py)
......
...@@ -17,12 +17,7 @@ class GaussianRandomTest(unittest.TestCase): ...@@ -17,12 +17,7 @@ class GaussianRandomTest(unittest.TestCase):
scope.new_var("Out").get_tensor() scope.new_var("Out").get_tensor()
op = Operator( op = Operator(
"gaussian_random", "gaussian_random", Out="Out", dims=[1000, 784], mean=.0, std=1.)
Out="Out",
dims=[1000, 784],
mean=.0,
std=1.,
seed=10)
op.infer_shape(scope) op.infer_shape(scope)
context = core.DeviceContext.create(place) context = core.DeviceContext.create(place)
......
...@@ -17,12 +17,7 @@ class UniformRandomTest(unittest.TestCase): ...@@ -17,12 +17,7 @@ class UniformRandomTest(unittest.TestCase):
scope.new_var("X").get_tensor() scope.new_var("X").get_tensor()
op = Operator( op = Operator(
"uniform_random", "uniform_random", Out="X", dims=[1000, 784], min=-5.0, max=10.0)
Out="X",
dims=[1000, 784],
min=-5.0,
max=10.0,
seed=10)
op.infer_shape(scope) op.infer_shape(scope)
ctx = core.DeviceContext.create(place) ctx = core.DeviceContext.create(place)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册