未验证 提交 cb73feea 编写于 作者: P pangyoki 提交者: GitHub

add Wait after TensorCopy (#34005)

上级 cbf22d65
......@@ -56,10 +56,13 @@ class NPUUniformRandomKernel : public framework::OpKernel<T> {
"unsupport type: %s.",
framework::ToTypeName(out_var->Type())));
}
T *data = tensor->mutable_data<T>(ctx.GetPlace());
tensor->mutable_data<T>(ctx.GetPlace());
int64_t size = tensor->numel();
std::unique_ptr<T[]> data_cpu(new T[size]);
Tensor cpu_tensor(tensor->type());
cpu_tensor.Resize(tensor->dims());
T *data_cpu = cpu_tensor.mutable_data<T>(platform::CPUPlace());
std::uniform_real_distribution<T> dist(
static_cast<T>(ctx.Attr<float>("min")),
static_cast<T>(ctx.Attr<float>("max")));
......@@ -90,12 +93,10 @@ class NPUUniformRandomKernel : public framework::OpKernel<T> {
}
// copy to NPU
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
memory::Copy(BOOST_GET_CONST(platform::NPUPlace, ctx.GetPlace()), data,
platform::CPUPlace(), reinterpret_cast<void *>(data_cpu.get()),
size * sizeof(T), stream);
framework::TensorCopy(
cpu_tensor, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), tensor);
ctx.template device_context<paddle::platform::NPUDeviceContext>().Wait();
}
};
......
......@@ -67,7 +67,7 @@ class TestNPUUniformRandomOp(OpTest):
self.dtype = np.float32
def test_check_output(self):
self.check_output_customized(self.verify_output)
self.check_output_customized(self.verify_output, self.place)
def verify_output(self, outs):
hist, prob = self.output_hist(np.array(outs[0]))
......
......@@ -1357,8 +1357,10 @@ class OpTest(unittest.TestCase):
if self.op_type not in compile_vs_runtime_white_list.COMPILE_RUN_OP_WHITE_LIST:
self.check_compile_vs_runtime(fetch_list, outs)
def check_output_customized(self, checker):
def check_output_customized(self, checker, custom_place=None):
places = self._get_places()
if custom_place:
places.append(custom_place)
for place in places:
outs = self.calc_output(place)
outs = [np.array(out) for out in outs]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册