未验证 提交 b3cf28f8 编写于 作者: R Ruibiao Chen 提交者: GitHub

Remove boost::variant (#43100)

* boost::variant -> paddle::variant

* boost::variant.apply_visit -> paddle::visit

* Update pybind_boost_hraders.h

* Fix CINN compilation errors

* Revert FetchResultType
上级 369b2b1b
......@@ -18,7 +18,6 @@
#include <string>
#include "boost/variant.hpp"
#include "paddle/fluid/distributed/collective/Types.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/variable.h"
......@@ -27,6 +26,7 @@
#include "paddle/fluid/platform/device/npu/npu_info.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/utils/variant.h"
namespace paddle {
namespace distributed {
......
......@@ -25,7 +25,6 @@
#include <string>
#include "boost/variant.hpp"
#include "paddle/fluid/distributed/collective/Types.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/variable.h"
......@@ -43,6 +42,7 @@
#endif
#include "paddle/fluid/platform/enforce.h"
#include "paddle/utils/variant.h"
namespace paddle {
namespace distributed {
......
......@@ -336,7 +336,7 @@ static std::string AttrTypeToString(const proto::AttrType& type) {
}
default: {
PADDLE_THROW(platform::errors::Fatal(
"AttrType of type boost::variant only supports specific data types."
"AttrType of type paddle::variant only supports specific data types."
"However, detected unrecognized AttrType: %d",
type));
}
......@@ -344,37 +344,39 @@ static std::string AttrTypeToString(const proto::AttrType& type) {
return ret;
}
template <typename T>
static std::string GetAttrValue(const framework::Attribute& attr,
bool is_vector) {
template <typename T, bool IsVector>
static typename std::enable_if<IsVector, std::string>::type GetAttrValue(
const framework::Attribute& attr) {
std::string val = "";
if (is_vector) {
val += "{";
for (auto x : BOOST_GET_CONST(std::vector<T>, attr)) {
val += std::to_string(x) + ",";
}
if (val.size() > 1) val.pop_back();
val += "}";
} else {
val = std::to_string(BOOST_GET_CONST(T, attr));
val += "{";
for (auto x : BOOST_GET_CONST(std::vector<T>, attr)) {
val += std::to_string(x) + ",";
}
if (val.size() > 1) val.pop_back();
val += "}";
return val;
}
template <typename T, bool IsVector>
static typename std::enable_if<!IsVector, std::string>::type GetAttrValue(
const framework::Attribute& attr) {
return std::to_string(BOOST_GET_CONST(T, attr));
}
static std::pair<std::string, std::string> GetAttrType(
const framework::Attribute& attr, bool is_arg) {
std::string ret = "";
std::string val = "";
size_t variant_pos = attr.which();
size_t variant_pos = attr.index();
switch (variant_pos) {
case (1): {
ret = "int";
val = GetAttrValue<int>(attr, false);
val = GetAttrValue<int, false>(attr);
break;
}
case (2): {
ret = "float";
val = GetAttrValue<float>(attr, false);
val = GetAttrValue<float, false>(attr);
break;
}
case (3): {
......@@ -386,13 +388,13 @@ static std::pair<std::string, std::string> GetAttrType(
case (4): {
ret = "std::vector<int>";
if (is_arg) ret += "&";
val = GetAttrValue<int>(attr, true);
val = GetAttrValue<int, true>(attr);
break;
}
case (5): {
ret = "std::vector<float>";
if (is_arg) ret += "&";
val = GetAttrValue<float>(attr, true);
val = GetAttrValue<float, true>(attr);
break;
}
case (6): {
......@@ -408,13 +410,13 @@ static std::pair<std::string, std::string> GetAttrType(
}
case (7): {
ret = "bool";
val = GetAttrValue<bool>(attr, false);
val = GetAttrValue<bool, false>(attr);
break;
}
case (8): {
ret = "std::vector<bool>";
if (is_arg) ret += "&";
val = GetAttrValue<bool>(attr, true);
val = GetAttrValue<bool, true>(attr);
break;
}
case (9): {
......@@ -423,7 +425,7 @@ static std::pair<std::string, std::string> GetAttrType(
}
case (10): {
ret = "int64_t";
val = GetAttrValue<int64_t>(attr, false);
val = GetAttrValue<int64_t, false>(attr);
break;
}
case (11): {
......@@ -434,18 +436,18 @@ static std::pair<std::string, std::string> GetAttrType(
case (12): {
ret = "std::vector<int64_t>";
if (is_arg) ret += "&";
val = GetAttrValue<int64_t>(attr, true);
val = GetAttrValue<int64_t, true>(attr);
break;
}
case (13): {
ret = "std::vector<double>";
if (is_arg) ret += "&";
val = GetAttrValue<double>(attr, true);
val = GetAttrValue<double, true>(attr);
break;
}
default: {
PADDLE_THROW(platform::errors::Fatal(
"AttrType of type boost::variant only supports specific data types."
"AttrType of type paddle::variant only supports specific data types."
"However, detected unrecognized AttrType: %d",
variant_pos));
}
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/attribute.h"
#include "boost/blank.hpp"
namespace paddle {
namespace framework {
......
......@@ -23,12 +23,12 @@ limitations under the License. */
#include <unordered_set>
#include <vector>
#include "boost/variant/get.hpp"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/type_defs.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
#include "paddle/utils/any.h"
#include "paddle/utils/variant.h"
namespace paddle {
namespace framework {
......@@ -45,8 +45,8 @@ struct ExtractAttribute {
T* operator()(Attribute& attr) const {
T* attr_value = nullptr;
try {
attr_value = &boost::get<T>(attr);
} catch (boost::bad_get& bad_get) {
attr_value = &paddle::get<T>(attr);
} catch (paddle::bad_variant_access const& bad_get) {
PADDLE_THROW(platform::errors::InvalidArgument(
"Cannot get attribute (%s) by type %s, its type is %s.",
attr_name_,
......@@ -80,8 +80,8 @@ struct ExtractAttribute<bool> {
}
bool* attr_value = nullptr;
try {
attr_value = &boost::get<bool>(attr);
} catch (boost::bad_get& bad_get) {
attr_value = &paddle::get<bool>(attr);
} catch (paddle::bad_variant_access const& bad_get) {
PADDLE_THROW(platform::errors::InvalidArgument(
"Cannot get attribute (%s) by type bool, its type is %s.",
attr_name_,
......@@ -108,8 +108,8 @@ struct ExtractAttribute<int64_t> {
}
int64_t* attr_value = nullptr;
try {
attr_value = &boost::get<int64_t>(attr);
} catch (boost::bad_get& bad_get) {
attr_value = &paddle::get<int64_t>(attr);
} catch (paddle::bad_variant_access const& bad_get) {
PADDLE_THROW(platform::errors::InvalidArgument(
"Cannot get attribute (%s) by type int64_t, its type is %s.",
attr_name_,
......@@ -138,8 +138,8 @@ struct ExtractAttribute<std::vector<int64_t>> {
}
std::vector<int64_t>* attr_value = nullptr;
try {
attr_value = &boost::get<std::vector<int64_t>>(attr);
} catch (boost::bad_get& bad_get) {
attr_value = &paddle::get<std::vector<int64_t>>(attr);
} catch (paddle::bad_variant_access const& bad_get) {
PADDLE_THROW(platform::errors::InvalidArgument(
"Cannot get attribute (%s) by type std::vector<int64_t>, its type is "
"%s.",
......@@ -167,8 +167,8 @@ struct ExtractAttribute<float> {
}
float* attr_value = nullptr;
try {
attr_value = &boost::get<float>(attr);
} catch (boost::bad_get& bad_get) {
attr_value = &paddle::get<float>(attr);
} catch (paddle::bad_variant_access const& bad_get) {
PADDLE_THROW(platform::errors::InvalidArgument(
"Cannot get attribute (%s) by type float, its type is %s.",
attr_name_,
......@@ -197,8 +197,8 @@ struct ExtractAttribute<std::vector<double>> {
}
std::vector<double>* attr_value = nullptr;
try {
attr_value = &boost::get<std::vector<double>>(attr);
} catch (boost::bad_get& bad_get) {
attr_value = &paddle::get<std::vector<double>>(attr);
} catch (paddle::bad_variant_access const& bad_get) {
PADDLE_THROW(platform::errors::InvalidArgument(
"Cannot get attribute (%s) by type std::vector<double>, its type is "
"%s.",
......@@ -214,11 +214,11 @@ struct ExtractAttribute<std::vector<double>> {
template <typename T>
inline proto::AttrType AttrTypeID() {
Attribute tmp = T();
return static_cast<proto::AttrType>(tmp.which() - 1);
return static_cast<proto::AttrType>(tmp.index() - 1);
}
inline proto::AttrType AttrTypeID(const Attribute& attr) {
return static_cast<proto::AttrType>(attr.which() - 1);
return static_cast<proto::AttrType>(attr.index() - 1);
}
class AttrReader {
......
......@@ -272,7 +272,7 @@ void BlockDesc::MoveFrom(BlockDesc *block) {
for (const auto &pair : src_op->GetAttrMap()) {
const auto &attr_name = pair.first;
const auto &attr_value = pair.second;
auto attr_type = static_cast<proto::AttrType>(attr_value.which() - 1);
auto attr_type = static_cast<proto::AttrType>(attr_value.index() - 1);
if (attr_type == proto::AttrType::BLOCK) {
auto block_id = BOOST_GET_CONST(BlockDesc *, attr_value)->ID();
dst_op->SetBlockAttr(attr_name, prog_->MutableBlock(block_id));
......
......@@ -174,7 +174,7 @@ FetchResultType AsyncSSAGraphExecutor::Run(
HandleException();
FetchList ret;
auto &val = BOOST_GET(FetchList, fetch_data);
auto &val = boost::get<FetchList>(fetch_data);
for (size_t fetch_idx = 0; fetch_idx < fetch_tensors.size(); ++fetch_idx) {
if (data_is_lod_tensor(val.at(fetch_idx))) {
std::vector<const LoDTensor *> lodtensor_ptrs;
......
......@@ -228,7 +228,7 @@ void FetchAsyncOpHandle::RunImpl() {
}
if (return_merged_) {
auto &val = BOOST_GET(FetchList, *data_);
auto &val = boost::get<FetchList>(*data_);
if (src_vars[0]->IsType<LoDTensor>()) {
// to lodtensor type
std::vector<const LoDTensor *> src_lodtensors;
......@@ -263,7 +263,7 @@ void FetchAsyncOpHandle::RunImpl() {
val.at(offset_) = std::move(dst_lodtensor_array);
}
} else {
auto &val = BOOST_GET(FetchUnmergedList, *data_);
auto &val = boost::get<FetchUnmergedList>(*data_);
auto &dst_tensors = val.at(offset_);
dst_tensors.reserve(src_vars.size());
......
......@@ -84,7 +84,7 @@ void FetchOpHandle::WaitAndMergeCPUFetchVars() const {
for (auto &t : tensors_) {
tensors_ptr.emplace_back(&BOOST_GET_CONST(LoDTensor, t));
}
auto &val = BOOST_GET(FetchList, *data_);
auto &val = boost::get<FetchList>(*data_);
LoDTensor var;
MergeLoDTensor(&var, tensors_ptr, platform::CPUPlace());
val.at(offset_) = std::move(var);
......@@ -106,11 +106,11 @@ void FetchOpHandle::WaitAndMergeCPUFetchVars() const {
tmp_array.emplace_back();
MergeLoDTensor(&(tmp_array.back()), tensors_ptr, platform::CPUPlace());
}
auto &val = BOOST_GET(FetchList, *data_);
auto &val = boost::get<FetchList>(*data_);
val.at(offset_) = std::move(tmp_array);
}
} else {
auto &val = BOOST_GET(FetchUnmergedList, *data_);
auto &val = boost::get<FetchUnmergedList>(*data_);
val.at(offset_) = std::move(tensors_);
}
}
......
......@@ -278,8 +278,7 @@ FetchResultType ParallelSSAGraphExecutor::Run(
if (!is_valid[scope_idx]) {
continue;
}
const auto &fetch_list =
BOOST_GET_CONST(FetchList, fetch_data[scope_idx]);
const auto &fetch_list = boost::get<FetchList>(fetch_data[scope_idx]);
if (data_is_lod_tensor(fetch_list[fetch_idx])) {
lodtensor_ptrs.push_back(
&(BOOST_GET_CONST(LoDTensor, fetch_list[fetch_idx])));
......@@ -318,7 +317,7 @@ FetchResultType ParallelSSAGraphExecutor::Run(
continue;
}
const auto &fetch_list =
BOOST_GET_CONST(FetchUnmergedList, fetch_data[scope_idx]);
boost::get<FetchUnmergedList>(fetch_data[scope_idx]);
PADDLE_ENFORCE_EQ(
fetch_list[fetch_idx].size(),
1,
......
......@@ -23,10 +23,10 @@ limitations under the License. */
namespace paddle {
namespace framework {
using FeedType = boost::variant<LoDTensor, Strings>;
using FeedType = paddle::variant<LoDTensor, Strings>;
using FeedList = std::vector<FeedType>;
using FetchType = boost::variant<LoDTensor, LoDTensorArray>;
using FetchType = paddle::variant<LoDTensor, LoDTensorArray, framework::Vocab>;
using FetchList = std::vector<FetchType>;
using FetchUnmergedList = std::vector<std::vector<FetchType>>;
......
......@@ -121,7 +121,7 @@ void ConvElementwiseAddFusePass::ApplyImpl(ir::Graph* graph) const {
elementwise_add_op_desc->GetNullableAttr("out_threshold");
// set the out_threshold of the elementwise add op to be the out_threshold
// of the conv2d_fusion
if (out_threshold_attr.which()) {
if (out_threshold_attr.index()) {
new_op_desc.SetAttr("out_threshold", out_threshold_attr);
}
new_op_desc.Flush();
......
......@@ -261,7 +261,7 @@ int FCFusePass::ApplyFCPattern(Graph* graph, bool with_relu) const {
// out_thrshold of fc
auto out_threshold_attr =
elementwise_add_op_desc->GetNullableAttr("out_threshold");
if (out_threshold_attr.which()) {
if (out_threshold_attr.index()) {
VLOG(4) << "setting out_threshold: "
<< BOOST_GET_CONST(float, out_threshold_attr);
desc.SetAttr("out_threshold", out_threshold_attr);
......
......@@ -78,7 +78,7 @@ static std::string RefineTemplateWithAttr(const std::string& op_type,
}
Attribute attr = it->second;
proto::AttrType attr_type =
static_cast<proto::AttrType>(it->second.which() - 1);
static_cast<proto::AttrType>(it->second.index() - 1);
if (attr_type == proto::AttrType::BOOLEAN) {
bool result = BOOST_GET(bool, attr);
if (result) {
......
......@@ -14,6 +14,7 @@
#include "paddle/fluid/framework/ir/generate_pass.h"
#include "boost/blank.hpp"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace paddle {
......@@ -105,7 +106,7 @@ Attribute GetOpAttrValue(const OpDesc* desc,
const proto::PassDesc::Attr& attr) {
Attribute value = desc->GetAttr(attr.name());
if (attr.has_element_index()) {
value = boost::apply_visitor(element_visitor(attr.element_index()), value);
value = paddle::visit(element_visitor(attr.element_index()), value);
}
return value;
}
......@@ -203,7 +204,7 @@ void InitGeneratePattern(const proto::PassDesc& pass_desc, PDPattern* pattern) {
Attribute attr = GetVarAttrValue(x->Var(), condition.attr());
if (condition.has_operation()) {
Attribute operation = GetAttrValue(condition.operation().value());
attr = boost::apply_visitor(
attr = paddle::visit(
operation_visitor(condition.operation().type()), attr, operation);
}
switch (condition.type()) {
......@@ -388,7 +389,7 @@ GraphPatternDetector::handle_t GetGenerateRewrite(
if (attr_map.has_operation()) {
Attribute operation =
GetAttrValue(attr_map.operation().value());
attr = boost::apply_visitor(
attr = paddle::visit(
operation_visitor(attr_map.operation().type()),
attr,
operation);
......
......@@ -320,7 +320,7 @@ void CPUQuantizeSquashPass::RequantOpSquash(Graph* graph) const {
"should have requantize output as input.",
requant_out->Name()));
float requant_scale_in =
boost::get<float>(requant_op->Op()->GetAttr("Scale_in"));
paddle::get<float>(requant_op->Op()->GetAttr("Scale_in"));
auto scale_name = "Scale_in";
if (any_op->Op()->Type() == "matmul")
......
......@@ -118,9 +118,9 @@ void ReshapeTransposeMatmulMkldnnFusePass::Fuse(
GET_IR_NODE_FROM_SUBGRAPH(matmul_out, matmul_out, rtm_pattern);
auto reshape_shape =
boost::get<std::vector<int>>(reshape_op->Op()->GetAttr("shape"));
paddle::get<std::vector<int>>(reshape_op->Op()->GetAttr("shape"));
auto transpose_axis =
boost::get<std::vector<int>>(transpose_op->Op()->GetAttr("axis"));
paddle::get<std::vector<int>>(transpose_op->Op()->GetAttr("axis"));
OpDesc *matmul_desc = matmul_op->Op();
std::string input_var_name = transpose_out->Name();
......
......@@ -16,6 +16,7 @@ limitations under the License. */
#include <string>
#include "boost/blank.hpp"
#include "glog/logging.h"
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/op_call_stack.h"
......@@ -563,7 +564,7 @@ proto::AttrType OpDesc::GetAttrType(const std::string &name) const {
it,
attrs_.end(),
platform::errors::NotFound("Attribute %s is not found.", name));
return static_cast<proto::AttrType>(it->second.which() - 1);
return static_cast<proto::AttrType>(it->second.index() - 1);
}
std::vector<std::string> OpDesc::AttrNames() const {
......@@ -584,7 +585,7 @@ void OpDesc::SetAttr(const std::string &name, const Attribute &v) {
// NOTICE(minqiyang): pybind11 will take the empty list in python as
// the std::vector<int> type in C++; so we have to change the attr's type
// here if we meet this issue
proto::AttrType attr_type = static_cast<proto::AttrType>(v.which() - 1);
proto::AttrType attr_type = static_cast<proto::AttrType>(v.index() - 1);
if (attr_type == proto::AttrType::INTS &&
BOOST_GET_CONST(std::vector<int>, v).size() == 0u) {
// Find current attr via attr name and set the correct attribute value
......@@ -837,9 +838,9 @@ void OpDesc::Flush() {
auto *attr_desc = desc_.add_attrs();
attr_desc->set_name(attr.first);
attr_desc->set_type(
static_cast<proto::AttrType>(attr.second.which() - 1));
static_cast<proto::AttrType>(attr.second.index() - 1));
SetAttrDescVisitor visitor(attr_desc);
boost::apply_visitor(visitor, attr.second);
paddle::visit(visitor, attr.second);
}
need_update_ = false;
......
......@@ -33,18 +33,18 @@ class OpVersionMap;
} // namespace pb
using OpAttrVariantT =
boost::variant<bool, /* AttrType::BOOL */
float, /* AttrType::FLOAT */
int32_t, /* AttrType::INT */
int64_t, /* AttrType::LONG*/
std::string, /* AttrType::STRING */
std::vector<bool>, /* AttrType::BOOLS */
std::vector<float>, /* AttrType::FLOATS */
std::vector<int32_t>, /* AttrType::INTS */
std::vector<int64_t>, /* AttrType::LONGS */
std::vector<std::string>, /* AttrType::STRINGS */
paddle::none_t /* None */
>;
paddle::variant<bool, /* AttrType::BOOL */
float, /* AttrType::FLOAT */
int32_t, /* AttrType::INT */
int64_t, /* AttrType::LONG*/
std::string, /* AttrType::STRING */
std::vector<bool>, /* AttrType::BOOLS */
std::vector<float>, /* AttrType::FLOATS */
std::vector<int32_t>, /* AttrType::INTS */
std::vector<int64_t>, /* AttrType::LONGS */
std::vector<std::string>, /* AttrType::STRINGS */
paddle::none_t /* None */
>;
struct OpUpdateInfo {
virtual ~OpUpdateInfo() = default;
......
......@@ -190,7 +190,7 @@ void OpAttrsToCinn(framework::OpDesc *pb_desc, cpp::OpDesc *cpp_desc) {
IMPL_ONE(LONG, int64_t);
IMPL_ONE(LONGS, std::vector<int64_t>);
case AttrType::BLOCK: {
auto i = pb_desc->GetAttrIfExists<int16_t>(name);
auto i = pb_desc->GetAttrIfExists<int32_t>(name);
cpp_desc->SetAttr<int32_t>(name, i);
break;
}
......
......@@ -56,7 +56,7 @@ namespace framework {
class OperatorBase;
using InferShapeVarPtr = boost::variant<VarDesc *, Variable *>;
using InferShapeVarPtr = paddle::variant<VarDesc *, Variable *>;
class InferShapeContext {
public:
......
......@@ -27,13 +27,13 @@ limitations under the License. */
namespace paddle {
namespace framework {
typedef boost::variant<int,
int64_t,
float,
double,
std::string,
Tensor,
LoDTensor /*, ChannelHolder*/>
typedef paddle::variant<int,
int64_t,
float,
double,
std::string,
Tensor,
LoDTensor /*, ChannelHolder*/>
ElementVar;
class Tuple {
......@@ -64,8 +64,8 @@ bool Tuple::isSameType(const Tuple& t) const {
return false;
}
for (size_t j = 0; j < tuple_size; ++j) {
auto type1 = get(j).which();
auto type2 = t.get(j).which();
auto type1 = get(j).index();
auto type2 = t.get(j).index();
if (type1 != type2) return false;
}
return true;
......
......@@ -22,9 +22,11 @@ limitations under the License. */
#include <unordered_set>
#include <vector>
#include "boost/blank.hpp"
#include "paddle/fluid/imperative/type_defs.h"
#include "paddle/fluid/platform/variant.h"
#include "paddle/utils/small_vector.h"
#include "paddle/utils/variant.h"
namespace paddle {
namespace framework {
......@@ -40,38 +42,38 @@ class InferNoNeedBufferVarsFN;
using VariableNameMap = std::map<std::string, std::vector<std::string>>;
using VariableValueMap = std::map<std::string, std::vector<Variable*>>;
using Attribute = boost::variant<boost::blank,
int,
float,
std::string,
std::vector<int>,
std::vector<float>,
std::vector<std::string>,
bool,
std::vector<bool>,
BlockDesc*,
int64_t,
std::vector<BlockDesc*>,
std::vector<int64_t>,
std::vector<double>>;
using Attribute = paddle::variant<boost::blank,
int,
float,
std::string,
std::vector<int>,
std::vector<float>,
std::vector<std::string>,
bool,
std::vector<bool>,
BlockDesc*,
int64_t,
std::vector<BlockDesc*>,
std::vector<int64_t>,
std::vector<double>>;
using AttributeMap = std::unordered_map<std::string, Attribute>;
#ifdef PADDLE_WITH_ASCEND_CL
using NPUAttribute = boost::variant<boost::blank,
int,
float,
std::string,
std::vector<int>,
std::vector<float>,
std::vector<std::string>,
bool,
std::vector<bool>,
BlockDesc*,
int64_t,
std::vector<BlockDesc*>,
std::vector<int64_t>,
std::vector<double>,
std::vector<std::vector<int64_t>>>;
using NPUAttribute = paddle::variant<boost::blank,
int,
float,
std::string,
std::vector<int>,
std::vector<float>,
std::vector<std::string>,
bool,
std::vector<bool>,
BlockDesc*,
int64_t,
std::vector<BlockDesc*>,
std::vector<int64_t>,
std::vector<double>,
std::vector<std::vector<int64_t>>>;
using NPUAttributeMap = std::unordered_map<std::string, NPUAttribute>;
#endif
......
......@@ -315,7 +315,7 @@ void VarDesc::SetAttr(const std::string &name, const Attribute &v) {
// NOTICE(sandyhouse): pybind11 will take the empty list in python as
// the std::vector<int> type in C++; so we have to change the attr's type
// here if we meet this issue
proto::AttrType attr_type = static_cast<proto::AttrType>(v.which() - 1);
proto::AttrType attr_type = static_cast<proto::AttrType>(v.index() - 1);
if (attr_type == proto::AttrType::INTS &&
BOOST_GET_CONST(std::vector<int>, v).size() == 0u) {
// Find current attr via attr name and set the correct attribute value
......
......@@ -122,7 +122,7 @@ class FeedOp : public framework::OperatorBase {
auto &feed_item = feed_list.at(static_cast<size_t>(col));
FeedVariableVisitor visitor(out_var, place);
boost::apply_visitor(visitor, feed_item);
paddle::visit(visitor, feed_item);
}
};
......
......@@ -53,19 +53,19 @@ struct RawPointerVisitor : public boost::static_visitor<const void *> {
};
const framework::VariableNameMap &OpVariant::Inputs() const {
return *boost::apply_visitor(InputsVisitor(), op_);
return *paddle::visit(InputsVisitor(), op_);
}
const framework::VariableNameMap &OpVariant::Outputs() const {
return *boost::apply_visitor(OutputsVisitor(), op_);
return *paddle::visit(OutputsVisitor(), op_);
}
const framework::AttributeMap &OpVariant::Attrs() const {
return *boost::apply_visitor(AttributeMapVisitor(), op_);
return *paddle::visit(AttributeMapVisitor(), op_);
}
const void *OpVariant::RawPointer() const {
return boost::apply_visitor(RawPointerVisitor(), op_);
return paddle::visit(RawPointerVisitor(), op_);
}
void AppendOpVariantByOpName(const std::vector<framework::OpDesc *> &op_descs,
......
......@@ -61,7 +61,7 @@ class OpVariant {
return RawPointer() == other.RawPointer();
}
int which() const { return static_cast<int>(op_.which()); }
int index() const { return static_cast<int>(op_.index()); }
struct Hasher {
size_t operator()(const OpVariant &op) const {
......@@ -70,8 +70,8 @@ class OpVariant {
};
private:
const boost::variant<const framework::OperatorBase *,
const framework::OpDesc *>
const paddle::variant<const framework::OperatorBase *,
const framework::OpDesc *>
op_;
};
......
......@@ -226,6 +226,7 @@ REGISTER_OPERATOR(
ops::BoxDecoderAndAssignOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(
box_decoder_and_assign,
ops::BoxDecoderAndAssignKernel<paddle::platform::CPUDeviceContext, float>,
......
......@@ -122,7 +122,7 @@ class BoxDecoderAndAssignCUDAKernel : public framework::OpKernel<T> {
int grid = (roi_num * class_num + block - 1) / block;
auto& device_ctx = context.cuda_device_context();
const T box_clip = context.Attr<T>("box_clip");
const T box_clip = static_cast<T>(context.Attr<float>("box_clip"));
DecodeBoxKernel<T>
<<<grid, block, 0, device_ctx.stream()>>>(prior_box_data,
......
......@@ -41,7 +41,7 @@ class BoxDecoderAndAssignKernel : public framework::OpKernel<T> {
output_assign_box->mutable_data<T>({roi_num, 4}, context.GetPlace());
T* output_box_data = output_box->data<T>();
T* output_assign_box_data = output_assign_box->data<T>();
const T bbox_clip = context.Attr<T>("box_clip");
const T bbox_clip = static_cast<T>(context.Attr<float>("box_clip"));
for (int i = 0; i < roi_num; ++i) {
T prior_box_width = prior_box_data[i * 4 + 2] - prior_box_data[i * 4] + 1;
......
......@@ -47,7 +47,7 @@ template <typename T>
void MatrixBitCodeFunctor<T>::Add(const framework::Tensor &vec,
framework::Tensor *tmat) {
MatrixBitCodeFunctorAdd<T> func(vec, tmat);
code_table_.apply_visitor(func);
paddle::visit(func, code_table_);
}
template <typename T>
......@@ -79,7 +79,7 @@ template <typename T>
void MatrixBitCodeFunctor<T>::AddGrad(const framework::Tensor &tmat,
framework::Tensor *vec) {
MatrixBitCodeFunctorAddGrad<T> func(tmat, vec);
code_table_.apply_visitor(func);
paddle::visit(func, code_table_);
}
template <typename T>
......@@ -121,7 +121,7 @@ void MatrixBitCodeFunctor<T>::Sum(const framework::Tensor &tmat,
framework::Tensor *sum,
T scale_sum) {
MatrixBitCodeFunctorSum<T> func(tmat, sum, scale_sum);
code_table_.apply_visitor(func);
paddle::visit(func, code_table_);
}
template <typename T>
......@@ -165,7 +165,7 @@ void MatrixBitCodeFunctor<T>::Mul(framework::Tensor *tmat,
const framework::Tensor &weight,
const framework::Tensor &input) {
MatrixBitCodeFunctorMul<T> func(tmat, weight, input);
code_table_.apply_visitor(func);
paddle::visit(func, code_table_);
}
template <typename T, size_t N>
......@@ -222,7 +222,7 @@ void MatrixBitCodeFunctor<T>::MulGradWeight(const framework::Tensor &tmat,
framework::Tensor *weight,
const framework::Tensor &input) {
MatrixBitCodeFunctorMulGradWeight<T> func(tmat, weight, input);
code_table_.apply_visitor(func);
paddle::visit(func, code_table_);
}
template <typename T>
......@@ -279,7 +279,7 @@ void MatrixBitCodeFunctor<T>::MulGradWeight(const framework::Tensor &tmat,
phi::SelectedRows *weight,
const framework::Tensor &input) {
MatrixBitCodeFunctorMulGradWeightSR<T> func(tmat, weight, input);
code_table_.apply_visitor(func);
paddle::visit(func, code_table_);
}
template <typename T>
......@@ -323,7 +323,7 @@ void MatrixBitCodeFunctor<T>::MulGradError(const framework::Tensor &tmat,
const framework::Tensor &weight,
framework::Tensor *input) {
MatrixBitCodeFunctorMulGradError<T> func(tmat, weight, input);
code_table_.apply_visitor(func);
paddle::visit(func, code_table_);
}
template <typename T>
......@@ -352,7 +352,7 @@ struct MatrixBitCodeFunctorSub : public boost::static_visitor<void> {
template <typename T>
void MatrixBitCodeFunctor<T>::Sub(framework::Tensor *tmat) {
MatrixBitCodeFunctorSub<T> func(tmat);
code_table_.apply_visitor(func);
paddle::visit(func, code_table_);
}
template class MatrixBitCodeFunctor<float>;
......
......@@ -208,7 +208,7 @@ class CustomCodeTable {
const int64_t* ids_;
};
using CodeTable = boost::variant<SimpleCodeTable, CustomCodeTable<int64_t>>;
using CodeTable = paddle::variant<SimpleCodeTable, CustomCodeTable<int64_t>>;
template <typename T>
class MatrixBitCodeFunctor {
......
......@@ -201,7 +201,8 @@ class SliceOpVarTypeInference : public framework::VarTypeInference {
auto x_name = "Input";
auto out_name = "Out";
auto decrease_axis = ctx->GetAttr("decrease_axis");
auto not_decrease = boost::get<std::vector<int>>(decrease_axis).size() == 0;
auto not_decrease =
paddle::get<std::vector<int>>(decrease_axis).size() == 0;
if (not_decrease) {
// The default type of out is LoDTensor.
// However, if no axis is decreased and the type of input is not
......
......@@ -19,11 +19,11 @@
#include <string>
#include <vector>
#include "boost/variant.hpp"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/platform/device/npu/dynload/hccl.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/utils/variant.h"
#if defined(PADDLE_WITH_CNCL)
#include "paddle/fluid/platform/device/mlu/device_context.h"
#endif
......
......@@ -19,6 +19,8 @@
#include <popart/optimizer.hpp>
#include <popart/sgd.hpp>
#include "boost/blank.hpp"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/platform/device/ipu/ipu_names.h"
#include "paddle/fluid/platform/device/ipu/ipu_strategy.h"
......@@ -390,7 +392,7 @@ void Compiler::LowerConstants(const Scope* scope) {
auto* tensor = var->GetMutable<framework::LoDTensor>();
ConstantOpAttrVisitor visitor(tensor, dtype);
auto value = op_desc->GetAttr("value");
boost::apply_visitor(visitor, value);
paddle::visit(visitor, value);
auto ddim = phi::make_ddim(shape);
tensor->Resize(ddim);
......@@ -475,7 +477,7 @@ void Compiler::LowerBody() {
auto attributes = std::map<std::string, popart::any>{};
for (auto& attr : op_desc->GetAttrMap()) {
CustomOpAttrVisitor visitor(&attributes, attr.first);
boost::apply_visitor(visitor, attr.second);
paddle::visit(visitor, attr.second);
}
auto __op_type =
BOOST_GET_CONST(std::string, op_desc->GetAttr("__op_type"));
......
......@@ -122,20 +122,21 @@ using namespace ::phi::enforce; // NOLINT
#endif
/*
* Summary: This BOOST_GET(_**) series macros are used to call boost::get
* safely. boost::get is not a completely safe api, although it will not
* Summary: This BOOST_GET(_**) series macros are used to call paddle::get
* safely. paddle::get is not a completely safe api, although it will not
* go wrong in most cases, but in extreme cases, it may fail and directly
* throw a boost::bad_get exception, without any stack information.
* throw a paddle::bad_variant_access const exception, without any stack
*information.
* This kind of problems is difficult to debug, so add these macros to
* enrich boost::get error information. At the same time, we restrict
* the direct use of boost::get by CI rule.
* enrich paddle::get error information. At the same time, we restrict
* the direct use of paddle::get by CI rule.
*
* Parameters:
*     __TYPE: the target variable type
* __VALUE: the target variable to get
*
* Examples:
* - unsafe writing: int x = boost::get<int>(y);
* - unsafe writing: int x = paddle::get<int>(y);
* - safe writing: int x = BOOST_GET(int, y);
*
* Note: GCC 4.8 cannot select right overloaded function here, so need
......@@ -155,12 +156,12 @@ using namespace phi::enforce::details; // NOLINT
__OutputTypePtr, \
__OutputType>::type { \
try { \
return boost::get<OutputType>(input); \
} catch (boost::bad_get&) { \
return paddle::get<OutputType>(input); \
} catch (paddle::bad_variant_access const&) { \
HANDLE_THE_ERROR \
throw ::phi::enforce::EnforceNotMet( \
phi::errors::InvalidArgument( \
"boost::get failed, cannot get value " \
"paddle::get failed, cannot get value " \
"(%s) by type %s, its type is %s.", \
expression, \
phi::enforce::demangle(typeid(OutputType).name()), \
......
......@@ -22,13 +22,14 @@
#include "gflags/gflags.h"
#include "paddle/fluid/platform/macros.h"
#include "paddle/fluid/platform/variant.h"
#include "paddle/utils/variant.h"
namespace paddle {
namespace platform {
struct FlagInfo {
using ValueType =
boost::variant<bool, int32_t, int64_t, uint64_t, double, std::string>;
paddle::variant<bool, int32_t, int64_t, uint64_t, double, std::string>;
std::string name;
mutable void *value_ptr;
ValueType default_value;
......
......@@ -259,7 +259,7 @@ static void RegisterGlobalVarGetterSetter() {
const auto &default_value = pair.second.default_value;
RegisterGetterSetterVisitor visitor(
"FLAGS_" + name, is_writable, value_ptr);
boost::apply_visitor(visitor, default_value);
paddle::visit(visitor, default_value);
}
}
......
......@@ -3338,7 +3338,7 @@ All parameter, weight, gradient are variables in Paddle.
py::return_value_policy::take_ownership);
py::class_<FetchList>(m, "FetchList", R"DOC( FetchList is a
vector of boost::variant<LoDTensor, LoDTensorArray>.
vector of paddle::variant<LoDTensor, LoDTensorArray>.
)DOC")
.def(
"_move_to_list",
......@@ -3385,7 +3385,7 @@ All parameter, weight, gradient are variables in Paddle.
py::arg("var"));
py::class_<FetchUnmergedList>(m, "FetchUnmergedList", R"DOC(
FetchUnmergedList is 2-D array of FetchType(boost::variant(LoDTensor, LoDTensorArray)).
FetchUnmergedList is 2-D array of FetchType(paddle::variant(LoDTensor, LoDTensorArray)).
)DOC")
.def(
"_move_to_list",
......@@ -4606,12 +4606,15 @@ All parameter, weight, gradient are variables in Paddle.
pybind11::gil_scoped_release release;
ret = self.Run(fetch_tensors, return_merged);
}
// TODO(Ruibiao): Refactor the run interface of PE to avoid use
// boost::get here
if (return_merged) {
return py::cast(
std::move(BOOST_GET(paddle::framework::FetchList, ret)));
std::move(boost::get<paddle::framework::FetchList>(ret)));
} else {
return py::cast(std::move(
BOOST_GET(paddle::framework::FetchUnmergedList, ret)));
boost::get<paddle::framework::FetchUnmergedList>(ret)));
}
})
.def("device_count", &ParallelExecutor::DeviceCount);
......
......@@ -19,11 +19,13 @@ limitations under the License. */
#include "glog/logging.h"
#include "paddle/fluid/platform/variant.h"
#include "paddle/utils/variant.h"
#include "pybind11/numpy.h"
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
// Cast boost::variant for PyBind.
// Cast paddle::variant for PyBind.
// Copy from
// https://github.com/pybind/pybind11/issues/576#issuecomment-269563199
namespace pybind11 {
namespace detail {
......@@ -37,8 +39,7 @@ namespace detail {
#endif
// Can be replaced by a generic lambda in C++14
struct PYBIND11_HIDDEN paddle_variant_caster_visitor
: public boost::static_visitor<handle> {
struct PYBIND11_HIDDEN paddle_variant_caster_visitor {
return_value_policy policy;
handle parent;
......@@ -127,8 +128,13 @@ struct paddle_variant_caster<V<Ts...>> {
static handle cast(Type const& src,
return_value_policy policy,
handle parent) {
/*
auto paddle_variant_caster_visitor = [&](Type const& src)->handle {
return make_caster<Type>::cast(src, policy, parent);
}
*/
paddle_variant_caster_visitor visitor(policy, parent);
return boost::apply_visitor(visitor, src);
return paddle::visit(visitor, src);
}
PYBIND11_TYPE_CASTER(Type, _("Variant"));
......@@ -137,8 +143,8 @@ struct paddle_variant_caster<V<Ts...>> {
// Add specialization for concrete variant type
template <class... Args>
struct type_caster<boost::variant<Args...>>
: paddle_variant_caster<boost::variant<Args...>> {};
struct type_caster<paddle::variant<Args...>>
: paddle_variant_caster<paddle::variant<Args...>> {};
} // namespace detail
} // namespace pybind11
......@@ -36,7 +36,7 @@ class ProtoArgumentMappingContext : public ::phi::ArgumentMappingContext {
bool HasAttr(const std::string& name) const override;
// now we can't use Attribute here, it will cause phi relay on
// boost::variant and BlockDesc
// paddle::variant and BlockDesc
paddle::any Attr(const std::string& name) const override;
size_t InputSize(const std::string& name) const override;
......
......@@ -99,7 +99,7 @@ class ArgumentMappingContext {
virtual bool HasAttr(const std::string& name) const = 0;
// now we can't use Attribute here, it will cause phi relay on
// boost::variant and BlockDesc
// paddle::variant and BlockDesc
virtual paddle::any Attr(const std::string& name) const = 0;
virtual size_t InputSize(const std::string& name) const = 0;
......
......@@ -113,7 +113,7 @@ void* DenseTensor::mutable_data(const Place& place,
size = requested_size;
}
/* some versions of boost::variant don't have operator!= */
/* some versions of paddle::variant don't have operator!= */
if (holder_ == nullptr || !(holder_->place() == place) ||
holder_->size() < size + meta_.offset) {
holder_.reset();
......@@ -142,7 +142,7 @@ void* DenseTensor::mutable_data(const Place& place,
"] now"));
size_t size = numel() * SizeOf(dtype());
/* some versions of boost::variant don't have operator!= */
/* some versions of paddle::variant don't have operator!= */
if (holder_ == nullptr || !(holder_->place() == place) ||
holder_->size() < size + meta_.offset ||
!(place.GetType() == phi::AllocationType::GPU &&
......
......@@ -20,31 +20,29 @@ limitations under the License. */
#include <unordered_map>
#include <vector>
// <boost/variant.hpp> is not suitable to be placed in the header file,
// it will introduce a large number of unnecessary includes, and these type
// declarations that depend on boost are also not suitable for the phi header
// file. Do some repeated forward declarations here to avoid
// <boost/variant.hpp> spreading to a large number of phi kernel files
#include "boost/blank.hpp"
#include "paddle/utils/variant.h"
namespace egr {
class EagerVariable;
}
namespace paddle {
namespace framework {
class BlockDesc;
using Attribute = boost::variant<boost::blank,
int,
float,
std::string,
std::vector<int>,
std::vector<float>,
std::vector<std::string>,
bool,
std::vector<bool>,
BlockDesc*,
int64_t,
std::vector<BlockDesc*>,
std::vector<int64_t>,
std::vector<double>>;
using Attribute = paddle::variant<boost::blank,
int,
float,
std::string,
std::vector<int>,
std::vector<float>,
std::vector<std::string>,
bool,
std::vector<bool>,
BlockDesc*,
int64_t,
std::vector<BlockDesc*>,
std::vector<int64_t>,
std::vector<double>>;
using AttributeMap = std::unordered_map<std::string, Attribute>;
} // namespace framework
namespace imperative {
......
......@@ -179,7 +179,7 @@ dtype::pstring* StringTensor::mutable_data(const phi::Place& place,
size = requested_size;
}
/* some versions of boost::variant don't have operator!= */
/* some versions of paddle::variant don't have operator!= */
if (holder_ == nullptr || !(holder_->place() == place) ||
holder_->size() < size + meta_.offset) {
holder_.reset();
......
......@@ -2,6 +2,11 @@
// https://github.com/mpark/variant/blob/single-header/v1.4.0/variant.hpp
// Modify the following points:
// 1. modify namespace mpark to namespace paddle
// 2. add type() member function for variant class
// 3. remove the visitation implementation under the branhch with
// MPARK_CPP14_CONSTEXPR defined since lib::cpp14::array could not be converted
// to std::initializer_list in Paddle's compilation
// 4. decorate PYBIND11_HIDDEN for struct value_visitor
// MPark.Variant
//
......@@ -22,6 +27,14 @@
#pragma GCC diagnostic ignored "-Wdeprecated-copy"
#endif
#if !defined(PYBIND11_HIDDEN)
#ifdef _WIN32
#define PYBIND11_HIDDEN __declspec(dllexport)
#else
#define PYBIND11_HIDDEN __attribute__((visibility("hidden")))
#endif
#endif
/*
variant synopsis
......@@ -1649,7 +1662,7 @@ struct variant {
};
template <typename Visitor>
struct value_visitor {
struct PYBIND11_HIDDEN value_visitor {
Visitor &&visitor_;
template <typename... Alts>
......@@ -2454,7 +2467,7 @@ class variant {
impl_.swap(that.impl_);
}
inline const std::type_info &type() noexcept { return impl_.type(); }
inline const std::type_info &type() const noexcept { return impl_.type(); }
private:
detail::impl<Ts...> impl_;
......@@ -2708,30 +2721,6 @@ inline constexpr bool operator!=(monostate, monostate) noexcept {
return false;
}
#ifdef MPARK_CPP14_CONSTEXPR
namespace detail {
inline constexpr bool all(std::initializer_list<bool> bs) {
for (bool b : bs) {
if (!b) {
return false;
}
}
return true;
}
} // namespace detail
template <typename Visitor, typename... Vs>
inline constexpr decltype(auto) visit(Visitor &&visitor, Vs &&...vs) {
return (detail::all(
lib::array<bool, sizeof...(Vs)>{!vs.valueless_by_exception()...})
? (void)0
: throw_bad_variant_access()),
detail::visitation::variant::visit_value(
lib::forward<Visitor>(visitor), lib::forward<Vs>(vs)...);
}
#else
namespace detail {
template <std::size_t N>
......@@ -2755,12 +2744,11 @@ inline constexpr DECLTYPE_AUTO visit(Visitor &&visitor, Vs &&...vs)
: throw_bad_variant_access()),
detail::visitation::variant::visit_value(lib::forward<Visitor>(visitor),
lib::forward<Vs>(vs)...))
#endif
template <typename... Ts>
inline auto swap(variant<Ts...> &lhs,
variant<Ts...> &rhs) noexcept(noexcept(lhs.swap(rhs)))
-> decltype(lhs.swap(rhs)) {
template <typename... Ts>
inline auto swap(variant<Ts...> &lhs,
variant<Ts...> &rhs) noexcept(noexcept(lhs.swap(rhs)))
-> decltype(lhs.swap(rhs)) {
lhs.swap(rhs);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册