提交 855c9e33 编写于 作者: F fengjiayi

clean softmax_op code

上级 24d51de0
......@@ -31,16 +31,13 @@ class SoftmaxKernel : public framework::OpKernel<T> {
// allocate memory on device.
Out->mutable_data<T>(context.GetPlace());
auto dims = X->dims();
auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1);
framework::LoDTensor flattened_x;
framework::LoDTensor flattened_out;
flattened_x.ShareDataWith(*X).Resize(flattened_dims);
flattened_out.ShareDataWith(*Out).Resize(flattened_dims);
int rank = X->dims().size();
Tensor X_2d = rank > 2 ? framework::ReshapeToMatrix(*X, rank - 1) : *X;
Tensor Out_2d =
rank > 2 ? framework::ReshapeToMatrix(*Out, rank - 1) : *Out;
math::SoftmaxFunctor<DeviceContext, T>()(
context.template device_context<DeviceContext>(), &flattened_x,
&flattened_out);
context.template device_context<DeviceContext>(), &X_2d, &Out_2d);
}
};
......@@ -55,18 +52,16 @@ class SoftmaxGradKernel : public framework::OpKernel<T> {
// allocate memory on device.
dX->mutable_data<T>(context.GetPlace());
auto dims = Out->dims();
auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1);
framework::LoDTensor flattened_out;
framework::LoDTensor flattened_d_out;
framework::LoDTensor flattened_d_x;
flattened_out.ShareDataWith(*Out).Resize(flattened_dims);
flattened_d_out.ShareDataWith(*dOut).Resize(flattened_dims);
flattened_d_x.ShareDataWith(*dX).Resize(flattened_dims);
int rank = Out->dims().size();
Tensor Out_2d =
rank > 2 ? framework::ReshapeToMatrix(*Out, rank - 1) : *Out;
Tensor dOut_2d =
rank > 2 ? framework::ReshapeToMatrix(*dOut, rank - 1) : *dOut;
Tensor dX_2d = rank > 2 ? framework::ReshapeToMatrix(*dX, rank - 1) : *dX;
math::SoftmaxGradFunctor<DeviceContext, T>()(
context.template device_context<DeviceContext>(), &flattened_out,
&flattened_d_out, &flattened_d_x);
context.template device_context<DeviceContext>(), &Out_2d, &dOut_2d,
&dX_2d);
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册