提交 4a1f7bd2 编写于 作者: Q qijun

add gpu python op test

上级 47d8bca8
...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <string>
#include "paddle/memory/memcpy.h" #include "paddle/memory/memcpy.h"
namespace paddle { namespace paddle {
...@@ -62,9 +62,11 @@ inline T* Tensor::mutable_data(platform::Place place) { ...@@ -62,9 +62,11 @@ inline T* Tensor::mutable_data(platform::Place place) {
if (platform::is_cpu_place(place)) { if (platform::is_cpu_place(place)) {
holder_.reset(new PlaceholderImpl<T, platform::CPUPlace>( holder_.reset(new PlaceholderImpl<T, platform::CPUPlace>(
boost::get<platform::CPUPlace>(place), size)); boost::get<platform::CPUPlace>(place), size));
} else if (platform::is_gpu_place(place)) {
#ifdef PADDLE_ONLY_CPU
PADDLE_THROW("'GPUPlace' is not supported in CPU only device.");
} }
#ifndef PADDLE_ONLY_CPU #else
else if (platform::is_gpu_place(place)) {
holder_.reset(new PlaceholderImpl<T, platform::GPUPlace>( holder_.reset(new PlaceholderImpl<T, platform::GPUPlace>(
boost::get<platform::GPUPlace>(place), size)); boost::get<platform::GPUPlace>(place), size));
} }
......
...@@ -132,12 +132,12 @@ inline void throw_on_error(T e) { ...@@ -132,12 +132,12 @@ inline void throw_on_error(T e) {
throw_on_error(e, ""); throw_on_error(e, "");
} }
#define PADDLE_THROW(...) \ #define PADDLE_THROW(...) \
do { \ do { \
throw ::paddle::platform::EnforceNotMet( \ throw ::paddle::platform::EnforceNotMet( \
std::make_exception_ptr( \ std::make_exception_ptr( \
std::runtime_error(string::Sprintf(__VA_ARGS__))), \ std::runtime_error(paddle::string::Sprintf(__VA_ARGS__))), \
__FILE__, __LINE__); \ __FILE__, __LINE__); \
} while (0) } while (0)
#define PADDLE_ENFORCE(...) \ #define PADDLE_ENFORCE(...) \
......
...@@ -56,6 +56,14 @@ static size_t UniqueIntegerGenerator() { ...@@ -56,6 +56,14 @@ static size_t UniqueIntegerGenerator() {
return generator.fetch_add(1); return generator.fetch_add(1);
} }
bool IsCompileGPU() {
#ifdef PADDLE_ONLY_CPU
return false;
#else
return true;
#endif
}
PYBIND11_PLUGIN(core) { PYBIND11_PLUGIN(core) {
py::module m("core", "C++ core of PaddlePaddle"); py::module m("core", "C++ core of PaddlePaddle");
...@@ -148,18 +156,23 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -148,18 +156,23 @@ All parameter, weight, gradient are variables in Paddle.
.def("temp", pd::OperatorBase::TMP_VAR_NAME); .def("temp", pd::OperatorBase::TMP_VAR_NAME);
py::class_<paddle::platform::DeviceContext>(m, "DeviceContext") py::class_<paddle::platform::DeviceContext>(m, "DeviceContext")
.def_static("cpu_context", .def_static("create",
[]() -> paddle::platform::DeviceContext* { [](paddle::platform::CPUPlace& place)
return new paddle::platform::CPUDeviceContext();
})
#ifndef PADDLE_ONLY_CPU
.def_static("gpu_context",
[](paddle::platform::GPUPlace& place)
-> paddle::platform::DeviceContext* { -> paddle::platform::DeviceContext* {
return new paddle::platform::CUDADeviceContext(place); return new paddle::platform::CPUDeviceContext();
}) })
.def_static(
"create",
[](paddle::platform::GPUPlace& place)
-> paddle::platform::DeviceContext* {
#ifdef PADDLE_ONLY_CPU
PADDLE_THROW("'GPUPlace' is not supported in CPU only device.");
#else
return new paddle::platform::CUDADeviceContext(place);
#endif #endif
; // NOLINT });
py::class_<paddle::platform::GPUPlace>(m, "GPUPlace").def(py::init<int>()); py::class_<paddle::platform::GPUPlace>(m, "GPUPlace").def(py::init<int>());
py::class_<paddle::platform::CPUPlace>(m, "CPUPlace").def(py::init<>()); py::class_<paddle::platform::CPUPlace>(m, "CPUPlace").def(py::init<>());
...@@ -198,5 +211,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -198,5 +211,7 @@ All parameter, weight, gradient are variables in Paddle.
m.def("unique_integer", UniqueIntegerGenerator); m.def("unique_integer", UniqueIntegerGenerator);
m.def("is_compile_gpu", IsCompileGPU);
return m.ptr(); return m.ptr();
} }
...@@ -25,42 +25,48 @@ class OpTestMeta(type): ...@@ -25,42 +25,48 @@ class OpTestMeta(type):
self.assertIsNotNone(func) self.assertIsNotNone(func)
scope = core.Scope(None) scope = core.Scope(None)
place = core.CPUPlace()
kwargs = dict() kwargs = dict()
for in_name in func.all_input_args: places = []
if hasattr(self, in_name): places.append(core.CPUPlace())
kwargs[in_name] = in_name if core.is_compile_gpu():
var = scope.create_var(in_name).get_tensor() places.append(core.GPUPlace(0))
arr = getattr(self, in_name)
var.set_dims(arr.shape) for place in places:
var.set(arr, place) for in_name in func.all_input_args:
else: if hasattr(self, in_name):
kwargs[in_name] = "@EMPTY@" kwargs[in_name] = in_name
var = scope.create_var(in_name).get_tensor()
arr = getattr(self, in_name)
var.set_dims(arr.shape)
var.set(arr, place)
else:
kwargs[in_name] = "@EMPTY@"
for out_name in func.all_output_args: for out_name in func.all_output_args:
if hasattr(self, out_name): if hasattr(self, out_name):
kwargs[out_name] = out_name kwargs[out_name] = out_name
scope.create_var(out_name).get_tensor() scope.create_var(out_name).get_tensor()
for attr_name in func.all_attr_args: for attr_name in func.all_attr_args:
if hasattr(self, attr_name): if hasattr(self, attr_name):
kwargs[attr_name] = getattr(self, attr_name) kwargs[attr_name] = getattr(self, attr_name)
op = func(**kwargs) op = func(**kwargs)
op.infer_shape(scope) op.infer_shape(scope)
ctx = core.DeviceContext.cpu_context() ctx = core.DeviceContext.create(place)
op.run(scope, ctx) op.run(scope, ctx)
for out_name in func.all_output_args: for out_name in func.all_output_args:
actual = numpy.array(scope.get_var(out_name).get_tensor()) actual = numpy.array(scope.get_var(out_name).get_tensor())
expect = getattr(self, out_name) expect = getattr(self, out_name)
# TODO(qijun) The default decimal is 7, but numpy.dot and eigen.mul # TODO(qijun) The default decimal is 7, but numpy.dot and eigen.mul
# has some diff, and could not pass unittest. So I set decimal 3 here. # has some diff, and could not pass unittest. So I set decimal 3 here.
# And I will check this in future. # And I will check this in future.
numpy.testing.assert_almost_equal(actual, expect, decimal=3) numpy.testing.assert_almost_equal(actual, expect, decimal=3)
obj.test_all = test_all obj.test_all = test_all
return obj return obj
...@@ -33,7 +33,7 @@ class TestFc(unittest.TestCase): ...@@ -33,7 +33,7 @@ class TestFc(unittest.TestCase):
op.infer_shape(scope) op.infer_shape(scope)
self.assertEqual([1000, 100], tensor.shape()) self.assertEqual([1000, 100], tensor.shape())
ctx = core.DeviceContext.cpu_context() ctx = core.DeviceContext.create(place)
op.run(scope, ctx) op.run(scope, ctx)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册