提交 5c1920b7 编写于 作者: D dengkaipeng

add Attr shift_ratio. test=develop

上级 71101c9c
......@@ -33,8 +33,12 @@ class TemporalShiftOp: public framework::OperatorWithKernel {
"Input(X) rank should be 4 in shape of [N*T, C, H, W].");
int seg_num = ctx->Attrs().Get<int>("seg_num");
float shift_ratio = ctx->Attrs().Get<float>("shift_ratio");
PADDLE_ENFORCE_GT(seg_num, 0,
"Attr(seg_num) should be greater then 0.");
"Attr(seg_num) should be greater than 0.");
PADDLE_ENFORCE(shift_ratio > 0 || shift_ratio < .5,
"Attr(shift_ratio) should be greater than 0 and less "
"than 0.5.");
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(dim_x[0] % seg_num, 0,
......@@ -69,6 +73,12 @@ class TemporalShiftOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<int>("seg_num",
"The temporal segment number, this should be a positive "
"interger.");
AddAttr<float>("shift_ratio",
"The shift ratio of the channels, the first shift ratio part "
"of channels will be shifted by -1 along the temporal dimension, "
"and the second shift ratio part of channels will be shifted by "
"1 along the temporal dimension. Default 0.25.")
.SetDefault(0.25);
AddComment(R"DOC(
This operator calculates the temporal shifting features for Input(X).
......@@ -85,7 +95,8 @@ class TemporalShiftOpMaker : public framework::OpProtoAndCheckerMaker {
padding width as 1 on each side, padding result will be in shape
of [N, T+2, C, H, W].
Step 3: Slice padding result as follows:
Step 3: Assume :attr:`shift_ratio` is :math:`0.25`, slice padding
result as follows:
slice1 = x[:, :T, :C/4, :, :]
slice2 = x[:, 2:T+2, C/4:C/2, :, :]
......
......@@ -20,7 +20,8 @@ using framework::Tensor;
template <typename T>
__global__ void KeTemporalShiftFw(const T* input, T* output, const int ntchw,
const int tchw, const int chw, const int hw, const int w, const int t, const int c) {
const int tchw, const int chw, const int hw, const int w, const int t, const int c,
const float shift_ratio) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
int src_it = 0;
......@@ -31,9 +32,12 @@ __global__ void KeTemporalShiftFw(const T* input, T* output, const int ntchw,
int ih = (tid % hw) / w;
int iw = tid % w;
if (ic < c / 4) {
const int c1 = static_cast<T>(c * shift_ratio);
const int c2 = static_cast<T>(c * 2 * shift_ratio);
if (ic < c1) {
src_it = it - 1;
} else if (ic < c / 2) {
} else if (ic < c2) {
src_it = it + 1;
} else {
src_it = it;
......@@ -50,7 +54,8 @@ __global__ void KeTemporalShiftFw(const T* input, T* output, const int ntchw,
template <typename T>
__global__ void KeTemporalShiftBw(const T* output_grad, T* input_grad, const int ntchw,
const int tchw, const int chw, const int hw, const int w, const int t, const int c) {
const int tchw, const int chw, const int hw, const int w, const int t, const int c,
const float shift_ratio) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
int src_it = 0;
......@@ -61,9 +66,12 @@ __global__ void KeTemporalShiftBw(const T* output_grad, T* input_grad, const int
int ih = (tid % hw) / w;
int iw = tid % w;
if (ic < c / 4) {
const int c1 = static_cast<T>(c * shift_ratio);
const int c2 = static_cast<T>(c * 2 * shift_ratio);
if (ic < c1) {
src_it = it - 1;
} else if (ic < c / 2) {
} else if (ic < c2) {
src_it = it + 1;
} else {
src_it = it;
......@@ -85,6 +93,7 @@ class TemporalShiftOpCUDAKernel : public framework::OpKernel<T> {
auto* input = ctx.Input<Tensor>("X");
auto* output = ctx.Output<Tensor>("Out");
int t = ctx.Attr<int>("seg_num");
float shift_ratio = ctx.Attr<float>("shift_ratio");
const int nt = input->dims()[0];
const int c = input->dims()[1];
......@@ -105,7 +114,7 @@ class TemporalShiftOpCUDAKernel : public framework::OpKernel<T> {
KeTemporalShiftFw<
T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
input_data, output_data, ntchw, tchw, chw, hw, w, t, c);
input_data, output_data, ntchw, tchw, chw, hw, w, t, c, shift_ratio);
}
};
......@@ -116,6 +125,7 @@ class TemporalShiftGradOpCUDAKernel : public framework::OpKernel<T> {
auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* output_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
int t = ctx.Attr<int>("seg_num");
float shift_ratio = ctx.Attr<float>("shift_ratio");
const int nt = output_grad->dims()[0];
const int c = output_grad->dims()[1];
......@@ -139,7 +149,7 @@ class TemporalShiftGradOpCUDAKernel : public framework::OpKernel<T> {
KeTemporalShiftBw<
T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
output_grad_data, input_grad_data, ntchw, tchw, chw, hw, w, t, c);
output_grad_data, input_grad_data, ntchw, tchw, chw, hw, w, t, c, shift_ratio);
}
};
......
......@@ -30,12 +30,16 @@ class TemporalShiftKernel: public framework::OpKernel<T> {
auto* input = ctx.Input<Tensor>("X");
auto* output = ctx.Output<Tensor>("Out");
int t = ctx.Attr<int>("seg_num");
float shift_ratio = ctx.Attr<float>("shift_ratio");
const int nt = input->dims()[0];
const int c = input->dims()[1];
const int h = input->dims()[2];
const int w = input->dims()[3];
const int c1 = static_cast<int>(c * shift_ratio);
const int c2 = static_cast<int>(c * 2 * shift_ratio);
const int hw = h * w;
const int chw = c * hw;
const int tchw = t * chw;
......@@ -51,9 +55,9 @@ class TemporalShiftKernel: public framework::OpKernel<T> {
int ih = (i % hw) / w;
int iw = i % w;
if (ic < c / 4) {
if (ic < c1) {
src_it = it - 1;
} else if (ic < c / 2) {
} else if (ic < c2) {
src_it = it + 1;
} else {
src_it = it;
......@@ -76,12 +80,16 @@ class TemporalShiftGradKernel : public framework::OpKernel<T> {
auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* output_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
int t = ctx.Attr<int>("seg_num");
float shift_ratio = ctx.Attr<float>("shift_ratio");
const int nt = output_grad->dims()[0];
const int c = output_grad->dims()[1];
const int h = output_grad->dims()[2];
const int w = output_grad->dims()[3];
const int c1 = static_cast<int>(c * shift_ratio);
const int c2 = static_cast<int>(c * 2 * shift_ratio);
const int hw = h * w;
const int chw = c * hw;
const int tchw = t * chw;
......@@ -98,9 +106,9 @@ class TemporalShiftGradKernel : public framework::OpKernel<T> {
int ih = (i % hw) / w;
int iw = i % w;
if (ic < c / 4) {
if (ic < c1) {
src_it = it - 1;
} else if (ic < c / 2) {
} else if (ic < c2) {
src_it = it + 1;
} else {
src_it = it;
......
......@@ -10266,7 +10266,7 @@ def shuffle_channel(x, group, name=None):
@templatedoc()
def temporal_shift(x, seg_num, name=None):
def temporal_shift(x, seg_num, shift_ratio=0.25, name=None):
"""
**Temporal Shift Operator**
......@@ -10275,6 +10275,7 @@ def temporal_shift(x, seg_num, name=None):
Args:
x(Variable): ${x_comment}
seg_num(int): ${seg_num_comment}
shift_ratio(float): ${shift_ratio_comment}
Returns:
out(Variable): The temporal shifting result is a tensor variable with the
......@@ -10287,7 +10288,7 @@ def temporal_shift(x, seg_num, name=None):
.. code-block:: python
input = fluid.layers.data(name='input', shape=[4,2,2], dtype='float32')
out = fluid.layers.temporal_shift(x=input, seg_num=2)
out = fluid.layers.temporal_shift(x=input, seg_num=2, shift_ratio=0.2)
"""
helper = LayerHelper("temporal_shift", **locals())
......@@ -10300,7 +10301,10 @@ def temporal_shift(x, seg_num, name=None):
type="temporal_shift",
inputs={"X": x},
outputs={"Out": out},
attrs={"seg_num": seg_num})
attrs={
"seg_num": seg_num,
"shift_ratio": shift_ratio
})
return out
......
......@@ -1052,7 +1052,7 @@ class TestBook(unittest.TestCase):
program = Program()
with program_guard(program):
x = layers.data(name="X", shape=[16, 4, 4], dtype="float32")
out = layers.temporal_shift(x, seg_num=4)
out = layers.temporal_shift(x, seg_num=4, shift_ratio=0.2)
self.assertIsNotNone(out)
print(str(program))
......
......@@ -21,13 +21,15 @@ from op_test import OpTest
from paddle.fluid import core
def temporal_shift(x, seg_num):
def temporal_shift(x, seg_num, shift_ratio):
shape = x.shape
reshape_x = x.reshape((-1, seg_num, shape[1], shape[2], shape[3]))
pad_x = np.pad(reshape_x, ((0, 0), (1, 1), (0, 0), (0, 0), (0, 0)), 'constant')
slice1 = pad_x[:, :seg_num, :shape[1]//4, :, :]
slice2 = pad_x[:, 2:seg_num+2, shape[1]//4:shape[1]//2, :, :]
slice3 = pad_x[:, 1:seg_num+1, shape[1]//2:, :, :]
c1 = int(shape[1] * shift_ratio)
c2 = int(shape[1] * 2 * shift_ratio)
slice1 = pad_x[:, :seg_num, :c1, :, :]
slice2 = pad_x[:, 2:seg_num+2, c1:c2, :, :]
slice3 = pad_x[:, 1:seg_num+1, c2:, :, :]
concat_x = np.concatenate([slice1, slice2, slice3], axis=2)
return concat_x.reshape(shape)
......@@ -39,13 +41,14 @@ class TestTemporalShift(OpTest):
self.attrs = {
"seg_num": self.seg_num,
"shift_ratio": self.shift_ratio,
}
self.inputs = {
"X": x,
}
output = temporal_shift(x, self.seg_num)
output = temporal_shift(x, self.seg_num, self.shift_ratio)
self.outputs = {"Out": output}
def test_check_output(self):
......@@ -57,17 +60,20 @@ class TestTemporalShift(OpTest):
def initTestCase(self):
self.x_shape = (6, 4, 4, 4)
self.seg_num = 3
self.shift_ratio = 0.25
class TestTemporalShift2(TestTemporalShift):
def initTestCase(self):
self.x_shape = (4, 9, 7, 7)
self.seg_num = 2
self.shift_ratio = 0.2
class TestTemporalShift2(TestTemporalShift):
def initTestCase(self):
self.x_shape = (3, 10, 5, 5)
self.seg_num = 1
self.shift_ratio = 0.3
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册