diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index 25dbd236e6432d99b27f1a6ffc4e07bf0f994155..41a2ddac769a33d2fb5be753f8b6b574ec0a0627 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -51,7 +51,8 @@ list(REMOVE_ITEM GENERAL_OPS minus_op mul_op recurrent_op - scale_op) + scale_op + transpose_op) op_library(net_op SRCS net_op.cc) op_library(minus_op SRCS minus_op.cc minus_op.cu DEPS scale_op) @@ -59,6 +60,7 @@ op_library(mul_op SRCS mul_op.cc mul_op.cu DEPS math_function) op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc DEPS framework_proto tensor operator net_op) op_library(scale_op SRCS scale_op.cc scale_op.cu DEPS net_op) +op_library(transpose_op SRCS transpose_op.cc transpose_op.cu DEPS paddle_memory device_context) foreach(src ${GENERAL_OPS}) op_library(${src} SRCS ${src}.cc ${src}.cu) diff --git a/paddle/operators/transpose_op.cc b/paddle/operators/transpose_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..b03d350151c025203a379df1669dab7446055f34 --- /dev/null +++ b/paddle/operators/transpose_op.cc @@ -0,0 +1,106 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/transpose_op.h" +#include +#include "paddle/framework/ddim.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; + +class TransposeOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + auto in_dim = ctx.Input("X")->dims(); + auto axis = ctx.GetAttr>("axis"); + size_t in_dim_size = in_dim.size(); + size_t axis_size = axis.size(); + PADDLE_ENFORCE_EQ( + in_dim_size, axis_size, + "the input tensor dimensions should be equal to the axis size"); + + std::vector axis_sorted(axis); + std::sort(axis_sorted.begin(), axis_sorted.end()); + for (size_t i = 0; i < axis_sorted.size(); i++) { + PADDLE_ENFORCE_EQ(axis_sorted[i], (int)i, + "the sorted axis should be [0, 1, ... dims - 1], " + "the dims equals to the input tensor dimensions"); + } + // + framework::DDim out_dim(in_dim); + for (size_t i = 0; i < axis.size(); i++) { + out_dim[i] = in_dim[axis[i]]; + } + ctx.Output("Out")->Resize(out_dim); + } +}; + +class TransposeOpMaker : public framework::OpProtoAndCheckerMaker { + public: + TransposeOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "The input of transpose op"); + AddOutput("Out", "The output of transpose op"); + AddAttr>( + "axis", + "a list of integers, and the num of integers should be " + "the same with the input tensor dimensions"); + AddComment(R"DOC( +Transpose the input tensor. +For example, input tensor shape(N, C, H, W) and axis {0, 2, 3, 1}, +the output tensor shape will be (N, H, W, C) +)DOC"); + } +}; + +class TransposeOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null"); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), + "Input(Out@GRAD) should not be null"); + auto x_dims = ctx.Input("X")->dims(); + auto *x_grad = ctx.Output(framework::GradVarName("X")); + + auto out_grad_dims = + ctx.Input(framework::GradVarName("Out"))->dims(); + auto out_dims = ctx.Input("Out")->dims(); + + PADDLE_ENFORCE(out_grad_dims == out_dims, + "Out@GRAD dims must equal to Input(X) dims"); + + x_grad->Resize(x_dims); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(transpose, ops::TransposeOp, ops::TransposeOpMaker, transpose_grad, + ops::TransposeOpGrad); +REGISTER_OP_CPU_KERNEL(transpose, + ops::TransposeKernel); +REGISTER_OP_CPU_KERNEL( + transpose_grad, + ops::TransposeGradKernel); diff --git a/paddle/operators/transpose_op.cu b/paddle/operators/transpose_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..96e864e62aa0fc7c0851af4f39c3a3b1ac30ab52 --- /dev/null +++ b/paddle/operators/transpose_op.cu @@ -0,0 +1,123 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/memory/memcpy.h" +#include "paddle/memory/memory.h" +#include "paddle/operators/transpose_op.h" + +namespace paddle { +namespace operators { + +template +__global__ void transpose_kernel(int nthreads, const T* in_data, T* out_data, + int* offset_buffer, int ndims) { + int* in_offset = offset_buffer; + int* out_offset = offset_buffer + ndims; + int* axis = offset_buffer + ndims; + + int to_index = blockIdx.x * blockDim.x + threadIdx.x; + + if (to_index < nthreads) { + int from_index = 0; + int temp = to_index; + for (size_t i = 0; i < ndims; i++) { + from_index += (temp / out_offset[i]) * in_offset[axis[i]]; + temp = temp % out_offset[i]; + } + out_data[to_index] = in_data[from_index]; + } +} + +template +void TransposeCUDA(const framework::ExecutionContext& context, + const framework::Tensor& in, framework::Tensor& out, + std::vector axis) { + auto* in_data = in.template data(); + auto* out_data = out.template mutable_data(context.GetPlace()); + auto in_dim = in.dims(); + auto out_dim = out.dims(); + auto data_size = product(in_dim); + size_t ndims = in_dim.size(); + std::vector in_offset(ndims, 1); + std::vector out_offset(ndims, 1); + std::vector buffer_dim_shape(1, ndims * 3); + + auto buffer_dims = framework::make_ddim(buffer_dim_shape); + framework::Tensor host_buffer; + platform::CPUPlace cpu_place; + platform::GPUPlace gpu_place; + + int* host_buffer_data = host_buffer.mutable_data(buffer_dims, cpu_place); + + auto offset_buffer = + memory::Alloc(context.GetPlace(), ndims * 3 * sizeof(int)); + + for (int i = ndims - 2; i >= 0; i--) { + in_offset[i] = in_offset[i + 1] * in_dim[i + 1]; + out_offset[i] = out_offset[i + 1] * out_dim[i + 1]; + } + + for (int i = 0; i < ndims; i++) { + host_buffer_data[i] = in_offset[i]; + host_buffer_data[i + ndims] = out_offset[i]; + host_buffer_data[i + ndims * 2] = axis[i]; + } + + memory::Copy(gpu_place, offset_buffer, cpu_place, host_buffer_data, + ndims * 3 * sizeof(int)); + int block = 512; + int grid = (data_size + block - 1) / block; + transpose_kernel<<>>(data_size, in_data, out_data, + static_cast(offset_buffer), ndims); + memory::Free(gpu_place, offset_buffer); +} + +template +class TransposeCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()), + "It must use GPUPlace."); + auto* in = context.Input("X"); + auto* out = context.Output("Out"); + auto axis = context.GetAttr>("axis"); + TransposeCUDA(context, *in, *out, axis); + } +}; + +template +class TransposeGradCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()), + "It must use GPUPlace."); + auto* in = context.Input(framework::GradVarName("Out")); + auto* out = context.Output(framework::GradVarName("X")); + auto axis_temp = context.GetAttr>("axis"); + + std::vector axis(axis_temp); + + for (size_t i = 0; i < axis.size(); i++) { + axis[axis_temp[i]] = i; + } + TransposeCUDA(context, *in, *out, axis); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL(transpose, ops::TransposeCUDAKernel); +REGISTER_OP_GPU_KERNEL(transpose_grad, ops::TransposeGradCUDAKernel); diff --git a/paddle/operators/transpose_op.h b/paddle/operators/transpose_op.h new file mode 100644 index 0000000000000000000000000000000000000000..1f24784eba4e629588ea77dbd8fb8a9b84c568eb --- /dev/null +++ b/paddle/operators/transpose_op.h @@ -0,0 +1,141 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once + +#include +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" +#include "paddle/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +template +void NaiveCpuTranspose(const framework::ExecutionContext& context, + const framework::Tensor& in, framework::Tensor& out, + std::vector axis) { + auto in_data = in.data(); + auto out_data = out.mutable_data(context.GetPlace()); + auto in_dim = in.dims(); + auto out_dim = out.dims(); + size_t ndims = in_dim.size(); + + std::vector in_offset(ndims, 1); + std::vector out_offset(ndims, 1); + + for (int i = ndims - 2; i >= 0; i--) { + in_offset[i] = in_offset[i + 1] * in_dim[i + 1]; + out_offset[i] = out_offset[i + 1] * out_dim[i + 1]; + } + + size_t data_size = product(in_dim); + + for (size_t to_index = 0; to_index < data_size; to_index++) { + int from_index = 0; + int temp = to_index; + for (size_t i = 0; i < ndims; i++) { + from_index += (temp / out_offset[i]) * in_offset[axis[i]]; + temp = temp % out_offset[i]; + } + out_data[to_index] = in_data[from_index]; + } +} + +template +void DoTranspose(const framework::ExecutionContext& context, + const framework::Tensor& in, framework::Tensor& out, + std::vector axis) { + Eigen::array permute; + for (int i = 0; i < Dims; i++) { + permute[i] = axis[i]; + } + auto in_dim = in.dims(); + auto out_dim = out.dims(); + + auto eigen_in = framework::EigenTensor::From(in); + auto eigen_out = framework::EigenTensor::From(out); + auto& dev = context.GetEigenDevice(); + eigen_out.device(dev) = eigen_in.shuffle(permute); +} + +template +class TransposeKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* in = context.Input("X"); + auto* out = context.Output("Out"); + out->mutable_data(context.GetPlace()); + + auto axis = context.GetAttr>("axis"); + int ndims = axis.size(); + switch (ndims) { + case 2: + DoTranspose(context, *in, *out, axis); + break; + case 3: + DoTranspose(context, *in, *out, axis); + break; + case 4: + DoTranspose(context, *in, *out, axis); + break; + case 5: + DoTranspose(context, *in, *out, axis); + break; + default: + NaiveCpuTranspose(context, *in, *out, axis); + break; + } + } +}; + +template +class TransposeGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* in = context.Input(framework::GradVarName("Out")); + auto* out = context.Output(framework::GradVarName("X")); + out->mutable_data(context.GetPlace()); + + auto axis_temp = context.GetAttr>("axis"); + std::vector axis(axis_temp); + + for (size_t i = 0; i < axis.size(); i++) { + axis[axis_temp[i]] = i; + } + + int ndims = axis.size(); + + switch (ndims) { + case 2: + DoTranspose(context, *in, *out, axis); + break; + case 3: + DoTranspose(context, *in, *out, axis); + break; + case 4: + DoTranspose(context, *in, *out, axis); + break; + case 5: + DoTranspose(context, *in, *out, axis); + break; + default: + NaiveCpuTranspose(context, *in, *out, axis); + break; + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index ba28b51adec2f1fc37b368531344c6802b84dc07..de120259191f3da23f50b45426f3ba3401f04a09 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -49,6 +49,7 @@ USE_OP(minus); USE_OP(cos_sim); USE_CPU_ONLY_OP(gather); USE_CPU_ONLY_OP(scatter); +USE_OP(transpose); namespace paddle { namespace framework {