未验证 提交 e26f1123 编写于 作者: K Kexin Zhao 提交者: GitHub

Add fp16 mul op support and bind paddle fp16 to numpy fp16 (#9017)

* add fp16 mul op support

* small fix

* fix bug

* small fix

* fix PADDLE_WITH_CUDA compiling issue

* reorg code

* test for pybind

* treate as float16 as uint16_t in pybind

* bind np.float16 to paddle float16

* small fix

* clean code

* remove redundancy

* fix mul_op test

* address comments

* small fix

* add is_float16_supported func
上级 71400711
...@@ -17,11 +17,14 @@ limitations under the License. */ ...@@ -17,11 +17,14 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using framework::OpKernelType;
using framework::Tensor; using framework::Tensor;
class MulOpShapeInference : public framework::InferShapeBase { class MulOp : public framework::OperatorWithKernel {
public: public:
void operator()(framework::InferShapeContext* ctx) const override { using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of MulOp should not be null."); PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of MulOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) of MulOp should not be null."); PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) of MulOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
...@@ -122,7 +125,7 @@ or not. But the output only shares the LoD information with input $X$. ...@@ -122,7 +125,7 @@ or not. But the output only shares the LoD information with input $X$.
} }
}; };
class MulOpGrad : public framework::OperatorWithKernel { class MulGradOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
...@@ -156,10 +159,7 @@ class MulOpGrad : public framework::OperatorWithKernel { ...@@ -156,10 +159,7 @@ class MulOpGrad : public framework::OperatorWithKernel {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(mul, paddle::framework::OperatorWithKernel, ops::MulOpMaker, REGISTER_OP(mul, ops::MulOp, ops::MulOpMaker, mul_grad, ops::MulGradOp);
ops::MulOpShapeInference,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(mul_grad, ops::MulOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
mul, ops::MulKernel<paddle::platform::CPUDeviceContext, float>); mul, ops::MulKernel<paddle::platform::CPUDeviceContext, float>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
......
...@@ -13,9 +13,11 @@ See the License for the specific language governing permissions and ...@@ -13,9 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/mul_op.h" #include "paddle/fluid/operators/mul_op.h"
#include "paddle/fluid/platform/float16.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL( namespace plat = paddle::platform;
mul, ops::MulKernel<paddle::platform::CUDADeviceContext, float>); REGISTER_OP_CUDA_KERNEL(mul, ops::MulKernel<plat::CUDADeviceContext, float>,
REGISTER_OP_CUDA_KERNEL( ops::MulKernel<plat::CUDADeviceContext, plat::float16>);
mul_grad, ops::MulGradKernel<paddle::platform::CUDADeviceContext, float>); REGISTER_OP_CUDA_KERNEL(mul_grad,
ops::MulGradKernel<plat::CUDADeviceContext, float>);
...@@ -48,7 +48,7 @@ class MulKernel : public framework::OpKernel<T> { ...@@ -48,7 +48,7 @@ class MulKernel : public framework::OpKernel<T> {
} }
math::matmul<DeviceContext, T>( math::matmul<DeviceContext, T>(
context.template device_context<DeviceContext>(), x_matrix, false, context.template device_context<DeviceContext>(), x_matrix, false,
y_matrix, false, 1, z, 0); y_matrix, false, static_cast<T>(1), z, static_cast<T>(0));
if (z_dim.size() != 2) { if (z_dim.size() != 2) {
z->Resize(z_dim); z->Resize(z_dim);
} }
......
...@@ -31,6 +31,7 @@ limitations under the License. */ ...@@ -31,6 +31,7 @@ limitations under the License. */
#include "paddle/fluid/operators/cond_op.h" #include "paddle/fluid/operators/cond_op.h"
#include "paddle/fluid/operators/net_op.h" #include "paddle/fluid/operators/net_op.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/gpu_info.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/pybind/const_value.h" #include "paddle/fluid/pybind/const_value.h"
...@@ -103,12 +104,14 @@ PYBIND11_PLUGIN(core) { ...@@ -103,12 +104,14 @@ PYBIND11_PLUGIN(core) {
.def("set", PyCPUTensorSetFromArray<double>) .def("set", PyCPUTensorSetFromArray<double>)
.def("set", PyCPUTensorSetFromArray<int64_t>) .def("set", PyCPUTensorSetFromArray<int64_t>)
.def("set", PyCPUTensorSetFromArray<bool>) .def("set", PyCPUTensorSetFromArray<bool>)
.def("set", PyCPUTensorSetFromArray<uint16_t>)
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
.def("set", PyCUDATensorSetFromArray<float>) .def("set", PyCUDATensorSetFromArray<float>)
.def("set", PyCUDATensorSetFromArray<int>) .def("set", PyCUDATensorSetFromArray<int>)
.def("set", PyCUDATensorSetFromArray<double>) .def("set", PyCUDATensorSetFromArray<double>)
.def("set", PyCUDATensorSetFromArray<int64_t>) .def("set", PyCUDATensorSetFromArray<int64_t>)
.def("set", PyCUDATensorSetFromArray<bool>) .def("set", PyCUDATensorSetFromArray<bool>)
.def("set", PyCUDATensorSetFromArray<uint16_t>)
#endif #endif
.def("shape", [](Tensor &self) { return vectorize(self.dims()); }) .def("shape", [](Tensor &self) { return vectorize(self.dims()); })
.def("set_float_element", TensorSetElement<float>) .def("set_float_element", TensorSetElement<float>)
...@@ -315,7 +318,6 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -315,7 +318,6 @@ All parameter, weight, gradient are variables in Paddle.
#endif #endif
}); });
// clang-format on // clang-format on
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
py::class_<platform::Communicator>(m, "Communicator").def(py::init<>()); py::class_<platform::Communicator>(m, "Communicator").def(py::init<>());
#endif #endif
...@@ -423,6 +425,12 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -423,6 +425,12 @@ All parameter, weight, gradient are variables in Paddle.
m.def("init_devices", &framework::InitDevices); m.def("init_devices", &framework::InitDevices);
m.def("is_compiled_with_cuda", IsCompiledWithCUDA); m.def("is_compiled_with_cuda", IsCompiledWithCUDA);
#ifdef PADDLE_WITH_CUDA
m.def("is_float16_supported", [](const platform::CUDAPlace &place) -> bool {
// Only GPUs with Compute Capability >= 53 support float16
return platform::GetCUDAComputeCapability(place.device) >= 53;
});
#endif
m.def("set_feed_variable", framework::SetFeedVariable); m.def("set_feed_variable", framework::SetFeedVariable);
m.def("get_fetch_variable", framework::GetFetchVariable); m.def("get_fetch_variable", framework::GetFetchVariable);
......
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/float16.h"
#include "pybind11/numpy.h" #include "pybind11/numpy.h"
#include "pybind11/pybind11.h" #include "pybind11/pybind11.h"
...@@ -77,21 +78,32 @@ struct CastToPyBufferImpl<true, I, ARGS...> { ...@@ -77,21 +78,32 @@ struct CastToPyBufferImpl<true, I, ARGS...> {
} else if (paddle::platform::is_cpu_place(tensor.place())) { } else if (paddle::platform::is_cpu_place(tensor.place())) {
dst_tensor = tensor; dst_tensor = tensor;
} }
if (std::type_index(typeid(CUR_TYPE)) ==
std::type_index(typeid(platform::float16))) {
return py::buffer_info(dst_tensor.data<CUR_TYPE>(), sizeof(CUR_TYPE),
"e", /* np.dtype('e') == np.float16 */
(size_t)framework::arity(dst_tensor.dims()),
dims_outside, strides);
} else {
return py::buffer_info(dst_tensor.data<CUR_TYPE>(), sizeof(CUR_TYPE), return py::buffer_info(dst_tensor.data<CUR_TYPE>(), sizeof(CUR_TYPE),
py::format_descriptor<CUR_TYPE>::format(), py::format_descriptor<CUR_TYPE>::format(),
(size_t)framework::arity(dst_tensor.dims()), (size_t)framework::arity(dst_tensor.dims()),
dims_outside, strides); dims_outside, strides);
}
} else { } else {
constexpr bool less = I + 1 < std::tuple_size<std::tuple<ARGS...>>::value; constexpr bool less = I + 1 < std::tuple_size<std::tuple<ARGS...>>::value;
return CastToPyBufferImpl<less, I + 1, ARGS...>()(tensor); return CastToPyBufferImpl<less, I + 1, ARGS...>()(tensor);
} }
} }
}; };
} // namespace details } // namespace details
inline py::buffer_info CastToPyBuffer(framework::Tensor &tensor) { inline py::buffer_info CastToPyBuffer(framework::Tensor &tensor) {
auto buffer_info = auto buffer_info =
details::CastToPyBufferImpl<true, 0, float, int, double, int64_t, bool>()( details::CastToPyBufferImpl<true, 0, float, int, double, int64_t, bool,
tensor); platform::float16>()(tensor);
return buffer_info; return buffer_info;
} }
...@@ -136,6 +148,22 @@ void PyCPUTensorSetFromArray( ...@@ -136,6 +148,22 @@ void PyCPUTensorSetFromArray(
std::memcpy(dst, array.data(), sizeof(T) * array.size()); std::memcpy(dst, array.data(), sizeof(T) * array.size());
} }
template <>
void PyCPUTensorSetFromArray(
framework::Tensor &self,
py::array_t<uint16_t, py::array::c_style | py::array::forcecast> array,
paddle::platform::CPUPlace &place) {
std::vector<int64_t> dims;
dims.reserve(array.ndim());
for (size_t i = 0; i < array.ndim(); ++i) {
dims.push_back((int)array.shape()[i]);
}
self.Resize(framework::make_ddim(dims));
auto *dst = self.mutable_data<platform::float16>(place);
std::memcpy(dst, array.data(), sizeof(uint16_t) * array.size());
}
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
template <typename T> template <typename T>
void PyCUDATensorSetFromArray( void PyCUDATensorSetFromArray(
...@@ -157,6 +185,28 @@ void PyCUDATensorSetFromArray( ...@@ -157,6 +185,28 @@ void PyCUDATensorSetFromArray(
paddle::platform::GpuMemcpyAsync(dst, array.data(), sizeof(T) * array.size(), paddle::platform::GpuMemcpyAsync(dst, array.data(), sizeof(T) * array.size(),
cudaMemcpyHostToDevice, dev_ctx->stream()); cudaMemcpyHostToDevice, dev_ctx->stream());
} }
template <>
void PyCUDATensorSetFromArray(
framework::Tensor &self,
py::array_t<uint16_t, py::array::c_style | py::array::forcecast> array,
paddle::platform::CUDAPlace &place) {
std::vector<int64_t> dims;
dims.reserve(array.ndim());
for (size_t i = 0; i < array.ndim(); ++i) {
dims.push_back((int)array.shape()[i]);
}
self.Resize(framework::make_ddim(dims));
auto *dst = self.mutable_data<platform::float16>(place);
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto dev_ctx =
static_cast<const platform::CUDADeviceContext *>(pool.Get(place));
paddle::platform::GpuMemcpyAsync(dst, array.data(),
sizeof(uint16_t) * array.size(),
cudaMemcpyHostToDevice, dev_ctx->stream());
}
#endif #endif
} // namespace pybind } // namespace pybind
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
import unittest import unittest
import numpy as np import numpy as np
import paddle.fluid.core as core
from op_test import OpTest from op_test import OpTest
...@@ -69,5 +70,42 @@ class TestMulOp2(OpTest): ...@@ -69,5 +70,42 @@ class TestMulOp2(OpTest):
['X'], 'Out', max_relative_error=0.5, no_grad_set=set('Y')) ['X'], 'Out', max_relative_error=0.5, no_grad_set=set('Y'))
class TestFP16MulOp1(OpTest):
def setUp(self):
self.op_type = "mul"
x = np.random.random((32, 84)).astype("float16")
y = np.random.random((84, 100)).astype("float16")
self.inputs = {'X': x.view(np.uint16), 'Y': y.view(np.uint16)}
self.outputs = {'Out': np.dot(x, y)}
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=1e-1)
class TestFP16MulOp2(OpTest):
def setUp(self):
self.op_type = "mul"
x = np.random.random((15, 4, 12, 10)).astype("float16")
y = np.random.random((4, 30, 8, 2, 9)).astype("float16")
self.inputs = {'X': x.view(np.uint16), 'Y': y.view(np.uint16)}
self.attrs = {
'x_num_col_dims': 2,
'y_num_col_dims': 2,
}
result = np.dot(
x.reshape(15 * 4, 12 * 10), y.reshape(4 * 30, 8 * 2 * 9))
result = result.reshape(15, 4, 8, 2, 9)
self.outputs = {'Out': result}
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=2e-1)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册