提交 ceb31d30 编写于 作者: D dengkaipeng

fix formax. test=develop

上级 d54005a7
......@@ -40,7 +40,7 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel<T> {
softmax->mutable_data<T>(context.GetPlace());
loss->mutable_data<T>(context.GetPlace());
int axis_dim = logits->dims()[logits->dims().size()-1];
int axis_dim = logits->dims()[logits->dims().size() - 1];
auto& dev_ctx =
context.template device_context<platform::CPUDeviceContext>();
......
......@@ -67,9 +67,10 @@ class CudnnCTCKernel : public framework::OpKernel<T> {
softmax_logits.mutable_data<T>(logits->dims(), ctx.GetPlace());
softmax_logits.set_lod(logits_lod);
int rank = logits->dims().size();
int axis_dim = logits->dims()[rank - 1];
Tensor in_2d = framework::ReshapeToMatrix(*logits, rank - 1);
Tensor out_2d = framework::ReshapeToMatrix(softmax_logits, rank - 1);
math::SoftmaxFunctor<DeviceContext, T, false>()(dev_ctx, -1, &in_2d,
math::SoftmaxFunctor<DeviceContext, T, false>()(dev_ctx, axis_dim, &in_2d,
&out_2d);
// ctc needs sequences data stored in transposed padding format
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册