diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index 41a2ddac769a33d2fb5be753f8b6b574ec0a0627..25dbd236e6432d99b27f1a6ffc4e07bf0f994155 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -51,8 +51,7 @@ list(REMOVE_ITEM GENERAL_OPS minus_op mul_op recurrent_op - scale_op - transpose_op) + scale_op) op_library(net_op SRCS net_op.cc) op_library(minus_op SRCS minus_op.cc minus_op.cu DEPS scale_op) @@ -60,7 +59,6 @@ op_library(mul_op SRCS mul_op.cc mul_op.cu DEPS math_function) op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc DEPS framework_proto tensor operator net_op) op_library(scale_op SRCS scale_op.cc scale_op.cu DEPS net_op) -op_library(transpose_op SRCS transpose_op.cc transpose_op.cu DEPS paddle_memory device_context) foreach(src ${GENERAL_OPS}) op_library(${src} SRCS ${src}.cc ${src}.cu) diff --git a/paddle/operators/transpose_op.cc b/paddle/operators/transpose_op.cc index b03d350151c025203a379df1669dab7446055f34..9b7812c79d39cacdd1a7a3749615effbe091a530 100644 --- a/paddle/operators/transpose_op.cc +++ b/paddle/operators/transpose_op.cc @@ -31,6 +31,7 @@ class TransposeOp : public framework::OperatorWithKernel { auto axis = ctx.GetAttr>("axis"); size_t in_dim_size = in_dim.size(); size_t axis_size = axis.size(); + PADDLE_ENFORCE_EQ( in_dim_size, axis_size, "the input tensor dimensions should be equal to the axis size"); @@ -42,7 +43,7 @@ class TransposeOp : public framework::OperatorWithKernel { "the sorted axis should be [0, 1, ... dims - 1], " "the dims equals to the input tensor dimensions"); } - // + framework::DDim out_dim(in_dim); for (size_t i = 0; i < axis.size(); i++) { out_dim[i] = in_dim[axis[i]]; @@ -60,11 +61,12 @@ class TransposeOpMaker : public framework::OpProtoAndCheckerMaker { AddOutput("Out", "The output of transpose op"); AddAttr>( "axis", - "a list of integers, and the num of integers should be " - "the same with the input tensor dimensions"); + "a list of values, and the size of the list should be " + "the same with the input tensor dimensions, the tensor will " + "permute the axes according the the values given"); AddComment(R"DOC( -Transpose the input tensor. -For example, input tensor shape(N, C, H, W) and axis {0, 2, 3, 1}, +The Tensor will be permuted according to the axis values given. +For example, given a input tensor of shape(N, C, H, W) and the axis is {0, 2, 3, 1}, the output tensor shape will be (N, H, W, C) )DOC"); } diff --git a/paddle/operators/transpose_op.cu b/paddle/operators/transpose_op.cu index 96e864e62aa0fc7c0851af4f39c3a3b1ac30ab52..853659e3c3ed6e1111d64a2c1fd2c1dbe5f994d2 100644 --- a/paddle/operators/transpose_op.cu +++ b/paddle/operators/transpose_op.cu @@ -12,6 +12,7 @@ See the License for the specific language governing permissions and limitations under the License. */ +#include #include "paddle/memory/memcpy.h" #include "paddle/memory/memory.h" #include "paddle/operators/transpose_op.h" @@ -24,7 +25,7 @@ __global__ void transpose_kernel(int nthreads, const T* in_data, T* out_data, int* offset_buffer, int ndims) { int* in_offset = offset_buffer; int* out_offset = offset_buffer + ndims; - int* axis = offset_buffer + ndims; + int* axis = offset_buffer + ndims * 2; int to_index = blockIdx.x * blockDim.x + threadIdx.x; @@ -51,31 +52,37 @@ void TransposeCUDA(const framework::ExecutionContext& context, size_t ndims = in_dim.size(); std::vector in_offset(ndims, 1); std::vector out_offset(ndims, 1); - std::vector buffer_dim_shape(1, ndims * 3); + auto cpu_place = platform::CPUPlace(); + auto gpu_place = boost::get(context.GetPlace()); + + // Get a host_buffer to cache the input offset, output offset and the axis. + std::vector buffer_dim_shape(1, ndims * 3); auto buffer_dims = framework::make_ddim(buffer_dim_shape); framework::Tensor host_buffer; - platform::CPUPlace cpu_place; - platform::GPUPlace gpu_place; - int* host_buffer_data = host_buffer.mutable_data(buffer_dims, cpu_place); - auto offset_buffer = - memory::Alloc(context.GetPlace(), ndims * 3 * sizeof(int)); - for (int i = ndims - 2; i >= 0; i--) { in_offset[i] = in_offset[i + 1] * in_dim[i + 1]; out_offset[i] = out_offset[i + 1] * out_dim[i + 1]; } - + // copy the data to the host_buffer for (int i = 0; i < ndims; i++) { host_buffer_data[i] = in_offset[i]; host_buffer_data[i + ndims] = out_offset[i]; host_buffer_data[i + ndims * 2] = axis[i]; } + // Get a device_buffer to cache the input offset, output offset and the axis. + auto offset_buffer = memory::Alloc(gpu_place, ndims * 3 * sizeof(int)); + + auto* cuda_device_context = reinterpret_cast( + const_cast(context.device_context_)); + + // copy the host_buffer data to the device_buffer memory::Copy(gpu_place, offset_buffer, cpu_place, host_buffer_data, - ndims * 3 * sizeof(int)); + ndims * 3 * sizeof(int), cuda_device_context->stream()); + int block = 512; int grid = (data_size + block - 1) / block; transpose_kernel<<>>(data_size, in_data, out_data, diff --git a/paddle/operators/transpose_op.h b/paddle/operators/transpose_op.h index 1f24784eba4e629588ea77dbd8fb8a9b84c568eb..ca64b5a636417ed11720669e0378818ee42d6284 100644 --- a/paddle/operators/transpose_op.h +++ b/paddle/operators/transpose_op.h @@ -17,7 +17,6 @@ #include #include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" -#include "paddle/operators/math/math_function.h" namespace paddle { namespace operators { diff --git a/python/paddle/v2/framework/tests/test_transpose_op.py b/python/paddle/v2/framework/tests/test_transpose_op.py new file mode 100644 index 0000000000000000000000000000000000000000..63021da6aaa4b09fe52ac98bd1d9ff46c143aa81 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_transpose_op.py @@ -0,0 +1,27 @@ +import unittest +import numpy as np +from gradient_checker import GradientChecker +from op_test_util import OpTestMeta +from paddle.v2.framework.op import Operator + + +class TestTransposeOp(unittest.TestCase): + __metaclass__ = OpTestMeta + + def setUp(self): + self.type = "transpose" + self.inputs = {'X': np.random.random((3, 4)).astype("float32"), } + self.attrs = {'axis': [1, 0]} + self.outputs = {'Out': self.inputs['X'].transpose((1, 0))} + + +class TransposeGradOpTest(GradientChecker): + def test_transpose(self): + op = Operator("transpose", X="X", Out="Out", axis=[1, 0]) + inputs = {'X': np.random.random((32, 84)).astype("float32"), } + + self.check_grad(op, inputs, set(["X"]), "Out", max_relative_error=0.5) + + +if __name__ == '__main__': + unittest.main()