提交 5c1920b7 编写于 作者: D dengkaipeng

add Attr shift_ratio. test=develop

上级 71101c9c
...@@ -33,8 +33,12 @@ class TemporalShiftOp: public framework::OperatorWithKernel { ...@@ -33,8 +33,12 @@ class TemporalShiftOp: public framework::OperatorWithKernel {
"Input(X) rank should be 4 in shape of [N*T, C, H, W]."); "Input(X) rank should be 4 in shape of [N*T, C, H, W].");
int seg_num = ctx->Attrs().Get<int>("seg_num"); int seg_num = ctx->Attrs().Get<int>("seg_num");
float shift_ratio = ctx->Attrs().Get<float>("shift_ratio");
PADDLE_ENFORCE_GT(seg_num, 0, PADDLE_ENFORCE_GT(seg_num, 0,
"Attr(seg_num) should be greater then 0."); "Attr(seg_num) should be greater than 0.");
PADDLE_ENFORCE(shift_ratio > 0 || shift_ratio < .5,
"Attr(shift_ratio) should be greater than 0 and less "
"than 0.5.");
if (ctx->IsRuntime()) { if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(dim_x[0] % seg_num, 0, PADDLE_ENFORCE_EQ(dim_x[0] % seg_num, 0,
...@@ -69,6 +73,12 @@ class TemporalShiftOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -69,6 +73,12 @@ class TemporalShiftOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<int>("seg_num", AddAttr<int>("seg_num",
"The temporal segment number, this should be a positive " "The temporal segment number, this should be a positive "
"interger."); "interger.");
AddAttr<float>("shift_ratio",
"The shift ratio of the channels, the first shift ratio part "
"of channels will be shifted by -1 along the temporal dimension, "
"and the second shift ratio part of channels will be shifted by "
"1 along the temporal dimension. Default 0.25.")
.SetDefault(0.25);
AddComment(R"DOC( AddComment(R"DOC(
This operator calculates the temporal shifting features for Input(X). This operator calculates the temporal shifting features for Input(X).
...@@ -85,7 +95,8 @@ class TemporalShiftOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -85,7 +95,8 @@ class TemporalShiftOpMaker : public framework::OpProtoAndCheckerMaker {
padding width as 1 on each side, padding result will be in shape padding width as 1 on each side, padding result will be in shape
of [N, T+2, C, H, W]. of [N, T+2, C, H, W].
Step 3: Slice padding result as follows: Step 3: Assume :attr:`shift_ratio` is :math:`0.25`, slice padding
result as follows:
slice1 = x[:, :T, :C/4, :, :] slice1 = x[:, :T, :C/4, :, :]
slice2 = x[:, 2:T+2, C/4:C/2, :, :] slice2 = x[:, 2:T+2, C/4:C/2, :, :]
......
...@@ -20,7 +20,8 @@ using framework::Tensor; ...@@ -20,7 +20,8 @@ using framework::Tensor;
template <typename T> template <typename T>
__global__ void KeTemporalShiftFw(const T* input, T* output, const int ntchw, __global__ void KeTemporalShiftFw(const T* input, T* output, const int ntchw,
const int tchw, const int chw, const int hw, const int w, const int t, const int c) { const int tchw, const int chw, const int hw, const int w, const int t, const int c,
const float shift_ratio) {
int tid = blockIdx.x * blockDim.x + threadIdx.x; int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x; int stride = blockDim.x * gridDim.x;
int src_it = 0; int src_it = 0;
...@@ -31,9 +32,12 @@ __global__ void KeTemporalShiftFw(const T* input, T* output, const int ntchw, ...@@ -31,9 +32,12 @@ __global__ void KeTemporalShiftFw(const T* input, T* output, const int ntchw,
int ih = (tid % hw) / w; int ih = (tid % hw) / w;
int iw = tid % w; int iw = tid % w;
if (ic < c / 4) { const int c1 = static_cast<T>(c * shift_ratio);
const int c2 = static_cast<T>(c * 2 * shift_ratio);
if (ic < c1) {
src_it = it - 1; src_it = it - 1;
} else if (ic < c / 2) { } else if (ic < c2) {
src_it = it + 1; src_it = it + 1;
} else { } else {
src_it = it; src_it = it;
...@@ -50,7 +54,8 @@ __global__ void KeTemporalShiftFw(const T* input, T* output, const int ntchw, ...@@ -50,7 +54,8 @@ __global__ void KeTemporalShiftFw(const T* input, T* output, const int ntchw,
template <typename T> template <typename T>
__global__ void KeTemporalShiftBw(const T* output_grad, T* input_grad, const int ntchw, __global__ void KeTemporalShiftBw(const T* output_grad, T* input_grad, const int ntchw,
const int tchw, const int chw, const int hw, const int w, const int t, const int c) { const int tchw, const int chw, const int hw, const int w, const int t, const int c,
const float shift_ratio) {
int tid = blockIdx.x * blockDim.x + threadIdx.x; int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x; int stride = blockDim.x * gridDim.x;
int src_it = 0; int src_it = 0;
...@@ -61,9 +66,12 @@ __global__ void KeTemporalShiftBw(const T* output_grad, T* input_grad, const int ...@@ -61,9 +66,12 @@ __global__ void KeTemporalShiftBw(const T* output_grad, T* input_grad, const int
int ih = (tid % hw) / w; int ih = (tid % hw) / w;
int iw = tid % w; int iw = tid % w;
if (ic < c / 4) { const int c1 = static_cast<T>(c * shift_ratio);
const int c2 = static_cast<T>(c * 2 * shift_ratio);
if (ic < c1) {
src_it = it - 1; src_it = it - 1;
} else if (ic < c / 2) { } else if (ic < c2) {
src_it = it + 1; src_it = it + 1;
} else { } else {
src_it = it; src_it = it;
...@@ -85,6 +93,7 @@ class TemporalShiftOpCUDAKernel : public framework::OpKernel<T> { ...@@ -85,6 +93,7 @@ class TemporalShiftOpCUDAKernel : public framework::OpKernel<T> {
auto* input = ctx.Input<Tensor>("X"); auto* input = ctx.Input<Tensor>("X");
auto* output = ctx.Output<Tensor>("Out"); auto* output = ctx.Output<Tensor>("Out");
int t = ctx.Attr<int>("seg_num"); int t = ctx.Attr<int>("seg_num");
float shift_ratio = ctx.Attr<float>("shift_ratio");
const int nt = input->dims()[0]; const int nt = input->dims()[0];
const int c = input->dims()[1]; const int c = input->dims()[1];
...@@ -105,7 +114,7 @@ class TemporalShiftOpCUDAKernel : public framework::OpKernel<T> { ...@@ -105,7 +114,7 @@ class TemporalShiftOpCUDAKernel : public framework::OpKernel<T> {
KeTemporalShiftFw< KeTemporalShiftFw<
T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>( T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
input_data, output_data, ntchw, tchw, chw, hw, w, t, c); input_data, output_data, ntchw, tchw, chw, hw, w, t, c, shift_ratio);
} }
}; };
...@@ -116,6 +125,7 @@ class TemporalShiftGradOpCUDAKernel : public framework::OpKernel<T> { ...@@ -116,6 +125,7 @@ class TemporalShiftGradOpCUDAKernel : public framework::OpKernel<T> {
auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("X")); auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* output_grad = ctx.Input<Tensor>(framework::GradVarName("Out")); auto* output_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
int t = ctx.Attr<int>("seg_num"); int t = ctx.Attr<int>("seg_num");
float shift_ratio = ctx.Attr<float>("shift_ratio");
const int nt = output_grad->dims()[0]; const int nt = output_grad->dims()[0];
const int c = output_grad->dims()[1]; const int c = output_grad->dims()[1];
...@@ -139,7 +149,7 @@ class TemporalShiftGradOpCUDAKernel : public framework::OpKernel<T> { ...@@ -139,7 +149,7 @@ class TemporalShiftGradOpCUDAKernel : public framework::OpKernel<T> {
KeTemporalShiftBw< KeTemporalShiftBw<
T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>( T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
output_grad_data, input_grad_data, ntchw, tchw, chw, hw, w, t, c); output_grad_data, input_grad_data, ntchw, tchw, chw, hw, w, t, c, shift_ratio);
} }
}; };
......
...@@ -30,12 +30,16 @@ class TemporalShiftKernel: public framework::OpKernel<T> { ...@@ -30,12 +30,16 @@ class TemporalShiftKernel: public framework::OpKernel<T> {
auto* input = ctx.Input<Tensor>("X"); auto* input = ctx.Input<Tensor>("X");
auto* output = ctx.Output<Tensor>("Out"); auto* output = ctx.Output<Tensor>("Out");
int t = ctx.Attr<int>("seg_num"); int t = ctx.Attr<int>("seg_num");
float shift_ratio = ctx.Attr<float>("shift_ratio");
const int nt = input->dims()[0]; const int nt = input->dims()[0];
const int c = input->dims()[1]; const int c = input->dims()[1];
const int h = input->dims()[2]; const int h = input->dims()[2];
const int w = input->dims()[3]; const int w = input->dims()[3];
const int c1 = static_cast<int>(c * shift_ratio);
const int c2 = static_cast<int>(c * 2 * shift_ratio);
const int hw = h * w; const int hw = h * w;
const int chw = c * hw; const int chw = c * hw;
const int tchw = t * chw; const int tchw = t * chw;
...@@ -51,9 +55,9 @@ class TemporalShiftKernel: public framework::OpKernel<T> { ...@@ -51,9 +55,9 @@ class TemporalShiftKernel: public framework::OpKernel<T> {
int ih = (i % hw) / w; int ih = (i % hw) / w;
int iw = i % w; int iw = i % w;
if (ic < c / 4) { if (ic < c1) {
src_it = it - 1; src_it = it - 1;
} else if (ic < c / 2) { } else if (ic < c2) {
src_it = it + 1; src_it = it + 1;
} else { } else {
src_it = it; src_it = it;
...@@ -76,12 +80,16 @@ class TemporalShiftGradKernel : public framework::OpKernel<T> { ...@@ -76,12 +80,16 @@ class TemporalShiftGradKernel : public framework::OpKernel<T> {
auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("X")); auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* output_grad = ctx.Input<Tensor>(framework::GradVarName("Out")); auto* output_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
int t = ctx.Attr<int>("seg_num"); int t = ctx.Attr<int>("seg_num");
float shift_ratio = ctx.Attr<float>("shift_ratio");
const int nt = output_grad->dims()[0]; const int nt = output_grad->dims()[0];
const int c = output_grad->dims()[1]; const int c = output_grad->dims()[1];
const int h = output_grad->dims()[2]; const int h = output_grad->dims()[2];
const int w = output_grad->dims()[3]; const int w = output_grad->dims()[3];
const int c1 = static_cast<int>(c * shift_ratio);
const int c2 = static_cast<int>(c * 2 * shift_ratio);
const int hw = h * w; const int hw = h * w;
const int chw = c * hw; const int chw = c * hw;
const int tchw = t * chw; const int tchw = t * chw;
...@@ -98,9 +106,9 @@ class TemporalShiftGradKernel : public framework::OpKernel<T> { ...@@ -98,9 +106,9 @@ class TemporalShiftGradKernel : public framework::OpKernel<T> {
int ih = (i % hw) / w; int ih = (i % hw) / w;
int iw = i % w; int iw = i % w;
if (ic < c / 4) { if (ic < c1) {
src_it = it - 1; src_it = it - 1;
} else if (ic < c / 2) { } else if (ic < c2) {
src_it = it + 1; src_it = it + 1;
} else { } else {
src_it = it; src_it = it;
......
...@@ -10266,7 +10266,7 @@ def shuffle_channel(x, group, name=None): ...@@ -10266,7 +10266,7 @@ def shuffle_channel(x, group, name=None):
@templatedoc() @templatedoc()
def temporal_shift(x, seg_num, name=None): def temporal_shift(x, seg_num, shift_ratio=0.25, name=None):
""" """
**Temporal Shift Operator** **Temporal Shift Operator**
...@@ -10275,6 +10275,7 @@ def temporal_shift(x, seg_num, name=None): ...@@ -10275,6 +10275,7 @@ def temporal_shift(x, seg_num, name=None):
Args: Args:
x(Variable): ${x_comment} x(Variable): ${x_comment}
seg_num(int): ${seg_num_comment} seg_num(int): ${seg_num_comment}
shift_ratio(float): ${shift_ratio_comment}
Returns: Returns:
out(Variable): The temporal shifting result is a tensor variable with the out(Variable): The temporal shifting result is a tensor variable with the
...@@ -10287,7 +10288,7 @@ def temporal_shift(x, seg_num, name=None): ...@@ -10287,7 +10288,7 @@ def temporal_shift(x, seg_num, name=None):
.. code-block:: python .. code-block:: python
input = fluid.layers.data(name='input', shape=[4,2,2], dtype='float32') input = fluid.layers.data(name='input', shape=[4,2,2], dtype='float32')
out = fluid.layers.temporal_shift(x=input, seg_num=2) out = fluid.layers.temporal_shift(x=input, seg_num=2, shift_ratio=0.2)
""" """
helper = LayerHelper("temporal_shift", **locals()) helper = LayerHelper("temporal_shift", **locals())
...@@ -10300,7 +10301,10 @@ def temporal_shift(x, seg_num, name=None): ...@@ -10300,7 +10301,10 @@ def temporal_shift(x, seg_num, name=None):
type="temporal_shift", type="temporal_shift",
inputs={"X": x}, inputs={"X": x},
outputs={"Out": out}, outputs={"Out": out},
attrs={"seg_num": seg_num}) attrs={
"seg_num": seg_num,
"shift_ratio": shift_ratio
})
return out return out
......
...@@ -1052,7 +1052,7 @@ class TestBook(unittest.TestCase): ...@@ -1052,7 +1052,7 @@ class TestBook(unittest.TestCase):
program = Program() program = Program()
with program_guard(program): with program_guard(program):
x = layers.data(name="X", shape=[16, 4, 4], dtype="float32") x = layers.data(name="X", shape=[16, 4, 4], dtype="float32")
out = layers.temporal_shift(x, seg_num=4) out = layers.temporal_shift(x, seg_num=4, shift_ratio=0.2)
self.assertIsNotNone(out) self.assertIsNotNone(out)
print(str(program)) print(str(program))
......
...@@ -21,13 +21,15 @@ from op_test import OpTest ...@@ -21,13 +21,15 @@ from op_test import OpTest
from paddle.fluid import core from paddle.fluid import core
def temporal_shift(x, seg_num): def temporal_shift(x, seg_num, shift_ratio):
shape = x.shape shape = x.shape
reshape_x = x.reshape((-1, seg_num, shape[1], shape[2], shape[3])) reshape_x = x.reshape((-1, seg_num, shape[1], shape[2], shape[3]))
pad_x = np.pad(reshape_x, ((0, 0), (1, 1), (0, 0), (0, 0), (0, 0)), 'constant') pad_x = np.pad(reshape_x, ((0, 0), (1, 1), (0, 0), (0, 0), (0, 0)), 'constant')
slice1 = pad_x[:, :seg_num, :shape[1]//4, :, :] c1 = int(shape[1] * shift_ratio)
slice2 = pad_x[:, 2:seg_num+2, shape[1]//4:shape[1]//2, :, :] c2 = int(shape[1] * 2 * shift_ratio)
slice3 = pad_x[:, 1:seg_num+1, shape[1]//2:, :, :] slice1 = pad_x[:, :seg_num, :c1, :, :]
slice2 = pad_x[:, 2:seg_num+2, c1:c2, :, :]
slice3 = pad_x[:, 1:seg_num+1, c2:, :, :]
concat_x = np.concatenate([slice1, slice2, slice3], axis=2) concat_x = np.concatenate([slice1, slice2, slice3], axis=2)
return concat_x.reshape(shape) return concat_x.reshape(shape)
...@@ -39,13 +41,14 @@ class TestTemporalShift(OpTest): ...@@ -39,13 +41,14 @@ class TestTemporalShift(OpTest):
self.attrs = { self.attrs = {
"seg_num": self.seg_num, "seg_num": self.seg_num,
"shift_ratio": self.shift_ratio,
} }
self.inputs = { self.inputs = {
"X": x, "X": x,
} }
output = temporal_shift(x, self.seg_num) output = temporal_shift(x, self.seg_num, self.shift_ratio)
self.outputs = {"Out": output} self.outputs = {"Out": output}
def test_check_output(self): def test_check_output(self):
...@@ -57,17 +60,20 @@ class TestTemporalShift(OpTest): ...@@ -57,17 +60,20 @@ class TestTemporalShift(OpTest):
def initTestCase(self): def initTestCase(self):
self.x_shape = (6, 4, 4, 4) self.x_shape = (6, 4, 4, 4)
self.seg_num = 3 self.seg_num = 3
self.shift_ratio = 0.25
class TestTemporalShift2(TestTemporalShift): class TestTemporalShift2(TestTemporalShift):
def initTestCase(self): def initTestCase(self):
self.x_shape = (4, 9, 7, 7) self.x_shape = (4, 9, 7, 7)
self.seg_num = 2 self.seg_num = 2
self.shift_ratio = 0.2
class TestTemporalShift2(TestTemporalShift): class TestTemporalShift2(TestTemporalShift):
def initTestCase(self): def initTestCase(self):
self.x_shape = (3, 10, 5, 5) self.x_shape = (3, 10, 5, 5)
self.seg_num = 1 self.seg_num = 1
self.shift_ratio = 0.3
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册