提交 2ddb1122 编写于 作者: D dongzhihong

"on hold"

上级 56faf513
...@@ -54,10 +54,27 @@ The equation is: Out = X * Y ...@@ -54,10 +54,27 @@ The equation is: Out = X * Y
class MulOpGrad : public framework::OperatorWithKernel { class MulOpGrad : public framework::OperatorWithKernel {
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override {} void InferShape(const framework::InferShapeContext &ctx) const override {
std::string DebugString() const override { PADDLE_ENFORCE_EQ(ctx.InputSize(), 3UL,
LOG(INFO) << "MulGrad"; "Input of MulOpGrad should be 3, X, Y, Out@GRAD");
return ""; PADDLE_ENFORCE_EQ(ctx.OutputSize(), 2UL,
"Output of MulOpGrad should be 2, X@GRAD, Y@GRAD");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), "Input(Y) should not be null");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null");
auto *x_grad = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto *y_grad = ctx.Output<framework::Tensor>(framework::GradVarName("Y"));
auto dim0 = ctx.Input<Tensor>(0)->dims();
auto dim1 = ctx.Input<Tensor>(1)->dims();
auto out_dims = ctx.Input<Tensor>(2)->dims();
PADDLE_ENFORCE(dim0[0] * dim1[0] == out_dims[0],
"Out@GRAD[0] must equal to X[0] * Y[0]");
PADDLE_ENFORCE(dim0[1] * dim1[1] == out_dims[1],
"Out@GRAD shape must equal to X[1] * Y[1]");
x_grad->Resize(dim1);
y_grad->Resize(dim0);
} }
}; };
...@@ -69,3 +86,5 @@ REGISTER_OP(mul, ops::MulOp, ops::MulOpMaker); ...@@ -69,3 +86,5 @@ REGISTER_OP(mul, ops::MulOp, ops::MulOpMaker);
REGISTER_GRADIENT_OP(mul, mul_grad, ops::MulOpGrad); REGISTER_GRADIENT_OP(mul, mul_grad, ops::MulOpGrad);
REGISTER_OP_CPU_KERNEL(mul, ops::MulKernel<paddle::platform::CPUPlace, float>); REGISTER_OP_CPU_KERNEL(mul, ops::MulKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(mul_grad,
ops::MulGradKernel<paddle::platform::CPUPlace, float>);
...@@ -16,5 +16,6 @@ ...@@ -16,5 +16,6 @@
#include "paddle/operators/mul_op.h" #include "paddle/operators/mul_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(mul, ops::MulKernel<paddle::platform::GPUPlace, float>); REGISTER_OP_GPU_KERNEL(mul, ops::MulKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(mul_grad,
ops::MulGradKernel<paddle::platform::GPUPlace, float>);
...@@ -46,5 +46,33 @@ class MulKernel : public framework::OpKernel { ...@@ -46,5 +46,33 @@ class MulKernel : public framework::OpKernel {
} }
}; };
template <typename Place, typename T>
class MulGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input0 = ctx.Input<Tensor>("X");
auto* input1 = ctx.Input<Tensor>("Y");
auto* input2 = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* output0 = ctx.Output<Tensor>(0);
auto* output1 = ctx.Output<Tensor>(1);
output0->mutable_data<T>(ctx.GetPlace());
output1->mutable_data<T>(ctx.GetPlace());
auto X = EigenMatrix<T>::From(*input0);
auto Y = EigenMatrix<T>::From(*input1);
auto dOut = EigenMatrix<T>::From(*input2);
auto dX = EigenMatrix<T>::From(*output0);
auto dY = EigenMatrix<T>::From(*output1);
// dX = Out@G * Y'
// dY = X' * Out@G
auto place = ctx.GetEigenDevice<Place>();
// TODO(dzh,qijun) : need transpose feature of blas library
// Eigen Tensor does not support it very well
// dX.device(place) = dOut.contract(dOut, transpose)
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -15,5 +15,7 @@ class TestMulOp(unittest.TestCase): ...@@ -15,5 +15,7 @@ class TestMulOp(unittest.TestCase):
self.outputs = {'Out': np.dot(self.inputs['X'], self.inputs['Y'])} self.outputs = {'Out': np.dot(self.inputs['X'], self.inputs['Y'])}
# TODO(dzh,qijun) : mulgrad test case need transpose feature of blas library
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册