提交 f8b885f2 编写于 作者: Y yangyaming

Using EigenTensor to reshape tensor.

上级 a4df3f5b
...@@ -22,36 +22,52 @@ class SquaredL2DistanceOp : public framework::OperatorWithKernel { ...@@ -22,36 +22,52 @@ class SquaredL2DistanceOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext& ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"),
"Input of SquaredL2DistanceOp " "Input of SquaredL2DistanceOp "
"must be initialized."); "must be initialized.");
PADDLE_ENFORCE_EQ(ctx.Input<Tensor>("X")->dims(), PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"),
ctx.Input<Tensor>("Y")->dims(), "Target of SquaredL2DistanceOp "
"Dimensions of SquaredL2DistanceOp's two inputs " "must be initialized.");
"must be same.")
framework::DDim dims = ctx.Input<Tensor>("X")->dims(); auto* X = ctx.Input<Tensor>("X");
ctx.Output<Tensor>("sub_result")->Resize(dims); auto xDims = X->dims();
ctx.Output<Tensor>("Out")->Resize(framework::make_ddim({dims[0], 1})); auto* Y = ctx.Input<Tensor>("Y");
auto yDims = Y->dims();
PADDLE_ENFORCE_EQ(framework::arity(xDims), framework::arity(yDims),
"Tensor rank of both SquaredL2DistanceOp's "
"inputs must be same.");
int rank = framework::arity(xDims);
PADDLE_ENFORCE(rank >= 2 || rank <= 6, "Tensor rank should be in [2, 6].");
PADDLE_ENFORCE(yDims[0] == 1 || yDims[0] == xDims[0],
"First dimension of target must be equal to input "
"or to 1.");
ctx.Output<Tensor>("sub_result")->Resize(xDims);
ctx.Output<Tensor>("Out")->Resize({xDims[0], 1});
} }
}; };
class SquaredL2DistanceOpMaker : public framework::OpProtoAndCheckerMaker { class SquaredL2DistanceOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SquaredL2DistanceOpMaker(framework::OpProto *proto, SquaredL2DistanceOpMaker(framework::OpProto* proto,
framework::OpAttrChecker *op_checker) framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input value."); AddInput("X", "Input of SquaredL2DistanceOp.");
AddInput("Y", "Target value."); AddInput("Y", "Target of SquaredL2DistanceOp.");
AddOutput("sub_result", AddOutput("sub_result",
"Buffering substraction result which " "Buffering substraction result which "
"will be reused in backward.") "will be reused in backward.")
.AsIntermediate(); .AsIntermediate();
AddOutput("Out", "Squared l2 distance between input and target."); AddOutput("Out", "Squared l2 distance between input and target.");
AddComment(R"DOC( AddComment(R"DOC(
SquaredL2DistanceOp will cacluate the squared L2 distances for SquaredL2DistanceOp will cacluate the squared L2 distance for
input and target. Number of distance value equals to the input and target. Number of distance value equals to the
first dimension of input. first dimension of input. First dimension of target could be equal to
input or to 1. If the first dimension of target is 1, SquaredL2DistanceOp
will broadcast the first dimension to the first dimension of input.
You can decide whether calculate the gradient of target.
)DOC"); )DOC");
} }
}; };
...@@ -61,9 +77,23 @@ class SquaredL2DistanceGradOp : public framework::OperatorWithKernel { ...@@ -61,9 +77,23 @@ class SquaredL2DistanceGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext& ctx) const override {
ctx.Output<Tensor>(framework::GradVarName("X")) PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
->Resize(ctx.Input<Tensor>("X")->dims()); "Gradient of Out should not be null");
// check out grad dimensions
auto outDims = ctx.Input<Tensor>(framework::GradVarName("Out"))->dims();
auto xDims = ctx.Input<Tensor>("X")->dims();
auto yDims = ctx.Input<Tensor>("Y")->dims();
PADDLE_ENFORCE_EQ(outDims[0], xDims[0],
"First dimension of output gradient and "
"input value must be equal.");
PADDLE_ENFORCE_EQ(outDims[1], 1,
"Second dimension of output gradient "
"must be 1.");
auto* xGrad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* yGrad = ctx.Output<Tensor>(framework::GradVarName("Y"));
if (xGrad != nullptr) xGrad->Resize(xDims);
if (yGrad != nullptr) yGrad->Resize(yDims);
} }
}; };
......
...@@ -20,17 +20,44 @@ namespace paddle { ...@@ -20,17 +20,44 @@ namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor, template <typename T, size_t D, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex> typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>; using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
template <typename T, int MajorType = Eigen::RowMajor, template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex> typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>; using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename Place, typename T> template <typename Place, typename T>
class SquaredL2DistanceKernel : public framework::OpKernel { class SquaredL2DistanceKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* input0 = context.Input<Tensor>("X");
const int rank = framework::arity(input0->dims());
switch (rank) {
case 2:
Operate<2>(context);
break;
case 3:
Operate<3>(context);
break;
case 4:
Operate<4>(context);
break;
case 5:
Operate<5>(context);
break;
case 6:
Operate<6>(context);
break;
default:
// already asserted in SquaredL2DistanceOpMaker
break;
}
}
private:
template <int Dims>
void Operate(const framework::ExecutionContext& context) const {
auto* input0 = context.Input<Tensor>("X"); auto* input0 = context.Input<Tensor>("X");
auto* input1 = context.Input<Tensor>("Y"); auto* input1 = context.Input<Tensor>("Y");
auto* output0 = context.Output<Tensor>("sub_result"); auto* output0 = context.Output<Tensor>("sub_result");
...@@ -39,17 +66,28 @@ class SquaredL2DistanceKernel : public framework::OpKernel { ...@@ -39,17 +66,28 @@ class SquaredL2DistanceKernel : public framework::OpKernel {
output0->mutable_data<T>(context.GetPlace()); output0->mutable_data<T>(context.GetPlace());
output1->mutable_data<T>(context.GetPlace()); output1->mutable_data<T>(context.GetPlace());
auto X = EigenMatrix<T>::From(*input0); auto X = EigenTensor<T, Dims>::From(*input0);
auto Y = EigenMatrix<T>::From(*input1); auto Y = EigenTensor<T, Dims>::From(*input1);
auto subResult = EigenMatrix<T>::From(*output0); auto subResult = EigenTensor<T, Dims>::From(*output0);
auto Z = EigenMatrix<T>::From(*output1); auto Z = EigenMatrix<T>::From(*output1);
auto xDims = X.dimensions();
auto yDims = Y.dimensions();
auto place = context.GetEigenDevice<Place>(); auto place = context.GetEigenDevice<Place>();
// buffer the substraction result // buffer the substraction result
if (yDims[0] == 1 && xDims[0] != yDims[0]) {
auto yBroadcastDims = yDims;
yBroadcastDims[0] = xDims[0];
subResult.device(place) = X - Y.broadcast(yBroadcastDims);
} else {
subResult.device(place) = X - Y; subResult.device(place) = X - Y;
const auto& inDims = X.dimensions(); }
// create matrix view for substraction result
const auto& subResMat = subResult.reshape(Eigen::array<int, 2>( const auto& subResMat = subResult.reshape(Eigen::array<int, 2>(
{static_cast<int>(inDims[0]), static_cast<int>(X.size() / inDims[0])})); {static_cast<int>(xDims[0]), static_cast<int>(X.size() / xDims[0])}));
Z.device(place) = subResMat.pow(2).sum(Eigen::array<int, 1>({1})); Z.device(place) = subResMat.pow(2).sum(Eigen::array<int, 1>({1}));
} }
}; };
...@@ -59,24 +97,78 @@ class SquaredL2DistanceGradKernel : public framework::OpKernel { ...@@ -59,24 +97,78 @@ class SquaredL2DistanceGradKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* input0 = context.Input<Tensor>("sub_result"); auto* input0 = context.Input<Tensor>("sub_result");
auto* OG = context.Input<Tensor>(framework::GradVarName("Out")); const int rank = framework::arity(input0->dims());
auto* IG = context.Output<Tensor>(framework::GradVarName("X")); switch (rank) {
case 2:
Operate<2>(context);
break;
case 3:
Operate<3>(context);
break;
case 4:
Operate<4>(context);
break;
case 5:
Operate<5>(context);
break;
case 6:
Operate<6>(context);
break;
default:
// already asserted in SquaredL2DistanceOpMaker
break;
}
}
IG->mutable_data<T>(context.GetPlace()); private:
template <int Dims>
void Operate(const framework::ExecutionContext& context) const {
auto* input0 = context.Input<Tensor>("sub_result");
auto* OG = context.Input<Tensor>(framework::GradVarName("Out"));
auto* XG = context.Output<Tensor>(framework::GradVarName("X"));
auto* YG = context.Output<Tensor>(framework::GradVarName("Y"));
auto subResult = EigenMatrix<T>::From(*input0); auto subResult = EigenTensor<T, Dims>::From(*input0);
auto outGrad = EigenMatrix<T>::From(*OG); auto outGrad = EigenMatrix<T>::From(*OG);
auto inGrad = EigenMatrix<T>::From(*IG);
const auto& subResDims = subResult.dimensions(); auto subResDims = subResult.dimensions();
int firstDim = static_cast<int>(subResDims[0]); int firstDim = static_cast<int>(subResDims[0]);
int cols = subResult.size() / firstDim; int cols = subResult.size() / firstDim;
const auto subResMat = const auto subResMat =
subResult.reshape(Eigen::array<int, 2>({firstDim, cols})); subResult.reshape(Eigen::array<int, 2>({firstDim, cols}));
// create a matrix view for input gradient tensor
auto inGradMat = inGrad.reshape(Eigen::array<int, 2>({firstDim, cols})); // calculate gradient
inGradMat.device(context.GetEigenDevice<Place>()) = auto gradMat =
2 * (outGrad.broadcast(Eigen::array<int, 2>({1, cols}))) * subResMat; 2 * (outGrad.broadcast(Eigen::array<int, 2>({1, cols}))) * subResMat;
// propagate back to input
auto eigenPlace = context.GetEigenDevice<Place>();
if (XG != nullptr) {
XG->mutable_data<T>(context.GetPlace());
auto xGrad = EigenTensor<T, Dims>::From(*XG);
// dimensions are same with subResult
auto xGradMat = xGrad.reshape(Eigen::array<int, 2>({firstDim, cols}));
xGradMat.device(eigenPlace) = gradMat;
}
if (YG != nullptr) {
YG->mutable_data<T>(context.GetPlace());
auto yGrad = EigenTensor<T, Dims>::From(*YG);
auto dimsYGrad = yGrad.dimensions();
auto yGradMat = yGrad.reshape(Eigen::array<int, 2>(
{static_cast<int>(dimsYGrad[0]),
static_cast<int>(yGrad.size() / dimsYGrad[0])}));
PADDLE_ENFORCE(dimsYGrad[0] <= firstDim,
"First dimension of gradient must be greater or "
"equal than first dimension of target");
if (dimsYGrad[0] == firstDim) {
yGradMat.device(eigenPlace) = -1 * gradMat;
} else {
yGradMat.device(eigenPlace) =
-1 * (gradMat.sum(Eigen::array<int, 2>({0})));
}
}
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册