提交 5e7aa8c7 编写于 作者: F fengjiayi

code clean

上级 855c9e33
......@@ -59,6 +59,14 @@ inline T* Tensor::mutable_data(platform::Place place) {
}
inline Tensor ReshapeToMatrix(const Tensor& src, int num_col_dims) {
int rank = src.dims().size();
PADDLE_ENFORCE_GE(
rank, 2,
"'ReshapeToMatrix()' is only used for flatten high rank "
"tensors to matrixs. Can not be used in reshaping vectors.");
if (rank == 2) {
return src;
}
Tensor res;
res.ShareDataWith(src);
res.Resize(flatten_to_2d(src.dims(), num_col_dims));
......
......@@ -45,11 +45,9 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
"Input(Label) should be 1.");
}
auto out_dim_vec =
framework::vectorize(framework::slice_ddim(x_dims, 0, rank - 1));
out_dim_vec.push_back(1);
ctx->SetOutputDim("Y", framework::make_ddim(out_dim_vec));
auto y_dims = x_dims;
y_dims[rank - 1] = 1;
ctx->SetOutputDim("Y", y_dims);
ctx->ShareLoD("X", /*->*/ "Y");
}
......
......@@ -34,10 +34,9 @@ class CrossEntropyOpKernel : public framework::OpKernel<T> {
y->mutable_data<T>(ctx.GetPlace());
int rank = x->dims().size();
Tensor x_2d = rank > 2 ? framework::ReshapeToMatrix(*x, rank - 1) : *x;
Tensor labels_2d =
rank > 2 ? framework::ReshapeToMatrix(*labels, rank - 1) : *labels;
Tensor y_2d = rank > 2 ? framework::ReshapeToMatrix(*y, rank - 1) : *y;
Tensor x_2d = framework::ReshapeToMatrix(*x, rank - 1);
Tensor labels_2d = framework::ReshapeToMatrix(*labels, rank - 1);
Tensor y_2d = framework::ReshapeToMatrix(*y, rank - 1);
math::CrossEntropyFunctor<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), &y_2d, &x_2d, &labels_2d,
......
......@@ -32,9 +32,8 @@ class SoftmaxKernel : public framework::OpKernel<T> {
Out->mutable_data<T>(context.GetPlace());
int rank = X->dims().size();
Tensor X_2d = rank > 2 ? framework::ReshapeToMatrix(*X, rank - 1) : *X;
Tensor Out_2d =
rank > 2 ? framework::ReshapeToMatrix(*Out, rank - 1) : *Out;
Tensor X_2d = framework::ReshapeToMatrix(*X, rank - 1);
Tensor Out_2d = framework::ReshapeToMatrix(*Out, rank - 1);
math::SoftmaxFunctor<DeviceContext, T>()(
context.template device_context<DeviceContext>(), &X_2d, &Out_2d);
......@@ -53,11 +52,9 @@ class SoftmaxGradKernel : public framework::OpKernel<T> {
dX->mutable_data<T>(context.GetPlace());
int rank = Out->dims().size();
Tensor Out_2d =
rank > 2 ? framework::ReshapeToMatrix(*Out, rank - 1) : *Out;
Tensor dOut_2d =
rank > 2 ? framework::ReshapeToMatrix(*dOut, rank - 1) : *dOut;
Tensor dX_2d = rank > 2 ? framework::ReshapeToMatrix(*dX, rank - 1) : *dX;
Tensor Out_2d = framework::ReshapeToMatrix(*Out, rank - 1);
Tensor dOut_2d = framework::ReshapeToMatrix(*dOut, rank - 1);
Tensor dX_2d = framework::ReshapeToMatrix(*dX, rank - 1);
math::SoftmaxGradFunctor<DeviceContext, T>()(
context.template device_context<DeviceContext>(), &Out_2d, &dOut_2d,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册