提交 ca392c7e 编写于 作者: M minqiyang

Implement infer var type context

上级 0b49e43d
...@@ -68,11 +68,11 @@ class SplitOpMaker : public OpProtoAndCheckerMaker { ...@@ -68,11 +68,11 @@ class SplitOpMaker : public OpProtoAndCheckerMaker {
class DummyVarTypeInference : public VarTypeInference { class DummyVarTypeInference : public VarTypeInference {
public: public:
void operator()(const OpDesc& op_desc, BlockDesc* block) const override { void operator()(framework::InferVarTypeContext& ctx) const override {
auto& inputs = op_desc.Input("X"); auto& inputs = ctx.Input("X");
auto type = block->Var(inputs.front())->GetType(); auto type = ctx.GetType(inputs.front());
auto out_var_name = op_desc.Output("Out").front(); auto out_var_name = ctx.Output("Out").front();
block->Var(out_var_name)->SetType(type); ctx.SetType(out_var_name, type);
} }
}; };
......
...@@ -127,9 +127,9 @@ struct OpInfoFiller<T, kGradOpDescMaker> { ...@@ -127,9 +127,9 @@ struct OpInfoFiller<T, kGradOpDescMaker> {
template <typename T> template <typename T>
struct OpInfoFiller<T, kVarTypeInference> { struct OpInfoFiller<T, kVarTypeInference> {
void operator()(const char* op_type, OpInfo* info) const { void operator()(const char* op_type, OpInfo* info) const {
info->infer_var_type_ = [](const OpDesc& fwd_op, BlockDesc* block) { info->infer_var_type_ = [](InferVarTypeContext& context) {
T inference; T inference;
inference(fwd_op, block); inference(context);
}; };
} }
}; };
......
...@@ -43,20 +43,20 @@ class SumOpMaker : public OpProtoAndCheckerMaker { ...@@ -43,20 +43,20 @@ class SumOpMaker : public OpProtoAndCheckerMaker {
class SumOpVarTypeInference : public VarTypeInference { class SumOpVarTypeInference : public VarTypeInference {
public: public:
void operator()(const OpDesc &op_desc, BlockDesc *block) const override { void operator()(InferVarTypeContext &ctx) const override {
auto &inputs = op_desc.Input("X"); auto &inputs = ctx.Input("X");
auto default_var_type = proto::VarType::SELECTED_ROWS; auto default_var_type = proto::VarType::SELECTED_ROWS;
bool any_input_is_lod_tensor = std::any_of( bool any_input_is_lod_tensor = std::any_of(
inputs.begin(), inputs.end(), [block](const std::string &name) { inputs.begin(), inputs.end(), [ctx](const std::string &name) {
return block->Var(name)->GetType() == proto::VarType::LOD_TENSOR; return ctx.GetType(name) == proto::VarType::LOD_TENSOR;
}); });
if (any_input_is_lod_tensor) { if (any_input_is_lod_tensor) {
default_var_type = proto::VarType::LOD_TENSOR; default_var_type = proto::VarType::LOD_TENSOR;
} }
auto out_var_name = op_desc.Output("Out").front(); auto out_var_name = ctx.Output("Out").front();
block->Var(out_var_name)->SetType(default_var_type); ctx.SetType(out_var_name, default_var_type);
} }
}; };
...@@ -71,7 +71,7 @@ class DummyOpMaker : public OpProtoAndCheckerMaker { ...@@ -71,7 +71,7 @@ class DummyOpMaker : public OpProtoAndCheckerMaker {
class DummyOpVarTypeInference : public VarTypeInference { class DummyOpVarTypeInference : public VarTypeInference {
public: public:
void operator()(const OpDesc &op_desc, BlockDesc *block) const override {} void operator()(framework::InferVarTypeContext &ctx) const override {}
}; };
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
......
...@@ -24,6 +24,7 @@ limitations under the License. */ ...@@ -24,6 +24,7 @@ limitations under the License. */
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/shape_inference.h" #include "paddle/fluid/framework/shape_inference.h"
#include "paddle/fluid/framework/var_type_inference.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -677,7 +678,8 @@ void OpDesc::InferVarType(BlockDesc *block) const { ...@@ -677,7 +678,8 @@ void OpDesc::InferVarType(BlockDesc *block) const {
// var type inference. Hence, we don't do any "default" setting here. // var type inference. Hence, we don't do any "default" setting here.
auto &info = OpInfoMap::Instance().Get(this->Type()); auto &info = OpInfoMap::Instance().Get(this->Type());
if (info.infer_var_type_) { if (info.infer_var_type_) {
info.infer_var_type_(*this, block); InferVarTypeContext context(this, block);
info.infer_var_type_(context);
} }
} }
......
...@@ -27,6 +27,7 @@ namespace framework { ...@@ -27,6 +27,7 @@ namespace framework {
class OperatorBase; class OperatorBase;
class OpDesc; class OpDesc;
class InferShapeContext; class InferShapeContext;
class InferVarTypeContext;
class BlockDesc; class BlockDesc;
class Variable; class Variable;
...@@ -53,7 +54,7 @@ using GradOpMakerFN = std::function<std::vector<std::unique_ptr<OpDesc>>( ...@@ -53,7 +54,7 @@ using GradOpMakerFN = std::function<std::vector<std::unique_ptr<OpDesc>>(
const std::vector<BlockDesc*>& grad_block)>; const std::vector<BlockDesc*>& grad_block)>;
using InferVarTypeFN = using InferVarTypeFN =
std::function<void(const OpDesc& /*op_desc*/, BlockDesc* /*block*/)>; std::function<void(framework::InferVarTypeContext& /*context*/)>;
using InferShapeFN = std::function<void(InferShapeContext*)>; using InferShapeFN = std::function<void(InferShapeContext*)>;
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include <string> #include <string>
#include <vector>
#include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/type_defs.h" #include "paddle/fluid/framework/type_defs.h"
...@@ -21,26 +22,113 @@ limitations under the License. */ ...@@ -21,26 +22,113 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class OpDesc;
class BlockDesc;
// default infer var type context
class InferVarTypeContext {
public:
InferVarTypeContext(const OpDesc* op, BlockDesc* block)
: op_(op), block_(block) {}
Attribute GetAttr(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(op_);
return op_->GetAttr(name);
}
inline bool HasVar(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(block_);
return block_->FindVarRecursive(name) != nullptr;
}
inline bool HasInput(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(op_);
return op_->Inputs().count(name) > 0;
}
inline bool HasOutput(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(op_);
return op_->Outputs().count(name) > 0;
}
inline const std::vector<std::string>& Input(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(op_);
return op_->Input(name);
}
inline const std::vector<std::string>& Output(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(op_);
return op_->Output(name);
}
inline proto::VarType::Type GetType(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(block_);
return block_->FindRecursiveOrCreateVar(name).GetType();
}
inline void SetType(const std::string& name, proto::VarType::Type type) {
PADDLE_ENFORCE_NOT_NULL(block_);
block_->FindRecursiveOrCreateVar(name).SetType(type);
}
inline proto::VarType::Type GetDataType(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(block_);
return block_->FindRecursiveOrCreateVar(name).GetDataType();
}
inline void SetDataType(const std::string& name, proto::VarType::Type type) {
PADDLE_ENFORCE_NOT_NULL(block_);
block_->FindRecursiveOrCreateVar(name).SetDataType(type);
}
inline std::vector<int64_t> GetShape(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(block_);
return block_->FindRecursiveOrCreateVar(name).GetShape();
}
inline void SetShape(const std::string& name,
const std::vector<int64_t>& dims) {
PADDLE_ENFORCE_NOT_NULL(block_);
block_->FindRecursiveOrCreateVar(name).SetShape(dims);
}
inline int32_t GetLoDLevel(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(block_);
return block_->FindRecursiveOrCreateVar(name).GetLoDLevel();
}
inline void SetLoDLevel(const std::string& name, int32_t lod_level) {
PADDLE_ENFORCE_NOT_NULL(block_);
block_->FindRecursiveOrCreateVar(name).SetLoDLevel(lod_level);
}
private:
const OpDesc* op_;
BlockDesc* block_;
};
// infer var type context for imperative mode
class RuntimeInferVarTypeContext : public InferVarTypeContext {
public:
RuntimeInferVarTypeContext() : InferVarTypeContext(nullptr, nullptr) {}
};
class VarTypeInference { class VarTypeInference {
public: public:
virtual ~VarTypeInference() {} virtual ~VarTypeInference() {}
virtual void operator()(const OpDesc& op_desc, BlockDesc* block) const = 0; virtual void operator()(InferVarTypeContext& context) const = 0; // NOLINT
}; };
class PassInDtypeAndVarTypeToOutput : public framework::VarTypeInference { class PassInDtypeAndVarTypeToOutput : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc& op_desc, void operator()(framework::InferVarTypeContext& ctx) const final { // NOLINT
framework::BlockDesc* block) const final {
auto in_out_var_names = this->GetInputOutputWithSameType(); auto in_out_var_names = this->GetInputOutputWithSameType();
for (auto& i_o_n : in_out_var_names) { for (auto& i_o_n : in_out_var_names) {
auto& x_name = op_desc.Input(i_o_n.first).at(0); auto& x_name = ctx.Input(i_o_n.first).at(0);
auto& out_name = op_desc.Output(i_o_n.second).at(0); auto& out_name = ctx.Output(i_o_n.second).at(0);
auto& x = block->FindRecursiveOrCreateVar(x_name); ctx.SetType(out_name, ctx.GetType(x_name));
auto& out = block->FindRecursiveOrCreateVar(out_name); ctx.SetDataType(out_name, ctx.GetDataType(x_name));
out.SetType(x.GetType());
out.SetDataType(x.GetDataType());
} }
} }
......
...@@ -44,20 +44,20 @@ class SumOpMaker : public OpProtoAndCheckerMaker { ...@@ -44,20 +44,20 @@ class SumOpMaker : public OpProtoAndCheckerMaker {
class SumOpVarTypeInference : public VarTypeInference { class SumOpVarTypeInference : public VarTypeInference {
public: public:
void operator()(const OpDesc &op_desc, BlockDesc *block) const override { void operator()(framework::InferVarTypeContext &ctx) const override {
auto &inputs = op_desc.Input("X"); auto &inputs = ctx.Input("X");
auto default_var_type = proto::VarType::SELECTED_ROWS; auto default_var_type = proto::VarType::SELECTED_ROWS;
bool any_input_is_lod_tensor = std::any_of( bool any_input_is_lod_tensor = std::any_of(
inputs.begin(), inputs.end(), [block](const std::string &name) { inputs.begin(), inputs.end(), [ctx](const std::string &name) {
return block->Var(name)->GetType() == proto::VarType::LOD_TENSOR; return ctx.GetType(name) == proto::VarType::LOD_TENSOR;
}); });
if (any_input_is_lod_tensor) { if (any_input_is_lod_tensor) {
default_var_type = proto::VarType::LOD_TENSOR; default_var_type = proto::VarType::LOD_TENSOR;
} }
auto out_var_name = op_desc.Output("Out").front(); auto out_var_name = ctx.Output("Out").front();
block->Var(out_var_name)->SetType(default_var_type); ctx.SetType(out_var_name, default_var_type);
} }
}; };
} // namespace framework } // namespace framework
......
...@@ -203,15 +203,12 @@ class BeamSearchDecodeInferShape : public framework::InferShapeBase { ...@@ -203,15 +203,12 @@ class BeamSearchDecodeInferShape : public framework::InferShapeBase {
class BeamSearchDecodeInferVarType : public framework::VarTypeInference { class BeamSearchDecodeInferVarType : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc& op_desc, void operator()(framework::InferVarTypeContext& ctx) const override {
framework::BlockDesc* block) const override { for (auto& o : ctx.Output("SentenceIds")) {
for (auto& o : op_desc.Output("SentenceIds")) { ctx.SetType(o, framework::proto::VarType::LOD_TENSOR);
auto& sentence_ids = block->FindRecursiveOrCreateVar(o);
sentence_ids.SetType(framework::proto::VarType::LOD_TENSOR);
} }
for (auto& o : op_desc.Output("SentenceScores")) { for (auto& o : ctx.Output("SentenceScores")) {
auto& sentence_scores = block->FindRecursiveOrCreateVar(o); ctx.SetType(o, framework::proto::VarType::LOD_TENSOR);
sentence_scores.SetType(framework::proto::VarType::LOD_TENSOR);
} }
} }
}; };
......
...@@ -120,15 +120,12 @@ class BeamSearchOp : public framework::OperatorWithKernel { ...@@ -120,15 +120,12 @@ class BeamSearchOp : public framework::OperatorWithKernel {
class BeamSearchInferVarType : public framework::VarTypeInference { class BeamSearchInferVarType : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc &op_desc, void operator()(framework::InferVarTypeContext &ctx) const override {
framework::BlockDesc *block) const override { for (auto &o : ctx.Output("selected_ids")) {
for (auto &o : op_desc.Output("selected_ids")) { ctx.SetType(o, framework::proto::VarType::LOD_TENSOR);
auto &selected_ids = block->FindRecursiveOrCreateVar(o);
selected_ids.SetType(framework::proto::VarType::LOD_TENSOR);
} }
for (auto &o : op_desc.Output("selected_scores")) { for (auto &o : ctx.Output("selected_scores")) {
auto &selected_scores = block->FindRecursiveOrCreateVar(o); ctx.SetType(o, framework::proto::VarType::LOD_TENSOR);
selected_scores.SetType(framework::proto::VarType::LOD_TENSOR);
} }
} }
}; };
......
...@@ -93,11 +93,9 @@ execution. ...@@ -93,11 +93,9 @@ execution.
class GetPlacesInferVarType : public framework::VarTypeInference { class GetPlacesInferVarType : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc &op_desc, void operator()(framework::InferVarTypeContext &ctx) const override {
framework::BlockDesc *block) const override { for (auto &o_name : ctx.Output("Out")) {
for (auto &o_name : op_desc.Output("Out")) { ctx.SetType(o_name, framework::proto::VarType::PLACE_LIST);
block->FindRecursiveOrCreateVar(o_name).SetType(
framework::proto::VarType::PLACE_LIST);
} }
} }
}; };
......
...@@ -100,16 +100,13 @@ class WriteToArrayInferShape : public framework::InferShapeBase { ...@@ -100,16 +100,13 @@ class WriteToArrayInferShape : public framework::InferShapeBase {
class WriteToArrayInferVarType : public framework::VarTypeInference { class WriteToArrayInferVarType : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc &op_desc, void operator()(framework::InferVarTypeContext &ctx) const override {
framework::BlockDesc *block) const override { auto x_name = ctx.Input("X")[0];
auto x_name = op_desc.Input("X")[0]; auto out_name = ctx.Output("Out")[0];
auto out_name = op_desc.Output("Out")[0];
VLOG(10) << "Set Variable " << out_name << " as LOD_TENSOR_ARRAY"; VLOG(10) << "Set Variable " << out_name << " as LOD_TENSOR_ARRAY";
auto &out = block->FindRecursiveOrCreateVar(out_name); ctx.SetType(out_name, framework::proto::VarType::LOD_TENSOR_ARRAY);
out.SetType(framework::proto::VarType::LOD_TENSOR_ARRAY); if (ctx.HasVar(x_name)) {
auto *x = block->FindVarRecursive(x_name); ctx.SetDataType(out_name, ctx.GetDataType(x_name));
if (x != nullptr) {
out.SetDataType(x->GetDataType());
} }
} }
}; };
......
...@@ -114,11 +114,10 @@ class MergeIdsOp : public framework::OperatorWithKernel { ...@@ -114,11 +114,10 @@ class MergeIdsOp : public framework::OperatorWithKernel {
class MergeIdsOpInferVarType : public framework::VarTypeInference { class MergeIdsOpInferVarType : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc &op_desc, void operator()(framework::InferVarTypeContext &ctx) const override {
framework::BlockDesc *block) const override { auto input_type = ctx.GetType(ctx.Input("Ids")[0]);
auto *input_var = block->Var(op_desc.Input("Ids")[0]); for (auto &out_var : ctx.Output("Out")) {
for (auto &out_var : op_desc.Output("Out")) { ctx.SetType(out_var, input_type);
block->Var(out_var)->SetType(input_var->GetType());
} }
} }
}; };
......
...@@ -71,11 +71,10 @@ class SplitIdsOp : public framework::OperatorWithKernel { ...@@ -71,11 +71,10 @@ class SplitIdsOp : public framework::OperatorWithKernel {
class SplitIdsOpInferVarType : public framework::VarTypeInference { class SplitIdsOpInferVarType : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc &op_desc, void operator()(framework::InferVarTypeContext &ctx) const override {
framework::BlockDesc *block) const override { auto input_type = ctx.GetType(ctx.Input("Ids")[0]);
auto *input_var = block->Var(op_desc.Input("Ids")[0]); for (auto &out_var : ctx.Output("Out")) {
for (auto &out_var : op_desc.Output("Out")) { ctx.SetType(out_var, input_type);
block->Var(out_var)->SetType(input_var->GetType());
} }
} }
}; };
......
...@@ -39,12 +39,11 @@ class FillConstantOp : public framework::OperatorWithKernel { ...@@ -39,12 +39,11 @@ class FillConstantOp : public framework::OperatorWithKernel {
class FillConstantOpVarTypeInference : public framework::VarTypeInference { class FillConstantOpVarTypeInference : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc& op_desc, void operator()(framework::InferVarTypeContext& ctx) const override {
framework::BlockDesc* block) const override {
auto data_type = static_cast<framework::proto::VarType::Type>( auto data_type = static_cast<framework::proto::VarType::Type>(
boost::get<int>(op_desc.GetAttr("dtype"))); boost::get<int>(ctx.GetAttr("dtype")));
auto& out_var_name = op_desc.Output("Out").front(); auto& out_var_name = ctx.Output("Out").front();
block->Var(out_var_name)->SetDataType(data_type); ctx.SetDataType(out_var_name, data_type);
} }
}; };
......
...@@ -137,22 +137,20 @@ class FusedEmbeddingSeqPoolOpGrad : public framework::OperatorWithKernel { ...@@ -137,22 +137,20 @@ class FusedEmbeddingSeqPoolOpGrad : public framework::OperatorWithKernel {
class FusedEmbeddingSeqPoolOpGradVarTypeInference class FusedEmbeddingSeqPoolOpGradVarTypeInference
: public framework::VarTypeInference { : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc& op_desc, void operator()(framework::InferVarTypeContext& ctx) const override {
framework::BlockDesc* block) const override { auto out_var_name = ctx.Output(framework::GradVarName("W")).front();
auto out_var_name = op_desc.Output(framework::GradVarName("W")).front(); auto attr = ctx.GetAttr("is_sparse");
auto attr = op_desc.GetAttr("is_sparse");
bool is_sparse = boost::get<bool>(attr); bool is_sparse = boost::get<bool>(attr);
if (is_sparse) { if (is_sparse) {
VLOG(3) << "fused_embedding_seq_pool_grad op " VLOG(3) << "fused_embedding_seq_pool_grad op "
<< framework::GradVarName("W") << " is set to SelectedRows"; << framework::GradVarName("W") << " is set to SelectedRows";
block->Var(out_var_name) ctx.SetType(out_var_name, framework::proto::VarType::SELECTED_ROWS);
->SetType(framework::proto::VarType::SELECTED_ROWS);
} else { } else {
VLOG(3) << "fused_embedding_seq_pool_grad op " VLOG(3) << "fused_embedding_seq_pool_grad op "
<< framework::GradVarName("W") << " is set to LoDTensor"; << framework::GradVarName("W") << " is set to LoDTensor";
block->Var(out_var_name)->SetType(framework::proto::VarType::LOD_TENSOR); ctx.SetType(out_var_name, framework::proto::VarType::LOD_TENSOR);
} }
block->Var(out_var_name)->SetDataType(block->Var("W")->GetDataType()); ctx.SetDataType(out_var_name, ctx.GetDataType(ctx.Input("W")[0]));
} }
}; };
......
...@@ -81,15 +81,12 @@ GetTensorFromSelectedRows is used to get the tensor from SelectedRows. ...@@ -81,15 +81,12 @@ GetTensorFromSelectedRows is used to get the tensor from SelectedRows.
class GetTensorFromSelectedRowsOpVarTypeInference class GetTensorFromSelectedRowsOpVarTypeInference
: public framework::VarTypeInference { : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc &op_desc, void operator()(framework::InferVarTypeContext &ctx) const { // NOLINT
framework::BlockDesc *block) const final { auto out_var_name = ctx.Output("Out").front();
auto out_var_name = op_desc.Output("Out").front(); auto in_var_name = ctx.Input("X").front();
auto in_var_name = op_desc.Input("X").front();
ctx.SetType(out_var_name, framework::proto::VarType::LOD_TENSOR);
auto out_var = block->FindRecursiveOrCreateVar(out_var_name); ctx.SetDataType(out_var_name, ctx.GetDataType(in_var_name));
auto in_var = block->FindRecursiveOrCreateVar(in_var_name);
out_var.SetType(framework::proto::VarType::LOD_TENSOR);
out_var.SetDataType(in_var.GetDataType());
} }
}; };
......
...@@ -197,38 +197,32 @@ class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel { ...@@ -197,38 +197,32 @@ class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel {
class HierarchicalSigmoidGradOpGradVarTypeInference class HierarchicalSigmoidGradOpGradVarTypeInference
: public framework::VarTypeInference { : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc& op_desc, void operator()(framework::InferVarTypeContext& ctx) const override {
framework::BlockDesc* block) const override { auto w_grad_var_name = ctx.Output(framework::GradVarName("W")).front();
auto w_grad_var_name = op_desc.Output(framework::GradVarName("W")).front(); auto bias_grad_var_name_vec = ctx.Output(framework::GradVarName("Bias"));
auto bias_grad_var_name_vec =
op_desc.Output(framework::GradVarName("Bias"));
std::string bias_grad_var_name; std::string bias_grad_var_name;
bool hasBias = false; bool hasBias = false;
if (bias_grad_var_name_vec.size()) { if (bias_grad_var_name_vec.size()) {
hasBias = true; hasBias = true;
bias_grad_var_name = bias_grad_var_name = ctx.Output(framework::GradVarName("Bias")).front();
op_desc.Output(framework::GradVarName("Bias")).front();
} }
auto attr = op_desc.GetAttr("is_sparse"); auto attr = ctx.GetAttr("is_sparse");
bool is_sparse = boost::get<bool>(attr); bool is_sparse = boost::get<bool>(attr);
if (is_sparse) { if (is_sparse) {
VLOG(30) << "hierarchical_sigmoid_grad op " << framework::GradVarName("W") VLOG(30) << "hierarchical_sigmoid_grad op " << framework::GradVarName("W")
<< " is set to SelectedRows"; << " is set to SelectedRows";
block->Var(w_grad_var_name) ctx.SetType(w_grad_var_name, framework::proto::VarType::SELECTED_ROWS);
->SetType(framework::proto::VarType::SELECTED_ROWS);
} else { } else {
VLOG(30) << "hierarchical_sigmoid_grad op " << framework::GradVarName("W") VLOG(30) << "hierarchical_sigmoid_grad op " << framework::GradVarName("W")
<< " is set to LoDTensor"; << " is set to LoDTensor";
block->Var(w_grad_var_name) ctx.SetType(w_grad_var_name, framework::proto::VarType::LOD_TENSOR);
->SetType(framework::proto::VarType::LOD_TENSOR);
} }
if (hasBias) { if (hasBias) {
VLOG(30) << "hierarchical_sigmoid_grad op " VLOG(30) << "hierarchical_sigmoid_grad op "
<< framework::GradVarName("Bias") << " is set to LoDTensor"; << framework::GradVarName("Bias") << " is set to LoDTensor";
block->Var(bias_grad_var_name) ctx.SetType(bias_grad_var_name, framework::proto::VarType::LOD_TENSOR);
->SetType(framework::proto::VarType::LOD_TENSOR);
} }
block->Var(w_grad_var_name)->SetDataType(block->Var("W")->GetDataType()); ctx.SetDataType(w_grad_var_name, ctx.GetDataType(ctx.Input("W")[0]));
} }
}; };
......
...@@ -64,11 +64,9 @@ class LoDRankTableInferShape : public framework::InferShapeBase { ...@@ -64,11 +64,9 @@ class LoDRankTableInferShape : public framework::InferShapeBase {
class LoDRankTableInferVarType : public framework::VarTypeInference { class LoDRankTableInferVarType : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc &op_desc, void operator()(framework::InferVarTypeContext &ctx) const override {
framework::BlockDesc *block) const override { for (auto &o : ctx.Output("Out")) {
for (auto &o : op_desc.Output("Out")) { ctx.SetType(o, framework::proto::VarType::LOD_RANK_TABLE);
block->FindRecursiveOrCreateVar(o).SetType(
framework::proto::VarType::LOD_RANK_TABLE);
} }
} }
}; };
......
...@@ -201,10 +201,9 @@ class LoDTensorToArrayInferShape : public framework::InferShapeBase { ...@@ -201,10 +201,9 @@ class LoDTensorToArrayInferShape : public framework::InferShapeBase {
class LoDTensorToArrayInferVarType : public framework::VarTypeInference { class LoDTensorToArrayInferVarType : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc &op_desc, void operator()(framework::InferVarTypeContext &ctx) const override {
framework::BlockDesc *block) const override { for (auto &out_var : ctx.Output("Out")) {
for (auto &out_var : op_desc.Output("Out")) { ctx.SetType(out_var, framework::proto::VarType::LOD_TENSOR_ARRAY);
block->Var(out_var)->SetType(framework::proto::VarType::LOD_TENSOR_ARRAY);
} }
} }
}; };
......
...@@ -147,22 +147,20 @@ class LookupTableOpGrad : public framework::OperatorWithKernel { ...@@ -147,22 +147,20 @@ class LookupTableOpGrad : public framework::OperatorWithKernel {
class LookupTableOpGradVarTypeInference : public framework::VarTypeInference { class LookupTableOpGradVarTypeInference : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc& op_desc, void operator()(framework::InferVarTypeContext& ctx) const override {
framework::BlockDesc* block) const override { auto out_var_name = ctx.Output(framework::GradVarName("W")).front();
auto out_var_name = op_desc.Output(framework::GradVarName("W")).front(); auto attr = ctx.GetAttr("is_sparse");
auto attr = op_desc.GetAttr("is_sparse");
bool is_sparse = boost::get<bool>(attr); bool is_sparse = boost::get<bool>(attr);
if (is_sparse) { if (is_sparse) {
VLOG(3) << "lookup_table_grad op " << framework::GradVarName("W") VLOG(3) << "lookup_table_grad op " << framework::GradVarName("W")
<< " is set to SelectedRows"; << " is set to SelectedRows";
block->Var(out_var_name) ctx.SetType(out_var_name, framework::proto::VarType::SELECTED_ROWS);
->SetType(framework::proto::VarType::SELECTED_ROWS);
} else { } else {
VLOG(3) << "lookup_table_grad op " << framework::GradVarName("W") VLOG(3) << "lookup_table_grad op " << framework::GradVarName("W")
<< " is set to LoDTensor"; << " is set to LoDTensor";
block->Var(out_var_name)->SetType(framework::proto::VarType::LOD_TENSOR); ctx.SetType(out_var_name, framework::proto::VarType::LOD_TENSOR);
} }
block->Var(out_var_name)->SetDataType(block->Var("W")->GetDataType()); ctx.SetDataType(out_var_name, ctx.GetDataType(ctx.Input("W")[0]));
} }
}; };
......
...@@ -237,23 +237,21 @@ class NCEOpGrad : public framework::OperatorWithKernel { ...@@ -237,23 +237,21 @@ class NCEOpGrad : public framework::OperatorWithKernel {
class NCEOpGradVarTypeInference : public framework::VarTypeInference { class NCEOpGradVarTypeInference : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc &op_desc, void operator()(framework::InferVarTypeContext &ctx) const override {
framework::BlockDesc *block) const override { auto weight_grad = ctx.Output(framework::GradVarName("Weight")).front();
auto weight_grad = op_desc.Output(framework::GradVarName("Weight")).front();
auto attr = op_desc.GetAttr("is_sparse"); auto attr = ctx.GetAttr("is_sparse");
bool is_sparse = boost::get<bool>(attr); bool is_sparse = boost::get<bool>(attr);
if (is_sparse) { if (is_sparse) {
VLOG(3) << "nce_op_grad op " << weight_grad << " and " VLOG(3) << "nce_op_grad op " << weight_grad << " and "
<< " is set to SelectedRows"; << " is set to SelectedRows";
block->Var(weight_grad) ctx.SetType(weight_grad, framework::proto::VarType::SELECTED_ROWS);
->SetType(framework::proto::VarType::SELECTED_ROWS);
} else { } else {
VLOG(3) << "nce_op_grad op " << weight_grad << " and " VLOG(3) << "nce_op_grad op " << weight_grad << " and "
<< " is set to LoDTensor"; << " is set to LoDTensor";
block->Var(weight_grad)->SetType(framework::proto::VarType::LOD_TENSOR); ctx.SetType(weight_grad, framework::proto::VarType::LOD_TENSOR);
} }
block->Var(weight_grad)->SetDataType(block->Var("Input")->GetDataType()); ctx.SetDataType(weight_grad, ctx.GetDataType(ctx.Input("Input")[0]));
} }
}; };
......
...@@ -37,8 +37,7 @@ class NgraphEngineOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -37,8 +37,7 @@ class NgraphEngineOpMaker : public framework::OpProtoAndCheckerMaker {
class NgraphEngineInferVarType : public framework::VarTypeInference { class NgraphEngineInferVarType : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc &op_desc, void operator()(framework::InferVarTypeContext &ctx) const override {}
framework::BlockDesc *block) const override {}
}; };
} // namespace operators } // namespace operators
......
...@@ -72,8 +72,7 @@ use L2 regularizers in case of using LARS. ...@@ -72,8 +72,7 @@ use L2 regularizers in case of using LARS.
class LarsMomentumOpVarTypeInference : public framework::VarTypeInference { class LarsMomentumOpVarTypeInference : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc &op_desc, void operator()(framework::InferVarTypeContext &ctx) const override {}
framework::BlockDesc *block) const override {}
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
......
...@@ -21,18 +21,14 @@ using Tensor = framework::Tensor; ...@@ -21,18 +21,14 @@ using Tensor = framework::Tensor;
class MomentumOpInferVarType : public framework::VarTypeInference { class MomentumOpInferVarType : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc& op_desc, void operator()(framework::InferVarTypeContext& ctx) const override {
framework::BlockDesc* block) const override { auto& input_var = ctx.Input("Param")[0];
auto input_var = op_desc.Input("Param")[0]; for (auto& out_var : ctx.Output("ParamOut")) {
for (auto& out_var : op_desc.Output("ParamOut")) { if (ctx.GetType(input_var) == framework::proto::VarType::SELECTED_ROWS) {
if (block->FindRecursiveOrCreateVar(input_var).GetType() == ctx.SetType(out_var, framework::proto::VarType::SELECTED_ROWS);
framework::proto::VarType::SELECTED_ROWS) { } else if (ctx.GetType(input_var) ==
block->FindRecursiveOrCreateVar(out_var).SetType(
framework::proto::VarType::SELECTED_ROWS);
} else if (block->FindRecursiveOrCreateVar(input_var).GetType() ==
framework::proto::VarType::LOD_TENSOR) { framework::proto::VarType::LOD_TENSOR) {
block->FindRecursiveOrCreateVar(out_var).SetType( ctx.SetType(out_var, framework::proto::VarType::LOD_TENSOR);
framework::proto::VarType::LOD_TENSOR);
} else { } else {
PADDLE_THROW( PADDLE_THROW(
"Only support LodTensor and SelectedRows, Unexpected Input Type."); "Only support LodTensor and SelectedRows, Unexpected Input Type.");
......
...@@ -50,20 +50,18 @@ class SGDOp : public framework::OperatorWithKernel { ...@@ -50,20 +50,18 @@ class SGDOp : public framework::OperatorWithKernel {
class SGDOpInferVarType : public framework::VarTypeInference { class SGDOpInferVarType : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc &op_desc, void operator()(framework::InferVarTypeContext &ctx) const override {
framework::BlockDesc *block) const override { auto &input_var_n = ctx.Input("Param")[0];
auto input_var_n = op_desc.Input("Param")[0]; auto in_var_type = ctx.GetType(input_var_n);
auto in_var_type = block->FindRecursiveOrCreateVar(input_var_n).GetType();
PADDLE_ENFORCE(in_var_type == framework::proto::VarType::SELECTED_ROWS || PADDLE_ENFORCE(in_var_type == framework::proto::VarType::SELECTED_ROWS ||
in_var_type == framework::proto::VarType::LOD_TENSOR, in_var_type == framework::proto::VarType::LOD_TENSOR,
"The input Var's type should be LoDtensor or SelectedRows," "The input Var's type should be LoDtensor or SelectedRows,"
" but the received var(%s)'s type is %s", " but the received var(%s)'s type is %s",
input_var_n, in_var_type); input_var_n, in_var_type);
for (auto &out_var_n : op_desc.Output("ParamOut")) { for (auto &out_var_n : ctx.Output("ParamOut")) {
auto &out_var = block->FindRecursiveOrCreateVar(out_var_n); if (ctx.GetType(out_var_n) != in_var_type) {
if (out_var.GetType() != in_var_type) { ctx.SetType(out_var_n, in_var_type);
out_var.SetType(in_var_type);
} }
} }
} }
......
...@@ -91,15 +91,12 @@ static void CallPythonFunc(py::object *callable, ...@@ -91,15 +91,12 @@ static void CallPythonFunc(py::object *callable,
} }
} }
class PyFuncOpVarTypInference : public framework::VarTypeInference { class PyFuncOpVarTypeInference : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc &op, void operator()(framework::InferVarTypeContext &ctx) const override {
framework::BlockDesc *block) const override { bool has_out = (ctx.HasOutput("Out") && !ctx.Output("Out").empty());
auto &outs = op.Outputs();
bool has_out = (outs.count("Out") > 0 && !outs.at("Out").empty());
auto &ins = op.Inputs(); bool has_in = (ctx.HasInput("X") && !ctx.Input("Out").empty());
bool has_in = (ins.count("X") > 0 && !ins.at("X").empty());
/** /**
* X or Out can be empty, so that py_func can be more flexible * X or Out can be empty, so that py_func can be more flexible
...@@ -107,7 +104,7 @@ class PyFuncOpVarTypInference : public framework::VarTypeInference { ...@@ -107,7 +104,7 @@ class PyFuncOpVarTypInference : public framework::VarTypeInference {
*/ */
PADDLE_ENFORCE(has_in || has_out, "Input(X) or Output(Out) must exist"); PADDLE_ENFORCE(has_in || has_out, "Input(X) or Output(Out) must exist");
PADDLE_ENFORCE_GE(boost::get<int>(op.GetAttr(kForwardPythonCallableId)), 0, PADDLE_ENFORCE_GE(boost::get<int>(ctx.GetAttr(kForwardPythonCallableId)), 0,
"Function id cannot be less than 0"); "Function id cannot be less than 0");
if (!has_out) return; if (!has_out) return;
...@@ -118,7 +115,7 @@ class PyFuncOpVarTypInference : public framework::VarTypeInference { ...@@ -118,7 +115,7 @@ class PyFuncOpVarTypInference : public framework::VarTypeInference {
* the corresponding forward variable * the corresponding forward variable
*/ */
const std::string kGradVarSuffix = framework::kGradVarSuffix; const std::string kGradVarSuffix = framework::kGradVarSuffix;
auto &out_var_names = outs.at("Out"); auto &out_var_names = ctx.Output("Out");
for (auto &out_var_name : out_var_names) { for (auto &out_var_name : out_var_names) {
if (out_var_name == framework::kEmptyVarName || if (out_var_name == framework::kEmptyVarName ||
out_var_name.size() < kGradVarSuffix.size()) { out_var_name.size() < kGradVarSuffix.size()) {
...@@ -128,18 +125,17 @@ class PyFuncOpVarTypInference : public framework::VarTypeInference { ...@@ -128,18 +125,17 @@ class PyFuncOpVarTypInference : public framework::VarTypeInference {
size_t len = out_var_name.size() - kGradVarSuffix.size(); size_t len = out_var_name.size() - kGradVarSuffix.size();
if (out_var_name.substr(len) == kGradVarSuffix) { if (out_var_name.substr(len) == kGradVarSuffix) {
auto fwd_var_name = out_var_name.substr(0, len); auto fwd_var_name = out_var_name.substr(0, len);
auto *out_var_desc = block->FindVarRecursive(out_var_name); PADDLE_ENFORCE(ctx.HasVar(out_var_name),
auto *fwd_var_desc = block->FindVarRecursive(fwd_var_name); "Backward variable %s not found", out_var_name);
PADDLE_ENFORCE_NOT_NULL(out_var_desc, "Backward variable %s not found", PADDLE_ENFORCE(ctx.HasVar(fwd_var_name),
out_var_name); "Backward variable %s not found", fwd_var_name);
PADDLE_ENFORCE_NOT_NULL(fwd_var_desc, "Forward variable %s not found",
fwd_var_name);
VLOG(10) << "Infer var_desc of Output(" << out_var_name << ") as Input(" VLOG(10) << "Infer var_desc of Output(" << out_var_name << ") as Input("
<< fwd_var_name << ")"; << fwd_var_name << ")";
out_var_desc->SetShape(fwd_var_desc->GetShape());
out_var_desc->SetDataType(fwd_var_desc->GetDataType()); ctx.SetShape(out_var_name, ctx.GetShape(fwd_var_name));
out_var_desc->SetLoDLevel(fwd_var_desc->GetLoDLevel()); ctx.SetDataType(out_var_name, ctx.GetDataType(fwd_var_name));
out_var_desc->SetType(fwd_var_desc->GetType()); ctx.SetLoDLevel(out_var_name, ctx.GetLoDLevel(fwd_var_name));
ctx.SetType(out_var_name, ctx.GetType(fwd_var_name));
} }
} }
} }
...@@ -309,5 +305,5 @@ class PyFuncOp : public framework::OperatorBase { ...@@ -309,5 +305,5 @@ class PyFuncOp : public framework::OperatorBase {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(py_func, ops::PyFuncOp, ops::PyFuncOpMaker, REGISTER_OPERATOR(py_func, ops::PyFuncOp, ops::PyFuncOpMaker,
ops::PyFuncOpVarTypInference, ops::PyFuncOpShapeInference, ops::PyFuncOpVarTypeInference, ops::PyFuncOpShapeInference,
ops::PyFuncOpGradDescMaker); ops::PyFuncOpGradDescMaker);
...@@ -123,23 +123,22 @@ class CustomReaderInferShape : public framework::InferShapeBase { ...@@ -123,23 +123,22 @@ class CustomReaderInferShape : public framework::InferShapeBase {
class CustomReaderInferVarType : public framework::VarTypeInference { class CustomReaderInferVarType : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc& op_desc, void operator()(const framework::InferVarTypeContext& ctx) const override {
framework::BlockDesc* block) const override { auto& out_var_name = ctx.Output("Out")[0];
framework::VarDesc* out_reader = block->FindVar(op_desc.Output("Out")[0]); PADDLE_ENFORCE(ctx.HasVar(out_var_name));
PADDLE_ENFORCE_NOT_NULL(out_reader); ctx.SetType(out_var_name, framework::proto::VarType::READER);
out_reader->SetType(framework::proto::VarType::READER);
auto sink_var_names = auto sink_var_names =
boost::get<std::vector<std::string>>(op_desc.GetAttr("sink_var_names")); boost::get<std::vector<std::string>>(ctx.GetAttr("sink_var_names"));
const auto* sub_block = const auto* sub_block =
boost::get<framework::BlockDesc*>(op_desc.GetAttr("sub_block")); boost::get<framework::BlockDesc*>(ctx.GetAttr("sub_block"));
std::vector<framework::proto::VarType::Type> res_data_types; std::vector<framework::proto::VarType::Type> res_data_types;
for (const std::string& var_name : sink_var_names) { for (const std::string& var_name : sink_var_names) {
framework::VarDesc* var = sub_block->FindVar(var_name); framework::VarDesc* var = sub_block->FindVar(var_name);
PADDLE_ENFORCE_NOT_NULL(var); PADDLE_ENFORCE_NOT_NULL(var);
res_data_types.emplace_back(var->GetDataType()); res_data_types.emplace_back(var->GetDataType());
} }
out_reader->SetDataTypes(res_data_types); ctx.SetDataTypes(out_var_name, res_data_types);
} }
}; };
......
...@@ -51,19 +51,16 @@ class ReadInferShape : public framework::InferShapeBase { ...@@ -51,19 +51,16 @@ class ReadInferShape : public framework::InferShapeBase {
class ReadInferVarType : public framework::VarTypeInference { class ReadInferVarType : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc& op_desc, void operator()(const framework::InferVarTypeContext& ctx) const override {
framework::BlockDesc* block) const override { bool infer_out = boost::get<bool>(ctx.GetAttr("infer_out"));
bool infer_out = boost::get<bool>(op_desc.GetAttr("infer_out"));
if (infer_out) { if (infer_out) {
std::string reader_name = op_desc.Input("Reader")[0]; std::string reader_name = ctx.Input("Reader")[0];
std::vector<std::string> out_names = op_desc.Output("Out"); std::vector<std::string> out_names = ctx.Output("Out");
framework::VarDesc* reader = block->FindVarRecursive(reader_name); auto dtypes = ctx.GetDataTypes(reader_name);
auto dtypes = reader->GetDataTypes();
PADDLE_ENFORCE_EQ(dtypes.size(), out_names.size()); PADDLE_ENFORCE_EQ(dtypes.size(), out_names.size());
for (size_t i = 0; i < dtypes.size(); ++i) { for (size_t i = 0; i < dtypes.size(); ++i) {
framework::VarDesc& out = block->FindRecursiveOrCreateVar(out_names[i]); ctx.SetType(out_names[i], framework::proto::VarType::LOD_TENSOR);
out.SetType(framework::proto::VarType::LOD_TENSOR); ctx.SetDataType(out_names[i], dtypes[i]);
out.SetDataType(dtypes[i]);
} }
} }
} }
......
...@@ -59,8 +59,7 @@ class FileReaderInferShape : public framework::InferShapeBase { ...@@ -59,8 +59,7 @@ class FileReaderInferShape : public framework::InferShapeBase {
class FileReaderInferVarType : public framework::VarTypeInference { class FileReaderInferVarType : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc& op_desc, void operator()(framework::InferVarTypeContext& ctx) const override;
framework::BlockDesc* block) const override;
}; };
// general infershape for decorated reader // general infershape for decorated reader
...@@ -72,8 +71,7 @@ class DecoratedReaderInferShape : public framework::InferShapeBase { ...@@ -72,8 +71,7 @@ class DecoratedReaderInferShape : public framework::InferShapeBase {
// general var type inference for decorated reader // general var type inference for decorated reader
class DecoratedReaderInferVarType : public framework::VarTypeInference { class DecoratedReaderInferVarType : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc& op_desc, void operator()(framework::InferVarTypeContext& ctx) const override;
framework::BlockDesc* block) const override;
}; };
class DecoratedReaderMakerBase : public framework::OpProtoAndCheckerMaker { class DecoratedReaderMakerBase : public framework::OpProtoAndCheckerMaker {
......
...@@ -159,12 +159,9 @@ This operator will serialize and write LoDTensor / SelectedRows variable to file ...@@ -159,12 +159,9 @@ This operator will serialize and write LoDTensor / SelectedRows variable to file
class SaveOpVarTypeInference : public framework::VarTypeInference { class SaveOpVarTypeInference : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc &op_desc, void operator()(framework::InferVarTypeContext &ctx) const override {
framework::BlockDesc *block) const override { auto out_var_name = ctx.Output(LOOKUP_TABLE_PATH).front();
auto out_var_name = op_desc.Output(LOOKUP_TABLE_PATH).front(); ctx.SetType(out_var_name, framework::proto::VarType::RAW);
auto &out_var = block->FindRecursiveOrCreateVar(out_var_name);
auto var_type = framework::proto::VarType::RAW;
out_var.SetType(var_type);
} }
}; };
......
...@@ -69,17 +69,13 @@ $$Out = scale*(X + bias)$$ ...@@ -69,17 +69,13 @@ $$Out = scale*(X + bias)$$
class ScaleOpVarTypeInference : public framework::VarTypeInference { class ScaleOpVarTypeInference : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc &op_desc, void operator()(framework::InferVarTypeContext &ctx) const override {
framework::BlockDesc *block) const override { auto &in_var_name = ctx.Input("X").front();
auto &in_var_name = op_desc.Input("X").front(); auto out_var_name = ctx.Output("Out").front();
auto &in_var = detail::Ref(block->FindVarRecursive(in_var_name));
auto out_var_name = op_desc.Output("Out").front();
auto *out_var = block->FindVarRecursive(out_var_name);
if (in_var_name != out_var_name) { if (in_var_name != out_var_name) {
out_var->SetType(in_var.GetType()); ctx.SetType(out_var_name, ctx.GetType(in_var_name));
out_var->SetDataType(in_var.GetDataType()); ctx.SetDataType(out_var_name, ctx.GetDataType(in_var_name));
} }
} }
}; };
......
...@@ -60,10 +60,9 @@ class SplitSelectedRowsOp : public framework::OperatorWithKernel { ...@@ -60,10 +60,9 @@ class SplitSelectedRowsOp : public framework::OperatorWithKernel {
class SplitSelectedRowsOpInferVarType : public framework::VarTypeInference { class SplitSelectedRowsOpInferVarType : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc &op_desc, void operator()(framework::InferVarTypeContext &ctx) const override {
framework::BlockDesc *block) const override { for (auto &out_var : ctx.Output("Out")) {
for (auto &out_var : op_desc.Output("Out")) { ctx.SetType(out_var, framework::proto::VarType::SELECTED_ROWS);
block->Var(out_var)->SetType(framework::proto::VarType::SELECTED_ROWS);
} }
} }
}; };
......
...@@ -159,24 +159,20 @@ the LoD information with the first input. ...@@ -159,24 +159,20 @@ the LoD information with the first input.
class SumOpVarTypeInference : public framework::VarTypeInference { class SumOpVarTypeInference : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc& op_desc, void operator()(framework::InferVarTypeContext& ctx) const override {
framework::BlockDesc* block) const override { auto& inputs = ctx.Input("X");
auto& inputs = op_desc.Input("X");
auto var_type = framework::proto::VarType::SELECTED_ROWS; auto var_type = framework::proto::VarType::SELECTED_ROWS;
for (auto& name : op_desc.Input("X")) { for (auto& name : ctx.Input("X")) {
VLOG(10) << name << " " VLOG(10) << name << " " << ctx.GetType(name);
<< block->FindRecursiveOrCreateVar(name).GetType();
} }
bool any_input_is_lod_tensor = std::any_of( bool any_input_is_lod_tensor = std::any_of(
inputs.begin(), inputs.end(), [block](const std::string& name) { inputs.begin(), inputs.end(), [ctx](const std::string& name) {
return block->FindRecursiveOrCreateVar(name).GetType() == return ctx.GetType(name) == framework::proto::VarType::LOD_TENSOR;
framework::proto::VarType::LOD_TENSOR;
}); });
auto is_tensor_array = [block](const std::string& name) { auto is_tensor_array = [ctx](const std::string& name) {
return block->FindRecursiveOrCreateVar(name).GetType() == return ctx.GetType(name) == framework::proto::VarType::LOD_TENSOR_ARRAY;
framework::proto::VarType::LOD_TENSOR_ARRAY;
}; };
bool any_input_is_tensor_array = bool any_input_is_tensor_array =
...@@ -188,8 +184,7 @@ class SumOpVarTypeInference : public framework::VarTypeInference { ...@@ -188,8 +184,7 @@ class SumOpVarTypeInference : public framework::VarTypeInference {
if (!all_inputs_are_tensor_array) { if (!all_inputs_are_tensor_array) {
std::ostringstream os; std::ostringstream os;
for (auto& each : inputs) { for (auto& each : inputs) {
os << " " << each << " type is " os << " " << each << " type is " << ctx.GetType(each) << "\n";
<< block->FindRecursiveOrCreateVar(each).GetType() << "\n";
} }
PADDLE_ENFORCE(all_inputs_are_tensor_array, PADDLE_ENFORCE(all_inputs_are_tensor_array,
"Not all inputs are tensor array:\n%s", os.str()); "Not all inputs are tensor array:\n%s", os.str());
...@@ -199,11 +194,9 @@ class SumOpVarTypeInference : public framework::VarTypeInference { ...@@ -199,11 +194,9 @@ class SumOpVarTypeInference : public framework::VarTypeInference {
var_type = framework::proto::VarType::LOD_TENSOR; var_type = framework::proto::VarType::LOD_TENSOR;
} }
auto out_var_name = op_desc.Output("Out").front(); auto out_var_name = ctx.Output("Out").front();
auto& out_var = block->FindRecursiveOrCreateVar(out_var_name); ctx.SetType(out_var_name, var_type);
out_var.SetType(var_type); ctx.SetDataType(out_var_name, ctx.GetDataType(inputs.front()));
auto& in_var = detail::Ref(block->FindVarRecursive(inputs.front()));
out_var.SetDataType(in_var.GetDataType());
} }
}; };
......
...@@ -177,10 +177,9 @@ class LoDTensorArray2TensorGradInferShape : public framework::InferShapeBase { ...@@ -177,10 +177,9 @@ class LoDTensorArray2TensorGradInferShape : public framework::InferShapeBase {
class LoDTensorArray2TensorGradInferVarType class LoDTensorArray2TensorGradInferVarType
: public framework::VarTypeInference { : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc &op_desc, void operator()(framework::InferVarTypeContext &ctx) const override {
framework::BlockDesc *block) const override { for (auto &out_var : ctx.Output(framework::GradVarName("X"))) {
for (auto &out_var : op_desc.Output(framework::GradVarName("X"))) { ctx.SetType(out_var, framework::proto::VarType::LOD_TENSOR_ARRAY);
block->Var(out_var)->SetType(framework::proto::VarType::LOD_TENSOR_ARRAY);
} }
} }
}; };
......
...@@ -46,8 +46,7 @@ class TensorRTEngineOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -46,8 +46,7 @@ class TensorRTEngineOpMaker : public framework::OpProtoAndCheckerMaker {
class TensorRTEngineInferVarType : public framework::VarTypeInference { class TensorRTEngineInferVarType : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc &op_desc, void operator()(framework::InferVarTypeContext &ctx) const override {}
framework::BlockDesc *block) const override {}
}; };
} // namespace operators } // namespace operators
......
...@@ -112,17 +112,15 @@ uniform distribution. The random result is in set [min, max]. ...@@ -112,17 +112,15 @@ uniform distribution. The random result is in set [min, max].
class UniformRandomOpVarTypeInference : public framework::VarTypeInference { class UniformRandomOpVarTypeInference : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc &op_desc, void operator()(framework::InferVarTypeContext &ctx) const override {
framework::BlockDesc *block) const override { auto out_var_name = ctx.Output("Out").front();
auto out_var_name = op_desc.Output("Out").front();
auto var_data_type = static_cast<framework::proto::VarType::Type>( auto var_data_type = static_cast<framework::proto::VarType::Type>(
boost::get<int>(op_desc.GetAttr("dtype"))); boost::get<int>(ctx.GetAttr("dtype")));
auto out_var = block->FindRecursiveOrCreateVar(out_var_name); if (ctx.GetType(out_var_name) != framework::proto::VarType::SELECTED_ROWS) {
if (out_var.GetType() != framework::proto::VarType::SELECTED_ROWS) { ctx.SetType(out_var_name, framework::proto::VarType::LOD_TENSOR);
out_var.SetType(framework::proto::VarType::LOD_TENSOR);
} }
out_var.SetDataType(var_data_type); ctx.SetDataType(out_var_name, var_data_type);
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册