diff --git a/paddle/operators/multiplex_op.cc b/paddle/operators/multiplex_op.cc index 6e77b86b5698a263b850a973cd1b8644a0aa2201..e0c19a3190c0f52cbe19f6ba00762f02445366c4 100644 --- a/paddle/operators/multiplex_op.cc +++ b/paddle/operators/multiplex_op.cc @@ -18,7 +18,6 @@ namespace paddle { namespace operators { using Tensor = framework::Tensor; -using LoDTensor = framework::LoDTensor; class MultiplexOp : public framework::OperatorWithKernel { public: @@ -27,11 +26,11 @@ class MultiplexOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { PADDLE_ENFORCE(!ctx.MultiInputVar("X").empty(), - "Input(X) should not be null"); + "Input(X) should not be null."); PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), "Output(Out) shouldn't be null."); auto ins = ctx.MultiInput("X"); - auto *out = ctx.Output("Out"); + auto *out = ctx.Output("Out"); auto num_ins = ins.size(); PADDLE_ENFORCE(num_ins > 2, "multiplex operator should have more than 2 inputs."); @@ -41,9 +40,9 @@ class MultiplexOp : public framework::OperatorWithKernel { for (size_t i = 2; i < num_ins; i++) { auto dim = ins[i]->dims(); - PADDLE_ENFORCE( - in_dim == dim, - "All the input tensors except the first one must have the same size"); + PADDLE_ENFORCE(in_dim == dim, + "All the input tensors except the first one must have the " + "same size."); } out->Resize(in_dim); } @@ -84,12 +83,12 @@ class MultiplexGradOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { PADDLE_ENFORCE(!ctx.MultiInputVar("X").empty(), - "Input(X) should not be null"); + "Input(X) should not be null."); PADDLE_ENFORCE(!ctx.MultiOutputVar(framework::GradVarName("X")).empty(), - "Output(X@Grad) should not be null"); + "Output(X@Grad) should not be null."); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), "Input(Out@GRAD) shouldn't be null."); - auto d_ins = ctx.MultiOutput(framework::GradVarName("X")); + auto d_ins = ctx.MultiOutput(framework::GradVarName("X")); auto ins = ctx.MultiInput("X"); // don't compute gradient for index (ins[0]) for (size_t i = 1; i < ins.size(); i++) { diff --git a/paddle/operators/multiplex_op.cu b/paddle/operators/multiplex_op.cu index 4736f15bd594178168e3bcf799142d0fc18bff13..ae4c7d183ac7891ce615b5464481ade319537b1a 100644 --- a/paddle/operators/multiplex_op.cu +++ b/paddle/operators/multiplex_op.cu @@ -18,19 +18,20 @@ namespace paddle { namespace operators { +using Tensor = framework::Tensor; + template class MultiplexGPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const { - auto ins = ctx.MultiInput("X"); - auto* out = ctx.Output("Out"); - + auto ins = ctx.MultiInput("X"); + auto* out = ctx.Output("Out"); out->mutable_data(ctx.GetPlace()); auto rows = ins[1]->dims()[0]; auto cols = ins[1]->dims()[1]; // copy index to cpu - framework::Tensor index_t_cpu; + Tensor index_t_cpu; index_t_cpu.CopyFrom(*(ins[0]), platform::CPUPlace()); auto* index = index_t_cpu.data(); auto stream = reinterpret_cast( @@ -38,7 +39,7 @@ class MultiplexGPUKernel : public framework::OpKernel { .stream(); Place place = boost::get(ctx.GetPlace()); for (auto i = 0; i < rows; i++) { - int k = (int)index[i] + 1; + size_t k = (size_t)index[i] + 1; PADDLE_ENFORCE_LT(k, ins.size(), "index exceeds the number of candidate tensors."); memory::Copy(place, out->data() + i * cols, place, @@ -51,10 +52,9 @@ template class MultiplexGradGPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const { - auto* d_out = ctx.Input(framework::GradVarName("Out")); - auto ins = ctx.MultiInput("X"); - auto d_ins = - ctx.MultiOutput(framework::GradVarName("X")); + auto* d_out = ctx.Input(framework::GradVarName("Out")); + auto ins = ctx.MultiInput("X"); + auto d_ins = ctx.MultiOutput(framework::GradVarName("X")); for (size_t i = 1; i < d_ins.size(); i++) { if (d_ins[i]) { d_ins[i]->mutable_data(ctx.GetPlace()); @@ -66,7 +66,7 @@ class MultiplexGradGPUKernel : public framework::OpKernel { auto rows = ins[1]->dims()[0]; auto cols = ins[1]->dims()[1]; // copy index to cpu - framework::Tensor index_t_cpu; + Tensor index_t_cpu; index_t_cpu.CopyFrom(*(ins[0]), platform::CPUPlace()); auto* index = index_t_cpu.data(); @@ -75,7 +75,7 @@ class MultiplexGradGPUKernel : public framework::OpKernel { .stream(); Place place = boost::get(ctx.GetPlace()); for (auto i = 0; i < rows; i++) { - int k = (int)index[i] + 1; + size_t k = (size_t)index[i] + 1; if (d_ins[k]) { memory::Copy(place, d_ins[k]->data() + i * cols, place, d_out->data() + i * cols, cols * sizeof(T), stream); diff --git a/paddle/operators/multiplex_op.h b/paddle/operators/multiplex_op.h index 44e8e0c1998014081b7e0aac603d573aba1f4a13..98b8ec930d0e6da4b7fc21cee2787aece6c4ae81 100644 --- a/paddle/operators/multiplex_op.h +++ b/paddle/operators/multiplex_op.h @@ -27,7 +27,7 @@ class MultiplexCPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const { auto ins = ctx.MultiInput("X"); - auto* out = ctx.Output("Out"); + auto* out = ctx.Output("Out"); out->mutable_data(ctx.GetPlace()); @@ -36,7 +36,7 @@ class MultiplexCPUKernel : public framework::OpKernel { auto* index = ins[0]->data(); Place place = boost::get(ctx.GetPlace()); for (auto i = 0; i < rows; i++) { - int k = (int)index[i] + 1; + size_t k = (size_t)index[i] + 1; PADDLE_ENFORCE_LT(k, ins.size(), "index exceeds the number of candidate tensors."); memory::Copy(place, out->data() + i * cols, place, @@ -66,7 +66,7 @@ class MultiplexGradCPUKernel : public framework::OpKernel { auto* index = ins[0]->data(); Place place = boost::get(ctx.GetPlace()); for (auto i = 0; i < rows; i++) { - int k = (int)index[i] + 1; + size_t k = (size_t)index[i] + 1; if (d_ins[k]) { memory::Copy(place, d_ins[k]->data() + i * cols, place, d_out->data() + i * cols, cols * sizeof(T));