未验证 提交 c7f1f3ed 编写于 作者: Q Qiyang Min 提交者: GitHub

Merge pull request #16214 from velconia/imperative_infer_var_type

Implement imperative infer var type
...@@ -68,11 +68,11 @@ class SplitOpMaker : public OpProtoAndCheckerMaker { ...@@ -68,11 +68,11 @@ class SplitOpMaker : public OpProtoAndCheckerMaker {
class DummyVarTypeInference : public VarTypeInference { class DummyVarTypeInference : public VarTypeInference {
public: public:
void operator()(const OpDesc& op_desc, BlockDesc* block) const override { void operator()(framework::InferVarTypeContext* ctx) const override {
auto& inputs = op_desc.Input("X"); auto& inputs = ctx->Input("X");
auto type = block->Var(inputs.front())->GetType(); auto type = ctx->GetType(inputs.front());
auto out_var_name = op_desc.Output("Out").front(); auto out_var_name = ctx->Output("Out").front();
block->Var(out_var_name)->SetType(type); ctx->SetType(out_var_name, type);
} }
}; };
......
...@@ -16,6 +16,8 @@ limitations under the License. */ ...@@ -16,6 +16,8 @@ limitations under the License. */
#include <string> #include <string>
#include <tuple> #include <tuple>
#include <unordered_map>
#include <unordered_set>
#include <vector> #include <vector>
#include "paddle/fluid/framework/grad_op_desc_maker.h" #include "paddle/fluid/framework/grad_op_desc_maker.h"
#include "paddle/fluid/framework/inplace_op_inference.h" #include "paddle/fluid/framework/inplace_op_inference.h"
...@@ -127,9 +129,9 @@ struct OpInfoFiller<T, kGradOpDescMaker> { ...@@ -127,9 +129,9 @@ struct OpInfoFiller<T, kGradOpDescMaker> {
template <typename T> template <typename T>
struct OpInfoFiller<T, kVarTypeInference> { struct OpInfoFiller<T, kVarTypeInference> {
void operator()(const char* op_type, OpInfo* info) const { void operator()(const char* op_type, OpInfo* info) const {
info->infer_var_type_ = [](const OpDesc& fwd_op, BlockDesc* block) { info->infer_var_type_ = [](InferVarTypeContext* context) {
T inference; T inference;
inference(fwd_op, block); inference(context);
}; };
} }
}; };
......
...@@ -43,20 +43,20 @@ class SumOpMaker : public OpProtoAndCheckerMaker { ...@@ -43,20 +43,20 @@ class SumOpMaker : public OpProtoAndCheckerMaker {
class SumOpVarTypeInference : public VarTypeInference { class SumOpVarTypeInference : public VarTypeInference {
public: public:
void operator()(const OpDesc &op_desc, BlockDesc *block) const override { void operator()(InferVarTypeContext *ctx) const override {
auto &inputs = op_desc.Input("X"); auto &inputs = ctx->Input("X");
auto default_var_type = proto::VarType::SELECTED_ROWS; auto default_var_type = proto::VarType::SELECTED_ROWS;
bool any_input_is_lod_tensor = std::any_of( bool any_input_is_lod_tensor = std::any_of(
inputs.begin(), inputs.end(), [block](const std::string &name) { inputs.begin(), inputs.end(), [&ctx](const std::string &name) {
return block->Var(name)->GetType() == proto::VarType::LOD_TENSOR; return ctx->GetType(name) == proto::VarType::LOD_TENSOR;
}); });
if (any_input_is_lod_tensor) { if (any_input_is_lod_tensor) {
default_var_type = proto::VarType::LOD_TENSOR; default_var_type = proto::VarType::LOD_TENSOR;
} }
auto out_var_name = op_desc.Output("Out").front(); auto out_var_name = ctx->Output("Out").front();
block->Var(out_var_name)->SetType(default_var_type); ctx->SetType(out_var_name, default_var_type);
} }
}; };
...@@ -71,7 +71,7 @@ class DummyOpMaker : public OpProtoAndCheckerMaker { ...@@ -71,7 +71,7 @@ class DummyOpMaker : public OpProtoAndCheckerMaker {
class DummyOpVarTypeInference : public VarTypeInference { class DummyOpVarTypeInference : public VarTypeInference {
public: public:
void operator()(const OpDesc &op_desc, BlockDesc *block) const override {} void operator()(framework::InferVarTypeContext *ctx) const override {}
}; };
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
......
...@@ -24,6 +24,7 @@ limitations under the License. */ ...@@ -24,6 +24,7 @@ limitations under the License. */
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/shape_inference.h" #include "paddle/fluid/framework/shape_inference.h"
#include "paddle/fluid/framework/var_type_inference.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -677,7 +678,8 @@ void OpDesc::InferVarType(BlockDesc *block) const { ...@@ -677,7 +678,8 @@ void OpDesc::InferVarType(BlockDesc *block) const {
// var type inference. Hence, we don't do any "default" setting here. // var type inference. Hence, we don't do any "default" setting here.
auto &info = OpInfoMap::Instance().Get(this->Type()); auto &info = OpInfoMap::Instance().Get(this->Type());
if (info.infer_var_type_) { if (info.infer_var_type_) {
info.infer_var_type_(*this, block); InferVarTypeContext context(this, block);
info.infer_var_type_(&context);
} }
} }
......
...@@ -27,6 +27,7 @@ namespace framework { ...@@ -27,6 +27,7 @@ namespace framework {
class OperatorBase; class OperatorBase;
class OpDesc; class OpDesc;
class InferShapeContext; class InferShapeContext;
class InferVarTypeContext;
class BlockDesc; class BlockDesc;
class Variable; class Variable;
...@@ -53,7 +54,7 @@ using GradOpMakerFN = std::function<std::vector<std::unique_ptr<OpDesc>>( ...@@ -53,7 +54,7 @@ using GradOpMakerFN = std::function<std::vector<std::unique_ptr<OpDesc>>(
const std::vector<BlockDesc*>& grad_block)>; const std::vector<BlockDesc*>& grad_block)>;
using InferVarTypeFN = using InferVarTypeFN =
std::function<void(const OpDesc& /*op_desc*/, BlockDesc* /*block*/)>; std::function<void(framework::InferVarTypeContext* /*context*/)>;
using InferShapeFN = std::function<void(InferShapeContext*)>; using InferShapeFN = std::function<void(InferShapeContext*)>;
......
...@@ -14,6 +14,8 @@ limitations under the License. */ ...@@ -14,6 +14,8 @@ limitations under the License. */
#pragma once #pragma once
#include <string> #include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/type_defs.h" #include "paddle/fluid/framework/type_defs.h"
...@@ -21,26 +23,123 @@ limitations under the License. */ ...@@ -21,26 +23,123 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class OpDesc;
class BlockDesc;
// default infer var type context
class InferVarTypeContext {
public:
InferVarTypeContext(const OpDesc* op, BlockDesc* block)
: op_(op), block_(block) {}
virtual ~InferVarTypeContext() {}
virtual Attribute GetAttr(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(op_);
return op_->GetAttr(name);
}
virtual bool HasVar(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(block_);
return block_->FindVarRecursive(name) != nullptr;
}
virtual bool HasInput(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(op_);
return op_->Inputs().count(name) > 0;
}
virtual bool HasOutput(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(op_);
return op_->Outputs().count(name) > 0;
}
virtual const std::vector<std::string>& Input(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(op_);
return op_->Input(name);
}
virtual const std::vector<std::string>& Output(
const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(op_);
return op_->Output(name);
}
virtual proto::VarType::Type GetType(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(block_);
return block_->FindRecursiveOrCreateVar(name).GetType();
}
virtual void SetType(const std::string& name, proto::VarType::Type type) {
PADDLE_ENFORCE_NOT_NULL(block_);
block_->FindRecursiveOrCreateVar(name).SetType(type);
}
virtual proto::VarType::Type GetDataType(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(block_);
return block_->FindRecursiveOrCreateVar(name).GetDataType();
}
virtual void SetDataType(const std::string& name, proto::VarType::Type type) {
PADDLE_ENFORCE_NOT_NULL(block_);
block_->FindRecursiveOrCreateVar(name).SetDataType(type);
}
virtual std::vector<proto::VarType::Type> GetDataTypes(
const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(block_);
return block_->FindRecursiveOrCreateVar(name).GetDataTypes();
}
virtual void SetDataTypes(
const std::string& name,
const std::vector<proto::VarType::Type>& multiple_data_type) {
PADDLE_ENFORCE_NOT_NULL(block_);
block_->FindRecursiveOrCreateVar(name).SetDataTypes(multiple_data_type);
}
virtual std::vector<int64_t> GetShape(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(block_);
return block_->FindRecursiveOrCreateVar(name).GetShape();
}
virtual void SetShape(const std::string& name,
const std::vector<int64_t>& dims) {
PADDLE_ENFORCE_NOT_NULL(block_);
block_->FindRecursiveOrCreateVar(name).SetShape(dims);
}
virtual int32_t GetLoDLevel(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(block_);
return block_->FindRecursiveOrCreateVar(name).GetLoDLevel();
}
virtual void SetLoDLevel(const std::string& name, int32_t lod_level) {
PADDLE_ENFORCE_NOT_NULL(block_);
block_->FindRecursiveOrCreateVar(name).SetLoDLevel(lod_level);
}
protected:
const OpDesc* op_;
BlockDesc* block_;
};
class VarTypeInference { class VarTypeInference {
public: public:
virtual ~VarTypeInference() {} virtual ~VarTypeInference() {}
virtual void operator()(const OpDesc& op_desc, BlockDesc* block) const = 0; virtual void operator()(InferVarTypeContext* context) const = 0; // NOLINT
}; };
class PassInDtypeAndVarTypeToOutput : public framework::VarTypeInference { class PassInDtypeAndVarTypeToOutput : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc& op_desc, void operator()(framework::InferVarTypeContext* ctx) const final { // NOLINT
framework::BlockDesc* block) const final {
auto in_out_var_names = this->GetInputOutputWithSameType(); auto in_out_var_names = this->GetInputOutputWithSameType();
for (auto& i_o_n : in_out_var_names) { for (auto& i_o_n : in_out_var_names) {
auto& x_name = op_desc.Input(i_o_n.first).at(0); auto& x_name = ctx->Input(i_o_n.first).at(0);
auto& out_name = op_desc.Output(i_o_n.second).at(0); auto& out_name = ctx->Output(i_o_n.second).at(0);
auto& x = block->FindRecursiveOrCreateVar(x_name); ctx->SetType(out_name, ctx->GetType(x_name));
auto& out = block->FindRecursiveOrCreateVar(out_name); ctx->SetDataType(out_name, ctx->GetDataType(x_name));
out.SetType(x.GetType());
out.SetDataType(x.GetDataType());
} }
} }
......
...@@ -44,20 +44,20 @@ class SumOpMaker : public OpProtoAndCheckerMaker { ...@@ -44,20 +44,20 @@ class SumOpMaker : public OpProtoAndCheckerMaker {
class SumOpVarTypeInference : public VarTypeInference { class SumOpVarTypeInference : public VarTypeInference {
public: public:
void operator()(const OpDesc &op_desc, BlockDesc *block) const override { void operator()(framework::InferVarTypeContext *ctx) const override {
auto &inputs = op_desc.Input("X"); auto &inputs = ctx->Input("X");
auto default_var_type = proto::VarType::SELECTED_ROWS; auto default_var_type = proto::VarType::SELECTED_ROWS;
bool any_input_is_lod_tensor = std::any_of( bool any_input_is_lod_tensor = std::any_of(
inputs.begin(), inputs.end(), [block](const std::string &name) { inputs.begin(), inputs.end(), [&ctx](const std::string &name) {
return block->Var(name)->GetType() == proto::VarType::LOD_TENSOR; return ctx->GetType(name) == proto::VarType::LOD_TENSOR;
}); });
if (any_input_is_lod_tensor) { if (any_input_is_lod_tensor) {
default_var_type = proto::VarType::LOD_TENSOR; default_var_type = proto::VarType::LOD_TENSOR;
} }
auto out_var_name = op_desc.Output("Out").front(); auto out_var_name = ctx->Output("Out").front();
block->Var(out_var_name)->SetType(default_var_type); ctx->SetType(out_var_name, default_var_type);
} }
}; };
} // namespace framework } // namespace framework
......
...@@ -218,7 +218,7 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() { ...@@ -218,7 +218,7 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
"%s has no backward implementation", Type()); "%s has no backward implementation", Type());
VLOG(3) << "apply op grad: " << Type(); VLOG(3) << "apply op grad: " << Type();
std::vector<framework::VariableValueMap> tmp_grad_outputs; std::vector<VarBasePtrMap> tmp_grad_outputs;
if (backward_id_ > 0) { if (backward_id_ > 0) {
VLOG(3) << "py_layer_grad"; VLOG(3) << "py_layer_grad";
tmp_grad_outputs.resize(1); tmp_grad_outputs.resize(1);
...@@ -241,26 +241,62 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() { ...@@ -241,26 +241,62 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
auto& outputs = tmp_grad_outputs[k][it.first]; auto& outputs = tmp_grad_outputs[k][it.first];
outputs.reserve(it.second.size()); outputs.reserve(it.second.size());
for (size_t i = 0; i < it.second.size(); ++i) { for (size_t i = 0; i < it.second.size(); ++i) {
VarBase* origin_grad_var_base = it.second[i];
// Allocate a new variable // Allocate a new variable
Variable* tmp_var = new framework::Variable(); VarBase* tmp_grad_var_base = new VarBase(
tmp_var->GetMutable<framework::LoDTensor>(); string::Sprintf("%s@IGrad", origin_grad_var_base->Name()),
outputs.emplace_back(tmp_var); origin_grad_var_base->DataType(), origin_grad_var_base->Dims(),
place_, true, false);
outputs.emplace_back(tmp_grad_var_base);
} }
} }
// Run grad op
framework::RuntimeContext ctx(grad_input_vars_[k], tmp_grad_outputs[k]);
// No need to do compile time infer shape here. // No need to do compile time infer shape here.
// grad_op_desc_->InferShape(*block_); // grad_op_desc_->InferShape(*block_);
// grad_op_desc->InferVarType(block_); // grad_op_desc->InferVarType(block_);
std::unique_ptr<framework::OperatorBase> opbase = std::unique_ptr<framework::OperatorBase> opbase =
framework::OpRegistry::CreateOp(*grad_op_desc); framework::OpRegistry::CreateOp(*grad_op_desc);
auto& info = framework::OpInfoMap::Instance().Get(grad_op_desc->Type());
if (info.infer_var_type_) {
RuntimeInferVarTypeContext infer_var_type_ctx(
&grad_input_vars_[k], &tmp_grad_outputs[k], &attrs_);
info.infer_var_type_(&infer_var_type_ctx);
}
framework::OperatorWithKernel* op_kernel = framework::OperatorWithKernel* op_kernel =
dynamic_cast<framework::OperatorWithKernel*>(opbase.get()); dynamic_cast<framework::OperatorWithKernel*>(opbase.get());
PADDLE_ENFORCE_NOT_NULL(op_kernel, "only support op with kernel"); PADDLE_ENFORCE_NOT_NULL(op_kernel, "only support op with kernel");
// Run grad op
framework::VariableValueMap grad_invars_map;
framework::VariableValueMap grad_outvars_map;
for (const auto& it : grad_input_vars_[k]) {
auto& grad_invars = grad_invars_map[it.first];
grad_invars.reserve(it.second.size());
for (const VarBase* grad_inp : it.second) {
PADDLE_ENFORCE_NOT_NULL(grad_inp->var_, "op %s input %s nullptr",
grad_op_desc->Type(), grad_inp->Name());
grad_invars.emplace_back(grad_inp->var_);
}
}
for (const auto& it : tmp_grad_outputs[k]) {
auto& grad_outvars = grad_outvars_map[it.first];
grad_outvars.reserve(it.second.size());
for (VarBase* grad_out : it.second) {
PADDLE_ENFORCE_NOT_NULL(grad_out->var_, "op %s output %s nullptr",
grad_op_desc->Type(), grad_out->Name());
grad_outvars.emplace_back(grad_out->var_);
}
}
framework::RuntimeContext ctx(grad_invars_map, grad_outvars_map);
framework::Scope scope; framework::Scope scope;
PreparedOp p = PreparedOp::Prepare(ctx, *op_kernel, place_); PreparedOp p = PreparedOp::Prepare(ctx, *op_kernel, place_);
p.op.RuntimeInferShape(scope, place_, ctx); p.op.RuntimeInferShape(scope, place_, ctx);
...@@ -277,8 +313,8 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() { ...@@ -277,8 +313,8 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
PADDLE_ENFORCE_EQ(outputs.size(), origin_outputs.size()); PADDLE_ENFORCE_EQ(outputs.size(), origin_outputs.size());
for (size_t i = 0; i < outputs.size(); ++i) { for (size_t i = 0; i < outputs.size(); ++i) {
framework::Variable* grad = outputs[i]; framework::Variable* grad = outputs[i]->var_;
framework::Variable* orig_grad = origin_outputs[i]; framework::Variable* orig_grad = origin_outputs[i]->var_;
AddTo(grad, orig_grad, place_); AddTo(grad, orig_grad, place_);
delete grad; delete grad;
} }
...@@ -326,28 +362,35 @@ void PyLayer::RegisterFunc(int func_id, const py::object& py_func) { ...@@ -326,28 +362,35 @@ void PyLayer::RegisterFunc(int func_id, const py::object& py_func) {
int PyLayer::NumFuncs() { return py_funcs_.size(); } int PyLayer::NumFuncs() { return py_funcs_.size(); }
std::vector<Variable*> PyLayer::Apply(int func_id, std::vector<framework::Variable*> PyLayer::Apply(
const std::vector<VarBase*>& inputs) { int func_id, const std::vector<VarBase*>& inputs) {
std::vector<framework::Variable*> invars;
for (const VarBase* in : inputs) {
invars.push_back(in->var_);
}
PADDLE_ENFORCE(py_funcs_.find(func_id) != py_funcs_.end()); PADDLE_ENFORCE(py_funcs_.find(func_id) != py_funcs_.end());
return CallPythonFunc(py_funcs_[func_id], invars); return CallPythonFunc(py_funcs_[func_id], inputs);
} }
std::vector<Variable*> PyLayer::ApplyGrad( std::vector<VarBase*> PyLayer::ApplyGrad(int func_id,
int func_id, const std::vector<framework::Variable*>& inputs) { const std::vector<VarBase*>& inputs) {
PADDLE_ENFORCE(py_funcs_.find(func_id) != py_funcs_.end()); PADDLE_ENFORCE(py_funcs_.find(func_id) != py_funcs_.end());
return CallPythonFunc(py_funcs_[func_id], inputs); auto rets = CallPythonFunc(py_funcs_[func_id], inputs);
std::vector<VarBase*> outs;
outs.reserve(rets.size());
for (size_t i = 0U; i != rets.size(); ++i) {
outs.emplace_back(new VarBase(
string::Sprintf("%s_out_%d", framework::GradVarName(PyLayer::kFwdOut),
i),
rets[i], nullptr, true));
}
return outs;
} }
std::vector<framework::Variable*> PyLayer::CallPythonFunc( std::vector<framework::Variable*> PyLayer::CallPythonFunc(
const py::object& callable, const std::vector<framework::Variable*>& ins) { const py::object& callable, const std::vector<VarBase*>& ins) {
py::gil_scoped_acquire guard; py::gil_scoped_acquire guard;
py::tuple in_args(ins.size()); py::tuple in_args(ins.size());
for (size_t i = 0; i < ins.size(); ++i) { for (size_t i = 0; i < ins.size(); ++i) {
const framework::LoDTensor& t = ins[i]->Get<framework::LoDTensor>(); const framework::LoDTensor& t = ins[i]->var_->Get<framework::LoDTensor>();
in_args[i] = t.IsInitialized() ? py::cast(t) : py::cast(nullptr); in_args[i] = t.IsInitialized() ? py::cast(t) : py::cast(nullptr);
} }
VLOG(3) << "pyfunc in " << py::len(in_args); VLOG(3) << "pyfunc in " << py::len(in_args);
...@@ -357,6 +400,7 @@ std::vector<framework::Variable*> PyLayer::CallPythonFunc( ...@@ -357,6 +400,7 @@ std::vector<framework::Variable*> PyLayer::CallPythonFunc(
auto ret_tuple = py::cast<py::tuple>(ret); auto ret_tuple = py::cast<py::tuple>(ret);
size_t ret_num = py::len(ret_tuple); size_t ret_num = py::len(ret_tuple);
std::vector<framework::Variable*> outs; std::vector<framework::Variable*> outs;
outs.reserve(ret_num);
VLOG(3) << "pyfunc out " << ret_num; VLOG(3) << "pyfunc out " << ret_num;
for (size_t i = 0; i < ret_num; ++i) { for (size_t i = 0; i < ret_num; ++i) {
try { try {
...@@ -367,7 +411,7 @@ std::vector<framework::Variable*> PyLayer::CallPythonFunc( ...@@ -367,7 +411,7 @@ std::vector<framework::Variable*> PyLayer::CallPythonFunc(
auto* tensor = var->GetMutable<framework::LoDTensor>(); auto* tensor = var->GetMutable<framework::LoDTensor>();
tensor->ShareDataWith(*py_out_tensor); tensor->ShareDataWith(*py_out_tensor);
tensor->set_lod(py_out_tensor->lod()); tensor->set_lod(py_out_tensor->lod());
outs.push_back(var); outs.emplace_back(var);
} catch (py::cast_error&) { } catch (py::cast_error&) {
PADDLE_THROW("The %d-th output must be LoDTensor", i); PADDLE_THROW("The %d-th output must be LoDTensor", i);
} }
......
...@@ -18,14 +18,16 @@ ...@@ -18,14 +18,16 @@
#include "paddle/fluid/framework/python_headers.h" #include "paddle/fluid/framework/python_headers.h"
// clang-format on // clang-format on
#include <map> // NOLINT #include <map> // NOLINT
#include <string> // NOLINT #include <string> // NOLINT
#include <vector> // NOLINT #include <vector> // NOLINT
#include <memory> // NOLINT #include <memory> // NOLINT
#include <unordered_map> // NOLINT
#include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/var_desc.h" #include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/framework/var_type_inference.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
...@@ -135,13 +137,13 @@ class VarBase { ...@@ -135,13 +137,13 @@ class VarBase {
persistable) {} persistable) {}
private: private:
// TODO(minqiyang): need support SelectedRows
VarBase(const std::string& name, framework::proto::VarType::Type dtype, VarBase(const std::string& name, framework::proto::VarType::Type dtype,
const framework::DDim& shape, const platform::Place& place, const framework::DDim& shape, const platform::Place& place,
framework::Variable* var, VarBase* grad, bool stop_gradient, framework::Variable* var, VarBase* grad, bool stop_gradient,
bool persistable) bool persistable)
: name_(name), : name_(name),
dtype_(dtype), type_(framework::proto::VarType::LOD_TENSOR),
place_(place),
var_(var), var_(var),
grads_(grad), grads_(grad),
stop_gradient_(stop_gradient), stop_gradient_(stop_gradient),
...@@ -151,10 +153,12 @@ class VarBase { ...@@ -151,10 +153,12 @@ class VarBase {
pre_op_out_idx_(-1) { pre_op_out_idx_(-1) {
if (!var_) { if (!var_) {
var_ = new framework::Variable(); var_ = new framework::Variable();
auto tensor = var_->GetMutable<framework::LoDTensor>();
tensor->Resize(shape);
tensor->mutable_data(place_, dtype_);
} }
auto tensor = var_->GetMutable<framework::LoDTensor>();
tensor->Resize(shape);
tensor->mutable_data(place, dtype);
VLOG(10) << "create varbase: " << name_ << " type: " << dtype
<< " place: " << place;
} }
public: public:
...@@ -184,7 +188,23 @@ class VarBase { ...@@ -184,7 +188,23 @@ class VarBase {
} }
} }
inline framework::proto::VarType::Type DType() const { return dtype_; } inline framework::DDim Dims() const {
return var_->Get<framework::LoDTensor>().dims();
}
// data type. e.g.. FP32
inline void SetDataType(framework::proto::VarType::Type type) {
auto tensor = var_->GetMutable<framework::LoDTensor>();
tensor->mutable_data(tensor->place(), type);
}
inline framework::proto::VarType::Type DataType() const {
auto tensor = var_->Get<framework::LoDTensor>();
return tensor.type();
}
// tensor type. e.g.. LoDTensor
inline void SetType(framework::proto::VarType::Type type) { type_ = type; }
inline framework::proto::VarType::Type Type() const { return type_; }
inline void SetStopGradient(bool stop_gradient) { inline void SetStopGradient(bool stop_gradient) {
stop_gradient_ = stop_gradient; stop_gradient_ = stop_gradient;
...@@ -238,7 +258,7 @@ class VarBase { ...@@ -238,7 +258,7 @@ class VarBase {
} }
std::string name_; std::string name_;
framework::proto::VarType::Type dtype_; framework::proto::VarType::Type type_;
platform::Place place_; platform::Place place_;
framework::Variable* var_; framework::Variable* var_;
...@@ -334,11 +354,13 @@ class PYBIND11_HIDDEN OpBase { ...@@ -334,11 +354,13 @@ class PYBIND11_HIDDEN OpBase {
std::map<std::string, std::vector<int>> pre_ops_out_idx_; std::map<std::string, std::vector<int>> pre_ops_out_idx_;
// Inputs to a vector of bwd ops. // Inputs to a vector of bwd ops.
std::vector<framework::VariableValueMap> grad_input_vars_; std::vector<VarBasePtrMap> grad_input_vars_;
// Outputs to a vector of bwd ops. // Outputs to a vector of bwd ops.
std::vector<framework::VariableValueMap> grad_output_vars_; std::vector<VarBasePtrMap> grad_output_vars_;
std::vector<py::object> backward_hooks_; std::vector<py::object> backward_hooks_;
framework::AttributeMap attrs_;
}; };
class Layer { class Layer {
...@@ -365,12 +387,131 @@ class PyLayer { ...@@ -365,12 +387,131 @@ class PyLayer {
static std::vector<framework::Variable*> Apply( static std::vector<framework::Variable*> Apply(
int func_id, const std::vector<VarBase*>& inputs); int func_id, const std::vector<VarBase*>& inputs);
static std::vector<framework::Variable*> ApplyGrad( static std::vector<VarBase*> ApplyGrad(int func_id,
int func_id, const std::vector<framework::Variable*>& inputs); const std::vector<VarBase*>& inputs);
private: private:
static std::vector<framework::Variable*> CallPythonFunc( static std::vector<framework::Variable*> CallPythonFunc(
const py::object& callable, const std::vector<framework::Variable*>& ins); const py::object& callable, const std::vector<VarBase*>& ins);
};
// infer var type context for imperative mode
class PYBIND11_HIDDEN RuntimeInferVarTypeContext
: public framework::InferVarTypeContext {
public:
RuntimeInferVarTypeContext(const imperative::VarBasePtrMap* inputs,
imperative::VarBasePtrMap* outputs,
const framework::AttributeMap* attrs_map)
: InferVarTypeContext(nullptr, nullptr),
inputs_(inputs),
outputs_(outputs),
attrs_(attrs_map),
input_names_(),
output_names_(),
var_set_() {
input_names_.reserve(inputs_->size());
for (auto& it : *inputs_) {
for (imperative::VarBase* var : it.second) {
input_names_[it.first].emplace_back(var->Name());
var_set_[var->Name()] = var;
}
}
output_names_.reserve(outputs_->size());
for (auto& it : *outputs_) {
for (imperative::VarBase* var : it.second) {
output_names_[it.first].emplace_back(var->Name());
var_set_[var->Name()] = var;
}
}
}
virtual ~RuntimeInferVarTypeContext() {}
framework::Attribute GetAttr(const std::string& name) const override {
PADDLE_ENFORCE_NOT_NULL(attrs_);
return attrs_->at(name);
}
bool HasVar(const std::string& name) const override {
return var_set_.count(name) > 0;
}
bool HasInput(const std::string& name) const override {
PADDLE_ENFORCE_NOT_NULL(inputs_);
return inputs_->count(name) > 0;
}
bool HasOutput(const std::string& name) const override {
PADDLE_ENFORCE_NOT_NULL(outputs_);
return outputs_->count(name) > 0;
}
const std::vector<std::string>& Input(
const std::string& name) const override {
return input_names_.at(name);
}
const std::vector<std::string>& Output(
const std::string& name) const override {
return output_names_.at(name);
}
framework::proto::VarType::Type GetType(
const std::string& name) const override {
return var_set_.at(name)->Type();
}
void SetType(const std::string& name,
framework::proto::VarType::Type type) override {
var_set_[name]->SetType(type);
}
framework::proto::VarType::Type GetDataType(
const std::string& name) const override {
return var_set_.at(name)->DataType();
}
void SetDataType(const std::string& name,
framework::proto::VarType::Type type) override {
var_set_[name]->SetDataType(type);
}
std::vector<framework::proto::VarType::Type> GetDataTypes(
const std::string& name) const override {
PADDLE_THROW("GetDataTypes is not supported in runtime InferVarType");
}
void SetDataTypes(const std::string& name,
const std::vector<framework::proto::VarType::Type>&
multiple_data_type) override {
PADDLE_THROW("SetDataTypes is not supported in runtime InferVarType");
}
std::vector<int64_t> GetShape(const std::string& name) const override {
PADDLE_THROW("Do not handle Shape in runtime InferVarType");
}
void SetShape(const std::string& name,
const std::vector<int64_t>& dims) override {
PADDLE_THROW("Do not handle Shape in runtime InferVarType");
}
int32_t GetLoDLevel(const std::string& name) const override {
PADDLE_THROW("Do not handle LoDLevel in runtime InferVarType");
}
void SetLoDLevel(const std::string& name, int32_t lod_level) override {
PADDLE_THROW("Do not handle LoDLevel in runtime InferVarType");
}
private:
const imperative::VarBasePtrMap* inputs_;
imperative::VarBasePtrMap* outputs_;
const framework::AttributeMap* attrs_;
std::unordered_map<std::string, std::vector<std::string>> input_names_;
std::unordered_map<std::string, std::vector<std::string>> output_names_;
std::unordered_map<std::string, imperative::VarBase*> var_set_;
}; };
} // namespace imperative } // namespace imperative
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include "paddle/fluid/framework/var_type_inference.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
...@@ -135,7 +136,7 @@ framework::VariableNameMap CreateOutputVarNameMap( ...@@ -135,7 +136,7 @@ framework::VariableNameMap CreateOutputVarNameMap(
Tracer::Tracer(framework::BlockDesc* root_block) : root_block_(root_block) {} Tracer::Tracer(framework::BlockDesc* root_block) : root_block_(root_block) {}
std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
const VarBasePtrMap& outputs, VarBasePtrMap* outputs,
framework::AttributeMap attrs_map, framework::AttributeMap attrs_map,
const platform::Place expected_place, const platform::Place expected_place,
const bool stop_gradient) { const bool stop_gradient) {
...@@ -163,7 +164,7 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, ...@@ -163,7 +164,7 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
op->TrackPreOp(it.first, it.second); op->TrackPreOp(it.first, it.second);
} }
op->output_vars_ = outputs; op->output_vars_ = *outputs;
for (auto it : op->output_vars_) { for (auto it : op->output_vars_) {
auto& outvars = outvars_map[it.first]; auto& outvars = outvars_map[it.first];
const std::vector<VarBase*>& outputs = it.second; const std::vector<VarBase*>& outputs = it.second;
...@@ -186,7 +187,7 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, ...@@ -186,7 +187,7 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
framework::VariableNameMap invars_name_map = framework::VariableNameMap invars_name_map =
CreateInputVarNameMap(op, inputs); CreateInputVarNameMap(op, inputs);
framework::VariableNameMap outvars_name_map = framework::VariableNameMap outvars_name_map =
CreateOutputVarNameMap(op, outputs); CreateOutputVarNameMap(op, *outputs);
auto& info = framework::OpInfoMap::Instance().Get(op->Type()); auto& info = framework::OpInfoMap::Instance().Get(op->Type());
if (info.Checker() != nullptr) { if (info.Checker() != nullptr) {
...@@ -197,6 +198,11 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, ...@@ -197,6 +198,11 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
framework::OpRegistry::CreateOp(op->Type(), invars_name_map, framework::OpRegistry::CreateOp(op->Type(), invars_name_map,
outvars_name_map, attrs_map); outvars_name_map, attrs_map);
if (info.infer_var_type_) {
RuntimeInferVarTypeContext infer_var_type_ctx(&inputs, outputs, &attrs_map);
info.infer_var_type_(&infer_var_type_ctx);
}
// TODO(minqiyang): Support infer var type in imperative mode // TODO(minqiyang): Support infer var type in imperative mode
// Run forward op // Run forward op
VLOG(3) << "tracer running " << op->Type(); VLOG(3) << "tracer running " << op->Type();
...@@ -221,6 +227,7 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, ...@@ -221,6 +227,7 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
VLOG(5) << "start construct backward op"; VLOG(5) << "start construct backward op";
// construct grad op descs // construct grad op descs
op->attrs_ = attrs_map;
std::unique_ptr<framework::OpDesc> fwd_op_desc(new framework::OpDesc( std::unique_ptr<framework::OpDesc> fwd_op_desc(new framework::OpDesc(
op->Type(), invars_name_map, outvars_name_map, attrs_map)); op->Type(), invars_name_map, outvars_name_map, attrs_map));
std::unique_ptr<std::unordered_map<std::string, std::string>> grad_to_var( std::unique_ptr<std::unordered_map<std::string, std::string>> grad_to_var(
...@@ -247,12 +254,12 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, ...@@ -247,12 +254,12 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
auto fwd_var_it = current_vars_map.find(grad_invar); auto fwd_var_it = current_vars_map.find(grad_invar);
PADDLE_ENFORCE(fwd_var_it != current_vars_map.end()); PADDLE_ENFORCE(fwd_var_it != current_vars_map.end());
// Forward inputs or outputs. // Forward inputs or outputs.
grad_in_vars.emplace_back(fwd_var_it->second->var_); grad_in_vars.emplace_back(fwd_var_it->second);
} else { } else {
VarBase* var = current_vars_map[var_it->second]; VarBase* var = current_vars_map[var_it->second];
InitGrad(var, prepared_op.GetDeviceContext()); InitGrad(var, prepared_op.GetDeviceContext());
// Douts. // Douts.
grad_in_vars.emplace_back(var->grads_->var_); grad_in_vars.emplace_back(var->grads_);
} }
vars_saved_for_backward.insert(it.first); vars_saved_for_backward.insert(it.first);
...@@ -269,7 +276,7 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, ...@@ -269,7 +276,7 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
op->Type()); op->Type());
VarBase* var = current_vars_map[var_it->second]; VarBase* var = current_vars_map[var_it->second];
InitGrad(var, prepared_op.GetDeviceContext()); InitGrad(var, prepared_op.GetDeviceContext());
grad_out_vars.push_back(var->grads_->var_); grad_out_vars.push_back(var->grads_);
} }
} }
} }
...@@ -309,23 +316,23 @@ std::vector<VarBase*> Tracer::PyTrace(OpBase* op, ...@@ -309,23 +316,23 @@ std::vector<VarBase*> Tracer::PyTrace(OpBase* op,
auto& grad_output_vars = auto& grad_output_vars =
op->grad_output_vars_[0][framework::GradVarName(PyLayer::kFwdOut)]; op->grad_output_vars_[0][framework::GradVarName(PyLayer::kFwdOut)];
for (const VarBase* inp : inputs) { for (VarBase* inp : inputs) {
grad_input_vars.push_back(inp->var_); grad_input_vars.push_back(inp);
} }
for (VarBase* out : outputs) { for (VarBase* out : outputs) {
grad_input_vars.push_back(out->var_); grad_input_vars.push_back(out);
} }
// TODO(minqiyang): Add GPU support for PyLayer, only support CPU now // TODO(minqiyang): Add GPU support for PyLayer, only support CPU now
platform::CPUPlace place; platform::CPUPlace place;
for (VarBase* out : outputs) { for (VarBase* out : outputs) {
InitGrad(out, platform::DeviceContextPool::Instance().Get(place)); InitGrad(out, platform::DeviceContextPool::Instance().Get(place));
grad_input_vars.push_back(out->grads_->var_); grad_input_vars.push_back(out->grads_);
} }
for (VarBase* inp : inputs) { for (VarBase* inp : inputs) {
InitGrad(inp, platform::DeviceContextPool::Instance().Get(place)); InitGrad(inp, platform::DeviceContextPool::Instance().Get(place));
grad_output_vars.push_back(inp->grads_->var_); grad_output_vars.push_back(inp->grads_);
} }
} }
return outputs; return outputs;
......
...@@ -48,7 +48,7 @@ class Tracer { ...@@ -48,7 +48,7 @@ class Tracer {
virtual ~Tracer() {} virtual ~Tracer() {}
std::set<std::string> Trace(OpBase* op, const VarBasePtrMap& inputs, std::set<std::string> Trace(OpBase* op, const VarBasePtrMap& inputs,
const VarBasePtrMap& outputs, VarBasePtrMap* outputs, // NOLINT
framework::AttributeMap attrs_map, framework::AttributeMap attrs_map,
const platform::Place expected_place, const platform::Place expected_place,
const bool stop_gradient = false); const bool stop_gradient = false);
......
...@@ -25,6 +25,7 @@ class VarBase; ...@@ -25,6 +25,7 @@ class VarBase;
class OpBase; class OpBase;
typedef std::map<std::string, std::vector<VarBase*>> VarBasePtrMap; typedef std::map<std::string, std::vector<VarBase*>> VarBasePtrMap;
typedef std::map<std::string, std::vector<const VarBase*>> ConstVarBasePtrMap;
typedef std::map<std::string, std::vector<OpBase*>> OpBasePtrMap; typedef std::map<std::string, std::vector<OpBase*>> OpBasePtrMap;
} // namespace imperative } // namespace imperative
......
...@@ -178,10 +178,10 @@ Beam Search Decode Operator. This Operator constructs the full hypotheses for ...@@ -178,10 +178,10 @@ Beam Search Decode Operator. This Operator constructs the full hypotheses for
each source sentence by walking back along the LoDTensorArray Input(ids) each source sentence by walking back along the LoDTensorArray Input(ids)
whose lods can be used to restore the path in the beam search tree. whose lods can be used to restore the path in the beam search tree.
The Output(SentenceIds) and Output(SentenceScores) separately contain the The Output(SentenceIds) and Output(SentenceScores) separately contain the
generated id sequences and the corresponding scores. The shapes and lods of the generated id sequences and the corresponding scores. The shapes and lods of the
two LodTensor are same. The lod level is 2 and the two levels separately two LodTensor are same. The lod level is 2 and the two levels separately
indicate how many hypotheses each source sentence has and how many ids each indicate how many hypotheses each source sentence has and how many ids each
hypothesis has. hypothesis has.
)DOC"); )DOC");
} }
...@@ -203,15 +203,12 @@ class BeamSearchDecodeInferShape : public framework::InferShapeBase { ...@@ -203,15 +203,12 @@ class BeamSearchDecodeInferShape : public framework::InferShapeBase {
class BeamSearchDecodeInferVarType : public framework::VarTypeInference { class BeamSearchDecodeInferVarType : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc& op_desc, void operator()(framework::InferVarTypeContext* ctx) const override {
framework::BlockDesc* block) const override { for (auto& o : ctx->Output("SentenceIds")) {
for (auto& o : op_desc.Output("SentenceIds")) { ctx->SetType(o, framework::proto::VarType::LOD_TENSOR);
auto& sentence_ids = block->FindRecursiveOrCreateVar(o);
sentence_ids.SetType(framework::proto::VarType::LOD_TENSOR);
} }
for (auto& o : op_desc.Output("SentenceScores")) { for (auto& o : ctx->Output("SentenceScores")) {
auto& sentence_scores = block->FindRecursiveOrCreateVar(o); ctx->SetType(o, framework::proto::VarType::LOD_TENSOR);
sentence_scores.SetType(framework::proto::VarType::LOD_TENSOR);
} }
} }
}; };
......
...@@ -65,7 +65,7 @@ class BeamSearchOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -65,7 +65,7 @@ class BeamSearchOpMaker : public framework::OpProtoAndCheckerMaker {
.SetDefault(true); .SetDefault(true);
AddComment(R"DOC( AddComment(R"DOC(
This operator does the search in beams for one time step. This operator does the search in beams for one time step.
Specifically, it selects the top-K candidate word ids of current step from Specifically, it selects the top-K candidate word ids of current step from
Input(ids) according to their Input(scores) for all source sentences, Input(ids) according to their Input(scores) for all source sentences,
where K is Attr(beam_size) and Input(ids), Input(scores) are predicted results where K is Attr(beam_size) and Input(ids), Input(scores) are predicted results
...@@ -120,15 +120,12 @@ class BeamSearchOp : public framework::OperatorWithKernel { ...@@ -120,15 +120,12 @@ class BeamSearchOp : public framework::OperatorWithKernel {
class BeamSearchInferVarType : public framework::VarTypeInference { class BeamSearchInferVarType : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc &op_desc, void operator()(framework::InferVarTypeContext *ctx) const override {
framework::BlockDesc *block) const override { for (auto &o : ctx->Output("selected_ids")) {
for (auto &o : op_desc.Output("selected_ids")) { ctx->SetType(o, framework::proto::VarType::LOD_TENSOR);
auto &selected_ids = block->FindRecursiveOrCreateVar(o);
selected_ids.SetType(framework::proto::VarType::LOD_TENSOR);
} }
for (auto &o : op_desc.Output("selected_scores")) { for (auto &o : ctx->Output("selected_scores")) {
auto &selected_scores = block->FindRecursiveOrCreateVar(o); ctx->SetType(o, framework::proto::VarType::LOD_TENSOR);
selected_scores.SetType(framework::proto::VarType::LOD_TENSOR);
} }
} }
}; };
......
...@@ -93,11 +93,9 @@ execution. ...@@ -93,11 +93,9 @@ execution.
class GetPlacesInferVarType : public framework::VarTypeInference { class GetPlacesInferVarType : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc &op_desc, void operator()(framework::InferVarTypeContext *ctx) const override {
framework::BlockDesc *block) const override { for (auto &o_name : ctx->Output("Out")) {
for (auto &o_name : op_desc.Output("Out")) { ctx->SetType(o_name, framework::proto::VarType::PLACE_LIST);
block->FindRecursiveOrCreateVar(o_name).SetType(
framework::proto::VarType::PLACE_LIST);
} }
} }
}; };
......
...@@ -100,16 +100,13 @@ class WriteToArrayInferShape : public framework::InferShapeBase { ...@@ -100,16 +100,13 @@ class WriteToArrayInferShape : public framework::InferShapeBase {
class WriteToArrayInferVarType : public framework::VarTypeInference { class WriteToArrayInferVarType : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc &op_desc, void operator()(framework::InferVarTypeContext *ctx) const override {
framework::BlockDesc *block) const override { auto x_name = ctx->Input("X")[0];
auto x_name = op_desc.Input("X")[0]; auto out_name = ctx->Output("Out")[0];
auto out_name = op_desc.Output("Out")[0];
VLOG(10) << "Set Variable " << out_name << " as LOD_TENSOR_ARRAY"; VLOG(10) << "Set Variable " << out_name << " as LOD_TENSOR_ARRAY";
auto &out = block->FindRecursiveOrCreateVar(out_name); ctx->SetType(out_name, framework::proto::VarType::LOD_TENSOR_ARRAY);
out.SetType(framework::proto::VarType::LOD_TENSOR_ARRAY); if (ctx->HasVar(x_name)) {
auto *x = block->FindVarRecursive(x_name); ctx->SetDataType(out_name, ctx->GetDataType(x_name));
if (x != nullptr) {
out.SetDataType(x->GetDataType());
} }
} }
}; };
......
...@@ -365,19 +365,16 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker { ...@@ -365,19 +365,16 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
class WhileGradOpVarTypeInference : public framework::VarTypeInference { class WhileGradOpVarTypeInference : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc &op_desc, void operator()(framework::InferVarTypeContext *ctx) const override {
framework::BlockDesc *block) const override { auto p_names = ctx->Input(kX);
auto p_names = op_desc.Input(kX); auto pg_ig_names = ctx->Output(framework::GradVarName(kX));
auto pg_ig_names = op_desc.Output(framework::GradVarName(kX));
for (size_t i = 0; i < p_names.size(); ++i) { for (size_t i = 0; i < p_names.size(); ++i) {
auto &p_var = detail::Ref(block->FindVarRecursive(p_names[i])); if (ctx->HasVar(pg_ig_names[i])) {
auto *g_var = block->FindVarRecursive(pg_ig_names[i]);
if (g_var != nullptr) { // Gradient could be @EMPTY@
VLOG(5) << "Setting " << pg_ig_names[i] << " following " << p_names[i] VLOG(5) << "Setting " << pg_ig_names[i] << " following " << p_names[i]
<< " type: " << p_var.GetType(); << " type: " << ctx->GetType(p_names[i]);
g_var->SetType(p_var.GetType()); ctx->SetType(pg_ig_names[i], ctx->GetType(p_names[i]));
g_var->SetDataType(p_var.GetDataType()); ctx->SetDataType(pg_ig_names[i], ctx->GetDataType(p_names[i]));
} }
} }
} }
......
...@@ -56,8 +56,7 @@ class FakeInitOp : public framework::OperatorBase { ...@@ -56,8 +56,7 @@ class FakeInitOp : public framework::OperatorBase {
class FakeInitOpVarTypeInference : public framework::VarTypeInference { class FakeInitOpVarTypeInference : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc &op_desc, void operator()(framework::InferVarTypeContext *ctx) const override {}
framework::BlockDesc *block) const override {}
}; };
class FakeInitOpMaker : public framework::OpProtoAndCheckerMaker { class FakeInitOpMaker : public framework::OpProtoAndCheckerMaker {
......
...@@ -114,11 +114,10 @@ class MergeIdsOp : public framework::OperatorWithKernel { ...@@ -114,11 +114,10 @@ class MergeIdsOp : public framework::OperatorWithKernel {
class MergeIdsOpInferVarType : public framework::VarTypeInference { class MergeIdsOpInferVarType : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc &op_desc, void operator()(framework::InferVarTypeContext *ctx) const override {
framework::BlockDesc *block) const override { auto input_type = ctx->GetType(ctx->Input("Ids")[0]);
auto *input_var = block->Var(op_desc.Input("Ids")[0]); for (auto &out_var : ctx->Output("Out")) {
for (auto &out_var : op_desc.Output("Out")) { ctx->SetType(out_var, input_type);
block->Var(out_var)->SetType(input_var->GetType());
} }
} }
}; };
......
...@@ -14,6 +14,8 @@ limitations under the License. */ ...@@ -14,6 +14,8 @@ limitations under the License. */
#include "paddle/fluid/operators/distributed_ops/split_ids_op.h" #include "paddle/fluid/operators/distributed_ops/split_ids_op.h"
#include <memory>
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -71,11 +73,10 @@ class SplitIdsOp : public framework::OperatorWithKernel { ...@@ -71,11 +73,10 @@ class SplitIdsOp : public framework::OperatorWithKernel {
class SplitIdsOpInferVarType : public framework::VarTypeInference { class SplitIdsOpInferVarType : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc &op_desc, void operator()(framework::InferVarTypeContext *ctx) const override {
framework::BlockDesc *block) const override { auto input_type = ctx->GetType(ctx->Input("Ids")[0]);
auto *input_var = block->Var(op_desc.Input("Ids")[0]); for (auto &out_var : ctx->Output("Out")) {
for (auto &out_var : op_desc.Output("Out")) { ctx->SetType(out_var, input_type);
block->Var(out_var)->SetType(input_var->GetType());
} }
} }
}; };
......
...@@ -39,12 +39,11 @@ class FillConstantOp : public framework::OperatorWithKernel { ...@@ -39,12 +39,11 @@ class FillConstantOp : public framework::OperatorWithKernel {
class FillConstantOpVarTypeInference : public framework::VarTypeInference { class FillConstantOpVarTypeInference : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc& op_desc, void operator()(framework::InferVarTypeContext* ctx) const override {
framework::BlockDesc* block) const override {
auto data_type = static_cast<framework::proto::VarType::Type>( auto data_type = static_cast<framework::proto::VarType::Type>(
boost::get<int>(op_desc.GetAttr("dtype"))); boost::get<int>(ctx->GetAttr("dtype")));
auto& out_var_name = op_desc.Output("Out").front(); auto& out_var_name = ctx->Output("Out").front();
block->Var(out_var_name)->SetDataType(data_type); ctx->SetDataType(out_var_name, data_type);
} }
}; };
......
...@@ -138,22 +138,20 @@ class FusedEmbeddingSeqPoolOpGrad : public framework::OperatorWithKernel { ...@@ -138,22 +138,20 @@ class FusedEmbeddingSeqPoolOpGrad : public framework::OperatorWithKernel {
class FusedEmbeddingSeqPoolOpGradVarTypeInference class FusedEmbeddingSeqPoolOpGradVarTypeInference
: public framework::VarTypeInference { : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc& op_desc, void operator()(framework::InferVarTypeContext* ctx) const override {
framework::BlockDesc* block) const override { auto out_var_name = ctx->Output(framework::GradVarName("W")).front();
auto out_var_name = op_desc.Output(framework::GradVarName("W")).front(); auto attr = ctx->GetAttr("is_sparse");
auto attr = op_desc.GetAttr("is_sparse");
bool is_sparse = boost::get<bool>(attr); bool is_sparse = boost::get<bool>(attr);
if (is_sparse) { if (is_sparse) {
VLOG(3) << "fused_embedding_seq_pool_grad op " VLOG(3) << "fused_embedding_seq_pool_grad op "
<< framework::GradVarName("W") << " is set to SelectedRows"; << framework::GradVarName("W") << " is set to SelectedRows";
block->Var(out_var_name) ctx->SetType(out_var_name, framework::proto::VarType::SELECTED_ROWS);
->SetType(framework::proto::VarType::SELECTED_ROWS);
} else { } else {
VLOG(3) << "fused_embedding_seq_pool_grad op " VLOG(3) << "fused_embedding_seq_pool_grad op "
<< framework::GradVarName("W") << " is set to LoDTensor"; << framework::GradVarName("W") << " is set to LoDTensor";
block->Var(out_var_name)->SetType(framework::proto::VarType::LOD_TENSOR); ctx->SetType(out_var_name, framework::proto::VarType::LOD_TENSOR);
} }
block->Var(out_var_name)->SetDataType(block->Var("W")->GetDataType()); ctx->SetDataType(out_var_name, ctx->GetDataType(ctx->Input("W")[0]));
} }
}; };
......
...@@ -81,15 +81,12 @@ GetTensorFromSelectedRows is used to get the tensor from SelectedRows. ...@@ -81,15 +81,12 @@ GetTensorFromSelectedRows is used to get the tensor from SelectedRows.
class GetTensorFromSelectedRowsOpVarTypeInference class GetTensorFromSelectedRowsOpVarTypeInference
: public framework::VarTypeInference { : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc &op_desc, void operator()(framework::InferVarTypeContext *ctx) const { // NOLINT
framework::BlockDesc *block) const final { auto out_var_name = ctx->Output("Out").front();
auto out_var_name = op_desc.Output("Out").front(); auto in_var_name = ctx->Input("X").front();
auto in_var_name = op_desc.Input("X").front();
ctx->SetType(out_var_name, framework::proto::VarType::LOD_TENSOR);
auto out_var = block->FindRecursiveOrCreateVar(out_var_name); ctx->SetDataType(out_var_name, ctx->GetDataType(in_var_name));
auto in_var = block->FindRecursiveOrCreateVar(in_var_name);
out_var.SetType(framework::proto::VarType::LOD_TENSOR);
out_var.SetDataType(in_var.GetDataType());
} }
}; };
......
...@@ -197,38 +197,32 @@ class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel { ...@@ -197,38 +197,32 @@ class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel {
class HierarchicalSigmoidGradOpGradVarTypeInference class HierarchicalSigmoidGradOpGradVarTypeInference
: public framework::VarTypeInference { : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc& op_desc, void operator()(framework::InferVarTypeContext* ctx) const override {
framework::BlockDesc* block) const override { auto w_grad_var_name = ctx->Output(framework::GradVarName("W")).front();
auto w_grad_var_name = op_desc.Output(framework::GradVarName("W")).front(); auto bias_grad_var_name_vec = ctx->Output(framework::GradVarName("Bias"));
auto bias_grad_var_name_vec =
op_desc.Output(framework::GradVarName("Bias"));
std::string bias_grad_var_name; std::string bias_grad_var_name;
bool hasBias = false; bool hasBias = false;
if (bias_grad_var_name_vec.size()) { if (bias_grad_var_name_vec.size()) {
hasBias = true; hasBias = true;
bias_grad_var_name = bias_grad_var_name = ctx->Output(framework::GradVarName("Bias")).front();
op_desc.Output(framework::GradVarName("Bias")).front();
} }
auto attr = op_desc.GetAttr("is_sparse"); auto attr = ctx->GetAttr("is_sparse");
bool is_sparse = boost::get<bool>(attr); bool is_sparse = boost::get<bool>(attr);
if (is_sparse) { if (is_sparse) {
VLOG(30) << "hierarchical_sigmoid_grad op " << framework::GradVarName("W") VLOG(30) << "hierarchical_sigmoid_grad op " << framework::GradVarName("W")
<< " is set to SelectedRows"; << " is set to SelectedRows";
block->Var(w_grad_var_name) ctx->SetType(w_grad_var_name, framework::proto::VarType::SELECTED_ROWS);
->SetType(framework::proto::VarType::SELECTED_ROWS);
} else { } else {
VLOG(30) << "hierarchical_sigmoid_grad op " << framework::GradVarName("W") VLOG(30) << "hierarchical_sigmoid_grad op " << framework::GradVarName("W")
<< " is set to LoDTensor"; << " is set to LoDTensor";
block->Var(w_grad_var_name) ctx->SetType(w_grad_var_name, framework::proto::VarType::LOD_TENSOR);
->SetType(framework::proto::VarType::LOD_TENSOR);
} }
if (hasBias) { if (hasBias) {
VLOG(30) << "hierarchical_sigmoid_grad op " VLOG(30) << "hierarchical_sigmoid_grad op "
<< framework::GradVarName("Bias") << " is set to LoDTensor"; << framework::GradVarName("Bias") << " is set to LoDTensor";
block->Var(bias_grad_var_name) ctx->SetType(bias_grad_var_name, framework::proto::VarType::LOD_TENSOR);
->SetType(framework::proto::VarType::LOD_TENSOR);
} }
block->Var(w_grad_var_name)->SetDataType(block->Var("W")->GetDataType()); ctx->SetDataType(w_grad_var_name, ctx->GetDataType(ctx->Input("W")[0]));
} }
}; };
......
...@@ -64,11 +64,9 @@ class LoDRankTableInferShape : public framework::InferShapeBase { ...@@ -64,11 +64,9 @@ class LoDRankTableInferShape : public framework::InferShapeBase {
class LoDRankTableInferVarType : public framework::VarTypeInference { class LoDRankTableInferVarType : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc &op_desc, void operator()(framework::InferVarTypeContext *ctx) const override {
framework::BlockDesc *block) const override { for (auto &o : ctx->Output("Out")) {
for (auto &o : op_desc.Output("Out")) { ctx->SetType(o, framework::proto::VarType::LOD_RANK_TABLE);
block->FindRecursiveOrCreateVar(o).SetType(
framework::proto::VarType::LOD_RANK_TABLE);
} }
} }
}; };
......
...@@ -201,10 +201,9 @@ class LoDTensorToArrayInferShape : public framework::InferShapeBase { ...@@ -201,10 +201,9 @@ class LoDTensorToArrayInferShape : public framework::InferShapeBase {
class LoDTensorToArrayInferVarType : public framework::VarTypeInference { class LoDTensorToArrayInferVarType : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc &op_desc, void operator()(framework::InferVarTypeContext *ctx) const override {
framework::BlockDesc *block) const override { for (auto &out_var : ctx->Output("Out")) {
for (auto &out_var : op_desc.Output("Out")) { ctx->SetType(out_var, framework::proto::VarType::LOD_TENSOR_ARRAY);
block->Var(out_var)->SetType(framework::proto::VarType::LOD_TENSOR_ARRAY);
} }
} }
}; };
......
...@@ -147,22 +147,20 @@ class LookupTableOpGrad : public framework::OperatorWithKernel { ...@@ -147,22 +147,20 @@ class LookupTableOpGrad : public framework::OperatorWithKernel {
class LookupTableOpGradVarTypeInference : public framework::VarTypeInference { class LookupTableOpGradVarTypeInference : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc& op_desc, void operator()(framework::InferVarTypeContext* ctx) const override {
framework::BlockDesc* block) const override { auto out_var_name = ctx->Output(framework::GradVarName("W")).front();
auto out_var_name = op_desc.Output(framework::GradVarName("W")).front(); auto attr = ctx->GetAttr("is_sparse");
auto attr = op_desc.GetAttr("is_sparse");
bool is_sparse = boost::get<bool>(attr); bool is_sparse = boost::get<bool>(attr);
if (is_sparse) { if (is_sparse) {
VLOG(3) << "lookup_table_grad op " << framework::GradVarName("W") VLOG(3) << "lookup_table_grad op " << framework::GradVarName("W")
<< " is set to SelectedRows"; << " is set to SelectedRows";
block->Var(out_var_name) ctx->SetType(out_var_name, framework::proto::VarType::SELECTED_ROWS);
->SetType(framework::proto::VarType::SELECTED_ROWS);
} else { } else {
VLOG(3) << "lookup_table_grad op " << framework::GradVarName("W") VLOG(3) << "lookup_table_grad op " << framework::GradVarName("W")
<< " is set to LoDTensor"; << " is set to LoDTensor";
block->Var(out_var_name)->SetType(framework::proto::VarType::LOD_TENSOR); ctx->SetType(out_var_name, framework::proto::VarType::LOD_TENSOR);
} }
block->Var(out_var_name)->SetDataType(block->Var("W")->GetDataType()); ctx->SetDataType(out_var_name, ctx->GetDataType(ctx->Input("W")[0]));
} }
}; };
......
...@@ -60,12 +60,9 @@ class NCCLInitOp : public framework::OperatorBase { ...@@ -60,12 +60,9 @@ class NCCLInitOp : public framework::OperatorBase {
class NCCLInitOpVarTypeInference : public framework::VarTypeInference { class NCCLInitOpVarTypeInference : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc &op_desc, void operator()(framework::InferVarTypeContext *ctx) const override {
framework::BlockDesc *block) const override { auto out_var_name = ctx->Output("Communicator").front();
auto out_var_name = op_desc.Output("Communicator").front(); ctx->SetType(out_var_name, framework::proto::VarType::RAW);
auto &out_var = block->FindRecursiveOrCreateVar(out_var_name);
auto var_type = framework::proto::VarType::RAW;
out_var.SetType(var_type);
} }
}; };
......
...@@ -237,23 +237,21 @@ class NCEOpGrad : public framework::OperatorWithKernel { ...@@ -237,23 +237,21 @@ class NCEOpGrad : public framework::OperatorWithKernel {
class NCEOpGradVarTypeInference : public framework::VarTypeInference { class NCEOpGradVarTypeInference : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc &op_desc, void operator()(framework::InferVarTypeContext *ctx) const override {
framework::BlockDesc *block) const override { auto weight_grad = ctx->Output(framework::GradVarName("Weight")).front();
auto weight_grad = op_desc.Output(framework::GradVarName("Weight")).front();
auto attr = op_desc.GetAttr("is_sparse"); auto attr = ctx->GetAttr("is_sparse");
bool is_sparse = boost::get<bool>(attr); bool is_sparse = boost::get<bool>(attr);
if (is_sparse) { if (is_sparse) {
VLOG(3) << "nce_op_grad op " << weight_grad << " and " VLOG(3) << "nce_op_grad op " << weight_grad << " and "
<< " is set to SelectedRows"; << " is set to SelectedRows";
block->Var(weight_grad) ctx->SetType(weight_grad, framework::proto::VarType::SELECTED_ROWS);
->SetType(framework::proto::VarType::SELECTED_ROWS);
} else { } else {
VLOG(3) << "nce_op_grad op " << weight_grad << " and " VLOG(3) << "nce_op_grad op " << weight_grad << " and "
<< " is set to LoDTensor"; << " is set to LoDTensor";
block->Var(weight_grad)->SetType(framework::proto::VarType::LOD_TENSOR); ctx->SetType(weight_grad, framework::proto::VarType::LOD_TENSOR);
} }
block->Var(weight_grad)->SetDataType(block->Var("Input")->GetDataType()); ctx->SetDataType(weight_grad, ctx->GetDataType(ctx->Input("Input")[0]));
} }
}; };
......
...@@ -37,8 +37,7 @@ class NgraphEngineOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -37,8 +37,7 @@ class NgraphEngineOpMaker : public framework::OpProtoAndCheckerMaker {
class NgraphEngineInferVarType : public framework::VarTypeInference { class NgraphEngineInferVarType : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc &op_desc, void operator()(framework::InferVarTypeContext *ctx) const override {}
framework::BlockDesc *block) const override {}
}; };
} // namespace operators } // namespace operators
......
...@@ -56,9 +56,9 @@ This optimizer use LARS (https://arxiv.org/abs/1708.03888) to optimize each ...@@ -56,9 +56,9 @@ This optimizer use LARS (https://arxiv.org/abs/1708.03888) to optimize each
weight using a local learning rate: weight using a local learning rate:
$$ $$
local\_lr = \eta * local\_lr = \eta *
\frac{\left \| param \right \|}{\left \| grad \right \| + \beta *\left \| param \right \|} \\ \frac{\left \| param \right \|}{\left \| grad \right \| + \beta *\left \| param \right \|} \\
velocity = mu * velocity + velocity = mu * velocity +
local\_lr * (grad + \beta * param) \\ local\_lr * (grad + \beta * param) \\
param = param - velocity. \\ param = param - velocity. \\
$$ $$
...@@ -72,8 +72,7 @@ use L2 regularizers in case of using LARS. ...@@ -72,8 +72,7 @@ use L2 regularizers in case of using LARS.
class LarsMomentumOpVarTypeInference : public framework::VarTypeInference { class LarsMomentumOpVarTypeInference : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc &op_desc, void operator()(framework::InferVarTypeContext* ctx) const override {}
framework::BlockDesc *block) const override {}
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
......
...@@ -21,18 +21,14 @@ using Tensor = framework::Tensor; ...@@ -21,18 +21,14 @@ using Tensor = framework::Tensor;
class MomentumOpInferVarType : public framework::VarTypeInference { class MomentumOpInferVarType : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc& op_desc, void operator()(framework::InferVarTypeContext* ctx) const override {
framework::BlockDesc* block) const override { auto& input_var = ctx->Input("Param")[0];
auto input_var = op_desc.Input("Param")[0]; for (auto& out_var : ctx->Output("ParamOut")) {
for (auto& out_var : op_desc.Output("ParamOut")) { if (ctx->GetType(input_var) == framework::proto::VarType::SELECTED_ROWS) {
if (block->FindRecursiveOrCreateVar(input_var).GetType() == ctx->SetType(out_var, framework::proto::VarType::SELECTED_ROWS);
framework::proto::VarType::SELECTED_ROWS) { } else if (ctx->GetType(input_var) ==
block->FindRecursiveOrCreateVar(out_var).SetType(
framework::proto::VarType::SELECTED_ROWS);
} else if (block->FindRecursiveOrCreateVar(input_var).GetType() ==
framework::proto::VarType::LOD_TENSOR) { framework::proto::VarType::LOD_TENSOR) {
block->FindRecursiveOrCreateVar(out_var).SetType( ctx->SetType(out_var, framework::proto::VarType::LOD_TENSOR);
framework::proto::VarType::LOD_TENSOR);
} else { } else {
PADDLE_THROW( PADDLE_THROW(
"Only support LodTensor and SelectedRows, Unexpected Input Type."); "Only support LodTensor and SelectedRows, Unexpected Input Type.");
......
...@@ -50,20 +50,18 @@ class SGDOp : public framework::OperatorWithKernel { ...@@ -50,20 +50,18 @@ class SGDOp : public framework::OperatorWithKernel {
class SGDOpInferVarType : public framework::VarTypeInference { class SGDOpInferVarType : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc &op_desc, void operator()(framework::InferVarTypeContext *ctx) const override {
framework::BlockDesc *block) const override { auto &input_var_n = ctx->Input("Param")[0];
auto input_var_n = op_desc.Input("Param")[0]; auto in_var_type = ctx->GetType(input_var_n);
auto in_var_type = block->FindRecursiveOrCreateVar(input_var_n).GetType();
PADDLE_ENFORCE(in_var_type == framework::proto::VarType::SELECTED_ROWS || PADDLE_ENFORCE(in_var_type == framework::proto::VarType::SELECTED_ROWS ||
in_var_type == framework::proto::VarType::LOD_TENSOR, in_var_type == framework::proto::VarType::LOD_TENSOR,
"The input Var's type should be LoDtensor or SelectedRows," "The input Var's type should be LoDtensor or SelectedRows,"
" but the received var(%s)'s type is %s", " but the received var(%s)'s type is %s",
input_var_n, in_var_type); input_var_n, in_var_type);
for (auto &out_var_n : op_desc.Output("ParamOut")) { for (auto &out_var_n : ctx->Output("ParamOut")) {
auto &out_var = block->FindRecursiveOrCreateVar(out_var_n); if (ctx->GetType(out_var_n) != in_var_type) {
if (out_var.GetType() != in_var_type) { ctx->SetType(out_var_n, in_var_type);
out_var.SetType(in_var_type);
} }
} }
} }
......
...@@ -14,8 +14,11 @@ ...@@ -14,8 +14,11 @@
#include "paddle/fluid/operators/py_func_op.h" #include "paddle/fluid/operators/py_func_op.h"
#include <memory>
#include <set> #include <set>
#include <string> #include <string>
#include <unordered_set>
#include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
...@@ -91,15 +94,12 @@ static void CallPythonFunc(py::object *callable, ...@@ -91,15 +94,12 @@ static void CallPythonFunc(py::object *callable,
} }
} }
class PyFuncOpVarTypInference : public framework::VarTypeInference { class PyFuncOpVarTypeInference : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc &op, void operator()(framework::InferVarTypeContext *ctx) const override {
framework::BlockDesc *block) const override { bool has_out = (ctx->HasOutput("Out") && !ctx->Output("Out").empty());
auto &outs = op.Outputs();
bool has_out = (outs.count("Out") > 0 && !outs.at("Out").empty());
auto &ins = op.Inputs(); bool has_in = (ctx->HasInput("X") && !ctx->Input("X").empty());
bool has_in = (ins.count("X") > 0 && !ins.at("X").empty());
/** /**
* X or Out can be empty, so that py_func can be more flexible * X or Out can be empty, so that py_func can be more flexible
...@@ -107,8 +107,8 @@ class PyFuncOpVarTypInference : public framework::VarTypeInference { ...@@ -107,8 +107,8 @@ class PyFuncOpVarTypInference : public framework::VarTypeInference {
*/ */
PADDLE_ENFORCE(has_in || has_out, "Input(X) or Output(Out) must exist"); PADDLE_ENFORCE(has_in || has_out, "Input(X) or Output(Out) must exist");
PADDLE_ENFORCE_GE(boost::get<int>(op.GetAttr(kForwardPythonCallableId)), 0, PADDLE_ENFORCE_GE(boost::get<int>(ctx->GetAttr(kForwardPythonCallableId)),
"Function id cannot be less than 0"); 0, "Function id cannot be less than 0");
if (!has_out) return; if (!has_out) return;
...@@ -118,7 +118,7 @@ class PyFuncOpVarTypInference : public framework::VarTypeInference { ...@@ -118,7 +118,7 @@ class PyFuncOpVarTypInference : public framework::VarTypeInference {
* the corresponding forward variable * the corresponding forward variable
*/ */
const std::string kGradVarSuffix = framework::kGradVarSuffix; const std::string kGradVarSuffix = framework::kGradVarSuffix;
auto &out_var_names = outs.at("Out"); auto &out_var_names = ctx->Output("Out");
for (auto &out_var_name : out_var_names) { for (auto &out_var_name : out_var_names) {
if (out_var_name == framework::kEmptyVarName || if (out_var_name == framework::kEmptyVarName ||
out_var_name.size() < kGradVarSuffix.size()) { out_var_name.size() < kGradVarSuffix.size()) {
...@@ -128,18 +128,17 @@ class PyFuncOpVarTypInference : public framework::VarTypeInference { ...@@ -128,18 +128,17 @@ class PyFuncOpVarTypInference : public framework::VarTypeInference {
size_t len = out_var_name.size() - kGradVarSuffix.size(); size_t len = out_var_name.size() - kGradVarSuffix.size();
if (out_var_name.substr(len) == kGradVarSuffix) { if (out_var_name.substr(len) == kGradVarSuffix) {
auto fwd_var_name = out_var_name.substr(0, len); auto fwd_var_name = out_var_name.substr(0, len);
auto *out_var_desc = block->FindVarRecursive(out_var_name); PADDLE_ENFORCE(ctx->HasVar(out_var_name),
auto *fwd_var_desc = block->FindVarRecursive(fwd_var_name); "Backward variable %s not found", out_var_name);
PADDLE_ENFORCE_NOT_NULL(out_var_desc, "Backward variable %s not found", PADDLE_ENFORCE(ctx->HasVar(fwd_var_name),
out_var_name); "Backward variable %s not found", fwd_var_name);
PADDLE_ENFORCE_NOT_NULL(fwd_var_desc, "Forward variable %s not found",
fwd_var_name);
VLOG(10) << "Infer var_desc of Output(" << out_var_name << ") as Input(" VLOG(10) << "Infer var_desc of Output(" << out_var_name << ") as Input("
<< fwd_var_name << ")"; << fwd_var_name << ")";
out_var_desc->SetShape(fwd_var_desc->GetShape());
out_var_desc->SetDataType(fwd_var_desc->GetDataType()); ctx->SetShape(out_var_name, ctx->GetShape(fwd_var_name));
out_var_desc->SetLoDLevel(fwd_var_desc->GetLoDLevel()); ctx->SetDataType(out_var_name, ctx->GetDataType(fwd_var_name));
out_var_desc->SetType(fwd_var_desc->GetType()); ctx->SetLoDLevel(out_var_name, ctx->GetLoDLevel(fwd_var_name));
ctx->SetType(out_var_name, ctx->GetType(fwd_var_name));
} }
} }
} }
...@@ -309,5 +308,5 @@ class PyFuncOp : public framework::OperatorBase { ...@@ -309,5 +308,5 @@ class PyFuncOp : public framework::OperatorBase {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(py_func, ops::PyFuncOp, ops::PyFuncOpMaker, REGISTER_OPERATOR(py_func, ops::PyFuncOp, ops::PyFuncOpMaker,
ops::PyFuncOpVarTypInference, ops::PyFuncOpShapeInference, ops::PyFuncOpVarTypeInference, ops::PyFuncOpShapeInference,
ops::PyFuncOpGradDescMaker); ops::PyFuncOpGradDescMaker);
...@@ -85,10 +85,10 @@ class CreateCustomReaderOpMaker : public DecoratedReaderMakerBase { ...@@ -85,10 +85,10 @@ class CreateCustomReaderOpMaker : public DecoratedReaderMakerBase {
AddComment(R"DOC( AddComment(R"DOC(
CreateCustomReader Operator CreateCustomReader Operator
A custom reader can be used for input data preprocessing. A custom reader can be used for input data preprocessing.
A custom reader holds its own sub-block, which will be executed in CPU A custom reader holds its own sub-block, which will be executed in CPU
in its 'ReadNext()' function. Users can configurate their own in its 'ReadNext()' function. Users can configurate their own
preprocessing pipelines by inserting operators into custom reader's preprocessing pipelines by inserting operators into custom reader's
sub-block. sub-block.
)DOC"); )DOC");
} }
...@@ -123,23 +123,22 @@ class CustomReaderInferShape : public framework::InferShapeBase { ...@@ -123,23 +123,22 @@ class CustomReaderInferShape : public framework::InferShapeBase {
class CustomReaderInferVarType : public framework::VarTypeInference { class CustomReaderInferVarType : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc& op_desc, void operator()(framework::InferVarTypeContext* ctx) const override {
framework::BlockDesc* block) const override { auto& out_var_name = ctx->Output("Out")[0];
framework::VarDesc* out_reader = block->FindVar(op_desc.Output("Out")[0]); PADDLE_ENFORCE(ctx->HasVar(out_var_name));
PADDLE_ENFORCE_NOT_NULL(out_reader); ctx->SetType(out_var_name, framework::proto::VarType::READER);
out_reader->SetType(framework::proto::VarType::READER);
auto sink_var_names = auto sink_var_names =
boost::get<std::vector<std::string>>(op_desc.GetAttr("sink_var_names")); boost::get<std::vector<std::string>>(ctx->GetAttr("sink_var_names"));
const auto* sub_block = const auto* sub_block =
boost::get<framework::BlockDesc*>(op_desc.GetAttr("sub_block")); boost::get<framework::BlockDesc*>(ctx->GetAttr("sub_block"));
std::vector<framework::proto::VarType::Type> res_data_types; std::vector<framework::proto::VarType::Type> res_data_types;
for (const std::string& var_name : sink_var_names) { for (const std::string& var_name : sink_var_names) {
framework::VarDesc* var = sub_block->FindVar(var_name); framework::VarDesc* var = sub_block->FindVar(var_name);
PADDLE_ENFORCE_NOT_NULL(var); PADDLE_ENFORCE_NOT_NULL(var);
res_data_types.emplace_back(var->GetDataType()); res_data_types.emplace_back(var->GetDataType());
} }
out_reader->SetDataTypes(res_data_types); ctx->SetDataTypes(out_var_name, res_data_types);
} }
}; };
......
...@@ -51,19 +51,16 @@ class ReadInferShape : public framework::InferShapeBase { ...@@ -51,19 +51,16 @@ class ReadInferShape : public framework::InferShapeBase {
class ReadInferVarType : public framework::VarTypeInference { class ReadInferVarType : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc& op_desc, void operator()(framework::InferVarTypeContext* ctx) const override {
framework::BlockDesc* block) const override { bool infer_out = boost::get<bool>(ctx->GetAttr("infer_out"));
bool infer_out = boost::get<bool>(op_desc.GetAttr("infer_out"));
if (infer_out) { if (infer_out) {
std::string reader_name = op_desc.Input("Reader")[0]; std::string reader_name = ctx->Input("Reader")[0];
std::vector<std::string> out_names = op_desc.Output("Out"); std::vector<std::string> out_names = ctx->Output("Out");
framework::VarDesc* reader = block->FindVarRecursive(reader_name); auto dtypes = ctx->GetDataTypes(reader_name);
auto dtypes = reader->GetDataTypes();
PADDLE_ENFORCE_EQ(dtypes.size(), out_names.size()); PADDLE_ENFORCE_EQ(dtypes.size(), out_names.size());
for (size_t i = 0; i < dtypes.size(); ++i) { for (size_t i = 0; i < dtypes.size(); ++i) {
framework::VarDesc& out = block->FindRecursiveOrCreateVar(out_names[i]); ctx->SetType(out_names[i], framework::proto::VarType::LOD_TENSOR);
out.SetType(framework::proto::VarType::LOD_TENSOR); ctx->SetDataType(out_names[i], dtypes[i]);
out.SetDataType(dtypes[i]);
} }
} }
} }
......
...@@ -98,11 +98,10 @@ void FileReaderInferShape::operator()(framework::InferShapeContext* ctx) const { ...@@ -98,11 +98,10 @@ void FileReaderInferShape::operator()(framework::InferShapeContext* ctx) const {
} }
} }
void FileReaderInferVarType::operator()(const framework::OpDesc& op_desc, void FileReaderInferVarType::operator()(
framework::BlockDesc* block) const { framework::InferVarTypeContext* ctx) const {
std::string reader_name = op_desc.Output("Out")[0]; std::string reader_name = ctx->Output("Out")[0];
framework::VarDesc* reader = block->FindVarRecursive(reader_name); ctx->SetType(reader_name, framework::proto::VarType::READER);
reader->SetType(framework::proto::VarType::READER);
} }
void DecoratedReaderInferShape::operator()( void DecoratedReaderInferShape::operator()(
...@@ -125,13 +124,11 @@ void DecoratedReaderInferShape::operator()( ...@@ -125,13 +124,11 @@ void DecoratedReaderInferShape::operator()(
} }
void DecoratedReaderInferVarType::operator()( void DecoratedReaderInferVarType::operator()(
const framework::OpDesc& op_desc, framework::BlockDesc* block) const { framework::InferVarTypeContext* ctx) const {
std::string in_reader_name = op_desc.Input("UnderlyingReader")[0]; const std::string& in_reader_name = ctx->Input("UnderlyingReader")[0];
framework::VarDesc* in_reader = block->FindVarRecursive(in_reader_name); const std::string& out_reader_name = ctx->Output("Out")[0];
std::string out_reader_name = op_desc.Output("Out")[0]; ctx->SetType(out_reader_name, framework::proto::VarType::READER);
framework::VarDesc* out_reader = block->FindVarRecursive(out_reader_name); ctx->SetDataTypes(out_reader_name, ctx->GetDataTypes(in_reader_name));
out_reader->SetType(framework::proto::VarType::READER);
out_reader->SetDataTypes(in_reader->GetDataTypes());
} }
void DecoratedReaderMakerBase::Make() { void DecoratedReaderMakerBase::Make() {
......
...@@ -14,7 +14,9 @@ ...@@ -14,7 +14,9 @@
#pragma once #pragma once
#include <memory>
#include <string> #include <string>
#include <unordered_map>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/reader.h" #include "paddle/fluid/framework/reader.h"
...@@ -59,8 +61,7 @@ class FileReaderInferShape : public framework::InferShapeBase { ...@@ -59,8 +61,7 @@ class FileReaderInferShape : public framework::InferShapeBase {
class FileReaderInferVarType : public framework::VarTypeInference { class FileReaderInferVarType : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc& op_desc, void operator()(framework::InferVarTypeContext* ctx) const override;
framework::BlockDesc* block) const override;
}; };
// general infershape for decorated reader // general infershape for decorated reader
...@@ -72,8 +73,7 @@ class DecoratedReaderInferShape : public framework::InferShapeBase { ...@@ -72,8 +73,7 @@ class DecoratedReaderInferShape : public framework::InferShapeBase {
// general var type inference for decorated reader // general var type inference for decorated reader
class DecoratedReaderInferVarType : public framework::VarTypeInference { class DecoratedReaderInferVarType : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc& op_desc, void operator()(framework::InferVarTypeContext* ctx) const override;
framework::BlockDesc* block) const override;
}; };
class DecoratedReaderMakerBase : public framework::OpProtoAndCheckerMaker { class DecoratedReaderMakerBase : public framework::OpProtoAndCheckerMaker {
......
...@@ -159,12 +159,9 @@ This operator will serialize and write LoDTensor / SelectedRows variable to file ...@@ -159,12 +159,9 @@ This operator will serialize and write LoDTensor / SelectedRows variable to file
class SaveOpVarTypeInference : public framework::VarTypeInference { class SaveOpVarTypeInference : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc &op_desc, void operator()(framework::InferVarTypeContext *ctx) const override {
framework::BlockDesc *block) const override { auto out_var_name = ctx->Output(LOOKUP_TABLE_PATH).front();
auto out_var_name = op_desc.Output(LOOKUP_TABLE_PATH).front(); ctx->SetType(out_var_name, framework::proto::VarType::RAW);
auto &out_var = block->FindRecursiveOrCreateVar(out_var_name);
auto var_type = framework::proto::VarType::RAW;
out_var.SetType(var_type);
} }
}; };
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/scale_op.h" #include "paddle/fluid/operators/scale_op.h"
#include <memory>
#include <string> #include <string>
#include "paddle/fluid/operators/detail/safe_ref.h" #include "paddle/fluid/operators/detail/safe_ref.h"
...@@ -69,17 +70,13 @@ $$Out = scale*(X + bias)$$ ...@@ -69,17 +70,13 @@ $$Out = scale*(X + bias)$$
class ScaleOpVarTypeInference : public framework::VarTypeInference { class ScaleOpVarTypeInference : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc &op_desc, void operator()(framework::InferVarTypeContext *ctx) const override {
framework::BlockDesc *block) const override { auto &in_var_name = ctx->Input("X").front();
auto &in_var_name = op_desc.Input("X").front(); auto out_var_name = ctx->Output("Out").front();
auto &in_var = detail::Ref(block->FindVarRecursive(in_var_name));
auto out_var_name = op_desc.Output("Out").front();
auto *out_var = block->FindVarRecursive(out_var_name);
if (in_var_name != out_var_name) { if (in_var_name != out_var_name) {
out_var->SetType(in_var.GetType()); ctx->SetType(out_var_name, ctx->GetType(in_var_name));
out_var->SetDataType(in_var.GetDataType()); ctx->SetDataType(out_var_name, ctx->GetDataType(in_var_name));
} }
} }
}; };
......
...@@ -14,6 +14,8 @@ limitations under the License. */ ...@@ -14,6 +14,8 @@ limitations under the License. */
#include "paddle/fluid/operators/split_selected_rows_op.h" #include "paddle/fluid/operators/split_selected_rows_op.h"
#include <memory>
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -60,10 +62,9 @@ class SplitSelectedRowsOp : public framework::OperatorWithKernel { ...@@ -60,10 +62,9 @@ class SplitSelectedRowsOp : public framework::OperatorWithKernel {
class SplitSelectedRowsOpInferVarType : public framework::VarTypeInference { class SplitSelectedRowsOpInferVarType : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc &op_desc, void operator()(framework::InferVarTypeContext *ctx) const override {
framework::BlockDesc *block) const override { for (auto &out_var : ctx->Output("Out")) {
for (auto &out_var : op_desc.Output("Out")) { ctx->SetType(out_var, framework::proto::VarType::SELECTED_ROWS);
block->Var(out_var)->SetType(framework::proto::VarType::SELECTED_ROWS);
} }
} }
}; };
......
...@@ -12,6 +12,7 @@ limitations under the License. */ ...@@ -12,6 +12,7 @@ limitations under the License. */
#include "paddle/fluid/operators/sum_op.h" #include "paddle/fluid/operators/sum_op.h"
#include <algorithm> #include <algorithm>
#include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -159,24 +160,20 @@ the LoD information with the first input. ...@@ -159,24 +160,20 @@ the LoD information with the first input.
class SumOpVarTypeInference : public framework::VarTypeInference { class SumOpVarTypeInference : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc& op_desc, void operator()(framework::InferVarTypeContext* ctx) const override {
framework::BlockDesc* block) const override { auto& inputs = ctx->Input("X");
auto& inputs = op_desc.Input("X");
auto var_type = framework::proto::VarType::SELECTED_ROWS; auto var_type = framework::proto::VarType::SELECTED_ROWS;
for (auto& name : op_desc.Input("X")) { for (auto& name : ctx->Input("X")) {
VLOG(10) << name << " " VLOG(10) << name << " " << ctx->GetType(name);
<< block->FindRecursiveOrCreateVar(name).GetType();
} }
bool any_input_is_lod_tensor = std::any_of( bool any_input_is_lod_tensor = std::any_of(
inputs.begin(), inputs.end(), [block](const std::string& name) { inputs.begin(), inputs.end(), [ctx](const std::string& name) {
return block->FindRecursiveOrCreateVar(name).GetType() == return ctx->GetType(name) == framework::proto::VarType::LOD_TENSOR;
framework::proto::VarType::LOD_TENSOR;
}); });
auto is_tensor_array = [block](const std::string& name) { auto is_tensor_array = [ctx](const std::string& name) {
return block->FindRecursiveOrCreateVar(name).GetType() == return ctx->GetType(name) == framework::proto::VarType::LOD_TENSOR_ARRAY;
framework::proto::VarType::LOD_TENSOR_ARRAY;
}; };
bool any_input_is_tensor_array = bool any_input_is_tensor_array =
...@@ -188,8 +185,7 @@ class SumOpVarTypeInference : public framework::VarTypeInference { ...@@ -188,8 +185,7 @@ class SumOpVarTypeInference : public framework::VarTypeInference {
if (!all_inputs_are_tensor_array) { if (!all_inputs_are_tensor_array) {
std::ostringstream os; std::ostringstream os;
for (auto& each : inputs) { for (auto& each : inputs) {
os << " " << each << " type is " os << " " << each << " type is " << ctx->GetType(each) << "\n";
<< block->FindRecursiveOrCreateVar(each).GetType() << "\n";
} }
PADDLE_ENFORCE(all_inputs_are_tensor_array, PADDLE_ENFORCE(all_inputs_are_tensor_array,
"Not all inputs are tensor array:\n%s", os.str()); "Not all inputs are tensor array:\n%s", os.str());
...@@ -199,11 +195,9 @@ class SumOpVarTypeInference : public framework::VarTypeInference { ...@@ -199,11 +195,9 @@ class SumOpVarTypeInference : public framework::VarTypeInference {
var_type = framework::proto::VarType::LOD_TENSOR; var_type = framework::proto::VarType::LOD_TENSOR;
} }
auto out_var_name = op_desc.Output("Out").front(); auto out_var_name = ctx->Output("Out").front();
auto& out_var = block->FindRecursiveOrCreateVar(out_var_name); ctx->SetType(out_var_name, var_type);
out_var.SetType(var_type); ctx->SetDataType(out_var_name, ctx->GetDataType(inputs.front()));
auto& in_var = detail::Ref(block->FindVarRecursive(inputs.front()));
out_var.SetDataType(in_var.GetDataType());
} }
}; };
......
...@@ -177,10 +177,9 @@ class LoDTensorArray2TensorGradInferShape : public framework::InferShapeBase { ...@@ -177,10 +177,9 @@ class LoDTensorArray2TensorGradInferShape : public framework::InferShapeBase {
class LoDTensorArray2TensorGradInferVarType class LoDTensorArray2TensorGradInferVarType
: public framework::VarTypeInference { : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc &op_desc, void operator()(framework::InferVarTypeContext *ctx) const override {
framework::BlockDesc *block) const override { for (auto &out_var : ctx->Output(framework::GradVarName("X"))) {
for (auto &out_var : op_desc.Output(framework::GradVarName("X"))) { ctx->SetType(out_var, framework::proto::VarType::LOD_TENSOR_ARRAY);
block->Var(out_var)->SetType(framework::proto::VarType::LOD_TENSOR_ARRAY);
} }
} }
}; };
......
...@@ -46,8 +46,7 @@ class TensorRTEngineOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -46,8 +46,7 @@ class TensorRTEngineOpMaker : public framework::OpProtoAndCheckerMaker {
class TensorRTEngineInferVarType : public framework::VarTypeInference { class TensorRTEngineInferVarType : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc &op_desc, void operator()(framework::InferVarTypeContext *ctx) const override {}
framework::BlockDesc *block) const override {}
}; };
} // namespace operators } // namespace operators
......
...@@ -112,17 +112,16 @@ uniform distribution. The random result is in set [min, max]. ...@@ -112,17 +112,16 @@ uniform distribution. The random result is in set [min, max].
class UniformRandomOpVarTypeInference : public framework::VarTypeInference { class UniformRandomOpVarTypeInference : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc &op_desc, void operator()(framework::InferVarTypeContext *ctx) const override {
framework::BlockDesc *block) const override { auto out_var_name = ctx->Output("Out").front();
auto out_var_name = op_desc.Output("Out").front();
auto var_data_type = static_cast<framework::proto::VarType::Type>( auto var_data_type = static_cast<framework::proto::VarType::Type>(
boost::get<int>(op_desc.GetAttr("dtype"))); boost::get<int>(ctx->GetAttr("dtype")));
auto out_var = block->FindRecursiveOrCreateVar(out_var_name); if (ctx->GetType(out_var_name) !=
if (out_var.GetType() != framework::proto::VarType::SELECTED_ROWS) { framework::proto::VarType::SELECTED_ROWS) {
out_var.SetType(framework::proto::VarType::LOD_TENSOR); ctx->SetType(out_var_name, framework::proto::VarType::LOD_TENSOR);
} }
out_var.SetDataType(var_data_type); ctx->SetDataType(out_var_name, var_data_type);
} }
}; };
......
...@@ -38,7 +38,7 @@ void BindTracer(pybind11::module* m) { ...@@ -38,7 +38,7 @@ void BindTracer(pybind11::module* m) {
.def("trace", .def("trace",
[](imperative::Tracer& self, imperative::OpBase* op, [](imperative::Tracer& self, imperative::OpBase* op,
const imperative::VarBasePtrMap& inputs, const imperative::VarBasePtrMap& inputs,
const imperative::VarBasePtrMap& outputs, imperative::VarBasePtrMap* outputs,
framework::AttributeMap attrs_map, framework::AttributeMap attrs_map,
const platform::CPUPlace expected_place, const platform::CPUPlace expected_place,
const bool stop_gradient = false) { const bool stop_gradient = false) {
...@@ -49,7 +49,7 @@ void BindTracer(pybind11::module* m) { ...@@ -49,7 +49,7 @@ void BindTracer(pybind11::module* m) {
.def("trace", .def("trace",
[](imperative::Tracer& self, imperative::OpBase* op, [](imperative::Tracer& self, imperative::OpBase* op,
const imperative::VarBasePtrMap& inputs, const imperative::VarBasePtrMap& inputs,
const imperative::VarBasePtrMap& outputs, imperative::VarBasePtrMap* outputs,
framework::AttributeMap attrs_map, framework::AttributeMap attrs_map,
const platform::CUDAPlace expected_place, const platform::CUDAPlace expected_place,
const bool stop_gradient = false) { const bool stop_gradient = false) {
......
...@@ -200,7 +200,7 @@ PYBIND11_MODULE(core, m) { ...@@ -200,7 +200,7 @@ PYBIND11_MODULE(core, m) {
.def_property("name", &imperative::VarBase::Name, .def_property("name", &imperative::VarBase::Name,
&imperative::VarBase::SetName) &imperative::VarBase::SetName)
.def_property_readonly("shape", &imperative::VarBase::Shape) .def_property_readonly("shape", &imperative::VarBase::Shape)
.def_property_readonly("dtype", &imperative::VarBase::DType) .def_property_readonly("dtype", &imperative::VarBase::DataType)
.def_property("persistable", &imperative::VarBase::IsPersistable, .def_property("persistable", &imperative::VarBase::IsPersistable,
&imperative::VarBase::SetPersistable) &imperative::VarBase::SetPersistable)
.def_property("stop_gradient", &imperative::VarBase::IsStopGradient, .def_property("stop_gradient", &imperative::VarBase::IsStopGradient,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册