未验证 提交 bfb79ee2 编写于 作者: I Infinity_lee 提交者: GitHub

【Hackathon No58】-fix set_value (#51197)

上级 2404847d
...@@ -29,5 +29,6 @@ PD_REGISTER_KERNEL(set_value_grad, ...@@ -29,5 +29,6 @@ PD_REGISTER_KERNEL(set_value_grad,
int64_t, int64_t,
bool, bool,
phi::dtype::float16, phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>, phi::dtype::complex<float>,
phi::dtype::complex<double>) {} phi::dtype::complex<double>) {}
...@@ -29,6 +29,7 @@ PD_REGISTER_KERNEL(set_value, ...@@ -29,6 +29,7 @@ PD_REGISTER_KERNEL(set_value,
int64_t, int64_t,
bool, bool,
phi::dtype::float16, phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>, phi::dtype::complex<float>,
phi::dtype::complex<double>) {} phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(set_value_with_tensor, PD_REGISTER_KERNEL(set_value_with_tensor,
...@@ -41,5 +42,6 @@ PD_REGISTER_KERNEL(set_value_with_tensor, ...@@ -41,5 +42,6 @@ PD_REGISTER_KERNEL(set_value_with_tensor,
int64_t, int64_t,
bool, bool,
phi::dtype::float16, phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>, phi::dtype::complex<float>,
phi::dtype::complex<double>) {} phi::dtype::complex<double>) {}
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "paddle/phi/common/int_array.h" #include "paddle/phi/common/int_array.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"
namespace phi { namespace phi {
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include "paddle/phi/common/int_array.h" #include "paddle/phi/common/int_array.h"
#include "paddle/phi/common/scalar.h" #include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"
#include "paddle/phi/infermeta/unary.h" #include "paddle/phi/infermeta/unary.h"
namespace phi { namespace phi {
......
...@@ -18,8 +18,10 @@ import unittest ...@@ -18,8 +18,10 @@ import unittest
from functools import reduce from functools import reduce
import numpy as np import numpy as np
from op_test import OpTest, convert_float_to_uint16
import paddle import paddle
import paddle.fluid.core as core
from paddle.fluid.layer_helper import LayerHelper from paddle.fluid.layer_helper import LayerHelper
...@@ -1521,5 +1523,31 @@ class TestSetValueInplaceLeafVar(unittest.TestCase): ...@@ -1521,5 +1523,31 @@ class TestSetValueInplaceLeafVar(unittest.TestCase):
paddle.enable_static() paddle.enable_static()
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestSetValueBFloat16(OpTest):
def setUp(self):
self.dtype = np.uint16
self.shape = [2, 3, 4]
self.__class__.op_type = self.op_type
self.data = np.ones(self.shape).astype(self.dtype)
x = np.random.rand([6]).astype('float32')
self.data[0, 0] = np.random.rand([6]).astype('float32')
out = self.data[0, 0]
self.inputs = {'X': convert_float_to_uint16(x)}
self.outputs = {'Out': convert_float_to_uint16(out)}
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place)
def test_check_grad(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X'], 'Out')
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册