提交 6824c09d 编写于 作者: Q QI JUN 提交者: GitHub

Merge pull request #3050 from QiJune/op_gpu_test

enable operator gpu unittest
...@@ -9,6 +9,11 @@ function(CheckCompilerCXX11Flag) ...@@ -9,6 +9,11 @@ function(CheckCompilerCXX11Flag)
if(${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS 4.8) if(${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS 4.8)
message(FATAL_ERROR "Unsupported GCC version. GCC >= 4.8 required.") message(FATAL_ERROR "Unsupported GCC version. GCC >= 4.8 required.")
endif() endif()
# TODO(qijun) gcc 4.9 or later versions raise SEGV due to the optimization problem.
# Use Debug mode instead for now.
if(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 4.9 OR CMAKE_CXX_COMPILER_VERSION VERSION_EQUAL 4.9)
set(CMAKE_BUILD_TYPE "Debug" CACHE STRING "" FORCE)
endif()
elseif(CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang" OR CMAKE_CXX_COMPILER_ID STREQUAL "Clang") elseif(CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang" OR CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
# cmake >= 3.0 compiler id "AppleClang" on Mac OS X, otherwise "Clang" # cmake >= 3.0 compiler id "AppleClang" on Mac OS X, otherwise "Clang"
# Apple Clang is a different compiler than upstream Clang which havs different version numbers. # Apple Clang is a different compiler than upstream Clang which havs different version numbers.
......
...@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include "paddle/memory/memcpy.h" #include "paddle/memory/memcpy.h"
namespace paddle { namespace paddle {
...@@ -62,9 +61,11 @@ inline T* Tensor::mutable_data(platform::Place place) { ...@@ -62,9 +61,11 @@ inline T* Tensor::mutable_data(platform::Place place) {
if (platform::is_cpu_place(place)) { if (platform::is_cpu_place(place)) {
holder_.reset(new PlaceholderImpl<T, platform::CPUPlace>( holder_.reset(new PlaceholderImpl<T, platform::CPUPlace>(
boost::get<platform::CPUPlace>(place), size)); boost::get<platform::CPUPlace>(place), size));
} else if (platform::is_gpu_place(place)) {
#ifdef PADDLE_ONLY_CPU
PADDLE_THROW("'GPUPlace' is not supported in CPU only device.");
} }
#ifndef PADDLE_ONLY_CPU #else
else if (platform::is_gpu_place(place)) {
holder_.reset(new PlaceholderImpl<T, platform::GPUPlace>( holder_.reset(new PlaceholderImpl<T, platform::GPUPlace>(
boost::get<platform::GPUPlace>(place), size)); boost::get<platform::GPUPlace>(place), size));
} }
......
#define EIGEN_USE_GPU
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/operators/add_op.h" #include "paddle/operators/add_op.h"
......
#define EIGEN_USE_GPU
#include "paddle/operators/cross_entropy_op.h" #include "paddle/operators/cross_entropy_op.h"
REGISTER_OP_GPU_KERNEL(onehot_cross_entropy, REGISTER_OP_GPU_KERNEL(onehot_cross_entropy,
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/operators/mul_op.h" #include "paddle/operators/mul_op.h"
REGISTER_OP_GPU_KERNEL(mul, ops::MulKernel<ops::GPUPlace, float>); REGISTER_OP_GPU_KERNEL(mul, ops::MulKernel<ops::GPUPlace, float>);
\ No newline at end of file
#define EIGEN_USE_GPU
#include "paddle/operators/rowwise_add_op.h" #include "paddle/operators/rowwise_add_op.h"
REGISTER_OP_GPU_KERNEL(rowwise_add, REGISTER_OP_GPU_KERNEL(rowwise_add,
......
#define EIGEN_USE_GPU
#include "paddle/operators/sgd_op.h" #include "paddle/operators/sgd_op.h"
REGISTER_OP_GPU_KERNEL(sgd, ops::SGDOpKernel<ops::GPUPlace, float>); REGISTER_OP_GPU_KERNEL(sgd, ops::SGDOpKernel<ops::GPUPlace, float>);
\ No newline at end of file
#define EIGEN_USE_GPU
#include "paddle/operators/sigmoid_op.h" #include "paddle/operators/sigmoid_op.h"
REGISTER_OP_GPU_KERNEL(sigmoid, ops::SigmoidKernel<ops::GPUPlace, float>); REGISTER_OP_GPU_KERNEL(sigmoid, ops::SigmoidKernel<ops::GPUPlace, float>);
#define EIGEN_USE_GPU
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/operators/softmax_op.h" #include "paddle/operators/softmax_op.h"
......
...@@ -144,12 +144,12 @@ inline void throw_on_error(T e) { ...@@ -144,12 +144,12 @@ inline void throw_on_error(T e) {
throw_on_error(e, ""); throw_on_error(e, "");
} }
#define PADDLE_THROW(...) \ #define PADDLE_THROW(...) \
do { \ do { \
throw ::paddle::platform::EnforceNotMet( \ throw ::paddle::platform::EnforceNotMet( \
std::make_exception_ptr( \ std::make_exception_ptr( \
std::runtime_error(string::Sprintf(__VA_ARGS__))), \ std::runtime_error(paddle::string::Sprintf(__VA_ARGS__))), \
__FILE__, __LINE__); \ __FILE__, __LINE__); \
} while (0) } while (0)
#define PADDLE_ENFORCE(...) \ #define PADDLE_ENFORCE(...) \
......
...@@ -20,6 +20,8 @@ limitations under the License. */ ...@@ -20,6 +20,8 @@ limitations under the License. */
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h" #include "paddle/framework/operator.h"
#include "paddle/framework/scope.h" #include "paddle/framework/scope.h"
#include "paddle/platform/enforce.h"
#include "paddle/platform/place.h"
#include "paddle/pybind/tensor_bind.h" #include "paddle/pybind/tensor_bind.h"
#include "pybind11/numpy.h" #include "pybind11/numpy.h"
#include "pybind11/pybind11.h" #include "pybind11/pybind11.h"
...@@ -55,6 +57,14 @@ static size_t UniqueIntegerGenerator() { ...@@ -55,6 +57,14 @@ static size_t UniqueIntegerGenerator() {
return generator.fetch_add(1); return generator.fetch_add(1);
} }
bool IsCompileGPU() {
#ifdef PADDLE_ONLY_CPU
return false;
#else
return true;
#endif
}
PYBIND11_PLUGIN(core) { PYBIND11_PLUGIN(core) {
py::module m("core", "C++ core of PaddlePaddle"); py::module m("core", "C++ core of PaddlePaddle");
...@@ -69,15 +79,27 @@ PYBIND11_PLUGIN(core) { ...@@ -69,15 +79,27 @@ PYBIND11_PLUGIN(core) {
self.Resize(pd::make_ddim(dim)); self.Resize(pd::make_ddim(dim));
}) })
.def("alloc_float", .def("alloc_float",
[](pd::Tensor& self) { [](pd::Tensor& self, paddle::platform::GPUPlace& place) {
self.mutable_data<float>(paddle::platform::CPUPlace()); self.mutable_data<float>(place);
})
.def("alloc_float",
[](pd::Tensor& self, paddle::platform::CPUPlace& place) {
self.mutable_data<float>(place);
}) })
.def("alloc_int", .def("alloc_int",
[](pd::Tensor& self) { [](pd::Tensor& self, paddle::platform::CPUPlace& place) {
self.mutable_data<int>(paddle::platform::CPUPlace()); self.mutable_data<int>(place);
}) })
.def("set", paddle::pybind::PyTensorSetFromArray<float>) .def("alloc_int",
.def("set", paddle::pybind::PyTensorSetFromArray<int>) [](pd::Tensor& self, paddle::platform::GPUPlace& place) {
self.mutable_data<int>(place);
})
.def("set", paddle::pybind::PyCPUTensorSetFromArray<float>)
.def("set", paddle::pybind::PyCPUTensorSetFromArray<int>)
#ifndef PADDLE_ONLY_CPU
.def("set", paddle::pybind::PyCUDATensorSetFromArray<float>)
.def("set", paddle::pybind::PyCUDATensorSetFromArray<int>)
#endif
.def("shape", .def("shape",
[](pd::Tensor& self) { return pd::vectorize(self.dims()); }); [](pd::Tensor& self) { return pd::vectorize(self.dims()); });
...@@ -136,11 +158,27 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -136,11 +158,27 @@ All parameter, weight, gradient are variables in Paddle.
"The module will return special predefined variable name in Paddle") "The module will return special predefined variable name in Paddle")
.def("empty", pd::OperatorBase::EMPTY_VAR_NAME) .def("empty", pd::OperatorBase::EMPTY_VAR_NAME)
.def("temp", pd::OperatorBase::TMP_VAR_NAME); .def("temp", pd::OperatorBase::TMP_VAR_NAME);
// clang-format off
py::class_<paddle::platform::DeviceContext>(m, "DeviceContext") py::class_<paddle::platform::DeviceContext>(m, "DeviceContext")
.def_static("cpu_context", []() -> paddle::platform::DeviceContext* { .def_static("create",
return new paddle::platform::CPUDeviceContext(); [](paddle::platform::CPUPlace& place)
}); -> paddle::platform::DeviceContext* {
return new paddle::platform::CPUDeviceContext();
})
.def_static("create",
[](paddle::platform::GPUPlace& place)
-> paddle::platform::DeviceContext* {
#ifdef PADDLE_ONLY_CPU
PADDLE_THROW("GPUPlace is not supported in CPU device.");
#else
return new paddle::platform::CUDADeviceContext(place);
#endif
});
// clang-format on
py::class_<paddle::platform::GPUPlace>(m, "GPUPlace").def(py::init<int>());
py::class_<paddle::platform::CPUPlace>(m, "CPUPlace").def(py::init<>());
py::class_<pd::OperatorBase, std::shared_ptr<pd::OperatorBase>> operator_base( py::class_<pd::OperatorBase, std::shared_ptr<pd::OperatorBase>> operator_base(
m, "Operator"); m, "Operator");
...@@ -176,5 +214,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -176,5 +214,7 @@ All parameter, weight, gradient are variables in Paddle.
m.def("unique_integer", UniqueIntegerGenerator); m.def("unique_integer", UniqueIntegerGenerator);
m.def("is_compile_gpu", IsCompileGPU);
return m.ptr(); return m.ptr();
} }
...@@ -13,9 +13,11 @@ ...@@ -13,9 +13,11 @@
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <paddle/framework/tensor.h> #include <string>
#include <pybind11/numpy.h> #include "paddle/framework/tensor.h"
#include <pybind11/pybind11.h> #include "paddle/memory/memcpy.h"
#include "pybind11/numpy.h"
#include "pybind11/pybind11.h"
namespace py = pybind11; namespace py = pybind11;
...@@ -40,9 +42,6 @@ template <size_t I, typename... ARGS> ...@@ -40,9 +42,6 @@ template <size_t I, typename... ARGS>
struct CastToPyBufferImpl<true, I, ARGS...> { struct CastToPyBufferImpl<true, I, ARGS...> {
using CUR_TYPE = typename std::tuple_element<I, std::tuple<ARGS...>>::type; using CUR_TYPE = typename std::tuple_element<I, std::tuple<ARGS...>>::type;
py::buffer_info operator()(framework::Tensor &tensor) { py::buffer_info operator()(framework::Tensor &tensor) {
PADDLE_ENFORCE(paddle::platform::is_cpu_place(tensor.holder_->place()),
"Only CPU tensor can cast to numpy array");
if (std::type_index(typeid(CUR_TYPE)) == tensor.holder_->type()) { if (std::type_index(typeid(CUR_TYPE)) == tensor.holder_->type()) {
auto dim_vec = framework::vectorize(tensor.dims()); auto dim_vec = framework::vectorize(tensor.dims());
std::vector<size_t> dims_outside; std::vector<size_t> dims_outside;
...@@ -56,12 +55,17 @@ struct CastToPyBufferImpl<true, I, ARGS...> { ...@@ -56,12 +55,17 @@ struct CastToPyBufferImpl<true, I, ARGS...> {
strides[i - 1] = sizeof(CUR_TYPE) * prod; strides[i - 1] = sizeof(CUR_TYPE) * prod;
prod *= dims_outside[i - 1]; prod *= dims_outside[i - 1];
} }
framework::Tensor dst_tensor;
if (paddle::platform::is_gpu_place(tensor.holder_->place())) {
dst_tensor.CopyFrom<CUR_TYPE>(tensor, platform::CPUPlace());
} else if (paddle::platform::is_cpu_place(tensor.holder_->place())) {
dst_tensor = tensor;
}
return py::buffer_info( return py::buffer_info(
tensor.mutable_data<CUR_TYPE>(tensor.holder_->place()), dst_tensor.mutable_data<CUR_TYPE>(dst_tensor.holder_->place()),
sizeof(CUR_TYPE), sizeof(CUR_TYPE),
py::format_descriptor<CUR_TYPE>::format(), py::format_descriptor<CUR_TYPE>::format(),
(size_t)framework::arity(tensor.dims()), (size_t)framework::arity(dst_tensor.dims()),
dims_outside, dims_outside,
strides); strides);
} else { } else {
...@@ -77,9 +81,10 @@ inline py::buffer_info CastToPyBuffer(framework::Tensor &tensor) { ...@@ -77,9 +81,10 @@ inline py::buffer_info CastToPyBuffer(framework::Tensor &tensor) {
} }
template <typename T> template <typename T>
void PyTensorSetFromArray( void PyCPUTensorSetFromArray(
framework::Tensor &self, framework::Tensor &self,
py::array_t<T, py::array::c_style | py::array::forcecast> array) { py::array_t<T, py::array::c_style | py::array::forcecast> array,
paddle::platform::CPUPlace &place) {
std::vector<int> dims; std::vector<int> dims;
dims.reserve(array.ndim()); dims.reserve(array.ndim());
for (size_t i = 0; i < array.ndim(); ++i) { for (size_t i = 0; i < array.ndim(); ++i) {
...@@ -87,9 +92,28 @@ void PyTensorSetFromArray( ...@@ -87,9 +92,28 @@ void PyTensorSetFromArray(
} }
self.Resize(framework::make_ddim(dims)); self.Resize(framework::make_ddim(dims));
auto *dst = self.mutable_data<T>(paddle::platform::CPUPlace()); auto *dst = self.mutable_data<T>(place);
std::memcpy(dst, array.data(), sizeof(T) * array.size()); std::memcpy(dst, array.data(), sizeof(T) * array.size());
} }
#ifndef PADDLE_ONLY_CPU
template <typename T>
void PyCUDATensorSetFromArray(
framework::Tensor &self,
py::array_t<T, py::array::c_style | py::array::forcecast> array,
paddle::platform::GPUPlace &place) {
std::vector<int> dims;
dims.reserve(array.ndim());
for (size_t i = 0; i < array.ndim(); ++i) {
dims.push_back((int)array.shape()[i]);
}
self.Resize(framework::make_ddim(dims));
auto *dst = self.mutable_data<T>(place);
paddle::platform::GpuMemcpySync(
dst, array.data(), sizeof(T) * array.size(), cudaMemcpyHostToDevice);
}
#endif
} // namespace pybind } // namespace pybind
} // namespace paddle } // namespace paddle
...@@ -8,7 +8,6 @@ add_python_test(test_framework ...@@ -8,7 +8,6 @@ add_python_test(test_framework
test_fc_op.py test_fc_op.py
test_add_two_op.py test_add_two_op.py
test_sgd_op.py test_sgd_op.py
test_cross_entropy_op.py
test_mul_op.py test_mul_op.py
test_mean_op.py test_mean_op.py
test_sigmoid_op.py test_sigmoid_op.py
......
...@@ -26,40 +26,45 @@ class OpTestMeta(type): ...@@ -26,40 +26,45 @@ class OpTestMeta(type):
scope = core.Scope() scope = core.Scope()
kwargs = dict() kwargs = dict()
places = []
places.append(core.CPUPlace())
if core.is_compile_gpu():
places.append(core.GPUPlace(0))
for in_name in func.all_input_args: for place in places:
if hasattr(self, in_name): for in_name in func.all_input_args:
kwargs[in_name] = in_name if hasattr(self, in_name):
var = scope.new_var(in_name).get_tensor() kwargs[in_name] = in_name
arr = getattr(self, in_name) var = scope.new_var(in_name).get_tensor()
var.set_dims(arr.shape) arr = getattr(self, in_name)
var.set(arr) var.set_dims(arr.shape)
else: var.set(arr, place)
kwargs[in_name] = "@EMPTY@" else:
kwargs[in_name] = "@EMPTY@"
for out_name in func.all_output_args: for out_name in func.all_output_args:
if hasattr(self, out_name): if hasattr(self, out_name):
kwargs[out_name] = out_name kwargs[out_name] = out_name
scope.new_var(out_name).get_tensor() scope.new_var(out_name).get_tensor()
for attr_name in func.all_attr_args: for attr_name in func.all_attr_args:
if hasattr(self, attr_name): if hasattr(self, attr_name):
kwargs[attr_name] = getattr(self, attr_name) kwargs[attr_name] = getattr(self, attr_name)
op = func(**kwargs) op = func(**kwargs)
op.infer_shape(scope) op.infer_shape(scope)
ctx = core.DeviceContext.cpu_context() ctx = core.DeviceContext.create(place)
op.run(scope, ctx) op.run(scope, ctx)
for out_name in func.all_output_args: for out_name in func.all_output_args:
actual = numpy.array(scope.find_var(out_name).get_tensor()) actual = numpy.array(scope.find_var(out_name).get_tensor())
expect = getattr(self, out_name) expect = getattr(self, out_name)
# TODO(qijun) The default decimal is 7, but numpy.dot and eigen.mul # TODO(qijun) The default decimal is 7, but numpy.dot and eigen.mul
# has some diff, and could not pass unittest. So I set decimal 3 here. # has some diff, and could not pass unittest. So I set decimal 3 here.
# And I will check this in future. # And I will check this in future.
numpy.testing.assert_almost_equal(actual, expect, decimal=3) numpy.testing.assert_almost_equal(actual, expect, decimal=3)
obj.test_all = test_all obj.test_all = test_all
return obj return obj
...@@ -8,8 +8,8 @@ class TestAddOp(unittest.TestCase): ...@@ -8,8 +8,8 @@ class TestAddOp(unittest.TestCase):
def setUp(self): def setUp(self):
self.type = "add_two" self.type = "add_two"
self.X = numpy.random.random((342, 345)).astype("float32") self.X = numpy.random.random((102, 105)).astype("float32")
self.Y = numpy.random.random((342, 345)).astype("float32") self.Y = numpy.random.random((102, 105)).astype("float32")
self.Out = self.X + self.Y self.Out = self.X + self.Y
......
...@@ -7,17 +7,19 @@ import paddle.v2.framework.create_op_creation_methods as creation ...@@ -7,17 +7,19 @@ import paddle.v2.framework.create_op_creation_methods as creation
class TestFc(unittest.TestCase): class TestFc(unittest.TestCase):
def test_fc(self): def test_fc(self):
scope = core.Scope() scope = core.Scope()
place = core.CPUPlace()
x = scope.new_var("X") x = scope.new_var("X")
x_tensor = x.get_tensor() x_tensor = x.get_tensor()
x_tensor.set_dims([1000, 784]) x_tensor.set_dims([1000, 784])
x_tensor.alloc_float() x_tensor.alloc_float(place)
w = scope.new_var("W") w = scope.new_var("W")
w_tensor = w.get_tensor() w_tensor = w.get_tensor()
w_tensor.set_dims([784, 100]) w_tensor.set_dims([784, 100])
w_tensor.alloc_float() w_tensor.alloc_float(place)
w_tensor.set(numpy.random.random((784, 100)).astype("float32")) w_tensor.set(numpy.random.random((784, 100)).astype("float32"), place)
# Set a real numpy array here. # Set a real numpy array here.
# x_tensor.set(numpy.array([])) # x_tensor.set(numpy.array([]))
...@@ -32,7 +34,7 @@ class TestFc(unittest.TestCase): ...@@ -32,7 +34,7 @@ class TestFc(unittest.TestCase):
op.infer_shape(scope) op.infer_shape(scope)
self.assertEqual([1000, 100], tensor.shape()) self.assertEqual([1000, 100], tensor.shape())
ctx = core.DeviceContext.cpu_context() ctx = core.DeviceContext.create(place)
op.run(scope, ctx) op.run(scope, ctx)
......
...@@ -8,8 +8,8 @@ class TestMulOp(unittest.TestCase): ...@@ -8,8 +8,8 @@ class TestMulOp(unittest.TestCase):
def setUp(self): def setUp(self):
self.type = "mul" self.type = "mul"
self.X = np.random.random((32, 784)).astype("float32") self.X = np.random.random((32, 84)).astype("float32")
self.Y = np.random.random((784, 100)).astype("float32") self.Y = np.random.random((84, 100)).astype("float32")
self.Out = np.dot(self.X, self.Y) self.Out = np.dot(self.X, self.Y)
......
...@@ -8,8 +8,8 @@ class TestRowwiseAddOp(unittest.TestCase): ...@@ -8,8 +8,8 @@ class TestRowwiseAddOp(unittest.TestCase):
def setUp(self): def setUp(self):
self.type = "rowwise_add" self.type = "rowwise_add"
self.X = np.random.random((32, 784)).astype("float32") self.X = np.random.random((32, 84)).astype("float32")
self.b = np.random.random(784).astype("float32") self.b = np.random.random(84).astype("float32")
self.Out = np.add(self.X, self.b) self.Out = np.add(self.X, self.b)
......
...@@ -8,8 +8,8 @@ class TestSGD(unittest.TestCase): ...@@ -8,8 +8,8 @@ class TestSGD(unittest.TestCase):
def setUp(self): def setUp(self):
self.type = "sgd" self.type = "sgd"
self.param = numpy.random.random((342, 345)).astype("float32") self.param = numpy.random.random((102, 105)).astype("float32")
self.grad = numpy.random.random((342, 345)).astype("float32") self.grad = numpy.random.random((102, 105)).astype("float32")
self.learning_rate = 0.1 self.learning_rate = 0.1
self.param_out = self.param - self.learning_rate * self.grad self.param_out = self.param - self.learning_rate * self.grad
......
...@@ -7,16 +7,17 @@ class TestScope(unittest.TestCase): ...@@ -7,16 +7,17 @@ class TestScope(unittest.TestCase):
def test_int_tensor(self): def test_int_tensor(self):
scope = core.Scope() scope = core.Scope()
var = scope.new_var("test_tensor") var = scope.new_var("test_tensor")
place = core.CPUPlace()
tensor = var.get_tensor() tensor = var.get_tensor()
tensor.set_dims([1000, 784]) tensor.set_dims([1000, 784])
tensor.alloc_int() tensor.alloc_int(place)
tensor_array = numpy.array(tensor) tensor_array = numpy.array(tensor)
self.assertEqual((1000, 784), tensor_array.shape) self.assertEqual((1000, 784), tensor_array.shape)
tensor_array[3, 9] = 1 tensor_array[3, 9] = 1
tensor_array[19, 11] = 2 tensor_array[19, 11] = 2
tensor.set(tensor_array) tensor.set(tensor_array, place)
tensor_array_2 = numpy.array(tensor) tensor_array_2 = numpy.array(tensor)
self.assertEqual(1.0, tensor_array_2[3, 9]) self.assertEqual(1.0, tensor_array_2[3, 9])
...@@ -25,16 +26,18 @@ class TestScope(unittest.TestCase): ...@@ -25,16 +26,18 @@ class TestScope(unittest.TestCase):
def test_float_tensor(self): def test_float_tensor(self):
scope = core.Scope() scope = core.Scope()
var = scope.new_var("test_tensor") var = scope.new_var("test_tensor")
place = core.CPUPlace()
tensor = var.get_tensor() tensor = var.get_tensor()
tensor.set_dims([1000, 784]) tensor.set_dims([1000, 784])
tensor.alloc_float() tensor.alloc_float(place)
tensor_array = numpy.array(tensor) tensor_array = numpy.array(tensor)
self.assertEqual((1000, 784), tensor_array.shape) self.assertEqual((1000, 784), tensor_array.shape)
tensor_array[3, 9] = 1.0 tensor_array[3, 9] = 1.0
tensor_array[19, 11] = 2.0 tensor_array[19, 11] = 2.0
tensor.set(tensor_array) tensor.set(tensor_array, place)
tensor_array_2 = numpy.array(tensor) tensor_array_2 = numpy.array(tensor)
self.assertAlmostEqual(1.0, tensor_array_2[3, 9]) self.assertAlmostEqual(1.0, tensor_array_2[3, 9])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册