提交 855c9e33 编写于 作者: F fengjiayi

clean softmax_op code

上级 24d51de0
...@@ -31,16 +31,13 @@ class SoftmaxKernel : public framework::OpKernel<T> { ...@@ -31,16 +31,13 @@ class SoftmaxKernel : public framework::OpKernel<T> {
// allocate memory on device. // allocate memory on device.
Out->mutable_data<T>(context.GetPlace()); Out->mutable_data<T>(context.GetPlace());
auto dims = X->dims(); int rank = X->dims().size();
auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1); Tensor X_2d = rank > 2 ? framework::ReshapeToMatrix(*X, rank - 1) : *X;
framework::LoDTensor flattened_x; Tensor Out_2d =
framework::LoDTensor flattened_out; rank > 2 ? framework::ReshapeToMatrix(*Out, rank - 1) : *Out;
flattened_x.ShareDataWith(*X).Resize(flattened_dims);
flattened_out.ShareDataWith(*Out).Resize(flattened_dims);
math::SoftmaxFunctor<DeviceContext, T>()( math::SoftmaxFunctor<DeviceContext, T>()(
context.template device_context<DeviceContext>(), &flattened_x, context.template device_context<DeviceContext>(), &X_2d, &Out_2d);
&flattened_out);
} }
}; };
...@@ -55,18 +52,16 @@ class SoftmaxGradKernel : public framework::OpKernel<T> { ...@@ -55,18 +52,16 @@ class SoftmaxGradKernel : public framework::OpKernel<T> {
// allocate memory on device. // allocate memory on device.
dX->mutable_data<T>(context.GetPlace()); dX->mutable_data<T>(context.GetPlace());
auto dims = Out->dims(); int rank = Out->dims().size();
auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1); Tensor Out_2d =
framework::LoDTensor flattened_out; rank > 2 ? framework::ReshapeToMatrix(*Out, rank - 1) : *Out;
framework::LoDTensor flattened_d_out; Tensor dOut_2d =
framework::LoDTensor flattened_d_x; rank > 2 ? framework::ReshapeToMatrix(*dOut, rank - 1) : *dOut;
flattened_out.ShareDataWith(*Out).Resize(flattened_dims); Tensor dX_2d = rank > 2 ? framework::ReshapeToMatrix(*dX, rank - 1) : *dX;
flattened_d_out.ShareDataWith(*dOut).Resize(flattened_dims);
flattened_d_x.ShareDataWith(*dX).Resize(flattened_dims);
math::SoftmaxGradFunctor<DeviceContext, T>()( math::SoftmaxGradFunctor<DeviceContext, T>()(
context.template device_context<DeviceContext>(), &flattened_out, context.template device_context<DeviceContext>(), &Out_2d, &dOut_2d,
&flattened_d_out, &flattened_d_x); &dX_2d);
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册