diff --git a/paddle/operators/multiplex_op.cc b/paddle/operators/multiplex_op.cc index 6e77b86b5698a263b850a973cd1b8644a0aa2201..7b50444d16dc57fd14b918d1159e3e21ecd1f1c4 100644 --- a/paddle/operators/multiplex_op.cc +++ b/paddle/operators/multiplex_op.cc @@ -18,7 +18,6 @@ namespace paddle { namespace operators { using Tensor = framework::Tensor; -using LoDTensor = framework::LoDTensor; class MultiplexOp : public framework::OperatorWithKernel { public: @@ -26,24 +25,31 @@ class MultiplexOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Ids"), + "Input(Ids) shouldn't be null."); PADDLE_ENFORCE(!ctx.MultiInputVar("X").empty(), - "Input(X) should not be null"); + "MultiInput(X) shouldn't be empty."); PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), "Output(Out) shouldn't be null."); + auto ids_dim = ctx.Input("Ids")->dims(); + PADDLE_ENFORCE( + ids_dim.size() == 2 && ids_dim[1] == 1, + "The index tensor must be a vector with size batchSize x 1."); + auto ins = ctx.MultiInput("X"); - auto *out = ctx.Output("Out"); + auto *out = ctx.Output("Out"); auto num_ins = ins.size(); - PADDLE_ENFORCE(num_ins > 2, - "multiplex operator should have more than 2 inputs."); - PADDLE_ENFORCE_EQ(ins[0]->dims().size(), 1, - "The first input must be a index vector."); - auto in_dim = ins[1]->dims(); - - for (size_t i = 2; i < num_ins; i++) { + PADDLE_ENFORCE(num_ins > 1, + "multiplex operator should have more than " + "one candidate input tensors."); + + auto in_dim = ins[0]->dims(); + PADDLE_ENFORCE(in_dim.size() >= 2, + "The rank of candidate tensors must be not less than 2."); + for (size_t i = 1; i < num_ins; i++) { auto dim = ins[i]->dims(); - PADDLE_ENFORCE( - in_dim == dim, - "All the input tensors except the first one must have the same size"); + PADDLE_ENFORCE(in_dim == dim, + "All the candidate tensors must have the same size."); } out->Resize(in_dim); } @@ -54,25 +60,25 @@ class MultiplexOpMaker : public framework::OpProtoAndCheckerMaker { MultiplexOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "The input tensors of multiplex operator.").AsDuplicable(); + AddInput("Ids", "The index tensor of multiplex operator."); + AddInput("X", "The candidate tensors of multiplex operator.") + .AsDuplicable(); AddOutput("Out", "The output tensor of multiplex operator."); AddComment(R"DOC(Multiplex operator -Multiplex multiple tensors according to the index provided by the first -input tensor. +Multiplex multiple tensors according to the index provided by the index tensor. -ins[0]: the index tensor. -ins[1:N]: the candidate output tensors. +Ids: the index tensor. +X[0 : N - 1]: the candidate tensors for output (N >= 2). For each index i from 0 to batchSize - 1, the output is the i-th row of the -the (index[i] + 1)-th tensor. +the (Ids[i])-th tensor. For i-th row of the output tensor: -y[i][j] = x_{k}[i][j], j = 0,1, ... , (x_{1}.width - 1) +y[i] = x_{k}[i] where y is the output tensor. `x_{k}` is the k-th input tensor -and `k = x{0}[i] + 1`. - +and `k = Ids[i]`. )DOC"); } }; @@ -84,15 +90,15 @@ class MultiplexGradOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { PADDLE_ENFORCE(!ctx.MultiInputVar("X").empty(), - "Input(X) should not be null"); + "Input(X) should not be null."); PADDLE_ENFORCE(!ctx.MultiOutputVar(framework::GradVarName("X")).empty(), - "Output(X@Grad) should not be null"); + "Output(X@Grad) should not be null."); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), - "Input(Out@GRAD) shouldn't be null."); - auto d_ins = ctx.MultiOutput(framework::GradVarName("X")); + "Input(Out@GRAD) should not be null."); + auto d_ins = ctx.MultiOutput(framework::GradVarName("X")); auto ins = ctx.MultiInput("X"); - // don't compute gradient for index (ins[0]) - for (size_t i = 1; i < ins.size(); i++) { + // No need to compute gradient for Input(Ids) + for (size_t i = 0; i < ins.size(); i++) { if (d_ins[i]) { d_ins[i]->Resize(ins[i]->dims()); } diff --git a/paddle/operators/multiplex_op.cu b/paddle/operators/multiplex_op.cu index 4736f15bd594178168e3bcf799142d0fc18bff13..70e46815fc9148a2530d437d20c14f5d40baa1a4 100644 --- a/paddle/operators/multiplex_op.cu +++ b/paddle/operators/multiplex_op.cu @@ -18,27 +18,30 @@ namespace paddle { namespace operators { +using Tensor = framework::Tensor; + template class MultiplexGPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const { - auto ins = ctx.MultiInput("X"); - auto* out = ctx.Output("Out"); - + auto ins = ctx.MultiInput("X"); + auto* ids = ctx.Input("Ids"); + auto* out = ctx.Output("Out"); out->mutable_data(ctx.GetPlace()); - auto rows = ins[1]->dims()[0]; - auto cols = ins[1]->dims()[1]; + auto rows = ins[0]->dims()[0]; + auto cols = ins[0]->numel() / rows; // copy index to cpu - framework::Tensor index_t_cpu; - index_t_cpu.CopyFrom(*(ins[0]), platform::CPUPlace()); - auto* index = index_t_cpu.data(); + Tensor index_t_cpu; + index_t_cpu.CopyFrom(*ids, platform::CPUPlace()); + auto* index = index_t_cpu.data(); auto stream = reinterpret_cast( ctx.device_context()) .stream(); Place place = boost::get(ctx.GetPlace()); for (auto i = 0; i < rows; i++) { - int k = (int)index[i] + 1; + int32_t k = index[i]; + PADDLE_ENFORCE_GE(k, 0, "index must be nonnegative."); PADDLE_ENFORCE_LT(k, ins.size(), "index exceeds the number of candidate tensors."); memory::Copy(place, out->data() + i * cols, place, @@ -51,11 +54,11 @@ template class MultiplexGradGPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const { - auto* d_out = ctx.Input(framework::GradVarName("Out")); - auto ins = ctx.MultiInput("X"); - auto d_ins = - ctx.MultiOutput(framework::GradVarName("X")); - for (size_t i = 1; i < d_ins.size(); i++) { + auto* d_out = ctx.Input(framework::GradVarName("Out")); + auto ins = ctx.MultiInput("X"); + auto* ids = ctx.Input("Ids"); + auto d_ins = ctx.MultiOutput(framework::GradVarName("X")); + for (size_t i = 0; i < d_ins.size(); i++) { if (d_ins[i]) { d_ins[i]->mutable_data(ctx.GetPlace()); auto t = framework::EigenVector::Flatten(*d_ins[i]); @@ -63,19 +66,19 @@ class MultiplexGradGPUKernel : public framework::OpKernel { } } - auto rows = ins[1]->dims()[0]; - auto cols = ins[1]->dims()[1]; + auto rows = ins[0]->dims()[0]; + auto cols = ins[0]->numel() / rows; // copy index to cpu - framework::Tensor index_t_cpu; - index_t_cpu.CopyFrom(*(ins[0]), platform::CPUPlace()); - auto* index = index_t_cpu.data(); + Tensor index_t_cpu; + index_t_cpu.CopyFrom(*ids, platform::CPUPlace()); + auto* index = index_t_cpu.data(); auto stream = reinterpret_cast( ctx.device_context()) .stream(); Place place = boost::get(ctx.GetPlace()); for (auto i = 0; i < rows; i++) { - int k = (int)index[i] + 1; + size_t k = static_cast(index[i]); if (d_ins[k]) { memory::Copy(place, d_ins[k]->data() + i * cols, place, d_out->data() + i * cols, cols * sizeof(T), stream); diff --git a/paddle/operators/multiplex_op.h b/paddle/operators/multiplex_op.h index 98466426bd90bc30a22ecf74e6739e2d4ad1d21d..637c63a34af394f5f54997c46c00a9ff00577476 100644 --- a/paddle/operators/multiplex_op.h +++ b/paddle/operators/multiplex_op.h @@ -27,16 +27,18 @@ class MultiplexCPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const { auto ins = ctx.MultiInput("X"); - auto* out = ctx.Output("Out"); + auto ids = ctx.Input("Ids"); + auto* out = ctx.Output("Out"); out->mutable_data(ctx.GetPlace()); - auto rows = ins[1]->dims()[0]; - auto cols = ins[1]->dims()[1]; - auto* index = ins[0]->data(); + auto rows = ins[0]->dims()[0]; + auto cols = ins[0]->numel() / rows; + auto index = ids->data(); Place place = boost::get(ctx.GetPlace()); for (auto i = 0; i < rows; i++) { - int k = (int)index[i] + 1; + int32_t k = index[i]; + PADDLE_ENFORCE_GE(k, 0, "index must be nonnegative."); PADDLE_ENFORCE_LT(static_cast(k), ins.size(), "index exceeds the number of candidate tensors."); memory::Copy(place, out->data() + i * cols, place, @@ -50,10 +52,11 @@ class MultiplexGradCPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const { auto* d_out = ctx.Input(framework::GradVarName("Out")); + auto* ids = ctx.Input("Ids"); auto ins = ctx.MultiInput("X"); auto d_ins = ctx.MultiOutput(framework::GradVarName("X")); - for (size_t i = 1; i < d_ins.size(); i++) { + for (size_t i = 0; i < d_ins.size(); i++) { if (d_ins[i]) { d_ins[i]->mutable_data(ctx.GetPlace()); auto t = framework::EigenVector::Flatten(*d_ins[i]); @@ -61,12 +64,12 @@ class MultiplexGradCPUKernel : public framework::OpKernel { } } - auto rows = ins[1]->dims()[0]; - auto cols = ins[1]->dims()[1]; - auto* index = ins[0]->data(); + auto rows = ins[0]->dims()[0]; + auto cols = ins[0]->numel() / rows; + auto* index = ids->data(); Place place = boost::get(ctx.GetPlace()); for (auto i = 0; i < rows; i++) { - int k = (int)index[i] + 1; + size_t k = static_cast(index[i]); if (d_ins[k]) { memory::Copy(place, d_ins[k]->data() + i * cols, place, d_out->data() + i * cols, cols * sizeof(T)); diff --git a/python/paddle/v2/framework/tests/test_multiplex_op.py b/python/paddle/v2/framework/tests/test_multiplex_op.py index f2b3881cde24c7fb96c3d7f9411352bc62d55077..5937eb5aa4621556c9b8d59ea83a39d9738c7925 100644 --- a/python/paddle/v2/framework/tests/test_multiplex_op.py +++ b/python/paddle/v2/framework/tests/test_multiplex_op.py @@ -6,20 +6,22 @@ from op_test import OpTest class TestMultiplexOp(OpTest): def setUp(self): self.op_type = "multiplex" - rows = 3 - index = np.array([3, 1, 0]) + rows = 4 + index = np.arange(0, rows).astype('int32') + np.random.shuffle(index) + index = np.reshape(index, (rows, 1)) ins1 = np.random.random((rows, 10)).astype("float32") ins2 = np.random.random((rows, 10)).astype("float32") ins3 = np.random.random((rows, 10)).astype("float32") ins4 = np.random.random((rows, 10)).astype("float32") self.inputs = { - 'X': [('index', index), ('x1', ins1), ('x2', ins2), ('x3', ins3), - ('x4', ins4)] + 'Ids': index, + 'X': [('x1', ins1), ('x2', ins2), ('x3', ins3), ('x4', ins4)] } # multiplex output output = np.zeros_like(ins1) for i in range(0, rows): - k = index[i] + 1 + k = index[i][0] output[i] = self.inputs['X'][k][1][i] self.outputs = {'Out': output}