diff --git a/paddle/operators/mul_op.cc b/paddle/operators/mul_op.cc index d77c0607a08a701b6d6100d8d3469c47015f8718..95b495b87aa006c32ee777559d67f782ccdeefbd 100644 --- a/paddle/operators/mul_op.cc +++ b/paddle/operators/mul_op.cc @@ -62,10 +62,6 @@ class MulOpGrad : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { - // PADDLE_ENFORCE_EQ(ctx.InputSize(), 3UL, - // "Input of MulOpGrad should be 3, X, Y, Out@GRAD"); - // PADDLE_ENFORCE_EQ(ctx.OutputSize(), 2UL, - // "Output of MulOpGrad should be 2, X@GRAD, Y@GRAD"); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null"); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), "Input(Y) should not be null"); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), diff --git a/paddle/operators/mul_op.h b/paddle/operators/mul_op.h index 279454c7f3f098a60d313e90499c15b9e28ede98..2afed81842924796b6f54589fb61627917b4f2f4 100644 --- a/paddle/operators/mul_op.h +++ b/paddle/operators/mul_op.h @@ -31,8 +31,6 @@ template class MulKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - // Eigen::array, 1> dim_pair = { - // {Eigen::IndexPair(1, 0)}}; auto* X = context.Input("X"); auto* Y = context.Input("Y"); auto* Z = context.Output("Out"); @@ -40,13 +38,6 @@ class MulKernel : public framework::OpKernel { auto* device_context = const_cast(context.device_context_); math::matmul(*X, false, *Y, false, 1, Z, 0, device_context); - - // auto X = EigenMatrix::From(*input0); - // auto Y = EigenMatrix::From(*input1); - // auto Z = EigenMatrix::From(*output); - // auto& place = context.GetEigenDevice(); - - // Z.device(place) = X.contract(Y, dim_pair); } }; @@ -60,25 +51,10 @@ class MulGradKernel : public framework::OpKernel { auto* dX = ctx.Output(framework::GradVarName("X")); auto* dY = ctx.Output(framework::GradVarName("Y")); - // auto* dXdata = dX->template mutable_data(ctx.GetPlace()); - // auto* dYdata = dY->template mutable_data(ctx.GetPlace()); auto* device_context = const_cast(ctx.device_context_); math::matmul(*dOut, false, *Y, true, 1, dX, 0, device_context); math::matmul(*X, true, *dOut, false, 1, dY, 0, device_context); - - // auto X = EigenMatrix::From(*input0); - // auto Y = EigenMatrix::From(*input1); - // auto dOut = EigenMatrix::From(*input2); - // auto dX = EigenMatrix::From(*output0); - // auto dY = EigenMatrix::From(*output1); - - // dX = Out@G * Y' - // dY = X' * Out@G - // auto place = ctx.GetEigenDevice(); - // TODO(dzh,qijun) : need transpose feature of blas library - // Eigen Tensor does not support it very well - // dX.device(place) = matmul(input2, ) } };