提交 82d4f903 编写于 作者: D dengkaipeng

fix format. test=develop

上级 28949f8e
...@@ -17,7 +17,7 @@ namespace operators { ...@@ -17,7 +17,7 @@ namespace operators {
using framework::Tensor; using framework::Tensor;
class TemporalShiftOp: public framework::OperatorWithKernel { class TemporalShiftOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
...@@ -29,23 +29,23 @@ class TemporalShiftOp: public framework::OperatorWithKernel { ...@@ -29,23 +29,23 @@ class TemporalShiftOp: public framework::OperatorWithKernel {
"Output(Out) of TemporalShiftOp should not be null."); "Output(Out) of TemporalShiftOp should not be null.");
auto dim_x = ctx->GetInputDim("X"); auto dim_x = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(dim_x.size(), 4, PADDLE_ENFORCE_EQ(dim_x.size(), 4,
"Input(X) rank should be 4 in shape of [N*T, C, H, W]."); "Input(X) rank should be 4 in shape of [N*T, C, H, W].");
int seg_num = ctx->Attrs().Get<int>("seg_num"); int seg_num = ctx->Attrs().Get<int>("seg_num");
float shift_ratio = ctx->Attrs().Get<float>("shift_ratio"); float shift_ratio = ctx->Attrs().Get<float>("shift_ratio");
PADDLE_ENFORCE_GT(seg_num, 0, PADDLE_ENFORCE_GT(seg_num, 0, "Attr(seg_num) should be greater than 0.");
"Attr(seg_num) should be greater than 0.");
PADDLE_ENFORCE(shift_ratio > 0 || shift_ratio < .5, PADDLE_ENFORCE(shift_ratio > 0 || shift_ratio < .5,
"Attr(shift_ratio) should be greater than 0 and less " "Attr(shift_ratio) should be greater than 0 and less "
"than 0.5."); "than 0.5.");
if (ctx->IsRuntime()) { if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(dim_x[0] % seg_num, 0, PADDLE_ENFORCE_EQ(
"Input(X) dims[0] should be divided exactly by Attr(seg_num)."); dim_x[0] % seg_num, 0,
"Input(X) dims[0] should be divided exactly by Attr(seg_num).");
} }
ctx->SetOutputDim("Out", dim_x); ctx->SetOutputDim("Out", dim_x);
ctx->ShareLoD("X", "Out"); ctx->ShareLoD("X", "Out");
} }
...@@ -70,14 +70,15 @@ class TemporalShiftOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -70,14 +70,15 @@ class TemporalShiftOpMaker : public framework::OpProtoAndCheckerMaker {
"The output tensor of temporal shift operator. " "The output tensor of temporal shift operator. "
"This is a 4-D tensor in the same shape with Input(X)."); "This is a 4-D tensor in the same shape with Input(X).");
AddAttr<int>("seg_num", AddAttr<int>("seg_num",
"The temporal segment number, this should be a positive " "The temporal segment number, this should be a positive "
"interger."); "interger.");
AddAttr<float>("shift_ratio", AddAttr<float>(
"The shift ratio of the channels, the first shift ratio part " "shift_ratio",
"of channels will be shifted by -1 along the temporal dimension, " "The shift ratio of the channels, the first shift ratio part "
"and the second shift ratio part of channels will be shifted by " "of channels will be shifted by -1 along the temporal dimension, "
"1 along the temporal dimension. Default 0.25.") "and the second shift ratio part of channels will be shifted by "
"1 along the temporal dimension. Default 0.25.")
.SetDefault(0.25); .SetDefault(0.25);
AddComment(R"DOC( AddComment(R"DOC(
...@@ -118,7 +119,7 @@ class TemporalShiftOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -118,7 +119,7 @@ class TemporalShiftOpMaker : public framework::OpProtoAndCheckerMaker {
} }
}; };
class TemporalShiftOpGrad: public framework::OperatorWithKernel { class TemporalShiftOpGrad : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
...@@ -144,7 +145,8 @@ class TemporalShiftOpGrad: public framework::OperatorWithKernel { ...@@ -144,7 +145,8 @@ class TemporalShiftOpGrad: public framework::OperatorWithKernel {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(temporal_shift, ops::TemporalShiftOp, ops::TemporalShiftOpMaker, REGISTER_OPERATOR(temporal_shift, ops::TemporalShiftOp,
ops::TemporalShiftOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>); paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(temporal_shift_grad, ops::TemporalShiftOpGrad); REGISTER_OPERATOR(temporal_shift_grad, ops::TemporalShiftOpGrad);
REGISTER_OP_CPU_KERNEL(temporal_shift, ops::TemporalShiftKernel<float>, REGISTER_OP_CPU_KERNEL(temporal_shift, ops::TemporalShiftKernel<float>,
......
...@@ -17,70 +17,72 @@ namespace operators { ...@@ -17,70 +17,72 @@ namespace operators {
using framework::Tensor; using framework::Tensor;
template <typename T> template <typename T>
__global__ void KeTemporalShiftFw(const T* input, T* output, const int ntchw, __global__ void KeTemporalShiftFw(const T* input, T* output, const int ntchw,
const int tchw, const int chw, const int hw, const int w, const int t, const int c, const int tchw, const int chw, const int hw,
const float shift_ratio) { const int w, const int t, const int c,
const float shift_ratio) {
int tid = blockIdx.x * blockDim.x + threadIdx.x; int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x; int stride = blockDim.x * gridDim.x;
int src_it = 0; int src_it = 0;
for (; tid < ntchw; tid += stride) { for (; tid < ntchw; tid += stride) {
int in = tid / tchw; int in = tid / tchw;
int it = (tid % tchw) / chw; int it = (tid % tchw) / chw;
int ic = (tid % chw) / hw; int ic = (tid % chw) / hw;
int ih = (tid % hw) / w; int ih = (tid % hw) / w;
int iw = tid % w; int iw = tid % w;
const int c1 = static_cast<T>(c * shift_ratio); const int c1 = static_cast<T>(c * shift_ratio);
const int c2 = static_cast<T>(c * 2 * shift_ratio); const int c2 = static_cast<T>(c * 2 * shift_ratio);
if (ic < c1) { if (ic < c1) {
src_it = it - 1; src_it = it - 1;
} else if (ic < c2) { } else if (ic < c2) {
src_it = it + 1; src_it = it + 1;
} else { } else {
src_it = it; src_it = it;
} }
if (src_it < 0 || src_it >= t) { if (src_it < 0 || src_it >= t) {
output[tid] = 0; output[tid] = 0;
} else { } else {
int src_idx = GetEntryIndex(in, src_it, ic, ih, iw, tchw, chw, hw, w); int src_idx = GetEntryIndex(in, src_it, ic, ih, iw, tchw, chw, hw, w);
output[tid] = input[src_idx]; output[tid] = input[src_idx];
} }
} }
} }
template <typename T> template <typename T>
__global__ void KeTemporalShiftBw(const T* output_grad, T* input_grad, const int ntchw, __global__ void KeTemporalShiftBw(const T* output_grad, T* input_grad,
const int tchw, const int chw, const int hw, const int w, const int t, const int c, const int ntchw, const int tchw,
const float shift_ratio) { const int chw, const int hw, const int w,
const int t, const int c,
const float shift_ratio) {
int tid = blockIdx.x * blockDim.x + threadIdx.x; int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x; int stride = blockDim.x * gridDim.x;
int src_it = 0; int src_it = 0;
for (; tid < ntchw; tid += stride) { for (; tid < ntchw; tid += stride) {
int in = tid / tchw; int in = tid / tchw;
int it = (tid % tchw) / chw; int it = (tid % tchw) / chw;
int ic = (tid % chw) / hw; int ic = (tid % chw) / hw;
int ih = (tid % hw) / w; int ih = (tid % hw) / w;
int iw = tid % w; int iw = tid % w;
const int c1 = static_cast<T>(c * shift_ratio); const int c1 = static_cast<T>(c * shift_ratio);
const int c2 = static_cast<T>(c * 2 * shift_ratio); const int c2 = static_cast<T>(c * 2 * shift_ratio);
if (ic < c1) { if (ic < c1) {
src_it = it - 1; src_it = it - 1;
} else if (ic < c2) { } else if (ic < c2) {
src_it = it + 1; src_it = it + 1;
} else { } else {
src_it = it; src_it = it;
} }
if (src_it >= 0 && src_it < t) { if (src_it >= 0 && src_it < t) {
int src_idx = GetEntryIndex(in, src_it, ic, ih, iw, tchw, chw, hw, w); int src_idx = GetEntryIndex(in, src_it, ic, ih, iw, tchw, chw, hw, w);
input_grad[src_idx] = output_grad[tid]; input_grad[src_idx] = output_grad[tid];
} }
} }
} }
...@@ -113,8 +115,8 @@ class TemporalShiftOpCUDAKernel : public framework::OpKernel<T> { ...@@ -113,8 +115,8 @@ class TemporalShiftOpCUDAKernel : public framework::OpKernel<T> {
grid_dim = grid_dim > 8 ? 8 : grid_dim; grid_dim = grid_dim > 8 ? 8 : grid_dim;
KeTemporalShiftFw< KeTemporalShiftFw<
T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>( T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
input_data, output_data, ntchw, tchw, chw, hw, w, t, c, shift_ratio); input_data, output_data, ntchw, tchw, chw, hw, w, t, c, shift_ratio);
} }
}; };
...@@ -138,7 +140,8 @@ class TemporalShiftGradOpCUDAKernel : public framework::OpKernel<T> { ...@@ -138,7 +140,8 @@ class TemporalShiftGradOpCUDAKernel : public framework::OpKernel<T> {
const int ntchw = nt * chw; const int ntchw = nt * chw;
const T* output_grad_data = output_grad->data<T>(); const T* output_grad_data = output_grad->data<T>();
T* input_grad_data = input_grad->mutable_data<T>({nt, c, h, w}, ctx.GetPlace()); T* input_grad_data =
input_grad->mutable_data<T>({nt, c, h, w}, ctx.GetPlace());
math::SetConstant<platform::CUDADeviceContext, T>()( math::SetConstant<platform::CUDADeviceContext, T>()(
ctx.template device_context<platform::CUDADeviceContext>(), input_grad, ctx.template device_context<platform::CUDADeviceContext>(), input_grad,
static_cast<T>(0)); static_cast<T>(0));
...@@ -148,8 +151,9 @@ class TemporalShiftGradOpCUDAKernel : public framework::OpKernel<T> { ...@@ -148,8 +151,9 @@ class TemporalShiftGradOpCUDAKernel : public framework::OpKernel<T> {
grid_dim = grid_dim > 8 ? 8 : grid_dim; grid_dim = grid_dim > 8 ? 8 : grid_dim;
KeTemporalShiftBw< KeTemporalShiftBw<
T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>( T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
output_grad_data, input_grad_data, ntchw, tchw, chw, hw, w, t, c, shift_ratio); output_grad_data, input_grad_data, ntchw, tchw, chw, hw, w, t, c,
shift_ratio);
} }
}; };
......
...@@ -18,13 +18,15 @@ namespace operators { ...@@ -18,13 +18,15 @@ namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
static HOSTDEVICE inline int GetEntryIndex(int in, int it, int ic, int ih, int iw, static HOSTDEVICE inline int GetEntryIndex(int in, int it, int ic, int ih,
const int tchw, const int chw, const int hw, const int w) { int iw, const int tchw,
const int chw, const int hw,
const int w) {
return in * tchw + it * chw + ic * hw + ih * w + iw; return in * tchw + it * chw + ic * hw + ih * w + iw;
} }
template <typename T> template <typename T>
class TemporalShiftKernel: public framework::OpKernel<T> { class TemporalShiftKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<Tensor>("X"); auto* input = ctx.Input<Tensor>("X");
...@@ -62,7 +64,7 @@ class TemporalShiftKernel: public framework::OpKernel<T> { ...@@ -62,7 +64,7 @@ class TemporalShiftKernel: public framework::OpKernel<T> {
} else { } else {
src_it = it; src_it = it;
} }
if (src_it < 0 || src_it >= t) { if (src_it < 0 || src_it >= t) {
output_data[i] = 0; output_data[i] = 0;
} else { } else {
...@@ -95,7 +97,8 @@ class TemporalShiftGradKernel : public framework::OpKernel<T> { ...@@ -95,7 +97,8 @@ class TemporalShiftGradKernel : public framework::OpKernel<T> {
const int tchw = t * chw; const int tchw = t * chw;
const T* output_grad_data = output_grad->data<T>(); const T* output_grad_data = output_grad->data<T>();
T* input_grad_data = input_grad->mutable_data<T>({nt, c, h, w}, ctx.GetPlace()); T* input_grad_data =
input_grad->mutable_data<T>({nt, c, h, w}, ctx.GetPlace());
memset(input_grad_data, 0, input_grad->numel() * sizeof(T)); memset(input_grad_data, 0, input_grad->numel() * sizeof(T));
int src_it = 0; int src_it = 0;
...@@ -113,7 +116,7 @@ class TemporalShiftGradKernel : public framework::OpKernel<T> { ...@@ -113,7 +116,7 @@ class TemporalShiftGradKernel : public framework::OpKernel<T> {
} else { } else {
src_it = it; src_it = it;
} }
if (src_it >= 0 && src_it < t) { if (src_it >= 0 && src_it < t) {
int src_idx = GetEntryIndex(in, src_it, ic, ih, iw, tchw, chw, hw, w); int src_idx = GetEntryIndex(in, src_it, ic, ih, iw, tchw, chw, hw, w);
input_grad_data[src_idx] = output_grad_data[i]; input_grad_data[src_idx] = output_grad_data[i];
......
...@@ -10301,10 +10301,8 @@ def temporal_shift(x, seg_num, shift_ratio=0.25, name=None): ...@@ -10301,10 +10301,8 @@ def temporal_shift(x, seg_num, shift_ratio=0.25, name=None):
type="temporal_shift", type="temporal_shift",
inputs={"X": x}, inputs={"X": x},
outputs={"Out": out}, outputs={"Out": out},
attrs={ attrs={"seg_num": seg_num,
"seg_num": seg_num, "shift_ratio": shift_ratio})
"shift_ratio": shift_ratio
})
return out return out
......
...@@ -24,15 +24,17 @@ from paddle.fluid import core ...@@ -24,15 +24,17 @@ from paddle.fluid import core
def temporal_shift(x, seg_num, shift_ratio): def temporal_shift(x, seg_num, shift_ratio):
shape = x.shape shape = x.shape
reshape_x = x.reshape((-1, seg_num, shape[1], shape[2], shape[3])) reshape_x = x.reshape((-1, seg_num, shape[1], shape[2], shape[3]))
pad_x = np.pad(reshape_x, ((0, 0), (1, 1), (0, 0), (0, 0), (0, 0)), 'constant') pad_x = np.pad(reshape_x, ((0, 0), (1, 1), (0, 0), (0, 0), (0, 0)),
'constant')
c1 = int(shape[1] * shift_ratio) c1 = int(shape[1] * shift_ratio)
c2 = int(shape[1] * 2 * shift_ratio) c2 = int(shape[1] * 2 * shift_ratio)
slice1 = pad_x[:, :seg_num, :c1, :, :] slice1 = pad_x[:, :seg_num, :c1, :, :]
slice2 = pad_x[:, 2:seg_num+2, c1:c2, :, :] slice2 = pad_x[:, 2:seg_num + 2, c1:c2, :, :]
slice3 = pad_x[:, 1:seg_num+1, c2:, :, :] slice3 = pad_x[:, 1:seg_num + 1, c2:, :, :]
concat_x = np.concatenate([slice1, slice2, slice3], axis=2) concat_x = np.concatenate([slice1, slice2, slice3], axis=2)
return concat_x.reshape(shape) return concat_x.reshape(shape)
class TestTemporalShift(OpTest): class TestTemporalShift(OpTest):
def setUp(self): def setUp(self):
self.initTestCase() self.initTestCase()
...@@ -44,9 +46,7 @@ class TestTemporalShift(OpTest): ...@@ -44,9 +46,7 @@ class TestTemporalShift(OpTest):
"shift_ratio": self.shift_ratio, "shift_ratio": self.shift_ratio,
} }
self.inputs = { self.inputs = {"X": x, }
"X": x,
}
output = temporal_shift(x, self.seg_num, self.shift_ratio) output = temporal_shift(x, self.seg_num, self.shift_ratio)
self.outputs = {"Out": output} self.outputs = {"Out": output}
...@@ -62,6 +62,7 @@ class TestTemporalShift(OpTest): ...@@ -62,6 +62,7 @@ class TestTemporalShift(OpTest):
self.seg_num = 3 self.seg_num = 3
self.shift_ratio = 0.25 self.shift_ratio = 0.25
class TestTemporalShift2(TestTemporalShift): class TestTemporalShift2(TestTemporalShift):
def initTestCase(self): def initTestCase(self):
self.x_shape = (4, 9, 7, 7) self.x_shape = (4, 9, 7, 7)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册