提交 1d4f02cc 编写于 作者: D dongzhihong

Merge remote-tracking branch 'origin/develop' into go_optimizer

......@@ -49,6 +49,7 @@ option(COVERALLS_UPLOAD "Package code coverage data to coveralls" OFF)
option(ON_TRAVIS "Exclude special unit test on Travis CI" OFF)
option(WITH_C_API "Compile PaddlePaddle with C-API(Prediction)" OFF)
option(WITH_GOLANG "Compile PaddlePaddle with GOLANG" OFF)
option(USE_NNPACK "Compile PaddlePaddle with NNPACK library" OFF)
# CMAKE_BUILD_TYPE
if(NOT CMAKE_BUILD_TYPE)
......@@ -129,6 +130,10 @@ if(WITH_GPU)
endif(NOT WITH_DSO)
endif(WITH_GPU)
if(USE_NNPACK)
list(APPEND EXTERNAL_LIBS ${NNPACK_LIB} ${PTHREADPOOL_LIB} "rt")
endif(USE_NNPACK)
add_subdirectory(proto)
# "add_subdirectory(paddle)" and "add_subdirectory(python)" should be
......
......@@ -101,23 +101,16 @@ function(merge_static_libs TARGET_NAME)
# First get the file names of the libraries to be merged
foreach(lib ${libs})
get_target_property(libtype ${lib} TYPE)
if(NOT libtype STREQUAL "STATIC_LIBRARY")
message(FATAL_ERROR "merge_static_libs can only process static libraries")
endif()
set(libfiles ${libfiles} $<TARGET_FILE:${lib}>)
endforeach()
if(APPLE) # Use OSX's libtool to merge archives
add_custom_target(${TARGET_NAME}_archive
COMMAND libtool -static -o "${CMAKE_CURRENT_BINARY_DIR}/lib${TARGET_NAME}.a" ${libfiles}
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}
DEPENDS ${libs}
)
add_library(${TARGET_NAME} STATIC IMPORTED GLOBAL)
set_property(TARGET ${TARGET_NAME} PROPERTY
IMPORTED_LOCATION "${CMAKE_CURRENT_BINARY_DIR}/lib${TARGET_NAME}.a")
add_dependencies(${TARGET_NAME} ${TARGET_NAME}_archive)
set(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME}_dummy.c)
file(WRITE ${dummyfile} "const char * dummy = \"${dummyfile}\";")
add_library(${TARGET_NAME} STATIC ${dummyfile})
add_custom_command(TARGET ${TARGET_NAME} POST_BUILD
COMMAND rm "${CMAKE_CURRENT_BINARY_DIR}/lib${TARGET_NAME}.a"
COMMAND /usr/bin/libtool -static -o "${CMAKE_CURRENT_BINARY_DIR}/lib${TARGET_NAME}.a" ${libfiles})
else() # general UNIX: use "ar" to extract objects and re-add to a common lib
foreach(lib ${libs})
set(objlistfile ${lib}.objlist) # list of objects in the input library
......
......@@ -105,3 +105,48 @@ shared_library(api
### Implementation
As above example CMakeLists.txt executes, each function invocation adds "nodes" to a dependency graph. It also use this graph to generate CMake commands including `add_executable`, `add_dependencies`, `target_link_libraries`, and `add_test`.
### Using Package Manager For Go
Building Go binaries and libraries need to satisfy their dependencies, generally
we can do `go get ./...` to download and compile all external dependencies. The
problems are:
1. `go get` will always get the latest code from the default branch of the
remote repo, so changes of dependents might break the build. This is very
different with what we already have in `cmake/external` which download a
specific version or commit id of the dependency.
1. Some locations can not access external dependencies through the internet, as mentioned
in https://github.com/PaddlePaddle/Paddle/issues/2605. Using package management
tools can package the dependencies as a "vendor" package, which can be mirrored
at many cloud file hosting, so users what to compile paddle by themselves can
download this "vendor" package from a mirror site.
#### Choose A Suitable Tool
As mentioned by @wangkuiyi, [Here](https://github.com/golang/go/wiki/PackageManagementTools)
list dozens of Go package managers. We choose the tool using following principles:
- Most "active" projects with more stars, more pull requests or commits
- Widely used project
After comparing all these projects, we shall choose between the most popular
tools: Godep and Glide.
Here's a brief comparison between Godep and Glide
: https://github.com/Masterminds/glide/wiki/Go-Package-Manager-Comparison. There are
also many complaints about using `Godep`. There's also a new "official" pakcage
management tool has been started at: https://github.com/golang/dep to resolve
such problems, but it's currently at Alpha stage. So the best choice now is
glide obviously.
#### Manage Go Packages
- Dependencies: `go/glide.yaml` will store the dependencies and their versions which
is directly imported by paddle. `go/glide.lock` will store all dependencies recursively
with their commit id. Builds will "lock" to these packages if we don't `glide up`
them
- Vendor package: `go/vendor` directory will generated when running `cmake` command. `cmake`
will download the code corresponding to `go/glide.lock`. If we put a vendor folder
under `go/`, cmake will just check the commit id to the packages under the folder,
if commit id matches, there will be no download at all.
# Design Doc: Save Model
## Overview
The model is the output of the training process. There are two
ways from which user can obtain a model:
- Save model triggered by user code: user code asks PaddlePaddle to
save a model.
- Convert model from the checkpoint: model being converted from
pservers' periodic checkpoint. In this way, the user can cancel a
job at any time, and still have a relatively fresh model (we
checkpoint around every 5 minutes).
### Trainer Saving Model vs. Pservers Saving Model
Both trainers and pservers have access to the model. So the model can
be saved from a trainer or pservers. We need to decide where the model
is saved from.
#### Dense Update vs. Sparse Update
There are two types of model update methods: dense update and sparse
update (when the model parameter is configured to be sparse).
- Dense update
Every trainer has it's own full copy of the model. Every model
update will update the entire model.
- Sparse update
The training input is sparse, and the trainer does not have the
entire model. It will only download the sub-model necessary related
to the input. When updating the model, only the sub-model related to
the training input is updated.
#### Pservers Saving Model
The benefit of letting pservers save model is they have the entire
model all the time. However, since pservers are on different nodes, it
requires a merging process to merge model shards into the same
model. Thus requires the pservers to write models to a distributed
filesystem, making the checkpoint shards visible to the merge program.
#### Trainer Saving Model
The benefit of letting one trainer to save the model is it does not
require a distributed filesystem. And it's reusing the same save model
logic when training locally - except when doing sparse update, the
trainer needs to download the entire model during the saving process.
#### Conclusion
Given trainer saving model does not require a distributed filesystem,
and is an intuitive extension to trainer saving model when training
locally, we decide to let the trainer save the model when doing
distributed training.
### Convert Model from Checkpoint
TODO
## Timeline
We first implement trainer save the model. Converting the latest
snapshot to a model will be a TODO for future.
## Trainer Save Model
### Trainer Election
One trainer will be elected as the one to save the model. When using
etcd, trainer ID is a randomly generated UUID, we will utilize etcd to
elect one trainer. When not using etcd, unique trainer IDs will be
given by the administrator, the trainer whose ID is "0" is elected to
save the model.
### Model Save Path
Each trainer will be given the directory to save the model. The
elected trainer will save the model to
`given-directory/trainerID`. Since the trainer ID is unique, this
would prevent concurrent save to the same file when multiple trainers
are elected to save the model when split-brain problem happens.
### What Happens When Model Is Saving
It takes some time to save model, we need to define what will happen
when save model is taking place.
When doing dense update, the trainer uses the local model. Pservers
does not need to pause model update.
When doing sparse update. The trainer needs to download the entire
model while saving. To get the most accurate model, the model update
needs to be paused before the download starts and resumed after the
download finishes. Otherwise, the trainer gets a model that is
"polluted": some part of the model is old, some part of the model is
new.
It's unclear that the "polluted" model will be inferior due to the
stochastic nature of deep learning, and pausing the model update will
add more complexity to the system. Since supporting sparse update is a
TODO item. We defer the evaluation of pause the model update or not
during saving model to the future.
......@@ -31,7 +31,7 @@ def event_handler(event):
# define training dataset reader
def train_reader():
train_x = np.array([[1, 1], [1, 2], [3, 4], [5, 2]])
train_y = np.array([-2, -3, -7, -7])
train_y = np.array([[-2], [-3], [-7], [-7]])
def reader():
for i in xrange(train_y.shape[0]):
......
......@@ -10,6 +10,14 @@ if(WITH_GPU)
cuda_compile(cu_objs ${cu_files})
endif()
if(USE_NNPACK)
include(nnpack/nnpack.cmake)
list(APPEND cpp_files nnpack/NNPACKConvOp.cpp)
if(WITH_TESTING)
add_unittest(NNPACKConvOpTest nnpack/NNPACKConvOpTest.cpp)
endif()
endif()
add_library(paddle_function STATIC ${cpp_files} ${cu_objs})
add_dependencies(paddle_function ${external_project_dependencies})
add_dependencies(paddle_function paddle_proto)
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "nnpack.h"
#include "paddle/function/ConvOp.h"
DEFINE_bool(nnpack_allocate_outside,
false,
"Allocate and free workspace memory outside the NNPACK interface.");
DEFINE_int32(nnpack_num_threads,
0,
"The number of nnpack threads"
"default: 0; 0 to disable threadpool.");
namespace paddle {
nnp_convolution_algorithm get_nnp_convolution_algorithm(
const std::string& algorithm) {
if (algorithm == "auto") {
return nnp_convolution_algorithm_auto;
} else if (algorithm == "ft8x8") {
return nnp_convolution_algorithm_ft8x8;
} else if (algorithm == "ft16x16") {
return nnp_convolution_algorithm_ft16x16;
} else if (algorithm == "wt8x8") {
return nnp_convolution_algorithm_wt8x8;
} else if (algorithm == "implicit-gemm") {
return nnp_convolution_algorithm_implicit_gemm;
} else if (algorithm == "direct") {
return nnp_convolution_algorithm_direct;
} else {
return nnp_convolution_algorithm_auto;
}
}
template <DeviceType Device>
class NNPACKConvFunction : public ConvFunctionBase {
public:
void init(const FuncConfig& config) override {
ConvFunctionBase::init(config);
CHECK_EQ(groups_, (size_t)1);
algorithm_ = get_nnp_convolution_algorithm(config.get<std::string>("algo"));
// algorithm_ = nnp_convolution_algorithm_auto;
transform_strategy_ = nnp_convolution_transform_strategy_compute;
nnp_status status = nnp_initialize();
CHECK_EQ(status, nnp_status_success);
workspaceBuffer_ = nullptr;
workspaceSize_ = 0;
threadpool_ = nullptr;
if (FLAGS_nnpack_num_threads) {
threadpool_ = pthreadpool_create(FLAGS_nnpack_num_threads);
VLOG(3) << "Number of threads "
<< pthreadpool_get_threads_count(threadpool_);
}
}
~NNPACKConvFunction() {
if (threadpool_) {
pthreadpool_destroy(threadpool_);
}
if (workspaceBuffer_) {
free(workspaceBuffer_);
}
}
virtual void check(const BufferArgs& inputs,
const BufferArgs& outputs) override {
const TensorShape& input = inputs[0].shape();
const TensorShape& filter = inputs[1].shape();
const TensorShape& output = outputs[0].shape();
checkShape(input, filter, output);
}
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ(numInputs_, inputs.size());
CHECK_EQ(numOutputs_, outputs.size());
CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO);
check(inputs, outputs);
const TensorShape& input = inputs[0].shape();
const TensorShape& filter = inputs[1].shape();
const TensorShape& output = outputs[0].shape();
size_t batchSize = input[0];
size_t inputChannels = input[1];
size_t inputHeight = input[2];
size_t inputWidth = input[3];
size_t filterHeight = getFilterHeight(filter);
size_t filterWidth = getFilterWidth(filter);
size_t outputChannels = output[1];
// size_t outputHeight = output[2];
// size_t outputWidth = output[3];
nnp_size inputSize = {.width = inputWidth, .height = inputHeight};
nnp_padding padding = {.top = (size_t)paddingH(),
.right = (size_t)paddingW(),
.bottom = (size_t)paddingH(),
.left = (size_t)paddingW()};
nnp_size kernelSize = {.width = filterWidth, .height = filterHeight};
nnp_size outputSubsampling = {.width = (size_t)strideW(),
.height = (size_t)strideH()};
float* inputData = inputs[0].data<float>();
float* filterData = inputs[1].data<float>();
float* outputData = outputs[0].data<float>();
void* bufferPtr = nullptr;
size_t* sizePtr = nullptr;
size_t needSize;
if (FLAGS_nnpack_allocate_outside) {
if (batchSize == 1) {
nnp_status status = nnp_convolution_inference(algorithm_,
transform_strategy_,
inputChannels,
outputChannels,
inputSize,
padding,
kernelSize,
outputSubsampling,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
&needSize,
nnp_activation_identity,
nullptr,
nullptr,
nullptr);
CHECK_EQ(status, nnp_status_success);
} else {
// only supports stride = 1
CHECK_EQ(strideH(), 1);
CHECK_EQ(strideW(), 1);
nnp_status status = nnp_convolution_output(algorithm_,
batchSize,
inputChannels,
outputChannels,
inputSize,
padding,
kernelSize,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
&needSize,
nnp_activation_identity,
nullptr,
nullptr,
nullptr);
CHECK_EQ(status, nnp_status_success);
}
VLOG(3) << "workspace size is " << needSize;
if (needSize > workspaceSize_) {
workspaceSize_ = needSize;
if (workspaceBuffer_) {
free(workspaceBuffer_);
} else {
posix_memalign(&workspaceBuffer_, 64, needSize);
}
}
if (needSize) {
bufferPtr = workspaceBuffer_;
sizePtr = &needSize;
}
}
if (batchSize == 1) {
nnp_status status =
nnp_convolution_inference(algorithm_,
transform_strategy_,
inputChannels,
outputChannels,
inputSize,
padding,
kernelSize,
outputSubsampling,
inputData,
filterData,
nullptr, /* bias */
outputData,
bufferPtr,
sizePtr,
nnp_activation_identity,
nullptr,
threadpool_, /* threadpool */
nullptr);
CHECK_EQ(status, nnp_status_success);
} else {
// only supports stride = 1
CHECK_EQ(strideH(), 1);
CHECK_EQ(strideW(), 1);
nnp_status status = nnp_convolution_output(algorithm_,
batchSize,
inputChannels,
outputChannels,
inputSize,
padding,
kernelSize,
inputData,
filterData,
nullptr, /* bias */
outputData,
bufferPtr,
sizePtr,
nnp_activation_identity,
nullptr,
threadpool_, /* threadpool */
nullptr);
CHECK_EQ(status, nnp_status_success);
}
}
private:
nnp_convolution_algorithm algorithm_;
nnp_convolution_transform_strategy transform_strategy_;
void* workspaceBuffer_;
size_t workspaceSize_;
pthreadpool_t threadpool_;
};
REGISTER_TYPED_FUNC(NNPACKConv, CPU, NNPACKConvFunction);
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "paddle/function/Function.h"
#include "paddle/function/FunctionTest.h"
DEFINE_string(algo,
"auto",
"The algorithm (auto, ft8x8, ft16x16, wt8x8, "
"implicit-gemm, or direct) for computing convolution of NNPACK.");
namespace paddle {
#define IS_NNPACK_SUPPORT(algo, filterSize, stride) \
if (algo == "direct" && filterSize != 1) continue; \
if (algo == "direct" && batchSize != 1) continue; \
if (algo == "wt8x8" && filterSize != 3) continue; \
if (algo == "implicit-gemm" && batchSize != 1) continue; \
if (algo != "auto" && algo != "implicit-gemm" && stride > 1) continue;
class ConvolutionTest {
public:
ConvolutionTest(const std::string& conv1,
const std::string& conv2,
std::string algo = "auto") {
for (size_t batchSize : {1, 32}) {
for (size_t inputSize : {7, 14, 54}) {
for (size_t filterSize : {1, 3, 5}) {
for (size_t inputChannels : {3, 64}) {
for (size_t outputChannels : {3, 64, 128}) {
if (inputChannels < outputChannels) break;
for (size_t stride : {1, 2}) {
// if batchSize > 1 NNPACKConv only supports stride = 1
if (batchSize > 1 && stride > 1) break;
for (size_t padding : {0, 1}) {
if (padding >= filterSize) break;
size_t outputSize =
(inputSize - filterSize + 2 * padding + stride) / stride;
IS_NNPACK_SUPPORT(algo, filterSize, stride);
LOG(INFO) << " batchSize=" << batchSize
<< " inputChannels=" << inputChannels
<< " inputHeight=" << inputSize
<< " inputWidth=" << inputSize
<< " outputChannels=" << outputChannels
<< " filterHeight=" << filterSize
<< " filterWidth=" << filterSize
<< " outputHeight=" << outputSize
<< " outputWidth=" << outputSize
<< " stride=" << stride << " padding=" << padding;
std::vector<size_t> paddings = {padding, padding};
std::vector<size_t> strides = {stride, stride};
Compare2Function<DEVICE_TYPE_CPU, DEVICE_TYPE_CPU> test(
conv1,
conv2,
FuncConfig()
.set("paddings", paddings)
.set("strides", strides)
.set("groups", (size_t)1)
.set("algo", algo));
TensorShape shape0{
batchSize, inputChannels, inputSize, inputSize};
TensorShape shape1{
outputChannels, inputChannels, filterSize, filterSize};
TensorShape shape2{
batchSize, outputChannels, outputSize, outputSize};
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, shape0));
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, shape1));
test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, shape2));
test.run();
}
}
}
}
}
}
}
}
};
TEST(Convolution, NNPACK) {
// NNPACK only supports stride = 1
ConvolutionTest test("GemmConv-CPU", "NNPACKConv-CPU", FLAGS_algo);
}
} // namespace paddle
# Find the NNPACK library
# NNPACK_ROOT - where to find NNPACK include and library.
#
set(NNPACK_FOUND OFF)
set(NNPACK_ROOT $ENV{NNPACK_ROOT} CACHE PATH "Folder contains NNPACK")
find_path(NNPACK_INC_DIR nnpack.h PATHS ${NNPACK_ROOT}/include)
find_library(NNPACK_LIB NAMES nnpack PATHS ${NNPACK_ROOT}/lib)
find_library(PTHREADPOOL_LIB NAMES pthreadpool PATHS ${NNPACK_ROOT}/lib)
if(NNPACK_INC_DIR AND NNPACK_LIB AND PTHREADPOOL_LIB)
set(NNPACK_FOUND ON)
INCLUDE_DIRECTORIES(${NNPACK_INC_DIR})
else()
message(FATAL_ERROR "Cannot find NNPACK in (${NNPACK_ROOT})")
endif()
......@@ -16,6 +16,10 @@ limitations under the License. */
#include "paddle/utils/Logging.h"
#include "paddle/utils/Stat.h"
DEFINE_bool(use_nnpack,
false,
"Whether to use nnpack for convolution calculation.");
namespace paddle {
/*
......@@ -37,6 +41,17 @@ bool ExpandConvLayer::init(const LayerMap &layerMap,
for (int i = 0; i < config_.inputs_size(); i++) {
std::vector<size_t> paddings = {(size_t)paddingY_[i], (size_t)padding_[i]};
std::vector<size_t> strides = {(size_t)strideY_[i], (size_t)stride_[i]};
if (FLAGS_use_nnpack) {
CHECK_EQ(isDeconv_, false);
createFunction(forward_,
"NNPACKConv",
FuncConfig()
.set("paddings", paddings)
.set("strides", strides)
.set("groups", (size_t)groups_[i])
.set("algo", std::string("auto")));
} else {
createFunction(forward_,
!isDeconv_ ? "GemmConv" : "GemmConvGradInput",
FuncConfig()
......@@ -58,6 +73,7 @@ bool ExpandConvLayer::init(const LayerMap &layerMap,
.set("strides", strides)
.set("groups", (size_t)groups_[i]));
}
}
return true;
}
......
......@@ -13,8 +13,11 @@ set(PY_FILES paddle/__init__.py
${V2_PY_FILES})
add_custom_target(copy_paddle_master)
SET(COPY_PADDLE_MASTER "")
if(WITH_GOLANG)
add_custom_command(TARGET copy_paddle_master
SET(COPY_PADDLE_MASTER "copy_paddle_master")
add_custom_command(TARGET ${COPY_PADDLE_MASTER}
COMMAND cp ${paddle_master_LIB_PATH} ${PROJ_ROOT}/python/paddle/v2/master/
)
add_dependencies(copy_paddle_master paddle_master)
......@@ -26,7 +29,7 @@ configure_file(${CMAKE_CURRENT_SOURCE_DIR}/setup.py.in
add_custom_command(OUTPUT ${OUTPUT_DIR}/.timestamp
COMMAND env ${py_env} ${PYTHON_EXECUTABLE} setup.py bdist_wheel
COMMAND ${CMAKE_COMMAND} -E touch ${OUTPUT_DIR}/.timestamp
DEPENDS gen_proto_py ${PY_FILES} ${external_project_dependencies} copy_paddle_master)
DEPENDS gen_proto_py ${PY_FILES} ${external_project_dependencies} ${COPY_PADDLE_MASTER})
add_custom_target(paddle_python ALL DEPENDS
${OUTPUT_DIR}/.timestamp)
......
......@@ -2082,10 +2082,10 @@ class MaxOutLayer(LayerBase):
class RowConvLayer(LayerBase):
def __init__(self, name, inputs, context_length, **xargs):
super(RowConvLayer, self).__init__(
name, 'maxout', 0, inputs=inputs, **xargs)
name, 'row_conv', 0, inputs=inputs, **xargs)
config_assert(
len(self.inputs) == 1,
'TransLayer must have one and only one input')
'row convolution layer must have one and only one input.')
input_layer = self.get_input_layer(0)
row_conv_conf = self.config.inputs[0].row_conv_conf
row_conv_conf.context_length = context_length
......
......@@ -7,7 +7,7 @@ layers {
}
layers {
name: "__row_conv_layer_0__"
type: "maxout"
type: "row_conv"
size: 2560
active_type: "relu"
inputs {
......
......@@ -30,6 +30,7 @@ http://www.robots.ox.ac.uk/~vgg/publications/papers/nilsback08.{pdf,ps.gz}.
"""
import cPickle
import itertools
import functools
from common import download
import tarfile
import scipy.io as scio
......@@ -54,21 +55,26 @@ TEST_FLAG = 'trnid'
VALID_FLAG = 'valid'
def default_mapper(sample):
def default_mapper(is_train, sample):
'''
map image bytes data to type needed by model input layer
'''
img, label = sample
img = load_image_bytes(img)
img = simple_transform(img, 256, 224, True)
img = simple_transform(
img, 256, 224, is_train, mean=[103.94, 116.78, 123.68])
return img.flatten().astype('float32'), label
train_mapper = functools.partial(default_mapper, True)
test_mapper = functools.partial(default_mapper, False)
def reader_creator(data_file,
label_file,
setid_file,
dataset_name,
mapper=default_mapper,
mapper,
buffered_size=1024,
use_xmap=True):
'''
......@@ -118,7 +124,7 @@ def reader_creator(data_file,
return map_readers(mapper, reader)
def train(mapper=default_mapper, buffered_size=1024, use_xmap=True):
def train(mapper=train_mapper, buffered_size=1024, use_xmap=True):
'''
Create flowers training set reader.
It returns a reader, each sample in the reader is
......@@ -141,7 +147,7 @@ def train(mapper=default_mapper, buffered_size=1024, use_xmap=True):
buffered_size, use_xmap)
def test(mapper=default_mapper, buffered_size=1024, use_xmap=True):
def test(mapper=test_mapper, buffered_size=1024, use_xmap=True):
'''
Create flowers test set reader.
It returns a reader, each sample in the reader is
......@@ -164,7 +170,7 @@ def test(mapper=default_mapper, buffered_size=1024, use_xmap=True):
buffered_size, use_xmap)
def valid(mapper=default_mapper, buffered_size=1024, use_xmap=True):
def valid(mapper=test_mapper, buffered_size=1024, use_xmap=True):
'''
Create flowers validation set reader.
It returns a reader, each sample in the reader is
......
......@@ -262,7 +262,12 @@ def left_right_flip(im):
return im[:, ::-1, :]
def simple_transform(im, resize_size, crop_size, is_train, is_color=True):
def simple_transform(im,
resize_size,
crop_size,
is_train,
is_color=True,
mean=None):
"""
Simply data argumentation for training. These operations include
resizing, croping and flipping.
......@@ -288,8 +293,20 @@ def simple_transform(im, resize_size, crop_size, is_train, is_color=True):
im = left_right_flip(im)
else:
im = center_crop(im, crop_size)
if len(im.shape) == 3:
im = to_chw(im)
im = im.astype('float32')
if mean is not None:
mean = np.array(mean, dtype=np.float32)
# mean value, may be one value per channel
if mean.ndim == 1:
mean = mean[:, np.newaxis, np.newaxis]
else:
# elementwise mean
assert len(mean.shape) == len(im)
im -= mean
return im
......@@ -297,7 +314,8 @@ def load_and_transform(filename,
resize_size,
crop_size,
is_train,
is_color=True):
is_color=True,
mean=None):
"""
Load image from the input file `filename` and transform image for
data argumentation. Please refer to the `simple_transform` interface
......@@ -318,5 +336,5 @@ def load_and_transform(filename,
:type is_train: bool
"""
im = load_image(filename)
im = simple_transform(im, resize_size, crop_size, is_train, is_color)
im = simple_transform(im, resize_size, crop_size, is_train, is_color, mean)
return im
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册