未验证 提交 62b5452d 编写于 作者: W Wilber 提交者: GitHub

conv_eltwiseadd_bn_fuse support fp16 (#45379)

上级 3c14b094
......@@ -17,8 +17,12 @@
#include <string>
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/common/data_type.h"
namespace phi {
class DenseTensor;
......@@ -30,6 +34,23 @@ class Scope;
} // namespace framework
} // namespace paddle
namespace {
template <typename T1, typename T2>
void ConvertTensorType(paddle::framework::LoDTensor* tensor) {
paddle::framework::Tensor tmp_tensor;
tmp_tensor.set_type(paddle::experimental::CppTypeToDataType<T2>::Type());
tmp_tensor.Resize(tensor->dims());
auto* tmp_data = tmp_tensor.mutable_data<T2>(paddle::platform::CPUPlace());
auto* data = tensor->mutable_data<T1>(paddle::platform::CPUPlace());
for (int i = 0; i < tensor->numel(); i++) {
tmp_data[i] = static_cast<T2>(data[i]);
}
tensor->clear();
paddle::framework::TensorCopySync(
tmp_tensor, paddle::platform::CPUPlace(), tensor);
}
} // namespace
namespace paddle {
namespace framework {
namespace ir {
......@@ -290,19 +311,7 @@ void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const {
auto tensor_type = conv_weight_tensor->dtype();
if (tensor_type == paddle::experimental::DataType::FLOAT16) {
framework::Tensor weight_float_tensor;
weight_float_tensor.set_type(paddle::experimental::DataType::FLOAT32);
weight_float_tensor.Resize(conv_weight_tensor->dims());
auto* weight_float_data =
weight_float_tensor.mutable_data<float>(platform::CPUPlace());
auto* data =
conv_weight_tensor->mutable_data<float16>(platform::CPUPlace());
for (int i = 0; i < conv_weight_tensor->numel(); i++) {
weight_float_data[i] = static_cast<float>(data[i]);
}
conv_weight_tensor->clear();
paddle::framework::TensorCopySync(
weight_float_tensor, platform::CPUPlace(), conv_weight_tensor);
ConvertTensorType<float16, float>(conv_weight_tensor);
}
// Get batch norm bias
......@@ -341,40 +350,8 @@ void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const {
conv_type());
if (tensor_type == paddle::experimental::DataType::FLOAT16) {
{
framework::Tensor weight_float16_tensor;
weight_float16_tensor.set_type(paddle::experimental::DataType::FLOAT16);
weight_float16_tensor.Resize(conv_weight_tensor->dims());
auto* weight_float16_data =
weight_float16_tensor.mutable_data<float16>(platform::CPUPlace());
auto* data =
conv_weight_tensor->mutable_data<float>(platform::CPUPlace());
for (int i = 0; i < conv_weight_tensor->numel(); i++) {
weight_float16_data[i] = static_cast<float16>(data[i]);
}
conv_weight_tensor->clear();
paddle::framework::TensorCopySync(
weight_float16_tensor, platform::CPUPlace(), conv_weight_tensor);
}
{
framework::Tensor eltwise_y_in_float16_tensor;
eltwise_y_in_float16_tensor.set_type(
paddle::experimental::DataType::FLOAT16);
eltwise_y_in_float16_tensor.Resize(eltwise_y_in_tensor->dims());
auto* eltwise_y_in_float16_data =
eltwise_y_in_float16_tensor.mutable_data<float16>(
platform::CPUPlace());
auto* data =
eltwise_y_in_tensor->mutable_data<float>(platform::CPUPlace());
for (int i = 0; i < eltwise_y_in_tensor->numel(); i++) {
eltwise_y_in_float16_data[i] = static_cast<float16>(data[i]);
}
eltwise_y_in_tensor->clear();
paddle::framework::TensorCopySync(eltwise_y_in_float16_tensor,
platform::CPUPlace(),
eltwise_y_in_tensor);
}
ConvertTensorType<float, float16>(conv_weight_tensor);
ConvertTensorType<float, float16>(eltwise_y_in_tensor);
}
// with MKL-DNN fuse conv+bn into conv with bias
......@@ -612,6 +589,16 @@ void ConvEltwiseAddBNFusePass::ApplyImpl(ir::Graph* graph) const {
float epsilon =
PADDLE_GET_CONST(float, batch_norm->Op()->GetAttr("epsilon"));
// conv_weight fp16 --> fp32
auto* conv_weight_tensor =
scope->FindVar(conv_weight->Name())->GetMutable<LoDTensor>();
auto tensor_type = conv_weight_tensor->dtype();
if (tensor_type == paddle::experimental::DataType::FLOAT16) {
ConvertTensorType<float16, float>(conv_weight_tensor);
ConvertTensorType<float16, float>(eltwise_y_in_tensor);
}
// if bias is an input to other ops as well then we cannot overwrite it
// so we create separate elementwise Y in nodes
if (eltwise_y_in->outputs.size() > 1) {
......@@ -666,6 +653,11 @@ void ConvEltwiseAddBNFusePass::ApplyImpl(ir::Graph* graph) const {
conv_type());
}
if (tensor_type == paddle::experimental::DataType::FLOAT16) {
ConvertTensorType<float, float16>(conv_weight_tensor);
ConvertTensorType<float, float16>(eltwise_y_in_tensor);
}
// Update the elementwise_add node
eltwise->Op()->SetAttr("axis", 1);
eltwise->Op()->SetOutput("Out", std::vector<std::string>({bn_out->Name()}));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册