未验证 提交 86a23818 编写于 作者: P PuQing 提交者: GitHub

[Numpy] Add FP16 dtype for CastNumpy2Scalar (#50002)

* add FP16 dtype for CastNumpy2Scalar

* fix throw message

* add test

* fix SyntaxWarning

* test skip for float16

* fix dtype mistakes
上级 96a0ce60
...@@ -1343,6 +1343,9 @@ paddle::experimental::Scalar CastNumpy2Scalar(PyObject* obj, ...@@ -1343,6 +1343,9 @@ paddle::experimental::Scalar CastNumpy2Scalar(PyObject* obj,
} else if (type_name == "numpy.float32") { } else if (type_name == "numpy.float32") {
float value = CastPyArg2Float(obj, op_type, arg_pos); float value = CastPyArg2Float(obj, op_type, arg_pos);
return paddle::experimental::Scalar(value); return paddle::experimental::Scalar(value);
} else if (type_name == "numpy.float16") {
float16 value = CastPyArg2Float16(obj, op_type, arg_pos);
return paddle::experimental::Scalar(value);
} else if (type_name == "numpy.int64") { } else if (type_name == "numpy.int64") {
int64_t value = CastPyArg2Long(obj, op_type, arg_pos); int64_t value = CastPyArg2Long(obj, op_type, arg_pos);
return paddle::experimental::Scalar(value); return paddle::experimental::Scalar(value);
...@@ -1352,7 +1355,7 @@ paddle::experimental::Scalar CastNumpy2Scalar(PyObject* obj, ...@@ -1352,7 +1355,7 @@ paddle::experimental::Scalar CastNumpy2Scalar(PyObject* obj,
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be " "%s(): argument (position %d) must be "
"numpy.float32/float64, numpy.int32/int64, but got %s", "numpy.float16/float32/float64, numpy.int32/int64, but got %s",
op_type, op_type,
arg_pos + 1, arg_pos + 1,
type_name)); // NOLINT type_name)); // NOLINT
......
...@@ -184,6 +184,12 @@ void CastPyArg2AttrLong(PyObject* obj, ...@@ -184,6 +184,12 @@ void CastPyArg2AttrLong(PyObject* obj,
attrs[key] = CastPyArg2Long(obj, op_type, arg_pos); attrs[key] = CastPyArg2Long(obj, op_type, arg_pos);
} }
float16 CastPyArg2Float16(PyObject* obj,
const std::string& op_type,
ssize_t arg_pos) {
return static_cast<float16>(CastPyArg2Double(obj, op_type, arg_pos));
}
float CastPyArg2Float(PyObject* obj, float CastPyArg2Float(PyObject* obj,
const std::string& op_type, const std::string& op_type,
ssize_t arg_pos) { ssize_t arg_pos) {
......
...@@ -55,6 +55,9 @@ int CastPyArg2Int(PyObject* obj, const std::string& op_type, ssize_t arg_pos); ...@@ -55,6 +55,9 @@ int CastPyArg2Int(PyObject* obj, const std::string& op_type, ssize_t arg_pos);
int64_t CastPyArg2Long(PyObject* obj, int64_t CastPyArg2Long(PyObject* obj,
const std::string& op_type, const std::string& op_type,
ssize_t arg_pos); ssize_t arg_pos);
float16 CastPyArg2Float16(PyObject* obj,
const std::string& op_type,
ssize_t arg_pos);
float CastPyArg2Float(PyObject* obj, float CastPyArg2Float(PyObject* obj,
const std::string& op_type, const std::string& op_type,
ssize_t arg_pos); ssize_t arg_pos);
......
...@@ -737,6 +737,24 @@ class TestElementwiseAddop1(unittest.TestCase): ...@@ -737,6 +737,24 @@ class TestElementwiseAddop1(unittest.TestCase):
paddle.enable_static() paddle.enable_static()
class TestTensorAddNumpyScalar(unittest.TestCase):
def test_float32_add(self):
paddle.disable_static()
a = paddle.full([4, 5, 6], 1.5, dtype='float32')
b = np.array([1.5], dtype='float32')[0]
c = a + b
self.assertTrue(c.dtype == core.VarDesc.VarType.FP32)
def test_float16_add(self):
if not core.is_compiled_with_cuda():
return
paddle.disable_static()
a = paddle.full([4, 5, 6], 1.5, dtype='float16')
b = np.array([1.5], dtype='float16')[0]
c = a + b
self.assertTrue(c.dtype == core.VarDesc.VarType.FP16)
if __name__ == '__main__': if __name__ == '__main__':
paddle.enable_static() paddle.enable_static()
unittest.main() unittest.main()
...@@ -516,12 +516,12 @@ class SingleProcessMultiThread(GradAllReduce): ...@@ -516,12 +516,12 @@ class SingleProcessMultiThread(GradAllReduce):
def _transpile_main_program(self): def _transpile_main_program(self):
# not need loss scale and no dense param # not need loss scale and no dense param
param_cnt = self._get_update_param_count() param_cnt = self._get_update_param_count()
if self.loss_scale is 0 and param_cnt is 0: if self.loss_scale == 0 and param_cnt == 0:
return return
# scale loss # scale loss
self._insert_scale_loss_grad_ops() self._insert_scale_loss_grad_ops()
# no param # no param
if param_cnt is 0: if param_cnt == 0:
return return
# fuse allreduce # fuse allreduce
if self.fuse_allreduce > 0: if self.fuse_allreduce > 0:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册