未验证 提交 e14ed71c 编写于 作者: W wangchaochaohu 提交者: GitHub

refine the performance of gather Op (#28458)

上级 e29ab5ea
...@@ -20,8 +20,8 @@ limitations under the License. */ ...@@ -20,8 +20,8 @@ limitations under the License. */
#include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/gpu_launch_config.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -165,14 +165,16 @@ __global__ void GatherGPUKernel(const T* input, const U* index, T* out, ...@@ -165,14 +165,16 @@ __global__ void GatherGPUKernel(const T* input, const U* index, T* out,
int out_index_dim_size, int out_index_dim_size,
int input_index_dim_size, int size) { int input_index_dim_size, int size) {
int idx = blockDim.x * blockIdx.x + threadIdx.x; int idx = blockDim.x * blockIdx.x + threadIdx.x;
int outer_size = outer_dim_size * out_index_dim_size;
for (; idx < size; idx += blockDim.x * gridDim.x) { for (; idx < size; idx += blockDim.x * gridDim.x) {
int inner_dim_index = idx / (outer_dim_size * out_index_dim_size); int inner_dim_index = idx / outer_size;
int next_idx = idx % (outer_dim_size * out_index_dim_size); int next_idx = idx - outer_size * inner_dim_index;
int index_dim_index = next_idx / (outer_dim_size); int index_dim_index = next_idx / outer_dim_size;
int out_dim_index = next_idx % outer_dim_size; int index_val = index[index_dim_index];
int out_dim_index = next_idx - outer_dim_size * index_dim_index;
int input_index = int input_index =
inner_dim_index * (outer_dim_size * input_index_dim_size) + inner_dim_index * (outer_dim_size * input_index_dim_size) +
index[index_dim_index] * outer_dim_size + out_dim_index; index_val * outer_dim_size + out_dim_index;
out[idx] = input[input_index]; out[idx] = input[input_index];
} }
} }
...@@ -234,10 +236,11 @@ void GatherV2CUDAFunction(const Tensor* input, const Tensor* index, ...@@ -234,10 +236,11 @@ void GatherV2CUDAFunction(const Tensor* input, const Tensor* index,
auto* out_data = out->mutable_data<T>(place); auto* out_data = out->mutable_data<T>(place);
int out_size = out->numel(); int out_size = out->numel();
int threads = 512; platform::GpuLaunchConfig config =
int grid = (out_size + threads - 1) / threads; platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), out_size);
auto stream = ctx.cuda_device_context().stream(); auto stream = ctx.cuda_device_context().stream();
GatherGPUKernel<T, U><<<grid, threads, 0, stream>>>( GatherGPUKernel<
T, U><<<config.block_per_grid, config.thread_per_block, 0, stream>>>(
input_data, index_data, out_data, outer_dim_size, inner_dim_size, input_data, index_data, out_data, outer_dim_size, inner_dim_size,
index_size, index_dim_size, out_size); index_size, index_dim_size, out_size);
} }
...@@ -280,10 +283,11 @@ void GatherV2GradCUDAFunction(const Tensor* input, const Tensor* index, ...@@ -280,10 +283,11 @@ void GatherV2GradCUDAFunction(const Tensor* input, const Tensor* index,
int out_index_dim_size = out_dim[axis_index]; int out_index_dim_size = out_dim[axis_index];
operators::math::set_constant(*dev_ctx, out, 0.0); operators::math::set_constant(*dev_ctx, out, 0.0);
int threads = 512; platform::GpuLaunchConfig config =
int grid = (input_size + threads - 1) / threads; platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), input_size);
auto stream = ctx.cuda_device_context().stream(); auto stream = ctx.cuda_device_context().stream();
GatherGradGPUKernel<T, U><<<grid, threads, 0, stream>>>( GatherGradGPUKernel<
T, U><<<config.block_per_grid, config.thread_per_block, 0, stream>>>(
input_data, index_data, out_data, outer_dim_size, inner_dim_size, input_data, index_data, out_data, outer_dim_size, inner_dim_size,
input_index_dim_size, out_index_dim_size, input_size); input_index_dim_size, out_index_dim_size, input_size);
} }
......
...@@ -66,6 +66,11 @@ class GatherOp : public framework::OperatorWithKernel { ...@@ -66,6 +66,11 @@ class GatherOp : public framework::OperatorWithKernel {
OperatorWithKernel::IndicateVarDataType(ctx, "X"), OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context()); ctx.device_context());
} }
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const framework::Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override {
return expected_kernel_type;
}
}; };
class GatherGradOp : public framework::OperatorWithKernel { class GatherGradOp : public framework::OperatorWithKernel {
......
...@@ -16,7 +16,7 @@ from __future__ import print_function ...@@ -16,7 +16,7 @@ from __future__ import print_function
from ..fluid.layers import core from ..fluid.layers import core
from ..fluid.layer_helper import LayerHelper from ..fluid.layer_helper import LayerHelper
from ..fluid.framework import Variable, OpProtoHolder, in_dygraph_mode, convert_np_dtype_to_dtype_ from ..fluid.framework import Variable, OpProtoHolder, in_dygraph_mode, convert_np_dtype_to_dtype_, device_guard
from ..fluid.data_feeder import convert_dtype, check_variable_and_dtype, check_type, check_dtype from ..fluid.data_feeder import convert_dtype, check_variable_and_dtype, check_type, check_dtype
from ..fluid.layers.tensor import fill_constant from ..fluid.layers.tensor import fill_constant
from ..fluid.layers import utils from ..fluid.layers import utils
...@@ -794,6 +794,7 @@ def gather(x, index, axis=None, name=None): ...@@ -794,6 +794,7 @@ def gather(x, index, axis=None, name=None):
axis = 0 axis = 0
axis_tensor = axis axis_tensor = axis
if not isinstance(axis, Variable): if not isinstance(axis, Variable):
with device_guard("cpu"):
axis_tensor = fill_constant(shape=[1], dtype='int64', value=axis) axis_tensor = fill_constant(shape=[1], dtype='int64', value=axis)
if in_dygraph_mode(): if in_dygraph_mode():
return core.ops.gather(x, index, axis_tensor) return core.ops.gather(x, index, axis_tensor)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册