未验证 提交 91ae7848 编写于 作者: L liuwei1031 提交者: GitHub

improve efficiency of runtime InferVarType (#22778) (#24181)

 * cherry pick #22778
上级 57b062e1
......@@ -45,19 +45,13 @@ class SumOpMaker : public OpProtoAndCheckerMaker {
class SumOpVarTypeInference : public VarTypeInference {
public:
void operator()(InferVarTypeContext *ctx) const override {
auto &inputs = ctx->Input("X");
auto default_var_type = proto::VarType::SELECTED_ROWS;
bool any_input_is_lod_tensor = std::any_of(
inputs.begin(), inputs.end(), [&ctx](const std::string &name) {
return ctx->GetType(name) == proto::VarType::LOD_TENSOR;
});
if (any_input_is_lod_tensor) {
if (ctx->InputTypeAnyOf("X", proto::VarType::LOD_TENSOR)) {
default_var_type = proto::VarType::LOD_TENSOR;
}
auto out_var_name = ctx->Output("Out").front();
ctx->SetType(out_var_name, default_var_type);
ctx->SetOutputType("Out", default_var_type);
}
};
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <string>
#include <unordered_map>
#include <vector>
......@@ -25,8 +26,14 @@ namespace framework {
class OpDesc;
class BlockDesc;
class StaticGraphVarTypeInference;
// default infer var type context
static const int ALL_ELEMENTS = -1;
class InferVarTypeContext {
friend class StaticGraphVarTypeInference;
public:
InferVarTypeContext(const OpDesc* op, BlockDesc* block)
: op_(op), block_(block) {}
......@@ -34,91 +41,267 @@ class InferVarTypeContext {
virtual ~InferVarTypeContext() {}
virtual Attribute GetAttr(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(op_);
PADDLE_ENFORCE_NOT_NULL(
op_, platform::errors::PreconditionNotMet("op_ should not be null"));
return op_->GetAttr(name);
}
virtual bool HasVar(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(block_);
return block_->FindVarRecursive(name) != nullptr;
}
virtual bool HasInput(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(op_);
PADDLE_ENFORCE_NOT_NULL(
op_, platform::errors::PreconditionNotMet("op_ should not be null"));
auto& inputs = op_->Inputs();
auto input = inputs.find(name);
return input != inputs.end() && !input->second.empty();
}
virtual bool HasOutput(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(op_);
PADDLE_ENFORCE_NOT_NULL(
op_, platform::errors::PreconditionNotMet("op_ should not be null"));
auto& outputs = op_->Outputs();
auto output = outputs.find(name);
return output != outputs.end() && !output->second.empty();
}
virtual const std::vector<std::string>& Input(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(op_);
virtual size_t InputSize(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(
op_, platform::errors::PreconditionNotMet("op_ should not be null"));
return op_->Inputs().at(name).size();
}
virtual const std::string& InputVarName(const std::string& name,
const int index = 0) const {
PADDLE_ENFORCE_NOT_NULL(
op_, platform::errors::PreconditionNotMet("op_ should not be null"));
return op_->Inputs().at(name)[index];
}
virtual bool InputTypeAnyOf(const std::string& name,
proto::VarType::Type type) const {
PADDLE_ENFORCE_NOT_NULL(
op_, platform::errors::PreconditionNotMet("op_ should not be null"));
auto& inputs = op_->Input(name);
return std::any_of(inputs.begin(), inputs.end(),
[this, &type](const std::string& name) {
return this->GetVarType(name) == type;
});
}
virtual bool InputTypeAllOf(const std::string& name,
proto::VarType::Type type) const {
PADDLE_ENFORCE_NOT_NULL(
op_, platform::errors::PreconditionNotMet("op_ should not be null"));
auto& inputs = op_->Input(name);
return std::all_of(inputs.begin(), inputs.end(),
[this, &type](const std::string& name) {
return this->GetVarType(name) == type;
});
}
virtual void SyncTypeAndDataType(const std::string& input_name,
const std::string& output_name,
int index = 0) {
PADDLE_ENFORCE_NOT_NULL(
op_, platform::errors::PreconditionNotMet("op_ should not be null"));
auto& x_name = op_->Input(input_name).at(index);
auto& out_name = op_->Output(output_name).at(index);
if (x_name != out_name) {
this->SetVarType(out_name, this->GetVarType(x_name));
this->SetVarDataType(out_name, this->GetVarDataType(x_name));
}
}
virtual void SetOutputType(const std::string& name, proto::VarType::Type type,
int index = 0) {
PADDLE_ENFORCE_NOT_NULL(
op_, platform::errors::PreconditionNotMet("op_ should not be null"));
if (ALL_ELEMENTS == index) {
for (const auto& var_name : op_->Output(name)) {
this->SetVarType(var_name, type);
}
} else {
auto& var_name = op_->Output(name).at(index);
this->SetVarType(var_name, type);
}
}
virtual proto::VarType::Type GetInputType(const std::string& name,
const int& index = 0) const {
PADDLE_ENFORCE_NOT_NULL(
op_, platform::errors::PreconditionNotMet("op_ should not be null"));
return this->GetVarType(op_->Input(name).at(index));
}
virtual proto::VarType::Type GetOutputType(const std::string& name,
const int& index = 0) const {
PADDLE_ENFORCE_NOT_NULL(
op_, platform::errors::PreconditionNotMet("op_ should not be null"));
return this->GetVarType(op_->Output(name).at(index));
}
virtual proto::VarType::Type GetInputDataType(const std::string& name,
const int& index = 0) const {
PADDLE_ENFORCE_NOT_NULL(
op_, platform::errors::PreconditionNotMet("op_ should not be null"));
return this->GetVarDataType(op_->Input(name).at(index));
}
virtual void SetOutputDataType(const std::string& name,
proto::VarType::Type type, int index = 0) {
PADDLE_ENFORCE_NOT_NULL(
op_, platform::errors::PreconditionNotMet("op_ should not be null"));
if (ALL_ELEMENTS == index) {
for (const auto& var_name : op_->Output(name)) {
this->SetVarDataType(var_name, type);
}
} else {
auto& var_name = op_->Output(name).at(index);
this->SetVarDataType(var_name, type);
}
}
virtual std::vector<proto::VarType::Type> GetInputDataTypes(
const std::string& name, const int& index = 0) const {
PADDLE_ENFORCE_NOT_NULL(
op_, platform::errors::PreconditionNotMet("op_ should not be null"));
return this->GetVarDataTypes(op_->Input(name).at(index));
}
virtual void SetOutputDataTypes(
const std::string& name,
const std::vector<proto::VarType::Type>& multiple_data_type,
const int& index = 0) {
PADDLE_ENFORCE_NOT_NULL(
op_, platform::errors::PreconditionNotMet("op_ should not be null"));
auto& var_name = op_->Output(name).at(index);
this->SetVarDataTypes(var_name, multiple_data_type);
}
virtual std::vector<int64_t> GetInputShape(const std::string& name,
const int& index = 0) const {
PADDLE_ENFORCE_NOT_NULL(
op_, platform::errors::PreconditionNotMet("op_ should not be null"));
auto& var_name = op_->Input(name).at(index);
return this->GetVarShape(var_name);
}
virtual void SetOutputShape(const std::string& name,
const std::vector<int64_t>& dims,
const int& index = 0) {
PADDLE_ENFORCE_NOT_NULL(
op_, platform::errors::PreconditionNotMet("op_ should not be null"));
auto& var_name = op_->Output(name).at(index);
this->SetVarShape(var_name, dims);
}
virtual int32_t GetInputLoDLevel(const std::string& name,
const int& index = 0) const {
PADDLE_ENFORCE_NOT_NULL(
op_, platform::errors::PreconditionNotMet("op_ should not be null"));
auto& var_name = op_->Input(name).at(index);
return this->GetVarLoDLevel(var_name);
}
virtual void SetOutputLoDLevel(const std::string& name, int32_t lod_level,
const int& index = 0) {
PADDLE_ENFORCE_NOT_NULL(
op_, platform::errors::PreconditionNotMet("op_ should not be null"));
auto& var_name = op_->Output(name).at(index);
this->SetVarLoDLevel(var_name, lod_level);
}
// add a speical API for save_op
// avoid use this API for common logic
virtual void InsertVar(const std::string& var_name,
proto::VarType::Type var_type) {
if (!IsDygraph()) this->SetVarType(var_name, var_type);
}
virtual bool IsDygraph() const { return false; }
protected:
virtual bool HasVar(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(block_, platform::errors::PreconditionNotMet(
"block_ should not be null"));
return block_->FindVarRecursive(name) != nullptr;
}
virtual const std::vector<std::string>& InputVars(
const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(
op_, platform::errors::PreconditionNotMet("op_ should not be null"));
return op_->Input(name);
}
virtual const std::vector<std::string>& Output(
virtual const std::vector<std::string>& OutputVars(
const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(op_);
PADDLE_ENFORCE_NOT_NULL(
op_, platform::errors::PreconditionNotMet("op_ should not be null"));
return op_->Output(name);
}
virtual proto::VarType::Type GetType(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(block_);
virtual proto::VarType::Type GetVarType(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(block_, platform::errors::PreconditionNotMet(
"block_ should not be null"));
return block_->FindRecursiveOrCreateVar(name).GetType();
}
virtual void SetType(const std::string& name, proto::VarType::Type type) {
PADDLE_ENFORCE_NOT_NULL(block_);
virtual void SetVarType(const std::string& name, proto::VarType::Type type) {
PADDLE_ENFORCE_NOT_NULL(
block_, platform::errors::PreconditionNotMet("op_ should not be null"));
block_->FindRecursiveOrCreateVar(name).SetType(type);
}
virtual proto::VarType::Type GetDataType(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(block_);
virtual proto::VarType::Type GetVarDataType(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(block_, platform::errors::PreconditionNotMet(
"block_ should not be null"));
return block_->FindRecursiveOrCreateVar(name).GetDataType();
}
virtual void SetDataType(const std::string& name, proto::VarType::Type type) {
PADDLE_ENFORCE_NOT_NULL(block_);
virtual void SetVarDataType(const std::string& name,
proto::VarType::Type type) {
PADDLE_ENFORCE_NOT_NULL(block_, platform::errors::PreconditionNotMet(
"block_ should not be null"));
block_->FindRecursiveOrCreateVar(name).SetDataType(type);
}
virtual std::vector<proto::VarType::Type> GetDataTypes(
virtual std::vector<proto::VarType::Type> GetVarDataTypes(
const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(block_);
PADDLE_ENFORCE_NOT_NULL(block_, platform::errors::PreconditionNotMet(
"block_ should not be null"));
return block_->FindRecursiveOrCreateVar(name).GetDataTypes();
}
virtual void SetDataTypes(
virtual void SetVarDataTypes(
const std::string& name,
const std::vector<proto::VarType::Type>& multiple_data_type) {
PADDLE_ENFORCE_NOT_NULL(block_);
PADDLE_ENFORCE_NOT_NULL(block_, platform::errors::PreconditionNotMet(
"block_ should not be null"));
block_->FindRecursiveOrCreateVar(name).SetDataTypes(multiple_data_type);
}
virtual std::vector<int64_t> GetShape(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(block_);
virtual std::vector<int64_t> GetVarShape(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(block_, platform::errors::PreconditionNotMet(
"block_ should not be null"));
return block_->FindRecursiveOrCreateVar(name).GetShape();
}
virtual void SetShape(const std::string& name,
const std::vector<int64_t>& dims) {
PADDLE_ENFORCE_NOT_NULL(block_);
virtual void SetVarShape(const std::string& name,
const std::vector<int64_t>& dims) {
PADDLE_ENFORCE_NOT_NULL(block_, platform::errors::PreconditionNotMet(
"block_ should not be null"));
block_->FindRecursiveOrCreateVar(name).SetShape(dims);
}
virtual int32_t GetLoDLevel(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(block_);
virtual int32_t GetVarLoDLevel(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(block_, platform::errors::PreconditionNotMet(
"block_ should not be null"));
return block_->FindRecursiveOrCreateVar(name).GetLoDLevel();
}
virtual void SetLoDLevel(const std::string& name, int32_t lod_level) {
PADDLE_ENFORCE_NOT_NULL(block_);
virtual void SetVarLoDLevel(const std::string& name, int32_t lod_level) {
PADDLE_ENFORCE_NOT_NULL(block_, platform::errors::PreconditionNotMet(
"block_ should not be null"));
block_->FindRecursiveOrCreateVar(name).SetLoDLevel(lod_level);
}
......@@ -133,22 +316,85 @@ class VarTypeInference {
virtual void operator()(InferVarTypeContext* context) const = 0; // NOLINT
};
class StaticGraphVarTypeInference : public VarTypeInference {
protected:
bool HasVar(InferVarTypeContext* ctx, const std::string& name) const {
return ctx->HasVar(name);
}
const std::vector<std::string>& Input(InferVarTypeContext* ctx,
const std::string& name) const {
return ctx->InputVars(name);
}
const std::vector<std::string>& Output(InferVarTypeContext* ctx,
const std::string& name) const {
return ctx->OutputVars(name);
}
proto::VarType::Type GetType(InferVarTypeContext* ctx,
const std::string& name) const {
return ctx->GetVarType(name);
}
void SetType(InferVarTypeContext* ctx, const std::string& name,
proto::VarType::Type type) const {
ctx->SetVarType(name, type);
}
proto::VarType::Type GetDataType(InferVarTypeContext* ctx,
const std::string& name) const {
return ctx->GetVarDataType(name);
}
void SetDataType(InferVarTypeContext* ctx, const std::string& name,
proto::VarType::Type type) const {
ctx->SetVarDataType(name, type);
}
std::vector<proto::VarType::Type> GetDataTypes(
InferVarTypeContext* ctx, const std::string& name) const {
return ctx->GetVarDataTypes(name);
}
void SetDataTypes(
InferVarTypeContext* ctx, const std::string& name,
const std::vector<proto::VarType::Type>& multiple_data_type) {
return ctx->SetVarDataTypes(name, multiple_data_type);
}
std::vector<int64_t> GetShape(InferVarTypeContext* ctx,
const std::string& name) const {
return ctx->GetVarShape(name);
}
void SetShape(InferVarTypeContext* ctx, const std::string& name,
const std::vector<int64_t>& dims) const {
ctx->SetVarShape(name, dims);
}
int32_t GetLoDLevel(InferVarTypeContext* ctx, const std::string& name) const {
return ctx->GetVarLoDLevel(name);
}
void SetLoDLevel(InferVarTypeContext* ctx, const std::string& name,
int32_t lod_level) const {
ctx->SetVarLoDLevel(name, lod_level);
}
};
class PassInDtypeAndVarTypeToOutput : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext* ctx) const final { // NOLINT
auto in_out_var_names = this->GetInputOutputWithSameType();
auto& in_out_var_names = this->GetInputOutputWithSameType();
for (auto& i_o_n : in_out_var_names) {
auto& x_name = ctx->Input(i_o_n.first).at(0);
auto& out_name = ctx->Output(i_o_n.second).at(0);
ctx->SetType(out_name, ctx->GetType(x_name));
ctx->SetDataType(out_name, ctx->GetDataType(x_name));
ctx->SyncTypeAndDataType(i_o_n.first, i_o_n.second);
}
}
protected:
virtual std::unordered_map<std::string, std::string>
virtual std::unordered_map<std::string, std::string>&
GetInputOutputWithSameType() const = 0;
};
......
......@@ -24,13 +24,13 @@ namespace framework {
class NOP : public OperatorBase {
public:
NOP(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs, const AttributeMap &attrs)
NOP(const std::string& type, const VariableNameMap& inputs,
const VariableNameMap& outputs, const AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
private:
void RunImpl(const Scope &scope,
const platform::Place &place) const override {}
void RunImpl(const Scope& scope,
const platform::Place& place) const override {}
};
class SumOpMaker : public OpProtoAndCheckerMaker {
......@@ -44,20 +44,14 @@ class SumOpMaker : public OpProtoAndCheckerMaker {
class SumOpVarTypeInference : public VarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto &inputs = ctx->Input("X");
void operator()(framework::InferVarTypeContext* ctx) const override {
auto default_var_type = proto::VarType::SELECTED_ROWS;
bool any_input_is_lod_tensor = std::any_of(
inputs.begin(), inputs.end(), [&ctx](const std::string &name) {
return ctx->GetType(name) == proto::VarType::LOD_TENSOR;
});
if (any_input_is_lod_tensor) {
if (ctx->InputTypeAnyOf("X", proto::VarType::LOD_TENSOR)) {
default_var_type = proto::VarType::LOD_TENSOR;
}
auto out_var_name = ctx->Output("Out").front();
ctx->SetType(out_var_name, default_var_type);
ctx->SetOutputType("Out", default_var_type);
}
};
} // namespace framework
......@@ -71,9 +65,79 @@ REGISTER_OPERATOR(sum_without_infer_var_type, paddle::framework::NOP,
namespace paddle {
namespace framework {
class TestStaticGraphVarTypeInference : public StaticGraphVarTypeInference {
public:
void operator()(InferVarTypeContext* context) const override {}
bool HasVar(InferVarTypeContext* ctx, const std::string& name) const {
return StaticGraphVarTypeInference::HasVar(ctx, name);
}
const std::vector<std::string>& Input(InferVarTypeContext* ctx,
const std::string& name) const {
return StaticGraphVarTypeInference::Input(ctx, name);
}
const std::vector<std::string>& Output(InferVarTypeContext* ctx,
const std::string& name) const {
return StaticGraphVarTypeInference::Output(ctx, name);
}
proto::VarType::Type GetType(InferVarTypeContext* ctx,
const std::string& name) const {
return StaticGraphVarTypeInference::GetType(ctx, name);
}
void SetType(InferVarTypeContext* ctx, const std::string& name,
proto::VarType::Type type) const {
StaticGraphVarTypeInference::SetType(ctx, name, type);
}
proto::VarType::Type GetDataType(InferVarTypeContext* ctx,
const std::string& name) const {
return StaticGraphVarTypeInference::GetDataType(ctx, name);
}
void SetDataType(InferVarTypeContext* ctx, const std::string& name,
proto::VarType::Type type) const {
StaticGraphVarTypeInference::SetDataType(ctx, name, type);
}
std::vector<proto::VarType::Type> GetDataTypes(
InferVarTypeContext* ctx, const std::string& name) const {
return StaticGraphVarTypeInference::GetDataTypes(ctx, name);
}
void SetDataTypes(
InferVarTypeContext* ctx, const std::string& name,
const std::vector<proto::VarType::Type>& multiple_data_type) {
return StaticGraphVarTypeInference::SetDataTypes(ctx, name,
multiple_data_type);
}
std::vector<int64_t> GetShape(InferVarTypeContext* ctx,
const std::string& name) const {
return StaticGraphVarTypeInference::GetShape(ctx, name);
}
void SetShape(InferVarTypeContext* ctx, const std::string& name,
const std::vector<int64_t>& dims) const {
StaticGraphVarTypeInference::SetShape(ctx, name, dims);
}
int32_t GetLoDLevel(InferVarTypeContext* ctx, const std::string& name) const {
return StaticGraphVarTypeInference::GetLoDLevel(ctx, name);
}
void SetLoDLevel(InferVarTypeContext* ctx, const std::string& name,
int32_t lod_level) const {
StaticGraphVarTypeInference::SetLoDLevel(ctx, name, lod_level);
}
};
TEST(InferVarType, sum_op) {
ProgramDesc prog;
auto *op = prog.MutableBlock(0)->AppendOp();
auto* op = prog.MutableBlock(0)->AppendOp();
op->SetType("sum");
op->SetInput("X", {"test_a", "test_b", "test_c"});
op->SetOutput("Out", {"test_out"});
......@@ -96,7 +160,7 @@ TEST(InferVarType, sum_op) {
TEST(InferVarType, sum_op_without_infer_var_type) {
ProgramDesc prog;
auto *op = prog.MutableBlock(0)->AppendOp();
auto* op = prog.MutableBlock(0)->AppendOp();
op->SetType("sum_without_infer_var_type");
op->SetInput("X", {"test2_a", "test2_b", "test2_c"});
op->SetOutput("Out", {"test2_out"});
......@@ -112,5 +176,112 @@ TEST(InferVarType, sum_op_without_infer_var_type) {
prog.MutableBlock(0)->Var("test2_out")->GetType());
}
TEST(InferVarType, multiple_api) {
ProgramDesc prog;
auto* block = prog.MutableBlock(0);
auto* op = block->AppendOp();
op->SetType("sum_without_infer_var_type");
op->SetInput("X", {"test2_a", "test2_b"});
op->SetOutput("Out", {"test2_a_out", "test2_b_out"});
block->Var("test2_a")->SetType(proto::VarType::SELECTED_ROWS);
block->Var("test2_b")->SetType(proto::VarType::SELECTED_ROWS);
block->Var("test2_a_out");
block->Var("test2_b_out");
InferVarTypeContext ctx(op, block);
ASSERT_TRUE(ctx.HasInput("X"));
ASSERT_TRUE(ctx.HasOutput("Out"));
ASSERT_EQ(2u, ctx.InputSize("X"));
ASSERT_EQ("test2_a", ctx.InputVarName("X", 0));
ASSERT_EQ(proto::VarType::SELECTED_ROWS, ctx.GetInputType("X"));
ASSERT_TRUE(ctx.InputTypeAllOf("X", proto::VarType::SELECTED_ROWS));
ASSERT_FALSE(ctx.InputTypeAnyOf("X", proto::VarType::LOD_TENSOR));
ctx.SyncTypeAndDataType("X", "Out");
ASSERT_EQ(proto::VarType::SELECTED_ROWS, ctx.GetOutputType("Out"));
ASSERT_EQ(proto::VarType::LOD_TENSOR, ctx.GetOutputType("Out", 1));
ctx.SetOutputType("Out", proto::VarType::SELECTED_ROWS, ALL_ELEMENTS);
ctx.SetOutputType("Out", proto::VarType::LOD_TENSOR, 1);
ASSERT_EQ(proto::VarType::SELECTED_ROWS, ctx.GetOutputType("Out"));
ASSERT_EQ(proto::VarType::LOD_TENSOR, ctx.GetOutputType("Out", 1));
ASSERT_EQ(0, ctx.GetInputDataType("X"));
ctx.SetOutputDataType("Out", proto::VarType::FP32, ALL_ELEMENTS);
ctx.SetOutputDataType("Out", proto::VarType::INT8, 1);
ASSERT_EQ(proto::VarType::FP32,
prog.MutableBlock(0)->Var("test2_a_out")->GetDataType());
ASSERT_EQ(proto::VarType::INT8,
prog.MutableBlock(0)->Var("test2_b_out")->GetDataType());
ASSERT_FALSE(ctx.IsDygraph());
// test StaticGraphVarTypeInference
TestStaticGraphVarTypeInference infer;
ASSERT_TRUE(infer.HasVar(&ctx, "test2_a"));
ASSERT_EQ(infer.Input(&ctx, "X").size(), infer.Output(&ctx, "Out").size());
ASSERT_EQ(proto::VarType::FP32, infer.GetDataType(&ctx, "test2_a_out"));
infer.SetDataType(&ctx, "test2_a_out", proto::VarType::FP64);
ASSERT_EQ(proto::VarType::FP64, infer.GetDataType(&ctx, "test2_a_out"));
ASSERT_EQ(proto::VarType::SELECTED_ROWS, infer.GetType(&ctx, "test2_a_out"));
infer.SetType(&ctx, "test2_a_out", proto::VarType::LOD_TENSOR);
ASSERT_EQ(proto::VarType::LOD_TENSOR, infer.GetType(&ctx, "test2_a_out"));
ASSERT_ANY_THROW(infer.GetDataTypes(&ctx, "test2_a_out"));
ASSERT_ANY_THROW(infer.SetDataTypes(&ctx, "test2_a_out", {}));
ASSERT_EQ(0u, infer.GetShape(&ctx, "test2_a_out").size());
infer.SetShape(&ctx, "test2_a_out", {
1, 3, 3,
});
ASSERT_EQ(3u, infer.GetShape(&ctx, "test2_a_out").size());
ASSERT_EQ(0, infer.GetLoDLevel(&ctx, "test2_a_out"));
infer.SetLoDLevel(&ctx, "test2_a_out", 2);
ASSERT_EQ(2, infer.GetLoDLevel(&ctx, "test2_a_out"));
}
TEST(InferVarType, test_enforce_check) {
InferVarTypeContext ctx(nullptr, nullptr);
ASSERT_ANY_THROW(ctx.HasInput("X"));
ASSERT_ANY_THROW(ctx.HasOutput("Out"));
ASSERT_ANY_THROW(ctx.InputSize("X"));
ASSERT_ANY_THROW(ctx.InputVarName("X"));
ASSERT_ANY_THROW(ctx.InputTypeAnyOf("X", proto::VarType::LOD_TENSOR));
ASSERT_ANY_THROW(ctx.InputTypeAllOf("X", proto::VarType::LOD_TENSOR));
ASSERT_ANY_THROW(ctx.SyncTypeAndDataType("X", "Out"));
ASSERT_ANY_THROW(ctx.SetOutputType("Out", proto::VarType::LOD_TENSOR));
ASSERT_ANY_THROW(ctx.GetInputType("X"));
ASSERT_ANY_THROW(ctx.GetOutputType("Out"));
ASSERT_ANY_THROW(ctx.GetInputDataType("X"));
ASSERT_ANY_THROW(ctx.SetOutputDataType("Out", proto::VarType::LOD_TENSOR));
ASSERT_ANY_THROW(ctx.GetInputDataTypes("X"));
ASSERT_ANY_THROW(ctx.SetOutputDataTypes("Out", {}));
ASSERT_ANY_THROW(ctx.GetInputShape("X"));
ASSERT_ANY_THROW(ctx.SetOutputShape("Out", {}));
ASSERT_ANY_THROW(ctx.GetInputLoDLevel("X"));
ASSERT_ANY_THROW(ctx.SetOutputLoDLevel("Out", 1));
ASSERT_ANY_THROW(ctx.InsertVar("var", proto::VarType::LOD_TENSOR));
}
} // namespace framework
} // namespace paddle
......@@ -14,6 +14,7 @@
#pragma once
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
......@@ -35,30 +36,7 @@ class RuntimeInferVarTypeContext : public framework::InferVarTypeContext {
: InferVarTypeContext(nullptr, nullptr),
inputs_(inputs),
outputs_(outputs),
attrs_(attrs_map),
input_names_(),
output_names_(),
var_set_() {
input_names_.reserve(inputs_.size());
for (auto& it : inputs_) {
for (auto& var : it.second) {
if (var) {
input_names_[it.first].emplace_back(var->Name());
var_set_[var->Name()] = var.get();
}
}
}
output_names_.reserve(outputs_.size());
for (auto& it : outputs_) {
for (auto& var : it.second) {
if (var) {
output_names_[it.first].emplace_back(var->Name());
var_set_[var->Name()] = var.get();
}
}
}
}
attrs_(attrs_map) {}
virtual ~RuntimeInferVarTypeContext() {}
......@@ -70,10 +48,6 @@ class RuntimeInferVarTypeContext : public framework::InferVarTypeContext {
return iter->second;
}
bool HasVar(const std::string& name) const override {
return var_set_.count(name) > 0;
}
bool HasInput(const std::string& name) const override {
auto it = inputs_.find(name);
return (it != inputs_.end() && it->second.size() > 0);
......@@ -84,93 +58,173 @@ class RuntimeInferVarTypeContext : public framework::InferVarTypeContext {
return (it != outputs_.end() && it->second.size() > 0);
}
const std::vector<std::string>& Input(
const std::string& name) const override {
auto iter = input_names_.find(name);
PADDLE_ENFORCE_EQ(
iter != input_names_.end(), true,
platform::errors::NotFound("Cannot find input var %s", name));
return iter->second;
size_t InputSize(const std::string& name) const {
return inputs_.at(name).size();
}
const std::vector<std::string>& Output(
const std::string& name) const override {
auto iter = output_names_.find(name);
const std::string& InputVarName(const std::string& name,
const int index = 0) const {
return inputs_.at(name)[index]->Name();
}
PADDLE_ENFORCE_EQ(
iter != output_names_.end(), true,
platform::errors::NotFound("Cannot find output var %s", name));
return iter->second;
bool InputTypeAnyOf(const std::string& name,
framework::proto::VarType::Type type) const override {
auto& inputs = inputs_.at(name);
return std::any_of(inputs.begin(), inputs.end(),
[&type](const std::shared_ptr<VarType>& var) {
return var->Type() == type;
});
}
framework::proto::VarType::Type GetType(
const std::string& name) const override {
auto iter = var_set_.find(name);
bool InputTypeAllOf(const std::string& name,
framework::proto::VarType::Type type) const override {
auto& inputs = inputs_.at(name);
return std::all_of(inputs.begin(), inputs.end(),
[&type](const std::shared_ptr<VarType>& var) {
return var->Type() == type;
});
}
PADDLE_ENFORCE_EQ(
iter != var_set_.end(), true,
platform::errors::NotFound("Cannot find var %s in GetType", name));
return iter->second->Type();
void SyncTypeAndDataType(const std::string& input_name,
const std::string& output_name,
int index = 0) override {
auto in_var = inputs_.at(input_name)[index];
auto out_var = outputs_.at(output_name)[index];
if (in_var != out_var) {
this->SetVarBaseType(out_var, in_var->Type());
this->SetVarBaseDataType(out_var, in_var->DataType());
}
}
void SetType(const std::string& name,
framework::proto::VarType::Type type) override {
if (name == "kLookupTablePath") {
VLOG(2) << "SUPER UGLY FIX, remove this when move imperative mode in C++";
void SetOutputType(const std::string& name,
framework::proto::VarType::Type type,
int index = 0) override {
if (index == framework::ALL_ELEMENTS) {
for (auto& item : outputs_.at(name)) {
this->SetVarBaseType(item, type);
}
} else {
var_set_[name]->SetType(type);
if ((var_set_[name]->MutableVar()->IsInitialized() == true) &&
(var_set_[name]->MutableVar()->Type() != type)) {
var_set_[name]->MutableVar()->Clear();
auto& var = outputs_.at(name)[index];
this->SetVarBaseType(var, type);
}
}
void SetVarBaseType(std::shared_ptr<VarType> out,
framework::proto::VarType::Type type) {
out->SetType(type);
if ((out->MutableVar()->IsInitialized() == true) &&
(out->MutableVar()->Type() != type)) {
out->MutableVar()->Clear();
}
}
void SetVarBaseDataType(std::shared_ptr<VarType> out,
framework::proto::VarType::Type type) {
out->SetDataType(type);
}
framework::proto::VarType::Type GetInputType(
const std::string& name, const int& index = 0) const override {
return inputs_.at(name)[index]->Type();
}
framework::proto::VarType::Type GetOutputType(
const std::string& name, const int& index = 0) const override {
return outputs_.at(name)[index]->Type();
}
framework::proto::VarType::Type GetInputDataType(
const std::string& name, const int& index = 0) const override {
return inputs_.at(name)[index]->DataType();
}
void SetOutputDataType(const std::string& name,
framework::proto::VarType::Type type,
int index = 0) override {
if (framework::ALL_ELEMENTS == index) {
for (auto& item : outputs_.at(name)) {
this->SetVarBaseDataType(item, type);
}
} else {
auto& var = outputs_.at(name)[index];
this->SetVarBaseDataType(var, type);
}
}
framework::proto::VarType::Type GetDataType(
bool IsDygraph() const override { return true; }
protected:
bool HasVar(const std::string& name) const override {
PADDLE_THROW(platform::errors::PermissionDenied(
"HasVar is not supported in runtime InferVarType"));
}
const std::vector<std::string>& InputVars(
const std::string& name) const override {
auto iter = var_set_.find(name);
PADDLE_THROW(platform::errors::PermissionDenied(
"InputVars is not supported in runtime InferVarType"));
}
PADDLE_ENFORCE_EQ(
iter != var_set_.end(), true,
platform::errors::NotFound("Cannot find var %s in GetDataType", name));
return iter->second->DataType();
const std::vector<std::string>& OutputVars(
const std::string& name) const override {
PADDLE_THROW(platform::errors::PermissionDenied(
"OutputVars is not supported in runtime InferVarType"));
}
framework::proto::VarType::Type GetVarType(
const std::string& name) const override {
PADDLE_THROW(platform::errors::PermissionDenied(
"Do not manipulate var in runtime InferVarType"));
}
void SetVarType(const std::string& name,
framework::proto::VarType::Type type) override {
PADDLE_THROW(platform::errors::PermissionDenied(
"Do not manipulate var in runtime InferVarType"));
}
void SetDataType(const std::string& name,
framework::proto::VarType::Type type) override {
var_set_[name]->SetDataType(type);
framework::proto::VarType::Type GetVarDataType(
const std::string& name) const override {
PADDLE_THROW(platform::errors::PermissionDenied(
"Do not manipulate var in runtime InferVarType"));
}
void SetVarDataType(const std::string& name,
framework::proto::VarType::Type type) override {
PADDLE_THROW(platform::errors::PermissionDenied(
"Do not manipulate var in runtime InferVarType"));
}
std::vector<framework::proto::VarType::Type> GetDataTypes(
std::vector<framework::proto::VarType::Type> GetVarDataTypes(
const std::string& name) const override {
PADDLE_THROW(platform::errors::PermissionDenied(
"GetDataTypes is not supported in runtime InferVarType"));
"GetVarDataTypes is not supported in runtime InferVarType"));
}
void SetDataTypes(const std::string& name,
const std::vector<framework::proto::VarType::Type>&
multiple_data_type) override {
void SetVarDataTypes(const std::string& name,
const std::vector<framework::proto::VarType::Type>&
multiple_data_type) override {
PADDLE_THROW(platform::errors::PermissionDenied(
"SetDataTypes is not supported in runtime InferVarType"));
"SetVarDataTypes is not supported in runtime InferVarType"));
}
std::vector<int64_t> GetShape(const std::string& name) const override {
std::vector<int64_t> GetVarShape(const std::string& name) const override {
PADDLE_THROW(platform::errors::PermissionDenied(
"Do not handle Shape in runtime InferVarType"));
}
void SetShape(const std::string& name,
const std::vector<int64_t>& dims) override {
void SetVarShape(const std::string& name,
const std::vector<int64_t>& dims) override {
PADDLE_THROW(platform::errors::PermissionDenied(
"Do not handle Shape in runtime InferVarType"));
}
int32_t GetLoDLevel(const std::string& name) const override {
int32_t GetVarLoDLevel(const std::string& name) const override {
PADDLE_THROW(platform::errors::PermissionDenied(
"Do not handle LoDLevel in runtime InferVarType"));
}
void SetLoDLevel(const std::string& name, int32_t lod_level) override {
void SetVarLoDLevel(const std::string& name, int32_t lod_level) override {
PADDLE_THROW(platform::errors::PermissionDenied(
"Do not handle LoDLevel in runtime InferVarType"));
}
......@@ -179,9 +233,6 @@ class RuntimeInferVarTypeContext : public framework::InferVarTypeContext {
const NameVarMap<VarType>& inputs_;
const NameVarMap<VarType>& outputs_;
const framework::AttributeMap& attrs_;
std::unordered_map<std::string, std::vector<std::string>> input_names_;
std::unordered_map<std::string, std::vector<std::string>> output_names_;
std::unordered_map<std::string, VarType*> var_set_;
};
} // namespace imperative
......
......@@ -37,33 +37,154 @@ using vb_vector = std::vector<std::shared_ptr<imperative::VarBase>>;
using var_pair = std::pair<std::string, vb_vector>;
template <typename VarType>
class TestRuntimeInferVarTypeContext
: public RuntimeInferVarTypeContext<VarType> {
public:
TestRuntimeInferVarTypeContext(const NameVarMap<VarType>& inputs,
const NameVarMap<VarType>& outputs,
const framework::AttributeMap& attrs_map)
: RuntimeInferVarTypeContext<VarType>(inputs, outputs, attrs_map) {}
bool HasVar(const std::string& name) const {
return RuntimeInferVarTypeContext<VarType>::HasVar(name);
}
const std::vector<std::string>& InputVars(const std::string& name) const {
return RuntimeInferVarTypeContext<VarType>::InputVars(name);
}
const std::vector<std::string>& OutputVars(const std::string& name) const {
return RuntimeInferVarTypeContext<VarType>::OutputVars(name);
}
framework::proto::VarType::Type GetVarType(const std::string& name) const {
return RuntimeInferVarTypeContext<VarType>::GetVarType(name);
}
void SetVarType(const std::string& name,
framework::proto::VarType::Type type) {
RuntimeInferVarTypeContext<VarType>::SetVarType(name, type);
}
framework::proto::VarType::Type GetVarDataType(
const std::string& name) const {
return RuntimeInferVarTypeContext<VarType>::GetVarDataType(name);
}
void SetVarDataType(const std::string& name,
framework::proto::VarType::Type type) {
RuntimeInferVarTypeContext<VarType>::SetVarDataType(name, type);
}
std::vector<framework::proto::VarType::Type> GetVarDataTypes(
const std::string& name) const {
return RuntimeInferVarTypeContext<VarType>::GetVarDataTypes(name);
}
void SetVarDataTypes(
const std::string& name,
const std::vector<framework::proto::VarType::Type>& multiple_data_type) {
RuntimeInferVarTypeContext<VarType>::SetVarDataTypes(name,
multiple_data_type);
}
std::vector<int64_t> GetVarShape(const std::string& name) const {
return RuntimeInferVarTypeContext<VarType>::GetVarShape(name);
}
void SetVarShape(const std::string& name, const std::vector<int64_t>& dims) {
RuntimeInferVarTypeContext<VarType>::SetVarShape(name, dims);
}
int32_t GetVarLoDLevel(const std::string& name) const {
return RuntimeInferVarTypeContext<VarType>::GetVarLoDLevel(name);
}
void SetVarLoDLevel(const std::string& name, int32_t lod_level) {
RuntimeInferVarTypeContext<VarType>::SetVarLoDLevel(name, lod_level);
}
};
TEST(test_layer, test_runtime_context) {
std::shared_ptr<imperative::VarBase> vin(
new imperative::VarBase(false, "vin"));
std::shared_ptr<imperative::VarBase> vin_b(
new imperative::VarBase(false, "vin_b"));
std::shared_ptr<imperative::VarBase> vout(
new imperative::VarBase(false, "vout"));
var_pair in_pair = var_pair("X", vb_vector(1, vin));
var_pair out_pair = var_pair("Out", vb_vector(1, vout));
std::shared_ptr<imperative::VarBase> vout_b(
new imperative::VarBase(false, "vout_b"));
var_pair in_pair = var_pair("X", {vin, vin_b});
var_pair out_pair = var_pair("Out", {vout, vout_b});
imperative::NameVarBaseMap ins = {in_pair};
imperative::NameVarBaseMap outs = {out_pair};
framework::AttributeMap attrs;
auto *ctx = new imperative::RuntimeInferVarTypeContext<imperative::VarBase>(
ins, outs, attrs);
ASSERT_TRUE(ctx->HasVar("vin"));
auto* ctx =
new imperative::TestRuntimeInferVarTypeContext<imperative::VarBase>(
ins, outs, attrs);
ASSERT_TRUE(ctx->HasInput("X"));
ASSERT_TRUE(ctx->HasOutput("Out"));
ASSERT_ANY_THROW(ctx->GetDataTypes("vin"));
ASSERT_EQ(2u, ctx->InputSize("X"));
ASSERT_EQ("vin", ctx->InputVarName("X", 0));
ASSERT_TRUE(ctx->InputTypeAnyOf("X", framework::proto::VarType::LOD_TENSOR));
ASSERT_TRUE(ctx->InputTypeAllOf("X", framework::proto::VarType::LOD_TENSOR));
ASSERT_EQ(framework::proto::VarType::LOD_TENSOR, ctx->GetInputType("X"));
ASSERT_EQ(framework::proto::VarType::FP32, ctx->GetInputDataType("X"));
ctx->SyncTypeAndDataType("X", "Out");
ASSERT_EQ(framework::proto::VarType::FP32, vout->DataType());
ASSERT_EQ(framework::proto::VarType::LOD_TENSOR, ctx->GetOutputType("Out"));
ctx->SetOutputType("Out", framework::proto::VarType::SELECTED_ROWS,
framework::ALL_ELEMENTS);
ctx->SetOutputType("Out", framework::proto::VarType::LOD_TENSOR_ARRAY);
ASSERT_EQ(framework::proto::VarType::LOD_TENSOR_ARRAY, vout->Type());
ASSERT_EQ(framework::proto::VarType::SELECTED_ROWS, vout_b->Type());
ctx->SetOutputDataType("Out", framework::proto::VarType::FP64,
framework::ALL_ELEMENTS);
ctx->SetOutputDataType("Out", framework::proto::VarType::INT8);
ASSERT_EQ(framework::proto::VarType::INT8, vout->DataType());
ASSERT_EQ(framework::proto::VarType::FP64, vout_b->DataType());
// no throw, but do nothing
ASSERT_NO_THROW(
ctx->InsertVar("vout", framework::proto::VarType::LOD_TENSOR));
ASSERT_EQ(framework::proto::VarType::LOD_TENSOR_ARRAY, vout->Type());
ASSERT_ANY_THROW(ctx->HasVar("vin"));
ASSERT_ANY_THROW(ctx->InputVars("X"));
ASSERT_ANY_THROW(ctx->OutputVars("Out"));
ASSERT_ANY_THROW(ctx->GetVarType("vin"));
ASSERT_ANY_THROW(
ctx->SetVarType("vin", framework::proto::VarType::LOD_TENSOR));
ASSERT_ANY_THROW(ctx->GetVarDataType("vin"));
ASSERT_ANY_THROW(
ctx->SetVarDataType("vout", framework::proto::VarType::FP32));
ASSERT_ANY_THROW(ctx->GetVarDataTypes("vin"));
std::vector<framework::proto::VarType::Type> NullType;
ASSERT_ANY_THROW(ctx->SetDataTypes("vin", NullType));
ASSERT_ANY_THROW(ctx->GetShape("vin"));
ASSERT_ANY_THROW(ctx->GetLoDLevel("vin"));
ASSERT_ANY_THROW(ctx->SetLoDLevel("vin", 2));
ASSERT_ANY_THROW(ctx->SetVarDataTypes("vin", NullType));
ASSERT_ANY_THROW(ctx->GetVarShape("vin"));
ASSERT_ANY_THROW(ctx->SetVarShape("vin", {}));
ASSERT_ANY_THROW(ctx->GetVarLoDLevel("vin"));
ASSERT_ANY_THROW(ctx->SetVarLoDLevel("vin", 2));
ASSERT_TRUE(ctx->IsDygraph());
}
std::string LayerDebugString(const std::string &op_type,
const NameVarBaseMap &ins,
const NameVarBaseMap &outs);
std::string LayerDebugString(const std::string& op_type,
const NameVarBaseMap& ins,
const NameVarBaseMap& outs);
TEST(test_layer, test_debug_string) {
platform::CPUPlace place;
......@@ -71,7 +192,7 @@ TEST(test_layer, test_debug_string) {
new imperative::VarBase(false, "vin"));
var_pair in_pair = var_pair("X", vb_vector(1, vin));
auto test_func = [&](std::shared_ptr<imperative::VarBase> &vout) {
auto test_func = [&](std::shared_ptr<imperative::VarBase>& vout) {
var_pair out_pair = var_pair("Out", vb_vector(1, vout));
imperative::NameVarBaseMap ins = {in_pair};
imperative::NameVarBaseMap outs = {out_pair};
......@@ -124,26 +245,26 @@ TEST(test_layer, test_debug_string) {
}
static std::shared_ptr<imperative::GradOpNode> CreateGradNode(
size_t id, const std::string &type, const imperative::NameVarBaseMap &ins,
const imperative::NameVarBaseMap &outs,
const framework::AttributeMap &attrs, const platform::Place &place) {
size_t id, const std::string& type, const imperative::NameVarBaseMap& ins,
const imperative::NameVarBaseMap& outs,
const framework::AttributeMap& attrs, const platform::Place& place) {
auto node = std::make_shared<imperative::GradOpNode>();
auto *op = &(node->emplace_back());
auto* op = &(node->emplace_back());
op->SetId(id);
op->SetPlace(place);
op->SetType(type);
op->SetAttrMap(attrs);
for (auto &pair : ins) {
for (auto& pair : ins) {
std::vector<std::shared_ptr<VariableWrapper>> vars;
for (auto &var : pair.second) {
for (auto& var : pair.second) {
vars.emplace_back(var->SharedVar());
}
op->SetInput(pair.first, vars, false);
}
for (auto &pair : outs) {
for (auto& pair : outs) {
std::vector<std::shared_ptr<VariableWrapper>> vars;
for (auto &var : pair.second) {
for (auto& var : pair.second) {
vars.emplace_back(var->SharedVar());
}
op->SetOutput(pair.first, vars, false);
......@@ -173,7 +294,7 @@ TEST(test_layer, test_clear_backward_info) {
node->InsertGradPendingNode(pending_node);
ASSERT_EQ(node->size(), 1UL);
auto *op = &(node->back());
auto* op = &(node->back());
ASSERT_GT(op->GetInsMap().size(), 0UL);
ASSERT_GT(op->GetOutsMap().size(), 0UL);
......@@ -196,10 +317,10 @@ TEST(test_layer, test_varbase_basic) {
std::shared_ptr<imperative::VarBase> vin_with_grad(
new imperative::VarBase(true, "vin"));
ASSERT_ANY_THROW(vin->MutableGradVar());
ASSERT_NO_THROW(ASSERT_TRUE(dynamic_cast<framework::Variable *>(
ASSERT_NO_THROW(ASSERT_TRUE(dynamic_cast<framework::Variable*>(
vin_with_grad->MutableGradVar()) != 0));
ASSERT_TRUE(dynamic_cast<framework::Variable *>(
vin_with_grad->MutableGradVar()) != 0);
ASSERT_TRUE(
dynamic_cast<framework::Variable*>(vin_with_grad->MutableGradVar()) != 0);
vin_with_grad->SetOverridedStopGradient(false);
ASSERT_FALSE(vin_with_grad->OverridedStopGradient());
ASSERT_NO_FATAL_FAILURE(vin_with_grad->SetPersistable(true));
......@@ -228,9 +349,9 @@ TEST(test_layer, test_dygraph_execution_context) {
auto op = framework::OpRegistry::CreateOp("mul", {}, {}, {}, false);
paddle::platform::CPUPlace cpu_place;
paddle::platform::DeviceContextPool &pool =
paddle::platform::DeviceContextPool& pool =
paddle::platform::DeviceContextPool::Instance();
auto *dev_ctx = pool.Get(cpu_place);
auto* dev_ctx = pool.Get(cpu_place);
paddle::framework::RuntimeContext ctx({}, {});
framework::Scope scope;
......
......@@ -129,9 +129,10 @@ class ActivationOp : public framework::OperatorWithKernel {
class ActivationOpInferVarType
: public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
std::unordered_map<std::string, std::string>& GetInputOutputWithSameType()
const override {
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Out"}};
static std::unordered_map<std::string, std::string> m{{"X", /*->*/ "Out"}};
return m;
}
};
......
......@@ -103,8 +103,7 @@ class AllcloseOp : public framework::OperatorWithKernel {
class AllcloseOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto out_var_name = ctx->Output("Out").front();
ctx->SetDataType(out_var_name, framework::proto::VarType::BOOL);
ctx->SetOutputDataType("Out", framework::proto::VarType::BOOL);
}
};
......
......@@ -60,11 +60,7 @@ class AssignOp : public framework::OperatorWithKernel {
class AssignInferVarType : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto out_var_name = ctx->Output("Out")[0];
auto input_type = ctx->GetType(ctx->Input("X")[0]);
auto input_data_type = ctx->GetDataType(ctx->Input("X")[0]);
ctx->SetType(out_var_name, input_type);
ctx->SetDataType(out_var_name, input_data_type);
ctx->SyncTypeAndDataType("X", "Out");
}
};
......
......@@ -171,9 +171,10 @@ class BatchNormGradMaker : public framework::SingleGradOpMaker<T> {
class BatchNormOpInferVarType
: public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
std::unordered_map<std::string, std::string>& GetInputOutputWithSameType()
const override {
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Y"}};
static std::unordered_map<std::string, std::string> m{{"X", /*->*/ "Y"}};
return m;
}
};
......
......@@ -204,12 +204,10 @@ class BeamSearchDecodeInferShape : public framework::InferShapeBase {
class BeamSearchDecodeInferVarType : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext* ctx) const override {
for (auto& o : ctx->Output("SentenceIds")) {
ctx->SetType(o, framework::proto::VarType::LOD_TENSOR);
}
for (auto& o : ctx->Output("SentenceScores")) {
ctx->SetType(o, framework::proto::VarType::LOD_TENSOR);
}
ctx->SetOutputType("SentenceIds", framework::proto::VarType::LOD_TENSOR,
framework::ALL_ELEMENTS);
ctx->SetOutputType("SentenceScores", framework::proto::VarType::LOD_TENSOR,
framework::ALL_ELEMENTS);
}
};
......
......@@ -122,12 +122,10 @@ class BeamSearchOp : public framework::OperatorWithKernel {
class BeamSearchInferVarType : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
for (auto &o : ctx->Output("selected_ids")) {
ctx->SetType(o, framework::proto::VarType::LOD_TENSOR);
}
for (auto &o : ctx->Output("selected_scores")) {
ctx->SetType(o, framework::proto::VarType::LOD_TENSOR);
}
ctx->SetOutputType("selected_ids", framework::proto::VarType::LOD_TENSOR,
framework::ALL_ELEMENTS);
ctx->SetOutputType("selected_scores", framework::proto::VarType::LOD_TENSOR,
framework::ALL_ELEMENTS);
}
};
......
......@@ -92,9 +92,8 @@ execution.
class GetPlacesInferVarType : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
for (auto &o_name : ctx->Output("Out")) {
ctx->SetType(o_name, framework::proto::VarType::PLACE_LIST);
}
ctx->SetOutputType("Out", framework::proto::VarType::PLACE_LIST,
framework::ALL_ELEMENTS);
}
};
......
......@@ -111,15 +111,15 @@ class WriteToArrayInferShape : public framework::InferShapeBase {
}
};
class WriteToArrayInferVarType : public framework::VarTypeInference {
class WriteToArrayInferVarType : public framework::StaticGraphVarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto x_name = ctx->Input("X")[0];
auto out_name = ctx->Output("Out")[0];
auto x_name = Input(ctx, "X")[0];
auto out_name = Output(ctx, "Out")[0];
VLOG(10) << "Set Variable " << out_name << " as LOD_TENSOR_ARRAY";
ctx->SetType(out_name, framework::proto::VarType::LOD_TENSOR_ARRAY);
if (ctx->HasVar(x_name)) {
ctx->SetDataType(out_name, ctx->GetDataType(x_name));
SetType(ctx, out_name, framework::proto::VarType::LOD_TENSOR_ARRAY);
if (HasVar(ctx, x_name)) {
SetDataType(ctx, out_name, GetDataType(ctx, x_name));
}
}
};
......
......@@ -398,18 +398,19 @@ class WhileGradOpMaker : public framework::SingleGradOpMaker<T> {
}
};
class WhileGradOpVarTypeInference : public framework::VarTypeInference {
class WhileGradOpVarTypeInference
: public framework::StaticGraphVarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto p_names = ctx->Input(kX);
auto pg_ig_names = ctx->Output(framework::GradVarName(kX));
auto p_names = Input(ctx, kX);
auto pg_ig_names = Output(ctx, framework::GradVarName(kX));
for (size_t i = 0; i < p_names.size(); ++i) {
if (ctx->HasVar(pg_ig_names[i])) {
if (HasVar(ctx, pg_ig_names[i])) {
VLOG(5) << "Setting " << pg_ig_names[i] << " following " << p_names[i]
<< " type: " << ctx->GetType(p_names[i]);
ctx->SetType(pg_ig_names[i], ctx->GetType(p_names[i]));
ctx->SetDataType(pg_ig_names[i], ctx->GetDataType(p_names[i]));
<< " type: " << GetType(ctx, p_names[i]);
SetType(ctx, pg_ig_names[i], GetType(ctx, p_names[i]));
SetDataType(ctx, pg_ig_names[i], GetDataType(ctx, p_names[i]));
}
}
}
......
......@@ -254,10 +254,11 @@ class Conv3DOpMaker : public framework::OpProtoAndCheckerMaker {
class ConvOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
std::unordered_map<std::string, std::string>& GetInputOutputWithSameType()
const override {
return std::unordered_map<std::string, std::string>{
static std::unordered_map<std::string, std::string> m{
{"Input", /*->*/ "Output"}};
return m;
}
};
......
......@@ -177,9 +177,10 @@ class CrossEntropyGradientOpBase : public framework::OperatorWithKernel {
class CrossEntropyOpInferVarType
: public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
std::unordered_map<std::string, std::string>& GetInputOutputWithSameType()
const override {
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Y"}};
static std::unordered_map<std::string, std::string> m{{"X", /*->*/ "Y"}};
return m;
}
};
......
......@@ -115,10 +115,8 @@ class MergeIdsOp : public framework::OperatorWithKernel {
class MergeIdsOpInferVarType : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto input_type = ctx->GetType(ctx->Input("Ids")[0]);
for (auto &out_var : ctx->Output("Out")) {
ctx->SetType(out_var, input_type);
}
auto input_type = ctx->GetInputType("Ids");
ctx->SetOutputType("Out", input_type, framework::ALL_ELEMENTS);
}
};
......
......@@ -73,10 +73,8 @@ class SplitIdsOp : public framework::OperatorWithKernel {
class SplitIdsOpInferVarType : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto input_type = ctx->GetType(ctx->Input("Ids")[0]);
for (auto &out_var : ctx->Output("Out")) {
ctx->SetType(out_var, input_type);
}
auto input_type = ctx->GetInputType("Ids");
ctx->SetOutputType("Out", input_type, framework::ALL_ELEMENTS);
}
};
......
......@@ -119,9 +119,10 @@ class ElementwiseOp : public framework::OperatorWithKernel {
class ElementwiseOpInferVarType
: public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
std::unordered_map<std::string, std::string> &GetInputOutputWithSameType()
const override {
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Out"}};
static std::unordered_map<std::string, std::string> m{{"X", /*->*/ "Out"}};
return m;
}
};
......
......@@ -49,8 +49,7 @@ class EyeOpVarTypeInference : public framework::VarTypeInference {
void operator()(framework::InferVarTypeContext* ctx) const override {
auto data_type = static_cast<framework::proto::VarType::Type>(
boost::get<int>(ctx->GetAttr("dtype")));
auto& out_var_name = ctx->Output("Out").front();
ctx->SetDataType(out_var_name, data_type);
ctx->SetOutputDataType("Out", data_type);
}
};
......
......@@ -72,14 +72,12 @@ The output will have the same shape and dtype as the input.
class FillAnyLikeVarTypeInference : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto out_var_name = ctx->Output("Out").front();
auto var_data_type = static_cast<framework::proto::VarType::Type>(
boost::get<int>(ctx->GetAttr("dtype")));
if (var_data_type < 0) {
const auto &input_var_name = ctx->Input("X").front();
ctx->SetDataType(out_var_name, ctx->GetDataType(input_var_name));
ctx->SetOutputDataType("Out", ctx->GetInputDataType("X"));
} else {
ctx->SetDataType(out_var_name, var_data_type);
ctx->SetOutputDataType("Out", var_data_type);
}
}
};
......
......@@ -64,8 +64,7 @@ class FillConstantOpVarTypeInference : public framework::VarTypeInference {
void operator()(framework::InferVarTypeContext* ctx) const override {
auto data_type = static_cast<framework::proto::VarType::Type>(
boost::get<int>(ctx->GetAttr("dtype")));
auto& out_var_name = ctx->Output("Out").front();
ctx->SetDataType(out_var_name, data_type);
ctx->SetOutputDataType("Out", data_type);
}
};
......
......@@ -63,8 +63,7 @@ class FillOpVarTypeInference : public framework::VarTypeInference {
void operator()(framework::InferVarTypeContext* ctx) const override {
auto data_type = static_cast<framework::proto::VarType::Type>(
boost::get<int>(ctx->GetAttr("dtype")));
auto& out_var_name = ctx->Output("Out").front();
ctx->SetDataType(out_var_name, data_type);
ctx->SetOutputDataType("Out", data_type);
}
};
......
......@@ -114,9 +114,10 @@ class FlipOpMaker : public framework::OpProtoAndCheckerMaker {
class FlipOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
std::unordered_map<std::string, std::string>& GetInputOutputWithSameType()
const override {
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Out"}};
static std::unordered_map<std::string, std::string> m{{"X", /*->*/ "Out"}};
return m;
}
};
......
......@@ -85,9 +85,10 @@ class FusedBatchNormActGradOpMaker : public framework::SingleGradOpMaker<T> {
class FusedBatchNormActOpInferVarType
: public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
std::unordered_map<std::string, std::string>& GetInputOutputWithSameType()
const override {
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Y"}};
static std::unordered_map<std::string, std::string> m{{"X", /*->*/ "Y"}};
return m;
}
};
......
......@@ -146,19 +146,20 @@ class FusedEmbeddingSeqPoolOpGradVarTypeInference
: public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext* ctx) const override {
auto out_var_name = ctx->Output(framework::GradVarName("W")).front();
auto out_var_name = framework::GradVarName("W");
auto attr = ctx->GetAttr("is_sparse");
bool is_sparse = boost::get<bool>(attr);
if (is_sparse) {
VLOG(3) << "fused_embedding_seq_pool_grad op "
<< framework::GradVarName("W") << " is set to SelectedRows";
ctx->SetType(out_var_name, framework::proto::VarType::SELECTED_ROWS);
ctx->SetOutputType(out_var_name,
framework::proto::VarType::SELECTED_ROWS);
} else {
VLOG(3) << "fused_embedding_seq_pool_grad op "
<< framework::GradVarName("W") << " is set to LoDTensor";
ctx->SetType(out_var_name, framework::proto::VarType::LOD_TENSOR);
ctx->SetOutputType(out_var_name, framework::proto::VarType::LOD_TENSOR);
}
ctx->SetDataType(out_var_name, ctx->GetDataType(ctx->Input("W")[0]));
ctx->SetOutputDataType(out_var_name, ctx->GetInputDataType("W"));
}
};
......
......@@ -83,11 +83,8 @@ class GetTensorFromSelectedRowsOpVarTypeInference
: public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const { // NOLINT
auto out_var_name = ctx->Output("Out").front();
auto in_var_name = ctx->Input("X").front();
ctx->SetType(out_var_name, framework::proto::VarType::LOD_TENSOR);
ctx->SetDataType(out_var_name, ctx->GetDataType(in_var_name));
ctx->SetOutputType("Out", framework::proto::VarType::LOD_TENSOR);
ctx->SetOutputDataType("Out", ctx->GetInputDataType("X"));
}
};
......
......@@ -216,9 +216,10 @@ DECLARE_INPLACE_OP_INFERER(GroupNormGradInplaceInToOut,
class GroupNormOpInferVarType
: public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
std::unordered_map<std::string, std::string> &GetInputOutputWithSameType()
const override {
return {{"X", /*->*/ "Y"}};
static std::unordered_map<std::string, std::string> m{{"X", /*->*/ "Y"}};
return m;
}
};
......
......@@ -229,31 +229,30 @@ class HierarchicalSigmoidGradOpGradVarTypeInference
: public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext* ctx) const override {
auto w_grad_var_name = ctx->Output(framework::GradVarName("W")).front();
auto has_bias_grad_var = ctx->HasOutput(framework::GradVarName("Bias"));
std::string bias_grad_var_name;
bool hasBias = false;
if (has_bias_grad_var) {
hasBias = true;
bias_grad_var_name = ctx->Output(framework::GradVarName("Bias")).front();
auto w_grad_var_name = framework::GradVarName("W");
auto bias_grad_var_name = framework::GradVarName("Bias");
if (ctx->HasOutput(bias_grad_var_name)) {
VLOG(3) << "hierarchical_sigmoid_grad op "
<< framework::GradVarName("Bias") << " is set to LoDTensor";
ctx->SetOutputType(bias_grad_var_name,
framework::proto::VarType::LOD_TENSOR);
}
auto attr = ctx->GetAttr("is_sparse");
bool is_sparse = boost::get<bool>(attr);
if (is_sparse) {
VLOG(3) << "hierarchical_sigmoid_grad op " << framework::GradVarName("W")
<< " is set to SelectedRows";
ctx->SetType(w_grad_var_name, framework::proto::VarType::SELECTED_ROWS);
ctx->SetOutputType(w_grad_var_name,
framework::proto::VarType::SELECTED_ROWS);
} else {
VLOG(3) << "hierarchical_sigmoid_grad op " << framework::GradVarName("W")
<< " is set to LoDTensor";
ctx->SetType(w_grad_var_name, framework::proto::VarType::LOD_TENSOR);
ctx->SetOutputType(w_grad_var_name,
framework::proto::VarType::LOD_TENSOR);
}
if (hasBias) {
VLOG(3) << "hierarchical_sigmoid_grad op "
<< framework::GradVarName("Bias") << " is set to LoDTensor";
ctx->SetType(bias_grad_var_name, framework::proto::VarType::LOD_TENSOR);
}
ctx->SetDataType(w_grad_var_name, ctx->GetDataType(ctx->Input("W")[0]));
ctx->SetOutputDataType(w_grad_var_name, ctx->GetInputDataType("W"));
}
};
......
......@@ -123,9 +123,10 @@ class InstanceNormDoubleGradMaker : public framework::SingleGradOpMaker<T> {
class InstanceNormOpInferVarType
: public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
std::unordered_map<std::string, std::string> &GetInputOutputWithSameType()
const override {
return std::unordered_map<std::string, std::string>{{"X", "Y"}};
static std::unordered_map<std::string, std::string> m{{"X", "Y"}};
return m;
}
};
......
......@@ -65,9 +65,8 @@ class LoDRankTableInferShape : public framework::InferShapeBase {
class LoDRankTableInferVarType : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
for (auto &o : ctx->Output("Out")) {
ctx->SetType(o, framework::proto::VarType::LOD_RANK_TABLE);
}
ctx->SetOutputType("Out", framework::proto::VarType::LOD_RANK_TABLE,
framework::ALL_ELEMENTS);
}
};
......
......@@ -76,24 +76,25 @@ class LoDResetOp : public framework::OperatorWithKernel {
}
};
class LoDResetOpVarTypeInference : public framework::VarTypeInference {
class LoDResetOpVarTypeInference
: public framework::StaticGraphVarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto x_var_name = ctx->Input("X").front();
auto out_var_name = ctx->Output("Out").front();
auto x_var_name = Input(ctx, "X").front();
auto out_var_name = Output(ctx, "Out").front();
bool append = boost::get<bool>(ctx->GetAttr("append"));
if (ctx->HasInput("Y")) {
auto y_var_name = ctx->Input("Y").front();
auto y_lod_level = std::max(ctx->GetLoDLevel(y_var_name), 1);
ctx->SetLoDLevel(out_var_name, y_lod_level);
auto y_var_name = Input(ctx, "Y").front();
auto y_lod_level = std::max(GetLoDLevel(ctx, y_var_name), 1);
SetLoDLevel(ctx, out_var_name, y_lod_level);
} else if (append) {
auto x_lod_level = std::max(ctx->GetLoDLevel(x_var_name), 1);
ctx->SetLoDLevel(out_var_name, x_lod_level);
auto x_lod_level = std::max(GetLoDLevel(ctx, x_var_name), 1);
SetLoDLevel(ctx, out_var_name, x_lod_level);
} else {
ctx->SetLoDLevel(out_var_name, 1);
SetLoDLevel(ctx, out_var_name, 1);
}
ctx->SetDataType(out_var_name, ctx->GetDataType(x_var_name));
ctx->SetType(out_var_name, paddle::framework::proto::VarType::LOD_TENSOR);
SetDataType(ctx, out_var_name, GetDataType(ctx, x_var_name));
SetType(ctx, out_var_name, paddle::framework::proto::VarType::LOD_TENSOR);
}
};
......
......@@ -221,9 +221,8 @@ class LoDTensorToArrayInferShape : public framework::InferShapeBase {
class LoDTensorToArrayInferVarType : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
for (auto &out_var : ctx->Output("Out")) {
ctx->SetType(out_var, framework::proto::VarType::LOD_TENSOR_ARRAY);
}
ctx->SetOutputType("Out", framework::proto::VarType::LOD_TENSOR_ARRAY,
framework::ALL_ELEMENTS);
}
};
......
......@@ -173,19 +173,20 @@ class LookupTableOpGrad : public framework::OperatorWithKernel {
class LookupTableOpGradVarTypeInference : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext* ctx) const override {
auto out_var_name = ctx->Output(framework::GradVarName("W")).front();
auto out_var_name = framework::GradVarName("W");
auto attr = ctx->GetAttr("is_sparse");
bool is_sparse = boost::get<bool>(attr);
if (is_sparse) {
VLOG(3) << "lookup_table_grad op " << framework::GradVarName("W")
<< " is set to SelectedRows";
ctx->SetType(out_var_name, framework::proto::VarType::SELECTED_ROWS);
ctx->SetOutputType(out_var_name,
framework::proto::VarType::SELECTED_ROWS);
} else {
VLOG(3) << "lookup_table_grad op " << framework::GradVarName("W")
<< " is set to LoDTensor";
ctx->SetType(out_var_name, framework::proto::VarType::LOD_TENSOR);
ctx->SetOutputType(out_var_name, framework::proto::VarType::LOD_TENSOR);
}
ctx->SetDataType(out_var_name, ctx->GetDataType(ctx->Input("W")[0]));
ctx->SetOutputDataType(out_var_name, ctx->GetInputDataType("W"));
}
};
......
......@@ -160,19 +160,20 @@ class LookupTableV2OpGrad : public framework::OperatorWithKernel {
class LookupTableV2OpGradVarTypeInference : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext* ctx) const override {
auto out_var_name = ctx->Output(framework::GradVarName("W")).front();
auto out_var_name = framework::GradVarName("W");
auto attr = ctx->GetAttr("is_sparse");
bool is_sparse = boost::get<bool>(attr);
if (is_sparse) {
VLOG(3) << "lookup_table_v2_grad op " << framework::GradVarName("W")
<< " is set to SelectedRows";
ctx->SetType(out_var_name, framework::proto::VarType::SELECTED_ROWS);
ctx->SetOutputType(out_var_name,
framework::proto::VarType::SELECTED_ROWS);
} else {
VLOG(3) << "lookup_table_v2_grad op " << framework::GradVarName("W")
<< " is set to LoDTensor";
ctx->SetType(out_var_name, framework::proto::VarType::LOD_TENSOR);
ctx->SetOutputType(out_var_name, framework::proto::VarType::LOD_TENSOR);
}
ctx->SetDataType(out_var_name, ctx->GetDataType(ctx->Input("W")[0]));
ctx->SetOutputDataType(out_var_name, ctx->GetInputDataType("W"));
}
};
......
......@@ -45,9 +45,10 @@ Mean Operator calculates the mean of all elements in X.
class MeanOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
std::unordered_map<std::string, std::string>& GetInputOutputWithSameType()
const override {
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Out"}};
static std::unordered_map<std::string, std::string> m{{"X", /*->*/ "Out"}};
return m;
}
};
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/merge_selected_rows_op.h"
#include <unordered_map>
namespace paddle {
namespace operators {
......@@ -79,9 +80,10 @@ Example:
class MergeSelectedRowsOpInferVarType
: public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
std::unordered_map<std::string, std::string>& GetInputOutputWithSameType()
const override {
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Out"}};
static std::unordered_map<std::string, std::string> m{{"X", /*->*/ "Out"}};
return m;
}
};
......
......@@ -200,9 +200,10 @@ or not. But the output only shares the LoD information with input $X$.
class MulOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
std::unordered_map<std::string, std::string>& GetInputOutputWithSameType()
const override {
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Out"}};
static std::unordered_map<std::string, std::string> m{{"X", /*->*/ "Out"}};
return m;
}
};
......
......@@ -61,8 +61,7 @@ class NCCLInitOp : public framework::OperatorBase {
class NCCLInitOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto out_var_name = ctx->Output("Communicator").front();
ctx->SetType(out_var_name, framework::proto::VarType::RAW);
ctx->SetOutputType("Communicator", framework::proto::VarType::RAW);
}
};
......
......@@ -280,20 +280,20 @@ class NCEOpGrad : public framework::OperatorWithKernel {
class NCEOpGradVarTypeInference : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto weight_grad = ctx->Output(framework::GradVarName("Weight")).front();
auto weight_grad = framework::GradVarName("Weight");
auto attr = ctx->GetAttr("is_sparse");
bool is_sparse = boost::get<bool>(attr);
if (is_sparse) {
VLOG(3) << "nce_op_grad op " << weight_grad << " and "
<< " is set to SelectedRows";
ctx->SetType(weight_grad, framework::proto::VarType::SELECTED_ROWS);
ctx->SetOutputType(weight_grad, framework::proto::VarType::SELECTED_ROWS);
} else {
VLOG(3) << "nce_op_grad op " << weight_grad << " and "
<< " is set to LoDTensor";
ctx->SetType(weight_grad, framework::proto::VarType::LOD_TENSOR);
ctx->SetOutputType(weight_grad, framework::proto::VarType::LOD_TENSOR);
}
ctx->SetDataType(weight_grad, ctx->GetDataType(ctx->Input("Input")[0]));
ctx->SetOutputDataType(weight_grad, ctx->GetInputDataType("Input"));
}
};
......
......@@ -22,18 +22,15 @@ using Tensor = framework::Tensor;
class MomentumOpInferVarType : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext* ctx) const override {
auto& input_var = ctx->Input("Param")[0];
for (auto& out_var : ctx->Output("ParamOut")) {
if (ctx->GetType(input_var) == framework::proto::VarType::SELECTED_ROWS) {
ctx->SetType(out_var, framework::proto::VarType::SELECTED_ROWS);
} else if (ctx->GetType(input_var) ==
framework::proto::VarType::LOD_TENSOR) {
ctx->SetType(out_var, framework::proto::VarType::LOD_TENSOR);
} else {
PADDLE_THROW(
"Only support LodTensor and SelectedRows, Unexpected Input Type.");
}
}
auto in_var_type = ctx->GetInputType("Param");
PADDLE_ENFORCE_EQ(
in_var_type == framework::proto::VarType::SELECTED_ROWS ||
in_var_type == framework::proto::VarType::LOD_TENSOR,
true,
platform::errors::InvalidArgument(
"Only support LodTensor and SelectedRows, Unexpected Input Type."));
ctx->SetOutputType("ParamOut", in_var_type, framework::ALL_ELEMENTS);
}
};
......
......@@ -75,19 +75,15 @@ class SGDOp : public framework::OperatorWithKernel {
class SGDOpInferVarType : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto &input_var_n = ctx->Input("Param")[0];
auto in_var_type = ctx->GetType(input_var_n);
PADDLE_ENFORCE(in_var_type == framework::proto::VarType::SELECTED_ROWS ||
in_var_type == framework::proto::VarType::LOD_TENSOR,
"The input Var's type should be LoDtensor or SelectedRows,"
" but the received var(%s)'s type is %s",
input_var_n, in_var_type);
for (auto &out_var_n : ctx->Output("ParamOut")) {
if (ctx->GetType(out_var_n) != in_var_type) {
ctx->SetType(out_var_n, in_var_type);
}
}
auto in_var_type = ctx->GetInputType("Param");
PADDLE_ENFORCE_EQ(in_var_type == framework::proto::VarType::SELECTED_ROWS ||
in_var_type == framework::proto::VarType::LOD_TENSOR,
true, platform::errors::InvalidArgument(
"The input Var's type should be LoDtensor or "
"SelectedRows, but the received type is %s",
in_var_type));
ctx->SetOutputType("ParamOut", in_var_type, framework::ALL_ELEMENTS);
}
};
......
......@@ -422,9 +422,10 @@ Example:
class PoolOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
std::unordered_map<std::string, std::string>& GetInputOutputWithSameType()
const override {
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Out"}};
static std::unordered_map<std::string, std::string> m{{"X", /*->*/ "Out"}};
return m;
}
};
......
......@@ -260,9 +260,7 @@ class PrintOpInferShape : public framework::InferShapeBase {
class PrintOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto input_type = ctx->GetType(ctx->Input("In")[0]);
auto out_name = ctx->Output("Out").front();
ctx->SetType(out_name, input_type);
ctx->SetOutputType("Out", ctx->GetInputType("In"));
}
};
......
......@@ -116,12 +116,11 @@ static void CallPythonFunc(py::object *callable,
}
}
class PyFuncOpVarTypeInference : public framework::VarTypeInference {
class PyFuncOpVarTypeInference : public framework::StaticGraphVarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
bool has_out = (ctx->HasOutput("Out") && !ctx->Output("Out").empty());
bool has_in = (ctx->HasInput("X") && !ctx->Input("X").empty());
bool has_out = ctx->HasOutput("Out");
bool has_in = ctx->HasInput("X");
/**
* X or Out can be empty, so that py_func can be more flexible
......@@ -147,7 +146,7 @@ class PyFuncOpVarTypeInference : public framework::VarTypeInference {
* the corresponding forward variable
*/
const std::string kGradVarSuffix = framework::kGradVarSuffix;
auto &out_var_names = ctx->Output("Out");
auto &out_var_names = Output(ctx, "Out");
for (auto &out_var_name : out_var_names) {
if (out_var_name == framework::kEmptyVarName ||
out_var_name.size() < kGradVarSuffix.size()) {
......@@ -157,19 +156,17 @@ class PyFuncOpVarTypeInference : public framework::VarTypeInference {
size_t len = out_var_name.size() - kGradVarSuffix.size();
if (out_var_name.substr(len) == kGradVarSuffix) {
auto fwd_var_name = out_var_name.substr(0, len);
PADDLE_ENFORCE_EQ(ctx->HasVar(out_var_name), true,
platform::errors::InvalidArgument(
"Backward variable %s not found", out_var_name));
PADDLE_ENFORCE_EQ(ctx->HasVar(fwd_var_name), true,
platform::errors::InvalidArgument(
"Backward variable %s not found", fwd_var_name));
OP_INOUT_CHECK(HasVar(ctx, out_var_name), "Var", out_var_name,
"py_func");
OP_INOUT_CHECK(HasVar(ctx, fwd_var_name), "Var", fwd_var_name,
"py_func");
VLOG(10) << "Infer var_desc of Output(" << out_var_name << ") as Input("
<< fwd_var_name << ")";
ctx->SetShape(out_var_name, ctx->GetShape(fwd_var_name));
ctx->SetDataType(out_var_name, ctx->GetDataType(fwd_var_name));
ctx->SetLoDLevel(out_var_name, ctx->GetLoDLevel(fwd_var_name));
ctx->SetType(out_var_name, ctx->GetType(fwd_var_name));
SetShape(ctx, out_var_name, GetShape(ctx, fwd_var_name));
SetDataType(ctx, out_var_name, GetDataType(ctx, fwd_var_name));
SetLoDLevel(ctx, out_var_name, GetLoDLevel(ctx, fwd_var_name));
SetType(ctx, out_var_name, GetType(ctx, fwd_var_name));
}
}
}
......
......@@ -75,8 +75,7 @@ class RandpermOpVarTypeInference : public framework::VarTypeInference {
void operator()(framework::InferVarTypeContext *ctx) const override {
auto var_data_type = static_cast<framework::proto::VarType::Type>(
boost::get<int>(ctx->GetAttr("dtype")));
auto out_var_name = ctx->Output("Out").front();
ctx->SetDataType(out_var_name, var_data_type);
ctx->SetOutputDataType("Out", var_data_type);
}
};
......
......@@ -70,18 +70,18 @@ class ReadInferShape : public framework::InferShapeBase {
}
};
class ReadInferVarType : public framework::VarTypeInference {
class ReadInferVarType : public framework::StaticGraphVarTypeInference {
public:
void operator()(framework::InferVarTypeContext* ctx) const override {
bool infer_out = boost::get<bool>(ctx->GetAttr("infer_out"));
if (infer_out) {
std::string reader_name = ctx->Input("Reader")[0];
std::vector<std::string> out_names = ctx->Output("Out");
auto dtypes = ctx->GetDataTypes(reader_name);
std::string reader_name = Input(ctx, "Reader")[0];
auto& out_names = Output(ctx, "Out");
auto dtypes = GetDataTypes(ctx, reader_name);
PADDLE_ENFORCE_EQ(dtypes.size(), out_names.size());
for (size_t i = 0; i < dtypes.size(); ++i) {
ctx->SetType(out_names[i], framework::proto::VarType::LOD_TENSOR);
ctx->SetDataType(out_names[i], dtypes[i]);
SetType(ctx, out_names[i], framework::proto::VarType::LOD_TENSOR);
SetDataType(ctx, out_names[i], dtypes[i]);
}
}
}
......
......@@ -100,8 +100,7 @@ void FileReaderInferShape::operator()(framework::InferShapeContext* ctx) const {
void FileReaderInferVarType::operator()(
framework::InferVarTypeContext* ctx) const {
std::string reader_name = ctx->Output("Out")[0];
ctx->SetType(reader_name, framework::proto::VarType::READER);
ctx->SetOutputType("Out", framework::proto::VarType::READER);
}
void DecoratedReaderInferShape::operator()(
......@@ -125,10 +124,8 @@ void DecoratedReaderInferShape::operator()(
void DecoratedReaderInferVarType::operator()(
framework::InferVarTypeContext* ctx) const {
const std::string& in_reader_name = ctx->Input("UnderlyingReader")[0];
const std::string& out_reader_name = ctx->Output("Out")[0];
ctx->SetType(out_reader_name, framework::proto::VarType::READER);
ctx->SetDataTypes(out_reader_name, ctx->GetDataTypes(in_reader_name));
ctx->SetOutputType("Out", framework::proto::VarType::READER);
ctx->SetOutputDataTypes("Out", ctx->GetInputDataTypes("UnderlyingReader"));
}
void DecoratedReaderMakerBase::Make() {
......
......@@ -58,8 +58,7 @@ class ReduceSumVarTypeInference : public paddle::framework::VarTypeInference {
auto data_type = static_cast<paddle::framework::proto::VarType::Type>(
boost::get<int>(ctx->GetAttr("out_dtype")));
if (data_type >= 0) {
auto& out_var_name = ctx->Output("Out").front();
ctx->SetDataType(out_var_name, data_type);
ctx->SetOutputDataType("Out", data_type);
}
}
};
......
......@@ -85,9 +85,8 @@ to a file on disk.
class SaveCombineOpInferVarType : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext* ctx) const override {
for (auto& o : ctx->Output("Y")) {
ctx->SetType(o, framework::proto::VarType::RAW);
}
ctx->SetOutputType("Y", framework::proto::VarType::RAW,
framework::ALL_ELEMENTS);
}
};
......
......@@ -73,7 +73,7 @@ class SaveOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto var_type = framework::proto::VarType::RAW;
ctx->SetType(LOOKUP_TABLE_PATH, var_type);
ctx->InsertVar(LOOKUP_TABLE_PATH, var_type);
}
};
......
......@@ -82,13 +82,7 @@ $$Out = scale*(X + bias)$$
class ScaleOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto &in_var_name = ctx->Input("X").front();
auto out_var_name = ctx->Output("Out").front();
if (in_var_name != out_var_name) {
ctx->SetType(out_var_name, ctx->GetType(in_var_name));
ctx->SetDataType(out_var_name, ctx->GetDataType(in_var_name));
}
ctx->SyncTypeAndDataType("X", "Out");
}
};
......
......@@ -45,9 +45,10 @@ class SeluOp : public framework::OperatorWithKernel {
class SeluOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
std::unordered_map<std::string, std::string> &GetInputOutputWithSameType()
const override {
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Out"}};
static std::unordered_map<std::string, std::string> m{{"X", /*->*/ "Out"}};
return m;
}
};
......
......@@ -145,9 +145,10 @@ For each row $i$ and each column $j$ in the matrix, we have:
class SoftmaxOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
std::unordered_map<std::string, std::string>& GetInputOutputWithSameType()
const override {
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Out"}};
static std::unordered_map<std::string, std::string> m{{"X", /*->*/ "Out"}};
return m;
}
};
......
......@@ -64,9 +64,8 @@ class SplitSelectedRowsOp : public framework::OperatorWithKernel {
class SplitSelectedRowsOpInferVarType : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
for (auto &out_var : ctx->Output("Out")) {
ctx->SetType(out_var, framework::proto::VarType::SELECTED_ROWS);
}
ctx->SetOutputType("Out", framework::proto::VarType::SELECTED_ROWS,
framework::ALL_ELEMENTS);
}
};
......
......@@ -210,43 +210,36 @@ class SumOpMaker : public framework::OpProtoAndCheckerMaker {
class SumOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext* ctx) const override {
auto& inputs = ctx->Input("X");
auto var_type = framework::proto::VarType::SELECTED_ROWS;
for (auto& name : ctx->Input("X")) {
VLOG(10) << name << " " << ctx->GetType(name);
}
bool any_input_is_lod_tensor = std::any_of(
inputs.begin(), inputs.end(), [ctx](const std::string& name) {
return ctx->GetType(name) == framework::proto::VarType::LOD_TENSOR;
});
auto is_tensor_array = [ctx](const std::string& name) {
return ctx->GetType(name) == framework::proto::VarType::LOD_TENSOR_ARRAY;
};
bool any_input_is_tensor_array =
std::any_of(inputs.begin(), inputs.end(), is_tensor_array);
bool all_inputs_are_tensor_array =
std::all_of(inputs.begin(), inputs.end(), is_tensor_array);
if (!ctx->IsDygraph()) {
auto var_type = framework::proto::VarType::SELECTED_ROWS;
if (VLOG_IS_ON(10)) {
for (size_t ind = 0; ind < ctx->InputSize("X"); ++ind) {
VLOG(10) << ctx->InputVarName("X", ind) << " "
<< ctx->GetInputType("X", ind);
}
}
if (any_input_is_tensor_array) {
if (!all_inputs_are_tensor_array) {
std::ostringstream os;
for (auto& each : inputs) {
os << " " << each << " type is " << ctx->GetType(each) << "\n";
if (ctx->InputTypeAnyOf("X",
framework::proto::VarType::LOD_TENSOR_ARRAY)) {
if (!ctx->InputTypeAllOf("X",
framework::proto::VarType::LOD_TENSOR_ARRAY)) {
std::ostringstream os;
for (size_t ind = 0; ind < ctx->InputSize("X"); ++ind) {
os << " " << ctx->InputVarName("X", ind) << " type is "
<< ctx->GetInputType("X", ind) << "\n";
}
PADDLE_THROW(platform::errors::InvalidArgument(
"Not all inputs are tensor array:\n%s", os.str()));
}
PADDLE_ENFORCE_EQ(all_inputs_are_tensor_array, true,
"Not all inputs are tensor array:\n%s", os.str());
var_type = framework::proto::VarType::LOD_TENSOR_ARRAY;
} else if (ctx->InputTypeAnyOf("X",
framework::proto::VarType::LOD_TENSOR)) {
var_type = framework::proto::VarType::LOD_TENSOR;
}
var_type = framework::proto::VarType::LOD_TENSOR_ARRAY;
} else if (any_input_is_lod_tensor) {
var_type = framework::proto::VarType::LOD_TENSOR;
}
auto out_var_name = ctx->Output("Out").front();
ctx->SetType(out_var_name, var_type);
ctx->SetDataType(out_var_name, ctx->GetDataType(inputs.front()));
ctx->SetOutputType("Out", var_type);
ctx->SetOutputDataType("Out", ctx->GetInputDataType("X"));
}
}
};
......
......@@ -213,9 +213,9 @@ class LoDTensorArray2TensorGradInferVarType
: public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
for (auto &out_var : ctx->Output(framework::GradVarName("X"))) {
ctx->SetType(out_var, framework::proto::VarType::LOD_TENSOR_ARRAY);
}
ctx->SetOutputType(framework::GradVarName("X"),
framework::proto::VarType::LOD_TENSOR_ARRAY,
framework::ALL_ELEMENTS);
}
};
......
......@@ -232,15 +232,13 @@ uniform distribution. The random result is in set [min, max).
class UniformRandomOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto out_var_name = ctx->Output("Out").front();
auto var_data_type = static_cast<framework::proto::VarType::Type>(
boost::get<int>(ctx->GetAttr("dtype")));
if (ctx->GetType(out_var_name) !=
framework::proto::VarType::SELECTED_ROWS) {
ctx->SetType(out_var_name, framework::proto::VarType::LOD_TENSOR);
if (ctx->GetOutputType("Out") != framework::proto::VarType::SELECTED_ROWS) {
ctx->SetOutputType("Out", framework::proto::VarType::LOD_TENSOR);
}
ctx->SetDataType(out_var_name, var_data_type);
ctx->SetOutputDataType("Out", var_data_type);
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册