提交 7b4b9d3e 编写于 作者: D dongzhihong

"format style"

上级 88360100
...@@ -66,10 +66,10 @@ class MulOpGrad : public framework::OperatorWithKernel { ...@@ -66,10 +66,10 @@ class MulOpGrad : public framework::OperatorWithKernel {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), "Input(Y) should not be null"); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), "Input(Y) should not be null");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null"); "Input(Out@GRAD) should not be null");
auto x_dims = ctx.Output<Tensor>("X")->dims();
auto y_dims = ctx.Output<Tensor>("Y")->dims();
auto *x_grad = ctx.Output<Tensor>(framework::GradVarName("X")); auto *x_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *y_grad = ctx.Output<Tensor>(framework::GradVarName("Y")); auto *y_grad = ctx.Output<Tensor>(framework::GradVarName("Y"));
auto x_dims = ctx.Output<Tensor>(framework::GradVarName("X"))->dims();
auto y_dims = ctx.Output<Tensor>(framework::GradVarName("Y"))->dims();
auto out_dims = ctx.Input<Tensor>(framework::GradVarName("Out"))->dims(); auto out_dims = ctx.Input<Tensor>(framework::GradVarName("Out"))->dims();
PADDLE_ENFORCE(x_dims[0] == out_dims[0], PADDLE_ENFORCE(x_dims[0] == out_dims[0],
"Out@GRAD M X N must equal to X dims 0, M "); "Out@GRAD M X N must equal to X dims 0, M ");
......
...@@ -53,7 +53,9 @@ class MulGradKernel : public framework::OpKernel { ...@@ -53,7 +53,9 @@ class MulGradKernel : public framework::OpKernel {
auto* dY = ctx.Output<Tensor>(framework::GradVarName("Y")); auto* dY = ctx.Output<Tensor>(framework::GradVarName("Y"));
auto* device_context = auto* device_context =
const_cast<platform::DeviceContext*>(ctx.device_context_); const_cast<platform::DeviceContext*>(ctx.device_context_);
// dX = dOut' * Y. dX: M x K, dOut : M x N, Y : K x N
math::matmul<Place, T>(*dOut, false, *Y, true, 1, dX, 0, device_context); math::matmul<Place, T>(*dOut, false, *Y, true, 1, dX, 0, device_context);
// dY = X' * dOut. dY: K x N, dOut : M x N, X : M x K
math::matmul<Place, T>(*X, true, *dOut, false, 1, dY, 0, device_context); math::matmul<Place, T>(*X, true, *dOut, false, 1, dY, 0, device_context);
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册