提交 d11b8e56 编写于 作者: F fengjiayi

fix

上级 211d56ed
......@@ -107,7 +107,7 @@ T TensorGetElement(const framework::Tensor &self, size_t offset) {
return self.data<T>()[offset];
} else {
std::shared_ptr<framework::Tensor> dst(new framework::Tensor);
framework::TensorCopy(self, platform::CPUPlace(), dst.get());
framework::TensorCopySync(self, platform::CPUPlace(), dst.get());
return dst->data<T>()[offset];
}
}
......@@ -117,9 +117,9 @@ template <typename T>
void TensorSetElement(framework::Tensor *self, size_t offset, T elem) {
if (platform::is_gpu_place(self->place())) {
std::shared_ptr<framework::Tensor> dst(new framework::Tensor);
framework::TensorCopy(*self, platform::CPUPlace(), dst.get());
framework::TensorCopySync(*self, platform::CPUPlace(), dst.get());
dst->data<T>()[offset] = elem;
framework::TensorCopy(*dst.get(), self->place(), self);
framework::TensorCopySync(*dst.get(), self->place(), self);
} else if (platform::is_cpu_place(self->place())) {
self->data<T>()[offset] = elem;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册