gaussian_random_op.cu 1.0 KB
Newer Older
D
dongzhihong 已提交
1 2 3
#include "paddle/operators/random_op.h"
#include "paddle/framework/op_registry.h"

D
dongzhihong 已提交
4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
namespace paddle {
namespace operators {
  
template<typename T>
class GaussianRandomOpKernel<platform::GPUPlace, T> : public framework::OpKernel {
public:
  void Compute(const framework::KernelContext& context) const override {
    auto mean = context.op_.GetAttr<T>("mean");
    auto std = context.op_.GetAttr<T>("std");
    auto* output = context.Output(0)->GetMutable<framework::Tensor>();
    T* r = output->mutable_data<T>(context.GetPlace());
    auto ctx = static_cast<const platform::GPUDeviceContext*>
      (context.device_context_);
    // generator need to modify context 
    auto g = const_cast<platform::GPUDeviceContext*>(ctx)->RandGenerator();
    curandGenerateNormal(g, r, framework::product(output->dims()), mean, std);
D
dongzhihong 已提交
20

D
dongzhihong 已提交
21 22 23 24 25 26 27 28
  }
};
  
}  // namespace operators
}  // namespace paddle
  

typedef paddle::operators::GaussianRandomOpKernel<paddle::platform::GPUPlace, float>
D
dongzhihong 已提交
29
  RandomOpKernel_GPU_float;
D
dongzhihong 已提交
30
REGISTER_OP_GPU_KERNEL(gaussian_random, GaussianRandomOpKernel_GPU_float);