提交 1f5192a2 编写于 作者: Q qijun

fix executor gpu unittest

上级 39f75a13
...@@ -30,7 +30,7 @@ Executor::Executor(const std::vector<platform::Place>& places) { ...@@ -30,7 +30,7 @@ Executor::Executor(const std::vector<platform::Place>& places) {
device_contexts_[i] = new platform::CPUDeviceContext( device_contexts_[i] = new platform::CPUDeviceContext(
boost::get<platform::CPUPlace>(places[i])); boost::get<platform::CPUPlace>(places[i]));
} else if (platform::is_gpu_place(places[i])) { } else if (platform::is_gpu_place(places[i])) {
#ifdef PADDLE_WITH_GPU #ifdef PADDLE_WITH_CUDA
device_contexts_[i] = new platform::CUDADeviceContext( device_contexts_[i] = new platform::CUDADeviceContext(
boost::get<platform::GPUPlace>(places[i])); boost::get<platform::GPUPlace>(places[i]));
#else #else
......
...@@ -293,7 +293,7 @@ TEST_F(ExecutorTesterFeed, CPU) { ...@@ -293,7 +293,7 @@ TEST_F(ExecutorTesterFeed, CPU) {
delete executor; delete executor;
} }
#ifdef PADDLE_WITH_GPU #ifdef PADDLE_WITH_CUDA
TEST_F(ExecutorTesterRandom, GPU) { TEST_F(ExecutorTesterRandom, GPU) {
std::vector<Place> places; std::vector<Place> places;
GPUPlace gpu_place(0); GPUPlace gpu_place(0);
...@@ -315,10 +315,20 @@ TEST_F(ExecutorTesterFeed, GPU) { ...@@ -315,10 +315,20 @@ TEST_F(ExecutorTesterFeed, GPU) {
Executor* executor = new Executor(places); Executor* executor = new Executor(places);
// need to set feed variable before Executor::Run // 3 mini-batch
set_feed_variable<float>(inputs_); for (int i = 0; i < 3; i++) {
executor->Run(pdesc_, GetScope()); // need to set feed variable before Executor::Run
std::cout << "start mini-batch " << i << std::endl;
set_feed_variable<float>(inputs_);
executor->Run(pdesc_, GetScope());
std::vector<std::vector<float>> result = get_fetch_variable<float>();
for (auto& vec : result) {
for (auto& num : vec) {
std::cout << num << " ";
}
std::cout << std::endl;
}
}
delete executor; delete executor;
} }
#endif #endif
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/operators/feed_op.h" #include "paddle/operators/fetch_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(fetch, ops::FetchKernel<float>); REGISTER_OP_GPU_KERNEL(fetch, ops::FetchKernel<float>);
...@@ -43,7 +43,8 @@ int GetCurrentDeviceId() { ...@@ -43,7 +43,8 @@ int GetCurrentDeviceId() {
} }
void SetDeviceId(int id) { void SetDeviceId(int id) {
PADDLE_ENFORCE(id < GetDeviceCount(), "id must less than GPU count"); // TODO(qijun): find a better way to cache the cuda device count
PADDLE_ENFORCE(id < GetCUDADeviceCount(), "id must less than GPU count");
PADDLE_ENFORCE(cudaSetDevice(id), PADDLE_ENFORCE(cudaSetDevice(id),
"cudaSetDevice failed in paddle::platform::SetDeviceId"); "cudaSetDevice failed in paddle::platform::SetDeviceId");
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册