提交 5e7aa8c7 编写于 作者: F fengjiayi

code clean

上级 855c9e33
...@@ -59,6 +59,14 @@ inline T* Tensor::mutable_data(platform::Place place) { ...@@ -59,6 +59,14 @@ inline T* Tensor::mutable_data(platform::Place place) {
} }
inline Tensor ReshapeToMatrix(const Tensor& src, int num_col_dims) { inline Tensor ReshapeToMatrix(const Tensor& src, int num_col_dims) {
int rank = src.dims().size();
PADDLE_ENFORCE_GE(
rank, 2,
"'ReshapeToMatrix()' is only used for flatten high rank "
"tensors to matrixs. Can not be used in reshaping vectors.");
if (rank == 2) {
return src;
}
Tensor res; Tensor res;
res.ShareDataWith(src); res.ShareDataWith(src);
res.Resize(flatten_to_2d(src.dims(), num_col_dims)); res.Resize(flatten_to_2d(src.dims(), num_col_dims));
......
...@@ -45,11 +45,9 @@ class CrossEntropyOp : public framework::OperatorWithKernel { ...@@ -45,11 +45,9 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
"Input(Label) should be 1."); "Input(Label) should be 1.");
} }
auto out_dim_vec = auto y_dims = x_dims;
framework::vectorize(framework::slice_ddim(x_dims, 0, rank - 1)); y_dims[rank - 1] = 1;
out_dim_vec.push_back(1); ctx->SetOutputDim("Y", y_dims);
ctx->SetOutputDim("Y", framework::make_ddim(out_dim_vec));
ctx->ShareLoD("X", /*->*/ "Y"); ctx->ShareLoD("X", /*->*/ "Y");
} }
......
...@@ -34,10 +34,9 @@ class CrossEntropyOpKernel : public framework::OpKernel<T> { ...@@ -34,10 +34,9 @@ class CrossEntropyOpKernel : public framework::OpKernel<T> {
y->mutable_data<T>(ctx.GetPlace()); y->mutable_data<T>(ctx.GetPlace());
int rank = x->dims().size(); int rank = x->dims().size();
Tensor x_2d = rank > 2 ? framework::ReshapeToMatrix(*x, rank - 1) : *x; Tensor x_2d = framework::ReshapeToMatrix(*x, rank - 1);
Tensor labels_2d = Tensor labels_2d = framework::ReshapeToMatrix(*labels, rank - 1);
rank > 2 ? framework::ReshapeToMatrix(*labels, rank - 1) : *labels; Tensor y_2d = framework::ReshapeToMatrix(*y, rank - 1);
Tensor y_2d = rank > 2 ? framework::ReshapeToMatrix(*y, rank - 1) : *y;
math::CrossEntropyFunctor<DeviceContext, T>()( math::CrossEntropyFunctor<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), &y_2d, &x_2d, &labels_2d, ctx.template device_context<DeviceContext>(), &y_2d, &x_2d, &labels_2d,
......
...@@ -32,9 +32,8 @@ class SoftmaxKernel : public framework::OpKernel<T> { ...@@ -32,9 +32,8 @@ class SoftmaxKernel : public framework::OpKernel<T> {
Out->mutable_data<T>(context.GetPlace()); Out->mutable_data<T>(context.GetPlace());
int rank = X->dims().size(); int rank = X->dims().size();
Tensor X_2d = rank > 2 ? framework::ReshapeToMatrix(*X, rank - 1) : *X; Tensor X_2d = framework::ReshapeToMatrix(*X, rank - 1);
Tensor Out_2d = Tensor Out_2d = framework::ReshapeToMatrix(*Out, rank - 1);
rank > 2 ? framework::ReshapeToMatrix(*Out, rank - 1) : *Out;
math::SoftmaxFunctor<DeviceContext, T>()( math::SoftmaxFunctor<DeviceContext, T>()(
context.template device_context<DeviceContext>(), &X_2d, &Out_2d); context.template device_context<DeviceContext>(), &X_2d, &Out_2d);
...@@ -53,11 +52,9 @@ class SoftmaxGradKernel : public framework::OpKernel<T> { ...@@ -53,11 +52,9 @@ class SoftmaxGradKernel : public framework::OpKernel<T> {
dX->mutable_data<T>(context.GetPlace()); dX->mutable_data<T>(context.GetPlace());
int rank = Out->dims().size(); int rank = Out->dims().size();
Tensor Out_2d = Tensor Out_2d = framework::ReshapeToMatrix(*Out, rank - 1);
rank > 2 ? framework::ReshapeToMatrix(*Out, rank - 1) : *Out; Tensor dOut_2d = framework::ReshapeToMatrix(*dOut, rank - 1);
Tensor dOut_2d = Tensor dX_2d = framework::ReshapeToMatrix(*dX, rank - 1);
rank > 2 ? framework::ReshapeToMatrix(*dOut, rank - 1) : *dOut;
Tensor dX_2d = rank > 2 ? framework::ReshapeToMatrix(*dX, rank - 1) : *dX;
math::SoftmaxGradFunctor<DeviceContext, T>()( math::SoftmaxGradFunctor<DeviceContext, T>()(
context.template device_context<DeviceContext>(), &Out_2d, &dOut_2d, context.template device_context<DeviceContext>(), &Out_2d, &dOut_2d,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册