diff --git a/CMakeLists.txt b/CMakeLists.txt index dcd1218a5b0b62f2739b727391aca31b48ed9ccb..ad559672ad2f83a3d62cdf332b47c6cf1e730f70 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -55,6 +55,7 @@ option(WITH_C_API "Compile PaddlePaddle with C-API(Prediction)" OFF) option(WITH_GOLANG "Compile PaddlePaddle with GOLANG" OFF) option(GLIDE_INSTALL "Download and install go dependencies " ON) option(USE_NNPACK "Compile PaddlePaddle with NNPACK library" OFF) +option(USE_EIGEN_FOR_BLAS "Use matrix multiplication in Eigen" OFF) # CMAKE_BUILD_TYPE if(NOT CMAKE_BUILD_TYPE) @@ -137,9 +138,9 @@ set(EXTERNAL_LIBS ) if(WITH_GPU) - list(APPEND EXTERNAL_LIB ${CUDA_LIBRARIES} ${CUDA_rt_LIBRARY}) + list(APPEND EXTERNAL_LIBS ${CUDA_LIBRARIES} ${CUDA_rt_LIBRARY}) if(NOT WITH_DSO) - list(APPEND EXTERNAL_LIB ${CUDNN_LIBRARY} ${CUDA_CUBLAS_LIBRARIES} ${CUDA_curand_LIBRARY}) + list(APPEND EXTERNAL_LIBS ${CUDNN_LIBRARY} ${CUDA_CUBLAS_LIBRARIES} ${CUDA_curand_LIBRARY}) endif(NOT WITH_DSO) endif(WITH_GPU) diff --git a/cmake/configure.cmake b/cmake/configure.cmake index 209f9078a637ac581d90212a48216eb388c477ed..51c3b918cc4ef4cf6c8052ccc14028a872309fcf 100644 --- a/cmake/configure.cmake +++ b/cmake/configure.cmake @@ -28,6 +28,10 @@ if(NOT WITH_TIMER) add_definitions(-DPADDLE_DISABLE_TIMER) endif(NOT WITH_TIMER) +if(USE_EIGEN_FOR_BLAS) + add_definitions(-DPADDLE_USE_EIGEN_FOR_BLAS) +endif(USE_EIGEN_FOR_BLAS) + if(NOT WITH_PROFILER) add_definitions(-DPADDLE_DISABLE_PROFILER) endif(NOT WITH_PROFILER) diff --git a/cmake/cudnn.cmake b/cmake/cudnn.cmake index 69f40df51680a104c47d9335c070c570dcaff59a..2c84061ff572de4687b4d496f8ded6deee8d1011 100644 --- a/cmake/cudnn.cmake +++ b/cmake/cudnn.cmake @@ -2,7 +2,7 @@ if(NOT WITH_GPU) return() endif() -set(CUDNN_ROOT "" CACHE PATH "CUDNN ROOT") +set(CUDNN_ROOT "/usr" CACHE PATH "CUDNN ROOT") find_path(CUDNN_INCLUDE_DIR cudnn.h PATHS ${CUDNN_ROOT} ${CUDNN_ROOT}/include $ENV{CUDNN_ROOT} $ENV{CUDNN_ROOT}/include ${CUDA_TOOLKIT_INCLUDE} diff --git a/paddle/capi/gradient_machine.cpp b/paddle/capi/gradient_machine.cpp index b3287552db87d25edbf6e7f3d5e68121df49e9d6..629449bbd497a7444144c533ad079b3ae6b51438 100644 --- a/paddle/capi/gradient_machine.cpp +++ b/paddle/capi/gradient_machine.cpp @@ -146,3 +146,19 @@ paddle_error paddle_gradient_machine_randomize_param( m->machine->randParameters(); return kPD_NO_ERROR; } + +paddle_error paddle_gradient_machine_get_layer_output( + paddle_gradient_machine machine, + const char* layerName, + paddle_arguments args) { + auto m = cast(machine); + auto out = paddle::capi::cast(args); + if (m == nullptr || layerName == nullptr || out == nullptr || + m->machine == nullptr) { + return kPD_NULLPTR; + } + + auto layerOutput = m->machine->getLayerOutput(layerName); + out->args.push_back(layerOutput); + return kPD_NO_ERROR; +} diff --git a/paddle/capi/gradient_machine.h b/paddle/capi/gradient_machine.h index c613ade5b24efbbf52f21c7ee86dd3189981c5ef..28eeb23e3bbdd4cc22a25c14170bf56c294f8cd7 100644 --- a/paddle/capi/gradient_machine.h +++ b/paddle/capi/gradient_machine.h @@ -39,7 +39,11 @@ PD_API paddle_error paddle_gradient_machine_create_for_inference( /** * @brief Create a gradient machine used for model inference, using config with * parameters which is generated by `paddle merge_model`. - * @param [out] machine that used for model inference. + * Example: + * paddle merge_model \ + * --model_dir="pass-00000" \ + * --model_file="merged_model.paddle" + * @param [out] machine that used for model inference * @param [in] mergedModel * @param [in] size * @return paddle_error @@ -97,6 +101,18 @@ paddle_gradient_machine_randomize_param(paddle_gradient_machine machine); PD_API paddle_error paddle_gradient_machine_destroy(paddle_gradient_machine machine); +/** + * @brief Get the output of the layer named `layerName`. + * @param [in] gradient machine that have run a inference + * @param [in] layerName name of specified layer + * @param [out] args output of the specified layer + * @return paddle_error + */ +PD_API paddle_error +paddle_gradient_machine_get_layer_output(paddle_gradient_machine machine, + const char* layerName, + paddle_arguments args); + #ifdef __cplusplus } #endif diff --git a/paddle/framework/backward.cc b/paddle/framework/backward.cc index c226e4e3d2a58d1a647e204c4cd26f4eb6bcd968..9d30887224fe0020ff5665f362e7403bf5c724ee 100644 --- a/paddle/framework/backward.cc +++ b/paddle/framework/backward.cc @@ -15,6 +15,8 @@ #include "paddle/framework/backward.h" #include +#include + #include "paddle/framework/op_registry.h" #include "paddle/operators/net_op.h" #include "paddle/operators/recurrent_op.h" @@ -43,11 +45,11 @@ static bool AllInSet( return all_in_set; } -static std::shared_ptr NOP() { - auto net_op = std::make_shared(); +static std::unique_ptr NOP() { + auto net_op = new operators::NetOp(); net_op->SetType("@NOP@"); net_op->CompleteAddOp(); - return net_op; + return std::unique_ptr(net_op); } // Get backward operator from a forward operator, a recursive implementation. @@ -62,11 +64,7 @@ static std::shared_ptr NOP() { // operator, in a complex situation, it maybe a NetOp. // // See Backward.h for details -static std::shared_ptr BackwardRecursive( - const OperatorBase& forwardOp, - std::unordered_set& no_grad_names, size_t& uniq_id); - -std::shared_ptr BackwardRecursive( +static std::unique_ptr BackwardRecursive( const OperatorBase& forwardOp, std::unordered_set& no_grad_names, size_t& uniq_id) { // If all input gradients of forwarding operator do not need to calculate, @@ -91,7 +89,7 @@ std::shared_ptr BackwardRecursive( } // Returned gradient network - auto net = std::make_shared(); + auto net = std::unique_ptr(new operators::NetOp()); if (forwardOp.IsNetOp()) { // Because forwardOp is a net op, it can static_cast. @@ -105,14 +103,14 @@ std::shared_ptr BackwardRecursive( // reversely travel forwardNet and collect all duplicate outputs. for (auto it = forwardNet.ops_.rbegin(); it != forwardNet.ops_.rend(); ++it, ++local_op_id) { - auto fwd = *it; + auto& fwd = *it; auto bwd = BackwardRecursive(*fwd, no_grad_names, uniq_id); - net->AddOp(bwd); ForEachVarName(bwd->Outputs(), [&dup_output_ops, local_op_id](const std::string& out) { dup_output_ops[out].emplace_back(local_op_id); return false; }); + net->AddOp(std::move(bwd)); } // Get unique ID for this method. auto uid = uniq_id++; @@ -122,7 +120,7 @@ std::shared_ptr BackwardRecursive( // to handle this case. For each duplicate output, rename it to an alias // (original name with a offset), append an `add` op for its operator, // and finally sum all the alias variable to the final output variable y. - using Pos = std::pair>; + using Pos = std::pair>; std::list insert_position; for (auto& dup_output_op : dup_output_ops) { const std::string& name = dup_output_op.first; @@ -150,13 +148,13 @@ std::shared_ptr BackwardRecursive( [](const Pos& l, const Pos& r) { return l.first > r.first; }); for (auto& pos : insert_position) { - net->InsertOp(pos.first + 1, pos.second); + net->InsertOp(pos.first + 1, std::move(pos.second)); } } else { - std::shared_ptr grad_op = OpRegistry::CreateGradOp(forwardOp); + std::unique_ptr grad_op(OpRegistry::CreateGradOp(forwardOp)); - ForEachVarName(grad_op->Inputs(), [&no_grad_names, &net, - grad_op](const std::string& grad_input) { + ForEachVarName(grad_op->Inputs(), [&no_grad_names, &net, &grad_op]( + const std::string& grad_input) { if (no_grad_names.count(grad_input)) { // +1 for \0 std::string prefix = grad_input.substr( @@ -190,23 +188,23 @@ std::shared_ptr BackwardRecursive( const auto& stepnet_op = *static_cast(&rnnop.stepnet()); // create stepnet's gradient op - auto grad_stepnet = BackwardRecursive(stepnet_op, no_grad_names, uniq_id); rnn_grad_op->set_stepnet( - std::static_pointer_cast(grad_stepnet)); + BackwardRecursive(stepnet_op, no_grad_names, uniq_id)); } if (net->ops_.empty()) { // Current no aux op is added to network return grad_op; } - net->AddOp(grad_op); + net->AddOp(std::move(grad_op)); } net->SetType("@GENERATED_BACKWARD@"); net->CompleteAddOp(); - return net; -} // namespace framework + return std::unique_ptr( + static_cast(net.release())); +} // See header for comments -std::shared_ptr Backward( +std::unique_ptr Backward( const OperatorBase& forwardOp, const std::unordered_set& no_grad_vars) { std::unordered_set no_grad_names; diff --git a/paddle/framework/backward.h b/paddle/framework/backward.h index c181919dc165cf0b49362f85e22ceb4131bbd387..1ecf69881b3126c2904920b9f4b77bfcccc9cf86 100644 --- a/paddle/framework/backward.h +++ b/paddle/framework/backward.h @@ -20,7 +20,7 @@ namespace framework { // Create the backward operator from a forward operator. // TODO(yuyang18): Add more API reference comment. -extern std::shared_ptr Backward( +extern std::unique_ptr Backward( const OperatorBase& forwardOp, const std::unordered_set& no_grad_vars); } // namespace framework diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc index d942604bf05998ab9e1ee147b28586e7e4e9777d..2c5ec76dfeb8b8485951e4d94896b6758e0cb930 100644 --- a/paddle/framework/backward_test.cc +++ b/paddle/framework/backward_test.cc @@ -32,9 +32,9 @@ class RowWiseAddOpMaker : public OpProtoAndCheckerMaker { public: RowWiseAddOpMaker(OpProto *proto, OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "Input X of Add").AsNoGradient(); - AddInput("b", "Bias of Add").AsNoGradient(); - AddOutput("Out", "Out of Add").AsNoGradient(); + AddInput("X", "Input X of Add").NotInGradient(); + AddInput("b", "Bias of Add").NotInGradient(); + AddOutput("Out", "Out of Add").NotInGradient(); AddComment("Add Op"); } }; @@ -180,8 +180,7 @@ TEST(Backward, simple_op_not_need_grad) { auto no_input_gop = f::Backward(*fwd, {"x", "b"}); ASSERT_NE(no_input_gop, nullptr); ASSERT_TRUE(no_input_gop->IsNetOp()); - ASSERT_EQ(0UL, - std::static_pointer_cast(no_input_gop)->ops_.size()); + ASSERT_EQ(0UL, static_cast(no_input_gop.get())->ops_.size()); } TEST(Backward, net_fc_backward_normal) { diff --git a/paddle/framework/framework.proto b/paddle/framework/framework.proto index 7077e8aa2c77c24efdbb34ed3a13821fe7678455..ae44a1ffd45dacdc44a72edc630e771e7a2f2990 100644 --- a/paddle/framework/framework.proto +++ b/paddle/framework/framework.proto @@ -60,7 +60,7 @@ message OpProto { optional bool duplicable = 3 [ default = false ]; optional bool intermediate = 4 [ default = false ]; - optional bool no_gradient = 5 [ default = false ]; + optional bool not_in_gradient = 5 [ default = false ]; } // AttrProto describes the C++ type Attribute. diff --git a/paddle/framework/grad_op_builder.cc b/paddle/framework/grad_op_builder.cc index b73dac22d029876de9a012de533647db3dd74cbb..0a2a41f6b62658ac8633a6e384d099f8d6641f33 100644 --- a/paddle/framework/grad_op_builder.cc +++ b/paddle/framework/grad_op_builder.cc @@ -28,7 +28,7 @@ static void TransOpArg(const OperatorBase* src_op, const OpArgType& src_type, const auto& src_arg_list = src_type == OpArgType::IN ? proto->inputs() : proto->outputs(); for (const auto& arg : src_arg_list) { - if (arg.no_gradient() && !is_grad) continue; + if (arg.not_in_gradient() && !is_grad) continue; const std::string src_name = arg.name(); std::string dst_name = is_grad ? GradVarName(src_name) : src_name; dst_inout[dst_name].reserve(src_inout.at(src_name).size()); diff --git a/paddle/framework/grad_op_builder_test.cc b/paddle/framework/grad_op_builder_test.cc index 0c26293fd29d24a7a40c47bdf055d2758846709b..902c2655e9182d74a48ad13e17a39a3304d5fa57 100644 --- a/paddle/framework/grad_op_builder_test.cc +++ b/paddle/framework/grad_op_builder_test.cc @@ -26,10 +26,10 @@ class IOIgnoredOpMaker : public OpProtoAndCheckerMaker { IOIgnoredOpMaker(OpProto *proto, OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("In1", "a single input"); - AddInput("In2_mult", "a multiple input").AsDuplicable().AsNoGradient(); + AddInput("In2_mult", "a multiple input").AsDuplicable().NotInGradient(); AddInput("In3_mult", "another multiple input").AsDuplicable(); AddOutput("Out1_mult", "a multiple output").AsDuplicable(); - AddOutput("Out2", "a single output").AsNoGradient(); + AddOutput("Out2", "a single output").NotInGradient(); AddComment("op with inputs and outputs ignored in gradient calculating"); } }; diff --git a/paddle/framework/op_registry.cc b/paddle/framework/op_registry.cc index 1caa02a2a1d046778f875d04eeaef957be741302..8eae86e9605da74cdc37caeb9569e7500aac2a63 100644 --- a/paddle/framework/op_registry.cc +++ b/paddle/framework/op_registry.cc @@ -17,5 +17,48 @@ limitations under the License. */ #include namespace paddle { -namespace framework {} // namespace framework +namespace framework { + +std::unique_ptr OpRegistry::CreateOp(const std::string& type, + const VarNameMap& inputs, + const VarNameMap& outputs, + AttributeMap attrs) { + auto it = op_info_map().find(type); + PADDLE_ENFORCE(it != op_info_map().end(), + "Operator '%s' has not been registered.", type); + it->second.checker_->Check(attrs); + auto op = it->second.creator_(type, inputs, outputs, attrs); + return std::unique_ptr(op); +} + +std::unique_ptr OpRegistry::CreateOp(const OpDesc& op_desc) { + VarNameMap inputs = ConvertOpDescVarsToVarNameMap(op_desc.inputs()); + VarNameMap outputs = ConvertOpDescVarsToVarNameMap(op_desc.outputs()); + AttributeMap attrs; + for (auto& attr : op_desc.attrs()) { + attrs[attr.name()] = GetAttrValue(attr); + } + + return CreateOp(op_desc.type(), inputs, outputs, attrs); +} + +OperatorBase::VarNameMap OpRegistry::ConvertOpDescVarsToVarNameMap( + const google::protobuf::RepeatedPtrField& op_desc_vars) { + VarNameMap ret_val; + for (auto& var : op_desc_vars) { + auto& var_names = ret_val[var.parameter()]; + auto& var_names_in_proto = var.arguments(); + var_names.reserve(static_cast(var_names_in_proto.size())); + std::copy(var_names_in_proto.begin(), var_names_in_proto.end(), + std::back_inserter(var_names)); + } + return ret_val; +} + +std::unique_ptr OpRegistry::CreateGradOp(const OperatorBase& op) { + PADDLE_ENFORCE(!op.IsNetOp(), "Use framework::Backward to get backward ops"); + return std::unique_ptr(BuildGradOp(&op)); +} + +} // namespace framework } // namespace paddle diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index 120f4ede6ba21e31683a8d19f0b39072c3f5c309..4c2d13d639005d2d2710c19f63988333d89bce13 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -29,103 +29,6 @@ limitations under the License. */ namespace paddle { namespace framework { -// this class not only make proto but also init attribute checkers. -class OpProtoAndCheckerMaker { - public: - OpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) - : proto_(proto), op_checker_(op_checker) {} - - ~OpProtoAndCheckerMaker() { - PADDLE_ENFORCE(validated_, "should call Validate after build"); - } - - void Validate() { - validated_ = true; - CheckNoDuplicatedInOutAttrs(); - } - - protected: - struct VariableBuilder { - OpProto::Var* var_; - - VariableBuilder& AsDuplicable() { - var_->set_duplicable(true); - return *this; - } - - VariableBuilder& AsIntermediate() { - var_->set_intermediate(true); - return *this; - } - - // TODO(FengJiayi, yuyang18): `AsNoGradient` is a very bad name, because it - // means that input/output is not needed when calculate gradient. It does - // not mean no gradient when backward. It should be changed soon. - VariableBuilder& AsNoGradient() { - var_->set_no_gradient(true); - return *this; - } - }; - - VariableBuilder AddInput(const std::string& name, - const std::string& comment) { - auto* input = proto_->add_inputs(); - input->set_name(name); - input->set_comment(comment); - return VariableBuilder{input}; - } - - VariableBuilder AddOutput(const std::string& name, - const std::string& comment) { - auto* output = proto_->add_outputs(); - output->set_name(name); - output->set_comment(comment); - return VariableBuilder{output}; - } - - template - TypedAttrChecker& AddAttr(const std::string& name, - const std::string& comment, - bool generated = false) { - auto* attr = proto_->add_attrs(); - attr->set_name(name); - attr->set_comment(comment); - attr->set_generated(generated); - attr->set_type(AttrTypeID()); - return op_checker_->AddAttrChecker(name); - } - - void AddComment(const std::string& comment) { proto_->set_comment(comment); } - - private: - void CheckNoDuplicatedInOutAttrs() { - std::unordered_set names; - auto checker = [&](const std::string& name) { - PADDLE_ENFORCE(!names.count(name), "[%s] is duplicated", name); - names.insert(name); - }; - for (auto& attr : proto_->attrs()) { - checker(attr.name()); - } - for (auto& input : proto_->inputs()) { - checker(input.name()); - } - for (auto& output : proto_->outputs()) { - checker(output.name()); - } - } - - OpProto* proto_; - OpAttrChecker* op_checker_; - bool validated_{false}; -}; - -class NOPMaker : public OpProtoAndCheckerMaker { - public: - NOPMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) - : OpProtoAndCheckerMaker(proto, op_checker) {} -}; - class OpRegistry { using VarNameMap = OperatorBase::VarNameMap; using OpCreator = std::function CreateOp(const std::string& type, + static std::unique_ptr CreateOp(const std::string& type, const VarNameMap& inputs, const VarNameMap& outputs, - AttributeMap attrs) { - auto it = op_info_map().find(type); - PADDLE_ENFORCE(it != op_info_map().end(), - "Operator '%s' has not been registered.", type); - it->second.checker_->Check(attrs); - auto op = it->second.creator_(type, inputs, outputs, attrs); - return std::shared_ptr(op); - } - - static VarNameMap ConvertOpDescVarsToVarNameMap( - const google::protobuf::RepeatedPtrField& op_desc_vars) { - VarNameMap ret_val; - for (auto& var : op_desc_vars) { - auto& var_names = ret_val[var.parameter()]; - auto& var_names_in_proto = var.arguments(); - var_names.reserve(static_cast(var_names_in_proto.size())); - std::copy(var_names_in_proto.begin(), var_names_in_proto.end(), - std::back_inserter(var_names)); - } - return ret_val; - } + AttributeMap attrs); - static std::shared_ptr CreateOp(const OpDesc& op_desc) { - VarNameMap inputs = ConvertOpDescVarsToVarNameMap(op_desc.inputs()); - VarNameMap outputs = ConvertOpDescVarsToVarNameMap(op_desc.outputs()); - AttributeMap attrs; - for (auto& attr : op_desc.attrs()) { - attrs[attr.name()] = GetAttrValue(attr); - } + static std::unique_ptr CreateOp(const OpDesc& op_desc); - return CreateOp(op_desc.type(), inputs, outputs, attrs); - } + static VarNameMap ConvertOpDescVarsToVarNameMap( + const google::protobuf::RepeatedPtrField& op_desc_vars); - static std::shared_ptr CreateGradOp(const OperatorBase& op) { - PADDLE_ENFORCE(!op.IsNetOp(), - "Use framework::Backward to get backward ops"); - std::shared_ptr grad_op(BuildGradOp(&op)); - return grad_op; - } + static std::unique_ptr CreateGradOp(const OperatorBase& op); static std::unordered_map& op_info_map() { static std::unordered_map op_info_map_; @@ -272,8 +144,18 @@ class OpKernelRegistrar : public Registrar { grad_op_class) \ STATIC_ASSERT_GLOBAL_NAMESPACE( \ __reg_op__##op_type, "REGISTER_OP must be called in global namespace"); \ - static ::paddle::framework::OpRegistrar \ + class _OpClass_##op_type##_ : public op_class { \ + public: \ + DEFINE_OP_CLONE_METHOD(_OpClass_##op_type##_); \ + DEFINE_OP_CONSTRUCTOR(_OpClass_##op_type##_, op_class); \ + }; \ + class _OpGradClass_##op_type##_ : public grad_op_class { \ + public: \ + DEFINE_OP_CLONE_METHOD(_OpGradClass_##op_type##_); \ + DEFINE_OP_CONSTRUCTOR(_OpGradClass_##op_type##_, grad_op_class); \ + }; \ + static ::paddle::framework::OpRegistrar< \ + _OpClass_##op_type##_, op_maker_class, _OpGradClass_##op_type##_> \ __op_registrar_##op_type##__(#op_type, #grad_op_type); \ int TouchOpRegistrar_##op_type() { \ __op_registrar_##op_type##__.Touch(); \ @@ -304,7 +186,8 @@ class OpKernelRegistrar : public Registrar { REGISTER_OP_KERNEL(op_type, CPU, ::paddle::platform::CPUPlace, __VA_ARGS__) /** - * Macro to mark what Operator and Kernel we will use and tell the compiler to + * Macro to mark what Operator and Kernel + * we will use and tell the compiler to * link them into target. */ #define USE_OP_ITSELF(op_type) \ @@ -324,7 +207,8 @@ class OpKernelRegistrar : public Registrar { __attribute__((unused)) = \ TouchOpKernelRegistrar_##op_type##_##DEVICE_TYPE() -// TODO(fengjiayi): The following macros seems ugly, do we have better method? +// TODO(fengjiayi): The following macros +// seems ugly, do we have better method? #ifdef PADDLE_ONLY_CPU #define USE_OP_KERNEL(op_type) USE_OP_DEVICE_KERNEL(op_type, CPU) diff --git a/paddle/framework/op_registry_test.cc b/paddle/framework/op_registry_test.cc index 1a85d568350dc04ca1df28129de19cd45b5204b8..50c45919c53af22665feeeebe753da283ded2b0c 100644 --- a/paddle/framework/op_registry_test.cc +++ b/paddle/framework/op_registry_test.cc @@ -76,8 +76,7 @@ TEST(OpRegistry, CreateOp) { attr->set_type(paddle::framework::AttrType::FLOAT); attr->set_f(scale); - std::shared_ptr op = - paddle::framework::OpRegistry::CreateOp(op_desc); + auto op = paddle::framework::OpRegistry::CreateOp(op_desc); paddle::framework::Scope scope; paddle::platform::CPUDeviceContext dev_ctx; op->Run(scope, dev_ctx); @@ -118,8 +117,7 @@ TEST(OpRegistry, DefaultValue) { ASSERT_TRUE(op_desc.IsInitialized()); - std::shared_ptr op = - paddle::framework::OpRegistry::CreateOp(op_desc); + auto op = paddle::framework::OpRegistry::CreateOp(op_desc); paddle::framework::Scope scope; paddle::platform::CPUDeviceContext dev_ctx; op->Run(scope, dev_ctx); diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index 0daf12e7f5f3539d460ce67d39ca1c06f5aa2237..eadd8f3316ff1ebffb94a56b2e62d661e4e0b38f 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -164,5 +164,43 @@ std::vector OperatorBase::OutputVars(bool has_intermediate) const { return ret_val; } +void OpProtoAndCheckerMaker::Validate() { + validated_ = true; + CheckNoDuplicatedInOutAttrs(); +} + +OpProtoAndCheckerMaker::VariableBuilder OpProtoAndCheckerMaker::AddInput( + const std::string& name, const std::string& comment) { + auto* input = proto_->add_inputs(); + input->set_name(name); + input->set_comment(comment); + return OpProtoAndCheckerMaker::VariableBuilder{input}; +} + +OpProtoAndCheckerMaker::VariableBuilder OpProtoAndCheckerMaker::AddOutput( + const std::string& name, const std::string& comment) { + auto* output = proto_->add_outputs(); + output->set_name(name); + output->set_comment(comment); + return OpProtoAndCheckerMaker::VariableBuilder{output}; +} + +void OpProtoAndCheckerMaker::CheckNoDuplicatedInOutAttrs() { + std::unordered_set names; + auto checker = [&](const std::string& name) { + PADDLE_ENFORCE(!names.count(name), "[%s] is duplicated", name); + names.insert(name); + }; + for (auto& attr : proto_->attrs()) { + checker(attr.name()); + } + for (auto& input : proto_->inputs()) { + checker(input.name()); + } + for (auto& output : proto_->outputs()) { + checker(output.name()); + } +} + } // namespace framework } // namespace paddle diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 60d4f06c7e6f8849800f238242a340ef5dbf771e..807298088981b969622174be753ea0da72067243 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -67,10 +67,6 @@ class OperatorBase { OperatorBase(const std::string& type, const VarNameMap& inputs, const VarNameMap& outputs, const AttributeMap& attrs); - OperatorBase(const OperatorBase& o) = delete; - OperatorBase& operator=(const OperatorBase& o) = delete; - OperatorBase(OperatorBase&& o) = delete; - virtual ~OperatorBase() {} template @@ -116,10 +112,14 @@ class OperatorBase { void SetType(const std::string& type) { type_ = type; } const AttributeMap& Attrs() const { return attrs_; } + // Return a new operator instance, which is as same as this. + // Use unique_ptr to prevent caller forget to delete this pointer. + virtual std::unique_ptr Clone() const = 0; + protected: std::string type_; // NOTE: in case of OpGrad, inputs_ contains: - // I (Inputs) + // I (Inputs)opear // O (Outputs) // OG (Output Gradients) VarNameMap inputs_; @@ -130,12 +130,97 @@ class OperatorBase { AttributeMap attrs_; }; +// Macro for define a clone method. +// If you are writing an kernel operator, `Clone` will be defined when you +// register it. i.e. `Clone` method is not needed to define by yourself. +#define DEFINE_OP_CLONE_METHOD(CLS) \ + std::unique_ptr Clone() const final { \ + return std::unique_ptr(new CLS(*this)); \ + } + +// Macro for define a default constructor for Operator. +// You can also use +// using PARENT_CLASS::PARENT_CLASS; +// to use parent's constructor. +#define DEFINE_OP_CONSTRUCTOR(CLS, PARENT_CLS) \ + CLS(const std::string& type, const VarNameMap& inputs, \ + const VarNameMap& outputs, const paddle::framework::AttributeMap& attrs) \ + : PARENT_CLS(type, inputs, outputs, attrs) {} + class NOP : public OperatorBase { public: using OperatorBase::OperatorBase; void InferShape(const Scope& scope) const override {} void Run(const Scope& scope, const platform::DeviceContext& dev_ctx) const override {} + std::unique_ptr Clone() const override { + return std::unique_ptr(new NOP(*this)); + } +}; + +// this class not only make proto but also init attribute checkers. +class OpProtoAndCheckerMaker { + public: + OpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) + : proto_(proto), op_checker_(op_checker) {} + + ~OpProtoAndCheckerMaker() { + PADDLE_ENFORCE(validated_, "should call Validate after build"); + } + + void Validate(); + + protected: + struct VariableBuilder { + OpProto::Var* var_; + + VariableBuilder& AsDuplicable() { + var_->set_duplicable(true); + return *this; + } + + VariableBuilder& AsIntermediate() { + var_->set_intermediate(true); + return *this; + } + + VariableBuilder& NotInGradient() { + var_->set_not_in_gradient(true); + return *this; + } + }; + + VariableBuilder AddInput(const std::string& name, const std::string& comment); + + VariableBuilder AddOutput(const std::string& name, + const std::string& comment); + + template + TypedAttrChecker& AddAttr(const std::string& name, + const std::string& comment, + bool generated = false) { + auto* attr = proto_->add_attrs(); + attr->set_name(name); + attr->set_comment(comment); + attr->set_generated(generated); + attr->set_type(AttrTypeID()); + return op_checker_->AddAttrChecker(name); + } + + void AddComment(const std::string& comment) { proto_->set_comment(comment); } + + private: + void CheckNoDuplicatedInOutAttrs(); + + OpProto* proto_; + OpAttrChecker* op_checker_; + bool validated_{false}; +}; + +class NOPMaker : public OpProtoAndCheckerMaker { + public: + NOPMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) {} }; class InferShapeContext { diff --git a/paddle/framework/operator_test.cc b/paddle/framework/operator_test.cc index 0441cec9f6d10246fba38b02b4de3cbe2ee4766b..2425b87779f6af01b0e8a91b5f574a28385f0efd 100644 --- a/paddle/framework/operator_test.cc +++ b/paddle/framework/operator_test.cc @@ -245,3 +245,21 @@ TEST(OpKernel, multi_inputs) { auto op = paddle::framework::OpRegistry::CreateOp(op_desc); op->Run(scope, cpu_device_context); } + +class OperatorClone : public paddle::framework::OperatorBase { + public: + DEFINE_OP_CLONE_METHOD(OperatorClone); + OperatorClone(const std::string& type, const VarNameMap& inputs, + const VarNameMap& outputs, + const paddle::framework::AttributeMap& attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + void InferShape(const paddle::framework::Scope& scope) const override {} + void Run(const paddle::framework::Scope& scope, + const paddle::platform::DeviceContext& dev_ctx) const override {} +}; + +TEST(Operator, Clone) { + OperatorClone a("ABC", {}, {}, {}); + auto b = a.Clone(); + ASSERT_EQ(a.Type(), b->Type()); +} \ No newline at end of file diff --git a/paddle/framework/pybind.cc b/paddle/framework/pybind.cc index fe0c87bc570825014222807cb90a3bb341b44e8e..f0114b9e4908d65b3fddb493230777f9e500b4e1 100644 --- a/paddle/framework/pybind.cc +++ b/paddle/framework/pybind.cc @@ -48,29 +48,6 @@ namespace framework { using Tensor = framework::Tensor; -template -void ExposeOperator(ClassType &m) { - m.def("infer_shape", &ClassType::type::InferShape) - .def("run", &ClassType::type::Run) - .def("type", - [](const typename ClassType::type &op) -> std::string { - return op.Type(); - }) - .def("outputs", - [](const typename ClassType::type &op) - -> std::map> { - return op.Outputs(); - }) - .def("inputs", - [](const typename ClassType::type &op) { return op.Inputs(); }) - .def("__str__", &ClassType::type::DebugString) - .def("no_intermediate_outputs", - [](const typename ClassType::type &op) { - return op.OutputVars(false); - }) - .def("support_gpu", &ClassType::type::SupportGPU); -} - static size_t UniqueIntegerGenerator() { static std::atomic generator; return generator.fetch_add(1); @@ -207,75 +184,69 @@ All parameter, weight, gradient are variables in Paddle. .def(py::init<>()) .def("__str__", string::to_string); - py::class_> operator_base( - m, "Operator"); - - operator_base.def_static("create", [](py::bytes protobin) { - OpDesc desc; - PADDLE_ENFORCE(desc.ParsePartialFromString(protobin), - "Cannot parse user input to OpDesc"); - PADDLE_ENFORCE(desc.IsInitialized(), - "User OpDesc is not initialized, reason %s", - desc.InitializationErrorString()); - return OpRegistry::CreateOp(desc); - }); - - operator_base.def("backward", - [](const OperatorBase &forwardOp, - const std::unordered_set &no_grad_vars) { - return Backward(forwardOp, no_grad_vars); - }); - - ExposeOperator(operator_base); - - py::class_> net(m, "Net"); - - net.def_static("create", - []() -> std::shared_ptr { - auto retv = std::make_shared(); - retv->SetType("plain_net"); - return retv; - }) - .def("add_op", &operators::NetOp::AddOp) - .def("add_op", - [](operators::NetOp &self, - const std::shared_ptr &net) -> void { - self.AddOp(std::static_pointer_cast(net)); - }) - .def("add_op", - [](operators::NetOp &self, - const std::shared_ptr &rnn) -> void { - self.AddOp(std::static_pointer_cast(rnn)); + py::class_(m, "Operator") + .def_static("create", + [](py::bytes protobin) { + OpDesc desc; + PADDLE_ENFORCE(desc.ParsePartialFromString(protobin), + "Cannot parse user input to OpDesc"); + PADDLE_ENFORCE(desc.IsInitialized(), + "User OpDesc is not initialized, reason %s", + desc.InitializationErrorString()); + return OpRegistry::CreateOp(desc); + }) + .def("backward", + [](const OperatorBase &forwardOp, + const std::unordered_set &no_grad_vars) { + return Backward(forwardOp, no_grad_vars).release(); }) + .def("infer_shape", &OperatorBase::InferShape) + .def("run", &OperatorBase::Run) + .def("type", + [](const OperatorBase &op) -> std::string { return op.Type(); }) + .def("outputs", + [](const OperatorBase &op) + -> std::map> { + return op.Outputs(); + }) + .def("inputs", [](const OperatorBase &op) { return op.Inputs(); }) + .def("__str__", &OperatorBase::DebugString) + .def("no_intermediate_outputs", + [](const OperatorBase &op) { return op.OutputVars(false); }) + .def("support_gpu", &OperatorBase::SupportGPU); + + py::class_(m, "Net") + .def_static("create", + []() -> operators::NetOp * { + auto *retv = new operators::NetOp; + retv->SetType("plain_net"); + return retv; + }) + .def("add_op", [](operators::NetOp &self, + const OperatorBase &op) { self.AddOp(op); }) .def("complete_add_op", &operators::NetOp::CompleteAddOp) .def("complete_add_op", [](std::shared_ptr &self) { self->CompleteAddOp(); }); - ExposeOperator(net); - // recurrent_op - py::class_> - rnn(m, "RecurrentOp"); - - rnn.def_static( - "create", - [](py::bytes protobin) -> std::shared_ptr { - OpDesc desc; - PADDLE_ENFORCE(desc.ParsePartialFromString(protobin), - "Cannot parse user input to OpDesc"); - PADDLE_ENFORCE(desc.IsInitialized(), - "User OpDesc is not initialized, reason %s", - desc.InitializationErrorString()); - auto rnn_op = OpRegistry::CreateOp(desc); - return std::dynamic_pointer_cast(rnn_op); - }) - .def("set_stepnet", - [](operators::RecurrentOp &self, - const std::shared_ptr &net) -> void { - self.set_stepnet(net); - }); - ExposeOperator(rnn); + py::class_(m, "RecurrentOp") + .def_static( + "create", + [](py::bytes protobin) -> operators::RecurrentOp * { + OpDesc desc; + PADDLE_ENFORCE(desc.ParsePartialFromString(protobin), + "Cannot parse user input to OpDesc"); + PADDLE_ENFORCE(desc.IsInitialized(), + "User OpDesc is not initialized, reason %s", + desc.InitializationErrorString()); + auto rnn_op = OpRegistry::CreateOp(desc); + return static_cast(rnn_op.release()); + }) + .def("set_stepnet", [](operators::RecurrentOp &self, + const operators::NetOp &net) -> void { + self.set_stepnet(net.Clone()); + }); m.def("unique_integer", UniqueIntegerGenerator); diff --git a/paddle/function/CMakeLists.txt b/paddle/function/CMakeLists.txt index 7dfb6f61c50959f7269725a00dbc4f9c27474bdf..c572a9d433bc16e6733b8fc9367970bef28e699a 100644 --- a/paddle/function/CMakeLists.txt +++ b/paddle/function/CMakeLists.txt @@ -4,6 +4,10 @@ file(GLOB cpp_files . *Op.cpp) list(APPEND h_files Function.h) list(APPEND cpp_files Function.cpp) list(APPEND cpp_files BufferArg.cpp) +list(APPEND cpp_files GemmFunctor.cpp) +if(USE_EIGEN_FOR_BLAS) + list(APPEND cpp_files EigenGemm.cpp) +endif(USE_EIGEN_FOR_BLAS) if(WITH_GPU) file(GLOB cu_files . *OpGpu.cu) diff --git a/paddle/function/DepthwiseConvOp.cpp b/paddle/function/DepthwiseConvOp.cpp index 490e8d546cbd460217abe95f6291b13fa207faa9..2f3112fe657cd381891dc53c7179e7520911e8c9 100644 --- a/paddle/function/DepthwiseConvOp.cpp +++ b/paddle/function/DepthwiseConvOp.cpp @@ -14,7 +14,6 @@ limitations under the License. */ #include "DepthwiseConvOp.h" #include "ConvOp.h" -#include "GemmFunctor.h" namespace paddle { diff --git a/paddle/function/DepthwiseConvOpGpu.cu b/paddle/function/DepthwiseConvOpGpu.cu index 33463805cbd4746c05548028e0bc4a0e2a90453e..2d722dfcfca0f328edeecf185ea37b8512b91907 100644 --- a/paddle/function/DepthwiseConvOpGpu.cu +++ b/paddle/function/DepthwiseConvOpGpu.cu @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "DepthwiseConvOp.h" -#include "GemmFunctor.h" #include "paddle/math/BaseMatrix.h" namespace paddle { diff --git a/paddle/function/EigenGemm.cpp b/paddle/function/EigenGemm.cpp new file mode 100644 index 0000000000000000000000000000000000000000..674141ed39b7f5573948348e3ba3bb526ae43c66 --- /dev/null +++ b/paddle/function/EigenGemm.cpp @@ -0,0 +1,91 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "unsupported/Eigen/CXX11/Tensor" + +namespace paddle { + +template +struct EigenBlasGemm { + typedef Eigen::TensorMap, + Eigen::Aligned> + Matrix; + + static void compute(const bool transA, + const bool transB, + const int M, + const int N, + const int K, + const T alpha, + const T* A, + const int lda, + const T* B, + const int ldb, + const T beta, + T* C, + const int ldc) { + Eigen::array sizeA; + if (transA) { + sizeA[0] = K; + sizeA[1] = M; + CHECK_EQ(M, lda); + } else { + sizeA[0] = M; + sizeA[1] = K; + CHECK_EQ(K, lda); + } + Eigen::array sizeB; + if (transB) { + sizeB[0] = N; + sizeB[1] = K; + CHECK_EQ(K, ldb); + } else { + sizeB[0] = K; + sizeB[1] = N; + CHECK_EQ(N, ldb); + } + Eigen::array sizeC; + sizeC[0] = M; + sizeC[1] = N; + CHECK_EQ(N, ldc); + + const Matrix a(const_cast(A), sizeA); + const Matrix b(const_cast(B), sizeB); + Matrix c(C, sizeC); + + typedef typename Eigen::Tensor::DimensionPair DimPair; + Eigen::array dims; + dims[0] = DimPair(1, 0); + dims[0].first = transA ? 0 : 1; + dims[0].second = transB ? 1 : 0; + + Eigen::DefaultDevice device; + if (alpha == T(1) && beta == T(0)) { + c.device(device) = a.contract(b, dims); + } else if (alpha == T(1) && beta == T(1)) { + c.device(device) += a.contract(b, dims); + } else { + c.device(device) = alpha * a.contract(b, dims) + beta * c; + } + } +}; + +#ifdef PADDLE_TYPE_DOUBLE +template class EigenBlasGemm; +#else +template class EigenBlasGemm; +#endif + +} // namespace paddle diff --git a/paddle/function/GemmConvOp.cpp b/paddle/function/GemmConvOp.cpp index 0ada4d70a0c7d13f9b5fb1a42eac07fc4c775a87..f8cf4ebea8d724f0291b981647622b63e3d84495 100644 --- a/paddle/function/GemmConvOp.cpp +++ b/paddle/function/GemmConvOp.cpp @@ -85,7 +85,6 @@ public: } Im2ColFunctor im2col; - GemmFunctor gemm; size_t inputOffset = imShape.getElements(); size_t outputOffset = (outputChannels / groups_) * outputHeight * outputWidth; @@ -108,19 +107,19 @@ public: int M = outputChannels / groups_; int N = outputHeight * outputWidth; int K = inputChannels / groups_ * filterHeight * filterWidth; - gemm(CblasNoTrans, - CblasNoTrans, - M, - N, - K, - 1.0f, - filterData + g * filterOffset, - K, - colData, - N, - beta, - outputData + g * outputOffset, - N); + BlasGemm::compute(false, + false, + M, + N, + K, + 1.0f, + filterData + g * filterOffset, + K, + colData, + N, + beta, + outputData + g * outputOffset, + N); } inputData += inputChannels * inputHeight * inputWidth; outputData += outputChannels * outputHeight * outputWidth; @@ -188,8 +187,6 @@ public: } Col2ImFunctor col2im; - GemmFunctor gemm; - size_t inputOffset = imShape.getElements(); size_t outputOffset = (outputChannels / groups_) * outputHeight * outputWidth; @@ -205,19 +202,19 @@ public: colData = inputGrad + g * inputOffset; scale = 1.0f; } - gemm(CblasTrans, - CblasNoTrans, - M, - N, - K, - 1.0f, - filterData + g * filterOffset, - M, - outputGrad + g * outputOffset, - N, - scale, - colData, - N); + BlasGemm::compute(true, + false, + M, + N, + K, + 1.0f, + filterData + g * filterOffset, + M, + outputGrad + g * outputOffset, + N, + scale, + colData, + N); if (needIm2col) { col2im(inputGrad + g * inputOffset, imShape, @@ -299,7 +296,6 @@ public: } Im2ColFunctor im2col; - GemmFunctor gemm; size_t inputOffset = imShape.getElements(); size_t outputOffset = (outputChannels / groups_) * outputHeight * outputWidth; @@ -321,19 +317,19 @@ public: int M = outputChannels / groups_; int K = outputHeight * outputWidth; int N = inputChannels / groups_ * filterHeight * filterWidth; - gemm(CblasNoTrans, - CblasTrans, - M, - N, - K, - 1.0f, - outputGrad + g * outputOffset, - K, - colData, - K, - i == 0 ? beta : 1.0f, - filterGrad + g * filterOffset, - N); + BlasGemm::compute(false, + true, + M, + N, + K, + 1.0f, + outputGrad + g * outputOffset, + K, + colData, + K, + i == 0 ? beta : 1.0f, + filterGrad + g * filterOffset, + N); } inputData += inputChannels * inputHeight * inputWidth; outputGrad += outputChannels * outputHeight * outputWidth; diff --git a/paddle/function/GemmFunctor.cpp b/paddle/function/GemmFunctor.cpp new file mode 100644 index 0000000000000000000000000000000000000000..dc83278d8e3b70101962521b76b902046629c2dd --- /dev/null +++ b/paddle/function/GemmFunctor.cpp @@ -0,0 +1,90 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "GemmFunctor.h" +#include "paddle/math/MathFunctions.h" + +namespace paddle { + +template +struct BlasGemm { + static void compute(const bool transA, + const bool transB, + const int M, + const int N, + const int K, + const T alpha, + const T* A, + const int lda, + const T* B, + const int ldb, + const T beta, + T* C, + const int ldc) { +#ifdef PADDLE_USE_EIGEN_FOR_BLAS + EigenBlasGemm::compute( + transA, transB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); +#else + gemm(transA == false ? CblasNoTrans : CblasTrans, + transB == false ? CblasNoTrans : CblasTrans, + M, + N, + K, + alpha, + A, + lda, + B, + ldb, + beta, + C, + ldc); +#endif + } +}; + +template +struct BlasGemm { + static void compute(const bool transA, + const bool transB, + const int M, + const int N, + const int K, + const T alpha, + const T* A, + const int lda, + const T* B, + const int ldb, + const T beta, + T* C, + const int ldc) { + hl_matrix_mul((T*)A, + transA == false ? HPPL_OP_N : HPPL_OP_T, + (T*)B, + transB == false ? HPPL_OP_N : HPPL_OP_T, + C, + M, + N, + K, + alpha, + beta, + lda, + ldb, + ldc); + } +}; + +template class BlasGemm; +template class BlasGemm; + +} // namespace paddle diff --git a/paddle/function/GemmFunctor.h b/paddle/function/GemmFunctor.h index d5db5cf5e7a855d89b262fe8cf42aa2c55f419f1..0809953b4eb17c25eadcce7f474a3dab0469bba1 100644 --- a/paddle/function/GemmFunctor.h +++ b/paddle/function/GemmFunctor.h @@ -14,7 +14,7 @@ limitations under the License. */ #pragma once -#include "paddle/math/MathFunctions.h" +#include "TensorType.h" namespace paddle { @@ -24,73 +24,42 @@ namespace paddle { // of MatMulFunction, we need to consider the reconstruction of hl_matrix_mul // interface. template -class GemmFunctor { -public: - void operator()(const CBLAS_TRANSPOSE transA, - const CBLAS_TRANSPOSE TransB, - const int M, - const int N, - const int K, - const T alpha, - const T* A, - const int lda, - const T* B, - const int ldb, - const T beta, - T* C, - const int ldc); +struct BlasGemm { + static void compute(const bool transA, + const bool transB, + const int M, + const int N, + const int K, + const T alpha, + const T* A, + const int lda, + const T* B, + const int ldb, + const T beta, + T* C, + const int ldc); }; +// TODO(hedaoyuan): Since the definition of the real type in the Paddle +// conflicts with the Eigen library, so compile the Eigen code can not +// include the Paddle header file. And need an EigenBlasGemm template class +// that does not contain the DeviceType parameter. +// I will fix this problem and merge BlasGemm and EigenBlasGemm into one. template -class GemmFunctor { -public: - void operator()(const CBLAS_TRANSPOSE transA, - const CBLAS_TRANSPOSE TransB, - const int M, - const int N, - const int K, - const T alpha, - const T* A, - const int lda, - const T* B, - const int ldb, - const T beta, - T* C, - const int ldc) { - gemm(transA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); - } -}; - -template -class GemmFunctor { -public: - void operator()(const CBLAS_TRANSPOSE transA, - const CBLAS_TRANSPOSE TransB, - const int M, - const int N, - const int K, - const T alpha, - const T* A, - const int lda, - const T* B, - const int ldb, - const T beta, - T* C, - const int ldc) { - hl_matrix_mul((T*)A, - transA == CblasNoTrans ? HPPL_OP_N : HPPL_OP_T, - (T*)B, - TransB == CblasNoTrans ? HPPL_OP_N : HPPL_OP_T, - C, - M, - N, - K, - alpha, - beta, - lda, - ldb, - ldc); - } +struct EigenBlasGemm { + static void compute(const bool transA, + const bool transB, + const int M, + const int N, + const int K, + const T alpha, + const T* A, + const int lda, + const T* B, + const int ldb, + const T beta, + T* C, + const int ldc); }; } // namespace paddle diff --git a/paddle/gserver/layers/MKLDNNFcLayer.cpp b/paddle/gserver/layers/MKLDNNFcLayer.cpp index 30f567eaf8248a8fba1b461a2bdbf2aab13f9e08..d201fac65e0459050304195140e1aae081468f43 100644 --- a/paddle/gserver/layers/MKLDNNFcLayer.cpp +++ b/paddle/gserver/layers/MKLDNNFcLayer.cpp @@ -57,11 +57,14 @@ bool MKLDNNFcLayer::init(const LayerMap& layerMap, } void MKLDNNFcLayer::convertWeightsFromPaddle() { - if (FLAGS_use_mkldnn_wgt) { + if (hasInitedWgt_) { return; } - if (hasInitedWgt_) { + // TODO(TJ): dst format should get from wgtVal_ + int dstFmt = PARAM_FORMAT_MKLDNN_OI; + int srcFmt = weight_->getParameterPtr()->getHeaderFormat(); + if (srcFmt == dstFmt) { return; } @@ -78,6 +81,7 @@ void MKLDNNFcLayer::convertWeightsFromPaddle() { MatrixPtr paddleWgtT; paddleWgt->transpose(paddleWgtT, true); weight_->getW()->copyFrom(*paddleWgtT); + weight_->getParameterPtr()->setHeaderFormat(dstFmt); hasInitedWgt_ = true; } diff --git a/paddle/gserver/tests/MKLDNNTester.cpp b/paddle/gserver/tests/MKLDNNTester.cpp index 99c8c4948c9b05ad15d1217ebb70026bbd48453f..de1635be2af37cd0ba49010199a417090865b0e4 100644 --- a/paddle/gserver/tests/MKLDNNTester.cpp +++ b/paddle/gserver/tests/MKLDNNTester.cpp @@ -330,9 +330,7 @@ void MKLDNNTester::run(const TestConfig& dnn, log_ = log; lvl_ = level; - // Firstly test FLAGS_use_mkldnn_wgt = false - FLAGS_use_mkldnn_wgt = false; - // reset and run once + // Firstly test mkldnn init from PARAM_FORMAT_ORIGINAL weight reset(dnn, ref, batchSize); randomWgtDatas(); clearWgtDiffs(); @@ -342,17 +340,32 @@ void MKLDNNTester::run(const TestConfig& dnn, runOnce(); } - // Then test FLAGS_use_mkldnn_wgt = true - FLAGS_use_mkldnn_wgt = true; - // after run once the mkldnn weight has been stored in dnnlayer + if (parameters_[DNN].empty()) { + // has no paramters + return; + } + + // After run some iterations, the mkldnn weight has been stored in dnnLayer + // and we can also get the mkldnn weight parameter header format. + // Weight parameter should always be index 0 (and bias index 1). + // TODO(TJ): should also consider mean and var format when batchnorm ready + int dnnWgtFmt = parameters_[DNN][0]->getHeaderFormat(); + int refWgtFmt = parameters_[REF][0]->getHeaderFormat(); + if (dnnWgtFmt == refWgtFmt) { + // weight format are equal, so no need check more + return; + } + // then save the weights and restart again vector dnnWgts, refWgts; CHECK_EQ(parameters_[DNN].size(), parameters_[REF].size()); saveWgt(parameters_[DNN], dnnWgts); saveWgt(parameters_[REF], refWgts); - // restart again with flag true + // restart again with dnn weight format reset(dnn, ref, batchSize); + // TODO(TJ): should also considerate mean and var format when batchnorm ready + parameters_[DNN][0]->setHeaderFormat(dnnWgtFmt); // restore wgt restoreWgt(dnnWgts, parameters_[DNN]); diff --git a/paddle/gserver/tests/MKLDNNTester.h b/paddle/gserver/tests/MKLDNNTester.h index 522eeaf24b1949abac057a1e59e9977610be23c0..e55e4493ffdfe45b8cfdee423febd1878b8b3d8a 100644 --- a/paddle/gserver/tests/MKLDNNTester.h +++ b/paddle/gserver/tests/MKLDNNTester.h @@ -108,7 +108,7 @@ private: * if many(>failRate) wrong(abs(dnn-ref)/abs(ref)>thres) points return the * max(diff/ref) * else return sum(abs(a-b)) / sum(abs(b)) - * The return value should smaller than eps when passing. + * The return value should be smaller than eps when passing. */ double getDelta(const real* d1, const real* d2, diff --git a/paddle/gserver/tests/test_NetworkCompare.cpp b/paddle/gserver/tests/test_NetworkCompare.cpp index f930c72fde3f5e0a6a45cb6bfd3507a4f48028fc..d36f72360f8ebd2033fb3e8c0e1b30911abba362 100644 --- a/paddle/gserver/tests/test_NetworkCompare.cpp +++ b/paddle/gserver/tests/test_NetworkCompare.cpp @@ -269,7 +269,8 @@ TEST(Compare, img_conv2) { bool useGpu = FLAGS_use_gpu; double eps = FLAGS_checkgrad_eps; FLAGS_use_gpu = true; - FLAGS_checkgrad_eps = 1e-2; + // Sometimes, this unit test will fail with 1e-2 + FLAGS_checkgrad_eps = 4e-2; compareNetwork(config_file_a, config_file_b); FLAGS_use_gpu = useGpu; FLAGS_checkgrad_eps = eps; diff --git a/paddle/memory/detail/system_allocator.cc b/paddle/memory/detail/system_allocator.cc index f61e67a32906083881dd7f47433521876be9b355..a270bd59581520859d43cddd2fc0cfa72080f46d 100644 --- a/paddle/memory/detail/system_allocator.cc +++ b/paddle/memory/detail/system_allocator.cc @@ -27,7 +27,7 @@ limitations under the License. */ // between host and device. Allocates too much would reduce the amount // of memory available to the system for paging. So, by default, we // should set false to use_pinned_memory. -DEFINE_bool(use_pinned_memory, false, "If set, allocate cpu pinned memory."); +DEFINE_bool(use_pinned_memory, true, "If set, allocate cpu pinned memory."); namespace paddle { namespace memory { diff --git a/paddle/memory/memory.cc b/paddle/memory/memory.cc index 207025f9b1c64f0f8943f9fae5edefc9328a1d26..29bc26f9d3bca0e30896657431f9a9bb1dac0d1d 100644 --- a/paddle/memory/memory.cc +++ b/paddle/memory/memory.cc @@ -13,22 +13,38 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/memory/memory.h" + +#include // for transform +#include // for memcpy +#include // for unique_ptr +#include // for call_once + +#include "glog/logging.h" + #include "paddle/memory/detail/buddy_allocator.h" #include "paddle/memory/detail/system_allocator.h" +#include "paddle/platform/gpu_info.h" -#include // for memcpy +DECLARE_double(fraction_of_gpu_memory_to_use); namespace paddle { namespace memory { -detail::BuddyAllocator* GetCPUBuddyAllocator() { - static detail::BuddyAllocator* a = nullptr; - if (a == nullptr) { - a = new detail::BuddyAllocator(new detail::CPUAllocator, - platform::CpuMinChunkSize(), - platform::CpuMaxChunkSize()); - } - return a; +using BuddyAllocator = detail::BuddyAllocator; + +std::once_flag cpu_allocator_flag; +std::once_flag gpu_allocator_flag; + +BuddyAllocator* GetCPUBuddyAllocator() { + static std::unique_ptr a{nullptr}; + + std::call_once(cpu_allocator_flag, [&]() { + a.reset(new BuddyAllocator(new detail::CPUAllocator, + platform::CpuMinChunkSize(), + platform::CpuMaxChunkSize())); + }); + + return a.get(); } template <> @@ -48,20 +64,36 @@ size_t Used(platform::CPUPlace place) { #ifndef PADDLE_ONLY_CPU -detail::BuddyAllocator* GetGPUBuddyAllocator(int gpu_id) { - static detail::BuddyAllocator** as = NULL; - if (as == NULL) { +BuddyAllocator* GetGPUBuddyAllocator(int gpu_id) { + using BuddyAllocVec = std::vector; + static std::unique_ptr as{ + new BuddyAllocVec, [](BuddyAllocVec* p) { + std::for_each(p->begin(), p->end(), + [](BuddyAllocator* p) { delete p; }); + }}; + + // GPU buddy allocators + auto& allocators = *as.get(); + + // GPU buddy allocator initialization + std::call_once(gpu_allocator_flag, [&]() { int gpu_num = platform::GetDeviceCount(); - as = new detail::BuddyAllocator*[gpu_num]; + allocators.reserve(gpu_num); for (int gpu = 0; gpu < gpu_num; gpu++) { platform::SetDeviceId(gpu); - as[gpu] = new detail::BuddyAllocator(new detail::GPUAllocator, - platform::GpuMinChunkSize(), - platform::GpuMaxChunkSize()); + allocators.emplace_back(new BuddyAllocator(new detail::GPUAllocator, + platform::GpuMinChunkSize(), + platform::GpuMaxChunkSize())); } - } + VLOG(3) << "\n\nNOTE: each GPU device use " + << FLAGS_fraction_of_gpu_memory_to_use * 100 << "% of GPU memory.\n" + << "You can set environment variable '" + << platform::kEnvFractionGpuMemoryToUse + << "' to change the fraction of GPU usage.\n\n"; + }); + platform::SetDeviceId(gpu_id); - return as[gpu_id]; + return allocators[gpu_id]; } template <> diff --git a/paddle/memory/memory.h b/paddle/memory/memory.h index 72351b9dfa63513713463bb47a3684f0dfd84ad3..11bbb881874ec50e1132547336fc6fb6b42bcc4f 100644 --- a/paddle/memory/memory.h +++ b/paddle/memory/memory.h @@ -14,7 +14,6 @@ limitations under the License. */ #pragma once -#include "paddle/platform/gpu_info.h" #include "paddle/platform/place.h" namespace paddle { diff --git a/paddle/operators/gather_test.cc b/paddle/operators/gather_test.cc index d24d83f299fdb071e60fa3cc7b223c0228cb29af..0ae1e99452973feb6d085dd6ef51e2afca988f59 100644 --- a/paddle/operators/gather_test.cc +++ b/paddle/operators/gather_test.cc @@ -45,4 +45,8 @@ TEST(Gather, GatherData) { for (int i = 0; i < 4; ++i) EXPECT_EQ(p_output[i], i + 4); for (int i = 4; i < 8; ++i) EXPECT_EQ(p_output[i], i - 4); + + delete src; + delete index; + delete output; } diff --git a/paddle/operators/math/math_function.cc b/paddle/operators/math/math_function.cc index affdd1ac2cd486930881ee6b34a4b32f41df7ee9..1e86fc3d166077265e0f433a6712b0665ea5a152 100644 --- a/paddle/operators/math/math_function.cc +++ b/paddle/operators/math/math_function.cc @@ -25,8 +25,8 @@ void gemm(const CBLAS_TRANSPOSE transA, const float alpha, const float* A, const float* B, const float beta, float* C, platform::DeviceContext* context) { - int lda = K; - int ldb = N; + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; int ldc = N; cblas_sgemm(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); @@ -40,8 +40,8 @@ void gemm(const CBLAS_TRANSPOSE transA, const double* B, const double beta, double* C, platform::DeviceContext* context) { - int lda = K; - int ldb = N; + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; int ldc = N; cblas_dgemm(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); diff --git a/paddle/operators/mean_op.cc b/paddle/operators/mean_op.cc index 49d0f43508b1ee3df0c6b5987942970e1649e310..d3d0e55a674587fb04f43f24d0790de4358f035a 100644 --- a/paddle/operators/mean_op.cc +++ b/paddle/operators/mean_op.cc @@ -34,7 +34,7 @@ class MeanOpMaker : public framework::OpProtoAndCheckerMaker { MeanOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "The input of mean op"); - AddOutput("Out", "The output of mean op").AsNoGradient(); + AddOutput("Out", "The output of mean op").NotInGradient(); AddComment("Mean Operator"); } }; diff --git a/paddle/operators/mean_op.h b/paddle/operators/mean_op.h index fcb703e63bd5a82f9ffac2bf64e06fd0218dbdaa..9848af280b62729bef9243052ceae0b7d8f4c6f5 100644 --- a/paddle/operators/mean_op.h +++ b/paddle/operators/mean_op.h @@ -55,9 +55,10 @@ class MeanGradKernel : public framework::OpKernel { IG->mutable_data(context.GetPlace()); T ig_size = (T)framework::product(IG->dims()); + Eigen::DSizes bcast(ig_size); EigenVector::Flatten(*IG).device(context.GetEigenDevice()) = - EigenScalar::From(*OG) / ig_size; + (EigenVector::From(*OG) / ig_size).broadcast(bcast); } }; diff --git a/paddle/operators/mul_op.cc b/paddle/operators/mul_op.cc index 95d19fb6aad37143e65759b03e12e3e78bce5915..460e458ca4f7f40746f0dbf7e258a165faa88e1a 100644 --- a/paddle/operators/mul_op.cc +++ b/paddle/operators/mul_op.cc @@ -18,6 +18,8 @@ namespace paddle { namespace operators { +using framework::Tensor; + class MulOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -59,10 +61,23 @@ class MulOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override {} - std::string DebugString() const override { - LOG(INFO) << "MulGrad"; - return ""; + void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null"); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), "Input(Y) should not be null"); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), + "Input(Out@GRAD) should not be null"); + auto x_dims = ctx.Input("X")->dims(); + auto y_dims = ctx.Input("Y")->dims(); + auto out_dims = ctx.Input(framework::GradVarName("Out"))->dims(); + auto *x_grad = ctx.Output(framework::GradVarName("X")); + auto *y_grad = ctx.Output(framework::GradVarName("Y")); + PADDLE_ENFORCE(x_dims[0] == out_dims[0], + "Out@GRAD M X N must equal to X dims 0, M "); + PADDLE_ENFORCE(y_dims[1] == out_dims[1], + "Out@GRAD M X N must equal to Y dims 1, N "); + + x_grad->Resize(x_dims); + y_grad->Resize(y_dims); } }; @@ -72,3 +87,5 @@ class MulOpGrad : public framework::OperatorWithKernel { namespace ops = paddle::operators; REGISTER_OP(mul, ops::MulOp, ops::MulOpMaker, mul_grad, ops::MulOpGrad); REGISTER_OP_CPU_KERNEL(mul, ops::MulKernel); +REGISTER_OP_CPU_KERNEL(mul_grad, + ops::MulGradKernel); diff --git a/paddle/operators/mul_op.cu b/paddle/operators/mul_op.cu index 346a7e505d123b5e4e831daa39a1f6349b3dcccf..a81444dbe63edeecedc5d822c65ff56c42b5db90 100644 --- a/paddle/operators/mul_op.cu +++ b/paddle/operators/mul_op.cu @@ -17,3 +17,5 @@ namespace ops = paddle::operators; REGISTER_OP_GPU_KERNEL(mul, ops::MulKernel); +REGISTER_OP_GPU_KERNEL(mul_grad, + ops::MulGradKernel); diff --git a/paddle/operators/mul_op.h b/paddle/operators/mul_op.h index b7812fd1a7a72f5ce543e18c8b7b5b51deff2204..8facc0281449785bf40726f23ca2fd5d166ff272 100644 --- a/paddle/operators/mul_op.h +++ b/paddle/operators/mul_op.h @@ -31,18 +31,34 @@ template class MulKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - Eigen::array, 1> dim_pair = { - {Eigen::IndexPair(1, 0)}}; - auto* input0 = context.Input("X"); - auto* input1 = context.Input("Y"); - auto* output = context.Output("Out"); - output->mutable_data(context.GetPlace()); - auto X = EigenMatrix::From(*input0); - auto Y = EigenMatrix::From(*input1); - auto Z = EigenMatrix::From(*output); - auto& place = context.GetEigenDevice(); - - Z.device(place) = X.contract(Y, dim_pair); + auto* X = context.Input("X"); + auto* Y = context.Input("Y"); + auto* Z = context.Output("Out"); + Z->mutable_data(context.GetPlace()); + auto* device_context = + const_cast(context.device_context_); + math::matmul(*X, false, *Y, false, 1, Z, 0, device_context); + } +}; + +template +class MulGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* X = ctx.Input("X"); + auto* Y = ctx.Input("Y"); + auto* dOut = ctx.Input(framework::GradVarName("Out")); + + auto* dX = ctx.Output(framework::GradVarName("X")); + auto* dY = ctx.Output(framework::GradVarName("Y")); + dX->mutable_data(ctx.GetPlace()); + dY->mutable_data(ctx.GetPlace()); + auto* device_context = + const_cast(ctx.device_context_); + // dX = dOut * Y'. dX: M x K, dOut : M x N, Y : K x N + math::matmul(*dOut, false, *Y, true, 1, dX, 0, device_context); + // dY = X' * dOut. dY: K x N, dOut : M x N, X : M x K + math::matmul(*X, true, *dOut, false, 1, dY, 0, device_context); } }; diff --git a/paddle/operators/net_op.cc b/paddle/operators/net_op.cc index c36fe8d6b58a0afa568e31e43567baa5f261c7d0..a7d710511093dfbe13a13b1222b0230bba0398bd 100644 --- a/paddle/operators/net_op.cc +++ b/paddle/operators/net_op.cc @@ -85,7 +85,14 @@ NetOp::NetOp(const std::string& type, const framework::OperatorBase::VarNameMap& inputs, const framework::OperatorBase::VarNameMap& outputs, const framework::AttributeMap& attrs) - : OperatorBase(type, inputs, outputs, attrs) {} + : framework::OperatorBase(type, inputs, outputs, attrs) {} + +std::unique_ptr NetOp::Clone() const { + PADDLE_ENFORCE( + add_op_done_, + "Must clone a sealed NetOp, invoke Net::CompleteAddOp before clone"); + return std::unique_ptr(new NetOp(*this)); +} } // namespace operators } // namespace paddle diff --git a/paddle/operators/net_op.h b/paddle/operators/net_op.h index 4a3408c158a029a96740717280c1562671fa938f..885ac6eeca65998dea62c1db40b9261cceb97805 100644 --- a/paddle/operators/net_op.h +++ b/paddle/operators/net_op.h @@ -41,6 +41,16 @@ class NetOp : public framework::OperatorBase { NetOp(const std::string& type, const VarNameMap& inputs, const VarNameMap& outputs, const framework::AttributeMap& attrs); + NetOp(const NetOp& o) : framework::OperatorBase(o.type_, {}, {}, o.attrs_) { + this->ops_.reserve(o.ops_.size()); + std::transform( + o.ops_.begin(), o.ops_.end(), std::back_inserter(this->ops_), + [](const std::unique_ptr& op) { + return std::unique_ptr(op->Clone()); + }); + this->CompleteAddOp(); + } + /** * Infer all the operators' input and output variables' shapes, will be called * before every mini-batch @@ -74,21 +84,27 @@ class NetOp : public framework::OperatorBase { return true; } + void AddOp(const framework::OperatorBase& op) { AddOp(op.Clone()); } + /** * @brief Add an operator by ptr */ - void AddOp(const std::shared_ptr& op) { + void AddOp(std::unique_ptr op) { PADDLE_ENFORCE(!add_op_done_, "Cannot AddOp when this network is sealed"); PADDLE_ENFORCE_NOT_NULL(op, "Cannot Insert Null op"); - ops_.push_back(op); + ops_.push_back(std::move(op)); } - void InsertOp(size_t pos, const std::shared_ptr& op) { + void InsertOp(size_t pos, std::unique_ptr op) { PADDLE_ENFORCE(!add_op_done_, "Cannot InsertOp when this network is sealed"); PADDLE_ENFORCE_NOT_NULL(op, "Cannot Insert Null op"); PADDLE_ENFORCE_LE(pos, ops_.size(), "Out of range"); - ops_.insert(ops_.begin() + pos, op); + ops_.insert(ops_.begin() + pos, std::move(op)); + } + + void InsertOp(size_t pos, const framework::OperatorBase& op) { + InsertOp(pos, op.Clone()); } void CompleteAddOp(bool calculate = true); @@ -98,7 +114,9 @@ class NetOp : public framework::OperatorBase { bool IsNetOp() const override; std::vector OutputVars(bool has_intermediate) const override; - std::vector> ops_; + std::unique_ptr Clone() const override; + + std::vector> ops_; private: bool add_op_done_{false}; diff --git a/paddle/operators/net_op_test.cc b/paddle/operators/net_op_test.cc index 0cef71de6a032674a54387986f65f17ca99b400e..e9598610c0a74e08a613a397109ad65994821498 100644 --- a/paddle/operators/net_op_test.cc +++ b/paddle/operators/net_op_test.cc @@ -13,6 +13,7 @@ static int run_cnt = 0; class TestOp : public framework::OperatorBase { public: using framework::OperatorBase::OperatorBase; + DEFINE_OP_CLONE_METHOD(TestOp); void InferShape(const Scope& scope) const override { ++infer_shape_cnt; } void Run(const Scope& scope, const platform::DeviceContext& dev_ctx) const override { @@ -37,15 +38,12 @@ TEST(OpKernel, all) { auto net = std::make_shared(); ASSERT_NE(net, nullptr); - auto op1 = std::shared_ptr( + net->AddOp(std::unique_ptr( new TestOp("test", {{"X", {"x"}}, {"W", {"w1"}}, {"b", {"b1"}}}, - {{"Out", {"y"}}}, {})); - net->AddOp(op1); - - auto op2 = std::shared_ptr( + {{"Out", {"y"}}}, {}))); + net->AddOp(std::unique_ptr( new TestOp("test", {{"X", {"y"}}, {"W", {"w2"}}, {"b", {"b2"}}}, - {{"Out", {"z"}}}, {})); - net->AddOp(op2); + {{"Out", {"z"}}}, {}))); net->CompleteAddOp(); AssertSameVectorWithoutOrder({"x", "w1", "b1", "w2", "b2"}, @@ -60,15 +58,31 @@ TEST(OpKernel, all) { TEST(NetOp, insert_op) { NetOp net; - auto op1 = std::shared_ptr( + auto op1 = std::unique_ptr( new framework::NOP("empty", {{"X", {"x"}}, {"W", {"w1"}}, {"b", {"b1"}}}, {{"Out", {"y"}}}, {})); - net.AddOp(op1); - net.InsertOp(0, op1); + net.AddOp(*op1); + net.InsertOp(0, *op1); ASSERT_EQ(2UL, net.ops_.size()); - net.InsertOp(2, op1); + net.InsertOp(2, std::move(op1)); ASSERT_EQ(3UL, net.ops_.size()); } +TEST(NetOp, Clone) { + NetOp net; + net.AddOp( + std::unique_ptr(new framework::NOP{"empty", {}, {}, {}})); + net.AddOp(std::unique_ptr( + new framework::NOP{"empty2", {}, {}, {}})); + net.CompleteAddOp(true); + auto new_net_op = net.Clone(); + ASSERT_NE(new_net_op, nullptr); + ASSERT_TRUE(new_net_op->IsNetOp()); + auto* new_net = static_cast(new_net_op.get()); + ASSERT_EQ(2, new_net->ops_.size()); + ASSERT_EQ(new_net->ops_[0]->Type(), "empty"); + ASSERT_EQ(new_net->ops_[1]->Type(), "empty2"); +} + } // namespace operators } // namespace paddle diff --git a/paddle/operators/recurrent_op.h b/paddle/operators/recurrent_op.h index 171a0bd2ae80245939a9237f51d195e8bc9178fc..bcfa817de8242153b164fa091309f19a6ad8a246 100644 --- a/paddle/operators/recurrent_op.h +++ b/paddle/operators/recurrent_op.h @@ -34,7 +34,8 @@ class RecurrentAlgorithm { void Run(const framework::Scope& scope, const platform::DeviceContext& dev_ctx) const; - void Init(rnn::Argument* arg, std::shared_ptr* stepnet) { + void Init(rnn::Argument* arg, + std::unique_ptr* stepnet) { PADDLE_ENFORCE_NOT_NULL(stepnet, "stepnet should be set before."); arg_ = arg; stepnet_ = stepnet; @@ -63,7 +64,7 @@ class RecurrentAlgorithm { void InitMemories(framework::Scope* step_scopes, bool infer_shape_mode) const; private: - std::shared_ptr* stepnet_; + std::unique_ptr* stepnet_; rnn::Argument* arg_; mutable size_t seq_len_; }; @@ -80,7 +81,8 @@ class RecurrentGradientAlgorithm { * operator. */ public: - void Init(rnn::Argument* arg, std::shared_ptr* stepnet) { + void Init(rnn::Argument* arg, + std::unique_ptr* stepnet) { PADDLE_ENFORCE_NOT_NULL(stepnet, "stepnet should be set before."); arg_ = std::move(arg); stepnet_ = stepnet; @@ -107,16 +109,23 @@ class RecurrentGradientAlgorithm { private: rnn::Argument* arg_; mutable size_t seq_len_; - std::shared_ptr* stepnet_; + std::unique_ptr* stepnet_; }; -class RecurrentOp final : public framework::OperatorBase { +class RecurrentOp : public framework::OperatorBase { public: RecurrentOp(const std::string& type, const VarNameMap& inputs, const VarNameMap& outputs, const framework::AttributeMap& attrs); + + RecurrentOp(const RecurrentOp& o) + : framework::OperatorBase( + static_cast(o)) { + // TODO(yuyang18): Implement copy ctor well. + PADDLE_THROW("Not implemented"); + } /** - * InferShape must be called before Run. - */ + * InferShape must be called before Run. + */ void InferShape(const framework::Scope& scope) const override { alg_.InferShape(scope); } @@ -126,23 +135,32 @@ class RecurrentOp final : public framework::OperatorBase { alg_.Run(scope, dev_ctx); } - void set_stepnet(std::shared_ptr net) { stepnet_ = net; } - const NetOp& stepnet() const { return *stepnet_; } + void set_stepnet(std::unique_ptr net) { + stepnet_ = std::move(net); + } + const OperatorBase& stepnet() const { return *stepnet_; } static const rnn::ArgumentName kArgName; private: RecurrentAlgorithm alg_; rnn::Argument arg_; - std::shared_ptr stepnet_; + std::unique_ptr stepnet_; }; -class RecurrentGradientOp final : public framework::OperatorBase { +class RecurrentGradientOp : public framework::OperatorBase { public: RecurrentGradientOp(const std::string& type, const VarNameMap& inputs, const VarNameMap& outputs, const framework::AttributeMap& attrs); + RecurrentGradientOp(const RecurrentGradientOp& o) + : framework::OperatorBase( + static_cast(o)) { + // TODO(yuyang18): Implement Copy ctor. + PADDLE_THROW("Not Implemented"); + } + /** * InferShape must be called before Run. */ @@ -157,12 +175,14 @@ class RecurrentGradientOp final : public framework::OperatorBase { static const rnn::ArgumentName kArgName; - void set_stepnet(const std::shared_ptr& net) { stepnet_ = net; } - const NetOp& stepnet() const { return *stepnet_; } + void set_stepnet(std::unique_ptr net) { + stepnet_ = std::move(net); + } + const OperatorBase& stepnet() const { return *stepnet_; } private: RecurrentGradientAlgorithm alg_; - std::shared_ptr stepnet_; + std::unique_ptr stepnet_; rnn::Argument arg_; }; diff --git a/paddle/operators/rowwise_add_op.cc b/paddle/operators/rowwise_add_op.cc index 8375d988045dc24fa1109646b46ff477e2a78132..6825dce332adc0dc11dda187d1bd367875b8603e 100644 --- a/paddle/operators/rowwise_add_op.cc +++ b/paddle/operators/rowwise_add_op.cc @@ -17,7 +17,9 @@ namespace paddle { namespace operators { -class RowWiseAddOp : public framework::OperatorWithKernel { +using framework::Tensor; + +class RowwiseAddOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -34,9 +36,9 @@ class RowWiseAddOp : public framework::OperatorWithKernel { } }; -class RowWiseAddOpMaker : public framework::OpProtoAndCheckerMaker { +class RowwiseAddOpMaker : public framework::OpProtoAndCheckerMaker { public: - RowWiseAddOpMaker(framework::OpProto *proto, + RowwiseAddOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "The left input of row-wise add op, must be matrix"); @@ -49,12 +51,32 @@ for i in xrange(X.shape[0]): )DOC"); } }; +class RowwiseAddGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "X should not be null"); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("b"), "b should not be null"); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), + "Input(Out@GRAD) should not be null"); + auto dims0 = ctx.Input("X")->dims(); + auto dims1 = ctx.Input("b")->dims(); + PADDLE_ENFORCE_EQ(1, dims1.size(), "b dims should be 1") + ctx.Output(framework::GradVarName("X"))->Resize(dims0); + ctx.Output(framework::GradVarName("b"))->Resize(dims1); + } +}; } // namespace operators } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP_WITHOUT_GRADIENT(rowwise_add, ops::RowWiseAddOp, - ops::RowWiseAddOpMaker); +REGISTER_OP(rowwise_add, ops::RowwiseAddOp, ops::RowwiseAddOpMaker, + rowwise_add_grad, ops::RowwiseAddGradOp); +REGISTER_OP_CPU_KERNEL( + rowwise_add, ops::RowwiseAddKernel); REGISTER_OP_CPU_KERNEL( - rowwise_add, ops::RowWiseAddKernel); + rowwise_add_grad, + ops::RowwiseAddGradKernel); diff --git a/paddle/operators/rowwise_add_op.cu b/paddle/operators/rowwise_add_op.cu index 86f80b81228a69ac4c05a4693901570f2b9966e0..cbc61ad3e117fc79a674ca21831d3fec59d1ec5b 100644 --- a/paddle/operators/rowwise_add_op.cu +++ b/paddle/operators/rowwise_add_op.cu @@ -17,4 +17,4 @@ namespace ops = paddle::operators; REGISTER_OP_GPU_KERNEL( - rowwise_add, ops::RowWiseAddKernel); + rowwise_add, ops::RowwiseAddKernel); diff --git a/paddle/operators/rowwise_add_op.h b/paddle/operators/rowwise_add_op.h index 01f88f2198774fbaa4c98ff9bf286f2f08496a9a..232135c38de68d4002e044972b282b43a1374c72 100644 --- a/paddle/operators/rowwise_add_op.h +++ b/paddle/operators/rowwise_add_op.h @@ -28,7 +28,7 @@ template ; template -class RowWiseAddKernel : public framework::OpKernel { +class RowwiseAddKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto out = context.Output("Out"); @@ -47,5 +47,25 @@ class RowWiseAddKernel : public framework::OpKernel { } }; +template +class RowwiseAddGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* dOut = context.Input(framework::GradVarName("Out")); + auto* dX = context.Output(framework::GradVarName("X")); + auto* db = context.Output(framework::GradVarName("b")); + dX->mutable_data(context.GetPlace()); + db->mutable_data(context.GetPlace()); + + auto OutGrad = EigenMatrix::From(*dOut); + auto place = context.GetEigenDevice(); + EigenMatrix::From(*dX).device(place) = OutGrad; + + // https://eigen.tuxfamily.org/dox/unsupported/TensorBase_8h_source.html + // colwise add + Eigen::array dims{{1}}; /* dimension to reduce */ + EigenVector::Flatten(*db).device(place) = OutGrad.sum(dims); + } +}; } // namespace operators } // namespace paddle diff --git a/paddle/operators/scatter_test.cc b/paddle/operators/scatter_test.cc index 4449ce6564396f1971506efb7458c00c834db19f..26fdaff1460a297fa638181641991f732533fe52 100644 --- a/paddle/operators/scatter_test.cc +++ b/paddle/operators/scatter_test.cc @@ -49,4 +49,8 @@ TEST(scatter, ScatterUpdate) { EXPECT_EQ(output->data()[i], float(i - 4)); for (size_t i = 8; i < 16; ++i) EXPECT_EQ(p_output[i], float(0)); for (size_t i = 8; i < 16; ++i) EXPECT_EQ(output->data()[i], float(0)); + + delete src; + delete index; + delete output; } diff --git a/paddle/operators/sgd_op.h b/paddle/operators/sgd_op.h index bfb449d0b029409eda4177fc7643810ee6a1df3d..a0b5000ffbf54364e15f87870913926a071fa972 100644 --- a/paddle/operators/sgd_op.h +++ b/paddle/operators/sgd_op.h @@ -30,7 +30,7 @@ class SGDOpKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& ctx) const override { auto param = ctx.Input("param"); auto grad = ctx.Input("grad"); - auto param_out = ctx.Output(0); + auto param_out = ctx.Output("param_out"); float lr = ctx.op_.GetAttr("learning_rate"); param_out->mutable_data(ctx.GetPlace()); diff --git a/paddle/operators/sigmoid_op.cc b/paddle/operators/sigmoid_op.cc index d773a4f2d50e82146a729b1cda085ce86ade89cc..761c6de8d4d2150b30b97b58da95da3d5f33db63 100644 --- a/paddle/operators/sigmoid_op.cc +++ b/paddle/operators/sigmoid_op.cc @@ -44,7 +44,8 @@ class SigmoidOpGrad : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { - ctx.Output(0)->Resize(ctx.Input(0)->dims()); + ctx.Output(framework::GradVarName("X")) + ->Resize(ctx.Input("Y")->dims()); } }; diff --git a/paddle/operators/sigmoid_op.h b/paddle/operators/sigmoid_op.h index 11ab923eb346c1f8de3a6bbebdfa874b6530004a..b01a9b3f23283471f8846325075719ba0e75ed35 100644 --- a/paddle/operators/sigmoid_op.h +++ b/paddle/operators/sigmoid_op.h @@ -37,7 +37,7 @@ class SigmoidKernel : public framework::OpKernel { auto Y = EigenVector::Flatten(*output); auto place = context.GetEigenDevice(); - Y.device(place) = 1.0 / (1.0 + (-1.0 * X).exp()); + Y.device(place) = 1. / (1. + (-X).exp()); } }; diff --git a/paddle/parameter/Parameter.cpp b/paddle/parameter/Parameter.cpp index ebe36d49376882fe4c1013e19dcf71f452b3e501..f0311095012d944768d80abe423d4a9bfc0e97f5 100644 --- a/paddle/parameter/Parameter.cpp +++ b/paddle/parameter/Parameter.cpp @@ -48,7 +48,8 @@ Parameter::Parameter(const ParameterConfig& config, bool useGpu, bool doInit) deviceId_(-1), sharedCount_(0), updateCounter_(0), - updated_(false) { + updated_(false), + headerFormat_(PARAM_FORMAT_ORIGINAL) { setID(-1); /* capture uninitialized id */ if (useGpu_ && FLAGS_parallel_nn) { /* gpu environment is specified by device property */ @@ -285,7 +286,7 @@ bool Parameter::save(const std::string& filename) const { bool Parameter::save(std::ostream& s) const { CpuVector vec(*bufs_[PARAMETER_VALUE].get()); Header header; - header.version = kFormatVersion; + header.format = headerFormat_; header.valueSize = sizeof(real); header.size = getSize(); @@ -344,8 +345,9 @@ bool Parameter::load(std::istream& s) { Header header; CHECK(s.read(reinterpret_cast(&header), sizeof(header))) << "Fail to read parameter " << getName(); - CHECK_EQ(header.version, kFormatVersion) << "Incorrect format version: " - << header.version; + CHECK(isHeaderFormatSupported(header.format)) << "Incorrect format version: " + << header.format; + headerFormat_ = header.format; CHECK_EQ(header.size, getSize()) << "The size (" << header.size << ") in the file does not match the size " << "(" << getSize() << ") of the parameter: " << getName(); diff --git a/paddle/parameter/Parameter.h b/paddle/parameter/Parameter.h index 0bac76f068ec22bec52766b43e331fe109a34188..e31cbc3dee6c57851c241e117dbbd9b701db9d2c 100644 --- a/paddle/parameter/Parameter.h +++ b/paddle/parameter/Parameter.h @@ -34,6 +34,20 @@ limitations under the License. */ namespace paddle { +typedef enum { + /// The paddle original basic format + PARAM_FORMAT_ORIGINAL = 0, + + /// See mkldnn_memory_format_t in + /// https://github.com/01org/mkl-dnn/blob/master/include/mkldnn_types.h + /// for a detailed description. + /// 2D weights tensor in the format (output channels, input channels). + PARAM_FORMAT_MKLDNN_OI, + + /// The total format items numbers + PARAM_FORMAT_ITEMS, +} PARAM_FORMAT; + class SparsePrefetchRowCpuMatrix; class Parameter; @@ -242,14 +256,30 @@ public: /// Initialize the value to 0 void zeroMem(); - static const int kFormatVersion = 0; /// file header structure struct Header { - int32_t version; // = 0, file format version + int32_t format; // = PARAM_FORMAT uint32_t valueSize; // = sizeof(real) uint64_t size; // = getSize() }; + /** + * @brief Is the header format supported. + */ + static bool isHeaderFormatSupported(int32_t fmt) { + return fmt < PARAM_FORMAT_ITEMS; + } + + /** + * @brief Get the format in header. + */ + int getHeaderFormat() { return headerFormat_; } + + /** + * @brief Set the format in header. + */ + void setHeaderFormat(int32_t fmt) { headerFormat_ = fmt; } + /** * @brief Parameter Update Hook. * @@ -321,6 +351,9 @@ protected: bool updated_; SparseFormat format_; + /// The header format for saving or loading param + int32_t headerFormat_; + std::vector> updaterHooks_; public: diff --git a/paddle/platform/CMakeLists.txt b/paddle/platform/CMakeLists.txt index acfc0639736beb82df41b851664e7bcd079b5eb1..120eb1e4af9cef43e76e27d4ad66acfbbd597a36 100644 --- a/paddle/platform/CMakeLists.txt +++ b/paddle/platform/CMakeLists.txt @@ -1,7 +1,7 @@ cc_library(cpu_info SRCS cpu_info.cc DEPS gflags glog) cc_test(cpu_info_test SRCS cpu_info_test.cc DEPS cpu_info) -nv_library(gpu_info SRCS gpu_info.cc DEPS gflags) +nv_library(gpu_info SRCS gpu_info.cc DEPS gflags glog) cc_library(place SRCS place.cc) cc_test(place_test SRCS place_test.cc DEPS place glog gflags) @@ -9,6 +9,7 @@ cc_test(place_test SRCS place_test.cc DEPS place glog gflags) add_subdirectory(dynload) cc_test(enforce_test SRCS enforce_test.cc DEPS stringpiece) +cc_test(environment_test SRCS environment_test.cc DEPS stringpiece) IF(WITH_GPU) set(GPU_CTX_DEPS dynload_cuda dynamic_loader) diff --git a/paddle/platform/enforce.h b/paddle/platform/enforce.h index 15fdf7a94f462a87f7edae1429eb0c4da0b17a84..81448897e95eb05f4ce7de8683d98e05bade77cb 100644 --- a/paddle/platform/enforce.h +++ b/paddle/platform/enforce.h @@ -86,7 +86,7 @@ struct EnforceNotMet : public std::exception { 2 + sizeof(void*) * 2, call_stack[i], demangled, addr_offset); } else { - sout << string::Sprintf("%-3d %*0p %s\n", i, 2 + sizeof(void*) * 2, + sout << string::Sprintf("%-3d %*0p\n", i, 2 + sizeof(void*) * 2, call_stack[i]); } } diff --git a/paddle/platform/environment.h b/paddle/platform/environment.h new file mode 100644 index 0000000000000000000000000000000000000000..4edcce932edc61453cef74f2c4ee0f72496b3677 --- /dev/null +++ b/paddle/platform/environment.h @@ -0,0 +1,60 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include + +#include "paddle/platform/enforce.h" +#include "paddle/string/piece.h" + +extern char** environ; // for environment variables + +namespace paddle { +namespace platform { + +inline void SetEnvVariable(const std::string& name, const std::string& value) { + PADDLE_ENFORCE_NE(setenv(name.c_str(), value.c_str(), 1), -1, + "Failed to set environment variable %s=%s", name, value); +} + +inline void UnsetEnvVariable(const std::string& name) { + PADDLE_ENFORCE_NE(unsetenv(name.c_str()), -1, + "Failed to unset environment variable %s", name); +} + +inline bool IsEnvVarDefined(const std::string& name) { + return std::getenv(name.c_str()) != nullptr; +} + +inline std::string GetEnvValue(const std::string& name) { + PADDLE_ENFORCE(IsEnvVarDefined(name), + "Tried to access undefined environment variable %s", name); + return std::getenv(name.c_str()); +} + +inline std::vector GetAllEnvVariables() { + std::vector vars; + for (auto var = environ; *var != nullptr; ++var) { + auto tail = string::Index(*var, "="); + auto name = string::SubStr(*var, 0, tail).ToString(); + vars.push_back(name); + } + return vars; +} + +} // namespace platform +} // namespace paddle diff --git a/paddle/platform/environment_test.cc b/paddle/platform/environment_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..5f136527215d6a676cfa1a3b08f09dfd3ab24a90 --- /dev/null +++ b/paddle/platform/environment_test.cc @@ -0,0 +1,54 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/platform/environment.h" + +#include "glog/logging.h" +#include "gtest/gtest.h" + +TEST(ENVIRONMENT, ACCESS) { + namespace platform = paddle::platform; + namespace string = paddle::string; + + platform::SetEnvVariable("PADDLE_USE_ENV", "TRUE"); + + EXPECT_TRUE(platform::IsEnvVarDefined("PADDLE_USE_ENV")); + EXPECT_EQ(platform::GetEnvValue("PADDLE_USE_ENV"), "TRUE"); + + platform::UnsetEnvVariable("PADDLE_USE_ENV"); + EXPECT_FALSE(platform::IsEnvVarDefined("PADDLE_USE_ENV")); + + platform::SetEnvVariable("PADDLE_USE_ENV1", "Hello "); + platform::SetEnvVariable("PADDLE_USE_ENV2", "World, "); + platform::SetEnvVariable("PADDLE_USE_ENV3", "PaddlePaddle!"); + + std::string env_info; + auto vars = platform::GetAllEnvVariables(); + for_each(vars.begin(), vars.end(), [&](const std::string& var) { + env_info += platform::GetEnvValue(var); + }); + + EXPECT_TRUE(string::Contains(env_info, "Hello World, PaddlePaddle!")); + platform::UnsetEnvVariable("PADDLE_USE_ENV1"); + platform::UnsetEnvVariable("PADDLE_USE_ENV2"); + platform::UnsetEnvVariable("PADDLE_USE_ENV3"); + + env_info.clear(); + vars = platform::GetAllEnvVariables(); + for_each(vars.begin(), vars.end(), [&](const std::string& var) { + env_info += platform::GetEnvValue(var); + }); + + EXPECT_FALSE(string::Contains(env_info, "Hello World, PaddlePaddle!")); + EXPECT_FALSE(platform::IsEnvVarDefined("PADDLE_USE_ENV1")); + EXPECT_FALSE(platform::IsEnvVarDefined("PADDLE_USE_ENV2")); + EXPECT_FALSE(platform::IsEnvVarDefined("PADDLE_USE_ENV3")); +} diff --git a/paddle/platform/gpu_info.cc b/paddle/platform/gpu_info.cc index edeb3ecd7bf8b87333813eee5b40f71030f6609f..be381a4e26cf0eb41f5b3de88bd03ad8901683cc 100644 --- a/paddle/platform/gpu_info.cc +++ b/paddle/platform/gpu_info.cc @@ -13,8 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/platform/gpu_info.h" + #include "gflags/gflags.h" + #include "paddle/platform/enforce.h" +#include "paddle/platform/environment.h" DEFINE_double(fraction_of_gpu_memory_to_use, 0.95, "Default use 95% of GPU memory for PaddlePaddle," @@ -70,6 +73,13 @@ size_t GpuMaxChunkSize() { GpuMemoryUsage(available, total); + if (IsEnvVarDefined(kEnvFractionGpuMemoryToUse)) { + auto val = std::stod(GetEnvValue(kEnvFractionGpuMemoryToUse)); + PADDLE_ENFORCE_GT(val, 0.0); + PADDLE_ENFORCE_LE(val, 1.0); + FLAGS_fraction_of_gpu_memory_to_use = val; + } + // Reserving the rest memory for page tables, etc. size_t reserving = (1 - FLAGS_fraction_of_gpu_memory_to_use) * total; diff --git a/paddle/platform/gpu_info.h b/paddle/platform/gpu_info.h index d3a5f5f13fdd3dd59eb43465da4a64b0d8d95e5b..ed2420b8740e583d307f6836a70fe7e1c780e28b 100644 --- a/paddle/platform/gpu_info.h +++ b/paddle/platform/gpu_info.h @@ -18,10 +18,15 @@ limitations under the License. */ #include #include +#include namespace paddle { namespace platform { +//! Environment variable: fraction of GPU memory to use on each device. +const std::string kEnvFractionGpuMemoryToUse = + "PADDLE_FRACTION_GPU_MEMORY_TO_USE"; + //! Get the total number of GPU devices in system. int GetDeviceCount(); diff --git a/paddle/pserver/ParameterServer2.cpp b/paddle/pserver/ParameterServer2.cpp index d7c1d4f788f44c6bfcec040ba24bdc454348c911..54f5c4c0fb4994871edc7a1e52237c9f903ce63b 100644 --- a/paddle/pserver/ParameterServer2.cpp +++ b/paddle/pserver/ParameterServer2.cpp @@ -1032,8 +1032,8 @@ void ParameterServer2::loadValueVector(const LoadValueRequest& request, Parameter::Header header; CHECK(fs.read(reinterpret_cast(&header), sizeof(header))) << "Fail to read parameters in pserver"; - CHECK_EQ(header.version, Parameter::kFormatVersion) - << "Incorrect format version: " << header.version; + CHECK(Parameter::isHeaderFormatSupported(header.format)) + << "Incorrect format version: " << header.format; CHECK_EQ(header.size, (size_t)size_) << "The size (" << header.size << ") in the file does not match the size " << "(" << size_ << ") of the pserver: " << serverId_; @@ -1063,7 +1063,8 @@ void ParameterServer2::saveValueVector(const SaveValueRequest& request, CpuVector& vec = vectors_[PARAMETER_APPLY] ? *vectors_[PARAMETER_APPLY] : *vectors_[PARAMETER_VALUE]; Parameter::Header header; - header.version = Parameter::kFormatVersion; + // TODO(TJ): save param headerFormat_ + header.format = PARAM_FORMAT_ORIGINAL; header.valueSize = sizeof(real); header.size = size_; diff --git a/paddle/scripts/docker/build.sh b/paddle/scripts/docker/build.sh index 7c12664aed7957d72c0b2d12ca2afe448eac5a98..2941662f349baf57d1fe8188e88ce21d5de07750 100644 --- a/paddle/scripts/docker/build.sh +++ b/paddle/scripts/docker/build.sh @@ -146,7 +146,8 @@ RUN apt-get update &&\ pip install /*.whl; apt-get install -f -y && \ apt-get clean -y && \ rm -f /*.whl && \ - paddle version + paddle version && \ + ldconfig ${DOCKERFILE_CUDNN_DSO} ${DOCKERFILE_GPU_ENV} ADD go/cmd/pserver/pserver /usr/bin/ diff --git a/paddle/trainer/TrainerConfigHelper.cpp b/paddle/trainer/TrainerConfigHelper.cpp index eba40862b926cfe863c569e73a6a3ceabcf1f3b4..a0a365aa0bb0ac26939a02c1cd626d0c17c6a9fe 100644 --- a/paddle/trainer/TrainerConfigHelper.cpp +++ b/paddle/trainer/TrainerConfigHelper.cpp @@ -29,7 +29,6 @@ DECLARE_bool(with_gpu); DECLARE_bool(parallel_nn); DECLARE_string(config_args); DECLARE_bool(use_mkldnn); -DECLARE_bool(use_mkldnn_wgt); const char *kConfigParserModuleName = "paddle.trainer.config_parser"; const char *kConfigParserFuncName = "parse_config_and_serialize"; @@ -47,7 +46,6 @@ TrainerConfigHelper::TrainerConfigHelper(const std::string &configFilePath) << ",with_cost=" << FLAGS_with_cost << ",use_gpu=" << FLAGS_use_gpu << ",parallel_nn=" << FLAGS_parallel_nn << ",use_mkldnn=" << FLAGS_use_mkldnn - << ",use_mkldnn_wgt=" << FLAGS_use_mkldnn_wgt << ",cudnn_version=" << hl_get_cudnn_lib_version(); if (!FLAGS_config_args.empty()) { configArgs << "," << FLAGS_config_args; diff --git a/paddle/utils/Flags.cpp b/paddle/utils/Flags.cpp index 600c83a8487191895de635dd8433f6c44e86ce77..ab1c181c62cdbee8cc5f804ec9aaf63ac5464ad6 100644 --- a/paddle/utils/Flags.cpp +++ b/paddle/utils/Flags.cpp @@ -27,7 +27,6 @@ DEFINE_bool(use_mkldnn, false, "Default still keep use CPU training"); DEFINE_bool(use_mkldnn, false, "Only support CPU training"); #endif -DEFINE_bool(use_mkldnn_wgt, false, "Init weight from CPU weight"); DEFINE_bool(parallel_nn, false, "Whether to use multi-threads to calculate one neural network." diff --git a/paddle/utils/Flags.h b/paddle/utils/Flags.h index 0aca4c0ee036ee8490c0ceca7279df876dc21947..1832bb515ec85df3d7733e01b063a01ad6a3b282 100644 --- a/paddle/utils/Flags.h +++ b/paddle/utils/Flags.h @@ -41,4 +41,3 @@ DECLARE_string(predict_file); DECLARE_bool(prev_batch_state); DECLARE_string(init_model_path); DECLARE_bool(use_mkldnn); -DECLARE_bool(use_mkldnn_wgt); diff --git a/python/paddle/v2/framework/tests/CMakeLists.txt b/python/paddle/v2/framework/tests/CMakeLists.txt index 96fad9b42e04a88fdcbda093683b57451b2a3e41..ce57a0713092723b6a99b2416e06ff1a436f043b 100644 --- a/python/paddle/v2/framework/tests/CMakeLists.txt +++ b/python/paddle/v2/framework/tests/CMakeLists.txt @@ -25,3 +25,5 @@ py_test(test_operator SRCS test_operator.py) # py_test(test_gaussian_random_op SRCS test_gaussian_random_op.py) py_test(test_uniform_random_op SRCS test_uniform_random_op.py) py_test(test_recurrent_op SRCS test_recurrent_op.py) +py_test(test_sgd_op SRCS test_sgd_op.py) +py_test(test_gradient_checker SRCS test_gradient_checker.py) diff --git a/python/paddle/v2/framework/tests/gradient_checker.py b/python/paddle/v2/framework/tests/gradient_checker.py index 501cf6110ff745b8a6022b463bc9cc3a70145c60..8b8e2f444be1169c23784321721c5d8154541fcf 100644 --- a/python/paddle/v2/framework/tests/gradient_checker.py +++ b/python/paddle/v2/framework/tests/gradient_checker.py @@ -1,6 +1,7 @@ import unittest import numpy +import itertools import paddle.v2.framework.core as core from paddle.v2.framework.op import Operator @@ -8,6 +9,7 @@ __all__ = ['get_numeric_gradient'] def create_op(op_type): + # TODO need to set attrs kwargs = dict() for in_name in Operator.get_op_input_names(op_type): kwargs[in_name] = in_name @@ -66,7 +68,6 @@ def get_numeric_gradient(op, local_scope.find_var(output).get_tensor().alloc_float(core.CPUPlace( )) - # TODO(yuyang18): Only CPU is support now. cpu_ctx = core.DeviceContext.create(core.CPUPlace()) def get_output(): @@ -109,12 +110,110 @@ def get_numeric_gradient(op, class GradientChecker(unittest.TestCase): - def assert_is_close(self, numeric_grads, scope, max_relative_error, - msg_prefix): - for name in numeric_grads: - b = numpy.array(scope.find_var(grad_var_name(name)).get_tensor()) - a = numeric_grads[name] + def __get_gradient(self, forward_op, backward_op, input_value, grad_names, + place): + """Get the input gradients after running forward and backward operators + on the given places. + + :param forward_op: forward operator + :type forward_op: Operator + :param backward_op: backward operator + :type backward_op: Operator + :param input_value: input values. + :type input_value: dict{string:numpy.array} + :param grad_names: the names of returned input gradients. + :type input_value: a list of string + :param place: the device type. + :type place: CPUPlace or GPUPlace + :return: the input grdients of given grad_names. + :rtype: a list of numpy.array + """ + scope = core.Scope() + ctx = core.DeviceContext.create(place) + + inputs = forward_op.inputs() + in_names = [item for k in inputs for item in inputs[k]] + outputs = forward_op.outputs() + out_names = [item for k in outputs for item in outputs[k]] + + # create input var and set value + for name, value in input_value.iteritems(): + if name not in in_names: + raise ValueError(name + "does not exist in Op's inputs.") + var = scope.new_var(name).get_tensor() + var.set_dims(value.shape) + var.set(value, place) + + # run forward op + for out_name in out_names: + scope.new_var(out_name) + forward_op.infer_shape(scope) + forward_op.run(scope, ctx) + + # set output var's shape + # set output grad to ones + for name in out_names: + out_tensor = scope.find_var(name).get_tensor() + grad_tensor = scope.new_var(grad_var_name(name)).get_tensor() + grad_tensor.set_dims(out_tensor.shape()) + data = numpy.ones(out_tensor.shape(), dtype=numpy.float32) + grad_tensor.set(data, place) + + # run backward op + for name in backward_op.outputs(): + scope.new_var(name) + backward_op.infer_shape(scope) + backward_op.run(scope, ctx) + + outs = [ + numpy.array(scope.find_var(name).get_tensor()) + for name in grad_names + ] + return outs + + def compare_grad(self, forward_op, input_value): + """ Compare the input gradients between CPU and GPU for the given forward + operator. + + :param forward_op: forward operator + :type forward_op: Operator + :param input_value: input values. + :type input_value: dict{string:numpy.array} + :raises: AssertionError, there is different gradient value. + """ + backward_op = core.Operator.backward(forward_op, set()) + # return if not compile with GPU or not implementing GPU kernel + if not (core.is_compile_gpu() and backward_op.support_gpu()): + return + outputs = backward_op.outputs() + out_names = [item for k in outputs for item in outputs[k]] + cpu_grads = self.__get_gradient(forward_op, backward_op, input_value, + out_names, core.CPUPlace()) + gpu_grads = self.__get_gradient(forward_op, backward_op, input_value, + out_names, core.GPUPlace(0)) + + for c_grad, g_grad, name in itertools.izip(cpu_grads, gpu_grads, + out_names): + self.assertTrue( + numpy.allclose( + c_grad, g_grad, atol=1e-4), + "output name: " + name + " has diff") + + def __assert_is_close(self, numeric_grads, analytic_grads, names, + max_relative_error, msg_prefix): + """Use relative error for the comparison. + + :param numeric_grads: the numerical graidents. + :type numeric_grads: a list of numpy.array + :param analytic_grads: the analytical graidents. + :type analytic_grads: a list of numpy.array + :param name: the names of gradients, used to print for debug. + :type names: a list of string + :param msg_prefix: string info, used to print for debug. + :type msf_prefix: string + """ + for a, b, name in itertools.izip(numeric_grads, analytic_grads, names): abs_a = numpy.abs(a) # if abs_a is nearly zero, then use abs error for a, not relative # error. @@ -159,106 +258,26 @@ class GradientChecker(unittest.TestCase): inputs = forward_op.inputs() in_names = [item for k in inputs for item in inputs[k]] - outputs = forward_op.outputs() - out_names = [item for k in outputs for item in outputs[k]] - for no_grad in no_grad_set: if no_grad not in in_names: raise ValueError("no_grad should be in in_names") - backward_op = core.Operator.backward(forward_op, no_grad_set) - bwd_outputs = backward_op.outputs() - bwd_out_names = [item for k in bwd_outputs for item in bwd_outputs[k]] - places = [core.CPUPlace()] if not only_cpu and core.is_compile_gpu() and backward_op.support_gpu(): places.append(core.GPUPlace(0)) - numeric_grad = dict() - # get numeric gradient - for check_name in inputs_to_check: - numeric_grad[check_name] = \ - get_numeric_gradient(forward_op, input_vars, output_name, - check_name) + # get numerical gradients + numeric_grads = [ + get_numeric_gradient(forward_op, input_vars, output_name, name) + for name in inputs_to_check + ] - # get operator gradient according to different device + check_names = [grad_var_name(name) for name in inputs_to_check] for place in places: - scope = core.Scope() - ctx = core.DeviceContext.create(place) - - # create input var and set value - for name, value in input_vars.iteritems(): - if name not in in_names: - raise ValueError(name + " not in op.inputs_") - var = scope.new_var(name).get_tensor() - var.set_dims(value.shape) - var.set(value, place) - - # create output var - for out_name in out_names: - scope.new_var(out_name).get_tensor() - - # infer the shape of output var and compute/set value of output var - forward_op.infer_shape(scope) - forward_op.run(scope, ctx) - - # create output grad var - # set shape as the output var - # set value of this grad to ones - for name in out_names: - out_tensor = scope.find_var(name).get_tensor() - grad_tensor = scope.new_var(grad_var_name(name)).get_tensor() - grad_tensor.set_dims(out_tensor.shape()) - data = 1.0 * numpy.ones(out_tensor.shape()) - grad_tensor.set(data, place) - - # create input grad var - for name in bwd_out_names: - scope.new_var(name).get_tensor() - - # infer the shape of input gradient var and compute/set it's value - # with backward op - backward_op.infer_shape(scope) - backward_op.run(scope, ctx) - - self.assert_is_close(numeric_grad, scope, max_relative_error, - "Gradient Check On %s" % str(place)) - - -if __name__ == '__main__': - - class GetNumericGradientTest(unittest.TestCase): - def test_add_op(self): - add_op = Operator('add_two', X="X", Y="Y", Out="Z") - x = numpy.random.random((10, 1)).astype("float32") - y = numpy.random.random((10, 1)).astype("float32") - - arr = get_numeric_gradient(add_op, {'X': x, "Y": y}, 'Z', 'X') - self.assertAlmostEqual(arr.mean(), 1.0, delta=1e-2) - - def test_softmax_op(self): - def stable_softmax(x): - """Compute the softmax of vector x in a numerically stable way.""" - shiftx = x - numpy.max(x) - exps = numpy.exp(shiftx) - return exps / numpy.sum(exps) - - def label_softmax_grad(Y, dY): - dX = Y * 0.0 - for i in range(Y.shape[0]): - d = numpy.dot(Y[i, :], dY[i, :]) - dX[i, :] = Y[i, :] * (dY[i, :] - d) - return dX - - softmax_op = Operator("softmax", X="X", Y="Y") - - X = numpy.random.random((2, 2)).astype("float32") - Y = numpy.apply_along_axis(stable_softmax, 1, X) - dY = numpy.ones(Y.shape) - dX = label_softmax_grad(Y, dY) - - arr = get_numeric_gradient(softmax_op, {"X": X}, 'Y', 'X') - numpy.testing.assert_almost_equal(arr, dX, decimal=1e-2) - - unittest.main() + # get analytical gradients according to different device + analytic_grads = self.__get_gradient(forward_op, backward_op, + input_vars, check_names, place) + self.__assert_is_close(numeric_grads, analytic_grads, check_names, + max_relative_error, + "Gradient Check On %s" % str(place)) diff --git a/python/paddle/v2/framework/tests/test_gradient_checker.py b/python/paddle/v2/framework/tests/test_gradient_checker.py new file mode 100644 index 0000000000000000000000000000000000000000..e0b315120862bea284e067070492dcdfbb661081 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_gradient_checker.py @@ -0,0 +1,43 @@ +import unittest +import numpy +from paddle.v2.framework.op import Operator +from gradient_checker import GradientChecker +from gradient_checker import get_numeric_gradient + + +class GetNumericGradientTest(unittest.TestCase): + def test_add_op(self): + add_op = Operator('add_two', X="X", Y="Y", Out="Z") + x = numpy.random.random((10, 1)).astype("float32") + y = numpy.random.random((10, 1)).astype("float32") + + arr = get_numeric_gradient(add_op, {'X': x, "Y": y}, 'Z', 'X') + self.assertAlmostEqual(arr.mean(), 1.0, delta=1e-4) + + def test_softmax_op(self): + def stable_softmax(x): + """Compute the softmax of vector x in a numerically stable way.""" + shiftx = x - numpy.max(x) + exps = numpy.exp(shiftx) + return exps / numpy.sum(exps) + + def label_softmax_grad(Y, dY): + dX = Y * 0.0 + for i in range(Y.shape[0]): + d = numpy.dot(Y[i, :], dY[i, :]) + dX[i, :] = Y[i, :] * (dY[i, :] - d) + return dX + + softmax_op = Operator("softmax", X="X", Y="Y") + + X = numpy.random.random((2, 2)).astype("float32") + Y = numpy.apply_along_axis(stable_softmax, 1, X) + dY = numpy.ones(Y.shape) + dX = label_softmax_grad(Y, dY) + + arr = get_numeric_gradient(softmax_op, {"X": X}, 'Y', 'X') + numpy.testing.assert_almost_equal(arr, dX, decimal=1e-2) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_mean_op.py b/python/paddle/v2/framework/tests/test_mean_op.py index b5d52b90567bcd0c9f376147145d8638049f7bab..f32b3160d651a290823223c46c45bb3b6950a505 100644 --- a/python/paddle/v2/framework/tests/test_mean_op.py +++ b/python/paddle/v2/framework/tests/test_mean_op.py @@ -1,5 +1,6 @@ import unittest from op_test_util import OpTestMeta +from gradient_checker import GradientChecker, create_op import numpy as np @@ -12,5 +13,12 @@ class TestMeanOp(unittest.TestCase): self.outputs = {'Out': np.mean(self.inputs['X'])} +class MeanGradOpTest(GradientChecker): + def test_normal(self): + op = create_op("mean") + inputs = {"X": np.random.random((10, 10)).astype("float32")} + self.check_grad(op, inputs, set("X"), "Out") + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/v2/framework/tests/test_mul_op.py b/python/paddle/v2/framework/tests/test_mul_op.py index ec0ac99156a546dd3fb7b27778032bece38ab5a9..ee0d81a64efcb81bae8b11b856c201a86da274e9 100644 --- a/python/paddle/v2/framework/tests/test_mul_op.py +++ b/python/paddle/v2/framework/tests/test_mul_op.py @@ -1,6 +1,7 @@ import unittest -from op_test_util import OpTestMeta import numpy as np +from gradient_checker import GradientChecker, create_op +from op_test_util import OpTestMeta class TestMulOp(unittest.TestCase): @@ -15,5 +16,19 @@ class TestMulOp(unittest.TestCase): self.outputs = {'Out': np.dot(self.inputs['X'], self.inputs['Y'])} +class MulGradOpTest(GradientChecker): + def test_mul(self): + op = create_op("mul") + inputs = { + 'X': np.random.random((32, 84)).astype("float32"), + 'Y': np.random.random((84, 100)).astype("float32") + } + # mul op will enlarge the relative error + self.check_grad( + op, inputs, set(["X", "Y"]), "Out", max_relative_error=0.5) + + +# TODO(dzh,qijun) : mulgrad test case need transpose feature of blas library + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/v2/framework/tests/test_rowwise_add_op.py b/python/paddle/v2/framework/tests/test_rowwise_add_op.py index f8521eb517057fbeb104b28af7da4fffe54f37de..29d72e850099734a9828ccceed47cc0a57fc3d6b 100644 --- a/python/paddle/v2/framework/tests/test_rowwise_add_op.py +++ b/python/paddle/v2/framework/tests/test_rowwise_add_op.py @@ -1,6 +1,7 @@ import unittest -from op_test_util import OpTestMeta import numpy as np +from op_test_util import OpTestMeta +from gradient_checker import GradientChecker, create_op class TestRowwiseAddOp(unittest.TestCase): @@ -15,5 +16,15 @@ class TestRowwiseAddOp(unittest.TestCase): self.outputs = {'Out': np.add(self.inputs['X'], self.inputs['b'])} +class RowwiseAddGradOpTest(GradientChecker): + def test_rowwise_add(self): + op = create_op("rowwise_add") + inputs = { + "X": np.random.uniform(0.1, 1, [10, 10]).astype("float32"), + "b": np.random.uniform(0.1, 1, [10]).astype("float32") + } + self.check_grad(op, inputs, set(["X", "b"]), "Out") + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/v2/framework/tests/test_sigmoid_op.py b/python/paddle/v2/framework/tests/test_sigmoid_op.py index 2a57a41ed8b718fd420062ba68e853a4861b7359..273c2e5ab1a84d12621fe9568c4cf22073b6aed4 100644 --- a/python/paddle/v2/framework/tests/test_sigmoid_op.py +++ b/python/paddle/v2/framework/tests/test_sigmoid_op.py @@ -1,6 +1,7 @@ import unittest -from op_test_util import OpTestMeta import numpy as np +from op_test_util import OpTestMeta +from gradient_checker import GradientChecker, create_op class TestSigmoidOp(unittest.TestCase): @@ -8,12 +9,20 @@ class TestSigmoidOp(unittest.TestCase): def setUp(self): self.type = "sigmoid" - self.inputs = {'X': np.random.random((32, 100)).astype("float32")} + self.inputs = {'X': np.random.random((15, 31)).astype("float32")} self.outputs = {'Y': 1 / (1 + np.exp(-self.inputs['X']))} -#class TestSigmoidGradOp(unittest.TestCase): -#TODO(qingqing) add unit test +class TestSigmoidGradOp(GradientChecker): + def test_grad(self): + op = create_op("sigmoid") + inputs = {"X": np.random.uniform(0.1, 1, [11, 17]).astype("float32")} + # compare gpu and cpu results for backward op. + # this test will be skiped if only compiling CPU version. + self.compare_grad(op, inputs) + # check gradients + self.check_grad(op, inputs, set("X"), "Y", max_relative_error=0.007) + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/v2/trainer.py b/python/paddle/v2/trainer.py index 9c4dd5f25083d210bbd218a85d8dbb3cce2c3d0e..0654a301049dcb347b79879076a869a0c14a07ae 100644 --- a/python/paddle/v2/trainer.py +++ b/python/paddle/v2/trainer.py @@ -27,16 +27,24 @@ class SGD(object): SGD Trainer combines data reader, network topolopy and update_equation together to train/test a neural network. - :param update_equation: The optimizer object. - :type update_equation: paddle.v2.optimizer.Optimizer :param cost: Target cost that neural network should be optimized. :type cost: paddle.v2.config_base.Layer :param parameters: The parameters dictionary. :type parameters: paddle.v2.parameters.Parameters + :param update_equation: The optimizer object. + :type update_equation: paddle.v2.optimizer.Optimizer :param extra_layers: Some layers in the neural network graph are not in the path of cost layer. - :param pserver_spec: pserver location, eg: localhost:3000 :type extra_layers: paddle.v2.config_base.Layer + :param is_local: Whether trainning locally + :type is_local: bool + :param pserver_spec: comma string for pserver location, + eg:127.10.0.10:3000,127.10.0.11:3000, + and this parameter is only used for fault + tolerant mode cluster training. + :type pserver_spec: string + :param use_etcd: Whether using etcd pserver. + :param use_etcd: bool """ def __init__(self,