diff --git a/paddle/operators/conv_shift_op.cc b/paddle/operators/conv_shift_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..e1e321ed5fce6ce4e4089cc5c5e488a2cbad6c82 --- /dev/null +++ b/paddle/operators/conv_shift_op.cc @@ -0,0 +1,206 @@ +/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/conv_shift_op.h" +#include "paddle/framework/eigen.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; +template +using EigenMatrix = framework::EigenMatrix; + +class ConvShiftOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null."); + PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should be not null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should be not null."); + + auto x_dims = ctx->GetInputDim("X"); + auto y_dims = ctx->GetInputDim("Y"); + PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank should be 2."); + PADDLE_ENFORCE_EQ(y_dims.size(), 2, "Input(Y)'s rank should be 2."); + PADDLE_ENFORCE_EQ(x_dims[0], y_dims[0], + "The 1st dimension of Input(X) and Input(Y) should " + "be equal."); + PADDLE_ENFORCE_EQ(y_dims[1] % 2, 1, + "The 2nd dimension of Input(Y) should be odd."); + PADDLE_ENFORCE_LE(y_dims[1], x_dims[1], + "The 2nd dimension of Input(Y) should be less than or " + "equal to the 2nd dimension of Input(X)."); + ctx->SetOutputDim("Out", x_dims); + ctx->ShareLoD("X", /*->*/ "Out"); + } +}; + +class ConvShiftGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null."); + PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should be not null."); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) should be not null."); + + auto x_grad_name = framework::GradVarName("X"); + if (ctx->HasOutput(x_grad_name)) { + auto x_dims = ctx->GetInputDim("X"); + ctx->SetOutputDim(x_grad_name, x_dims); + } + + auto y_grad_name = framework::GradVarName("Y"); + if (ctx->HasOutput(y_grad_name)) { + auto y_dims = ctx->GetInputDim("Y"); + ctx->SetOutputDim(y_grad_name, y_dims); + } + } +}; + +class ConvShiftOpMaker : public framework::OpProtoAndCheckerMaker { + public: + ConvShiftOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : framework::OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", + "(Tensor, default Tensor), a 2-D tensor with shape B x M, " + "where B is the batch size and M is the data dimension."); + AddInput("Y", + "(Tensor, default Tensor), a 2-D tensor with shape B x N, " + "where B is the batch size and N is the data dimension. N must " + "be odd."); + AddOutput("Out", + "(Tensor, default Tensor), a 2-D tensor with shape B x M, " + "i.e., the same shape as X."); + AddComment(R"DOC( +ConvShift Operator. + +A layer for circular convolution of two vectors, +as used in the Neural Turing Machine: https://arxiv.org/abs/1410.5401 + +The equation is: + + \f[ + Out[i] = \sum_{j=-(N-1)/2}^{(N-1)/2} X_{i+j} * Y_{j} + \f] + +where X's index is computed modulo M, and b's index is computed modulo N. + +Both of the input `X` and `Y` can carry LoD (Level of Details) information. +However, the output only shares the LoD information with input `X`. +)DOC"); + } +}; + +template +class ConvShiftKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &context) const override { + auto *X = context.Input("X"); + auto *Y = context.Input("Y"); + auto *Out = context.Output("Out"); + Out->mutable_data(context.GetPlace()); + + auto x = EigenMatrix::From(*X); + auto y = EigenMatrix::From(*Y); + auto out = EigenMatrix::From(*Out); + out.setZero(); + + size_t batch_size = X->dims()[0]; + size_t x_width = X->dims()[1]; + size_t y_width = Y->dims()[1]; + size_t y_half_width = (y_width - 1) / 2; + + for (size_t k = 0; k < batch_size; ++k) { + for (size_t i = 0; i < x_width; ++i) { + for (size_t j = 0; j < y_width; ++j) { + int index = (i + j - y_half_width + x_width) % x_width; + out(k, i) += x(k, index) * y(k, j); + } + } + } + } +}; + +template +class ConvShiftGradKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &context) const override { + auto *X = context.Input("X"); + auto *Y = context.Input("Y"); + auto *dOut = context.Input(framework::GradVarName("Out")); + auto *dX = context.Output(framework::GradVarName("X")); + auto *dY = context.Output(framework::GradVarName("Y")); + + auto x = EigenMatrix::From(*X); + auto y = EigenMatrix::From(*Y); + auto dout = EigenMatrix::From(*dOut); + + auto x_dims = X->dims(); + auto y_dims = Y->dims(); + size_t batch_size = x_dims[0]; + size_t x_width = x_dims[1]; + size_t y_width = y_dims[1]; + size_t y_half_width = (y_width - 1) / 2; + + // The below trades code duplication for efficiency (keeping the if + // statement outside of the loop). + if (dX) { + dX->mutable_data(context.GetPlace()); + auto dx = EigenMatrix::From(*dX); + dx.setZero(); + for (size_t k = 0; k < batch_size; ++k) { + for (size_t i = 0; i < x_width; ++i) { + for (size_t j = 0; j < y_width; ++j) { + int index = (i + j - y_half_width + x_width) % x_width; + dx(k, index) += dout(k, i) * y(k, j); + } + } + } + } + + if (dY) { + dY->mutable_data(context.GetPlace()); + auto dy = EigenMatrix::From(*dY); + dy.setZero(); + for (size_t k = 0; k < batch_size; ++k) { + for (size_t i = 0; i < x_width; ++i) { + for (size_t j = 0; j < y_width; ++j) { + int index = (i + j - y_half_width + x_width) % x_width; + dy(k, j) += x(k, index) * dout(k, i); + } + } + } + } + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(conv_shift, ops::ConvShiftOp, ops::ConvShiftOpMaker, + conv_shift_grad, ops::ConvShiftGradOp); +REGISTER_OP_CPU_KERNEL(conv_shift, + ops::ConvShiftKernel); +REGISTER_OP_CPU_KERNEL( + conv_shift_grad, + ops::ConvShiftGradKernel); diff --git a/paddle/operators/conv_shift_op.cu b/paddle/operators/conv_shift_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..145e966fe9caa68f7485bb258fa78fd34bfd4c04 --- /dev/null +++ b/paddle/operators/conv_shift_op.cu @@ -0,0 +1,194 @@ +/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/conv_shift_op.h" +#include "paddle/platform/cuda_helper.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; + +namespace { + +inline int div_up(int x, int y) { return (x + y - 1) / y; } + +// Some notes on the design: +// +// Each thread is responsible for computing a single output out[k, i]. +// Thread blocks are based on tiles of x with height 1 in the batch dimension. +// +// This design is based on the typical use case where the filter +// y is fairly small. For large y, it would probably be more efficient +// to also tile across y. +template +__global__ void conv_shift_forward(const T *x, const T *y, T *out, int x_width, + int y_width, int y_half_width, + int batch_size) { + extern __shared__ T mem[]; + + int tx = threadIdx.x; + int i = blockIdx.x * blockDim.x + tx; // global x index + int k = blockIdx.y; // batch index + + // Check if we are in a boundary block with fewer x's to process than + // blockDim.x. + int num_x = + (blockIdx.x == gridDim.x - 1) ? (x_width % blockDim.x) : blockDim.x; + + T *sx = mem; + T *sx_pad = &mem[num_x]; + T *sy = &mem[blockDim.x + y_width]; + + // Collaboratively load y[k, :] and length-y padding of x into shared memory. + int pad_start = blockIdx.x * blockDim.x + num_x + x_width - y_half_width; + for (int j = tx; j < y_width; j += blockDim.x) { + sy[j] = y[k * y_width + j]; + sx_pad[j] = x[k * x_width + (pad_start + j) % x_width]; + } + + // Load a cyclically shifted slice of x into shared memory. + if (tx < num_x) { + int load_i = (i - y_half_width + x_width) % x_width; + sx[tx] = x[k * x_width + load_i]; + } else { + return; + } + __syncthreads(); + + // Compute dot product of sx[tx:tx + y_width] and sy. + T sum = 0; + for (int j = 0; j < y_width; ++j) { + sum += sx[tx + j] * sy[j]; + } + + // Save to out[k, i]. + out[k * x_width + i] = sum; +} + +// Compute x gradient - initial naive implementation with atomic add. +template +__global__ void conv_shift_dx(const T *dout, const T *y, T *dx, int x_width, + int y_width, int y_half_width, int batch_size) { + int i = blockIdx.x * blockDim.x + threadIdx.x; // x index + int j = blockIdx.y; // y index + int k = blockIdx.z; // batch index + + if (i < x_width) { + int index = (i + j - y_half_width + x_width) % x_width; + atomicAdd(&dx[k * x_width + index], + dout[k * x_width + i] * y[k * y_width + j]); + } +} + +// Compute y gradient - initial naive implementation with atomic add. +template +__global__ void conv_shift_dy(const T *x, const T *dout, T *dy, int x_width, + int y_width, int y_half_width, int batch_size) { + int i = blockIdx.x * blockDim.x + threadIdx.x; // x index + int j = blockIdx.y; // y index + int k = blockIdx.z; // batch index + + if (i < x_width) { + int index = (i + j - y_half_width + x_width) % x_width; + atomicAdd(&dy[k * y_width + j], + x[k * x_width + index] * dout[k * x_width + i]); + } +} +} // namespace + +template +class ConvShiftKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &context) const override { + const Tensor *X = context.Input("X"); + const Tensor *Y = context.Input("Y"); + Tensor *Out = context.Output("Out"); + const T *x_data = X->data(); + const T *y_data = Y->data(); + T *out_data = Out->mutable_data(context.GetPlace()); + + int batch_size = X->dims()[0]; + int x_width = X->dims()[1]; + int y_width = Y->dims()[1]; + int y_half_width = (y_width - 1) / 2; + + const int x_per_block = 256; + int num_x_blocks = div_up(x_width, x_per_block); + int mem_per_block = (x_per_block + 2 * y_width) * sizeof(T); + + dim3 grid_dim(num_x_blocks, batch_size); + + auto stream = reinterpret_cast( + context.device_context()) + .stream(); + + conv_shift_forward<<>>( + x_data, y_data, out_data, x_width, y_width, y_half_width, batch_size); + } +}; + +template +class ConvShiftGradKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &context) const override { + const Tensor *X = context.Input("X"); + const Tensor *Y = context.Input("Y"); + const Tensor *dOut = context.Input(framework::GradVarName("Out")); + const T *x_data = X->data(); + const T *y_data = Y->data(); + const T *dout_data = dOut->data(); + + Tensor *dX = context.Output(framework::GradVarName("X")); + Tensor *dY = context.Output(framework::GradVarName("Y")); + + int batch_size = X->dims()[0]; + int x_width = X->dims()[1]; + int y_width = Y->dims()[1]; + int y_half_width = (y_width - 1) / 2; + + auto stream = reinterpret_cast( + context.device_context()) + .stream(); + + const int x_per_block = 256; + int num_x_blocks = div_up(x_width, x_per_block); + dim3 grid_dim(num_x_blocks, y_width, batch_size); + + if (dX) { + T *dx_data = dX->mutable_data(context.GetPlace()); + cudaMemsetAsync(dx_data, 0, dX->numel() * sizeof(T), stream); + conv_shift_dx<<>>( + dout_data, y_data, dx_data, x_width, y_width, y_half_width, + batch_size); + } + if (dY) { + T *dy_data = dY->mutable_data(context.GetPlace()); + cudaMemsetAsync(dy_data, 0, dY->numel() * sizeof(T), stream); + conv_shift_dy<<>>( + x_data, dout_data, dy_data, x_width, y_width, y_half_width, + batch_size); + } + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL(conv_shift, + ops::ConvShiftKernel); +REGISTER_OP_GPU_KERNEL( + conv_shift_grad, + ops::ConvShiftGradKernel); diff --git a/paddle/operators/conv_shift_op.h b/paddle/operators/conv_shift_op.h new file mode 100644 index 0000000000000000000000000000000000000000..5a160b0f1696c70868fc48d219b38cde2018e8a3 --- /dev/null +++ b/paddle/operators/conv_shift_op.h @@ -0,0 +1,33 @@ +/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +class ConvShiftKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &context) const override; +}; + +template +class ConvShiftGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &context) const override; +}; +} // namespace operators +} // namespace paddle diff --git a/python/paddle/v2/framework/tests/test_conv_shift_op.py b/python/paddle/v2/framework/tests/test_conv_shift_op.py new file mode 100644 index 0000000000000000000000000000000000000000..b9ab21a06a1c6e8e2d1e936a0b4b8a07a59f57b9 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_conv_shift_op.py @@ -0,0 +1,47 @@ +import unittest +import numpy as np +from op_test import OpTest + + +def conv_shift_forward(x, y): + out = np.zeros_like(x) + M = x.shape[1] + N = y.shape[1] + y_half_width = (N - 1) / 2 + for i in xrange(M): + for j in xrange(N): + out[:, i] += x[:, (i + j + M - y_half_width) % M] * y[:, j] + return out + + +class TestConvShiftOp(OpTest): + def setUp(self): + self.op_type = "conv_shift" + + batch_size = 4 + x_dim = 17 + y_dim = 3 # must be odd and <= x_dim + x = np.random.random((batch_size, x_dim)).astype("float32") + y = np.random.random((batch_size, y_dim)).astype("float32") + self.inputs = {'X': x, 'Y': y} + + out = conv_shift_forward(x, y) + self.outputs = {'Out': out} + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.05) + + def test_check_grad_ignore_x(self): + self.check_grad( + ['Y'], 'Out', max_relative_error=0.05, no_grad_set=set("X")) + + def test_check_grad_ignore_y(self): + self.check_grad( + ['X'], 'Out', max_relative_error=0.05, no_grad_set=set('Y')) + + +if __name__ == '__main__': + unittest.main()