未验证 提交 7f50bb7e 编写于 作者: Z Zhang Ting 提交者: GitHub

support NHWC for temporal_shift op (#31642)

上级 402288ad
...@@ -80,7 +80,8 @@ class TemporalShiftOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -80,7 +80,8 @@ class TemporalShiftOpMaker : public framework::OpProtoAndCheckerMaker {
void Make() override { void Make() override {
AddInput("X", AddInput("X",
"The input tensor of temporal shift operator. " "The input tensor of temporal shift operator. "
"This is a 4-D tensor with shape of [N*T, C, H, W]. " "This is a 4-D tensor with shape of [N*T, C, H, W] "
"or [N*T, H, W, C]. "
"While N is the batch size, T is the temporal segment " "While N is the batch size, T is the temporal segment "
"number, C is the channel number, H is the height of " "number, C is the channel number, H is the height of "
"features and W is the width of features. " "features and W is the width of features. "
...@@ -100,15 +101,23 @@ class TemporalShiftOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -100,15 +101,23 @@ class TemporalShiftOpMaker : public framework::OpProtoAndCheckerMaker {
"by 1 along the temporal dimension. :attr:`shift_ratio` should be in " "by 1 along the temporal dimension. :attr:`shift_ratio` should be in "
"range [0, 0.5]. Default 0.25.") "range [0, 0.5]. Default 0.25.")
.SetDefault(0.25); .SetDefault(0.25);
AddAttr<std::string>(
"data_format",
"(string, default NCHW) Only used in "
"an optional string from: \"NHWC\", \"NCHW\". "
"Specify that the data format of the input and output data is "
"channel_first or channel_last.")
.SetDefault("NCHW");
AddComment(R"DOC( AddComment(R"DOC(
This operator calculates the temporal shifting features for Input(X). This operator calculates the temporal shifting features for Input(X).
Input(X) should be in shape of [N*T, C, H, W], while N is the batch Input(X) should be in shape of [N*T, C, H, W] or [N*T, H, W, C], while
size, T is the temporal segment number specified by :attr:`seg_num`, N is the batch size, T is the temporal segment number specified by
C is the channel number, H and W is the height and width of features. :attr:`seg_num`, C is the channel number, H and W is the height and
width of features.
Temporal Shifting is calculated as follows: Temporal Shifting is calculated as follows when data format is NCHW:
Step 1: Reshape Input(X) to [N, T, C, H, W]. Step 1: Reshape Input(X) to [N, T, C, H, W].
......
...@@ -19,22 +19,46 @@ namespace operators { ...@@ -19,22 +19,46 @@ namespace operators {
using framework::Tensor; using framework::Tensor;
template <typename T> template <typename T>
__global__ void KeTemporalShiftFw(const T* input, T* output, const int ntchw, __global__ void KeTemporalShiftFwNCHW(const T* input, T* output,
const int tchw, const int chw, const int hw, const int ntchw, const int tchw,
const int w, const int t, const int c, const int chw, const int hw, const int t,
const float shift_ratio) { const int c1, const int c2) {
int tid = blockIdx.x * blockDim.x + threadIdx.x; int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x; int stride = blockDim.x * gridDim.x;
int src_it = 0; int src_it = 0;
for (; tid < ntchw; tid += stride) { for (; tid < ntchw; tid += stride) {
int in = tid / tchw;
int it = (tid % tchw) / chw; int it = (tid % tchw) / chw;
int ic = (tid % chw) / hw; int ic = (tid % chw) / hw;
int ih = (tid % hw) / w;
int iw = tid % w;
const int c1 = static_cast<int>(c * shift_ratio); if (ic < c1) {
const int c2 = static_cast<int>(c * 2 * shift_ratio); src_it = it - 1;
} else if (ic < c2) {
src_it = it + 1;
} else {
src_it = it;
}
if (src_it < 0 || src_it >= t) {
output[tid] = 0;
} else {
output[tid] = input[tid + (src_it - it) * chw];
}
}
}
template <typename T>
__global__ void KeTemporalShiftFwNHWC(const T* input, T* output,
const int nthwc, const int thwc,
const int hwc, const int t, const int c,
const int c1, const int c2) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
int src_it = 0;
for (; tid < nthwc; tid += stride) {
int it = (tid % thwc) / hwc;
int ic = tid % c;
if (ic < c1) { if (ic < c1) {
src_it = it - 1; src_it = it - 1;
...@@ -47,42 +71,65 @@ __global__ void KeTemporalShiftFw(const T* input, T* output, const int ntchw, ...@@ -47,42 +71,65 @@ __global__ void KeTemporalShiftFw(const T* input, T* output, const int ntchw,
if (src_it < 0 || src_it >= t) { if (src_it < 0 || src_it >= t) {
output[tid] = 0; output[tid] = 0;
} else { } else {
int src_idx = GetEntryIndex(in, src_it, ic, ih, iw, tchw, chw, hw, w); output[tid] = input[tid + (src_it - it) * hwc];
output[tid] = input[src_idx];
} }
} }
} }
template <typename T> template <typename T>
__global__ void KeTemporalShiftBw(const T* output_grad, T* input_grad, __global__ void KeTemporalShiftBwNCHW(const T* output_grad, T* input_grad,
const int ntchw, const int tchw, const int ntchw, const int tchw,
const int chw, const int hw, const int w, const int chw, const int hw, const int t,
const int t, const int c, const int c1, const int c2) {
const float shift_ratio) {
int tid = blockIdx.x * blockDim.x + threadIdx.x; int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x; int stride = blockDim.x * gridDim.x;
int src_it = 0; int src_it = 0;
for (; tid < ntchw; tid += stride) { for (; tid < ntchw; tid += stride) {
int in = tid / tchw;
int it = (tid % tchw) / chw; int it = (tid % tchw) / chw;
int ic = (tid % chw) / hw; int ic = (tid % chw) / hw;
int ih = (tid % hw) / w;
int iw = tid % w;
const int c1 = static_cast<int>(c * shift_ratio);
const int c2 = static_cast<int>(c * 2 * shift_ratio);
if (ic < c1) { if (ic < c1) {
src_it = it - 1; src_it = it + 1;
} else if (ic < c2) { } else if (ic < c2) {
src_it = it - 1;
} else {
src_it = it;
}
if (src_it >= 0 && src_it < t) {
input_grad[tid] = output_grad[tid + (src_it - it) * chw];
} else {
input_grad[tid] = 0;
}
}
}
template <typename T>
__global__ void KeTemporalShiftBwNHWC(const T* output_grad, T* input_grad,
const int nthwc, const int thwc,
const int hwc, const int t, const int c,
const int c1, const int c2) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
int src_it = 0;
for (; tid < nthwc; tid += stride) {
int it = (tid % thwc) / hwc;
int ic = tid % c;
if (ic < c1) {
src_it = it + 1; src_it = it + 1;
} else if (ic < c2) {
src_it = it - 1;
} else { } else {
src_it = it; src_it = it;
} }
if (src_it >= 0 && src_it < t) { if (src_it >= 0 && src_it < t) {
int src_idx = GetEntryIndex(in, src_it, ic, ih, iw, tchw, chw, hw, w); input_grad[tid] = output_grad[tid + (src_it - it) * hwc];
input_grad[src_idx] = output_grad[tid]; } else {
input_grad[tid] = 0;
} }
} }
} }
...@@ -98,27 +145,48 @@ class TemporalShiftOpCUDAKernel : public framework::OpKernel<T> { ...@@ -98,27 +145,48 @@ class TemporalShiftOpCUDAKernel : public framework::OpKernel<T> {
auto* output = ctx.Output<Tensor>("Out"); auto* output = ctx.Output<Tensor>("Out");
int t = ctx.Attr<int>("seg_num"); int t = ctx.Attr<int>("seg_num");
float shift_ratio = ctx.Attr<float>("shift_ratio"); float shift_ratio = ctx.Attr<float>("shift_ratio");
const std::string data_format_str = ctx.Attr<std::string>("data_format");
const DataLayout data_layout =
framework::StringToDataLayout(data_format_str);
const int nt = input->dims()[0]; const int nt = input->dims()[0];
const int c = input->dims()[1]; const int c = (data_layout == DataLayout::kNCHW ? input->dims()[1]
const int h = input->dims()[2]; : input->dims()[3]);
const int w = input->dims()[3]; const int h = (data_layout == DataLayout::kNCHW ? input->dims()[2]
: input->dims()[1]);
const int w = (data_layout == DataLayout::kNCHW ? input->dims()[3]
: input->dims()[2]);
const int hw = h * w; const int hw = h * w;
const int chw = c * hw; const int chw = c * hw;
const int tchw = t * chw; const int tchw = t * chw;
const int ntchw = nt * chw; const int ntchw = nt * chw;
const int c1 = static_cast<int>(c * shift_ratio);
const int c2 = static_cast<int>(c * 2 * shift_ratio);
framework::DDim out_dims = (data_layout == DataLayout::kNCHW
? framework::make_ddim({nt, c, h, w})
: framework::make_ddim({nt, h, w, c}));
const T* input_data = input->data<T>(); const T* input_data = input->data<T>();
T* output_data = output->mutable_data<T>({nt, c, h, w}, ctx.GetPlace()); T* output_data = output->mutable_data<T>(out_dims, ctx.GetPlace());
int pixelNum = nt * chw; int pixelNum = nt * chw;
platform::GpuLaunchConfig config = int threads = 1024;
platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), pixelNum); int grid = (pixelNum + threads - 1) / threads;
const auto& dev_ctx = ctx.cuda_device_context();
int blocks_per_sm = dev_ctx.GetMaxPhysicalThreadCount() / threads;
grid = std::min(dev_ctx.GetSMCount() * blocks_per_sm, grid);
KeTemporalShiftFw<T><<<config.block_per_grid, config.thread_per_block, 0, if (data_layout == DataLayout::kNCHW) {
ctx.cuda_device_context().stream()>>>( KeTemporalShiftFwNCHW<
input_data, output_data, ntchw, tchw, chw, hw, w, t, c, shift_ratio); T><<<grid, threads, 0, ctx.cuda_device_context().stream()>>>(
input_data, output_data, ntchw, tchw, chw, hw, t, c1, c2);
} else {
KeTemporalShiftFwNHWC<
T><<<grid, threads, 0, ctx.cuda_device_context().stream()>>>(
input_data, output_data, ntchw, tchw, chw, t, c, c1, c2);
}
} }
}; };
...@@ -130,32 +198,49 @@ class TemporalShiftGradOpCUDAKernel : public framework::OpKernel<T> { ...@@ -130,32 +198,49 @@ class TemporalShiftGradOpCUDAKernel : public framework::OpKernel<T> {
auto* output_grad = ctx.Input<Tensor>(framework::GradVarName("Out")); auto* output_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
int t = ctx.Attr<int>("seg_num"); int t = ctx.Attr<int>("seg_num");
float shift_ratio = ctx.Attr<float>("shift_ratio"); float shift_ratio = ctx.Attr<float>("shift_ratio");
const std::string data_format_str = ctx.Attr<std::string>("data_format");
const DataLayout data_layout =
framework::StringToDataLayout(data_format_str);
const int nt = output_grad->dims()[0]; const int nt = output_grad->dims()[0];
const int c = output_grad->dims()[1]; const int c = (data_layout == DataLayout::kNCHW ? output_grad->dims()[1]
const int h = output_grad->dims()[2]; : output_grad->dims()[3]);
const int w = output_grad->dims()[3]; const int h = (data_layout == DataLayout::kNCHW ? output_grad->dims()[2]
: output_grad->dims()[1]);
const int w = (data_layout == DataLayout::kNCHW ? output_grad->dims()[3]
: output_grad->dims()[2]);
const int hw = h * w; const int hw = h * w;
const int chw = c * hw; const int chw = c * hw;
const int tchw = t * chw; const int tchw = t * chw;
const int ntchw = nt * chw; const int ntchw = nt * chw;
const int c1 = static_cast<int>(c * shift_ratio);
const int c2 = static_cast<int>(c * 2 * shift_ratio);
framework::DDim in_grad_dims = (data_layout == DataLayout::kNCHW
? framework::make_ddim({nt, c, h, w})
: framework::make_ddim({nt, h, w, c}));
const T* output_grad_data = output_grad->data<T>(); const T* output_grad_data = output_grad->data<T>();
T* input_grad_data = T* input_grad_data =
input_grad->mutable_data<T>({nt, c, h, w}, ctx.GetPlace()); input_grad->mutable_data<T>(in_grad_dims, ctx.GetPlace());
math::SetConstant<platform::CUDADeviceContext, T>()(
ctx.template device_context<platform::CUDADeviceContext>(), input_grad,
static_cast<T>(0));
int pixelNum = nt * chw; int pixelNum = nt * chw;
platform::GpuLaunchConfig config = int threads = 1024;
platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), pixelNum); int grid = (pixelNum + threads - 1) / threads;
const auto& dev_ctx = ctx.cuda_device_context();
int blocks_per_sm = dev_ctx.GetMaxPhysicalThreadCount() / threads;
grid = std::min(dev_ctx.GetSMCount() * blocks_per_sm, grid);
KeTemporalShiftBw<T><<<config.block_per_grid, config.thread_per_block, 0, if (data_layout == DataLayout::kNCHW) {
ctx.cuda_device_context().stream()>>>( KeTemporalShiftBwNCHW<
output_grad_data, input_grad_data, ntchw, tchw, chw, hw, w, t, c, T><<<grid, threads, 0, ctx.cuda_device_context().stream()>>>(
shift_ratio); output_grad_data, input_grad_data, ntchw, tchw, chw, hw, t, c1, c2);
} else {
KeTemporalShiftBwNHWC<
T><<<grid, threads, 0, ctx.cuda_device_context().stream()>>>(
output_grad_data, input_grad_data, ntchw, tchw, chw, t, c, c1, c2);
}
} }
}; };
......
...@@ -17,12 +17,106 @@ namespace paddle { ...@@ -17,12 +17,106 @@ namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
using DataLayout = framework::DataLayout;
static HOSTDEVICE inline int GetEntryIndex(int in, int it, int ic, int ih, template <typename T>
int iw, const int tchw, void TemporalShiftFwNCHW(const T* input, T* output, const int ntchw,
const int chw, const int hw, const int tchw, const int chw, const int hw,
const int w) { const int t, const int c1, const int c2) {
return in * tchw + it * chw + ic * hw + ih * w + iw; int src_it = 0;
for (int i = 0; i < ntchw; i++) {
int it = (i % tchw) / chw;
int ic = (i % chw) / hw;
if (ic < c1) {
src_it = it - 1;
} else if (ic < c2) {
src_it = it + 1;
} else {
src_it = it;
}
if (src_it < 0 || src_it >= t) {
output[i] = 0;
} else {
output[i] = input[i + (src_it - it) * chw];
}
}
}
template <typename T>
void TemporalShiftFwNHWC(const T* input, T* output, const int nthwc,
const int thwc, const int hwc, const int t,
const int c, const int c1, const int c2) {
int src_it = 0;
for (int i = 0; i < nthwc; i++) {
int it = (i % thwc) / hwc;
int ic = i % c;
if (ic < c1) {
src_it = it - 1;
} else if (ic < c2) {
src_it = it + 1;
} else {
src_it = it;
}
if (src_it < 0 || src_it >= t) {
output[i] = 0;
} else {
output[i] = input[i + (src_it - it) * hwc];
}
}
}
template <typename T>
void TemporalShiftBwNCHW(const T* output_grad, T* input_grad, const int ntchw,
const int tchw, const int chw, const int hw,
const int t, const int c1, const int c2) {
int src_it = 0;
for (int i = 0; i < ntchw; i++) {
int it = (i % tchw) / chw;
int ic = (i % chw) / hw;
if (ic < c1) {
src_it = it + 1;
} else if (ic < c2) {
src_it = it - 1;
} else {
src_it = it;
}
if (src_it >= 0 && src_it < t) {
input_grad[i] = output_grad[i + (src_it - it) * chw];
} else {
input_grad[i] = 0;
}
}
}
template <typename T>
void TemporalShiftBwNHWC(const T* output_grad, T* input_grad, const int nthwc,
const int thwc, const int hwc, const int t,
const int c, const int c1, const int c2) {
int src_it = 0;
for (int i = 0; i < nthwc; i++) {
int it = (i % thwc) / hwc;
int ic = i % c;
if (ic < c1) {
src_it = it + 1;
} else if (ic < c2) {
src_it = it - 1;
} else {
src_it = it;
}
if (src_it >= 0 && src_it < t) {
input_grad[i] = output_grad[i + (src_it - it) * hwc];
} else {
input_grad[i] = 0;
}
}
} }
template <typename T> template <typename T>
...@@ -33,44 +127,38 @@ class TemporalShiftKernel : public framework::OpKernel<T> { ...@@ -33,44 +127,38 @@ class TemporalShiftKernel : public framework::OpKernel<T> {
auto* output = ctx.Output<Tensor>("Out"); auto* output = ctx.Output<Tensor>("Out");
int t = ctx.Attr<int>("seg_num"); int t = ctx.Attr<int>("seg_num");
float shift_ratio = ctx.Attr<float>("shift_ratio"); float shift_ratio = ctx.Attr<float>("shift_ratio");
const std::string data_format_str = ctx.Attr<std::string>("data_format");
const DataLayout data_layout =
framework::StringToDataLayout(data_format_str);
const int nt = input->dims()[0]; const int nt = input->dims()[0];
const int c = input->dims()[1]; const int c = (data_layout == DataLayout::kNCHW ? input->dims()[1]
const int h = input->dims()[2]; : input->dims()[3]);
const int w = input->dims()[3]; const int h = (data_layout == DataLayout::kNCHW ? input->dims()[2]
: input->dims()[1]);
const int c1 = static_cast<int>(c * shift_ratio); const int w = (data_layout == DataLayout::kNCHW ? input->dims()[3]
const int c2 = static_cast<int>(c * 2 * shift_ratio); : input->dims()[2]);
const int hw = h * w; const int hw = h * w;
const int chw = c * hw; const int chw = c * hw;
const int tchw = t * chw; const int tchw = t * chw;
const int ntchw = nt * chw;
const int c1 = static_cast<int>(c * shift_ratio);
const int c2 = static_cast<int>(c * 2 * shift_ratio);
framework::DDim out_dims = (data_layout == DataLayout::kNCHW
? framework::make_ddim({nt, c, h, w})
: framework::make_ddim({nt, h, w, c}));
const T* input_data = input->data<T>(); const T* input_data = input->data<T>();
T* output_data = output->mutable_data<T>({nt, c, h, w}, ctx.GetPlace()); T* output_data = output->mutable_data<T>(out_dims, ctx.GetPlace());
int src_it = 0; if (data_layout == DataLayout::kNCHW) {
for (int i = 0; i < output->numel(); i++) { TemporalShiftFwNCHW<T>(input_data, output_data, ntchw, tchw, chw, hw, t,
int in = i / tchw; c1, c2);
int it = (i % tchw) / chw; } else {
int ic = (i % chw) / hw; TemporalShiftFwNHWC<T>(input_data, output_data, ntchw, tchw, chw, t, c,
int ih = (i % hw) / w; c1, c2);
int iw = i % w;
if (ic < c1) {
src_it = it - 1;
} else if (ic < c2) {
src_it = it + 1;
} else {
src_it = it;
}
if (src_it < 0 || src_it >= t) {
output_data[i] = 0;
} else {
int src_idx = GetEntryIndex(in, src_it, ic, ih, iw, tchw, chw, hw, w);
output_data[i] = input_data[src_idx];
}
} }
} }
}; };
...@@ -83,44 +171,39 @@ class TemporalShiftGradKernel : public framework::OpKernel<T> { ...@@ -83,44 +171,39 @@ class TemporalShiftGradKernel : public framework::OpKernel<T> {
auto* output_grad = ctx.Input<Tensor>(framework::GradVarName("Out")); auto* output_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
int t = ctx.Attr<int>("seg_num"); int t = ctx.Attr<int>("seg_num");
float shift_ratio = ctx.Attr<float>("shift_ratio"); float shift_ratio = ctx.Attr<float>("shift_ratio");
const std::string data_format_str = ctx.Attr<std::string>("data_format");
const DataLayout data_layout =
framework::StringToDataLayout(data_format_str);
const int nt = output_grad->dims()[0]; const int nt = output_grad->dims()[0];
const int c = output_grad->dims()[1]; const int c = (data_layout == DataLayout::kNCHW ? output_grad->dims()[1]
const int h = output_grad->dims()[2]; : output_grad->dims()[3]);
const int w = output_grad->dims()[3]; const int h = (data_layout == DataLayout::kNCHW ? output_grad->dims()[2]
: output_grad->dims()[1]);
const int c1 = static_cast<int>(c * shift_ratio); const int w = (data_layout == DataLayout::kNCHW ? output_grad->dims()[3]
const int c2 = static_cast<int>(c * 2 * shift_ratio); : output_grad->dims()[2]);
const int hw = h * w; const int hw = h * w;
const int chw = c * hw; const int chw = c * hw;
const int tchw = t * chw; const int tchw = t * chw;
const int ntchw = nt * chw;
const int c1 = static_cast<int>(c * shift_ratio);
const int c2 = static_cast<int>(c * 2 * shift_ratio);
framework::DDim in_grad_dims = (data_layout == DataLayout::kNCHW
? framework::make_ddim({nt, c, h, w})
: framework::make_ddim({nt, h, w, c}));
const T* output_grad_data = output_grad->data<T>(); const T* output_grad_data = output_grad->data<T>();
T* input_grad_data = T* input_grad_data =
input_grad->mutable_data<T>({nt, c, h, w}, ctx.GetPlace()); input_grad->mutable_data<T>(in_grad_dims, ctx.GetPlace());
memset(input_grad_data, 0, input_grad->numel() * sizeof(T));
if (data_layout == DataLayout::kNCHW) {
int src_it = 0; TemporalShiftBwNCHW<T>(output_grad_data, input_grad_data, ntchw, tchw,
for (int i = 0; i < output_grad->numel(); i++) { chw, hw, t, c1, c2);
int in = i / tchw; } else {
int it = (i % tchw) / chw; TemporalShiftBwNHWC<T>(output_grad_data, input_grad_data, ntchw, tchw,
int ic = (i % chw) / hw; chw, t, c, c1, c2);
int ih = (i % hw) / w;
int iw = i % w;
if (ic < c1) {
src_it = it - 1;
} else if (ic < c2) {
src_it = it + 1;
} else {
src_it = it;
}
if (src_it >= 0 && src_it < t) {
int src_idx = GetEntryIndex(in, src_it, ic, ih, iw, tchw, chw, hw, w);
input_grad_data[src_idx] = output_grad_data[i];
}
} }
} }
}; };
......
...@@ -13334,7 +13334,7 @@ def shuffle_channel(x, group, name=None): ...@@ -13334,7 +13334,7 @@ def shuffle_channel(x, group, name=None):
@templatedoc() @templatedoc()
def temporal_shift(x, seg_num, shift_ratio=0.25, name=None): def temporal_shift(x, seg_num, shift_ratio=0.25, name=None, data_format="NCHW"):
""" """
**Temporal Shift Operator** **Temporal Shift Operator**
...@@ -13348,6 +13348,8 @@ def temporal_shift(x, seg_num, shift_ratio=0.25, name=None): ...@@ -13348,6 +13348,8 @@ def temporal_shift(x, seg_num, shift_ratio=0.25, name=None):
name(str, optional): For detailed information, please refer name(str, optional): For detailed information, please refer
to :ref:`api_guide_Name`. Usually name is no need to set and to :ref:`api_guide_Name`. Usually name is no need to set and
None by default. None by default.
data_format(str, optional): Data format that specifies the layout of input.
It can be "NCHW" or "NHWC". Default: "NCHW".
Returns: Returns:
out(Tensor): The temporal shifting result is a tensor with the out(Tensor): The temporal shifting result is a tensor with the
...@@ -13365,6 +13367,13 @@ def temporal_shift(x, seg_num, shift_ratio=0.25, name=None): ...@@ -13365,6 +13367,13 @@ def temporal_shift(x, seg_num, shift_ratio=0.25, name=None):
input = paddle.randn([6, 4, 2, 2]) input = paddle.randn([6, 4, 2, 2])
out = F.temporal_shift(x=input, seg_num=2, shift_ratio=0.2) out = F.temporal_shift(x=input, seg_num=2, shift_ratio=0.2)
""" """
if data_format not in ["NCHW", "NHWC"]:
raise ValueError("Attr(data_format) should be 'NCHW' or 'NHWC'. "
"Received Attr(data_format): {}.".format(data_format))
if in_dygraph_mode():
return core.ops.temporal_shift(x, 'seg_num', seg_num, 'shift_ratio',
shift_ratio, 'data_format', data_format)
helper = LayerHelper("temporal_shift", **locals()) helper = LayerHelper("temporal_shift", **locals())
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'temporal_shift') check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'temporal_shift')
check_type(seg_num, 'seg_num', int, 'temporal_shift') check_type(seg_num, 'seg_num', int, 'temporal_shift')
...@@ -13375,16 +13384,15 @@ def temporal_shift(x, seg_num, shift_ratio=0.25, name=None): ...@@ -13375,16 +13384,15 @@ def temporal_shift(x, seg_num, shift_ratio=0.25, name=None):
if not isinstance(seg_num, int): if not isinstance(seg_num, int):
raise TypeError("seg_num must be int type.") raise TypeError("seg_num must be int type.")
if in_dygraph_mode():
return core.ops.temporal_shift(x, 'seg_num', seg_num, 'shift_ratio',
shift_ratio)
helper.append_op( helper.append_op(
type="temporal_shift", type="temporal_shift",
inputs={"X": x}, inputs={"X": x},
outputs={"Out": out}, outputs={"Out": out},
attrs={"seg_num": seg_num, attrs={
"shift_ratio": shift_ratio}) "seg_num": seg_num,
"shift_ratio": shift_ratio,
"data_format": data_format
})
return out return out
......
...@@ -22,7 +22,9 @@ import paddle ...@@ -22,7 +22,9 @@ import paddle
from paddle.fluid import core from paddle.fluid import core
def temporal_shift(x, seg_num, shift_ratio): def temporal_shift(x, seg_num, shift_ratio, data_format):
if data_format == "NHWC":
x = np.transpose(x, (0, 3, 1, 2))
shape = x.shape shape = x.shape
reshape_x = x.reshape((-1, seg_num, shape[1], shape[2], shape[3])) reshape_x = x.reshape((-1, seg_num, shape[1], shape[2], shape[3]))
pad_x = np.pad(reshape_x, ((0, 0), (1, 1), (0, 0), (0, 0), (0, 0)), pad_x = np.pad(reshape_x, ((0, 0), (1, 1), (0, 0), (0, 0), (0, 0)),
...@@ -33,7 +35,10 @@ def temporal_shift(x, seg_num, shift_ratio): ...@@ -33,7 +35,10 @@ def temporal_shift(x, seg_num, shift_ratio):
slice2 = pad_x[:, 2:seg_num + 2, c1:c2, :, :] slice2 = pad_x[:, 2:seg_num + 2, c1:c2, :, :]
slice3 = pad_x[:, 1:seg_num + 1, c2:, :, :] slice3 = pad_x[:, 1:seg_num + 1, c2:, :, :]
concat_x = np.concatenate([slice1, slice2, slice3], axis=2) concat_x = np.concatenate([slice1, slice2, slice3], axis=2)
return concat_x.reshape(shape) out = concat_x.reshape(shape)
if data_format == "NHWC":
out = np.transpose(out, (0, 2, 3, 1))
return out
class TestTemporalShift(OpTest): class TestTemporalShift(OpTest):
...@@ -45,11 +50,13 @@ class TestTemporalShift(OpTest): ...@@ -45,11 +50,13 @@ class TestTemporalShift(OpTest):
self.attrs = { self.attrs = {
"seg_num": self.seg_num, "seg_num": self.seg_num,
"shift_ratio": self.shift_ratio, "shift_ratio": self.shift_ratio,
"data_format": self.data_format
} }
self.inputs = {"X": x, } self.inputs = {"X": x, }
output = temporal_shift(x, self.seg_num, self.shift_ratio) output = temporal_shift(x, self.seg_num, self.shift_ratio,
self.data_format)
self.outputs = {"Out": output} self.outputs = {"Out": output}
def test_check_output(self): def test_check_output(self):
...@@ -63,6 +70,7 @@ class TestTemporalShift(OpTest): ...@@ -63,6 +70,7 @@ class TestTemporalShift(OpTest):
self.seg_num = 3 self.seg_num = 3
self.shift_ratio = 0.25 self.shift_ratio = 0.25
self.dtype = 'float64' self.dtype = 'float64'
self.data_format = 'NCHW'
class TestTemporalShift2(TestTemporalShift): class TestTemporalShift2(TestTemporalShift):
...@@ -70,6 +78,7 @@ class TestTemporalShift2(TestTemporalShift): ...@@ -70,6 +78,7 @@ class TestTemporalShift2(TestTemporalShift):
self.x_shape = (4, 9, 7, 7) self.x_shape = (4, 9, 7, 7)
self.seg_num = 2 self.seg_num = 2
self.shift_ratio = 0.2 self.shift_ratio = 0.2
self.data_format = 'NCHW'
class TestTemporalShift3(TestTemporalShift): class TestTemporalShift3(TestTemporalShift):
...@@ -77,6 +86,15 @@ class TestTemporalShift3(TestTemporalShift): ...@@ -77,6 +86,15 @@ class TestTemporalShift3(TestTemporalShift):
self.x_shape = (3, 10, 5, 5) self.x_shape = (3, 10, 5, 5)
self.seg_num = 1 self.seg_num = 1
self.shift_ratio = 0.3 self.shift_ratio = 0.3
self.data_format = 'NCHW'
class TestTemporalShift4(TestTemporalShift):
def initTestCase(self):
self.x_shape = (6, 5, 5, 4)
self.seg_num = 3
self.shift_ratio = 0.25
self.data_format = 'NHWC'
@unittest.skipIf(not core.is_compiled_with_cuda(), @unittest.skipIf(not core.is_compiled_with_cuda(),
...@@ -87,6 +105,7 @@ class TestTemporalShiftFP16(TestTemporalShift): ...@@ -87,6 +105,7 @@ class TestTemporalShiftFP16(TestTemporalShift):
self.seg_num = 1 self.seg_num = 1
self.shift_ratio = 0.3 self.shift_ratio = 0.3
self.dtype = 'float16' self.dtype = 'float16'
self.data_format = 'NCHW'
def test_check_output(self): def test_check_output(self):
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
...@@ -114,6 +133,14 @@ class TestTemporalShiftAPI(unittest.TestCase): ...@@ -114,6 +133,14 @@ class TestTemporalShiftAPI(unittest.TestCase):
out = paddle.nn.functional.temporal_shift( out = paddle.nn.functional.temporal_shift(
x=input, seg_num=2, shift_ratio=0.2) x=input, seg_num=2, shift_ratio=0.2)
def test_error(self):
def attr_data_format():
input = paddle.randn([6, 4, 2, 2])
out = paddle.nn.functional.temporal_shift(
x=input, seg_num=2, shift_ratio=0.2, data_format="HWC")
self.assertRaises(ValueError, attr_data_format)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册