From 58ac8f46b84093b66d1567deee37df5800231a52 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Tue, 26 Sep 2017 12:50:22 +0800 Subject: [PATCH] apply more general dims for multiplex_op --- paddle/operators/multiplex_op.cc | 8 ++++---- paddle/operators/multiplex_op.cu | 4 ++-- paddle/operators/multiplex_op.h | 4 ++-- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/paddle/operators/multiplex_op.cc b/paddle/operators/multiplex_op.cc index 44be9b38cee..342b3911ae6 100644 --- a/paddle/operators/multiplex_op.cc +++ b/paddle/operators/multiplex_op.cc @@ -44,7 +44,8 @@ class MultiplexOp : public framework::OperatorWithKernel { "one candidate input tensors."); auto in_dim = ins[0]->dims(); - PADDLE_ENFORCE(in_dim.size() == 2, "Candidate tensors must be matrix."); + PADDLE_ENFORCE(in_dim.size() >= 2, + "The rank of candidate tensors must be not less than 2."); for (size_t i = 1; i < num_ins; i++) { auto dim = ins[i]->dims(); PADDLE_ENFORCE(in_dim == dim, @@ -65,8 +66,7 @@ class MultiplexOpMaker : public framework::OpProtoAndCheckerMaker { AddOutput("Out", "The output tensor of multiplex operator."); AddComment(R"DOC(Multiplex operator -Multiplex multiple tensors according to the index provided by the first -input tensor. +Multiplex multiple tensors according to the index provided by the index tensor. Ids: the index tensor. X[0 : N - 1]: the candidate tensors for output (N >= 2). @@ -75,7 +75,7 @@ the (Ids[i])-th tensor. For i-th row of the output tensor: -y[i][j] = x_{k}[i][j], j = 0,1, ... , (x_{0}.width - 1) +y[i] = x_{k}[i] where y is the output tensor. `x_{k}` is the k-th input tensor and `k = Ids[i]`. diff --git a/paddle/operators/multiplex_op.cu b/paddle/operators/multiplex_op.cu index d990b227e70..70e46815fc9 100644 --- a/paddle/operators/multiplex_op.cu +++ b/paddle/operators/multiplex_op.cu @@ -30,7 +30,7 @@ class MultiplexGPUKernel : public framework::OpKernel { out->mutable_data(ctx.GetPlace()); auto rows = ins[0]->dims()[0]; - auto cols = ins[0]->dims()[1]; + auto cols = ins[0]->numel() / rows; // copy index to cpu Tensor index_t_cpu; index_t_cpu.CopyFrom(*ids, platform::CPUPlace()); @@ -67,7 +67,7 @@ class MultiplexGradGPUKernel : public framework::OpKernel { } auto rows = ins[0]->dims()[0]; - auto cols = ins[0]->dims()[1]; + auto cols = ins[0]->numel() / rows; // copy index to cpu Tensor index_t_cpu; index_t_cpu.CopyFrom(*ids, platform::CPUPlace()); diff --git a/paddle/operators/multiplex_op.h b/paddle/operators/multiplex_op.h index c39684920c0..637c63a34af 100644 --- a/paddle/operators/multiplex_op.h +++ b/paddle/operators/multiplex_op.h @@ -33,7 +33,7 @@ class MultiplexCPUKernel : public framework::OpKernel { out->mutable_data(ctx.GetPlace()); auto rows = ins[0]->dims()[0]; - auto cols = ins[0]->dims()[1]; + auto cols = ins[0]->numel() / rows; auto index = ids->data(); Place place = boost::get(ctx.GetPlace()); for (auto i = 0; i < rows; i++) { @@ -65,7 +65,7 @@ class MultiplexGradCPUKernel : public framework::OpKernel { } auto rows = ins[0]->dims()[0]; - auto cols = ins[0]->dims()[1]; + auto cols = ins[0]->numel() / rows; auto* index = ids->data(); Place place = boost::get(ctx.GetPlace()); for (auto i = 0; i < rows; i++) { -- GitLab