提交 211d8186 编写于 作者: Y Yu Yang

Process elemwise grad op's lod. mul_op's lod

上级 c44fb003
...@@ -137,9 +137,10 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> { ...@@ -137,9 +137,10 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> {
}; };
template <typename T> template <typename T>
class EltwiseAddMKLDNNGradKernel : public framework::OpKernel<T> { class EltwiseAddMKLDNNGradKernel : public ElemwiseGradKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
ElemwiseGradKernel<T>::Compute(ctx);
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out")); auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/operators/elementwise_op.h"
#include "paddle/fluid/operators/elementwise_op_function.h" #include "paddle/fluid/operators/elementwise_op_function.h"
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
...@@ -136,9 +137,11 @@ elementwise_add_grad(const framework::ExecutionContext& ctx, ...@@ -136,9 +137,11 @@ elementwise_add_grad(const framework::ExecutionContext& ctx,
} }
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class ElementwiseAddGradKernel : public framework::OpKernel<T> { class ElementwiseAddGradKernel : public ElemwiseGradKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
ElemwiseGradKernel<T>::Compute(ctx);
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out")); auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
......
...@@ -14,8 +14,8 @@ limitations under the License. */ ...@@ -14,8 +14,8 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/operators/elementwise_op.h"
#include "paddle/fluid/operators/elementwise_op_function.h" #include "paddle/fluid/operators/elementwise_op_function.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -53,9 +53,10 @@ struct DivGradDY { ...@@ -53,9 +53,10 @@ struct DivGradDY {
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class ElementwiseDivGradKernel : public framework::OpKernel<T> { class ElementwiseDivGradKernel : public ElemwiseGradKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
ElemwiseGradKernel<T>::Compute(ctx);
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
auto* x = ctx.Input<Tensor>("X"); auto* x = ctx.Input<Tensor>("X");
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/operators/elementwise_op.h"
#include "paddle/fluid/operators/elementwise_op_function.h" #include "paddle/fluid/operators/elementwise_op_function.h"
namespace paddle { namespace paddle {
...@@ -55,9 +56,10 @@ struct MaxGradDy { ...@@ -55,9 +56,10 @@ struct MaxGradDy {
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class ElementwiseMaxGradKernel : public framework::OpKernel<T> { class ElementwiseMaxGradKernel : public ElemwiseGradKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
ElemwiseGradKernel<T>::Compute(ctx);
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
auto* x = ctx.Input<Tensor>("X"); auto* x = ctx.Input<Tensor>("X");
......
...@@ -14,8 +14,8 @@ limitations under the License. */ ...@@ -14,8 +14,8 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/operators/elementwise_op.h"
#include "paddle/fluid/operators/elementwise_op_function.h" #include "paddle/fluid/operators/elementwise_op_function.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -55,9 +55,10 @@ struct MinGradDy { ...@@ -55,9 +55,10 @@ struct MinGradDy {
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class ElementwiseMinGradKernel : public framework::OpKernel<T> { class ElementwiseMinGradKernel : public ElemwiseGradKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
ElemwiseGradKernel<T>::Compute(ctx);
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
auto* x = ctx.Input<Tensor>("X"); auto* x = ctx.Input<Tensor>("X");
......
...@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and ...@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/operators/elementwise_op.h"
#include "paddle/fluid/operators/elementwise_op_function.h" #include "paddle/fluid/operators/elementwise_op_function.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -50,9 +50,10 @@ struct MulGradDY { ...@@ -50,9 +50,10 @@ struct MulGradDY {
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class ElementwiseMulGradKernel : public framework::OpKernel<T> { class ElementwiseMulGradKernel : public ElemwiseGradKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
ElemwiseGradKernel<T>::Compute(ctx);
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
auto* x = ctx.Input<Tensor>("X"); auto* x = ctx.Input<Tensor>("X");
......
...@@ -205,6 +205,20 @@ class ElementwiseOpExplicitGrad : public ElementwiseOpGrad { ...@@ -205,6 +205,20 @@ class ElementwiseOpExplicitGrad : public ElementwiseOpGrad {
} }
}; };
template <typename T>
class ElemwiseGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* dx =
context.Output<framework::LoDTensor>(framework::GradVarName("X"));
if (dx != nullptr) {
auto& dout =
*context.Input<framework::LoDTensor>(framework::GradVarName("Out"));
dx->set_lod(dout.lod());
}
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/operators/elementwise_op.h"
#include "paddle/fluid/operators/elementwise_op_function.h" #include "paddle/fluid/operators/elementwise_op_function.h"
namespace paddle { namespace paddle {
...@@ -50,9 +51,10 @@ struct SubGradDY { ...@@ -50,9 +51,10 @@ struct SubGradDY {
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class ElementwiseSubGradKernel : public framework::OpKernel<T> { class ElementwiseSubGradKernel : public ElemwiseGradKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
ElemwiseGradKernel<T>::Compute(ctx);
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out")); auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
......
...@@ -62,23 +62,31 @@ class MulGradKernel : public framework::OpKernel<T> { ...@@ -62,23 +62,31 @@ class MulGradKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
int x_num_col_dims = ctx.template Attr<int>("x_num_col_dims"); int x_num_col_dims = ctx.template Attr<int>("x_num_col_dims");
int y_num_col_dims = ctx.template Attr<int>("y_num_col_dims"); int y_num_col_dims = ctx.template Attr<int>("y_num_col_dims");
const Tensor* x = ctx.Input<Tensor>("X"); auto* x = ctx.Input<framework::LoDTensor>("X");
const Tensor* y = ctx.Input<Tensor>("Y"); auto* y = ctx.Input<framework::LoDTensor>("Y");
const Tensor x_matrix = x->dims().size() > 2 auto x_matrix = x->dims().size() > 2
? framework::ReshapeToMatrix(*x, x_num_col_dims) ? framework::ReshapeToMatrix(*x, x_num_col_dims)
: *x; : static_cast<const Tensor&>(*x);
const Tensor y_matrix = y->dims().size() > 2 auto y_matrix = y->dims().size() > 2
? framework::ReshapeToMatrix(*y, y_num_col_dims) ? framework::ReshapeToMatrix(*y, y_num_col_dims)
: *y; : static_cast<const Tensor&>(*y);
const Tensor* dout = ctx.Input<Tensor>(framework::GradVarName("Out")); auto* dout = ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"));
Tensor dout_mat; Tensor dout_mat;
dout_mat.ShareDataWith(*dout); dout_mat.ShareDataWith(*dout);
dout_mat.Resize({framework::flatten_to_2d(x->dims(), x_num_col_dims)[0], dout_mat.Resize({framework::flatten_to_2d(x->dims(), x_num_col_dims)[0],
framework::flatten_to_2d(y->dims(), y_num_col_dims)[1]}); framework::flatten_to_2d(y->dims(), y_num_col_dims)[1]});
Tensor* dx = ctx.Output<Tensor>(framework::GradVarName("X")); auto* dx = ctx.Output<framework::LoDTensor>(framework::GradVarName("X"));
Tensor* dy = ctx.Output<Tensor>(framework::GradVarName("Y")); auto* dy = ctx.Output<framework::LoDTensor>(framework::GradVarName("Y"));
if (dx != nullptr) {
dx->set_lod(x->lod());
}
if (dy != nullptr) {
dy->set_lod(y->lod());
}
auto& dev_ctx = ctx.template device_context<DeviceContext>(); auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx); auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
if (dx) { if (dx) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册