提交 b567ce7a 编写于 作者: A Abhinav Arora

Merge branch 'refine_pod' of github.com:abhinavarora/Paddle into refine_pod

...@@ -36,7 +36,7 @@ class AssignValueOp : public framework::OperatorWithKernel { ...@@ -36,7 +36,7 @@ class AssignValueOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::proto::DataType(ctx.Attr<int>("dtype")), ctx.GetPlace()); framework::proto::VarType::Type(ctx.Attr<int>("dtype")), ctx.GetPlace());
} }
}; };
...@@ -49,8 +49,8 @@ class AssignValueOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -49,8 +49,8 @@ class AssignValueOpMaker : public framework::OpProtoAndCheckerMaker {
"(vector<int>) " "(vector<int>) "
"Shape of values."); "Shape of values.");
AddAttr<int>("dtype", "data type of values") AddAttr<int>("dtype", "data type of values")
.InEnum({framework::proto::DataType::INT32, .InEnum({framework::proto::VarType::Type::INT32,
framework::proto::DataType::FP32}); framework::proto::VarType::Type::FP32});
AddAttr<std::vector<float>>("fp32_values", "store the float values") AddAttr<std::vector<float>>("fp32_values", "store the float values")
.SetDefault({}); .SetDefault({});
AddAttr<std::vector<int>>("int32_values", "store the int values") AddAttr<std::vector<int>>("int32_values", "store the int values")
......
...@@ -30,10 +30,10 @@ class AssignValueKernel : public framework::OpKernel<T> { ...@@ -30,10 +30,10 @@ class AssignValueKernel : public framework::OpKernel<T> {
int dtype = ctx.Attr<int>("dtype"); int dtype = ctx.Attr<int>("dtype");
const char* value_name = nullptr; const char* value_name = nullptr;
switch (dtype) { switch (dtype) {
case framework::proto::DataType::INT32: case framework::proto::VarType::Type::INT32:
value_name = "int32_values"; value_name = "int32_values";
break; break;
case framework::proto::DataType::FP32: case framework::proto::VarType::Type::FP32:
value_name = "fp32_values"; value_name = "fp32_values";
break; break;
default: default:
......
...@@ -195,14 +195,14 @@ void BindBlockDesc(py::module &m) { ...@@ -195,14 +195,14 @@ void BindBlockDesc(py::module &m) {
} }
void BindVarDsec(py::module &m) { void BindVarDsec(py::module &m) {
py::enum_<proto::DataType>(m, "DataType", "") py::enum_<proto::VarType::Type>(m, "DataType", "")
.value("BOOL", proto::DataType::BOOL) .value("BOOL", proto::VarType::Type::BOOL)
.value("INT16", proto::DataType::INT16) .value("INT16", proto::VarType::Type::INT16)
.value("INT32", proto::DataType::INT32) .value("INT32", proto::VarType::Type::INT32)
.value("INT64", proto::DataType::INT64) .value("INT64", proto::VarType::Type::INT64)
.value("FP16", proto::DataType::FP16) .value("FP16", proto::VarType::Type::FP16)
.value("FP32", proto::DataType::FP32) .value("FP32", proto::VarType::Type::FP32)
.value("FP64", proto::DataType::FP64); .value("FP64", proto::VarType::Type::FP64);
py::class_<VarDesc> var_desc(m, "VarDesc", ""); py::class_<VarDesc> var_desc(m, "VarDesc", "");
var_desc var_desc
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册