未验证 提交 9396f286 编写于 作者: A Aganlengzi 提交者: GitHub

[NPU] fix fill_constant and test_memcpy_op_npu (#37144)

上级 1773afd7
......@@ -87,6 +87,9 @@ class FillConstantOp : public framework::OperatorWithKernel {
case 3:
kt.place_ = platform::XPUPlace();
break;
case 4:
kt.place_ = platform::NPUPlace();
break;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Could NOT determine the place of variable, place_type = %d .",
......@@ -164,7 +167,8 @@ class FillConstantOpMaker : public framework::OpProtoAndCheckerMaker {
"0: CPUPlace. "
"1: CUDAPlace. "
"2: CUDAPinnedPlace. "
"3: XPUPlace. ")
"3: XPUPlace. "
"4: NPUPlace. ")
.SetDefault(-1);
AddOutput("Out",
"(Tensor) Tensor of specified shape will be filled "
......
......@@ -144,5 +144,26 @@ class TestFillConstantBool(OpTest):
self.check_output_with_place(self.place)
class TestFillConstantWithPlaceType(OpTest):
def setUp(self):
self.set_npu()
self.place = paddle.NPUPlace(0)
self.op_type = "fill_constant"
self.init_dtype()
self.inputs = {}
self.attrs = {'shape': [123, 92], 'value': 3.8, 'place_type': 4}
self.outputs = {'Out': np.full((123, 92), 3.8)}
def set_npu(self):
self.__class__.use_npu = True
def init_dtype(self):
self.dtype = np.float32
def test_check_output(self):
self.check_output_with_place(self.place)
if __name__ == '__main__':
unittest.main()
......@@ -54,7 +54,7 @@ class TestMemcpy_FillConstant(unittest.TestCase):
"shape": [10, 10],
"dtype": npu_var.dtype,
"value": 1.0,
"place_type": 1
"place_type": 4
})
main_program.global_block().append_op(
type="fill_constant",
......@@ -63,7 +63,7 @@ class TestMemcpy_FillConstant(unittest.TestCase):
"shape": [10, 10],
"dtype": cpu_var.dtype,
"value": 0.0,
"place_type": 2
"place_type": 0
})
return main_program, npu_var, cpu_var
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册