diff --git a/paddle/fluid/framework/details/op_registry.h b/paddle/fluid/framework/details/op_registry.h index 346aba07d1533b607ba6a2947e4ca4a962b7b66c..420d4da8d5143f161ac730c316102267e4e2e8e7 100644 --- a/paddle/fluid/framework/details/op_registry.h +++ b/paddle/fluid/framework/details/op_registry.h @@ -129,9 +129,9 @@ struct OpInfoFiller { template struct OpInfoFiller { void operator()(const char* op_type, OpInfo* info) const { - info->infer_var_type_ = [](InferVarTypeContext& context) { + info->infer_var_type_ = [](InferVarTypeContext* context) { T inference; - inference(context); + inference(*context); }; } }; diff --git a/paddle/fluid/framework/ir/graph_test.cc b/paddle/fluid/framework/ir/graph_test.cc index 2940f3ceeb3db17a91c073771a7066e693fd8e86..851c1b80a85b534e648dd8e6417fb3a66ad3e12d 100644 --- a/paddle/fluid/framework/ir/graph_test.cc +++ b/paddle/fluid/framework/ir/graph_test.cc @@ -48,7 +48,7 @@ class SumOpVarTypeInference : public VarTypeInference { auto default_var_type = proto::VarType::SELECTED_ROWS; bool any_input_is_lod_tensor = std::any_of( - inputs.begin(), inputs.end(), [ctx](const std::string &name) { + inputs.begin(), inputs.end(), [&ctx](const std::string &name) { return ctx.GetType(name) == proto::VarType::LOD_TENSOR; }); if (any_input_is_lod_tensor) { diff --git a/paddle/fluid/framework/op_desc.cc b/paddle/fluid/framework/op_desc.cc index aae0eafe6cbb06b4c1a9f61831ef4a11d5771d9a..8f9c6cb5e924a7f35451f67e59c2455f057188e7 100644 --- a/paddle/fluid/framework/op_desc.cc +++ b/paddle/fluid/framework/op_desc.cc @@ -679,7 +679,7 @@ void OpDesc::InferVarType(BlockDesc *block) const { auto &info = OpInfoMap::Instance().Get(this->Type()); if (info.infer_var_type_) { InferVarTypeContext context(this, block); - info.infer_var_type_(context); + info.infer_var_type_(&context); } } diff --git a/paddle/fluid/framework/type_defs.h b/paddle/fluid/framework/type_defs.h index a774f9ff49a7f3f0f1f380e8f4eeefd188c26583..f55520901c53fcc5bea90c5758f401f021a5c723 100644 --- a/paddle/fluid/framework/type_defs.h +++ b/paddle/fluid/framework/type_defs.h @@ -54,7 +54,7 @@ using GradOpMakerFN = std::function>( const std::vector& grad_block)>; using InferVarTypeFN = - std::function; + std::function; using InferShapeFN = std::function; diff --git a/paddle/fluid/framework/var_type_inference_test.cc b/paddle/fluid/framework/var_type_inference_test.cc index d7d3e0a03377de402c33a0a770c7296232fe1b0e..60e1d610daf12948dd0d864dcd16d1c9d8990aa3 100644 --- a/paddle/fluid/framework/var_type_inference_test.cc +++ b/paddle/fluid/framework/var_type_inference_test.cc @@ -49,7 +49,7 @@ class SumOpVarTypeInference : public VarTypeInference { auto default_var_type = proto::VarType::SELECTED_ROWS; bool any_input_is_lod_tensor = std::any_of( - inputs.begin(), inputs.end(), [ctx](const std::string &name) { + inputs.begin(), inputs.end(), [&ctx](const std::string &name) { return ctx.GetType(name) == proto::VarType::LOD_TENSOR; }); if (any_input_is_lod_tensor) { diff --git a/paddle/fluid/imperative/layer.cc b/paddle/fluid/imperative/layer.cc index aee905aa41c75cfef9c369cefa739774d31983b6..28ab208f3f3f8e5e662835cea0ca06f41d7bde1e 100644 --- a/paddle/fluid/imperative/layer.cc +++ b/paddle/fluid/imperative/layer.cc @@ -243,12 +243,14 @@ std::map> OpBase::ApplyGrad() { auto& outputs = tmp_grad_outputs[k][it.first]; outputs.reserve(it.second.size()); for (size_t i = 0; i < it.second.size(); ++i) { + VarBase* origin_grad_var_base = it.second[i]; + // Allocate a new variable - Variable* tmp_var = new framework::Variable(); - tmp_var->GetMutable(); - VarBase* tmp_var_base = - new VarBase(it.second[i]->Name(), tmp_var, nullptr, true); - outputs.emplace_back(tmp_var_base); + VarBase* tmp_grad_var_base = new VarBase( + string::Sprintf("%s@IGrad", origin_grad_var_base->Name()), + origin_grad_var_base->DataType(), origin_grad_var_base->Dims(), + place_, true, false); + outputs.emplace_back(tmp_grad_var_base); } } @@ -259,13 +261,12 @@ std::map> OpBase::ApplyGrad() { std::unique_ptr opbase = framework::OpRegistry::CreateOp(*grad_op_desc); - // auto& info = - // framework::OpInfoMap::Instance().Get(grad_op_desc->Type()); - // if (info.infer_var_type_) { - // framework::RuntimeInferVarTypeContext infer_var_type_ctx( - // this, &grad_inputs, &outputs, &attrs_map); - // info.infer_var_type_(infer_var_type_ctx); - // } + auto& info = framework::OpInfoMap::Instance().Get(grad_op_desc->Type()); + if (info.infer_var_type_) { + RuntimeInferVarTypeContext infer_var_type_ctx( + &grad_input_vars_[k], &tmp_grad_outputs[k], &attrs_); + info.infer_var_type_(&infer_var_type_ctx); + } framework::OperatorWithKernel* op_kernel = dynamic_cast(opbase.get()); @@ -298,7 +299,6 @@ std::map> OpBase::ApplyGrad() { } framework::RuntimeContext ctx(grad_invars_map, grad_outvars_map); - framework::Scope scope; PreparedOp p = PreparedOp::Prepare(ctx, *op_kernel, place_); p.op.RuntimeInferShape(scope, place_, ctx); diff --git a/paddle/fluid/imperative/layer.h b/paddle/fluid/imperative/layer.h index 4ad7d847c182b3c7f93ff0e01cf9307994aab648..f210cd174532cc240988e0f3f45be03fe73e44bf 100644 --- a/paddle/fluid/imperative/layer.h +++ b/paddle/fluid/imperative/layer.h @@ -137,13 +137,13 @@ class VarBase { persistable) {} private: + // TODO(minqiyang): need support SelectedRows VarBase(const std::string& name, framework::proto::VarType::Type dtype, const framework::DDim& shape, const platform::Place& place, framework::Variable* var, VarBase* grad, bool stop_gradient, bool persistable) : name_(name), - dtype_(dtype), - place_(place), + type_(framework::proto::VarType::LOD_TENSOR), var_(var), grads_(grad), stop_gradient_(stop_gradient), @@ -153,10 +153,12 @@ class VarBase { pre_op_out_idx_(-1) { if (!var_) { var_ = new framework::Variable(); - auto tensor = var_->GetMutable(); - tensor->Resize(shape); - tensor->mutable_data(place_, dtype_); } + auto tensor = var_->GetMutable(); + tensor->Resize(shape); + tensor->mutable_data(place, dtype); + VLOG(10) << "create varbase: " << name_ << " type: " << dtype + << " place: " << place; } public: @@ -186,11 +188,23 @@ class VarBase { } } - inline void SetDType(framework::proto::VarType::Type type) { + inline framework::DDim Dims() const { + return var_->Get().dims(); + } + + // data type. e.g.. FP32 + inline void SetDataType(framework::proto::VarType::Type type) { auto tensor = var_->GetMutable(); - tensor->mutable_data(place_, dtype_); + tensor->mutable_data(place_, type); } - inline framework::proto::VarType::Type DType() const { return dtype_; } + inline framework::proto::VarType::Type DataType() const { + auto tensor = var_->Get(); + return tensor.type(); + } + + // tensor type. e.g.. LoDTensor + inline void SetType(framework::proto::VarType::Type type) { type_ = type; } + inline framework::proto::VarType::Type Type() const { return type_; } inline void SetStopGradient(bool stop_gradient) { stop_gradient_ = stop_gradient; @@ -244,7 +258,7 @@ class VarBase { } std::string name_; - framework::proto::VarType::Type dtype_; + framework::proto::VarType::Type type_; platform::Place place_; framework::Variable* var_; @@ -339,6 +353,8 @@ class PYBIND11_HIDDEN OpBase { std::vector grad_output_vars_; std::vector backward_hooks_; + + framework::AttributeMap attrs_; }; class Layer { @@ -437,22 +453,22 @@ class PYBIND11_HIDDEN RuntimeInferVarTypeContext framework::proto::VarType::Type GetType( const std::string& name) const override { - return var_set_.at(name)->DType(); + return var_set_.at(name)->Type(); } void SetType(const std::string& name, framework::proto::VarType::Type type) override { - var_set_[name]->SetDType(type); + var_set_[name]->SetType(type); } framework::proto::VarType::Type GetDataType( const std::string& name) const override { - return var_set_.at(name)->DType(); + return var_set_.at(name)->DataType(); } void SetDataType(const std::string& name, framework::proto::VarType::Type type) override { - var_set_[name]->SetDType(type); + var_set_[name]->SetDataType(type); } std::vector GetDataTypes( diff --git a/paddle/fluid/imperative/tracer.cc b/paddle/fluid/imperative/tracer.cc index 166883bd6f9fe57cac40c7e0e7649ea55969abf1..0f7a2415372d938885235b35b00cd683519c797b 100644 --- a/paddle/fluid/imperative/tracer.cc +++ b/paddle/fluid/imperative/tracer.cc @@ -232,7 +232,7 @@ std::set Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, if (info.infer_var_type_) { RuntimeInferVarTypeContext infer_var_type_ctx(&inputs, &outputs, &attrs_map); - info.infer_var_type_(infer_var_type_ctx); + info.infer_var_type_(&infer_var_type_ctx); } // TODO(minqiyang): Support infer var type in imperative mode @@ -259,6 +259,7 @@ std::set Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, VLOG(5) << "start construct backward op"; // construct grad op descs + op->attrs_ = attrs_map; std::unique_ptr fwd_op_desc(new framework::OpDesc( op->Type(), invars_name_map, outvars_name_map, attrs_map)); std::unique_ptr> grad_to_var( diff --git a/paddle/fluid/operators/sum_op.cc b/paddle/fluid/operators/sum_op.cc index 7dba00fffa6644c10079f0441a7482ec5078886e..2405a74d2bd8f74affe0401f7f5397f72ce2f017 100644 --- a/paddle/fluid/operators/sum_op.cc +++ b/paddle/fluid/operators/sum_op.cc @@ -168,11 +168,11 @@ class SumOpVarTypeInference : public framework::VarTypeInference { } bool any_input_is_lod_tensor = std::any_of( - inputs.begin(), inputs.end(), [ctx](const std::string& name) { + inputs.begin(), inputs.end(), [&ctx](const std::string& name) { return ctx.GetType(name) == framework::proto::VarType::LOD_TENSOR; }); - auto is_tensor_array = [ctx](const std::string& name) { + auto is_tensor_array = [&ctx](const std::string& name) { return ctx.GetType(name) == framework::proto::VarType::LOD_TENSOR_ARRAY; }; diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 552a5e0c3289b022041c6ea4f26694ed24aa858d..5a80b785e8100f39a5bf16522e11dd6facc8d427 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -194,7 +194,7 @@ PYBIND11_MODULE(core, m) { .def_property("name", &imperative::VarBase::Name, &imperative::VarBase::SetName) .def_property_readonly("shape", &imperative::VarBase::Shape) - .def_property_readonly("dtype", &imperative::VarBase::DType) + .def_property_readonly("dtype", &imperative::VarBase::DataType) .def_property("persistable", &imperative::VarBase::IsPersistable, &imperative::VarBase::SetPersistable) .def_property("stop_gradient", &imperative::VarBase::IsStopGradient,