提交 82d4f903 编写于 作者: D dengkaipeng

fix format. test=develop

上级 28949f8e
......@@ -17,7 +17,7 @@ namespace operators {
using framework::Tensor;
class TemporalShiftOp: public framework::OperatorWithKernel {
class TemporalShiftOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
......@@ -29,23 +29,23 @@ class TemporalShiftOp: public framework::OperatorWithKernel {
"Output(Out) of TemporalShiftOp should not be null.");
auto dim_x = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(dim_x.size(), 4,
"Input(X) rank should be 4 in shape of [N*T, C, H, W].");
PADDLE_ENFORCE_EQ(dim_x.size(), 4,
"Input(X) rank should be 4 in shape of [N*T, C, H, W].");
int seg_num = ctx->Attrs().Get<int>("seg_num");
float shift_ratio = ctx->Attrs().Get<float>("shift_ratio");
PADDLE_ENFORCE_GT(seg_num, 0,
"Attr(seg_num) should be greater than 0.");
PADDLE_ENFORCE_GT(seg_num, 0, "Attr(seg_num) should be greater than 0.");
PADDLE_ENFORCE(shift_ratio > 0 || shift_ratio < .5,
"Attr(shift_ratio) should be greater than 0 and less "
"than 0.5.");
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(dim_x[0] % seg_num, 0,
"Input(X) dims[0] should be divided exactly by Attr(seg_num).");
PADDLE_ENFORCE_EQ(
dim_x[0] % seg_num, 0,
"Input(X) dims[0] should be divided exactly by Attr(seg_num).");
}
ctx->SetOutputDim("Out", dim_x);
ctx->SetOutputDim("Out", dim_x);
ctx->ShareLoD("X", "Out");
}
......@@ -70,14 +70,15 @@ class TemporalShiftOpMaker : public framework::OpProtoAndCheckerMaker {
"The output tensor of temporal shift operator. "
"This is a 4-D tensor in the same shape with Input(X).");
AddAttr<int>("seg_num",
"The temporal segment number, this should be a positive "
"interger.");
AddAttr<float>("shift_ratio",
"The shift ratio of the channels, the first shift ratio part "
"of channels will be shifted by -1 along the temporal dimension, "
"and the second shift ratio part of channels will be shifted by "
"1 along the temporal dimension. Default 0.25.")
AddAttr<int>("seg_num",
"The temporal segment number, this should be a positive "
"interger.");
AddAttr<float>(
"shift_ratio",
"The shift ratio of the channels, the first shift ratio part "
"of channels will be shifted by -1 along the temporal dimension, "
"and the second shift ratio part of channels will be shifted by "
"1 along the temporal dimension. Default 0.25.")
.SetDefault(0.25);
AddComment(R"DOC(
......@@ -118,7 +119,7 @@ class TemporalShiftOpMaker : public framework::OpProtoAndCheckerMaker {
}
};
class TemporalShiftOpGrad: public framework::OperatorWithKernel {
class TemporalShiftOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
......@@ -144,7 +145,8 @@ class TemporalShiftOpGrad: public framework::OperatorWithKernel {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(temporal_shift, ops::TemporalShiftOp, ops::TemporalShiftOpMaker,
REGISTER_OPERATOR(temporal_shift, ops::TemporalShiftOp,
ops::TemporalShiftOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(temporal_shift_grad, ops::TemporalShiftOpGrad);
REGISTER_OP_CPU_KERNEL(temporal_shift, ops::TemporalShiftKernel<float>,
......
......@@ -17,70 +17,72 @@ namespace operators {
using framework::Tensor;
template <typename T>
__global__ void KeTemporalShiftFw(const T* input, T* output, const int ntchw,
const int tchw, const int chw, const int hw, const int w, const int t, const int c,
const float shift_ratio) {
const int tchw, const int chw, const int hw,
const int w, const int t, const int c,
const float shift_ratio) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
int src_it = 0;
for (; tid < ntchw; tid += stride) {
int in = tid / tchw;
int it = (tid % tchw) / chw;
int ic = (tid % chw) / hw;
int ih = (tid % hw) / w;
int iw = tid % w;
const int c1 = static_cast<T>(c * shift_ratio);
const int c2 = static_cast<T>(c * 2 * shift_ratio);
if (ic < c1) {
src_it = it - 1;
} else if (ic < c2) {
src_it = it + 1;
} else {
src_it = it;
}
if (src_it < 0 || src_it >= t) {
output[tid] = 0;
} else {
int src_idx = GetEntryIndex(in, src_it, ic, ih, iw, tchw, chw, hw, w);
output[tid] = input[src_idx];
}
int in = tid / tchw;
int it = (tid % tchw) / chw;
int ic = (tid % chw) / hw;
int ih = (tid % hw) / w;
int iw = tid % w;
const int c1 = static_cast<T>(c * shift_ratio);
const int c2 = static_cast<T>(c * 2 * shift_ratio);
if (ic < c1) {
src_it = it - 1;
} else if (ic < c2) {
src_it = it + 1;
} else {
src_it = it;
}
if (src_it < 0 || src_it >= t) {
output[tid] = 0;
} else {
int src_idx = GetEntryIndex(in, src_it, ic, ih, iw, tchw, chw, hw, w);
output[tid] = input[src_idx];
}
}
}
template <typename T>
__global__ void KeTemporalShiftBw(const T* output_grad, T* input_grad, const int ntchw,
const int tchw, const int chw, const int hw, const int w, const int t, const int c,
const float shift_ratio) {
__global__ void KeTemporalShiftBw(const T* output_grad, T* input_grad,
const int ntchw, const int tchw,
const int chw, const int hw, const int w,
const int t, const int c,
const float shift_ratio) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
int src_it = 0;
for (; tid < ntchw; tid += stride) {
int in = tid / tchw;
int it = (tid % tchw) / chw;
int ic = (tid % chw) / hw;
int ih = (tid % hw) / w;
int iw = tid % w;
const int c1 = static_cast<T>(c * shift_ratio);
const int c2 = static_cast<T>(c * 2 * shift_ratio);
if (ic < c1) {
src_it = it - 1;
} else if (ic < c2) {
src_it = it + 1;
} else {
src_it = it;
}
if (src_it >= 0 && src_it < t) {
int src_idx = GetEntryIndex(in, src_it, ic, ih, iw, tchw, chw, hw, w);
input_grad[src_idx] = output_grad[tid];
}
int in = tid / tchw;
int it = (tid % tchw) / chw;
int ic = (tid % chw) / hw;
int ih = (tid % hw) / w;
int iw = tid % w;
const int c1 = static_cast<T>(c * shift_ratio);
const int c2 = static_cast<T>(c * 2 * shift_ratio);
if (ic < c1) {
src_it = it - 1;
} else if (ic < c2) {
src_it = it + 1;
} else {
src_it = it;
}
if (src_it >= 0 && src_it < t) {
int src_idx = GetEntryIndex(in, src_it, ic, ih, iw, tchw, chw, hw, w);
input_grad[src_idx] = output_grad[tid];
}
}
}
......@@ -113,8 +115,8 @@ class TemporalShiftOpCUDAKernel : public framework::OpKernel<T> {
grid_dim = grid_dim > 8 ? 8 : grid_dim;
KeTemporalShiftFw<
T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
input_data, output_data, ntchw, tchw, chw, hw, w, t, c, shift_ratio);
T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
input_data, output_data, ntchw, tchw, chw, hw, w, t, c, shift_ratio);
}
};
......@@ -138,7 +140,8 @@ class TemporalShiftGradOpCUDAKernel : public framework::OpKernel<T> {
const int ntchw = nt * chw;
const T* output_grad_data = output_grad->data<T>();
T* input_grad_data = input_grad->mutable_data<T>({nt, c, h, w}, ctx.GetPlace());
T* input_grad_data =
input_grad->mutable_data<T>({nt, c, h, w}, ctx.GetPlace());
math::SetConstant<platform::CUDADeviceContext, T>()(
ctx.template device_context<platform::CUDADeviceContext>(), input_grad,
static_cast<T>(0));
......@@ -148,8 +151,9 @@ class TemporalShiftGradOpCUDAKernel : public framework::OpKernel<T> {
grid_dim = grid_dim > 8 ? 8 : grid_dim;
KeTemporalShiftBw<
T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
output_grad_data, input_grad_data, ntchw, tchw, chw, hw, w, t, c, shift_ratio);
T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
output_grad_data, input_grad_data, ntchw, tchw, chw, hw, w, t, c,
shift_ratio);
}
};
......
......@@ -18,13 +18,15 @@ namespace operators {
using Tensor = framework::Tensor;
static HOSTDEVICE inline int GetEntryIndex(int in, int it, int ic, int ih, int iw,
const int tchw, const int chw, const int hw, const int w) {
static HOSTDEVICE inline int GetEntryIndex(int in, int it, int ic, int ih,
int iw, const int tchw,
const int chw, const int hw,
const int w) {
return in * tchw + it * chw + ic * hw + ih * w + iw;
}
template <typename T>
class TemporalShiftKernel: public framework::OpKernel<T> {
class TemporalShiftKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<Tensor>("X");
......@@ -62,7 +64,7 @@ class TemporalShiftKernel: public framework::OpKernel<T> {
} else {
src_it = it;
}
if (src_it < 0 || src_it >= t) {
output_data[i] = 0;
} else {
......@@ -95,7 +97,8 @@ class TemporalShiftGradKernel : public framework::OpKernel<T> {
const int tchw = t * chw;
const T* output_grad_data = output_grad->data<T>();
T* input_grad_data = input_grad->mutable_data<T>({nt, c, h, w}, ctx.GetPlace());
T* input_grad_data =
input_grad->mutable_data<T>({nt, c, h, w}, ctx.GetPlace());
memset(input_grad_data, 0, input_grad->numel() * sizeof(T));
int src_it = 0;
......@@ -113,7 +116,7 @@ class TemporalShiftGradKernel : public framework::OpKernel<T> {
} else {
src_it = it;
}
if (src_it >= 0 && src_it < t) {
int src_idx = GetEntryIndex(in, src_it, ic, ih, iw, tchw, chw, hw, w);
input_grad_data[src_idx] = output_grad_data[i];
......
......@@ -10301,10 +10301,8 @@ def temporal_shift(x, seg_num, shift_ratio=0.25, name=None):
type="temporal_shift",
inputs={"X": x},
outputs={"Out": out},
attrs={
"seg_num": seg_num,
"shift_ratio": shift_ratio
})
attrs={"seg_num": seg_num,
"shift_ratio": shift_ratio})
return out
......
......@@ -24,15 +24,17 @@ from paddle.fluid import core
def temporal_shift(x, seg_num, shift_ratio):
shape = x.shape
reshape_x = x.reshape((-1, seg_num, shape[1], shape[2], shape[3]))
pad_x = np.pad(reshape_x, ((0, 0), (1, 1), (0, 0), (0, 0), (0, 0)), 'constant')
pad_x = np.pad(reshape_x, ((0, 0), (1, 1), (0, 0), (0, 0), (0, 0)),
'constant')
c1 = int(shape[1] * shift_ratio)
c2 = int(shape[1] * 2 * shift_ratio)
slice1 = pad_x[:, :seg_num, :c1, :, :]
slice2 = pad_x[:, 2:seg_num+2, c1:c2, :, :]
slice3 = pad_x[:, 1:seg_num+1, c2:, :, :]
slice2 = pad_x[:, 2:seg_num + 2, c1:c2, :, :]
slice3 = pad_x[:, 1:seg_num + 1, c2:, :, :]
concat_x = np.concatenate([slice1, slice2, slice3], axis=2)
return concat_x.reshape(shape)
class TestTemporalShift(OpTest):
def setUp(self):
self.initTestCase()
......@@ -44,9 +46,7 @@ class TestTemporalShift(OpTest):
"shift_ratio": self.shift_ratio,
}
self.inputs = {
"X": x,
}
self.inputs = {"X": x, }
output = temporal_shift(x, self.seg_num, self.shift_ratio)
self.outputs = {"Out": output}
......@@ -62,6 +62,7 @@ class TestTemporalShift(OpTest):
self.seg_num = 3
self.shift_ratio = 0.25
class TestTemporalShift2(TestTemporalShift):
def initTestCase(self):
self.x_shape = (4, 9, 7, 7)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册