diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 46e56981ea57722bbc064304761e7ab7b7aee141..c020ff45ad3f3a72bf8a88622df333c1765a3d21 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -55,7 +55,7 @@ paddle.fluid.Inferencer.__init__ ArgSpec(args=['self', 'infer_func', 'param_path paddle.fluid.Inferencer.infer ArgSpec(args=['self', 'inputs', 'return_numpy'], varargs=None, keywords=None, defaults=(True,)) paddle.fluid.DistributeTranspiler.__init__ ArgSpec(args=['self', 'config'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.DistributeTranspiler.get_pserver_program ArgSpec(args=['self', 'endpoint'], varargs=None, keywords=None, defaults=None) -paddle.fluid.DistributeTranspiler.get_startup_program ArgSpec(args=['self', 'endpoint', 'pserver_program'], varargs=None, keywords=None, defaults=None) +paddle.fluid.DistributeTranspiler.get_startup_program ArgSpec(args=['self', 'endpoint', 'pserver_program', 'startup_program'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.DistributeTranspiler.get_trainer_program ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None) paddle.fluid.DistributeTranspiler.transpile ArgSpec(args=['self', 'trainer_id', 'program', 'pservers', 'trainers', 'sync_mode'], varargs=None, keywords=None, defaults=(None, '127.0.0.1:6174', 1, True)) paddle.fluid.InferenceTranspiler.__init__ @@ -159,6 +159,7 @@ paddle.fluid.layers.relu ArgSpec(args=['x'], varargs=None, keywords=None, defaul paddle.fluid.layers.log ArgSpec(args=['x'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.crop ArgSpec(args=['x', 'shape', 'offsets', 'name'], varargs=None, keywords=None, defaults=(None, None, None)) paddle.fluid.layers.rank_loss ArgSpec(args=['label', 'left', 'right', 'name'], varargs=None, keywords=None, defaults=(None,)) +paddle.fluid.layers.flatten ArgSpec(args=['x', 'axis', 'name'], varargs=None, keywords=None, defaults=(1, None)) paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)) paddle.fluid.layers.open_recordio_file ArgSpec(args=['filename', 'shapes', 'lod_levels', 'dtypes', 'pass_num', 'for_parallel'], varargs=None, keywords=None, defaults=(1, True)) paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None)) @@ -327,7 +328,7 @@ paddle.fluid.contrib.BeamSearchDecoder.update_array ArgSpec(args=['self', 'array paddle.fluid.contrib.memory_usage ArgSpec(args=['program', 'batch_size'], varargs=None, keywords=None, defaults=None) paddle.fluid.transpiler.DistributeTranspiler.__init__ ArgSpec(args=['self', 'config'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.transpiler.DistributeTranspiler.get_pserver_program ArgSpec(args=['self', 'endpoint'], varargs=None, keywords=None, defaults=None) -paddle.fluid.transpiler.DistributeTranspiler.get_startup_program ArgSpec(args=['self', 'endpoint', 'pserver_program'], varargs=None, keywords=None, defaults=None) +paddle.fluid.transpiler.DistributeTranspiler.get_startup_program ArgSpec(args=['self', 'endpoint', 'pserver_program', 'startup_program'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.transpiler.DistributeTranspiler.get_trainer_program ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None) paddle.fluid.transpiler.DistributeTranspiler.transpile ArgSpec(args=['self', 'trainer_id', 'program', 'pservers', 'trainers', 'sync_mode'], varargs=None, keywords=None, defaults=(None, '127.0.0.1:6174', 1, True)) paddle.fluid.transpiler.InferenceTranspiler.__init__ diff --git a/paddle/fluid/framework/details/exception_holder.h b/paddle/fluid/framework/details/exception_holder.h index 6e302a29233b96451df14b4685911be1cd87c1ab..c97b364de1ecae21e97351196389615187932b5e 100644 --- a/paddle/fluid/framework/details/exception_holder.h +++ b/paddle/fluid/framework/details/exception_holder.h @@ -14,6 +14,7 @@ #pragma once +#include "glog/logging.h" #include "paddle/fluid/platform/enforce.h" namespace paddle { @@ -22,27 +23,24 @@ namespace details { class ExceptionHolder { public: - void Catch(const platform::EnforceNotMet& exp) { - std::lock_guard lock(mu_); - exception_.reset(new platform::EnforceNotMet(exp)); - type_ = kEnforceNotMet; - } - - void Catch(const platform::EOFException& exp) { - std::lock_guard lock(mu_); - // EOFException will not cover up existing EnforceNotMet. - if (exception_.get() == nullptr) { - exception_.reset(new platform::EOFException(exp)); - type_ = kEOF; + void Catch(std::exception_ptr eptr) { + try { + std::rethrow_exception(eptr); + } catch (platform::EOFException exp) { + Catch(exp); + } catch (platform::EnforceNotMet exp) { + Catch(exp); + } catch (...) { + LOG(FATAL) << "Unknown exception caught"; } } - bool ExceptionCatched() const { + bool IsCaught() const { std::lock_guard lock(mu_); return exception_.get() != nullptr; } - void Throw() { + void ReThrow() { std::lock_guard lock(mu_); switch (type_) { case kNone: @@ -50,27 +48,41 @@ class ExceptionHolder { case kEnforceNotMet: { auto e = *static_cast(exception_.get()); throw e; - break; } case kEOF: { auto e = *static_cast(exception_.get()); throw e; - break; } - default: - LOG(FATAL) << "Unknown exception."; } - exception_.reset(); - type_ = kNone; + ClearImpl(); } void Clear() { std::lock_guard lock(mu_); + ClearImpl(); + } + + private: + void ClearImpl() { exception_.reset(); type_ = kNone; } - private: + void Catch(const platform::EnforceNotMet& exp) { + std::lock_guard lock(mu_); + exception_.reset(new platform::EnforceNotMet(exp)); + type_ = kEnforceNotMet; + } + + void Catch(const platform::EOFException& exp) { + std::lock_guard lock(mu_); + // EOFException will not cover up existing EnforceNotMet. + if (exception_.get() == nullptr) { + exception_.reset(new platform::EOFException(exp)); + type_ = kEOF; + } + } + enum ExceptionType { kNone, kEnforceNotMet, kEOF }; ExceptionType type_{kNone}; diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc index 994bb6492f685138d02971a6caf12572aecd6d6f..c9e331ef359f853263f8dad38dd0a2be4d9618ad 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc @@ -107,11 +107,11 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( auto cur_ready_vars = ready_vars.PopAll(1, &timeout); if (timeout) { - if (exception_holder_.ExceptionCatched()) { + if (exception_holder_.IsCaught()) { for (auto &run_op_future : run_op_futures_) { run_op_future.wait(); } - exception_holder_.Throw(); + exception_holder_.ReThrow(); } else { continue; } @@ -220,12 +220,8 @@ void ThreadedSSAGraphExecutor::RunOp( running_ops_--; ready_var_q->Extend(op->Outputs()); VLOG(10) << op << " " << op->Name() << "Signal posted"; - } catch (platform::EOFException ex) { - exception_holder_.Catch(ex); - } catch (platform::EnforceNotMet ex) { - exception_holder_.Catch(ex); } catch (...) { - LOG(FATAL) << "Unknown exception catched"; + exception_holder_.Catch(std::current_exception()); } }; if (pool_) { diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h index c9d55fbf525a1a476ac469e8e57462169a7db2da..5736a5c4e232698085936303d1f23760649f8245 100644 --- a/paddle/fluid/framework/ir/graph.h +++ b/paddle/fluid/framework/ir/graph.h @@ -28,6 +28,38 @@ namespace paddle { namespace framework { namespace ir { +/* + * The graph is a Directed Acyclic Single Static Assignment Graph. + * + * In more detail, the following properties must hold: + * + * The graph shouldn't contain cycle. Each node is a black-box to the graph + * so the node itself could be a loop operator. + * + * Each Variable-type node has only one input (thus single static assignment). + * + * The output/input of operator is variable and the output/input of variable + * is operator. + * + * The following data harzards in Program are addressed in the Graph: + * + * Write-After-Read + * a = op1(x) + * x = op2(b) + * A control-dependency connection is created bettwen op1 and op2 such that + * op1->op2, so as to ensure correct order. + * + * Write-After-Write + * x = op1(a) + * x = op2(b) + * A control-dependency connection is created between op1 and op2 such that + * op1->op2, so as to ensure correct order. + * + * Other properties currently hold, but is not enforced yet: + * + * Variable-type node (not control dep) with the same variable name share + * the same underlying VarDesc. + */ class Graph { public: explicit Graph(const ProgramDesc &program); diff --git a/paddle/fluid/framework/ir/graph_test.cc b/paddle/fluid/framework/ir/graph_test.cc index f9e6bdf3625bdced9d1a9195a979b0f46016d8bf..b1b8d1c586c98a327a8e5b4890ced00022155e6b 100644 --- a/paddle/fluid/framework/ir/graph_test.cc +++ b/paddle/fluid/framework/ir/graph_test.cc @@ -36,7 +36,7 @@ class SumOpMaker : public OpProtoAndCheckerMaker { public: void Make() { AddInput("X", "").AsDuplicable(); - AddOutput("Out", ""); + AddOutput("Out", "").AsDuplicable(); AddComment(""); } }; @@ -59,11 +59,27 @@ class SumOpVarTypeInference : public VarTypeInference { block->Var(out_var_name)->SetType(default_var_type); } }; + +class DummyOpMaker : public OpProtoAndCheckerMaker { + public: + void Make() { + AddInput("X", "").AsDuplicable(); + AddOutput("Out", "").AsDuplicable(); + AddComment(""); + } +}; + +class DummyOpVarTypeInference : public VarTypeInference { + public: + void operator()(const OpDesc &op_desc, BlockDesc *block) const override {} +}; } // namespace framework } // namespace paddle REGISTER_OPERATOR(sum, paddle::framework::NOP, paddle::framework::SumOpMaker, paddle::framework::SumOpVarTypeInference); +REGISTER_OPERATOR(dummy, paddle::framework::NOP, paddle::framework::SumOpMaker, + paddle::framework::SumOpVarTypeInference); REGISTER_OPERATOR(sum_without_infer_var_type, paddle::framework::NOP, paddle::framework::SumOpMaker); @@ -110,5 +126,83 @@ TEST(GraphTest, Basic) { } ASSERT_EQ(nodes.size(), 5); } + +TEST(GraphTest, WriteAfterRead) { + // void Test() { + ProgramDesc prog; + auto *op = prog.MutableBlock(0)->AppendOp(); + op->SetType("sum"); + op->SetInput("X", {"a"}); + op->SetOutput("Out", {"b"}); + op->SetAttr("op_role", 1); + + op = prog.MutableBlock(0)->AppendOp(); + op->SetType("dummy"); + op->SetInput("X", {"c"}); + op->SetOutput("Out", {"a"}); + op->SetAttr("op_role", 1); + + prog.MutableBlock(0)->Var("a")->SetType(proto::VarType::LOD_TENSOR); + prog.MutableBlock(0)->Var("b")->SetType(proto::VarType::LOD_TENSOR); + prog.MutableBlock(0)->Var("c")->SetType(proto::VarType::LOD_TENSOR); + + std::unique_ptr g(new ir::Graph(prog)); + ir::Node *control_dep1 = nullptr; + ir::Node *control_dep2 = nullptr; + for (ir::Node *n : g->Nodes()) { + if (n->Name() == "sum") { + ASSERT_EQ(n->outputs[0]->Name(), "b"); + ASSERT_TRUE(ir::IsControlDepVar(*n->outputs[1])); + control_dep1 = n->outputs[1]; + ASSERT_EQ(n->outputs.size(), 2); + } + if (n->Name() == "dummy") { + ASSERT_EQ(n->inputs[0]->Name(), "c"); + ASSERT_TRUE(ir::IsControlDepVar(*n->inputs[1])); + control_dep2 = n->inputs[1]; + ASSERT_EQ(n->inputs.size(), 2); + } + } + ASSERT_EQ(control_dep1, control_dep2); +} + +TEST(GraphTest, WriteAfterWrite) { + // void Test() { + ProgramDesc prog; + auto *op = prog.MutableBlock(0)->AppendOp(); + op->SetType("sum"); + op->SetInput("X", {"a"}); + op->SetOutput("Out", {"b"}); + op->SetAttr("op_role", 1); + + op = prog.MutableBlock(0)->AppendOp(); + op->SetType("dummy"); + op->SetInput("X", {"c"}); + op->SetOutput("Out", {"b"}); + op->SetAttr("op_role", 1); + + prog.MutableBlock(0)->Var("a")->SetType(proto::VarType::LOD_TENSOR); + prog.MutableBlock(0)->Var("b")->SetType(proto::VarType::LOD_TENSOR); + prog.MutableBlock(0)->Var("c")->SetType(proto::VarType::LOD_TENSOR); + + std::unique_ptr g(new ir::Graph(prog)); + ir::Node *control_dep1 = nullptr; + ir::Node *control_dep2 = nullptr; + for (ir::Node *n : g->Nodes()) { + if (n->Name() == "sum") { + ASSERT_EQ(n->outputs[0]->Name(), "b"); + ASSERT_TRUE(ir::IsControlDepVar(*n->outputs[1])); + ASSERT_EQ(n->outputs.size(), 2); + control_dep1 = n->outputs[1]; + } + if (n->Name() == "dummy") { + ASSERT_EQ(n->inputs[0]->Name(), "c"); + ASSERT_TRUE(ir::IsControlDepVar(*n->inputs[1])); + control_dep2 = n->inputs[1]; + ASSERT_EQ(n->inputs.size(), 2); + ASSERT_EQ(control_dep1, control_dep2); + } + } +} } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/tensor.cc b/paddle/fluid/framework/tensor.cc index c7286dacf01659f3af0927a71856e5a6496cb877..56bb9142dabe0d5546e321e675a5acba7bf4d306 100644 --- a/paddle/fluid/framework/tensor.cc +++ b/paddle/fluid/framework/tensor.cc @@ -112,5 +112,6 @@ Tensor& Tensor::Resize(const DDim& dims) { const DDim& Tensor::dims() const { return dims_; } int64_t Tensor::numel() const { return product(dims_); } + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/tensor_impl.h b/paddle/fluid/framework/tensor_impl.h index 7f678f869aac4616c8bca440d0431f765da41dd6..b7b62eef23ec351686378c913d18fc72308fd7b2 100644 --- a/paddle/fluid/framework/tensor_impl.h +++ b/paddle/fluid/framework/tensor_impl.h @@ -59,6 +59,14 @@ inline T* Tensor::mutable_data(platform::Place place) { } inline Tensor ReshapeToMatrix(const Tensor& src, int num_col_dims) { + int rank = src.dims().size(); + PADDLE_ENFORCE_GE( + rank, 2, + "'ReshapeToMatrix()' is only used for flatten high rank " + "tensors to matrixs. Can not be used in reshaping vectors."); + if (rank == 2) { + return src; + } Tensor res; res.ShareDataWith(src); res.Resize(flatten_to_2d(src.dims(), num_col_dims)); diff --git a/paddle/fluid/inference/api/api_impl.cc b/paddle/fluid/inference/api/api_impl.cc index 08d7af6d3af7054061b15b904c69b2862c629562..e31c637e969f7a86f4f185abb4f0f01d3303db75 100644 --- a/paddle/fluid/inference/api/api_impl.cc +++ b/paddle/fluid/inference/api/api_impl.cc @@ -22,6 +22,9 @@ limitations under the License. */ #include #include "paddle/fluid/inference/api/api_impl.h" +#include "paddle/fluid/platform/profiler.h" + +DEFINE_bool(profile, false, "Turn on profiler for fluid"); namespace paddle { namespace { @@ -58,6 +61,15 @@ bool NativePaddlePredictor::Init( std::shared_ptr parent_scope) { VLOG(3) << "Predictor::init()"; + if (FLAGS_profile) { + LOG(WARNING) << "Profiler is actived, might affect the performance"; + LOG(INFO) << "You can turn off by set gflags '-profile false'"; + + auto tracking_device = config_.use_gpu ? platform::ProfilerState::kAll + : platform::ProfilerState::kCPU; + platform::EnableProfiler(tracking_device); + } + if (config_.use_gpu) { place_ = paddle::platform::CUDAPlace(config_.device); } else { @@ -102,6 +114,10 @@ bool NativePaddlePredictor::Init( } NativePaddlePredictor::~NativePaddlePredictor() { + if (FLAGS_profile) { + platform::DisableProfiler(platform::EventSortingKey::kTotal, + "./profile.log"); + } if (sub_scope_) { scope_->DeleteScope(sub_scope_); } diff --git a/paddle/fluid/operators/.flatten_op.cc.swp b/paddle/fluid/operators/.flatten_op.cc.swp deleted file mode 100644 index 3395b6074b6a4c684a97674af702ca8b91dc85e9..0000000000000000000000000000000000000000 Binary files a/paddle/fluid/operators/.flatten_op.cc.swp and /dev/null differ diff --git a/paddle/fluid/operators/cross_entropy_op.cc b/paddle/fluid/operators/cross_entropy_op.cc index a3bec3da45136bca5cb2763e7ffd6b67703a1813..578ab63bc380ee62d76e34b7cf3cbd590bfa2eda 100644 --- a/paddle/fluid/operators/cross_entropy_op.cc +++ b/paddle/fluid/operators/cross_entropy_op.cc @@ -28,23 +28,26 @@ class CrossEntropyOp : public framework::OperatorWithKernel { auto x_dims = ctx->GetInputDim("X"); auto label_dims = ctx->GetInputDim("Label"); - PADDLE_ENFORCE_EQ(x_dims.size(), 2UL, "Input(X)'s rank should be 2."); - PADDLE_ENFORCE_EQ(label_dims.size(), 2UL, - "Input(Label)'s rank should be 2."); - PADDLE_ENFORCE_EQ(x_dims[0], label_dims[0], - "The 1st dimension of Input(X) and Input(Label) should " - "be equal."); + int rank = x_dims.size(); + PADDLE_ENFORCE_EQ(rank, label_dims.size(), + "Input(X) and Input(Label) shall have the same rank."); + PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1), + framework::slice_ddim(label_dims, 0, rank - 1), + "Input(X) and Input(Label) shall have the same shape " + "except the last dimension."); if (ctx->Attrs().Get("soft_label")) { - PADDLE_ENFORCE_EQ(x_dims[1], label_dims[1], - "If Attr(soft_label) == true, the 2nd dimension of " + PADDLE_ENFORCE_EQ(x_dims[rank - 1], label_dims[rank - 1], + "If Attr(soft_label) == true, the last dimension of " "Input(X) and Input(Label) should be equal."); } else { - PADDLE_ENFORCE_EQ(label_dims[1], 1UL, - "If Attr(softLabel) == false, the 2nd dimension of " + PADDLE_ENFORCE_EQ(label_dims[rank - 1], 1UL, + "If Attr(softLabel) == false, the last dimension of " "Input(Label) should be 1."); } - ctx->SetOutputDim("Y", {x_dims[0], 1}); + auto y_dims = x_dims; + y_dims[rank - 1] = 1; + ctx->SetOutputDim("Y", y_dims); ctx->ShareLoD("X", /*->*/ "Y"); } @@ -74,24 +77,28 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel { auto x_dims = ctx->GetInputDim("X"); auto label_dims = ctx->GetInputDim("Label"); auto dy_dims = ctx->GetInputDim(framework::GradVarName("Y")); - PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank should be 2."); - PADDLE_ENFORCE_EQ(dy_dims.size(), 2, "Input(Y@Grad)'s rank should be 2."); - PADDLE_ENFORCE_EQ(label_dims.size(), 2, "Input(Label)'s rank should be 2."); - PADDLE_ENFORCE_EQ(x_dims[0], label_dims[0], - "The 1st dimension of Input(X) and Input(Label) should " - "be equal."); - PADDLE_ENFORCE_EQ(x_dims[0], dy_dims[0], - "The 1st dimension of Input(X) and Input(Y@Grad) should " - "be equal."); - PADDLE_ENFORCE_EQ(dy_dims[1], 1, - "The 2nd dimension of Input(Y@Grad) should be 1."); + int rank = x_dims.size(); + PADDLE_ENFORCE_EQ(dy_dims.size(), rank, + "Input(Y@Grad) and Input(X) should have the same rank."); + PADDLE_ENFORCE_EQ(label_dims.size(), rank, + "Input(Label) and Input(X) should have the same rank."); + PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1), + framework::slice_ddim(label_dims, 0, rank - 1), + "The Input(X) and Input(Label) should have the same " + "shape except the last dimension."); + PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1), + framework::slice_ddim(dy_dims, 0, rank - 1), + "The Input(X) and Input(Y@Grad) should have the same " + "shape except the last dimension."); + PADDLE_ENFORCE_EQ(dy_dims[rank - 1], 1, + "The last dimension of Input(Y@Grad) should be 1."); if (ctx->Attrs().Get("soft_label")) { - PADDLE_ENFORCE_EQ(x_dims[1], label_dims[1], - "When Attr(soft_label) == true, the 2nd dimension of " + PADDLE_ENFORCE_EQ(x_dims[rank - 1], label_dims[rank - 1], + "When Attr(soft_label) == true, the last dimension of " "Input(X) and Input(Label) should be equal."); } else { - PADDLE_ENFORCE_EQ(label_dims[1], 1, - "When Attr(soft_label) == false, the 2nd dimension of " + PADDLE_ENFORCE_EQ(label_dims[rank - 1], 1, + "When Attr(soft_label) == false, the last dimension of " "Input(Label) should be 1."); } ctx->SetOutputDim(framework::GradVarName("X"), x_dims); @@ -113,18 +120,20 @@ class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { AddInput("X", - "(Tensor, default Tensor), a 2-D tensor with shape [N x D]," - " where N is the batch size and D is the number of classes. " - "This input is a probability computed by the previous operator, " - "which is almost always the result of a softmax operator."); - AddInput("Label", - "(Tensor), the ground truth which is a 2-D tensor. When " - "soft_label is set to false, Label is a Tensor with shape " - "[N x 1]. When soft_label is set to true, Label is a " - "Tensor with shape [N x D]."); + "(Tensor, default Tensor), a tensor whose last dimension " + "size is equal to the number of classes. This input is a " + "probability computed by the previous operator, which is almost " + "always the result of a softmax operator."); + AddInput( + "Label", + "(Tensor), the tensor which represents the ground truth. It has the " + "same shape with 'X' except the last dimension. When soft_label is set " + "to false, the last dimension size is 1; when soft_label is set to " + "true, the last dimension size is equal to the number of classes."); AddOutput("Y", - "(Tensor, default Tensor), a 2-D tensor with shape " - "[N x 1]. The cross entropy loss."); + "(Tensor, default Tensor), a tensor whose shape is same " + "with 'X' except that the last dimension size is 1. It " + "represents the cross entropy loss."); AddAttr("soft_label", "(bool, default false), a flag indicating whether to " "interpretate the given labels as soft labels.") @@ -132,6 +141,12 @@ class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker { AddComment(R"DOC( CrossEntropy Operator. +The input 'X' and 'Label' will first be logically flattened to 2-D matrixs. +The matrix's second dimension(row length) is as same as the original last +dimension, and the first dimension(column length) is the product of all other +original dimensions. Then the softmax computation will take palce on each raw +of flattened matrixs. + It supports both standard cross-entropy and soft-label cross-entropy loss computation. 1) One-hot cross-entropy: diff --git a/paddle/fluid/operators/cross_entropy_op.h b/paddle/fluid/operators/cross_entropy_op.h index 19a2aec92b267ece94685ce34604b7d1cfa5d209..36b58d80144d242277f6fc970a3a61a6721d4b50 100644 --- a/paddle/fluid/operators/cross_entropy_op.h +++ b/paddle/fluid/operators/cross_entropy_op.h @@ -33,8 +33,13 @@ class CrossEntropyOpKernel : public framework::OpKernel { auto* y = ctx.Output("Y"); y->mutable_data(ctx.GetPlace()); + int rank = x->dims().size(); + Tensor x_2d = framework::ReshapeToMatrix(*x, rank - 1); + Tensor labels_2d = framework::ReshapeToMatrix(*labels, rank - 1); + Tensor y_2d = framework::ReshapeToMatrix(*y, rank - 1); + math::CrossEntropyFunctor()( - ctx.template device_context(), y, x, labels, + ctx.template device_context(), &y_2d, &x_2d, &labels_2d, ctx.Attr("soft_label")); } }; @@ -98,9 +103,12 @@ class CrossEntropyGradientOpKernel : public framework::OpKernel { auto* dy = ctx.Input(framework::GradVarName("Y")); auto* label = ctx.Input("Label"); auto* dx = ctx.Output(framework::GradVarName("X")); - auto* dx_data = dx->mutable_data(ctx.GetPlace()); + T* dx_data = dx->mutable_data(ctx.GetPlace()); - int64_t class_num = x->dims()[1]; + // Following computation only depends on the last dimension size. So it's + // unnecessary to convert tensors to 2-D views. + int rank = x->dims().size(); + int64_t class_num = x->dims()[rank - 1]; if (ctx.Attr("soft_label")) { XeSoftlabelGradFunctor functor(dx_data, dy->data(), x->data(), label->data(), diff --git a/paddle/fluid/operators/shape_op.cc b/paddle/fluid/operators/shape_op.cc index b44d5f898013a5d27467bd80118c29a886d5e8b3..1be9fe47af71d31ce2e0eba807ea4a43601f8aca 100644 --- a/paddle/fluid/operators/shape_op.cc +++ b/paddle/fluid/operators/shape_op.cc @@ -38,7 +38,7 @@ class ShapeOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("Input", "(Tensor), The input tensor."); AddOutput("Out", "(Tensor), The shape of input tensor, the data type of the shape" - " is int64_t, will be on the same device with the input Tensor."); + " is int32_t, will be on the same device with the input Tensor."); AddComment(R"DOC( Shape Operator @@ -53,5 +53,5 @@ Get the shape of input tensor. Only support CPU input Tensor now. namespace ops = paddle::operators; REGISTER_OPERATOR(shape, ops::ShapeOp, ops::ShapeOpMaker, paddle::framework::EmptyGradOpMaker); -REGISTER_OP_CPU_KERNEL(shape, ops::ShapeKernel, ops::ShapeKernel, +REGISTER_OP_CPU_KERNEL(shape, ops::ShapeKernel, ops::ShapeKernel, ops::ShapeKernel, ops::ShapeKernel); diff --git a/paddle/fluid/operators/shape_op.cu b/paddle/fluid/operators/shape_op.cu index 7736a2a1e13cfa5d445411b3efac7669a7bf23a2..d8fa9515abf807ab4ae3c47e8e1b1cf7f30440a8 100644 --- a/paddle/fluid/operators/shape_op.cu +++ b/paddle/fluid/operators/shape_op.cu @@ -15,6 +15,6 @@ limitations under the License. */ #include "paddle/fluid/operators/shape_op.h" REGISTER_OP_CUDA_KERNEL(shape, paddle::operators::ShapeKernel, - paddle::operators::ShapeKernel, + paddle::operators::ShapeKernel, paddle::operators::ShapeKernel, paddle::operators::ShapeKernel); diff --git a/paddle/fluid/operators/shape_op.h b/paddle/fluid/operators/shape_op.h index 3be86b66a538e7b38a5d59095fee7e7636364bce..0d510a505583c55e26a26bfc6e5d6192899b3d9e 100644 --- a/paddle/fluid/operators/shape_op.h +++ b/paddle/fluid/operators/shape_op.h @@ -27,7 +27,7 @@ class ShapeKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& ctx) const override { auto* in_t = ctx.Input("Input"); auto* out_t = ctx.Output("Out"); - auto out_data = out_t->mutable_data(platform::CPUPlace()); + auto out_data = out_t->mutable_data(platform::CPUPlace()); auto in_dims = in_t->dims(); for (int i = 0; i < in_dims.size(); ++i) { out_data[i] = in_dims[i]; diff --git a/paddle/fluid/operators/softmax_op.h b/paddle/fluid/operators/softmax_op.h index 1205bd0587f32caae04c27ecea581fc17988507f..cf1eeb017d666f605a431aa54637d8cbc99c7c46 100644 --- a/paddle/fluid/operators/softmax_op.h +++ b/paddle/fluid/operators/softmax_op.h @@ -31,16 +31,12 @@ class SoftmaxKernel : public framework::OpKernel { // allocate memory on device. Out->mutable_data(context.GetPlace()); - auto dims = X->dims(); - auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1); - framework::LoDTensor flattened_x; - framework::LoDTensor flattened_out; - flattened_x.ShareDataWith(*X).Resize(flattened_dims); - flattened_out.ShareDataWith(*Out).Resize(flattened_dims); + int rank = X->dims().size(); + Tensor X_2d = framework::ReshapeToMatrix(*X, rank - 1); + Tensor Out_2d = framework::ReshapeToMatrix(*Out, rank - 1); math::SoftmaxFunctor()( - context.template device_context(), &flattened_x, - &flattened_out); + context.template device_context(), &X_2d, &Out_2d); } }; @@ -55,18 +51,14 @@ class SoftmaxGradKernel : public framework::OpKernel { // allocate memory on device. dX->mutable_data(context.GetPlace()); - auto dims = Out->dims(); - auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1); - framework::LoDTensor flattened_out; - framework::LoDTensor flattened_d_out; - framework::LoDTensor flattened_d_x; - flattened_out.ShareDataWith(*Out).Resize(flattened_dims); - flattened_d_out.ShareDataWith(*dOut).Resize(flattened_dims); - flattened_d_x.ShareDataWith(*dX).Resize(flattened_dims); + int rank = Out->dims().size(); + Tensor Out_2d = framework::ReshapeToMatrix(*Out, rank - 1); + Tensor dOut_2d = framework::ReshapeToMatrix(*dOut, rank - 1); + Tensor dX_2d = framework::ReshapeToMatrix(*dX, rank - 1); math::SoftmaxGradFunctor()( - context.template device_context(), &flattened_out, - &flattened_d_out, &flattened_d_x); + context.template device_context(), &Out_2d, &dOut_2d, + &dX_2d); } }; diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index 9f8cb6ef0bb7f8dfc05e5647330d0644979a053f..abe479693c16d2d7de5e1e7f4a2ddbdf5ac748e1 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -20,9 +20,11 @@ from .layer_function_generator import autodoc, templatedoc from ..layer_helper import LayerHelper from . import tensor from . import nn +from . import ops from ... import compat as cpt import math import six +import numpy from functools import reduce __all__ = [ @@ -266,10 +268,11 @@ def detection_output(loc, prior_box_var=prior_box_var, target_box=loc, code_type='decode_center_size') - old_shape = scores.shape - scores = nn.reshape(x=scores, shape=(-1, old_shape[-1])) + compile_shape = scores.shape + run_shape = ops.shape(scores) + scores = nn.flatten(x=scores, axis=2) scores = nn.softmax(input=scores) - scores = nn.reshape(x=scores, shape=old_shape) + scores = nn.reshape(x=scores, shape=compile_shape, actual_shape=run_shape) scores = nn.transpose(scores, perm=[0, 2, 1]) scores.stop_gradient = True nmsed_outs = helper.create_tmp_variable(dtype=decoded_box.dtype) @@ -679,9 +682,10 @@ def ssd_loss(location, raise ValueError("Only support mining_type == max_negative now.") num, num_prior, num_class = confidence.shape + conf_shape = ops.shape(confidence) def __reshape_to_2d(var): - return nn.reshape(x=var, shape=[-1, var.shape[-1]]) + return nn.flatten(x=var, axis=2) # 1. Find matched boundding box by prior box. # 1.1 Compute IOU similarity between ground-truth boxes and prior boxes. @@ -692,7 +696,8 @@ def ssd_loss(location, # 2. Compute confidence for mining hard examples # 2.1. Get the target label based on matched indices - gt_label = nn.reshape(x=gt_label, shape=gt_label.shape + (1, )) + gt_label = nn.reshape( + x=gt_label, shape=(len(gt_label.shape) - 1) * (0, ) + (-1, 1)) gt_label.stop_gradient = True target_label, _ = target_assign( gt_label, matched_indices, mismatch_value=background_label) @@ -703,9 +708,12 @@ def ssd_loss(location, target_label = __reshape_to_2d(target_label) target_label.stop_gradient = True conf_loss = nn.softmax_with_cross_entropy(confidence, target_label) - # 3. Mining hard examples - conf_loss = nn.reshape(x=conf_loss, shape=(num, num_prior)) + conf_loss = nn.reshape( + x=conf_loss, + shape=(num, num_prior), + actual_shape=ops.slice( + conf_shape, axes=[0], starts=[0], ends=[2])) conf_loss.stop_gradient = True neg_indices = helper.create_tmp_variable(dtype='int32') dtype = matched_indices.dtype @@ -774,7 +782,11 @@ def ssd_loss(location, # 5.3 Compute overall weighted loss. loss = conf_loss_weight * conf_loss + loc_loss_weight * loc_loss # reshape to [N, Np], N is the batch size and Np is the prior box number. - loss = nn.reshape(x=loss, shape=[-1, num_prior]) + loss = nn.reshape( + x=loss, + shape=(num, num_prior), + actual_shape=ops.slice( + conf_shape, axes=[0], starts=[0], ends=[2])) loss = nn.reduce_sum(loss, dim=1, keep_dim=True) if normalize: normalizer = nn.reduce_sum(target_loc_weight) @@ -1007,13 +1019,7 @@ def multi_box_head(inputs, """ def _reshape_with_axis_(input, axis=1): - if not (axis > 0 and axis < len(input.shape)): - raise ValueError("The axis should be smaller than " - "the arity of input and bigger than 0.") - new_shape = [ - -1, reduce(lambda x, y: x * y, input.shape[axis:len(input.shape)]) - ] - out = nn.reshape(x=input, shape=new_shape) + out = nn.flatten(x=input, axis=axis) return out def _is_list_or_tuple_(data): @@ -1103,11 +1109,13 @@ def multi_box_head(inputs, stride=stride) mbox_loc = nn.transpose(mbox_loc, perm=[0, 2, 3, 1]) - new_shape = [ - mbox_loc.shape[0], mbox_loc.shape[1] * mbox_loc.shape[2] * - cpt.floor_division(mbox_loc.shape[3], 4), 4 + compile_shape = [ + mbox_loc.shape[0], cpt.floor_division( + box_loc.shape[1] * mbox_loc.shape[2] * mbox_loc.shape[3], 4), 4 ] - mbox_loc_flatten = nn.reshape(mbox_loc, shape=new_shape) + run_shape = tensor.assign(numpy.array([0, -1, 4]).astype("int32")) + mbox_loc_flatten = nn.reshape( + mbox_loc, shape=compile_shape, actual_shape=run_shape) mbox_locs.append(mbox_loc_flatten) # get conf @@ -1119,11 +1127,16 @@ def multi_box_head(inputs, padding=pad, stride=stride) conf_loc = nn.transpose(conf_loc, perm=[0, 2, 3, 1]) - new_shape = [ - conf_loc.shape[0], conf_loc.shape[1] * conf_loc.shape[2] * - cpt.floor_division(conf_loc.shape[3], num_classes), num_classes + new_shape = [0, -1, num_classes] + compile_shape = [ + conf_loc.shape[0], + cpt.floor_division(conf_loc.shape[1] * conf_loc.shape[2] * + conf_loc.shape[3], num_classes), num_classes ] - conf_loc_flatten = nn.reshape(conf_loc, shape=new_shape) + run_shape = tensor.assign( + numpy.array([0, -1, num_classes]).astype("int32")) + conf_loc_flatten = nn.reshape( + conf_loc, shape=compile_shape, actual_shape=run_shape) mbox_confs.append(conf_loc_flatten) if len(box_results) == 1: diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 62e0f3876c6f64696bdbbe59776a91fa2d6405f4..aed09914bb8afe16f72d2cd03603251e1d0bab64 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -112,6 +112,7 @@ __all__ = [ 'log', 'crop', 'rank_loss', + 'flatten', ] @@ -5361,3 +5362,70 @@ def rank_loss(label, left, right, name=None): "Right": right}, outputs={'Out': out}) return out + + +def flatten(x, axis=1, name=None): + """ + **Flatten layer** + Flattens the input tensor into a 2D matrix. + + Examples: + Case 1: + Given + X.shape = (3, 100, 100, 4) + and + axis = 2 + We get: + Out.shape = (3 * 100, 4 * 100) + + Case 2: + Given + X.shape = (3, 100, 100, 4) + and + axis = 0 + We get: + Out.shape = (1, 3 * 100 * 100 * 4) + + Args: + x (Variable): A tensor of rank >= axis. + axis (int): Indicate up to which input dimensions (exclusive) should + be flattened to the outer dimension of the output. + The value for axis must be in the range [0, R], where R + is the rank of the input tensor. When axis = 0, the shape + of the output tensor is (1, (d_0 X d_1 ... d_n), where the + shape of the input tensor is (d_0, d_1, ... d_n). + name(str|None): A name for this layer(optional). If set None, the layer + will be named automatically. + + Returns: + Variable: A 2D tensor with the contents of the input tensor, with input + dimensions up to axis flattened to the outer dimension of + the output and remaining input dimensions flattened into the + inner dimension of the output. + + Raises: + ValueError: If x is not a variable. + ValueError: If axis is not in range [0, rank(x)]. + + Examples: + + .. code-block:: python + + x = fluid.layers.data(name="x", shape=[4, 4, 3], dtype="float32") + out = fluid.layers.flatten(x=x, axis=2) + """ + helper = LayerHelper('flatten', **locals()) + + if not (isinstance(x, Variable)): + raise ValueError("The input x should be a Variable") + + if not (isinstance(axis, int)) or axis > len(x.shape) or axis < 0: + raise ValueError("The axis should be a int, and in range [0, rank(x)]") + + out = helper.create_tmp_variable(x.dtype) + helper.append_op( + type='flatten', + inputs={"X": x}, + outputs={'Out': out}, + attrs={"axis": axis}) + return out diff --git a/python/paddle/fluid/tests/unittests/test_cross_entropy_op.py b/python/paddle/fluid/tests/unittests/test_cross_entropy_op.py index c5b9e92d69133e593a2ce223e83006eda590daa5..86ac159323a5f9f6149ce5ed4437402eb885c6bc 100644 --- a/python/paddle/fluid/tests/unittests/test_cross_entropy_op.py +++ b/python/paddle/fluid/tests/unittests/test_cross_entropy_op.py @@ -105,5 +105,107 @@ class TestCrossEntropyOp3(OpTest): ["X"], "Y", max_relative_error=0.05, numeric_grad_delta=0.001) +class TestCrossEntropyOp4(OpTest): + """Test high rank tensor cross-entropy with discrete one-hot labels. + """ + + def setUp(self): + self.op_type = "cross_entropy" + shape = [10, 2, 4] + ins_num = np.prod(np.array(shape)) + class_num = 10 + + X_2d = randomize_probability(ins_num, class_num, dtype='float64') + + label_2d = np.random.randint(0, class_num, (ins_num, 1), dtype="int64") + cross_entropy_2d = np.asmatrix( + [[-np.log(X_2d[i][label_2d[i][0]])] for i in range(X_2d.shape[0])], + dtype="float64") + + X = X_2d.reshape(shape + [class_num]) + label = label_2d.reshape(shape + [1]) + cross_entropy = np.array(cross_entropy_2d).reshape(shape + [1]) + + self.inputs = {"X": X, "Label": label} + self.outputs = {"Y": cross_entropy} + self.attrs = {"soft_label": False} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Y", numeric_grad_delta=0.001) + + +class TestCrossEntropyOp5(OpTest): + """Test high rank tensor cross-entropy with vectorized soft labels. + """ + + def setUp(self): + self.op_type = "cross_entropy" + shape = [4, 3] + ins_num = np.prod(np.array(shape)) + class_num = 37 + + X_2d = randomize_probability(ins_num, class_num) + label_2d = np.random.uniform(0.1, 1.0, + [ins_num, class_num]).astype("float32") + label_2d /= label_2d.sum(axis=1, keepdims=True) + cross_entropy_2d = (-label_2d * np.log(X_2d)).sum( + axis=1, keepdims=True).astype("float32") + + X = X_2d.reshape(shape + [class_num]) + label = label_2d.reshape(shape + [class_num]) + cross_entropy = np.array(cross_entropy_2d).reshape(shape + [1]) + + self.inputs = {"X": X, "Label": label} + self.outputs = {"Y": cross_entropy} + self.attrs = {"soft_label": True} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad( + ["X"], "Y", max_relative_error=0.05, numeric_grad_delta=0.001) + + +class TestCrossEntropyOp6(OpTest): + """Test high rank tensor cross-entropy with vectorized one-hot representation of labels. + """ + + def setUp(self): + self.op_type = "cross_entropy" + shape = [4, 3, 2] + ins_num = np.prod(np.array(shape)) + class_num = 17 + + X_2d = randomize_probability(ins_num, class_num) + label_index_2d = np.random.randint( + 0, class_num, (ins_num), dtype="int32") + label_2d = np.zeros(X_2d.shape) + label_2d[np.arange(ins_num), label_index_2d] = 1 + + cross_entropy_2d = np.asmatrix( + [[-np.log(X_2d[i][label_index_2d[i]])] + for i in range(X_2d.shape[0])], + dtype="float32") + + X = X_2d.reshape(shape + [class_num]) + label = label_2d.reshape(shape + [class_num]) + cross_entropy = np.array(cross_entropy_2d).reshape(shape + [1]) + + self.inputs = {"X": X, "Label": label.astype(np.float32)} + self.outputs = {"Y": cross_entropy} + self.attrs = {"soft_label": True} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad( + ["X"], "Y", max_relative_error=0.05, numeric_grad_delta=0.001) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index aae5a24f6ca69606153e5814cc8d632ce732bc8e..63234afc2018610190dd6fb2b47b837ad3d2d0d3 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -465,6 +465,17 @@ class TestBook(unittest.TestCase): self.assertIsNotNone(out) print(str(program)) + def test_flatten(self): + program = Program() + with program_guard(program): + x = layers.data( + name='x', + append_batch_size=False, + shape=[4, 4, 3], + dtype="float32") + out = layers.flatten(x, axis=1, name="flatten") + self.assertIsNotNone(out) + def test_shape(self): program = Program() with program_guard(program): diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index 6d94d5745a44fe320cddaab30bc6ed8073c4126e..16db9f7b97bcc70f839d25cec4e5ad1c99aae0da 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -532,7 +532,10 @@ class DistributeTranspiler(object): pserver_program._sync_with_cpp() return pserver_program - def get_startup_program(self, endpoint, pserver_program): + def get_startup_program(self, + endpoint, + pserver_program, + startup_program=None): """ Get startup program for current parameter server. Modify operator input variables if there are variables that @@ -542,12 +545,17 @@ class DistributeTranspiler(object): endpoint (str): current pserver endpoint. pserver_program (Program): call get_pserver_program first and pass the result here. + startup_program (Program): if pass None, will use + default_startup_program Returns: Program: parameter server side startup program. """ s_prog = Program() - orig_s_prog = default_startup_program() + if not startup_program: + orig_s_prog = default_startup_program() + else: + orig_s_prog = startup_program s_prog.random_seed = orig_s_prog.random_seed params = self.param_grad_ep_mapping[endpoint]["params"]