提交 367a54e0 编写于 作者: Y Yibing Liu 提交者: GitHub

Merge pull request #4360 from kuke/multiplex_modify_dev

Modify multiplex_op
...@@ -18,7 +18,6 @@ namespace paddle { ...@@ -18,7 +18,6 @@ namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
class MultiplexOp : public framework::OperatorWithKernel { class MultiplexOp : public framework::OperatorWithKernel {
public: public:
...@@ -26,24 +25,31 @@ class MultiplexOp : public framework::OperatorWithKernel { ...@@ -26,24 +25,31 @@ class MultiplexOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Ids"),
"Input(Ids) shouldn't be null.");
PADDLE_ENFORCE(!ctx.MultiInputVar("X").empty(), PADDLE_ENFORCE(!ctx.MultiInputVar("X").empty(),
"Input(X) should not be null"); "MultiInput(X) shouldn't be empty.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"),
"Output(Out) shouldn't be null."); "Output(Out) shouldn't be null.");
auto ids_dim = ctx.Input<Tensor>("Ids")->dims();
PADDLE_ENFORCE(
ids_dim.size() == 2 && ids_dim[1] == 1,
"The index tensor must be a vector with size batchSize x 1.");
auto ins = ctx.MultiInput<Tensor>("X"); auto ins = ctx.MultiInput<Tensor>("X");
auto *out = ctx.Output<LoDTensor>("Out"); auto *out = ctx.Output<Tensor>("Out");
auto num_ins = ins.size(); auto num_ins = ins.size();
PADDLE_ENFORCE(num_ins > 2, PADDLE_ENFORCE(num_ins > 1,
"multiplex operator should have more than 2 inputs."); "multiplex operator should have more than "
PADDLE_ENFORCE_EQ(ins[0]->dims().size(), 1, "one candidate input tensors.");
"The first input must be a index vector.");
auto in_dim = ins[1]->dims(); auto in_dim = ins[0]->dims();
PADDLE_ENFORCE(in_dim.size() >= 2,
for (size_t i = 2; i < num_ins; i++) { "The rank of candidate tensors must be not less than 2.");
for (size_t i = 1; i < num_ins; i++) {
auto dim = ins[i]->dims(); auto dim = ins[i]->dims();
PADDLE_ENFORCE( PADDLE_ENFORCE(in_dim == dim,
in_dim == dim, "All the candidate tensors must have the same size.");
"All the input tensors except the first one must have the same size");
} }
out->Resize(in_dim); out->Resize(in_dim);
} }
...@@ -54,25 +60,25 @@ class MultiplexOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -54,25 +60,25 @@ class MultiplexOpMaker : public framework::OpProtoAndCheckerMaker {
MultiplexOpMaker(framework::OpProto *proto, MultiplexOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker) framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input tensors of multiplex operator.").AsDuplicable(); AddInput("Ids", "The index tensor of multiplex operator.");
AddInput("X", "The candidate tensors of multiplex operator.")
.AsDuplicable();
AddOutput("Out", "The output tensor of multiplex operator."); AddOutput("Out", "The output tensor of multiplex operator.");
AddComment(R"DOC(Multiplex operator AddComment(R"DOC(Multiplex operator
Multiplex multiple tensors according to the index provided by the first Multiplex multiple tensors according to the index provided by the index tensor.
input tensor.
ins[0]: the index tensor. Ids: the index tensor.
ins[1:N]: the candidate output tensors. X[0 : N - 1]: the candidate tensors for output (N >= 2).
For each index i from 0 to batchSize - 1, the output is the i-th row of the For each index i from 0 to batchSize - 1, the output is the i-th row of the
the (index[i] + 1)-th tensor. the (Ids[i])-th tensor.
For i-th row of the output tensor: For i-th row of the output tensor:
y[i][j] = x_{k}[i][j], j = 0,1, ... , (x_{1}.width - 1) y[i] = x_{k}[i]
where y is the output tensor. `x_{k}` is the k-th input tensor where y is the output tensor. `x_{k}` is the k-th input tensor
and `k = x{0}[i] + 1`. and `k = Ids[i]`.
)DOC"); )DOC");
} }
}; };
...@@ -84,15 +90,15 @@ class MultiplexGradOp : public framework::OperatorWithKernel { ...@@ -84,15 +90,15 @@ class MultiplexGradOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE(!ctx.MultiInputVar("X").empty(), PADDLE_ENFORCE(!ctx.MultiInputVar("X").empty(),
"Input(X) should not be null"); "Input(X) should not be null.");
PADDLE_ENFORCE(!ctx.MultiOutputVar(framework::GradVarName("X")).empty(), PADDLE_ENFORCE(!ctx.MultiOutputVar(framework::GradVarName("X")).empty(),
"Output(X@Grad) should not be null"); "Output(X@Grad) should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
"Input(Out@GRAD) shouldn't be null."); "Input(Out@GRAD) should not be null.");
auto d_ins = ctx.MultiOutput<LoDTensor>(framework::GradVarName("X")); auto d_ins = ctx.MultiOutput<Tensor>(framework::GradVarName("X"));
auto ins = ctx.MultiInput<Tensor>("X"); auto ins = ctx.MultiInput<Tensor>("X");
// don't compute gradient for index (ins[0]) // No need to compute gradient for Input(Ids)
for (size_t i = 1; i < ins.size(); i++) { for (size_t i = 0; i < ins.size(); i++) {
if (d_ins[i]) { if (d_ins[i]) {
d_ins[i]->Resize(ins[i]->dims()); d_ins[i]->Resize(ins[i]->dims());
} }
......
...@@ -18,27 +18,30 @@ ...@@ -18,27 +18,30 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor;
template <typename Place, typename T> template <typename Place, typename T>
class MultiplexGPUKernel : public framework::OpKernel { class MultiplexGPUKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& ctx) const { void Compute(const framework::ExecutionContext& ctx) const {
auto ins = ctx.MultiInput<framework::Tensor>("X"); auto ins = ctx.MultiInput<Tensor>("X");
auto* out = ctx.Output<framework::LoDTensor>("Out"); auto* ids = ctx.Input<Tensor>("Ids");
auto* out = ctx.Output<Tensor>("Out");
out->mutable_data<T>(ctx.GetPlace()); out->mutable_data<T>(ctx.GetPlace());
auto rows = ins[1]->dims()[0]; auto rows = ins[0]->dims()[0];
auto cols = ins[1]->dims()[1]; auto cols = ins[0]->numel() / rows;
// copy index to cpu // copy index to cpu
framework::Tensor index_t_cpu; Tensor index_t_cpu;
index_t_cpu.CopyFrom<T>(*(ins[0]), platform::CPUPlace()); index_t_cpu.CopyFrom<int32_t>(*ids, platform::CPUPlace());
auto* index = index_t_cpu.data<T>(); auto* index = index_t_cpu.data<int32_t>();
auto stream = reinterpret_cast<const platform::CUDADeviceContext&>( auto stream = reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context()) ctx.device_context())
.stream(); .stream();
Place place = boost::get<Place>(ctx.GetPlace()); Place place = boost::get<Place>(ctx.GetPlace());
for (auto i = 0; i < rows; i++) { for (auto i = 0; i < rows; i++) {
int k = (int)index[i] + 1; int32_t k = index[i];
PADDLE_ENFORCE_GE(k, 0, "index must be nonnegative.");
PADDLE_ENFORCE_LT(k, ins.size(), PADDLE_ENFORCE_LT(k, ins.size(),
"index exceeds the number of candidate tensors."); "index exceeds the number of candidate tensors.");
memory::Copy(place, out->data<T>() + i * cols, place, memory::Copy(place, out->data<T>() + i * cols, place,
...@@ -51,11 +54,11 @@ template <typename Place, typename T> ...@@ -51,11 +54,11 @@ template <typename Place, typename T>
class MultiplexGradGPUKernel : public framework::OpKernel { class MultiplexGradGPUKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& ctx) const { void Compute(const framework::ExecutionContext& ctx) const {
auto* d_out = ctx.Input<framework::Tensor>(framework::GradVarName("Out")); auto* d_out = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto ins = ctx.MultiInput<framework::Tensor>("X"); auto ins = ctx.MultiInput<Tensor>("X");
auto d_ins = auto* ids = ctx.Input<Tensor>("Ids");
ctx.MultiOutput<framework::Tensor>(framework::GradVarName("X")); auto d_ins = ctx.MultiOutput<Tensor>(framework::GradVarName("X"));
for (size_t i = 1; i < d_ins.size(); i++) { for (size_t i = 0; i < d_ins.size(); i++) {
if (d_ins[i]) { if (d_ins[i]) {
d_ins[i]->mutable_data<T>(ctx.GetPlace()); d_ins[i]->mutable_data<T>(ctx.GetPlace());
auto t = framework::EigenVector<T>::Flatten(*d_ins[i]); auto t = framework::EigenVector<T>::Flatten(*d_ins[i]);
...@@ -63,19 +66,19 @@ class MultiplexGradGPUKernel : public framework::OpKernel { ...@@ -63,19 +66,19 @@ class MultiplexGradGPUKernel : public framework::OpKernel {
} }
} }
auto rows = ins[1]->dims()[0]; auto rows = ins[0]->dims()[0];
auto cols = ins[1]->dims()[1]; auto cols = ins[0]->numel() / rows;
// copy index to cpu // copy index to cpu
framework::Tensor index_t_cpu; Tensor index_t_cpu;
index_t_cpu.CopyFrom<T>(*(ins[0]), platform::CPUPlace()); index_t_cpu.CopyFrom<int32_t>(*ids, platform::CPUPlace());
auto* index = index_t_cpu.data<T>(); auto* index = index_t_cpu.data<int32_t>();
auto stream = reinterpret_cast<const platform::CUDADeviceContext&>( auto stream = reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context()) ctx.device_context())
.stream(); .stream();
Place place = boost::get<Place>(ctx.GetPlace()); Place place = boost::get<Place>(ctx.GetPlace());
for (auto i = 0; i < rows; i++) { for (auto i = 0; i < rows; i++) {
int k = (int)index[i] + 1; size_t k = static_cast<size_t>(index[i]);
if (d_ins[k]) { if (d_ins[k]) {
memory::Copy(place, d_ins[k]->data<T>() + i * cols, place, memory::Copy(place, d_ins[k]->data<T>() + i * cols, place,
d_out->data<T>() + i * cols, cols * sizeof(T), stream); d_out->data<T>() + i * cols, cols * sizeof(T), stream);
......
...@@ -27,16 +27,18 @@ class MultiplexCPUKernel : public framework::OpKernel { ...@@ -27,16 +27,18 @@ class MultiplexCPUKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& ctx) const { void Compute(const framework::ExecutionContext& ctx) const {
auto ins = ctx.MultiInput<framework::Tensor>("X"); auto ins = ctx.MultiInput<framework::Tensor>("X");
auto* out = ctx.Output<framework::LoDTensor>("Out"); auto ids = ctx.Input<framework::Tensor>("Ids");
auto* out = ctx.Output<framework::Tensor>("Out");
out->mutable_data<T>(ctx.GetPlace()); out->mutable_data<T>(ctx.GetPlace());
auto rows = ins[1]->dims()[0]; auto rows = ins[0]->dims()[0];
auto cols = ins[1]->dims()[1]; auto cols = ins[0]->numel() / rows;
auto* index = ins[0]->data<T>(); auto index = ids->data<int32_t>();
Place place = boost::get<Place>(ctx.GetPlace()); Place place = boost::get<Place>(ctx.GetPlace());
for (auto i = 0; i < rows; i++) { for (auto i = 0; i < rows; i++) {
int k = (int)index[i] + 1; int32_t k = index[i];
PADDLE_ENFORCE_GE(k, 0, "index must be nonnegative.");
PADDLE_ENFORCE_LT(static_cast<size_t>(k), ins.size(), PADDLE_ENFORCE_LT(static_cast<size_t>(k), ins.size(),
"index exceeds the number of candidate tensors."); "index exceeds the number of candidate tensors.");
memory::Copy(place, out->data<T>() + i * cols, place, memory::Copy(place, out->data<T>() + i * cols, place,
...@@ -50,10 +52,11 @@ class MultiplexGradCPUKernel : public framework::OpKernel { ...@@ -50,10 +52,11 @@ class MultiplexGradCPUKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& ctx) const { void Compute(const framework::ExecutionContext& ctx) const {
auto* d_out = ctx.Input<framework::Tensor>(framework::GradVarName("Out")); auto* d_out = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* ids = ctx.Input<framework::Tensor>("Ids");
auto ins = ctx.MultiInput<framework::Tensor>("X"); auto ins = ctx.MultiInput<framework::Tensor>("X");
auto d_ins = auto d_ins =
ctx.MultiOutput<framework::Tensor>(framework::GradVarName("X")); ctx.MultiOutput<framework::Tensor>(framework::GradVarName("X"));
for (size_t i = 1; i < d_ins.size(); i++) { for (size_t i = 0; i < d_ins.size(); i++) {
if (d_ins[i]) { if (d_ins[i]) {
d_ins[i]->mutable_data<T>(ctx.GetPlace()); d_ins[i]->mutable_data<T>(ctx.GetPlace());
auto t = framework::EigenVector<T>::Flatten(*d_ins[i]); auto t = framework::EigenVector<T>::Flatten(*d_ins[i]);
...@@ -61,12 +64,12 @@ class MultiplexGradCPUKernel : public framework::OpKernel { ...@@ -61,12 +64,12 @@ class MultiplexGradCPUKernel : public framework::OpKernel {
} }
} }
auto rows = ins[1]->dims()[0]; auto rows = ins[0]->dims()[0];
auto cols = ins[1]->dims()[1]; auto cols = ins[0]->numel() / rows;
auto* index = ins[0]->data<T>(); auto* index = ids->data<int32_t>();
Place place = boost::get<Place>(ctx.GetPlace()); Place place = boost::get<Place>(ctx.GetPlace());
for (auto i = 0; i < rows; i++) { for (auto i = 0; i < rows; i++) {
int k = (int)index[i] + 1; size_t k = static_cast<size_t>(index[i]);
if (d_ins[k]) { if (d_ins[k]) {
memory::Copy(place, d_ins[k]->data<T>() + i * cols, place, memory::Copy(place, d_ins[k]->data<T>() + i * cols, place,
d_out->data<T>() + i * cols, cols * sizeof(T)); d_out->data<T>() + i * cols, cols * sizeof(T));
......
...@@ -6,20 +6,22 @@ from op_test import OpTest ...@@ -6,20 +6,22 @@ from op_test import OpTest
class TestMultiplexOp(OpTest): class TestMultiplexOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "multiplex" self.op_type = "multiplex"
rows = 3 rows = 4
index = np.array([3, 1, 0]) index = np.arange(0, rows).astype('int32')
np.random.shuffle(index)
index = np.reshape(index, (rows, 1))
ins1 = np.random.random((rows, 10)).astype("float32") ins1 = np.random.random((rows, 10)).astype("float32")
ins2 = np.random.random((rows, 10)).astype("float32") ins2 = np.random.random((rows, 10)).astype("float32")
ins3 = np.random.random((rows, 10)).astype("float32") ins3 = np.random.random((rows, 10)).astype("float32")
ins4 = np.random.random((rows, 10)).astype("float32") ins4 = np.random.random((rows, 10)).astype("float32")
self.inputs = { self.inputs = {
'X': [('index', index), ('x1', ins1), ('x2', ins2), ('x3', ins3), 'Ids': index,
('x4', ins4)] 'X': [('x1', ins1), ('x2', ins2), ('x3', ins3), ('x4', ins4)]
} }
# multiplex output # multiplex output
output = np.zeros_like(ins1) output = np.zeros_like(ins1)
for i in range(0, rows): for i in range(0, rows):
k = index[i] + 1 k = index[i][0]
output[i] = self.inputs['X'][k][1][i] output[i] = self.inputs['X'][k][1][i]
self.outputs = {'Out': output} self.outputs = {'Out': output}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册