diff --git a/CMakeLists.txt b/CMakeLists.txt index 06dd5a1332cb9dda8f942a342693ec74fb6184d9..ad559672ad2f83a3d62cdf332b47c6cf1e730f70 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -55,6 +55,7 @@ option(WITH_C_API "Compile PaddlePaddle with C-API(Prediction)" OFF) option(WITH_GOLANG "Compile PaddlePaddle with GOLANG" OFF) option(GLIDE_INSTALL "Download and install go dependencies " ON) option(USE_NNPACK "Compile PaddlePaddle with NNPACK library" OFF) +option(USE_EIGEN_FOR_BLAS "Use matrix multiplication in Eigen" OFF) # CMAKE_BUILD_TYPE if(NOT CMAKE_BUILD_TYPE) diff --git a/cmake/configure.cmake b/cmake/configure.cmake index 209f9078a637ac581d90212a48216eb388c477ed..51c3b918cc4ef4cf6c8052ccc14028a872309fcf 100644 --- a/cmake/configure.cmake +++ b/cmake/configure.cmake @@ -28,6 +28,10 @@ if(NOT WITH_TIMER) add_definitions(-DPADDLE_DISABLE_TIMER) endif(NOT WITH_TIMER) +if(USE_EIGEN_FOR_BLAS) + add_definitions(-DPADDLE_USE_EIGEN_FOR_BLAS) +endif(USE_EIGEN_FOR_BLAS) + if(NOT WITH_PROFILER) add_definitions(-DPADDLE_DISABLE_PROFILER) endif(NOT WITH_PROFILER) diff --git a/cmake/cudnn.cmake b/cmake/cudnn.cmake index 69f40df51680a104c47d9335c070c570dcaff59a..2c84061ff572de4687b4d496f8ded6deee8d1011 100644 --- a/cmake/cudnn.cmake +++ b/cmake/cudnn.cmake @@ -2,7 +2,7 @@ if(NOT WITH_GPU) return() endif() -set(CUDNN_ROOT "" CACHE PATH "CUDNN ROOT") +set(CUDNN_ROOT "/usr" CACHE PATH "CUDNN ROOT") find_path(CUDNN_INCLUDE_DIR cudnn.h PATHS ${CUDNN_ROOT} ${CUDNN_ROOT}/include $ENV{CUDNN_ROOT} $ENV{CUDNN_ROOT}/include ${CUDA_TOOLKIT_INCLUDE} diff --git a/doc/api/v2/config/layer.rst b/doc/api/v2/config/layer.rst index 9a5901616f8a61d686929a11e09d278df3ee51d6..1329b77bb44f52c66a703740715b890c47234e72 100644 --- a/doc/api/v2/config/layer.rst +++ b/doc/api/v2/config/layer.rst @@ -257,6 +257,11 @@ seq_concat .. autoclass:: paddle.v2.layer.seq_concat :noindex: +seq_slice +--------- +.. autoclass:: paddle.v2.layer.seq_slice + :noindex: + kmax_sequence_score ------------------- .. autoclass:: paddle.v2.layer.kmax_sequence_score @@ -362,6 +367,11 @@ trans .. autoclass:: paddle.v2.layer.trans :noindex: +scale_shift +----------- +.. autoclass:: paddle.v2.layer.scale_shift + :noindex: + Sampling Layers =============== diff --git a/go/master/client.go b/go/master/client.go index 62801b9b7fe85fe27147b12160f48d988623d547..f04cf50ce3cf765a79cbe555d3edb68f3dbb911e 100644 --- a/go/master/client.go +++ b/go/master/client.go @@ -63,13 +63,24 @@ func WithAddr(addr string) func(c *Client) error { // WithEtcd sets the client to use etcd for master discovery. func WithEtcd(endpoints []string, timeout time.Duration) func(*Client) error { return func(c *Client) error { - cli, err := clientv3.New(clientv3.Config{ - Endpoints: endpoints, - DialTimeout: timeout, - }) - if err != nil { + var cli *clientv3.Client + f := func() error { + var err error + cli, err = clientv3.New(clientv3.Config{ + Endpoints: endpoints, + DialTimeout: timeout, + }) return err } + for { + err := f() + if err != nil { + log.Warningln(err) + } else { + break + } + time.Sleep(time.Second) + } ch := make(chan string, 1) a, err := GetKey(cli, DefaultAddrPath, timeout) @@ -101,9 +112,6 @@ func NewClient(opts ...func(*Client) error) (*Client, error) { } } c.ch = make(chan record, c.bufSize) - // FIXME: connection is created asyncrosly in monitorMaster go routine, - // ensure the connection is ready for use before calling c.addClient. - time.Sleep(time.Second) return c, nil } diff --git a/paddle/CMakeLists.txt b/paddle/CMakeLists.txt index cf61a243e9df2fd4a580e41f07cb0a22dcc72083..ec866b2907d4623e8a94a249bc9af624071ade97 100644 --- a/paddle/CMakeLists.txt +++ b/paddle/CMakeLists.txt @@ -15,6 +15,7 @@ if(Boost_FOUND) add_subdirectory(platform) add_subdirectory(framework) add_subdirectory(operators) + add_subdirectory(pybind) endif() if(WITH_C_API) diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 68304c9fc8b8fa13cb1f99b82517abc87c71496c..c0838d9b759110fd706577386d2c81bda6876223 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -18,8 +18,8 @@ cc_test(scope_test SRCS scope_test.cc DEPS scope) proto_library(framework_proto SRCS framework.proto) cc_library(attribute SRCS attribute.cc DEPS framework_proto) - -cc_library(operator SRCS operator.cc DEPS framework_proto device_context tensor scope attribute) +cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto) +cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope) cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry) cc_library(grad_op_builder SRCS grad_op_builder.cc DEPS operator) @@ -39,21 +39,3 @@ add_custom_command(TARGET framework_py_proto POST_BUILD cc_library(backward SRCS backward.cc DEPS net_op) cc_test(backward_test SRCS backward_test.cc DEPS backward recurrent_op device_context) - -if(WITH_PYTHON) -cc_library(paddle_pybind SHARED - SRCS pybind.cc - DEPS pybind python backward - sgd_op - add_op - mul_op - rowwise_add_op - sigmoid_op - softmax_op - mean_op - cross_entropy_op - recurrent_op - uniform_random_op - gaussian_random_op - fill_zeros_like_op) -endif(WITH_PYTHON) diff --git a/paddle/framework/backward.cc b/paddle/framework/backward.cc index 9d30887224fe0020ff5665f362e7403bf5c724ee..bfda18724cc8ed23a40e0626ff07a290d26aa9d2 100644 --- a/paddle/framework/backward.cc +++ b/paddle/framework/backward.cc @@ -110,7 +110,7 @@ static std::unique_ptr BackwardRecursive( dup_output_ops[out].emplace_back(local_op_id); return false; }); - net->AddOp(std::move(bwd)); + net->AppendOp(std::move(bwd)); } // Get unique ID for this method. auto uid = uniq_id++; @@ -163,8 +163,9 @@ static std::unique_ptr BackwardRecursive( // If part of input gradient of that operator is not calculated, fill // zero variables to that input gradient. - net->AddOp(OpRegistry::CreateOp("fill_zeros_like", {{"Src", {prefix}}}, - {{"Dst", {grad_input}}}, {})); + net->AppendOp(OpRegistry::CreateOp("fill_zeros_like", + {{"Src", {prefix}}}, + {{"Dst", {grad_input}}}, {})); } return false; }); @@ -195,7 +196,7 @@ static std::unique_ptr BackwardRecursive( if (net->ops_.empty()) { // Current no aux op is added to network return grad_op; } - net->AddOp(std::move(grad_op)); + net->AppendOp(std::move(grad_op)); } net->SetType("@GENERATED_BACKWARD@"); net->CompleteAddOp(); diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc index 2c5ec76dfeb8b8485951e4d94896b6758e0cb930..f100c4d05489ac3bd4ceb5f11ae871985f0e5d83 100644 --- a/paddle/framework/backward_test.cc +++ b/paddle/framework/backward_test.cc @@ -72,16 +72,16 @@ class NoGradOpMaker : public OpProtoAndCheckerMaker { class FcOp : public operators::NetOp { public: - FcOp(const std::string &type, const VarNameMap &inputs, - const VarNameMap &outputs, const AttributeMap &attrs) + FcOp(const std::string &type, const VariableNameMap &inputs, + const VariableNameMap &outputs, const AttributeMap &attrs) : NetOp(type, inputs, outputs, attrs) { - AddOp(OpRegistry::CreateOp("mul", - {{"X", {Input("X")}}, {"Y", {Input("W")}}}, - {{"Out", {Output("mul_result")}}}, {})); + AppendOp(OpRegistry::CreateOp("mul", + {{"X", {Input("X")}}, {"Y", {Input("W")}}}, + {{"Out", {Output("mul_result")}}}, {})); auto input_b = Inputs("b"); std::string before_act = "mul_result"; if (input_b.size() != 0) { - AddOp(OpRegistry::CreateOp( + AppendOp(OpRegistry::CreateOp( "rowwise_add", {{"X", {Output("mul_result")}}, {"b", {input_b[0]}}}, {{"Out", {Output("add_result")}}}, {})); before_act = "add_result"; @@ -92,8 +92,8 @@ class FcOp : public operators::NetOp { } } - AddOp(OpRegistry::CreateOp("sigmoid", {{"X", {Output(before_act)}}}, - {{"Out", {Output("Out")}}}, {})); + AppendOp(OpRegistry::CreateOp("sigmoid", {{"X", {Output(before_act)}}}, + {{"Out", {Output("Out")}}}, {})); CompleteAddOp(false); } }; @@ -234,13 +234,13 @@ TEST(Backward, net_fc_backward_not_have_b) { TEST(Backward, net_input_of_network_not_need_grad) { ops::NetOp net; - net.AddOp(f::OpRegistry::CreateOp( + net.AppendOp(f::OpRegistry::CreateOp( "fc", {{"X", {"x"}}, {"W", {"W1"}}, {"b", {"b1"}}}, {{"mul_result", {"mul_tmp_0"}}, {"add_result", {"add_tmp_0"}}, {"Out", {"hidden0"}}}, {})); - net.AddOp(f::OpRegistry::CreateOp( + net.AppendOp(f::OpRegistry::CreateOp( "fc", {{"X", {"hidden0"}}, {"W", {"W2"}}, {"b", {"b2"}}}, {{"mul_result", {"mul_tmp_1"}}, {"add_result", {"add_tmp_1"}}, @@ -273,10 +273,10 @@ TEST(Backward, net_input_of_network_not_need_grad) { TEST(Backward, net_shared_weight) { ops::NetOp net; - net.AddOp(f::OpRegistry::CreateOp("mul", {{"X", {"x"}}, {"Y", {"w"}}}, - {{"Out", {"out"}}}, {})); - net.AddOp(f::OpRegistry::CreateOp("mul", {{"X", {"out"}}, {"Y", {"w"}}}, - {{"Out", {"FinalOut"}}}, {})); + net.AppendOp(f::OpRegistry::CreateOp("mul", {{"X", {"x"}}, {"Y", {"w"}}}, + {{"Out", {"out"}}}, {})); + net.AppendOp(f::OpRegistry::CreateOp("mul", {{"X", {"out"}}, {"Y", {"w"}}}, + {{"Out", {"FinalOut"}}}, {})); net.CompleteAddOp(); auto bwd = f::Backward(net, {}); @@ -357,19 +357,19 @@ TEST(Backward, op_part_of_input_are_not_need) { TEST(Backward, linear_net_intermediate_variable_has_no_grad) { ops::NetOp net; - net.AddOp(f::OpRegistry::CreateOp( + net.AppendOp(f::OpRegistry::CreateOp( "fc", {{"X", {"x1"}}, {"W", {"w1"}}, {"b", {"b1"}}}, {{"mul_result", {"mul_out1"}}, {"add_result", {"add_out1"}}, {"Out", {"out1"}}}, {})); - net.AddOp(f::OpRegistry::CreateOp( + net.AppendOp(f::OpRegistry::CreateOp( "fc", {{"X", {"out1"}}, {"W", {"w2"}}, {"b", {"b2"}}}, {{"mul_result", {"mul_out2"}}, {"add_result", {"tmp_out2"}}, {"Out", {"out2"}}}, {})); - net.AddOp(f::OpRegistry::CreateOp( + net.AppendOp(f::OpRegistry::CreateOp( "fc", {{"X", {"out2"}}, {"W", {"w3"}}, {"b", {"b3"}}}, {{"mul_result", {"mul_out3"}}, {"add_result", {"tmp_out3"}}, diff --git a/paddle/framework/grad_op_builder.cc b/paddle/framework/grad_op_builder.cc index 0a2a41f6b62658ac8633a6e384d099f8d6641f33..b02a599a800668b22e7fe39a10fa6dc132e305bd 100644 --- a/paddle/framework/grad_op_builder.cc +++ b/paddle/framework/grad_op_builder.cc @@ -20,13 +20,13 @@ namespace framework { enum class OpArgType { IN, OUT }; static void TransOpArg(const OperatorBase* src_op, const OpArgType& src_type, - bool is_grad, OperatorBase::VarNameMap* vars) { + bool is_grad, VariableNameMap* vars) { const auto& src_inout = src_type == OpArgType::IN ? src_op->Inputs() : src_op->Outputs(); auto& dst_inout = *vars; - const OpProto* proto = OpRegistry::op_info_map().at(src_op->Type()).proto_; + auto& proto = OpInfoMap::Instance().Get(src_op->Type()).Proto(); const auto& src_arg_list = - src_type == OpArgType::IN ? proto->inputs() : proto->outputs(); + src_type == OpArgType::IN ? proto.inputs() : proto.outputs(); for (const auto& arg : src_arg_list) { if (arg.not_in_gradient() && !is_grad) continue; const std::string src_name = arg.name(); @@ -40,26 +40,18 @@ static void TransOpArg(const OperatorBase* src_op, const OpArgType& src_type, } OperatorBase* BuildGradOp(const OperatorBase* op) { - auto it = OpRegistry::op_info_map().find(op->Type()); - PADDLE_ENFORCE(it != OpRegistry::op_info_map().end(), - "'%s' has not been registered.", op->Type()); - PADDLE_ENFORCE(it->second.proto_ != nullptr, "'%s' has no OpProto.", - op->Type()); - std::string grad_op_type = it->second.grad_op_type_; - PADDLE_ENFORCE(!grad_op_type.empty(), "'%s' has no gradient operator.", - op->Type()); + auto& info = OpInfoMap::Instance().Get(op->Type()); + PADDLE_ENFORCE(info.HasGradientOp()); - OperatorBase::VarNameMap inputs; - OperatorBase::VarNameMap outputs; + VariableNameMap inputs; + VariableNameMap outputs; TransOpArg(op, OpArgType::IN, false, &inputs); // I TransOpArg(op, OpArgType::OUT, false, &inputs); // O TransOpArg(op, OpArgType::OUT, true, &inputs); // OG TransOpArg(op, OpArgType::IN, true, &outputs); // IG - it = OpRegistry::op_info_map().find(grad_op_type); - PADDLE_ENFORCE(it != OpRegistry::op_info_map().end(), - "'%s' has not been registered.", grad_op_type); - return it->second.creator_(grad_op_type, inputs, outputs, op->Attrs()); + auto& grad_info = OpInfoMap::Instance().Get(info.grad_op_type_); + return grad_info.Creator()(info.grad_op_type_, inputs, outputs, op->Attrs()); } } // namespace framework diff --git a/paddle/framework/op_info.cc b/paddle/framework/op_info.cc new file mode 100644 index 0000000000000000000000000000000000000000..81ba29797c5f478e5d6a91236f3e8de1e6b43e49 --- /dev/null +++ b/paddle/framework/op_info.cc @@ -0,0 +1,29 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/framework/op_info.h" + +namespace paddle { +namespace framework { + +static OpInfoMap* g_op_info_map = nullptr; + +OpInfoMap& OpInfoMap::Instance() { + if (g_op_info_map == nullptr) { + g_op_info_map = new OpInfoMap(); + } + return *g_op_info_map; +} +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/op_info.h b/paddle/framework/op_info.h new file mode 100644 index 0000000000000000000000000000000000000000..94245c6c44aca962b0db890947a9dc5550ac0799 --- /dev/null +++ b/paddle/framework/op_info.h @@ -0,0 +1,101 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include +#include +#include +#include + +#include "paddle/framework/attribute.h" + +namespace paddle { +namespace framework { +class OperatorBase; +using VariableNameMap = std::map>; + +using OpCreator = std::function; + +struct OpInfo { + OpCreator creator_; + std::string grad_op_type_; + OpProto* proto_; + OpAttrChecker* checker_; + + bool HasOpProtoAndChecker() const { + return proto_ != nullptr && checker_ != nullptr; + } + + const OpProto& Proto() const { + PADDLE_ENFORCE_NOT_NULL(proto_, "Operator Proto has not been registered"); + PADDLE_ENFORCE(proto_->IsInitialized(), + "Operator Proto must be initialized in op info"); + return *proto_; + } + + const OpAttrChecker& Checker() const { + PADDLE_ENFORCE_NOT_NULL(checker_, + "Operator Checker has not been registered"); + return *checker_; + } + + const OpCreator& Creator() const { + PADDLE_ENFORCE_NOT_NULL(creator_, + "Operator Creator has not been registered"); + return creator_; + } + + bool HasGradientOp() const { return !grad_op_type_.empty(); } +}; + +class OpInfoMap { + public: + static OpInfoMap& Instance(); + + OpInfoMap(const OpInfoMap& o) = delete; + OpInfoMap(OpInfoMap&& o) = delete; + OpInfoMap& operator=(const OpInfoMap& o) = delete; + OpInfoMap& operator=(OpInfoMap&& o) = delete; + + bool Has(const std::string& op_type) const { + return map_.find(op_type) != map_.end(); + } + + void Insert(const std::string& type, const OpInfo& info) { + PADDLE_ENFORCE(!Has(type), "Operator %s has been registered", type); + map_.insert({type, info}); + } + + const OpInfo& Get(const std::string& type) const { + auto it = map_.find(type); + PADDLE_ENFORCE(it != map_.end(), "Operator %s are not found", type); + return it->second; + } + + template + void IterAllInfo(Callback callback) { + for (auto& it : map_) { + callback(it.first, it.second); + } + } + + private: + OpInfoMap() = default; + std::unordered_map map_; +}; + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/op_registry.cc b/paddle/framework/op_registry.cc index 8eae86e9605da74cdc37caeb9569e7500aac2a63..b0e85dd49f97da4a7f889fde0b5f060954947be8 100644 --- a/paddle/framework/op_registry.cc +++ b/paddle/framework/op_registry.cc @@ -19,32 +19,18 @@ limitations under the License. */ namespace paddle { namespace framework { -std::unique_ptr OpRegistry::CreateOp(const std::string& type, - const VarNameMap& inputs, - const VarNameMap& outputs, - AttributeMap attrs) { - auto it = op_info_map().find(type); - PADDLE_ENFORCE(it != op_info_map().end(), - "Operator '%s' has not been registered.", type); - it->second.checker_->Check(attrs); - auto op = it->second.creator_(type, inputs, outputs, attrs); +std::unique_ptr OpRegistry::CreateOp( + const std::string& type, const VariableNameMap& inputs, + const VariableNameMap& outputs, AttributeMap attrs) { + auto& info = OpInfoMap::Instance().Get(type); + info.Checker().Check(attrs); + auto op = info.Creator()(type, inputs, outputs, attrs); return std::unique_ptr(op); } -std::unique_ptr OpRegistry::CreateOp(const OpDesc& op_desc) { - VarNameMap inputs = ConvertOpDescVarsToVarNameMap(op_desc.inputs()); - VarNameMap outputs = ConvertOpDescVarsToVarNameMap(op_desc.outputs()); - AttributeMap attrs; - for (auto& attr : op_desc.attrs()) { - attrs[attr.name()] = GetAttrValue(attr); - } - - return CreateOp(op_desc.type(), inputs, outputs, attrs); -} - -OperatorBase::VarNameMap OpRegistry::ConvertOpDescVarsToVarNameMap( +static VariableNameMap ConvertOpDescVarsToVarNameMap( const google::protobuf::RepeatedPtrField& op_desc_vars) { - VarNameMap ret_val; + VariableNameMap ret_val; for (auto& var : op_desc_vars) { auto& var_names = ret_val[var.parameter()]; auto& var_names_in_proto = var.arguments(); @@ -55,6 +41,17 @@ OperatorBase::VarNameMap OpRegistry::ConvertOpDescVarsToVarNameMap( return ret_val; } +std::unique_ptr OpRegistry::CreateOp(const OpDesc& op_desc) { + VariableNameMap inputs = ConvertOpDescVarsToVarNameMap(op_desc.inputs()); + VariableNameMap outputs = ConvertOpDescVarsToVarNameMap(op_desc.outputs()); + AttributeMap attrs; + for (auto& attr : op_desc.attrs()) { + attrs[attr.name()] = GetAttrValue(attr); + } + + return CreateOp(op_desc.type(), inputs, outputs, attrs); +} + std::unique_ptr OpRegistry::CreateGradOp(const OperatorBase& op) { PADDLE_ENFORCE(!op.IsNetOp(), "Use framework::Backward to get backward ops"); return std::unique_ptr(BuildGradOp(&op)); diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index 4c2d13d639005d2d2710c19f63988333d89bce13..2d09cde41e3f5086279f9441e0fdc52549bed5ab 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -23,6 +23,7 @@ limitations under the License. */ #include "paddle/framework/attribute.h" #include "paddle/framework/framework.pb.h" #include "paddle/framework/grad_op_builder.h" +#include "paddle/framework/op_info.h" #include "paddle/framework/operator.h" #include "paddle/framework/scope.h" @@ -30,28 +31,16 @@ namespace paddle { namespace framework { class OpRegistry { - using VarNameMap = OperatorBase::VarNameMap; - using OpCreator = std::function; - public: - struct OpInfo { - OpCreator creator_; - std::string grad_op_type_; - OpProto* proto_; - OpAttrChecker* checker_; - }; - template static void RegisterOp(const std::string& op_type, const std::string& grad_op_type) { - PADDLE_ENFORCE(op_info_map().count(op_type) == 0, + PADDLE_ENFORCE(!OpInfoMap::Instance().Has(op_type), "'%s' is registered more than once.", op_type); OpInfo op_info; - op_info.creator_ = [](const std::string& type, const VarNameMap& inputs, - const VarNameMap& outputs, - const AttributeMap& attrs) { + op_info.creator_ = []( + const std::string& type, const VariableNameMap& inputs, + const VariableNameMap& outputs, const AttributeMap& attrs) { return new OpType(type, inputs, outputs, attrs); }; op_info.grad_op_type_ = grad_op_type; @@ -70,7 +59,7 @@ class OpRegistry { op_info.proto_ = nullptr; op_info.checker_ = nullptr; } - op_info_map().insert(std::make_pair(op_type, op_info)); + OpInfoMap::Instance().Insert(op_type, op_info); // register gradient op if (!grad_op_type.empty()) { RegisterOp(grad_op_type, ""); @@ -78,21 +67,13 @@ class OpRegistry { } static std::unique_ptr CreateOp(const std::string& type, - const VarNameMap& inputs, - const VarNameMap& outputs, + const VariableNameMap& inputs, + const VariableNameMap& outputs, AttributeMap attrs); static std::unique_ptr CreateOp(const OpDesc& op_desc); - static VarNameMap ConvertOpDescVarsToVarNameMap( - const google::protobuf::RepeatedPtrField& op_desc_vars); - static std::unique_ptr CreateGradOp(const OperatorBase& op); - - static std::unordered_map& op_info_map() { - static std::unordered_map op_info_map_; - return op_info_map_; - } }; class Registrar { diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index eadd8f3316ff1ebffb94a56b2e62d661e4e0b38f..7abbde610f1e9c530393b9a9cabe40b826712212 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -115,8 +115,8 @@ void OperatorBase::Rename(const std::string& old_name, } OperatorBase::OperatorBase(const std::string& type, - const OperatorBase::VarNameMap& inputs, - const OperatorBase::VarNameMap& outputs, + const VariableNameMap& inputs, + const VariableNameMap& outputs, const AttributeMap& attrs) : type_(type), inputs_(inputs), outputs_(outputs), attrs_(attrs) { static std::atomic gUniqId(0UL); @@ -141,18 +141,10 @@ std::vector OperatorBase::OutputVars(bool has_intermediate) const { } return ret_val; } - auto it = OpRegistry::op_info_map().find(type_); - PADDLE_ENFORCE( - it != OpRegistry::op_info_map().end(), - "Operator %s not registered, cannot figure out intermediate outputs", - type_); - PADDLE_ENFORCE( - it->second.proto_ != nullptr, - "Operator %s has no OpProto, cannot figure out intermediate outputs", - type_); + auto& info = OpInfoMap::Instance().Get(Type()); // get all OpProto::Var for outputs - for (auto& o : it->second.proto_->outputs()) { + for (auto& o : info.Proto().outputs()) { // ignore all intermediate output if (o.intermediate()) continue; auto out = outputs_.find(o.name()); diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 807298088981b969622174be753ea0da72067243..8397570d26f06f0238e9c5afc85d721df7679257 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -19,6 +19,7 @@ limitations under the License. */ #include #include +#include "op_info.h" #include "paddle/framework/attribute.h" #include "paddle/framework/framework.pb.h" #include "paddle/framework/scope.h" @@ -62,10 +63,8 @@ class ExecutionContext; */ class OperatorBase { public: - using VarNameMap = std::map>; - - OperatorBase(const std::string& type, const VarNameMap& inputs, - const VarNameMap& outputs, const AttributeMap& attrs); + OperatorBase(const std::string& type, const VariableNameMap& inputs, + const VariableNameMap& outputs, const AttributeMap& attrs); virtual ~OperatorBase() {} @@ -93,8 +92,8 @@ class OperatorBase { /// rename inputs outputs name void Rename(const std::string& old_name, const std::string& new_name); - const VarNameMap& Inputs() const { return inputs_; } - const VarNameMap& Outputs() const { return outputs_; } + const VariableNameMap& Inputs() const { return inputs_; } + const VariableNameMap& Outputs() const { return outputs_; } //! Get a input with argument's name described in `op_proto` const std::string& Input(const std::string& name) const; //! Get a input which has multiple variables. @@ -122,30 +121,32 @@ class OperatorBase { // I (Inputs)opear // O (Outputs) // OG (Output Gradients) - VarNameMap inputs_; + VariableNameMap inputs_; // NOTE: in case of OpGrad, outputs_ contains // IG (Inputs Gradients) - VarNameMap outputs_; + VariableNameMap outputs_; AttributeMap attrs_; }; // Macro for define a clone method. // If you are writing an kernel operator, `Clone` will be defined when you // register it. i.e. `Clone` method is not needed to define by yourself. -#define DEFINE_OP_CLONE_METHOD(CLS) \ +#define DEFINE_OP_CLONE_METHOD(cls) \ std::unique_ptr Clone() const final { \ - return std::unique_ptr(new CLS(*this)); \ + return std::unique_ptr(new cls(*this)); \ } // Macro for define a default constructor for Operator. // You can also use // using PARENT_CLASS::PARENT_CLASS; // to use parent's constructor. -#define DEFINE_OP_CONSTRUCTOR(CLS, PARENT_CLS) \ - CLS(const std::string& type, const VarNameMap& inputs, \ - const VarNameMap& outputs, const paddle::framework::AttributeMap& attrs) \ - : PARENT_CLS(type, inputs, outputs, attrs) {} +#define DEFINE_OP_CONSTRUCTOR(cls, parent_cls) \ + cls(const std::string& type, \ + const ::paddle::framework::VariableNameMap& inputs, \ + const ::paddle::framework::VariableNameMap& outputs, \ + const paddle::framework::AttributeMap& attrs) \ + : parent_cls(type, inputs, outputs, attrs) {} class NOP : public OperatorBase { public: @@ -389,8 +390,8 @@ class OperatorWithKernel : public OperatorBase { using OpKernelMap = std::unordered_map, OpKernelHash>; - OperatorWithKernel(const std::string& type, const VarNameMap& inputs, - const VarNameMap& outputs, const AttributeMap& attrs) + OperatorWithKernel(const std::string& type, const VariableNameMap& inputs, + const VariableNameMap& outputs, const AttributeMap& attrs) : OperatorBase(type, inputs, outputs, attrs) {} void InferShape(const Scope& scope) const override { diff --git a/paddle/framework/operator_test.cc b/paddle/framework/operator_test.cc index 2425b87779f6af01b0e8a91b5f574a28385f0efd..1d7efb7b9403f7c1c6bdbb27a0258f79ae032f43 100644 --- a/paddle/framework/operator_test.cc +++ b/paddle/framework/operator_test.cc @@ -23,8 +23,8 @@ static int op_run_num = 0; class OpWithoutKernelTest : public OperatorBase { public: - OpWithoutKernelTest(const std::string& type, const VarNameMap& inputs, - const VarNameMap& outputs, const AttributeMap& attrs) + OpWithoutKernelTest(const std::string& type, const VariableNameMap& inputs, + const VariableNameMap& outputs, const AttributeMap& attrs) : OperatorBase(type, inputs, outputs, attrs), x(1) {} void InferShape(const Scope& scope) const override {} void Run(const Scope& scope, @@ -249,8 +249,9 @@ TEST(OpKernel, multi_inputs) { class OperatorClone : public paddle::framework::OperatorBase { public: DEFINE_OP_CLONE_METHOD(OperatorClone); - OperatorClone(const std::string& type, const VarNameMap& inputs, - const VarNameMap& outputs, + OperatorClone(const std::string& type, + const paddle::framework::VariableNameMap& inputs, + const paddle::framework::VariableNameMap& outputs, const paddle::framework::AttributeMap& attrs) : OperatorBase(type, inputs, outputs, attrs) {} void InferShape(const paddle::framework::Scope& scope) const override {} diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h index b8c779f4e5fc7bc51298cdd35b26c2c8ac98edf6..643f875491724bf443bd7727391734377ee6180c 100644 --- a/paddle/framework/tensor.h +++ b/paddle/framework/tensor.h @@ -105,7 +105,10 @@ class Tensor { template inline Tensor Slice(const int& begin_idx, const int& end_idx) const; - platform::Place place() const { return holder_->place(); } + platform::Place place() const { + PADDLE_ENFORCE_NOT_NULL(holder_, "Tensor get place() must contains holder"); + return holder_->place(); + } private: template diff --git a/paddle/function/CMakeLists.txt b/paddle/function/CMakeLists.txt index 7dfb6f61c50959f7269725a00dbc4f9c27474bdf..c572a9d433bc16e6733b8fc9367970bef28e699a 100644 --- a/paddle/function/CMakeLists.txt +++ b/paddle/function/CMakeLists.txt @@ -4,6 +4,10 @@ file(GLOB cpp_files . *Op.cpp) list(APPEND h_files Function.h) list(APPEND cpp_files Function.cpp) list(APPEND cpp_files BufferArg.cpp) +list(APPEND cpp_files GemmFunctor.cpp) +if(USE_EIGEN_FOR_BLAS) + list(APPEND cpp_files EigenGemm.cpp) +endif(USE_EIGEN_FOR_BLAS) if(WITH_GPU) file(GLOB cu_files . *OpGpu.cu) diff --git a/paddle/function/DepthwiseConvOp.cpp b/paddle/function/DepthwiseConvOp.cpp index 490e8d546cbd460217abe95f6291b13fa207faa9..2f3112fe657cd381891dc53c7179e7520911e8c9 100644 --- a/paddle/function/DepthwiseConvOp.cpp +++ b/paddle/function/DepthwiseConvOp.cpp @@ -14,7 +14,6 @@ limitations under the License. */ #include "DepthwiseConvOp.h" #include "ConvOp.h" -#include "GemmFunctor.h" namespace paddle { diff --git a/paddle/function/DepthwiseConvOpGpu.cu b/paddle/function/DepthwiseConvOpGpu.cu index 33463805cbd4746c05548028e0bc4a0e2a90453e..2d722dfcfca0f328edeecf185ea37b8512b91907 100644 --- a/paddle/function/DepthwiseConvOpGpu.cu +++ b/paddle/function/DepthwiseConvOpGpu.cu @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "DepthwiseConvOp.h" -#include "GemmFunctor.h" #include "paddle/math/BaseMatrix.h" namespace paddle { diff --git a/paddle/function/EigenGemm.cpp b/paddle/function/EigenGemm.cpp new file mode 100644 index 0000000000000000000000000000000000000000..674141ed39b7f5573948348e3ba3bb526ae43c66 --- /dev/null +++ b/paddle/function/EigenGemm.cpp @@ -0,0 +1,91 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "unsupported/Eigen/CXX11/Tensor" + +namespace paddle { + +template +struct EigenBlasGemm { + typedef Eigen::TensorMap, + Eigen::Aligned> + Matrix; + + static void compute(const bool transA, + const bool transB, + const int M, + const int N, + const int K, + const T alpha, + const T* A, + const int lda, + const T* B, + const int ldb, + const T beta, + T* C, + const int ldc) { + Eigen::array sizeA; + if (transA) { + sizeA[0] = K; + sizeA[1] = M; + CHECK_EQ(M, lda); + } else { + sizeA[0] = M; + sizeA[1] = K; + CHECK_EQ(K, lda); + } + Eigen::array sizeB; + if (transB) { + sizeB[0] = N; + sizeB[1] = K; + CHECK_EQ(K, ldb); + } else { + sizeB[0] = K; + sizeB[1] = N; + CHECK_EQ(N, ldb); + } + Eigen::array sizeC; + sizeC[0] = M; + sizeC[1] = N; + CHECK_EQ(N, ldc); + + const Matrix a(const_cast(A), sizeA); + const Matrix b(const_cast(B), sizeB); + Matrix c(C, sizeC); + + typedef typename Eigen::Tensor::DimensionPair DimPair; + Eigen::array dims; + dims[0] = DimPair(1, 0); + dims[0].first = transA ? 0 : 1; + dims[0].second = transB ? 1 : 0; + + Eigen::DefaultDevice device; + if (alpha == T(1) && beta == T(0)) { + c.device(device) = a.contract(b, dims); + } else if (alpha == T(1) && beta == T(1)) { + c.device(device) += a.contract(b, dims); + } else { + c.device(device) = alpha * a.contract(b, dims) + beta * c; + } + } +}; + +#ifdef PADDLE_TYPE_DOUBLE +template class EigenBlasGemm; +#else +template class EigenBlasGemm; +#endif + +} // namespace paddle diff --git a/paddle/function/GemmConvOp.cpp b/paddle/function/GemmConvOp.cpp index 0ada4d70a0c7d13f9b5fb1a42eac07fc4c775a87..f8cf4ebea8d724f0291b981647622b63e3d84495 100644 --- a/paddle/function/GemmConvOp.cpp +++ b/paddle/function/GemmConvOp.cpp @@ -85,7 +85,6 @@ public: } Im2ColFunctor im2col; - GemmFunctor gemm; size_t inputOffset = imShape.getElements(); size_t outputOffset = (outputChannels / groups_) * outputHeight * outputWidth; @@ -108,19 +107,19 @@ public: int M = outputChannels / groups_; int N = outputHeight * outputWidth; int K = inputChannels / groups_ * filterHeight * filterWidth; - gemm(CblasNoTrans, - CblasNoTrans, - M, - N, - K, - 1.0f, - filterData + g * filterOffset, - K, - colData, - N, - beta, - outputData + g * outputOffset, - N); + BlasGemm::compute(false, + false, + M, + N, + K, + 1.0f, + filterData + g * filterOffset, + K, + colData, + N, + beta, + outputData + g * outputOffset, + N); } inputData += inputChannels * inputHeight * inputWidth; outputData += outputChannels * outputHeight * outputWidth; @@ -188,8 +187,6 @@ public: } Col2ImFunctor col2im; - GemmFunctor gemm; - size_t inputOffset = imShape.getElements(); size_t outputOffset = (outputChannels / groups_) * outputHeight * outputWidth; @@ -205,19 +202,19 @@ public: colData = inputGrad + g * inputOffset; scale = 1.0f; } - gemm(CblasTrans, - CblasNoTrans, - M, - N, - K, - 1.0f, - filterData + g * filterOffset, - M, - outputGrad + g * outputOffset, - N, - scale, - colData, - N); + BlasGemm::compute(true, + false, + M, + N, + K, + 1.0f, + filterData + g * filterOffset, + M, + outputGrad + g * outputOffset, + N, + scale, + colData, + N); if (needIm2col) { col2im(inputGrad + g * inputOffset, imShape, @@ -299,7 +296,6 @@ public: } Im2ColFunctor im2col; - GemmFunctor gemm; size_t inputOffset = imShape.getElements(); size_t outputOffset = (outputChannels / groups_) * outputHeight * outputWidth; @@ -321,19 +317,19 @@ public: int M = outputChannels / groups_; int K = outputHeight * outputWidth; int N = inputChannels / groups_ * filterHeight * filterWidth; - gemm(CblasNoTrans, - CblasTrans, - M, - N, - K, - 1.0f, - outputGrad + g * outputOffset, - K, - colData, - K, - i == 0 ? beta : 1.0f, - filterGrad + g * filterOffset, - N); + BlasGemm::compute(false, + true, + M, + N, + K, + 1.0f, + outputGrad + g * outputOffset, + K, + colData, + K, + i == 0 ? beta : 1.0f, + filterGrad + g * filterOffset, + N); } inputData += inputChannels * inputHeight * inputWidth; outputGrad += outputChannels * outputHeight * outputWidth; diff --git a/paddle/function/GemmFunctor.cpp b/paddle/function/GemmFunctor.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9e25ee58a12490a1454436b3fe4a89176478d5c0 --- /dev/null +++ b/paddle/function/GemmFunctor.cpp @@ -0,0 +1,90 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "GemmFunctor.h" +#include "paddle/math/MathFunctions.h" + +namespace paddle { + +template +struct BlasGemm { + static void compute(const bool transA, + const bool transB, + const int M, + const int N, + const int K, + const T alpha, + const T* A, + const int lda, + const T* B, + const int ldb, + const T beta, + T* C, + const int ldc) { +#ifdef PADDLE_USE_EIGEN_FOR_BLAS + EigenBlasGemm::compute( + transA, transB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); +#else + gemm(transA == false ? CblasNoTrans : CblasTrans, + transB == false ? CblasNoTrans : CblasTrans, + M, + N, + K, + alpha, + A, + lda, + B, + ldb, + beta, + C, + ldc); +#endif + } +}; + +template +struct BlasGemm { + static void compute(const bool transA, + const bool transB, + const int M, + const int N, + const int K, + const T alpha, + const T* A, + const int lda, + const T* B, + const int ldb, + const T beta, + T* C, + const int ldc) { + hl_matrix_mul((T*)A, + transA == false ? HPPL_OP_N : HPPL_OP_T, + (T*)B, + transB == false ? HPPL_OP_N : HPPL_OP_T, + C, + M, + N, + K, + alpha, + beta, + lda, + ldb, + ldc); + } +}; + +template struct BlasGemm; +template struct BlasGemm; + +} // namespace paddle diff --git a/paddle/function/GemmFunctor.h b/paddle/function/GemmFunctor.h index d5db5cf5e7a855d89b262fe8cf42aa2c55f419f1..0809953b4eb17c25eadcce7f474a3dab0469bba1 100644 --- a/paddle/function/GemmFunctor.h +++ b/paddle/function/GemmFunctor.h @@ -14,7 +14,7 @@ limitations under the License. */ #pragma once -#include "paddle/math/MathFunctions.h" +#include "TensorType.h" namespace paddle { @@ -24,73 +24,42 @@ namespace paddle { // of MatMulFunction, we need to consider the reconstruction of hl_matrix_mul // interface. template -class GemmFunctor { -public: - void operator()(const CBLAS_TRANSPOSE transA, - const CBLAS_TRANSPOSE TransB, - const int M, - const int N, - const int K, - const T alpha, - const T* A, - const int lda, - const T* B, - const int ldb, - const T beta, - T* C, - const int ldc); +struct BlasGemm { + static void compute(const bool transA, + const bool transB, + const int M, + const int N, + const int K, + const T alpha, + const T* A, + const int lda, + const T* B, + const int ldb, + const T beta, + T* C, + const int ldc); }; +// TODO(hedaoyuan): Since the definition of the real type in the Paddle +// conflicts with the Eigen library, so compile the Eigen code can not +// include the Paddle header file. And need an EigenBlasGemm template class +// that does not contain the DeviceType parameter. +// I will fix this problem and merge BlasGemm and EigenBlasGemm into one. template -class GemmFunctor { -public: - void operator()(const CBLAS_TRANSPOSE transA, - const CBLAS_TRANSPOSE TransB, - const int M, - const int N, - const int K, - const T alpha, - const T* A, - const int lda, - const T* B, - const int ldb, - const T beta, - T* C, - const int ldc) { - gemm(transA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); - } -}; - -template -class GemmFunctor { -public: - void operator()(const CBLAS_TRANSPOSE transA, - const CBLAS_TRANSPOSE TransB, - const int M, - const int N, - const int K, - const T alpha, - const T* A, - const int lda, - const T* B, - const int ldb, - const T beta, - T* C, - const int ldc) { - hl_matrix_mul((T*)A, - transA == CblasNoTrans ? HPPL_OP_N : HPPL_OP_T, - (T*)B, - TransB == CblasNoTrans ? HPPL_OP_N : HPPL_OP_T, - C, - M, - N, - K, - alpha, - beta, - lda, - ldb, - ldc); - } +struct EigenBlasGemm { + static void compute(const bool transA, + const bool transB, + const int M, + const int N, + const int K, + const T alpha, + const T* A, + const int lda, + const T* B, + const int ldb, + const T beta, + T* C, + const int ldc); }; } // namespace paddle diff --git a/paddle/gserver/gradientmachines/NeuralNetwork.cpp b/paddle/gserver/gradientmachines/NeuralNetwork.cpp index cfa80a89365af5111746eec9599d16e37532a9f7..26cff3e67710b2f38d93572c5d58849aa94a5135 100644 --- a/paddle/gserver/gradientmachines/NeuralNetwork.cpp +++ b/paddle/gserver/gradientmachines/NeuralNetwork.cpp @@ -202,7 +202,7 @@ void NeuralNetwork::prefetch(const std::vector& inArgs) { auto mat = dynamic_cast( para->getMat(PARAMETER_VALUE).get()); para->clearGradient(); - mat->clearIndices(); + if (mat) mat->clearIndices(); } } } diff --git a/paddle/gserver/gradientmachines/RecurrentGradientMachine.cpp b/paddle/gserver/gradientmachines/RecurrentGradientMachine.cpp index f98bf95064fa539b990309dfe0bff10c1e99d096..d00d408ab8d2b1e6951e9fe58981ba85b9077908 100644 --- a/paddle/gserver/gradientmachines/RecurrentGradientMachine.cpp +++ b/paddle/gserver/gradientmachines/RecurrentGradientMachine.cpp @@ -184,7 +184,7 @@ public: } void backward(const UpdateCallback& callback) override { - if (biases_) { + if (biases_ && biases_->getWGrad()) { backwardActivation(); biases_->getWGrad()->collectBias(*getOutputGrad(), 1); biases_->getParameterPtr()->incUpdate(callback); @@ -1012,11 +1012,6 @@ void RecurrentGradientMachine::generateSequence() { /* width */ resultNum, false, /* useGpu */ false); - Matrix::resizeOrCreate(generator_.outArg.value, - /* height */ maxGenWordCount, - /* width */ 1, - false, - /* useGpu */ false); } ICpuGpuVector::resizeOrCreate(generator_.outArg.sequenceStartPositions, numSequences + 1, @@ -1026,7 +1021,7 @@ void RecurrentGradientMachine::generateSequence() { } else { oneWaySearch(numSequences); } - if (dataArgsSize_) createDataOutlink(batchMachineIdVec_); + if (dataArgsSize_) createDataOutlink(); size_t size = generator_.ids.size(); generator_.outArg.ids->resize(size); @@ -1106,6 +1101,7 @@ void RecurrentGradientMachine::oneWaySearch(size_t batchSize) { } batchMachineIdVec_.clear(); + batchMachineStartPos_.clear(); int* starts = generator_.outArg.sequenceStartPositions->getMutableData(false); starts[0] = 0; generator_.ids.clear(); @@ -1312,13 +1308,20 @@ void RecurrentGradientMachine::fillGenOutputs() { finalPaths_[i].resize(minFinalPathsSize); } - batchMachineIdVec_.clear(); generator_.ids.clear(); int* starts = generator_.outArg.sequenceStartPositions->getMutableData(false); starts[0] = 0; if (numResults > 1) { - real* probs = generator_.outArg.in->getData(); + int idsProbSaveSize = 0; + for (auto inSeq : finalPaths_) { + for (auto path : inSeq) idsProbSaveSize += path.ids.size(); + idsProbSaveSize += inSeq.size(); + } + Matrix::resizeOrCreate( + generator_.outArg.value, idsProbSaveSize, 1, false, false); real* idsProb = generator_.outArg.value->getData(); + + real* probs = generator_.outArg.in->getData(); size_t curPos = 0; for (size_t i = 0; i < finalPaths_.size(); ++i) { for (size_t j = 0; j < finalPaths_[i].size(); ++j) { @@ -1333,24 +1336,16 @@ void RecurrentGradientMachine::fillGenOutputs() { curPos += genLen; idsProb[curPos++] = -1.0; probs[i * numResults + j] = path.logProb; - - if (!j && dataArgsSize_) { - // in beam search, here only reserved the top 1 generated result - // for out_links that are not the generated word indices. - batchMachineIdVec_.insert(batchMachineIdVec_.end(), - path.machineIdVec.begin(), - path.machineIdVec.end()); - } } starts[i + 1] = generator_.ids.size(); } } else { for (size_t i = 0; i < finalPaths_.size(); ++i) { CHECK(!finalPaths_[i].empty()); - generator_.ids.insert(generator_.ids.begin(), - finalPaths_[i][0].ids.begin(), - finalPaths_[i][0].ids.end()); - starts[i + 1] = starts[i] + finalPaths_[i][0].ids.size(); + Path& path = finalPaths_[i][0]; + generator_.ids.insert( + generator_.ids.begin(), path.ids.begin(), path.ids.end()); + starts[i + 1] = starts[i] + path.ids.size(); } } } @@ -1364,25 +1359,76 @@ void RecurrentGradientMachine::copyDataOutlinkFrame(size_t machineCur) { } } -void RecurrentGradientMachine::createDataOutlink( - std::vector& machineIdVec) { - size_t seqNum = - getBeamSize() > 1UL ? finalPaths_.size() : finalPaths_[0].size(); - std::vector starts(seqNum + 1, 0); - for (size_t i = 0; i < seqNum; ++i) { - size_t seqLen = getBeamSize() > 1UL ? finalPaths_[i][0].ids.size() - : finalPaths_[0][i].ids.size(); - starts[i + 1] = starts[i] + seqLen; +void RecurrentGradientMachine::createDataOutlinkSelRowsInfo( + bool isSeq, std::vector& outArgs) { + batchMachineIdVec_.clear(); + + size_t seqIdx = 0; + for (size_t i = 0; i < finalPaths_.size(); ++i) { + for (size_t j = 0; j < finalPaths_[i].size(); ++j) { + std::vector& machineIdVec = finalPaths_[i][j].machineIdVec; + if (isSeq) { + for (size_t i = 0; i < machineIdVec.size(); ++i) { + size_t rowId = machineIdVec[i]; + int* seqPos = + outArgs[i].sequenceStartPositions->getMutableData(false); + batchMachineIdVec_.push_back(seqPos[rowId]); + } + } else { + batchMachineIdVec_.insert( + batchMachineIdVec_.end(), machineIdVec.begin(), machineIdVec.end()); + } + seqIdx++; + } + } +} + +void RecurrentGradientMachine::createDataOutlinkCopySizeInfo( + bool isSeq, std::vector& outArgs, std::vector& copySize) { + size_t totalSeqNum = std::accumulate( + finalPaths_.begin(), + finalPaths_.end(), + 0UL, + [](size_t a, const std::vector& b) { return a + b.size(); }); + copySize.resize(totalSeqNum, 1); + + batchMachineStartPos_.resize(totalSeqNum + 1, 0); + if (isSeq) { + ICpuGpuVectorPtr inputSeqStartPos = outArgs[0].sequenceStartPositions; + CHECK_EQ(static_cast(inputSeqStartPos->getSize() - 1), + getBeamSize() > 1 ? finalPaths_.size() : finalPaths_[0].size()); + int* starts = inputSeqStartPos->getMutableData(false); + int seqId = 0; + for (size_t i = 0; i < finalPaths_.size(); ++i) { + for (size_t j = 0; j < finalPaths_[i].size(); ++j) { + copySize[seqId] = getBeamSize() > 1 ? starts[i + 1] - starts[i] + : starts[j + 1] - starts[j]; + batchMachineStartPos_[seqId + 1] = + batchMachineStartPos_[seqId] + finalPaths_[i][j].ids.size(); + seqId++; + } + } + } else { + for (size_t i = 0; i < finalPaths_[0].size(); ++i) + batchMachineStartPos_[i + 1] = + batchMachineStartPos_[i] + finalPaths_[0][i].ids.size(); } +} +void RecurrentGradientMachine::createDataOutlink() { for (size_t i = 0; i < dataArgsSize_; i++) { + bool isSeq = dataArgsFrame_[i][0].hasSeq(); + std::vector copySize; + createDataOutlinkCopySizeInfo(isSeq, dataArgsFrame_[i], copySize); + createDataOutlinkSelRowsInfo(isSeq, dataArgsFrame_[i]); + dataArgs_[i].concat(dataArgsFrame_[i], - machineIdVec, - starts, + batchMachineIdVec_, + batchMachineStartPos_, + copySize, useGpu_, HPPL_STREAM_1, PASS_TEST); - auto dataAgent = dynamic_cast(outFrameLines_[i + 1].agentLayer.get()); CHECK_NOTNULL(dataAgent); diff --git a/paddle/gserver/gradientmachines/RecurrentGradientMachine.h b/paddle/gserver/gradientmachines/RecurrentGradientMachine.h index fb3fc5877ac96323e891f800db80af83b6809831..c16fae6d1770e616fdcfabd440624c9be9753c91 100644 --- a/paddle/gserver/gradientmachines/RecurrentGradientMachine.h +++ b/paddle/gserver/gradientmachines/RecurrentGradientMachine.h @@ -190,7 +190,7 @@ public: std::vector ids; /** - * @brief idsProb, log probability of each generated words. + * @brief idsProb, log probability of each generated word. */ std::vector idsProb; @@ -472,15 +472,43 @@ private: void copyDataOutlinkFrame(size_t machineCur); /* - * @brief In generation, if the layer group has more than 1 outlink, outlinks - * except the first one are data outlinks. This function creates the data - * outlinks. - * @note In beam search, only one generated sequence with the hightest log - * probabilites are retained. - * @param machineIdVec : select a row of output matrix in each frame - * that the generation process expanded. + * @brief In generation, if the layer group has more than 1 outlink, outlink + * except the first one is a data outlink. In RecurrentLayerGroup, each time + * step is a separate Network, outputs of a layer inside the + * RecurrentLayerGroup are stored in separate Arguments. If one layer is + * specified as an outlink of RecurrentLayerGroup. This function will + * collect outputs in each time step of each generated sequence which are + * dispersed in separate Arguments to form a new single Argument as output of + * RecurrentLayerGroup. */ - void createDataOutlink(std::vector& machineIdVec); + void createDataOutlink(); + + /* + * @brief decide to select how many rows from the Matrix stored the forward + * pass results from a start position. + * + * @param isSeq: a flag indicating whetehr the layer to be output of the + * RecurrentGradientMachine is a sequence or not + * @param outArgs: all of the the returned Arguments of the forward pass + * during the generation process. + * @param copySize: the returned result, number of rows to select from the + * Matrix stored the forward pass results from a start position. + */ + void createDataOutlinkCopySizeInfo(bool isSeq, + std::vector& outArgs, + std::vector& copySize); + + /* + * @brief decide index of the start row for each time step of a generated + * sequence in Matrix stored the entire beam search batch's forward pass + * results. + * + * @param isSeq: a flag indicating whether the layer to be output of the + * RecurrentGradientMachine is a sequence or not + * @param outArgs: all of the returned Arguments of the forward pass + * during the generation process. + */ + void createDataOutlinkSelRowsInfo(bool isSeq, std::vector& outArgs); /* * @brief used in beam search, connect previous frame to form recurrent link @@ -543,6 +571,7 @@ private: std::vector topIds_; std::vector seqIds_; std::vector batchMachineIdVec_; + std::vector batchMachineStartPos_; std::vector> finalPaths_; std::vector minFinalPathLogProb_; BeamSearchControlCallbacks* beamSearchCtrlCallbacks_; diff --git a/paddle/gserver/layers/KmaxSeqScoreLayer.cpp b/paddle/gserver/layers/KmaxSeqScoreLayer.cpp index 8ce591d4762466e1ed4b2970cb9cae9203bc0a2b..d5407555b248d79a5156a5ea354042d43ecda02c 100644 --- a/paddle/gserver/layers/KmaxSeqScoreLayer.cpp +++ b/paddle/gserver/layers/KmaxSeqScoreLayer.cpp @@ -80,13 +80,14 @@ void KmaxSeqScoreLayer::forward(PassType passType) { << "input of " << getName() << " must be a sequence or a nested sequence."; CHECK_EQ(input.value->getWidth(), 1UL) - << "input of " << getName() - << " is score over a sequence or a nested sequence, so its width " - << " must be 1."; + << "input of " << getName() << " are scores over a sequence or " + << "a nested sequence, so its width must be 1."; if (useGpu_) { - // this Layer runs only in CPU, if the model is runing on GPU, - // then copy the input to this layer from GPU to CPU. + /* + * currently, this Layer only runs in CPU, if the other part of the model is + * runing on GPU, then copy the input to this layer from GPU to CPU. + */ Matrix::resizeOrCreate(scores_, inputScore->getHeight(), 1, @@ -97,6 +98,14 @@ void KmaxSeqScoreLayer::forward(PassType passType) { scores_ = inputScore; } + /* + * TODO(caoying) + * In PaddePaddle, currently all matrices are real number types, + * but output of this layer which is some selected indices of the give + * sequence are actually filled with int types so that storing int types + * information in a real number matrix is dangerous, since real numbers will + * be convered to int types. + */ Matrix::resizeOrCreate( output_.value, input.hasSubseq() ? input.getNumSubSequences() : input.getNumSequences(), diff --git a/paddle/gserver/layers/ScaleShiftLayer.cpp b/paddle/gserver/layers/ScaleShiftLayer.cpp new file mode 100644 index 0000000000000000000000000000000000000000..35fd038ab43a8a8b08bc328b3d1b08a7bbedd0a1 --- /dev/null +++ b/paddle/gserver/layers/ScaleShiftLayer.cpp @@ -0,0 +1,107 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "Layer.h" + +namespace paddle { + +/** + * A layer applies a linear transformation to each element in each row of + * the input matrix. For each element, the layer first re-scale it and then + * adds a bias to it. + * + * \f[ + * y = wx + b + * \f] + * + * Here, w is the scale and b is the bias. Both w and b are trainable scalars. + * + */ + +class ScaleShiftLayer : public Layer { +protected: + std::unique_ptr scale_; + std::unique_ptr offset_; + +public: + explicit ScaleShiftLayer(const LayerConfig& config) : Layer(config) {} + + bool init(const LayerMap& layerMap, + const ParameterMap& parameterMap) override; + + void forward(PassType passType) override; + void backward(const UpdateCallback& callback = nullptr) override; +}; + +REGISTER_LAYER(scale_shift, ScaleShiftLayer); + +bool ScaleShiftLayer::init(const LayerMap& layerMap, + const ParameterMap& parameterMap) { + Layer::init(layerMap, parameterMap); + CHECK_EQ(inputLayers_.size(), 1U); + scale_.reset(new Weight(1, 1, parameters_[0])); + if (biasParameter_.get() != NULL) { + offset_ = std::unique_ptr(new Weight(1, 1, biasParameter_)); + } + return true; +} + +void ScaleShiftLayer::forward(PassType passType) { + Layer::forward(passType); + + MatrixPtr inV = getInputValue(0); + resetOutput(inV->getHeight(), inV->getWidth()); + MatrixPtr outV = getOutputValue(); + real scaleValue = scale_->getW()->getElement(0, 0); + outV->mulScalar(*inV, scaleValue); + if (offset_) { + real offsetValue = offset_->getW()->getElement(0, 0); + outV->add(offsetValue); + } +} + +void ScaleShiftLayer::backward(const UpdateCallback& callback) { + MatrixPtr inV = getInputValue(0); + MatrixPtr inG = getInputGrad(0); + MatrixPtr outV = getOutputValue(); + MatrixPtr outG = getOutputGrad(); + + /* Calculate the parameter gradient for the current layer */ + if (scale_->getWGrad()) { + MatrixPtr rowSumMtx; + Matrix::resizeOrCreate(rowSumMtx, outG->getHeight(), 1, false, useGpu_); + // this_i = scaleDest * this_i + scaleSum * \sum_j b_{ij} * c_{ij} + rowSumMtx->sumOfProducts( + /* b= */ *inV, /* c= */ *outG, /* scaleSum= */ 1, /* scaleDest= */ 0.); + // this_i = scaleDest * this_i + scaleSum * \sum_j b_{ji} + scale_->getWGrad()->sumCols( + /* b= */ *rowSumMtx, /* scaleSum= */ 1., /* scaleDest= */ 1.); + scale_->getParameterPtr()->incUpdate(callback); + } + if (offset_ && offset_->getWGrad()) { + MatrixPtr rowSumMtx; + Matrix::resizeOrCreate(rowSumMtx, outG->getHeight(), 1, false, useGpu_); + rowSumMtx->sumRows(*outG, 1., 0.); + offset_->getWGrad()->sumCols(*rowSumMtx, 1., 1.); + offset_->getParameterPtr()->incUpdate(callback); + } + + /* Calculate the input layers error */ + if (inG) { + real scaleValue = scale_->getW()->getElement(0, 0); + inG->add(*outG, scaleValue); + } +} + +} // namespace paddle diff --git a/paddle/gserver/layers/SequenceSliceLayer.cpp b/paddle/gserver/layers/SequenceSliceLayer.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6f6577445f56c17d2c43d21a19a086c985714658 --- /dev/null +++ b/paddle/gserver/layers/SequenceSliceLayer.cpp @@ -0,0 +1,221 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "Layer.h" +#include "paddle/math/Matrix.h" +#include "paddle/math/Vector.h" +#include "paddle/utils/Logging.h" +#include "paddle/utils/Stat.h" + +namespace paddle { + +class SequenceSliceLayer : public Layer { +public: + explicit SequenceSliceLayer(const LayerConfig& config) : Layer(config) {} + + bool init(const LayerMap& layerMap, + const ParameterMap& parameterMap) override; + + void forward(PassType passType) override; + void backward(const UpdateCallback& callback = nullptr) override; + +private: + /* + * TODO(caoying) + * In PaddePaddle, currently all matrices are real number types, + * but the second and the (optional) third input which are some + * selected indices of the give sequence to trim the sequence, are actually + * filled with int types so that storing int types information in real number + * matrices is very dangerous, since real numbers will be convered to int + * types. If a user fills this matrix himself, invalid data may occor. + */ + + MatrixPtr startIdsOnCpu_; + MatrixPtr endIdsOnCpu_; + + std::vector selectedRows_; + IVectorPtr rowIndice_; + std::vector> inputSeqInfoVec_; + std::vector outSubSeqStartPos_; + std::vector outSeqStartPos_; + + void checkInputs(); + void copySliceIdsToCpu(); + void calSelectedRows(const MatrixPtr starts, const MatrixPtr ends); +}; + +REGISTER_LAYER(seq_slice, SequenceSliceLayer); + +bool SequenceSliceLayer::init(const LayerMap& layerMap, + const ParameterMap& parameterMap) { + /* Initialize the basic parent class */ + Layer::init(layerMap, parameterMap); + CHECK_GE(inputLayers_.size(), 2U); + CHECK_LE(inputLayers_.size(), 3U); + + setNeedSequenceInfo(false); + return true; +} + +void SequenceSliceLayer::checkInputs() { + const Argument& inputSeq = getInput(0); + CHECK(inputSeq.hasSeq()) << "The first input of sequence slice layer " + << "must be a sequence."; + const MatrixPtr indices1 = getInputValue(1); + CHECK_EQ(static_cast(indices1->getHeight()), + inputSeq.hasSubseq() ? inputSeq.getNumSubSequences() + : inputSeq.getNumSequences()) + << "Height of the second input should be equal to number of sequence " + << "in the first input."; + if (inputLayers_.size() == 3) { + const MatrixPtr indices2 = getInputValue(2); + CHECK_EQ(indices2->getHeight(), indices1->getHeight()) + << "start indices and end indices should have the same height."; + CHECK_EQ(indices2->getWidth(), indices1->getWidth()) + << "start indices and end indices should have the same Width."; + } +} + +void SequenceSliceLayer::copySliceIdsToCpu() { + const MatrixPtr indices1 = getInputValue(1); + if (inputLayers_.size() == 2U) { + if (config_.select_first()) { + Matrix::resizeOrCreate(startIdsOnCpu_, + indices1->getHeight(), + indices1->getWidth(), + false /* trans */, + false /* useGpu */); + startIdsOnCpu_->copyFrom(*indices1); + endIdsOnCpu_ = nullptr; + } else { + Matrix::resizeOrCreate(endIdsOnCpu_, + indices1->getHeight(), + indices1->getWidth(), + false /* trans */, + false /* useGpu */); + endIdsOnCpu_->copyFrom(*indices1); + startIdsOnCpu_ = nullptr; + } + } else if (inputLayers_.size() == 3U) { + Matrix::resizeOrCreate(startIdsOnCpu_, + indices1->getHeight(), + indices1->getWidth(), + false /* trans */, + false /* useGpu */); + startIdsOnCpu_->copyFrom(*indices1); + + const MatrixPtr indices2 = getInputValue(2); + Matrix::resizeOrCreate(endIdsOnCpu_, + indices2->getHeight(), + indices2->getWidth(), + false /* trans */, + false /* useGpu */); + endIdsOnCpu_->copyFrom(*indices2); + } +} + +void SequenceSliceLayer::calSelectedRows(const MatrixPtr starts, + const MatrixPtr ends) { + CHECK(starts || ends) << "At least one of the start or end indices " + << "should be given."; + + outSeqStartPos_.resize(1, 0); + outSubSeqStartPos_.resize(1, 0); + selectedRows_.clear(); + + size_t beamSize = starts ? starts->getWidth() : ends->getWidth(); + size_t rowIdx = 0; + for (size_t i = 0; i < inputSeqInfoVec_.size(); ++i) { + for (size_t j = 0; j < inputSeqInfoVec_[i].size() - 1; ++j) { + for (size_t k = 0; k < beamSize; ++k) { + if (starts && starts->getElement(rowIdx, k) == -1.) break; + if (ends && ends->getElement(rowIdx, k) == -1.) break; + + int begPos = inputSeqInfoVec_[i][j]; + if (starts) begPos += starts->getElement(rowIdx, k); + + int endPos = inputSeqInfoVec_[i][j + 1] - 1; + if (ends) endPos = inputSeqInfoVec_[i][j] + ends->getElement(rowIdx, k); + + int seqLen = endPos - begPos + 1; + CHECK_GT(seqLen, 0U); + for (int m = begPos; m <= endPos; ++m) selectedRows_.push_back(m); + inputSeqInfoVec_.size() > 1 + ? outSubSeqStartPos_.push_back(outSubSeqStartPos_.back() + seqLen) + : outSeqStartPos_.push_back(outSeqStartPos_.back() + seqLen); + } + rowIdx++; + } + if (inputSeqInfoVec_.size() > 1) + outSeqStartPos_.push_back(outSubSeqStartPos_.back()); + } + + if (useGpu_) { + rowIndice_ = IVector::create(selectedRows_.size(), useGpu_); + rowIndice_->copyFrom(selectedRows_.data(), selectedRows_.size()); + } else { + rowIndice_ = + IVector::create(selectedRows_.data(), selectedRows_.size(), useGpu_); + } + + // create the sequence information for the output. + ICpuGpuVector::resizeOrCreate( + output_.sequenceStartPositions, outSeqStartPos_.size(), false); + output_.sequenceStartPositions->copyFrom( + outSeqStartPos_.data(), outSeqStartPos_.size(), false); + + if (inputSeqInfoVec_.size() > 1) { + ICpuGpuVector::resizeOrCreate( + output_.subSequenceStartPositions, outSubSeqStartPos_.size(), false); + output_.subSequenceStartPositions->copyFrom( + outSubSeqStartPos_.data(), outSubSeqStartPos_.size(), false); + } +} + +void SequenceSliceLayer::forward(PassType passType) { + Layer::forward(passType); + checkInputs(); + + const Argument& inputSeq = getInput(0); + inputSeqInfoVec_.clear(); + Argument::reorganizeSeqInfo(inputSeq.sequenceStartPositions, + inputSeq.subSequenceStartPositions, + inputSeqInfoVec_); + if (!useGpu_) { + if (inputLayers_.size() == 2U) { + startIdsOnCpu_ = config_.select_first() ? getInputValue(1) : nullptr; + endIdsOnCpu_ = config_.select_first() ? nullptr : getInputValue(1); + } else if (inputLayers_.size() == 3U) { + startIdsOnCpu_ = getInputValue(1); + endIdsOnCpu_ = getInputValue(2); + } + } else { + copySliceIdsToCpu(); + } + + // calculate the selected row indices in a batch, + // and build the output sequence information. + calSelectedRows(startIdsOnCpu_ ? startIdsOnCpu_ : nullptr, + endIdsOnCpu_ ? endIdsOnCpu_ : nullptr); + + resetOutput(selectedRows_.size(), getSize()); + + getOutputValue()->selectRows(*getInputValue(0), *rowIndice_); +} + +void SequenceSliceLayer::backward(const UpdateCallback& callback) { + getOutputGrad()->addToRows(*getInputGrad(0), *rowIndice_); +} + +} // namespace paddle diff --git a/paddle/gserver/layers/SubNestedSequenceLayer.cpp b/paddle/gserver/layers/SubNestedSequenceLayer.cpp index 648d3908f391450f276d8a900ebb3bccb8d5532c..e9bee77212065effdac78cba590caed2e9155f0a 100644 --- a/paddle/gserver/layers/SubNestedSequenceLayer.cpp +++ b/paddle/gserver/layers/SubNestedSequenceLayer.cpp @@ -52,23 +52,34 @@ private: * ] * * ths output is saved to private member rowIndice_; - * [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15, - * 16,17,18,19,20,21,22,23,24,25,26,27] + * [0,1,2,3,4,5,6,7,8,9,15,16,17,18,19,20,21,23,24,25,26,27] */ - void calSelectedCols(const MatrixPtr selectedIndices, + void calSelectedRows(const MatrixPtr selectedIndices, const std::vector>& inputSeqInfo); - // if the second input of this layer is on GPU memory, copy it to CPU memory. + /* + * TODO(caoying) + * In PaddePaddle, currently all matrices are real number types, + * but the second is some selected indices of the give sequence to trim + * the nested sequence, are actually filled with int types so that storing + * int types information in real number matrices is very dangerous, since + * real numbers will be convered to int types. If a user fills this matrix + * himself, invalid data may occor. + * + * if the second input of this layer is on GPU memory, copy it to CPU memory. + */ MatrixPtr selIdsCpu_; - // reorganized sequenceStartPositions and subSequenceStartPositions - // into a 2d vector to facilitate the sequence selection process. + /* + * reorganize sequenceStartPositions and subSequenceStartPositions + * into a 2d vector to facilitate the sequence selection process. + */ std::vector> inputSeqInfoVec_; - // the final selected row indices in a batch, - // rowIdx_ and selectedRows_ actually share a same memory. + /* store the final selected row indices in a batch */ IVectorPtr rowIndice_; + /* rowIndice_ and selectedRows_ actually share a same memory. */ std::vector selectedRows_; }; @@ -83,7 +94,7 @@ bool SubNestedSequenceLayer::init(const LayerMap& layerMap, return true; } -void SubNestedSequenceLayer::calSelectedCols( +void SubNestedSequenceLayer::calSelectedRows( const MatrixPtr selectedIndices, const std::vector>& inputSeqInfo) { selectedRows_.clear(); @@ -160,7 +171,7 @@ void SubNestedSequenceLayer::forward(PassType passType) { Argument::reorganizeSeqInfo(inputSeq.sequenceStartPositions, inputSeq.subSequenceStartPositions, inputSeqInfoVec_); - calSelectedCols(selIdsCpu_, inputSeqInfoVec_); + calSelectedRows(selIdsCpu_, inputSeqInfoVec_); resetOutput(selectedRows_.size(), getSize()); getOutputValue()->selectRows(*getInputValue(0), *rowIndice_); diff --git a/paddle/gserver/tests/CMakeLists.txt b/paddle/gserver/tests/CMakeLists.txt index c2a2993620492a9ec5dae932ff1292ced2c00064..346c01ced648e47a5516c810e1e975a3a5ed2394 100644 --- a/paddle/gserver/tests/CMakeLists.txt +++ b/paddle/gserver/tests/CMakeLists.txt @@ -34,6 +34,12 @@ add_unittest_without_exec(test_CRFLayerGrad add_test(NAME test_CRFLayerGrad COMMAND test_CRFLayerGrad) +################ test_SeqSliceLayerGrad #################### +add_unittest_without_exec(test_SeqSliceLayerGrad + test_SeqSliceLayerGrad.cpp + LayerGradUtil.cpp) +add_test(NAME test_SeqSliceLayerGrad + COMMAND test_SeqSliceLayerGrad) add_unittest_without_exec(test_ActivationGrad test_ActivationGrad.cpp diff --git a/paddle/gserver/tests/test_LayerGrad.cpp b/paddle/gserver/tests/test_LayerGrad.cpp index c522b20f0e7f0a9951bfdab488b40aab390a3068..b974dc5d573884fd099b9755a7e60202e9cfeb6c 100644 --- a/paddle/gserver/tests/test_LayerGrad.cpp +++ b/paddle/gserver/tests/test_LayerGrad.cpp @@ -2025,6 +2025,21 @@ TEST(Layer, RowL2NormLayer) { } } +TEST(Layer, ScaleShiftLayer) { + const size_t batchSize = 16; + const size_t size = 32; + TestConfig config; + config.layerConfig.set_type("scale_shift"); + config.layerConfig.set_size(size); + config.biasSize = 1; + config.inputDefs.push_back( + {INPUT_DATA, "input", /* dim= */ size, /* paraSize= */ 1}); + config.layerConfig.add_inputs(); + for (auto useGpu : {false, true}) { + testLayerGrad(config, "scale_shift", batchSize, false, useGpu, false); + } +} + int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); initMain(argc, argv); diff --git a/paddle/gserver/tests/test_NetworkCompare.cpp b/paddle/gserver/tests/test_NetworkCompare.cpp index f930c72fde3f5e0a6a45cb6bfd3507a4f48028fc..d36f72360f8ebd2033fb3e8c0e1b30911abba362 100644 --- a/paddle/gserver/tests/test_NetworkCompare.cpp +++ b/paddle/gserver/tests/test_NetworkCompare.cpp @@ -269,7 +269,8 @@ TEST(Compare, img_conv2) { bool useGpu = FLAGS_use_gpu; double eps = FLAGS_checkgrad_eps; FLAGS_use_gpu = true; - FLAGS_checkgrad_eps = 1e-2; + // Sometimes, this unit test will fail with 1e-2 + FLAGS_checkgrad_eps = 4e-2; compareNetwork(config_file_a, config_file_b); FLAGS_use_gpu = useGpu; FLAGS_checkgrad_eps = eps; diff --git a/paddle/gserver/tests/test_SeqSliceLayerGrad.cpp b/paddle/gserver/tests/test_SeqSliceLayerGrad.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d560ca650bc5b156de280a2a0d698b67eb032907 --- /dev/null +++ b/paddle/gserver/tests/test_SeqSliceLayerGrad.cpp @@ -0,0 +1,223 @@ +/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "ModelConfig.pb.h" +#include "paddle/gserver/layers/DataLayer.h" +#include "paddle/trainer/Trainer.h" + +#include "LayerGradUtil.h" +#include "paddle/testing/TestUtil.h" + +using namespace paddle; // NOLINT +using namespace std; // NOLINT + +DECLARE_int32(gpu_id); +DECLARE_bool(thread_local_rand_use_global_seed); + +const int MAX_SEQ_NUM = 17; +const int MAX_SEQ_LEN = 23; +const int MAX_BEAM_SIZE = 13; + +vector randSampling(real range, int n) { + CHECK_GE(range, n); + vector num(range); + iota(begin(num), end(num), 0.); + if (range == n) return num; + + random_shuffle(begin(num), end(num)); + num.resize(n); + sort(begin(num), end(num)); + return num; +} + +void genSeqInfo(vector& seqStartPos, vector& subSeqStartPos) { + seqStartPos.resize(1, 0); + subSeqStartPos.resize(1, 0); + + srand((size_t)(time(NULL))); + int seqNum = 1 + (rand() % MAX_SEQ_NUM); + for (int i = 0; i < seqNum; ++i) { + int subSeqNum = 1 + (rand() % MAX_SEQ_NUM); + for (int j = 0; j < subSeqNum; ++j) + subSeqStartPos.push_back(subSeqStartPos.back() + + (1 + (rand() % MAX_SEQ_LEN))); + seqStartPos.push_back(subSeqStartPos.back()); + } +} + +/* + generate start indices according to sequence start positions. + */ +void genStarts(vector& seqStartPos, + vector>& starts, + size_t beamSize) { + starts.clear(); + starts.resize(seqStartPos.size() - 1, vector(beamSize, -1.)); + + for (size_t i = 0; i < seqStartPos.size() - 1; ++i) { + int seqLen = seqStartPos[i + 1] - seqStartPos[i]; + vector randStarts = + randSampling(seqLen, min(seqLen, static_cast(beamSize))); + copy(begin(randStarts), end(randStarts), begin(starts[i])); + } +} + +/* + generate end indices according to sequence start positions and start indices. + */ +void genEnds(vector& seqStartPos, + vector>& starts, + vector>& ends, + size_t beamSize) { + CHECK_EQ(seqStartPos.size() - 1, starts.size()); + ends.clear(); + ends.resize(seqStartPos.size() - 1, vector(beamSize, -1.)); + + for (size_t i = 0; i < starts.size(); ++i) { + for (size_t j = 0; j < starts[i].size(); ++j) { + int seqLen = seqStartPos[i + 1] - seqStartPos[i]; + CHECK_GE(seqLen - 1, starts[i][j]); + if (starts[i][j] == -1.) break; + if (starts[i][j] == (seqLen - 1)) { + ends[i][j] = starts[i][j]; + } else { + ends[i][j] = starts[i][j] + randSampling(seqLen - starts[i][j], 1)[0]; + } + } + } +} + +void genTestData(vector& seqStartPos, + vector& subSeqStartPos, + vector>& starts, + vector>& ends, + bool hasSubseq) { + size_t beamSize = 1 + (rand() % MAX_BEAM_SIZE); + genSeqInfo(seqStartPos, subSeqStartPos); + + genStarts(hasSubseq ? subSeqStartPos : seqStartPos, starts, beamSize); + genEnds(hasSubseq ? subSeqStartPos : seqStartPos, starts, ends, beamSize); +} + +template +void flatten2dVector(vector>& inVec, vector& outVec) { + size_t totalSize{0}; + for (auto const& items : inVec) totalSize += items.size(); + outVec.reserve(totalSize); + + for (auto& items : inVec) + move(items.begin(), items.end(), back_inserter(outVec)); +} + +void testSeqSliceLayer(bool hasSubseq, + bool useGpu, + vector& seqStartPos, + vector& subSeqStartPos, + vector>& starts, + vector>& ends) { + // layer size is not crutial for this layer, + // so here use a small layer size in the unittest. + const size_t layerSize{4}; + TestConfig config; + config.layerConfig.set_type("seq_slice"); + config.layerConfig.set_size(layerSize); + + // add the first input + MatrixPtr seqInputPtr = + Matrix::create(hasSubseq ? subSeqStartPos.back() : seqStartPos.back(), + layerSize, + false, + false); + seqInputPtr->randomizeUniform(); + + if (hasSubseq) { + config.inputDefs.push_back({INPUT_SELF_DEFINE_DATA, + "seq_input", + seqInputPtr, + seqStartPos, + subSeqStartPos}); + } else { + config.inputDefs.push_back( + {INPUT_SELF_DEFINE_DATA, "seq_input", seqInputPtr, seqStartPos}); + } + config.layerConfig.add_inputs(); + + // add start indices + if (starts.size()) { + vector startsToVec; + flatten2dVector(starts, startsToVec); + + MatrixPtr startMatrixPtr = + Matrix::create(starts.size(), starts[0].size(), false, false); + startMatrixPtr->copyFrom(startsToVec.data(), startsToVec.size()); + + config.inputDefs.push_back( + {INPUT_SELF_DEFINE_DATA, "starts", startMatrixPtr}); + config.layerConfig.add_inputs(); + config.layerConfig.set_select_first(true); + } + + // add end indices + if (ends.size()) { + vector endsToVec; + flatten2dVector(ends, endsToVec); + + MatrixPtr endMatrixPtr = + Matrix::create(ends.size(), ends[0].size(), false, false); + endMatrixPtr->copyFrom(endsToVec.data(), endsToVec.size()); + + config.inputDefs.push_back({INPUT_SELF_DEFINE_DATA, "ends", endMatrixPtr}); + config.layerConfig.add_inputs(); + config.layerConfig.set_select_first(false); + } + + testLayerGrad(config, "seq_slice", /*batchSize*/ 100, false, useGpu, false); +} + +TEST(Layer, SeqSliceLayer) { + vector seqStartPos; + vector subSeqStartPos; + vector> starts; + vector> ends; + + std::vector mode = {false}; +#ifndef PADDLE_ONLY_CPU + mode.push_back(true); +#endif + genSeqInfo(seqStartPos, subSeqStartPos); + for (bool hasSubseq : {true, false}) { + LOG(INFO) << "hasSubSeq : " << hasSubseq; + genTestData(seqStartPos, subSeqStartPos, starts, ends, hasSubseq); + for (bool useGpu : mode) { + vector> tmp; + testSeqSliceLayer( + hasSubseq, useGpu, seqStartPos, subSeqStartPos, tmp, ends); + testSeqSliceLayer( + hasSubseq, useGpu, seqStartPos, subSeqStartPos, starts, tmp); + testSeqSliceLayer( + hasSubseq, useGpu, seqStartPos, subSeqStartPos, starts, ends); + } + } +} + +int main(int argc, char** argv) { + initMain(argc, argv); + hl_start(); + hl_init(FLAGS_gpu_id); + FLAGS_thread_local_rand_use_global_seed = true; + srand(1); + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/paddle/memory/memory.cc b/paddle/memory/memory.cc index 0266bf4f7d65c7aafd4242af41cbd1c71f44bff8..29bc26f9d3bca0e30896657431f9a9bb1dac0d1d 100644 --- a/paddle/memory/memory.cc +++ b/paddle/memory/memory.cc @@ -19,8 +19,13 @@ limitations under the License. */ #include // for unique_ptr #include // for call_once +#include "glog/logging.h" + #include "paddle/memory/detail/buddy_allocator.h" #include "paddle/memory/detail/system_allocator.h" +#include "paddle/platform/gpu_info.h" + +DECLARE_double(fraction_of_gpu_memory_to_use); namespace paddle { namespace memory { @@ -80,6 +85,11 @@ BuddyAllocator* GetGPUBuddyAllocator(int gpu_id) { platform::GpuMinChunkSize(), platform::GpuMaxChunkSize())); } + VLOG(3) << "\n\nNOTE: each GPU device use " + << FLAGS_fraction_of_gpu_memory_to_use * 100 << "% of GPU memory.\n" + << "You can set environment variable '" + << platform::kEnvFractionGpuMemoryToUse + << "' to change the fraction of GPU usage.\n\n"; }); platform::SetDeviceId(gpu_id); diff --git a/paddle/memory/memory.h b/paddle/memory/memory.h index 72351b9dfa63513713463bb47a3684f0dfd84ad3..11bbb881874ec50e1132547336fc6fb6b42bcc4f 100644 --- a/paddle/memory/memory.h +++ b/paddle/memory/memory.h @@ -14,7 +14,6 @@ limitations under the License. */ #pragma once -#include "paddle/platform/gpu_info.h" #include "paddle/platform/place.h" namespace paddle { diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index a7c89787e43df6173791bc54b3faffc034867f7d..b56a45b6bd1e4d834a3c11da989b4a0707a24bf6 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -43,6 +43,7 @@ endfunction() add_subdirectory(math) cc_test(gather_test SRCS gather_test.cc DEPS tensor) +op_library(gather_op SRCS gather_op.cc gather_op.cu) cc_test(scatter_test SRCS scatter_test.cc DEPS tensor) @@ -68,3 +69,5 @@ op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc DEPS framework_proto tensor op_registry operator net_op) op_library(uniform_random_op SRCS uniform_random_op.cc uniform_random_op.cu) +op_library(scale_op SRCS scale_op.cc scale_op.cu DEPS net_op) +op_library(minus_op SRCS minus_op.cc minus_op.cu DEPS scale_op) diff --git a/paddle/operators/cross_entropy_op.cc b/paddle/operators/cross_entropy_op.cc index a623c551e1088365ade6f73bc6149977b6ef017e..ab1e1c101a10e09a81f7785d2f1514822e3bdf15 100644 --- a/paddle/operators/cross_entropy_op.cc +++ b/paddle/operators/cross_entropy_op.cc @@ -39,11 +39,10 @@ class OnehotCrossEntropyGradientOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { - auto X_grad = ctx.Output(framework::GradVarName("X")); + auto dX = ctx.Output(framework::GradVarName("X")); auto X = ctx.Input("X"); - // TODO(superjom) add enforce here after helper functions ready - X_grad->Resize(X->dims()); + dX->Resize(X->dims()); } }; @@ -70,9 +69,7 @@ namespace ops = paddle::operators; REGISTER_OP(onehot_cross_entropy, ops::OnehotCrossEntropyOp, ops::OnehotCrossEntropyOpMaker, onehot_cross_entropy_grad, ops::OnehotCrossEntropyGradientOp); -REGISTER_OP_CPU_KERNEL( - onehot_cross_entropy, - ops::OnehotCrossEntropyOpKernel); -REGISTER_OP_CPU_KERNEL( - onehot_cross_entropy_grad, - ops::OnehotCrossEntropyGradientOpKernel); +REGISTER_OP_CPU_KERNEL(onehot_cross_entropy, + ops::OnehotCrossEntropyOpKernel); +REGISTER_OP_CPU_KERNEL(onehot_cross_entropy_grad, + ops::OnehotCrossEntropyGradientOpKernel); diff --git a/paddle/operators/cross_entropy_op.cu b/paddle/operators/cross_entropy_op.cu index 4bbc8f093a794d46737a16488684a6a0cc25e285..d999bfce58c8a6db5c811aad677c07094b881841 100644 --- a/paddle/operators/cross_entropy_op.cu +++ b/paddle/operators/cross_entropy_op.cu @@ -12,10 +12,122 @@ See the License for the specific language governing permissions and limitations under the License. */ -#define EIGEN_USE_GPU -#include "paddle/operators/cross_entropy_op.h" +#include "paddle/framework/op_registry.h" +#include "paddle/platform/assert.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +__host__ __device__ T clipping_log(const T x) { + PADDLE_ASSERT(std::is_floating_point::value); + const T kApproInf = 1e20; + T v = log(x); + if (v == INFINITY) { + return kApproInf; + } + if (v == -INFINITY) { + return -kApproInf; + } + return v; +} + +template +__global__ void CrossEntropyKernel(T* Y, const T* X, const int* label, + const int N, const int D) { + // TOOD(qingqing) define CUDA_1D_KERNEL_LOOP macro in a common file. + // CUDA_1D_KERNEL_LOOP(i, N) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; + i += blockDim.x * gridDim.x) { + PADDLE_ASSERT(label[i] >= 0 && label[i] < D); + Y[i] = -clipping_log(X[i * D + label[i]]); + } +} + +// TODO(qingqing): make zero setting an common function. +template +__global__ void zero(T* X, const int N) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; + i += blockDim.x * gridDim.x) { + X[i] = 0.0; + } +} + +template +__global__ void CrossEntropyGradientKernel(T* dX, const T* dY, const T* X, + const int* label, const int N, + const int D) { + // TOOD(qingqing) define CUDA_1D_KERNEL_LOOP macro in a common file. + // CUDA_1D_KERNEL_LOOP(i, N) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; + i += blockDim.x * gridDim.x) { + int idx = i * D + label[i]; + dX[idx] = -dY[i] / X[idx]; + } +} + +template +class OnehotCrossEntropyOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), + "It must use GPUPlace."); + + auto X = ctx.Input("X"); + const T* Xdata = X->data(); + const int* label_data = ctx.Input("label")->data(); + auto Y = ctx.Output("Y"); + Y->mutable_data(ctx.GetPlace()); + T* Ydata = Y->data(); + + int N = X->dims()[0]; + int D = X->dims()[1]; + int block = 512; + int grid = (N + block - 1) / block; + // TODO(qingqing) launch kernel on specified stream + // base on ExecutionContext. + CrossEntropyKernel<<>>(Ydata, Xdata, label_data, N, D); + } +}; + +template +class OnehotCrossEntropyGradientOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), + "It must use GPUPlace."); + + auto X = ctx.Input("X"); + auto dX = ctx.Output(framework::GradVarName("X")); + auto dY = ctx.Input(framework::GradVarName("Y")); + auto label = ctx.Input("label"); + + auto* dXdata = dX->template mutable_data(ctx.GetPlace()); + auto* dYdata = dY->template data(); + auto* Xdata = X->template data(); + auto* label_data = label->data(); + + int N = X->dims()[0]; + int D = X->dims()[1]; + int block = 512; + int grid = (N * D + block - 1) / block; + zero<<>>(dXdata, N * D); + + grid = (N + block - 1) / block; + // TODO(qingqing): launch kernel on specified stream + // base on ExecutionContext. + CrossEntropyGradientKernel<<>>(dXdata, dYdata, Xdata, + label_data, N, D); + } +}; + +} // namespace operators +} // namespace paddle namespace ops = paddle::operators; -REGISTER_OP_GPU_KERNEL( - onehot_cross_entropy, - ops::OnehotCrossEntropyOpKernel); +REGISTER_OP_GPU_KERNEL(onehot_cross_entropy, + ops::OnehotCrossEntropyOpCUDAKernel); +REGISTER_OP_GPU_KERNEL(onehot_cross_entropy_grad, + ops::OnehotCrossEntropyGradientOpCUDAKernel); diff --git a/paddle/operators/cross_entropy_op.h b/paddle/operators/cross_entropy_op.h index b7df92c9a98ebf12b72a8d3d8e8e4e1a950f06c9..eb4d1348de1d940e2648c83c8ba94b289f10c5b2 100644 --- a/paddle/operators/cross_entropy_op.h +++ b/paddle/operators/cross_entropy_op.h @@ -21,7 +21,7 @@ namespace operators { using Tensor = framework::Tensor; template -T tolerable_value(T x) { +inline T tolerable_value(const T x) { static_assert(std::is_floating_point::value, "tolerable_value works only on float, " "double and double double."); @@ -39,10 +39,13 @@ T tolerable_value(T x) { return x; } -template +template class OnehotCrossEntropyOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), + "It must use CPUPlace."); + auto X = ctx.Input("X"); const T* Xdata = X->data(); const int* label_data = ctx.Input("label")->data(); @@ -62,10 +65,13 @@ class OnehotCrossEntropyOpKernel : public framework::OpKernel { } }; -template +template class OnehotCrossEntropyGradientOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), + "It must use CPUPlace."); + auto X = ctx.Input("X"); auto dX = ctx.Output(framework::GradVarName("X")); auto dY = ctx.Input(framework::GradVarName("Y")); @@ -79,6 +85,8 @@ class OnehotCrossEntropyGradientOpKernel : public framework::OpKernel { const int batch_size = X->dims()[0]; const int class_num = X->dims()[1]; + // TODO(qingqing): make zero setting an common function. + memset(dXdata, 0, sizeof(T) * batch_size * class_num); for (int i = 0; i < batch_size; ++i) { int index = i * class_num + label_data[i]; dXdata[index] = -tolerable_value(dYdata[i] / Xdata[index]); diff --git a/paddle/operators/gather.h b/paddle/operators/gather.h index d6e6990394e46ba06c4bacfe33ca522f3ff1413a..92fb51ec17709bc6f8abb2f516a9240fb5dc3a77 100644 --- a/paddle/operators/gather.h +++ b/paddle/operators/gather.h @@ -17,6 +17,7 @@ limitations under the License. */ #include #include "paddle/framework/ddim.h" +#include "paddle/framework/eigen.h" #include "paddle/framework/tensor.h" #include "paddle/platform/place.h" @@ -25,13 +26,13 @@ namespace operators { // Implementation of CPU copy template -void CPUGather(const T* params, const int* indices, const int slice_size, +void CPUGather(const T* src, const int* indices, const int slice_size, const int index_size, T* output) { const size_t slice_bytes = slice_size * sizeof(T); for (int i = 0; i < index_size; ++i) { int index_ = indices[i]; - memcpy(output + i * slice_size, params + index_ * slice_size, slice_bytes); + memcpy(output + i * slice_size, src + index_ * slice_size, slice_bytes); } } @@ -55,7 +56,7 @@ void Gather(const platform::Place& place, const paddle::framework::Tensor* src, int index_size = index->dims()[0]; auto src_dims = src->dims(); - paddle::framework::DDim output_dims(src_dims); + framework::DDim output_dims(src_dims); output_dims[0] = index_size; // slice size diff --git a/paddle/operators/gather_op.cc b/paddle/operators/gather_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..123bed296c462c30bddd3bfbd530098fdbfe4856 --- /dev/null +++ b/paddle/operators/gather_op.cc @@ -0,0 +1,72 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/gather_op.h" +#include "paddle/framework/ddim.h" + +namespace paddle { +namespace operators { + +class GatherOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + int batch_size = ctx.Input("Index")->dims()[0]; + PADDLE_ENFORCE_GE(batch_size, 0, "Batch size must be >0"); + framework::DDim output_dims(ctx.Input("X")->dims()); + output_dims[0] = batch_size; + ctx.Output("Out")->Resize(output_dims); + } +}; + +class GatherGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + auto X_grad = ctx.Output(framework::GradVarName("X")); + auto X = ctx.Input("X"); + + X_grad->Resize(X->dims()); + } +}; + +class GatherOpMaker : public framework::OpProtoAndCheckerMaker { + public: + GatherOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "The source input of gather op"); + AddInput("Index", "The index input of gather op"); + AddOutput("Out", "The output of add op"); + AddComment(R"DOC( +Gather Operator by selecting from the first axis, + +Out = X[Index] +)DOC"); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(gather, ops::GatherOp, ops::GatherOpMaker, gather_grad, + ops::GatherGradOp); +REGISTER_OP_CPU_KERNEL(gather, + ops::GatherOpKernel); +REGISTER_OP_CPU_KERNEL( + gather_grad, + ops::GatherGradientOpKernel); diff --git a/paddle/operators/gather_op.cu b/paddle/operators/gather_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..3f04a7b3f8142106917975cd1e0413fa1633a298 --- /dev/null +++ b/paddle/operators/gather_op.cu @@ -0,0 +1,20 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#define EIGEN_USE_GPU +#include "paddle/operators/gather_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL(gather, + ops::GatherOpKernel); diff --git a/paddle/operators/gather_op.h b/paddle/operators/gather_op.h new file mode 100644 index 0000000000000000000000000000000000000000..381854f301870beadb72d9e9b4eb17ff199960fb --- /dev/null +++ b/paddle/operators/gather_op.h @@ -0,0 +1,53 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "gather.h" +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" +#include "scatter.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class GatherOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto *X = ctx.Input("X"); + auto *Index = ctx.Input("Index"); + auto *Y = ctx.Output("Out"); + + Y->mutable_data(ctx.GetPlace()); + Gather(ctx.GetPlace(), X, Index, Y); + } +}; + +template +class GatherGradientOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto *Index = ctx.Input("Index"); + auto *dX = ctx.Output(framework::GradVarName("X")); + auto *dO = ctx.Input(framework::GradVarName("Out")); + + dX->mutable_data(ctx.GetPlace()); + ScatterUpdate(ctx.GetPlace(), dO, Index, dX); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/gaussian_random_op.cc b/paddle/operators/gaussian_random_op.cc index f30bbce9586d61063b4b61d98695bb568ef73c8d..a85363ad81d2a23e7267026c067f74f8c94c4786 100644 --- a/paddle/operators/gaussian_random_op.cc +++ b/paddle/operators/gaussian_random_op.cc @@ -1,11 +1,8 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -19,25 +16,25 @@ namespace paddle { namespace operators { template -class GaussianRandomKernel : public framework::OpKernel { +class CPUGaussianRandomKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { float mean = context.op_.GetAttr("mean"); float std = context.op_.GetAttr("std"); - auto* tensor = context.Output(0); + auto* tensor = context.Output("Out"); T* data = tensor->mutable_data(context.GetPlace()); - // TODO(dzh): attribute does not support unsigned int. - // And we need a global random seed configuration. - int seed = context.op_.GetAttr("seed"); + unsigned int seed = + static_cast(context.op_.GetAttr("seed")); + std::minstd_rand engine; if (seed == 0) { seed = std::random_device()(); } - std::mt19937 g(seed); - std::normal_distribution distribution(mean, std); + engine.seed(seed); + std::normal_distribution dist(mean, std); ssize_t size = framework::product(tensor->dims()); - for (int i = 0; i < size; ++i) { - data[i] = distribution(g); + for (ssize_t i = 0; i < size; ++i) { + data[i] = dist(engine); } } }; @@ -48,7 +45,7 @@ class GaussianRandomOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext& context) const override { - auto* tensor = context.Output(0); + auto* tensor = context.Output("Out"); auto dims = GetAttr>("dims"); PADDLE_ENFORCE(dims.size() > 0UL, "dims can be one int or array. dims must be set."); @@ -68,8 +65,8 @@ Use to initialize tensor with gaussian random generator. )DOC"); AddAttr>("dims", "The dimension of random tensor."); - AddAttr("mean", "mean value of random.").SetDefault(.0f); - AddAttr("std", "minimum value of random value.").SetDefault(1.0f); + AddAttr("mean", "mean of random tensor.").SetDefault(.0f); + AddAttr("std", "std of random tensor.").SetDefault(1.0f); AddAttr("seed", "Random seed of generator." "0 means use system wide seed") @@ -83,4 +80,4 @@ Use to initialize tensor with gaussian random generator. namespace ops = paddle::operators; REGISTER_OP_WITHOUT_GRADIENT(gaussian_random, ops::GaussianRandomOp, ops::GaussianRandomOpMaker); -REGISTER_OP_CPU_KERNEL(gaussian_random, ops::GaussianRandomKernel); +REGISTER_OP_CPU_KERNEL(gaussian_random, ops::CPUGaussianRandomKernel); diff --git a/paddle/operators/gaussian_random_op.cu b/paddle/operators/gaussian_random_op.cu index 1340b1e1e9f19fd96ced9e57fab75fe9d33bc84e..018a4bfcb26b9008c054000c91edf01e371fd82b 100644 --- a/paddle/operators/gaussian_random_op.cu +++ b/paddle/operators/gaussian_random_op.cu @@ -1,53 +1,65 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include -#include -#include "paddle/platform/dynload/curand.h" -#include "paddle/platform/gpu_info.h" - +#include +#include +#include +#include #include "paddle/framework/op_registry.h" +#include "paddle/framework/operator.h" namespace paddle { namespace operators { template -class GaussianRandomKernel : public framework::OpKernel { +struct GaussianGenerator { + T mean_, std_; + unsigned int seed_; + + __host__ __device__ GaussianGenerator(T mean, T std, int seed) + : mean_(mean), std_(std), seed_(seed) {} + + __host__ __device__ T operator()(const unsigned int n) const { + thrust::minstd_rand rng; + rng.seed(seed_); + thrust::normal_distribution dist(mean_, std_); + rng.discard(n); + return dist(rng); + } +}; + +template +class GPUGaussianRandomKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - float mean = context.op_.GetAttr("mean"); - float std = context.op_.GetAttr("std"); - auto* tensor = context.Output(0); + auto* tensor = context.Output("Out"); T* data = tensor->mutable_data(context.GetPlace()); - - int seed = context.op_.GetAttr("seed"); + unsigned int seed = + static_cast(context.op_.GetAttr("seed")); if (seed == 0) { std::random_device rd; seed = rd(); } - curandGenerator_t g; - PADDLE_ENFORCE(platform::dynload::curandCreateGenerator( - &g, CURAND_RNG_PSEUDO_DEFAULT)); - PADDLE_ENFORCE( - platform::dynload::curandSetPseudoRandomGeneratorSeed(g, seed)); - platform::dynload::curandGenerateNormal( - g, data, framework::product(tensor->dims()), mean, std); + T mean = static_cast(context.op_.GetAttr("mean")); + T std = static_cast(context.op_.GetAttr("std")); + thrust::counting_iterator index_sequence_begin(0); + ssize_t N = framework::product(tensor->dims()); + thrust::transform(index_sequence_begin, index_sequence_begin + N, + thrust::device_ptr(data), + GaussianGenerator(mean, std, seed)); } }; } // namespace operators } // namespace paddle -namespace ops = paddle::operators; -REGISTER_OP_GPU_KERNEL(gaussian_random, ops::GaussianRandomKernel); +REGISTER_OP_GPU_KERNEL(gaussian_random, + paddle::operators::GPUGaussianRandomKernel); diff --git a/paddle/operators/math/math_function.cc b/paddle/operators/math/math_function.cc index affdd1ac2cd486930881ee6b34a4b32f41df7ee9..1e86fc3d166077265e0f433a6712b0665ea5a152 100644 --- a/paddle/operators/math/math_function.cc +++ b/paddle/operators/math/math_function.cc @@ -25,8 +25,8 @@ void gemm(const CBLAS_TRANSPOSE transA, const float alpha, const float* A, const float* B, const float beta, float* C, platform::DeviceContext* context) { - int lda = K; - int ldb = N; + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; int ldc = N; cblas_sgemm(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); @@ -40,8 +40,8 @@ void gemm(const CBLAS_TRANSPOSE transA, const double* B, const double beta, double* C, platform::DeviceContext* context) { - int lda = K; - int ldb = N; + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; int ldc = N; cblas_dgemm(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); diff --git a/paddle/operators/minus_op.cc b/paddle/operators/minus_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..1eee9644babbdfac68821ca774845ad8ebbd5aee --- /dev/null +++ b/paddle/operators/minus_op.cc @@ -0,0 +1,87 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/minus_op.h" +#include "paddle/operators/net_op.h" + +namespace paddle { +namespace operators { + +class MinusOp : public framework::OperatorWithKernel { + public: + MinusOp(const std::string &type, const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : OperatorWithKernel(type, inputs, outputs, attrs) {} + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + auto *left_tensor = ctx.Input("X"); + auto *right_tensor = ctx.Input("Y"); + + PADDLE_ENFORCE_EQ( + framework::product(left_tensor->dims()), + framework::product(right_tensor->dims()), + "Minus operator must take two tensor with same num of elements"); + ctx.Output("Out")->Resize(left_tensor->dims()); + } +}; + +class MinusOpMaker : public framework::OpProtoAndCheckerMaker { + public: + MinusOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "The left tensor of minus operator.").NotInGradient(); + AddInput("Y", "The right tensor of minus operator.").NotInGradient(); + AddOutput("Out", "The output tensor of minus operator.").NotInGradient(); + + AddComment(R"DOC(Minus Operator + +Equation: Out = X - Y +)DOC"); + } +}; +template +class MinusGradOp : public NetOp { + public: + MinusGradOp(const std::string &type, const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : NetOp(type, inputs, outputs, attrs) { + auto out_grad = Input(framework::GradVarName("Out")); + auto x_grad = Output(framework::GradVarName("X")); + auto y_grad = Output(framework::GradVarName("Y")); + + // x_grad = out_grad + AppendOp(framework::OpRegistry::CreateOp("identity", {{"X", {out_grad}}}, + {{"Out", {x_grad}}}, {})); + + framework::AttributeMap scale_attr; + scale_attr["scale"] = static_cast(-1); + AppendOp(framework::OpRegistry::CreateOp("scale", {{"X", {out_grad}}}, + {{"Out", {y_grad}}}, scale_attr)); + CompleteAddOp(false); + } +}; + +} // namespace operators +} // namespace paddle + +USE_OP(scale); +USE_OP_ITSELF(identity); +namespace ops = paddle::operators; +REGISTER_OP(minus, ops::MinusOp, ops::MinusOpMaker, minus_grad, + ops::MinusGradOp); +REGISTER_OP_CPU_KERNEL(minus, + ops::MinusKernel); diff --git a/paddle/operators/minus_op.cu b/paddle/operators/minus_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..a8375cc6301b2c1a917299c3933b03226bb72907 --- /dev/null +++ b/paddle/operators/minus_op.cu @@ -0,0 +1,18 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/minus_op.h" + +REGISTER_OP_GPU_KERNEL( + minus, paddle::operators::MinusKernel); diff --git a/paddle/operators/minus_op.h b/paddle/operators/minus_op.h new file mode 100644 index 0000000000000000000000000000000000000000..6310a4fd5141516cff4fc7acbe1d17913a1b5506 --- /dev/null +++ b/paddle/operators/minus_op.h @@ -0,0 +1,39 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +class MinusKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* left_tensor = context.Input("X"); + auto* right_tensor = context.Input("Y"); + auto* out_tensor = context.Output("Out"); + + out_tensor->mutable_data(context.GetPlace()); + auto& dev = context.GetEigenDevice(); + framework::EigenVector::Flatten(*out_tensor).device(dev) = + framework::EigenVector::Flatten(*left_tensor) - + framework::EigenVector::Flatten(*right_tensor); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/mul_op.cc b/paddle/operators/mul_op.cc index 95d19fb6aad37143e65759b03e12e3e78bce5915..173cc3850ca9d97200e272ec59d1bd3fe09b5053 100644 --- a/paddle/operators/mul_op.cc +++ b/paddle/operators/mul_op.cc @@ -13,11 +13,12 @@ limitations under the License. */ #include "paddle/operators/mul_op.h" -#include "paddle/operators/math/math_function.h" namespace paddle { namespace operators { +using framework::Tensor; + class MulOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -59,10 +60,23 @@ class MulOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override {} - std::string DebugString() const override { - LOG(INFO) << "MulGrad"; - return ""; + void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null"); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), "Input(Y) should not be null"); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), + "Input(Out@GRAD) should not be null"); + auto x_dims = ctx.Input("X")->dims(); + auto y_dims = ctx.Input("Y")->dims(); + auto out_dims = ctx.Input(framework::GradVarName("Out"))->dims(); + auto *x_grad = ctx.Output(framework::GradVarName("X")); + auto *y_grad = ctx.Output(framework::GradVarName("Y")); + PADDLE_ENFORCE(x_dims[0] == out_dims[0], + "Out@GRAD M X N must equal to X dims 0, M "); + PADDLE_ENFORCE(y_dims[1] == out_dims[1], + "Out@GRAD M X N must equal to Y dims 1, N "); + + x_grad->Resize(x_dims); + y_grad->Resize(y_dims); } }; @@ -72,3 +86,5 @@ class MulOpGrad : public framework::OperatorWithKernel { namespace ops = paddle::operators; REGISTER_OP(mul, ops::MulOp, ops::MulOpMaker, mul_grad, ops::MulOpGrad); REGISTER_OP_CPU_KERNEL(mul, ops::MulKernel); +REGISTER_OP_CPU_KERNEL(mul_grad, + ops::MulGradKernel); diff --git a/paddle/operators/mul_op.cu b/paddle/operators/mul_op.cu index 346a7e505d123b5e4e831daa39a1f6349b3dcccf..a81444dbe63edeecedc5d822c65ff56c42b5db90 100644 --- a/paddle/operators/mul_op.cu +++ b/paddle/operators/mul_op.cu @@ -17,3 +17,5 @@ namespace ops = paddle::operators; REGISTER_OP_GPU_KERNEL(mul, ops::MulKernel); +REGISTER_OP_GPU_KERNEL(mul_grad, + ops::MulGradKernel); diff --git a/paddle/operators/mul_op.h b/paddle/operators/mul_op.h index b7812fd1a7a72f5ce543e18c8b7b5b51deff2204..8facc0281449785bf40726f23ca2fd5d166ff272 100644 --- a/paddle/operators/mul_op.h +++ b/paddle/operators/mul_op.h @@ -31,18 +31,34 @@ template class MulKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - Eigen::array, 1> dim_pair = { - {Eigen::IndexPair(1, 0)}}; - auto* input0 = context.Input("X"); - auto* input1 = context.Input("Y"); - auto* output = context.Output("Out"); - output->mutable_data(context.GetPlace()); - auto X = EigenMatrix::From(*input0); - auto Y = EigenMatrix::From(*input1); - auto Z = EigenMatrix::From(*output); - auto& place = context.GetEigenDevice(); - - Z.device(place) = X.contract(Y, dim_pair); + auto* X = context.Input("X"); + auto* Y = context.Input("Y"); + auto* Z = context.Output("Out"); + Z->mutable_data(context.GetPlace()); + auto* device_context = + const_cast(context.device_context_); + math::matmul(*X, false, *Y, false, 1, Z, 0, device_context); + } +}; + +template +class MulGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* X = ctx.Input("X"); + auto* Y = ctx.Input("Y"); + auto* dOut = ctx.Input(framework::GradVarName("Out")); + + auto* dX = ctx.Output(framework::GradVarName("X")); + auto* dY = ctx.Output(framework::GradVarName("Y")); + dX->mutable_data(ctx.GetPlace()); + dY->mutable_data(ctx.GetPlace()); + auto* device_context = + const_cast(ctx.device_context_); + // dX = dOut * Y'. dX: M x K, dOut : M x N, Y : K x N + math::matmul(*dOut, false, *Y, true, 1, dX, 0, device_context); + // dY = X' * dOut. dY: K x N, dOut : M x N, X : M x K + math::matmul(*X, true, *dOut, false, 1, dY, 0, device_context); } }; diff --git a/paddle/operators/net_op.cc b/paddle/operators/net_op.cc index a7d710511093dfbe13a13b1222b0230bba0398bd..44d925f0b0cc5ff20d52e548816f118c2027343a 100644 --- a/paddle/operators/net_op.cc +++ b/paddle/operators/net_op.cc @@ -68,10 +68,15 @@ std::string NetOp::DebugString() const { bool NetOp::IsNetOp() const { return true; } std::vector NetOp::OutputVars(bool has_intermediate) const { + std::vector all; + for (auto& pair : this->outputs_) { + for (auto& var_name : pair.second) { + all.push_back(var_name); + } + } if (has_intermediate) { - return this->outputs_.at(kAll); + return all; } - auto& all = this->outputs_.at(kAll); std::vector ret_val; for (auto& each : all) { if (!Contains(intermediate_outputs_, each)) { @@ -81,9 +86,8 @@ std::vector NetOp::OutputVars(bool has_intermediate) const { return ret_val; } -NetOp::NetOp(const std::string& type, - const framework::OperatorBase::VarNameMap& inputs, - const framework::OperatorBase::VarNameMap& outputs, +NetOp::NetOp(const std::string& type, const framework::VariableNameMap& inputs, + const framework::VariableNameMap& outputs, const framework::AttributeMap& attrs) : framework::OperatorBase(type, inputs, outputs, attrs) {} diff --git a/paddle/operators/net_op.h b/paddle/operators/net_op.h index 885ac6eeca65998dea62c1db40b9261cceb97805..fcd8134b2c19cae6a4d006a4cd6fe32d2d627c34 100644 --- a/paddle/operators/net_op.h +++ b/paddle/operators/net_op.h @@ -38,8 +38,10 @@ class NetOp : public framework::OperatorBase { public: static const char kAll[]; NetOp() : framework::OperatorBase("plain_net", {}, {}, {}) {} - NetOp(const std::string& type, const VarNameMap& inputs, - const VarNameMap& outputs, const framework::AttributeMap& attrs); + + NetOp(const std::string& type, const framework::VariableNameMap& inputs, + const framework::VariableNameMap& outputs, + const framework::AttributeMap& attrs); NetOp(const NetOp& o) : framework::OperatorBase(o.type_, {}, {}, o.attrs_) { this->ops_.reserve(o.ops_.size()); @@ -84,13 +86,14 @@ class NetOp : public framework::OperatorBase { return true; } - void AddOp(const framework::OperatorBase& op) { AddOp(op.Clone()); } + void AppendOp(const framework::OperatorBase& op) { AppendOp(op.Clone()); } /** * @brief Add an operator by ptr */ - void AddOp(std::unique_ptr op) { - PADDLE_ENFORCE(!add_op_done_, "Cannot AddOp when this network is sealed"); + void AppendOp(std::unique_ptr op) { + PADDLE_ENFORCE(!add_op_done_, + "Cannot AppendOp when this network is sealed"); PADDLE_ENFORCE_NOT_NULL(op, "Cannot Insert Null op"); ops_.push_back(std::move(op)); } diff --git a/paddle/operators/net_op_test.cc b/paddle/operators/net_op_test.cc index e9598610c0a74e08a613a397109ad65994821498..f2e98ee7a1e14ee739abba01e97608845ce557f4 100644 --- a/paddle/operators/net_op_test.cc +++ b/paddle/operators/net_op_test.cc @@ -38,10 +38,10 @@ TEST(OpKernel, all) { auto net = std::make_shared(); ASSERT_NE(net, nullptr); - net->AddOp(std::unique_ptr( + net->AppendOp(std::unique_ptr( new TestOp("test", {{"X", {"x"}}, {"W", {"w1"}}, {"b", {"b1"}}}, {{"Out", {"y"}}}, {}))); - net->AddOp(std::unique_ptr( + net->AppendOp(std::unique_ptr( new TestOp("test", {{"X", {"y"}}, {"W", {"w2"}}, {"b", {"b2"}}}, {{"Out", {"z"}}}, {}))); @@ -61,7 +61,7 @@ TEST(NetOp, insert_op) { auto op1 = std::unique_ptr( new framework::NOP("empty", {{"X", {"x"}}, {"W", {"w1"}}, {"b", {"b1"}}}, {{"Out", {"y"}}}, {})); - net.AddOp(*op1); + net.AppendOp(*op1); net.InsertOp(0, *op1); ASSERT_EQ(2UL, net.ops_.size()); net.InsertOp(2, std::move(op1)); @@ -70,16 +70,16 @@ TEST(NetOp, insert_op) { TEST(NetOp, Clone) { NetOp net; - net.AddOp( + net.AppendOp( std::unique_ptr(new framework::NOP{"empty", {}, {}, {}})); - net.AddOp(std::unique_ptr( + net.AppendOp(std::unique_ptr( new framework::NOP{"empty2", {}, {}, {}})); net.CompleteAddOp(true); auto new_net_op = net.Clone(); ASSERT_NE(new_net_op, nullptr); ASSERT_TRUE(new_net_op->IsNetOp()); auto* new_net = static_cast(new_net_op.get()); - ASSERT_EQ(2, new_net->ops_.size()); + ASSERT_EQ(2UL, new_net->ops_.size()); ASSERT_EQ(new_net->ops_[0]->Type(), "empty"); ASSERT_EQ(new_net->ops_[1]->Type(), "empty2"); } diff --git a/paddle/operators/recurrent_op.cc b/paddle/operators/recurrent_op.cc index 78ce0ba3c0fa4fe380e49a848c2434fe593cd00b..16bd249cb3d989c695ec9378f09d48833d70be58 100644 --- a/paddle/operators/recurrent_op.cc +++ b/paddle/operators/recurrent_op.cc @@ -131,8 +131,8 @@ const rnn::ArgumentName RecurrentGradientOp::kArgName{ "memories", "pre_memories", "boot_memories@grad"}; RecurrentOp::RecurrentOp(const std::string& type, - const framework::OperatorBase::VarNameMap& inputs, - const framework::OperatorBase::VarNameMap& outputs, + const framework::VariableNameMap& inputs, + const framework::VariableNameMap& outputs, const framework::AttributeMap& attrs) : OperatorBase(type, inputs, outputs, attrs) { rnn::InitArgument(kArgName, &arg_, *this); @@ -223,8 +223,8 @@ void RecurrentGradientAlgorithm::InferShape(const Scope& scope) const { } RecurrentGradientOp::RecurrentGradientOp( - const std::string& type, const framework::OperatorBase::VarNameMap& inputs, - const framework::OperatorBase::VarNameMap& outputs, + const std::string& type, const framework::VariableNameMap& inputs, + const framework::VariableNameMap& outputs, const framework::AttributeMap& attrs) : OperatorBase(type, inputs, outputs, attrs) { rnn::InitArgument(kArgName, &arg_, *this); diff --git a/paddle/operators/recurrent_op.h b/paddle/operators/recurrent_op.h index bcfa817de8242153b164fa091309f19a6ad8a246..1033d657a3a8f96c8b3dae8dd93d3f1f6840b59b 100644 --- a/paddle/operators/recurrent_op.h +++ b/paddle/operators/recurrent_op.h @@ -114,8 +114,9 @@ class RecurrentGradientAlgorithm { class RecurrentOp : public framework::OperatorBase { public: - RecurrentOp(const std::string& type, const VarNameMap& inputs, - const VarNameMap& outputs, const framework::AttributeMap& attrs); + RecurrentOp(const std::string& type, const framework::VariableNameMap& inputs, + const framework::VariableNameMap& outputs, + const framework::AttributeMap& attrs); RecurrentOp(const RecurrentOp& o) : framework::OperatorBase( @@ -150,8 +151,9 @@ class RecurrentOp : public framework::OperatorBase { class RecurrentGradientOp : public framework::OperatorBase { public: - RecurrentGradientOp(const std::string& type, const VarNameMap& inputs, - const VarNameMap& outputs, + RecurrentGradientOp(const std::string& type, + const framework::VariableNameMap& inputs, + const framework::VariableNameMap& outputs, const framework::AttributeMap& attrs); RecurrentGradientOp(const RecurrentGradientOp& o) diff --git a/paddle/operators/rowwise_add_op.cc b/paddle/operators/rowwise_add_op.cc index 8375d988045dc24fa1109646b46ff477e2a78132..6825dce332adc0dc11dda187d1bd367875b8603e 100644 --- a/paddle/operators/rowwise_add_op.cc +++ b/paddle/operators/rowwise_add_op.cc @@ -17,7 +17,9 @@ namespace paddle { namespace operators { -class RowWiseAddOp : public framework::OperatorWithKernel { +using framework::Tensor; + +class RowwiseAddOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -34,9 +36,9 @@ class RowWiseAddOp : public framework::OperatorWithKernel { } }; -class RowWiseAddOpMaker : public framework::OpProtoAndCheckerMaker { +class RowwiseAddOpMaker : public framework::OpProtoAndCheckerMaker { public: - RowWiseAddOpMaker(framework::OpProto *proto, + RowwiseAddOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "The left input of row-wise add op, must be matrix"); @@ -49,12 +51,32 @@ for i in xrange(X.shape[0]): )DOC"); } }; +class RowwiseAddGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "X should not be null"); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("b"), "b should not be null"); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), + "Input(Out@GRAD) should not be null"); + auto dims0 = ctx.Input("X")->dims(); + auto dims1 = ctx.Input("b")->dims(); + PADDLE_ENFORCE_EQ(1, dims1.size(), "b dims should be 1") + ctx.Output(framework::GradVarName("X"))->Resize(dims0); + ctx.Output(framework::GradVarName("b"))->Resize(dims1); + } +}; } // namespace operators } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP_WITHOUT_GRADIENT(rowwise_add, ops::RowWiseAddOp, - ops::RowWiseAddOpMaker); +REGISTER_OP(rowwise_add, ops::RowwiseAddOp, ops::RowwiseAddOpMaker, + rowwise_add_grad, ops::RowwiseAddGradOp); +REGISTER_OP_CPU_KERNEL( + rowwise_add, ops::RowwiseAddKernel); REGISTER_OP_CPU_KERNEL( - rowwise_add, ops::RowWiseAddKernel); + rowwise_add_grad, + ops::RowwiseAddGradKernel); diff --git a/paddle/operators/rowwise_add_op.cu b/paddle/operators/rowwise_add_op.cu index 86f80b81228a69ac4c05a4693901570f2b9966e0..4a57f64c890ce99d6060faec6a4a01b107403344 100644 --- a/paddle/operators/rowwise_add_op.cu +++ b/paddle/operators/rowwise_add_op.cu @@ -17,4 +17,7 @@ namespace ops = paddle::operators; REGISTER_OP_GPU_KERNEL( - rowwise_add, ops::RowWiseAddKernel); + rowwise_add, ops::RowwiseAddKernel); +REGISTER_OP_GPU_KERNEL( + rowwise_add_grad, + ops::RowwiseAddGradKernel); diff --git a/paddle/operators/rowwise_add_op.h b/paddle/operators/rowwise_add_op.h index 01f88f2198774fbaa4c98ff9bf286f2f08496a9a..1cbd8bb31ad90a32d8a4e3bb59617d0b5384e470 100644 --- a/paddle/operators/rowwise_add_op.h +++ b/paddle/operators/rowwise_add_op.h @@ -1,16 +1,16 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ #pragma once #include "paddle/framework/eigen.h" @@ -28,7 +28,7 @@ template ; template -class RowWiseAddKernel : public framework::OpKernel { +class RowwiseAddKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto out = context.Output("Out"); @@ -47,5 +47,25 @@ class RowWiseAddKernel : public framework::OpKernel { } }; +template +class RowwiseAddGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* dOut = context.Input(framework::GradVarName("Out")); + auto* dX = context.Output(framework::GradVarName("X")); + auto* db = context.Output(framework::GradVarName("b")); + dX->mutable_data(context.GetPlace()); + db->mutable_data(context.GetPlace()); + + auto OutGrad = EigenMatrix::From(*dOut); + auto place = context.GetEigenDevice(); + EigenMatrix::From(*dX).device(place) = OutGrad; + + // https://eigen.tuxfamily.org/dox/unsupported/TensorBase_8h_source.html + // colwise add + Eigen::array dims{{0}}; /* dimension to reduce */ + EigenVector::Flatten(*db).device(place) = OutGrad.sum(dims); + } +}; } // namespace operators } // namespace paddle diff --git a/paddle/operators/scale_op.cc b/paddle/operators/scale_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..8e96a74c94ab7ff4d8c3266695e5157aff67905b --- /dev/null +++ b/paddle/operators/scale_op.cc @@ -0,0 +1,105 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/scale_op.h" +#include "paddle/operators/net_op.h" + +namespace paddle { +namespace operators { + +class ScaleOp : public framework::OperatorWithKernel { + public: + ScaleOp(const std::string &type, const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : OperatorWithKernel(type, inputs, outputs, attrs) {} + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + auto *in = ctx.Input("X"); + auto *out = ctx.Output("Out"); + out->Resize(in->dims()); + } +}; + +template +class ScaleOpMaker : public framework::OpProtoAndCheckerMaker { + public: + ScaleOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "The input tensor of scale operator.").NotInGradient(); + AddOutput("Out", "The output tensor of scale operator.").NotInGradient(); + AddComment(R"DOC(Scale operator + +The equation is: Out = scale*X +)DOC"); + AddAttr("scale", "scale of scale operator.").SetDefault(1.0); + } +}; + +// Identity Op's gradient is identity op, too. +// Grad(Out=scale(X)) => Grad(X) = scale(Grad(Out)) +template +class ScaleGradOp : public NetOp { + public: + ScaleGradOp(const std::string &type, const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : NetOp(type, inputs, outputs, attrs) { + AppendOp(framework::OpRegistry::CreateOp( + "scale", {{"X", {Input(framework::GradVarName("Out"))}}}, + {{"Out", {Output(framework::GradVarName("X"))}}}, + {{"scale", GetAttr("scale")}})); + CompleteAddOp(false); + } +}; + +// identity is a alias of scale op. This is also a example for creating a alias +// operator. +template +class IdentityOpMaker : public framework::OpProtoAndCheckerMaker { + public: + IdentityOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "input tensor of identity op"); + AddOutput("Out", "output tensor of identity op"); + AddComment("identity operator. Just a alias of scale op which scale = 1.0"); + } +}; + +template +class IdentityOp : public NetOp { + public: + IdentityOp(const std::string &type, const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : NetOp(type, inputs, outputs, attrs) { + AppendOp(framework::OpRegistry::CreateOp( + "scale", {{"X", {Input("X")}}}, {{"Out", {Output("Out")}}}, + {{"scale", static_cast(1)}})); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP(scale, ops::ScaleOp, ops::ScaleOpMaker, scale_grad, + ops::ScaleGradOp); +REGISTER_OP_CPU_KERNEL(scale, + ops::ScaleKernel); +REGISTER_OP_WITHOUT_GRADIENT(identity, ops::IdentityOp, + ops::IdentityOpMaker); diff --git a/paddle/operators/scale_op.cu b/paddle/operators/scale_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..63efbe0da8a90dd237d2d692076075339179acf6 --- /dev/null +++ b/paddle/operators/scale_op.cu @@ -0,0 +1,18 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/scale_op.h" + +REGISTER_OP_GPU_KERNEL( + scale, paddle::operators::ScaleKernel); diff --git a/paddle/operators/scale_op.h b/paddle/operators/scale_op.h new file mode 100644 index 0000000000000000000000000000000000000000..aea64f1b0428ffe79ba8d90cf79dbfd2b5ef36f4 --- /dev/null +++ b/paddle/operators/scale_op.h @@ -0,0 +1,40 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once + +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { +template +class ScaleKernel : public framework::OpKernel { + public: + virtual void Compute(const framework::ExecutionContext& context) const { + auto* tensor = context.Output("Out"); + auto* in = context.Input("X"); + tensor->mutable_data(in->place()); + + auto scale = static_cast(context.op_.GetAttr("scale")); + + auto eigen_out = framework::EigenVector::Flatten(*tensor); + auto eigen_in = framework::EigenVector::Flatten(*in); + auto& dev = context.GetEigenDevice(); + eigen_out.device(dev) = scale * eigen_in; + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/uniform_random_op.cc b/paddle/operators/uniform_random_op.cc index a0a0d4d914b37fca4250e5218a953f573611a086..29491137e6d8b4bfa2d0d07d48ffed1212a6131f 100644 --- a/paddle/operators/uniform_random_op.cc +++ b/paddle/operators/uniform_random_op.cc @@ -1,11 +1,8 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -39,7 +36,8 @@ class CPUUniformRandomKernel : public framework::OpKernel { std::uniform_real_distribution dist( static_cast(context.op_.GetAttr("min")), static_cast(context.op_.GetAttr("max"))); - for (ssize_t i = 0; i < framework::product(tensor->dims()); ++i) { + ssize_t size = framework::product(tensor->dims()); + for (ssize_t i = 0; i < size; ++i) { data[i] = dist(engine); } } @@ -66,7 +64,6 @@ class UniformRandomOpMaker : public framework::OpProtoAndCheckerMaker { : framework::OpProtoAndCheckerMaker(proto, op_checker) { AddOutput("Out", "The output tensor of uniform random op"); AddComment(R"DOC(Uniform random operator. - Used to initialize tensor with uniform random generator. )DOC"); AddAttr>("dims", "the dimension of random tensor"); diff --git a/paddle/operators/uniform_random_op.cu b/paddle/operators/uniform_random_op.cu index 7a243555b6385af690e9632dfa81bf96d70f925d..1d6709934cbbcf50265eabef87c857654f783ed8 100644 --- a/paddle/operators/uniform_random_op.cu +++ b/paddle/operators/uniform_random_op.cu @@ -1,11 +1,8 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/paddle/parameter/Argument.cpp b/paddle/parameter/Argument.cpp index 0547ac93cd183afbcede41d280c6b4b16ed7dab1..2b945de18a4cdc3712ac7e282494ed7d3ecc600d 100644 --- a/paddle/parameter/Argument.cpp +++ b/paddle/parameter/Argument.cpp @@ -276,17 +276,21 @@ int32_t Argument::resizeAndCopyFrom(const Argument& src, void Argument::concat(const std::vector& args, const std::vector& selectRows, const std::vector& seqStartPos, + const std::vector& copySize, bool useGpu, hl_stream_t stream, PassType passType) { CHECK(!subSequenceStartPositions) << "undefined behavior for subsequence positions"; - size_t batchSize = selectRows.size(); + size_t batchSize = 0; + for (size_t i = 0; i < copySize.size(); ++i) + batchSize += copySize[i] * (seqStartPos[i + 1] - seqStartPos[i]); + auto copyArg = [batchSize, stream](MatrixPtr& dst, MatrixPtr src, - int startRow, - int pos, + int desStartRow, + int srcStartRow, int size, bool useGpu) { if (!src) { @@ -300,14 +304,14 @@ void Argument::concat(const std::vector& args, dst->resize(batchSize, width); } - MatrixPtr tmpMatrix = dst->subMatrix(startRow, size); - tmpMatrix->copyFrom(*src->subMatrix(pos, size), stream); + MatrixPtr tmpMatrix = dst->subMatrix(desStartRow, size); + tmpMatrix->copyFrom(*src->subMatrix(srcStartRow, size), stream); }; auto copyIds = [batchSize, stream](IVectorPtr& dst, const IVectorPtr& src, - int startRow, - int pos, + int desStartRow, + int srcStartRow, int size, bool useGpu) { if (!src) { @@ -315,13 +319,14 @@ void Argument::concat(const std::vector& args, return; } IVector::resizeOrCreate(dst, batchSize, useGpu); - dst->subVec(startRow, size)->copyFrom(*src->subVec(pos, size), stream); + dst->subVec(desStartRow, size) + ->copyFrom(*src->subVec(srcStartRow, size), stream); }; auto copyStrs = [batchSize, stream](SVectorPtr& dst, const SVectorPtr& src, - int startRow, - int pos, + int desStartRow, + int srcStartRow, int size, bool useGpu) { if (!src) { @@ -333,30 +338,31 @@ void Argument::concat(const std::vector& args, } else { dst->resize(batchSize); } - std::copy( - src->begin() + pos, src->begin() + pos + size, dst->begin() + startRow); + std::copy(src->begin() + srcStartRow, + src->begin() + srcStartRow + size, + dst->begin() + desStartRow); }; dataId = args[0].dataId; CHECK_NE(seqStartPos.size(), 0UL); - size_t sampleNum = seqStartPos.size() - 1; - for (size_t i = 0; i < sampleNum; ++i) { + int desStartRow = 0; + for (size_t i = 0; i < copySize.size(); ++i) { int startPos = seqStartPos[i]; int endPos = seqStartPos[i + 1]; CHECK_GE(args.size(), static_cast(endPos - startPos)); for (int j = startPos; j < endPos; ++j) { const Argument& arg = args[j - startPos]; - CHECK_EQ(arg.dataId, dataId) << "Arguments in concat should have" - << " same dataId"; - const int copySize = 1; - const int rowIdx = selectRows[j]; - copyArg(in, arg.in, j, rowIdx, copySize, useGpu); - copyArg(value, arg.value, j, rowIdx, copySize, useGpu); + CHECK_EQ(arg.dataId, dataId) << "Arguments to concatenate should have " + << "the same dataId."; + const int srcStartRow = selectRows[j]; + copyArg(in, arg.in, desStartRow, srcStartRow, copySize[i], useGpu); + copyArg(value, arg.value, desStartRow, srcStartRow, copySize[i], useGpu); if (passType != PASS_TEST) { - copyArg(grad, arg.grad, j, rowIdx, copySize, useGpu); + copyArg(grad, arg.grad, desStartRow, srcStartRow, copySize[i], useGpu); } - copyIds(ids, arg.ids, j, rowIdx, copySize, useGpu); - copyStrs(strs, arg.strs, j, rowIdx, copySize, useGpu); + copyIds(ids, arg.ids, desStartRow, srcStartRow, copySize[i], useGpu); + copyStrs(strs, arg.strs, desStartRow, srcStartRow, copySize[i], useGpu); + desStartRow += copySize[i]; } } ICpuGpuVector::resizeOrCreate( @@ -670,19 +676,28 @@ void Argument::reorganizeSeqInfo( const ICpuGpuVectorPtr seqStartPos, const ICpuGpuVectorPtr subSeqStartPos, std::vector>& reorganizedSeqInfo) { - int* seqStarts = seqStartPos->getMutableData(false); - int* subSeqStarts = subSeqStartPos->getMutableData(false); + CHECK(seqStartPos); int seqNum = seqStartPos->getSize() - 1; - reorganizedSeqInfo.resize(seqNum, std::vector()); - int seqIdx = 0; - for (size_t i = 0; i < subSeqStartPos->getSize(); ++i) { - reorganizedSeqInfo[seqIdx].push_back(subSeqStarts[i]); - if (subSeqStarts[i] == seqStarts[seqIdx + 1]) { - seqIdx++; - if (seqIdx == seqNum) return; + int* seqStarts = seqStartPos->getMutableData(false); + + if (subSeqStartPos) { + int* subSeqStarts = subSeqStartPos->getMutableData(false); + reorganizedSeqInfo.resize(seqNum, std::vector()); + int seqIdx = 0; + for (size_t i = 0; i < subSeqStartPos->getSize(); ++i) { reorganizedSeqInfo[seqIdx].push_back(subSeqStarts[i]); + if (subSeqStarts[i] == seqStarts[seqIdx + 1]) { + seqIdx++; + if (seqIdx == seqNum) return; + reorganizedSeqInfo[seqIdx].push_back(subSeqStarts[i]); + } } + } else { + reorganizedSeqInfo.resize(1, std::vector(seqNum + 1, 0)); + memcpy(reorganizedSeqInfo[0].data(), + seqStarts, + sizeof(int) * seqStartPos->getSize()); } } diff --git a/paddle/parameter/Argument.h b/paddle/parameter/Argument.h index d8d7a4398f99a2794c5d25528a7d582f5ed629ba..38797a76f55c311070192bd307103143d67cabca 100644 --- a/paddle/parameter/Argument.h +++ b/paddle/parameter/Argument.h @@ -240,6 +240,7 @@ struct Argument { void concat(const std::vector& args, const std::vector& selectRows, const std::vector& seqStartPos, + const std::vector& copySize, bool useGpu, hl_stream_t stream, PassType passType); diff --git a/paddle/parameter/Parameter.h b/paddle/parameter/Parameter.h index e31cbc3dee6c57851c241e117dbbd9b701db9d2c..321f4275d8e68d7d3fbbc19acf0afacf689474e5 100644 --- a/paddle/parameter/Parameter.h +++ b/paddle/parameter/Parameter.h @@ -65,7 +65,10 @@ public: size_t getSize() const { return config_.size(); } bool isFullSize() const { - return this->getSize() == bufs_[PARAMETER_VALUE]->getSize(); + if (bufs_[PARAMETER_VALUE]) { + return this->getSize() == bufs_[PARAMETER_VALUE]->getSize(); + } + return false; } inline bool useGpu() const { return useGpu_; } diff --git a/paddle/platform/CMakeLists.txt b/paddle/platform/CMakeLists.txt index acfc0639736beb82df41b851664e7bcd079b5eb1..120eb1e4af9cef43e76e27d4ad66acfbbd597a36 100644 --- a/paddle/platform/CMakeLists.txt +++ b/paddle/platform/CMakeLists.txt @@ -1,7 +1,7 @@ cc_library(cpu_info SRCS cpu_info.cc DEPS gflags glog) cc_test(cpu_info_test SRCS cpu_info_test.cc DEPS cpu_info) -nv_library(gpu_info SRCS gpu_info.cc DEPS gflags) +nv_library(gpu_info SRCS gpu_info.cc DEPS gflags glog) cc_library(place SRCS place.cc) cc_test(place_test SRCS place_test.cc DEPS place glog gflags) @@ -9,6 +9,7 @@ cc_test(place_test SRCS place_test.cc DEPS place glog gflags) add_subdirectory(dynload) cc_test(enforce_test SRCS enforce_test.cc DEPS stringpiece) +cc_test(environment_test SRCS environment_test.cc DEPS stringpiece) IF(WITH_GPU) set(GPU_CTX_DEPS dynload_cuda dynamic_loader) diff --git a/paddle/platform/device_context.cc b/paddle/platform/device_context.cc index f92c15ae450e94de44d27e77763e791e6bae4426..ad212c5b2c47312743362db4926c80bf056e100d 100644 --- a/paddle/platform/device_context.cc +++ b/paddle/platform/device_context.cc @@ -114,9 +114,6 @@ CUDADeviceContext::~CUDADeviceContext() { PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_)); } - if (curand_generator_) { - PADDLE_ENFORCE(dynload::curandDestroyGenerator(curand_generator_)); - } eigen_stream_.reset(); eigen_device_.reset(); PADDLE_ENFORCE(cudaStreamDestroy(stream_)); @@ -152,19 +149,6 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() { cudaStream_t CUDADeviceContext::stream() { return stream_; } -curandGenerator_t CUDADeviceContext::curand_generator() { - if (!curand_generator_) { - SetDeviceId(place_.device); - PADDLE_ENFORCE(dynload::curandCreateGenerator(&curand_generator_, - CURAND_RNG_PSEUDO_DEFAULT)); - PADDLE_ENFORCE( - dynload::curandSetPseudoRandomGeneratorSeed(curand_generator_, seed_)); - - PADDLE_ENFORCE(dynload::curandSetStream(curand_generator_, stream_)); - } - return curand_generator_; -} - #endif // PADDLE_ONLY_CPU } // namespace platform diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index c5042ae33e47e04521e59e0d91ddd8d4efffe50a..11528e1194e4516891034fa8febdac3ba6eed204 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -17,7 +17,6 @@ limitations under the License. */ #ifndef PADDLE_ONLY_CPU #include "paddle/platform/dynload/cublas.h" #include "paddle/platform/dynload/cudnn.h" -#include "paddle/platform/dynload/curand.h" #include "paddle/platform/gpu_info.h" #define EIGEN_USE_GPU #endif @@ -40,7 +39,7 @@ class DeviceContext { class CPUDeviceContext : public DeviceContext { public: CPUDeviceContext(); - explicit CPUDeviceContext(CPUPlace); + explicit CPUDeviceContext(CPUPlace place); virtual ~CPUDeviceContext() {} Eigen::DefaultDevice* eigen_device() const; @@ -56,7 +55,7 @@ class EigenCudaStreamDevice; class CUDADeviceContext : public DeviceContext { public: - explicit CUDADeviceContext(GPUPlace); + explicit CUDADeviceContext(GPUPlace place); virtual ~CUDADeviceContext(); /*! \brief Wait for all operations completion in the stream. */ @@ -75,9 +74,6 @@ class CUDADeviceContext : public DeviceContext { /*! \brief Return cudnn handle in the device context. */ cudnnHandle_t cudnn_handle(); - /*! \brief Return curand handle in the device context. */ - curandGenerator_t curand_generator(); - /*! \brief Return cuda stream in the device context. */ cudaStream_t stream(); // clang-format on @@ -85,18 +81,13 @@ class CUDADeviceContext : public DeviceContext { private: GPUPlace place_; - private: std::unique_ptr eigen_device_; std::unique_ptr eigen_stream_; - private: - uint64_t seed_; - // clang-format off cudaStream_t stream_{nullptr}; cudnnHandle_t cudnn_handle_{nullptr}; cublasHandle_t cublas_handle_{nullptr}; - curandGenerator_t curand_generator_{nullptr}; // clang-format on }; diff --git a/paddle/platform/device_context_test.cc b/paddle/platform/device_context_test.cc index 8b764bdcd9d92e6b2203e45160acee35ec110538..5883a55272f0f24c94d48bc43c62ddb7bef15465 100644 --- a/paddle/platform/device_context_test.cc +++ b/paddle/platform/device_context_test.cc @@ -43,8 +43,6 @@ TEST(Device, CUDADeviceContext) { ASSERT_NE(nullptr, cudnn_handle); cublasHandle_t cublas_handle = device_context->cublas_handle(); ASSERT_NE(nullptr, cublas_handle); - curandGenerator_t curand_handle = device_context->curand_generator(); - ASSERT_NE(nullptr, curand_handle); ASSERT_NE(nullptr, device_context->stream()); delete device_context; } diff --git a/paddle/platform/environment.h b/paddle/platform/environment.h new file mode 100644 index 0000000000000000000000000000000000000000..4edcce932edc61453cef74f2c4ee0f72496b3677 --- /dev/null +++ b/paddle/platform/environment.h @@ -0,0 +1,60 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include + +#include "paddle/platform/enforce.h" +#include "paddle/string/piece.h" + +extern char** environ; // for environment variables + +namespace paddle { +namespace platform { + +inline void SetEnvVariable(const std::string& name, const std::string& value) { + PADDLE_ENFORCE_NE(setenv(name.c_str(), value.c_str(), 1), -1, + "Failed to set environment variable %s=%s", name, value); +} + +inline void UnsetEnvVariable(const std::string& name) { + PADDLE_ENFORCE_NE(unsetenv(name.c_str()), -1, + "Failed to unset environment variable %s", name); +} + +inline bool IsEnvVarDefined(const std::string& name) { + return std::getenv(name.c_str()) != nullptr; +} + +inline std::string GetEnvValue(const std::string& name) { + PADDLE_ENFORCE(IsEnvVarDefined(name), + "Tried to access undefined environment variable %s", name); + return std::getenv(name.c_str()); +} + +inline std::vector GetAllEnvVariables() { + std::vector vars; + for (auto var = environ; *var != nullptr; ++var) { + auto tail = string::Index(*var, "="); + auto name = string::SubStr(*var, 0, tail).ToString(); + vars.push_back(name); + } + return vars; +} + +} // namespace platform +} // namespace paddle diff --git a/paddle/platform/environment_test.cc b/paddle/platform/environment_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..5f136527215d6a676cfa1a3b08f09dfd3ab24a90 --- /dev/null +++ b/paddle/platform/environment_test.cc @@ -0,0 +1,54 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/platform/environment.h" + +#include "glog/logging.h" +#include "gtest/gtest.h" + +TEST(ENVIRONMENT, ACCESS) { + namespace platform = paddle::platform; + namespace string = paddle::string; + + platform::SetEnvVariable("PADDLE_USE_ENV", "TRUE"); + + EXPECT_TRUE(platform::IsEnvVarDefined("PADDLE_USE_ENV")); + EXPECT_EQ(platform::GetEnvValue("PADDLE_USE_ENV"), "TRUE"); + + platform::UnsetEnvVariable("PADDLE_USE_ENV"); + EXPECT_FALSE(platform::IsEnvVarDefined("PADDLE_USE_ENV")); + + platform::SetEnvVariable("PADDLE_USE_ENV1", "Hello "); + platform::SetEnvVariable("PADDLE_USE_ENV2", "World, "); + platform::SetEnvVariable("PADDLE_USE_ENV3", "PaddlePaddle!"); + + std::string env_info; + auto vars = platform::GetAllEnvVariables(); + for_each(vars.begin(), vars.end(), [&](const std::string& var) { + env_info += platform::GetEnvValue(var); + }); + + EXPECT_TRUE(string::Contains(env_info, "Hello World, PaddlePaddle!")); + platform::UnsetEnvVariable("PADDLE_USE_ENV1"); + platform::UnsetEnvVariable("PADDLE_USE_ENV2"); + platform::UnsetEnvVariable("PADDLE_USE_ENV3"); + + env_info.clear(); + vars = platform::GetAllEnvVariables(); + for_each(vars.begin(), vars.end(), [&](const std::string& var) { + env_info += platform::GetEnvValue(var); + }); + + EXPECT_FALSE(string::Contains(env_info, "Hello World, PaddlePaddle!")); + EXPECT_FALSE(platform::IsEnvVarDefined("PADDLE_USE_ENV1")); + EXPECT_FALSE(platform::IsEnvVarDefined("PADDLE_USE_ENV2")); + EXPECT_FALSE(platform::IsEnvVarDefined("PADDLE_USE_ENV3")); +} diff --git a/paddle/platform/gpu_info.cc b/paddle/platform/gpu_info.cc index edeb3ecd7bf8b87333813eee5b40f71030f6609f..be381a4e26cf0eb41f5b3de88bd03ad8901683cc 100644 --- a/paddle/platform/gpu_info.cc +++ b/paddle/platform/gpu_info.cc @@ -13,8 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/platform/gpu_info.h" + #include "gflags/gflags.h" + #include "paddle/platform/enforce.h" +#include "paddle/platform/environment.h" DEFINE_double(fraction_of_gpu_memory_to_use, 0.95, "Default use 95% of GPU memory for PaddlePaddle," @@ -70,6 +73,13 @@ size_t GpuMaxChunkSize() { GpuMemoryUsage(available, total); + if (IsEnvVarDefined(kEnvFractionGpuMemoryToUse)) { + auto val = std::stod(GetEnvValue(kEnvFractionGpuMemoryToUse)); + PADDLE_ENFORCE_GT(val, 0.0); + PADDLE_ENFORCE_LE(val, 1.0); + FLAGS_fraction_of_gpu_memory_to_use = val; + } + // Reserving the rest memory for page tables, etc. size_t reserving = (1 - FLAGS_fraction_of_gpu_memory_to_use) * total; diff --git a/paddle/platform/gpu_info.h b/paddle/platform/gpu_info.h index d3a5f5f13fdd3dd59eb43465da4a64b0d8d95e5b..ed2420b8740e583d307f6836a70fe7e1c780e28b 100644 --- a/paddle/platform/gpu_info.h +++ b/paddle/platform/gpu_info.h @@ -18,10 +18,15 @@ limitations under the License. */ #include #include +#include namespace paddle { namespace platform { +//! Environment variable: fraction of GPU memory to use on each device. +const std::string kEnvFractionGpuMemoryToUse = + "PADDLE_FRACTION_GPU_MEMORY_TO_USE"; + //! Get the total number of GPU devices in system. int GetDeviceCount(); diff --git a/paddle/pserver/ParameterClient2.cpp b/paddle/pserver/ParameterClient2.cpp index f7e391f76324a09c203dfbbb449feb050caa8fb4..54063a809a4f9e558f8d364f5c437f2b6d98925b 100644 --- a/paddle/pserver/ParameterClient2.cpp +++ b/paddle/pserver/ParameterClient2.cpp @@ -65,7 +65,6 @@ void ParameterClient2::initThreads() { LOG(INFO) << "parallel_thread_num dosent need to set"; } syncThreadPool_.reset(new SyncThreadPool(threadNum_)); - startThreads(); } @@ -224,6 +223,14 @@ void ParameterClient2::prepareSendData( request.set_cost(cost); request.set_batch_status(batchStatus); CHECK_EQ(request.blocks_size(), 0); + VLOG(10) << "request: trainer_id: " << request.trainer_id() + << " update_mode" << request.update_mode() + << " send_back_parameter: " << request.send_back_parameter() + << " send_back_parameter_type: " + << request.send_back_parameter_type() + << " num_samples: " << request.num_samples() + << " cost: " << request.cost() + << " batch_status: " << request.batch_status(); } for (const auto& segments : parameterSegments) { const auto it = parameterMap_.find(segments.id); @@ -251,11 +258,17 @@ void ParameterClient2::prepareSendData( CHECK(sendMat != nullptr) << "sendMat is nullptr"; syncThreadPool_->exec([&](int tid, size_t numThreads) { + std::lock_guard guard(sparseAutoGrowthMutex_); const auto& localIndices = prefetchMat->getLocalIndices(); /// num of sparse rows size_t nLocalBlocks = localIndices.size(); uint64_t beginDim = 0; uint64_t endDim = 0; + + // FIXME(typhoonzero): let it resize first + prefetchMat->getLocalRow(nLocalBlocks + 1); + sendMat->getLocalRow(nLocalBlocks + 1); + for (size_t row = 0; row < nLocalBlocks; ++row) { int64_t blockId = localIndices[row]; // local row -> sparse row int serverId = std::abs((blockId + nameHash) % serviceNum_); @@ -275,7 +288,6 @@ void ParameterClient2::prepareSendData( block->set_begin_pos(row * blockSize); /// block len block->set_block_size(endDim - beginDim); - if (sendingPara) { sendJob->parallelInputIovs[serverId].push_back( {sendMat->getLocalRow(row), sizeof(real) * (size_t)blockSize}); diff --git a/paddle/pserver/ParameterClient2.h b/paddle/pserver/ParameterClient2.h index 89b3ddd502151e537b81bdbb09f171dd6e13ba26..29b9eeacddf2945dd22b7b17fc87c7c74b868896 100644 --- a/paddle/pserver/ParameterClient2.h +++ b/paddle/pserver/ParameterClient2.h @@ -583,6 +583,7 @@ protected: #ifndef PADDLE_DISABLE_TIMER uint64_t forwardbackwordTime_; #endif + std::mutex sparseAutoGrowthMutex_; /// map id to parameter used for decoding protobuf data std::unordered_map parameterMap_; diff --git a/paddle/pybind/CMakeLists.txt b/paddle/pybind/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..40db811767a9c273f073f4715e6ddfbf05887730 --- /dev/null +++ b/paddle/pybind/CMakeLists.txt @@ -0,0 +1,20 @@ +if(WITH_PYTHON) +cc_library(paddle_pybind SHARED + SRCS pybind.cc + DEPS pybind python backward + sgd_op + gather_op + add_op + mul_op + rowwise_add_op + sigmoid_op + softmax_op + mean_op + cross_entropy_op + recurrent_op + uniform_random_op + gaussian_random_op + fill_zeros_like_op + scale_op + minus_op) +endif(WITH_PYTHON) diff --git a/paddle/framework/pybind.cc b/paddle/pybind/pybind.cc similarity index 91% rename from paddle/framework/pybind.cc rename to paddle/pybind/pybind.cc index f0114b9e4908d65b3fddb493230777f9e500b4e1..27b98e77db80505f7498deb75164e184b900262b 100644 --- a/paddle/framework/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -18,11 +18,11 @@ limitations under the License. */ #include "paddle/framework/backward.h" #include "paddle/framework/op_registry.h" -#include "paddle/framework/tensor_py.h" #include "paddle/operators/net_op.h" #include "paddle/operators/recurrent_op.h" #include "paddle/platform/enforce.h" #include "paddle/platform/place.h" +#include "paddle/pybind/tensor_py.h" #include "paddle/string/to_string.h" #include "pybind11/numpy.h" #include "pybind11/pybind11.h" @@ -31,7 +31,7 @@ limitations under the License. */ namespace py = pybind11; USE_OP(add_two); -USE_CPU_ONLY_OP(onehot_cross_entropy); +USE_OP(onehot_cross_entropy); USE_OP(sgd); USE_OP(mul); USE_OP(mean); @@ -42,6 +42,10 @@ USE_OP(fill_zeros_like); USE_OP_ITSELF(recurrent_op); USE_OP(gaussian_random); USE_OP(uniform_random); +USE_OP(scale); +USE_OP_ITSELF(identity); +USE_OP(minus); +USE_CPU_ONLY_OP(gather); namespace paddle { namespace framework { @@ -131,26 +135,24 @@ All parameter, weight, gradient are variables in Paddle. py::return_value_policy::reference) .def("find_var", &Scope::FindVar, py::return_value_policy::reference) .def(py::init<>()) - .def("new_scope", [](Scope &self) -> Scope * { return &self.NewScope(); }, + .def("new_scope", + [](Scope &self) -> Scope * { return &self.NewScope(); }, py::return_value_policy::reference) .def("drop_kids", &Scope::DropKids); //! @note: Be careful! PyBind will return std::string as an unicode, not //! Python str. If you want a str object, you should cast them in Python. m.def("get_all_op_protos", []() -> std::vector { - auto &op_info_map = OpRegistry::op_info_map(); std::vector ret_values; - for (auto it = op_info_map.begin(); it != op_info_map.end(); ++it) { - const OpProto *proto = it->second.proto_; - if (proto == nullptr) { - continue; - } - PADDLE_ENFORCE(proto->IsInitialized(), "OpProto must all be initialized"); + + OpInfoMap::Instance().IterAllInfo([&ret_values](const std::string &type, + const OpInfo &info) { + if (!info.HasOpProtoAndChecker()) return; std::string str; - PADDLE_ENFORCE(proto->SerializeToString(&str), + PADDLE_ENFORCE(info.Proto().SerializeToString(&str), "Serialize OpProto Error. This could be a bug of Paddle."); - ret_values.push_back(py::bytes(str)); - } + ret_values.emplace_back(str); + }); return ret_values; }); m.def_submodule( @@ -222,8 +224,10 @@ All parameter, weight, gradient are variables in Paddle. retv->SetType("plain_net"); return retv; }) - .def("add_op", [](operators::NetOp &self, - const OperatorBase &op) { self.AddOp(op); }) + .def("append_op", + [](operators::NetOp &self, const OperatorBase &op) { + self.AppendOp(op); + }) .def("complete_add_op", &operators::NetOp::CompleteAddOp) .def("complete_add_op", [](std::shared_ptr &self) { self->CompleteAddOp(); @@ -243,10 +247,9 @@ All parameter, weight, gradient are variables in Paddle. auto rnn_op = OpRegistry::CreateOp(desc); return static_cast(rnn_op.release()); }) - .def("set_stepnet", [](operators::RecurrentOp &self, - const operators::NetOp &net) -> void { - self.set_stepnet(net.Clone()); - }); + .def("set_stepnet", + [](operators::RecurrentOp &self, const operators::NetOp &net) + -> void { self.set_stepnet(net.Clone()); }); m.def("unique_integer", UniqueIntegerGenerator); diff --git a/paddle/framework/tensor_py.h b/paddle/pybind/tensor_py.h similarity index 92% rename from paddle/framework/tensor_py.h rename to paddle/pybind/tensor_py.h index 4e1ab77b157fe1adaeac55c271c056236f2d40de..39ba60b4dc7ebe3f39a0aa4023b34540b340a841 100644 --- a/paddle/framework/tensor_py.h +++ b/paddle/pybind/tensor_py.h @@ -63,8 +63,11 @@ struct CastToPyBufferImpl { } return py::buffer_info( dst_tensor.mutable_data(dst_tensor.holder_->place()), - sizeof(CUR_TYPE), py::format_descriptor::format(), - (size_t)framework::arity(dst_tensor.dims()), dims_outside, strides); + sizeof(CUR_TYPE), + py::format_descriptor::format(), + (size_t)framework::arity(dst_tensor.dims()), + dims_outside, + strides); } else { constexpr bool less = I + 1 < std::tuple_size>::value; return CastToPyBufferImpl()(tensor); @@ -107,8 +110,8 @@ void PyCUDATensorSetFromArray( self.Resize(framework::make_ddim(dims)); auto *dst = self.mutable_data(place); - paddle::platform::GpuMemcpySync(dst, array.data(), sizeof(T) * array.size(), - cudaMemcpyHostToDevice); + paddle::platform::GpuMemcpySync( + dst, array.data(), sizeof(T) * array.size(), cudaMemcpyHostToDevice); } #endif diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index a3ca3f251099bfbcf3dfef74c10f744d5081f333..0031f30e1da4068cda78c6b49578659e0f4ab15d 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -338,7 +338,8 @@ def RecurrentLayerGroupWithoutOutLinksBegin(name, in_links_count += 1 layer_name = MakeLayerNameInParentSubmodel(name) layer = g_layer_map[layer_name] - ScatterAgentLayer(name=name, size=layer.size) + ScatterAgentLayer( + name=name, size=layer.size, width=layer.width, height=layer.height) pair = g_current_submodel.in_links.add() pair.layer_name = layer_name @@ -2197,8 +2198,8 @@ class MaxOutLayer(LayerBase): maxout_conf = self.config.inputs[0].maxout_conf parse_maxout(self.inputs[0].maxout, input_layer.name, maxout_conf) out_channels = maxout_conf.image_conf.channels / maxout_conf.groups - self.set_cnn_layer(name, g_layer_map[input_layer.name].height, - g_layer_map[input_layer.name].width, out_channels) + self.set_cnn_layer(name, maxout_conf.image_conf.img_size_y, + maxout_conf.image_conf.img_size, out_channels) @config_layer('row_conv') @@ -2232,6 +2233,20 @@ class ClipLayer(LayerBase): self.config.inputs[0].clip_conf.max = max +@config_layer('scale_shift') +class ScaleShiftLayer(LayerBase): + def __init__(self, name, inputs, bias=True, **xargs): + super(ScaleShiftLayer, self).__init__( + name, 'scale_shift', 0, inputs=inputs, **xargs) + config_assert( + len(self.inputs) == 1, + 'ScaleShiftLayer must have one and only one input.') + input_layer = self.get_input_layer(0) + self.set_layer_size(input_layer.size) + self.create_input_parameter(0, 1, [1, 1]) + self.create_bias_parameter(bias, 1) + + # key: cost type # value: cost class g_cost_map = {} @@ -2402,9 +2417,11 @@ class GatherAgentLayer(LayerBase): @config_layer('scatter_agent') class ScatterAgentLayer(LayerBase): - def __init__(self, name, size, device=None): + def __init__(self, name, size, width=None, height=None, device=None): super(ScatterAgentLayer, self).__init__( name, 'scatter_agent', size, inputs=[], device=device) + if height and width: + self.set_layer_height_width(height, width) @config_layer('multiplex') @@ -2688,6 +2705,49 @@ class SubSequenceLayer(LayerBase): self.create_bias_parameter(bias, size) +@config_layer('seq_slice') +class SeqSliceLayer(LayerBase): + def __init__(self, name, inputs, starts, ends, bias=False, **xargs): + if isinstance(inputs, list): + assert len(inputs) == 1, ('the first input of sequence slice layer ' + 'is a single sequence input.') + else: + inputs = [inputs] + + if starts is not None: + if isinstance(starts, list): + assert len(starts) == 1, ( + 'the start indices for sequence slice layer cannot ' + 'be a list having more than one element.') + starts = starts[0] + inputs.append(starts) + + if ends is not None: + if isinstance(ends, list): + assert len(ends) == 1, ( + 'the end indices for sequence slice layer cannot ' + 'be a list having more than one element.') + ends = ends[0] + inputs.append(ends) + assert len(inputs) >= 2, ( + 'the sequence slice layer has at least two inputs.') + + super(SeqSliceLayer, self).__init__( + name, 'seq_slice', 0, inputs=inputs, **xargs) + + input_layer0 = self.get_input_layer(0) + size = input_layer0.size + self.set_layer_size(size) + + if len(inputs) == 3: + assert ( + self.get_input_layer(1).size == self.get_input_layer(2).size), ( + 'If start and end indices are both given to' + 'sequence slice layer, they should have the same width.') + elif len(inputs) == 2: + self.config.select_first = (starts is not None) + + @config_layer('sub_nested_seq') class SubNestedSequenceLayer(LayerBase): def __init__(self, name, inputs, selected_indices, bias=False, **xargs): diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index d61c94dc8278cf84b8b5e6c5c2e3e6ab875afcbf..103fa349d64f0dd2b0b2da7b90d21fcb2cc7da97 100755 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -16,11 +16,13 @@ import functools import collections import inspect +import paddle.trainer.config_parser as cp from paddle.trainer.config_parser import * from .activations import LinearActivation, SigmoidActivation, TanhActivation, \ ReluActivation, IdentityActivation, SoftmaxActivation, BaseActivation from .evaluators import * -from .poolings import MaxPooling, AvgPooling, BasePoolingType +from .poolings import MaxPooling, AvgPooling, BasePoolingType, \ + CudnnAvgPooling, CudnnMaxPooling from .attrs import * from .default_decorators import * @@ -133,7 +135,9 @@ __all__ = [ 'sub_nested_seq_layer', 'clip_layer', 'slice_projection', + 'seq_slice_layer', 'kmax_sequence_score_layer', + 'scale_shift_layer', ] @@ -230,8 +234,10 @@ class LayerType(object): CROP_LAYER = 'crop' SUB_NESTED_SEQ = 'sub_nested_seq' CLIP_LAYER = 'clip' + SEQ_SLICE = 'seq_slice' KMAX_SEQ_SCORE = 'kmax_seq_score' + SCALE_SHIFT_LAYER = 'scale_shift' @staticmethod def is_layer_type(type_name): @@ -330,6 +336,14 @@ class LayerOutput(object): self.outputs = outputs self.reverse = reverse + @property + def width(self): + return cp.g_layer_map[self.full_name].width + + @property + def height(self): + return cp.g_layer_map[self.full_name].height + def set_input(self, input): """ Set the input for a memory layer. Can only be used for memory layer @@ -911,7 +925,13 @@ def data_layer(name, size, height=None, width=None, layer_attr=None): width=width, **ExtraLayerAttribute.to_kwargs(layer_attr)) - return LayerOutput(name, LayerType.DATA, size=size) + num_filters = None + if height is not None and width is not None: + num_filters = size / (width * height) + assert num_filters * width * height == size, \ + "size=%s width=%s height=%s" % (size, width, height) + + return LayerOutput(name, LayerType.DATA, size=size, num_filters=num_filters) @wrap_name_default("embedding") @@ -2571,6 +2591,10 @@ def img_pool_layer(input, assert input.num_filters is not None num_channels = input.num_filters + assert type(pool_type) in [AvgPooling, MaxPooling, CudnnAvgPooling, + CudnnMaxPooling], \ + "only (Cudnn)AvgPooling, (Cudnn)MaxPooling are supported" + if pool_type is None: pool_type = MaxPooling() elif isinstance(pool_type, AvgPooling): @@ -2580,7 +2604,6 @@ def img_pool_layer(input, if ( isinstance(pool_type, AvgPooling) or isinstance(pool_type, MaxPooling)) \ else pool_type.name - pool_size_y = pool_size if pool_size_y is None else pool_size_y stride_y = stride if stride_y is None else stride_y padding_y = padding if padding_y is None else padding_y @@ -4204,8 +4227,7 @@ def conv_operator(img, num_channels = img.num_filters assert isinstance(filter, LayerOutput) - if filter.size is not None: - filter.size = filter_size * filter_size_y * num_filters * num_channels + assert filter.size is not None opCls = ConvTransOperator if trans else ConvOperator @@ -4916,7 +4938,6 @@ def maxout_layer(input, groups, num_channels=None, name=None, layer_attr=None): :return: LayerOutput object. :rtype: LayerOutput """ - assert input.layer_type == LayerType.CONV_LAYER assert isinstance(input.activation, LinearActivation) assert groups > 1 if num_channels is None: @@ -6238,6 +6259,72 @@ def clip_layer(input, min, max, name=None): name, LayerType.CLIP_LAYER, parents=[input], size=input.size) +@wrap_name_default() +def seq_slice_layer(input, starts, ends, name=None): + """ + seq_slice_layer will return one or several sub-sequences from the + input sequence layer given start and end indices. + + - If only start indices are given, and end indices are set to None, + this layer slices the input sequence from the given start indices + to its end. + - If only end indices are given, and start indices are set to None, + this layer slices the input sequence from its beginning to the + given end indices. + - If start and end indices are both given, they should have the same + number of elements. + + If start or end indices contains more than one elements, the input sequence + will be sliced for multiple times. + + + .. code-block:: python + + seq_silce = seq_slice_layer(input=input_seq, + starts=start_pos, ends=end_pos) + + :param name: name of this layer. + :type name: basestring + :param input: input for this layer, it should be a sequence. + :type input: LayerOutput + :param starts: start indices to slice the input sequence. + :type starts: LayerOutput|None + :param ends: end indices to slice the input sequence. + :type ends: LayerOutput|None + :return: LayerOutput object. + :rtype: LayerOutput + + """ + + assert isinstance(input, LayerOutput), ( + 'The first input of seq_slice layer must be a PaddlePaddle layer.') + + if starts is not None: + assert isinstance(starts, LayerOutput), ( + 'The start indices for seq_slice layer ' + 'must be a PaddlePaddle layer.') + if ends is not None: + assert isinstance(ends, LayerOutput), ( + 'The end indices for seq_slice layer must be a PaddlePaddle layer.') + assert starts is not None or ends is not None, ( + 'start and end indices ' + 'cannot be set to None at the same time, at least one of ' + 'them should be given.') + if starts is not None and ends is not None: + assert starts.size == ends.size, ( + 'If start and end indices are both given to seq_slice_layer, ' + 'they should have the same width.') + + Layer( + name=name, + type=LayerType.SEQ_SLICE, + inputs=input.name, + starts=starts.name if starts is not None else None, + ends=ends.name if ends is not None else None) + return LayerOutput( + name, LayerType.SEQ_SLICE, parents=[input], size=input.size) + + @wrap_name_default() @layer_support() def kmax_sequence_score_layer(input, name=None, beam_size=1): @@ -6274,3 +6361,43 @@ def kmax_sequence_score_layer(input, name=None, beam_size=1): return LayerOutput( name, LayerType.KMAX_SEQ_SCORE, parents=[input], size=input.size) + + +@wrap_name_default("scale_shift") +@wrap_param_attr_default() +@wrap_bias_attr_default() +def scale_shift_layer(input, name=None, param_attr=None, bias_attr=None): + """ + A layer applies a linear transformation to each element in each row of + the input matrix. For each element, the layer first re-scale it and then + adds a bias to it. + + This layer is very like the SlopeInterceptLayer, except the scale and + bias are trainable. + + .. math:: + + y = w * x + b + + .. code-block:: python + + scale_shift = scale_shift_layer(input=input_layer, bias_attr=False) + + :param name: The Layer Name. + :type name: basestring + :param input: The input layer. + :type input: LayerOutput. + :param param_attr: The parameter attribute of scaling. + :type param_attr: ParameterAttribute + :param bias_attr: The parameter attribute of shifting. + :type bias_attr: ParameterAttribute + :return: LayerOutput object. + :rtype: LayerOutput + """ + Layer( + name=name, + type=LayerType.SCALE_SHIFT_LAYER, + inputs=Input(input.name, **param_attr.attr), + bias=ParamAttr.to_bias(bias_attr)) + return LayerOutput( + name, LayerType.SCALE_SHIFT_LAYER, parents=[input], size=input.size) diff --git a/python/paddle/trainer_config_helpers/tests/configs/file_list.sh b/python/paddle/trainer_config_helpers/tests/configs/file_list.sh index a61beb871ad064c617fa141451afcb2a5ac64854..1ca5c8a07ebb7a7d842445bbe75cc3bf7bfb295a 100755 --- a/python/paddle/trainer_config_helpers/tests/configs/file_list.sh +++ b/python/paddle/trainer_config_helpers/tests/configs/file_list.sh @@ -8,6 +8,7 @@ test_spp_layer test_bilinear_interp test_maxout test_bi_grumemory math_ops test_seq_concat_reshape test_pad test_smooth_l1 test_multiplex_layer test_prelu_layer test_row_conv test_detection_output_layer test_multibox_loss_layer test_recursive_topology test_gated_unit_layer test_clip_layer test_row_l2_norm_layer -test_kmax_seq_socre_layer test_seq_select_layers) +test_kmax_seq_socre_layer test_seq_select_layers test_scale_shift_layer +test_seq_slice_layer) export whole_configs=(test_split_datasource) diff --git a/python/paddle/trainer_config_helpers/tests/configs/protostr/test_kmax_seq_socre_layer.protostr b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_kmax_seq_socre_layer.protostr index 81bd71f68eb3f2c04ccd46ee3b77a07543395c60..3d32220bfbf5f4c67f88303cb9773ecfa484da4b 100644 --- a/python/paddle/trainer_config_helpers/tests/configs/protostr/test_kmax_seq_socre_layer.protostr +++ b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_kmax_seq_socre_layer.protostr @@ -1,12 +1,6 @@ type: "nn" layers { - name: "input" - type: "data" - size: 300 - active_type: "" -} -layers { - name: "data" + name: "input_seq" type: "data" size: 128 active_type: "" @@ -17,7 +11,7 @@ layers { size: 1 active_type: "exponential" inputs { - input_layer_name: "data" + input_layer_name: "input_seq" input_parameter_name: "___fc_layer_0__.w0" } bias_parameter_name: "___fc_layer_0__.wbias" @@ -51,15 +45,14 @@ parameters { initial_strategy: 0 initial_smart: false } -input_layer_names: "data" +input_layer_names: "input_seq" output_layer_names: "__kmax_sequence_score_layer_0__" sub_models { name: "root" - layer_names: "input" - layer_names: "data" + layer_names: "input_seq" layer_names: "__fc_layer_0__" layer_names: "__kmax_sequence_score_layer_0__" - input_layer_names: "data" + input_layer_names: "input_seq" output_layer_names: "__kmax_sequence_score_layer_0__" is_recurrent_layer_group: false } diff --git a/python/paddle/trainer_config_helpers/tests/configs/protostr/test_scale_shift_layer.protostr b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_scale_shift_layer.protostr new file mode 100644 index 0000000000000000000000000000000000000000..35ade126a2586a8e3eee6f0ac3c7e49523c8f5c5 --- /dev/null +++ b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_scale_shift_layer.protostr @@ -0,0 +1,72 @@ +type: "nn" +layers { + name: "data" + type: "data" + size: 100 + active_type: "" +} +layers { + name: "__scale_shift_0__" + type: "scale_shift" + size: 100 + active_type: "" + inputs { + input_layer_name: "data" + input_parameter_name: "___scale_shift_0__.w0" + } +} +layers { + name: "__scale_shift_1__" + type: "scale_shift" + size: 100 + active_type: "" + inputs { + input_layer_name: "data" + input_parameter_name: "___scale_shift_1__.w0" + } + bias_parameter_name: "___scale_shift_1__.wbias" +} +parameters { + name: "___scale_shift_0__.w0" + size: 1 + initial_mean: 0.0 + initial_std: 1.0 + dims: 1 + dims: 1 + initial_strategy: 0 + initial_smart: true +} +parameters { + name: "___scale_shift_1__.w0" + size: 1 + initial_mean: 0.0 + initial_std: 1.0 + dims: 1 + dims: 1 + initial_strategy: 0 + initial_smart: true +} +parameters { + name: "___scale_shift_1__.wbias" + size: 1 + initial_mean: 0.0 + initial_std: 0.0 + dims: 1 + dims: 1 + initial_strategy: 0 + initial_smart: false +} +input_layer_names: "data" +output_layer_names: "__scale_shift_0__" +output_layer_names: "__scale_shift_1__" +sub_models { + name: "root" + layer_names: "data" + layer_names: "__scale_shift_0__" + layer_names: "__scale_shift_1__" + input_layer_names: "data" + output_layer_names: "__scale_shift_0__" + output_layer_names: "__scale_shift_1__" + is_recurrent_layer_group: false +} + diff --git a/python/paddle/trainer_config_helpers/tests/configs/protostr/test_seq_slice_layer.protostr b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_seq_slice_layer.protostr new file mode 100644 index 0000000000000000000000000000000000000000..5b73d614fe862e74c8dc5c24a776c0020334224c --- /dev/null +++ b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_seq_slice_layer.protostr @@ -0,0 +1,79 @@ +type: "nn" +layers { + name: "word" + type: "data" + size: 128 + active_type: "" +} +layers { + name: "starts" + type: "data" + size: 5 + active_type: "" +} +layers { + name: "ends" + type: "data" + size: 5 + active_type: "" +} +layers { + name: "__seq_slice_layer_0__" + type: "seq_slice" + size: 128 + active_type: "" + inputs { + input_layer_name: "word" + } + inputs { + input_layer_name: "starts" + } + inputs { + input_layer_name: "ends" + } +} +layers { + name: "__seq_slice_layer_1__" + type: "seq_slice" + size: 128 + active_type: "" + inputs { + input_layer_name: "word" + } + inputs { + input_layer_name: "starts" + } + select_first: true +} +layers { + name: "__seq_slice_layer_2__" + type: "seq_slice" + size: 128 + active_type: "" + inputs { + input_layer_name: "word" + } + inputs { + input_layer_name: "ends" + } + select_first: false +} +input_layer_names: "word" +output_layer_names: "__seq_slice_layer_0__" +output_layer_names: "__seq_slice_layer_1__" +output_layer_names: "__seq_slice_layer_2__" +sub_models { + name: "root" + layer_names: "word" + layer_names: "starts" + layer_names: "ends" + layer_names: "__seq_slice_layer_0__" + layer_names: "__seq_slice_layer_1__" + layer_names: "__seq_slice_layer_2__" + input_layer_names: "word" + output_layer_names: "__seq_slice_layer_0__" + output_layer_names: "__seq_slice_layer_1__" + output_layer_names: "__seq_slice_layer_2__" + is_recurrent_layer_group: false +} + diff --git a/python/paddle/trainer_config_helpers/tests/configs/test_kmax_seq_socre_layer.py b/python/paddle/trainer_config_helpers/tests/configs/test_kmax_seq_socre_layer.py index d245c5a41c793e1f02f306bfe64071bd9885906e..48d0cd55da2481743de66ea95190c0856e7ddc39 100644 --- a/python/paddle/trainer_config_helpers/tests/configs/test_kmax_seq_socre_layer.py +++ b/python/paddle/trainer_config_helpers/tests/configs/test_kmax_seq_socre_layer.py @@ -2,9 +2,7 @@ #coding=utf-8 from paddle.trainer_config_helpers import * -data = data_layer(name='input', size=300) - -data = data_layer(name="data", size=128) +data = data_layer(name="input_seq", size=128) scores = fc_layer(input=data, size=1, act=ExpActivation()) kmax_seq_id = kmax_sequence_score_layer(input=scores, beam_size=5) diff --git a/python/paddle/trainer_config_helpers/tests/configs/test_scale_shift_layer.py b/python/paddle/trainer_config_helpers/tests/configs/test_scale_shift_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..dd589116fa9932144ca066d3fa4c929d1433a7f1 --- /dev/null +++ b/python/paddle/trainer_config_helpers/tests/configs/test_scale_shift_layer.py @@ -0,0 +1,9 @@ +from paddle.trainer_config_helpers import * + +data = data_layer(name='data', size=100) + +scale = scale_shift_layer(input=data, bias_attr=False) + +scale_shift = scale_shift_layer(input=data) + +outputs(scale, scale_shift) diff --git a/python/paddle/trainer_config_helpers/tests/configs/test_seq_slice_layer.py b/python/paddle/trainer_config_helpers/tests/configs/test_seq_slice_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..510ad3220893fddac278ba691307d00d57e440a3 --- /dev/null +++ b/python/paddle/trainer_config_helpers/tests/configs/test_seq_slice_layer.py @@ -0,0 +1,13 @@ +#!/usr/bin/env python +#coding=utf-8 +from paddle.trainer_config_helpers import * + +input_seq = data_layer("word", size=128) +starts = data_layer("starts", size=5) +ends = data_layer("ends", size=5) + +seq_slice1 = seq_slice_layer(input=input_seq, starts=starts, ends=ends) +seq_slice2 = seq_slice_layer(input=input_seq, starts=starts, ends=None) +seq_slice3 = seq_slice_layer(input=input_seq, starts=None, ends=ends) + +outputs(seq_slice1, seq_slice2, seq_slice3) diff --git a/python/paddle/v2/framework/tests/CMakeLists.txt b/python/paddle/v2/framework/tests/CMakeLists.txt index ce57a0713092723b6a99b2416e06ff1a436f043b..2849ee7c8d0404432fcf6156552f40657d094983 100644 --- a/python/paddle/v2/framework/tests/CMakeLists.txt +++ b/python/paddle/v2/framework/tests/CMakeLists.txt @@ -13,6 +13,7 @@ py_test(test_add_two_op SRCS test_add_two_op.py) py_test(test_sigmoid_op SRCS test_sigmoid_op.py) py_test(test_softmax_op SRCS test_softmax_op.py) py_test(test_cross_entropy_op SRCS test_cross_entropy_op.py) +py_test(test_gather_op SRCS test_gather_op.py) py_test(test_fill_zeros_like_op SRCS test_fill_zeros_like_op.py) py_test(gradient_checker SRCS gradient_checker.py) @@ -22,8 +23,10 @@ py_test(test_rowwise_add_op SRCS test_rowwise_add_op.py) py_test(test_default_scope_funcs SRCS test_default_scope_funcs.py) py_test(test_operator SRCS test_operator.py) -# py_test(test_gaussian_random_op SRCS test_gaussian_random_op.py) +py_test(test_gaussian_random_op SRCS test_gaussian_random_op.py) py_test(test_uniform_random_op SRCS test_uniform_random_op.py) py_test(test_recurrent_op SRCS test_recurrent_op.py) py_test(test_sgd_op SRCS test_sgd_op.py) py_test(test_gradient_checker SRCS test_gradient_checker.py) +py_test(test_scale_and_identity_op SRCS test_scale_and_identity_op.py) +py_test(mnist SRCS mnist.py) diff --git a/python/paddle/v2/framework/tests/gradient_checker.py b/python/paddle/v2/framework/tests/gradient_checker.py index 8b8e2f444be1169c23784321721c5d8154541fcf..c22c6f8831b2551d9a83747bc0d15789a78a101e 100644 --- a/python/paddle/v2/framework/tests/gradient_checker.py +++ b/python/paddle/v2/framework/tests/gradient_checker.py @@ -160,8 +160,13 @@ class GradientChecker(unittest.TestCase): grad_tensor.set(data, place) # run backward op - for name in backward_op.outputs(): + backward_outs = backward_op.outputs() + backward_names = [ + item for key in backward_outs for item in backward_outs[key] + ] + for name in backward_names: scope.new_var(name) + backward_op.infer_shape(scope) backward_op.run(scope, ctx) diff --git a/python/paddle/v2/framework/tests/mnist.py b/python/paddle/v2/framework/tests/mnist.py new file mode 100644 index 0000000000000000000000000000000000000000..9a0b109850e92c66e69f74c5cd0853a09b5551a1 --- /dev/null +++ b/python/paddle/v2/framework/tests/mnist.py @@ -0,0 +1,249 @@ +import paddle.v2.framework.core as core +from paddle.v2.framework.op import Operator +import numpy +import paddle.v2 as paddle + +BATCH_SIZE = 100 + +scope = core.Scope() +place = core.CPUPlace() +# if you want to test GPU training, you can use gpu place +# place = core.GPUPlace(0) +dev_ctx = core.DeviceContext.create(place) + +init_net = core.Net.create() +forward_net = core.Net.create() +backward_net = None +optimize_net = core.Net.create() + + +def atomic_id(): + id = 0 + while True: + yield id + id += 1 + + +uniq_id = atomic_id().next + + +def data_layer(name, dims): + var = scope.new_var(name) + tensor = var.get_tensor() + tensor.set_dims(dims) # 1 is batch size holder. + return name + + +def feed_data(name, data): + assert isinstance(data, numpy.ndarray) + tensor = scope.find_var(name).get_tensor() + tensor.set_dims(data.shape) + if data.dtype == numpy.dtype('int32'): + tensor.alloc_int(place) + elif data.dtype == numpy.dtype('float32'): + tensor.alloc_float(place) + else: + raise ValueError("data type not supported") + tensor.set(data, place) + + +def grad_var_name(var_name): + return var_name + "@GRAD" + + +def sgd_optimizer(net, param_name, learning_rate=0.005): + grad_name = grad_var_name(param_name) + optimize_op = Operator( + "sgd", + param=param_name, + grad=grad_name, + param_out=param_name, + learning_rate=learning_rate) + net.append_op(optimize_op) + + +# should use operator and add these to the init_network +def init_param(net, param_name, dims): + scope.new_var(param_name) + op = Operator( + "uniform_random", Out=param_name, dims=dims, min=-0.5, max=0.5, seed=10) + op.infer_shape(scope) + net.append_op(op) + + +# fc_layer +def fc_layer(net, input, size, act="softmax", bias=True, param=None, name=None): + """ + Add a fc layer to net + + :param input: input variable name. + :type input: str + :param size: fully connected layer size. + :param act: activation name + :param param: parameter attribute, used for initialize parameters. + :param bias: bias attribute. False will not have a bias. + :param name: the name of fc layer. If not set, model will generate a + readable name + :return: output variable name. + """ + if name is None: + name = 'fc_%d' % uniq_id() + if not isinstance(name, str): + raise ValueError("name should be string") + + input_dims = scope.find_var(input).get_tensor().get_dims() + + w_name = param or name + ".w" + init_param(net=init_net, param_name=w_name, dims=[input_dims[1], size]) + sgd_optimizer(net=optimize_net, param_name=w_name, learning_rate=0.01) + + pre_activation = name + ".mul.out" + scope.new_var(pre_activation) + mul_op = Operator("mul", X=input, Y=w_name, Out=pre_activation) + net.append_op(mul_op) + + # create bias variable if needed + if bias: + bias_name = name + ".b" + init_param(net=init_net, param_name=bias_name, dims=[size]) + sgd_optimizer( + net=optimize_net, param_name=bias_name, learning_rate=0.001) + bias_out = name + ".rowwise_add.out" + scope.new_var(bias_out) + rowwise_append_op = Operator( + "rowwise_add", X=pre_activation, b=bias_name, Out=bias_out) + net.append_op(rowwise_append_op) + pre_activation = bias_out + + activation_op = Operator(act, X=pre_activation, Y=name) + net.append_op(activation_op) + scope.new_var(name) + net.infer_shape(scope) + return name + + +def cross_entropy_layer(net, input, label): + cost_name = 'cross_entropy_%d' % uniq_id() + cross_entropy_op = Operator( + "onehot_cross_entropy", X=input, label=label, Y=cost_name) + net.append_op(cross_entropy_op) + scope.new_var(cost_name) + net.infer_shape(scope) + return cost_name + + +def create_backward_net(forward_net): + net = core.Operator.backward(forward_net, set()) + for input in net.inputs()["all"]: + var = scope.new_var(input) + var.get_tensor() + for output in net.outputs()["all"]: + var = scope.new_var(output) + var.get_tensor() + return net + + +def debug_print_op(op): + print("===============" + op.type() + "==============") + print("***inputs:***") + for input in op.inputs()["all"]: + print input, scope.find_var(input).get_tensor().get_dims() + print("\n***outputs:***") + for output in op.outputs()["all"]: + print output, scope.find_var(output).get_tensor().get_dims() + print("") + print("") + + +def set_cost(cost): + cost_shape = numpy.array(scope.find_var(cost).get_tensor()).shape + cost_grad = \ + scope.find_var(grad_var_name(cost)).get_tensor() + cost_grad.set_dims(cost_shape) + cost_grad.alloc_float(place) + cost_grad.set(numpy.ones(cost_shape).astype("float32"), place) + + +def get_cost_mean(cost): + cost_data = numpy.array(scope.find_var(cost).get_tensor()) + return cost_data.sum() / len(cost_data) + + +def error_rate(predict, label): + predict_var = numpy.array(scope.find_var(predict).get_tensor()).argmax( + axis=1) + label = numpy.array(scope.find_var(label).get_tensor()) + error_num = numpy.sum(predict_var != label) + return error_num / float(len(label)) + + +images = data_layer(name='pixel', dims=[BATCH_SIZE, 784]) +labels = data_layer(name='label', dims=[BATCH_SIZE]) +fc1 = fc_layer(net=forward_net, input=images, size=100, act="sigmoid") +fc2 = fc_layer(net=forward_net, input=fc1, size=100, act="sigmoid") +predict = fc_layer(net=forward_net, input=fc2, size=100, act="softmax") +cost = cross_entropy_layer(net=forward_net, input=predict, label=labels) + +init_net.complete_add_op(True) +forward_net.complete_add_op(True) +backward_net = create_backward_net(forward_net) +optimize_net.complete_add_op(True) + +print(init_net) +print(forward_net) +print(backward_net) +print(optimize_net) + +debug_print_op(forward_net) +debug_print_op(backward_net) +debug_print_op(optimize_net) + +train_reader = paddle.batch( + paddle.reader.shuffle( + paddle.dataset.mnist.train(), buf_size=8192), + batch_size=BATCH_SIZE) + + +def test(cost_name): + test_reader = paddle.batch( + paddle.dataset.mnist.test(), batch_size=BATCH_SIZE) + cost = [] + error = [] + for data in test_reader(): + image_data = numpy.array(map(lambda x: x[0], data)).astype("float32") + label_data = numpy.array(map(lambda x: x[1], data)).astype("int32") + feed_data(images, image_data) + feed_data(labels, label_data) + + forward_net.infer_shape(scope) + forward_net.run(scope, dev_ctx) + cost.append(get_cost_mean(cost_name)) + error.append(error_rate(predict, "label")) + print("cost=" + str(sum(cost) / float(len(cost))) + " error_rate=" + str( + sum(error) / float(len(error)))) + + +PASS_NUM = 1 + +init_net.run(scope, dev_ctx) +for pass_id in range(PASS_NUM): + batch_id = 0 + + for data in train_reader(): + image_data = numpy.array(map(lambda x: x[0], data)).astype("float32") + label_data = numpy.array(map(lambda x: x[1], data)).astype("int32") + feed_data(images, image_data) + feed_data(labels, label_data) + + forward_net.infer_shape(scope) + forward_net.run(scope, dev_ctx) + set_cost(cost) + backward_net.infer_shape(scope) + backward_net.run(scope, dev_ctx) + + optimize_net.run(scope, dev_ctx) + if batch_id % 100 == 0: + print("pass[" + str(pass_id) + "] batch_id[" + str(batch_id) + "]") + test(cost) + + batch_id = batch_id + 1 diff --git a/python/paddle/v2/framework/tests/op_test_util.py b/python/paddle/v2/framework/tests/op_test_util.py index dd65e0f2dc23d3f657ff16c55fb297dae210b2d7..3bc05a0feccbbd3d5e7852d85bd3dc8edaccfd07 100644 --- a/python/paddle/v2/framework/tests/op_test_util.py +++ b/python/paddle/v2/framework/tests/op_test_util.py @@ -64,7 +64,8 @@ class OpTestMeta(type): actual = numpy.array(scope.find_var(out_name).get_tensor()) expect = self.outputs[out_name] self.assertTrue( - numpy.allclose(actual, expect), + numpy.allclose( + actual, expect, atol=1e-05), "output name: " + out_name + "has diff") obj.test_all = test_all diff --git a/python/paddle/v2/framework/tests/test_cross_entropy_op.py b/python/paddle/v2/framework/tests/test_cross_entropy_op.py index 4815192e255c6e0429db3f50918a76a773b30131..d4277f2a42ce2e66e37405ccd3b2ee444d403d1a 100644 --- a/python/paddle/v2/framework/tests/test_cross_entropy_op.py +++ b/python/paddle/v2/framework/tests/test_cross_entropy_op.py @@ -8,9 +8,8 @@ class TestCrossEntropy(unittest.TestCase): __metaclass__ = OpTestMeta def setUp(self): - # TODO this unit test is not passed self.type = "onehot_cross_entropy" - batch_size = 100 + batch_size = 30 class_num = 10 X = numpy.random.random((batch_size, class_num)).astype("float32") label = 5 * numpy.ones(batch_size).astype("int32") @@ -22,9 +21,9 @@ class TestCrossEntropy(unittest.TestCase): class CrossEntropyGradOpTest(GradientChecker): - def test_softmax_grad(self): + def test_check_grad(self): op = create_op("onehot_cross_entropy") - batch_size = 100 + batch_size = 30 class_num = 10 inputs = { "X": numpy.random.uniform( diff --git a/python/paddle/v2/framework/tests/test_gather_op.py b/python/paddle/v2/framework/tests/test_gather_op.py new file mode 100644 index 0000000000000000000000000000000000000000..e86898304252d08be718e40fed46c5e921596af7 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_gather_op.py @@ -0,0 +1,34 @@ +import unittest +from op_test_util import OpTestMeta +from gradient_checker import GradientChecker, create_op +import numpy +import paddle.v2.framework.core as core +from paddle.v2.framework.op import Operator + + +class TestGatherOp(unittest.TestCase): + __metaclass__ = OpTestMeta + + def setUp(self): + self.type = "gather" + xnp = numpy.random.random((10, 20)).astype("float32") + self.inputs = { + 'X': xnp, + 'Index': numpy.array([1, 3, 5]).astype("int32") + } + self.outputs = {'Out': self.inputs['X'][self.inputs['Index']]} + + +class TestGatherGradOp(GradientChecker): + def test_gather_grad(self): + print 'creating op' + op = create_op("gather") + print 'creating op done' + xnp = numpy.random.random((10, 20)).astype("float32") + inputs = {'X': xnp, 'Index': numpy.array([1, 3, 5]).astype("int32")} + print 'correct before check gradient' + self.check_grad(op, inputs, set("X"), "Out") + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_minus_op.py b/python/paddle/v2/framework/tests/test_minus_op.py new file mode 100644 index 0000000000000000000000000000000000000000..5abdd4a69bf3faa2f3341f338e195815389a7cef --- /dev/null +++ b/python/paddle/v2/framework/tests/test_minus_op.py @@ -0,0 +1,30 @@ +import unittest +import numpy as np +from gradient_checker import GradientChecker, create_op +from op_test_util import OpTestMeta + + +class MinusOpTest(unittest.TestCase): + __metaclass__ = OpTestMeta + + def setUp(self): + self.type = "minus" + self.inputs = { + 'X': np.random.random((32, 84)).astype("float32"), + 'Y': np.random.random((32, 84)).astype("float32") + } + self.outputs = {'Out': (self.inputs['X'] - self.inputs['Y'])} + + +class MinusGradTest(GradientChecker): + def test_left(self): + op = create_op("minus") + inputs = { + "X": np.random.random((10, 10)).astype("float32"), + "Y": np.random.random((10, 10)).astype("float32") + } + self.check_grad(op, inputs, ["X", 'Y'], "Out") + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_mul_op.py b/python/paddle/v2/framework/tests/test_mul_op.py index ec0ac99156a546dd3fb7b27778032bece38ab5a9..ee0d81a64efcb81bae8b11b856c201a86da274e9 100644 --- a/python/paddle/v2/framework/tests/test_mul_op.py +++ b/python/paddle/v2/framework/tests/test_mul_op.py @@ -1,6 +1,7 @@ import unittest -from op_test_util import OpTestMeta import numpy as np +from gradient_checker import GradientChecker, create_op +from op_test_util import OpTestMeta class TestMulOp(unittest.TestCase): @@ -15,5 +16,19 @@ class TestMulOp(unittest.TestCase): self.outputs = {'Out': np.dot(self.inputs['X'], self.inputs['Y'])} +class MulGradOpTest(GradientChecker): + def test_mul(self): + op = create_op("mul") + inputs = { + 'X': np.random.random((32, 84)).astype("float32"), + 'Y': np.random.random((84, 100)).astype("float32") + } + # mul op will enlarge the relative error + self.check_grad( + op, inputs, set(["X", "Y"]), "Out", max_relative_error=0.5) + + +# TODO(dzh,qijun) : mulgrad test case need transpose feature of blas library + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/v2/framework/tests/test_net.py b/python/paddle/v2/framework/tests/test_net.py index b42cadd11ab75abbc35763c8d12e8c27e995f0dc..9339cf28dabc95b46b958777200fb1db9dcf284f 100644 --- a/python/paddle/v2/framework/tests/test_net.py +++ b/python/paddle/v2/framework/tests/test_net.py @@ -6,8 +6,8 @@ import unittest def fc(X, W, Y): ret_v = core.Net.create() - ret_v.add_op(Operator("mul", X="X", Y="W", Out="pre_activation")) - ret_v.add_op(Operator("sigmoid", X="pre_activation", Y=Y)) + ret_v.append_op(Operator("mul", X="X", Y="W", Out="pre_activation")) + ret_v.append_op(Operator("sigmoid", X="pre_activation", Y=Y)) ret_v.complete_add_op(True) return ret_v @@ -16,12 +16,12 @@ class TestNet(unittest.TestCase): def test_net_all(self): net = core.Net.create() op1 = Operator("add_two", X="X", Y="Y", Out="Out") - net.add_op(op1) + net.append_op(op1) net2 = core.Net.create() - net2.add_op(fc(X="X", W="w", Y="fc.out")) + net2.append_op(fc(X="X", W="w", Y="fc.out")) net2.complete_add_op(True) - net.add_op(net2) + net.append_op(net2) net.complete_add_op(True) expected = ''' diff --git a/python/paddle/v2/framework/tests/test_recurrent_op.py b/python/paddle/v2/framework/tests/test_recurrent_op.py index 3d4a34d8d713ff1beeeba8ac48ad95176f7a29f2..d6000ab9f9d5b969f96128b183f48d49000c8a5e 100644 --- a/python/paddle/v2/framework/tests/test_recurrent_op.py +++ b/python/paddle/v2/framework/tests/test_recurrent_op.py @@ -150,7 +150,7 @@ class TestRecurrentOp(unittest.TestCase): sig_op = Operator("sigmoid", X="sum", Y="h@alias") for op in [x_fc_op, h_fc_op, sum_op, sig_op]: - stepnet.add_op(op) + stepnet.append_op(op) stepnet.complete_add_op(True) self.rnnop.set_stepnet(stepnet) diff --git a/python/paddle/v2/framework/tests/test_rowwise_add_op.py b/python/paddle/v2/framework/tests/test_rowwise_add_op.py index f8521eb517057fbeb104b28af7da4fffe54f37de..45d569da29d13cf8e2a3cb9d67c2d01e8b365453 100644 --- a/python/paddle/v2/framework/tests/test_rowwise_add_op.py +++ b/python/paddle/v2/framework/tests/test_rowwise_add_op.py @@ -1,6 +1,7 @@ import unittest -from op_test_util import OpTestMeta import numpy as np +from op_test_util import OpTestMeta +from gradient_checker import GradientChecker, create_op class TestRowwiseAddOp(unittest.TestCase): @@ -15,5 +16,15 @@ class TestRowwiseAddOp(unittest.TestCase): self.outputs = {'Out': np.add(self.inputs['X'], self.inputs['b'])} +class RowwiseAddGradOpTest(GradientChecker): + def test_rowwise_add(self): + op = create_op("rowwise_add") + inputs = { + "X": np.random.uniform(0.1, 1, [5, 10]).astype("float32"), + "b": np.random.uniform(0.1, 1, [10]).astype("float32") + } + self.check_grad(op, inputs, set(["X", "b"]), "Out") + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/v2/framework/tests/test_scale_and_identity_op.py b/python/paddle/v2/framework/tests/test_scale_and_identity_op.py new file mode 100644 index 0000000000000000000000000000000000000000..69b301c376ee7a4ebb2e2dadc645c7d10f823a08 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_scale_and_identity_op.py @@ -0,0 +1,43 @@ +import unittest +from op_test_util import OpTestMeta +from gradient_checker import GradientChecker, create_op +import numpy as np +from paddle.v2.framework.op import Operator + + +class IdentityTest(unittest.TestCase): + __metaclass__ = OpTestMeta + + def setUp(self): + self.type = "identity" + self.inputs = {'X': np.random.random((32, 784)).astype("float32")} + self.outputs = {'Out': self.inputs['X']} + + +class IdentityGradOpTest(GradientChecker): + def test_normal(self): + op = create_op("identity") + inputs = {"X": np.random.random((10, 10)).astype("float32")} + self.check_grad(op, inputs, set("X"), "Out") + + +class ScaleTest(unittest.TestCase): + __metaclass__ = OpTestMeta + + def setUp(self): + self.type = "scale" + self.inputs = {'X': np.random.random((32, 784)).astype("float32")} + self.attrs = {'scale': -2.3} + self.outputs = {'Out': self.inputs['X'] * self.attrs['scale']} + + +class ScaleGradTest(GradientChecker): + def test_normal(self): + op = Operator("scale", X="X", Out="Out", scale=3.2) + self.check_grad(op, + {"X": np.random.random((10, 10)).astype("float32")}, + set("X"), "Out") + + +if __name__ == '__main__': + unittest.main()