提交 3fc0d192 编写于 作者: P phlrain

update

上级 a8e02ef1
......@@ -19,56 +19,6 @@ namespace operators {
using Tensor = framework::Tensor;
using DataLayout = framework::DataLayout;
template <typename T>
void TemporalShiftFwNCHW(const T* input, T* output, const int ntchw,
const int tchw, const int chw, const int hw,
const int t, const int c1, const int c2) {
int src_it = 0;
for (int i = 0; i < ntchw; i++) {
int it = (i % tchw) / chw;
int ic = (i % chw) / hw;
if (ic < c1) {
src_it = it - 1;
} else if (ic < c2) {
src_it = it + 1;
} else {
src_it = it;
}
if (src_it < 0 || src_it >= t) {
output[i] = 0;
} else {
output[i] = input[i + (src_it - it) * chw];
}
}
}
template <typename T>
void TemporalShiftFwNHWC(const T* input, T* output, const int nthwc,
const int thwc, const int hwc, const int t,
const int c, const int c1, const int c2) {
int src_it = 0;
for (int i = 0; i < nthwc; i++) {
int it = (i % thwc) / hwc;
int ic = i % c;
if (ic < c1) {
src_it = it - 1;
} else if (ic < c2) {
src_it = it + 1;
} else {
src_it = it;
}
if (src_it < 0 || src_it >= t) {
output[i] = 0;
} else {
output[i] = input[i + (src_it - it) * hwc];
}
}
}
template <typename T>
void TemporalShiftBwNCHW(const T* output_grad, T* input_grad, const int ntchw,
const int tchw, const int chw, const int hw,
......@@ -122,45 +72,7 @@ void TemporalShiftBwNHWC(const T* output_grad, T* input_grad, const int nthwc,
template <typename T>
class TemporalShiftKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<Tensor>("X");
auto* output = ctx.Output<Tensor>("Out");
int t = ctx.Attr<int>("seg_num");
float shift_ratio = ctx.Attr<float>("shift_ratio");
const std::string data_format_str = ctx.Attr<std::string>("data_format");
const DataLayout data_layout =
framework::StringToDataLayout(data_format_str);
const int nt = input->dims()[0];
const int c = (data_layout == DataLayout::kNCHW ? input->dims()[1]
: input->dims()[3]);
const int h = (data_layout == DataLayout::kNCHW ? input->dims()[2]
: input->dims()[1]);
const int w = (data_layout == DataLayout::kNCHW ? input->dims()[3]
: input->dims()[2]);
const int hw = h * w;
const int chw = c * hw;
const int tchw = t * chw;
const int ntchw = nt * chw;
const int c1 = static_cast<int>(c * shift_ratio);
const int c2 = static_cast<int>(c * 2 * shift_ratio);
framework::DDim out_dims =
(data_layout == DataLayout::kNCHW ? phi::make_ddim({nt, c, h, w})
: phi::make_ddim({nt, h, w, c}));
const T* input_data = input->data<T>();
T* output_data = output->mutable_data<T>(out_dims, ctx.GetPlace());
if (data_layout == DataLayout::kNCHW) {
TemporalShiftFwNCHW<T>(input_data, output_data, ntchw, tchw, chw, hw, t,
c1, c2);
} else {
TemporalShiftFwNHWC<T>(input_data, output_data, ntchw, tchw, chw, t, c,
c1, c2);
}
}
void Compute(const framework::ExecutionContext& ctx) const override {}
};
template <typename T>
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/selected_rows.h"
namespace phi {
template <typename T, typename Context>
void ClipByNormKernel(const Context& ctx,
const DenseTensor& x,
float max_norm,
DenseTensor* out);
template <typename T, typename Context>
void ClipByNormSparseKernel(const Context& ctx,
const SelectedRows& x,
float max_norm,
SelectedRows* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/clip_by_norm_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/clip_by_norm_kernel_impl.h"
PD_REGISTER_KERNEL(
clip_by_norm, CPU, ALL_LAYOUT, phi::ClipByNormKernel, float) {}
PD_REGISTER_KERNEL(
clip_by_norm_sparse, CPU, ALL_LAYOUT, phi::ClipByNormSparseKernel, float) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/temporal_shift_grad_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T>
void TemporalShiftBwNCHW(const T* output_grad,
T* input_grad,
const int ntchw,
const int tchw,
const int chw,
const int hw,
const int t,
const int c1,
const int c2) {
int src_it = 0;
for (int i = 0; i < ntchw; i++) {
int it = (i % tchw) / chw;
int ic = (i % chw) / hw;
if (ic < c1) {
src_it = it + 1;
} else if (ic < c2) {
src_it = it - 1;
} else {
src_it = it;
}
if (src_it >= 0 && src_it < t) {
input_grad[i] = output_grad[i + (src_it - it) * chw];
} else {
input_grad[i] = 0;
}
}
}
template <typename T>
void TemporalShiftBwNHWC(const T* output_grad,
T* input_grad,
const int nthwc,
const int thwc,
const int hwc,
const int t,
const int c,
const int c1,
const int c2) {
int src_it = 0;
for (int i = 0; i < nthwc; i++) {
int it = (i % thwc) / hwc;
int ic = i % c;
if (ic < c1) {
src_it = it + 1;
} else if (ic < c2) {
src_it = it - 1;
} else {
src_it = it;
}
if (src_it >= 0 && src_it < t) {
input_grad[i] = output_grad[i + (src_it - it) * hwc];
} else {
input_grad[i] = 0;
}
}
}
template <typename T, typename Context>
void TemporalShiftGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
int seg_num,
float shift_ratio,
const std::string& data_format_str,
DenseTensor* x_grad) {
auto* input_grad = x_grad;
auto* output_grad = &out_grad;
int t = seg_num;
const DataLayout data_layout =
paddle::framework::StringToDataLayout(data_format_str);
const int nt = output_grad->dims()[0];
const int c = (data_layout == DataLayout::kNCHW ? output_grad->dims()[1]
: output_grad->dims()[3]);
const int h = (data_layout == DataLayout::kNCHW ? output_grad->dims()[2]
: output_grad->dims()[1]);
const int w = (data_layout == DataLayout::kNCHW ? output_grad->dims()[3]
: output_grad->dims()[2]);
const int hw = h * w;
const int chw = c * hw;
const int tchw = t * chw;
const int ntchw = nt * chw;
const int c1 = static_cast<int>(c * shift_ratio);
const int c2 = static_cast<int>(c * 2 * shift_ratio);
DDim in_grad_dims =
(data_layout == DataLayout::kNCHW ? phi::make_ddim({nt, c, h, w})
: phi::make_ddim({nt, h, w, c}));
const T* output_grad_data = output_grad->data<T>();
T* input_grad_data =
input_grad->mutable_data<T>(in_grad_dims, dev_ctx.GetPlace());
if (data_layout == DataLayout::kNCHW) {
TemporalShiftBwNCHW<T>(
output_grad_data, input_grad_data, ntchw, tchw, chw, hw, t, c1, c2);
} else {
TemporalShiftBwNHWC<T>(
output_grad_data, input_grad_data, ntchw, tchw, chw, t, c, c1, c2);
}
}
} // namespace phi
PD_REGISTER_KERNEL(temporal_shift_grad,
CPU,
ALL_LAYOUT,
phi::TemporalShiftGradKernel,
float,
double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/temporal_shift_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T>
void TemporalShiftFwNCHW(const T* input,
T* output,
const int ntchw,
const int tchw,
const int chw,
const int hw,
const int t,
const int c1,
const int c2) {
int src_it = 0;
for (int i = 0; i < ntchw; i++) {
int it = (i % tchw) / chw;
int ic = (i % chw) / hw;
if (ic < c1) {
src_it = it - 1;
} else if (ic < c2) {
src_it = it + 1;
} else {
src_it = it;
}
if (src_it < 0 || src_it >= t) {
output[i] = 0;
} else {
output[i] = input[i + (src_it - it) * chw];
}
}
}
template <typename T>
void TemporalShiftFwNHWC(const T* input,
T* output,
const int nthwc,
const int thwc,
const int hwc,
const int t,
const int c,
const int c1,
const int c2) {
int src_it = 0;
for (int i = 0; i < nthwc; i++) {
int it = (i % thwc) / hwc;
int ic = i % c;
if (ic < c1) {
src_it = it - 1;
} else if (ic < c2) {
src_it = it + 1;
} else {
src_it = it;
}
if (src_it < 0 || src_it >= t) {
output[i] = 0;
} else {
output[i] = input[i + (src_it - it) * hwc];
}
}
}
template <typename T, typename Context>
void TemporalShiftKernel(const Context& dev_ctx,
const DenseTensor& x,
int seg_num,
float shift_ratio,
const std::string& data_format_str,
DenseTensor* out) {
auto* input = &x;
auto* output = out;
int t = seg_num;
const DataLayout data_layout =
paddle::framework::StringToDataLayout(data_format_str);
const int nt = input->dims()[0];
const int c =
(data_layout == DataLayout::kNCHW ? input->dims()[1] : input->dims()[3]);
const int h =
(data_layout == DataLayout::kNCHW ? input->dims()[2] : input->dims()[1]);
const int w =
(data_layout == DataLayout::kNCHW ? input->dims()[3] : input->dims()[2]);
const int hw = h * w;
const int chw = c * hw;
const int tchw = t * chw;
const int ntchw = nt * chw;
const int c1 = static_cast<int>(c * shift_ratio);
const int c2 = static_cast<int>(c * 2 * shift_ratio);
DDim out_dims =
(data_layout == DataLayout::kNCHW ? phi::make_ddim({nt, c, h, w})
: phi::make_ddim({nt, h, w, c}));
const T* input_data = input->data<T>();
T* output_data = output->mutable_data<T>(out_dims, dev_ctx.GetPlace());
if (data_layout == DataLayout::kNCHW) {
TemporalShiftFwNCHW<T>(
input_data, output_data, ntchw, tchw, chw, hw, t, c1, c2);
} else {
TemporalShiftFwNHWC<T>(
input_data, output_data, ntchw, tchw, chw, t, c, c1, c2);
}
}
} // namespace phi
PD_REGISTER_KERNEL(
temporal_shift, CPU, ALL_LAYOUT, phi::TemporalShiftKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/clip_by_norm_kernel.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/impl/clip_by_norm_kernel_impl.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/phi/kernels/gpu/reduce.h"
#include "paddle/phi/kernels/primitive/functor_primitives.h"
namespace phi {
template <>
void ClipByNormKernel<phi::dtype::float16, phi::GPUContext>(
const GPUContext& dev_ctx,
const DenseTensor& x_in,
float max_norm,
DenseTensor* out_p) {
dev_ctx.template Alloc<dtype::float16>(out_p);
std::vector<int> reduce_dims;
reduce_dims.resize(x_in.dims().size());
for (int i = 0; i < reduce_dims.size(); ++i) {
reduce_dims[i] = i;
}
DenseTensor tmp;
tmp.Resize({1});
dev_ctx.template Alloc<float>(&tmp);
kernels::TensorReduceImpl<dtype::float16,
float,
kps::AddFunctor,
kps::SquareFunctor<dtype::float16, float>>(
dev_ctx,
x_in,
&tmp,
kps::SquareFunctor<dtype::float16, float>(),
reduce_dims,
dev_ctx.stream());
auto tmp_eigen = EigenVector<float>::Flatten(tmp);
auto x_norm = tmp_eigen.sqrt();
auto x = EigenVector<dtype::float16>::Flatten(x_in);
auto out = EigenVector<dtype::float16>::Flatten(*out_p);
auto& place = *dev_ctx.eigen_device();
auto temp = (x_norm <= max_norm).template cast<float>();
auto epsilon =
((x_norm <= static_cast<float>(1e-30)).all().template cast<float>()) *
static_cast<float>(1e-6);
auto scaling =
(temp + (static_cast<float>(1) - temp) * max_norm / (x_norm + epsilon))
.template cast<dtype::float16>();
Eigen::array<int, 1> one_dim{{1}};
Eigen::DSizes<int, 1> m_dsize(x_in.numel());
out.device(place) = x * scaling.reshape(one_dim).broadcast(m_dsize);
}
template <>
void ClipByNormSparseKernel<phi::dtype::float16, phi::GPUContext>(
const phi::GPUContext& ctx,
const SelectedRows& x,
float max_norm,
SelectedRows* out) {
// merge ids in selected rows first
paddle::operators::math::scatter::MergeAdd<GPUContext, dtype::float16>
merge_func;
phi::SelectedRows merged_input;
merge_func(ctx, x, &merged_input);
auto input = merged_input.value();
phi::SelectedRows* output_selected_rows = out;
output_selected_rows->set_rows(merged_input.rows());
output_selected_rows->set_height(merged_input.height());
auto output = output_selected_rows->mutable_value();
output->Resize(merged_input.value().dims());
output->mutable_data<dtype::float16>(ctx.GetPlace());
ClipByNormKernel<dtype::float16>(ctx, input, max_norm, output);
}
} // namespace phi
// PD_REGISTER_KERNEL(
// clip_by_norm, GPU, ALL_LAYOUT, phi::ClipByNormKernel, float,
// phi::dtype::float16) {}
// PD_REGISTER_KERNEL(
// clip_by_norm_sparse, GPU, ALL_LAYOUT, phi::ClipByNormSparseKernel, float,
// phi::dtype::float16) {}
PD_REGISTER_KERNEL(
clip_by_norm, GPU, ALL_LAYOUT, phi::ClipByNormKernel, float) {}
PD_REGISTER_KERNEL(
clip_by_norm_sparse, GPU, ALL_LAYOUT, phi::ClipByNormSparseKernel, float) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/temporal_shift_grad_kernel.h"
namespace phi {
template <typename T>
__global__ void KeTemporalShiftBwNCHW(const T* output_grad,
T* input_grad,
const int ntchw,
const int tchw,
const int chw,
const int hw,
const int t,
const int c1,
const int c2) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
int src_it = 0;
for (; tid < ntchw; tid += stride) {
int it = (tid % tchw) / chw;
int ic = (tid % chw) / hw;
if (ic < c1) {
src_it = it + 1;
} else if (ic < c2) {
src_it = it - 1;
} else {
src_it = it;
}
if (src_it >= 0 && src_it < t) {
input_grad[tid] = output_grad[tid + (src_it - it) * chw];
} else {
input_grad[tid] = 0;
}
}
}
template <typename T>
__global__ void KeTemporalShiftBwNHWC(const T* output_grad,
T* input_grad,
const int nthwc,
const int thwc,
const int hwc,
const int t,
const int c,
const int c1,
const int c2) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
int src_it = 0;
for (; tid < nthwc; tid += stride) {
int it = (tid % thwc) / hwc;
int ic = tid % c;
if (ic < c1) {
src_it = it + 1;
} else if (ic < c2) {
src_it = it - 1;
} else {
src_it = it;
}
if (src_it >= 0 && src_it < t) {
input_grad[tid] = output_grad[tid + (src_it - it) * hwc];
} else {
input_grad[tid] = 0;
}
}
}
template <typename T, typename Context>
void TemporalShiftGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
int seg_num,
float shift_ratio,
const std::string& data_format_str,
DenseTensor* x_grad) {
auto* input_grad = x_grad;
auto* output_grad = &out_grad;
int t = seg_num;
const DataLayout data_layout =
paddle::framework::StringToDataLayout(data_format_str);
const int nt = output_grad->dims()[0];
const int c = (data_layout == DataLayout::kNCHW ? output_grad->dims()[1]
: output_grad->dims()[3]);
const int h = (data_layout == DataLayout::kNCHW ? output_grad->dims()[2]
: output_grad->dims()[1]);
const int w = (data_layout == DataLayout::kNCHW ? output_grad->dims()[3]
: output_grad->dims()[2]);
const int hw = h * w;
const int chw = c * hw;
const int tchw = t * chw;
const int ntchw = nt * chw;
const int c1 = static_cast<int>(c * shift_ratio);
const int c2 = static_cast<int>(c * 2 * shift_ratio);
DDim in_grad_dims =
(data_layout == DataLayout::kNCHW ? phi::make_ddim({nt, c, h, w})
: phi::make_ddim({nt, h, w, c}));
const T* output_grad_data = output_grad->data<T>();
T* input_grad_data =
input_grad->mutable_data<T>(in_grad_dims, dev_ctx.GetPlace());
int pixelNum = nt * chw;
int threads = 1024;
int grid = (pixelNum + threads - 1) / threads;
int blocks_per_sm = dev_ctx.GetMaxPhysicalThreadCount() / threads;
grid = std::min(dev_ctx.GetSMCount() * blocks_per_sm, grid);
if (data_layout == DataLayout::kNCHW) {
KeTemporalShiftBwNCHW<T><<<grid, threads, 0, dev_ctx.stream()>>>(
output_grad_data, input_grad_data, ntchw, tchw, chw, hw, t, c1, c2);
} else {
KeTemporalShiftBwNHWC<T><<<grid, threads, 0, dev_ctx.stream()>>>(
output_grad_data, input_grad_data, ntchw, tchw, chw, t, c, c1, c2);
}
}
} // namespace phi
PD_REGISTER_KERNEL(temporal_shift_grad,
GPU,
ALL_LAYOUT,
phi::TemporalShiftGradKernel,
float,
double,
phi::dtype::float16) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/temporal_shift_kernel.h"
namespace phi {
template <typename T>
__global__ void KeTemporalShiftFwNCHW(const T* input,
T* output,
const int ntchw,
const int tchw,
const int chw,
const int hw,
const int t,
const int c1,
const int c2) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
int src_it = 0;
for (; tid < ntchw; tid += stride) {
int it = (tid % tchw) / chw;
int ic = (tid % chw) / hw;
if (ic < c1) {
src_it = it - 1;
} else if (ic < c2) {
src_it = it + 1;
} else {
src_it = it;
}
if (src_it < 0 || src_it >= t) {
output[tid] = 0;
} else {
output[tid] = input[tid + (src_it - it) * chw];
}
}
}
template <typename T>
__global__ void KeTemporalShiftFwNHWC(const T* input,
T* output,
const int nthwc,
const int thwc,
const int hwc,
const int t,
const int c,
const int c1,
const int c2) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
int src_it = 0;
for (; tid < nthwc; tid += stride) {
int it = (tid % thwc) / hwc;
int ic = tid % c;
if (ic < c1) {
src_it = it - 1;
} else if (ic < c2) {
src_it = it + 1;
} else {
src_it = it;
}
if (src_it < 0 || src_it >= t) {
output[tid] = 0;
} else {
output[tid] = input[tid + (src_it - it) * hwc];
}
}
}
template <typename T, typename Context>
void TemporalShiftKernel(const Context& dev_ctx,
const DenseTensor& x,
int seg_num,
float shift_ratio,
const std::string& data_format_str,
DenseTensor* out) {
auto* input = &x;
auto* output = out;
int t = seg_num;
const DataLayout data_layout =
paddle::framework::StringToDataLayout(data_format_str);
const int nt = input->dims()[0];
const int c =
(data_layout == DataLayout::kNCHW ? input->dims()[1] : input->dims()[3]);
const int h =
(data_layout == DataLayout::kNCHW ? input->dims()[2] : input->dims()[1]);
const int w =
(data_layout == DataLayout::kNCHW ? input->dims()[3] : input->dims()[2]);
const int hw = h * w;
const int chw = c * hw;
const int tchw = t * chw;
const int ntchw = nt * chw;
const int c1 = static_cast<int>(c * shift_ratio);
const int c2 = static_cast<int>(c * 2 * shift_ratio);
DDim out_dims =
(data_layout == DataLayout::kNCHW ? phi::make_ddim({nt, c, h, w})
: phi::make_ddim({nt, h, w, c}));
const T* input_data = input->data<T>();
T* output_data = output->mutable_data<T>(out_dims, dev_ctx.GetPlace());
int pixelNum = nt * chw;
int threads = 1024;
int grid = (pixelNum + threads - 1) / threads;
int blocks_per_sm = dev_ctx.GetMaxPhysicalThreadCount() / threads;
grid = std::min(dev_ctx.GetSMCount() * blocks_per_sm, grid);
if (data_layout == DataLayout::kNCHW) {
KeTemporalShiftFwNCHW<T><<<grid, threads, 0, dev_ctx.stream()>>>(
input_data, output_data, ntchw, tchw, chw, hw, t, c1, c2);
} else {
KeTemporalShiftFwNHWC<T><<<grid, threads, 0, dev_ctx.stream()>>>(
input_data, output_data, ntchw, tchw, chw, t, c, c1, c2);
}
}
} // namespace phi
PD_REGISTER_KERNEL(temporal_shift,
GPU,
ALL_LAYOUT,
phi::TemporalShiftKernel,
float,
double,
phi::dtype::float16) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
namespace phi {
template <typename T, typename Context>
void ClipByNormKernel(const Context& ctx,
const DenseTensor& x_in,
float max_norm,
DenseTensor* out_p) {
ctx.template Alloc<T>(out_p);
auto x = EigenVector<T>::Flatten(x_in);
auto out = EigenVector<T>::Flatten(*out_p);
auto x_norm = x.square().sum().sqrt();
auto& place = *ctx.eigen_device();
auto temp = (x_norm <= max_norm).template cast<T>();
auto epsilon = ((x_norm <= static_cast<T>(1e-30)).all().template cast<T>()) *
static_cast<T>(1e-6);
auto scaling =
temp + (static_cast<T>(1) - temp) * max_norm / (x_norm + epsilon);
Eigen::array<int, 1> one_dim{{1}};
Eigen::DSizes<int, 1> m_dsize(x_in.numel());
if (ctx.GetPlace() == phi::CPUPlace()) {
out.device(place) = x * scaling.reshape(one_dim).eval().broadcast(m_dsize);
} else {
out.device(place) = x * scaling.reshape(one_dim).broadcast(m_dsize);
}
}
template <typename T, typename Context>
void ClipByNormSparseKernel(const Context& ctx,
const SelectedRows& x,
float max_norm,
SelectedRows* out) {
// merge ids in selected rows first
paddle::operators::math::scatter::MergeAdd<Context, T> merge_func;
phi::SelectedRows merged_input;
merge_func(ctx, x, &merged_input);
auto input = merged_input.value();
phi::SelectedRows* output_selected_rows = out;
output_selected_rows->set_rows(merged_input.rows());
output_selected_rows->set_height(merged_input.height());
auto output = output_selected_rows->mutable_value();
output->Resize(merged_input.value().dims());
output->mutable_data<T>(ctx.GetPlace());
ClipByNormKernel<T>(ctx, input, max_norm, output);
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void TemporalShiftGradKernel(const Context& ctx,
const DenseTensor& out_grad,
int seg_num,
float shift_ratio,
const std::string& data_format,
DenseTensor* x_grad);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void TemporalShiftKernel(const Context& ctx,
const DenseTensor& x,
int seg_num,
float shift_ratio,
const std::string& data_format,
DenseTensor* out);
} // namespace phi
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册