未验证 提交 7f50bb7e 编写于 作者: Z Zhang Ting 提交者: GitHub

support NHWC for temporal_shift op (#31642)

上级 402288ad
......@@ -80,7 +80,8 @@ class TemporalShiftOpMaker : public framework::OpProtoAndCheckerMaker {
void Make() override {
AddInput("X",
"The input tensor of temporal shift operator. "
"This is a 4-D tensor with shape of [N*T, C, H, W]. "
"This is a 4-D tensor with shape of [N*T, C, H, W] "
"or [N*T, H, W, C]. "
"While N is the batch size, T is the temporal segment "
"number, C is the channel number, H is the height of "
"features and W is the width of features. "
......@@ -100,15 +101,23 @@ class TemporalShiftOpMaker : public framework::OpProtoAndCheckerMaker {
"by 1 along the temporal dimension. :attr:`shift_ratio` should be in "
"range [0, 0.5]. Default 0.25.")
.SetDefault(0.25);
AddAttr<std::string>(
"data_format",
"(string, default NCHW) Only used in "
"an optional string from: \"NHWC\", \"NCHW\". "
"Specify that the data format of the input and output data is "
"channel_first or channel_last.")
.SetDefault("NCHW");
AddComment(R"DOC(
This operator calculates the temporal shifting features for Input(X).
Input(X) should be in shape of [N*T, C, H, W], while N is the batch
size, T is the temporal segment number specified by :attr:`seg_num`,
C is the channel number, H and W is the height and width of features.
Input(X) should be in shape of [N*T, C, H, W] or [N*T, H, W, C], while
N is the batch size, T is the temporal segment number specified by
:attr:`seg_num`, C is the channel number, H and W is the height and
width of features.
Temporal Shifting is calculated as follows:
Temporal Shifting is calculated as follows when data format is NCHW:
Step 1: Reshape Input(X) to [N, T, C, H, W].
......
......@@ -19,22 +19,46 @@ namespace operators {
using framework::Tensor;
template <typename T>
__global__ void KeTemporalShiftFw(const T* input, T* output, const int ntchw,
const int tchw, const int chw, const int hw,
const int w, const int t, const int c,
const float shift_ratio) {
__global__ void KeTemporalShiftFwNCHW(const T* input, T* output,
const int ntchw, const int tchw,
const int chw, const int hw, const int t,
const int c1, const int c2) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
int src_it = 0;
for (; tid < ntchw; tid += stride) {
int in = tid / tchw;
int it = (tid % tchw) / chw;
int ic = (tid % chw) / hw;
int ih = (tid % hw) / w;
int iw = tid % w;
const int c1 = static_cast<int>(c * shift_ratio);
const int c2 = static_cast<int>(c * 2 * shift_ratio);
if (ic < c1) {
src_it = it - 1;
} else if (ic < c2) {
src_it = it + 1;
} else {
src_it = it;
}
if (src_it < 0 || src_it >= t) {
output[tid] = 0;
} else {
output[tid] = input[tid + (src_it - it) * chw];
}
}
}
template <typename T>
__global__ void KeTemporalShiftFwNHWC(const T* input, T* output,
const int nthwc, const int thwc,
const int hwc, const int t, const int c,
const int c1, const int c2) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
int src_it = 0;
for (; tid < nthwc; tid += stride) {
int it = (tid % thwc) / hwc;
int ic = tid % c;
if (ic < c1) {
src_it = it - 1;
......@@ -47,42 +71,65 @@ __global__ void KeTemporalShiftFw(const T* input, T* output, const int ntchw,
if (src_it < 0 || src_it >= t) {
output[tid] = 0;
} else {
int src_idx = GetEntryIndex(in, src_it, ic, ih, iw, tchw, chw, hw, w);
output[tid] = input[src_idx];
output[tid] = input[tid + (src_it - it) * hwc];
}
}
}
template <typename T>
__global__ void KeTemporalShiftBw(const T* output_grad, T* input_grad,
const int ntchw, const int tchw,
const int chw, const int hw, const int w,
const int t, const int c,
const float shift_ratio) {
__global__ void KeTemporalShiftBwNCHW(const T* output_grad, T* input_grad,
const int ntchw, const int tchw,
const int chw, const int hw, const int t,
const int c1, const int c2) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
int src_it = 0;
for (; tid < ntchw; tid += stride) {
int in = tid / tchw;
int it = (tid % tchw) / chw;
int ic = (tid % chw) / hw;
int ih = (tid % hw) / w;
int iw = tid % w;
const int c1 = static_cast<int>(c * shift_ratio);
const int c2 = static_cast<int>(c * 2 * shift_ratio);
if (ic < c1) {
src_it = it - 1;
src_it = it + 1;
} else if (ic < c2) {
src_it = it - 1;
} else {
src_it = it;
}
if (src_it >= 0 && src_it < t) {
input_grad[tid] = output_grad[tid + (src_it - it) * chw];
} else {
input_grad[tid] = 0;
}
}
}
template <typename T>
__global__ void KeTemporalShiftBwNHWC(const T* output_grad, T* input_grad,
const int nthwc, const int thwc,
const int hwc, const int t, const int c,
const int c1, const int c2) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
int src_it = 0;
for (; tid < nthwc; tid += stride) {
int it = (tid % thwc) / hwc;
int ic = tid % c;
if (ic < c1) {
src_it = it + 1;
} else if (ic < c2) {
src_it = it - 1;
} else {
src_it = it;
}
if (src_it >= 0 && src_it < t) {
int src_idx = GetEntryIndex(in, src_it, ic, ih, iw, tchw, chw, hw, w);
input_grad[src_idx] = output_grad[tid];
input_grad[tid] = output_grad[tid + (src_it - it) * hwc];
} else {
input_grad[tid] = 0;
}
}
}
......@@ -98,27 +145,48 @@ class TemporalShiftOpCUDAKernel : public framework::OpKernel<T> {
auto* output = ctx.Output<Tensor>("Out");
int t = ctx.Attr<int>("seg_num");
float shift_ratio = ctx.Attr<float>("shift_ratio");
const std::string data_format_str = ctx.Attr<std::string>("data_format");
const DataLayout data_layout =
framework::StringToDataLayout(data_format_str);
const int nt = input->dims()[0];
const int c = input->dims()[1];
const int h = input->dims()[2];
const int w = input->dims()[3];
const int c = (data_layout == DataLayout::kNCHW ? input->dims()[1]
: input->dims()[3]);
const int h = (data_layout == DataLayout::kNCHW ? input->dims()[2]
: input->dims()[1]);
const int w = (data_layout == DataLayout::kNCHW ? input->dims()[3]
: input->dims()[2]);
const int hw = h * w;
const int chw = c * hw;
const int tchw = t * chw;
const int ntchw = nt * chw;
const int c1 = static_cast<int>(c * shift_ratio);
const int c2 = static_cast<int>(c * 2 * shift_ratio);
framework::DDim out_dims = (data_layout == DataLayout::kNCHW
? framework::make_ddim({nt, c, h, w})
: framework::make_ddim({nt, h, w, c}));
const T* input_data = input->data<T>();
T* output_data = output->mutable_data<T>({nt, c, h, w}, ctx.GetPlace());
T* output_data = output->mutable_data<T>(out_dims, ctx.GetPlace());
int pixelNum = nt * chw;
platform::GpuLaunchConfig config =
platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), pixelNum);
int threads = 1024;
int grid = (pixelNum + threads - 1) / threads;
const auto& dev_ctx = ctx.cuda_device_context();
int blocks_per_sm = dev_ctx.GetMaxPhysicalThreadCount() / threads;
grid = std::min(dev_ctx.GetSMCount() * blocks_per_sm, grid);
KeTemporalShiftFw<T><<<config.block_per_grid, config.thread_per_block, 0,
ctx.cuda_device_context().stream()>>>(
input_data, output_data, ntchw, tchw, chw, hw, w, t, c, shift_ratio);
if (data_layout == DataLayout::kNCHW) {
KeTemporalShiftFwNCHW<
T><<<grid, threads, 0, ctx.cuda_device_context().stream()>>>(
input_data, output_data, ntchw, tchw, chw, hw, t, c1, c2);
} else {
KeTemporalShiftFwNHWC<
T><<<grid, threads, 0, ctx.cuda_device_context().stream()>>>(
input_data, output_data, ntchw, tchw, chw, t, c, c1, c2);
}
}
};
......@@ -130,32 +198,49 @@ class TemporalShiftGradOpCUDAKernel : public framework::OpKernel<T> {
auto* output_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
int t = ctx.Attr<int>("seg_num");
float shift_ratio = ctx.Attr<float>("shift_ratio");
const std::string data_format_str = ctx.Attr<std::string>("data_format");
const DataLayout data_layout =
framework::StringToDataLayout(data_format_str);
const int nt = output_grad->dims()[0];
const int c = output_grad->dims()[1];
const int h = output_grad->dims()[2];
const int w = output_grad->dims()[3];
const int c = (data_layout == DataLayout::kNCHW ? output_grad->dims()[1]
: output_grad->dims()[3]);
const int h = (data_layout == DataLayout::kNCHW ? output_grad->dims()[2]
: output_grad->dims()[1]);
const int w = (data_layout == DataLayout::kNCHW ? output_grad->dims()[3]
: output_grad->dims()[2]);
const int hw = h * w;
const int chw = c * hw;
const int tchw = t * chw;
const int ntchw = nt * chw;
const int c1 = static_cast<int>(c * shift_ratio);
const int c2 = static_cast<int>(c * 2 * shift_ratio);
framework::DDim in_grad_dims = (data_layout == DataLayout::kNCHW
? framework::make_ddim({nt, c, h, w})
: framework::make_ddim({nt, h, w, c}));
const T* output_grad_data = output_grad->data<T>();
T* input_grad_data =
input_grad->mutable_data<T>({nt, c, h, w}, ctx.GetPlace());
math::SetConstant<platform::CUDADeviceContext, T>()(
ctx.template device_context<platform::CUDADeviceContext>(), input_grad,
static_cast<T>(0));
input_grad->mutable_data<T>(in_grad_dims, ctx.GetPlace());
int pixelNum = nt * chw;
platform::GpuLaunchConfig config =
platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), pixelNum);
int threads = 1024;
int grid = (pixelNum + threads - 1) / threads;
const auto& dev_ctx = ctx.cuda_device_context();
int blocks_per_sm = dev_ctx.GetMaxPhysicalThreadCount() / threads;
grid = std::min(dev_ctx.GetSMCount() * blocks_per_sm, grid);
KeTemporalShiftBw<T><<<config.block_per_grid, config.thread_per_block, 0,
ctx.cuda_device_context().stream()>>>(
output_grad_data, input_grad_data, ntchw, tchw, chw, hw, w, t, c,
shift_ratio);
if (data_layout == DataLayout::kNCHW) {
KeTemporalShiftBwNCHW<
T><<<grid, threads, 0, ctx.cuda_device_context().stream()>>>(
output_grad_data, input_grad_data, ntchw, tchw, chw, hw, t, c1, c2);
} else {
KeTemporalShiftBwNHWC<
T><<<grid, threads, 0, ctx.cuda_device_context().stream()>>>(
output_grad_data, input_grad_data, ntchw, tchw, chw, t, c, c1, c2);
}
}
};
......
......@@ -17,12 +17,106 @@ namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using DataLayout = framework::DataLayout;
static HOSTDEVICE inline int GetEntryIndex(int in, int it, int ic, int ih,
int iw, const int tchw,
const int chw, const int hw,
const int w) {
return in * tchw + it * chw + ic * hw + ih * w + iw;
template <typename T>
void TemporalShiftFwNCHW(const T* input, T* output, const int ntchw,
const int tchw, const int chw, const int hw,
const int t, const int c1, const int c2) {
int src_it = 0;
for (int i = 0; i < ntchw; i++) {
int it = (i % tchw) / chw;
int ic = (i % chw) / hw;
if (ic < c1) {
src_it = it - 1;
} else if (ic < c2) {
src_it = it + 1;
} else {
src_it = it;
}
if (src_it < 0 || src_it >= t) {
output[i] = 0;
} else {
output[i] = input[i + (src_it - it) * chw];
}
}
}
template <typename T>
void TemporalShiftFwNHWC(const T* input, T* output, const int nthwc,
const int thwc, const int hwc, const int t,
const int c, const int c1, const int c2) {
int src_it = 0;
for (int i = 0; i < nthwc; i++) {
int it = (i % thwc) / hwc;
int ic = i % c;
if (ic < c1) {
src_it = it - 1;
} else if (ic < c2) {
src_it = it + 1;
} else {
src_it = it;
}
if (src_it < 0 || src_it >= t) {
output[i] = 0;
} else {
output[i] = input[i + (src_it - it) * hwc];
}
}
}
template <typename T>
void TemporalShiftBwNCHW(const T* output_grad, T* input_grad, const int ntchw,
const int tchw, const int chw, const int hw,
const int t, const int c1, const int c2) {
int src_it = 0;
for (int i = 0; i < ntchw; i++) {
int it = (i % tchw) / chw;
int ic = (i % chw) / hw;
if (ic < c1) {
src_it = it + 1;
} else if (ic < c2) {
src_it = it - 1;
} else {
src_it = it;
}
if (src_it >= 0 && src_it < t) {
input_grad[i] = output_grad[i + (src_it - it) * chw];
} else {
input_grad[i] = 0;
}
}
}
template <typename T>
void TemporalShiftBwNHWC(const T* output_grad, T* input_grad, const int nthwc,
const int thwc, const int hwc, const int t,
const int c, const int c1, const int c2) {
int src_it = 0;
for (int i = 0; i < nthwc; i++) {
int it = (i % thwc) / hwc;
int ic = i % c;
if (ic < c1) {
src_it = it + 1;
} else if (ic < c2) {
src_it = it - 1;
} else {
src_it = it;
}
if (src_it >= 0 && src_it < t) {
input_grad[i] = output_grad[i + (src_it - it) * hwc];
} else {
input_grad[i] = 0;
}
}
}
template <typename T>
......@@ -33,44 +127,38 @@ class TemporalShiftKernel : public framework::OpKernel<T> {
auto* output = ctx.Output<Tensor>("Out");
int t = ctx.Attr<int>("seg_num");
float shift_ratio = ctx.Attr<float>("shift_ratio");
const std::string data_format_str = ctx.Attr<std::string>("data_format");
const DataLayout data_layout =
framework::StringToDataLayout(data_format_str);
const int nt = input->dims()[0];
const int c = input->dims()[1];
const int h = input->dims()[2];
const int w = input->dims()[3];
const int c1 = static_cast<int>(c * shift_ratio);
const int c2 = static_cast<int>(c * 2 * shift_ratio);
const int c = (data_layout == DataLayout::kNCHW ? input->dims()[1]
: input->dims()[3]);
const int h = (data_layout == DataLayout::kNCHW ? input->dims()[2]
: input->dims()[1]);
const int w = (data_layout == DataLayout::kNCHW ? input->dims()[3]
: input->dims()[2]);
const int hw = h * w;
const int chw = c * hw;
const int tchw = t * chw;
const int ntchw = nt * chw;
const int c1 = static_cast<int>(c * shift_ratio);
const int c2 = static_cast<int>(c * 2 * shift_ratio);
framework::DDim out_dims = (data_layout == DataLayout::kNCHW
? framework::make_ddim({nt, c, h, w})
: framework::make_ddim({nt, h, w, c}));
const T* input_data = input->data<T>();
T* output_data = output->mutable_data<T>({nt, c, h, w}, ctx.GetPlace());
int src_it = 0;
for (int i = 0; i < output->numel(); i++) {
int in = i / tchw;
int it = (i % tchw) / chw;
int ic = (i % chw) / hw;
int ih = (i % hw) / w;
int iw = i % w;
if (ic < c1) {
src_it = it - 1;
} else if (ic < c2) {
src_it = it + 1;
} else {
src_it = it;
}
if (src_it < 0 || src_it >= t) {
output_data[i] = 0;
} else {
int src_idx = GetEntryIndex(in, src_it, ic, ih, iw, tchw, chw, hw, w);
output_data[i] = input_data[src_idx];
}
T* output_data = output->mutable_data<T>(out_dims, ctx.GetPlace());
if (data_layout == DataLayout::kNCHW) {
TemporalShiftFwNCHW<T>(input_data, output_data, ntchw, tchw, chw, hw, t,
c1, c2);
} else {
TemporalShiftFwNHWC<T>(input_data, output_data, ntchw, tchw, chw, t, c,
c1, c2);
}
}
};
......@@ -83,44 +171,39 @@ class TemporalShiftGradKernel : public framework::OpKernel<T> {
auto* output_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
int t = ctx.Attr<int>("seg_num");
float shift_ratio = ctx.Attr<float>("shift_ratio");
const std::string data_format_str = ctx.Attr<std::string>("data_format");
const DataLayout data_layout =
framework::StringToDataLayout(data_format_str);
const int nt = output_grad->dims()[0];
const int c = output_grad->dims()[1];
const int h = output_grad->dims()[2];
const int w = output_grad->dims()[3];
const int c1 = static_cast<int>(c * shift_ratio);
const int c2 = static_cast<int>(c * 2 * shift_ratio);
const int c = (data_layout == DataLayout::kNCHW ? output_grad->dims()[1]
: output_grad->dims()[3]);
const int h = (data_layout == DataLayout::kNCHW ? output_grad->dims()[2]
: output_grad->dims()[1]);
const int w = (data_layout == DataLayout::kNCHW ? output_grad->dims()[3]
: output_grad->dims()[2]);
const int hw = h * w;
const int chw = c * hw;
const int tchw = t * chw;
const int ntchw = nt * chw;
const int c1 = static_cast<int>(c * shift_ratio);
const int c2 = static_cast<int>(c * 2 * shift_ratio);
framework::DDim in_grad_dims = (data_layout == DataLayout::kNCHW
? framework::make_ddim({nt, c, h, w})
: framework::make_ddim({nt, h, w, c}));
const T* output_grad_data = output_grad->data<T>();
T* input_grad_data =
input_grad->mutable_data<T>({nt, c, h, w}, ctx.GetPlace());
memset(input_grad_data, 0, input_grad->numel() * sizeof(T));
int src_it = 0;
for (int i = 0; i < output_grad->numel(); i++) {
int in = i / tchw;
int it = (i % tchw) / chw;
int ic = (i % chw) / hw;
int ih = (i % hw) / w;
int iw = i % w;
if (ic < c1) {
src_it = it - 1;
} else if (ic < c2) {
src_it = it + 1;
} else {
src_it = it;
}
if (src_it >= 0 && src_it < t) {
int src_idx = GetEntryIndex(in, src_it, ic, ih, iw, tchw, chw, hw, w);
input_grad_data[src_idx] = output_grad_data[i];
}
input_grad->mutable_data<T>(in_grad_dims, ctx.GetPlace());
if (data_layout == DataLayout::kNCHW) {
TemporalShiftBwNCHW<T>(output_grad_data, input_grad_data, ntchw, tchw,
chw, hw, t, c1, c2);
} else {
TemporalShiftBwNHWC<T>(output_grad_data, input_grad_data, ntchw, tchw,
chw, t, c, c1, c2);
}
}
};
......
......@@ -13334,7 +13334,7 @@ def shuffle_channel(x, group, name=None):
@templatedoc()
def temporal_shift(x, seg_num, shift_ratio=0.25, name=None):
def temporal_shift(x, seg_num, shift_ratio=0.25, name=None, data_format="NCHW"):
"""
**Temporal Shift Operator**
......@@ -13348,6 +13348,8 @@ def temporal_shift(x, seg_num, shift_ratio=0.25, name=None):
name(str, optional): For detailed information, please refer
to :ref:`api_guide_Name`. Usually name is no need to set and
None by default.
data_format(str, optional): Data format that specifies the layout of input.
It can be "NCHW" or "NHWC". Default: "NCHW".
Returns:
out(Tensor): The temporal shifting result is a tensor with the
......@@ -13365,6 +13367,13 @@ def temporal_shift(x, seg_num, shift_ratio=0.25, name=None):
input = paddle.randn([6, 4, 2, 2])
out = F.temporal_shift(x=input, seg_num=2, shift_ratio=0.2)
"""
if data_format not in ["NCHW", "NHWC"]:
raise ValueError("Attr(data_format) should be 'NCHW' or 'NHWC'. "
"Received Attr(data_format): {}.".format(data_format))
if in_dygraph_mode():
return core.ops.temporal_shift(x, 'seg_num', seg_num, 'shift_ratio',
shift_ratio, 'data_format', data_format)
helper = LayerHelper("temporal_shift", **locals())
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'temporal_shift')
check_type(seg_num, 'seg_num', int, 'temporal_shift')
......@@ -13375,16 +13384,15 @@ def temporal_shift(x, seg_num, shift_ratio=0.25, name=None):
if not isinstance(seg_num, int):
raise TypeError("seg_num must be int type.")
if in_dygraph_mode():
return core.ops.temporal_shift(x, 'seg_num', seg_num, 'shift_ratio',
shift_ratio)
helper.append_op(
type="temporal_shift",
inputs={"X": x},
outputs={"Out": out},
attrs={"seg_num": seg_num,
"shift_ratio": shift_ratio})
attrs={
"seg_num": seg_num,
"shift_ratio": shift_ratio,
"data_format": data_format
})
return out
......
......@@ -22,7 +22,9 @@ import paddle
from paddle.fluid import core
def temporal_shift(x, seg_num, shift_ratio):
def temporal_shift(x, seg_num, shift_ratio, data_format):
if data_format == "NHWC":
x = np.transpose(x, (0, 3, 1, 2))
shape = x.shape
reshape_x = x.reshape((-1, seg_num, shape[1], shape[2], shape[3]))
pad_x = np.pad(reshape_x, ((0, 0), (1, 1), (0, 0), (0, 0), (0, 0)),
......@@ -33,7 +35,10 @@ def temporal_shift(x, seg_num, shift_ratio):
slice2 = pad_x[:, 2:seg_num + 2, c1:c2, :, :]
slice3 = pad_x[:, 1:seg_num + 1, c2:, :, :]
concat_x = np.concatenate([slice1, slice2, slice3], axis=2)
return concat_x.reshape(shape)
out = concat_x.reshape(shape)
if data_format == "NHWC":
out = np.transpose(out, (0, 2, 3, 1))
return out
class TestTemporalShift(OpTest):
......@@ -45,11 +50,13 @@ class TestTemporalShift(OpTest):
self.attrs = {
"seg_num": self.seg_num,
"shift_ratio": self.shift_ratio,
"data_format": self.data_format
}
self.inputs = {"X": x, }
output = temporal_shift(x, self.seg_num, self.shift_ratio)
output = temporal_shift(x, self.seg_num, self.shift_ratio,
self.data_format)
self.outputs = {"Out": output}
def test_check_output(self):
......@@ -63,6 +70,7 @@ class TestTemporalShift(OpTest):
self.seg_num = 3
self.shift_ratio = 0.25
self.dtype = 'float64'
self.data_format = 'NCHW'
class TestTemporalShift2(TestTemporalShift):
......@@ -70,6 +78,7 @@ class TestTemporalShift2(TestTemporalShift):
self.x_shape = (4, 9, 7, 7)
self.seg_num = 2
self.shift_ratio = 0.2
self.data_format = 'NCHW'
class TestTemporalShift3(TestTemporalShift):
......@@ -77,6 +86,15 @@ class TestTemporalShift3(TestTemporalShift):
self.x_shape = (3, 10, 5, 5)
self.seg_num = 1
self.shift_ratio = 0.3
self.data_format = 'NCHW'
class TestTemporalShift4(TestTemporalShift):
def initTestCase(self):
self.x_shape = (6, 5, 5, 4)
self.seg_num = 3
self.shift_ratio = 0.25
self.data_format = 'NHWC'
@unittest.skipIf(not core.is_compiled_with_cuda(),
......@@ -87,6 +105,7 @@ class TestTemporalShiftFP16(TestTemporalShift):
self.seg_num = 1
self.shift_ratio = 0.3
self.dtype = 'float16'
self.data_format = 'NCHW'
def test_check_output(self):
place = core.CUDAPlace(0)
......@@ -114,6 +133,14 @@ class TestTemporalShiftAPI(unittest.TestCase):
out = paddle.nn.functional.temporal_shift(
x=input, seg_num=2, shift_ratio=0.2)
def test_error(self):
def attr_data_format():
input = paddle.randn([6, 4, 2, 2])
out = paddle.nn.functional.temporal_shift(
x=input, seg_num=2, shift_ratio=0.2, data_format="HWC")
self.assertRaises(ValueError, attr_data_format)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册