diff --git a/python/paddle/v2/framework/tests/mnist.py b/python/paddle/v2/framework/tests/mnist.py index 0c27ce3e355b6accac96ca57bb08564742ffc318..d9941023fe69fcd6e924766ef88702b4818abeee 100644 --- a/python/paddle/v2/framework/tests/mnist.py +++ b/python/paddle/v2/framework/tests/mnist.py @@ -7,6 +7,8 @@ BATCH_SIZE = 100 scope = core.Scope() place = core.CPUPlace() +# if you want to test GPU training, you can use gpu place +# place = core.GPUPlace(0) dev_ctx = core.DeviceContext.create(place) init_net = core.Net.create()