提交 70825506 编写于 作者: D dongzhihong

"remove context random seeding "

上级 6fc6647c
...@@ -21,12 +21,10 @@ Eigen::DefaultDevice* DeviceContext::get_eigen_device<Eigen::DefaultDevice>() ...@@ -21,12 +21,10 @@ Eigen::DefaultDevice* DeviceContext::get_eigen_device<Eigen::DefaultDevice>()
} }
CPUDeviceContext::CPUDeviceContext() { CPUDeviceContext::CPUDeviceContext() {
random_seed_ = std::chrono::system_clock::now().time_since_epoch().count();
eigen_device_.reset(new Eigen::DefaultDevice()); eigen_device_.reset(new Eigen::DefaultDevice());
} }
CPUDeviceContext::CPUDeviceContext(CPUPlace place) { CPUDeviceContext::CPUDeviceContext(CPUPlace place) {
random_seed_ = std::chrono::system_clock::now().time_since_epoch().count();
eigen_device_.reset(new Eigen::DefaultDevice()); eigen_device_.reset(new Eigen::DefaultDevice());
} }
...@@ -44,7 +42,6 @@ Eigen::GpuDevice* DeviceContext::get_eigen_device<Eigen::GpuDevice>() const { ...@@ -44,7 +42,6 @@ Eigen::GpuDevice* DeviceContext::get_eigen_device<Eigen::GpuDevice>() const {
} }
CUDADeviceContext::CUDADeviceContext(GPUPlace place) : place_(place) { CUDADeviceContext::CUDADeviceContext(GPUPlace place) : place_(place) {
random_seed_ = std::chrono::system_clock::now().time_since_epoch().count();
SetDeviceId(place_.device); SetDeviceId(place_.device);
// TODO(qijun) Pass a created cuda stream to Eigen::CudaStreamDevice directly // TODO(qijun) Pass a created cuda stream to Eigen::CudaStreamDevice directly
// here will cause segment fault. We must implement a class derived from // here will cause segment fault. We must implement a class derived from
...@@ -111,8 +108,8 @@ curandGenerator_t CUDADeviceContext::curand_generator() { ...@@ -111,8 +108,8 @@ curandGenerator_t CUDADeviceContext::curand_generator() {
SetDeviceId(place_.device); SetDeviceId(place_.device);
PADDLE_ENFORCE(dynload::curandCreateGenerator(&curand_generator_, PADDLE_ENFORCE(dynload::curandCreateGenerator(&curand_generator_,
CURAND_RNG_PSEUDO_DEFAULT)); CURAND_RNG_PSEUDO_DEFAULT));
PADDLE_ENFORCE(dynload::curandSetPseudoRandomGeneratorSeed( PADDLE_ENFORCE(
curand_generator_, random_seed_)); dynload::curandSetPseudoRandomGeneratorSeed(curand_generator_, seed_));
} }
return curand_generator_; return curand_generator_;
} }
......
...@@ -21,7 +21,6 @@ limitations under the License. */ ...@@ -21,7 +21,6 @@ limitations under the License. */
#include "paddle/platform/gpu_info.h" #include "paddle/platform/gpu_info.h"
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#endif #endif
#include <chrono>
#include <memory> #include <memory>
#include "paddle/platform/place.h" #include "paddle/platform/place.h"
#include "unsupported/Eigen/CXX11/Tensor" #include "unsupported/Eigen/CXX11/Tensor"
...@@ -40,7 +39,6 @@ class DeviceContext { ...@@ -40,7 +39,6 @@ class DeviceContext {
class CPUDeviceContext : public DeviceContext { class CPUDeviceContext : public DeviceContext {
public: public:
typedef std::mt19937 random_generator_type;
CPUDeviceContext(); CPUDeviceContext();
explicit CPUDeviceContext(CPUPlace); explicit CPUDeviceContext(CPUPlace);
virtual ~CPUDeviceContext() {} virtual ~CPUDeviceContext() {}
...@@ -49,16 +47,7 @@ class CPUDeviceContext : public DeviceContext { ...@@ -49,16 +47,7 @@ class CPUDeviceContext : public DeviceContext {
Place GetPlace() const override; Place GetPlace() const override;
random_generator_type& RandGenerator() {
if (!rand_generator_) {
rand_generator_.reset(new random_generator_type(random_seed_));
}
return *rand_generator_.get();
}
private: private:
unsigned random_seed_;
std::unique_ptr<random_generator_type> rand_generator_;
std::unique_ptr<Eigen::DefaultDevice> eigen_device_; std::unique_ptr<Eigen::DefaultDevice> eigen_device_;
}; };
...@@ -97,7 +86,8 @@ class CUDADeviceContext : public DeviceContext { ...@@ -97,7 +86,8 @@ class CUDADeviceContext : public DeviceContext {
std::unique_ptr<Eigen::CudaStreamDevice> eigen_stream_; std::unique_ptr<Eigen::CudaStreamDevice> eigen_stream_;
private: private:
unsigned random_seed_; uint64_t seed_;
// clang-format off // clang-format off
cudnnHandle_t cudnn_handle_ = nullptr; cudnnHandle_t cudnn_handle_ = nullptr;
cublasHandle_t cublas_handle_ = nullptr; cublasHandle_t cublas_handle_ = nullptr;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册