未验证 提交 c7f1f3ed 编写于 作者: Q Qiyang Min 提交者: GitHub

Merge pull request #16214 from velconia/imperative_infer_var_type

Implement imperative infer var type
......@@ -68,11 +68,11 @@ class SplitOpMaker : public OpProtoAndCheckerMaker {
class DummyVarTypeInference : public VarTypeInference {
public:
void operator()(const OpDesc& op_desc, BlockDesc* block) const override {
auto& inputs = op_desc.Input("X");
auto type = block->Var(inputs.front())->GetType();
auto out_var_name = op_desc.Output("Out").front();
block->Var(out_var_name)->SetType(type);
void operator()(framework::InferVarTypeContext* ctx) const override {
auto& inputs = ctx->Input("X");
auto type = ctx->GetType(inputs.front());
auto out_var_name = ctx->Output("Out").front();
ctx->SetType(out_var_name, type);
}
};
......
......@@ -16,6 +16,8 @@ limitations under the License. */
#include <string>
#include <tuple>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/grad_op_desc_maker.h"
#include "paddle/fluid/framework/inplace_op_inference.h"
......@@ -127,9 +129,9 @@ struct OpInfoFiller<T, kGradOpDescMaker> {
template <typename T>
struct OpInfoFiller<T, kVarTypeInference> {
void operator()(const char* op_type, OpInfo* info) const {
info->infer_var_type_ = [](const OpDesc& fwd_op, BlockDesc* block) {
info->infer_var_type_ = [](InferVarTypeContext* context) {
T inference;
inference(fwd_op, block);
inference(context);
};
}
};
......
......@@ -43,20 +43,20 @@ class SumOpMaker : public OpProtoAndCheckerMaker {
class SumOpVarTypeInference : public VarTypeInference {
public:
void operator()(const OpDesc &op_desc, BlockDesc *block) const override {
auto &inputs = op_desc.Input("X");
void operator()(InferVarTypeContext *ctx) const override {
auto &inputs = ctx->Input("X");
auto default_var_type = proto::VarType::SELECTED_ROWS;
bool any_input_is_lod_tensor = std::any_of(
inputs.begin(), inputs.end(), [block](const std::string &name) {
return block->Var(name)->GetType() == proto::VarType::LOD_TENSOR;
inputs.begin(), inputs.end(), [&ctx](const std::string &name) {
return ctx->GetType(name) == proto::VarType::LOD_TENSOR;
});
if (any_input_is_lod_tensor) {
default_var_type = proto::VarType::LOD_TENSOR;
}
auto out_var_name = op_desc.Output("Out").front();
block->Var(out_var_name)->SetType(default_var_type);
auto out_var_name = ctx->Output("Out").front();
ctx->SetType(out_var_name, default_var_type);
}
};
......@@ -71,7 +71,7 @@ class DummyOpMaker : public OpProtoAndCheckerMaker {
class DummyOpVarTypeInference : public VarTypeInference {
public:
void operator()(const OpDesc &op_desc, BlockDesc *block) const override {}
void operator()(framework::InferVarTypeContext *ctx) const override {}
};
} // namespace framework
} // namespace paddle
......
......@@ -24,6 +24,7 @@ limitations under the License. */
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/shape_inference.h"
#include "paddle/fluid/framework/var_type_inference.h"
namespace paddle {
namespace framework {
......@@ -677,7 +678,8 @@ void OpDesc::InferVarType(BlockDesc *block) const {
// var type inference. Hence, we don't do any "default" setting here.
auto &info = OpInfoMap::Instance().Get(this->Type());
if (info.infer_var_type_) {
info.infer_var_type_(*this, block);
InferVarTypeContext context(this, block);
info.infer_var_type_(&context);
}
}
......
......@@ -27,6 +27,7 @@ namespace framework {
class OperatorBase;
class OpDesc;
class InferShapeContext;
class InferVarTypeContext;
class BlockDesc;
class Variable;
......@@ -53,7 +54,7 @@ using GradOpMakerFN = std::function<std::vector<std::unique_ptr<OpDesc>>(
const std::vector<BlockDesc*>& grad_block)>;
using InferVarTypeFN =
std::function<void(const OpDesc& /*op_desc*/, BlockDesc* /*block*/)>;
std::function<void(framework::InferVarTypeContext* /*context*/)>;
using InferShapeFN = std::function<void(InferShapeContext*)>;
......
......@@ -14,6 +14,8 @@ limitations under the License. */
#pragma once
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/type_defs.h"
......@@ -21,26 +23,123 @@ limitations under the License. */
namespace paddle {
namespace framework {
class OpDesc;
class BlockDesc;
// default infer var type context
class InferVarTypeContext {
public:
InferVarTypeContext(const OpDesc* op, BlockDesc* block)
: op_(op), block_(block) {}
virtual ~InferVarTypeContext() {}
virtual Attribute GetAttr(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(op_);
return op_->GetAttr(name);
}
virtual bool HasVar(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(block_);
return block_->FindVarRecursive(name) != nullptr;
}
virtual bool HasInput(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(op_);
return op_->Inputs().count(name) > 0;
}
virtual bool HasOutput(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(op_);
return op_->Outputs().count(name) > 0;
}
virtual const std::vector<std::string>& Input(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(op_);
return op_->Input(name);
}
virtual const std::vector<std::string>& Output(
const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(op_);
return op_->Output(name);
}
virtual proto::VarType::Type GetType(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(block_);
return block_->FindRecursiveOrCreateVar(name).GetType();
}
virtual void SetType(const std::string& name, proto::VarType::Type type) {
PADDLE_ENFORCE_NOT_NULL(block_);
block_->FindRecursiveOrCreateVar(name).SetType(type);
}
virtual proto::VarType::Type GetDataType(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(block_);
return block_->FindRecursiveOrCreateVar(name).GetDataType();
}
virtual void SetDataType(const std::string& name, proto::VarType::Type type) {
PADDLE_ENFORCE_NOT_NULL(block_);
block_->FindRecursiveOrCreateVar(name).SetDataType(type);
}
virtual std::vector<proto::VarType::Type> GetDataTypes(
const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(block_);
return block_->FindRecursiveOrCreateVar(name).GetDataTypes();
}
virtual void SetDataTypes(
const std::string& name,
const std::vector<proto::VarType::Type>& multiple_data_type) {
PADDLE_ENFORCE_NOT_NULL(block_);
block_->FindRecursiveOrCreateVar(name).SetDataTypes(multiple_data_type);
}
virtual std::vector<int64_t> GetShape(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(block_);
return block_->FindRecursiveOrCreateVar(name).GetShape();
}
virtual void SetShape(const std::string& name,
const std::vector<int64_t>& dims) {
PADDLE_ENFORCE_NOT_NULL(block_);
block_->FindRecursiveOrCreateVar(name).SetShape(dims);
}
virtual int32_t GetLoDLevel(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(block_);
return block_->FindRecursiveOrCreateVar(name).GetLoDLevel();
}
virtual void SetLoDLevel(const std::string& name, int32_t lod_level) {
PADDLE_ENFORCE_NOT_NULL(block_);
block_->FindRecursiveOrCreateVar(name).SetLoDLevel(lod_level);
}
protected:
const OpDesc* op_;
BlockDesc* block_;
};
class VarTypeInference {
public:
virtual ~VarTypeInference() {}
virtual void operator()(const OpDesc& op_desc, BlockDesc* block) const = 0;
virtual void operator()(InferVarTypeContext* context) const = 0; // NOLINT
};
class PassInDtypeAndVarTypeToOutput : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const final {
void operator()(framework::InferVarTypeContext* ctx) const final { // NOLINT
auto in_out_var_names = this->GetInputOutputWithSameType();
for (auto& i_o_n : in_out_var_names) {
auto& x_name = op_desc.Input(i_o_n.first).at(0);
auto& out_name = op_desc.Output(i_o_n.second).at(0);
auto& x_name = ctx->Input(i_o_n.first).at(0);
auto& out_name = ctx->Output(i_o_n.second).at(0);
auto& x = block->FindRecursiveOrCreateVar(x_name);
auto& out = block->FindRecursiveOrCreateVar(out_name);
out.SetType(x.GetType());
out.SetDataType(x.GetDataType());
ctx->SetType(out_name, ctx->GetType(x_name));
ctx->SetDataType(out_name, ctx->GetDataType(x_name));
}
}
......
......@@ -44,20 +44,20 @@ class SumOpMaker : public OpProtoAndCheckerMaker {
class SumOpVarTypeInference : public VarTypeInference {
public:
void operator()(const OpDesc &op_desc, BlockDesc *block) const override {
auto &inputs = op_desc.Input("X");
void operator()(framework::InferVarTypeContext *ctx) const override {
auto &inputs = ctx->Input("X");
auto default_var_type = proto::VarType::SELECTED_ROWS;
bool any_input_is_lod_tensor = std::any_of(
inputs.begin(), inputs.end(), [block](const std::string &name) {
return block->Var(name)->GetType() == proto::VarType::LOD_TENSOR;
inputs.begin(), inputs.end(), [&ctx](const std::string &name) {
return ctx->GetType(name) == proto::VarType::LOD_TENSOR;
});
if (any_input_is_lod_tensor) {
default_var_type = proto::VarType::LOD_TENSOR;
}
auto out_var_name = op_desc.Output("Out").front();
block->Var(out_var_name)->SetType(default_var_type);
auto out_var_name = ctx->Output("Out").front();
ctx->SetType(out_var_name, default_var_type);
}
};
} // namespace framework
......
......@@ -218,7 +218,7 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
"%s has no backward implementation", Type());
VLOG(3) << "apply op grad: " << Type();
std::vector<framework::VariableValueMap> tmp_grad_outputs;
std::vector<VarBasePtrMap> tmp_grad_outputs;
if (backward_id_ > 0) {
VLOG(3) << "py_layer_grad";
tmp_grad_outputs.resize(1);
......@@ -241,26 +241,62 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
auto& outputs = tmp_grad_outputs[k][it.first];
outputs.reserve(it.second.size());
for (size_t i = 0; i < it.second.size(); ++i) {
VarBase* origin_grad_var_base = it.second[i];
// Allocate a new variable
Variable* tmp_var = new framework::Variable();
tmp_var->GetMutable<framework::LoDTensor>();
outputs.emplace_back(tmp_var);
VarBase* tmp_grad_var_base = new VarBase(
string::Sprintf("%s@IGrad", origin_grad_var_base->Name()),
origin_grad_var_base->DataType(), origin_grad_var_base->Dims(),
place_, true, false);
outputs.emplace_back(tmp_grad_var_base);
}
}
// Run grad op
framework::RuntimeContext ctx(grad_input_vars_[k], tmp_grad_outputs[k]);
// No need to do compile time infer shape here.
// grad_op_desc_->InferShape(*block_);
// grad_op_desc->InferVarType(block_);
std::unique_ptr<framework::OperatorBase> opbase =
framework::OpRegistry::CreateOp(*grad_op_desc);
auto& info = framework::OpInfoMap::Instance().Get(grad_op_desc->Type());
if (info.infer_var_type_) {
RuntimeInferVarTypeContext infer_var_type_ctx(
&grad_input_vars_[k], &tmp_grad_outputs[k], &attrs_);
info.infer_var_type_(&infer_var_type_ctx);
}
framework::OperatorWithKernel* op_kernel =
dynamic_cast<framework::OperatorWithKernel*>(opbase.get());
PADDLE_ENFORCE_NOT_NULL(op_kernel, "only support op with kernel");
// Run grad op
framework::VariableValueMap grad_invars_map;
framework::VariableValueMap grad_outvars_map;
for (const auto& it : grad_input_vars_[k]) {
auto& grad_invars = grad_invars_map[it.first];
grad_invars.reserve(it.second.size());
for (const VarBase* grad_inp : it.second) {
PADDLE_ENFORCE_NOT_NULL(grad_inp->var_, "op %s input %s nullptr",
grad_op_desc->Type(), grad_inp->Name());
grad_invars.emplace_back(grad_inp->var_);
}
}
for (const auto& it : tmp_grad_outputs[k]) {
auto& grad_outvars = grad_outvars_map[it.first];
grad_outvars.reserve(it.second.size());
for (VarBase* grad_out : it.second) {
PADDLE_ENFORCE_NOT_NULL(grad_out->var_, "op %s output %s nullptr",
grad_op_desc->Type(), grad_out->Name());
grad_outvars.emplace_back(grad_out->var_);
}
}
framework::RuntimeContext ctx(grad_invars_map, grad_outvars_map);
framework::Scope scope;
PreparedOp p = PreparedOp::Prepare(ctx, *op_kernel, place_);
p.op.RuntimeInferShape(scope, place_, ctx);
......@@ -277,8 +313,8 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
PADDLE_ENFORCE_EQ(outputs.size(), origin_outputs.size());
for (size_t i = 0; i < outputs.size(); ++i) {
framework::Variable* grad = outputs[i];
framework::Variable* orig_grad = origin_outputs[i];
framework::Variable* grad = outputs[i]->var_;
framework::Variable* orig_grad = origin_outputs[i]->var_;
AddTo(grad, orig_grad, place_);
delete grad;
}
......@@ -326,28 +362,35 @@ void PyLayer::RegisterFunc(int func_id, const py::object& py_func) {
int PyLayer::NumFuncs() { return py_funcs_.size(); }
std::vector<Variable*> PyLayer::Apply(int func_id,
const std::vector<VarBase*>& inputs) {
std::vector<framework::Variable*> invars;
for (const VarBase* in : inputs) {
invars.push_back(in->var_);
}
std::vector<framework::Variable*> PyLayer::Apply(
int func_id, const std::vector<VarBase*>& inputs) {
PADDLE_ENFORCE(py_funcs_.find(func_id) != py_funcs_.end());
return CallPythonFunc(py_funcs_[func_id], invars);
return CallPythonFunc(py_funcs_[func_id], inputs);
}
std::vector<Variable*> PyLayer::ApplyGrad(
int func_id, const std::vector<framework::Variable*>& inputs) {
std::vector<VarBase*> PyLayer::ApplyGrad(int func_id,
const std::vector<VarBase*>& inputs) {
PADDLE_ENFORCE(py_funcs_.find(func_id) != py_funcs_.end());
return CallPythonFunc(py_funcs_[func_id], inputs);
auto rets = CallPythonFunc(py_funcs_[func_id], inputs);
std::vector<VarBase*> outs;
outs.reserve(rets.size());
for (size_t i = 0U; i != rets.size(); ++i) {
outs.emplace_back(new VarBase(
string::Sprintf("%s_out_%d", framework::GradVarName(PyLayer::kFwdOut),
i),
rets[i], nullptr, true));
}
return outs;
}
std::vector<framework::Variable*> PyLayer::CallPythonFunc(
const py::object& callable, const std::vector<framework::Variable*>& ins) {
const py::object& callable, const std::vector<VarBase*>& ins) {
py::gil_scoped_acquire guard;
py::tuple in_args(ins.size());
for (size_t i = 0; i < ins.size(); ++i) {
const framework::LoDTensor& t = ins[i]->Get<framework::LoDTensor>();
const framework::LoDTensor& t = ins[i]->var_->Get<framework::LoDTensor>();
in_args[i] = t.IsInitialized() ? py::cast(t) : py::cast(nullptr);
}
VLOG(3) << "pyfunc in " << py::len(in_args);
......@@ -357,6 +400,7 @@ std::vector<framework::Variable*> PyLayer::CallPythonFunc(
auto ret_tuple = py::cast<py::tuple>(ret);
size_t ret_num = py::len(ret_tuple);
std::vector<framework::Variable*> outs;
outs.reserve(ret_num);
VLOG(3) << "pyfunc out " << ret_num;
for (size_t i = 0; i < ret_num; ++i) {
try {
......@@ -367,7 +411,7 @@ std::vector<framework::Variable*> PyLayer::CallPythonFunc(
auto* tensor = var->GetMutable<framework::LoDTensor>();
tensor->ShareDataWith(*py_out_tensor);
tensor->set_lod(py_out_tensor->lod());
outs.push_back(var);
outs.emplace_back(var);
} catch (py::cast_error&) {
PADDLE_THROW("The %d-th output must be LoDTensor", i);
}
......
......@@ -22,10 +22,12 @@
#include <string> // NOLINT
#include <vector> // NOLINT
#include <memory> // NOLINT
#include <unordered_map> // NOLINT
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/framework/var_type_inference.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/operators/math/math_function.h"
......@@ -135,13 +137,13 @@ class VarBase {
persistable) {}
private:
// TODO(minqiyang): need support SelectedRows
VarBase(const std::string& name, framework::proto::VarType::Type dtype,
const framework::DDim& shape, const platform::Place& place,
framework::Variable* var, VarBase* grad, bool stop_gradient,
bool persistable)
: name_(name),
dtype_(dtype),
place_(place),
type_(framework::proto::VarType::LOD_TENSOR),
var_(var),
grads_(grad),
stop_gradient_(stop_gradient),
......@@ -151,10 +153,12 @@ class VarBase {
pre_op_out_idx_(-1) {
if (!var_) {
var_ = new framework::Variable();
}
auto tensor = var_->GetMutable<framework::LoDTensor>();
tensor->Resize(shape);
tensor->mutable_data(place_, dtype_);
}
tensor->mutable_data(place, dtype);
VLOG(10) << "create varbase: " << name_ << " type: " << dtype
<< " place: " << place;
}
public:
......@@ -184,7 +188,23 @@ class VarBase {
}
}
inline framework::proto::VarType::Type DType() const { return dtype_; }
inline framework::DDim Dims() const {
return var_->Get<framework::LoDTensor>().dims();
}
// data type. e.g.. FP32
inline void SetDataType(framework::proto::VarType::Type type) {
auto tensor = var_->GetMutable<framework::LoDTensor>();
tensor->mutable_data(tensor->place(), type);
}
inline framework::proto::VarType::Type DataType() const {
auto tensor = var_->Get<framework::LoDTensor>();
return tensor.type();
}
// tensor type. e.g.. LoDTensor
inline void SetType(framework::proto::VarType::Type type) { type_ = type; }
inline framework::proto::VarType::Type Type() const { return type_; }
inline void SetStopGradient(bool stop_gradient) {
stop_gradient_ = stop_gradient;
......@@ -238,7 +258,7 @@ class VarBase {
}
std::string name_;
framework::proto::VarType::Type dtype_;
framework::proto::VarType::Type type_;
platform::Place place_;
framework::Variable* var_;
......@@ -334,11 +354,13 @@ class PYBIND11_HIDDEN OpBase {
std::map<std::string, std::vector<int>> pre_ops_out_idx_;
// Inputs to a vector of bwd ops.
std::vector<framework::VariableValueMap> grad_input_vars_;
std::vector<VarBasePtrMap> grad_input_vars_;
// Outputs to a vector of bwd ops.
std::vector<framework::VariableValueMap> grad_output_vars_;
std::vector<VarBasePtrMap> grad_output_vars_;
std::vector<py::object> backward_hooks_;
framework::AttributeMap attrs_;
};
class Layer {
......@@ -365,12 +387,131 @@ class PyLayer {
static std::vector<framework::Variable*> Apply(
int func_id, const std::vector<VarBase*>& inputs);
static std::vector<framework::Variable*> ApplyGrad(
int func_id, const std::vector<framework::Variable*>& inputs);
static std::vector<VarBase*> ApplyGrad(int func_id,
const std::vector<VarBase*>& inputs);
private:
static std::vector<framework::Variable*> CallPythonFunc(
const py::object& callable, const std::vector<framework::Variable*>& ins);
const py::object& callable, const std::vector<VarBase*>& ins);
};
// infer var type context for imperative mode
class PYBIND11_HIDDEN RuntimeInferVarTypeContext
: public framework::InferVarTypeContext {
public:
RuntimeInferVarTypeContext(const imperative::VarBasePtrMap* inputs,
imperative::VarBasePtrMap* outputs,
const framework::AttributeMap* attrs_map)
: InferVarTypeContext(nullptr, nullptr),
inputs_(inputs),
outputs_(outputs),
attrs_(attrs_map),
input_names_(),
output_names_(),
var_set_() {
input_names_.reserve(inputs_->size());
for (auto& it : *inputs_) {
for (imperative::VarBase* var : it.second) {
input_names_[it.first].emplace_back(var->Name());
var_set_[var->Name()] = var;
}
}
output_names_.reserve(outputs_->size());
for (auto& it : *outputs_) {
for (imperative::VarBase* var : it.second) {
output_names_[it.first].emplace_back(var->Name());
var_set_[var->Name()] = var;
}
}
}
virtual ~RuntimeInferVarTypeContext() {}
framework::Attribute GetAttr(const std::string& name) const override {
PADDLE_ENFORCE_NOT_NULL(attrs_);
return attrs_->at(name);
}
bool HasVar(const std::string& name) const override {
return var_set_.count(name) > 0;
}
bool HasInput(const std::string& name) const override {
PADDLE_ENFORCE_NOT_NULL(inputs_);
return inputs_->count(name) > 0;
}
bool HasOutput(const std::string& name) const override {
PADDLE_ENFORCE_NOT_NULL(outputs_);
return outputs_->count(name) > 0;
}
const std::vector<std::string>& Input(
const std::string& name) const override {
return input_names_.at(name);
}
const std::vector<std::string>& Output(
const std::string& name) const override {
return output_names_.at(name);
}
framework::proto::VarType::Type GetType(
const std::string& name) const override {
return var_set_.at(name)->Type();
}
void SetType(const std::string& name,
framework::proto::VarType::Type type) override {
var_set_[name]->SetType(type);
}
framework::proto::VarType::Type GetDataType(
const std::string& name) const override {
return var_set_.at(name)->DataType();
}
void SetDataType(const std::string& name,
framework::proto::VarType::Type type) override {
var_set_[name]->SetDataType(type);
}
std::vector<framework::proto::VarType::Type> GetDataTypes(
const std::string& name) const override {
PADDLE_THROW("GetDataTypes is not supported in runtime InferVarType");
}
void SetDataTypes(const std::string& name,
const std::vector<framework::proto::VarType::Type>&
multiple_data_type) override {
PADDLE_THROW("SetDataTypes is not supported in runtime InferVarType");
}
std::vector<int64_t> GetShape(const std::string& name) const override {
PADDLE_THROW("Do not handle Shape in runtime InferVarType");
}
void SetShape(const std::string& name,
const std::vector<int64_t>& dims) override {
PADDLE_THROW("Do not handle Shape in runtime InferVarType");
}
int32_t GetLoDLevel(const std::string& name) const override {
PADDLE_THROW("Do not handle LoDLevel in runtime InferVarType");
}
void SetLoDLevel(const std::string& name, int32_t lod_level) override {
PADDLE_THROW("Do not handle LoDLevel in runtime InferVarType");
}
private:
const imperative::VarBasePtrMap* inputs_;
imperative::VarBasePtrMap* outputs_;
const framework::AttributeMap* attrs_;
std::unordered_map<std::string, std::vector<std::string>> input_names_;
std::unordered_map<std::string, std::vector<std::string>> output_names_;
std::unordered_map<std::string, imperative::VarBase*> var_set_;
};
} // namespace imperative
......
......@@ -19,6 +19,7 @@
#include <unordered_map>
#include <unordered_set>
#include "paddle/fluid/framework/var_type_inference.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
......@@ -135,7 +136,7 @@ framework::VariableNameMap CreateOutputVarNameMap(
Tracer::Tracer(framework::BlockDesc* root_block) : root_block_(root_block) {}
std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
const VarBasePtrMap& outputs,
VarBasePtrMap* outputs,
framework::AttributeMap attrs_map,
const platform::Place expected_place,
const bool stop_gradient) {
......@@ -163,7 +164,7 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
op->TrackPreOp(it.first, it.second);
}
op->output_vars_ = outputs;
op->output_vars_ = *outputs;
for (auto it : op->output_vars_) {
auto& outvars = outvars_map[it.first];
const std::vector<VarBase*>& outputs = it.second;
......@@ -186,7 +187,7 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
framework::VariableNameMap invars_name_map =
CreateInputVarNameMap(op, inputs);
framework::VariableNameMap outvars_name_map =
CreateOutputVarNameMap(op, outputs);
CreateOutputVarNameMap(op, *outputs);
auto& info = framework::OpInfoMap::Instance().Get(op->Type());
if (info.Checker() != nullptr) {
......@@ -197,6 +198,11 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
framework::OpRegistry::CreateOp(op->Type(), invars_name_map,
outvars_name_map, attrs_map);
if (info.infer_var_type_) {
RuntimeInferVarTypeContext infer_var_type_ctx(&inputs, outputs, &attrs_map);
info.infer_var_type_(&infer_var_type_ctx);
}
// TODO(minqiyang): Support infer var type in imperative mode
// Run forward op
VLOG(3) << "tracer running " << op->Type();
......@@ -221,6 +227,7 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
VLOG(5) << "start construct backward op";
// construct grad op descs
op->attrs_ = attrs_map;
std::unique_ptr<framework::OpDesc> fwd_op_desc(new framework::OpDesc(
op->Type(), invars_name_map, outvars_name_map, attrs_map));
std::unique_ptr<std::unordered_map<std::string, std::string>> grad_to_var(
......@@ -247,12 +254,12 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
auto fwd_var_it = current_vars_map.find(grad_invar);
PADDLE_ENFORCE(fwd_var_it != current_vars_map.end());
// Forward inputs or outputs.
grad_in_vars.emplace_back(fwd_var_it->second->var_);
grad_in_vars.emplace_back(fwd_var_it->second);
} else {
VarBase* var = current_vars_map[var_it->second];
InitGrad(var, prepared_op.GetDeviceContext());
// Douts.
grad_in_vars.emplace_back(var->grads_->var_);
grad_in_vars.emplace_back(var->grads_);
}
vars_saved_for_backward.insert(it.first);
......@@ -269,7 +276,7 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
op->Type());
VarBase* var = current_vars_map[var_it->second];
InitGrad(var, prepared_op.GetDeviceContext());
grad_out_vars.push_back(var->grads_->var_);
grad_out_vars.push_back(var->grads_);
}
}
}
......@@ -309,23 +316,23 @@ std::vector<VarBase*> Tracer::PyTrace(OpBase* op,
auto& grad_output_vars =
op->grad_output_vars_[0][framework::GradVarName(PyLayer::kFwdOut)];
for (const VarBase* inp : inputs) {
grad_input_vars.push_back(inp->var_);
for (VarBase* inp : inputs) {
grad_input_vars.push_back(inp);
}
for (VarBase* out : outputs) {
grad_input_vars.push_back(out->var_);
grad_input_vars.push_back(out);
}
// TODO(minqiyang): Add GPU support for PyLayer, only support CPU now
platform::CPUPlace place;
for (VarBase* out : outputs) {
InitGrad(out, platform::DeviceContextPool::Instance().Get(place));
grad_input_vars.push_back(out->grads_->var_);
grad_input_vars.push_back(out->grads_);
}
for (VarBase* inp : inputs) {
InitGrad(inp, platform::DeviceContextPool::Instance().Get(place));
grad_output_vars.push_back(inp->grads_->var_);
grad_output_vars.push_back(inp->grads_);
}
}
return outputs;
......
......@@ -48,7 +48,7 @@ class Tracer {
virtual ~Tracer() {}
std::set<std::string> Trace(OpBase* op, const VarBasePtrMap& inputs,
const VarBasePtrMap& outputs,
VarBasePtrMap* outputs, // NOLINT
framework::AttributeMap attrs_map,
const platform::Place expected_place,
const bool stop_gradient = false);
......
......@@ -25,6 +25,7 @@ class VarBase;
class OpBase;
typedef std::map<std::string, std::vector<VarBase*>> VarBasePtrMap;
typedef std::map<std::string, std::vector<const VarBase*>> ConstVarBasePtrMap;
typedef std::map<std::string, std::vector<OpBase*>> OpBasePtrMap;
} // namespace imperative
......
......@@ -203,15 +203,12 @@ class BeamSearchDecodeInferShape : public framework::InferShapeBase {
class BeamSearchDecodeInferVarType : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override {
for (auto& o : op_desc.Output("SentenceIds")) {
auto& sentence_ids = block->FindRecursiveOrCreateVar(o);
sentence_ids.SetType(framework::proto::VarType::LOD_TENSOR);
void operator()(framework::InferVarTypeContext* ctx) const override {
for (auto& o : ctx->Output("SentenceIds")) {
ctx->SetType(o, framework::proto::VarType::LOD_TENSOR);
}
for (auto& o : op_desc.Output("SentenceScores")) {
auto& sentence_scores = block->FindRecursiveOrCreateVar(o);
sentence_scores.SetType(framework::proto::VarType::LOD_TENSOR);
for (auto& o : ctx->Output("SentenceScores")) {
ctx->SetType(o, framework::proto::VarType::LOD_TENSOR);
}
}
};
......
......@@ -120,15 +120,12 @@ class BeamSearchOp : public framework::OperatorWithKernel {
class BeamSearchInferVarType : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {
for (auto &o : op_desc.Output("selected_ids")) {
auto &selected_ids = block->FindRecursiveOrCreateVar(o);
selected_ids.SetType(framework::proto::VarType::LOD_TENSOR);
void operator()(framework::InferVarTypeContext *ctx) const override {
for (auto &o : ctx->Output("selected_ids")) {
ctx->SetType(o, framework::proto::VarType::LOD_TENSOR);
}
for (auto &o : op_desc.Output("selected_scores")) {
auto &selected_scores = block->FindRecursiveOrCreateVar(o);
selected_scores.SetType(framework::proto::VarType::LOD_TENSOR);
for (auto &o : ctx->Output("selected_scores")) {
ctx->SetType(o, framework::proto::VarType::LOD_TENSOR);
}
}
};
......
......@@ -93,11 +93,9 @@ execution.
class GetPlacesInferVarType : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {
for (auto &o_name : op_desc.Output("Out")) {
block->FindRecursiveOrCreateVar(o_name).SetType(
framework::proto::VarType::PLACE_LIST);
void operator()(framework::InferVarTypeContext *ctx) const override {
for (auto &o_name : ctx->Output("Out")) {
ctx->SetType(o_name, framework::proto::VarType::PLACE_LIST);
}
}
};
......
......@@ -100,16 +100,13 @@ class WriteToArrayInferShape : public framework::InferShapeBase {
class WriteToArrayInferVarType : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {
auto x_name = op_desc.Input("X")[0];
auto out_name = op_desc.Output("Out")[0];
void operator()(framework::InferVarTypeContext *ctx) const override {
auto x_name = ctx->Input("X")[0];
auto out_name = ctx->Output("Out")[0];
VLOG(10) << "Set Variable " << out_name << " as LOD_TENSOR_ARRAY";
auto &out = block->FindRecursiveOrCreateVar(out_name);
out.SetType(framework::proto::VarType::LOD_TENSOR_ARRAY);
auto *x = block->FindVarRecursive(x_name);
if (x != nullptr) {
out.SetDataType(x->GetDataType());
ctx->SetType(out_name, framework::proto::VarType::LOD_TENSOR_ARRAY);
if (ctx->HasVar(x_name)) {
ctx->SetDataType(out_name, ctx->GetDataType(x_name));
}
}
};
......
......@@ -365,19 +365,16 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
class WhileGradOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {
auto p_names = op_desc.Input(kX);
auto pg_ig_names = op_desc.Output(framework::GradVarName(kX));
void operator()(framework::InferVarTypeContext *ctx) const override {
auto p_names = ctx->Input(kX);
auto pg_ig_names = ctx->Output(framework::GradVarName(kX));
for (size_t i = 0; i < p_names.size(); ++i) {
auto &p_var = detail::Ref(block->FindVarRecursive(p_names[i]));
auto *g_var = block->FindVarRecursive(pg_ig_names[i]);
if (g_var != nullptr) { // Gradient could be @EMPTY@
if (ctx->HasVar(pg_ig_names[i])) {
VLOG(5) << "Setting " << pg_ig_names[i] << " following " << p_names[i]
<< " type: " << p_var.GetType();
g_var->SetType(p_var.GetType());
g_var->SetDataType(p_var.GetDataType());
<< " type: " << ctx->GetType(p_names[i]);
ctx->SetType(pg_ig_names[i], ctx->GetType(p_names[i]));
ctx->SetDataType(pg_ig_names[i], ctx->GetDataType(p_names[i]));
}
}
}
......
......@@ -56,8 +56,7 @@ class FakeInitOp : public framework::OperatorBase {
class FakeInitOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {}
void operator()(framework::InferVarTypeContext *ctx) const override {}
};
class FakeInitOpMaker : public framework::OpProtoAndCheckerMaker {
......
......@@ -114,11 +114,10 @@ class MergeIdsOp : public framework::OperatorWithKernel {
class MergeIdsOpInferVarType : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {
auto *input_var = block->Var(op_desc.Input("Ids")[0]);
for (auto &out_var : op_desc.Output("Out")) {
block->Var(out_var)->SetType(input_var->GetType());
void operator()(framework::InferVarTypeContext *ctx) const override {
auto input_type = ctx->GetType(ctx->Input("Ids")[0]);
for (auto &out_var : ctx->Output("Out")) {
ctx->SetType(out_var, input_type);
}
}
};
......
......@@ -14,6 +14,8 @@ limitations under the License. */
#include "paddle/fluid/operators/distributed_ops/split_ids_op.h"
#include <memory>
namespace paddle {
namespace operators {
......@@ -71,11 +73,10 @@ class SplitIdsOp : public framework::OperatorWithKernel {
class SplitIdsOpInferVarType : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {
auto *input_var = block->Var(op_desc.Input("Ids")[0]);
for (auto &out_var : op_desc.Output("Out")) {
block->Var(out_var)->SetType(input_var->GetType());
void operator()(framework::InferVarTypeContext *ctx) const override {
auto input_type = ctx->GetType(ctx->Input("Ids")[0]);
for (auto &out_var : ctx->Output("Out")) {
ctx->SetType(out_var, input_type);
}
}
};
......
......@@ -39,12 +39,11 @@ class FillConstantOp : public framework::OperatorWithKernel {
class FillConstantOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override {
void operator()(framework::InferVarTypeContext* ctx) const override {
auto data_type = static_cast<framework::proto::VarType::Type>(
boost::get<int>(op_desc.GetAttr("dtype")));
auto& out_var_name = op_desc.Output("Out").front();
block->Var(out_var_name)->SetDataType(data_type);
boost::get<int>(ctx->GetAttr("dtype")));
auto& out_var_name = ctx->Output("Out").front();
ctx->SetDataType(out_var_name, data_type);
}
};
......
......@@ -138,22 +138,20 @@ class FusedEmbeddingSeqPoolOpGrad : public framework::OperatorWithKernel {
class FusedEmbeddingSeqPoolOpGradVarTypeInference
: public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override {
auto out_var_name = op_desc.Output(framework::GradVarName("W")).front();
auto attr = op_desc.GetAttr("is_sparse");
void operator()(framework::InferVarTypeContext* ctx) const override {
auto out_var_name = ctx->Output(framework::GradVarName("W")).front();
auto attr = ctx->GetAttr("is_sparse");
bool is_sparse = boost::get<bool>(attr);
if (is_sparse) {
VLOG(3) << "fused_embedding_seq_pool_grad op "
<< framework::GradVarName("W") << " is set to SelectedRows";
block->Var(out_var_name)
->SetType(framework::proto::VarType::SELECTED_ROWS);
ctx->SetType(out_var_name, framework::proto::VarType::SELECTED_ROWS);
} else {
VLOG(3) << "fused_embedding_seq_pool_grad op "
<< framework::GradVarName("W") << " is set to LoDTensor";
block->Var(out_var_name)->SetType(framework::proto::VarType::LOD_TENSOR);
ctx->SetType(out_var_name, framework::proto::VarType::LOD_TENSOR);
}
block->Var(out_var_name)->SetDataType(block->Var("W")->GetDataType());
ctx->SetDataType(out_var_name, ctx->GetDataType(ctx->Input("W")[0]));
}
};
......
......@@ -81,15 +81,12 @@ GetTensorFromSelectedRows is used to get the tensor from SelectedRows.
class GetTensorFromSelectedRowsOpVarTypeInference
: public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const final {
auto out_var_name = op_desc.Output("Out").front();
auto in_var_name = op_desc.Input("X").front();
auto out_var = block->FindRecursiveOrCreateVar(out_var_name);
auto in_var = block->FindRecursiveOrCreateVar(in_var_name);
out_var.SetType(framework::proto::VarType::LOD_TENSOR);
out_var.SetDataType(in_var.GetDataType());
void operator()(framework::InferVarTypeContext *ctx) const { // NOLINT
auto out_var_name = ctx->Output("Out").front();
auto in_var_name = ctx->Input("X").front();
ctx->SetType(out_var_name, framework::proto::VarType::LOD_TENSOR);
ctx->SetDataType(out_var_name, ctx->GetDataType(in_var_name));
}
};
......
......@@ -197,38 +197,32 @@ class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel {
class HierarchicalSigmoidGradOpGradVarTypeInference
: public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override {
auto w_grad_var_name = op_desc.Output(framework::GradVarName("W")).front();
auto bias_grad_var_name_vec =
op_desc.Output(framework::GradVarName("Bias"));
void operator()(framework::InferVarTypeContext* ctx) const override {
auto w_grad_var_name = ctx->Output(framework::GradVarName("W")).front();
auto bias_grad_var_name_vec = ctx->Output(framework::GradVarName("Bias"));
std::string bias_grad_var_name;
bool hasBias = false;
if (bias_grad_var_name_vec.size()) {
hasBias = true;
bias_grad_var_name =
op_desc.Output(framework::GradVarName("Bias")).front();
bias_grad_var_name = ctx->Output(framework::GradVarName("Bias")).front();
}
auto attr = op_desc.GetAttr("is_sparse");
auto attr = ctx->GetAttr("is_sparse");
bool is_sparse = boost::get<bool>(attr);
if (is_sparse) {
VLOG(30) << "hierarchical_sigmoid_grad op " << framework::GradVarName("W")
<< " is set to SelectedRows";
block->Var(w_grad_var_name)
->SetType(framework::proto::VarType::SELECTED_ROWS);
ctx->SetType(w_grad_var_name, framework::proto::VarType::SELECTED_ROWS);
} else {
VLOG(30) << "hierarchical_sigmoid_grad op " << framework::GradVarName("W")
<< " is set to LoDTensor";
block->Var(w_grad_var_name)
->SetType(framework::proto::VarType::LOD_TENSOR);
ctx->SetType(w_grad_var_name, framework::proto::VarType::LOD_TENSOR);
}
if (hasBias) {
VLOG(30) << "hierarchical_sigmoid_grad op "
<< framework::GradVarName("Bias") << " is set to LoDTensor";
block->Var(bias_grad_var_name)
->SetType(framework::proto::VarType::LOD_TENSOR);
ctx->SetType(bias_grad_var_name, framework::proto::VarType::LOD_TENSOR);
}
block->Var(w_grad_var_name)->SetDataType(block->Var("W")->GetDataType());
ctx->SetDataType(w_grad_var_name, ctx->GetDataType(ctx->Input("W")[0]));
}
};
......
......@@ -64,11 +64,9 @@ class LoDRankTableInferShape : public framework::InferShapeBase {
class LoDRankTableInferVarType : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {
for (auto &o : op_desc.Output("Out")) {
block->FindRecursiveOrCreateVar(o).SetType(
framework::proto::VarType::LOD_RANK_TABLE);
void operator()(framework::InferVarTypeContext *ctx) const override {
for (auto &o : ctx->Output("Out")) {
ctx->SetType(o, framework::proto::VarType::LOD_RANK_TABLE);
}
}
};
......
......@@ -201,10 +201,9 @@ class LoDTensorToArrayInferShape : public framework::InferShapeBase {
class LoDTensorToArrayInferVarType : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {
for (auto &out_var : op_desc.Output("Out")) {
block->Var(out_var)->SetType(framework::proto::VarType::LOD_TENSOR_ARRAY);
void operator()(framework::InferVarTypeContext *ctx) const override {
for (auto &out_var : ctx->Output("Out")) {
ctx->SetType(out_var, framework::proto::VarType::LOD_TENSOR_ARRAY);
}
}
};
......
......@@ -147,22 +147,20 @@ class LookupTableOpGrad : public framework::OperatorWithKernel {
class LookupTableOpGradVarTypeInference : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override {
auto out_var_name = op_desc.Output(framework::GradVarName("W")).front();
auto attr = op_desc.GetAttr("is_sparse");
void operator()(framework::InferVarTypeContext* ctx) const override {
auto out_var_name = ctx->Output(framework::GradVarName("W")).front();
auto attr = ctx->GetAttr("is_sparse");
bool is_sparse = boost::get<bool>(attr);
if (is_sparse) {
VLOG(3) << "lookup_table_grad op " << framework::GradVarName("W")
<< " is set to SelectedRows";
block->Var(out_var_name)
->SetType(framework::proto::VarType::SELECTED_ROWS);
ctx->SetType(out_var_name, framework::proto::VarType::SELECTED_ROWS);
} else {
VLOG(3) << "lookup_table_grad op " << framework::GradVarName("W")
<< " is set to LoDTensor";
block->Var(out_var_name)->SetType(framework::proto::VarType::LOD_TENSOR);
ctx->SetType(out_var_name, framework::proto::VarType::LOD_TENSOR);
}
block->Var(out_var_name)->SetDataType(block->Var("W")->GetDataType());
ctx->SetDataType(out_var_name, ctx->GetDataType(ctx->Input("W")[0]));
}
};
......
......@@ -60,12 +60,9 @@ class NCCLInitOp : public framework::OperatorBase {
class NCCLInitOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {
auto out_var_name = op_desc.Output("Communicator").front();
auto &out_var = block->FindRecursiveOrCreateVar(out_var_name);
auto var_type = framework::proto::VarType::RAW;
out_var.SetType(var_type);
void operator()(framework::InferVarTypeContext *ctx) const override {
auto out_var_name = ctx->Output("Communicator").front();
ctx->SetType(out_var_name, framework::proto::VarType::RAW);
}
};
......
......@@ -237,23 +237,21 @@ class NCEOpGrad : public framework::OperatorWithKernel {
class NCEOpGradVarTypeInference : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {
auto weight_grad = op_desc.Output(framework::GradVarName("Weight")).front();
void operator()(framework::InferVarTypeContext *ctx) const override {
auto weight_grad = ctx->Output(framework::GradVarName("Weight")).front();
auto attr = op_desc.GetAttr("is_sparse");
auto attr = ctx->GetAttr("is_sparse");
bool is_sparse = boost::get<bool>(attr);
if (is_sparse) {
VLOG(3) << "nce_op_grad op " << weight_grad << " and "
<< " is set to SelectedRows";
block->Var(weight_grad)
->SetType(framework::proto::VarType::SELECTED_ROWS);
ctx->SetType(weight_grad, framework::proto::VarType::SELECTED_ROWS);
} else {
VLOG(3) << "nce_op_grad op " << weight_grad << " and "
<< " is set to LoDTensor";
block->Var(weight_grad)->SetType(framework::proto::VarType::LOD_TENSOR);
ctx->SetType(weight_grad, framework::proto::VarType::LOD_TENSOR);
}
block->Var(weight_grad)->SetDataType(block->Var("Input")->GetDataType());
ctx->SetDataType(weight_grad, ctx->GetDataType(ctx->Input("Input")[0]));
}
};
......
......@@ -37,8 +37,7 @@ class NgraphEngineOpMaker : public framework::OpProtoAndCheckerMaker {
class NgraphEngineInferVarType : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {}
void operator()(framework::InferVarTypeContext *ctx) const override {}
};
} // namespace operators
......
......@@ -72,8 +72,7 @@ use L2 regularizers in case of using LARS.
class LarsMomentumOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {}
void operator()(framework::InferVarTypeContext* ctx) const override {}
};
} // namespace operators
} // namespace paddle
......
......@@ -21,18 +21,14 @@ using Tensor = framework::Tensor;
class MomentumOpInferVarType : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override {
auto input_var = op_desc.Input("Param")[0];
for (auto& out_var : op_desc.Output("ParamOut")) {
if (block->FindRecursiveOrCreateVar(input_var).GetType() ==
framework::proto::VarType::SELECTED_ROWS) {
block->FindRecursiveOrCreateVar(out_var).SetType(
framework::proto::VarType::SELECTED_ROWS);
} else if (block->FindRecursiveOrCreateVar(input_var).GetType() ==
void operator()(framework::InferVarTypeContext* ctx) const override {
auto& input_var = ctx->Input("Param")[0];
for (auto& out_var : ctx->Output("ParamOut")) {
if (ctx->GetType(input_var) == framework::proto::VarType::SELECTED_ROWS) {
ctx->SetType(out_var, framework::proto::VarType::SELECTED_ROWS);
} else if (ctx->GetType(input_var) ==
framework::proto::VarType::LOD_TENSOR) {
block->FindRecursiveOrCreateVar(out_var).SetType(
framework::proto::VarType::LOD_TENSOR);
ctx->SetType(out_var, framework::proto::VarType::LOD_TENSOR);
} else {
PADDLE_THROW(
"Only support LodTensor and SelectedRows, Unexpected Input Type.");
......
......@@ -50,20 +50,18 @@ class SGDOp : public framework::OperatorWithKernel {
class SGDOpInferVarType : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {
auto input_var_n = op_desc.Input("Param")[0];
auto in_var_type = block->FindRecursiveOrCreateVar(input_var_n).GetType();
void operator()(framework::InferVarTypeContext *ctx) const override {
auto &input_var_n = ctx->Input("Param")[0];
auto in_var_type = ctx->GetType(input_var_n);
PADDLE_ENFORCE(in_var_type == framework::proto::VarType::SELECTED_ROWS ||
in_var_type == framework::proto::VarType::LOD_TENSOR,
"The input Var's type should be LoDtensor or SelectedRows,"
" but the received var(%s)'s type is %s",
input_var_n, in_var_type);
for (auto &out_var_n : op_desc.Output("ParamOut")) {
auto &out_var = block->FindRecursiveOrCreateVar(out_var_n);
if (out_var.GetType() != in_var_type) {
out_var.SetType(in_var_type);
for (auto &out_var_n : ctx->Output("ParamOut")) {
if (ctx->GetType(out_var_n) != in_var_type) {
ctx->SetType(out_var_n, in_var_type);
}
}
}
......
......@@ -14,8 +14,11 @@
#include "paddle/fluid/operators/py_func_op.h"
#include <memory>
#include <set>
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
......@@ -91,15 +94,12 @@ static void CallPythonFunc(py::object *callable,
}
}
class PyFuncOpVarTypInference : public framework::VarTypeInference {
class PyFuncOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc &op,
framework::BlockDesc *block) const override {
auto &outs = op.Outputs();
bool has_out = (outs.count("Out") > 0 && !outs.at("Out").empty());
void operator()(framework::InferVarTypeContext *ctx) const override {
bool has_out = (ctx->HasOutput("Out") && !ctx->Output("Out").empty());
auto &ins = op.Inputs();
bool has_in = (ins.count("X") > 0 && !ins.at("X").empty());
bool has_in = (ctx->HasInput("X") && !ctx->Input("X").empty());
/**
* X or Out can be empty, so that py_func can be more flexible
......@@ -107,8 +107,8 @@ class PyFuncOpVarTypInference : public framework::VarTypeInference {
*/
PADDLE_ENFORCE(has_in || has_out, "Input(X) or Output(Out) must exist");
PADDLE_ENFORCE_GE(boost::get<int>(op.GetAttr(kForwardPythonCallableId)), 0,
"Function id cannot be less than 0");
PADDLE_ENFORCE_GE(boost::get<int>(ctx->GetAttr(kForwardPythonCallableId)),
0, "Function id cannot be less than 0");
if (!has_out) return;
......@@ -118,7 +118,7 @@ class PyFuncOpVarTypInference : public framework::VarTypeInference {
* the corresponding forward variable
*/
const std::string kGradVarSuffix = framework::kGradVarSuffix;
auto &out_var_names = outs.at("Out");
auto &out_var_names = ctx->Output("Out");
for (auto &out_var_name : out_var_names) {
if (out_var_name == framework::kEmptyVarName ||
out_var_name.size() < kGradVarSuffix.size()) {
......@@ -128,18 +128,17 @@ class PyFuncOpVarTypInference : public framework::VarTypeInference {
size_t len = out_var_name.size() - kGradVarSuffix.size();
if (out_var_name.substr(len) == kGradVarSuffix) {
auto fwd_var_name = out_var_name.substr(0, len);
auto *out_var_desc = block->FindVarRecursive(out_var_name);
auto *fwd_var_desc = block->FindVarRecursive(fwd_var_name);
PADDLE_ENFORCE_NOT_NULL(out_var_desc, "Backward variable %s not found",
out_var_name);
PADDLE_ENFORCE_NOT_NULL(fwd_var_desc, "Forward variable %s not found",
fwd_var_name);
PADDLE_ENFORCE(ctx->HasVar(out_var_name),
"Backward variable %s not found", out_var_name);
PADDLE_ENFORCE(ctx->HasVar(fwd_var_name),
"Backward variable %s not found", fwd_var_name);
VLOG(10) << "Infer var_desc of Output(" << out_var_name << ") as Input("
<< fwd_var_name << ")";
out_var_desc->SetShape(fwd_var_desc->GetShape());
out_var_desc->SetDataType(fwd_var_desc->GetDataType());
out_var_desc->SetLoDLevel(fwd_var_desc->GetLoDLevel());
out_var_desc->SetType(fwd_var_desc->GetType());
ctx->SetShape(out_var_name, ctx->GetShape(fwd_var_name));
ctx->SetDataType(out_var_name, ctx->GetDataType(fwd_var_name));
ctx->SetLoDLevel(out_var_name, ctx->GetLoDLevel(fwd_var_name));
ctx->SetType(out_var_name, ctx->GetType(fwd_var_name));
}
}
}
......@@ -309,5 +308,5 @@ class PyFuncOp : public framework::OperatorBase {
namespace ops = paddle::operators;
REGISTER_OPERATOR(py_func, ops::PyFuncOp, ops::PyFuncOpMaker,
ops::PyFuncOpVarTypInference, ops::PyFuncOpShapeInference,
ops::PyFuncOpVarTypeInference, ops::PyFuncOpShapeInference,
ops::PyFuncOpGradDescMaker);
......@@ -123,23 +123,22 @@ class CustomReaderInferShape : public framework::InferShapeBase {
class CustomReaderInferVarType : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override {
framework::VarDesc* out_reader = block->FindVar(op_desc.Output("Out")[0]);
PADDLE_ENFORCE_NOT_NULL(out_reader);
out_reader->SetType(framework::proto::VarType::READER);
void operator()(framework::InferVarTypeContext* ctx) const override {
auto& out_var_name = ctx->Output("Out")[0];
PADDLE_ENFORCE(ctx->HasVar(out_var_name));
ctx->SetType(out_var_name, framework::proto::VarType::READER);
auto sink_var_names =
boost::get<std::vector<std::string>>(op_desc.GetAttr("sink_var_names"));
boost::get<std::vector<std::string>>(ctx->GetAttr("sink_var_names"));
const auto* sub_block =
boost::get<framework::BlockDesc*>(op_desc.GetAttr("sub_block"));
boost::get<framework::BlockDesc*>(ctx->GetAttr("sub_block"));
std::vector<framework::proto::VarType::Type> res_data_types;
for (const std::string& var_name : sink_var_names) {
framework::VarDesc* var = sub_block->FindVar(var_name);
PADDLE_ENFORCE_NOT_NULL(var);
res_data_types.emplace_back(var->GetDataType());
}
out_reader->SetDataTypes(res_data_types);
ctx->SetDataTypes(out_var_name, res_data_types);
}
};
......
......@@ -51,19 +51,16 @@ class ReadInferShape : public framework::InferShapeBase {
class ReadInferVarType : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override {
bool infer_out = boost::get<bool>(op_desc.GetAttr("infer_out"));
void operator()(framework::InferVarTypeContext* ctx) const override {
bool infer_out = boost::get<bool>(ctx->GetAttr("infer_out"));
if (infer_out) {
std::string reader_name = op_desc.Input("Reader")[0];
std::vector<std::string> out_names = op_desc.Output("Out");
framework::VarDesc* reader = block->FindVarRecursive(reader_name);
auto dtypes = reader->GetDataTypes();
std::string reader_name = ctx->Input("Reader")[0];
std::vector<std::string> out_names = ctx->Output("Out");
auto dtypes = ctx->GetDataTypes(reader_name);
PADDLE_ENFORCE_EQ(dtypes.size(), out_names.size());
for (size_t i = 0; i < dtypes.size(); ++i) {
framework::VarDesc& out = block->FindRecursiveOrCreateVar(out_names[i]);
out.SetType(framework::proto::VarType::LOD_TENSOR);
out.SetDataType(dtypes[i]);
ctx->SetType(out_names[i], framework::proto::VarType::LOD_TENSOR);
ctx->SetDataType(out_names[i], dtypes[i]);
}
}
}
......
......@@ -98,11 +98,10 @@ void FileReaderInferShape::operator()(framework::InferShapeContext* ctx) const {
}
}
void FileReaderInferVarType::operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const {
std::string reader_name = op_desc.Output("Out")[0];
framework::VarDesc* reader = block->FindVarRecursive(reader_name);
reader->SetType(framework::proto::VarType::READER);
void FileReaderInferVarType::operator()(
framework::InferVarTypeContext* ctx) const {
std::string reader_name = ctx->Output("Out")[0];
ctx->SetType(reader_name, framework::proto::VarType::READER);
}
void DecoratedReaderInferShape::operator()(
......@@ -125,13 +124,11 @@ void DecoratedReaderInferShape::operator()(
}
void DecoratedReaderInferVarType::operator()(
const framework::OpDesc& op_desc, framework::BlockDesc* block) const {
std::string in_reader_name = op_desc.Input("UnderlyingReader")[0];
framework::VarDesc* in_reader = block->FindVarRecursive(in_reader_name);
std::string out_reader_name = op_desc.Output("Out")[0];
framework::VarDesc* out_reader = block->FindVarRecursive(out_reader_name);
out_reader->SetType(framework::proto::VarType::READER);
out_reader->SetDataTypes(in_reader->GetDataTypes());
framework::InferVarTypeContext* ctx) const {
const std::string& in_reader_name = ctx->Input("UnderlyingReader")[0];
const std::string& out_reader_name = ctx->Output("Out")[0];
ctx->SetType(out_reader_name, framework::proto::VarType::READER);
ctx->SetDataTypes(out_reader_name, ctx->GetDataTypes(in_reader_name));
}
void DecoratedReaderMakerBase::Make() {
......
......@@ -14,7 +14,9 @@
#pragma once
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/reader.h"
......@@ -59,8 +61,7 @@ class FileReaderInferShape : public framework::InferShapeBase {
class FileReaderInferVarType : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override;
void operator()(framework::InferVarTypeContext* ctx) const override;
};
// general infershape for decorated reader
......@@ -72,8 +73,7 @@ class DecoratedReaderInferShape : public framework::InferShapeBase {
// general var type inference for decorated reader
class DecoratedReaderInferVarType : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override;
void operator()(framework::InferVarTypeContext* ctx) const override;
};
class DecoratedReaderMakerBase : public framework::OpProtoAndCheckerMaker {
......
......@@ -159,12 +159,9 @@ This operator will serialize and write LoDTensor / SelectedRows variable to file
class SaveOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {
auto out_var_name = op_desc.Output(LOOKUP_TABLE_PATH).front();
auto &out_var = block->FindRecursiveOrCreateVar(out_var_name);
auto var_type = framework::proto::VarType::RAW;
out_var.SetType(var_type);
void operator()(framework::InferVarTypeContext *ctx) const override {
auto out_var_name = ctx->Output(LOOKUP_TABLE_PATH).front();
ctx->SetType(out_var_name, framework::proto::VarType::RAW);
}
};
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/scale_op.h"
#include <memory>
#include <string>
#include "paddle/fluid/operators/detail/safe_ref.h"
......@@ -69,17 +70,13 @@ $$Out = scale*(X + bias)$$
class ScaleOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {
auto &in_var_name = op_desc.Input("X").front();
auto &in_var = detail::Ref(block->FindVarRecursive(in_var_name));
auto out_var_name = op_desc.Output("Out").front();
auto *out_var = block->FindVarRecursive(out_var_name);
void operator()(framework::InferVarTypeContext *ctx) const override {
auto &in_var_name = ctx->Input("X").front();
auto out_var_name = ctx->Output("Out").front();
if (in_var_name != out_var_name) {
out_var->SetType(in_var.GetType());
out_var->SetDataType(in_var.GetDataType());
ctx->SetType(out_var_name, ctx->GetType(in_var_name));
ctx->SetDataType(out_var_name, ctx->GetDataType(in_var_name));
}
}
};
......
......@@ -14,6 +14,8 @@ limitations under the License. */
#include "paddle/fluid/operators/split_selected_rows_op.h"
#include <memory>
namespace paddle {
namespace operators {
......@@ -60,10 +62,9 @@ class SplitSelectedRowsOp : public framework::OperatorWithKernel {
class SplitSelectedRowsOpInferVarType : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {
for (auto &out_var : op_desc.Output("Out")) {
block->Var(out_var)->SetType(framework::proto::VarType::SELECTED_ROWS);
void operator()(framework::InferVarTypeContext *ctx) const override {
for (auto &out_var : ctx->Output("Out")) {
ctx->SetType(out_var, framework::proto::VarType::SELECTED_ROWS);
}
}
};
......
......@@ -12,6 +12,7 @@ limitations under the License. */
#include "paddle/fluid/operators/sum_op.h"
#include <algorithm>
#include <memory>
#include <string>
#include <vector>
......@@ -159,24 +160,20 @@ the LoD information with the first input.
class SumOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override {
auto& inputs = op_desc.Input("X");
void operator()(framework::InferVarTypeContext* ctx) const override {
auto& inputs = ctx->Input("X");
auto var_type = framework::proto::VarType::SELECTED_ROWS;
for (auto& name : op_desc.Input("X")) {
VLOG(10) << name << " "
<< block->FindRecursiveOrCreateVar(name).GetType();
for (auto& name : ctx->Input("X")) {
VLOG(10) << name << " " << ctx->GetType(name);
}
bool any_input_is_lod_tensor = std::any_of(
inputs.begin(), inputs.end(), [block](const std::string& name) {
return block->FindRecursiveOrCreateVar(name).GetType() ==
framework::proto::VarType::LOD_TENSOR;
inputs.begin(), inputs.end(), [ctx](const std::string& name) {
return ctx->GetType(name) == framework::proto::VarType::LOD_TENSOR;
});
auto is_tensor_array = [block](const std::string& name) {
return block->FindRecursiveOrCreateVar(name).GetType() ==
framework::proto::VarType::LOD_TENSOR_ARRAY;
auto is_tensor_array = [ctx](const std::string& name) {
return ctx->GetType(name) == framework::proto::VarType::LOD_TENSOR_ARRAY;
};
bool any_input_is_tensor_array =
......@@ -188,8 +185,7 @@ class SumOpVarTypeInference : public framework::VarTypeInference {
if (!all_inputs_are_tensor_array) {
std::ostringstream os;
for (auto& each : inputs) {
os << " " << each << " type is "
<< block->FindRecursiveOrCreateVar(each).GetType() << "\n";
os << " " << each << " type is " << ctx->GetType(each) << "\n";
}
PADDLE_ENFORCE(all_inputs_are_tensor_array,
"Not all inputs are tensor array:\n%s", os.str());
......@@ -199,11 +195,9 @@ class SumOpVarTypeInference : public framework::VarTypeInference {
var_type = framework::proto::VarType::LOD_TENSOR;
}
auto out_var_name = op_desc.Output("Out").front();
auto& out_var = block->FindRecursiveOrCreateVar(out_var_name);
out_var.SetType(var_type);
auto& in_var = detail::Ref(block->FindVarRecursive(inputs.front()));
out_var.SetDataType(in_var.GetDataType());
auto out_var_name = ctx->Output("Out").front();
ctx->SetType(out_var_name, var_type);
ctx->SetDataType(out_var_name, ctx->GetDataType(inputs.front()));
}
};
......
......@@ -177,10 +177,9 @@ class LoDTensorArray2TensorGradInferShape : public framework::InferShapeBase {
class LoDTensorArray2TensorGradInferVarType
: public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {
for (auto &out_var : op_desc.Output(framework::GradVarName("X"))) {
block->Var(out_var)->SetType(framework::proto::VarType::LOD_TENSOR_ARRAY);
void operator()(framework::InferVarTypeContext *ctx) const override {
for (auto &out_var : ctx->Output(framework::GradVarName("X"))) {
ctx->SetType(out_var, framework::proto::VarType::LOD_TENSOR_ARRAY);
}
}
};
......
......@@ -46,8 +46,7 @@ class TensorRTEngineOpMaker : public framework::OpProtoAndCheckerMaker {
class TensorRTEngineInferVarType : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {}
void operator()(framework::InferVarTypeContext *ctx) const override {}
};
} // namespace operators
......
......@@ -112,17 +112,16 @@ uniform distribution. The random result is in set [min, max].
class UniformRandomOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {
auto out_var_name = op_desc.Output("Out").front();
void operator()(framework::InferVarTypeContext *ctx) const override {
auto out_var_name = ctx->Output("Out").front();
auto var_data_type = static_cast<framework::proto::VarType::Type>(
boost::get<int>(op_desc.GetAttr("dtype")));
boost::get<int>(ctx->GetAttr("dtype")));
auto out_var = block->FindRecursiveOrCreateVar(out_var_name);
if (out_var.GetType() != framework::proto::VarType::SELECTED_ROWS) {
out_var.SetType(framework::proto::VarType::LOD_TENSOR);
if (ctx->GetType(out_var_name) !=
framework::proto::VarType::SELECTED_ROWS) {
ctx->SetType(out_var_name, framework::proto::VarType::LOD_TENSOR);
}
out_var.SetDataType(var_data_type);
ctx->SetDataType(out_var_name, var_data_type);
}
};
......
......@@ -38,7 +38,7 @@ void BindTracer(pybind11::module* m) {
.def("trace",
[](imperative::Tracer& self, imperative::OpBase* op,
const imperative::VarBasePtrMap& inputs,
const imperative::VarBasePtrMap& outputs,
imperative::VarBasePtrMap* outputs,
framework::AttributeMap attrs_map,
const platform::CPUPlace expected_place,
const bool stop_gradient = false) {
......@@ -49,7 +49,7 @@ void BindTracer(pybind11::module* m) {
.def("trace",
[](imperative::Tracer& self, imperative::OpBase* op,
const imperative::VarBasePtrMap& inputs,
const imperative::VarBasePtrMap& outputs,
imperative::VarBasePtrMap* outputs,
framework::AttributeMap attrs_map,
const platform::CUDAPlace expected_place,
const bool stop_gradient = false) {
......
......@@ -200,7 +200,7 @@ PYBIND11_MODULE(core, m) {
.def_property("name", &imperative::VarBase::Name,
&imperative::VarBase::SetName)
.def_property_readonly("shape", &imperative::VarBase::Shape)
.def_property_readonly("dtype", &imperative::VarBase::DType)
.def_property_readonly("dtype", &imperative::VarBase::DataType)
.def_property("persistable", &imperative::VarBase::IsPersistable,
&imperative::VarBase::SetPersistable)
.def_property("stop_gradient", &imperative::VarBase::IsStopGradient,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册