From ee2028a110cf0c7c6df1eba744effaca1754139b Mon Sep 17 00:00:00 2001 From: Zeng Jinle <32832641+sneaxiy@users.noreply.github.com> Date: Sun, 5 May 2019 22:21:02 -0500 Subject: [PATCH] Add use_cuda to inplace pass (#17205) * add use_cuda to inplace pass,test=develop * add test softmax_with_xe_inplace test,test=develop --- .../fluid/framework/details/build_strategy.cc | 3 + .../framework/details/eager_deletion_pass.cc | 21 +++++ .../framework/details/inplace_op_pass.cc | 68 +++++++++++---- .../details/memory_optimize_helper.h | 2 + paddle/fluid/framework/details/op_registry.h | 4 +- paddle/fluid/framework/inplace_op_inference.h | 6 +- .../framework/inplace_op_inference_test.cc | 8 +- paddle/fluid/framework/type_defs.h | 2 +- paddle/fluid/operators/batch_norm_op.cc | 14 ++- .../operators/elementwise/elementwise_op.h | 12 +-- paddle/fluid/operators/flatten_op.cc | 14 +-- paddle/fluid/operators/group_norm_op.cc | 4 +- paddle/fluid/operators/reshape_op.cc | 14 +-- .../softmax_with_cross_entropy_op.cc | 18 +++- .../softmax_with_cross_entropy_op.cu | 52 +++++------ paddle/fluid/operators/sum_op.cc | 11 ++- ...test_inplace_softmax_with_cross_entropy.py | 86 +++++++++++++++++++ 17 files changed, 239 insertions(+), 100 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_inplace_softmax_with_cross_entropy.py diff --git a/paddle/fluid/framework/details/build_strategy.cc b/paddle/fluid/framework/details/build_strategy.cc index 26680eeb298..8aa4a9645dd 100644 --- a/paddle/fluid/framework/details/build_strategy.cc +++ b/paddle/fluid/framework/details/build_strategy.cc @@ -311,6 +311,9 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph, "GPU, skipped."; continue; } + } else if (pass->Type() == "inplace_pass") { + pass->Erase(kUseCuda); + pass->Set(kUseCuda, new bool(use_cuda)); } VLOG(3) << "Start Apply Pass " << pass->Type(); graph = pass->Apply(graph); diff --git a/paddle/fluid/framework/details/eager_deletion_pass.cc b/paddle/fluid/framework/details/eager_deletion_pass.cc index 622a59b4c2e..5ea18efe5c3 100644 --- a/paddle/fluid/framework/details/eager_deletion_pass.cc +++ b/paddle/fluid/framework/details/eager_deletion_pass.cc @@ -33,6 +33,19 @@ namespace details { using OpToVarNameSetMap = std::unordered_map>; +static std::map> VarsGroupByScopeIdx( + const OpToVarNameSetMap &map) { + std::map> result; + for (auto &pair : map) { + size_t scope_idx = pair.first->GetScopeIdx(); + auto &var_set = result[scope_idx]; + for (auto &var : pair.second) { + var_set.insert(var); + } + } + return result; +} + // Check whether the variable is LoDTensor based on static VarDesc info static bool IsLoDTensor(VarDesc *var) { return var->Proto()->type().type() == proto::VarType::LOD_TENSOR; @@ -236,6 +249,14 @@ void EagerDeletionPass::ApplyImpl(ir::Graph *graph) const { VLOG(10) << "FLAGS_memory_fraction_of_eager_deletion = " << memory_fraction; VLOG(10) << "Create " << op_vars_map.size() << " EagerDeletionOpHandle(s)"; + if (VLOG_IS_ON(10)) { + auto vars_group_by_scope_idx = VarsGroupByScopeIdx(op_vars_map); + for (auto &pair : vars_group_by_scope_idx) { + VLOG(10) << "Scope " << pair.first << " has " << pair.second.size() + << " vars"; + } + } + auto while_op_eager_deletion_pass = ir::PassRegistry::Instance().Get("while_op_eager_deletion_pass"); while_op_eager_deletion_pass->Apply(graph); diff --git a/paddle/fluid/framework/details/inplace_op_pass.cc b/paddle/fluid/framework/details/inplace_op_pass.cc index d4812a01d8a..9313d9958dd 100644 --- a/paddle/fluid/framework/details/inplace_op_pass.cc +++ b/paddle/fluid/framework/details/inplace_op_pass.cc @@ -111,9 +111,9 @@ class InplacePass : public ir::Pass { // Check whether all `ops` is the preceding ops of `op` bool CheckOpDeps(ir::Node *op, const std::vector &ops) const; - // Find node whose name is equal to the given name - static ir::Node *FindNodeByName(const std::string &name, - const std::vector &nodes); + // Find nodes whose name are equal to the given name + static std::unordered_set FindNodesByName( + const std::string &name, const std::vector &nodes); // Get all versions vars named var_name std::vector *AllVersionVars(const std::string &var_name) const; @@ -290,17 +290,15 @@ void InplacePass::RenameInOut(ir::Node *op, ir::Node *in_var, op->Op()->Flush(); } -ir::Node *InplacePass::FindNodeByName(const std::string &name, - const std::vector &nodes) { - ir::Node *found_node = nullptr; +std::unordered_set InplacePass::FindNodesByName( + const std::string &name, const std::vector &nodes) { + std::unordered_set ret; for (auto *node : nodes) { if (node->Name() == name) { - PADDLE_ENFORCE(found_node == nullptr, "Find duplicate input nodes %s", - name); - found_node = node; + ret.insert(node); } } - return found_node; + return ret; } void InplacePass::ApplyImpl(ir::Graph *graph) const { @@ -326,6 +324,10 @@ void InplacePass::ApplyImpl(ir::Graph *graph) const { } // Step 3: traverse ops and try inplace if possible + bool use_cuda = Get(kUseCuda); + VLOG(4) << "Inplace pass is applied when use_cuda = " + << (use_cuda ? "true" : "false"); + for (auto *op_node : ops) { PADDLE_ENFORCE_NOT_NULL(op_node->Op(), "op_desc is nullptr"); @@ -343,7 +345,7 @@ void InplacePass::ApplyImpl(ir::Graph *graph) const { continue; } - auto in_to_outs = infer_inplace(*op_desc); + auto in_to_outs = infer_inplace(*op_desc, use_cuda); for (auto &pair : in_to_outs) { auto &in_param = pair.first; auto &out_param = pair.second; @@ -385,9 +387,17 @@ void InplacePass::ApplyImpl(ir::Graph *graph) const { continue; } - auto *in_node = FindNodeByName(in_arg, op_node->inputs); - PADDLE_ENFORCE_NOT_NULL(in_node, "Input(%s)=%s cannot be found in op %s", - in_param, in_arg, op_type); + auto in_nodes = FindNodesByName(in_arg, op_node->inputs); + PADDLE_ENFORCE(!in_nodes.empty(), "Input(%s)=%s cannot be found in op %s", + in_param, in_arg, op_type); + + if (in_nodes.size() > 1) { + VLOG(4) << "Cannot inplace because Input(" << in_param << ")=" << in_arg + << " occurs in other inputs of " << op_type; + continue; + } + + auto *in_node = *in_nodes.begin(); if (!NodeCanReused(in_node)) { VLOG(4) << "Cannot inplace because Input(" << in_param << ")=" << in_arg @@ -410,10 +420,29 @@ void InplacePass::ApplyImpl(ir::Graph *graph) const { continue; } - auto *out_node = FindNodeByName(out_arg, op_node->outputs); - PADDLE_ENFORCE_NOT_NULL(out_node, - "Output(%s)=%s cannot be found in op %s", - out_param, out_arg, op_type); + auto out_nodes = FindNodesByName(out_arg, op_node->outputs); + PADDLE_ENFORCE(!out_nodes.empty(), + "Output(%s)=%s cannot be found in op %s", out_param, + out_arg, op_type); + + PADDLE_ENFORCE_EQ( + out_nodes.size(), 1, + "Wrong graph: Output(%s)=%s occurs in other outputs of op %s", + out_param, out_arg, op_type); + + if (!FindNodesByName(in_arg, op_node->outputs).empty()) { + VLOG(4) << "Cannot inplace because Input(" << in_param << ")=" << in_arg + << " occurs in output of op " << op_type; + continue; + } + + if (!FindNodesByName(out_arg, op_node->inputs).empty()) { + VLOG(4) << "Cannot inplace because Output(" << in_param + << ")=" << out_arg << " occurs in input of op " << op_type; + continue; + } + + auto *out_node = *out_nodes.begin(); if (!NodeCanReused(out_node)) { VLOG(4) << "Cannot inplace because Output(" << out_param @@ -457,4 +486,5 @@ void InplacePass::ApplyImpl(ir::Graph *graph) const { } // namespace framework } // namespace paddle -REGISTER_PASS(inplace_pass, paddle::framework::details::InplacePass); +REGISTER_PASS(inplace_pass, paddle::framework::details::InplacePass) + .RequirePassAttr(paddle::framework::details::kUseCuda); diff --git a/paddle/fluid/framework/details/memory_optimize_helper.h b/paddle/fluid/framework/details/memory_optimize_helper.h index 0a65ec051df..3ef407e4e9c 100644 --- a/paddle/fluid/framework/details/memory_optimize_helper.h +++ b/paddle/fluid/framework/details/memory_optimize_helper.h @@ -36,6 +36,8 @@ namespace details { constexpr char kMemOptSkipVars[] = "@MEM_OPT_SKIP_VARS@"; typedef std::unordered_set MemOptSkipVars; +constexpr char kUseCuda[] = "use_cuda"; + std::vector SortOpLikeDescOrder(const ir::Graph& graph); // NOTE(dzh): A ordered set for node reuse in memory optimize. diff --git a/paddle/fluid/framework/details/op_registry.h b/paddle/fluid/framework/details/op_registry.h index 18de595983f..0f03ca51da7 100644 --- a/paddle/fluid/framework/details/op_registry.h +++ b/paddle/fluid/framework/details/op_registry.h @@ -214,9 +214,9 @@ struct OpInfoFiller { template struct OpInfoFiller { void operator()(const char* op_type, OpInfo* info) const { - info->infer_inplace_ = [](const OpDesc& op_desc) { + info->infer_inplace_ = [](const OpDesc& op_desc, bool use_cuda) { T infer; - return infer(op_desc); + return infer(op_desc, use_cuda); }; } }; diff --git a/paddle/fluid/framework/inplace_op_inference.h b/paddle/fluid/framework/inplace_op_inference.h index df46d4f9a80..fddcbaf596d 100644 --- a/paddle/fluid/framework/inplace_op_inference.h +++ b/paddle/fluid/framework/inplace_op_inference.h @@ -37,7 +37,7 @@ class InplaceOpInference { public: virtual ~InplaceOpInference() {} virtual std::unordered_map operator()( - const OpDesc& op_desc) const = 0; + const OpDesc& op_desc, bool use_cuda) const = 0; }; /* @@ -47,7 +47,7 @@ class InplaceOpInference { class SingleOpInplaceInToOut : public InplaceOpInference { public: std::unordered_map operator()( - const OpDesc& op_desc) const override { + const OpDesc& op_desc, bool use_cuda) const override { PADDLE_ENFORCE(!op_desc.InputNames().empty(), "Op inputs must not be empty"); PADDLE_ENFORCE(!op_desc.OutputNames().empty(), @@ -65,7 +65,7 @@ class SingleOpInplaceInToOut : public InplaceOpInference { class GradOpInplaceInToOut : public InplaceOpInference { public: std::unordered_map operator()( - const OpDesc& op_desc) const override { + const OpDesc& op_desc, bool use_cuda) const override { std::unordered_map ret; std::unordered_set output_names(op_desc.OutputNames().begin(), op_desc.OutputNames().end()); diff --git a/paddle/fluid/framework/inplace_op_inference_test.cc b/paddle/fluid/framework/inplace_op_inference_test.cc index b2141628d2b..cebca9207a3 100644 --- a/paddle/fluid/framework/inplace_op_inference_test.cc +++ b/paddle/fluid/framework/inplace_op_inference_test.cc @@ -32,7 +32,9 @@ namespace paddle { namespace framework { std::unique_ptr CreateInplacePass() { - return ir::PassRegistry::Instance().Get("inplace_pass"); + auto pass = ir::PassRegistry::Instance().Get("inplace_pass"); + pass->Set(details::kUseCuda, new bool(true)); + return pass; } class NOP : public OperatorBase { @@ -141,7 +143,7 @@ class MultiOutGradShapeInference : public framework::InferShapeBase { class MultiOutInplaceInToOut : public framework::InplaceOpInference { public: std::unordered_map operator()( - const OpDesc& op_desc) const override { + const OpDesc& op_desc, bool use_cuda) const override { return std::unordered_map{ {"X", "Out"}, {"Y", "YOut"}, {"Z", "ZOut"}, }; @@ -151,7 +153,7 @@ class MultiOutInplaceInToOut : public framework::InplaceOpInference { class MultiOutGradInplaceInToOut : public framework::InplaceOpInference { public: std::unordered_map operator()( - const OpDesc& op_desc) const override { + const OpDesc& op_desc, bool use_cuda) const override { return std::unordered_map{ {framework::GradVarName("YOut"), framework::GradVarName("Y")}, {framework::GradVarName("Out"), framework::GradVarName("X")}, diff --git a/paddle/fluid/framework/type_defs.h b/paddle/fluid/framework/type_defs.h index 4ae6a272d5b..7f1bfb5d9a8 100644 --- a/paddle/fluid/framework/type_defs.h +++ b/paddle/fluid/framework/type_defs.h @@ -60,7 +60,7 @@ using InferVarTypeFN = using InferShapeFN = std::function; using InplacePair = std::unordered_map; -using InferInplaceOpFN = std::function; +using InferInplaceOpFN = std::function; using InferNoNeedBufferVarsFN = std::function( const VariableNameMap& /*inputs*/, const VariableNameMap& /*outputs*/, diff --git a/paddle/fluid/operators/batch_norm_op.cc b/paddle/fluid/operators/batch_norm_op.cc index 0cc3e1f2b83..d583909a666 100644 --- a/paddle/fluid/operators/batch_norm_op.cc +++ b/paddle/fluid/operators/batch_norm_op.cc @@ -600,25 +600,21 @@ std::unique_ptr BatchNormGradMaker::Apply() const { class BatchNormInplaceInToOut : public framework::InplaceOpInference { public: std::unordered_map operator()( - const framework::OpDesc &op_desc) const override { - std::unordered_map inplace_in_to_out = { - {"Mean", "MeanOut"}, {"Variance", "VarianceOut"}, {"X", "Y"}, - }; - return inplace_in_to_out; + const framework::OpDesc &op_desc, bool use_cuda) const override { + return {{"Mean", "MeanOut"}, {"Variance", "VarianceOut"}, {"X", "Y"}}; } }; class BatchNormGradInplaceInToOut : public framework::InplaceOpInference { public: std::unordered_map operator()( - const framework::OpDesc &op_desc) const override { - std::unordered_map inplace_in_to_out = { - // Scale, Bias, SavedMean, SavedVariance shape is [batch_size, C] + const framework::OpDesc &op_desc, bool use_cuda) const override { + // Scale, Bias, SavedMean, SavedVariance shape is [batch_size, C] + return { {framework::GradVarName("Y"), framework::GradVarName("X")}, {"SavedMean", framework::GradVarName("Scale")}, {"SavedVariance", framework::GradVarName("Bias")}, }; - return inplace_in_to_out; } }; diff --git a/paddle/fluid/operators/elementwise/elementwise_op.h b/paddle/fluid/operators/elementwise/elementwise_op.h index 95246b38f53..22d1d0dfbe4 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_op.h @@ -255,20 +255,16 @@ class ElemwiseGradKernel : public framework::OpKernel { class ElementwiseOpInplace : public framework::InplaceOpInference { public: std::unordered_map operator()( - const framework::OpDesc &op_desc) const override { - return std::unordered_map{ - {"X", "Out"}, - }; + const framework::OpDesc &op_desc, bool use_cuda) const override { + return {{"X", "Out"}}; } }; class ElementwiseGradOpInplace : public framework::InplaceOpInference { public: std::unordered_map operator()( - const framework::OpDesc &op_desc) const override { - return std::unordered_map{ - {framework::GradVarName("Out"), framework::GradVarName("X")}, - }; + const framework::OpDesc &op_desc, bool use_cuda) const override { + return {{framework::GradVarName("Out"), framework::GradVarName("X")}}; } }; diff --git a/paddle/fluid/operators/flatten_op.cc b/paddle/fluid/operators/flatten_op.cc index 7f43a1cfe97..f4085daa106 100644 --- a/paddle/fluid/operators/flatten_op.cc +++ b/paddle/fluid/operators/flatten_op.cc @@ -270,22 +270,16 @@ class Flatten2GradOp : public framework::OperatorBase { class FlattenOpInplaceInToOut : public framework::InplaceOpInference { public: std::unordered_map operator()( - const framework::OpDesc &op_desc) const override { - std::unordered_map inplace_in_to_out = { - {"X", "Out"}, - }; - return inplace_in_to_out; + const framework::OpDesc &op_desc, bool use_cuda) const override { + return {{"X", "Out"}}; } }; class FlattenGradInplaceinToOut : public framework::InplaceOpInference { public: std::unordered_map operator()( - const framework::OpDesc &op_desc) const override { - std::unordered_map inplace_in_to_out = { - {framework::GradVarName("Out"), framework::GradVarName("X")}, - }; - return inplace_in_to_out; + const framework::OpDesc &op_desc, bool use_cuda) const override { + return {{framework::GradVarName("Out"), framework::GradVarName("X")}}; } }; diff --git a/paddle/fluid/operators/group_norm_op.cc b/paddle/fluid/operators/group_norm_op.cc index 09fd6a25d18..2b1e8038fc4 100644 --- a/paddle/fluid/operators/group_norm_op.cc +++ b/paddle/fluid/operators/group_norm_op.cc @@ -173,7 +173,7 @@ class GroupNormGradMaker : public framework::SingleGradOpDescMaker { class GroupNormInplaceInToOut : public framework::InplaceOpInference { public: std::unordered_map operator()( - const framework::OpDesc &op_desc) const override { + const framework::OpDesc &op_desc, bool use_cuda) const override { return {{"X", "Y"}}; } }; @@ -181,7 +181,7 @@ class GroupNormInplaceInToOut : public framework::InplaceOpInference { class GroupNormGradInplaceInToOut : public framework::InplaceOpInference { public: std::unordered_map operator()( - const framework::OpDesc &op_desc) const override { + const framework::OpDesc &op_desc, bool use_cuda) const override { return {{framework::GradVarName("Y"), framework::GradVarName("X")}}; } }; diff --git a/paddle/fluid/operators/reshape_op.cc b/paddle/fluid/operators/reshape_op.cc index 5165af6a253..f3719e8f438 100644 --- a/paddle/fluid/operators/reshape_op.cc +++ b/paddle/fluid/operators/reshape_op.cc @@ -325,22 +325,16 @@ class Reshape2GradOp : public framework::OperatorWithKernel { class ReshapeOpInplaceInToOut : public framework::InplaceOpInference { public: std::unordered_map operator()( - const framework::OpDesc &op_desc) const override { - std::unordered_map inplace_in_to_out = { - {"X", "Out"}, - }; - return inplace_in_to_out; + const framework::OpDesc &op_desc, bool use_cuda) const override { + return {{"X", "Out"}}; } }; class ReshapeGradInplaceInToOut : public framework::InplaceOpInference { public: std::unordered_map operator()( - const framework::OpDesc &op_desc) const override { - std::unordered_map inplace_in_to_out = { - {framework::GradVarName("Out"), framework::GradVarName("X")}, - }; - return inplace_in_to_out; + const framework::OpDesc &op_desc, bool use_cuda) const override { + return {{framework::GradVarName("Out"), framework::GradVarName("X")}}; } }; diff --git a/paddle/fluid/operators/softmax_with_cross_entropy_op.cc b/paddle/fluid/operators/softmax_with_cross_entropy_op.cc index d7718bda5cc..371ab0384a3 100644 --- a/paddle/fluid/operators/softmax_with_cross_entropy_op.cc +++ b/paddle/fluid/operators/softmax_with_cross_entropy_op.cc @@ -228,11 +228,24 @@ class SoftmaxGradMaker : public framework::SingleGradOpDescMaker { } }; +class SoftmaxWithCrossEntropyInplaceInference + : public framework::InplaceOpInference { + public: + std::unordered_map operator()( + const framework::OpDesc& op_desc, bool use_cuda) const { + if (use_cuda && !boost::get(op_desc.GetAttr("soft_label"))) { + return {{"Logits", "Softmax"}}; + } else { + return {}; + } + } +}; + class SoftmaxWithCrossEntropyGradInplaceInference : public framework::InplaceOpInference { public: std::unordered_map operator()( - const framework::OpDesc& op_desc) const { + const framework::OpDesc& op_desc, bool use_cuda) const { return {{"Softmax", framework::GradVarName("Logits")}}; } }; @@ -243,7 +256,8 @@ class SoftmaxWithCrossEntropyGradInplaceInference namespace ops = paddle::operators; REGISTER_OPERATOR(softmax_with_cross_entropy, ops::SoftmaxWithCrossEntropyOp, - ops::SoftmaxWithCrossEntropyOpMaker, ops::SoftmaxGradMaker); + ops::SoftmaxWithCrossEntropyOpMaker, ops::SoftmaxGradMaker, + ops::SoftmaxWithCrossEntropyInplaceInference); REGISTER_OPERATOR(softmax_with_cross_entropy_grad, ops::SoftmaxWithCrossEntropyOpGrad, ops::SoftmaxWithCrossEntropyGradInplaceInference); diff --git a/paddle/fluid/operators/softmax_with_cross_entropy_op.cu b/paddle/fluid/operators/softmax_with_cross_entropy_op.cu index ed61fb38b5a..dc5ec7bc38c 100644 --- a/paddle/fluid/operators/softmax_with_cross_entropy_op.cu +++ b/paddle/fluid/operators/softmax_with_cross_entropy_op.cu @@ -183,8 +183,7 @@ static __global__ void RowReductionForDiffMaxSum(const T* logits_data, // Make sure that BlockDim <= feature_size template static __global__ void RowReductionForSoftmaxAndCrossEntropy( - const T* logits_data, const T* labels_data, T* loss_data, T* softmax, - int feature_size) { + const T* labels_data, T* loss_data, T* softmax, int feature_size) { __shared__ BlockReduceTempStorage temp_storage; auto beg_idx = feature_size * blockIdx.x + threadIdx.x; @@ -210,11 +209,9 @@ static __global__ void RowReductionForSoftmaxAndCrossEntropy( template struct HardLabelSoftmaxWithCrossEntropyFunctor { public: - HardLabelSoftmaxWithCrossEntropyFunctor(const T* logits, - const int64_t* labels, T* loss, + HardLabelSoftmaxWithCrossEntropyFunctor(const int64_t* labels, T* loss, T* log_softmax, int feature_size) - : logits_(logits), - labels_(labels), + : labels_(labels), loss_(loss), log_softmax_(log_softmax), feature_size_(feature_size) {} @@ -232,7 +229,6 @@ struct HardLabelSoftmaxWithCrossEntropyFunctor { } private: - const T* logits_; const int64_t* labels_; T* loss_; T* log_softmax_; @@ -242,13 +238,11 @@ struct HardLabelSoftmaxWithCrossEntropyFunctor { template struct HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx { public: - HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx(const T* logits, - const int64_t* labels, + HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx(const int64_t* labels, T* loss, T* log_softmax, int feature_size, int ignore_idx) - : logits_(logits), - labels_(labels), + : labels_(labels), loss_(loss), log_softmax_(log_softmax), feature_size_(feature_size), @@ -267,7 +261,6 @@ struct HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx { } private: - const T* logits_; const int64_t* labels_; T* loss_; T* log_softmax_; @@ -293,23 +286,22 @@ static void HardLabelSoftmaxWithCrossEntropy( : (1 << static_cast(std::log2(feature_size))); auto stream = ctx.stream(); -#define CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(BlockDim) \ - case BlockDim: { \ - RowReductionForMax<<>>( \ - logits_data, loss_data, feature_size); \ - RowReductionForDiffMaxSum<<>>( \ - logits_data, loss_data, softmax_data, feature_size); \ - platform::ForRange for_range( \ - ctx, batch_size* feature_size); \ - if (ignore_idx >= 0 && ignore_idx < feature_size) { \ - for_range(HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx( \ - logits_data, labels_data, loss_data, softmax_data, feature_size, \ - ignore_idx)); \ - } else { \ - for_range(HardLabelSoftmaxWithCrossEntropyFunctor( \ - logits_data, labels_data, loss_data, softmax_data, feature_size)); \ - } \ +#define CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(BlockDim) \ + case BlockDim: { \ + RowReductionForMax<<>>( \ + logits_data, loss_data, feature_size); \ + RowReductionForDiffMaxSum<<>>( \ + logits_data, loss_data, softmax_data, feature_size); \ + platform::ForRange for_range( \ + ctx, batch_size* feature_size); \ + if (ignore_idx >= 0 && ignore_idx < feature_size) { \ + for_range(HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx( \ + labels_data, loss_data, softmax_data, feature_size, ignore_idx)); \ + } else { \ + for_range(HardLabelSoftmaxWithCrossEntropyFunctor( \ + labels_data, loss_data, softmax_data, feature_size)); \ + } \ } break switch (block_dim) { @@ -356,7 +348,7 @@ static void SoftmaxWithCrossEntropyFusedKernel(const T* logits_data, logits_data, loss_data, softmax_data, feature_size); \ RowReductionForSoftmaxAndCrossEntropy< \ T, BlockDim><<>>( \ - logits_data, labels_data, loss_data, softmax_data, feature_size); \ + labels_data, loss_data, softmax_data, feature_size); \ break switch (block_dim) { diff --git a/paddle/fluid/operators/sum_op.cc b/paddle/fluid/operators/sum_op.cc index 67f7510e874..1eb4076d64d 100644 --- a/paddle/fluid/operators/sum_op.cc +++ b/paddle/fluid/operators/sum_op.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include #include #include +#include #include #include "paddle/fluid/framework/var_type_inference.h" @@ -237,13 +238,21 @@ class SumGradMaker : public framework::GradOpDescMakerBase { } }; +class SumInplace : public framework::InplaceOpInference { + public: + std::unordered_map operator()( + const framework::OpDesc& op_desc, bool use_cuda) const override { + return {{"X", "Out"}}; + } +}; + } // namespace operators } // namespace paddle namespace ops = paddle::operators; REGISTER_OPERATOR(sum, ops::SumOp, ops::SumOpMaker, ops::SumGradMaker, - ops::SumOpVarTypeInference); + ops::SumOpVarTypeInference, ops::SumInplace); REGISTER_OP_CPU_KERNEL( sum, ops::SumKernel, diff --git a/python/paddle/fluid/tests/unittests/test_inplace_softmax_with_cross_entropy.py b/python/paddle/fluid/tests/unittests/test_inplace_softmax_with_cross_entropy.py new file mode 100644 index 00000000000..a19626297a6 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_inplace_softmax_with_cross_entropy.py @@ -0,0 +1,86 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.fluid as fluid +from paddle.fluid import layers +import numpy as np +import unittest + + +class TestSoftmaxWithXe(unittest.TestCase): + def setUp(self): + self.m, self.n = np.random.random_integers( + low=100, high=2000, size=[2]).astype('int64') + + def softmax_with_xe(self, x, y, place, inplace=True): + m, n = x.shape + with fluid.program_guard(fluid.Program(), fluid.Program()): + with fluid.scope_guard(fluid.Scope()): + x_d = fluid.layers.data( + name='x', + shape=[m, n], + dtype='float32', + append_batch_size=False) + y_d = fluid.layers.data( + name='y', + shape=[m, 1], + dtype='int64', + append_batch_size=False) + z_d, s_d = fluid.layers.softmax_with_cross_entropy( + x_d, y_d, return_softmax=True) + + exe = fluid.Executor(place) + + exe.run(fluid.default_startup_program()) + + build_strategy = fluid.BuildStrategy() + build_strategy.enable_inplace = inplace + prog = fluid.CompiledProgram(fluid.default_main_program( + )).with_data_parallel( + build_strategy=build_strategy, places=place) + + if inplace and isinstance(place, fluid.CUDAPlace): + fetch_list = [z_d.name, x_d.name] + else: + fetch_list = [z_d.name, s_d.name] + + z, s = exe.run(prog, + feed={x_d.name: x, + y_d.name: y}, + fetch_list=fetch_list) + return z, s + + def main_with_place(self, place): + x = np.random.random(size=[self.m, self.n]).astype('float32') + x_range = [(-30, 30), (10, 20), (-1, 1), (2, 3), (0, 0.3), (-200, -100)] + + for a, b in x_range: + x = ((b - a) * x + a).astype('float32') + y = np.random.random_integers( + size=[self.m, 1], low=0, high=self.n - 1).astype('int64') + z1, s1 = self.softmax_with_xe(x, y, place, False) + z2, s2 = self.softmax_with_xe(x, y, place, True) + + self.assertTrue((z1 == z2).all()) + self.assertTrue((s1 == s2).all()) + + def test_main(self): + self.main_with_place(fluid.CPUPlace()) + if fluid.core.is_compiled_with_cuda(): + self.main_with_place(fluid.CUDAPlace(0)) + + +if __name__ == '__main__': + unittest.main() -- GitLab