提交 b851515b 编写于 作者: Z zchen0211

merge new op grammar

上级 88a8eedd
...@@ -20,13 +20,8 @@ ...@@ -20,13 +20,8 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
// template <typename T>
__global__ void print_arr(const float *params, const int N) {
CUDA_1D_KERNEL_LOOP(i, N) { printf("device: %d, %f\n", i, params[i]); }
}
template <typename T> template <typename T>
class GatherOpCUDAKernel : public framework::OpKernel { class GatherOpCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
...@@ -42,7 +37,7 @@ class GatherOpCUDAKernel : public framework::OpKernel { ...@@ -42,7 +37,7 @@ class GatherOpCUDAKernel : public framework::OpKernel {
}; };
template <typename T> template <typename T>
class GatherGradOpCUDAKernel : public framework::OpKernel { class GatherGradOpCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
......
...@@ -20,7 +20,7 @@ namespace paddle { ...@@ -20,7 +20,7 @@ namespace paddle {
namespace operators { namespace operators {
template <typename T> template <typename T>
class ScatterOpCUDAKernel : public framework::OpKernel { class ScatterOpCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
...@@ -37,7 +37,7 @@ class ScatterOpCUDAKernel : public framework::OpKernel { ...@@ -37,7 +37,7 @@ class ScatterOpCUDAKernel : public framework::OpKernel {
}; };
template <typename T> template <typename T>
class ScatterGradOpCUDAKernel : public framework::OpKernel { class ScatterGradOpCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册