提交 6bef0796 编写于 作者: Y yangyaming

Follow coding style and move reshaping operation to paddle tensor.

上级 f8b885f2
...@@ -30,22 +30,27 @@ class SquaredL2DistanceOp : public framework::OperatorWithKernel { ...@@ -30,22 +30,27 @@ class SquaredL2DistanceOp : public framework::OperatorWithKernel {
"Target of SquaredL2DistanceOp " "Target of SquaredL2DistanceOp "
"must be initialized."); "must be initialized.");
auto* X = ctx.Input<Tensor>("X"); auto* x = ctx.Input<Tensor>("X");
auto xDims = X->dims(); auto x_dims = x->dims();
auto* Y = ctx.Input<Tensor>("Y"); auto* y = ctx.Input<Tensor>("Y");
auto yDims = Y->dims(); auto y_dims = y->dims();
PADDLE_ENFORCE_EQ(framework::arity(xDims), framework::arity(yDims), PADDLE_ENFORCE_EQ(framework::arity(x_dims), framework::arity(y_dims),
"Tensor rank of both SquaredL2DistanceOp's " "Tensor rank of both SquaredL2DistanceOp's "
"inputs must be same."); "inputs must be same.");
int rank = framework::arity(xDims);
PADDLE_ENFORCE(rank >= 2 || rank <= 6, "Tensor rank should be in [2, 6]."); int rank = framework::arity(x_dims);
PADDLE_ENFORCE(yDims[0] == 1 || yDims[0] == xDims[0], PADDLE_ENFORCE(rank >= 2, "Tensor rank should be at least equal to 2.");
PADDLE_ENFORCE_EQ(framework::product(x_dims) / x_dims[0],
framework::product(y_dims) / y_dims[0],
"Product of dimensions expcet the first dimension of "
"input and target must be equal.");
PADDLE_ENFORCE(y_dims[0] == 1 || y_dims[0] == x_dims[0],
"First dimension of target must be equal to input " "First dimension of target must be equal to input "
"or to 1."); "or to 1.");
ctx.Output<Tensor>("sub_result")->Resize(xDims); ctx.Output<Tensor>("sub_result")->Resize(x_dims);
ctx.Output<Tensor>("Out")->Resize({xDims[0], 1}); ctx.Output<Tensor>("Out")->Resize({x_dims[0], 1});
} }
}; };
...@@ -66,8 +71,8 @@ class SquaredL2DistanceOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -66,8 +71,8 @@ class SquaredL2DistanceOpMaker : public framework::OpProtoAndCheckerMaker {
input and target. Number of distance value equals to the input and target. Number of distance value equals to the
first dimension of input. First dimension of target could be equal to first dimension of input. First dimension of target could be equal to
input or to 1. If the first dimension of target is 1, SquaredL2DistanceOp input or to 1. If the first dimension of target is 1, SquaredL2DistanceOp
will broadcast the first dimension to the first dimension of input. will broadcast target's first dimension to input's first dimension.
You can decide whether calculate the gradient of target. You can decide whether calculate the gradient of input and target.
)DOC"); )DOC");
} }
}; };
...@@ -81,19 +86,19 @@ class SquaredL2DistanceGradOp : public framework::OperatorWithKernel { ...@@ -81,19 +86,19 @@ class SquaredL2DistanceGradOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
"Gradient of Out should not be null"); "Gradient of Out should not be null");
// check out grad dimensions // check out grad dimensions
auto outDims = ctx.Input<Tensor>(framework::GradVarName("Out"))->dims(); auto out_dims = ctx.Input<Tensor>(framework::GradVarName("Out"))->dims();
auto xDims = ctx.Input<Tensor>("X")->dims(); auto x_dims = ctx.Input<Tensor>("X")->dims();
auto yDims = ctx.Input<Tensor>("Y")->dims(); auto y_dims = ctx.Input<Tensor>("Y")->dims();
PADDLE_ENFORCE_EQ(outDims[0], xDims[0], PADDLE_ENFORCE_EQ(out_dims[0], x_dims[0],
"First dimension of output gradient and " "First dimension of output gradient and "
"input value must be equal."); "input value must be equal.");
PADDLE_ENFORCE_EQ(outDims[1], 1, PADDLE_ENFORCE_EQ(out_dims[1], 1,
"Second dimension of output gradient " "Second dimension of output gradient "
"must be 1."); "must be 1.");
auto* xGrad = ctx.Output<Tensor>(framework::GradVarName("X")); auto* x_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* yGrad = ctx.Output<Tensor>(framework::GradVarName("Y")); auto* y_grad = ctx.Output<Tensor>(framework::GradVarName("Y"));
if (xGrad != nullptr) xGrad->Resize(xDims); if (x_grad != nullptr) x_grad->Resize(x_dims);
if (yGrad != nullptr) yGrad->Resize(yDims); if (y_grad != nullptr) y_grad->Resize(y_dims);
} }
}; };
......
...@@ -20,9 +20,6 @@ namespace paddle { ...@@ -20,9 +20,6 @@ namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
template <typename T, size_t D, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
template <typename T, int MajorType = Eigen::RowMajor, template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex> typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>; using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
...@@ -31,64 +28,39 @@ template <typename Place, typename T> ...@@ -31,64 +28,39 @@ template <typename Place, typename T>
class SquaredL2DistanceKernel : public framework::OpKernel { class SquaredL2DistanceKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* input0 = context.Input<Tensor>("X"); auto* in0 = context.Input<Tensor>("X");
const int rank = framework::arity(input0->dims()); auto* in1 = context.Input<Tensor>("Y");
switch (rank) { auto* out0 = context.Output<Tensor>("sub_result");
case 2: auto* out1 = context.Output<Tensor>("Out");
Operate<2>(context);
break; auto in0_dims = in0->dims();
case 3: auto in1_dims = in1->dims();
Operate<3>(context);
break; int cols = framework::product(in0_dims) / in0_dims[0];
case 4: // reduce dimensions except the first
Operate<4>(context); auto x =
break; EigenMatrix<T>::From(*in0, framework::make_ddim({in0_dims[0], cols}));
case 5: auto y =
Operate<5>(context); EigenMatrix<T>::From(*in1, framework::make_ddim({in1_dims[0], cols}));
break;
case 6: out0->mutable_data<T>(context.GetPlace());
Operate<6>(context); out1->mutable_data<T>(context.GetPlace());
break; auto sub_result = EigenMatrix<T>::From(*out0);
default: auto z = EigenMatrix<T>::From(*out1);
// already asserted in SquaredL2DistanceOpMaker
break;
}
}
private:
template <int Dims>
void Operate(const framework::ExecutionContext& context) const {
auto* input0 = context.Input<Tensor>("X");
auto* input1 = context.Input<Tensor>("Y");
auto* output0 = context.Output<Tensor>("sub_result");
auto* output1 = context.Output<Tensor>("Out");
output0->mutable_data<T>(context.GetPlace());
output1->mutable_data<T>(context.GetPlace());
auto X = EigenTensor<T, Dims>::From(*input0);
auto Y = EigenTensor<T, Dims>::From(*input1);
auto subResult = EigenTensor<T, Dims>::From(*output0);
auto Z = EigenMatrix<T>::From(*output1);
auto xDims = X.dimensions();
auto yDims = Y.dimensions();
auto place = context.GetEigenDevice<Place>(); auto place = context.GetEigenDevice<Place>();
auto x_dims = x.dimensions();
auto y_dims = y.dimensions();
// buffer the substraction result // buffer the substraction result
if (yDims[0] == 1 && xDims[0] != yDims[0]) { if (y_dims[0] == 1 && x_dims[0] > y_dims[0]) {
auto yBroadcastDims = yDims; auto y_broadcast_dims = y_dims;
yBroadcastDims[0] = xDims[0]; y_broadcast_dims[0] = x_dims[0];
subResult.device(place) = X - Y.broadcast(yBroadcastDims); sub_result.device(place) = x - y.broadcast(y_broadcast_dims);
} else { } else {
subResult.device(place) = X - Y; sub_result.device(place) = x - y;
} }
// create matrix view for substraction result z.device(place) = sub_result.pow(2).sum(Eigen::array<int, 1>({1}));
const auto& subResMat = subResult.reshape(Eigen::array<int, 2>(
{static_cast<int>(xDims[0]), static_cast<int>(X.size() / xDims[0])}));
Z.device(place) = subResMat.pow(2).sum(Eigen::array<int, 1>({1}));
} }
}; };
...@@ -96,77 +68,47 @@ template <typename Place, typename T> ...@@ -96,77 +68,47 @@ template <typename Place, typename T>
class SquaredL2DistanceGradKernel : public framework::OpKernel { class SquaredL2DistanceGradKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* input0 = context.Input<Tensor>("sub_result"); auto* in0 = context.Input<Tensor>("sub_result");
const int rank = framework::arity(input0->dims()); auto* in1 = context.Input<Tensor>(framework::GradVarName("Out"));
switch (rank) { auto* x_g = context.Output<Tensor>(framework::GradVarName("X"));
case 2: auto* y_g = context.Output<Tensor>(framework::GradVarName("Y"));
Operate<2>(context);
break;
case 3:
Operate<3>(context);
break;
case 4:
Operate<4>(context);
break;
case 5:
Operate<5>(context);
break;
case 6:
Operate<6>(context);
break;
default:
// already asserted in SquaredL2DistanceOpMaker
break;
}
}
private: auto sub_result = EigenMatrix<T>::From(*in0);
template <int Dims> auto out_grad = EigenMatrix<T>::From(*in1);
void Operate(const framework::ExecutionContext& context) const {
auto* input0 = context.Input<Tensor>("sub_result");
auto* OG = context.Input<Tensor>(framework::GradVarName("Out"));
auto* XG = context.Output<Tensor>(framework::GradVarName("X"));
auto* YG = context.Output<Tensor>(framework::GradVarName("Y"));
auto subResult = EigenTensor<T, Dims>::From(*input0); auto x_dims = x_g->dims();
auto outGrad = EigenMatrix<T>::From(*OG); auto y_dims = y_g->dims();
auto subResDims = subResult.dimensions();
int firstDim = static_cast<int>(subResDims[0]);
int cols = subResult.size() / firstDim;
const auto subResMat =
subResult.reshape(Eigen::array<int, 2>({firstDim, cols}));
int cols = framework::product(x_dims) / x_dims[0];
// calculate gradient // calculate gradient
auto gradMat = auto grad_mat =
2 * (outGrad.broadcast(Eigen::array<int, 2>({1, cols}))) * subResMat; 2 * (out_grad.broadcast(Eigen::array<int, 2>({1, cols}))) * sub_result;
// propagate back to input // propagate back to input
auto eigenPlace = context.GetEigenDevice<Place>(); auto eigen_place = context.GetEigenDevice<Place>();
if (XG != nullptr) { if (x_g != nullptr) {
XG->mutable_data<T>(context.GetPlace()); x_g->mutable_data<T>(context.GetPlace());
auto xGrad = EigenTensor<T, Dims>::From(*XG); // eigen matrix
auto x_grad =
EigenMatrix<T>::From(*x_g, framework::make_ddim({x_dims[0], cols}));
// dimensions are same with subResult // dimensions are same with subResult
auto xGradMat = xGrad.reshape(Eigen::array<int, 2>({firstDim, cols})); x_grad.device(eigen_place) = grad_mat;
xGradMat.device(eigenPlace) = gradMat;
} }
if (YG != nullptr) {
YG->mutable_data<T>(context.GetPlace()); if (y_g != nullptr) {
auto yGrad = EigenTensor<T, Dims>::From(*YG); y_g->mutable_data<T>(context.GetPlace());
auto dimsYGrad = yGrad.dimensions(); auto y_grad =
auto yGradMat = yGrad.reshape(Eigen::array<int, 2>( EigenMatrix<T>::From(*y_g, framework::make_ddim({y_dims[0], cols}));
{static_cast<int>(dimsYGrad[0]),
static_cast<int>(yGrad.size() / dimsYGrad[0])})); PADDLE_ENFORCE(sub_result.dimensions()[0] >= y_dims[0],
PADDLE_ENFORCE(dimsYGrad[0] <= firstDim,
"First dimension of gradient must be greater or " "First dimension of gradient must be greater or "
"equal than first dimension of target"); "equal than first dimension of target");
if (dimsYGrad[0] == firstDim) { if (sub_result.dimensions()[0] == y_dims[0]) {
yGradMat.device(eigenPlace) = -1 * gradMat; y_grad.device(eigen_place) = -1 * grad_mat;
} else { } else {
yGradMat.device(eigenPlace) = y_grad.device(eigen_place) =
-1 * (gradMat.sum(Eigen::array<int, 2>({0}))); -1 * (grad_mat.sum(Eigen::array<int, 2>({0})));
} }
} }
} }
......
...@@ -21,5 +21,15 @@ class TestSquaredL2DistanceOp(unittest.TestCase): ...@@ -21,5 +21,15 @@ class TestSquaredL2DistanceOp(unittest.TestCase):
} }
class TestSquaredL2DistanceGradOp(GradientChecker):
def test_squared_l2_distance(self):
op = create_op("squared_l2_distance")
inputs = {
'X': np.random.uniform(0.1, 1., (2, 3)).astype('float32'),
'Y': np.random.uniform(0.1, 1., (2, 3)).astype('float32')
}
self.check_grad(op, inputs, set(["X", "Y"]), "Out")
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册