未验证 提交 613303db 编写于 作者: W wangchaochaohu 提交者: GitHub

refine the slice Op to improve the performance of xlnet for fp16 training (#24967)

上级 37bdb526
...@@ -13,12 +13,12 @@ See the License for the specific language governing permissions and ...@@ -13,12 +13,12 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <thrust/device_vector.h> #include <thrust/device_vector.h>
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/slice_op.h" #include "paddle/fluid/operators/slice_op.h"
#include "paddle/fluid/platform/cuda_device_function.h" #include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -94,17 +94,22 @@ class SliceGradKernel<paddle::platform::CUDADeviceContext, ...@@ -94,17 +94,22 @@ class SliceGradKernel<paddle::platform::CUDADeviceContext,
dim3 blocks((numel - 1) / PADDLE_CUDA_NUM_THREADS + 1); dim3 blocks((numel - 1) / PADDLE_CUDA_NUM_THREADS + 1);
dim3 threads(PADDLE_CUDA_NUM_THREADS); dim3 threads(PADDLE_CUDA_NUM_THREADS);
auto stream = ctx.cuda_device_context().stream(); auto stream = ctx.cuda_device_context().stream();
const std::vector<int64_t> out_shape =
auto out_shape = framework::vectorize<int64_t>(out_dims); framework::vectorize<int64_t>(out_dims);
thrust::device_vector<int64_t> out_dims_vec(out_shape.begin(), const std::vector<int64_t> in_shape =
out_shape.end()); framework::vectorize<int64_t>(in_dims);
auto in_shape = framework::vectorize<int64_t>(in_dims);
thrust::device_vector<int64_t> in_dims_vec(in_shape.begin(), framework::Tensor out_dims_tensor;
in_shape.end()); framework::Tensor in_dims_tensor;
thrust::device_vector<int64_t> offsets_vec(offsets.begin(), offsets.end()); framework::Tensor offsets_tensor;
const int64_t* out_dims_ptr = thrust::raw_pointer_cast(out_dims_vec.data()); framework::TensorFromVector(out_shape, ctx.device_context(),
const int64_t* in_dims_ptr = thrust::raw_pointer_cast(in_dims_vec.data()); &out_dims_tensor);
const int64_t* offsets_ptr = thrust::raw_pointer_cast(offsets_vec.data()); framework::TensorFromVector(in_shape, ctx.device_context(),
&in_dims_tensor);
framework::TensorFromVector(offsets, ctx.device_context(), &offsets_tensor);
const int64_t* out_dims_ptr = out_dims_tensor.data<int64_t>();
const int64_t* in_dims_ptr = in_dims_tensor.data<int64_t>();
const int64_t* offsets_ptr = offsets_tensor.data<int64_t>();
switch (rank) { switch (rank) {
case 1: case 1:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册