未验证 提交 033ef5e9 编写于 作者: R ronnywang 提交者: GitHub

[CustomKernel] capi add eager mode support (#44164)

* [CustomKernel] add capi eager mode support

* add ut

* add capi test
上级 961d6cce
......@@ -11,4 +11,8 @@ if(WITH_CUSTOM_DEVICE)
custom_device_test
SRCS custom_device_test.cc
DEPS device_manager device_context)
cc_test(
capi_test
SRCS capi_test.cc
DEPS phi_capi)
endif()
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include <cstring>
#include <string>
#include "paddle/phi/capi/all.h"
#ifndef UNUSED
#define UNUSED __attribute__((unused))
#endif
#include "paddle/phi/capi/capi.h"
TEST(CustomKernel, CAPI) {
std::string str = "capi";
EXPECT_EQ(str.data(), PD_StringAttr(&str));
std::vector<int32_t> int32_vec({1, 2, 3});
auto int32_list = PD_ListInt32Attr(&int32_vec);
EXPECT_EQ(int32_list.data, int32_vec.data());
EXPECT_EQ(int32_list.size, int32_vec.size());
std::vector<int64_t> int64_vec({1, 2, 3});
auto int64_list = PD_ListInt64Attr(&int64_vec);
EXPECT_EQ(int64_list.data, int64_vec.data());
EXPECT_EQ(int64_list.size, int64_vec.size());
std::vector<float> float_vec({1, 2, 3});
auto float_list = PD_ListFloatAttr(&float_vec);
EXPECT_EQ(float_list.data, float_vec.data());
EXPECT_EQ(float_list.size, float_vec.size());
std::vector<double> double_vec({1, 2, 3});
auto double_list = PD_ListDoubleAttr(&double_vec);
EXPECT_EQ(double_list.data, double_vec.data());
EXPECT_EQ(double_list.size, double_vec.size());
std::vector<std::string> string_vec{"capi", "api"};
auto string_list = PD_ListStringAttr(&string_vec);
auto string_data = reinterpret_cast<void**>(string_list.data);
for (size_t i = 0; i < string_vec.size(); ++i) {
EXPECT_EQ(string_data[i], string_vec[i].data());
}
std::vector<bool> bool_vec{true, false, true};
auto bool_list = PD_ListBoolAttr(&bool_vec);
auto bool_data = reinterpret_cast<uint8_t*>(bool_list.data);
for (size_t i = 0; i < bool_vec.size(); ++i) {
EXPECT_EQ(bool_data[i], static_cast<uint8_t>(bool_vec[i]));
}
std::vector<float*> ptr_vec;
for (size_t i = 0; i < float_vec.size(); ++i) {
ptr_vec.push_back(&float_vec[i]);
}
auto ptr_list = PD_TensorVectorToList(reinterpret_cast<PD_Tensor*>(&ptr_vec));
EXPECT_EQ(ptr_list.data, ptr_vec.data());
EXPECT_EQ(ptr_list.size, ptr_vec.size());
}
int main(int argc, char** argv) {
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}
......@@ -87,6 +87,26 @@ PD_List PD_KernelContextListScalarAttrAt(PD_KernelContext *ctx, size_t index);
PD_Place *PD_KernelContextPlaceAttrAt(PD_KernelContext *ctx, size_t index);
const char *PD_StringAttr(void *attr);
PD_DataType PD_DatatTypeAttr(void *attr);
PD_DataLayout PD_DatatLayoutAttr(void *attr);
PD_List PD_ListInt32Attr(void *attr);
PD_List PD_ListInt64Attr(void *attr);
PD_List PD_ListFloatAttr(void *attr);
PD_List PD_ListDoubleAttr(void *attr);
PD_List PD_ListScalarAttr(void *attr);
PD_List PD_ListStringAttr(void *attr);
PD_List PD_ListBoolAttr(void *attr);
#ifdef __cplusplus
} // extern "C"
#endif
......
......@@ -82,6 +82,10 @@ void PD_TensorShareLoDWith(PD_Tensor *dst,
const PD_Tensor *src,
PD_Status *status);
PD_Tensor *PD_OptionalTensorGetPointer(PD_Tensor *tensor);
PD_List PD_TensorVectorToList(PD_Tensor *tensor);
#ifdef __cplusplus
} // extern "C"
#endif
......
......@@ -19,7 +19,129 @@
namespace phi {
namespace capi {
// eager mode
inline std::vector<phi::capi::DenseTensor> PD_TensorVector(PD_Tensor *tensor) {
std::vector<phi::capi::DenseTensor> ret;
auto list = PD_TensorVectorToList(tensor);
auto data = reinterpret_cast<PD_Tensor **>(list.data);
for (size_t i = 0; i < list.size; ++i) {
ret.emplace_back(data[i]);
}
return ret;
}
inline paddle::optional<phi::capi::DenseTensor> PD_OptionalTensor(
PD_Tensor *tensor) {
auto ptr = PD_OptionalTensorGetPointer(tensor);
return ptr ? paddle::optional<phi::capi::DenseTensor>(
phi::capi::DenseTensor(ptr))
: paddle::optional<phi::capi::DenseTensor>(paddle::none);
}
template <typename T>
inline T PD_Attr(void *attr) {
return *reinterpret_cast<T *>(attr);
}
template <>
inline std::string PD_Attr<std::string>(void *attr) {
return PD_StringAttr(attr);
}
template <>
inline PD_DataType PD_Attr<PD_DataType>(void *attr) {
return PD_DatatTypeAttr(attr);
}
template <>
inline PD_DataLayout PD_Attr<PD_DataLayout>(void *attr) {
return PD_DatatLayoutAttr(attr);
}
template <>
inline std::vector<int32_t> PD_Attr<std::vector<int32_t>>(void *attr) {
auto list = PD_ListInt32Attr(attr);
auto data = reinterpret_cast<int32_t *>(list.data);
std::vector<int32_t> cc_list(data, data + list.size);
return cc_list;
}
template <>
inline std::vector<int64_t> PD_Attr<std::vector<int64_t>>(void *attr) {
auto list = PD_ListInt64Attr(attr);
auto data = reinterpret_cast<int64_t *>(list.data);
std::vector<int64_t> cc_list(data, data + list.size);
return cc_list;
}
template <>
inline std::vector<float> PD_Attr<std::vector<float>>(void *attr) {
auto list = PD_ListFloatAttr(attr);
auto data = reinterpret_cast<float *>(list.data);
std::vector<float> cc_list(data, data + list.size);
return cc_list;
}
template <>
inline std::vector<double> PD_Attr<std::vector<double>>(void *attr) {
auto list = PD_ListDoubleAttr(attr);
auto data = reinterpret_cast<double *>(list.data);
std::vector<double> cc_list(data, data + list.size);
return cc_list;
}
template <>
inline phi::capi::Scalar PD_Attr<phi::capi::Scalar>(void *attr) {
return phi::capi::Scalar(reinterpret_cast<PD_Scalar *>(attr));
}
template <>
inline phi::capi::IntArray PD_Attr<phi::capi::IntArray>(void *attr) {
return phi::capi::IntArray(reinterpret_cast<PD_IntArray *>(attr));
}
template <>
inline phi::capi::Place PD_Attr<phi::capi::Place>(void *attr) {
return phi::capi::Place(reinterpret_cast<PD_Place *>(attr));
}
template <>
inline std::vector<phi::capi::Scalar> PD_Attr<std::vector<phi::capi::Scalar>>(
void *attr) {
auto c_list = PD_ListScalarAttr(attr);
auto data = reinterpret_cast<PD_Scalar **>(c_list.data);
std::vector<phi::capi::Scalar> list;
for (size_t i = 0; i < c_list.size; ++i) {
list.emplace_back(data[i]);
}
PD_DeletePointerList(c_list);
return list;
}
template <>
inline std::vector<std::string> PD_Attr<std::vector<std::string>>(void *attr) {
auto c_list = PD_ListStringAttr(attr);
auto data = reinterpret_cast<char **>(c_list.data);
std::vector<std::string> list;
for (size_t i = 0; i < c_list.size; ++i) {
list.emplace_back(data[i]);
}
PD_DeletePointerList(c_list);
return list;
}
template <>
inline std::vector<bool> PD_Attr<std::vector<bool>>(void *attr) {
auto c_list = PD_ListBoolAttr(attr);
std::vector<bool> list;
auto data = reinterpret_cast<uint8_t *>(c_list.data);
for (size_t i = 0; i < c_list.size; ++i) {
list[i] = static_cast<bool>(data[i]);
}
PD_DeleteUInt8List(c_list);
return list;
}
//
inline phi::capi::DeviceContext PD_GetDeviceContext(PD_KernelContext *ctx) {
return phi::capi::DeviceContext(PD_KernelContextGetDeviceContext(ctx));
}
......@@ -189,7 +311,7 @@ inline std::vector<phi::capi::Scalar> PD_AttrAt<std::vector<phi::capi::Scalar>>(
template <>
inline std::vector<std::string> PD_AttrAt<std::vector<std::string>>(
PD_KernelContext *ctx, size_t index) {
auto c_list = PD_KernelContextListScalarAttrAt(ctx, index);
auto c_list = PD_KernelContextListStringAttrAt(ctx, index);
auto data = reinterpret_cast<char **>(c_list.data);
std::vector<std::string> list;
for (size_t i = 0; i < c_list.size; ++i) {
......
......@@ -220,4 +220,89 @@ PD_DataLayout PD_KernelContextDataLayoutAttrAt(PD_KernelContext* ctx,
kernel_context->AttrAt<phi::DataLayout>(index));
}
// eager
const char* PD_StringAttr(void* attr) {
auto* str = reinterpret_cast<std::string*>(attr);
return str->c_str();
}
PD_DataType PD_DatatTypeAttr(void* attr) {
auto* dtype = reinterpret_cast<phi::DataType*>(attr);
return phi::capi::ToPDDataType(*dtype);
}
PD_DataLayout PD_DatatLayoutAttr(void* attr) {
auto* layout = reinterpret_cast<phi::DataLayout*>(attr);
return phi::capi::ToPDDataLayout(*layout);
}
PD_List PD_ListInt32Attr(void* attr) {
PD_List list;
const auto& cc_list = *reinterpret_cast<std::vector<int32_t>*>(attr);
list.size = cc_list.size();
list.data = const_cast<int32_t*>(cc_list.data());
return list;
}
PD_List PD_ListInt64Attr(void* attr) {
PD_List list;
const auto& cc_list = *reinterpret_cast<std::vector<int64_t>*>(attr);
list.size = cc_list.size();
list.data = const_cast<int64_t*>(cc_list.data());
return list;
}
PD_List PD_ListFloatAttr(void* attr) {
PD_List list;
const auto& cc_list = *reinterpret_cast<std::vector<float>*>(attr);
list.size = cc_list.size();
list.data = const_cast<float*>(cc_list.data());
return list;
}
PD_List PD_ListDoubleAttr(void* attr) {
PD_List list;
const auto& cc_list = *reinterpret_cast<std::vector<double>*>(attr);
list.size = cc_list.size();
list.data = const_cast<double*>(cc_list.data());
return list;
}
PD_List PD_ListScalarAttr(void* attr) {
PD_List list;
const auto& cc_list = *reinterpret_cast<std::vector<phi::Scalar>*>(attr);
list.size = cc_list.size();
auto data = new PD_Scalar*[list.size];
for (size_t i = 0; i < list.size; ++i) {
data[i] =
const_cast<PD_Scalar*>(reinterpret_cast<const PD_Scalar*>(&cc_list[i]));
}
list.data = data;
return list;
}
PD_List PD_ListStringAttr(void* attr) {
PD_List list;
const auto& cc_list = *reinterpret_cast<std::vector<std::string>*>(attr);
list.size = cc_list.size();
auto data = new char*[list.size];
for (size_t i = 0; i < list.size; ++i) {
data[i] = const_cast<char*>(cc_list[i].data());
}
list.data = reinterpret_cast<void*>(data);
return list;
}
PD_List PD_ListBoolAttr(void* attr) {
PD_List list;
const auto& cc_list = *reinterpret_cast<std::vector<bool>*>(attr);
list.size = cc_list.size();
auto data = reinterpret_cast<uint8_t*>(new uint8_t[cc_list.size()]);
for (size_t i = 0; i < cc_list.size(); ++i) {
data[i] = static_cast<uint8_t>(cc_list[i]);
}
list.data = data;
return list;
}
PD_REGISTER_CAPI(kernel_context);
......@@ -299,4 +299,19 @@ void PD_TensorShareLoDWith(PD_Tensor* dst,
meta_dst.share_lod(meta_src);
}
PD_Tensor* PD_OptionalTensorGetPointer(PD_Tensor* tensor) {
auto cc_tensor =
reinterpret_cast<paddle::optional<phi::DenseTensor>*>(tensor);
return reinterpret_cast<PD_Tensor*>(cc_tensor->get_ptr());
}
PD_List PD_TensorVectorToList(PD_Tensor* tensor) {
auto cc_tensor =
reinterpret_cast<std::vector<const phi::DenseTensor*>*>(tensor);
PD_List list;
list.size = cc_tensor->size();
list.data = cc_tensor->data();
return list;
}
PD_REGISTER_CAPI(tensor);
if(WITH_CUSTOM_DEVICE)
py_test(test_custom_device_data_loader SRCS test_custom_device_data_loader.py)
py_test(test_custom_cpu_plugin SRCS test_custom_cpu_plugin.py)
set_tests_properties(test_custom_cpu_plugin PROPERTIES TIMEOUT 120)
endif()
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cstdint>
#include <cstdio>
#include <cstring>
#include <iostream>
#include "paddle/phi/backends/device_ext.h"
#define MEMORY_FRACTION 0.5f
C_Status Init() { return C_SUCCESS; }
C_Status InitDevice(const C_Device device) { return C_SUCCESS; }
C_Status SetDevice(const C_Device device) { return C_SUCCESS; }
C_Status GetDevice(const C_Device device) {
device->id = 0;
return C_SUCCESS;
}
C_Status DestroyDevice(const C_Device device) { return C_SUCCESS; }
C_Status Finalize() { return C_SUCCESS; }
C_Status GetDevicesCount(size_t *count) {
*count = 1;
return C_SUCCESS;
}
C_Status GetDevicesList(size_t *devices) {
devices[0] = 0;
return C_SUCCESS;
}
C_Status MemCpy(const C_Device device,
void *dst,
const void *src,
size_t size) {
memcpy(dst, src, size);
return C_SUCCESS;
}
C_Status AsyncMemCpy(const C_Device device,
C_Stream stream,
void *dst,
const void *src,
size_t size) {
memcpy(dst, src, size);
return C_SUCCESS;
}
C_Status MemCpyP2P(const C_Device dst_device,
const C_Device src_device,
void *dst,
const void *src,
size_t size) {
memcpy(dst, src, size);
return C_SUCCESS;
}
C_Status AsyncMemCpyP2P(const C_Device dst_device,
const C_Device src_device,
C_Stream stream,
void *dst,
const void *src,
size_t size) {
memcpy(dst, src, size);
return C_SUCCESS;
}
C_Status Allocate(const C_Device device, void **ptr, size_t size) {
auto data = malloc(size);
if (data) {
*ptr = data;
return C_SUCCESS;
} else {
*ptr = nullptr;
}
return C_FAILED;
}
C_Status Deallocate(const C_Device device, void *ptr, size_t size) {
free(ptr);
return C_SUCCESS;
}
C_Status CreateStream(const C_Device device, C_Stream *stream) {
stream = nullptr;
return C_SUCCESS;
}
C_Status DestroyStream(const C_Device device, C_Stream stream) {
return C_SUCCESS;
}
C_Status CreateEvent(const C_Device device, C_Event *event) {
return C_SUCCESS;
}
C_Status RecordEvent(const C_Device device, C_Stream stream, C_Event event) {
return C_SUCCESS;
}
C_Status DestroyEvent(const C_Device device, C_Event event) {
return C_SUCCESS;
}
C_Status SyncDevice(const C_Device device) { return C_SUCCESS; }
C_Status SyncStream(const C_Device device, C_Stream stream) {
return C_SUCCESS;
}
C_Status SyncEvent(const C_Device device, C_Event event) { return C_SUCCESS; }
C_Status StreamWaitEvent(const C_Device device,
C_Stream stream,
C_Event event) {
return C_SUCCESS;
}
C_Status VisibleDevices(size_t *devices) { return C_SUCCESS; }
C_Status DeviceMemStats(const C_Device device,
size_t *total_memory,
size_t *free_memory) {
float memusage;
FILE *fp;
char buffer[1024];
size_t byte_read;
char *pos;
fp = fopen("/proc/meminfo", "r");
byte_read = fread(buffer, 1, sizeof(buffer), fp);
fclose(fp);
buffer[byte_read] = '\0';
pos = strstr(buffer, "MemTotal:");
sscanf(pos, "MemTotal: %lu kB", total_memory);
pos = strstr(pos, "MemFree:");
sscanf(pos, "MemFree: %lu kB", free_memory);
*total_memory = *total_memory * 1024;
*free_memory = *free_memory * 1024;
*free_memory = *free_memory * MEMORY_FRACTION;
return C_SUCCESS;
}
C_Status DeviceMinChunkSize(const C_Device device, size_t *size) {
*size = 512;
return C_SUCCESS;
}
void InitPlugin(CustomRuntimeParams *params) {
PADDLE_CUSTOM_RUNTIME_CHECK_VERSION(params);
params->device_type = "custom_cpu";
params->sub_device_type = "v0.1";
memset(reinterpret_cast<void *>(params->interface),
0,
sizeof(C_DeviceInterface));
params->interface->initialize = Init;
params->interface->finalize = Finalize;
params->interface->init_device = InitDevice;
params->interface->set_device = SetDevice;
params->interface->get_device = GetDevice;
params->interface->deinit_device = DestroyDevice;
params->interface->create_stream = CreateStream;
params->interface->destroy_stream = DestroyStream;
params->interface->create_event = CreateEvent;
params->interface->destroy_event = DestroyEvent;
params->interface->record_event = RecordEvent;
params->interface->synchronize_device = SyncDevice;
params->interface->synchronize_stream = SyncStream;
params->interface->synchronize_event = SyncEvent;
params->interface->stream_wait_event = StreamWaitEvent;
params->interface->memory_copy_h2d = MemCpy;
params->interface->memory_copy_d2d = MemCpy;
params->interface->memory_copy_d2h = MemCpy;
params->interface->memory_copy_p2p = MemCpyP2P;
params->interface->async_memory_copy_h2d = AsyncMemCpy;
params->interface->async_memory_copy_d2d = AsyncMemCpy;
params->interface->async_memory_copy_d2h = AsyncMemCpy;
params->interface->async_memory_copy_p2p = AsyncMemCpyP2P;
params->interface->device_memory_allocate = Allocate;
params->interface->host_memory_allocate = Allocate;
params->interface->unified_memory_allocate = Allocate;
params->interface->device_memory_deallocate = Deallocate;
params->interface->host_memory_deallocate = Deallocate;
params->interface->unified_memory_deallocate = Deallocate;
params->interface->get_device_count = GetDevicesCount;
params->interface->get_device_list = GetDevicesList;
params->interface->device_memory_stats = DeviceMemStats;
params->interface->device_min_chunk_size = DeviceMinChunkSize;
}
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import site
from paddle.fluid import core
from distutils.sysconfig import get_python_lib
from distutils.core import setup, Extension
from setuptools.command.build_ext import build_ext
# refer: https://note.qidong.name/2018/03/setup-warning-strict-prototypes
# Avoid a gcc warning below:
# cc1plus: warning: command line option ‘-Wstrict-prototypes’ is valid
# for C/ObjC but not for C++
class BuildExt(build_ext):
def build_extensions(self):
if '-Wstrict-prototypes' in self.compiler.compiler_so:
self.compiler.compiler_so.remove('-Wstrict-prototypes')
super(BuildExt, self).build_extensions()
# cc flags
paddle_extra_compile_args = [
'-std=c++14',
'-shared',
'-fPIC',
'-Wno-parentheses',
'-DPADDLE_WITH_CUSTOM_KERNEL',
'-DPADDLE_WITH_CUSTOM_DEVICE',
]
if core.is_compiled_with_npu():
paddle_extra_compile_args += ['-D_GLIBCXX_USE_CXX11_ABI=0']
# include path
site_packages_path = site.getsitepackages()
include_dirs = list(
map(lambda path: os.path.join(path, 'paddle', 'include'),
site_packages_path))
# include path third_party
compile_third_party_path = os.path.join(os.environ['PADDLE_ROOT'],
'build/third_party')
include_dirs += [
os.path.join(compile_third_party_path, 'boost/src/extern_boost'), # boost
os.path.join(compile_third_party_path, 'install/gflags/include'), # gflags
os.path.join(compile_third_party_path, 'install/glog/include'), # glog
]
# libs path
library_dirs = list(
map(lambda path: os.path.join(path, 'paddle', 'fluid'), site_packages_path))
# libs
libs = [':core_avx.so']
if not core.has_avx_core and core.has_noavx_core:
libs = [':core_noavx.so']
custom_cpu_plugin_so = Extension('custom_cpu_runtime',
sources=['custom_cpu_runtime.cc'],
include_dirs=include_dirs,
library_dirs=library_dirs,
libraries=libs,
extra_compile_args=paddle_extra_compile_args)
setup(name='custom_kernel_dot',
version='1.0',
description='custom kernel fot compiling',
cmdclass={'build_ext': BuildExt},
ext_modules=[custom_cpu_plugin_so])
......@@ -19,24 +19,29 @@ import unittest
import numpy as np
class TestCustomDeviceDataLoader(unittest.TestCase):
class TestCustomCPUPlugin(unittest.TestCase):
def setUp(self):
# compile so and set to current path
cur_dir = os.path.dirname(os.path.abspath(__file__))
# --inplace to place output so file to current dir
cmd = 'cd {} && {} custom_cpu_setup.py build_ext --inplace'.format(
cur_dir, sys.executable)
cmd = 'rm -rf PaddleCustomDevice && git clone https://github.com/PaddlePaddle/PaddleCustomDevice.git && cd PaddleCustomDevice/backends/custom_cpu && mkdir build && cd build && cmake .. && make -j8'
os.system(cmd)
# set environment for loading and registering compiled custom kernels
# only valid in current process
os.environ['CUSTOM_DEVICE_ROOT'] = cur_dir
os.environ['CUSTOM_DEVICE_ROOT'] = os.path.join(
cur_dir, 'PaddleCustomDevice/backends/custom_cpu/build')
def test_custom_device_dataloader(self):
import paddle
with paddle.fluid.framework._test_eager_guard():
self._test_custom_device_dataloader()
self._test_custom_device_dataloader()
def _test_custom_device_dataloader(self):
import paddle
paddle.set_device('custom_cpu')
dataset = paddle.vision.datasets.MNIST(
mode='test',
......@@ -55,6 +60,66 @@ class TestCustomDeviceDataLoader(unittest.TestCase):
self.assertTrue(label.place.is_custom_place())
break
def test_custom_device_mnist(self):
import paddle
with paddle.fluid.framework._test_eager_guard():
self._test_custom_device_mnist()
self._test_custom_device_mnist()
def _test_custom_device_mnist(self):
import paddle
class MNIST(paddle.nn.Layer):
def __init__(self):
super(MNIST, self).__init__()
self.shape = 1 * 28 * 28
self.size = 10
self.output_weight = self.create_parameter(
[self.shape, self.size])
self.accuracy = paddle.metric.Accuracy()
def forward(self, inputs, label=None):
x = paddle.reshape(inputs, shape=[-1, self.shape])
x = paddle.matmul(x, self.output_weight)
x = paddle.nn.functional.softmax(x)
if label is not None:
self.accuracy.reset()
correct = self.accuracy.compute(x, label)
self.accuracy.update(correct)
acc = self.accuracy.accumulate()
return x, acc
else:
return x
paddle.set_device('custom_cpu')
dataset = paddle.vision.datasets.MNIST(
mode='train',
transform=paddle.vision.transforms.Compose(
[paddle.vision.transforms.ToTensor()]))
loader = paddle.io.DataLoader(dataset,
batch_size=64,
num_workers=1,
shuffle=True)
mnist = MNIST()
sgd = paddle.optimizer.SGD(learning_rate=0.01,
parameters=mnist.parameters())
data = next(loader())
img = data[0]
label = data[1]
label_int32 = paddle.cast(label, 'int32')
pred, acc = mnist(img, label_int32)
avg_loss = paddle.nn.functional.cross_entropy(pred, label_int32)
avg_loss.backward()
sgd.step()
sgd.clear_grad()
self.assertTrue(pred.place.is_custom_place())
def tearDown(self):
del os.environ['CUSTOM_DEVICE_ROOT']
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册