提交 fb52bc6e 编写于 作者: Y Yibing Liu

revert code layout in multiplex_op

上级 7620efdf
...@@ -106,8 +106,8 @@ namespace ops = paddle::operators; ...@@ -106,8 +106,8 @@ namespace ops = paddle::operators;
REGISTER_OP(multiplex, ops::MultiplexOp, ops::MultiplexOpMaker, multiplex_grad, REGISTER_OP(multiplex, ops::MultiplexOp, ops::MultiplexOpMaker, multiplex_grad,
ops::MultiplexGradOp); ops::MultiplexGradOp);
REGISTER_OP_CPU_KERNEL(multiplex, REGISTER_OP_CPU_KERNEL(
ops::MultiplexKernel<paddle::platform::CPUPlace, float>); multiplex, ops::MultiplexCPUKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
multiplex_grad, multiplex_grad,
ops::MultiplexGradKernel<paddle::platform::CPUPlace, float>); ops::MultiplexGradCPUKernel<paddle::platform::CPUPlace, float>);
...@@ -15,10 +15,81 @@ ...@@ -15,10 +15,81 @@
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/operators/multiplex_op.h" #include "paddle/operators/multiplex_op.h"
namespace paddle {
namespace operators {
template <typename Place, typename T>
class MultiplexGPUKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const {
auto ins = ctx.MultiInput<framework::Tensor>("X");
auto* out = ctx.Output<framework::LoDTensor>("Out");
out->mutable_data<T>(ctx.GetPlace());
auto rows = ins[1]->dims()[0];
auto cols = ins[1]->dims()[1];
// copy index to cpu
framework::Tensor index_t_cpu;
index_t_cpu.CopyFrom<T>(*(ins[0]), platform::CPUPlace());
auto* index = index_t_cpu.data<T>();
auto stream = reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream();
Place place = boost::get<Place>(ctx.GetPlace());
for (auto i = 0; i < rows; i++) {
int k = (int)index[i] + 1;
PADDLE_ENFORCE_LT(k, ins.size(),
"index exceeds the number of candidate tensors.");
memory::Copy(place, out->data<T>() + i * cols, place,
ins[k]->data<T>() + i * cols, cols * sizeof(T), stream);
}
}
};
template <typename Place, typename T>
class MultiplexGradGPUKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const {
auto* d_out = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto ins = ctx.MultiInput<framework::Tensor>("X");
auto d_ins =
ctx.MultiOutput<framework::Tensor>(framework::GradVarName("X"));
for (size_t i = 1; i < d_ins.size(); i++) {
if (d_ins[i]) {
d_ins[i]->mutable_data<T>(ctx.GetPlace());
auto t = framework::EigenVector<T>::Flatten(*d_ins[i]);
t.device(ctx.GetEigenDevice<Place>()) = t.constant(static_cast<T>(0));
}
}
auto rows = ins[1]->dims()[0];
auto cols = ins[1]->dims()[1];
// copy index to cpu
framework::Tensor index_t_cpu;
index_t_cpu.CopyFrom<T>(*(ins[0]), platform::CPUPlace());
auto* index = index_t_cpu.data<T>();
auto stream = reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream();
Place place = boost::get<Place>(ctx.GetPlace());
for (auto i = 0; i < rows; i++) {
int k = (int)index[i] + 1;
if (d_ins[k]) {
memory::Copy(place, d_ins[k]->data<T>() + i * cols, place,
d_out->data<T>() + i * cols, cols * sizeof(T), stream);
}
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(multiplex, REGISTER_OP_GPU_KERNEL(
ops::MultiplexKernel<paddle::platform::GPUPlace, float>); multiplex, ops::MultiplexGPUKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL( REGISTER_OP_GPU_KERNEL(
multiplex_grad, multiplex_grad,
ops::MultiplexGradKernel<paddle::platform::GPUPlace, float>); ops::MultiplexGradGPUKernel<paddle::platform::GPUPlace, float>);
...@@ -23,7 +23,7 @@ namespace paddle { ...@@ -23,7 +23,7 @@ namespace paddle {
namespace operators { namespace operators {
template <typename Place, typename T> template <typename Place, typename T>
class MultiplexKernel : public framework::OpKernel { class MultiplexCPUKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& ctx) const { void Compute(const framework::ExecutionContext& ctx) const {
auto ins = ctx.MultiInput<framework::Tensor>("X"); auto ins = ctx.MultiInput<framework::Tensor>("X");
...@@ -33,40 +33,20 @@ class MultiplexKernel : public framework::OpKernel { ...@@ -33,40 +33,20 @@ class MultiplexKernel : public framework::OpKernel {
auto rows = ins[1]->dims()[0]; auto rows = ins[1]->dims()[0];
auto cols = ins[1]->dims()[1]; auto cols = ins[1]->dims()[1];
if (platform::is_cpu_place(ctx.GetPlace())) { auto* index = ins[0]->data<T>();
auto* index = ins[0]->data<T>(); Place place = boost::get<Place>(ctx.GetPlace());
platform::CPUPlace place = boost::get<platform::CPUPlace>(ctx.GetPlace()); for (auto i = 0; i < rows; i++) {
for (auto i = 0; i < rows; i++) { int k = (int)index[i] + 1;
int k = (int)index[i] + 1; PADDLE_ENFORCE_LT(k, ins.size(),
PADDLE_ENFORCE_LT(k, ins.size(), "index exceeds the number of candidate tensors.");
"index exceeds the number of candidate tensors."); memory::Copy(place, out->data<T>() + i * cols, place,
memory::Copy(place, out->data<T>() + i * cols, place, ins[k]->data<T>() + i * cols, cols * sizeof(T));
ins[k]->data<T>() + i * cols, cols * sizeof(T));
}
} else {
#ifndef PADDLE_ONLY_CPU
// copy index to cpu
framework::Tensor index_t_cpu;
index_t_cpu.CopyFrom<T>(*(ins[0]), platform::CPUPlace());
auto* index = index_t_cpu.data<T>();
auto stream = reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream();
platform::GPUPlace place = boost::get<platform::GPUPlace>(ctx.GetPlace());
for (auto i = 0; i < rows; i++) {
int k = (int)index[i] + 1;
PADDLE_ENFORCE_LT(k, ins.size(),
"index exceeds the number of candidate tensors.");
memory::Copy(place, out->data<T>() + i * cols, place,
ins[k]->data<T>() + i * cols, cols * sizeof(T), stream);
}
#endif
} }
} }
}; };
template <typename Place, typename T> template <typename Place, typename T>
class MultiplexGradKernel : public framework::OpKernel { class MultiplexGradCPUKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& ctx) const { void Compute(const framework::ExecutionContext& ctx) const {
auto* d_out = ctx.Input<framework::Tensor>(framework::GradVarName("Out")); auto* d_out = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
...@@ -83,35 +63,14 @@ class MultiplexGradKernel : public framework::OpKernel { ...@@ -83,35 +63,14 @@ class MultiplexGradKernel : public framework::OpKernel {
auto rows = ins[1]->dims()[0]; auto rows = ins[1]->dims()[0];
auto cols = ins[1]->dims()[1]; auto cols = ins[1]->dims()[1];
if (platform::is_cpu_place(ctx.GetPlace())) { auto* index = ins[0]->data<T>();
auto* index = ins[0]->data<T>(); Place place = boost::get<Place>(ctx.GetPlace());
platform::CPUPlace place = boost::get<platform::CPUPlace>(ctx.GetPlace()); for (auto i = 0; i < rows; i++) {
for (auto i = 0; i < rows; i++) { int k = (int)index[i] + 1;
int k = (int)index[i] + 1; if (d_ins[k]) {
if (d_ins[k]) { memory::Copy(place, d_ins[k]->data<T>() + i * cols, place,
memory::Copy(place, d_ins[k]->data<T>() + i * cols, place, d_out->data<T>() + i * cols, cols * sizeof(T));
d_out->data<T>() + i * cols, cols * sizeof(T));
}
}
} else {
#ifndef PADDLE_ONLY_CPU
// copy index to cpu
framework::Tensor index_t_cpu;
index_t_cpu.CopyFrom<T>(*(ins[0]), platform::CPUPlace());
auto* index = index_t_cpu.data<T>();
auto stream = reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream();
platform::GPUPlace place = boost::get<platform::GPUPlace>(ctx.GetPlace());
for (auto i = 0; i < rows; i++) {
int k = (int)index[i] + 1;
if (d_ins[k]) {
memory::Copy(place, d_ins[k]->data<T>() + i * cols, place,
d_out->data<T>() + i * cols, cols * sizeof(T), stream);
}
} }
#endif
} }
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册