提交 438bca9c 编写于 作者: M minqiyang

Implement Runtime Var Type Inference

test=develop
上级 ca392c7e
...@@ -16,6 +16,8 @@ limitations under the License. */ ...@@ -16,6 +16,8 @@ limitations under the License. */
#include <string> #include <string>
#include <tuple> #include <tuple>
#include <unordered_map>
#include <unordered_set>
#include <vector> #include <vector>
#include "paddle/fluid/framework/grad_op_desc_maker.h" #include "paddle/fluid/framework/grad_op_desc_maker.h"
#include "paddle/fluid/framework/inplace_op_inference.h" #include "paddle/fluid/framework/inplace_op_inference.h"
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include <string> #include <string>
#include <unordered_map>
#include <vector> #include <vector>
#include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/op_desc.h"
...@@ -80,6 +81,19 @@ class InferVarTypeContext { ...@@ -80,6 +81,19 @@ class InferVarTypeContext {
block_->FindRecursiveOrCreateVar(name).SetDataType(type); block_->FindRecursiveOrCreateVar(name).SetDataType(type);
} }
inline std::vector<proto::VarType::Type> GetDataTypes(
const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(block_);
return block_->FindRecursiveOrCreateVar(name).GetDataTypes();
}
inline void SetDataTypes(
const std::string& name,
const std::vector<proto::VarType::Type>& multiple_data_type) {
PADDLE_ENFORCE_NOT_NULL(block_);
block_->FindRecursiveOrCreateVar(name).SetDataTypes(multiple_data_type);
}
inline std::vector<int64_t> GetShape(const std::string& name) const { inline std::vector<int64_t> GetShape(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(block_); PADDLE_ENFORCE_NOT_NULL(block_);
return block_->FindRecursiveOrCreateVar(name).GetShape(); return block_->FindRecursiveOrCreateVar(name).GetShape();
...@@ -101,17 +115,11 @@ class InferVarTypeContext { ...@@ -101,17 +115,11 @@ class InferVarTypeContext {
block_->FindRecursiveOrCreateVar(name).SetLoDLevel(lod_level); block_->FindRecursiveOrCreateVar(name).SetLoDLevel(lod_level);
} }
private: protected:
const OpDesc* op_; const OpDesc* op_;
BlockDesc* block_; BlockDesc* block_;
}; };
// infer var type context for imperative mode
class RuntimeInferVarTypeContext : public InferVarTypeContext {
public:
RuntimeInferVarTypeContext() : InferVarTypeContext(nullptr, nullptr) {}
};
class VarTypeInference { class VarTypeInference {
public: public:
virtual ~VarTypeInference() {} virtual ~VarTypeInference() {}
......
...@@ -220,7 +220,7 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() { ...@@ -220,7 +220,7 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
} }
VLOG(3) << "apply op grad: " << Type(); VLOG(3) << "apply op grad: " << Type();
std::vector<framework::VariableValueMap> tmp_grad_outputs; std::vector<VarBasePtrMap> tmp_grad_outputs;
if (backward_id_ > 0) { if (backward_id_ > 0) {
VLOG(3) << "py_layer_grad"; VLOG(3) << "py_layer_grad";
tmp_grad_outputs.resize(1); tmp_grad_outputs.resize(1);
...@@ -246,23 +246,59 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() { ...@@ -246,23 +246,59 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
// Allocate a new variable // Allocate a new variable
Variable* tmp_var = new framework::Variable(); Variable* tmp_var = new framework::Variable();
tmp_var->GetMutable<framework::LoDTensor>(); tmp_var->GetMutable<framework::LoDTensor>();
outputs.emplace_back(tmp_var); VarBase* tmp_var_base =
new VarBase(it.second[i]->Name(), tmp_var, nullptr, true);
outputs.emplace_back(tmp_var_base);
} }
} }
// Run grad op
framework::RuntimeContext ctx(grad_input_vars_[k], tmp_grad_outputs[k]);
// No need to do compile time infer shape here. // No need to do compile time infer shape here.
// grad_op_desc_->InferShape(*block_); // grad_op_desc_->InferShape(*block_);
// grad_op_desc->InferVarType(block_); // grad_op_desc->InferVarType(block_);
std::unique_ptr<framework::OperatorBase> opbase = std::unique_ptr<framework::OperatorBase> opbase =
framework::OpRegistry::CreateOp(*grad_op_desc); framework::OpRegistry::CreateOp(*grad_op_desc);
// auto& info =
// framework::OpInfoMap::Instance().Get(grad_op_desc->Type());
// if (info.infer_var_type_) {
// framework::RuntimeInferVarTypeContext infer_var_type_ctx(
// this, &grad_inputs, &outputs, &attrs_map);
// info.infer_var_type_(infer_var_type_ctx);
// }
framework::OperatorWithKernel* op_kernel = framework::OperatorWithKernel* op_kernel =
dynamic_cast<framework::OperatorWithKernel*>(opbase.get()); dynamic_cast<framework::OperatorWithKernel*>(opbase.get());
PADDLE_ENFORCE_NOT_NULL(op_kernel, "only support op with kernel"); PADDLE_ENFORCE_NOT_NULL(op_kernel, "only support op with kernel");
// Run grad op
framework::VariableValueMap grad_invars_map;
framework::VariableValueMap grad_outvars_map;
for (const auto& it : grad_input_vars_[k]) {
auto& grad_invars = grad_invars_map[it.first];
grad_invars.reserve(it.second.size());
for (const VarBase* grad_inp : it.second) {
PADDLE_ENFORCE_NOT_NULL(grad_inp->var_, "op %s input %s nullptr",
grad_op_desc->Type(), grad_inp->Name());
grad_invars.emplace_back(grad_inp->var_);
}
}
for (const auto& it : tmp_grad_outputs[k]) {
auto& grad_outvars = grad_outvars_map[it.first];
grad_outvars.reserve(it.second.size());
for (VarBase* grad_out : it.second) {
PADDLE_ENFORCE_NOT_NULL(grad_out->var_, "op %s output %s nullptr",
grad_op_desc->Type(), grad_out->Name());
grad_outvars.emplace_back(grad_out->var_);
}
}
framework::RuntimeContext ctx(grad_invars_map, grad_outvars_map);
framework::Scope scope; framework::Scope scope;
PreparedOp p = PreparedOp::Prepare(ctx, *op_kernel, place_); PreparedOp p = PreparedOp::Prepare(ctx, *op_kernel, place_);
p.op.RuntimeInferShape(scope, place_, ctx); p.op.RuntimeInferShape(scope, place_, ctx);
...@@ -279,8 +315,8 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() { ...@@ -279,8 +315,8 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
PADDLE_ENFORCE_EQ(outputs.size(), origin_outputs.size()); PADDLE_ENFORCE_EQ(outputs.size(), origin_outputs.size());
for (size_t i = 0; i < outputs.size(); ++i) { for (size_t i = 0; i < outputs.size(); ++i) {
framework::Variable* grad = outputs[i]; framework::Variable* grad = outputs[i]->var_;
framework::Variable* orig_grad = origin_outputs[i]; framework::Variable* orig_grad = origin_outputs[i]->var_;
AddTo(grad, orig_grad, place_); AddTo(grad, orig_grad, place_);
delete grad; delete grad;
} }
...@@ -328,28 +364,35 @@ void PyLayer::RegisterFunc(int func_id, const py::object& py_func) { ...@@ -328,28 +364,35 @@ void PyLayer::RegisterFunc(int func_id, const py::object& py_func) {
int PyLayer::NumFuncs() { return py_funcs_.size(); } int PyLayer::NumFuncs() { return py_funcs_.size(); }
std::vector<Variable*> PyLayer::Apply(int func_id, std::vector<framework::Variable*> PyLayer::Apply(
const std::vector<VarBase*>& inputs) { int func_id, const std::vector<VarBase*>& inputs) {
std::vector<framework::Variable*> invars;
for (const VarBase* in : inputs) {
invars.push_back(in->var_);
}
PADDLE_ENFORCE(py_funcs_.find(func_id) != py_funcs_.end()); PADDLE_ENFORCE(py_funcs_.find(func_id) != py_funcs_.end());
return CallPythonFunc(py_funcs_[func_id], invars); return CallPythonFunc(py_funcs_[func_id], inputs);
} }
std::vector<Variable*> PyLayer::ApplyGrad( std::vector<VarBase*> PyLayer::ApplyGrad(int func_id,
int func_id, const std::vector<framework::Variable*>& inputs) { const std::vector<VarBase*>& inputs) {
PADDLE_ENFORCE(py_funcs_.find(func_id) != py_funcs_.end()); PADDLE_ENFORCE(py_funcs_.find(func_id) != py_funcs_.end());
return CallPythonFunc(py_funcs_[func_id], inputs); auto rets = CallPythonFunc(py_funcs_[func_id], inputs);
std::vector<VarBase*> outs;
outs.reserve(rets.size());
for (size_t i = 0U; i != rets.size(); ++i) {
outs.emplace_back(new VarBase(
string::Sprintf("%s_out_%d", framework::GradVarName(PyLayer::kFwdOut),
i),
rets[i], nullptr, true));
}
return outs;
} }
std::vector<framework::Variable*> PyLayer::CallPythonFunc( std::vector<framework::Variable*> PyLayer::CallPythonFunc(
const py::object& callable, const std::vector<framework::Variable*>& ins) { const py::object& callable, const std::vector<VarBase*>& ins) {
py::gil_scoped_acquire guard; py::gil_scoped_acquire guard;
py::tuple in_args(ins.size()); py::tuple in_args(ins.size());
for (size_t i = 0; i < ins.size(); ++i) { for (size_t i = 0; i < ins.size(); ++i) {
const framework::LoDTensor& t = ins[i]->Get<framework::LoDTensor>(); const framework::LoDTensor& t = ins[i]->var_->Get<framework::LoDTensor>();
in_args[i] = t.IsInitialized() ? py::cast(t) : py::cast(nullptr); in_args[i] = t.IsInitialized() ? py::cast(t) : py::cast(nullptr);
} }
VLOG(3) << "pyfunc in " << py::len(in_args); VLOG(3) << "pyfunc in " << py::len(in_args);
...@@ -359,6 +402,7 @@ std::vector<framework::Variable*> PyLayer::CallPythonFunc( ...@@ -359,6 +402,7 @@ std::vector<framework::Variable*> PyLayer::CallPythonFunc(
auto ret_tuple = py::cast<py::tuple>(ret); auto ret_tuple = py::cast<py::tuple>(ret);
size_t ret_num = py::len(ret_tuple); size_t ret_num = py::len(ret_tuple);
std::vector<framework::Variable*> outs; std::vector<framework::Variable*> outs;
outs.reserve(ret_num);
VLOG(3) << "pyfunc out " << ret_num; VLOG(3) << "pyfunc out " << ret_num;
for (size_t i = 0; i < ret_num; ++i) { for (size_t i = 0; i < ret_num; ++i) {
try { try {
...@@ -369,7 +413,7 @@ std::vector<framework::Variable*> PyLayer::CallPythonFunc( ...@@ -369,7 +413,7 @@ std::vector<framework::Variable*> PyLayer::CallPythonFunc(
auto* tensor = var->GetMutable<framework::LoDTensor>(); auto* tensor = var->GetMutable<framework::LoDTensor>();
tensor->ShareDataWith(*py_out_tensor); tensor->ShareDataWith(*py_out_tensor);
tensor->set_lod(py_out_tensor->lod()); tensor->set_lod(py_out_tensor->lod());
outs.push_back(var); outs.emplace_back(var);
} catch (py::cast_error&) { } catch (py::cast_error&) {
PADDLE_THROW("The %d-th output must be LoDTensor", i); PADDLE_THROW("The %d-th output must be LoDTensor", i);
} }
......
...@@ -18,14 +18,16 @@ ...@@ -18,14 +18,16 @@
#include "paddle/fluid/framework/python_headers.h" #include "paddle/fluid/framework/python_headers.h"
// clang-format on // clang-format on
#include <map> // NOLINT #include <map> // NOLINT
#include <string> // NOLINT #include <string> // NOLINT
#include <vector> // NOLINT #include <vector> // NOLINT
#include <memory> // NOLINT #include <memory> // NOLINT
#include <unordered_map> // NOLINT
#include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/var_desc.h" #include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/framework/var_type_inference.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
...@@ -184,6 +186,10 @@ class VarBase { ...@@ -184,6 +186,10 @@ class VarBase {
} }
} }
inline void SetDType(framework::proto::VarType::Type type) {
auto tensor = var_->GetMutable<framework::LoDTensor>();
tensor->mutable_data(place_, dtype_);
}
inline framework::proto::VarType::Type DType() const { return dtype_; } inline framework::proto::VarType::Type DType() const { return dtype_; }
inline void SetStopGradient(bool stop_gradient) { inline void SetStopGradient(bool stop_gradient) {
...@@ -328,9 +334,9 @@ class PYBIND11_HIDDEN OpBase { ...@@ -328,9 +334,9 @@ class PYBIND11_HIDDEN OpBase {
std::map<std::string, std::vector<int>> pre_ops_out_idx_; std::map<std::string, std::vector<int>> pre_ops_out_idx_;
// Inputs to a vector of bwd ops. // Inputs to a vector of bwd ops.
std::vector<framework::VariableValueMap> grad_input_vars_; std::vector<VarBasePtrMap> grad_input_vars_;
// Outputs to a vector of bwd ops. // Outputs to a vector of bwd ops.
std::vector<framework::VariableValueMap> grad_output_vars_; std::vector<VarBasePtrMap> grad_output_vars_;
std::vector<py::object> backward_hooks_; std::vector<py::object> backward_hooks_;
}; };
...@@ -359,12 +365,130 @@ class PyLayer { ...@@ -359,12 +365,130 @@ class PyLayer {
static std::vector<framework::Variable*> Apply( static std::vector<framework::Variable*> Apply(
int func_id, const std::vector<VarBase*>& inputs); int func_id, const std::vector<VarBase*>& inputs);
static std::vector<framework::Variable*> ApplyGrad( static std::vector<VarBase*> ApplyGrad(int func_id,
int func_id, const std::vector<framework::Variable*>& inputs); const std::vector<VarBase*>& inputs);
private: private:
static std::vector<framework::Variable*> CallPythonFunc( static std::vector<framework::Variable*> CallPythonFunc(
const py::object& callable, const std::vector<framework::Variable*>& ins); const py::object& callable, const std::vector<VarBase*>& ins);
};
// infer var type context for imperative mode
class PYBIND11_HIDDEN RuntimeInferVarTypeContext
: public framework::InferVarTypeContext {
public:
RuntimeInferVarTypeContext(imperative::OpBase* op,
const imperative::VarBasePtrMap* inputs,
imperative::VarBasePtrMap* outputs,
const framework::AttributeMap* attrs_map)
: InferVarTypeContext(nullptr, nullptr),
op_(op),
inputs_(inputs),
outputs_(outputs),
attrs_(attrs_map),
input_names_(),
output_names_(),
var_set_() {
input_names_.reserve(inputs_->size());
for (auto& it : *inputs_) {
for (imperative::VarBase* var : it.second) {
input_names_[it.first].emplace_back(var->Name());
var_set_[var->Name()] = var;
}
}
output_names_.reserve(outputs_->size());
for (auto& it : *outputs_) {
for (imperative::VarBase* var : it.second) {
output_names_[it.first].emplace_back(var->Name());
var_set_[var->Name()] = var;
}
}
}
framework::Attribute GetAttr(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(attrs_);
return attrs_->at(name);
}
inline bool HasVar(const std::string& name) const {
return var_set_.count(name) > 0;
}
inline bool HasInput(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(inputs_);
return inputs_->count(name) > 0;
}
inline bool HasOutput(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(outputs_);
return outputs_->count(name) > 0;
}
inline const std::vector<std::string>& Input(const std::string& name) const {
return input_names_.at(name);
}
inline const std::vector<std::string>& Output(const std::string& name) const {
return output_names_.at(name);
}
inline framework::proto::VarType::Type GetType(
const std::string& name) const {
return var_set_.at(name)->DType();
}
inline void SetType(const std::string& name,
framework::proto::VarType::Type type) {
var_set_[name]->SetDType(type);
}
inline framework::proto::VarType::Type GetDataType(
const std::string& name) const {
return var_set_.at(name)->DType();
}
inline void SetDataType(const std::string& name,
framework::proto::VarType::Type type) {
var_set_[name]->SetDType(type);
}
inline std::vector<framework::proto::VarType::Type> GetDataTypes(
const std::string& name) const {
PADDLE_THROW("GetDataTypes is not supported in runtime InferVarType");
}
inline void SetDataTypes(
const std::string& name,
const std::vector<framework::proto::VarType::Type>& multiple_data_type) {
PADDLE_THROW("SetDataTypes is not supported in runtime InferVarType");
}
inline std::vector<int64_t> GetShape(const std::string& name) const {
PADDLE_THROW("Do not handle Shape in runtime InferVarType");
}
inline void SetShape(const std::string& name,
const std::vector<int64_t>& dims) {
PADDLE_THROW("Do not handle Shape in runtime InferVarType");
}
inline int32_t GetLoDLevel(const std::string& name) const {
PADDLE_THROW("Do not handle LoDLevel in runtime InferVarType");
}
inline void SetLoDLevel(const std::string& name, int32_t lod_level) {
PADDLE_THROW("Do not handle LoDLevel in runtime InferVarType");
}
private:
imperative::OpBase* op_;
const imperative::VarBasePtrMap* inputs_;
imperative::VarBasePtrMap* outputs_;
const framework::AttributeMap* attrs_;
std::unordered_map<std::string, std::vector<std::string>> input_names_;
std::unordered_map<std::string, std::vector<std::string>> output_names_;
std::unordered_map<std::string, imperative::VarBase*> var_set_;
}; };
} // namespace imperative } // namespace imperative
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include "paddle/fluid/framework/var_type_inference.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
...@@ -160,7 +161,7 @@ Tracer::Tracer(framework::BlockDesc* root_block) : root_block_(root_block) { ...@@ -160,7 +161,7 @@ Tracer::Tracer(framework::BlockDesc* root_block) : root_block_(root_block) {
} }
std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
const VarBasePtrMap& outputs, VarBasePtrMap& outputs,
framework::AttributeMap attrs_map, framework::AttributeMap attrs_map,
const platform::Place expected_place, const platform::Place expected_place,
const bool stop_gradient) { const bool stop_gradient) {
...@@ -228,6 +229,12 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, ...@@ -228,6 +229,12 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
framework::OpRegistry::CreateOp(op->Type(), invars_name_map, framework::OpRegistry::CreateOp(op->Type(), invars_name_map,
outvars_name_map, attrs_map); outvars_name_map, attrs_map);
if (info.infer_var_type_) {
RuntimeInferVarTypeContext infer_var_type_ctx(op, &inputs, &outputs,
&attrs_map);
info.infer_var_type_(infer_var_type_ctx);
}
// TODO(minqiyang): Support infer var type in imperative mode // TODO(minqiyang): Support infer var type in imperative mode
// Run forward op // Run forward op
VLOG(3) << "tracer running " << op->Type(); VLOG(3) << "tracer running " << op->Type();
...@@ -278,12 +285,12 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, ...@@ -278,12 +285,12 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
auto fwd_var_it = current_vars_map.find(grad_invar); auto fwd_var_it = current_vars_map.find(grad_invar);
PADDLE_ENFORCE(fwd_var_it != current_vars_map.end()); PADDLE_ENFORCE(fwd_var_it != current_vars_map.end());
// Forward inputs or outputs. // Forward inputs or outputs.
grad_in_vars.emplace_back(fwd_var_it->second->var_); grad_in_vars.emplace_back(fwd_var_it->second);
} else { } else {
VarBase* var = current_vars_map[var_it->second]; VarBase* var = current_vars_map[var_it->second];
InitGrad(var, prepared_op.GetDeviceContext()); InitGrad(var, prepared_op.GetDeviceContext());
// Douts. // Douts.
grad_in_vars.emplace_back(var->grads_->var_); grad_in_vars.emplace_back(var->grads_);
} }
vars_saved_for_backward.insert(it.first); vars_saved_for_backward.insert(it.first);
...@@ -300,7 +307,7 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, ...@@ -300,7 +307,7 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
op->Type()); op->Type());
VarBase* var = current_vars_map[var_it->second]; VarBase* var = current_vars_map[var_it->second];
InitGrad(var, prepared_op.GetDeviceContext()); InitGrad(var, prepared_op.GetDeviceContext());
grad_out_vars.push_back(var->grads_->var_); grad_out_vars.push_back(var->grads_);
} }
} }
} }
...@@ -342,23 +349,23 @@ std::vector<VarBase*> Tracer::PyTrace(OpBase* op, ...@@ -342,23 +349,23 @@ std::vector<VarBase*> Tracer::PyTrace(OpBase* op,
auto& grad_output_vars = auto& grad_output_vars =
op->grad_output_vars_[0][framework::GradVarName(PyLayer::kFwdOut)]; op->grad_output_vars_[0][framework::GradVarName(PyLayer::kFwdOut)];
for (const VarBase* inp : inputs) { for (VarBase* inp : inputs) {
grad_input_vars.push_back(inp->var_); grad_input_vars.push_back(inp);
} }
for (VarBase* out : outputs) { for (VarBase* out : outputs) {
grad_input_vars.push_back(out->var_); grad_input_vars.push_back(out);
} }
// TODO(minqiyang): Add GPU support for PyLayer, only support CPU now // TODO(minqiyang): Add GPU support for PyLayer, only support CPU now
platform::CPUPlace place; platform::CPUPlace place;
for (VarBase* out : outputs) { for (VarBase* out : outputs) {
InitGrad(out, platform::DeviceContextPool::Instance().Get(place)); InitGrad(out, platform::DeviceContextPool::Instance().Get(place));
grad_input_vars.push_back(out->grads_->var_); grad_input_vars.push_back(out->grads_);
} }
for (VarBase* inp : inputs) { for (VarBase* inp : inputs) {
InitGrad(inp, platform::DeviceContextPool::Instance().Get(place)); InitGrad(inp, platform::DeviceContextPool::Instance().Get(place));
grad_output_vars.push_back(inp->grads_->var_); grad_output_vars.push_back(inp->grads_);
} }
} }
return outputs; return outputs;
......
...@@ -25,6 +25,7 @@ class VarBase; ...@@ -25,6 +25,7 @@ class VarBase;
class OpBase; class OpBase;
typedef std::map<std::string, std::vector<VarBase*>> VarBasePtrMap; typedef std::map<std::string, std::vector<VarBase*>> VarBasePtrMap;
typedef std::map<std::string, std::vector<const VarBase*>> ConstVarBasePtrMap;
typedef std::map<std::string, std::vector<OpBase*>> OpBasePtrMap; typedef std::map<std::string, std::vector<OpBase*>> OpBasePtrMap;
} // namespace imperative } // namespace imperative
......
...@@ -365,19 +365,16 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker { ...@@ -365,19 +365,16 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
class WhileGradOpVarTypeInference : public framework::VarTypeInference { class WhileGradOpVarTypeInference : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc &op_desc, void operator()(framework::InferVarTypeContext &ctx) const override {
framework::BlockDesc *block) const override { auto p_names = ctx.Input(kX);
auto p_names = op_desc.Input(kX); auto pg_ig_names = ctx.Output(framework::GradVarName(kX));
auto pg_ig_names = op_desc.Output(framework::GradVarName(kX));
for (size_t i = 0; i < p_names.size(); ++i) { for (size_t i = 0; i < p_names.size(); ++i) {
auto &p_var = detail::Ref(block->FindVarRecursive(p_names[i])); if (ctx.HasVar(pg_ig_names[i])) {
auto *g_var = block->FindVarRecursive(pg_ig_names[i]);
if (g_var != nullptr) { // Gradient could be @EMPTY@
VLOG(5) << "Setting " << pg_ig_names[i] << " following " << p_names[i] VLOG(5) << "Setting " << pg_ig_names[i] << " following " << p_names[i]
<< " type: " << p_var.GetType(); << " type: " << ctx.GetType(p_names[i]);
g_var->SetType(p_var.GetType()); ctx.SetType(pg_ig_names[i], ctx.GetType(p_names[i]));
g_var->SetDataType(p_var.GetDataType()); ctx.SetDataType(pg_ig_names[i], ctx.GetDataType(p_names[i]));
} }
} }
} }
......
...@@ -14,6 +14,8 @@ limitations under the License. */ ...@@ -14,6 +14,8 @@ limitations under the License. */
#include "paddle/fluid/operators/distributed_ops/split_ids_op.h" #include "paddle/fluid/operators/distributed_ops/split_ids_op.h"
#include <memory>
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -60,12 +60,9 @@ class NCCLInitOp : public framework::OperatorBase { ...@@ -60,12 +60,9 @@ class NCCLInitOp : public framework::OperatorBase {
class NCCLInitOpVarTypeInference : public framework::VarTypeInference { class NCCLInitOpVarTypeInference : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc &op_desc, void operator()(framework::InferVarTypeContext &ctx) const override {
framework::BlockDesc *block) const override { auto out_var_name = ctx.Output("Communicator").front();
auto out_var_name = op_desc.Output("Communicator").front(); ctx.SetType(out_var_name, framework::proto::VarType::RAW);
auto &out_var = block->FindRecursiveOrCreateVar(out_var_name);
auto var_type = framework::proto::VarType::RAW;
out_var.SetType(var_type);
} }
}; };
......
...@@ -14,8 +14,11 @@ ...@@ -14,8 +14,11 @@
#include "paddle/fluid/operators/py_func_op.h" #include "paddle/fluid/operators/py_func_op.h"
#include <memory>
#include <set> #include <set>
#include <string> #include <string>
#include <unordered_set>
#include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
......
...@@ -123,7 +123,7 @@ class CustomReaderInferShape : public framework::InferShapeBase { ...@@ -123,7 +123,7 @@ class CustomReaderInferShape : public framework::InferShapeBase {
class CustomReaderInferVarType : public framework::VarTypeInference { class CustomReaderInferVarType : public framework::VarTypeInference {
public: public:
void operator()(const framework::InferVarTypeContext& ctx) const override { void operator()(framework::InferVarTypeContext& ctx) const override {
auto& out_var_name = ctx.Output("Out")[0]; auto& out_var_name = ctx.Output("Out")[0];
PADDLE_ENFORCE(ctx.HasVar(out_var_name)); PADDLE_ENFORCE(ctx.HasVar(out_var_name));
ctx.SetType(out_var_name, framework::proto::VarType::READER); ctx.SetType(out_var_name, framework::proto::VarType::READER);
......
...@@ -51,7 +51,7 @@ class ReadInferShape : public framework::InferShapeBase { ...@@ -51,7 +51,7 @@ class ReadInferShape : public framework::InferShapeBase {
class ReadInferVarType : public framework::VarTypeInference { class ReadInferVarType : public framework::VarTypeInference {
public: public:
void operator()(const framework::InferVarTypeContext& ctx) const override { void operator()(framework::InferVarTypeContext& ctx) const override {
bool infer_out = boost::get<bool>(ctx.GetAttr("infer_out")); bool infer_out = boost::get<bool>(ctx.GetAttr("infer_out"));
if (infer_out) { if (infer_out) {
std::string reader_name = ctx.Input("Reader")[0]; std::string reader_name = ctx.Input("Reader")[0];
......
...@@ -98,11 +98,10 @@ void FileReaderInferShape::operator()(framework::InferShapeContext* ctx) const { ...@@ -98,11 +98,10 @@ void FileReaderInferShape::operator()(framework::InferShapeContext* ctx) const {
} }
} }
void FileReaderInferVarType::operator()(const framework::OpDesc& op_desc, void FileReaderInferVarType::operator()(
framework::BlockDesc* block) const { framework::InferVarTypeContext& ctx) const {
std::string reader_name = op_desc.Output("Out")[0]; std::string reader_name = ctx.Output("Out")[0];
framework::VarDesc* reader = block->FindVarRecursive(reader_name); ctx.SetType(reader_name, framework::proto::VarType::READER);
reader->SetType(framework::proto::VarType::READER);
} }
void DecoratedReaderInferShape::operator()( void DecoratedReaderInferShape::operator()(
...@@ -125,13 +124,11 @@ void DecoratedReaderInferShape::operator()( ...@@ -125,13 +124,11 @@ void DecoratedReaderInferShape::operator()(
} }
void DecoratedReaderInferVarType::operator()( void DecoratedReaderInferVarType::operator()(
const framework::OpDesc& op_desc, framework::BlockDesc* block) const { framework::InferVarTypeContext& ctx) const {
std::string in_reader_name = op_desc.Input("UnderlyingReader")[0]; const std::string& in_reader_name = ctx.Input("UnderlyingReader")[0];
framework::VarDesc* in_reader = block->FindVarRecursive(in_reader_name); const std::string& out_reader_name = ctx.Output("Out")[0];
std::string out_reader_name = op_desc.Output("Out")[0]; ctx.SetType(out_reader_name, framework::proto::VarType::READER);
framework::VarDesc* out_reader = block->FindVarRecursive(out_reader_name); ctx.SetDataTypes(out_reader_name, ctx.GetDataTypes(in_reader_name));
out_reader->SetType(framework::proto::VarType::READER);
out_reader->SetDataTypes(in_reader->GetDataTypes());
} }
void DecoratedReaderMakerBase::Make() { void DecoratedReaderMakerBase::Make() {
......
...@@ -14,7 +14,9 @@ ...@@ -14,7 +14,9 @@
#pragma once #pragma once
#include <memory>
#include <string> #include <string>
#include <unordered_map>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/reader.h" #include "paddle/fluid/framework/reader.h"
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/scale_op.h" #include "paddle/fluid/operators/scale_op.h"
#include <memory>
#include <string> #include <string>
#include "paddle/fluid/operators/detail/safe_ref.h" #include "paddle/fluid/operators/detail/safe_ref.h"
......
...@@ -14,6 +14,8 @@ limitations under the License. */ ...@@ -14,6 +14,8 @@ limitations under the License. */
#include "paddle/fluid/operators/split_selected_rows_op.h" #include "paddle/fluid/operators/split_selected_rows_op.h"
#include <memory>
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -12,6 +12,7 @@ limitations under the License. */ ...@@ -12,6 +12,7 @@ limitations under the License. */
#include "paddle/fluid/operators/sum_op.h" #include "paddle/fluid/operators/sum_op.h"
#include <algorithm> #include <algorithm>
#include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
......
...@@ -38,7 +38,7 @@ void BindTracer(pybind11::module* m) { ...@@ -38,7 +38,7 @@ void BindTracer(pybind11::module* m) {
.def("trace", .def("trace",
[](imperative::Tracer& self, imperative::OpBase* op, [](imperative::Tracer& self, imperative::OpBase* op,
const imperative::VarBasePtrMap& inputs, const imperative::VarBasePtrMap& inputs,
const imperative::VarBasePtrMap& outputs, imperative::VarBasePtrMap& outputs,
framework::AttributeMap attrs_map, framework::AttributeMap attrs_map,
const platform::CPUPlace expected_place, const platform::CPUPlace expected_place,
const bool stop_gradient = false) { const bool stop_gradient = false) {
...@@ -48,7 +48,7 @@ void BindTracer(pybind11::module* m) { ...@@ -48,7 +48,7 @@ void BindTracer(pybind11::module* m) {
.def("trace", .def("trace",
[](imperative::Tracer& self, imperative::OpBase* op, [](imperative::Tracer& self, imperative::OpBase* op,
const imperative::VarBasePtrMap& inputs, const imperative::VarBasePtrMap& inputs,
const imperative::VarBasePtrMap& outputs, imperative::VarBasePtrMap& outputs,
framework::AttributeMap attrs_map, framework::AttributeMap attrs_map,
const platform::CUDAPlace expected_place, const platform::CUDAPlace expected_place,
const bool stop_gradient = false) { const bool stop_gradient = false) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册