提交 ee0a794c 编写于 作者: S sweetsky0901

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into my_unpool_max_2d

......@@ -459,11 +459,11 @@ function(py_test TARGET_NAME)
if(WITH_TESTING)
set(options STATIC static SHARED shared)
set(oneValueArgs "")
set(multiValueArgs SRCS DEPS)
cmake_parse_arguments(py_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
set(multiValueArgs SRCS DEPS ARGS)
cmake_parse_arguments(py_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
add_test(NAME ${TARGET_NAME}
COMMAND env PYTHONPATH=${PADDLE_PYTHON_BUILD_DIR}/lib-python
${PYTHON_EXECUTABLE} ${py_test_SRCS}
${PYTHON_EXECUTABLE} -u ${py_test_SRCS} ${py_test_ARGS}
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
endif()
endfunction()
......@@ -54,7 +54,7 @@ img_conv
.. _api_v2.layer_context_projection:
context_projection
context_projection
------------------
.. autoclass:: paddle.v2.layer.context_projection
:noindex:
......@@ -70,7 +70,7 @@ Image Pooling Layer
img_pool
--------
.. autoclass:: paddle.v2.layer.img_pool
:noindex:
:noindex:
spp
---
......@@ -104,7 +104,7 @@ sum_to_one_norm
---------------
.. autoclass:: paddle.v2.layer.sum_to_one_norm
:noindex:
cross_channel_norm
------------------
.. autoclass:: paddle.v2.layer.cross_channel_norm
......@@ -114,7 +114,7 @@ row_l2_norm
-----------
.. autoclass:: paddle.v2.layer.row_l2_norm
:noindex:
Recurrent Layers
================
......@@ -415,6 +415,13 @@ multiplex
.. autoclass:: paddle.v2.layer.multiplex
:noindex:
Factorization Machine Layer
============================
factorization_machine
---------------------
.. autoclass:: paddle.v2.layer.factorization_machine
:noindex:
Slicing and Joining Layers
==========================
......
......@@ -55,7 +55,7 @@ paddle_error paddle_matrix_set_row(paddle_matrix mat,
}
PD_API paddle_error paddle_matrix_set_value(paddle_matrix mat,
paddle_real* value) {
paddle_real* value) {
if (mat == nullptr || value == nullptr) return kPD_NULLPTR;
auto ptr = cast(mat);
if (ptr->mat == nullptr) return kPD_NULLPTR;
......@@ -75,7 +75,7 @@ PD_API paddle_error paddle_matrix_set_value(paddle_matrix mat,
}
PD_API paddle_error paddle_matrix_get_value(paddle_matrix mat,
paddle_real* result) {
paddle_real* result) {
if (mat == nullptr || result == nullptr) return kPD_NULLPTR;
auto ptr = cast(mat);
if (ptr->mat == nullptr) return kPD_NULLPTR;
......
......@@ -79,7 +79,7 @@ PD_API paddle_error paddle_matrix_set_row(paddle_matrix mat,
* @note value should contain enough element of data to init the mat
*/
PD_API paddle_error paddle_matrix_set_value(paddle_matrix mat,
paddle_real* value);
paddle_real* value);
/**
* @brief PDMatGetRow Get raw row buffer from matrix
......@@ -93,14 +93,14 @@ PD_API paddle_error paddle_matrix_get_row(paddle_matrix mat,
paddle_real** rawRowBuffer);
/**
* @brief copy data from the matrix
* @brief copy data from the matrix
* @param [in] mat Target matrix
* @param [out] result pointer to store the matrix data
* @param [out] result pointer to store the matrix data
* @return paddle_error
* @note the space of the result should allocated before invoke this API
*/
PD_API paddle_error paddle_matrix_get_value(paddle_matrix mat,
paddle_real* result);
paddle_real* result);
/**
* @brief PDMatCreateNone Create None Matrix
* @return
......
......@@ -135,18 +135,17 @@ inline void CopyToVector(const Tensor& src, const platform::DeviceContext& ctx,
auto dst_ptr = static_cast<void*>(dst->data());
if (platform::is_cpu_place(src.place())) {
memory::Copy(dst_place, dst_ptr, boost::get<platform::CPUPlace>(src.place()),
src_ptr, size);
memory::Copy(dst_place, dst_ptr,
boost::get<platform::CPUPlace>(src.place()), src_ptr, size);
}
#ifdef PADDLE_WITH_CUDA
else if (platform::is_gpu_place(src.place())) { // NOLINT
memory::Copy(
dst_place, dst_ptr, boost::get<platform::GPUPlace>(src.place()), src_ptr,
size,
dst_place, dst_ptr, boost::get<platform::GPUPlace>(src.place()),
src_ptr, size,
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream());
}
#endif
}
} // namespace framework
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "FactorizationMachineLayer.h"
#include <algorithm>
#include <vector>
#include "paddle/math/SparseMatrix.h"
#include "paddle/utils/Logging.h"
#include "paddle/utils/Stat.h"
namespace paddle {
REGISTER_LAYER(factorization_machine, FactorizationMachineLayer);
bool FactorizationMachineLayer::init(const LayerMap& layerMap,
const ParameterMap& parameterMap) {
/* Initialize the basic parent class */
Layer::init(layerMap, parameterMap);
factorSize_ = config_.factor_size();
/* initialize the latentVectors_ */
CHECK_EQ(inputLayers_.size(), 1UL);
size_t inputSize = inputLayers_[0]->getSize();
CHECK_EQ(parameters_[0]->getSize(), inputSize * factorSize_);
latentVectors_ = std::unique_ptr<Weight>(
new Weight(inputSize, factorSize_, parameters_[0]));
return true;
}
void FactorizationMachineLayer::forward(PassType passType) {
Layer::forward(passType);
const MatrixPtr& inputV = getInputValue(0);
size_t batchSize = inputV->getHeight();
size_t outputSize = getSize();
size_t inputSize = inputLayers_[0]->getSize();
reserveOutput(batchSize, outputSize);
MatrixPtr outV = getOutputValue();
Matrix::resizeOrCreate(
latentVectorsSquare_, inputSize, factorSize_, false, useGpu_);
Matrix::resizeOrCreate(
inputMulFactor_, batchSize, factorSize_, false, useGpu_);
Matrix::resizeOrCreate(tmpOut_, batchSize, factorSize_, false, useGpu_);
REGISTER_TIMER_INFO("FmInputMulFactorTimer", getName().c_str());
inputMulFactor_->mul(*inputV, *latentVectors_->getW());
inputMulFactor_->square2(*tmpOut_);
outV->sumRows(*tmpOut_, 0.5, 0);
if (dynamic_cast<CpuSparseMatrix*>(inputV.get())) {
Matrix::resizeOrCreateSparseMatrix(inputSquare_,
inputV->getHeight(),
inputV->getWidth(),
inputV->getElementCnt(),
inputV->getValueType());
inputSquare_->copyFrom(*inputV);
(dynamic_cast<CpuSparseMatrix*>(inputSquare_.get()))->square2();
} else {
Matrix::resizeOrCreate(
inputSquare_, inputV->getHeight(), inputV->getWidth(), false, useGpu_);
inputV->square2(*inputSquare_);
}
latentVectors_->getW()->square2(*latentVectorsSquare_);
tmpOut_->mul(*inputSquare_, *latentVectorsSquare_);
outV->sumRows(*tmpOut_, -0.5, 1.0);
/* activation */ {
REGISTER_TIMER_INFO("FmFwAtvTimer", getName().c_str());
forwardActivation();
}
}
void FactorizationMachineLayer::backward(const UpdateCallback& callback) {
/* Do derivation */ { backwardActivation(); }
const MatrixPtr& inputV = getInputValue(0);
const MatrixPtr& oGrad = getOutputGrad();
Matrix::resizeOrCreate(
tmpSum_, 1, latentVectors_->getW()->getHeight(), false, useGpu_);
MatrixPtr tmpSumTrans = Matrix::create(tmpSum_->getRowBuf(0),
latentVectors_->getW()->getHeight(),
1,
false,
useGpu_);
/* Calculate the gradients of the latentVectors_ matrix */
if (latentVectors_->getWGrad()) {
if (dynamic_cast<CpuSparseMatrix*>(inputV.get())) {
Matrix::resizeOrCreateSparseMatrix(tmpInput_,
inputV->getHeight(),
inputV->getWidth(),
inputV->getElementCnt());
CpuSparseMatrix* sparseInputV =
dynamic_cast<CpuSparseMatrix*>(inputV.get());
CpuSparseMatrix* sparseInputSquare =
dynamic_cast<CpuSparseMatrix*>(inputSquare_.get());
CpuSparseMatrix* sparseTmpInput =
dynamic_cast<CpuSparseMatrix*>(tmpInput_.get());
sparseTmpInput->copyFrom(*sparseInputV);
sparseTmpInput->rowScale(0, *sparseInputV, *oGrad);
latentVectors_->getWGrad()->mul(
*sparseTmpInput->getTranspose(), *inputMulFactor_, 1, 1);
sparseTmpInput->rowScale(0, *sparseInputSquare, *oGrad);
Matrix::resizeOrCreate(negOnes_, 1, inputV->getHeight(), false, useGpu_);
negOnes_->zeroMem();
negOnes_->add(-1);
tmpSum_->mul(*negOnes_, *sparseTmpInput, 1, 0);
} else {
Matrix::resizeOrCreate(
tmpInput_, inputV->getHeight(), inputV->getWidth(), false, useGpu_);
tmpInput_->rowScale(0, *inputV, *oGrad);
latentVectors_->getWGrad()->mul(
*tmpInput_->getTranspose(), *inputMulFactor_, 1, 1);
tmpInput_->rowScale(0, *inputSquare_, *oGrad);
tmpSum_->sumCols(*tmpInput_, -1, 0);
}
latentVectors_->getWGrad()->addRowScale(
0, *latentVectors_->getW(), *tmpSumTrans);
/* Increasing the number of gradient */
latentVectors_->getParameterPtr()->incUpdate(callback);
}
/* Calculate the input layers gradient */
MatrixPtr inGrad = getInputGrad(0);
if (inGrad != NULL) {
inGrad->mul(
*inputMulFactor_, *latentVectors_->getW()->getTranspose(), 1, 1);
tmpSumTrans->sumRows(*latentVectorsSquare_, -1, 0);
inGrad->addColScale(0, *inputV, *tmpSum_);
inGrad->rowScale(0, *inGrad, *oGrad);
}
}
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "Layer.h"
#include "paddle/math/Matrix.h"
#include "paddle/utils/ThreadLocal.h"
namespace paddle {
/**
* @brief The Factorization Machine models pairwise (order-2) feature
* interactions as inner product of the learned latent vectors corresponding
* to each input feature.
*
* The Factorization Machine can effectively capture feature interactions
* especially when the input is sparse. While in principle FM can model higher
* order feature interaction, in practice usually only order-2 feature
* interactions are considered. The Factorization Machine Layer here only
* computes the order-2 interations with the formula:
*
* \f[
* y = \sum_{i=1}^{n-1}\sum_{j=i+1}^n\langle v_i, v_j \rangle x_i x_j
* \f]
*
* The detailed calculation for forward and backward can be found at this paper:
*
* Factorization machines.
*
* The config file api is factorization_machine.
*/
class FactorizationMachineLayer : public Layer {
protected:
// The latent vectors, shape: (size, factorSize_)
// Each row of the latentVectors_ matrix is the latent vector
// corresponding to one input feature dimension
std::unique_ptr<Weight> latentVectors_;
// The hyperparameter that defines the dimensionality of the factorization
size_t factorSize_;
private:
// Store the square values of the letent vectors matrix
MatrixPtr latentVectorsSquare_;
// Store the square values of input matrix
MatrixPtr inputSquare_;
// The result of input matrix * latent vector matrix that will be used in
// both forward and backward step
MatrixPtr inputMulFactor_;
// Store temporary calculation result
MatrixPtr tmpOut_;
MatrixPtr tmpSum_;
MatrixPtr tmpInput_;
// Negative identity matrix
MatrixPtr negOnes_;
public:
explicit FactorizationMachineLayer(const LayerConfig& config)
: Layer(config) {}
~FactorizationMachineLayer() {}
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
};
} // namespace paddle
......@@ -2464,6 +2464,25 @@ TEST(Layer, L2DistanceLayer) {
}
}
void testFactorizationMachineLayer(InputType type, bool useGpu) {
const int FACTOR_SIZE = 10;
TestConfig config;
config.layerConfig.set_type("factorization_machine");
config.layerConfig.set_factor_size(FACTOR_SIZE);
config.layerConfig.set_size(1);
config.biasSize = 0;
config.inputDefs.push_back({type, "layer_0", 128, 1280});
config.layerConfig.add_inputs();
testLayerGrad(config, "factorization_machine", 16, false, useGpu, false);
}
TEST(Layer, FactorizationMachineLayer) {
for (auto useGpu : {false, true}) {
testFactorizationMachineLayer(INPUT_DATA, useGpu);
}
testFactorizationMachineLayer(INPUT_SPARSE_FLOAT_VALUE_DATA, false);
}
int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
initMain(argc, argv);
......
......@@ -260,6 +260,35 @@ void CpuSparseMatrix::printOneRow(std::ostream& os, size_t idx) const {
os << ";";
}
void CpuSparseMatrix::rowScale(size_t cCol, CpuSparseMatrix& b, Matrix& c) {
CHECK(getFormat() != SPARSE_CSC) << "Not supported";
CHECK_EQ(height_, b.getHeight());
CHECK_EQ(width_, b.getWidth());
real* A = getValue();
real* B = b.getValue();
if (b.getValueType() == FLOAT_VALUE) {
for (size_t i = 0; i < height_; i++) {
size_t start = getRowStartIdx(i);
size_t end = getRowStartIdx(i + 1);
CHECK_EQ(start, b.getRowStartIdx(i));
CHECK_EQ(end, b.getRowStartIdx(i + 1));
for (size_t j = start; j < end; j++) {
A[j] = B[j] * c.getElement(i, cCol);
}
}
} else if (b.getValueType() == NO_VALUE) {
for (size_t i = 0; i < height_; i++) {
size_t start = getRowStartIdx(i);
size_t end = getRowStartIdx(i + 1);
CHECK_EQ(start, b.getRowStartIdx(i));
CHECK_EQ(end, b.getRowStartIdx(i + 1));
for (size_t j = start; j < end; j++) {
A[j] = c.getElement(i, cCol);
}
}
}
}
void CpuSparseMatrix::randomizeUniform() {
CHECK_LE(elementCnt_, height_ * width_);
if (valueType_ == FLOAT_VALUE) {
......
......@@ -239,6 +239,15 @@ public:
const unsigned int* cols,
const real* values);
/**
* @brief this_row = b_row * c_row[cCol]
*
* @param[in] cCol the column of matrix c used to scale each row of b
* @param[in] b CpuSparseMatrix
* @param[in] c Matrix
*/
void rowScale(size_t cCol, CpuSparseMatrix& b, Matrix& c);
void randomizeUniform();
void copyFrom(const GpuSparseMatrix& src, hl_stream_t stream);
......
......@@ -23,8 +23,7 @@ template <typename T>
class MaxOutFunctor<platform::CPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input,
framework::Tensor * output,
const framework::Tensor& input, framework::Tensor* output,
int groups) {
const int batch_size = input.dims()[0];
const int input_height = input.dims()[2];
......@@ -37,34 +36,30 @@ class MaxOutFunctor<platform::CPUPlace, T> {
T* output_data = output->mutable_data<T>(context.GetPlace());
for (int i = 0; i < batch_size; ++i) {
int new_bindex = c_size * i;
int new_bindex = c_size * i;
for (int c = 0; c < output_channels; ++c) {
int new_cindex = fea_size * c;
for (int f = 0; f < fea_size; ++f) {
T ele = static_cast<T>(-FLT_MAX);
for (int ph = 0; ph < groups; ++ph) {
T x = input_data[(new_bindex + new_cindex) * groups
+ ph * fea_size + f];
T x = input_data[(new_bindex + new_cindex) * groups +
ph * fea_size + f];
ele = ele > x ? ele : x;
}
output_data[(new_bindex+new_cindex+f)] = ele;
output_data[(new_bindex + new_cindex + f)] = ele;
}
}
}
}
};
template <class T>
class MaxOutGradFunctor<platform::CPUPlace, T> {
public:
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input,
framework::Tensor * input_grad,
const framework::Tensor& input, framework::Tensor* input_grad,
const framework::Tensor& output,
const framework::Tensor& output_grad,
int groups) {
const framework::Tensor& output_grad, int groups) {
const int batch_size = input.dims()[0];
const int input_height = input.dims()[2];
const int input_width = input.dims()[3];
......@@ -84,11 +79,11 @@ public:
bool continue_match = true;
int output_idx = blen + clen + f;
for (int g = 0; g < groups && continue_match; ++g) {
int input_idx = input_idx0 + fea_size * g;
if (input_data[input_idx] == output_data[output_idx]) {
input_grad_data[input_idx] += output_grad_data[output_idx];
continue_match = false;
}
int input_idx = input_idx0 + fea_size * g;
if (input_data[input_idx] == output_data[output_idx]) {
input_grad_data[input_idx] += output_grad_data[output_idx];
continue_match = false;
}
}
}
}
......
......@@ -21,9 +21,9 @@ namespace math {
template <typename T>
__global__ void KernelMaxOut(const int nthreads, const T* input_data,
const int channels,
const int input_height, const int input_width,
int groups, T* output_data ) {
const int channels, const int input_height,
const int input_width, int groups,
T* output_data) {
const int size = input_height * input_width * channels / groups;
const int feat_len = input_height * input_width;
int index = blockIdx.x * blockDim.x + threadIdx.x;
......@@ -34,7 +34,7 @@ __global__ void KernelMaxOut(const int nthreads, const T* input_data,
int channel_idx = batch_offset / feat_len;
int feat_idx = batch_offset % feat_len;
int data_idx =
(batch_idx * size + channel_idx * feat_len) * groups + feat_idx;
(batch_idx * size + channel_idx * feat_len) * groups + feat_idx;
T ele = static_cast<T>(-FLT_MAX);
for (int g = 0; g < groups; ++g) {
T x = input_data[data_idx + g * feat_len];
......@@ -44,34 +44,35 @@ __global__ void KernelMaxOut(const int nthreads, const T* input_data,
}
}
template <typename T>
__global__ void KernelMaxoutGrad(
const int nthreads, const T* input_data, const T* output_data,
const T* output_grad, T* input_grad, const int channels,
const int input_height, const int input_width, int groups) {
const int size = input_height * input_width * channels / groups;
const int feat_len = input_height * input_width;
int index = blockIdx.x * blockDim.x + threadIdx.x;
int offset = blockDim.x * gridDim.x;
for (int i = index; i < nthreads; i += offset) {
int batch_idx = i / size;
int batch_offset = i % size;
int channel_idx = batch_offset / feat_len;
int feat_idx = batch_offset % feat_len;
int data_idx =
__global__ void KernelMaxoutGrad(const int nthreads, const T* input_data,
const T* output_data, const T* output_grad,
T* input_grad, const int channels,
const int input_height, const int input_width,
int groups) {
const int size = input_height * input_width * channels / groups;
const int feat_len = input_height * input_width;
int index = blockIdx.x * blockDim.x + threadIdx.x;
int offset = blockDim.x * gridDim.x;
for (int i = index; i < nthreads; i += offset) {
int batch_idx = i / size;
int batch_offset = i % size;
int channel_idx = batch_offset / feat_len;
int feat_idx = batch_offset % feat_len;
int data_idx =
(batch_idx * size + channel_idx * feat_len) * groups + feat_idx;
int max_index = -1;
bool continue_match = true;
for (int g = 0; g < groups && continue_match; ++g) {
if (input_data[data_idx + g * feat_len] == output_data[i]) {
max_index = data_idx + g * feat_len;
continue_match = false;
break;
}
}
if (max_index != -1) {
input_grad[max_index] += output_grad[index];
int max_index = -1;
bool continue_match = true;
for (int g = 0; g < groups && continue_match; ++g) {
if (input_data[data_idx + g * feat_len] == output_data[i]) {
max_index = data_idx + g * feat_len;
continue_match = false;
break;
}
}
if (max_index != -1) {
input_grad[max_index] += output_grad[index];
}
}
}
/*
* All tensors are in NCHW format.
......@@ -80,7 +81,7 @@ template <typename T>
class MaxOutFunctor<platform::GPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor * output,
const framework::Tensor& input, framework::Tensor* output,
int groups) {
const int batch_size = input.dims()[0];
const int input_channels = input.dims()[1];
......@@ -92,7 +93,7 @@ class MaxOutFunctor<platform::GPUPlace, T> {
const T* input_data = input.data<T>();
T* output_data = output->mutable_data<T>(context.GetPlace());
int nthreads = output->numel();
int nthreads = output->numel();
int blocks = (nthreads + 1024 - 1) / 1024;
dim3 threads(1024, 1);
dim3 grid(blocks, 1);
......@@ -101,8 +102,7 @@ class MaxOutFunctor<platform::GPUPlace, T> {
T><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(nthreads, input_data, input_channels,
input_height, input_width, groups,
output_data);
input_height, input_width, groups, output_data);
}
};
/*
......@@ -112,11 +112,9 @@ template <typename T>
class MaxOutGradFunctor<platform::GPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input,
framework::Tensor * input_grad,
const framework::Tensor& input, framework::Tensor* input_grad,
const framework::Tensor& output,
const framework::Tensor& output_grad,
int groups) {
const framework::Tensor& output_grad, int groups) {
const int batch_size = input.dims()[0];
const int input_channels = input.dims()[1];
const int input_height = input.dims()[2];
......@@ -129,7 +127,7 @@ class MaxOutGradFunctor<platform::GPUPlace, T> {
const T* output_data = output.data<T>();
const T* output_grad_data = output_grad.data<T>();
T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
int nthreads = output.numel();
int nthreads = output.numel();
int blocks = (nthreads + 1024 - 1) / 1024;
dim3 threads(1024, 1);
dim3 grid(blocks, 1);
......@@ -137,9 +135,9 @@ class MaxOutGradFunctor<platform::GPUPlace, T> {
KernelMaxoutGrad<
T><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(
nthreads, input_data, output_data, output_grad_data, input_grad_data,
input_channels, input_height, input_width, groups);
.stream()>>>(nthreads, input_data, output_data,
output_grad_data, input_grad_data, input_channels,
input_height, input_width, groups);
}
};
......
......@@ -21,15 +21,14 @@ namespace paddle {
namespace operators {
namespace math {
#define FLT_MAX \
__FLT_MAX__
#define FLT_MAX __FLT_MAX__
template <typename Place, typename T>
class MaxOutFunctor {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor * output,
const framework::Tensor& input, framework::Tensor* output,
int groups);
};
......@@ -37,8 +36,7 @@ template <typename Place, class T>
class MaxOutGradFunctor {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input,
framework::Tensor * input_grad,
const framework::Tensor& input, framework::Tensor* input_grad,
const framework::Tensor& output,
const framework::Tensor& output_grad, int groups);
};
......
......@@ -22,16 +22,17 @@ class MaxOutOpMaker : public framework::OpProtoAndCheckerMaker {
public:
MaxOutOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X",
AddInput(
"X",
"(Tensor) The input tensor of maxout operator. "
"The format of input tensor is NCHW. Where N is batch size, C is the "
"number of channels, H and W is the height and width of feature.");
AddOutput("Out",
"(Tensor) The output tensor of maxout operator."
"The format of output tensor is also NCHW."
"Where N is batch size, C is "
"the number of channels, H and W is the height and "
"width of feature.");
"(Tensor) The output tensor of maxout operator."
"The format of output tensor is also NCHW."
"Where N is batch size, C is "
"the number of channels, H and W is the height and "
"width of feature.");
AddAttr<int>(
"groups",
R"DOC("Specifies how many groups the input tensor will be split"
......@@ -59,21 +60,19 @@ class MaxOutOpMaker : public framework::OpProtoAndCheckerMaker {
}
};
class MaxOutOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of MaxoutOp"
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of MaxoutOp"
"should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of MaxoutOp should not be null.");
auto in_x_dims = ctx->GetInputDim("X");
int groups = ctx->Attrs().Get<int>("groups");
// check groups > 1
PADDLE_ENFORCE_GT(
groups, 1,
"groups should be larger than 1 in maxoutop");
PADDLE_ENFORCE_GT(groups, 1, "groups should be larger than 1 in maxoutop");
std::vector<int64_t> output_shape({in_x_dims[0], in_x_dims[1] / groups});
output_shape.push_back(in_x_dims[2]);
output_shape.push_back(in_x_dims[3]);
......@@ -87,18 +86,17 @@ class MaxOutOpGrad : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null.");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
"Input(X@GRAD) should not be null.");
"Input(X@GRAD) should not be null.");
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
};
} // namespace operators
} // namespace paddle
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(maxout, ops::MaxOutOp, ops::MaxOutOpMaker, maxout_grad,
ops::MaxOutOpGrad);
REGISTER_OP_CPU_KERNEL(maxout, ops::MaxOutKernel<paddle::platform::CPUPlace,
float>);
REGISTER_OP_CPU_KERNEL(maxout_grad,
ops::MaxOutGradKernel<paddle::platform::CPUPlace,
float>);
ops::MaxOutOpGrad);
REGISTER_OP_CPU_KERNEL(maxout,
ops::MaxOutKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
maxout_grad, ops::MaxOutGradKernel<paddle::platform::CPUPlace, float>);
......@@ -18,8 +18,6 @@ namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(maxout,
ops::MaxOutKernel<paddle::platform::GPUPlace, float>,
ops::MaxOutKernel<paddle::platform::GPUPlace, double>);
REGISTER_OP_GPU_KERNEL(maxout_grad,
ops::MaxOutGradKernel<paddle::platform::GPUPlace,
float>,
ops::MaxOutGradKernel<paddle::platform::GPUPlace,
double>);
REGISTER_OP_GPU_KERNEL(
maxout_grad, ops::MaxOutGradKernel<paddle::platform::GPUPlace, float>,
ops::MaxOutGradKernel<paddle::platform::GPUPlace, double>);
......@@ -53,7 +53,7 @@ class MaxOutGradKernel : public framework::OpKernel<T> {
zero(device_ctx, in_x_grad, static_cast<T>(0.0));
math::MaxOutGradFunctor<Place, T> maxout_backward;
maxout_backward(context.device_context(), *in_x, in_x_grad, *out,
*out_grad, groups);
*out_grad, groups);
}
}
};
......
......@@ -43,8 +43,8 @@ class ROIPoolOp : public framework::OperatorWithKernel {
"ROIs should be a 2-D tensor of shape (num_rois, 5)"
"given as [[batch_id, x1, y1, x2, y2], …].");
PADDLE_ENFORCE(rois_dims[1] == kROISize,
"ROIs should be a 2-D tensor of shape (num_rois, 5)"
"given as [[batch_id, x1, y1, x2, y2], …].");
"ROIs should be a 2-D tensor of shape (num_rois, 5)"
"given as [[batch_id, x1, y1, x2, y2], …].");
int pooled_height = ctx->Attrs().Get<int>("pooled_height");
int pooled_width = ctx->Attrs().Get<int>("pooled_width");
......@@ -65,7 +65,7 @@ class ROIPoolOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("Out", out_dims);
ctx->SetOutputDim("Argmax", out_dims);
}
}
protected:
framework::OpKernelType GetKernelType(
......@@ -100,7 +100,7 @@ class ROIPoolGradOp : public framework::OperatorWithKernel {
class ROIPoolOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ROIPoolOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X",
"(Tensor), "
......@@ -125,21 +125,22 @@ class ROIPoolOpMaker : public framework::OpProtoAndCheckerMaker {
"(Tensor), "
"Argmaxes corresponding to indices in X used "
"for gradient computation. Only output "
"if arg “is_test” is false.").AsIntermediate();
"if arg “is_test” is false.")
.AsIntermediate();
AddAttr<float>("spatial_scale",
"(float, default 1.0), "
"Multiplicative spatial scale factor "
"to translate ROI coords from their input scale "
"to the scale used when pooling.")
.SetDefault(1.0);
.SetDefault(1.0);
AddAttr<int>("pooled_height",
"(int, default 1), "
"The pooled output height.")
.SetDefault(1);
.SetDefault(1);
AddAttr<int>("pooled_width",
"(int, default 1), "
"The pooled output width.")
.SetDefault(1);
.SetDefault(1);
AddComment(R"DOC(
ROIPool operator
......@@ -153,11 +154,10 @@ https://stackoverflow.com/questions/43430056/what-is-roi-layer-in-fast-rcnn
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(roi_pool, ops::ROIPoolOp, ops::ROIPoolOpMaker,
roi_pool_grad, ops::ROIPoolGradOp);
REGISTER_OP(roi_pool, ops::ROIPoolOp, ops::ROIPoolOpMaker, roi_pool_grad,
ops::ROIPoolGradOp);
REGISTER_OP_CPU_KERNEL(
roi_pool,
ops::CPUROIPoolOpKernel<paddle::platform::CPUPlace, float>,
roi_pool, ops::CPUROIPoolOpKernel<paddle::platform::CPUPlace, float>,
ops::CPUROIPoolOpKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL(
roi_pool_grad,
......
......@@ -29,101 +29,95 @@ static inline int NumBlocks(const int N) {
kNumMaxinumNumBlocks);
}
template <typename T>
__global__ void GPUROIPoolForward(
const int nthreads, const T* input_data, const int64_t* input_rois,
const float spatial_scale, const int channels, const int height,
const int width, const int pooled_height, const int pooled_width,
T* output_data, int64_t* argmax_data) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
int offset = blockDim.x * gridDim.x;
for (size_t i = index; i < nthreads; i += offset) {
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int c = (index / pooled_width / pooled_height) % channels;
int n = index / pooled_width / pooled_height / channels;
const int64_t* offset_input_rois = input_rois + n * kROISize;
int roi_batch_ind = offset_input_rois[0];
int roi_start_w = round(offset_input_rois[1] * spatial_scale);
int roi_start_h = round(offset_input_rois[2] * spatial_scale);
int roi_end_w = round(offset_input_rois[3] * spatial_scale);
int roi_end_h = round(offset_input_rois[4] * spatial_scale);
int roi_width = max(roi_end_w - roi_start_w + 1, 1);
int roi_height = max(roi_end_h - roi_start_h + 1, 1);
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
int hstart = static_cast<int>(floor(static_cast<T>(ph) * bin_size_h));
int wstart = static_cast<int>(floor(static_cast<T>(pw) * bin_size_w));
int hend = static_cast<int>(ceil(static_cast<T>(ph + 1) * bin_size_h));
int wend = static_cast<int>(ceil(static_cast<T>(pw + 1) * bin_size_w));
hstart = min(max(hstart + roi_start_h, 0), height);
hend = min(max(hend + roi_start_h, 0), height);
wstart = min(max(wstart + roi_start_w, 0), width);
wend = min(max(wend + roi_start_w, 0), width);
bool is_empty = (hend <= hstart) || (wend <= wstart);
T maxval = is_empty ? 0 : -std::numeric_limits<T>::max();
int maxidx = -1;
const T* offset_input_data =
input_data + (roi_batch_ind * channels + c) * height * width;
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
int input_data_index = h * width + w;
if (offset_input_data[input_data_index] > maxval) {
maxval = offset_input_data[input_data_index];
maxidx = input_data_index;
}
template <typename T>
__global__ void GPUROIPoolForward(const int nthreads, const T* input_data,
const int64_t* input_rois,
const float spatial_scale, const int channels,
const int height, const int width,
const int pooled_height,
const int pooled_width, T* output_data,
int64_t* argmax_data) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
int offset = blockDim.x * gridDim.x;
for (size_t i = index; i < nthreads; i += offset) {
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int c = (index / pooled_width / pooled_height) % channels;
int n = index / pooled_width / pooled_height / channels;
const int64_t* offset_input_rois = input_rois + n * kROISize;
int roi_batch_ind = offset_input_rois[0];
int roi_start_w = round(offset_input_rois[1] * spatial_scale);
int roi_start_h = round(offset_input_rois[2] * spatial_scale);
int roi_end_w = round(offset_input_rois[3] * spatial_scale);
int roi_end_h = round(offset_input_rois[4] * spatial_scale);
int roi_width = max(roi_end_w - roi_start_w + 1, 1);
int roi_height = max(roi_end_h - roi_start_h + 1, 1);
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
int hstart = static_cast<int>(floor(static_cast<T>(ph) * bin_size_h));
int wstart = static_cast<int>(floor(static_cast<T>(pw) * bin_size_w));
int hend = static_cast<int>(ceil(static_cast<T>(ph + 1) * bin_size_h));
int wend = static_cast<int>(ceil(static_cast<T>(pw + 1) * bin_size_w));
hstart = min(max(hstart + roi_start_h, 0), height);
hend = min(max(hend + roi_start_h, 0), height);
wstart = min(max(wstart + roi_start_w, 0), width);
wend = min(max(wend + roi_start_w, 0), width);
bool is_empty = (hend <= hstart) || (wend <= wstart);
T maxval = is_empty ? 0 : -std::numeric_limits<T>::max();
int maxidx = -1;
const T* offset_input_data =
input_data + (roi_batch_ind * channels + c) * height * width;
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
int input_data_index = h * width + w;
if (offset_input_data[input_data_index] > maxval) {
maxval = offset_input_data[input_data_index];
maxidx = input_data_index;
}
}
output_data[index] = maxval;
if (argmax_data) {
argmax_data[index] = maxidx;
}
}
output_data[index] = maxval;
if (argmax_data) {
argmax_data[index] = maxidx;
}
}
}
template <typename T>
__global__ void GPUROIPoolBackward(
const int nthreads,
const int64_t* input_rois,
const T* output_grad,
const int64_t* argmax_data,
const int num_rois,
const float spatial_scale,
const int channels,
const int height,
const int width,
const int pooled_height,
const int pooled_width,
T* input_grad) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
int offset = blockDim.x * gridDim.x;
for (int i = index; i < nthreads; i += offset) {
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int c = (index / pooled_width / pooled_height) % channels;
int n = index / pooled_width / pooled_height / channels;
const int64_t* offset_input_rois = input_rois + n * kROISize;
int roi_batch_ind = offset_input_rois[0];
int input_offset = (roi_batch_ind * channels + c) * height * width;
int output_offset = (n * channels + c) * pooled_height * pooled_width;
const T* offset_output_grad = output_grad + output_offset;
T* offset_input_grad = input_grad + input_offset;
const int64_t* offset_argmax_data = argmax_data + output_offset;
int argmax = offset_argmax_data[ph * pooled_width + pw];
if (argmax != -1) {
platform::CudaAtomicAdd(offset_input_grad + argmax,
const int nthreads, const int64_t* input_rois, const T* output_grad,
const int64_t* argmax_data, const int num_rois, const float spatial_scale,
const int channels, const int height, const int width,
const int pooled_height, const int pooled_width, T* input_grad) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
int offset = blockDim.x * gridDim.x;
for (int i = index; i < nthreads; i += offset) {
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int c = (index / pooled_width / pooled_height) % channels;
int n = index / pooled_width / pooled_height / channels;
const int64_t* offset_input_rois = input_rois + n * kROISize;
int roi_batch_ind = offset_input_rois[0];
int input_offset = (roi_batch_ind * channels + c) * height * width;
int output_offset = (n * channels + c) * pooled_height * pooled_width;
const T* offset_output_grad = output_grad + output_offset;
T* offset_input_grad = input_grad + input_offset;
const int64_t* offset_argmax_data = argmax_data + output_offset;
int argmax = offset_argmax_data[ph * pooled_width + pw];
if (argmax != -1) {
platform::CudaAtomicAdd(
offset_input_grad + argmax,
static_cast<T>(offset_output_grad[ph * pooled_width + pw]));
}
}
}
}
template <typename Place, typename T>
class GPUROIPoolOpKernel : public framework::OpKernel<T> {
......@@ -145,25 +139,18 @@ class GPUROIPoolOpKernel : public framework::OpKernel<T> {
int width = in_dims[3];
size_t rois_num = rois->dims()[0];
if (rois_num== 0) return;
if (rois_num == 0) return;
int output_size = out->numel();
int blocks = NumBlocks(output_size);
int threads = kNumCUDAThreads;
GPUROIPoolForward<T>
<<<blocks, threads, 0, ctx.cuda_device_context().stream()>>>(
output_size,
in->data<T>(),
rois->data<int64_t>(),
spatial_scale,
channels,
height,
width,
pooled_height,
pooled_width,
out->mutable_data<T>(ctx.GetPlace()),
argmax->mutable_data<int64_t>(ctx.GetPlace()));
GPUROIPoolForward<
T><<<blocks, threads, 0, ctx.cuda_device_context().stream()>>>(
output_size, in->data<T>(), rois->data<int64_t>(), spatial_scale,
channels, height, width, pooled_height, pooled_width,
out->mutable_data<T>(ctx.GetPlace()),
argmax->mutable_data<int64_t>(ctx.GetPlace()));
}
};
......@@ -175,10 +162,8 @@ class GPUROIPoolGradOpKernel : public framework::OpKernel<T> {
auto* rois = ctx.Input<Tensor>("ROIs");
auto* argmax = ctx.Input<Tensor>("Argmax");
auto* out_grad =
ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* x_grad =
ctx.Output<Tensor>(framework::GradVarName("X"));
auto* out_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* x_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto pooled_height = ctx.Attr<int>("pooled_height");
auto pooled_width = ctx.Attr<int>("pooled_width");
......@@ -199,21 +184,13 @@ class GPUROIPoolGradOpKernel : public framework::OpKernel<T> {
int threads = kNumCUDAThreads;
if (output_grad_size > 0) {
GPUROIPoolBackward<T>
<<<blocks, threads, 0, ctx.cuda_device_context().stream()>>>(
output_grad_size,
rois->data<int64_t>(),
out_grad->data<T>(),
argmax->data<int64_t>(),
rois_num,
spatial_scale,
channels,
height,
width,
pooled_height,
pooled_width,
x_grad->mutable_data<T>(ctx.GetPlace()));
}
GPUROIPoolBackward<
T><<<blocks, threads, 0, ctx.cuda_device_context().stream()>>>(
output_grad_size, rois->data<int64_t>(), out_grad->data<T>(),
argmax->data<int64_t>(), rois_num, spatial_scale, channels, height,
width, pooled_height, pooled_width,
x_grad->mutable_data<T>(ctx.GetPlace()));
}
}
}
};
......@@ -223,8 +200,7 @@ class GPUROIPoolGradOpKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
roi_pool,
ops::GPUROIPoolOpKernel<paddle::platform::GPUPlace, float>,
roi_pool, ops::GPUROIPoolOpKernel<paddle::platform::GPUPlace, float>,
ops::GPUROIPoolOpKernel<paddle::platform::GPUPlace, double>);
REGISTER_OP_GPU_KERNEL(
roi_pool_grad,
......
......@@ -133,54 +133,47 @@ class CPUROIPoolGradOpKernel : public framework::OpKernel<T> {
auto* in = ctx.Input<framework::Tensor>("X");
auto* rois = ctx.Input<framework::Tensor>("ROIs");
auto* argmax = ctx.Input<framework::Tensor>("Argmax");
auto* out_grad =
ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* x_grad =
ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto* in_grad = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto pooled_height = ctx.Attr<int>("pooled_height");
auto pooled_width = ctx.Attr<int>("pooled_width");
if (x_grad) {
int channels = in->dims()[1];
auto in_stride = framework::stride(in->dims());
auto roi_stride = framework::stride(rois->dims());
if (in_grad) {
const int64_t* rois_data = rois->data<int64_t>();
int rois_num = rois->dims()[0];
T* x_grad_data = x_grad->mutable_data<T>(ctx.GetPlace());
const T* out_grad_data = out_grad->data<T>();
const int64_t* argmax_data = argmax->data<int64_t>();
T* in_grad_data = in_grad->mutable_data<T>(ctx.GetPlace());
math::SetConstant<Place, T> set_zero;
set_zero(ctx.device_context(), x_grad, static_cast<T>(0));
set_zero(ctx.device_context(), in_grad, static_cast<T>(0));
size_t roi_offset = roi_stride[0];
size_t batch_offset = in_stride[0];
size_t channel_offset = in_stride[1];
auto in_stride = framework::stride(in->dims());
auto argmax_stride = framework::stride(argmax->dims());
auto roi_stride = framework::stride(rois->dims());
auto out_stride = framework::stride(out_grad->dims());
const T* out_grad_data = out_grad->data<T>();
size_t pool_channel_offset = pooled_height * pooled_width;
const int64_t* argmax_data = argmax->data<int64_t>();
int rois_num = rois->dims()[0];
int channels = in->dims()[1];
for (size_t n = 0; n < rois_num; ++n) {
size_t roi_batch_idx = rois_data[0];
T* batch_grad_data = x_grad_data + batch_offset * roi_batch_idx;
for (int n = 0; n < rois_num; ++n) {
int roi_batch_idx = rois_data[0];
T* batch_grad_data = in_grad_data + roi_batch_idx * in_stride[0];
for (int c = 0; c < channels; ++c) {
for (int ph = 0; ph < pooled_height; ++ph) {
for (int pw = 0; pw < pooled_width; ++pw) {
size_t pool_index = ph * pooled_width + pw;
int pool_index = ph * pooled_width + pw;
if (argmax_data[pool_index] >= 0) {
size_t index = static_cast<size_t>(argmax_data[pool_index]);
auto index = argmax_data[pool_index];
batch_grad_data[index] += out_grad_data[pool_index];
}
}
}
batch_grad_data += channel_offset;
out_grad_data += pool_channel_offset;
argmax_data += pool_channel_offset;
batch_grad_data += in_stride[1];
out_grad_data += out_stride[1];
argmax_data += argmax_stride[1];
}
rois_data += roi_offset;
rois_data += roi_stride[0];
}
}
}
......
......@@ -45,7 +45,7 @@ class SequenceSliceOp : public framework::OperatorWithKernel {
// Initialize the output's dims to maximum,
// and re-set to real dims by the value of Offset and Length at kernel
ctx->SetOutputDim("Out", input_dims);
}
}
protected:
framework::OpKernelType GetKernelType(
......@@ -93,8 +93,7 @@ class SequenceSliceOpMaker : public framework::OpProtoAndCheckerMaker {
"(Tensor), "
"a vector<int> to describe the length of every input sequence for "
"sub sequence item.");
AddOutput("Out",
"(LoDTensor), the output of SequenceSliceOp.");
AddOutput("Out", "(LoDTensor), the output of SequenceSliceOp.");
AddComment(R"DOC(
Sequence slice operator
......
......@@ -55,7 +55,7 @@ SGD operator
This operator implements one step of the stochastic gradient descent algorithm.
$$param_out = param - learning_rate * grad$$
$$param\_out = param - learning\_rate * grad$$
)DOC");
}
......
......@@ -57,11 +57,21 @@ class ShrinkRNNMemoryOpProtoMaker : public framework::OpProtoAndCheckerMaker {
ShrinkRNNMemoryOpProtoMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "");
AddInput("RankTable", "");
AddInput("I", "");
AddOutput("Out", "");
AddComment("");
AddInput("X", "(LoDTensor) The RNN step memory to be shrinked.");
AddInput("RankTable", "(LoDRankTable) The lod_rank_table of dynamic RNN.");
AddInput("I",
"(LoDTensor) The step index. The RNN step memory 'X' will be "
"shrinked to match the size of the input of the index'th step.");
AddOutput("Out", "(LoDTensor) The shrinked RNN step memory.");
AddComment(
R"DOC(
In dynamic RNN, we are able to handle sequences of different lengths.
Because of the multiple lengths, the size of each step input can be
different, which may lead to a mismatching between the input of
the current step and the memory generated by the previous one. This
operator shrinks memory according to the size of the next step input,
to make sure that they can match each other.
)DOC");
}
};
......
......@@ -544,6 +544,9 @@ message LayerConfig {
// for batch normalization layer
// The small constant added to the variance to improve numeric stability.
optional double epsilon = 60 [ default = 0.00001 ];
// for factorization machine layer
optional uint32 factor_size = 61;
}
message EvaluatorConfig {
......
......@@ -3870,6 +3870,21 @@ class ScaleSubRegionLayer(LayerBase):
image_conf.channels)
@config_layer('factorization_machine')
class FactorizationMachineLayer(LayerBase):
def __init__(self, name, inputs, factor_size, **xargs):
super(FactorizationMachineLayer, self).__init__(
name, 'factorization_machine', size=1, inputs=inputs, **xargs)
config_assert(
len(self.inputs) == 1,
'factorization machine layer must have one and only one input.')
self.config.factor_size = factor_size
input_layer = self.get_input_layer(0)
psize = input_layer.size * factor_size
dims = [input_layer.size, factor_size]
self.create_input_parameter(0, psize, dims)
# Deprecated, use a new layer specific class instead
@config_func
def Layer(name, type, **xargs):
......
......@@ -148,6 +148,7 @@ __all__ = [
'resize_layer',
'sub_seq_layer',
'scale_sub_region_layer',
'factorization_machine',
]
......@@ -264,6 +265,8 @@ class LayerType(object):
SCALE_SUB_REGION_LAYER = 'scale_sub_region'
FACTORIZATION_MACHINE = 'factorization_machine'
@staticmethod
def is_layer_type(type_name):
"""
......@@ -1900,9 +1903,12 @@ def repeat_layer(input,
A layer for repeating the input for num_repeats times.
If as_row_vector:
.. math::
y = [x_1,\cdots, x_n, \cdots, x_1, \cdots, x_n]
If not as_row_vector:
.. math::
y = [x_1,\cdots, x_1, \cdots, x_n, \cdots, x_n]
......@@ -1915,19 +1921,19 @@ def repeat_layer(input,
:param input: The input of this layer.
:type input: LayerOutput
:param num_repeats: Repeat the input so many times
:param num_repeats: The times of repeating the input.
:type num_repeats: int
:param name: The name of this layer. It is optional.
:param as_row_vector: True for treating input as row vector and repeating
in the column direction. This is equivalent to apply
concat_layer() with num_repeats same input.
False for treating input as column vector and repeating
in the row direction.
:type name: basestring
:param as_row_vector: Whether to treat the input as row vectors or not. If
the parameter is set to True, the repeating operation
will be performed in the column direction. Otherwise,
it will be performed in the row direction.
:type as_row_vector: bool
:param act: Activation type. IdentityActivation is the default activation.
:type act: BaseActivation
:type name: basestring
:param layer_attr: extra layer attributes.
:param layer_attr: The extra layer attribute. See ExtraLayerAttribute for
details.
:type layer_attr: ExtraLayerAttribute.
:return: LayerOutput object.
:rtype: LayerOutput
......@@ -1974,13 +1980,14 @@ def seq_reshape_layer(input,
:param input: The input of this layer.
:type input: LayerOutput
:param reshape_size: the size of reshaped sequence.
:param reshape_size: The dimension of the reshaped sequence.
:type reshape_size: int
:param name: The name of this layer. It is optional.
:type name: basestring
:param act: Activation type. IdentityActivation is the default activation.
:type act: BaseActivation
:param layer_attr: extra layer attributes.
:param layer_attr: The extra layer attribute. See ExtraLayerAttribute for
details.
:type layer_attr: ExtraLayerAttribute.
:param bias_attr: The bias attribute. If the parameter is set to False or an object
whose type is not ParameterAttribute, no bias is defined. If the
......@@ -2008,7 +2015,7 @@ def seq_reshape_layer(input,
@layer_support()
def interpolation_layer(input, weight, name=None, layer_attr=None):
"""
This layer is for linear interpolation with two inputs,
This layer performs linear interpolation on two inputs,
which is used in NEURAL TURING MACHINE.
.. math::
......@@ -2030,7 +2037,8 @@ def interpolation_layer(input, weight, name=None, layer_attr=None):
:type weight: LayerOutput
:param name: The name of this layer. It is optional.
:type name: basestring
:param layer_attr: extra layer attributes.
:param layer_attr: The extra layer attribute. See ExtraLayerAttribute for
details.
:type layer_attr: ExtraLayerAttribute.
:return: LayerOutput object.
:rtype: LayerOutput
......@@ -2064,7 +2072,7 @@ def bilinear_interp_layer(input,
name=None,
layer_attr=None):
"""
This layer is to implement bilinear interpolation on conv layer output.
This layer implements bilinear interpolation on convolutional layer's output.
Please refer to Wikipedia: https://en.wikipedia.org/wiki/Bilinear_interpolation
......@@ -2074,18 +2082,19 @@ def bilinear_interp_layer(input,
bilinear = bilinear_interp_layer(input=layer1, out_size_x=64, out_size_y=64)
:param input: A input layer.
:type input: LayerOutput.
:param out_size_x: bilinear interpolation output width.
:type out_size_x: int | None
:param out_size_y: bilinear interpolation output height.
:type out_size_y: int | None
:param name: The layer's name, which cna not be specified.
:type name: None | basestring
:param layer_attr: Extra Layer attribute.
:type layer_attr: ExtraLayerAttribute
:param input: The input of this layer.
:type input: LayerOutput.
:param out_size_x: The width of the output.
:type out_size_x: int
:param out_size_y: The height of the output.
:type out_size_y: int
:param name: The name of this layer. It is optional.
:type name: basestring
:param layer_attr: The extra layer attribute. See ExtraLayerAttribute for
details.
:type layer_attr: ExtraLayerAttribute
:return: LayerOutput object.
:rtype: LayerOutput
:rtype: LayerOutput
"""
assert input.layer_type == LayerType.CONV_LAYER
assert isinstance(input.activation, LinearActivation)
......@@ -2120,8 +2129,8 @@ def power_layer(input, weight, name=None, layer_attr=None):
.. math::
y = x^w
where :math:`x` is a input vector, :math:`w` is scalar weight,
and :math:`y` is a output vector.
where :math:`x` is an input vector, :math:`w` is a scalar exponent,
and :math:`y` is an output vector.
The example usage is:
......@@ -2131,11 +2140,12 @@ def power_layer(input, weight, name=None, layer_attr=None):
:param input: The input of this layer.
:type input: LayerOutput
:param weight: Weight layer.
:param weight: The exponent of the power.
:type weight: LayerOutput
:param name: The name of this layer. It is optional.
:type name: basestring
:param layer_attr: extra layer attributes.
:param layer_attr: The extra layer attribute. See ExtraLayerAttribute for
details.
:type layer_attr: ExtraLayerAttribute.
:return: LayerOutput object.
:rtype: LayerOutput
......@@ -2175,11 +2185,12 @@ def scaling_layer(input, weight, name=None, layer_attr=None):
:param input: The input of this layer.
:type input: LayerOutput
:param weight: Weight layer.
:param weight: The weight of each sample.
:type weight: LayerOutput
:param name: The name of this layer. It is optional.
:type name: basestring
:param layer_attr: extra layer attributes.
:param layer_attr: The extra layer attribute. See ExtraLayerAttribute for
details.
:type layer_attr: ExtraLayerAttribute.
:return: LayerOutput object.
:rtype: LayerOutput
......@@ -2217,7 +2228,8 @@ def trans_layer(input, name=None, layer_attr=None):
:type input: LayerOutput
:param name: The name of this layer. It is optional.
:type name: basestring
:param layer_attr: extra layer attributes.
:param layer_attr: The extra layer attribute. See ExtraLayerAttribute for
details.
:type layer_attr: ExtraLayerAttribute.
:return: LayerOutput object.
:rtype: LayerOutput
......@@ -2253,11 +2265,14 @@ def rotate_layer(input, height, width, name=None, layer_attr=None):
:param input: The input of this layer.
:type input: LayerOutput
:param height: The height of the sample matrix
:param height: The height of the sample matrix.
:type height: int
:param width: The width of the sample matrix.
:type width: int
:param name: The name of this layer. It is optional.
:type name: basestring
:param layer_attr: extra layer attributes.
:param layer_attr: The extra layer attribute. See ExtraLayerAttribute for
details.
:type layer_attr: ExtraLayerAttribute.
:return: LayerOutput object.
:rtype: LayerOutput
......@@ -2302,15 +2317,15 @@ def cos_sim(a, b, scale=1, size=1, name=None, layer_attr=None):
:param name: The name of this layer. It is optional.
:type name: basestring
:param a: input layer a
:param a: The first input of this layer.
:type a: LayerOutput
:param b: input layer b
:param b: The second input of this layer.
:type b: LayerOutput
:param scale: scale for cosine value. default is 5.
:param scale: The scale of the cosine similarity. 1 is the default value.
:type scale: float
:param size: layer size. NOTE size_a * size should equal size_b.
:param size: The dimension of this layer. NOTE size_a * size should equal size_b.
:type size: int
:param layer_attr: Extra Layer Attribute.
:param layer_attr: The extra layer attribute. See ExtraLayerAttribute for details.
:type layer_attr: ExtraLayerAttribute
:return: LayerOutput object.
:rtype: LayerOutput
......@@ -2395,8 +2410,10 @@ def hsigmoid(input,
"""
Organize the classes into a binary tree. At each node, a sigmoid function
is used to calculate the probability of belonging to the right branch.
This idea is from "F. Morin, Y. Bengio (AISTATS 05):
Hierarchical Probabilistic Neural Network Language Model."
Reference:
`Hierarchical Probabilistic Neural Network Language Model
<http://www.gatsby.ucl.ac.uk/aistats/fullpapers/208.pdf>`_
The example usage is:
......@@ -2407,19 +2424,21 @@ def hsigmoid(input,
:param input: The input of this layer.
:type input: LayerOutput | list | tuple
:param label: Label layer.
:param label: The input label.
:type label: LayerOutput
:param num_classes: number of classes.
:type num_classes: int | None
:param num_classes: The number of classes. And it should be larger than 2. If the parameter
is not set or set to None, its actual value will be automatically set to
the number of labels.
:type num_classes: int
:param name: The name of this layer. It is optional.
:type name: basestring
:param bias_attr: The bias attribute. If the parameter is set to False or an object
whose type is not ParameterAttribute, no bias is defined. If the
parameter is set to True, the bias is initialized to zero.
:type bias_attr: ParameterAttribute | None | bool | Any
:param param_attr: Parameter Attribute. None means default parameter.
:type param_attr: ParameterAttribute | None
:param layer_attr: Extra Layer Attribute.
:param param_attr: The parameter attribute. See ParameterAttribute for details.
:type param_attr: ParameterAttribute
:param layer_attr: The extra layer attribute. See ExtraLayerAttribute for details.
:type layer_attr: ExtraLayerAttribute
:return: LayerOutput object.
:rtype: LayerOutput
......@@ -2969,8 +2988,8 @@ def spp_layer(input,
A layer performs spatial pyramid pooling.
Reference:
Spatial Pyramid Pooling in Deep Convolutional Networks for Visual Recognition
https://arxiv.org/abs/1406.4729
`Spatial Pyramid Pooling in Deep Convolutional Networks for Visual Recognition
https://arxiv.org/abs/1406.4729`_
The example usage is:
......@@ -3071,8 +3090,8 @@ def img_cmrnorm_layer(input,
Response normalization across feature maps.
Reference:
ImageNet Classification with Deep Convolutional Neural Networks
http://www.cs.toronto.edu/~fritz/absps/imagenet.pdf
`ImageNet Classification with Deep Convolutional Neural Networks
http://www.cs.toronto.edu/~fritz/absps/imagenet.pdf`_
The example usage is:
......@@ -3138,9 +3157,9 @@ def batch_norm_layer(input,
y_i &\\gets \\gamma \\hat{x_i} + \\beta \\qquad &//\ scale\ and\ shift
Reference:
Batch Normalization: Accelerating Deep Network Training by Reducing
`Batch Normalization: Accelerating Deep Network Training by Reducing
Internal Covariate Shift
http://arxiv.org/abs/1502.03167
http://arxiv.org/abs/1502.03167`_
The example usage is:
......@@ -4241,7 +4260,7 @@ def dot_prod_layer(input1, input2, name=None, layer_attr=None):
:param name: The name of this layer. It is optional.
:type name: basestring
:param input1: The first input layer.
:type input: LayerOutput
:type input1: LayerOutput
:param input2: The second input layer.
:type input2: LayerOutput
:param layer_attr: The extra layer attribute. See ExtraLayerAttribute for
......@@ -5397,10 +5416,10 @@ def maxout_layer(input, groups, num_channels=None, name=None, layer_attr=None):
to be devided by groups.
Reference:
Maxout Networks
http://www.jmlr.org/proceedings/papers/v28/goodfellow13.pdf
Multi-digit Number Recognition from Street View Imagery using Deep Convolutional Neural Networks
https://arxiv.org/pdf/1312.6082v4.pdf
`Maxout Networks
http://www.jmlr.org/proceedings/papers/v28/goodfellow13.pdf`_
`Multi-digit Number Recognition from Street View Imagery using Deep Convolutional Neural Networks
https://arxiv.org/pdf/1312.6082v4.pdf`_
.. math::
y_{si+j} = \max_k x_{gsi + sk + j}
......@@ -5465,9 +5484,9 @@ def ctc_layer(input,
alignment between the inputs and the target labels is unknown.
Reference:
Connectionist Temporal Classification: Labelling Unsegmented Sequence Data
`Connectionist Temporal Classification: Labelling Unsegmented Sequence Data
with Recurrent Neural Networks
http://machinelearning.wustl.edu/mlpapers/paper_files/icml2006_GravesFGS06.pdf
http://machinelearning.wustl.edu/mlpapers/paper_files/icml2006_GravesFGS06.pdf`_
Note:
Considering the 'blank' label needed by CTC, you need to use (num_classes + 1)
......@@ -5539,9 +5558,9 @@ def warp_ctc_layer(input,
install it to :code:`third_party/install/warpctc` directory.
Reference:
Connectionist Temporal Classification: Labelling Unsegmented Sequence Data
`Connectionist Temporal Classification: Labelling Unsegmented Sequence Data
with Recurrent Neural Networks
http://machinelearning.wustl.edu/mlpapers/paper_files/icml2006_GravesFGS06.pdf
http://machinelearning.wustl.edu/mlpapers/paper_files/icml2006_GravesFGS06.pdf`_
Note:
- Let num_classes represents the category number. Considering the 'blank'
......@@ -5761,8 +5780,8 @@ def nce_layer(input,
Noise-contrastive estimation.
Reference:
A fast and simple algorithm for training neural probabilistic language
models. https://www.cs.toronto.edu/~amnih/papers/ncelm.pdf
`A fast and simple algorithm for training neural probabilistic language
models. https://www.cs.toronto.edu/~amnih/papers/ncelm.pdf`_
The example usage is:
......@@ -5877,8 +5896,8 @@ def rank_cost(left,
A cost Layer for learning to rank using gradient descent.
Reference:
Learning to Rank using Gradient Descent
http://research.microsoft.com/en-us/um/people/cburges/papers/ICML_ranking.pdf
`Learning to Rank using Gradient Descent
http://research.microsoft.com/en-us/um/people/cburges/papers/ICML_ranking.pdf`_
.. math::
......@@ -6413,8 +6432,8 @@ def smooth_l1_cost(input, label, name=None, coeff=1.0, layer_attr=None):
smooth_{L1}(x) = \\begin{cases} 0.5x^2& \\text{if} \\ |x| < 1 \\\\ |x|-0.5& \\text{otherwise} \end{cases}
Reference:
Fast R-CNN
https://arxiv.org/pdf/1504.08083v2.pdf
`Fast R-CNN
https://arxiv.org/pdf/1504.08083v2.pdf`_
The example usage is:
......@@ -6620,8 +6639,8 @@ def prelu_layer(input,
The Parametric Relu activation that actives outputs with a learnable weight.
Reference:
Delving Deep into Rectifiers: Surpassing Human-Level Performance on
ImageNet Classification http://arxiv.org/pdf/1502.01852v1.pdf
`Delving Deep into Rectifiers: Surpassing Human-Level Performance on
ImageNet Classification http://arxiv.org/pdf/1502.01852v1.pdf`_
.. math::
z_i &\\quad if \\quad z_i > 0 \\\\
......@@ -6717,8 +6736,8 @@ def gated_unit_layer(input,
product between :match:`X'` and :math:`\sigma` is finally returned.
Reference:
Language Modeling with Gated Convolutional Networks
https://arxiv.org/abs/1612.08083
`Language Modeling with Gated Convolutional Networks
https://arxiv.org/abs/1612.08083`_
.. math::
y=\\text{act}(X \cdot W + b)\otimes \sigma(X \cdot V + c)
......@@ -7387,3 +7406,73 @@ def scale_sub_region_layer(input, indices, value, name=None):
parents=[input, indices],
num_filters=input.num_filters,
size=input.size)
@wrap_name_default()
@wrap_act_default(act=LinearActivation())
@wrap_param_attr_default()
@layer_support()
def factorization_machine(input,
factor_size,
act=None,
name=None,
param_attr=None,
layer_attr=None):
"""
The Factorization Machine models pairwise feature interactions as inner
product of the learned latent vectors corresponding to each input feature.
The Factorization Machine can effectively capture feature interactions
especially when the input is sparse.
This implementation only consider the 2-order feature interactions using
Factorization Machine with the formula:
.. math::
y = \sum_{i=1}^{n-1}\sum_{j=i+1}^n\langle v_i, v_j \rangle x_i x_j
Note:
X is the input vector with size n. V is the factor matrix. Each row of V
is the latent vector corresponding to each input dimesion. The size of
each latent vector is k.
For details of Factorization Machine, please refer to the paper:
Factorization machines.
.. code-block:: python
first_order = paddle.layer.fc(input=input,
size=1,
act=paddle.activation.Linear())
second_order = paddle.layer.factorization_machine(input=input,
factor_size=10)
fm = paddle.layer.addto(input=[first_order, second_order],
act=paddle.activation.Linear(),
bias_attr=False)
:param input: The input layer. Supported input types: all input data types
on CPU, and only dense input types on GPU.
:type input: LayerOutput
:param factor_size: The hyperparameter that defines the dimensionality of
the latent vector size.
:type context_len: int
:param act: Activation Type. Default is linear activation.
:type act: BaseActivation
:param param_attr: The parameter attribute. See ParameterAttribute for
details.
:type param_attr: ParameterAttribute
:param layer_attr: Extra Layer config.
:type layer_attr: ExtraLayerAttribute|None
:return: LayerOutput object.
:rtype: LayerOutput
"""
assert isinstance(input, LayerOutput)
assert factor_size > 0, "the factor_size must be greater than 0."
Layer(
inputs=[Input(input.name, **param_attr.attr)],
name=name,
factor_size=factor_size,
type=LayerType.FACTORIZATION_MACHINE,
active_type=act.name,
**ExtraLayerAttribute.to_kwargs(layer_attr))
return LayerOutput(
name, LayerType.FACTORIZATION_MACHINE, input, activation=act, size=1)
......@@ -11,6 +11,7 @@ test_recursive_topology test_gated_unit_layer test_clip_layer test_row_l2_norm_l
test_kmax_seq_socre_layer test_sub_nested_seq_select_layer test_scale_shift_layer
test_seq_slice_layer test_cross_entropy_over_beam test_roi_pool_layer test_pooling3D_layer
test_conv3d_layer test_deconv3d_layer test_BatchNorm3D test_resize_layer
test_scale_sub_region_layer test_dot_prod_layer test_l2_distance_layer)
test_scale_sub_region_layer test_dot_prod_layer test_l2_distance_layer
test_factorization_machine)
export whole_configs=(test_split_datasource)
type: "nn"
layers {
name: "data"
type: "data"
size: 1024
active_type: ""
}
layers {
name: "__factorization_machine_0__"
type: "factorization_machine"
size: 1
active_type: ""
inputs {
input_layer_name: "data"
input_parameter_name: "___factorization_machine_0__.w0"
}
factor_size: 10
}
parameters {
name: "___factorization_machine_0__.w0"
size: 10240
initial_mean: 0.0
initial_std: 0.03125
dims: 1024
dims: 10
initial_strategy: 0
initial_smart: true
}
input_layer_names: "data"
output_layer_names: "__factorization_machine_0__"
sub_models {
name: "root"
layer_names: "data"
layer_names: "__factorization_machine_0__"
input_layer_names: "data"
output_layer_names: "__factorization_machine_0__"
is_recurrent_layer_group: false
}
from paddle.trainer_config_helpers import *
data = data_layer(name='data', size=1024)
fm = factorization_machine(input=data, factor_size=10)
outputs(fm)
......@@ -38,6 +38,7 @@ UCI_TEST_DATA = None
URL_MODEL = 'https://github.com/PaddlePaddle/book/raw/develop/01.fit_a_line/fit_a_line.tar'
MD5_MODEL = '52fc3da8ef3937822fcdd87ee05c0c9b'
def feature_range(maximums, minimums):
import matplotlib
matplotlib.use('Agg')
......@@ -114,7 +115,8 @@ def test():
def model():
tar_file = paddle.v2.dataset.common.download(URL_MODEL, 'fit_a_line.tar', MD5_MODEL)
tar_file = paddle.v2.dataset.common.download(URL_MODEL, 'fit_a_line.tar',
MD5_MODEL)
with open(tar_file, 'r') as f:
parameters = Parameters.from_tar(f)
return parameters
......
......@@ -395,7 +395,11 @@ class Block(object):
return v
def all_parameters(self):
return {v for k, v in self.vars.iteritems() if isinstance(v, Parameter)}
return list(self.iter_parameters())
def iter_parameters(self):
return (item[1] for item in self.vars.iteritems()
if isinstance(item[1], Parameter))
def create_var(self, *args, **kwargs):
var = Variable(self, *args, **kwargs)
......@@ -469,6 +473,37 @@ class Block(object):
for index in range(len(self.ops)):
assert self.ops[index].desc == ops_in_cpp[index]
def copy_param_info_from(self, other):
"""
Copy the information of parameters from other block
Args:
other(Block): other block
Returns:
None
"""
if not isinstance(other, Block):
raise TypeError("copy_param_info_from should be invoked with Block")
for p in other.iter_parameters():
assert isinstance(p, Parameter)
v = self.vars.get(p.name, None)
if v is None:
raise ValueError("copy_param_info_from should be invoked with "
"same topology")
assert isinstance(v, Variable)
new_p = Parameter(
block=self,
shape=v.shape,
dtype=v.dtype,
type=v.type,
lod_level=v.lod_level,
stop_gradient=p.stop_gradient,
trainable=p.trainable,
optimize_attr=p.optimize_attr,
regularizer=p.regularizer,
name=v.name)
self.vars[new_p.name] = new_p
class Program(object):
def __init__(self):
......@@ -489,6 +524,7 @@ class Program(object):
p.desc = core.ProgramDesc(self.desc)
p.blocks = [Block(p, i) for i in xrange(self.desc.num_blocks())]
p.sync_with_cpp()
p.copy_param_info_from(self)
return p
def prune(self, targets):
......@@ -572,6 +608,24 @@ class Program(object):
for block in self.blocks:
block.sync_with_cpp()
def copy_param_info_from(self, other):
"""
Copy the information of parameters from other program.
Args:
other(Program): Other program
Returns:
None
"""
if not isinstance(other, Program):
raise TypeError("copy_param_info_from should be invoked with "
"Program")
if len(self.blocks) != len(other.blocks):
raise ValueError("copy_param_info_from should be invoked with two "
"program, with represent the same topology")
self.global_block().copy_param_info_from(other.global_block())
def list_vars(self):
for each_block in self.blocks:
for each_var in each_block.vars.itervalues():
......
file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py")
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
list(REMOVE_ITEM TEST_OPS test_image_classification_train)
py_test(test_image_classification_train_resnet SRCS test_image_classification_train.py ARGS resnet)
py_test(test_image_classification_train_vgg SRCS test_image_classification_train.py ARGS vgg)
# default test
foreach(src ${TEST_OPS})
py_test(${src} SRCS ${src}.py)
endforeach()
from __future__ import print_function
import numpy as np
import paddle.v2 as paddle
import paddle.v2.fluid as fluid
import sys
def resnet_cifar10(input, depth=32):
......@@ -80,11 +82,18 @@ data_shape = [3, 32, 32]
images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
# Add neural network config
# option 1. resnet
# net = resnet_cifar10(images, 32)
# option 2. vgg
net = vgg16_bn_drop(images)
net_type = "vgg"
if len(sys.argv) >= 2:
net_type = sys.argv[1]
if net_type == "vgg":
print("train vgg net")
net = vgg16_bn_drop(images)
elif net_type == "resnet":
print("train resnet")
net = resnet_cifar10(images, 32)
else:
raise ValueError("%s network is not supported" % net_type)
predict = fluid.layers.fc(input=net, size=classdim, act='softmax')
cost = fluid.layers.cross_entropy(input=predict, label=label)
......
......@@ -35,6 +35,13 @@ opts = optimizer.minimize(avg_cost)
accuracy = fluid.evaluator.Accuracy(input=predict, label=label)
inference_program = fluid.default_main_program().clone()
test_accuracy = fluid.evaluator.Accuracy(
input=predict, label=label, main_program=inference_program)
test_target = [avg_cost] + test_accuracy.metrics + test_accuracy.states
inference_program = fluid.io.get_inference_program(
test_target, main_program=inference_program)
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.mnist.train(), buf_size=8192),
......@@ -69,11 +76,6 @@ for pass_id in range(PASS_NUM):
acc = np.array(outs[1])
pass_acc = accuracy.eval(exe)
test_accuracy = fluid.evaluator.Accuracy(input=predict, label=label)
test_target = [avg_cost] + test_accuracy.metrics + test_accuracy.states
inference_program = fluid.io.get_inference_program(test_target)
test_accuracy.reset(exe)
for data in test_reader():
x_data = np.array(map(lambda x: x[0], data)).astype("float32")
......
......@@ -30,9 +30,7 @@ class TestMaxOutOp(OpTest):
def init_test_case(self):
self.MaxOut_forward_naive = maxout_forward_naive
self.shape = [100, 6, 2, 2]
self.groups=2
self.groups = 2
if __name__ == '__main__':
......
from __future__ import print_function
import unittest
from paddle.v2.fluid.framework import Program
from paddle.v2.fluid.framework import g_main_program
import paddle.v2.fluid.layers as layers
class TestProgram(unittest.TestCase):
......@@ -48,8 +50,8 @@ class TestProgram(unittest.TestCase):
# FIXME(yuyang18): We manual compare the output string, since the order
# of variable could be changed.
print prog
print prog.clone()
print(prog)
print(prog.clone())
def test_parse_program_from_string(self):
prog = Program()
......@@ -67,8 +69,8 @@ class TestProgram(unittest.TestCase):
binary_str = prog.desc.serialize_to_string()
prog_restored = Program.parse_from_string(binary_str)
print prog
print prog_restored
print(prog)
print(prog_restored)
def test_append_backward(self):
prog = Program()
......@@ -123,6 +125,20 @@ class TestProgram(unittest.TestCase):
actual_ops.append(op.type)
self.assertEqual(actual_ops, expect_ops)
def test_program_clone_with_parameter(self):
main_program = Program()
startup_program = Program()
kwargs = {
'main_program': main_program,
'startup_program': startup_program
}
d = layers.data(name='x', shape=[784], dtype='float32', **kwargs)
hidden = layers.fc(input=d, size=100, **kwargs)
layers.fc(input=hidden, size=100, **kwargs)
new_program = main_program.clone()
self.assertNotEqual(0, len(new_program.blocks[0].all_parameters()))
if __name__ == '__main__':
unittest.main()
......@@ -4,24 +4,22 @@ import math
import sys
from op_test import OpTest
class TestROIPoolOp(OpTest):
def set_data(self):
self.init_test_case()
self.make_rois()
self.calc_roi_pool()
self.inputs = {
'X': self.x,
'ROIs': self.rois}
self.inputs = {'X': self.x, 'ROIs': self.rois}
self.attrs = {
'spatial_scale': self.spatial_scale,
'pooled_height': self.pooled_height,
'pooled_width': self.pooled_width}
'pooled_width': self.pooled_width
}
self.outputs = {
'Out': self.outs,
'Argmax': self.argmaxes}
self.outputs = {'Out': self.outs, 'Argmax': self.argmaxes}
def init_test_case(self):
self.batch_size = 5
......@@ -30,10 +28,9 @@ class TestROIPoolOp(OpTest):
self.width = 4
# n, c, h, w
self.x_dim = (self.batch_size, self.channels,
self.height, self.width)
self.x_dim = (self.batch_size, self.channels, self.height, self.width)
self.spatial_scale = 1.0/4.0
self.spatial_scale = 1.0 / 4.0
self.pooled_height = 2
self.pooled_width = 2
self.rois_num = 2
......@@ -41,13 +38,11 @@ class TestROIPoolOp(OpTest):
self.x = np.random.random(self.x_dim).astype('float32')
def calc_roi_pool(self):
out_data = np.zeros(
(self.rois_num, self.channels,
self.pooled_height, self.pooled_width))
argmax_data = np.zeros(
(self.rois_num, self.channels,
self.pooled_height, self.pooled_width))
out_data = np.zeros((self.rois_num, self.channels, self.pooled_height,
self.pooled_width))
argmax_data = np.zeros((self.rois_num, self.channels,
self.pooled_height, self.pooled_width))
for i in range(self.rois_num):
roi = self.rois[i]
roi_batch_id = roi[0]
......@@ -56,8 +51,8 @@ class TestROIPoolOp(OpTest):
roi_end_w = int(round(roi[3] * self.spatial_scale))
roi_end_h = int(round(roi[4] * self.spatial_scale))
roi_height = int(max(roi_end_h - roi_start_h + 1, 1));
roi_width = int(max(roi_end_w - roi_start_w + 1, 1));
roi_height = int(max(roi_end_h - roi_start_h + 1, 1))
roi_width = int(max(roi_end_w - roi_start_w + 1, 1))
x_i = self.x[roi_batch_id]
......@@ -84,7 +79,7 @@ class TestROIPoolOp(OpTest):
out_data[i, c, ph, pw] = -sys.float_info.max
argmax_data[i, c, ph, pw] = -1
for h in range(hstart, hend):
for w in range(wstart, wend):
if x_i[c, h, w] > out_data[i, c, ph, pw]:
......@@ -104,11 +99,11 @@ class TestROIPoolOp(OpTest):
y1 = np.random.random_integers(
0, self.height / self.spatial_scale - self.pooled_height)
x2 = np.random.random_integers(
x1 + self.pooled_width, self.width / self.spatial_scale)
y2 = np.random.random_integers(
y1 + self.pooled_height, self.height / self.spatial_scale)
x2 = np.random.random_integers(x1 + self.pooled_width,
self.width / self.spatial_scale)
y2 = np.random.random_integers(y1 + self.pooled_height,
self.height / self.spatial_scale)
roi = [batch_ids[i], x1, y1, x2, y2]
rois.append(roi)
self.rois = np.array(rois).astype("int64")
......@@ -123,5 +118,6 @@ class TestROIPoolOp(OpTest):
def test_check_grad(self):
self.check_grad(['X'], 'Out')
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册