From 5c1920b731be024bbef9be757b83b12d2fc03470 Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Fri, 8 Mar 2019 09:40:45 +0000 Subject: [PATCH] add Attr shift_ratio. test=develop --- paddle/fluid/operators/temporal_shift_op.cc | 15 +++++++++-- paddle/fluid/operators/temporal_shift_op.cu | 26 +++++++++++++------ paddle/fluid/operators/temporal_shift_op.h | 16 +++++++++--- python/paddle/fluid/layers/nn.py | 10 ++++--- .../fluid/tests/unittests/test_layers.py | 2 +- .../tests/unittests/test_temporal_shift_op.py | 16 ++++++++---- 6 files changed, 62 insertions(+), 23 deletions(-) diff --git a/paddle/fluid/operators/temporal_shift_op.cc b/paddle/fluid/operators/temporal_shift_op.cc index a71d372c7..4f1cad367 100644 --- a/paddle/fluid/operators/temporal_shift_op.cc +++ b/paddle/fluid/operators/temporal_shift_op.cc @@ -33,8 +33,12 @@ class TemporalShiftOp: public framework::OperatorWithKernel { "Input(X) rank should be 4 in shape of [N*T, C, H, W]."); int seg_num = ctx->Attrs().Get("seg_num"); + float shift_ratio = ctx->Attrs().Get("shift_ratio"); PADDLE_ENFORCE_GT(seg_num, 0, - "Attr(seg_num) should be greater then 0."); + "Attr(seg_num) should be greater than 0."); + PADDLE_ENFORCE(shift_ratio > 0 || shift_ratio < .5, + "Attr(shift_ratio) should be greater than 0 and less " + "than 0.5."); if (ctx->IsRuntime()) { PADDLE_ENFORCE_EQ(dim_x[0] % seg_num, 0, @@ -69,6 +73,12 @@ class TemporalShiftOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr("seg_num", "The temporal segment number, this should be a positive " "interger."); + AddAttr("shift_ratio", + "The shift ratio of the channels, the first shift ratio part " + "of channels will be shifted by -1 along the temporal dimension, " + "and the second shift ratio part of channels will be shifted by " + "1 along the temporal dimension. Default 0.25.") + .SetDefault(0.25); AddComment(R"DOC( This operator calculates the temporal shifting features for Input(X). @@ -85,7 +95,8 @@ class TemporalShiftOpMaker : public framework::OpProtoAndCheckerMaker { padding width as 1 on each side, padding result will be in shape of [N, T+2, C, H, W]. - Step 3: Slice padding result as follows: + Step 3: Assume :attr:`shift_ratio` is :math:`0.25`, slice padding + result as follows: slice1 = x[:, :T, :C/4, :, :] slice2 = x[:, 2:T+2, C/4:C/2, :, :] diff --git a/paddle/fluid/operators/temporal_shift_op.cu b/paddle/fluid/operators/temporal_shift_op.cu index b555c08c2..3d9c9ddd5 100644 --- a/paddle/fluid/operators/temporal_shift_op.cu +++ b/paddle/fluid/operators/temporal_shift_op.cu @@ -20,7 +20,8 @@ using framework::Tensor; template __global__ void KeTemporalShiftFw(const T* input, T* output, const int ntchw, - const int tchw, const int chw, const int hw, const int w, const int t, const int c) { + const int tchw, const int chw, const int hw, const int w, const int t, const int c, + const float shift_ratio) { int tid = blockIdx.x * blockDim.x + threadIdx.x; int stride = blockDim.x * gridDim.x; int src_it = 0; @@ -31,9 +32,12 @@ __global__ void KeTemporalShiftFw(const T* input, T* output, const int ntchw, int ih = (tid % hw) / w; int iw = tid % w; - if (ic < c / 4) { + const int c1 = static_cast(c * shift_ratio); + const int c2 = static_cast(c * 2 * shift_ratio); + + if (ic < c1) { src_it = it - 1; - } else if (ic < c / 2) { + } else if (ic < c2) { src_it = it + 1; } else { src_it = it; @@ -50,7 +54,8 @@ __global__ void KeTemporalShiftFw(const T* input, T* output, const int ntchw, template __global__ void KeTemporalShiftBw(const T* output_grad, T* input_grad, const int ntchw, - const int tchw, const int chw, const int hw, const int w, const int t, const int c) { + const int tchw, const int chw, const int hw, const int w, const int t, const int c, + const float shift_ratio) { int tid = blockIdx.x * blockDim.x + threadIdx.x; int stride = blockDim.x * gridDim.x; int src_it = 0; @@ -61,9 +66,12 @@ __global__ void KeTemporalShiftBw(const T* output_grad, T* input_grad, const int int ih = (tid % hw) / w; int iw = tid % w; - if (ic < c / 4) { + const int c1 = static_cast(c * shift_ratio); + const int c2 = static_cast(c * 2 * shift_ratio); + + if (ic < c1) { src_it = it - 1; - } else if (ic < c / 2) { + } else if (ic < c2) { src_it = it + 1; } else { src_it = it; @@ -85,6 +93,7 @@ class TemporalShiftOpCUDAKernel : public framework::OpKernel { auto* input = ctx.Input("X"); auto* output = ctx.Output("Out"); int t = ctx.Attr("seg_num"); + float shift_ratio = ctx.Attr("shift_ratio"); const int nt = input->dims()[0]; const int c = input->dims()[1]; @@ -105,7 +114,7 @@ class TemporalShiftOpCUDAKernel : public framework::OpKernel { KeTemporalShiftFw< T><<>>( - input_data, output_data, ntchw, tchw, chw, hw, w, t, c); + input_data, output_data, ntchw, tchw, chw, hw, w, t, c, shift_ratio); } }; @@ -116,6 +125,7 @@ class TemporalShiftGradOpCUDAKernel : public framework::OpKernel { auto* input_grad = ctx.Output(framework::GradVarName("X")); auto* output_grad = ctx.Input(framework::GradVarName("Out")); int t = ctx.Attr("seg_num"); + float shift_ratio = ctx.Attr("shift_ratio"); const int nt = output_grad->dims()[0]; const int c = output_grad->dims()[1]; @@ -139,7 +149,7 @@ class TemporalShiftGradOpCUDAKernel : public framework::OpKernel { KeTemporalShiftBw< T><<>>( - output_grad_data, input_grad_data, ntchw, tchw, chw, hw, w, t, c); + output_grad_data, input_grad_data, ntchw, tchw, chw, hw, w, t, c, shift_ratio); } }; diff --git a/paddle/fluid/operators/temporal_shift_op.h b/paddle/fluid/operators/temporal_shift_op.h index 3342a8b4a..6b8001596 100644 --- a/paddle/fluid/operators/temporal_shift_op.h +++ b/paddle/fluid/operators/temporal_shift_op.h @@ -30,12 +30,16 @@ class TemporalShiftKernel: public framework::OpKernel { auto* input = ctx.Input("X"); auto* output = ctx.Output("Out"); int t = ctx.Attr("seg_num"); + float shift_ratio = ctx.Attr("shift_ratio"); const int nt = input->dims()[0]; const int c = input->dims()[1]; const int h = input->dims()[2]; const int w = input->dims()[3]; + const int c1 = static_cast(c * shift_ratio); + const int c2 = static_cast(c * 2 * shift_ratio); + const int hw = h * w; const int chw = c * hw; const int tchw = t * chw; @@ -51,9 +55,9 @@ class TemporalShiftKernel: public framework::OpKernel { int ih = (i % hw) / w; int iw = i % w; - if (ic < c / 4) { + if (ic < c1) { src_it = it - 1; - } else if (ic < c / 2) { + } else if (ic < c2) { src_it = it + 1; } else { src_it = it; @@ -76,12 +80,16 @@ class TemporalShiftGradKernel : public framework::OpKernel { auto* input_grad = ctx.Output(framework::GradVarName("X")); auto* output_grad = ctx.Input(framework::GradVarName("Out")); int t = ctx.Attr("seg_num"); + float shift_ratio = ctx.Attr("shift_ratio"); const int nt = output_grad->dims()[0]; const int c = output_grad->dims()[1]; const int h = output_grad->dims()[2]; const int w = output_grad->dims()[3]; + const int c1 = static_cast(c * shift_ratio); + const int c2 = static_cast(c * 2 * shift_ratio); + const int hw = h * w; const int chw = c * hw; const int tchw = t * chw; @@ -98,9 +106,9 @@ class TemporalShiftGradKernel : public framework::OpKernel { int ih = (i % hw) / w; int iw = i % w; - if (ic < c / 4) { + if (ic < c1) { src_it = it - 1; - } else if (ic < c / 2) { + } else if (ic < c2) { src_it = it + 1; } else { src_it = it; diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 29b3ff903..1280baae5 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -10266,7 +10266,7 @@ def shuffle_channel(x, group, name=None): @templatedoc() -def temporal_shift(x, seg_num, name=None): +def temporal_shift(x, seg_num, shift_ratio=0.25, name=None): """ **Temporal Shift Operator** @@ -10275,6 +10275,7 @@ def temporal_shift(x, seg_num, name=None): Args: x(Variable): ${x_comment} seg_num(int): ${seg_num_comment} + shift_ratio(float): ${shift_ratio_comment} Returns: out(Variable): The temporal shifting result is a tensor variable with the @@ -10287,7 +10288,7 @@ def temporal_shift(x, seg_num, name=None): .. code-block:: python input = fluid.layers.data(name='input', shape=[4,2,2], dtype='float32') - out = fluid.layers.temporal_shift(x=input, seg_num=2) + out = fluid.layers.temporal_shift(x=input, seg_num=2, shift_ratio=0.2) """ helper = LayerHelper("temporal_shift", **locals()) @@ -10300,7 +10301,10 @@ def temporal_shift(x, seg_num, name=None): type="temporal_shift", inputs={"X": x}, outputs={"Out": out}, - attrs={"seg_num": seg_num}) + attrs={ + "seg_num": seg_num, + "shift_ratio": shift_ratio + }) return out diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index e8ba63be6..75411f5dd 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -1052,7 +1052,7 @@ class TestBook(unittest.TestCase): program = Program() with program_guard(program): x = layers.data(name="X", shape=[16, 4, 4], dtype="float32") - out = layers.temporal_shift(x, seg_num=4) + out = layers.temporal_shift(x, seg_num=4, shift_ratio=0.2) self.assertIsNotNone(out) print(str(program)) diff --git a/python/paddle/fluid/tests/unittests/test_temporal_shift_op.py b/python/paddle/fluid/tests/unittests/test_temporal_shift_op.py index 55ebc880c..dbef184d6 100644 --- a/python/paddle/fluid/tests/unittests/test_temporal_shift_op.py +++ b/python/paddle/fluid/tests/unittests/test_temporal_shift_op.py @@ -21,13 +21,15 @@ from op_test import OpTest from paddle.fluid import core -def temporal_shift(x, seg_num): +def temporal_shift(x, seg_num, shift_ratio): shape = x.shape reshape_x = x.reshape((-1, seg_num, shape[1], shape[2], shape[3])) pad_x = np.pad(reshape_x, ((0, 0), (1, 1), (0, 0), (0, 0), (0, 0)), 'constant') - slice1 = pad_x[:, :seg_num, :shape[1]//4, :, :] - slice2 = pad_x[:, 2:seg_num+2, shape[1]//4:shape[1]//2, :, :] - slice3 = pad_x[:, 1:seg_num+1, shape[1]//2:, :, :] + c1 = int(shape[1] * shift_ratio) + c2 = int(shape[1] * 2 * shift_ratio) + slice1 = pad_x[:, :seg_num, :c1, :, :] + slice2 = pad_x[:, 2:seg_num+2, c1:c2, :, :] + slice3 = pad_x[:, 1:seg_num+1, c2:, :, :] concat_x = np.concatenate([slice1, slice2, slice3], axis=2) return concat_x.reshape(shape) @@ -39,13 +41,14 @@ class TestTemporalShift(OpTest): self.attrs = { "seg_num": self.seg_num, + "shift_ratio": self.shift_ratio, } self.inputs = { "X": x, } - output = temporal_shift(x, self.seg_num) + output = temporal_shift(x, self.seg_num, self.shift_ratio) self.outputs = {"Out": output} def test_check_output(self): @@ -57,17 +60,20 @@ class TestTemporalShift(OpTest): def initTestCase(self): self.x_shape = (6, 4, 4, 4) self.seg_num = 3 + self.shift_ratio = 0.25 class TestTemporalShift2(TestTemporalShift): def initTestCase(self): self.x_shape = (4, 9, 7, 7) self.seg_num = 2 + self.shift_ratio = 0.2 class TestTemporalShift2(TestTemporalShift): def initTestCase(self): self.x_shape = (3, 10, 5, 5) self.seg_num = 1 + self.shift_ratio = 0.3 if __name__ == "__main__": -- GitLab