diff --git a/paddle/fluid/operators/mul_op.cc b/paddle/fluid/operators/mul_op.cc index e7bed2c39735b66c19e738c91f4977e46571143b..90af1e2d602ac039b4d98a69a889ff8b1b85ffc6 100644 --- a/paddle/fluid/operators/mul_op.cc +++ b/paddle/fluid/operators/mul_op.cc @@ -17,11 +17,14 @@ limitations under the License. */ namespace paddle { namespace operators { +using framework::OpKernelType; using framework::Tensor; -class MulOpShapeInference : public framework::InferShapeBase { +class MulOp : public framework::OperatorWithKernel { public: - void operator()(framework::InferShapeContext* ctx) const override { + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of MulOp should not be null."); PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) of MulOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), @@ -122,7 +125,7 @@ or not. But the output only shares the LoD information with input $X$. } }; -class MulOpGrad : public framework::OperatorWithKernel { +class MulGradOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -156,10 +159,7 @@ class MulOpGrad : public framework::OperatorWithKernel { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OPERATOR(mul, paddle::framework::OperatorWithKernel, ops::MulOpMaker, - ops::MulOpShapeInference, - paddle::framework::DefaultGradOpDescMaker); -REGISTER_OPERATOR(mul_grad, ops::MulOpGrad); +REGISTER_OP(mul, ops::MulOp, ops::MulOpMaker, mul_grad, ops::MulGradOp); REGISTER_OP_CPU_KERNEL( mul, ops::MulKernel); REGISTER_OP_CPU_KERNEL( diff --git a/paddle/fluid/operators/mul_op.cu.cc b/paddle/fluid/operators/mul_op.cu.cc index 0667530e943856576ae8c9fe4856cb6aa1448e4e..757f9c3ee2665c7ac654659416fe8dd727dca16d 100644 --- a/paddle/fluid/operators/mul_op.cu.cc +++ b/paddle/fluid/operators/mul_op.cu.cc @@ -13,9 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/mul_op.h" +#include "paddle/fluid/platform/float16.h" namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL( - mul, ops::MulKernel); -REGISTER_OP_CUDA_KERNEL( - mul_grad, ops::MulGradKernel); +namespace plat = paddle::platform; +REGISTER_OP_CUDA_KERNEL(mul, ops::MulKernel, + ops::MulKernel); +REGISTER_OP_CUDA_KERNEL(mul_grad, + ops::MulGradKernel); diff --git a/paddle/fluid/operators/mul_op.h b/paddle/fluid/operators/mul_op.h index 38311cf87265ad0f1f815734cbf69bd682d62e62..b1260d36ebe11f65529ac274c959479dcb38ee5f 100644 --- a/paddle/fluid/operators/mul_op.h +++ b/paddle/fluid/operators/mul_op.h @@ -48,7 +48,7 @@ class MulKernel : public framework::OpKernel { } math::matmul( context.template device_context(), x_matrix, false, - y_matrix, false, 1, z, 0); + y_matrix, false, static_cast(1), z, static_cast(0)); if (z_dim.size() != 2) { z->Resize(z_dim); } diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index d2e883caccdd34a9d662f06b83cf9a71d3d4a51e..6c05442466f5f3d8e04a8f0a2206443b1007a107 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -31,6 +31,7 @@ limitations under the License. */ #include "paddle/fluid/operators/cond_op.h" #include "paddle/fluid/operators/net_op.h" #include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/gpu_info.h" #include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/pybind/const_value.h" @@ -103,12 +104,14 @@ PYBIND11_PLUGIN(core) { .def("set", PyCPUTensorSetFromArray) .def("set", PyCPUTensorSetFromArray) .def("set", PyCPUTensorSetFromArray) + .def("set", PyCPUTensorSetFromArray) #ifdef PADDLE_WITH_CUDA .def("set", PyCUDATensorSetFromArray) .def("set", PyCUDATensorSetFromArray) .def("set", PyCUDATensorSetFromArray) .def("set", PyCUDATensorSetFromArray) .def("set", PyCUDATensorSetFromArray) + .def("set", PyCUDATensorSetFromArray) #endif .def("shape", [](Tensor &self) { return vectorize(self.dims()); }) .def("set_float_element", TensorSetElement) @@ -315,7 +318,6 @@ All parameter, weight, gradient are variables in Paddle. #endif }); // clang-format on - #ifdef PADDLE_WITH_CUDA py::class_(m, "Communicator").def(py::init<>()); #endif @@ -423,6 +425,12 @@ All parameter, weight, gradient are variables in Paddle. m.def("init_devices", &framework::InitDevices); m.def("is_compiled_with_cuda", IsCompiledWithCUDA); +#ifdef PADDLE_WITH_CUDA + m.def("is_float16_supported", [](const platform::CUDAPlace &place) -> bool { + // Only GPUs with Compute Capability >= 53 support float16 + return platform::GetCUDAComputeCapability(place.device) >= 53; + }); +#endif m.def("set_feed_variable", framework::SetFeedVariable); m.def("get_fetch_variable", framework::GetFetchVariable); diff --git a/paddle/fluid/pybind/tensor_py.h b/paddle/fluid/pybind/tensor_py.h index 1b0916ea0370d95a0c7dd149ee3f7b294c5e2351..3b206f2f87abe01363fb7e61c319559c6dd24594 100644 --- a/paddle/fluid/pybind/tensor_py.h +++ b/paddle/fluid/pybind/tensor_py.h @@ -17,6 +17,7 @@ limitations under the License. */ #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/float16.h" #include "pybind11/numpy.h" #include "pybind11/pybind11.h" @@ -77,21 +78,32 @@ struct CastToPyBufferImpl { } else if (paddle::platform::is_cpu_place(tensor.place())) { dst_tensor = tensor; } - return py::buffer_info(dst_tensor.data(), sizeof(CUR_TYPE), - py::format_descriptor::format(), - (size_t)framework::arity(dst_tensor.dims()), - dims_outside, strides); + + if (std::type_index(typeid(CUR_TYPE)) == + std::type_index(typeid(platform::float16))) { + return py::buffer_info(dst_tensor.data(), sizeof(CUR_TYPE), + "e", /* np.dtype('e') == np.float16 */ + (size_t)framework::arity(dst_tensor.dims()), + dims_outside, strides); + } else { + return py::buffer_info(dst_tensor.data(), sizeof(CUR_TYPE), + py::format_descriptor::format(), + (size_t)framework::arity(dst_tensor.dims()), + dims_outside, strides); + } } else { constexpr bool less = I + 1 < std::tuple_size>::value; return CastToPyBufferImpl()(tensor); } } }; + } // namespace details + inline py::buffer_info CastToPyBuffer(framework::Tensor &tensor) { auto buffer_info = - details::CastToPyBufferImpl()( - tensor); + details::CastToPyBufferImpl()(tensor); return buffer_info; } @@ -136,6 +148,22 @@ void PyCPUTensorSetFromArray( std::memcpy(dst, array.data(), sizeof(T) * array.size()); } +template <> +void PyCPUTensorSetFromArray( + framework::Tensor &self, + py::array_t array, + paddle::platform::CPUPlace &place) { + std::vector dims; + dims.reserve(array.ndim()); + for (size_t i = 0; i < array.ndim(); ++i) { + dims.push_back((int)array.shape()[i]); + } + + self.Resize(framework::make_ddim(dims)); + auto *dst = self.mutable_data(place); + std::memcpy(dst, array.data(), sizeof(uint16_t) * array.size()); +} + #ifdef PADDLE_WITH_CUDA template void PyCUDATensorSetFromArray( @@ -157,6 +185,28 @@ void PyCUDATensorSetFromArray( paddle::platform::GpuMemcpyAsync(dst, array.data(), sizeof(T) * array.size(), cudaMemcpyHostToDevice, dev_ctx->stream()); } + +template <> +void PyCUDATensorSetFromArray( + framework::Tensor &self, + py::array_t array, + paddle::platform::CUDAPlace &place) { + std::vector dims; + dims.reserve(array.ndim()); + for (size_t i = 0; i < array.ndim(); ++i) { + dims.push_back((int)array.shape()[i]); + } + + self.Resize(framework::make_ddim(dims)); + auto *dst = self.mutable_data(place); + + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto dev_ctx = + static_cast(pool.Get(place)); + paddle::platform::GpuMemcpyAsync(dst, array.data(), + sizeof(uint16_t) * array.size(), + cudaMemcpyHostToDevice, dev_ctx->stream()); +} #endif } // namespace pybind diff --git a/python/paddle/fluid/tests/unittests/test_mul_op.py b/python/paddle/fluid/tests/unittests/test_mul_op.py index 9d1da420c7f70bd2a89d183a5f0a2b145f0ff475..40440bea1267112b84b66002a0bf921be3029265 100644 --- a/python/paddle/fluid/tests/unittests/test_mul_op.py +++ b/python/paddle/fluid/tests/unittests/test_mul_op.py @@ -14,6 +14,7 @@ import unittest import numpy as np +import paddle.fluid.core as core from op_test import OpTest @@ -69,5 +70,42 @@ class TestMulOp2(OpTest): ['X'], 'Out', max_relative_error=0.5, no_grad_set=set('Y')) +class TestFP16MulOp1(OpTest): + def setUp(self): + self.op_type = "mul" + x = np.random.random((32, 84)).astype("float16") + y = np.random.random((84, 100)).astype("float16") + self.inputs = {'X': x.view(np.uint16), 'Y': y.view(np.uint16)} + self.outputs = {'Out': np.dot(x, y)} + + def test_check_output(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_output_with_place(place, atol=1e-1) + + +class TestFP16MulOp2(OpTest): + def setUp(self): + self.op_type = "mul" + x = np.random.random((15, 4, 12, 10)).astype("float16") + y = np.random.random((4, 30, 8, 2, 9)).astype("float16") + self.inputs = {'X': x.view(np.uint16), 'Y': y.view(np.uint16)} + self.attrs = { + 'x_num_col_dims': 2, + 'y_num_col_dims': 2, + } + result = np.dot( + x.reshape(15 * 4, 12 * 10), y.reshape(4 * 30, 8 * 2 * 9)) + result = result.reshape(15, 4, 8, 2, 9) + self.outputs = {'Out': result} + + def test_check_output(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_output_with_place(place, atol=2e-1) + + if __name__ == "__main__": unittest.main()