From 187bf412692d9f36bf50bd0e65024b8314063884 Mon Sep 17 00:00:00 2001 From: wuhuanzhou Date: Tue, 6 Apr 2021 17:21:54 +0800 Subject: [PATCH] optimize compilation of operators using eigen (#31851) --- paddle/fluid/operators/CMakeLists.txt | 4 +- paddle/fluid/operators/addmm_op.h | 8 +- paddle/fluid/operators/eigen/CMakeLists.txt | 10 +++ paddle/fluid/operators/eigen/broadcast.cc | 86 ++++++++++++++++++ paddle/fluid/operators/eigen/broadcast.cu | 87 +++++++++++++++++++ paddle/fluid/operators/eigen/eigen_function.h | 52 +++++++++++ paddle/fluid/operators/expand_as_op.h | 19 ++-- paddle/fluid/operators/expand_as_v2_op.h | 19 ++-- paddle/fluid/operators/expand_op.h | 22 ++--- paddle/fluid/operators/expand_v2_op.h | 22 ++--- paddle/fluid/operators/meshgrid_op.h | 18 ++-- paddle/fluid/operators/tile_op.h | 22 ++--- 12 files changed, 309 insertions(+), 60 deletions(-) create mode 100644 paddle/fluid/operators/eigen/CMakeLists.txt create mode 100644 paddle/fluid/operators/eigen/broadcast.cc create mode 100644 paddle/fluid/operators/eigen/broadcast.cu create mode 100644 paddle/fluid/operators/eigen/eigen_function.h diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 467a5ff906..ed87872753 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -10,6 +10,7 @@ file(WRITE ${pybind_file} "// Generated by the paddle/fluid/operators/CMakeLists copy_if_different(${pybind_file} ${pybind_file_final}) add_subdirectory(math) +add_subdirectory(eigen) add_subdirectory(controlflow) add_subdirectory(detection) add_subdirectory(elementwise) @@ -110,8 +111,9 @@ set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence_padding sequence_scale cos_sim_fun set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions beam_search fc matrix_inverse) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} box_wrapper boost ps_gpu_wrapper) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} common_infer_shape_functions) +set(COMMON_OP_DEPS ${COMMON_OP_DEPS} eigen_cc_function) if (WITH_GPU OR WITH_ROCM) - set(COMMON_OP_DEPS ${COMMON_OP_DEPS} depthwise_conv prelu bert_encoder_functor) + set(COMMON_OP_DEPS ${COMMON_OP_DEPS} depthwise_conv prelu bert_encoder_functor eigen_cu_function) endif() set(COMMON_OP_DEPS ${COMMON_OP_DEPS} device_memory_aligment) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} layer) diff --git a/paddle/fluid/operators/addmm_op.h b/paddle/fluid/operators/addmm_op.h index 97e3ed9c1a..ecfd10d2fa 100644 --- a/paddle/fluid/operators/addmm_op.h +++ b/paddle/fluid/operators/addmm_op.h @@ -18,6 +18,7 @@ limitations under the License. */ #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/operators/eigen/eigen_function.h" #include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/math_function.h" @@ -32,8 +33,8 @@ template using EigenTensor = framework::EigenTensor; -using Array1 = Eigen::DSizes; -using Array2 = Eigen::DSizes; +using Array1 = Eigen::DSizes; +using Array2 = Eigen::DSizes; using Tensor = framework::Tensor; @@ -105,7 +106,8 @@ class AddMMKernel : public framework::OpKernel { auto eigen_out = EigenTensor::From(*out); auto& place = *context.template device_context().eigen_device(); - eigen_out.device(place) = eigen_input.broadcast(bcast_dims); + EigenBroadcast, T, 2>::Eval( + place, eigen_out, eigen_input, bcast_dims); blas.GEMM(false, false, x_dims[0], y_dims[1], x_dims[1], alpha, x->data(), x_dims[1], y->data(), y_dims[1], beta, diff --git a/paddle/fluid/operators/eigen/CMakeLists.txt b/paddle/fluid/operators/eigen/CMakeLists.txt new file mode 100644 index 0000000000..848bf2433c --- /dev/null +++ b/paddle/fluid/operators/eigen/CMakeLists.txt @@ -0,0 +1,10 @@ +file(GLOB EIGEN_CC_SOURCES RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.cc") +cc_library(eigen_cc_function SRCS ${EIGEN_CC_SOURCES} DEPS eigen3) +if(WITH_GPU OR WITH_ROCM) + file(GLOB EIGEN_CU_SOURCES RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.cu") + if(WITH_GPU) + nv_library(eigen_cu_function SRCS ${EIGEN_CU_SOURCES} DEPS eigen3) + elseif(WITH_ROCM) + hip_library(eigen_cu_function SRCS ${EIGEN_CU_SOURCES} DEPS eigen3) + endif() +endif() diff --git a/paddle/fluid/operators/eigen/broadcast.cc b/paddle/fluid/operators/eigen/broadcast.cc new file mode 100644 index 0000000000..dab25f9549 --- /dev/null +++ b/paddle/fluid/operators/eigen/broadcast.cc @@ -0,0 +1,86 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#include "paddle/fluid/operators/eigen/eigen_function.h" +#include "paddle/fluid/platform/float16.h" + +namespace paddle { +namespace operators { + +template +struct EigenBroadcast { + using Array = Eigen::DSizes; + using InType = Eigen::TensorMap< + Eigen::Tensor>; + using InType32BitIndex = + Eigen::TensorMap, + Eigen::Aligned>; + using OutType = Eigen::TensorMap< + Eigen::Tensor>; + using OutType32BitIndex = + Eigen::TensorMap, + Eigen::Aligned>; + + static void Eval(const Eigen::DefaultDevice& dev, OutType out, InType in, + const Array& bcast) { + out.device(dev) = in.broadcast(bcast); + } + + static void Eval(const Eigen::DefaultDevice& dev, OutType32BitIndex out, + InType32BitIndex in, const Array& bcast) { + out.device(dev) = in.broadcast(bcast); + } +}; + +template +struct EigenBroadcastGrad { + using Array = Eigen::DSizes; + using Array2 = Eigen::DSizes; + using InType = Eigen::TensorMap< + Eigen::Tensor>; + using OutType = + Eigen::TensorMap>; + static void Eval(const Eigen::DefaultDevice& dev, OutType out, InType in, + const Array& reduce_dims, const Array2& reshape_dims) { + out.device(dev) = + in.reshape(reshape_dims).sum(reduce_dims).reshape(out.dimensions()); + } +}; + +#define INSTANTIATION(FUNCTOR, T) \ + template struct FUNCTOR; \ + template struct FUNCTOR; \ + template struct FUNCTOR; \ + template struct FUNCTOR; \ + template struct FUNCTOR; \ + template struct FUNCTOR +INSTANTIATION(EigenBroadcast, bool); +INSTANTIATION(EigenBroadcast, platform::float16); +INSTANTIATION(EigenBroadcast, float); +INSTANTIATION(EigenBroadcast, double); +INSTANTIATION(EigenBroadcast, int); +INSTANTIATION(EigenBroadcast, int64_t); +INSTANTIATION(EigenBroadcastGrad, bool); +INSTANTIATION(EigenBroadcastGrad, float); +INSTANTIATION(EigenBroadcastGrad, platform::float16); +INSTANTIATION(EigenBroadcastGrad, double); +INSTANTIATION(EigenBroadcastGrad, int); +INSTANTIATION(EigenBroadcastGrad, int64_t); +template struct EigenBroadcastGrad; +template struct EigenBroadcastGrad; +template struct EigenBroadcastGrad; +template struct EigenBroadcastGrad; +#undef INSTANTIATION + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/eigen/broadcast.cu b/paddle/fluid/operators/eigen/broadcast.cu new file mode 100644 index 0000000000..63e244d393 --- /dev/null +++ b/paddle/fluid/operators/eigen/broadcast.cu @@ -0,0 +1,87 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#include "paddle/fluid/operators/eigen/eigen_function.h" +#include "paddle/fluid/platform/float16.h" + +namespace paddle { +namespace operators { + +template +struct EigenBroadcast { + using Array = Eigen::DSizes; + using InType = Eigen::TensorMap< + Eigen::Tensor>; + using InType32BitIndex = + Eigen::TensorMap, + Eigen::Aligned>; + using OutType = Eigen::TensorMap< + Eigen::Tensor>; + using OutType32BitIndex = + Eigen::TensorMap, + Eigen::Aligned>; + + static void Eval(const Eigen::GpuDevice& dev, OutType out, InType in, + const Array& bcast) { + out.device(dev) = in.broadcast(bcast); + } + + static void Eval(const Eigen::GpuDevice& dev, OutType32BitIndex out, + InType32BitIndex in, const Array& bcast) { + out.device(dev) = in.broadcast(bcast); + } +}; + +template +struct EigenBroadcastGrad { + using Array = Eigen::DSizes; + using Array2 = Eigen::DSizes; + using InType = Eigen::TensorMap< + Eigen::Tensor>; + using OutType = + Eigen::TensorMap>; + static void Eval(const Eigen::GpuDevice& dev, OutType out, InType in, + const Array& reduce_dims, const Array2& reshape_dims) { + out.device(dev) = + in.reshape(reshape_dims).sum(reduce_dims).reshape(out.dimensions()); + } +}; + +#define INSTANTIATION(FUNCTOR, T) \ + template struct FUNCTOR; \ + template struct FUNCTOR; \ + template struct FUNCTOR; \ + template struct FUNCTOR; \ + template struct FUNCTOR; \ + template struct FUNCTOR +INSTANTIATION(EigenBroadcast, bool); +INSTANTIATION(EigenBroadcast, platform::float16); +INSTANTIATION(EigenBroadcast, float); +INSTANTIATION(EigenBroadcast, double); +INSTANTIATION(EigenBroadcast, int); +INSTANTIATION(EigenBroadcast, int64_t); +INSTANTIATION(EigenBroadcastGrad, bool); +INSTANTIATION(EigenBroadcastGrad, float); +INSTANTIATION(EigenBroadcastGrad, platform::float16); +INSTANTIATION(EigenBroadcastGrad, double); +INSTANTIATION(EigenBroadcastGrad, int); +INSTANTIATION(EigenBroadcastGrad, int64_t); +template struct EigenBroadcastGrad; +template struct EigenBroadcastGrad; +template struct EigenBroadcastGrad; +template struct EigenBroadcastGrad; +template struct EigenBroadcastGrad; +#undef INSTANTIATION + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/eigen/eigen_function.h b/paddle/fluid/operators/eigen/eigen_function.h new file mode 100644 index 0000000000..5966950595 --- /dev/null +++ b/paddle/fluid/operators/eigen/eigen_function.h @@ -0,0 +1,52 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#pragma once +#include "unsupported/Eigen/CXX11/Tensor" + +namespace paddle { +namespace operators { + +template +struct EigenBroadcast { + using Array = Eigen::DSizes; + using InType = Eigen::TensorMap< + Eigen::Tensor>; + using InType32BitIndex = + Eigen::TensorMap, + Eigen::Aligned>; + using OutType = Eigen::TensorMap< + Eigen::Tensor>; + using OutType32BitIndex = + Eigen::TensorMap, + Eigen::Aligned>; + static void Eval(const EigenDevice& dev, OutType out, InType in, + const Array& bcast); + static void Eval(const EigenDevice& dev, OutType32BitIndex out, + InType32BitIndex in, const Array& bcast); +}; + +template +struct EigenBroadcastGrad { + using Array = Eigen::DSizes; + using Array2 = Eigen::DSizes; + using InType = Eigen::TensorMap< + Eigen::Tensor>; + using OutType = + Eigen::TensorMap>; + static void Eval(const EigenDevice& dev, OutType out, InType in, + const Array& reduce_dims, const Array2& reshape_dims); +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/expand_as_op.h b/paddle/fluid/operators/expand_as_op.h index cbaeb0c4e4..4cefadb24e 100644 --- a/paddle/fluid/operators/expand_as_op.h +++ b/paddle/fluid/operators/expand_as_op.h @@ -22,6 +22,7 @@ limitations under the License. */ #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/operators/eigen/eigen_function.h" #define MAX_RANK_SUPPORTED 6 @@ -75,7 +76,7 @@ class ExpandAsKernel : public framework::OpKernel { auto in_dims = in0->dims(); auto* target_tensor = context.Input("target_tensor"); auto* out0 = context.Output("Out"); - Eigen::DSizes bcast_dims; + Eigen::DSizes bcast_dims; int bcast_dims_remainder = 0; auto x_dims = in0->dims(); auto y_dims = target_tensor->dims(); @@ -104,7 +105,8 @@ class ExpandAsKernel : public framework::OpKernel { auto y = EigenTensor::From(*out0); auto& place = *context.template device_context().eigen_device(); - y.device(place) = x.broadcast(bcast_dims); + EigenBroadcast, T, Rank>::Eval(place, y, x, + bcast_dims); } }; @@ -165,20 +167,19 @@ class ExpandAsGradKernel : public framework::OpKernel { auto* out0 = context.Output(framework::GradVarName("X")); out0->mutable_data(context.GetPlace()); auto x_grad = EigenVector::Flatten(*out0); - Eigen::DSizes reshape_dims; + Eigen::DSizes reshape_dims; for (size_t i = 0; i < reshape_size; ++i) { reshape_dims[i] = reshape_dims_vec[i]; } - Eigen::DSizes reduce_dims; + Eigen::DSizes reduce_dims; for (size_t i = 0; i < reduce_size; ++i) { reduce_dims[i] = reduce_dims_vec[i]; } auto out_grad = EigenVector::Flatten(*in0); - x_grad.device( - *context.template device_context().eigen_device()) = - out_grad.reshape(reshape_dims) - .sum(reduce_dims) - .reshape(x_grad.dimensions()); + auto& place = + *context.template device_context().eigen_device(); + EigenBroadcastGrad, T, Dims>::Eval( + place, x_grad, out_grad, reduce_dims, reshape_dims); } }; diff --git a/paddle/fluid/operators/expand_as_v2_op.h b/paddle/fluid/operators/expand_as_v2_op.h index c36e461926..441dd35380 100644 --- a/paddle/fluid/operators/expand_as_v2_op.h +++ b/paddle/fluid/operators/expand_as_v2_op.h @@ -23,6 +23,7 @@ limitations under the License. */ #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/operators/eigen/eigen_function.h" #define MAX_RANK_SUPPORTED 6 @@ -108,7 +109,7 @@ class ExpandAsV2Kernel : public framework::OpKernel { } } auto* out0 = context.Output("Out"); - Eigen::DSizes bcast_dims; + Eigen::DSizes bcast_dims; for (size_t i = 0; i < repeat_times.size(); ++i) { bcast_dims[i] = repeat_times[i]; } @@ -122,7 +123,8 @@ class ExpandAsV2Kernel : public framework::OpKernel { auto y = EigenTensor::From(*out0, out_dims); auto& place = *context.template device_context().eigen_device(); - y.device(place) = x.broadcast(bcast_dims); + EigenBroadcast, T, Rank>::Eval(place, y, x, + bcast_dims); } }; @@ -191,20 +193,19 @@ class ExpandAsV2GradKernel : public framework::OpKernel { auto* out0 = context.Output(framework::GradVarName("X")); out0->mutable_data(context.GetPlace()); auto x_grad = EigenVector::Flatten(*out0); - Eigen::DSizes reshape_dims; + Eigen::DSizes reshape_dims; for (size_t i = 0; i < reshape_size; ++i) { reshape_dims[i] = reshape_dims_vec[i]; } - Eigen::DSizes reduce_dims; + Eigen::DSizes reduce_dims; for (size_t i = 0; i < reduce_size; ++i) { reduce_dims[i] = reduce_dims_vec[i]; } auto out_grad = EigenVector::Flatten(*in0); - x_grad.device( - *context.template device_context().eigen_device()) = - out_grad.reshape(reshape_dims) - .sum(reduce_dims) - .reshape(x_grad.dimensions()); + auto& place = + *context.template device_context().eigen_device(); + EigenBroadcastGrad, T, Dims>::Eval( + place, x_grad, out_grad, reduce_dims, reshape_dims); } }; diff --git a/paddle/fluid/operators/expand_op.h b/paddle/fluid/operators/expand_op.h index 8b79a1feb8..abd525497d 100644 --- a/paddle/fluid/operators/expand_op.h +++ b/paddle/fluid/operators/expand_op.h @@ -25,6 +25,7 @@ limitations under the License. */ #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/operators/eigen/eigen_function.h" #define MAX_RANK_SUPPORTED 6 @@ -141,7 +142,7 @@ class ExpandKernel : public framework::OpKernel { "of dimensions (%d) of the input.", expand_times.size(), static_cast(in_dims.size()))); auto* out0 = context.Output("Out"); - Eigen::DSizes bcast_dims; + Eigen::DSizes bcast_dims; for (size_t i = 0; i < expand_times.size(); ++i) { bcast_dims[i] = expand_times[i]; } @@ -160,9 +161,11 @@ class ExpandKernel : public framework::OpKernel { // use 32-bit index to speed up bool use_32bit_index = y.size() < Eigen::NumTraits::highest(); if (use_32bit_index) { - To32BitIndex(y).device(place) = To32BitIndex(x).broadcast(bcast_dims); + EigenBroadcast, T, Rank>::Eval( + place, To32BitIndex(y), To32BitIndex(x), bcast_dims); } else { - y.device(place) = x.broadcast(bcast_dims); + EigenBroadcast, T, Rank>::Eval(place, y, x, + bcast_dims); } } }; @@ -241,20 +244,19 @@ class ExpandGradKernel : public framework::OpKernel { auto* out0 = context.Output(framework::GradVarName("X")); out0->mutable_data(context.GetPlace()); auto x_grad = EigenVector::Flatten(*out0); - Eigen::DSizes reshape_dims; + Eigen::DSizes reshape_dims; for (size_t i = 0; i < reshape_size; ++i) { reshape_dims[i] = reshape_dims_vec[i]; } - Eigen::DSizes reduce_dims; + Eigen::DSizes reduce_dims; for (size_t i = 0; i < reduce_size; ++i) { reduce_dims[i] = reduce_dims_vec[i]; } auto out_grad = EigenVector::Flatten(*in0); - x_grad.device( - *context.template device_context().eigen_device()) = - out_grad.reshape(reshape_dims) - .sum(reduce_dims) - .reshape(x_grad.dimensions()); + auto& place = + *context.template device_context().eigen_device(); + EigenBroadcastGrad, T, Dims>::Eval( + place, x_grad, out_grad, reduce_dims, reshape_dims); } }; diff --git a/paddle/fluid/operators/expand_v2_op.h b/paddle/fluid/operators/expand_v2_op.h index ec9c6e62f2..af5fdf22cd 100644 --- a/paddle/fluid/operators/expand_v2_op.h +++ b/paddle/fluid/operators/expand_v2_op.h @@ -26,6 +26,7 @@ limitations under the License. */ #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/operators/eigen/eigen_function.h" #define MAX_RANK_SUPPORTED 6 @@ -174,7 +175,7 @@ class ExpandV2Kernel : public framework::OpKernel { } auto* out0 = context.Output("Out"); - Eigen::DSizes bcast_dims; + Eigen::DSizes bcast_dims; for (size_t i = 0; i < repeat_times.size(); ++i) { bcast_dims[i] = repeat_times[i]; } @@ -194,9 +195,11 @@ class ExpandV2Kernel : public framework::OpKernel { // use 32-bit index to speed up bool use_32bit_index = y.size() < Eigen::NumTraits::highest(); if (use_32bit_index) { - To32BitIndex(y).device(place) = To32BitIndex(x).broadcast(bcast_dims); + EigenBroadcast, T, Rank>::Eval( + place, To32BitIndex(y), To32BitIndex(x), bcast_dims); } else { - y.device(place) = x.broadcast(bcast_dims); + EigenBroadcast, T, Rank>::Eval(place, y, x, + bcast_dims); } } }; @@ -275,20 +278,19 @@ class ExpandV2GradKernel : public framework::OpKernel { auto* out0 = context.Output(framework::GradVarName("X")); out0->mutable_data(context.GetPlace()); auto x_grad = EigenVector::Flatten(*out0); - Eigen::DSizes reshape_dims; + Eigen::DSizes reshape_dims; for (size_t i = 0; i < reshape_size; ++i) { reshape_dims[i] = reshape_dims_vec[i]; } - Eigen::DSizes reduce_dims; + Eigen::DSizes reduce_dims; for (size_t i = 0; i < reduce_size; ++i) { reduce_dims[i] = reduce_dims_vec[i]; } auto out_grad = EigenVector::Flatten(*in0); - x_grad.device( - *context.template device_context().eigen_device()) = - out_grad.reshape(reshape_dims) - .sum(reduce_dims) - .reshape(x_grad.dimensions()); + auto& place = + *context.template device_context().eigen_device(); + EigenBroadcastGrad, T, Dims>::Eval( + place, x_grad, out_grad, reduce_dims, reshape_dims); } }; diff --git a/paddle/fluid/operators/meshgrid_op.h b/paddle/fluid/operators/meshgrid_op.h index 162622c7d0..345e007de4 100644 --- a/paddle/fluid/operators/meshgrid_op.h +++ b/paddle/fluid/operators/meshgrid_op.h @@ -25,6 +25,7 @@ #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/operators/eigen/eigen_function.h" #include "paddle/fluid/platform/errors.h" #define MAX_RANK_SUPPORTED 6 @@ -106,19 +107,21 @@ class MeshgridKernel : public framework::OpKernel { reshape_ins_tensor.Resize(out_dims_reshape); framework::DDim out_dims = framework::make_ddim(shape); - Eigen::DSizes bcast_dims; + Eigen::DSizes bcast_dims; for (int64_t j = 0; j < size; j++) { bcast_dims[j] = shape[j]; } bcast_dims[i] = 1; outs[i]->Resize(out_dims); - auto x = framework::EigenTensor::From(reshape_ins_tensor); + auto x = framework::EigenTensor::From( + static_cast(reshape_ins_tensor)); outs[i]->mutable_data(context.GetPlace()); auto y = framework::EigenTensor::From(*outs[i]); auto& place = *context.template device_context().eigen_device(); - y.device(place) = x.broadcast(bcast_dims); + EigenBroadcast, T, Rank>::Eval(place, y, x, + bcast_dims); } } }; @@ -169,21 +172,20 @@ class MeshgridGradKernel : public framework::OpKernel { } } - Eigen::DSizes reduce_dims; + Eigen::DSizes reduce_dims; for (int k = 0; k < n; k++) { reduce_dims[k] = reduce_dims_vec[k]; } - Eigen::DSizes reshape_dims; + Eigen::DSizes reshape_dims; for (int k = 0; k < n * 2; k++) { reshape_dims[k] = reshape_dims_vec[k]; } - auto tensor_reduce_tmp = - out_grad_tmp.reshape(reshape_dims).sum(reduce_dims); auto& place = *context.template device_context().eigen_device(); - in_grad.device(place) = tensor_reduce_tmp.reshape(in_grad.dimensions()); + EigenBroadcastGrad, T, Rank>::Eval( + place, in_grad, out_grad_tmp, reduce_dims, reshape_dims); } } }; diff --git a/paddle/fluid/operators/tile_op.h b/paddle/fluid/operators/tile_op.h index dffd3e5864..4bbde8d08e 100644 --- a/paddle/fluid/operators/tile_op.h +++ b/paddle/fluid/operators/tile_op.h @@ -26,6 +26,7 @@ limitations under the License. */ #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/operators/eigen/eigen_function.h" #define MAX_RANK_SUPPORTED 6 @@ -155,7 +156,7 @@ class TileKernel : public framework::OpKernel { "'repeat_times' for tile op must match after promotion.", vec_in_dims.size(), repeat_times.size())); auto* out0 = context.Output("Out"); - Eigen::DSizes bcast_dims; + Eigen::DSizes bcast_dims; for (size_t i = 0; i < repeat_times.size(); ++i) { bcast_dims[i] = repeat_times[i]; } @@ -175,9 +176,11 @@ class TileKernel : public framework::OpKernel { // use 32-bit index to speed up bool use_32bit_index = y.size() < Eigen::NumTraits::highest(); if (use_32bit_index) { - To32BitIndex(y).device(place) = To32BitIndex(x).broadcast(bcast_dims); + EigenBroadcast, T, Rank>::Eval( + place, To32BitIndex(y), To32BitIndex(x), bcast_dims); } else { - y.device(place) = x.broadcast(bcast_dims); + EigenBroadcast, T, Rank>::Eval(place, y, x, + bcast_dims); } } }; @@ -255,21 +258,20 @@ class TileGradKernel : public framework::OpKernel { auto* out0 = context.Output(framework::GradVarName("X")); out0->mutable_data(context.GetPlace()); auto x_grad = EigenVector::Flatten(*out0); - Eigen::DSizes reshape_dims; + Eigen::DSizes reshape_dims; for (size_t i = 0; i < reshape_size; ++i) { reshape_dims[i] = reshape_dims_vec[i]; } - Eigen::DSizes reduce_dims; + Eigen::DSizes reduce_dims; for (size_t i = 0; i < reduce_size; ++i) { reduce_dims[i] = reduce_dims_vec[i]; } auto out_grad = EigenVector::Flatten(*in0); - x_grad.device( - *context.template device_context().eigen_device()) = - out_grad.reshape(reshape_dims) - .sum(reduce_dims) - .reshape(x_grad.dimensions()); + auto& place = + *context.template device_context().eigen_device(); + EigenBroadcastGrad, T, Dims>::Eval( + place, x_grad, out_grad, reduce_dims, reshape_dims); } }; -- GitLab