提交 f9f910a3 编写于 作者: Y Yu Yang

Complete op

上级 1cd20140
......@@ -14,8 +14,72 @@ limitations under the License. */
#include "paddle/pybind/protobuf.h"
#include <deque>
#include <iostream>
#include "paddle/framework/attribute.h"
// Cast boost::variant for PyBind.
// Copy from
// https://github.com/pybind/pybind11/issues/576#issuecomment-269563199
namespace pybind11 {
namespace detail {
// Can be replaced by a generic lambda in C++14
struct variant_caster_visitor : public boost::static_visitor<handle> {
return_value_policy policy;
handle parent;
variant_caster_visitor(return_value_policy policy, handle parent)
: policy(policy), parent(parent) {}
template <class T>
handle operator()(T const &src) const {
return make_caster<T>::cast(src, policy, parent);
}
};
template <class Variant>
struct variant_caster;
template <template <class...> class V, class... Ts>
struct variant_caster<V<Ts...>> {
using Type = V<Ts...>;
template <class T>
bool try_load(handle src, bool convert) {
auto caster = make_caster<T>();
if (!load_success_ && caster.load(src, convert)) {
load_success_ = true;
value = cast_op<T>(caster);
return true;
}
return false;
}
bool load(handle src, bool convert) {
auto unused = {false, try_load<Ts>(src, convert)...};
(void)(unused);
return load_success_;
}
static handle cast(Type const &src,
return_value_policy policy,
handle parent) {
variant_caster_visitor visitor(policy, parent);
return boost::apply_visitor(visitor, src);
}
PYBIND11_TYPE_CASTER(Type, _("Variant"));
bool load_success_{false};
};
// Add specialization for concrete variant type
template <class... Args>
struct type_caster<boost::variant<Args...>>
: variant_caster<boost::variant<Args...>> {};
} // namespace detail
} // namespace pybind11
namespace paddle {
namespace pybind {
......@@ -40,6 +104,15 @@ inline void VectorToRepeated(const std::vector<T> &vec,
}
}
template <typename RepeatedField>
inline void VectorToRepeated(const std::vector<bool> &vec,
RepeatedField *repeated_field) {
repeated_field->Reserve(vec.size());
for (auto elem : vec) {
*repeated_field->Add() = elem;
}
}
class ProgramDescBind;
class OpDescBind;
class BlockDescBind;
......@@ -146,6 +219,10 @@ public:
void operator()(const std::vector<bool> &v) const {
VectorToRepeated(v, attr_->mutable_bools());
}
void operator()(BlockDesc *desc) const {
attr_->set_block_idx(desc->idx());
}
void operator()(boost::blank) const { PADDLE_THROW("Unexpected branch"); }
};
void Sync() {
......@@ -168,13 +245,52 @@ public:
for (auto &attr : attrs_) {
auto *attr_desc = op_desc_.add_attrs();
attr_desc->set_name(attr.first);
attr_desc->set_type(static_cast<AttrType>(attr.second.which() - 1));
attr_desc->set_type(
static_cast<framework::AttrType>(attr.second.which() - 1));
boost::apply_visitor(SetAttrDescVisitor(attr_desc), attr.second);
}
need_update_ = false;
}
}
bool HasAttr(const std::string &name) const {
return attrs_.find(name) != attrs_.end();
}
framework::AttrType GetAttrType(const std::string &name) const {
auto it = attrs_.find(name);
PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name);
return static_cast<framework::AttrType>(it->second.which() - 1);
}
std::vector<std::string> AttrNames() const {
std::vector<std::string> retv;
retv.reserve(attrs_.size());
for (auto &attr : attrs_) {
retv.push_back(attr.first);
}
return retv;
}
void SetAttr(const std::string &name, const Attribute &v) {
this->attrs_[name] = v;
}
void SetBlockAttr(const std::string &name, BlockDescBind &block);
int GetBlockAttr(const std::string &name) const {
auto it = attrs_.find(name);
PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name);
return boost::get<BlockDesc *>(it->second)->idx();
}
Attribute GetAttr(const std::string &name) const {
auto it = attrs_.find(name);
PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name);
return it->second;
}
private:
OpDesc op_desc_;
std::unordered_map<std::string, std::vector<std::string>> inputs_;
......@@ -232,6 +348,8 @@ public:
}
}
BlockDesc *RawPtr() { return desc_; }
private:
ProgramDescBind *prog_; // not_own
BlockDesc *desc_; // not_own
......@@ -303,6 +421,11 @@ BlockDescBind *BlockDescBind::ParentBlock() const {
return prog_->Block(static_cast<size_t>(this->desc_->parent_idx()));
}
void OpDescBind::SetBlockAttr(const std::string &name, BlockDescBind &block) {
BlockDesc *desc = block.RawPtr();
this->attrs_[name] = desc;
}
void BindProgramDesc(py::module &m) {
py::class_<ProgramDescBind>(m, "ProgramDesc", "")
.def_static("instance",
......@@ -351,8 +474,19 @@ void BindVarDsec(py::module &m) {
}
void BindOpDesc(py::module &m) {
py::class_<OpDescBind>(m, "OpDesc", "")
.def("type", &OpDescBind::Type)
py::enum_<framework::AttrType>(m, "AttrType", "")
.value("INT", AttrType::INT)
.value("INTS", AttrType::INTS)
.value("FLOAT", AttrType::FLOAT)
.value("FLOATS", AttrType::FLOATS)
.value("STRING", AttrType::STRING)
.value("STRINGS", AttrType::STRINGS)
.value("BOOL", AttrType::BOOLEAN)
.value("BOOLS", AttrType::BOOLEANS)
.value("BLOCK", AttrType::BLOCK);
py::class_<OpDescBind> op_desc(m, "OpDesc", "");
op_desc.def("type", &OpDescBind::Type)
.def("set_type", &OpDescBind::SetType)
.def("input", &OpDescBind::Input)
.def("input_names", &OpDescBind::InputNames)
......@@ -361,7 +495,15 @@ void BindOpDesc(py::module &m) {
.def("output_names", &OpDescBind::OutputNames)
.def("set_output", &OpDescBind::SetOutput)
.def("__str__", &OpDescBind::DebugString)
.def("__repr__", &OpDescBind::DebugString);
.def("__repr__", &OpDescBind::DebugString)
.def("has_attr", &OpDescBind::HasAttr)
.def("attr_type", &OpDescBind::GetAttrType)
.def("attr_names", &OpDescBind::AttrNames)
.def("set_attr", &OpDescBind::SetAttr)
.def("attr", &OpDescBind::GetAttr)
.def("set_block_attr", &OpDescBind::SetBlockAttr)
.def("get_block_attr", &OpDescBind::GetBlockAttr);
}
} // namespace pybind
} // namespace paddle
......@@ -20,6 +20,40 @@ class TestOpDesc(unittest.TestCase):
self.assertEqual(['z'], op.output("Out"))
self.assertEqual(["Out"], op.output_names())
op.set_attr("int_attr", 1)
self.assertEqual(1, op.attr("int_attr"))
self.assertTrue(op.has_attr("int_attr"))
op.set_attr("float_attr", -1.32)
self.assertAlmostEqual(-1.32, op.attr("float_attr"), delta=1e-4)
self.assertTrue(op.has_attr("float_attr"))
op.set_attr("bool_attr", False)
self.assertFalse(op.attr("bool_attr"))
op.set_attr("string_attr", "abc")
self.assertEqual("abc", op.attr("string_attr"))
self.assertTrue(op.has_attr("string_attr"))
op.set_attr("ints_attr", [1, 2, 3])
self.assertEqual([1, 2, 3], op.attr("ints_attr"))
expected = [1.2, 2.3, 3.4]
op.set_attr("floats_attr", expected)
for e, a in zip(expected, op.attr("floats_attr")):
self.assertAlmostEqual(e, a, delta=1e-4)
op.set_attr("strings_attr", ["a", "b", "c"])
self.assertEqual(["a", "b", "c"], op.attr("strings_attr"))
op.set_attr("bools_attr", [True, False, True])
self.assertEqual([True, False, True], op.attr("bools_attr"))
self.assertEqual(8, len(op.attr_names()))
op.set_block_attr("block_attr", prog.block(0))
self.assertEqual(0, op.get_block_attr("block_attr"))
class TestProgramDesc(unittest.TestCase):
def test_instance(self):
......@@ -51,7 +85,7 @@ class TestProgramDesc(unittest.TestCase):
class TestVarDesc(unittest.TestCase):
def test_shape(self):
program_desc = core.ProgramDesc.instance()
block = program_desc.root_block()
block = program_desc.block(0)
var = block.new_var('my_var')
src_shape = [3, 2, 10, 8]
var.set_shape(src_shape)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册