diff --git a/paddle/operators/transpose_op.cc b/paddle/operators/transpose_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..babf2f561c31d5436fe1611c576e6e7fc04401db --- /dev/null +++ b/paddle/operators/transpose_op.cc @@ -0,0 +1,119 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/transpose_op.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; + +class TransposeOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null"); + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), + "Output(Out) should not be null"); + auto x_dims = ctx.Input("X")->dims(); + std::vector axis = ctx.Attr>("axis"); + size_t x_rank = x_dims.size(); + size_t axis_size = axis.size(); + + PADDLE_ENFORCE_EQ(x_rank, axis_size, + "the input tensor's rank(%d) " + "should be equal to the axis's size(%d)", + x_rank, axis_size); + + std::vector count(axis_size, 0); + for (size_t i = 0; i < axis_size; i++) { + PADDLE_ENFORCE( + axis[i] < static_cast(axis_size) && ++count[axis[i]] == 1, + "Each element of Attribute axis should be a unique value " + "range from 0 to (dims - 1), " + "where the dims is the axis's size"); + } + + framework::DDim out_dims(x_dims); + for (size_t i = 0; i < axis_size; i++) { + out_dims[i] = x_dims[axis[i]]; + } + ctx.Output("Out")->Resize(out_dims); + } +}; + +class TransposeOpMaker : public framework::OpProtoAndCheckerMaker { + public: + TransposeOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput( + "X", + "(Tensor)The input tensor, tensors with rank at most 6 are supported"); + AddOutput("Out", "(Tensor)The output tensor"); + AddAttr>( + "axis", + "(vector)a list of values, and the size of the list should be " + "the same with the input tensor rank, the tensor will " + "permute the axes according the the values given"); + AddComment(R"DOC( +The Tensor will be permuted according to the axis values given. +The op is very much like the numpy.transpose function in python +For example: + >> input = numpy.arange(6).reshape((2,3)) + >> input + array([[0, 1, 2], + [3, 4, 5]]) + >> axis = [1, 0] + >> output = input.transpose(axis) + >> output + array([[0, 3], + [1, 4], + [2, 5]]) +So, given a input tensor of shape(N, C, H, W) and the axis is {0, 2, 3, 1}, +the output tensor shape will be (N, H, W, C) +)DOC"); + } +}; + +class TransposeOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null"); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), + "Input(Out@GRAD) should not be null"); + auto x_dims = ctx.Input("X")->dims(); + auto *x_grad = + ctx.Output(framework::GradVarName("X")); + + if (x_grad) x_grad->Resize(x_dims); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(transpose, ops::TransposeOp, ops::TransposeOpMaker, transpose_grad, + ops::TransposeOpGrad); +REGISTER_OP_CPU_KERNEL(transpose, + ops::TransposeKernel); +REGISTER_OP_CPU_KERNEL( + transpose_grad, + ops::TransposeGradKernel); diff --git a/paddle/operators/transpose_op.cu b/paddle/operators/transpose_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..af3f581462c919bbd2dd1067e536cc638f9c267d --- /dev/null +++ b/paddle/operators/transpose_op.cu @@ -0,0 +1,22 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/transpose_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL(transpose, + ops::TransposeKernel); +REGISTER_OP_GPU_KERNEL( + transpose_grad, + ops::TransposeGradKernel); diff --git a/paddle/operators/transpose_op.h b/paddle/operators/transpose_op.h new file mode 100644 index 0000000000000000000000000000000000000000..ea299dce72ad340b0a65ee50582dc156b5ad7abb --- /dev/null +++ b/paddle/operators/transpose_op.h @@ -0,0 +1,128 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once + +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +void EigenTranspose(const framework::ExecutionContext& context, + const framework::Tensor& in, framework::Tensor& out, + std::vector axis) { + Eigen::array permute; + for (int i = 0; i < Rank; i++) { + permute[i] = axis[i]; + } + auto in_dim = in.dims(); + auto out_dim = out.dims(); + + auto eigen_in = framework::EigenTensor::From(in); + auto eigen_out = framework::EigenTensor::From(out); + auto& dev = context.GetEigenDevice(); + eigen_out.device(dev) = eigen_in.shuffle(permute); +} + +template +class TransposeKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* x = context.Input("X"); + auto* out = context.Output("Out"); + out->mutable_data(context.GetPlace()); + + std::vector axis = context.Attr>("axis"); + int ndims = axis.size(); + switch (ndims) { + case 1: + EigenTranspose(context, *x, *out, axis); + break; + case 2: + EigenTranspose(context, *x, *out, axis); + break; + case 3: + EigenTranspose(context, *x, *out, axis); + break; + case 4: + EigenTranspose(context, *x, *out, axis); + break; + case 5: + EigenTranspose(context, *x, *out, axis); + break; + case 6: + EigenTranspose(context, *x, *out, axis); + break; + default: + PADDLE_THROW("Tensors with rank at most 6 are supported"); + } + } +}; + +template +class TransposeGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* out_grad = + context.Input(framework::GradVarName("Out")); + auto* x_grad = + context.Output(framework::GradVarName("X")); + if (x_grad) { + x_grad->mutable_data(context.GetPlace()); + + std::vector axis = context.Attr>("axis"); + std::vector reversed_axis(axis); + + for (size_t i = 0; i < axis.size(); i++) { + reversed_axis[axis[i]] = i; + } + + int ndims = axis.size(); + + switch (ndims) { + case 1: + EigenTranspose(context, *out_grad, *x_grad, + reversed_axis); + break; + case 2: + EigenTranspose(context, *out_grad, *x_grad, + reversed_axis); + break; + case 3: + EigenTranspose(context, *out_grad, *x_grad, + reversed_axis); + break; + case 4: + EigenTranspose(context, *out_grad, *x_grad, + reversed_axis); + break; + case 5: + EigenTranspose(context, *out_grad, *x_grad, + reversed_axis); + break; + case 6: + EigenTranspose(context, *out_grad, *x_grad, + reversed_axis); + break; + default: + PADDLE_THROW("Tensors with rank at most 6 are supported"); + } + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/v2/framework/tests/test_transpose_op.py b/python/paddle/v2/framework/tests/test_transpose_op.py new file mode 100644 index 0000000000000000000000000000000000000000..9409cbaa00f792b60d5950556b869108aa732478 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_transpose_op.py @@ -0,0 +1,56 @@ +import unittest +import numpy as np +from op_test import OpTest + + +class TestTransposeOp(OpTest): + def setUp(self): + self.initTestCase() + self.op_type = "transpose" + self.inputs = {'X': np.random.random(self.shape).astype("float32")} + self.attrs = {'axis': list(self.axis)} + self.outputs = {'Out': self.inputs['X'].transpose(self.axis)} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + def initTestCase(self): + self.shape = (3, 4) + self.axis = (1, 0) + + +class TestCase0(TestTransposeOp): + def initTestCase(self): + self.shape = (3, ) + self.axis = (0, ) + + +class TestCase1(TestTransposeOp): + def initTestCase(self): + self.shape = (3, 4, 5) + self.axis = (0, 2, 1) + + +class TestCase2(TestTransposeOp): + def initTestCase(self): + self.shape = (2, 3, 4, 5) + self.axis = (0, 2, 3, 1) + + +class TestCase3(TestTransposeOp): + def initTestCase(self): + self.shape = (2, 3, 4, 5, 6) + self.axis = (4, 2, 3, 1, 0) + + +class TestCase4(TestTransposeOp): + def initTestCase(self): + self.shape = (2, 3, 4, 5, 6, 1) + self.axis = (4, 2, 3, 1, 0, 5) + + +if __name__ == '__main__': + unittest.main()