diff --git a/cmake/generic.cmake b/cmake/generic.cmake index 534be0abe246ac70950d85ad05441825c8ca768a..41b9b5928958ae31799c396a8d77fd7cff557905 100644 --- a/cmake/generic.cmake +++ b/cmake/generic.cmake @@ -187,7 +187,13 @@ function(cc_library TARGET_NAME) endif() # cpplint code style - add_style_check_target(${TARGET_NAME} ${cc_library_SRCS}) + foreach(source_file ${cc_library_SRCS}) + string(REGEX REPLACE "\\.[^.]*$" "" source ${source_file}) + if(EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${source}.h) + list(APPEND cc_library_HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/${source}.h) + endif() + endforeach() + add_style_check_target(${TARGET_NAME} ${cc_library_SRCS} ${cc_library_HEADERS}) else(cc_library_SRCS) if (cc_library_DEPS) @@ -239,6 +245,14 @@ function(nv_library TARGET_NAME) add_dependencies(${TARGET_NAME} ${nv_library_DEPS}) target_link_libraries(${TARGET_NAME} ${nv_library_DEPS}) endif() + # cpplint code style + foreach(source_file ${nv_library_SRCS}) + string(REGEX REPLACE "\\.[^.]*$" "" source ${source_file}) + if(EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${source}.h) + list(APPEND cc_library_HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/${source}.h) + endif() + endforeach() + add_style_check_target(${TARGET_NAME} ${nv_library_SRCS} ${nv_library_HEADERS}) else(nv_library_SRCS) if (nv_library_DEPS) merge_static_libs(${TARGET_NAME} ${nv_library_DEPS}) diff --git a/paddle/framework/ddim.h b/paddle/framework/ddim.h index 9fcc657edcd5459d0a42a64d708603a4bcd53cf0..5aa5af0c19be5a209c760282cb1a090fc57a53ad 100644 --- a/paddle/framework/ddim.h +++ b/paddle/framework/ddim.h @@ -25,18 +25,15 @@ limitations under the License. */ namespace paddle { namespace framework { -namespace { -typedef boost::variant, Dim<2>, Dim<3>, Dim<4>, Dim<5>, Dim<6>, Dim<7>, - Dim<8>, Dim<9>> - DDimVar; -} - /** * \brief A dynamically sized dimension. * * The number of dimensions must be between [1, 9]. */ struct DDim { + typedef boost::variant, Dim<2>, Dim<3>, Dim<4>, Dim<5>, Dim<6>, Dim<7>, + Dim<8>, Dim<9>> + DDimVar; DDimVar var; DDim() : var(Dim<1>()) {} diff --git a/paddle/framework/grad_op_builder.cc b/paddle/framework/grad_op_builder.cc index 3aefbb3fffb064c6fa9f41d7296088a24758af80..6d032fb78f099f5142d64e531d1a03c10ed5e68e 100644 --- a/paddle/framework/grad_op_builder.cc +++ b/paddle/framework/grad_op_builder.cc @@ -26,7 +26,7 @@ using VarIndexMap = std::unordered_map; enum class OpArgType { IN, OUT }; static std::vector* GetOpFormat(OperatorBase* op, const OpArgType& type) { - std::string key = type == OpArgType::IN ? "input_format" : "output_name"; + std::string key = type == OpArgType::IN ? "input_format" : "output_format"; return op->attrs_.count(key) ? &boost::get>(op->attrs_.at(key)) : nullptr; @@ -34,7 +34,7 @@ static std::vector* GetOpFormat(OperatorBase* op, const OpArgType& type) { static const std::vector* GetOpFormat(const OperatorBase* op, const OpArgType& type) { - std::string key = type == OpArgType::IN ? "input_format" : "output_name"; + std::string key = type == OpArgType::IN ? "input_format" : "output_format"; return op->attrs_.count(key) ? &boost::get>(op->attrs_.at(key)) : nullptr; @@ -82,7 +82,7 @@ OperatorBase* BuildGradOp(const OperatorBase* op) { grad_op->attrs_ = op->attrs_; grad_op->attrs_.erase("input_format"); grad_op->attrs_.erase("output_format"); - if (GetOpFormat(op, OpArgType::OUT) != nullptr) { + if (GetOpFormat(op, OpArgType::IN) != nullptr) { grad_op->attrs_["output_format"] = std::vector({0}); } if (GetOpFormat(op, OpArgType::IN) != nullptr || diff --git a/paddle/framework/grad_op_builder.h b/paddle/framework/grad_op_builder.h index cf235de6c267a4a1feb7afd3e4dbe7a6a668ee5e..998f8ebbb5f2f4fb8b7e938b5916afd0f8a7930d 100644 --- a/paddle/framework/grad_op_builder.h +++ b/paddle/framework/grad_op_builder.h @@ -1,3 +1,17 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + #pragma once #include "paddle/framework/operator.h" diff --git a/paddle/framework/grad_op_builder_test.cc b/paddle/framework/grad_op_builder_test.cc index e9cf3b9798db2cbfb8d26259ae9a6741fbae8278..96d7f309d67b15c000ab8ce3769931322fbca880 100644 --- a/paddle/framework/grad_op_builder_test.cc +++ b/paddle/framework/grad_op_builder_test.cc @@ -8,10 +8,49 @@ USE_OP(add_two); namespace paddle { namespace framework { +class NOP : public OperatorBase { + public: + void InferShape(const Scope &scope) const override {} + void Run(const Scope &scope, + const platform::DeviceContext &dev_ctx) const override {} +}; + +class MutiInOutOpMaker : public OpProtoAndCheckerMaker { + public: + MutiInOutOpMaker(OpProto *proto, OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("In1", "a single input"); + AddInput("In2_mult", "a multiple input").SetMultiple(); + AddInput("In3", "another single input"); + AddOutput("Out1", "a single output"); + AddOutput("Out2_mult", "a multiple output").SetMultiple(); + AddComment("test op with multiple inputs and outputs"); + } +}; + +class IOIgnoredOpMaker : public OpProtoAndCheckerMaker { + public: + IOIgnoredOpMaker(OpProto *proto, OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("In1", "a single input"); + AddInput("In2_mult", "a multiple input").SetMultiple().IgnoreGradient(); + AddInput("In3_mult", "another multiple input").SetMultiple(); + AddOutput("Out1_mult", "a multiple output").SetMultiple(); + AddOutput("Out2", "a single output").IgnoreGradient(); + AddComment("op with inputs and outputs ignored in gradient calculating"); + } +}; + +} // namespace framework +} // namespace paddle + +namespace f = paddle::framework; + TEST(GradOpBuilder, AddTwo) { - std::shared_ptr add_op( - OpRegistry::CreateOp("add_two", {"x", "y"}, {"out"}, {})); - std::shared_ptr grad_add_op = OpRegistry::CreateGradOp(*add_op); + std::shared_ptr add_op( + f::OpRegistry::CreateOp("add_two", {"x", "y"}, {"out"}, {})); + std::shared_ptr grad_add_op = + f::OpRegistry::CreateGradOp(*add_op); EXPECT_EQ(static_cast(grad_add_op->inputs_.size()), 4); EXPECT_EQ(static_cast(grad_add_op->outputs_.size()), 2); EXPECT_EQ(grad_add_op->Input("X"), "x"); @@ -22,5 +61,85 @@ TEST(GradOpBuilder, AddTwo) { EXPECT_EQ(grad_add_op->Output("Y@GRAD"), "y@GRAD"); } -} // namespace framework -} // namespace paddle \ No newline at end of file +REGISTER_OP(mult_io, f::NOP, f::MutiInOutOpMaker); +REGISTER_GRADIENT_OP(mult_io, mult_io_grad, f::NOP); +REGISTER_OP(io_ignored, f::NOP, f::IOIgnoredOpMaker); +REGISTER_GRADIENT_OP(io_ignored, io_ignored_grad, f::NOP); + +TEST(GradOpBuilder, MutiInOut) { + f::AttributeMap attrs{{"input_format", std::vector{0, 1, 4, 5}}, + {"output_format", std::vector{0, 1, 3}}}; + std::shared_ptr test_op(f::OpRegistry::CreateOp( + "mult_io", {"in1", "in2_1", "in2_2", "in2_3", "in3"}, + {"out1", "out2_1", "out2_2"}, attrs)); + std::shared_ptr grad_test_op = + f::OpRegistry::CreateGradOp(*test_op); + + ASSERT_EQ(grad_test_op->inputs_.size(), 5UL + 3UL + 3UL); + EXPECT_EQ(grad_test_op->Input("In1"), "in1"); + EXPECT_EQ(grad_test_op->Inputs("In2_mult"), + std::vector({"in2_1", "in2_2", "in2_3"})); + EXPECT_EQ(grad_test_op->Input("In3"), "in3"); + EXPECT_EQ(grad_test_op->Input("Out1"), "out1"); + EXPECT_EQ(grad_test_op->Inputs("Out2_mult"), + std::vector({"out2_1", "out2_2"})); + EXPECT_EQ(grad_test_op->Input("Out1" + f::OperatorBase::GRAD_VAR_SUFFIX()), + "out1" + f::OperatorBase::GRAD_VAR_SUFFIX()); + EXPECT_EQ( + grad_test_op->Inputs("Out2_mult" + f::OperatorBase::GRAD_VAR_SUFFIX()), + std::vector( + {"out2_1" + f::OperatorBase::GRAD_VAR_SUFFIX(), + "out2_2" + f::OperatorBase::GRAD_VAR_SUFFIX()})); + + ASSERT_EQ(grad_test_op->outputs_.size(), 5UL); + EXPECT_EQ(grad_test_op->Output("In1" + f::OperatorBase::GRAD_VAR_SUFFIX()), + "in1" + f::OperatorBase::GRAD_VAR_SUFFIX()); + EXPECT_EQ( + grad_test_op->Outputs("In2_mult" + f::OperatorBase::GRAD_VAR_SUFFIX()), + std::vector({"in2_1" + f::OperatorBase::GRAD_VAR_SUFFIX(), + "in2_2" + f::OperatorBase::GRAD_VAR_SUFFIX(), + "in2_3" + f::OperatorBase::GRAD_VAR_SUFFIX()})); + EXPECT_EQ(grad_test_op->Output("In3" + f::OperatorBase::GRAD_VAR_SUFFIX()), + "in3" + f::OperatorBase::GRAD_VAR_SUFFIX()); +} + +TEST(GradOpBuilder, IOIgnoredInGradient) { + f::AttributeMap attrs{{"input_format", std::vector{0, 1, 3, 5}}, + {"output_format", std::vector{0, 2, 3}}}; + std::shared_ptr test_op(f::OpRegistry::CreateOp( + "io_ignored", {"in1", "in2_1", "in2_2", "in3_1", "in3_2"}, + {"out1_1", "out1_2", "out2"}, attrs)); + std::shared_ptr grad_test_op = + f::OpRegistry::CreateGradOp(*test_op); + + // 'In2' and 'Out2' are ignored in gradient calculating + ASSERT_EQ(grad_test_op->inputs_.size(), 5UL + 3UL + 3UL); + EXPECT_EQ(grad_test_op->Input("In1"), "in1"); + EXPECT_EQ(grad_test_op->Inputs("In2_mult"), + std::vector({f::OperatorBase::EMPTY_VAR_NAME(), + f::OperatorBase::EMPTY_VAR_NAME()})); + EXPECT_EQ(grad_test_op->Inputs("In3_mult"), + std::vector({"in3_1", "in3_2"})); + EXPECT_EQ(grad_test_op->Inputs("Out1_mult"), + std::vector({"out1_1", "out1_2"})); + EXPECT_EQ(grad_test_op->Input("Out2"), f::OperatorBase::EMPTY_VAR_NAME()); + EXPECT_EQ( + grad_test_op->Inputs("Out1_mult" + f::OperatorBase::GRAD_VAR_SUFFIX()), + std::vector( + {"out1_1" + f::OperatorBase::GRAD_VAR_SUFFIX(), + "out1_2" + f::OperatorBase::GRAD_VAR_SUFFIX()})); + EXPECT_EQ(grad_test_op->Input("Out2" + f::OperatorBase::GRAD_VAR_SUFFIX()), + "out2" + f::OperatorBase::GRAD_VAR_SUFFIX()); + + ASSERT_EQ(grad_test_op->outputs_.size(), 5UL); + EXPECT_EQ(grad_test_op->Output("In1" + f::OperatorBase::GRAD_VAR_SUFFIX()), + "in1" + f::OperatorBase::GRAD_VAR_SUFFIX()); + EXPECT_EQ( + grad_test_op->Outputs("In2_mult" + f::OperatorBase::GRAD_VAR_SUFFIX()), + std::vector({"in2_1" + f::OperatorBase::GRAD_VAR_SUFFIX(), + "in2_2" + f::OperatorBase::GRAD_VAR_SUFFIX()})); + EXPECT_EQ( + grad_test_op->Outputs("In3_mult" + f::OperatorBase::GRAD_VAR_SUFFIX()), + std::vector({"in3_1" + f::OperatorBase::GRAD_VAR_SUFFIX(), + "in3_2" + f::OperatorBase::GRAD_VAR_SUFFIX()})); +} diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index b58e7d34ebd2dc7231c54799fc911da1eb3262ee..24ce7930f110f9b7b398f879713158d96c7712da 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -314,7 +314,7 @@ class OpRegistry { static std::unordered_map& protos() { static std::unordered_map protos_; return protos_; - }; + } static std::unordered_map& grad_ops() { static std::unordered_map grad_ops_; @@ -336,7 +336,7 @@ class OpRegistry { static std::unordered_map& op_checkers() { static std::unordered_map op_checkers_; return op_checkers_; - }; + } static void GenerateTempVariableName(OperatorBase* op) { static std::atomic gUniqId(0UL); @@ -353,7 +353,7 @@ class OpRegistry { template class OpRegisterHelper { public: - OpRegisterHelper(const char* op_type) { + explicit OpRegisterHelper(const char* op_type) { OpRegistry::RegisterOp(op_type); } }; diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index c4e23c3350fc64fa10b187c39eefa5846ba29238..6786ad080fd0fd26a572735290f6ac6d9fdab857 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -285,7 +285,7 @@ class OperatorWithKernel : public OperatorBase { platform::Place place_; OpKernelKey() = default; - OpKernelKey(const platform::DeviceContext& dev_ctx) { + explicit OpKernelKey(const platform::DeviceContext& dev_ctx) { place_ = dev_ctx.GetPlace(); } diff --git a/paddle/framework/pybind.cc b/paddle/framework/pybind.cc index 1f30cb10f62cf019ed8fab78af4ff18151fba72c..cbb86c4195a6c7e976fc5e0dd69d77be46dfb17c 100644 --- a/paddle/framework/pybind.cc +++ b/paddle/framework/pybind.cc @@ -105,7 +105,16 @@ PYBIND11_PLUGIN(core) { .def("set", PyCUDATensorSetFromArray) .def("set", PyCUDATensorSetFromArray) #endif - .def("shape", [](Tensor &self) { return vectorize(self.dims()); }); + .def("shape", [](Tensor &self) { return vectorize(self.dims()); }) + .def("set_float_element", + [](Tensor &self, size_t offset, float f) { + // TODO(yuyang18): Only support GPU now. + self.data()[offset] = f; + }) + .def("get_float_element", [](Tensor &self, size_t offset) -> float { + // TODO(yuyang18): Only support GPU now. + return self.data()[offset]; + }); py::class_(m, "Variable", R"DOC(Variable Class. diff --git a/paddle/gserver/gradientmachines/RecurrentGradientMachine.cpp b/paddle/gserver/gradientmachines/RecurrentGradientMachine.cpp index 9ddd449de7500f5682d59469328f06971c6e83bf..f98bf95064fa539b990309dfe0bff10c1e99d096 100644 --- a/paddle/gserver/gradientmachines/RecurrentGradientMachine.cpp +++ b/paddle/gserver/gradientmachines/RecurrentGradientMachine.cpp @@ -967,8 +967,9 @@ void RecurrentGradientMachine::generateSequence() { size_t numSequences = getGenBatchSize(); resizeBootFrame(numSequences); - // We create only two sub-network in generation for alternate use. - // Thus, we can reduce total memory of output_ in layer forward. + // We create only two sub-network in generation, one stores states of all + // layers in previous time step and the other storing the states at current + // time step. resizeOrCreateFrames(2); // outFrameLines_.size() > 1UL @@ -1001,10 +1002,9 @@ void RecurrentGradientMachine::generateSequence() { // init outArg size_t resultNum = generator_.config.num_results_per_sample(); - IVector::resizeOrCreate( - generator_.outArg.ids, - generator_.config.max_num_frames() * numSequences * resultNum, - false); + size_t maxGenWordCount = + generator_.config.max_num_frames() * numSequences * resultNum; + IVector::resizeOrCreate(generator_.outArg.ids, maxGenWordCount, false); if (resultNum > 1) { CHECK_LE(resultNum, static_cast(generator_.config.beam_size())); Matrix::resizeOrCreate(generator_.outArg.in, @@ -1012,6 +1012,11 @@ void RecurrentGradientMachine::generateSequence() { /* width */ resultNum, false, /* useGpu */ false); + Matrix::resizeOrCreate(generator_.outArg.value, + /* height */ maxGenWordCount, + /* width */ 1, + false, + /* useGpu */ false); } ICpuGpuVector::resizeOrCreate(generator_.outArg.sequenceStartPositions, numSequences + 1, @@ -1313,13 +1318,20 @@ void RecurrentGradientMachine::fillGenOutputs() { starts[0] = 0; if (numResults > 1) { real* probs = generator_.outArg.in->getData(); + real* idsProb = generator_.outArg.value->getData(); + size_t curPos = 0; for (size_t i = 0; i < finalPaths_.size(); ++i) { for (size_t j = 0; j < finalPaths_[i].size(); ++j) { Path& path = finalPaths_[i][j]; - generator_.ids.push_back(path.ids.size()); // sequence size + size_t genLen = path.ids.size(); + generator_.ids.push_back(genLen); // sequence size generator_.ids.insert( generator_.ids.end(), path.ids.begin(), path.ids.end()); generator_.ids.push_back(-1); // end of sequence + + memcpy(idsProb + curPos, path.idsProb.data(), sizeof(real) * genLen); + curPos += genLen; + idsProb[curPos++] = -1.0; probs[i * numResults + j] = path.logProb; if (!j && dataArgsSize_) { diff --git a/paddle/gserver/gradientmachines/RecurrentGradientMachine.h b/paddle/gserver/gradientmachines/RecurrentGradientMachine.h index f245620cf668bb341df99cf498105cbd996a6b24..fb3fc5877ac96323e891f800db80af83b6809831 100644 --- a/paddle/gserver/gradientmachines/RecurrentGradientMachine.h +++ b/paddle/gserver/gradientmachines/RecurrentGradientMachine.h @@ -189,6 +189,11 @@ public: */ std::vector ids; + /** + * @brief idsProb, log probability of each generated words. + */ + std::vector idsProb; + /** * @brief logProb, current probability of path. */ @@ -228,11 +233,13 @@ public: */ Path(Path& old, int newId, real logProb, int machineId, int topIndex) : ids(old.ids), + idsProb(old.idsProb), logProb(old.logProb + logProb), machineId(machineId), topIndex(topIndex), seqId(old.seqId) { ids.push_back(newId); + idsProb.push_back(logProb); if (!old.probHistory.empty()) { this->probHistory = old.probHistory; // probHistory store current prob, not sum @@ -411,8 +418,9 @@ protected: struct Generator { GeneratorConfig config; - std::vector ids; // store generated sequences - Argument outArg; // final output argument + std::vector ids; // store generated sequences + std::vector idsProb; // log probability of each generated word + Argument outArg; // final output argument }; bool generating_; Generator generator_; diff --git a/paddle/gserver/tests/LayerGradUtil.cpp b/paddle/gserver/tests/LayerGradUtil.cpp index 9eca58f1a1baa6fb1c404a91a345bc7f9d6b4acc..fd9cfa1dc7a9028cb2c5c98baca98ffb2a837bac 100644 --- a/paddle/gserver/tests/LayerGradUtil.cpp +++ b/paddle/gserver/tests/LayerGradUtil.cpp @@ -400,7 +400,6 @@ void initDataLayer(TestConfig testConf, const std::vector& labelSeqStartPositions = testConf.inputDefs[i].labelSeqStartPositions; if (labelSeqStartPositions.size() != 0) { - CHECK(!sequenceStartPositions); CHECK_GE(static_cast(labelSeqStartPositions.size()), 2); sequenceStartPositions = @@ -410,6 +409,19 @@ void initDataLayer(TestConfig testConf, useGpu); data.sequenceStartPositions = sequenceStartPositions; } + + const std::vector& labelSubSeqStartPositions = + testConf.inputDefs[i].labelSubSeqStartPositions; + if (labelSubSeqStartPositions.size() != 0) { + CHECK_GE(static_cast(labelSubSeqStartPositions.size()), 2); + + subSequenceStartPositions = + ICpuGpuVector::create(labelSubSeqStartPositions.size(), useGpu); + subSequenceStartPositions->copyFrom(labelSubSeqStartPositions.data(), + labelSubSeqStartPositions.size(), + useGpu); + data.subSequenceStartPositions = subSequenceStartPositions; + } break; } default: diff --git a/paddle/gserver/tests/LayerGradUtil.h b/paddle/gserver/tests/LayerGradUtil.h index d299b4dd09418589514d99a72f83e1103ace7de1..5debedf5ef6a3262578ca01b335e664f9a334d35 100644 --- a/paddle/gserver/tests/LayerGradUtil.h +++ b/paddle/gserver/tests/LayerGradUtil.h @@ -67,6 +67,7 @@ struct InputDef { bool isStatic; std::vector labelInitValue; std::vector labelSeqStartPositions; + std::vector labelSubSeqStartPositions; MatrixPtr selfDefinedData; InputDef(InputType type, string nameIn, size_t dimIn, size_t sizeIn) { @@ -81,8 +82,10 @@ struct InputDef { InputDef(InputType type, string nameIn, MatrixPtr selfDefinedData, - std::vector selfDefinedSeqStartPos = {}) + std::vector selfDefinedSeqStartPos = {}, + std::vector selfDefinedSubSeqStartPos = {}) : labelSeqStartPositions(selfDefinedSeqStartPos), + labelSubSeqStartPositions(selfDefinedSubSeqStartPos), selfDefinedData(selfDefinedData) { inputType = type; name = nameIn; diff --git a/paddle/memory/detail/buddy_allocator.h b/paddle/memory/detail/buddy_allocator.h index 4fa3fb0ee5f826d2b084c0ba184c505aee3acc48..9c41378483993101a098fc4ad1068c1ef908e566 100644 --- a/paddle/memory/detail/buddy_allocator.h +++ b/paddle/memory/detail/buddy_allocator.h @@ -39,7 +39,7 @@ class BuddyAllocator { public: void* Alloc(size_t unaligned_size); - void Free(void*); + void Free(void* ptr); size_t Used(); public: diff --git a/paddle/memory/detail/meta_cache.h b/paddle/memory/detail/meta_cache.h index ca0789779e273fb71c3d6282c0a921cda2d776cc..cf5815644284c23a1d2abc904f8c5053ce107a72 100644 --- a/paddle/memory/detail/meta_cache.h +++ b/paddle/memory/detail/meta_cache.h @@ -33,17 +33,17 @@ namespace detail { */ class MetadataCache { public: - MetadataCache(bool uses_gpu); + explicit MetadataCache(bool uses_gpu); public: /*! \brief Load the associated metadata for the specified memory block. */ - Metadata load(const MemoryBlock*); + Metadata load(const MemoryBlock* memory_block); /*! \brief Store the associated metadata for the specified memory block. */ - void store(MemoryBlock*, const Metadata&); + void store(MemoryBlock* memory_block, const Metadata& meta_data); /*! \brief Indicate that the specified metadata will no longer be used. */ - void invalidate(MemoryBlock*); + void invalidate(MemoryBlock* memory_block); public: MetadataCache(const MetadataCache&) = delete; diff --git a/paddle/memory/memory.h b/paddle/memory/memory.h index 44f567caf9c19775f17988b5142b7693b41a126d..72351b9dfa63513713463bb47a3684f0dfd84ad3 100644 --- a/paddle/memory/memory.h +++ b/paddle/memory/memory.h @@ -68,7 +68,7 @@ class PODDeleter { static_assert(std::is_pod::value, "T must be POD"); public: - PODDeleter(Place place) : place_(place) {} + explicit PODDeleter(Place place) : place_(place) {} void operator()(T* ptr) { Free(place_, static_cast(ptr)); } private: diff --git a/paddle/operators/add_op.cu b/paddle/operators/add_op.cu index f961b37565f400b5c26844b9e7a3cff5e682340b..9bd08634da96c5595d6dd702ad9afafb94632b03 100644 --- a/paddle/operators/add_op.cu +++ b/paddle/operators/add_op.cu @@ -1,3 +1,17 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + #define EIGEN_USE_GPU #include "paddle/framework/op_registry.h" #include "paddle/operators/add_op.h" diff --git a/paddle/operators/cross_entropy_op.cu b/paddle/operators/cross_entropy_op.cu index 926a0c616b957d8e542c1f3dee227a718fb29f07..2f453f8379ca7ce0612fed757719acb2d2cf0ad8 100644 --- a/paddle/operators/cross_entropy_op.cu +++ b/paddle/operators/cross_entropy_op.cu @@ -1,5 +1,19 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + #define EIGEN_USE_GPU #include "paddle/operators/cross_entropy_op.h" REGISTER_OP_GPU_KERNEL(onehot_cross_entropy, - ops::OnehotCrossEntropyOpKernel); \ No newline at end of file + ops::OnehotCrossEntropyOpKernel); diff --git a/paddle/operators/fill_zeros_like_op.cu b/paddle/operators/fill_zeros_like_op.cu index 55ad58f4f17cd4a3e737c01b001675d2690d273e..ed1068219c8fee8c6e8809f450a9d38c8226f317 100644 --- a/paddle/operators/fill_zeros_like_op.cu +++ b/paddle/operators/fill_zeros_like_op.cu @@ -1,6 +1,20 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + #include "paddle/framework/op_registry.h" #include "paddle/operators/fill_zeros_like_op.h" REGISTER_OP_GPU_KERNEL( fill_zeros_like, - paddle::operators::FillZerosLikeKernel); \ No newline at end of file + paddle::operators::FillZerosLikeKernel); diff --git a/paddle/operators/mean_op.cu b/paddle/operators/mean_op.cu index e15de2fd0dd84e4015ee0e3b5343d7651b027a88..8b97b0154ccdc8c41a90f7580af829c5c8663b60 100644 --- a/paddle/operators/mean_op.cu +++ b/paddle/operators/mean_op.cu @@ -1,6 +1,20 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + #define EIGEN_USE_GPU #include "paddle/operators/mean_op.h" REGISTER_OP_GPU_KERNEL(mean, ops::MeanKernel); -REGISTER_OP_GPU_KERNEL(mean_grad, ops::MeanGradKernel); \ No newline at end of file +REGISTER_OP_GPU_KERNEL(mean_grad, ops::MeanGradKernel); diff --git a/paddle/operators/mul_op.cu b/paddle/operators/mul_op.cu index dc9236701627dc9335b844d2a82e18eb1f7dfd42..1dc04c4297daed7a7861a09cf6b99446c296ffa5 100644 --- a/paddle/operators/mul_op.cu +++ b/paddle/operators/mul_op.cu @@ -15,4 +15,4 @@ #define EIGEN_USE_GPU #include "paddle/operators/mul_op.h" -REGISTER_OP_GPU_KERNEL(mul, ops::MulKernel); \ No newline at end of file +REGISTER_OP_GPU_KERNEL(mul, ops::MulKernel); diff --git a/paddle/operators/recurrent_op.h b/paddle/operators/recurrent_op.h index 2a0964fff326500b6215dd4afac63c75d64c4a06..35e6d9d50dd04048da7ffb384014d5909cd659a4 100644 --- a/paddle/operators/recurrent_op.h +++ b/paddle/operators/recurrent_op.h @@ -19,7 +19,7 @@ namespace paddle { namespace operators { -using namespace paddle::framework; +using namespace paddle::framework; // NOLINT namespace rnn { @@ -94,7 +94,7 @@ void InitArgument(const ArgumentName& name, Argument* arg); }; // namespace rnn // The sequence format in RecurrentOp is Tensor now. -// TODO: +// TODO(Yan Chunwei): // 1. No-padding computing for sequences with indifinite length in one batch. // 2. Hierarchical RNN for sequence with sub-sequence. // 3. Internal Memory. @@ -172,12 +172,10 @@ public: /** * InferShape must be called before Run. */ - virtual void InferShape(const Scope& scope) const override { - alg_.InferShape(scope); - } + void InferShape(const Scope& scope) const override { alg_.InferShape(scope); } - virtual void Run(const Scope& scope, - const platform::DeviceContext& dev_ctx) const override { + void Run(const Scope& scope, + const platform::DeviceContext& dev_ctx) const override { alg_.Run(scope, dev_ctx); } @@ -194,12 +192,10 @@ public: /** * InferShape must be called before Run. */ - virtual void InferShape(const Scope& scope) const override { - alg_.InferShape(scope); - } + void InferShape(const Scope& scope) const override { alg_.InferShape(scope); } - virtual void Run(const Scope& scope, - const platform::DeviceContext& dev_ctx) const override { + void Run(const Scope& scope, + const platform::DeviceContext& dev_ctx) const override { alg_.Run(scope, dev_ctx); } diff --git a/paddle/operators/rowwise_add_op.cu b/paddle/operators/rowwise_add_op.cu index 82338ceccc06653791b26472e18d804f62735649..f76faa0a3a93a1ac277a1d1aa83c3fa6c3944648 100644 --- a/paddle/operators/rowwise_add_op.cu +++ b/paddle/operators/rowwise_add_op.cu @@ -1,3 +1,17 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + #define EIGEN_USE_GPU #include "paddle/operators/rowwise_add_op.h" diff --git a/paddle/operators/sgd_op.cu b/paddle/operators/sgd_op.cu index d79258cbf13c699cfb2afaee229cf96a3e377b5e..72629ccfbb8bc8ec53045289bd985c721c62fa10 100644 --- a/paddle/operators/sgd_op.cu +++ b/paddle/operators/sgd_op.cu @@ -1,4 +1,18 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + #define EIGEN_USE_GPU #include "paddle/operators/sgd_op.h" -REGISTER_OP_GPU_KERNEL(sgd, ops::SGDOpKernel); \ No newline at end of file +REGISTER_OP_GPU_KERNEL(sgd, ops::SGDOpKernel); diff --git a/paddle/operators/sigmoid_op.cu b/paddle/operators/sigmoid_op.cu index c9d11a2e1f9dcc563765c9e8cc1bae6beff57f18..2123b17e4b5e90c22c2d6e9177f2a8956f8a4ac9 100644 --- a/paddle/operators/sigmoid_op.cu +++ b/paddle/operators/sigmoid_op.cu @@ -1,3 +1,17 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + #define EIGEN_USE_GPU #include "paddle/operators/sigmoid_op.h" diff --git a/paddle/operators/softmax_op.cu b/paddle/operators/softmax_op.cu index 8c652213f2e4c0e0ea1a31987fcb37c86374cd2a..b79228580a7ea0f70b62eb2dc7a61cf85bc0b5fb 100644 --- a/paddle/operators/softmax_op.cu +++ b/paddle/operators/softmax_op.cu @@ -1,6 +1,21 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + #define EIGEN_USE_GPU #include "paddle/framework/op_registry.h" #include "paddle/operators/softmax_op.h" REGISTER_OP_GPU_KERNEL(softmax, ops::SoftmaxKernel); -REGISTER_OP_GPU_KERNEL(softmax_grad, ops::SoftmaxGradKernel); +REGISTER_OP_GPU_KERNEL(softmax_grad, + ops::SoftmaxGradKernel); diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index 2038fafe2e15ec2631726643695ac6cbc317fed9..48b9f5dcb5cc578f9e70ed7abe076b66b68dc719 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -40,7 +40,7 @@ class DeviceContext { class CPUDeviceContext : public DeviceContext { public: CPUDeviceContext(); - CPUDeviceContext(CPUPlace); + explicit CPUDeviceContext(CPUPlace); virtual ~CPUDeviceContext() {} Eigen::DefaultDevice* eigen_device() const; @@ -55,7 +55,7 @@ class CPUDeviceContext : public DeviceContext { class CUDADeviceContext : public DeviceContext { public: - explicit CUDADeviceContext(GPUPlace); + CUDADeviceContext(GPUPlace); // NOLINT virtual ~CUDADeviceContext(); /*! \brief Wait for all operations completion in the stream. */ @@ -69,10 +69,10 @@ class CUDADeviceContext : public DeviceContext { // clang-format off /*! \brief Return cublas handle in the device context. */ - cublasHandle_t cublas_handle (); + cublasHandle_t cublas_handle(); /*! \brief Return cudnn handle in the device context. */ - cudnnHandle_t cudnn_handle (); + cudnnHandle_t cudnn_handle(); /*! \brief Return curand handle in the device context. */ curandGenerator_t curand_generator(); diff --git a/paddle/platform/dynload/cublas.cc b/paddle/platform/dynload/cublas.cc index 4e3dfdaefb2348346e8f917b1f6c758bf6d91a1a..9cd2a1f565526f8dc45932ba6168f4e25c6ad238 100644 --- a/paddle/platform/dynload/cublas.cc +++ b/paddle/platform/dynload/cublas.cc @@ -1,3 +1,17 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + #include namespace paddle { diff --git a/paddle/platform/dynload/cudnn.cc b/paddle/platform/dynload/cudnn.cc index 8b5e15b5efcdae6a1eed09f002eb2f4f2163035f..d3e4cb567d71b987724366b6a0896f5df0eb6055 100644 --- a/paddle/platform/dynload/cudnn.cc +++ b/paddle/platform/dynload/cudnn.cc @@ -1,3 +1,17 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + #include namespace paddle { @@ -25,4 +39,4 @@ CUDNN_DNN_ROUTINE_EACH_R5(DEFINE_WRAP); } // namespace dynload } // namespace platform -} // namespace paddle \ No newline at end of file +} // namespace paddle diff --git a/paddle/platform/dynload/curand.cc b/paddle/platform/dynload/curand.cc index 5c1fab992c98569d4a95b6e699d97d428511e48e..d05dd88126bfee7278e553710a717b8f2eb02ae0 100644 --- a/paddle/platform/dynload/curand.cc +++ b/paddle/platform/dynload/curand.cc @@ -1,3 +1,17 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + #include namespace paddle { @@ -10,6 +24,7 @@ void *curand_dso_handle; #define DEFINE_WRAP(__name) DynLoad__##__name __name CURAND_RAND_ROUTINE_EACH(DEFINE_WRAP); -} -} -} \ No newline at end of file + +} // namespace dynload +} // namespace platform +} // namespace paddle diff --git a/paddle/platform/enforce.h b/paddle/platform/enforce.h index 60a42c777d1c2ebbc22fdb77b1100cc6fcf7ff35..bc0715656a7d61774d53d4a0643ec1c105706085 100644 --- a/paddle/platform/enforce.h +++ b/paddle/platform/enforce.h @@ -162,5 +162,50 @@ inline void throw_on_error(T e) { } \ } while (0) +/* + * Some enforce helpers here, usage: + * int a = 1; + * int b = 2; + * PADDLE_ENFORCE_EQ(a, b); + * + * will raise an expression described as follows: + * "enforce a == b failed, 1 != 2" with detailed stack infomation. + * + * extra messages is also supported, for example: + * PADDLE_ENFORCE(a, b, "some simple enforce failed between %d numbers", 2) + */ + +#define PADDLE_ENFORCE_EQ(__VAL0, __VAL1, ...) \ + __PADDLE_BINARY_COMPARE(__VAL0, __VAL1, ==, !=, __VA_ARGS__) +#define PADDLE_ENFORCE_NE(__VAL0, __VAL1, ...) \ + __PADDLE_BINARY_COMPARE(__VAL0, __VAL1, !=, ==, __VA_ARGS__) +#define PADDLE_ENFORCE_GT(__VAL0, __VAL1, ...) \ + __PADDLE_BINARY_COMPARE(__VAL0, __VAL1, >, <=, __VA_ARGS__) +#define PADDLE_ENFORCE_GE(__VAL0, __VAL1, ...) \ + __PADDLE_BINARY_COMPARE(__VAL0, __VAL1, >=, <, __VA_ARGS__) +#define PADDLE_ENFORCE_LT(__VAL0, __VAL1, ...) \ + __PADDLE_BINARY_COMPARE(__VAL0, __VAL1, <, >=, __VA_ARGS__) +#define PADDLE_ENFORCE_LE(__VAL0, __VAL1, ...) \ + __PADDLE_BINARY_COMPARE(__VAL0, __VAL1, <=, >, __VA_ARGS__) + +// if two values have different data types, choose a compatible type for them. +template +struct CompatibleType { + static const bool t1_to_t2 = std::is_convertible::value; + typedef typename std::conditional::type type; +}; + +#define __PADDLE_BINARY_COMPARE(__VAL0, __VAL1, __CMP, __INV_CMP, ...) \ + PADDLE_ENFORCE(__COMPATIBLE_TYPE(__VAL0, __VAL1, __VAL0) \ + __CMP __COMPATIBLE_TYPE(__VAL0, __VAL1, __VAL1), \ + "enforce %s " #__CMP " %s failed, %s " #__INV_CMP " %s\n%s", \ + #__VAL0, #__VAL1, std::to_string(__VAL0), \ + std::to_string(__VAL1), \ + paddle::string::Sprintf("" __VA_ARGS__)); + +#define __COMPATIBLE_TYPE(__VAL0, __VAL1, __VAL) \ + typename paddle::platform::CompatibleType::type(__VAL) + } // namespace platform } // namespace paddle diff --git a/paddle/platform/enforce_test.cc b/paddle/platform/enforce_test.cc index 2ac31812a80d8dd57ce82234cb5835e029a46067..7117b49474044af08ae9db79c2fae6693e966af2 100644 --- a/paddle/platform/enforce_test.cc +++ b/paddle/platform/enforce_test.cc @@ -34,3 +34,165 @@ TEST(ENFORCE, FAILED) { } ASSERT_TRUE(in_catch); } + +TEST(ENFORCE, NO_ARG_OK) { + int a = 2; + int b = 2; + PADDLE_ENFORCE_EQ(a, b); + // test enforce with extra message. + PADDLE_ENFORCE_EQ(a, b, "some thing wrong %s", "info"); +} + +TEST(ENFORCE_EQ, NO_EXTRA_MSG_FAIL) { + int a = 2; + bool in_catch = false; + + try { + PADDLE_ENFORCE_EQ(a, 1 + 3); + + } catch (paddle::platform::EnforceNotMet error) { + in_catch = true; + const std::string msg = "enforce a == 1 + 3 failed, 2 != 4"; + const char* what = error.what(); + for (size_t i = 0; i < msg.length(); ++i) { + ASSERT_EQ(what[i], msg[i]); + } + } + + ASSERT_TRUE(in_catch); +} + +TEST(ENFORCE_EQ, EXTRA_MSG_FAIL) { + int a = 2; + bool in_catch = false; + + try { + PADDLE_ENFORCE_EQ(a, 1 + 3, "%s size not match", "their"); + + } catch (paddle::platform::EnforceNotMet error) { + in_catch = true; + const std::string msg = + "enforce a == 1 + 3 failed, 2 != 4\ntheir size not match"; + const char* what = error.what(); + for (size_t i = 0; i < msg.length(); ++i) { + ASSERT_EQ(what[i], msg[i]); + } + } + + ASSERT_TRUE(in_catch); +} + +TEST(ENFORCE_NE, OK) { + PADDLE_ENFORCE_NE(1, 2); + PADDLE_ENFORCE_NE(1.0, 2UL); +} +TEST(ENFORCE_NE, FAIL) { + bool in_catch = false; + + try { + // 2UL here to check data type compatible + PADDLE_ENFORCE_NE(1.0, 1UL); + + } catch (paddle::platform::EnforceNotMet error) { + in_catch = true; + const std::string msg = "enforce 1.0 != 1UL failed, 1.000000 == 1"; + const char* what = error.what(); + for (size_t i = 0; i < msg.length(); ++i) { + ASSERT_EQ(what[i], msg[i]); + } + } + + ASSERT_TRUE(in_catch); +} + +TEST(ENFORCE_GT, OK) { PADDLE_ENFORCE_GT(2, 1); } +TEST(ENFORCE_GT, FAIL) { + bool in_catch = false; + + try { + // 2UL here to check data type compatible + PADDLE_ENFORCE_GT(1, 2UL); + + } catch (paddle::platform::EnforceNotMet error) { + in_catch = true; + const std::string msg = "enforce 1 > 2UL failed, 1 <= 2"; + const char* what = error.what(); + for (size_t i = 0; i < msg.length(); ++i) { + ASSERT_EQ(what[i], msg[i]); + } + } + + ASSERT_TRUE(in_catch); +} + +TEST(ENFORCE_GE, OK) { + PADDLE_ENFORCE_GE(2, 2UL); + PADDLE_ENFORCE_GE(3, 2UL); + PADDLE_ENFORCE_GE(3, 2); + PADDLE_ENFORCE_GE(3.21, 2UL); +} +TEST(ENFORCE_GE, FAIL) { + bool in_catch = false; + + try { + PADDLE_ENFORCE_GE(1, 2UL); + + } catch (paddle::platform::EnforceNotMet error) { + in_catch = true; + const std::string msg = "enforce 1 >= 2UL failed, 1 < 2"; + const char* what = error.what(); + for (size_t i = 0; i < msg.length(); ++i) { + ASSERT_EQ(what[i], msg[i]); + } + } + + ASSERT_TRUE(in_catch); +} + +TEST(ENFORCE_LE, OK) { + PADDLE_ENFORCE_LE(1, 1); + PADDLE_ENFORCE_LE(1, 1UL); + PADDLE_ENFORCE_LE(2, 3UL); + PADDLE_ENFORCE_LE(2UL, 3); + PADDLE_ENFORCE_LE(2UL, 3.2); +} +TEST(ENFORCE_LE, FAIL) { + bool in_catch = false; + + try { + PADDLE_ENFORCE_GT(1, 2UL); + + } catch (paddle::platform::EnforceNotMet error) { + in_catch = true; + const std::string msg = "enforce 1 > 2UL failed, 1 <= 2"; + const char* what = error.what(); + for (size_t i = 0; i < msg.length(); ++i) { + ASSERT_EQ(what[i], msg[i]); + } + } + + ASSERT_TRUE(in_catch); +} + +TEST(ENFORCE_LT, OK) { + PADDLE_ENFORCE_LT(3, 10); + PADDLE_ENFORCE_LT(2, 3UL); + PADDLE_ENFORCE_LT(2UL, 3); +} +TEST(ENFORCE_LT, FAIL) { + bool in_catch = false; + + try { + PADDLE_ENFORCE_LT(1UL, 0.12); + + } catch (paddle::platform::EnforceNotMet error) { + in_catch = true; + const std::string msg = "enforce 1UL < 0.12 failed, 1 >= 0.12"; + const char* what = error.what(); + for (size_t i = 0; i < msg.length(); ++i) { + ASSERT_EQ(what[i], msg[i]); + } + } + + ASSERT_TRUE(in_catch); +} diff --git a/paddle/platform/place.h b/paddle/platform/place.h index 7cead183884bc9379355cd931921b40d6c11ce90..a37ad38a8fb030192fa4c871106c6eb54816768a 100644 --- a/paddle/platform/place.h +++ b/paddle/platform/place.h @@ -32,7 +32,7 @@ struct CPUPlace { struct GPUPlace { GPUPlace() : GPUPlace(0) {} - GPUPlace(int d) : device(d) {} + GPUPlace(int d) : device(d) {} // NOLINT // needed for variant equality comparison inline bool operator==(const GPUPlace &o) const { return device == o.device; } diff --git a/paddle/string/piece.h b/paddle/string/piece.h index 0272529d1c9b2cb6000a26f1d4d80276d06bf27b..3b887490b5c6c016bc30d8db060c5c1c01b8bf54 100644 --- a/paddle/string/piece.h +++ b/paddle/string/piece.h @@ -39,8 +39,8 @@ public: // size_ is 0. Piece(); Piece(const char* d, size_t n); - Piece(const char* d); - Piece(const std::string& s); + Piece(const char* d); // NOLINT + Piece(const std::string& s); // NOLINT const char* data() const { return data_; } size_t len() const { return size_; } diff --git a/python/paddle/v2/framework/tests/CMakeLists.txt b/python/paddle/v2/framework/tests/CMakeLists.txt index 4619b0edc3dd7e253e01f7fee5e6a8641340d291..e66197030e2dd9e113e4564aaacb1c5dab25771b 100644 --- a/python/paddle/v2/framework/tests/CMakeLists.txt +++ b/python/paddle/v2/framework/tests/CMakeLists.txt @@ -13,4 +13,5 @@ add_python_test(test_framework test_sigmoid_op.py test_softmax_op.py test_rowwise_add_op.py - test_network.py) + test_network.py + gradient_checker.py) diff --git a/python/paddle/v2/framework/tests/gradient_checker.py b/python/paddle/v2/framework/tests/gradient_checker.py new file mode 100644 index 0000000000000000000000000000000000000000..4022de1c40e41aa77a7f31d82b55b63585cbd5f5 --- /dev/null +++ b/python/paddle/v2/framework/tests/gradient_checker.py @@ -0,0 +1,90 @@ +import paddle.v2.framework.core as core +from paddle.v2.framework.create_op_creation_methods import op_creations +import numpy +import unittest + +__all__ = ['get_numeric_gradient'] + + +def get_numeric_gradient(op, + input_values, + output_name, + input_to_check, + delta=1e-2, + local_scope=None): + """ + Get Numeric Gradient for an operator's input. + + :param op: C++ operator instance, could be an network + :param input_values: The input variables. Should be an dictionary, key is + variable name. Value is numpy array. + :param output_name: The final output variable name. + :param input_to_check: The input variable need to get gradient. + :param delta: The perturbation value for numeric gradient method. The + smaller delta is, the more accurate result will get. But if that delta is + too small, it could occur numerical stability problem. + :param local_scope: The local scope used for get_numeric_gradient. + :return: The gradient array in numpy format. + """ + if local_scope is None: + local_scope = core.Scope() + + # Create all input variable in local_scope + for var_name in input_values: + var = local_scope.new_var(var_name) + tensor = var.get_tensor() + tensor.set_dims(input_values[var_name].shape) + tensor.alloc_float(core.CPUPlace()) + tensor.set(input_values[var_name], core.CPUPlace()) + + # Create all output variable in local_scope + for output in op.outputs(): + if local_scope.find_var(output) is None: + local_scope.new_var(output).get_tensor() + + op.infer_shape(local_scope) + + # allocate output memory + for output in op.outputs(): + local_scope.find_var(output).get_tensor().alloc_float(core.CPUPlace()) + + # TODO(yuyang18): Only CPU is support now. + cpu_ctx = core.DeviceContext.create(core.CPUPlace()) + + def get_output(): + op.run(local_scope, cpu_ctx) + return numpy.array(local_scope.find_var(output_name).get_tensor()).sum() + + def product(dim): + return reduce(lambda a, b: a * b, dim, 1) + + tensor_to_check = local_scope.find_var(input_to_check).get_tensor() + tensor_size = product(tensor_to_check.get_dims()) + gradient_flat = numpy.zeros(shape=(tensor_size, ), dtype='float32') + for i in xrange(tensor_size): + origin = tensor_to_check.get_float_element(i) + x_pos = origin + delta + tensor_to_check.set_float_element(i, x_pos) + y_pos = get_output() + + x_neg = origin - delta + tensor_to_check.set_float_element(i, x_neg) + y_neg = get_output() + + tensor_to_check.set_float_element(i, origin) # restore old value + gradient_flat[i] = (y_pos - y_neg) / delta / 2 + return gradient_flat.reshape(tensor_to_check.get_dims()) + + +if __name__ == '__main__': + + class GetNumericGradientTest(unittest.TestCase): + def test_add_op(self): + add_op = op_creations.add_two(X="X", Y="Y", Out="Z") + x = numpy.random.random((10, 1)).astype("float32") + y = numpy.random.random((10, 1)).astype("float32") + + arr = get_numeric_gradient(add_op, {'X': x, "Y": y}, 'Z', 'X') + self.assertAlmostEqual(arr.mean(), 1.0, delta=1e-2) + + unittest.main()