提交 82d4f903 编写于 作者: D dengkaipeng

fix format. test=develop

上级 28949f8e
......@@ -17,7 +17,7 @@ namespace operators {
using framework::Tensor;
class TemporalShiftOp: public framework::OperatorWithKernel {
class TemporalShiftOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
......@@ -34,14 +34,14 @@ class TemporalShiftOp: public framework::OperatorWithKernel {
int seg_num = ctx->Attrs().Get<int>("seg_num");
float shift_ratio = ctx->Attrs().Get<float>("shift_ratio");
PADDLE_ENFORCE_GT(seg_num, 0,
"Attr(seg_num) should be greater than 0.");
PADDLE_ENFORCE_GT(seg_num, 0, "Attr(seg_num) should be greater than 0.");
PADDLE_ENFORCE(shift_ratio > 0 || shift_ratio < .5,
"Attr(shift_ratio) should be greater than 0 and less "
"than 0.5.");
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(dim_x[0] % seg_num, 0,
PADDLE_ENFORCE_EQ(
dim_x[0] % seg_num, 0,
"Input(X) dims[0] should be divided exactly by Attr(seg_num).");
}
......@@ -73,7 +73,8 @@ class TemporalShiftOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<int>("seg_num",
"The temporal segment number, this should be a positive "
"interger.");
AddAttr<float>("shift_ratio",
AddAttr<float>(
"shift_ratio",
"The shift ratio of the channels, the first shift ratio part "
"of channels will be shifted by -1 along the temporal dimension, "
"and the second shift ratio part of channels will be shifted by "
......@@ -118,7 +119,7 @@ class TemporalShiftOpMaker : public framework::OpProtoAndCheckerMaker {
}
};
class TemporalShiftOpGrad: public framework::OperatorWithKernel {
class TemporalShiftOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
......@@ -144,7 +145,8 @@ class TemporalShiftOpGrad: public framework::OperatorWithKernel {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(temporal_shift, ops::TemporalShiftOp, ops::TemporalShiftOpMaker,
REGISTER_OPERATOR(temporal_shift, ops::TemporalShiftOp,
ops::TemporalShiftOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(temporal_shift_grad, ops::TemporalShiftOpGrad);
REGISTER_OP_CPU_KERNEL(temporal_shift, ops::TemporalShiftKernel<float>,
......
......@@ -17,10 +17,10 @@ namespace operators {
using framework::Tensor;
template <typename T>
__global__ void KeTemporalShiftFw(const T* input, T* output, const int ntchw,
const int tchw, const int chw, const int hw, const int w, const int t, const int c,
const int tchw, const int chw, const int hw,
const int w, const int t, const int c,
const float shift_ratio) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
......@@ -53,8 +53,10 @@ __global__ void KeTemporalShiftFw(const T* input, T* output, const int ntchw,
}
template <typename T>
__global__ void KeTemporalShiftBw(const T* output_grad, T* input_grad, const int ntchw,
const int tchw, const int chw, const int hw, const int w, const int t, const int c,
__global__ void KeTemporalShiftBw(const T* output_grad, T* input_grad,
const int ntchw, const int tchw,
const int chw, const int hw, const int w,
const int t, const int c,
const float shift_ratio) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
......@@ -138,7 +140,8 @@ class TemporalShiftGradOpCUDAKernel : public framework::OpKernel<T> {
const int ntchw = nt * chw;
const T* output_grad_data = output_grad->data<T>();
T* input_grad_data = input_grad->mutable_data<T>({nt, c, h, w}, ctx.GetPlace());
T* input_grad_data =
input_grad->mutable_data<T>({nt, c, h, w}, ctx.GetPlace());
math::SetConstant<platform::CUDADeviceContext, T>()(
ctx.template device_context<platform::CUDADeviceContext>(), input_grad,
static_cast<T>(0));
......@@ -149,7 +152,8 @@ class TemporalShiftGradOpCUDAKernel : public framework::OpKernel<T> {
KeTemporalShiftBw<
T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
output_grad_data, input_grad_data, ntchw, tchw, chw, hw, w, t, c, shift_ratio);
output_grad_data, input_grad_data, ntchw, tchw, chw, hw, w, t, c,
shift_ratio);
}
};
......
......@@ -18,13 +18,15 @@ namespace operators {
using Tensor = framework::Tensor;
static HOSTDEVICE inline int GetEntryIndex(int in, int it, int ic, int ih, int iw,
const int tchw, const int chw, const int hw, const int w) {
static HOSTDEVICE inline int GetEntryIndex(int in, int it, int ic, int ih,
int iw, const int tchw,
const int chw, const int hw,
const int w) {
return in * tchw + it * chw + ic * hw + ih * w + iw;
}
template <typename T>
class TemporalShiftKernel: public framework::OpKernel<T> {
class TemporalShiftKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<Tensor>("X");
......@@ -95,7 +97,8 @@ class TemporalShiftGradKernel : public framework::OpKernel<T> {
const int tchw = t * chw;
const T* output_grad_data = output_grad->data<T>();
T* input_grad_data = input_grad->mutable_data<T>({nt, c, h, w}, ctx.GetPlace());
T* input_grad_data =
input_grad->mutable_data<T>({nt, c, h, w}, ctx.GetPlace());
memset(input_grad_data, 0, input_grad->numel() * sizeof(T));
int src_it = 0;
......
......@@ -10301,10 +10301,8 @@ def temporal_shift(x, seg_num, shift_ratio=0.25, name=None):
type="temporal_shift",
inputs={"X": x},
outputs={"Out": out},
attrs={
"seg_num": seg_num,
"shift_ratio": shift_ratio
})
attrs={"seg_num": seg_num,
"shift_ratio": shift_ratio})
return out
......
......@@ -24,15 +24,17 @@ from paddle.fluid import core
def temporal_shift(x, seg_num, shift_ratio):
shape = x.shape
reshape_x = x.reshape((-1, seg_num, shape[1], shape[2], shape[3]))
pad_x = np.pad(reshape_x, ((0, 0), (1, 1), (0, 0), (0, 0), (0, 0)), 'constant')
pad_x = np.pad(reshape_x, ((0, 0), (1, 1), (0, 0), (0, 0), (0, 0)),
'constant')
c1 = int(shape[1] * shift_ratio)
c2 = int(shape[1] * 2 * shift_ratio)
slice1 = pad_x[:, :seg_num, :c1, :, :]
slice2 = pad_x[:, 2:seg_num+2, c1:c2, :, :]
slice3 = pad_x[:, 1:seg_num+1, c2:, :, :]
slice2 = pad_x[:, 2:seg_num + 2, c1:c2, :, :]
slice3 = pad_x[:, 1:seg_num + 1, c2:, :, :]
concat_x = np.concatenate([slice1, slice2, slice3], axis=2)
return concat_x.reshape(shape)
class TestTemporalShift(OpTest):
def setUp(self):
self.initTestCase()
......@@ -44,9 +46,7 @@ class TestTemporalShift(OpTest):
"shift_ratio": self.shift_ratio,
}
self.inputs = {
"X": x,
}
self.inputs = {"X": x, }
output = temporal_shift(x, self.seg_num, self.shift_ratio)
self.outputs = {"Out": output}
......@@ -62,6 +62,7 @@ class TestTemporalShift(OpTest):
self.seg_num = 3
self.shift_ratio = 0.25
class TestTemporalShift2(TestTemporalShift):
def initTestCase(self):
self.x_shape = (4, 9, 7, 7)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册