未验证 提交 5f448135 编写于 作者: F fengjiayi 提交者: GitHub

Merge pull request #7529 from JiayiFeng/remove_functor1

remove `functor1` of ElementwiseGradCompute
...@@ -81,23 +81,6 @@ struct ElementwiseAddGradFunctor { ...@@ -81,23 +81,6 @@ struct ElementwiseAddGradFunctor {
} }
}; };
template <typename T>
struct ElementwiseAddOneGradFunctor {
template <typename Device, typename X, typename Y, typename Z, typename dX,
typename dY, typename dZ>
void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz) {
auto dz_e = framework::EigenVector<T>::Flatten(*dz);
if (dx) {
auto dx_e = framework::EigenVector<T>::Flatten(*dx);
dx_e.device(d) = dz_e;
}
if (dy) {
auto dy_e = framework::EigenVector<T>::Flatten(*dy);
dy_e.device(d) = dz_e.sum();
}
}
};
template <typename T> template <typename T>
struct ElementwiseAddBroadCastGradFunctor { struct ElementwiseAddBroadCastGradFunctor {
template <typename Device, typename X, typename Y, typename Z, typename dX, template <typename Device, typename X, typename Y, typename Z, typename dX,
...@@ -142,7 +125,6 @@ class ElementwiseAddGradKernel : public framework::OpKernel<T> { ...@@ -142,7 +125,6 @@ class ElementwiseAddGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
ElementwiseGradCompute<DeviceContext, T, ElementwiseAddGradFunctor<T>, ElementwiseGradCompute<DeviceContext, T, ElementwiseAddGradFunctor<T>,
ElementwiseAddOneGradFunctor<T>,
ElementwiseAddBroadCastGradFunctor<T>, ElementwiseAddBroadCastGradFunctor<T>,
ElementwiseAddBroadCast2GradFunctor<T>>(ctx); ElementwiseAddBroadCast2GradFunctor<T>>(ctx);
} }
......
...@@ -107,7 +107,6 @@ class ElementwiseDivGradKernel : public framework::OpKernel<T> { ...@@ -107,7 +107,6 @@ class ElementwiseDivGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
ElementwiseGradCompute<DeviceContext, T, ElementwiseDivGradFunctor<T>, ElementwiseGradCompute<DeviceContext, T, ElementwiseDivGradFunctor<T>,
ElementwiseDivGradFunctor<T>,
ElementwiseDivBroadCastGradFunctor<T>, ElementwiseDivBroadCastGradFunctor<T>,
ElementwiseDivBroadCast2GradFunctor<T>>(ctx); ElementwiseDivBroadCast2GradFunctor<T>>(ctx);
} }
......
...@@ -106,7 +106,6 @@ class ElementwiseMulGradKernel : public framework::OpKernel<T> { ...@@ -106,7 +106,6 @@ class ElementwiseMulGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
ElementwiseGradCompute<DeviceContext, T, ElementwiseMulGradFunctor<T>, ElementwiseGradCompute<DeviceContext, T, ElementwiseMulGradFunctor<T>,
ElementwiseMulGradFunctor<T>,
ElementwiseMulBroadCastGradFunctor<T>, ElementwiseMulBroadCastGradFunctor<T>,
ElementwiseMulBroadCast2GradFunctor<T>>(ctx); ElementwiseMulBroadCast2GradFunctor<T>>(ctx);
} }
......
...@@ -311,8 +311,7 @@ EIGEN_FUNCTOR(Mul, EIGEN_MUL); ...@@ -311,8 +311,7 @@ EIGEN_FUNCTOR(Mul, EIGEN_MUL);
EIGEN_FUNCTOR(Div, EIGEN_DIV); EIGEN_FUNCTOR(Div, EIGEN_DIV);
template <typename DeviceContext, typename T, typename functor, template <typename DeviceContext, typename T, typename functor,
typename functor1, typename broadcastfunctor, typename broadcastfunctor, typename broadcast2functor>
typename broadcast2functor>
void ElementwiseGradCompute(const framework::ExecutionContext& ctx) { void ElementwiseGradCompute(const framework::ExecutionContext& ctx) {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
......
...@@ -43,23 +43,6 @@ struct ElementwiseSubGradFunctor { ...@@ -43,23 +43,6 @@ struct ElementwiseSubGradFunctor {
} }
}; };
template <typename T>
struct ElementwiseSubOneGradFunctor {
template <typename Device, typename X, typename Y, typename Z, typename dX,
typename dY, typename dZ>
void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz) {
auto dz_e = framework::EigenVector<T>::Flatten(*dz);
if (dx) {
auto dx_e = framework::EigenVector<T>::Flatten(*dx);
dx_e.device(d) = dz_e;
}
if (dy) {
auto dy_e = framework::EigenVector<T>::Flatten(*dy);
dy_e.device(d) = (-1.0) * dz_e.sum();
}
}
};
template <typename T> template <typename T>
struct ElementwiseSubBroadCastGradFunctor { struct ElementwiseSubBroadCastGradFunctor {
template <typename Device, typename X, typename Y, typename Z, typename dX, template <typename Device, typename X, typename Y, typename Z, typename dX,
...@@ -106,7 +89,6 @@ class ElementwiseSubGradKernel : public framework::OpKernel<T> { ...@@ -106,7 +89,6 @@ class ElementwiseSubGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
ElementwiseGradCompute<DeviceContext, T, ElementwiseSubGradFunctor<T>, ElementwiseGradCompute<DeviceContext, T, ElementwiseSubGradFunctor<T>,
ElementwiseSubOneGradFunctor<T>,
ElementwiseSubBroadCastGradFunctor<T>, ElementwiseSubBroadCastGradFunctor<T>,
ElementwiseSubBroadCast2GradFunctor<T>>(ctx); ElementwiseSubBroadCast2GradFunctor<T>>(ctx);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册