From 613303dbf66497e9e9baa98fb4e0eb0eb86d6f7f Mon Sep 17 00:00:00 2001 From: wangchaochaohu Date: Wed, 10 Jun 2020 11:33:28 +0800 Subject: [PATCH] refine the slice Op to improve the performance of xlnet for fp16 training (#24967) --- paddle/fluid/operators/slice_op.cu | 29 +++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/paddle/fluid/operators/slice_op.cu b/paddle/fluid/operators/slice_op.cu index a4ea5ad1eed..d6945df9e18 100644 --- a/paddle/fluid/operators/slice_op.cu +++ b/paddle/fluid/operators/slice_op.cu @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. */ #include +#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/slice_op.h" #include "paddle/fluid/platform/cuda_device_function.h" #include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/float16.h" - namespace paddle { namespace operators { @@ -94,17 +94,22 @@ class SliceGradKernel(out_dims); - thrust::device_vector out_dims_vec(out_shape.begin(), - out_shape.end()); - auto in_shape = framework::vectorize(in_dims); - thrust::device_vector in_dims_vec(in_shape.begin(), - in_shape.end()); - thrust::device_vector offsets_vec(offsets.begin(), offsets.end()); - const int64_t* out_dims_ptr = thrust::raw_pointer_cast(out_dims_vec.data()); - const int64_t* in_dims_ptr = thrust::raw_pointer_cast(in_dims_vec.data()); - const int64_t* offsets_ptr = thrust::raw_pointer_cast(offsets_vec.data()); + const std::vector out_shape = + framework::vectorize(out_dims); + const std::vector in_shape = + framework::vectorize(in_dims); + + framework::Tensor out_dims_tensor; + framework::Tensor in_dims_tensor; + framework::Tensor offsets_tensor; + framework::TensorFromVector(out_shape, ctx.device_context(), + &out_dims_tensor); + framework::TensorFromVector(in_shape, ctx.device_context(), + &in_dims_tensor); + framework::TensorFromVector(offsets, ctx.device_context(), &offsets_tensor); + const int64_t* out_dims_ptr = out_dims_tensor.data(); + const int64_t* in_dims_ptr = in_dims_tensor.data(); + const int64_t* offsets_ptr = offsets_tensor.data(); switch (rank) { case 1: -- GitLab