提交 b567ce7a 编写于 作者: A Abhinav Arora

Merge branch 'refine_pod' of github.com:abhinavarora/Paddle into refine_pod

......@@ -36,7 +36,7 @@ class AssignValueOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::proto::DataType(ctx.Attr<int>("dtype")), ctx.GetPlace());
framework::proto::VarType::Type(ctx.Attr<int>("dtype")), ctx.GetPlace());
}
};
......@@ -49,8 +49,8 @@ class AssignValueOpMaker : public framework::OpProtoAndCheckerMaker {
"(vector<int>) "
"Shape of values.");
AddAttr<int>("dtype", "data type of values")
.InEnum({framework::proto::DataType::INT32,
framework::proto::DataType::FP32});
.InEnum({framework::proto::VarType::Type::INT32,
framework::proto::VarType::Type::FP32});
AddAttr<std::vector<float>>("fp32_values", "store the float values")
.SetDefault({});
AddAttr<std::vector<int>>("int32_values", "store the int values")
......
......@@ -30,10 +30,10 @@ class AssignValueKernel : public framework::OpKernel<T> {
int dtype = ctx.Attr<int>("dtype");
const char* value_name = nullptr;
switch (dtype) {
case framework::proto::DataType::INT32:
case framework::proto::VarType::Type::INT32:
value_name = "int32_values";
break;
case framework::proto::DataType::FP32:
case framework::proto::VarType::Type::FP32:
value_name = "fp32_values";
break;
default:
......
......@@ -195,14 +195,14 @@ void BindBlockDesc(py::module &m) {
}
void BindVarDsec(py::module &m) {
py::enum_<proto::DataType>(m, "DataType", "")
.value("BOOL", proto::DataType::BOOL)
.value("INT16", proto::DataType::INT16)
.value("INT32", proto::DataType::INT32)
.value("INT64", proto::DataType::INT64)
.value("FP16", proto::DataType::FP16)
.value("FP32", proto::DataType::FP32)
.value("FP64", proto::DataType::FP64);
py::enum_<proto::VarType::Type>(m, "DataType", "")
.value("BOOL", proto::VarType::Type::BOOL)
.value("INT16", proto::VarType::Type::INT16)
.value("INT32", proto::VarType::Type::INT32)
.value("INT64", proto::VarType::Type::INT64)
.value("FP16", proto::VarType::Type::FP16)
.value("FP32", proto::VarType::Type::FP32)
.value("FP64", proto::VarType::Type::FP64);
py::class_<VarDesc> var_desc(m, "VarDesc", "");
var_desc
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册