未验证 提交 de6f15b6 编写于 作者: W Wilber 提交者: GitHub

reconstruct code for convert_fp16 (#46428) (#47087)

上级 2cc8797e
......@@ -30,7 +30,7 @@ namespace paddle {
namespace inference {
namespace analysis {
bool OpSupportPrecision(const std::string& phi_op_type,
bool OpSupportPrecision(const std::string& op_type,
phi::Backend backend,
phi::DataType precision,
const std::unordered_set<std::string>& blacklist);
......
......@@ -140,39 +140,12 @@ void IrParamsSyncAmongDevicesPass::CopyParamsToGpu(Argument *argument) {
auto var_data_type = var_node->Var()->GetDataType();
VLOG(5) << "var_name is " << var_name << ", data type is "
<< var_data_type;
if (var_data_type == paddle::framework::proto::VarType::FP16 &&
t->dtype() != paddle::experimental::DataType::FLOAT16) {
framework::Tensor half_tensor;
half_tensor.set_type(paddle::experimental::DataType::FLOAT16);
half_tensor.Resize(t->dims());
auto *half_data =
half_tensor.mutable_data<float16>(platform::CPUPlace());
for (int i = 0; i < t->numel(); i++) {
auto *data = t->mutable_data<float16>(platform::CPUPlace());
half_data[i] = static_cast<float16>(data[i]);
}
t->clear();
paddle::framework::TensorCopySync(half_tensor, place, t);
} else if (var_data_type == paddle::framework::proto::VarType::BF16) {
framework::Tensor bf16_tensor;
bf16_tensor.set_type(paddle::experimental::DataType::BFLOAT16);
bf16_tensor.Resize(t->dims());
auto *bf16_data = bf16_tensor.mutable_data<platform::bfloat16>(
platform::CPUPlace());
for (int i = 0; i < t->numel(); i++) {
auto *data = t->mutable_data<bfloat16>(platform::CPUPlace());
bf16_data[i] = static_cast<platform::bfloat16>(data[i]);
}
t->clear();
paddle::framework::TensorCopySync(bf16_tensor, place, t);
} else {
platform::CPUPlace cpu_place;
framework::LoDTensor temp_tensor;
temp_tensor.Resize(t->dims());
paddle::framework::TensorCopySync(*t, cpu_place, &temp_tensor);
t->clear();
paddle::framework::TensorCopySync(temp_tensor, place, t);
}
platform::CPUPlace cpu_place;
framework::LoDTensor temp_tensor;
temp_tensor.Resize(t->dims());
paddle::framework::TensorCopySync(*t, cpu_place, &temp_tensor);
t->clear();
paddle::framework::TensorCopySync(temp_tensor, place, t);
}
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册