未验证 提交 f265a313 编写于 作者: Z zhupengyang 提交者: GitHub

[XPU] support convert fp16 model (#50790)

上级 569b018e
...@@ -215,7 +215,7 @@ if(WITH_XPU) ...@@ -215,7 +215,7 @@ if(WITH_XPU)
cc_library( cc_library(
xpu_quant_utils xpu_quant_utils
SRCS xpu/quant_utils.cc SRCS xpu/quant_utils.cc
DEPS pass) DEPS pass phi)
cc_library( cc_library(
xpu_pass_utils xpu_pass_utils
SRCS xpu/pass_utils.cc SRCS xpu/pass_utils.cc
......
...@@ -47,6 +47,23 @@ bool PhiKernelSupportPrecision( ...@@ -47,6 +47,23 @@ bool PhiKernelSupportPrecision(
return phi::KernelFactory::Instance().HasKernel(op_type, kernel_key); return phi::KernelFactory::Instance().HasKernel(op_type, kernel_key);
} }
static phi::Backend ConvertPlaceToBackend(const phi::Place& place) {
switch (place.GetType()) {
case phi::AllocationType::CPU:
return phi::Backend::CPU;
case phi::AllocationType::GPU:
return phi::Backend::GPU;
case phi::AllocationType::XPU:
return phi::Backend::XPU;
case phi::AllocationType::NPU:
return phi::Backend::NPU;
default:
PADDLE_THROW(platform::errors::InvalidArgument(
"Cannot convert place(%d).", static_cast<int>(place.GetType())));
}
return phi::Backend::UNDEFINED;
}
bool KernelSupportPrecision( bool KernelSupportPrecision(
const std::string& op_type, const std::string& op_type,
phi::Backend backend, phi::Backend backend,
...@@ -65,7 +82,7 @@ bool KernelSupportPrecision( ...@@ -65,7 +82,7 @@ bool KernelSupportPrecision(
auto it = all_kernels.find(op_type); auto it = all_kernels.find(op_type);
if (it != all_kernels.end()) { if (it != all_kernels.end()) {
for (const auto& kern_pair : it->second) { for (const auto& kern_pair : it->second) {
if (platform::is_gpu_place(kern_pair.first.place_) && if (ConvertPlaceToBackend(kern_pair.first.place_) == backend &&
kern_pair.first.data_type_ == kern_pair.first.data_type_ ==
framework::TransToProtoVarType(precision)) { framework::TransToProtoVarType(precision)) {
support = true; support = true;
...@@ -150,20 +167,8 @@ bool OpSupportPrecision(const std::string& op_type, ...@@ -150,20 +167,8 @@ bool OpSupportPrecision(const std::string& op_type,
phi::Backend backend, phi::Backend backend,
phi::DataType precision, phi::DataType precision,
const std::unordered_set<std::string>& black_list) { const std::unordered_set<std::string>& black_list) {
bool support = false; return black_list.count(op_type) == 0 &&
if (black_list.count(op_type) == 0) { KernelSupportPrecision(op_type, backend, precision);
// Actual custom backend will be added after the NUM_BACKENDS.
// We use this feature to determine whether backend is custom device.
if (backend == phi::Backend::GPU ||
static_cast<size_t>(backend) >
static_cast<size_t>(phi::Backend::NUM_BACKENDS)) {
support = KernelSupportPrecision(op_type, backend, precision);
} else {
PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"Now, only support backend of GPU and Custom Device ."));
}
}
return support;
} }
// The set of ops that support fp16 calculation and are considered // The set of ops that support fp16 calculation and are considered
...@@ -192,15 +197,13 @@ void AutoMixedPrecisionPass::SetDefaultBlacklist() const { ...@@ -192,15 +197,13 @@ void AutoMixedPrecisionPass::SetDefaultBlacklist() const {
} }
void AutoMixedPrecisionPass::Init(Graph* graph) const { void AutoMixedPrecisionPass::Init(Graph* graph) const {
bool enable_gpu_mixed = Get<bool>("enable_gpu_mixed"); if (Has("enable_gpu_mixed") && Get<bool>("enable_gpu_mixed")) {
bool enable_custom_device_mixed = false;
if (Has("enable_custom_device_mixed")) {
enable_custom_device_mixed = Get<bool>("enable_custom_device_mixed");
}
if (enable_gpu_mixed) {
backend_ = phi::Backend::GPU; backend_ = phi::Backend::GPU;
} else if (enable_custom_device_mixed) { } else if (Has("enable_xpu_mixed") && Get<bool>("enable_xpu_mixed")) {
// transform Backend::CUSTOM to actual backend. backend_ = phi::Backend::XPU;
} else if (Has("enable_custom_device_mixed") &&
Get<bool>("enable_custom_device_mixed")) {
// transform Backend::CUSTOM to actual backend.
// Here, we only consider one custom backend. // Here, we only consider one custom backend.
#ifdef PADDLE_WITH_CUSTOM_DEVICE #ifdef PADDLE_WITH_CUSTOM_DEVICE
auto device_type = phi::DeviceManager::GetAllCustomDeviceTypes()[0]; auto device_type = phi::DeviceManager::GetAllCustomDeviceTypes()[0];
...@@ -214,7 +217,7 @@ void AutoMixedPrecisionPass::Init(Graph* graph) const { ...@@ -214,7 +217,7 @@ void AutoMixedPrecisionPass::Init(Graph* graph) const {
"Cannot enable custom_device_mixed.")); "Cannot enable custom_device_mixed."));
#endif #endif
} }
skip_pass_ = !enable_gpu_mixed && !enable_custom_device_mixed; skip_pass_ = backend_ == phi::Backend::UNDEFINED;
low_precision_ = static_cast<phi::DataType>(Get<int>("mixed_precision_mode")); low_precision_ = static_cast<phi::DataType>(Get<int>("mixed_precision_mode"));
...@@ -225,7 +228,6 @@ void AutoMixedPrecisionPass::Init(Graph* graph) const { ...@@ -225,7 +228,6 @@ void AutoMixedPrecisionPass::Init(Graph* graph) const {
VLOG(4) << " - " << name; VLOG(4) << " - " << name;
} }
keep_io_types_ = true;
if (Has("keep_io_types")) { if (Has("keep_io_types")) {
keep_io_types_ = Get<bool>("keep_io_types"); keep_io_types_ = Get<bool>("keep_io_types");
} }
...@@ -607,6 +609,20 @@ bool AutoMixedPrecisionPass::InputVarsNotConvert( ...@@ -607,6 +609,20 @@ bool AutoMixedPrecisionPass::InputVarsNotConvert(
return true; return true;
} }
} }
if (backend_ == phi::Backend::XPU) {
if (GetOpOriginalType(op_desc->Type()) == "layer_norm") {
auto vecs = op_desc->Input("Bias");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
vecs = op_desc->Input("Scale");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
}
}
return false; return false;
} }
...@@ -632,6 +648,20 @@ bool AutoMixedPrecisionPass::OutputVarsNotConvert( ...@@ -632,6 +648,20 @@ bool AutoMixedPrecisionPass::OutputVarsNotConvert(
return true; return true;
} }
} }
if (backend_ == phi::Backend::XPU) {
if (GetOpOriginalType(op_desc->Type()) == "layer_norm") {
auto vecs = op_desc->Output("Mean");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
vecs = op_desc->Output("Variance");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
}
}
return false; return false;
} }
......
...@@ -68,11 +68,11 @@ class AutoMixedPrecisionPass : public FusePassBase { ...@@ -68,11 +68,11 @@ class AutoMixedPrecisionPass : public FusePassBase {
private: private:
mutable bool skip_pass_{false}; mutable bool skip_pass_{false};
mutable bool keep_io_types_{false}; mutable bool keep_io_types_{true};
// float16 or bfloat16 now // float16 or bfloat16 now
mutable phi::DataType low_precision_{phi::DataType::FLOAT16}; mutable phi::DataType low_precision_{phi::DataType::FLOAT16};
mutable phi::Backend backend_{phi::Backend::GPU}; mutable phi::Backend backend_{phi::Backend::UNDEFINED};
mutable std::unordered_set<std::string> black_list_; mutable std::unordered_set<std::string> black_list_;
......
...@@ -245,6 +245,12 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph, ...@@ -245,6 +245,12 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph,
QuantWeight<int16_t>(mul_w_tensor, mul_w_max_tensor, !transpose_w); QuantWeight<int16_t>(mul_w_tensor, mul_w_max_tensor, !transpose_w);
} }
if (bias != nullptr) {
auto* bias_tensor =
scope->Var(bias->Name())->GetMutable<phi::DenseTensor>();
CastToFp32(bias_tensor);
}
std::string fc_out_name; std::string fc_out_name;
if (act_out) { if (act_out) {
fc_out_name = act_out->Name(); fc_out_name = act_out->Name();
......
...@@ -31,6 +31,7 @@ ...@@ -31,6 +31,7 @@
#include "paddle/fluid/framework/ir/xpu/quant_utils.h" #include "paddle/fluid/framework/ir/xpu/quant_utils.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/kernels/concat_kernel.h"
namespace phi { namespace phi {
class DenseTensor; class DenseTensor;
...@@ -617,6 +618,9 @@ class MultiEncoderXPUFusePass : public FusePassBase { ...@@ -617,6 +618,9 @@ class MultiEncoderXPUFusePass : public FusePassBase {
bool ApplyMultiEncoderXPUFuse(ir::Graph* graph) const; bool ApplyMultiEncoderXPUFuse(ir::Graph* graph) const;
// Mask must be fp32 even if model is fp16
int CastMask(ir::Graph* graph) const;
// 1. Transpose q_w, k_w, v_w // 1. Transpose q_w, k_w, v_w
// 2. Concat q_w, k_w, v_w // 2. Concat q_w, k_w, v_w
// 3. Generate qkv_w_max tensor // 3. Generate qkv_w_max tensor
...@@ -674,8 +678,11 @@ void MultiEncoderXPUFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -674,8 +678,11 @@ void MultiEncoderXPUFusePass::ApplyImpl(ir::Graph* graph) const {
} }
} }
} }
int cast_mask_counts = CastMask(graph);
AddStatis(single_encoder_fused_counts); AddStatis(single_encoder_fused_counts);
AddStatis(multi_encoder_fused_counts); AddStatis(multi_encoder_fused_counts);
AddStatis(cast_mask_counts);
} }
void MultiEncoderXPUFusePass::PrepareQKVWeight( void MultiEncoderXPUFusePass::PrepareQKVWeight(
...@@ -685,29 +692,28 @@ void MultiEncoderXPUFusePass::PrepareQKVWeight( ...@@ -685,29 +692,28 @@ void MultiEncoderXPUFusePass::PrepareQKVWeight(
phi::DenseTensor* qkv_w, phi::DenseTensor* qkv_w,
phi::DenseTensor* qkv_w_max) const { phi::DenseTensor* qkv_w_max) const {
// Transpose // Transpose
phi::DenseTensor q_w_trans; phi::DenseTensor q_w_t;
phi::DenseTensor k_w_trans; phi::DenseTensor k_w_t;
phi::DenseTensor v_w_trans; phi::DenseTensor v_w_t;
Transpose2D<float>(q_w, &q_w_trans); Assign(q_w, &q_w_t);
Transpose2D<float>(k_w, &k_w_trans); Assign(k_w, &k_w_t);
Transpose2D<float>(v_w, &v_w_trans); Assign(v_w, &v_w_t);
Transpose2D(&q_w_t);
Transpose2D(&k_w_t);
Transpose2D(&v_w_t);
// Concat // Concat
auto q_w_trans_dims = q_w_trans.dims(); qkv_w->Resize(DDim(
auto k_w_trans_dims = k_w_trans.dims(); {q_w_t.dims()[0] + k_w_t.dims()[0] + v_w_t.dims()[0], q_w_t.dims()[1]}));
auto v_w_trans_dims = v_w_trans.dims();
qkv_w->Resize(DDim({q_w_trans_dims[0] + k_w_trans_dims[0] + v_w_trans_dims[0],
q_w_trans_dims[1]}));
qkv_w->set_type(q_w.type()); qkv_w->set_type(q_w.type());
auto* dev_ctx = static_cast<phi::CPUContext*>( auto* dev_ctx = static_cast<phi::CPUContext*>(
platform::DeviceContextPool::Instance().Get(phi::CPUPlace())); platform::DeviceContextPool::Instance().Get(phi::CPUPlace()));
int size = q_w.numel(); std::vector<const phi::DenseTensor*> in_tensors{&q_w_t, &k_w_t, &v_w_t};
auto* qkv_w_data = dev_ctx->Alloc<float>(qkv_w); if (q_w.type() == phi::DataType::FLOAT16) {
memcpy(qkv_w_data, q_w_trans.data(), size * sizeof(float)); phi::ConcatKernel<float16>(*dev_ctx, in_tensors, 0, qkv_w);
qkv_w_data += size; } else {
memcpy(qkv_w_data, k_w_trans.data(), size * sizeof(float)); phi::ConcatKernel<float>(*dev_ctx, in_tensors, 0, qkv_w);
qkv_w_data += size; }
memcpy(qkv_w_data, v_w_trans.data(), size * sizeof(float));
// Quant to int16 // Quant to int16
QuantWeight<int16_t>(qkv_w, qkv_w_max, false); QuantWeight<int16_t>(qkv_w, qkv_w_max, false);
...@@ -846,6 +852,9 @@ int MultiEncoderXPUFusePass::ApplySingleEncoderXPUFuse( ...@@ -846,6 +852,9 @@ int MultiEncoderXPUFusePass::ApplySingleEncoderXPUFuse(
auto* block = q_matmul->Op()->Block(); auto* block = q_matmul->Op()->Block();
auto* scope = param_scope(); auto* scope = param_scope();
bool enable_fp16 =
scope->FindVar(q_matmul_w->Name())->Get<phi::DenseTensor>().dtype() ==
phi::DataType::FLOAT16;
// Prepare q,k,v weight // Prepare q,k,v weight
std::string q_w_name = q_matmul_w->Name(); std::string q_w_name = q_matmul_w->Name();
...@@ -905,12 +914,32 @@ int MultiEncoderXPUFusePass::ApplySingleEncoderXPUFuse( ...@@ -905,12 +914,32 @@ int MultiEncoderXPUFusePass::ApplySingleEncoderXPUFuse(
auto* qkv_add_bias = graph->CreateVarNode(&qkv_add_bias_desc); auto* qkv_add_bias = graph->CreateVarNode(&qkv_add_bias_desc);
auto* qkv_add_bias_var = block->Var(qkv_add_bias_name); auto* qkv_add_bias_var = block->Var(qkv_add_bias_name);
qkv_add_bias_var->SetPersistable(true); qkv_add_bias_var->SetPersistable(true);
auto* q_add_bias_tensor =
scope->FindVar(q_add_bias_name)->GetMutable<phi::DenseTensor>();
auto* k_add_bias_tensor =
scope->FindVar(k_add_bias_name)->GetMutable<phi::DenseTensor>();
auto* v_add_bias_tensor =
scope->FindVar(v_add_bias_name)->GetMutable<phi::DenseTensor>();
CastToFp32(q_add_bias_tensor);
CastToFp32(k_add_bias_tensor);
CastToFp32(v_add_bias_tensor);
ConcatQKVBias( ConcatQKVBias(
scope->FindVar(q_add_bias_name)->Get<phi::DenseTensor>(), *q_add_bias_tensor,
scope->FindVar(k_add_bias_name)->Get<phi::DenseTensor>(), *k_add_bias_tensor,
scope->FindVar(v_add_bias_name)->Get<phi::DenseTensor>(), *v_add_bias_tensor,
scope->Var(qkv_add_bias_name)->GetMutable<phi::DenseTensor>()); scope->Var(qkv_add_bias_name)->GetMutable<phi::DenseTensor>());
// Prepare qkv_add_0_bias, qkv_add_2_bias, qkv_add_3_bias
auto qkv_add_0_bias_name = qkv_add_0_bias->Name();
CastToFp32(
scope->FindVar(qkv_add_0_bias_name)->GetMutable<phi::DenseTensor>());
auto qkv_add_2_bias_name = qkv_add_2_bias->Name();
CastToFp32(
scope->FindVar(qkv_add_2_bias_name)->GetMutable<phi::DenseTensor>());
auto qkv_add_3_bias_name = qkv_add_3_bias->Name();
CastToFp32(
scope->FindVar(qkv_add_3_bias_name)->GetMutable<phi::DenseTensor>());
// Generate single_encoder_xpu op // Generate single_encoder_xpu op
framework::OpDesc op_desc(block); framework::OpDesc op_desc(block);
op_desc.SetType("single_encoder_xpu"); op_desc.SetType("single_encoder_xpu");
...@@ -927,9 +956,9 @@ int MultiEncoderXPUFusePass::ApplySingleEncoderXPUFuse( ...@@ -927,9 +956,9 @@ int MultiEncoderXPUFusePass::ApplySingleEncoderXPUFuse(
qkv_matmul_3_w_max_name}); qkv_matmul_3_w_max_name});
op_desc.SetInput("fc_bias", op_desc.SetInput("fc_bias",
{qkv_add_bias_name, {qkv_add_bias_name,
qkv_add_0_bias->Name(), qkv_add_0_bias_name,
qkv_add_2_bias->Name(), qkv_add_2_bias_name,
qkv_add_3_bias->Name()}); qkv_add_3_bias_name});
if (norm_before) { if (norm_before) {
op_desc.SetInput("ln_scale", {ln_0_scale->Name(), ln_1_scale->Name()}); op_desc.SetInput("ln_scale", {ln_0_scale->Name(), ln_1_scale->Name()});
op_desc.SetInput("ln_bias", {ln_0_bias->Name(), ln_1_bias->Name()}); op_desc.SetInput("ln_bias", {ln_0_bias->Name(), ln_1_bias->Name()});
...@@ -953,6 +982,7 @@ int MultiEncoderXPUFusePass::ApplySingleEncoderXPUFuse( ...@@ -953,6 +982,7 @@ int MultiEncoderXPUFusePass::ApplySingleEncoderXPUFuse(
static_cast<int>(qkv_matmul_2_w_shape[1] / qkv_matmul_2_w_shape[0])); static_cast<int>(qkv_matmul_2_w_shape[1] / qkv_matmul_2_w_shape[0]));
op_desc.SetAttr("act_type", ConvertActivationType(act_type)); op_desc.SetAttr("act_type", ConvertActivationType(act_type));
op_desc.SetAttr("relative_type", static_cast<int>(0)); op_desc.SetAttr("relative_type", static_cast<int>(0));
op_desc.SetAttr("enable_fp16", enable_fp16);
if (norm_before) { if (norm_before) {
op_desc.SetOutput("out", {qkv_add_4_out->Name()}); op_desc.SetOutput("out", {qkv_add_4_out->Name()});
} else { } else {
...@@ -1186,6 +1216,9 @@ bool MultiEncoderXPUFusePass::ApplyMultiEncoderXPUFuse(ir::Graph* graph) const { ...@@ -1186,6 +1216,9 @@ bool MultiEncoderXPUFusePass::ApplyMultiEncoderXPUFuse(ir::Graph* graph) const {
PADDLE_GET_CONST(int, single_encoders[0]->Op()->GetAttr(attr_name))); PADDLE_GET_CONST(int, single_encoders[0]->Op()->GetAttr(attr_name)));
} }
op_desc.SetAttr("slice_idx", static_cast<int>(-1)); op_desc.SetAttr("slice_idx", static_cast<int>(-1));
op_desc.SetAttr(
"enable_fp16",
PADDLE_GET_CONST(bool, single_encoders[0]->Op()->GetAttr("enable_fp16")));
op_desc.SetOutput("out", {out_name}); op_desc.SetOutput("out", {out_name});
op_desc.SetOutput("x_fp16", {x_fp16_name}); op_desc.SetOutput("x_fp16", {x_fp16_name});
op_desc.SetOutput("out_fp16", {out_fp16_name}); op_desc.SetOutput("out_fp16", {out_fp16_name});
...@@ -1213,6 +1246,61 @@ bool MultiEncoderXPUFusePass::ApplyMultiEncoderXPUFuse(ir::Graph* graph) const { ...@@ -1213,6 +1246,61 @@ bool MultiEncoderXPUFusePass::ApplyMultiEncoderXPUFuse(ir::Graph* graph) const {
return true; return true;
} }
int MultiEncoderXPUFusePass::CastMask(ir::Graph* graph) const {
int cast_counts = 0;
auto nodes = graph->Nodes();
for (auto node : nodes) {
if (node->IsVar()) continue;
auto op_desc = node->Op();
if (node->IsVar() || //
op_desc->Type() != "multi_encoder_xpu" ||
!op_desc->GetAttrIfExists<bool>("enable_fp16") ||
op_desc->Inputs().count("mask") == 0)
continue;
auto* block = op_desc->Block();
auto* scope = param_scope();
// Find mask node
std::string mask_name = op_desc->Inputs().at("mask")[0];
Node* mask = nullptr;
for (auto* in_node : node->inputs) {
if (in_node->Var()->Name() == mask_name) {
mask = in_node;
break;
}
}
// Create new_mask node/var/tensor
std::string new_mask_name = mask_name + "_fp32";
VarDesc new_mask_desc(new_mask_name);
auto* new_mask = graph->CreateVarNode(&new_mask_desc);
block->Var(new_mask_name);
scope->Var(new_mask_name)->GetMutable<phi::DenseTensor>();
// Create cast op
framework::OpDesc cast_op_desc(block);
cast_op_desc.SetType("cast");
cast_op_desc.SetInput("X", {mask_name});
cast_op_desc.SetAttr("in_dtype",
static_cast<int>(framework::proto::VarType::FP16));
cast_op_desc.SetAttr("out_dtype",
static_cast<int>(framework::proto::VarType::FP32));
cast_op_desc.SetOutput("Out", {new_mask_name});
auto* cast = graph->CreateOpNode(&cast_op_desc);
IR_NODE_LINK_TO(mask, cast);
IR_NODE_LINK_TO(cast, new_mask);
// Update encoder
op_desc->SetInput("mask", {new_mask_name});
IR_NODE_LINK_TO(new_mask, node);
IR_NODE_UNLINK(node, mask);
cast_counts++;
}
return cast_counts;
}
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
......
...@@ -16,33 +16,92 @@ ...@@ -16,33 +16,92 @@
#include <vector> #include <vector>
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/enforce.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/assign_kernel.h"
#include "paddle/phi/kernels/cast_kernel.h"
#include "paddle/phi/kernels/transpose_kernel.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
template <typename T> void Assign(const phi::DenseTensor& in, phi::DenseTensor* out) {
void Transpose2D(const phi::DenseTensor& in, phi::DenseTensor* out) { auto* cpu_ctx = static_cast<phi::CPUContext*>(
auto in_dims = in.dims(); platform::DeviceContextPool::Instance().Get(phi::CPUPlace()));
out->Resize(in.dims());
out->set_type(in.dtype());
out->set_layout(in.layout());
phi::AssignKernel(*cpu_ctx, in, out);
}
void Transpose2D(phi::DenseTensor* in, phi::DenseTensor* out) {
auto in_dims = in->dims();
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
in_dims.size(), in_dims.size(),
2, 2,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"In dims rank should be 2, but received in dims size is [%d].", "In dims rank should be 2, but received in dims size is [%d].",
in_dims.size())); in_dims.size()));
out->Resize({in_dims[1], in_dims[0]});
out->set_type(in.type()); phi::DenseTensor trans_tensor;
auto* dev_ctx = static_cast<phi::CPUContext*>( phi::DenseTensor* out_ptr = out == nullptr ? &trans_tensor : out;
out_ptr->Resize({in_dims[1], in_dims[0]});
out_ptr->set_type(in->type());
out_ptr->set_layout(in->layout());
auto* cpu_ctx = static_cast<phi::CPUContext*>(
platform::DeviceContextPool::Instance().Get(phi::CPUPlace())); platform::DeviceContextPool::Instance().Get(phi::CPUPlace()));
dev_ctx->Alloc<T>(out);
std::vector<int> axis{1, 0}; std::vector<int> axis{1, 0};
phi::funcs::Transpose<phi::CPUContext, T, 2> trans2d; switch (in->dtype()) {
trans2d(*dev_ctx, in, out, axis); case phi::DataType::FLOAT16:
phi::TransposeKernel<float16>(*cpu_ctx, *in, axis, out_ptr);
break;
case phi::DataType::FLOAT32:
phi::TransposeKernel<float>(*cpu_ctx, *in, axis, out_ptr);
break;
default:
PADDLE_THROW(platform::errors::InvalidArgument(
"Only support fp16 and fp32, but received dtype is %s.",
phi::DataTypeToString(in->dtype())));
break;
}
if (out == nullptr) {
Assign(*out_ptr, in);
}
} }
template void Transpose2D<float>(const phi::DenseTensor& in, void CastToFp32(phi::DenseTensor* in, phi::DenseTensor* out) {
phi::DenseTensor* out); auto* cpu_ctx = static_cast<phi::CPUContext*>(
platform::DeviceContextPool::Instance().Get(phi::CPUPlace()));
phi::DenseTensor fp32_tensor;
phi::DenseTensor* out_ptr = out == nullptr ? &fp32_tensor : out;
out_ptr->Resize(in->dims());
out_ptr->set_type(phi::DataType::FLOAT32);
out_ptr->set_layout(in->layout());
switch (in->dtype()) {
case phi::DataType::FLOAT16:
phi::CastKernel<float16>(*cpu_ctx, *in, phi::DataType::FLOAT32, out_ptr);
break;
case phi::DataType::FLOAT32:
if (out == nullptr) {
return;
} else {
phi::AssignKernel(*cpu_ctx, *in, out_ptr);
}
break;
default:
PADDLE_THROW(platform::errors::InvalidArgument(
"Only support fp16 and fp32, but received dtype is %s.",
phi::DataTypeToString(in->dtype())));
break;
}
if (out == nullptr) {
Assign(*out_ptr, in);
}
}
static float FindMaxAbs(const float* data, int len) { static float FindMaxAbs(const float* data, int len) {
float max_f = 0.0f; float max_f = 0.0f;
...@@ -151,14 +210,15 @@ template <typename T> ...@@ -151,14 +210,15 @@ template <typename T>
void QuantWeight(phi::DenseTensor* weight, void QuantWeight(phi::DenseTensor* weight,
phi::DenseTensor* weight_max, phi::DenseTensor* weight_max,
bool transpose) { bool transpose) {
// Convert fp16 to fp32
phi::DenseTensor weight_fp32;
CastToFp32(weight, &weight_fp32);
// Transpose // Transpose
auto* weight_data = weight->data<float>();
phi::DenseTensor weight_trans;
if (transpose) { if (transpose) {
Transpose2D<float>(*weight, &weight_trans); Transpose2D(&weight_fp32);
weight_data = weight_trans.data<float>();
weight->Resize(weight_trans.dims());
} }
// Find max // Find max
paddle::platform::DeviceContextPool& pool = paddle::platform::DeviceContextPool& pool =
paddle::platform::DeviceContextPool::Instance(); paddle::platform::DeviceContextPool::Instance();
...@@ -171,21 +231,22 @@ void QuantWeight(phi::DenseTensor* weight, ...@@ -171,21 +231,22 @@ void QuantWeight(phi::DenseTensor* weight,
} }
phi::XPUContext* xpu_ctx = static_cast<phi::XPUContext*>(pool.Get(place)); phi::XPUContext* xpu_ctx = static_cast<phi::XPUContext*>(pool.Get(place));
int max_ptr_size = xpu_ctx->x_context()->max_ptr_size(); int max_ptr_size = xpu_ctx->x_context()->max_ptr_size();
int size = weight->numel(); int size = weight_fp32.numel();
auto* weight_data = weight_fp32.data<float>();
float max_val = FindMaxAbs(weight_data, size); float max_val = FindMaxAbs(weight_data, size);
std::vector<float> max_vec(max_ptr_size, max_val); std::vector<float> max_vec(max_ptr_size, max_val);
weight_max->set_type(paddle::experimental::CppTypeToDataType<float>::Type()); weight_max->set_type(phi::DataType::FLOAT32);
weight_max->Resize({max_ptr_size}); weight_max->Resize({max_ptr_size});
auto* dev_ctx = static_cast<phi::CPUContext*>( auto* cpu_ctx = static_cast<phi::CPUContext*>(
platform::DeviceContextPool::Instance().Get(phi::CPUPlace())); platform::DeviceContextPool::Instance().Get(phi::CPUPlace()));
memcpy(dev_ctx->Alloc<float>(weight_max), memcpy(cpu_ctx->Alloc<float>(weight_max),
max_vec.data(), max_vec.data(),
max_ptr_size * sizeof(float)); max_ptr_size * sizeof(float));
// Quant // Quant
std::vector<T> quant_data(size);
QuantFP32ToIntX(weight_data, quant_data.data(), max_val, size);
weight->set_type(paddle::experimental::CppTypeToDataType<T>::Type()); weight->set_type(paddle::experimental::CppTypeToDataType<T>::Type());
memcpy(dev_ctx->Alloc<T>(weight), quant_data.data(), size * sizeof(T)); weight->Resize(weight_fp32.dims());
QuantFP32ToIntX(weight_data, cpu_ctx->Alloc<T>(weight), max_val, size);
} }
template void QuantWeight<int16_t>(phi::DenseTensor* weight, template void QuantWeight<int16_t>(phi::DenseTensor* weight,
......
...@@ -19,8 +19,11 @@ namespace paddle { ...@@ -19,8 +19,11 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
template <typename T> void Assign(const phi::DenseTensor& in, phi::DenseTensor* out);
void Transpose2D(const phi::DenseTensor& in, phi::DenseTensor* out);
void Transpose2D(phi::DenseTensor* in, phi::DenseTensor* out = nullptr);
void CastToFp32(phi::DenseTensor* in, phi::DenseTensor* out = nullptr);
// 1. Quant weight from fp32 to int16/int31 // 1. Quant weight from fp32 to int16/int31
// 2. Weight data is in-place update. // 2. Weight data is in-place update.
......
...@@ -41,18 +41,31 @@ ConvertToMixedPrecisionPass::ConvertToMixedPrecisionPass( ...@@ -41,18 +41,31 @@ ConvertToMixedPrecisionPass::ConvertToMixedPrecisionPass(
backend_(backend), backend_(backend),
keep_io_types_(keep_io_types), keep_io_types_(keep_io_types),
black_list_(black_list) { black_list_(black_list) {
if (mixed_precision_ != phi::DataType::FLOAT16 && switch (backend_) {
mixed_precision_ != phi::DataType::BFLOAT16) { case phi::Backend::GPU:
PADDLE_THROW(paddle::platform::errors::InvalidArgument( PADDLE_ENFORCE(mixed_precision_ == phi::DataType::FLOAT16 ||
"mixed_precision currently not supported dtype %d, we now only " mixed_precision_ == phi::DataType::BFLOAT16,
"support fp16 and bf16.", platform::errors::InvalidArgument(
static_cast<int>(mixed_precision_))); "mixed_precision of %s currently only supported fp16 "
} "and bf16, not support %s.",
if (backend_ != phi::Backend::GPU && backend_ != phi::Backend::CUSTOM) { experimental::BackendToString(backend_),
PADDLE_THROW(paddle::platform::errors::InvalidArgument( phi::DataTypeToString(mixed_precision_)));
"mixed_precision currently not supported place %d, we now only " break;
"support gpu and custom device .", case phi::Backend::XPU:
static_cast<int>(backend_))); case phi::Backend::CUSTOM:
PADDLE_ENFORCE(mixed_precision_ == phi::DataType::FLOAT16,
platform::errors::InvalidArgument(
"mixed_precision of %s currently only supported fp16 "
"and bf16, not support %s.",
experimental::BackendToString(backend_),
phi::DataTypeToString(mixed_precision_)));
break;
default:
PADDLE_THROW(platform::errors::InvalidArgument(
"mixed_precision currently not supported place GPU or XPU or CUSTOM, "
"not support %s.",
experimental::BackendToString(backend_)));
break;
} }
} }
...@@ -70,17 +83,16 @@ void ConvertToMixedPrecisionPass::Run() { ...@@ -70,17 +83,16 @@ void ConvertToMixedPrecisionPass::Run() {
framework::ir::AutoMixedPrecisionPass pass; framework::ir::AutoMixedPrecisionPass pass;
pass.Set("mixed_precision_mode", new int{static_cast<int>(mixed_precision_)}); pass.Set("mixed_precision_mode", new int{static_cast<int>(mixed_precision_)});
pass.Set("mixed_black_list",
new std::unordered_set<std::string>{black_list_});
if (backend_ == phi::Backend::GPU) { if (backend_ == phi::Backend::GPU) {
pass.Set("enable_gpu_mixed", new bool{true}); pass.Set("enable_gpu_mixed", new bool{true});
pass.Set("enable_custom_device_mixed", new bool{false}); } else if (backend_ == phi::Backend::XPU) {
pass.Set("enable_xpu_mixed", new bool{true});
} else if (backend_ == phi::Backend::CUSTOM) { } else if (backend_ == phi::Backend::CUSTOM) {
pass.Set("enable_gpu_mixed", new bool{false});
pass.Set("enable_custom_device_mixed", new bool{true}); pass.Set("enable_custom_device_mixed", new bool{true});
} }
pass.Set("mixed_black_list",
new std::unordered_set<std::string>{black_list_});
pass.Set("keep_io_types", new bool{keep_io_types_}); pass.Set("keep_io_types", new bool{keep_io_types_});
pass.Apply(main_graph_.get()); pass.Apply(main_graph_.get());
SaveMixedModel(); SaveMixedModel();
......
...@@ -1302,18 +1302,23 @@ void AnalysisPredictor::PrepareArgument() { ...@@ -1302,18 +1302,23 @@ void AnalysisPredictor::PrepareArgument() {
<< ", we will use a new PassStrategy. Note that only the GPU " << ", we will use a new PassStrategy. Note that only the GPU "
"backend is supported for now."; "backend is supported for now.";
if (!config_.use_cinn_compiler_) { if (!config_.use_cinn_compiler_) {
pass_builder->ClearPasses();
const auto &deleted_passes = pass_builder->GetAllDeletedPasses(); const auto &deleted_passes = pass_builder->GetAllDeletedPasses();
if (config_.tensorrt_engine_enabled()) { if (config_.tensorrt_engine_enabled()) {
pass_builder->ClearPasses();
for (const auto &pass : kTrtLowerPrecisionPasses) { for (const auto &pass : kTrtLowerPrecisionPasses) {
if (deleted_passes.count(pass)) continue; if (deleted_passes.count(pass)) continue;
pass_builder->AppendPass(pass); pass_builder->AppendPass(pass);
} }
} else if (config_.use_gpu()) { } else if (config_.use_gpu()) {
pass_builder->ClearPasses();
for (const auto &pass : kGpuLowerPrecisionPasses) { for (const auto &pass : kGpuLowerPrecisionPasses) {
if (deleted_passes.count(pass)) continue; if (deleted_passes.count(pass)) continue;
pass_builder->AppendPass(pass); pass_builder->AppendPass(pass);
} }
} else if (config_.use_xpu()) {
// All passes support fp16. Not reset pass_builder.
} else {
pass_builder->ClearPasses();
} }
} }
} }
......
...@@ -519,9 +519,9 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) { ...@@ -519,9 +519,9 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
"delete_dropout_op_pass", "delete_dropout_op_pass",
"identity_scale_op_clean_pass", "identity_scale_op_clean_pass",
"generate_sequence_xpu_fuse_pass", "generate_sequence_xpu_fuse_pass",
"embedding_with_eltwise_add_xpu_fuse_pass",
"multi_encoder_xpu_fuse_pass", "multi_encoder_xpu_fuse_pass",
"multi_encoder_xpu_slice_fuse_pass", "multi_encoder_xpu_slice_fuse_pass",
"embedding_with_eltwise_add_xpu_fuse_pass",
"fc_xpu_fuse_pass", "fc_xpu_fuse_pass",
"link_xpu_op_max_pass", "link_xpu_op_max_pass",
}); });
......
...@@ -253,7 +253,8 @@ XPUOpMap& get_kl2_ops() { ...@@ -253,7 +253,8 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::BOOL, phi::DataType::BOOL,
phi::DataType::FLOAT16, phi::DataType::FLOAT16,
phi::DataType::FLOAT32})}, phi::DataType::FLOAT32})},
{"fc_xpu", XPUKernelSet({phi::DataType::FLOAT32})}, {"fc_xpu",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"fill", {"fill",
XPUKernelSet({phi::DataType::INT64, XPUKernelSet({phi::DataType::INT64,
phi::DataType::INT32, phi::DataType::INT32,
...@@ -461,7 +462,8 @@ XPUOpMap& get_kl2_ops() { ...@@ -461,7 +462,8 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::FLOAT16, phi::DataType::FLOAT16,
phi::DataType::INT32, phi::DataType::INT32,
phi::DataType::INT64})}, phi::DataType::INT64})},
{"multi_encoder_xpu", XPUKernelSet({phi::DataType::FLOAT32})}, {"multi_encoder_xpu",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"nearest_interp_v2", XPUKernelSet({phi::DataType::FLOAT32})}, {"nearest_interp_v2", XPUKernelSet({phi::DataType::FLOAT32})},
{"nearest_interp_v2_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"nearest_interp_v2_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"not_equal", {"not_equal",
......
...@@ -210,6 +210,8 @@ inline std::string BackendToString(const Backend& backend) { ...@@ -210,6 +210,8 @@ inline std::string BackendToString(const Backend& backend) {
return "KPS"; return "KPS";
case Backend::IPU: case Backend::IPU:
return "IPU"; return "IPU";
case Backend::CUSTOM:
return "CUSTOM";
default: { default: {
size_t device_type_id_ = static_cast<size_t>(backend) - size_t device_type_id_ = static_cast<size_t>(backend) -
static_cast<size_t>(Backend::NUM_BACKENDS); static_cast<size_t>(Backend::NUM_BACKENDS);
......
...@@ -33,44 +33,53 @@ void FcXPUKernel(const Context& ctx, ...@@ -33,44 +33,53 @@ void FcXPUKernel(const Context& ctx,
float act_alpha, float act_alpha,
DenseTensor* out, DenseTensor* out,
DenseTensor* out_max) { DenseTensor* out_max) {
using XPUType = typename XPUTypeTrait<T>::Type;
auto in_mat_dims = flatten_to_2d(x.dims(), in_num_col_dims); auto in_mat_dims = flatten_to_2d(x.dims(), in_num_col_dims);
int m = in_mat_dims[0]; int m = in_mat_dims[0];
int k = in_mat_dims[1]; int k = in_mat_dims[1];
int n = w.dims()[0]; int n = w.dims()[0];
auto* x_data = reinterpret_cast<const XPUType*>(x.data<T>());
const float* x_max_data = const float* x_max_data =
x_max.get_ptr() == nullptr ? nullptr : x_max.get_ptr()->data<float>(); x_max.get_ptr() == nullptr ? nullptr : x_max.get_ptr()->data<float>();
const float* bias_data = const float* bias_data =
bias.get_ptr() == nullptr ? nullptr : bias.get_ptr()->data<float>(); bias.get_ptr() == nullptr ? nullptr : bias.get_ptr()->data<float>();
auto* out_data = reinterpret_cast<XPUType*>(ctx.template Alloc<T>(out));
xpu::Activation_t act(static_cast<xpu::Activation_t::act_enum>(act_type)); xpu::Activation_t act(static_cast<xpu::Activation_t::act_enum>(act_type));
if (act_type == 5) { if (act_type == 5) {
act.leaky_alpha = act_alpha; act.leaky_alpha = act_alpha;
} else if (act_type == 15) { } else if (act_type == 15) {
act.hard_sigmoid_slope = act_alpha; act.hard_sigmoid_slope = act_alpha;
} }
int r = xpu::fc_fusion<T, int16_t, T, int16_t>( // TX, TW. TY, TGEMM int r =
ctx.x_context(), // ctx xpu::fc_fusion<XPUType, int16_t, XPUType, int16_t>( // TX, TW. TY, TGEMM
x.data<T>(), // x ctx.x_context(), // ctx
w.data<int16_t>(), // w x_data, // x
ctx.template Alloc<T>(out), // y w.data<int16_t>(), // w
m, // m out_data, // y
n, // n m, // m
k, // k n, // n
transpose_x, // x_trans k, // k
true, // w_trans transpose_x, // x_trans
x_max_data, // x_maxptr true, // w_trans
w_max.data<float>(), // w_maxptr x_max_data, // x_maxptr
ctx.template Alloc<float>(out_max), // y_maxptr w_max.data<float>(), // w_maxptr
transpose_x ? m : k, // ldx ctx.template Alloc<float>(out_max), // y_maxptr
k, // ldw transpose_x ? m : k, // ldx
n, // ldy k, // ldw
alpha, // alpha n, // ldy
beta, // beta alpha, // alpha
bias_data, // bias beta, // beta
act); bias_data, // bias
act);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "fc_xpu"); PADDLE_ENFORCE_XDNN_SUCCESS(r, "fc_xpu");
} }
} // namespace fusion } // namespace fusion
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL(fc_xpu, XPU, ALL_LAYOUT, phi::fusion::FcXPUKernel, float) {} PD_REGISTER_KERNEL(fc_xpu,
XPU,
ALL_LAYOUT,
phi::fusion::FcXPUKernel,
float,
phi::dtype::float16) {}
...@@ -40,18 +40,26 @@ void MultiEncoderXPUKernel(const Context& ctx, ...@@ -40,18 +40,26 @@ void MultiEncoderXPUKernel(const Context& ctx,
DenseTensor* out, DenseTensor* out,
DenseTensor* x_fp16, DenseTensor* x_fp16,
DenseTensor* out_fp16) { DenseTensor* out_fp16) {
using float16 = typename XPUTypeTrait<phi::dtype::float16>::Type;
// XPU2 only support fp16 input/output. // XPU2 only support fp16 input/output.
float16* x_fp16_data = reinterpret_cast<float16*>( auto x_dtype = x.dtype();
ctx.template Alloc<phi::dtype::float16>(x_fp16)); const float16* x_fp16_data = nullptr;
int r_cast_x = xpu::cast_v2<float, float16>( float16* out_fp16_data = nullptr;
ctx.x_context(), x.data<T>(), x_fp16_data, x.numel()); if (x_dtype == phi::DataType::FLOAT32) {
PADDLE_ENFORCE_XDNN_SUCCESS(r_cast_x, auto* x_fp16_data_t = reinterpret_cast<float16*>(
"multi_encoder_xpu(cast x from fp32 to fp16)"); ctx.template Alloc<phi::dtype::float16>(x_fp16));
int r_cast_x = xpu::cast_v2<float, float16>(
float16* out_fp16_data = reinterpret_cast<float16*>( ctx.x_context(), x.data<float>(), x_fp16_data_t, x.numel());
ctx.template Alloc<phi::dtype::float16>(out_fp16)); PADDLE_ENFORCE_XDNN_SUCCESS(r_cast_x,
"multi_encoder_xpu(cast x from fp32 to fp16)");
x_fp16_data = x_fp16_data_t;
out_fp16_data = reinterpret_cast<float16*>(
ctx.template Alloc<phi::dtype::float16>(out_fp16));
} else {
x_fp16_data =
reinterpret_cast<const float16*>(x.data<phi::dtype::float16>());
out_fp16_data = reinterpret_cast<float16*>(
ctx.template Alloc<phi::dtype::float16>(out));
}
// q,k,v weight are fused. // q,k,v weight are fused.
// Each encoder's weight should be: w0, null, null, w3, w4, w5 // Each encoder's weight should be: w0, null, null, w3, w4, w5
...@@ -78,8 +86,8 @@ void MultiEncoderXPUKernel(const Context& ctx, ...@@ -78,8 +86,8 @@ void MultiEncoderXPUKernel(const Context& ctx,
ln_scale_data.push_back(ln_scale[i]->data<float>()); ln_scale_data.push_back(ln_scale[i]->data<float>());
ln_bias_data.push_back(ln_bias[i]->data<float>()); ln_bias_data.push_back(ln_bias[i]->data<float>());
} }
const T* mask_data = const float* mask_data =
mask.get_ptr() == nullptr ? nullptr : mask.get_ptr()->data<T>(); mask.get_ptr() == nullptr ? nullptr : mask.get_ptr()->data<float>();
xpu::Activation_t qkv_act(static_cast<xpu::Activation_t::act_enum>(act_type)); xpu::Activation_t qkv_act(static_cast<xpu::Activation_t::act_enum>(act_type));
int batch = x.dims()[0]; int batch = x.dims()[0];
...@@ -152,10 +160,15 @@ void MultiEncoderXPUKernel(const Context& ctx, ...@@ -152,10 +160,15 @@ void MultiEncoderXPUKernel(const Context& ctx,
PADDLE_ENFORCE_XDNN_SUCCESS(r, "multi_encoder_xpu"); PADDLE_ENFORCE_XDNN_SUCCESS(r, "multi_encoder_xpu");
} }
int r_cast_out = xpu::cast_v2<float16, float>( if (x_dtype == phi::DataType::FLOAT32) {
ctx.x_context(), out_fp16_data, ctx.template Alloc<T>(out), out->numel()); int r_cast_out =
PADDLE_ENFORCE_XDNN_SUCCESS(r_cast_out, xpu::cast_v2<float16, float>(ctx.x_context(),
"multi_encoder_xpu(cast out from fp16 to fp32)"); out_fp16_data,
ctx.template Alloc<float>(out),
out->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(
r_cast_out, "multi_encoder_xpu(cast out from fp16 to fp32)");
}
} }
} // namespace fusion } // namespace fusion
...@@ -165,4 +178,5 @@ PD_REGISTER_KERNEL(multi_encoder_xpu, ...@@ -165,4 +178,5 @@ PD_REGISTER_KERNEL(multi_encoder_xpu,
XPU, XPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::fusion::MultiEncoderXPUKernel, phi::fusion::MultiEncoderXPUKernel,
float) {} float,
phi::dtype::float16) {}
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import tempfile
import unittest
import paddle
from paddle.inference import (
PlaceType,
PrecisionType,
convert_to_mixed_precision,
)
from paddle.jit import to_static
from paddle.static import InputSpec
from paddle.vision.models import resnet50
class ConvertMixedPrecison(unittest.TestCase):
def test(self):
self.temp_dir = tempfile.TemporaryDirectory()
model = resnet50(True)
net = to_static(
model, input_spec=[InputSpec(shape=[None, 3, 224, 224], name='x')]
)
paddle.jit.save(
net, os.path.join(self.temp_dir.name, 'resnet50/inference')
)
convert_to_mixed_precision(
os.path.join(self.temp_dir.name, 'resnet50/inference.pdmodel'),
os.path.join(self.temp_dir.name, 'resnet50/inference.pdiparams'),
os.path.join(
self.temp_dir.name, 'mixed_precision/inference.pdmodel'
),
os.path.join(
self.temp_dir.name, 'mixed_precision/inference.pdiparams'
),
backend=PlaceType.XPU,
mixed_precision=PrecisionType.Half,
)
self.temp_dir.cleanup()
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册