提交 e0395a53 编写于 作者: D dongzhihong

"remove unused commented code"

上级 632b320e
...@@ -62,10 +62,6 @@ class MulOpGrad : public framework::OperatorWithKernel { ...@@ -62,10 +62,6 @@ class MulOpGrad : public framework::OperatorWithKernel {
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
// PADDLE_ENFORCE_EQ(ctx.InputSize(), 3UL,
// "Input of MulOpGrad should be 3, X, Y, Out@GRAD");
// PADDLE_ENFORCE_EQ(ctx.OutputSize(), 2UL,
// "Output of MulOpGrad should be 2, X@GRAD, Y@GRAD");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null"); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), "Input(Y) should not be null"); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), "Input(Y) should not be null");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
......
...@@ -31,8 +31,6 @@ template <typename Place, typename T> ...@@ -31,8 +31,6 @@ template <typename Place, typename T>
class MulKernel : public framework::OpKernel { class MulKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
// Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair = {
// {Eigen::IndexPair<Eigen::DenseIndex>(1, 0)}};
auto* X = context.Input<Tensor>("X"); auto* X = context.Input<Tensor>("X");
auto* Y = context.Input<Tensor>("Y"); auto* Y = context.Input<Tensor>("Y");
auto* Z = context.Output<Tensor>("Out"); auto* Z = context.Output<Tensor>("Out");
...@@ -40,13 +38,6 @@ class MulKernel : public framework::OpKernel { ...@@ -40,13 +38,6 @@ class MulKernel : public framework::OpKernel {
auto* device_context = auto* device_context =
const_cast<platform::DeviceContext*>(context.device_context_); const_cast<platform::DeviceContext*>(context.device_context_);
math::matmul<Place, T>(*X, false, *Y, false, 1, Z, 0, device_context); math::matmul<Place, T>(*X, false, *Y, false, 1, Z, 0, device_context);
// auto X = EigenMatrix<T>::From(*input0);
// auto Y = EigenMatrix<T>::From(*input1);
// auto Z = EigenMatrix<T>::From(*output);
// auto& place = context.GetEigenDevice<Place>();
// Z.device(place) = X.contract(Y, dim_pair);
} }
}; };
...@@ -60,25 +51,10 @@ class MulGradKernel : public framework::OpKernel { ...@@ -60,25 +51,10 @@ class MulGradKernel : public framework::OpKernel {
auto* dX = ctx.Output<Tensor>(framework::GradVarName("X")); auto* dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dY = ctx.Output<Tensor>(framework::GradVarName("Y")); auto* dY = ctx.Output<Tensor>(framework::GradVarName("Y"));
// auto* dXdata = dX->template mutable_data<T>(ctx.GetPlace());
// auto* dYdata = dY->template mutable_data<T>(ctx.GetPlace());
auto* device_context = auto* device_context =
const_cast<platform::DeviceContext*>(ctx.device_context_); const_cast<platform::DeviceContext*>(ctx.device_context_);
math::matmul<Place, T>(*dOut, false, *Y, true, 1, dX, 0, device_context); math::matmul<Place, T>(*dOut, false, *Y, true, 1, dX, 0, device_context);
math::matmul<Place, T>(*X, true, *dOut, false, 1, dY, 0, device_context); math::matmul<Place, T>(*X, true, *dOut, false, 1, dY, 0, device_context);
// auto X = EigenMatrix<T>::From(*input0);
// auto Y = EigenMatrix<T>::From(*input1);
// auto dOut = EigenMatrix<T>::From(*input2);
// auto dX = EigenMatrix<T>::From(*output0);
// auto dY = EigenMatrix<T>::From(*output1);
// dX = Out@G * Y'
// dY = X' * Out@G
// auto place = ctx.GetEigenDevice<Place>();
// TODO(dzh,qijun) : need transpose feature of blas library
// Eigen Tensor does not support it very well
// dX.device(place) = matmul(input2, )
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册