提交 1f5192a2 编写于 作者: Q qijun

fix executor gpu unittest

上级 39f75a13
......@@ -30,7 +30,7 @@ Executor::Executor(const std::vector<platform::Place>& places) {
device_contexts_[i] = new platform::CPUDeviceContext(
boost::get<platform::CPUPlace>(places[i]));
} else if (platform::is_gpu_place(places[i])) {
#ifdef PADDLE_WITH_GPU
#ifdef PADDLE_WITH_CUDA
device_contexts_[i] = new platform::CUDADeviceContext(
boost::get<platform::GPUPlace>(places[i]));
#else
......
......@@ -293,7 +293,7 @@ TEST_F(ExecutorTesterFeed, CPU) {
delete executor;
}
#ifdef PADDLE_WITH_GPU
#ifdef PADDLE_WITH_CUDA
TEST_F(ExecutorTesterRandom, GPU) {
std::vector<Place> places;
GPUPlace gpu_place(0);
......@@ -315,10 +315,20 @@ TEST_F(ExecutorTesterFeed, GPU) {
Executor* executor = new Executor(places);
// need to set feed variable before Executor::Run
set_feed_variable<float>(inputs_);
executor->Run(pdesc_, GetScope());
// 3 mini-batch
for (int i = 0; i < 3; i++) {
// need to set feed variable before Executor::Run
std::cout << "start mini-batch " << i << std::endl;
set_feed_variable<float>(inputs_);
executor->Run(pdesc_, GetScope());
std::vector<std::vector<float>> result = get_fetch_variable<float>();
for (auto& vec : result) {
for (auto& num : vec) {
std::cout << num << " ";
}
std::cout << std::endl;
}
}
delete executor;
}
#endif
......@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/feed_op.h"
#include "paddle/operators/fetch_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(fetch, ops::FetchKernel<float>);
......@@ -43,7 +43,8 @@ int GetCurrentDeviceId() {
}
void SetDeviceId(int id) {
PADDLE_ENFORCE(id < GetDeviceCount(), "id must less than GPU count");
// TODO(qijun): find a better way to cache the cuda device count
PADDLE_ENFORCE(id < GetCUDADeviceCount(), "id must less than GPU count");
PADDLE_ENFORCE(cudaSetDevice(id),
"cudaSetDevice failed in paddle::platform::SetDeviceId");
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册