diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index aa8b4628c5f39d49a89ac18b8b11154ae19dfecc..b1cfb23f3a839fe6ae56db4485288a56518d9834 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -131,11 +131,13 @@ cc_test(version_test SRCS version_test.cc DEPS version) cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS shape_inference op_info operator glog version) -if(NOT WIN32) -cc_library(ngraph_bridge SRCS ngraph_bridge.cc DEPS operator framework_proto ngraph) -cc_library(ngraph_operator SRCS ngraph_operator.cc DEPS ngraph_bridge operator op_info device_context tensor scope glog - shape_inference data_transform lod_tensor profiler) -endif(NOT WIN32) +if(WITH_NGRAPH) + if(NOT WIN32) + cc_library(ngraph_bridge SRCS ngraph_bridge.cc DEPS operator framework_proto ngraph) + cc_library(ngraph_operator SRCS ngraph_operator.cc DEPS ngraph_bridge operator op_info device_context tensor scope glog + shape_inference data_transform lod_tensor profiler ngraph) + endif(NOT WIN32) +endif(WITH_NGRAPH) cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog proto_desc) nv_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry) @@ -171,14 +173,20 @@ if(WITH_DISTRIBUTE) set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) else() - if(NOT WIN32) - cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass ngraph_operator variable_helper garbage_collector) - else(NOT WIN32) - cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass variable_helper garbage_collector) - endif(NOT WIN32) + if(WITH_NGRAPH) + if(NOT WIN32) + cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass ngraph ngraph_operator variable_helper) + else(NOT WIN32) + cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass variable_helper) + endif(NOT WIN32) + else(WITH_NGRAPH) + cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass variable_helper) + endif(WITH_NGRAPH) cc_test(test_naive_executor SRCS naive_executor_test.cc DEPS naive_executor elementwise_add_op) endif() +target_link_libraries(executor garbage_collector) + cc_library(parallel_executor SRCS parallel_executor.cc DEPS threaded_ssa_graph_executor scope_buffered_ssa_graph_executor graph build_strategy diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index 7eab87601594e4405b66479a6d390659c153ba79..16c4552a5f05e6a25c131942b8aa60b10d975c1c 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -18,7 +18,6 @@ limitations under the License. */ #include "paddle/fluid/framework/feed_fetch_method.h" #include "paddle/fluid/framework/lod_rank_table.h" #include "paddle/fluid/framework/lod_tensor_array.h" -#include "paddle/fluid/framework/ngraph_operator.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/reader.h" #include "paddle/fluid/framework/transfer_scope_cache.h" @@ -27,6 +26,10 @@ limitations under the License. */ #include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/profiler.h" +#ifdef PADDLE_WITH_NGRAPH +#include "paddle/fluid/framework/ngraph_operator.h" +#endif + DECLARE_bool(benchmark); DEFINE_bool(use_mkldnn, false, "Use MKLDNN to run"); DEFINE_bool(use_ngraph, false, "Use NGRAPH to run"); @@ -131,11 +134,11 @@ static void DeleteUnusedTensors( static void EnableFusedOp(ExecutorPrepareContext* ctx) { #ifdef PADDLE_WITH_NGRAPH VLOG(3) << "use_ngraph=True"; - auto intervals = FusedOperator::FusedOpIntervals(&ctx->ops_); + auto intervals = NgraphOperator::NgraphOpIntervals(&ctx->ops_); for (auto& interval : intervals) { - auto* fused_op = new FusedOperator(ctx->prog_, ctx->block_id_, - interval.at(0), interval.at(1)); - *interval[0] = std::unique_ptr(fused_op); + auto* ng_op = new NgraphOperator(ctx->prog_, ctx->block_id_, interval.at(0), + interval.at(1)); + *interval[0] = std::unique_ptr(ng_op); } for (auto it = intervals.rbegin(); it != intervals.rend(); ++it) { ctx->ops_.erase(it->at(0) + 1, it->at(1)); diff --git a/paddle/fluid/framework/ngraph_bridge.cc b/paddle/fluid/framework/ngraph_bridge.cc index e22c29037718a60ff7f24404d7749600e2edb80b..a5acfd70449e92663cb66ef90a141c087ff6ec88 100644 --- a/paddle/fluid/framework/ngraph_bridge.cc +++ b/paddle/fluid/framework/ngraph_bridge.cc @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#ifdef PADDLE_WITH_NGRAPH #include #include #include @@ -27,14 +26,15 @@ namespace paddle { namespace framework { static std::shared_ptr GetNode( - const std::shared_ptr& op, const std::string prm, + const std::shared_ptr& op, const std::string name, const VariableNameMap& var_map, std::shared_ptr< std::unordered_map>> ngb_node_map) { - auto& var_names = var_map.at(prm); + auto& var_names = var_map.at(name); PADDLE_ENFORCE_EQ(var_names.size(), 1, - "op %s prm %s expects one associated var", op->Type(), prm); + "op %s name %s expects one associated var", op->Type(), + name); if (ngb_node_map->find(var_names[0]) != ngb_node_map->end()) { return (*ngb_node_map)[var_names[0]]; } else { @@ -43,42 +43,42 @@ static std::shared_ptr GetNode( } static std::shared_ptr GetInputNode( - const std::shared_ptr& op, const std::string prm, + const std::shared_ptr& op, const std::string name, std::shared_ptr< std::unordered_map>> ngb_node_map) { - return GetNode(op, prm, op->Inputs(), ngb_node_map); + return GetNode(op, name, op->Inputs(), ngb_node_map); } static std::shared_ptr GetOutputNode( - const std::shared_ptr& op, const std::string prm, + const std::shared_ptr& op, const std::string name, std::shared_ptr< std::unordered_map>> ngb_node_map) { - return GetNode(op, prm, op->Outputs(), ngb_node_map); + return GetNode(op, name, op->Outputs(), ngb_node_map); } static void SetOutputNode( - const std::shared_ptr& op, const std::string prm, + const std::shared_ptr& op, const std::string name, std::shared_ptr node, std::shared_ptr< std::unordered_map>> ngb_node_map) { - auto& var_names = op->Outputs().at(prm); + auto& var_names = op->Outputs().at(name); if (var_names.size() == 1) { (*ngb_node_map)[var_names[0]] = node; } else if (var_names.size() == 0) { (*ngb_node_map)[""] = node; } else { - PADDLE_THROW("prm %s has more than 1 var_names.", prm); + PADDLE_THROW("name %s has more than 1 var_names.", name); } } static bool HasOutput(const std::shared_ptr& op, - const std::string prm) { + const std::string name) { auto& outputs = op->Outputs(); - if (outputs.find(prm) == outputs.end()) return false; - return outputs.at(prm).size() > 0; + if (outputs.find(name) == outputs.end()) return false; + return outputs.at(name).size() > 0; } template @@ -118,4 +118,3 @@ void NgraphBridge::BuildNgNode(const std::shared_ptr& op) { } // namespace framework } // namespace paddle -#endif diff --git a/paddle/fluid/framework/ngraph_bridge.h b/paddle/fluid/framework/ngraph_bridge.h index 9ed6b9510942136a61faa5755fd8fa74286939a8..5ad7b8daeb6a782515e50fc87ca7188b46308390 100644 --- a/paddle/fluid/framework/ngraph_bridge.h +++ b/paddle/fluid/framework/ngraph_bridge.h @@ -14,8 +14,6 @@ limitations under the License. */ #pragma once -#ifdef PADDLE_WITH_NGRAPH - #include #include #include @@ -53,4 +51,3 @@ class NgraphBridge { } // namespace framework } // namespace paddle -#endif diff --git a/paddle/fluid/framework/ngraph_operator.cc b/paddle/fluid/framework/ngraph_operator.cc index 3fea753f0659019395c9b214e52a7912058c501c..253de4c61160e52202a0192215a93284f27e5896 100644 --- a/paddle/fluid/framework/ngraph_operator.cc +++ b/paddle/fluid/framework/ngraph_operator.cc @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#ifdef PADDLE_WITH_NGRAPH #include #include @@ -58,16 +57,16 @@ typedef enum { /* nGraph support state on ops */ } op_state; // perform graph build through bridge and execute computation -class NgraphOperator { +class NgraphEngine { public: - explicit NgraphOperator(const Scope& scope, const platform::Place& place, - const std::vector>& ops, - const std::unordered_map< - std::string, ngraph::element::Type>& var_type_map, - const std::unordered_set& persist, - const std::unordered_set& fetches, - const std::unordered_set& post_op_inputs, - op_state ng_op_state) + explicit NgraphEngine(const Scope& scope, const platform::Place& place, + const std::vector>& ops, + const std::unordered_map< + std::string, ngraph::element::Type>& var_type_map, + const std::unordered_set& persist, + const std::unordered_set& fetches, + const std::unordered_set& post_op_inputs, + op_state ng_op_state) : scope_(scope), place_(place), fused_ops_(ops), @@ -132,7 +131,7 @@ class NgraphOperator { }; std::vector>::iterator>> -FusedOperator::FusedOpIntervals( +NgraphOperator::NgraphOpIntervals( std::vector>* ops) { std::vector>::iterator>> intervals; @@ -185,7 +184,7 @@ FusedOperator::FusedOpIntervals( return intervals; } -FusedOperator::FusedOperator( +NgraphOperator::NgraphOperator( const ProgramDesc& prog, size_t block_id, std::vector>::iterator start, std::vector>::iterator end, @@ -215,7 +214,7 @@ FusedOperator::FusedOperator( Process(); } -void FusedOperator::Process() { +void NgraphOperator::Process() { auto& bdesc = pdesc_.Block(block_); for (auto& var : bdesc.AllVars()) { if (!(var->GetType() == proto::VarType::SELECTED_ROWS || @@ -251,8 +250,8 @@ void FusedOperator::Process() { } } -void FusedOperator::RunImpl(const Scope& scope, - const platform::Place& place) const { +void NgraphOperator::RunImpl(const Scope& scope, + const platform::Place& place) const { op_state ng_op_state = PARTIAL_TEST; auto& bdesc = pdesc_.Block(block_); for (auto* op : bdesc.AllOps()) { @@ -266,19 +265,19 @@ void FusedOperator::RunImpl(const Scope& scope, ng_op_state = ng_op_state == PARTIAL_TEST ? FULL_TEST : FULL_TRAIN; } - NgraphOperator ngraph_op(scope, place, fused_ops_, var_type_map_, - persistables_, fetches_, post_op_inputs_, - ng_op_state); - ngraph_op.Run(scope, place); + NgraphEngine ngraph_engine(scope, place, fused_ops_, var_type_map_, + persistables_, fetches_, post_op_inputs_, + ng_op_state); + ngraph_engine.Run(scope, place); } std::unordered_map> - NgraphOperator::func_cache_ = {}; + NgraphEngine::func_cache_ = {}; -std::shared_ptr NgraphOperator::backend_ = +std::shared_ptr NgraphEngine::backend_ = ngraph::runtime::Backend::create("CPU"); -void NgraphOperator::GetNgInputShape(std::shared_ptr op) { +void NgraphEngine::GetNgInputShape(std::shared_ptr op) { op->RuntimeInferShape(scope_, place_); for (auto& var_name_item : op->Inputs()) { for (auto& var_name : var_name_item.second) { @@ -301,7 +300,7 @@ void NgraphOperator::GetNgInputShape(std::shared_ptr op) { } } -void NgraphOperator::BuildNgNodes() { +void NgraphEngine::BuildNgNodes() { for (auto& var_name : var_out_) { if (var_node_map_->find(var_name) == var_node_map_->end()) { auto* var = scope_.FindVar(var_name); @@ -323,7 +322,7 @@ void NgraphOperator::BuildNgNodes() { } } -void NgraphOperator::BuildNgIO() { +void NgraphEngine::BuildNgIO() { std::unordered_set inputs; std::unordered_set outputs; @@ -395,7 +394,7 @@ void NgraphOperator::BuildNgIO() { } } -void NgraphOperator::BuildNgFunction() { +void NgraphEngine::BuildNgFunction() { BuildNgNodes(); ngraph_function_ = nullptr; ngraph::NodeVector func_outputs; @@ -416,7 +415,7 @@ void NgraphOperator::BuildNgFunction() { std::make_shared(func_outputs, func_inputs); } -std::shared_ptr NgraphOperator::GetCacheKey() { +std::shared_ptr NgraphEngine::GetCacheKey() { auto cache_key = std::make_shared(""); *cache_key += std::to_string(fused_ops_.size()); for (auto& op : fused_ops_) { @@ -444,7 +443,7 @@ std::shared_ptr NgraphOperator::GetCacheKey() { return cache_key; } -void NgraphOperator::GetNgFunction() { +void NgraphEngine::GetNgFunction() { bool cache_on = true; if (cache_on) { std::string cache_key_val = *GetCacheKey(); @@ -459,8 +458,7 @@ void NgraphOperator::GetNgFunction() { } } -void NgraphOperator::Run(const Scope& scope, - const platform::Place& place) const { +void NgraphEngine::Run(const Scope& scope, const platform::Place& place) const { std::vector> t_in; std::vector> t_out; @@ -545,7 +543,6 @@ void NgraphOperator::Run(const Scope& scope, } backend_->call(ngraph_function_, t_out, t_in); -} // NgraphOperator::RunImpl +} // NgraphEngine::RunImpl } // namespace framework } // namespace paddle -#endif diff --git a/paddle/fluid/framework/ngraph_operator.h b/paddle/fluid/framework/ngraph_operator.h index 3ca023e11111c5b447b2cabbfb8bb29877297f65..ede80f44bea208b66acc3b3f4bc0f4adee4fb860 100644 --- a/paddle/fluid/framework/ngraph_operator.h +++ b/paddle/fluid/framework/ngraph_operator.h @@ -14,8 +14,6 @@ limitations under the License. */ #pragma once -#ifdef PADDLE_WITH_NGRAPH - #include #include #include @@ -34,14 +32,14 @@ limitations under the License. */ namespace paddle { namespace framework { -class FusedOperator : public OperatorBase { +class NgraphOperator : public OperatorBase { public: static std::vector< std::vector>::iterator>> - FusedOpIntervals( + NgraphOpIntervals( std::vector>* ops); - explicit FusedOperator( + explicit NgraphOperator( const ProgramDesc& prog, size_t block_id, std::vector>::iterator start, std::vector>::iterator end, @@ -64,4 +62,3 @@ class FusedOperator : public OperatorBase { }; } // namespace framework } // namespace paddle -#endif diff --git a/paddle/fluid/inference/analysis/passes/ir_graph_build_pass.cc b/paddle/fluid/inference/analysis/passes/ir_graph_build_pass.cc index b8a045c18fab54581b4d2b902be373f55ad09e8a..c6e923c00484f01f17550ae2926dabcadc0c3ac6 100644 --- a/paddle/fluid/inference/analysis/passes/ir_graph_build_pass.cc +++ b/paddle/fluid/inference/analysis/passes/ir_graph_build_pass.cc @@ -44,9 +44,10 @@ void IrGraphBuildPass::RunImpl(Argument *argument) { argument->SetMainProgram(program.release()); } else if (argument->model_program_path_valid() && argument->model_params_path_valid()) { - auto program = - LoadModel(argument->model_program_path(), argument->model_params_path(), - argument->scope_ptr(), place, argument->model_from_memory()); + auto program = LoadModel( + argument->model_program_path(), argument->model_params_path(), + argument->scope_ptr(), place, + argument->model_from_memory_valid() && argument->model_from_memory()); argument->SetMainProgram(program.release()); } else { PADDLE_THROW( diff --git a/paddle/fluid/inference/tests/api/CMakeLists.txt b/paddle/fluid/inference/tests/api/CMakeLists.txt index a07626a10315a6206f8c1ebc9a19df90663a88ee..8a4bc04b67879918c6ac8d1b40dae68a107034d4 100644 --- a/paddle/fluid/inference/tests/api/CMakeLists.txt +++ b/paddle/fluid/inference/tests/api/CMakeLists.txt @@ -1,4 +1,4 @@ -set(INFERENCE_EXTRA_DEPS paddle_inference_api paddle_fluid_api ir_pass_manager analysis_predictor) +set(INFERENCE_EXTRA_DEPS paddle_inference_api paddle_fluid_api ir_pass_manager analysis_predictor benchmark) if(WITH_GPU AND TENSORRT_FOUND) set(INFERENCE_EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} analysis ${analysis_deps} ir_pass_manager analysis_predictor) diff --git a/paddle/fluid/inference/tests/api/tester_helper.h b/paddle/fluid/inference/tests/api/tester_helper.h index d572ea0177c1e398229a02718ca31cc78a7059ef..8209a049f4614fe31c22c4e83c1968411b749b49 100644 --- a/paddle/fluid/inference/tests/api/tester_helper.h +++ b/paddle/fluid/inference/tests/api/tester_helper.h @@ -30,8 +30,10 @@ #include "paddle/fluid/inference/api/helper.h" #include "paddle/fluid/inference/tests/api/config_printer.h" #include "paddle/fluid/inference/tests/test_helper.h" +#include "paddle/fluid/inference/utils/benchmark.h" #include "paddle/fluid/platform/profiler.h" +DEFINE_string(model_name, "", "model name"); DEFINE_string(infer_model, "", "model path"); DEFINE_string(infer_data, "", "data file"); DEFINE_int32(batch_size, 1, "batch size."); @@ -40,6 +42,8 @@ DEFINE_bool(test_all_data, false, "Test the all dataset in data file."); DEFINE_int32(num_threads, 1, "Running the inference program in multi-threads."); DEFINE_bool(use_analysis, true, "Running the inference program in analysis mode."); +DEFINE_bool(record_benchmark, false, + "Record benchmark after profiling the model"); DECLARE_bool(profile); DECLARE_int32(paddle_num_threads); @@ -192,8 +196,16 @@ void TestOneThreadPrediction( predictor->Run(inputs[j], outputs, batch_size); } } - PrintTime(batch_size, num_times, 1, 0, run_timer.toc() / num_times, - inputs.size()); + + double latency = run_timer.toc() / num_times; + PrintTime(batch_size, num_times, 1, 0, latency, inputs.size()); + if (FLAGS_record_benchmark) { + Benchmark benchmark; + benchmark.SetName(FLAGS_model_name); + benchmark.SetBatchSize(batch_size); + benchmark.SetLatency(latency); + benchmark.PersistToFile("benchmark_record.txt"); + } } } diff --git a/paddle/fluid/inference/tests/api/trt_models_tester.cc b/paddle/fluid/inference/tests/api/trt_models_tester.cc index ef612ce6148329c33f194842945bb5438afcf645..9eb3fb5da1065f14d9ad1c3520f6415fbadfdca1 100644 --- a/paddle/fluid/inference/tests/api/trt_models_tester.cc +++ b/paddle/fluid/inference/tests/api/trt_models_tester.cc @@ -135,6 +135,9 @@ TEST(TensorRT_resnext50, compare) { TEST(TensorRT_resnext50, profile) { std::string model_dir = FLAGS_infer_model + "/resnext50"; + // Set FLAGS_record_benchmark to true to record benchmark to file. + // FLAGS_record_benchmark=true; + FLAGS_model_name = "resnext50"; profile(model_dir, /* use_analysis */ true, FLAGS_use_tensorrt); } diff --git a/paddle/fluid/inference/utils/benchmark.cc b/paddle/fluid/inference/utils/benchmark.cc index d03aa11b75ee58524746212e43a5796773f47932..0bd526bcac2d9ceda95730dc3c5210aed8ccfb5c 100644 --- a/paddle/fluid/inference/utils/benchmark.cc +++ b/paddle/fluid/inference/utils/benchmark.cc @@ -30,7 +30,7 @@ std::string Benchmark::SerializeToString() const { ss << '\n'; ss << name_ << "\t"; - ss << batch_size_ << "\t"; + ss << batch_size_ << "\t\t"; ss << num_threads_ << "\t"; ss << latency_ << "\t"; ss << 1000.0 / latency_; diff --git a/paddle/fluid/inference/utils/visualizer.cc b/paddle/fluid/inference/utils/visualizer.cc index 040b6476fb4febc5ca1912c8db72dc63c3bced08..7c0dd64dea88e51b24c4bc04818d633ee0d2f722 100644 --- a/paddle/fluid/inference/utils/visualizer.cc +++ b/paddle/fluid/inference/utils/visualizer.cc @@ -26,9 +26,6 @@ DEFINE_string(model_dir, "", "model directory"); DEFINE_string(model_program_path, "", "model program path"); DEFINE_string(model_params_path, "", "model params path"); -USE_PASS(graph_viz_pass); -USE_PASS(graph_to_program_pass); - using paddle::inference::analysis::Argument; namespace paddle { @@ -40,7 +37,6 @@ void Visualizer::SetArgument(Argument *argument) { argument_ = argument; } bool Visualizer::Run() { paddle::framework::InitDevices(false); paddle::inference::analysis::Analyzer().Run(argument_); - return true; } @@ -77,7 +73,7 @@ int main(int argc, char *argv[]) { // Only 1 pass, default filename is 0_ir_origin.dot // For more details, looking for paddle::inference::analysis::IRPassManager - argument.SetIrAnalysisPasses({"graph_viz_pass"}); + argument.SetIrAnalysisPasses({"infer_clean_graph_pass", "graph_viz_pass"}); std::unique_ptr scope{ new paddle::framework::Scope()}; @@ -90,3 +86,7 @@ int main(int argc, char *argv[]) { return 0; } + +USE_PASS(infer_clean_graph_pass); +USE_PASS(graph_viz_pass); +USE_PASS(graph_to_program_pass); diff --git a/paddle/fluid/operators/activation_op.h b/paddle/fluid/operators/activation_op.h index 87d549678a0e6c183aac89539cf1f6331729de2c..c7df3ea58a91579e35ff0d486516271a6daf054f 100644 --- a/paddle/fluid/operators/activation_op.h +++ b/paddle/fluid/operators/activation_op.h @@ -301,23 +301,22 @@ template struct GeluFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { - auto temp = - ((x * static_cast(M_SQRT1_2)).erf()).template cast().eval(); + auto temp = (x * static_cast(M_SQRT1_2)).erf(); out.device(d) = x * static_cast(0.5) * (static_cast(1) + temp); } }; template struct GeluGradFunctor : BaseActivationFunctor { - bool Inplace() const { return IsInplace("gelu"); } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { - auto temp = (static_cast(0.5 * M_2_SQRTPI * M_SQRT1_2) * x * - ((-static_cast(0.5) * x.square()).exp())) - .template cast() - .eval(); - dx.device(d) = dout * (out / x + temp); + auto first = static_cast(0.5) * + (static_cast(1) + ((x * static_cast(M_SQRT1_2)).erf())); + + auto second = static_cast(0.5 * M_2_SQRTPI * M_SQRT1_2) * x * + (-static_cast(0.5) * x.square()).exp(); + dx.device(d) = dout * (first + second); } }; diff --git a/paddle/fluid/operators/distributed/brpc_client.cc b/paddle/fluid/operators/distributed/brpc_client.cc index b394c678fb6503eb73a1e11e6feb814251e9e940..350969f74be258ffbfef687b56083a9c6508bc81 100644 --- a/paddle/fluid/operators/distributed/brpc_client.cc +++ b/paddle/fluid/operators/distributed/brpc_client.cc @@ -158,7 +158,7 @@ ChannelQueuePtr BRPCClient::GetChannel(const std::string& ep) { for (int i = 0; i < FLAGS_brpc_channel_num; ++i) { std::shared_ptr c(new ChannelContext()); if (c->channel.Init(ep.c_str(), &options) != 0) { - LOG(ERROR) << "Fail to initialize channel"; + LOG(FATAL) << "Fail to initialize channel"; return nullptr; } diff --git a/paddle/fluid/operators/distributed/grpc_client.cc b/paddle/fluid/operators/distributed/grpc_client.cc index 857214aa211aee0251571e46049c66c084b470f1..f14dfcdb238a9580affde96e4d5a0093743eb6c8 100644 --- a/paddle/fluid/operators/distributed/grpc_client.cc +++ b/paddle/fluid/operators/distributed/grpc_client.cc @@ -390,8 +390,7 @@ void GRPCClient::Proceed() { VLOG(3) << c->GetVarHandlePtr()->String() << " process"; c->Process(); } else if (c->status_.error_code() == grpc::StatusCode::DEADLINE_EXCEEDED) { - // FIXME(gongwb): parse error_details? - LOG(ERROR) << c->GetVarHandlePtr()->String() + LOG(FATAL) << c->GetVarHandlePtr()->String() << " meets grpc error, error_code:" << c->status_.error_code() << " error_message:" << c->status_.error_message() << " error_details:" << c->status_.error_details();