提交 58ac8f46 编写于 作者: Y Yibing Liu

apply more general dims for multiplex_op

上级 089f8e2d
...@@ -44,7 +44,8 @@ class MultiplexOp : public framework::OperatorWithKernel { ...@@ -44,7 +44,8 @@ class MultiplexOp : public framework::OperatorWithKernel {
"one candidate input tensors."); "one candidate input tensors.");
auto in_dim = ins[0]->dims(); auto in_dim = ins[0]->dims();
PADDLE_ENFORCE(in_dim.size() == 2, "Candidate tensors must be matrix."); PADDLE_ENFORCE(in_dim.size() >= 2,
"The rank of candidate tensors must be not less than 2.");
for (size_t i = 1; i < num_ins; i++) { for (size_t i = 1; i < num_ins; i++) {
auto dim = ins[i]->dims(); auto dim = ins[i]->dims();
PADDLE_ENFORCE(in_dim == dim, PADDLE_ENFORCE(in_dim == dim,
...@@ -65,8 +66,7 @@ class MultiplexOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -65,8 +66,7 @@ class MultiplexOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("Out", "The output tensor of multiplex operator."); AddOutput("Out", "The output tensor of multiplex operator.");
AddComment(R"DOC(Multiplex operator AddComment(R"DOC(Multiplex operator
Multiplex multiple tensors according to the index provided by the first Multiplex multiple tensors according to the index provided by the index tensor.
input tensor.
Ids: the index tensor. Ids: the index tensor.
X[0 : N - 1]: the candidate tensors for output (N >= 2). X[0 : N - 1]: the candidate tensors for output (N >= 2).
...@@ -75,7 +75,7 @@ the (Ids[i])-th tensor. ...@@ -75,7 +75,7 @@ the (Ids[i])-th tensor.
For i-th row of the output tensor: For i-th row of the output tensor:
y[i][j] = x_{k}[i][j], j = 0,1, ... , (x_{0}.width - 1) y[i] = x_{k}[i]
where y is the output tensor. `x_{k}` is the k-th input tensor where y is the output tensor. `x_{k}` is the k-th input tensor
and `k = Ids[i]`. and `k = Ids[i]`.
......
...@@ -30,7 +30,7 @@ class MultiplexGPUKernel : public framework::OpKernel { ...@@ -30,7 +30,7 @@ class MultiplexGPUKernel : public framework::OpKernel {
out->mutable_data<T>(ctx.GetPlace()); out->mutable_data<T>(ctx.GetPlace());
auto rows = ins[0]->dims()[0]; auto rows = ins[0]->dims()[0];
auto cols = ins[0]->dims()[1]; auto cols = ins[0]->numel() / rows;
// copy index to cpu // copy index to cpu
Tensor index_t_cpu; Tensor index_t_cpu;
index_t_cpu.CopyFrom<int32_t>(*ids, platform::CPUPlace()); index_t_cpu.CopyFrom<int32_t>(*ids, platform::CPUPlace());
...@@ -67,7 +67,7 @@ class MultiplexGradGPUKernel : public framework::OpKernel { ...@@ -67,7 +67,7 @@ class MultiplexGradGPUKernel : public framework::OpKernel {
} }
auto rows = ins[0]->dims()[0]; auto rows = ins[0]->dims()[0];
auto cols = ins[0]->dims()[1]; auto cols = ins[0]->numel() / rows;
// copy index to cpu // copy index to cpu
Tensor index_t_cpu; Tensor index_t_cpu;
index_t_cpu.CopyFrom<int32_t>(*ids, platform::CPUPlace()); index_t_cpu.CopyFrom<int32_t>(*ids, platform::CPUPlace());
......
...@@ -33,7 +33,7 @@ class MultiplexCPUKernel : public framework::OpKernel { ...@@ -33,7 +33,7 @@ class MultiplexCPUKernel : public framework::OpKernel {
out->mutable_data<T>(ctx.GetPlace()); out->mutable_data<T>(ctx.GetPlace());
auto rows = ins[0]->dims()[0]; auto rows = ins[0]->dims()[0];
auto cols = ins[0]->dims()[1]; auto cols = ins[0]->numel() / rows;
auto index = ids->data<int32_t>(); auto index = ids->data<int32_t>();
Place place = boost::get<Place>(ctx.GetPlace()); Place place = boost::get<Place>(ctx.GetPlace());
for (auto i = 0; i < rows; i++) { for (auto i = 0; i < rows; i++) {
...@@ -65,7 +65,7 @@ class MultiplexGradCPUKernel : public framework::OpKernel { ...@@ -65,7 +65,7 @@ class MultiplexGradCPUKernel : public framework::OpKernel {
} }
auto rows = ins[0]->dims()[0]; auto rows = ins[0]->dims()[0];
auto cols = ins[0]->dims()[1]; auto cols = ins[0]->numel() / rows;
auto* index = ids->data<int32_t>(); auto* index = ids->data<int32_t>();
Place place = boost::get<Place>(ctx.GetPlace()); Place place = boost::get<Place>(ctx.GetPlace());
for (auto i = 0; i < rows; i++) { for (auto i = 0; i < rows; i++) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册