未验证 提交 9396f286 编写于 作者: A Aganlengzi 提交者: GitHub

[NPU] fix fill_constant and test_memcpy_op_npu (#37144)

上级 1773afd7
...@@ -87,6 +87,9 @@ class FillConstantOp : public framework::OperatorWithKernel { ...@@ -87,6 +87,9 @@ class FillConstantOp : public framework::OperatorWithKernel {
case 3: case 3:
kt.place_ = platform::XPUPlace(); kt.place_ = platform::XPUPlace();
break; break;
case 4:
kt.place_ = platform::NPUPlace();
break;
default: default:
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"Could NOT determine the place of variable, place_type = %d .", "Could NOT determine the place of variable, place_type = %d .",
...@@ -164,7 +167,8 @@ class FillConstantOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -164,7 +167,8 @@ class FillConstantOpMaker : public framework::OpProtoAndCheckerMaker {
"0: CPUPlace. " "0: CPUPlace. "
"1: CUDAPlace. " "1: CUDAPlace. "
"2: CUDAPinnedPlace. " "2: CUDAPinnedPlace. "
"3: XPUPlace. ") "3: XPUPlace. "
"4: NPUPlace. ")
.SetDefault(-1); .SetDefault(-1);
AddOutput("Out", AddOutput("Out",
"(Tensor) Tensor of specified shape will be filled " "(Tensor) Tensor of specified shape will be filled "
......
...@@ -144,5 +144,26 @@ class TestFillConstantBool(OpTest): ...@@ -144,5 +144,26 @@ class TestFillConstantBool(OpTest):
self.check_output_with_place(self.place) self.check_output_with_place(self.place)
class TestFillConstantWithPlaceType(OpTest):
def setUp(self):
self.set_npu()
self.place = paddle.NPUPlace(0)
self.op_type = "fill_constant"
self.init_dtype()
self.inputs = {}
self.attrs = {'shape': [123, 92], 'value': 3.8, 'place_type': 4}
self.outputs = {'Out': np.full((123, 92), 3.8)}
def set_npu(self):
self.__class__.use_npu = True
def init_dtype(self):
self.dtype = np.float32
def test_check_output(self):
self.check_output_with_place(self.place)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -54,7 +54,7 @@ class TestMemcpy_FillConstant(unittest.TestCase): ...@@ -54,7 +54,7 @@ class TestMemcpy_FillConstant(unittest.TestCase):
"shape": [10, 10], "shape": [10, 10],
"dtype": npu_var.dtype, "dtype": npu_var.dtype,
"value": 1.0, "value": 1.0,
"place_type": 1 "place_type": 4
}) })
main_program.global_block().append_op( main_program.global_block().append_op(
type="fill_constant", type="fill_constant",
...@@ -63,7 +63,7 @@ class TestMemcpy_FillConstant(unittest.TestCase): ...@@ -63,7 +63,7 @@ class TestMemcpy_FillConstant(unittest.TestCase):
"shape": [10, 10], "shape": [10, 10],
"dtype": cpu_var.dtype, "dtype": cpu_var.dtype,
"value": 0.0, "value": 0.0,
"place_type": 2 "place_type": 0
}) })
return main_program, npu_var, cpu_var return main_program, npu_var, cpu_var
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册