提交 660f3e25 编写于 作者: Y Yu Yang

Fix CI

上级 12246a9d
...@@ -277,12 +277,16 @@ class Vector { ...@@ -277,12 +277,16 @@ class Vector {
kDirty = 0x10 kDirty = 0x10
}; };
void MutableCPU() { void CopyToCPU() const {
if (IsInCUDA() && IsDirty()) {
// COPY GPU Data To CPU // COPY GPU Data To CPU
Copy(cuda_vec_, platform::CPUPlace(), &cpu_vec_); Copy(cuda_vec_, platform::CPUPlace(), &cpu_vec_);
WaitPlace(cuda_vec_.place()); WaitPlace(cuda_vec_.place());
} }
void MutableCPU() {
if (IsInCUDA() && IsDirty()) {
CopyToCPU();
}
flag_ = kDirty | kDataInCPU; flag_ = kDirty | kDataInCPU;
} }
...@@ -311,8 +315,10 @@ class Vector { ...@@ -311,8 +315,10 @@ class Vector {
SetFlag(kDataInCUDA); SetFlag(kDataInCUDA);
} else if (!(place == cuda_vec_.place())) { } else if (!(place == cuda_vec_.place())) {
framework::Tensor tmp; framework::Tensor tmp;
WaitPlace(cuda_vec_.place());
Copy(cuda_vec_, boost::get<platform::CUDAPlace>(place), &tmp); Copy(cuda_vec_, boost::get<platform::CUDAPlace>(place), &tmp);
WaitPlace(cuda_vec_.place()); WaitPlace(cuda_vec_.place());
WaitPlace(place);
cuda_vec_.ShareDataWith(tmp); cuda_vec_.ShareDataWith(tmp);
} else { } else {
// Not Dirty && DataInCUDA && Device is same // Not Dirty && DataInCUDA && Device is same
...@@ -324,8 +330,7 @@ class Vector { ...@@ -324,8 +330,7 @@ class Vector {
void ImmutableCPU() const { void ImmutableCPU() const {
if (IsDirty() && if (IsDirty() &&
!IsInCPU()) { // If data has been changed in CUDA, or CPU has no data. !IsInCPU()) { // If data has been changed in CUDA, or CPU has no data.
Copy(cuda_vec_, platform::CPUPlace(), &cpu_vec_); CopyToCPU();
WaitPlace(cuda_vec_.place());
UnsetFlag(kDirty); UnsetFlag(kDirty);
} }
SetFlag(kDataInCPU); SetFlag(kDataInCPU);
......
...@@ -81,10 +81,12 @@ TEST(mixed_vector, MultiGPU) { ...@@ -81,10 +81,12 @@ TEST(mixed_vector, MultiGPU) {
} }
ASSERT_EQ(tmp.size(), 10); ASSERT_EQ(tmp.size(), 10);
paddle::platform::CUDAPlace gpu0(0); paddle::platform::CUDAPlace gpu0(0);
paddle::platform::SetDeviceId(0);
multiply_10<<<1, 1, 0, GetCUDAStream(gpu0)>>>(tmp.MutableData(gpu0)); multiply_10<<<1, 1, 0, GetCUDAStream(gpu0)>>>(tmp.MutableData(gpu0));
paddle::platform::CUDAPlace gpu1(1); paddle::platform::CUDAPlace gpu1(1);
multiply_10<<<1, 1, 0, GetCUDAStream(gpu1)>>>(tmp.MutableData(gpu1)); auto* gpu1_ptr = tmp.MutableData(gpu1);
paddle::platform::SetDeviceId(1);
multiply_10<<<1, 1, 0, GetCUDAStream(gpu1)>>>(gpu1_ptr);
for (int i = 0; i < 10; ++i) { for (int i = 0; i < 10; ++i) {
ASSERT_EQ(tmp[i], i * 100); ASSERT_EQ(tmp[i], i * 100);
} }
......
...@@ -154,7 +154,9 @@ struct SelectedRowsAddTo<platform::CUDADeviceContext, T> { ...@@ -154,7 +154,9 @@ struct SelectedRowsAddTo<platform::CUDADeviceContext, T> {
auto* in2_value = input2->mutable_value(); auto* in2_value = input2->mutable_value();
// concat rows // concat rows
if (in1_rows.size()) {
in2_rows.Extend(in1_rows.begin(), in1_rows.end()); in2_rows.Extend(in1_rows.begin(), in1_rows.end());
}
auto in1_place = input1.place(); auto in1_place = input1.place();
PADDLE_ENFORCE(platform::is_gpu_place(in1_place)); PADDLE_ENFORCE(platform::is_gpu_place(in1_place));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册