未验证 提交 bfb79ee2 编写于 作者: I Infinity_lee 提交者: GitHub

【Hackathon No58】-fix set_value (#51197)

上级 2404847d
......@@ -29,5 +29,6 @@ PD_REGISTER_KERNEL(set_value_grad,
int64_t,
bool,
phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
......@@ -29,6 +29,7 @@ PD_REGISTER_KERNEL(set_value,
int64_t,
bool,
phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(set_value_with_tensor,
......@@ -41,5 +42,6 @@ PD_REGISTER_KERNEL(set_value_with_tensor,
int64_t,
bool,
phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
......@@ -16,6 +16,7 @@
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"
namespace phi {
......
......@@ -17,6 +17,7 @@
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"
#include "paddle/phi/infermeta/unary.h"
namespace phi {
......
......@@ -18,8 +18,10 @@ import unittest
from functools import reduce
import numpy as np
from op_test import OpTest, convert_float_to_uint16
import paddle
import paddle.fluid.core as core
from paddle.fluid.layer_helper import LayerHelper
......@@ -1521,5 +1523,31 @@ class TestSetValueInplaceLeafVar(unittest.TestCase):
paddle.enable_static()
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestSetValueBFloat16(OpTest):
def setUp(self):
self.dtype = np.uint16
self.shape = [2, 3, 4]
self.__class__.op_type = self.op_type
self.data = np.ones(self.shape).astype(self.dtype)
x = np.random.rand([6]).astype('float32')
self.data[0, 0] = np.random.rand([6]).astype('float32')
out = self.data[0, 0]
self.inputs = {'X': convert_float_to_uint16(x)}
self.outputs = {'Out': convert_float_to_uint16(out)}
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place)
def test_check_grad(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X'], 'Out')
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册