diff --git a/paddle/operators/multiplex_op.cc b/paddle/operators/multiplex_op.cc index 6b22c782fe2456147223c43f94c290def62495b9..6e77b86b5698a263b850a973cd1b8644a0aa2201 100644 --- a/paddle/operators/multiplex_op.cc +++ b/paddle/operators/multiplex_op.cc @@ -106,8 +106,8 @@ namespace ops = paddle::operators; REGISTER_OP(multiplex, ops::MultiplexOp, ops::MultiplexOpMaker, multiplex_grad, ops::MultiplexGradOp); -REGISTER_OP_CPU_KERNEL(multiplex, - ops::MultiplexKernel); +REGISTER_OP_CPU_KERNEL( + multiplex, ops::MultiplexCPUKernel); REGISTER_OP_CPU_KERNEL( multiplex_grad, - ops::MultiplexGradKernel); + ops::MultiplexGradCPUKernel); diff --git a/paddle/operators/multiplex_op.cu b/paddle/operators/multiplex_op.cu index 3d219389ba5b66b580b83a0f8816e413313aa233..4736f15bd594178168e3bcf799142d0fc18bff13 100644 --- a/paddle/operators/multiplex_op.cu +++ b/paddle/operators/multiplex_op.cu @@ -15,10 +15,81 @@ #include "paddle/framework/op_registry.h" #include "paddle/operators/multiplex_op.h" +namespace paddle { +namespace operators { + +template +class MultiplexGPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const { + auto ins = ctx.MultiInput("X"); + auto* out = ctx.Output("Out"); + + out->mutable_data(ctx.GetPlace()); + + auto rows = ins[1]->dims()[0]; + auto cols = ins[1]->dims()[1]; + // copy index to cpu + framework::Tensor index_t_cpu; + index_t_cpu.CopyFrom(*(ins[0]), platform::CPUPlace()); + auto* index = index_t_cpu.data(); + auto stream = reinterpret_cast( + ctx.device_context()) + .stream(); + Place place = boost::get(ctx.GetPlace()); + for (auto i = 0; i < rows; i++) { + int k = (int)index[i] + 1; + PADDLE_ENFORCE_LT(k, ins.size(), + "index exceeds the number of candidate tensors."); + memory::Copy(place, out->data() + i * cols, place, + ins[k]->data() + i * cols, cols * sizeof(T), stream); + } + } +}; + +template +class MultiplexGradGPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const { + auto* d_out = ctx.Input(framework::GradVarName("Out")); + auto ins = ctx.MultiInput("X"); + auto d_ins = + ctx.MultiOutput(framework::GradVarName("X")); + for (size_t i = 1; i < d_ins.size(); i++) { + if (d_ins[i]) { + d_ins[i]->mutable_data(ctx.GetPlace()); + auto t = framework::EigenVector::Flatten(*d_ins[i]); + t.device(ctx.GetEigenDevice()) = t.constant(static_cast(0)); + } + } + + auto rows = ins[1]->dims()[0]; + auto cols = ins[1]->dims()[1]; + // copy index to cpu + framework::Tensor index_t_cpu; + index_t_cpu.CopyFrom(*(ins[0]), platform::CPUPlace()); + auto* index = index_t_cpu.data(); + + auto stream = reinterpret_cast( + ctx.device_context()) + .stream(); + Place place = boost::get(ctx.GetPlace()); + for (auto i = 0; i < rows; i++) { + int k = (int)index[i] + 1; + if (d_ins[k]) { + memory::Copy(place, d_ins[k]->data() + i * cols, place, + d_out->data() + i * cols, cols * sizeof(T), stream); + } + } + } +}; +} // namespace operators +} // namespace paddle + namespace ops = paddle::operators; -REGISTER_OP_GPU_KERNEL(multiplex, - ops::MultiplexKernel); +REGISTER_OP_GPU_KERNEL( + multiplex, ops::MultiplexGPUKernel); REGISTER_OP_GPU_KERNEL( multiplex_grad, - ops::MultiplexGradKernel); + ops::MultiplexGradGPUKernel); diff --git a/paddle/operators/multiplex_op.h b/paddle/operators/multiplex_op.h index dcc01d0f9818fcc00065d0ca5559b2b1a9e99ca4..44e8e0c1998014081b7e0aac603d573aba1f4a13 100644 --- a/paddle/operators/multiplex_op.h +++ b/paddle/operators/multiplex_op.h @@ -23,7 +23,7 @@ namespace paddle { namespace operators { template -class MultiplexKernel : public framework::OpKernel { +class MultiplexCPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const { auto ins = ctx.MultiInput("X"); @@ -33,40 +33,20 @@ class MultiplexKernel : public framework::OpKernel { auto rows = ins[1]->dims()[0]; auto cols = ins[1]->dims()[1]; - if (platform::is_cpu_place(ctx.GetPlace())) { - auto* index = ins[0]->data(); - platform::CPUPlace place = boost::get(ctx.GetPlace()); - for (auto i = 0; i < rows; i++) { - int k = (int)index[i] + 1; - PADDLE_ENFORCE_LT(k, ins.size(), - "index exceeds the number of candidate tensors."); - memory::Copy(place, out->data() + i * cols, place, - ins[k]->data() + i * cols, cols * sizeof(T)); - } - } else { -#ifndef PADDLE_ONLY_CPU - // copy index to cpu - framework::Tensor index_t_cpu; - index_t_cpu.CopyFrom(*(ins[0]), platform::CPUPlace()); - auto* index = index_t_cpu.data(); - auto stream = reinterpret_cast( - ctx.device_context()) - .stream(); - platform::GPUPlace place = boost::get(ctx.GetPlace()); - for (auto i = 0; i < rows; i++) { - int k = (int)index[i] + 1; - PADDLE_ENFORCE_LT(k, ins.size(), - "index exceeds the number of candidate tensors."); - memory::Copy(place, out->data() + i * cols, place, - ins[k]->data() + i * cols, cols * sizeof(T), stream); - } -#endif + auto* index = ins[0]->data(); + Place place = boost::get(ctx.GetPlace()); + for (auto i = 0; i < rows; i++) { + int k = (int)index[i] + 1; + PADDLE_ENFORCE_LT(k, ins.size(), + "index exceeds the number of candidate tensors."); + memory::Copy(place, out->data() + i * cols, place, + ins[k]->data() + i * cols, cols * sizeof(T)); } } }; template -class MultiplexGradKernel : public framework::OpKernel { +class MultiplexGradCPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const { auto* d_out = ctx.Input(framework::GradVarName("Out")); @@ -83,35 +63,14 @@ class MultiplexGradKernel : public framework::OpKernel { auto rows = ins[1]->dims()[0]; auto cols = ins[1]->dims()[1]; - if (platform::is_cpu_place(ctx.GetPlace())) { - auto* index = ins[0]->data(); - platform::CPUPlace place = boost::get(ctx.GetPlace()); - for (auto i = 0; i < rows; i++) { - int k = (int)index[i] + 1; - if (d_ins[k]) { - memory::Copy(place, d_ins[k]->data() + i * cols, place, - d_out->data() + i * cols, cols * sizeof(T)); - } - } - } else { -#ifndef PADDLE_ONLY_CPU - // copy index to cpu - framework::Tensor index_t_cpu; - index_t_cpu.CopyFrom(*(ins[0]), platform::CPUPlace()); - auto* index = index_t_cpu.data(); - - auto stream = reinterpret_cast( - ctx.device_context()) - .stream(); - platform::GPUPlace place = boost::get(ctx.GetPlace()); - for (auto i = 0; i < rows; i++) { - int k = (int)index[i] + 1; - if (d_ins[k]) { - memory::Copy(place, d_ins[k]->data() + i * cols, place, - d_out->data() + i * cols, cols * sizeof(T), stream); - } + auto* index = ins[0]->data(); + Place place = boost::get(ctx.GetPlace()); + for (auto i = 0; i < rows; i++) { + int k = (int)index[i] + 1; + if (d_ins[k]) { + memory::Copy(place, d_ins[k]->data() + i * cols, place, + d_out->data() + i * cols, cols * sizeof(T)); } -#endif } } };