提交 30b50dcf 编写于 作者: T tangwei12

fix Serial output type

上级 2e25e739
......@@ -87,7 +87,7 @@ class CheckpointSaveOp : public framework::OperatorBase {
std::string *serial_num = serial_var->GetMutable<std::string>();
serial_num->append("0");
dir.append("/");
dir.append(serial_num);
dir.append(serial_num->c_str());
MkDirRecursively(dir.c_str());
auto inp_var_names = Inputs("X");
......@@ -159,10 +159,29 @@ to a file on disk.
}
};
class CheckpointSaveOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {
auto out_var_name = op_desc.Output("Serial").front();
auto &out_var = block->FindRecursiveOrCreateVar(out_var_name);
auto var_type = framework::proto::VarType::RAW;
out_var.SetType(var_type);
}
};
class CheckpointSaveOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(checkpoint_save, ops::CheckpointSaveOp,
ops::CheckpointSaveOpProtoMaker);
REGISTER_OPERATOR(send_vars, ops::CheckpointSaveOp,
paddle::framework::EmptyGradOpMaker,
ops::CheckpointSaveOpProtoMaker,
ops::CheckpointSaveOpVarTypeInference,
ops::CheckpointSaveOpShapeInference);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册