diff --git a/paddle/fluid/operators/gather.cu.h b/paddle/fluid/operators/gather.cu.h index c4bdd9e439c54db03f8fa8c4fe439ed6edbd0c7a..16864f28baaf91d548d611ef198bd8598fe8960a 100644 --- a/paddle/fluid/operators/gather.cu.h +++ b/paddle/fluid/operators/gather.cu.h @@ -20,8 +20,8 @@ limitations under the License. */ #include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/platform/cuda_primitives.h" +#include "paddle/fluid/platform/gpu_launch_config.h" #include "paddle/fluid/platform/place.h" - namespace paddle { namespace operators { @@ -165,14 +165,16 @@ __global__ void GatherGPUKernel(const T* input, const U* index, T* out, int out_index_dim_size, int input_index_dim_size, int size) { int idx = blockDim.x * blockIdx.x + threadIdx.x; + int outer_size = outer_dim_size * out_index_dim_size; for (; idx < size; idx += blockDim.x * gridDim.x) { - int inner_dim_index = idx / (outer_dim_size * out_index_dim_size); - int next_idx = idx % (outer_dim_size * out_index_dim_size); - int index_dim_index = next_idx / (outer_dim_size); - int out_dim_index = next_idx % outer_dim_size; + int inner_dim_index = idx / outer_size; + int next_idx = idx - outer_size * inner_dim_index; + int index_dim_index = next_idx / outer_dim_size; + int index_val = index[index_dim_index]; + int out_dim_index = next_idx - outer_dim_size * index_dim_index; int input_index = inner_dim_index * (outer_dim_size * input_index_dim_size) + - index[index_dim_index] * outer_dim_size + out_dim_index; + index_val * outer_dim_size + out_dim_index; out[idx] = input[input_index]; } } @@ -234,10 +236,11 @@ void GatherV2CUDAFunction(const Tensor* input, const Tensor* index, auto* out_data = out->mutable_data(place); int out_size = out->numel(); - int threads = 512; - int grid = (out_size + threads - 1) / threads; + platform::GpuLaunchConfig config = + platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), out_size); auto stream = ctx.cuda_device_context().stream(); - GatherGPUKernel<<>>( + GatherGPUKernel< + T, U><<>>( input_data, index_data, out_data, outer_dim_size, inner_dim_size, index_size, index_dim_size, out_size); } @@ -280,10 +283,11 @@ void GatherV2GradCUDAFunction(const Tensor* input, const Tensor* index, int out_index_dim_size = out_dim[axis_index]; operators::math::set_constant(*dev_ctx, out, 0.0); - int threads = 512; - int grid = (input_size + threads - 1) / threads; + platform::GpuLaunchConfig config = + platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), input_size); auto stream = ctx.cuda_device_context().stream(); - GatherGradGPUKernel<<>>( + GatherGradGPUKernel< + T, U><<>>( input_data, index_data, out_data, outer_dim_size, inner_dim_size, input_index_dim_size, out_index_dim_size, input_size); } diff --git a/paddle/fluid/operators/gather_op.cc b/paddle/fluid/operators/gather_op.cc index a99879316d684ca95e73ce8db43e988efcbab4c4..72b44b22f9c06060468c3ab9a11b18658082c716 100644 --- a/paddle/fluid/operators/gather_op.cc +++ b/paddle/fluid/operators/gather_op.cc @@ -66,6 +66,11 @@ class GatherOp : public framework::OperatorWithKernel { OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.device_context()); } + framework::OpKernelType GetKernelTypeForVar( + const std::string& var_name, const framework::Tensor& tensor, + const framework::OpKernelType& expected_kernel_type) const override { + return expected_kernel_type; + } }; class GatherGradOp : public framework::OperatorWithKernel { diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 19b88e122e4e2371eb435958d937f71dc972bff3..1d0785f97db0a11b559e240934dda4d085a025bf 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -16,7 +16,7 @@ from __future__ import print_function from ..fluid.layers import core from ..fluid.layer_helper import LayerHelper -from ..fluid.framework import Variable, OpProtoHolder, in_dygraph_mode, convert_np_dtype_to_dtype_ +from ..fluid.framework import Variable, OpProtoHolder, in_dygraph_mode, convert_np_dtype_to_dtype_, device_guard from ..fluid.data_feeder import convert_dtype, check_variable_and_dtype, check_type, check_dtype from ..fluid.layers.tensor import fill_constant from ..fluid.layers import utils @@ -794,7 +794,8 @@ def gather(x, index, axis=None, name=None): axis = 0 axis_tensor = axis if not isinstance(axis, Variable): - axis_tensor = fill_constant(shape=[1], dtype='int64', value=axis) + with device_guard("cpu"): + axis_tensor = fill_constant(shape=[1], dtype='int64', value=axis) if in_dygraph_mode(): return core.ops.gather(x, index, axis_tensor)