未验证 提交 e14ed71c 编写于 作者: W wangchaochaohu 提交者: GitHub

refine the performance of gather Op (#28458)

上级 e29ab5ea
......@@ -20,8 +20,8 @@ limitations under the License. */
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/gpu_launch_config.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace operators {
......@@ -165,14 +165,16 @@ __global__ void GatherGPUKernel(const T* input, const U* index, T* out,
int out_index_dim_size,
int input_index_dim_size, int size) {
int idx = blockDim.x * blockIdx.x + threadIdx.x;
int outer_size = outer_dim_size * out_index_dim_size;
for (; idx < size; idx += blockDim.x * gridDim.x) {
int inner_dim_index = idx / (outer_dim_size * out_index_dim_size);
int next_idx = idx % (outer_dim_size * out_index_dim_size);
int index_dim_index = next_idx / (outer_dim_size);
int out_dim_index = next_idx % outer_dim_size;
int inner_dim_index = idx / outer_size;
int next_idx = idx - outer_size * inner_dim_index;
int index_dim_index = next_idx / outer_dim_size;
int index_val = index[index_dim_index];
int out_dim_index = next_idx - outer_dim_size * index_dim_index;
int input_index =
inner_dim_index * (outer_dim_size * input_index_dim_size) +
index[index_dim_index] * outer_dim_size + out_dim_index;
index_val * outer_dim_size + out_dim_index;
out[idx] = input[input_index];
}
}
......@@ -234,10 +236,11 @@ void GatherV2CUDAFunction(const Tensor* input, const Tensor* index,
auto* out_data = out->mutable_data<T>(place);
int out_size = out->numel();
int threads = 512;
int grid = (out_size + threads - 1) / threads;
platform::GpuLaunchConfig config =
platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), out_size);
auto stream = ctx.cuda_device_context().stream();
GatherGPUKernel<T, U><<<grid, threads, 0, stream>>>(
GatherGPUKernel<
T, U><<<config.block_per_grid, config.thread_per_block, 0, stream>>>(
input_data, index_data, out_data, outer_dim_size, inner_dim_size,
index_size, index_dim_size, out_size);
}
......@@ -280,10 +283,11 @@ void GatherV2GradCUDAFunction(const Tensor* input, const Tensor* index,
int out_index_dim_size = out_dim[axis_index];
operators::math::set_constant(*dev_ctx, out, 0.0);
int threads = 512;
int grid = (input_size + threads - 1) / threads;
platform::GpuLaunchConfig config =
platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), input_size);
auto stream = ctx.cuda_device_context().stream();
GatherGradGPUKernel<T, U><<<grid, threads, 0, stream>>>(
GatherGradGPUKernel<
T, U><<<config.block_per_grid, config.thread_per_block, 0, stream>>>(
input_data, index_data, out_data, outer_dim_size, inner_dim_size,
input_index_dim_size, out_index_dim_size, input_size);
}
......
......@@ -66,6 +66,11 @@ class GatherOp : public framework::OperatorWithKernel {
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
}
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const framework::Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override {
return expected_kernel_type;
}
};
class GatherGradOp : public framework::OperatorWithKernel {
......
......@@ -16,7 +16,7 @@ from __future__ import print_function
from ..fluid.layers import core
from ..fluid.layer_helper import LayerHelper
from ..fluid.framework import Variable, OpProtoHolder, in_dygraph_mode, convert_np_dtype_to_dtype_
from ..fluid.framework import Variable, OpProtoHolder, in_dygraph_mode, convert_np_dtype_to_dtype_, device_guard
from ..fluid.data_feeder import convert_dtype, check_variable_and_dtype, check_type, check_dtype
from ..fluid.layers.tensor import fill_constant
from ..fluid.layers import utils
......@@ -794,7 +794,8 @@ def gather(x, index, axis=None, name=None):
axis = 0
axis_tensor = axis
if not isinstance(axis, Variable):
axis_tensor = fill_constant(shape=[1], dtype='int64', value=axis)
with device_guard("cpu"):
axis_tensor = fill_constant(shape=[1], dtype='int64', value=axis)
if in_dygraph_mode():
return core.ops.gather(x, index, axis_tensor)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册