提交 755927d2 编写于 作者: T tangwei12

shape type to int64_t, test=develop

上级 8b7f45a8
...@@ -25,17 +25,17 @@ message Version { optional int64 version = 1 [ default = 0 ]; } ...@@ -25,17 +25,17 @@ message Version { optional int64 version = 1 [ default = 0 ]; }
enum AttrType { enum AttrType {
INT = 0; INT = 0;
FLOAT = 1; LONG = 1;
STRING = 2; FLOAT = 2;
INTS = 3; STRING = 3;
FLOATS = 4; INTS = 4;
STRINGS = 5; LONGS = 5;
BOOLEAN = 6; FLOATS = 6;
BOOLEANS = 7; STRINGS = 7;
BLOCK = 8; BOOLEAN = 8;
LONG = 9; BOOLEANS = 9;
BLOCKS = 10; BLOCK = 10;
LONGS = 11; BLOCKS = 11;
} }
// OpDesc describes an instance of a C++ framework::OperatorBase // OpDesc describes an instance of a C++ framework::OperatorBase
...@@ -46,17 +46,17 @@ message OpDesc { ...@@ -46,17 +46,17 @@ message OpDesc {
required string name = 1; required string name = 1;
required AttrType type = 2; required AttrType type = 2;
optional int32 i = 3; optional int32 i = 3;
optional float f = 4; optional int64 l = 4;
optional string s = 5; optional float f = 5;
repeated int32 ints = 6; optional string s = 6;
repeated float floats = 7; repeated int32 ints = 7;
repeated string strings = 8; repeated int64 longs = 8;
optional bool b = 10; repeated float floats = 9;
repeated bool bools = 11; repeated string strings = 10;
optional int32 block_idx = 12; optional bool b = 11;
optional int64 l = 13; repeated bool bools = 12;
optional int32 block_idx = 13;
repeated int32 blocks_idx = 14; repeated int32 blocks_idx = 14;
optional int64 longs = 15;
}; };
message Var { message Var {
......
...@@ -415,11 +415,13 @@ struct SetAttrDescVisitor : public boost::static_visitor<void> { ...@@ -415,11 +415,13 @@ struct SetAttrDescVisitor : public boost::static_visitor<void> {
void operator()(const std::vector<BlockDesc *> &v) const { void operator()(const std::vector<BlockDesc *> &v) const {
std::vector<int> blocks_idx; std::vector<int> blocks_idx;
for (auto blk : v) { for (auto blk : v) {
blocks_idx.push_back(blk->ID()); blocks_idx.push_sback(blk->ID());
} }
VectorToRepeated(blocks_idx, attr_->mutable_blocks_idx()); VectorToRepeated(blocks_idx, attr_->mutable_blocks_idx());
} }
void operator()(BlockDesc *desc) const { attr_->set_block_idx(desc->ID()); } void operator()(BlockDesapply_visitorc *desc) const {
attr_->set_block_idx(desc->ID());
}
void operator()(int64_t v) const { attr_->set_l(v); } void operator()(int64_t v) const { attr_->set_l(v); }
void operator()(const std::vector<int64_t> &v) const { void operator()(const std::vector<int64_t> &v) const {
......
...@@ -33,10 +33,10 @@ using VariableNameMap = std::map<std::string, std::vector<std::string>>; ...@@ -33,10 +33,10 @@ using VariableNameMap = std::map<std::string, std::vector<std::string>>;
// The order should be as same as framework.proto // The order should be as same as framework.proto
using Attribute = using Attribute =
boost::variant<boost::blank, int, float, std::string, std::vector<int>, boost::variant<boost::blank, int, int64_t, float, std::string,
std::vector<float>, std::vector<std::string>, bool, std::vector<int>, std::vector<int64_t>, std::vector<float>,
std::vector<bool>, BlockDesc*, int64_t, std::vector<std::string>, bool, std::vector<bool>,
std::vector<BlockDesc*>, std::vector<int64_t>>; BlockDesc*, std::vector<BlockDesc*>>;
using AttributeMap = std::unordered_map<std::string, Attribute>; using AttributeMap = std::unordered_map<std::string, Attribute>;
......
...@@ -24,7 +24,7 @@ class FillConstantInferShape : public framework::InferShapeBase { ...@@ -24,7 +24,7 @@ class FillConstantInferShape : public framework::InferShapeBase {
void operator()(framework::InferShapeContext *ctx) const override { void operator()(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of FillConstantOp should not be null."); "Output(Out) of FillConstantOp should not be null.");
auto &shape = ctx->Attrs().Get<std::vector<int>>("shape"); auto &shape = ctx->Attrs().Get<std::vector<int64_t>>("shape");
ctx->SetOutputDim("Out", framework::make_ddim(shape)); ctx->SetOutputDim("Out", framework::make_ddim(shape));
} }
}; };
...@@ -47,10 +47,10 @@ class FillConstantOp : public framework::OperatorBase { ...@@ -47,10 +47,10 @@ class FillConstantOp : public framework::OperatorBase {
if (out_var.IsType<framework::LoDTensor>()) { if (out_var.IsType<framework::LoDTensor>()) {
tensor = out_var.GetMutable<framework::LoDTensor>(); tensor = out_var.GetMutable<framework::LoDTensor>();
tensor->Resize(framework::make_ddim(Attr<std::vector<int>>("shape"))); tensor->Resize(framework::make_ddim(Attr<std::vector<int64_t>>("shape")));
} else if (out_var.IsType<framework::SelectedRows>()) { } else if (out_var.IsType<framework::SelectedRows>()) {
tensor = out_var.GetMutable<framework::SelectedRows>()->mutable_value(); tensor = out_var.GetMutable<framework::SelectedRows>()->mutable_value();
tensor->Resize(framework::make_ddim(Attr<std::vector<int>>("shape"))); tensor->Resize(framework::make_ddim(Attr<std::vector<int64_t>>("shape")));
} else { } else {
PADDLE_THROW( PADDLE_THROW(
"fill constant op's output only" "fill constant op's output only"
...@@ -83,7 +83,8 @@ class FillConstantOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -83,7 +83,8 @@ class FillConstantOpMaker : public framework::OpProtoAndCheckerMaker {
"(int, default 5 (FP32)) " "(int, default 5 (FP32)) "
"Output data type") "Output data type")
.SetDefault(framework::proto::VarType::FP32); .SetDefault(framework::proto::VarType::FP32);
AddAttr<std::vector<int>>("shape", "(vector<int>) The shape of the output"); AddAttr<std::vector<int64_t>>("shape",
"(vector<int64_t>) The shape of the output");
AddAttr<float>("value", "(float, default 0) The value to be filled") AddAttr<float>("value", "(float, default 0) The value to be filled")
.SetDefault(0.0f); .SetDefault(0.0f);
AddAttr<bool>("force_cpu", AddAttr<bool>("force_cpu",
......
...@@ -52,7 +52,7 @@ class GaussianRandomOp : public framework::OperatorWithKernel { ...@@ -52,7 +52,7 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of GaussianRandomOp should not be null."); "Output(Out) of GaussianRandomOp should not be null.");
auto shape = ctx->Attrs().Get<std::vector<int>>("shape"); auto shape = ctx->Attrs().Get<std::vector<int64_t>>("shape");
std::vector<int64_t> temp; std::vector<int64_t> temp;
temp.reserve(shape.size()); temp.reserve(shape.size());
for (auto dim : shape) { for (auto dim : shape) {
...@@ -88,8 +88,8 @@ class GaussianRandomOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -88,8 +88,8 @@ class GaussianRandomOpMaker : public framework::OpProtoAndCheckerMaker {
void Make() override { void Make() override {
AddOutput("Out", "Output matrix of gaussian random op"); AddOutput("Out", "Output matrix of gaussian random op");
AddAttr<std::vector<int>>("shape", AddAttr<std::vector<int64_t>>("shape",
"(vector<int>) " "(vector<int64_t>) "
"The dimension of random tensor."); "The dimension of random tensor.");
AddAttr<float>("mean", AddAttr<float>("mean",
"(float, default 0.0) " "(float, default 0.0) "
......
...@@ -29,7 +29,7 @@ class CPUUniformRandomKernel : public framework::OpKernel<T> { ...@@ -29,7 +29,7 @@ class CPUUniformRandomKernel : public framework::OpKernel<T> {
if (out_var->IsType<framework::LoDTensor>()) { if (out_var->IsType<framework::LoDTensor>()) {
tensor = out_var->GetMutable<framework::LoDTensor>(); tensor = out_var->GetMutable<framework::LoDTensor>();
} else if (out_var->IsType<framework::SelectedRows>()) { } else if (out_var->IsType<framework::SelectedRows>()) {
auto shape = ctx.Attr<std::vector<int>>("shape"); auto shape = ctx.Attr<std::vector<int64_t>>("shape");
auto *selected_rows = out_var->GetMutable<framework::SelectedRows>(); auto *selected_rows = out_var->GetMutable<framework::SelectedRows>();
tensor = selected_rows->mutable_value(); tensor = selected_rows->mutable_value();
tensor->Resize(framework::make_ddim(shape)); tensor->Resize(framework::make_ddim(shape));
...@@ -67,7 +67,7 @@ class UniformRandomOp : public framework::OperatorWithKernel { ...@@ -67,7 +67,7 @@ class UniformRandomOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE( PADDLE_ENFORCE(
ctx->Attrs().Get<float>("min") < ctx->Attrs().Get<float>("max"), ctx->Attrs().Get<float>("min") < ctx->Attrs().Get<float>("max"),
"uniform_random's min must less then max"); "uniform_random's min must less then max");
auto &shape = ctx->Attrs().Get<std::vector<int>>("shape"); auto &shape = ctx->Attrs().Get<std::vector<int64_t>>("shape");
std::vector<int64_t> temp; std::vector<int64_t> temp;
temp.reserve(shape.size()); temp.reserve(shape.size());
for (auto dim : shape) { for (auto dim : shape) {
...@@ -94,7 +94,7 @@ This operator initializes a tensor with random values sampled from a ...@@ -94,7 +94,7 @@ This operator initializes a tensor with random values sampled from a
uniform distribution. The random result is in set [min, max]. uniform distribution. The random result is in set [min, max].
)DOC"); )DOC");
AddAttr<std::vector<int>>("shape", "The shape of the output tensor"); AddAttr<std::vector<int64_t>>("shape", "The shape of the output tensor");
AddAttr<float>("min", "Minimum value of uniform random. [default -1.0].") AddAttr<float>("min", "Minimum value of uniform random. [default -1.0].")
.SetDefault(-1.0f); .SetDefault(-1.0f);
AddAttr<float>("max", "Maximun value of uniform random. [default 1.0].") AddAttr<float>("max", "Maximun value of uniform random. [default 1.0].")
......
...@@ -259,8 +259,8 @@ void BindOpDesc(pybind11::module *m) { ...@@ -259,8 +259,8 @@ void BindOpDesc(pybind11::module *m) {
pybind11::enum_<pd::proto::AttrType>(*m, "AttrType", "") pybind11::enum_<pd::proto::AttrType>(*m, "AttrType", "")
.value("INT", pd::proto::AttrType::INT) .value("INT", pd::proto::AttrType::INT)
.value("INTS", pd::proto::AttrType::INTS) .value("INTS", pd::proto::AttrType::INTS)
.value("LONG", pd::proto::AttrType::FLOAT) .value("LONG", pd::proto::AttrType::LONG)
.value("LONGS", pd::proto::AttrType::FLOAT) .value("LONGS", pd::proto::AttrType::LONGS)
.value("FLOAT", pd::proto::AttrType::FLOAT) .value("FLOAT", pd::proto::AttrType::FLOAT)
.value("FLOATS", pd::proto::AttrType::FLOATS) .value("FLOATS", pd::proto::AttrType::FLOATS)
.value("STRING", pd::proto::AttrType::STRING) .value("STRING", pd::proto::AttrType::STRING)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册