提交 0cf5bdec 编写于 作者: D dongzhihong

"tensor mutable data"

上级 2799da66
......@@ -51,9 +51,11 @@ class MulGradKernel : public framework::OpKernel {
auto* dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dY = ctx.Output<Tensor>(framework::GradVarName("Y"));
dX->mutable_data<T>(ctx.GetPlace());
dY->mutable_data<T>(ctx.GetPlace());
auto* device_context =
const_cast<platform::DeviceContext*>(ctx.device_context_);
// dX = dOut' * Y. dX: M x K, dOut : M x N, Y : K x N
// dX = dOut * Y'. dX: M x K, dOut : M x N, Y : K x N
math::matmul<Place, T>(*dOut, false, *Y, true, 1, dX, 0, device_context);
// dY = X' * dOut. dY: K x N, dOut : M x N, X : M x K
math::matmul<Place, T>(*X, true, *dOut, false, 1, dY, 0, device_context);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册