提交 ac29d00c 编写于 作者: R ranqiu

Update doc of layers.py

......@@ -21,7 +21,7 @@ third_party/
cmake-build-*
# generated while compiling
python/paddle/v2/framework/core.so
python/paddle/v2/fluid/core.so
paddle/pybind/pybind.h
CMakeFiles
cmake_install.cmake
......
......@@ -377,6 +377,12 @@ std::vector<std::unique_ptr<OpDescBind>> MakeOpGrad(
return grad_op_descs;
}
static BlockDescBind* CreateStepBlock(
ProgramDescBind& program_desc,
std::unordered_set<std::string>* no_grad_vars,
std::unordered_map<std::string, std::string>* grad_to_var,
int step_block_idx);
std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
ProgramDescBind& program_desc, int block_idx,
std::unordered_set<std::string>* no_grad_vars,
......@@ -392,13 +398,13 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
if ((*it)->Type() == "recurrent") {
int step_block_idx = (*it)->GetBlockAttr("step_block");
auto backward_block_op_descs = MakeBlockBackward(
program_desc, step_block_idx, no_grad_vars, grad_to_var);
BlockDescBind* backward_block = CreateStepBlock(
program_desc, no_grad_vars, grad_to_var, step_block_idx);
op_grads = MakeOpGrad(*it, no_grad_vars, grad_to_var, {backward_block});
} else if ((*it)->Type() == "conditional_block") {
BlockDescBind* backward_block =
program_desc.AppendBlock(*program_desc.MutableBlock(step_block_idx));
for (auto& ptr : backward_block_op_descs) {
backward_block->AppendAllocatedOp(std::move(ptr));
}
CreateStepBlock(program_desc, no_grad_vars, grad_to_var,
(*it)->GetBlockAttr("block"));
op_grads = MakeOpGrad(*it, no_grad_vars, grad_to_var, {backward_block});
} else {
op_grads = MakeOpGrad(*it, no_grad_vars, grad_to_var);
......@@ -449,6 +455,21 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
return backward_descs;
}
static BlockDescBind* CreateStepBlock(
ProgramDescBind& program_desc,
std::unordered_set<std::string>* no_grad_vars,
std::unordered_map<std::string, std::string>* grad_to_var,
int step_block_idx) {
auto backward_block_op_descs = MakeBlockBackward(program_desc, step_block_idx,
no_grad_vars, grad_to_var);
BlockDescBind* backward_block =
program_desc.AppendBlock(*program_desc.MutableBlock(step_block_idx));
for (auto& ptr : backward_block_op_descs) {
backward_block->AppendAllocatedOp(move(ptr));
}
return backward_block;
}
ParamGradInfoMap AppendBackward(
ProgramDescBind& program_desc, const VarDescBind& target,
const std::unordered_set<std::string>& no_grad_vars) {
......
......@@ -27,10 +27,32 @@ inline VarDesc::VarType ToVarType(std::type_index type) {
return VarDesc_VarType_LOD_RANK_TABLE;
} else if (type.hash_code() == typeid(LoDTensorArray).hash_code()) {
return VarDesc_VarType_LOD_TENSOR_ARRAY;
} else if (type.hash_code() == typeid(SelectedRows).hash_code()) {
return VarDesc_VarType_SELECTED_ROWS;
} else {
PADDLE_THROW("ToVarType:Unsupported type %s", type.name());
}
}
template <typename Visitor>
inline void VisitVarType(const Variable& var, Visitor visitor) {
switch (ToVarType(var.Type())) {
case VarDesc_VarType_LOD_TENSOR:
visitor(var.Get<framework::LoDTensor>());
return;
case VarDesc_VarType_LOD_RANK_TABLE:
visitor(var.Get<LoDRankTable>());
return;
case VarDesc_VarType_LOD_TENSOR_ARRAY:
visitor(var.Get<LoDTensorArray>());
return;
case VarDesc_VarType_SELECTED_ROWS:
visitor(var.Get<SelectedRows>());
return;
default:
PADDLE_THROW("Not supported visit type, %d", ToVarType(var.Type()));
}
}
} // namespace framework
} // namespace paddle
......@@ -54,7 +54,6 @@ void MKLDNNAddtoLayer::reshape(
ow = iw;
reshapeOutput(oh, ow);
resizeOutput(bs, oc * oh * ow);
printSizeInfo();
}
void MKLDNNAddtoLayer::resetFwd(std::vector<primitive>& pipeline,
......
......@@ -125,7 +125,6 @@ void MKLDNNBatchNormLayer::reshape(
<< "Input channel can not be changed";
reshapeOutput(oh, ow);
resizeOutput(bs, oc * oh * ow);
printSizeInfo();
}
void MKLDNNBatchNormLayer::resetFwd(std::vector<primitive>& pipeline,
......
......@@ -102,8 +102,6 @@ void MKLDNNConvLayer::reshape(
reshapeOutput(oh, ow);
resizeOutput(bs, oc * oh * ow);
printSizeInfo();
}
void MKLDNNConvLayer::resetFwd(std::vector<primitive>& pipeline,
......
......@@ -92,7 +92,7 @@ public:
void printSizeInfo() override {
MKLDNNLayer::printSizeInfo();
VLOG(MKLDNN_SIZES) << getName() << ": fh: " << fh_ << ", fw: " << fw_
<< ": ph: " << ph_ << ", pw: " << pw_ << ", sh: " << sh_
<< ", ph: " << ph_ << ", pw: " << pw_ << ", sh: " << sh_
<< ", sw: " << sw_ << ", dh: " << dh_ << ", dw: " << dw_;
}
......
......@@ -84,8 +84,6 @@ void MKLDNNFcLayer::reshape(
reshapeOutput(oh, ow);
resizeOutput(bs, oc);
printSizeInfo();
}
void MKLDNNFcLayer::resetFwd(std::vector<primitive>& pipeline,
......
......@@ -71,8 +71,6 @@ void MKLDNNPoolLayer::reshape(
reshapeOutput(oh, ow);
resizeOutput(bs, oc * oh * ow);
printSizeInfo();
}
void MKLDNNPoolLayer::resetFwd(std::vector<primitive>& pipeline,
......
......@@ -98,7 +98,7 @@ void ROIPoolLayer::forward(PassType passType) {
size_t roiStartH = round(bottomROIs[2] * spatialScale_);
size_t roiEndW = round(bottomROIs[3] * spatialScale_);
size_t roiEndH = round(bottomROIs[4] * spatialScale_);
CHECK_GE(roiBatchIdx, 0);
CHECK_GE(roiBatchIdx, 0UL);
CHECK_LT(roiBatchIdx, batchSize);
size_t roiHeight = std::max(roiEndH - roiStartH + 1, 1UL);
size_t roiWidth = std::max(roiEndW - roiStartW + 1, 1UL);
......
......@@ -297,7 +297,7 @@ static void getAddtoConfig(TestConfig& cfg,
}
void testAddtoLayer(const testImageDesc& pm, const size_t nInputs) {
CHECK_GE(nInputs, 1);
CHECK_GE(nInputs, 1UL);
TestConfig dnnConfig;
getAddtoConfig(dnnConfig, pm, nInputs);
dnnConfig.layerConfig.set_type("mkldnn_addto");
......
......@@ -214,6 +214,7 @@ set(GLOB_OP_LIB ${OP_LIBRARY} CACHE INTERNAL "Global OP library")
cc_test(gather_test SRCS gather_test.cc DEPS tensor)
cc_test(net_op_test SRCS net_op_test.cc DEPS net_op)
cc_test(scatter_test SRCS scatter_test.cc DEPS tensor)
cc_test(beam_search_decode_op_test SRCS beam_search_decode_op_test.cc DEPS lod_tensor)
cc_test(strided_memcpy_test SRCS strided_memcpy_test.cc DEPS tensor paddle_memory)
cc_test(dynamic_recurrent_op_test SRCS dynamic_recurrent_op_test.cc
rnn/recurrent_op_utils.cc
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/data_type.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/var_type.h"
namespace paddle {
namespace operators {
class AssignFunctor {
public:
AssignFunctor(framework::Variable *out,
const platform::DeviceContext &dev_ctx)
: out_(out), dev_ctx_(dev_ctx) {}
void operator()(const framework::LoDTensor &lod_tensor) const {
auto &out_tensor = *out_->GetMutable<framework::LoDTensor>();
copy_tensor(lod_tensor, &out_tensor);
}
void operator()(const framework::LoDTensorArray &array) const {
auto &out_array = *out_->GetMutable<framework::LoDTensorArray>();
out_array.resize(array.size());
for (size_t i = 0; i < array.size(); ++i) {
copy_tensor(array[i], &out_array[i]);
}
}
void operator()(const framework::SelectedRows &rows) const {
framework::SelectedRows &out_rows =
*out_->GetMutable<framework::SelectedRows>();
out_rows.set_rows(rows.rows());
out_rows.set_height(rows.height());
auto &t = rows.value();
out_rows.mutable_value()->CopyFrom(t, t.place(), dev_ctx_);
}
template <typename T>
void operator()(const T &v) const {
PADDLE_THROW("Not support type for assign op %s", typeid(T).name());
}
private:
void copy_tensor(const framework::LoDTensor &lod_tensor,
framework::LoDTensor *out) const {
auto &out_tensor = *out;
out_tensor.CopyFrom(lod_tensor, lod_tensor.place(), dev_ctx_);
out_tensor.set_lod(lod_tensor.lod());
}
framework::Variable *out_;
const platform::DeviceContext &dev_ctx_;
};
class AssignOp : public framework::OperatorBase {
public:
AssignOp(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override {
auto *x = scope.FindVar(Input("X"));
if (x == nullptr) {
return;
}
auto *out = scope.FindVar(Output("Out"));
PADDLE_ENFORCE(
out != nullptr,
"The Output(Out) should not be null if the Input(X) is set.");
framework::VisitVarType(*x, AssignFunctor(out, dev_ctx));
}
};
class AssignOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
AssignOpProtoMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X",
"(LoDTensor, SelectedRows or LoDTensorArray) The input variable "
"could be LoDTensor, SelectedRows or LoDTensorArray.")
.AsDispensable();
AddOutput("Out",
"(LoDTensor, SelectedRows or LoDTensorArray) The type of output "
"is the same as input X.");
AddComment(R"DOC(Assign Operator
Out = X, when type in [LoDTensor/SelectedRows/LoDTensorArray]
raise error if the type is not listed above.
)DOC");
}
};
class AssignInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *context) const override {
if (context->HasInput("X")) {
auto type = context->GetInputsVarType("X")[0];
if (type == framework::VarDesc_VarType_SELECTED_ROWS ||
type == framework::VarDesc_VarType_LOD_TENSOR) {
context->SetOutputDim("Out", context->GetInputDim("X"));
}
}
}
};
class AssignGradMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDescBind> Apply() const override {
auto *op = new framework::OpDescBind();
op->SetType("assign");
op->SetInput("X", OutputGrad("Out"));
op->SetOutput("Out", InputGrad("X"));
return std::unique_ptr<framework::OpDescBind>(op);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(assign, ops::AssignOp, ops::AssignGradMaker,
ops::AssignInferShape, ops::AssignOpProtoMaker);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/beam_search_decode_op.h"
namespace paddle {
namespace operators {
class BeamSearchDecodeOp : public framework::OperatorBase {
public:
BeamSearchDecodeOp(const std::string& type,
const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope& scope,
const platform::DeviceContext& dev_ctx) const override {
framework::ExecutionContext ctx(*this, scope, dev_ctx);
const LoDTensorArray* ids = ctx.Input<LoDTensorArray>("Ids");
const LoDTensorArray* scores = ctx.Input<LoDTensorArray>("Scores");
const size_t step_num = ids->size();
PADDLE_ENFORCE_GT(step_num, 0UL,
"beam search steps should be larger than 0");
const size_t source_num = ids->at(0).lod().at(0).size() - 1;
PADDLE_ENFORCE_GT(source_num, 0UL, "source num should be larger than 0");
for (size_t i = 0; i < step_num; ++i) {
PADDLE_ENFORCE_EQ(ids->at(i).lod().size(), 2UL,
"Level of LodTensor should be 2");
}
// prepare output
LoDTensor* sentenceIds = ctx.Output<LoDTensor>("SentenceIds");
LoDTensor* sentenceScores = ctx.Output<LoDTensor>("SentenceScores");
BeamSearchDecoder<float> beam_search_decoder;
beam_search_decoder.PackAllSteps(*ids, *scores, sentenceIds,
sentenceScores);
}
};
class BeamSearchDecodeOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
BeamSearchDecodeOpProtoMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Ids",
"(LodTensorArray)"
"score of the candidate words in each step");
AddInput("Scores",
"(LodTensorArray)"
"score of the candidate words in each step");
AddOutput("SentenceIds",
"(LodTensor)"
"All possible result sentences of word ids");
AddOutput("SentenceScores",
"(LodTensor)"
"All possible result sentences of word scores");
AddComment(R"DOC(
Pack the result of Beam search op into SentenceIds and SentenceScores.
)DOC");
}
};
class BeamSearchDecodeInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext* context) const override {
PADDLE_ENFORCE(context->HasInput("Ids"),
"BeamSearchDecodeOp must has input Ids");
PADDLE_ENFORCE(context->HasInput("Scores"),
"BeamSearchDecodeOp must has input Scores");
PADDLE_ENFORCE(context->HasOutput("SentenceIds"),
"BeamSearchDecodeOp must has output SentenceIds");
PADDLE_ENFORCE(context->HasOutput("SentenceScores"),
"BeamSearchDecodeOp must has output SentenceScores");
}
};
class BeamSearchDecodeInferVarType : public framework::VarTypeInference {
public:
void operator()(const framework::OpDescBind& op_desc,
framework::BlockDescBind* block) const override {
for (auto& o : op_desc.Output("SentenceIds")) {
block->Var(o)->SetType(framework::VarDesc::LOD_TENSOR);
}
for (auto& o : op_desc.Output("SentenceScores")) {
block->Var(o)->SetType(framework::VarDesc::LOD_TENSOR);
}
}
};
} // namespace operators
} // namespace paddle
REGISTER_OPERATOR(beam_search_decode, paddle::operators::BeamSearchDecodeOp,
paddle::operators::BeamSearchDecodeOpProtoMaker,
paddle::operators::BeamSearchDecodeInferShape,
paddle::operators::BeamSearchDecodeInferVarType,
paddle::framework::EmptyGradOpMaker);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/lod_tensor_array.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
using LoDTensor = framework::LoDTensor;
using LoDTensorArray = framework::LoDTensorArray;
// all the lod have 2 levels.
// The First is source level, the second is sentence level.
// source level describe how many candidate words for this source.
// sentence level describe these candidates belong to which prefix
const size_t kSourceLevel = 0;
const size_t kSentenceLevel = 1;
template <typename T>
struct BeamNode {
BeamNode(int64_t word_id, T score) : word_id_(word_id), score_(score) {}
~BeamNode() {
if (parent_) {
parent_->DropKid(this);
if (parent_->kids_.size() == 0UL) {
delete parent_;
}
}
VLOG(3) << "Delete BeamNode root with word_id:" << this->word_id_;
}
void AppendTo(BeamNode* parent) {
parent_ = parent;
parent->kids_.insert(this);
}
void DropKid(BeamNode* kid) { kids_.erase(kid); }
BeamNode* parent_ = nullptr;
std::unordered_set<BeamNode*> kids_;
int64_t word_id_;
T score_;
};
template <typename T>
using BeamNodeVector = std::vector<std::unique_ptr<BeamNode<T>>>;
template <typename T>
struct Sentence {
std::vector<int64_t> word_ids;
std::vector<T> scores;
};
template <typename T>
using SentenceVector = std::vector<Sentence<T>>;
template <typename T>
struct BeamSearchDecoder {
/**
* make a BeamNode and all it's related prefix BeanNode into a Sentence.
*/
Sentence<T> MakeSentence(const BeamNode<T>* node) const;
/**
* Param:
* cur_ids: LoDTensor of One step for word ID
* cur_scores: LoDTensor of One Step for word score
* prefixes_list: prefixes for each source sentence.
* sentence_vector_list: result sentence_vector for each source sentence.
* Return:
* a new prefixes list for each source of current step
*/
std::vector<BeamNodeVector<T>> PackTwoSteps(
const LoDTensor& cur_ids, const LoDTensor& cur_scores,
std::vector<BeamNodeVector<T>>& prefixes_list,
std::vector<SentenceVector<T>>* sentence_vector_list) const;
/**
* convert the result sentence_vector for each source sentence into two
* LodTensor.
* One is all candidate sentences with word id, one is all candidate sentences
* with word score.
* Param:
* sentence_vector_list: sentence_vector for each source sentence.
* id_tensor: result LoDTensor for sentences of id.
* score_tensor: result LoDTensor for sentences of score.
*/
void ConvertSentenceVectorToLodTensor(
std::vector<SentenceVector<T>> sentence_vector_list, LoDTensor* id_tensor,
LoDTensor* score_tensor) const;
/**
* Pack all steps of id/score LodTensor into sentence LoDTensor
* it's main logic is:
* ```python
* prefix
* result_sentence
* result_lod_tensor
*
* for (step in steps):
* prefix = PackTwoSteps(prefix, step, &result_sentence)
* ConvertSentenceVector<T>ToLodTensor(result_sentence, &result_lod_tensor)
* ```
*/
void PackAllSteps(const LoDTensorArray& step_ids,
const LoDTensorArray& step_scores, LoDTensor* id_tensor,
LoDTensor* score_tensor) const;
};
template <typename T>
Sentence<T> BeamSearchDecoder<T>::MakeSentence(const BeamNode<T>* node) const {
Sentence<T> sentence;
while (node != nullptr) {
sentence.word_ids.emplace_back(node->word_id_);
sentence.scores.emplace_back(node->score_);
node = node->parent_;
}
std::reverse(std::begin(sentence.word_ids), std::end(sentence.word_ids));
std::reverse(std::begin(sentence.scores), std::end(sentence.scores));
return sentence;
}
template <typename T>
std::vector<BeamNodeVector<T>> BeamSearchDecoder<T>::PackTwoSteps(
const LoDTensor& cur_ids, const LoDTensor& cur_scores,
std::vector<BeamNodeVector<T>>& prefixes_list,
std::vector<SentenceVector<T>>* sentence_vector_list) const {
std::vector<BeamNodeVector<T>> result;
for (size_t src_idx = 0; src_idx < cur_ids.lod()[kSourceLevel].size() - 1;
++src_idx) {
size_t src_start = cur_ids.lod().at(kSourceLevel)[src_idx];
size_t src_end = cur_ids.lod().at(kSourceLevel)[src_idx + 1];
BeamNodeVector<T> beam_nodes;
// if prefixes size is 0, it means this is the first step. In this step,
// all candidate id is the start of candidate sentences.
if (prefixes_list.empty()) {
PADDLE_ENFORCE_EQ(cur_ids.lod().at(kSourceLevel).back(),
cur_ids.lod().at(kSentenceLevel).back(),
"in the first step");
for (size_t id_idx = src_start; id_idx < src_end; ++id_idx) {
beam_nodes.push_back(std::unique_ptr<BeamNode<T>>(new BeamNode<T>(
cur_ids.data<int64_t>()[id_idx], cur_scores.data<T>()[id_idx])));
}
} else {
BeamNodeVector<T>& prefixes = prefixes_list[src_idx];
SentenceVector<T>& sentence_vector = (*sentence_vector_list)[src_idx];
PADDLE_ENFORCE_EQ(src_end - src_start, prefixes.size(),
"prefix and candidate set number should be the same");
auto candidate_offset = cur_ids.lod()[kSentenceLevel];
for (size_t prefix_idx = 0; prefix_idx < prefixes.size(); ++prefix_idx) {
std::unique_ptr<BeamNode<T>>& prefix = prefixes[prefix_idx];
size_t candidate_start = candidate_offset[src_start + prefix_idx];
size_t candidate_end = candidate_offset[src_start + prefix_idx + 1];
if (candidate_start == candidate_end) {
VLOG(3) << "this sentence has no more candidate, "
"add to result sentence and rm it from beam tree";
sentence_vector.push_back(MakeSentence(prefix.get()));
prefix.reset();
} else {
for (size_t candidate_idx = candidate_start;
candidate_idx < candidate_end; ++candidate_idx) {
auto* candidate =
new BeamNode<T>(cur_ids.data<int64_t>()[candidate_idx],
cur_scores.data<T>()[candidate_idx]);
candidate->AppendTo(prefix.get());
beam_nodes.push_back(std::unique_ptr<BeamNode<T>>(candidate));
}
prefix.release();
}
}
}
result.push_back(std::move(beam_nodes));
}
return result;
}
template <typename T>
void BeamSearchDecoder<T>::ConvertSentenceVectorToLodTensor(
std::vector<SentenceVector<T>> sentence_vector_list, LoDTensor* id_tensor,
LoDTensor* score_tensor) const {
size_t src_num = sentence_vector_list.size();
PADDLE_ENFORCE_NE(src_num, 0, "src_num should not be 0");
std::vector<size_t> source_level_lod = {0};
std::vector<size_t> sentence_level_lod = {0};
std::vector<int64_t> id_data;
std::vector<T> score_data;
for (size_t src_idx = 0; src_idx < src_num; ++src_idx) {
for (Sentence<T>& sentence : sentence_vector_list[src_idx]) {
id_data.insert(id_data.end(), sentence.word_ids.begin(),
sentence.word_ids.end());
score_data.insert(score_data.end(), sentence.scores.begin(),
sentence.scores.end());
sentence_level_lod.push_back(sentence_level_lod.back() +
sentence.word_ids.size());
}
source_level_lod.push_back(source_level_lod.back() +
sentence_vector_list[src_idx].size());
}
auto cpu_place = new paddle::platform::CPUPlace();
paddle::platform::CPUDeviceContext cpu_ctx(*cpu_place);
framework::LoD lod;
lod.push_back(source_level_lod);
lod.push_back(sentence_level_lod);
id_tensor->set_lod(lod);
id_tensor->Resize({static_cast<int64_t>(id_data.size())});
id_tensor->mutable_data<int64_t>(paddle::platform::CPUPlace());
id_tensor->CopyFromVector<int64_t>(id_data, cpu_ctx);
score_tensor->set_lod(lod);
score_tensor->Resize({static_cast<int64_t>(score_data.size())});
score_tensor->mutable_data<T>(paddle::platform::CPUPlace());
score_tensor->CopyFromVector<T>(score_data, cpu_ctx);
}
template <typename T>
void BeamSearchDecoder<T>::PackAllSteps(const LoDTensorArray& step_ids,
const LoDTensorArray& step_scores,
LoDTensor* id_tensor,
LoDTensor* score_tensor) const {
PADDLE_ENFORCE(!step_ids.empty(), "step num should be larger than 0");
PADDLE_ENFORCE_EQ(step_ids.size(), step_scores.size(),
"step_ids and step_scores should be the same");
const size_t step_num = step_ids.size();
const size_t src_num = step_ids.at(0).lod().at(kSourceLevel).size() - 1;
PADDLE_ENFORCE_GT(src_num, 0UL, "source num should be larger than 0");
// previous prefixes for each step,
// the init length is 0, means this is the first step.
std::vector<BeamNodeVector<T>> beamnode_vector_list(0);
std::vector<SentenceVector<T>> sentence_vector_list(src_num);
// pack all steps for one batch first, then another batch
for (size_t step_id = 0; step_id < step_num; ++step_id) {
beamnode_vector_list =
PackTwoSteps(step_ids.at(step_id), step_scores.at(step_id),
beamnode_vector_list, &sentence_vector_list);
}
// append last beam_node to result
for (size_t src_idx = 0; src_idx < src_num; ++src_idx) {
for (auto& beam_node : beamnode_vector_list.at(src_idx)) {
sentence_vector_list[src_idx].push_back(MakeSentence(beam_node.get()));
beam_node.reset();
}
}
ConvertSentenceVectorToLodTensor(sentence_vector_list, id_tensor,
score_tensor);
}
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/beam_search_decode_op.h"
#include "gtest/gtest.h"
using CPUPlace = paddle::platform::CPUPlace;
using LoD = paddle::framework::LoD;
using LoDTensor = paddle::framework::LoDTensor;
using LoDTensorArray = paddle::framework::LoDTensorArray;
template <typename T>
using BeamNode = paddle::operators::BeamNode<T>;
template <typename T>
using BeamSearchDecoder = paddle::operators::BeamSearchDecoder<T>;
template <typename T>
using Sentence = paddle::operators::Sentence<T>;
template <typename T>
using BeamNodeVector = paddle::operators::BeamNodeVector<T>;
template <typename T>
using SentenceVector = paddle::operators::SentenceVector<T>;
namespace paddle {
namespace test {
void GenerateExample(const std::vector<size_t>& level_0,
const std::vector<size_t>& level_1,
const std::vector<int>& data, LoDTensorArray* ids,
LoDTensorArray* scores) {
PADDLE_ENFORCE_EQ(level_0.back(), level_1.size() - 1,
"source level is used to describe candidate set");
PADDLE_ENFORCE_EQ(level_1.back(), data.size(),
"the lowest level is used to describe data"
", so it's last element should be data length");
CPUPlace place;
LoD lod;
lod.push_back(level_0);
lod.push_back(level_1);
// Ids
LoDTensor tensor_id;
tensor_id.set_lod(lod);
tensor_id.Resize({static_cast<int64_t>(data.size())});
// malloc memory
int64_t* id_ptr = tensor_id.mutable_data<int64_t>(place);
for (size_t i = 0; i < data.size(); ++i) {
id_ptr[i] = static_cast<int64_t>(data.at(i));
}
// Scores
LoDTensor tensor_score;
tensor_score.set_lod(lod);
tensor_score.Resize({static_cast<int64_t>(data.size())});
// malloc memory
float* score_ptr = tensor_score.mutable_data<float>(place);
for (size_t i = 0; i < data.size(); ++i) {
score_ptr[i] = static_cast<float>(data.at(i));
}
ids->push_back(tensor_id);
scores->push_back(tensor_score);
}
} // namespace test
} // namespace paddle
TEST(BeamSearchDecodeOp, DeleteBeamNode) {
auto* root = new BeamNode<float>(0, 0);
auto* b1 = new BeamNode<float>(1, 1);
auto* b2 = new BeamNode<float>(2, 2);
auto* b3 = new BeamNode<float>(3, 3);
b1->AppendTo(root);
b2->AppendTo(root);
b3->AppendTo(b1);
delete b3;
delete b2;
}
TEST(BeamSearchDecodeOp, MakeSentence) {
auto* root = new BeamNode<float>(0, 0);
auto* b1 = new BeamNode<float>(1, 1);
auto* end = new BeamNode<float>(2, 2);
b1->AppendTo(root);
end->AppendTo(b1);
BeamSearchDecoder<float> helper;
Sentence<float> sentence = helper.MakeSentence(end);
delete end;
std::vector<int64_t> expect_ids = {0, 1, 2};
ASSERT_EQ(sentence.word_ids, expect_ids);
std::vector<float> expect_scores = {0, 1, 2};
ASSERT_EQ(sentence.scores, expect_scores);
}
TEST(BeamSearchDecodeOp, PackTwoStepsFistStep) {
CPUPlace place;
LoDTensorArray ids;
LoDTensorArray scores;
paddle::test::GenerateExample(
std::vector<size_t>{0, 2, 6}, std::vector<size_t>{0, 1, 2, 3, 4, 5, 6},
std::vector<int>{1, 2, 3, 4, 5, 6}, &ids, &scores);
std::vector<BeamNodeVector<float>> beamnode_vector_list;
std::vector<SentenceVector<float>> sentence_vector_list(
2, SentenceVector<float>());
BeamSearchDecoder<float> helper;
beamnode_vector_list = helper.PackTwoSteps(
ids[0], scores[0], beamnode_vector_list, &sentence_vector_list);
ASSERT_EQ(beamnode_vector_list.size(), 2UL);
ASSERT_EQ(beamnode_vector_list[0].size(), 2UL);
ASSERT_EQ(beamnode_vector_list[1].size(), 4UL);
}
TEST(BeamSearchDecodeOp, PackTwoSteps) {
CPUPlace place;
// first source has three prefix
BeamNodeVector<float> source0_prefixes;
source0_prefixes.push_back(
std::unique_ptr<BeamNode<float>>(new BeamNode<float>(1, 1)));
source0_prefixes.push_back(
std::unique_ptr<BeamNode<float>>(new BeamNode<float>(0, 0)));
source0_prefixes.push_back(
std::unique_ptr<BeamNode<float>>(new BeamNode<float>(3, 3)));
// second source has two prefix
BeamNodeVector<float> source1_prefixes;
source1_prefixes.push_back(
std::unique_ptr<BeamNode<float>>(new BeamNode<float>(4, 4)));
source1_prefixes.push_back(
std::unique_ptr<BeamNode<float>>(new BeamNode<float>(5, 5)));
std::vector<BeamNodeVector<float>> beamnode_vector_list;
std::vector<SentenceVector<float>> sentence_vector_list(
2, SentenceVector<float>());
beamnode_vector_list.push_back(std::move(source0_prefixes));
beamnode_vector_list.push_back(std::move(source1_prefixes));
// generate data for one step
LoDTensorArray ids;
LoDTensorArray scores;
paddle::test::GenerateExample(std::vector<size_t>{0, 3, 5},
std::vector<size_t>{0, 1, 1, 3, 4, 5},
std::vector<int>{0, 1, 2, 3, 4}, &ids, &scores);
BeamSearchDecoder<float> helper1;
beamnode_vector_list = helper1.PackTwoSteps(
ids[0], scores[0], beamnode_vector_list, &sentence_vector_list);
ASSERT_EQ(sentence_vector_list[0].size(), 1UL);
ASSERT_EQ(sentence_vector_list[1].size(), 0UL);
ASSERT_EQ(beamnode_vector_list[0].size(), 3UL);
ASSERT_EQ(beamnode_vector_list[1].size(), 2UL);
}
TEST(BeamSearchDecodeOp, PackAllSteps) {
CPUPlace place;
// we will constuct a sample data with 3 steps and 2 source sentences
LoDTensorArray ids;
LoDTensorArray scores;
paddle::test::GenerateExample(
std::vector<size_t>{0, 3, 6}, std::vector<size_t>{0, 1, 2, 3, 4, 5, 6},
std::vector<int>{1, 2, 3, 4, 5, 6}, &ids, &scores);
paddle::test::GenerateExample(
std::vector<size_t>{0, 3, 6}, std::vector<size_t>{0, 1, 1, 3, 5, 5, 6},
std::vector<int>{0, 1, 2, 3, 4, 5}, &ids, &scores);
paddle::test::GenerateExample(std::vector<size_t>{0, 3, 6},
std::vector<size_t>{0, 0, 1, 2, 3, 4, 5},
std::vector<int>{0, 1, 2, 3, 4}, &ids, &scores);
ASSERT_EQ(ids.size(), 3UL);
ASSERT_EQ(scores.size(), 3UL);
BeamSearchDecoder<float> helper;
LoDTensor id_tensor;
LoDTensor score_tensor;
helper.PackAllSteps(ids, scores, &id_tensor, &score_tensor);
LoD lod = id_tensor.lod();
std::vector<size_t> expect_source_lod = {0, 4, 8};
EXPECT_EQ(lod[0], expect_source_lod);
std::vector<size_t> expect_sentence_lod = {0, 1, 3, 6, 9, 10, 13, 16, 19};
EXPECT_EQ(lod[1], expect_sentence_lod);
// 2| 1, 0| 3, 1, 0| 3, 2, 1| 5| 4, 3, 2| 4, 4, 3| 6, 5, 4
std::vector<int> expect_data = {2, 1, 0, 3, 1, 0, 3, 2, 1, 5,
4, 3, 2, 4, 4, 3, 6, 5, 4};
ASSERT_EQ(id_tensor.dims()[0], static_cast<int64_t>(expect_data.size()));
for (size_t i = 0; i < expect_data.size(); ++i) {
ASSERT_EQ(id_tensor.data<int64_t>()[i],
static_cast<int64_t>(expect_data[i]));
}
for (int64_t i = 0; i < id_tensor.dims()[0]; ++i) {
ASSERT_EQ(score_tensor.data<float>()[i],
static_cast<float>(id_tensor.data<int64_t>()[i]));
}
}
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/bilinear_tensor_product_op.h"
namespace paddle {
namespace operators {
using framework::Tensor;
class BilinearTensorProductOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Weight"),
"Input(Weight) should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null.");
auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
auto weight_dims = ctx->GetInputDim("Weight");
PADDLE_ENFORCE_EQ(x_dims.size(), 2UL, "The input(X) must be a 2D Tensor.");
PADDLE_ENFORCE_EQ(y_dims.size(), 2UL, "The input(Y) must be a 2D Tensor.");
PADDLE_ENFORCE_EQ(weight_dims.size(), 3UL,
"The input(Weight) must be a 3D tensor.");
PADDLE_ENFORCE_EQ(x_dims[0], y_dims[0],
"The first dimension(batch_size) of input(X) must be "
"equal to the first dimension of the input(Y).");
PADDLE_ENFORCE_EQ(x_dims[1], weight_dims[1],
"The second dimension of input(X) must be equal to "
"the second dimension of the input(Weight).");
PADDLE_ENFORCE_EQ(y_dims[1], weight_dims[2],
"The second dimension of input(Y) must be equal to "
"the third dimension of the input(Weight).");
if (ctx->HasInput("Bias")) {
auto bias_dims = ctx->GetInputDim("Bias");
PADDLE_ENFORCE(bias_dims.size() == 2UL && bias_dims[0] == 1UL,
"The Input(Bias) must be a 2-D tensor with "
"the 2nd dimension fixed to 1 (a row vector).");
PADDLE_ENFORCE_EQ(bias_dims[1], weight_dims[0],
"The second dimension of input(Bias) must be equal "
"to the first dimension of the input(Weight).");
}
ctx->SetOutputDim("Out", {x_dims[0], weight_dims[0]});
ctx->ShareLoD("X", /*->*/ "Out");
}
};
class BilinearTensorProductOpMaker : public framework::OpProtoAndCheckerMaker {
public:
BilinearTensorProductOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The first input of bilinear_tensor_product operator.");
AddInput("Y", "The second input of bilinear_tensor_product operator.");
AddInput("Weight",
"The learnable parameters of bilinear_tensor_product operator.");
AddInput("Bias", "The learnable bias of bilinear_tensor_product operator.")
.AsDispensable();
AddOutput("Out", "The output of bilinear_tensor_product operator.");
AddComment(R"DOC(
Bilinear Tensor Product operator.
Given input X and Y, a 3D tensor weight, and bias. Each column of the
output is computed by one slice i = 1, . . . , k of the tensor:
M = (X W_i) \cdot Y
Out_i = \sum_i {M_i} + Bias_i
)DOC");
}
};
class BilinearTensorProductOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Weight"),
"Input(Weight) should not be null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null.");
auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
auto weight_dims = ctx->GetInputDim("Weight");
auto out_dims = ctx->GetInputDim(framework::GradVarName("Out"));
PADDLE_ENFORCE_EQ(out_dims.size(), 2UL,
"The input(Out@GRAD) must be a 2D Tensor.");
PADDLE_ENFORCE_EQ(
x_dims[0], out_dims[0],
"The first dimension(batch_size) of input(Out@GRAD) must be "
"equal to the first dimension of the Input(X).");
PADDLE_ENFORCE_EQ(
weight_dims[0], out_dims[1],
"The second dimension of input(Out@GRAD) must be equal to "
"the third dimension of the Input(Weight).");
if (ctx->HasInput("Bias")) {
auto bias_dims = ctx->GetInputDim("Bias");
PADDLE_ENFORCE_EQ(
bias_dims[1], out_dims[1],
"The second dimension of input(Out@GRAD) must be equal to "
"the second dimension of the Input(Bias).");
auto bias_grad_name = framework::GradVarName("Bias");
if (ctx->HasOutput(bias_grad_name))
ctx->SetOutputDim(bias_grad_name, bias_dims);
}
auto x_grad_name = framework::GradVarName("X");
auto y_grad_name = framework::GradVarName("Y");
auto weight_grad_name = framework::GradVarName("Weight");
if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, x_dims);
}
if (ctx->HasOutput(y_grad_name)) {
ctx->SetOutputDim(y_grad_name, y_dims);
}
if (ctx->HasOutput(weight_grad_name)) {
ctx->SetOutputDim(weight_grad_name, weight_dims);
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(bilinear_tensor_product, ops::BilinearTensorProductOp,
ops::BilinearTensorProductOpMaker, bilinear_tensor_product_grad,
ops::BilinearTensorProductOpGrad);
REGISTER_OP_CPU_KERNEL(
bilinear_tensor_product,
ops::BilinearTensorProductKernel<paddle::platform::CPUPlace, float>,
ops::BilinearTensorProductKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL(
bilinear_tensor_product_grad,
ops::BilinearTensorProductGradKernel<paddle::platform::CPUPlace, float>,
ops::BilinearTensorProductGradKernel<paddle::platform::CPUPlace, double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/operators/bilinear_tensor_product_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
bilinear_tensor_product,
ops::BilinearTensorProductKernel<paddle::platform::GPUPlace, float>,
ops::BilinearTensorProductKernel<paddle::platform::GPUPlace, double>);
REGISTER_OP_GPU_KERNEL(
bilinear_tensor_product_grad,
ops::BilinearTensorProductGradKernel<paddle::platform::GPUPlace, float>,
ops::BilinearTensorProductGradKernel<paddle::platform::GPUPlace, double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/math_function.h"
namespace paddle {
namespace operators {
using framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename Place, typename T>
class BilinearTensorProductKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* weight = ctx.Input<Tensor>("Weight");
auto* bias = ctx.Input<Tensor>("Bias");
auto* out = ctx.Output<Tensor>("Out");
out->mutable_data<T>(ctx.GetPlace());
auto y_mat = EigenMatrix<T>::From(*y);
auto output_mat = EigenMatrix<T>::From(*out);
auto batch_size = x->dims()[0];
auto weight_dims = weight->dims();
int out_dim = weight_dims[0];
auto x_dim = weight_dims[1];
auto y_dim = weight_dims[2];
auto place = ctx.GetEigenDevice<Place>();
// Create the intermediate variable to caculate the result of
// Input(X) multiplied by Input(Weight_i), the formula is:
// left_mul = X Weight_i.
Tensor left_mul;
left_mul.mutable_data<T>(framework::make_ddim({batch_size, y_dim}),
ctx.GetPlace());
auto left_mul_mat = EigenMatrix<T>::From(left_mul);
for (int i = 0; i < out_dim; ++i) {
auto output_col_vec = output_mat.chip(i, 1);
Tensor weight_mat =
weight->Slice(i, i + 1).Resize(framework::make_ddim({x_dim, y_dim}));
math::gemm<Place, T>(ctx.device_context(), CblasNoTrans, CblasNoTrans,
batch_size, y_dim, x_dim, 1, x->data<T>(),
weight_mat.data<T>(), 0, left_mul.data<T>());
output_col_vec.device(place) =
(left_mul_mat * y_mat).sum(Eigen::DSizes<int, 1>(1));
}
if (bias) {
auto bias_vec = EigenMatrix<T>::From(*bias);
Eigen::DSizes<int, 2> bcast(batch_size, 1);
output_mat.device(place) = bias_vec.broadcast(bcast) + output_mat;
}
}
};
template <typename Place, typename T>
class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const Tensor* x = ctx.Input<Tensor>("X");
const Tensor* y = ctx.Input<Tensor>("Y");
const Tensor* weight = ctx.Input<Tensor>("Weight");
Tensor* d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
Tensor* d_y = ctx.Output<Tensor>(framework::GradVarName("Y"));
Tensor* d_weight = ctx.Output<Tensor>(framework::GradVarName("Weight"));
Tensor* d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias"));
const Tensor* d_out = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto batch_size = x->dims()[0];
auto weight_dims = weight->dims();
int out_dim = weight_dims[0];
auto x_dim = weight_dims[1];
auto y_dim = weight_dims[2];
auto x_mat = EigenMatrix<T>::From(*x);
auto y_mat = EigenMatrix<T>::From(*y);
auto d_out_mat = EigenMatrix<T>::From(*d_out);
auto place = ctx.GetEigenDevice<Place>();
// Create the intermediate variable to caculate the Output(Y@Grad).
Tensor x_scale;
x_scale.mutable_data<T>(framework::make_ddim({batch_size, x_dim}),
ctx.GetPlace());
auto x_scale_mat = EigenMatrix<T>::From(x_scale);
// Create the intermediate variable to caculate the Output(X@Grad).
Tensor y_scale;
y_scale.mutable_data<T>(framework::make_ddim({batch_size, y_dim}),
ctx.GetPlace());
auto y_scale_mat = EigenMatrix<T>::From(y_scale);
math::SetConstant<Place, T> set_zero;
// Set Output(X@Grad) be zero.
if (d_x) {
d_x->mutable_data<T>(ctx.GetPlace());
set_zero(ctx.device_context(), d_x, static_cast<T>(0));
}
// Set Output(Y@Grad) be zero.
if (d_y) {
d_y->mutable_data<T>(ctx.GetPlace());
set_zero(ctx.device_context(), d_y, static_cast<T>(0));
}
// Caculate the Output(X@Grad) and Output(Y@Grad).
if (d_x || d_y) {
Eigen::DSizes<int, 2> bcast_for_x(1, y_dim);
Eigen::DSizes<int, 2> bcast_for_y(1, x_dim);
for (int i = 0; i < out_dim; ++i) {
Tensor weight_i = weight->Slice(i, i + 1).Resize(
framework::make_ddim({x_dim, y_dim}));
auto output_vec = d_out_mat.chip(i, 1);
if (d_x) {
y_scale_mat.device(place) =
output_vec.reshape(Eigen::DSizes<int, 2>(batch_size, 1))
.broadcast(bcast_for_x) *
y_mat;
math::gemm<Place, T>(ctx.device_context(), CblasNoTrans, CblasTrans,
batch_size, x_dim, y_dim, 1, y_scale.data<T>(),
weight_i.data<T>(), 1, d_x->data<T>());
}
if (d_y) {
x_scale_mat.device(place) =
output_vec.reshape(Eigen::DSizes<int, 2>(batch_size, 1))
.broadcast(bcast_for_y) *
x_mat;
math::gemm<Place, T>(ctx.device_context(), CblasNoTrans, CblasNoTrans,
batch_size, y_dim, x_dim, 1, x_scale.data<T>(),
weight_i.data<T>(), 1, d_y->data<T>());
}
}
}
// Caculate the gradient of Input(Weight).
if (d_weight) {
d_weight->mutable_data<T>(ctx.GetPlace());
Eigen::DSizes<int, 2> bcast_for_weight(1, x_dim);
for (int i = 0; i < out_dim; ++i) {
Tensor d_weight_i = d_weight->Slice(i, i + 1).Resize(
framework::make_ddim({x_dim, y_dim}));
auto output_vec = d_out_mat.chip(i, 1);
x_scale_mat.device(place) =
output_vec.reshape(Eigen::DSizes<int, 2>(batch_size, 1))
.broadcast(bcast_for_weight) *
x_mat;
math::gemm<Place, T>(ctx.device_context(), CblasTrans, CblasNoTrans,
x_dim, y_dim, batch_size, 1, x_scale.data<T>(),
y->data<T>(), 0, d_weight_i.data<T>());
}
}
// Caculate the gradient of Input(Bias).
if (d_bias) {
d_bias->mutable_data<T>(ctx.GetPlace());
auto d_bias_mat = EigenMatrix<T>::From(*d_bias);
d_bias_mat.device(place) = d_out_mat.sum(Eigen::DSizes<int, 1>(0));
}
}
};
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include "paddle/framework/executor.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
class ConditionalOp : public framework::OperatorBase {
public:
ConditionalOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
protected:
std::vector<const framework::LoDTensor *> InputTensors(
const framework::Scope &scope) const {
std::vector<const framework::LoDTensor *> retv;
auto xs = Inputs("X");
retv.resize(xs.size(), nullptr);
std::transform(
xs.begin(), xs.end(), retv.begin(),
[&scope](const std::string &var_name) -> const framework::LoDTensor * {
auto *var = scope.FindVar(var_name);
PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s", var_name);
return &var->Get<framework::LoDTensor>();
});
return retv;
}
};
class ConditionalBlockOp : public ConditionalOp {
public:
ConditionalBlockOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: ConditionalOp(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override {
auto xs = InputTensors(scope);
bool need_run = std::all_of(
xs.begin(), xs.end(),
[](const framework::LoDTensor *t) { return t->numel() != 0; });
if (need_run) {
auto *scope_var = scope.FindVar(Output("Scope"));
PADDLE_ENFORCE(scope_var != nullptr, "Must set scope");
auto *scopes = scope_var->GetMutable<std::vector<framework::Scope *>>();
scopes->resize(1);
scopes->front() = &scope.NewScope();
auto &cur_scope = *scopes->front();
auto *block = Attr<framework::BlockDescBind *>("block");
framework::Executor exec(dev_ctx);
exec.Run(*block->Program(), &cur_scope, block->ID(), false);
}
}
};
class ConditionalBlockOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
ConditionalBlockOpProtoMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X",
"The conditional variable of this operator. If X is empty, the "
"whole sub-block will not be executed.")
.AsDuplicable();
AddInput("Params", "The input variables of the sub-block.").AsDuplicable();
AddOutput("Out", "The output variables of the sub-block.").AsDuplicable();
AddOutput("Scope",
"(std::vector<Scope*>) The step scope of conditional block. To "
"unify the conditional block, rnn and while op, the type of "
"scope is std::vector<Scope*>");
AddAttr<framework::BlockDescBind *>(
"block", "The step block of conditional block operator");
AddComment(R"DOC(Conditional block operator
Run the sub-block if X is not empty. Params is the other inputs and Out is the
outputs of the sub-block.
)DOC");
}
};
class ConditionalBlockGradOp : public ConditionalOp {
public:
ConditionalBlockGradOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: ConditionalOp(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override {
auto xs = this->InputTensors(scope);
bool need_run = std::all_of(
xs.begin(), xs.end(),
[](const framework::LoDTensor *t) { return t->numel() != 0; });
if (need_run) {
auto *scope_var = scope.FindVar(Input("Scope"));
PADDLE_ENFORCE(scope_var != nullptr, "Must set scope");
auto &scopes = scope_var->Get<std::vector<framework::Scope *>>();
framework::Scope &cur_scope = *scopes[0];
auto *block = Attr<framework::BlockDescBind *>("block");
framework::Executor exec(dev_ctx);
exec.Run(*block->Program(), &cur_scope, block->ID(), false);
AssignLocalGradientToGlobal(dev_ctx, cur_scope, Inputs("Params"),
Outputs(framework::GradVarName("Params")));
AssignLocalGradientToGlobal(dev_ctx, cur_scope, Inputs("X"),
Outputs(framework::GradVarName("X")));
}
}
private:
void AssignLocalGradientToGlobal(
const platform::DeviceContext &dev_ctx, const framework::Scope &cur_scope,
const std::vector<std::string> &p_names,
const std::vector<std::string> &pg_names) const {
for (size_t i = 0; i < p_names.size(); ++i) {
auto out_grad_name = pg_names[i];
auto in_grad_name = framework::GradVarName(p_names[i]);
auto *in_var = cur_scope.FindVar(in_grad_name);
if (in_var == nullptr) {
continue;
}
auto new_in_grad_name = cur_scope.Rename(in_grad_name);
auto assign =
framework::OpRegistry::CreateOp("assign", {{"X", {new_in_grad_name}}},
{{"Out", {out_grad_name}}}, {});
assign->Run(cur_scope, dev_ctx);
cur_scope.Rename(new_in_grad_name, in_grad_name);
}
}
};
class ConditionalBlockGradInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *context) const override {
PADDLE_ENFORCE(context->HasInputs("X"));
if (context->HasInputs("Params")) {
PADDLE_ENFORCE(context->HasOutputs(framework::GradVarName("Params")));
context->SetOutputsDim(framework::GradVarName("Params"),
context->GetInputsDim("Params"));
}
PADDLE_ENFORCE(context->HasOutputs(framework::GradVarName("X")));
context->SetOutputsDim(framework::GradVarName("X"),
context->GetInputsDim("X"));
}
};
class ConditionalBlockGradMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDescBind> Apply() const override {
auto grad_op = new framework::OpDescBind();
grad_op->SetType("conditional_block_grad");
grad_op->SetInput("X", Input("X"));
grad_op->SetInput("Params", Input("Params"));
grad_op->SetInput("Out", Output("Out"));
grad_op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
grad_op->SetInput("Scope", Output("Scope"));
grad_op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
grad_op->SetOutput(framework::GradVarName("Params"), InputGrad("Params"));
grad_op->SetBlockAttr("block", *this->grad_block_[0]);
return std::unique_ptr<framework::OpDescBind>(grad_op);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(conditional_block, ops::ConditionalBlockOp,
ops::ConditionalBlockOpProtoMaker,
ops::ConditionalBlockGradMaker);
REGISTER_OPERATOR(conditional_block_grad, ops::ConditionalBlockGradOp,
ops::ConditionalBlockGradInferShape);
......@@ -29,7 +29,7 @@ class L1NormKernel : public framework::OpKernel<T> {
Out->mutable_data<T>(context.GetPlace());
auto x = framework::EigenVector<T>::Flatten(*X);
auto out = framework::EigenVector<T>::Flatten(*Out);
auto out = framework::EigenScalar<T>::From(*Out);
auto place = context.GetEigenDevice<Place>();
out.device(place) = x.abs().sum();
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/lod_reset_op.h"
namespace paddle {
namespace operators {
class LoDResetOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
// input check
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of LoDResetOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of LoDResetOp should not be null.");
// If target LoD is not set form Input(), then it must be set from Attr().
if (!ctx->HasInput("TargetLoD")) {
auto level0 = ctx->Attrs().Get<std::vector<int>>("target_lod");
PADDLE_ENFORCE(level0.size() > 1,
"Target LoD is not found, should be set to be a valid one "
"through Input() or Attr().");
}
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
}
protected:
framework::OpKernelType GetKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
ctx.device_context());
}
};
class LoDResetOpMaker : public framework::OpProtoAndCheckerMaker {
public:
LoDResetOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(LoDTensor) The input tensor of lod_reset operator.");
AddInput("TargetLoD",
"(Tensor, optional) The target level 0 LoD from Input().")
.AsDispensable();
AddOutput("Out", "(LoDTensor) The output tensor of lod_reset operator.");
AddAttr<std::vector<int>>("target_lod",
"The target level 0 LoD from Attr().")
.SetDefault(std::vector<int>{});
AddComment(R"DOC(LoDReset operator
Reset LoD of Input(X) into a new one specified by Input(TargetLoD) or
Attr(target_lod), or set LoD for Input(X) if it doesn't have one.
Currently the lod_reset operator only supports the reset of level 0 LoD.
At least one of Input(TargetLoD) and Attr(target_lod) must be set,
and if both of them are set, Input(TargetLoD) will be chosen as the
target LoD.
An example:
Given a float LoDTensor X with shape (6, 1), its transpose form represents
[1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
with LoD = [[0, 2, 5, 6]] and the three (transposed) sequences look like
[1.0, 2.0], [3.0, 4.0, 5.0], [6.0].
If target LoD = [0, 4, 6], the lod_reset operator will reset the LoD and
the sequences that the LoDTensor Output(Out) contains becomes:
[1.0, 2.0, 3.0, 4.0], [5.0, 6.0].
)DOC");
}
};
class LoDResetGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) shouldn't be null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) shouldn't be null.");
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
protected:
framework::OpKernelType GetKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
ctx.device_context());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(lod_reset, ops::LoDResetOp, ops::LoDResetOpMaker, lod_reset_grad,
ops::LoDResetGradOp);
REGISTER_OP_CPU_KERNEL(lod_reset,
ops::LoDResetKernel<paddle::platform::CPUPlace, float>,
ops::LoDResetKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL(
lod_reset_grad, ops::LoDResetGradKernel<paddle::platform::CPUPlace, float>,
ops::LoDResetGradKernel<paddle::platform::CPUPlace, double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/lod_reset_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(lod_reset,
ops::LoDResetKernel<paddle::platform::GPUPlace, float>,
ops::LoDResetKernel<paddle::platform::GPUPlace, double>);
REGISTER_OP_GPU_KERNEL(
lod_reset_grad, ops::LoDResetGradKernel<paddle::platform::GPUPlace, float>,
ops::LoDResetGradKernel<paddle::platform::GPUPlace, double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
template <typename Place, typename T>
class LoDResetKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const {
auto* out = ctx.Output<framework::LoDTensor>("Out");
auto* in = ctx.Input<framework::LoDTensor>("X");
auto* lod_t = ctx.Input<framework::Tensor>("TargetLoD");
std::vector<int> level0;
if (lod_t) {
auto* lod = lod_t->data<int>();
if (platform::is_gpu_place(ctx.GetPlace())) {
framework::Tensor lod_cpu;
lod_cpu.CopyFrom(*lod_t, platform::CPUPlace(), ctx.device_context());
lod = lod_cpu.data<int>();
}
level0 = std::vector<int>(lod, lod + lod_t->numel());
} else {
level0 = ctx.Attr<std::vector<int>>("target_lod");
}
PADDLE_ENFORCE(level0.size() > 1UL,
"The size of target LoD should be greater than 1.");
PADDLE_ENFORCE(level0[0] == 0,
"Target LoD should be a vector starting from 0.");
PADDLE_ENFORCE(level0.back() == in->dims()[0],
"Target LoD should be a vector end with the "
"first dimension of Input(X).");
for (size_t i = 0; i < level0.size() - 1; ++i) {
PADDLE_ENFORCE(level0[i + 1] > level0[i],
"Target LoD should be an ascending vector.");
}
out->ShareDataWith(*in);
// cast level0 to size_t
std::vector<size_t> ulevel0(level0.size(), 0);
std::transform(level0.begin(), level0.end(), ulevel0.begin(),
[](int a) { return static_cast<size_t>(a); });
framework::LoD target_lod;
target_lod.push_back(ulevel0);
out->set_lod(target_lod);
}
};
template <typename Place, typename T>
class LoDResetGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const {
auto* d_out = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* d_x = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
d_x->ShareDataWith(*d_out);
}
};
} // namespace operators
} // namespace paddle
......@@ -74,11 +74,10 @@ Tensor CombineBatchAndN(const framework::ExecutionContext& context,
Tensor output;
auto in_dims = input.dims();
if (in_dims.size() == 3) {
output.Resize(in_dims);
output.Resize({in_dims[1], in_dims[0], in_dims[2]});
output.mutable_data<T>(context.GetPlace());
EigenTranspose<Place, T, 3>(context, input, output, {1, 0, 2});
std::vector<int64_t> out_dims = {in_dims[1], in_dims[0] * in_dims[2]};
output.Resize(make_ddim(out_dims));
output.Resize({in_dims[1], in_dims[0] * in_dims[2]});
} else {
output.ShareDataWith(input);
}
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/op_registry.h"
#include "paddle/memory/memcpy.h"
namespace paddle {
namespace operators {
using LoD = framework::LoD;
class MergeLoDTensorOp : public framework::OperatorBase {
public:
MergeLoDTensorOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override {
auto &x = scope.FindVar(Input("X"))->Get<framework::LoDTensor>();
auto &mask = scope.FindVar(Input("Mask"))->Get<framework::LoDTensor>();
auto &in_true = scope.FindVar(Input("InTrue"))->Get<framework::LoDTensor>();
auto &in_false =
scope.FindVar(Input("InFalse"))->Get<framework::LoDTensor>();
auto *out =
scope.FindVar(Output("Out"))->GetMutable<framework::LoDTensor>();
auto level = static_cast<size_t>(Attr<int>("level"));
auto &mask_dim = mask.dims();
std::unique_ptr<framework::LoDTensor> cpu_mask{new framework::LoDTensor()};
if (platform::is_cpu_place(mask.place())) {
cpu_mask->ShareDataWith(mask);
} else if (platform::is_gpu_place(mask.place())) {
#ifdef PADDLE_WITH_CUDA
cpu_mask->CopyFrom(mask, platform::CPUPlace(), dev_ctx);
#else
PADDLE_THROW("Not supported GPU, Please compile WITH_GPU option");
#endif
}
auto *mask_data = cpu_mask->data<bool>();
int rank = in_true.dims().size();
platform::Place place = in_true.place();
std::type_index data_type = in_true.type();
framework::DDim in_true_dims =
framework::slice_ddim(in_true.dims(), 1, rank);
int64_t batch_size = in_true.dims()[0] + in_false.dims()[0];
auto in_true_dim_vec = framework::vectorize(in_true_dims);
in_true_dim_vec.insert(in_true_dim_vec.begin(), batch_size);
framework::DDim out_dims = framework::make_ddim(in_true_dim_vec);
out->Resize(out_dims);
out->mutable_data(place, data_type);
auto *out_lod = out->mutable_lod();
out_lod->clear();
size_t out_offset = 0;
// Build LoDTensor `out`
size_t in_true_idx = 0;
size_t in_false_idx = 0;
for (size_t i = 0; i < static_cast<size_t>(mask_dim[0]); i++) {
const framework::LoDTensor *input = nullptr;
size_t *in_idx = nullptr;
if (static_cast<int>(mask_data[i]) == 0) {
input = &in_false;
in_idx = &in_false_idx;
} else {
input = &in_true;
in_idx = &in_true_idx;
}
auto lod_and_offset = framework::GetSubLoDAndAbsoluteOffset(
input->lod(), *in_idx, (*in_idx) + 1, 0);
auto &lod_length = lod_and_offset.first;
framework::AppendLoD(out_lod, lod_length);
size_t start_offset = lod_and_offset.second.first;
size_t end_offset = lod_and_offset.second.second;
PADDLE_ENFORCE_GE(end_offset, start_offset);
size_t len = end_offset - start_offset;
if (len == 0) {
continue;
}
out->Slice(out_offset, out_offset + len)
.CopyFrom(input->Slice(start_offset, end_offset), place, dev_ctx);
out_offset += len;
(*in_idx) += 1;
}
for (size_t i = 0; i < level; i++) {
out_lod->insert(out_lod->begin(), x.lod()[i]);
}
}
};
class MergeLoDTensorOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
MergeLoDTensorOpProtoMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X",
"The input LoDTensor, contains complete lod information to "
"construct the output");
AddInput("Mask", "A bool column vector which mask the input");
AddInput("InTrue", "The True branch to be merged");
AddInput("InFalse", "The False branch to be merged");
AddOutput("Out", "The merged output LoDTensor");
AddAttr<int>("level", "(int) the specific lod level to rank.")
.SetDefault(0)
.EqualGreaterThan(0);
AddComment(
R"DOC(
Merge True and False branches of LoDTensor into a single Output,
with a mask at certain lod level. X is used to obtain complete
lod information. Please refer to SplitLoDTensorOp.)DOC");
}
};
class MergeLoDTensorInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *context) const override {
PADDLE_ENFORCE(context->HasInput("X"),
"MergeLoDTensorOp must has input X.");
PADDLE_ENFORCE(context->HasInput("Mask"),
"MergeLoDTensorOp must has input Mask.");
PADDLE_ENFORCE(context->HasInput("InTrue"),
"MergeLoDTensorOp must has input InTrue.");
PADDLE_ENFORCE(context->HasInput("InFalse"),
"MergeLoDTensorOp must has input InFalse.");
PADDLE_ENFORCE(context->HasOutput("Out"),
"MergeLoDTensorOp must has output Out");
auto mask_dim = context->GetInputDim("Mask");
PADDLE_ENFORCE_EQ(mask_dim.size(), 2);
PADDLE_ENFORCE_EQ(mask_dim[1], 1);
context->SetOutputDim("Out", context->GetInputDim("InTrue"));
}
};
class MergeLoDTensorGradMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDescBind> Apply() const override {
auto *grad_op = new framework::OpDescBind();
grad_op->SetType("split_lod_tensor");
grad_op->SetInput("X", OutputGrad("Out"));
grad_op->SetInput("Mask", Input("Mask"));
grad_op->SetOutput("OutTrue", InputGrad("InTrue"));
grad_op->SetOutput("OutFalse", InputGrad("InFalse"));
grad_op->SetAttrMap(Attrs());
return std::unique_ptr<framework::OpDescBind>(grad_op);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(merge_lod_tensor, ops::MergeLoDTensorOp,
ops::MergeLoDTensorOpProtoMaker,
ops::MergeLoDTensorInferShape, ops::MergeLoDTensorGradMaker);
......@@ -47,7 +47,7 @@ class SequenceConcatOpMaker : public framework::OpProtoAndCheckerMaker {
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X",
"(vector<LoDTensor>) Input is a vector of LoDTensor, "
"(LodTensorArray) Input is a vector of LoDTensor, "
"each of which is a variable-length sequence or nested sequence.")
.AsDuplicable();
AddOutput("Out",
......
......@@ -126,6 +126,7 @@ class SequencePoolGradKernel : public framework::OpKernel<T> {
int64_t h = static_cast<int64_t>(lod[i + 1] - lod[i]);
auto in_g_e = EigenMatrix<T>::From(in_g_t, {h, w});
auto out_g_e = EigenMatrix<T>::From(out_g_t, {1, w});
auto out_g_e_v = EigenVector<T>::Flatten(out_g_t);
Eigen::DSizes<int, 2> bcast(h, 1);
if (pooltype == "AVERAGE") {
......@@ -136,9 +137,9 @@ class SequencePoolGradKernel : public framework::OpKernel<T> {
in_g_e.device(place) =
(out_g_e / std::sqrt(static_cast<T>(h))).broadcast(bcast);
} else if (pooltype == "LAST") {
in_g_e.chip(h - 1, 0).device(place) = out_g_e;
in_g_e.chip(h - 1, 0).device(place) = out_g_e_v;
} else if (pooltype == "FIRST") {
in_g_e.chip(0, 0).device(place) = out_g_e;
in_g_e.chip(0, 0).device(place) = out_g_e_v;
} else {
PADDLE_THROW("unsupported pooling pooltype");
}
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/op_registry.h"
#include "paddle/memory/memcpy.h"
namespace paddle {
namespace operators {
struct CopyRange {
size_t begin;
size_t end;
};
using LoD = framework::LoD;
class SplitLoDTensorOp : public framework::OperatorBase {
public:
SplitLoDTensorOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override {
auto &x = scope.FindVar(Input("X"))->Get<framework::LoDTensor>();
auto &mask = scope.FindVar(Input("Mask"))->Get<framework::LoDTensor>();
auto *out_true =
scope.FindVar(Output("OutTrue"))->GetMutable<framework::LoDTensor>();
auto *out_false =
scope.FindVar(Output("OutFalse"))->GetMutable<framework::LoDTensor>();
auto level = static_cast<size_t>(Attr<int>("level"));
auto &x_lod = x.lod();
auto &mask_dim = mask.dims();
std::unique_ptr<framework::LoDTensor> cpu_mask{new framework::LoDTensor()};
if (platform::is_cpu_place(mask.place())) {
cpu_mask->ShareDataWith(mask);
} else if (platform::is_gpu_place(mask.place())) {
#ifdef PADDLE_WITH_CUDA
cpu_mask->CopyFrom(mask, platform::CPUPlace(), dev_ctx);
#else
PADDLE_THROW("Not supported GPU, Please compile WITH_GPU option");
#endif
}
auto *mask_data = cpu_mask->data<bool>();
std::vector<std::vector<CopyRange>> copy_ranges(mask_dim[0]);
// set out_true/out_false lod
for (size_t t = 0; t < 2; t++) {
LoD *lod = nullptr;
if (t == 0) {
lod = out_false->mutable_lod();
} else {
lod = out_true->mutable_lod();
}
lod->clear();
for (size_t i = 0; i < static_cast<size_t>(mask_dim[0]); i++) {
if (static_cast<size_t>(mask_data[i]) == t) {
size_t start_idx = i;
auto lod_and_offset = framework::GetSubLoDAndAbsoluteOffset(
x_lod, start_idx, start_idx + 1, level);
auto &lod_length = lod_and_offset.first;
framework::AppendLoD(lod, lod_length);
size_t start_offset = lod_and_offset.second.first;
size_t end_offset = lod_and_offset.second.second;
copy_ranges[t].emplace_back(CopyRange{start_offset, end_offset});
}
}
}
for (size_t t = 0; t < 2; ++t) {
framework::LoDTensor *out;
if (t == 0) {
out = out_false;
} else {
out = out_true;
}
auto &ranges = copy_ranges[t];
size_t height = std::accumulate(
ranges.begin(), ranges.end(), 0UL,
[](size_t a, const CopyRange &b) { return a + b.end - b.begin; });
auto x_dim = x.dims();
x_dim[0] = static_cast<int64_t>(height);
out->Resize(x_dim);
out->mutable_data(x.place(), x.type());
size_t offset = 0;
for (auto &each_range : ranges) {
size_t len = each_range.end - each_range.begin;
if (len == 0) {
continue;
}
// out[offset: offset+len] = x[each_range.begin: each_range.end]
out->Slice(static_cast<int>(offset), static_cast<int>(offset + len))
.CopyFrom(x.Slice(static_cast<int>(each_range.begin),
static_cast<int>(each_range.end)),
x.place(), dev_ctx);
offset += len;
}
}
}
};
class SplitLoDTensorOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
SplitLoDTensorOpProtoMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input LoDTensor");
AddInput("Mask", "A bool column vector which mask the input");
AddOutput("OutTrue", "True branch of input LoDTensor");
AddOutput("OutFalse", "False branch of input LoDTensor");
AddAttr<int>("level", "(int) the specific lod level to split.")
.SetDefault(0)
.EqualGreaterThan(0);
AddComment(
R"DOC(
Split a LoDTensor with a Mask at certain level. The input LoDTensor
has 3 sequence at certain lod level. The Mask is a bool column vector,
such as [0, 1, 0] at the same level. The first and third sequence will
be send to False Output LoDTensor; whereas the second sequence will
be send to True Output LoDTensor. Please refer to MergeLoDTensorOp.)DOC");
}
};
class SplitLoDTensorInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *context) const override {
PADDLE_ENFORCE(context->HasInput("X"),
"SplitLoDTensorOp must has input X.");
PADDLE_ENFORCE(context->HasInput("Mask"),
"SplitLoDTensorOp must has input Mask.");
PADDLE_ENFORCE(context->HasOutput("OutTrue"),
"SplitLoDTensorOp must has output OutTrue.");
PADDLE_ENFORCE(context->HasOutput("OutFalse"),
"SplitLoDTensorOp must has output OutFalse.");
auto mask_dim = context->GetInputDim("Mask");
PADDLE_ENFORCE_EQ(mask_dim.size(), 2);
PADDLE_ENFORCE_EQ(mask_dim[1], 1);
context->SetOutputDim("OutTrue", context->GetInputDim("X"));
context->SetOutputDim("OutFalse", context->GetInputDim("X"));
}
};
class SplitLoDTensorArrayGradMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDescBind> Apply() const override {
auto *grad_op = new framework::OpDescBind();
grad_op->SetType("merge_lod_tensor");
grad_op->SetInput("InTrue", OutputGrad("OutTrue"));
grad_op->SetInput("InFalse", OutputGrad("OutFalse"));
grad_op->SetInput("Mask", Input("Mask"));
grad_op->SetInput("X", Input("X"));
grad_op->SetOutput("Out", InputGrad("X"));
grad_op->SetAttrMap(Attrs());
return std::unique_ptr<framework::OpDescBind>(grad_op);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(split_lod_tensor, ops::SplitLoDTensorOp,
ops::SplitLoDTensorOpProtoMaker,
ops::SplitLoDTensorInferShape,
ops::SplitLoDTensorArrayGradMaker);
......@@ -29,7 +29,7 @@ class SquaredL2NormKernel : public framework::OpKernel<T> {
Out->mutable_data<T>(context.GetPlace());
auto x = framework::EigenVector<T>::Flatten(*X);
auto out = framework::EigenVector<T>::Flatten(*Out);
auto out = framework::EigenScalar<T>::From(*Out);
auto place = context.GetEigenDevice<Place>();
out.device(place) = x.square().sum();
......
......@@ -27,20 +27,22 @@ namespace platform {
This wrap is a hack to avoid this bug.
*/
template <class Callable, class... Args>
template <typename Callable, typename... Args>
inline void call_once(std::once_flag& flag, Callable&& f, Args&&... args) {
bool good = false;
std::exception ex;
std::call_once(flag, [&]() {
try {
f(args...);
good = true;
} catch (const std::exception& e) {
ex = e;
} catch (...) {
ex = std::runtime_error("excption caught in call_once");
}
});
std::call_once(flag,
[&](Args&&... args) {
try {
f(args...);
good = true;
} catch (const std::exception& e) {
ex = e;
} catch (...) {
ex = std::runtime_error("excption caught in call_once");
}
},
args...);
if (!good) {
throw std::exception(ex);
}
......
......@@ -42,6 +42,9 @@ limitations under the License. */
#include "paddle/platform/gpu_info.h"
#endif
// disable auto conversion to list in Python
PYBIND11_MAKE_OPAQUE(paddle::framework::LoDTensorArray);
namespace paddle {
namespace pybind {
static size_t UniqueIntegerGenerator(const std::string &prefix) {
......
......@@ -37,10 +37,10 @@ configure_file(${CMAKE_CURRENT_SOURCE_DIR}/setup.py.in
${CMAKE_CURRENT_BINARY_DIR}/setup.py)
add_custom_command(OUTPUT ${PADDLE_SOURCE_DIR}/python/paddle/v2/framework/core.so
COMMAND cmake -E copy $<TARGET_FILE:paddle_pybind> ${PADDLE_SOURCE_DIR}/python/paddle/v2/framework/core.so
add_custom_command(OUTPUT ${PADDLE_SOURCE_DIR}/python/paddle/v2/fluid/core.so
COMMAND cmake -E copy $<TARGET_FILE:paddle_pybind> ${PADDLE_SOURCE_DIR}/python/paddle/v2/fluid/core.so
DEPENDS paddle_pybind)
add_custom_target(copy_paddle_pybind ALL DEPENDS ${PADDLE_SOURCE_DIR}/python/paddle/v2/framework/core.so)
add_custom_target(copy_paddle_pybind ALL DEPENDS ${PADDLE_SOURCE_DIR}/python/paddle/v2/fluid/core.so)
add_custom_command(OUTPUT ${PADDLE_PYTHON_BUILD_DIR}/.timestamp
......@@ -66,7 +66,7 @@ if (WITH_TESTING)
add_subdirectory(paddle/v2/tests)
add_subdirectory(paddle/v2/reader/tests)
add_subdirectory(paddle/v2/plot/tests)
add_subdirectory(paddle/v2/framework/tests)
add_subdirectory(paddle/v2/fluid/tests)
endif()
endif()
install(DIRECTORY ${PADDLE_PYTHON_PACKAGE_DIR}
......
......@@ -3658,9 +3658,10 @@ def gru_step_layer(input,
:param gate_act: Activation type of this layer's two gates. SigmoidActivation is
the default activation.
:type gate_act: BaseActivation
:param bias_attr: The bias attribute. If the parameter is set to False or an object
whose type is not ParameterAttribute, no bias is defined. If the
parameter is set to True, the bias is initialized to zero.
:param bias_attr: The parameter attribute for bias. If this parameter is set to
False or an object whose type is not ParameterAttribute, no bias
is defined. If this parameter is set to True,
the bias is initialized to zero.
:type bias_attr: ParameterAttribute | None | bool | Any
:param param_attr: The parameter attribute. See ParameterAttribute for details.
:type param_attr: ParameterAttribute
......@@ -3728,9 +3729,10 @@ def gru_step_naive_layer(input,
:param gate_act: Activation type of this layer's two gates. SigmoidActivation
is the default activation.
:type gate_act: BaseActivation
:param bias_attr: The bias attribute. If the parameter is set to False or an object
whose type is not ParameterAttribute, no bias is defined. If the
parameter is set to True, the bias is initialized to zero.
:param bias_attr: The parameter attribute for bias. If this parameter is set to
False or an object whose type is not ParameterAttribute, no bias
is defined. If this parameter is set to True,
the bias is initialized to zero.
:type bias_attr: ParameterAttribute | None | bool | Any
:param param_attr: The parameter attribute. See ParameterAttribute for details.
:type param_attr: ParameterAttribute
......@@ -3863,9 +3865,10 @@ def recurrent_layer(input,
:type input: LayerOutput
:param act: Activation type. TanhActivation is the default activation.
:type act: BaseActivation
:param bias_attr: The bias attribute. If the parameter is set to False or an object
whose type is not ParameterAttribute, no bias is defined. If the
parameter is set to True, the bias is initialized to zero.
:param bias_attr: The parameter attribute for bias. If this parameter is set to
False or an object whose type is not ParameterAttribute,
no bias is defined. If the parameter is set to True,
the bias is initialized to zero.
:type bias_attr: ParameterAttribute | None | bool | Any
:param param_attr: The parameter attribute. See ParameterAttribute for
details.
......@@ -4885,9 +4888,10 @@ def tensor_layer(a,
:param param_attr: The parameter attribute. See ParameterAttribute for
details.
:type param_attr: ParameterAttribute
:param bias_attr: The bias attribute. If the parameter is set to False or an object
whose type is not ParameterAttribute, no bias is defined. If the
parameter is set to True, the bias is initialized to zero.
:param bias_attr: The parameter attribute for bias. If this parameter is set to
False or an object whose type is not ParameterAttribute,
no bias is defined. If this parameter is set to True,
the bias is initialized to zero.
:type bias_attr: ParameterAttribute | None | bool | Any
:param layer_attr: The extra layer attribute. See ExtraLayerAttribute for
details.
......@@ -4961,9 +4965,10 @@ def selective_fc_layer(input,
:param param_attr: The parameter attribute. See ParameterAttribute for
details.
:type param_attr: ParameterAttribute
:param bias_attr: The bias attribute. If the parameter is set to False or an object
whose type is not ParameterAttribute, no bias is defined. If the
parameter is set to True, the bias is initialized to zero.
:param bias_attr: The parameter attribute for bias. If this parameter is set to
False or an object whose type is not ParameterAttribute,
no bias is defined. If this parameter is set to True,
the bias is initialized to zero.
:type bias_attr: ParameterAttribute | None | bool | Any
:param layer_attr: The extra layer attribute. See ExtraLayerAttribute for
details.
......@@ -5662,10 +5667,10 @@ def nce_layer(input,
to the num_classes. Each member of the list defines
the probability of a class given input x.
:type neg_distribution: list | tuple | collections.Sequence | None
:param bias_attr: The attribute for bias. If this parameter is set False or
any object whose type is not ParameterAttribute, no bias
is added. If this parameter is set True, the bias is
initialized to zero.
:param bias_attr: The parameter attribute for bias. If this parameter is set to
False or an object whose type is not ParameterAttribute,
no bias is defined. If this parameter is set to True,
the bias is initialized to zero.
:type bias_attr: ParameterAttribute | None | bool | Any
:param layer_attr: The extra layer attribute. See ExtraLayerAttribute for
details.
......@@ -6578,9 +6583,9 @@ def gated_unit_layer(input,
:param gate_param_attr: The parameter attribute of the gate. See ParameterAttribute
for details.
:type gate_param_attr: ParameterAttribute
:param gate_bias_attr: The bias attribute of the gate. If the parameter is set to False or
:param gate_bias_attr: The bias attribute of the gate. If this parameter is set to False or
an object whose type is not ParameterAttribute, no bias is defined.
If the parameter is set to True, the bias is initialized to zero.
If this parameter is set to True, the bias is initialized to zero.
:type gate_bias_attr: ParameterAttribute | bool | None | Any
:param inproj_attr: Extra layer attributes of the projection. See ExtraLayerAttribute for
details.
......@@ -6588,9 +6593,9 @@ def gated_unit_layer(input,
:param inproj_param_attr: The parameter attribute of the projection. See ParameterAttribute
for details.
:type inproj_param_attr: ParameterAttribute
:param inproj_bias_attr: The bias attribute of the projection. If the parameter is set to False
:param inproj_bias_attr: The bias attribute of the projection. If this parameter is set to False
or an object whose type is not ParameterAttribute, no bias is defined.
If the parameter is set to True, the bias is initialized to zero.
If this parameter is set to True, the bias is initialized to zero.
:type inproj_bias_attr: ParameterAttribute | bool | None | Any
:param layer_attr: Extra layer attribute of the product. See ExtraLayerAttribute for
details.
......
......@@ -681,34 +681,42 @@ def lstmemory_unit(input,
state_act=TanhActivation())
:param input: input layer.
:param input: Input layer.
:type input: LayerOutput
:param out_memory: output of previous time step
:param out_memory: The output of previous time step.
:type out_memory: LayerOutput | None
:param name: lstmemory unit name.
:param name: The lstmemory unit name.
:type name: basestring
:param size: lstmemory unit size.
:param size: The lstmemory unit size.
:type size: int
:param param_attr: parameter attribute, None means default attribute.
:param param_attr: The parameter attribute for the weights in
input to hidden projection.
None means default attribute.
:type param_attr: ParameterAttribute
:param act: last activiation type of lstm.
:param act: The last activiation type of lstm.
:type act: BaseActivation
:param gate_act: gate activiation type of lstm.
:param gate_act: The gate activiation type of lstm.
:type gate_act: BaseActivation
:param state_act: state activiation type of lstm.
:param state_act: The state activiation type of lstm.
:type state_act: BaseActivation
:param input_proj_bias_attr: bias attribute for input to hidden projection.
False means no bias, None means default bias.
:type input_proj_bias_attr: ParameterAttribute|False|None
:param input_proj_layer_attr: extra layer attribute for input to hidden
projection of the LSTM unit, such as dropout, error clipping.
:param input_proj_bias_attr: The parameter attribute for the bias in
input to hidden projection.
False or None means no bias.
If this parameter is set to True,
the bias is initialized to zero.
:type input_proj_bias_attr: ParameterAttribute|bool|None
:param input_proj_layer_attr: The extra layer attribute for
input to hidden projection of the LSTM unit,
such as dropout, error clipping.
:type input_proj_layer_attr: ExtraLayerAttribute
:param lstm_bias_attr: bias parameter attribute of lstm layer.
False means no bias, None means default bias.
:type lstm_bias_attr: ParameterAttribute|False|None
:param lstm_layer_attr: extra attribute of lstm layer.
:param lstm_bias_attr: The parameter attribute for the bias in lstm layer.
False or None means no bias.
If this parameter is set to True,
the bias is initialized to zero.
:type lstm_bias_attr: ParameterAttribute|True|None
:param lstm_layer_attr: The extra attribute of lstm layer.
:type lstm_layer_attr: ExtraLayerAttribute
:return: lstmemory unit name.
:return: The lstmemory unit name.
:rtype: LayerOutput
"""
if size is None:
......@@ -786,34 +794,42 @@ def lstmemory_group(input,
gate_act=SigmoidActivation(),
state_act=TanhActivation())
:param input: input layer.
:param input: Input layer.
:type input: LayerOutput
:param size: lstmemory group size.
:param size: The lstmemory group size.
:type size: int
:param name: name of lstmemory group.
:param name: The name of lstmemory group.
:type name: basestring
:param out_memory: output of previous time step.
:param out_memory: The output of previous time step.
:type out_memory: LayerOutput | None
:param reverse: process the input in a reverse order or not.
:param reverse: Process the input in a reverse order or not.
:type reverse: bool
:param param_attr: parameter attribute, None means default attribute.
:param param_attr: The parameter attribute for the weights in
input to hidden projection.
None means default attribute.
:type param_attr: ParameterAttribute
:param act: last activiation type of lstm.
:param act: The last activiation type of lstm.
:type act: BaseActivation
:param gate_act: gate activiation type of lstm.
:param gate_act: The gate activiation type of lstm.
:type gate_act: BaseActivation
:param state_act: state activiation type of lstm.
:param state_act: The state activiation type of lstm.
:type state_act: BaseActivation
:param lstm_bias_attr: bias parameter attribute of lstm layer.
False means no bias, None means default bias.
:type lstm_bias_attr: ParameterAttribute|False|None
:param input_proj_bias_attr: bias attribute for input to hidden projection.
False means no bias, None means default bias.
:type input_proj_bias_attr: ParameterAttribute|False|None
:param input_proj_layer_attr: extra layer attribute for input to hidden
projection of the LSTM unit, such as dropout, error clipping.
:param input_proj_bias_attr: The parameter attribute for the bias in
input to hidden projection.
False or None means no bias.
If this parameter is set to True,
the bias is initialized to zero.
:type input_proj_bias_attr: ParameterAttribute|bool|None
:param input_proj_layer_attr: The extra layer attribute for
input to hidden projection of the LSTM unit,
such as dropout, error clipping.
:type input_proj_layer_attr: ExtraLayerAttribute
:param lstm_layer_attr: lstm layer's extra attribute.
:param lstm_bias_attr: The parameter attribute for the bias in lstm layer.
False or None means no bias.
If this parameter is set to True,
the bias is initialized to zero.
:type lstm_bias_attr: ParameterAttribute|True|None
:param lstm_layer_attr: The extra attribute of lstm layer.
:type lstm_layer_attr: ExtraLayerAttribute
:return: the lstmemory group.
:rtype: LayerOutput
......
from paddle.v2.framework import framework as framework
from paddle.v2.fluid import framework as framework
__all__ = ['append_backward_ops']
......
......@@ -13,7 +13,7 @@ A `scoped_function` will take a `function` as input. That function will be
invoked in a new local scope.
"""
import paddle.v2.framework.core
import paddle.v2.fluid.core
import threading
__tl_scope__ = threading.local()
......@@ -27,13 +27,13 @@ __all__ = [
def get_cur_scope():
"""
Get current scope.
:rtype: paddle.v2.framework.core.Scope
:rtype: paddle.v2.fluid.core.Scope
"""
cur_scope_stack = getattr(__tl_scope__, 'cur_scope', None)
if cur_scope_stack is None:
__tl_scope__.cur_scope = list()
if len(__tl_scope__.cur_scope) == 0:
__tl_scope__.cur_scope.append(paddle.v2.framework.core.Scope())
__tl_scope__.cur_scope.append(paddle.v2.fluid.core.Scope())
return __tl_scope__.cur_scope[-1]
......
import paddle.v2.framework.op as op
import paddle.v2.fluid.op as op
import numpy as np
import paddle.v2.framework.core as core
import paddle.v2.fluid.core as core
def avg_accumulate(accumulated_var, per_eval, num_batches, place):
......@@ -22,7 +22,7 @@ class Evaluator(object):
NOTE: default run on CPUPlace(), running on GPUPlace doesn't improve performance much.
:param scope: the scope instance contains the input.
:type scope: paddle.v2.framework.core.scope
:type scope: paddle.v2.fluid.core.scope
:param operator: operator name for caculating the evaluation for each mini-batch.
:type operator: string
:param input: output variable name of forward network.
......
import paddle.v2.framework.core as core
from paddle.v2.framework.framework import Block, Program, g_main_program
import paddle.v2.fluid.core as core
from paddle.v2.fluid.framework import Block, Program, g_main_program
g_scope = core.Scope()
......
import paddle.v2.framework.core as core
import paddle.v2.framework.proto.framework_pb2 as framework_pb2
import paddle.v2.fluid.core as core
import paddle.v2.fluid.proto.framework_pb2 as framework_pb2
import collections
import numpy as np
import copy
......@@ -285,7 +285,7 @@ class Operator(object):
self.desc.check_attrs()
no_kernel_op_set = {
'feed', 'fetch', 'save', 'load', 'recurrent',
'rnn_memory_helper_grad', 'while'
'rnn_memory_helper_grad', 'conditional_block', 'while'
}
if type not in no_kernel_op_set:
self.desc.infer_var_type(self.block.desc)
......
import paddle.v2.framework.framework as framework
import paddle.v2.fluid.framework as framework
import numpy as np
__all__ = [
......
import os
import cPickle as pickle
from paddle.v2.framework.framework import Program, Parameter, g_main_program, \
from paddle.v2.fluid.framework import Program, Parameter, g_main_program, \
Variable
__all__ = [
......
import copy
import itertools
from paddle.v2.framework.framework import Variable, g_main_program, \
from paddle.v2.fluid.framework import Variable, g_main_program, \
g_startup_program, unique_name, Program
from paddle.v2.framework.initializer import ConstantInitializer, \
UniformInitializer
from paddle.v2.fluid.initializer import ConstantInitializer, \
UniformInitializer, XavierInitializer
class LayerHelper(object):
......@@ -61,7 +61,7 @@ class LayerHelper(object):
@property
def param_attr(self):
default = {'name': None, 'initializer': UniformInitializer()}
default = {'name': None, 'initializer': XavierInitializer()}
actual = self.kwargs.get('param_attr', None)
if actual is None:
actual = default
......@@ -70,10 +70,11 @@ class LayerHelper(object):
actual[default_field] = default[default_field]
return actual
@property
def bias_attr(self):
default = {'name': None, 'initializer': ConstantInitializer()}
default = {'name': None, 'initializer': XavierInitializer()}
bias_attr = self.kwargs.get('bias_attr', None)
if bias_attr is True:
if bias_attr is None:
bias_attr = default
if isinstance(bias_attr, dict):
......@@ -166,7 +167,7 @@ class LayerHelper(object):
num_flatten_dims = 1
size = list(input_var.shape[num_flatten_dims:])
bias_attr = self.bias_attr()
bias_attr = self.bias_attr
if not bias_attr:
return input_var
......
import paddle.v2.framework.core as core
from paddle.v2.framework.framework import OpProtoHolder, Variable, Program, \
import paddle.v2.fluid.core as core
import paddle.v2.fluid.proto.framework_pb2 as framework_pb2
from paddle.v2.fluid.framework import OpProtoHolder, Variable, Program, \
Operator
from paddle.v2.framework.initializer import ConstantInitializer, \
from paddle.v2.fluid.initializer import ConstantInitializer, \
NormalInitializer
from paddle.v2.framework.layer_helper import LayerHelper, unique_name
from paddle.v2.fluid.layer_helper import LayerHelper, unique_name
import re
import cStringIO
__all__ = [
'fc', 'data', 'cross_entropy', 'conv2d', 'pool2d', 'embedding', 'concat',
'StaticRNN', 'cast', 'sequence_conv', 'sequence_pool', 'sums', 'cos_sim',
'batch_norm', 'accuracy'
'batch_norm', 'accuracy', 'split_lod_tensor'
]
def fc(input,
size,
param_attr=None,
bias_attr=True,
bias_attr=None,
name=None,
act=None,
num_flatten_dims=1,
......@@ -125,6 +127,55 @@ def embedding(input,
return tmp
# TODO(qijun): expose H0 and C0
def dynamic_lstm(input,
size,
data_type='float32',
param_attr=None,
bias_attr=None,
use_peepholes=True,
is_reverse=False,
gate_activation='sigmoid',
cell_activation='tanh',
candidate_activation='tanh',
main_program=None,
startup_program=None):
helper = LayerHelper('lstm', **locals())
size = size / 4
weight = helper.create_parameter(
attr=helper.param_attr, shape=[size, 4 * size], dtype=data_type)
bias_size = [1, 7 * size]
if not use_peepholes:
bias_size[1] = 4 * size
bias = helper.create_parameter(
attr=helper.bias_attr, shape=bias_size, dtype=data_type, suffix='b')
hidden = helper.create_tmp_variable(data_type)
cell = helper.create_tmp_variable(data_type)
batch_gate = helper.create_tmp_variable(data_type)
batch_cell_pre_act = helper.create_tmp_variable(data_type)
helper.append_op(
type='lstm',
inputs={'Input': input,
'Weight': weight,
'Bias': bias},
outputs={
'Hidden': hidden,
'Cell': cell,
'BatchGate': batch_gate,
'BatchCellPreAct': batch_cell_pre_act
},
attrs={
'use_peepholes': use_peepholes,
'is_reverse': is_reverse,
'gate_activation': gate_activation,
'cell_activation': cell_activation,
'candidate_activation': candidate_activation
})
return hidden, cell
def data(name,
shape,
data_type='float32',
......@@ -175,6 +226,11 @@ def data(name,
stop_gradient=stop_gradient)
def create_tensor(dtype, name=None, main_program=None):
helper = LayerHelper("create_tensor", **locals())
return helper.create_variable(name=helper.name, dtype=dtype)
def _convert_(name):
"""
Formatting.
......@@ -191,6 +247,58 @@ def _convert_(name):
return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower()
def _generate_doc_string_(op_proto):
"""
Generate docstring by OpProto
Args:
op_proto (framework_pb2.OpProto): a protobuf message typed OpProto
Returns:
str: the document string
"""
def _type_to_str_(tp):
return framework_pb2.AttrType.Name(tp)
if not isinstance(op_proto, framework_pb2.OpProto):
raise TypeError("OpProto should be `framework_pb2.OpProto`")
buf = cStringIO.StringIO()
buf.write(op_proto.comment)
buf.write('\nArgs:\n')
for each_input in op_proto.inputs:
line_begin = ' {0}: '.format(_convert_(each_input.name))
buf.write(line_begin)
buf.write(each_input.comment)
buf.write('\n')
buf.write(' ' * len(line_begin))
buf.write('Duplicable: ')
buf.write(str(each_input.duplicable))
buf.write(' Optional: ')
buf.write(str(each_input.dispensable))
buf.write('\n')
for each_attr in op_proto.attrs:
buf.write(' ')
buf.write(each_attr.name)
buf.write(' (')
buf.write(_type_to_str_(each_attr.type))
buf.write('): ')
buf.write(each_attr.comment)
buf.write('\n')
if len(op_proto.outputs) != 0:
buf.write('\nReturns:\n')
buf.write(' ')
for each_opt in op_proto.outputs:
if not each_opt.intermediate:
break
buf.write(each_opt.comment)
return buf.getvalue()
def _create_op_func_(op_type):
"""
Create an Operator for a Function.
......@@ -249,11 +357,6 @@ def _create_op_func_(op_type):
return dtype
def func(**kwargs):
"""
This function implements the function for the operator. This process
involves doing the sanity check (using the function above), reading
inputs from protobuf and applying the activations on top.
"""
helper = LayerHelper(op_type, **kwargs)
dtype = infer_and_check_data_type(op_proto, **kwargs)
......@@ -277,6 +380,7 @@ def _create_op_func_(op_type):
func.__name__ = op_type
globals()[op_type] = func
func.__doc__ = _generate_doc_string_(op_proto)
global __all__
__all__.append(op_type)
......@@ -352,6 +456,56 @@ def sums(input, main_program=None, startup_program=None):
return out
def assign(input, output, main_program=None):
helper = LayerHelper('assign', **locals())
helper.append_op(
type='scale',
inputs={'X': [input]},
outputs={'Out': [output]},
attrs={'scale': 1.0})
return output
def split_lod_tensor(input,
mask,
level,
main_program=None,
startup_program=None):
helper = LayerHelper('split_lod_tensor', **locals())
out_true = helper.create_tmp_variable(dtype=input.data_type)
out_false = helper.create_tmp_variable(dtype=input.data_type)
helper.append_op(
type='split_lod_tensor',
inputs={
'X': input,
'Mask': mask,
},
outputs={'OutTrue': out_true,
'OutFalse': out_false},
attrs={'level': level})
return out_true, out_false
def merge_lod_tensor(in_true,
in_false,
x,
mask,
level,
main_program=None,
startup_program=None):
helper = LayerHelper('merge_lod_tensor', **locals())
out = helper.create_tmp_variable(dtype=x.data_type)
helper.append_op(
type='merge_lod_tensor',
inputs={'X': x,
'Mask': mask,
'InTrue': in_true,
'InFalse': in_false},
outputs={'Out': out},
attrs={'level': level})
return out
def cos_sim(X, Y, **kwargs):
"""
This function performs the cosine similarity between two tensors
......@@ -685,6 +839,23 @@ def batch_norm(input,
return helper.append_activation(batch_norm_out)
def beam_search_decode(ids, scores, main_program=None, startup_program=None):
helper = LayerHelper('beam_search_decode', **locals())
sentence_ids = helper.create_tmp_variable(dtype=ids.data_type)
sentence_scores = helper.create_tmp_variable(dtype=ids.data_type)
helper.append_op(
type="beam_search_decode",
inputs={"Ids": ids,
"Scores": scores},
outputs={
"SentenceIds": sentence_ids,
"SentenceScores": sentence_scores
})
return sentence_ids, sentence_scores
class BlockGuard(object):
"""
BlockGuard class.
......@@ -1276,3 +1447,73 @@ def array_length(array, main_program=None):
helper.append_op(
type='lod_array_length', inputs={'X': [array]}, outputs={'Out': [tmp]})
return tmp
class ConditionalBlockGuard(BlockGuard):
def __init__(self, block):
if not isinstance(block, ConditionalBlock):
raise TypeError("block should be conditional block")
super(ConditionalBlockGuard, self).__init__(block.helper.main_program)
self.block = block
def __enter__(self):
return super(ConditionalBlockGuard, self).__enter__()
def __exit__(self, exc_type, exc_val, exc_tb):
self.block.complete()
return super(ConditionalBlockGuard, self).__exit__(exc_type, exc_val,
exc_tb)
class ConditionalBlock(object):
def __init__(self, inputs, name=None, main_program=None):
for each_input in inputs:
if not isinstance(each_input, Variable):
raise TypeError("Each input should be variable")
self.inputs = inputs
self.helper = LayerHelper(
'conditional_block', name=name, main_program=main_program)
def block(self):
return ConditionalBlockGuard(self)
def complete(self):
inside_block = self.helper.main_program.current_block()
parent_block = self.helper.main_program.block(inside_block.parent_idx)
intermediate = set()
params = set()
for each_op in inside_block.ops:
assert isinstance(each_op, Operator)
for iname in each_op.input_names:
for in_var_name in each_op.input(iname):
if in_var_name not in intermediate:
params.add(in_var_name)
for oname in each_op.output_names:
for out_var_name in each_op.output(oname):
intermediate.add(out_var_name)
input_set = set([ipt.name for ipt in self.inputs])
param_list = [
parent_block.var(each_name) for each_name in params
if each_name not in input_set
]
out_list = [
parent_block.var(var_name) for var_name in parent_block.vars
if var_name not in intermediate
]
step_scope = parent_block.create_var(
type=core.VarDesc.VarType.STEP_SCOPES)
parent_block.append_op(
type='conditional_block',
inputs={
'X': self.inputs,
'Params': param_list,
},
outputs={'Out': out_list,
'Scope': [step_scope]},
attrs={'block': inside_block})
......@@ -3,8 +3,8 @@ import json
import logging
from collections import defaultdict
import paddle.v2.framework.core as core
import paddle.v2.framework.proto.framework_pb2 as framework_pb2
import paddle.v2.fluid.core as core
import paddle.v2.fluid.proto.framework_pb2 as framework_pb2
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
......
import paddle.v2.framework.layers as layers
import paddle.v2.fluid.layers as layers
__all__ = ["simple_img_conv_pool", "sequence_conv_pool"]
......
import paddle.v2.framework.core as core
import paddle.v2.framework.proto.framework_pb2 as framework_pb2
import paddle.v2.fluid.core as core
import paddle.v2.fluid.proto.framework_pb2 as framework_pb2
def get_all_op_protos():
......
from collections import defaultdict
import paddle.v2.framework.framework as framework
from paddle.v2.framework.framework import unique_name, Program
from paddle.v2.framework.backward import append_backward_ops
from paddle.v2.framework.initializer import ConstantInitializer
from paddle.v2.framework.regularizer import append_regularization_ops
from paddle.v2.framework.layer_helper import LayerHelper
import paddle.v2.fluid.framework as framework
from paddle.v2.fluid.framework import unique_name, Program
from paddle.v2.fluid.backward import append_backward_ops
from paddle.v2.fluid.initializer import ConstantInitializer
from paddle.v2.fluid.regularizer import append_regularization_ops
from paddle.v2.fluid.layer_helper import LayerHelper
__all__ = [
'SGDOptimizer', 'MomentumOptimizer', 'AdagradOptimizer', 'AdamOptimizer',
......
import paddle.v2.framework.framework as framework
import paddle.v2.fluid.framework as framework
__all__ = [
'append_regularization_ops', 'L2DecayRegularizer', 'L1DecayRegularizer'
......
file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py")
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
foreach(src ${TEST_OPS})
py_test(${src} SRCS ${src}.py)
endforeach()
add_subdirectory(book)
import paddle.v2 as paddle
import paddle.v2.framework.layers as layers
import paddle.v2.framework.core as core
import paddle.v2.framework.optimizer as optimizer
import paddle.v2.fluid.layers as layers
import paddle.v2.fluid.core as core
import paddle.v2.fluid.optimizer as optimizer
from paddle.v2.framework.framework import Program
from paddle.v2.framework.io import save_persistables, load_persistables
from paddle.v2.framework.executor import Executor
from paddle.v2.fluid.framework import Program
from paddle.v2.fluid.io import save_persistables, load_persistables
from paddle.v2.fluid.executor import Executor
import numpy as np
......
import numpy as np
import paddle.v2 as paddle
import paddle.v2.framework.core as core
import paddle.v2.framework.layers as layers
import paddle.v2.framework.nets as nets
import paddle.v2.framework.optimizer as optimizer
from paddle.v2.framework.executor import Executor
from paddle.v2.framework.framework import g_startup_program, g_main_program
from paddle.v2.framework.initializer import XavierInitializer
import paddle.v2.fluid.core as core
import paddle.v2.fluid.layers as layers
import paddle.v2.fluid.nets as nets
import paddle.v2.fluid.optimizer as optimizer
from paddle.v2.fluid.executor import Executor
from paddle.v2.fluid.framework import g_startup_program, g_main_program
from paddle.v2.fluid.initializer import XavierInitializer
def resnet_cifar10(input, depth=32, main_program=None, startup_program=None):
......
import paddle.v2 as paddle
import paddle.v2.framework.layers as layers
import paddle.v2.framework.nets as nets
import paddle.v2.framework.core as core
import paddle.v2.framework.optimizer as optimizer
import paddle.v2.fluid.layers as layers
import paddle.v2.fluid.nets as nets
import paddle.v2.fluid.core as core
import paddle.v2.fluid.optimizer as optimizer
from paddle.v2.framework.framework import Program
from paddle.v2.framework.executor import Executor
from paddle.v2.fluid.framework import Program
from paddle.v2.fluid.executor import Executor
import numpy as np
......
import paddle.v2 as paddle
import paddle.v2.framework.layers as layers
import paddle.v2.framework.core as core
import paddle.v2.framework.optimizer as optimizer
from paddle.v2.framework.framework import Program
from paddle.v2.framework.executor import Executor
from paddle.v2.framework.regularizer import L2DecayRegularizer
from paddle.v2.framework.initializer import UniformInitializer
import paddle.v2.fluid.layers as layers
import paddle.v2.fluid.core as core
import paddle.v2.fluid.optimizer as optimizer
from paddle.v2.fluid.framework import Program
from paddle.v2.fluid.executor import Executor
from paddle.v2.fluid.regularizer import L2DecayRegularizer
from paddle.v2.fluid.initializer import UniformInitializer
import numpy as np
......
import paddle.v2 as paddle
import paddle.v2.framework.layers as layers
import paddle.v2.framework.nets as nets
import paddle.v2.framework.core as core
import paddle.v2.framework.optimizer as optimizer
import paddle.v2.fluid.layers as layers
import paddle.v2.fluid.nets as nets
import paddle.v2.fluid.core as core
import paddle.v2.fluid.optimizer as optimizer
from paddle.v2.framework.framework import Program
from paddle.v2.framework.executor import Executor
from paddle.v2.fluid.framework import Program
from paddle.v2.fluid.executor import Executor
import numpy as np
......
import paddle.v2 as paddle
import paddle.v2.framework.layers as layers
import paddle.v2.framework.nets as nets
import paddle.v2.framework.core as core
import paddle.v2.framework.optimizer as optimizer
import paddle.v2.fluid.layers as layers
import paddle.v2.fluid.nets as nets
import paddle.v2.fluid.core as core
import paddle.v2.fluid.optimizer as optimizer
from paddle.v2.framework.framework import Program, g_main_program, g_startup_program
from paddle.v2.framework.executor import Executor
from paddle.v2.fluid.framework import Program, g_main_program, g_startup_program
from paddle.v2.fluid.executor import Executor
import numpy as np
......
import paddle.v2 as paddle
import paddle.v2.fluid.layers as layers
import paddle.v2.fluid.nets as nets
import paddle.v2.fluid.core as core
import paddle.v2.fluid.optimizer as optimizer
from paddle.v2.fluid.framework import Program, g_main_program, g_startup_program
from paddle.v2.fluid.executor import Executor
import numpy as np
def stacked_lstm_net(input_dim,
class_dim=2,
emb_dim=128,
hid_dim=512,
stacked_num=3):
assert stacked_num % 2 == 1
data = layers.data(name="words", shape=[1], data_type="int64")
label = layers.data(name="label", shape=[1], data_type="int64")
emb = layers.embedding(input=data, size=[input_dim, emb_dim])
# add bias attr
# TODO(qijun) linear act
fc1 = layers.fc(input=emb, size=hid_dim)
lstm1, cell1 = layers.dynamic_lstm(input=fc1, size=hid_dim)
inputs = [fc1, lstm1]
for i in range(2, stacked_num + 1):
fc = layers.fc(input=inputs, size=hid_dim)
lstm, cell = layers.dynamic_lstm(
input=fc, size=hid_dim, is_reverse=(i % 2) == 0)
inputs = [fc, lstm]
fc_last = layers.sequence_pool(input=inputs[0], pool_type='max')
lstm_last = layers.sequence_pool(input=inputs[1], pool_type='max')
prediction = layers.fc(input=[fc_last, lstm_last],
size=class_dim,
act='softmax')
cost = layers.cross_entropy(input=prediction, label=label)
avg_cost = layers.mean(x=cost)
adam_optimizer = optimizer.AdamOptimizer(learning_rate=0.002)
opts = adam_optimizer.minimize(avg_cost)
acc = layers.accuracy(input=prediction, label=label)
return avg_cost, acc
def to_lodtensor(data, place):
seq_lens = [len(seq) for seq in data]
cur_len = 0
lod = [cur_len]
for l in seq_lens:
cur_len += l
lod.append(cur_len)
flattened_data = np.concatenate(data, axis=0).astype("int64")
flattened_data = flattened_data.reshape([len(flattened_data), 1])
res = core.LoDTensor()
res.set(flattened_data, place)
res.set_lod([lod])
return res
def main():
BATCH_SIZE = 100
PASS_NUM = 5
word_dict = paddle.dataset.imdb.word_dict()
print "load word dict successfully"
dict_dim = len(word_dict)
class_dim = 2
cost, acc = stacked_lstm_net(input_dim=dict_dim, class_dim=class_dim)
train_data = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.imdb.train(word_dict), buf_size=1000),
batch_size=BATCH_SIZE)
place = core.CPUPlace()
exe = Executor(place)
exe.run(g_startup_program)
for pass_id in xrange(PASS_NUM):
for data in train_data():
tensor_words = to_lodtensor(map(lambda x: x[0], data), place)
label = np.array(map(lambda x: x[1], data)).astype("int64")
label = label.reshape([BATCH_SIZE, 1])
tensor_label = core.LoDTensor()
tensor_label.set(label, place)
outs = exe.run(g_main_program,
feed={"words": tensor_words,
"label": tensor_label},
fetch_list=[cost, acc])
cost_val = np.array(outs[0])
acc_val = np.array(outs[1])
print("cost=" + str(cost_val) + " acc=" + str(acc_val))
if cost_val < 1.0 and acc_val > 0.7:
exit(0)
exit(1)
if __name__ == '__main__':
main()
import paddle.v2 as paddle
import paddle.v2.framework.layers as layers
import paddle.v2.framework.core as core
import paddle.v2.framework.optimizer as optimizer
import paddle.v2.fluid.layers as layers
import paddle.v2.fluid.core as core
import paddle.v2.fluid.optimizer as optimizer
from paddle.v2.framework.framework import g_main_program, g_startup_program
from paddle.v2.framework.executor import Executor
from paddle.v2.fluid.framework import g_main_program, g_startup_program
from paddle.v2.fluid.executor import Executor
import numpy as np
......
import paddle.v2 as paddle
import paddle.v2.framework.layers as layers
import paddle.v2.framework.core as core
import paddle.v2.framework.optimizer as optimizer
import paddle.v2.fluid.layers as layers
import paddle.v2.fluid.core as core
import paddle.v2.fluid.optimizer as optimizer
from paddle.v2.framework.framework import Program
from paddle.v2.framework.executor import Executor
from paddle.v2.fluid.framework import Program
from paddle.v2.fluid.executor import Executor
import numpy as np
......
......@@ -2,12 +2,12 @@ import unittest
import numpy as np
import random
import itertools
import paddle.v2.framework.core as core
import paddle.v2.fluid.core as core
import collections
from paddle.v2.framework.backward import append_backward_ops
from paddle.v2.framework.op import Operator
from paddle.v2.framework.executor import Executor
from paddle.v2.framework.framework import Program, OpProtoHolder
from paddle.v2.fluid.backward import append_backward_ops
from paddle.v2.fluid.op import Operator
from paddle.v2.fluid.executor import Executor
from paddle.v2.fluid.framework import Program, OpProtoHolder
def randomize_probability(batch_size, class_num, dtype='float32'):
......
import unittest
import paddle.v2.framework.core as core
import paddle.v2.framework.layers as layers
from paddle.v2.framework.executor import Executor
from paddle.v2.framework.backward import append_backward_ops
from paddle.v2.framework.framework import g_main_program
import paddle.v2.fluid.core as core
import paddle.v2.fluid.layers as layers
from paddle.v2.fluid.executor import Executor
from paddle.v2.fluid.backward import append_backward_ops
from paddle.v2.fluid.framework import g_main_program
import numpy
......
import op_test
import numpy
import unittest
class TestAssignOp(op_test.OpTest):
def setUp(self):
self.op_type = "assign"
x = numpy.random.random(size=(100, 10))
self.inputs = {'X': x}
self.outputs = {'Out': x}
def test_forward(self):
self.check_output()
def test_backward(self):
self.check_grad(['X'], 'Out')
if __name__ == '__main__':
unittest.main()
import unittest
import numpy as np
from op_test import OpTest
import paddle.v2.framework.core as core
from paddle.v2.framework.op import Operator
import paddle.v2.fluid.core as core
from paddle.v2.fluid.op import Operator
def grad_var_name(var_name):
......
import unittest
import numpy as np
from op_test import OpTest
class TestBilinearTensorProductOp(OpTest):
def setUp(self):
self.op_type = "bilinear_tensor_product"
batch_size = 6
size0 = 3
size1 = 4
size2 = 5
a = np.random.random((batch_size, size0)).astype("float32")
b = np.random.random((batch_size, size1)).astype("float32")
w = np.random.random((size2, size0, size1)).astype("float32")
bias = np.random.random((1, size2)).astype("float32")
output = np.zeros((batch_size, size2)).astype("float32")
for i in range(size2):
w_i = w[i, :, :]
output[:, i] = np.sum(np.matmul(a, w_i) * b, axis=1)
self.inputs = {
'X': a,
'Y': b,
'Weight': w,
'Bias': bias,
}
self.outputs = {'Out': output + bias}
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(['X', 'Y', 'Weight', 'Bias'], 'Out')
if __name__ == "__main__":
unittest.main()
import op_test
import unittest
import numpy as np
import paddle.v2.framework.core as core
import paddle.v2.fluid.core as core
class TestCastOp(op_test.OpTest):
......
import logging
import paddle.v2.framework.core as core
import paddle.v2.fluid.core as core
import unittest
import numpy as np
from paddle.v2.framework.op import Operator, CondOp
from paddle.v2.fluid.op import Operator, CondOp
class PySimpleCond(object):
......
import unittest
import paddle.v2.fluid.layers as layers
import paddle.v2.fluid.core as core
from paddle.v2.fluid.framework import g_startup_program, g_main_program
from paddle.v2.fluid.executor import Executor
from paddle.v2.fluid.backward import append_backward_ops
import numpy
class ConditionalBlock(unittest.TestCase):
def test_forward(self):
data = layers.data(name='X', shape=[1], data_type='float32')
data.stop_gradient = False
cond = layers.ConditionalBlock(inputs=[data])
out = layers.create_tensor(dtype='float32')
with cond.block():
hidden = layers.fc(input=data, size=10)
layers.assign(hidden, out)
cpu = core.CPUPlace()
exe = Executor(cpu)
exe.run(g_startup_program)
x = core.LoDTensor()
x.set(numpy.random.random(size=(10, 1)).astype('float32'), cpu)
outs = map(numpy.array, exe.run(feed={'X': x}, fetch_list=[out]))[0]
print outs
loss = layers.mean(x=out)
append_backward_ops(loss=loss)
outs = map(numpy.array,
exe.run(feed={'X': x},
fetch_list=[
g_main_program.block(0).var(data.name + "@GRAD")
]))[0]
print outs
if __name__ == '__main__':
unittest.main()
import unittest
import paddle.v2.fluid.layers as layers
class TestDocString(unittest.TestCase):
def test_layer_doc_string(self):
print layers.dropout.__doc__
if __name__ == '__main__':
unittest.main()
from paddle.v2.framework.default_scope_funcs import *
from paddle.v2.fluid.default_scope_funcs import *
import unittest
......
import logging
import paddle.v2.framework.core as core
import paddle.v2.fluid.core as core
import unittest
from paddle.v2.framework.op import Operator, DynamicRecurrentOp
from paddle.v2.fluid.op import Operator, DynamicRecurrentOp
import numpy as np
# for siplicity, just one level LoD
......
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册