未验证 提交 6b258317 编写于 作者: L Leo Chen 提交者: GitHub

fix TransferInplaceBack (#29830)

上级 59b47f3b
......@@ -1336,12 +1336,6 @@ Scope* OperatorWithKernel::PrepareData(
continue;
}
auto out_var_names = OutputVars(true);
if (std::find(out_var_names.begin(), out_var_names.end(), var_name) !=
out_var_names.end()) {
transfered_inplace_vars->emplace_back(var_name);
}
VLOG(3) << "Transform Variable " << var_name << " from "
<< kernel_type_for_var << " to " << expected_kernel_key;
......@@ -1383,13 +1377,33 @@ Scope* OperatorWithKernel::PrepareData(
if (enable_cache_runtime_context_) {
pre_scope_ = nullptr;
}
// Create new var with the same name in transfer scopes
auto* trans_var = new_scope->Var(var_name);
input_vars[i] = trans_var;
// Find if inplace exists between input and output
// If inplace exists, set the new created var to inplaced output, and
// record its name in transfered_inplace_vars.
for (auto& pair : Outputs()) {
for (size_t j = 0; j < pair.second.size(); ++j) {
if (pair.second[j] == var_name) {
VLOG(4) << "Found inplace between input(" << var_name_item.first
<< ") and output(" << pair.first
<< "), the variable name is " << var_name;
ctx->outputs[pair.first][j] = trans_var;
transfered_inplace_vars->emplace_back(var_name);
}
}
}
// Do transfer
Tensor out;
TransformData(expected_kernel_key, kernel_type_for_var, *tensor_in, &out);
SetTensorToVariable(*var, out, trans_var);
}
}
// If pre_scope = &scope, it means that scope is cached and the op is not in
// while block. If new_scope = nullptr, it means that for each input of this
// Op, there is no need to do PrepareData. So PrepareData could be skipped at
......
......@@ -40,5 +40,19 @@ class TestIncrement(unittest.TestCase):
self.assertEqual((output.numpy() == expected_result).all(), True)
class TestInplaceApiWithDataTransform(unittest.TestCase):
def test_increment(self):
if fluid.core.is_compiled_with_cuda():
paddle.enable_static()
with paddle.fluid.device_guard("gpu:0"):
x = paddle.fluid.layers.fill_constant([1], "float32", 0)
with paddle.fluid.device_guard("cpu"):
x = paddle.increment(x)
exe = paddle.static.Executor(paddle.CUDAPlace(0))
a, = exe.run(paddle.static.default_main_program(), fetch_list=[x])
paddle.disable_static()
self.assertEqual(a[0], 1)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册