未验证 提交 73bf9673 编写于 作者: Y Yuang Liu 提交者: GitHub

[fleet exe] supprot fp16 feed and fetch on cpp side (#39758)

上级 68631ed4
...@@ -52,6 +52,8 @@ bool LoadDataFromDistModelTensor(const DistModelTensor &input_data, ...@@ -52,6 +52,8 @@ bool LoadDataFromDistModelTensor(const DistModelTensor &input_data,
input_tensor_ptr = input_tensor->mutable_data<float>(dims, place); input_tensor_ptr = input_tensor->mutable_data<float>(dims, place);
} else if (input_data.dtype == DistModelDataType::INT32) { } else if (input_data.dtype == DistModelDataType::INT32) {
input_tensor_ptr = input_tensor->mutable_data<int32_t>(dims, place); input_tensor_ptr = input_tensor->mutable_data<int32_t>(dims, place);
} else if (input_data.dtype == DistModelDataType::FLOAT16) {
input_tensor_ptr = input_tensor->mutable_data<float16>(dims, place);
} else { } else {
LOG(ERROR) << "unsupported feed type " << input_data.dtype; LOG(ERROR) << "unsupported feed type " << input_data.dtype;
return false; return false;
...@@ -412,6 +414,8 @@ bool DistModel::PrepareFeedAndFetch() { ...@@ -412,6 +414,8 @@ bool DistModel::PrepareFeedAndFetch() {
feeds_to_dtype_.insert({var_name, DistModelDataType::INT32}); feeds_to_dtype_.insert({var_name, DistModelDataType::INT32});
} else if (real_var->GetDataType() == framework::proto::VarType::INT64) { } else if (real_var->GetDataType() == framework::proto::VarType::INT64) {
feeds_to_dtype_.insert({var_name, DistModelDataType::INT64}); feeds_to_dtype_.insert({var_name, DistModelDataType::INT64});
} else if (real_var->GetDataType() == framework::proto::VarType::FP16) {
feeds_to_dtype_.insert({var_name, DistModelDataType::FLOAT16});
} else { } else {
LOG(ERROR) << "Don't support feed var dtype for: " LOG(ERROR) << "Don't support feed var dtype for: "
<< real_var->GetDataType(); << real_var->GetDataType();
...@@ -503,9 +507,13 @@ bool DistModel::FetchResults(std::vector<DistModelTensor> *output_data, ...@@ -503,9 +507,13 @@ bool DistModel::FetchResults(std::vector<DistModelTensor> *output_data,
} else if (type == framework::proto::VarType::INT32) { } else if (type == framework::proto::VarType::INT32) {
rst = FetchResult<int32_t>(fetch, output); rst = FetchResult<int32_t>(fetch, output);
output->dtype = DistModelDataType::INT32; output->dtype = DistModelDataType::INT32;
} else if (type == framework::proto::VarType::FP16) {
rst = FetchResult<float16>(fetch, output);
output->dtype = DistModelDataType::FLOAT16;
} else { } else {
LOG(ERROR) << "DistModel meets unknown fetch data type. DistModel only " LOG(ERROR) << "DistModel meets unknown fetch data type. DistModel only "
"supports float32, int64 and int32 fetch type for now."; "supports float32, float16, int64 and int32 fetch type "
"for now.";
} }
if (!rst) { if (!rst) {
LOG(ERROR) << "DistModel fails to fetch result " << idx_to_fetches_[idx]; LOG(ERROR) << "DistModel fails to fetch result " << idx_to_fetches_[idx];
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#pragma once #pragma once
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/macros.h" #include "paddle/fluid/platform/macros.h"
namespace paddle { namespace paddle {
...@@ -40,6 +41,11 @@ constexpr DistModelDataType DistModelGetDtype<float>() { ...@@ -40,6 +41,11 @@ constexpr DistModelDataType DistModelGetDtype<float>() {
return DistModelDataType::FLOAT32; return DistModelDataType::FLOAT32;
} }
template <>
constexpr DistModelDataType DistModelGetDtype<platform::float16>() {
return DistModelDataType::FLOAT16;
}
class DistModelDataBuf { class DistModelDataBuf {
public: public:
explicit DistModelDataBuf(size_t length) explicit DistModelDataBuf(size_t length)
......
...@@ -24,10 +24,41 @@ ...@@ -24,10 +24,41 @@
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "pybind11/pybind11.h"
namespace py = pybind11; namespace py = pybind11;
namespace pybind11 {
namespace detail {
// Note: use same enum number of float16 in numpy.
// import numpy as np
// print np.dtype(np.float16).num # 23
constexpr int NPY_FLOAT16_ = 23;
// Note: Since float16 is not a builtin type in C++, we register
// paddle::platform::float16 as numpy.float16.
// Ref: https://github.com/pybind/pybind11/issues/1776
template <>
struct npy_format_descriptor<paddle::platform::float16> {
static py::dtype dtype() {
handle ptr = npy_api::get().PyArray_DescrFromType_(NPY_FLOAT16_);
return reinterpret_borrow<py::dtype>(ptr);
}
static std::string format() {
// Note: "e" represents float16.
// Details at:
// https://docs.python.org/3/library/struct.html#format-characters.
return "e";
}
static constexpr auto name = _("float16");
};
} // namespace detail
} // namespace pybind11
namespace paddle { namespace paddle {
namespace pybind { namespace pybind {
...@@ -175,6 +206,7 @@ void BindFleetExecutor(py::module* m) { ...@@ -175,6 +206,7 @@ void BindFleetExecutor(py::module* m) {
.def(py::init(&DistModelDataBufCreate<int32_t>)) .def(py::init(&DistModelDataBufCreate<int32_t>))
.def(py::init(&DistModelDataBufCreate<int64_t>)) .def(py::init(&DistModelDataBufCreate<int64_t>))
.def(py::init(&DistModelDataBufCreate<float>)) .def(py::init(&DistModelDataBufCreate<float>))
.def(py::init(&DistModelDataBufCreate<paddle::platform::float16>))
.def("reset", .def("reset",
[](DistModelDataBuf& self, std::vector<float>& data) { [](DistModelDataBuf& self, std::vector<float>& data) {
self.Resize(data.size() * sizeof(float)); self.Resize(data.size() * sizeof(float));
...@@ -183,29 +215,35 @@ void BindFleetExecutor(py::module* m) { ...@@ -183,29 +215,35 @@ void BindFleetExecutor(py::module* m) {
.def("reset", &DistModelDataBufReset<int32_t>) .def("reset", &DistModelDataBufReset<int32_t>)
.def("reset", &DistModelDataBufReset<int64_t>) .def("reset", &DistModelDataBufReset<int64_t>)
.def("reset", &DistModelDataBufReset<float>) .def("reset", &DistModelDataBufReset<float>)
.def("reset", &DistModelDataBufReset<paddle::platform::float16>)
.def("length", &DistModelDataBuf::length) .def("length", &DistModelDataBuf::length)
.def("tolist", .def("tolist", [](DistModelDataBuf& self,
[](DistModelDataBuf& self, const std::string& dtype) -> py::list { const std::string& dtype) -> py::list {
py::list l; py::list l;
if (dtype == "int32") { if (dtype == "int32") {
auto* data = static_cast<int32_t*>(self.data()); auto* data = static_cast<int32_t*>(self.data());
auto size = self.length() / sizeof(int32_t); auto size = self.length() / sizeof(int32_t);
l = py::cast(std::vector<int32_t>(data, data + size)); l = py::cast(std::vector<int32_t>(data, data + size));
} else if (dtype == "int64") { } else if (dtype == "int64") {
auto* data = static_cast<int64_t*>(self.data()); auto* data = static_cast<int64_t*>(self.data());
auto size = self.length() / sizeof(int64_t); auto size = self.length() / sizeof(int64_t);
l = py::cast(std::vector<int64_t>(data, data + size)); l = py::cast(std::vector<int64_t>(data, data + size));
} else if (dtype == "float32") { } else if (dtype == "float32") {
auto* data = static_cast<float*>(self.data()); auto* data = static_cast<float*>(self.data());
auto size = self.length() / sizeof(float); auto size = self.length() / sizeof(float);
l = py::cast(std::vector<float>(data, data + size)); l = py::cast(std::vector<float>(data, data + size));
} else { } else if (dtype == "float16") {
PADDLE_THROW(platform::errors::Unimplemented( auto* data = static_cast<paddle::platform::float16*>(self.data());
"Unsupported data type. Now only supports INT32, INT64 and " auto size = self.length() / sizeof(paddle::platform::float16);
"FLOAT32.")); l = py::cast(
} std::vector<paddle::platform::float16>(data, data + size));
return l; } else {
}); PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported data type. Now only supports INT32, INT64, "
"FLOAT16 and FLOAT32."));
}
return l;
});
py::class_<DistModelTensor>(*m, "DistModelTensor") py::class_<DistModelTensor>(*m, "DistModelTensor")
.def(py::init<>()) .def(py::init<>())
...@@ -221,6 +259,10 @@ void BindFleetExecutor(py::module* m) { ...@@ -221,6 +259,10 @@ void BindFleetExecutor(py::module* m) {
py::arg("name") = "", py::arg("name") = "",
py::arg("lod") = std::vector<std::vector<size_t>>(), py::arg("lod") = std::vector<std::vector<size_t>>(),
py::arg("copy") = true) py::arg("copy") = true)
.def(py::init(&DistModelTensorCreate<paddle::platform::float16>),
py::arg("data"), py::arg("name") = "",
py::arg("lod") = std::vector<std::vector<size_t>>(),
py::arg("copy") = true)
.def_readwrite("name", &DistModelTensor::name) .def_readwrite("name", &DistModelTensor::name)
.def_readwrite("shape", &DistModelTensor::shape) .def_readwrite("shape", &DistModelTensor::shape)
.def_readwrite("data", &DistModelTensor::data) .def_readwrite("data", &DistModelTensor::data)
...@@ -231,7 +273,8 @@ void BindFleetExecutor(py::module* m) { ...@@ -231,7 +273,8 @@ void BindFleetExecutor(py::module* m) {
py::enum_<DistModelDataType>(*m, "DistModelDataType") py::enum_<DistModelDataType>(*m, "DistModelDataType")
.value("FLOAT32", DistModelDataType::FLOAT32) .value("FLOAT32", DistModelDataType::FLOAT32)
.value("INT64", DistModelDataType::INT64) .value("INT64", DistModelDataType::INT64)
.value("INT32", DistModelDataType::INT32); .value("INT32", DistModelDataType::INT32)
.value("FLOAT16", DistModelDataType::FLOAT16);
} }
} // namespace pybind } // namespace pybind
} // namespace paddle } // namespace paddle
...@@ -58,6 +58,19 @@ class TestDistModelTensor(unittest.TestCase): ...@@ -58,6 +58,19 @@ class TestDistModelTensor(unittest.TestCase):
self.assertEqual(dist_tensor_float.as_ndarray().ravel().tolist(), self.assertEqual(dist_tensor_float.as_ndarray().ravel().tolist(),
tensor_float.ravel().tolist()) tensor_float.ravel().tolist())
tensor_float_16 = np.random.randn(20, 2).astype('float16')
dist_tensor_float_16 = DistModelTensor(tensor_float_16,
'float_tensor_16')
self.assertEqual(dist_tensor_float_16.dtype, DistModelDataType.FLOAT16)
self.assertEqual(
dist_tensor_float_16.data.tolist('float16'),
tensor_float_16.ravel().tolist())
self.assertEqual(dist_tensor_float_16.data.length(), 40 * 2)
self.assertEqual(dist_tensor_float_16.name, 'float_tensor_16')
dist_tensor_float_16.data.reset(tensor_float_16)
self.assertEqual(dist_tensor_float_16.as_ndarray().ravel().tolist(),
tensor_float_16.ravel().tolist())
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册