提交 fa5a5a3a 编写于 作者: Y Yu Yang

Merge branch 'develop' of github.com:baidu/Paddle into feature/move_pybind_to_framework_dir

......@@ -27,7 +27,7 @@ RUN apt-get update && \
git python-pip python-dev openssh-server bison \
wget unzip unrar tar xz-utils bzip2 gzip coreutils ntp \
curl sed grep graphviz libjpeg-dev zlib1g-dev \
python-numpy python-matplotlib gcc g++ \
python-numpy python-matplotlib gcc-4.8 g++-4.8 \
automake locales clang-format-3.8 swig doxygen cmake \
liblapack-dev liblapacke-dev libboost-dev \
clang-3.8 llvm-3.8 libclang-3.8-dev \
......
......@@ -9,6 +9,11 @@ function(CheckCompilerCXX11Flag)
if(${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS 4.8)
message(FATAL_ERROR "Unsupported GCC version. GCC >= 4.8 required.")
endif()
# TODO(qijun) gcc 4.9 or later versions raise SEGV due to the optimization problem.
# Use Debug mode instead for now.
if(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 4.9 OR CMAKE_CXX_COMPILER_VERSION VERSION_EQUAL 4.9)
set(CMAKE_BUILD_TYPE "Debug" CACHE STRING "" FORCE)
endif()
elseif(CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang" OR CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
# cmake >= 3.0 compiler id "AppleClang" on Mac OS X, otherwise "Clang"
# Apple Clang is a different compiler than upstream Clang which havs different version numbers.
......
......@@ -104,6 +104,11 @@ cross_channel_norm
------------------
.. autoclass:: paddle.v2.layer.cross_channel_norm
:noindex:
row_l2_norm
-----------
.. autoclass:: paddle.v2.layer.row_l2_norm
:noindex:
Recurrent Layers
================
......@@ -320,6 +325,11 @@ scaling
.. autoclass:: paddle.v2.layer.scaling
:noindex:
clip
----
.. autoclass:: paddle.v2.layer.clip
:noindex:
slope_intercept
---------------
.. autoclass:: paddle.v2.layer.slope_intercept
......
......@@ -1022,6 +1022,15 @@ void hl_batch_norm_forward_inference(hl_tensor_descriptor inputDesc,
real alpha = 1.0f;
real beta = 1.0f;
cudnnBatchNormMode_t mode = CUDNN_BATCHNORM_SPATIAL;
int batch_size = ((cudnn_tensor_descriptor)inputDesc)->batch_size;
if (batch_size > 1024 && g_cudnn_lib_version < 6000) {
LOG(INFO) << " To process current batch data with size " << batch_size
<< " (>1024), cudnnBatchNorm requires cuDNN version >= 6000."
<< " If there is an error complaining CUDNN_STATUS_NOT_SUPPORTED,"
<< " just recompile PaddlePaddle with cuDNN >= 6000, replacing"
<< " current version " << g_cudnn_lib_version;
}
CHECK_CUDNN(
dynload::cudnnBatchNormalizationForwardInference(t_resource.cudnn_handle,
mode,
......
......@@ -38,7 +38,7 @@ cc_library(backward SRCS backward.cc DEPS net)
cc_test(backward_test SRCS backward_test.cc DEPS backward)
cc_library(paddle_pybind SHARED
SRCS pybind.cc
DEPS pybind python
DEPS pybind python backward
fc_op
sgd_op
add_op
......
......@@ -400,6 +400,14 @@ class GradOpRegisterHelper {
return 0; \
}
/**
* Macro to Forbid user register Gradient Operator.
*/
#define NO_GRADIENT(__op_type) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_gradient_op__##__op_type##__op_type##_grad, \
"NO_GRADIENT must be in global namespace")
/**
* Macro to Register OperatorKernel.
*/
......
......@@ -20,16 +20,16 @@ namespace paddle {
namespace framework {
template <>
Eigen::DefaultDevice* ExecutionContext::GetEigenDevice<
Eigen::DefaultDevice& ExecutionContext::GetEigenDevice<
platform::CPUPlace, Eigen::DefaultDevice>() const {
return device_context_.get_eigen_device<Eigen::DefaultDevice>();
return *device_context_.get_eigen_device<Eigen::DefaultDevice>();
}
#ifndef PADDLE_ONLY_CPU
template <>
Eigen::GpuDevice*
Eigen::GpuDevice&
ExecutionContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const {
return device_context_.get_eigen_device<Eigen::GpuDevice>();
return *device_context_.get_eigen_device<Eigen::GpuDevice>();
}
#endif
......
......@@ -253,7 +253,7 @@ class ExecutionContext : public OperatorContext {
template <typename PlaceType,
typename DeviceType =
typename EigenDeviceConverter<PlaceType>::EigenDeviceType>
DeviceType* GetEigenDevice() const;
DeviceType& GetEigenDevice() const;
platform::Place GetPlace() const { return device_context_.GetPlace(); }
......
......@@ -4,7 +4,7 @@ Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
......@@ -16,11 +16,14 @@ limitations under the License. */
#include <fstream>
#include <vector>
#include "paddle/framework/backward.h"
#include "paddle/framework/net.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/scope.h"
#include "paddle/framework/tensor_bind.h"
#include "paddle/platform/enforce.h"
#include "paddle/platform/place.h"
#include "pybind11/numpy.h"
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
......@@ -43,6 +46,10 @@ template <typename ClassType>
void ExposeOperator(ClassType &m) {
m.def("infer_shape", &ClassType::type::InferShape)
.def("run", &ClassType::type::Run)
.def("type",
[](const typename ClassType::type &op) -> std::string {
return op.type_;
})
.def("outputs",
[](const typename ClassType::type &op) -> std::vector<std::string> {
return op.outputs_;
......@@ -55,6 +62,14 @@ static size_t UniqueIntegerGenerator() {
return generator.fetch_add(1);
}
bool IsCompileGPU() {
#ifdef PADDLE_ONLY_CPU
return false;
#else
return true;
#endif
}
PYBIND11_PLUGIN(core) {
py::module m("core", "C++ core of PaddlePaddle");
......@@ -68,16 +83,29 @@ PYBIND11_PLUGIN(core) {
self.Resize(make_ddim(dim));
})
.def("alloc_float",
[](Tensor &self) {
self.mutable_data<float>(paddle::platform::CPUPlace());
[](pd::Tensor &self, paddle::platform::GPUPlace &place) {
self.mutable_data<float>(place);
})
.def("alloc_float",
[](pd::Tensor &self, paddle::platform::CPUPlace &place) {
self.mutable_data<float>(place);
})
.def("alloc_int",
[](Tensor &self) {
self.mutable_data<int>(paddle::platform::CPUPlace());
[](pd::Tensor &self, paddle::platform::CPUPlace &place) {
self.mutable_data<int>(place);
})
.def("set", PyTensorSetFromArray<float>)
.def("set", PyTensorSetFromArray<int>)
.def("shape", [](Tensor &self) { return vectorize(self.dims()); });
.def("alloc_int",
[](pd::Tensor &self, paddle::platform::GPUPlace &place) {
self.mutable_data<int>(place);
})
.def("set", paddle::pybind::PyCPUTensorSetFromArray<float>)
.def("set", paddle::pybind::PyCPUTensorSetFromArray<int>)
#ifndef PADDLE_ONLY_CPU
.def("set", paddle::pybind::PyCUDATensorSetFromArray<float>)
.def("set", paddle::pybind::PyCUDATensorSetFromArray<int>)
#endif
.def("shape",
[](pd::Tensor &self) { return pd::vectorize(self.dims()); });
py::class_<Variable>(m, "Variable", R"DOC(Variable Class.
......@@ -124,13 +152,29 @@ All parameter, weight, gradient are variables in Paddle.
m.def_submodule(
"var_names",
"The module will return special predefined variable name in Paddle")
.def("empty", OperatorBase::EMPTY_VAR_NAME)
.def("temp", OperatorBase::TMP_VAR_NAME);
.def("empty", pd::OperatorBase::EMPTY_VAR_NAME)
.def("temp", pd::OperatorBase::TMP_VAR_NAME);
// clang-format off
py::class_<paddle::platform::DeviceContext>(m, "DeviceContext")
.def_static("cpu_context", []() -> paddle::platform::DeviceContext * {
return new paddle::platform::CPUDeviceContext();
});
.def_static("create",
[](paddle::platform::CPUPlace& place)
-> paddle::platform::DeviceContext* {
return new paddle::platform::CPUDeviceContext();
})
.def_static("create",
[](paddle::platform::GPUPlace& place)
-> paddle::platform::DeviceContext* {
#ifdef PADDLE_ONLY_CPU
PADDLE_THROW("GPUPlace is not supported in CPU device.");
#else
return new paddle::platform::CUDADeviceContext(place);
#endif
});
// clang-format on
py::class_<paddle::platform::GPUPlace>(m, "GPUPlace").def(py::init<int>());
py::class_<paddle::platform::CPUPlace>(m, "CPUPlace").def(py::init<>());
py::class_<OperatorBase, std::shared_ptr<OperatorBase>> operator_base(
m, "Operator");
......@@ -144,6 +188,13 @@ All parameter, weight, gradient are variables in Paddle.
desc.InitializationErrorString());
return OpRegistry::CreateOp(desc);
});
operator_base.def("backward",
[](const pd::OperatorBase &forwardOp,
const std::unordered_set<std::string> &no_grad_vars) {
return pd::Backward(forwardOp, no_grad_vars);
});
ExposeOperator(operator_base);
py::class_<NetOp, std::shared_ptr<NetOp>> net(m, "Net");
......@@ -166,6 +217,8 @@ All parameter, weight, gradient are variables in Paddle.
m.def("unique_integer", UniqueIntegerGenerator);
m.def("is_compile_gpu", IsCompileGPU);
return m.ptr();
}
} // namespace framework
......
......@@ -165,4 +165,4 @@ class Tensor {
} // namespace framework
} // namespace paddle
#include "paddle/framework/detail/tensor-inl.h"
#include "paddle/framework/tensor_impl.h"
......@@ -13,9 +13,11 @@
limitations under the License. */
#pragma once
#include <paddle/framework/tensor.h>
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
#include <string>
#include "paddle/framework/tensor.h"
#include "paddle/memory/memcpy.h"
#include "pybind11/numpy.h"
#include "pybind11/pybind11.h"
namespace py = pybind11;
......@@ -40,9 +42,6 @@ template <size_t I, typename... ARGS>
struct CastToPyBufferImpl<true, I, ARGS...> {
using CUR_TYPE = typename std::tuple_element<I, std::tuple<ARGS...>>::type;
py::buffer_info operator()(framework::Tensor &tensor) {
PADDLE_ENFORCE(paddle::platform::is_cpu_place(tensor.holder_->place()),
"Only CPU tensor can cast to numpy array");
if (std::type_index(typeid(CUR_TYPE)) == tensor.holder_->type()) {
auto dim_vec = framework::vectorize(tensor.dims());
std::vector<size_t> dims_outside;
......@@ -56,11 +55,16 @@ struct CastToPyBufferImpl<true, I, ARGS...> {
strides[i - 1] = sizeof(CUR_TYPE) * prod;
prod *= dims_outside[i - 1];
}
framework::Tensor dst_tensor;
if (paddle::platform::is_gpu_place(tensor.holder_->place())) {
dst_tensor.CopyFrom<CUR_TYPE>(tensor, platform::CPUPlace());
} else if (paddle::platform::is_cpu_place(tensor.holder_->place())) {
dst_tensor = tensor;
}
return py::buffer_info(
tensor.mutable_data<CUR_TYPE>(tensor.holder_->place()),
dst_tensor.mutable_data<CUR_TYPE>(dst_tensor.holder_->place()),
sizeof(CUR_TYPE), py::format_descriptor<CUR_TYPE>::format(),
(size_t)framework::arity(tensor.dims()), dims_outside, strides);
(size_t)framework::arity(dst_tensor.dims()), dims_outside, strides);
} else {
constexpr bool less = I + 1 < std::tuple_size<std::tuple<ARGS...>>::value;
return CastToPyBufferImpl<less, I + 1, ARGS...>()(tensor);
......@@ -74,9 +78,10 @@ inline py::buffer_info CastToPyBuffer(framework::Tensor &tensor) {
}
template <typename T>
void PyTensorSetFromArray(
void PyCPUTensorSetFromArray(
framework::Tensor &self,
py::array_t<T, py::array::c_style | py::array::forcecast> array) {
py::array_t<T, py::array::c_style | py::array::forcecast> array,
paddle::platform::CPUPlace &place) {
std::vector<int> dims;
dims.reserve(array.ndim());
for (size_t i = 0; i < array.ndim(); ++i) {
......@@ -84,9 +89,28 @@ void PyTensorSetFromArray(
}
self.Resize(framework::make_ddim(dims));
auto *dst = self.mutable_data<T>(paddle::platform::CPUPlace());
auto *dst = self.mutable_data<T>(place);
std::memcpy(dst, array.data(), sizeof(T) * array.size());
}
#ifndef PADDLE_ONLY_CPU
template <typename T>
void PyCUDATensorSetFromArray(
framework::Tensor &self,
py::array_t<T, py::array::c_style | py::array::forcecast> array,
paddle::platform::GPUPlace &place) {
std::vector<int> dims;
dims.reserve(array.ndim());
for (size_t i = 0; i < array.ndim(); ++i) {
dims.push_back((int)array.shape()[i]);
}
self.Resize(framework::make_ddim(dims));
auto *dst = self.mutable_data<T>(place);
paddle::platform::GpuMemcpySync(dst, array.data(), sizeof(T) * array.size(),
cudaMemcpyHostToDevice);
}
#endif
} // namespace pybind
} // namespace paddle
......@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/memory/memcpy.h"
namespace paddle {
......@@ -62,9 +61,11 @@ inline T* Tensor::mutable_data(platform::Place place) {
if (platform::is_cpu_place(place)) {
holder_.reset(new PlaceholderImpl<T, platform::CPUPlace>(
boost::get<platform::CPUPlace>(place), size));
} else if (platform::is_gpu_place(place)) {
#ifdef PADDLE_ONLY_CPU
PADDLE_THROW("'GPUPlace' is not supported in CPU only device.");
}
#ifndef PADDLE_ONLY_CPU
else if (platform::is_gpu_place(place)) {
#else
holder_.reset(new PlaceholderImpl<T, platform::GPUPlace>(
boost::get<platform::GPUPlace>(place), size));
}
......
......@@ -109,6 +109,13 @@ protected:
return filter[filter.ndims() - 1];
}
// determine whether im2col needs to be performed
inline bool isNeedIm2col(const TensorShape& filter) const {
return !(getFilterHeight(filter) == 1 && getFilterWidth(filter) == 1 &&
strideH() == 1 && strideW() == 1 && paddingH() == 0 &&
paddingW() == 0);
}
std::vector<size_t> strides_;
std::vector<size_t> paddings_;
......
......@@ -66,16 +66,23 @@ public:
real* inputData = inputs[0].data<real>();
real* filterData = inputs[1].data<real>();
real* outputData = outputs[0].data<real>();
bool needIm2col = isNeedIm2col(filter);
TensorShape imShape =
TensorShape({inputChannels / groups_, inputHeight, inputWidth});
TensorShape colShape = TensorShape({inputChannels / groups_,
filterHeight,
filterWidth,
outputHeight,
outputWidth});
resizeBuffer<Device>(colShape.getElements());
real* colData = reinterpret_cast<real*>(memory_->getBuf());
TensorShape colShape;
real* colData = NULL;
if (needIm2col) {
colShape = TensorShape({inputChannels / groups_,
filterHeight,
filterWidth,
outputHeight,
outputWidth});
resizeBuffer<Device>(colShape.getElements());
colData = reinterpret_cast<real*>(memory_->getBuf());
}
Im2ColFunctor<kCFO, Device, real> im2col;
GemmFunctor<Device, real> gemm;
......@@ -86,15 +93,18 @@ public:
for (size_t i = 0; i < batchSize; i++) {
for (size_t g = 0; g < groups_; g++) {
im2col(inputData + g * inputOffset,
imShape,
colData,
colShape,
strideH(),
strideW(),
paddingH(),
paddingW());
if (needIm2col) {
im2col(inputData + g * inputOffset,
imShape,
colData,
colShape,
strideH(),
strideW(),
paddingH(),
paddingW());
} else {
colData = inputData + g * inputOffset;
}
int M = outputChannels / groups_;
int N = outputHeight * outputWidth;
int K = inputChannels / groups_ * filterHeight * filterWidth;
......@@ -159,19 +169,27 @@ public:
real* outputGrad = inputs[0].data<real>();
real* filterData = inputs[1].data<real>();
real* inputGrad = outputs[0].data<real>();
bool needIm2col = isNeedIm2col(filter);
TensorShape imShape =
TensorShape({inputChannels / groups_, inputHeight, inputWidth});
TensorShape colShape = TensorShape({inputChannels / groups_,
filterHeight,
filterWidth,
outputHeight,
outputWidth});
resizeBuffer<Device>(colShape.getElements());
real* colData = reinterpret_cast<real*>(memory_->getBuf());
TensorShape colShape;
real* colData = NULL;
if (needIm2col) {
colShape = TensorShape({inputChannels / groups_,
filterHeight,
filterWidth,
outputHeight,
outputWidth});
resizeBuffer<Device>(colShape.getElements());
colData = reinterpret_cast<real*>(memory_->getBuf());
}
Col2ImFunctor<kCFO, Device, real> col2im;
GemmFunctor<Device, real> gemm;
size_t inputOffset = imShape.getElements();
size_t outputOffset =
(outputChannels / groups_) * outputHeight * outputWidth;
......@@ -182,6 +200,11 @@ public:
int K = outputChannels / groups_;
int N = outputHeight * outputWidth;
int M = inputChannels / groups_ * filterHeight * filterWidth;
real scale = 0.0f;
if (!needIm2col) {
colData = inputGrad + g * inputOffset;
scale = 1.0f;
}
gemm(CblasTrans,
CblasNoTrans,
M,
......@@ -192,17 +215,19 @@ public:
M,
outputGrad + g * outputOffset,
N,
0.0f,
scale,
colData,
N);
col2im(inputGrad + g * inputOffset,
imShape,
colData,
colShape,
strideH(),
strideW(),
paddingH(),
paddingW());
if (needIm2col) {
col2im(inputGrad + g * inputOffset,
imShape,
colData,
colShape,
strideH(),
strideW(),
paddingH(),
paddingW());
}
}
inputGrad += inputChannels * inputHeight * inputWidth;
outputGrad += outputChannels * outputHeight * outputWidth;
......@@ -255,16 +280,23 @@ public:
real* outputGrad = inputs[0].data<real>();
real* inputData = inputs[1].data<real>();
real* filterGrad = outputs[0].data<real>();
bool needIm2col = isNeedIm2col(filter);
TensorShape imShape =
TensorShape({inputChannels / groups_, inputHeight, inputWidth});
TensorShape colShape = TensorShape({inputChannels / groups_,
filterHeight,
filterWidth,
outputHeight,
outputWidth});
resizeBuffer<Device>(colShape.getElements());
real* colData = reinterpret_cast<real*>(memory_->getBuf());
TensorShape colShape;
real* colData = NULL;
if (needIm2col) {
colShape = TensorShape({inputChannels / groups_,
filterHeight,
filterWidth,
outputHeight,
outputWidth});
resizeBuffer<Device>(colShape.getElements());
colData = reinterpret_cast<real*>(memory_->getBuf());
}
Im2ColFunctor<kCFO, Device, real> im2col;
GemmFunctor<Device, real> gemm;
......@@ -274,15 +306,18 @@ public:
size_t filterOffset = filter.getElements() / groups_;
for (size_t i = 0; i < batchSize; i++) {
for (size_t g = 0; g < groups_; g++) {
im2col(inputData + g * inputOffset,
imShape,
colData,
colShape,
strideH(),
strideW(),
paddingH(),
paddingW());
if (needIm2col) {
im2col(inputData + g * inputOffset,
imShape,
colData,
colShape,
strideH(),
strideW(),
paddingH(),
paddingW());
} else {
colData = inputData + g * inputOffset;
}
int M = outputChannels / groups_;
int K = outputHeight * outputWidth;
int N = inputChannels / groups_ * filterHeight * filterWidth;
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "Layer.h"
namespace paddle {
/**
* A layer for clipping the input value by the threshold.
* \f[
* out[i] = \min\left(\max\left(in[i],p_{1}\right),p_{2}\right)
* \f]
*/
class ClipLayer : public Layer {
protected:
double min_;
double max_;
public:
explicit ClipLayer(const LayerConfig& config) : Layer(config) {}
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
};
REGISTER_LAYER(clip, ClipLayer);
bool ClipLayer::init(const LayerMap& layerMap,
const ParameterMap& parameterMap) {
Layer::init(layerMap, parameterMap);
CHECK_EQ(inputLayers_.size(), 1U);
auto layerConf = config_.inputs(0).clip_conf();
min_ = layerConf.min();
max_ = layerConf.max();
CHECK_LT(min_, max_);
return true;
}
void ClipLayer::forward(PassType passType) {
Layer::forward(passType);
MatrixPtr inV = getInputValue(0);
resetOutput(inV->getHeight(), inV->getWidth());
MatrixPtr outV = getOutputValue();
outV->copyFrom(*inV);
outV->clip(min_, max_);
}
void ClipLayer::backward(const UpdateCallback& callback) {
MatrixPtr inV = getInputValue(0);
MatrixPtr inG = getInputGrad(0);
if (inG) {
MatrixPtr outV = getOutputValue();
MatrixPtr outG = getOutputGrad();
MatrixPtr tmpMtx;
Matrix::resizeOrCreate(
tmpMtx, outG->getHeight(), outG->getWidth(), false, useGpu_);
tmpMtx->clipDerivative(*inV, min_, max_);
inG->addDotMul(*outG, *tmpMtx, 1, 1);
}
}
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "Layer.h"
namespace paddle {
/**
* A layer for L2 normalization in each row,
* \f[
* out[i] = \frac{in[i]}{\sqrt{\sum_{k=1}^N in[k]^{2}}}
* \f]
* where the size of \f$in\f$ is (batchSize x dataDim),
* and the size of \f$out\f$ is (batchSize x dataDim).
*/
class RowL2NormLayer : public Layer {
protected:
MatrixPtr inSquare_;
MatrixPtr l2NormReciprocal_;
MatrixPtr dotSum_;
public:
explicit RowL2NormLayer(const LayerConfig& config) : Layer(config) {}
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
};
REGISTER_LAYER(row_l2_norm, RowL2NormLayer);
bool RowL2NormLayer::init(const LayerMap& layerMap,
const ParameterMap& parameterMap) {
Layer::init(layerMap, parameterMap);
CHECK_EQ(inputLayers_.size(), 1U);
return true;
}
void RowL2NormLayer::forward(PassType passType) {
Layer::forward(passType);
MatrixPtr inV = getInputValue(0);
/* malloc memory for the output_ if necessary */
size_t batchSize = inV->getHeight();
size_t dataDim = getSize();
CHECK_EQ(dataDim, inV->getWidth());
resetOutput(batchSize, dataDim);
MatrixPtr outV = getOutputValue();
Matrix::resizeOrCreate(inSquare_, batchSize, dataDim, false, useGpu_);
inV->square2(*inSquare_);
Matrix::resizeOrCreate(l2NormReciprocal_, batchSize, 1, false, useGpu_);
inSquare_->rowSum(*l2NormReciprocal_);
l2NormReciprocal_->sqrt2(*l2NormReciprocal_);
l2NormReciprocal_->scalarDiv(*l2NormReciprocal_, 1.0);
outV->rowScale(0, *inV, *l2NormReciprocal_);
}
void RowL2NormLayer::backward(const UpdateCallback& callback) {
MatrixPtr inV = getInputValue(0);
MatrixPtr inG = getInputGrad(0);
MatrixPtr outV = getOutputValue();
MatrixPtr outG = getOutputGrad();
size_t batchSize = inV->getHeight();
// inG[ij] += outG[ij] / l2NormReciprocal
// inG[ij] += -inV[ij] * l2NormReciprocal * l2NormReciprocal * DotMul(outG[i],
// inV[i])
if (inG) {
Matrix::resizeOrCreate(dotSum_, batchSize, 1, false, useGpu_);
dotSum_->zeroMem();
dotSum_->rowDotMul(0, *outG, *outV);
dotSum_->dotMul(*dotSum_, *l2NormReciprocal_);
dotSum_->dotMul(*dotSum_, *l2NormReciprocal_);
inSquare_->rowScale(0, *inV, *dotSum_);
inG->sub(*inSquare_);
inG->addRowScale(0, *outG, *l2NormReciprocal_);
}
}
} // namespace paddle
......@@ -1899,6 +1899,36 @@ TEST(Layer, CropLayer) {
}
}
TEST(Layer, ClipLayer) {
const size_t batchSize = 128;
const size_t size = 512;
TestConfig config;
config.layerConfig.set_type("clip");
config.inputDefs.push_back({INPUT_DATA, "input", size, 0});
LayerInputConfig* input = config.layerConfig.add_inputs();
ClipConfig* layerConf = input->mutable_clip_conf();
double p1 = std::rand() / (double)RAND_MAX;
double p2 = std::rand() / (double)RAND_MAX;
layerConf->set_min(std::min(p1, p2));
layerConf->set_max(std::max(p1, p2));
for (auto useGpu : {false, true}) {
testLayerGrad(config, "clip", batchSize, false, useGpu, false);
}
}
TEST(Layer, RowL2NormLayer) {
const size_t batchSize = 128;
const size_t size = 512;
TestConfig config;
config.layerConfig.set_type("row_l2_norm");
config.layerConfig.set_size(size);
config.inputDefs.push_back({INPUT_DATA, "input", size, 0});
config.layerConfig.add_inputs();
for (auto useGpu : {false, true}) {
testLayerGrad(config, "row_l2_norm", batchSize, false, useGpu, false);
}
}
int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
initMain(argc, argv);
......
......@@ -442,6 +442,12 @@ DEFINE_MATRIX_UNARY_PARAMETER_OP(Clip, TWO_PARAMETER,
template<class T>
void BaseMatrixT<T>::clip(T p1, T p2) { applyUnary(unary::Clip<T>(p1, p2)); }
DEFINE_MATRIX_BINARY_PARAMETER_OP(ClipDerivative, TWO_PARAMETER, a = b < p1 ? 0 : (b > p2 ? 0 : 1));
template<class T>
void BaseMatrixT<T>::clipDerivative(BaseMatrixT& b, T p1, T p2) {
applyBinary(binary::ClipDerivative<T>(p1, p2), b);
}
DEFINE_MATRIX_UNARY_PARAMETER_OP(BiggerThanScalar, ONE_PARAMETER,
a = a > p ? 1.0f : 0.0f);
template<class T>
......
......@@ -488,6 +488,13 @@ public:
*/
void clip(T p1, T p2);
/**
* this = b < low ? 0 : 1
*
* this = b > high ? 0 : 1
*/
void clipDerivative(BaseMatrixT& b, T p1, T p2);
/**
* @code
* a = a > p ? 1.0f : 0.0f
......
......@@ -60,10 +60,5 @@ op_library(sgd_op SRCS sgd_op.cc sgd_op.cu)
op_library(fc_op
SRCS fc_op.cc
DEPS mul_op rowwise_add_op sigmoid_op softmax_op net)
op_library(recurrent_network_op
SRCS recurrent_network_op.cc
DEPS op_desc tensor net)
cc_test(recurrent_network_op_test
SRCS recurrent_network_op_test.cc
DEPS recurrent_network_op mul_op add_op)
op_library(recurrent_op SRCS recurrent_op.cc DEPS op_desc tensor op_registry operator net)
cc_test(recurrent_op_test SRCS recurrent_op_test.cc DEPS recurrent_op gtest mul_op add_op)
......@@ -50,10 +50,6 @@ The equation is: Out = X + Y
class AddOpGrad : public OperatorWithKernel {
protected:
void InferShape(const InferShapeContext &ctx) const override {}
std::string DebugString() const override {
LOG(INFO) << "AddOpGrad";
return "";
}
};
} // namespace operators
......
#define EIGEN_USE_GPU
#include "paddle/framework/op_registry.h"
#include "paddle/operators/add_op.h"
......
......@@ -28,10 +28,13 @@ public:
output->mutable_data<T>(context.GetPlace());
EigenVector<T>::Flatten(*output).device(
*(context.GetEigenDevice<Place>())) =
framework::EigenVector<T>::Flatten(*input0) +
framework::EigenVector<T>::Flatten(*input1);
auto X = EigenVector<T>::Flatten(*input0);
auto Y = EigenVector<T>::Flatten(*input1);
auto Z = EigenVector<T>::Flatten(*output);
auto place = context.GetEigenDevice<Place>();
Z.device(place) = X + Y;
}
};
......
#define EIGEN_USE_GPU
#include "paddle/operators/cross_entropy_op.h"
REGISTER_OP_GPU_KERNEL(onehot_cross_entropy,
......
......@@ -33,13 +33,23 @@ public:
MeanOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input of mean op");
AddOutput("Out", "The output of mean op");
AddOutput("Out", "The output of mean op").IgnoreGradient();
AddComment("Mean Operator");
}
};
class MeanGradOp : public OperatorWithKernel {
protected:
void InferShape(const InferShapeContext &ctx) const override {
ctx.Output<Tensor>("X" + GRAD_VAR_SUFFIX())
->Resize(ctx.Input<Tensor>("X")->dims());
}
};
} // namespace operators
} // namespace paddle
REGISTER_OP(mean, ops::MeanOp, ops::MeanOpMaker);
REGISTER_OP_CPU_KERNEL(mean, ops::MeanKernel<ops::CPUPlace, float>);
REGISTER_GRADIENT_OP(mean, mean_grad, ops::MeanGradOp);
REGISTER_OP_CPU_KERNEL(mean_grad, ops::MeanGradKernel<ops::CPUPlace, float>);
......@@ -3,3 +3,4 @@
#include "paddle/operators/mean_op.h"
REGISTER_OP_GPU_KERNEL(mean, ops::MeanKernel<ops::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(mean_grad, ops::MeanGradKernel<ops::GPUPlace, float>);
\ No newline at end of file
......@@ -27,8 +27,28 @@ public:
output->mutable_data<T>(context.GetPlace());
EigenScalar<T>::From(*output).device(*(context.GetEigenDevice<Place>())) =
EigenVector<T>::Flatten(*input).mean();
auto X = EigenVector<T>::Flatten(*input);
auto y = EigenScalar<T>::From(*output);
auto place = context.GetEigenDevice<Place>();
y.device(place) = X.mean();
}
};
template <typename Place, typename T>
class MeanGradKernel : public OpKernel {
public:
void Compute(const ExecutionContext& context) const override {
auto OG = context.Input<Tensor>("Out" + OperatorBase::GRAD_VAR_SUFFIX());
PADDLE_ENFORCE(framework::product(OG->dims()) == 1,
"Mean Gradient should be scalar");
auto IG = context.Output<Tensor>("X" + OperatorBase::GRAD_VAR_SUFFIX());
IG->mutable_data<T>(context.GetPlace());
T ig_size = (T)framework::product(IG->dims());
EigenVector<T>::Flatten(*IG).device(*(context.GetEigenDevice<Place>())) =
EigenScalar<T>::From(*OG) / ig_size;
}
};
......
......@@ -12,6 +12,7 @@
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/operators/mul_op.h"
REGISTER_OP_GPU_KERNEL(mul, ops::MulKernel<ops::GPUPlace, float>);
\ No newline at end of file
......@@ -26,13 +26,18 @@ public:
Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair = {
{Eigen::IndexPair<Eigen::DenseIndex>(1, 0)}};
auto input0 = context.Input<Tensor>("X");
auto input1 = context.Input<Tensor>("Y");
auto output = context.Output<Tensor>(0);
output->mutable_data<T>(context.GetPlace());
EigenMatrix<T>::From(*output).device(*(context.GetEigenDevice<Place>())) =
EigenMatrix<T>::From(*context.Input<Tensor>("X"))
.contract(EigenMatrix<T>::From(*context.Input<Tensor>("Y")),
dim_pair);
auto X = EigenMatrix<T>::From(*input0);
auto Y = EigenMatrix<T>::From(*input1);
auto Z = EigenMatrix<T>::From(*output);
auto place = context.GetEigenDevice<Place>();
Z.device(place) = X.contract(Y, dim_pair);
}
};
} // namespace operators
......
......@@ -12,7 +12,7 @@
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/recurrent_network_op.h"
#include "paddle/operators/recurrent_op.h"
#include <glog/logging.h>
#include <cstring>
......@@ -29,11 +29,15 @@ namespace rnn {
void SegmentInputs(const std::vector<Scope*>& step_scopes,
const std::vector<Link>& inlinks,
const size_t seq_len) {
const size_t seq_len,
bool infer_shape_mode) {
PADDLE_ENFORCE(!inlinks.empty(), "no in links are provided.");
for (size_t i = 0; i < inlinks.size(); ++i) {
Tensor* input =
step_scopes[0]->FindVar(inlinks[i].external)->GetMutable<Tensor>();
auto input_var = step_scopes[0]->FindVar(inlinks[i].external);
PADDLE_ENFORCE(input_var != nullptr,
"input link [%s] is not in scope.",
inlinks[i].external);
Tensor* input = input_var->GetMutable<Tensor>();
DDim dims = input->dims();
PADDLE_ENFORCE(static_cast<size_t>(dims[0]) == seq_len,
"all the inlinks must have same length");
......@@ -41,7 +45,9 @@ void SegmentInputs(const std::vector<Scope*>& step_scopes,
for (size_t j = 0; j < seq_len; j++) {
Tensor* step_input =
step_scopes[j]->NewVar(inlinks[i].internal)->GetMutable<Tensor>();
*step_input = input->Slice<float>(j, j + 1);
if (!infer_shape_mode) {
*step_input = input->Slice<float>(j, j + 1);
}
step_input->Resize(step_dims);
}
}
......@@ -49,36 +55,41 @@ void SegmentInputs(const std::vector<Scope*>& step_scopes,
void ConcatOutputs(const std::vector<Scope*>& step_scopes,
const std::vector<Link>& outlinks,
const size_t seq_len) {
const size_t seq_len,
bool infer_shape_mode) {
for (size_t i = 0; i < outlinks.size(); i++) {
Tensor* output =
step_scopes[0]->FindVar(outlinks[i].external)->GetMutable<Tensor>();
// TODO(qingiqng) remove following code after adding
// InferShape in RecurrentGradientOp
DDim step_dims = step_scopes[0]
->FindVar(outlinks[i].internal)
->GetMutable<Tensor>()
->dims();
std::vector<int> dims_vec = vectorize(step_dims);
dims_vec.insert(dims_vec.begin(), seq_len);
output->mutable_data<float>(make_ddim(dims_vec), platform::CPUPlace());
for (size_t j = 0; j < seq_len; j++) {
Tensor* step_output =
step_scopes[j]->FindVar(outlinks[i].internal)->GetMutable<Tensor>();
// TODO(luotao02) data type and platform::DeviceContext() should set
// correctly
(output->Slice<float>(j, j + 1))
.CopyFrom<float>(*step_output, platform::CPUPlace());
auto output_var = step_scopes[0]->FindVar(outlinks[i].external);
PADDLE_ENFORCE(output_var != nullptr,
"output link [%s] is not in scope.",
outlinks[i].external);
Tensor* output = output_var->GetMutable<Tensor>();
if (infer_shape_mode) {
DDim step_dims = step_scopes[0]
->FindVar(outlinks[i].internal)
->GetMutable<Tensor>()
->dims();
std::vector<int> dims_vec = vectorize(step_dims);
dims_vec.insert(dims_vec.begin(), seq_len);
output->Resize(make_ddim(dims_vec));
} else {
output->mutable_data<float>(platform::CPUPlace());
for (size_t j = 0; j < seq_len; j++) {
Tensor* step_output =
step_scopes[j]->FindVar(outlinks[i].internal)->GetMutable<Tensor>();
// TODO(luotao02) data type and platform::DeviceContext() should set
// correctly
(output->Slice<float>(j, j + 1))
.CopyFrom<float>(*step_output, platform::CPUPlace());
}
}
}
}
void LinkMemories(const std::vector<Scope*>& scopes,
const std::vector<rnn::MemoryAttr>& memories,
size_t step_id,
int offset) {
const size_t step_id,
const int offset,
bool infer_shape_mode) {
PADDLE_ENFORCE(step_id < scopes.size(),
"step [%d] is out of range of step scopes' size [%d]",
step_id,
......@@ -95,18 +106,13 @@ void LinkMemories(const std::vector<Scope*>& scopes,
auto scope = scopes[step_id];
auto linked_scope = scopes[step_id + offset];
for (auto& attr : memories) {
auto mem = scope->NewVar(attr.pre_var)->GetMutable<Tensor>();
// maybe share variable is better?
auto mem = scope->FindVar(attr.pre_var)->GetMutable<Tensor>();
auto linked_mem = linked_scope->FindVar(attr.var)->GetMutable<Tensor>();
mem->ShareDataWith<float>(*linked_mem);
// TODO(qingqing) remove following code
// the memory of current step should be allocated in step net
auto m = scope->NewVar(attr.var)->GetMutable<Tensor>();
// for unit test, as addOp and mulOp are null currently, if not
// mutable_data, mem.data() in output will be error. We will
// remove this line after merge the correct addOp and mulOp.
m->mutable_data<float>(mem->dims(), platform::CPUPlace());
if (infer_shape_mode) {
mem->Resize(linked_mem->dims());
} else {
mem->ShareDataWith<float>(*linked_mem);
}
}
}
......@@ -175,60 +181,39 @@ void RecurrentAlgorithm::InferShape(const Scope& scope) const {
->dims()[0];
CreateScopes(scope);
auto step_scopes = GetStepScopes(scope);
// SegmentInputs is called in InferShape. The input must hold memory in
// SegmentInputs. But the other op only set dimension for the output in
// InferShape. That's a problem. Wether the RNN op needs InferShape or not?
// Wether the following functions (SegmentInputs, InitMemories, ...) need
// to rewrite for RNN op?
rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_);
InitMemories(step_scopes[0]);
PADDLE_ENFORCE(scope.FindVar(arg_->step_net) != nullptr,
"stepnet [%s] is not in scope.",
arg_->step_net);
rnn::SegmentInputs(
step_scopes, arg_->inlinks, seq_len_, true /*infer_shape_mode*/);
InitMemories(step_scopes[0], true /*infer_shape_mode*/);
Variable* net = scope.FindVar(arg_->step_net);
PADDLE_ENFORCE(net != nullptr, "failed to get step net");
// If the InferShape is called in OperatorBase's run function,
// the rnn op only needs to do InferShape for the first time step
for (size_t i = 0; i < seq_len_; i++) {
if (i > 0) {
rnn::LinkMemories(step_scopes, arg_->memories, i, -1);
rnn::LinkMemories(
step_scopes, arg_->memories, i, -1, true /*infer_shape_mode*/);
}
net->GetMutable<NetOp>()->InferShape(*step_scopes[i]);
}
auto outlinks = arg_->outlinks;
for (size_t i = 0; i < outlinks.size(); i++) {
DDim step_dims = step_scopes[0]
->FindVar(outlinks[i].internal)
->GetMutable<Tensor>()
->dims();
std::vector<int> dims_vec = vectorize(step_dims);
// now only support fixed length
dims_vec.insert(dims_vec.begin(), seq_len_);
Tensor* output =
step_scopes[0]->FindVar(outlinks[i].external)->GetMutable<Tensor>();
output->Resize(make_ddim(dims_vec));
}
rnn::ConcatOutputs(
step_scopes, arg_->outlinks, seq_len_, true /*infer_shape_mode*/);
}
void RecurrentAlgorithm::Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const {
auto step_scopes = GetStepScopes(scope);
rnn::SegmentInputs(
step_scopes, arg_->inlinks, seq_len_, false /*infer_shape_mode*/);
InitMemories(step_scopes[0], false /*infer_shape_mode*/);
Variable* net = scope.FindVar(arg_->step_net);
for (size_t step_id = 0; step_id < seq_len_; step_id++) {
// the link memory is done in InferShape
// maybe remove following code after testing
if (step_id > 0) {
rnn::LinkMemories(step_scopes, arg_->memories, step_id, -1);
rnn::LinkMemories(
step_scopes, arg_->memories, step_id, -1, false /*infer_shape_mode*/);
}
net->GetMutable<NetOp>()->Run(*step_scopes[step_id], dev_ctx);
}
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_);
rnn::ConcatOutputs(
step_scopes, arg_->outlinks, seq_len_, false /*infer_shape_mode*/);
}
void RecurrentAlgorithm::CreateScopes(const Scope& scope) const {
......@@ -244,18 +229,19 @@ void RecurrentAlgorithm::CreateScopes(const Scope& scope) const {
// Now all variables in scope must be created outside of op.
auto net_op = scope.FindVar(arg_->step_net)->GetMutable<NetOp>();
for (auto& input : net_op->inputs_) {
// the weight are located in parent scope
if (!step_scope.FindVar(input)) step_scope.NewVar(input);
}
for (auto& output : net_op->outputs_) {
step_scope.NewVar(output);
}
step_scopes->emplace_back(&step_scope);
}
}
}
void RecurrentAlgorithm::InitMemories(Scope* step_scope) const {
void RecurrentAlgorithm::InitMemories(Scope* step_scope,
bool infer_shape_mode) const {
for (auto& attr : arg_->memories) {
Tensor* pre_mem = step_scope->NewVar(attr.pre_var)->GetMutable<Tensor>();
PADDLE_ENFORCE(step_scope->FindVar(attr.boot_var) != nullptr,
......@@ -263,13 +249,11 @@ void RecurrentAlgorithm::InitMemories(Scope* step_scope) const {
attr.var,
attr.boot_var);
Tensor* boot_mem = step_scope->FindVar(attr.boot_var)->GetMutable<Tensor>();
pre_mem->ShareDataWith<float>(*boot_mem);
// TODO(qingqing) remove following code
// the memory of current step should be allocated in step net
// here for unit test
auto cur_step_mem = step_scope->NewVar(attr.var)->GetMutable<Tensor>();
cur_step_mem->mutable_data<float>(boot_mem->dims(), platform::CPUPlace());
if (infer_shape_mode) {
pre_mem->Resize(boot_mem->dims());
} else {
pre_mem->ShareDataWith<float>(*boot_mem);
}
}
}
......@@ -307,13 +291,14 @@ public:
: OpProtoAndCheckerMaker(proto, op_checker) {
const auto& name = RecurrentOp::kArgName;
// inputs and outputs stored in proto
AddInput(name.inlinks, "the input that need to be segmented for each step.")
AddInput(name.inlinks,
"the inputs that need to be segmented for each step.")
.SetMultiple();
AddInput(name.boot_memories, "variables to initialize memories.")
.SetMultiple();
AddInput(name.step_net, "network shared by all steps.");
AddOutput(name.outlinks, "the output that need to concated for all steps.")
AddOutput(name.outlinks, "the outputs that need to concated for all steps.")
.SetMultiple();
AddOutput(name.step_scopes, "step scopes");
......@@ -331,34 +316,39 @@ public:
void RecurrentGradientAlgorithm::Run(
const Scope& scope, const platform::DeviceContext& dev_ctx) const {
auto step_scopes = GetStepScopes(scope);
rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_);
PADDLE_ENFORCE(scope.FindVar(arg_->step_net) != nullptr,
"step net is not in scope.");
rnn::SegmentInputs(
step_scopes, arg_->inlinks, seq_len_, false /*infer_shape_mode*/);
Variable* net = scope.FindVar(arg_->step_net);
PADDLE_ENFORCE(net != nullptr, "failed to get step net");
for (int step_id = seq_len_ - 1; step_id >= 0; --step_id) {
if (static_cast<size_t>(step_id) != seq_len_ - 1) {
rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1);
rnn::LinkMemories(
step_scopes, arg_->memories, step_id, 1, false /*infer_shape_mode*/);
}
net->GetMutable<NetOp>()->Run(*step_scopes[step_id], dev_ctx);
}
LinkBootMemoryGradients(step_scopes[0]);
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_);
LinkBootMemoryGradients(step_scopes[0], false);
rnn::ConcatOutputs(
step_scopes, arg_->outlinks, seq_len_, false /*infer_shape_mode*/);
}
void RecurrentGradientAlgorithm::LinkBootMemoryGradients(
Scope* step_scope) const {
Scope* step_scope, bool infer_shape_mode) const {
for (auto& attr : arg_->memories) {
Tensor* mem_grad = step_scope->NewVar(attr.var)->GetMutable<Tensor>();
PADDLE_ENFORCE(mem_grad != nullptr,
"boot_tensor should be retrieved before");
PADDLE_ENFORCE(step_scope->FindVar(attr.var) != nullptr,
"memory variable [%s] does not exists",
attr.var);
PADDLE_ENFORCE(step_scope->FindVar(attr.boot_var) != nullptr,
"memory [%s]'s boot variable [%s] not exists",
attr.var,
"boot variable [%s] does not exists",
attr.boot_var);
Tensor* mem_grad = step_scope->NewVar(attr.var)->GetMutable<Tensor>();
Tensor* boot_mem_grad =
step_scope->NewVar(attr.boot_var)->GetMutable<Tensor>();
boot_mem_grad->ShareDataWith<float>(*mem_grad);
if (infer_shape_mode) {
boot_mem_grad->Resize(mem_grad->dims());
} else {
boot_mem_grad->ShareDataWith<float>(*mem_grad);
}
}
}
......@@ -367,34 +357,20 @@ void RecurrentGradientAlgorithm::InferShape(const Scope& scope) const {
->GetMutable<Tensor>()
->dims()[0];
auto step_scopes = GetStepScopes(scope);
rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_);
PADDLE_ENFORCE(scope.FindVar(arg_->step_net) != nullptr,
"step net is not in scope.");
rnn::SegmentInputs(
step_scopes, arg_->inlinks, seq_len_, true /*infer_shape_mode*/);
Variable* net = scope.FindVar(arg_->step_net);
PADDLE_ENFORCE(net != nullptr, "failed to get step net");
for (int step_id = seq_len_ - 1; step_id >= 0; --step_id) {
if (static_cast<size_t>(step_id) != seq_len_ - 1) {
rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1);
rnn::LinkMemories(
step_scopes, arg_->memories, step_id, 1, true /*infer_shape_mode*/);
}
net->GetMutable<NetOp>()->InferShape(*step_scopes[step_id]);
}
auto outlinks = arg_->outlinks;
for (size_t i = 0; i < outlinks.size(); i++) {
DDim step_dims = step_scopes[0]
->FindVar(outlinks[i].internal)
->GetMutable<Tensor>()
->dims();
std::vector<int> dims_vec = vectorize(step_dims);
// now only support fixed length
dims_vec.insert(dims_vec.begin(), seq_len_);
Tensor* output =
step_scopes[0]->FindVar(outlinks[i].external)->GetMutable<Tensor>();
output->Resize(make_ddim(dims_vec));
}
LinkBootMemoryGradients(step_scopes[0]);
rnn::ConcatOutputs(
step_scopes, arg_->outlinks, seq_len_, true /*infer_shape_mode*/);
LinkBootMemoryGradients(step_scopes[0], true /*infer_shape_mode*/);
}
void RecurrentGradientOp::Init() {
......
......@@ -72,19 +72,22 @@ struct ArgumentName {
*/
void SegmentInputs(const std::vector<Scope*>& step_scopes,
const std::vector<Link>& inlinks,
const size_t seq_len);
const size_t seq_len,
bool infer_shape_mode);
/**
* Process outputs of step nets and merge to variables.
*/
void ConcatOutputs(const std::vector<Scope*>& step_scopes,
const std::vector<Link>& outlinks,
const size_t seq_len);
const size_t seq_len,
bool infer_shape_mode);
void LinkMemories(const std::vector<Scope*>& step_scopes,
const std::vector<MemoryAttr>& memories,
size_t step_id,
int offset);
const size_t step_id,
const int offset,
bool infer_shape_mode);
void InitArgument(const ArgumentName& name, Argument* arg);
......@@ -122,7 +125,7 @@ protected:
return *scope.FindVar(arg_->step_scopes)->GetMutable<std::vector<Scope*>>();
}
void InitMemories(Scope* step_scopes) const;
void InitMemories(Scope* step_scopes, bool infer_shape_mode) const;
private:
std::unique_ptr<rnn::Argument> arg_;
......@@ -145,7 +148,7 @@ public:
void Run(const Scope& scope, const platform::DeviceContext& dev_ctx) const;
void LinkBootMemoryGradients(Scope* step_scopes) const;
void LinkBootMemoryGradients(Scope* step_scopes, bool infer_shape_mode) const;
/**
* InferShape must be called before Run.
......
......@@ -18,7 +18,7 @@
#include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/tensor.h"
#include "paddle/operators/recurrent_network_op.h"
#include "paddle/operators/recurrent_op.h"
namespace paddle {
namespace operators {
......@@ -55,7 +55,7 @@ protected:
w->GetMutable<Tensor>()->mutable_data<float>(
make_ddim(std::vector<int>{30, 30}), platform::CPUPlace());
for (auto boot : std::vector<std::string>{"x_boot", "h_boot"}) {
for (auto boot : std::vector<std::string>{"h_boot"}) {
LOG(INFO) << "create global variable " << boot;
Variable* h_boot = scope_.NewVar(boot);
h_boot->GetMutable<Tensor>()->mutable_data<float>(
......@@ -79,7 +79,6 @@ protected:
op_desc.add_inputs("x0");
op_desc.add_inputs("x1");
// boot_memories 3
op_desc.add_inputs("x_boot");
op_desc.add_inputs("h_boot");
// step net 5
op_desc.add_inputs("step_net");
......@@ -91,7 +90,7 @@ protected:
auto _input_format = std::vector<int>{
0, // in_link
3, // memories
5 // step_net
4 // step_net
};
auto input_format = op_desc.add_attrs();
input_format->set_name("input_format");
......@@ -129,12 +128,11 @@ protected:
inlink_alias->add_strings(item);
}
// pre memories
for (const auto& item :
std::vector<std::string>{"rnn/x@pre", "rnn/h@pre"}) {
for (const auto& item : std::vector<std::string>{"rnn/h@pre"}) {
pre_memories->add_strings(item);
}
// memories
for (const auto& item : std::vector<std::string>{"rnn/x", "rnn/h"}) {
for (const auto& item : std::vector<std::string>{"rnn/h"}) {
memories->add_strings(item);
}
// output alias
......@@ -151,14 +149,11 @@ protected:
LOG(INFO) << "create variable step_net";
Variable* var = scope_.NewVar("step_net");
auto net = var->GetMutable<NetOp>();
// rnn/s is net's input or output?
net->inputs_ = {"rnn/h@pre", "rnn/w", "rnn/x"};
net->inputs_ = {"rnn/s", "rnn/h"};
net->AddOp(
OpRegistry::CreateOp("mul", {"rnn/h@pre", "rnn/w"}, {"rnn/s"}, {}));
net->AddOp(
OpRegistry::CreateOp("add_two", {"rnn/x", "rnn/s"}, {"rnn/h"}, {}));
OpRegistry::CreateOp("add_two", {"x@alias", "rnn/s"}, {"rnn/h"}, {}));
net->CompleteAddOp();
}
......@@ -297,7 +292,10 @@ protected:
inlink.internal = "rnn/x";
auto step_scopes =
scope_.FindVar("step_scopes")->GetMutable<std::vector<Scope*>>();
rnn::SegmentInputs(*step_scopes, std::vector<rnn::Link>{inlink}, 10);
rnn::SegmentInputs(*step_scopes,
std::vector<rnn::Link>{inlink},
10,
true /*infer_shape_mode*/);
}
void LinkeMemories() {
......@@ -311,7 +309,8 @@ protected:
auto step_scopes =
scope_.FindVar("step_scopes")->GetMutable<std::vector<Scope*>>();
for (int i = 1; i < 10; ++i) {
rnn::LinkMemories(*step_scopes, memories, i, -1);
rnn::LinkMemories(
*step_scopes, memories, i, -1, true /*infer_shape_mode*/);
}
}
......@@ -333,14 +332,14 @@ TEST(RecurrentOp, LinkMemories) {
using namespace paddle::operators;
// create and init step scopes
int len = 10;
size_t len = 10;
std::vector<Scope*> step_scopes;
for (int i = 0; i < len; ++i) {
for (size_t i = 0; i < len; ++i) {
auto scope = new Scope();
scope->NewVar("pre_h");
auto tensor = scope->NewVar("h")->GetMutable<Tensor>();
float* data = tensor->mutable_data<float>({15, 20}, CPUPlace());
for (int j = 0; j < 15 * 20; ++j) {
for (size_t j = 0; j < 15 * 20; ++j) {
data[j] = rand() * (1. / (double)RAND_MAX);
}
step_scopes.push_back(scope);
......@@ -354,24 +353,24 @@ TEST(RecurrentOp, LinkMemories) {
std::vector<rnn::MemoryAttr> memories;
memories.push_back(mem_attr);
for (int i = 1; i < len; ++i) {
rnn::LinkMemories(step_scopes, memories, i, -1);
for (size_t i = 1; i < len; ++i) {
rnn::LinkMemories(step_scopes, memories, i, -1, false /*infer_shape_mode*/);
}
// check
for (int i = 0; i < len - 1; ++i) {
for (size_t i = 0; i < len - 1; ++i) {
const float* a =
step_scopes[i]->FindVar("h")->GetMutable<Tensor>()->data<float>();
const float* b = step_scopes[i + 1]
->FindVar("pre_h")
->GetMutable<Tensor>()
->data<float>();
for (size_t i = 0; i < 15 * 20; ++i) {
ASSERT_FLOAT_EQ(a[i], b[i]);
for (size_t j = 0; j < 15 * 20; ++j) {
ASSERT_FLOAT_EQ(a[j], b[j]);
}
}
for (int i = len - 2; i >= 0; --i) {
rnn::LinkMemories(step_scopes, memories, i, 1);
rnn::LinkMemories(step_scopes, memories, i, 1, false /*infer_shape_mode*/);
}
// check
for (int i = len - 2; i >= 0; --i) {
......@@ -379,8 +378,8 @@ TEST(RecurrentOp, LinkMemories) {
step_scopes[i]->FindVar("pre_h")->GetMutable<Tensor>()->data<float>();
const float* b =
step_scopes[i + 1]->FindVar("h")->GetMutable<Tensor>()->data<float>();
for (size_t i = 0; i < 15 * 20; ++i) {
ASSERT_FLOAT_EQ(a[i], b[i]);
for (size_t j = 0; j < 15 * 20; ++j) {
ASSERT_FLOAT_EQ(a[j], b[j]);
}
}
......@@ -391,9 +390,3 @@ TEST(RecurrentOp, LinkMemories) {
USE_OP(add_two);
USE_OP(mul);
// int main() {
// //! TODO(yuyang18): Temporary disable this unit-test because implementation
// //! error.
// return 0;
//}
\ No newline at end of file
#define EIGEN_USE_GPU
#include "paddle/operators/rowwise_add_op.h"
REGISTER_OP_GPU_KERNEL(rowwise_add,
......
......@@ -33,7 +33,7 @@ public:
const int rest_size = input.size() / bias_size;
Eigen::DSizes<int, 1> one_d(input.size());
Eigen::DSizes<int, 1> bcast(rest_size);
output.reshape(one_d).device(*(context.GetEigenDevice<Place>())) =
output.reshape(one_d).device(context.GetEigenDevice<Place>()) =
input.reshape(one_d) + bias.broadcast(bcast).reshape(one_d);
}
};
......
#define EIGEN_USE_GPU
#include "paddle/operators/sgd_op.h"
REGISTER_OP_GPU_KERNEL(sgd, ops::SGDOpKernel<ops::GPUPlace, float>);
\ No newline at end of file
......@@ -29,8 +29,12 @@ public:
param_out->mutable_data<T>(ctx.GetPlace());
EigenVector<T>::Flatten(*param_out).device(*(ctx.GetEigenDevice<Place>())) =
EigenVector<T>::Flatten(*param) - lr * EigenVector<T>::Flatten(*grad);
auto p = EigenVector<T>::Flatten(*param);
auto g = EigenVector<T>::Flatten(*grad);
auto o = EigenVector<T>::Flatten(*param_out);
auto place = ctx.GetEigenDevice<Place>();
o.device(place) = p - lr * g;
}
};
......
#define EIGEN_USE_GPU
#include "paddle/operators/sigmoid_op.h"
REGISTER_OP_GPU_KERNEL(sigmoid, ops::SigmoidKernel<ops::GPUPlace, float>);
......@@ -27,9 +27,11 @@ public:
auto output = context.Output<Tensor>(0);
output->mutable_data<T>(context.GetPlace());
EigenVector<T>::Flatten(*output).device(
*(context.GetEigenDevice<Place>())) =
1.0 / (1.0 + (-1.0 * EigenVector<T>::Flatten(*input)).exp());
auto X = EigenVector<T>::Flatten(*input);
auto Y = EigenVector<T>::Flatten(*output);
auto place = context.GetEigenDevice<Place>();
Y.device(place) = 1.0 / (1.0 + (-1.0 * X).exp());
}
};
} // namespace operators
......
#define EIGEN_USE_GPU
#include "paddle/framework/op_registry.h"
#include "paddle/operators/softmax_op.h"
......
......@@ -46,9 +46,9 @@ public:
.reshape(batch_by_one)
.broadcast(one_by_class));
softmax.device(*(context.GetEigenDevice<Place>())) = shifted_logits.exp();
softmax.device(context.GetEigenDevice<Place>()) = shifted_logits.exp();
softmax.device(*(context.GetEigenDevice<Place>())) =
softmax.device(context.GetEigenDevice<Place>()) =
(softmax *
softmax.sum(along_class)
.inverse()
......
......@@ -51,6 +51,7 @@ using CPUPlace = platform::CPUPlace;
using GPUPlace = platform::GPUPlace;
using NetOp = framework::NetOp;
using OpRegistry = framework::OpRegistry;
using OperatorBase = framework::OperatorBase;
} // namespace operators
} // namespace paddle
......
......@@ -144,12 +144,12 @@ inline void throw_on_error(T e) {
throw_on_error(e, "");
}
#define PADDLE_THROW(...) \
do { \
throw ::paddle::platform::EnforceNotMet( \
std::make_exception_ptr( \
std::runtime_error(string::Sprintf(__VA_ARGS__))), \
__FILE__, __LINE__); \
#define PADDLE_THROW(...) \
do { \
throw ::paddle::platform::EnforceNotMet( \
std::make_exception_ptr( \
std::runtime_error(paddle::string::Sprintf(__VA_ARGS__))), \
__FILE__, __LINE__); \
} while (0)
#define PADDLE_ENFORCE(...) \
......
cc_library(paddle_pybind SHARED
SRCS pybind.cc
DEPS pybind python backward
fc_op
sgd_op
add_op
mean_op
cross_entropy_op
recurrent_op)
......@@ -148,7 +148,7 @@ cat >> /paddle/build/Dockerfile <<EOF
ADD *.deb /
# run paddle version to install python packages first
RUN apt-get update &&\
apt-get install -y python-pip && pip install -U pip && \
apt-get install -y wget python-pip && pip install -U pip && \
dpkg -i /*.deb ; apt-get install -f -y && \
apt-get clean -y && \
rm -f /*.deb && \
......
......@@ -298,6 +298,11 @@ message DetectionOutputConfig {
optional uint32 width = 9 [default = 1];
}
message ClipConfig {
required double min = 1;
required double max = 2;
}
message LayerInputConfig {
required string input_layer_name = 1;
optional string input_parameter_name = 2;
......@@ -318,6 +323,7 @@ message LayerInputConfig {
optional RowConvConfig row_conv_conf = 15;
optional MultiBoxLossConfig multibox_loss_conf = 16;
optional DetectionOutputConfig detection_output_conf = 17;
optional ClipConfig clip_conf = 18;
}
message LayerConfig {
......
......@@ -2198,6 +2198,20 @@ class RowConvLayer(LayerBase):
self.create_input_parameter(0, psize, dims)
@config_layer('clip')
class ClipLayer(LayerBase):
def __init__(self, name, inputs, min, max, **xargs):
super(ClipLayer, self).__init__(name, 'clip', 0, inputs=inputs, **xargs)
config_assert(
len(self.inputs) == 1,
'ClipLayer must have one and only one input.')
config_assert(min < max, 'min must be less than max.')
input_layer = self.get_input_layer(0)
self.set_layer_size(input_layer.size)
self.config.inputs[0].clip_conf.min = min
self.config.inputs[0].clip_conf.max = max
# key: cost type
# value: cost class
g_cost_map = {}
......@@ -2754,6 +2768,16 @@ class SumToOneNormLayer(LayerBase):
self.set_layer_size(input_layer0.size)
@config_layer('row_l2_norm')
class RowL2NormLayer(LayerBase):
def __init__(self, name, inputs, **xargs):
super(RowL2NormLayer, self).__init__(
name, 'row_l2_norm', 0, inputs=inputs, **xargs)
config_assert(len(self.inputs) == 1, 'RowL2NormLayer must have 1 input')
input_layer = self.get_input_layer(0)
self.set_layer_size(input_layer.size)
@config_layer('cos_vm')
class CosSimVecMatLayer(LayerBase):
def __init__(self, name, size, inputs, cos_scale=1.0, device=None):
......
......@@ -76,6 +76,7 @@ __all__ = [
'trans_layer',
'rotate_layer',
'sum_to_one_norm_layer',
'row_l2_norm_layer',
'get_output_layer',
'LayerType',
'context_projection',
......@@ -128,6 +129,7 @@ __all__ = [
'prelu_layer',
'gated_unit_layer',
'crop_layer',
'clip_layer',
'slice_projection',
]
......@@ -160,6 +162,7 @@ class LayerType(object):
BATCH_NORM_LAYER = 'batch_norm'
NORM_LAYER = 'norm'
SUM_TO_ONE_NORM_LAYER = 'sum_to_one_norm'
ROW_L2_NORM_LAYER = 'row_l2_norm'
ADDTO_LAYER = 'addto'
CONCAT_LAYER = 'concat'
......@@ -221,6 +224,7 @@ class LayerType(object):
PRELU = 'prelu'
CROP_LAYER = 'crop'
CLIP_LAYER = 'clip'
@staticmethod
def is_layer_type(type_name):
......@@ -2889,6 +2893,42 @@ def sum_to_one_norm_layer(input, name=None, layer_attr=None):
name, LayerType.SUM_TO_ONE_NORM_LAYER, parents=[input], size=input.size)
@wrap_name_default()
@layer_support()
def row_l2_norm_layer(input, name=None, layer_attr=None):
"""
A layer for L2-normalization in each row.
.. math::
out[i] = \frac{in[i]}{\sqrt{\sum_{k=1}^N in[k]^{2}}}
where the size of :math:`in` is (batchSize x dataDim) ,
and the size of :math:`out` is a (batchSize x dataDim) .
The example usage is:
.. code-block:: python
row_l2_norm_layer = row_l2_norm_layer(input=layer)
:param input: Input layer.
:type input: LayerOutput
:param name: Layer name.
:type name: basestring
:param layer_attr: extra layer attributes.
:type layer_attr: ExtraLayerAttribute.
:return: LayerOutput object.
:rtype: LayerOutput
"""
Layer(
name=name,
type=LayerType.ROW_L2_NORM_LAYER,
inputs=[input.name],
**ExtraAttr.to_kwargs(layer_attr))
return LayerOutput(
name, LayerType.ROW_L2_NORM_LAYER, parents=[input], size=input.size)
@wrap_name_default("addto")
@wrap_act_default(act=LinearActivation())
@wrap_bias_attr_default(has_bias=False)
......@@ -6046,3 +6086,36 @@ def crop_layer(input, offset, axis=2, shape=None, name=None, layer_attr=None):
layer_type=LayerType.CROP_LAYER,
parents=input,
size=l.config.size)
@wrap_name_default("clip")
def clip_layer(input, min, max, name=None):
"""
A layer for clipping the input value by the threshold.
.. math::
out[i] = \min\left(\max\left(in[i],p_{1}\right),p_{2}\right)
.. code-block:: python
clip = clip_layer(input=input_layer, min=-10, max=10)
:param name: The Layer Name.
:type name: basestring
:param input: The input layer.
:type input: LayerOutput.
:param min: The lower threshold for clipping.
:type min: double
:param max: The upper threshold for clipping.
:type max: double
:return: LayerOutput
"""
Layer(
name=name,
type=LayerType.CLIP_LAYER,
inputs=[input.name],
min=min,
max=max)
return LayerOutput(
name, LayerType.CLIP_LAYER, parents=[input], size=input.size)
......@@ -7,6 +7,6 @@ test_rnn_group shared_fc shared_lstm shared_gru test_cost_layers_with_weight
test_spp_layer test_bilinear_interp test_maxout test_bi_grumemory math_ops
test_seq_concat_reshape test_pad test_smooth_l1 test_multiplex_layer
test_prelu_layer test_row_conv test_detection_output_layer test_multibox_loss_layer
test_recursive_topology test_gated_unit_layer)
test_recursive_topology test_gated_unit_layer test_clip_layer test_row_l2_norm_layer)
export whole_configs=(test_split_datasource)
type: "nn"
layers {
name: "input"
type: "data"
size: 300
active_type: ""
}
layers {
name: "__clip_0__"
type: "clip"
size: 300
active_type: ""
inputs {
input_layer_name: "input"
clip_conf {
min: -10
max: 10
}
}
}
input_layer_names: "input"
output_layer_names: "__clip_0__"
sub_models {
name: "root"
layer_names: "input"
layer_names: "__clip_0__"
input_layer_names: "input"
output_layer_names: "__clip_0__"
is_recurrent_layer_group: false
}
type: "nn"
layers {
name: "input"
type: "data"
size: 300
active_type: ""
}
layers {
name: "__row_l2_norm_layer_0__"
type: "row_l2_norm"
size: 300
active_type: ""
inputs {
input_layer_name: "input"
}
}
input_layer_names: "input"
output_layer_names: "__row_l2_norm_layer_0__"
sub_models {
name: "root"
layer_names: "input"
layer_names: "__row_l2_norm_layer_0__"
input_layer_names: "input"
output_layer_names: "__row_l2_norm_layer_0__"
is_recurrent_layer_group: false
}
from paddle.trainer_config_helpers import *
data = data_layer(name='input', size=300)
clip = clip_layer(input=data, min=-10, max=10)
outputs(clip)
from paddle.trainer_config_helpers import *
data = data_layer(name='input', size=300)
row_l2_norm = row_l2_norm_layer(input=data)
outputs(row_l2_norm)
......@@ -8,7 +8,6 @@ add_python_test(test_framework
test_fc_op.py
test_add_two_op.py
test_sgd_op.py
test_cross_entropy_op.py
test_mul_op.py
test_mean_op.py
test_sigmoid_op.py
......
......@@ -26,40 +26,45 @@ class OpTestMeta(type):
scope = core.Scope()
kwargs = dict()
places = []
places.append(core.CPUPlace())
if core.is_compile_gpu():
places.append(core.GPUPlace(0))
for in_name in func.all_input_args:
if hasattr(self, in_name):
kwargs[in_name] = in_name
var = scope.new_var(in_name).get_tensor()
arr = getattr(self, in_name)
var.set_dims(arr.shape)
var.set(arr)
else:
kwargs[in_name] = "@EMPTY@"
for place in places:
for in_name in func.all_input_args:
if hasattr(self, in_name):
kwargs[in_name] = in_name
var = scope.new_var(in_name).get_tensor()
arr = getattr(self, in_name)
var.set_dims(arr.shape)
var.set(arr, place)
else:
kwargs[in_name] = "@EMPTY@"
for out_name in func.all_output_args:
if hasattr(self, out_name):
kwargs[out_name] = out_name
scope.new_var(out_name).get_tensor()
for out_name in func.all_output_args:
if hasattr(self, out_name):
kwargs[out_name] = out_name
scope.new_var(out_name).get_tensor()
for attr_name in func.all_attr_args:
if hasattr(self, attr_name):
kwargs[attr_name] = getattr(self, attr_name)
for attr_name in func.all_attr_args:
if hasattr(self, attr_name):
kwargs[attr_name] = getattr(self, attr_name)
op = func(**kwargs)
op = func(**kwargs)
op.infer_shape(scope)
op.infer_shape(scope)
ctx = core.DeviceContext.cpu_context()
op.run(scope, ctx)
ctx = core.DeviceContext.create(place)
op.run(scope, ctx)
for out_name in func.all_output_args:
actual = numpy.array(scope.find_var(out_name).get_tensor())
expect = getattr(self, out_name)
# TODO(qijun) The default decimal is 7, but numpy.dot and eigen.mul
# has some diff, and could not pass unittest. So I set decimal 3 here.
# And I will check this in future.
numpy.testing.assert_almost_equal(actual, expect, decimal=3)
for out_name in func.all_output_args:
actual = numpy.array(scope.find_var(out_name).get_tensor())
expect = getattr(self, out_name)
# TODO(qijun) The default decimal is 7, but numpy.dot and eigen.mul
# has some diff, and could not pass unittest. So I set decimal 3 here.
# And I will check this in future.
numpy.testing.assert_almost_equal(actual, expect, decimal=3)
obj.test_all = test_all
return obj
import unittest
from op_test_util import OpTestMeta
import numpy
import paddle.v2.framework.core as core
import paddle.v2.framework.create_op_creation_methods as creation
from op_test_util import OpTestMeta
class TestAddOp(unittest.TestCase):
......@@ -8,10 +12,19 @@ class TestAddOp(unittest.TestCase):
def setUp(self):
self.type = "add_two"
self.X = numpy.random.random((342, 345)).astype("float32")
self.Y = numpy.random.random((342, 345)).astype("float32")
self.X = numpy.random.random((102, 105)).astype("float32")
self.Y = numpy.random.random((102, 105)).astype("float32")
self.Out = self.X + self.Y
class TestAddGradOp(unittest.TestCase):
def test_add_grad(self):
op = creation.op_creations.add_two(X="X", Y="Y", Out="Out")
backward_op = core.Operator.backward(op, set())
self.assertEqual(backward_op.type(), "add_two_grad")
expected = '''Op(add_two_grad), inputs:(X, Y, Out, Out@GRAD), outputs:(X@GRAD, Y@GRAD).'''
self.assertEqual(expected, str(backward_op))
if __name__ == '__main__':
unittest.main()
......@@ -7,17 +7,19 @@ import paddle.v2.framework.create_op_creation_methods as creation
class TestFc(unittest.TestCase):
def test_fc(self):
scope = core.Scope()
place = core.CPUPlace()
x = scope.new_var("X")
x_tensor = x.get_tensor()
x_tensor.set_dims([1000, 784])
x_tensor.alloc_float()
x_tensor.alloc_float(place)
w = scope.new_var("W")
w_tensor = w.get_tensor()
w_tensor.set_dims([784, 100])
w_tensor.alloc_float()
w_tensor.alloc_float(place)
w_tensor.set(numpy.random.random((784, 100)).astype("float32"))
w_tensor.set(numpy.random.random((784, 100)).astype("float32"), place)
# Set a real numpy array here.
# x_tensor.set(numpy.array([]))
......@@ -32,7 +34,7 @@ class TestFc(unittest.TestCase):
op.infer_shape(scope)
self.assertEqual([1000, 100], tensor.shape())
ctx = core.DeviceContext.cpu_context()
ctx = core.DeviceContext.create(place)
op.run(scope, ctx)
......
......@@ -8,8 +8,8 @@ class TestMulOp(unittest.TestCase):
def setUp(self):
self.type = "mul"
self.X = np.random.random((32, 784)).astype("float32")
self.Y = np.random.random((784, 100)).astype("float32")
self.X = np.random.random((32, 84)).astype("float32")
self.Y = np.random.random((84, 100)).astype("float32")
self.Out = np.dot(self.X, self.Y)
......
......@@ -8,8 +8,8 @@ class TestRowwiseAddOp(unittest.TestCase):
def setUp(self):
self.type = "rowwise_add"
self.X = np.random.random((32, 784)).astype("float32")
self.b = np.random.random(784).astype("float32")
self.X = np.random.random((32, 84)).astype("float32")
self.b = np.random.random(84).astype("float32")
self.Out = np.add(self.X, self.b)
......
......@@ -8,8 +8,8 @@ class TestSGD(unittest.TestCase):
def setUp(self):
self.type = "sgd"
self.param = numpy.random.random((342, 345)).astype("float32")
self.grad = numpy.random.random((342, 345)).astype("float32")
self.param = numpy.random.random((102, 105)).astype("float32")
self.grad = numpy.random.random((102, 105)).astype("float32")
self.learning_rate = 0.1
self.param_out = self.param - self.learning_rate * self.grad
......
......@@ -7,16 +7,17 @@ class TestScope(unittest.TestCase):
def test_int_tensor(self):
scope = core.Scope()
var = scope.new_var("test_tensor")
place = core.CPUPlace()
tensor = var.get_tensor()
tensor.set_dims([1000, 784])
tensor.alloc_int()
tensor.alloc_int(place)
tensor_array = numpy.array(tensor)
self.assertEqual((1000, 784), tensor_array.shape)
tensor_array[3, 9] = 1
tensor_array[19, 11] = 2
tensor.set(tensor_array)
tensor.set(tensor_array, place)
tensor_array_2 = numpy.array(tensor)
self.assertEqual(1.0, tensor_array_2[3, 9])
......@@ -25,16 +26,18 @@ class TestScope(unittest.TestCase):
def test_float_tensor(self):
scope = core.Scope()
var = scope.new_var("test_tensor")
place = core.CPUPlace()
tensor = var.get_tensor()
tensor.set_dims([1000, 784])
tensor.alloc_float()
tensor.alloc_float(place)
tensor_array = numpy.array(tensor)
self.assertEqual((1000, 784), tensor_array.shape)
tensor_array[3, 9] = 1.0
tensor_array[19, 11] = 2.0
tensor.set(tensor_array)
tensor.set(tensor_array, place)
tensor_array_2 = numpy.array(tensor)
self.assertAlmostEqual(1.0, tensor_array_2[3, 9])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册