提交 f4df0cb1 编写于 作者: Q Qiao Longfei

update the type of shape to int64, format code

上级 fad42fe7
...@@ -24,7 +24,7 @@ class FakeInitInferShape : public framework::InferShapeBase { ...@@ -24,7 +24,7 @@ class FakeInitInferShape : public framework::InferShapeBase {
void operator()(framework::InferShapeContext *ctx) const override { void operator()(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of FakeInitOp should not be null."); "Output(Out) of FakeInitOp should not be null.");
auto &shape = ctx->Attrs().Get<std::vector<int>>("shape"); auto &shape = ctx->Attrs().Get<std::vector<int64_t>>("shape");
ctx->SetOutputDim("Out", framework::make_ddim(shape)); ctx->SetOutputDim("Out", framework::make_ddim(shape));
} }
}; };
...@@ -42,10 +42,10 @@ class FakeInitOp : public framework::OperatorBase { ...@@ -42,10 +42,10 @@ class FakeInitOp : public framework::OperatorBase {
if (out_var.IsType<framework::LoDTensor>()) { if (out_var.IsType<framework::LoDTensor>()) {
tensor = out_var.GetMutable<framework::LoDTensor>(); tensor = out_var.GetMutable<framework::LoDTensor>();
tensor->Resize(framework::make_ddim(Attr<std::vector<int>>("shape"))); tensor->Resize(framework::make_ddim(Attr<std::vector<int64_t>>("shape")));
} else if (out_var.IsType<framework::SelectedRows>()) { } else if (out_var.IsType<framework::SelectedRows>()) {
tensor = out_var.GetMutable<framework::SelectedRows>()->mutable_value(); tensor = out_var.GetMutable<framework::SelectedRows>()->mutable_value();
tensor->Resize(framework::make_ddim(Attr<std::vector<int>>("shape"))); tensor->Resize(framework::make_ddim(Attr<std::vector<int64_t>>("shape")));
} else { } else {
PADDLE_THROW( PADDLE_THROW(
"fake init op's output only" "fake init op's output only"
...@@ -63,7 +63,8 @@ class FakeInitOpVarTypeInference : public framework::VarTypeInference { ...@@ -63,7 +63,8 @@ class FakeInitOpVarTypeInference : public framework::VarTypeInference {
class FakeInitOpMaker : public framework::OpProtoAndCheckerMaker { class FakeInitOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
AddAttr<std::vector<int>>("shape", "(vector<int>) The shape of the output"); AddAttr<std::vector<int64_t>>("shape",
"(vector<int64_t>) The shape of the output");
AddOutput("Out", AddOutput("Out",
"(Tensor) Tensor of specified shape will be filled " "(Tensor) Tensor of specified shape will be filled "
"with the specified value"); "with the specified value");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册