diff --git a/CMakeLists.txt b/CMakeLists.txt index 1252e7539816016dfdf1b90b8941fa42e6bb85e0..264420ad830ed39b38f1918951d8d66c84fd5ee9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -127,6 +127,7 @@ include(external/warpctc) # download, build, install warpctc include(external/any) # download libn::any include(external/eigen) # download eigen3 include(external/pybind11) # download pybind11 +include(external/nccl) include(cudnn) # set cudnn libraries, must before configure include(configure) # add paddle env configuration @@ -159,7 +160,7 @@ set(EXTERNAL_LIBS if(WITH_GPU) list(APPEND EXTERNAL_LIBS ${CUDA_LIBRARIES} ${CUDA_rt_LIBRARY}) if(NOT WITH_DSO) - list(APPEND EXTERNAL_LIBS ${CUDNN_LIBRARY} ${CUDA_CUBLAS_LIBRARIES} ${CUDA_curand_LIBRARY}) + list(APPEND EXTERNAL_LIBS ${CUDNN_LIBRARY} ${CUDA_CUBLAS_LIBRARIES} ${CUDA_curand_LIBRARY} ${NCCL_LIBRARY}) endif(NOT WITH_DSO) endif(WITH_GPU) diff --git a/cmake/configure.cmake b/cmake/configure.cmake index db8f5ab0456792f903093b9cf20e2541f00add5c..24ddb24399dabeec9b8e5faf36be3eb21f420111 100644 --- a/cmake/configure.cmake +++ b/cmake/configure.cmake @@ -62,11 +62,11 @@ else() FIND_PACKAGE(CUDA REQUIRED) if(${CUDA_VERSION_MAJOR} VERSION_LESS 7) - message(FATAL_ERROR "Paddle need CUDA >= 7.0 to compile") + message(FATAL_ERROR "Paddle needs CUDA >= 7.0 to compile") endif() if(NOT CUDNN_FOUND) - message(FATAL_ERROR "Paddle need cudnn to compile") + message(FATAL_ERROR "Paddle needs cudnn to compile") endif() set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS} "-Xcompiler ${SIMD_FLAG}") diff --git a/cmake/external/nccl.cmake b/cmake/external/nccl.cmake new file mode 100644 index 0000000000000000000000000000000000000000..dfbbed58c9ed7cc57809b3d33a29ce26a35d75a2 --- /dev/null +++ b/cmake/external/nccl.cmake @@ -0,0 +1,50 @@ +INCLUDE(ExternalProject) + +SET(NCCL_SOURCE_DIR ${THIRD_PARTY_PATH}/nccl) + +INCLUDE_DIRECTORIES(${NCCL_SOURCE_DIR}/src/extern_nccl/src) + + +if(WITH_DSO) + # If we use DSO, we do not build nccl, just download the dependencies + set(NCCL_BUILD_COMMAND "") + set(NCCL_INSTALL_COMMAND "") + set(NCCL_INSTALL_DIR "") +else() + # otherwise, we build nccl and link it. + set(NCCL_BUILD_COMMAND "make -j 8") + set(NCCL_INSTALL_COMMAND "make install") + SET(NCCL_INSTALL_DIR ${THIRD_PARTY_PATH}/install/nccl) +endif() + +ExternalProject_Add( + extern_nccl + ${EXTERNAL_PROJECT_LOG_ARGS} + GIT_REPOSITORY "https://github.com/NVIDIA/nccl.git" + GIT_TAG "v1.3.4-1" + PREFIX "${NCCL_SOURCE_DIR}" + UPDATE_COMMAND "" + CONFIGURE_COMMAND "" + BUILD_COMMAND "${NCCL_BUILD_COMMAND}" + INSTALL_COMMAND "${NCCL_INSTALL_COMMAND}" + INSTALL_DIR "${NCCL_INSTALL_DIR}" + TEST_COMMAND "" +) + +if (WITH_DSO) + if (${CMAKE_VERSION} VERSION_LESS "3.3.0") + set(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/lib_any_dummy.c) + file(WRITE ${dummyfile} "const char * dummy_any = \"${dummyfile}\";") + add_library(nccl STATIC ${dummyfile}) + else() + add_library(nccl INTERFACE) + endif() +else() + ADD_LIBRARY(nccl STATIC IMPORTED GLOBAL) + SET_PROPERTY(TARGET nccl PROPERTY IMPORTED_LOCATION + ${NCCL_INSTALL_DIR}/lib/libnccl.a) +endif() + +add_dependencies(nccl extern_nccl) + +LIST(APPEND external_project_dependencies nccl) diff --git a/doc/faq/local/index_cn.rst b/doc/faq/local/index_cn.rst index 75c4ba028e497e29e9030a86514348726d9c0a80..0e939a2671ace8682c90cdc1c1bb2da1dda0d568 100644 --- a/doc/faq/local/index_cn.rst +++ b/doc/faq/local/index_cn.rst @@ -174,7 +174,7 @@ decoder_inputs = paddle.layer.fc( 1. 两者都是对梯度的截断,但截断时机不同,前者在 :code:`optimzier` 更新网络参数时应用;后者在激活函数反向计算时被调用; 2. 截断对象不同:前者截断可学习参数的梯度,后者截断回传给前层的梯度; -除此之外,还可以通过减小学习律或者对数据进行归一化处理来解决这类问题。 +除此之外,还可以通过减小学习率或者对数据进行归一化处理来解决这类问题。 5. 如何调用 infer 接口输出多个layer的预测结果 ----------------------------------------------- diff --git a/paddle/framework/block_desc.cc b/paddle/framework/block_desc.cc index 21d4fdaf0680036a484ee4258e47c6c8854967c3..251e340e6ddcc17ba16bdcab63f2a8c907122eab 100644 --- a/paddle/framework/block_desc.cc +++ b/paddle/framework/block_desc.cc @@ -41,6 +41,19 @@ bool BlockDescBind::HasVar(const std::string &name) const { return vars_.find(name) != vars_.end(); } +VarDescBind *BlockDescBind::FindVarRecursive(const std::string &name) const { + auto it = vars_.find(name); + if (it == vars_.end()) { + return Parent() == kNoneBlockIndex ? nullptr + : ParentBlock()->FindVarRecursive(name); + } + return it->second.get(); +} + +bool BlockDescBind::HasVarRecursive(const std::string &name) const { + return FindVarRecursive(name) != nullptr; +} + std::vector BlockDescBind::AllVars() const { std::vector res; for (const auto &p : vars_) { @@ -97,7 +110,7 @@ void BlockDescBind::Flush() { } BlockDescBind *BlockDescBind::ParentBlock() const { - if (this->desc_->parent_idx() == -1) { + if (this->desc_->parent_idx() == kNoneBlockIndex) { return nullptr; } return prog_->Block(static_cast(this->desc_->parent_idx())); diff --git a/paddle/framework/block_desc.h b/paddle/framework/block_desc.h index 7d1d33f6860aa90518abb379a5e9964d6029c6fa..c685050850dc25f346df49b5ce1d897974870460 100644 --- a/paddle/framework/block_desc.h +++ b/paddle/framework/block_desc.h @@ -21,6 +21,7 @@ limitations under the License. */ #include #include "paddle/framework/op_desc.h" +#include "paddle/framework/proto_desc.h" #include "paddle/framework/var_desc.h" #include "paddle/platform/macros.h" @@ -56,6 +57,10 @@ class BlockDescBind { bool HasVar(const std::string &var_name) const; + VarDescBind *FindVarRecursive(const std::string &name_bytes) const; + + bool HasVarRecursive(const std::string &var_name) const; + std::set LocalVarNames() const { std::set var_names; for (auto &var : vars_) { diff --git a/paddle/framework/op_info.h b/paddle/framework/op_info.h index e926180780609c0a8ffc6270627835c50bbce782..59a64d71371b546f76eabdeed7e7514e8fb0f84a 100644 --- a/paddle/framework/op_info.h +++ b/paddle/framework/op_info.h @@ -87,11 +87,8 @@ class OpInfoMap { } } - template - void IterAllInfo(Callback callback) { - for (auto& it : map_) { - callback(it.first, it.second); - } + const std::unordered_map& map() const { + return map_; } private: diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 9d7fe1f5ba293227e67cf6bfcd566a1247c567ed..0d0304ac9e13089ef533b0a47f0ec989c8fd7078 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -327,37 +327,47 @@ class CompileTimeInferShapeContext : public InferShapeContext { bool HasInput(const std::string& name) const override { const std::vector& input_names = op_.Input(name); auto length = input_names.size(); + if (length == 0) { + return false; + } PADDLE_ENFORCE_EQ(length, 1UL, "Input(%s) should have only one value, " "but it have %d now", name, length); - return block_.HasVar(input_names[0]); + return block_.HasVarRecursive(input_names[0]); } bool HasOutput(const std::string& name) const override { const std::vector& output_names = op_.Output(name); auto length = output_names.size(); + if (length == 0) { + return false; + } PADDLE_ENFORCE_EQ(length, 1UL, "Output(%s) should have only one value, " "but it have %d now", name, length); - return block_.HasVar(output_names[0]); + return block_.HasVarRecursive(output_names[0]); } bool HasInputs(const std::string& name) const override { const std::vector& input_names = op_.Input(name); - PADDLE_ENFORCE(!input_names.empty(), "Inputs(%s) length is 0", name); + if (input_names.empty()) { + return false; + } for (auto& input : input_names) { - if (!block_.HasVar(input)) return false; + if (!block_.HasVarRecursive(input)) return false; } return true; } bool HasOutputs(const std::string& name) const override { const std::vector& output_names = op_.Output(name); - PADDLE_ENFORCE(!output_names.empty(), "Inputs(%s) length is 0", name); + if (output_names.empty()) { + return false; + } for (auto& output : output_names) { - if (!block_.HasVar(output)) return false; + if (!block_.HasVarRecursive(output)) return false; } return true; } @@ -404,11 +414,11 @@ class CompileTimeInferShapeContext : public InferShapeContext { private: DDim GetDim(const std::string& name) const override { - return framework::make_ddim(block_.FindVar(name)->Shape()); + return framework::make_ddim(block_.FindVarRecursive(name)->Shape()); } void SetDim(const std::string& name, const DDim& dim) override { - block_.FindVar(name)->SetShape(framework::vectorize(dim)); + block_.FindVarRecursive(name)->SetShape(framework::vectorize(dim)); } const OpDescBind& op_; @@ -421,13 +431,27 @@ class RuntimeInferShapeContext : public InferShapeContext { : op_(op), scope_(scope) {} bool HasInput(const std::string& name) const override { - auto ipt = op_.Input(name); + auto& ins = Inputs(name); + size_t length = ins.size(); + if (length == 0) { + return false; + } + PADDLE_ENFORCE_EQ(length, 1UL, "Input %s should have more than one inputs", + name); + auto ipt = ins[0]; auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt); return var != nullptr; } bool HasOutput(const std::string& name) const override { - auto ipt = op_.Output(name); + auto& outs = Outputs(name); + size_t length = outs.size(); + if (length == 0) { + return false; + } + PADDLE_ENFORCE_EQ(length, 1UL, "Output %s should have more than one inputs", + name); + auto ipt = outs[0]; auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt); return var != nullptr; } diff --git a/paddle/framework/program_desc.cc b/paddle/framework/program_desc.cc index e2349cefe09a6c1e0b11f77775426fe5c7594c9d..8e99bba81117c9cc50227122527d6ab9a421c251 100644 --- a/paddle/framework/program_desc.cc +++ b/paddle/framework/program_desc.cc @@ -35,8 +35,8 @@ ProgramDesc *ProgramDescBind::Proto() { ProgramDescBind::ProgramDescBind() { auto *block = prog_.mutable_blocks()->Add(); - block->set_idx(0); - block->set_parent_idx(-1); + block->set_idx(kRootBlockIndex); + block->set_parent_idx(kNoneBlockIndex); blocks_.emplace_back(new BlockDescBind(this, block)); } diff --git a/paddle/framework/program_desc.h b/paddle/framework/program_desc.h index 20cc1a2325ffd6f8ef52879a749f106c268376d4..dc4cd7cc735b5e4e3466d9b82dc5eb8647c80ef9 100644 --- a/paddle/framework/program_desc.h +++ b/paddle/framework/program_desc.h @@ -17,6 +17,7 @@ limitations under the License. */ #include #include #include "paddle/framework/framework.pb.h" +#include "paddle/framework/proto_desc.h" #include "paddle/platform/macros.h" namespace paddle { diff --git a/paddle/framework/proto_desc.h b/paddle/framework/proto_desc.h new file mode 100644 index 0000000000000000000000000000000000000000..fa01224fefce50eb3688ff407f0a7c948c5b7cfc --- /dev/null +++ b/paddle/framework/proto_desc.h @@ -0,0 +1,26 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +namespace paddle { +namespace framework { + +// The Index of first Block in Program. also called root block. +constexpr int kRootBlockIndex = 0; +// The Parent Index of root Block, this block does not exist. +constexpr int kNoneBlockIndex = -1; + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/var_desc.cc b/paddle/framework/var_desc.cc index c302217e5aacdc17800238770d689b7fb65804f3..8e92c81d1137472737230be79d71824593d3256f 100644 --- a/paddle/framework/var_desc.cc +++ b/paddle/framework/var_desc.cc @@ -18,6 +18,10 @@ limitations under the License. */ namespace paddle { namespace framework { +VarDesc::VarType VarDescBind::GetType() const { return desc_.type(); } + +void VarDescBind::SetType(VarDesc::VarType type) { desc_.set_type(type); } + void VarDescBind::SetShape(const std::vector &dims) { VectorToRepeated(dims, mutable_tensor_desc()->mutable_dims()); } diff --git a/paddle/framework/var_desc.h b/paddle/framework/var_desc.h index af4c26ca0a77b444852cc01545a8b585a5c3afcc..929de1f836fa906966ff125c70380d85d062afdf 100644 --- a/paddle/framework/var_desc.h +++ b/paddle/framework/var_desc.h @@ -75,9 +75,9 @@ class VarDescBind { int32_t GetLodLevel() const; - VarDesc::VarType GetType() const { return desc_.type(); } + VarDesc::VarType GetType() const; - void SetType(VarDesc::VarType type) { desc_.set_type(type); } + void SetType(VarDesc::VarType type); bool Persistable() const { return desc_.persistable(); } diff --git a/paddle/gserver/activations/MKLDNNActivation.cpp b/paddle/gserver/activations/MKLDNNActivation.cpp index 18c5638100065109fba1f0647a1c5f91256f7b9d..f3ccd68160859795f28a40f8d0d4032adb289ccf 100644 --- a/paddle/gserver/activations/MKLDNNActivation.cpp +++ b/paddle/gserver/activations/MKLDNNActivation.cpp @@ -126,7 +126,7 @@ void MKLDNNEltwiseActivation::resetFwd(Argument& act) { copyInVal_ = nullptr; if (act.grad && algo == algorithm::eltwise_tanh) { // tanh need save src input for backward - inVal_ = MKLDNNMatrix::create(nullptr, val_->getPrimitiveDesc()); + inVal_ = MKLDNNMatrix::create(val_->getPrimitiveDesc()); copyInVal_ = std::make_shared(*val_, *inVal_); CHECK(copyInVal_) << "should not be emptry"; pipelineFwd_.push_back(*copyInVal_); @@ -145,7 +145,7 @@ void MKLDNNEltwiseActivation::resetBwd(Argument& act) { algorithm algo = getAlgo(this->getName()); float alpha = getBwdAlpha(); float beta = getBeta(); - grad_ = MKLDNNMatrix::create(act.grad, val_->getPrimitiveDesc()); + grad_ = MKLDNNMatrix::create(val_->getPrimitiveDesc(), act.grad); auto eng = CPUEngine::Instance().getEngine(); auto bwdDesc = eltwise_bwd::desc( algo, grad_->getMemoryDesc(), val_->getMemoryDesc(), alpha, beta); @@ -230,7 +230,7 @@ void MKLDNNActivation::resetFwd(Argument& act) { int ic = cnt_ / bs / ih / iw; CHECK_EQ(cnt_, (size_t)bs * ic * ih * iw); val_ = MKLDNNMatrix::create( - act.value, {bs, ic, ih, iw}, mkldnn::memory::format::nchw, *engine_); + {bs, ic, ih, iw}, mkldnn::memory::format::nchw, *engine_, act.value); CHECK(val_); val_->downSpatial(); } diff --git a/paddle/gserver/layers/MKLDNNBase.h b/paddle/gserver/layers/MKLDNNBase.h index 4c0234e7b3a91053596c32cea581fa5d1e26b9d5..af02a37cad668708f77ecf423549a8ec993e54fb 100644 --- a/paddle/gserver/layers/MKLDNNBase.h +++ b/paddle/gserver/layers/MKLDNNBase.h @@ -21,8 +21,8 @@ namespace paddle { typedef enum { MKLDNN_BASE = 1, // basical info of MKLDNN MKLDNN_TESTS = 1, // gtest info of MKLDNN - MKLDNN_SIZES = 2, // size info of MKLDNN - MKLDNN_FMTS = 3, // format info of MKLDNN + MKLDNN_FMTS = 2, // format info of MKLDNN + MKLDNN_SIZES = 3, // size info of MKLDNN MKLDNN_ALL = 4, // show all info of MKLDNN } MKLDNN_LOG_LEVEL; diff --git a/paddle/gserver/layers/MKLDNNConvLayer.cpp b/paddle/gserver/layers/MKLDNNConvLayer.cpp index 26810a648343d6203f7937740641325ae8ea6879..83f4e4e6151d727b3e6cf367bb7ecae55dd7df73 100644 --- a/paddle/gserver/layers/MKLDNNConvLayer.cpp +++ b/paddle/gserver/layers/MKLDNNConvLayer.cpp @@ -116,8 +116,6 @@ void MKLDNNConvLayer::resetFwd(std::vector& pipeline, resetFwdBuffers(fwdPD_, in, wgt, bias, out); resetFwdPipeline(pipeline, fwdPD_, in, wgt, bias, out); - - printValueFormatFlow(); } void MKLDNNConvLayer::resetBwd(std::vector& pipeline, @@ -135,12 +133,6 @@ void MKLDNNConvLayer::resetBwd(std::vector& pipeline, resetBwdBuffers(bwdWgtPD, bwdDataPD, in, wgt, bias, out); resetBwdPipeline(pipeline, bwdWgtPD, bwdDataPD, in, wgt, bias, out); - - printGradFormatFlow(); -} - -void MKLDNNConvLayer::updateInputData() { - cpuInVal_->setData(getInputValue(0, CPU_DEVICE)->getData()); } void MKLDNNConvLayer::updateWeights(const UpdateCallback& callback) { @@ -211,11 +203,18 @@ void MKLDNNConvLayer::resetFwdBuffers( MKLDNNMatrixPtr& bias, MKLDNNMatrixPtr& out) { CHECK(pd); - resetInValue(pd, in); + resetInValue( + in, std::make_shared(pd->src_primitive_desc())); + + resetOutValue(out, pd->dst_primitive_desc()); - resetWgtBiasValue(pd, wgt, bias); + resetWithMatrix(wgt, weight_->getW(), pd->weights_primitive_desc()); - resetOutValue(pd, out); + if (biases_ && biases_->getW()) { + resetWithMatrix(bias, biases_->getW(), pd->bias_primitive_desc()); + } else { + bias = nullptr; + } } void MKLDNNConvLayer::resetFwdPipeline( @@ -225,104 +224,12 @@ void MKLDNNConvLayer::resetFwdPipeline( MKLDNNMatrixPtr& wgt, MKLDNNMatrixPtr& bias, MKLDNNMatrixPtr& out) { - if (cvtInVal_) { - pipeline.push_back(*cvtInVal_); - } - if (bias) { fwd_.reset(new conv_fwd(*pd, *in, *wgt, *bias, *out)); } else { fwd_.reset(new conv_fwd(*pd, *in, *wgt, *out)); } pipeline.push_back(*fwd_); - - if (cvtOutVal_) { - pipeline.push_back(*cvtOutVal_); - } -} - -void MKLDNNConvLayer::resetInValue( - std::shared_ptr& pd, MKLDNNMatrixPtr& in) { - const MatrixPtr& inMat = inputLayers_[0]->getOutputValue(); - in = MKLDNNMatrix::create(inMat, pd->src_primitive_desc()); - - // create buffer and reorder if input value do not match - cpuInVal_ = nullptr; - cvtInVal_ = nullptr; - - MKLDNNMatrixPtr dnnIn = std::dynamic_pointer_cast(inMat); - CHECK_EQ(inputIsOnlyMKLDNN(), dnnIn != nullptr); - if (dnnIn != nullptr && dnnIn->getPrimitiveDesc() == in->getPrimitiveDesc()) { - in = dnnIn; - return; - } - if (dnnIn) { - if (dnnIn->getFormat() == format::nc) { - CHECK(ih_ == 1 && iw_ == 1) << "when input is nc format"; - // create a new one with nchw format and same data - memory::dims inDims = memory::dims{bs_, ic_, 1, 1}; - dnnIn = MKLDNNMatrix::create(inMat, inDims, format::nchw, engine_); - } - if (dnnIn->getPrimitiveDesc() == in->getPrimitiveDesc()) { - in = dnnIn; - return; - } - cpuInVal_ = dnnIn; - in = MKLDNNMatrix::create(nullptr, pd->src_primitive_desc()); - cvtInVal_ = MKLDNNMatrix::createReorder(cpuInVal_, in); - CHECK(cvtInVal_) << "should not be emptry"; - } else { - memory::dims inDims = memory::dims{bs_, ic_, ih_, iw_}; - cpuInVal_ = MKLDNNMatrix::create(inMat, inDims, format::nchw, engine_); - if (cpuInVal_->getPrimitiveDesc() != in->getPrimitiveDesc()) { - // create new mkldnn matrix - in = MKLDNNMatrix::create(nullptr, pd->src_primitive_desc()); - cvtInVal_ = MKLDNNMatrix::createReorder(cpuInVal_, in); - CHECK(cvtInVal_) << "should not be emptry"; - } else { - in = cpuInVal_; - } - } -} - -void MKLDNNConvLayer::resetWgtBiasValue( - std::shared_ptr& pd, - MKLDNNMatrixPtr& wgt, - MKLDNNMatrixPtr& bias) { - wgt = MKLDNNMatrix::create(weight_->getW(), pd->weights_primitive_desc()); - VLOG(MKLDNN_FMTS) << "Weight value format: " << wgt->getFormat(); - - bias = (biases_ && biases_->getW()) - ? MKLDNNMatrix::create(biases_->getW(), pd->bias_primitive_desc()) - : nullptr; -} - -void MKLDNNConvLayer::resetOutValue( - std::shared_ptr& pd, MKLDNNMatrixPtr& out) { - out = MKLDNNMatrix::create(output_.value, pd->dst_primitive_desc()); - - // create reorder if output value has cpu device and pd do not match - cpuOutVal_ = nullptr; - cvtOutVal_ = nullptr; - if (!outputIsOnlyMKLDNN()) { - const MatrixPtr& cpuOut = getOutput(CPU_DEVICE).value; - memory::dims outDims = memory::dims{bs_, oc_, oh_, ow_}; - cpuOutVal_ = MKLDNNMatrix::create(cpuOut, outDims, format::nchw, engine_); - if (cpuOutVal_->getPrimitiveDesc() != pd->dst_primitive_desc()) { - out = MKLDNNMatrix::create(nullptr, pd->dst_primitive_desc()); - cvtOutVal_ = MKLDNNMatrix::createReorder(out, cpuOutVal_); - CHECK(cvtOutVal_) << "should not be empty"; - } else { - cpuOut->setData(output_.value->getData()); - cpuOutVal_ = out; - } - // when output is cpu device, change the mkldnn output value and make them - // share the same data. Then if next layer use inputlayer->getOuputValue() - // to achieve the input value, it will get the right data. - output_.value = std::dynamic_pointer_cast(cpuOutVal_); - return; - } - output_.value = std::dynamic_pointer_cast(out); } void MKLDNNConvLayer::resetBwdWgtPD( @@ -331,8 +238,8 @@ void MKLDNNConvLayer::resetBwdWgtPD( loadConvSettings(wgtDims, biasDims, strides, dilations, padL, padR); // create backward weight using input, output and weight value memory desc - CHECK(inVal_) << "Should have input value"; - CHECK(outVal_) << "Should have output value"; + CHECK(inVal_) << "Should have internal input value"; + CHECK(outVal_) << "Should have internal output value"; CHECK(wgtVal_) << "Should have weight value"; algorithm algo = algorithm::convolution_direct; padding_kind padKind = padding_kind::zero; @@ -372,8 +279,8 @@ void MKLDNNConvLayer::resetBwdDataPD( memory::dims wgtDims, biasDims, strides, dilations, padL, padR; loadConvSettings(wgtDims, biasDims, strides, dilations, padL, padR); - CHECK(inVal_) << "Should have input value"; - CHECK(outVal_) << "Should have output value"; + CHECK(inVal_) << "Should have internal input value"; + CHECK(outVal_) << "Should have internal output value"; // create backward data using input and output value memory desc // but using weight memory desc with any format auto bwdDataDesc = conv_bwdData::desc(algorithm::convolution_direct, @@ -399,12 +306,27 @@ void MKLDNNConvLayer::resetBwdBuffers( MKLDNNMatrixPtr& bias, MKLDNNMatrixPtr& out) { CHECK(wgtPD); - resetOutGrad(wgtPD, out); + resetOutGrad(out, wgtPD->diff_dst_primitive_desc()); - resetWgtBiasGrad(wgtPD, wgt, bias); + resetWithMatrix( + wgt, weight_->getWGrad(), wgtPD->diff_weights_primitive_desc()); + CHECK(wgtVal_ != nullptr && + wgt->getPrimitiveDesc() == wgtVal_->getPrimitiveDesc()) + << "primitive desc of weight grad and value should be equal"; - resetInGrad(dataPD, in); + bias = nullptr; + if (biases_ && biases_->getWGrad()) { + resetWithMatrix( + bias, biases_->getWGrad(), wgtPD->diff_bias_primitive_desc()); + CHECK(bias && biasVal_ && + bias->getPrimitiveDesc() == biasVal_->getPrimitiveDesc()) + << "primitive desc of bias grad should equal the bias value"; + } + if (dataPD == nullptr) { + return; + } + resetInGrad(in, dataPD->diff_src_primitive_desc()); resetWgtValBwdData(dataPD, wgtValBwdData_); } @@ -416,10 +338,7 @@ void MKLDNNConvLayer::resetBwdPipeline( MKLDNNMatrixPtr& wgt, MKLDNNMatrixPtr& bias, MKLDNNMatrixPtr& out) { - if (cvtOutGrad_) { - pipeline.push_back(*cvtOutGrad_); - } - + CHECK(inVal_); // add bwdWgt handle if (bias) { bwdWgt_.reset(new conv_bwdWgt(*wgtPD, *inVal_, *out, *wgt, *bias)); @@ -431,99 +350,13 @@ void MKLDNNConvLayer::resetBwdPipeline( if (dataPD == nullptr) { return; } - if (cvtWgtVal_) { pipeline.push_back(*cvtWgtVal_); } - // add bwdData handle CHECK(wgtValBwdData_) << "Should have weight memory"; bwdData_.reset(new conv_bwdData(*dataPD, *out, *wgtValBwdData_, *in)); pipeline.push_back(*bwdData_); - - if (cvtInGrad_) { - pipeline.push_back(*cvtInGrad_); - } -} - -void MKLDNNConvLayer::resetOutGrad( - std::shared_ptr& wgtPD, MKLDNNMatrixPtr& out) { - cpuOutGrad_ = nullptr; - cvtOutGrad_ = nullptr; - CHECK(outVal_ != nullptr && - outVal_->getPrimitiveDesc() == wgtPD->diff_dst_primitive_desc()) - << "primitive desc of out grad and value should be equal"; - if (outputIsOnlyMKLDNN()) { - MKLDNNLayer::resetOutGrad(out, outVal_->getPrimitiveDesc()); - } else { - const MatrixPtr& cpuOut = getOutput(CPU_DEVICE).grad; - // always share the same grad data of CPU output - // then the activation can get the right grad from output_.grad - output_.grad->setData(cpuOut->getData()); - // same PrimitiveDesc with cpuInVal_ - CHECK(cpuOutVal_); - cpuOutGrad_ = MKLDNNMatrix::create(cpuOut, cpuOutVal_->getPrimitiveDesc()); - // create reorder if primitive desc does not match - if (cpuOutGrad_->getPrimitiveDesc() != outVal_->getPrimitiveDesc()) { - out = MKLDNNMatrix::create(nullptr, outVal_->getPrimitiveDesc()); - cvtOutGrad_ = MKLDNNMatrix::createReorder(cpuOutGrad_, out); - CHECK(cvtOutGrad_); - } else { - out = cpuOutGrad_; - } - } -} - -void MKLDNNConvLayer::resetWgtBiasGrad( - std::shared_ptr& wgtPD, - MKLDNNMatrixPtr& wgt, - MKLDNNMatrixPtr& bias) { - wgt = MKLDNNMatrix::create(weight_->getWGrad(), - wgtPD->diff_weights_primitive_desc()); - CHECK(nullptr != wgtVal_ && - wgt->getPrimitiveDesc() == wgtVal_->getPrimitiveDesc()) - << "primitive desc of weight grad and value should be equal"; - VLOG(MKLDNN_FMTS) << "weight grad format: " << wgt->getFormat(); - - bias = nullptr; - if (biasVal_ == nullptr) { - return; - } - bias = MKLDNNMatrix::create(biases_->getWGrad(), - wgtPD->diff_bias_primitive_desc()); - CHECK(bias->getPrimitiveDesc() == biasVal_->getPrimitiveDesc()) - << "primitive desc of bias grad should equal the bias value"; -} - -void MKLDNNConvLayer::resetInGrad( - std::shared_ptr& dataPD, - MKLDNNMatrixPtr& in) { - in = nullptr; - cpuInGrad_ = nullptr; - cvtInGrad_ = nullptr; - if (dataPD == nullptr) { - return; - } - - if (inputIsOnlyMKLDNN()) { - MKLDNNLayer::resetInGrad(in, dataPD->diff_src_primitive_desc()); - CHECK(nullptr != inVal_ && - in->getPrimitiveDesc() == inVal_->getPrimitiveDesc()) - << "primitive desc of input grad and value should be equal"; - } else { - const MatrixPtr& cpuIn = getInputGrad(0, CPU_DEVICE); - // same PrimitiveDesc with cpuInVal_ - CHECK(cpuInVal_); - cpuInGrad_ = MKLDNNMatrix::create(cpuIn, cpuInVal_->getPrimitiveDesc()); - in = cpuInGrad_; - // create reorder if PrimitiveDesc does not match - if (cpuInGrad_->getPrimitiveDesc() != dataPD->diff_src_primitive_desc()) { - in = MKLDNNMatrix::create(getInputGrad(0, MKLDNN_DEVICE), - dataPD->diff_src_primitive_desc()); - cvtInGrad_ = MKLDNNMatrix::createReorder(in, cpuInGrad_); - CHECK(cvtInGrad_); - } - } } void MKLDNNConvLayer::resetWgtValBwdData( @@ -537,8 +370,7 @@ void MKLDNNConvLayer::resetWgtValBwdData( // since the primitive_desc would be different with wgtVal_ CHECK(wgtVal_) << "should have weight value"; if (dataPD->weights_primitive_desc() != wgtVal_->getPrimitiveDesc()) { - wgtValBwdData_ = - MKLDNNMatrix::create(nullptr, dataPD->weights_primitive_desc()); + wgtValBwdData_ = MKLDNNMatrix::create(dataPD->weights_primitive_desc()); cvtWgtVal_ = MKLDNNMatrix::createReorder(wgtVal_, wgtValBwdData_); CHECK(cvtWgtVal_); } else { diff --git a/paddle/gserver/layers/MKLDNNConvLayer.h b/paddle/gserver/layers/MKLDNNConvLayer.h index f84f2f737c47a1b8adc2b83360a0396ffbc6ae24..1fed0e1c6565b763a3ee73a0853f560ddfbd44c6 100644 --- a/paddle/gserver/layers/MKLDNNConvLayer.h +++ b/paddle/gserver/layers/MKLDNNConvLayer.h @@ -48,17 +48,6 @@ protected: // save forward primitive_desc, which can be used backward std::shared_ptr fwdPD_; - // MKLDNNMatrixPtr which should be created from CPU Device - MKLDNNMatrixPtr cpuInVal_; - MKLDNNMatrixPtr cpuInGrad_; - MKLDNNMatrixPtr cpuOutVal_; - MKLDNNMatrixPtr cpuOutGrad_; - // convert handle between CPU device and MKLDNN device - std::shared_ptr cvtInVal_; - std::shared_ptr cvtInGrad_; - std::shared_ptr cvtOutVal_; - std::shared_ptr cvtOutGrad_; - // whether the weight has been init bool hasInitedWgt_; @@ -94,8 +83,6 @@ public: MKLDNNMatrixPtr& bias, MKLDNNMatrixPtr& out) override; - void updateInputData() override; - void updateWeights(const UpdateCallback& callback) override; void convertWeightsFromPaddle() override; @@ -109,26 +96,6 @@ public: << ", sw: " << sw_ << ", dh: " << dh_ << ", dw: " << dw_; } - void printValueFormatFlow() override { - if (cpuInVal_) { - VLOG(MKLDNN_FMTS) << cpuInVal_->getFormat() << " >>>"; - } - MKLDNNLayer::printValueFormatFlow(); - if (cpuOutVal_) { - VLOG(MKLDNN_FMTS) << " >>> " << cpuOutVal_->getFormat(); - } - } - - void printGradFormatFlow() override { - if (cpuInGrad_) { - VLOG(MKLDNN_FMTS) << cpuInGrad_->getFormat() << " <<<"; - } - MKLDNNLayer::printGradFormatFlow(); - if (cpuOutGrad_) { - VLOG(MKLDNN_FMTS) << " <<< " << cpuOutGrad_->getFormat(); - } - } - protected: /** * load the dims settings of this conv @@ -162,23 +129,6 @@ protected: MKLDNNMatrixPtr& bias, MKLDNNMatrixPtr& out); - /** - * reset MKLDNNMatrix of input value - */ - void resetInValue(std::shared_ptr& pd, - MKLDNNMatrixPtr& in); - /** - * reset MKLDNNMatrix of weight and bias value - */ - void resetWgtBiasValue(std::shared_ptr& pd, - MKLDNNMatrixPtr& wgt, - MKLDNNMatrixPtr& bias); - /** - * reset MKLDNNMatrix of output value - */ - void resetOutValue(std::shared_ptr& pd, - MKLDNNMatrixPtr& out); - /** * reset the backward weight primitive descriptor. */ @@ -207,22 +157,6 @@ protected: MKLDNNMatrixPtr& bias, MKLDNNMatrixPtr& out); - /** - * reset MKLDNNMatrix of output grad - */ - void resetOutGrad(std::shared_ptr& wgtPD, - MKLDNNMatrixPtr& out); - /** - * reset MKLDNNMatrix of weight and bias grad - */ - void resetWgtBiasGrad(std::shared_ptr& wgtPD, - MKLDNNMatrixPtr& wgt, - MKLDNNMatrixPtr& bias); - /** - * reset MKLDNNMatrix of input grad - */ - void resetInGrad(std::shared_ptr& dataPD, - MKLDNNMatrixPtr& in); /** * reset MKLDNNMatrix of weight value for backward data * since the primitive_desc would be different with wgtVal_ diff --git a/paddle/gserver/layers/MKLDNNFcLayer.cpp b/paddle/gserver/layers/MKLDNNFcLayer.cpp index cf19a155681f3a1ceb20af67245c8f2b8fa8fa73..d82063a7130ca928ba042e210eb216f90c7207cd 100644 --- a/paddle/gserver/layers/MKLDNNFcLayer.cpp +++ b/paddle/gserver/layers/MKLDNNFcLayer.cpp @@ -62,7 +62,7 @@ void MKLDNNFcLayer::convertWeightsFromPaddle() { CHECK(wgtVal_) << "should have been initialized"; bool hasNoSpatial_ = ih_ == 1 && iw_ == 1; auto targetDim = wgtVal_->getDims(); - auto srcFmt = hasNoSpatial_ ? memory::format::io : memory::format::ihwo; + auto srcFmt = hasNoSpatial_ ? format::io : format::ihwo; wgtVal_->reorderDataFrom(wgtVal_, srcFmt, targetDim); hasInitedWgt_ = true; } @@ -71,7 +71,7 @@ void MKLDNNFcLayer::convertWeightsToPaddle() { CHECK(wgtVal_) << "should have been initialized"; bool hasNoSpatial_ = ih_ == 1 && iw_ == 1; auto targetDim = wgtVal_->getDims(); - auto dstFmt = hasNoSpatial_ ? memory::format::io : memory::format::ihwo; + auto dstFmt = hasNoSpatial_ ? format::io : format::ihwo; wgtVal_->reorderDataTo(wgtVal_, dstFmt, targetDim); } @@ -100,8 +100,6 @@ void MKLDNNFcLayer::resetFwd(std::vector& pipeline, resetFwdPD(fwdPD_, in, wgt, bias, out); resetFwdPipeline(pipeline, fwdPD_, in, wgt, bias, out); - - printValueFormatFlow(); } void MKLDNNFcLayer::resetBwd(std::vector& pipeline, @@ -119,12 +117,6 @@ void MKLDNNFcLayer::resetBwd(std::vector& pipeline, resetBwdDataPD(bwdDataPD, in, out); resetBwdPipeline(pipeline, bwdWgtPD, bwdDataPD, in, wgt, bias, out); - - printGradFormatFlow(); -} - -void MKLDNNFcLayer::updateInputData() { - inVal_->setData(getInputValue(0, CPU_DEVICE)->getData()); } void MKLDNNFcLayer::updateWeights(const UpdateCallback& callback) { @@ -139,51 +131,30 @@ void MKLDNNFcLayer::resetFwdBuffers(MKLDNNMatrixPtr& in, MKLDNNMatrixPtr& bias, MKLDNNMatrixPtr& out) { resetInValue(in); - - resetWgtBiasValue(wgt, bias); - - resetOutValue(out); -} - -void MKLDNNFcLayer::resetInValue(MKLDNNMatrixPtr& in) { - if (inputIsOnlyMKLDNN()) { - const MatrixPtr& dnnIn = getInputValue(0); - in = std::dynamic_pointer_cast(dnnIn); - CHECK(in) << "Input should be MKLDNNMatrix"; - } else { - CHECK_EQ(getPrev(0)->getDeviceId(), CPU_DEVICE) << "Only support CPU yet"; - const MatrixPtr& cpuIn = getInputValue(0, CPU_DEVICE); - in = MKLDNNMatrix::create( - cpuIn, {bs_, ic_, ih_, iw_}, format::nchw, engine_); - } + CHECK(in); in->downSpatial(); -} -void MKLDNNFcLayer::resetWgtBiasValue(MKLDNNMatrixPtr& wgt, - MKLDNNMatrixPtr& bias) { + auto outPD = + MKLDNNMatrix::createPrimitiveDesc({bs_, oc_}, format::nc, engine_); + resetOutValue(out, outPD); + format wgtFmt = format::oihw; - if (inVal_->getFormat() == format::nChw8c) { + if (in->getFormat() == format::nChw8c) { wgtFmt = format::oIhw8i; - } else if (inVal_->getFormat() == format::nChw16c) { + } else if (in->getFormat() == format::nChw16c) { wgtFmt = format::oIhw16i; } - wgt = MKLDNNMatrix::create( - weight_->getW(), {oc_, ic_, ih_, iw_}, wgtFmt, engine_); + auto wgtPD = + MKLDNNMatrix::createPrimitiveDesc({oc_, ic_, ih_, iw_}, wgtFmt, engine_); + resetWithMatrix(wgt, weight_->getW(), wgtPD); wgt->downSpatial(); - VLOG(MKLDNN_FMTS) << "Weight value format: " << wgt->getFormat(); - - bias = (biases_ && biases_->getW()) - ? MKLDNNMatrix::create(biases_->getW(), {oc_}, format::x, engine_) - : nullptr; -} -void MKLDNNFcLayer::resetOutValue(MKLDNNMatrixPtr& out) { - out = MKLDNNMatrix::create(output_.value, {bs_, oc_}, format::nc, engine_); - if (!outputIsOnlyMKLDNN()) { - // fc cpu output value do not need create convert, just share data - getOutput(CPU_DEVICE).value->setData(out->getData()); + if (biases_ && biases_->getW()) { + auto biasPD = MKLDNNMatrix::createPrimitiveDesc({oc_}, format::x, engine_); + resetWithMatrix(bias, biases_->getW(), biasPD); + } else { + bias = nullptr; } - output_.value = std::dynamic_pointer_cast(out); } void MKLDNNFcLayer::resetFwdPD(std::shared_ptr& pd, @@ -219,7 +190,6 @@ void MKLDNNFcLayer::resetFwdPipeline( } else { fwd_.reset(new fc_fwd(*pd, *in, *wgt, *out)); } - pipeline.push_back(*fwd_); } @@ -227,44 +197,18 @@ void MKLDNNFcLayer::resetBwdBuffers(MKLDNNMatrixPtr& in, MKLDNNMatrixPtr& wgt, MKLDNNMatrixPtr& bias, MKLDNNMatrixPtr& out) { - resetOutGrad(out); - - resetWgtBiasGrad(wgt, bias); - - resetInGrad(in); -} - -void MKLDNNFcLayer::resetOutGrad(MKLDNNMatrixPtr& out) { - CHECK(outVal_); - if (outputIsOnlyMKLDNN()) { - MKLDNNLayer::resetOutGrad(out, outVal_->getPrimitiveDesc()); - } else { - const MatrixPtr& cpuOut = getOutput(CPU_DEVICE).grad; - output_.grad->setData(cpuOut->getData()); - out = MKLDNNMatrix::create(cpuOut, outVal_->getPrimitiveDesc()); - } -} + CHECK(inVal_ && outVal_); + resetOutGrad(out, outVal_->getPrimitiveDesc()); + resetInGrad(in, inVal_->getPrimitiveDesc()); -void MKLDNNFcLayer::resetWgtBiasGrad(MKLDNNMatrixPtr& wgt, - MKLDNNMatrixPtr& bias) { CHECK(wgtVal_); - wgt = MKLDNNMatrix::create(weight_->getWGrad(), wgtVal_->getPrimitiveDesc()); + resetWithMatrix(wgt, weight_->getWGrad(), wgtVal_->getPrimitiveDesc()); - bias = nullptr; - if (biasVal_ == nullptr) { - return; - } - bias = - MKLDNNMatrix::create(biases_->getWGrad(), biasVal_->getPrimitiveDesc()); -} - -void MKLDNNFcLayer::resetInGrad(MKLDNNMatrixPtr& in) { - in = nullptr; - if (inputLayers_[0]->getOutput().grad == nullptr) { - return; + if (biasVal_) { + resetWithMatrix(bias, biases_->getWGrad(), biasVal_->getPrimitiveDesc()); + } else { + bias = nullptr; } - CHECK(inVal_); - MKLDNNLayer::resetInGrad(in, inVal_->getPrimitiveDesc()); } void MKLDNNFcLayer::resetBwdWgtPD( diff --git a/paddle/gserver/layers/MKLDNNFcLayer.h b/paddle/gserver/layers/MKLDNNFcLayer.h index c76878aafab7e986d2bf478eaba02f2f0aced293..ee861763ff3dc10ddb4c119358b80dbe1614aecb 100644 --- a/paddle/gserver/layers/MKLDNNFcLayer.h +++ b/paddle/gserver/layers/MKLDNNFcLayer.h @@ -66,8 +66,6 @@ public: MKLDNNMatrixPtr& bias, MKLDNNMatrixPtr& out) override; - void updateInputData() override; - void updateWeights(const UpdateCallback& callback) override; void convertWeightsFromPaddle() override; @@ -84,9 +82,6 @@ protected: MKLDNNMatrixPtr& wgt, MKLDNNMatrixPtr& bias, MKLDNNMatrixPtr& out); - void resetInValue(MKLDNNMatrixPtr& in); - void resetWgtBiasValue(MKLDNNMatrixPtr& wgt, MKLDNNMatrixPtr& bias); - void resetOutValue(MKLDNNMatrixPtr& out); void resetFwdPD(std::shared_ptr& pd, MKLDNNMatrixPtr in, MKLDNNMatrixPtr wgt, @@ -109,9 +104,6 @@ protected: MKLDNNMatrixPtr& wgt, MKLDNNMatrixPtr& bias, MKLDNNMatrixPtr& out); - void resetOutGrad(MKLDNNMatrixPtr& out); - void resetWgtBiasGrad(MKLDNNMatrixPtr& wgt, MKLDNNMatrixPtr& bias); - void resetInGrad(MKLDNNMatrixPtr& in); void resetBwdWgtPD(std::shared_ptr& pd, MKLDNNMatrixPtr& wgt, MKLDNNMatrixPtr& bias, diff --git a/paddle/gserver/layers/MKLDNNLayer.cpp b/paddle/gserver/layers/MKLDNNLayer.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6bb19976b5552fcd2e420f03de45c77a90ffb9d2 --- /dev/null +++ b/paddle/gserver/layers/MKLDNNLayer.cpp @@ -0,0 +1,333 @@ +/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "MKLDNNLayer.h" + +using namespace mkldnn; // NOLINT +typedef memory::format format; + +namespace paddle { + +bool MKLDNNLayer::init(const LayerMap& layerMap, + const ParameterMap& parameterMap) { + CHECK(FLAGS_use_mkldnn) << "MkldnnLayers only support use_mkldnn." + << "Please set WITH_MKLDNN=ON " + << "and set use_mkldnn=True"; + CHECK(!useGpu_) << "Do not support GPU yet"; + + // set device id before Layer::init + setDevice(MKLDNN_DEVICE); + // change param device to MKLDNN device + setParamsDevice(MKLDNN_DEVICE, parameterMap); + if (!Layer::init(layerMap, parameterMap)) { + return false; + } + setOutputMap(); + checkCPUOutputsNumber(); + + stream_.reset(new MKLDNNStream()); + engine_ = CPUEngine::Instance().getEngine(); + return true; +} + +void MKLDNNLayer::forward(PassType passType) { + passType_ = passType; + + { + REGISTER_TIMER_INFO("mkldnn_FwdTimer", getName().c_str()); + CHECK(!inputLayers_.empty()); + copySeqInfoToOutputs(); + size_t elemenCnt = inputLayers_[0]->getOutputValue()->getElementCnt(); + if (inputElemenCnt_ != elemenCnt) { + VLOG(MKLDNN_BASE) << getName() << " reset mkldnn forward"; + // reset when input total sizes changed, not only the batchsize + inputElemenCnt_ = elemenCnt; + pipelineFwd_.clear(); + reshape(bs_, ic_, ih_, iw_, oc_, oh_, ow_); + // all cpu device output grad or value share output's + shareCPUDevice(); + resetFwd(pipelineFwd_, inVal_, wgtVal_, biasVal_, outVal_); + // MKLDNNLayer output value should be MKLDNNMatrix + // so external output value is necessary. + // Then external input value is not necessary, + // since input may be mkldnn internal buffer. + CHECK(extOutVal_) << "external output value is necessary"; + output_.value = std::dynamic_pointer_cast(extOutVal_); + CHECK(inVal_ && outVal_) << "internal memories are necessary"; + if (cvtInVal_) { + pipelineFwd_.insert(pipelineFwd_.begin(), *cvtInVal_); + } + if (cvtOutVal_) { + pipelineFwd_.push_back(*cvtOutVal_); + } + convertWeightsFromPaddle(); + printSizeInfo(); + printValueFormat(); + needResetBwd_ = true; + } + + if (inputLayers_[0]->getType() == "data") { + // Update input value data when input layer is "data" type, + // since the input value data address might be changed. + CHECK(extInVal_); + extInVal_->setData(getInputValue(0, CPU_DEVICE)->getData()); + } + + if (!outputOnlyMKLDNN_) { + clearGrads(); + } + stream_->submit(pipelineFwd_); + } + { + REGISTER_TIMER_INFO("FwActTimer", getName().c_str()); + forwardActivation(); + } +} + +void MKLDNNLayer::backward(const UpdateCallback& callback) { + if (needResetBwd_) { + VLOG(MKLDNN_BASE) << getName() << " reset mkldnn backward"; + pipelineBwd_.clear(); + pipelineMergeGrad_.clear(); + mergeGrad_ = nullptr; + resetBwd(pipelineBwd_, inGrad_, wgtGrad_, biasGrad_, outGrad_); + // external output grad is not necessary + // since output may be mkldnn internal buffer or merge them directly. + CHECK(outGrad_) << "internal output grad is necessary"; + if (extOutGrad_) { + CHECK_EQ(extOutGrad_->getData(), output_.grad->getData()) + << "the external buffer should share the same data with output_.grad"; + } + if (cvtOutGrad_) { + pipelineBwd_.insert(pipelineBwd_.begin(), *cvtOutGrad_); + } + if (cvtInGrad_) { + pipelineBwd_.push_back(*cvtInGrad_); + } + printGradFormat(); + needResetBwd_ = false; + } + + // merge grad must before backward activation + if (mergeGrad_) { + REGISTER_TIMER_INFO("MergeBpGrad", getName().c_str()); + stream_->submit(pipelineMergeGrad_); + } + { + REGISTER_TIMER_INFO("BpActTimer", getName().c_str()); + backwardActivation(); + } + { + REGISTER_TIMER_INFO("mkldnn_bwdTimer", getName().c_str()); + stream_->submit(pipelineBwd_); + } + { + REGISTER_TIMER_INFO("WeightUpdate", getName().c_str()); + updateWeights(callback); + } +} + +void MKLDNNLayer::reshapeInput(int& batchsize, int& height, int& width) { + const Argument& input = inputLayers_[0]->getOutput(); + batchsize = input.getBatchSize(); + int h = input.getFrameHeight(); + int w = input.getFrameWidth(); + if (h != 0) { + height = h; + } + if (w != 0) { + width = w; + } +} + +void MKLDNNLayer::reshapeOutput(size_t height, size_t width) { + output_.setFrameHeight(height); + output_.setFrameWidth(width); + for (size_t i = 0; i < outputOtherDevice_.size(); i++) { + outputOtherDevice_[i].setFrameHeight(height); + outputOtherDevice_[i].setFrameWidth(width); + } +} + +void MKLDNNLayer::resetWithMatrix(MKLDNNMatrixPtr& dnn, + const MatrixPtr& mat, + memory::primitive_desc pd) { + dnn = nullptr; + if (mat == nullptr) { + return; + } + dnn = MKLDNNMatrix::create(pd, mat); +} + +void MKLDNNLayer::resetInValue( + MKLDNNMatrixPtr& in, const std::shared_ptr& intPD) { + cvtInVal_ = nullptr; + extInVal_ = nullptr; + in = nullptr; + CHECK_GT(bs_ * ic_ * ih_ * iw_, 0); + auto extPD = MKLDNNMatrix::createPrimitiveDesc( + {bs_, ic_, ih_, iw_}, format::nchw, engine_); + const MatrixPtr& inMat = inputLayers_[0]->getOutputValue(); + in = std::dynamic_pointer_cast(inMat); + CHECK_EQ(inputIsOnlyMKLDNN(), in != nullptr); + if (in == nullptr || in->getFormat() == format::nc) { + in = MKLDNNMatrix::create(extPD, inMat); + } + extInVal_ = isPaddleFormat(in->getFormat()) ? in : nullptr; + if (in->getFormat() == format::nc) { + CHECK(ih_ == 1 && iw_ == 1); + } + if (nullptr == intPD || in->getPrimitiveDesc() == *intPD) { + return; + } + // need create reorder + in = MKLDNNMatrix::create(*intPD); + extInVal_ = extInVal_ ? extInVal_ : MKLDNNMatrix::create(extPD, inMat); + cvtInVal_ = MKLDNNMatrix::createReorder(extInVal_, in); + CHECK(cvtInVal_) << "should not be emptry"; +} + +void MKLDNNLayer::resetOutValue(MKLDNNMatrixPtr& out, + memory::primitive_desc intPD) { + cvtOutVal_ = nullptr; + out = MKLDNNMatrix::create(intPD, output_.value); + extOutVal_ = out; + if (outputIsOnlyMKLDNN() || isPaddleFormat(extOutVal_->getFormat())) { + return; + } + // need create reorder + CHECK_GT(bs_ * oc_ * oh_ * ow_, 0); + extOutVal_ = MKLDNNMatrix::create( + memory::dims{bs_, oc_, oh_, ow_}, format::nchw, engine_, output_.value); + out = MKLDNNMatrix::create(intPD); + cvtOutVal_ = MKLDNNMatrix::createReorder(out, extOutVal_); + CHECK(cvtOutVal_) << "should not be empty"; +} + +void MKLDNNLayer::resetInGrad(MKLDNNMatrixPtr& in, + memory::primitive_desc intPD) { + cvtInGrad_ = nullptr; + extInGrad_ = nullptr; + in = nullptr; + LayerPtr& input = inputLayers_[0]; + if (input->getOutputGrad() == nullptr) { + // no need input grad + return; + } + CHECK(inputIsOnlyMKLDNN() || input->getOutputMapSize() <= 1) + << "only support input is MKLDNN layer or only have one output layer"; + // when input is a mkldnn branch node, + // this layer will save input grad to a internal buffer, + // and the mkldnn input layer will merge them to actual prev->output_.grad + const MatrixPtr& inMat = + input->getOutputMapSize() <= 1 ? input->getOutputGrad() : nullptr; + in = MKLDNNMatrix::create(intPD, inMat); + Argument& arg = input->getOutput(this->getName()); + arg.grad = std::dynamic_pointer_cast(in); + CHECK(inVal_); + CHECK(inVal_->getPrimitiveDesc() == intPD) << "the primitive desc must equal"; + if (inputIsOnlyMKLDNN()) { + return; + } + + extInGrad_ = in; + if (isPaddleFormat(extInGrad_->getFormat())) { + return; + } + // need create reorder + // TODO(TJ): add macro definition to simplify it + CHECK(extInVal_ != nullptr && isPaddleFormat(extInVal_->getFormat())) + << "should have external input value and the format must be nchw(nc)"; + extInGrad_ = MKLDNNMatrix::create(extInVal_->getPrimitiveDesc(), inMat); + CHECK(inVal_ != nullptr && inVal_->getPrimitiveDesc() == intPD) + << "should have internal input value and primitive desc must equal"; + in = MKLDNNMatrix::create(intPD); + cvtInGrad_ = MKLDNNMatrix::createReorder(in, extInGrad_); + CHECK(cvtInGrad_); +} + +void MKLDNNLayer::resetOutGrad(MKLDNNMatrixPtr& out, + memory::primitive_desc intPD) { + cvtOutGrad_ = nullptr; + extOutGrad_ = nullptr; + out = nullptr; + MatrixPtr& outMat = output_.grad; + out = MKLDNNMatrix::create(intPD, outMat); + resetMergeGrad(out); + if (outputIsOnlyMKLDNN()) { + return; + } + CHECK_LE(outputMap_.size(), 1U) << "do not support mixed with cpu device"; + extOutGrad_ = out; + if (isPaddleFormat(extOutGrad_->getFormat())) { + return; + } + // need create reorder + CHECK(extOutVal_ != nullptr && isPaddleFormat(extOutVal_->getFormat())) + << "should have external output value and the format must be nchw(nc)"; + extOutGrad_ = MKLDNNMatrix::create(extOutVal_->getPrimitiveDesc(), outMat); + CHECK(outVal_ != nullptr && outVal_->getPrimitiveDesc() == intPD) + << "should have internal output value and primitive desc must equal"; + out = MKLDNNMatrix::create(intPD); + cvtOutGrad_ = MKLDNNMatrix::createReorder(extOutGrad_, out); + CHECK(cvtOutGrad_); +} + +void MKLDNNLayer::resetMergeGrad(MKLDNNMatrixPtr& out) { + mergeGrad_ = nullptr; + pipelineMergeGrad_.clear(); + if (outputMap_.size() <= 1 || !outputIsOnlyMKLDNN()) { + // do not merge when output is not all MKLDNN or only one output + return; + } + CHECK(out) << "should have reset internal ouput grad"; + std::vector scales(outputMap_.size(), 1.0); + std::vector srcPDs; + std::vector srcs; + for (auto it = outputMap_.begin(); it != outputMap_.end(); ++it) { + MKLDNNMatrixPtr src = + std::dynamic_pointer_cast(it->second->grad); + CHECK(src) << "should be MKLDNNMatrix"; + auto srcDims = src->getDims(); + auto dstDims = out->getDims(); + CHECK_EQ(srcDims.size(), dstDims.size()); + for (size_t i = 0; i < srcDims.size(); ++i) { + CHECK_EQ(srcDims[i], dstDims[i]); + } + VLOG(MKLDNN_BASE) << getName() << " has output grad " << it->first + << ", format " << src->getFormat(); + srcPDs.push_back(src->getPrimitiveDesc()); + srcs.push_back(*src); + } + + // TODO(TJ): remove me when mkldnn sum support different formats + for (size_t i = 1; i < srcPDs.size(); ++i) { + CHECK(srcPDs[0] == srcPDs[i]); + } + tmpOutGrad_ = out; + tmpCvt_ = nullptr; + if (out->getPrimitiveDesc() != srcPDs[0]) { + tmpOutGrad_ = MKLDNNMatrix::create(srcPDs[0]); + tmpCvt_ = MKLDNNMatrix::createReorder(tmpOutGrad_, out); + CHECK(tmpCvt_); + pipelineMergeGrad_.push_back(*tmpCvt_); + } + + auto sumPD = + sum::primitive_desc(tmpOutGrad_->getMemoryDesc(), scales, srcPDs); + mergeGrad_.reset(new sum(sumPD, srcs, *tmpOutGrad_)); + pipelineMergeGrad_.insert(pipelineMergeGrad_.begin(), *mergeGrad_); +} + +} // namespace paddle diff --git a/paddle/gserver/layers/MKLDNNLayer.h b/paddle/gserver/layers/MKLDNNLayer.h index 4e2753eba2350d2c3df81b57fe98270a3c38cb24..2c21a5b2aaecb17a52a5de9a98664068f2255d83 100644 --- a/paddle/gserver/layers/MKLDNNLayer.h +++ b/paddle/gserver/layers/MKLDNNLayer.h @@ -58,11 +58,31 @@ protected: std::vector pipelineFwd_; std::vector pipelineBwd_; - // MKLDNNMatrixPtr with internal format + /* Value and grad are seperated as internal and external buffers. + * Each MKLDNNLayer must init or reset internal buffer at least, + * and the external buffer format is always nchw of nc(when h==w==1), + * which is the same format as paddle. + * The output_.value and output_.grad always save the external data, + * when mixed with cpu device. + * When all layers are mkldnn layers, they could save internal data. + */ + // below MKLDNNMatrix buffers are all internal buffers MKLDNNMatrixPtr inVal_; MKLDNNMatrixPtr inGrad_; MKLDNNMatrixPtr outVal_; MKLDNNMatrixPtr outGrad_; + // below are external value and grad + MKLDNNMatrixPtr extInVal_; + MKLDNNMatrixPtr extInGrad_; + MKLDNNMatrixPtr extOutVal_; + MKLDNNMatrixPtr extOutGrad_; + // convert handle between external and internal buffers + std::shared_ptr cvtInVal_; + std::shared_ptr cvtInGrad_; + std::shared_ptr cvtOutVal_; + std::shared_ptr cvtOutGrad_; + + // weight and bias are always internal buffers MKLDNNMatrixPtr wgtVal_; MKLDNNMatrixPtr wgtGrad_; MKLDNNMatrixPtr biasVal_; @@ -91,6 +111,7 @@ public: oh_(0), ow_(0), needResetBwd_(true), + outputOnlyMKLDNN_(false), engine_(mkldnn::engine::cpu, 0), stream_(nullptr), fwd_(nullptr), @@ -99,92 +120,9 @@ public: ~MKLDNNLayer() {} - virtual bool init(const LayerMap& layerMap, - const ParameterMap& parameterMap) { - CHECK(FLAGS_use_mkldnn) << "MkldnnLayers only support use_mkldnn." - << "Please set WITH_MKLDNN=ON " - << "and set use_mkldnn=True"; - CHECK(!useGpu_) << "Do not support GPU yet"; - - // set device id before Layer::init - setDevice(MKLDNN_DEVICE); - // change param device to MKLDNN device - setParamsDevice(MKLDNN_DEVICE, parameterMap); - if (!Layer::init(layerMap, parameterMap)) { - return false; - } - setOutputMap(); - checkCPUOutputsNumber(); - - stream_.reset(new MKLDNNStream()); - engine_ = CPUEngine::Instance().getEngine(); - return true; - } - - void forward(PassType passType) override { - passType_ = passType; - - { - REGISTER_TIMER_INFO("mkldnn_FwdTimer", getName().c_str()); - CHECK(!inputLayers_.empty()); - copySeqInfoToOutputs(); - size_t elemenCnt = inputLayers_[0]->getOutput().value->getElementCnt(); - if (inputElemenCnt_ != elemenCnt) { - VLOG(MKLDNN_BASE) << getName() << " reset mkldnn forward"; - // reset when input total sizes changed, not only the batchsize - inputElemenCnt_ = elemenCnt; - pipelineFwd_.clear(); - reshape(bs_, ic_, ih_, iw_, oc_, oh_, ow_); - resetFwd(pipelineFwd_, inVal_, wgtVal_, biasVal_, outVal_); - convertWeightsFromPaddle(); - needResetBwd_ = true; - } - - if (inputLayers_[0]->getType() == "data") { - updateInputData(); - } - - if (!outputOnlyMKLDNN_) { - clearGrads(); - } - stream_->submit(pipelineFwd_); - } - - /* activation */ { - REGISTER_TIMER_INFO("FwActTimer", getName().c_str()); - forwardActivation(); - } - } - - void backward(const UpdateCallback& callback) override { - if (needResetBwd_) { - VLOG(MKLDNN_BASE) << getName() << " reset mkldnn backward"; - pipelineBwd_.clear(); - pipelineMergeGrad_.clear(); - mergeGrad_ = nullptr; - resetBwd(pipelineBwd_, inGrad_, wgtGrad_, biasGrad_, outGrad_); - needResetBwd_ = false; - } - - // merge grad must before backward activation - if (mergeGrad_) { - REGISTER_TIMER_INFO("MergeBpGrad", getName().c_str()); - stream_->submit(pipelineMergeGrad_); - } - { - REGISTER_TIMER_INFO("BpActTimer", getName().c_str()); - backwardActivation(); - } - { - REGISTER_TIMER_INFO("mkldnn_bwdTimer", getName().c_str()); - stream_->submit(pipelineBwd_); - } - - { - REGISTER_TIMER_INFO("WeightUpdate", getName().c_str()); - updateWeights(callback); - } - } + virtual bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); + virtual void forward(PassType passType); + virtual void backward(const UpdateCallback& callback); /** * reshape the input image sizes @@ -195,7 +133,7 @@ public: int& bs, int& ic, int& ih, int& iw, int oc, int& oh, int& ow) = 0; /** - * reset the mkldnn forward primitve and memory + * reset the mkldnn forward primitve and memories * only would be called when input size changes */ virtual void resetFwd(std::vector& pipeline, @@ -205,7 +143,7 @@ public: MKLDNNMatrixPtr& out) = 0; /** - * reset the mkldnn backward primitve and memory for mkldnn fc + * reset the mkldnn backward primitve and memories * only would be called when needed */ virtual void resetBwd(std::vector& pipeline, @@ -214,12 +152,6 @@ public: MKLDNNMatrixPtr& bias, MKLDNNMatrixPtr& out) = 0; - /** - * Update input value data when input layer is "data" type. - * Since the input value data address might be changed. - */ - virtual void updateInputData() {} - /** * Update weights and biases if necessary. */ @@ -246,131 +178,78 @@ protected: /** * reshape the input image sizes and input batchsize */ - virtual void reshapeInput(int& batchsize, int& height, int& width) { - const Argument& input = inputLayers_[0]->getOutput(); - batchsize = input.getBatchSize(); - int h = input.getFrameHeight(); - int w = input.getFrameWidth(); - if (h != 0) { - height = h; - } - if (w != 0) { - width = w; - } - } + void reshapeInput(int& batchsize, int& height, int& width); /** * reshape output image sizes */ - virtual void reshapeOutput(size_t height, size_t width) { - output_.setFrameHeight(height); - output_.setFrameWidth(width); - for (size_t i = 0; i < outputOtherDevice_.size(); i++) { - outputOtherDevice_[i].setFrameHeight(height); - outputOtherDevice_[i].setFrameWidth(width); - } - } + void reshapeOutput(size_t height, size_t width); /** - * reset the output grad matrix from primitive desc. - * and reset the merge grad primitive if needed. - * note: when this layer has serval outputs, - * it could not be mixed with cpu device, - * since it can not get memory desc from cpu device. + * reset MKLDNNMatrix from Matrix and internal primitive desc. + * reset nullptr if matrix or primitive desc is empty */ - virtual void resetOutGrad(MKLDNNMatrixPtr& out, - mkldnn::memory::primitive_desc pd) { - CHECK(outputIsOnlyMKLDNN()) << "do not support mixed with other device yet"; - mergeGrad_ = nullptr; - pipelineMergeGrad_.clear(); - out = MKLDNNMatrix::create(output_.grad, pd); - if (outputMap_.size() <= 1) { - return; - } - std::vector scales(outputMap_.size(), 1.0); - std::vector srcPDs; - std::vector srcs; - for (auto it = outputMap_.begin(); it != outputMap_.end(); ++it) { - MKLDNNMatrixPtr src = - std::dynamic_pointer_cast(it->second->grad); - VLOG(MKLDNN_BASE) << getName() << " has output grad " << it->first; - CHECK(src) << "should be MKLDNNMatrix"; - auto srcDims = src->getDims(); - auto dstDims = out->getDims(); - CHECK_EQ(srcDims.size(), dstDims.size()); - for (size_t i = 0; i < srcDims.size(); ++i) { - CHECK_EQ(srcDims[i], dstDims[i]); - } - srcPDs.push_back(src->getPrimitiveDesc()); - srcs.push_back(*src); - } + void resetWithMatrix(MKLDNNMatrixPtr& dnn, + const MatrixPtr& mat, + mkldnn::memory::primitive_desc pd); - // TODO(TJ): remove me when mkldnn sum support different formats - for (size_t i = 1; i < srcPDs.size(); ++i) { - CHECK(srcPDs[0] == srcPDs[i]); - } - tmpOutGrad_ = nullptr; - tmpCvt_ = nullptr; - if (out->getPrimitiveDesc() != srcPDs[0]) { - tmpOutGrad_ = MKLDNNMatrix::create(nullptr, srcPDs[0]); - tmpCvt_ = MKLDNNMatrix::createReorder(tmpOutGrad_, out); - CHECK(tmpCvt_); - pipelineMergeGrad_.push_back(*tmpCvt_); - } else { - tmpOutGrad_ = out; - } + /** + * reset input value from input MKLDNNMatrix and internal primitive desc. + * reset both internal and external buffer and create reorder if necessary. + */ + void resetInValue( + MKLDNNMatrixPtr& in, + const std::shared_ptr& intPD = nullptr); - auto sumPD = mkldnn::sum::primitive_desc( - tmpOutGrad_->getMemoryDesc(), scales, srcPDs); - mergeGrad_.reset(new mkldnn::sum(sumPD, srcs, *tmpOutGrad_)); - pipelineMergeGrad_.insert(pipelineMergeGrad_.begin(), *mergeGrad_); - } + /** + * reset output value from internal primitive desc. + * reset both internal and external buffer and create reorder if necessary. + */ + void resetOutValue(MKLDNNMatrixPtr& out, + mkldnn::memory::primitive_desc intPD); /** - * reset input grad from primitive desc. - * this function is avaiable for input is only mkldnn - * or input do not care cpu device + * reset input grad from internal primitive desc. + * reset both internal and external buffer and create reorder if necessary. */ - virtual void resetInGrad(MKLDNNMatrixPtr& in, - mkldnn::memory::primitive_desc pd) { - LayerPtr& input = inputLayers_[0]; - const MatrixPtr& grad = - input->getOutputMapSize() > 1 ? nullptr : input->getOutput().grad; - in = MKLDNNMatrix::create(grad, pd); - Argument& arg = input->getOutput(this->getName()); - arg.grad = std::dynamic_pointer_cast(in); - } + void resetInGrad(MKLDNNMatrixPtr& in, mkldnn::memory::primitive_desc intPD); /** - * print info about sizes + * reset output grad from internal primitive desc. + * merge grad if necessary. + * reset both internal and external buffer and create reorder if necessary. + * note: about merge grad, when this layer has several outputs, + * it could not be mixed with cpu device, + * since it can not get memory desc from cpu device. */ - virtual void printSizeInfo() { - VLOG(MKLDNN_SIZES) << getName() << ": bs: " << bs_ << ", ic: " << ic_ - << ", ih: " << ih_ << ", iw: " << iw_ << ", oc: " << oc_ - << ", oh: " << oh_ << ", ow: " << ow_; - } + void resetOutGrad(MKLDNNMatrixPtr& out, mkldnn::memory::primitive_desc intPD); /** - * Print the mkldnn memory format flow of value + * reset the merge grad primitive if necessary. + * note: do not support the grads mixed with cpu device, + * since it can not get memory desc from cpu device. */ - virtual void printValueFormatFlow() { - if (inVal_ && outVal_) { - VLOG(MKLDNN_FMTS) << inVal_->getFormat() << " >>> " - << outVal_->getFormat(); - } - } + void resetMergeGrad(MKLDNNMatrixPtr& out); + +protected: + /** + * Set deviceId of this layer. + */ + void setDevice(int id) { deviceId_ = id; } /** - * Print the mkldnn memory format flow of grad + * check the format is nchw or nc, + * which is supported by Paddle default memory layout */ - virtual void printGradFormatFlow() { - if (inGrad_ && outGrad_) { - VLOG(MKLDNN_FMTS) << inGrad_->getFormat() << " <<< " - << outGrad_->getFormat(); + bool isPaddleFormat(mkldnn::memory::format fmt) { + if (fmt == mkldnn::memory::format::nchw || + fmt == mkldnn::memory::format::nc) { + return true; + } else { + return false; } } -protected: /** * If input only has MKLDNN device. * Otherwise, only support the previous layer using CPU device. @@ -380,7 +259,6 @@ protected: if (prevDevice == MKLDNN_DEVICE) { return true; } else { - // do not support GPU yet CHECK_EQ(prevDevice, CPU_DEVICE) << "Only support CPU yet"; return false; } @@ -400,18 +278,74 @@ protected: } /** - * Set deviceId of this layer. + * print info about sizes */ - void setDevice(int id) { deviceId_ = id; } + virtual void printSizeInfo() { + VLOG(MKLDNN_SIZES) << getName() << ": bs: " << bs_ << ", ic: " << ic_ + << ", ih: " << ih_ << ", iw: " << iw_ << ", oc: " << oc_ + << ", oh: " << oh_ << ", ow: " << ow_; + } + + /** + * print the mkldnn memory format of value + */ + virtual void printValueFormat() { + if (extInVal_) { + VLOG(MKLDNN_FMTS) << extInVal_->getFormat() << " >>> "; + } + if (inVal_) { + VLOG(MKLDNN_FMTS) << inVal_->getFormat() << " >>>"; + } + if (outVal_) { + VLOG(MKLDNN_FMTS) << outVal_->getFormat() << " >>> "; + } + if (extOutVal_) { + VLOG(MKLDNN_FMTS) << extOutVal_->getFormat(); + } + if (wgtVal_) { + VLOG(MKLDNN_FMTS) << "Weight value format: " << wgtVal_->getFormat(); + } + if (biasVal_) { + VLOG(MKLDNN_FMTS) << "Bias value format: " << biasVal_->getFormat(); + } + } + + /** + * print the mkldnn memory format of grad + */ + virtual void printGradFormat() { + if (extOutGrad_) { + VLOG(MKLDNN_FMTS) << extOutGrad_->getFormat(); + } + if (outGrad_) { + VLOG(MKLDNN_FMTS) << outGrad_->getFormat() << " <<< "; + } + if (inGrad_) { + VLOG(MKLDNN_FMTS) << inGrad_->getFormat() << " <<<"; + } + if (extInGrad_) { + VLOG(MKLDNN_FMTS) << extInGrad_->getFormat() << " <<< "; + } + if (wgtGrad_) { + VLOG(MKLDNN_FMTS) << "Weight grad format: " << wgtGrad_->getFormat(); + } + if (biasGrad_) { + VLOG(MKLDNN_FMTS) << "Bias grad format: " << biasGrad_->getFormat(); + } + } private: /** * clear all grad */ void clearGrads() { - output_.grad->zeroMem(); + if (output_.grad) { + output_.grad->zeroMem(); + } for (size_t i = 0; i < outputOtherDevice_.size(); i++) { - outputOtherDevice_[i].grad->zeroMem(); + if (outputOtherDevice_[i].grad) { + outputOtherDevice_[i].grad->zeroMem(); + } } } @@ -449,6 +383,19 @@ private: } } + /** + * if have cpu device, share value and grad data with output_ + */ + void shareCPUDevice() { + if (outputIsOnlyMKLDNN()) { + return; + } + for (size_t i = 0; i < outputOtherDevice_.size(); i++) { + outputOtherDevice_[i].value = output_.value; + outputOtherDevice_[i].grad = output_.grad; + } + } + /** * Check the cpu device number of outputOtherDevice_. * should have only one at most. diff --git a/paddle/gserver/layers/MKLDNNPoolLayer.cpp b/paddle/gserver/layers/MKLDNNPoolLayer.cpp index 0e53e2d1b7e6691909955eeacd345981a9960ec6..6e89260f49979d4edb4da138507a73dc2bf120de 100644 --- a/paddle/gserver/layers/MKLDNNPoolLayer.cpp +++ b/paddle/gserver/layers/MKLDNNPoolLayer.cpp @@ -85,8 +85,6 @@ void MKLDNNPoolLayer::resetFwd(std::vector& pipeline, resetFwdPD(fwdPD_, in, out); resetFwdPipeline(pipeline, fwdPD_, in, out); - - printValueFormatFlow(); } void MKLDNNPoolLayer::resetBwd(std::vector& pipeline, @@ -101,65 +99,22 @@ void MKLDNNPoolLayer::resetBwd(std::vector& pipeline, resetBwdPD(pd, in, out); resetBwdPipeline(pipeline, pd, in, out); - - printGradFormatFlow(); -} - -void MKLDNNPoolLayer::updateInputData() { - inVal_->setData(getInputValue(0, CPU_DEVICE)->getData()); } void MKLDNNPoolLayer::resetFwdBuffers(MKLDNNMatrixPtr& in, MKLDNNMatrixPtr& out) { resetInValue(in); - resetOutValue(out); -} - -void MKLDNNPoolLayer::resetInValue(MKLDNNMatrixPtr& in) { - if (inputIsOnlyMKLDNN()) { - const MatrixPtr& dnnIn = getInputValue(0); - in = std::dynamic_pointer_cast(dnnIn); - CHECK(in) << "Input should be MKLDNNMatrix"; - } else { - CHECK_EQ(getPrev(0)->getDeviceId(), CPU_DEVICE) << "Only support CPU yet"; - const MatrixPtr& cpuIn = getInputValue(0, CPU_DEVICE); - in = MKLDNNMatrix::create( - cpuIn, {bs_, ic_, ih_, iw_}, format::nchw, engine_); - } -} - -void MKLDNNPoolLayer::resetOutValue(MKLDNNMatrixPtr& out) { - CHECK(inVal_) << "Should reset input value first"; memory::dims outDims = memory::dims{bs_, oc_, oh_, ow_}; - out = MKLDNNMatrix::create( - output_.value, outDims, inVal_->getFormat(), engine_); - - // create reorder if output value has cpu device and pd do not match - cpuOutVal_ = nullptr; - cvtOutVal_ = nullptr; - if (!outputIsOnlyMKLDNN()) { - const MatrixPtr& cpuOut = getOutput(CPU_DEVICE).value; - cpuOutVal_ = MKLDNNMatrix::create(cpuOut, outDims, format::nchw, engine_); - if (cpuOutVal_->getPrimitiveDesc() != out->getPrimitiveDesc()) { - out = MKLDNNMatrix::create(nullptr, out->getPrimitiveDesc()); - cvtOutVal_ = MKLDNNMatrix::createReorder(out, cpuOutVal_); - CHECK(cvtOutVal_) << "should not be emptry"; - } else { - cpuOut->setData(output_.value->getData()); - cpuOutVal_ = out; - } - output_.value = std::dynamic_pointer_cast(cpuOutVal_); - return; - } - output_.value = std::dynamic_pointer_cast(outVal_); + CHECK(in); + auto outPD = + MKLDNNMatrix::createPrimitiveDesc(outDims, in->getFormat(), engine_); + resetOutValue(out, outPD); } void MKLDNNPoolLayer::resetFwdPD(std::shared_ptr& pd, MKLDNNMatrixPtr in, MKLDNNMatrixPtr out) { - memory::dims inDims = memory::dims{bs_, ic_, ih_, iw_}; - memory::dims outDims = memory::dims{bs_, oc_, oh_, ow_}; memory::dims kernels = memory::dims{fh_, fw_}; memory::dims strides = memory::dims{sh_, sw_}; memory::dims padL = memory::dims{ph_, pw_}; @@ -194,58 +149,26 @@ void MKLDNNPoolLayer::resetFwdPipeline( ? std::make_shared(pool_fwd(*pd, *in, *out, *workspace_)) : std::make_shared(pool_fwd(*pd, *in, *out)); pipeline.push_back(*fwd_); - - if (cvtOutVal_) { - pipeline.push_back(*cvtOutVal_); - } } void MKLDNNPoolLayer::resetBwdBuffers(MKLDNNMatrixPtr& in, MKLDNNMatrixPtr& out) { - resetOutGrad(out); - - resetInGrad(in); -} -void MKLDNNPoolLayer::resetOutGrad(MKLDNNMatrixPtr& out) { - cpuOutGrad_ = nullptr; - cvtOutGrad_ = nullptr; - CHECK(outVal_); - if (outputIsOnlyMKLDNN()) { - MKLDNNLayer::resetOutGrad(out, outVal_->getPrimitiveDesc()); - } else { - const MatrixPtr& cpuOut = getOutput(CPU_DEVICE).grad; - // always share the same grad data of CPU output - // then the activation can get the right grad from output_.grad - output_.grad->setData(cpuOut->getData()); - cpuOutGrad_ = MKLDNNMatrix::create( - cpuOut, memory::dims{bs_, oc_, oh_, ow_}, format::nchw, engine_); - if (cpuOutGrad_->getPrimitiveDesc() != outVal_->getPrimitiveDesc()) { - out = MKLDNNMatrix::create(nullptr, outVal_->getPrimitiveDesc()); - cvtOutGrad_ = MKLDNNMatrix::createReorder(cpuOutGrad_, out); - CHECK(cvtOutGrad_) << "should not be emptry"; - } else { - out = cpuOutGrad_; - } - } -} - -void MKLDNNPoolLayer::resetInGrad(MKLDNNMatrixPtr& in) { - in = nullptr; - if (inputLayers_[0]->getOutput().grad == nullptr) { - return; - } - CHECK(inVal_); - MKLDNNLayer::resetInGrad(in, inVal_->getPrimitiveDesc()); + CHECK(inVal_ && outVal_); + resetOutGrad(out, outVal_->getPrimitiveDesc()); + resetInGrad(in, inVal_->getPrimitiveDesc()); } void MKLDNNPoolLayer::resetBwdPD(std::shared_ptr& pd, MKLDNNMatrixPtr& in, MKLDNNMatrixPtr& out) { + pd = nullptr; + if (in == nullptr) { + return; + } memory::dims kernels = memory::dims{fh_, fw_}; memory::dims strides = memory::dims{sh_, sw_}; memory::dims padL = memory::dims{ph_, pw_}; memory::dims padR = getPaddingR(); - CHECK(in); CHECK(out); auto bwdDesc = pool_bwd::desc(poolAlgo_, in->getMemoryDesc(), @@ -263,8 +186,8 @@ void MKLDNNPoolLayer::resetBwdPipeline( std::shared_ptr& pd, MKLDNNMatrixPtr& in, MKLDNNMatrixPtr& out) { - if (cvtOutGrad_) { - pipeline.push_back(*cvtOutGrad_); + if (pd == nullptr) { + return; } bwdData_ = diff --git a/paddle/gserver/layers/MKLDNNPoolLayer.h b/paddle/gserver/layers/MKLDNNPoolLayer.h index 891e15a7efcdd2e54f61352efc1ba7345b91c76b..c5ec87828bfb28b4502b4ec6b47287089c514204 100644 --- a/paddle/gserver/layers/MKLDNNPoolLayer.h +++ b/paddle/gserver/layers/MKLDNNPoolLayer.h @@ -38,13 +38,6 @@ protected: // pooling_avg or pooling_max mkldnn::algorithm poolAlgo_; - // MKLDNNMatrixPtr which should be created from CPU Device - MKLDNNMatrixPtr cpuOutVal_; - MKLDNNMatrixPtr cpuOutGrad_; - // convert handle between CPU device and MKLDNN device - std::shared_ptr cvtOutVal_; - std::shared_ptr cvtOutGrad_; - // save forward primitive_desc, which can be used backward std::shared_ptr fwdPD_; // according to https://github.com/01org/mkl-dnn/blob/master/tests/gtests/ @@ -74,8 +67,6 @@ public: MKLDNNMatrixPtr& bias, MKLDNNMatrixPtr& out) override; - void updateInputData() override; - void printSizeInfo() override { MKLDNNLayer::printSizeInfo(); VLOG(MKLDNN_SIZES) << getName() << ": fh: " << fh_ << ", fw: " << fw_ @@ -90,8 +81,6 @@ protected: * reset pipeline. */ void resetFwdBuffers(MKLDNNMatrixPtr& in, MKLDNNMatrixPtr& out); - void resetInValue(MKLDNNMatrixPtr& in); - void resetOutValue(MKLDNNMatrixPtr& out); void resetFwdPD(std::shared_ptr& pd, MKLDNNMatrixPtr in, MKLDNNMatrixPtr out); @@ -106,8 +95,6 @@ protected: * reset pipeline. */ void resetBwdBuffers(MKLDNNMatrixPtr& in, MKLDNNMatrixPtr& out); - void resetOutGrad(MKLDNNMatrixPtr& out); - void resetInGrad(MKLDNNMatrixPtr& in); void resetBwdPD(std::shared_ptr& pd, MKLDNNMatrixPtr& in, MKLDNNMatrixPtr& out); diff --git a/paddle/gserver/tests/MKLDNNTester.cpp b/paddle/gserver/tests/MKLDNNTester.cpp index 3bf6a9e176cc1235aa5ddefcedd4253e6afc1342..0a19fe23336ea943cb8a572dc40f8c0fbbd7236a 100644 --- a/paddle/gserver/tests/MKLDNNTester.cpp +++ b/paddle/gserver/tests/MKLDNNTester.cpp @@ -97,7 +97,7 @@ void MKLDNNTester::randomWgtDatas() { parameters_[REF][i]->randomize(); dnnValue->copyFrom(*refValue); - VLOG(lvl_) << "Random weight data " << parameters_[DNN][i]->getName(); + VLOG(MKLDNN_TESTS) << "Random weight " << parameters_[DNN][i]->getName(); printVector(dnnValue); } } @@ -109,7 +109,7 @@ void MKLDNNTester::randomBotDatas() { dataLayers_[REF][i]->getOutputValue()->randomizeUniform(); dataLayers_[DNN][i]->getOutputValue()->copyFrom( *(dataLayers_[REF][i]->getOutputValue())); - VLOG(lvl_) << "Input " << i << " data:"; + VLOG(MKLDNN_TESTS) << "Random Foward, InputValue " << i; printMatrix(dataLayers_[REF][i]->getOutputValue()); } } @@ -118,12 +118,12 @@ void MKLDNNTester::randomTopDiffs() { refLayer_->getOutputGrad()->randomizeUniform(); dnnLayer_->getOutput(CPU_DEVICE) .grad->copyFrom(*(refLayer_->getOutputGrad())); - VLOG(lvl_) << "Random Backward Input, TopDiff: "; + VLOG(MKLDNN_TESTS) << "Random Backward, OutputGrad"; printMatrix(refLayer_->getOutputGrad()); } void MKLDNNTester::checkForward() { - VLOG(MKLDNN_ALL) << "Check Forward"; + VLOG(MKLDNN_TESTS) << "Check Forward"; printTopDatas(); double delta = compareMatrix(dnnLayer_->getOutputValue(), refLayer_->getOutputValue()); @@ -131,15 +131,15 @@ void MKLDNNTester::checkForward() { } void MKLDNNTester::checkBackwardData() { - VLOG(MKLDNN_ALL) << "Check Backward Data"; + VLOG(MKLDNN_TESTS) << "Check Backward Data"; // TODO(TJ): uncomment me when batch norm ready // const bool isBN = dnnLayer_->getType() == "mkldnn_batch_norm"; for (size_t i = 0; i < dataLayers_[DNN].size(); ++i) { const MatrixPtr& dnnDiff = dataLayers_[DNN][i]->getOutputGrad(); const MatrixPtr& refDiff = dataLayers_[REF][i]->getOutputGrad(); - VLOG(lvl_) << "Mkldnn Backward Output BotDiff " << i; + VLOG(MKLDNN_ALL) << "MKLDNN Backward Result: InputGrad " << i; printMatrix(dnnDiff); - VLOG(lvl_) << "Reference Backward Output BotDiff " << i; + VLOG(MKLDNN_ALL) << "Reference Backward Result: InputGrad " << i; printMatrix(refDiff); double delta = compareMatrix(dnnDiff, refDiff); @@ -153,7 +153,7 @@ void MKLDNNTester::checkBackwardData() { } void MKLDNNTester::checkBackwardWgts() { - VLOG(MKLDNN_ALL) << "Check Backward Weight"; + VLOG(MKLDNN_TESTS) << "Check Backward Weight"; CHECK_EQ(parameters_[DNN].size(), parameters_[REF].size()); vector dnnWgts; // used to temply save mkldnn weights saveWgt(parameters_[DNN], dnnWgts); @@ -165,9 +165,11 @@ void MKLDNNTester::checkBackwardWgts() { for (size_t i = 0; i < parameters_[DNN].size(); ++i) { const VectorPtr& dnn = parameters_[DNN][i]->getBuf(PARAMETER_VALUE); const VectorPtr& ref = parameters_[REF][i]->getBuf(PARAMETER_VALUE); - VLOG(lvl_) << "Mkldnn Output weight " << parameters_[DNN][i]->getName(); + VLOG(MKLDNN_ALL) << "MKLDNN Result: weight value" + << parameters_[DNN][i]->getName(); printVector(dnn); - VLOG(lvl_) << "Reference Output weight " << parameters_[REF][i]->getName(); + VLOG(MKLDNN_ALL) << "Reference Result: weight value " + << parameters_[REF][i]->getName(); printVector(ref); double delta = compareVector(dnn, ref); @@ -240,7 +242,8 @@ void MKLDNNTester::printTopDatas() { } for (int n = 0; n < NUM; ++n) { - VLOG(lvl_) << testLayers_[n]->getType() << " forward output TopData: "; + VLOG(MKLDNN_ALL) << testLayers_[n]->getType() + << " Forward Result: OutputValue"; printMatrix(testLayers_[n]->getOutputValue()); } } @@ -252,7 +255,7 @@ void MKLDNNTester::printMatrix(const MatrixPtr& m) { std::ostringstream ostr; m->print(ostr); - VLOG(lvl_) << std::endl << ostr.str(); + VLOG(MKLDNN_ALL) << std::endl << ostr.str(); } void MKLDNNTester::printVector(const VectorPtr& v) { @@ -262,7 +265,7 @@ void MKLDNNTester::printVector(const VectorPtr& v) { std::ostringstream ostr; v->print(ostr, v->getSize()); - VLOG(lvl_) << std::endl << ostr.str(); + VLOG(MKLDNN_ALL) << std::endl << ostr.str(); } double MKLDNNTester::getDelta(const real* d1, @@ -314,7 +317,7 @@ void MKLDNNTester::runOnce() { UpdateCallback updateCallback = [](Parameter* para) { auto& grad = para->getBuf(PARAMETER_GRADIENT); auto& value = para->getBuf(PARAMETER_VALUE); - real lr = 1e-3; + real lr = 1e-2; value->add(*grad, lr); grad->zeroMem(); }; @@ -340,10 +343,9 @@ void MKLDNNTester::run(const TestConfig& dnn, size_t batchSize, size_t inputImgH, size_t inputImgW, + bool printDetails, size_t iter, - float epsilon, - bool log, - int level) { + float epsilon) { CHECK(dnn.layerConfig.type().compare(0, 7, "mkldnn_") == 0 || dnn.layerConfig.active_type().compare(0, 7, "mkldnn_") == 0) << "should be MKLDNN layer or MKLDNN activation"; @@ -359,10 +361,9 @@ void MKLDNNTester::run(const TestConfig& dnn, ih_ = inputImgH; iw_ = inputImgW; + log_ = printDetails; iter_ = iter; eps_ = epsilon; - log_ = log; - lvl_ = level; // Firstly test mkldnn init from PARAM_FORMAT_ORIGINAL weight reset(dnn, ref, batchSize); @@ -531,9 +532,11 @@ void MKLDNNTester::getOutResult(const std::string& configPath, void MKLDNNTester::compareResult(DataOut& ref, DataOut& dnn, float eps) { CHECK_EQ(ref.outValues.size(), dnn.outValues.size()); CHECK_EQ(ref.paraValues.size(), dnn.paraValues.size()); + VLOG(MKLDNN_TESTS) << "compare value size: " << ref.outValues.size(); for (size_t i = 0; i < ref.outValues.size(); i++) { EXPECT_LE(fabs(compareMatrix(ref.outValues[i], dnn.outValues[i])), eps); } + VLOG(MKLDNN_TESTS) << "compare param size: " << ref.outValues.size(); for (size_t i = 0; i < ref.paraValues.size(); i++) { EXPECT_LE(fabs(compareVector(ref.paraValues[i], dnn.paraValues[i])), eps); } @@ -544,9 +547,10 @@ void MKLDNNTester::runBranchesTest(const std::string& configPath, float eps) { DataIn in; initArgument(in, configPath, iter); - DataOut outCpu, outDnn; + VLOG(MKLDNN_TESTS) << "runing cpu network"; getOutResult(configPath, in, outCpu, false, iter); + VLOG(MKLDNN_TESTS) << "runing mkldnn network"; getOutResult(configPath, in, outDnn, true, iter); compareResult(outCpu, outDnn, eps); diff --git a/paddle/gserver/tests/MKLDNNTester.h b/paddle/gserver/tests/MKLDNNTester.h index 51abfcb67e2ec35fe1b0179e742a7d18f08f8a2c..c385d1c72717d120211f167b5c5eb9a557da3714 100644 --- a/paddle/gserver/tests/MKLDNNTester.h +++ b/paddle/gserver/tests/MKLDNNTester.h @@ -58,8 +58,6 @@ protected: size_t iter_; /// whether to print out the details bool log_; - /// vlog level to print the matrix details datas - int lvl_; /// epsilon float eps_; /// input image size, default 1 @@ -70,7 +68,6 @@ public: iter_ = iter; eps_ = epsilon; log_ = false; - lvl_ = MKLDNN_ALL; } ~MKLDNNTester() {} @@ -81,10 +78,9 @@ public: size_t batchSize, size_t inputImgH = 1, size_t inputImgW = 1, + bool printDetails = false, size_t iter = 3, - float epsilon = 1e-4, - bool log = false, - int level = MKLDNN_ALL); + float epsilon = 1e-4); static void runBranchesTest(const std::string& configPath, size_t iter = 3, float eps = 1e-4); diff --git a/paddle/gserver/tests/test_PyDataProvider2.py b/paddle/gserver/tests/test_PyDataProvider2.py index 2e6225519f4681238f4b40fb33764ead4a16b24a..0d0fe476ff5eac8bf8ad1c9fe09b32c1a8f73ebc 100644 --- a/paddle/gserver/tests/test_PyDataProvider2.py +++ b/paddle/gserver/tests/test_PyDataProvider2.py @@ -51,7 +51,10 @@ def test_sparse_non_value_no_seq(setting, filename): yield [(i + 1) * (j + 1) for j in xrange(10)] -@provider(input_types=[sparse_vector(30000, seq_type=SequenceType.NO_SEQUENCE)]) +@provider(input_types=[ + sparse_float_vector( + 30000, seq_type=SequenceType.NO_SEQUENCE) +]) def test_sparse_value_no_seq(setting, filename): for i in xrange(200): yield [((i + 1) * (j + 1), float(j) / float(i + 1)) for j in xrange(10)] diff --git a/paddle/math/MKLDNNMatrix.cpp b/paddle/math/MKLDNNMatrix.cpp index 0778bb63b7b3bca9b3d2647ca43dad72d783950a..21a8f73c3e650d4b3c3b86247594cd965f4ead35 100644 --- a/paddle/math/MKLDNNMatrix.cpp +++ b/paddle/math/MKLDNNMatrix.cpp @@ -18,7 +18,7 @@ using namespace mkldnn; // NOLINT namespace paddle { -MKLDNNMatrixPtr MKLDNNMatrix::create(MatrixPtr m, memory::primitive_desc pd) { +MKLDNNMatrixPtr MKLDNNMatrix::create(memory::primitive_desc pd, MatrixPtr m) { memory::desc md = pd.desc(); size_t ndims = md.data.ndims; int* dims = md.data.dims; @@ -41,12 +41,12 @@ MKLDNNMatrixPtr MKLDNNMatrix::create(MatrixPtr m, memory::primitive_desc pd) { return std::make_shared(cpuMatrix, pd); } -MKLDNNMatrixPtr MKLDNNMatrix::create(MatrixPtr m, - memory::dims dims, +MKLDNNMatrixPtr MKLDNNMatrix::create(memory::dims dims, memory::format fmt, engine& eg, + MatrixPtr m, mkldnn::memory::data_type dtype) { - return create(m, memory::primitive_desc(memory::desc(dims, dtype, fmt), eg)); + return create(createPrimitiveDesc(dims, fmt, eg, dtype), m); } std::shared_ptr MKLDNNMatrix::createReorder(const MKLDNNMatrixPtr& src, diff --git a/paddle/math/MKLDNNMatrix.h b/paddle/math/MKLDNNMatrix.h index c843115eb9a5be50d6ff873f1510844228c9d89f..fe755d096da9713e39581a909e5d21aa93d69f0f 100644 --- a/paddle/math/MKLDNNMatrix.h +++ b/paddle/math/MKLDNNMatrix.h @@ -40,24 +40,37 @@ public: /** * Create MKLDNNMatrix from a MatrixPtr and memory primitive_desc */ - static MKLDNNMatrixPtr create(MatrixPtr m, mkldnn::memory::primitive_desc pd); + static MKLDNNMatrixPtr create(mkldnn::memory::primitive_desc pd, + MatrixPtr m = nullptr); /** * Create MKLDNNMatrix from a MatrixPtr and memory details info */ static MKLDNNMatrixPtr create( - MatrixPtr m, mkldnn::memory::dims dims, mkldnn::memory::format fmt, mkldnn::engine& eg, + MatrixPtr m = nullptr, mkldnn::memory::data_type dtype = mkldnn::memory::data_type::f32); + /** + * Create primitive descriptor. + * default with f32 dtype + */ + static mkldnn::memory::primitive_desc createPrimitiveDesc( + const mkldnn::memory::dims dims, + const mkldnn::memory::format& fmt, + const mkldnn::engine& eg, + const mkldnn::memory::data_type& dtype = mkldnn::memory::data_type::f32) { + return mkldnn::memory::primitive_desc(memory::desc(dims, dtype, fmt), eg); + } + /** * Create Memory descriptor. * default with any format and f32 dtype */ static mkldnn::memory::desc createMemoryDesc( - const mkldnn::memory::dims& dims, + const mkldnn::memory::dims dims, const mkldnn::memory::format& fmt = mkldnn::memory::format::any, const mkldnn::memory::data_type& dtype = mkldnn::memory::data_type::f32) { return mkldnn::memory::desc(dims, dtype, fmt); diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index 1919d86c33b6e89d5edf78a1f0caa6403550601a..c9a93cd653d0fade7e878af9adc9e33ba2d1c95b 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -116,7 +116,8 @@ set(DEPS_OPS sum_op pool_op pool_with_index_op - sequence_conv_op) + sequence_conv_op + lstm_op) op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc @@ -128,7 +129,7 @@ op_library(sum_op DEPS net_op) op_library(pool_op DEPS pooling) op_library(pool_with_index_op DEPS pooling) op_library(sequence_conv_op DEPS sequence_project) - +op_library(lstm_op DEPS sequence2batch lstm_compute) list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS}) foreach(src ${GENERAL_OPS}) diff --git a/paddle/operators/clip_op.cc b/paddle/operators/clip_op.cc index 2d029394dd97a9c33c9c57fd3565345139cdff92..f80204c6833d6436f2cf21610beea45b36787eea 100644 --- a/paddle/operators/clip_op.cc +++ b/paddle/operators/clip_op.cc @@ -27,8 +27,8 @@ class ClipOp : public framework::OperatorWithKernel { PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) of ClipOp should not be null."); auto x_dims = ctx->GetInputDim("X"); - auto max = Attr("max"); - auto min = Attr("min"); + auto max = ctx->Attrs().Get("max"); + auto min = ctx->Attrs().Get("min"); PADDLE_ENFORCE_LT(min, max, "max should be greater than min."); ctx->SetOutputDim("Out", x_dims); ctx->ShareLoD("X", /*->*/ "Out"); diff --git a/paddle/operators/conv2d_op.h b/paddle/operators/conv2d_op.h index f629728f68d65ce81b4910cae7f89ab06d6d94b8..0621389a79eee6b5e75b1eab309b49f8aa4a97ca 100644 --- a/paddle/operators/conv2d_op.h +++ b/paddle/operators/conv2d_op.h @@ -114,7 +114,7 @@ class GemmConv2DKernel : public framework::OpKernel { // im2col Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); im2col(context.device_context(), in_slice, col, strides[0], strides[1], - paddings[0], paddings[1]); + paddings[0], paddings[0], paddings[1], paddings[1]); // gemm Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); @@ -213,7 +213,8 @@ class GemmConvGrad2DKernel : public framework::OpKernel { Tensor in_grad_slice = in_grad_batch.Slice(g * in_step, (g + 1) * in_step); col2im(context.device_context(), in_grad_slice, col, strides[0], - strides[1], paddings[0], paddings[1]); + strides[1], paddings[0], paddings[0], paddings[1], + paddings[1]); } } } @@ -235,7 +236,8 @@ class GemmConvGrad2DKernel : public framework::OpKernel { out_grad_batch.Slice(g * out_step, (g + 1) * out_step); Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); im2col(context.device_context(), in_slice, col, strides[0], - strides[1], paddings[0], paddings[1]); + strides[1], paddings[0], paddings[0], paddings[1], + paddings[1]); // gemm Tensor filter_grad_slice = diff --git a/paddle/operators/conv2dtranspose_op.cc b/paddle/operators/conv2dtranspose_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..c1b231906e2f172b6f9cee55f850d1a5ec6c3221 --- /dev/null +++ b/paddle/operators/conv2dtranspose_op.cc @@ -0,0 +1,107 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/conv2dtranspose_op.h" + +namespace paddle { +namespace operators { + +void Conv2DTransposeOp::InferShape(framework::InferShapeContext* ctx) const { + PADDLE_ENFORCE(ctx->HasInput("Input"), + "Input(Input) of Conv2DTransposeOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Filter"), + "Input(Filter) of Conv2DTransposeOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Output"), + "Output(Output) of Conv2DTransposeOp should not be null."); + + auto in_dims = ctx->GetInputDim("Input"); + auto filter_dims = ctx->GetInputDim("Filter"); + std::vector strides = ctx->Attrs().Get>("strides"); + std::vector paddings = ctx->Attrs().Get>("paddings"); + + for (size_t i = 0; i < paddings.size(); ++i) { + PADDLE_ENFORCE_EQ(paddings[i], 0, + "No Padding allowed in conv transpose op."); + } + + PADDLE_ENFORCE_EQ(in_dims.size(), 4, + "Conv2DTransposeOp input should be 4-D tensor."); + PADDLE_ENFORCE_EQ(filter_dims.size(), 4, + "Conv2DTransposeOp filter should be 4-D tensor."); + PADDLE_ENFORCE_EQ(in_dims[1], filter_dims[0], + "input and kernel input dimension should be equal."); + + auto output_height = (in_dims[2] - 1) * strides[0] + filter_dims[2]; + auto output_width = (in_dims[3] - 1) * strides[1] + filter_dims[3]; + ctx->SetOutputDim("Output", + {in_dims[0], filter_dims[1], output_height, output_width}); +} + +Conv2DTransposeOpMaker::Conv2DTransposeOpMaker( + framework::OpProto* proto, framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput( + "Input", + "(Tensor) The input tensor of convolution transpose operator. " + "The format of input tensor is NCHW. Where N is batch size, C is the " + "number of input channels, H and W is the height and width of image."); + AddInput("Filter", + "(Tensor) The filter tensor of convolution transpose operator." + "The format of the filter tensor is CMHW, where C is the number of " + "output image channels, M is the number of input image channels, " + "H and W is height and width of filter. " + "We enforce groups number == 1 and padding == 0 in " + "convolution transpose Scenario."); + AddOutput("Output", + "(Tensor) The output tensor of convolution transpose operator." + "The format of output tensor is also NCHW."); + AddAttr>("strides", + "strides of convolution transpose operator.") + .SetDefault({1, 1}); + AddAttr>("paddings", + "paddings of convolution transpose operator.") + .SetDefault({0, 0}); + AddComment(R"DOC( +The convolution transpose operation calculates the output based on the input, filter +and strides, paddings, groups parameters. The size of each dimension of the +parameters is checked in the infer-shape. +)DOC"); +} + +void Conv2DTransposeOpGrad::InferShape( + framework::InferShapeContext* ctx) const { + auto in_dims = ctx->GetInputDim("Input"); + auto filter_dims = ctx->GetInputDim("Filter"); + if (ctx->HasOutput(framework::GradVarName("Input"))) { + ctx->SetOutputDim(framework::GradVarName("Input"), in_dims); + } + if (ctx->HasOutput(framework::GradVarName("Filter"))) { + ctx->SetOutputDim(framework::GradVarName("Filter"), filter_dims); + } +} + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(conv2dtranspose, ops::Conv2DTransposeOp, + ops::Conv2DTransposeOpMaker, conv2dtranspose_grad, + ops::Conv2DTransposeOpGrad); + +REGISTER_OP_CPU_KERNEL( + conv2dtranspose, + ops::GemmConv2DTransposeKernel); +REGISTER_OP_CPU_KERNEL( + conv2dtranspose_grad, + ops::GemmConv2DTransposeGradKernel); diff --git a/paddle/operators/conv2dtranspose_op.cu b/paddle/operators/conv2dtranspose_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..761bc1959e69be94f43571728e6b92a322558b99 --- /dev/null +++ b/paddle/operators/conv2dtranspose_op.cu @@ -0,0 +1,24 @@ +/* Copyright (c) 2016 PaddlePaddle Authors All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/conv2dtranspose_op.h" + +namespace ops = paddle::operators; + +REGISTER_OP_GPU_KERNEL( + conv2dtranspose, + ops::GemmConv2DTransposeKernel); +REGISTER_OP_GPU_KERNEL( + conv2dtranspose_grad, + ops::GemmConv2DTransposeGradKernel); diff --git a/paddle/operators/conv2dtranspose_op.h b/paddle/operators/conv2dtranspose_op.h new file mode 100644 index 0000000000000000000000000000000000000000..8c70b3dcec1e26ab3d8a42d88040764c643b5ae6 --- /dev/null +++ b/paddle/operators/conv2dtranspose_op.h @@ -0,0 +1,254 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" +#include "paddle/operators/math/im2col.h" +#include "paddle/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using DDim = framework::DDim; + +// Define Op classes in .h file so that other conv transpose +// operator implementations can reuse the code. +class Conv2DTransposeOpMaker : public framework::OpProtoAndCheckerMaker { + public: + Conv2DTransposeOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker); +}; + +class Conv2DTransposeOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override; +}; + +class Conv2DTransposeOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override; +}; + +template +class GemmConv2DTransposeKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* input = context.Input("Input"); + // The filter will be reshaped, so it should not be constant pointer + Tensor filter = *context.Input("Filter"); + + Tensor* output = context.Output("Output"); + + std::vector strides = context.Attr>("strides"); + + // TODO(Zhuoyuan): Paddings can be added in future. + // groups will alway be disabled in conv2dtranspose. + + const int batch_size = input->dims()[0]; + const int m = input->dims()[1]; + const int h = input->dims()[2]; + const int w = input->dims()[3]; + + const int k_h = filter.dims()[2]; + const int k_w = filter.dims()[3]; + + const int c = output->dims()[1]; // output channels + const int o_h = output->dims()[2]; + const int o_w = output->dims()[3]; + + paddle::operators::math::Col2ImFunctor< + paddle::operators::math::ColFormat::kCFO, Place, T> + col2im; + + // use col_shape in the im2col and col2im calculation + DDim col_shape = {c, k_h, k_w, h, w}; + + // use col_matrix_shape in the gemm calculation + DDim col_matrix_shape = {c * k_h * k_w, h * w}; + + Tensor col; + col.mutable_data(col_shape, context.GetPlace()); + // col_matrix shares the same piece of data with col, + // but will be reshaped into a two-dimensional matrix shape + // to call the matrix multiplication interface. + Tensor col_matrix; + col_matrix.ShareDataWith(col); + col_matrix.Resize(col_matrix_shape); + + DDim output_shape = {c, o_h, o_w}; + DDim input_matrix_shape = {m, h * w}; + + DDim filter_matrix_shape = {m, c * k_h * k_w}; + filter.Resize(filter_matrix_shape); + + // convolution transpose: gemm + col2im (similar to conv-backward on input) + + output->mutable_data(context.GetPlace()); + auto t = framework::EigenVector::Flatten(*output); + t.device(context.GetEigenDevice()) = t.constant(static_cast(0)); + + for (int i = 0; i < batch_size; i++) { + // batch with size (M, h * w) + Tensor input_batch = input->Slice(i, i + 1).Resize(input_matrix_shape); + // filter size: (M, c * k_h * k_w) + + // output size: (c, o_h, o_w) + Tensor output_batch = output->Slice(i, i + 1).Resize(output_shape); + + // col_matrix = filter * input_batch + // of shape (c * k_h * k_w, h * w) + math::matmul(context.device_context(), filter, true, + input_batch, false, T(1.0), &col_matrix, T(0.0)); + col2im(context.device_context(), output_batch, col, strides[0], + strides[1], 0, 0, 0, 0); + } + } +}; + +template +class GemmConv2DTransposeGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* input = context.Input("Input"); + const Tensor* output_grad = + context.Input(framework::GradVarName("Output")); + + // For filter, we do not use const pointer b/c we will do reshape, + // but we should avoid modifying its value. + Tensor filter = *context.Input("Filter"); + + Tensor* input_grad = + context.Output(framework::GradVarName("Input")); + Tensor* filter_grad = + context.Output(framework::GradVarName("Filter")); + + std::vector strides = context.Attr>("strides"); + // Actually, no paddings and groups allowed in conv transpose. + std::vector paddings = context.Attr>("paddings"); + + const int batch_size = input->dims()[0]; + const int m = input->dims()[1]; + const int h = input->dims()[2]; + const int w = input->dims()[3]; + + const int k_h = filter.dims()[2]; + const int k_w = filter.dims()[3]; + + const int c = output_grad->dims()[1]; // output channels + const int o_h = output_grad->dims()[2]; + const int o_w = output_grad->dims()[3]; + + // Only im2col functor required for bp to get to the right shape + paddle::operators::math::Im2ColFunctor< + paddle::operators::math::ColFormat::kCFO, Place, T> + im2col; + + // use col_shape in the im2col and col2im calculation + DDim col_shape = {c, k_h, k_w, h, w}; + + // use col_matrix_shape in the gemm calculation + DDim col_matrix_shape_f = {c * h * w, k_h * k_w}; + + Tensor col; + col.mutable_data(col_shape, context.GetPlace()); + // col_matrix shares the same piece of data with col, + // but will be reshaped into a two-dimensional matrix shape + // to call the matrix multiplication interface. + + DDim output_shape = {c, o_h, o_w}; + DDim input_matrix_shape = {m, h * w}; + + DDim filter_matrix_shape = {m, c * k_h * k_w}; + filter.Resize(filter_matrix_shape); + + // convolution transpose grad on input: + // im2col + gemm (similar to conv-forward) + // input need to compute gradient + if (input_grad) { + Tensor col_matrix; + col_matrix.ShareDataWith(col); + DDim col_matrix_shape = {c * k_h * k_w, h * w}; + col_matrix.Resize(col_matrix_shape); + + input_grad->mutable_data(context.GetPlace()); + auto t = framework::EigenVector::Flatten(*input_grad); + t.device(context.GetEigenDevice()) = t.constant(static_cast(0)); + + for (int i = 0; i < batch_size; i++) { + // batch with size (c, o_h * o_w) + Tensor output_grad_batch = + output_grad->Slice(i, i + 1).Resize(output_shape); + // filter of size (m, c * k_h * k_w) + + // batch with size (m, h, w) + Tensor input_grad_batch = + input_grad->Slice(i, i + 1).Resize(input_matrix_shape); + + // im2col: dy from (c, o_h, o_w) -> (c * k_h * k_w, h * w) + im2col(context.device_context(), output_grad_batch, col, strides[0], + strides[1], paddings[0], paddings[0], paddings[1], paddings[1]); + + // gemm: dx = filter * dy + // (m, c * k_h * k_w) * (c * k_h * k_w, h * w) -> (m, c, h) + math::matmul(context.device_context(), filter, false, + col_matrix, false, T(1.0), &input_grad_batch, + T(0.0)); + } + } + + // filter gradient required + if (filter_grad) { + Tensor col_matrix_f; + col_matrix_f.ShareDataWith(col); + DDim col_matrix_shape_f = {c * h * w, k_h * k_w}; + col_matrix_f.Resize(col_matrix_shape_f); + + filter_grad->mutable_data(context.GetPlace()); + Tensor filter_grad_ = *filter_grad; + filter_grad_.Resize(filter_matrix_shape); + auto t = framework::EigenVector::Flatten(filter_grad_); + t.device(context.GetEigenDevice()) = t.constant(static_cast(0)); + + for (int i = 0; i < batch_size; ++i) { + // batch with size (c, o_h, o_w) + Tensor output_grad_batch = + output_grad->Slice(i, i + 1).Resize(output_shape); + // input batch + Tensor in_batch = input->Slice(i, i + 1).Resize(input_matrix_shape); + + // im2col: (c * h * w, k_h * k_w) + im2col(context.device_context(), output_grad_batch, col, strides[0], + strides[1], paddings[0], paddings[0], paddings[1], paddings[1]); + + // gemm: d_filter = x * y_grad^T + // (m, c * h * w) * (k_h * k_w, c * h * w) -> (m, c, h) + math::matmul(context.device_context(), in_batch, false, + col_matrix_f, true, T(1.0), &filter_grad_, + T(1.0)); + } + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/crop_op.cc b/paddle/operators/crop_op.cc index a994d916761da3b44cb60bd4c1c767cd1987522f..ed78e9e3a3a49b7ff0990b8d13cfe2dae594b722 100644 --- a/paddle/operators/crop_op.cc +++ b/paddle/operators/crop_op.cc @@ -59,7 +59,8 @@ class CropOpMaker : public framework::OpProtoAndCheckerMaker { "The input should be a k-D tensor(k > 0 and k < 7)"); AddInput("Y", "The input used as reference for cropping" - " with the same dimension as X. "); + " with the same dimension as X. ") + .AsDispensable(); AddOutput("Out", "The output of crop op " "with the same dimension as X."); diff --git a/paddle/operators/fc_op.cc b/paddle/operators/fc_op.cc deleted file mode 100644 index 7c422c81fc479fa2e317bdee1b66017096381d27..0000000000000000000000000000000000000000 --- a/paddle/operators/fc_op.cc +++ /dev/null @@ -1,200 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/framework/op_registry.h" -#include "paddle/operators/net_op.h" - -namespace paddle { -namespace operators { - -class FCOp : public NetOp { - public: - FCOp(const std::string &type, const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) - : NetOp(type, inputs, outputs, attrs) { - PADDLE_ENFORCE(!Inputs("X").empty(), - "Inputs(X) of FCOp should not be null."); - PADDLE_ENFORCE(!Inputs("W").empty(), - "Inputs(W) of FCOp should not be null."); - PADDLE_ENFORCE(!Outputs("MulOut").empty(), - "Outputs(MulOut) of FCOp should not be null."); - PADDLE_ENFORCE_NE(Output("Out"), framework::kEmptyVarName, - "Output(Out) of FCOp should not be null."); - - auto x = Inputs("X"); - auto w = Inputs("W"); - auto mul_out = Outputs("MulOut"); - PADDLE_ENFORCE_EQ( - x.size(), w.size(), - "The size of inputs X(%d) should be the same as that of weights W(%d).", - x.size(), w.size()); - PADDLE_ENFORCE_EQ(mul_out.size(), x.size(), - "The size of intermediate mul_out(%d) should be the same " - "as that of inputs X(%d).", - mul_out.size(), x.size()); - - size_t n = x.size(); - PADDLE_ENFORCE_GE(n, static_cast(1), - "The size of inputs X(%d) should be no less than 1.", n); - - auto x_num_col_dims = Attr>("xNumColDims"); - - // Set all values or set no values (use the default value) - if (!x_num_col_dims.empty()) { - PADDLE_ENFORCE_EQ(x_num_col_dims.size(), n, - "The size of attribute xNumColDims(%d) should be the " - "same as that of inputs X(%d).", - x_num_col_dims.size(), n); - } else { - x_num_col_dims.resize(n); - for (size_t i = 0; i < n; i++) { - x_num_col_dims[i] = 1; - } - } - - // mul_out[i] = X[i] * W[i] - for (size_t i = 0; i < n; i++) { - framework::AttributeMap mul_attr; - mul_attr["x_num_col_dims"] = static_cast(x_num_col_dims[i]); - mul_attr["y_num_col_dims"] = static_cast(1); - AppendOp( - framework::OpRegistry::CreateOp("mul", {{"X", {x[i]}}, {"Y", {w[i]}}}, - {{"Out", {mul_out[i]}}}, mul_attr)); - } - - // sum_out = X[0] * W[0] + ... + X[n-1] * W[n-1] - auto sum_out = mul_out[0]; - if (n > 1) { - PADDLE_ENFORCE_NE(Output("SumOut"), framework::kEmptyVarName, - "Output(SumOut) of FCOp should not be null when the " - "size of Inputs(X) > 1."); - - sum_out = Output("SumOut"); - AppendOp(framework::OpRegistry::CreateOp("sum", {{"X", {mul_out}}}, - {{"Out", {sum_out}}}, {})); - } else { - if (Output("SumOut") != framework::kEmptyVarName) { - this->Rename(Output("SumOut"), framework::kEmptyVarName); - } - } - - // add_out = sum_out + b - auto b = Input("B"); - auto add_out = sum_out; - if (b != framework::kEmptyVarName) { - PADDLE_ENFORCE_NE( - Output("AddOut"), framework::kEmptyVarName, - "Output(AddOut) of FCOp should not be null when Input(B) is set."); - - add_out = Output("AddOut"); - AppendOp(framework::OpRegistry::CreateOp( - "elementwise_add", {{"X", {sum_out}}, {"Y", {Input("B")}}}, - {{"Out", {add_out}}}, {})); - } else { - if (Output("AddOut") != framework::kEmptyVarName) { - this->Rename(Output("AddOut"), framework::kEmptyVarName); - } - } - - auto activation = Attr("activation"); - AppendOp(framework::OpRegistry::CreateOp(activation, {{"X", {add_out}}}, - {{"Y", {Output("Out")}}}, {})); - CompleteAddOp(false); - } -}; - -class FCOpMaker : public framework::OpProtoAndCheckerMaker { - public: - FCOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) - : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", - "(A vector of Tensors) each input Tensor can be of arbitrary " - "dimension, and will be reshaped to a 2-D matrix of size " - "(minibatch, number_of_input_features) according to attribute " - "xNumColDims.") - .AsDuplicable(); - AddInput("W", - "(A vector of Tensors) the weights of FC operator, a " - "vector of 2-D matrix of size " - "(number_of_input_features, number_of_neurons).") - .AsDuplicable(); - AddInput("B", - "(Tensor) the bias of FC operator, a 1-D vector of size " - "number_of_neurons."); - - AddOutput("Out", - "(Tensor) the activated output matrix of FC operator, a 2-D " - "matrix of size (minibatch, number_of_neurons)."); - AddOutput("MulOut", - "(A vector of Tensors) the intermediate outputs of FC operator, " - "each Tensor saving the product of X_i * W_i.") - .AsIntermediate() - .AsDuplicable(); - AddOutput( - "SumOut", - "(Tensor) the intermediate output of FC operator, " - "saving the sum of the products of X and W, that is sum{X_i * W_i}.") - .AsIntermediate(); - AddOutput("AddOut", - "(Tensor) the non-actived output of FC operator, " - "saving sum{X_i * W_i} + B.") - .AsIntermediate(); - AddAttr( - "activation", - "(string, default identity) the activation type of FC operator.") - .SetDefault("identity") - .InEnum({"identity", "sigmoid", "softmax"}); - AddAttr>( - "xNumColDims", - "(std::vector) The inputs Tensors of FC operator can be of " - "more than 2 dimensions. In that case, each input Tensor `X_i` will be " - "reshaped to a 2-D matrix. The matrix's first dimension " - "(the length of column) will be the product of `X_i`'s last " - "`xNumColDims_i` dimensions, that is " - "`X_i.dims[0] x ... x X_i.dims[xNumColDims_i - 1]`. " - "The matrix's second dimension (the length of row) will be the product " - "of `X_i`'s first `rank - xNumColDims_i` dimensions, that is " - "`X_i.dims[xNumColDims_i] x ... x X_i.dims[rank - 1]`)") - .SetDefault(std::vector{}); - - AddComment(R"DOC( -Fully Connected Operator, known as Fully Connected Layer or Inner Product Layer -in Convolutional Neural Networks. Neurons in a fully connected layer have -full connections to all activations in the previous layer. -It computes an inner product of a set of -learned weights with a matrix multiplication followed by a bias offset -(optionally). - -Equation: - Out = Act(sum_n{X_i * W_i} + B) - -where X_i is Tensor that will be reshaped to a 2-D matrix of size (M x K), -usually M is the minibatch size and K is the number of input features. -W_i is a 2-D matrix of size (K x N), where N means the number of neurons -in the fully connected layer. B is a 1-D vector of size N. -Thus, the output Out is a 2-D matrix of size (M x N). -Activation type can be set to `identity` (default), `sigmoid` or `softmax`. - -All the inputs can carry the LoD (Level of Details) information, -or not. But the output only shares the LoD with first input (`X[0]`). -)DOC"); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OP_WITHOUT_GRADIENT(fc, ops::FCOp, ops::FCOpMaker); diff --git a/paddle/operators/gaussian_random_op.cc b/paddle/operators/gaussian_random_op.cc index f59f497d9f32069b764a9f777c7e9d6da9cdb108..04dfdf7c48381240108cf924979764966599151f 100644 --- a/paddle/operators/gaussian_random_op.cc +++ b/paddle/operators/gaussian_random_op.cc @@ -59,7 +59,7 @@ class GaussianRandomOp : public framework::OperatorWithKernel { protected: framework::DataType IndicateDataType( const framework::ExecutionContext& ctx) const override { - return static_cast(Attr("data_type")); + return static_cast(ctx.Attr("data_type")); } }; diff --git a/paddle/operators/gru_unit_op.cc b/paddle/operators/gru_unit_op.cc index 72dd841c85ce9934a57263d10c366e679693c471..a596f93769780419d27b7c0b40631d3da78e6700 100644 --- a/paddle/operators/gru_unit_op.cc +++ b/paddle/operators/gru_unit_op.cc @@ -54,8 +54,7 @@ class GRUUnitOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_EQ( weight_width, frame_size * 3, "The shape of Weight matrix must be [frame_size, frame_size * 3]."); - auto bias = Input("Bias"); - if (bias != framework::kEmptyVarName) { + if (ctx->HasInput("Bias")) { auto bias_dims = ctx->GetInputDim("Bias"); int bias_height = bias_dims[0]; int bias_width = bias_dims[1]; @@ -89,7 +88,8 @@ class GRUUnitOpMaker : public framework::OpProtoAndCheckerMaker { "weights of output candidate with shape [frame_size, frame_size]"); AddInput("Bias", "(Tensor) Bias vector with shape [1, frame_size * 3] concating " - "bias of the update gate, reset gate and output candidate."); + "bias of the update gate, reset gate and output candidate.") + .AsDispensable(); AddOutput("Gate", "(Tensor) Matrix with shape [batch_size, frame_size * 3] for the " "output of update gate, reset gate and output candidate") diff --git a/paddle/operators/identity_op.cc b/paddle/operators/identity_op.cc deleted file mode 100644 index 2cc632205e63abbe412b09af4b894420ac512ec5..0000000000000000000000000000000000000000 --- a/paddle/operators/identity_op.cc +++ /dev/null @@ -1,63 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ - -#include "paddle/operators/net_op.h" -#include "paddle/operators/scale_op.h" - -namespace paddle { -namespace operators { - -// The identity operator is an alias of the scale operator. This is also an -// example for creating an alias for an existing operator. -template -class IdentityOpMaker : public framework::OpProtoAndCheckerMaker { - public: - IdentityOpMaker(framework::OpProto *proto, - framework::OpAttrChecker *op_checker) - : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "The input tensor of identity operator."); - AddOutput("Y", "The output tensor of identity operator."); - AddComment(R"DOC( -The identity operator is an alias of the scale operator -with the attribute scale fixed to 1.0. -)DOC"); - } -}; - -template -class IdentityOp : public NetOp { - public: - IdentityOp(const std::string &type, const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) - : NetOp(type, inputs, outputs, attrs) { - PADDLE_ENFORCE_NE(Input("X"), framework::kEmptyVarName, - "Input(X) of IdentityOp should not be null."); - PADDLE_ENFORCE_NE(Output("Y"), framework::kEmptyVarName, - "Output(Y) of IdentityOp should not be null."); - - AppendOp(framework::OpRegistry::CreateOp( - "scale", {{"X", {Input("X")}}}, {{"Out", {Output("Y")}}}, - {{"scale", static_cast(1)}})); - CompleteAddOp(false); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; - -REGISTER_OP_WITHOUT_GRADIENT(identity, ops::IdentityOp, - ops::IdentityOpMaker); diff --git a/paddle/operators/interp_op.cc b/paddle/operators/interp_op.cc deleted file mode 100644 index d02b01c3f3a1b30ec27253140203b076a98ce0c2..0000000000000000000000000000000000000000 --- a/paddle/operators/interp_op.cc +++ /dev/null @@ -1,113 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/framework/op_registry.h" -#include "paddle/operators/net_op.h" - -namespace paddle { -namespace operators { - -class InterpOp : public NetOp { - public: - InterpOp(const std::string &type, const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) - : NetOp(type, inputs, outputs, attrs) { - PADDLE_ENFORCE_NE(Input("X"), framework::kEmptyVarName, - "Input(X) of InterpOp should not be null."); - PADDLE_ENFORCE_NE(Input("Y"), framework::kEmptyVarName, - "Input(Y) of InterpOp should not be null."); - PADDLE_ENFORCE_NE(Input("W"), framework::kEmptyVarName, - "Input(W) of InterpOp should not be null."); - PADDLE_ENFORCE_NE(Output("SubOut"), framework::kEmptyVarName, - "Output(SubOut) of InterpOp should not be null."); - PADDLE_ENFORCE_NE(Output("MulOut"), framework::kEmptyVarName, - "Output(MulOut) of InterpOp should not be null."); - PADDLE_ENFORCE_NE(Output("Out"), framework::kEmptyVarName, - "Output(Out) of InterpOp should not be null."); - - // SubOut = X - Y - auto x = Input("X"); - auto y = Input("Y"); - auto sub_out = Output("SubOut"); - AppendOp(framework::OpRegistry::CreateOp( - "elementwise_sub", {{"X", {x}}, {"Y", {y}}}, {{"Out", {sub_out}}}, {})); - - // MulOut = SubOut * W = (X - Y) * W - auto w = Input("W"); - auto mul_out = Output("MulOut"); - AppendOp(framework::OpRegistry::CreateOp( - "elementwise_mul", {{"X", {sub_out}}, {"Y", {w}}}, {{"Out", {mul_out}}}, - {{"axis", 0}})); - - // Out = MulOut + Y = (X - Y) * W + Y = X * W + Y * (1 - W) - AppendOp(framework::OpRegistry::CreateOp("elementwise_add", - {{"X", {mul_out}}, {"Y", {y}}}, - {{"Out", {Output("Out")}}}, {})); - - CompleteAddOp(false); - } -}; - -class InterpOpMaker : public framework::OpProtoAndCheckerMaker { - public: - InterpOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) - : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", - "(Tensor), 2-D Matrix of shape [batch_size, data_dim]" - "containing data samples, the first input of interp_op"); - AddInput("Y", - "(Tensor), 2-D Matrix of shape `[batch_size, data_dim]`" - "containing data samples, the second input of interp_op"); - AddInput("W", - "(Tensor), 1-D Vector of shape [batch_size]," - "the interpolated values in the half-open interval [0.0, 1.0)"); - AddOutput("SubOut", - "(Tensor), the intermediate subtraction outputs, saving X - Y.") - .AsIntermediate(); - AddOutput("MulOut", - "(Tensor), the intermediate multiplication outputs," - "saving the elementwise multiplication of (X - Y) and W.") - .AsIntermediate(); - AddOutput("Out", - "(Tensor), the output of interp_op, same shape with X," - "returns the first-dimensional piecewise linear interpolant " - "between X and Y"); - AddComment(R"DOC( - Linear Interpolation with two inputs, used in NEURAL TURING MACHINE. - - Equation: - Out.row[i] = X.row[i] * W[i] + Y.row[i] * (1 - W[i]) - = (X.row[i] - Y.row[i]) * W[i] + Y.row[i] - - Example: - X = [[1,2],[3,4]], - Y = [[2,1],[4,3]], - W = [0.3, 0.4] - - Then, Out = [[1.7,1.3],[3.6,3.4]] - - where 1.7 = 1*0.3+2*(1-0.3), - 1.3 = 2*0.3+1*(1-0.3), - 3.6 = 3*0.4+4*(1-0.4), - 3.4 = 4*0.4+3*(1-0.4) -)DOC"); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OP_WITHOUT_GRADIENT(interp, ops::InterpOp, ops::InterpOpMaker); diff --git a/paddle/operators/lookup_table_op.cc b/paddle/operators/lookup_table_op.cc index b88cd14d78f616b0e57386ab891dad1a872bfe65..ad86a2e5bc23b2b0ea853971cf79dec745e9706a 100644 --- a/paddle/operators/lookup_table_op.cc +++ b/paddle/operators/lookup_table_op.cc @@ -32,6 +32,9 @@ class LookupTableOp : public framework::OperatorWithKernel { auto table_dims = ctx->GetInputDim("W"); auto ids_dims = ctx->GetInputDim("Ids"); + PADDLE_ENFORCE_EQ(ids_dims.size(), 2); + PADDLE_ENFORCE_EQ(ids_dims[1], 1); + ctx->SetOutputDim("Out", {ids_dims[0], table_dims[1]}); ctx->ShareLoD("Ids", /*->*/ "Out"); } @@ -53,7 +56,9 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker { " which is a learnable parameter."); AddInput("Ids", "An input with type int32 or int64" - "contains the ids to be looked up in W."); + "contains the ids to be looked up in W." + "Ids must be a column vector with rank = 2." + "The 2nd dimension size must be 1"); AddOutput("Out", "The lookup results, which have the same type with W."); AddComment(R"DOC( This operator is used to perform lookups on the parameter W, diff --git a/paddle/operators/lstm_op.cc b/paddle/operators/lstm_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..0a089b7c2dc1e05224525bc4fe5399ec39036d01 --- /dev/null +++ b/paddle/operators/lstm_op.cc @@ -0,0 +1,226 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/lstm_op.h" + +namespace paddle { +namespace operators { + +class LSTMOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Input"), + "Input(Input) of LSTM should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Hidden"), + "Output(Hidden) of LSTM should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Cell"), + "Output(Cell) of LSTM should not be null."); + + auto x_dims = ctx->GetInputDim("Input"); + PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2."); + + if (ctx->HasInput("H0")) { + PADDLE_ENFORCE(ctx->HasInput("C0"), + "Input(Cell) and Input(Hidden) of LSTM should not " + "be null at the same time."); + auto h_dims = ctx->GetInputDim("H0"); + auto c_dims = ctx->GetInputDim("C0"); + PADDLE_ENFORCE(h_dims == c_dims, + "The dimension of Input(H0) and Input(C0) " + "should be the same."); + } + + int frame_size = x_dims[1] / 4; + auto w_dims = ctx->GetInputDim("Weight"); + PADDLE_ENFORCE_EQ(w_dims.size(), 2, + "The rank of Input(Weight) should be 2."); + PADDLE_ENFORCE_EQ(w_dims[0], frame_size, + "The first dimension of Input(Weight) " + "should be %d.", + frame_size); + PADDLE_ENFORCE_EQ(w_dims[1], 4 * frame_size, + "The second dimension of Input(Weight) " + "should be 4 * %d.", + frame_size); + auto b_dims = ctx->GetInputDim("Bias"); + PADDLE_ENFORCE_EQ(b_dims.size(), 2, "The rank of Input(Bias) should be 2."); + PADDLE_ENFORCE_EQ(b_dims[0], 1, + "The first dimension of Input(Bias) should be 1."); + if (ctx->Attrs().Get("usePeepholes")) { + PADDLE_ENFORCE_EQ(b_dims[1], 7 * frame_size, + "The second dimension of Input(Bias) should be " + "7 * %d if enable peepholes connection", + frame_size); + } else { + PADDLE_ENFORCE_EQ(b_dims[1], 4 * frame_size, + "The second dimension of Input(Bias) should be " + "4 * %d if disable peepholes connection", + frame_size); + } + ctx->SetOutputDim("Hidden", {x_dims[0], frame_size}); + ctx->SetOutputDim("Cell", {x_dims[0], frame_size}); + ctx->SetOutputDim("BatchGate", x_dims); + ctx->ShareLoD("Input", "Hidden"); + ctx->ShareLoD("Input", "Cell"); + } +}; + +class LSTMOpMaker : public framework::OpProtoAndCheckerMaker { + public: + LSTMOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("Input", + "(LoDTensor) the first input is a LodTensor, which support " + "variable-time length input sequence. The underlying tensor in " + "this LoDTensor is a matrix with shape (T X 4D), where, T is the " + "total time steps in this mini-batch, D is the hidden size."); + AddInput("H0", + "(Tensor, optional) the initial hidden state is an optional " + "input. This is a tensor with shape (N x D), where N is the " + "batch size, D is the hidden size."); + AddInput("C0", + "(Tensor, optional) the initial cell state is an optional " + "input. This is a tensor with shape (N x D), where N is the " + "batch size. `H0` and `C0` can be NULL but only at the same time"); + AddInput("Weight", + "(Tensor) the learnable hidden-hidden weights." + " - The shape is (D x 4D), where D is the hidden size. " + " - Weight = {W_ch, W_ih, W_fh, W_oh}"); + AddInput("Bias", + "(Tensor) the learnable weights, which contains two parts: " + "input-hidden bias weight and peephole connections weight if " + "setting `usePeepholes` True. " + "1. `usePeepholes = False` " + " - The shape is (1 x 4D). " + " - Bias = {b_c, b_i, b_f, b_o}." + "2. `usePeepholes = True` " + " - The shape is (1 x 7D). " + " - Bias = {b_c, b_i, b_f, b_o, W_ic, W_fc, W_oc}."); + AddOutput("BatchGate", + "(LoDTensor) This LoDTensor contains input gate, forget gate " + "and output gate after the nonlinear computation. This " + "LoDTensor has the same shape with the reorganized input, which " + "was also be called batch input. The LoD size is 2. The first " + "LoD is the batch offsets and the second LoD contains the " + "indexes, which denote the position of reorganized sequence " + "in the raw input.") + .AsIntermediate(); + AddOutput("Hidden", + "(LoDTensor) the hidden state lod tensor of LSTM operator. " + "The shape and lod is the same with the `Input`."); + AddOutput("Cell", + "(LoDTensor) the cell state lod tensor of LSTM operator. " + "The shape and lod is the same with the `Input`."); + AddAttr("usePeepholes", + "(bool, defalut: True) " + "whether to enable diagonal/peephole connections.") + .SetDefault(true); + AddAttr("isReverse", + "(bool, defalut: False) " + "whether to compute reversed LSTM.") + .SetDefault(false); + AddAttr( + "gateActivation", + "(string, default: sigmoid)" + "The activation for input gate, forget gate and output " + "gate, `sigmoid` by default.") + .SetDefault("sigmoid"); + AddAttr("cellActivation", + "(string, default: tanh)" + "The activation for cell output, `tanh` by defalut.") + .SetDefault("tanh"); + AddAttr("candidateActivation", + "(string, default: tanh)" + "The activation for candidate hidden state, " + "`tanh` by default.") + .SetDefault("tanh"); + AddComment(R"DOC(Long-Short Term Memory (LSTM) Operator + +The defalut implementation is diagonal/peephole connection [1], the formula is +as follows + + i_t = \sigma(W_{ix}x_{t} + W_{ih}h_{t-1} + W_{ic}c_{t-1} + b_i) + + f_t = \sigma(W_{fx}x_{t} + W_{fh}h_{t-1} + W_{fc}c_{t-1} + b_f) + + \tilde{c_t} = act_g(W_{cx}x_t + W_{ch}h_{t-1} + b_c) + + o_t = \sigma(W_{ox}x_{t} + W_{oh}h_{t-1} + W_{oc}c_t + b_o) + + c_t = f_t ⊙ c_{t-1} + i_t ⊙ \tilde{c_t} + + h_t = o_t ⊙ act_h(c_t) + +where the W terms denote weight matrices (e.g. \f$W_{xi}\f$ is the matrix +of weights from the input gate to the input), \f$W_{ic}, W_{fc}, W_{oc}\f$ +are diagonal weight matrices for peephole connections. In our implenmention, +We use vectors to reprenset these diagonal weight matrices. The b terms +denote bias vectors (\f$b_i\f$ is the input gate bias vector), \f$\sigma\f$ +is the non-line actications, such as logistic sigmoid function, and +\f$i, f, o\f$ and \f$c\f$ are respectively the input gate, forget gate, +output gate and cell activation vectors, all of which are the same size as +the cell output activation vector \f$h\f$. + +The ⊙ is the element-wise product of the vectors, \f$act_g\f$ and \f$act_h\f$ +are the cell input and cell output activation functions, `tanh` is usually +used for them. \f$\tilde{c_t}\f$ is also called candidate hidden state, +which is computed based on the current input and the previous hidden state. + +Set `usePeepholes` False to disable peephole connection [2]. The formula +is omitted here. + +@note These \f$W_{xi}x_{t}, W_{xf}x_{t}, W_{xc}x_{t}, W_{xo}x_{t}\f$ +operations on the input x_{t} were NOT included in this operator. +Users can choose to use fully-connect operator before LSTM operator. + +[1] Hasim Sak, Andrew Senior, and Francoise Beaufays. Long short-term memory +recurrent neural network architectures for large scale acoustic modeling. +INTERSPEECH, 2014. + +[2] S. Hochreiter and J. Schmidhuber. Long Short-Term Memory. +Neural Computation, 9(8):1735-1780, 1997. + +)DOC"); + } +}; + +class LSTMGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Hidden")), + "Input(Hidden@GRAD) should not be null"); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Cell")), + "Input(Cell@GRAD) should not be null"); + ctx->SetOutputDim(framework::GradVarName("Weight"), + ctx->GetInputDim("Weight")); + ctx->SetOutputDim(framework::GradVarName("Bias"), ctx->GetInputDim("Bias")); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(lstm, ops::LSTMOp, ops::LSTMOpMaker, lstm_grad, ops::LSTMGradOp); +REGISTER_OP_CPU_KERNEL(lstm, ops::LSTMKernel, + ops::LSTMKernel); +REGISTER_OP_CPU_KERNEL(lstm_grad, + ops::LSTMGradKernel, + ops::LSTMGradKernel); diff --git a/paddle/operators/lstm_op.cu b/paddle/operators/lstm_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..9ad56941553bf19a56c25f41f76fe20dfa3a106f --- /dev/null +++ b/paddle/operators/lstm_op.cu @@ -0,0 +1,23 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#define EIGEN_USE_GPU +#include "paddle/operators/lstm_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL(lstm, ops::LSTMKernel, + ops::LSTMKernel); +REGISTER_OP_GPU_KERNEL(lstm_grad, + ops::LSTMGradKernel, + ops::LSTMGradKernel); diff --git a/paddle/operators/lstm_op.h b/paddle/operators/lstm_op.h new file mode 100644 index 0000000000000000000000000000000000000000..0af5694c48fcb4437e3acd422606de013bb2e145 --- /dev/null +++ b/paddle/operators/lstm_op.h @@ -0,0 +1,139 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/op_registry.h" +#include "paddle/operators/math/lstm_compute.h" +#include "paddle/operators/math/math_function.h" +#include "paddle/operators/math/sequence2batch.h" + +namespace paddle { +namespace operators { + +using framework::LoDTensor; +using framework::Tensor; +template +using EigenMatrix = framework::EigenMatrix; + +template +class LSTMKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("Input"); + auto* weight = ctx.Input("Weight"); + auto* bias = ctx.Input("Bias"); + + auto* batch_gate = ctx.Output("BatchGate"); + batch_gate->mutable_data(ctx.GetPlace()); + auto* hidden_out = ctx.Output("Hidden"); + hidden_out->mutable_data(ctx.GetPlace()); + auto* cell_out = ctx.Output("Cell"); + cell_out->mutable_data(ctx.GetPlace()); + + // Now the function ShareLoD in InferShape is not implemented. + // So copy LoD here. + ctx.ShareLoD("Input", "Hidden"); + ctx.ShareLoD("Input", "Cell"); + + bool is_reverse = ctx.Attr("isReverse"); + math::LoDTensor2BatchFunctor to_batch; + to_batch(ctx.device_context(), *input, *batch_gate, is_reverse); + + auto in_dims = input->dims(); + int frame_size = static_cast(in_dims[1] / 4); + framework::DDim dims({in_dims[0], frame_size}); + + if (bias) { + Eigen::array extents({{1, 4 * frame_size}}); + Eigen::array offsets({{0, 0}}); + auto b = EigenMatrix::From(*bias); + auto gate = EigenMatrix::From(*batch_gate); + gate.device(ctx.GetEigenDevice()) = + gate + + b.slice(offsets, extents) + .reshape(Eigen::array({{1, frame_size * 4}})) + .broadcast( + Eigen::array({{static_cast(in_dims[0]), 1}})); + } + + math::LstmMetaValue lstm_value; + T* bias_data = const_cast(bias->data()); + // the code style in LstmMetaValue will be updated later. + lstm_value.checkIg = bias_data + 4 * frame_size; + lstm_value.checkFg = lstm_value.checkIg + frame_size; + lstm_value.checkOg = lstm_value.checkFg + frame_size; + lstm_value.prevStateValue = nullptr; + + framework::LoDTensor batch_out, batch_cell, batch_cell_pre_act; + batch_out.mutable_data(dims, ctx.GetPlace()); + batch_cell.mutable_data(dims, ctx.GetPlace()); + batch_cell_pre_act.mutable_data(dims, ctx.GetPlace()); + + auto batch_starts = batch_gate->lod()[0]; + size_t num_batch = batch_starts.size() - 1; + auto gate_act = ctx.Attr("gateActivation"); + auto cell_act = ctx.Attr("cellActivation"); + auto cand_act = ctx.Attr("candidateActivation"); + + for (size_t n = 0; n < num_batch; n++) { + int bstart = static_cast(batch_starts[n]); + int bend = static_cast(batch_starts[n + 1]); + + Tensor gate_t = batch_gate->Slice(bstart, bend); + Tensor out_t = batch_out.Slice(bstart, bend); + Tensor cell_t = batch_cell.Slice(bstart, bend); + Tensor cell_pre_act_t = batch_cell_pre_act.Slice(bstart, bend); + + int cur_batch_size = bend - bstart; + + if (n != 0) { + int pre_h_start = static_cast(batch_starts[n - 1]); + int pre_h_end = pre_h_start + cur_batch_size; + auto pre_hidden_t = batch_out.Slice(pre_h_start, pre_h_end); + math::matmul(ctx.device_context(), pre_hidden_t, false, + *weight, false, static_cast(1.0), &gate_t, + static_cast(1.0)); + } + // else if : FIXME support the initial hidden and cell + + lstm_value.gateValue = gate_t.data(); + lstm_value.outputValue = out_t.data(); + lstm_value.stateValue = cell_t.data(); + lstm_value.stateActiveValue = cell_pre_act_t.data(); + math::LstmUnitFunctor::compute(ctx.device_context(), lstm_value, + frame_size, cur_batch_size, + gate_act, cell_act, cand_act); + lstm_value.prevStateValue = lstm_value.stateValue; + } + + math::Batch2LoDTensorFunctor to_seq; + batch_out.set_lod(batch_gate->lod()); + // restore the output hidden in LoDTensor from the batch hidden + to_seq(ctx.device_context(), batch_out, *hidden_out); + + batch_cell.set_lod(batch_gate->lod()); + // restore the output cell state in LoDTensor from the batch cell + to_seq(ctx.device_context(), batch_cell, *cell_out); + } +}; + +template +class LSTMGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override {} +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/lstm_unit_op.h b/paddle/operators/lstm_unit_op.h index a0ff498c1d3ed2aaa10f5473ef91de168c250649..625b1852c2f0eb2ed435f73fea251c40c614a7dd 100644 --- a/paddle/operators/lstm_unit_op.h +++ b/paddle/operators/lstm_unit_op.h @@ -19,7 +19,6 @@ namespace paddle { namespace operators { -using framework::LoDTensor; using framework::Tensor; template diff --git a/paddle/operators/math/CMakeLists.txt b/paddle/operators/math/CMakeLists.txt index 2560c0a5aace83b111dcf02b3349335cfdc78274..a3a744e5f7022be7c1630070bcc5476fda7bed17 100644 --- a/paddle/operators/math/CMakeLists.txt +++ b/paddle/operators/math/CMakeLists.txt @@ -1,3 +1,5 @@ +add_subdirectory(detail) + if(WITH_GPU) nv_library(math_function SRCS math_function.cc math_function.cu im2col.cc im2col.cu DEPS cblas device_context operator) nv_test(math_function_gpu_test SRCS math_function_test.cu DEPS math_function tensor) @@ -8,6 +10,8 @@ if(WITH_GPU) nv_library(pooling SRCS pooling.cc pooling.cu DEPS device_context) nv_library(vol2col SRCS vol2col.cc vol2col.cu DEPS device_context) nv_library(sequence_project SRCS sequence_project.cc sequence_project.cu DEPS device_context) + nv_library(sequence2batch SRCS sequence2batch.cc sequence2batch.cu DEPS device_context) + nv_library(lstm_compute SRCS lstm_compute.cc lstm_compute.cu DEPS device_context activation_functions) else() cc_library(math_function SRCS math_function.cc im2col.cc DEPS cblas device_context operator) cc_library(selected_rows_functor SRCS selected_rows_functor.cc DEPS selected_rows math_function) @@ -16,6 +20,8 @@ else() cc_library(pooling SRCS pooling.cc DEPS device_context) cc_library(vol2col SRCS vol2col.cc DEPS device_context) cc_library(sequence_project SRCS sequence_project.cc DEPS device_context) + cc_library(sequence2batch SRCS sequence2batch.cc DEPS device_context) + cc_library(lstm_compute SRCS lstm_compute.cc DEPS device_context activation_functions) endif() cc_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) diff --git a/paddle/operators/math/cross_entropy.cu b/paddle/operators/math/cross_entropy.cu index 367190e6b0682ec62550e869e2f04c3a2b2cbec3..db878129d650d663e187ecabb106eea0e39db6fa 100644 --- a/paddle/operators/math/cross_entropy.cu +++ b/paddle/operators/math/cross_entropy.cu @@ -22,8 +22,6 @@ namespace { template __global__ void CrossEntropyKernel(T* Y, const T* X, const int* label, const int N, const int D) { - // TOOD(qingqing) define CUDA_1D_KERNEL_LOOP macro in a common file. - // CUDA_1D_KERNEL_LOOP(i, N) { for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; i += blockDim.x * gridDim.x) { PADDLE_ASSERT(label[i] >= 0 && label[i] < D); diff --git a/paddle/operators/math/detail/CMakeLists.txt b/paddle/operators/math/detail/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..49cf228de2204cb4888cf645a0cb68ed04cc3371 --- /dev/null +++ b/paddle/operators/math/detail/CMakeLists.txt @@ -0,0 +1,5 @@ +if(WITH_AVX) + cc_library(activation_functions SRCS hl_cpu_functions.cc hl_avx_functions.cc) +else() + cc_library(activation_functions SRCS hl_cpu_functions.cc) +endif() diff --git a/paddle/operators/math/detail/hl_activation_functions.h b/paddle/operators/math/detail/hl_activation_functions.h new file mode 100644 index 0000000000000000000000000000000000000000..9d7d9914f0090bff17049038dfa2288d84f3dbda --- /dev/null +++ b/paddle/operators/math/detail/hl_activation_functions.h @@ -0,0 +1,188 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifndef HL_ACTIVATION_FUNCTIONS_H_ +#define HL_ACTIVATION_FUNCTIONS_H_ + +#include "hl_functions.h" +#include "paddle/operators/math/lstm_compute.h" + +/** + * Active functions: sigmoid, relu, tanh and linear. + */ +#define FLOAT_ACTIVE_FUNCTION \ + { \ + hppl::typef::sigmoid, hppl::typef::relu, hppl::typef::tanh, \ + hppl::typef::linear \ + } + +#define DOUBLE_ACTIVE_FUNCTION \ + { \ + hppl::typed::sigmoid, hppl::typed::relu, hppl::typed::tanh, \ + hppl::typed::linear \ + } + +#define AVX_ACTIVE_FUNCTION \ + { hppl::sigmoid, hppl::relu, hppl::tanh, hppl::linear } + +namespace hppl { + +using activation_mode_t = paddle::operators::math::activation_mode_t; + +/** + * Hppl supports sigmoid, relu, tanh, linear active functions + * for neural networks' forward and backward activation. + */ +template +class Active { + public: + typedef T (*forward)(T); + typedef T (*backward)(T, T); +}; + +template +struct ForwardActType; + +template <> +struct ForwardActType { + using type = Active::forward; +}; + +template <> +struct ForwardActType { + using type = Active::forward; +}; + +template +struct BackwardActType; + +template <> +struct BackwardActType { + using type = Active::backward; +}; + +template <> +struct BackwardActType { + using type = Active::backward; +}; + +#ifdef __NVCC__ +namespace gpu { +static __device__ Active::forward forward[] = FLOAT_ACTIVE_FUNCTION; +static __device__ Active::backward backward[] = FLOAT_ACTIVE_FUNCTION; + +static __device__ Active::forward forward_d[] = DOUBLE_ACTIVE_FUNCTION; +static __device__ Active::backward backward_d[] = + DOUBLE_ACTIVE_FUNCTION; + +template +struct ForwardAct { + __device__ typename ForwardActType::type operator()( + activation_mode_t type); +}; + +template <> +struct ForwardAct { + __device__ ForwardActType::type operator()(activation_mode_t type) { + return forward[type]; + } +}; + +template <> +struct ForwardAct { + __device__ ForwardActType::type operator()(activation_mode_t type) { + return forward_d[type]; + } +}; + +template +struct BackwardAct { + __device__ typename BackwardActType::type operator()( + activation_mode_t type); +}; + +template <> +struct BackwardAct { + __device__ BackwardActType::type operator()(activation_mode_t type) { + return backward[type]; + } +}; + +template <> +struct BackwardAct { + __device__ BackwardActType::type operator()(activation_mode_t type) { + return backward_d[type]; + } +}; + +} // namespace gpu +#else +namespace cpu { +static Active::forward forward[] = FLOAT_ACTIVE_FUNCTION; +static Active::backward backward[] = FLOAT_ACTIVE_FUNCTION; + +static Active::forward forward_d[] = DOUBLE_ACTIVE_FUNCTION; +static Active::backward backward_d[] = DOUBLE_ACTIVE_FUNCTION; + +template +struct ForwardAct { + typename ForwardActType::type operator()(activation_mode_t type); +}; + +template <> +struct ForwardAct { + ForwardActType::type operator()(activation_mode_t type) { + return forward[type]; + } +}; + +template <> +struct ForwardAct { + ForwardActType::type operator()(activation_mode_t type) { + return forward_d[type]; + } +}; + +template +struct BackwardAct { + typename BackwardActType::type operator()(activation_mode_t type); +}; + +template <> +struct BackwardAct { + BackwardActType::type operator()(activation_mode_t type) { + return backward[type]; + } +}; + +template <> +struct BackwardAct { + BackwardActType::type operator()(activation_mode_t type) { + return backward_d[type]; + } +}; + +} // namespace cpu + +#ifdef __AVX__ +namespace avx { +static Active<__m256>::forward forward[] = AVX_ACTIVE_FUNCTION; +static Active<__m256>::backward backward[] = AVX_ACTIVE_FUNCTION; +} // namespace avx +#endif +#endif + +} // namespace hppl + +#endif // HL_ACTIVATION_FUNCTIONS_H_ diff --git a/paddle/operators/math/detail/hl_avx_functions.cc b/paddle/operators/math/detail/hl_avx_functions.cc new file mode 100644 index 0000000000000000000000000000000000000000..415bac5d93ee00244d072b0998c6941b14d4f8d8 --- /dev/null +++ b/paddle/operators/math/detail/hl_avx_functions.cc @@ -0,0 +1,70 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "hl_functions.h" +// TODO(qingqing) refine this dependence +#include "paddle/cuda/src/avx_mathfun.h" + +namespace hppl { + +__m256 exp(__m256 a) { return exp256_ps(a); } + +__m256 relu(const __m256 a) { + __m256 tmp = _mm256_set1_ps(0.0f); + return _mm256_max_ps(a, tmp); +} + +__m256 sigmoid(const __m256 a) { + __m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); + __m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); + __m256 tmp = _mm256_max_ps(a, min); + tmp = _mm256_min_ps(tmp, max); + tmp = _mm256_sub_ps(_mm256_set1_ps(0.0f), tmp); + tmp = exp(tmp); + tmp = _mm256_add_ps(_mm256_set1_ps(1.0f), tmp); + tmp = _mm256_div_ps(_mm256_set1_ps(1.0f), tmp); + return tmp; +} + +__m256 tanh(const __m256 a) { + __m256 max = _mm256_set1_ps(EXP_MAX_INPUT); + __m256 tmp = _mm256_mul_ps(_mm256_set1_ps(-2.0f), a); + tmp = _mm256_min_ps(tmp, max); + tmp = exp(tmp); + return _mm256_sub_ps(_mm256_div_ps(_mm256_set1_ps(2.0f), + _mm256_add_ps(_mm256_set1_ps(1.0f), tmp)), + _mm256_set1_ps(1.0f)); +} + +__m256 linear(const __m256 a) { return a; } + +__m256 relu(const __m256 a, const __m256 b) { + return _mm256_mul_ps( + a, _mm256_and_ps(_mm256_cmp_ps(b, _mm256_set1_ps(0.0f), _CMP_GT_OS), + _mm256_set1_ps(1.0f))); +} + +__m256 sigmoid(const __m256 a, const __m256 b) { + return _mm256_mul_ps(_mm256_mul_ps(a, b), + _mm256_sub_ps(_mm256_set1_ps(1.0f), b)); +} + +__m256 tanh(const __m256 a, const __m256 b) { + return _mm256_mul_ps( + a, _mm256_sub_ps(_mm256_set1_ps(1.0f), _mm256_mul_ps(b, b))); +} + +__m256 linear(const __m256 a, const __m256 b) { return a; } +} // namespace hppl diff --git a/paddle/operators/math/detail/hl_avx_functions.h b/paddle/operators/math/detail/hl_avx_functions.h new file mode 100644 index 0000000000000000000000000000000000000000..35f4eabb4c07c6cc9d2edded02e5b6290b1232f8 --- /dev/null +++ b/paddle/operators/math/detail/hl_avx_functions.h @@ -0,0 +1,32 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifndef HL_AVX_FUNCTIONS_H_ +#define HL_AVX_FUNCTIONS_H_ + +#include + +namespace hppl { +__m256 relu(const __m256 a); +__m256 sigmoid(const __m256 a); +__m256 tanh(const __m256 a); +__m256 linear(const __m256 a); + +__m256 relu(const __m256 a, const __m256 b); +__m256 sigmoid(const __m256 a, const __m256 b); +__m256 tanh(const __m256 a, const __m256 b); +__m256 linear(const __m256 a, const __m256 b); +} // namespace hppl + +#endif // HL_AVX_FUNCTIONS_H_ diff --git a/paddle/operators/math/detail/hl_cpu_functions.cc b/paddle/operators/math/detail/hl_cpu_functions.cc new file mode 100644 index 0000000000000000000000000000000000000000..21ec78f9629af0e4673a56517d76ac6734f57db8 --- /dev/null +++ b/paddle/operators/math/detail/hl_cpu_functions.cc @@ -0,0 +1,89 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "hl_functions.h" + +namespace hppl { +namespace typef { + +float relu(const float a) { + return a > static_cast(0.0) ? a : static_cast(0.0); +} + +float sigmoid(const float a) { + const float min = SIGMOID_THRESHOLD_MIN; + const float max = SIGMOID_THRESHOLD_MAX; + float tmp = (a < min) ? min : ((a > max) ? max : a); + return static_cast(1.0) / (static_cast(1.0) + exp(-tmp)); +} + +float tanh(const float a) { + float tmp = -2.0 * a; + tmp = (tmp > EXP_MAX_INPUT) ? EXP_MAX_INPUT : tmp; + return (2.0 / (1.0 + exp(tmp))) - 1.0; +} + +float linear(const float a) { return a; } + +float relu(const float a, const float b) { return a * (b > 0.0 ? 1.0 : 0.0); } + +float sigmoid(const float a, const float b) { + return a * b * (static_cast(1) - b); +} + +float tanh(const float a, const float b) { + return a * (static_cast(1) - b * b); +} + +float linear(const float a, const float b) { return a; } + +} // namespace typef + +namespace typed { +double relu(const double a) { + return a > static_cast(0.0) ? a : static_cast(0.0); +} + +double sigmoid(const double a) { + const double min = SIGMOID_THRESHOLD_MIN; + const double max = SIGMOID_THRESHOLD_MAX; + double tmp = (a < min) ? min : ((a > max) ? max : a); + return static_cast(1.0) / (static_cast(1.0) + exp(-tmp)); +} + +double tanh(const double a) { + double tmp = -2.0 * a; + tmp = (tmp > EXP_MAX_INPUT) ? EXP_MAX_INPUT : tmp; + return (2.0 / (1.0 + exp(tmp))) - 1.0; +} + +double linear(const double a) { return a; } + +double relu(const double a, const double b) { + return a * (b > 0.0 ? 1.0 : 0.0); +} + +double sigmoid(const double a, const double b) { + return a * b * (static_cast(1) - b); +} + +double tanh(const double a, const double b) { + return a * (static_cast(1) - b * b); +} + +double linear(const double a, const double b) { return a; } + +} // namespace typed +} // namespace hppl diff --git a/paddle/operators/math/detail/hl_functions.h b/paddle/operators/math/detail/hl_functions.h new file mode 100644 index 0000000000000000000000000000000000000000..3e2f0c9ee6d3ae2ed598c4d5f09b85b7d61fdd51 --- /dev/null +++ b/paddle/operators/math/detail/hl_functions.h @@ -0,0 +1,71 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifndef HL_FUNCTIONS_H_ +#define HL_FUNCTIONS_H_ + +/** + * sigmoid threshold maximum + */ +#define SIGMOID_THRESHOLD_MIN -40.0 + +/** + * sigmoid threshold minimum + */ +#define SIGMOID_THRESHOLD_MAX 13.0 + +/** + * The maximum input value for exp, used to avoid overflow problem. + * currently only used for tanh function. + */ +#define EXP_MAX_INPUT 40.0 + +#ifndef __NVCC__ +namespace hppl { +namespace typef { +float relu(const float a); +float sigmoid(const float a); +float tanh(const float a); +float linear(const float a); + +float relu(const float a, const float b); +float sigmoid(const float a, const float b); +float tanh(const float a, const float b); +float linear(const float a, const float b); + +} // namespace typef + +namespace typed { +double relu(const double a); +double sigmoid(const double a); +double tanh(const double a); +double linear(const double a); + +double relu(const double a, const double b); +double sigmoid(const double a, const double b); +double tanh(const double a, const double b); +double linear(const double a, const double b); +} // namespace typed + +} // namespace hppl + +#ifdef __AVX__ +#include "hl_avx_functions.h" +#endif + +#else +#include "hl_gpu_functions.h" +#endif + +#endif // HL_FUNCTIONS_H_ diff --git a/paddle/operators/math/detail/hl_gpu_functions.h b/paddle/operators/math/detail/hl_gpu_functions.h new file mode 100644 index 0000000000000000000000000000000000000000..72f2204e7b2cfdba1367b51e3731dde11fb292d6 --- /dev/null +++ b/paddle/operators/math/detail/hl_gpu_functions.h @@ -0,0 +1,93 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifndef HL_GPU_FUNCTIONS_CUH_ +#define HL_GPU_FUNCTIONS_CUH_ + +#include "hl_base.h" + +namespace hppl { +namespace typef { + +__device__ static float relu(const float a) { return a > 0.0f ? a : 0.0f; } + +__device__ static float sigmoid(const float a) { + const float min = SIGMOID_THRESHOLD_MIN; + const float max = SIGMOID_THRESHOLD_MAX; + float tmp = (a < min) ? min : ((a > max) ? max : a); + return __fdividef(1.0f, 1.0f + __expf(-tmp)); +} + +__device__ static float tanh(const float a) { + float tmp = -2.0 * a; + tmp = (tmp > EXP_MAX_INPUT) ? EXP_MAX_INPUT : tmp; + return __fdividef(2.0f, (1.0f + __expf(-2.0f * tmp))) - 1.0f; +} + +__device__ static float linear(const float a) { return a; } + +__device__ static float relu(const float a, const float b) { + return a * (b > 0.0f ? 1.0f : 0.0f); +} + +__device__ static float sigmoid(const float a, const float b) { + return a * b * (1.0f - b); +} + +__device__ static float tanh(const float a, const float b) { + return a * (1.0f - b * b); +} + +__device__ static float linear(const float a, const float b) { return a; } + +} // namespace typef + +namespace typed { + +__device__ static double relu(const double a) { return a > 0.0 ? a : 0.0; } + +__device__ static double sigmoid(const double a) { + const double min = SIGMOID_THRESHOLD_MIN; + const double max = SIGMOID_THRESHOLD_MAX; + double tmp = (a < min) ? min : ((a > max) ? max : a); + return 1.0 / (1.0 + exp(-tmp)); +} + +__device__ static double tanh(const double a) { + double tmp = -2.0 * a; + tmp = (tmp > EXP_MAX_INPUT) ? EXP_MAX_INPUT : tmp; + return (2.0 / (1.0 + exp(-2.0 * a))) - 1.0; +} + +__device__ static double linear(const double a) { return a; } + +__device__ static double relu(const double a, const double b) { + return a * (b > 0.0 ? 1.0 : 0.0); +} + +__device__ static double sigmoid(const double a, const double b) { + return a * b * (1 - b); +} + +__device__ static double tanh(const double a, const double b) { + return a * (1.0 - b * b); +} + +__device__ static double linear(const double a, const double b) { return a; } + +} // namespace typef + +} // namespace hppl + +#endif // HL_GPU_FUNCTIONS_CUH_ diff --git a/paddle/operators/math/detail/lstm_cpu_kernel.h b/paddle/operators/math/detail/lstm_cpu_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..74d51d7bc9b91f4c8088384d77183131f57aafab --- /dev/null +++ b/paddle/operators/math/detail/lstm_cpu_kernel.h @@ -0,0 +1,310 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include "paddle/operators/math/detail/hl_activation_functions.h" +#include "paddle/operators/math/lstm_compute.h" + +namespace paddle { +namespace operators { +namespace math { +namespace detail { + +#ifndef __NVCC__ + +template +void naive_lstm_forward_one_sequence(Op op, LstmMetaValue value, + int frameSize, + activation_mode_t active_node, + activation_mode_t active_gate, + activation_mode_t active_state) { + T rValueIn; + T rValueIg; + T rValueFg; + T rValueOg; + T rCheckI; + T rCheckF; + T rCheckO; + T rState; + T rPrevState = 0; + T rStateAtv; + T rOut; + + T *valueIn = value.gateValue; + T *valueIg = value.gateValue + frameSize; + T *valueFg = value.gateValue + frameSize * 2; + T *valueOg = value.gateValue + frameSize * 3; + + for (int i = 0; i < frameSize; i++) { + rValueIn = valueIn[i]; + rValueIg = valueIg[i]; + rValueFg = valueFg[i]; + rValueOg = valueOg[i]; + rCheckI = value.checkIg[i]; + rCheckF = value.checkFg[i]; + rCheckO = value.checkOg[i]; + + if (value.prevStateValue) { + rPrevState = value.prevStateValue[i]; + } + + hppl::cpu::ForwardAct act; + op(rValueIn, rValueIg, rValueFg, rValueOg, rPrevState, rState, rStateAtv, + rOut, rCheckI, rCheckF, rCheckO, act(active_node), act(active_gate), + act(active_state)); + + valueIn[i] = rValueIn; + valueIg[i] = rValueIg; + valueFg[i] = rValueFg; + valueOg[i] = rValueOg; + value.stateValue[i] = rState; + value.stateActiveValue[i] = rStateAtv; + value.outputValue[i] = rOut; + } +} + +template +void naive_lstm_backward_one_sequence(Op op, LstmMetaValue value, + LstmMetaGrad grad, int frameSize, + activation_mode_t active_node, + activation_mode_t active_gate, + activation_mode_t active_state) { + T rValueIn; + T rValueIg; + T rValueFg; + T rValueOg; + T rGradIn; + T rGradIg; + T rGradFg; + T rGradOg; + T rPrevState = 0; + T rPrevStateGrad; + T rState; + T rStateGrad; + T rStateAtv; + T rOutputGrad; + T rCheckI; + T rCheckF; + T rCheckO; + T rCheckIGrad; + T rCheckFGrad; + T rCheckOGrad; + + T *valueIn = value.gateValue; + T *valueIg = value.gateValue + frameSize; + T *valueFg = value.gateValue + frameSize * 2; + T *valueOg = value.gateValue + frameSize * 3; + T *gradIn = grad.gateGrad; + T *gradIg = grad.gateGrad + frameSize; + T *gradFg = grad.gateGrad + frameSize * 2; + T *gradOg = grad.gateGrad + frameSize * 3; + + for (int i = 0; i < frameSize; i++) { + rValueIn = valueIn[i]; + rValueIg = valueIg[i]; + rValueFg = valueFg[i]; + rValueOg = valueOg[i]; + rCheckI = value.checkIg[i]; + rCheckF = value.checkFg[i]; + rCheckO = value.checkOg[i]; + rState = value.stateValue[i]; + rStateAtv = value.stateActiveValue[i]; + rOutputGrad = grad.outputGrad[i]; + rStateGrad = grad.stateGrad[i]; + if (value.prevStateValue) { + rPrevState = value.prevStateValue[i]; + } + + hppl::cpu::BackwardAct act; + op(rValueIn, rValueIg, rValueFg, rValueOg, rGradIn, rGradIg, rGradFg, + rGradOg, rPrevState, rPrevStateGrad, rState, rStateGrad, rStateAtv, + rOutputGrad, rCheckI, rCheckF, rCheckO, rCheckIGrad, rCheckFGrad, + rCheckOGrad, act(active_node), act(active_gate), act(active_state)); + + gradIn[i] = rGradIn; + gradIg[i] = rGradIg; + gradFg[i] = rGradFg; + gradOg[i] = rGradOg; + grad.stateGrad[i] = rStateGrad; + + if (grad.prevStateGrad) grad.prevStateGrad[i] = rPrevStateGrad; + if (value.prevStateValue) { + if (grad.checkIgGrad) grad.checkIgGrad[i] += rCheckIGrad; + if (grad.checkFgGrad) grad.checkFgGrad[i] += rCheckFGrad; + } + if (grad.checkOgGrad) grad.checkOgGrad[i] += rCheckOGrad; + } +} + +template +void avx_lstm_forward_one_sequence(Op op, LstmMetaValue value, int frameSize, + activation_mode_t active_node, + activation_mode_t active_gate, + activation_mode_t active_state) { +#ifdef __AVX__ + __m256 rValueIn; + __m256 rValueIg; + __m256 rValueFg; + __m256 rValueOg; + __m256 rCheckI; + __m256 rCheckF; + __m256 rCheckO; + __m256 rState; + __m256 rPrevState = _mm256_set1_ps(0.0f); + __m256 rStateAtv; + __m256 rOut; + + __m256 *valueIn = (__m256 *)value.gateValue; + __m256 *valueIg = (__m256 *)(value.gateValue + frameSize); + __m256 *valueFg = (__m256 *)(value.gateValue + frameSize * 2); + __m256 *valueOg = (__m256 *)(value.gateValue + frameSize * 3); + + for (int i = 0; i < frameSize / 8; i++) { + rValueIn = valueIn[i]; + rValueIg = valueIg[i]; + rValueFg = valueFg[i]; + rValueOg = valueOg[i]; + rCheckI = ((__m256 *)value.checkIg)[i]; + rCheckF = ((__m256 *)value.checkFg)[i]; + rCheckO = ((__m256 *)value.checkOg)[i]; + + if (value.prevStateValue) { + rPrevState = ((__m256 *)value.prevStateValue)[i]; + } + + op(rValueIn, rValueIg, rValueFg, rValueOg, rPrevState, rState, rStateAtv, + rOut, rCheckI, rCheckF, rCheckO, hppl::avx::forward[active_node], + hppl::avx::forward[active_gate], hppl::avx::forward[active_state]); + + valueIn[i] = rValueIn; + valueIg[i] = rValueIg; + valueFg[i] = rValueFg; + valueOg[i] = rValueOg; + ((__m256 *)value.stateValue)[i] = rState; + ((__m256 *)value.stateActiveValue)[i] = rStateAtv; + ((__m256 *)value.outputValue)[i] = rOut; + } +#endif +} + +template +void avx_lstm_backward_one_sequence(Op op, LstmMetaValue value, + LstmMetaGrad grad, int frameSize, + activation_mode_t active_node, + activation_mode_t active_gate, + activation_mode_t active_state) { +#ifdef __AVX__ + __m256 rValueIn; + __m256 rValueIg; + __m256 rValueFg; + __m256 rValueOg; + __m256 rGradIn; + __m256 rGradIg; + __m256 rGradFg; + __m256 rGradOg; + __m256 rPrevState = _mm256_set1_ps(0.0f); + __m256 rPrevStateGrad; + __m256 rStateGrad; + __m256 rState; + __m256 rStateAtv; + __m256 rOutputGrad; + __m256 rCheckI; + __m256 rCheckF; + __m256 rCheckO; + __m256 rCheckIGrad; + __m256 rCheckFGrad; + __m256 rCheckOGrad; + + __m256 *valueIn = (__m256 *)value.gateValue; + __m256 *valueIg = (__m256 *)(value.gateValue + frameSize); + __m256 *valueFg = (__m256 *)(value.gateValue + frameSize * 2); + __m256 *valueOg = (__m256 *)(value.gateValue + frameSize * 3); + __m256 *gradIn = (__m256 *)grad.gateGrad; + __m256 *gradIg = (__m256 *)(grad.gateGrad + frameSize); + __m256 *gradFg = (__m256 *)(grad.gateGrad + frameSize * 2); + __m256 *gradOg = (__m256 *)(grad.gateGrad + frameSize * 3); + + for (int i = 0; i < frameSize / 8; i++) { + rValueIn = valueIn[i]; + rValueIg = valueIg[i]; + rValueFg = valueFg[i]; + rValueOg = valueOg[i]; + rCheckI = ((__m256 *)value.checkIg)[i]; + rCheckF = ((__m256 *)value.checkFg)[i]; + rCheckO = ((__m256 *)value.checkOg)[i]; + rState = ((__m256 *)value.stateValue)[i]; + rStateAtv = ((__m256 *)value.stateActiveValue)[i]; + rOutputGrad = ((__m256 *)grad.outputGrad)[i]; + rStateGrad = ((__m256 *)grad.stateGrad)[i]; + if (value.prevStateValue) { + rPrevState = ((__m256 *)value.prevStateValue)[i]; + } + + op(rValueIn, rValueIg, rValueFg, rValueOg, rGradIn, rGradIg, rGradFg, + rGradOg, rPrevState, rPrevStateGrad, rState, rStateGrad, rStateAtv, + rOutputGrad, rCheckI, rCheckF, rCheckO, rCheckIGrad, rCheckFGrad, + rCheckOGrad, hppl::avx::backward[active_node], + hppl::avx::backward[active_gate], hppl::avx::backward[active_state]); + + gradIn[i] = rGradIn; + gradIg[i] = rGradIg; + gradFg[i] = rGradFg; + gradOg[i] = rGradOg; + ((__m256 *)grad.stateGrad)[i] = rStateGrad; + + if (grad.prevStateGrad) ((__m256 *)grad.prevStateGrad)[i] = rPrevStateGrad; + if (value.prevStateValue) { + if (grad.checkIgGrad) ((__m256 *)grad.checkIgGrad)[i] += rCheckIGrad; + if (grad.checkFgGrad) ((__m256 *)grad.checkFgGrad)[i] += rCheckFGrad; + } + if (grad.checkOgGrad) ((__m256 *)grad.checkOgGrad)[i] += rCheckOGrad; + } +#endif +} + +template +void cpu_lstm_forward(Op op, LstmMetaValue value, int frameSize, + activation_mode_t active_node, + activation_mode_t active_gate, + activation_mode_t active_state) { + if (Op::avx && !(frameSize & (8 - 1)) && (std::is_same::value)) { + avx_lstm_forward_one_sequence(op, value, frameSize, active_node, + active_gate, active_state); + } else { + naive_lstm_forward_one_sequence(op, value, frameSize, active_node, + active_gate, active_state); + } +} + +template +void cpu_lstm_backward(Op op, LstmMetaValue value, LstmMetaGrad grad, + int frameSize, activation_mode_t active_node, + activation_mode_t active_gate, + activation_mode_t active_state) { + if (Op::avx && !(frameSize & (8 - 1)) && (std::is_same::value)) { + avx_lstm_backward_one_sequence(op, value, grad, frameSize, active_node, + active_gate, active_state); + } else { + naive_lstm_backward_one_sequence(op, value, grad, frameSize, active_node, + active_gate, active_state); + } +} + +#endif + +} // namespace detail +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/detail/lstm_gpu_kernel.h b/paddle/operators/math/detail/lstm_gpu_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..9573eaefb6a9d678ef70f2e2bffdc6a3011b21ea --- /dev/null +++ b/paddle/operators/math/detail/lstm_gpu_kernel.h @@ -0,0 +1,256 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include "paddle/operators/math/detail/hl_activation_functions.h" +#include "paddle/operators/math/lstm_compute.h" +#include "paddle/platform/cuda_helper.h" +#include "paddle/platform/device_context.h" + +#include + +namespace paddle { +namespace operators { +namespace math { +namespace detail { + +/* + * threads(framePerBlock, batchPerBlock) + * grid(frameBlocks, batchBlocks) + */ +template +__global__ void KeLstmForward(Op op, LstmMetaValue value, int frameSize, + int batchSize, activation_mode_t active_node, + activation_mode_t active_gate, + activation_mode_t active_state) { + const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x; + if (frameIdx >= frameSize) return; + + int batchIdx = 0; + if (isBatch) { + batchIdx = blockIdx.y * blockDim.y + threadIdx.y; + if (batchIdx >= batchSize) return; + value.gateValue += batchIdx * frameSize * 4; + value.outputValue += batchIdx * frameSize; + value.stateValue += batchIdx * frameSize; + value.stateActiveValue += batchIdx * frameSize; + } + + T rState; + T rPrevState = 0; + T rStateAtv; + T rOut; + T rValueIn; + T rValueIg; + T rValueFg; + T rValueOg; + T rCheckI = value.checkIg[frameIdx]; + T rCheckF = value.checkFg[frameIdx]; + T rCheckO = value.checkOg[frameIdx]; + + rValueIn = value.gateValue[frameIdx]; + rValueIg = value.gateValue[frameIdx + frameSize]; + rValueFg = value.gateValue[frameIdx + frameSize * 2]; + rValueOg = value.gateValue[frameIdx + frameSize * 3]; + + if (value.prevStateValue) { + if (isBatch) value.prevStateValue += batchIdx * frameSize; + rPrevState = value.prevStateValue[frameIdx]; + } + + hppl::gpu::ForwardAct act; + op(rValueIn, rValueIg, rValueFg, rValueOg, rPrevState, rState, rStateAtv, + rOut, rCheckI, rCheckF, rCheckO, act(active_node), act(active_gate), + act(active_state)); + + value.gateValue[frameIdx] = rValueIn; + value.gateValue[frameIdx + frameSize] = rValueIg; + value.gateValue[frameIdx + frameSize * 2] = rValueFg; + value.gateValue[frameIdx + frameSize * 3] = rValueOg; + + value.stateValue[frameIdx] = rState; + value.stateActiveValue[frameIdx] = rStateAtv; + value.outputValue[frameIdx] = rOut; +} + +/* + * threads(framePerBlock, batchPerBlock) + * grid(frameBlocks, batchBlocks) + */ +template +__global__ void KeLstmBackward(Op op, LstmMetaValue value, + LstmMetaGrad grad, int frameSize, + int batchSize, activation_mode_t active_node, + activation_mode_t active_gate, + activation_mode_t active_state) { + const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x; + if (frameIdx >= frameSize) return; + + int batchIdx = 0; + if (isBatch) { + batchIdx = blockIdx.y * blockDim.y + threadIdx.y; + if (batchIdx >= batchSize) return; + value.gateValue += batchIdx * frameSize * 4; + value.stateValue += batchIdx * frameSize; + value.stateActiveValue += batchIdx * frameSize; + grad.gateGrad += batchIdx * frameSize * 4; + grad.stateGrad += batchIdx * frameSize; + grad.outputGrad += batchIdx * frameSize; + } + + T rValueIn; + T rValueIg; + T rValueFg; + T rValueOg; + T rGradIn; + T rGradIg; + T rGradFg; + T rGradOg; + T rPrevState = 0; + T rPrevStateGrad; + T rState; + T rStateGrad; + T rStateAtv; + T rOutputGrad; + T rCheckI = value.checkIg[frameIdx]; + T rCheckF = value.checkFg[frameIdx]; + T rCheckO = value.checkOg[frameIdx]; + T rCheckIGrad; + T rCheckFGrad; + T rCheckOGrad; + + rValueIn = value.gateValue[frameIdx]; + rValueIg = value.gateValue[frameIdx + frameSize]; + rValueFg = value.gateValue[frameIdx + frameSize * 2]; + rValueOg = value.gateValue[frameIdx + frameSize * 3]; + rState = value.stateValue[frameIdx]; + rStateAtv = value.stateActiveValue[frameIdx]; + rOutputGrad = grad.outputGrad[frameIdx]; + rStateGrad = grad.stateGrad[frameIdx]; + + if (value.prevStateValue) { + if (isBatch) value.prevStateValue += batchIdx * frameSize; + rPrevState = value.prevStateValue[frameIdx]; + } + + hppl::gpu::BackwardAct act; + op(rValueIn, rValueIg, rValueFg, rValueOg, rGradIn, rGradIg, rGradFg, rGradOg, + rPrevState, rPrevStateGrad, rState, rStateGrad, rStateAtv, rOutputGrad, + rCheckI, rCheckF, rCheckO, rCheckIGrad, rCheckFGrad, rCheckOGrad, + act(active_node), act(active_gate), act(active_state)); + + grad.gateGrad[frameIdx] = rGradIn; + grad.gateGrad[frameIdx + frameSize] = rGradIg; + grad.gateGrad[frameIdx + frameSize * 2] = rGradFg; + grad.gateGrad[frameIdx + frameSize * 3] = rGradOg; + grad.stateGrad[frameIdx] = rStateGrad; + if (grad.prevStateGrad) { + if (isBatch) grad.prevStateGrad += batchIdx * frameSize; + grad.prevStateGrad[frameIdx] = rPrevStateGrad; + } + + if (isBatch) { + if (value.prevStateValue) { + if (grad.checkIgGrad) + paddle::platform::CudaAtomicAdd(grad.checkIgGrad + frameIdx, + rCheckIGrad); + if (grad.checkFgGrad) + paddle::platform::CudaAtomicAdd(grad.checkFgGrad + frameIdx, + rCheckFGrad); + } + if (grad.checkOgGrad) + paddle::platform::CudaAtomicAdd(grad.checkOgGrad + frameIdx, rCheckOGrad); + } else { + if (value.prevStateValue) { + if (grad.checkIgGrad) grad.checkIgGrad[frameIdx] += rCheckIGrad; + if (grad.checkFgGrad) grad.checkFgGrad[frameIdx] += rCheckFGrad; + } + if (grad.checkOgGrad) grad.checkOgGrad[frameIdx] += rCheckOGrad; + } +} + +template +void gpu_lstm_forward(const platform::DeviceContext& context, Op op, + LstmMetaValue value, int frameSize, int batchSize, + activation_mode_t active_node, + activation_mode_t active_gate, + activation_mode_t active_state) { + dim3 threads; + dim3 grid; + if (batchSize == 1) { + int framePerBlock = frameSize <= 1024 ? frameSize : 1024; + int frameBlocks = (frameSize + 1024 - 1) / 1024; + threads = dim3(framePerBlock, 1); + grid = dim3(frameBlocks, 1); + } else { + /* framePerBlock = 32 batchPerBlock = 32 */ + threads = dim3(32, 32); + grid = dim3((frameSize + 32 - 1) / 32, (batchSize + 32 - 1) / 32); + } + + auto stream = + reinterpret_cast(context).stream(); + if (batchSize == 1) { + KeLstmForward<<>>( + op, value, frameSize, batchSize, active_node, active_gate, + active_state); + } else { + KeLstmForward<<>>( + op, value, frameSize, batchSize, active_node, active_gate, + active_state); + } +} + +template +void gpu_lstm_backward(const platform::DeviceContext& context, Op op, + LstmMetaValue value, LstmMetaGrad grad, + int frameSize, int batchSize, + activation_mode_t active_node, + activation_mode_t active_gate, + activation_mode_t active_state) { + dim3 threads; + dim3 grid; + if (batchSize == 1) { + int framePerBlock = frameSize <= 1024 ? frameSize : 1024; + int frameBlocks = (frameSize + 1024 - 1) / 1024; + threads = dim3(framePerBlock, 1); + grid = dim3(frameBlocks, 1); + } else { + /* framePerBlock = 32 batchPerBlock = 32 */ + threads = dim3(32, 32); + grid = dim3((frameSize + 32 - 1) / 32, (batchSize + 32 - 1) / 32); + } + + auto stream = + reinterpret_cast(context).stream(); + if (batchSize == 1) { + KeLstmBackward<<>>( + op, value, grad, frameSize, batchSize, active_node, active_gate, + active_state); + } else { + KeLstmBackward<<>>( + op, value, grad, frameSize, batchSize, active_node, active_gate, + active_state); + } +} + +} // namespace detail +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/detail/lstm_kernel.h b/paddle/operators/math/detail/lstm_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..6f3ead2397d5131b4468d0ad288513cedb289594 --- /dev/null +++ b/paddle/operators/math/detail/lstm_kernel.h @@ -0,0 +1,138 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/math/detail/hl_activation_functions.h" +#include "paddle/platform/hostdevice.h" + +#include + +namespace paddle { +namespace operators { +namespace math { +namespace detail { + +namespace forward { + +template +class lstm { + public: + HOSTDEVICE void operator()(T &valueIn, T &valueIg, T &valueFg, T &valueOg, + T &prevState, T &state, T &stateAtv, T &output, + T &checkI, T &checkF, T &checkO, + typename hppl::ForwardActType::type actInput, + typename hppl::ForwardActType::type actGate, + typename hppl::ForwardActType::type actState) { + valueIn = actInput(valueIn); + valueIg = actGate(valueIg + prevState * checkI); + valueFg = actGate(valueFg + prevState * checkF); + state = valueIn * valueIg + prevState * valueFg; + valueOg = actGate(valueOg + state * checkO); + stateAtv = actState(state); + output = valueOg * stateAtv; + } +#ifndef __NVCC__ +#ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default + static const bool avx = false; +#else + // Only float support AVX optimization + static const bool avx = std::is_same::value; + + HOSTDEVICE void operator()(__m256 &valueIn, __m256 &valueIg, __m256 &valueFg, + __m256 &valueOg, __m256 &prevState, __m256 &state, + __m256 &stateAtv, __m256 &output, __m256 &checkI, + __m256 &checkF, __m256 &checkO, + hppl::Active<__m256>::forward actInput, + hppl::Active<__m256>::forward actGate, + hppl::Active<__m256>::forward actState) { + valueIn = actInput(valueIn); + valueIg = actGate(_mm256_add_ps(valueIg, _mm256_mul_ps(prevState, checkI))); + valueFg = actGate(_mm256_add_ps(valueFg, _mm256_mul_ps(prevState, checkF))); + state = _mm256_add_ps(_mm256_mul_ps(valueIn, valueIg), + _mm256_mul_ps(prevState, valueFg)); + valueOg = actGate(_mm256_add_ps(valueOg, _mm256_mul_ps(state, checkO))); + stateAtv = actState(state); + output = _mm256_mul_ps(valueOg, stateAtv); + } +#endif +#endif +}; + +} // namespace forward + +namespace backward { + +template +class lstm { + public: + HOSTDEVICE void operator()(T &valueIn, T &valueIg, T &valueFg, T &valueOg, + T &gradIn, T &gradIg, T &gradFg, T &gradOg, + T &prevState, T &prevStateGrad, T &state, + T &stateGrad, T &stateAtv, T &outputGrad, + T &checkI, T &checkF, T &checkO, T &checkIGrad, + T &checkFGrad, T &checkOGrad, + typename hppl::BackwardActType::type actInput, + typename hppl::BackwardActType::type actGate, + typename hppl::BackwardActType::type actState) { + gradOg = actGate(outputGrad * stateAtv, valueOg); + stateGrad += actState(outputGrad * valueOg, stateAtv) + gradOg * checkO; + gradIn = actInput(stateGrad * valueIg, valueIn); + gradIg = actGate(stateGrad * valueIn, valueIg); + gradFg = actGate(stateGrad * prevState, valueFg); + prevStateGrad = gradIg * checkI + gradFg * checkF + stateGrad * valueFg; + checkIGrad = gradIg * prevState; + checkFGrad = gradFg * prevState; + checkOGrad = gradOg * state; + } +#ifndef __NVCC__ +#ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default + static const bool avx = false; +#else + // Only float support AVX optimization + static const bool avx = std::is_same::value; + HOSTDEVICE void operator()(__m256 &valueIn, __m256 &valueIg, __m256 &valueFg, + __m256 &valueOg, __m256 &gradIn, __m256 &gradIg, + __m256 &gradFg, __m256 &gradOg, __m256 &prevState, + __m256 &prevStateGrad, __m256 &state, + __m256 &stateGrad, __m256 &stateAtv, + __m256 &outputGrad, __m256 &checkI, __m256 &checkF, + __m256 &checkO, __m256 &checkIGrad, + __m256 &checkFGrad, __m256 &checkOGrad, + hppl::Active<__m256>::backward actInput, + hppl::Active<__m256>::backward actGate, + hppl::Active<__m256>::backward actState) { + gradOg = actGate(_mm256_mul_ps(outputGrad, stateAtv), valueOg); + stateGrad = _mm256_add_ps( + actState(_mm256_mul_ps(outputGrad, valueOg), stateAtv), stateGrad); + stateGrad = _mm256_add_ps(_mm256_mul_ps(gradOg, checkO), stateGrad); + gradIn = actInput(_mm256_mul_ps(stateGrad, valueIg), valueIn); + gradIg = actGate(_mm256_mul_ps(stateGrad, valueIn), valueIg); + gradFg = actGate(_mm256_mul_ps(stateGrad, prevState), valueFg); + prevStateGrad = _mm256_add_ps(_mm256_mul_ps(gradIg, checkI), + _mm256_mul_ps(gradFg, checkF)); + prevStateGrad = + _mm256_add_ps(_mm256_mul_ps(stateGrad, valueFg), prevStateGrad); + checkIGrad = _mm256_mul_ps(gradIg, prevState); + checkFGrad = _mm256_mul_ps(gradFg, prevState); + checkOGrad = _mm256_mul_ps(gradOg, state); + } +#endif +#endif +}; + +} // namespace backward + +} // namespace detail +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/im2col.cc b/paddle/operators/math/im2col.cc index 729ba8665cf9d3aa342dc10f5f5a5fe3803bf75f..3b1b0bd71dd3768b932864e185af8dc839b4653e 100644 --- a/paddle/operators/math/im2col.cc +++ b/paddle/operators/math/im2col.cc @@ -29,8 +29,8 @@ class Im2ColFunctor(); @@ -52,16 +68,14 @@ class Im2ColFunctor= input_height || - (im_col_idx - padding_width) < 0 || - (im_col_idx - padding_width) >= input_width) { + int im_row_idx = h * stride_height + h_offset - padding_up; + int im_col_idx = w * stride_width + w_offset - padding_left; + + if (im_row_idx < 0 || im_row_idx >= input_height || im_col_idx < 0 || + im_col_idx >= input_width) { col_data[(c * output_height + h) * output_width + w] = T(0); } else { - im_row_idx += c_im * input_height - padding_height; - im_col_idx -= padding_width; + im_row_idx += c_im * input_height; col_data[(c * output_height + h) * output_width + w] = im_data[im_row_idx * input_width + im_col_idx]; } @@ -82,7 +96,8 @@ class Col2ImFunctor(); @@ -103,14 +134,12 @@ class Col2ImFunctor= 0 && - (im_row_idx - padding_height) < input_height && - (im_col_idx - padding_width) >= 0 && - (im_col_idx - padding_width) < input_width) { - im_row_idx += c_im * input_height - padding_height; - im_col_idx -= padding_width; + int im_row_idx = h * stride_height + h_offset - padding_up; + int im_col_idx = w * stride_width + w_offset - padding_left; + + if ((im_row_idx) >= 0 && (im_row_idx) < input_height && + (im_col_idx) >= 0 && (im_col_idx) < input_width) { + im_row_idx += c_im * input_height; im_data[im_row_idx * input_width + im_col_idx] += col_data[(c * output_height + h) * output_width + w]; } @@ -140,8 +169,8 @@ class Im2ColFunctor= down_pad) { - row_begin = 0; - } else { - row_begin = down_pad - up_pad; - } - row_end = row_begin + ((input_height + up_pad + down_pad - filter_height) / - stride_height + - 1); + PADDLE_ENFORCE_EQ( + (input_height + padding_up + padding_down - filter_height) / + stride_height + + 1, + output_height, + "Output_height and padding(padding_up, padding_down) are " + "inconsistent."); + PADDLE_ENFORCE_EQ( + (input_width + padding_left + padding_right - filter_width) / + stride_width + + 1, + output_width, + "output_width and padding(padding_left, padding_right) are " + "inconsistent."); const T* im_data = im.data(); T* col_data = col.data(); - for (int col_row_idx = row_begin; col_row_idx < row_end; ++col_row_idx) { + for (int col_row_idx = 0; col_row_idx < output_height; ++col_row_idx) { for (int col_col_idx = 0; col_col_idx < output_width; ++col_col_idx) { for (int channel = 0; channel < input_channels; ++channel) { for (int filter_row_idx = 0; filter_row_idx < filter_height; @@ -175,17 +207,16 @@ class Im2ColFunctor= input_height || im_col_offset < 0 || im_col_offset >= input_width) { col_data[col_offset] = T(0); @@ -214,7 +245,8 @@ class Col2ImFunctor= down_pad) { - row_begin = 0; - } else { - row_begin = down_pad - up_pad; - } - row_end = row_begin + ((input_height + up_pad + down_pad - filter_height) / - stride_height + - 1); + PADDLE_ENFORCE_EQ( + (input_height + padding_up + padding_down - filter_height) / + stride_height + + 1, + output_height, + "Output_height and padding(padding_up, padding_down) are " + "inconsistent."); + PADDLE_ENFORCE_EQ( + (input_width + padding_left + padding_right - filter_width) / + stride_width + + 1, + output_width, + "output_width and padding(padding_left, padding_right) are " + "inconsistent."); T* im_data = im.data(); const T* col_data = col.data(); - for (int col_row_idx = row_begin; col_row_idx < row_end; ++col_row_idx) { + for (int col_row_idx = 0; col_row_idx < output_height; ++col_row_idx) { for (int col_col_idx = 0; col_col_idx < output_width; ++col_col_idx) { for (int channel = 0; channel < input_channels; ++channel) { for (int filter_row_idx = 0; filter_row_idx < filter_height; ++filter_row_idx) { for (int filter_col_idx = 0; filter_col_idx < filter_width; ++filter_col_idx) { - int im_row_offset = // change or not ??? - col_row_idx * stride_height + filter_row_idx - padding_height; + int im_row_offset = + col_row_idx * stride_height + filter_row_idx - padding_up; int im_col_offset = - col_col_idx * stride_width + filter_col_idx - padding_width; - int col_offset = - ((((col_row_idx - row_begin) * output_width + col_col_idx) * - input_channels + - channel) * - filter_height + - filter_row_idx) * - filter_width + - filter_col_idx; + col_col_idx * stride_width + filter_col_idx - padding_left; + int col_offset = (((col_row_idx * output_width + col_col_idx) * + input_channels + + channel) * + filter_height + + filter_row_idx) * + filter_width + + filter_col_idx; if (im_row_offset >= 0 && im_row_offset < input_height && im_col_offset >= 0 && im_col_offset < input_width) { int im_offset = diff --git a/paddle/operators/math/im2col.cu b/paddle/operators/math/im2col.cu index 2416758629938a2a69a26503dc10f60aaaa7df76..7b201fdbf3c5dd7d336d359e00b7323cecc0231a 100644 --- a/paddle/operators/math/im2col.cu +++ b/paddle/operators/math/im2col.cu @@ -66,8 +66,8 @@ class Im2ColFunctor(context) .stream()>>>( im.data(), num_outputs, input_height, input_width, filter_height, - filter_width, stride_height, stride_width, padding_height, - padding_width, output_height, output_width, col.data()); + filter_width, stride_height, stride_width, padding_up, padding_left, + output_height, output_width, col.data()); } }; @@ -152,7 +161,8 @@ class Col2ImFunctor<<(context) .stream()>>>( - num_kernels, col.data(), input_height + 2 * padding_height, - input_width + 2 * padding_width, input_channels, filter_height, - filter_width, stride_height, stride_width, padding_height, - padding_width, output_height, output_width, im.data()); + num_kernels, col.data(), input_height + padding_up + padding_down, + input_width + padding_left + padding_left, input_channels, + filter_height, filter_width, stride_height, stride_width, padding_up, + padding_left, output_height, output_width, im.data()); } }; @@ -199,8 +219,7 @@ __global__ void im2colOCF(const T* im_data, T* col_data, int input_channels, int input_height, int input_width, int filter_height, int filter_width, int stride_height, int stride_width, int padding_height, int padding_width, - int output_height, int output_width, int row_begin, - int row_end) { + int output_height, int output_width) { int swid = blockIdx.x; int shid = blockIdx.y; for (int channelid = threadIdx.z; channelid < input_channels; @@ -208,8 +227,7 @@ __global__ void im2colOCF(const T* im_data, T* col_data, int input_channels, for (int idy = threadIdx.y; idy < filter_height; idy += blockDim.y) { for (int idx = threadIdx.x; idx < filter_width; idx += blockDim.x) { int width_offset = idx + swid * stride_width - padding_width; - int height_offset = - idy + (shid + row_begin) * stride_height - padding_height; + int height_offset = idy + shid * stride_height - padding_height; int im_offset = width_offset + height_offset * input_width + channelid * input_height * input_width; @@ -240,8 +258,8 @@ class Im2ColFunctor= down_pad) { - row_begin = 0; - } else { - row_begin = down_pad - up_pad; - } - row_end = row_begin + ((input_height + up_pad + down_pad - filter_height) / - stride_height + - 1); - - int output_height = row_end - row_begin; // col.dims()[0]; + int output_height = col.dims()[0]; int output_width = col.dims()[1]; + PADDLE_ENFORCE((input_height + padding_up + padding_down - filter_height) / + stride_height + + 1 == + output_height); + PADDLE_ENFORCE((input_width + padding_left + padding_right - filter_width) / + stride_width + + 1 == + output_width); + int block_dim_x = 0; int block_dim_y = 0; if (filter_height <= 4 && filter_width <= 4) { @@ -289,9 +303,8 @@ class Im2ColFunctor(context) .stream()>>>( im.data(), col.data(), input_channels, input_height, input_width, - filter_height, filter_width, stride_height, stride_width, - padding_height, padding_width, output_height, output_width, row_begin, - row_end); + filter_height, filter_width, stride_height, stride_width, padding_up, + padding_left, output_height, output_width); } }; @@ -300,8 +313,7 @@ __global__ void col2imOCF(T* im_data, const T* col_data, int input_channels, int input_height, int input_width, int filter_height, int filter_width, int stride_height, int stride_width, int padding_height, int padding_width, - int output_height, int output_width, int row_begin, - int row_end) { + int output_height, int output_width) { int swid = blockIdx.x; int shid = blockIdx.y; for (int channelid = threadIdx.z; channelid < input_channels; @@ -309,8 +321,7 @@ __global__ void col2imOCF(T* im_data, const T* col_data, int input_channels, for (int idy = threadIdx.y; idy < filter_height; idy += blockDim.y) { for (int idx = threadIdx.x; idx < filter_width; idx += blockDim.x) { int width_offset = idx + swid * stride_width - padding_width; - int height_offset = - idy + (shid + row_begin) * stride_height - padding_height; + int height_offset = idy + shid * stride_height - padding_height; int im_offset = width_offset + height_offset * input_width + channelid * input_height * input_width; @@ -340,7 +351,8 @@ class Col2ImFunctor= down_pad) { - row_begin = 0; - } else { - row_begin = down_pad - up_pad; - } - row_end = row_begin + ((input_height + up_pad + down_pad - filter_height) / - stride_height + - 1); - - int output_height = row_end - row_begin; // col.dims()[0]; + int output_height = col.dims()[0]; int output_width = col.dims()[1]; + PADDLE_ENFORCE((input_height + padding_up + padding_down - filter_height) / + stride_height + + 1 == + output_height); + PADDLE_ENFORCE((input_width + padding_left + padding_right - filter_width) / + stride_width + + 1 == + output_width); + int block_dim_x = 0; int block_dim_y = 0; if (filter_height <= 4 && filter_width <= 4) { @@ -388,9 +396,8 @@ class Col2ImFunctor(context) .stream()>>>( im.data(), col.data(), input_channels, input_height, input_width, - filter_height, filter_width, stride_height, stride_width, - padding_height, padding_width, output_height, output_width, row_begin, - row_end); + filter_height, filter_width, stride_height, stride_width, padding_up, + padding_left, output_height, output_width); } }; diff --git a/paddle/operators/math/im2col.h b/paddle/operators/math/im2col.h index 7b717e1603c94cd77c74cb0d86f1d23e2692f9d8..c736d4fa523c2af3e3dd7a11114d7f84021bc5c1 100644 --- a/paddle/operators/math/im2col.h +++ b/paddle/operators/math/im2col.h @@ -74,8 +74,8 @@ class Im2ColFunctor { public: void operator()(const platform::DeviceContext& context, const framework::Tensor& im, framework::Tensor& col, - int stride_height, int stride_width, int padding_height, - int padding_width); + int stride_height, int stride_width, int padding_up, + int padding_down, int padding_left, int padding_right); }; template @@ -83,7 +83,8 @@ class Col2ImFunctor { public: void operator()(const platform::DeviceContext& context, framework::Tensor& im, const framework::Tensor& col, int stride_height, - int stride_width, int padding_height, int padding_width); + int stride_width, int padding_up, int padding_down, + int padding_left, int padding_right); }; } // namespace math diff --git a/paddle/operators/math/im2col_test.cc b/paddle/operators/math/im2col_test.cc index 16b1396d37b2b5bfbb9b3d3dc265053e47d237c6..5763782c4edec87f44dabef2ccffe3097eeb2421 100644 --- a/paddle/operators/math/im2col_test.cc +++ b/paddle/operators/math/im2col_test.cc @@ -85,10 +85,10 @@ void testIm2col() { paddle::operators::math::ColFormat::kOCF, Place, float> im2col_ocf; - im2col(*context, input, output_cfo, stride, stride, padding, padding); - im2col_ocf(*context, input, output_ocf, /*stride_height*/ stride, - /*stride_width*/ stride, /*up_pad*/ padding, - /*down_pad*/ padding); + im2col(*context, input, output_cfo, stride, stride, padding, padding, padding, + padding); + im2col_ocf(*context, input, output_ocf, stride, stride, padding, padding, + padding, padding); float out_cfo_data[] = {0, 1, 1, 2, 3, 4, 4, 5}; float out_ocf_data[] = {0, 1, 3, 4, 1, 2, 4, 5}; @@ -131,7 +131,8 @@ void testIm2col() { input.CopyFrom(input_tmp, *place, *context); } - col2im(*context, input, output_cfo, stride, stride, padding, padding); + col2im(*context, input, output_cfo, stride, stride, padding, padding, padding, + padding); float* in_ptr; if (paddle::platform::is_cpu_place(*place)) { @@ -152,9 +153,8 @@ void testIm2col() { input.CopyFrom(input_tmp, *place, *context); } - col2im_ocf(*context, input, output_ocf, /*stride_height*/ stride, - /*stride_width*/ stride, /*up_pad*/ padding, - /*down_pad*/ padding); + col2im_ocf(*context, input, output_ocf, stride, stride, padding, padding, + padding, padding); if (paddle::platform::is_cpu_place(*place)) { in_ptr = input.data(); diff --git a/paddle/operators/math/lstm_compute.cc b/paddle/operators/math/lstm_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..0febf8e3b70111d12f858cf6259a2801a42d9a90 --- /dev/null +++ b/paddle/operators/math/lstm_compute.cc @@ -0,0 +1,82 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/math/lstm_compute.h" +#include "paddle/operators/math/detail/lstm_cpu_kernel.h" +#include "paddle/operators/math/detail/lstm_kernel.h" + +namespace paddle { +namespace operators { +namespace math { + +template +struct LstmUnitFunctor { + static void compute(const platform::DeviceContext& context, + LstmMetaValue value, int frame_size, int batch_size, + const std::string& gate_act, const std::string& cell_act, + const std::string& cand_act) { + for (int b = 0; b < batch_size; b++) { + detail::cpu_lstm_forward(detail::forward::lstm(), value, frame_size, + ActiveType(cand_act), ActiveType(gate_act), + ActiveType(cell_act)); + value.gateValue += frame_size * 4; + value.stateValue += frame_size; + value.stateActiveValue += frame_size; + value.outputValue += frame_size; + if (value.prevStateValue) { + value.prevStateValue += frame_size; + } + } + } +}; + +template +struct LstmUnitGradFunctor { + static void compute(const platform::DeviceContext& context, + LstmMetaValue value, LstmMetaGrad grad, + int frame_size, int batch_size, + const std::string& gate_act, const std::string& cell_act, + const std::string& cand_act) { + for (int b = 0; b < batch_size; b++) { + detail::cpu_lstm_backward(detail::backward::lstm(), value, grad, + frame_size, ActiveType(cand_act), + ActiveType(gate_act), ActiveType(cell_act)); + + value.gateValue += frame_size * 4; + value.stateValue += frame_size; + value.stateActiveValue += frame_size; + value.outputValue += frame_size; + if (value.prevStateValue) { + value.prevStateValue += frame_size; + } + + grad.gateGrad += frame_size * 4; + grad.stateGrad += frame_size; + grad.stateActiveGrad += frame_size; + grad.outputGrad += frame_size; + if (grad.prevStateGrad) { + grad.prevStateGrad += frame_size; + } + } + } +}; + +template class LstmUnitFunctor; +template class LstmUnitFunctor; +template class LstmUnitGradFunctor; +template class LstmUnitGradFunctor; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/lstm_compute.cu b/paddle/operators/math/lstm_compute.cu new file mode 100644 index 0000000000000000000000000000000000000000..b2122f2a5c08a6d9d53293833177f0ba2c3ab860 --- /dev/null +++ b/paddle/operators/math/lstm_compute.cu @@ -0,0 +1,55 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/math/detail/lstm_gpu_kernel.h" +#include "paddle/operators/math/detail/lstm_kernel.h" +#include "paddle/operators/math/lstm_compute.h" + +namespace paddle { +namespace operators { +namespace math { + +template +struct LstmUnitFunctor { + static void compute(const platform::DeviceContext& context, + LstmMetaValue value, int frame_size, int batch_size, + const std::string& gate_act, const std::string& cell_act, + const std::string& cand_act) { + detail::gpu_lstm_forward(context, detail::forward::lstm(), value, + frame_size, batch_size, ActiveType(cand_act), + ActiveType(gate_act), ActiveType(cell_act)); + } +}; + +template +struct LstmUnitGradFunctor { + static void compute(const platform::DeviceContext& context, + LstmMetaValue value, LstmMetaGrad grad, + int frame_size, int batch_size, + const std::string& gate_act, const std::string& cell_act, + const std::string& cand_act) { + detail::gpu_lstm_backward(context, detail::backward::lstm(), value, grad, + frame_size, batch_size, ActiveType(cand_act), + ActiveType(gate_act), ActiveType(cell_act)); + } +}; + +template class LstmUnitFunctor; +template class LstmUnitFunctor; +template class LstmUnitGradFunctor; +template class LstmUnitGradFunctor; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/lstm_compute.h b/paddle/operators/math/lstm_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..28d2c6fd3b0d8143da90c37f241072e37397f98b --- /dev/null +++ b/paddle/operators/math/lstm_compute.h @@ -0,0 +1,91 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/platform/device_context.h" +#include "paddle/platform/enforce.h" + +namespace paddle { +namespace operators { +namespace math { + +typedef enum { + HL_ACTIVATION_SIGMOID = 0, + HL_ACTIVATION_RELU = 1, + HL_ACTIVATION_TANH = 2, + HL_ACTIVATION_LINEAR = 3, + HL_ACTIVATION_END +} activation_mode_t; + +template +struct LstmMetaValue { + T *gateValue; + T *prevStateValue; + T *stateValue; + T *stateActiveValue; + T *outputValue; + T *checkIg; + T *checkFg; + T *checkOg; +}; + +template +struct LstmMetaGrad { + T *gateGrad; + T *prevStateGrad; + T *stateGrad; + T *stateActiveGrad; + T *outputGrad; + T *checkIgGrad; + T *checkFgGrad; + T *checkOgGrad; +}; + +inline activation_mode_t ActiveType(const std::string &type) { + if (type == "sigmoid") { + return HL_ACTIVATION_SIGMOID; + } else if (type == "relu") { + return HL_ACTIVATION_RELU; + } else if (type == "tanh") { + return HL_ACTIVATION_TANH; + } else if (type == "linear" || type == "identity" || type == "") { + return HL_ACTIVATION_LINEAR; + } else { + PADDLE_THROW("Do not support activation type."); + } +} + +template +class LstmUnitFunctor { + public: + static void compute(const platform::DeviceContext &context, + LstmMetaValue value, int frame_size, int batch_size, + const std::string &gate_act, const std::string &cell_act, + const std::string &cand_act); +}; + +template +class LstmUnitGradFunctor { + public: + static void compute(const platform::DeviceContext &context, + LstmMetaValue value, LstmMetaGrad grad, + int frame_size, int batch_size, + const std::string &gate_act, const std::string &cell_act, + const std::string &cand_act); +}; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/sequence2batch.cc b/paddle/operators/math/sequence2batch.cc new file mode 100644 index 0000000000000000000000000000000000000000..10c6e105b950b9d510e7a14828d72531e8eb0028 --- /dev/null +++ b/paddle/operators/math/sequence2batch.cc @@ -0,0 +1,61 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/math/sequence2batch.h" + +namespace paddle { +namespace operators { +namespace math { + +template +class CopyMatrixRowsFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::LoDTensor& src, const size_t* index, + framework::LoDTensor& dst, bool is_src_index) { + auto src_dims = src.dims(); + auto dst_dims = dst.dims(); + PADDLE_ENFORCE_EQ(src_dims.size(), 2UL, + "The src must be matrix with rank 2."); + PADDLE_ENFORCE_EQ(dst_dims.size(), 2UL, + "The dst must be matrix with rank 2."); + PADDLE_ENFORCE_EQ(src_dims[1], dst_dims[1], + "The width of src and dst must be same."); + auto height = dst_dims[0]; + auto width = dst_dims[1]; + auto* src_data = src.data(); + auto* dst_data = dst.data(); + for (int i = 0; i < height; ++i) { + if (is_src_index) { + memcpy(dst_data + i * width, src_data + index[i] * width, + width * sizeof(T)); + } else { + memcpy(dst_data + index[i] * width, src_data + i * width, + width * sizeof(T)); + } + } + } +}; + +template class CopyMatrixRowsFunctor; +template class CopyMatrixRowsFunctor; + +template class LoDTensor2BatchFunctor; +template class LoDTensor2BatchFunctor; +template class Batch2LoDTensorFunctor; +template class Batch2LoDTensorFunctor; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/sequence2batch.cu b/paddle/operators/math/sequence2batch.cu new file mode 100644 index 0000000000000000000000000000000000000000..4f349946785171e6c59b22163ba76791c7244f88 --- /dev/null +++ b/paddle/operators/math/sequence2batch.cu @@ -0,0 +1,78 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/math/sequence2batch.h" + +namespace paddle { +namespace operators { +namespace math { + +template +__global__ void CopyMatrixRowsKernel(const T* src, T* dst, const size_t* index, + int64_t height, int64_t width, + bool is_src_index) { + int idx = threadIdx.x; + int idy = threadIdx.y; + int id = blockIdx.x + idy * GridDimX; + while (id < height) { + int src_idx = is_src_index ? index[id] : id; + int dst_idx = is_src_index ? id : index[id]; + const T* src_data = src + src_idx * width; + T* dst_data = dst + dst_idx * width; + for (int i = idx; i < width; i += BlockDimX) { + dst_data[i] = src_data[i]; + } + id += BlockDimY * GridDimX; + } +} + +template +class CopyMatrixRowsFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::LoDTensor& src, const size_t* index, + framework::LoDTensor& dst, bool is_src_index) { + auto src_dims = src.dims(); + auto dst_dims = dst.dims(); + PADDLE_ENFORCE_EQ(src_dims.size(), 2, + "The src must be matrix with rank 2."); + PADDLE_ENFORCE_EQ(dst_dims.size(), 2, + "The dst must be matrix with rank 2."); + PADDLE_ENFORCE_EQ(src_dims[1], dst_dims[1], + "The width of src and dst must be same."); + auto height = dst_dims[0]; + auto width = dst_dims[1]; + auto* src_data = src.data(); + auto* dst_data = dst.data(); + + dim3 threads(128, 8); + dim3 grid(8, 1); + auto stream = + reinterpret_cast(context).stream(); + CopyMatrixRowsKernel<<>>( + src_data, dst_data, index, height, width, is_src_index); + } +}; + +template class CopyMatrixRowsFunctor; +template class CopyMatrixRowsFunctor; + +template class LoDTensor2BatchFunctor; +template class LoDTensor2BatchFunctor; +template class Batch2LoDTensorFunctor; +template class Batch2LoDTensorFunctor; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/sequence2batch.h b/paddle/operators/math/sequence2batch.h new file mode 100644 index 0000000000000000000000000000000000000000..03cd018e46e90c9bbe689c9686377e0e998ee513 --- /dev/null +++ b/paddle/operators/math/sequence2batch.h @@ -0,0 +1,148 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/lod_tensor.h" +#include "paddle/framework/tensor.h" +#include "paddle/platform/device_context.h" + +namespace paddle { +namespace operators { +namespace math { + +template +class CopyMatrixRowsFunctor { + public: + // If is_src_index is true, + // copy the indexed rows of input src to the output dst. + // If is_src_index is false, + // copy the input src to the indexed rows of output dst. + // The indexed rows are based on the input index. + void operator()(const platform::DeviceContext& context, + const framework::LoDTensor& src, const size_t* index, + framework::LoDTensor& dst, bool is_src_index); +}; + +template +class LoDTensor2BatchFunctor { + // Calculate the length of each sequence and + // sort sequence index by the length. + // example: sequences = {s0, s1, s2} + // s0: 0 0 0 0, s1: 1 1 1 1 1, s2: 2 2 2 + // seq_info[3] = {(4, 5, 1), (0, 4, 0), (9, 3, 2)} + // + struct SeqInfo { + SeqInfo(int start, int length, int seq_idx) + : start(start), length(length), seq_idx(seq_idx) {} + int start; + int length; + int seq_idx; + }; + + public: + void operator()(const platform::DeviceContext& context, + const framework::LoDTensor& lod_tensor, + framework::LoDTensor& batch, bool is_reverse) const { + auto lods = lod_tensor.lod(); + PADDLE_ENFORCE_EQ(lods.size(), 1UL, "Only support one level sequence now."); + auto lod = lods[0]; + + std::vector seq_info; + for (size_t seq_id = 0; seq_id < lod.size() - 1; ++seq_id) { + int length = lod[seq_id + 1] - lod[seq_id]; + seq_info.emplace_back(lod[seq_id], length, seq_id); + } + + std::sort(seq_info.begin(), seq_info.end(), + [](SeqInfo a, SeqInfo b) { return a.length > b.length; }); + + // calculate the start position of each batch + // (numBatch equal the maxLength of sequences) + // example: sequences = {s0, s1, s2} + // s0: 0 0 0 0, s1: 1 1 1 1 1, s2: 2 2 2 + // num_batch = 5, + // batchIndex = {b0, b1, b2, b3, b4} + // b0: 1 0 2, b1: 1 0 2, b2: 1 0 2, b3: 1 0, b4: 1 + // batch_start_positions[6] = {0, 3, 6, 9, 11, 12} + // batch_start_positions[0] = len(b0) + // batch_start_positions[1] = len(b0) + len(b1) + // batch_start_positions[2] = len(b0) + len(b1) + len(b2) + // ... + // seq2batch_idx[12] = {4, 0, 9, + // 5, 1, 10, + // 6, 2, 11, + // 7, 3, + // 8} + // The batch number represents batch size after rearranging the + // input LodTensor. It is also the maximum length of input sequence. + + paddle::framework::LoD batch_lods; + batch_lods.emplace_back(std::vector{0}); + batch_lods.emplace_back(std::vector{0}); + + // batch_lods[0] is the start positions for batch LoDTensor + int num_batch = seq_info[0].length; + batch_lods[0].resize(static_cast(num_batch + 1)); + // batch_lods[1] is the raw index in the input LoDTensor + auto dims = lod_tensor.dims(); + batch_lods[1].resize(static_cast(dims[0])); + + size_t* batch_starts = batch_lods[0].data(); + size_t* seq2batch_idx = batch_lods[1].data(); + batch_starts[0] = 0; + for (size_t n = 0; n < num_batch; n++) { + auto batch_id = static_cast(batch_starts[n]); + for (size_t i = 0; i < seq_info.size(); ++i) { + size_t seq_len = seq_info[i].length; + int start = seq_info[i].start; + if (n < seq_len) { + seq2batch_idx[batch_id] = + is_reverse ? start + seq_len - 1 - n : start + n; + batch_id++; + } else { + break; + } + } + batch_starts[n + 1] = static_cast(batch_id); + } + batch.set_lod(batch_lods); + + CopyMatrixRowsFunctor to_batch; + to_batch(context, lod_tensor, seq2batch_idx, batch, true); + } +}; + +template +class Batch2LoDTensorFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::LoDTensor& batch, + framework::LoDTensor& lod_tensor) const { + auto in_lod = batch.lod(); + PADDLE_ENFORCE_EQ(in_lod.size(), 2UL, + "The LoD size of input `batch` should be 2."); + auto out_lod = lod_tensor.lod()[0]; + auto num = out_lod[out_lod.size() - 1]; + PADDLE_ENFORCE_EQ(num, lod_tensor.dims()[0]); + PADDLE_ENFORCE_EQ(num, in_lod[1].size()); + PADDLE_ENFORCE_EQ(num, batch.dims()[0]); + CopyMatrixRowsFunctor to_seq; + size_t* index = in_lod[1].data(); + to_seq(context, batch, index, lod_tensor, false); + } +}; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/sequence_project.h b/paddle/operators/math/sequence_project.h index a2ab86f790df313b04866b259883d716a4ce5502..53b61ce16c99f27bb0612f5c6df4ff1c1276356a 100644 --- a/paddle/operators/math/sequence_project.h +++ b/paddle/operators/math/sequence_project.h @@ -133,8 +133,8 @@ class SequenceProjectFunctor { in_t.Resize(framework::make_ddim(input_shape)); im2col_ocf(context, in_t, out_t, - /*stride_height*/ context_stride, /*stride_width*/ 0, up_pad, - down_pad); + /*stride_height*/ context_stride, /*stride_width*/ 1, up_pad, + down_pad, 0, 0); } if (padding_trainable) { diff --git a/paddle/operators/mul_op.cc b/paddle/operators/mul_op.cc index 065800f250d8b35a626060bac271e1bce6bb784b..b9b9cd7ca05b4373c27f672cc1ee20daab6827a8 100644 --- a/paddle/operators/mul_op.cc +++ b/paddle/operators/mul_op.cc @@ -49,7 +49,19 @@ class MulOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_EQ( x_mat_dims[1], y_mat_dims[0], "First matrix's width must be equal with second matrix's height."); - ctx->SetOutputDim("Out", {x_mat_dims[0], y_mat_dims[1]}); + std::vector output_dims; + output_dims.reserve( + static_cast(x_num_col_dims + y_dims.size() - y_num_col_dims)); + + for (int i = 0; i < x_num_col_dims; ++i) { + output_dims.push_back(x_dims[i]); + } + + for (int i = y_num_col_dims; i < y_dims.size(); ++i) { + output_dims.push_back(y_dims[i]); + } + + ctx->SetOutputDim("Out", framework::make_ddim(output_dims)); ctx->ShareLoD("X", /*->*/ "Out"); } }; @@ -109,15 +121,6 @@ class MulOpGrad : public framework::OperatorWithKernel { auto y_mat_dims = framework::flatten_to_2d( y_dims, ctx->Attrs().Get("y_num_col_dims")); - PADDLE_ENFORCE_EQ( - x_mat_dims[0], out_dims[0], - "The first dimension of Out@GRAD must equal to the first dimension of " - "the first operand."); - PADDLE_ENFORCE_EQ( - y_mat_dims[1], out_dims[1], - "The second dimension of Out@GRAD must equal to the second " - "dimension of the second operand."); - auto x_grad_name = framework::GradVarName("X"); auto y_grad_name = framework::GradVarName("Y"); diff --git a/paddle/operators/mul_op.h b/paddle/operators/mul_op.h index 3f3e77595b701d428a728fc4727dd3ff4abee45f..bd1bdb4f81b88256822d663fe42ad314338c91ff 100644 --- a/paddle/operators/mul_op.h +++ b/paddle/operators/mul_op.h @@ -46,8 +46,15 @@ class MulKernel : public framework::OpKernel { : *y; z->mutable_data(context.GetPlace()); + auto z_dim = z->dims(); + if (z_dim.size() != 2) { + z->Resize({x_matrix.dims()[0], y_matrix.dims()[1]}); + } math::matmul(context.device_context(), x_matrix, false, y_matrix, false, 1, z, 0); + if (z_dim.size() != 2) { + z->Resize(z_dim); + } } }; @@ -67,6 +74,11 @@ class MulGradKernel : public framework::OpKernel { : *y; const Tensor* dout = ctx.Input(framework::GradVarName("Out")); + Tensor dout_mat; + dout_mat.ShareDataWith(*dout); + dout_mat.Resize({framework::flatten_to_2d(x->dims(), x_num_col_dims)[0], + framework::flatten_to_2d(y->dims(), y_num_col_dims)[1]}); + Tensor* dx = ctx.Output(framework::GradVarName("X")); Tensor* dy = ctx.Output(framework::GradVarName("Y")); if (dx) { @@ -74,9 +86,10 @@ class MulGradKernel : public framework::OpKernel { Tensor dx_matrix = dx->dims().size() > 2 ? framework::ReshapeToMatrix(*dx, x_num_col_dims) : *dx; + // dx = dout * y'. dx: M x K, dout : M x N, y : K x N - math::matmul(ctx.device_context(), *dout, false, y_matrix, true, - 1, &dx_matrix, 0); + math::matmul(ctx.device_context(), dout_mat, false, y_matrix, + true, 1, &dx_matrix, 0); } if (dy) { dy->mutable_data(ctx.GetPlace()); @@ -84,8 +97,8 @@ class MulGradKernel : public framework::OpKernel { ? framework::ReshapeToMatrix(*dy, y_num_col_dims) : *dy; // dy = x' * dout. dy K x N, dout : M x N, x : M x K - math::matmul(ctx.device_context(), x_matrix, true, *dout, false, - 1, &dy_matrix, 0); + math::matmul(ctx.device_context(), x_matrix, true, dout_mat, + false, 1, &dy_matrix, 0); } } }; diff --git a/paddle/operators/reduce_op.cc b/paddle/operators/reduce_op.cc index 46f66a1370a35593d1911fc9b3ce76beb38c0956..0599daa7688a5658ebea8902c4e15e63570539fb 100644 --- a/paddle/operators/reduce_op.cc +++ b/paddle/operators/reduce_op.cc @@ -160,66 +160,6 @@ class ReduceMinOpMaker : public ReduceOpMaker { } }; -class NormOp : public NetOp { - public: - NormOp(const std::string &type, const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) - : NetOp(type, inputs, outputs, attrs) { - PADDLE_ENFORCE_NE(Input("X"), framework::kEmptyVarName, - "Input(X) of NormOp should not be null."); - PADDLE_ENFORCE_NE(Output("AbsOut"), framework::kEmptyVarName, - "Output(AbsOut) of NormOp should not be null."); - PADDLE_ENFORCE_NE(Output("PowOut"), framework::kEmptyVarName, - "Output(PowOut) of NormOp should not be null."); - PADDLE_ENFORCE_NE(Output("SumOut"), framework::kEmptyVarName, - "Output(SumOut) of NormOp should not be null."); - PADDLE_ENFORCE_NE(Output("Out"), framework::kEmptyVarName, - "Output(Out) of NormOp should not be null."); - auto dim = Attr("dim"); - auto keep_dim = Attr("keep_dim"); - auto p = Attr("p"); - PADDLE_ENFORCE_GT(p, 0, "Order of the norm should be positive."); - AppendOp(framework::OpRegistry::CreateOp("abs", {{"X", {Input("X")}}}, - {{"Y", {Output("AbsOut")}}}, {})); - AppendOp(framework::OpRegistry::CreateOp("pow", {{"X", {Output("AbsOut")}}}, - {{"Y", {Output("PowOut")}}}, - {{"factor", p}})); - framework::AttributeMap sum_attr; - sum_attr["dim"] = dim; - sum_attr["keep_dim"] = keep_dim; - AppendOp(framework::OpRegistry::CreateOp( - "reduce_sum", {{"X", {Output("PowOut")}}}, - {{"Out", {Output("SumOut")}}}, sum_attr)); - AppendOp(framework::OpRegistry::CreateOp( - "pow", {{"X", {Output("SumOut")}}}, {{"Y", {Output("Out")}}}, - {{"factor", static_cast(1. / p)}})); - CompleteAddOp(false); - } -}; - -class NormOpMaker : public ReduceOpMaker { - public: - NormOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) - : ReduceOpMaker(proto, op_checker) { - AddOutput("AbsOut", - "(Tensor) The intermediate output of Norm operator, " - "saving the absolute value of the input tensor X.") - .AsIntermediate(); - AddOutput("PowOut", - "(Tensor) The intermediate output of Norm operator, " - "saving the p-th power of the output tensor AbsOut.") - .AsIntermediate(); - AddOutput("SumOut", - "(Tensor) the intermediate output of Norm operator, " - "saving the sum of PowOut reduced on the given dimension.") - .AsIntermediate(); - AddAttr("p", "(float, default 2) The order of Norm.").SetDefault(2); - SetComment("Norm", "vector p-norm"); - AddComment(comment_); - } -}; - } // namespace operators } // namespace paddle @@ -237,8 +177,6 @@ REGISTER_OP(reduce_max, ops::ReduceOp, ops::ReduceMaxOpMaker, reduce_max_grad, REGISTER_OP(reduce_min, ops::ReduceOp, ops::ReduceMinOpMaker, reduce_min_grad, ops::ReduceGradOp); -REGISTER_OP_WITHOUT_GRADIENT(norm, ops::NormOp, ops::NormOpMaker); - #define REGISTER_REDUCE_CPU_KERNEL(reduce_type, functor, grad_functor) \ REGISTER_OP_CPU_KERNEL( \ reduce_type, \ diff --git a/paddle/operators/sequence_conv_op.h b/paddle/operators/sequence_conv_op.h index b6ae12f6bb0d4f6650db29b9a358ac43b0052945..4735fa4a5fa249e28892afd46960c2a81ab42f07 100644 --- a/paddle/operators/sequence_conv_op.h +++ b/paddle/operators/sequence_conv_op.h @@ -167,8 +167,8 @@ class SequenceConvGradKernel : public framework::OpKernel { in_t.Resize(framework::make_ddim(input_shape)); col2im_ocf(context.device_context(), in_t, col_t, - /*stride_height*/ context_stride, /*stride_width*/ 0, - up_pad, down_pad); + /*stride_height*/ context_stride, /*stride_width*/ 1, + up_pad, down_pad, 0, 0); } col_t.Resize(framework::make_ddim( {sequence_height, context_length * sequence_width})); diff --git a/paddle/operators/smooth_l1_loss_op.cc b/paddle/operators/smooth_l1_loss_op.cc index a4f0f37764667c43d48c6aa7646d61cdf4f3fd2d..758481943d463f22eb6c6e0be9a99ad99161da5b 100644 --- a/paddle/operators/smooth_l1_loss_op.cc +++ b/paddle/operators/smooth_l1_loss_op.cc @@ -62,11 +62,13 @@ class SmoothL1LossOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("InsideWeight", "Optional input tensor of smooth l1 loss op with the same shape " "as X. If provided, the result of (X - Y) will be multiplied " - "by this tensor element by element."); + "by this tensor element by element.") + .AsDispensable(); AddInput("OutsideWeight", "Optinal input of smooth l1 loss op with the same shape as X." "If provided, the output smooth l1 loss will be multiplied by " - "this tensor element by element."); + "this tensor element by element.") + .AsDispensable(); AddOutput("Diff", "Intermediate variable to cache InsideWeight*(X-Y).") .AsIntermediate(); AddOutput("Out", "Smooth l1 loss."); diff --git a/paddle/operators/uniform_random_op.cc b/paddle/operators/uniform_random_op.cc index f244ddc51fab3a6a82ffe517e35a97bc77f61b3e..39b53948e3cc58ff1d0ab481143b066b1a2fae16 100644 --- a/paddle/operators/uniform_random_op.cc +++ b/paddle/operators/uniform_random_op.cc @@ -65,7 +65,7 @@ class UniformRandomOp : public framework::OperatorWithKernel { protected: framework::DataType IndicateDataType( const framework::ExecutionContext& ctx) const override { - return static_cast(Attr("data_type")); + return static_cast(ctx.Attr("data_type")); } }; diff --git a/paddle/platform/CMakeLists.txt b/paddle/platform/CMakeLists.txt index daf519b91d623d4369774dc4e37dcb7b1733666b..eb850b658583f2256629d63fdb64248dbf249937 100644 --- a/paddle/platform/CMakeLists.txt +++ b/paddle/platform/CMakeLists.txt @@ -25,3 +25,4 @@ nv_test(device_context_test SRCS device_context_test.cc DEPS device_context gpu_ nv_test(cudnn_helper_test SRCS cudnn_helper_test.cc DEPS dynload_cuda) nv_test(transform_test SRCS transform_test.cu DEPS paddle_memory place device_context) +nv_test(nccl_test SRCS nccl_test.cu DEPS dynload_cuda gpu_info device_context) diff --git a/paddle/platform/dynload/CMakeLists.txt b/paddle/platform/dynload/CMakeLists.txt index ceb66f84b6b01892cbaf61c79a47ae60d2589164..bb3fec1be9e811c26cc6851314e960e96fc366b3 100644 --- a/paddle/platform/dynload/CMakeLists.txt +++ b/paddle/platform/dynload/CMakeLists.txt @@ -1,2 +1,3 @@ cc_library(dynamic_loader SRCS dynamic_loader.cc DEPS glog gflags) -nv_library(dynload_cuda SRCS cublas.cc cudnn.cc curand.cc DEPS dynamic_loader) +nv_library(dynload_cuda SRCS cublas.cc cudnn.cc curand.cc nccl.cc + DEPS dynamic_loader nccl) diff --git a/paddle/platform/dynload/dynamic_loader.cc b/paddle/platform/dynload/dynamic_loader.cc index ae9a0a982c73de05821579d22b7f9ad99f24a92b..6feba42c0d9d618d27da12e6a6752058b296995e 100644 --- a/paddle/platform/dynload/dynamic_loader.cc +++ b/paddle/platform/dynload/dynamic_loader.cc @@ -35,6 +35,11 @@ DEFINE_string(warpctc_dir, "", "Specify path for loading libwarpctc.so."); DEFINE_string(lapack_dir, "", "Specify path for loading liblapack.so."); +DEFINE_string(nccl_dir, "", + "Specify path for loading nccl library, such as libcublas, " + "libcurand. For instance, /usr/local/cuda/lib64. If default, " + "dlopen will search cuda from LD_LIBRARY_PATH"); + namespace paddle { namespace platform { namespace dynload { @@ -157,6 +162,14 @@ void GetLapackDsoHandle(void** dso_handle) { #endif } +void GetNCCLDsoHandle(void** dso_handle) { +#if defined(__APPLE__) || defined(__OSX__) + GetDsoHandleFromSearchPath(FLAGS_nccl_dir, "libnccl.dylib", dso_handle); +#else + GetDsoHandleFromSearchPath(FLAGS_nccl_dir, "libnccl.so", dso_handle); +#endif +} + } // namespace dynload } // namespace platform } // namespace paddle diff --git a/paddle/platform/dynload/dynamic_loader.h b/paddle/platform/dynload/dynamic_loader.h index a99b05443feb909f10b2c56f4d8bdf3c6fa11e3f..c0e5452e5ae723ec314ebafde86a6ff63980be00 100644 --- a/paddle/platform/dynload/dynamic_loader.h +++ b/paddle/platform/dynload/dynamic_loader.h @@ -58,6 +58,14 @@ void GetWarpCTCDsoHandle(void** dso_handle); */ void GetLapackDsoHandle(void** dso_handle); +/** + * @brief load the DSO of NVIDIA nccl + * + * @param **dso_handle dso handler + * + */ +void GetNCCLDsoHandle(void** dso_handle); + } // namespace dynload } // namespace platform } // namespace paddle diff --git a/paddle/platform/dynload/nccl.cc b/paddle/platform/dynload/nccl.cc new file mode 100644 index 0000000000000000000000000000000000000000..8f92b8d94d56047b7d3fb43b15e3c06575c8d57b --- /dev/null +++ b/paddle/platform/dynload/nccl.cc @@ -0,0 +1,30 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/platform/dynload/nccl.h" + +namespace paddle { +namespace platform { +namespace dynload { + +std::once_flag nccl_dso_flag; +void *nccl_dso_handle; + +#define DEFINE_WRAP(__name) DynLoad__##__name __name + +NCCL_RAND_ROUTINE_EACH(DEFINE_WRAP); + +} // namespace dynload +} // namespace platform +} // namespace paddle diff --git a/paddle/platform/dynload/nccl.h b/paddle/platform/dynload/nccl.h new file mode 100644 index 0000000000000000000000000000000000000000..0618c7414fd1235e81ee9d92a3a07b53d6ad6ebc --- /dev/null +++ b/paddle/platform/dynload/nccl.h @@ -0,0 +1,72 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once + +#include +#include +#include +#include "paddle/platform/dynload/dynamic_loader.h" + +namespace paddle { +namespace platform { +namespace dynload { + +extern std::once_flag nccl_dso_flag; +extern void* nccl_dso_handle; + +#ifdef PADDLE_USE_DSO +#define DECLARE_DYNAMIC_LOAD_NCCL_WRAP(__name) \ + struct DynLoad__##__name { \ + template \ + auto operator()(Args... args) -> decltype(__name(args...)) { \ + using nccl_func = decltype(__name(args...)) (*)(Args...); \ + std::call_once(nccl_dso_flag, \ + paddle::platform::dynload::GetNCCLDsoHandle, \ + &nccl_dso_handle); \ + void* p_##__name = dlsym(nccl_dso_handle, #__name); \ + return reinterpret_cast(p_##__name)(args...); \ + } \ + }; \ + extern DynLoad__##__name __name +#else +#define DECLARE_DYNAMIC_LOAD_NCCL_WRAP(__name) \ + struct DynLoad__##__name { \ + template \ + ncclResult_t operator()(Args... args) { \ + return __name(args...); \ + } \ + }; \ + extern DynLoad__##__name __name +#endif + +#define NCCL_RAND_ROUTINE_EACH(__macro) \ + __macro(ncclCommInitAll); \ + __macro(ncclGetUniqueId); \ + __macro(ncclCommInitRank); \ + __macro(ncclCommDestroy); \ + __macro(ncclCommCount); \ + __macro(ncclCommCuDevice); \ + __macro(ncclCommUserRank); \ + __macro(ncclAllReduce); \ + __macro(ncclBcast); \ + __macro(ncclAllGather); \ + __macro(ncclReduce); \ + __macro(ncclGetErrorString); + +NCCL_RAND_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_NCCL_WRAP) + +} // namespace dynload +} // namespace platform +} // namespace paddle diff --git a/paddle/platform/enforce.h b/paddle/platform/enforce.h index cd906c3fa9375cd6edaed0377a596771e25043d4..bfe708748a62ff9ac5d151bc652142e1f4925c83 100644 --- a/paddle/platform/enforce.h +++ b/paddle/platform/enforce.h @@ -29,11 +29,14 @@ limitations under the License. */ #include // for __cxa_demangle #endif +#include + #ifdef PADDLE_WITH_CUDA #include "paddle/platform/dynload/cublas.h" #include "paddle/platform/dynload/cudnn.h" #include "paddle/platform/dynload/curand.h" +#include "paddle/platform/dynload/nccl.h" #include #include @@ -172,6 +175,17 @@ inline typename std::enable_if::type throw_on_error( throw std::runtime_error(err + string::Sprintf(args...)); } +template +inline typename std::enable_if::type throw_on_error( + ncclResult_t stat, const Args&... args) { + if (stat == ncclSuccess) { + return; + } else { + throw std::runtime_error(platform::dynload::ncclGetErrorString(stat) + + string::Sprintf(args...)); + } +} + #endif // PADDLE_ONLY_CPU template diff --git a/paddle/platform/nccl_test.cu b/paddle/platform/nccl_test.cu new file mode 100644 index 0000000000000000000000000000000000000000..ab8b96f7263aed83407866fedf9e529ce0affe3f --- /dev/null +++ b/paddle/platform/nccl_test.cu @@ -0,0 +1,139 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "glog/logging.h" +#include "gtest/gtest.h" +#include "paddle/platform/device_context.h" +#include "paddle/platform/dynload/nccl.h" +#include "paddle/platform/enforce.h" +#include "paddle/platform/gpu_info.h" + +#include +#include +#include + +static int dev_count = 0; + +namespace paddle { +namespace platform { + +TEST(NCCL, init) { + std::vector comms; + comms.resize(dev_count); + + auto status = dynload::ncclCommInitAll(comms.data(), dev_count, nullptr); + PADDLE_ENFORCE(status); + for (int i = 0; i < dev_count; ++i) { + dynload::ncclCommDestroy(comms[i]); + } +} + +template +struct PerThreadData { + thrust::device_vector send_buff; + thrust::device_vector recv_buff; + CUDADeviceContext dev_ctx; + + T* SendBuff() { return thrust::raw_pointer_cast(send_buff.data()); } + + T* RecvBuff() { return thrust::raw_pointer_cast(recv_buff.data()); } + + PerThreadData(int gpu_id, size_t size) : dev_ctx(GPUPlace(gpu_id)) { + send_buff.resize(size); + for (size_t i = 0; i < size; ++i) { + send_buff[i] = static_cast(i); + } + recv_buff.resize(size); + } +}; + +static constexpr int ELEM_COUNT = 10000; + +TEST(NCCL, all_reduce) { + std::vector comms; + comms.resize(dev_count); + VLOG(1) << "Initializing ncclComm"; + auto status = dynload::ncclCommInitAll(comms.data(), dev_count, nullptr); + PADDLE_ENFORCE(status); + VLOG(1) << "ncclComm initialized"; + VLOG(1) << "Creating thread data"; + std::vector>> data; + data.reserve(dev_count); + for (int i = 0; i < dev_count; ++i) { + VLOG(1) << "Creating thread data for device " << i; + SetDeviceId(i); + data.emplace_back(new PerThreadData(i, ELEM_COUNT)); + } + VLOG(1) << "Thread data created"; + + VLOG(1) << "Check send_buf data"; + for (int i = 0; i < dev_count; ++i) { + VLOG(1) << "Check on device " << i; + SetDeviceId(i); + thrust::host_vector tmp = data[i]->send_buff; + for (size_t j = 0; j < tmp.size(); ++j) { + ASSERT_NEAR(static_cast(j), tmp[j], 1e-5); + } + } + + VLOG(1) << "Invoking ncclAllReduce"; + + for (int i = 0; i < dev_count; ++i) { + VLOG(1) << "Invoking ncclAllReduce with device " << i; + SetDeviceId(i); + PADDLE_ENFORCE(dynload::ncclAllReduce( + data[i]->SendBuff(), data[i]->RecvBuff(), ELEM_COUNT, ncclDouble, + ncclSum, comms[i], data[i]->dev_ctx.stream())); + VLOG(1) << "Invoked ncclAllReduce for device " << i; + } + + VLOG(1) << "Invoked ncclAllReduce"; + + VLOG(1) << "Sync devices"; + for (int i = 0; i < dev_count; ++i) { + VLOG(1) << "Sync device " << i; + SetDeviceId(i); + data[i]->dev_ctx.Wait(); + } + VLOG(1) << "device synced"; + + for (int i = 0; i < dev_count; ++i) { + SetDeviceId(i); + VLOG(1) << "Checking vector on device " << i; + thrust::host_vector tmp = data[i]->recv_buff; + for (size_t j = 0; j < tmp.size(); ++j) { + auto elem = static_cast(j); + elem *= dev_count; + ASSERT_NEAR(tmp[j], elem, 1e-4); + } + } + + for (int i = 0; i < dev_count; ++i) { + dynload::ncclCommDestroy(comms[i]); + } +} +} // namespace platform +} // namespace paddle + +int main(int argc, char** argv) { + dev_count = paddle::platform::GetCUDADeviceCount(); + if (dev_count <= 1) { + LOG(WARNING) + << "Cannot test multi-gpu nccl, because the CUDA device count is " + << dev_count; + return 0; + } + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/paddle/platform/place.h b/paddle/platform/place.h index 0efc6932349a5b3ad295d195a16737a642e18943..5370360a7de26e409a1545182a12d3df1f37658b 100644 --- a/paddle/platform/place.h +++ b/paddle/platform/place.h @@ -35,6 +35,7 @@ struct GPUPlace { GPUPlace() : GPUPlace(0) {} explicit GPUPlace(int d) : device(d) {} + inline int GetDeviceId() const { return device; } // needed for variant equality comparison inline bool operator==(const GPUPlace &o) const { return device == o.device; } inline bool operator!=(const GPUPlace &o) const { return !(*this == o); } diff --git a/paddle/pybind/CMakeLists.txt b/paddle/pybind/CMakeLists.txt index 46c24e2cd53c068a25e1a5c8c6df600c3111e20a..d7cd738828a10b431370c92026b89d62add1275e 100644 --- a/paddle/pybind/CMakeLists.txt +++ b/paddle/pybind/CMakeLists.txt @@ -4,3 +4,5 @@ if(WITH_PYTHON) DEPS pybind python backward proto_desc tensor_array paddle_memory executor ${GLOB_OP_LIB}) endif(WITH_PYTHON) + +cc_binary(print_operators_doc SRCS print_operators_doc.cc DEPS ${GLOB_OP_LIB} tensor_array) diff --git a/paddle/pybind/print_operators_doc.cc b/paddle/pybind/print_operators_doc.cc new file mode 100644 index 0000000000000000000000000000000000000000..24f2a9383f7a069f1a8c7ed2bf3da46720470efa --- /dev/null +++ b/paddle/pybind/print_operators_doc.cc @@ -0,0 +1,132 @@ +#include +#include // std::stringstream +#include + +#include "paddle/framework/op_info.h" +#include "paddle/framework/op_registry.h" +#include "paddle/pybind/pybind.h" + +std::string Escape(const std::string& s) { + std::string r; + for (size_t i = 0; i < s.size(); i++) { + switch (s[i]) { + case '\"': + r += "\\\""; + break; + case '\\': + r += "\\\\"; + break; + case '\n': + r += "\\n"; + break; + case '\t': + r += "\\t"; + case '\r': + break; + default: + r += s[i]; + break; + } + } + return r; +} + +std::string AttrType(paddle::framework::AttrType at) { + switch (at) { + case paddle::framework::INT: + return "int"; + case paddle::framework::FLOAT: + return "float"; + case paddle::framework::STRING: + return "string"; + case paddle::framework::BOOLEAN: + return "bool"; + case paddle::framework::INTS: + return "int array"; + case paddle::framework::FLOATS: + return "float array"; + case paddle::framework::STRINGS: + return "string array"; + case paddle::framework::BOOLEANS: + return "bool array"; + case paddle::framework::BLOCK: + return "block id"; + } + return "UNKNOWN"; // not possible +} + +void PrintVar(const paddle::framework::OpProto::Var& v, std::stringstream& ss) { + ss << " { " + << "\n" + << " \"name\" : \"" << Escape(v.name()) << "\",\n" + << " \"comment\" : \"" << Escape(v.comment()) << "\",\n" + << " \"duplicable\" : " << v.duplicable() << ",\n" + << " \"intermediate\" : " << v.intermediate() << "\n" + << " },"; +} + +void PrintAttr(const paddle::framework::OpProto::Attr& a, + std::stringstream& ss) { + ss << " { " + << "\n" + << " \"name\" : \"" << Escape(a.name()) << "\",\n" + << " \"type\" : \"" << AttrType(a.type()) << "\",\n" + << " \"comment\" : \"" << Escape(a.comment()) << "\",\n" + << " \"generated\" : " << a.generated() << "\n" + << " },"; +} + +void PrintOpProto(const std::string& type, + const paddle::framework::OpInfo& opinfo, + std::stringstream& ss) { + std::cerr << "Processing " << type << "\n"; + + const paddle::framework::OpProto* p = opinfo.proto_; + if (p == nullptr) { + return; // It is possible that an operator doesn't have OpProto. + } + + ss << "{\n" + << " \"type\" : \"" << Escape(p->type()) << "\",\n" + << " \"comment\" : \"" << Escape(p->comment()) << "\",\n"; + + ss << " \"inputs\" : [ " + << "\n"; + for (int i = 0; i < p->inputs_size(); i++) { + PrintVar(p->inputs(i), ss); + } + ss.seekp(-1, ss.cur); // remove the trailing comma + ss << " ], " + << "\n"; + + ss << " \"outputs\" : [ " + << "\n"; + for (int i = 0; i < p->outputs_size(); i++) { + PrintVar(p->outputs(i), ss); + } + ss.seekp(-1, ss.cur); // remove the trailing comma + ss << " ], " + << "\n"; + + ss << " \"attrs\" : [ " + << "\n"; + for (int i = 0; i < p->attrs_size(); i++) { + PrintAttr(p->attrs(i), ss); + } + ss.seekp(-1, ss.cur); // remove the trailing comma + ss << " ] " + << "\n"; + + ss << "},"; +} + +int main() { + std::stringstream ss; + ss << "[\n"; + for (auto& iter : paddle::framework::OpInfoMap::Instance().map()) { + PrintOpProto(iter.first, iter.second, ss); + } + ss.seekp(-1, ss.cur); // remove the trailing comma + ss << "]\n"; + std::cout << ss.str(); +} diff --git a/paddle/pybind/protobuf.cc b/paddle/pybind/protobuf.cc index 405ac544e10f19a33399a649f76699fefc3d49b9..5d43ecea11202fa3f9e21fbde907c9d1d7dd4025 100644 --- a/paddle/pybind/protobuf.cc +++ b/paddle/pybind/protobuf.cc @@ -257,6 +257,7 @@ void BindOpDesc(py::module &m) { .def("block_attr", &OpDescBind::GetBlockAttr) .def("check_attrs", &OpDescBind::CheckAttrs) .def("infer_shape", &OpDescBind::InferShape) + .def("infer_var_type", &OpDescBind::InferVarType) .def("serialize_to_string", [](OpDescBind &op_desc) -> py::bytes { const OpDesc *desc = op_desc.Proto(); PADDLE_ENFORCE(desc->IsInitialized(), diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index 26b793a4bbf5df7a2635838a6c6a8264ca8ebb67..b6e44fdbad6e2817e3077901f58177adc4bb0c71 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -225,15 +225,16 @@ All parameter, weight, gradient are variables in Paddle. //! Python str. If you want a str object, you should cast them in Python. m.def("get_all_op_protos", []() -> std::vector { std::vector ret_values; - - OpInfoMap::Instance().IterAllInfo([&ret_values](const std::string &type, - const OpInfo &info) { - if (!info.HasOpProtoAndChecker()) return; - std::string str; - PADDLE_ENFORCE(info.Proto().SerializeToString(&str), - "Serialize OpProto Error. This could be a bug of Paddle."); - ret_values.emplace_back(str); - }); + for (auto &iter : OpInfoMap::Instance().map()) { + auto &info = iter.second; + if (info.HasOpProtoAndChecker()) { + std::string str; + PADDLE_ENFORCE( + info.Proto().SerializeToString(&str), + "Serialize OpProto Error. This could be a bug of Paddle."); + ret_values.emplace_back(str); + } + } return ret_values; }); m.def_submodule( diff --git a/paddle/trainer/tests/sample_trainer_config_branch_net.conf b/paddle/trainer/tests/sample_trainer_config_branch_net.conf index c2594bc13c250a877a7b8a77e11405671c4d8907..a073708a184d6392a4eead69272e684013f1c751 100644 --- a/paddle/trainer/tests/sample_trainer_config_branch_net.conf +++ b/paddle/trainer/tests/sample_trainer_config_branch_net.conf @@ -17,7 +17,7 @@ from paddle.trainer_config_helpers import * ################################### Data Configuration ################################### TrainData(ProtoData(files = "trainer/tests/mnist.list")) ################################### Algorithm Configuration ################################### -settings(batch_size = 256, +settings(batch_size = 128, learning_method = MomentumOptimizer(momentum=0.5, sparse=False)) ################################### Network Configuration ################################### data = data_layer(name ="input", size=784) @@ -44,10 +44,11 @@ a2 = img_conv_layer(input=tmp, shared_biases=True, act=ReluActivation()) -tmp = concat_layer(input=[a1, a2]) +tmp = addto_layer(input=[a1, a2], + act=ReluActivation(), + bias_attr=False) tmp = img_pool_layer(input=tmp, - num_channels=64, pool_size=3, stride=2, padding=1, @@ -55,35 +56,34 @@ tmp = img_pool_layer(input=tmp, b1 = img_conv_layer(input=tmp, filter_size=3, - num_filters=64, + num_filters=32, padding=1, shared_biases=True, act=ReluActivation()) b1 = img_pool_layer(input=b1, pool_size=3, - stride=1, - padding=1, + stride=2, + padding=0, pool_type=MaxPooling()) b2 = img_conv_layer(input=tmp, - filter_size=5, + filter_size=3, num_filters=64, - padding=2, + padding=1, shared_biases=True, act=ReluActivation()) b2 = img_pool_layer(input=b2, pool_size=5, - stride=1, - padding=2, + stride=2, + padding=1, pool_type=MaxPooling()) -tmp = addto_layer(input=[b1, b2], - act=ReluActivation(), - bias_attr=False) +tmp = concat_layer(input=[b1, b2]) tmp = img_pool_layer(input=tmp, + num_channels=96, pool_size=3, stride=2, padding=1, diff --git a/paddle/trainer/tests/sample_trainer_config_simple_net.conf b/paddle/trainer/tests/sample_trainer_config_simple_net.conf index 77f78161535c49da4ef7fc1563cff58c021aecef..2ba71884d0953dc721808732fde12e695c6a757d 100644 --- a/paddle/trainer/tests/sample_trainer_config_simple_net.conf +++ b/paddle/trainer/tests/sample_trainer_config_simple_net.conf @@ -17,7 +17,7 @@ from paddle.trainer_config_helpers import * ################################### Data Configuration ################################### TrainData(ProtoData(files = "trainer/tests/mnist.list")) ################################### Algorithm Configuration ################################### -settings(batch_size = 1000, +settings(batch_size = 128, learning_method = MomentumOptimizer(momentum=0.5, sparse=False)) ################################### Network Configuration ################################### data = data_layer(name ="input", size=784) diff --git a/python/paddle/trainer/PyDataProvider2.py b/python/paddle/trainer/PyDataProvider2.py index 248da4ae8d1fb24652625ae8fc9ef314a028b912..05635833bf1645f78f5ba15caee3e9b8da9f5544 100644 --- a/python/paddle/trainer/PyDataProvider2.py +++ b/python/paddle/trainer/PyDataProvider2.py @@ -175,7 +175,7 @@ def index_slot(value_range, seq_type=SequenceType.NO_SEQUENCE): dense_vector = dense_slot sparse_binary_vector = sparse_non_value_slot -sparse_vector = sparse_value_slot +sparse_float_vector = sparse_value_slot integer_value = index_slot # dense_array can be used for variable-length input feature. @@ -216,7 +216,7 @@ def sparse_binary_vector_sub_sequence(dim): return sparse_binary_vector(dim, seq_type=SequenceType.SUB_SEQUENCE) -def sparse_vector_sequence(dim): +def sparse_float_vector_sequence(dim): """ Data type of a sequence of sparse vector, which most elements are zero, others could be any float value. @@ -226,11 +226,11 @@ def sparse_vector_sequence(dim): :return: An input type object :rtype: InputType """ - return sparse_vector(dim, seq_type=SequenceType.SEQUENCE) + return sparse_float_vector(dim, seq_type=SequenceType.SEQUENCE) -def sparse_vector_sub_sequence(dim): - return sparse_vector(dim, seq_type=SequenceType.SUB_SEQUENCE) +def sparse_float_vector_sub_sequence(dim): + return sparse_float_vector(dim, seq_type=SequenceType.SUB_SEQUENCE) def integer_value_sequence(value_range): diff --git a/python/paddle/trainer_config_helpers/networks.py b/python/paddle/trainer_config_helpers/networks.py index 120c9d11a5ebaa72b94590e596fd4362c552f979..3821d075cba5d39b5808a39093b8570d9302b667 100644 --- a/python/paddle/trainer_config_helpers/networks.py +++ b/python/paddle/trainer_config_helpers/networks.py @@ -1457,11 +1457,13 @@ def dot_product_attention(encoded_sequence, expanded = expand_layer( input=transformed_state, - expanded_as=encoded_sequence, + expand_as=encoded_sequence, name='%s_expand' % name) m = linear_comb_layer( - weights=expanded, vectors=encoded_sequence, name='%s_dot-product') + weights=expanded, + vectors=encoded_sequence, + name='%s_dot-product' % name) attention_weight = fc_layer( input=m, diff --git a/python/paddle/v2/framework/framework.py b/python/paddle/v2/framework/framework.py index 03a3dacf25c2ad5514e914d2f6e9637493ba80f4..40b9008d67b4e42093b9b9cbdecd1dbff4150b41 100644 --- a/python/paddle/v2/framework/framework.py +++ b/python/paddle/v2/framework/framework.py @@ -53,8 +53,8 @@ class Variable(object): if is_new_var: self.desc.set_data_type(dtype) else: - old_dtype = self.data_type() - if dtype != old_shape: + old_dtype = self.data_type + if dtype != old_dtype: raise ValueError("Variable {0} has been created before. " "The previous data type is {1}; the new " "data type is {2}. They are not " @@ -113,6 +113,10 @@ class Variable(object): def lod_level(self): return self.desc.lod_level() + @property + def type(self): + return self.desc.type() + @staticmethod def _unique_var_name_(): uid = core.unique_integer() # unique during whole process. @@ -192,31 +196,32 @@ class Operator(object): self.desc.set_type(type) proto = OpProtoHolder.instance().get_op_proto(type) - if inputs is not None: - given = set() - need = set() - for n in inputs: - given.add(n) - for m in proto.inputs: - need.add(m.name) - if not given == need: - raise ValueError( - "Incorrect setting for input(s) of operator \"%s\". Need: [%s] Given: [%s]" - % (type, ", ".join(str(e) for e in need), ", ".join( - str(e) for e in given))) + def find_name(var_list, name): + for var_name in var_list: + if var_name == name: + return True + return False + if inputs is not None: for in_proto in proto.inputs: - in_argus = inputs[in_proto.name] - if not isinstance(in_argus, list): - in_argus = [in_argus] - if not in_proto.duplicable and len(in_argus) > 1: - raise ValueError( - "Input %s expects only one input, but %d are given." % - (in_proto.name, len(in_argus))) - in_argu_names = [] - for argu in in_argus: - in_argu_names.append(argu.name) - self.desc.set_input(in_proto.name, in_argu_names) + found = find_name(inputs, in_proto.name) + assert found or in_proto.dispensable, "Input {} not found".format( + in_proto.name) + + if found: + in_argus = inputs[in_proto.name] + if not isinstance(in_argus, list): + in_argus = [in_argus] + if not in_proto.duplicable and len(in_argus) > 1: + raise ValueError( + "Input %s expects only one input, but %d are given." + % (in_proto.name, len(in_argus))) + in_argu_names = [] + for argu in in_argus: + in_argu_names.append(argu.name) + self.desc.set_input(in_proto.name, in_argu_names) + else: + self.desc.set_input(in_proto.name, []) if outputs is not None: given = set() @@ -250,13 +255,14 @@ class Operator(object): attr_name = attr.name if (not attr_name in attrs) or (attrs[attr_name] is None): continue - if not isinstance(attrs[attr_name], Block): - self.desc.set_attr(attr_name, attrs[attr_name]) - else: + if isinstance(attrs[attr_name], Block): self.desc.set_block_attr(attr_name, attrs[attr_name].desc) + else: + self.desc.set_attr(attr_name, attrs[attr_name]) self.desc.check_attrs() if type not in {'feed', 'fetch'}: + self.desc.infer_var_type(self.block.desc) self.desc.infer_shape(self.block.desc) def __str__(self): diff --git a/python/paddle/v2/framework/layer_helper.py b/python/paddle/v2/framework/layer_helper.py index 849a6f43065ae95e908e449e9ef9300b64692e5e..f3da32f0e07a22204b3feaed5d1d8d01556e4655 100644 --- a/python/paddle/v2/framework/layer_helper.py +++ b/python/paddle/v2/framework/layer_helper.py @@ -1,8 +1,11 @@ -from paddle.v2.framework.framework import Variable, OpProtoHolder, g_program, g_init_program -import paddle.v2.framework.core as core import copy import itertools +import paddle.v2.framework.core as core + +from paddle.v2.framework.framework import Variable, g_program, \ + g_init_program + def unique_name(prefix): uid = core.unique_integer() # unique during whole process. @@ -120,10 +123,7 @@ class LayerHelper(object): if attr['name'] is None: attr['name'] = unique_name(".".join([self.name, suffix])) self.init_program.global_block().create_parameter( - name=attr['name'], - dtype=dtype, - shape=shape, - init_attr=attr['init_attr']) + dtype=dtype, shape=shape, **attr) return self.program.global_block().create_parameter( name=attr['name'], dtype=dtype, shape=shape) @@ -133,6 +133,9 @@ class LayerHelper(object): dtype=dtype, persistable=False) + def create_variable(self, *args, **kwargs): + return self.program.current_block().create_var(*args, **kwargs) + def create_global_variable(self, *args, **kwargs): return self.program.global_block().create_var( *args, persistable=False, **kwargs) diff --git a/python/paddle/v2/framework/layers.py b/python/paddle/v2/framework/layers.py index ac77aefa15333b06f9803ce1d91071df803483d1..6894c40c3a6514f448133f029c4de8cc30405515 100644 --- a/python/paddle/v2/framework/layers.py +++ b/python/paddle/v2/framework/layers.py @@ -1,9 +1,12 @@ -from paddle.v2.framework.layer_helper import LayerHelper +from paddle.v2.framework.layer_helper import LayerHelper, unique_name import paddle.v2.framework.core as core -from paddle.v2.framework.framework import OpProtoHolder, Variable +from paddle.v2.framework.framework import OpProtoHolder, Variable, Program import re -__all__ = ['fc', 'data', 'cross_entropy', 'conv2d', 'pool2d'] +__all__ = [ + 'fc', 'data', 'cross_entropy', 'conv2d', 'pool2d', 'embedding', 'concat', + 'StaticRNN' +] def fc(input, @@ -24,7 +27,9 @@ def fc(input, mul_results = [] for input_var, param_attr in helper.iter_inputs_and_params(): input_shape = input_var.shape - param_shape = list(input_shape[num_flatten_dims:]) + [size] + param_shape = [ + reduce(lambda a, b: a * b, input_shape[num_flatten_dims:], 1) + ] + [size] w = helper.create_parameter( attr=param_attr, shape=param_shape, dtype=dtype) @@ -36,10 +41,8 @@ def fc(input, "Y": w, }, outputs={"Out": tmp}, - attrs={ - 'x_num_col_dims': num_flatten_dims, - 'y_num_col_dims': len(input_shape) - num_flatten_dims - }) + attrs={'x_num_col_dims': num_flatten_dims, + 'y_num_col_dims': 1}) mul_results.append(tmp) # sum @@ -55,6 +58,24 @@ def fc(input, return helper.append_activation(pre_activation) +def embedding(input, + size, + data_type='float32', + param_attr=None, + program=None, + init_program=None): + helper = LayerHelper('embedding', **locals()) + w = helper.create_parameter( + attr=helper.param_attr, shape=size, dtype=data_type) + tmp = helper.create_tmp_variable(data_type) + helper.append_op( + type='lookup_table', + inputs={'Ids': input, + 'W': w}, + outputs={'Out': tmp}) + return tmp + + def data(name, shape, data_type='float32', @@ -122,6 +143,19 @@ _create_op_func_('mean') _create_op_func_('mul') +def concat(input, axis, program=None, init_program=None): + helper = LayerHelper('concat', **locals()) + if not isinstance(input, list) and not isinstance(input, tuple): + input = [input] + out = helper.create_tmp_variable(dtype=input[0].data_type) + helper.append_op( + type='concat', + inputs={'X': input}, + outputs={'Out': [out]}, + attrs={'axis': axis}) + return out + + def cross_entropy(input, label, **kwargs): helper = LayerHelper('cross_entropy', **kwargs) out = helper.create_tmp_variable(dtype=input.data_type) @@ -240,3 +274,170 @@ def pool2d(input, }) return pool_out + + +class BlockGuard(object): + """ + BlockGuard used to create sub-block in program by using Python `with` + keyword. + """ + + def __init__(self, program): + if not isinstance(program, Program): + raise TypeError("BlockGuard takes a program") + self.program = program + + def __enter__(self): + self.program.create_block() + + def __exit__(self, exc_type, exc_val, exc_tb): + self.program.rollback() + if exc_type is not None: + return False # re-raise exception + return True + + +class StaticRNNGuard(BlockGuard): + def __init__(self, rnn): + if not isinstance(rnn, StaticRNN): + raise TypeError("StaticRNNGuard takes an StaticRNN") + super(StaticRNNGuard, self).__init__(rnn.helper.program) + self.rnn = rnn + + def __enter__(self): + self.rnn.status = StaticRNN.IN_RNN_BLOCK + return super(StaticRNNGuard, self).__enter__() + + def __exit__(self, exc_type, exc_val, exc_tb): + self.rnn.status = StaticRNN.AFTER_RNN_BLOCK + self.rnn.complete_rnn_op() + return super(StaticRNNGuard, self).__exit__(exc_type, exc_val, exc_tb) + + +class StaticRNNMemoryLink(object): + """ + :param init: the initial variable for Memory + :type init: Variable + :param pre_mem: the memory variable in previous time step + :type pre_mem: Variable + :param mem: the memory variable in current time step + :type mem: Variable + """ + + def __init__(self, init, pre_mem, mem=None): + self.init = init + self.pre_mem = pre_mem + self.mem = mem + + +class StaticRNN(object): + BEFORE_RNN_BLOCK = 0 + IN_RNN_BLOCK = 1 + AFTER_RNN_BLOCK = 2 + + def __init__(self, name=None, program=None): + self.helper = LayerHelper("static_rnn", name=name, program=program) + self.memories = {} # memory map, from pre_mem.name --> MemoryLink + self.inputs = [] # input variable list in current block + self.outputs = [] # output variable list in parent block + self.status = StaticRNN.BEFORE_RNN_BLOCK # status flag. + # sequence length, since it is a static RNN, sequence length are fixed. + self.seq_len = None + + def step(self): + return StaticRNNGuard(self) + + def _assert_in_rnn_block_(self, method): + if self.status != StaticRNN.IN_RNN_BLOCK: + raise ValueError("You must invoke {0} in rnn block".format(method)) + + def memory(self, init=None, shape=None, dtype=None, init_value=0): + self._assert_in_rnn_block_('memory') + if init is None: + if shape is None or dtype is None: + raise ValueError( + "if init is None, memory at least need shape and dtype") + parent_block = self.parent_block() + var_name = unique_name("@".join([self.helper.name, "memory_boot"])) + boot_var = parent_block.create_var( + name=var_name, shape=shape, dtype=dtype, persistable=False) + + parent_block.append_op( + type="fill_constant", + inputs={}, + outputs={'Out': [boot_var]}, + attrs={ + 'value': init_value, + 'shape': boot_var.shape, + 'data_type': boot_var.data_type + }) + + return self.memory(init=boot_var) + else: + pre_mem = self.helper.create_variable( + name=unique_name("@".join([self.helper.name, "mem"])), + dtype=init.data_type, + shape=init.shape) + self.memories[pre_mem.name] = StaticRNNMemoryLink( + init=init, pre_mem=pre_mem) + return pre_mem + + def step_input(self, x): + self._assert_in_rnn_block_('step_input') + if not isinstance(x, Variable): + raise TypeError("step input takes a Variable") + if self.seq_len is None: + self.seq_len = x.shape[1] + elif self.seq_len != x.shape[1]: + raise ValueError("Static RNN only take fix seq_len input") + + ipt = self.helper.create_variable( + name=x.name, + dtype=x.data_type, + shape=[-1] + list(x.shape[2:]), + type=x.type) + self.inputs.append(ipt) + return ipt + + def step_output(self, o): + self._assert_in_rnn_block_('step_output') + if not isinstance(o, Variable): + raise TypeError("step output takes a Variable") + + out_var = self.parent_block().create_var( + name=o.name, + shape=[-1, self.seq_len] + list(o.shape[1:]), + dtype=o.data_type) + + self.outputs.append(out_var) + + def output(self, *outputs): + for each in outputs: + self.step_output(each) + + def update_memory(self, mem, var): + if not isinstance(mem, Variable) or not isinstance(var, Variable): + raise TypeError("update memory should take variables") + self.memories[mem.name].mem = var + + def parent_block(self): + prog = self.helper.program + parent_idx = prog.current_block().parent_idx + assert parent_idx >= 0 + parent_block = prog.block(parent_idx) + return parent_block + + def __call__(self, *args, **kwargs): + if self.status != StaticRNN.AFTER_RNN_BLOCK: + raise ValueError("RNN output can only be retrieved after rnn block") + if len(self.outputs) == 0: + raise ValueError("RNN has no output") + elif len(self.outputs) == 1: + return self.outputs[0] + else: + return self.outputs + + def complete_rnn_op(self): + # TODO(yuyang18): Create RNN Op here. + # Implement this method after RNN op complete. + pass diff --git a/python/paddle/v2/framework/tests/op_test.py b/python/paddle/v2/framework/tests/op_test.py index 215fa0b94e423755b7bc3f05a2b14a8c85451202..0fdc21ef5133d17b33860a0e095574d3136b2fd1 100644 --- a/python/paddle/v2/framework/tests/op_test.py +++ b/python/paddle/v2/framework/tests/op_test.py @@ -4,6 +4,8 @@ import random import itertools import paddle.v2.framework.core as core from paddle.v2.framework.op import Operator +from paddle.v2.framework.executor import Executor +from paddle.v2.framework.framework import Program, OpProtoHolder def grad_var_name(var_name): @@ -177,7 +179,12 @@ def get_backward_op(scope, op, no_grad_set): return backward_op -def get_gradient(scope, op, inputs, outputs, grad_name, place, +def get_gradient(scope, + op, + inputs, + outputs, + grad_names, + place, no_grad_set=None): ctx = core.DeviceContext.create(place) @@ -193,8 +200,52 @@ def get_gradient(scope, op, inputs, outputs, grad_name, place, backward_op.run(scope, ctx) - out = np.array(scope.find_var(grad_name).get_tensor()) - return out + return [ + np.array(scope.find_var(grad_name).get_tensor()) + for grad_name in grad_names + ] + + +def append_input_output(block, op_proto, np_list, is_input): + '''Insert VarDesc and generate Python variable instance''' + proto_list = op_proto.inputs if is_input else op_proto.outputs + + def create_var(block, name, np_list, var_proto): + if name not in np_list: + assert var_proto.intermediate, "{} not found".format(name) + shape = None + lod_level = None + else: + np_value = np_list[name] + if isinstance(np_value, tuple): + shape = list(np_value[0].shape) + lod_level = len(np_value[1]) + else: + shape = list(np_value.shape) + lod_level = 0 + return block.create_var( + dtype="float32", shape=shape, lod_level=lod_level, name=name) + + var_dict = {} + for var_proto in proto_list: + var_name = str(var_proto.name) + if is_input: + if (var_name not in np_list) and var_proto.dispensable: + continue + assert (var_name in np_list) or (var_proto.dispensable), \ + "Missing {} as input".format(var_name) + if var_proto.duplicable: + assert isinstance(np_list[var_name], list), \ + "Duplicable {} should be set as list".format(var_name) + var_list = [] + for (name, np_value) in np_list[var_name]: + var_list.append( + create_var(block, name, {name: np_value}, var_proto)) + var_dict[var_name] = var_list + else: + var_dict[var_name] = create_var(block, var_name, np_list, var_proto) + + return var_dict class OpTest(unittest.TestCase): @@ -213,48 +264,93 @@ class OpTest(unittest.TestCase): np.random.set_state(cls._np_rand_state) random.setstate(cls._py_rand_state) + def feed_var(self, input_vars, place): + feed_map = {} + for var_name in input_vars: + if isinstance(input_vars[var_name], list): + for name, np_value in self.inputs[var_name]: + tensor = core.LoDTensor() + tensor.set(np_value, place) + feed_map[name] = tensor + else: + tensor = core.LoDTensor() + if isinstance(self.inputs[var_name], tuple): + tensor.set(self.inputs[var_name][0], place) + tensor.set_lod(self.inputs[var_name][1]) + else: + tensor.set(self.inputs[var_name], place) + feed_map[var_name] = tensor + + return feed_map + def check_output_with_place(self, place, atol): - self.scope = core.Scope() - op_inputs = self.inputs if hasattr(self, "inputs") else dict() - op_outputs = self.outputs if hasattr(self, "outputs") else dict() - op_attrs = self.attrs if hasattr(self, "attrs") else dict() - self.op = create_op(self.scope, self.op_type, op_inputs, op_outputs, - op_attrs) - if isinstance(place, core.GPUPlace) and not self.op.support_gpu(): - return - set_input(self.scope, self.op, self.inputs, place) - ctx = core.DeviceContext.create(place) - self.op.run(self.scope, ctx) + op_proto = OpProtoHolder.instance().get_op_proto(self.op_type) + + program = Program() + block = program.global_block() + + inputs = append_input_output(block, op_proto, self.inputs, True) + outputs = append_input_output(block, op_proto, self.outputs, False) + + op = block.append_op( + type=self.op_type, + inputs=inputs, + outputs=outputs, + attrs=self.attrs if hasattr(self, "attrs") else dict()) + + fetch_list = [] + for var_name, var in outputs.iteritems(): + if var_name in self.outputs: + if isinstance(var, list): + for v in var: + fetch_list.append(v) + else: + fetch_list.append(var) - for out_name, out_dup in Operator.get_op_outputs(self.op.type()): + feed_map = self.feed_var(inputs, place) + + exe = Executor(place) + outs = exe.run(program, feed=feed_map, fetch_list=fetch_list) + + for out_name, out_dup in Operator.get_op_outputs(self.op_type): if out_name not in self.outputs: continue + def find_actual(target_name, fetch_list): + found = [ + i for i, var in enumerate(fetch_list) + if var.name == target_name + ] + self.assertTrue( + len(found) == 1, "Found {} {}".format( + len(found), target_name)) + return found[0] + if out_dup: sub_out = self.outputs[out_name] if not isinstance(sub_out, list): raise AssertionError("sub_out type %s is not list", type(sub_out)) - for sub_out_name, expect in sub_out: - actual = np.array( - self.scope.find_var(sub_out_name).get_tensor()) + idx = find_actual(sub_out_name, fetch_list) + actual = outs[idx] self.assertTrue( np.allclose( actual, expect, atol=atol), - "output name: " + out_name + " has diff.") + "Output (" + sub_out_name + ") has diff at " + + str(place)) else: - actual = np.array(self.scope.find_var(out_name).get_tensor()) + idx = find_actual(out_name, fetch_list) + actual = outs[idx] expect = self.outputs[out_name] - self.assertTrue( np.allclose( actual, expect, atol=atol), - "output name: " + out_name + " has diff.") + "Output (" + out_name + ") has diff at " + str(place)) def check_output(self, atol=1e-5): places = [core.CPUPlace()] - if core.is_compile_gpu(): + if core.is_compile_gpu() and core.op_support_gpu(self.op_type): places.append(core.GPUPlace(0)) for place in places: self.check_output_with_place(place, atol) @@ -310,11 +406,9 @@ class OpTest(unittest.TestCase): ] cpu_place = core.CPUPlace() - cpu_analytic_grads = [ - get_gradient(self.scope, self.op, self.inputs, self.outputs, - grad_name, cpu_place, no_grad_set) - for grad_name in grad_names - ] + cpu_analytic_grads = get_gradient(self.scope, self.op, self.inputs, + self.outputs, grad_names, cpu_place, + no_grad_set) self.__assert_is_close(numeric_grads, cpu_analytic_grads, grad_names, max_relative_error, @@ -322,11 +416,9 @@ class OpTest(unittest.TestCase): if core.is_compile_gpu() and self.op.support_gpu(): gpu_place = core.GPUPlace(0) - gpu_analytic_grads = [ - get_gradient(self.scope, self.op, self.inputs, self.outputs, - grad_name, gpu_place, no_grad_set) - for grad_name in grad_names - ] + gpu_analytic_grads = get_gradient(self.scope, self.op, self.inputs, + self.outputs, grad_names, + gpu_place, no_grad_set) self.__assert_is_close(numeric_grads, gpu_analytic_grads, grad_names, max_relative_error, diff --git a/python/paddle/v2/framework/tests/test_accuracy_op.py b/python/paddle/v2/framework/tests/test_accuracy_op.py index b6f3a35d6f58ba90b39e3f6296ae635220a2e965..02be9a02910bee3eae63e12cceaa51cf53591539 100644 --- a/python/paddle/v2/framework/tests/test_accuracy_op.py +++ b/python/paddle/v2/framework/tests/test_accuracy_op.py @@ -16,7 +16,9 @@ class TestAccuracyOp(OpTest): if ele == label[rowid]: num_correct += 1 break - self.outputs = {'Accuracy': [num_correct / float(n)]} + self.outputs = { + 'Accuracy': np.array([num_correct / float(n)]).astype("float32") + } def test_check_output(self): self.check_output() diff --git a/python/paddle/v2/framework/tests/test_activation_op.py b/python/paddle/v2/framework/tests/test_activation_op.py index 5831b880e4c5ef881929920e87ac64d6c87a2ab5..c1668cd00ff6c3782dd17a789e4ad93b92e5209d 100644 --- a/python/paddle/v2/framework/tests/test_activation_op.py +++ b/python/paddle/v2/framework/tests/test_activation_op.py @@ -172,8 +172,8 @@ class TestBRelu(OpTest): def setUp(self): self.op_type = "brelu" x = np.random.uniform(-1, 1, [4, 4]).astype("float32") - t_min = 1 - t_max = 4 + t_min = 1.0 + t_max = 4.0 # The same with TestAbs x[np.abs(x - t_min) < 0.005] = t_min + 0.02 x[np.abs(x - t_max) < 0.005] = t_max + 0.02 @@ -218,7 +218,7 @@ class TestSoftRelu(OpTest): def setUp(self): self.op_type = "soft_relu" x = np.random.uniform(-3, 3, [4, 4]).astype("float32") - threshold = 2 + threshold = 2.0 # The same reason with TestAbs x[np.abs(x - threshold) < 0.005] = threshold + 0.02 x[np.abs(x + threshold) < 0.005] = -threshold + 0.02 @@ -303,7 +303,7 @@ class TestPow(OpTest): def setUp(self): self.op_type = "pow" self.inputs = {'X': np.random.uniform(1, 2, [11, 17]).astype("float32")} - self.attrs = {'factor': 3} + self.attrs = {'factor': 3.0} self.outputs = {'Y': np.power(self.inputs['X'], 3)} def test_check_output(self): diff --git a/python/paddle/v2/framework/tests/test_clip_op.py b/python/paddle/v2/framework/tests/test_clip_op.py index 5df6a494989017bab0416e0af962b2a85db046ba..a7e1bf174408e4139db0435d9f4bb0c885f76705 100644 --- a/python/paddle/v2/framework/tests/test_clip_op.py +++ b/python/paddle/v2/framework/tests/test_clip_op.py @@ -37,14 +37,14 @@ class TestCase1(TestClipOp): def initTestCase(self): self.shape = (8, 16, 8) self.max = 0.7 - self.min = 0 + self.min = 0.0 class TestCase2(TestClipOp): def initTestCase(self): self.shape = (8, 16) - self.max = 1 - self.min = 0 + self.max = 1.0 + self.min = 0.0 class TestCase3(TestClipOp): diff --git a/python/paddle/v2/framework/tests/test_conv2dtranspose_op.py b/python/paddle/v2/framework/tests/test_conv2dtranspose_op.py new file mode 100644 index 0000000000000000000000000000000000000000..71ca262f00378381d2d65e87d198d6b1755e9a2b --- /dev/null +++ b/python/paddle/v2/framework/tests/test_conv2dtranspose_op.py @@ -0,0 +1,102 @@ +import unittest +import numpy as np +from op_test import OpTest + + +def conv2dtranspose_forward_naive(input_, filter_, conv2dtranspose_param): + # [2, 3, 5, 5] + in_n, in_c, in_h, in_w = input_.shape + # [3, 6, 3, 3] + f_c, out_c, f_h, f_w = filter_.shape + assert in_c == f_c + + stride, pad = conv2dtranspose_param['stride'], conv2dtranspose_param['pad'] + out_h = (in_h - 1) * stride[0] + f_h + out_w = (in_w - 1) * stride[1] + f_w + + out = np.zeros((in_n, out_c, out_h, out_w)) + + for n in range(in_n): + for i in range(in_h): + for j in range(in_w): + input_masked = input_[n, :, i, j] # (c) + input_masked = np.reshape(input_masked, (in_c, 1, 1)) + input_masked = np.tile(input_masked, (1, f_h, f_w)) + + for k in range(out_c): + tmp_out = np.sum(input_masked * filter_[:, k, :, :], axis=0) + i1, i2 = i * stride[0], i * stride[0] + f_h + j1, j2 = j * stride[0], j * stride[0] + f_w + out[n, k, i1:i2, j1:j2] += tmp_out + + return out + + +class TestConv2dTransposeOp(OpTest): + def setUp(self): + # init as conv transpose + self.init_op_type() + + # [2, 3, 5, 5] -> kernel [3, 6, 3, 3] -> output [2, 6, 7, 7] + self.init_test_case() + + conv2dtranspose_param = {'stride': self.stride, 'pad': self.pad} + input_ = np.random.random(self.input_size).astype("float32") + filter_ = np.random.random(self.filter_size).astype("float32") + output = conv2dtranspose_forward_naive(input_, filter_, + conv2dtranspose_param) + # print 'deconv output py', output, output.shape + + self.inputs = {'Input': input_, 'Filter': filter_} + self.attrs = { + 'strides': self.stride, + 'paddings': self.pad, + # 'dilations': self.dilations + } + self.outputs = {'Output': output} + + def test_check_output(self): + print 'check output here' + self.check_output() + + def test_check_grad(self): + self.check_grad( + set(['Input', 'Filter']), 'Output', max_relative_error=0.05) + + def test_check_grad_no_filter(self): + self.check_grad( + ['Input'], + 'Output', + max_relative_error=0.05, + no_grad_set=set(['Filter'])) + + def test_check_grad_no_input(self): + self.check_grad( + ['Filter'], + 'Output', + max_relative_error=0.05, + no_grad_set=set(['Input'])) + + def init_test_case(self): + self.pad = [0, 0] + self.stride = [1, 1] + self.dilations = [1, 1] + self.input_size = [2, 3, 5, 5] # NCHW + f_c = self.input_size[1] + self.filter_size = [f_c, 6, 3, 3] + + def init_op_type(self): + self.op_type = "conv2dtranspose" + + +""" +class TestCudnn(TestConv2dOp): + def init_group(self): + self.groups = 1 + + def init_op_type(self): + self.op_type = "conv_cudnn" +""" + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_fc_op.py b/python/paddle/v2/framework/tests/test_fc_op.py deleted file mode 100644 index 9f56fe5049c66aa5fce40ce815105e7871ebc3b2..0000000000000000000000000000000000000000 --- a/python/paddle/v2/framework/tests/test_fc_op.py +++ /dev/null @@ -1,62 +0,0 @@ -import unittest -import numpy as np -from op_test import OpTest - - -class TestFCOp1(OpTest): - def setUp(self): - x0 = np.random.random((16, 32)).astype("float32") - w0 = np.random.random((32, 10)).astype("float32") - - mul_out0 = np.dot(x0, w0) - identity_out = mul_out0 - - self.op_type = "fc" - self.inputs = {"X": [("X0", x0)], "W": [("W0", w0)]} - self.outputs = {"MulOut": [("MulOut0", mul_out0)], "Out": identity_out} - - def test_check_output(self): - self.check_output() - - def test_check_grad(self): - self.check_grad(["X0", "W0"], "Out", max_relative_error=0.01) - - -class TestFCOp2(OpTest): - def setUp(self): - x0 = np.random.random((16, 4, 8)).astype("float32") - x1 = np.random.random((4, 4, 32)).astype("float32") - w0 = np.random.random((32, 10)).astype("float32") - w1 = np.random.random((32, 10)).astype("float32") - b = np.random.random(10).astype("float32") - - mul_out0 = np.dot(x0.reshape(16, 4 * 8), w0) - mul_out1 = np.dot(x1.reshape(4 * 4, 32), w1) - sum_out = mul_out0 + mul_out1 - add_out = np.add(sum_out, b) - sigmoid_out = 1 / (1 + np.exp(-add_out)) - - self.op_type = "fc" - self.inputs = { - "X": [("X0", x0), ("X1", x1)], - "W": [("W0", w0), ("W1", w1)], - "B": b - } - self.attrs = {"xNumColDims": [1, 2], "activation": "sigmoid"} - self.outputs = { - "MulOut": [("MulOut0", mul_out0), ("MulOut1", mul_out1)], - "SumOut": sum_out, - "AddOut": add_out, - "Out": sigmoid_out - } - - def test_check_output(self): - self.check_output() - - def test_check_grad(self): - self.check_grad( - ["X0", "X1", "W0", "W1", "B"], "Out", max_relative_error=0.01) - - -if __name__ == '__main__': - unittest.main() diff --git a/python/paddle/v2/framework/tests/test_identity_op.py b/python/paddle/v2/framework/tests/test_identity_op.py deleted file mode 100644 index 26cec1fcc3ad003281c9c41571d475b55bd30026..0000000000000000000000000000000000000000 --- a/python/paddle/v2/framework/tests/test_identity_op.py +++ /dev/null @@ -1,20 +0,0 @@ -import unittest -import numpy as np -from op_test import OpTest - - -class TestIdentityOp(OpTest): - def setUp(self): - self.op_type = "identity" - self.inputs = {'X': np.random.random((10, 10)).astype("float32")} - self.outputs = {'Y': self.inputs['X']} - - def test_check_output(self): - self.check_output() - - def test_check_grad(self): - self.check_grad(['X'], 'Y') - - -if __name__ == "__main__": - unittest.main() diff --git a/python/paddle/v2/framework/tests/test_interp_op.py b/python/paddle/v2/framework/tests/test_interp_op.py deleted file mode 100644 index 066569b96c9611bd20e7192f8bd6caa6e467202f..0000000000000000000000000000000000000000 --- a/python/paddle/v2/framework/tests/test_interp_op.py +++ /dev/null @@ -1,28 +0,0 @@ -import unittest -import numpy as np -from op_test import OpTest - - -class TestInterpOp(OpTest): - def setUp(self): - self.op_type = "interp" - x = np.random.random((2, 3)).astype("float32") - y = np.random.random((2, 3)).astype("float32") - w = np.random.random(2).astype("float32") - - sub_out = x - y - mul_out = sub_out * w.reshape(2, 1) - out = mul_out + y - - self.inputs = {'X': x, 'Y': y, 'W': w} - self.outputs = {'Out': out, 'SubOut': sub_out, 'MulOut': mul_out} - - def test_check_output(self): - self.check_output() - - def test_check_grad_normal(self): - self.check_grad(['X', 'Y'], 'Out') - - -if __name__ == "__main__": - unittest.main() diff --git a/python/paddle/v2/framework/tests/test_layers.py b/python/paddle/v2/framework/tests/test_layers.py index 4ecc02b12d8db53e897dea10186bc053d05be303..7aedb985f98f2d8953e0968d19ece9c70d792246 100644 --- a/python/paddle/v2/framework/tests/test_layers.py +++ b/python/paddle/v2/framework/tests/test_layers.py @@ -88,6 +88,77 @@ class TestBook(unittest.TestCase): print str(program) + def test_word_embedding(self): + program = Program() + dict_size = 10000 + embed_size = 32 + first_word = layers.data( + name='firstw', shape=[1], data_type='int32', program=program) + second_word = layers.data( + name='secondw', shape=[1], data_type='int32', program=program) + third_word = layers.data( + name='thirdw', shape=[1], data_type='int32', program=program) + forth_word = layers.data( + name='forthw', shape=[1], data_type='int32', program=program) + next_word = layers.data( + name='nextw', shape=[1], data_type='int32', program=program) + + embed_param_attr_1 = { + 'name': 'shared_w', + 'init_attr': { + 'max': 1.0, + 'type': 'uniform_random', + 'min': -1.0 + } + } + embed_param_attr_2 = {'name': 'shared_w'} + + embed_first = layers.embedding( + input=first_word, + size=[dict_size, embed_size], + data_type='float32', + param_attr=embed_param_attr_1, + program=program) + embed_second = layers.embedding( + input=second_word, + size=[dict_size, embed_size], + data_type='float32', + param_attr=embed_param_attr_2, + program=program) + + embed_third = layers.embedding( + input=third_word, + size=[dict_size, embed_size], + data_type='float32', + param_attr=embed_param_attr_2, + program=program) + embed_forth = layers.embedding( + input=forth_word, + size=[dict_size, embed_size], + data_type='float32', + param_attr=embed_param_attr_2, + program=program) + + concat_embed = layers.concat( + input=[embed_first, embed_second, embed_third, embed_forth], + axis=1, + program=program) + + hidden1 = layers.fc(input=concat_embed, + size=256, + act='sigmoid', + program=program) + predict_word = layers.fc(input=hidden1, + size=dict_size, + act='softmax', + program=program) + cost = layers.cross_entropy( + input=predict_word, label=next_word, program=program) + avg_cost = layers.mean(x=cost, program=program) + self.assertIsNotNone(avg_cost) + + print str(program) + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/v2/framework/tests/test_lookup_table_op.py b/python/paddle/v2/framework/tests/test_lookup_table_op.py index b259bb67e832adcb31b0ab4e992738be2b85f884..2c48f9bf93b939aa631cd54e8fb14b5cba22f2e0 100644 --- a/python/paddle/v2/framework/tests/test_lookup_table_op.py +++ b/python/paddle/v2/framework/tests/test_lookup_table_op.py @@ -8,7 +8,8 @@ class TestLookupTableOp(OpTest): self.op_type = "lookup_table" table = np.random.random((17, 31)).astype("float32") ids = np.random.randint(0, 17, 4).astype("int32") - self.inputs = {'W': table, 'Ids': ids} + ids_expand = np.expand_dims(ids, axis=1) + self.inputs = {'W': table, 'Ids': ids_expand} self.outputs = {'Out': table[ids]} def test_check_output(self): diff --git a/python/paddle/v2/framework/tests/test_lstm_op.py b/python/paddle/v2/framework/tests/test_lstm_op.py new file mode 100644 index 0000000000000000000000000000000000000000..bcce8d32c944a39e6d6aad4c99f8aa152222c3c1 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_lstm_op.py @@ -0,0 +1,185 @@ +import unittest +import numpy as np +from op_test import OpTest + +SIGMOID_THRESHOLD_MIN = -40.0 +SIGMOID_THRESHOLD_MAX = 13.0 +EXP_MAX_INPUT = 40.0 + + +def identity(x): + return x + + +def sigmoid(x): + y = np.copy(x) + y[x < SIGMOID_THRESHOLD_MIN] = SIGMOID_THRESHOLD_MIN + y[x > SIGMOID_THRESHOLD_MAX] = SIGMOID_THRESHOLD_MAX + return 1. / (1. + np.exp(-y)) + + +def tanh(x): + y = -2. * x + y[y > EXP_MAX_INPUT] = EXP_MAX_INPUT + return (2. / (1. + np.exp(y))) - 1. + + +def relu(x): + return np.maximum(x, 0) + + +ACTVATION = { + 'identity': identity, + 'sigmoid': sigmoid, + 'tanh': tanh, + 'relu': relu +} + + +def lstm( + input, # T x 4D + lod, # 1 x N + h0=None, # N x D + c0=None, # N x D + w_h=None, # D x 4D + w_b=None, # 1 x 4D + w_c=None, # 1 x 3D + is_reverse=False, + act_gate=None, + act_cell=None, + act_cand=None): + def _step(x, w_h, w_c, h_pre, c_pre, act_gate, act_cell, act_cand): + g = np.dot(h_pre, w_h) # 1 x 4D + g = g + x + g = np.reshape(g, (1, g.size)) + c_tmp, g_i, g_f, g_o = np.split(g, 4, axis=1) + if w_c is None: + g_i = act_gate(g_i) # 1 x D + g_f = act_gate(g_f) # 1 x D + else: + w_ic, w_fc, w_oc = np.split(w_c, 3, axis=1) + g_i = act_gate(g_i + w_ic * c_pre) # 1 x D + g_f = act_gate(g_f + w_fc * c_pre) # 1 x D + c = g_f * c_pre + g_i * act_cand(c_tmp) # 1 x D + + if w_c is None: + g_o = act_gate(g_o) # 1 x D + else: + _, _, w_oc = np.split(w_c, 3, axis=1) + g_o = act_gate(g_o + w_oc * c) # 1 x D + h = g_o * act_cell(c) + bg = np.concatenate((act_cand(c_tmp), g_i, g_f, g_o), axis=1) + return h, c, bg + + def _reverse(x, lod): + y = np.zeros_like(x) + for i in range(len(lod) - 1): + b, e = lod[i], lod[i + 1] + y[b:e, :] = np.flip(x[b:e, :], 0) + return y + + offset = lod[0] + batch_size = len(offset) - 1 + hidden = [] + cell = [] + gate = [] + input = _reverse(input, offset) if is_reverse else input + if w_b is not None: + input = input + np.tile(w_b, (offset[-1], 1)) + for i in range(batch_size): + # compute one sequence + seq_len = offset[i + 1] - offset[i] + x = input[offset[i]:offset[i + 1], :] + h_pre = h0[i] # 1 x D + c_pre = c0[i] # 1 x D + for j in range(seq_len): + # compute one step + h_pre, c_pre, g_pre = _step(x[j], w_h, w_c, h_pre, c_pre, act_gate, + act_cell, act_cand) + hidden.append(h_pre.flatten()) + cell.append(c_pre.flatten()) + gate.append(g_pre.flatten()) + + hidden = np.array(hidden).astype("float64") + cell = np.array(cell).astype("float64") + gate = np.array(gate).astype("float64") + + hidden = _reverse(hidden, offset) if is_reverse else hidden + cell = _reverse(cell, offset) if is_reverse else cell + + assert gate.shape == input.shape + assert hidden.shape == (input.shape[0], input.shape[1] / 4) + assert cell.shape == (input.shape[0], input.shape[1] / 4) + return hidden, cell, gate + + +class TestLstmOp(OpTest): + def set_data(self): + self.lod = [[0, 2, 6, 9]] + self.D = 64 + self.sort_idx = [2, 6, 0, 3, 7, 1, 4, 8, 5] + + self.act_gate = "sigmoid" + self.act_cell = "tanh" + self.act_cand = "tanh" + + self.is_reverse = False + + def setUp(self): + self.set_data() + self.op_type = "lstm" + + T = self.lod[0][-1] + N = len(self.lod[0]) - 1 + + x = np.random.normal(size=(T, 4 * self.D)).astype("float64") + h0 = np.zeros((N, self.D)).astype("float64") + c0 = np.zeros((N, self.D)).astype("float64") + w = np.random.normal(size=(self.D, 4 * self.D)).astype("float64") + b = np.random.normal(size=(1, 7 * self.D)).astype("float64") + + w_b = b[:, 0:4 * self.D] + w_c = b[:, 4 * self.D:] + h, c, g = lstm(x, self.lod, h0, c0, w, w_b, w_c, self.is_reverse, + ACTVATION[self.act_gate], ACTVATION[self.act_cell], + ACTVATION[self.act_cand]) + + g_sort = np.zeros_like(x) + for i, j in enumerate(self.sort_idx): + g_sort[i, :] = g[j, :] + + self.inputs = { + 'Input': (x, self.lod), + 'H0': h0, + 'C0': c0, + 'Weight': w, + 'Bias': b + } + self.outputs = {'Hidden': h, 'Cell': c, 'BatchGate': g_sort} + self.attrs = { + 'usePeepholes': True, + 'isReverse': self.is_reverse, + 'gateActivation': 'sigmoid', + 'cellActivation': 'tanh', + 'candidateActivation': 'tanh' + } + + def test_check_output(self): + self.check_output() + + +class TestLstmOpRerverse(TestLstmOp): + def set_data(self): + self.lod = [[0, 2, 6, 9]] + self.D = 64 + self.sort_idx = [2, 6, 0, 3, 7, 1, 4, 8, 5] + + self.act_gate = "sigmoid" + self.act_cell = "tanh" + self.act_cand = "tanh" + + self.is_reverse = True + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_mul_op.py b/python/paddle/v2/framework/tests/test_mul_op.py index b3d95a56b88e510734da54f36ff21ccd7e1baabb..57d6d7e7e095cab2c3afb60d229fc09da98aed8b 100644 --- a/python/paddle/v2/framework/tests/test_mul_op.py +++ b/python/paddle/v2/framework/tests/test_mul_op.py @@ -35,10 +35,10 @@ class TestMulOp2(OpTest): 'Y': np.random.random((4, 30, 8, 2, 9)).astype("float32") } self.attrs = {'x_num_col_dims': 2, 'y_num_col_dims': 2} - self.outputs = { - 'Out': np.dot(self.inputs['X'].reshape(15 * 4, 12 * 10), - self.inputs['Y'].reshape(4 * 30, 8 * 2 * 9)) - } + result = np.dot(self.inputs['X'].reshape(15 * 4, 12 * 10), + self.inputs['Y'].reshape(4 * 30, 8 * 2 * 9)) + result = result.reshape(15, 4, 8, 2, 9) + self.outputs = {'Out': result} def test_check_output(self): self.check_output() diff --git a/python/paddle/v2/framework/tests/test_pad_op.py b/python/paddle/v2/framework/tests/test_pad_op.py index 9052e63b5683801da7c73be4de23013c949add98..55f1774e5755c846f60a2f1df3e705444a81192b 100644 --- a/python/paddle/v2/framework/tests/test_pad_op.py +++ b/python/paddle/v2/framework/tests/test_pad_op.py @@ -27,7 +27,7 @@ class TestPadOp(OpTest): def initTestCase(self): self.shape = (16, 16) self.paddings = [(0, 1), (2, 3)] - self.pad_value = 0 + self.pad_value = 0.0 class TestCase1(TestPadOp): @@ -41,7 +41,7 @@ class TestCase2(TestPadOp): def initTestCase(self): self.shape = (2, 2, 2) self.paddings = [(0, 0), (0, 0), (1, 2)] - self.pad_value = 1 + self.pad_value = 1.0 class TestCase3(TestPadOp): diff --git a/python/paddle/v2/framework/tests/test_reduce_op.py b/python/paddle/v2/framework/tests/test_reduce_op.py index 0fec31c2e22e1eda2c62aa9b38487d703815f414..70359d60cbe656150877673c63e81eae92d8ab9a 100644 --- a/python/paddle/v2/framework/tests/test_reduce_op.py +++ b/python/paddle/v2/framework/tests/test_reduce_op.py @@ -85,33 +85,5 @@ class Test1DReduce(OpTest): self.check_grad(['X'], 'Out') -class TestNorm(OpTest): - def setUp(self): - # use x away from 0 to avoid errors of numerical gradient when gradient near 0 - x = np.random.random((5, 6, 10)).astype("float32") + 0.2 - p = 2 - dim = 1 - keep_dim = False - abs_out = np.absolute(x) - pow_out = np.power(x, p) - sum_out = np.sum(pow_out, axis=dim, keepdims=keep_dim) - out = np.power(sum_out, 1. / p) - self.op_type = "norm" - self.inputs = {'X': x} - self.attrs = {"p": p, "dim": dim, "keep_dim": keep_dim} - self.outputs = { - "AbsOut": abs_out, - "PowOut": pow_out, - "SumOut": sum_out, - "Out": out - } - - def test_check_output(self): - self.check_output() - - def test_check_grad(self): - self.check_grad(['X'], 'Out', max_relative_error=0.01) - - if __name__ == '__main__': unittest.main() diff --git a/python/paddle/v2/framework/tests/test_rnn_helpers.py b/python/paddle/v2/framework/tests/test_rnn_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..be0ecfb129aa181229bc42d8d6818ad860991965 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_rnn_helpers.py @@ -0,0 +1,38 @@ +import unittest +from paddle.v2.framework.layers import * +from paddle.v2.framework.framework import g_program + + +class TestRNN(unittest.TestCase): + def test_rnn(self): + img = data( + shape=[ + 80, # sequence length + 22, # image height + 22 + ], # image width + data_type='float32', + name='image') + hidden = fc(input=img, size=100, act='sigmoid', num_flatten_dims=2) + self.assertEqual((-1, 80, 100), hidden.shape) + hidden = fc(input=hidden, size=100, act='sigmoid', num_flatten_dims=2) + self.assertEqual((-1, 80, 100), hidden.shape) + + rnn = StaticRNN() + with rnn.step(): + hidden = rnn.step_input(hidden) + self.assertEqual((-1, 100), hidden.shape) + memory = rnn.memory(shape=(-1, 32), dtype='float32', init_value=0.0) + + rnn_out = fc(input=[hidden, memory], size=32, act='sigmoid') + self.assertEqual((-1, 32), rnn_out.shape) + rnn.update_memory(memory, rnn_out) + rnn.output(rnn_out) + + out = rnn() + self.assertEqual((-1, 80, 32), out.shape) + print g_program + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_word2vec.py b/python/paddle/v2/framework/tests/test_word2vec.py new file mode 100644 index 0000000000000000000000000000000000000000..b5d98035156c425ab97d2bf75f8f09c71884368f --- /dev/null +++ b/python/paddle/v2/framework/tests/test_word2vec.py @@ -0,0 +1,165 @@ +import paddle.v2 as paddle +import paddle.v2.framework.layers as layers +import paddle.v2.framework.core as core +import paddle.v2.framework.optimizer as optimizer + +from paddle.v2.framework.framework import Program, g_program +from paddle.v2.framework.executor import Executor + +import numpy as np + +init_program = Program() +program = Program() + +embed_size = 32 +hidden_size = 256 +N = 5 +batch_size = 32 + +word_dict = paddle.dataset.imikolov.build_dict() +dict_size = len(word_dict) + +first_word = layers.data( + name='firstw', + shape=[1], + data_type='int32', + program=program, + init_program=init_program) +second_word = layers.data( + name='secondw', + shape=[1], + data_type='int32', + program=program, + init_program=init_program) +third_word = layers.data( + name='thirdw', + shape=[1], + data_type='int32', + program=program, + init_program=init_program) +forth_word = layers.data( + name='forthw', + shape=[1], + data_type='int32', + program=program, + init_program=init_program) +next_word = layers.data( + name='nextw', + shape=[1], + data_type='int32', + program=program, + init_program=init_program) + +embed_param_attr_1 = { + 'name': 'shared_w', + 'init_attr': { + 'max': 1.0, + 'type': 'uniform_random', + 'min': -1.0 + } +} +embed_param_attr_2 = {'name': 'shared_w'} + +embed_first = layers.embedding( + input=first_word, + size=[dict_size, embed_size], + data_type='float32', + param_attr=embed_param_attr_1, + program=program, + init_program=init_program) +embed_second = layers.embedding( + input=second_word, + size=[dict_size, embed_size], + data_type='float32', + param_attr=embed_param_attr_2, + program=program, + init_program=init_program) + +embed_third = layers.embedding( + input=third_word, + size=[dict_size, embed_size], + data_type='float32', + param_attr=embed_param_attr_2, + program=program, + init_program=init_program) +embed_forth = layers.embedding( + input=forth_word, + size=[dict_size, embed_size], + data_type='float32', + param_attr=embed_param_attr_2, + program=program, + init_program=init_program) + +concat_embed = layers.concat( + input=[embed_first, embed_second, embed_third, embed_forth], + axis=1, + program=program, + init_program=init_program) + +hidden1 = layers.fc(input=concat_embed, + size=hidden_size, + act='sigmoid', + program=program, + init_program=init_program) +predict_word = layers.fc(input=hidden1, + size=dict_size, + act='softmax', + program=program, + init_program=init_program) +cost = layers.cross_entropy( + input=predict_word, + label=next_word, + program=program, + init_program=init_program) +avg_cost = layers.mean(x=cost, program=program, init_program=init_program) + +sgd_optimizer = optimizer.SGDOptimizer(learning_rate=0.001) +opts = sgd_optimizer.minimize(avg_cost) + +train_reader = paddle.batch( + paddle.dataset.imikolov.train(word_dict, N), batch_size) + +place = core.CPUPlace() +exe = Executor(place) + +exe.run(init_program, feed={}, fetch_list=[]) +PASS_NUM = 100 +for pass_id in range(PASS_NUM): + for data in train_reader(): + input_data = [[data_idx[idx] for data_idx in data] for idx in xrange(5)] + input_data = map(lambda x: np.array(x).astype("int32"), input_data) + input_data = map(lambda x: np.expand_dims(x, axis=1), input_data) + + first_data = input_data[0] + first_tensor = core.LoDTensor() + first_tensor.set(first_data, place) + + second_data = input_data[0] + second_tensor = core.LoDTensor() + second_tensor.set(second_data, place) + + third_data = input_data[0] + third_tensor = core.LoDTensor() + third_tensor.set(third_data, place) + + forth_data = input_data[0] + forth_tensor = core.LoDTensor() + forth_tensor.set(forth_data, place) + + next_data = input_data[0] + next_tensor = core.LoDTensor() + next_tensor.set(next_data, place) + + outs = exe.run(program, + feed={ + 'firstw': first_tensor, + 'secondw': second_tensor, + 'thirdw': third_tensor, + 'forthw': forth_tensor, + 'nextw': next_tensor + }, + fetch_list=[avg_cost]) + out = np.array(outs[0]) + if out[0] < 10.0: + exit(0) # if avg cost less than 10.0, we think our code is good. +exit(1) diff --git a/python/paddle/v2/parameters.py b/python/paddle/v2/parameters.py index 4cfd91882e2d5f0098d27b8897359152ddd94dda..bd97dc1199fedc8ac91c1c6086957e8cce88bdc4 100644 --- a/python/paddle/v2/parameters.py +++ b/python/paddle/v2/parameters.py @@ -101,6 +101,10 @@ class Parameters(object): self.__param_conf__[param_conf.name] = param_conf + def update_param_conf(self, model_config): + for p in model_config.parameters: + self.__param_conf__[p.name] = p + def keys(self): """ keys are the names of each parameter. @@ -322,6 +326,17 @@ class Parameters(object): self.set(name, arr.reshape(self.get_shape(name))) def to_tar(self, f): + """ + Save parameters to a tar file. + + WARNING: You should use `paddle.v2.trainer.SGD.save_parameter_to_tar(f)` + to save parameters most of the time. Otherwise, some settings such + as model average will not take effect. + + :param f: + :type f: file + :return: + """ tar = tarfile.TarFile(fileobj=f, mode='w') for nm in self.names(): buf = cStringIO.StringIO() diff --git a/python/paddle/v2/tests/CMakeLists.txt b/python/paddle/v2/tests/CMakeLists.txt index b7791559594321a85f41b508b69efeb077d69595..b4333ed530ce464095ec38d72706949cc464fbe4 100644 --- a/python/paddle/v2/tests/CMakeLists.txt +++ b/python/paddle/v2/tests/CMakeLists.txt @@ -5,3 +5,4 @@ py_test(test_topology SRCS test_topology.py) py_test(test_rnn_layer SRCS test_rnn_layer.py) py_test(test_parameters SRCS test_parameters.py) py_test(test_data_feeder SRCS test_data_feeder.py) +py_test(test_paramconf_order SRCS test_paramconf_order.py) diff --git a/python/paddle/v2/tests/test_data_feeder.py b/python/paddle/v2/tests/test_data_feeder.py index 83da678da387ed1c86868847f140c6c09fbec3b5..63905c04cf737d0f1d226a4a5a27777351dbf5a3 100644 --- a/python/paddle/v2/tests/test_data_feeder.py +++ b/python/paddle/v2/tests/test_data_feeder.py @@ -97,7 +97,7 @@ class DataFeederTest(unittest.TestCase): each_sample.append(zip(a, b)) data.append(each_sample) - feeder = DataFeeder([('input', data_type.sparse_vector(dim))], + feeder = DataFeeder([('input', data_type.sparse_float_vector(dim))], {'input': 0}) arg = feeder(data) output = arg.getSlotValue(0) diff --git a/python/paddle/v2/tests/test_paramconf_order.py b/python/paddle/v2/tests/test_paramconf_order.py new file mode 100644 index 0000000000000000000000000000000000000000..41fea64122b81948d57cce07f00d764e4889da66 --- /dev/null +++ b/python/paddle/v2/tests/test_paramconf_order.py @@ -0,0 +1,85 @@ +# Copyright PaddlePaddle contributors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest +import math +import paddle.v2 as paddle + + +def wordemb(inlayer): + wordemb = paddle.layer.table_projection( + input=inlayer, + size=5, + param_attr=paddle.attr.Param( + name="_proj", initial_std=0.001, learning_rate=1, l2_rate=0)) + return wordemb + + +def train(): + word_dict = paddle.dataset.imikolov.build_dict() + dict_size = len(word_dict) + # Every layer takes integer value of range [0, dict_size) + firstword = paddle.layer.data( + name="firstw", type=paddle.data_type.integer_value(dict_size)) + secondword = paddle.layer.data( + name="secondw", type=paddle.data_type.integer_value(dict_size)) + thirdword = paddle.layer.data( + name="thirdw", type=paddle.data_type.integer_value(dict_size)) + fourthword = paddle.layer.data( + name="fourthw", type=paddle.data_type.integer_value(dict_size)) + nextword = paddle.layer.data( + name="fifthw", type=paddle.data_type.integer_value(dict_size)) + + Efirst = wordemb(firstword) + Esecond = wordemb(secondword) + Ethird = wordemb(thirdword) + Efourth = wordemb(fourthword) + + contextemb = paddle.layer.concat(input=[Efirst, Esecond, Ethird, Efourth]) + hidden1 = paddle.layer.fc(name="fc1", + input=contextemb, + size=128, + act=paddle.activation.Sigmoid(), + layer_attr=paddle.attr.Extra(drop_rate=0.5), + bias_attr=paddle.attr.Param(learning_rate=2), + param_attr=paddle.attr.Param( + initial_std=1. / math.sqrt(5 * 8), + learning_rate=1, + l2_rate=6e-4)) + predictword = paddle.layer.fc(input=hidden1, + size=dict_size, + bias_attr=paddle.attr.Param(learning_rate=2), + act=paddle.activation.Softmax()) + + return paddle.layer.classification_cost(input=predictword, label=nextword) + + +class TestParamConfOrder(unittest.TestCase): + def test_param_conf_order(self): + paddle.init() + cost = train() + parameters = paddle.parameters.create(cost) + adagrad = paddle.optimizer.AdaGrad( + learning_rate=3e-3, + regularization=paddle.optimizer.L2Regularization(rate=8e-4)) + + trainer = paddle.trainer.SGD(cost, parameters, adagrad) + for p in trainer.get_topology_proto().parameters: + if p.name == "_fc1.w0": + self.assertEqual(p.decay_rate, 6e-4) + else: + self.assertEqual(p.decay_rate, 8e-4) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/v2/topology.py b/python/paddle/v2/topology.py index 2db66be2505dde38a501edf45984e1f36beb351d..923ccecb0bf1236b4a3768fdc07dc3027e2863b7 100644 --- a/python/paddle/v2/topology.py +++ b/python/paddle/v2/topology.py @@ -19,6 +19,7 @@ import paddle.trainer_config_helpers as conf_helps import layer as v2_layer import config_base import cPickle +from paddle.trainer import config_parser as cp __all__ = ['Topology'] @@ -50,6 +51,35 @@ class Topology(object): assert isinstance(self.__model_config__, ModelConfig) + def update_from_default(self): + # HACK(typhoonzero): update ParameterConfig(proto) in case of + # optimizers are defined after layers, or between layers. + # Must be called from trainer.__init__() + for parameter in self.__model_config__.parameters: + if parameter.momentum == 0.0 and cp.g_default_momentum: + parameter.momentum = cp.g_default_momentum + if parameter.decay_rate == 0.0 and cp.g_default_decay_rate: + parameter.decay_rate = cp.g_default_decay_rate + if parameter.initial_mean == 0.0: + parameter.initial_mean = cp.g_default_initial_mean + if parameter.initial_std == 0.01: + parameter.initial_std = cp.g_default_initial_std + if parameter.initial_strategy == 0: + parameter.initial_strategy = cp.g_default_initial_strategy + if parameter.initial_smart == False: + parameter.initial_smart = cp.g_default_initial_smart + if parameter.num_batches_regularization == 1 and \ + cp.g_default_num_batches_regularization: + parameter.num_batches_regularization = \ + cp.g_default_num_batches_regularization + if parameter.gradient_clipping_threshold == 0.0 and \ + cp.g_default_gradient_clipping_threshold: + parameter.gradient_clipping_threshold = \ + cp.g_default_gradient_clipping_threshold + if parameter.device == -1 and cp.g_default_device: + parameter.device = cp.g_default_device + # FIXME(typhoonzero): ignored: update_hooks, g_default_compact_func + def use_sparse_updater(self): """ check if any parameter require to use sparse_update diff --git a/python/paddle/v2/trainer.py b/python/paddle/v2/trainer.py index 076e75593991415bc3fbcbd36a108c8c7de66932..b68fd0d5a97a7993ddd0a1d947304fa5428c01b8 100644 --- a/python/paddle/v2/trainer.py +++ b/python/paddle/v2/trainer.py @@ -64,6 +64,11 @@ class SGD(object): "paddle.v2.optimizer.Optimizer") import py_paddle.swig_paddle as api topology = Topology(cost, extra_layers=extra_layers) + # HACK(typhoonzero): update ParameterConfig(proto) in case of optimizers + # are defined after layers, or between layers. + topology.update_from_default() + parameters.update_param_conf(topology.proto()) + self.__optimizer__ = update_equation self.__topology__ = topology self.__parameters__ = parameters @@ -91,6 +96,9 @@ class SGD(object): self.__parameters__.append_gradient_machine(gm) self.__parameter_updater__ = None + def get_topology_proto(self): + return self.__topology_in_proto__ + def __use_remote_sparse_updater__(self): return self.__use_sparse_updater__ and not self.__is_local__