提交 f94109d4 编写于 作者: Y Yibing Liu

replace LoDTensor in multiplex_op

上级 47fbc96f
...@@ -18,7 +18,6 @@ namespace paddle { ...@@ -18,7 +18,6 @@ namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
class MultiplexOp : public framework::OperatorWithKernel { class MultiplexOp : public framework::OperatorWithKernel {
public: public:
...@@ -27,11 +26,11 @@ class MultiplexOp : public framework::OperatorWithKernel { ...@@ -27,11 +26,11 @@ class MultiplexOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE(!ctx.MultiInputVar("X").empty(), PADDLE_ENFORCE(!ctx.MultiInputVar("X").empty(),
"Input(X) should not be null"); "Input(X) should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"),
"Output(Out) shouldn't be null."); "Output(Out) shouldn't be null.");
auto ins = ctx.MultiInput<Tensor>("X"); auto ins = ctx.MultiInput<Tensor>("X");
auto *out = ctx.Output<LoDTensor>("Out"); auto *out = ctx.Output<Tensor>("Out");
auto num_ins = ins.size(); auto num_ins = ins.size();
PADDLE_ENFORCE(num_ins > 2, PADDLE_ENFORCE(num_ins > 2,
"multiplex operator should have more than 2 inputs."); "multiplex operator should have more than 2 inputs.");
...@@ -41,9 +40,9 @@ class MultiplexOp : public framework::OperatorWithKernel { ...@@ -41,9 +40,9 @@ class MultiplexOp : public framework::OperatorWithKernel {
for (size_t i = 2; i < num_ins; i++) { for (size_t i = 2; i < num_ins; i++) {
auto dim = ins[i]->dims(); auto dim = ins[i]->dims();
PADDLE_ENFORCE( PADDLE_ENFORCE(in_dim == dim,
in_dim == dim, "All the input tensors except the first one must have the "
"All the input tensors except the first one must have the same size"); "same size.");
} }
out->Resize(in_dim); out->Resize(in_dim);
} }
...@@ -84,12 +83,12 @@ class MultiplexGradOp : public framework::OperatorWithKernel { ...@@ -84,12 +83,12 @@ class MultiplexGradOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE(!ctx.MultiInputVar("X").empty(), PADDLE_ENFORCE(!ctx.MultiInputVar("X").empty(),
"Input(X) should not be null"); "Input(X) should not be null.");
PADDLE_ENFORCE(!ctx.MultiOutputVar(framework::GradVarName("X")).empty(), PADDLE_ENFORCE(!ctx.MultiOutputVar(framework::GradVarName("X")).empty(),
"Output(X@Grad) should not be null"); "Output(X@Grad) should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
"Input(Out@GRAD) shouldn't be null."); "Input(Out@GRAD) shouldn't be null.");
auto d_ins = ctx.MultiOutput<LoDTensor>(framework::GradVarName("X")); auto d_ins = ctx.MultiOutput<Tensor>(framework::GradVarName("X"));
auto ins = ctx.MultiInput<Tensor>("X"); auto ins = ctx.MultiInput<Tensor>("X");
// don't compute gradient for index (ins[0]) // don't compute gradient for index (ins[0])
for (size_t i = 1; i < ins.size(); i++) { for (size_t i = 1; i < ins.size(); i++) {
......
...@@ -18,19 +18,20 @@ ...@@ -18,19 +18,20 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor;
template <typename Place, typename T> template <typename Place, typename T>
class MultiplexGPUKernel : public framework::OpKernel { class MultiplexGPUKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& ctx) const { void Compute(const framework::ExecutionContext& ctx) const {
auto ins = ctx.MultiInput<framework::Tensor>("X"); auto ins = ctx.MultiInput<Tensor>("X");
auto* out = ctx.Output<framework::LoDTensor>("Out"); auto* out = ctx.Output<Tensor>("Out");
out->mutable_data<T>(ctx.GetPlace()); out->mutable_data<T>(ctx.GetPlace());
auto rows = ins[1]->dims()[0]; auto rows = ins[1]->dims()[0];
auto cols = ins[1]->dims()[1]; auto cols = ins[1]->dims()[1];
// copy index to cpu // copy index to cpu
framework::Tensor index_t_cpu; Tensor index_t_cpu;
index_t_cpu.CopyFrom<T>(*(ins[0]), platform::CPUPlace()); index_t_cpu.CopyFrom<T>(*(ins[0]), platform::CPUPlace());
auto* index = index_t_cpu.data<T>(); auto* index = index_t_cpu.data<T>();
auto stream = reinterpret_cast<const platform::CUDADeviceContext&>( auto stream = reinterpret_cast<const platform::CUDADeviceContext&>(
...@@ -38,7 +39,7 @@ class MultiplexGPUKernel : public framework::OpKernel { ...@@ -38,7 +39,7 @@ class MultiplexGPUKernel : public framework::OpKernel {
.stream(); .stream();
Place place = boost::get<Place>(ctx.GetPlace()); Place place = boost::get<Place>(ctx.GetPlace());
for (auto i = 0; i < rows; i++) { for (auto i = 0; i < rows; i++) {
int k = (int)index[i] + 1; size_t k = (size_t)index[i] + 1;
PADDLE_ENFORCE_LT(k, ins.size(), PADDLE_ENFORCE_LT(k, ins.size(),
"index exceeds the number of candidate tensors."); "index exceeds the number of candidate tensors.");
memory::Copy(place, out->data<T>() + i * cols, place, memory::Copy(place, out->data<T>() + i * cols, place,
...@@ -51,10 +52,9 @@ template <typename Place, typename T> ...@@ -51,10 +52,9 @@ template <typename Place, typename T>
class MultiplexGradGPUKernel : public framework::OpKernel { class MultiplexGradGPUKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& ctx) const { void Compute(const framework::ExecutionContext& ctx) const {
auto* d_out = ctx.Input<framework::Tensor>(framework::GradVarName("Out")); auto* d_out = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto ins = ctx.MultiInput<framework::Tensor>("X"); auto ins = ctx.MultiInput<Tensor>("X");
auto d_ins = auto d_ins = ctx.MultiOutput<Tensor>(framework::GradVarName("X"));
ctx.MultiOutput<framework::Tensor>(framework::GradVarName("X"));
for (size_t i = 1; i < d_ins.size(); i++) { for (size_t i = 1; i < d_ins.size(); i++) {
if (d_ins[i]) { if (d_ins[i]) {
d_ins[i]->mutable_data<T>(ctx.GetPlace()); d_ins[i]->mutable_data<T>(ctx.GetPlace());
...@@ -66,7 +66,7 @@ class MultiplexGradGPUKernel : public framework::OpKernel { ...@@ -66,7 +66,7 @@ class MultiplexGradGPUKernel : public framework::OpKernel {
auto rows = ins[1]->dims()[0]; auto rows = ins[1]->dims()[0];
auto cols = ins[1]->dims()[1]; auto cols = ins[1]->dims()[1];
// copy index to cpu // copy index to cpu
framework::Tensor index_t_cpu; Tensor index_t_cpu;
index_t_cpu.CopyFrom<T>(*(ins[0]), platform::CPUPlace()); index_t_cpu.CopyFrom<T>(*(ins[0]), platform::CPUPlace());
auto* index = index_t_cpu.data<T>(); auto* index = index_t_cpu.data<T>();
...@@ -75,7 +75,7 @@ class MultiplexGradGPUKernel : public framework::OpKernel { ...@@ -75,7 +75,7 @@ class MultiplexGradGPUKernel : public framework::OpKernel {
.stream(); .stream();
Place place = boost::get<Place>(ctx.GetPlace()); Place place = boost::get<Place>(ctx.GetPlace());
for (auto i = 0; i < rows; i++) { for (auto i = 0; i < rows; i++) {
int k = (int)index[i] + 1; size_t k = (size_t)index[i] + 1;
if (d_ins[k]) { if (d_ins[k]) {
memory::Copy(place, d_ins[k]->data<T>() + i * cols, place, memory::Copy(place, d_ins[k]->data<T>() + i * cols, place,
d_out->data<T>() + i * cols, cols * sizeof(T), stream); d_out->data<T>() + i * cols, cols * sizeof(T), stream);
......
...@@ -27,7 +27,7 @@ class MultiplexCPUKernel : public framework::OpKernel { ...@@ -27,7 +27,7 @@ class MultiplexCPUKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& ctx) const { void Compute(const framework::ExecutionContext& ctx) const {
auto ins = ctx.MultiInput<framework::Tensor>("X"); auto ins = ctx.MultiInput<framework::Tensor>("X");
auto* out = ctx.Output<framework::LoDTensor>("Out"); auto* out = ctx.Output<framework::Tensor>("Out");
out->mutable_data<T>(ctx.GetPlace()); out->mutable_data<T>(ctx.GetPlace());
...@@ -36,7 +36,7 @@ class MultiplexCPUKernel : public framework::OpKernel { ...@@ -36,7 +36,7 @@ class MultiplexCPUKernel : public framework::OpKernel {
auto* index = ins[0]->data<T>(); auto* index = ins[0]->data<T>();
Place place = boost::get<Place>(ctx.GetPlace()); Place place = boost::get<Place>(ctx.GetPlace());
for (auto i = 0; i < rows; i++) { for (auto i = 0; i < rows; i++) {
int k = (int)index[i] + 1; size_t k = (size_t)index[i] + 1;
PADDLE_ENFORCE_LT(k, ins.size(), PADDLE_ENFORCE_LT(k, ins.size(),
"index exceeds the number of candidate tensors."); "index exceeds the number of candidate tensors.");
memory::Copy(place, out->data<T>() + i * cols, place, memory::Copy(place, out->data<T>() + i * cols, place,
...@@ -66,7 +66,7 @@ class MultiplexGradCPUKernel : public framework::OpKernel { ...@@ -66,7 +66,7 @@ class MultiplexGradCPUKernel : public framework::OpKernel {
auto* index = ins[0]->data<T>(); auto* index = ins[0]->data<T>();
Place place = boost::get<Place>(ctx.GetPlace()); Place place = boost::get<Place>(ctx.GetPlace());
for (auto i = 0; i < rows; i++) { for (auto i = 0; i < rows; i++) {
int k = (int)index[i] + 1; size_t k = (size_t)index[i] + 1;
if (d_ins[k]) { if (d_ins[k]) {
memory::Copy(place, d_ins[k]->data<T>() + i * cols, place, memory::Copy(place, d_ins[k]->data<T>() + i * cols, place,
d_out->data<T>() + i * cols, cols * sizeof(T)); d_out->data<T>() + i * cols, cols * sizeof(T));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册