gaussian_random_op.cu 1.1 KB
Newer Older
D
dongzhihong 已提交
1
#include "paddle/framework/op_registry.h"
D
dongzhihong 已提交
2
#include "paddle/operators/guassian_random_op.h"
D
dongzhihong 已提交
3

D
dongzhihong 已提交
4 5
namespace paddle {
namespace operators {
D
dongzhihong 已提交
6 7 8 9 10

template <typename T>
class GaussianRandomOpKernel<platform::GPUPlace, T>
    : public framework::OpKernel {
 public:
D
dongzhihong 已提交
11 12 13 14 15
  void Compute(const framework::KernelContext& context) const override {
    auto mean = context.op_.GetAttr<T>("mean");
    auto std = context.op_.GetAttr<T>("std");
    auto* output = context.Output(0)->GetMutable<framework::Tensor>();
    T* r = output->mutable_data<T>(context.GetPlace());
D
dongzhihong 已提交
16 17 18
    auto ctx =
        static_cast<const platform::GPUDeviceContext*>(context.device_context_);
    // generator need to modify context
D
dongzhihong 已提交
19 20 21 22
    auto g = const_cast<platform::GPUDeviceContext*>(ctx)->RandGenerator();
    curandGenerateNormal(g, r, framework::product(output->dims()), mean, std);
  }
};
D
dongzhihong 已提交
23

D
dongzhihong 已提交
24 25 26
}  // namespace operators
}  // namespace paddle

D
dongzhihong 已提交
27 28 29
typedef paddle::operators::GaussianRandomOpKernel<paddle::platform::GPUPlace,
                                                  float>
    RandomOpKernel_GPU_float;
D
dongzhihong 已提交
30
REGISTER_OP_GPU_KERNEL(gaussian_random, GaussianRandomOpKernel_GPU_float);