提交 848c317a 编写于 作者: Q qiaolongfei

update gpu code

上级 6fcdc916
...@@ -43,13 +43,13 @@ class GPUGaussianRandomKernel : public framework::OpKernel { ...@@ -43,13 +43,13 @@ class GPUGaussianRandomKernel : public framework::OpKernel {
auto* tensor = context.Output<framework::Tensor>("Out"); auto* tensor = context.Output<framework::Tensor>("Out");
T* data = tensor->mutable_data<T>(context.GetPlace()); T* data = tensor->mutable_data<T>(context.GetPlace());
unsigned int seed = unsigned int seed =
static_cast<unsigned int>(context.op_.GetAttr<int>("seed")); static_cast<unsigned int>(context.op().GetAttr<int>("seed"));
if (seed == 0) { if (seed == 0) {
std::random_device rd; std::random_device rd;
seed = rd(); seed = rd();
} }
T mean = static_cast<T>(context.op_.GetAttr<float>("mean")); T mean = static_cast<T>(context.op().GetAttr<float>("mean"));
T std = static_cast<T>(context.op_.GetAttr<float>("std")); T std = static_cast<T>(context.op().GetAttr<float>("std"));
thrust::counting_iterator<unsigned int> index_sequence_begin(0); thrust::counting_iterator<unsigned int> index_sequence_begin(0);
ssize_t N = framework::product(tensor->dims()); ssize_t N = framework::product(tensor->dims());
thrust::transform(index_sequence_begin, index_sequence_begin + N, thrust::transform(index_sequence_begin, index_sequence_begin + N,
......
...@@ -46,13 +46,13 @@ class GPUUniformRandomKernel : public framework::OpKernel { ...@@ -46,13 +46,13 @@ class GPUUniformRandomKernel : public framework::OpKernel {
auto* tensor = context.Output<framework::Tensor>("Out"); auto* tensor = context.Output<framework::Tensor>("Out");
T* data = tensor->mutable_data<T>(context.GetPlace()); T* data = tensor->mutable_data<T>(context.GetPlace());
unsigned int seed = unsigned int seed =
static_cast<unsigned int>(context.op_.GetAttr<int>("seed")); static_cast<unsigned int>(context.op().GetAttr<int>("seed"));
if (seed == 0) { if (seed == 0) {
std::random_device rd; std::random_device rd;
seed = rd(); seed = rd();
} }
T min = static_cast<T>(context.op_.GetAttr<float>("min")); T min = static_cast<T>(context.op().GetAttr<float>("min"));
T max = static_cast<T>(context.op_.GetAttr<float>("max")); T max = static_cast<T>(context.op().GetAttr<float>("max"));
thrust::counting_iterator<unsigned int> index_sequence_begin(0); thrust::counting_iterator<unsigned int> index_sequence_begin(0);
ssize_t N = framework::product(tensor->dims()); ssize_t N = framework::product(tensor->dims());
thrust::transform(index_sequence_begin, index_sequence_begin + N, thrust::transform(index_sequence_begin, index_sequence_begin + N,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册