未验证 提交 fcb746cb 编写于 作者: D duanyanhui 提交者: GitHub

Expand mixed_precision to custom device (#50378)

* expand mix_precision to custom_device

* fix bug

* fix bug

* fix comment

* fix DEFINE bug
上级 4a7d9cd8
...@@ -22,6 +22,9 @@ ...@@ -22,6 +22,9 @@
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/errors.h" #include "paddle/phi/core/errors.h"
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#include "paddle/phi/backends/device_manager.h"
#endif
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -44,16 +47,19 @@ bool PhiKernelSupportPrecision( ...@@ -44,16 +47,19 @@ bool PhiKernelSupportPrecision(
return phi::KernelFactory::Instance().HasKernel(op_type, kernel_key); return phi::KernelFactory::Instance().HasKernel(op_type, kernel_key);
} }
bool GpuKernelSupportPrecision( bool KernelSupportPrecision(
const std::string& op_type, const std::string& op_type,
phi::Backend backend,
phi::DataType precision, phi::DataType precision,
phi::DataLayout layout = phi::DataLayout::ALL_LAYOUT) { phi::DataLayout layout = phi::DataLayout::ALL_LAYOUT) {
auto phi_op_type = phi::TransToPhiKernelName(op_type); auto phi_op_type = phi::TransToPhiKernelName(op_type);
bool support = PhiKernelSupportPrecision(
phi_op_type, phi::Backend::GPU, precision, layout); bool support =
PhiKernelSupportPrecision(phi_op_type, backend, precision, layout);
if (backend == phi::Backend::GPU) {
support |= PhiKernelSupportPrecision( support |= PhiKernelSupportPrecision(
phi_op_type, phi::Backend::GPUDNN, precision, layout); phi_op_type, phi::Backend::GPUDNN, precision, layout);
}
if (!support) { if (!support) {
const auto& all_kernels = framework::OperatorWithKernel::AllOpKernels(); const auto& all_kernels = framework::OperatorWithKernel::AllOpKernels();
auto it = all_kernels.find(op_type); auto it = all_kernels.find(op_type);
...@@ -146,11 +152,15 @@ bool OpSupportPrecision(const std::string& op_type, ...@@ -146,11 +152,15 @@ bool OpSupportPrecision(const std::string& op_type,
const std::unordered_set<std::string>& black_list) { const std::unordered_set<std::string>& black_list) {
bool support = false; bool support = false;
if (black_list.count(op_type) == 0) { if (black_list.count(op_type) == 0) {
if (backend == phi::Backend::GPU) { // Actual custom backend will be added after the NUM_BACKENDS.
support = GpuKernelSupportPrecision(op_type, precision); // We use this feature to determine whether backend is custom device.
if (backend == phi::Backend::GPU ||
static_cast<size_t>(backend) >
static_cast<size_t>(phi::Backend::NUM_BACKENDS)) {
support = KernelSupportPrecision(op_type, backend, precision);
} else { } else {
PADDLE_THROW(paddle::platform::errors::InvalidArgument( PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"Now, only support backend of GPU.")); "Now, only support backend of GPU and Custom Device ."));
} }
} }
return support; return support;
...@@ -183,11 +193,28 @@ void AutoMixedPrecisionPass::SetDefaultBlacklist() const { ...@@ -183,11 +193,28 @@ void AutoMixedPrecisionPass::SetDefaultBlacklist() const {
void AutoMixedPrecisionPass::Init(Graph* graph) const { void AutoMixedPrecisionPass::Init(Graph* graph) const {
bool enable_gpu_mixed = Get<bool>("enable_gpu_mixed"); bool enable_gpu_mixed = Get<bool>("enable_gpu_mixed");
bool enable_custom_device_mixed = false;
if (Has("enable_custom_device_mixed")) {
enable_custom_device_mixed = Get<bool>("enable_custom_device_mixed");
}
if (enable_gpu_mixed) { if (enable_gpu_mixed) {
backend_ = phi::Backend::GPU; backend_ = phi::Backend::GPU;
} } else if (enable_custom_device_mixed) {
// transform Backend::CUSTOM to actual backend.
skip_pass_ = !enable_gpu_mixed; // Here, we only consider one custom backend.
#ifdef PADDLE_WITH_CUSTOM_DEVICE
auto device_type = phi::DeviceManager::GetAllCustomDeviceTypes()[0];
backend_ = static_cast<phi::Backend>(
static_cast<size_t>(phi::Backend::NUM_BACKENDS) +
phi::CustomRegisteredDeviceMap::Instance()
.GetOrRegisterGlobalDeviceTypeId(device_type));
#else
PADDLE_THROW(paddle::platform::errors::Unavailable(
"Paddle is not compiled with CustomDevice. "
"Cannot enable custom_device_mixed."));
#endif
}
skip_pass_ = !enable_gpu_mixed && !enable_custom_device_mixed;
low_precision_ = static_cast<phi::DataType>(Get<int>("mixed_precision_mode")); low_precision_ = static_cast<phi::DataType>(Get<int>("mixed_precision_mode"));
...@@ -466,8 +493,8 @@ void AutoMixedPrecisionPass::UpdateOpPrecision() const { ...@@ -466,8 +493,8 @@ void AutoMixedPrecisionPass::UpdateOpPrecision() const {
// when op_1 only support cpu kernel. if op_2's intput var is op_1's // when op_1 only support cpu kernel. if op_2's intput var is op_1's
// output var, then op_2 should not run at low precision. // output var, then op_2 should not run at low precision.
if (GetOpOriginalType(op_type) != "feed" && if (GetOpOriginalType(op_type) != "feed" &&
!GpuKernelSupportPrecision(GetOpOriginalType(op_type), !KernelSupportPrecision(
phi::DataType::FLOAT32)) { GetOpOriginalType(op_type), backend_, phi::DataType::FLOAT32)) {
for (auto* out_var_node : op_node->outputs) { for (auto* out_var_node : op_node->outputs) {
CHECK_EQ(out_var_node->IsVar(), true); CHECK_EQ(out_var_node->IsVar(), true);
if (out_var_node->Var()->Persistable()) continue; if (out_var_node->Var()->Persistable()) continue;
......
...@@ -48,10 +48,10 @@ ConvertToMixedPrecisionPass::ConvertToMixedPrecisionPass( ...@@ -48,10 +48,10 @@ ConvertToMixedPrecisionPass::ConvertToMixedPrecisionPass(
"support fp16 and bf16.", "support fp16 and bf16.",
static_cast<int>(mixed_precision_))); static_cast<int>(mixed_precision_)));
} }
if (backend_ != phi::Backend::GPU) { if (backend_ != phi::Backend::GPU && backend_ != phi::Backend::CUSTOM) {
PADDLE_THROW(paddle::platform::errors::InvalidArgument( PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"mixed_precision currently not supported place %d, we now only " "mixed_precision currently not supported place %d, we now only "
"support gpu.", "support gpu and custom device .",
static_cast<int>(backend_))); static_cast<int>(backend_)));
} }
} }
...@@ -72,7 +72,13 @@ void ConvertToMixedPrecisionPass::Run() { ...@@ -72,7 +72,13 @@ void ConvertToMixedPrecisionPass::Run() {
pass.Set("mixed_precision_mode", new int{static_cast<int>(mixed_precision_)}); pass.Set("mixed_precision_mode", new int{static_cast<int>(mixed_precision_)});
pass.Set("mixed_black_list", pass.Set("mixed_black_list",
new std::unordered_set<std::string>{black_list_}); new std::unordered_set<std::string>{black_list_});
if (backend_ == phi::Backend::GPU) {
pass.Set("enable_gpu_mixed", new bool{true}); pass.Set("enable_gpu_mixed", new bool{true});
pass.Set("enable_custom_device_mixed", new bool{false});
} else if (backend_ == phi::Backend::CUSTOM) {
pass.Set("enable_gpu_mixed", new bool{false});
pass.Set("enable_custom_device_mixed", new bool{true});
}
pass.Set("keep_io_types", new bool{keep_io_types_}); pass.Set("keep_io_types", new bool{keep_io_types_});
pass.Apply(main_graph_.get()); pass.Apply(main_graph_.get());
......
...@@ -146,6 +146,8 @@ phi::Backend ConvertBackend(paddle_infer::PlaceType backend) { ...@@ -146,6 +146,8 @@ phi::Backend ConvertBackend(paddle_infer::PlaceType backend) {
return phi::Backend::CPU; return phi::Backend::CPU;
case paddle_infer::PlaceType::kIPU: case paddle_infer::PlaceType::kIPU:
return phi::Backend::IPU; return phi::Backend::IPU;
case paddle_infer::PlaceType::kCUSTOM:
return phi::Backend::CUSTOM;
default: default:
PADDLE_THROW(paddle::platform::errors::InvalidArgument( PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"Paddle Inference not support backend, we now only support GPU, XPU, " "Paddle Inference not support backend, we now only support GPU, XPU, "
......
...@@ -59,6 +59,9 @@ enum class Backend : uint8_t { ...@@ -59,6 +59,9 @@ enum class Backend : uint8_t {
// paddle kernel primitives backend // paddle kernel primitives backend
KPS, KPS,
// custom device reference
CUSTOM,
// end of backend types // end of backend types
NUM_BACKENDS, NUM_BACKENDS,
...@@ -207,7 +210,7 @@ inline std::string BackendToString(const Backend& backend) { ...@@ -207,7 +210,7 @@ inline std::string BackendToString(const Backend& backend) {
return "KPS"; return "KPS";
case Backend::IPU: case Backend::IPU:
return "IPU"; return "IPU";
default: default: {
size_t device_type_id_ = static_cast<size_t>(backend) - size_t device_type_id_ = static_cast<size_t>(backend) -
static_cast<size_t>(Backend::NUM_BACKENDS); static_cast<size_t>(Backend::NUM_BACKENDS);
std::string device_type = std::string device_type =
...@@ -220,6 +223,7 @@ inline std::string BackendToString(const Backend& backend) { ...@@ -220,6 +223,7 @@ inline std::string BackendToString(const Backend& backend) {
"Invalid enum backend type `", static_cast<int>(backend), "`."); "Invalid enum backend type `", static_cast<int>(backend), "`.");
} }
} }
}
} }
} // namespace experimental } // namespace experimental
......
...@@ -258,6 +258,41 @@ class TestCustomCPUPlugin(unittest.TestCase): ...@@ -258,6 +258,41 @@ class TestCustomCPUPlugin(unittest.TestCase):
avg_loss.backward() avg_loss.backward()
sgd.step() sgd.step()
def _test_custom_device_mix_precision(self):
import tempfile
import paddle
from paddle.inference import (
PlaceType,
PrecisionType,
convert_to_mixed_precision,
)
from paddle.jit import to_static
from paddle.static import InputSpec
from paddle.vision.models import resnet50
self.temp_dir = tempfile.TemporaryDirectory()
model = resnet50(True)
net = to_static(
model, input_spec=[InputSpec(shape=[None, 3, 224, 224], name='x')]
)
paddle.jit.save(
net, os.path.join(self.temp_dir.name, 'resnet50/inference')
)
convert_to_mixed_precision(
os.path.join(self.temp_dir.name, 'resnet50/inference.pdmodel'),
os.path.join(self.temp_dir.name, 'resnet50/inference.pdiparams'),
os.path.join(
self.temp_dir.name, 'mixed_precision/inference.pdmodel'
),
os.path.join(
self.temp_dir.name, 'mixed_precision/inference.pdiparams'
),
backend=PlaceType.CUSTOM,
mixed_precision=PrecisionType.Half,
)
self.temp_dir.cleanup()
def _test_custom_device_py_api(self): def _test_custom_device_py_api(self):
import paddle import paddle
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册