提交 58ac8f46 编写于 作者: Y Yibing Liu

apply more general dims for multiplex_op

上级 089f8e2d
......@@ -44,7 +44,8 @@ class MultiplexOp : public framework::OperatorWithKernel {
"one candidate input tensors.");
auto in_dim = ins[0]->dims();
PADDLE_ENFORCE(in_dim.size() == 2, "Candidate tensors must be matrix.");
PADDLE_ENFORCE(in_dim.size() >= 2,
"The rank of candidate tensors must be not less than 2.");
for (size_t i = 1; i < num_ins; i++) {
auto dim = ins[i]->dims();
PADDLE_ENFORCE(in_dim == dim,
......@@ -65,8 +66,7 @@ class MultiplexOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("Out", "The output tensor of multiplex operator.");
AddComment(R"DOC(Multiplex operator
Multiplex multiple tensors according to the index provided by the first
input tensor.
Multiplex multiple tensors according to the index provided by the index tensor.
Ids: the index tensor.
X[0 : N - 1]: the candidate tensors for output (N >= 2).
......@@ -75,7 +75,7 @@ the (Ids[i])-th tensor.
For i-th row of the output tensor:
y[i][j] = x_{k}[i][j], j = 0,1, ... , (x_{0}.width - 1)
y[i] = x_{k}[i]
where y is the output tensor. `x_{k}` is the k-th input tensor
and `k = Ids[i]`.
......
......@@ -30,7 +30,7 @@ class MultiplexGPUKernel : public framework::OpKernel {
out->mutable_data<T>(ctx.GetPlace());
auto rows = ins[0]->dims()[0];
auto cols = ins[0]->dims()[1];
auto cols = ins[0]->numel() / rows;
// copy index to cpu
Tensor index_t_cpu;
index_t_cpu.CopyFrom<int32_t>(*ids, platform::CPUPlace());
......@@ -67,7 +67,7 @@ class MultiplexGradGPUKernel : public framework::OpKernel {
}
auto rows = ins[0]->dims()[0];
auto cols = ins[0]->dims()[1];
auto cols = ins[0]->numel() / rows;
// copy index to cpu
Tensor index_t_cpu;
index_t_cpu.CopyFrom<int32_t>(*ids, platform::CPUPlace());
......
......@@ -33,7 +33,7 @@ class MultiplexCPUKernel : public framework::OpKernel {
out->mutable_data<T>(ctx.GetPlace());
auto rows = ins[0]->dims()[0];
auto cols = ins[0]->dims()[1];
auto cols = ins[0]->numel() / rows;
auto index = ids->data<int32_t>();
Place place = boost::get<Place>(ctx.GetPlace());
for (auto i = 0; i < rows; i++) {
......@@ -65,7 +65,7 @@ class MultiplexGradCPUKernel : public framework::OpKernel {
}
auto rows = ins[0]->dims()[0];
auto cols = ins[0]->dims()[1];
auto cols = ins[0]->numel() / rows;
auto* index = ids->data<int32_t>();
Place place = boost::get<Place>(ctx.GetPlace());
for (auto i = 0; i < rows; i++) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册