未验证 提交 c2cd02de 编写于 作者: Y YuhangLi 提交者: GitHub

init output 4 all backend (#53124)

上级 ba899b5c
...@@ -418,11 +418,8 @@ void FakeInitializeOutputsForFunctionKernel( ...@@ -418,11 +418,8 @@ void FakeInitializeOutputsForFunctionKernel(
runtime_ctx.inputs.find("Beta1Pow")->second.at(0)); runtime_ctx.inputs.find("Beta1Pow")->second.at(0));
phi::TensorBase* beta2_pow = GetTensorFormVar( phi::TensorBase* beta2_pow = GetTensorFormVar(
runtime_ctx.inputs.find("Beta2Pow")->second.at(0)); runtime_ctx.inputs.find("Beta2Pow")->second.at(0));
if (beta1_pow->place() == CPUPlace() && if (beta1_pow->place() == beta2_pow->place()) {
beta2_pow->place() == CPUPlace()) { backend = phi::TransToPhiBackend(beta1_pow->place());
backend = phi::TransToPhiBackend(CPUPlace());
} else {
backend = phi::TransToPhiBackend(GPUPlace());
} }
} else { } else {
PADDLE_THROW(phi::errors::Unimplemented( PADDLE_THROW(phi::errors::Unimplemented(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册