diff --git a/paddle/fluid/framework/details/op_registry.h b/paddle/fluid/framework/details/op_registry.h index 79281863f6995ee0e2fc55860087aebbcdd34206..346aba07d1533b607ba6a2947e4ca4a962b7b66c 100644 --- a/paddle/fluid/framework/details/op_registry.h +++ b/paddle/fluid/framework/details/op_registry.h @@ -16,6 +16,8 @@ limitations under the License. */ #include #include +#include +#include #include #include "paddle/fluid/framework/grad_op_desc_maker.h" #include "paddle/fluid/framework/inplace_op_inference.h" diff --git a/paddle/fluid/framework/var_type_inference.h b/paddle/fluid/framework/var_type_inference.h index ed52e1ad81a0d3cf71fd4978b5b73d6c1c92b41b..b4b7be619abd9ccf854405b5168041bc947401c6 100644 --- a/paddle/fluid/framework/var_type_inference.h +++ b/paddle/fluid/framework/var_type_inference.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once #include +#include #include #include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/op_desc.h" @@ -80,6 +81,19 @@ class InferVarTypeContext { block_->FindRecursiveOrCreateVar(name).SetDataType(type); } + inline std::vector GetDataTypes( + const std::string& name) const { + PADDLE_ENFORCE_NOT_NULL(block_); + return block_->FindRecursiveOrCreateVar(name).GetDataTypes(); + } + + inline void SetDataTypes( + const std::string& name, + const std::vector& multiple_data_type) { + PADDLE_ENFORCE_NOT_NULL(block_); + block_->FindRecursiveOrCreateVar(name).SetDataTypes(multiple_data_type); + } + inline std::vector GetShape(const std::string& name) const { PADDLE_ENFORCE_NOT_NULL(block_); return block_->FindRecursiveOrCreateVar(name).GetShape(); @@ -101,17 +115,11 @@ class InferVarTypeContext { block_->FindRecursiveOrCreateVar(name).SetLoDLevel(lod_level); } - private: + protected: const OpDesc* op_; BlockDesc* block_; }; -// infer var type context for imperative mode -class RuntimeInferVarTypeContext : public InferVarTypeContext { - public: - RuntimeInferVarTypeContext() : InferVarTypeContext(nullptr, nullptr) {} -}; - class VarTypeInference { public: virtual ~VarTypeInference() {} diff --git a/paddle/fluid/imperative/layer.cc b/paddle/fluid/imperative/layer.cc index 5530823b90f6580692456253b0eb9d0af4e3240b..aee905aa41c75cfef9c369cefa739774d31983b6 100644 --- a/paddle/fluid/imperative/layer.cc +++ b/paddle/fluid/imperative/layer.cc @@ -220,7 +220,7 @@ std::map> OpBase::ApplyGrad() { } VLOG(3) << "apply op grad: " << Type(); - std::vector tmp_grad_outputs; + std::vector tmp_grad_outputs; if (backward_id_ > 0) { VLOG(3) << "py_layer_grad"; tmp_grad_outputs.resize(1); @@ -246,23 +246,59 @@ std::map> OpBase::ApplyGrad() { // Allocate a new variable Variable* tmp_var = new framework::Variable(); tmp_var->GetMutable(); - outputs.emplace_back(tmp_var); + VarBase* tmp_var_base = + new VarBase(it.second[i]->Name(), tmp_var, nullptr, true); + outputs.emplace_back(tmp_var_base); } } - // Run grad op - framework::RuntimeContext ctx(grad_input_vars_[k], tmp_grad_outputs[k]); - // No need to do compile time infer shape here. // grad_op_desc_->InferShape(*block_); // grad_op_desc->InferVarType(block_); std::unique_ptr opbase = framework::OpRegistry::CreateOp(*grad_op_desc); + + // auto& info = + // framework::OpInfoMap::Instance().Get(grad_op_desc->Type()); + // if (info.infer_var_type_) { + // framework::RuntimeInferVarTypeContext infer_var_type_ctx( + // this, &grad_inputs, &outputs, &attrs_map); + // info.infer_var_type_(infer_var_type_ctx); + // } + framework::OperatorWithKernel* op_kernel = dynamic_cast(opbase.get()); PADDLE_ENFORCE_NOT_NULL(op_kernel, "only support op with kernel"); + // Run grad op + framework::VariableValueMap grad_invars_map; + framework::VariableValueMap grad_outvars_map; + + for (const auto& it : grad_input_vars_[k]) { + auto& grad_invars = grad_invars_map[it.first]; + grad_invars.reserve(it.second.size()); + for (const VarBase* grad_inp : it.second) { + PADDLE_ENFORCE_NOT_NULL(grad_inp->var_, "op %s input %s nullptr", + grad_op_desc->Type(), grad_inp->Name()); + + grad_invars.emplace_back(grad_inp->var_); + } + } + + for (const auto& it : tmp_grad_outputs[k]) { + auto& grad_outvars = grad_outvars_map[it.first]; + grad_outvars.reserve(it.second.size()); + for (VarBase* grad_out : it.second) { + PADDLE_ENFORCE_NOT_NULL(grad_out->var_, "op %s output %s nullptr", + grad_op_desc->Type(), grad_out->Name()); + + grad_outvars.emplace_back(grad_out->var_); + } + } + + framework::RuntimeContext ctx(grad_invars_map, grad_outvars_map); + framework::Scope scope; PreparedOp p = PreparedOp::Prepare(ctx, *op_kernel, place_); p.op.RuntimeInferShape(scope, place_, ctx); @@ -279,8 +315,8 @@ std::map> OpBase::ApplyGrad() { PADDLE_ENFORCE_EQ(outputs.size(), origin_outputs.size()); for (size_t i = 0; i < outputs.size(); ++i) { - framework::Variable* grad = outputs[i]; - framework::Variable* orig_grad = origin_outputs[i]; + framework::Variable* grad = outputs[i]->var_; + framework::Variable* orig_grad = origin_outputs[i]->var_; AddTo(grad, orig_grad, place_); delete grad; } @@ -328,28 +364,35 @@ void PyLayer::RegisterFunc(int func_id, const py::object& py_func) { int PyLayer::NumFuncs() { return py_funcs_.size(); } -std::vector PyLayer::Apply(int func_id, - const std::vector& inputs) { - std::vector invars; - for (const VarBase* in : inputs) { - invars.push_back(in->var_); - } +std::vector PyLayer::Apply( + int func_id, const std::vector& inputs) { PADDLE_ENFORCE(py_funcs_.find(func_id) != py_funcs_.end()); - return CallPythonFunc(py_funcs_[func_id], invars); + return CallPythonFunc(py_funcs_[func_id], inputs); } -std::vector PyLayer::ApplyGrad( - int func_id, const std::vector& inputs) { +std::vector PyLayer::ApplyGrad(int func_id, + const std::vector& inputs) { PADDLE_ENFORCE(py_funcs_.find(func_id) != py_funcs_.end()); - return CallPythonFunc(py_funcs_[func_id], inputs); + auto rets = CallPythonFunc(py_funcs_[func_id], inputs); + + std::vector outs; + outs.reserve(rets.size()); + for (size_t i = 0U; i != rets.size(); ++i) { + outs.emplace_back(new VarBase( + string::Sprintf("%s_out_%d", framework::GradVarName(PyLayer::kFwdOut), + i), + rets[i], nullptr, true)); + } + + return outs; } std::vector PyLayer::CallPythonFunc( - const py::object& callable, const std::vector& ins) { + const py::object& callable, const std::vector& ins) { py::gil_scoped_acquire guard; py::tuple in_args(ins.size()); for (size_t i = 0; i < ins.size(); ++i) { - const framework::LoDTensor& t = ins[i]->Get(); + const framework::LoDTensor& t = ins[i]->var_->Get(); in_args[i] = t.IsInitialized() ? py::cast(t) : py::cast(nullptr); } VLOG(3) << "pyfunc in " << py::len(in_args); @@ -359,6 +402,7 @@ std::vector PyLayer::CallPythonFunc( auto ret_tuple = py::cast(ret); size_t ret_num = py::len(ret_tuple); std::vector outs; + outs.reserve(ret_num); VLOG(3) << "pyfunc out " << ret_num; for (size_t i = 0; i < ret_num; ++i) { try { @@ -369,7 +413,7 @@ std::vector PyLayer::CallPythonFunc( auto* tensor = var->GetMutable(); tensor->ShareDataWith(*py_out_tensor); tensor->set_lod(py_out_tensor->lod()); - outs.push_back(var); + outs.emplace_back(var); } catch (py::cast_error&) { PADDLE_THROW("The %d-th output must be LoDTensor", i); } diff --git a/paddle/fluid/imperative/layer.h b/paddle/fluid/imperative/layer.h index 618a5b7a03295ce679dc6a88e0eac57069e78b8b..494988608e73b4eedd30103067e9d47281fd5af2 100644 --- a/paddle/fluid/imperative/layer.h +++ b/paddle/fluid/imperative/layer.h @@ -18,14 +18,16 @@ #include "paddle/fluid/framework/python_headers.h" // clang-format on -#include // NOLINT -#include // NOLINT -#include // NOLINT -#include // NOLINT +#include // NOLINT +#include // NOLINT +#include // NOLINT +#include // NOLINT +#include // NOLINT #include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/var_desc.h" +#include "paddle/fluid/framework/var_type_inference.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/operators/math/math_function.h" @@ -184,6 +186,10 @@ class VarBase { } } + inline void SetDType(framework::proto::VarType::Type type) { + auto tensor = var_->GetMutable(); + tensor->mutable_data(place_, dtype_); + } inline framework::proto::VarType::Type DType() const { return dtype_; } inline void SetStopGradient(bool stop_gradient) { @@ -328,9 +334,9 @@ class PYBIND11_HIDDEN OpBase { std::map> pre_ops_out_idx_; // Inputs to a vector of bwd ops. - std::vector grad_input_vars_; + std::vector grad_input_vars_; // Outputs to a vector of bwd ops. - std::vector grad_output_vars_; + std::vector grad_output_vars_; std::vector backward_hooks_; }; @@ -359,12 +365,130 @@ class PyLayer { static std::vector Apply( int func_id, const std::vector& inputs); - static std::vector ApplyGrad( - int func_id, const std::vector& inputs); + static std::vector ApplyGrad(int func_id, + const std::vector& inputs); private: static std::vector CallPythonFunc( - const py::object& callable, const std::vector& ins); + const py::object& callable, const std::vector& ins); +}; + +// infer var type context for imperative mode +class PYBIND11_HIDDEN RuntimeInferVarTypeContext + : public framework::InferVarTypeContext { + public: + RuntimeInferVarTypeContext(imperative::OpBase* op, + const imperative::VarBasePtrMap* inputs, + imperative::VarBasePtrMap* outputs, + const framework::AttributeMap* attrs_map) + : InferVarTypeContext(nullptr, nullptr), + op_(op), + inputs_(inputs), + outputs_(outputs), + attrs_(attrs_map), + input_names_(), + output_names_(), + var_set_() { + input_names_.reserve(inputs_->size()); + for (auto& it : *inputs_) { + for (imperative::VarBase* var : it.second) { + input_names_[it.first].emplace_back(var->Name()); + var_set_[var->Name()] = var; + } + } + + output_names_.reserve(outputs_->size()); + for (auto& it : *outputs_) { + for (imperative::VarBase* var : it.second) { + output_names_[it.first].emplace_back(var->Name()); + var_set_[var->Name()] = var; + } + } + } + + framework::Attribute GetAttr(const std::string& name) const { + PADDLE_ENFORCE_NOT_NULL(attrs_); + return attrs_->at(name); + } + + inline bool HasVar(const std::string& name) const { + return var_set_.count(name) > 0; + } + + inline bool HasInput(const std::string& name) const { + PADDLE_ENFORCE_NOT_NULL(inputs_); + return inputs_->count(name) > 0; + } + + inline bool HasOutput(const std::string& name) const { + PADDLE_ENFORCE_NOT_NULL(outputs_); + return outputs_->count(name) > 0; + } + + inline const std::vector& Input(const std::string& name) const { + return input_names_.at(name); + } + + inline const std::vector& Output(const std::string& name) const { + return output_names_.at(name); + } + + inline framework::proto::VarType::Type GetType( + const std::string& name) const { + return var_set_.at(name)->DType(); + } + + inline void SetType(const std::string& name, + framework::proto::VarType::Type type) { + var_set_[name]->SetDType(type); + } + + inline framework::proto::VarType::Type GetDataType( + const std::string& name) const { + return var_set_.at(name)->DType(); + } + + inline void SetDataType(const std::string& name, + framework::proto::VarType::Type type) { + var_set_[name]->SetDType(type); + } + + inline std::vector GetDataTypes( + const std::string& name) const { + PADDLE_THROW("GetDataTypes is not supported in runtime InferVarType"); + } + + inline void SetDataTypes( + const std::string& name, + const std::vector& multiple_data_type) { + PADDLE_THROW("SetDataTypes is not supported in runtime InferVarType"); + } + + inline std::vector GetShape(const std::string& name) const { + PADDLE_THROW("Do not handle Shape in runtime InferVarType"); + } + + inline void SetShape(const std::string& name, + const std::vector& dims) { + PADDLE_THROW("Do not handle Shape in runtime InferVarType"); + } + + inline int32_t GetLoDLevel(const std::string& name) const { + PADDLE_THROW("Do not handle LoDLevel in runtime InferVarType"); + } + + inline void SetLoDLevel(const std::string& name, int32_t lod_level) { + PADDLE_THROW("Do not handle LoDLevel in runtime InferVarType"); + } + + private: + imperative::OpBase* op_; + const imperative::VarBasePtrMap* inputs_; + imperative::VarBasePtrMap* outputs_; + const framework::AttributeMap* attrs_; + std::unordered_map> input_names_; + std::unordered_map> output_names_; + std::unordered_map var_set_; }; } // namespace imperative diff --git a/paddle/fluid/imperative/tracer.cc b/paddle/fluid/imperative/tracer.cc index 7ee92b4d8c46d8814400dbc02847d701005f3d5b..7a07ec358dceea46917014ae1457a800ce52f3f1 100644 --- a/paddle/fluid/imperative/tracer.cc +++ b/paddle/fluid/imperative/tracer.cc @@ -19,6 +19,7 @@ #include #include +#include "paddle/fluid/framework/var_type_inference.h" #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/enforce.h" @@ -160,7 +161,7 @@ Tracer::Tracer(framework::BlockDesc* root_block) : root_block_(root_block) { } std::set Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, - const VarBasePtrMap& outputs, + VarBasePtrMap& outputs, framework::AttributeMap attrs_map, const platform::Place expected_place, const bool stop_gradient) { @@ -228,6 +229,12 @@ std::set Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, framework::OpRegistry::CreateOp(op->Type(), invars_name_map, outvars_name_map, attrs_map); + if (info.infer_var_type_) { + RuntimeInferVarTypeContext infer_var_type_ctx(op, &inputs, &outputs, + &attrs_map); + info.infer_var_type_(infer_var_type_ctx); + } + // TODO(minqiyang): Support infer var type in imperative mode // Run forward op VLOG(3) << "tracer running " << op->Type(); @@ -278,12 +285,12 @@ std::set Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, auto fwd_var_it = current_vars_map.find(grad_invar); PADDLE_ENFORCE(fwd_var_it != current_vars_map.end()); // Forward inputs or outputs. - grad_in_vars.emplace_back(fwd_var_it->second->var_); + grad_in_vars.emplace_back(fwd_var_it->second); } else { VarBase* var = current_vars_map[var_it->second]; InitGrad(var, prepared_op.GetDeviceContext()); // Douts. - grad_in_vars.emplace_back(var->grads_->var_); + grad_in_vars.emplace_back(var->grads_); } vars_saved_for_backward.insert(it.first); @@ -300,7 +307,7 @@ std::set Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, op->Type()); VarBase* var = current_vars_map[var_it->second]; InitGrad(var, prepared_op.GetDeviceContext()); - grad_out_vars.push_back(var->grads_->var_); + grad_out_vars.push_back(var->grads_); } } } @@ -342,23 +349,23 @@ std::vector Tracer::PyTrace(OpBase* op, auto& grad_output_vars = op->grad_output_vars_[0][framework::GradVarName(PyLayer::kFwdOut)]; - for (const VarBase* inp : inputs) { - grad_input_vars.push_back(inp->var_); + for (VarBase* inp : inputs) { + grad_input_vars.push_back(inp); } for (VarBase* out : outputs) { - grad_input_vars.push_back(out->var_); + grad_input_vars.push_back(out); } // TODO(minqiyang): Add GPU support for PyLayer, only support CPU now platform::CPUPlace place; for (VarBase* out : outputs) { InitGrad(out, platform::DeviceContextPool::Instance().Get(place)); - grad_input_vars.push_back(out->grads_->var_); + grad_input_vars.push_back(out->grads_); } for (VarBase* inp : inputs) { InitGrad(inp, platform::DeviceContextPool::Instance().Get(place)); - grad_output_vars.push_back(inp->grads_->var_); + grad_output_vars.push_back(inp->grads_); } } return outputs; diff --git a/paddle/fluid/imperative/type_defs.h b/paddle/fluid/imperative/type_defs.h index fc9e42f8d0e9996176a5cbab7d8c7cf08ddce1af..c51ce931defbc87231a2f8c6c07f99d9853fb283 100644 --- a/paddle/fluid/imperative/type_defs.h +++ b/paddle/fluid/imperative/type_defs.h @@ -25,6 +25,7 @@ class VarBase; class OpBase; typedef std::map> VarBasePtrMap; +typedef std::map> ConstVarBasePtrMap; typedef std::map> OpBasePtrMap; } // namespace imperative diff --git a/paddle/fluid/operators/controlflow/while_op.cc b/paddle/fluid/operators/controlflow/while_op.cc index 8352ba4f2b846af58d2d041ebf5201ee15f8481c..90c30678680b2dc741f68e335367dd2cbe50412f 100644 --- a/paddle/fluid/operators/controlflow/while_op.cc +++ b/paddle/fluid/operators/controlflow/while_op.cc @@ -365,19 +365,16 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker { class WhileGradOpVarTypeInference : public framework::VarTypeInference { public: - void operator()(const framework::OpDesc &op_desc, - framework::BlockDesc *block) const override { - auto p_names = op_desc.Input(kX); - auto pg_ig_names = op_desc.Output(framework::GradVarName(kX)); + void operator()(framework::InferVarTypeContext &ctx) const override { + auto p_names = ctx.Input(kX); + auto pg_ig_names = ctx.Output(framework::GradVarName(kX)); for (size_t i = 0; i < p_names.size(); ++i) { - auto &p_var = detail::Ref(block->FindVarRecursive(p_names[i])); - auto *g_var = block->FindVarRecursive(pg_ig_names[i]); - if (g_var != nullptr) { // Gradient could be @EMPTY@ + if (ctx.HasVar(pg_ig_names[i])) { VLOG(5) << "Setting " << pg_ig_names[i] << " following " << p_names[i] - << " type: " << p_var.GetType(); - g_var->SetType(p_var.GetType()); - g_var->SetDataType(p_var.GetDataType()); + << " type: " << ctx.GetType(p_names[i]); + ctx.SetType(pg_ig_names[i], ctx.GetType(p_names[i])); + ctx.SetDataType(pg_ig_names[i], ctx.GetDataType(p_names[i])); } } } diff --git a/paddle/fluid/operators/distributed_ops/split_ids_op.cc b/paddle/fluid/operators/distributed_ops/split_ids_op.cc index 2932a202a5b04b4bc7ac5ad89aa59c101bec1a94..e9f3f89c6e2e4fa24aa65cebbb6207cedf8df088 100644 --- a/paddle/fluid/operators/distributed_ops/split_ids_op.cc +++ b/paddle/fluid/operators/distributed_ops/split_ids_op.cc @@ -14,6 +14,8 @@ limitations under the License. */ #include "paddle/fluid/operators/distributed_ops/split_ids_op.h" +#include + namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/nccl/nccl_op.cc b/paddle/fluid/operators/nccl/nccl_op.cc index 0018139cb06fe0573565c920849843e674df6f4c..7df5a881f5739fc355b9a7bf90a8d999ccb8a85f 100644 --- a/paddle/fluid/operators/nccl/nccl_op.cc +++ b/paddle/fluid/operators/nccl/nccl_op.cc @@ -60,12 +60,9 @@ class NCCLInitOp : public framework::OperatorBase { class NCCLInitOpVarTypeInference : public framework::VarTypeInference { public: - void operator()(const framework::OpDesc &op_desc, - framework::BlockDesc *block) const override { - auto out_var_name = op_desc.Output("Communicator").front(); - auto &out_var = block->FindRecursiveOrCreateVar(out_var_name); - auto var_type = framework::proto::VarType::RAW; - out_var.SetType(var_type); + void operator()(framework::InferVarTypeContext &ctx) const override { + auto out_var_name = ctx.Output("Communicator").front(); + ctx.SetType(out_var_name, framework::proto::VarType::RAW); } }; diff --git a/paddle/fluid/operators/py_func_op.cc b/paddle/fluid/operators/py_func_op.cc index f630ad678ff8e067e142e5adb0a84b6f70ccf6d8..6472b9c16318e030d18b67aa208d56670592c87f 100644 --- a/paddle/fluid/operators/py_func_op.cc +++ b/paddle/fluid/operators/py_func_op.cc @@ -14,8 +14,11 @@ #include "paddle/fluid/operators/py_func_op.h" +#include #include #include +#include +#include #include #include "paddle/fluid/framework/op_registry.h" diff --git a/paddle/fluid/operators/reader/create_custom_reader_op.cc b/paddle/fluid/operators/reader/create_custom_reader_op.cc index 915325905b67aaf92618dddd9552b9f3623d5c98..b65e23685681fd4179a4cf223658440c2550f2d4 100644 --- a/paddle/fluid/operators/reader/create_custom_reader_op.cc +++ b/paddle/fluid/operators/reader/create_custom_reader_op.cc @@ -123,7 +123,7 @@ class CustomReaderInferShape : public framework::InferShapeBase { class CustomReaderInferVarType : public framework::VarTypeInference { public: - void operator()(const framework::InferVarTypeContext& ctx) const override { + void operator()(framework::InferVarTypeContext& ctx) const override { auto& out_var_name = ctx.Output("Out")[0]; PADDLE_ENFORCE(ctx.HasVar(out_var_name)); ctx.SetType(out_var_name, framework::proto::VarType::READER); diff --git a/paddle/fluid/operators/reader/read_op.cc b/paddle/fluid/operators/reader/read_op.cc index 9a98d68e13336a8af6d3bec38525ebcd1dad59f0..40549ce54d4db3fe802d68a85a3a3ffdccb12b16 100644 --- a/paddle/fluid/operators/reader/read_op.cc +++ b/paddle/fluid/operators/reader/read_op.cc @@ -51,7 +51,7 @@ class ReadInferShape : public framework::InferShapeBase { class ReadInferVarType : public framework::VarTypeInference { public: - void operator()(const framework::InferVarTypeContext& ctx) const override { + void operator()(framework::InferVarTypeContext& ctx) const override { bool infer_out = boost::get(ctx.GetAttr("infer_out")); if (infer_out) { std::string reader_name = ctx.Input("Reader")[0]; diff --git a/paddle/fluid/operators/reader/reader_op_registry.cc b/paddle/fluid/operators/reader/reader_op_registry.cc index 3921eedf94abbe68bed035940913f830a6c16e48..44772281be41b88f3fb07a0acd394b6f5f9dbc38 100644 --- a/paddle/fluid/operators/reader/reader_op_registry.cc +++ b/paddle/fluid/operators/reader/reader_op_registry.cc @@ -98,11 +98,10 @@ void FileReaderInferShape::operator()(framework::InferShapeContext* ctx) const { } } -void FileReaderInferVarType::operator()(const framework::OpDesc& op_desc, - framework::BlockDesc* block) const { - std::string reader_name = op_desc.Output("Out")[0]; - framework::VarDesc* reader = block->FindVarRecursive(reader_name); - reader->SetType(framework::proto::VarType::READER); +void FileReaderInferVarType::operator()( + framework::InferVarTypeContext& ctx) const { + std::string reader_name = ctx.Output("Out")[0]; + ctx.SetType(reader_name, framework::proto::VarType::READER); } void DecoratedReaderInferShape::operator()( @@ -125,13 +124,11 @@ void DecoratedReaderInferShape::operator()( } void DecoratedReaderInferVarType::operator()( - const framework::OpDesc& op_desc, framework::BlockDesc* block) const { - std::string in_reader_name = op_desc.Input("UnderlyingReader")[0]; - framework::VarDesc* in_reader = block->FindVarRecursive(in_reader_name); - std::string out_reader_name = op_desc.Output("Out")[0]; - framework::VarDesc* out_reader = block->FindVarRecursive(out_reader_name); - out_reader->SetType(framework::proto::VarType::READER); - out_reader->SetDataTypes(in_reader->GetDataTypes()); + framework::InferVarTypeContext& ctx) const { + const std::string& in_reader_name = ctx.Input("UnderlyingReader")[0]; + const std::string& out_reader_name = ctx.Output("Out")[0]; + ctx.SetType(out_reader_name, framework::proto::VarType::READER); + ctx.SetDataTypes(out_reader_name, ctx.GetDataTypes(in_reader_name)); } void DecoratedReaderMakerBase::Make() { diff --git a/paddle/fluid/operators/reader/reader_op_registry.h b/paddle/fluid/operators/reader/reader_op_registry.h index 58b0dfd555158d897d9dd216c3abd770305dadc6..5a775b82f5bd91178a4a2dc2b572efd5ad074b60 100644 --- a/paddle/fluid/operators/reader/reader_op_registry.h +++ b/paddle/fluid/operators/reader/reader_op_registry.h @@ -14,7 +14,9 @@ #pragma once +#include #include +#include #include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/reader.h" diff --git a/paddle/fluid/operators/scale_op.cc b/paddle/fluid/operators/scale_op.cc index d2f05c42a7949edb0b911f119a0df5a9abac4aa5..208a6f8009c6961752bdcc64be97b3eefe28b937 100644 --- a/paddle/fluid/operators/scale_op.cc +++ b/paddle/fluid/operators/scale_op.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include "paddle/fluid/operators/scale_op.h" +#include #include #include "paddle/fluid/operators/detail/safe_ref.h" diff --git a/paddle/fluid/operators/split_selected_rows_op.cc b/paddle/fluid/operators/split_selected_rows_op.cc index e950f30a42ed8d3990e68da56323f396b6839215..f102b911b58a3dc52e66976c2c89e07b30b5b98c 100644 --- a/paddle/fluid/operators/split_selected_rows_op.cc +++ b/paddle/fluid/operators/split_selected_rows_op.cc @@ -14,6 +14,8 @@ limitations under the License. */ #include "paddle/fluid/operators/split_selected_rows_op.h" +#include + namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/sum_op.cc b/paddle/fluid/operators/sum_op.cc index d674711392188c81b448389febbace8d450bd876..7dba00fffa6644c10079f0441a7482ec5078886e 100644 --- a/paddle/fluid/operators/sum_op.cc +++ b/paddle/fluid/operators/sum_op.cc @@ -12,6 +12,7 @@ limitations under the License. */ #include "paddle/fluid/operators/sum_op.h" #include +#include #include #include diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index 6bbda69297a48ce27ce23282c4e08d49ee3cce6c..21e7793e0a51dbd1001c0c0591ff1826bd5e5d6c 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -38,7 +38,7 @@ void BindTracer(pybind11::module* m) { .def("trace", [](imperative::Tracer& self, imperative::OpBase* op, const imperative::VarBasePtrMap& inputs, - const imperative::VarBasePtrMap& outputs, + imperative::VarBasePtrMap& outputs, framework::AttributeMap attrs_map, const platform::CPUPlace expected_place, const bool stop_gradient = false) { @@ -48,7 +48,7 @@ void BindTracer(pybind11::module* m) { .def("trace", [](imperative::Tracer& self, imperative::OpBase* op, const imperative::VarBasePtrMap& inputs, - const imperative::VarBasePtrMap& outputs, + imperative::VarBasePtrMap& outputs, framework::AttributeMap attrs_map, const platform::CUDAPlace expected_place, const bool stop_gradient = false) {