From 7f50bb7ec162c42285d3822e643c93685a9c917e Mon Sep 17 00:00:00 2001 From: Zhang Ting Date: Wed, 17 Mar 2021 19:22:29 +0800 Subject: [PATCH] support NHWC for temporal_shift op (#31642) --- paddle/fluid/operators/temporal_shift_op.cc | 19 +- paddle/fluid/operators/temporal_shift_op.cu | 179 +++++++++++---- paddle/fluid/operators/temporal_shift_op.h | 211 ++++++++++++------ python/paddle/fluid/layers/nn.py | 22 +- .../tests/unittests/test_temporal_shift_op.py | 33 ++- 5 files changed, 338 insertions(+), 126 deletions(-) diff --git a/paddle/fluid/operators/temporal_shift_op.cc b/paddle/fluid/operators/temporal_shift_op.cc index 2e87447ed1..acf99d09ff 100644 --- a/paddle/fluid/operators/temporal_shift_op.cc +++ b/paddle/fluid/operators/temporal_shift_op.cc @@ -80,7 +80,8 @@ class TemporalShiftOpMaker : public framework::OpProtoAndCheckerMaker { void Make() override { AddInput("X", "The input tensor of temporal shift operator. " - "This is a 4-D tensor with shape of [N*T, C, H, W]. " + "This is a 4-D tensor with shape of [N*T, C, H, W] " + "or [N*T, H, W, C]. " "While N is the batch size, T is the temporal segment " "number, C is the channel number, H is the height of " "features and W is the width of features. " @@ -100,15 +101,23 @@ class TemporalShiftOpMaker : public framework::OpProtoAndCheckerMaker { "by 1 along the temporal dimension. :attr:`shift_ratio` should be in " "range [0, 0.5]. Default 0.25.") .SetDefault(0.25); + AddAttr( + "data_format", + "(string, default NCHW) Only used in " + "an optional string from: \"NHWC\", \"NCHW\". " + "Specify that the data format of the input and output data is " + "channel_first or channel_last.") + .SetDefault("NCHW"); AddComment(R"DOC( This operator calculates the temporal shifting features for Input(X). - Input(X) should be in shape of [N*T, C, H, W], while N is the batch - size, T is the temporal segment number specified by :attr:`seg_num`, - C is the channel number, H and W is the height and width of features. + Input(X) should be in shape of [N*T, C, H, W] or [N*T, H, W, C], while + N is the batch size, T is the temporal segment number specified by + :attr:`seg_num`, C is the channel number, H and W is the height and + width of features. - Temporal Shifting is calculated as follows: + Temporal Shifting is calculated as follows when data format is NCHW: Step 1: Reshape Input(X) to [N, T, C, H, W]. diff --git a/paddle/fluid/operators/temporal_shift_op.cu b/paddle/fluid/operators/temporal_shift_op.cu index 4f2d7ce3cf..cb1ff5335c 100644 --- a/paddle/fluid/operators/temporal_shift_op.cu +++ b/paddle/fluid/operators/temporal_shift_op.cu @@ -19,22 +19,46 @@ namespace operators { using framework::Tensor; template -__global__ void KeTemporalShiftFw(const T* input, T* output, const int ntchw, - const int tchw, const int chw, const int hw, - const int w, const int t, const int c, - const float shift_ratio) { +__global__ void KeTemporalShiftFwNCHW(const T* input, T* output, + const int ntchw, const int tchw, + const int chw, const int hw, const int t, + const int c1, const int c2) { int tid = blockIdx.x * blockDim.x + threadIdx.x; int stride = blockDim.x * gridDim.x; int src_it = 0; + for (; tid < ntchw; tid += stride) { - int in = tid / tchw; int it = (tid % tchw) / chw; int ic = (tid % chw) / hw; - int ih = (tid % hw) / w; - int iw = tid % w; - const int c1 = static_cast(c * shift_ratio); - const int c2 = static_cast(c * 2 * shift_ratio); + if (ic < c1) { + src_it = it - 1; + } else if (ic < c2) { + src_it = it + 1; + } else { + src_it = it; + } + + if (src_it < 0 || src_it >= t) { + output[tid] = 0; + } else { + output[tid] = input[tid + (src_it - it) * chw]; + } + } +} + +template +__global__ void KeTemporalShiftFwNHWC(const T* input, T* output, + const int nthwc, const int thwc, + const int hwc, const int t, const int c, + const int c1, const int c2) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + int src_it = 0; + + for (; tid < nthwc; tid += stride) { + int it = (tid % thwc) / hwc; + int ic = tid % c; if (ic < c1) { src_it = it - 1; @@ -47,42 +71,65 @@ __global__ void KeTemporalShiftFw(const T* input, T* output, const int ntchw, if (src_it < 0 || src_it >= t) { output[tid] = 0; } else { - int src_idx = GetEntryIndex(in, src_it, ic, ih, iw, tchw, chw, hw, w); - output[tid] = input[src_idx]; + output[tid] = input[tid + (src_it - it) * hwc]; } } } template -__global__ void KeTemporalShiftBw(const T* output_grad, T* input_grad, - const int ntchw, const int tchw, - const int chw, const int hw, const int w, - const int t, const int c, - const float shift_ratio) { +__global__ void KeTemporalShiftBwNCHW(const T* output_grad, T* input_grad, + const int ntchw, const int tchw, + const int chw, const int hw, const int t, + const int c1, const int c2) { int tid = blockIdx.x * blockDim.x + threadIdx.x; int stride = blockDim.x * gridDim.x; int src_it = 0; + for (; tid < ntchw; tid += stride) { - int in = tid / tchw; int it = (tid % tchw) / chw; int ic = (tid % chw) / hw; - int ih = (tid % hw) / w; - int iw = tid % w; - - const int c1 = static_cast(c * shift_ratio); - const int c2 = static_cast(c * 2 * shift_ratio); if (ic < c1) { - src_it = it - 1; + src_it = it + 1; } else if (ic < c2) { + src_it = it - 1; + } else { + src_it = it; + } + + if (src_it >= 0 && src_it < t) { + input_grad[tid] = output_grad[tid + (src_it - it) * chw]; + } else { + input_grad[tid] = 0; + } + } +} + +template +__global__ void KeTemporalShiftBwNHWC(const T* output_grad, T* input_grad, + const int nthwc, const int thwc, + const int hwc, const int t, const int c, + const int c1, const int c2) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + int src_it = 0; + + for (; tid < nthwc; tid += stride) { + int it = (tid % thwc) / hwc; + int ic = tid % c; + + if (ic < c1) { src_it = it + 1; + } else if (ic < c2) { + src_it = it - 1; } else { src_it = it; } if (src_it >= 0 && src_it < t) { - int src_idx = GetEntryIndex(in, src_it, ic, ih, iw, tchw, chw, hw, w); - input_grad[src_idx] = output_grad[tid]; + input_grad[tid] = output_grad[tid + (src_it - it) * hwc]; + } else { + input_grad[tid] = 0; } } } @@ -98,27 +145,48 @@ class TemporalShiftOpCUDAKernel : public framework::OpKernel { auto* output = ctx.Output("Out"); int t = ctx.Attr("seg_num"); float shift_ratio = ctx.Attr("shift_ratio"); + const std::string data_format_str = ctx.Attr("data_format"); + const DataLayout data_layout = + framework::StringToDataLayout(data_format_str); const int nt = input->dims()[0]; - const int c = input->dims()[1]; - const int h = input->dims()[2]; - const int w = input->dims()[3]; + const int c = (data_layout == DataLayout::kNCHW ? input->dims()[1] + : input->dims()[3]); + const int h = (data_layout == DataLayout::kNCHW ? input->dims()[2] + : input->dims()[1]); + const int w = (data_layout == DataLayout::kNCHW ? input->dims()[3] + : input->dims()[2]); const int hw = h * w; const int chw = c * hw; const int tchw = t * chw; const int ntchw = nt * chw; + const int c1 = static_cast(c * shift_ratio); + const int c2 = static_cast(c * 2 * shift_ratio); + + framework::DDim out_dims = (data_layout == DataLayout::kNCHW + ? framework::make_ddim({nt, c, h, w}) + : framework::make_ddim({nt, h, w, c})); const T* input_data = input->data(); - T* output_data = output->mutable_data({nt, c, h, w}, ctx.GetPlace()); + T* output_data = output->mutable_data(out_dims, ctx.GetPlace()); int pixelNum = nt * chw; - platform::GpuLaunchConfig config = - platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), pixelNum); + int threads = 1024; + int grid = (pixelNum + threads - 1) / threads; + const auto& dev_ctx = ctx.cuda_device_context(); + int blocks_per_sm = dev_ctx.GetMaxPhysicalThreadCount() / threads; + grid = std::min(dev_ctx.GetSMCount() * blocks_per_sm, grid); - KeTemporalShiftFw<<>>( - input_data, output_data, ntchw, tchw, chw, hw, w, t, c, shift_ratio); + if (data_layout == DataLayout::kNCHW) { + KeTemporalShiftFwNCHW< + T><<>>( + input_data, output_data, ntchw, tchw, chw, hw, t, c1, c2); + } else { + KeTemporalShiftFwNHWC< + T><<>>( + input_data, output_data, ntchw, tchw, chw, t, c, c1, c2); + } } }; @@ -130,32 +198,49 @@ class TemporalShiftGradOpCUDAKernel : public framework::OpKernel { auto* output_grad = ctx.Input(framework::GradVarName("Out")); int t = ctx.Attr("seg_num"); float shift_ratio = ctx.Attr("shift_ratio"); + const std::string data_format_str = ctx.Attr("data_format"); + const DataLayout data_layout = + framework::StringToDataLayout(data_format_str); const int nt = output_grad->dims()[0]; - const int c = output_grad->dims()[1]; - const int h = output_grad->dims()[2]; - const int w = output_grad->dims()[3]; + const int c = (data_layout == DataLayout::kNCHW ? output_grad->dims()[1] + : output_grad->dims()[3]); + const int h = (data_layout == DataLayout::kNCHW ? output_grad->dims()[2] + : output_grad->dims()[1]); + const int w = (data_layout == DataLayout::kNCHW ? output_grad->dims()[3] + : output_grad->dims()[2]); const int hw = h * w; const int chw = c * hw; const int tchw = t * chw; const int ntchw = nt * chw; + const int c1 = static_cast(c * shift_ratio); + const int c2 = static_cast(c * 2 * shift_ratio); + + framework::DDim in_grad_dims = (data_layout == DataLayout::kNCHW + ? framework::make_ddim({nt, c, h, w}) + : framework::make_ddim({nt, h, w, c})); const T* output_grad_data = output_grad->data(); T* input_grad_data = - input_grad->mutable_data({nt, c, h, w}, ctx.GetPlace()); - math::SetConstant()( - ctx.template device_context(), input_grad, - static_cast(0)); + input_grad->mutable_data(in_grad_dims, ctx.GetPlace()); int pixelNum = nt * chw; - platform::GpuLaunchConfig config = - platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), pixelNum); + int threads = 1024; + int grid = (pixelNum + threads - 1) / threads; + const auto& dev_ctx = ctx.cuda_device_context(); + int blocks_per_sm = dev_ctx.GetMaxPhysicalThreadCount() / threads; + grid = std::min(dev_ctx.GetSMCount() * blocks_per_sm, grid); - KeTemporalShiftBw<<>>( - output_grad_data, input_grad_data, ntchw, tchw, chw, hw, w, t, c, - shift_ratio); + if (data_layout == DataLayout::kNCHW) { + KeTemporalShiftBwNCHW< + T><<>>( + output_grad_data, input_grad_data, ntchw, tchw, chw, hw, t, c1, c2); + } else { + KeTemporalShiftBwNHWC< + T><<>>( + output_grad_data, input_grad_data, ntchw, tchw, chw, t, c, c1, c2); + } } }; diff --git a/paddle/fluid/operators/temporal_shift_op.h b/paddle/fluid/operators/temporal_shift_op.h index 4c7eed5af4..05364b94c9 100644 --- a/paddle/fluid/operators/temporal_shift_op.h +++ b/paddle/fluid/operators/temporal_shift_op.h @@ -17,12 +17,106 @@ namespace paddle { namespace operators { using Tensor = framework::Tensor; +using DataLayout = framework::DataLayout; -static HOSTDEVICE inline int GetEntryIndex(int in, int it, int ic, int ih, - int iw, const int tchw, - const int chw, const int hw, - const int w) { - return in * tchw + it * chw + ic * hw + ih * w + iw; +template +void TemporalShiftFwNCHW(const T* input, T* output, const int ntchw, + const int tchw, const int chw, const int hw, + const int t, const int c1, const int c2) { + int src_it = 0; + for (int i = 0; i < ntchw; i++) { + int it = (i % tchw) / chw; + int ic = (i % chw) / hw; + + if (ic < c1) { + src_it = it - 1; + } else if (ic < c2) { + src_it = it + 1; + } else { + src_it = it; + } + + if (src_it < 0 || src_it >= t) { + output[i] = 0; + } else { + output[i] = input[i + (src_it - it) * chw]; + } + } +} + +template +void TemporalShiftFwNHWC(const T* input, T* output, const int nthwc, + const int thwc, const int hwc, const int t, + const int c, const int c1, const int c2) { + int src_it = 0; + for (int i = 0; i < nthwc; i++) { + int it = (i % thwc) / hwc; + int ic = i % c; + + if (ic < c1) { + src_it = it - 1; + } else if (ic < c2) { + src_it = it + 1; + } else { + src_it = it; + } + + if (src_it < 0 || src_it >= t) { + output[i] = 0; + } else { + output[i] = input[i + (src_it - it) * hwc]; + } + } +} + +template +void TemporalShiftBwNCHW(const T* output_grad, T* input_grad, const int ntchw, + const int tchw, const int chw, const int hw, + const int t, const int c1, const int c2) { + int src_it = 0; + for (int i = 0; i < ntchw; i++) { + int it = (i % tchw) / chw; + int ic = (i % chw) / hw; + + if (ic < c1) { + src_it = it + 1; + } else if (ic < c2) { + src_it = it - 1; + } else { + src_it = it; + } + + if (src_it >= 0 && src_it < t) { + input_grad[i] = output_grad[i + (src_it - it) * chw]; + } else { + input_grad[i] = 0; + } + } +} + +template +void TemporalShiftBwNHWC(const T* output_grad, T* input_grad, const int nthwc, + const int thwc, const int hwc, const int t, + const int c, const int c1, const int c2) { + int src_it = 0; + for (int i = 0; i < nthwc; i++) { + int it = (i % thwc) / hwc; + int ic = i % c; + + if (ic < c1) { + src_it = it + 1; + } else if (ic < c2) { + src_it = it - 1; + } else { + src_it = it; + } + + if (src_it >= 0 && src_it < t) { + input_grad[i] = output_grad[i + (src_it - it) * hwc]; + } else { + input_grad[i] = 0; + } + } } template @@ -33,44 +127,38 @@ class TemporalShiftKernel : public framework::OpKernel { auto* output = ctx.Output("Out"); int t = ctx.Attr("seg_num"); float shift_ratio = ctx.Attr("shift_ratio"); + const std::string data_format_str = ctx.Attr("data_format"); + const DataLayout data_layout = + framework::StringToDataLayout(data_format_str); const int nt = input->dims()[0]; - const int c = input->dims()[1]; - const int h = input->dims()[2]; - const int w = input->dims()[3]; - - const int c1 = static_cast(c * shift_ratio); - const int c2 = static_cast(c * 2 * shift_ratio); + const int c = (data_layout == DataLayout::kNCHW ? input->dims()[1] + : input->dims()[3]); + const int h = (data_layout == DataLayout::kNCHW ? input->dims()[2] + : input->dims()[1]); + const int w = (data_layout == DataLayout::kNCHW ? input->dims()[3] + : input->dims()[2]); const int hw = h * w; const int chw = c * hw; const int tchw = t * chw; + const int ntchw = nt * chw; + const int c1 = static_cast(c * shift_ratio); + const int c2 = static_cast(c * 2 * shift_ratio); + + framework::DDim out_dims = (data_layout == DataLayout::kNCHW + ? framework::make_ddim({nt, c, h, w}) + : framework::make_ddim({nt, h, w, c})); const T* input_data = input->data(); - T* output_data = output->mutable_data({nt, c, h, w}, ctx.GetPlace()); - - int src_it = 0; - for (int i = 0; i < output->numel(); i++) { - int in = i / tchw; - int it = (i % tchw) / chw; - int ic = (i % chw) / hw; - int ih = (i % hw) / w; - int iw = i % w; - - if (ic < c1) { - src_it = it - 1; - } else if (ic < c2) { - src_it = it + 1; - } else { - src_it = it; - } - - if (src_it < 0 || src_it >= t) { - output_data[i] = 0; - } else { - int src_idx = GetEntryIndex(in, src_it, ic, ih, iw, tchw, chw, hw, w); - output_data[i] = input_data[src_idx]; - } + T* output_data = output->mutable_data(out_dims, ctx.GetPlace()); + + if (data_layout == DataLayout::kNCHW) { + TemporalShiftFwNCHW(input_data, output_data, ntchw, tchw, chw, hw, t, + c1, c2); + } else { + TemporalShiftFwNHWC(input_data, output_data, ntchw, tchw, chw, t, c, + c1, c2); } } }; @@ -83,44 +171,39 @@ class TemporalShiftGradKernel : public framework::OpKernel { auto* output_grad = ctx.Input(framework::GradVarName("Out")); int t = ctx.Attr("seg_num"); float shift_ratio = ctx.Attr("shift_ratio"); + const std::string data_format_str = ctx.Attr("data_format"); + const DataLayout data_layout = + framework::StringToDataLayout(data_format_str); const int nt = output_grad->dims()[0]; - const int c = output_grad->dims()[1]; - const int h = output_grad->dims()[2]; - const int w = output_grad->dims()[3]; - - const int c1 = static_cast(c * shift_ratio); - const int c2 = static_cast(c * 2 * shift_ratio); + const int c = (data_layout == DataLayout::kNCHW ? output_grad->dims()[1] + : output_grad->dims()[3]); + const int h = (data_layout == DataLayout::kNCHW ? output_grad->dims()[2] + : output_grad->dims()[1]); + const int w = (data_layout == DataLayout::kNCHW ? output_grad->dims()[3] + : output_grad->dims()[2]); const int hw = h * w; const int chw = c * hw; const int tchw = t * chw; + const int ntchw = nt * chw; + + const int c1 = static_cast(c * shift_ratio); + const int c2 = static_cast(c * 2 * shift_ratio); + framework::DDim in_grad_dims = (data_layout == DataLayout::kNCHW + ? framework::make_ddim({nt, c, h, w}) + : framework::make_ddim({nt, h, w, c})); const T* output_grad_data = output_grad->data(); T* input_grad_data = - input_grad->mutable_data({nt, c, h, w}, ctx.GetPlace()); - memset(input_grad_data, 0, input_grad->numel() * sizeof(T)); - - int src_it = 0; - for (int i = 0; i < output_grad->numel(); i++) { - int in = i / tchw; - int it = (i % tchw) / chw; - int ic = (i % chw) / hw; - int ih = (i % hw) / w; - int iw = i % w; - - if (ic < c1) { - src_it = it - 1; - } else if (ic < c2) { - src_it = it + 1; - } else { - src_it = it; - } - - if (src_it >= 0 && src_it < t) { - int src_idx = GetEntryIndex(in, src_it, ic, ih, iw, tchw, chw, hw, w); - input_grad_data[src_idx] = output_grad_data[i]; - } + input_grad->mutable_data(in_grad_dims, ctx.GetPlace()); + + if (data_layout == DataLayout::kNCHW) { + TemporalShiftBwNCHW(output_grad_data, input_grad_data, ntchw, tchw, + chw, hw, t, c1, c2); + } else { + TemporalShiftBwNHWC(output_grad_data, input_grad_data, ntchw, tchw, + chw, t, c, c1, c2); } } }; diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 8d96e46f83..fa8df14c86 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -13334,7 +13334,7 @@ def shuffle_channel(x, group, name=None): @templatedoc() -def temporal_shift(x, seg_num, shift_ratio=0.25, name=None): +def temporal_shift(x, seg_num, shift_ratio=0.25, name=None, data_format="NCHW"): """ **Temporal Shift Operator** @@ -13348,6 +13348,8 @@ def temporal_shift(x, seg_num, shift_ratio=0.25, name=None): name(str, optional): For detailed information, please refer to :ref:`api_guide_Name`. Usually name is no need to set and None by default. + data_format(str, optional): Data format that specifies the layout of input. + It can be "NCHW" or "NHWC". Default: "NCHW". Returns: out(Tensor): The temporal shifting result is a tensor with the @@ -13365,6 +13367,13 @@ def temporal_shift(x, seg_num, shift_ratio=0.25, name=None): input = paddle.randn([6, 4, 2, 2]) out = F.temporal_shift(x=input, seg_num=2, shift_ratio=0.2) """ + if data_format not in ["NCHW", "NHWC"]: + raise ValueError("Attr(data_format) should be 'NCHW' or 'NHWC'. " + "Received Attr(data_format): {}.".format(data_format)) + if in_dygraph_mode(): + return core.ops.temporal_shift(x, 'seg_num', seg_num, 'shift_ratio', + shift_ratio, 'data_format', data_format) + helper = LayerHelper("temporal_shift", **locals()) check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'temporal_shift') check_type(seg_num, 'seg_num', int, 'temporal_shift') @@ -13375,16 +13384,15 @@ def temporal_shift(x, seg_num, shift_ratio=0.25, name=None): if not isinstance(seg_num, int): raise TypeError("seg_num must be int type.") - if in_dygraph_mode(): - return core.ops.temporal_shift(x, 'seg_num', seg_num, 'shift_ratio', - shift_ratio) - helper.append_op( type="temporal_shift", inputs={"X": x}, outputs={"Out": out}, - attrs={"seg_num": seg_num, - "shift_ratio": shift_ratio}) + attrs={ + "seg_num": seg_num, + "shift_ratio": shift_ratio, + "data_format": data_format + }) return out diff --git a/python/paddle/fluid/tests/unittests/test_temporal_shift_op.py b/python/paddle/fluid/tests/unittests/test_temporal_shift_op.py index 050c38e549..5bab4a52bf 100644 --- a/python/paddle/fluid/tests/unittests/test_temporal_shift_op.py +++ b/python/paddle/fluid/tests/unittests/test_temporal_shift_op.py @@ -22,7 +22,9 @@ import paddle from paddle.fluid import core -def temporal_shift(x, seg_num, shift_ratio): +def temporal_shift(x, seg_num, shift_ratio, data_format): + if data_format == "NHWC": + x = np.transpose(x, (0, 3, 1, 2)) shape = x.shape reshape_x = x.reshape((-1, seg_num, shape[1], shape[2], shape[3])) pad_x = np.pad(reshape_x, ((0, 0), (1, 1), (0, 0), (0, 0), (0, 0)), @@ -33,7 +35,10 @@ def temporal_shift(x, seg_num, shift_ratio): slice2 = pad_x[:, 2:seg_num + 2, c1:c2, :, :] slice3 = pad_x[:, 1:seg_num + 1, c2:, :, :] concat_x = np.concatenate([slice1, slice2, slice3], axis=2) - return concat_x.reshape(shape) + out = concat_x.reshape(shape) + if data_format == "NHWC": + out = np.transpose(out, (0, 2, 3, 1)) + return out class TestTemporalShift(OpTest): @@ -45,11 +50,13 @@ class TestTemporalShift(OpTest): self.attrs = { "seg_num": self.seg_num, "shift_ratio": self.shift_ratio, + "data_format": self.data_format } self.inputs = {"X": x, } - output = temporal_shift(x, self.seg_num, self.shift_ratio) + output = temporal_shift(x, self.seg_num, self.shift_ratio, + self.data_format) self.outputs = {"Out": output} def test_check_output(self): @@ -63,6 +70,7 @@ class TestTemporalShift(OpTest): self.seg_num = 3 self.shift_ratio = 0.25 self.dtype = 'float64' + self.data_format = 'NCHW' class TestTemporalShift2(TestTemporalShift): @@ -70,6 +78,7 @@ class TestTemporalShift2(TestTemporalShift): self.x_shape = (4, 9, 7, 7) self.seg_num = 2 self.shift_ratio = 0.2 + self.data_format = 'NCHW' class TestTemporalShift3(TestTemporalShift): @@ -77,6 +86,15 @@ class TestTemporalShift3(TestTemporalShift): self.x_shape = (3, 10, 5, 5) self.seg_num = 1 self.shift_ratio = 0.3 + self.data_format = 'NCHW' + + +class TestTemporalShift4(TestTemporalShift): + def initTestCase(self): + self.x_shape = (6, 5, 5, 4) + self.seg_num = 3 + self.shift_ratio = 0.25 + self.data_format = 'NHWC' @unittest.skipIf(not core.is_compiled_with_cuda(), @@ -87,6 +105,7 @@ class TestTemporalShiftFP16(TestTemporalShift): self.seg_num = 1 self.shift_ratio = 0.3 self.dtype = 'float16' + self.data_format = 'NCHW' def test_check_output(self): place = core.CUDAPlace(0) @@ -114,6 +133,14 @@ class TestTemporalShiftAPI(unittest.TestCase): out = paddle.nn.functional.temporal_shift( x=input, seg_num=2, shift_ratio=0.2) + def test_error(self): + def attr_data_format(): + input = paddle.randn([6, 4, 2, 2]) + out = paddle.nn.functional.temporal_shift( + x=input, seg_num=2, shift_ratio=0.2, data_format="HWC") + + self.assertRaises(ValueError, attr_data_format) + if __name__ == "__main__": unittest.main() -- GitLab