提交 8edf60ce 编写于 作者: Y Yibing Liu

Merge branch 'develop' of upstream into fix_seq_pad

...@@ -16,7 +16,9 @@ find_library(TENSORRT_LIBRARY NAMES libnvinfer.so libnvinfer.a ...@@ -16,7 +16,9 @@ find_library(TENSORRT_LIBRARY NAMES libnvinfer.so libnvinfer.a
DOC "Path to TensorRT library.") DOC "Path to TensorRT library.")
if(TENSORRT_INCLUDE_DIR AND TENSORRT_LIBRARY) if(TENSORRT_INCLUDE_DIR AND TENSORRT_LIBRARY)
if(WITH_DSO)
set(TENSORRT_FOUND ON) set(TENSORRT_FOUND ON)
endif(WITH DSO)
else() else()
set(TENSORRT_FOUND OFF) set(TENSORRT_FOUND OFF)
endif() endif()
......
...@@ -429,7 +429,7 @@ struct LSTM : public PatternBase { ...@@ -429,7 +429,7 @@ struct LSTM : public PatternBase {
struct GRU : public PatternBase { struct GRU : public PatternBase {
GRU(PDPattern* pattern, const std::string& name_scope) GRU(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "lstm") {} : PatternBase(pattern, name_scope, "gru") {}
PDNode* operator()(PDNode* x); PDNode* operator()(PDNode* x);
......
...@@ -9,8 +9,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -9,8 +9,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <glog/logging.h>
#include "paddle/fluid/inference/api/paddle_inference_api.h" #include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle { namespace paddle {
...@@ -64,13 +64,15 @@ PaddleBuf& PaddleBuf::operator=(PaddleBuf&& other) { ...@@ -64,13 +64,15 @@ PaddleBuf& PaddleBuf::operator=(PaddleBuf&& other) {
void PaddleBuf::Resize(size_t length) { void PaddleBuf::Resize(size_t length) {
// Only the owned memory can be reset, the external memory can't be changed. // Only the owned memory can be reset, the external memory can't be changed.
if (length_ == length) return; if (length_ >= length) return;
if (memory_owned_) { if (memory_owned_) {
Free(); Free();
data_ = malloc(length);
length_ = length;
memory_owned_ = true;
} else {
PADDLE_THROW("The memory is allocated externally, can not Resized");
} }
data_ = new char[length];
length_ = length;
memory_owned_ = true;
} }
void PaddleBuf::Reset(void* data, size_t length) { void PaddleBuf::Reset(void* data, size_t length) {
...@@ -82,8 +84,8 @@ void PaddleBuf::Reset(void* data, size_t length) { ...@@ -82,8 +84,8 @@ void PaddleBuf::Reset(void* data, size_t length) {
void PaddleBuf::Free() { void PaddleBuf::Free() {
if (memory_owned_ && data_) { if (memory_owned_ && data_) {
assert(length_ > 0); PADDLE_ENFORCE_GT(length_, 0);
delete[] static_cast<char*>(data_); free(static_cast<char*>(data_));
data_ = nullptr; data_ = nullptr;
length_ = 0; length_ = 0;
} }
......
...@@ -53,7 +53,7 @@ set(TEXT_CLASSIFICATION_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/text_classifi ...@@ -53,7 +53,7 @@ set(TEXT_CLASSIFICATION_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/text_classifi
download_model_and_data(${TEXT_CLASSIFICATION_INSTALL_DIR} "text-classification-Senta.tar.gz" "text_classification_data.txt.tar.gz") download_model_and_data(${TEXT_CLASSIFICATION_INSTALL_DIR} "text-classification-Senta.tar.gz" "text_classification_data.txt.tar.gz")
inference_analysis_test(test_analyzer_text_classification SRCS analyzer_text_classification_tester.cc inference_analysis_test(test_analyzer_text_classification SRCS analyzer_text_classification_tester.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${TEXT_CLASSIFICATION_INSTALL_DIR}/text-classification-Senta ARGS --infer_model=${TEXT_CLASSIFICATION_INSTALL_DIR}/model
--infer_data=${TEXT_CLASSIFICATION_INSTALL_DIR}/data.txt) --infer_data=${TEXT_CLASSIFICATION_INSTALL_DIR}/data.txt)
# ocr # ocr
......
...@@ -300,6 +300,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -300,6 +300,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings"); std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations"); std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
bool fuse_relu = ctx.Attr<bool>("fuse_relu"); bool fuse_relu = ctx.Attr<bool>("fuse_relu");
bool fuse_eltwise = ctx.Attr<bool>("fuse_eltwise");
int groups = ctx.Attr<int>("groups"); int groups = ctx.Attr<int>("groups");
// TODO: add support for dilation // TODO: add support for dilation
...@@ -366,12 +367,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -366,12 +367,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
bias_tz = paddle::framework::vectorize2int(bias->dims()); bias_tz = paddle::framework::vectorize2int(bias->dims());
auto bias_md = platform::MKLDNNMemDesc( auto bias_md = platform::MKLDNNMemDesc(
bias_tz, platform::MKLDNNGetDataType<T>(), memory::format::x); bias_tz, platform::MKLDNNGetDataType<T>(), memory::format::x);
conv_pd = conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, bias_md, dst_md,
ConvFwdPrimitiveDesc(src_md, weights_md, bias_md, dst_md, strides, strides, paddings, mkldnn_engine,
paddings, mkldnn_engine, fuse_relu); fuse_relu, fuse_eltwise);
} else { } else {
conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides, conv_pd =
paddings, mkldnn_engine, fuse_relu); ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides, paddings,
mkldnn_engine, fuse_relu, fuse_eltwise);
} }
// Save conv_pd/src_memory/weights_memory for backward pass // Save conv_pd/src_memory/weights_memory for backward pass
dev_ctx.SetBlob(key_conv_pd, conv_pd); dev_ctx.SetBlob(key_conv_pd, conv_pd);
...@@ -421,16 +423,26 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -421,16 +423,26 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
} }
private: private:
mkldnn::primitive_attr AddRelu() const { mkldnn::primitive_attr CreatePostOps(bool fuse_relu,
// Fusion with ReLU layer is executed through the PostOps feature. Create a bool fuse_eltwise) const {
// PostOps object and configure it to execute an eltwise relu operation.
mkldnn::primitive_attr conv_attr; mkldnn::primitive_attr conv_attr;
constexpr float scale = 1.0f;
constexpr float negative_slope = 0.0f;
constexpr float placeholder = 0.0f;
mkldnn::post_ops post_operations; mkldnn::post_ops post_operations;
post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_relu, // Fusion with Elementwise layer relies on adding a sum post-operation with
negative_slope, placeholder); // the scale parameter. It is assumed that when fuse_eltwise is true, the
// Output tensor contains the data coming from residual connection. The
// result of this post_op is: Output = scale * Output + Conv_Out.
if (fuse_eltwise) {
post_operations.append_sum(1.0f);
}
// Fusion with ReLU layer is executed through the PostOps feature. Create a
// PostOps object and configure it to execute an eltwise relu operation.
if (fuse_relu) {
constexpr float scale = 1.0f;
constexpr float negative_slope = 0.0f;
constexpr float placeholder = 0.0f;
post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_relu,
negative_slope, placeholder);
}
conv_attr.set_post_ops(post_operations); conv_attr.set_post_ops(post_operations);
return conv_attr; return conv_attr;
} }
...@@ -439,8 +451,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -439,8 +451,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
ConvFwdPrimitiveDesc(const memory::desc& src, const memory::desc& weights, ConvFwdPrimitiveDesc(const memory::desc& src, const memory::desc& weights,
const memory::desc& dst, const std::vector<int>& strides, const memory::desc& dst, const std::vector<int>& strides,
const std::vector<int>& paddings, const std::vector<int>& paddings,
const mkldnn::engine& engine, const mkldnn::engine& engine, const bool fuse_relu,
const bool fuse_relu) const { const bool fuse_eltwise) const {
memory::dims stride_dims = {strides[0], strides[1]}; memory::dims stride_dims = {strides[0], strides[1]};
memory::dims padding_dims = {paddings[0], paddings[1]}; memory::dims padding_dims = {paddings[0], paddings[1]};
...@@ -449,10 +461,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -449,10 +461,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
dst, stride_dims, padding_dims, padding_dims, dst, stride_dims, padding_dims, padding_dims,
mkldnn::padding_kind::zero); mkldnn::padding_kind::zero);
mkldnn::primitive_attr conv_attr; mkldnn::primitive_attr conv_attr = CreatePostOps(fuse_relu, fuse_eltwise);
if (fuse_relu) {
conv_attr = AddRelu();
}
auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc( auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc(
conv_desc, conv_attr, engine); conv_desc, conv_attr, engine);
...@@ -466,8 +475,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -466,8 +475,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const memory::desc& bias, const memory::desc& dst, const memory::desc& bias, const memory::desc& dst,
const std::vector<int>& strides, const std::vector<int>& strides,
const std::vector<int>& paddings, const std::vector<int>& paddings,
const mkldnn::engine& engine, const mkldnn::engine& engine, const bool fuse_relu,
const bool fuse_relu) const { const bool fuse_eltwise) const {
memory::dims stride_dims = {strides[0], strides[1]}; memory::dims stride_dims = {strides[0], strides[1]};
memory::dims padding_dims = {paddings[0], paddings[1]}; memory::dims padding_dims = {paddings[0], paddings[1]};
...@@ -476,10 +485,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -476,10 +485,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
bias, dst, stride_dims, padding_dims, padding_dims, bias, dst, stride_dims, padding_dims, padding_dims,
mkldnn::padding_kind::zero); mkldnn::padding_kind::zero);
mkldnn::primitive_attr conv_attr; mkldnn::primitive_attr conv_attr = CreatePostOps(fuse_relu, fuse_eltwise);
if (fuse_relu) {
conv_attr = AddRelu();
}
auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc( auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc(
conv_desc, conv_attr, engine); conv_desc, conv_attr, engine);
......
...@@ -164,6 +164,11 @@ void Conv2DOpMaker::Make() { ...@@ -164,6 +164,11 @@ void Conv2DOpMaker::Make() {
.SetDefault(false); .SetDefault(false);
AddAttr<bool>("fuse_relu", "(bool, default false) Only used in mkldnn kernel") AddAttr<bool>("fuse_relu", "(bool, default false) Only used in mkldnn kernel")
.SetDefault(false); .SetDefault(false);
AddAttr<bool>("fuse_eltwise",
"(bool, default false) Only used in mkldnn kernel. Used "
"whenever convolution output is connected via skip connection "
"to a previous layer.")
.SetDefault(false);
AddAttr<std::string>( AddAttr<std::string>(
"data_format", "data_format",
"(string, default NCHW) Only used in " "(string, default NCHW) Only used in "
......
...@@ -125,7 +125,7 @@ VarHandlePtr GRPCClient::AsyncGetVar(const std::string& ep, ...@@ -125,7 +125,7 @@ VarHandlePtr GRPCClient::AsyncGetVar(const std::string& ep,
VarHandlePtr h(new VarHandle(ep, "Get", var_name_val, p_ctx, p_scope)); VarHandlePtr h(new VarHandle(ep, "Get", var_name_val, p_ctx, p_scope));
s->Prepare(h, time_out); s->Prepare(h, time_out);
framework::AsyncIO([var_name_val, p_scope, p_ctx, s, this] { framework::AsyncIO([var_name_val, s, this] {
// prepare input // prepare input
sendrecv::VariableMessage req; sendrecv::VariableMessage req;
req.set_varname(var_name_val); req.set_varname(var_name_val);
...@@ -166,7 +166,7 @@ VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep, ...@@ -166,7 +166,7 @@ VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep,
s->Prepare(h, time_out); s->Prepare(h, time_out);
framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx, framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx,
time_out, s, this] { s, this] {
auto* var = p_scope->FindVar(in_var_name_val); auto* var = p_scope->FindVar(in_var_name_val);
::grpc::ByteBuffer req; ::grpc::ByteBuffer req;
......
...@@ -82,8 +82,10 @@ class ProtoEncodeHelper { ...@@ -82,8 +82,10 @@ class ProtoEncodeHelper {
: base_(buf), p_(buf), limit_(base_ + max_size) {} : base_(buf), p_(buf), limit_(base_ + max_size) {}
~ProtoEncodeHelper() { ~ProtoEncodeHelper() {
#define REPLACE_ENFORCE_GLOG 1
// Make sure callers didn't do operations that went over max_size promised // Make sure callers didn't do operations that went over max_size promised
PADDLE_ENFORCE_LE(p_, limit_); paddle::platform::throw_on_error(p_ <= limit_);
#undef REPLACE_ENFORCE_GLOG
} }
const char* data() const { return base_; } const char* data() const { return base_; }
......
...@@ -59,17 +59,16 @@ static void ParallelExecuteBlocks( ...@@ -59,17 +59,16 @@ static void ParallelExecuteBlocks(
framework::ProgramDesc *program, framework::Scope *scope) { framework::ProgramDesc *program, framework::Scope *scope) {
std::vector<std::future<void>> fs; std::vector<std::future<void>> fs;
for (size_t idx : parallel_blkids) { for (size_t idx : parallel_blkids) {
fs.push_back( fs.push_back(framework::Async([&executor, &prepared, &scope, idx]() {
framework::Async([&executor, &prepared, &program, &scope, idx]() { int run_block = idx; // thread local
int run_block = idx; // thread local try {
try { VLOG(3) << "running server block: " << run_block
VLOG(3) << "running server block: " << run_block << "pointer: " << prepared[run_block].get();
<< "pointer: " << prepared[run_block].get(); executor->RunPreparedContext(prepared[run_block].get(), scope);
executor->RunPreparedContext(prepared[run_block].get(), scope); } catch (const std::exception &e) {
} catch (const std::exception &e) { LOG(ERROR) << "run sub program error " << e.what();
LOG(ERROR) << "run sub program error " << e.what(); }
} }));
}));
} }
for (size_t i = 0; i < fs.size(); ++i) fs[i].wait(); for (size_t i = 0; i < fs.size(); ++i) fs[i].wait();
} }
......
...@@ -103,6 +103,58 @@ class MaxSeqPoolGradFunctor { ...@@ -103,6 +103,58 @@ class MaxSeqPoolGradFunctor {
} }
}; };
template <typename T>
class LastSeqPoolFunctor {
public:
void operator()(const platform::CPUDeviceContext& context,
const framework::LoDTensor& input,
framework::Tensor* output) {
// Create pointers to input and output data
auto* in_data = input.data<T>();
auto* out_data = output->data<T>();
// Calculate the size of each item in sequence
int64_t item_size = input.numel() / input.dims()[0];
auto lod = input.lod()[0];
int seq_num = static_cast<int>(lod.size()) - 1;
for (int i = 0; i < seq_num; ++i) {
// Calculate the length of each sequence
int64_t seq_len = static_cast<int64_t>(lod[i + 1] - lod[i]);
// Point to the begin of next sequence
in_data += seq_len * item_size;
// Copy the last item of sequence to output
std::memcpy(out_data, (in_data - item_size), item_size * sizeof(T));
out_data += item_size;
}
}
};
template <typename T>
class FirstSeqPoolFunctor {
public:
void operator()(const platform::CPUDeviceContext& context,
const framework::LoDTensor& input,
framework::Tensor* output) {
// Create pointers to input and output data
auto* in_data = input.data<T>();
auto* out_data = output->data<T>();
// Calculate the size of each item in sequence
int64_t item_size = input.numel() / input.dims()[0];
auto lod = input.lod()[0];
int seq_num = static_cast<int>(lod.size()) - 1;
for (int i = 0; i < seq_num; ++i) {
// Calculate the length of each sequence
int64_t seq_len = static_cast<int64_t>(lod[i + 1] - lod[i]);
// Copy the first item of sequence to output
std::memcpy(out_data, in_data, item_size * sizeof(T));
// Point to the next sequence
in_data += seq_len * item_size;
out_data += item_size;
}
}
};
template <typename T> template <typename T>
class SequencePoolFunctor<platform::CPUDeviceContext, T> { class SequencePoolFunctor<platform::CPUDeviceContext, T> {
public: public:
...@@ -116,6 +168,16 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> { ...@@ -116,6 +168,16 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> {
max_pool(context, input, output, index); max_pool(context, input, output, index);
return; return;
} }
if (pooltype == "LAST") {
math::LastSeqPoolFunctor<T> last_pool;
last_pool(context, input, output);
return;
}
if (pooltype == "FIRST") {
math::FirstSeqPoolFunctor<T> first_pool;
first_pool(context, input, output);
return;
}
auto lod = input.lod()[0]; auto lod = input.lod()[0];
auto& place = *context.eigen_device(); auto& place = *context.eigen_device();
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) { for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
...@@ -133,10 +195,6 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> { ...@@ -133,10 +195,6 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> {
} else if (pooltype == "SQRT") { } else if (pooltype == "SQRT") {
out_e.device(place) = in_e.sum(Eigen::array<int, 1>({{0}})) / out_e.device(place) = in_e.sum(Eigen::array<int, 1>({{0}})) /
std::sqrt(static_cast<T>(h)); std::sqrt(static_cast<T>(h));
} else if (pooltype == "LAST") {
out_e.device(place) = in_e.chip(h - 1, 0);
} else if (pooltype == "FIRST") {
out_e.device(place) = in_e.chip(0, 0);
} else { } else {
PADDLE_THROW("unsupported pooling pooltype"); PADDLE_THROW("unsupported pooling pooltype");
} }
......
...@@ -26,10 +26,13 @@ class PReluOp : public framework::OperatorWithKernel { ...@@ -26,10 +26,13 @@ class PReluOp : public framework::OperatorWithKernel {
std::string mode = ctx->Attrs().Get<std::string>("mode"); std::string mode = ctx->Attrs().Get<std::string>("mode");
auto x_dim = ctx->GetInputDim("X"); auto x_dim = ctx->GetInputDim("X");
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); PADDLE_ENFORCE(ctx->HasInput("X"),
PADDLE_ENFORCE(ctx->HasInput("Alpha"), "Input(Alpha) should not be null"); "Input(X) of PreluOp should not be null");
PADDLE_ENFORCE(ctx->HasInput("Alpha"),
"Input(Alpha) of PreluOp should not be null");
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null"); PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of PreluOp should not be null");
if (mode == "all") { if (mode == "all") {
PADDLE_ENFORCE(product(ctx->GetInputDim("Alpha")) == 1, PADDLE_ENFORCE(product(ctx->GetInputDim("Alpha")) == 1,
"For mode 'all', size of weight Alpha must be one."); "For mode 'all', size of weight Alpha must be one.");
......
...@@ -33,6 +33,7 @@ function print_usage() { ...@@ -33,6 +33,7 @@ function print_usage() {
${BLUE}single_test${NONE}: run a single unit test ${BLUE}single_test${NONE}: run a single unit test
${BLUE}bind_test${NONE}: parallel tests bind to different GPU ${BLUE}bind_test${NONE}: parallel tests bind to different GPU
${BLUE}doc${NONE}: generate paddle documents ${BLUE}doc${NONE}: generate paddle documents
${BLUE}gen_doc_lib${NONE}: generate paddle documents library
${BLUE}html${NONE}: convert C++ source code into HTML ${BLUE}html${NONE}: convert C++ source code into HTML
${BLUE}dockerfile${NONE}: generate paddle release dockerfile ${BLUE}dockerfile${NONE}: generate paddle release dockerfile
${BLUE}capi${NONE}: generate paddle CAPI package ${BLUE}capi${NONE}: generate paddle CAPI package
...@@ -431,24 +432,60 @@ EOF ...@@ -431,24 +432,60 @@ EOF
linkchecker doc/v2/cn/html/index.html linkchecker doc/v2/cn/html/index.html
linkchecker doc/v2/api/en/html/index.html linkchecker doc/v2/api/en/html/index.html
if [[ "$TRAVIS_PULL_REQUEST" != "false" ]]; then exit 0; fi; # if [[ "$TRAVIS_PULL_REQUEST" != "false" ]]; then exit 0; fi;
#
# # Deploy to the the content server if its a "develop" or "release/version" branch
# # The "develop_doc" branch is reserved to test full deploy process without impacting the real content.
# if [ "$TRAVIS_BRANCH" == "develop_doc" ]; then
# PPO_SCRIPT_BRANCH=develop
# elif [[ "$TRAVIS_BRANCH" == "develop" || "$TRAVIS_BRANCH" =~ ^v|release/[[:digit:]]+\.[[:digit:]]+(\.[[:digit:]]+)?(-\S*)?$ ]]; then
# PPO_SCRIPT_BRANCH=master
# else
# # Early exit, this branch doesn't require documentation build
# return 0;
# fi
# # Fetch the paddlepaddle.org deploy_docs.sh from the appopriate branch
# export DEPLOY_DOCS_SH=https://raw.githubusercontent.com/PaddlePaddle/PaddlePaddle.org/$PPO_SCRIPT_BRANCH/scripts/deploy/deploy_docs.sh
# export PYTHONPATH=$PYTHONPATH:${PADDLE_ROOT}/build/python:/paddle/build/python
# cd ..
# curl $DEPLOY_DOCS_SH | bash -s $CONTENT_DEC_PASSWD $TRAVIS_BRANCH ${PADDLE_ROOT} ${PADDLE_ROOT}/build/doc/ ${PPO_SCRIPT_BRANCH}
# cd -
}
# Deploy to the the content server if its a "develop" or "release/version" branch function gen_doc_lib() {
# The "develop_doc" branch is reserved to test full deploy process without impacting the real content. mkdir -p ${PADDLE_ROOT}/build
if [ "$TRAVIS_BRANCH" == "develop_doc" ]; then cd ${PADDLE_ROOT}/build
PPO_SCRIPT_BRANCH=develop cat <<EOF
elif [[ "$TRAVIS_BRANCH" == "develop" || "$TRAVIS_BRANCH" =~ ^v|release/[[:digit:]]+\.[[:digit:]]+(\.[[:digit:]]+)?(-\S*)?$ ]]; then ========================================
PPO_SCRIPT_BRANCH=master Building documentation library ...
else In /paddle/build
# Early exit, this branch doesn't require documentation build ========================================
return 0; EOF
fi cmake .. \
# Fetch the paddlepaddle.org deploy_docs.sh from the appopriate branch -DCMAKE_BUILD_TYPE=Release \
export DEPLOY_DOCS_SH=https://raw.githubusercontent.com/PaddlePaddle/PaddlePaddle.org/$PPO_SCRIPT_BRANCH/scripts/deploy/deploy_docs.sh -DWITH_DOC=ON \
export PYTHONPATH=$PYTHONPATH:${PADDLE_ROOT}/build/python:/paddle/build/python -DWITH_GPU=OFF \
cd .. -DWITH_MKL=OFF \
curl $DEPLOY_DOCS_SH | bash -s $CONTENT_DEC_PASSWD $TRAVIS_BRANCH ${PADDLE_ROOT} ${PADDLE_ROOT}/build/doc/ ${PPO_SCRIPT_BRANCH} -DWITH_FLUID_ONLY=ON
cd -
local LIB_TYPE=$1
case $LIB_TYPE in
full)
# Build full Paddle Python module. Will timeout without caching 'copy_paddle_pybind' first
make -j `nproc` gen_proto_py framework_py_proto copy_paddle_pybind paddle_python
;;
pybind)
# Build paddle pybind library. Takes 49 minutes to build. Might timeout
make -j `nproc` copy_paddle_pybind
;;
proto)
# Even smaller library.
make -j `nproc` framework_py_proto
;;
*)
exit 0
;;
esac
} }
function gen_html() { function gen_html() {
...@@ -608,6 +645,9 @@ function main() { ...@@ -608,6 +645,9 @@ function main() {
doc) doc)
gen_docs gen_docs
;; ;;
gen_doc_lib)
gen_doc_lib $2
;;
html) html)
gen_html gen_html
;; ;;
......
...@@ -92,7 +92,7 @@ class TrainTaskConfig(object): ...@@ -92,7 +92,7 @@ class TrainTaskConfig(object):
src_vocab_fpath = data_path + "vocab.bpe.32000" src_vocab_fpath = data_path + "vocab.bpe.32000"
trg_vocab_fpath = data_path + "vocab.bpe.32000" trg_vocab_fpath = data_path + "vocab.bpe.32000"
train_file_pattern = data_path + "train.tok.clean.bpe.32000.en-de" train_file_pattern = data_path + "train.tok.clean.bpe.32000.en-de"
val_file_pattern = data_path + "newstest2013.tok.bpe.32000.en-de" val_file_pattern = data_path + "newstest2013.tok.bpe.32000.en-de.cut"
pool_size = 2000 pool_size = 2000
sort_type = None sort_type = None
local = True local = True
...@@ -624,11 +624,12 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler, ...@@ -624,11 +624,12 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler,
init = True init = True
# Validate and save the model for inference. # Validate and save the model for inference.
if TrainTaskConfig.val_file_pattern is not None: if batch_id == 0 or batch_id == 4:
val_avg_cost, val_ppl = test() if TrainTaskConfig.val_file_pattern is not None:
print("[%f]" % val_avg_cost) val_avg_cost, val_ppl = test()
else: print("[%f]" % val_avg_cost)
assert (False) else:
assert (False)
#import transformer_reader as reader #import transformer_reader as reader
...@@ -1701,8 +1702,9 @@ class DistTransformer2x2(TestDistRunnerBase): ...@@ -1701,8 +1702,9 @@ class DistTransformer2x2(TestDistRunnerBase):
exe.run(startup_prog) exe.run(startup_prog)
exe.run(pserver_prog) exe.run(pserver_prog)
def run_trainer(self, place, args): def run_trainer(self, use_cuda, args):
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
TrainTaskConfig.use_gpu = use_cuda
sum_cost, avg_cost, predict, token_num, local_lr_scheduler = get_model( sum_cost, avg_cost, predict, token_num, local_lr_scheduler = get_model(
args.is_dist, not args.sync_mode) args.is_dist, not args.sync_mode)
......
...@@ -61,9 +61,10 @@ class TestDistRunnerBase(object): ...@@ -61,9 +61,10 @@ class TestDistRunnerBase(object):
exe.run(startup_prog) exe.run(startup_prog)
exe.run(pserver_prog) exe.run(pserver_prog)
def run_trainer(self, place, args): def run_trainer(self, use_cuda, args):
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
test_program, avg_cost, train_reader, test_reader, batch_acc, predict = \ test_program, avg_cost, train_reader, test_reader, batch_acc, predict = \
self.get_model(batch_size=2) self.get_model(batch_size=2)
if args.mem_opt: if args.mem_opt:
...@@ -91,7 +92,7 @@ class TestDistRunnerBase(object): ...@@ -91,7 +92,7 @@ class TestDistRunnerBase(object):
build_stra.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.AllReduce build_stra.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.AllReduce
exe = fluid.ParallelExecutor( exe = fluid.ParallelExecutor(
True, use_cuda,
loss_name=avg_cost.name, loss_name=avg_cost.name,
exec_strategy=strategy, exec_strategy=strategy,
build_strategy=build_stra) build_strategy=build_stra)
...@@ -142,9 +143,8 @@ def runtime_main(test_class): ...@@ -142,9 +143,8 @@ def runtime_main(test_class):
if args.role == "pserver" and args.is_dist: if args.role == "pserver" and args.is_dist:
model.run_pserver(args) model.run_pserver(args)
else: else:
p = fluid.CUDAPlace(0) if core.is_compiled_with_cuda( use_cuda = True if core.is_compiled_with_cuda() else False
) else fluid.CPUPlace() model.run_trainer(use_cuda, args)
model.run_trainer(p, args)
import paddle.compat as cpt import paddle.compat as cpt
...@@ -225,11 +225,12 @@ class TestDistBase(unittest.TestCase): ...@@ -225,11 +225,12 @@ class TestDistBase(unittest.TestCase):
def check_with_place(self, model_file, delta=1e-3, check_error_log=False): def check_with_place(self, model_file, delta=1e-3, check_error_log=False):
# TODO(typhoonzero): should auto adapt GPU count on the machine. # TODO(typhoonzero): should auto adapt GPU count on the machine.
required_envs = { required_envs = {
"PATH": os.getenv("PATH"), "PATH": os.getenv("PATH", ""),
"PYTHONPATH": os.getenv("PYTHONPATH"), "PYTHONPATH": os.getenv("PYTHONPATH", ""),
"LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH"), "LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""),
"FLAGS_fraction_of_gpu_memory_to_use": "0.15", "FLAGS_fraction_of_gpu_memory_to_use": "0.15",
"FLAGS_cudnn_deterministic": "1" "FLAGS_cudnn_deterministic": "1",
"CPU_NUM": "1"
} }
if check_error_log: if check_error_log:
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
from __future__ import print_function from __future__ import print_function
import os
import unittest import unittest
import paddle import paddle
from test_dist_base import TestDistBase from test_dist_base import TestDistBase
...@@ -44,6 +45,14 @@ def download_files(): ...@@ -44,6 +45,14 @@ def download_files():
test_url = url_prefix + 'newstest2013.tok.bpe.32000.en-de' test_url = url_prefix + 'newstest2013.tok.bpe.32000.en-de'
test_md5 = '9dd74a266dbdb25314183899f269b4a2' test_md5 = '9dd74a266dbdb25314183899f269b4a2'
paddle.dataset.common.download(test_url, 'test_dist_transformer', test_md5) paddle.dataset.common.download(test_url, 'test_dist_transformer', test_md5)
# cut test data for faster CI
orig_path = os.path.join(paddle.dataset.common.DATA_HOME,
"test_dist_transformer",
"newstest2013.tok.bpe.32000.en-de")
head_path = os.path.join(paddle.dataset.common.DATA_HOME,
"test_dist_transformer",
"newstest2013.tok.bpe.32000.en-de.cut")
os.system("head -n10 %s > %s" % (orig_path, head_path))
class TestDistTransformer2x2Sync(TestDistBase): class TestDistTransformer2x2Sync(TestDistBase):
......
...@@ -65,8 +65,43 @@ class InferenceTranspiler(object): ...@@ -65,8 +65,43 @@ class InferenceTranspiler(object):
if use_mkldnn: if use_mkldnn:
self._fuse_conv_bias_mkldnn(program) self._fuse_conv_bias_mkldnn(program)
self._fuse_conv_relu_mkldnn(program) self._fuse_conv_relu_mkldnn(program)
self._fuse_conv_eltwise_mkldnn(program)
self._fuse_conv_relu_mkldnn(
program) # ResNet residual block merging
self._fuse_bn_relu_mkldnn(program) self._fuse_bn_relu_mkldnn(program)
def _fuse_conv_eltwise_mkldnn(self, program):
'''
Transpile the program fusing elementwise_add into conv for MKLDNN
program. Elementwise add following convolution OP can be fused by adding
'fuse_eltwise' attribute to convolution OP and replacing its output
Tensor with second parameter of elementwise_add.
The result of fuse is:
- before:
- conv->elementwise_add->any_other_op
- after:
- conv->any_other_op
:param program: program to transpile
:type program: Program
'''
self.block = program.block(0)
i = 0
while i < len(self.block.ops):
current_op = self.block.ops[i]
if current_op.type in ['conv2d']:
next_op = self.block.ops[i + 1]
if next_op.type == 'elementwise_add':
self._fuse_conv_eltwise(current_op, next_op)
self.block._remove_op(i + 1) # Remove elementwise_add
i = i + 1
self._adjust_input()
self._remove_unused_var()
# TODO(luotao): use clone() method to flush the program.desc in force,
# since some large program.desc will not be flushed immediately.
# And a better solution will be considered later.
program = program.clone()
def _fuse_conv_relu_mkldnn(self, program): def _fuse_conv_relu_mkldnn(self, program):
''' '''
Transpile the program by fused relu activation for MKLDNN program. Transpile the program by fused relu activation for MKLDNN program.
...@@ -88,9 +123,9 @@ class InferenceTranspiler(object): ...@@ -88,9 +123,9 @@ class InferenceTranspiler(object):
if current_op.type in ['conv2d']: if current_op.type in ['conv2d']:
next_op = self.block.ops[i + 1] next_op = self.block.ops[i + 1]
if next_op.type == 'relu': if next_op.type == 'relu':
# modify conv OP to include relu # modify bnorm OP to include relu
current_op.set_attr("fuse_relu", True) current_op.set_attr("fuse_relu", True)
# remove conv OP # remove relu OP
self.block._remove_op(i + 1) self.block._remove_op(i + 1)
i = i + 1 i = i + 1
...@@ -409,6 +444,20 @@ class InferenceTranspiler(object): ...@@ -409,6 +444,20 @@ class InferenceTranspiler(object):
outputs={"Output": out_var}, outputs={"Output": out_var},
attrs=attrs) attrs=attrs)
def _fuse_conv_eltwise(self, conv_op, eltwise_op):
'''
fuse the conv op with elementwise_add
:param conv_op: convolution operator
:type conv_op: Operator
:param eltwise_op: operator adding data from skip connection
:type eltwise_op: Operator
'''
conv_op.set_attr("fuse_eltwise", True)
self.input_map[conv_op.output("Output")[0]] = eltwise_op.input("Y")[0]
self.input_map[eltwise_op.output("Out")[0]] = eltwise_op.input("Y")[0]
def _adjust_input(self): def _adjust_input(self):
for i in range(len(self.block.ops)): for i in range(len(self.block.ops)):
current_op = self.block.ops[i] current_op = self.block.ops[i]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册