未验证 提交 2238a535 编写于 作者: G Guoxia Wang 提交者: GitHub

remove fp32 tmp tensor and cast op for initializer.Normal and initializer.Constant (#38818)

上级 04f73d89
......@@ -18,6 +18,7 @@ limitations under the License. */
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/fill_constant_op.h"
namespace paddle {
......@@ -38,10 +39,12 @@ struct GaussianGenerator {
__host__ __device__ T operator()(const unsigned int n) const {
thrust::minstd_rand rng;
rng.seed(seed_);
thrust::normal_distribution<T> dist(mean_, std_);
using MT = typename details::MPTypeTrait<T>::Type;
thrust::normal_distribution<MT> dist(mean_, std_);
unsigned int new_n = n + offset_;
rng.discard(new_n);
return dist(rng);
MT out = dist(rng);
return static_cast<T>(out);
}
};
......@@ -124,10 +127,14 @@ class GPUGaussianRandomBatchSizeLikeKernel : public framework::OpKernel<T> {
} // namespace operators
} // namespace paddle
REGISTER_OP_CUDA_KERNEL(gaussian_random,
REGISTER_OP_CUDA_KERNEL(
gaussian_random,
paddle::operators::GPUGaussianRandomKernel<paddle::platform::float16>,
paddle::operators::GPUGaussianRandomKernel<float>,
paddle::operators::GPUGaussianRandomKernel<double>);
REGISTER_OP_CUDA_KERNEL(
gaussian_random_batch_size_like,
paddle::operators::GPUGaussianRandomBatchSizeLikeKernel<
paddle::platform::float16>,
paddle::operators::GPUGaussianRandomBatchSizeLikeKernel<float>,
paddle::operators::GPUGaussianRandomBatchSizeLikeKernel<double>);
......@@ -137,54 +137,27 @@ class ConstantInitializer(Initializer):
isinstance(var, framework.EagerParamBase))
assert isinstance(block, framework.Block)
# to be compatible of fp16 initializers
if var.dtype == VarDesc.VarType.FP16:
out_dtype = VarDesc.VarType.FP32
out_var = block.create_var(
name=unique_name.generate(".".join(
['constant_init', var.name, 'tmp'])),
shape=var.shape,
dtype=out_dtype,
type=VarDesc.VarType.LOD_TENSOR,
persistable=False)
else:
out_dtype = var.dtype
out_var = var
if framework.in_dygraph_mode():
out_var = _C_ops.fill_constant(
out_var, 'value',
var = _C_ops.fill_constant(
var, 'value',
float(self._value), 'force_cpu', self._force_cpu, 'dtype',
int(out_dtype), 'str_value',
int(var.dtype), 'str_value',
str(float(self._value)), 'shape', var.shape)
if var.dtype == VarDesc.VarType.FP16:
var_tmp = _C_ops.cast(out_var, 'in_dtype', out_var.dtype,
'out_dtype', var.dtype)
var.copy_(var_tmp, False)
else:
var.copy_(out_var, False)
return None
else:
# fill constant should set the "str_value" to preserve precision
op = block.append_op(
type="fill_constant",
outputs={"Out": out_var},
outputs={"Out": var},
attrs={
"shape": var.shape,
"dtype": int(out_dtype),
"dtype": int(var.dtype),
"value": float(self._value),
'str_value': str(float(self._value)),
'force_cpu': self._force_cpu
},
stop_gradient=True)
if var.dtype == VarDesc.VarType.FP16:
block.append_op(
type="cast",
inputs={"X": out_var},
outputs={"Out": var},
attrs={"in_dtype": out_var.dtype,
"out_dtype": var.dtype})
var.op = op
return op
......@@ -361,38 +334,12 @@ class NormalInitializer(Initializer):
if self._seed == 0:
self._seed = block.program.random_seed
# to be compatible of fp16 initalizers
if var.dtype in [VarDesc.VarType.FP16, VarDesc.VarType.BF16]:
out_dtype = VarDesc.VarType.FP32
out_var = block.create_var(
name=unique_name.generate(".".join(
['gaussian_random', var.name, 'tmp'])),
shape=var.shape,
dtype=out_dtype,
type=VarDesc.VarType.LOD_TENSOR,
persistable=False)
else:
out_dtype = var.dtype
out_var = var
if framework.in_dygraph_mode():
out_var = _C_ops.gaussian_random(
'shape', var.shape, 'dtype', out_dtype, 'mean', self._mean,
'std', self._std_dev, 'seed', self._seed, 'use_mkldnn', False)
if var.dtype in [VarDesc.VarType.FP16, VarDesc.VarType.BF16]:
var_tmp = _C_ops.cast(out_var, 'in_dtype', out_var.dtype,
'out_dtype', var.dtype)
var.copy_(var_tmp, False)
else:
var.copy_(out_var, False)
return None
else:
op = block.append_op(
type="gaussian_random",
outputs={"Out": out_var},
outputs={"Out": var},
attrs={
"shape": var.shape,
"dtype": out_dtype,
"dtype": var.dtype,
"mean": self._mean,
"std": self._std_dev,
"seed": self._seed,
......@@ -400,15 +347,11 @@ class NormalInitializer(Initializer):
},
stop_gradient=True)
if var.dtype in [VarDesc.VarType.FP16, VarDesc.VarType.BF16]:
block.append_op(
type="cast",
inputs={"X": out_var},
outputs={"Out": var},
attrs={"in_dtype": out_var.dtype,
"out_dtype": var.dtype})
if not framework.in_dygraph_mode():
var.op = op
return op
else:
return None
class TruncatedNormalInitializer(Initializer):
......
......@@ -65,7 +65,7 @@ class TestConstantInitializer(unittest.TestCase):
lod_level=0,
name="param",
initializer=initializer.ConstantInitializer())
num_ops = 2 if dtype == "float16" else 1
num_ops = 1
self.assertEqual(len(block.ops), num_ops)
init_op = block.ops[0]
self.assertEqual(init_op.type, 'fill_constant')
......@@ -84,7 +84,7 @@ class TestConstantInitializer(unittest.TestCase):
lod_level=0,
name="param",
initializer=initializer.ConstantInitializer(2.3))
num_ops = 2 if dtype == "float16" else 1
num_ops = 1
self.assertEqual(len(block.ops), num_ops)
init_op = block.ops[0]
self.assertEqual(init_op.type, 'fill_constant')
......@@ -94,10 +94,8 @@ class TestConstantInitializer(unittest.TestCase):
def test_constant_initializer_fp16(self):
"""Test constant initializer with float16
"""
block = self.test_constant_initializer_default_value("float16")
self.assertTrue(check_cast_op(block.ops[1]))
block = self.test_constant_initializer("float16")
self.assertTrue(check_cast_op(block.ops[1]))
self.test_constant_initializer_default_value("float16")
self.test_constant_initializer("float16")
def test_constant_initializer_bf16(self):
"""Test constant initializer with bfloat16
......@@ -246,7 +244,7 @@ class TestNormalInitializer(unittest.TestCase):
lod_level=0,
name="param",
initializer=initializer.NormalInitializer(2.3, 1.9, 123))
num_ops = 2 if dtype in ["float16", "uint16"] else 1
num_ops = 1
self.assertEqual(len(block.ops), num_ops)
init_op = block.ops[0]
self.assertEqual(init_op.type, 'gaussian_random')
......@@ -258,14 +256,12 @@ class TestNormalInitializer(unittest.TestCase):
def test_normal_initializer_fp16(self):
"""Test normal initializer with float16
"""
block = self.test_normal_initializer("float16")
self.assertTrue(check_cast_op(block.ops[1]))
self.test_normal_initializer("float16")
def test_normal_initializer_bf16(self):
"""Test normal initializer with bfloat16
"""
block = self.test_normal_initializer("uint16")
self.assertTrue(check_cast_op(block.ops[1]))
self.test_normal_initializer("uint16")
class TestXavierInitializer(unittest.TestCase):
......
......@@ -54,7 +54,7 @@ class TestConstantInitializer(unittest.TestCase):
lod_level=0,
name="param",
initializer=init_inst)
num_ops = 2 if dtype in ["float16"] else 1
num_ops = 1
self.assertEqual(len(block.ops), num_ops)
init_op = block.ops[0]
self.assertEqual(init_op.type, 'fill_constant')
......@@ -103,9 +103,7 @@ class TestConstantInitializer(unittest.TestCase):
"""Test constant initializer with float16
"""
block = self.test_constant_initializer_default_value_static("float16")
self.assertTrue(check_cast_op(block.ops[1]))
block = self.test_constant_initializer_static("float16")
self.assertTrue(check_cast_op(block.ops[1]))
self.test_constant_initializer_default_value_dygraph("float16")
self.test_constant_initializer_dygraph("float16")
......@@ -402,7 +400,7 @@ class TestNormal(unittest.TestCase):
lod_level=0,
name="param",
initializer=initializer.Normal(2.3, 1.9))
num_ops = 2 if dtype in ["float16", "uint16"] else 1
num_ops = 1
self.assertEqual(len(block.ops), num_ops)
init_op = block.ops[0]
self.assertEqual(init_op.type, 'gaussian_random')
......@@ -417,13 +415,11 @@ class TestNormal(unittest.TestCase):
"""Test normal initializer with float16
"""
block = self.test_normal_initializer("float16")
self.assertTrue(check_cast_op(block.ops[1]))
def test_normal_initializer_bf16(self):
"""Test normal initializer with bfloat16
"""
block = self.test_normal_initializer("uint16") #bfloat16
self.assertTrue(check_cast_op(block.ops[1]))
def test_normal_initializer_dygraph(self):
"""Test normal initializer in dygraph model.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册