未验证 提交 187bf412 编写于 作者: W wuhuanzhou 提交者: GitHub

optimize compilation of operators using eigen (#31851)

上级 78af100c
......@@ -10,6 +10,7 @@ file(WRITE ${pybind_file} "// Generated by the paddle/fluid/operators/CMakeLists
copy_if_different(${pybind_file} ${pybind_file_final})
add_subdirectory(math)
add_subdirectory(eigen)
add_subdirectory(controlflow)
add_subdirectory(detection)
add_subdirectory(elementwise)
......@@ -110,8 +111,9 @@ set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence_padding sequence_scale cos_sim_fun
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions beam_search fc matrix_inverse)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} box_wrapper boost ps_gpu_wrapper)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} common_infer_shape_functions)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} eigen_cc_function)
if (WITH_GPU OR WITH_ROCM)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} depthwise_conv prelu bert_encoder_functor)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} depthwise_conv prelu bert_encoder_functor eigen_cu_function)
endif()
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} device_memory_aligment)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} layer)
......
......@@ -18,6 +18,7 @@ limitations under the License. */
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/eigen/eigen_function.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/math_function.h"
......@@ -32,8 +33,8 @@ template <typename T, size_t D, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
using Array1 = Eigen::DSizes<int64_t, 1>;
using Array2 = Eigen::DSizes<int64_t, 2>;
using Array1 = Eigen::DSizes<Eigen::DenseIndex, 1>;
using Array2 = Eigen::DSizes<Eigen::DenseIndex, 2>;
using Tensor = framework::Tensor;
......@@ -105,7 +106,8 @@ class AddMMKernel : public framework::OpKernel<T> {
auto eigen_out = EigenTensor<T, 2>::From(*out);
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
eigen_out.device(place) = eigen_input.broadcast(bcast_dims);
EigenBroadcast<std::decay_t<decltype(place)>, T, 2>::Eval(
place, eigen_out, eigen_input, bcast_dims);
blas.GEMM(false, false, x_dims[0], y_dims[1], x_dims[1], alpha,
x->data<T>(), x_dims[1], y->data<T>(), y_dims[1], beta,
......
file(GLOB EIGEN_CC_SOURCES RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.cc")
cc_library(eigen_cc_function SRCS ${EIGEN_CC_SOURCES} DEPS eigen3)
if(WITH_GPU OR WITH_ROCM)
file(GLOB EIGEN_CU_SOURCES RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.cu")
if(WITH_GPU)
nv_library(eigen_cu_function SRCS ${EIGEN_CU_SOURCES} DEPS eigen3)
elseif(WITH_ROCM)
hip_library(eigen_cu_function SRCS ${EIGEN_CU_SOURCES} DEPS eigen3)
endif()
endif()
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/eigen/eigen_function.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
namespace operators {
template <typename T, int Rank>
struct EigenBroadcast<Eigen::DefaultDevice, T, Rank> {
using Array = Eigen::DSizes<Eigen::DenseIndex, Rank>;
using InType = Eigen::TensorMap<
Eigen::Tensor<const T, Rank, Eigen::RowMajor, Eigen::DenseIndex>>;
using InType32BitIndex =
Eigen::TensorMap<Eigen::Tensor<const T, Rank, Eigen::RowMajor, int>,
Eigen::Aligned>;
using OutType = Eigen::TensorMap<
Eigen::Tensor<T, Rank, Eigen::RowMajor, Eigen::DenseIndex>>;
using OutType32BitIndex =
Eigen::TensorMap<Eigen::Tensor<T, Rank, Eigen::RowMajor, int>,
Eigen::Aligned>;
static void Eval(const Eigen::DefaultDevice& dev, OutType out, InType in,
const Array& bcast) {
out.device(dev) = in.broadcast(bcast);
}
static void Eval(const Eigen::DefaultDevice& dev, OutType32BitIndex out,
InType32BitIndex in, const Array& bcast) {
out.device(dev) = in.broadcast(bcast);
}
};
template <typename T, int Rank>
struct EigenBroadcastGrad<Eigen::DefaultDevice, T, Rank> {
using Array = Eigen::DSizes<Eigen::DenseIndex, Rank>;
using Array2 = Eigen::DSizes<Eigen::DenseIndex, Rank * 2>;
using InType = Eigen::TensorMap<
Eigen::Tensor<const T, 1, Eigen::RowMajor, Eigen::DenseIndex>>;
using OutType =
Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor, Eigen::DenseIndex>>;
static void Eval(const Eigen::DefaultDevice& dev, OutType out, InType in,
const Array& reduce_dims, const Array2& reshape_dims) {
out.device(dev) =
in.reshape(reshape_dims).sum(reduce_dims).reshape(out.dimensions());
}
};
#define INSTANTIATION(FUNCTOR, T) \
template struct FUNCTOR<Eigen::DefaultDevice, T, 1>; \
template struct FUNCTOR<Eigen::DefaultDevice, T, 2>; \
template struct FUNCTOR<Eigen::DefaultDevice, T, 3>; \
template struct FUNCTOR<Eigen::DefaultDevice, T, 4>; \
template struct FUNCTOR<Eigen::DefaultDevice, T, 5>; \
template struct FUNCTOR<Eigen::DefaultDevice, T, 6>
INSTANTIATION(EigenBroadcast, bool);
INSTANTIATION(EigenBroadcast, platform::float16);
INSTANTIATION(EigenBroadcast, float);
INSTANTIATION(EigenBroadcast, double);
INSTANTIATION(EigenBroadcast, int);
INSTANTIATION(EigenBroadcast, int64_t);
INSTANTIATION(EigenBroadcastGrad, bool);
INSTANTIATION(EigenBroadcastGrad, float);
INSTANTIATION(EigenBroadcastGrad, platform::float16);
INSTANTIATION(EigenBroadcastGrad, double);
INSTANTIATION(EigenBroadcastGrad, int);
INSTANTIATION(EigenBroadcastGrad, int64_t);
template struct EigenBroadcastGrad<Eigen::DefaultDevice, float, 0>;
template struct EigenBroadcastGrad<Eigen::DefaultDevice, double, 0>;
template struct EigenBroadcastGrad<Eigen::DefaultDevice, int, 0>;
template struct EigenBroadcastGrad<Eigen::DefaultDevice, int64_t, 0>;
#undef INSTANTIATION
} // namespace operators
} // namespace paddle
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/eigen/eigen_function.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
namespace operators {
template <typename T, int Rank>
struct EigenBroadcast<Eigen::GpuDevice, T, Rank> {
using Array = Eigen::DSizes<Eigen::DenseIndex, Rank>;
using InType = Eigen::TensorMap<
Eigen::Tensor<const T, Rank, Eigen::RowMajor, Eigen::DenseIndex>>;
using InType32BitIndex =
Eigen::TensorMap<Eigen::Tensor<const T, Rank, Eigen::RowMajor, int>,
Eigen::Aligned>;
using OutType = Eigen::TensorMap<
Eigen::Tensor<T, Rank, Eigen::RowMajor, Eigen::DenseIndex>>;
using OutType32BitIndex =
Eigen::TensorMap<Eigen::Tensor<T, Rank, Eigen::RowMajor, int>,
Eigen::Aligned>;
static void Eval(const Eigen::GpuDevice& dev, OutType out, InType in,
const Array& bcast) {
out.device(dev) = in.broadcast(bcast);
}
static void Eval(const Eigen::GpuDevice& dev, OutType32BitIndex out,
InType32BitIndex in, const Array& bcast) {
out.device(dev) = in.broadcast(bcast);
}
};
template <typename T, int Rank>
struct EigenBroadcastGrad<Eigen::GpuDevice, T, Rank> {
using Array = Eigen::DSizes<Eigen::DenseIndex, Rank>;
using Array2 = Eigen::DSizes<Eigen::DenseIndex, Rank * 2>;
using InType = Eigen::TensorMap<
Eigen::Tensor<const T, 1, Eigen::RowMajor, Eigen::DenseIndex>>;
using OutType =
Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor, Eigen::DenseIndex>>;
static void Eval(const Eigen::GpuDevice& dev, OutType out, InType in,
const Array& reduce_dims, const Array2& reshape_dims) {
out.device(dev) =
in.reshape(reshape_dims).sum(reduce_dims).reshape(out.dimensions());
}
};
#define INSTANTIATION(FUNCTOR, T) \
template struct FUNCTOR<Eigen::GpuDevice, T, 1>; \
template struct FUNCTOR<Eigen::GpuDevice, T, 2>; \
template struct FUNCTOR<Eigen::GpuDevice, T, 3>; \
template struct FUNCTOR<Eigen::GpuDevice, T, 4>; \
template struct FUNCTOR<Eigen::GpuDevice, T, 5>; \
template struct FUNCTOR<Eigen::GpuDevice, T, 6>
INSTANTIATION(EigenBroadcast, bool);
INSTANTIATION(EigenBroadcast, platform::float16);
INSTANTIATION(EigenBroadcast, float);
INSTANTIATION(EigenBroadcast, double);
INSTANTIATION(EigenBroadcast, int);
INSTANTIATION(EigenBroadcast, int64_t);
INSTANTIATION(EigenBroadcastGrad, bool);
INSTANTIATION(EigenBroadcastGrad, float);
INSTANTIATION(EigenBroadcastGrad, platform::float16);
INSTANTIATION(EigenBroadcastGrad, double);
INSTANTIATION(EigenBroadcastGrad, int);
INSTANTIATION(EigenBroadcastGrad, int64_t);
template struct EigenBroadcastGrad<Eigen::GpuDevice, float, 0>;
template struct EigenBroadcastGrad<Eigen::GpuDevice, platform::float16, 0>;
template struct EigenBroadcastGrad<Eigen::GpuDevice, double, 0>;
template struct EigenBroadcastGrad<Eigen::GpuDevice, int, 0>;
template struct EigenBroadcastGrad<Eigen::GpuDevice, int64_t, 0>;
#undef INSTANTIATION
} // namespace operators
} // namespace paddle
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "unsupported/Eigen/CXX11/Tensor"
namespace paddle {
namespace operators {
template <typename EigenDevice, typename T, int Rank>
struct EigenBroadcast {
using Array = Eigen::DSizes<Eigen::DenseIndex, Rank>;
using InType = Eigen::TensorMap<
Eigen::Tensor<const T, Rank, Eigen::RowMajor, Eigen::DenseIndex>>;
using InType32BitIndex =
Eigen::TensorMap<Eigen::Tensor<const T, Rank, Eigen::RowMajor, int>,
Eigen::Aligned>;
using OutType = Eigen::TensorMap<
Eigen::Tensor<T, Rank, Eigen::RowMajor, Eigen::DenseIndex>>;
using OutType32BitIndex =
Eigen::TensorMap<Eigen::Tensor<T, Rank, Eigen::RowMajor, int>,
Eigen::Aligned>;
static void Eval(const EigenDevice& dev, OutType out, InType in,
const Array& bcast);
static void Eval(const EigenDevice& dev, OutType32BitIndex out,
InType32BitIndex in, const Array& bcast);
};
template <typename EigenDevice, typename T, int Rank>
struct EigenBroadcastGrad {
using Array = Eigen::DSizes<Eigen::DenseIndex, Rank>;
using Array2 = Eigen::DSizes<Eigen::DenseIndex, Rank * 2>;
using InType = Eigen::TensorMap<
Eigen::Tensor<const T, 1, Eigen::RowMajor, Eigen::DenseIndex>>;
using OutType =
Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor, Eigen::DenseIndex>>;
static void Eval(const EigenDevice& dev, OutType out, InType in,
const Array& reduce_dims, const Array2& reshape_dims);
};
} // namespace operators
} // namespace paddle
......@@ -22,6 +22,7 @@ limitations under the License. */
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/eigen/eigen_function.h"
#define MAX_RANK_SUPPORTED 6
......@@ -75,7 +76,7 @@ class ExpandAsKernel : public framework::OpKernel<T> {
auto in_dims = in0->dims();
auto* target_tensor = context.Input<Tensor>("target_tensor");
auto* out0 = context.Output<Tensor>("Out");
Eigen::DSizes<int, Rank> bcast_dims;
Eigen::DSizes<Eigen::DenseIndex, Rank> bcast_dims;
int bcast_dims_remainder = 0;
auto x_dims = in0->dims();
auto y_dims = target_tensor->dims();
......@@ -104,7 +105,8 @@ class ExpandAsKernel : public framework::OpKernel<T> {
auto y = EigenTensor<T, Rank>::From(*out0);
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
y.device(place) = x.broadcast(bcast_dims);
EigenBroadcast<std::decay_t<decltype(place)>, T, Rank>::Eval(place, y, x,
bcast_dims);
}
};
......@@ -165,20 +167,19 @@ class ExpandAsGradKernel : public framework::OpKernel<T> {
auto* out0 = context.Output<Tensor>(framework::GradVarName("X"));
out0->mutable_data<T>(context.GetPlace());
auto x_grad = EigenVector<T>::Flatten(*out0);
Eigen::DSizes<int, Dims * 2> reshape_dims;
Eigen::DSizes<Eigen::DenseIndex, Dims * 2> reshape_dims;
for (size_t i = 0; i < reshape_size; ++i) {
reshape_dims[i] = reshape_dims_vec[i];
}
Eigen::DSizes<int, Dims> reduce_dims;
Eigen::DSizes<Eigen::DenseIndex, Dims> reduce_dims;
for (size_t i = 0; i < reduce_size; ++i) {
reduce_dims[i] = reduce_dims_vec[i];
}
auto out_grad = EigenVector<T>::Flatten(*in0);
x_grad.device(
*context.template device_context<DeviceContext>().eigen_device()) =
out_grad.reshape(reshape_dims)
.sum(reduce_dims)
.reshape(x_grad.dimensions());
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
EigenBroadcastGrad<std::decay_t<decltype(place)>, T, Dims>::Eval(
place, x_grad, out_grad, reduce_dims, reshape_dims);
}
};
......
......@@ -23,6 +23,7 @@ limitations under the License. */
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/eigen/eigen_function.h"
#define MAX_RANK_SUPPORTED 6
......@@ -108,7 +109,7 @@ class ExpandAsV2Kernel : public framework::OpKernel<T> {
}
}
auto* out0 = context.Output<Tensor>("Out");
Eigen::DSizes<int, Rank> bcast_dims;
Eigen::DSizes<Eigen::DenseIndex, Rank> bcast_dims;
for (size_t i = 0; i < repeat_times.size(); ++i) {
bcast_dims[i] = repeat_times[i];
}
......@@ -122,7 +123,8 @@ class ExpandAsV2Kernel : public framework::OpKernel<T> {
auto y = EigenTensor<T, Rank>::From(*out0, out_dims);
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
y.device(place) = x.broadcast(bcast_dims);
EigenBroadcast<std::decay_t<decltype(place)>, T, Rank>::Eval(place, y, x,
bcast_dims);
}
};
......@@ -191,20 +193,19 @@ class ExpandAsV2GradKernel : public framework::OpKernel<T> {
auto* out0 = context.Output<Tensor>(framework::GradVarName("X"));
out0->mutable_data<T>(context.GetPlace());
auto x_grad = EigenVector<T>::Flatten(*out0);
Eigen::DSizes<int, Dims * 2> reshape_dims;
Eigen::DSizes<Eigen::DenseIndex, Dims * 2> reshape_dims;
for (size_t i = 0; i < reshape_size; ++i) {
reshape_dims[i] = reshape_dims_vec[i];
}
Eigen::DSizes<int, Dims> reduce_dims;
Eigen::DSizes<Eigen::DenseIndex, Dims> reduce_dims;
for (size_t i = 0; i < reduce_size; ++i) {
reduce_dims[i] = reduce_dims_vec[i];
}
auto out_grad = EigenVector<T>::Flatten(*in0);
x_grad.device(
*context.template device_context<DeviceContext>().eigen_device()) =
out_grad.reshape(reshape_dims)
.sum(reduce_dims)
.reshape(x_grad.dimensions());
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
EigenBroadcastGrad<std::decay_t<decltype(place)>, T, Dims>::Eval(
place, x_grad, out_grad, reduce_dims, reshape_dims);
}
};
......
......@@ -25,6 +25,7 @@ limitations under the License. */
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/eigen/eigen_function.h"
#define MAX_RANK_SUPPORTED 6
......@@ -141,7 +142,7 @@ class ExpandKernel : public framework::OpKernel<T> {
"of dimensions (%d) of the input.",
expand_times.size(), static_cast<size_t>(in_dims.size())));
auto* out0 = context.Output<Tensor>("Out");
Eigen::DSizes<int, Rank> bcast_dims;
Eigen::DSizes<Eigen::DenseIndex, Rank> bcast_dims;
for (size_t i = 0; i < expand_times.size(); ++i) {
bcast_dims[i] = expand_times[i];
}
......@@ -160,9 +161,11 @@ class ExpandKernel : public framework::OpKernel<T> {
// use 32-bit index to speed up
bool use_32bit_index = y.size() < Eigen::NumTraits<int>::highest();
if (use_32bit_index) {
To32BitIndex(y).device(place) = To32BitIndex(x).broadcast(bcast_dims);
EigenBroadcast<std::decay_t<decltype(place)>, T, Rank>::Eval(
place, To32BitIndex(y), To32BitIndex(x), bcast_dims);
} else {
y.device(place) = x.broadcast(bcast_dims);
EigenBroadcast<std::decay_t<decltype(place)>, T, Rank>::Eval(place, y, x,
bcast_dims);
}
}
};
......@@ -241,20 +244,19 @@ class ExpandGradKernel : public framework::OpKernel<T> {
auto* out0 = context.Output<Tensor>(framework::GradVarName("X"));
out0->mutable_data<T>(context.GetPlace());
auto x_grad = EigenVector<T>::Flatten(*out0);
Eigen::DSizes<int, Dims * 2> reshape_dims;
Eigen::DSizes<Eigen::DenseIndex, Dims * 2> reshape_dims;
for (size_t i = 0; i < reshape_size; ++i) {
reshape_dims[i] = reshape_dims_vec[i];
}
Eigen::DSizes<int, Dims> reduce_dims;
Eigen::DSizes<Eigen::DenseIndex, Dims> reduce_dims;
for (size_t i = 0; i < reduce_size; ++i) {
reduce_dims[i] = reduce_dims_vec[i];
}
auto out_grad = EigenVector<T>::Flatten(*in0);
x_grad.device(
*context.template device_context<DeviceContext>().eigen_device()) =
out_grad.reshape(reshape_dims)
.sum(reduce_dims)
.reshape(x_grad.dimensions());
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
EigenBroadcastGrad<std::decay_t<decltype(place)>, T, Dims>::Eval(
place, x_grad, out_grad, reduce_dims, reshape_dims);
}
};
......
......@@ -26,6 +26,7 @@ limitations under the License. */
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/eigen/eigen_function.h"
#define MAX_RANK_SUPPORTED 6
......@@ -174,7 +175,7 @@ class ExpandV2Kernel : public framework::OpKernel<T> {
}
auto* out0 = context.Output<Tensor>("Out");
Eigen::DSizes<int, Rank> bcast_dims;
Eigen::DSizes<Eigen::DenseIndex, Rank> bcast_dims;
for (size_t i = 0; i < repeat_times.size(); ++i) {
bcast_dims[i] = repeat_times[i];
}
......@@ -194,9 +195,11 @@ class ExpandV2Kernel : public framework::OpKernel<T> {
// use 32-bit index to speed up
bool use_32bit_index = y.size() < Eigen::NumTraits<int>::highest();
if (use_32bit_index) {
To32BitIndex(y).device(place) = To32BitIndex(x).broadcast(bcast_dims);
EigenBroadcast<std::decay_t<decltype(place)>, T, Rank>::Eval(
place, To32BitIndex(y), To32BitIndex(x), bcast_dims);
} else {
y.device(place) = x.broadcast(bcast_dims);
EigenBroadcast<std::decay_t<decltype(place)>, T, Rank>::Eval(place, y, x,
bcast_dims);
}
}
};
......@@ -275,20 +278,19 @@ class ExpandV2GradKernel : public framework::OpKernel<T> {
auto* out0 = context.Output<Tensor>(framework::GradVarName("X"));
out0->mutable_data<T>(context.GetPlace());
auto x_grad = EigenVector<T>::Flatten(*out0);
Eigen::DSizes<int, Dims * 2> reshape_dims;
Eigen::DSizes<Eigen::DenseIndex, Dims * 2> reshape_dims;
for (size_t i = 0; i < reshape_size; ++i) {
reshape_dims[i] = reshape_dims_vec[i];
}
Eigen::DSizes<int, Dims> reduce_dims;
Eigen::DSizes<Eigen::DenseIndex, Dims> reduce_dims;
for (size_t i = 0; i < reduce_size; ++i) {
reduce_dims[i] = reduce_dims_vec[i];
}
auto out_grad = EigenVector<T>::Flatten(*in0);
x_grad.device(
*context.template device_context<DeviceContext>().eigen_device()) =
out_grad.reshape(reshape_dims)
.sum(reduce_dims)
.reshape(x_grad.dimensions());
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
EigenBroadcastGrad<std::decay_t<decltype(place)>, T, Dims>::Eval(
place, x_grad, out_grad, reduce_dims, reshape_dims);
}
};
......
......@@ -25,6 +25,7 @@
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/eigen/eigen_function.h"
#include "paddle/fluid/platform/errors.h"
#define MAX_RANK_SUPPORTED 6
......@@ -106,19 +107,21 @@ class MeshgridKernel : public framework::OpKernel<T> {
reshape_ins_tensor.Resize(out_dims_reshape);
framework::DDim out_dims = framework::make_ddim(shape);
Eigen::DSizes<int, Rank> bcast_dims;
Eigen::DSizes<Eigen::DenseIndex, Rank> bcast_dims;
for (int64_t j = 0; j < size; j++) {
bcast_dims[j] = shape[j];
}
bcast_dims[i] = 1;
outs[i]->Resize(out_dims);
auto x = framework::EigenTensor<T, Rank>::From(reshape_ins_tensor);
auto x = framework::EigenTensor<T, Rank>::From(
static_cast<const framework::Tensor>(reshape_ins_tensor));
outs[i]->mutable_data<T>(context.GetPlace());
auto y = framework::EigenTensor<T, Rank>::From(*outs[i]);
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
y.device(place) = x.broadcast(bcast_dims);
EigenBroadcast<std::decay_t<decltype(place)>, T, Rank>::Eval(place, y, x,
bcast_dims);
}
}
};
......@@ -169,21 +172,20 @@ class MeshgridGradKernel : public framework::OpKernel<T> {
}
}
Eigen::DSizes<int, Rank> reduce_dims;
Eigen::DSizes<Eigen::DenseIndex, Rank> reduce_dims;
for (int k = 0; k < n; k++) {
reduce_dims[k] = reduce_dims_vec[k];
}
Eigen::DSizes<int, Rank * 2> reshape_dims;
Eigen::DSizes<Eigen::DenseIndex, Rank * 2> reshape_dims;
for (int k = 0; k < n * 2; k++) {
reshape_dims[k] = reshape_dims_vec[k];
}
auto tensor_reduce_tmp =
out_grad_tmp.reshape(reshape_dims).sum(reduce_dims);
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
in_grad.device(place) = tensor_reduce_tmp.reshape(in_grad.dimensions());
EigenBroadcastGrad<std::decay_t<decltype(place)>, T, Rank>::Eval(
place, in_grad, out_grad_tmp, reduce_dims, reshape_dims);
}
}
};
......
......@@ -26,6 +26,7 @@ limitations under the License. */
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/eigen/eigen_function.h"
#define MAX_RANK_SUPPORTED 6
......@@ -155,7 +156,7 @@ class TileKernel : public framework::OpKernel<T> {
"'repeat_times' for tile op must match after promotion.",
vec_in_dims.size(), repeat_times.size()));
auto* out0 = context.Output<Tensor>("Out");
Eigen::DSizes<int, Rank> bcast_dims;
Eigen::DSizes<Eigen::DenseIndex, Rank> bcast_dims;
for (size_t i = 0; i < repeat_times.size(); ++i) {
bcast_dims[i] = repeat_times[i];
}
......@@ -175,9 +176,11 @@ class TileKernel : public framework::OpKernel<T> {
// use 32-bit index to speed up
bool use_32bit_index = y.size() < Eigen::NumTraits<int>::highest();
if (use_32bit_index) {
To32BitIndex(y).device(place) = To32BitIndex(x).broadcast(bcast_dims);
EigenBroadcast<std::decay_t<decltype(place)>, T, Rank>::Eval(
place, To32BitIndex(y), To32BitIndex(x), bcast_dims);
} else {
y.device(place) = x.broadcast(bcast_dims);
EigenBroadcast<std::decay_t<decltype(place)>, T, Rank>::Eval(place, y, x,
bcast_dims);
}
}
};
......@@ -255,21 +258,20 @@ class TileGradKernel : public framework::OpKernel<T> {
auto* out0 = context.Output<Tensor>(framework::GradVarName("X"));
out0->mutable_data<T>(context.GetPlace());
auto x_grad = EigenVector<T>::Flatten(*out0);
Eigen::DSizes<int, Dims * 2> reshape_dims;
Eigen::DSizes<Eigen::DenseIndex, Dims * 2> reshape_dims;
for (size_t i = 0; i < reshape_size; ++i) {
reshape_dims[i] = reshape_dims_vec[i];
}
Eigen::DSizes<int, Dims> reduce_dims;
Eigen::DSizes<Eigen::DenseIndex, Dims> reduce_dims;
for (size_t i = 0; i < reduce_size; ++i) {
reduce_dims[i] = reduce_dims_vec[i];
}
auto out_grad = EigenVector<T>::Flatten(*in0);
x_grad.device(
*context.template device_context<DeviceContext>().eigen_device()) =
out_grad.reshape(reshape_dims)
.sum(reduce_dims)
.reshape(x_grad.dimensions());
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
EigenBroadcastGrad<std::decay_t<decltype(place)>, T, Dims>::Eval(
place, x_grad, out_grad, reduce_dims, reshape_dims);
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册