未验证 提交 e445b3ff 编写于 作者: Y Yu Yang 提交者: GitHub

Move framework.proto to proto namespace (#6718)

* Move framework.proto to proto namespace

* Fix compile

* Fix compile

* Fix Compile
上级 a87f4963
......@@ -53,7 +53,7 @@ Kernel实现 | CPU、CUDA共享Kernel实现在`.h`文件中,否则,CPU
```cpp
class MulOpMaker : public framework::OpProtoAndCheckerMaker {
public:
MulOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
MulOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(Tensor), 2D tensor of size (M x K)");
AddInput("Y", "(Tensor), 2D tensor of size (K x N)");
......@@ -82,7 +82,7 @@ The equation is: Out = X * Y
template <typename AttrType>
class ScaleOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ScaleOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
ScaleOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input tensor of scale operator.").NotInGradient();
AddOutput("Out", "The output tensor of scale operator.").NotInGradient();
......
......@@ -50,7 +50,7 @@ First, define `ProtoMaker` to describe the Operator's input, output, and additio
```cpp
class MulOpMaker : public framework::OpProtoAndCheckerMaker {
public:
MulOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
MulOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(Tensor), 2D tensor of size (M x K)");
AddInput("Y", "(Tensor), 2D tensor of size (K x N)");
......@@ -79,7 +79,7 @@ An additional example [`ScaleOp`](https://github.com/PaddlePaddle/Paddle/blob/de
template <typename AttrType>
class ScaleOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ScaleOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
ScaleOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input tensor of scale operator.").NotInGradient();
AddOutput("Out", "The output tensor of scale operator.").NotInGradient();
......
......@@ -19,42 +19,42 @@ limitations under the License. */
namespace paddle {
namespace framework {
Attribute GetAttrValue(const OpDesc::Attr& attr_desc) {
Attribute GetAttrValue(const proto::OpDesc::Attr& attr_desc) {
switch (attr_desc.type()) {
case framework::AttrType::BOOLEAN: {
case proto::AttrType::BOOLEAN: {
return attr_desc.b();
}
case framework::AttrType::INT: {
case proto::AttrType::INT: {
return attr_desc.i();
}
case framework::AttrType::FLOAT: {
case proto::AttrType::FLOAT: {
return attr_desc.f();
}
case framework::AttrType::STRING: {
case proto::AttrType::STRING: {
return attr_desc.s();
}
case framework::AttrType::BOOLEANS: {
case proto::AttrType::BOOLEANS: {
std::vector<bool> val(attr_desc.bools_size());
for (int i = 0; i < attr_desc.bools_size(); ++i) {
val[i] = attr_desc.bools(i);
}
return val;
}
case framework::AttrType::INTS: {
case proto::AttrType::INTS: {
std::vector<int> val(attr_desc.ints_size());
for (int i = 0; i < attr_desc.ints_size(); ++i) {
val[i] = attr_desc.ints(i);
}
return val;
}
case framework::AttrType::FLOATS: {
case proto::AttrType::FLOATS: {
std::vector<float> val(attr_desc.floats_size());
for (int i = 0; i < attr_desc.floats_size(); ++i) {
val[i] = attr_desc.floats(i);
}
return val;
}
case framework::AttrType::STRINGS: {
case proto::AttrType::STRINGS: {
std::vector<std::string> val(attr_desc.strings_size());
for (int i = 0; i < attr_desc.strings_size(); ++i) {
val[i] = attr_desc.strings(i);
......
......@@ -27,12 +27,12 @@ limitations under the License. */
namespace paddle {
namespace framework {
template <typename T>
inline AttrType AttrTypeID() {
inline proto::AttrType AttrTypeID() {
Attribute tmp = T();
return static_cast<AttrType>(tmp.which() - 1);
return static_cast<proto::AttrType>(tmp.which() - 1);
}
Attribute GetAttrValue(const OpDesc::Attr& attr_desc);
Attribute GetAttrValue(const proto::OpDesc::Attr& attr_desc);
class AttrReader {
public:
......
......@@ -341,7 +341,7 @@ static void CreateGradVarInBlock(
auto* param = block_desc->FindVarRecursive(pname);
auto* grad = block_desc->FindVar(arg);
if (param == nullptr) {
grad->SetDataType(DataType::FP32);
grad->SetDataType(proto::DataType::FP32);
} else {
grad->SetDataType(param->GetDataType());
}
......
......@@ -166,7 +166,7 @@ class FillZeroOpMaker : public OpProtoAndCheckerMaker {
class SumOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SumOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
SumOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "the input tensors of sum operator.").AsDuplicable();
AddOutput("Out", "the output tensor of sum operator.");
......
......@@ -128,22 +128,22 @@ BlockDescBind *BlockDescBind::ParentBlock() const {
return prog_->MutableBlock(static_cast<size_t>(this->desc_->parent_idx()));
}
BlockDesc *BlockDescBind::Proto() {
proto::BlockDesc *BlockDescBind::Proto() {
Flush();
return desc_;
}
BlockDescBind::BlockDescBind(ProgramDescBind *prog, BlockDesc *desc)
BlockDescBind::BlockDescBind(ProgramDescBind *prog, proto::BlockDesc *desc)
: prog_(prog), desc_(desc), need_update_(false) {
for (const VarDesc &var_desc : desc_->vars()) {
for (const proto::VarDesc &var_desc : desc_->vars()) {
vars_[var_desc.name()].reset(new VarDescBind(var_desc));
}
for (const OpDesc &op_desc : desc_->ops()) {
for (const proto::OpDesc &op_desc : desc_->ops()) {
ops_.emplace_back(new OpDescBind(op_desc, prog));
}
}
BlockDescBind::BlockDescBind(const BlockDescBind &other, BlockDesc *desc,
BlockDescBind::BlockDescBind(const BlockDescBind &other, proto::BlockDesc *desc,
ProgramDescBind *prog)
: prog_(prog), desc_(desc) {
need_update_ = true;
......
......@@ -36,9 +36,9 @@ class ProgramDescBind;
class BlockDescBind {
public:
BlockDescBind(ProgramDescBind *prog, BlockDesc *desc);
BlockDescBind(ProgramDescBind *prog, proto::BlockDesc *desc);
BlockDescBind(const BlockDescBind &other, BlockDesc *desc,
BlockDescBind(const BlockDescBind &other, proto::BlockDesc *desc,
ProgramDescBind *prog);
~BlockDescBind() {
......@@ -88,7 +88,7 @@ class BlockDescBind {
void Flush();
BlockDesc *Proto();
proto::BlockDesc *Proto();
ProgramDescBind *Program() { return this->prog_; }
......@@ -97,8 +97,8 @@ class BlockDescBind {
void ClearPBVars();
private:
ProgramDescBind *prog_; // not_own
BlockDesc *desc_; // not_own
ProgramDescBind *prog_; // not_own
proto::BlockDesc *desc_; // not_own
bool need_update_;
std::deque<std::unique_ptr<OpDescBind>> ops_;
......
......@@ -20,7 +20,8 @@
namespace paddle {
namespace framework {
inline DataType ToDataType(std::type_index type) {
inline proto::DataType ToDataType(std::type_index type) {
using namespace paddle::framework::proto;
if (typeid(float).hash_code() == type.hash_code()) {
return DataType::FP32;
} else if (typeid(double).hash_code() == type.hash_code()) {
......@@ -36,7 +37,8 @@ inline DataType ToDataType(std::type_index type) {
}
}
inline std::type_index ToTypeIndex(DataType type) {
inline std::type_index ToTypeIndex(proto::DataType type) {
using namespace paddle::framework::proto;
switch (type) {
case DataType::FP32:
return typeid(float);
......@@ -54,7 +56,8 @@ inline std::type_index ToTypeIndex(DataType type) {
}
template <typename Visitor>
inline void VisitDataType(DataType type, Visitor visitor) {
inline void VisitDataType(proto::DataType type, Visitor visitor) {
using namespace paddle::framework::proto;
switch (type) {
case DataType::FP32:
visitor.template operator()<float>();
......
......@@ -90,7 +90,7 @@ struct OpInfoFiller<T, kOperator> {
template <typename T>
struct OpInfoFiller<T, kOpProtoAndCheckerMaker> {
void operator()(const char* op_type, OpInfo* info) const {
info->proto_ = new OpProto;
info->proto_ = new proto::OpProto;
info->checker_ = new OpAttrChecker();
auto maker = T(info->proto_, info->checker_);
maker.Validate();
......
......@@ -41,20 +41,20 @@ Executor::Executor(const std::vector<platform::Place>& places) {
device_contexts_.swap(borrowed_contexts);
}
static void CreateTensor(Variable* var, VarDesc::VarType var_type) {
if (var_type == VarDesc::LOD_TENSOR) {
static void CreateTensor(Variable* var, proto::VarDesc::VarType var_type) {
if (var_type == proto::VarDesc::LOD_TENSOR) {
var->GetMutable<LoDTensor>();
} else if (var_type == VarDesc::SELECTED_ROWS) {
} else if (var_type == proto::VarDesc::SELECTED_ROWS) {
var->GetMutable<SelectedRows>();
} else if (var_type == VarDesc::FEED_MINIBATCH) {
} else if (var_type == proto::VarDesc::FEED_MINIBATCH) {
var->GetMutable<FeedFetchList>();
} else if (var_type == VarDesc::FETCH_LIST) {
} else if (var_type == proto::VarDesc::FETCH_LIST) {
var->GetMutable<FeedFetchList>();
} else if (var_type == VarDesc::STEP_SCOPES) {
} else if (var_type == proto::VarDesc::STEP_SCOPES) {
var->GetMutable<std::vector<framework::Scope>>();
} else if (var_type == VarDesc::LOD_RANK_TABLE) {
} else if (var_type == proto::VarDesc::LOD_RANK_TABLE) {
var->GetMutable<LoDRankTable>();
} else if (var_type == VarDesc::LOD_TENSOR_ARRAY) {
} else if (var_type == proto::VarDesc::LOD_TENSOR_ARRAY) {
var->GetMutable<LoDTensorArray>();
} else {
PADDLE_THROW(
......
......@@ -14,7 +14,7 @@ limitations under the License. */
syntax = "proto2";
option optimize_for = LITE_RUNTIME;
package paddle.framework;
package paddle.framework.proto;
enum AttrType {
INT = 0;
......
......@@ -197,7 +197,7 @@ void SerializeToStream(std::ostream &os, const LoDTensor &tensor,
{ // the 2nd field, tensor description
// int32_t size
// void* protobuf message
framework::TensorDesc desc;
proto::TensorDesc desc;
desc.set_data_type(framework::ToDataType(tensor.type()));
auto dims = framework::vectorize(tensor.dims());
auto *pb_dims = desc.mutable_dims();
......@@ -262,7 +262,7 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor) {
uint32_t version;
is.read(reinterpret_cast<char *>(&version), sizeof(version));
PADDLE_ENFORCE_EQ(version, 0U, "Only version 0 is supported");
framework::TensorDesc desc;
proto::TensorDesc desc;
{ // int32_t size
// proto buffer
int32_t size;
......@@ -281,16 +281,16 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor) {
void *buf;
platform::Place cpu = platform::CPUPlace();
switch (desc.data_type()) {
case framework::FP32:
case proto::FP32:
buf = tensor->mutable_data<float>(cpu);
break;
case framework::FP64:
case proto::FP64:
buf = tensor->mutable_data<double>(cpu);
break;
case framework::INT32:
case proto::INT32:
buf = tensor->mutable_data<int>(cpu);
break;
case framework::INT64:
case proto::INT64:
buf = tensor->mutable_data<int64_t>(cpu);
break;
default:
......
......@@ -58,11 +58,11 @@ class CompileTimeInferShapeContext : public InferShapeContext {
PADDLE_ENFORCE_LT(j, Outputs(out).size());
auto *in_var = block_.FindVarRecursive(Inputs(in)[i]);
auto *out_var = block_.FindVarRecursive(Outputs(out)[j]);
if (in_var->GetType() != VarDesc::LOD_TENSOR) {
if (in_var->GetType() != proto::VarDesc::LOD_TENSOR) {
VLOG(3) << "input " << in << " is not LodTensor";
return;
}
PADDLE_ENFORCE_EQ(in_var->GetType(), VarDesc::LOD_TENSOR,
PADDLE_ENFORCE_EQ(in_var->GetType(), proto::VarDesc::LOD_TENSOR,
"The %d-th output of Output(%s) must be LoDTensor.", j,
out);
out_var->SetLoDLevel(in_var->GetLodLevel());
......@@ -70,7 +70,7 @@ class CompileTimeInferShapeContext : public InferShapeContext {
bool IsRuntime() const override;
protected:
VarDesc::VarType GetVarType(const std::string &name) const override;
proto::VarDesc::VarType GetVarType(const std::string &name) const override;
DDim GetDim(const std::string &name) const override;
......@@ -90,12 +90,12 @@ OpDescBind::OpDescBind(const std::string &type, const VariableNameMap &inputs,
need_update_ = true;
}
OpDescBind::OpDescBind(const OpDesc &desc, ProgramDescBind *prog)
OpDescBind::OpDescBind(const proto::OpDesc &desc, ProgramDescBind *prog)
: desc_(desc), need_update_(false) {
// restore inputs_
int input_size = desc_.inputs_size();
for (int i = 0; i < input_size; ++i) {
const OpDesc::Var &var = desc_.inputs(i);
const proto::OpDesc::Var &var = desc_.inputs(i);
std::vector<std::string> &args = inputs_[var.parameter()];
int argu_size = var.arguments_size();
args.reserve(argu_size);
......@@ -106,7 +106,7 @@ OpDescBind::OpDescBind(const OpDesc &desc, ProgramDescBind *prog)
// restore outputs_
int output_size = desc_.outputs_size();
for (int i = 0; i < output_size; ++i) {
const OpDesc::Var &var = desc_.outputs(i);
const proto::OpDesc::Var &var = desc_.outputs(i);
std::vector<std::string> &args = outputs_[var.parameter()];
int argu_size = var.arguments_size();
args.reserve(argu_size);
......@@ -115,9 +115,9 @@ OpDescBind::OpDescBind(const OpDesc &desc, ProgramDescBind *prog)
}
}
// restore attrs_
for (const OpDesc::Attr &attr : desc_.attrs()) {
for (const proto::OpDesc::Attr &attr : desc_.attrs()) {
std::string attr_name = attr.name();
if (attr.type() != AttrType::BLOCK) {
if (attr.type() != proto::AttrType::BLOCK) {
attrs_[attr_name] = GetAttrValue(attr);
} else {
auto bid = attr.block_idx();
......@@ -126,7 +126,7 @@ OpDescBind::OpDescBind(const OpDesc &desc, ProgramDescBind *prog)
}
}
OpDesc *OpDescBind::Proto() {
proto::OpDesc *OpDescBind::Proto() {
Flush();
return &desc_;
}
......@@ -175,10 +175,10 @@ void OpDescBind::SetOutput(const std::string &param_name,
this->outputs_[param_name] = args;
}
AttrType OpDescBind::GetAttrType(const std::string &name) const {
proto::AttrType OpDescBind::GetAttrType(const std::string &name) const {
auto it = attrs_.find(name);
PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name);
return static_cast<AttrType>(it->second.which() - 1);
return static_cast<proto::AttrType>(it->second.which() - 1);
}
std::vector<std::string> OpDescBind::AttrNames() const {
......@@ -253,8 +253,8 @@ void OpDescBind::RenameInput(const std::string &old_name,
}
struct SetAttrDescVisitor : public boost::static_visitor<void> {
explicit SetAttrDescVisitor(OpDesc::Attr *attr) : attr_(attr) {}
mutable OpDesc::Attr *attr_;
explicit SetAttrDescVisitor(proto::OpDesc::Attr *attr) : attr_(attr) {}
mutable proto::OpDesc::Attr *attr_;
void operator()(int v) const { attr_->set_i(v); }
void operator()(float v) const { attr_->set_f(v); }
void operator()(const std::string &v) const { attr_->set_s(v); }
......@@ -272,7 +272,9 @@ struct SetAttrDescVisitor : public boost::static_visitor<void> {
void operator()(const std::vector<bool> &v) const {
VectorToRepeated(v, attr_->mutable_bools());
}
void operator()(BlockDesc *desc) const { attr_->set_block_idx(desc->idx()); }
void operator()(proto::BlockDesc *desc) const {
attr_->set_block_idx(desc->idx());
}
void operator()(boost::blank) const { PADDLE_THROW("Unexpected branch"); }
};
......@@ -297,7 +299,7 @@ void OpDescBind::Flush() {
auto *attr_desc = desc_.add_attrs();
attr_desc->set_name(attr.first);
attr_desc->set_type(
static_cast<framework::AttrType>(attr.second.which() - 1));
static_cast<proto::AttrType>(attr.second.which() - 1));
SetAttrDescVisitor visitor(attr_desc);
boost::apply_visitor(visitor, attr.second);
}
......@@ -375,7 +377,7 @@ void OpDescBind::InferVarType(BlockDescBind *block) const {
for (auto &out_pair : this->outputs_) {
for (auto &out_var_name : out_pair.second) {
block->FindRecursiveOrCreateVar(out_var_name)
->SetType(VarDesc::LOD_TENSOR);
->SetType(proto::VarDesc::LOD_TENSOR);
}
}
}
......@@ -484,7 +486,7 @@ void CompileTimeInferShapeContext::SetDim(const std::string &name,
}
bool CompileTimeInferShapeContext::IsRuntime() const { return false; }
VarDesc::VarType CompileTimeInferShapeContext::GetVarType(
proto::VarDesc::VarType CompileTimeInferShapeContext::GetVarType(
const std::string &name) const {
return block_.FindVarRecursive(name)->GetType();
}
......
......@@ -33,9 +33,9 @@ class OpDescBind {
OpDescBind(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs, const AttributeMap &attrs);
OpDescBind(const OpDesc &desc, ProgramDescBind *prog);
OpDescBind(const proto::OpDesc &desc, ProgramDescBind *prog);
OpDesc *Proto();
proto::OpDesc *Proto();
std::string Type() const { return desc_.type(); }
......@@ -59,7 +59,7 @@ class OpDescBind {
return attrs_.find(name) != attrs_.end();
}
AttrType GetAttrType(const std::string &name) const;
proto::AttrType GetAttrType(const std::string &name) const;
std::vector<std::string> AttrNames() const;
......@@ -126,7 +126,7 @@ class OpDescBind {
return ret_val;
}
OpDesc desc_;
proto::OpDesc desc_;
VariableNameMap inputs_;
VariableNameMap outputs_;
AttributeMap attrs_;
......
......@@ -34,7 +34,7 @@ class InferShapeBase {
struct OpInfo {
OpCreator creator_;
GradOpMakerFN grad_op_maker_;
OpProto* proto_{nullptr};
proto::OpProto* proto_{nullptr};
OpAttrChecker* checker_{nullptr};
InferVarTypeFN infer_var_type_;
InferShapeFN infer_shape_;
......@@ -43,7 +43,7 @@ struct OpInfo {
return proto_ != nullptr && checker_ != nullptr;
}
const OpProto& Proto() const {
const proto::OpProto& Proto() const {
PADDLE_ENFORCE_NOT_NULL(proto_, "Operator Proto has not been registered");
PADDLE_ENFORCE(proto_->IsInitialized(),
"Operator Proto must be initialized in op info");
......
......@@ -22,6 +22,8 @@ namespace framework {
// this class not only make proto but also init attribute checkers.
class OpProtoAndCheckerMaker {
public:
using OpProto = proto::OpProto;
using OpAttrChecker = framework::OpAttrChecker;
OpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
: proto_(proto), op_checker_(op_checker) {}
......@@ -80,7 +82,7 @@ class OpProtoAndCheckerMaker {
class NOPMaker : public OpProtoAndCheckerMaker {
public:
NOPMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
NOPMaker(OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {}
};
......
......@@ -18,7 +18,7 @@ limitations under the License. */
class TestAttrProtoMaker : public paddle::framework::OpProtoAndCheckerMaker {
public:
TestAttrProtoMaker(paddle::framework::OpProto* proto,
TestAttrProtoMaker(paddle::framework::proto::OpProto* proto,
paddle::framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddAttr<float>("scale", "scale of test op");
......@@ -27,7 +27,7 @@ class TestAttrProtoMaker : public paddle::framework::OpProtoAndCheckerMaker {
};
TEST(ProtoMaker, DuplicatedAttr) {
paddle::framework::OpProto op_proto;
paddle::framework::proto::OpProto op_proto;
paddle::framework::OpAttrChecker op_checker;
auto proto_maker = TestAttrProtoMaker(&op_proto, &op_checker);
ASSERT_THROW(proto_maker.Validate(), paddle::platform::EnforceNotMet);
......@@ -35,7 +35,7 @@ TEST(ProtoMaker, DuplicatedAttr) {
class TestInOutProtoMaker : public paddle::framework::OpProtoAndCheckerMaker {
public:
TestInOutProtoMaker(paddle::framework::OpProto* proto,
TestInOutProtoMaker(paddle::framework::proto::OpProto* proto,
paddle::framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("input", "input of test op");
......@@ -44,7 +44,7 @@ class TestInOutProtoMaker : public paddle::framework::OpProtoAndCheckerMaker {
};
TEST(ProtoMaker, DuplicatedInOut) {
paddle::framework::OpProto op_proto;
paddle::framework::proto::OpProto op_proto;
paddle::framework::OpAttrChecker op_checker;
auto proto_maker = TestInOutProtoMaker(&op_proto, &op_checker);
ASSERT_THROW(proto_maker.Validate(), paddle::platform::EnforceNotMet);
......
......@@ -31,7 +31,8 @@ std::unique_ptr<OperatorBase> OpRegistry::CreateOp(
}
static VariableNameMap ConvertOpDescVarsToVarNameMap(
const google::protobuf::RepeatedPtrField<OpDesc::Var>& op_desc_vars) {
const google::protobuf::RepeatedPtrField<proto::OpDesc::Var>&
op_desc_vars) {
VariableNameMap ret_val;
for (auto& var : op_desc_vars) {
auto& var_names = ret_val[var.parameter()];
......@@ -43,7 +44,8 @@ static VariableNameMap ConvertOpDescVarsToVarNameMap(
return ret_val;
}
std::unique_ptr<OperatorBase> OpRegistry::CreateOp(const OpDesc& op_desc) {
std::unique_ptr<OperatorBase> OpRegistry::CreateOp(
const proto::OpDesc& op_desc) {
VLOG(1) << "CreateOp directly from OpDesc is deprecated. It should only be"
"used in unit tests. Use CreateOp(const OpDescBind& op_desc) "
"instead.";
......
......@@ -77,7 +77,7 @@ class OpRegistry {
const VariableNameMap& outputs,
AttributeMap attrs);
static std::unique_ptr<OperatorBase> CreateOp(const OpDesc& op_desc);
static std::unique_ptr<OperatorBase> CreateOp(const proto::OpDesc& op_desc);
static std::unique_ptr<OperatorBase> CreateOp(const OpDescBind& op_desc);
};
......
......@@ -51,7 +51,7 @@ class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
static void BuildVar(const std::string& param_name,
std::initializer_list<const char*> arguments,
paddle::framework::OpDesc::Var* var) {
paddle::framework::proto::OpDesc::Var* var) {
var->set_parameter(param_name);
for (auto& arg_name : arguments) {
var->add_arguments(arg_name);
......@@ -63,7 +63,7 @@ REGISTER_OP_WITHOUT_GRADIENT(my_test_op, paddle::framework::MyTestOp,
paddle::framework::MyTestOpProtoAndCheckerMaker);
TEST(OpRegistry, CreateOp) {
paddle::framework::OpDesc op_desc;
paddle::framework::proto::OpDesc op_desc;
op_desc.set_type("cos_sim");
BuildVar("input", {"aa"}, op_desc.add_inputs());
BuildVar("output", {"bb"}, op_desc.add_outputs());
......@@ -71,7 +71,7 @@ TEST(OpRegistry, CreateOp) {
float scale = 3.3;
auto attr = op_desc.mutable_attrs()->Add();
attr->set_name("scale");
attr->set_type(paddle::framework::AttrType::FLOAT);
attr->set_type(paddle::framework::proto::AttrType::FLOAT);
attr->set_f(scale);
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
......@@ -83,14 +83,14 @@ TEST(OpRegistry, CreateOp) {
}
TEST(OpRegistry, IllegalAttr) {
paddle::framework::OpDesc op_desc;
paddle::framework::proto::OpDesc op_desc;
op_desc.set_type("cos_sim");
BuildVar("input", {"aa"}, op_desc.add_inputs());
BuildVar("output", {"bb"}, op_desc.add_outputs());
auto attr = op_desc.mutable_attrs()->Add();
attr->set_name("scale");
attr->set_type(paddle::framework::AttrType::FLOAT);
attr->set_type(paddle::framework::proto::AttrType::FLOAT);
attr->set_f(-2.0);
bool caught = false;
......@@ -108,7 +108,7 @@ TEST(OpRegistry, IllegalAttr) {
}
TEST(OpRegistry, DefaultValue) {
paddle::framework::OpDesc op_desc;
paddle::framework::proto::OpDesc op_desc;
op_desc.set_type("cos_sim");
BuildVar("input", {"aa"}, op_desc.add_inputs());
BuildVar("output", {"bb"}, op_desc.add_outputs());
......@@ -123,7 +123,7 @@ TEST(OpRegistry, DefaultValue) {
}
TEST(OpRegistry, CustomChecker) {
paddle::framework::OpDesc op_desc;
paddle::framework::proto::OpDesc op_desc;
op_desc.set_type("my_test_op");
BuildVar("input", {"ii"}, op_desc.add_inputs());
BuildVar("output", {"oo"}, op_desc.add_outputs());
......@@ -145,7 +145,7 @@ TEST(OpRegistry, CustomChecker) {
// set 'test_attr' set to an illegal value
auto attr = op_desc.mutable_attrs()->Add();
attr->set_name("test_attr");
attr->set_type(paddle::framework::AttrType::INT);
attr->set_type(paddle::framework::proto::AttrType::INT);
attr->set_i(3);
caught = false;
try {
......@@ -164,7 +164,7 @@ TEST(OpRegistry, CustomChecker) {
op_desc.mutable_attrs()->Clear();
attr = op_desc.mutable_attrs()->Add();
attr->set_name("test_attr");
attr->set_type(paddle::framework::AttrType::INT);
attr->set_type(paddle::framework::proto::AttrType::INT);
attr->set_i(4);
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
paddle::platform::CPUDeviceContext dev_ctx;
......
......@@ -377,7 +377,7 @@ class RuntimeInferShapeContext : public InferShapeContext {
}
}
VarDesc::VarType GetVarType(const std::string& name) const override {
proto::VarDesc::VarType GetVarType(const std::string& name) const override {
auto* var = scope_.FindVar(name);
return ToVarType(var->Type());
}
......@@ -417,7 +417,7 @@ OpKernelType OperatorWithKernel::GetKernelType(
const ExecutionContext& ctx) const {
return OpKernelType(IndicateDataType(ctx), ctx.GetPlace());
}
DataType OperatorWithKernel::IndicateDataType(
proto::DataType OperatorWithKernel::IndicateDataType(
const ExecutionContext& ctx) const {
auto& scope = ctx.scope();
int data_type = -1;
......@@ -443,7 +443,7 @@ DataType OperatorWithKernel::IndicateDataType(
}
}
PADDLE_ENFORCE(data_type != -1, "DataType should be indicated by input");
return static_cast<DataType>(data_type);
return static_cast<proto::DataType>(data_type);
}
} // namespace framework
......
......@@ -358,12 +358,13 @@ struct OpKernelType {
};
platform::Place place_;
DataType data_type_;
proto::DataType data_type_;
OpKernelType(DataType data_type, platform::Place place)
OpKernelType(proto::DataType data_type, platform::Place place)
: place_(place), data_type_(data_type) {}
OpKernelType(DataType data_type, const platform::DeviceContext& dev_ctx)
OpKernelType(proto::DataType data_type,
const platform::DeviceContext& dev_ctx)
: place_(dev_ctx.GetPlace()), data_type_(data_type) {}
bool operator==(const OpKernelType& o) const {
......@@ -409,7 +410,7 @@ class OperatorWithKernel : public OperatorBase {
private:
// indicate kernel DataType by input data. Defaultly all input data must be
// same.
DataType IndicateDataType(const ExecutionContext& ctx) const;
proto::DataType IndicateDataType(const ExecutionContext& ctx) const;
};
std::ostream& operator<<(std::ostream& os, const OpKernelType& kernel_key);
......
......@@ -58,7 +58,7 @@ class OpeWithoutKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
static void BuildVar(const std::string& param_name,
std::initializer_list<const char*> arguments,
paddle::framework::OpDesc::Var* var) {
paddle::framework::proto::OpDesc::Var* var) {
var->set_parameter(param_name);
for (auto& arg_name : arguments) {
*var->mutable_arguments()->Add() = arg_name;
......@@ -70,14 +70,14 @@ REGISTER_OP_WITHOUT_GRADIENT(
paddle::framework::OpeWithoutKernelTestProtoAndCheckerMaker);
TEST(OperatorBase, all) {
paddle::framework::OpDesc op_desc;
paddle::framework::proto::OpDesc op_desc;
op_desc.set_type("test_operator");
BuildVar("input", {"IN1"}, op_desc.add_inputs());
BuildVar("output", {"OUT1"}, op_desc.add_outputs());
auto attr = op_desc.mutable_attrs()->Add();
attr->set_name("scale");
attr->set_type(paddle::framework::AttrType::FLOAT);
attr->set_type(paddle::framework::proto::AttrType::FLOAT);
attr->set_f(3.14);
paddle::platform::CPUDeviceContext device_context;
......@@ -115,7 +115,7 @@ class OpWithKernelTest : public OperatorWithKernel {
protected:
void InferShape(framework::InferShapeContext* ctx) const override {}
OpKernelType GetKernelType(const ExecutionContext& ctx) const override {
return OpKernelType(DataType::FP32, ctx.GetPlace());
return OpKernelType(proto::DataType::FP32, ctx.GetPlace());
}
};
......@@ -195,14 +195,14 @@ REGISTER_OP_CPU_KERNEL(op_with_kernel,
// test with single input
TEST(OpKernel, all) {
paddle::framework::OpDesc op_desc;
paddle::framework::proto::OpDesc op_desc;
op_desc.set_type("op_with_kernel");
BuildVar("x", {"IN1"}, op_desc.add_inputs());
BuildVar("y", {"OUT1"}, op_desc.add_outputs());
auto attr = op_desc.mutable_attrs()->Add();
attr->set_name("scale");
attr->set_type(paddle::framework::AttrType::FLOAT);
attr->set_type(paddle::framework::proto::AttrType::FLOAT);
attr->set_f(3.14);
paddle::platform::CPUDeviceContext cpu_device_context;
......@@ -224,7 +224,7 @@ REGISTER_OP_CPU_KERNEL(op_multi_inputs_with_kernel,
TEST(OpKernel, multi_inputs) {
using namespace paddle::framework;
OpDesc op_desc;
proto::OpDesc op_desc;
op_desc.set_type("op_multi_inputs_with_kernel");
BuildVar("xs", {"x0", "x1", "x2"}, op_desc.add_inputs());
BuildVar("k", {"k0"}, op_desc.add_inputs());
......@@ -232,7 +232,7 @@ TEST(OpKernel, multi_inputs) {
auto attr = op_desc.mutable_attrs()->Add();
attr->set_name("scale");
attr->set_type(paddle::framework::AttrType::FLOAT);
attr->set_type(paddle::framework::proto::AttrType::FLOAT);
attr->set_f(3.14);
paddle::platform::CPUDeviceContext cpu_device_context;
......
......@@ -26,7 +26,7 @@ BlockDescBind *ProgramDescBind::AppendBlock(const BlockDescBind &parent) {
return blocks_.back().get();
}
ProgramDesc *ProgramDescBind::Proto() {
proto::ProgramDesc *ProgramDescBind::Proto() {
for (auto &block : blocks_) {
block->Flush();
}
......@@ -49,7 +49,7 @@ ProgramDescBind::ProgramDescBind(const ProgramDescBind &o) {
}
}
ProgramDescBind::ProgramDescBind(const ProgramDesc &desc) {
ProgramDescBind::ProgramDescBind(const proto::ProgramDesc &desc) {
desc_ = desc;
for (auto &block_desc : *desc_.mutable_blocks()) {
blocks_.emplace_back(new BlockDescBind(this, &block_desc));
......
......@@ -29,7 +29,7 @@ class ProgramDescBind {
public:
ProgramDescBind();
explicit ProgramDescBind(const ProgramDesc &desc);
explicit ProgramDescBind(const proto::ProgramDesc &desc);
ProgramDescBind(const ProgramDescBind &o);
......@@ -43,10 +43,10 @@ class ProgramDescBind {
size_t Size() const { return blocks_.size(); }
ProgramDesc *Proto();
proto::ProgramDesc *Proto();
private:
ProgramDesc desc_;
proto::ProgramDesc desc_;
std::vector<std::unique_ptr<BlockDescBind>> blocks_;
};
......
......@@ -22,15 +22,15 @@ TEST(ProgramDesc, copy_ctor) {
ProgramDescBind program;
auto* global_block = program.MutableBlock(0);
auto* x = global_block->Var("X");
x->SetType(VarDesc_VarType_LOD_TENSOR);
x->SetType(proto::VarDesc_VarType_LOD_TENSOR);
x->SetLoDLevel(0);
x->SetDataType(FP32);
x->SetDataType(proto::FP32);
x->SetShape({1000, 784});
auto* y = global_block->Var("Y");
y->SetType(VarDesc_VarType_LOD_TENSOR);
y->SetType(proto::VarDesc_VarType_LOD_TENSOR);
y->SetLoDLevel(0);
y->SetDataType(FP32);
y->SetDataType(proto::FP32);
y->SetShape({784, 100});
auto* op = global_block->AppendOp();
......@@ -39,7 +39,7 @@ TEST(ProgramDesc, copy_ctor) {
op->SetInput("Y", {y->Name()});
auto* out = global_block->Var("Out");
out->SetType(VarDesc_VarType_LOD_TENSOR);
out->SetType(proto::VarDesc_VarType_LOD_TENSOR);
op->SetOutput("Y", {out->Name()});
ProgramDescBind program_copy(program);
......@@ -84,15 +84,15 @@ TEST(ProgramDescBind, serialize_and_deserialize) {
ProgramDescBind program_origin;
auto* global_block = program_origin.MutableBlock(0);
auto* x = global_block->Var("X");
x->SetType(VarDesc_VarType_LOD_TENSOR);
x->SetType(proto::VarDesc_VarType_LOD_TENSOR);
x->SetLoDLevel(0);
x->SetDataType(FP32);
x->SetDataType(proto::FP32);
x->SetShape({1000, 784});
auto* y = global_block->Var("Y");
y->SetType(VarDesc_VarType_LOD_TENSOR);
y->SetType(proto::VarDesc_VarType_LOD_TENSOR);
y->SetLoDLevel(0);
y->SetDataType(FP32);
y->SetDataType(proto::FP32);
y->SetShape({784, 100});
auto* op = global_block->AppendOp();
......@@ -101,7 +101,7 @@ TEST(ProgramDescBind, serialize_and_deserialize) {
op->SetInput("Y", {y->Name()});
auto* out = global_block->Var("Out");
out->SetType(VarDesc_VarType_LOD_TENSOR);
out->SetType(proto::VarDesc_VarType_LOD_TENSOR);
op->SetOutput("Y", {out->Name()});
std::string binary_str;
......
......@@ -29,7 +29,7 @@ const std::string kFetchOpType = "fetch";
const std::string kDropOutOpType = "dropout";
const std::string kBatchNormOpType = "batch_norm";
bool HasDependentVar(const OpDesc& op_desc,
bool HasDependentVar(const proto::OpDesc& op_desc,
const std::set<std::string>& dependent_vars) {
for (auto& var : op_desc.outputs()) {
for (auto& argu : var.arguments()) {
......@@ -41,14 +41,15 @@ bool HasDependentVar(const OpDesc& op_desc,
return false;
}
bool IsTarget(const OpDesc& op_desc) {
bool IsTarget(const proto::OpDesc& op_desc) {
if (op_desc.has_is_target()) {
return op_desc.is_target();
}
return false;
}
void prune_impl(const ProgramDesc& input, ProgramDesc* output, int block_id) {
void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
int block_id) {
// TODO(tonyyang-svail):
// - will change to use multiple blocks for RNN op and Cond Op
......@@ -104,12 +105,12 @@ void prune_impl(const ProgramDesc& input, ProgramDesc* output, int block_id) {
}
// TODO(fengjiayi): Prune() could be inplaced to avoid unnecessary copies
void Prune(const ProgramDesc& input, ProgramDesc* output) {
void Prune(const proto::ProgramDesc& input, proto::ProgramDesc* output) {
prune_impl(input, output, 0);
}
void inference_optimize_impl(const ProgramDesc& input, ProgramDesc* output,
int block_id) {
void inference_optimize_impl(const proto::ProgramDesc& input,
proto::ProgramDesc* output, int block_id) {
*output = input;
auto* op_field = output->mutable_blocks(block_id)->mutable_ops();
for (auto& op_desc : *op_field) {
......@@ -125,7 +126,8 @@ void inference_optimize_impl(const ProgramDesc& input, ProgramDesc* output,
}
}
void InferenceOptimize(const ProgramDesc& input, ProgramDesc* output) {
void InferenceOptimize(const proto::ProgramDesc& input,
proto::ProgramDesc* output) {
inference_optimize_impl(input, output, 0);
}
......
......@@ -20,9 +20,10 @@ limitations under the License. */
namespace paddle {
namespace framework {
void Prune(const ProgramDesc& input, ProgramDesc* output);
void Prune(const proto::ProgramDesc& input, proto::ProgramDesc* output);
void InferenceOptimize(const ProgramDesc& input, ProgramDesc* output);
void InferenceOptimize(const proto::ProgramDesc& input,
proto::ProgramDesc* output);
} // namespace framework
} // namespace paddle
......@@ -34,7 +34,7 @@ void AddOp(const std::string &type, const f::VariableNameMap &inputs,
for (auto kv : outputs) {
for (auto v : kv.second) {
auto var = block->Var(v);
var->SetDataType(paddle::framework::DataType::FP32);
var->SetDataType(paddle::framework::proto::DataType::FP32);
}
}
......@@ -57,14 +57,14 @@ TEST(Prune, one_operator) {
AddOp("one_one", {{"input", {"a"}}}, {{"output", {"b"}}}, f::AttributeMap{},
block);
f::ProgramDesc *pdesc = program.Proto();
f::ProgramDesc pruned;
f::proto::ProgramDesc *pdesc = program.Proto();
f::proto::ProgramDesc pruned;
Prune(*pdesc, &pruned);
f::Prune(*pdesc, &pruned);
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 0);
pdesc->mutable_blocks(0)->mutable_ops(0)->set_is_target(true);
Prune(*pdesc, &pruned);
f::Prune(*pdesc, &pruned);
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 1);
}
......@@ -81,12 +81,12 @@ TEST(Prune, forward) {
AddOp("one_one", {{"input", {"d"}}}, {{"output", {"e"}}}, f::AttributeMap{},
block);
f::ProgramDesc *pdesc = program.Proto();
f::proto::ProgramDesc *pdesc = program.Proto();
for (int i = 0; i < pdesc->blocks(0).ops_size(); ++i) {
f::ProgramDesc pruned;
f::proto::ProgramDesc pruned;
pdesc->mutable_blocks(0)->mutable_ops(i)->set_is_target(true);
Prune(*pdesc, &pruned);
f::Prune(*pdesc, &pruned);
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), i + 1);
}
}
......@@ -104,11 +104,11 @@ TEST(Prune, multi_input_op) {
AddOp("three_one", {{"input", {"b0", "b1", "b2"}}}, {{"output", {"c"}}},
f::AttributeMap{}, block);
f::ProgramDesc *pdesc = program.Proto();
f::proto::ProgramDesc *pdesc = program.Proto();
pdesc->mutable_blocks(0)->mutable_ops(3)->set_is_target(true);
f::ProgramDesc pruned;
Prune(*pdesc, &pruned);
f::proto::ProgramDesc pruned;
f::Prune(*pdesc, &pruned);
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 4);
}
......@@ -123,11 +123,11 @@ TEST(Prune, multi_output_op) {
AddOp("one_one", {{"input", {"c"}}}, {{"output", {"c1"}}}, f::AttributeMap{},
block);
f::ProgramDesc *pdesc = program.Proto();
f::proto::ProgramDesc *pdesc = program.Proto();
pdesc->mutable_blocks(0)->mutable_ops(2)->set_is_target(true);
f::ProgramDesc pruned;
Prune(*pdesc, &pruned);
f::proto::ProgramDesc pruned;
f::Prune(*pdesc, &pruned);
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 2);
}
......@@ -142,11 +142,11 @@ TEST(Prune, multi_target) {
AddOp("one_one", {{"input", {"c"}}}, {{"output", {"c1"}}}, f::AttributeMap{},
block);
f::ProgramDesc *pdesc = program.Proto();
f::proto::ProgramDesc *pdesc = program.Proto();
pdesc->mutable_blocks(0)->mutable_ops(1)->set_is_target(true);
pdesc->mutable_blocks(0)->mutable_ops(2)->set_is_target(true);
f::ProgramDesc pruned;
Prune(*pdesc, &pruned);
f::proto::ProgramDesc pruned;
f::Prune(*pdesc, &pruned);
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 3);
}
......@@ -57,17 +57,17 @@ void InferShapeContext::SetDims(const std::vector<std::string> &names,
SetDim(names[i], dims[i]);
}
}
std::vector<VarDesc::VarType> InferShapeContext::GetInputsVarType(
std::vector<proto::VarDesc::VarType> InferShapeContext::GetInputsVarType(
const std::string &name) const {
return GetVarTypes(Inputs(name));
}
std::vector<VarDesc::VarType> InferShapeContext::GetOutputsVarType(
std::vector<proto::VarDesc::VarType> InferShapeContext::GetOutputsVarType(
const std::string &name) const {
return GetVarTypes(Outputs(name));
}
std::vector<VarDesc::VarType> InferShapeContext::GetVarTypes(
std::vector<proto::VarDesc::VarType> InferShapeContext::GetVarTypes(
const std::vector<std::string> &names) const {
std::vector<VarDesc::VarType> retv;
std::vector<proto::VarDesc::VarType> retv;
retv.resize(names.size());
std::transform(names.begin(), names.end(), retv.begin(),
std::bind(std::mem_fn(&InferShapeContext::GetVarType), this,
......
......@@ -27,8 +27,9 @@ class InferShapeContext {
virtual bool HasInput(const std::string &name) const = 0;
virtual bool HasOutput(const std::string &name) const = 0;
std::vector<VarDesc::VarType> GetInputsVarType(const std::string &name) const;
std::vector<VarDesc::VarType> GetOutputsVarType(
std::vector<proto::VarDesc::VarType> GetInputsVarType(
const std::string &name) const;
std::vector<proto::VarDesc::VarType> GetOutputsVarType(
const std::string &name) const;
virtual bool HasInputs(const std::string &name) const = 0;
......@@ -65,10 +66,10 @@ class InferShapeContext {
std::vector<framework::DDim> GetDims(
const std::vector<std::string> &names) const;
std::vector<VarDesc::VarType> GetVarTypes(
std::vector<proto::VarDesc::VarType> GetVarTypes(
const std::vector<std::string> &names) const;
virtual VarDesc::VarType GetVarType(const std::string &name) const = 0;
virtual proto::VarDesc::VarType GetVarType(const std::string &name) const = 0;
};
} // namespace framework
......
......@@ -18,15 +18,17 @@ limitations under the License. */
namespace paddle {
namespace framework {
VarDesc::VarType VarDescBind::GetType() const { return desc_.type(); }
proto::VarDesc::VarType VarDescBind::GetType() const { return desc_.type(); }
void VarDescBind::SetType(VarDesc::VarType type) { desc_.set_type(type); }
void VarDescBind::SetType(proto::VarDesc::VarType type) {
desc_.set_type(type);
}
void VarDescBind::SetShape(const std::vector<int64_t> &dims) {
VectorToRepeated(dims, mutable_tensor_desc()->mutable_dims());
}
void VarDescBind::SetDataType(DataType data_type) {
void VarDescBind::SetDataType(proto::DataType data_type) {
mutable_tensor_desc()->set_data_type(data_type);
}
......@@ -34,14 +36,16 @@ std::vector<int64_t> VarDescBind::Shape() const {
return RepeatedToVector(tensor_desc().dims());
}
DataType VarDescBind::GetDataType() const { return tensor_desc().data_type(); }
proto::DataType VarDescBind::GetDataType() const {
return tensor_desc().data_type();
}
void VarDescBind::SetLoDLevel(int32_t lod_level) {
switch (desc_.type()) {
case VarDesc::LOD_TENSOR:
case proto::VarDesc::LOD_TENSOR:
desc_.mutable_lod_tensor()->set_lod_level(lod_level);
break;
case VarDesc::LOD_TENSOR_ARRAY:
case proto::VarDesc::LOD_TENSOR_ARRAY:
desc_.mutable_tensor_array()->set_lod_level(lod_level);
break;
default:
......@@ -52,9 +56,9 @@ void VarDescBind::SetLoDLevel(int32_t lod_level) {
int32_t VarDescBind::GetLodLevel() const {
switch (desc_.type()) {
case VarDesc::LOD_TENSOR:
case proto::VarDesc::LOD_TENSOR:
return desc_.lod_tensor().lod_level();
case VarDesc::LOD_TENSOR_ARRAY:
case proto::VarDesc::LOD_TENSOR_ARRAY:
return desc_.tensor_array().lod_level();
default:
PADDLE_THROW("Tensor type=%d does not support LoDLevel",
......@@ -62,29 +66,29 @@ int32_t VarDescBind::GetLodLevel() const {
}
}
const TensorDesc &VarDescBind::tensor_desc() const {
const proto::TensorDesc &VarDescBind::tensor_desc() const {
PADDLE_ENFORCE(desc_.has_type(), "invoke TensorDesc must after set type");
switch (desc_.type()) {
case VarDesc::SELECTED_ROWS:
case proto::VarDesc::SELECTED_ROWS:
return desc_.selected_rows();
case VarDesc::LOD_TENSOR:
case proto::VarDesc::LOD_TENSOR:
return desc_.lod_tensor().tensor();
case VarDesc::LOD_TENSOR_ARRAY:
case proto::VarDesc::LOD_TENSOR_ARRAY:
return desc_.tensor_array().tensor();
default:
PADDLE_THROW("Unexpected branch.");
}
}
TensorDesc *VarDescBind::mutable_tensor_desc() {
proto::TensorDesc *VarDescBind::mutable_tensor_desc() {
PADDLE_ENFORCE(desc_.has_type(),
"invoke MutableTensorDesc must after set type");
switch (desc_.type()) {
case VarDesc::SELECTED_ROWS:
case proto::VarDesc::SELECTED_ROWS:
return desc_.mutable_selected_rows();
case VarDesc::LOD_TENSOR:
case proto::VarDesc::LOD_TENSOR:
return desc_.mutable_lod_tensor()->mutable_tensor();
case VarDesc::LOD_TENSOR_ARRAY:
case proto::VarDesc::LOD_TENSOR_ARRAY:
return desc_.mutable_tensor_array()->mutable_tensor();
default:
PADDLE_THROW("Unexpected branch.");
......
......@@ -57,40 +57,40 @@ class VarDescBind {
public:
explicit VarDescBind(const std::string &name) {
desc_.set_name(name);
desc_.set_type(VarDesc::LOD_TENSOR);
desc_.set_type(proto::VarDesc::LOD_TENSOR);
}
explicit VarDescBind(const VarDesc &desc) : desc_(desc) {}
explicit VarDescBind(const proto::VarDesc &desc) : desc_(desc) {}
VarDesc *Proto() { return &desc_; }
proto::VarDesc *Proto() { return &desc_; }
std::string Name() const { return desc_.name(); }
void SetShape(const std::vector<int64_t> &dims);
void SetDataType(DataType data_type);
void SetDataType(proto::DataType data_type);
std::vector<int64_t> Shape() const;
DataType GetDataType() const;
proto::DataType GetDataType() const;
void SetLoDLevel(int32_t lod_level);
int32_t GetLodLevel() const;
VarDesc::VarType GetType() const;
proto::VarDesc::VarType GetType() const;
void SetType(VarDesc::VarType type);
void SetType(proto::VarDesc::VarType type);
bool Persistable() const { return desc_.persistable(); }
void SetPersistable(bool persistable) { desc_.set_persistable(persistable); }
private:
const TensorDesc &tensor_desc() const;
TensorDesc *mutable_tensor_desc();
const proto::TensorDesc &tensor_desc() const;
proto::TensorDesc *mutable_tensor_desc();
VarDesc desc_;
proto::VarDesc desc_;
};
} // namespace framework
} // namespace paddle
......@@ -20,15 +20,15 @@
namespace paddle {
namespace framework {
inline VarDesc::VarType ToVarType(std::type_index type) {
inline proto::VarDesc::VarType ToVarType(std::type_index type) {
if (type.hash_code() == typeid(LoDTensor).hash_code()) {
return VarDesc_VarType_LOD_TENSOR;
return proto::VarDesc_VarType_LOD_TENSOR;
} else if (type.hash_code() == typeid(LoDRankTable).hash_code()) {
return VarDesc_VarType_LOD_RANK_TABLE;
return proto::VarDesc_VarType_LOD_RANK_TABLE;
} else if (type.hash_code() == typeid(LoDTensorArray).hash_code()) {
return VarDesc_VarType_LOD_TENSOR_ARRAY;
return proto::VarDesc_VarType_LOD_TENSOR_ARRAY;
} else if (type.hash_code() == typeid(SelectedRows).hash_code()) {
return VarDesc_VarType_SELECTED_ROWS;
return proto::VarDesc_VarType_SELECTED_ROWS;
} else {
PADDLE_THROW("ToVarType:Unsupported type %s", type.name());
}
......@@ -37,16 +37,16 @@ inline VarDesc::VarType ToVarType(std::type_index type) {
template <typename Visitor>
inline void VisitVarType(const Variable& var, Visitor visitor) {
switch (ToVarType(var.Type())) {
case VarDesc_VarType_LOD_TENSOR:
case proto::VarDesc_VarType_LOD_TENSOR:
visitor(var.Get<framework::LoDTensor>());
return;
case VarDesc_VarType_LOD_RANK_TABLE:
case proto::VarDesc_VarType_LOD_RANK_TABLE:
visitor(var.Get<LoDRankTable>());
return;
case VarDesc_VarType_LOD_TENSOR_ARRAY:
case proto::VarDesc_VarType_LOD_TENSOR_ARRAY:
visitor(var.Get<LoDTensorArray>());
return;
case VarDesc_VarType_SELECTED_ROWS:
case proto::VarDesc_VarType_SELECTED_ROWS:
visitor(var.Get<SelectedRows>());
return;
default:
......
......@@ -36,14 +36,14 @@ class SumOpVarTypeInference : public VarTypeInference {
void operator()(const OpDescBind &op_desc,
BlockDescBind *block) const override {
auto &inputs = op_desc.Input("X");
auto default_var_type = VarDesc::SELECTED_ROWS;
auto default_var_type = proto::VarDesc::SELECTED_ROWS;
bool any_input_is_lod_tensor = std::any_of(
inputs.begin(), inputs.end(), [block](const std::string &name) {
return block->Var(name)->GetType() == VarDesc::LOD_TENSOR;
return block->Var(name)->GetType() == proto::VarDesc::LOD_TENSOR;
});
if (any_input_is_lod_tensor) {
default_var_type = VarDesc::LOD_TENSOR;
default_var_type = proto::VarDesc::LOD_TENSOR;
}
auto out_var_name = op_desc.Output("Out").front();
......@@ -68,19 +68,19 @@ TEST(InferVarType, sum_op) {
op->SetInput("X", {"test_a", "test_b", "test_c"});
op->SetOutput("Out", {"test_out"});
prog.MutableBlock(0)->Var("test_a")->SetType(VarDesc::SELECTED_ROWS);
prog.MutableBlock(0)->Var("test_b")->SetType(VarDesc::SELECTED_ROWS);
prog.MutableBlock(0)->Var("test_c")->SetType(VarDesc::SELECTED_ROWS);
prog.MutableBlock(0)->Var("test_a")->SetType(proto::VarDesc::SELECTED_ROWS);
prog.MutableBlock(0)->Var("test_b")->SetType(proto::VarDesc::SELECTED_ROWS);
prog.MutableBlock(0)->Var("test_c")->SetType(proto::VarDesc::SELECTED_ROWS);
prog.MutableBlock(0)->Var("test_out");
op->InferVarType(prog.MutableBlock(0));
ASSERT_EQ(VarDesc::SELECTED_ROWS,
ASSERT_EQ(proto::VarDesc::SELECTED_ROWS,
prog.MutableBlock(0)->Var("test_out")->GetType());
prog.MutableBlock(0)->Var("test_b")->SetType(VarDesc::LOD_TENSOR);
prog.MutableBlock(0)->Var("test_b")->SetType(proto::VarDesc::LOD_TENSOR);
op->InferVarType(prog.MutableBlock(0));
ASSERT_EQ(VarDesc::LOD_TENSOR,
ASSERT_EQ(proto::VarDesc::LOD_TENSOR,
prog.MutableBlock(0)->Var("test_out")->GetType());
}
......@@ -91,14 +91,14 @@ TEST(InferVarType, sum_op_without_infer_var_type) {
op->SetInput("X", {"test2_a", "test2_b", "test2_c"});
op->SetOutput("Out", {"test2_out"});
prog.MutableBlock(0)->Var("test2_a")->SetType(VarDesc::SELECTED_ROWS);
prog.MutableBlock(0)->Var("test2_b")->SetType(VarDesc::SELECTED_ROWS);
prog.MutableBlock(0)->Var("test2_c")->SetType(VarDesc::SELECTED_ROWS);
prog.MutableBlock(0)->Var("test2_a")->SetType(proto::VarDesc::SELECTED_ROWS);
prog.MutableBlock(0)->Var("test2_b")->SetType(proto::VarDesc::SELECTED_ROWS);
prog.MutableBlock(0)->Var("test2_c")->SetType(proto::VarDesc::SELECTED_ROWS);
prog.MutableBlock(0)->Var("test2_out");
op->InferVarType(prog.MutableBlock(0));
ASSERT_EQ(VarDesc_VarType_LOD_TENSOR,
ASSERT_EQ(proto::VarDesc_VarType_LOD_TENSOR,
prog.MutableBlock(0)->Var("test2_out")->GetType());
}
......
......@@ -63,8 +63,7 @@ class AccuracyOp : public framework::OperatorWithKernel {
class AccuracyOpMaker : public framework::OpProtoAndCheckerMaker {
public:
AccuracyOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
AccuracyOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
// TODO(typhoonzero): support both inference value and indices.
AddInput("Out", "The network output of topk (inferences)");
......
......@@ -38,9 +38,8 @@ class ActivationOpGrad : public framework::OperatorWithKernel {
class SigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SigmoidOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
SigmoidOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Sigmoid operator");
AddOutput("Y", "Output of Sigmoid operator");
AddComment(R"DOC(
......@@ -54,9 +53,8 @@ $$y = \frac{1}{1 + e^{-x}}$$
class LogSigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
public:
LogSigmoidOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
LogSigmoidOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of LogSigmoid operator");
AddOutput("Y", "Output of LogSigmoid operator");
AddComment(R"DOC(
......@@ -70,8 +68,8 @@ $$y = \log \frac{1}{1 + e^{-x}}$$
class ExpOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ExpOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
ExpOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Exp operator");
AddOutput("Y", "Output of Exp operator");
AddComment(R"DOC(
......@@ -85,8 +83,8 @@ $y = e^x$
class ReluOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ReluOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
ReluOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Relu operator");
AddOutput("Y", "Output of Relu operator");
AddComment(R"DOC(
......@@ -100,9 +98,8 @@ $y = \max(x, 0)$
class LeakyReluOpMaker : public framework::OpProtoAndCheckerMaker {
public:
LeakyReluOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
LeakyReluOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of LeakyRelu operator");
AddOutput("Y", "Output of LeakyRelu operator");
AddAttr<float>("alpha", "The small negative slope").SetDefault(0.02f);
......@@ -117,9 +114,8 @@ $y = \max(x, \alpha * x)$
class SoftShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SoftShrinkOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
SoftShrinkOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Softshrink operator");
AddOutput("Y", "Output of Softshrink operator");
AddAttr<float>("lambda", "non-negative offset").SetDefault(0.5f);
......@@ -140,8 +136,8 @@ $$
class TanhOpMaker : public framework::OpProtoAndCheckerMaker {
public:
TanhOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
TanhOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Tanh operator");
AddOutput("Y", "Output of Tanh operator");
AddComment(R"DOC(
......@@ -155,9 +151,8 @@ $$y = \frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$
class TanhShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
public:
TanhShrinkOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
TanhShrinkOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of TanhShrink operator");
AddOutput("Y", "Output of TanhShrink operator");
AddComment(R"DOC(
......@@ -171,9 +166,8 @@ $$y = x - \frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$
class HardShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
public:
HardShrinkOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
HardShrinkOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of HardShrink operator");
AddOutput("Y", "Output of HardShrink operator");
AddAttr<float>("threshold", "The value of threshold for HardShrink")
......@@ -195,8 +189,8 @@ $$
class SqrtOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SqrtOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
SqrtOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Sqrt operator");
AddOutput("Y", "Output of Sqrt operator");
AddComment(R"DOC(
......@@ -210,8 +204,8 @@ $y = \sqrt{x}$
class AbsOpMaker : public framework::OpProtoAndCheckerMaker {
public:
AbsOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AbsOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Abs operator");
AddOutput("Y", "Output of Abs operator");
AddComment(R"DOC(
......@@ -225,8 +219,8 @@ $y = |x|$
class CeilOpMaker : public framework::OpProtoAndCheckerMaker {
public:
CeilOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
CeilOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Ceil operator");
AddOutput("Y", "Output of Ceil operator");
AddComment(R"DOC(
......@@ -240,8 +234,8 @@ $y = ceil(x)$
class FloorOpMaker : public framework::OpProtoAndCheckerMaker {
public:
FloorOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
FloorOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Floor operator");
AddOutput("Y", "Output of Floor operator");
AddComment(R"DOC(
......@@ -255,8 +249,8 @@ $y = floor(x)$
class RoundOpMaker : public framework::OpProtoAndCheckerMaker {
public:
RoundOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
RoundOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Round operator");
AddOutput("Y", "Output of Round operator");
AddComment(R"DOC(
......@@ -270,9 +264,8 @@ $y = [x]$
class ReciprocalOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ReciprocalOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
ReciprocalOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Reciprocal operator");
AddOutput("Y", "Output of Reciprocal operator");
AddComment(R"DOC(
......@@ -286,8 +279,8 @@ $$y = \frac{1}{x}$$
class LogOpMaker : public framework::OpProtoAndCheckerMaker {
public:
LogOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
LogOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Log operator");
AddOutput("Y", "Output of Log operator");
AddComment(R"DOC(
......@@ -303,8 +296,8 @@ Natural logarithm of x.
class SquareOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SquareOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
SquareOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Square operator");
AddOutput("Y", "Output of Square operator");
AddComment(R"DOC(
......@@ -318,9 +311,8 @@ $y = x^2$
class SoftplusOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SoftplusOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
SoftplusOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Softplus operator");
AddOutput("Y", "Output of Softplus operator");
AddComment(R"DOC(
......@@ -334,9 +326,8 @@ $y = \ln(1 + e^{x})$
class SoftsignOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SoftsignOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
SoftsignOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Softsign operator");
AddOutput("Y", "Output of Softsign operator");
AddComment(R"DOC(
......@@ -350,8 +341,8 @@ $$y = \frac{x}{1 + |x|}$$
class BReluOpMaker : public framework::OpProtoAndCheckerMaker {
public:
BReluOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
BReluOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of BRelu operator");
AddOutput("Y", "Output of BRelu operator");
AddAttr<float>("t_min", "The min marginal value of BRelu")
......@@ -369,9 +360,8 @@ $y = \max(\min(x, t_{min}), t_{max})$
class SoftReluOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SoftReluOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
SoftReluOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of SoftRelu operator");
AddOutput("Y", "Output of SoftRelu operator");
AddAttr<float>("threshold", "The threshold value of SoftRelu")
......@@ -387,8 +377,8 @@ $y = \ln(1 + \exp(\max(\min(x, threshold), threshold))$
class ELUOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ELUOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
ELUOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of ELU operator");
AddOutput("Y", "Output of ELU operator");
AddAttr<float>("alpha", "The alpha value of ELU").SetDefault(1.0f);
......@@ -406,8 +396,8 @@ $y = \max(0, x) + \min(0, \alpha * (e^x - 1))$
class Relu6OpMaker : public framework::OpProtoAndCheckerMaker {
public:
Relu6OpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
Relu6OpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Relu6 operator");
AddOutput("Y", "Output of Relu6 operator");
AddAttr<float>("threshold", "The threshold value of Relu6")
......@@ -423,8 +413,8 @@ $y = \min(\max(0, x), 6)$
class PowOpMaker : public framework::OpProtoAndCheckerMaker {
public:
PowOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
PowOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Pow operator");
AddOutput("Y", "Output of Pow operator");
AddAttr<float>("factor", "The exponential factor of Pow").SetDefault(1.0f);
......@@ -439,8 +429,8 @@ $y = x^{factor}$
class STanhOpMaker : public framework::OpProtoAndCheckerMaker {
public:
STanhOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
STanhOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of STanh operator");
AddOutput("Y", "Output of STanh operator");
AddAttr<float>("scale_a", "The scale parameter of a for the input")
......@@ -458,9 +448,8 @@ $$y = b * \frac{e^{a * x} - e^{-a * x}}{e^{a * x} + e^{-a * x}}$$
class ThresholdedReluOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ThresholdedReluOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
ThresholdedReluOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of ThresholdedRelu operator");
AddOutput("Y", "Output of ThresholdedRelu operator");
AddAttr<float>("threshold", "The threshold location of activation")
......@@ -481,9 +470,8 @@ $$
class HardSigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
public:
HardSigmoidOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
HardSigmoidOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of HardSigmoid operator");
AddOutput("Y", "Output of HardSigmoid operator");
AddAttr<float>("slope", "Slope for linear approximation of sigmoid")
......@@ -508,8 +496,8 @@ It is recommended to use the defaults for this activation.
class SwishOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SwishOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
SwishOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Swish operator");
AddOutput("Y", "Output of Swish operator");
AddAttr<float>("beta", "Constant beta of swish operator").SetDefault(1.0f);
......
......@@ -59,8 +59,7 @@ class AdadeltaOp : public framework::OperatorWithKernel {
class AdadeltaOpMaker : public framework::OpProtoAndCheckerMaker {
public:
AdadeltaOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
AdadeltaOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Param", "(Tensor) Input parameter");
AddInput("Grad", "(Tensor) Input gradient");
......
......@@ -59,8 +59,7 @@ class AdagradOp : public framework::OperatorWithKernel {
class AdagradOpMaker : public framework::OpProtoAndCheckerMaker {
public:
AdagradOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
AdagradOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Param", "(Tensor) Input parameter");
AddInput("Grad", "(Tensor) Input gradient");
......
......@@ -73,7 +73,7 @@ class AdamOp : public framework::OperatorWithKernel {
class AdamOpMaker : public framework::OpProtoAndCheckerMaker {
public:
AdamOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
AdamOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Param", "(Tensor) Input parameter");
AddInput("Grad", "(Tensor) Input gradient");
......
......@@ -67,7 +67,7 @@ class AdamaxOp : public framework::OperatorWithKernel {
class AdamaxOpMaker : public framework::OpProtoAndCheckerMaker {
public:
AdamaxOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
AdamaxOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Param", "(Tensor) Input parameter");
AddInput("Grad", "(Tensor) Input gradient");
......
......@@ -114,8 +114,7 @@ class ArrayToLoDTensorOp : public framework::OperatorBase {
class ArrayToLoDTensorOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
ArrayToLoDTensorOpProtoMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
ArrayToLoDTensorOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X",
"(std::vector<LodTensor>) A vector of tensors that is going to "
......
......@@ -86,8 +86,7 @@ class AssignOp : public framework::OperatorBase {
class AssignOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
AssignOpProtoMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
AssignOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X",
"(LoDTensor, SelectedRows or LoDTensorArray) The input variable "
......@@ -109,8 +108,8 @@ class AssignInferShape : public framework::InferShapeBase {
void operator()(framework::InferShapeContext *context) const override {
if (context->HasInput("X")) {
auto type = context->GetInputsVarType("X")[0];
if (type == framework::VarDesc_VarType_SELECTED_ROWS ||
type == framework::VarDesc_VarType_LOD_TENSOR) {
if (type == framework::proto::VarDesc_VarType_SELECTED_ROWS ||
type == framework::proto::VarDesc_VarType_LOD_TENSOR) {
context->SetOutputDim("Out", context->GetInputDim("X"));
}
}
......
......@@ -49,7 +49,7 @@ class AucOp : public framework::OperatorWithKernel {
class AucOpMaker : public framework::OpProtoAndCheckerMaker {
public:
AucOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
AucOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Out",
"A floating point 2D tensor, values are in the range [0, 1]."
......
......@@ -85,8 +85,7 @@ class BatchNormOp : public framework::OperatorWithKernel {
class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker {
public:
BatchNormOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
BatchNormOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddAttr<bool>("is_test", "").SetDefault(false);
AddAttr<float>("momentum", "").SetDefault(0.9);
......
......@@ -83,9 +83,8 @@ class BeamSearchDecodeOp : public framework::OperatorBase {
class BeamSearchDecodeOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
BeamSearchDecodeOpProtoMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
BeamSearchDecodeOpProtoMaker(OpProto* proto, OpAttrChecker* op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Ids",
"(LodTensorArray)"
"score of the candidate words in each step");
......@@ -123,10 +122,10 @@ class BeamSearchDecodeInferVarType : public framework::VarTypeInference {
void operator()(const framework::OpDescBind& op_desc,
framework::BlockDescBind* block) const override {
for (auto& o : op_desc.Output("SentenceIds")) {
block->Var(o)->SetType(framework::VarDesc::LOD_TENSOR);
block->Var(o)->SetType(framework::proto::VarDesc::LOD_TENSOR);
}
for (auto& o : op_desc.Output("SentenceScores")) {
block->Var(o)->SetType(framework::VarDesc::LOD_TENSOR);
block->Var(o)->SetType(framework::proto::VarDesc::LOD_TENSOR);
}
}
};
......
......@@ -153,8 +153,7 @@ bool BeamSearch::NextItemSet(std::vector<BeamSearch::Item> *items) {
class BeamSearchProtoAndCheckerMaker
: public framework::OpProtoAndCheckerMaker {
public:
BeamSearchProtoAndCheckerMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
BeamSearchProtoAndCheckerMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
// inputs and outputs stored in proto
AddInput("pre_ids", "ids in previous step");
......
......@@ -65,8 +65,7 @@ class BilinearTensorProductOp : public framework::OperatorWithKernel {
class BilinearTensorProductOpMaker : public framework::OpProtoAndCheckerMaker {
public:
BilinearTensorProductOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
BilinearTensorProductOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The first input of bilinear_tensor_product operator.");
AddInput("Y", "The second input of bilinear_tensor_product operator.");
......
......@@ -20,8 +20,7 @@ namespace operators {
class CastOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
CastOpProtoMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
CastOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input tensor of cast op");
AddOutput("Out", "The output tensor of cast op");
......
......@@ -55,7 +55,7 @@ class CastOpKernel : public framework::OpKernel<InT> {
auto* in = context.Input<framework::Tensor>("X");
auto* out = context.Output<framework::Tensor>("Out");
framework::VisitDataType(
static_cast<framework::DataType>(context.Attr<int>("out_dtype")),
static_cast<framework::proto::DataType>(context.Attr<int>("out_dtype")),
CastOpFunctor<DeviceContext, InT>(
in, out, context.template device_context<DeviceContext>()));
}
......
......@@ -57,15 +57,14 @@ class ChunkEvalOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(framework::DataType::FP32,
return framework::OpKernelType(framework::proto::DataType::FP32,
ctx.device_context());
}
};
class ChunkEvalOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ChunkEvalOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
ChunkEvalOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Inference",
"(Tensor, default: Tensor<int64_t>). "
......
......@@ -37,8 +37,7 @@ class ClipByNormOp : public framework::OperatorWithKernel {
class ClipByNormOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ClipByNormOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
ClipByNormOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X",
"(Tensor) The input of clip_by_norm op."
......
......@@ -38,7 +38,7 @@ class ClipOp : public framework::OperatorWithKernel {
template <typename AttrType>
class ClipOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ClipOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
ClipOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X",
"(Tensor)The input of clip op."
......
......@@ -20,8 +20,7 @@ namespace operators {
template <typename OpComment>
class CompareOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
CompareOpProtoMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
CompareOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
OpComment comment;
AddInput("X",
......
......@@ -58,7 +58,7 @@ class ConcatOp : public framework::OperatorWithKernel {
class ConcatOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ConcatOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
ConcatOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input tensors of concat operator.").AsDuplicable();
AddOutput("Out", "Output tensor of concat operator.");
......
......@@ -205,8 +205,7 @@ void CondOp::Run(const Scope& scope,
class CondOpProtoAndCheckerMaker : public framework::OpProtoAndCheckerMaker {
public:
CondOpProtoAndCheckerMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
CondOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Cond", "The condition, which is a bool vector");
AddInput("Xs", "Inputs of Subnets").AsDuplicable();
......
......@@ -74,8 +74,7 @@ class ConditionalBlockOp : public ConditionalOp {
class ConditionalBlockOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
ConditionalBlockOpProtoMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
ConditionalBlockOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X",
"The conditional variable of this operator. If X is empty, the "
......
......@@ -19,8 +19,7 @@ namespace operators {
class CudnnConv2DOpMaker : public Conv2DOpMaker {
public:
CudnnConv2DOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
CudnnConv2DOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: Conv2DOpMaker(proto, op_checker) {
AddAttr<int>("workspace_size_MB",
"workspace size for cudnn, in MB, "
......@@ -34,8 +33,7 @@ class CudnnConv2DOpMaker : public Conv2DOpMaker {
class CudnnConv3DOpMaker : public Conv3DOpMaker {
public:
CudnnConv3DOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
CudnnConv3DOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: Conv3DOpMaker(proto, op_checker) {
AddAttr<int>("workspace_size_MB",
"workspace size for cudnn, in MB, "
......
......@@ -66,8 +66,7 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
ctx->SetOutputDim("Output", framework::make_ddim(output_shape));
}
Conv2DOpMaker::Conv2DOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
Conv2DOpMaker::Conv2DOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput(
"Input",
......@@ -138,8 +137,7 @@ $$
)DOC");
}
Conv3DOpMaker::Conv3DOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
Conv3DOpMaker::Conv3DOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput(
"Input",
......
......@@ -50,14 +50,12 @@ inline bool IsExpand(std::vector<int64_t>& filter_dim,
// operator implementations can reuse the code.
class Conv2DOpMaker : public framework::OpProtoAndCheckerMaker {
public:
Conv2DOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker);
Conv2DOpMaker(OpProto* proto, OpAttrChecker* op_checker);
};
class Conv3DOpMaker : public framework::OpProtoAndCheckerMaker {
public:
Conv3DOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker);
Conv3DOpMaker(OpProto* proto, OpAttrChecker* op_checker);
};
class ConvOp : public framework::OperatorWithKernel {
......
......@@ -75,8 +75,7 @@ class ConvShiftGradOp : public framework::OperatorWithKernel {
class ConvShiftOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ConvShiftOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
ConvShiftOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X",
"(Tensor, default Tensor<float>), a 2-D tensor with shape B x M, "
......
......@@ -19,8 +19,7 @@ namespace operators {
class CudnnConv2DTransposeOpMaker : public Conv2DTransposeOpMaker {
public:
CudnnConv2DTransposeOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
CudnnConv2DTransposeOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: Conv2DTransposeOpMaker(proto, op_checker) {
AddAttr<std::vector<int>>("dilations", "dilations of convolution operator.")
.SetDefault({1, 1});
......@@ -36,8 +35,7 @@ class CudnnConv2DTransposeOpMaker : public Conv2DTransposeOpMaker {
class CudnnConv3DTransposeOpMaker : public Conv3DTransposeOpMaker {
public:
CudnnConv3DTransposeOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
CudnnConv3DTransposeOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: Conv3DTransposeOpMaker(proto, op_checker) {
AddAttr<std::vector<int>>("dilations", "dilations of convolution operator.")
.SetDefault({1, 1, 1});
......
......@@ -53,8 +53,8 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const {
ctx->SetOutputDim("Output", framework::make_ddim(output_shape));
}
Conv2DTransposeOpMaker::Conv2DTransposeOpMaker(
framework::OpProto* proto, framework::OpAttrChecker* op_checker)
Conv2DTransposeOpMaker::Conv2DTransposeOpMaker(OpProto* proto,
OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput(
"Input",
......@@ -112,8 +112,8 @@ Example:
)DOC");
}
Conv3DTransposeOpMaker::Conv3DTransposeOpMaker(
framework::OpProto* proto, framework::OpAttrChecker* op_checker)
Conv3DTransposeOpMaker::Conv3DTransposeOpMaker(OpProto* proto,
OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Input",
"(Tensor) The input tensor of convolution transpose operator."
......
......@@ -30,14 +30,12 @@ using DDim = framework::DDim;
// operator implementations can reuse the code.
class Conv2DTransposeOpMaker : public framework::OpProtoAndCheckerMaker {
public:
Conv2DTransposeOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker);
Conv2DTransposeOpMaker(OpProto* proto, OpAttrChecker* op_checker);
};
class Conv3DTransposeOpMaker : public framework::OpProtoAndCheckerMaker {
public:
Conv3DTransposeOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker);
Conv3DTransposeOpMaker(OpProto* proto, OpAttrChecker* op_checker);
};
class ConvTransposeOp : public framework::OperatorWithKernel {
......
......@@ -62,7 +62,7 @@ class CosSimOp : public framework::OperatorWithKernel {
class CosSimOpMaker : public framework::OpProtoAndCheckerMaker {
public:
CosSimOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
CosSimOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The 1st input of cos_sim op.");
AddInput("Y", "The 2nd input of cos_sim op.");
......
......@@ -18,8 +18,7 @@ namespace paddle {
namespace operators {
class CRFDecodingOpMaker : public framework::OpProtoAndCheckerMaker {
public:
CRFDecodingOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
CRFDecodingOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Emission",
"(LoDTensor, default: LoDTensor<float>). A LoDTensor with shape "
......
......@@ -52,7 +52,7 @@ class CropOp : public framework::OperatorWithKernel {
class CropOpMaker : public framework::OpProtoAndCheckerMaker {
public:
CropOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
CropOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X",
"The input of pad op. "
......
......@@ -111,8 +111,7 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker {
public:
CrossEntropyOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
CrossEntropyOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X",
"(Tensor, default Tensor<float>), a 2-D tensor with shape N x D, "
......
......@@ -55,8 +55,7 @@ class DecayedAdagradOp : public framework::OperatorWithKernel {
class DecayedAdagradOpMaker : public framework::OpProtoAndCheckerMaker {
public:
DecayedAdagradOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
DecayedAdagradOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Param", "(Tensor) Input parameter");
AddInput("Grad", "(Tensor) Input gradient");
......
......@@ -40,8 +40,7 @@ class DropoutOp : public framework::OperatorWithKernel {
template <typename AttrType>
class DropoutOpMaker : public framework::OpProtoAndCheckerMaker {
public:
DropoutOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
DropoutOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input of dropout op.");
AddOutput("Out", "The output of dropout op.");
......
......@@ -19,8 +19,7 @@ namespace paddle {
namespace operators {
class ElementwiseAddOpMaker : public ElementwiseOpMaker {
public:
ElementwiseAddOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
ElementwiseAddOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: ElementwiseOpMaker(proto, op_checker) {
SetComment("Add", "$Out = X + Y$");
AddComment(comment_);
......
......@@ -19,8 +19,7 @@ namespace paddle {
namespace operators {
class ElementwiseDivOpMaker : public ElementwiseOpMaker {
public:
ElementwiseDivOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
ElementwiseDivOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: ElementwiseOpMaker(proto, op_checker) {
SetComment("Div", "$Out = X / Y$");
AddComment(comment_);
......
......@@ -20,8 +20,7 @@ namespace operators {
class ElementwiseMulOpMaker : public ElementwiseOpMaker {
public:
ElementwiseMulOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
ElementwiseMulOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: ElementwiseOpMaker(proto, op_checker) {
SetComment("Mul", "$Out = X \\odot\\ Y$");
AddComment(comment_);
......
......@@ -43,8 +43,7 @@ class ElementwiseOp : public framework::OperatorWithKernel {
class ElementwiseOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ElementwiseOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
ElementwiseOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(Tensor) The first input tensor of elementwise op");
AddInput("Y", "(Tensor) The second input tensor of elementwise op");
......
......@@ -19,8 +19,7 @@ namespace paddle {
namespace operators {
class ElementwiseSubOpMaker : public ElementwiseOpMaker {
public:
ElementwiseSubOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
ElementwiseSubOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: ElementwiseOpMaker(proto, op_checker) {
SetComment("Sub", "$Out = X - Y$");
AddComment(comment_);
......
......@@ -55,7 +55,7 @@ class ExpandOp : public framework::OperatorWithKernel {
class ExpandOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ExpandOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
ExpandOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X",
"(Tensor, default Tensor<float>) A tensor with rank in [1, 6]."
......
......@@ -54,8 +54,7 @@ class FeedOp : public framework::OperatorBase {
class FeedOpInfoMaker : public framework::OpProtoAndCheckerMaker {
public:
FeedOpInfoMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
FeedOpInfoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input of feed op");
AddOutput("Out", "The output of feed op");
......
......@@ -61,8 +61,7 @@ class FetchOp : public framework::OperatorBase {
class FetchOpInfoMaker : public framework::OpProtoAndCheckerMaker {
public:
FetchOpInfoMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
FetchOpInfoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input of fetch op");
AddOutput("Out", "The output of fetch op");
......
......@@ -52,7 +52,7 @@ class FillConstantBatchSizeLikeOp : public framework::OperatorWithKernel {
framework::OpKernelType GetKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
static_cast<framework::DataType>(ctx.Attr<int>("dtype")),
static_cast<framework::proto::DataType>(ctx.Attr<int>("dtype")),
ctx.device_context());
}
};
......@@ -60,13 +60,12 @@ class FillConstantBatchSizeLikeOp : public framework::OperatorWithKernel {
class FillConstantBatchSizeLikeOpMaker
: public framework::OpProtoAndCheckerMaker {
public:
FillConstantBatchSizeLikeOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
FillConstantBatchSizeLikeOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddAttr<int>("dtype",
"(int, default 5 (FP32)) "
"Output data type")
.SetDefault(framework::DataType::FP32);
.SetDefault(framework::proto::DataType::FP32);
AddInput("Input",
"(Tensor) Tensor "
"whose dim_idx th dimension is used to specify the batch_size");
......
......@@ -34,7 +34,8 @@ class FillConstantOp : public framework::OperatorBase {
using framework::OperatorBase::OperatorBase;
void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override {
auto data_type = static_cast<framework::DataType>(Attr<int>("dtype"));
auto data_type =
static_cast<framework::proto::DataType>(Attr<int>("dtype"));
auto value = Attr<float>("value");
auto force_cpu = Attr<bool>("force_cpu");
auto &out =
......@@ -52,13 +53,12 @@ class FillConstantOp : public framework::OperatorBase {
class FillConstantOpMaker : public framework::OpProtoAndCheckerMaker {
public:
FillConstantOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
FillConstantOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddAttr<int>("dtype",
"(int, default 5 (FP32)) "
"Output data type")
.SetDefault(framework::DataType::FP32);
.SetDefault(framework::proto::DataType::FP32);
AddAttr<std::vector<int>>("shape", "(vector<int>) The shape of the output");
AddAttr<float>("value", "(float, default 0) The value to be filled")
.SetDefault(0.0f);
......
......@@ -48,7 +48,7 @@ class FillOp : public framework::OperatorBase {
"Cannot find variable %s", Output("Out"))
.GetMutable<framework::LoDTensor>());
out.Resize(framework::make_ddim(Attr<std::vector<int>>("shape")));
auto dtype = static_cast<framework::DataType>(Attr<int>("dtype"));
auto dtype = static_cast<framework::proto::DataType>(Attr<int>("dtype"));
platform::CPUPlace cpu;
auto force_cpu = Attr<bool>("force_cpu");
out.mutable_data(force_cpu ? cpu : dev_ctx.GetPlace(),
......@@ -76,7 +76,7 @@ class FillOp : public framework::OperatorBase {
class FillOpMaker : public framework::OpProtoAndCheckerMaker {
public:
FillOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
FillOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddComment(R"DOC(Fill operator
......@@ -88,7 +88,7 @@ Fill an tensor with `value` and `shape`. The type of the tensor is specify by
"value", "The float values of tensor, which are flatten in row major");
AddAttr<std::vector<int>>("shape", "The shape of output tensor");
AddAttr<int>("dtype", "The data type of output tensor, Default is float")
.SetDefault(framework::DataType::FP32);
.SetDefault(framework::proto::DataType::FP32);
AddAttr<bool>("force_cpu",
"Whether the output tensor must be at CPU memory or not. "
"Default is false.")
......
......@@ -33,8 +33,7 @@ class FillZerosLikeOp : public framework::OperatorWithKernel {
class FillZerosLikeOpMaker : public framework::OpProtoAndCheckerMaker {
public:
FillZerosLikeOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
FillZerosLikeOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input of fill-zeros-like op.");
AddOutput("Y", "The variable will be filled up with zeros.");
......
......@@ -57,7 +57,7 @@ class FTRLOp : public framework::OperatorWithKernel {
class FTRLOpMaker : public framework::OpProtoAndCheckerMaker {
public:
FTRLOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
FTRLOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Param",
"(Tensor, default Tensor<float>) "
......
......@@ -67,7 +67,7 @@ class GatherGradOp : public framework::OperatorWithKernel {
class GatherOpMaker : public framework::OpProtoAndCheckerMaker {
public:
GatherOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
GatherOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The source input of gather op");
AddInput("Index", "The index input of gather op");
......
......@@ -60,15 +60,14 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
framework::OpKernelType GetKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
static_cast<framework::DataType>(ctx.Attr<int>("dtype")),
static_cast<framework::proto::DataType>(ctx.Attr<int>("dtype")),
ctx.device_context());
}
};
class GaussianRandomOpMaker : public framework::OpProtoAndCheckerMaker {
public:
GaussianRandomOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
GaussianRandomOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddOutput("Out", "Output matrix of gaussian random op");
......@@ -91,7 +90,7 @@ class GaussianRandomOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<int>("dtype",
"(int, default 5(FP32)) "
"Output data type.")
.SetDefault(framework::DataType::FP32);
.SetDefault(framework::proto::DataType::FP32);
AddComment(R"DOC(
GaussianRandom Operator.
......
......@@ -67,7 +67,7 @@ class GRUOp : public framework::OperatorWithKernel {
class GRUOpMaker : public framework::OpProtoAndCheckerMaker {
public:
GRUOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
GRUOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Input",
"(LoDTensor) The first input is a LodTensor, which supports "
......
......@@ -71,8 +71,7 @@ class GRUUnitOp : public framework::OperatorWithKernel {
class GRUUnitOpMaker : public framework::OpProtoAndCheckerMaker {
public:
GRUUnitOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
GRUUnitOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Input",
"(Tensor) Matrix with shape [batch_size, frame_size * 3] for the "
......
......@@ -46,8 +46,7 @@ class HingeLossOp : public framework::OperatorWithKernel {
template <typename AttrType>
class HingeLossOpMaker : public framework::OpProtoAndCheckerMaker {
public:
HingeLossOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
HingeLossOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Logits",
"The input value (Logits) of Hinge loss op."
......
......@@ -45,8 +45,7 @@ class HuberLossOp : public framework::OperatorWithKernel {
template <typename AttrType>
class HuberLossOpMaker : public framework::OpProtoAndCheckerMaker {
public:
HuberLossOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
HuberLossOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X",
"The input value of huber loss op."
......
......@@ -70,8 +70,7 @@ class IncrementOp : public framework::OperatorBase {
class IncrementOpMaker : public framework::OpProtoAndCheckerMaker {
public:
IncrementOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
IncrementOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(Tensor) The input tensor of increment operator");
AddOutput("Out", "(Tensor) The output tensor of increment operator.");
......
......@@ -47,8 +47,7 @@ class IsEmptyOp : public framework::OperatorBase {
class IsEmptyOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
IsEmptyOpProtoMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
IsEmptyOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput(kInput, "(Tensor) Tensor which is to be checked.");
AddOutput(kOutput, "(Tensor) a boolean Tensor that indicate empty or not.");
......
......@@ -48,7 +48,7 @@ class L1NormGradOp : public framework::OperatorWithKernel {
class L1NormOpMaker : public framework::OpProtoAndCheckerMaker {
public:
L1NormOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
L1NormOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(Tensor) The input of l1_norm op.");
AddOutput("Out", "(Scalar) The output of l1_norm op.");
......
......@@ -19,8 +19,7 @@ namespace operators {
class LinearChainCRFOpMaker : public framework::OpProtoAndCheckerMaker {
public:
LinearChainCRFOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
LinearChainCRFOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Emission",
"(LoDTensor, default LoDTensor<float>) "
......
......@@ -58,8 +58,7 @@ class LoadOp : public framework::OperatorBase {
class LoadOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
LoadOpProtoMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
LoadOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddOutput("Out", "(Tensor) The tensor need to be loaded");
AddAttr<std::string>("file_path",
......
......@@ -38,8 +38,7 @@ class LoDArrayLengthOp : public framework::OperatorBase {
class LoDArrayLengthProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
LoDArrayLengthProtoMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
LoDArrayLengthProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(LoDTensorArray) The input tensor array.");
AddOutput("Out", "(Tensor) 1x1 CPU Tensor of length, int64_t");
......
......@@ -35,8 +35,7 @@ class LoDRankTableOp : public framework::OperatorBase {
class LoDRankTableOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
LoDRankTableOpProtoMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
LoDRankTableOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X",
"(LoDTensor) input lod tensor, must contain lod information.");
......@@ -67,7 +66,7 @@ class LoDRankTableInferVarType : public framework::VarTypeInference {
framework::BlockDescBind *block) const override {
for (auto &o : op_desc.Output("Out")) {
block->FindRecursiveOrCreateVar(o)->SetType(
framework::VarDesc::LOD_RANK_TABLE);
framework::proto::VarDesc::LOD_RANK_TABLE);
}
}
};
......
......@@ -48,8 +48,7 @@ class LoDResetOp : public framework::OperatorWithKernel {
class LoDResetOpMaker : public framework::OpProtoAndCheckerMaker {
public:
LoDResetOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
LoDResetOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(LoDTensor) The input tensor of lod_reset operator.");
AddInput("TargetLoD",
......
......@@ -97,8 +97,7 @@ class LoDTensorToArrayOp : public framework::OperatorBase {
class LoDTensorToArrayOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
LoDTensorToArrayOpProtoMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
LoDTensorToArrayOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "");
AddInput("RankTable", "");
......@@ -131,7 +130,7 @@ class LoDTensorToArrayInferVarType : public framework::VarTypeInference {
void operator()(const framework::OpDescBind &op_desc,
framework::BlockDescBind *block) const override {
for (auto &out_var : op_desc.Output("Out")) {
block->Var(out_var)->SetType(framework::VarDesc::LOD_TENSOR_ARRAY);
block->Var(out_var)->SetType(framework::proto::VarDesc::LOD_TENSOR_ARRAY);
}
}
};
......
......@@ -46,8 +46,7 @@ class LogLossOp : public framework::OperatorWithKernel {
template <typename AttrType>
class LogLossOpMaker : public framework::OpProtoAndCheckerMaker {
public:
LogLossOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
LogLossOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Predicted",
"The input value (Predicted) of Log loss op."
......
......@@ -20,8 +20,7 @@ namespace operators {
template <typename OpComment>
class BinaryLogicalOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
BinaryLogicalOpProtoMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
BinaryLogicalOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
OpComment comment;
AddInput("X",
......@@ -45,8 +44,7 @@ Each element of Out is calculated by %s
template <typename OpComment>
class UnaryLogicalOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
UnaryLogicalOpProtoMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
UnaryLogicalOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
OpComment comment;
AddInput("X", string::Sprintf("(LoDTensor) Operand of %s operator",
......
......@@ -51,8 +51,7 @@ class LookupTableOp : public framework::OperatorWithKernel {
class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
public:
LookupTableOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
LookupTableOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("W",
"An input represents embedding tensors, "
......@@ -117,11 +116,12 @@ class LookupTableOpGradVarTypeInference : public framework::VarTypeInference {
if (is_sparse) {
VLOG(3) << "lookup_table_grad op " << framework::GradVarName("W")
<< " is set to SelectedRows";
block->Var(out_var_name)->SetType(framework::VarDesc::SELECTED_ROWS);
block->Var(out_var_name)
->SetType(framework::proto::VarDesc::SELECTED_ROWS);
} else {
VLOG(3) << "lookup_table_grad op " << framework::GradVarName("W")
<< " is set to LoDTensor";
block->Var(out_var_name)->SetType(framework::VarDesc::LOD_TENSOR);
block->Var(out_var_name)->SetType(framework::proto::VarDesc::LOD_TENSOR);
}
}
};
......
......@@ -140,7 +140,7 @@ class LRNOp : public framework::OperatorWithKernel {
template <typename T>
class LRNOpMaker : public framework::OpProtoAndCheckerMaker {
public:
LRNOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
LRNOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X",
"(Tensor) The input of LRN operator. "
......
......@@ -102,7 +102,7 @@ class LSTMOp : public framework::OperatorWithKernel {
class LSTMOpMaker : public framework::OpProtoAndCheckerMaker {
public:
LSTMOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
LSTMOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Input",
"(LoDTensor) the first input is a LodTensor, which support "
......
......@@ -48,8 +48,7 @@ class LstmUnitOp : public framework::OperatorWithKernel {
class LstmUnitOpMaker : public framework::OpProtoAndCheckerMaker {
public:
LstmUnitOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
LstmUnitOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X",
"Lstm unit only applies non-linear activations, please make sure"
......
......@@ -42,8 +42,7 @@ class MarginRankLossOp : public framework::OperatorWithKernel {
template <typename T>
class MarginRankLossOpMaker : public framework::OpProtoAndCheckerMaker {
public:
MarginRankLossOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
MarginRankLossOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X1",
"(2-D tensor with shape [batch_size x 1]) The score for "
......
......@@ -130,7 +130,7 @@ class MatMulOp : public framework::OperatorWithKernel {
class MatMulOpMaker : public framework::OpProtoAndCheckerMaker {
public:
MatMulOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
MatMulOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The first input of MatMul op");
AddInput("Y", "The second input of MatMul op");
......
......@@ -40,8 +40,7 @@ class MaxSeqenceLenOp : public framework::OperatorBase {
class MaxSeqenceLenOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
MaxSeqenceLenOpProtoMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
MaxSeqenceLenOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("RankTable", "The lod_rank_table.");
AddOutput("Out", "The max sequence length.");
......
......@@ -20,7 +20,7 @@ using framework::Tensor;
class MaxOutOpMaker : public framework::OpProtoAndCheckerMaker {
public:
MaxOutOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
MaxOutOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput(
"X",
......
......@@ -32,7 +32,7 @@ class MeanOp : public framework::OperatorWithKernel {
class MeanOpMaker : public framework::OpProtoAndCheckerMaker {
public:
MeanOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
MeanOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input of mean op");
AddOutput("Out", "The output of mean op");
......
......@@ -114,8 +114,7 @@ class MergeLoDTensorOp : public framework::OperatorBase {
class MergeLoDTensorOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
MergeLoDTensorOpProtoMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
MergeLoDTensorOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X",
"The input LoDTensor, contains complete lod information to "
......
......@@ -46,7 +46,7 @@ class MinusOp : public framework::OperatorWithKernel {
class MinusOpMaker : public framework::OpProtoAndCheckerMaker {
public:
MinusOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
MinusOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The left tensor of minus operator.");
AddInput("Y", "The right tensor of minus operator.");
......
......@@ -39,8 +39,7 @@ class ModifiedHuberLossOp : public framework::OperatorWithKernel {
class ModifiedHuberLossOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ModifiedHuberLossOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
ModifiedHuberLossOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X",
"The input tensor of modified huber loss op. "
......
......@@ -54,8 +54,7 @@ class MomentumOp : public framework::OperatorWithKernel {
class MomentumOpMaker : public framework::OpProtoAndCheckerMaker {
public:
MomentumOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
MomentumOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Param",
"(Tensor, default Tensor<float>) "
......
......@@ -71,7 +71,7 @@ class MulOpShapeInference : public framework::InferShapeBase {
class MulOpMaker : public framework::OpProtoAndCheckerMaker {
public:
MulOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
MulOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The first input of mul op");
AddInput("Y", "The second input of mul op");
......
......@@ -61,8 +61,7 @@ class MultiplexOp : public framework::OperatorWithKernel {
class MultiplexOpMaker : public framework::OpProtoAndCheckerMaker {
public:
MultiplexOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
MultiplexOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Ids", "The index tensor of multiplex operator.");
AddInput("X", "The candidate tensors of multiplex operator.")
......
......@@ -35,8 +35,8 @@ Here we give some examples to show how these rules will be used.
```c++
class AccumulateOpMaker : public framework::OpProtoAndCheckerMaker {
public:
AccumulateOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
AccumulateOpMaker(OpProto *proto,
OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(Tensor) The input tensor that has to be accumulated to the output tensor.
If the output size is not the same as input size,
......
......@@ -43,8 +43,7 @@ class NCCLInitOp : public framework::OperatorBase {
class NCCLInitOpMaker : public framework::OpProtoAndCheckerMaker {
public:
NCCLInitOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
NCCLInitOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddOutput("Communicator",
"Create Communicator for communicating between gpus");
......@@ -52,7 +51,7 @@ class NCCLInitOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<int>("dtype",
"(int, default 5 (FP32)) "
"Output data type")
.SetDefault(framework::DataType::FP32);
.SetDefault(framework::proto::DataType::FP32);
AddComment(R"DOC(
NCCLInit Operator.
......@@ -141,8 +140,7 @@ class NCCLBcastOp : public framework::OperatorWithKernel {
// AllreduceOp
class NCCLAllReduceOpMaker : public framework::OpProtoAndCheckerMaker {
public:
NCCLAllReduceOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
NCCLAllReduceOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input of AllReduce op");
AddInput("Communicator", "Communicator for communicating between gpus");
......@@ -163,8 +161,7 @@ AllReduce the input tensors.
// ReduceOp
class NCCLReduceOpMaker : public framework::OpProtoAndCheckerMaker {
public:
NCCLReduceOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
NCCLReduceOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input of Reduce op");
AddInput("Communicator", "Communicator for communicating between gpus");
......@@ -190,8 +187,7 @@ Reduce the tensors.
// BcastOp
class NCCLBcastOpMaker : public framework::OpProtoAndCheckerMaker {
public:
NCCLBcastOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
NCCLBcastOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input of BcastSend op");
AddInput("Communicator", "Communicator for communicating between gpus");
......
......@@ -73,7 +73,7 @@ class NCEOp : public framework::OperatorWithKernel {
class NCEOpMaker : public framework::OpProtoAndCheckerMaker {
public:
NCEOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
NCEOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Input", "(Tensor) A tensor of shape [batch_size, dim].");
AddInput(
......
......@@ -48,7 +48,7 @@ class PadOp : public framework::OperatorWithKernel {
class PadOpMaker : public framework::OpProtoAndCheckerMaker {
public:
PadOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
PadOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X",
"The input of pad op. "
......
......@@ -67,8 +67,7 @@ void PoolOpGrad::InferShape(framework::InferShapeContext *ctx) const {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
Pool2dOpMaker::Pool2dOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
Pool2dOpMaker::Pool2dOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput(
"X",
......@@ -136,8 +135,7 @@ Example:
)DOC");
}
Pool3dOpMaker::Pool3dOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
Pool3dOpMaker::Pool3dOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X",
"(Tensor) The input tensor of pooling operator. "
......
......@@ -40,14 +40,12 @@ class PoolOpGrad : public framework::OperatorWithKernel {
class Pool2dOpMaker : public framework::OpProtoAndCheckerMaker {
public:
Pool2dOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker);
Pool2dOpMaker(OpProto* proto, OpAttrChecker* op_checker);
};
class Pool3dOpMaker : public framework::OpProtoAndCheckerMaker {
public:
Pool3dOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker);
Pool3dOpMaker(OpProto* proto, OpAttrChecker* op_checker);
};
template <typename DeviceContext, typename T>
......
......@@ -100,8 +100,7 @@ class MaxPoolWithIndexOpGrad : public framework::OperatorWithKernel {
class MaxPool2dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker {
public:
MaxPool2dWithIndexOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
MaxPool2dWithIndexOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput(
"X",
......@@ -178,8 +177,7 @@ Example:
class MaxPool3dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker {
public:
MaxPool3dWithIndexOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
MaxPool3dWithIndexOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X",
"(Tensor) The input tensor of pooling operator. "
......
......@@ -95,8 +95,7 @@ class PositiveNegativePairOp : public framework::OperatorWithKernel {
class PositiveNegativePairOpMaker : public framework::OpProtoAndCheckerMaker {
public:
PositiveNegativePairOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
PositiveNegativePairOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Score",
"(Tensor, float) Model Score on an item (with "
......
......@@ -90,8 +90,7 @@ class PrecisionRecallOp : public framework::OperatorWithKernel {
class PrecisionRecallOpMaker : public framework::OpProtoAndCheckerMaker {
public:
PrecisionRecallOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
PrecisionRecallOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("MaxProbs",
"(Tensor, default Tensor<float>) A 2-D tensor with shape N x 1, "
......
......@@ -38,7 +38,7 @@ class PReluOp : public framework::OperatorWithKernel {
class PReluOpMaker : public framework::OpProtoAndCheckerMaker {
public:
PReluOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
PReluOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input tensor of prelu operator.");
AddInput("Alpha", "The alpha weight of prelu operator.");
......
......@@ -59,8 +59,7 @@ class ProximalAdagradOp : public framework::OperatorWithKernel {
class ProximalAdagradOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ProximalAdagradOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
ProximalAdagradOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Param",
"(Tensor, default Tensor<float>) "
......
......@@ -47,8 +47,7 @@ class ProximalGDOp : public framework::OperatorWithKernel {
class ProximalGDOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ProximalGDOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
ProximalGDOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Param",
"(Tensor, default Tensor<float>) "
......
......@@ -45,8 +45,7 @@ class RankLossOp : public framework::OperatorWithKernel {
class RankLossOpMaker : public framework::OpProtoAndCheckerMaker {
public:
RankLossOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
RankLossOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Label",
"(2-D Tensor with shape [batch_size x 1]) "
......
......@@ -497,8 +497,7 @@ class RecurrentGradOp : public RecurrentBase {
class RecurrentOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
RecurrentOpProtoMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
RecurrentOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput(kInputs, "rnn inputs").AsDuplicable();
AddInput(kInitialStates, "rnn initial states").AsDuplicable();
......
......@@ -97,7 +97,7 @@ class RecvOp : public framework::OperatorBase {
class RecvOpMaker : public framework::OpProtoAndCheckerMaker {
public:
RecvOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
RecvOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("RX", "(Tensor) Input tensor to be saved");
AddComment(R"DOC(
......
......@@ -83,7 +83,7 @@ class ReduceGradOp : public framework::OperatorWithKernel {
class ReduceOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ReduceOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
ReduceOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X",
"(Tensor) The input tensor. Tensors with rank at most 6 are "
......@@ -135,8 +135,7 @@ If reduce_all is true, just reduce along all dimensions and output a scalar.
class ReduceSumOpMaker : public ReduceOpMaker {
public:
ReduceSumOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
ReduceSumOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: ReduceOpMaker(proto, op_checker) {
SetComment("ReduceSum", "sum");
AddComment(comment_);
......@@ -145,8 +144,7 @@ class ReduceSumOpMaker : public ReduceOpMaker {
class ReduceMeanOpMaker : public ReduceOpMaker {
public:
ReduceMeanOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
ReduceMeanOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: ReduceOpMaker(proto, op_checker) {
SetComment("ReduceMean", "mean");
AddComment(comment_);
......@@ -155,8 +153,7 @@ class ReduceMeanOpMaker : public ReduceOpMaker {
class ReduceMaxOpMaker : public ReduceOpMaker {
public:
ReduceMaxOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
ReduceMaxOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: ReduceOpMaker(proto, op_checker) {
SetComment("ReduceMax", "max");
AddComment(comment_);
......@@ -165,8 +162,7 @@ class ReduceMaxOpMaker : public ReduceOpMaker {
class ReduceMinOpMaker : public ReduceOpMaker {
public:
ReduceMinOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
ReduceMinOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: ReduceOpMaker(proto, op_checker) {
SetComment("ReduceMin", "min");
AddComment(comment_);
......
......@@ -77,8 +77,7 @@ class ReshapeOp : public framework::OperatorWithKernel {
class ReshapeOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ReshapeOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
ReshapeOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input tensor of reshape operator.");
AddOutput("Out", "The output tensor of reshape operator.");
......
......@@ -63,8 +63,7 @@ class RmspropOp : public framework::OperatorWithKernel {
class RmspropOpMaker : public framework::OpProtoAndCheckerMaker {
public:
RmspropOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
RmspropOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Param",
"(Tensor, default Tensor<float>) "
......
......@@ -57,15 +57,14 @@ class RNNMemoryHelperOpShapeInference : public framework::InferShapeBase {
class RNNMemoryHelperOpInfoMaker : public framework::OpProtoAndCheckerMaker {
public:
RNNMemoryHelperOpInfoMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
RNNMemoryHelperOpInfoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "");
AddOutput("Out", "");
AddAttr<int>("dtype",
"(int, default 5 (FP32)) "
"Output data type")
.SetDefault(framework::DataType::FP32);
.SetDefault(framework::proto::DataType::FP32);
AddComment("");
}
};
......@@ -114,8 +113,7 @@ class RNNMemoryHelperGradOp : public framework::OperatorBase {
class RNNMemoryHelperGradOpInfoMaker
: public framework::OpProtoAndCheckerMaker {
public:
RNNMemoryHelperGradOpInfoMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
RNNMemoryHelperGradOpInfoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput(framework::GradVarName("Out"), "");
AddInput("X", "");
......@@ -124,7 +122,7 @@ class RNNMemoryHelperGradOpInfoMaker
AddAttr<int>("dtype",
"(int, default 5 (FP32)) "
"Output data type")
.SetDefault(framework::DataType::FP32);
.SetDefault(framework::proto::DataType::FP32);
AddComment("");
}
};
......
......@@ -99,8 +99,7 @@ class ROIPoolGradOp : public framework::OperatorWithKernel {
class ROIPoolOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ROIPoolOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
ROIPoolOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X",
"(Tensor), "
......
......@@ -76,8 +76,7 @@ class RowConvGradOp : public framework::OperatorWithKernel {
class RowConvOpMaker : public framework::OpProtoAndCheckerMaker {
public:
RowConvOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
RowConvOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X",
"(LoDTensor), the input(X) is a LodTensor, which supports "
......
......@@ -94,8 +94,7 @@ class SaveOp : public framework::OperatorBase {
class SaveOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
SaveOpProtoMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
SaveOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(Tensor ) Input tensor to be saved");
AddComment(R"DOC(
......
......@@ -38,7 +38,7 @@ class ScaleOp : public framework::OperatorWithKernel {
template <typename AttrType>
class ScaleOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ScaleOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
ScaleOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(Tensor) Input tensor of scale operator.");
AddOutput("Out", "(Tensor) Output tensor of scale operator.");
......
......@@ -78,8 +78,7 @@ class ScatterGradOp : public framework::OperatorWithKernel {
class ScatterOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ScatterOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
ScatterOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Ref", "The source input of scatter op");
AddInput("Index",
......
......@@ -59,7 +59,7 @@ class SendOp : public framework::OperatorBase {
class SendOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SendOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
SendOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(Tensor) Input tensor to be saved");
AddOutput("Out", "(Tensor) Output fetched from server");
......
......@@ -43,8 +43,7 @@ class SequenceConcatOp : public framework::OperatorWithKernel {
class SequenceConcatOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SequenceConcatOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
SequenceConcatOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X",
"(LodTensorArray) Input is a vector of LoDTensor, "
......
......@@ -100,8 +100,7 @@ class SequenceConvGradOp : public framework::OperatorWithKernel {
class SequenceConvOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SequenceConvOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
SequenceConvOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput(
"X",
......
......@@ -37,8 +37,7 @@ class SequenceExpandOp : public framework::OperatorWithKernel {
class SequenceExpandOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SequenceExpandOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
SequenceExpandOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X",
"(Tensor or LoDTensor) The input(X) of this operator can be a "
......
......@@ -37,8 +37,7 @@ class SequencePoolOp : public framework::OperatorWithKernel {
class SequencePoolOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SequencePoolOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
SequencePoolOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(LoDTensor) The variable-length input of SequencePoolOp");
AddOutput("Out",
......
......@@ -79,8 +79,7 @@ class SequenceSliceGradOp : public framework::OperatorWithKernel {
class SequenceSliceOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SequenceSliceOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
SequenceSliceOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X",
"(LoDTensor), "
......
......@@ -33,8 +33,7 @@ class SequenceSoftmaxOp : public framework::OperatorWithKernel {
class SequenceSoftmaxOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SequenceSoftmaxOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
SequenceSoftmaxOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X",
"(LoDTensor) 1-D or 2-D input LoDTensor with the 2-nd dimension "
......
......@@ -43,7 +43,7 @@ class SGDOp : public framework::OperatorWithKernel {
class SGDOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SGDOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
SGDOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Param", "(Tensor) Input parameter");
AddInput("LearningRate", "(Tensor) Learning rate of SGD");
......
......@@ -54,8 +54,7 @@ class ShrinkRNNMemoryOp : public ArrayOp {
class ShrinkRNNMemoryOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
ShrinkRNNMemoryOpProtoMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
ShrinkRNNMemoryOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(LoDTensor) The RNN step memory to be shrinked.");
AddInput("RankTable", "(LoDRankTable) The lod_rank_table of dynamic RNN.");
......
......@@ -86,8 +86,8 @@ class SigmoidCrossEntropyWithLogitsGradOp
class SigmoidCrossEntropyWithLogitsOpMaker
: public framework::OpProtoAndCheckerMaker {
public:
SigmoidCrossEntropyWithLogitsOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
SigmoidCrossEntropyWithLogitsOpMaker(OpProto* proto,
OpAttrChecker* op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X",
"(Tensor, default Tensor<float>), a 2-D tensor with shape N x D, "
......
......@@ -34,7 +34,7 @@ class SignOp : public framework::OperatorWithKernel {
template <typename AttrType>
class SignOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SignOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
SignOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(Tensor) Input tensor of sign operator.");
AddOutput("Out", "(Tensor) Output tensor of sign operator.");
......
......@@ -47,8 +47,7 @@ class SmoothL1LossOp : public framework::OperatorWithKernel {
template <typename AttrType>
class SmoothL1LossOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SmoothL1LossOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
SmoothL1LossOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X",
"(Tensor, default Tensor<float>) A tensor with rank at least 2. "
......
......@@ -36,8 +36,7 @@ class SoftmaxOp : public framework::OperatorWithKernel {
class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SoftmaxOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
SoftmaxOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X",
"The input tensor of softmax. "
......
......@@ -20,8 +20,7 @@ namespace operators {
class SoftmaxWithCrossEntropyOpMaker
: public framework::OpProtoAndCheckerMaker {
public:
SoftmaxWithCrossEntropyOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
SoftmaxWithCrossEntropyOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Logits",
"(Tensor, default: Tensor<float>), The unscaled log probabilities "
......
......@@ -118,8 +118,7 @@ class SplitLoDTensorOp : public framework::OperatorBase {
class SplitLoDTensorOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
SplitLoDTensorOpProtoMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
SplitLoDTensorOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input LoDTensor");
AddInput("Mask", "A bool column vector which mask the input");
......
......@@ -65,7 +65,7 @@ class SplitOp : public framework::OperatorWithKernel {
class SplitOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SplitOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
SplitOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(Tensor) Input tensor of the split operator.");
AddOutput("Out", "(Tensor) Output tensors of the split operator.")
......
......@@ -18,7 +18,7 @@ namespace operators {
class SppOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SppOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
SppOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput(
"X",
......
......@@ -56,8 +56,7 @@ class SquaredL2DistanceOp : public framework::OperatorWithKernel {
class SquaredL2DistanceOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SquaredL2DistanceOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
SquaredL2DistanceOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(Tensor) Input of SquaredL2DistanceOp.");
AddInput("Y", "(Tensor) Target of SquaredL2DistanceOp.");
......
......@@ -48,8 +48,7 @@ class SquaredL2NormGradOp : public framework::OperatorWithKernel {
class SquaredL2NormOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SquaredL2NormOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
SquaredL2NormOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(Tensor) The input of squared_l2_norm op.");
AddOutput("Out", "(Scalar) The output of squared_l2_norm op.");
......
......@@ -29,7 +29,7 @@ class SumOp : public framework::OperatorWithKernel {
"Output(Out) of SumOp should not be null.");
if (ctx->IsRuntime() &&
ctx->GetOutputsVarType("Out")[0] ==
framework::VarDesc::LOD_TENSOR_ARRAY) {
framework::proto::VarDesc::LOD_TENSOR_ARRAY) {
return; // skip runtime infershape when is tensor array;
}
......@@ -72,8 +72,8 @@ class SumOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_NE(dtype, -1,
"Sum operator should have at least one tensor");
return framework::OpKernelType(static_cast<framework::DataType>(dtype),
ctx.device_context());
return framework::OpKernelType(
static_cast<framework::proto::DataType>(dtype), ctx.device_context());
} else if (x_vars[0]->IsType<framework::SelectedRows>()) {
return framework::OpKernelType(
framework::ToDataType(
......@@ -98,7 +98,7 @@ class SumOp : public framework::OperatorWithKernel {
class SumOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SumOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
SumOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(vector<Tensor>) The input tensors of sum operator.")
.AsDuplicable();
......@@ -118,7 +118,7 @@ class SumOpVarTypeInference : public framework::VarTypeInference {
void operator()(const framework::OpDescBind& op_desc,
framework::BlockDescBind* block) const override {
auto& inputs = op_desc.Input("X");
auto var_type = framework::VarDesc::SELECTED_ROWS;
auto var_type = framework::proto::VarDesc::SELECTED_ROWS;
for (auto& name : op_desc.Input("X")) {
VLOG(10) << name << " "
......@@ -128,12 +128,12 @@ class SumOpVarTypeInference : public framework::VarTypeInference {
bool any_input_is_lod_tensor = std::any_of(
inputs.begin(), inputs.end(), [block](const std::string& name) {
return block->FindRecursiveOrCreateVar(name)->GetType() ==
framework::VarDesc::LOD_TENSOR;
framework::proto::VarDesc::LOD_TENSOR;
});
auto is_tensor_array = [block](const std::string& name) {
return detail::Ref(block->FindRecursiveOrCreateVar(name)).GetType() ==
framework::VarDesc::LOD_TENSOR_ARRAY;
framework::proto::VarDesc::LOD_TENSOR_ARRAY;
};
bool any_input_is_tensor_array =
......@@ -152,9 +152,9 @@ class SumOpVarTypeInference : public framework::VarTypeInference {
PADDLE_ENFORCE(all_inputs_are_tensor_array,
"Not all inputs are tensor array:\n%s", os.str());
}
var_type = framework::VarDesc::LOD_TENSOR_ARRAY;
var_type = framework::proto::VarDesc::LOD_TENSOR_ARRAY;
} else if (any_input_is_lod_tensor) {
var_type = framework::VarDesc::LOD_TENSOR;
var_type = framework::proto::VarDesc::LOD_TENSOR;
}
auto out_var_name = op_desc.Output("Out").front();
......
......@@ -51,8 +51,7 @@ class WriteToArrayOp : public ArrayOp {
class WriteToArrayOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
WriteToArrayOpProtoMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
WriteToArrayOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(LoDTensor) the tensor will be written to tensor array");
AddInput(
......@@ -104,7 +103,7 @@ class WriteToArrayInferVarType : public framework::VarTypeInference {
VLOG(10) << "Set Variable " << out_name << " as LOD_TENSOR_ARRAY";
auto &out = detail::Ref(block->FindRecursiveOrCreateVar(out_name),
"Cannot found %s", out_name);
out.SetType(framework::VarDesc::LOD_TENSOR_ARRAY);
out.SetType(framework::proto::VarDesc::LOD_TENSOR_ARRAY);
auto *x = block->FindVarRecursive(x_name);
if (x != nullptr) {
out.SetDataType(x->GetDataType());
......@@ -140,8 +139,7 @@ class ReadFromArrayOp : public ArrayOp {
class ReadFromArrayProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
ReadFromArrayProtoMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
ReadFromArrayProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(TensorArray) the array will be read from.");
AddInput("I",
......
......@@ -46,7 +46,7 @@ class TopkOp : public framework::OperatorWithKernel {
class TopkOpMaker : public framework::OpProtoAndCheckerMaker {
public:
TopkOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
TopkOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(Tensor) The input of Topk op");
AddOutput("Out", "(Tensor) The output tensor of Topk op");
......
......@@ -55,8 +55,7 @@ class TransposeOp : public framework::OperatorWithKernel {
class TransposeOpMaker : public framework::OpProtoAndCheckerMaker {
public:
TransposeOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
TransposeOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput(
"X",
......
......@@ -66,15 +66,14 @@ class UniformRandomOp : public framework::OperatorWithKernel {
framework::OpKernelType GetKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
static_cast<framework::DataType>(ctx.Attr<int>("dtype")),
static_cast<framework::proto::DataType>(ctx.Attr<int>("dtype")),
ctx.GetPlace());
}
};
class UniformRandomOpMaker : public framework::OpProtoAndCheckerMaker {
public:
UniformRandomOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
UniformRandomOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddOutput("Out", "(Tensor) The output tensor of uniform random op");
AddComment(R"DOC(
......@@ -100,7 +99,7 @@ uniform distribution.
"0 means use a seed generated by the system.")
.SetDefault(0);
AddAttr<int>("dtype", "(int, default 5(FP32)) Output tensor data type")
.SetDefault(framework::DataType::FP32);
.SetDefault(framework::proto::DataType::FP32);
}
};
} // namespace operators
......
......@@ -18,8 +18,7 @@ namespace operators {
class Unpool2dOpMaker : public framework::OpProtoAndCheckerMaker {
public:
Unpool2dOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
Unpool2dOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput(
"X",
......
......@@ -64,7 +64,7 @@ class WhileOp : public framework::OperatorBase {
class WhileOpMaker : public framework::OpProtoAndCheckerMaker {
public:
WhileOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
WhileOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput(kParameters,
"A set of variables, which are required by operators inside the "
......@@ -321,10 +321,10 @@ class WhileGradOpShapeInference : public framework::InferShapeBase {
continue;
}
auto dims = ctx->GetInputsElementDim(kParameters, i);
if (var_types[i] == framework::VarDesc::LOD_TENSOR) {
if (var_types[i] == framework::proto::VarDesc::LOD_TENSOR) {
names_to_set.push_back(pg_names[i]);
dims_to_set.push_back(dims);
} else if (var_types[i] == framework::VarDesc::LOD_TENSOR_ARRAY) {
} else if (var_types[i] == framework::proto::VarDesc::LOD_TENSOR_ARRAY) {
// not sure how to set the dim of LOD_TENSOR_ARRAY
names_to_set.push_back(pg_names[i]);
dims_to_set.push_back(dims);
......
......@@ -31,31 +31,32 @@ std::string Escape(const std::string& s) {
return r;
}
std::string AttrType(paddle::framework::AttrType at) {
std::string AttrType(paddle::framework::proto::AttrType at) {
switch (at) {
case paddle::framework::INT:
case paddle::framework::proto::INT:
return "int";
case paddle::framework::FLOAT:
case paddle::framework::proto::FLOAT:
return "float";
case paddle::framework::STRING:
case paddle::framework::proto::STRING:
return "string";
case paddle::framework::BOOLEAN:
case paddle::framework::proto::BOOLEAN:
return "bool";
case paddle::framework::INTS:
case paddle::framework::proto::INTS:
return "int array";
case paddle::framework::FLOATS:
case paddle::framework::proto::FLOATS:
return "float array";
case paddle::framework::STRINGS:
case paddle::framework::proto::STRINGS:
return "string array";
case paddle::framework::BOOLEANS:
case paddle::framework::proto::BOOLEANS:
return "bool array";
case paddle::framework::BLOCK:
case paddle::framework::proto::BLOCK:
return "block id";
}
return "UNKNOWN"; // not possible
}
void PrintVar(const paddle::framework::OpProto::Var& v, std::stringstream& ss) {
void PrintVar(const paddle::framework::proto::OpProto::Var& v,
std::stringstream& ss) {
ss << " { "
<< "\n"
<< " \"name\" : \"" << Escape(v.name()) << "\",\n"
......@@ -65,7 +66,7 @@ void PrintVar(const paddle::framework::OpProto::Var& v, std::stringstream& ss) {
<< " },";
}
void PrintAttr(const paddle::framework::OpProto::Attr& a,
void PrintAttr(const paddle::framework::proto::OpProto::Attr& a,
std::stringstream& ss) {
ss << " { "
<< "\n"
......@@ -81,7 +82,7 @@ void PrintOpProto(const std::string& type,
std::stringstream& ss) {
std::cerr << "Processing " << type << "\n";
const paddle::framework::OpProto* p = opinfo.proto_;
const paddle::framework::proto::OpProto* p = opinfo.proto_;
if (p == nullptr) {
return; // It is possible that an operator doesn't have OpProto.
}
......
......@@ -144,7 +144,7 @@ void BindProgramDesc(py::module &m) {
.def("serialize_to_string", SerializeMessage<ProgramDescBind>)
.def("parse_from_string",
[](ProgramDescBind &program_desc, const std::string &data) {
ProgramDesc *desc = program_desc.Proto();
proto::ProgramDesc *desc = program_desc.Proto();
PADDLE_ENFORCE(desc->ParseFromString(data),
"Fail to parse ProgramDesc from string. This could "
"be a bug of Paddle.");
......@@ -184,14 +184,14 @@ void BindBlockDesc(py::module &m) {
}
void BindVarDsec(py::module &m) {
py::enum_<DataType>(m, "DataType", "")
.value("BOOL", DataType::BOOL)
.value("INT16", DataType::INT16)
.value("INT32", DataType::INT32)
.value("INT64", DataType::INT64)
.value("FP16", DataType::FP16)
.value("FP32", DataType::FP32)
.value("FP64", DataType::FP64);
py::enum_<proto::DataType>(m, "DataType", "")
.value("BOOL", proto::DataType::BOOL)
.value("INT16", proto::DataType::INT16)
.value("INT32", proto::DataType::INT32)
.value("INT64", proto::DataType::INT64)
.value("FP16", proto::DataType::FP16)
.value("FP32", proto::DataType::FP32)
.value("FP64", proto::DataType::FP64);
py::class_<VarDescBind> var_desc(m, "VarDesc", "");
var_desc
......@@ -213,27 +213,27 @@ void BindVarDsec(py::module &m) {
.def("persistable", &VarDescBind::Persistable)
.def("set_persistable", &VarDescBind::SetPersistable);
py::enum_<VarDesc::VarType>(var_desc, "VarType", "")
.value("LOD_TENSOR", VarDesc::LOD_TENSOR)
.value("SELECTED_ROWS", VarDesc::SELECTED_ROWS)
.value("FEED_MINIBATCH", VarDesc::FEED_MINIBATCH)
.value("FETCH_LIST", VarDesc::FETCH_LIST)
.value("STEP_SCOPES", VarDesc::STEP_SCOPES)
.value("LOD_RANK_TABLE", VarDesc::LOD_RANK_TABLE)
.value("LOD_TENSOR_ARRAY", VarDesc::LOD_TENSOR_ARRAY);
py::enum_<proto::VarDesc::VarType>(var_desc, "VarType", "")
.value("LOD_TENSOR", proto::VarDesc::LOD_TENSOR)
.value("SELECTED_ROWS", proto::VarDesc::SELECTED_ROWS)
.value("FEED_MINIBATCH", proto::VarDesc::FEED_MINIBATCH)
.value("FETCH_LIST", proto::VarDesc::FETCH_LIST)
.value("STEP_SCOPES", proto::VarDesc::STEP_SCOPES)
.value("LOD_RANK_TABLE", proto::VarDesc::LOD_RANK_TABLE)
.value("LOD_TENSOR_ARRAY", proto::VarDesc::LOD_TENSOR_ARRAY);
}
void BindOpDesc(py::module &m) {
py::enum_<AttrType>(m, "AttrType", "")
.value("INT", AttrType::INT)
.value("INTS", AttrType::INTS)
.value("FLOAT", AttrType::FLOAT)
.value("FLOATS", AttrType::FLOATS)
.value("STRING", AttrType::STRING)
.value("STRINGS", AttrType::STRINGS)
.value("BOOL", AttrType::BOOLEAN)
.value("BOOLS", AttrType::BOOLEANS)
.value("BLOCK", AttrType::BLOCK);
py::enum_<proto::AttrType>(m, "AttrType", "")
.value("INT", proto::AttrType::INT)
.value("INTS", proto::AttrType::INTS)
.value("FLOAT", proto::AttrType::FLOAT)
.value("FLOATS", proto::AttrType::FLOATS)
.value("STRING", proto::AttrType::STRING)
.value("STRINGS", proto::AttrType::STRINGS)
.value("BOOL", proto::AttrType::BOOLEAN)
.value("BOOLS", proto::AttrType::BOOLEANS)
.value("BLOCK", proto::AttrType::BLOCK);
py::class_<OpDescBind> op_desc(m, "OpDesc", "");
op_desc.def("type", &OpDescBind::Type)
......
......@@ -288,12 +288,12 @@ All parameter, weight, gradient are variables in Paddle.
for (const auto &t : targets) {
prog_with_targets.MutableBlock(t[0])->Op(t[1])->MarkAsTarget();
}
ProgramDesc pruned_desc;
proto::ProgramDesc pruned_desc;
Prune(*prog_with_targets.Proto(), &pruned_desc);
return new ProgramDescBind(pruned_desc);
});
m.def("inference_optimize", [](ProgramDescBind &origin) {
ProgramDesc pruned_desc;
proto::ProgramDesc pruned_desc;
InferenceOptimize(*(origin.Proto()), &pruned_desc);
return new ProgramDescBind(pruned_desc);
});
......@@ -345,7 +345,7 @@ All parameter, weight, gradient are variables in Paddle.
py::class_<OperatorBase>(m, "Operator")
.def_static("create",
[](py::bytes protobin) {
OpDesc desc;
proto::OpDesc desc;
PADDLE_ENFORCE(desc.ParsePartialFromString(protobin),
"Cannot parse user input to OpDesc");
PADDLE_ENFORCE(desc.IsInitialized(),
......@@ -398,7 +398,7 @@ All parameter, weight, gradient are variables in Paddle.
py::class_<operators::CondOp, OperatorBase>(m, "CondOp")
.def_static("create",
[](py::bytes protobin) -> operators::CondOp * {
OpDesc desc;
proto::OpDesc desc;
PADDLE_ENFORCE(desc.ParsePartialFromString(protobin),
"Cannot parse user input to OpDesc");
PADDLE_ENFORCE(desc.IsInitialized(),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册