diff --git a/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc b/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc index 458a26a762f416a2aea2336f1778ddb085f6add2..94d5c4bac58fdaee3a2b86775c5bc8ac562c332e 100644 --- a/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc @@ -17,8 +17,12 @@ #include #include "paddle/fluid/framework/convert_utils.h" +#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/place.h" +#include "paddle/phi/common/data_type.h" namespace phi { class DenseTensor; @@ -30,6 +34,23 @@ class Scope; } // namespace framework } // namespace paddle +namespace { +template +void ConvertTensorType(paddle::framework::LoDTensor* tensor) { + paddle::framework::Tensor tmp_tensor; + tmp_tensor.set_type(paddle::experimental::CppTypeToDataType::Type()); + tmp_tensor.Resize(tensor->dims()); + auto* tmp_data = tmp_tensor.mutable_data(paddle::platform::CPUPlace()); + auto* data = tensor->mutable_data(paddle::platform::CPUPlace()); + for (int i = 0; i < tensor->numel(); i++) { + tmp_data[i] = static_cast(data[i]); + } + tensor->clear(); + paddle::framework::TensorCopySync( + tmp_tensor, paddle::platform::CPUPlace(), tensor); +} +} // namespace + namespace paddle { namespace framework { namespace ir { @@ -290,19 +311,7 @@ void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const { auto tensor_type = conv_weight_tensor->dtype(); if (tensor_type == paddle::experimental::DataType::FLOAT16) { - framework::Tensor weight_float_tensor; - weight_float_tensor.set_type(paddle::experimental::DataType::FLOAT32); - weight_float_tensor.Resize(conv_weight_tensor->dims()); - auto* weight_float_data = - weight_float_tensor.mutable_data(platform::CPUPlace()); - auto* data = - conv_weight_tensor->mutable_data(platform::CPUPlace()); - for (int i = 0; i < conv_weight_tensor->numel(); i++) { - weight_float_data[i] = static_cast(data[i]); - } - conv_weight_tensor->clear(); - paddle::framework::TensorCopySync( - weight_float_tensor, platform::CPUPlace(), conv_weight_tensor); + ConvertTensorType(conv_weight_tensor); } // Get batch norm bias @@ -341,40 +350,8 @@ void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const { conv_type()); if (tensor_type == paddle::experimental::DataType::FLOAT16) { - { - framework::Tensor weight_float16_tensor; - weight_float16_tensor.set_type(paddle::experimental::DataType::FLOAT16); - weight_float16_tensor.Resize(conv_weight_tensor->dims()); - auto* weight_float16_data = - weight_float16_tensor.mutable_data(platform::CPUPlace()); - auto* data = - conv_weight_tensor->mutable_data(platform::CPUPlace()); - for (int i = 0; i < conv_weight_tensor->numel(); i++) { - weight_float16_data[i] = static_cast(data[i]); - } - conv_weight_tensor->clear(); - paddle::framework::TensorCopySync( - weight_float16_tensor, platform::CPUPlace(), conv_weight_tensor); - } - - { - framework::Tensor eltwise_y_in_float16_tensor; - eltwise_y_in_float16_tensor.set_type( - paddle::experimental::DataType::FLOAT16); - eltwise_y_in_float16_tensor.Resize(eltwise_y_in_tensor->dims()); - auto* eltwise_y_in_float16_data = - eltwise_y_in_float16_tensor.mutable_data( - platform::CPUPlace()); - auto* data = - eltwise_y_in_tensor->mutable_data(platform::CPUPlace()); - for (int i = 0; i < eltwise_y_in_tensor->numel(); i++) { - eltwise_y_in_float16_data[i] = static_cast(data[i]); - } - eltwise_y_in_tensor->clear(); - paddle::framework::TensorCopySync(eltwise_y_in_float16_tensor, - platform::CPUPlace(), - eltwise_y_in_tensor); - } + ConvertTensorType(conv_weight_tensor); + ConvertTensorType(eltwise_y_in_tensor); } // with MKL-DNN fuse conv+bn into conv with bias @@ -612,6 +589,16 @@ void ConvEltwiseAddBNFusePass::ApplyImpl(ir::Graph* graph) const { float epsilon = PADDLE_GET_CONST(float, batch_norm->Op()->GetAttr("epsilon")); + // conv_weight fp16 --> fp32 + auto* conv_weight_tensor = + scope->FindVar(conv_weight->Name())->GetMutable(); + auto tensor_type = conv_weight_tensor->dtype(); + + if (tensor_type == paddle::experimental::DataType::FLOAT16) { + ConvertTensorType(conv_weight_tensor); + ConvertTensorType(eltwise_y_in_tensor); + } + // if bias is an input to other ops as well then we cannot overwrite it // so we create separate elementwise Y in nodes if (eltwise_y_in->outputs.size() > 1) { @@ -666,6 +653,11 @@ void ConvEltwiseAddBNFusePass::ApplyImpl(ir::Graph* graph) const { conv_type()); } + if (tensor_type == paddle::experimental::DataType::FLOAT16) { + ConvertTensorType(conv_weight_tensor); + ConvertTensorType(eltwise_y_in_tensor); + } + // Update the elementwise_add node eltwise->Op()->SetAttr("axis", 1); eltwise->Op()->SetOutput("Out", std::vector({bn_out->Name()}));