diff --git a/paddle/framework/var_type.h b/paddle/framework/var_type.h index d060196bb2c478b776851288cb71a1880d60660d..0f19870bec3e69d07278507cc556a86bbd25d12d 100644 --- a/paddle/framework/var_type.h +++ b/paddle/framework/var_type.h @@ -27,10 +27,32 @@ inline VarDesc::VarType ToVarType(std::type_index type) { return VarDesc_VarType_LOD_RANK_TABLE; } else if (type.hash_code() == typeid(LoDTensorArray).hash_code()) { return VarDesc_VarType_LOD_TENSOR_ARRAY; + } else if (type.hash_code() == typeid(SelectedRows).hash_code()) { + return VarDesc_VarType_SELECTED_ROWS; } else { PADDLE_THROW("ToVarType:Unsupported type %s", type.name()); } } +template +inline void VisitVarType(const Variable& var, Visitor visitor) { + switch (ToVarType(var.Type())) { + case VarDesc_VarType_LOD_TENSOR: + visitor(var.Get()); + return; + case VarDesc_VarType_LOD_RANK_TABLE: + visitor(var.Get()); + return; + case VarDesc_VarType_LOD_TENSOR_ARRAY: + visitor(var.Get()); + return; + case VarDesc_VarType_SELECTED_ROWS: + visitor(var.Get()); + return; + default: + PADDLE_THROW("Not supported visit type, %d", ToVarType(var.Type())); + } +} + } // namespace framework } // namespace paddle diff --git a/paddle/gserver/layers/ROIPoolLayer.cpp b/paddle/gserver/layers/ROIPoolLayer.cpp index 99cfddb0cf3337745a716a8c329713c18b99eda3..35d4b12d3d357800fe72899069b5377c252fac5f 100644 --- a/paddle/gserver/layers/ROIPoolLayer.cpp +++ b/paddle/gserver/layers/ROIPoolLayer.cpp @@ -98,7 +98,7 @@ void ROIPoolLayer::forward(PassType passType) { size_t roiStartH = round(bottomROIs[2] * spatialScale_); size_t roiEndW = round(bottomROIs[3] * spatialScale_); size_t roiEndH = round(bottomROIs[4] * spatialScale_); - CHECK_GE(roiBatchIdx, 0); + CHECK_GE(roiBatchIdx, 0UL); CHECK_LT(roiBatchIdx, batchSize); size_t roiHeight = std::max(roiEndH - roiStartH + 1, 1UL); size_t roiWidth = std::max(roiEndW - roiStartW + 1, 1UL); diff --git a/paddle/gserver/tests/test_MKLDNN.cpp b/paddle/gserver/tests/test_MKLDNN.cpp index a0e039c2a33b586e21775ad06c1278a10804d654..a859e34c8996d81f14bf1edcb6e23d5a4f687e6b 100644 --- a/paddle/gserver/tests/test_MKLDNN.cpp +++ b/paddle/gserver/tests/test_MKLDNN.cpp @@ -297,7 +297,7 @@ static void getAddtoConfig(TestConfig& cfg, } void testAddtoLayer(const testImageDesc& pm, const size_t nInputs) { - CHECK_GE(nInputs, 1); + CHECK_GE(nInputs, 1UL); TestConfig dnnConfig; getAddtoConfig(dnnConfig, pm, nInputs); dnnConfig.layerConfig.set_type("mkldnn_addto"); diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index 29ce44c23308cb5ae1c1df5c9be1412c28abe96f..709f7de2e43093114d096cbfca5b5d49293a6d3e 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -214,6 +214,7 @@ set(GLOB_OP_LIB ${OP_LIBRARY} CACHE INTERNAL "Global OP library") cc_test(gather_test SRCS gather_test.cc DEPS tensor) cc_test(net_op_test SRCS net_op_test.cc DEPS net_op) cc_test(scatter_test SRCS scatter_test.cc DEPS tensor) +cc_test(beam_search_decode_op_test SRCS beam_search_decode_op_test.cc DEPS lod_tensor) cc_test(strided_memcpy_test SRCS strided_memcpy_test.cc DEPS tensor paddle_memory) cc_test(dynamic_recurrent_op_test SRCS dynamic_recurrent_op_test.cc rnn/recurrent_op_utils.cc diff --git a/paddle/operators/assign_op.cc b/paddle/operators/assign_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..609e915b932e2bc4d5abee1e5f868cc07a7619d3 --- /dev/null +++ b/paddle/operators/assign_op.cc @@ -0,0 +1,138 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/framework/data_type.h" +#include "paddle/framework/op_registry.h" +#include "paddle/framework/var_type.h" + +namespace paddle { +namespace operators { +class AssignFunctor { + public: + AssignFunctor(framework::Variable *out, + const platform::DeviceContext &dev_ctx) + : out_(out), dev_ctx_(dev_ctx) {} + + void operator()(const framework::LoDTensor &lod_tensor) const { + auto &out_tensor = *out_->GetMutable(); + copy_tensor(lod_tensor, &out_tensor); + } + + void operator()(const framework::LoDTensorArray &array) const { + auto &out_array = *out_->GetMutable(); + out_array.resize(array.size()); + for (size_t i = 0; i < array.size(); ++i) { + copy_tensor(array[i], &out_array[i]); + } + } + + void operator()(const framework::SelectedRows &rows) const { + framework::SelectedRows &out_rows = + *out_->GetMutable(); + out_rows.set_rows(rows.rows()); + out_rows.set_height(rows.height()); + auto &t = rows.value(); + out_rows.mutable_value()->CopyFrom(t, t.place(), dev_ctx_); + } + + template + void operator()(const T &v) const { + PADDLE_THROW("Not support type for assign op %s", typeid(T).name()); + } + + private: + void copy_tensor(const framework::LoDTensor &lod_tensor, + framework::LoDTensor *out) const { + auto &out_tensor = *out; + out_tensor.CopyFrom(lod_tensor, lod_tensor.place(), dev_ctx_); + out_tensor.set_lod(lod_tensor.lod()); + } + + framework::Variable *out_; + const platform::DeviceContext &dev_ctx_; +}; + +class AssignOp : public framework::OperatorBase { + public: + AssignOp(const std::string &type, const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + void Run(const framework::Scope &scope, + const platform::DeviceContext &dev_ctx) const override { + auto *x = scope.FindVar(Input("X")); + if (x == nullptr) { + return; + } + auto *out = scope.FindVar(Output("Out")); + PADDLE_ENFORCE( + out != nullptr, + "The Output(Out) should not be null if the Input(X) is set."); + framework::VisitVarType(*x, AssignFunctor(out, dev_ctx)); + } +}; + +class AssignOpProtoMaker : public framework::OpProtoAndCheckerMaker { + public: + AssignOpProtoMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", + "(LoDTensor, SelectedRows or LoDTensorArray) The input variable " + "could be LoDTensor, SelectedRows or LoDTensorArray.") + .AsDispensable(); + AddOutput("Out", + "(LoDTensor, SelectedRows or LoDTensorArray) The type of output " + "is the same as input X."); + AddComment(R"DOC(Assign Operator + +Out = X, when type in [LoDTensor/SelectedRows/LoDTensorArray] +raise error if the type is not listed above. +)DOC"); + } +}; + +class AssignInferShape : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *context) const override { + if (context->HasInput("X")) { + auto type = context->GetInputsVarType("X")[0]; + if (type == framework::VarDesc_VarType_SELECTED_ROWS || + type == framework::VarDesc_VarType_LOD_TENSOR) { + context->SetOutputDim("Out", context->GetInputDim("X")); + } + } + } +}; + +class AssignGradMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + auto *op = new framework::OpDescBind(); + op->SetType("assign"); + op->SetInput("X", OutputGrad("Out")); + op->SetOutput("Out", InputGrad("X")); + return std::unique_ptr(op); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(assign, ops::AssignOp, ops::AssignGradMaker, + ops::AssignInferShape, ops::AssignOpProtoMaker); diff --git a/paddle/operators/beam_search_decode_op.cc b/paddle/operators/beam_search_decode_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..1ba4dfcdaba498bfef98258f03664afebe14ec18 --- /dev/null +++ b/paddle/operators/beam_search_decode_op.cc @@ -0,0 +1,110 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/beam_search_decode_op.h" + +namespace paddle { +namespace operators { + +class BeamSearchDecodeOp : public framework::OperatorBase { + public: + BeamSearchDecodeOp(const std::string& type, + const framework::VariableNameMap& inputs, + const framework::VariableNameMap& outputs, + const framework::AttributeMap& attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + void Run(const framework::Scope& scope, + const platform::DeviceContext& dev_ctx) const override { + framework::ExecutionContext ctx(*this, scope, dev_ctx); + const LoDTensorArray* ids = ctx.Input("Ids"); + const LoDTensorArray* scores = ctx.Input("Scores"); + const size_t step_num = ids->size(); + PADDLE_ENFORCE_GT(step_num, 0UL, + "beam search steps should be larger than 0"); + const size_t source_num = ids->at(0).lod().at(0).size() - 1; + PADDLE_ENFORCE_GT(source_num, 0UL, "source num should be larger than 0"); + + for (size_t i = 0; i < step_num; ++i) { + PADDLE_ENFORCE_EQ(ids->at(i).lod().size(), 2UL, + "Level of LodTensor should be 2"); + } + + // prepare output + LoDTensor* sentenceIds = ctx.Output("SentenceIds"); + LoDTensor* sentenceScores = ctx.Output("SentenceScores"); + + BeamSearchDecoder beam_search_decoder; + beam_search_decoder.PackAllSteps(*ids, *scores, sentenceIds, + sentenceScores); + } +}; + +class BeamSearchDecodeOpProtoMaker : public framework::OpProtoAndCheckerMaker { + public: + BeamSearchDecodeOpProtoMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("Ids", + "(LodTensorArray)" + "score of the candidate words in each step"); + AddInput("Scores", + "(LodTensorArray)" + "score of the candidate words in each step"); + AddOutput("SentenceIds", + "(LodTensor)" + "All possible result sentences of word ids"); + AddOutput("SentenceScores", + "(LodTensor)" + "All possible result sentences of word scores"); + AddComment(R"DOC( +Pack the result of Beam search op into SentenceIds and SentenceScores. +)DOC"); + } +}; + +class BeamSearchDecodeInferShape : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext* context) const override { + PADDLE_ENFORCE(context->HasInput("Ids"), + "BeamSearchDecodeOp must has input Ids"); + PADDLE_ENFORCE(context->HasInput("Scores"), + "BeamSearchDecodeOp must has input Scores"); + PADDLE_ENFORCE(context->HasOutput("SentenceIds"), + "BeamSearchDecodeOp must has output SentenceIds"); + PADDLE_ENFORCE(context->HasOutput("SentenceScores"), + "BeamSearchDecodeOp must has output SentenceScores"); + } +}; + +class BeamSearchDecodeInferVarType : public framework::VarTypeInference { + public: + void operator()(const framework::OpDescBind& op_desc, + framework::BlockDescBind* block) const override { + for (auto& o : op_desc.Output("SentenceIds")) { + block->Var(o)->SetType(framework::VarDesc::LOD_TENSOR); + } + for (auto& o : op_desc.Output("SentenceScores")) { + block->Var(o)->SetType(framework::VarDesc::LOD_TENSOR); + } + } +}; + +} // namespace operators +} // namespace paddle + +REGISTER_OPERATOR(beam_search_decode, paddle::operators::BeamSearchDecodeOp, + paddle::operators::BeamSearchDecodeOpProtoMaker, + paddle::operators::BeamSearchDecodeInferShape, + paddle::operators::BeamSearchDecodeInferVarType, + paddle::framework::EmptyGradOpMaker); diff --git a/paddle/operators/beam_search_decode_op.h b/paddle/operators/beam_search_decode_op.h new file mode 100644 index 0000000000000000000000000000000000000000..0f007ec22f9a66572971516a711317f348e1ec5a --- /dev/null +++ b/paddle/operators/beam_search_decode_op.h @@ -0,0 +1,280 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/framework/lod_tensor_array.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using LoDTensor = framework::LoDTensor; +using LoDTensorArray = framework::LoDTensorArray; + +// all the lod have 2 levels. +// The First is source level, the second is sentence level. +// source level describe how many candidate words for this source. +// sentence level describe these candidates belong to which prefix +const size_t kSourceLevel = 0; +const size_t kSentenceLevel = 1; + +template +struct BeamNode { + BeamNode(int64_t word_id, T score) : word_id_(word_id), score_(score) {} + + ~BeamNode() { + if (parent_) { + parent_->DropKid(this); + if (parent_->kids_.size() == 0UL) { + delete parent_; + } + } + VLOG(3) << "Delete BeamNode root with word_id:" << this->word_id_; + } + + void AppendTo(BeamNode* parent) { + parent_ = parent; + parent->kids_.insert(this); + } + + void DropKid(BeamNode* kid) { kids_.erase(kid); } + + BeamNode* parent_ = nullptr; + std::unordered_set kids_; + int64_t word_id_; + T score_; +}; + +template +using BeamNodeVector = std::vector>>; + +template +struct Sentence { + std::vector word_ids; + std::vector scores; +}; + +template +using SentenceVector = std::vector>; + +template +struct BeamSearchDecoder { + /** + * make a BeamNode and all it's related prefix BeanNode into a Sentence. + */ + Sentence MakeSentence(const BeamNode* node) const; + + /** + * Param: + * cur_ids: LoDTensor of One step for word ID + * cur_scores: LoDTensor of One Step for word score + * prefixes_list: prefixes for each source sentence. + * sentence_vector_list: result sentence_vector for each source sentence. + * Return: + * a new prefixes list for each source of current step + */ + std::vector> PackTwoSteps( + const LoDTensor& cur_ids, const LoDTensor& cur_scores, + std::vector>& prefixes_list, + std::vector>* sentence_vector_list) const; + + /** + * convert the result sentence_vector for each source sentence into two + * LodTensor. + * One is all candidate sentences with word id, one is all candidate sentences + * with word score. + * Param: + * sentence_vector_list: sentence_vector for each source sentence. + * id_tensor: result LoDTensor for sentences of id. + * score_tensor: result LoDTensor for sentences of score. + */ + void ConvertSentenceVectorToLodTensor( + std::vector> sentence_vector_list, LoDTensor* id_tensor, + LoDTensor* score_tensor) const; + + /** + * Pack all steps of id/score LodTensor into sentence LoDTensor + * it's main logic is: + * ```python + * prefix + * result_sentence + * result_lod_tensor + * + * for (step in steps): + * prefix = PackTwoSteps(prefix, step, &result_sentence) + * ConvertSentenceVectorToLodTensor(result_sentence, &result_lod_tensor) + * ``` + */ + void PackAllSteps(const LoDTensorArray& step_ids, + const LoDTensorArray& step_scores, LoDTensor* id_tensor, + LoDTensor* score_tensor) const; +}; + +template +Sentence BeamSearchDecoder::MakeSentence(const BeamNode* node) const { + Sentence sentence; + while (node != nullptr) { + sentence.word_ids.emplace_back(node->word_id_); + sentence.scores.emplace_back(node->score_); + node = node->parent_; + } + + std::reverse(std::begin(sentence.word_ids), std::end(sentence.word_ids)); + std::reverse(std::begin(sentence.scores), std::end(sentence.scores)); + + return sentence; +} + +template +std::vector> BeamSearchDecoder::PackTwoSteps( + const LoDTensor& cur_ids, const LoDTensor& cur_scores, + std::vector>& prefixes_list, + std::vector>* sentence_vector_list) const { + std::vector> result; + + for (size_t src_idx = 0; src_idx < cur_ids.lod()[kSourceLevel].size() - 1; + ++src_idx) { + size_t src_start = cur_ids.lod().at(kSourceLevel)[src_idx]; + size_t src_end = cur_ids.lod().at(kSourceLevel)[src_idx + 1]; + + BeamNodeVector beam_nodes; + + // if prefixes size is 0, it means this is the first step. In this step, + // all candidate id is the start of candidate sentences. + if (prefixes_list.empty()) { + PADDLE_ENFORCE_EQ(cur_ids.lod().at(kSourceLevel).back(), + cur_ids.lod().at(kSentenceLevel).back(), + "in the first step"); + for (size_t id_idx = src_start; id_idx < src_end; ++id_idx) { + beam_nodes.push_back(std::unique_ptr>(new BeamNode( + cur_ids.data()[id_idx], cur_scores.data()[id_idx]))); + } + } else { + BeamNodeVector& prefixes = prefixes_list[src_idx]; + SentenceVector& sentence_vector = (*sentence_vector_list)[src_idx]; + + PADDLE_ENFORCE_EQ(src_end - src_start, prefixes.size(), + "prefix and candidate set number should be the same"); + + auto candidate_offset = cur_ids.lod()[kSentenceLevel]; + for (size_t prefix_idx = 0; prefix_idx < prefixes.size(); ++prefix_idx) { + std::unique_ptr>& prefix = prefixes[prefix_idx]; + size_t candidate_start = candidate_offset[src_start + prefix_idx]; + size_t candidate_end = candidate_offset[src_start + prefix_idx + 1]; + if (candidate_start == candidate_end) { + VLOG(3) << "this sentence has no more candidate, " + "add to result sentence and rm it from beam tree"; + sentence_vector.push_back(MakeSentence(prefix.get())); + prefix.reset(); + } else { + for (size_t candidate_idx = candidate_start; + candidate_idx < candidate_end; ++candidate_idx) { + auto* candidate = + new BeamNode(cur_ids.data()[candidate_idx], + cur_scores.data()[candidate_idx]); + candidate->AppendTo(prefix.get()); + beam_nodes.push_back(std::unique_ptr>(candidate)); + } + prefix.release(); + } + } + } + result.push_back(std::move(beam_nodes)); + } + return result; +} + +template +void BeamSearchDecoder::ConvertSentenceVectorToLodTensor( + std::vector> sentence_vector_list, LoDTensor* id_tensor, + LoDTensor* score_tensor) const { + size_t src_num = sentence_vector_list.size(); + + PADDLE_ENFORCE_NE(src_num, 0, "src_num should not be 0"); + + std::vector source_level_lod = {0}; + std::vector sentence_level_lod = {0}; + std::vector id_data; + std::vector score_data; + + for (size_t src_idx = 0; src_idx < src_num; ++src_idx) { + for (Sentence& sentence : sentence_vector_list[src_idx]) { + id_data.insert(id_data.end(), sentence.word_ids.begin(), + sentence.word_ids.end()); + score_data.insert(score_data.end(), sentence.scores.begin(), + sentence.scores.end()); + sentence_level_lod.push_back(sentence_level_lod.back() + + sentence.word_ids.size()); + } + source_level_lod.push_back(source_level_lod.back() + + sentence_vector_list[src_idx].size()); + } + + auto cpu_place = new paddle::platform::CPUPlace(); + paddle::platform::CPUDeviceContext cpu_ctx(*cpu_place); + + framework::LoD lod; + lod.push_back(source_level_lod); + lod.push_back(sentence_level_lod); + + id_tensor->set_lod(lod); + id_tensor->Resize({static_cast(id_data.size())}); + id_tensor->mutable_data(paddle::platform::CPUPlace()); + id_tensor->CopyFromVector(id_data, cpu_ctx); + + score_tensor->set_lod(lod); + score_tensor->Resize({static_cast(score_data.size())}); + score_tensor->mutable_data(paddle::platform::CPUPlace()); + score_tensor->CopyFromVector(score_data, cpu_ctx); +} + +template +void BeamSearchDecoder::PackAllSteps(const LoDTensorArray& step_ids, + const LoDTensorArray& step_scores, + LoDTensor* id_tensor, + LoDTensor* score_tensor) const { + PADDLE_ENFORCE(!step_ids.empty(), "step num should be larger than 0"); + PADDLE_ENFORCE_EQ(step_ids.size(), step_scores.size(), + "step_ids and step_scores should be the same"); + const size_t step_num = step_ids.size(); + const size_t src_num = step_ids.at(0).lod().at(kSourceLevel).size() - 1; + + PADDLE_ENFORCE_GT(src_num, 0UL, "source num should be larger than 0"); + + // previous prefixes for each step, + // the init length is 0, means this is the first step. + std::vector> beamnode_vector_list(0); + std::vector> sentence_vector_list(src_num); + + // pack all steps for one batch first, then another batch + for (size_t step_id = 0; step_id < step_num; ++step_id) { + beamnode_vector_list = + PackTwoSteps(step_ids.at(step_id), step_scores.at(step_id), + beamnode_vector_list, &sentence_vector_list); + } + // append last beam_node to result + for (size_t src_idx = 0; src_idx < src_num; ++src_idx) { + for (auto& beam_node : beamnode_vector_list.at(src_idx)) { + sentence_vector_list[src_idx].push_back(MakeSentence(beam_node.get())); + beam_node.reset(); + } + } + + ConvertSentenceVectorToLodTensor(sentence_vector_list, id_tensor, + score_tensor); +} + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/beam_search_decode_op_test.cc b/paddle/operators/beam_search_decode_op_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..5ac23991f3c7768abaf94f3a4b750697de0ef114 --- /dev/null +++ b/paddle/operators/beam_search_decode_op_test.cc @@ -0,0 +1,221 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/beam_search_decode_op.h" +#include "gtest/gtest.h" + +using CPUPlace = paddle::platform::CPUPlace; +using LoD = paddle::framework::LoD; +using LoDTensor = paddle::framework::LoDTensor; +using LoDTensorArray = paddle::framework::LoDTensorArray; + +template +using BeamNode = paddle::operators::BeamNode; +template +using BeamSearchDecoder = paddle::operators::BeamSearchDecoder; +template +using Sentence = paddle::operators::Sentence; +template +using BeamNodeVector = paddle::operators::BeamNodeVector; +template +using SentenceVector = paddle::operators::SentenceVector; + +namespace paddle { +namespace test { + +void GenerateExample(const std::vector& level_0, + const std::vector& level_1, + const std::vector& data, LoDTensorArray* ids, + LoDTensorArray* scores) { + PADDLE_ENFORCE_EQ(level_0.back(), level_1.size() - 1, + "source level is used to describe candidate set"); + PADDLE_ENFORCE_EQ(level_1.back(), data.size(), + "the lowest level is used to describe data" + ", so it's last element should be data length"); + + CPUPlace place; + + LoD lod; + lod.push_back(level_0); + lod.push_back(level_1); + + // Ids + LoDTensor tensor_id; + tensor_id.set_lod(lod); + tensor_id.Resize({static_cast(data.size())}); + // malloc memory + int64_t* id_ptr = tensor_id.mutable_data(place); + for (size_t i = 0; i < data.size(); ++i) { + id_ptr[i] = static_cast(data.at(i)); + } + + // Scores + LoDTensor tensor_score; + tensor_score.set_lod(lod); + tensor_score.Resize({static_cast(data.size())}); + // malloc memory + float* score_ptr = tensor_score.mutable_data(place); + for (size_t i = 0; i < data.size(); ++i) { + score_ptr[i] = static_cast(data.at(i)); + } + + ids->push_back(tensor_id); + scores->push_back(tensor_score); +} + +} // namespace test +} // namespace paddle + +TEST(BeamSearchDecodeOp, DeleteBeamNode) { + auto* root = new BeamNode(0, 0); + auto* b1 = new BeamNode(1, 1); + auto* b2 = new BeamNode(2, 2); + auto* b3 = new BeamNode(3, 3); + + b1->AppendTo(root); + b2->AppendTo(root); + b3->AppendTo(b1); + + delete b3; + delete b2; +} + +TEST(BeamSearchDecodeOp, MakeSentence) { + auto* root = new BeamNode(0, 0); + auto* b1 = new BeamNode(1, 1); + auto* end = new BeamNode(2, 2); + b1->AppendTo(root); + end->AppendTo(b1); + + BeamSearchDecoder helper; + Sentence sentence = helper.MakeSentence(end); + delete end; + + std::vector expect_ids = {0, 1, 2}; + ASSERT_EQ(sentence.word_ids, expect_ids); + + std::vector expect_scores = {0, 1, 2}; + ASSERT_EQ(sentence.scores, expect_scores); +} + +TEST(BeamSearchDecodeOp, PackTwoStepsFistStep) { + CPUPlace place; + + LoDTensorArray ids; + LoDTensorArray scores; + + paddle::test::GenerateExample( + std::vector{0, 2, 6}, std::vector{0, 1, 2, 3, 4, 5, 6}, + std::vector{1, 2, 3, 4, 5, 6}, &ids, &scores); + + std::vector> beamnode_vector_list; + std::vector> sentence_vector_list( + 2, SentenceVector()); + + BeamSearchDecoder helper; + beamnode_vector_list = helper.PackTwoSteps( + ids[0], scores[0], beamnode_vector_list, &sentence_vector_list); + ASSERT_EQ(beamnode_vector_list.size(), 2UL); + ASSERT_EQ(beamnode_vector_list[0].size(), 2UL); + ASSERT_EQ(beamnode_vector_list[1].size(), 4UL); +} + +TEST(BeamSearchDecodeOp, PackTwoSteps) { + CPUPlace place; + + // first source has three prefix + BeamNodeVector source0_prefixes; + source0_prefixes.push_back( + std::unique_ptr>(new BeamNode(1, 1))); + source0_prefixes.push_back( + std::unique_ptr>(new BeamNode(0, 0))); + source0_prefixes.push_back( + std::unique_ptr>(new BeamNode(3, 3))); + + // second source has two prefix + BeamNodeVector source1_prefixes; + source1_prefixes.push_back( + std::unique_ptr>(new BeamNode(4, 4))); + source1_prefixes.push_back( + std::unique_ptr>(new BeamNode(5, 5))); + + std::vector> beamnode_vector_list; + std::vector> sentence_vector_list( + 2, SentenceVector()); + + beamnode_vector_list.push_back(std::move(source0_prefixes)); + beamnode_vector_list.push_back(std::move(source1_prefixes)); + + // generate data for one step + LoDTensorArray ids; + LoDTensorArray scores; + + paddle::test::GenerateExample(std::vector{0, 3, 5}, + std::vector{0, 1, 1, 3, 4, 5}, + std::vector{0, 1, 2, 3, 4}, &ids, &scores); + + BeamSearchDecoder helper1; + beamnode_vector_list = helper1.PackTwoSteps( + ids[0], scores[0], beamnode_vector_list, &sentence_vector_list); + + ASSERT_EQ(sentence_vector_list[0].size(), 1UL); + ASSERT_EQ(sentence_vector_list[1].size(), 0UL); + ASSERT_EQ(beamnode_vector_list[0].size(), 3UL); + ASSERT_EQ(beamnode_vector_list[1].size(), 2UL); +} + +TEST(BeamSearchDecodeOp, PackAllSteps) { + CPUPlace place; + + // we will constuct a sample data with 3 steps and 2 source sentences + LoDTensorArray ids; + LoDTensorArray scores; + + paddle::test::GenerateExample( + std::vector{0, 3, 6}, std::vector{0, 1, 2, 3, 4, 5, 6}, + std::vector{1, 2, 3, 4, 5, 6}, &ids, &scores); + paddle::test::GenerateExample( + std::vector{0, 3, 6}, std::vector{0, 1, 1, 3, 5, 5, 6}, + std::vector{0, 1, 2, 3, 4, 5}, &ids, &scores); + paddle::test::GenerateExample(std::vector{0, 3, 6}, + std::vector{0, 0, 1, 2, 3, 4, 5}, + std::vector{0, 1, 2, 3, 4}, &ids, &scores); + + ASSERT_EQ(ids.size(), 3UL); + ASSERT_EQ(scores.size(), 3UL); + + BeamSearchDecoder helper; + + LoDTensor id_tensor; + LoDTensor score_tensor; + helper.PackAllSteps(ids, scores, &id_tensor, &score_tensor); + + LoD lod = id_tensor.lod(); + std::vector expect_source_lod = {0, 4, 8}; + EXPECT_EQ(lod[0], expect_source_lod); + std::vector expect_sentence_lod = {0, 1, 3, 6, 9, 10, 13, 16, 19}; + EXPECT_EQ(lod[1], expect_sentence_lod); + // 2| 1, 0| 3, 1, 0| 3, 2, 1| 5| 4, 3, 2| 4, 4, 3| 6, 5, 4 + std::vector expect_data = {2, 1, 0, 3, 1, 0, 3, 2, 1, 5, + 4, 3, 2, 4, 4, 3, 6, 5, 4}; + ASSERT_EQ(id_tensor.dims()[0], static_cast(expect_data.size())); + for (size_t i = 0; i < expect_data.size(); ++i) { + ASSERT_EQ(id_tensor.data()[i], + static_cast(expect_data[i])); + } + for (int64_t i = 0; i < id_tensor.dims()[0]; ++i) { + ASSERT_EQ(score_tensor.data()[i], + static_cast(id_tensor.data()[i])); + } +} diff --git a/paddle/operators/bilinear_tensor_product_op.cc b/paddle/operators/bilinear_tensor_product_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..c65ba7eb262f3aabe2c00837b79806c0b40b60fd --- /dev/null +++ b/paddle/operators/bilinear_tensor_product_op.cc @@ -0,0 +1,159 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/bilinear_tensor_product_op.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; + +class BilinearTensorProductOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Weight"), + "Input(Weight) should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null."); + auto x_dims = ctx->GetInputDim("X"); + auto y_dims = ctx->GetInputDim("Y"); + auto weight_dims = ctx->GetInputDim("Weight"); + + PADDLE_ENFORCE_EQ(x_dims.size(), 2UL, "The input(X) must be a 2D Tensor."); + PADDLE_ENFORCE_EQ(y_dims.size(), 2UL, "The input(Y) must be a 2D Tensor."); + PADDLE_ENFORCE_EQ(weight_dims.size(), 3UL, + "The input(Weight) must be a 3D tensor."); + PADDLE_ENFORCE_EQ(x_dims[0], y_dims[0], + "The first dimension(batch_size) of input(X) must be " + "equal to the first dimension of the input(Y)."); + PADDLE_ENFORCE_EQ(x_dims[1], weight_dims[1], + "The second dimension of input(X) must be equal to " + "the second dimension of the input(Weight)."); + PADDLE_ENFORCE_EQ(y_dims[1], weight_dims[2], + "The second dimension of input(Y) must be equal to " + "the third dimension of the input(Weight)."); + + if (ctx->HasInput("Bias")) { + auto bias_dims = ctx->GetInputDim("Bias"); + PADDLE_ENFORCE(bias_dims.size() == 2UL && bias_dims[0] == 1UL, + "The Input(Bias) must be a 2-D tensor with " + "the 2nd dimension fixed to 1 (a row vector)."); + PADDLE_ENFORCE_EQ(bias_dims[1], weight_dims[0], + "The second dimension of input(Bias) must be equal " + "to the first dimension of the input(Weight)."); + } + + ctx->SetOutputDim("Out", {x_dims[0], weight_dims[0]}); + ctx->ShareLoD("X", /*->*/ "Out"); + } +}; + +class BilinearTensorProductOpMaker : public framework::OpProtoAndCheckerMaker { + public: + BilinearTensorProductOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "The first input of bilinear_tensor_product operator."); + AddInput("Y", "The second input of bilinear_tensor_product operator."); + AddInput("Weight", + "The learnable parameters of bilinear_tensor_product operator."); + AddInput("Bias", "The learnable bias of bilinear_tensor_product operator.") + .AsDispensable(); + AddOutput("Out", "The output of bilinear_tensor_product operator."); + AddComment(R"DOC( +Bilinear Tensor Product operator. +Given input X and Y, a 3D tensor weight, and bias. Each column of the +output is computed by one slice i = 1, . . . , k of the tensor: + + M = (X W_i) \cdot Y + Out_i = \sum_i {M_i} + Bias_i + +)DOC"); + } +}; + +class BilinearTensorProductOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Weight"), + "Input(Weight) should not be null."); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) should not be null."); + auto x_dims = ctx->GetInputDim("X"); + auto y_dims = ctx->GetInputDim("Y"); + auto weight_dims = ctx->GetInputDim("Weight"); + auto out_dims = ctx->GetInputDim(framework::GradVarName("Out")); + + PADDLE_ENFORCE_EQ(out_dims.size(), 2UL, + "The input(Out@GRAD) must be a 2D Tensor."); + PADDLE_ENFORCE_EQ( + x_dims[0], out_dims[0], + "The first dimension(batch_size) of input(Out@GRAD) must be " + "equal to the first dimension of the Input(X)."); + PADDLE_ENFORCE_EQ( + weight_dims[0], out_dims[1], + "The second dimension of input(Out@GRAD) must be equal to " + "the third dimension of the Input(Weight)."); + + if (ctx->HasInput("Bias")) { + auto bias_dims = ctx->GetInputDim("Bias"); + PADDLE_ENFORCE_EQ( + bias_dims[1], out_dims[1], + "The second dimension of input(Out@GRAD) must be equal to " + "the second dimension of the Input(Bias)."); + auto bias_grad_name = framework::GradVarName("Bias"); + if (ctx->HasOutput(bias_grad_name)) + ctx->SetOutputDim(bias_grad_name, bias_dims); + } + + auto x_grad_name = framework::GradVarName("X"); + auto y_grad_name = framework::GradVarName("Y"); + auto weight_grad_name = framework::GradVarName("Weight"); + + if (ctx->HasOutput(x_grad_name)) { + ctx->SetOutputDim(x_grad_name, x_dims); + } + if (ctx->HasOutput(y_grad_name)) { + ctx->SetOutputDim(y_grad_name, y_dims); + } + if (ctx->HasOutput(weight_grad_name)) { + ctx->SetOutputDim(weight_grad_name, weight_dims); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(bilinear_tensor_product, ops::BilinearTensorProductOp, + ops::BilinearTensorProductOpMaker, bilinear_tensor_product_grad, + ops::BilinearTensorProductOpGrad); +REGISTER_OP_CPU_KERNEL( + bilinear_tensor_product, + ops::BilinearTensorProductKernel, + ops::BilinearTensorProductKernel); +REGISTER_OP_CPU_KERNEL( + bilinear_tensor_product_grad, + ops::BilinearTensorProductGradKernel, + ops::BilinearTensorProductGradKernel); diff --git a/paddle/operators/bilinear_tensor_product_op.cu b/paddle/operators/bilinear_tensor_product_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..858d2668d01379afe8082cd1eda32a2a5d09bd18 --- /dev/null +++ b/paddle/operators/bilinear_tensor_product_op.cu @@ -0,0 +1,26 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#define EIGEN_USE_GPU +#include "paddle/operators/bilinear_tensor_product_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL( + bilinear_tensor_product, + ops::BilinearTensorProductKernel, + ops::BilinearTensorProductKernel); +REGISTER_OP_GPU_KERNEL( + bilinear_tensor_product_grad, + ops::BilinearTensorProductGradKernel, + ops::BilinearTensorProductGradKernel); diff --git a/paddle/operators/bilinear_tensor_product_op.h b/paddle/operators/bilinear_tensor_product_op.h new file mode 100644 index 0000000000000000000000000000000000000000..ffa4f43a327418498c1f110504127e7d2878409d --- /dev/null +++ b/paddle/operators/bilinear_tensor_product_op.h @@ -0,0 +1,184 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" +#include "paddle/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; + +template +using EigenMatrix = framework::EigenMatrix; + +template +class BilinearTensorProductKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto* weight = ctx.Input("Weight"); + auto* bias = ctx.Input("Bias"); + auto* out = ctx.Output("Out"); + out->mutable_data(ctx.GetPlace()); + + auto y_mat = EigenMatrix::From(*y); + auto output_mat = EigenMatrix::From(*out); + + auto batch_size = x->dims()[0]; + auto weight_dims = weight->dims(); + int out_dim = weight_dims[0]; + auto x_dim = weight_dims[1]; + auto y_dim = weight_dims[2]; + auto place = ctx.GetEigenDevice(); + + // Create the intermediate variable to caculate the result of + // Input(X) multiplied by Input(Weight_i), the formula is: + // left_mul = X Weight_i. + Tensor left_mul; + left_mul.mutable_data(framework::make_ddim({batch_size, y_dim}), + ctx.GetPlace()); + auto left_mul_mat = EigenMatrix::From(left_mul); + + for (int i = 0; i < out_dim; ++i) { + auto output_col_vec = output_mat.chip(i, 1); + Tensor weight_mat = + weight->Slice(i, i + 1).Resize(framework::make_ddim({x_dim, y_dim})); + math::gemm(ctx.device_context(), CblasNoTrans, CblasNoTrans, + batch_size, y_dim, x_dim, 1, x->data(), + weight_mat.data(), 0, left_mul.data()); + output_col_vec.device(place) = + (left_mul_mat * y_mat).sum(Eigen::DSizes(1)); + } + if (bias) { + auto bias_vec = EigenMatrix::From(*bias); + Eigen::DSizes bcast(batch_size, 1); + output_mat.device(place) = bias_vec.broadcast(bcast) + output_mat; + } + } +}; + +template +class BilinearTensorProductGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + const Tensor* x = ctx.Input("X"); + const Tensor* y = ctx.Input("Y"); + const Tensor* weight = ctx.Input("Weight"); + Tensor* d_x = ctx.Output(framework::GradVarName("X")); + Tensor* d_y = ctx.Output(framework::GradVarName("Y")); + Tensor* d_weight = ctx.Output(framework::GradVarName("Weight")); + Tensor* d_bias = ctx.Output(framework::GradVarName("Bias")); + const Tensor* d_out = ctx.Input(framework::GradVarName("Out")); + + auto batch_size = x->dims()[0]; + auto weight_dims = weight->dims(); + int out_dim = weight_dims[0]; + auto x_dim = weight_dims[1]; + auto y_dim = weight_dims[2]; + + auto x_mat = EigenMatrix::From(*x); + auto y_mat = EigenMatrix::From(*y); + auto d_out_mat = EigenMatrix::From(*d_out); + auto place = ctx.GetEigenDevice(); + + // Create the intermediate variable to caculate the Output(Y@Grad). + Tensor x_scale; + x_scale.mutable_data(framework::make_ddim({batch_size, x_dim}), + ctx.GetPlace()); + auto x_scale_mat = EigenMatrix::From(x_scale); + + // Create the intermediate variable to caculate the Output(X@Grad). + Tensor y_scale; + y_scale.mutable_data(framework::make_ddim({batch_size, y_dim}), + ctx.GetPlace()); + auto y_scale_mat = EigenMatrix::From(y_scale); + + math::SetConstant set_zero; + + // Set Output(X@Grad) be zero. + if (d_x) { + d_x->mutable_data(ctx.GetPlace()); + set_zero(ctx.device_context(), d_x, static_cast(0)); + } + + // Set Output(Y@Grad) be zero. + if (d_y) { + d_y->mutable_data(ctx.GetPlace()); + set_zero(ctx.device_context(), d_y, static_cast(0)); + } + + // Caculate the Output(X@Grad) and Output(Y@Grad). + if (d_x || d_y) { + Eigen::DSizes bcast_for_x(1, y_dim); + Eigen::DSizes bcast_for_y(1, x_dim); + for (int i = 0; i < out_dim; ++i) { + Tensor weight_i = weight->Slice(i, i + 1).Resize( + framework::make_ddim({x_dim, y_dim})); + auto output_vec = d_out_mat.chip(i, 1); + if (d_x) { + y_scale_mat.device(place) = + output_vec.reshape(Eigen::DSizes(batch_size, 1)) + .broadcast(bcast_for_x) * + y_mat; + math::gemm(ctx.device_context(), CblasNoTrans, CblasTrans, + batch_size, x_dim, y_dim, 1, y_scale.data(), + weight_i.data(), 1, d_x->data()); + } + if (d_y) { + x_scale_mat.device(place) = + output_vec.reshape(Eigen::DSizes(batch_size, 1)) + .broadcast(bcast_for_y) * + x_mat; + math::gemm(ctx.device_context(), CblasNoTrans, CblasNoTrans, + batch_size, y_dim, x_dim, 1, x_scale.data(), + weight_i.data(), 1, d_y->data()); + } + } + } + + // Caculate the gradient of Input(Weight). + if (d_weight) { + d_weight->mutable_data(ctx.GetPlace()); + Eigen::DSizes bcast_for_weight(1, x_dim); + for (int i = 0; i < out_dim; ++i) { + Tensor d_weight_i = d_weight->Slice(i, i + 1).Resize( + framework::make_ddim({x_dim, y_dim})); + auto output_vec = d_out_mat.chip(i, 1); + x_scale_mat.device(place) = + output_vec.reshape(Eigen::DSizes(batch_size, 1)) + .broadcast(bcast_for_weight) * + x_mat; + math::gemm(ctx.device_context(), CblasTrans, CblasNoTrans, + x_dim, y_dim, batch_size, 1, x_scale.data(), + y->data(), 0, d_weight_i.data()); + } + } + + // Caculate the gradient of Input(Bias). + if (d_bias) { + d_bias->mutable_data(ctx.GetPlace()); + auto d_bias_mat = EigenMatrix::From(*d_bias); + d_bias_mat.device(place) = d_out_mat.sum(Eigen::DSizes(0)); + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/matmul_op.h b/paddle/operators/matmul_op.h index 5ce30740c90b5cd0bd4f8ab183cf985ed5d827c1..4f565946d596b5e5fbf90f16c0c13c780c36886c 100644 --- a/paddle/operators/matmul_op.h +++ b/paddle/operators/matmul_op.h @@ -74,11 +74,10 @@ Tensor CombineBatchAndN(const framework::ExecutionContext& context, Tensor output; auto in_dims = input.dims(); if (in_dims.size() == 3) { - output.Resize(in_dims); + output.Resize({in_dims[1], in_dims[0], in_dims[2]}); output.mutable_data(context.GetPlace()); EigenTranspose(context, input, output, {1, 0, 2}); - std::vector out_dims = {in_dims[1], in_dims[0] * in_dims[2]}; - output.Resize(make_ddim(out_dims)); + output.Resize({in_dims[1], in_dims[0] * in_dims[2]}); } else { output.ShareDataWith(input); } diff --git a/paddle/operators/merge_lod_tensor_op.cc b/paddle/operators/merge_lod_tensor_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..80460c476921b63ec5228a9780880c7db3c85217 --- /dev/null +++ b/paddle/operators/merge_lod_tensor_op.cc @@ -0,0 +1,182 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/framework/op_registry.h" +#include "paddle/memory/memcpy.h" + +namespace paddle { +namespace operators { + +using LoD = framework::LoD; + +class MergeLoDTensorOp : public framework::OperatorBase { + public: + MergeLoDTensorOp(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + void Run(const framework::Scope &scope, + const platform::DeviceContext &dev_ctx) const override { + auto &x = scope.FindVar(Input("X"))->Get(); + auto &mask = scope.FindVar(Input("Mask"))->Get(); + auto &in_true = scope.FindVar(Input("InTrue"))->Get(); + auto &in_false = + scope.FindVar(Input("InFalse"))->Get(); + auto *out = + scope.FindVar(Output("Out"))->GetMutable(); + auto level = static_cast(Attr("level")); + + auto &mask_dim = mask.dims(); + + std::unique_ptr cpu_mask{new framework::LoDTensor()}; + if (platform::is_cpu_place(mask.place())) { + cpu_mask->ShareDataWith(mask); + } else if (platform::is_gpu_place(mask.place())) { +#ifdef PADDLE_WITH_CUDA + cpu_mask->CopyFrom(mask, platform::CPUPlace(), dev_ctx); +#else + PADDLE_THROW("Not supported GPU, Please compile WITH_GPU option"); +#endif + } + auto *mask_data = cpu_mask->data(); + + int rank = in_true.dims().size(); + platform::Place place = in_true.place(); + std::type_index data_type = in_true.type(); + framework::DDim in_true_dims = + framework::slice_ddim(in_true.dims(), 1, rank); + + int64_t batch_size = in_true.dims()[0] + in_false.dims()[0]; + + auto in_true_dim_vec = framework::vectorize(in_true_dims); + in_true_dim_vec.insert(in_true_dim_vec.begin(), batch_size); + + framework::DDim out_dims = framework::make_ddim(in_true_dim_vec); + out->Resize(out_dims); + out->mutable_data(place, data_type); + + auto *out_lod = out->mutable_lod(); + out_lod->clear(); + size_t out_offset = 0; + + // Build LoDTensor `out` + + size_t in_true_idx = 0; + size_t in_false_idx = 0; + for (size_t i = 0; i < static_cast(mask_dim[0]); i++) { + const framework::LoDTensor *input = nullptr; + size_t *in_idx = nullptr; + if (static_cast(mask_data[i]) == 0) { + input = &in_false; + in_idx = &in_false_idx; + } else { + input = &in_true; + in_idx = &in_true_idx; + } + auto lod_and_offset = framework::GetSubLoDAndAbsoluteOffset( + input->lod(), *in_idx, (*in_idx) + 1, 0); + auto &lod_length = lod_and_offset.first; + + framework::AppendLoD(out_lod, lod_length); + + size_t start_offset = lod_and_offset.second.first; + size_t end_offset = lod_and_offset.second.second; + + PADDLE_ENFORCE_GE(end_offset, start_offset); + size_t len = end_offset - start_offset; + if (len == 0) { + continue; + } + out->Slice(out_offset, out_offset + len) + .CopyFrom(input->Slice(start_offset, end_offset), place, dev_ctx); + out_offset += len; + (*in_idx) += 1; + } + + for (size_t i = 0; i < level; i++) { + out_lod->insert(out_lod->begin(), x.lod()[i]); + } + } +}; + +class MergeLoDTensorOpProtoMaker : public framework::OpProtoAndCheckerMaker { + public: + MergeLoDTensorOpProtoMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", + "The input LoDTensor, contains complete lod information to " + "construct the output"); + AddInput("Mask", "A bool column vector which mask the input"); + AddInput("InTrue", "The True branch to be merged"); + AddInput("InFalse", "The False branch to be merged"); + AddOutput("Out", "The merged output LoDTensor"); + AddAttr("level", "(int) the specific lod level to rank.") + .SetDefault(0) + .EqualGreaterThan(0); + AddComment( + R"DOC( + Merge True and False branches of LoDTensor into a single Output, + with a mask at certain lod level. X is used to obtain complete + lod information. Please refer to SplitLoDTensorOp.)DOC"); + } +}; + +class MergeLoDTensorInferShape : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *context) const override { + PADDLE_ENFORCE(context->HasInput("X"), + "MergeLoDTensorOp must has input X."); + PADDLE_ENFORCE(context->HasInput("Mask"), + "MergeLoDTensorOp must has input Mask."); + PADDLE_ENFORCE(context->HasInput("InTrue"), + "MergeLoDTensorOp must has input InTrue."); + PADDLE_ENFORCE(context->HasInput("InFalse"), + "MergeLoDTensorOp must has input InFalse."); + PADDLE_ENFORCE(context->HasOutput("Out"), + "MergeLoDTensorOp must has output Out"); + + auto mask_dim = context->GetInputDim("Mask"); + PADDLE_ENFORCE_EQ(mask_dim.size(), 2); + PADDLE_ENFORCE_EQ(mask_dim[1], 1); + + context->SetOutputDim("Out", context->GetInputDim("InTrue")); + } +}; + +class MergeLoDTensorGradMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + auto *grad_op = new framework::OpDescBind(); + grad_op->SetType("split_lod_tensor"); + grad_op->SetInput("X", OutputGrad("Out")); + grad_op->SetInput("Mask", Input("Mask")); + grad_op->SetOutput("OutTrue", InputGrad("InTrue")); + grad_op->SetOutput("OutFalse", InputGrad("InFalse")); + grad_op->SetAttrMap(Attrs()); + return std::unique_ptr(grad_op); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(merge_lod_tensor, ops::MergeLoDTensorOp, + ops::MergeLoDTensorOpProtoMaker, + ops::MergeLoDTensorInferShape, ops::MergeLoDTensorGradMaker); diff --git a/paddle/operators/sequence_concat_op.cc b/paddle/operators/sequence_concat_op.cc index db737bed7a4d2dc5b60cbc6ac172caec95acd35e..d1de0b444712a8c304c33bd194e306dfe3c41f02 100644 --- a/paddle/operators/sequence_concat_op.cc +++ b/paddle/operators/sequence_concat_op.cc @@ -47,7 +47,7 @@ class SequenceConcatOpMaker : public framework::OpProtoAndCheckerMaker { framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", - "(vector) Input is a vector of LoDTensor, " + "(LodTensorArray) Input is a vector of LoDTensor, " "each of which is a variable-length sequence or nested sequence.") .AsDuplicable(); AddOutput("Out", diff --git a/paddle/operators/sequence_pool_op.h b/paddle/operators/sequence_pool_op.h index 2b8a25c2414c20efaffedfc8603697b3a104634f..7f136d8cf0e1eaae7b4de32988b60ae8a5034cc6 100644 --- a/paddle/operators/sequence_pool_op.h +++ b/paddle/operators/sequence_pool_op.h @@ -126,6 +126,7 @@ class SequencePoolGradKernel : public framework::OpKernel { int64_t h = static_cast(lod[i + 1] - lod[i]); auto in_g_e = EigenMatrix::From(in_g_t, {h, w}); auto out_g_e = EigenMatrix::From(out_g_t, {1, w}); + auto out_g_e_v = EigenVector::Flatten(out_g_t); Eigen::DSizes bcast(h, 1); if (pooltype == "AVERAGE") { @@ -136,9 +137,9 @@ class SequencePoolGradKernel : public framework::OpKernel { in_g_e.device(place) = (out_g_e / std::sqrt(static_cast(h))).broadcast(bcast); } else if (pooltype == "LAST") { - in_g_e.chip(h - 1, 0).device(place) = out_g_e; + in_g_e.chip(h - 1, 0).device(place) = out_g_e_v; } else if (pooltype == "FIRST") { - in_g_e.chip(0, 0).device(place) = out_g_e; + in_g_e.chip(0, 0).device(place) = out_g_e_v; } else { PADDLE_THROW("unsupported pooling pooltype"); } diff --git a/paddle/operators/split_lod_tensor_op.cc b/paddle/operators/split_lod_tensor_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..db635f2ba0804143c9a2e04ff006dfbc8744f3fc --- /dev/null +++ b/paddle/operators/split_lod_tensor_op.cc @@ -0,0 +1,186 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/framework/op_registry.h" +#include "paddle/memory/memcpy.h" + +namespace paddle { +namespace operators { + +struct CopyRange { + size_t begin; + size_t end; +}; + +using LoD = framework::LoD; + +class SplitLoDTensorOp : public framework::OperatorBase { + public: + SplitLoDTensorOp(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + void Run(const framework::Scope &scope, + const platform::DeviceContext &dev_ctx) const override { + auto &x = scope.FindVar(Input("X"))->Get(); + auto &mask = scope.FindVar(Input("Mask"))->Get(); + auto *out_true = + scope.FindVar(Output("OutTrue"))->GetMutable(); + auto *out_false = + scope.FindVar(Output("OutFalse"))->GetMutable(); + auto level = static_cast(Attr("level")); + auto &x_lod = x.lod(); + auto &mask_dim = mask.dims(); + + std::unique_ptr cpu_mask{new framework::LoDTensor()}; + if (platform::is_cpu_place(mask.place())) { + cpu_mask->ShareDataWith(mask); + } else if (platform::is_gpu_place(mask.place())) { +#ifdef PADDLE_WITH_CUDA + cpu_mask->CopyFrom(mask, platform::CPUPlace(), dev_ctx); +#else + PADDLE_THROW("Not supported GPU, Please compile WITH_GPU option"); +#endif + } + auto *mask_data = cpu_mask->data(); + + std::vector> copy_ranges(mask_dim[0]); + + // set out_true/out_false lod + for (size_t t = 0; t < 2; t++) { + LoD *lod = nullptr; + if (t == 0) { + lod = out_false->mutable_lod(); + } else { + lod = out_true->mutable_lod(); + } + lod->clear(); + for (size_t i = 0; i < static_cast(mask_dim[0]); i++) { + if (static_cast(mask_data[i]) == t) { + size_t start_idx = i; + auto lod_and_offset = framework::GetSubLoDAndAbsoluteOffset( + x_lod, start_idx, start_idx + 1, level); + + auto &lod_length = lod_and_offset.first; + framework::AppendLoD(lod, lod_length); + + size_t start_offset = lod_and_offset.second.first; + size_t end_offset = lod_and_offset.second.second; + copy_ranges[t].emplace_back(CopyRange{start_offset, end_offset}); + } + } + } + + for (size_t t = 0; t < 2; ++t) { + framework::LoDTensor *out; + if (t == 0) { + out = out_false; + } else { + out = out_true; + } + auto &ranges = copy_ranges[t]; + size_t height = std::accumulate( + ranges.begin(), ranges.end(), 0UL, + [](size_t a, const CopyRange &b) { return a + b.end - b.begin; }); + auto x_dim = x.dims(); + x_dim[0] = static_cast(height); + out->Resize(x_dim); + out->mutable_data(x.place(), x.type()); + size_t offset = 0; + for (auto &each_range : ranges) { + size_t len = each_range.end - each_range.begin; + if (len == 0) { + continue; + } + // out[offset: offset+len] = x[each_range.begin: each_range.end] + out->Slice(static_cast(offset), static_cast(offset + len)) + .CopyFrom(x.Slice(static_cast(each_range.begin), + static_cast(each_range.end)), + x.place(), dev_ctx); + offset += len; + } + } + } +}; + +class SplitLoDTensorOpProtoMaker : public framework::OpProtoAndCheckerMaker { + public: + SplitLoDTensorOpProtoMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "The input LoDTensor"); + AddInput("Mask", "A bool column vector which mask the input"); + AddOutput("OutTrue", "True branch of input LoDTensor"); + AddOutput("OutFalse", "False branch of input LoDTensor"); + AddAttr("level", "(int) the specific lod level to split.") + .SetDefault(0) + .EqualGreaterThan(0); + AddComment( + R"DOC( + Split a LoDTensor with a Mask at certain level. The input LoDTensor + has 3 sequence at certain lod level. The Mask is a bool column vector, + such as [0, 1, 0] at the same level. The first and third sequence will + be send to False Output LoDTensor; whereas the second sequence will + be send to True Output LoDTensor. Please refer to MergeLoDTensorOp.)DOC"); + } +}; + +class SplitLoDTensorInferShape : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *context) const override { + PADDLE_ENFORCE(context->HasInput("X"), + "SplitLoDTensorOp must has input X."); + PADDLE_ENFORCE(context->HasInput("Mask"), + "SplitLoDTensorOp must has input Mask."); + PADDLE_ENFORCE(context->HasOutput("OutTrue"), + "SplitLoDTensorOp must has output OutTrue."); + PADDLE_ENFORCE(context->HasOutput("OutFalse"), + "SplitLoDTensorOp must has output OutFalse."); + + auto mask_dim = context->GetInputDim("Mask"); + PADDLE_ENFORCE_EQ(mask_dim.size(), 2); + PADDLE_ENFORCE_EQ(mask_dim[1], 1); + + context->SetOutputDim("OutTrue", context->GetInputDim("X")); + context->SetOutputDim("OutFalse", context->GetInputDim("X")); + } +}; + +class SplitLoDTensorArrayGradMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + auto *grad_op = new framework::OpDescBind(); + grad_op->SetType("merge_lod_tensor"); + grad_op->SetInput("InTrue", OutputGrad("OutTrue")); + grad_op->SetInput("InFalse", OutputGrad("OutFalse")); + grad_op->SetInput("Mask", Input("Mask")); + grad_op->SetInput("X", Input("X")); + grad_op->SetOutput("Out", InputGrad("X")); + grad_op->SetAttrMap(Attrs()); + return std::unique_ptr(grad_op); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(split_lod_tensor, ops::SplitLoDTensorOp, + ops::SplitLoDTensorOpProtoMaker, + ops::SplitLoDTensorInferShape, + ops::SplitLoDTensorArrayGradMaker); diff --git a/python/paddle/v2/framework/layers.py b/python/paddle/v2/framework/layers.py index dab8a1474f6fc28dd61e8ff3a8399986ef8ba6fd..a2219465b7a1cdb974eeff13fc5fd801f39200bb 100644 --- a/python/paddle/v2/framework/layers.py +++ b/python/paddle/v2/framework/layers.py @@ -1,15 +1,17 @@ import paddle.v2.framework.core as core +import paddle.v2.framework.proto.framework_pb2 as framework_pb2 from paddle.v2.framework.framework import OpProtoHolder, Variable, Program, \ Operator from paddle.v2.framework.initializer import ConstantInitializer, \ NormalInitializer from paddle.v2.framework.layer_helper import LayerHelper, unique_name import re +import cStringIO __all__ = [ 'fc', 'data', 'cross_entropy', 'conv2d', 'pool2d', 'embedding', 'concat', 'StaticRNN', 'cast', 'sequence_conv', 'sequence_pool', 'sums', 'cos_sim', - 'batch_norm', 'accuracy' + 'batch_norm', 'accuracy', 'split_lod_tensor' ] @@ -240,6 +242,58 @@ def _convert_(name): return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower() +def _generate_doc_string_(op_proto): + """ + Generate docstring by OpProto + + Args: + op_proto (framework_pb2.OpProto): a protobuf message typed OpProto + + Returns: + str: the document string + """ + + def _type_to_str_(tp): + return framework_pb2.AttrType.Name(tp) + + if not isinstance(op_proto, framework_pb2.OpProto): + raise TypeError("OpProto should be `framework_pb2.OpProto`") + + buf = cStringIO.StringIO() + buf.write(op_proto.comment) + buf.write('\nArgs:\n') + for each_input in op_proto.inputs: + line_begin = ' {0}: '.format(_convert_(each_input.name)) + buf.write(line_begin) + buf.write(each_input.comment) + buf.write('\n') + buf.write(' ' * len(line_begin)) + buf.write('Duplicable: ') + buf.write(str(each_input.duplicable)) + buf.write(' Optional: ') + buf.write(str(each_input.dispensable)) + buf.write('\n') + + for each_attr in op_proto.attrs: + buf.write(' ') + buf.write(each_attr.name) + buf.write(' (') + buf.write(_type_to_str_(each_attr.type)) + buf.write('): ') + buf.write(each_attr.comment) + buf.write('\n') + + if len(op_proto.outputs) != 0: + buf.write('\nReturns:\n') + buf.write(' ') + for each_opt in op_proto.outputs: + if not each_opt.intermediate: + break + buf.write(each_opt.comment) + + return buf.getvalue() + + def _create_op_func_(op_type): """ Create an Operator for a Function. @@ -298,11 +352,6 @@ def _create_op_func_(op_type): return dtype def func(**kwargs): - """ - This function implements the function for the operator. This process - involves doing the sanity check (using the function above), reading - inputs from protobuf and applying the activations on top. - """ helper = LayerHelper(op_type, **kwargs) dtype = infer_and_check_data_type(op_proto, **kwargs) @@ -326,6 +375,7 @@ def _create_op_func_(op_type): func.__name__ = op_type globals()[op_type] = func + func.__doc__ = _generate_doc_string_(op_proto) global __all__ __all__.append(op_type) @@ -401,6 +451,46 @@ def sums(input, main_program=None, startup_program=None): return out +def split_lod_tensor(input, + mask, + level, + main_program=None, + startup_program=None): + helper = LayerHelper('split_lod_tensor', **locals()) + out_true = helper.create_tmp_variable(dtype=input.data_type) + out_false = helper.create_tmp_variable(dtype=input.data_type) + helper.append_op( + type='split_lod_tensor', + inputs={ + 'X': input, + 'Mask': mask, + }, + outputs={'OutTrue': out_true, + 'OutFalse': out_false}, + attrs={'level': level}) + return out_true, out_false + + +def merge_lod_tensor(in_true, + in_false, + x, + mask, + level, + main_program=None, + startup_program=None): + helper = LayerHelper('merge_lod_tensor', **locals()) + out = helper.create_tmp_variable(dtype=x.data_type) + helper.append_op( + type='merge_lod_tensor', + inputs={'X': x, + 'Mask': mask, + 'InTrue': in_true, + 'InFalse': in_false}, + outputs={'Out': out}, + attrs={'level': level}) + return out + + def cos_sim(X, Y, **kwargs): """ This function performs the cosine similarity between two tensors diff --git a/python/paddle/v2/framework/tests/CMakeLists.txt b/python/paddle/v2/framework/tests/CMakeLists.txt index 4d7664469e481344cf9eea84688f068b4fb99dee..e795627bfe9e8ad0c196349a332e62e975f20aa3 100644 --- a/python/paddle/v2/framework/tests/CMakeLists.txt +++ b/python/paddle/v2/framework/tests/CMakeLists.txt @@ -3,3 +3,5 @@ string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") foreach(src ${TEST_OPS}) py_test(${src} SRCS ${src}.py) endforeach() + +add_subdirectory(book) diff --git a/python/paddle/v2/framework/tests/book/CMakeLists.txt b/python/paddle/v2/framework/tests/book/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..4d7664469e481344cf9eea84688f068b4fb99dee --- /dev/null +++ b/python/paddle/v2/framework/tests/book/CMakeLists.txt @@ -0,0 +1,5 @@ +file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py") +string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") +foreach(src ${TEST_OPS}) + py_test(${src} SRCS ${src}.py) +endforeach() diff --git a/python/paddle/v2/framework/tests/test_fit_a_line.py b/python/paddle/v2/framework/tests/book/test_fit_a_line.py similarity index 100% rename from python/paddle/v2/framework/tests/test_fit_a_line.py rename to python/paddle/v2/framework/tests/book/test_fit_a_line.py diff --git a/python/paddle/v2/framework/tests/test_image_classification_train.py b/python/paddle/v2/framework/tests/book/test_image_classification_train.py similarity index 100% rename from python/paddle/v2/framework/tests/test_image_classification_train.py rename to python/paddle/v2/framework/tests/book/test_image_classification_train.py diff --git a/python/paddle/v2/framework/tests/test_recognize_digits_conv.py b/python/paddle/v2/framework/tests/book/test_recognize_digits_conv.py similarity index 100% rename from python/paddle/v2/framework/tests/test_recognize_digits_conv.py rename to python/paddle/v2/framework/tests/book/test_recognize_digits_conv.py diff --git a/python/paddle/v2/framework/tests/test_recognize_digits_mlp.py b/python/paddle/v2/framework/tests/book/test_recognize_digits_mlp.py similarity index 100% rename from python/paddle/v2/framework/tests/test_recognize_digits_mlp.py rename to python/paddle/v2/framework/tests/book/test_recognize_digits_mlp.py diff --git a/python/paddle/v2/framework/tests/test_recommender_system.py b/python/paddle/v2/framework/tests/book/test_recommender_system.py similarity index 100% rename from python/paddle/v2/framework/tests/test_recommender_system.py rename to python/paddle/v2/framework/tests/book/test_recommender_system.py diff --git a/python/paddle/v2/framework/tests/test_understand_sentiment_conv.py b/python/paddle/v2/framework/tests/book/test_understand_sentiment_conv.py similarity index 100% rename from python/paddle/v2/framework/tests/test_understand_sentiment_conv.py rename to python/paddle/v2/framework/tests/book/test_understand_sentiment_conv.py diff --git a/python/paddle/v2/framework/tests/test_understand_sentiment_dynamic_lstm.py b/python/paddle/v2/framework/tests/book/test_understand_sentiment_dynamic_lstm.py similarity index 100% rename from python/paddle/v2/framework/tests/test_understand_sentiment_dynamic_lstm.py rename to python/paddle/v2/framework/tests/book/test_understand_sentiment_dynamic_lstm.py diff --git a/python/paddle/v2/framework/tests/test_understand_sentiment_lstm.py b/python/paddle/v2/framework/tests/book/test_understand_sentiment_lstm.py similarity index 100% rename from python/paddle/v2/framework/tests/test_understand_sentiment_lstm.py rename to python/paddle/v2/framework/tests/book/test_understand_sentiment_lstm.py diff --git a/python/paddle/v2/framework/tests/test_word2vec.py b/python/paddle/v2/framework/tests/book/test_word2vec.py similarity index 100% rename from python/paddle/v2/framework/tests/test_word2vec.py rename to python/paddle/v2/framework/tests/book/test_word2vec.py diff --git a/python/paddle/v2/framework/tests/test_assign_op.py b/python/paddle/v2/framework/tests/test_assign_op.py new file mode 100644 index 0000000000000000000000000000000000000000..1b0c145f1a69678b228bc70e4e4e273f5bcf9888 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_assign_op.py @@ -0,0 +1,21 @@ +import op_test +import numpy +import unittest + + +class TestAssignOp(op_test.OpTest): + def setUp(self): + self.op_type = "assign" + x = numpy.random.random(size=(100, 10)) + self.inputs = {'X': x} + self.outputs = {'Out': x} + + def test_forward(self): + self.check_output() + + def test_backward(self): + self.check_grad(['X'], 'Out') + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_bilinear_tensor_product_op.py b/python/paddle/v2/framework/tests/test_bilinear_tensor_product_op.py new file mode 100644 index 0000000000000000000000000000000000000000..080ca43b8269e0f6a9f4d0ce3973f4d4a07a8e2a --- /dev/null +++ b/python/paddle/v2/framework/tests/test_bilinear_tensor_product_op.py @@ -0,0 +1,37 @@ +import unittest +import numpy as np +from op_test import OpTest + + +class TestBilinearTensorProductOp(OpTest): + def setUp(self): + self.op_type = "bilinear_tensor_product" + batch_size = 6 + size0 = 3 + size1 = 4 + size2 = 5 + a = np.random.random((batch_size, size0)).astype("float32") + b = np.random.random((batch_size, size1)).astype("float32") + w = np.random.random((size2, size0, size1)).astype("float32") + bias = np.random.random((1, size2)).astype("float32") + output = np.zeros((batch_size, size2)).astype("float32") + for i in range(size2): + w_i = w[i, :, :] + output[:, i] = np.sum(np.matmul(a, w_i) * b, axis=1) + self.inputs = { + 'X': a, + 'Y': b, + 'Weight': w, + 'Bias': bias, + } + self.outputs = {'Out': output + bias} + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad(['X', 'Y', 'Weight', 'Bias'], 'Out') + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_create_op_doc_string.py b/python/paddle/v2/framework/tests/test_create_op_doc_string.py new file mode 100644 index 0000000000000000000000000000000000000000..d21e96df2a64d3fa418dca94690ea0b820db80de --- /dev/null +++ b/python/paddle/v2/framework/tests/test_create_op_doc_string.py @@ -0,0 +1,11 @@ +import unittest +import paddle.v2.framework.layers as layers + + +class TestDocString(unittest.TestCase): + def test_layer_doc_string(self): + print layers.dropout.__doc__ + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_split_and_merge_lod_tensor_op.py b/python/paddle/v2/framework/tests/test_split_and_merge_lod_tensor_op.py new file mode 100644 index 0000000000000000000000000000000000000000..6ba1e568249d4a72820cc26193a8e0e030ae5f7c --- /dev/null +++ b/python/paddle/v2/framework/tests/test_split_and_merge_lod_tensor_op.py @@ -0,0 +1,181 @@ +import unittest +import paddle.v2.framework.core as core +import numpy as np +import paddle.v2.framework.layers as layers +from paddle.v2.framework.framework import Program +from paddle.v2.framework.executor import Executor +from paddle.v2.framework.backward import append_backward_ops + + +class TestCPULoDTensorArrayOps(unittest.TestCase): + def place(self): + return core.CPUPlace() + + def test_split_and_merge_lod_tensor_no_lod(self): + tensor = core.LoDTensor() + tensor.set(np.arange(10).reshape(10, 1).astype('int32'), self.place()) + + mask_np = np.array([0, 0, 1, 1, 1, 1, 0, 0, 0, 0]).astype('bool') + mask_np = np.expand_dims(mask_np, axis=1) + + mask = core.LoDTensor() + mask.set(mask_np, self.place()) + + expect_true_tensor = np.array([2, 3, 4, 5]).astype('int32') + expect_true_tensor = np.expand_dims(expect_true_tensor, axis=1) + expect_true = core.LoDTensor() + expect_true.set(expect_true_tensor, self.place()) + + expect_false_tensor = np.array([0, 1, 6, 7, 8, 9]).astype('int32') + expect_false_tensor = np.expand_dims(expect_false_tensor, axis=1) + + expect_false = core.LoDTensor() + expect_false.set(expect_false_tensor, self.place()) + + self.main( + tensor=tensor, + mask=mask, + expect_true=expect_true, + expect_false=expect_false, + expect_out=tensor) + + def test_split_and_merge_lod_tensor_level_0(self): + tensor = core.LoDTensor() + tensor.set(np.arange(10).reshape(10, 1).astype('int32'), self.place()) + tensor.set_lod([[0, 3, 9, 10]]) + + mask_np = np.array([0, 1, 0]).astype('bool') + mask_np = np.expand_dims(mask_np, axis=1) + + mask = core.LoDTensor() + mask.set(mask_np, self.place()) + + expect_true_tensor = np.array([3, 4, 5, 6, 7, 8]).astype('int32') + expect_true_tensor = np.expand_dims(expect_true_tensor, axis=1) + expect_true = core.LoDTensor() + expect_true.set(expect_true_tensor, self.place()) + expect_true.set_lod([[0, 6]]) + + expect_false_tensor = np.array([0, 1, 2, 9]).astype('int32') + expect_false_tensor = np.expand_dims(expect_false_tensor, axis=1) + expect_false_lod = [[0, 3, 4]] + + expect_false = core.LoDTensor() + expect_false.set(expect_false_tensor, self.place()) + expect_false.set_lod(expect_false_lod) + + self.main( + tensor=tensor, + mask=mask, + expect_true=expect_true, + expect_false=expect_false, + expect_out=tensor) + + def main(self, tensor, mask, expect_true, expect_false, expect_out, + level=0): + place = self.place() + program = Program() + x = layers.data(name='x', shape=[1], main_program=program) + x.persistable = True + + y = layers.data(name='y', shape=[1], main_program=program) + y.persistable = True + + out_true, out_false = layers.split_lod_tensor( + input=x, mask=y, level=level, main_program=program) + out_true.persistable = True + out_false.persistable = True + + out = layers.merge_lod_tensor( + in_true=out_true, + in_false=out_false, + mask=y, + x=x, + level=level, + main_program=program) + + out.persistable = True + + exe = Executor(place) + scope = core.Scope() + exe.run(program, feed={'x': tensor, 'y': mask}, scope=scope) + + var_true = scope.find_var(out_true.name).get_tensor() + + var_false = scope.find_var(out_false.name).get_tensor() + + var_out = scope.find_var(out.name).get_tensor() + + self.check_tensor_same(var_true, expect_true) + self.check_tensor_same(var_false, expect_false) + self.check_tensor_same(var_out, expect_out) + + def check_tensor_same(self, actual, expect): + self.assertTrue(np.allclose(np.array(actual), np.array(expect))) + self.assertEqual(actual.lod(), expect.lod()) + + +class TestCPUSplitMergeLoDTensorGrad(unittest.TestCase): + def test_grad(self): + place = core.CPUPlace() + program = Program() + + x = layers.data( + name='x', + shape=[1], + data_type='float32', + main_program=program, + stop_gradient=False) + y = layers.data( + name='y', + shape=[1], + data_type='bool', + main_program=program, + stop_gradient=False) + + level = 0 + + out_true, out_false = layers.split_lod_tensor( + input=x, mask=y, level=level, main_program=program) + out = layers.merge_lod_tensor( + in_true=out_true, + in_false=out_false, + mask=y, + x=x, + level=level, + main_program=program) + mean = layers.mean(x=out, main_program=program) + + append_backward_ops(mean) + + tensor = core.LoDTensor() + tensor.set(np.arange(10).reshape(10, 1).astype('float32'), place) + tensor.set_lod([[0, 3, 9, 10]]) + + mask_np = np.array([0, 1, 0]).astype('bool') + mask_np = np.expand_dims(mask_np, axis=1) + + mask = core.LoDTensor() + mask.set(mask_np, place) + + exe = Executor(place) + scope = core.Scope() + + g_vars = program.global_block().var(x.name + "@GRAD") + g_out = [ + item.sum() + for item in map(np.array, + exe.run(program, + feed={'x': tensor, + 'y': mask}, + fetch_list=[g_vars], + scope=scope)) + ] + + g_out_sum = np.array(g_out).sum() + + self.assertAlmostEqual(1.0, g_out_sum, delta=0.1) + + +if __name__ == '__main__': + unittest.main()