diff --git a/paddle/operators/squared_l2_distance_op.cc b/paddle/operators/squared_l2_distance_op.cc index 9fc498d5a538c6ef35c44ef568335952f231f49f..3049f0f8ba0065ac89e5bc44a646525782b0b179 100644 --- a/paddle/operators/squared_l2_distance_op.cc +++ b/paddle/operators/squared_l2_distance_op.cc @@ -22,36 +22,52 @@ class SquaredL2DistanceOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { + void InferShape(const framework::InferShapeContext& ctx) const override { PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input of SquaredL2DistanceOp " "must be initialized."); - PADDLE_ENFORCE_EQ(ctx.Input("X")->dims(), - ctx.Input("Y")->dims(), - "Dimensions of SquaredL2DistanceOp's two inputs " - "must be same.") - framework::DDim dims = ctx.Input("X")->dims(); - ctx.Output("sub_result")->Resize(dims); - ctx.Output("Out")->Resize(framework::make_ddim({dims[0], 1})); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), + "Target of SquaredL2DistanceOp " + "must be initialized."); + + auto* X = ctx.Input("X"); + auto xDims = X->dims(); + auto* Y = ctx.Input("Y"); + auto yDims = Y->dims(); + + PADDLE_ENFORCE_EQ(framework::arity(xDims), framework::arity(yDims), + "Tensor rank of both SquaredL2DistanceOp's " + "inputs must be same."); + int rank = framework::arity(xDims); + PADDLE_ENFORCE(rank >= 2 || rank <= 6, "Tensor rank should be in [2, 6]."); + PADDLE_ENFORCE(yDims[0] == 1 || yDims[0] == xDims[0], + "First dimension of target must be equal to input " + "or to 1."); + + ctx.Output("sub_result")->Resize(xDims); + ctx.Output("Out")->Resize({xDims[0], 1}); } }; class SquaredL2DistanceOpMaker : public framework::OpProtoAndCheckerMaker { public: - SquaredL2DistanceOpMaker(framework::OpProto *proto, - framework::OpAttrChecker *op_checker) + SquaredL2DistanceOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "Input value."); - AddInput("Y", "Target value."); + AddInput("X", "Input of SquaredL2DistanceOp."); + AddInput("Y", "Target of SquaredL2DistanceOp."); AddOutput("sub_result", "Buffering substraction result which " "will be reused in backward.") .AsIntermediate(); AddOutput("Out", "Squared l2 distance between input and target."); AddComment(R"DOC( - SquaredL2DistanceOp will cacluate the squared L2 distances for + SquaredL2DistanceOp will cacluate the squared L2 distance for input and target. Number of distance value equals to the - first dimension of input. + first dimension of input. First dimension of target could be equal to + input or to 1. If the first dimension of target is 1, SquaredL2DistanceOp + will broadcast the first dimension to the first dimension of input. + You can decide whether calculate the gradient of target. )DOC"); } }; @@ -61,9 +77,23 @@ class SquaredL2DistanceGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - ctx.Output(framework::GradVarName("X")) - ->Resize(ctx.Input("X")->dims()); + void InferShape(const framework::InferShapeContext& ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), + "Gradient of Out should not be null"); + // check out grad dimensions + auto outDims = ctx.Input(framework::GradVarName("Out"))->dims(); + auto xDims = ctx.Input("X")->dims(); + auto yDims = ctx.Input("Y")->dims(); + PADDLE_ENFORCE_EQ(outDims[0], xDims[0], + "First dimension of output gradient and " + "input value must be equal."); + PADDLE_ENFORCE_EQ(outDims[1], 1, + "Second dimension of output gradient " + "must be 1."); + auto* xGrad = ctx.Output(framework::GradVarName("X")); + auto* yGrad = ctx.Output(framework::GradVarName("Y")); + if (xGrad != nullptr) xGrad->Resize(xDims); + if (yGrad != nullptr) yGrad->Resize(yDims); } }; diff --git a/paddle/operators/squared_l2_distance_op.h b/paddle/operators/squared_l2_distance_op.h index b350fd01177c293c17adf5bb47b4230495b240c2..e95364c7069e8baa093400b2eba20e745bed79d1 100644 --- a/paddle/operators/squared_l2_distance_op.h +++ b/paddle/operators/squared_l2_distance_op.h @@ -20,17 +20,44 @@ namespace paddle { namespace operators { using Tensor = framework::Tensor; -template -using EigenMatrix = framework::EigenMatrix; +using EigenTensor = framework::EigenTensor; template -using EigenVector = framework::EigenVector; +using EigenMatrix = framework::EigenMatrix; template class SquaredL2DistanceKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { + auto* input0 = context.Input("X"); + const int rank = framework::arity(input0->dims()); + switch (rank) { + case 2: + Operate<2>(context); + break; + case 3: + Operate<3>(context); + break; + case 4: + Operate<4>(context); + break; + case 5: + Operate<5>(context); + break; + case 6: + Operate<6>(context); + break; + default: + // already asserted in SquaredL2DistanceOpMaker + break; + } + } + + private: + template + void Operate(const framework::ExecutionContext& context) const { auto* input0 = context.Input("X"); auto* input1 = context.Input("Y"); auto* output0 = context.Output("sub_result"); @@ -39,17 +66,28 @@ class SquaredL2DistanceKernel : public framework::OpKernel { output0->mutable_data(context.GetPlace()); output1->mutable_data(context.GetPlace()); - auto X = EigenMatrix::From(*input0); - auto Y = EigenMatrix::From(*input1); - auto subResult = EigenMatrix::From(*output0); + auto X = EigenTensor::From(*input0); + auto Y = EigenTensor::From(*input1); + auto subResult = EigenTensor::From(*output0); auto Z = EigenMatrix::From(*output1); + auto xDims = X.dimensions(); + auto yDims = Y.dimensions(); + auto place = context.GetEigenDevice(); + // buffer the substraction result - subResult.device(place) = X - Y; - const auto& inDims = X.dimensions(); + if (yDims[0] == 1 && xDims[0] != yDims[0]) { + auto yBroadcastDims = yDims; + yBroadcastDims[0] = xDims[0]; + subResult.device(place) = X - Y.broadcast(yBroadcastDims); + } else { + subResult.device(place) = X - Y; + } + + // create matrix view for substraction result const auto& subResMat = subResult.reshape(Eigen::array( - {static_cast(inDims[0]), static_cast(X.size() / inDims[0])})); + {static_cast(xDims[0]), static_cast(X.size() / xDims[0])})); Z.device(place) = subResMat.pow(2).sum(Eigen::array({1})); } }; @@ -59,24 +97,78 @@ class SquaredL2DistanceGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* input0 = context.Input("sub_result"); - auto* OG = context.Input(framework::GradVarName("Out")); - auto* IG = context.Output(framework::GradVarName("X")); + const int rank = framework::arity(input0->dims()); + switch (rank) { + case 2: + Operate<2>(context); + break; + case 3: + Operate<3>(context); + break; + case 4: + Operate<4>(context); + break; + case 5: + Operate<5>(context); + break; + case 6: + Operate<6>(context); + break; + default: + // already asserted in SquaredL2DistanceOpMaker + break; + } + } - IG->mutable_data(context.GetPlace()); + private: + template + void Operate(const framework::ExecutionContext& context) const { + auto* input0 = context.Input("sub_result"); + auto* OG = context.Input(framework::GradVarName("Out")); + auto* XG = context.Output(framework::GradVarName("X")); + auto* YG = context.Output(framework::GradVarName("Y")); - auto subResult = EigenMatrix::From(*input0); + auto subResult = EigenTensor::From(*input0); auto outGrad = EigenMatrix::From(*OG); - auto inGrad = EigenMatrix::From(*IG); - const auto& subResDims = subResult.dimensions(); + auto subResDims = subResult.dimensions(); int firstDim = static_cast(subResDims[0]); int cols = subResult.size() / firstDim; const auto subResMat = subResult.reshape(Eigen::array({firstDim, cols})); - // create a matrix view for input gradient tensor - auto inGradMat = inGrad.reshape(Eigen::array({firstDim, cols})); - inGradMat.device(context.GetEigenDevice()) = + + // calculate gradient + auto gradMat = 2 * (outGrad.broadcast(Eigen::array({1, cols}))) * subResMat; + + // propagate back to input + auto eigenPlace = context.GetEigenDevice(); + if (XG != nullptr) { + XG->mutable_data(context.GetPlace()); + auto xGrad = EigenTensor::From(*XG); + // dimensions are same with subResult + auto xGradMat = xGrad.reshape(Eigen::array({firstDim, cols})); + xGradMat.device(eigenPlace) = gradMat; + } + if (YG != nullptr) { + YG->mutable_data(context.GetPlace()); + auto yGrad = EigenTensor::From(*YG); + auto dimsYGrad = yGrad.dimensions(); + auto yGradMat = yGrad.reshape(Eigen::array( + {static_cast(dimsYGrad[0]), + static_cast(yGrad.size() / dimsYGrad[0])})); + + PADDLE_ENFORCE(dimsYGrad[0] <= firstDim, + "First dimension of gradient must be greater or " + "equal than first dimension of target"); + + if (dimsYGrad[0] == firstDim) { + yGradMat.device(eigenPlace) = -1 * gradMat; + } else { + yGradMat.device(eigenPlace) = + -1 * (gradMat.sum(Eigen::array({0}))); + } + } } };