未验证 提交 3c14b38e 编写于 作者: D duanyanhui 提交者: GitHub

fix npu save_combine (#50496)

上级 3d5faa88
......@@ -195,45 +195,52 @@ class SaveCombineOpKernel : public framework::OpKernel<T> {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& dev_ctx = *pool.Get(place);
for (size_t i = 0; i < inp_var_names.size(); i++) {
PADDLE_ENFORCE_NOT_NULL(
inp_vars[i],
platform::errors::InvalidArgument("Cannot find variable %s to save.",
inp_var_names[i]));
PADDLE_ENFORCE_EQ(
inp_vars[i]->IsType<phi::DenseTensor>() ||
inp_vars[i]->IsType<framework::Vocab>(),
true,
platform::errors::InvalidArgument(
"SaveCombine operator only supports saving "
"phi::DenseTensor or Vocab variable, %s has wrong type.",
inp_var_names[i]));
if (inp_vars.size() > 0 && inp_vars[0]->IsType<phi::DenseTensor>()) {
std::vector<const phi::DenseTensor*> x(inp_vars.size());
for (size_t i = 0; i < inp_vars.size(); i++) {
x[i] = (&(inp_vars[i]->Get<phi::DenseTensor>()));
}
SaveCombineTensorKernel<T>(dev_ctx,
x,
filename,
overwrite,
save_as_fp16,
save_to_memory,
output);
} else {
std::vector<const phi::ExtendedTensor*> x(inp_vars.size());
for (size_t i = 0; i < inp_vars.size(); i++) {
x[i] = (&(inp_vars[i]->Get<framework::Vocab>()));
}
SaveCombineVocabKernel<T>(dev_ctx,
x,
filename,
overwrite,
save_as_fp16,
save_to_memory,
output);
if (inp_vars.size() > 0 && inp_vars[0]->IsType<phi::DenseTensor>()) {
std::vector<const phi::DenseTensor*> x(inp_vars.size());
for (size_t i = 0; i < inp_vars.size(); i++) {
PADDLE_ENFORCE_NOT_NULL(
inp_vars[i],
platform::errors::InvalidArgument(
"Cannot find variable %s to save.", inp_var_names[i]));
PADDLE_ENFORCE_EQ(
inp_vars[i]->IsType<phi::DenseTensor>(),
true,
platform::errors::InvalidArgument(
"SaveCombine operator only supports saving "
"phi::DenseTensor or Vocab variable, %s has wrong type.",
inp_var_names[i]));
x[i] = (&(inp_vars[i]->Get<phi::DenseTensor>()));
}
SaveCombineTensorKernel<T>(dev_ctx,
x,
filename,
overwrite,
save_as_fp16,
save_to_memory,
output);
} else {
std::vector<const phi::ExtendedTensor*> x(inp_vars.size());
for (size_t i = 0; i < inp_vars.size(); i++) {
PADDLE_ENFORCE_NOT_NULL(
inp_vars[i],
platform::errors::InvalidArgument(
"Cannot find variable %s to save.", inp_var_names[i]));
PADDLE_ENFORCE_EQ(
inp_vars[i]->IsType<framework::Vocab>(),
true,
platform::errors::InvalidArgument(
"SaveCombine operator only supports saving "
"phi::DenseTensor or Vocab variable, %s has wrong type.",
inp_var_names[i]));
x[i] = (&(inp_vars[i]->Get<framework::Vocab>()));
}
SaveCombineVocabKernel<T>(dev_ctx,
x,
filename,
overwrite,
save_as_fp16,
save_to_memory,
output);
}
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册