提交 c003895c 编写于 作者: X X.Dragon 提交者: GitHub

Merge pull request #3920 from NHZlX/op_transpose

add the transpose op
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/transpose_op.h"
namespace paddle {
namespace operators {
using framework::Tensor;
class TransposeOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"),
"Output(Out) should not be null");
auto x_dims = ctx.Input<Tensor>("X")->dims();
std::vector<int> axis = ctx.Attr<std::vector<int>>("axis");
size_t x_rank = x_dims.size();
size_t axis_size = axis.size();
PADDLE_ENFORCE_EQ(x_rank, axis_size,
"the input tensor's rank(%d) "
"should be equal to the axis's size(%d)",
x_rank, axis_size);
std::vector<int> count(axis_size, 0);
for (size_t i = 0; i < axis_size; i++) {
PADDLE_ENFORCE(
axis[i] < static_cast<int>(axis_size) && ++count[axis[i]] == 1,
"Each element of Attribute axis should be a unique value "
"range from 0 to (dims - 1), "
"where the dims is the axis's size");
}
framework::DDim out_dims(x_dims);
for (size_t i = 0; i < axis_size; i++) {
out_dims[i] = x_dims[axis[i]];
}
ctx.Output<framework::LoDTensor>("Out")->Resize(out_dims);
}
};
class TransposeOpMaker : public framework::OpProtoAndCheckerMaker {
public:
TransposeOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput(
"X",
"(Tensor)The input tensor, tensors with rank at most 6 are supported");
AddOutput("Out", "(Tensor)The output tensor");
AddAttr<std::vector<int>>(
"axis",
"(vector<int>)a list of values, and the size of the list should be "
"the same with the input tensor rank, the tensor will "
"permute the axes according the the values given");
AddComment(R"DOC(
The Tensor will be permuted according to the axis values given.
The op is very much like the numpy.transpose function in python
For example:
>> input = numpy.arange(6).reshape((2,3))
>> input
array([[0, 1, 2],
[3, 4, 5]])
>> axis = [1, 0]
>> output = input.transpose(axis)
>> output
array([[0, 3],
[1, 4],
[2, 5]])
So, given a input tensor of shape(N, C, H, W) and the axis is {0, 2, 3, 1},
the output tensor shape will be (N, H, W, C)
)DOC");
}
};
class TransposeOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null");
auto x_dims = ctx.Input<Tensor>("X")->dims();
auto *x_grad =
ctx.Output<framework::LoDTensor>(framework::GradVarName("X"));
if (x_grad) x_grad->Resize(x_dims);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(transpose, ops::TransposeOp, ops::TransposeOpMaker, transpose_grad,
ops::TransposeOpGrad);
REGISTER_OP_CPU_KERNEL(transpose,
ops::TransposeKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
transpose_grad,
ops::TransposeGradKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/transpose_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(transpose,
ops::TransposeKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(
transpose_grad,
ops::TransposeGradKernel<paddle::platform::GPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
template <typename Place, typename T, int Rank>
void EigenTranspose(const framework::ExecutionContext& context,
const framework::Tensor& in, framework::Tensor& out,
std::vector<int> axis) {
Eigen::array<int, Rank> permute;
for (int i = 0; i < Rank; i++) {
permute[i] = axis[i];
}
auto in_dim = in.dims();
auto out_dim = out.dims();
auto eigen_in = framework::EigenTensor<T, Rank>::From(in);
auto eigen_out = framework::EigenTensor<T, Rank>::From(out);
auto& dev = context.GetEigenDevice<Place>();
eigen_out.device(dev) = eigen_in.shuffle(permute);
}
template <typename Place, typename T>
class TransposeKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<framework::Tensor>("X");
auto* out = context.Output<framework::Tensor>("Out");
out->mutable_data<T>(context.GetPlace());
std::vector<int> axis = context.Attr<std::vector<int>>("axis");
int ndims = axis.size();
switch (ndims) {
case 1:
EigenTranspose<Place, T, 1>(context, *x, *out, axis);
break;
case 2:
EigenTranspose<Place, T, 2>(context, *x, *out, axis);
break;
case 3:
EigenTranspose<Place, T, 3>(context, *x, *out, axis);
break;
case 4:
EigenTranspose<Place, T, 4>(context, *x, *out, axis);
break;
case 5:
EigenTranspose<Place, T, 5>(context, *x, *out, axis);
break;
case 6:
EigenTranspose<Place, T, 6>(context, *x, *out, axis);
break;
default:
PADDLE_THROW("Tensors with rank at most 6 are supported");
}
}
};
template <typename Place, typename T>
class TransposeGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* out_grad =
context.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* x_grad =
context.Output<framework::Tensor>(framework::GradVarName("X"));
if (x_grad) {
x_grad->mutable_data<T>(context.GetPlace());
std::vector<int> axis = context.Attr<std::vector<int>>("axis");
std::vector<int> reversed_axis(axis);
for (size_t i = 0; i < axis.size(); i++) {
reversed_axis[axis[i]] = i;
}
int ndims = axis.size();
switch (ndims) {
case 1:
EigenTranspose<Place, T, 1>(context, *out_grad, *x_grad,
reversed_axis);
break;
case 2:
EigenTranspose<Place, T, 2>(context, *out_grad, *x_grad,
reversed_axis);
break;
case 3:
EigenTranspose<Place, T, 3>(context, *out_grad, *x_grad,
reversed_axis);
break;
case 4:
EigenTranspose<Place, T, 4>(context, *out_grad, *x_grad,
reversed_axis);
break;
case 5:
EigenTranspose<Place, T, 5>(context, *out_grad, *x_grad,
reversed_axis);
break;
case 6:
EigenTranspose<Place, T, 6>(context, *out_grad, *x_grad,
reversed_axis);
break;
default:
PADDLE_THROW("Tensors with rank at most 6 are supported");
}
}
}
};
} // namespace operators
} // namespace paddle
import unittest
import numpy as np
from op_test import OpTest
class TestTransposeOp(OpTest):
def setUp(self):
self.initTestCase()
self.op_type = "transpose"
self.inputs = {'X': np.random.random(self.shape).astype("float32")}
self.attrs = {'axis': list(self.axis)}
self.outputs = {'Out': self.inputs['X'].transpose(self.axis)}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
def initTestCase(self):
self.shape = (3, 4)
self.axis = (1, 0)
class TestCase0(TestTransposeOp):
def initTestCase(self):
self.shape = (3, )
self.axis = (0, )
class TestCase1(TestTransposeOp):
def initTestCase(self):
self.shape = (3, 4, 5)
self.axis = (0, 2, 1)
class TestCase2(TestTransposeOp):
def initTestCase(self):
self.shape = (2, 3, 4, 5)
self.axis = (0, 2, 3, 1)
class TestCase3(TestTransposeOp):
def initTestCase(self):
self.shape = (2, 3, 4, 5, 6)
self.axis = (4, 2, 3, 1, 0)
class TestCase4(TestTransposeOp):
def initTestCase(self):
self.shape = (2, 3, 4, 5, 6, 1)
self.axis = (4, 2, 3, 1, 0, 5)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册