From f94109d4285ef34ddd1fbdea36bacd33f9d97231 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Mon, 25 Sep 2017 14:13:45 +0800 Subject: [PATCH] replace LoDTensor in multiplex_op --- paddle/operators/multiplex_op.cc | 17 ++++++++--------- paddle/operators/multiplex_op.cu | 22 +++++++++++----------- paddle/operators/multiplex_op.h | 6 +++--- 3 files changed, 22 insertions(+), 23 deletions(-) diff --git a/paddle/operators/multiplex_op.cc b/paddle/operators/multiplex_op.cc index 6e77b86b569..e0c19a3190c 100644 --- a/paddle/operators/multiplex_op.cc +++ b/paddle/operators/multiplex_op.cc @@ -18,7 +18,6 @@ namespace paddle { namespace operators { using Tensor = framework::Tensor; -using LoDTensor = framework::LoDTensor; class MultiplexOp : public framework::OperatorWithKernel { public: @@ -27,11 +26,11 @@ class MultiplexOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { PADDLE_ENFORCE(!ctx.MultiInputVar("X").empty(), - "Input(X) should not be null"); + "Input(X) should not be null."); PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), "Output(Out) shouldn't be null."); auto ins = ctx.MultiInput("X"); - auto *out = ctx.Output("Out"); + auto *out = ctx.Output("Out"); auto num_ins = ins.size(); PADDLE_ENFORCE(num_ins > 2, "multiplex operator should have more than 2 inputs."); @@ -41,9 +40,9 @@ class MultiplexOp : public framework::OperatorWithKernel { for (size_t i = 2; i < num_ins; i++) { auto dim = ins[i]->dims(); - PADDLE_ENFORCE( - in_dim == dim, - "All the input tensors except the first one must have the same size"); + PADDLE_ENFORCE(in_dim == dim, + "All the input tensors except the first one must have the " + "same size."); } out->Resize(in_dim); } @@ -84,12 +83,12 @@ class MultiplexGradOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { PADDLE_ENFORCE(!ctx.MultiInputVar("X").empty(), - "Input(X) should not be null"); + "Input(X) should not be null."); PADDLE_ENFORCE(!ctx.MultiOutputVar(framework::GradVarName("X")).empty(), - "Output(X@Grad) should not be null"); + "Output(X@Grad) should not be null."); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), "Input(Out@GRAD) shouldn't be null."); - auto d_ins = ctx.MultiOutput(framework::GradVarName("X")); + auto d_ins = ctx.MultiOutput(framework::GradVarName("X")); auto ins = ctx.MultiInput("X"); // don't compute gradient for index (ins[0]) for (size_t i = 1; i < ins.size(); i++) { diff --git a/paddle/operators/multiplex_op.cu b/paddle/operators/multiplex_op.cu index 4736f15bd59..ae4c7d183ac 100644 --- a/paddle/operators/multiplex_op.cu +++ b/paddle/operators/multiplex_op.cu @@ -18,19 +18,20 @@ namespace paddle { namespace operators { +using Tensor = framework::Tensor; + template class MultiplexGPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const { - auto ins = ctx.MultiInput("X"); - auto* out = ctx.Output("Out"); - + auto ins = ctx.MultiInput("X"); + auto* out = ctx.Output("Out"); out->mutable_data(ctx.GetPlace()); auto rows = ins[1]->dims()[0]; auto cols = ins[1]->dims()[1]; // copy index to cpu - framework::Tensor index_t_cpu; + Tensor index_t_cpu; index_t_cpu.CopyFrom(*(ins[0]), platform::CPUPlace()); auto* index = index_t_cpu.data(); auto stream = reinterpret_cast( @@ -38,7 +39,7 @@ class MultiplexGPUKernel : public framework::OpKernel { .stream(); Place place = boost::get(ctx.GetPlace()); for (auto i = 0; i < rows; i++) { - int k = (int)index[i] + 1; + size_t k = (size_t)index[i] + 1; PADDLE_ENFORCE_LT(k, ins.size(), "index exceeds the number of candidate tensors."); memory::Copy(place, out->data() + i * cols, place, @@ -51,10 +52,9 @@ template class MultiplexGradGPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const { - auto* d_out = ctx.Input(framework::GradVarName("Out")); - auto ins = ctx.MultiInput("X"); - auto d_ins = - ctx.MultiOutput(framework::GradVarName("X")); + auto* d_out = ctx.Input(framework::GradVarName("Out")); + auto ins = ctx.MultiInput("X"); + auto d_ins = ctx.MultiOutput(framework::GradVarName("X")); for (size_t i = 1; i < d_ins.size(); i++) { if (d_ins[i]) { d_ins[i]->mutable_data(ctx.GetPlace()); @@ -66,7 +66,7 @@ class MultiplexGradGPUKernel : public framework::OpKernel { auto rows = ins[1]->dims()[0]; auto cols = ins[1]->dims()[1]; // copy index to cpu - framework::Tensor index_t_cpu; + Tensor index_t_cpu; index_t_cpu.CopyFrom(*(ins[0]), platform::CPUPlace()); auto* index = index_t_cpu.data(); @@ -75,7 +75,7 @@ class MultiplexGradGPUKernel : public framework::OpKernel { .stream(); Place place = boost::get(ctx.GetPlace()); for (auto i = 0; i < rows; i++) { - int k = (int)index[i] + 1; + size_t k = (size_t)index[i] + 1; if (d_ins[k]) { memory::Copy(place, d_ins[k]->data() + i * cols, place, d_out->data() + i * cols, cols * sizeof(T), stream); diff --git a/paddle/operators/multiplex_op.h b/paddle/operators/multiplex_op.h index 44e8e0c1998..98b8ec930d0 100644 --- a/paddle/operators/multiplex_op.h +++ b/paddle/operators/multiplex_op.h @@ -27,7 +27,7 @@ class MultiplexCPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const { auto ins = ctx.MultiInput("X"); - auto* out = ctx.Output("Out"); + auto* out = ctx.Output("Out"); out->mutable_data(ctx.GetPlace()); @@ -36,7 +36,7 @@ class MultiplexCPUKernel : public framework::OpKernel { auto* index = ins[0]->data(); Place place = boost::get(ctx.GetPlace()); for (auto i = 0; i < rows; i++) { - int k = (int)index[i] + 1; + size_t k = (size_t)index[i] + 1; PADDLE_ENFORCE_LT(k, ins.size(), "index exceeds the number of candidate tensors."); memory::Copy(place, out->data() + i * cols, place, @@ -66,7 +66,7 @@ class MultiplexGradCPUKernel : public framework::OpKernel { auto* index = ins[0]->data(); Place place = boost::get(ctx.GetPlace()); for (auto i = 0; i < rows; i++) { - int k = (int)index[i] + 1; + size_t k = (size_t)index[i] + 1; if (d_ins[k]) { memory::Copy(place, d_ins[k]->data() + i * cols, place, d_out->data() + i * cols, cols * sizeof(T)); -- GitLab