未验证 提交 bf91d11e 编写于 作者: C chengduo 提交者: GitHub

Clean elementwise_op_function (#15502)

test=develop
上级 5cfc40de
......@@ -277,68 +277,6 @@ class TransformFunctor {
Functor func_;
};
#define EIGEN_FUNCTOR(name, eigen_op) \
struct Eigen##name##Functor { \
template <typename DeviceContext, typename T> \
inline void Run(const framework::Tensor *x, const framework::Tensor *y, \
framework::Tensor *z, \
const framework::ExecutionContext &ctx) { \
auto x_e = framework::EigenVector<T>::Flatten(*x); \
auto y_e = framework::EigenVector<T>::Flatten(*y); \
auto z_e = framework::EigenVector<T>::Flatten(*z); \
z_e.device( \
*ctx.template device_context<DeviceContext>().eigen_device()) = \
eigen_op(x_e, y_e); \
} \
template <typename DeviceContext, typename T> \
inline void RunBroadCast(const framework::Tensor *x, \
const framework::Tensor *y, framework::Tensor *z, \
const framework::ExecutionContext &ctx, int pre, \
int n) { \
auto x_e = framework::EigenVector<T>::Flatten(*x); \
auto y_e = framework::EigenVector<T>::Flatten(*y); \
auto z_e = framework::EigenVector<T>::Flatten(*z); \
auto y_bcast = y_e.reshape(Eigen::DSizes<int, 2>(1, n)) \
.broadcast(Eigen::DSizes<int, 2>(pre, 1)) \
.reshape(Eigen::DSizes<int, 1>(x_e.size())); \
z_e.device( \
*ctx.template device_context<DeviceContext>().eigen_device()) = \
eigen_op(x_e, y_bcast); \
} \
template <typename DeviceContext, typename T> \
inline void RunBroadCast2(const framework::Tensor *x, \
const framework::Tensor *y, \
framework::Tensor *z, \
const framework::ExecutionContext &ctx, int pre, \
int n, int post) { \
auto x_e = framework::EigenVector<T>::Flatten(*x); \
auto y_e = framework::EigenVector<T>::Flatten(*y); \
auto z_e = framework::EigenVector<T>::Flatten(*z); \
auto y_bcast = y_e.reshape(Eigen::DSizes<int, 3>(1, n, 1)) \
.broadcast(Eigen::DSizes<int, 3>(pre, 1, post)) \
.reshape(Eigen::DSizes<int, 1>(x_e.size())); \
z_e.device( \
*ctx.template device_context<DeviceContext>().eigen_device()) = \
eigen_op(x_e, y_bcast); \
} \
}
#define EIGEN_ADD(x, y) ((x) + (y))
EIGEN_FUNCTOR(Add, EIGEN_ADD);
#define EIGEN_SUB(x, y) ((x) - (y))
EIGEN_FUNCTOR(Sub, EIGEN_SUB);
#define EIGEN_MUL(x, y) ((x) * (y))
EIGEN_FUNCTOR(Mul, EIGEN_MUL);
#define EIGEN_DIV(x, y) ((x) / (y))
EIGEN_FUNCTOR(Div, EIGEN_DIV);
template <typename T, typename DX_OP, typename DY_OP>
struct ElemwiseGradNoBroadcast {
const T *x_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册