提交 167933f6 编写于 作者: C chengduozh

set constant for loss

test=develop
上级 b4aca8ed
...@@ -144,6 +144,8 @@ class CudnnCTCKernel : public framework::OpKernel<T> { ...@@ -144,6 +144,8 @@ class CudnnCTCKernel : public framework::OpKernel<T> {
CUDNN_CTC_LOSS_ALGO_DETERMINISTIC, cu_ctcloss_desc, &workspace_size)); CUDNN_CTC_LOSS_ALGO_DETERMINISTIC, cu_ctcloss_desc, &workspace_size));
T* loss_data = loss->mutable_data<T>(loss_dims, ctx.GetPlace()); T* loss_data = loss->mutable_data<T>(loss_dims, ctx.GetPlace());
math::SetConstant<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), loss, static_cast<T>(0));
auto temp_allocation = auto temp_allocation =
platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx).Allocate( platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx).Allocate(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册