diff --git a/paddle/operators/multiplex_op.cc b/paddle/operators/multiplex_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..6e77b86b5698a263b850a973cd1b8644a0aa2201 --- /dev/null +++ b/paddle/operators/multiplex_op.cc @@ -0,0 +1,113 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/multiplex_op.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; + +class MultiplexOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE(!ctx.MultiInputVar("X").empty(), + "Input(X) should not be null"); + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), + "Output(Out) shouldn't be null."); + auto ins = ctx.MultiInput("X"); + auto *out = ctx.Output("Out"); + auto num_ins = ins.size(); + PADDLE_ENFORCE(num_ins > 2, + "multiplex operator should have more than 2 inputs."); + PADDLE_ENFORCE_EQ(ins[0]->dims().size(), 1, + "The first input must be a index vector."); + auto in_dim = ins[1]->dims(); + + for (size_t i = 2; i < num_ins; i++) { + auto dim = ins[i]->dims(); + PADDLE_ENFORCE( + in_dim == dim, + "All the input tensors except the first one must have the same size"); + } + out->Resize(in_dim); + } +}; + +class MultiplexOpMaker : public framework::OpProtoAndCheckerMaker { + public: + MultiplexOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "The input tensors of multiplex operator.").AsDuplicable(); + AddOutput("Out", "The output tensor of multiplex operator."); + AddComment(R"DOC(Multiplex operator + +Multiplex multiple tensors according to the index provided by the first +input tensor. + +ins[0]: the index tensor. +ins[1:N]: the candidate output tensors. +For each index i from 0 to batchSize - 1, the output is the i-th row of the +the (index[i] + 1)-th tensor. + +For i-th row of the output tensor: + +y[i][j] = x_{k}[i][j], j = 0,1, ... , (x_{1}.width - 1) + +where y is the output tensor. `x_{k}` is the k-th input tensor +and `k = x{0}[i] + 1`. + +)DOC"); + } +}; + +class MultiplexGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE(!ctx.MultiInputVar("X").empty(), + "Input(X) should not be null"); + PADDLE_ENFORCE(!ctx.MultiOutputVar(framework::GradVarName("X")).empty(), + "Output(X@Grad) should not be null"); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), + "Input(Out@GRAD) shouldn't be null."); + auto d_ins = ctx.MultiOutput(framework::GradVarName("X")); + auto ins = ctx.MultiInput("X"); + // don't compute gradient for index (ins[0]) + for (size_t i = 1; i < ins.size(); i++) { + if (d_ins[i]) { + d_ins[i]->Resize(ins[i]->dims()); + } + } + } +}; + +} // namespace operators +} // namespace paddle +namespace ops = paddle::operators; + +REGISTER_OP(multiplex, ops::MultiplexOp, ops::MultiplexOpMaker, multiplex_grad, + ops::MultiplexGradOp); +REGISTER_OP_CPU_KERNEL( + multiplex, ops::MultiplexCPUKernel); +REGISTER_OP_CPU_KERNEL( + multiplex_grad, + ops::MultiplexGradCPUKernel); diff --git a/paddle/operators/multiplex_op.cu b/paddle/operators/multiplex_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..4736f15bd594178168e3bcf799142d0fc18bff13 --- /dev/null +++ b/paddle/operators/multiplex_op.cu @@ -0,0 +1,95 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/framework/op_registry.h" +#include "paddle/operators/multiplex_op.h" + +namespace paddle { +namespace operators { + +template +class MultiplexGPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const { + auto ins = ctx.MultiInput("X"); + auto* out = ctx.Output("Out"); + + out->mutable_data(ctx.GetPlace()); + + auto rows = ins[1]->dims()[0]; + auto cols = ins[1]->dims()[1]; + // copy index to cpu + framework::Tensor index_t_cpu; + index_t_cpu.CopyFrom(*(ins[0]), platform::CPUPlace()); + auto* index = index_t_cpu.data(); + auto stream = reinterpret_cast( + ctx.device_context()) + .stream(); + Place place = boost::get(ctx.GetPlace()); + for (auto i = 0; i < rows; i++) { + int k = (int)index[i] + 1; + PADDLE_ENFORCE_LT(k, ins.size(), + "index exceeds the number of candidate tensors."); + memory::Copy(place, out->data() + i * cols, place, + ins[k]->data() + i * cols, cols * sizeof(T), stream); + } + } +}; + +template +class MultiplexGradGPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const { + auto* d_out = ctx.Input(framework::GradVarName("Out")); + auto ins = ctx.MultiInput("X"); + auto d_ins = + ctx.MultiOutput(framework::GradVarName("X")); + for (size_t i = 1; i < d_ins.size(); i++) { + if (d_ins[i]) { + d_ins[i]->mutable_data(ctx.GetPlace()); + auto t = framework::EigenVector::Flatten(*d_ins[i]); + t.device(ctx.GetEigenDevice()) = t.constant(static_cast(0)); + } + } + + auto rows = ins[1]->dims()[0]; + auto cols = ins[1]->dims()[1]; + // copy index to cpu + framework::Tensor index_t_cpu; + index_t_cpu.CopyFrom(*(ins[0]), platform::CPUPlace()); + auto* index = index_t_cpu.data(); + + auto stream = reinterpret_cast( + ctx.device_context()) + .stream(); + Place place = boost::get(ctx.GetPlace()); + for (auto i = 0; i < rows; i++) { + int k = (int)index[i] + 1; + if (d_ins[k]) { + memory::Copy(place, d_ins[k]->data() + i * cols, place, + d_out->data() + i * cols, cols * sizeof(T), stream); + } + } + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_GPU_KERNEL( + multiplex, ops::MultiplexGPUKernel); +REGISTER_OP_GPU_KERNEL( + multiplex_grad, + ops::MultiplexGradGPUKernel); diff --git a/paddle/operators/multiplex_op.h b/paddle/operators/multiplex_op.h new file mode 100644 index 0000000000000000000000000000000000000000..44e8e0c1998014081b7e0aac603d573aba1f4a13 --- /dev/null +++ b/paddle/operators/multiplex_op.h @@ -0,0 +1,78 @@ + +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once + +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" +#include "paddle/memory/memcpy.h" + +namespace paddle { +namespace operators { + +template +class MultiplexCPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const { + auto ins = ctx.MultiInput("X"); + auto* out = ctx.Output("Out"); + + out->mutable_data(ctx.GetPlace()); + + auto rows = ins[1]->dims()[0]; + auto cols = ins[1]->dims()[1]; + auto* index = ins[0]->data(); + Place place = boost::get(ctx.GetPlace()); + for (auto i = 0; i < rows; i++) { + int k = (int)index[i] + 1; + PADDLE_ENFORCE_LT(k, ins.size(), + "index exceeds the number of candidate tensors."); + memory::Copy(place, out->data() + i * cols, place, + ins[k]->data() + i * cols, cols * sizeof(T)); + } + } +}; + +template +class MultiplexGradCPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const { + auto* d_out = ctx.Input(framework::GradVarName("Out")); + auto ins = ctx.MultiInput("X"); + auto d_ins = + ctx.MultiOutput(framework::GradVarName("X")); + for (size_t i = 1; i < d_ins.size(); i++) { + if (d_ins[i]) { + d_ins[i]->mutable_data(ctx.GetPlace()); + auto t = framework::EigenVector::Flatten(*d_ins[i]); + t.device(ctx.GetEigenDevice()) = t.constant(static_cast(0)); + } + } + + auto rows = ins[1]->dims()[0]; + auto cols = ins[1]->dims()[1]; + auto* index = ins[0]->data(); + Place place = boost::get(ctx.GetPlace()); + for (auto i = 0; i < rows; i++) { + int k = (int)index[i] + 1; + if (d_ins[k]) { + memory::Copy(place, d_ins[k]->data() + i * cols, place, + d_out->data() + i * cols, cols * sizeof(T)); + } + } + } +}; +} +} diff --git a/python/paddle/v2/framework/tests/test_multiplex_op.py b/python/paddle/v2/framework/tests/test_multiplex_op.py new file mode 100644 index 0000000000000000000000000000000000000000..f2b3881cde24c7fb96c3d7f9411352bc62d55077 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_multiplex_op.py @@ -0,0 +1,43 @@ +import unittest +import numpy as np +from op_test import OpTest + + +class TestMultiplexOp(OpTest): + def setUp(self): + self.op_type = "multiplex" + rows = 3 + index = np.array([3, 1, 0]) + ins1 = np.random.random((rows, 10)).astype("float32") + ins2 = np.random.random((rows, 10)).astype("float32") + ins3 = np.random.random((rows, 10)).astype("float32") + ins4 = np.random.random((rows, 10)).astype("float32") + self.inputs = { + 'X': [('index', index), ('x1', ins1), ('x2', ins2), ('x3', ins3), + ('x4', ins4)] + } + # multiplex output + output = np.zeros_like(ins1) + for i in range(0, rows): + k = index[i] + 1 + output[i] = self.inputs['X'][k][1][i] + self.outputs = {'Out': output} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['x1', 'x2', 'x3', 'x4'], 'Out') + + def test_check_grad_ignore_x1(self): + self.check_grad(['x2', 'x3', 'x4'], 'Out', no_grad_set=set('x1')) + + def test_check_grad_ignore_x1_x2(self): + self.check_grad(['x3', 'x4'], 'Out', no_grad_set=set(['x1', 'x2'])) + + def test_check_grad_ignore_x3(self): + self.check_grad(['x1', 'x2', 'x4'], 'Out', no_grad_set=set('x3')) + + +if __name__ == '__main__': + unittest.main()