diff --git a/paddle/fluid/operators/temporal_shift_op.h b/paddle/fluid/operators/temporal_shift_op.h index ec43ed88cb0ded230f5b2f53d2f257ba1c951cec..141f2127da9609b48c08a3607db0e42bad1f94b7 100644 --- a/paddle/fluid/operators/temporal_shift_op.h +++ b/paddle/fluid/operators/temporal_shift_op.h @@ -19,56 +19,6 @@ namespace operators { using Tensor = framework::Tensor; using DataLayout = framework::DataLayout; -template -void TemporalShiftFwNCHW(const T* input, T* output, const int ntchw, - const int tchw, const int chw, const int hw, - const int t, const int c1, const int c2) { - int src_it = 0; - for (int i = 0; i < ntchw; i++) { - int it = (i % tchw) / chw; - int ic = (i % chw) / hw; - - if (ic < c1) { - src_it = it - 1; - } else if (ic < c2) { - src_it = it + 1; - } else { - src_it = it; - } - - if (src_it < 0 || src_it >= t) { - output[i] = 0; - } else { - output[i] = input[i + (src_it - it) * chw]; - } - } -} - -template -void TemporalShiftFwNHWC(const T* input, T* output, const int nthwc, - const int thwc, const int hwc, const int t, - const int c, const int c1, const int c2) { - int src_it = 0; - for (int i = 0; i < nthwc; i++) { - int it = (i % thwc) / hwc; - int ic = i % c; - - if (ic < c1) { - src_it = it - 1; - } else if (ic < c2) { - src_it = it + 1; - } else { - src_it = it; - } - - if (src_it < 0 || src_it >= t) { - output[i] = 0; - } else { - output[i] = input[i + (src_it - it) * hwc]; - } - } -} - template void TemporalShiftBwNCHW(const T* output_grad, T* input_grad, const int ntchw, const int tchw, const int chw, const int hw, @@ -122,45 +72,7 @@ void TemporalShiftBwNHWC(const T* output_grad, T* input_grad, const int nthwc, template class TemporalShiftKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* input = ctx.Input("X"); - auto* output = ctx.Output("Out"); - int t = ctx.Attr("seg_num"); - float shift_ratio = ctx.Attr("shift_ratio"); - const std::string data_format_str = ctx.Attr("data_format"); - const DataLayout data_layout = - framework::StringToDataLayout(data_format_str); - - const int nt = input->dims()[0]; - const int c = (data_layout == DataLayout::kNCHW ? input->dims()[1] - : input->dims()[3]); - const int h = (data_layout == DataLayout::kNCHW ? input->dims()[2] - : input->dims()[1]); - const int w = (data_layout == DataLayout::kNCHW ? input->dims()[3] - : input->dims()[2]); - - const int hw = h * w; - const int chw = c * hw; - const int tchw = t * chw; - const int ntchw = nt * chw; - - const int c1 = static_cast(c * shift_ratio); - const int c2 = static_cast(c * 2 * shift_ratio); - - framework::DDim out_dims = - (data_layout == DataLayout::kNCHW ? phi::make_ddim({nt, c, h, w}) - : phi::make_ddim({nt, h, w, c})); - const T* input_data = input->data(); - T* output_data = output->mutable_data(out_dims, ctx.GetPlace()); - - if (data_layout == DataLayout::kNCHW) { - TemporalShiftFwNCHW(input_data, output_data, ntchw, tchw, chw, hw, t, - c1, c2); - } else { - TemporalShiftFwNHWC(input_data, output_data, ntchw, tchw, chw, t, c, - c1, c2); - } - } + void Compute(const framework::ExecutionContext& ctx) const override {} }; template diff --git a/paddle/phi/kernels/clip_by_norm_kernel.h b/paddle/phi/kernels/clip_by_norm_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..0d00a9f0061a931fe2439ea8749a55c3a9e50978 --- /dev/null +++ b/paddle/phi/kernels/clip_by_norm_kernel.h @@ -0,0 +1,34 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/selected_rows.h" + +namespace phi { + +template +void ClipByNormKernel(const Context& ctx, + const DenseTensor& x, + float max_norm, + DenseTensor* out); + +template +void ClipByNormSparseKernel(const Context& ctx, + const SelectedRows& x, + float max_norm, + SelectedRows* out); + +} // namespace phi diff --git a/paddle/phi/kernels/cpu/clip_by_norm_kernel.cc b/paddle/phi/kernels/cpu/clip_by_norm_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..aec0f8976b0e872744949266c81a245f9ad1f3e8 --- /dev/null +++ b/paddle/phi/kernels/cpu/clip_by_norm_kernel.cc @@ -0,0 +1,24 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/clip_by_norm_kernel.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/clip_by_norm_kernel_impl.h" + +PD_REGISTER_KERNEL( + clip_by_norm, CPU, ALL_LAYOUT, phi::ClipByNormKernel, float) {} + +PD_REGISTER_KERNEL( + clip_by_norm_sparse, CPU, ALL_LAYOUT, phi::ClipByNormSparseKernel, float) {} diff --git a/paddle/phi/kernels/cpu/temporal_shift_grad_kernel.cc b/paddle/phi/kernels/cpu/temporal_shift_grad_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..400f7e8783932c1a3e40f0c4e3ec6fe45421d6db --- /dev/null +++ b/paddle/phi/kernels/cpu/temporal_shift_grad_kernel.cc @@ -0,0 +1,136 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/temporal_shift_grad_kernel.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/common/layout.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void TemporalShiftBwNCHW(const T* output_grad, + T* input_grad, + const int ntchw, + const int tchw, + const int chw, + const int hw, + const int t, + const int c1, + const int c2) { + int src_it = 0; + for (int i = 0; i < ntchw; i++) { + int it = (i % tchw) / chw; + int ic = (i % chw) / hw; + + if (ic < c1) { + src_it = it + 1; + } else if (ic < c2) { + src_it = it - 1; + } else { + src_it = it; + } + + if (src_it >= 0 && src_it < t) { + input_grad[i] = output_grad[i + (src_it - it) * chw]; + } else { + input_grad[i] = 0; + } + } +} + +template +void TemporalShiftBwNHWC(const T* output_grad, + T* input_grad, + const int nthwc, + const int thwc, + const int hwc, + const int t, + const int c, + const int c1, + const int c2) { + int src_it = 0; + for (int i = 0; i < nthwc; i++) { + int it = (i % thwc) / hwc; + int ic = i % c; + + if (ic < c1) { + src_it = it + 1; + } else if (ic < c2) { + src_it = it - 1; + } else { + src_it = it; + } + + if (src_it >= 0 && src_it < t) { + input_grad[i] = output_grad[i + (src_it - it) * hwc]; + } else { + input_grad[i] = 0; + } + } +} + +template +void TemporalShiftGradKernel(const Context& dev_ctx, + const DenseTensor& out_grad, + int seg_num, + float shift_ratio, + const std::string& data_format_str, + DenseTensor* x_grad) { + auto* input_grad = x_grad; + auto* output_grad = &out_grad; + int t = seg_num; + const DataLayout data_layout = + paddle::framework::StringToDataLayout(data_format_str); + + const int nt = output_grad->dims()[0]; + const int c = (data_layout == DataLayout::kNCHW ? output_grad->dims()[1] + : output_grad->dims()[3]); + const int h = (data_layout == DataLayout::kNCHW ? output_grad->dims()[2] + : output_grad->dims()[1]); + const int w = (data_layout == DataLayout::kNCHW ? output_grad->dims()[3] + : output_grad->dims()[2]); + + const int hw = h * w; + const int chw = c * hw; + const int tchw = t * chw; + const int ntchw = nt * chw; + + const int c1 = static_cast(c * shift_ratio); + const int c2 = static_cast(c * 2 * shift_ratio); + + DDim in_grad_dims = + (data_layout == DataLayout::kNCHW ? phi::make_ddim({nt, c, h, w}) + : phi::make_ddim({nt, h, w, c})); + const T* output_grad_data = output_grad->data(); + T* input_grad_data = + input_grad->mutable_data(in_grad_dims, dev_ctx.GetPlace()); + + if (data_layout == DataLayout::kNCHW) { + TemporalShiftBwNCHW( + output_grad_data, input_grad_data, ntchw, tchw, chw, hw, t, c1, c2); + } else { + TemporalShiftBwNHWC( + output_grad_data, input_grad_data, ntchw, tchw, chw, t, c, c1, c2); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(temporal_shift_grad, + CPU, + ALL_LAYOUT, + phi::TemporalShiftGradKernel, + float, + double) {} diff --git a/paddle/phi/kernels/cpu/temporal_shift_kernel.cc b/paddle/phi/kernels/cpu/temporal_shift_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..6721117992dd538b436da8fde0e03b7c8714a831 --- /dev/null +++ b/paddle/phi/kernels/cpu/temporal_shift_kernel.cc @@ -0,0 +1,131 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/temporal_shift_kernel.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/common/layout.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void TemporalShiftFwNCHW(const T* input, + T* output, + const int ntchw, + const int tchw, + const int chw, + const int hw, + const int t, + const int c1, + const int c2) { + int src_it = 0; + for (int i = 0; i < ntchw; i++) { + int it = (i % tchw) / chw; + int ic = (i % chw) / hw; + + if (ic < c1) { + src_it = it - 1; + } else if (ic < c2) { + src_it = it + 1; + } else { + src_it = it; + } + + if (src_it < 0 || src_it >= t) { + output[i] = 0; + } else { + output[i] = input[i + (src_it - it) * chw]; + } + } +} + +template +void TemporalShiftFwNHWC(const T* input, + T* output, + const int nthwc, + const int thwc, + const int hwc, + const int t, + const int c, + const int c1, + const int c2) { + int src_it = 0; + for (int i = 0; i < nthwc; i++) { + int it = (i % thwc) / hwc; + int ic = i % c; + + if (ic < c1) { + src_it = it - 1; + } else if (ic < c2) { + src_it = it + 1; + } else { + src_it = it; + } + + if (src_it < 0 || src_it >= t) { + output[i] = 0; + } else { + output[i] = input[i + (src_it - it) * hwc]; + } + } +} + +template +void TemporalShiftKernel(const Context& dev_ctx, + const DenseTensor& x, + int seg_num, + float shift_ratio, + const std::string& data_format_str, + DenseTensor* out) { + auto* input = &x; + auto* output = out; + int t = seg_num; + const DataLayout data_layout = + paddle::framework::StringToDataLayout(data_format_str); + + const int nt = input->dims()[0]; + const int c = + (data_layout == DataLayout::kNCHW ? input->dims()[1] : input->dims()[3]); + const int h = + (data_layout == DataLayout::kNCHW ? input->dims()[2] : input->dims()[1]); + const int w = + (data_layout == DataLayout::kNCHW ? input->dims()[3] : input->dims()[2]); + + const int hw = h * w; + const int chw = c * hw; + const int tchw = t * chw; + const int ntchw = nt * chw; + + const int c1 = static_cast(c * shift_ratio); + const int c2 = static_cast(c * 2 * shift_ratio); + + DDim out_dims = + (data_layout == DataLayout::kNCHW ? phi::make_ddim({nt, c, h, w}) + : phi::make_ddim({nt, h, w, c})); + const T* input_data = input->data(); + T* output_data = output->mutable_data(out_dims, dev_ctx.GetPlace()); + + if (data_layout == DataLayout::kNCHW) { + TemporalShiftFwNCHW( + input_data, output_data, ntchw, tchw, chw, hw, t, c1, c2); + } else { + TemporalShiftFwNHWC( + input_data, output_data, ntchw, tchw, chw, t, c, c1, c2); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL( + temporal_shift, CPU, ALL_LAYOUT, phi::TemporalShiftKernel, float, double) {} diff --git a/paddle/phi/kernels/gpu/clip_by_norm_kernel.cu b/paddle/phi/kernels/gpu/clip_by_norm_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..74efdcf733404017323b1cae45f51d105f9db4c2 --- /dev/null +++ b/paddle/phi/kernels/gpu/clip_by_norm_kernel.cu @@ -0,0 +1,112 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/clip_by_norm_kernel.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/impl/clip_by_norm_kernel_impl.h" + +#include "paddle/fluid/operators/math/selected_rows_functor.h" +#include "paddle/phi/kernels/gpu/reduce.h" +#include "paddle/phi/kernels/primitive/functor_primitives.h" + +namespace phi { + +template <> +void ClipByNormKernel( + const GPUContext& dev_ctx, + const DenseTensor& x_in, + float max_norm, + DenseTensor* out_p) { + dev_ctx.template Alloc(out_p); + std::vector reduce_dims; + reduce_dims.resize(x_in.dims().size()); + for (int i = 0; i < reduce_dims.size(); ++i) { + reduce_dims[i] = i; + } + + DenseTensor tmp; + tmp.Resize({1}); + dev_ctx.template Alloc(&tmp); + kernels::TensorReduceImpl>( + dev_ctx, + x_in, + &tmp, + kps::SquareFunctor(), + reduce_dims, + dev_ctx.stream()); + + auto tmp_eigen = EigenVector::Flatten(tmp); + auto x_norm = tmp_eigen.sqrt(); + + auto x = EigenVector::Flatten(x_in); + auto out = EigenVector::Flatten(*out_p); + + auto& place = *dev_ctx.eigen_device(); + + auto temp = (x_norm <= max_norm).template cast(); + auto epsilon = + ((x_norm <= static_cast(1e-30)).all().template cast()) * + static_cast(1e-6); + + auto scaling = + (temp + (static_cast(1) - temp) * max_norm / (x_norm + epsilon)) + .template cast(); + Eigen::array one_dim{{1}}; + Eigen::DSizes m_dsize(x_in.numel()); + + out.device(place) = x * scaling.reshape(one_dim).broadcast(m_dsize); +} + +template <> +void ClipByNormSparseKernel( + const phi::GPUContext& ctx, + const SelectedRows& x, + float max_norm, + SelectedRows* out) { + // merge ids in selected rows first + paddle::operators::math::scatter::MergeAdd + merge_func; + phi::SelectedRows merged_input; + merge_func(ctx, x, &merged_input); + auto input = merged_input.value(); + + phi::SelectedRows* output_selected_rows = out; + output_selected_rows->set_rows(merged_input.rows()); + output_selected_rows->set_height(merged_input.height()); + auto output = output_selected_rows->mutable_value(); + output->Resize(merged_input.value().dims()); + output->mutable_data(ctx.GetPlace()); + + ClipByNormKernel(ctx, input, max_norm, output); +} + +} // namespace phi + +// PD_REGISTER_KERNEL( +// clip_by_norm, GPU, ALL_LAYOUT, phi::ClipByNormKernel, float, +// phi::dtype::float16) {} + +// PD_REGISTER_KERNEL( +// clip_by_norm_sparse, GPU, ALL_LAYOUT, phi::ClipByNormSparseKernel, float, +// phi::dtype::float16) {} +PD_REGISTER_KERNEL( + clip_by_norm, GPU, ALL_LAYOUT, phi::ClipByNormKernel, float) {} + +PD_REGISTER_KERNEL( + clip_by_norm_sparse, GPU, ALL_LAYOUT, phi::ClipByNormSparseKernel, float) {} diff --git a/paddle/phi/kernels/gpu/temporal_shift_grad_kernel.cu b/paddle/phi/kernels/gpu/temporal_shift_grad_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..065b1726dce5b0f525e421eb5d2e094f1615badb --- /dev/null +++ b/paddle/phi/kernels/gpu/temporal_shift_grad_kernel.cu @@ -0,0 +1,149 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/layout.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/temporal_shift_grad_kernel.h" + +namespace phi { + +template +__global__ void KeTemporalShiftBwNCHW(const T* output_grad, + T* input_grad, + const int ntchw, + const int tchw, + const int chw, + const int hw, + const int t, + const int c1, + const int c2) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + int src_it = 0; + + for (; tid < ntchw; tid += stride) { + int it = (tid % tchw) / chw; + int ic = (tid % chw) / hw; + + if (ic < c1) { + src_it = it + 1; + } else if (ic < c2) { + src_it = it - 1; + } else { + src_it = it; + } + + if (src_it >= 0 && src_it < t) { + input_grad[tid] = output_grad[tid + (src_it - it) * chw]; + } else { + input_grad[tid] = 0; + } + } +} + +template +__global__ void KeTemporalShiftBwNHWC(const T* output_grad, + T* input_grad, + const int nthwc, + const int thwc, + const int hwc, + const int t, + const int c, + const int c1, + const int c2) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + int src_it = 0; + + for (; tid < nthwc; tid += stride) { + int it = (tid % thwc) / hwc; + int ic = tid % c; + + if (ic < c1) { + src_it = it + 1; + } else if (ic < c2) { + src_it = it - 1; + } else { + src_it = it; + } + + if (src_it >= 0 && src_it < t) { + input_grad[tid] = output_grad[tid + (src_it - it) * hwc]; + } else { + input_grad[tid] = 0; + } + } +} + +template +void TemporalShiftGradKernel(const Context& dev_ctx, + const DenseTensor& out_grad, + int seg_num, + float shift_ratio, + const std::string& data_format_str, + DenseTensor* x_grad) { + auto* input_grad = x_grad; + auto* output_grad = &out_grad; + int t = seg_num; + const DataLayout data_layout = + paddle::framework::StringToDataLayout(data_format_str); + + const int nt = output_grad->dims()[0]; + const int c = (data_layout == DataLayout::kNCHW ? output_grad->dims()[1] + : output_grad->dims()[3]); + const int h = (data_layout == DataLayout::kNCHW ? output_grad->dims()[2] + : output_grad->dims()[1]); + const int w = (data_layout == DataLayout::kNCHW ? output_grad->dims()[3] + : output_grad->dims()[2]); + + const int hw = h * w; + const int chw = c * hw; + const int tchw = t * chw; + const int ntchw = nt * chw; + + const int c1 = static_cast(c * shift_ratio); + const int c2 = static_cast(c * 2 * shift_ratio); + + DDim in_grad_dims = + (data_layout == DataLayout::kNCHW ? phi::make_ddim({nt, c, h, w}) + : phi::make_ddim({nt, h, w, c})); + const T* output_grad_data = output_grad->data(); + T* input_grad_data = + input_grad->mutable_data(in_grad_dims, dev_ctx.GetPlace()); + + int pixelNum = nt * chw; + int threads = 1024; + int grid = (pixelNum + threads - 1) / threads; + int blocks_per_sm = dev_ctx.GetMaxPhysicalThreadCount() / threads; + grid = std::min(dev_ctx.GetSMCount() * blocks_per_sm, grid); + + if (data_layout == DataLayout::kNCHW) { + KeTemporalShiftBwNCHW<<>>( + output_grad_data, input_grad_data, ntchw, tchw, chw, hw, t, c1, c2); + } else { + KeTemporalShiftBwNHWC<<>>( + output_grad_data, input_grad_data, ntchw, tchw, chw, t, c, c1, c2); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(temporal_shift_grad, + GPU, + ALL_LAYOUT, + phi::TemporalShiftGradKernel, + float, + double, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/gpu/temporal_shift_kernel.cu b/paddle/phi/kernels/gpu/temporal_shift_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..34d80a1bc804b82bc93c77534ec7769b4b5406f4 --- /dev/null +++ b/paddle/phi/kernels/gpu/temporal_shift_kernel.cu @@ -0,0 +1,148 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/layout.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/temporal_shift_kernel.h" + +namespace phi { + +template +__global__ void KeTemporalShiftFwNCHW(const T* input, + T* output, + const int ntchw, + const int tchw, + const int chw, + const int hw, + const int t, + const int c1, + const int c2) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + int src_it = 0; + + for (; tid < ntchw; tid += stride) { + int it = (tid % tchw) / chw; + int ic = (tid % chw) / hw; + + if (ic < c1) { + src_it = it - 1; + } else if (ic < c2) { + src_it = it + 1; + } else { + src_it = it; + } + + if (src_it < 0 || src_it >= t) { + output[tid] = 0; + } else { + output[tid] = input[tid + (src_it - it) * chw]; + } + } +} + +template +__global__ void KeTemporalShiftFwNHWC(const T* input, + T* output, + const int nthwc, + const int thwc, + const int hwc, + const int t, + const int c, + const int c1, + const int c2) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + int src_it = 0; + + for (; tid < nthwc; tid += stride) { + int it = (tid % thwc) / hwc; + int ic = tid % c; + + if (ic < c1) { + src_it = it - 1; + } else if (ic < c2) { + src_it = it + 1; + } else { + src_it = it; + } + + if (src_it < 0 || src_it >= t) { + output[tid] = 0; + } else { + output[tid] = input[tid + (src_it - it) * hwc]; + } + } +} + +template +void TemporalShiftKernel(const Context& dev_ctx, + const DenseTensor& x, + int seg_num, + float shift_ratio, + const std::string& data_format_str, + DenseTensor* out) { + auto* input = &x; + auto* output = out; + int t = seg_num; + const DataLayout data_layout = + paddle::framework::StringToDataLayout(data_format_str); + + const int nt = input->dims()[0]; + const int c = + (data_layout == DataLayout::kNCHW ? input->dims()[1] : input->dims()[3]); + const int h = + (data_layout == DataLayout::kNCHW ? input->dims()[2] : input->dims()[1]); + const int w = + (data_layout == DataLayout::kNCHW ? input->dims()[3] : input->dims()[2]); + + const int hw = h * w; + const int chw = c * hw; + const int tchw = t * chw; + const int ntchw = nt * chw; + + const int c1 = static_cast(c * shift_ratio); + const int c2 = static_cast(c * 2 * shift_ratio); + + DDim out_dims = + (data_layout == DataLayout::kNCHW ? phi::make_ddim({nt, c, h, w}) + : phi::make_ddim({nt, h, w, c})); + const T* input_data = input->data(); + T* output_data = output->mutable_data(out_dims, dev_ctx.GetPlace()); + + int pixelNum = nt * chw; + int threads = 1024; + int grid = (pixelNum + threads - 1) / threads; + int blocks_per_sm = dev_ctx.GetMaxPhysicalThreadCount() / threads; + grid = std::min(dev_ctx.GetSMCount() * blocks_per_sm, grid); + + if (data_layout == DataLayout::kNCHW) { + KeTemporalShiftFwNCHW<<>>( + input_data, output_data, ntchw, tchw, chw, hw, t, c1, c2); + } else { + KeTemporalShiftFwNHWC<<>>( + input_data, output_data, ntchw, tchw, chw, t, c, c1, c2); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(temporal_shift, + GPU, + ALL_LAYOUT, + phi::TemporalShiftKernel, + float, + double, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/impl/clip_by_norm_kernel_impl.h b/paddle/phi/kernels/impl/clip_by_norm_kernel_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..f759d22148e7bf1235083eb86693f0f5871b8711 --- /dev/null +++ b/paddle/phi/kernels/impl/clip_by_norm_kernel_impl.h @@ -0,0 +1,70 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/operators/math/selected_rows_functor.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" + +namespace phi { + +template +void ClipByNormKernel(const Context& ctx, + const DenseTensor& x_in, + float max_norm, + DenseTensor* out_p) { + ctx.template Alloc(out_p); + auto x = EigenVector::Flatten(x_in); + auto out = EigenVector::Flatten(*out_p); + auto x_norm = x.square().sum().sqrt(); + auto& place = *ctx.eigen_device(); + + auto temp = (x_norm <= max_norm).template cast(); + auto epsilon = ((x_norm <= static_cast(1e-30)).all().template cast()) * + static_cast(1e-6); + + auto scaling = + temp + (static_cast(1) - temp) * max_norm / (x_norm + epsilon); + Eigen::array one_dim{{1}}; + Eigen::DSizes m_dsize(x_in.numel()); + if (ctx.GetPlace() == phi::CPUPlace()) { + out.device(place) = x * scaling.reshape(one_dim).eval().broadcast(m_dsize); + } else { + out.device(place) = x * scaling.reshape(one_dim).broadcast(m_dsize); + } +} + +template +void ClipByNormSparseKernel(const Context& ctx, + const SelectedRows& x, + float max_norm, + SelectedRows* out) { + // merge ids in selected rows first + paddle::operators::math::scatter::MergeAdd merge_func; + phi::SelectedRows merged_input; + merge_func(ctx, x, &merged_input); + auto input = merged_input.value(); + + phi::SelectedRows* output_selected_rows = out; + output_selected_rows->set_rows(merged_input.rows()); + output_selected_rows->set_height(merged_input.height()); + auto output = output_selected_rows->mutable_value(); + output->Resize(merged_input.value().dims()); + output->mutable_data(ctx.GetPlace()); + + ClipByNormKernel(ctx, input, max_norm, output); +} + +} // namespace phi diff --git a/paddle/phi/kernels/temporal_shift_grad_kernel.h b/paddle/phi/kernels/temporal_shift_grad_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..1bcd3d61c26f538827896411a21a5b1e5aa1c127 --- /dev/null +++ b/paddle/phi/kernels/temporal_shift_grad_kernel.h @@ -0,0 +1,29 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void TemporalShiftGradKernel(const Context& ctx, + const DenseTensor& out_grad, + int seg_num, + float shift_ratio, + const std::string& data_format, + DenseTensor* x_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/temporal_shift_kernel.h b/paddle/phi/kernels/temporal_shift_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..a927d7fb23aae68e6422ed0ac80dbfaa8ac0da58 --- /dev/null +++ b/paddle/phi/kernels/temporal_shift_kernel.h @@ -0,0 +1,29 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void TemporalShiftKernel(const Context& ctx, + const DenseTensor& x, + int seg_num, + float shift_ratio, + const std::string& data_format, + DenseTensor* out); + +} // namespace phi