diff --git a/python/paddle/fluid/layers/tensor.py b/python/paddle/fluid/layers/tensor.py index a6ed2a30a08e3c58bd565bf942fd0dfdf54a4d3a..85e9424f2eae97232511f2186ca740f7fe509678 100644 --- a/python/paddle/fluid/layers/tensor.py +++ b/python/paddle/fluid/layers/tensor.py @@ -193,8 +193,16 @@ def create_global_var(shape, numpy.int64), 'create_global_var') check_dtype(dtype, 'dtype', [ - 'bool', 'float16', 'float32', 'float64', 'int8', 'int16', 'int32', - 'int64', 'uint8' + 'bool', + 'float16', + 'float32', + 'float64', + 'int8', + 'int16', + 'int32', + 'int64', + 'uint8', + 'uint16', ], 'create_global_var') helper = LayerHelper("global_var", **locals()) diff --git a/python/paddle/fluid/tests/unittests/test_sgd_op_bf16.py b/python/paddle/fluid/tests/unittests/test_sgd_op_bf16.py index e60b04257dbbd931385f3ab2989698e6c8df2ab7..8b65fc4e431f916be6beadb44d6f29edacf76982 100644 --- a/python/paddle/fluid/tests/unittests/test_sgd_op_bf16.py +++ b/python/paddle/fluid/tests/unittests/test_sgd_op_bf16.py @@ -20,8 +20,10 @@ import paddle.fluid as fluid import paddle.fluid.core as core from paddle.fluid.op import Operator from paddle.fluid.tests.unittests.op_test import ( - OpTest, convert_float_to_uint16, convert_uint16_to_float) + convert_float_to_uint16, convert_uint16_to_float, OpTest, OpTestTool) import paddle +import paddle.static.amp as amp +import struct @unittest.skipIf(not core.supports_bfloat16(), @@ -209,6 +211,152 @@ class TestSparseGradParamSGDOpBF16Case2(TestSparseGradParamSGDOpBF16): self.param_rows = [a for a in range(self.grad_height)] +@OpTestTool.skip_if_not_cpu_bf16() +class TestSGDOpBF16API(unittest.TestCase): + def setUp(self): + self.sample_count = 20 + self.value = np.random.random() + + self.ids_shape = (32, 1) + self.w_shape = (64, 16) + self.y_shape = (32, 16) + self.learning_rate = 0.1 + + np.random.seed(12345) + self._set_initializer() + fluid.set_flags({'FLAGS_use_mkldnn': True}) + + def _fp322bf16(self, val: np.float32): + return np.uint16(struct.unpack('> 16) + + def _bf162fp32(self, val: np.uint16): + return np.float32(struct.unpack('