未验证 提交 cb73feea 编写于 作者: P pangyoki 提交者: GitHub

add Wait after TensorCopy (#34005)

上级 cbf22d65
...@@ -56,10 +56,13 @@ class NPUUniformRandomKernel : public framework::OpKernel<T> { ...@@ -56,10 +56,13 @@ class NPUUniformRandomKernel : public framework::OpKernel<T> {
"unsupport type: %s.", "unsupport type: %s.",
framework::ToTypeName(out_var->Type()))); framework::ToTypeName(out_var->Type())));
} }
T *data = tensor->mutable_data<T>(ctx.GetPlace()); tensor->mutable_data<T>(ctx.GetPlace());
int64_t size = tensor->numel(); int64_t size = tensor->numel();
std::unique_ptr<T[]> data_cpu(new T[size]);
Tensor cpu_tensor(tensor->type());
cpu_tensor.Resize(tensor->dims());
T *data_cpu = cpu_tensor.mutable_data<T>(platform::CPUPlace());
std::uniform_real_distribution<T> dist( std::uniform_real_distribution<T> dist(
static_cast<T>(ctx.Attr<float>("min")), static_cast<T>(ctx.Attr<float>("min")),
static_cast<T>(ctx.Attr<float>("max"))); static_cast<T>(ctx.Attr<float>("max")));
...@@ -90,12 +93,10 @@ class NPUUniformRandomKernel : public framework::OpKernel<T> { ...@@ -90,12 +93,10 @@ class NPUUniformRandomKernel : public framework::OpKernel<T> {
} }
// copy to NPU // copy to NPU
auto stream = framework::TensorCopy(
ctx.template device_context<paddle::platform::NPUDeviceContext>() cpu_tensor, ctx.GetPlace(),
.stream(); ctx.template device_context<platform::DeviceContext>(), tensor);
memory::Copy(BOOST_GET_CONST(platform::NPUPlace, ctx.GetPlace()), data, ctx.template device_context<paddle::platform::NPUDeviceContext>().Wait();
platform::CPUPlace(), reinterpret_cast<void *>(data_cpu.get()),
size * sizeof(T), stream);
} }
}; };
......
...@@ -67,7 +67,7 @@ class TestNPUUniformRandomOp(OpTest): ...@@ -67,7 +67,7 @@ class TestNPUUniformRandomOp(OpTest):
self.dtype = np.float32 self.dtype = np.float32
def test_check_output(self): def test_check_output(self):
self.check_output_customized(self.verify_output) self.check_output_customized(self.verify_output, self.place)
def verify_output(self, outs): def verify_output(self, outs):
hist, prob = self.output_hist(np.array(outs[0])) hist, prob = self.output_hist(np.array(outs[0]))
......
...@@ -1357,8 +1357,10 @@ class OpTest(unittest.TestCase): ...@@ -1357,8 +1357,10 @@ class OpTest(unittest.TestCase):
if self.op_type not in compile_vs_runtime_white_list.COMPILE_RUN_OP_WHITE_LIST: if self.op_type not in compile_vs_runtime_white_list.COMPILE_RUN_OP_WHITE_LIST:
self.check_compile_vs_runtime(fetch_list, outs) self.check_compile_vs_runtime(fetch_list, outs)
def check_output_customized(self, checker): def check_output_customized(self, checker, custom_place=None):
places = self._get_places() places = self._get_places()
if custom_place:
places.append(custom_place)
for place in places: for place in places:
outs = self.calc_output(place) outs = self.calc_output(place)
outs = [np.array(out) for out in outs] outs = [np.array(out) for out in outs]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册