diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 7eec0b31556a93ec2ee390772a260a0a5d90c87f..295b580e53b71f7f6cb6d109a4ff420f9bb01e72 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -216,6 +216,7 @@ paddle.fluid.layers.merge_selected_rows (ArgSpec(args=['x', 'name'], varargs=Non paddle.fluid.layers.get_tensor_from_selected_rows (ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '7ffc849e71f31dfe29030ff94e662de6')) paddle.fluid.layers.lstm (ArgSpec(args=['input', 'init_h', 'init_c', 'max_len', 'hidden_size', 'num_layers', 'dropout_prob', 'is_bidirec', 'is_test', 'name', 'default_initializer', 'seed'], varargs=None, keywords=None, defaults=(0.0, False, False, None, None, -1)), ('document', 'd5e6c494ac35100e2ed4d4bd9a1ed932')) paddle.fluid.layers.shuffle_channel (ArgSpec(args=['x', 'group', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '2fa6782d43d02ae64482d21235a82949')) +paddle.fluid.layers.temporal_shift(ArgSpec(args=['x', 'seg_num', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '2fa6782d43d02ae64482d21235a82949')) paddle.fluid.layers.py_func (ArgSpec(args=['func', 'x', 'out', 'backward_func', 'skip_vars_in_backward_input'], varargs=None, keywords=None, defaults=(None, None)), ('document', '8404e472ac12b4a30a505d3d3a3e5fdb')) paddle.fluid.layers.psroi_pool (ArgSpec(args=['input', 'rois', 'output_channels', 'spatial_scale', 'pooled_height', 'pooled_width', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '1546136806fef5c08f6918544bd9151d')) paddle.fluid.layers.teacher_student_sigmoid_loss (ArgSpec(args=['input', 'label', 'soft_max_up_bound', 'soft_max_lower_bound'], varargs=None, keywords=None, defaults=(15.0, -15.0)), ('document', '2f6ff96864054a31aa4bb659c6722c99')) diff --git a/paddle/fluid/operators/temporal_shift_op.cc b/paddle/fluid/operators/temporal_shift_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..8cb9fedfb3a6aa6b54089b18c6bab37db7193fa4 --- /dev/null +++ b/paddle/fluid/operators/temporal_shift_op.cc @@ -0,0 +1,115 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/fluid/operators/temporal_shift_op.h" +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; + +class TemporalShiftOp: public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of TemporalShiftOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of TemporalShiftOp should not be null."); + + auto dim_x = ctx->GetInputDim("X"); + PADDLE_ENFORCE_EQ(dim_x.size(), 4, + "Input(X) rank should be 4 in shape of [N*T, C, H, W]."); + + int seg_num = ctx->Attrs().Get("seg_num"); + PADDLE_ENFORCE_GT(seg_num, 0, + "Attr(seg_num) should be greater then 0."); + + if (ctx->IsRuntime()) { + PADDLE_ENFORCE_EQ(dim_x[0] % seg_num, 0, + "Input(X) dims[0] should be divided exactly by Attr(seg_num)."); + } + + ctx->SetOutputDim("Out", dim_x); + ctx->ShareLoD("X", "Out"); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.GetPlace()); + } +}; + +class TemporalShiftOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "The input tensor of temporal shift operator. " + "This is a 4-D tensor with shape of [N*T, C, H, W]. " + "While N is the batch size, T is the temporal segment " + "number, C is the channel number, H is the height of " + "features and W is the width of features."); + AddOutput("Out", + "The output tensor of temporal shift operator. " + "This is a 4-D tensor in the same shape with Input(X)."); + + AddAttr("seg_num", + "The temporal segment number, this should be a positive " + "interger."); + + AddComment(R"DOC( + This operator calculates the temporal shift features for Input(X). + + For details of spectral normalization, please refer to paper: + `Temporal Shift Module `_ . + + )DOC"); + } +}; + +class TemporalShiftOpGrad: public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) should not be null"); + auto dim_x = ctx->GetInputDim("X"); + if (ctx->HasOutput(framework::GradVarName("X"))) { + ctx->SetOutputDim(framework::GradVarName("X"), dim_x); + } + } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.GetPlace()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(temporal_shift, ops::TemporalShiftOp, ops::TemporalShiftOpMaker, + paddle::framework::DefaultGradOpDescMaker); +REGISTER_OPERATOR(temporal_shift_grad, ops::TemporalShiftOpGrad); +REGISTER_OP_CPU_KERNEL(temporal_shift, ops::TemporalShiftKernel, + ops::TemporalShiftKernel); +REGISTER_OP_CPU_KERNEL(temporal_shift_grad, ops::TemporalShiftGradKernel, + ops::TemporalShiftGradKernel); diff --git a/paddle/fluid/operators/temporal_shift_op.cu b/paddle/fluid/operators/temporal_shift_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..b62b4703e2cf4e8c825e2802387aba9a3ba3cd0a --- /dev/null +++ b/paddle/fluid/operators/temporal_shift_op.cu @@ -0,0 +1,151 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/fluid/operators/temporal_shift_op.h" +#include "paddle/fluid/platform/cuda_primitives.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; + + +template +__global__ void KeTemporalShiftFw(const T* input, T* output, const int ntchw, + const int tchw, const int chw, const int hw, const int w, const int t, const int c) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + int src_it = 0; + for (; tid < ntchw; tid += stride) { + int in = tid / tchw; + int it = (tid % tchw) / chw; + int ic = (tid % chw) / hw; + int ih = (tid % hw) / w; + int iw = tid % w; + + if (ic < c / 4) { + src_it = it - 1; + } else if (ic < c / 2) { + src_it = it + 1; + } else { + src_it = it; + } + + if (src_it < 0 || src_it >= t) { + output[tid] = 0; + } else { + int src_idx = GetEntryIndex(in, src_it, ic, ih, iw, tchw, chw, hw, w); + output[tid] = input[src_idx]; + } + } +} + +template +__global__ void KeTemporalShiftBw(const T* output_grad, T* input_grad, const int ntchw, + const int tchw, const int chw, const int hw, const int w, const int t, const int c) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + int src_it = 0; + for (; tid < ntchw; tid += stride) { + int in = tid / tchw; + int it = (tid % tchw) / chw; + int ic = (tid % chw) / hw; + int ih = (tid % hw) / w; + int iw = tid % w; + + if (ic < c / 4) { + src_it = it - 1; + } else if (ic < c / 2) { + src_it = it + 1; + } else { + src_it = it; + } + + if (src_it >= 0 && src_it < t) { + int src_idx = GetEntryIndex(in, src_it, ic, ih, iw, tchw, chw, hw, w); + input_grad[src_idx] = output_grad[tid]; + } + } +} + +template +class TemporalShiftOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), + "This kernel only runs on GPU device."); + auto* input = ctx.Input("X"); + auto* output = ctx.Output("Out"); + int t = ctx.Attr("seg_num"); + + const int nt = input->dims()[0]; + const int c = input->dims()[1]; + const int h = input->dims()[2]; + const int w = input->dims()[3]; + + const int hw = h * w; + const int chw = c * hw; + const int tchw = t * chw; + const int ntchw = nt * chw; + + const T* input_data = input->data(); + T* output_data = output->mutable_data({nt, c, h, w}, ctx.GetPlace()); + + int pixelNum = nt * chw; + int grid_dim = (pixelNum + 512 - 1) / 512; + grid_dim = grid_dim > 8 ? 8 : grid_dim; + + KeTemporalShiftFw< + T><<>>( + input_data, output_data, ntchw, tchw, chw, hw, w, t, c); + } +}; + +template +class TemporalShiftGradOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input_grad = ctx.Output(framework::GradVarName("X")); + auto* output_grad = ctx.Input(framework::GradVarName("Out")); + int t = ctx.Attr("seg_num"); + + const int nt = output_grad->dims()[0]; + const int c = output_grad->dims()[1]; + const int h = output_grad->dims()[2]; + const int w = output_grad->dims()[3]; + + const int hw = h * w; + const int chw = c * hw; + const int tchw = t * chw; + const int ntchw = nt * chw; + + const T* output_grad_data = output_grad->data(); + T* input_grad_data = input_grad->mutable_data({nt, c, h, w}, ctx.GetPlace()); + + int pixelNum = nt * chw; + int grid_dim = (pixelNum + 512 - 1) / 512; + grid_dim = grid_dim > 8 ? 8 : grid_dim; + + KeTemporalShiftBw< + T><<>>( + output_grad_data, input_grad_data, ntchw, tchw, chw, hw, w, t, c); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL(temporal_shift, ops::TemporalShiftOpCUDAKernel, + ops::TemporalShiftOpCUDAKernel); +REGISTER_OP_CUDA_KERNEL(temporal_shift_grad, + ops::TemporalShiftGradOpCUDAKernel, + ops::TemporalShiftGradOpCUDAKernel); diff --git a/paddle/fluid/operators/temporal_shift_op.h b/paddle/fluid/operators/temporal_shift_op.h new file mode 100644 index 0000000000000000000000000000000000000000..9b96def3c728ac7d14a43b15ed0ecebfd667d0de --- /dev/null +++ b/paddle/fluid/operators/temporal_shift_op.h @@ -0,0 +1,117 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +static HOSTDEVICE inline int GetEntryIndex(int in, int it, int ic, int ih, int iw, + const int tchw, const int chw, const int hw, const int w) { + return in * tchw + it * chw + ic * hw + ih * w + iw; +} + +template +class TemporalShiftKernel: public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("X"); + auto* output = ctx.Output("Out"); + int t = ctx.Attr("seg_num"); + + const int nt = input->dims()[0]; + const int c = input->dims()[1]; + const int h = input->dims()[2]; + const int w = input->dims()[3]; + + const int hw = h * w; + const int chw = c * hw; + const int tchw = t * chw; + + const T* input_data = input->data(); + T* output_data = output->mutable_data({nt, c, h, w}, ctx.GetPlace()); + + int src_it = 0; + for (int i = 0; i < output->numel(); i++) { + int in = i / tchw; + int it = (i % tchw) / chw; + int ic = (i % chw) / hw; + int ih = (i % hw) / w; + int iw = i % w; + + if (ic < c / 4) { + src_it = it - 1; + } else if (ic < c / 2) { + src_it = it + 1; + } else { + src_it = it; + } + + if (src_it < 0 || src_it >= t) { + output_data[i] = 0; + } else { + int src_idx = GetEntryIndex(in, src_it, ic, ih, iw, tchw, chw, hw, w); + output_data[i] = input_data[src_idx]; + } + } + } +}; + +template +class TemporalShiftGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input_grad = ctx.Output(framework::GradVarName("X")); + auto* output_grad = ctx.Input(framework::GradVarName("Out")); + int t = ctx.Attr("seg_num"); + + const int nt = output_grad->dims()[0]; + const int c = output_grad->dims()[1]; + const int h = output_grad->dims()[2]; + const int w = output_grad->dims()[3]; + + const int hw = h * w; + const int chw = c * hw; + const int tchw = t * chw; + + const T* output_grad_data = output_grad->data(); + T* input_grad_data = input_grad->mutable_data({nt, c, h, w}, ctx.GetPlace()); + + int src_it = 0; + for (int i = 0; i < output_grad->numel(); i++) { + int in = i / tchw; + int it = (i % tchw) / chw; + int ic = (i % chw) / hw; + int ih = (i % hw) / w; + int iw = i % w; + + if (ic < c / 4) { + src_it = it - 1; + } else if (ic < c / 2) { + src_it = it + 1; + } else { + src_it = it; + } + + if (src_it >= 0 && src_it < t) { + int src_idx = GetEntryIndex(in, src_it, ic, ih, iw, tchw, chw, hw, w); + input_grad_data[src_idx] = output_grad_data[i]; + } + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 5b4f1efe479b12cb8ec390b8753d097764d70860..29b3ff903703a4b5695c7a9f292199b9a611f27a 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -182,6 +182,7 @@ __all__ = [ 'get_tensor_from_selected_rows', 'lstm', 'shuffle_channel', + 'temporal_shift', 'py_func', 'psroi_pool', 'teacher_student_sigmoid_loss', @@ -10264,6 +10265,45 @@ def shuffle_channel(x, group, name=None): return out +@templatedoc() +def temporal_shift(x, seg_num, name=None): + """ + **Temporal Shift Operator** + + ${comment} + + Args: + x(Variable): ${x_comment} + seg_num(int): ${seg_num_comment} + + Returns: + out(Variable): The temporal shifting result is a tensor variable with the + same shape and same type as the input. + + Raises: + TypeError: seg_num must be int type. + + Examples: + .. code-block:: python + + input = fluid.layers.data(name='input', shape=[4,2,2], dtype='float32') + out = fluid.layers.temporal_shift(x=input, seg_num=2) + """ + helper = LayerHelper("temporal_shift", **locals()) + + out = helper.create_variable_for_type_inference(dtype=x.dtype) + + if not isinstance(seg_num, int): + raise TypeError("seg_num must be int type.") + + helper.append_op( + type="temporal_shift", + inputs={"X": x}, + outputs={"Out": out}, + attrs={"seg_num": seg_num}) + return out + + class PyFuncRegistry(object): _register_funcs = [] diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index ff49c1be979a2076952963ec54302fb68361eedf..e8ba63be675f6b5f5a0241af81baf676a3f3e999 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -1048,6 +1048,14 @@ class TestBook(unittest.TestCase): print(str(program)) + def test_temporal_shift(self): + program = Program() + with program_guard(program): + x = layers.data(name="X", shape=[16, 4, 4], dtype="float32") + out = layers.temporal_shift(x, seg_num=4) + self.assertIsNotNone(out) + print(str(program)) + def test_shuffle_channel(self): program = Program() with program_guard(program): diff --git a/python/paddle/fluid/tests/unittests/test_temporal_shift_op.py b/python/paddle/fluid/tests/unittests/test_temporal_shift_op.py new file mode 100644 index 0000000000000000000000000000000000000000..c2ab34e4d630309f7f4d9d2dc4fdb6af7974dba4 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_temporal_shift_op.py @@ -0,0 +1,77 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division + +import unittest +import numpy as np +from op_test import OpTest + +from paddle.fluid import core + + +def temporal_shift(x, seg_num): + shape = x.shape + reshape_x = x.reshape((-1, seg_num, shape[1], shape[2], shape[3])) + pad_x = np.pad(reshape_x, ((0, 0), (1, 1), (0, 0), (0, 0), (0, 0)), 'constant') + slice1 = pad_x[:, :seg_num, :shape[1]//4, :, :] + slice2 = pad_x[:, 2:seg_num+2, shape[1]//4:shape[1]//2, :, :] + slice3 = pad_x[:, 1:seg_num+1, shape[1]//2:, :, :] + concat_x = np.concatenate([slice1, slice2, slice3], axis=2) + return concat_x.reshape(shape) + +class TestTemporalShift(OpTest): + def setUp(self): + self.initTestCase() + self.op_type = 'temporal_shift' + x = np.random.random(self.x_shape).astype('float32') + + self.attrs = { + "seg_num": self.seg_num, + } + + self.inputs = { + "X": x, + } + + output = temporal_shift(x, self.seg_num) + self.outputs = {"Out": output} + + def test_check_output(self): + self.check_output() + + def test_check_grad_ignore_uv(self): + self.check_grad( + ['X'], + 'Out', + max_relative_error=0.01) + + def initTestCase(self): + self.x_shape = (6, 4, 4, 4) + self.seg_num = 3 + +class TestTemporalShift2(TestTemporalShift): + def initTestCase(self): + self.x_shape = (4, 9, 7, 7) + self.seg_num = 2 + + +class TestTemporalShift2(TestTemporalShift): + def initTestCase(self): + self.x_shape = (3, 10, 5, 5) + self.seg_num = 1 + + +if __name__ == "__main__": + unittest.main()