From 7141debe38881bfc0fe146111bbba2211c1a6ddd Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Fri, 26 Oct 2018 19:43:19 +0800 Subject: [PATCH] add cudnn back. staged. --- paddle/fluid/framework/executor.cc | 80 ++++++++----- paddle/fluid/framework/op_desc.cc | 49 +++++--- paddle/fluid/framework/op_desc.h | 10 -- paddle/fluid/framework/operator.cc | 113 ++++++++++-------- paddle/fluid/framework/shape_inference.h | 3 + paddle/fluid/inference/api/api_impl.cc | 8 +- .../api/demo_ci/real_data_icnet_tester.cc | 6 +- .../api/demo_ci/thread_icnet_test.cc | 94 ++++++++------- paddle/fluid/memory/detail/buddy_allocator.cc | 3 +- paddle/fluid/memory/detail/meta_cache.cc | 2 + paddle/fluid/operators/top_k_op.cc | 2 +- paddle/fluid/operators/top_k_op.cu | 99 +++++---------- paddle/fluid/operators/top_k_op.h | 5 +- paddle/fluid/platform/CMakeLists.txt | 7 ++ paddle/fluid/platform/cudnn_helper.h | 9 +- paddle/fluid/platform/enforce.h | 41 ++++--- 16 files changed, 287 insertions(+), 244 deletions(-) diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index bdf7e7c124..ddbcff7b39 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -299,16 +299,19 @@ std::unique_ptr Executor::Prepare( std::unique_ptr ctx( new ExecutorPrepareContext(program, block_id)); VLOG(3) << "after create prepare"; - // PADDLE_ENFORCE_LT(static_cast(block_id), program.Size()); + // PADDLE_ENFORCE_LT(static_cast(block_id), program.Size()); VLOG(3) << "before create op_desc"; auto& block = program.Block(block_id); - VLOG(3) << "create before" << ctx->ops_.size() << " " << block.AllOps().size(); + VLOG(3) << "create before" << ctx->ops_.size() << " " + << block.AllOps().size(); int counter = 0; for (auto& op_desc : block.AllOps()) { ctx->ops_.push_back(OpRegistry::CreateOp(*op_desc)); - VLOG(3) << "create op " << "index " << ++counter << " type " << op_desc->Type(); + VLOG(3) << "create op " + << "index " << ++counter << " type " << op_desc->Type(); } - VLOG(3) << "create finished" << ctx->ops_.size() << " " << block.AllOps().size(); + VLOG(3) << "create finished" << ctx->ops_.size() << " " + << block.AllOps().size(); return ctx; } @@ -320,22 +323,25 @@ std::vector> Executor::Prepare( for (auto& bid : block_ids) { VLOG(3) << "block id" << bid; auto* ctx = new ExecutorPrepareContext(program, bid); - //PADDLE_ENFORCE_LT(static_cast(bid), program.Size()); + // PADDLE_ENFORCE_LT(static_cast(bid), program.Size()); auto& block = program.Block(bid); int counter = 0; - VLOG(3) << "create before" << ctx->ops_.size() << " " << block.AllOps().size(); + VLOG(3) << "create before" << ctx->ops_.size() << " " + << block.AllOps().size(); for (auto& op_desc : block.AllOps()) { - ctx->ops_.push_back(OpRegistry::CreateOp(*op_desc)); - VLOG(3) << "create op " << "index " << ++counter << " type " << op_desc->Type(); + VLOG(3) << "create op " + << "index " << ++counter << " type " << op_desc->Type(); } - VLOG(3) << "create finished" << ctx->ops_.size() << " " << block.AllOps().size(); + VLOG(3) << "create finished" << ctx->ops_.size() << " " + << block.AllOps().size(); result.push_back(std::shared_ptr(ctx)); } return result; } -// void CheckResult(const std::string op_type, ExecutorPrepareContext* ctx, Scope* local_scope) { +// void CheckResult(const std::string op_type, ExecutorPrepareContext* ctx, +// Scope* local_scope) { // VLOG(3) << "before checking result"; // auto& dev_ctx = *platform::DeviceContextPool::Instance().Get(place_); // std::vector outputs; @@ -343,7 +349,8 @@ std::vector> Executor::Prepare( // bool found = false; // framework::OpDesc* myop = nullptr; // for(auto& op : block.AllOps()) { -// if(op->Type() == "load_combine" || op->Type() == "fetch" || op->Type() == "feed") return; +// if(op->Type() == "load_combine" || op->Type() == "fetch" || op->Type() == +// "feed") return; // if (op->Type() == op_type) { // found = true; // myop = op; @@ -370,7 +377,8 @@ std::vector> Executor::Prepare( // for(size_t i=0; i < check.numel(); ++i) { // sum += check.data()[i]; // } -// VLOG(3) << "op " << op->Type() << " output var " << var_name << " sum " << sum; +// VLOG(3) << "op " << op->Type() << " output var " << var_name << " sum " +// << sum; // VLOG(3) << "after checking result"; // } @@ -389,11 +397,14 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, VLOG(3) << "Scope ptr " << local_scope; for (auto& op : ctx->ops_) { op->Run(*local_scope, place_); - // CheckResult(op->Type(), ctx, local_scope); - if (FLAGS_benchmark) { - VLOG(2) << "Memory used after operator " + op->Type() + " running: " - << memory::memory_usage(place_); - } + // CheckResult(op->Type(), ctx, local_scope); + // if (FLAGS_benchmark) { + // VLOG(2) << "Memory used after operator " + op->Type() + " running: " + // << memory::memory_usage(place_); + // } + VLOG(2) << "Memory used after operator " + op->Type() + " running: " + << memory::memory_usage(place_); + // platform::DeviceContextPool::Instance().Get(place_)->Wait(); } platform::DeviceContextPool::Instance().Get(place_)->Wait(); @@ -403,13 +414,15 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, // auto& block = ctx->prog_.Block(0); // for(auto& op : block.AllOps()) { - // if(op->Type() == "load_combine" || op->Type() == "fetch" || op->Type() == "feed") continue; + // if(op->Type() == "load_combine" || op->Type() == "fetch" || op->Type() == + // "feed") continue; // // for(auto& real_op : ctx->ops_) { // // if(real_op->Type() == op->Type()) { - // // VLOG(3) << real_op->Type() << " " <DebugStringEx(local_scope); + // // VLOG(3) << real_op->Type() << " " <DebugStringEx(local_scope); // // } // // } - + // //VLOG(3) << "start op output" << op->Type(); // for(auto var_name: op->InputArgumentNames()) { // auto* var = local_scope->Var(var_name); @@ -418,19 +431,21 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, // auto* tensor = var->GetMutable(); // framework::Tensor check; // VLOG(3) << "before tensor copy"; - + // framework::TensorCopy(*tensor, platform::CPUPlace(), dev_ctx, &check); - + // VLOG(3) << "after tensor copy"; // float sum = .0; // for(size_t i=0; i < check.numel(); ++i) { - // if(std::type_index(check.type()) == std::type_index(typeid(int64_t))) { + // if(std::type_index(check.type()) == std::type_index(typeid(int64_t))) + // { // sum += static_cast(check.data()[i]); // } else { // sum += check.data()[i]; // } // } - // VLOG(3) << "op " << op->Type() << " input var " << var_name << " sum " << sum; + // VLOG(3) << "op " << op->Type() << " input var " << var_name << " sum " + // << sum; // } // VLOG(3) << "op " << op->Type() << "input finished"; @@ -442,23 +457,28 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, // framework::Tensor check; // VLOG(3) << "before tensor copy"; // if(op->Type() == "batch_norm" && platform::is_gpu_place(place_)) { - // VLOG(3) << "op " << op->Type() << " output var " << var_name << " " << tensor->numel(); + // VLOG(3) << "op " << op->Type() << " output var " << var_name << " " + // << tensor->numel(); // tensor->mutable_data(place_); - // framework::TensorCopy(*tensor, platform::CPUPlace(), dev_ctx, &check); + // framework::TensorCopy(*tensor, platform::CPUPlace(), dev_ctx, + // &check); // } else { - // framework::TensorCopy(*tensor, platform::CPUPlace(), dev_ctx, &check); + // framework::TensorCopy(*tensor, platform::CPUPlace(), dev_ctx, + // &check); // } - + // VLOG(3) << "after tensor copy"; // float sum = .0; // for(size_t i=0; i < check.numel(); ++i) { - // if(std::type_index(check.type()) == std::type_index(typeid(int64_t))) { + // if(std::type_index(check.type()) == std::type_index(typeid(int64_t))) + // { // sum += static_cast(check.data()[i]); // } else { // sum += check.data()[i]; // } // } - // VLOG(3) << "op " << op->Type() << " output var " << var_name << " sum " << sum; + // VLOG(3) << "op " << op->Type() << " output var " << var_name << " sum " + // << sum; // } // } diff --git a/paddle/fluid/framework/op_desc.cc b/paddle/fluid/framework/op_desc.cc index 555faba962..c293cf92b4 100644 --- a/paddle/fluid/framework/op_desc.cc +++ b/paddle/fluid/framework/op_desc.cc @@ -50,19 +50,41 @@ class CompileTimeInferShapeContext : public InferShapeContext { const std::vector &Outputs( const std::string &name) const override; + void ShareDim(const std::string &in, const std::string &out, size_t i = 0, + size_t j = 0) override { + PADDLE_ENFORCE_LT(i, Inputs(in).size()); + PADDLE_ENFORCE_LT(j, Outputs(out).size()); + const std::string &input_n = Inputs(in)[i]; + const std::string &output_n = Outputs(out)[j]; + + PADDLE_ENFORCE(input_n != framework::kEmptyVarName, "The %s[%d] is @EMPTY@", + in, i); + PADDLE_ENFORCE(output_n != framework::kEmptyVarName, + "The %s[%d] is @EMPTY@", out, j); + + auto *in_var = block_.FindVarRecursive(input_n); + auto *out_var = block_.FindVarRecursive(output_n); + + PADDLE_ENFORCE(in_var->GetType() == out_var->GetType(), + "The type of %s and %s is not the same.", input_n, output_n); + + SetDim(output_n, GetDim(input_n)); + } + void ShareLoD(const std::string &in, const std::string &out, size_t i = 0, size_t j = 0) const override { PADDLE_ENFORCE_LT(i, Inputs(in).size()); PADDLE_ENFORCE_LT(j, Outputs(out).size()); + PADDLE_ENFORCE(Inputs(in)[i] != framework::kEmptyVarName, + "The %s[%d] is @EMPTY@", in, i); + PADDLE_ENFORCE(Outputs(out)[j] != framework::kEmptyVarName, + "The %s[%d] is @EMPTY@", out, j); auto *in_var = block_.FindVarRecursive(Inputs(in)[i]); auto *out_var = block_.FindVarRecursive(Outputs(out)[j]); if (in_var->GetType() != proto::VarType::LOD_TENSOR) { VLOG(3) << "input " << in << " is not LodTensor"; return; } - PADDLE_ENFORCE_EQ(in_var->GetType(), proto::VarType::LOD_TENSOR, - "The %d-th output of Output(%s) must be LoDTensor.", j, - out); out_var->SetLoDLevel(in_var->GetLoDLevel()); } @@ -441,7 +463,10 @@ static void InitInferShapeFuncs() { for (auto &kern_pair : OperatorWithKernel::AllOpKernels()) { auto op_type = kern_pair.first; - auto &op_info = info_map.at(op_type); + auto it = info_map.find(op_type); + PADDLE_ENFORCE(it != info_map.end(), "%s has not been registered", + op_type); + auto &op_info = it->second; auto op = static_cast(op_info.Creator()( "", VariableNameMap{}, VariableNameMap{}, AttributeMap{})); if (op_info.infer_shape_) { // infer_shape has been registered. @@ -490,20 +515,14 @@ void OpDesc::InferShape(const BlockDesc &block) const { } void OpDesc::InferVarType(BlockDesc *block) const { + // There are a few places that var type can be set. + // When VarDesc is created, default set to LOD_TENSOR. + // When output variable is created, default is defaut set to LOD_TENSOR. + // We limit here to be the only place that operator defines its customized + // var type inference. Hence, we don't do any "default" setting here. auto &info = OpInfoMap::Instance().Get(this->Type()); if (info.infer_var_type_) { info.infer_var_type_(*this, block); - } else { - // all output type is LoDTensor by default - VLOG(10) << this->Type() - << " has not registered InferVarType. Set output variables to " - "LOD_TENSOR"; - for (auto &out_pair : this->outputs_) { - for (auto &out_var_name : out_pair.second) { - block->FindRecursiveOrCreateVar(out_var_name) - .SetType(proto::VarType::LOD_TENSOR); - } - } } } diff --git a/paddle/fluid/framework/op_desc.h b/paddle/fluid/framework/op_desc.h index b4205aba83..440e0509be 100644 --- a/paddle/fluid/framework/op_desc.h +++ b/paddle/fluid/framework/op_desc.h @@ -100,16 +100,6 @@ class OpDesc { std::vector InputNames() const { return MapKeys(inputs_); } std::vector OutputNames() const { return MapKeys(outputs_); } - void SetInputMap(const VariableNameMap &input) { - this->inputs_ = input; - this->need_update_ = true; - } - - void SetOutputMap(const VariableNameMap &output) { - this->outputs_ = output; - this->need_update_ = true; - } - const VariableNameMap &Inputs() const { return inputs_; } const VariableNameMap &Outputs() const { return outputs_; } diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 3b4a620f8c..ea060ebc60 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -62,7 +62,7 @@ static DDim GetDims(const Scope& scope, const std::string& name, if (var->IsType()) { const LoDTensor& tensor = var->Get(); - if (!tensor.IsInitialized()) { + if (UNLIKELY(!tensor.IsInitialized())) { return DDim({-1}); } return tensor.dims(); @@ -91,13 +91,13 @@ static std::string GetDtype(const Scope& scope, const std::string& name) { if (var->IsType()) { const LoDTensor& tensor = var->Get(); - if (!tensor.IsInitialized()) { + if (UNLIKELY(!tensor.IsInitialized())) { return ""; } return DataTypeToString(ToDataType(tensor.type())); } else if (var->IsType()) { auto tensor = var->Get().value(); - if (!tensor.IsInitialized()) { + if (UNLIKELY(!tensor.IsInitialized())) { return "uninited"; } else { return DataTypeToString(ToDataType(tensor.type())); @@ -130,7 +130,7 @@ static LoD GetLoD(const Scope& scope, const std::string& name) { if (var->IsType()) { const LoDTensor& tensor = var->Get(); - if (!tensor.IsInitialized()) { + if (UNLIKELY(!tensor.IsInitialized())) { return default_lod; } return tensor.lod(); @@ -149,11 +149,13 @@ void OperatorBase::Run(const Scope& scope, const platform::Place& place) { platform::SetDeviceId(dev_id); #endif } - VLOG(3) << "start pool"; - platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); - platform::RecordEvent record_event(Type(), pool.Get(place)); - VLOG(3) << "start RunImpl"; + + // The profile has a process-wide mutex, results in serious performance issue + // in concurrency scenerio. Here use an `if` to fix this issue. + // Please not remove the `if`, ask @Superjomn if there are any concern. + RunImpl(scope, place); + VLOG(3) << place << " " << DebugStringEx(&scope); } @@ -206,7 +208,6 @@ const std::vector& OperatorBase::Outputs( } std::string OperatorBase::DebugStringEx(const Scope* scope) const { - VLOG(3) << this->Type() << " scope ptr " << scope; std::stringstream ss; ss << "Op(" << type_ << "), inputs:{"; for (auto it = inputs_.begin(); it != inputs_.end();) { @@ -470,35 +471,35 @@ class RuntimeInferShapeContext : public InferShapeContext { : op_(op), scope_(scope) {} bool HasInput(const std::string& name) const override { - if (!op_.HasInputs(name)) { + // has only one input + const auto& ins = op_.Inputs(); + auto it = ins.find(name); + if (it == ins.end()) { return false; } - auto& ins = Inputs(name); - size_t length = ins.size(); - if (length == 0) { + const auto& in = it->second; + if (in.size() == 0 || in[0] == kEmptyVarName) { return false; } - PADDLE_ENFORCE_EQ(length, 1UL, + PADDLE_ENFORCE_EQ(in.size(), 1UL, "Input %s should not have more than one inputs", name); - auto ipt = ins[0]; - auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt); - return var != nullptr; + return scope_.FindVar(in[0]) != nullptr; } bool HasOutput(const std::string& name) const override { - if (!op_.HasOutputs(name)) { + // has only one output + const auto& outs = op_.Outputs(); + auto it = outs.find(name); + if (it == outs.end()) { return false; } - auto& outs = Outputs(name); - size_t length = outs.size(); - if (length == 0) { + const auto& out = it->second; + if (out.size() == 0 || out[0] == kEmptyVarName) { return false; } - PADDLE_ENFORCE_EQ(length, 1UL, - "Output %s should not have more than one inputs", name); - auto ipt = outs[0]; - auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt); - return var != nullptr; + PADDLE_ENFORCE_EQ(out.size(), 1UL, + "Output %s should not have more than one outputs", name); + return scope_.FindVar(out[0]) != nullptr; } bool HasInputs(const std::string& name) const override { @@ -545,13 +546,45 @@ class RuntimeInferShapeContext : public InferShapeContext { return op_.Outputs(name); } - void ShareLoD(const std::string& in, const std::string& out, size_t i = 0, - size_t j = 0) const override { + void ShareDim(const std::string& in, const std::string& out, size_t i = 0, + size_t j = 0) override { PADDLE_ENFORCE_LT(i, Inputs(in).size()); PADDLE_ENFORCE_LT(j, Outputs(out).size()); - Variable* in_var = scope_.FindVar(Inputs(in)[i]); - Variable* out_var = scope_.FindVar(Outputs(out)[j]); + const std::string& input_n = Inputs(in)[i]; + const std::string& output_n = Outputs(out)[j]; + + Variable* in_var = scope_.FindVar(input_n); + Variable* out_var = scope_.FindVar(output_n); + PADDLE_ENFORCE(in_var->Type() == out_var->Type(), + "The type of %s and %s is not the same.", output_n, + GetDim(input_n)); + + if (in_var->IsType()) { + auto& in_sele_rows = in_var->Get(); + auto out_sele_rows = out_var->GetMutable(); + out_sele_rows->mutable_value()->Resize(in_sele_rows.value().dims()); + out_sele_rows->set_rows(in_sele_rows.rows()); + out_sele_rows->set_height(in_sele_rows.height()); + } else if (in_var->IsType()) { + auto& in_lod_tensor = in_var->Get(); + auto* out_lod_tensor = out_var->GetMutable(); + out_lod_tensor->Resize(in_lod_tensor.dims()); + } else { + PADDLE_THROW( + "Currently, the input type of ShareDim only can be LoDTensor " + "or SelectedRows."); + } + } + + void ShareLoD(const std::string& in, const std::string& out, size_t i = 0, + size_t j = 0) const override { + const std::vector& inputs = Inputs(in); + const std::vector& outputs = Outputs(out); + PADDLE_ENFORCE_LT(i, inputs.size()); + PADDLE_ENFORCE_LT(j, outputs.size()); + Variable* in_var = scope_.FindVar(inputs.at(i)); if (!in_var->IsType()) return; + Variable* out_var = scope_.FindVar(outputs.at(j)); PADDLE_ENFORCE(out_var->IsType(), "The %d-th output of Output(%s) must be LoDTensor.", j, out); auto in_tensor = in_var->Get(); @@ -579,20 +612,6 @@ class RuntimeInferShapeContext : public InferShapeContext { out_tensor->set_layout(in_tensor.layout()); } - void ShareLayout(const std::string& in, const std::string& out, size_t i = 0, - size_t j = 0) const { - PADDLE_ENFORCE_LT(i, Inputs(in).size()); - PADDLE_ENFORCE_LT(j, Outputs(out).size()); - Variable* in_var = scope_.FindVar(Inputs(in)[i]); - Variable* out_var = scope_.FindVar(Outputs(out)[j]); - if (!in_var->IsType()) return; - PADDLE_ENFORCE(out_var->IsType(), - "The %d-th output of Output(%s) must be LoDTensor.", j, out); - auto in_tensor = in_var->Get(); - auto* out_tensor = out_var->GetMutable(); - out_tensor->set_layout(in_tensor.layout()); - } - bool IsRuntime() const override { return true; } protected: @@ -663,16 +682,12 @@ static void CheckTensorNANOrInf(const std::string& name, void OperatorWithKernel::RunImpl(const Scope& scope, const platform::Place& place) const { RuntimeInferShapeContext infer_shape_ctx(*this, scope); - VLOG(3) << "start Infershape"; this->InferShape(&infer_shape_ctx); - VLOG(3) << "Infershape Pass"; platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto* dev_ctx = pool.Get(place); // check if op[type] has kernel registered. - VLOG(3) << "Start Kernels"; auto& all_op_kernels = AllOpKernels(); - VLOG(3) << "Kernel map finish"; auto kernels_iter = all_op_kernels.find(type_); if (kernels_iter == all_op_kernels.end()) { PADDLE_THROW( @@ -690,7 +705,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope, auto expected_kernel_key = this->GetExpectedKernelType(ExecutionContext(*this, scope, *dev_ctx)); - VLOG(3) << "expected_kernel_key: " << expected_kernel_key; + VLOG(3) << "expected_kernel_key:" << expected_kernel_key; auto kernel_iter = kernels.find(expected_kernel_key); #ifdef PADDLE_WITH_MKLDNN diff --git a/paddle/fluid/framework/shape_inference.h b/paddle/fluid/framework/shape_inference.h index 5f497cafa0..280bc19dce 100644 --- a/paddle/fluid/framework/shape_inference.h +++ b/paddle/fluid/framework/shape_inference.h @@ -56,6 +56,9 @@ class InferShapeContext { virtual const std::vector &Outputs( const std::string &name) const = 0; + virtual void ShareDim(const std::string &in, const std::string &out, + size_t i = 0, size_t j = 0) = 0; + virtual void ShareLoD(const std::string &in, const std::string &out, size_t i = 0, size_t j = 0) const = 0; diff --git a/paddle/fluid/inference/api/api_impl.cc b/paddle/fluid/inference/api/api_impl.cc index aaf6d5a4f3..c778529cc4 100644 --- a/paddle/fluid/inference/api/api_impl.cc +++ b/paddle/fluid/inference/api/api_impl.cc @@ -112,11 +112,11 @@ bool NativePaddlePredictor::Init( auto &block = inference_program_->Block(0); for (auto *op_desc : block.AllOps()) { - if (op_desc->HasAttr("use_cudnn")) { - op_desc->SetAttr("use_cudnn", false); - } + // if (op_desc->HasAttr("use_cudnn")) { + // op_desc->SetAttr("use_cudnn", false); + // } if (op_desc->HasAttr("workspace_size_MB")) { - op_desc->SetAttr("workspace_size_MB", 0); + op_desc->SetAttr("workspace_size_MB", 1024); } } diff --git a/paddle/fluid/inference/api/demo_ci/real_data_icnet_tester.cc b/paddle/fluid/inference/api/demo_ci/real_data_icnet_tester.cc index 1b6463a333..c7db21d093 100644 --- a/paddle/fluid/inference/api/demo_ci/real_data_icnet_tester.cc +++ b/paddle/fluid/inference/api/demo_ci/real_data_icnet_tester.cc @@ -27,8 +27,8 @@ NativeConfig GetConfig() { NativeConfig config; // config.model_dir = FLAGS_dirname; - config.prog_file = "hs_lb_without_bn/__model__"; - config.param_file = "hs_lb_without_bn/__params__"; + config.prog_file = "hs_lb_without_bn_cudnn/__model__"; + config.param_file = "hs_lb_without_bn_cudnn/__params__"; // config.prog_file = "hs_lb_without_bn_cuda/__model__"; // config.param_file = "hs_lb_without_bn_cuda/__params__"; config.fraction_of_gpu_memory = 0.0; @@ -106,7 +106,7 @@ void test_naive(int batch_size) { std::cout << "batch: " << batch_size << " predict cost: " << time_diff(time1, time2) / steps << "ms" << std::endl; - std::cout << outputs.size() << std::endl; + std::cout << outputs.size() << std::endl; int64_t* data_o = static_cast(outputs[0].data.data()); int64_t sum_out = 0; for (size_t j = 0; j < outputs[0].data.length() / sizeof(int64_t); ++j) { diff --git a/paddle/fluid/inference/api/demo_ci/thread_icnet_test.cc b/paddle/fluid/inference/api/demo_ci/thread_icnet_test.cc index 9a018ee347..e1ce46b3bb 100644 --- a/paddle/fluid/inference/api/demo_ci/thread_icnet_test.cc +++ b/paddle/fluid/inference/api/demo_ci/thread_icnet_test.cc @@ -21,12 +21,12 @@ #include #include #include // NOLINT +#include #include "paddle/fluid/inference/api/paddle_inference_api.h" #define ASSERT_TRUE(x) x #define ASSERT_EQ(x, y) assert(x == y) - // DEFINE_string(dirname, "./LB_icnet_model", // "Directory of the inference model."); namespace paddle { @@ -34,7 +34,7 @@ NativeConfig GetConfig() { NativeConfig config; config.prog_file = "./hs_lb_without_bn_cuda/__model__"; config.param_file = "./hs_lb_without_bn_cuda/__params__"; - config.fraction_of_gpu_memory = 0.5; + config.fraction_of_gpu_memory = 0.0; config.use_gpu = true; config.device = 0; return config; @@ -54,7 +54,7 @@ void test_naive(int batch_size, std::string model_path) { int height = 449; int width = 581; std::vector data; - for(int i=0; i < 3 * height * width; ++i) { + for (int i = 0; i < 3 * height * width; ++i) { data.push_back(0.0); } @@ -86,47 +86,61 @@ void test_naive(int batch_size, std::string model_path) { // in_img.close(); // std::cout << "sum: " << sum_n << std::endl; - PaddleTensor tensor; - tensor.shape = std::vector({batch_size, 3, height, width}); - tensor.data.Resize(sizeof(float) * batch_size * 3 * height * width); - std::copy(data.begin(), data.end(), - static_cast(tensor.data.data())); - tensor.dtype = PaddleDType::FLOAT32; - std::vector paddle_tensor_feeds(1, tensor); - - constexpr int num_jobs = 2; // each job run 1 batch - std::vector threads; - + PaddleTensor tensor; + tensor.shape = std::vector({batch_size, 3, height, width}); + tensor.data.Resize(sizeof(float) * batch_size * 3 * height * width); + std::copy(data.begin(), data.end(), static_cast(tensor.data.data())); + tensor.dtype = PaddleDType::FLOAT32; + std::vector paddle_tensor_feeds(1, tensor); + + constexpr int num_jobs = 5; // each job run 1 batch + std::vector threads; + // using PtrPred = std::vector>; + std::vector> predictors; + for (int tid = 0; tid < num_jobs; ++tid) { + auto& pred = CreatePaddlePredictor(config); + predictors.emplace_back(std::move(pred)); + } - for (int tid = 0; tid < num_jobs; ++tid) { - threads.emplace_back([&, tid]() { + using namespace std::chrono_literals; + // std::this_thread::sleep_for(std::chrono::seconds(20)); + std::cout << "before start predict"; + + int epoches = 100000; + for (int tid = 0; tid < num_jobs; ++tid) { + threads.emplace_back([&, tid]() { + // auto predictor = CreatePaddlePredictor(config); + auto& predictor = predictors[tid]; + // auto& predictor = predictors[tid]; + // auto predictor = preds[tid]; + // std::this_thread::sleep_for(std::chrono::seconds(20)); PaddleTensor tensor_out; std::vector outputs(1, tensor_out); - auto predictor = CreatePaddlePredictor(config); - for (size_t i = 0; i < 1000; i++) { - ASSERT_TRUE(predictor->Run(paddle_tensor_feeds, &outputs)); - VLOG(0) << "tid : " << tid << " run: " << i << "finished"; - //std::cout <<"tid : " << tid << " run: " << i << "finished" << std::endl; - ASSERT_EQ(outputs.size(), 1UL); - // int64_t* data_o = static_cast(outputs[0].data.data()); - // int64_t sum_out = 0; - // for (size_t j = 0; j < outputs[0].data.length() / sizeof(int64_t); - // ++j) { - // sum_out += data_o[j]; - // } - // std::cout << "tid : " << tid << "pass : " << i << " " << sum_out - // << std::endl; - } - }); - } - for (int i = 0; i < num_jobs; ++i) { - threads[i].join(); - } + for (size_t i = 0; i < epoches; i++) { + ASSERT_TRUE(predictor->Run(paddle_tensor_feeds, &outputs)); + VLOG(0) << "tid : " << tid << " run: " << i << "finished"; + // std::cout <<"tid : " << tid << " run: " << i << "finished" << + // std::endl; + ASSERT_EQ(outputs.size(), 1UL); + // int64_t* data_o = static_cast(outputs[0].data.data()); + // int64_t sum_out = 0; + // for (size_t j = 0; j < outputs[0].data.length() / sizeof(int64_t); + // ++j) { + // sum_out += data_o[j]; + // } + // std::cout << "tid : " << tid << "pass : " << i << " " << sum_out + // << std::endl; + } + }); + } + for (int i = 0; i < num_jobs; ++i) { + threads[i].join(); } +} // } -} // namespace paddle +} // namespace paddle - int main(int argc, char** argv) { - paddle::test_naive(1 << 0, ""); - return 0; +int main(int argc, char** argv) { + paddle::test_naive(1 << 0, ""); + return 0; } diff --git a/paddle/fluid/memory/detail/buddy_allocator.cc b/paddle/fluid/memory/detail/buddy_allocator.cc index c2f45fdc99..dad5c8257a 100644 --- a/paddle/fluid/memory/detail/buddy_allocator.cc +++ b/paddle/fluid/memory/detail/buddy_allocator.cc @@ -11,7 +11,8 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - +#define GLOG_NO_ABBREVIATED_SEVERITIES +#define GOOGLE_GLOG_DLL_DECL #include "paddle/fluid/memory/detail/buddy_allocator.h" #include "glog/logging.h" diff --git a/paddle/fluid/memory/detail/meta_cache.cc b/paddle/fluid/memory/detail/meta_cache.cc index b86e4f38c4..2a283733f5 100644 --- a/paddle/fluid/memory/detail/meta_cache.cc +++ b/paddle/fluid/memory/detail/meta_cache.cc @@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#define GLOG_NO_ABBREVIATED_SEVERITIES +#define GOOGLE_GLOG_DLL_DECL #include "glog/logging.h" #include "paddle/fluid/memory/detail/memory_block.h" #include "paddle/fluid/platform/assert.h" diff --git a/paddle/fluid/operators/top_k_op.cc b/paddle/fluid/operators/top_k_op.cc index c17d1afc30..4a8ac441cf 100644 --- a/paddle/fluid/operators/top_k_op.cc +++ b/paddle/fluid/operators/top_k_op.cc @@ -50,7 +50,7 @@ class TopkOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { AddInput("X", "(Tensor) The input of Topk op"); - AddOutput("Out", "(Tensor) The output tensor of Topk op"); + AddOutput("Out", "(Tensor) The output tensor of Topk op").Reuse("X"); AddOutput("Indices", "(Tensor) The indices of Topk elements of input"); AddComment(R"DOC( Top K operator diff --git a/paddle/fluid/operators/top_k_op.cu b/paddle/fluid/operators/top_k_op.cu index 0cad224ca8..9da8551eb2 100644 --- a/paddle/fluid/operators/top_k_op.cu +++ b/paddle/fluid/operators/top_k_op.cu @@ -256,65 +256,36 @@ __device__ __forceinline__ void BlockReduce(Pair* sh_topk, int* maxid, * 3. go to the second setp, until one thread's topk value is null; * 4. go to the first setp, until get the topk value. */ - template __global__ void KeMatrixTopK(T* output, int output_stride, int64_t* indices, - const T* src, int lds, int dim, int k, - int grid_dim, int num) { + const T* src, int lds, int dim, int k) { __shared__ Pair sh_topk[BlockSize]; + __shared__ int maxid[BlockSize / 2]; const int tid = threadIdx.x; const int warp = threadIdx.x / 32; + output += blockIdx.x * output_stride; + indices += blockIdx.x * k; - const int bid = blockIdx.x; - for (int i = bid; i < num; i += grid_dim) { - int top_num = k; - __shared__ int maxid[BlockSize / 2]; - T* out = output + i * output_stride; - int64_t* inds = indices + i * k; - Pair topk[MaxLength]; - int beam = MaxLength; - Pair max; - bool is_empty = false; - bool firststep = true; - - for (int j = 0; j < MaxLength; j++) { - topk[j].set(-INFINITY, -1); - } - while (top_num) { - ThreadGetTopK( - topk, &beam, k, src + i * lds, &firststep, &is_empty, &max, dim, tid); + Pair topk[MaxLength]; + int beam = MaxLength; + Pair max; + bool is_empty = false; + bool firststep = true; - sh_topk[tid] = topk[0]; - BlockReduce(sh_topk, maxid, topk, &out, &inds, - &beam, &top_num, tid, warp); - } + for (int k = 0; k < MaxLength; k++) { + topk[k].set(-INFINITY, -1); } -} - -inline static int GetDesiredBlockDim(int dim) { - if (dim > 128) { - return 256; - } else if (dim > 64) { - return 128; - } else if (dim > 32) { - return 64; - } else { - return 32; + while (k) { + ThreadGetTopK(topk, &beam, k, + src + blockIdx.x * lds, &firststep, + &is_empty, &max, dim, tid); + + sh_topk[tid] = topk[0]; + BlockReduce(sh_topk, maxid, topk, &output, + &indices, &beam, &k, tid, warp); } } -#define FIXED_BLOCK_DIM_BASE(dim, ...) \ - case (dim): { \ - constexpr auto kBlockDim = (dim); \ - __VA_ARGS__; \ - } break - -#define FIXED_BLOCK_DIM(...) \ - FIXED_BLOCK_DIM_BASE(256, ##__VA_ARGS__); \ - FIXED_BLOCK_DIM_BASE(128, ##__VA_ARGS__); \ - FIXED_BLOCK_DIM_BASE(64, ##__VA_ARGS__); \ - FIXED_BLOCK_DIM_BASE(32, ##__VA_ARGS__) - template class TopkOpCUDAKernel : public framework::OpKernel { public: @@ -327,38 +298,30 @@ class TopkOpCUDAKernel : public framework::OpKernel { size_t k = static_cast(ctx.Attr("k")); const T* input_data = input->data(); + T* output_data = output->mutable_data(ctx.GetPlace()); // FIXME(typhoonzero): data is always converted to type T? int64_t* indices_data = indices->mutable_data(ctx.GetPlace()); - framework::DDim inputdims = input->dims(); - const size_t input_height = framework::product( - framework::slice_ddim(inputdims, 0, inputdims.size() - 1)); - const size_t input_width = inputdims[inputdims.size() - 1]; - + size_t input_height = input->dims()[0]; + size_t input_width = input->dims()[1]; if (k > input_width) k = input_width; // NOTE: pass lds and dim same to input width. // NOTE: old matrix implementation of stride is different to eigen. // TODO(typhoonzero): refine this kernel. - const int kMaxHeight = 2048; - int gridx = input_height < kMaxHeight ? input_height : kMaxHeight; - auto& dev_ctx = ctx.cuda_device_context(); - switch (GetDesiredBlockDim(input_width)) { - FIXED_BLOCK_DIM( - KeMatrixTopK<<>>( - output_data, k, indices_data, input_data, input_width, - input_width, static_cast(k), gridx, input_height)); - default: - PADDLE_THROW("Error"); - } + dim3 threads(256, 1); + dim3 grid(input_height, 1); + + KeMatrixTopK<<< + grid, threads, 0, reinterpret_cast( + ctx.device_context()) + .stream()>>>( + output_data, output->dims()[1], indices_data, input_data, input_width, + input_width, static_cast(k)); } }; -#undef FIXED_BLOCK_DIM_BASE -#undef FIXED_BLOCK_DIM - } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/top_k_op.h b/paddle/fluid/operators/top_k_op.h index 76ece57b39..054dd48199 100644 --- a/paddle/fluid/operators/top_k_op.h +++ b/paddle/fluid/operators/top_k_op.h @@ -34,6 +34,7 @@ class TopkKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { // Get the top k elements of each row of input tensor + // FIXME: only deal with matrix(2d tensor). auto* input = ctx.Input("X"); auto* output = ctx.Output("Out"); auto* indices = ctx.Output("Indices"); @@ -43,6 +44,8 @@ class TopkKernel : public framework::OpKernel { T* output_data = output->mutable_data(ctx.GetPlace()); int64_t* indices_data = indices->mutable_data(ctx.GetPlace()); + auto eg_input = EigenMatrix::From(*input); + // reshape input to a flattern matrix(like flat_inner_dims) framework::DDim inputdims = input->dims(); const size_t row = framework::product( @@ -50,7 +53,7 @@ class TopkKernel : public framework::OpKernel { const size_t col = inputdims[inputdims.size() - 1]; Eigen::DSizes flat2dims(row, col); // NOTE: eigen shape doesn't affect paddle tensor. - auto eg_input = EigenMatrix::Reshape(*input, inputdims.size() - 1); + eg_input.reshape(flat2dims); #ifdef PADDLE_WITH_MKLML #pragma omp parallel for diff --git a/paddle/fluid/platform/CMakeLists.txt b/paddle/fluid/platform/CMakeLists.txt index 4e2b3ac0e3..9ac8ae2ac7 100644 --- a/paddle/fluid/platform/CMakeLists.txt +++ b/paddle/fluid/platform/CMakeLists.txt @@ -27,6 +27,12 @@ ENDIF() cc_library(cpu_info SRCS cpu_info.cc DEPS ${CPU_INFO_DEPS}) cc_test(cpu_info_test SRCS cpu_info_test.cc DEPS cpu_info) +set(CUDA_LIB "C:\\Program\ Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v8.0\\lib\\x64") +set(MYDEPS ${MYDEPS} libcmt shlwapi) +set(MYDEPS ${MYDEPS} ${CUDA_LIB}/cudart${CMAKE_STATIC_LIBRARY_SUFFIX}) +set(MYDEPS ${MYDEPS} ${CUDA_LIB}/cublas${CMAKE_STATIC_LIBRARY_SUFFIX}) +set(MYDEPS ${MYDEPS} ${CUDA_LIB}/cudnn${CMAKE_STATIC_LIBRARY_SUFFIX}) + nv_library(gpu_info SRCS gpu_info.cc DEPS gflags glog enforce) cc_library(place SRCS place.cc DEPS enforce boost) @@ -58,6 +64,7 @@ nv_test(device_context_test SRCS device_context_test.cu DEPS device_context gpu_ cc_test(init_test SRCS init_test.cc DEPS device_context) nv_test(cudnn_helper_test SRCS cudnn_helper_test.cc DEPS dynload_cuda) +target_link_libraries(cudnn_helper_test ${MYDEPS}) nv_test(transform_test SRCS transform_test.cu DEPS memory place device_context) diff --git a/paddle/fluid/platform/cudnn_helper.h b/paddle/fluid/platform/cudnn_helper.h index b6e15862c1..8fe6c20be1 100644 --- a/paddle/fluid/platform/cudnn_helper.h +++ b/paddle/fluid/platform/cudnn_helper.h @@ -68,7 +68,14 @@ inline const char* cudnnGetErrorString(cudnnStatus_t status) { } \ } while (false) #else -#define CUDNN_ENFORCE(condition) +// windows +#define CUDNN_ENFORCE(condition) \ + do { \ + cudnnStatus_t status = condition; \ + if (status != CUDNN_STATUS_SUCCESS) { \ + std::cerr << ::paddle::platform::cudnnGetErrorString(status); \ + } \ + } while (false) #endif enum class DataLayout { // Not use diff --git a/paddle/fluid/platform/enforce.h b/paddle/fluid/platform/enforce.h index baa123fd0f..241f79d8e7 100644 --- a/paddle/fluid/platform/enforce.h +++ b/paddle/fluid/platform/enforce.h @@ -127,7 +127,7 @@ struct EOFException : public std::exception { #define UNLIKELY(condition) __builtin_expect(static_cast(condition), 0) #else // there is no equivalent intrinsics in msvc. -#define UNLIKELY(condition) (condition == 0) +#define UNLIKELY(condition) ((condition) == 0) #endif template @@ -309,7 +309,6 @@ inline void throw_on_error(T e) { #define PADDLE_ENFORCE_LE(__VAL0, __VAL1, ...) \ __PADDLE_BINARY_COMPARE(__VAL0, __VAL1, <=, >, __VA_ARGS__) - #define PADDLE_ENFORCE_NOT_NULL(__VAL, ...) \ do { \ if (UNLIKELY(nullptr == (__VAL))) { \ @@ -330,26 +329,26 @@ inline void throw_on_error(T e) { } \ } while (0) #else -#define PADDLE_ENFORCE_EQ(__VAL0, __VAL1, ...) ((__VAL0)==(__VAL1)) -#define PADDLE_ENFORCE_NE(__VAL0, __VAL1, ...) ((__VAL0)!=(__VAL1)) -#define PADDLE_ENFORCE_GT(__VAL0, __VAL1, ...) ((__VAL0)>(__VAL1)) -#define PADDLE_ENFORCE_GE(__VAL0, __VAL1, ...) ((__VAL0)>=(__VAL1)) -#define PADDLE_ENFORCE_LT(__VAL0, __VAL1, ...) ((__VAL0)<(__VAL1)) -#define PADDLE_ENFORCE_LE(__VAL0, __VAL1, ...) ((__VAL0)<=(__VAL1)) - -#define __PADDLE_BINARY_COMPARE(__VAL0, __VAL1, __CMP, __INV_CMP, ...) \ - do { \ - if (!((__VAL0)__CMP(__VAL1))) { \ - PADDLE_THROW("Windows disable the enforce. Enforce failed."); \ - } \ - } while(0) -#define PADDLE_ENFORCE_NOT_NULL(__VAL1, ...) \ - do { \ - if (nullptr == (__VAL1)) { \ +#define PADDLE_ENFORCE_EQ(__VAL0, __VAL1, ...) ((__VAL0) == (__VAL1)) +#define PADDLE_ENFORCE_NE(__VAL0, __VAL1, ...) ((__VAL0) != (__VAL1)) +#define PADDLE_ENFORCE_GT(__VAL0, __VAL1, ...) ((__VAL0) > (__VAL1)) +#define PADDLE_ENFORCE_GE(__VAL0, __VAL1, ...) ((__VAL0) >= (__VAL1)) +#define PADDLE_ENFORCE_LT(__VAL0, __VAL1, ...) ((__VAL0) < (__VAL1)) +#define PADDLE_ENFORCE_LE(__VAL0, __VAL1, ...) ((__VAL0) <= (__VAL1)) + +#define __PADDLE_BINARY_COMPARE(__VAL0, __VAL1, __CMP, __INV_CMP, ...) \ + do { \ + if (!((__VAL0)__CMP(__VAL1))) { \ + PADDLE_THROW("Windows disable the enforce. Enforce failed."); \ + } \ + } while (0) +#define PADDLE_ENFORCE_NOT_NULL(__VAL1, ...) \ + do { \ + if (nullptr == (__VAL1)) { \ PADDLE_THROW("Windows disable the enforce. Enforce failed"); \ - } \ - } while(0) -#endif // !_WIN32 + } \ + } while (0) +#endif // !_WIN32 } // namespace platform } // namespace paddle -- GitLab