提交 996a7473 编写于 作者: C chengduo 提交者: Yibing Liu

Fix cross_entropy bug (#16237)

test=release/1.3
上级 2c5fee53
...@@ -439,7 +439,8 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> { ...@@ -439,7 +439,8 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
context.Input<Tensor>(framework::GradVarName("Loss"))->data<T>(); context.Input<Tensor>(framework::GradVarName("Loss"))->data<T>();
Tensor* logit_grad = Tensor* logit_grad =
context.Output<Tensor>(framework::GradVarName("Logits")); context.Output<Tensor>(framework::GradVarName("Logits"));
logit_grad->ShareDataWith(*context.Input<Tensor>("Softmax")); framework::TensorCopy(*context.Input<Tensor>("Softmax"), context.GetPlace(),
context.device_context(), logit_grad);
T* logit_grad_data = logit_grad->data<T>(); T* logit_grad_data = logit_grad->data<T>();
const int batch_size = logit_grad->dims()[0]; const int batch_size = logit_grad->dims()[0];
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册