diff --git a/paddle/fluid/framework/details/graph_test_base.h b/paddle/fluid/framework/details/graph_test_base.h index 2fae68451610aba7dc063cb6be2f3cff7fb5e1c1..d139f8488309eecf89c924a346ab0e574edc86dc 100644 --- a/paddle/fluid/framework/details/graph_test_base.h +++ b/paddle/fluid/framework/details/graph_test_base.h @@ -68,11 +68,11 @@ class SplitOpMaker : public OpProtoAndCheckerMaker { class DummyVarTypeInference : public VarTypeInference { public: - void operator()(framework::InferVarTypeContext& ctx) const override { - auto& inputs = ctx.Input("X"); - auto type = ctx.GetType(inputs.front()); - auto out_var_name = ctx.Output("Out").front(); - ctx.SetType(out_var_name, type); + void operator()(framework::InferVarTypeContext* ctx) const override { + auto& inputs = ctx->Input("X"); + auto type = ctx->GetType(inputs.front()); + auto out_var_name = ctx->Output("Out").front(); + ctx->SetType(out_var_name, type); } }; diff --git a/paddle/fluid/framework/details/op_registry.h b/paddle/fluid/framework/details/op_registry.h index 420d4da8d5143f161ac730c316102267e4e2e8e7..e13ff99f3fdb564141531b401565c932fa1f3dab 100644 --- a/paddle/fluid/framework/details/op_registry.h +++ b/paddle/fluid/framework/details/op_registry.h @@ -131,7 +131,7 @@ struct OpInfoFiller { void operator()(const char* op_type, OpInfo* info) const { info->infer_var_type_ = [](InferVarTypeContext* context) { T inference; - inference(*context); + inference(context); }; } }; diff --git a/paddle/fluid/framework/ir/graph_test.cc b/paddle/fluid/framework/ir/graph_test.cc index 851c1b80a85b534e648dd8e6417fb3a66ad3e12d..a95588a57b434763fb0f01e33528ef15fd1aa42b 100644 --- a/paddle/fluid/framework/ir/graph_test.cc +++ b/paddle/fluid/framework/ir/graph_test.cc @@ -43,20 +43,20 @@ class SumOpMaker : public OpProtoAndCheckerMaker { class SumOpVarTypeInference : public VarTypeInference { public: - void operator()(InferVarTypeContext &ctx) const override { - auto &inputs = ctx.Input("X"); + void operator()(InferVarTypeContext *ctx) const override { + auto &inputs = ctx->Input("X"); auto default_var_type = proto::VarType::SELECTED_ROWS; bool any_input_is_lod_tensor = std::any_of( inputs.begin(), inputs.end(), [&ctx](const std::string &name) { - return ctx.GetType(name) == proto::VarType::LOD_TENSOR; + return ctx->GetType(name) == proto::VarType::LOD_TENSOR; }); if (any_input_is_lod_tensor) { default_var_type = proto::VarType::LOD_TENSOR; } - auto out_var_name = ctx.Output("Out").front(); - ctx.SetType(out_var_name, default_var_type); + auto out_var_name = ctx->Output("Out").front(); + ctx->SetType(out_var_name, default_var_type); } }; @@ -71,7 +71,7 @@ class DummyOpMaker : public OpProtoAndCheckerMaker { class DummyOpVarTypeInference : public VarTypeInference { public: - void operator()(framework::InferVarTypeContext &ctx) const override {} + void operator()(framework::InferVarTypeContext *ctx) const override {} }; } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/var_type_inference.h b/paddle/fluid/framework/var_type_inference.h index 5dd08442c20322cef9a4b17f9e6cae98fa3e8066..2e9c64d3e6854bf70c0aee06128b9f1b7c8c7439 100644 --- a/paddle/fluid/framework/var_type_inference.h +++ b/paddle/fluid/framework/var_type_inference.h @@ -126,20 +126,20 @@ class InferVarTypeContext { class VarTypeInference { public: virtual ~VarTypeInference() {} - virtual void operator()(InferVarTypeContext& context) const = 0; // NOLINT + virtual void operator()(InferVarTypeContext* context) const = 0; // NOLINT }; class PassInDtypeAndVarTypeToOutput : public framework::VarTypeInference { public: - void operator()(framework::InferVarTypeContext& ctx) const final { // NOLINT + void operator()(framework::InferVarTypeContext* ctx) const final { // NOLINT auto in_out_var_names = this->GetInputOutputWithSameType(); for (auto& i_o_n : in_out_var_names) { - auto& x_name = ctx.Input(i_o_n.first).at(0); - auto& out_name = ctx.Output(i_o_n.second).at(0); + auto& x_name = ctx->Input(i_o_n.first).at(0); + auto& out_name = ctx->Output(i_o_n.second).at(0); - ctx.SetType(out_name, ctx.GetType(x_name)); - ctx.SetDataType(out_name, ctx.GetDataType(x_name)); + ctx->SetType(out_name, ctx->GetType(x_name)); + ctx->SetDataType(out_name, ctx->GetDataType(x_name)); } } diff --git a/paddle/fluid/framework/var_type_inference_test.cc b/paddle/fluid/framework/var_type_inference_test.cc index 60e1d610daf12948dd0d864dcd16d1c9d8990aa3..6bbb25a573d076d5ec6d6fd960a304639e9e3d49 100644 --- a/paddle/fluid/framework/var_type_inference_test.cc +++ b/paddle/fluid/framework/var_type_inference_test.cc @@ -44,20 +44,20 @@ class SumOpMaker : public OpProtoAndCheckerMaker { class SumOpVarTypeInference : public VarTypeInference { public: - void operator()(framework::InferVarTypeContext &ctx) const override { - auto &inputs = ctx.Input("X"); + void operator()(framework::InferVarTypeContext *ctx) const override { + auto &inputs = ctx->Input("X"); auto default_var_type = proto::VarType::SELECTED_ROWS; bool any_input_is_lod_tensor = std::any_of( inputs.begin(), inputs.end(), [&ctx](const std::string &name) { - return ctx.GetType(name) == proto::VarType::LOD_TENSOR; + return ctx->GetType(name) == proto::VarType::LOD_TENSOR; }); if (any_input_is_lod_tensor) { default_var_type = proto::VarType::LOD_TENSOR; } - auto out_var_name = ctx.Output("Out").front(); - ctx.SetType(out_var_name, default_var_type); + auto out_var_name = ctx->Output("Out").front(); + ctx->SetType(out_var_name, default_var_type); } }; } // namespace framework diff --git a/paddle/fluid/imperative/tracer.cc b/paddle/fluid/imperative/tracer.cc index 0f7a2415372d938885235b35b00cd683519c797b..18bd1d19383cd4aa6336f38bcc0f64af1d97bfb7 100644 --- a/paddle/fluid/imperative/tracer.cc +++ b/paddle/fluid/imperative/tracer.cc @@ -161,7 +161,7 @@ Tracer::Tracer(framework::BlockDesc* root_block) : root_block_(root_block) { } std::set Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, - VarBasePtrMap& outputs, + VarBasePtrMap* outputs, framework::AttributeMap attrs_map, const platform::Place expected_place, const bool stop_gradient) { @@ -195,7 +195,7 @@ std::set Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, } } - op->output_vars_ = outputs; + op->output_vars_ = *outputs; for (auto it : op->output_vars_) { auto& outvars = outvars_map[it.first]; const std::vector& outputs = it.second; @@ -218,7 +218,7 @@ std::set Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, framework::VariableNameMap invars_name_map = CreateInputVarNameMap(op, inputs); framework::VariableNameMap outvars_name_map = - CreateOutputVarNameMap(op, outputs); + CreateOutputVarNameMap(op, *outputs); auto& info = framework::OpInfoMap::Instance().Get(op->Type()); if (info.Checker() != nullptr) { @@ -230,8 +230,7 @@ std::set Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, outvars_name_map, attrs_map); if (info.infer_var_type_) { - RuntimeInferVarTypeContext infer_var_type_ctx(&inputs, &outputs, - &attrs_map); + RuntimeInferVarTypeContext infer_var_type_ctx(&inputs, outputs, &attrs_map); info.infer_var_type_(&infer_var_type_ctx); } diff --git a/paddle/fluid/imperative/tracer.h b/paddle/fluid/imperative/tracer.h index ae3b16727dbbacf928dacdf97d983b4f6687c8bb..a87f3b8009dd552626c6c03fba3b0bbf3a78bb83 100644 --- a/paddle/fluid/imperative/tracer.h +++ b/paddle/fluid/imperative/tracer.h @@ -48,7 +48,7 @@ class Tracer { virtual ~Tracer() {} std::set Trace(OpBase* op, const VarBasePtrMap& inputs, - VarBasePtrMap& outputs, // NOLINT + VarBasePtrMap* outputs, // NOLINT framework::AttributeMap attrs_map, const platform::Place expected_place, const bool stop_gradient = false); diff --git a/paddle/fluid/operators/beam_search_decode_op.cc b/paddle/fluid/operators/beam_search_decode_op.cc index 703edcad1183389f1986b3d045e9e1222f78d51f..4cef49280dfb5207a9d94df42d94657f03ec838f 100644 --- a/paddle/fluid/operators/beam_search_decode_op.cc +++ b/paddle/fluid/operators/beam_search_decode_op.cc @@ -203,12 +203,12 @@ class BeamSearchDecodeInferShape : public framework::InferShapeBase { class BeamSearchDecodeInferVarType : public framework::VarTypeInference { public: - void operator()(framework::InferVarTypeContext& ctx) const override { - for (auto& o : ctx.Output("SentenceIds")) { - ctx.SetType(o, framework::proto::VarType::LOD_TENSOR); + void operator()(framework::InferVarTypeContext* ctx) const override { + for (auto& o : ctx->Output("SentenceIds")) { + ctx->SetType(o, framework::proto::VarType::LOD_TENSOR); } - for (auto& o : ctx.Output("SentenceScores")) { - ctx.SetType(o, framework::proto::VarType::LOD_TENSOR); + for (auto& o : ctx->Output("SentenceScores")) { + ctx->SetType(o, framework::proto::VarType::LOD_TENSOR); } } }; diff --git a/paddle/fluid/operators/beam_search_op.cc b/paddle/fluid/operators/beam_search_op.cc index 8958d00a6870af52c60f90125dc7871a1810a4cd..a6aa35e0569364d79c15aea6e6dbc6ca670d49f0 100644 --- a/paddle/fluid/operators/beam_search_op.cc +++ b/paddle/fluid/operators/beam_search_op.cc @@ -120,12 +120,12 @@ class BeamSearchOp : public framework::OperatorWithKernel { class BeamSearchInferVarType : public framework::VarTypeInference { public: - void operator()(framework::InferVarTypeContext &ctx) const override { - for (auto &o : ctx.Output("selected_ids")) { - ctx.SetType(o, framework::proto::VarType::LOD_TENSOR); + void operator()(framework::InferVarTypeContext *ctx) const override { + for (auto &o : ctx->Output("selected_ids")) { + ctx->SetType(o, framework::proto::VarType::LOD_TENSOR); } - for (auto &o : ctx.Output("selected_scores")) { - ctx.SetType(o, framework::proto::VarType::LOD_TENSOR); + for (auto &o : ctx->Output("selected_scores")) { + ctx->SetType(o, framework::proto::VarType::LOD_TENSOR); } } }; diff --git a/paddle/fluid/operators/controlflow/tensor_array_read_write_op.cc b/paddle/fluid/operators/controlflow/tensor_array_read_write_op.cc index 041eef602ecd77538597ac3dd178f013601612e8..45f18ac9255bdd75d8cbb5e1dd30ebba52260850 100644 --- a/paddle/fluid/operators/controlflow/tensor_array_read_write_op.cc +++ b/paddle/fluid/operators/controlflow/tensor_array_read_write_op.cc @@ -100,13 +100,13 @@ class WriteToArrayInferShape : public framework::InferShapeBase { class WriteToArrayInferVarType : public framework::VarTypeInference { public: - void operator()(framework::InferVarTypeContext &ctx) const override { - auto x_name = ctx.Input("X")[0]; - auto out_name = ctx.Output("Out")[0]; + void operator()(framework::InferVarTypeContext *ctx) const override { + auto x_name = ctx->Input("X")[0]; + auto out_name = ctx->Output("Out")[0]; VLOG(10) << "Set Variable " << out_name << " as LOD_TENSOR_ARRAY"; - ctx.SetType(out_name, framework::proto::VarType::LOD_TENSOR_ARRAY); - if (ctx.HasVar(x_name)) { - ctx.SetDataType(out_name, ctx.GetDataType(x_name)); + ctx->SetType(out_name, framework::proto::VarType::LOD_TENSOR_ARRAY); + if (ctx->HasVar(x_name)) { + ctx->SetDataType(out_name, ctx->GetDataType(x_name)); } } }; diff --git a/paddle/fluid/operators/controlflow/while_op.cc b/paddle/fluid/operators/controlflow/while_op.cc index 90c30678680b2dc741f68e335367dd2cbe50412f..deb8ec3bb2d5682e8733365fb865daebbf8405e0 100644 --- a/paddle/fluid/operators/controlflow/while_op.cc +++ b/paddle/fluid/operators/controlflow/while_op.cc @@ -365,16 +365,16 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker { class WhileGradOpVarTypeInference : public framework::VarTypeInference { public: - void operator()(framework::InferVarTypeContext &ctx) const override { - auto p_names = ctx.Input(kX); - auto pg_ig_names = ctx.Output(framework::GradVarName(kX)); + void operator()(framework::InferVarTypeContext *ctx) const override { + auto p_names = ctx->Input(kX); + auto pg_ig_names = ctx->Output(framework::GradVarName(kX)); for (size_t i = 0; i < p_names.size(); ++i) { - if (ctx.HasVar(pg_ig_names[i])) { + if (ctx->HasVar(pg_ig_names[i])) { VLOG(5) << "Setting " << pg_ig_names[i] << " following " << p_names[i] - << " type: " << ctx.GetType(p_names[i]); - ctx.SetType(pg_ig_names[i], ctx.GetType(p_names[i])); - ctx.SetDataType(pg_ig_names[i], ctx.GetDataType(p_names[i])); + << " type: " << ctx->GetType(p_names[i]); + ctx->SetType(pg_ig_names[i], ctx->GetType(p_names[i])); + ctx->SetDataType(pg_ig_names[i], ctx->GetDataType(p_names[i])); } } } diff --git a/paddle/fluid/operators/distributed_ops/fake_init_op.cc b/paddle/fluid/operators/distributed_ops/fake_init_op.cc index 89228c72432701cbbddf7574d40f823b653d19f3..5ee35e0458a64dacc1c469a435edd28de1b78e6b 100644 --- a/paddle/fluid/operators/distributed_ops/fake_init_op.cc +++ b/paddle/fluid/operators/distributed_ops/fake_init_op.cc @@ -56,7 +56,7 @@ class FakeInitOp : public framework::OperatorBase { class FakeInitOpVarTypeInference : public framework::VarTypeInference { public: - void operator()(framework::InferVarTypeContext &ctx) const override {} + void operator()(framework::InferVarTypeContext *ctx) const override {} }; class FakeInitOpMaker : public framework::OpProtoAndCheckerMaker { diff --git a/paddle/fluid/operators/distributed_ops/merge_ids_op.cc b/paddle/fluid/operators/distributed_ops/merge_ids_op.cc index 0a269c7575cfade16fd3cb9b395f3d11a16b30a6..1b0b4dd31693340bc39c0da8995a2a2d40b13e00 100644 --- a/paddle/fluid/operators/distributed_ops/merge_ids_op.cc +++ b/paddle/fluid/operators/distributed_ops/merge_ids_op.cc @@ -114,10 +114,10 @@ class MergeIdsOp : public framework::OperatorWithKernel { class MergeIdsOpInferVarType : public framework::VarTypeInference { public: - void operator()(framework::InferVarTypeContext &ctx) const override { - auto input_type = ctx.GetType(ctx.Input("Ids")[0]); - for (auto &out_var : ctx.Output("Out")) { - ctx.SetType(out_var, input_type); + void operator()(framework::InferVarTypeContext *ctx) const override { + auto input_type = ctx->GetType(ctx->Input("Ids")[0]); + for (auto &out_var : ctx->Output("Out")) { + ctx->SetType(out_var, input_type); } } }; diff --git a/paddle/fluid/operators/distributed_ops/split_ids_op.cc b/paddle/fluid/operators/distributed_ops/split_ids_op.cc index e9f3f89c6e2e4fa24aa65cebbb6207cedf8df088..191ca1efe8ca5798ddbd38968eafde349af8a7d1 100644 --- a/paddle/fluid/operators/distributed_ops/split_ids_op.cc +++ b/paddle/fluid/operators/distributed_ops/split_ids_op.cc @@ -73,10 +73,10 @@ class SplitIdsOp : public framework::OperatorWithKernel { class SplitIdsOpInferVarType : public framework::VarTypeInference { public: - void operator()(framework::InferVarTypeContext &ctx) const override { - auto input_type = ctx.GetType(ctx.Input("Ids")[0]); - for (auto &out_var : ctx.Output("Out")) { - ctx.SetType(out_var, input_type); + void operator()(framework::InferVarTypeContext *ctx) const override { + auto input_type = ctx->GetType(ctx->Input("Ids")[0]); + for (auto &out_var : ctx->Output("Out")) { + ctx->SetType(out_var, input_type); } } }; diff --git a/paddle/fluid/operators/fill_constant_op.cc b/paddle/fluid/operators/fill_constant_op.cc index eb5996d50ebaedddee5c5a7973a4cf34842c64f6..cf2f4776cf2ae9a707d3b841c2a41b7f82ca7833 100644 --- a/paddle/fluid/operators/fill_constant_op.cc +++ b/paddle/fluid/operators/fill_constant_op.cc @@ -39,11 +39,11 @@ class FillConstantOp : public framework::OperatorWithKernel { class FillConstantOpVarTypeInference : public framework::VarTypeInference { public: - void operator()(framework::InferVarTypeContext& ctx) const override { + void operator()(framework::InferVarTypeContext* ctx) const override { auto data_type = static_cast( - boost::get(ctx.GetAttr("dtype"))); - auto& out_var_name = ctx.Output("Out").front(); - ctx.SetDataType(out_var_name, data_type); + boost::get(ctx->GetAttr("dtype"))); + auto& out_var_name = ctx->Output("Out").front(); + ctx->SetDataType(out_var_name, data_type); } }; diff --git a/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.cc b/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.cc index 27a761c29f68a142b209dc2bd8a716e8b4ce4fdb..5edeeae14e96fe383fa3ebc026e25ddd9ade1ef3 100644 --- a/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.cc +++ b/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.cc @@ -137,20 +137,20 @@ class FusedEmbeddingSeqPoolOpGrad : public framework::OperatorWithKernel { class FusedEmbeddingSeqPoolOpGradVarTypeInference : public framework::VarTypeInference { public: - void operator()(framework::InferVarTypeContext& ctx) const override { - auto out_var_name = ctx.Output(framework::GradVarName("W")).front(); - auto attr = ctx.GetAttr("is_sparse"); + void operator()(framework::InferVarTypeContext* ctx) const override { + auto out_var_name = ctx->Output(framework::GradVarName("W")).front(); + auto attr = ctx->GetAttr("is_sparse"); bool is_sparse = boost::get(attr); if (is_sparse) { VLOG(3) << "fused_embedding_seq_pool_grad op " << framework::GradVarName("W") << " is set to SelectedRows"; - ctx.SetType(out_var_name, framework::proto::VarType::SELECTED_ROWS); + ctx->SetType(out_var_name, framework::proto::VarType::SELECTED_ROWS); } else { VLOG(3) << "fused_embedding_seq_pool_grad op " << framework::GradVarName("W") << " is set to LoDTensor"; - ctx.SetType(out_var_name, framework::proto::VarType::LOD_TENSOR); + ctx->SetType(out_var_name, framework::proto::VarType::LOD_TENSOR); } - ctx.SetDataType(out_var_name, ctx.GetDataType(ctx.Input("W")[0])); + ctx->SetDataType(out_var_name, ctx->GetDataType(ctx->Input("W")[0])); } }; diff --git a/paddle/fluid/operators/get_tensor_from_selected_rows_op.cc b/paddle/fluid/operators/get_tensor_from_selected_rows_op.cc index 5388e65497ff84787c6b8cd0653188cddeb31e81..c0893359af2f4de4ed8fd88ebff122447e8d84c7 100644 --- a/paddle/fluid/operators/get_tensor_from_selected_rows_op.cc +++ b/paddle/fluid/operators/get_tensor_from_selected_rows_op.cc @@ -81,12 +81,12 @@ GetTensorFromSelectedRows is used to get the tensor from SelectedRows. class GetTensorFromSelectedRowsOpVarTypeInference : public framework::VarTypeInference { public: - void operator()(framework::InferVarTypeContext &ctx) const { // NOLINT - auto out_var_name = ctx.Output("Out").front(); - auto in_var_name = ctx.Input("X").front(); + void operator()(framework::InferVarTypeContext *ctx) const { // NOLINT + auto out_var_name = ctx->Output("Out").front(); + auto in_var_name = ctx->Input("X").front(); - ctx.SetType(out_var_name, framework::proto::VarType::LOD_TENSOR); - ctx.SetDataType(out_var_name, ctx.GetDataType(in_var_name)); + ctx->SetType(out_var_name, framework::proto::VarType::LOD_TENSOR); + ctx->SetDataType(out_var_name, ctx->GetDataType(in_var_name)); } }; diff --git a/paddle/fluid/operators/hierarchical_sigmoid_op.cc b/paddle/fluid/operators/hierarchical_sigmoid_op.cc index 508c99b953b5dda171d3e657f564ba862a30c83a..d0e1057c4357e372d3ab396841de7b2d0577d365 100644 --- a/paddle/fluid/operators/hierarchical_sigmoid_op.cc +++ b/paddle/fluid/operators/hierarchical_sigmoid_op.cc @@ -197,32 +197,32 @@ class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel { class HierarchicalSigmoidGradOpGradVarTypeInference : public framework::VarTypeInference { public: - void operator()(framework::InferVarTypeContext& ctx) const override { - auto w_grad_var_name = ctx.Output(framework::GradVarName("W")).front(); - auto bias_grad_var_name_vec = ctx.Output(framework::GradVarName("Bias")); + void operator()(framework::InferVarTypeContext* ctx) const override { + auto w_grad_var_name = ctx->Output(framework::GradVarName("W")).front(); + auto bias_grad_var_name_vec = ctx->Output(framework::GradVarName("Bias")); std::string bias_grad_var_name; bool hasBias = false; if (bias_grad_var_name_vec.size()) { hasBias = true; - bias_grad_var_name = ctx.Output(framework::GradVarName("Bias")).front(); + bias_grad_var_name = ctx->Output(framework::GradVarName("Bias")).front(); } - auto attr = ctx.GetAttr("is_sparse"); + auto attr = ctx->GetAttr("is_sparse"); bool is_sparse = boost::get(attr); if (is_sparse) { VLOG(30) << "hierarchical_sigmoid_grad op " << framework::GradVarName("W") << " is set to SelectedRows"; - ctx.SetType(w_grad_var_name, framework::proto::VarType::SELECTED_ROWS); + ctx->SetType(w_grad_var_name, framework::proto::VarType::SELECTED_ROWS); } else { VLOG(30) << "hierarchical_sigmoid_grad op " << framework::GradVarName("W") << " is set to LoDTensor"; - ctx.SetType(w_grad_var_name, framework::proto::VarType::LOD_TENSOR); + ctx->SetType(w_grad_var_name, framework::proto::VarType::LOD_TENSOR); } if (hasBias) { VLOG(30) << "hierarchical_sigmoid_grad op " << framework::GradVarName("Bias") << " is set to LoDTensor"; - ctx.SetType(bias_grad_var_name, framework::proto::VarType::LOD_TENSOR); + ctx->SetType(bias_grad_var_name, framework::proto::VarType::LOD_TENSOR); } - ctx.SetDataType(w_grad_var_name, ctx.GetDataType(ctx.Input("W")[0])); + ctx->SetDataType(w_grad_var_name, ctx->GetDataType(ctx->Input("W")[0])); } }; diff --git a/paddle/fluid/operators/lod_rank_table_op.cc b/paddle/fluid/operators/lod_rank_table_op.cc index a7bbb49827ffcd4b3df13f35ed0cc749884def9c..0a43ac0c52f9bc98eacf743480166682482cc3c0 100644 --- a/paddle/fluid/operators/lod_rank_table_op.cc +++ b/paddle/fluid/operators/lod_rank_table_op.cc @@ -64,9 +64,9 @@ class LoDRankTableInferShape : public framework::InferShapeBase { class LoDRankTableInferVarType : public framework::VarTypeInference { public: - void operator()(framework::InferVarTypeContext &ctx) const override { - for (auto &o : ctx.Output("Out")) { - ctx.SetType(o, framework::proto::VarType::LOD_RANK_TABLE); + void operator()(framework::InferVarTypeContext *ctx) const override { + for (auto &o : ctx->Output("Out")) { + ctx->SetType(o, framework::proto::VarType::LOD_RANK_TABLE); } } }; diff --git a/paddle/fluid/operators/lod_tensor_to_array_op.cc b/paddle/fluid/operators/lod_tensor_to_array_op.cc index 4fd45db67bf0808976398cab30ec7a6019e48daa..61e342737045616112d51b7753939286a31dc6cd 100644 --- a/paddle/fluid/operators/lod_tensor_to_array_op.cc +++ b/paddle/fluid/operators/lod_tensor_to_array_op.cc @@ -201,9 +201,9 @@ class LoDTensorToArrayInferShape : public framework::InferShapeBase { class LoDTensorToArrayInferVarType : public framework::VarTypeInference { public: - void operator()(framework::InferVarTypeContext &ctx) const override { - for (auto &out_var : ctx.Output("Out")) { - ctx.SetType(out_var, framework::proto::VarType::LOD_TENSOR_ARRAY); + void operator()(framework::InferVarTypeContext *ctx) const override { + for (auto &out_var : ctx->Output("Out")) { + ctx->SetType(out_var, framework::proto::VarType::LOD_TENSOR_ARRAY); } } }; diff --git a/paddle/fluid/operators/lookup_table_op.cc b/paddle/fluid/operators/lookup_table_op.cc index a59ff23f932a4d7df63788baefbcc5b38d484263..8d1ebe6b1ce3374d21f0cdfff21ca27929398e8e 100644 --- a/paddle/fluid/operators/lookup_table_op.cc +++ b/paddle/fluid/operators/lookup_table_op.cc @@ -147,20 +147,20 @@ class LookupTableOpGrad : public framework::OperatorWithKernel { class LookupTableOpGradVarTypeInference : public framework::VarTypeInference { public: - void operator()(framework::InferVarTypeContext& ctx) const override { - auto out_var_name = ctx.Output(framework::GradVarName("W")).front(); - auto attr = ctx.GetAttr("is_sparse"); + void operator()(framework::InferVarTypeContext* ctx) const override { + auto out_var_name = ctx->Output(framework::GradVarName("W")).front(); + auto attr = ctx->GetAttr("is_sparse"); bool is_sparse = boost::get(attr); if (is_sparse) { VLOG(3) << "lookup_table_grad op " << framework::GradVarName("W") << " is set to SelectedRows"; - ctx.SetType(out_var_name, framework::proto::VarType::SELECTED_ROWS); + ctx->SetType(out_var_name, framework::proto::VarType::SELECTED_ROWS); } else { VLOG(3) << "lookup_table_grad op " << framework::GradVarName("W") << " is set to LoDTensor"; - ctx.SetType(out_var_name, framework::proto::VarType::LOD_TENSOR); + ctx->SetType(out_var_name, framework::proto::VarType::LOD_TENSOR); } - ctx.SetDataType(out_var_name, ctx.GetDataType(ctx.Input("W")[0])); + ctx->SetDataType(out_var_name, ctx->GetDataType(ctx->Input("W")[0])); } }; diff --git a/paddle/fluid/operators/nccl/nccl_op.cc b/paddle/fluid/operators/nccl/nccl_op.cc index 7df5a881f5739fc355b9a7bf90a8d999ccb8a85f..6a0ae0dede695d80508bcc92a7a13ae9f73c3c57 100644 --- a/paddle/fluid/operators/nccl/nccl_op.cc +++ b/paddle/fluid/operators/nccl/nccl_op.cc @@ -60,9 +60,9 @@ class NCCLInitOp : public framework::OperatorBase { class NCCLInitOpVarTypeInference : public framework::VarTypeInference { public: - void operator()(framework::InferVarTypeContext &ctx) const override { - auto out_var_name = ctx.Output("Communicator").front(); - ctx.SetType(out_var_name, framework::proto::VarType::RAW); + void operator()(framework::InferVarTypeContext *ctx) const override { + auto out_var_name = ctx->Output("Communicator").front(); + ctx->SetType(out_var_name, framework::proto::VarType::RAW); } }; diff --git a/paddle/fluid/operators/nce_op.cc b/paddle/fluid/operators/nce_op.cc index 3c3d79cc7b3a74c96f9a7a0b21c35250a0da8d11..fa7cc58c08455457dd129afd130067704ec72c7c 100644 --- a/paddle/fluid/operators/nce_op.cc +++ b/paddle/fluid/operators/nce_op.cc @@ -237,21 +237,21 @@ class NCEOpGrad : public framework::OperatorWithKernel { class NCEOpGradVarTypeInference : public framework::VarTypeInference { public: - void operator()(framework::InferVarTypeContext &ctx) const override { - auto weight_grad = ctx.Output(framework::GradVarName("Weight")).front(); + void operator()(framework::InferVarTypeContext *ctx) const override { + auto weight_grad = ctx->Output(framework::GradVarName("Weight")).front(); - auto attr = ctx.GetAttr("is_sparse"); + auto attr = ctx->GetAttr("is_sparse"); bool is_sparse = boost::get(attr); if (is_sparse) { VLOG(3) << "nce_op_grad op " << weight_grad << " and " << " is set to SelectedRows"; - ctx.SetType(weight_grad, framework::proto::VarType::SELECTED_ROWS); + ctx->SetType(weight_grad, framework::proto::VarType::SELECTED_ROWS); } else { VLOG(3) << "nce_op_grad op " << weight_grad << " and " << " is set to LoDTensor"; - ctx.SetType(weight_grad, framework::proto::VarType::LOD_TENSOR); + ctx->SetType(weight_grad, framework::proto::VarType::LOD_TENSOR); } - ctx.SetDataType(weight_grad, ctx.GetDataType(ctx.Input("Input")[0])); + ctx->SetDataType(weight_grad, ctx->GetDataType(ctx->Input("Input")[0])); } }; diff --git a/paddle/fluid/operators/ngraph/ngraph_engine_op.cc b/paddle/fluid/operators/ngraph/ngraph_engine_op.cc index a88ddf33a0540b9e97ae43dc59516868b70e82df..479c95ba08c316be3d1d983ea736fcc505332d6e 100644 --- a/paddle/fluid/operators/ngraph/ngraph_engine_op.cc +++ b/paddle/fluid/operators/ngraph/ngraph_engine_op.cc @@ -37,7 +37,7 @@ class NgraphEngineOpMaker : public framework::OpProtoAndCheckerMaker { class NgraphEngineInferVarType : public framework::VarTypeInference { public: - void operator()(framework::InferVarTypeContext &ctx) const override {} + void operator()(framework::InferVarTypeContext *ctx) const override {} }; } // namespace operators diff --git a/paddle/fluid/operators/optimizers/lars_momentum_op.cc b/paddle/fluid/operators/optimizers/lars_momentum_op.cc index 668fa889acd927a510b8f7112ac6f79447467c04..126b665dd4d9301ae67346afa45a250accfec656 100644 --- a/paddle/fluid/operators/optimizers/lars_momentum_op.cc +++ b/paddle/fluid/operators/optimizers/lars_momentum_op.cc @@ -72,7 +72,7 @@ use L2 regularizers in case of using LARS. class LarsMomentumOpVarTypeInference : public framework::VarTypeInference { public: - void operator()(framework::InferVarTypeContext &ctx) const override {} + void operator()(framework::InferVarTypeContext* ctx) const override {} }; } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/optimizers/momentum_op.cc b/paddle/fluid/operators/optimizers/momentum_op.cc index 1be423da5b93ba3073509ff4b5fe7358ff0ac21d..7cf218c20f4c8a22aefc8cd8ce8e1cca36dee3bf 100644 --- a/paddle/fluid/operators/optimizers/momentum_op.cc +++ b/paddle/fluid/operators/optimizers/momentum_op.cc @@ -21,14 +21,14 @@ using Tensor = framework::Tensor; class MomentumOpInferVarType : public framework::VarTypeInference { public: - void operator()(framework::InferVarTypeContext& ctx) const override { - auto& input_var = ctx.Input("Param")[0]; - for (auto& out_var : ctx.Output("ParamOut")) { - if (ctx.GetType(input_var) == framework::proto::VarType::SELECTED_ROWS) { - ctx.SetType(out_var, framework::proto::VarType::SELECTED_ROWS); - } else if (ctx.GetType(input_var) == + void operator()(framework::InferVarTypeContext* ctx) const override { + auto& input_var = ctx->Input("Param")[0]; + for (auto& out_var : ctx->Output("ParamOut")) { + if (ctx->GetType(input_var) == framework::proto::VarType::SELECTED_ROWS) { + ctx->SetType(out_var, framework::proto::VarType::SELECTED_ROWS); + } else if (ctx->GetType(input_var) == framework::proto::VarType::LOD_TENSOR) { - ctx.SetType(out_var, framework::proto::VarType::LOD_TENSOR); + ctx->SetType(out_var, framework::proto::VarType::LOD_TENSOR); } else { PADDLE_THROW( "Only support LodTensor and SelectedRows, Unexpected Input Type."); diff --git a/paddle/fluid/operators/optimizers/sgd_op.cc b/paddle/fluid/operators/optimizers/sgd_op.cc index cac3d9b68fd7ea0f4c8ff7326990820665af0a3e..34e99a14ff77cf8aa7d7f58529140f21d864b596 100644 --- a/paddle/fluid/operators/optimizers/sgd_op.cc +++ b/paddle/fluid/operators/optimizers/sgd_op.cc @@ -50,18 +50,18 @@ class SGDOp : public framework::OperatorWithKernel { class SGDOpInferVarType : public framework::VarTypeInference { public: - void operator()(framework::InferVarTypeContext &ctx) const override { - auto &input_var_n = ctx.Input("Param")[0]; - auto in_var_type = ctx.GetType(input_var_n); + void operator()(framework::InferVarTypeContext *ctx) const override { + auto &input_var_n = ctx->Input("Param")[0]; + auto in_var_type = ctx->GetType(input_var_n); PADDLE_ENFORCE(in_var_type == framework::proto::VarType::SELECTED_ROWS || in_var_type == framework::proto::VarType::LOD_TENSOR, "The input Var's type should be LoDtensor or SelectedRows," " but the received var(%s)'s type is %s", input_var_n, in_var_type); - for (auto &out_var_n : ctx.Output("ParamOut")) { - if (ctx.GetType(out_var_n) != in_var_type) { - ctx.SetType(out_var_n, in_var_type); + for (auto &out_var_n : ctx->Output("ParamOut")) { + if (ctx->GetType(out_var_n) != in_var_type) { + ctx->SetType(out_var_n, in_var_type); } } } diff --git a/paddle/fluid/operators/py_func_op.cc b/paddle/fluid/operators/py_func_op.cc index 6472b9c16318e030d18b67aa208d56670592c87f..67202c7f9d6ea4718e6b308326271376f9305a92 100644 --- a/paddle/fluid/operators/py_func_op.cc +++ b/paddle/fluid/operators/py_func_op.cc @@ -96,10 +96,10 @@ static void CallPythonFunc(py::object *callable, class PyFuncOpVarTypeInference : public framework::VarTypeInference { public: - void operator()(framework::InferVarTypeContext &ctx) const override { - bool has_out = (ctx.HasOutput("Out") && !ctx.Output("Out").empty()); + void operator()(framework::InferVarTypeContext *ctx) const override { + bool has_out = (ctx->HasOutput("Out") && !ctx->Output("Out").empty()); - bool has_in = (ctx.HasInput("X") && !ctx.Input("Out").empty()); + bool has_in = (ctx->HasInput("X") && !ctx->Input("Out").empty()); /** * X or Out can be empty, so that py_func can be more flexible @@ -107,8 +107,8 @@ class PyFuncOpVarTypeInference : public framework::VarTypeInference { */ PADDLE_ENFORCE(has_in || has_out, "Input(X) or Output(Out) must exist"); - PADDLE_ENFORCE_GE(boost::get(ctx.GetAttr(kForwardPythonCallableId)), 0, - "Function id cannot be less than 0"); + PADDLE_ENFORCE_GE(boost::get(ctx->GetAttr(kForwardPythonCallableId)), + 0, "Function id cannot be less than 0"); if (!has_out) return; @@ -118,7 +118,7 @@ class PyFuncOpVarTypeInference : public framework::VarTypeInference { * the corresponding forward variable */ const std::string kGradVarSuffix = framework::kGradVarSuffix; - auto &out_var_names = ctx.Output("Out"); + auto &out_var_names = ctx->Output("Out"); for (auto &out_var_name : out_var_names) { if (out_var_name == framework::kEmptyVarName || out_var_name.size() < kGradVarSuffix.size()) { @@ -128,17 +128,17 @@ class PyFuncOpVarTypeInference : public framework::VarTypeInference { size_t len = out_var_name.size() - kGradVarSuffix.size(); if (out_var_name.substr(len) == kGradVarSuffix) { auto fwd_var_name = out_var_name.substr(0, len); - PADDLE_ENFORCE(ctx.HasVar(out_var_name), + PADDLE_ENFORCE(ctx->HasVar(out_var_name), "Backward variable %s not found", out_var_name); - PADDLE_ENFORCE(ctx.HasVar(fwd_var_name), + PADDLE_ENFORCE(ctx->HasVar(fwd_var_name), "Backward variable %s not found", fwd_var_name); VLOG(10) << "Infer var_desc of Output(" << out_var_name << ") as Input(" << fwd_var_name << ")"; - ctx.SetShape(out_var_name, ctx.GetShape(fwd_var_name)); - ctx.SetDataType(out_var_name, ctx.GetDataType(fwd_var_name)); - ctx.SetLoDLevel(out_var_name, ctx.GetLoDLevel(fwd_var_name)); - ctx.SetType(out_var_name, ctx.GetType(fwd_var_name)); + ctx->SetShape(out_var_name, ctx->GetShape(fwd_var_name)); + ctx->SetDataType(out_var_name, ctx->GetDataType(fwd_var_name)); + ctx->SetLoDLevel(out_var_name, ctx->GetLoDLevel(fwd_var_name)); + ctx->SetType(out_var_name, ctx->GetType(fwd_var_name)); } } } diff --git a/paddle/fluid/operators/reader/create_custom_reader_op.cc b/paddle/fluid/operators/reader/create_custom_reader_op.cc index b65e23685681fd4179a4cf223658440c2550f2d4..fdc7b0f6a0e8de232865adb70677af80eb08a174 100644 --- a/paddle/fluid/operators/reader/create_custom_reader_op.cc +++ b/paddle/fluid/operators/reader/create_custom_reader_op.cc @@ -123,22 +123,22 @@ class CustomReaderInferShape : public framework::InferShapeBase { class CustomReaderInferVarType : public framework::VarTypeInference { public: - void operator()(framework::InferVarTypeContext& ctx) const override { - auto& out_var_name = ctx.Output("Out")[0]; - PADDLE_ENFORCE(ctx.HasVar(out_var_name)); - ctx.SetType(out_var_name, framework::proto::VarType::READER); + void operator()(framework::InferVarTypeContext* ctx) const override { + auto& out_var_name = ctx->Output("Out")[0]; + PADDLE_ENFORCE(ctx->HasVar(out_var_name)); + ctx->SetType(out_var_name, framework::proto::VarType::READER); auto sink_var_names = - boost::get>(ctx.GetAttr("sink_var_names")); + boost::get>(ctx->GetAttr("sink_var_names")); const auto* sub_block = - boost::get(ctx.GetAttr("sub_block")); + boost::get(ctx->GetAttr("sub_block")); std::vector res_data_types; for (const std::string& var_name : sink_var_names) { framework::VarDesc* var = sub_block->FindVar(var_name); PADDLE_ENFORCE_NOT_NULL(var); res_data_types.emplace_back(var->GetDataType()); } - ctx.SetDataTypes(out_var_name, res_data_types); + ctx->SetDataTypes(out_var_name, res_data_types); } }; diff --git a/paddle/fluid/operators/reader/read_op.cc b/paddle/fluid/operators/reader/read_op.cc index 40549ce54d4db3fe802d68a85a3a3ffdccb12b16..33a69ad5fec2b850cae070ca3f113f12c4e835f9 100644 --- a/paddle/fluid/operators/reader/read_op.cc +++ b/paddle/fluid/operators/reader/read_op.cc @@ -51,16 +51,16 @@ class ReadInferShape : public framework::InferShapeBase { class ReadInferVarType : public framework::VarTypeInference { public: - void operator()(framework::InferVarTypeContext& ctx) const override { - bool infer_out = boost::get(ctx.GetAttr("infer_out")); + void operator()(framework::InferVarTypeContext* ctx) const override { + bool infer_out = boost::get(ctx->GetAttr("infer_out")); if (infer_out) { - std::string reader_name = ctx.Input("Reader")[0]; - std::vector out_names = ctx.Output("Out"); - auto dtypes = ctx.GetDataTypes(reader_name); + std::string reader_name = ctx->Input("Reader")[0]; + std::vector out_names = ctx->Output("Out"); + auto dtypes = ctx->GetDataTypes(reader_name); PADDLE_ENFORCE_EQ(dtypes.size(), out_names.size()); for (size_t i = 0; i < dtypes.size(); ++i) { - ctx.SetType(out_names[i], framework::proto::VarType::LOD_TENSOR); - ctx.SetDataType(out_names[i], dtypes[i]); + ctx->SetType(out_names[i], framework::proto::VarType::LOD_TENSOR); + ctx->SetDataType(out_names[i], dtypes[i]); } } } diff --git a/paddle/fluid/operators/reader/reader_op_registry.cc b/paddle/fluid/operators/reader/reader_op_registry.cc index 44772281be41b88f3fb07a0acd394b6f5f9dbc38..64a1f6b68702f33ec72d901cf6621b674b331030 100644 --- a/paddle/fluid/operators/reader/reader_op_registry.cc +++ b/paddle/fluid/operators/reader/reader_op_registry.cc @@ -99,9 +99,9 @@ void FileReaderInferShape::operator()(framework::InferShapeContext* ctx) const { } void FileReaderInferVarType::operator()( - framework::InferVarTypeContext& ctx) const { - std::string reader_name = ctx.Output("Out")[0]; - ctx.SetType(reader_name, framework::proto::VarType::READER); + framework::InferVarTypeContext* ctx) const { + std::string reader_name = ctx->Output("Out")[0]; + ctx->SetType(reader_name, framework::proto::VarType::READER); } void DecoratedReaderInferShape::operator()( @@ -124,11 +124,11 @@ void DecoratedReaderInferShape::operator()( } void DecoratedReaderInferVarType::operator()( - framework::InferVarTypeContext& ctx) const { - const std::string& in_reader_name = ctx.Input("UnderlyingReader")[0]; - const std::string& out_reader_name = ctx.Output("Out")[0]; - ctx.SetType(out_reader_name, framework::proto::VarType::READER); - ctx.SetDataTypes(out_reader_name, ctx.GetDataTypes(in_reader_name)); + framework::InferVarTypeContext* ctx) const { + const std::string& in_reader_name = ctx->Input("UnderlyingReader")[0]; + const std::string& out_reader_name = ctx->Output("Out")[0]; + ctx->SetType(out_reader_name, framework::proto::VarType::READER); + ctx->SetDataTypes(out_reader_name, ctx->GetDataTypes(in_reader_name)); } void DecoratedReaderMakerBase::Make() { diff --git a/paddle/fluid/operators/reader/reader_op_registry.h b/paddle/fluid/operators/reader/reader_op_registry.h index 5a775b82f5bd91178a4a2dc2b572efd5ad074b60..795a5806050efe6469732004125e4a80b08e5304 100644 --- a/paddle/fluid/operators/reader/reader_op_registry.h +++ b/paddle/fluid/operators/reader/reader_op_registry.h @@ -61,7 +61,7 @@ class FileReaderInferShape : public framework::InferShapeBase { class FileReaderInferVarType : public framework::VarTypeInference { public: - void operator()(framework::InferVarTypeContext& ctx) const override; + void operator()(framework::InferVarTypeContext* ctx) const override; }; // general infershape for decorated reader @@ -73,7 +73,7 @@ class DecoratedReaderInferShape : public framework::InferShapeBase { // general var type inference for decorated reader class DecoratedReaderInferVarType : public framework::VarTypeInference { public: - void operator()(framework::InferVarTypeContext& ctx) const override; + void operator()(framework::InferVarTypeContext* ctx) const override; }; class DecoratedReaderMakerBase : public framework::OpProtoAndCheckerMaker { diff --git a/paddle/fluid/operators/save_op.cc b/paddle/fluid/operators/save_op.cc index 45da2ac4c6cab7acd697c1890b1f98f8af7daf6e..b02c098099625ca544fd889d5bb1c13ef2374450 100644 --- a/paddle/fluid/operators/save_op.cc +++ b/paddle/fluid/operators/save_op.cc @@ -159,9 +159,9 @@ This operator will serialize and write LoDTensor / SelectedRows variable to file class SaveOpVarTypeInference : public framework::VarTypeInference { public: - void operator()(framework::InferVarTypeContext &ctx) const override { - auto out_var_name = ctx.Output(LOOKUP_TABLE_PATH).front(); - ctx.SetType(out_var_name, framework::proto::VarType::RAW); + void operator()(framework::InferVarTypeContext *ctx) const override { + auto out_var_name = ctx->Output(LOOKUP_TABLE_PATH).front(); + ctx->SetType(out_var_name, framework::proto::VarType::RAW); } }; diff --git a/paddle/fluid/operators/scale_op.cc b/paddle/fluid/operators/scale_op.cc index 208a6f8009c6961752bdcc64be97b3eefe28b937..4e4a015e18305cd7aad71722056b15216f44782e 100644 --- a/paddle/fluid/operators/scale_op.cc +++ b/paddle/fluid/operators/scale_op.cc @@ -70,13 +70,13 @@ $$Out = scale*(X + bias)$$ class ScaleOpVarTypeInference : public framework::VarTypeInference { public: - void operator()(framework::InferVarTypeContext &ctx) const override { - auto &in_var_name = ctx.Input("X").front(); - auto out_var_name = ctx.Output("Out").front(); + void operator()(framework::InferVarTypeContext *ctx) const override { + auto &in_var_name = ctx->Input("X").front(); + auto out_var_name = ctx->Output("Out").front(); if (in_var_name != out_var_name) { - ctx.SetType(out_var_name, ctx.GetType(in_var_name)); - ctx.SetDataType(out_var_name, ctx.GetDataType(in_var_name)); + ctx->SetType(out_var_name, ctx->GetType(in_var_name)); + ctx->SetDataType(out_var_name, ctx->GetDataType(in_var_name)); } } }; diff --git a/paddle/fluid/operators/split_selected_rows_op.cc b/paddle/fluid/operators/split_selected_rows_op.cc index f102b911b58a3dc52e66976c2c89e07b30b5b98c..88dfebc0cff0d0f7752c372780f1d952667ec630 100644 --- a/paddle/fluid/operators/split_selected_rows_op.cc +++ b/paddle/fluid/operators/split_selected_rows_op.cc @@ -62,9 +62,9 @@ class SplitSelectedRowsOp : public framework::OperatorWithKernel { class SplitSelectedRowsOpInferVarType : public framework::VarTypeInference { public: - void operator()(framework::InferVarTypeContext &ctx) const override { - for (auto &out_var : ctx.Output("Out")) { - ctx.SetType(out_var, framework::proto::VarType::SELECTED_ROWS); + void operator()(framework::InferVarTypeContext *ctx) const override { + for (auto &out_var : ctx->Output("Out")) { + ctx->SetType(out_var, framework::proto::VarType::SELECTED_ROWS); } } }; diff --git a/paddle/fluid/operators/sum_op.cc b/paddle/fluid/operators/sum_op.cc index 2405a74d2bd8f74affe0401f7f5397f72ce2f017..1391148ccf5d13082cb31ef2e143249e8ef95bfc 100644 --- a/paddle/fluid/operators/sum_op.cc +++ b/paddle/fluid/operators/sum_op.cc @@ -160,20 +160,20 @@ the LoD information with the first input. class SumOpVarTypeInference : public framework::VarTypeInference { public: - void operator()(framework::InferVarTypeContext& ctx) const override { - auto& inputs = ctx.Input("X"); + void operator()(framework::InferVarTypeContext* ctx) const override { + auto& inputs = ctx->Input("X"); auto var_type = framework::proto::VarType::SELECTED_ROWS; - for (auto& name : ctx.Input("X")) { - VLOG(10) << name << " " << ctx.GetType(name); + for (auto& name : ctx->Input("X")) { + VLOG(10) << name << " " << ctx->GetType(name); } bool any_input_is_lod_tensor = std::any_of( - inputs.begin(), inputs.end(), [&ctx](const std::string& name) { - return ctx.GetType(name) == framework::proto::VarType::LOD_TENSOR; + inputs.begin(), inputs.end(), [ctx](const std::string& name) { + return ctx->GetType(name) == framework::proto::VarType::LOD_TENSOR; }); - auto is_tensor_array = [&ctx](const std::string& name) { - return ctx.GetType(name) == framework::proto::VarType::LOD_TENSOR_ARRAY; + auto is_tensor_array = [ctx](const std::string& name) { + return ctx->GetType(name) == framework::proto::VarType::LOD_TENSOR_ARRAY; }; bool any_input_is_tensor_array = @@ -185,7 +185,7 @@ class SumOpVarTypeInference : public framework::VarTypeInference { if (!all_inputs_are_tensor_array) { std::ostringstream os; for (auto& each : inputs) { - os << " " << each << " type is " << ctx.GetType(each) << "\n"; + os << " " << each << " type is " << ctx->GetType(each) << "\n"; } PADDLE_ENFORCE(all_inputs_are_tensor_array, "Not all inputs are tensor array:\n%s", os.str()); @@ -195,9 +195,9 @@ class SumOpVarTypeInference : public framework::VarTypeInference { var_type = framework::proto::VarType::LOD_TENSOR; } - auto out_var_name = ctx.Output("Out").front(); - ctx.SetType(out_var_name, var_type); - ctx.SetDataType(out_var_name, ctx.GetDataType(inputs.front())); + auto out_var_name = ctx->Output("Out").front(); + ctx->SetType(out_var_name, var_type); + ctx->SetDataType(out_var_name, ctx->GetDataType(inputs.front())); } }; diff --git a/paddle/fluid/operators/tensor_array_to_tensor_op.cc b/paddle/fluid/operators/tensor_array_to_tensor_op.cc index d7f67ccb2fa88670003e60541cd698112d9c3176..2b83c42f205c6ec0c14305586e179a003ce2619f 100644 --- a/paddle/fluid/operators/tensor_array_to_tensor_op.cc +++ b/paddle/fluid/operators/tensor_array_to_tensor_op.cc @@ -177,9 +177,9 @@ class LoDTensorArray2TensorGradInferShape : public framework::InferShapeBase { class LoDTensorArray2TensorGradInferVarType : public framework::VarTypeInference { public: - void operator()(framework::InferVarTypeContext &ctx) const override { - for (auto &out_var : ctx.Output(framework::GradVarName("X"))) { - ctx.SetType(out_var, framework::proto::VarType::LOD_TENSOR_ARRAY); + void operator()(framework::InferVarTypeContext *ctx) const override { + for (auto &out_var : ctx->Output(framework::GradVarName("X"))) { + ctx->SetType(out_var, framework::proto::VarType::LOD_TENSOR_ARRAY); } } }; diff --git a/paddle/fluid/operators/tensorrt/tensorrt_engine_op.cc b/paddle/fluid/operators/tensorrt/tensorrt_engine_op.cc index 845629d40f6620cfc728cda5862684a194ac595c..6cf3e65e00ff6dd6a87d2b699ae89b9bde5d5462 100644 --- a/paddle/fluid/operators/tensorrt/tensorrt_engine_op.cc +++ b/paddle/fluid/operators/tensorrt/tensorrt_engine_op.cc @@ -46,7 +46,7 @@ class TensorRTEngineOpMaker : public framework::OpProtoAndCheckerMaker { class TensorRTEngineInferVarType : public framework::VarTypeInference { public: - void operator()(framework::InferVarTypeContext &ctx) const override {} + void operator()(framework::InferVarTypeContext *ctx) const override {} }; } // namespace operators diff --git a/paddle/fluid/operators/uniform_random_op.cc b/paddle/fluid/operators/uniform_random_op.cc index b3a8b6a141e744a616f033271eb797803790d586..bb6a1c5b165693df4199fe0794daffc2cff789a4 100644 --- a/paddle/fluid/operators/uniform_random_op.cc +++ b/paddle/fluid/operators/uniform_random_op.cc @@ -112,15 +112,16 @@ uniform distribution. The random result is in set [min, max]. class UniformRandomOpVarTypeInference : public framework::VarTypeInference { public: - void operator()(framework::InferVarTypeContext &ctx) const override { - auto out_var_name = ctx.Output("Out").front(); + void operator()(framework::InferVarTypeContext *ctx) const override { + auto out_var_name = ctx->Output("Out").front(); auto var_data_type = static_cast( - boost::get(ctx.GetAttr("dtype"))); + boost::get(ctx->GetAttr("dtype"))); - if (ctx.GetType(out_var_name) != framework::proto::VarType::SELECTED_ROWS) { - ctx.SetType(out_var_name, framework::proto::VarType::LOD_TENSOR); + if (ctx->GetType(out_var_name) != + framework::proto::VarType::SELECTED_ROWS) { + ctx->SetType(out_var_name, framework::proto::VarType::LOD_TENSOR); } - ctx.SetDataType(out_var_name, var_data_type); + ctx->SetDataType(out_var_name, var_data_type); } }; diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index 21e7793e0a51dbd1001c0c0591ff1826bd5e5d6c..e7d078d03a1f8e3aad5d86c9ed63bfa3cf73f546 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -38,7 +38,7 @@ void BindTracer(pybind11::module* m) { .def("trace", [](imperative::Tracer& self, imperative::OpBase* op, const imperative::VarBasePtrMap& inputs, - imperative::VarBasePtrMap& outputs, + imperative::VarBasePtrMap* outputs, framework::AttributeMap attrs_map, const platform::CPUPlace expected_place, const bool stop_gradient = false) { @@ -48,7 +48,7 @@ void BindTracer(pybind11::module* m) { .def("trace", [](imperative::Tracer& self, imperative::OpBase* op, const imperative::VarBasePtrMap& inputs, - imperative::VarBasePtrMap& outputs, + imperative::VarBasePtrMap* outputs, framework::AttributeMap attrs_map, const platform::CUDAPlace expected_place, const bool stop_gradient = false) {