未验证 提交 63ac947e 编写于 作者: K Kaipeng Deng 提交者: GitHub

Merge pull request #16135 from heavengate/shift

Add temporal_shift op for TSM model
......@@ -225,6 +225,7 @@ paddle.fluid.layers.merge_selected_rows (ArgSpec(args=['x', 'name'], varargs=Non
paddle.fluid.layers.get_tensor_from_selected_rows (ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '7ffc849e71f31dfe29030ff94e662de6'))
paddle.fluid.layers.lstm (ArgSpec(args=['input', 'init_h', 'init_c', 'max_len', 'hidden_size', 'num_layers', 'dropout_prob', 'is_bidirec', 'is_test', 'name', 'default_initializer', 'seed'], varargs=None, keywords=None, defaults=(0.0, False, False, None, None, -1)), ('document', 'd5e6c494ac35100e2ed4d4bd9a1ed932'))
paddle.fluid.layers.shuffle_channel (ArgSpec(args=['x', 'group', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '2fa6782d43d02ae64482d21235a82949'))
paddle.fluid.layers.temporal_shift (ArgSpec(args=['x', 'seg_num', 'shift_ratio', 'name'], varargs=None, keywords=None, defaults=(0.25, None)), ('document', 'fe4481fb31363b09cfdd228fc6776ddf'))
paddle.fluid.layers.py_func (ArgSpec(args=['func', 'x', 'out', 'backward_func', 'skip_vars_in_backward_input'], varargs=None, keywords=None, defaults=(None, None)), ('document', '8404e472ac12b4a30a505d3d3a3e5fdb'))
paddle.fluid.layers.psroi_pool (ArgSpec(args=['input', 'rois', 'output_channels', 'spatial_scale', 'pooled_height', 'pooled_width', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '1546136806fef5c08f6918544bd9151d'))
paddle.fluid.layers.teacher_student_sigmoid_loss (ArgSpec(args=['input', 'label', 'soft_max_up_bound', 'soft_max_lower_bound'], varargs=None, keywords=None, defaults=(15.0, -15.0)), ('document', '2f6ff96864054a31aa4bb659c6722c99'))
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/temporal_shift_op.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using framework::Tensor;
class TemporalShiftOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of TemporalShiftOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of TemporalShiftOp should not be null.");
auto dim_x = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(dim_x.size(), 4,
"Input(X) rank should be 4 in shape of [N*T, C, H, W].");
int seg_num = ctx->Attrs().Get<int>("seg_num");
float shift_ratio = ctx->Attrs().Get<float>("shift_ratio");
PADDLE_ENFORCE_GT(seg_num, 0, "Attr(seg_num) should be greater than 0.");
PADDLE_ENFORCE(shift_ratio > 0 || shift_ratio < .5,
"Attr(shift_ratio) should be greater than 0 and less "
"than 0.5.");
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(
dim_x[0] % seg_num, 0,
"Input(X) dims[0] should be divided exactly by Attr(seg_num).");
}
ctx->SetOutputDim("Out", dim_x);
ctx->ShareLoD("X", "Out");
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
ctx.GetPlace());
}
};
class TemporalShiftOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"The input tensor of temporal shift operator. "
"This is a 4-D tensor with shape of [N*T, C, H, W]. "
"While N is the batch size, T is the temporal segment "
"number, C is the channel number, H is the height of "
"features and W is the width of features.");
AddOutput("Out",
"The output tensor of temporal shift operator. "
"This is a 4-D tensor in the same shape with Input(X).");
AddAttr<int>("seg_num",
"The temporal segment number, this should be a positive "
"integer.");
AddAttr<float>(
"shift_ratio",
"The shift ratio of the channels, the first :attr:`shift_ratio` part "
"of channels will be shifted by -1 along the temporal dimension, "
"and the second :attr:`shift_ratio` part of channels will be shifted "
"by 1 along the temporal dimension. Default 0.25.")
.SetDefault(0.25);
AddComment(R"DOC(
This operator calculates the temporal shifting features for Input(X).
Input(X) should be in shape of [N*T, C, H, W], while N is the batch
size, T is the temporal segment number specified by :attr:`seg_num`,
C is the channel number, H and W is the height and width of features.
Temporal Shifting is calculated as follows:
Step 1: Reshape Input(X) to [N, T, C, H, W].
Step 2: Pad 0 to reshaping result in the 2nd(T) dimension with
padding width as 1 on each side, padding result will be in shape
of [N, T+2, C, H, W].
Step 3: Assume :attr:`shift_ratio` is :math:`1/4`, slice padding
result as follows:
$$
slice1 = x[:, :T, :C/4, :, :]
$$
$$
slice2 = x[:, 2:T+2, C/4:C/2, :, :]
$$
$$
slice3 = x[:, 1:T+1, C/2:, :, :]
$$
Step 4: Concatenate three slices along the 3rd(C) dimension and
reshape result to [N*T, C, H, W].
For details of temporal shifting, please refer to paper:
`Temporal Shift Module <http://arxiv.org/abs/1811.08383>`_ .
)DOC");
}
};
class TemporalShiftOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null");
auto dim_x = ctx->GetInputDim("X");
if (ctx->HasOutput(framework::GradVarName("X"))) {
ctx->SetOutputDim(framework::GradVarName("X"), dim_x);
}
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
ctx.GetPlace());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(temporal_shift, ops::TemporalShiftOp,
ops::TemporalShiftOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(temporal_shift_grad, ops::TemporalShiftOpGrad);
REGISTER_OP_CPU_KERNEL(temporal_shift, ops::TemporalShiftKernel<float>,
ops::TemporalShiftKernel<double>);
REGISTER_OP_CPU_KERNEL(temporal_shift_grad, ops::TemporalShiftGradKernel<float>,
ops::TemporalShiftGradKernel<double>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/temporal_shift_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
namespace paddle {
namespace operators {
using framework::Tensor;
template <typename T>
__global__ void KeTemporalShiftFw(const T* input, T* output, const int ntchw,
const int tchw, const int chw, const int hw,
const int w, const int t, const int c,
const float shift_ratio) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
int src_it = 0;
for (; tid < ntchw; tid += stride) {
int in = tid / tchw;
int it = (tid % tchw) / chw;
int ic = (tid % chw) / hw;
int ih = (tid % hw) / w;
int iw = tid % w;
const int c1 = static_cast<T>(c * shift_ratio);
const int c2 = static_cast<T>(c * 2 * shift_ratio);
if (ic < c1) {
src_it = it - 1;
} else if (ic < c2) {
src_it = it + 1;
} else {
src_it = it;
}
if (src_it < 0 || src_it >= t) {
output[tid] = 0;
} else {
int src_idx = GetEntryIndex(in, src_it, ic, ih, iw, tchw, chw, hw, w);
output[tid] = input[src_idx];
}
}
}
template <typename T>
__global__ void KeTemporalShiftBw(const T* output_grad, T* input_grad,
const int ntchw, const int tchw,
const int chw, const int hw, const int w,
const int t, const int c,
const float shift_ratio) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
int src_it = 0;
for (; tid < ntchw; tid += stride) {
int in = tid / tchw;
int it = (tid % tchw) / chw;
int ic = (tid % chw) / hw;
int ih = (tid % hw) / w;
int iw = tid % w;
const int c1 = static_cast<T>(c * shift_ratio);
const int c2 = static_cast<T>(c * 2 * shift_ratio);
if (ic < c1) {
src_it = it - 1;
} else if (ic < c2) {
src_it = it + 1;
} else {
src_it = it;
}
if (src_it >= 0 && src_it < t) {
int src_idx = GetEntryIndex(in, src_it, ic, ih, iw, tchw, chw, hw, w);
input_grad[src_idx] = output_grad[tid];
}
}
}
template <typename T>
class TemporalShiftOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"This kernel only runs on GPU device.");
auto* input = ctx.Input<Tensor>("X");
auto* output = ctx.Output<Tensor>("Out");
int t = ctx.Attr<int>("seg_num");
float shift_ratio = ctx.Attr<float>("shift_ratio");
const int nt = input->dims()[0];
const int c = input->dims()[1];
const int h = input->dims()[2];
const int w = input->dims()[3];
const int hw = h * w;
const int chw = c * hw;
const int tchw = t * chw;
const int ntchw = nt * chw;
const T* input_data = input->data<T>();
T* output_data = output->mutable_data<T>({nt, c, h, w}, ctx.GetPlace());
int pixelNum = nt * chw;
int grid_dim = (pixelNum + 512 - 1) / 512;
grid_dim = grid_dim > 8 ? 8 : grid_dim;
KeTemporalShiftFw<
T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
input_data, output_data, ntchw, tchw, chw, hw, w, t, c, shift_ratio);
}
};
template <typename T>
class TemporalShiftGradOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* output_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
int t = ctx.Attr<int>("seg_num");
float shift_ratio = ctx.Attr<float>("shift_ratio");
const int nt = output_grad->dims()[0];
const int c = output_grad->dims()[1];
const int h = output_grad->dims()[2];
const int w = output_grad->dims()[3];
const int hw = h * w;
const int chw = c * hw;
const int tchw = t * chw;
const int ntchw = nt * chw;
const T* output_grad_data = output_grad->data<T>();
T* input_grad_data =
input_grad->mutable_data<T>({nt, c, h, w}, ctx.GetPlace());
math::SetConstant<platform::CUDADeviceContext, T>()(
ctx.template device_context<platform::CUDADeviceContext>(), input_grad,
static_cast<T>(0));
int pixelNum = nt * chw;
int grid_dim = (pixelNum + 512 - 1) / 512;
grid_dim = grid_dim > 8 ? 8 : grid_dim;
KeTemporalShiftBw<
T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
output_grad_data, input_grad_data, ntchw, tchw, chw, hw, w, t, c,
shift_ratio);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(temporal_shift, ops::TemporalShiftOpCUDAKernel<float>,
ops::TemporalShiftOpCUDAKernel<double>);
REGISTER_OP_CUDA_KERNEL(temporal_shift_grad,
ops::TemporalShiftGradOpCUDAKernel<float>,
ops::TemporalShiftGradOpCUDAKernel<double>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
static HOSTDEVICE inline int GetEntryIndex(int in, int it, int ic, int ih,
int iw, const int tchw,
const int chw, const int hw,
const int w) {
return in * tchw + it * chw + ic * hw + ih * w + iw;
}
template <typename T>
class TemporalShiftKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<Tensor>("X");
auto* output = ctx.Output<Tensor>("Out");
int t = ctx.Attr<int>("seg_num");
float shift_ratio = ctx.Attr<float>("shift_ratio");
const int nt = input->dims()[0];
const int c = input->dims()[1];
const int h = input->dims()[2];
const int w = input->dims()[3];
const int c1 = static_cast<int>(c * shift_ratio);
const int c2 = static_cast<int>(c * 2 * shift_ratio);
const int hw = h * w;
const int chw = c * hw;
const int tchw = t * chw;
const T* input_data = input->data<T>();
T* output_data = output->mutable_data<T>({nt, c, h, w}, ctx.GetPlace());
int src_it = 0;
for (int i = 0; i < output->numel(); i++) {
int in = i / tchw;
int it = (i % tchw) / chw;
int ic = (i % chw) / hw;
int ih = (i % hw) / w;
int iw = i % w;
if (ic < c1) {
src_it = it - 1;
} else if (ic < c2) {
src_it = it + 1;
} else {
src_it = it;
}
if (src_it < 0 || src_it >= t) {
output_data[i] = 0;
} else {
int src_idx = GetEntryIndex(in, src_it, ic, ih, iw, tchw, chw, hw, w);
output_data[i] = input_data[src_idx];
}
}
}
};
template <typename T>
class TemporalShiftGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* output_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
int t = ctx.Attr<int>("seg_num");
float shift_ratio = ctx.Attr<float>("shift_ratio");
const int nt = output_grad->dims()[0];
const int c = output_grad->dims()[1];
const int h = output_grad->dims()[2];
const int w = output_grad->dims()[3];
const int c1 = static_cast<int>(c * shift_ratio);
const int c2 = static_cast<int>(c * 2 * shift_ratio);
const int hw = h * w;
const int chw = c * hw;
const int tchw = t * chw;
const T* output_grad_data = output_grad->data<T>();
T* input_grad_data =
input_grad->mutable_data<T>({nt, c, h, w}, ctx.GetPlace());
memset(input_grad_data, 0, input_grad->numel() * sizeof(T));
int src_it = 0;
for (int i = 0; i < output_grad->numel(); i++) {
int in = i / tchw;
int it = (i % tchw) / chw;
int ic = (i % chw) / hw;
int ih = (i % hw) / w;
int iw = i % w;
if (ic < c1) {
src_it = it - 1;
} else if (ic < c2) {
src_it = it + 1;
} else {
src_it = it;
}
if (src_it >= 0 && src_it < t) {
int src_idx = GetEntryIndex(in, src_it, ic, ih, iw, tchw, chw, hw, w);
input_grad_data[src_idx] = output_grad_data[i];
}
}
}
};
} // namespace operators
} // namespace paddle
......@@ -183,6 +183,7 @@ __all__ = [
'get_tensor_from_selected_rows',
'lstm',
'shuffle_channel',
'temporal_shift',
'py_func',
'psroi_pool',
'teacher_student_sigmoid_loss',
......@@ -10391,6 +10392,48 @@ def shuffle_channel(x, group, name=None):
return out
@templatedoc()
def temporal_shift(x, seg_num, shift_ratio=0.25, name=None):
"""
**Temporal Shift Operator**
${comment}
Args:
x(Variable): ${x_comment}
seg_num(int): ${seg_num_comment}
shift_ratio(float): ${shift_ratio_comment}
name (str, default None): The name of this layer.
Returns:
out(Variable): The temporal shifting result is a tensor variable with the
same shape and same type as the input.
Raises:
TypeError: seg_num must be int type.
Examples:
.. code-block:: python
input = fluid.layers.data(name='input', shape=[4,2,2], dtype='float32')
out = fluid.layers.temporal_shift(x=input, seg_num=2, shift_ratio=0.2)
"""
helper = LayerHelper("temporal_shift", **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
if not isinstance(seg_num, int):
raise TypeError("seg_num must be int type.")
helper.append_op(
type="temporal_shift",
inputs={"X": x},
outputs={"Out": out},
attrs={"seg_num": seg_num,
"shift_ratio": shift_ratio})
return out
class PyFuncRegistry(object):
_register_funcs = []
......
......@@ -1593,6 +1593,14 @@ class TestBook(unittest.TestCase):
print(str(program))
def test_temporal_shift(self):
program = Program()
with program_guard(program):
x = layers.data(name="X", shape=[16, 4, 4], dtype="float32")
out = layers.temporal_shift(x, seg_num=4, shift_ratio=0.2)
self.assertIsNotNone(out)
print(str(program))
def test_shuffle_channel(self):
program = Program()
with program_guard(program):
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
import unittest
import numpy as np
from op_test import OpTest
from paddle.fluid import core
def temporal_shift(x, seg_num, shift_ratio):
shape = x.shape
reshape_x = x.reshape((-1, seg_num, shape[1], shape[2], shape[3]))
pad_x = np.pad(reshape_x, ((0, 0), (1, 1), (0, 0), (0, 0), (0, 0)),
'constant')
c1 = int(shape[1] * shift_ratio)
c2 = int(shape[1] * 2 * shift_ratio)
slice1 = pad_x[:, :seg_num, :c1, :, :]
slice2 = pad_x[:, 2:seg_num + 2, c1:c2, :, :]
slice3 = pad_x[:, 1:seg_num + 1, c2:, :, :]
concat_x = np.concatenate([slice1, slice2, slice3], axis=2)
return concat_x.reshape(shape)
class TestTemporalShift(OpTest):
def setUp(self):
self.initTestCase()
self.op_type = 'temporal_shift'
x = np.random.random(self.x_shape).astype('float32')
self.attrs = {
"seg_num": self.seg_num,
"shift_ratio": self.shift_ratio,
}
self.inputs = {"X": x, }
output = temporal_shift(x, self.seg_num, self.shift_ratio)
self.outputs = {"Out": output}
def test_check_output(self):
self.check_output()
def test_check_grad_ignore_uv(self):
self.check_grad(['X'], 'Out')
def initTestCase(self):
self.x_shape = (6, 4, 4, 4)
self.seg_num = 3
self.shift_ratio = 0.25
class TestTemporalShift2(TestTemporalShift):
def initTestCase(self):
self.x_shape = (4, 9, 7, 7)
self.seg_num = 2
self.shift_ratio = 0.2
class TestTemporalShift3(TestTemporalShift):
def initTestCase(self):
self.x_shape = (3, 10, 5, 5)
self.seg_num = 1
self.shift_ratio = 0.3
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册