提交 95440685 编写于 作者: Y Yi Wang

Resovle conflicts manually

......@@ -38,7 +38,7 @@ before_install:
# Paddle is using protobuf 3.1 currently. Protobuf 3.2 breaks the compatibility. So we specify the python
# protobuf version.
- pip install numpy wheel 'protobuf==3.1' sphinx==1.5.6 recommonmark sphinx-rtd-theme==0.1.9 virtualenv pre-commit requests==2.9.2 LinkChecker
- pip install rarfile
- pip install rarfile nltk==3.2.2 scipy==0.19.0 recordio matplotlib Pillow
- curl https://glide.sh/get | bash
- eval "$(GIMME_GO_VERSION=1.8.3 gimme)"
- go get -u github.com/alecthomas/gometalinter
......
......@@ -38,17 +38,16 @@ RUN apt-get update && \
RUN pip --no-cache-dir install 'numpy>=1.12.0'
# Install Go and glide
RUN wget -O go.tgz https://storage.googleapis.com/golang/go1.8.1.linux-amd64.tar.gz && \
tar -C /usr/local -xzf go.tgz && \
RUN wget -qO- https://storage.googleapis.com/golang/go1.8.1.linux-amd64.tar.gz | \
tar -xz -C /usr/local && \
mkdir /root/gopath && \
mkdir /root/gopath/bin && \
mkdir /root/gopath/src && \
rm go.tgz
mkdir /root/gopath/src
ENV GOROOT=/usr/local/go GOPATH=/root/gopath
# should not be in the same line with GOROOT definition, otherwise docker build could not find GOROOT.
ENV PATH=${PATH}:${GOROOT}/bin:${GOPATH}/bin
# install glide
RUN curl -q https://glide.sh/get | sh
RUN curl -s -q https://glide.sh/get | sh
# git credential to skip password typing
RUN git config --global credential.helper store
......
......@@ -257,6 +257,16 @@ seq_concat
.. autoclass:: paddle.v2.layer.seq_concat
:noindex:
kmax_sequence_score
-------------------
.. autoclass:: paddle.v2.layer.kmax_sequence_score
:noindex:
sub_nested_seq
--------------
.. autoclass:: paddle.v2.layer.sub_nested_seq
:noindex:
Reshaping Layers
================
......
......@@ -11,6 +11,15 @@ Paddle每次发新的版本,遵循以下流程:
* 编译这个版本的Ubuntu Deb包。如果失败,修复Ubuntu Deb包编译问题,Patch号加一,返回第二步。
* 使用Regression Test List作为检查列表,测试Docker镜像/ubuntu安装包的功能正确性
* 如果失败,记录下所有失败的例子,在这个`release/版本号`分支中,修复所有bug后,Patch号加一,返回第二步
* 编译这个版本的python wheel包,并发布到pypi。
* 由于pypi.python.org目前遵循[严格的命名规范PEP 513](https://www.python.org/dev/peps/pep-0513),在使用twine上传之前,需要重命名wheel包中platform相关的后缀,比如将`linux_x86_64`修改成`manylinux1_x86_64`
* pypi上的package名称为paddlepaddle和paddlepaddle_gpu,如果要上传GPU版本的包,需要修改build/python/setup.py中,name: "paddlepaddle_gpu"并重新打包wheel包:`python setup.py bdist_wheel`
* 上传方法:
```
cd build/python
pip install twine
twine upload dist/[package to upload]
```
4. 第三步完成后,将`release/版本号`分支合入master分支,并删除`release/版本号`分支。将master分支的合入commit打上tag,tag为`版本号`。同时再将`master`分支合入`develop`分支。最后删除`release/版本号`分支。
5. 编译master分支的Docker发行镜像,发布到dockerhub。编译ubuntu的deb包,发布到github release页面
6. 协同完成Release Note的书写
......
......@@ -13,15 +13,11 @@
# serve to show the default.
import sys
import os, subprocess
sys.path.insert(0, os.path.abspath('@PROJ_ROOT@/python'))
import shlex
from recommonmark import parser, transform
try:
import py_paddle
import paddle
import paddle.v2
except ImportError:
print("Must install paddle python package before generating documentation")
sys.exit(1)
import paddle
import paddle.v2
MarkdownParser = parser.CommonMarkParser
AutoStructify = transform.AutoStructify
......
......@@ -13,15 +13,11 @@
# serve to show the default.
import sys
import os, subprocess
sys.path.insert(0, os.path.abspath('@PROJ_ROOT@/python'))
import shlex
from recommonmark import parser, transform
try:
import py_paddle
import paddle
import paddle.v2
except ImportError:
print("Must install paddle python package before generating documentation")
sys.exit(1)
import paddle
import paddle.v2
MarkdownParser = parser.CommonMarkParser
......
......@@ -31,13 +31,17 @@ add_dependencies(framework_py_proto framework_py_proto_init)
cc_library(backward SRCS backward.cc DEPS net_op)
cc_test(backward_test SRCS backward_test.cc DEPS backward)
if(WITH_PYTHON)
cc_library(paddle_pybind SHARED
SRCS pybind.cc
DEPS pybind python backward
fc_op
sgd_op
add_op
mean_op
cross_entropy_op
fill_zeros_like_op
recurrent_op)
fc_op
sgd_op
add_op
mean_op
cross_entropy_op
recurrent_op
uniform_random_op
fill_zeros_like_op)
endif(WITH_PYTHON)
......@@ -13,6 +13,7 @@
limitations under the License. */
#include "paddle/framework/backward.h"
#include <list>
#include "paddle/framework/op_registry.h"
#include "paddle/operators/net_op.h"
......
......@@ -17,16 +17,21 @@
#include <gtest/gtest.h>
#include "paddle/framework/op_registry.h"
#include "paddle/operators/net_op.h"
#include "paddle/operators/type_alias.h"
namespace paddle {
namespace framework {
using OperatorBase = framework::OperatorBase;
using OpProtoAndCheckerMaker = framework::OpProtoAndCheckerMaker;
using OpProto = framework::OpProto;
using OpAttrChecker = framework::OpAttrChecker;
using Scope = framework::Scope;
using DeviceContext = platform::DeviceContext;
class EmptyOp : public OperatorBase {
public:
void InferShape(const Scope &scope) const override {}
void Run(const Scope &scope,
const platform::DeviceContext &dev_ctx) const override {}
void Run(const Scope &scope, const DeviceContext &dev_ctx) const override {}
};
class RowWiseAddOpMaker : public OpProtoAndCheckerMaker {
......@@ -71,7 +76,7 @@ class NoGradOpMaker : public OpProtoAndCheckerMaker {
}
};
class FcOp : public ops::NetOp {
class FcOp : public operators::NetOp {
public:
void Init() override {
AddOp(OpRegistry::CreateOp("mul",
......@@ -145,6 +150,7 @@ class AddOpMaker : public OpProtoAndCheckerMaker {
} // namespace paddle
namespace f = paddle::framework;
namespace ops = paddle::operators;
using EnforceNotMet = paddle::platform::EnforceNotMet;
REGISTER_OP(rowwise_add, f::EmptyOp, f::RowWiseAddOpMaker);
REGISTER_GRADIENT_OP(rowwise_add, rowwise_add_grad, f::EmptyOp);
......
......@@ -204,12 +204,6 @@ class OpRegistry {
return CreateOp(op_desc.type(), inputs, outputs, attrs);
}
static bool SupportGPU(const std::string& op_type) {
OperatorWithKernel::OpKernelKey key;
key.place_ = platform::GPUPlace();
return OperatorWithKernel::AllOpKernels().at(op_type).count(key) != 0;
}
static std::shared_ptr<OperatorBase> CreateGradOp(const OperatorBase& op) {
PADDLE_ENFORCE(!op.IsNetOp(),
"Use framework::Backward to get backward ops");
......
......@@ -87,6 +87,8 @@ class OperatorBase {
virtual bool IsNetOp() const { return false; }
virtual bool SupportGPU() const { return false; }
/// rename inputs outputs name
void Rename(const std::string& old_name, const std::string& new_name);
......@@ -160,14 +162,14 @@ class OperatorContext {
template <typename T>
const T* Input(const std::string& name) const {
auto var = InputVar(name);
PADDLE_ENFORCE(var != nullptr, "Input(%s) should not be nullptr", name);
PADDLE_ENFORCE_NOT_NULL(var, "Input(%s) should not be nullptr", name);
return &var->Get<T>();
}
template <typename T>
T* Output(const std::string& name) const {
auto var = OutputVar(name);
PADDLE_ENFORCE(var != nullptr, "Output(%s) should not be nullptr", name);
PADDLE_ENFORCE_NOT_NULL(var, "Output(%s) should not be nullptr", name);
return var->GetMutable<T>();
}
......@@ -179,9 +181,9 @@ class OperatorContext {
std::transform(names.begin(), names.end(), std::back_inserter(res),
[&](const std::string& sub_name) {
auto var = scope_.FindVar(sub_name);
PADDLE_ENFORCE(var != nullptr,
"MultiInput(%s:%s) should not be nullptr",
name, sub_name);
PADDLE_ENFORCE_NOT_NULL(
var, "MultiInput(%s:%s) should not be nullptr", name,
sub_name);
return &var->Get<T>();
});
return res;
......@@ -195,9 +197,9 @@ class OperatorContext {
std::transform(names.begin(), names.end(), std::back_inserter(res),
[&](const std::string& sub_name) {
auto var = scope_.FindVar(sub_name);
PADDLE_ENFORCE(var != nullptr,
"MultiOutput(%s:%s) should not be nullptr",
name, sub_name);
PADDLE_ENFORCE_NOT_NULL(
var, "MultiOutput(%s:%s) should not be nullptr", name,
sub_name);
return var->GetMutable<T>();
});
return res;
......@@ -283,7 +285,7 @@ class OperatorWithKernel : public OperatorBase {
using OpKernelMap =
std::unordered_map<OpKernelKey, std::unique_ptr<OpKernel>, OpKernelHash>;
void InferShape(const Scope& scope) const {
void InferShape(const Scope& scope) const override {
InferShape(InferShapeContext(this, scope));
}
......@@ -299,6 +301,12 @@ class OperatorWithKernel : public OperatorBase {
return g_all_op_kernels;
}
bool SupportGPU() const override {
OperatorWithKernel::OpKernelKey key;
key.place_ = platform::GPUPlace();
return OperatorWithKernel::AllOpKernels().at(type_).count(key) != 0;
}
protected:
virtual void InferShape(const InferShapeContext& ctx) const = 0;
};
......
......@@ -18,11 +18,8 @@ limitations under the License. */
#include "paddle/framework/backward.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/scope.h"
#include "paddle/framework/tensor_py.h"
#include "paddle/operators/net_op.h"
#include "paddle/operators/type_alias.h"
#include "paddle/platform/enforce.h"
#include "paddle/platform/place.h"
#include "pybind11/numpy.h"
......@@ -42,8 +39,12 @@ USE_OP(softmax);
USE_OP(rowwise_add);
USE_OP(fill_zeros_like);
USE_OP_WITHOUT_KERNEL(recurrent_op);
USE_OP(uniform_random);
namespace paddle {
namespace framework {
using Tensor = framework::Tensor;
template <typename ClassType>
void ExposeOperator(ClassType &m) {
m.def("infer_shape", &ClassType::type::InferShape)
......@@ -130,8 +131,8 @@ All parameter, weight, gradient are variables in Paddle.
[](Variable &self) -> Tensor * { return self.GetMutable<Tensor>(); },
py::return_value_policy::reference)
.def("get_net",
[](Variable &self) -> ops::NetOp * {
return self.GetMutable<ops::NetOp>();
[](Variable &self) -> operators::NetOp * {
return self.GetMutable<operators::NetOp>();
},
py::return_value_policy::reference);
......@@ -202,8 +203,6 @@ All parameter, weight, gradient are variables in Paddle.
return OpRegistry::CreateOp(desc);
});
operator_base.def_static("support_gpu", &OpRegistry::SupportGPU);
operator_base.def("backward",
[](const OperatorBase &forwardOp,
const std::unordered_set<std::string> &no_grad_vars) {
......@@ -212,23 +211,24 @@ All parameter, weight, gradient are variables in Paddle.
ExposeOperator(operator_base);
py::class_<ops::NetOp, std::shared_ptr<ops::NetOp>> net(m, "Net");
py::class_<operators::NetOp, std::shared_ptr<operators::NetOp>> net(m, "Net");
net.def_static("create",
[]() -> std::shared_ptr<ops::NetOp> {
auto retv = std::make_shared<ops::NetOp>();
[]() -> std::shared_ptr<operators::NetOp> {
auto retv = std::make_shared<operators::NetOp>();
retv->type_ = "plain_net";
return retv;
})
.def("add_op", &ops::NetOp::AddOp)
.def(
"add_op",
[](ops::NetOp &self, const std::shared_ptr<ops::NetOp> &net) -> void {
self.AddOp(std::static_pointer_cast<OperatorBase>(net));
})
.def("complete_add_op", &ops::NetOp::CompleteAddOp)
.def("complete_add_op",
[](std::shared_ptr<ops::NetOp> &self) { self->CompleteAddOp(); });
.def("add_op", &operators::NetOp::AddOp)
.def("add_op",
[](operators::NetOp &self,
const std::shared_ptr<operators::NetOp> &net) -> void {
self.AddOp(std::static_pointer_cast<OperatorBase>(net));
})
.def("complete_add_op", &operators::NetOp::CompleteAddOp)
.def("complete_add_op", [](std::shared_ptr<operators::NetOp> &self) {
self->CompleteAddOp();
});
ExposeOperator(net);
......
......@@ -127,8 +127,8 @@ class Tensor {
memory::PODDeleter<T, Place>(place)),
place_(place),
size_(size) {
PADDLE_ENFORCE(ptr_ != nullptr, "Insufficient %s memory to allocation.",
is_cpu_place(place_) ? "CPU" : "GPU");
PADDLE_ENFORCE_NOT_NULL(ptr_, "Insufficient %s memory to allocation.",
(is_cpu_place(place_) ? "CPU" : "GPU"));
}
virtual size_t size() const { return size_; }
......
......@@ -14,17 +14,18 @@ limitations under the License. */
#pragma once
#include "paddle/memory/memcpy.h"
#include "paddle/platform/enforce.h"
namespace paddle {
namespace framework {
template <typename T>
inline void Tensor::check_memory_size() const {
PADDLE_ENFORCE(holder_ != nullptr,
"Tenosr holds no memory. Call Tensor::mutable_data first.");
PADDLE_ENFORCE(holder_->size() >= product(dims_) * sizeof(T) + offset_,
"Tensor's dims_ is out of bound. Call Tensor::mutable_data "
"first to re-allocate memory.");
PADDLE_ENFORCE_NOT_NULL(
holder_, "Tenosr holds no memory. Call Tensor::mutable_data first.");
PADDLE_ENFORCE_GE(holder_->size(), product(dims_) * sizeof(T) + offset_,
"Tensor's dims_ is out of bound. Call Tensor::mutable_data "
"first to re-allocate memory.");
}
template <typename T>
......@@ -51,9 +52,9 @@ inline T* Tensor::mutable_data(DDim dims, platform::Place place) {
template <typename T>
inline T* Tensor::mutable_data(platform::Place place) {
static_assert(std::is_pod<T>::value, "T must be POD");
PADDLE_ENFORCE(product(dims_) > 0,
"Tensor's numel must be larger than zero to call "
"Tensor::mutable_data. Call Tensor::set_dim first.");
PADDLE_ENFORCE_GT(product(dims_), 0,
"Tensor's numel must be larger than zero to call "
"Tensor::mutable_data. Call Tensor::set_dim first.");
/* some versions of boost::variant don't have operator!= */
size_t size = product(dims_) * sizeof(T);
if (holder_ == nullptr || !(holder_->place() == place) ||
......@@ -120,11 +121,11 @@ inline void Tensor::CopyFrom(const Tensor& src,
template <typename T>
inline Tensor Tensor::Slice(const int& begin_idx, const int& end_idx) const {
check_memory_size<T>();
PADDLE_ENFORCE(begin_idx >= 0, "Slice begin index is less than zero.");
PADDLE_ENFORCE(end_idx <= dims_[0], "Slice end index is out of bound.");
PADDLE_ENFORCE(begin_idx < end_idx,
"Begin index must be less than end index.");
PADDLE_ENFORCE(dims_[0] != 1, "Can not slice a tensor with dims_[0] = 1.");
PADDLE_ENFORCE_GE(begin_idx, 0, "Slice begin index is less than zero.");
PADDLE_ENFORCE_LE(end_idx, dims_[0], "Slice end index is out of bound.");
PADDLE_ENFORCE_LT(begin_idx, end_idx,
"Begin index must be less than end index.");
PADDLE_ENFORCE_NE(dims_[0], 1, "Can not slice a tensor with dims_[0] = 1.");
int base = product(dims_) / dims_[0];
Tensor dst;
dst.holder_ = holder_;
......
......@@ -36,7 +36,8 @@ TEST(Tensor, DataAssert) {
} catch (paddle::platform::EnforceNotMet err) {
caught = true;
std::string msg =
"Tenosr holds no memory. Call Tensor::mutable_data first.";
"holder_ should not be null\nTenosr holds no memory. Call "
"Tensor::mutable_data first.";
const char* what = err.what();
for (size_t i = 0; i < msg.length(); ++i) {
ASSERT_EQ(what[i], msg[i]);
......@@ -111,7 +112,8 @@ TEST(Tensor, ShareDataWith) {
} catch (paddle::platform::EnforceNotMet err) {
caught = true;
std::string msg =
"Tenosr holds no memory. Call Tensor::mutable_data first.";
"holder_ should not be null\nTenosr holds no memory. Call "
"Tensor::mutable_data first.";
const char* what = err.what();
for (size_t i = 0; i < msg.length(); ++i) {
ASSERT_EQ(what[i], msg[i]);
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "Layer.h"
namespace paddle {
class KmaxSeqScoreLayer : public Layer {
private:
MatrixPtr scores_;
size_t beamSize_;
void kmaxScorePerSeq(const real* score,
real* sortedRes,
const ICpuGpuVectorPtr seqStartPos);
public:
explicit KmaxSeqScoreLayer(const LayerConfig& config) : Layer(config) {}
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
};
REGISTER_LAYER(kmax_seq_score, KmaxSeqScoreLayer);
bool KmaxSeqScoreLayer::init(const LayerMap& layerMap,
const ParameterMap& parameterMap) {
bool ret = Layer::init(layerMap, parameterMap);
CHECK_EQ(1U, inputLayers_.size());
beamSize_ = config_.beam_size();
CHECK_GE(beamSize_, 1U);
setNeedSequenceInfo(false);
setNeedGradient(false);
return ret;
}
void KmaxSeqScoreLayer::kmaxScorePerSeq(const real* scores,
real* sortedIds,
const ICpuGpuVectorPtr seqStartPos) {
int* starts = seqStartPos->getMutableData(false);
std::vector<real> indices;
for (size_t i = 0; i < seqStartPos->getSize() - 1; ++i) {
int seqLen = starts[i + 1] - starts[i];
int k = std::min(static_cast<int>(beamSize_), seqLen);
indices.resize(seqLen, 0);
std::iota(begin(indices), end(indices), 0.);
std::vector<real> tmpScore(scores + starts[i], scores + starts[i + 1]);
std::partial_sort(
begin(indices),
begin(indices) + k,
end(indices),
[&](size_t a, size_t b) { return tmpScore[a] > tmpScore[b]; });
memcpy(sortedIds + (i * beamSize_), indices.data(), k * sizeof(real));
}
}
void KmaxSeqScoreLayer::forward(PassType passType) {
Layer::forward(passType);
const Argument& input = getInput(0);
const MatrixPtr inputScore = getInputValue(0);
CHECK(input.hasSeq() || input.hasSubseq())
<< "input of " << getName()
<< " must be a sequence or a nested sequence.";
CHECK_EQ(input.value->getWidth(), 1UL)
<< "input of " << getName()
<< " is score over a sequence or a nested sequence, so its width "
<< " must be 1.";
if (useGpu_) {
// this Layer runs only in CPU, if the model is runing on GPU,
// then copy the input to this layer from GPU to CPU.
Matrix::resizeOrCreate(scores_,
inputScore->getHeight(),
1,
false /* trans */,
false /* useGpu */);
scores_->copyFrom(*inputScore);
} else {
scores_ = inputScore;
}
Matrix::resizeOrCreate(
output_.value,
input.hasSubseq() ? input.getNumSubSequences() : input.getNumSequences(),
beamSize_,
false,
false);
output_.value->one();
output_.value->mulScalar(-1.);
kmaxScorePerSeq(scores_->getData(),
output_.value->getData(),
input.hasSubseq() ? input.subSequenceStartPositions
: input.sequenceStartPositions);
}
void KmaxSeqScoreLayer::backward(const UpdateCallback& callback) {}
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "Layer.h"
#include "paddle/math/Matrix.h"
#include "paddle/math/Vector.h"
#include "paddle/utils/Logging.h"
#include "paddle/utils/Stat.h"
namespace paddle {
class SubNestedSequenceLayer : public Layer {
public:
explicit SubNestedSequenceLayer(const LayerConfig& config) : Layer(config) {}
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
private:
/*
* This functions generates the indices of rows in a batch according to the
* indices of selected sub-sequence in each sequence.
*
* Examples:
* selectedIndices:
* [
* [0, 1, -1],
* [0, 1, 2],
* [0, -1, -1],
* [0, 2, 3],
* ]
* inputSeqInfo:
* [
* [0,3,4],
* [4,5,7,10,15],
* [15,20],
* [20,22,23,25,28]
* ]
*
* ths output is saved to private member rowIndice_;
* [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,
* 16,17,18,19,20,21,22,23,24,25,26,27]
*/
void calSelectedCols(const MatrixPtr selectedIndices,
const std::vector<std::vector<int>>& inputSeqInfo);
// if the second input of this layer is on GPU memory, copy it to CPU memory.
MatrixPtr selIdsCpu_;
// reorganized sequenceStartPositions and subSequenceStartPositions
// into a 2d vector to facilitate the sequence selection process.
std::vector<std::vector<int>> inputSeqInfoVec_;
// the final selected row indices in a batch,
// rowIdx_ and selectedRows_ actually share a same memory.
IVectorPtr rowIndice_;
std::vector<int> selectedRows_;
};
REGISTER_LAYER(sub_nested_seq, SubNestedSequenceLayer);
bool SubNestedSequenceLayer::init(const LayerMap& layerMap,
const ParameterMap& parameterMap) {
/* Initialize the basic parent class */
Layer::init(layerMap, parameterMap);
CHECK_EQ(2U, inputLayers_.size());
setNeedSequenceInfo(false);
return true;
}
void SubNestedSequenceLayer::calSelectedCols(
const MatrixPtr selectedIndices,
const std::vector<std::vector<int>>& inputSeqInfo) {
selectedRows_.clear();
std::vector<int> outSeqStartInfo(1, 0);
std::vector<int> outSubSeqStartInfo(1, 0);
size_t seqNum = selectedIndices->getHeight();
size_t beamSize = selectedIndices->getWidth();
for (size_t i = 0; i < seqNum; ++i) {
for (size_t j = 0; j < beamSize; ++j) {
if (selectedIndices->getElement(i, j) == -1.) break;
int selSubSeqIdx = selectedIndices->getElement(i, j);
CHECK_GT(inputSeqInfoVec_[i].size() - 1, selSubSeqIdx);
size_t subSeqLen = inputSeqInfoVec_[i][selSubSeqIdx + 1] -
inputSeqInfoVec_[i][selSubSeqIdx];
for (size_t k = 0; k < subSeqLen; ++k)
selectedRows_.push_back(inputSeqInfoVec_[i][selSubSeqIdx] + k);
outSubSeqStartInfo.push_back(outSubSeqStartInfo.back() + subSeqLen);
}
outSeqStartInfo.push_back(outSubSeqStartInfo.back());
}
if (useGpu_) {
rowIndice_ = IVector::create(selectedRows_.size(), useGpu_);
rowIndice_->copyFrom(selectedRows_.data(), selectedRows_.size());
} else {
rowIndice_ =
IVector::create(selectedRows_.data(), selectedRows_.size(), useGpu_);
}
// create the sequence information for the output.
ICpuGpuVector::resizeOrCreate(
output_.sequenceStartPositions, outSeqStartInfo.size(), false);
output_.sequenceStartPositions->copyFrom(
outSeqStartInfo.data(), outSeqStartInfo.size(), false);
ICpuGpuVector::resizeOrCreate(
output_.subSequenceStartPositions, outSubSeqStartInfo.size(), false);
output_.subSequenceStartPositions->copyFrom(
outSubSeqStartInfo.data(), outSubSeqStartInfo.size(), false);
}
void SubNestedSequenceLayer::forward(PassType passType) {
Layer::forward(passType);
const Argument& inputSeq = getInput(0);
CHECK(inputSeq.hasSubseq()) << "The first input of SubNestSequence layer "
<< "must be a nested sequence.";
const MatrixPtr selectedIndices = getInputValue(1);
CHECK_EQ(inputSeq.getNumSequences(), selectedIndices->getHeight());
if (dynamic_cast<GpuMatrix*>(selectedIndices.get())) {
/*
* Currently, the second input for this layer is generated by
* kmax_sequence_score_layer whose output is always stored on CPU,
* or a data_layer which canbe on GPU.
*
* If the second input is on GPU, copy it to CPU memory, because this
* input always uses very few memory, and operations related to it are
* all logic control, not computations.
*/
Matrix::resizeOrCreate(selIdsCpu_,
selectedIndices->getHeight(),
selectedIndices->getWidth(),
false /* trans */,
false /* useGpu */);
selIdsCpu_->copyFrom(*selectedIndices);
} else {
selIdsCpu_ = selectedIndices;
}
Argument::reorganizeSeqInfo(inputSeq.sequenceStartPositions,
inputSeq.subSequenceStartPositions,
inputSeqInfoVec_);
calSelectedCols(selIdsCpu_, inputSeqInfoVec_);
resetOutput(selectedRows_.size(), getSize());
getOutputValue()->selectRows(*getInputValue(0), *rowIndice_);
}
void SubNestedSequenceLayer::backward(const UpdateCallback& callback) {
MatrixPtr inputSeqGrad = getInputGrad(0);
MatrixPtr outputGrad = getOutputGrad();
if (inputSeqGrad) outputGrad->addToRows(*inputSeqGrad, *rowIndice_);
}
} // namespace paddle
......@@ -66,6 +66,16 @@ add_unittest_without_exec(test_BatchNorm
add_test(NAME test_BatchNorm
COMMAND test_BatchNorm)
################# test_KmaxSeqScore #######################
add_unittest_without_exec(test_KmaxSeqScore
test_KmaxSeqScore.cpp
LayerGradUtil.cpp)
add_test(NAME test_KmaxSeqScore
COMMAND test_KmaxSeqScore)
################## test_Evaluator #######################
add_unittest(test_Evaluator
test_Evaluator.cpp)
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include <algorithm>
#include <string>
#include <vector>
#include "ModelConfig.pb.h"
#include "paddle/gserver/layers/DataLayer.h"
#include "paddle/trainer/Trainer.h"
#include "paddle/utils/GlobalConstants.h"
#include "LayerGradUtil.h"
#include "paddle/testing/TestUtil.h"
using namespace paddle; // NOLINT
using namespace std; // NOLINT
DECLARE_bool(use_gpu);
DECLARE_int32(gpu_id);
DECLARE_bool(thread_local_rand_use_global_seed);
vector<int> randSampling(int range, int n) {
CHECK_GE(range, n);
vector<int> num(range);
iota(begin(num), end(num), 0);
if (range == n) return num;
random_shuffle(begin(num), end(num));
num.resize(n);
return num;
}
void genRandomSeqInfo(vector<int>& seqStartPosition,
vector<int>& subSeqStartPosition) {
const int maxSeqNum = 100;
// generate random start position information
int seqNum = 1 + (rand() % maxSeqNum);
seqStartPosition.resize(seqNum + 1, 0);
subSeqStartPosition.resize(1, 0);
for (int i = 0; i < seqNum; ++i) {
int subSeqLen = 1 + (rand() % maxSeqNum);
for (int j = 0; j < subSeqLen; ++j)
subSeqStartPosition.push_back(subSeqStartPosition.back() + subSeqLen);
seqStartPosition[i + 1] = subSeqStartPosition.back();
}
}
void genRandomGroundTruth(real* values,
vector<vector<int>>& groundTruth,
vector<int>& startPos,
size_t beamSize) {
groundTruth.resize(startPos.size() - 1, vector<int>(beamSize, -1));
for (size_t i = 0; i < startPos.size() - 1; ++i) {
int seqLen = startPos[i + 1] - startPos[i];
vector<int> pos =
randSampling(seqLen, min(static_cast<int>(beamSize), seqLen));
for (size_t j = 0; j < pos.size(); ++j) {
groundTruth[i][j] = pos[j];
values[startPos[i] + pos[j]] = 1.;
}
}
}
void checkLayerOut(vector<vector<int>> groundTruth,
real* layerOut,
size_t beamSize) {
for (size_t i = 0; i < groundTruth.size(); ++i) {
int begPos = i * beamSize;
vector<real> tmp(layerOut + begPos, layerOut + begPos + beamSize);
sort(begin(tmp), end(tmp));
sort(begin(groundTruth[i]), end(groundTruth[i]));
for (size_t j = 0; j < beamSize; ++j) CHECK_EQ(tmp[j], groundTruth[i][j]);
}
}
TEST(Layer, kmaxSeqScoreLayer) {
const size_t maxBeamSize = 100;
int beamSize = 1 + (rand() % maxBeamSize);
vector<int> seqStartPosition;
vector<int> subSeqStartPosition;
genRandomSeqInfo(seqStartPosition, subSeqStartPosition);
MatrixPtr inValue =
Matrix::create(subSeqStartPosition.back(), 1, false, false);
for (auto hasSubseq : {false, true}) {
vector<vector<int>> groundTruth;
inValue->randomizeUniform();
genRandomGroundTruth(inValue->getData(),
groundTruth,
hasSubseq ? subSeqStartPosition : seqStartPosition,
beamSize);
for (auto useGpu : {false, true}) {
TestConfig config;
config.layerConfig.set_type("kmax_seq_score");
config.layerConfig.set_beam_size(beamSize);
if (hasSubseq) {
config.inputDefs.push_back({INPUT_SELF_DEFINE_DATA,
"scores",
inValue,
seqStartPosition,
subSeqStartPosition});
} else {
config.inputDefs.push_back(
{INPUT_SELF_DEFINE_DATA, "scores", inValue, seqStartPosition});
}
config.layerConfig.add_inputs();
// data layer initialize
std::vector<DataLayerPtr> dataLayers;
LayerMap layerMap;
vector<Argument> datas;
initDataLayer(
config,
&dataLayers,
&datas,
&layerMap,
"kmax_seq_score",
100 /* actually this parameter is unused in self-defined input*/,
false,
useGpu);
// test layer initialize
std::vector<ParameterPtr> parameters;
LayerPtr kmaxSeqScoreLayer;
FLAGS_use_gpu = useGpu;
initTestLayer(config, &layerMap, &parameters, &kmaxSeqScoreLayer);
kmaxSeqScoreLayer->forward(PASS_TRAIN);
const MatrixPtr outValue = kmaxSeqScoreLayer->getOutputValue();
CHECK_EQ(outValue->getHeight(),
hasSubseq ? subSeqStartPosition.size() - 1
: seqStartPosition.size() - 1);
CHECK_EQ(outValue->getWidth(), beamSize);
checkLayerOut(groundTruth, outValue->getData(), beamSize);
}
}
}
int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
initMain(argc, argv);
FLAGS_thread_local_rand_use_global_seed = true;
srand((size_t)(time(NULL)));
return RUN_ALL_TESTS();
}
......@@ -1899,6 +1899,84 @@ TEST(Layer, CropLayer) {
}
}
vector<real> randSampling(real range, int n) {
CHECK_GE(range, n);
vector<real> num(range);
iota(begin(num), end(num), 0.);
if (range == n) return num;
random_shuffle(begin(num), end(num));
num.resize(n);
sort(begin(num), end(num));
return num;
}
TEST(Layer, SubNestedSequenceLayer) {
// layer size is not crutial for this layer,
// so use a small layer size in unittest
const int layerSize = 4;
const int maxSeqNum = 50;
const int maxSeqLen = 50;
const int maxBeamSize = 32;
srand((size_t)(time(NULL)));
int beamSize = 1 + (rand() % maxBeamSize);
TestConfig config;
config.layerConfig.set_type("sub_nested_seq");
config.layerConfig.set_name("sub_nested_seq_layer");
config.layerConfig.set_size(layerSize);
int seqNum = 1 + (rand() % maxSeqNum);
// sequence information for the first input, it is a nested sequence
vector<int> seqStartPos(seqNum + 1, 0);
vector<int> subSeqStartPos(1, 0);
// selected indices
MatrixPtr selectedIndices = Matrix::create(seqNum, beamSize, false, false);
selectedIndices->one();
selectedIndices->mulScalar(-1.);
real* indicesData = selectedIndices->getData();
for (int i = 0; i < seqNum; ++i) {
int subSeqNum = 1 + (rand() % maxSeqNum);
for (int j = 0; j < subSeqNum; ++j) {
subSeqStartPos.push_back(subSeqStartPos.back() +
(1 + (rand() % maxSeqLen)));
}
vector<real> selSeqs =
randSampling(static_cast<real>(subSeqNum), min(beamSize, subSeqNum));
memcpy(indicesData + (i * beamSize),
selSeqs.data(),
selSeqs.size() * sizeof(real));
seqStartPos[i + 1] = subSeqStartPos.back();
}
MatrixPtr seqInputPtr =
Matrix::create(seqStartPos.back(), layerSize, false, false);
seqInputPtr->randomizeUniform();
config.inputDefs.push_back({INPUT_SELF_DEFINE_DATA,
"nested_seq_input",
seqInputPtr,
seqStartPos,
subSeqStartPos});
config.layerConfig.add_inputs();
config.inputDefs.push_back(
{INPUT_SELF_DEFINE_DATA, "selected_indices", selectedIndices});
config.layerConfig.add_inputs();
for (auto useGpu : {false, true}) {
testLayerGrad(config,
"sub_nested_seq",
/* batchSize */ seqNum,
/* trans */ false,
/* useGpu*/ useGpu,
/* useWeight */ false);
}
}
TEST(Layer, ClipLayer) {
const size_t batchSize = 128;
const size_t size = 512;
......
......@@ -59,6 +59,7 @@ op_library(cross_entropy_op SRCS cross_entropy_op.cc cross_entropy_op.cu)
op_library(fill_zeros_like_op SRCS fill_zeros_like_op.cc fill_zeros_like_op.cu)
op_library(sgd_op SRCS sgd_op.cc sgd_op.cu)
cc_test(sgd_op_test SRCS sgd_op_test.cc DEPS sgd_op)
op_library(fc_op
SRCS fc_op.cc
......@@ -66,3 +67,5 @@ op_library(fc_op
op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc
DEPS framework_proto tensor op_registry operator net_op)
cc_test(recurrent_op_test SRCS recurrent_op_test.cc DEPS recurrent_op gtest mul_op add_op)
op_library(uniform_random_op
SRCS uniform_random_op.cc uniform_random_op.cu)
......@@ -17,9 +17,9 @@ limitations under the License. */
namespace paddle {
namespace operators {
class AddOp : public OperatorWithKernel {
class AddOp : public framework::OperatorWithKernel {
protected:
void InferShape(const InferShapeContext &ctx) const override {
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_EQ(ctx.Input<Tensor>("X")->dims(),
ctx.Input<Tensor>("Y")->dims(),
"Two input of Add Op's dimension must be same.");
......@@ -27,9 +27,9 @@ class AddOp : public OperatorWithKernel {
}
};
class AddOpMaker : public OpProtoAndCheckerMaker {
class AddOpMaker : public framework::OpProtoAndCheckerMaker {
public:
AddOpMaker(OpProto *proto, OpAttrChecker *op_checker)
AddOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The first input of add op");
AddInput("Y", "The second input of add op");
......@@ -42,14 +42,17 @@ The equation is: Out = X + Y
}
};
class AddOpGrad : public OperatorWithKernel {
class AddOpGrad : public framework::OperatorWithKernel {
protected:
void InferShape(const InferShapeContext &ctx) const override {}
void InferShape(const framework::InferShapeContext &ctx) const override {}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(add_two, ops::AddOp, ops::AddOpMaker);
REGISTER_GRADIENT_OP(add_two, add_two_grad, ops::AddOpGrad);
REGISTER_OP_CPU_KERNEL(add_two, ops::AddKernel<ops::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(add_two,
ops::AddKernel<paddle::platform::CPUPlace, float>);
......@@ -16,4 +16,6 @@
#include "paddle/framework/op_registry.h"
#include "paddle/operators/add_op.h"
REGISTER_OP_GPU_KERNEL(add_two, ops::AddKernel<ops::GPUPlace, float>);
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(add_two,
ops::AddKernel<paddle::platform::GPUPlace, float>);
......@@ -13,15 +13,21 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/operators/type_alias.h"
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename Place, typename T>
class AddKernel : public OpKernel {
class AddKernel : public framework::OpKernel {
public:
void Compute(const ExecutionContext& context) const override {
void Compute(const framework::ExecutionContext& context) const override {
auto* input0 = context.Input<Tensor>("X");
auto* input1 = context.Input<Tensor>("Y");
auto* output = context.Output<Tensor>("Out");
......
......@@ -14,9 +14,9 @@ limitations under the License. */
#include <gtest/gtest.h>
#define private public
#include <paddle/framework/op_registry.h>
#include "paddle/framework/op_registry.h"
USE_OP(add_two);
// USE_OP(add_two_grad);
TEST(AddOp, GetOpProto) {
auto& protos = paddle::framework::OpRegistry::protos();
......
......@@ -17,9 +17,9 @@ limitations under the License. */
namespace paddle {
namespace operators {
class OnehotCrossEntropyOp : public OperatorWithKernel {
class OnehotCrossEntropyOp : public framework::OperatorWithKernel {
protected:
void InferShape(const InferShapeContext &ctx) const override {
void InferShape(const framework::InferShapeContext &ctx) const override {
auto *X = ctx.Input<Tensor>("X");
auto *label = ctx.Input<Tensor>("label");
......@@ -30,9 +30,9 @@ class OnehotCrossEntropyOp : public OperatorWithKernel {
}
};
class OnehotCrossEntropyGradientOp : public OperatorWithKernel {
class OnehotCrossEntropyGradientOp : public framework::OperatorWithKernel {
protected:
void InferShape(const InferShapeContext &ctx) const override {
void InferShape(const framework::InferShapeContext &ctx) const override {
auto X_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto X = ctx.Input<Tensor>("X");
......@@ -41,9 +41,10 @@ class OnehotCrossEntropyGradientOp : public OperatorWithKernel {
}
};
class OnehotCrossEntropyOpMaker : public OpProtoAndCheckerMaker {
class OnehotCrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker {
public:
OnehotCrossEntropyOpMaker(OpProto *proto, OpAttrChecker *op_checker)
OnehotCrossEntropyOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The first input of OnehotCrossEntropyOp");
AddInput("label", "The second input of OnehotCrossEntropyOp");
......@@ -59,11 +60,14 @@ OnehotCrossEntropy Operator.
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(onehot_cross_entropy, ops::OnehotCrossEntropyOp,
ops::OnehotCrossEntropyOpMaker);
REGISTER_OP_CPU_KERNEL(onehot_cross_entropy,
ops::OnehotCrossEntropyOpKernel<ops::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
onehot_cross_entropy,
ops::OnehotCrossEntropyOpKernel<paddle::platform::CPUPlace, float>);
REGISTER_GRADIENT_OP(onehot_cross_entropy, onehot_cross_entropy_grad,
ops::OnehotCrossEntropyGradientOp);
REGISTER_OP_CPU_KERNEL(
onehot_cross_entropy_grad,
ops::OnehotCrossEntropyGradientOpKernel<ops::CPUPlace, float>);
ops::OnehotCrossEntropyGradientOpKernel<paddle::platform::CPUPlace, float>);
......@@ -14,3 +14,8 @@
#define EIGEN_USE_GPU
#include "paddle/operators/cross_entropy_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
onehot_cross_entropy,
ops::OnehotCrossEntropyOpKernel<paddle::platform::GPUPlace, float>);
......@@ -13,11 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/operators/type_alias.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
T tolerable_value(T x) {
static_assert(std::is_floating_point<T>::value,
......@@ -38,9 +40,9 @@ T tolerable_value(T x) {
}
template <typename Place, typename T>
class OnehotCrossEntropyOpKernel : public OpKernel {
class OnehotCrossEntropyOpKernel : public framework::OpKernel {
public:
void Compute(const ExecutionContext& ctx) const override {
void Compute(const framework::ExecutionContext& ctx) const override {
auto X = ctx.Input<Tensor>("X");
const T* Xdata = X->data<T>();
const int* label_data = ctx.Input<Tensor>("label")->data<int>();
......@@ -61,9 +63,9 @@ class OnehotCrossEntropyOpKernel : public OpKernel {
};
template <typename Place, typename T>
class OnehotCrossEntropyGradientOpKernel : public OpKernel {
class OnehotCrossEntropyGradientOpKernel : public framework::OpKernel {
public:
void Compute(const ExecutionContext& ctx) const override {
void Compute(const framework::ExecutionContext& ctx) const override {
auto X = ctx.Input<Tensor>("X");
auto dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto dY = ctx.Input<Tensor>(framework::GradVarName("Y"));
......
......@@ -12,11 +12,16 @@
See the License for the specific language governing permissions and
limitations under the License. */
#include "type_alias.h"
#include "paddle/operators/net_op.h"
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
using OpRegistry = framework::OpRegistry;
class FullyConnectedOp : public NetOp {
public:
void Init() override {
......@@ -39,9 +44,10 @@ class FullyConnectedOp : public NetOp {
}
};
class FullyConnectedOpMaker : public OpProtoAndCheckerMaker {
class FullyConnectedOpMaker : public framework::OpProtoAndCheckerMaker {
public:
FullyConnectedOpMaker(OpProto *proto, OpAttrChecker *op_checker)
FullyConnectedOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "the input of fc operator");
AddInput("W", "the weight of fc operator");
......@@ -66,4 +72,5 @@ USE_OP(rowwise_add);
USE_OP(sigmoid);
USE_OP(softmax);
namespace ops = paddle::operators;
REGISTER_OP(fc, ops::FullyConnectedOp, ops::FullyConnectedOpMaker);
......@@ -42,8 +42,8 @@ The output will have the same size with input.
} // namespace operators
} // namespace paddle
REGISTER_OP(fill_zeros_like, paddle::operators::FillZerosLikeOp,
paddle::operators::FillZerosLikeOpMaker);
namespace ops = paddle::operators;
REGISTER_OP(fill_zeros_like, ops::FillZerosLikeOp, ops::FillZerosLikeOpMaker);
REGISTER_OP_CPU_KERNEL(
fill_zeros_like,
paddle::operators::FillZerosLikeKernel<paddle::platform::CPUPlace, float>);
ops::FillZerosLikeKernel<paddle::platform::CPUPlace, float>);
......@@ -16,6 +16,7 @@
#include "paddle/framework/op_registry.h"
#include "paddle/operators/fill_zeros_like_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
fill_zeros_like,
paddle::operators::FillZerosLikeKernel<paddle::platform::GPUPlace, float>);
ops::FillZerosLikeKernel<paddle::platform::GPUPlace, float>);
......@@ -13,7 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/operators/type_alias.h"
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
......
......@@ -17,18 +17,18 @@ limitations under the License. */
namespace paddle {
namespace operators {
class MeanOp : public OperatorWithKernel {
class MeanOp : public framework::OperatorWithKernel {
protected:
void InferShape(const InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.InputVar("X") != nullptr,
"Input of MeanOp must be initialized.");
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"),
"Input of MeanOp must be initialized.");
ctx.Output<Tensor>("Out")->Resize({1});
}
};
class MeanOpMaker : public OpProtoAndCheckerMaker {
class MeanOpMaker : public framework::OpProtoAndCheckerMaker {
public:
MeanOpMaker(OpProto *proto, OpAttrChecker *op_checker)
MeanOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input of mean op");
AddOutput("Out", "The output of mean op").IgnoreGradient();
......@@ -36,9 +36,9 @@ class MeanOpMaker : public OpProtoAndCheckerMaker {
}
};
class MeanGradOp : public OperatorWithKernel {
class MeanGradOp : public framework::OperatorWithKernel {
protected:
void InferShape(const InferShapeContext &ctx) const override {
void InferShape(const framework::InferShapeContext &ctx) const override {
ctx.Output<Tensor>("X" + framework::kGradVarSuffix)
->Resize(ctx.Input<Tensor>("X")->dims());
}
......@@ -47,7 +47,10 @@ class MeanGradOp : public OperatorWithKernel {
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(mean, ops::MeanOp, ops::MeanOpMaker);
REGISTER_OP_CPU_KERNEL(mean, ops::MeanKernel<ops::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(mean,
ops::MeanKernel<paddle::platform::CPUPlace, float>);
REGISTER_GRADIENT_OP(mean, mean_grad, ops::MeanGradOp);
REGISTER_OP_CPU_KERNEL(mean_grad, ops::MeanGradKernel<ops::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(mean_grad,
ops::MeanGradKernel<paddle::platform::CPUPlace, float>);
......@@ -16,5 +16,8 @@
#include "paddle/operators/mean_op.h"
REGISTER_OP_GPU_KERNEL(mean, ops::MeanKernel<ops::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(mean_grad, ops::MeanGradKernel<ops::GPUPlace, float>);
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(mean,
ops::MeanKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(mean_grad,
ops::MeanGradKernel<paddle::platform::GPUPlace, float>);
......@@ -13,15 +13,24 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/operators/type_alias.h"
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenScalar = framework::EigenScalar<T, MajorType, IndexType>;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename Place, typename T>
class MeanKernel : public OpKernel {
class MeanKernel : public framework::OpKernel {
public:
void Compute(const ExecutionContext& context) const override {
void Compute(const framework::ExecutionContext& context) const override {
auto input = context.Input<Tensor>(0);
auto output = context.Output<Tensor>(0);
......@@ -36,9 +45,9 @@ class MeanKernel : public OpKernel {
};
template <typename Place, typename T>
class MeanGradKernel : public OpKernel {
class MeanGradKernel : public framework::OpKernel {
public:
void Compute(const ExecutionContext& context) const override {
void Compute(const framework::ExecutionContext& context) const override {
auto OG = context.Input<Tensor>("Out" + framework::kGradVarSuffix);
PADDLE_ENFORCE(framework::product(OG->dims()) == 1,
"Mean Gradient should be scalar");
......
......@@ -17,9 +17,9 @@
namespace paddle {
namespace operators {
class MulOp : public OperatorWithKernel {
class MulOp : public framework::OperatorWithKernel {
protected:
void InferShape(const InferShapeContext &ctx) const override {
void InferShape(const framework::InferShapeContext &ctx) const override {
auto dim0 = ctx.Input<Tensor>("X")->dims();
auto dim1 = ctx.Input<Tensor>("Y")->dims();
PADDLE_ENFORCE_EQ(dim0.size(), 2,
......@@ -35,9 +35,9 @@ class MulOp : public OperatorWithKernel {
}
};
class MulOpMaker : public OpProtoAndCheckerMaker {
class MulOpMaker : public framework::OpProtoAndCheckerMaker {
public:
MulOpMaker(OpProto *proto, OpAttrChecker *op_checker)
MulOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The first input of mul op");
AddInput("Y", "The second input of mul op");
......@@ -50,9 +50,9 @@ The equation is: Out = X * Y
}
};
class MulOpGrad : public OperatorWithKernel {
class MulOpGrad : public framework::OperatorWithKernel {
protected:
void InferShape(const InferShapeContext &ctx) const override {}
void InferShape(const framework::InferShapeContext &ctx) const override {}
std::string DebugString() const override {
LOG(INFO) << "MulGrad";
return "";
......@@ -62,7 +62,8 @@ class MulOpGrad : public OperatorWithKernel {
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(mul, ops::MulOp, ops::MulOpMaker);
REGISTER_GRADIENT_OP(mul, mul_grad, ops::MulOpGrad);
REGISTER_OP_CPU_KERNEL(mul, ops::MulKernel<ops::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(mul, ops::MulKernel<paddle::platform::CPUPlace, float>);
......@@ -15,4 +15,6 @@
#define EIGEN_USE_GPU
#include "paddle/operators/mul_op.h"
REGISTER_OP_GPU_KERNEL(mul, ops::MulKernel<ops::GPUPlace, float>);
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(mul, ops::MulKernel<paddle::platform::GPUPlace, float>);
......@@ -13,16 +13,21 @@
limitations under the License. */
#pragma once
#include "paddle/operators/type_alias.h"
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename Place, typename T>
class MulKernel : public OpKernel {
class MulKernel : public framework::OpKernel {
public:
void Compute(const ExecutionContext& context) const override {
void Compute(const framework::ExecutionContext& context) const override {
Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair = {
{Eigen::IndexPair<Eigen::DenseIndex>(1, 0)}};
......@@ -40,5 +45,6 @@ class MulKernel : public OpKernel {
Z.device(place) = X.contract(Y, dim_pair);
}
};
} // namespace operators
} // namespace paddle
......@@ -16,10 +16,6 @@ limitations under the License. */
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/scope.h"
#include "paddle/operators/type_alias.h"
#include "paddle/platform/device_context.h"
namespace paddle {
namespace operators {
......@@ -64,20 +60,29 @@ class NetOp : public framework::OperatorBase {
}
}
bool SupportGPU() const override {
for (auto& op : ops_) {
if (!op->SupportGPU()) {
return false;
}
}
return true;
}
/**
* @brief Add an operator by ptr
*/
void AddOp(const std::shared_ptr<OperatorBase>& op) {
PADDLE_ENFORCE(!add_op_done_, "Cannot AddOp when this network is sealed");
PADDLE_ENFORCE(op != nullptr, "Cannot Insert Null op");
PADDLE_ENFORCE_NOT_NULL(op, "Cannot Insert Null op");
ops_.push_back(op);
}
void InsertOp(size_t pos, const std::shared_ptr<OperatorBase>& op) {
PADDLE_ENFORCE(!add_op_done_,
"Cannot InsertOp when this network is sealed");
PADDLE_ENFORCE(op != nullptr, "Cannot Insert Null op");
PADDLE_ENFORCE(pos <= ops_.size(), "Out of range");
PADDLE_ENFORCE_NOT_NULL(op, "Cannot Insert Null op");
PADDLE_ENFORCE_LE(pos, ops_.size(), "Out of range");
ops_.insert(ops_.begin() + pos, op);
}
......
......@@ -2,31 +2,27 @@
#include <gtest/gtest.h>
#include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
namespace paddle {
namespace operators {
using Scope = framework::Scope;
using DeviceContext = platform::DeviceContext;
static int infer_shape_cnt = 0;
static int run_cnt = 0;
class TestOp : public OperatorBase {
class TestOp : public framework::OperatorBase {
public:
void InferShape(const framework::Scope& scope) const override {
++infer_shape_cnt;
}
void Run(const framework::Scope& scope,
const paddle::platform::DeviceContext& dev_ctx) const override {
void InferShape(const Scope& scope) const override { ++infer_shape_cnt; }
void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const override {
++run_cnt;
}
};
class EmptyOp : public OperatorBase {
class EmptyOp : public framework::OperatorBase {
public:
void InferShape(const Scope& scope) const override {}
void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const override {}
void Run(const Scope& scope, const DeviceContext& dev_ctx) const override {}
};
template <typename T>
......@@ -73,7 +69,7 @@ TEST(OpKernel, all) {
net->Run(scope, dev_ctx);
ASSERT_EQ(2, infer_shape_cnt);
ASSERT_EQ(2, run_cnt);
ASSERT_THROW(net->AddOp(op2), paddle::platform::EnforceNotMet);
ASSERT_THROW(net->AddOp(op2), platform::EnforceNotMet);
}
TEST(NetOp, insert_op) {
......
......@@ -14,17 +14,19 @@
#include "paddle/operators/recurrent_op.h"
#include <glog/logging.h>
#include <cstring>
#include <sstream>
#include "paddle/framework/op_registry.h"
#include "paddle/operators/net_op.h"
#include "paddle/platform/enforce.h"
namespace paddle {
namespace operators {
using Scope = framework::Scope;
using Variable = framework::Variable;
using Tensor = framework::Tensor;
void RecurrentAlgorithm::InferShape(const Scope& scope) const {
seq_len_ = scope.FindVar((arg_->inlinks[0]).external)
->GetMutable<Tensor>()
......@@ -140,10 +142,11 @@ void RecurrentOp::Init() {
alg_.Init(std::move(arg));
}
class RecurrentAlgorithmProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
class RecurrentAlgorithmProtoAndCheckerMaker
: public framework::OpProtoAndCheckerMaker {
public:
RecurrentAlgorithmProtoAndCheckerMaker(OpProto* proto,
OpAttrChecker* op_checker)
RecurrentAlgorithmProtoAndCheckerMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
const auto& name = RecurrentOp::kArgName;
// inputs and outputs stored in proto
......
......@@ -18,7 +18,9 @@ namespace paddle {
namespace operators {
namespace rnn {
namespace fmw = paddle::framework;
namespace f = paddle::framework;
using Tensor = framework::Tensor;
void SegmentInputs(const std::vector<Scope*>& step_scopes,
const std::vector<Link>& inlinks, const size_t seq_len,
......@@ -30,10 +32,10 @@ void SegmentInputs(const std::vector<Scope*>& step_scopes,
inlinks[i].external);
Tensor* input = input_var->GetMutable<Tensor>();
fmw::DDim dims = input->dims();
f::DDim dims = input->dims();
PADDLE_ENFORCE(static_cast<size_t>(dims[0]) == seq_len,
"all the inlinks must have same length");
fmw::DDim step_dims = slice_ddim(dims, 1, dims.size());
f::DDim step_dims = slice_ddim(dims, 1, dims.size());
for (size_t j = 0; j < seq_len; j++) {
Tensor* step_input =
step_scopes[j]->NewVar(inlinks[i].internal)->GetMutable<Tensor>();
......@@ -58,11 +60,10 @@ void ConcatOutputs(const std::vector<Scope*>& step_scopes,
auto step_scope_var = step_scopes[0]->FindVar(outlinks[i].internal);
PADDLE_ENFORCE(step_scope_var != nullptr, "%s not in scope",
outlinks[i].internal);
fmw::DDim step_dims =
step_scope_var->template GetMutable<Tensor>()->dims();
f::DDim step_dims = step_scope_var->template GetMutable<Tensor>()->dims();
std::vector<int> dims_vec = vectorize(step_dims);
dims_vec.insert(dims_vec.begin(), seq_len);
output->Resize(fmw::make_ddim(dims_vec));
output->Resize(f::make_ddim(dims_vec));
} else {
output->mutable_data<float>(platform::CPUPlace());
for (size_t j = 0; j < seq_len; j++) {
......@@ -104,7 +105,7 @@ void LinkMemories(const std::vector<Scope*>& scopes,
}
void InitArgument(const ArgumentName& name, Argument* arg,
const OperatorBase& op) {
const framework::OperatorBase& op) {
arg->step_net = op.Input(name.step_net);
arg->step_scopes = op.Output(name.step_scopes);
......
......@@ -17,12 +17,13 @@
#include <string>
#include "paddle/framework/operator.h"
#include "paddle/operators/type_alias.h"
namespace paddle {
namespace operators {
namespace rnn {
using Scope = framework::Scope;
/**
* Memory of a RNN (same as the role of `Momory` in PaddlePaddle).
*
......@@ -86,7 +87,7 @@ void LinkMemories(const std::vector<Scope*>& step_scopes,
const int offset, bool infer_shape_mode);
void InitArgument(const ArgumentName& name, Argument* arg,
const OperatorBase& op);
const framework::OperatorBase& op);
} // namespace rnn
} // namespace operators
......
......@@ -13,12 +13,13 @@
limitations under the License. */
#include "paddle/operators/rowwise_add_op.h"
namespace paddle {
namespace operators {
class RowWiseAddOp : public OperatorWithKernel {
class RowWiseAddOp : public framework::OperatorWithKernel {
protected:
void InferShape(const InferShapeContext &ctx) const override {
void InferShape(const framework::InferShapeContext &ctx) const override {
auto dim0 = ctx.Input<Tensor>("X")->dims();
auto dim1 = ctx.Input<Tensor>("b")->dims();
......@@ -30,9 +31,10 @@ class RowWiseAddOp : public OperatorWithKernel {
}
};
class RowWiseAddOpMaker : public OpProtoAndCheckerMaker {
class RowWiseAddOpMaker : public framework::OpProtoAndCheckerMaker {
public:
RowWiseAddOpMaker(OpProto *proto, OpAttrChecker *op_checker)
RowWiseAddOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The left input of row-wise add op, must be matrix");
AddInput("b", "The right input of row-wise add op, must be vector");
......@@ -48,6 +50,7 @@ for i in xrange(X.shape[0]):
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(rowwise_add, ops::RowWiseAddOp, ops::RowWiseAddOpMaker);
REGISTER_OP_CPU_KERNEL(rowwise_add,
ops::RowWiseAddKernel<ops::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
rowwise_add, ops::RowWiseAddKernel<paddle::platform::CPUPlace, float>);
......@@ -15,5 +15,6 @@
#define EIGEN_USE_GPU
#include "paddle/operators/rowwise_add_op.h"
REGISTER_OP_GPU_KERNEL(rowwise_add,
ops::RowWiseAddKernel<ops::GPUPlace, float>);
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
rowwise_add, ops::RowWiseAddKernel<paddle::platform::GPUPlace, float>);
......@@ -13,15 +13,24 @@
limitations under the License. */
#pragma once
#include "paddle/operators/type_alias.h"
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename Place, typename T>
class RowWiseAddKernel : public OpKernel {
class RowWiseAddKernel : public framework::OpKernel {
public:
void Compute(const ExecutionContext& context) const override {
void Compute(const framework::ExecutionContext& context) const override {
auto out = context.Output<Tensor>(0);
out->mutable_data<T>(context.GetPlace());
......
......@@ -17,9 +17,9 @@ limitations under the License. */
namespace paddle {
namespace operators {
class SGDOp : public OperatorWithKernel {
class SGDOp : public framework::OperatorWithKernel {
protected:
void InferShape(const InferShapeContext &ctx) const override {
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE(
ctx.Input<Tensor>("param")->dims() == ctx.Input<Tensor>("grad")->dims(),
"Two input of SGD Op's dimension must be same.");
......@@ -27,9 +27,9 @@ class SGDOp : public OperatorWithKernel {
}
};
class SGDOpMaker : public OpProtoAndCheckerMaker {
class SGDOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SGDOpMaker(OpProto *proto, OpAttrChecker *op_checker)
SGDOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("param", "input parameter");
AddInput("grad", "input gradient");
......@@ -47,5 +47,7 @@ param_out = param - learning_rate * grad;
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(sgd, ops::SGDOp, ops::SGDOpMaker);
REGISTER_OP_CPU_KERNEL(sgd, ops::SGDOpKernel<ops::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(sgd,
ops::SGDOpKernel<paddle::platform::CPUPlace, float>);
......@@ -15,4 +15,6 @@
#define EIGEN_USE_GPU
#include "paddle/operators/sgd_op.h"
REGISTER_OP_GPU_KERNEL(sgd, ops::SGDOpKernel<ops::GPUPlace, float>);
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(sgd,
ops::SGDOpKernel<paddle::platform::GPUPlace, float>);
......@@ -13,15 +13,21 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/operators/type_alias.h"
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename Place, typename T>
class SGDOpKernel : public OpKernel {
class SGDOpKernel : public framework::OpKernel {
public:
void Compute(const ExecutionContext& ctx) const override {
void Compute(const framework::ExecutionContext& ctx) const override {
auto param = ctx.Input<Tensor>("param");
auto grad = ctx.Input<Tensor>("grad");
auto param_out = ctx.Output<Tensor>(0);
......
......@@ -13,19 +13,21 @@
limitations under the License. */
#include "paddle/operators/sigmoid_op.h"
namespace paddle {
namespace operators {
class SigmoidOp : public OperatorWithKernel {
class SigmoidOp : public framework::OperatorWithKernel {
protected:
void InferShape(const InferShapeContext &ctx) const override {
void InferShape(const framework::InferShapeContext &ctx) const override {
ctx.Output<Tensor>("Y")->Resize(ctx.Input<Tensor>("X")->dims());
}
};
class SigmoidOpMaker : public OpProtoAndCheckerMaker {
class SigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SigmoidOpMaker(OpProto *proto, OpAttrChecker *op_checker)
SigmoidOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "sigmoid input");
AddOutput("Y", "sigmoid output");
......@@ -33,9 +35,9 @@ class SigmoidOpMaker : public OpProtoAndCheckerMaker {
}
};
class SigmoidOpGrad : public OperatorWithKernel {
class SigmoidOpGrad : public framework::OperatorWithKernel {
protected:
void InferShape(const InferShapeContext &ctx) const override {
void InferShape(const framework::InferShapeContext &ctx) const override {
ctx.Output<Tensor>(0)->Resize(ctx.Input<Tensor>(0)->dims());
}
};
......@@ -43,9 +45,11 @@ class SigmoidOpGrad : public OperatorWithKernel {
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(sigmoid, ops::SigmoidOp, ops::SigmoidOpMaker);
REGISTER_GRADIENT_OP(sigmoid, sigmoid_grad, ops::SigmoidOpGrad);
REGISTER_OP_CPU_KERNEL(sigmoid, ops::SigmoidKernel<ops::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(sigmoid_grad,
ops::SigmoidGradKernel<ops::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(sigmoid,
ops::SigmoidKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
sigmoid_grad, ops::SigmoidGradKernel<paddle::platform::CPUPlace, float>);
......@@ -15,6 +15,9 @@
#define EIGEN_USE_GPU
#include "paddle/operators/sigmoid_op.h"
REGISTER_OP_GPU_KERNEL(sigmoid, ops::SigmoidKernel<ops::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(sigmoid_grad,
ops::SigmoidGradKernel<ops::GPUPlace, float>);
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(sigmoid,
ops::SigmoidKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(
sigmoid_grad, ops::SigmoidGradKernel<paddle::platform::GPUPlace, float>);
......@@ -13,16 +13,21 @@
limitations under the License. */
#pragma once
#include "paddle/operators/type_alias.h"
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename Place, typename T>
class SigmoidKernel : public OpKernel {
class SigmoidKernel : public framework::OpKernel {
public:
void Compute(const ExecutionContext& context) const override {
void Compute(const framework::ExecutionContext& context) const override {
auto input = context.Input<Tensor>(0);
auto output = context.Output<Tensor>(0);
output->mutable_data<T>(context.GetPlace());
......@@ -37,9 +42,9 @@ class SigmoidKernel : public OpKernel {
};
template <typename Place, typename T>
class SigmoidGradKernel : public OpKernel {
class SigmoidGradKernel : public framework::OpKernel {
public:
void Compute(const ExecutionContext& context) const override {
void Compute(const framework::ExecutionContext& context) const override {
auto Y_t = context.Input<Tensor>("Y");
auto dY_t = context.Input<Tensor>(framework::GradVarName("Y"));
auto dX_t = context.Output<Tensor>(framework::GradVarName("X"));
......
......@@ -17,18 +17,19 @@ limitations under the License. */
namespace paddle {
namespace operators {
class SoftmaxOp : public OperatorWithKernel {
class SoftmaxOp : public framework::OperatorWithKernel {
protected:
void InferShape(const InferShapeContext &ctx) const override {
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.Input<Tensor>("X")->dims().size() == 2UL,
"The input of softmax op must be matrix");
ctx.Output<Tensor>("Y")->Resize(ctx.Input<Tensor>("X")->dims());
}
};
class SoftmaxOpMaker : public OpProtoAndCheckerMaker {
class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SoftmaxOpMaker(OpProto *proto, OpAttrChecker *op_checker)
SoftmaxOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "input of softmax");
AddOutput("Y", "output of softmax");
......@@ -36,12 +37,12 @@ class SoftmaxOpMaker : public OpProtoAndCheckerMaker {
}
};
class SoftmaxOpGrad : public OperatorWithKernel {
class SoftmaxOpGrad : public framework::OperatorWithKernel {
protected:
void InferShape(const InferShapeContext &ctx) const override {
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.InputVar("Y") != nullptr, "Input(Y) should not be null");
PADDLE_ENFORCE(ctx.InputVar(framework::GradVarName("Y")) != nullptr,
"Input(Y@GRAD) should not be null");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Y")),
"Input(Y@GRAD) should not be null");
PADDLE_ENFORCE(ctx.Input<Tensor>("Y")->dims() ==
ctx.Input<Tensor>(framework::GradVarName("Y"))->dims(),
"the shape of Input(0) and Input(1) should be the same");
......@@ -53,8 +54,11 @@ class SoftmaxOpGrad : public OperatorWithKernel {
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(softmax, ops::SoftmaxOp, ops::SoftmaxOpMaker);
REGISTER_OP_CPU_KERNEL(softmax, ops::SoftmaxKernel<ops::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(softmax,
ops::SoftmaxKernel<paddle::platform::CPUPlace, float>);
REGISTER_GRADIENT_OP(softmax, softmax_grad, ops::SoftmaxOpGrad);
REGISTER_OP_CPU_KERNEL(softmax_grad,
ops::SoftmaxGradKernel<ops::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
softmax_grad, ops::SoftmaxGradKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
/* Copyright (c) 2016 PaddlePaddle Authors All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -13,9 +13,11 @@
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/framework/op_registry.h"
#include "paddle/operators/softmax_op.h"
REGISTER_OP_GPU_KERNEL(softmax, ops::SoftmaxKernel<ops::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(softmax_grad,
ops::SoftmaxGradKernel<ops::GPUPlace, float>);
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(softmax,
ops::SoftmaxKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(
softmax_grad, ops::SoftmaxGradKernel<paddle::platform::GPUPlace, float>);
......@@ -13,19 +13,21 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/ddim.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/tensor.h"
#include "paddle/operators/type_alias.h"
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename Place, typename T>
class SoftmaxKernel : public OpKernel {
class SoftmaxKernel : public framework::OpKernel {
public:
void Compute(const ExecutionContext& context) const override {
void Compute(const framework::ExecutionContext& context) const override {
auto input = context.Input<Tensor>("X");
auto output = context.Output<Tensor>("Y");
output->mutable_data<T>(context.GetPlace());
......@@ -62,9 +64,9 @@ class SoftmaxKernel : public OpKernel {
};
template <typename Place, typename T>
class SoftmaxGradKernel : public OpKernel {
class SoftmaxGradKernel : public framework::OpKernel {
public:
void Compute(const ExecutionContext& context) const override {
void Compute(const framework::ExecutionContext& context) const override {
std::shared_ptr<Tensor> scale_ = std::make_shared<Tensor>();
auto Y = context.Input<Tensor>("Y");
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <random>
#include <type_traits>
#include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
namespace paddle {
namespace operators {
// It seems that Eigen::Tensor::random in GPU will SEGFAULT.
// Use std::random and thrust::random(thrust is a std library in CUDA) to
// implement uniform random.
template <typename T>
class CPUUniformRandomKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* tensor = context.Output<framework::Tensor>(0);
T* data = tensor->mutable_data<T>(context.GetPlace());
unsigned int seed =
static_cast<unsigned int>(context.op_.GetAttr<int>("seed"));
std::minstd_rand engine;
if (seed == 0) {
seed = std::random_device()();
}
engine.seed(seed);
std::uniform_real_distribution<T> dist(
static_cast<T>(context.op_.GetAttr<float>("min")),
static_cast<T>(context.op_.GetAttr<float>("max")));
for (ssize_t i = 0; i < framework::product(tensor->dims()); ++i) {
data[i] = dist(engine);
}
}
};
class UniformRandomOp : public framework::OperatorWithKernel {
protected:
void InferShape(const framework::InferShapeContext& ctx) const override {
PADDLE_ENFORCE(GetAttr<float>("min") < GetAttr<float>("max"),
"uniform_random's min must less then max");
auto* tensor = ctx.Output<framework::Tensor>(0);
auto dims = GetAttr<std::vector<int>>("dims");
tensor->Resize(framework::make_ddim(dims));
}
};
class UniformRandomOpMaker : public framework::OpProtoAndCheckerMaker {
public:
UniformRandomOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddOutput("Out", "The output tensor of uniform random op");
AddComment(R"DOC(Uniform random operator.
Used to initialize tensor with uniform random generator.
)DOC");
AddAttr<std::vector<int>>("dims", "the dimension of random tensor");
AddAttr<float>("min", "Minimum value of uniform random").SetDefault(-1.0f);
AddAttr<float>("max", "Maximun value of uniform random").SetDefault(1.0f);
AddAttr<int>("seed",
"Random seed of uniform random. "
"0 means generate a seed by system")
.SetDefault(0);
}
};
} // namespace operators
} // namespace paddle
REGISTER_OP(uniform_random, paddle::operators::UniformRandomOp,
paddle::operators::UniformRandomOpMaker);
REGISTER_OP_CPU_KERNEL(uniform_random,
paddle::operators::CPUUniformRandomKernel<float>);
......@@ -12,44 +12,59 @@
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include <thrust/device_ptr.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/random.h>
#include <thrust/transform.h>
#include "paddle/framework/op_registry.h"
#include "paddle/operators/net_op.h"
#include "paddle/framework/operator.h"
namespace paddle {
namespace operators {
using OpKernel = framework::OpKernel;
using OperatorBase = framework::OperatorBase;
using InferShapeContext = framework::InferShapeContext;
using ExecutionContext = framework::ExecutionContext;
using Variable = framework::Variable;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenScalar = framework::EigenScalar<T, MajorType, IndexType>;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename T, size_t D, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
using Tensor = framework::Tensor;
using Scope = framework::Scope;
using OperatorWithKernel = framework::OperatorWithKernel;
using OperatorBase = framework::OperatorBase;
using OpProtoAndCheckerMaker = framework::OpProtoAndCheckerMaker;
using OpProto = framework::OpProto;
using OpAttrChecker = framework::OpAttrChecker;
using CPUPlace = platform::CPUPlace;
using GPUPlace = platform::GPUPlace;
using OpRegistry = framework::OpRegistry;
template <typename T>
struct UniformGenerator {
T min_, max_;
unsigned int seed_;
__host__ __device__ UniformGenerator(T min, T max, int seed)
: min_(min), max_(max), seed_(seed) {}
__host__ __device__ T operator()(const unsigned int n) const {
thrust::minstd_rand rng;
rng.seed(seed_);
thrust::uniform_real_distribution<T> dist(min_, max_);
rng.discard(n);
return dist(rng);
}
};
// It seems that Eigen::Tensor::random in GPU will SEGFAULT.
// Use std::random and thrust::random(thrust is a std library in CUDA) to
// implement uniform random.
template <typename T>
class GPUUniformRandomKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* tensor = context.Output<framework::Tensor>(0);
T* data = tensor->mutable_data<T>(context.GetPlace());
unsigned int seed =
static_cast<unsigned int>(context.op_.GetAttr<int>("seed"));
if (seed == 0) {
seed = std::random_device()();
}
T min = static_cast<T>(context.op_.GetAttr<float>("min"));
T max = static_cast<T>(context.op_.GetAttr<float>("max"));
thrust::counting_iterator<unsigned int> index_sequence_begin(0);
ssize_t N = framework::product(tensor->dims());
thrust::transform(index_sequence_begin, index_sequence_begin + N,
thrust::device_ptr<T>(data),
UniformGenerator<T>(min, max, seed));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(uniform_random,
paddle::operators::GPUUniformRandomKernel<float>);
......@@ -666,4 +666,24 @@ void Argument::subArgFrom(const Argument& input,
}
}
void Argument::reorganizeSeqInfo(
const ICpuGpuVectorPtr seqStartPos,
const ICpuGpuVectorPtr subSeqStartPos,
std::vector<std::vector<int>>& reorganizedSeqInfo) {
int* seqStarts = seqStartPos->getMutableData(false);
int* subSeqStarts = subSeqStartPos->getMutableData(false);
int seqNum = seqStartPos->getSize() - 1;
reorganizedSeqInfo.resize(seqNum, std::vector<int>());
int seqIdx = 0;
for (size_t i = 0; i < subSeqStartPos->getSize(); ++i) {
reorganizedSeqInfo[seqIdx].push_back(subSeqStarts[i]);
if (subSeqStarts[i] == seqStarts[seqIdx + 1]) {
seqIdx++;
if (seqIdx == seqNum) return;
reorganizedSeqInfo[seqIdx].push_back(subSeqStarts[i]);
}
}
}
} // namespace paddle
......@@ -317,6 +317,30 @@ struct Argument {
*/
void printValueString(std::ostream& stream,
const std::string& prefix = "") const;
/**
* @brief reorganizeSeqInfo will reorganize sequenceStartPositions and
* subSequenceStartPositions into a 2 dimensional arrary: reorganizedSeqInfo.
*
* @param seqStartPos: sequenceStartPositions of an Argument.
* @param subSeqStartPos: subSequenceStartPositions of an Argument.
* @param the reorganized sequence start position information.
*
* Examples:
* seqStartPos: [0, 4, 15, 20, 28]
* subSeqStartPos: [0, 3, 4, 5, 7, 10, 15, 20, 22, 23, 25, 28]
* reorganizedSeqInfo:
* [
* [0,3,4],
* [4,5,7,10,15],
* [15,20],
* [20,22,23,25,28]
* ]
*/
static void reorganizeSeqInfo(
const ICpuGpuVectorPtr seqStartPos,
const ICpuGpuVectorPtr subSeqStartPos,
std::vector<std::vector<int>>& reorganizedSeqInfo);
};
} // namespace paddle
......@@ -187,13 +187,9 @@ inline void throw_on_error(T e) {
__PADDLE_BINARY_COMPARE(__VAL0, __VAL1, <, >=, __VA_ARGS__)
#define PADDLE_ENFORCE_LE(__VAL0, __VAL1, ...) \
__PADDLE_BINARY_COMPARE(__VAL0, __VAL1, <=, >, __VA_ARGS__)
// if two values have different data types, choose a compatible type for them.
template <typename T1, typename T2>
struct CompatibleType {
static const bool t1_to_t2 = std::is_convertible<T1, T2>::value;
typedef typename std::conditional<t1_to_t2, T2, T1>::type type;
};
#define PADDLE_ENFORCE_NOT_NULL(__VAL, ...) \
PADDLE_ENFORCE(nullptr != (__VAL), #__VAL " should not be null\n%s", \
paddle::string::Sprintf("" __VA_ARGS__));
template <typename T>
inline std::string enforce_to_string(const T& val) {
......@@ -211,17 +207,12 @@ inline std::string enforce_to_string(const char* const& val) {
}
#define __PADDLE_BINARY_COMPARE(__VAL0, __VAL1, __CMP, __INV_CMP, ...) \
PADDLE_ENFORCE(__COMPATIBLE_TYPE(__VAL0, __VAL1, __VAL0) \
__CMP __COMPATIBLE_TYPE(__VAL0, __VAL1, __VAL1), \
PADDLE_ENFORCE(__VAL0 __CMP __VAL1, \
"enforce %s " #__CMP " %s failed, %s " #__INV_CMP " %s\n%s", \
#__VAL0, #__VAL1, \
paddle::platform::enforce_to_string(__VAL0), \
paddle::platform::enforce_to_string(__VAL1), \
paddle::string::Sprintf("" __VA_ARGS__));
#define __COMPATIBLE_TYPE(__VAL0, __VAL1, __VAL) \
typename paddle::platform::CompatibleType<decltype(__VAL0), \
decltype(__VAL1)>::type(__VAL)
} // namespace platform
} // namespace paddle
......@@ -9,8 +9,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/platform/enforce.h"
#include <memory>
#include "gtest/gtest.h"
#include "paddle/platform/enforce.h"
TEST(ENFORCE, OK) {
PADDLE_ENFORCE(true, "Enforce is ok %d now %f", 123, 0.345);
......@@ -196,3 +198,27 @@ TEST(ENFORCE_LT, FAIL) {
ASSERT_TRUE(in_catch);
}
TEST(ENFORCE_NOT_NULL, OK) {
int* a = new int;
PADDLE_ENFORCE_NOT_NULL(a);
delete a;
}
TEST(ENFORCE_NOT_NULL, FAIL) {
bool in_catch = false;
int* a{nullptr};
try {
PADDLE_ENFORCE_NOT_NULL(a);
} catch (paddle::platform::EnforceNotMet error) {
in_catch = true;
const std::string msg = "a should not be null";
const char* what = error.what();
for (size_t i = 0; i < msg.length(); ++i) {
ASSERT_EQ(what[i], msg[i]);
}
}
ASSERT_TRUE(in_catch);
}
cc_library(paddle_pybind SHARED
SRCS pybind.cc
DEPS pybind python backward
fc_op
sgd_op
add_op
mean_op
cross_entropy_op
recurrent_op
fill_zeros_like_op)
......@@ -33,6 +33,9 @@ Configuring cmake in /paddle/build ...
-DWITH_AVX=${WITH_AVX:-OFF}
-DWITH_GOLANG=${WITH_GOLANG:-OFF}
-DWITH_SWIG_PY=ON
-DWITH_C_API=${WITH_C_API:-OFF}
-DWITH_PYTHON=${WITH_PYTHON:-ON}
-DWITH_SWIG_PY=${WITH_SWIG_PY:-ON}
-DCUDNN_ROOT=/usr/
-DWITH_STYLE_CHECK=${WITH_STYLE_CHECK:-OFF}
-DWITH_TESTING=${WITH_TESTING:-OFF}
......@@ -49,7 +52,9 @@ cmake .. \
-DWITH_GPU=${WITH_GPU:-OFF} \
-DWITH_AVX=${WITH_AVX:-OFF} \
-DWITH_GOLANG=${WITH_GOLANG:-OFF} \
-DWITH_SWIG_PY=ON \
-DWITH_SWIG_PY=${WITH_SWIG_PY:-ON} \
-DWITH_C_API=${WITH_C_API:-OFF} \
-DWITH_PYTHON=${WITH_PYTHON:-ON} \
-DCUDNN_ROOT=/usr/ \
-DWITH_STYLE_CHECK=${WITH_STYLE_CHECK:-OFF} \
-DWITH_TESTING=${WITH_TESTING:-OFF} \
......
......@@ -5,15 +5,9 @@ set -e
mkdir -p $TRAVIS_BUILD_DIR/build
cd $TRAVIS_BUILD_DIR/build
# Compile paddle binaries first
cmake .. -DCMAKE_BUILD_TYPE=Debug -DWITH_GPU=OFF -DWITH_DOC=OFF -DWITH_MKLDNN=OFF -DWITH_MKLML=OFF -DWITH_GOLANG=ON -DWITH_STYLE_CHECK=OFF
mkdir output
make -j `nproc`
find .. -name '*whl' | xargs pip install # install all wheels.
rm -rf *
# Compile Documentation only.
cmake .. -DCMAKE_BUILD_TYPE=Debug -DWITH_GPU=OFF -DWITH_MKLDNN=OFF -DWITH_MKLML=OFF -DWITH_DOC=ON
make -j `nproc` gen_proto_py
make -j `nproc` paddle_docs paddle_docs_cn
# check websites for broken links
......@@ -35,6 +29,7 @@ TARGET_BRANCH="gh-pages"
SOURCE_BRANCH="master"
# Clone the repo to output directory
mkdir output
git clone $REPO output
cd output
......
......@@ -17,7 +17,7 @@ foreach(filename ${proto_filenames})
COMMAND ${PROTOBUF_PROTOC_EXECUTABLE}
ARGS "--python_out=${PROJ_ROOT}/python/paddle/proto"
"-I" ${CMAKE_CURRENT_SOURCE_DIR} ${ABS_FIL}
DEPENDS ${ABS_FIL} ${external_project_dependencies})
DEPENDS ${ABS_FIL} protoc)
endforeach()
add_custom_target(gen_proto_py ALL DEPENDS ${PROTO_GEN_PY})
......@@ -2657,6 +2657,31 @@ class SubSequenceLayer(LayerBase):
self.create_bias_parameter(bias, size)
@config_layer('sub_nested_seq')
class SubNestedSequenceLayer(LayerBase):
def __init__(self, name, inputs, selected_indices, bias=False, **xargs):
if isinstance(inputs, list):
assert len(inputs) == 1, ('the first input of sub_nested_seq '
'layer is a single nested sequence.')
inputs = inputs[0]
if isinstance(selected_indices, list):
assert len(selected_indices) == 1, (
'the second input of '
'sub_nested_seq layer is a single layer which is a '
'set of selected indices.')
selected_indices = selected_indices[0]
super(SubNestedSequenceLayer, self).__init__(
name,
'sub_nested_seq',
0,
inputs=[inputs, selected_indices],
**xargs)
input_layer0 = self.get_input_layer(0)
size = input_layer0.size
self.set_layer_size(size)
@config_layer('out_prod')
class OuterProdLayer(LayerBase):
def __init__(self, name, inputs, device=None):
......@@ -3223,6 +3248,16 @@ class CTCLayer(LayerBase):
config_assert(len(self.inputs) == 2, 'CTCLayer must have 2 inputs')
@config_layer('kmax_seq_score')
class KmaxSeqScoreLayer(LayerBase):
def __init__(self, name, inputs, beam_size, **xargs):
super(KmaxSeqScoreLayer, self).__init__(
name, 'kmax_seq_score', 0, inputs=inputs, **xargs)
config_assert(
len(self.inputs) == 1, 'KmaxSeqScoreLayer has only one input.')
self.config.beam_size = beam_size
@config_layer('warp_ctc')
class WarpCTCLayer(LayerBase):
def __init__(self,
......
......@@ -129,8 +129,10 @@ __all__ = [
'prelu_layer',
'gated_unit_layer',
'crop_layer',
'sub_nested_seq_layer',
'clip_layer',
'slice_projection',
'kmax_sequence_score_layer',
]
......@@ -224,8 +226,11 @@ class LayerType(object):
PRELU = 'prelu'
CROP_LAYER = 'crop'
SUB_NESTED_SEQ = 'sub_nested_seq'
CLIP_LAYER = 'clip'
KMAX_SEQ_SCORE = 'kmax_seq_score'
@staticmethod
def is_layer_type(type_name):
"""
......@@ -6088,6 +6093,53 @@ def crop_layer(input, offset, axis=2, shape=None, name=None, layer_attr=None):
size=l.config.size)
@wrap_name_default()
@layer_support()
def sub_nested_seq_layer(input, selected_indices, name=None):
"""
The sub_nested_seq_layer accepts two inputs: the first one is a nested
sequence; the second one is a set of selceted indices in the nested sequence.
Then sub_nest_seq_layer trims the first nested sequence input according
to the selected indices to form a new output. This layer is useful in
beam training.
The example usage is:
.. code-block:: python
sub_nest_seq = sub_nested_seq_layer(input=[data, selected_indices])
:param input: A nested sequence.
:type input: LayerOutput
:param selected_indices: a set of sequence indices in the nested sequence.
:type input: LayerOutput
:param name: name of this layer.
:type name: basestring
:return: LayerOutput object.
:rtype: LayerOutput
"""
assert isinstance(input, LayerOutput), (
'The first input of '
'sub_nested_seq_layer must be a Paddle layer.')
assert isinstance(selected_indices, LayerOutput), (
'The second input of '
'sub_nested_seq_layer must be a Paddle layer.')
l = Layer(
inputs=input.name,
selected_indices=selected_indices.name,
name=name,
type=LayerType.SUB_NESTED_SEQ)
return LayerOutput(
name=name,
layer_type=LayerType.SUB_NESTED_SEQ,
parents=input,
size=l.config.size)
@wrap_name_default("clip")
def clip_layer(input, min, max, name=None):
"""
......@@ -6109,7 +6161,8 @@ def clip_layer(input, min, max, name=None):
:type min: double
:param max: The upper threshold for clipping.
:type max: double
:return: LayerOutput
:return: LayerOutput object.
:rtype: LayerOutput
"""
Layer(
name=name,
......@@ -6119,3 +6172,41 @@ def clip_layer(input, min, max, name=None):
max=max)
return LayerOutput(
name, LayerType.CLIP_LAYER, parents=[input], size=input.size)
@wrap_name_default()
@layer_support()
def kmax_sequence_score_layer(input, name=None, beam_size=1):
"""
This layer accepts one input which are scores over a sequence or a nested
sequence, and returns indices of beam_size sequences with highest scores.
.. code-block:: python
kmax_indices = kmax_sequence_score_layer(input=input_layer, beam_size)
:param name: The Layer Name.
:type name: basestring
:param input: The input layer. It stores scores over a sequence or a nested
sequence and its size must be 1.
:type input: LayerOutput.
:param beam_size: squence indices with top beam_size scores are returned.
:type beam_size: double
:return: LayerOutput object.
:rtype: LayerOutput
"""
assert isinstance(input, LayerOutput), ("kmax_sequence_score_layer "
"accepts only one input.")
assert input.size == 1, (
"input of kmax_sequence_score_layer is a score"
"over a sequence or a nested sequence, so its width must be 1.")
Layer(
name=name,
type=LayerType.KMAX_SEQ_SCORE,
inputs=[input.name],
beam_size=beam_size)
return LayerOutput(
name, LayerType.KMAX_SEQ_SCORE, parents=[input], size=input.size)
......@@ -7,6 +7,7 @@ test_rnn_group shared_fc shared_lstm shared_gru test_cost_layers_with_weight
test_spp_layer test_bilinear_interp test_maxout test_bi_grumemory math_ops
test_seq_concat_reshape test_pad test_smooth_l1 test_multiplex_layer
test_prelu_layer test_row_conv test_detection_output_layer test_multibox_loss_layer
test_recursive_topology test_gated_unit_layer test_clip_layer test_row_l2_norm_layer)
test_recursive_topology test_gated_unit_layer test_clip_layer test_row_l2_norm_layer
test_kmax_seq_socre_layer test_seq_select_layers)
export whole_configs=(test_split_datasource)
type: "nn"
layers {
name: "input"
type: "data"
size: 300
active_type: ""
}
layers {
name: "data"
type: "data"
size: 128
active_type: ""
}
layers {
name: "__fc_layer_0__"
type: "fc"
size: 1
active_type: "exponential"
inputs {
input_layer_name: "data"
input_parameter_name: "___fc_layer_0__.w0"
}
bias_parameter_name: "___fc_layer_0__.wbias"
}
layers {
name: "__kmax_sequence_score_layer_0__"
type: "kmax_seq_score"
active_type: ""
inputs {
input_layer_name: "__fc_layer_0__"
}
beam_size: 5
}
parameters {
name: "___fc_layer_0__.w0"
size: 128
initial_mean: 0.0
initial_std: 0.0883883476483
dims: 128
dims: 1
initial_strategy: 0
initial_smart: true
}
parameters {
name: "___fc_layer_0__.wbias"
size: 1
initial_mean: 0.0
initial_std: 0.0
dims: 1
dims: 1
initial_strategy: 0
initial_smart: false
}
input_layer_names: "data"
output_layer_names: "__kmax_sequence_score_layer_0__"
sub_models {
name: "root"
layer_names: "input"
layer_names: "data"
layer_names: "__fc_layer_0__"
layer_names: "__kmax_sequence_score_layer_0__"
input_layer_names: "data"
output_layer_names: "__kmax_sequence_score_layer_0__"
is_recurrent_layer_group: false
}
type: "nn"
layers {
name: "input_seq"
type: "data"
size: 300
active_type: ""
}
layers {
name: "input"
type: "data"
size: 5
active_type: ""
}
layers {
name: "__sub_nested_seq_layer_0__"
type: "sub_nested_seq"
size: 300
active_type: ""
inputs {
input_layer_name: "input_seq"
}
inputs {
input_layer_name: "input"
}
}
input_layer_names: "input_seq"
output_layer_names: "__sub_nested_seq_layer_0__"
sub_models {
name: "root"
layer_names: "input_seq"
layer_names: "input"
layer_names: "__sub_nested_seq_layer_0__"
input_layer_names: "input_seq"
output_layer_names: "__sub_nested_seq_layer_0__"
is_recurrent_layer_group: false
}
#!/usr/bin/env python
#coding=utf-8
from paddle.trainer_config_helpers import *
data = data_layer(name='input', size=300)
data = data_layer(name="data", size=128)
scores = fc_layer(input=data, size=1, act=ExpActivation())
kmax_seq_id = kmax_sequence_score_layer(input=scores, beam_size=5)
outputs(kmax_seq_id)
#!/usr/bin/env python
#coding=utf-8
from paddle.trainer_config_helpers import *
beam_size = 5
data = data_layer(name='input_seq', size=300)
selected_ids = data_layer(name='input', size=beam_size)
sub_nest_seq = sub_nested_seq_layer(input=data, selected_indices=selected_ids)
outputs(sub_nest_seq)
......@@ -13,6 +13,7 @@ py_test(test_protobuf SRCS test_protobuf.py)
py_test(test_add_two_op SRCS test_add_two_op.py)
py_test(test_sigmoid_op SRCS test_sigmoid_op.py)
py_test(test_softmax_op SRCS test_softmax_op.py)
py_test(test_cross_entropy_op SRCS test_cross_entropy_op.py)
py_test(test_fill_zeros_like_op SRCS test_fill_zeros_like_op.py)
py_test(gradient_checker SRCS gradient_checker.py)
......@@ -21,3 +22,4 @@ py_test(test_rowwise_add_op SRCS test_rowwise_add_op.py)
py_test(test_default_scope_funcs SRCS test_default_scope_funcs.py)
py_test(test_operator SRCS test_operator.py)
py_test(test_uniform_random_op SRCS test_uniform_random_op.py)
import unittest
import numpy
import paddle.v2.framework.core as core
from paddle.v2.framework.op import Operator
import numpy
import unittest
__all__ = ['get_numeric_gradient']
def create_op(op_type):
kwargs = dict()
for in_name in Operator.get_op_input_names(op_type):
kwargs[in_name] = in_name
for out_name in Operator.get_op_output_names(op_type):
kwargs[out_name] = out_name
return Operator(op_type, **kwargs)
def grad_var_name(var_name):
return var_name + "@GRAD"
def get_numeric_gradient(op,
input_values,
output_name,
input_to_check,
delta=1e-2,
delta=0.005,
local_scope=None):
"""
Get Numeric Gradient for an operator's input.
......@@ -76,6 +91,113 @@ def get_numeric_gradient(op,
return gradient_flat.reshape(tensor_to_check.get_dims())
class GradientChecker(unittest.TestCase):
def __is_close(self, numeric_grads, scope, max_relative_error):
for name in numeric_grads:
op_grad = numpy.array(
scope.find_var(grad_var_name(name)).get_tensor())
is_close = numpy.allclose(
numeric_grads[name], op_grad, rtol=max_relative_error, atol=100)
if not is_close:
return False
return True
def check_grad(self,
forward_op,
input_vars,
inputs_to_check,
output_name,
no_grad_set=None,
only_cpu=False,
max_relative_error=0.005):
"""
:param forward_op: used to create backward_op
:param input_vars: numpy value of input variable. The following
computation will use these variables.
:param inputs_to_check: inputs var names that should check gradient.
:param output_name: output name that used to
:param max_relative_error: The relative tolerance parameter.
:param no_grad_set: used when create backward ops
:param only_cpu: only compute and check gradient on cpu kernel.
:return:
"""
if no_grad_set is None:
no_grad_set = set()
tmp_outs = forward_op.temp_outputs()
no_tmp_out = filter(lambda name: name not in tmp_outs,
forward_op.outputs())
if len(no_tmp_out) != 1:
raise ValueError("non temp out_names should be 1")
in_names = forward_op.inputs()
for no_grad in no_grad_set:
if no_grad not in in_names:
raise ValueError("no_grad should be in in_names")
backward_op = core.Operator.backward(forward_op, no_grad_set)
places = [core.CPUPlace()]
if not only_cpu and core.is_compile_gpu() and backward_op.support_gpu():
places.append(core.GPUPlace(0))
numeric_grad = dict()
# get numeric gradient
for check_name in inputs_to_check:
numeric_grad[check_name] = \
get_numeric_gradient(forward_op, input_vars, output_name, check_name)
# get operator gradient according to different device
for place in places:
scope = core.Scope()
ctx = core.DeviceContext.create(place)
# create input var and set value
for name, value in input_vars.iteritems():
if name not in in_names:
raise ValueError(name + " not in op.inputs_")
var = scope.new_var(name).get_tensor()
var.set_dims(value.shape)
var.set(value, place)
# create output var
for out_name in forward_op.outputs():
scope.new_var(out_name).get_tensor()
# infer the shape of output var and compute/set value of output var
forward_op.infer_shape(scope)
forward_op.run(scope, ctx)
# create output grad var
# set shape as the output var
# set value of this grad to ones
for name in forward_op.outputs():
out_tensor = scope.find_var(name).get_tensor()
grad_tensor = scope.new_var(grad_var_name(name)).get_tensor()
grad_tensor.set_dims(out_tensor.shape())
data = 1.0 * numpy.ones(out_tensor.shape())
grad_tensor.set(data, place)
# create input grad var
for name in backward_op.outputs():
scope.new_var(name).get_tensor()
# infer the shape of input gradient var and compute/set it's value
# with backward op
backward_op.infer_shape(scope)
backward_op.run(scope, ctx)
if isinstance(place, core.CPUPlace):
msg = "CPU kernel gradient is not close to numeric gradient"
else:
if isinstance(place, core.GPUPlace):
msg = "GPU kernel gradient is not close to numeric gradient"
else:
raise ValueError("unknown place " + type(place))
self.assertTrue(
self.__is_close(numeric_grad, scope, max_relative_error), msg)
if __name__ == '__main__':
class GetNumericGradientTest(unittest.TestCase):
......@@ -87,4 +209,28 @@ if __name__ == '__main__':
arr = get_numeric_gradient(add_op, {'X': x, "Y": y}, 'Z', 'X')
self.assertAlmostEqual(arr.mean(), 1.0, delta=1e-2)
def test_softmax_op(self):
def stable_softmax(x):
"""Compute the softmax of vector x in a numerically stable way."""
shiftx = x - numpy.max(x)
exps = numpy.exp(shiftx)
return exps / numpy.sum(exps)
def label_softmax_grad(Y, dY):
dX = Y * 0.0
for i in range(Y.shape[0]):
d = numpy.dot(Y[i, :], dY[i, :])
dX[i, :] = Y[i, :] * (dY[i, :] - d)
return dX
softmax_op = Operator("softmax", X="X", Y="Y")
X = numpy.random.random((2, 2)).astype("float32")
Y = numpy.apply_along_axis(stable_softmax, 1, X)
dY = numpy.ones(Y.shape)
dX = label_softmax_grad(Y, dY)
arr = get_numeric_gradient(softmax_op, {"X": X}, 'Y', 'X')
numpy.testing.assert_almost_equal(arr, dX, decimal=1e-2)
unittest.main()
import paddle.v2.framework.core as core
import unittest
import numpy
import paddle.v2.framework.core as core
from paddle.v2.framework.op import Operator
......@@ -24,7 +23,7 @@ class OpTestMeta(type):
scope = core.Scope()
kwargs = dict()
places = [core.CPUPlace()]
if core.is_compile_gpu() and core.Operator.support_gpu(self.type):
if core.is_compile_gpu():
places.append(core.GPUPlace(0))
for place in places:
......@@ -53,6 +52,8 @@ class OpTestMeta(type):
kwargs[attr_name] = self.attrs[attr_name]
op = Operator(self.type, **kwargs)
if isinstance(place, core.GPUPlace) and not op.support_gpu():
return
op.infer_shape(scope)
......
import unittest
import numpy
from op_test_util import OpTestMeta
from gradient_checker import GradientChecker, create_op
class TestSGD(unittest.TestCase):
class TestCrossEntropy(unittest.TestCase):
__metaclass__ = OpTestMeta
def setUp(self):
......@@ -20,7 +21,18 @@ class TestSGD(unittest.TestCase):
self.outputs = {'Y': numpy.array(Y).astype("float32")}
# TODO(superjom) add gradient check
class CrossEntropyGradOpTest(GradientChecker):
def test_softmax_grad(self):
op = create_op("onehot_cross_entropy")
batch_size = 100
class_num = 10
inputs = {
"X": numpy.random.uniform(
0.1, 1.0, [batch_size, class_num]).astype("float32"),
"label": (class_num / 2) * numpy.ones(batch_size).astype("int32")
}
self.check_grad(op, inputs, set("X"), "Y")
if __name__ == "__main__":
unittest.main()
import unittest
import numpy as np
import paddle.v2.framework.core as core
from paddle.v2.framework.op import Operator
from gradient_checker import GradientChecker, create_op
from op_test_util import OpTestMeta
......@@ -25,62 +24,11 @@ class TestSoftmaxOp(unittest.TestCase):
}
class TestSoftmaxGradOp(unittest.TestCase):
def test_softmax_grad(self):
op = Operator('softmax', X="X", Y="Y")
backward_op = core.Operator.backward(op, set())
self.assertEqual(backward_op.type(), "softmax_grad")
expected = '''Op(softmax_grad), inputs:(X, Y, Y@GRAD), outputs:(X@GRAD).'''
self.assertEqual(expected, str(backward_op))
batch_size = 3
class_num = 5
# Initialize X and add 1e-2 for numerical stability
Y = np.random.rand(batch_size, class_num).astype(np.float32)
Y = Y + 1e-2
dY = np.random.rand(batch_size, class_num).astype(np.float32)
# Reference implementation of cross entropy with soft labels
def label_softmax_grad(Y, dY):
dX = Y * 0.0
for i in range(batch_size):
d = np.dot(Y[i, :], dY[i, :])
dX[i, :] = Y[i, :] * (dY[i, :] - d)
return dX
expected = label_softmax_grad(Y, dY)
scope = core.Scope()
places = []
places.append(core.CPUPlace())
if core.is_compile_gpu():
places.append(core.GPUPlace(0))
for place in places:
y = scope.new_var("Y")
y_tensor = y.get_tensor()
y_tensor.set_dims([batch_size, class_num])
y_tensor.alloc_float(place)
y_tensor.set(Y, place)
dy = scope.new_var("Y@GRAD")
dy_tensor = dy.get_tensor()
dy_tensor.set_dims([batch_size, class_num])
dy_tensor.alloc_float(place)
dy_tensor.set(dY, place)
x = scope.new_var("X")
dx = scope.new_var("X@GRAD")
tensor = scope.find_var("X@GRAD").get_tensor()
backward_op.infer_shape(scope)
self.assertEqual([batch_size, class_num], tensor.shape())
ctx = core.DeviceContext.create(place)
backward_op.run(scope, ctx)
actual = np.array(tensor)
np.testing.assert_almost_equal(actual, expected, decimal=3)
class SoftmaxGradOpTest(GradientChecker):
def test_softmax(self):
op = create_op("softmax")
inputs = {"X": np.random.uniform(0.1, 1, [10, 10]).astype("float32")}
self.check_grad(op, inputs, set("X"), "Y")
if __name__ == '__main__':
......
import unittest
from paddle.v2.framework.op import Operator
import paddle.v2.framework.core as core
import numpy
class UniformRandomTest(unittest.TestCase):
def test_uniform_random_cpu(self):
self.uniform_random_test(place=core.CPUPlace())
def test_uniform_random_gpu(self):
if core.is_compile_gpu():
self.uniform_random_test(place=core.GPUPlace(0))
def uniform_random_test(self, place):
scope = core.Scope()
scope.new_var("X").get_tensor()
op = Operator(
"uniform_random",
Out="X",
dims=[1000, 784],
min=-5.0,
max=10.0,
seed=10)
op.infer_shape(scope)
ctx = core.DeviceContext.create(place)
op.run(scope, ctx)
tensor = numpy.array(scope.find_var("X").get_tensor())
self.assertAlmostEqual(tensor.mean(), 2.5, delta=0.1)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册