提交 82d4f903 编写于 作者: D dengkaipeng

fix format. test=develop

上级 28949f8e
...@@ -17,7 +17,7 @@ namespace operators { ...@@ -17,7 +17,7 @@ namespace operators {
using framework::Tensor; using framework::Tensor;
class TemporalShiftOp: public framework::OperatorWithKernel { class TemporalShiftOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
...@@ -34,14 +34,14 @@ class TemporalShiftOp: public framework::OperatorWithKernel { ...@@ -34,14 +34,14 @@ class TemporalShiftOp: public framework::OperatorWithKernel {
int seg_num = ctx->Attrs().Get<int>("seg_num"); int seg_num = ctx->Attrs().Get<int>("seg_num");
float shift_ratio = ctx->Attrs().Get<float>("shift_ratio"); float shift_ratio = ctx->Attrs().Get<float>("shift_ratio");
PADDLE_ENFORCE_GT(seg_num, 0, PADDLE_ENFORCE_GT(seg_num, 0, "Attr(seg_num) should be greater than 0.");
"Attr(seg_num) should be greater than 0.");
PADDLE_ENFORCE(shift_ratio > 0 || shift_ratio < .5, PADDLE_ENFORCE(shift_ratio > 0 || shift_ratio < .5,
"Attr(shift_ratio) should be greater than 0 and less " "Attr(shift_ratio) should be greater than 0 and less "
"than 0.5."); "than 0.5.");
if (ctx->IsRuntime()) { if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(dim_x[0] % seg_num, 0, PADDLE_ENFORCE_EQ(
dim_x[0] % seg_num, 0,
"Input(X) dims[0] should be divided exactly by Attr(seg_num)."); "Input(X) dims[0] should be divided exactly by Attr(seg_num).");
} }
...@@ -73,7 +73,8 @@ class TemporalShiftOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -73,7 +73,8 @@ class TemporalShiftOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<int>("seg_num", AddAttr<int>("seg_num",
"The temporal segment number, this should be a positive " "The temporal segment number, this should be a positive "
"interger."); "interger.");
AddAttr<float>("shift_ratio", AddAttr<float>(
"shift_ratio",
"The shift ratio of the channels, the first shift ratio part " "The shift ratio of the channels, the first shift ratio part "
"of channels will be shifted by -1 along the temporal dimension, " "of channels will be shifted by -1 along the temporal dimension, "
"and the second shift ratio part of channels will be shifted by " "and the second shift ratio part of channels will be shifted by "
...@@ -118,7 +119,7 @@ class TemporalShiftOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -118,7 +119,7 @@ class TemporalShiftOpMaker : public framework::OpProtoAndCheckerMaker {
} }
}; };
class TemporalShiftOpGrad: public framework::OperatorWithKernel { class TemporalShiftOpGrad : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
...@@ -144,7 +145,8 @@ class TemporalShiftOpGrad: public framework::OperatorWithKernel { ...@@ -144,7 +145,8 @@ class TemporalShiftOpGrad: public framework::OperatorWithKernel {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(temporal_shift, ops::TemporalShiftOp, ops::TemporalShiftOpMaker, REGISTER_OPERATOR(temporal_shift, ops::TemporalShiftOp,
ops::TemporalShiftOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>); paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(temporal_shift_grad, ops::TemporalShiftOpGrad); REGISTER_OPERATOR(temporal_shift_grad, ops::TemporalShiftOpGrad);
REGISTER_OP_CPU_KERNEL(temporal_shift, ops::TemporalShiftKernel<float>, REGISTER_OP_CPU_KERNEL(temporal_shift, ops::TemporalShiftKernel<float>,
......
...@@ -17,10 +17,10 @@ namespace operators { ...@@ -17,10 +17,10 @@ namespace operators {
using framework::Tensor; using framework::Tensor;
template <typename T> template <typename T>
__global__ void KeTemporalShiftFw(const T* input, T* output, const int ntchw, __global__ void KeTemporalShiftFw(const T* input, T* output, const int ntchw,
const int tchw, const int chw, const int hw, const int w, const int t, const int c, const int tchw, const int chw, const int hw,
const int w, const int t, const int c,
const float shift_ratio) { const float shift_ratio) {
int tid = blockIdx.x * blockDim.x + threadIdx.x; int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x; int stride = blockDim.x * gridDim.x;
...@@ -53,8 +53,10 @@ __global__ void KeTemporalShiftFw(const T* input, T* output, const int ntchw, ...@@ -53,8 +53,10 @@ __global__ void KeTemporalShiftFw(const T* input, T* output, const int ntchw,
} }
template <typename T> template <typename T>
__global__ void KeTemporalShiftBw(const T* output_grad, T* input_grad, const int ntchw, __global__ void KeTemporalShiftBw(const T* output_grad, T* input_grad,
const int tchw, const int chw, const int hw, const int w, const int t, const int c, const int ntchw, const int tchw,
const int chw, const int hw, const int w,
const int t, const int c,
const float shift_ratio) { const float shift_ratio) {
int tid = blockIdx.x * blockDim.x + threadIdx.x; int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x; int stride = blockDim.x * gridDim.x;
...@@ -138,7 +140,8 @@ class TemporalShiftGradOpCUDAKernel : public framework::OpKernel<T> { ...@@ -138,7 +140,8 @@ class TemporalShiftGradOpCUDAKernel : public framework::OpKernel<T> {
const int ntchw = nt * chw; const int ntchw = nt * chw;
const T* output_grad_data = output_grad->data<T>(); const T* output_grad_data = output_grad->data<T>();
T* input_grad_data = input_grad->mutable_data<T>({nt, c, h, w}, ctx.GetPlace()); T* input_grad_data =
input_grad->mutable_data<T>({nt, c, h, w}, ctx.GetPlace());
math::SetConstant<platform::CUDADeviceContext, T>()( math::SetConstant<platform::CUDADeviceContext, T>()(
ctx.template device_context<platform::CUDADeviceContext>(), input_grad, ctx.template device_context<platform::CUDADeviceContext>(), input_grad,
static_cast<T>(0)); static_cast<T>(0));
...@@ -149,7 +152,8 @@ class TemporalShiftGradOpCUDAKernel : public framework::OpKernel<T> { ...@@ -149,7 +152,8 @@ class TemporalShiftGradOpCUDAKernel : public framework::OpKernel<T> {
KeTemporalShiftBw< KeTemporalShiftBw<
T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>( T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
output_grad_data, input_grad_data, ntchw, tchw, chw, hw, w, t, c, shift_ratio); output_grad_data, input_grad_data, ntchw, tchw, chw, hw, w, t, c,
shift_ratio);
} }
}; };
......
...@@ -18,13 +18,15 @@ namespace operators { ...@@ -18,13 +18,15 @@ namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
static HOSTDEVICE inline int GetEntryIndex(int in, int it, int ic, int ih, int iw, static HOSTDEVICE inline int GetEntryIndex(int in, int it, int ic, int ih,
const int tchw, const int chw, const int hw, const int w) { int iw, const int tchw,
const int chw, const int hw,
const int w) {
return in * tchw + it * chw + ic * hw + ih * w + iw; return in * tchw + it * chw + ic * hw + ih * w + iw;
} }
template <typename T> template <typename T>
class TemporalShiftKernel: public framework::OpKernel<T> { class TemporalShiftKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<Tensor>("X"); auto* input = ctx.Input<Tensor>("X");
...@@ -95,7 +97,8 @@ class TemporalShiftGradKernel : public framework::OpKernel<T> { ...@@ -95,7 +97,8 @@ class TemporalShiftGradKernel : public framework::OpKernel<T> {
const int tchw = t * chw; const int tchw = t * chw;
const T* output_grad_data = output_grad->data<T>(); const T* output_grad_data = output_grad->data<T>();
T* input_grad_data = input_grad->mutable_data<T>({nt, c, h, w}, ctx.GetPlace()); T* input_grad_data =
input_grad->mutable_data<T>({nt, c, h, w}, ctx.GetPlace());
memset(input_grad_data, 0, input_grad->numel() * sizeof(T)); memset(input_grad_data, 0, input_grad->numel() * sizeof(T));
int src_it = 0; int src_it = 0;
......
...@@ -10301,10 +10301,8 @@ def temporal_shift(x, seg_num, shift_ratio=0.25, name=None): ...@@ -10301,10 +10301,8 @@ def temporal_shift(x, seg_num, shift_ratio=0.25, name=None):
type="temporal_shift", type="temporal_shift",
inputs={"X": x}, inputs={"X": x},
outputs={"Out": out}, outputs={"Out": out},
attrs={ attrs={"seg_num": seg_num,
"seg_num": seg_num, "shift_ratio": shift_ratio})
"shift_ratio": shift_ratio
})
return out return out
......
...@@ -24,15 +24,17 @@ from paddle.fluid import core ...@@ -24,15 +24,17 @@ from paddle.fluid import core
def temporal_shift(x, seg_num, shift_ratio): def temporal_shift(x, seg_num, shift_ratio):
shape = x.shape shape = x.shape
reshape_x = x.reshape((-1, seg_num, shape[1], shape[2], shape[3])) reshape_x = x.reshape((-1, seg_num, shape[1], shape[2], shape[3]))
pad_x = np.pad(reshape_x, ((0, 0), (1, 1), (0, 0), (0, 0), (0, 0)), 'constant') pad_x = np.pad(reshape_x, ((0, 0), (1, 1), (0, 0), (0, 0), (0, 0)),
'constant')
c1 = int(shape[1] * shift_ratio) c1 = int(shape[1] * shift_ratio)
c2 = int(shape[1] * 2 * shift_ratio) c2 = int(shape[1] * 2 * shift_ratio)
slice1 = pad_x[:, :seg_num, :c1, :, :] slice1 = pad_x[:, :seg_num, :c1, :, :]
slice2 = pad_x[:, 2:seg_num+2, c1:c2, :, :] slice2 = pad_x[:, 2:seg_num + 2, c1:c2, :, :]
slice3 = pad_x[:, 1:seg_num+1, c2:, :, :] slice3 = pad_x[:, 1:seg_num + 1, c2:, :, :]
concat_x = np.concatenate([slice1, slice2, slice3], axis=2) concat_x = np.concatenate([slice1, slice2, slice3], axis=2)
return concat_x.reshape(shape) return concat_x.reshape(shape)
class TestTemporalShift(OpTest): class TestTemporalShift(OpTest):
def setUp(self): def setUp(self):
self.initTestCase() self.initTestCase()
...@@ -44,9 +46,7 @@ class TestTemporalShift(OpTest): ...@@ -44,9 +46,7 @@ class TestTemporalShift(OpTest):
"shift_ratio": self.shift_ratio, "shift_ratio": self.shift_ratio,
} }
self.inputs = { self.inputs = {"X": x, }
"X": x,
}
output = temporal_shift(x, self.seg_num, self.shift_ratio) output = temporal_shift(x, self.seg_num, self.shift_ratio)
self.outputs = {"Out": output} self.outputs = {"Out": output}
...@@ -62,6 +62,7 @@ class TestTemporalShift(OpTest): ...@@ -62,6 +62,7 @@ class TestTemporalShift(OpTest):
self.seg_num = 3 self.seg_num = 3
self.shift_ratio = 0.25 self.shift_ratio = 0.25
class TestTemporalShift2(TestTemporalShift): class TestTemporalShift2(TestTemporalShift):
def initTestCase(self): def initTestCase(self):
self.x_shape = (4, 9, 7, 7) self.x_shape = (4, 9, 7, 7)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册