提交 b3f44ad7 编写于 作者: Y Yibing Liu

add multiplex operator

上级 4137cb0b
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/multiplex_op.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
class MultiplexOp : public framework::OperatorWithKernel {
public:
MultiplexOp(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {}
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
auto ins = ctx.MultiInput<Tensor>("X");
auto *out = ctx.Output<Tensor>("Out");
auto num_ins = ins.size();
PADDLE_ENFORCE(num_ins > 2,
"multiplex operator should have more than 2 inputs.");
PADDLE_ENFORCE_EQ(ins[0]->dims().size(), 1,
"The first input must be a index vector.");
auto in_dim = ins[1]->dims();
for (size_t i = 2; i < num_ins; i++) {
auto dim = ins[i]->dims();
PADDLE_ENFORCE(
in_dim == dim,
"All the input tensors except the first one must have the same size");
}
out->Resize(in_dim);
}
};
class MultiplexOpMaker : public framework::OpProtoAndCheckerMaker {
public:
MultiplexOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input tensor of multiplex operator.").AsDuplicable();
AddOutput("Out", "The output tensor of multiplex operator.");
AddComment(R"DOC(Multiplex operator
Multiplex multiple tensors according to the index provided by the first
input tensor.
ins[0]: the index of the tensor to output of size batchSize.
ins[1:N]: the candidate output tensor.
For each index i from 0 to batchSize - 1, the output is the i-th row of the
the (index[i] + 1)-th tensor.
For each i-th row of output:
y[i][j] = x_{k}[i][j], j = 0,1, ... , (x_{1}.width - 1)
where y is the output tensor. `x_{k}` is the k-th input layer
and `k = x{0}[i] + 1`.
)DOC");
}
};
class MultiplexGradOp : public framework::OperatorWithKernel {
public:
MultiplexGradOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {}
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
"Input(Out@GRAD) shouldn't be null.");
auto d_ins = ctx.MultiOutput<Tensor>(framework::GradVarName("X"));
auto ins = ctx.MultiInput<Tensor>("X");
for (size_t i = 0; i < ins.size(); i++) {
auto dims = ins[i]->dims();
d_ins[i]->Resize(dims);
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(multiplex, ops::MultiplexOp, ops::MultiplexOpMaker, multiplex_grad,
ops::MultiplexGradOp);
REGISTER_OP_CPU_KERNEL(multiplex, ops::MultiplexCPUKernel<float>);
REGISTER_OP_CPU_KERNEL(multiplex_grad, ops::MultiplexGradCPUKernel<float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
class MultiplexGPUKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const {
auto ins = ctx.MultiInput<Tensor>("X");
auto* out = ctx.Output<Tensor>("Out");
out->mutable_data<T>(ctx.GetPlace());
auto rows = ins[1]->dims()[0];
auto cols = ins[1]->dims()[1];
// copy index to cpu
Tensor index_t_cpu;
index_t_cpu.CopyFrom<T>(*(ins[0]), paddle::platform::CPUPlace());
auto index = index_t_cpu.data<T>();
for (auto i = 0; i < rows; i++) {
int k = (int)index[i] + 1;
cudaMemcpy(out->data<T>() + i * cols, ins[k]->data<T>() + i * cols,
cols * sizeof(T), cudaMemcpyDeviceToDevice);
}
}
};
template <typename T>
class MultiplexGradGPUKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const {
auto* d_out = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto ins = ctx.MultiInput<Tensor>("X");
auto d_ins = ctx.MultiOutput<Tensor>(framework::GradVarName("X"));
for (auto d_in : d_ins) {
d_in->mutable_data<T>(ctx.GetPlace());
auto dims = d_in->dims();
cudaMemset(d_in->data<T>(), 0, framework::product(dims) * sizeof(T));
}
auto rows = ins[1]->dims()[0];
auto cols = ins[1]->dims()[1];
// copy index to cpu
Tensor index_t_cpu;
index_t_cpu.CopyFrom<T>(*(ins[0]), paddle::platform::CPUPlace());
auto index = index_t_cpu.data<T>();
for (auto i = 0; i < rows; i++) {
int k = (int)index[i] + 1;
cudaMemcpy(d_ins[k]->data<T>() + i * cols, d_out->data<T>() + i * cols,
cols * sizeof(T), cudaMemcpyDeviceToDevice);
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(multiplex, ops::MultiplexGPUKernel<float>);
REGISTER_OP_GPU_KERNEL(multiplex_grad, ops::MultiplexGradGPUKernel<float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
template <typename T>
class MultiplexCPUKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const {
auto ins = ctx.MultiInput<framework::Tensor>("X");
auto* out = ctx.Output<framework::Tensor>("Out");
out->mutable_data<T>(ctx.GetPlace());
auto index = ins[0]->data<T>();
auto rows = ins[1]->dims()[0];
auto cols = ins[1]->dims()[1];
for (auto i = 0; i < rows; i++) {
int k = (int)index[i] + 1;
memcpy(out->data<T>() + i * cols, ins[k]->data<T>() + i * cols,
cols * sizeof(T));
}
}
};
template <typename T>
class MultiplexGradCPUKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const {
auto* d_out = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto ins = ctx.MultiInput<framework::Tensor>("X");
auto d_ins =
ctx.MultiOutput<framework::Tensor>(framework::GradVarName("X"));
for (auto d_in : d_ins) {
d_in->mutable_data<T>(ctx.GetPlace());
auto dims = d_in->dims();
memset(d_in->data<T>(), 0, framework::product(dims) * sizeof(T));
}
auto index = ins[0]->data<T>();
auto rows = ins[1]->dims()[0];
auto cols = ins[1]->dims()[1];
for (auto i = 0; i < rows; i++) {
int k = (int)index[i] + 1;
memcpy(d_ins[k]->data<T>() + i * cols, d_out->data<T>() + i * cols,
cols * sizeof(T));
}
}
};
}
}
......@@ -55,6 +55,7 @@ USE_OP(top_k);
USE_OP(squared_l2_distance);
USE_OP(sum);
USE_OP(reshape);
USE_OP(multiplex);
namespace paddle {
namespace framework {
......
......@@ -36,3 +36,4 @@ py_test(mnist SRCS mnist.py)
py_test(test_concat_op SRCS test_concat_op.py)
py_test(test_squared_l2_distance_op SRCS test_squared_l2_distance_op.py)
py_test(test_reshape_op SRCS test_reshape_op.py)
py_test(test_multiplex_op SRCS test_multiplex_op.py)
import unittest
import numpy as np
from op_test import OpTest
class TestMultiplexOp(OpTest):
def setUp(self):
self.op_type = "multiplex"
rows = 3
index = np.array([3, 1, 0])
ins1 = np.random.random((rows, 10)).astype("float32")
ins2 = np.random.random((rows, 10)).astype("float32")
ins3 = np.random.random((rows, 10)).astype("float32")
ins4 = np.random.random((rows, 10)).astype("float32")
self.inputs = {
'X': [('index', index), ('x1', ins1), ('x2', ins2), ('x3', ins3),
('x4', ins4)]
}
# multiplex output
output = np.zeros_like(ins1)
for i in range(0, rows):
k = index[i] + 1
output[i] = self.inputs['X'][k][1][i]
self.outputs = {'Out': output}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["x1"], "Out")
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册