提交 70825506 编写于 作者: D dongzhihong

"remove context random seeding "

上级 6fc6647c
......@@ -21,12 +21,10 @@ Eigen::DefaultDevice* DeviceContext::get_eigen_device<Eigen::DefaultDevice>()
}
CPUDeviceContext::CPUDeviceContext() {
random_seed_ = std::chrono::system_clock::now().time_since_epoch().count();
eigen_device_.reset(new Eigen::DefaultDevice());
}
CPUDeviceContext::CPUDeviceContext(CPUPlace place) {
random_seed_ = std::chrono::system_clock::now().time_since_epoch().count();
eigen_device_.reset(new Eigen::DefaultDevice());
}
......@@ -44,7 +42,6 @@ Eigen::GpuDevice* DeviceContext::get_eigen_device<Eigen::GpuDevice>() const {
}
CUDADeviceContext::CUDADeviceContext(GPUPlace place) : place_(place) {
random_seed_ = std::chrono::system_clock::now().time_since_epoch().count();
SetDeviceId(place_.device);
// TODO(qijun) Pass a created cuda stream to Eigen::CudaStreamDevice directly
// here will cause segment fault. We must implement a class derived from
......@@ -111,8 +108,8 @@ curandGenerator_t CUDADeviceContext::curand_generator() {
SetDeviceId(place_.device);
PADDLE_ENFORCE(dynload::curandCreateGenerator(&curand_generator_,
CURAND_RNG_PSEUDO_DEFAULT));
PADDLE_ENFORCE(dynload::curandSetPseudoRandomGeneratorSeed(
curand_generator_, random_seed_));
PADDLE_ENFORCE(
dynload::curandSetPseudoRandomGeneratorSeed(curand_generator_, seed_));
}
return curand_generator_;
}
......
......@@ -21,7 +21,6 @@ limitations under the License. */
#include "paddle/platform/gpu_info.h"
#define EIGEN_USE_GPU
#endif
#include <chrono>
#include <memory>
#include "paddle/platform/place.h"
#include "unsupported/Eigen/CXX11/Tensor"
......@@ -40,7 +39,6 @@ class DeviceContext {
class CPUDeviceContext : public DeviceContext {
public:
typedef std::mt19937 random_generator_type;
CPUDeviceContext();
explicit CPUDeviceContext(CPUPlace);
virtual ~CPUDeviceContext() {}
......@@ -49,16 +47,7 @@ class CPUDeviceContext : public DeviceContext {
Place GetPlace() const override;
random_generator_type& RandGenerator() {
if (!rand_generator_) {
rand_generator_.reset(new random_generator_type(random_seed_));
}
return *rand_generator_.get();
}
private:
unsigned random_seed_;
std::unique_ptr<random_generator_type> rand_generator_;
std::unique_ptr<Eigen::DefaultDevice> eigen_device_;
};
......@@ -97,7 +86,8 @@ class CUDADeviceContext : public DeviceContext {
std::unique_ptr<Eigen::CudaStreamDevice> eigen_stream_;
private:
unsigned random_seed_;
uint64_t seed_;
// clang-format off
cudnnHandle_t cudnn_handle_ = nullptr;
cublasHandle_t cublas_handle_ = nullptr;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册