未验证 提交 62b5452d 编写于 作者: W Wilber 提交者: GitHub

conv_eltwiseadd_bn_fuse support fp16 (#45379)

上级 3c14b094
...@@ -17,8 +17,12 @@ ...@@ -17,8 +17,12 @@
#include <string> #include <string>
#include "paddle/fluid/framework/convert_utils.h" #include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/common/data_type.h"
namespace phi { namespace phi {
class DenseTensor; class DenseTensor;
...@@ -30,6 +34,23 @@ class Scope; ...@@ -30,6 +34,23 @@ class Scope;
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
namespace {
template <typename T1, typename T2>
void ConvertTensorType(paddle::framework::LoDTensor* tensor) {
paddle::framework::Tensor tmp_tensor;
tmp_tensor.set_type(paddle::experimental::CppTypeToDataType<T2>::Type());
tmp_tensor.Resize(tensor->dims());
auto* tmp_data = tmp_tensor.mutable_data<T2>(paddle::platform::CPUPlace());
auto* data = tensor->mutable_data<T1>(paddle::platform::CPUPlace());
for (int i = 0; i < tensor->numel(); i++) {
tmp_data[i] = static_cast<T2>(data[i]);
}
tensor->clear();
paddle::framework::TensorCopySync(
tmp_tensor, paddle::platform::CPUPlace(), tensor);
}
} // namespace
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
...@@ -290,19 +311,7 @@ void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -290,19 +311,7 @@ void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const {
auto tensor_type = conv_weight_tensor->dtype(); auto tensor_type = conv_weight_tensor->dtype();
if (tensor_type == paddle::experimental::DataType::FLOAT16) { if (tensor_type == paddle::experimental::DataType::FLOAT16) {
framework::Tensor weight_float_tensor; ConvertTensorType<float16, float>(conv_weight_tensor);
weight_float_tensor.set_type(paddle::experimental::DataType::FLOAT32);
weight_float_tensor.Resize(conv_weight_tensor->dims());
auto* weight_float_data =
weight_float_tensor.mutable_data<float>(platform::CPUPlace());
auto* data =
conv_weight_tensor->mutable_data<float16>(platform::CPUPlace());
for (int i = 0; i < conv_weight_tensor->numel(); i++) {
weight_float_data[i] = static_cast<float>(data[i]);
}
conv_weight_tensor->clear();
paddle::framework::TensorCopySync(
weight_float_tensor, platform::CPUPlace(), conv_weight_tensor);
} }
// Get batch norm bias // Get batch norm bias
...@@ -341,40 +350,8 @@ void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -341,40 +350,8 @@ void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const {
conv_type()); conv_type());
if (tensor_type == paddle::experimental::DataType::FLOAT16) { if (tensor_type == paddle::experimental::DataType::FLOAT16) {
{ ConvertTensorType<float, float16>(conv_weight_tensor);
framework::Tensor weight_float16_tensor; ConvertTensorType<float, float16>(eltwise_y_in_tensor);
weight_float16_tensor.set_type(paddle::experimental::DataType::FLOAT16);
weight_float16_tensor.Resize(conv_weight_tensor->dims());
auto* weight_float16_data =
weight_float16_tensor.mutable_data<float16>(platform::CPUPlace());
auto* data =
conv_weight_tensor->mutable_data<float>(platform::CPUPlace());
for (int i = 0; i < conv_weight_tensor->numel(); i++) {
weight_float16_data[i] = static_cast<float16>(data[i]);
}
conv_weight_tensor->clear();
paddle::framework::TensorCopySync(
weight_float16_tensor, platform::CPUPlace(), conv_weight_tensor);
}
{
framework::Tensor eltwise_y_in_float16_tensor;
eltwise_y_in_float16_tensor.set_type(
paddle::experimental::DataType::FLOAT16);
eltwise_y_in_float16_tensor.Resize(eltwise_y_in_tensor->dims());
auto* eltwise_y_in_float16_data =
eltwise_y_in_float16_tensor.mutable_data<float16>(
platform::CPUPlace());
auto* data =
eltwise_y_in_tensor->mutable_data<float>(platform::CPUPlace());
for (int i = 0; i < eltwise_y_in_tensor->numel(); i++) {
eltwise_y_in_float16_data[i] = static_cast<float16>(data[i]);
}
eltwise_y_in_tensor->clear();
paddle::framework::TensorCopySync(eltwise_y_in_float16_tensor,
platform::CPUPlace(),
eltwise_y_in_tensor);
}
} }
// with MKL-DNN fuse conv+bn into conv with bias // with MKL-DNN fuse conv+bn into conv with bias
...@@ -612,6 +589,16 @@ void ConvEltwiseAddBNFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -612,6 +589,16 @@ void ConvEltwiseAddBNFusePass::ApplyImpl(ir::Graph* graph) const {
float epsilon = float epsilon =
PADDLE_GET_CONST(float, batch_norm->Op()->GetAttr("epsilon")); PADDLE_GET_CONST(float, batch_norm->Op()->GetAttr("epsilon"));
// conv_weight fp16 --> fp32
auto* conv_weight_tensor =
scope->FindVar(conv_weight->Name())->GetMutable<LoDTensor>();
auto tensor_type = conv_weight_tensor->dtype();
if (tensor_type == paddle::experimental::DataType::FLOAT16) {
ConvertTensorType<float16, float>(conv_weight_tensor);
ConvertTensorType<float16, float>(eltwise_y_in_tensor);
}
// if bias is an input to other ops as well then we cannot overwrite it // if bias is an input to other ops as well then we cannot overwrite it
// so we create separate elementwise Y in nodes // so we create separate elementwise Y in nodes
if (eltwise_y_in->outputs.size() > 1) { if (eltwise_y_in->outputs.size() > 1) {
...@@ -666,6 +653,11 @@ void ConvEltwiseAddBNFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -666,6 +653,11 @@ void ConvEltwiseAddBNFusePass::ApplyImpl(ir::Graph* graph) const {
conv_type()); conv_type());
} }
if (tensor_type == paddle::experimental::DataType::FLOAT16) {
ConvertTensorType<float, float16>(conv_weight_tensor);
ConvertTensorType<float, float16>(eltwise_y_in_tensor);
}
// Update the elementwise_add node // Update the elementwise_add node
eltwise->Op()->SetAttr("axis", 1); eltwise->Op()->SetAttr("axis", 1);
eltwise->Op()->SetOutput("Out", std::vector<std::string>({bn_out->Name()})); eltwise->Op()->SetOutput("Out", std::vector<std::string>({bn_out->Name()}));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册