提交 211d8186 编写于 作者: Y Yu Yang

Process elemwise grad op's lod. mul_op's lod

上级 c44fb003
......@@ -137,9 +137,10 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> {
};
template <typename T>
class EltwiseAddMKLDNNGradKernel : public framework::OpKernel<T> {
class EltwiseAddMKLDNNGradKernel : public ElemwiseGradKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
ElemwiseGradKernel<T>::Compute(ctx);
using Tensor = framework::Tensor;
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/operators/elementwise_op.h"
#include "paddle/fluid/operators/elementwise_op_function.h"
#include "paddle/fluid/operators/math/blas.h"
......@@ -136,9 +137,11 @@ elementwise_add_grad(const framework::ExecutionContext& ctx,
}
template <typename DeviceContext, typename T>
class ElementwiseAddGradKernel : public framework::OpKernel<T> {
class ElementwiseAddGradKernel : public ElemwiseGradKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
ElemwiseGradKernel<T>::Compute(ctx);
using Tensor = framework::Tensor;
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
......
......@@ -14,8 +14,8 @@ limitations under the License. */
#pragma once
#include "paddle/fluid/operators/elementwise_op.h"
#include "paddle/fluid/operators/elementwise_op_function.h"
namespace paddle {
namespace operators {
......@@ -53,9 +53,10 @@ struct DivGradDY {
};
template <typename DeviceContext, typename T>
class ElementwiseDivGradKernel : public framework::OpKernel<T> {
class ElementwiseDivGradKernel : public ElemwiseGradKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
ElemwiseGradKernel<T>::Compute(ctx);
using Tensor = framework::Tensor;
auto* x = ctx.Input<Tensor>("X");
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include "paddle/fluid/operators/elementwise_op.h"
#include "paddle/fluid/operators/elementwise_op_function.h"
namespace paddle {
......@@ -55,9 +56,10 @@ struct MaxGradDy {
};
template <typename DeviceContext, typename T>
class ElementwiseMaxGradKernel : public framework::OpKernel<T> {
class ElementwiseMaxGradKernel : public ElemwiseGradKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
ElemwiseGradKernel<T>::Compute(ctx);
using Tensor = framework::Tensor;
auto* x = ctx.Input<Tensor>("X");
......
......@@ -14,8 +14,8 @@ limitations under the License. */
#pragma once
#include "paddle/fluid/operators/elementwise_op.h"
#include "paddle/fluid/operators/elementwise_op_function.h"
namespace paddle {
namespace operators {
......@@ -55,9 +55,10 @@ struct MinGradDy {
};
template <typename DeviceContext, typename T>
class ElementwiseMinGradKernel : public framework::OpKernel<T> {
class ElementwiseMinGradKernel : public ElemwiseGradKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
ElemwiseGradKernel<T>::Compute(ctx);
using Tensor = framework::Tensor;
auto* x = ctx.Input<Tensor>("X");
......
......@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/operators/elementwise_op.h"
#include "paddle/fluid/operators/elementwise_op_function.h"
namespace paddle {
namespace operators {
......@@ -50,9 +50,10 @@ struct MulGradDY {
};
template <typename DeviceContext, typename T>
class ElementwiseMulGradKernel : public framework::OpKernel<T> {
class ElementwiseMulGradKernel : public ElemwiseGradKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
ElemwiseGradKernel<T>::Compute(ctx);
using Tensor = framework::Tensor;
auto* x = ctx.Input<Tensor>("X");
......
......@@ -205,6 +205,20 @@ class ElementwiseOpExplicitGrad : public ElementwiseOpGrad {
}
};
template <typename T>
class ElemwiseGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* dx =
context.Output<framework::LoDTensor>(framework::GradVarName("X"));
if (dx != nullptr) {
auto& dout =
*context.Input<framework::LoDTensor>(framework::GradVarName("Out"));
dx->set_lod(dout.lod());
}
}
};
} // namespace operators
} // namespace paddle
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/operators/elementwise_op.h"
#include "paddle/fluid/operators/elementwise_op_function.h"
namespace paddle {
......@@ -50,9 +51,10 @@ struct SubGradDY {
};
template <typename DeviceContext, typename T>
class ElementwiseSubGradKernel : public framework::OpKernel<T> {
class ElementwiseSubGradKernel : public ElemwiseGradKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
ElemwiseGradKernel<T>::Compute(ctx);
using Tensor = framework::Tensor;
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
......
......@@ -62,23 +62,31 @@ class MulGradKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override {
int x_num_col_dims = ctx.template Attr<int>("x_num_col_dims");
int y_num_col_dims = ctx.template Attr<int>("y_num_col_dims");
const Tensor* x = ctx.Input<Tensor>("X");
const Tensor* y = ctx.Input<Tensor>("Y");
const Tensor x_matrix = x->dims().size() > 2
? framework::ReshapeToMatrix(*x, x_num_col_dims)
: *x;
const Tensor y_matrix = y->dims().size() > 2
? framework::ReshapeToMatrix(*y, y_num_col_dims)
: *y;
const Tensor* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* x = ctx.Input<framework::LoDTensor>("X");
auto* y = ctx.Input<framework::LoDTensor>("Y");
auto x_matrix = x->dims().size() > 2
? framework::ReshapeToMatrix(*x, x_num_col_dims)
: static_cast<const Tensor&>(*x);
auto y_matrix = y->dims().size() > 2
? framework::ReshapeToMatrix(*y, y_num_col_dims)
: static_cast<const Tensor&>(*y);
auto* dout = ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"));
Tensor dout_mat;
dout_mat.ShareDataWith(*dout);
dout_mat.Resize({framework::flatten_to_2d(x->dims(), x_num_col_dims)[0],
framework::flatten_to_2d(y->dims(), y_num_col_dims)[1]});
Tensor* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
Tensor* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
auto* dx = ctx.Output<framework::LoDTensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<framework::LoDTensor>(framework::GradVarName("Y"));
if (dx != nullptr) {
dx->set_lod(x->lod());
}
if (dy != nullptr) {
dy->set_lod(y->lod());
}
auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
if (dx) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册