diff --git a/CMakeLists.txt b/CMakeLists.txt index bdd48565edeca051f54e8fe4eb51cd1dbd5e836a..98e1ac9f2690ea6525d0807da4d21eaa3736967c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -65,6 +65,7 @@ option(REPLACE_ENFORCE_GLOG "Replace PADDLE_ENFORCE with glog/CHECK for better d option(WITH_ANAKIN "Compile with Anakin library" OFF) option(WITH_GRPC "Use grpc as the default rpc framework" ${WITH_DISTRIBUTE}) option(WITH_BRPC_RDMA "Use brpc rdma as the rpc protocal" OFF) +option(WITH_INFERENCE "Compile fluid inference library" ON) option(WITH_SYSTEM_BLAS "Use system blas library" OFF) option(PY_VERSION "Compile PaddlePaddle with python3 support" ${PY_VERSION}) diff --git a/cmake/generic.cmake b/cmake/generic.cmake index 9059ae206207d1feef9e037a635c7a07500f0b25..82c958073cba92f00a341121e36ba45531b22aec 100644 --- a/cmake/generic.cmake +++ b/cmake/generic.cmake @@ -264,6 +264,8 @@ function(cc_test TARGET_NAME) WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) if (${cc_test_SERIAL}) set_property(TEST ${TARGET_NAME} PROPERTY RUN_SERIAL 1) + + set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cpu_deterministic=true) set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_init_allocated_mem=true) set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cudnn_deterministic=true) endif() @@ -330,6 +332,8 @@ function(nv_test TARGET_NAME) add_test(${TARGET_NAME} ${TARGET_NAME}) if (nv_test_SERIAL) set_property(TEST ${TARGET_NAME} PROPERTY RUN_SERIAL 1) + + set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cpu_deterministic=true) set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_init_allocated_mem=true) set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cudnn_deterministic=true) endif() @@ -580,6 +584,7 @@ function(py_test TARGET_NAME) cmake_parse_arguments(py_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) add_test(NAME ${TARGET_NAME} COMMAND env FLAGS_init_allocated_mem=true FLAGS_cudnn_deterministic=true + FLAGS_cpu_deterministic=true PYTHONPATH=${PADDLE_BINARY_DIR}/python ${py_test_ENVS} ${PYTHON_EXECUTABLE} -u ${py_test_SRCS} ${py_test_ARGS} WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) diff --git a/doc/survey/op_fusion_design.md b/doc/survey/op_fusion_design.md new file mode 100644 index 0000000000000000000000000000000000000000..d6e48f4f58269b67450cb012f6dcc59e1083abba --- /dev/null +++ b/doc/survey/op_fusion_design.md @@ -0,0 +1,20 @@ +# Operator fusion +Fusing multiple operators together is an important method to optimize the program execution, particularly for GPU or other specialized accelerators. An obvious benefit is to avoid the overhead of saving the intermediate result back into global memory. + +There are generally two ways to fuse operators, fusing directly connected operators and fusing non directly connected operators. The first method is mainly used by [NNVM Compiler](https://github.com/dmlc/tvm/) and [XLA](https://www.tensorflow.org/performance/xla/). The second method is mainly used by Dynet and TensorFlow Fold to do auto-batching. The principle of fusing operator is according to some rules to combine multiple operations into one, for example, `Y = X * W` and `Z = Y + B` can be fused to `Z = X * W + B`, and `Y1 = X1 * W` and `Y2 = X2 * W` can be fused to `[Y1;Y2] = [X1;X2] * W`. In order to get a short-term profit, we decided to try to manually specify these rules. + +## Challenge +The challenge of fusing operators is: + - how to make the rules. + - how to implement these rules efficiently. + +### How to make the rules? + +The problem of determining the best single location for a fusion operator is an NP-hard combinatorial problem. After analysis the operators of the DL model, we found there are two group of operators can be fused explicitly, one is the simple and adjacent operations, for example, `tmp = x + y` and `z = Relu(tmp)`, and the other is the operators that have the same function, for example, a serials of `SGD` or `Momentum`. They usually appear in the model in a large number. So we should think about how to fuse them separately first. + +### How to implement these rules efficiently? +#### How to fuse the adjacent operations efficiently? +Here we use a template function to represent the fused operations. The pros of using a template function are that it is simple and efficient, and the cons are that it is not easy to expand, and it can only be used to express some simple operations. So taking into account our current needs, the template function is more appropriate. + +#### How to fuse the operators that have the same function efficiently? +We take SGD operator as an example, the training model may have hundreds of parameters and correspondingly have the same number of SGD operators. The expression(`w = w - lr*w_g`) of those operators is the same, so during of training, the executor will execute this expression hundreds time in CPU or other specialized accelerators. If we can fuse them and make the address of all `w` and all `w_g` continuous respectively, we only need execute one time. For some accelerators, the time of launching kernel is not neglected, so the time of hundreds of times of launching and executing kernel may be larger than launching and executing only once. There usually are many operators that similar to `SGD` in the DL model, such as `AllReduce` and `FC`. diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 3ef317bb7a1c25c5738342f34ae7994b0184a7de..dd172ff9c97814c089ddb2e5bf729880cf0c9cdb 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -336,6 +336,7 @@ paddle.fluid.contrib.BeamSearchDecoder.decode ArgSpec(args=['self'], varargs=Non paddle.fluid.contrib.BeamSearchDecoder.early_stop ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None) paddle.fluid.contrib.BeamSearchDecoder.read_array ArgSpec(args=['self', 'init', 'is_ids', 'is_scores'], varargs=None, keywords=None, defaults=(False, False)) paddle.fluid.contrib.BeamSearchDecoder.update_array ArgSpec(args=['self', 'array', 'value'], varargs=None, keywords=None, defaults=None) +paddle.fluid.contrib.memory_usage ArgSpec(args=['program', 'batch_size'], varargs=None, keywords=None, defaults=None) paddle.fluid.transpiler.DistributeTranspiler.__init__ ArgSpec(args=['self', 'config'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.transpiler.DistributeTranspiler.create_splited_vars ArgSpec(args=['self', 'source_var', 'block', 'tag'], varargs=None, keywords=None, defaults=None) paddle.fluid.transpiler.DistributeTranspiler.get_pserver_program ArgSpec(args=['self', 'endpoint'], varargs=None, keywords=None, defaults=None) diff --git a/paddle/fluid/CMakeLists.txt b/paddle/fluid/CMakeLists.txt index d274d96c29bdbf5973d568d783369c3975bdc436..2577e59d9cf24c26b7c04aa00cdde6cde17f7206 100644 --- a/paddle/fluid/CMakeLists.txt +++ b/paddle/fluid/CMakeLists.txt @@ -5,5 +5,7 @@ add_subdirectory(operators) add_subdirectory(pybind) add_subdirectory(string) add_subdirectory(recordio) -# NOTE: please add subdirectory inference at last. -add_subdirectory(inference) +if(WITH_INFERENCE) + # NOTE: please add subdirectory inference at last. + add_subdirectory(inference) +endif() diff --git a/paddle/fluid/framework/details/build_strategy.h b/paddle/fluid/framework/details/build_strategy.h index b2e5399e2376a86c1cd310b29c768832665af87f..8714a42162bda3d5ad12e7925fe8cc4e693f51b1 100644 --- a/paddle/fluid/framework/details/build_strategy.h +++ b/paddle/fluid/framework/details/build_strategy.h @@ -21,6 +21,26 @@ namespace framework { namespace details { struct BuildStrategy { + // ParallelExecutor supports two modes of ReduceStrategy, kAllReduce and + // kReduce, for CPU and GPU. If you use kAllReduce, different threads + // optimize their parameters separately. If you use kReduce, the optimizations + // of parameters are distributed to different threads. + // For example, a model has 100 parameters and is running with four threads, + // if you choose kAllReduce, every thread is to optimize 100 parameters + // separately, if you choose kReduce, every thread is to optimize 25 + // parameters. + // Of particular note is, if you use kReduce when using CPU training, + // all the parameters are shared between different threads. This feature will + // save memory. + // FIXME(zcd): The result of the two modes(kAllReduce and kReduce) maybe not + // equal for GPU. Because, the result of the different order of summing maybe + // different, for example, the result of `a+b+c+d` may be different with the + // result of `c+a+b+d`. + // For GPU, the implementation of kAllReduce and kReduce is adopted NCCL, + // so the result of kAllReduce and kReduce maybe not equal. + // For CPU, if you want to fix the order of summing to make the result + // of kAllReduce and kReduce no diff, you can add + // `FLAGS_cpu_deterministic=true` to env. enum class ReduceStrategy { kAllReduce = 0, kReduce = 1 }; enum class GradientScaleStrategy { diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index 5ca2ed8f96244a11925dfa6af8e48458cf334ecd..a4fdbcb26d1d0cfb05edebff5419d9559c336b3a 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -275,7 +275,8 @@ std::unique_ptr MultiDevSSAGraphBuilder::ApplyImpl( if (strategy_.gradient_scale_ != BuildStrategy::GradientScaleStrategy::kCustomized) { // TODO(paddle-dev): Why is there no input for this op_handle? - CreateScaleLossGradOp(&result); + auto loss_grad_name = node->Op()->OutputArgumentNames()[0]; + CreateScaleLossGradOp(&result, loss_grad_name); } // This assumes the backward generating code will ensure IsScaleLossOp // is true only for the op that scale the final scalar loss. @@ -535,7 +536,8 @@ int MultiDevSSAGraphBuilder::GetVarDeviceID(const ir::Graph &graph, return got == sharded_var_device.end() ? -1 : got->second; } -void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(ir::Graph *result) const { +void MultiDevSSAGraphBuilder::CreateScaleLossGradOp( + ir::Graph *result, const std::string &loss_grad_name) const { for (size_t i = 0; i < places_.size(); ++i) { // Insert ScaleCost OpHandle #ifdef PADDLE_WITH_CUDA @@ -558,10 +560,10 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(ir::Graph *result) const { // loss->pending_ops_.emplace_back(op_handle); // op_handle->inputs_.emplace_back(loss); - CreateOpOutput(result, op_handle, - result->CreateEmptyNode(GradVarName(loss_var_name_), - ir::Node::Type::kVariable), - places_[i], i); + CreateOpOutput( + result, op_handle, + result->CreateEmptyNode(loss_grad_name, ir::Node::Type::kVariable), + places_[i], i); } } diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.h b/paddle/fluid/framework/details/multi_devices_graph_builder.h index 099dbe5abef6458c4613c9f680440734f59cb6e2..f2cb6bb1c861e07f1034f1742ad4f3cfbb0d8837 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.h +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.h @@ -75,7 +75,9 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { void CreateComputationalOps(ir::Graph *result, ir::Node *node, size_t num_places) const; - void CreateScaleLossGradOp(ir::Graph *result) const; + void CreateScaleLossGradOp(ir::Graph *result, + const std::string &loss_grad_name) const; + VarHandle *CreateReduceOp(ir::Graph *result, const std::string &og, int dst_dev_id) const; void CreateComputationalOp(ir::Graph *result, ir::Node *node, diff --git a/paddle/fluid/framework/details/reduce_op_handle.cc b/paddle/fluid/framework/details/reduce_op_handle.cc index 68bdfbaf52120d19d05d156529626f42adda630d..6c7e5c1fb06620b1c071b00fcfcc1b4a29bf8d62 100644 --- a/paddle/fluid/framework/details/reduce_op_handle.cc +++ b/paddle/fluid/framework/details/reduce_op_handle.cc @@ -18,6 +18,10 @@ #include "paddle/fluid/framework/details/variable_visitor.h" #include "paddle/fluid/platform/profiler.h" +DEFINE_bool( + cpu_deterministic, false, + "Whether to make the result of computation deterministic in CPU side."); + namespace paddle { namespace framework { namespace details { @@ -91,11 +95,33 @@ void ReduceOpHandle::RunImpl() { } else { std::vector lod_tensors = GetInputValues(in_var_handles, var_scopes); + if (paddle::platform::is_cpu_place(lod_tensors[0]->place())) { this->RunAndRecordEvent([&] { - ReduceLoDTensor func(lod_tensors, - out_var->GetMutable()); - VisitDataType(ToDataType(lod_tensors[0]->type()), func); + // FIXME(zcd): The order of summing is important, + // especially when the type of data is float or double. + // For example, the result of `a+b+c+d` may be different + // with the result of `c+a+b+d`, so the summing order should be fixed. + if (!FLAGS_cpu_deterministic) { + ReduceLoDTensor func(lod_tensors, + out_var->GetMutable()); + VisitDataType(ToDataType(lod_tensors[0]->type()), func); + } else { + // We sum lod_tensors to reduce_sum_trg which is in local_scopes_0 + // here, but it doesn't mean reduce_sum_trg must be in local_scopes_0. + auto &reduce_sum_trg = *this->local_scopes_[0] + ->FindVar(kLocalExecScopeName) + ->Get() + ->FindVar(out_var_handle->name_) + ->GetMutable(); + ReduceLoDTensor func(lod_tensors, &reduce_sum_trg); + VisitDataType(ToDataType(lod_tensors[0]->type()), func); + + auto trg = out_var->GetMutable(); + if (reduce_sum_trg.data() != trg->data()) { + TensorCopy(reduce_sum_trg, platform::CPUPlace(), trg); + } + } }); } else if (paddle::platform::is_gpu_place(lod_tensors[0]->place())) { #ifdef PADDLE_WITH_CUDA diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 0c8acf71bfa0814e66560258ad6131c743ebc81b..16c7f819f35655fae1f08fa5be0d204ed98ca9c4 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -778,6 +778,7 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType( const ExecutionContext& ctx) const { auto& scope = ctx.scope(); int data_type = -1; + std::string last_input_name; for (auto& input : this->inputs_) { for (auto& ipt_name : input.second) { auto* var = scope.FindVar(ipt_name); @@ -794,9 +795,10 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType( int tmp = static_cast(ToDataType(t->type())); PADDLE_ENFORCE( tmp == data_type || data_type == -1, - "DataType of Paddle Op %s must be the same. Get %d != %d", Type(), - data_type, tmp); + "DataType of Paddle Op %s must be the same. Get %s(%d) != %s(%d)", + Type(), last_input_name, data_type, ipt_name, tmp); data_type = tmp; + last_input_name = ipt_name; } } } diff --git a/paddle/fluid/inference/analysis/analyzer.cc b/paddle/fluid/inference/analysis/analyzer.cc index 98bdfcc00b9f0e8f40dfc92e4021b2bd6fb19313..c4ab26a2288bb9d8f3cd54a797d2062e0606b219 100644 --- a/paddle/fluid/inference/analysis/analyzer.cc +++ b/paddle/fluid/inference/analysis/analyzer.cc @@ -24,7 +24,7 @@ namespace paddle { -DEFINE_bool(inference_analysis_enable_tensorrt_subgraph_engine, false, +DEFINE_bool(inference_analysis_enable_tensorrt_subgraph_engine, true, "Enable subgraph to TensorRT engine for acceleration"); DEFINE_string(inference_analysis_graphviz_log_root, "./", @@ -42,10 +42,19 @@ class DfgPassManagerImpl final : public DfgPassManager { // TODO(Superjomn) set the key with pass reprs. AddPass("fluid-to-data-flow-graph", new FluidToDataFlowGraphPass); if (FLAGS_inference_analysis_enable_tensorrt_subgraph_engine) { - auto trt_teller = [](const Node* node) { + auto trt_teller = [&](const Node* node) { + std::unordered_set teller_set( + {"elementwise_add", "mul", "conv2d", "pool2d", "relu"}); if (!node->IsFunction()) return false; - return static_cast(node)->func_type() == "mul"; + + const auto* func = static_cast(node); + if (teller_set.count(func->func_type())) + return true; + else { + return false; + } }; + AddPass("tensorrt-subgraph-marker", new TensorRTSubgraphNodeMarkPass(trt_teller)); AddPass("tensorrt-subgraph", new TensorRTSubGraphPass(trt_teller)); diff --git a/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc b/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc index 2328d870422c5a31c22d7b09980aae35e01b2b25..aaf7ca67011fb7bd4a74f6d8f57317594c528ca4 100644 --- a/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc +++ b/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc @@ -23,7 +23,7 @@ namespace paddle { namespace inference { -DEFINE_int32(tensorrt_max_batchsize, 300, "TensorRT maximum batch size"); +DEFINE_int32(tensorrt_max_batchsize, 3, "TensorRT maximum batch size"); DEFINE_int32(tensorrt_workspace_size, 2048, "TensorRT workspace size"); namespace analysis { @@ -87,34 +87,113 @@ void DataFlowGraphToFluidPass::AddFluidOp(Node *node) { } void CreateTrtEngineOp(Node *node, const DataFlowGraph &graph, - const framework::proto::BlockDesc &block) { + framework::proto::BlockDesc *block) { static int counter{0}; PADDLE_ENFORCE(node->IsFunctionBlock()); framework::OpDesc desc; auto *func = static_cast(node); // collect inputs - std::vector io; + std::unordered_set input_names; for (auto *x : func->inlinks) { - io.push_back(x->name()); + input_names.insert(x->name()); } - desc.SetInput("Xs", io); + desc.SetInput( + "Xs", std::vector(input_names.begin(), input_names.end())); - // collect outputs - io.clear(); + std::unordered_set output_names; for (auto *x : func->outlinks) { - io.push_back(x->name()); + output_names.insert(x->name()); } - desc.SetOutput("Ys", io); + + std::vector output_temp(output_names.begin(), + output_names.end()); + desc.SetOutput("Ys", output_temp); desc.SetType("tensorrt_engine"); - PADDLE_ENFORCE(!block.vars().empty(), "the block has no var-desc"); + std::unordered_map output_name_map; + + // The following procedure is used to rename all the intermediate + // variables and the output variables of the subgraph. + // Why we do this? + // During the transition from fluid OP to tensorrt OP, we map + // the input and output Tensor(fluid data structure) of fluid OP + // to the correspondin ITensor (trt data structure) through the + // Tensor name. When we set up ITensor for an variable, we must + // ensure that it has not been set before. + // If there is variable in the fluid graph, which is not only the + // input of a OP, but also the output of a Op, there will be problems. + // So we have to rename the variable in the subgraph to make sure + // it is either an OP's input or an OP's output. + + auto subgraph_nodes = func->subgraph; + for (int index = 0; index < block->ops_size(); index++) { + framework::proto::OpDesc *op = block->mutable_ops(index); + auto correspond_node = subgraph_nodes[index]; + PADDLE_ENFORCE_EQ(correspond_node->name(), op->type()); + + std::unordered_map var2id; + for (auto *in_var : correspond_node->inlinks) { + var2id[in_var->name()] = in_var->id(); + } + // rename for the input variables of op inside subgraph + for (int i = 0; i < op->inputs_size(); i++) { + framework::proto::OpDesc_Var *in_var = op->mutable_inputs(i); + std::vector replaced_names; + for (int k = 0; k < in_var->arguments_size(); k++) { + std::string arg_value = in_var->arguments(k); + if (input_names.count(arg_value)) { + replaced_names.push_back(arg_value); + } else { + replaced_names.push_back(arg_value + + std::to_string(var2id[arg_value])); + } + } + in_var->clear_arguments(); + for (size_t k = 0; k < replaced_names.size(); k++) { + in_var->add_arguments(replaced_names[k]); + } + } + var2id.clear(); + for (auto out_var : correspond_node->outlinks) { + var2id[out_var->name()] = out_var->id(); + } + + // rename for the output variables of op inside subgraph + for (int i = 0; i < op->outputs_size(); i++) { + framework::proto::OpDesc_Var *out_var = op->mutable_outputs(i); + std::vector replaced_names; + for (int k = 0; k < out_var->arguments_size(); k++) { + std::string arg_value = out_var->arguments(k); + if (output_names.count(arg_value)) { + output_name_map[arg_value] = + arg_value + std::to_string(var2id[arg_value]); + } + replaced_names.push_back(arg_value + std::to_string(var2id[arg_value])); + } + out_var->clear_arguments(); + for (size_t k = 0; k < replaced_names.size(); k++) { + out_var->add_arguments(replaced_names[k]); + } + } + } + // When tensorrt engine runs at the end of the operation, + // output_mapping help us copy the data from the renamed ITensor + // to Tensor. + std::vector output_mapping; + for (auto name : output_names) { + PADDLE_ENFORCE(output_name_map.count(name) != 0); + output_mapping.push_back(output_name_map[name]); + } + + PADDLE_ENFORCE(!block->vars().empty(), "the block has no var-desc"); // Set attrs - SetAttr(desc.Proto(), "subgraph", block.SerializeAsString()); + SetAttr(desc.Proto(), "subgraph", block->SerializeAsString()); SetAttr(desc.Proto(), "engine_uniq_key", "trt-" + std::to_string(counter++)); SetAttr(desc.Proto(), "max_batch", FLAGS_tensorrt_max_batchsize); SetAttr(desc.Proto(), "max_workspace", FLAGS_tensorrt_workspace_size); SetAttr(desc.Proto(), "parameters", ExtractParameters(graph.nodes.nodes())); + SetAttr(desc.Proto(), "output_name_mapping", output_mapping); node->SetPbMsg(desc.Proto()->SerializeAsString()); } @@ -146,15 +225,17 @@ void DataFlowGraphToFluidPass::AddEngineOp(Node *node) { LOG(INFO) << "transformed variable size: " << block_desc.Proto()->vars().size(); // copy ops. + for (auto *node : block_node->subgraph) { auto *op = block_desc.AppendOp(); PADDLE_ENFORCE(!node->pb_msg().empty()); op->Proto()->ParseFromString(node->pb_msg()); } + *block_desc.Proto()->mutable_vars() = argument_->origin_program_desc->blocks(0).vars(); PADDLE_ENFORCE(!block_desc.Proto()->vars().empty()); - CreateTrtEngineOp(node, *argument_->main_dfg, *block_desc.Proto()); + CreateTrtEngineOp(node, *argument_->main_dfg, block_desc.Proto()); auto *main_block = desc_->mutable_blocks(framework::kRootBlockIndex); auto *op = main_block->add_ops(); PADDLE_ENFORCE(!node->pb_msg().empty(), "failed to set desc for block"); diff --git a/paddle/fluid/inference/analysis/subgraph_splitter.cc b/paddle/fluid/inference/analysis/subgraph_splitter.cc index 389f9e1a9148a4daf0e5b751cce5cb6325252a4e..80809d4c43ca08298bad25cf614dcb4117d3f99a 100644 --- a/paddle/fluid/inference/analysis/subgraph_splitter.cc +++ b/paddle/fluid/inference/analysis/subgraph_splitter.cc @@ -76,7 +76,7 @@ void UnionFindCombine(const node_map_t &node_map, size_t a, size_t b) { std::vector> SubGraphSplitter::ExtractSubGraphs() { std::vector marked_nodes; - for (auto &node : GraphTraits(graph_).nodes()) { + for (auto &node : GraphTraits(graph_).nodes_in_TS()) { if (node.attr(kMarkerAttrName).Bool()) { marked_nodes.push_back(&node); } diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt index d86c046f2e5b08a4c00cf6cad19627e6a196c798..8f42a37cd3f8978b917b42e8f45a128b8422aa57 100644 --- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -1,6 +1,7 @@ # Add TRT tests nv_library(tensorrt_converter SRCS mul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc +activation_op.cc DEPS tensorrt_engine operator scope framework_proto op_registry) nv_test(test_op_converter SRCS test_op_converter.cc DEPS diff --git a/paddle/fluid/inference/tensorrt/convert/op_converter.h b/paddle/fluid/inference/tensorrt/convert/op_converter.h index 1b6a0ad82f3ceb00cec15c28c8121adc22271b7a..41faaf7212accaaec238062b1340e8da8fa6be33 100644 --- a/paddle/fluid/inference/tensorrt/convert/op_converter.h +++ b/paddle/fluid/inference/tensorrt/convert/op_converter.h @@ -55,7 +55,6 @@ class OpConverter { it = Registry::Lookup("fc"); } } - if (op_desc.Type().find("elementwise") != std::string::npos) { static std::unordered_set add_tensor_op_set{ "add", "mul", "sub", "div", "max", "min", "pow"}; @@ -72,6 +71,8 @@ class OpConverter { "Unsupported elementwise type" + op_type); it = Registry::Lookup("elementwise_" + op_type + "_weight"); + PADDLE_ENFORCE_NOT_NULL(it, "no OpConverter for optype [%s]", + op_desc.Type()); } else { PADDLE_ENFORCE(add_tensor_op_set.count(op_type) > 0, "Unsupported elementwise type" + op_type); diff --git a/paddle/fluid/operators/fused_elemwise_activation_op.cc b/paddle/fluid/operators/fused_elemwise_activation_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..a6fd0aeb021dce40339c32251af130d5984dccd2 --- /dev/null +++ b/paddle/fluid/operators/fused_elemwise_activation_op.cc @@ -0,0 +1,221 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include + +#include "paddle/fluid/operators/fused_elemwise_activation_op.h" + +namespace paddle { +namespace operators { + +class FusedElemwiseActivationOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE( + ctx->HasInput("X"), + "Input(X) of FusedElemwiseActivationOp op should not be null."); + PADDLE_ENFORCE( + ctx->HasInput("Y"), + "Input(Y) of FusedElemwiseActivationOp op should not be null."); + PADDLE_ENFORCE( + ctx->HasOutput("Out"), + "Output(Out) of FusedElemwiseActivationOp op should not be null."); + + auto x_dim = ctx->GetInputDim("X"); + auto y_dim = ctx->GetInputDim("Y"); + PADDLE_ENFORCE_GE(x_dim.size(), y_dim.size(), + "Rank of first input must >= rank of second input."); + + ctx->SetOutputDim("Out", x_dim); + ctx->ShareLoD("X", /*->*/ "Out"); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + PADDLE_ENFORCE_EQ(ctx.Input("X")->type(), + ctx.Input("Y")->type(), + "The element's type of input should be the same."); + auto input_data_type = + framework::ToDataType(ctx.Input("X")->type()); + return framework::OpKernelType(input_data_type, ctx.GetPlace()); + } +}; + +class FusedElemwiseActivationMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(vector)"); + AddInput("Y", "(vector)"); + AddOutput("Out", "vector"); + AddAttr("axis", + "axis is used by elementwise_op, the default value is -1.") + .SetDefault(-1); + AddAttr("scale", + "scale is used by scale_op, the default value is 0.0.") + .SetDefault(0.0); + AddAttr("recomputation", + "Whether to recompute the Out." + "fused_elemwise_activation_grad has two methods to get the " + "dx and dy, one " + "is to use the 'Out', and the other is not to use it. " + "The former method will save the time of recomputing the " + "'Out', but it must occupy the memory to store the 'out'. " + "While, the later method can avoid occupying the memory, " + "but it must recompute the 'Out'. The default value is true.") + .SetDefault(true); + AddAttr>("functor_list", + "The functors that should be fused.") + .AddCustomChecker([&](const std::vector &functor_list) { + PADDLE_ENFORCE(ValidCheck(functor_list)); + }); + + AddComment(R"DOC( +FusedElemwiseActivation Operator. + +At present, FusedElemwiseActivation only supports Two kinds of compound +operators (elementwise_op and activation_op): + + Z = Binary(X, Unary(Y)) + Z = Unary(Binary(X, Y)) + +The attributions of activation_op can be get from fused_elemwise_activation_op's +attributions. functor_list records the functors to be fused, for example +"scale,elementwise_add". + +)DOC"); + } + + private: + bool ValidCheck(const std::vector &functors) { + std::unordered_set unary_fun = {"scale", "relu"}; + std::unordered_set binary_fun = {"elementwise_add"}; + + std::string unary_fun_str; + if (binary_fun.count(functors[0])) { + unary_fun_str = functors[1]; + } else if (binary_fun.count(functors[1])) { + unary_fun_str = functors[0]; + } else { + PADDLE_THROW("%s and %s are not included in fused_list.", functors[0], + functors[1]); + } + PADDLE_ENFORCE_EQ(unary_fun.count(unary_fun_str), 1, + "%s is not included in fused_list.", unary_fun_str); + return true; + } +}; + +class FusedElemwiseActivationGradMaker + : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + auto *op_desc_ptr = new framework::OpDesc(); + op_desc_ptr->SetType(this->ForwardOpType() + "_grad"); + + for (auto &input_param : this->InputNames()) { + op_desc_ptr->SetInput(input_param, this->Input(input_param)); + op_desc_ptr->SetOutput(framework::GradVarName(input_param), + this->InputGrad(input_param, true)); + } + + for (auto &output_param : this->OutputNames()) { + op_desc_ptr->SetInput(output_param, this->Output(output_param)); + op_desc_ptr->SetInput(framework::GradVarName(output_param), + this->OutputGrad(output_param)); + } + op_desc_ptr->SetAttrMap(this->Attrs()); + + std::vector functor_names = + boost::get>( + op_desc_ptr->GetAttr("functor_list")); + functor_names[0] += "_grad"; + functor_names[1] += "_grad"; + op_desc_ptr->SetAttr("functor_list", functor_names); + return std::unique_ptr(op_desc_ptr); + } +}; + +class FusedElemwiseActivationOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); + PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null"); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) should not be null"); + + auto x_dims = ctx->GetInputDim("X"); + auto y_dims = ctx->GetInputDim("Y"); + auto out_dims = ctx->GetInputDim(framework::GradVarName("Out")); + + PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(), + "Rank of first input must >= rank of second input."); + + auto x_grad_name = framework::GradVarName("X"); + auto y_grad_name = framework::GradVarName("Y"); + if (ctx->HasOutput(x_grad_name)) { + ctx->SetOutputDim(x_grad_name, x_dims); + } + if (ctx->HasOutput(y_grad_name)) { + ctx->SetOutputDim(y_grad_name, y_dims); + } + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + auto input_data_type_index = ctx.Input("X")->type(); + PADDLE_ENFORCE_EQ(input_data_type_index, + ctx.Input("Y")->type(), + "The element's type of input should be the same."); + PADDLE_ENFORCE_EQ( + input_data_type_index, + ctx.Input(framework::GradVarName("Out"))->type(), + "The element's type of input should be the same."); + + auto input_data_type = framework::ToDataType(input_data_type_index); + return framework::OpKernelType(input_data_type, ctx.GetPlace()); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(fused_elemwise_activation, ops::FusedElemwiseActivationOp, + ops::FusedElemwiseActivationMaker, + ops::FusedElemwiseActivationGradMaker); +REGISTER_OPERATOR(fused_elemwise_activation_grad, + ops::FusedElemwiseActivationOpGrad); + +REGISTER_OP_CPU_KERNEL( + fused_elemwise_activation, + ops::FusedElemwiseActivationKernel, + ops::FusedElemwiseActivationKernel); + +REGISTER_OP_CPU_KERNEL( + fused_elemwise_activation_grad, + ops::FusedElemwiseActivationGradKernel, + ops::FusedElemwiseActivationGradKernel); diff --git a/paddle/fluid/operators/fused_elemwise_activation_op.cu b/paddle/fluid/operators/fused_elemwise_activation_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..e1d2b16b4b5e3a480777f834c2cbeb6d00a755e4 --- /dev/null +++ b/paddle/fluid/operators/fused_elemwise_activation_op.cu @@ -0,0 +1,30 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/fused_elemwise_activation_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + fused_elemwise_activation, + ops::FusedElemwiseActivationKernel, + ops::FusedElemwiseActivationKernel); + +REGISTER_OP_CUDA_KERNEL( + fused_elemwise_activation_grad, + ops::FusedElemwiseActivationGradKernel, + ops::FusedElemwiseActivationGradKernel); diff --git a/paddle/fluid/operators/fused_elemwise_activation_op.h b/paddle/fluid/operators/fused_elemwise_activation_op.h new file mode 100644 index 0000000000000000000000000000000000000000..fe0017b824532b1210d0ae3e51983d63d081f12a --- /dev/null +++ b/paddle/fluid/operators/fused_elemwise_activation_op.h @@ -0,0 +1,425 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include "paddle/fluid/framework/op_desc.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/detail/safe_ref.h" +#include "paddle/fluid/operators/elementwise_op_function.h" +#include "paddle/fluid/operators/math/functors.h" + +namespace math = paddle::operators::math; + +namespace paddle { +namespace operators { + +// CompoundFunctors +// For example: Z = Binary(X, Unary(Y)) +template +struct BinaryCompoundFunctor { + BinaryCompoundFunctor(const BinaryFun &binary_fun, const UnaryFun &unary_fun) + : binary_fun_(binary_fun), unary_fun_(unary_fun) {} + + inline HOSTDEVICE T operator()(T x, T y) { + return binary_fun_(x, unary_fun_(y)); + } + + private: + BinaryFun binary_fun_; + UnaryFun unary_fun_; +}; + +// For example: Z = Unary(Binary(X, Y)) +template +struct UnaryCompoundFunctor { + UnaryCompoundFunctor(const UnaryFun &unary_fun, const BinaryFun &binary_fun) + : unary_fun_(unary_fun), binary_fun_(binary_fun) {} + + inline HOSTDEVICE T operator()(T x, T y) { + return unary_fun_(binary_fun_(x, y)); + } + + private: + UnaryFun unary_fun_; + BinaryFun binary_fun_; +}; + +// FIXME(zcd): DBinaryFun and DUnaryFun have to method to get +// the dx, one is to use the 'out', and the other is not to use it. +// the former method will save the time of recomputing the +// 'out', but it must occupy the memory to store the 'out'. +// While the later method can avoid occupying this memory, +// but it must recompute the 'out'. + +template +struct BinaryCompoundGradDxFunctor { + BinaryCompoundGradDxFunctor(const DBinaryFun &d_binary_fun, + const UnaryFun &unary_fun) + : d_binary_fun_(d_binary_fun), unary_fun_(unary_fun) {} + + inline HOSTDEVICE T operator()(T x, T y, T out, T dout) { + if (Recomputation) { + return dout * d_binary_fun_(x, unary_fun_(y)); + } else { + return dout * d_binary_fun_(x, unary_fun_(y), out); + } + } + + private: + DBinaryFun d_binary_fun_; + UnaryFun unary_fun_; +}; + +template +struct BinaryCompoundGradDyFunctor { + BinaryCompoundGradDyFunctor(const DBinaryFun &d_binary_fun, + const UnaryFun &unary_fun, + const DUnaryFun &d_unary_fun) + : d_binary_fun_(d_binary_fun), + unary_fun_(unary_fun), + d_unary_fun_(d_unary_fun) {} + + inline HOSTDEVICE T operator()(T x, T y, T out, T dout) { + if (Recomputation) { + return dout * d_binary_fun_(unary_fun_(y), x) * d_unary_fun_(y); + } else { + return dout * d_binary_fun_(unary_fun_(y), x, out) * d_unary_fun_(y); + } + } + + private: + DBinaryFun d_binary_fun_; + UnaryFun unary_fun_; + DUnaryFun d_unary_fun_; +}; + +template +struct UnaryCompoundGradDxFunctor { + UnaryCompoundGradDxFunctor(const DUnaryFun &d_unary_fun, + const BinaryFun &binary_fun, + const DBinaryFun &d_binary_fun) + : d_unary_fun_(d_unary_fun), + binary_fun_(binary_fun), + d_binary_fun_(d_binary_fun) {} + + inline HOSTDEVICE T operator()(T x, T y, T out, T dout) { + T base; + if (Recomputation) { + base = dout * d_unary_fun_(binary_fun_(x, y)); + } else { + base = dout * d_unary_fun_(binary_fun_(x, y), out); + } + return base * d_binary_fun_(x, y); + } + + private: + DUnaryFun d_unary_fun_; + BinaryFun binary_fun_; + DBinaryFun d_binary_fun_; +}; + +template +struct UnaryCompoundGradDyFunctor { + UnaryCompoundGradDyFunctor(const DUnaryFun &d_unary_fun, + const BinaryFun &binary_fun, + const DBinaryFun &d_binary_fun) + : d_unary_fun_(d_unary_fun), + binary_fun_(binary_fun), + d_binary_fun_(d_binary_fun) {} + + inline HOSTDEVICE T operator()(T x, T y, T out, T dout) { + T base; + if (Recomputation) { + base = dout * d_unary_fun_(binary_fun_(x, y)); + } else { + base = dout * d_unary_fun_(binary_fun_(x, y), out); + } + return base * d_binary_fun_(y, x); + } + + private: + DUnaryFun d_unary_fun_; + BinaryFun binary_fun_; + DBinaryFun d_binary_fun_; +}; + +template +static void RunBinaryCompoundFunctor(const framework::ExecutionContext &ctx, + const BinaryFunctor &binary_functor, + const UnaryFunctor &unary_functor, + const framework::Tensor *in_x, + const framework::Tensor *in_y, + framework::Tensor *output) { + int axis = ctx.Attr("axis"); + using BinaryCompoundFunctor = + BinaryCompoundFunctor; + + ElementwiseComputeEx( + ctx, in_x, in_y, axis, + BinaryCompoundFunctor(binary_functor, unary_functor), output); +} + +template +static void RunUnaryCompoundFunctors(const framework::ExecutionContext &ctx, + const UnaryFunctor &unary_functor, + const BinaryFunctor &binary_functor, + const framework::Tensor *in_x, + const framework::Tensor *in_y, + framework::Tensor *output) { + int axis = ctx.Attr("axis"); + + using UnaryCompoundFunctor = + UnaryCompoundFunctor; + + ElementwiseComputeEx( + ctx, in_x, in_y, axis, + UnaryCompoundFunctor(unary_functor, binary_functor), output); +} + +template +static void RunBinaryCompoundGradFunctors( + const framework::ExecutionContext &ctx, + const BinaryGradFunctor &binary_grad_functor, + const UnaryFunctor &unary_functor, + const UnaryGradFunctor &unary_grad_functor, const framework::Tensor *in_x, + const framework::Tensor *in_y, const framework::Tensor *in_out, + const framework::Tensor *in_out_grad, framework::Tensor *x_grad, + framework::Tensor *y_grad) { + int axis = ctx.Attr("axis"); + + using BinaryCompoundDxFunctor = + BinaryCompoundGradDxFunctor; + using BinaryCompoundDyFunctor = + BinaryCompoundGradDyFunctor; + + ElemwiseGradCompute( + ctx, *in_x, *in_y, *in_out, *in_out_grad, axis, x_grad, y_grad, + BinaryCompoundDxFunctor(binary_grad_functor, unary_functor), + BinaryCompoundDyFunctor(binary_grad_functor, unary_functor, + unary_grad_functor)); +} + +template +static void RunUnaryCompoundGradFunctors( + const framework::ExecutionContext &ctx, + const UnaryGradFunctor &unary_grad_functor, + const BinaryFunctor &binary_functor, + const BinaryGradFunctor &binary_grad_functor, const framework::Tensor *in_x, + const framework::Tensor *in_y, const framework::Tensor *in_out, + const framework::Tensor *in_out_grad, framework::Tensor *x_grad, + framework::Tensor *y_grad) { + int axis = ctx.Attr("axis"); + + using UnaryCompoundDxFunctor = + UnaryCompoundGradDxFunctor; + using UnaryCompoundDyFunctor = + UnaryCompoundGradDyFunctor; + + ElemwiseGradCompute( + ctx, *in_x, *in_y, *in_out, *in_out_grad, axis, x_grad, y_grad, + UnaryCompoundDxFunctor(unary_grad_functor, binary_functor, + binary_grad_functor), + UnaryCompoundDyFunctor(unary_grad_functor, binary_functor, + binary_grad_functor)); +} + +template +static void RunFunctors(const framework::ExecutionContext &ctx, + const framework::Tensor *in_x, + const framework::Tensor *in_y, + framework::Tensor *output) { + auto &functors = ctx.Attr>("functor_list"); + auto funcs_str = functors[0] + "," + functors[1]; + // TODO(zcd): The following code can be refined. + if (funcs_str == "elementwise_add,scale") { + // Z = Binary(X, Unary(Y)) + T scale = static_cast(ctx.Attr("scale")); + RunBinaryCompoundFunctor, + math::ScaleFunctor>( + ctx, math::AddFunctor(), math::ScaleFunctor(scale), in_x, in_y, + output); + } else if (funcs_str == "scale,elementwise_add") { + // Z = Unary(Binary(X, Y)) + T scale = static_cast(ctx.Attr("scale")); + RunUnaryCompoundFunctors, + math::AddFunctor>( + ctx, math::ScaleFunctor(scale), math::AddFunctor(), in_x, in_y, + output); + } else if (funcs_str == "elementwise_add,relu") { + RunBinaryCompoundFunctor, + math::ReluFunctor>( + ctx, math::AddFunctor(), math::ReluFunctor(), in_x, in_y, output); + } else if (funcs_str == "relu,elementwise_add") { + RunUnaryCompoundFunctors, + math::AddFunctor>( + ctx, math::ReluFunctor(), math::AddFunctor(), in_x, in_y, output); + } else { + PADDLE_THROW("%s has not been implemented.", funcs_str); + } +} + +template +static void RunGradFunctors(const framework::ExecutionContext &ctx, + const framework::Tensor *in_x, + const framework::Tensor *in_y, + const framework::Tensor *in_out, + const framework::Tensor *in_out_grad, + framework::Tensor *x_grad, + framework::Tensor *y_grad) { + auto &functors = ctx.Attr>("functor_list"); + auto funcs_str = functors[0] + "," + functors[1]; + + bool recomputation = ctx.Attr("recomputation"); + + // TODO(zcd): The following code can be refined. for example, use registion + if (funcs_str == "elementwise_add_grad,scale_grad") { + // The backward of Z = Binary(X, Unary(Y)) + T scale = static_cast(ctx.Attr("scale")); + if (recomputation) { + RunBinaryCompoundGradFunctors, + math::ScaleFunctor, + math::ScaleGradFunctor, true>( + ctx, math::AddGradFunctor(), math::ScaleFunctor(scale), + math::ScaleGradFunctor(scale), in_x, in_y, in_out, in_out_grad, + x_grad, y_grad); + } else { + RunBinaryCompoundGradFunctors, + math::ScaleFunctor, + math::ScaleGradFunctor, false>( + ctx, math::AddGradFunctor(), math::ScaleFunctor(scale), + math::ScaleGradFunctor(scale), in_x, in_y, in_out, in_out_grad, + x_grad, y_grad); + } + } else if (funcs_str == "scale_grad,elementwise_add_grad") { + // The backward of Z = Unary(Binary(X, Y)) + T scale = static_cast(ctx.Attr("scale")); + if (recomputation) { + RunUnaryCompoundGradFunctors, + math::AddFunctor, math::AddGradFunctor, + true>(ctx, math::ScaleGradFunctor(scale), + math::AddFunctor(), + math::AddGradFunctor(), in_x, in_y, + in_out, in_out_grad, x_grad, y_grad); + } else { + RunUnaryCompoundGradFunctors, + math::AddFunctor, math::AddGradFunctor, + false>(ctx, math::ScaleGradFunctor(scale), + math::AddFunctor(), + math::AddGradFunctor(), in_x, in_y, + in_out, in_out_grad, x_grad, y_grad); + } + } else if (funcs_str == "elementwise_add_grad,relu_grad") { + if (recomputation) { + RunBinaryCompoundGradFunctors, + math::ReluFunctor, + math::ReluGradFunctor, true>( + ctx, math::AddGradFunctor(), math::ReluFunctor(), + math::ReluGradFunctor(), in_x, in_y, in_out, in_out_grad, x_grad, + y_grad); + } else { + RunBinaryCompoundGradFunctors, + math::ReluFunctor, + math::ReluGradFunctor, false>( + ctx, math::AddGradFunctor(), math::ReluFunctor(), + math::ReluGradFunctor(), in_x, in_y, in_out, in_out_grad, x_grad, + y_grad); + } + } else if (funcs_str == "relu_grad,elementwise_add_grad") { + if (recomputation) { + RunUnaryCompoundGradFunctors, + math::AddFunctor, math::AddGradFunctor, + true>(ctx, math::ReluGradFunctor(), + math::AddFunctor(), + math::AddGradFunctor(), in_x, in_y, + in_out, in_out_grad, x_grad, y_grad); + } else { + RunUnaryCompoundGradFunctors, + math::AddFunctor, math::AddGradFunctor, + false>(ctx, math::ReluGradFunctor(), + math::AddFunctor(), + math::AddGradFunctor(), in_x, in_y, + in_out, in_out_grad, x_grad, y_grad); + } + } else { + PADDLE_THROW("%s has not been implemented.", funcs_str); + } +} + +template +class FusedElemwiseActivationKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto &in_x = detail::Ref(ctx.Input("X"), + "Cannot get input tensor %s, variable name = %s", + "X", ctx.op().Input("X")); + auto &in_y = detail::Ref(ctx.Input("Y"), + "Cannot get input tensor %s, variable name = %s", + "Y", ctx.op().Input("Y")); + auto &output = detail::Ref(ctx.Output("Out"), + "Cannot get input tensor %s, variable name = %s", + "Out", ctx.op().Output("Out")); + + RunFunctors(ctx, &in_x, &in_y, &output); + } +}; + +template +class FusedElemwiseActivationGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto &in_x = detail::Ref(ctx.Input("X"), + "Cannot get input tensor %s, variable name = %s", + "X", ctx.op().Input("X")); + auto &in_y = detail::Ref(ctx.Input("Y"), + "Cannot get input tensor %s, variable name = %s", + "Y", ctx.op().Input("Y")); + auto &in_out = detail::Ref(ctx.Input("Out"), + "Cannot get input tensor %s, variable name = %s", + "Out", ctx.op().Input("Out")); + auto &in_out_grad = + detail::Ref(ctx.Input(framework::GradVarName("Out")), + "Cannot get input tensor %s, variable name = %s", + framework::GradVarName("Out"), + ctx.op().Input(framework::GradVarName("Out"))); + + framework::Tensor *x_grad = + ctx.Output(framework::GradVarName("X")); + framework::Tensor *y_grad = + ctx.Output(framework::GradVarName("Y")); + + RunGradFunctors(ctx, &in_x, &in_y, &in_out, &in_out_grad, + x_grad, y_grad); + } +}; +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/math/functors.h b/paddle/fluid/operators/math/functors.h new file mode 100644 index 0000000000000000000000000000000000000000..ad2f49ccbf5ff37d33cc9e71c1a683571f4f8137 --- /dev/null +++ b/paddle/fluid/operators/math/functors.h @@ -0,0 +1,71 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +namespace paddle { +namespace operators { +namespace math { + +// AddFunctor +template +struct AddFunctor { + // out = x + y; + inline HOSTDEVICE T operator()(T x, T y) { return x + y; } +}; + +template +struct AddGradFunctor { + inline HOSTDEVICE T operator()(T x, T y) { return 1; } + + inline HOSTDEVICE T operator()(T x, T y, T out) const { return 1; } +}; + +template +struct ScaleFunctor { + explicit ScaleFunctor(const T coeff) : coeff_(coeff) {} + + inline HOSTDEVICE T operator()(T ele) { return ele * coeff_; } + + private: + T coeff_; +}; + +template +struct ScaleGradFunctor { + explicit ScaleGradFunctor(T coeff) : coeff_(coeff) {} + + inline HOSTDEVICE T operator()(T x) { return coeff_; } + + inline HOSTDEVICE T operator()(T x, T out) { return coeff_; } + + private: + T coeff_; +}; + +template +struct ReluFunctor { + inline HOSTDEVICE T operator()(T x) { return x * (x > 0); } +}; + +template +struct ReluGradFunctor { + inline HOSTDEVICE T operator()(T x) { return x > 0 ? 1 : 0; } + + inline HOSTDEVICE T operator()(T x, T out) { return x > 0 ? 1 : 0; } +}; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/parallel_do_op.cc b/paddle/fluid/operators/parallel_do_op.cc index 916cdad3fd288d1f3ffb19bc769ab827dd1e9103..eb09470f37eabb5524f774bc289fc68f5884c540 100644 --- a/paddle/fluid/operators/parallel_do_op.cc +++ b/paddle/fluid/operators/parallel_do_op.cc @@ -163,12 +163,11 @@ class ParallelDoOp : public framework::OperatorBase { auto &place = places[place_idx]; auto *cur_scope = sub_scopes[place_idx]; - workers.emplace_back( - framework::Async([program, cur_scope, place, block, place_idx] { - framework::Executor executor(place); - executor.Run(*program, cur_scope, block->ID(), - false /*create_local_scope*/); - })); + workers.emplace_back(framework::Async([program, cur_scope, place, block] { + framework::Executor executor(place); + executor.Run(*program, cur_scope, block->ID(), + false /*create_local_scope*/); + })); } for (auto &worker : workers) { worker.wait(); @@ -239,12 +238,11 @@ class ParallelDoGradOp : public framework::OperatorBase { auto *cur_scope = sub_scopes[i]; // execute - workers.emplace_back( - framework::Async([program, cur_scope, place, block, i] { - framework::Executor executor(place); - executor.Run(*program, cur_scope, block->ID(), - false /*create_local_scope*/); - })); + workers.emplace_back(framework::Async([program, cur_scope, place, block] { + framework::Executor executor(place); + executor.Run(*program, cur_scope, block->ID(), + false /*create_local_scope*/); + })); } for (auto &worker : workers) { worker.wait(); diff --git a/paddle/fluid/operators/read_op.cc b/paddle/fluid/operators/read_op.cc index 65fcce8bb019965a805ad09d50be0aba64e4f24e..a0d640b2020958af53a4405ae886eadb2a1e117e 100644 --- a/paddle/fluid/operators/read_op.cc +++ b/paddle/fluid/operators/read_op.cc @@ -15,6 +15,7 @@ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/reader.h" #include "paddle/fluid/operators/detail/safe_ref.h" +#include "paddle/fluid/platform/profiler.h" namespace paddle { namespace operators { @@ -65,6 +66,12 @@ class ReadOp : public framework::OperatorBase { .GetMutable(); std::vector out_arg_names = Outputs("Out"); std::vector ins; + + // For profiling + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + auto& ctx = *pool.Get(dev_place); + platform::RecordEvent record_event(Type(), &ctx); + reader->ReadNext(&ins); if (ins.empty()) { if (Attr("throw_eof_exp")) { diff --git a/paddle/fluid/operators/tensorrt_engine_op.cc b/paddle/fluid/operators/tensorrt_engine_op.cc index 1172822e12222ded219104e3bad2613d30f891b8..ee3078876c15b06a887064f08dc0c05d450b5f77 100644 --- a/paddle/fluid/operators/tensorrt_engine_op.cc +++ b/paddle/fluid/operators/tensorrt_engine_op.cc @@ -55,18 +55,8 @@ nvinfer1::Dims Vec2TRT_Dims(const std::vector &shape) { "TensorRT' tensor input requires at least 2 dimensions"); PADDLE_ENFORCE_LE(shape.size(), 4UL, "TensorRT' tensor input requires at most 4 dimensions"); - - switch (shape.size()) { - case 2: - return nvinfer1::Dims2(1, shape[1]); - case 3: - return nvinfer1::Dims3(1, shape[1], shape[2]); - case 4: - return nvinfer1::Dims4(1, shape[1], shape[2], shape[3]); - default: - return nvinfer1::Dims(); - } - return nvinfer1::Dims(); + PADDLE_ENFORCE_EQ(shape.size(), 4UL); + return nvinfer1::DimsCHW(shape[1], shape[2], shape[3]); } } // namespace @@ -86,6 +76,9 @@ void TensorRTEngineKernel::Prepare( parameters.insert(param); } + std::vector output_maps = + context.Attr>("output_name_mapping"); + // TODO(Superjomn) replace this with a different stream auto *engine = Singleton::Global().Create( max_batch, max_workspace, nullptr /*engine hold its own stream*/, @@ -97,6 +90,7 @@ void TensorRTEngineKernel::Prepare( // Add inputs VLOG(4) << "declare inputs"; for (auto &input : context.Inputs("Xs")) { + if (parameters.count(input)) continue; VLOG(4) << "declare input " << input; auto *var = block.FindVar(input); // TensorRT engine need to create parameters. The parameter's description @@ -122,7 +116,7 @@ void TensorRTEngineKernel::Prepare( block_desc, parameters, context.scope(), engine); // Add outputs - for (auto &output : context.Outputs("Ys")) { + for (auto &output : output_maps) { engine->DeclareOutput(output); } diff --git a/paddle/fluid/operators/tensorrt_engine_op.h b/paddle/fluid/operators/tensorrt_engine_op.h index 32d10fd8a5687ebaae1d7d75af531cbc45ef4245..2cbe1213a2f428a3ce56b06f97636baeb4b66c26 100644 --- a/paddle/fluid/operators/tensorrt_engine_op.h +++ b/paddle/fluid/operators/tensorrt_engine_op.h @@ -66,8 +66,17 @@ class TensorRTEngineKernel : public framework::OpKernel { PADDLE_ENFORCE_LE(FLAGS_tensorrt_engine_batch_size, context.Attr("max_batch")); + std::vector output_maps = + context.Attr>("output_name_mapping"); + + auto params = context.Attr>("parameters"); + std::unordered_set parameters; + for (const auto& param : params) { + parameters.insert(param); + } // Convert input tensor from fluid to engine. for (const auto& x : context.Inputs("Xs")) { + if (parameters.count(x)) continue; // convert input and copy to TRT engine's buffer auto& t = inference::analysis::GetFromScope( context.scope(), x); @@ -82,10 +91,12 @@ class TensorRTEngineKernel : public framework::OpKernel { // Execute the engine. PADDLE_ENFORCE_GT(FLAGS_tensorrt_engine_batch_size, 0); engine->Execute(FLAGS_tensorrt_engine_batch_size); + // Convert output tensor from engine to fluid + int output_index = 0; for (const auto& y : context.Outputs("Ys")) { // convert output and copy to fluid. - nvinfer1::ITensor* trt_t = engine->GetITensor(y); + nvinfer1::ITensor* trt_t = engine->GetITensor(output_maps[output_index]); auto dims = trt_t->getDimensions(); // Use the output ITensor's dims to reshape the Fluid Tensor. std::vector ddim(dims.d, dims.d + dims.nbDims); @@ -102,7 +113,7 @@ class TensorRTEngineKernel : public framework::OpKernel { // TODO(Superjomn) change this float to dtype size. auto size = inference::analysis::AccuDims(dims.d, dims.nbDims) * FLAGS_tensorrt_engine_batch_size; - engine->GetOutputInCPU(y, + engine->GetOutputInCPU(output_maps[output_index], fluid_t->mutable_data(platform::CPUPlace()), size * sizeof(float)); //} else { @@ -110,6 +121,7 @@ class TensorRTEngineKernel : public framework::OpKernel { // y, fluid_t->mutable_data(platform::CUDAPlace()), // size * sizeof(float)); //} + output_index += 1; } cudaStreamSynchronize(*engine->stream()); diff --git a/paddle/fluid/operators/tensorrt_engine_op_test.cc b/paddle/fluid/operators/tensorrt_engine_op_test.cc index 7cb1e47a1516c32fb31a7818e7203b498e31e431..37657fa0b0498986fe67027415279af1775e58b9 100644 --- a/paddle/fluid/operators/tensorrt_engine_op_test.cc +++ b/paddle/fluid/operators/tensorrt_engine_op_test.cc @@ -103,6 +103,9 @@ TEST(TensorRTEngineOp, manual) { SetAttr(engine_op_desc.Proto(), "engine_uniq_key", "a_engine"); SetAttr>(engine_op_desc.Proto(), "parameters", std::vector({})); + SetAttr>(engine_op_desc.Proto(), + "output_name_mapping", + std::vector({"z0"})); LOG(INFO) << "create engine op"; auto engine_op = framework::OpRegistry::CreateOp(*engine_op_desc.Proto()); @@ -196,6 +199,10 @@ void Execute(int batch_size, int input_dim, int output_dim, int nlayers = 1) { std::vector({"y0", "y1", "y2", "y3"})); SetAttr(engine_op_desc.Proto(), "engine_uniq_key", "b_engine"); + SetAttr>(engine_op_desc.Proto(), + "output_name_mapping", + std::vector({"z3"})); + auto engine_op = framework::OpRegistry::CreateOp(*engine_op_desc.Proto()); // Execute them. diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index 956e3c43485b36aaeb2d366d6145edd3d4535122..3b38c42801e0a4b503d929ca422b354f4c51bb0c 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -123,7 +123,8 @@ def __bootstrap__(): read_env_flags = [ 'use_pinned_memory', 'check_nan_inf', 'benchmark', 'warpctc_dir', 'eager_delete_scope', 'use_mkldnn', 'initial_cpu_memory_in_mb', - 'init_allocated_mem', 'free_idle_memory', 'paddle_num_threads' + 'init_allocated_mem', 'free_idle_memory', 'paddle_num_threads', + 'cpu_deterministic' ] if core.is_compiled_with_dist(): read_env_flags.append('rpc_deadline') diff --git a/python/paddle/fluid/contrib/__init__.py b/python/paddle/fluid/contrib/__init__.py index 12cd5d918e93181c6b7e328e6aee4ad941b0a0da..9de9e9504510baec9aefb47f91793c364450795a 100644 --- a/python/paddle/fluid/contrib/__init__.py +++ b/python/paddle/fluid/contrib/__init__.py @@ -14,5 +14,7 @@ import decoder from decoder import * +import memory_usage_calc +from memory_usage_calc import * -__all__ = decoder.__all__ +__all__ = decoder.__all__ + memory_usage_calc.__all__ diff --git a/python/paddle/fluid/contrib/memory_usage_calc.py b/python/paddle/fluid/contrib/memory_usage_calc.py new file mode 100644 index 0000000000000000000000000000000000000000..5da846edb63c28efd791fdfac4046cfa56c24181 --- /dev/null +++ b/python/paddle/fluid/contrib/memory_usage_calc.py @@ -0,0 +1,102 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This module privides a memory usage calculate function for user. +The purpose of this API is to allow users to estimate memory usage of +a program under a special batch size, then user can set appropriate +batch size to fully utilize a GPU. + +This API is still under active development and may change drastically. +""" + +from .. import core +from ..framework import Program, Variable + +__all__ = ['memory_usage'] + +dtype_to_size = { + core.VarDesc.VarType.FP16: 2, + core.VarDesc.VarType.FP32: 4, + core.VarDesc.VarType.FP64: 8, + core.VarDesc.VarType.INT16: 2, + core.VarDesc.VarType.INT32: 4, + core.VarDesc.VarType.INT64: 8, + core.VarDesc.VarType.BOOL: 1, + core.VarDesc.VarType.UINT8: 1, +} + +DEBUG = False + + +def memory_usage(program, batch_size): + """ + Get the estimate memory usage of program with input batch size. + + Args: + program(Program): The current Program. + batch_size(int): The current input data batch_size. + + Returns: + min_total_memory(float): the estimate memory usage lower bound. + max_total_memory(float): the estimate memory usage upper bound. + unit_str(string): the unit of estimate usage result. + + Examples: + + >>> import paddle.fluid as fluid + >>> lower_usage, upper_usage, unit = fluid.contrib.memory_usage( + fluid.default_main_program(), batch_size=10) + >>> print "memory usage is about %.3f - %.3f %s" % \ + (lower_usage, upper_usage, unit) + + """ + + # Parameters check + if not isinstance(program, Program): + raise TypeError( + "Calculating Memory Usage requires Program as its Parameter." + "But you passed in %s" % (type(prgram))) + if batch_size <= 0: + raise ValueError("The batch size need to be positive.") + + # Get the var_name list of first block and calculate + total_memory = 0.0 + for var in program.global_block().vars.itervalues(): + data_count = 1 + for x in var.shape: + if x == -1: + data_count *= batch_size + else: + data_count *= x + var_memory = data_count * dtype_to_size[var.dtype] + if DEBUG: + print "%s memory usage: %d" % (var.name, var_memory) + total_memory += var_memory + if DEBUG: + print "total memory usage: %.2f" % (total_memory) + + # Convert appropriate unit + unit_str = "B" + if total_memory > 1024: + total_memory /= 1024 + unit_str = "KB" + if total_memory > 1024: + total_memory /= 1024 + unit_str = "MB" + + # Append extra memory consumption (5% - 10%) + min_total_memory = total_memory * 1.05 + max_total_memory = total_memory * 1.1 + + return min_total_memory, max_total_memory, unit_str diff --git a/python/paddle/fluid/tests/unittests/dist_se_resnext.py b/python/paddle/fluid/tests/unittests/dist_se_resnext.py index bf7816b2466edd7db836c738da90f5f97b631843..f1f35d96f67ad5ef79ec9cb20f070a8352f0e97e 100644 --- a/python/paddle/fluid/tests/unittests/dist_se_resnext.py +++ b/python/paddle/fluid/tests/unittests/dist_se_resnext.py @@ -174,6 +174,9 @@ class SE_ResNeXt(): padding=(filter_size - 1) / 2, groups=groups, act=None, + # avoid pserver CPU init differs from GPU + param_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant()), bias_attr=False) return fluid.layers.batch_norm(input=conv, act=act) @@ -194,10 +197,8 @@ class SE_ResNeXt(): def get_model(batch_size): # Input data - image = fluid.layers.fill_constant( - shape=[batch_size, 3, 224, 224], dtype='float32', value=0.0) - label = fluid.layers.fill_constant( - shape=[batch_size, 1], dtype='int64', value=0.0) + image = fluid.layers.data(name="data", shape=[3, 224, 224], dtype='float32') + label = fluid.layers.data(name="int64", shape=[1], dtype='int64') # Train program model = SE_ResNeXt(layers=50) @@ -222,8 +223,10 @@ def get_model(batch_size): lr = [base_lr * (0.1**i) for i in range(len(bd) + 1)] optimizer = fluid.optimizer.Momentum( - learning_rate=fluid.layers.piecewise_decay( - boundaries=bd, values=lr), + # FIXME(typhoonzero): add back LR decay once ParallelExecutor fixed. + #learning_rate=fluid.layers.piecewise_decay( + # boundaries=bd, values=lr), + learning_rate=base_lr, momentum=0.9, regularization=fluid.regularizer.L2Decay(1e-4)) optimizer.minimize(avg_cost) @@ -232,7 +235,7 @@ def get_model(batch_size): train_reader = paddle.batch( paddle.dataset.flowers.train(), batch_size=batch_size) test_reader = paddle.batch( - paddle.dataset.flowers.test(), batch_size=batch_size) + paddle.dataset.flowers.test(use_xmap=False), batch_size=batch_size) return test_program, avg_cost, train_reader, test_reader, acc_top1, out @@ -256,7 +259,6 @@ class DistSeResneXt2x2: trainers) pserver_prog = t.get_pserver_program(current_endpoint) startup_prog = t.get_startup_program(current_endpoint, pserver_prog) - place = fluid.CPUPlace() exe = fluid.Executor(place) exe.run(startup_prog) @@ -302,12 +304,19 @@ class DistSeResneXt2x2: ] feeder = fluid.DataFeeder(feed_var_list, place) - reader_generator = train_reader() - first_loss, = exe.run(fetch_list=[avg_cost.name]) + reader_generator = test_reader() + + data = next(reader_generator) + first_loss, = exe.run(fetch_list=[avg_cost.name], + feed=feeder.feed(data)) print(first_loss) + for i in xrange(5): - loss, = exe.run(fetch_list=[avg_cost.name]) - last_loss, = exe.run(fetch_list=[avg_cost.name]) + data = next(reader_generator) + loss, = exe.run(fetch_list=[avg_cost.name], feed=feeder.feed(data)) + + data = next(reader_generator) + last_loss, = exe.run(fetch_list=[avg_cost.name], feed=feeder.feed(data)) print(last_loss) diff --git a/python/paddle/fluid/tests/unittests/test_activation_op.py b/python/paddle/fluid/tests/unittests/test_activation_op.py index 5ed387fb1247f1a91147cb6981f1adc7c2eeb8a2..34f9cf0620fd1351111e93e16ed5f7e765d7078b 100644 --- a/python/paddle/fluid/tests/unittests/test_activation_op.py +++ b/python/paddle/fluid/tests/unittests/test_activation_op.py @@ -313,9 +313,9 @@ class TestAbs(OpTest): self.init_dtype() x = np.random.uniform(-1, 1, [4, 4]).astype(self.dtype) - # Because we set delta = 0.005 in caculating numeric gradient, + # Because we set delta = 0.005 in calculating numeric gradient, # if x is too small, such as 0.002, x_neg will be -0.003 - # x_pos will be 0.007, so the numeric gradient is unaccurate. + # x_pos will be 0.007, so the numeric gradient is inaccurate. # we should avoid this x[np.abs(x) < 0.005] = 0.02 out = np.abs(x) diff --git a/python/paddle/fluid/tests/unittests/test_dist_base.py b/python/paddle/fluid/tests/unittests/test_dist_base.py index 1aaab6f906ef6482bc515bb3c42d82431902e1d8..58cfd4e1fd958d8d59e49c87fbbabd0182975add 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_base.py +++ b/python/paddle/fluid/tests/unittests/test_dist_base.py @@ -63,7 +63,8 @@ class TestDistBase(unittest.TestCase): "PATH": os.getenv("PATH"), "PYTHONPATH": os.getenv("PYTHONPATH"), "LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH"), - "FLAGS_fraction_of_gpu_memory_to_use": "0.15" + "FLAGS_fraction_of_gpu_memory_to_use": "0.15", + "FLAGS_cudnn_deterministic": "1" } # Run local to get a base line env_local = {"CUDA_VISIBLE_DEVICES": "0"} diff --git a/python/paddle/fluid/tests/unittests/test_dist_se_resnext.py b/python/paddle/fluid/tests/unittests/test_dist_se_resnext.py index 04671d079731ce414561b0ede6bc2b195b07d82a..f3a5fd6985bab1d04f6e1484534367548f383dfb 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_se_resnext.py +++ b/python/paddle/fluid/tests/unittests/test_dist_se_resnext.py @@ -17,8 +17,7 @@ from test_dist_base import TestDistBase class TestDistSeResneXt2x2(TestDistBase): def test_se_resnext(self): - # TODO(paddle-dev): Is the delta too large? - self.check_with_place("dist_se_resnext.py", delta=0.2) + self.check_with_place("dist_se_resnext.py") if __name__ == "__main__": diff --git a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py index 7107cec2bfcefda48bac01e7c868f7b4811b9be7..b24036326d51aa56220d46cba202a0d4b93cdd7c 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py +++ b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py @@ -359,5 +359,110 @@ class TestL2DecayWithPiecewise(TranspilerTest): ["sum", "scale", "scale", "elementwise_add", "momentum"]) +class TestDistLookupTableBase(TranspilerTest): + def network_with_table(self, is_sparse, is_distributed): + def emb_pool(ids): + table_size = 1000 + emb_size = 64 + emb = fluid.layers.embedding( + input=ids, + size=[table_size, emb_size], + dtype='float32', + param_attr='shared_w', # share parameter + is_sparse=is_sparse, + is_distributed=is_distributed) + pool = fluid.layers.sequence_pool(input=emb, pool_type='average') + return pool + + title_ids = fluid.layers.data( + name='title_ids', shape=[1], dtype='int64', lod_level=1) + brand_ids = fluid.layers.data( + name='brand_ids', shape=[1], dtype='int64', lod_level=1) + title_emb = emb_pool(title_ids) + brand_emb = emb_pool(brand_ids) + fc0 = fluid.layers.concat(input=[title_emb, brand_emb], axis=1) + predict = fluid.layers.fc(input=fc0, + size=2, + act=None, + param_attr=fluid.ParamAttr(name='fc_w'), + bias_attr=fluid.ParamAttr(name='fc_b')) + + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + cost = fluid.layers.cross_entropy(input=predict, label=label) + avg_cost = fluid.layers.mean(cost) + optimizer = fluid.optimizer.Adam(learning_rate=0.003) + optimizer.minimize(avg_cost) + + +class TestLocalLookupTable(TestDistLookupTableBase): + def net_conf(self): + self.network_with_table(is_sparse=True, is_distributed=False) + + def transpiler_test_impl(self): + pserver1, startup1 = self.get_pserver(self.pserver1_ep) + + self.assertEqual(len(pserver1.blocks), 3) + # 0 listen_and_serv + # 1 optimize for fc_w or fc_b adam + self.assertEqual([op.type for op in pserver1.blocks[1].ops], + ["sum", "scale", "adam", "scale", "scale"]) + # 2 optimize for table adam + # NOTE: if param is not selected rows, the grad will scaled to grad / trainer_num + self.assertEqual([op.type for op in pserver1.blocks[2].ops], + ["sum", "adam", "scale", "scale"]) + + trainer = self.get_trainer() + self.assertEqual(len(trainer.blocks), 1) + ops = [ + 'lookup_table', 'sequence_pool', 'lookup_table', 'sequence_pool', + 'concat', 'mul', 'elementwise_add', 'cross_entropy', 'mean', + 'fill_constant', 'mean_grad', 'cross_entropy_grad', + 'elementwise_add_grad', 'send', 'mul_grad', 'send', 'concat_grad', + 'sequence_pool_grad', 'lookup_table_grad', 'sequence_pool_grad', + 'lookup_table_grad', 'sum', 'split_selected_rows', 'send', + 'send_barrier', 'recv', 'recv', 'recv', 'fetch_barrier', 'concat' + ] + self.assertEqual([op.type for op in trainer.blocks[0].ops], ops) + + +class TestDistLookupTable(TestDistLookupTableBase): + def net_conf(self): + self.network_with_table(is_sparse=True, is_distributed=True) + + def transpiler_test_impl(self): + pserver1, startup1 = self.get_pserver(self.pserver1_ep) + + self.assertEqual(len(pserver1.blocks), 6) + # 0 listen_and_serv + # 1 optimize for fc_w or fc_b adam + self.assertEqual([op.type for op in pserver1.blocks[1].ops], + ["sum", "scale", "adam", "scale", "scale"]) + # 2 optimize for table sgd + self.assertEqual([op.type for op in pserver1.blocks[2].ops], + ["sum", "sgd"]) + # 3 prefetch -> lookup_sparse_table for data0 + self.assertEqual([op.type for op in pserver1.blocks[3].ops], + ["lookup_sparse_table"]) + # 4 prefetch -> lookup_sparse_table for data1 + self.assertEqual([op.type for op in pserver1.blocks[4].ops], + ["lookup_sparse_table"]) + # 5 save table + self.assertEqual([op.type for op in pserver1.blocks[5].ops], ["save"]) + + trainer = self.get_trainer() + self.assertEqual(len(trainer.blocks), 1) + ops = [ + 'split_ids', 'prefetch', 'merge_ids', 'sequence_pool', 'split_ids', + 'prefetch', 'merge_ids', 'sequence_pool', 'concat', 'mul', + 'elementwise_add', 'cross_entropy', 'mean', 'fill_constant', + 'mean_grad', 'cross_entropy_grad', 'elementwise_add_grad', 'send', + 'mul_grad', 'send', 'concat_grad', 'sequence_pool_grad', + 'lookup_table_grad', 'sequence_pool_grad', 'lookup_table_grad', + 'sum', 'split_ids', 'send', 'send_barrier', 'recv', 'recv', + 'fetch_barrier' + ] + self.assertEqual([op.type for op in trainer.blocks[0].ops], ops) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_fused_elemwise_activation_op.py b/python/paddle/fluid/tests/unittests/test_fused_elemwise_activation_op.py new file mode 100644 index 0000000000000000000000000000000000000000..ec0a939e9ec21952a6657ea849bb9844bb69cc8d --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_fused_elemwise_activation_op.py @@ -0,0 +1,818 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest +import numpy as np +import paddle.fluid.core as core +from op_test import OpTest + +# scale + add +# TestElementwiseAddOp +# TestFusedOperatorsOp_scalar +# TestFusedOperatorsOp_scalar2 +# TestFusedOperatorsOp_Vector +# TestFusedOperatorsOp_broadcast_0 +# TestFusedOperatorsOp_broadcast_1 +# TestFusedOperatorsOp_broadcast_2 +# TestFusedOperatorsOp_broadcast_3 +# TestFusedOperatorsOp_broadcast_4 +# TestFusedOperatorsOp_rowwise_add_0 +# TestFusedOperatorsOp_rowwise_add_1 +# TestFusedOperatorsOp_channelwise_add + + +class TestElementwiseAddOp(OpTest): + def setUp(self): + self.op_type = "fused_elemwise_activation" + self.dtype = np.float32 + self.axis = -1 + + self.init_axis() + self.init_dtype() + self.init_input() + self.init_output() + self.init_attr() + + self.inputs = { + 'X': OpTest.np_dtype_to_fluid_dtype(self.x), + 'Y': OpTest.np_dtype_to_fluid_dtype(self.y) + } + self.outputs = {'Out': self.out} + + def init_input(self): + self.x = np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype) + self.y = np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype) + + def init_output(self): + self.scale = 0.1 + self.out = (self.x + self.y) * self.scale + + def init_attr(self): + self.attrs = { + 'axis': self.axis, + 'scale': self.scale, + 'functor_list': ["scale", "elementwise_add"] + } + + def init_dtype(self): + pass + + def init_axis(self): + pass + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.005) + + def test_check_grad_ingore_x(self): + self.check_grad( + ['Y'], 'Out', max_relative_error=0.005, no_grad_set=set("X")) + + def test_check_grad_ingore_y(self): + self.check_grad( + ['X'], 'Out', max_relative_error=0.005, no_grad_set=set('Y')) + + +class TestFusedOperatorsOp_scalar(TestElementwiseAddOp): + def init_input(self): + self.x = np.random.rand(2, 3, 4).astype(self.dtype) + self.y = np.random.rand(1).astype(self.dtype) + + def init_output(self): + self.scale = 0.1 + self.out = (self.x + self.y) * self.scale + + +class TestFusedOperatorsOp_scalar2(TestElementwiseAddOp): + def init_input(self): + self.x = np.random.rand(2, 3, 4).astype(self.dtype) + self.y = np.random.rand(1, 1).astype(self.dtype) + + def init_output(self): + self.scale = 0.1 + self.out = (self.x + self.y) * self.scale + + +class TestFusedOperatorsOp_Vector(TestElementwiseAddOp): + def init_input(self): + self.x = np.random.random((32, )).astype(self.dtype) + self.y = np.random.random((32, )).astype(self.dtype) + + def init_output(self): + self.scale = 0.1 + self.out = (self.x + self.y) * self.scale + + +class TestFusedOperatorsOp_broadcast_0(TestElementwiseAddOp): + def init_input(self): + self.x = np.random.rand(2, 3, 4).astype(self.dtype) + self.y = np.random.rand(2).astype(self.dtype) + + def init_axis(self): + self.axis = 0 + + def init_output(self): + self.scale = 0.1 + self.out = (self.x + self.y.reshape(2, 1, 1)) * self.scale + + +class TestFusedOperatorsOp_broadcast_1(TestElementwiseAddOp): + def init_input(self): + self.x = np.random.rand(2, 3, 4).astype(self.dtype) + self.y = np.random.rand(3).astype(self.dtype) + + def init_axis(self): + self.axis = 1 + + def init_output(self): + self.scale = 0.1 + self.out = (self.x + self.y.reshape(1, 3, 1)) * self.scale + + +class TestFusedOperatorsOp_broadcast_2(TestElementwiseAddOp): + def init_input(self): + self.x = np.random.rand(2, 3, 4).astype(self.dtype) + self.y = np.random.rand(4).astype(self.dtype) + + def init_output(self): + self.scale = 0.1 + self.out = (self.x + self.y.reshape(1, 1, 4)) * self.scale + + +class TestFusedOperatorsOp_broadcast_3(TestElementwiseAddOp): + def init_input(self): + self.x = np.random.rand(2, 3, 4, 5).astype(self.dtype) + self.y = np.random.rand(3, 4).astype(self.dtype) + + def init_axis(self): + self.axis = 1 + + def init_output(self): + self.scale = 0.1 + self.out = (self.x + self.y.reshape(1, 3, 4, 1)) * self.scale + + +class TestFusedOperatorsOp_broadcast_4(TestElementwiseAddOp): + def init_input(self): + self.x = np.random.rand(2, 3, 4, 5).astype(self.dtype) + self.y = np.random.rand(2, 1).astype(self.dtype) + + def init_axis(self): + self.axis = 0 + + def init_output(self): + self.scale = 0.1 + self.out = (self.x + self.y.reshape(2, 1, 1, 1)) * self.scale + + +class TestFusedOperatorsOp_rowwise_add_0(TestElementwiseAddOp): + def init_input(self): + self.x = np.random.rand(2, 3, 4).astype(self.dtype) + self.y = np.random.rand(3, 4).astype(self.dtype) + + def init_axis(self): + self.axis = 1 + + def init_output(self): + self.scale = 0.1 + self.out = (self.x + self.y.reshape(1, 3, 4)) * self.scale + + +class TestFusedOperatorsOp_rowwise_add_1(TestElementwiseAddOp): + def init_input(self): + self.x = np.random.rand(2, 1).astype(self.dtype) + self.y = np.random.rand(1).astype(self.dtype) + + def init_axis(self): + self.axis = 1 + + def init_output(self): + self.scale = 0.1 + self.out = (self.x + self.y.reshape(1, 1)) * self.scale + + +class TestFusedOperatorsOp_channelwise_add(TestElementwiseAddOp): + def init_input(self): + self.x = np.random.rand(3, 20, 20).astype(self.dtype) + self.y = np.random.rand(3, 1, 1).astype(self.dtype) + + def init_axis(self): + self.axis = -1 + + def init_output(self): + self.scale = 0.1 + self.out = (self.x + self.y) * self.scale + + +# add + scale +# TestElementwiseAddOp_f_add_scale +# TestFusedOperatorsOp_scalar_f_add_scale +# TestFusedOperatorsOp_scalar2_f_add_scale +# TestFusedOperatorsOp_Vector_f_add_scale +# TestFusedOperatorsOp_broadcast_0_f_add_scale +# TestFusedOperatorsOp_broadcast_1_f_add_scale +# TestFusedOperatorsOp_broadcast_2_f_add_scale +# TestFusedOperatorsOp_broadcast_3_f_add_scale +# TestFusedOperatorsOp_broadcast_4_f_add_scale +# TestFusedOperatorsOp_rowwise_add_0_f_add_scale +# TestFusedOperatorsOp_rowwise_add_1_f_add_scale +# TestFusedOperatorsOp_channelwise_add_f_add_scale + + +class TestFusedOperatorsOp_f_add_scale(TestElementwiseAddOp): + def init_output(self): + self.scale = 0.1 + self.out = self.x + self.y * self.scale + + def init_attr(self): + self.attrs = { + 'axis': self.axis, + 'scale': self.scale, + 'functor_list': ["elementwise_add", "scale"] + } + + +class TestFusedOperatorsOp_scalar_f_add_scale(TestFusedOperatorsOp_scalar): + def init_output(self): + self.scale = 0.1 + self.out = self.x + self.y * self.scale + + def init_attr(self): + self.attrs = { + 'axis': self.axis, + 'scale': self.scale, + 'functor_list': ["elementwise_add", "scale"] + } + + +class TestFusedOperatorsOp_scalar2_f_add_scale(TestFusedOperatorsOp_scalar2): + def init_output(self): + self.scale = 0.1 + self.out = self.x + self.y * self.scale + + def init_attr(self): + self.attrs = { + 'axis': self.axis, + 'scale': self.scale, + 'functor_list': ["elementwise_add", "scale"] + } + + +class TestFusedOperatorsOp_Vector_f_add_scale(TestFusedOperatorsOp_Vector): + def init_output(self): + self.scale = 0.1 + self.out = self.x + self.y * self.scale + + def init_attr(self): + self.attrs = { + 'axis': self.axis, + 'scale': self.scale, + 'functor_list': ["elementwise_add", "scale"] + } + + +class TestFusedOperatorsOp_broadcast_0_f_add_scale( + TestFusedOperatorsOp_broadcast_0): + def init_axis(self): + self.axis = 0 + + def init_output(self): + self.scale = 0.1 + self.out = self.x + self.y.reshape(2, 1, 1) * self.scale + + def init_attr(self): + self.attrs = { + 'axis': self.axis, + 'scale': self.scale, + 'functor_list': ["elementwise_add", "scale"] + } + + +class TestFusedOperatorsOp_broadcast_1_f_add_scale( + TestFusedOperatorsOp_broadcast_1): + def init_axis(self): + self.axis = 1 + + def init_output(self): + self.scale = 0.1 + self.out = self.x + self.y.reshape(1, 3, 1) * self.scale + + def init_attr(self): + self.attrs = { + 'axis': self.axis, + 'scale': self.scale, + 'functor_list': ["elementwise_add", "scale"] + } + + +class TestFusedOperatorsOp_broadcast_2_f_add_scale( + TestFusedOperatorsOp_broadcast_2): + def init_output(self): + self.scale = 0.1 + self.out = self.x + self.y.reshape(1, 1, 4) * self.scale + + def init_attr(self): + self.attrs = { + 'axis': self.axis, + 'scale': self.scale, + 'functor_list': ["elementwise_add", "scale"] + } + + +class TestFusedOperatorsOp_broadcast_3_f_add_scale( + TestFusedOperatorsOp_broadcast_3): + def init_axis(self): + self.axis = 1 + + def init_output(self): + self.scale = 0.1 + self.out = self.x + self.y.reshape(1, 3, 4, 1) * self.scale + + def init_attr(self): + self.attrs = { + 'axis': self.axis, + 'scale': self.scale, + 'functor_list': ["elementwise_add", "scale"] + } + + +class TestFusedOperatorsOp_broadcast_4_f_add_scale( + TestFusedOperatorsOp_broadcast_4): + def init_axis(self): + self.axis = 0 + + def init_output(self): + self.scale = 0.2 + self.out = self.x + self.y.reshape(2, 1, 1, 1) * self.scale + + def init_attr(self): + self.attrs = { + 'axis': self.axis, + 'scale': self.scale, + 'functor_list': ["elementwise_add", "scale"] + } + + +class TestFusedOperatorsOp_rowwise_add_0_f_add_scale( + TestFusedOperatorsOp_rowwise_add_0): + def init_axis(self): + self.axis = 1 + + def init_output(self): + self.scale = 0.1 + self.out = self.x + self.y.reshape(1, 3, 4) * self.scale + + def init_attr(self): + self.attrs = { + 'axis': self.axis, + 'scale': self.scale, + 'functor_list': ["elementwise_add", "scale"] + } + + +class TestFusedOperatorsOp_rowwise_add_1_f_add_scale( + TestFusedOperatorsOp_rowwise_add_1): + def init_axis(self): + self.axis = 1 + + def init_output(self): + self.scale = 0.2 + self.out = self.x + self.y.reshape(1, 1) * self.scale + + def init_attr(self): + self.attrs = { + 'axis': self.axis, + 'scale': self.scale, + 'functor_list': ["elementwise_add", "scale"] + } + + +class TestFusedOperatorsOp_channelwise_add_f_add_scale( + TestFusedOperatorsOp_channelwise_add): + def init_axis(self): + self.axis = -1 + + def init_output(self): + self.scale = 0.2 + self.out = self.x + self.y * self.scale + + def init_attr(self): + self.attrs = { + 'axis': self.axis, + 'scale': self.scale, + 'functor_list': ["elementwise_add", "scale"] + } + + +# add + relu +# TestElementwiseAddOp_f_add_relu +# TestFusedOperatorsOp_scalar_f_add_relu +# TestFusedOperatorsOp_scalar2_f_add_relu +# TestFusedOperatorsOp_Vector_f_add_relu +# TestFusedOperatorsOp_broadcast_0_f_add_relu +# TestFusedOperatorsOp_broadcast_1_f_add_relu +# TestFusedOperatorsOp_broadcast_2_f_add_relu +# TestFusedOperatorsOp_broadcast_3_f_add_relu +# TestFusedOperatorsOp_broadcast_4_f_add_relu +# TestFusedOperatorsOp_rowwise_add_0_f_add_relu +# TestFusedOperatorsOp_rowwise_add_1_f_add_relu +# TestFusedOperatorsOp_channelwise_add_f_add_relu + + +class TestFusedOperatorsOp_f_add_relu(TestElementwiseAddOp): + def init_output(self): + # Copy from test_activation_op.py + # Because we set delta = 0.005 in calculating numeric gradient, + # if x is too small, such as 0.002, x_neg will be -0.003 + # x_pos will be 0.007, so the numeric gradient is inaccurate. + # we should avoid this + self.y[np.abs(self.y) < 0.005] = 0.02 + self.out = self.x + np.maximum(self.y, 0) + + def init_attr(self): + self.attrs = { + 'axis': self.axis, + 'functor_list': ["elementwise_add", "relu"] + } + + +class TestFusedOperatorsOp_scalar_f_add_relu(TestFusedOperatorsOp_scalar): + def init_output(self): + self.y[np.abs(self.y) < 0.005] = 0.02 + self.out = self.x + np.maximum(self.y, 0) + + def init_attr(self): + self.attrs = { + 'axis': self.axis, + 'functor_list': ["elementwise_add", "relu"] + } + + +class TestFusedOperatorsOp_scalar2_f_add_relu(TestFusedOperatorsOp_scalar2): + def init_output(self): + self.y[np.abs(self.y) < 0.005] = 0.02 + self.out = self.x + np.maximum(self.y, 0) + + def init_attr(self): + self.attrs = { + 'axis': self.axis, + 'functor_list': ["elementwise_add", "relu"] + } + + +class TestFusedOperatorsOp_Vector_f_add_relu(TestFusedOperatorsOp_Vector): + def init_output(self): + self.y[np.abs(self.y) < 0.005] = 0.02 + self.out = self.x + np.maximum(self.y, 0) + + def init_attr(self): + self.attrs = { + 'axis': self.axis, + 'functor_list': ["elementwise_add", "relu"] + } + + +class TestFusedOperatorsOp_broadcast_0_f_add_relu( + TestFusedOperatorsOp_broadcast_0): + def init_axis(self): + self.axis = 0 + + def init_output(self): + self.y[np.abs(self.y) < 0.005] = 0.02 + self.out = self.x + np.maximum(self.y.reshape(2, 1, 1), 0) + + def init_attr(self): + self.attrs = { + 'axis': self.axis, + 'functor_list': ["elementwise_add", "relu"] + } + + +class TestFusedOperatorsOp_broadcast_1_f_add_relu( + TestFusedOperatorsOp_broadcast_1): + def init_axis(self): + self.axis = 1 + + def init_output(self): + self.y[np.abs(self.y) < 0.005] = 0.02 + self.out = self.x + np.maximum(self.y.reshape(1, 3, 1), 0) + + def init_attr(self): + self.attrs = { + 'axis': self.axis, + 'functor_list': ["elementwise_add", "relu"] + } + + +class TestFusedOperatorsOp_broadcast_2_f_add_relu( + TestFusedOperatorsOp_broadcast_2): + def init_output(self): + self.y[np.abs(self.y) < 0.005] = 0.02 + self.out = self.x + np.maximum(self.y.reshape(1, 1, 4), 0) + + def init_attr(self): + self.attrs = { + 'axis': self.axis, + 'functor_list': ["elementwise_add", "relu"] + } + + +class TestFusedOperatorsOp_broadcast_3_f_add_relu( + TestFusedOperatorsOp_broadcast_3): + def init_axis(self): + self.axis = 1 + + def init_output(self): + self.y[np.abs(self.y) < 0.005] = 0.02 + self.out = self.x + np.maximum(self.y.reshape(1, 3, 4, 1), 0) + + def init_attr(self): + self.attrs = { + 'axis': self.axis, + 'functor_list': ["elementwise_add", "relu"] + } + + +class TestFusedOperatorsOp_broadcast_4_f_add_relu( + TestFusedOperatorsOp_broadcast_4): + def init_axis(self): + self.axis = 0 + + def init_output(self): + self.y[np.abs(self.y) < 0.005] = 0.02 + self.out = self.x + np.maximum(self.y.reshape(2, 1, 1, 1), 0) + + def init_attr(self): + self.attrs = { + 'axis': self.axis, + 'functor_list': ["elementwise_add", "relu"] + } + + +class TestFusedOperatorsOp_rowwise_add_0_f_add_relu( + TestFusedOperatorsOp_rowwise_add_0): + def init_axis(self): + self.axis = 1 + + def init_output(self): + self.y[np.abs(self.y) < 0.005] = 0.02 + self.out = self.x + np.maximum(self.y.reshape(1, 3, 4), 0) + + def init_attr(self): + self.attrs = { + 'axis': self.axis, + 'functor_list': ["elementwise_add", "relu"] + } + + +class TestFusedOperatorsOp_rowwise_add_1_f_add_relu( + TestFusedOperatorsOp_rowwise_add_1): + def init_axis(self): + self.axis = 1 + + def init_output(self): + self.y[np.abs(self.y) < 0.005] = 0.02 + self.out = self.x + np.maximum(self.y.reshape(1, 1), 0) + + def init_attr(self): + self.attrs = { + 'axis': self.axis, + 'functor_list': ["elementwise_add", "relu"] + } + + +class TestFusedOperatorsOp_channelwise_add_f_add_relu( + TestFusedOperatorsOp_channelwise_add): + def init_axis(self): + self.axis = -1 + + def init_output(self): + self.y[np.abs(self.y) < 0.005] = 0.02 + self.out = self.x + np.maximum(self.y, 0) + + def init_attr(self): + self.attrs = { + 'axis': self.axis, + 'functor_list': ["elementwise_add", "relu"] + } + + +# relu + add +# TestElementwiseAddOp_f_relu_add +# TestFusedOperatorsOp_scalar_f_relu_add +# TestFusedOperatorsOp_scalar2_f_relu_add +# TestFusedOperatorsOp_Vector_f_relu_add +# TestFusedOperatorsOp_broadcast_0_f_relu_add +# TestFusedOperatorsOp_broadcast_1_f_relu_add +# TestFusedOperatorsOp_broadcast_2_f_relu_add +# TestFusedOperatorsOp_broadcast_3_f_relu_add +# TestFusedOperatorsOp_broadcast_4_f_relu_add +# TestFusedOperatorsOp_rowwise_add_0_f_relu_add +# TestFusedOperatorsOp_rowwise_add_1_f_relu_add +# TestFusedOperatorsOp_channelwise_add_f_relu_add + + +class TestFusedOperatorsOp_f_relu_add(TestElementwiseAddOp): + def init_output(self): + # Copy from test_activation_op.py + # Because we set delta = 0.005 in calculating numeric gradient, + # if x is too small, such as 0.002, x_neg will be -0.003 + # x_pos will be 0.007, so the numeric gradient is inaccurate. + # we should avoid this + self.out = self.x + self.y + self.out = np.maximum(self.out, 0) + self.out[np.abs(self.out) < 0.005] = 0.02 + + def init_attr(self): + self.attrs = { + 'axis': self.axis, + 'functor_list': ["relu", "elementwise_add"] + } + + +class TestFusedOperatorsOp_scalar_f_relu_add(TestFusedOperatorsOp_scalar): + def init_output(self): + self.out = self.x + self.y + self.out = np.maximum(self.out, 0) + self.out[np.abs(self.out) < 0.005] = 0.02 + + def init_attr(self): + self.attrs = { + 'axis': self.axis, + 'functor_list': ["relu", "elementwise_add"] + } + + +class TestFusedOperatorsOp_scalar2_f_relu_add(TestFusedOperatorsOp_scalar2): + def init_output(self): + self.out = self.x + self.y + self.out = np.maximum(self.out, 0) + self.out[np.abs(self.out) < 0.005] = 0.02 + + def init_attr(self): + self.attrs = { + 'axis': self.axis, + 'functor_list': ["relu", "elementwise_add"] + } + + +class TestFusedOperatorsOp_Vector_f_relu_add(TestFusedOperatorsOp_Vector): + def init_output(self): + self.out = self.x + self.y + self.out = np.maximum(self.out, 0) + self.out[np.abs(self.out) < 0.005] = 0.02 + + def init_attr(self): + self.attrs = { + 'axis': self.axis, + 'functor_list': ["relu", "elementwise_add"] + } + + +class TestFusedOperatorsOp_broadcast_0_f_relu_add( + TestFusedOperatorsOp_broadcast_0): + def init_axis(self): + self.axis = 0 + + def init_output(self): + self.out = self.x + self.y.reshape(2, 1, 1) + self.out = np.maximum(self.out, 0) + self.out[np.abs(self.out) < 0.005] = 0.02 + + def init_attr(self): + self.attrs = { + 'axis': self.axis, + 'functor_list': ["relu", "elementwise_add"] + } + + +class TestFusedOperatorsOp_broadcast_1_f_relu_add( + TestFusedOperatorsOp_broadcast_1): + def init_axis(self): + self.axis = 1 + + def init_output(self): + self.out = self.x + self.y.reshape(1, 3, 1) + self.out = np.maximum(self.out, 0) + self.out[np.abs(self.out) < 0.005] = 0.02 + + def init_attr(self): + self.attrs = { + 'axis': self.axis, + 'functor_list': ["relu", "elementwise_add"] + } + + +class TestFusedOperatorsOp_broadcast_2_f_relu_add( + TestFusedOperatorsOp_broadcast_2): + def init_output(self): + self.out = self.x + self.y.reshape(1, 1, 4) + self.out = np.maximum(self.out, 0) + self.out[np.abs(self.out) < 0.005] = 0.02 + + def init_attr(self): + self.attrs = { + 'axis': self.axis, + 'functor_list': ["relu", "elementwise_add"] + } + + +class TestFusedOperatorsOp_broadcast_3_f_relu_add( + TestFusedOperatorsOp_broadcast_3): + def init_axis(self): + self.axis = 1 + + def init_output(self): + self.out = self.x + self.y.reshape(1, 3, 4, 1) + self.out = np.maximum(self.out, 0) + self.out[np.abs(self.out) < 0.005] = 0.02 + + def init_attr(self): + self.attrs = { + 'axis': self.axis, + 'functor_list': ["relu", "elementwise_add"] + } + + +class TestFusedOperatorsOp_broadcast_4_f_relu_add( + TestFusedOperatorsOp_broadcast_4): + def init_axis(self): + self.axis = 0 + + def init_output(self): + self.out = self.x + self.y.reshape(2, 1, 1, 1) + self.out = np.maximum(self.out, 0) + self.out[np.abs(self.out) < 0.005] = 0.02 + + def init_attr(self): + self.attrs = { + 'axis': self.axis, + 'functor_list': ["relu", "elementwise_add"] + } + + +class TestFusedOperatorsOp_rowwise_add_0_f_relu_add( + TestFusedOperatorsOp_rowwise_add_0): + def init_axis(self): + self.axis = 1 + + def init_output(self): + self.out = self.x + self.y.reshape(1, 3, 4) + self.out = np.maximum(self.out, 0) + self.out[np.abs(self.out) < 0.005] = 0.02 + + def init_attr(self): + self.attrs = { + 'axis': self.axis, + 'functor_list': ["relu", "elementwise_add"] + } + + +class TestFusedOperatorsOp_rowwise_add_1_f_relu_add( + TestFusedOperatorsOp_rowwise_add_1): + def init_axis(self): + self.axis = 1 + + def init_output(self): + self.out = self.x + self.y.reshape(1, 1) + self.out = np.maximum(self.out, 0) + self.out[np.abs(self.out) < 0.005] = 0.02 + + def init_attr(self): + self.attrs = { + 'axis': self.axis, + 'functor_list': ["relu", "elementwise_add"] + } + + +class TestFusedOperatorsOp_channelwise_add_f_relu_add( + TestFusedOperatorsOp_channelwise_add): + def init_axis(self): + self.axis = -1 + + def init_output(self): + self.out = self.x + self.y + self.out = np.maximum(self.out, 0) + self.out[np.abs(self.out) < 0.005] = 0.02 + + def init_attr(self): + self.attrs = { + 'axis': self.axis, + 'functor_list': ["relu", "elementwise_add"] + } + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_memory_usage.py b/python/paddle/fluid/tests/unittests/test_memory_usage.py new file mode 100644 index 0000000000000000000000000000000000000000..f9daf83652e18faab0ab31402b9f5889a0beceaf --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_memory_usage.py @@ -0,0 +1,69 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import paddle +import paddle.fluid as fluid +import contextlib +import unittest + + +def train_simulator(test_batch_size=10): + if test_batch_size <= 0: + raise ValueError("batch_size should be a positive integeral value, " + "but got batch_size={}".format(test_batch_size)) + + x = fluid.layers.data(name='x', shape=[13], dtype='float32') + y_predict = fluid.layers.fc(input=x, size=1, act=None) + y = fluid.layers.data(name='y', shape=[1], dtype='float32') + + cost = fluid.layers.square_error_cost(input=y_predict, label=y) + avg_cost = fluid.layers.mean(cost) + + sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001) + sgd_optimizer.minimize(avg_cost) + + # Calculate memory usage in current network config + lower_usage, upper_usage, unit = fluid.contrib.memory_usage( + fluid.default_main_program(), batch_size=test_batch_size) + + print("memory usage is about %.3f - %.3f %s" % + (lower_usage, upper_usage, unit)) + + +class TestMemoryUsage(unittest.TestCase): + def test_with_unit_B(self): + with self.program_scope_guard(): + train_simulator() + + def test_with_unit_KB(self): + with self.program_scope_guard(): + train_simulator(test_batch_size=1000) + + def test_with_unit_MB(self): + with self.program_scope_guard(): + train_simulator(test_batch_size=100000) + + @contextlib.contextmanager + def program_scope_guard(self): + prog = fluid.Program() + startup_prog = fluid.Program() + scope = fluid.core.Scope() + with fluid.scope_guard(scope): + with fluid.program_guard(prog, startup_prog): + yield + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_parallel_executor_seresnext.py b/python/paddle/fluid/tests/unittests/test_parallel_executor_seresnext.py index b56129a433a9b222f93525ed8fd3013c6f653148..a28428d8dee201ba105e18684c15d4b4582d989f 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_executor_seresnext.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_executor_seresnext.py @@ -198,7 +198,7 @@ class TestResnet(TestParallelExecutorBase): model, use_cuda, iter=20, - delta2=1e-4): + delta2=1e-6): if use_cuda and not core.is_compiled_with_cuda(): return @@ -276,10 +276,10 @@ class TestResnet(TestParallelExecutorBase): model=SE_ResNeXt50Small, use_cuda=False, iter=2, delta2=1e-3) def test_seresnext_with_new_strategy(self): - # self._compare_reduce_and_allreduce( - # model=SE_ResNeXt50Small, use_cuda=True) self._compare_reduce_and_allreduce( - model=SE_ResNeXt50Small, use_cuda=False, iter=5, delta2=1e-2) + model=SE_ResNeXt50Small, use_cuda=True, delta2=1e-2) + self._compare_reduce_and_allreduce( + model=SE_ResNeXt50Small, use_cuda=False, iter=5) if __name__ == '__main__': diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index d4d19799fdb291545117f327d2b9b2c25fbfe5f5..b0a100e1db34ad2971eadabff09fa5d0ce3f51dc 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -896,8 +896,6 @@ class DistributeTranspiler(object): self.table_name ][0] table_opt_block = pserver_program.create_block(pre_block_idx) - # only support sgd now - assert table_opt_op.type == "sgd" if self.sync_mode: # create grad vars in pserver program @@ -937,11 +935,12 @@ class DistributeTranspiler(object): "LearningRate": [lr_var] } outputs = {"ParamOut": [param_var]} - table_opt_block.append_op( - type=table_opt_op.type, - inputs=inputs, - outputs=outputs, - attrs=table_opt_op.attrs) + # only support sgd now + import logging + logging.warn( + "distribute lookup table only support sgd optimizer, change it's optimizer to sgd instead of " + + table_opt_op.type) + table_opt_block.append_op(type="sgd", inputs=inputs, outputs=outputs) # add table parameter gradient and it's block id to grad_to_block_id grad_to_block_id.append(grad_var.name + ":" + str(table_opt_block.idx))