提交 b851515b 编写于 作者: Z zchen0211

merge new op grammar

上级 88a8eedd
......@@ -20,13 +20,8 @@
namespace paddle {
namespace operators {
// template <typename T>
__global__ void print_arr(const float *params, const int N) {
CUDA_1D_KERNEL_LOOP(i, N) { printf("device: %d, %f\n", i, params[i]); }
}
template <typename T>
class GatherOpCUDAKernel : public framework::OpKernel {
class GatherOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
......@@ -42,7 +37,7 @@ class GatherOpCUDAKernel : public framework::OpKernel {
};
template <typename T>
class GatherGradOpCUDAKernel : public framework::OpKernel {
class GatherGradOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
......
......@@ -20,7 +20,7 @@ namespace paddle {
namespace operators {
template <typename T>
class ScatterOpCUDAKernel : public framework::OpKernel {
class ScatterOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
......@@ -37,7 +37,7 @@ class ScatterOpCUDAKernel : public framework::OpKernel {
};
template <typename T>
class ScatterGradOpCUDAKernel : public framework::OpKernel {
class ScatterGradOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册