提交 17b4b980 编写于 作者: X xzl

add the transpose op

上级 c1075126
...@@ -51,7 +51,8 @@ list(REMOVE_ITEM GENERAL_OPS ...@@ -51,7 +51,8 @@ list(REMOVE_ITEM GENERAL_OPS
minus_op minus_op
mul_op mul_op
recurrent_op recurrent_op
scale_op) scale_op
transpose_op)
op_library(net_op SRCS net_op.cc) op_library(net_op SRCS net_op.cc)
op_library(minus_op SRCS minus_op.cc minus_op.cu DEPS scale_op) op_library(minus_op SRCS minus_op.cc minus_op.cu DEPS scale_op)
...@@ -59,6 +60,7 @@ op_library(mul_op SRCS mul_op.cc mul_op.cu DEPS math_function) ...@@ -59,6 +60,7 @@ op_library(mul_op SRCS mul_op.cc mul_op.cu DEPS math_function)
op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc
DEPS framework_proto tensor operator net_op) DEPS framework_proto tensor operator net_op)
op_library(scale_op SRCS scale_op.cc scale_op.cu DEPS net_op) op_library(scale_op SRCS scale_op.cc scale_op.cu DEPS net_op)
op_library(transpose_op SRCS transpose_op.cc transpose_op.cu DEPS paddle_memory device_context)
foreach(src ${GENERAL_OPS}) foreach(src ${GENERAL_OPS})
op_library(${src} SRCS ${src}.cc ${src}.cu) op_library(${src} SRCS ${src}.cc ${src}.cu)
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/transpose_op.h"
#include <vector>
#include "paddle/framework/ddim.h"
namespace paddle {
namespace operators {
using framework::Tensor;
class TransposeOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
auto in_dim = ctx.Input<Tensor>("X")->dims();
auto axis = ctx.GetAttr<std::vector<int>>("axis");
size_t in_dim_size = in_dim.size();
size_t axis_size = axis.size();
PADDLE_ENFORCE_EQ(
in_dim_size, axis_size,
"the input tensor dimensions should be equal to the axis size");
std::vector<int> axis_sorted(axis);
std::sort(axis_sorted.begin(), axis_sorted.end());
for (size_t i = 0; i < axis_sorted.size(); i++) {
PADDLE_ENFORCE_EQ(axis_sorted[i], (int)i,
"the sorted axis should be [0, 1, ... dims - 1], "
"the dims equals to the input tensor dimensions");
}
//
framework::DDim out_dim(in_dim);
for (size_t i = 0; i < axis.size(); i++) {
out_dim[i] = in_dim[axis[i]];
}
ctx.Output<Tensor>("Out")->Resize(out_dim);
}
};
class TransposeOpMaker : public framework::OpProtoAndCheckerMaker {
public:
TransposeOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input of transpose op");
AddOutput("Out", "The output of transpose op");
AddAttr<std::vector<int>>(
"axis",
"a list of integers, and the num of integers should be "
"the same with the input tensor dimensions");
AddComment(R"DOC(
Transpose the input tensor.
For example, input tensor shape(N, C, H, W) and axis {0, 2, 3, 1},
the output tensor shape will be (N, H, W, C)
)DOC");
}
};
class TransposeOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null");
auto x_dims = ctx.Input<Tensor>("X")->dims();
auto *x_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto out_grad_dims =
ctx.Input<Tensor>(framework::GradVarName("Out"))->dims();
auto out_dims = ctx.Input<Tensor>("Out")->dims();
PADDLE_ENFORCE(out_grad_dims == out_dims,
"Out@GRAD dims must equal to Input(X) dims");
x_grad->Resize(x_dims);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(transpose, ops::TransposeOp, ops::TransposeOpMaker, transpose_grad,
ops::TransposeOpGrad);
REGISTER_OP_CPU_KERNEL(transpose,
ops::TransposeKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
transpose_grad,
ops::TransposeGradKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/memory/memcpy.h"
#include "paddle/memory/memory.h"
#include "paddle/operators/transpose_op.h"
namespace paddle {
namespace operators {
template <typename T>
__global__ void transpose_kernel(int nthreads, const T* in_data, T* out_data,
int* offset_buffer, int ndims) {
int* in_offset = offset_buffer;
int* out_offset = offset_buffer + ndims;
int* axis = offset_buffer + ndims;
int to_index = blockIdx.x * blockDim.x + threadIdx.x;
if (to_index < nthreads) {
int from_index = 0;
int temp = to_index;
for (size_t i = 0; i < ndims; i++) {
from_index += (temp / out_offset[i]) * in_offset[axis[i]];
temp = temp % out_offset[i];
}
out_data[to_index] = in_data[from_index];
}
}
template <typename T>
void TransposeCUDA(const framework::ExecutionContext& context,
const framework::Tensor& in, framework::Tensor& out,
std::vector<int> axis) {
auto* in_data = in.template data<T>();
auto* out_data = out.template mutable_data<T>(context.GetPlace());
auto in_dim = in.dims();
auto out_dim = out.dims();
auto data_size = product(in_dim);
size_t ndims = in_dim.size();
std::vector<int> in_offset(ndims, 1);
std::vector<int> out_offset(ndims, 1);
std::vector<int64_t> buffer_dim_shape(1, ndims * 3);
auto buffer_dims = framework::make_ddim(buffer_dim_shape);
framework::Tensor host_buffer;
platform::CPUPlace cpu_place;
platform::GPUPlace gpu_place;
int* host_buffer_data = host_buffer.mutable_data<int>(buffer_dims, cpu_place);
auto offset_buffer =
memory::Alloc(context.GetPlace(), ndims * 3 * sizeof(int));
for (int i = ndims - 2; i >= 0; i--) {
in_offset[i] = in_offset[i + 1] * in_dim[i + 1];
out_offset[i] = out_offset[i + 1] * out_dim[i + 1];
}
for (int i = 0; i < ndims; i++) {
host_buffer_data[i] = in_offset[i];
host_buffer_data[i + ndims] = out_offset[i];
host_buffer_data[i + ndims * 2] = axis[i];
}
memory::Copy(gpu_place, offset_buffer, cpu_place, host_buffer_data,
ndims * 3 * sizeof(int));
int block = 512;
int grid = (data_size + block - 1) / block;
transpose_kernel<T><<<grid, block>>>(data_size, in_data, out_data,
static_cast<int*>(offset_buffer), ndims);
memory::Free(gpu_place, offset_buffer);
}
template <typename T>
class TransposeCUDAKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()),
"It must use GPUPlace.");
auto* in = context.Input<framework::Tensor>("X");
auto* out = context.Output<framework::Tensor>("Out");
auto axis = context.GetAttr<std::vector<int>>("axis");
TransposeCUDA<T>(context, *in, *out, axis);
}
};
template <typename T>
class TransposeGradCUDAKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()),
"It must use GPUPlace.");
auto* in = context.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* out = context.Output<framework::Tensor>(framework::GradVarName("X"));
auto axis_temp = context.GetAttr<std::vector<int>>("axis");
std::vector<int> axis(axis_temp);
for (size_t i = 0; i < axis.size(); i++) {
axis[axis_temp[i]] = i;
}
TransposeCUDA<T>(context, *in, *out, axis);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(transpose, ops::TransposeCUDAKernel<float>);
REGISTER_OP_GPU_KERNEL(transpose_grad, ops::TransposeGradCUDAKernel<float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <iostream>
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/math_function.h"
namespace paddle {
namespace operators {
template <typename Place, typename T>
void NaiveCpuTranspose(const framework::ExecutionContext& context,
const framework::Tensor& in, framework::Tensor& out,
std::vector<int> axis) {
auto in_data = in.data<T>();
auto out_data = out.mutable_data<T>(context.GetPlace());
auto in_dim = in.dims();
auto out_dim = out.dims();
size_t ndims = in_dim.size();
std::vector<int> in_offset(ndims, 1);
std::vector<int> out_offset(ndims, 1);
for (int i = ndims - 2; i >= 0; i--) {
in_offset[i] = in_offset[i + 1] * in_dim[i + 1];
out_offset[i] = out_offset[i + 1] * out_dim[i + 1];
}
size_t data_size = product(in_dim);
for (size_t to_index = 0; to_index < data_size; to_index++) {
int from_index = 0;
int temp = to_index;
for (size_t i = 0; i < ndims; i++) {
from_index += (temp / out_offset[i]) * in_offset[axis[i]];
temp = temp % out_offset[i];
}
out_data[to_index] = in_data[from_index];
}
}
template <typename Place, typename T, int Dims>
void DoTranspose(const framework::ExecutionContext& context,
const framework::Tensor& in, framework::Tensor& out,
std::vector<int> axis) {
Eigen::array<int, Dims> permute;
for (int i = 0; i < Dims; i++) {
permute[i] = axis[i];
}
auto in_dim = in.dims();
auto out_dim = out.dims();
auto eigen_in = framework::EigenTensor<T, Dims>::From(in);
auto eigen_out = framework::EigenTensor<T, Dims>::From(out);
auto& dev = context.GetEigenDevice<Place>();
eigen_out.device(dev) = eigen_in.shuffle(permute);
}
template <typename Place, typename T>
class TransposeKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<framework::Tensor>("X");
auto* out = context.Output<framework::Tensor>("Out");
out->mutable_data<T>(context.GetPlace());
auto axis = context.GetAttr<std::vector<int>>("axis");
int ndims = axis.size();
switch (ndims) {
case 2:
DoTranspose<Place, T, 2>(context, *in, *out, axis);
break;
case 3:
DoTranspose<Place, T, 3>(context, *in, *out, axis);
break;
case 4:
DoTranspose<Place, T, 4>(context, *in, *out, axis);
break;
case 5:
DoTranspose<Place, T, 5>(context, *in, *out, axis);
break;
default:
NaiveCpuTranspose<Place, T>(context, *in, *out, axis);
break;
}
}
};
template <typename Place, typename T>
class TransposeGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* out = context.Output<framework::Tensor>(framework::GradVarName("X"));
out->mutable_data<T>(context.GetPlace());
auto axis_temp = context.GetAttr<std::vector<int>>("axis");
std::vector<int> axis(axis_temp);
for (size_t i = 0; i < axis.size(); i++) {
axis[axis_temp[i]] = i;
}
int ndims = axis.size();
switch (ndims) {
case 2:
DoTranspose<Place, T, 2>(context, *in, *out, axis);
break;
case 3:
DoTranspose<Place, T, 3>(context, *in, *out, axis);
break;
case 4:
DoTranspose<Place, T, 4>(context, *in, *out, axis);
break;
case 5:
DoTranspose<Place, T, 5>(context, *in, *out, axis);
break;
default:
NaiveCpuTranspose<Place, T>(context, *in, *out, axis);
break;
}
}
};
} // namespace operators
} // namespace paddle
...@@ -49,6 +49,7 @@ USE_OP(minus); ...@@ -49,6 +49,7 @@ USE_OP(minus);
USE_OP(cos_sim); USE_OP(cos_sim);
USE_CPU_ONLY_OP(gather); USE_CPU_ONLY_OP(gather);
USE_CPU_ONLY_OP(scatter); USE_CPU_ONLY_OP(scatter);
USE_OP(transpose);
namespace paddle { namespace paddle {
namespace framework { namespace framework {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册