未验证 提交 b3cf28f8 编写于 作者: R Ruibiao Chen 提交者: GitHub

Remove boost::variant (#43100)

* boost::variant -> paddle::variant

* boost::variant.apply_visit -> paddle::visit

* Update pybind_boost_hraders.h

* Fix CINN compilation errors

* Revert FetchResultType
上级 369b2b1b
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
#include <string> #include <string>
#include "boost/variant.hpp"
#include "paddle/fluid/distributed/collective/Types.h" #include "paddle/fluid/distributed/collective/Types.h"
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/framework/variable.h"
...@@ -27,6 +26,7 @@ ...@@ -27,6 +26,7 @@
#include "paddle/fluid/platform/device/npu/npu_info.h" #include "paddle/fluid/platform/device/npu/npu_info.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/utils/variant.h"
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
......
...@@ -25,7 +25,6 @@ ...@@ -25,7 +25,6 @@
#include <string> #include <string>
#include "boost/variant.hpp"
#include "paddle/fluid/distributed/collective/Types.h" #include "paddle/fluid/distributed/collective/Types.h"
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/framework/variable.h"
...@@ -43,6 +42,7 @@ ...@@ -43,6 +42,7 @@
#endif #endif
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/utils/variant.h"
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
......
...@@ -336,7 +336,7 @@ static std::string AttrTypeToString(const proto::AttrType& type) { ...@@ -336,7 +336,7 @@ static std::string AttrTypeToString(const proto::AttrType& type) {
} }
default: { default: {
PADDLE_THROW(platform::errors::Fatal( PADDLE_THROW(platform::errors::Fatal(
"AttrType of type boost::variant only supports specific data types." "AttrType of type paddle::variant only supports specific data types."
"However, detected unrecognized AttrType: %d", "However, detected unrecognized AttrType: %d",
type)); type));
} }
...@@ -344,37 +344,39 @@ static std::string AttrTypeToString(const proto::AttrType& type) { ...@@ -344,37 +344,39 @@ static std::string AttrTypeToString(const proto::AttrType& type) {
return ret; return ret;
} }
template <typename T> template <typename T, bool IsVector>
static std::string GetAttrValue(const framework::Attribute& attr, static typename std::enable_if<IsVector, std::string>::type GetAttrValue(
bool is_vector) { const framework::Attribute& attr) {
std::string val = ""; std::string val = "";
if (is_vector) { val += "{";
val += "{"; for (auto x : BOOST_GET_CONST(std::vector<T>, attr)) {
for (auto x : BOOST_GET_CONST(std::vector<T>, attr)) { val += std::to_string(x) + ",";
val += std::to_string(x) + ",";
}
if (val.size() > 1) val.pop_back();
val += "}";
} else {
val = std::to_string(BOOST_GET_CONST(T, attr));
} }
if (val.size() > 1) val.pop_back();
val += "}";
return val; return val;
} }
template <typename T, bool IsVector>
static typename std::enable_if<!IsVector, std::string>::type GetAttrValue(
const framework::Attribute& attr) {
return std::to_string(BOOST_GET_CONST(T, attr));
}
static std::pair<std::string, std::string> GetAttrType( static std::pair<std::string, std::string> GetAttrType(
const framework::Attribute& attr, bool is_arg) { const framework::Attribute& attr, bool is_arg) {
std::string ret = ""; std::string ret = "";
std::string val = ""; std::string val = "";
size_t variant_pos = attr.which(); size_t variant_pos = attr.index();
switch (variant_pos) { switch (variant_pos) {
case (1): { case (1): {
ret = "int"; ret = "int";
val = GetAttrValue<int>(attr, false); val = GetAttrValue<int, false>(attr);
break; break;
} }
case (2): { case (2): {
ret = "float"; ret = "float";
val = GetAttrValue<float>(attr, false); val = GetAttrValue<float, false>(attr);
break; break;
} }
case (3): { case (3): {
...@@ -386,13 +388,13 @@ static std::pair<std::string, std::string> GetAttrType( ...@@ -386,13 +388,13 @@ static std::pair<std::string, std::string> GetAttrType(
case (4): { case (4): {
ret = "std::vector<int>"; ret = "std::vector<int>";
if (is_arg) ret += "&"; if (is_arg) ret += "&";
val = GetAttrValue<int>(attr, true); val = GetAttrValue<int, true>(attr);
break; break;
} }
case (5): { case (5): {
ret = "std::vector<float>"; ret = "std::vector<float>";
if (is_arg) ret += "&"; if (is_arg) ret += "&";
val = GetAttrValue<float>(attr, true); val = GetAttrValue<float, true>(attr);
break; break;
} }
case (6): { case (6): {
...@@ -408,13 +410,13 @@ static std::pair<std::string, std::string> GetAttrType( ...@@ -408,13 +410,13 @@ static std::pair<std::string, std::string> GetAttrType(
} }
case (7): { case (7): {
ret = "bool"; ret = "bool";
val = GetAttrValue<bool>(attr, false); val = GetAttrValue<bool, false>(attr);
break; break;
} }
case (8): { case (8): {
ret = "std::vector<bool>"; ret = "std::vector<bool>";
if (is_arg) ret += "&"; if (is_arg) ret += "&";
val = GetAttrValue<bool>(attr, true); val = GetAttrValue<bool, true>(attr);
break; break;
} }
case (9): { case (9): {
...@@ -423,7 +425,7 @@ static std::pair<std::string, std::string> GetAttrType( ...@@ -423,7 +425,7 @@ static std::pair<std::string, std::string> GetAttrType(
} }
case (10): { case (10): {
ret = "int64_t"; ret = "int64_t";
val = GetAttrValue<int64_t>(attr, false); val = GetAttrValue<int64_t, false>(attr);
break; break;
} }
case (11): { case (11): {
...@@ -434,18 +436,18 @@ static std::pair<std::string, std::string> GetAttrType( ...@@ -434,18 +436,18 @@ static std::pair<std::string, std::string> GetAttrType(
case (12): { case (12): {
ret = "std::vector<int64_t>"; ret = "std::vector<int64_t>";
if (is_arg) ret += "&"; if (is_arg) ret += "&";
val = GetAttrValue<int64_t>(attr, true); val = GetAttrValue<int64_t, true>(attr);
break; break;
} }
case (13): { case (13): {
ret = "std::vector<double>"; ret = "std::vector<double>";
if (is_arg) ret += "&"; if (is_arg) ret += "&";
val = GetAttrValue<double>(attr, true); val = GetAttrValue<double, true>(attr);
break; break;
} }
default: { default: {
PADDLE_THROW(platform::errors::Fatal( PADDLE_THROW(platform::errors::Fatal(
"AttrType of type boost::variant only supports specific data types." "AttrType of type paddle::variant only supports specific data types."
"However, detected unrecognized AttrType: %d", "However, detected unrecognized AttrType: %d",
variant_pos)); variant_pos));
} }
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/attribute.h" #include "paddle/fluid/framework/attribute.h"
#include "boost/blank.hpp"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
......
...@@ -23,12 +23,12 @@ limitations under the License. */ ...@@ -23,12 +23,12 @@ limitations under the License. */
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
#include "boost/variant/get.hpp"
#include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/type_defs.h" #include "paddle/fluid/framework/type_defs.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h" #include "paddle/fluid/platform/errors.h"
#include "paddle/utils/any.h" #include "paddle/utils/any.h"
#include "paddle/utils/variant.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -45,8 +45,8 @@ struct ExtractAttribute { ...@@ -45,8 +45,8 @@ struct ExtractAttribute {
T* operator()(Attribute& attr) const { T* operator()(Attribute& attr) const {
T* attr_value = nullptr; T* attr_value = nullptr;
try { try {
attr_value = &boost::get<T>(attr); attr_value = &paddle::get<T>(attr);
} catch (boost::bad_get& bad_get) { } catch (paddle::bad_variant_access const& bad_get) {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"Cannot get attribute (%s) by type %s, its type is %s.", "Cannot get attribute (%s) by type %s, its type is %s.",
attr_name_, attr_name_,
...@@ -80,8 +80,8 @@ struct ExtractAttribute<bool> { ...@@ -80,8 +80,8 @@ struct ExtractAttribute<bool> {
} }
bool* attr_value = nullptr; bool* attr_value = nullptr;
try { try {
attr_value = &boost::get<bool>(attr); attr_value = &paddle::get<bool>(attr);
} catch (boost::bad_get& bad_get) { } catch (paddle::bad_variant_access const& bad_get) {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"Cannot get attribute (%s) by type bool, its type is %s.", "Cannot get attribute (%s) by type bool, its type is %s.",
attr_name_, attr_name_,
...@@ -108,8 +108,8 @@ struct ExtractAttribute<int64_t> { ...@@ -108,8 +108,8 @@ struct ExtractAttribute<int64_t> {
} }
int64_t* attr_value = nullptr; int64_t* attr_value = nullptr;
try { try {
attr_value = &boost::get<int64_t>(attr); attr_value = &paddle::get<int64_t>(attr);
} catch (boost::bad_get& bad_get) { } catch (paddle::bad_variant_access const& bad_get) {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"Cannot get attribute (%s) by type int64_t, its type is %s.", "Cannot get attribute (%s) by type int64_t, its type is %s.",
attr_name_, attr_name_,
...@@ -138,8 +138,8 @@ struct ExtractAttribute<std::vector<int64_t>> { ...@@ -138,8 +138,8 @@ struct ExtractAttribute<std::vector<int64_t>> {
} }
std::vector<int64_t>* attr_value = nullptr; std::vector<int64_t>* attr_value = nullptr;
try { try {
attr_value = &boost::get<std::vector<int64_t>>(attr); attr_value = &paddle::get<std::vector<int64_t>>(attr);
} catch (boost::bad_get& bad_get) { } catch (paddle::bad_variant_access const& bad_get) {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"Cannot get attribute (%s) by type std::vector<int64_t>, its type is " "Cannot get attribute (%s) by type std::vector<int64_t>, its type is "
"%s.", "%s.",
...@@ -167,8 +167,8 @@ struct ExtractAttribute<float> { ...@@ -167,8 +167,8 @@ struct ExtractAttribute<float> {
} }
float* attr_value = nullptr; float* attr_value = nullptr;
try { try {
attr_value = &boost::get<float>(attr); attr_value = &paddle::get<float>(attr);
} catch (boost::bad_get& bad_get) { } catch (paddle::bad_variant_access const& bad_get) {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"Cannot get attribute (%s) by type float, its type is %s.", "Cannot get attribute (%s) by type float, its type is %s.",
attr_name_, attr_name_,
...@@ -197,8 +197,8 @@ struct ExtractAttribute<std::vector<double>> { ...@@ -197,8 +197,8 @@ struct ExtractAttribute<std::vector<double>> {
} }
std::vector<double>* attr_value = nullptr; std::vector<double>* attr_value = nullptr;
try { try {
attr_value = &boost::get<std::vector<double>>(attr); attr_value = &paddle::get<std::vector<double>>(attr);
} catch (boost::bad_get& bad_get) { } catch (paddle::bad_variant_access const& bad_get) {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"Cannot get attribute (%s) by type std::vector<double>, its type is " "Cannot get attribute (%s) by type std::vector<double>, its type is "
"%s.", "%s.",
...@@ -214,11 +214,11 @@ struct ExtractAttribute<std::vector<double>> { ...@@ -214,11 +214,11 @@ struct ExtractAttribute<std::vector<double>> {
template <typename T> template <typename T>
inline proto::AttrType AttrTypeID() { inline proto::AttrType AttrTypeID() {
Attribute tmp = T(); Attribute tmp = T();
return static_cast<proto::AttrType>(tmp.which() - 1); return static_cast<proto::AttrType>(tmp.index() - 1);
} }
inline proto::AttrType AttrTypeID(const Attribute& attr) { inline proto::AttrType AttrTypeID(const Attribute& attr) {
return static_cast<proto::AttrType>(attr.which() - 1); return static_cast<proto::AttrType>(attr.index() - 1);
} }
class AttrReader { class AttrReader {
......
...@@ -272,7 +272,7 @@ void BlockDesc::MoveFrom(BlockDesc *block) { ...@@ -272,7 +272,7 @@ void BlockDesc::MoveFrom(BlockDesc *block) {
for (const auto &pair : src_op->GetAttrMap()) { for (const auto &pair : src_op->GetAttrMap()) {
const auto &attr_name = pair.first; const auto &attr_name = pair.first;
const auto &attr_value = pair.second; const auto &attr_value = pair.second;
auto attr_type = static_cast<proto::AttrType>(attr_value.which() - 1); auto attr_type = static_cast<proto::AttrType>(attr_value.index() - 1);
if (attr_type == proto::AttrType::BLOCK) { if (attr_type == proto::AttrType::BLOCK) {
auto block_id = BOOST_GET_CONST(BlockDesc *, attr_value)->ID(); auto block_id = BOOST_GET_CONST(BlockDesc *, attr_value)->ID();
dst_op->SetBlockAttr(attr_name, prog_->MutableBlock(block_id)); dst_op->SetBlockAttr(attr_name, prog_->MutableBlock(block_id));
......
...@@ -174,7 +174,7 @@ FetchResultType AsyncSSAGraphExecutor::Run( ...@@ -174,7 +174,7 @@ FetchResultType AsyncSSAGraphExecutor::Run(
HandleException(); HandleException();
FetchList ret; FetchList ret;
auto &val = BOOST_GET(FetchList, fetch_data); auto &val = boost::get<FetchList>(fetch_data);
for (size_t fetch_idx = 0; fetch_idx < fetch_tensors.size(); ++fetch_idx) { for (size_t fetch_idx = 0; fetch_idx < fetch_tensors.size(); ++fetch_idx) {
if (data_is_lod_tensor(val.at(fetch_idx))) { if (data_is_lod_tensor(val.at(fetch_idx))) {
std::vector<const LoDTensor *> lodtensor_ptrs; std::vector<const LoDTensor *> lodtensor_ptrs;
......
...@@ -228,7 +228,7 @@ void FetchAsyncOpHandle::RunImpl() { ...@@ -228,7 +228,7 @@ void FetchAsyncOpHandle::RunImpl() {
} }
if (return_merged_) { if (return_merged_) {
auto &val = BOOST_GET(FetchList, *data_); auto &val = boost::get<FetchList>(*data_);
if (src_vars[0]->IsType<LoDTensor>()) { if (src_vars[0]->IsType<LoDTensor>()) {
// to lodtensor type // to lodtensor type
std::vector<const LoDTensor *> src_lodtensors; std::vector<const LoDTensor *> src_lodtensors;
...@@ -263,7 +263,7 @@ void FetchAsyncOpHandle::RunImpl() { ...@@ -263,7 +263,7 @@ void FetchAsyncOpHandle::RunImpl() {
val.at(offset_) = std::move(dst_lodtensor_array); val.at(offset_) = std::move(dst_lodtensor_array);
} }
} else { } else {
auto &val = BOOST_GET(FetchUnmergedList, *data_); auto &val = boost::get<FetchUnmergedList>(*data_);
auto &dst_tensors = val.at(offset_); auto &dst_tensors = val.at(offset_);
dst_tensors.reserve(src_vars.size()); dst_tensors.reserve(src_vars.size());
......
...@@ -84,7 +84,7 @@ void FetchOpHandle::WaitAndMergeCPUFetchVars() const { ...@@ -84,7 +84,7 @@ void FetchOpHandle::WaitAndMergeCPUFetchVars() const {
for (auto &t : tensors_) { for (auto &t : tensors_) {
tensors_ptr.emplace_back(&BOOST_GET_CONST(LoDTensor, t)); tensors_ptr.emplace_back(&BOOST_GET_CONST(LoDTensor, t));
} }
auto &val = BOOST_GET(FetchList, *data_); auto &val = boost::get<FetchList>(*data_);
LoDTensor var; LoDTensor var;
MergeLoDTensor(&var, tensors_ptr, platform::CPUPlace()); MergeLoDTensor(&var, tensors_ptr, platform::CPUPlace());
val.at(offset_) = std::move(var); val.at(offset_) = std::move(var);
...@@ -106,11 +106,11 @@ void FetchOpHandle::WaitAndMergeCPUFetchVars() const { ...@@ -106,11 +106,11 @@ void FetchOpHandle::WaitAndMergeCPUFetchVars() const {
tmp_array.emplace_back(); tmp_array.emplace_back();
MergeLoDTensor(&(tmp_array.back()), tensors_ptr, platform::CPUPlace()); MergeLoDTensor(&(tmp_array.back()), tensors_ptr, platform::CPUPlace());
} }
auto &val = BOOST_GET(FetchList, *data_); auto &val = boost::get<FetchList>(*data_);
val.at(offset_) = std::move(tmp_array); val.at(offset_) = std::move(tmp_array);
} }
} else { } else {
auto &val = BOOST_GET(FetchUnmergedList, *data_); auto &val = boost::get<FetchUnmergedList>(*data_);
val.at(offset_) = std::move(tensors_); val.at(offset_) = std::move(tensors_);
} }
} }
......
...@@ -278,8 +278,7 @@ FetchResultType ParallelSSAGraphExecutor::Run( ...@@ -278,8 +278,7 @@ FetchResultType ParallelSSAGraphExecutor::Run(
if (!is_valid[scope_idx]) { if (!is_valid[scope_idx]) {
continue; continue;
} }
const auto &fetch_list = const auto &fetch_list = boost::get<FetchList>(fetch_data[scope_idx]);
BOOST_GET_CONST(FetchList, fetch_data[scope_idx]);
if (data_is_lod_tensor(fetch_list[fetch_idx])) { if (data_is_lod_tensor(fetch_list[fetch_idx])) {
lodtensor_ptrs.push_back( lodtensor_ptrs.push_back(
&(BOOST_GET_CONST(LoDTensor, fetch_list[fetch_idx]))); &(BOOST_GET_CONST(LoDTensor, fetch_list[fetch_idx])));
...@@ -318,7 +317,7 @@ FetchResultType ParallelSSAGraphExecutor::Run( ...@@ -318,7 +317,7 @@ FetchResultType ParallelSSAGraphExecutor::Run(
continue; continue;
} }
const auto &fetch_list = const auto &fetch_list =
BOOST_GET_CONST(FetchUnmergedList, fetch_data[scope_idx]); boost::get<FetchUnmergedList>(fetch_data[scope_idx]);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
fetch_list[fetch_idx].size(), fetch_list[fetch_idx].size(),
1, 1,
......
...@@ -23,10 +23,10 @@ limitations under the License. */ ...@@ -23,10 +23,10 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
using FeedType = boost::variant<LoDTensor, Strings>; using FeedType = paddle::variant<LoDTensor, Strings>;
using FeedList = std::vector<FeedType>; using FeedList = std::vector<FeedType>;
using FetchType = boost::variant<LoDTensor, LoDTensorArray>; using FetchType = paddle::variant<LoDTensor, LoDTensorArray, framework::Vocab>;
using FetchList = std::vector<FetchType>; using FetchList = std::vector<FetchType>;
using FetchUnmergedList = std::vector<std::vector<FetchType>>; using FetchUnmergedList = std::vector<std::vector<FetchType>>;
......
...@@ -121,7 +121,7 @@ void ConvElementwiseAddFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -121,7 +121,7 @@ void ConvElementwiseAddFusePass::ApplyImpl(ir::Graph* graph) const {
elementwise_add_op_desc->GetNullableAttr("out_threshold"); elementwise_add_op_desc->GetNullableAttr("out_threshold");
// set the out_threshold of the elementwise add op to be the out_threshold // set the out_threshold of the elementwise add op to be the out_threshold
// of the conv2d_fusion // of the conv2d_fusion
if (out_threshold_attr.which()) { if (out_threshold_attr.index()) {
new_op_desc.SetAttr("out_threshold", out_threshold_attr); new_op_desc.SetAttr("out_threshold", out_threshold_attr);
} }
new_op_desc.Flush(); new_op_desc.Flush();
......
...@@ -261,7 +261,7 @@ int FCFusePass::ApplyFCPattern(Graph* graph, bool with_relu) const { ...@@ -261,7 +261,7 @@ int FCFusePass::ApplyFCPattern(Graph* graph, bool with_relu) const {
// out_thrshold of fc // out_thrshold of fc
auto out_threshold_attr = auto out_threshold_attr =
elementwise_add_op_desc->GetNullableAttr("out_threshold"); elementwise_add_op_desc->GetNullableAttr("out_threshold");
if (out_threshold_attr.which()) { if (out_threshold_attr.index()) {
VLOG(4) << "setting out_threshold: " VLOG(4) << "setting out_threshold: "
<< BOOST_GET_CONST(float, out_threshold_attr); << BOOST_GET_CONST(float, out_threshold_attr);
desc.SetAttr("out_threshold", out_threshold_attr); desc.SetAttr("out_threshold", out_threshold_attr);
......
...@@ -78,7 +78,7 @@ static std::string RefineTemplateWithAttr(const std::string& op_type, ...@@ -78,7 +78,7 @@ static std::string RefineTemplateWithAttr(const std::string& op_type,
} }
Attribute attr = it->second; Attribute attr = it->second;
proto::AttrType attr_type = proto::AttrType attr_type =
static_cast<proto::AttrType>(it->second.which() - 1); static_cast<proto::AttrType>(it->second.index() - 1);
if (attr_type == proto::AttrType::BOOLEAN) { if (attr_type == proto::AttrType::BOOLEAN) {
bool result = BOOST_GET(bool, attr); bool result = BOOST_GET(bool, attr);
if (result) { if (result) {
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/fluid/framework/ir/generate_pass.h" #include "paddle/fluid/framework/ir/generate_pass.h"
#include "boost/blank.hpp"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace paddle { namespace paddle {
...@@ -105,7 +106,7 @@ Attribute GetOpAttrValue(const OpDesc* desc, ...@@ -105,7 +106,7 @@ Attribute GetOpAttrValue(const OpDesc* desc,
const proto::PassDesc::Attr& attr) { const proto::PassDesc::Attr& attr) {
Attribute value = desc->GetAttr(attr.name()); Attribute value = desc->GetAttr(attr.name());
if (attr.has_element_index()) { if (attr.has_element_index()) {
value = boost::apply_visitor(element_visitor(attr.element_index()), value); value = paddle::visit(element_visitor(attr.element_index()), value);
} }
return value; return value;
} }
...@@ -203,7 +204,7 @@ void InitGeneratePattern(const proto::PassDesc& pass_desc, PDPattern* pattern) { ...@@ -203,7 +204,7 @@ void InitGeneratePattern(const proto::PassDesc& pass_desc, PDPattern* pattern) {
Attribute attr = GetVarAttrValue(x->Var(), condition.attr()); Attribute attr = GetVarAttrValue(x->Var(), condition.attr());
if (condition.has_operation()) { if (condition.has_operation()) {
Attribute operation = GetAttrValue(condition.operation().value()); Attribute operation = GetAttrValue(condition.operation().value());
attr = boost::apply_visitor( attr = paddle::visit(
operation_visitor(condition.operation().type()), attr, operation); operation_visitor(condition.operation().type()), attr, operation);
} }
switch (condition.type()) { switch (condition.type()) {
...@@ -388,7 +389,7 @@ GraphPatternDetector::handle_t GetGenerateRewrite( ...@@ -388,7 +389,7 @@ GraphPatternDetector::handle_t GetGenerateRewrite(
if (attr_map.has_operation()) { if (attr_map.has_operation()) {
Attribute operation = Attribute operation =
GetAttrValue(attr_map.operation().value()); GetAttrValue(attr_map.operation().value());
attr = boost::apply_visitor( attr = paddle::visit(
operation_visitor(attr_map.operation().type()), operation_visitor(attr_map.operation().type()),
attr, attr,
operation); operation);
......
...@@ -320,7 +320,7 @@ void CPUQuantizeSquashPass::RequantOpSquash(Graph* graph) const { ...@@ -320,7 +320,7 @@ void CPUQuantizeSquashPass::RequantOpSquash(Graph* graph) const {
"should have requantize output as input.", "should have requantize output as input.",
requant_out->Name())); requant_out->Name()));
float requant_scale_in = float requant_scale_in =
boost::get<float>(requant_op->Op()->GetAttr("Scale_in")); paddle::get<float>(requant_op->Op()->GetAttr("Scale_in"));
auto scale_name = "Scale_in"; auto scale_name = "Scale_in";
if (any_op->Op()->Type() == "matmul") if (any_op->Op()->Type() == "matmul")
......
...@@ -118,9 +118,9 @@ void ReshapeTransposeMatmulMkldnnFusePass::Fuse( ...@@ -118,9 +118,9 @@ void ReshapeTransposeMatmulMkldnnFusePass::Fuse(
GET_IR_NODE_FROM_SUBGRAPH(matmul_out, matmul_out, rtm_pattern); GET_IR_NODE_FROM_SUBGRAPH(matmul_out, matmul_out, rtm_pattern);
auto reshape_shape = auto reshape_shape =
boost::get<std::vector<int>>(reshape_op->Op()->GetAttr("shape")); paddle::get<std::vector<int>>(reshape_op->Op()->GetAttr("shape"));
auto transpose_axis = auto transpose_axis =
boost::get<std::vector<int>>(transpose_op->Op()->GetAttr("axis")); paddle::get<std::vector<int>>(transpose_op->Op()->GetAttr("axis"));
OpDesc *matmul_desc = matmul_op->Op(); OpDesc *matmul_desc = matmul_op->Op();
std::string input_var_name = transpose_out->Name(); std::string input_var_name = transpose_out->Name();
......
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
#include <string> #include <string>
#include "boost/blank.hpp"
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/op_call_stack.h" #include "paddle/fluid/framework/op_call_stack.h"
...@@ -563,7 +564,7 @@ proto::AttrType OpDesc::GetAttrType(const std::string &name) const { ...@@ -563,7 +564,7 @@ proto::AttrType OpDesc::GetAttrType(const std::string &name) const {
it, it,
attrs_.end(), attrs_.end(),
platform::errors::NotFound("Attribute %s is not found.", name)); platform::errors::NotFound("Attribute %s is not found.", name));
return static_cast<proto::AttrType>(it->second.which() - 1); return static_cast<proto::AttrType>(it->second.index() - 1);
} }
std::vector<std::string> OpDesc::AttrNames() const { std::vector<std::string> OpDesc::AttrNames() const {
...@@ -584,7 +585,7 @@ void OpDesc::SetAttr(const std::string &name, const Attribute &v) { ...@@ -584,7 +585,7 @@ void OpDesc::SetAttr(const std::string &name, const Attribute &v) {
// NOTICE(minqiyang): pybind11 will take the empty list in python as // NOTICE(minqiyang): pybind11 will take the empty list in python as
// the std::vector<int> type in C++; so we have to change the attr's type // the std::vector<int> type in C++; so we have to change the attr's type
// here if we meet this issue // here if we meet this issue
proto::AttrType attr_type = static_cast<proto::AttrType>(v.which() - 1); proto::AttrType attr_type = static_cast<proto::AttrType>(v.index() - 1);
if (attr_type == proto::AttrType::INTS && if (attr_type == proto::AttrType::INTS &&
BOOST_GET_CONST(std::vector<int>, v).size() == 0u) { BOOST_GET_CONST(std::vector<int>, v).size() == 0u) {
// Find current attr via attr name and set the correct attribute value // Find current attr via attr name and set the correct attribute value
...@@ -837,9 +838,9 @@ void OpDesc::Flush() { ...@@ -837,9 +838,9 @@ void OpDesc::Flush() {
auto *attr_desc = desc_.add_attrs(); auto *attr_desc = desc_.add_attrs();
attr_desc->set_name(attr.first); attr_desc->set_name(attr.first);
attr_desc->set_type( attr_desc->set_type(
static_cast<proto::AttrType>(attr.second.which() - 1)); static_cast<proto::AttrType>(attr.second.index() - 1));
SetAttrDescVisitor visitor(attr_desc); SetAttrDescVisitor visitor(attr_desc);
boost::apply_visitor(visitor, attr.second); paddle::visit(visitor, attr.second);
} }
need_update_ = false; need_update_ = false;
......
...@@ -33,18 +33,18 @@ class OpVersionMap; ...@@ -33,18 +33,18 @@ class OpVersionMap;
} // namespace pb } // namespace pb
using OpAttrVariantT = using OpAttrVariantT =
boost::variant<bool, /* AttrType::BOOL */ paddle::variant<bool, /* AttrType::BOOL */
float, /* AttrType::FLOAT */ float, /* AttrType::FLOAT */
int32_t, /* AttrType::INT */ int32_t, /* AttrType::INT */
int64_t, /* AttrType::LONG*/ int64_t, /* AttrType::LONG*/
std::string, /* AttrType::STRING */ std::string, /* AttrType::STRING */
std::vector<bool>, /* AttrType::BOOLS */ std::vector<bool>, /* AttrType::BOOLS */
std::vector<float>, /* AttrType::FLOATS */ std::vector<float>, /* AttrType::FLOATS */
std::vector<int32_t>, /* AttrType::INTS */ std::vector<int32_t>, /* AttrType::INTS */
std::vector<int64_t>, /* AttrType::LONGS */ std::vector<int64_t>, /* AttrType::LONGS */
std::vector<std::string>, /* AttrType::STRINGS */ std::vector<std::string>, /* AttrType::STRINGS */
paddle::none_t /* None */ paddle::none_t /* None */
>; >;
struct OpUpdateInfo { struct OpUpdateInfo {
virtual ~OpUpdateInfo() = default; virtual ~OpUpdateInfo() = default;
......
...@@ -190,7 +190,7 @@ void OpAttrsToCinn(framework::OpDesc *pb_desc, cpp::OpDesc *cpp_desc) { ...@@ -190,7 +190,7 @@ void OpAttrsToCinn(framework::OpDesc *pb_desc, cpp::OpDesc *cpp_desc) {
IMPL_ONE(LONG, int64_t); IMPL_ONE(LONG, int64_t);
IMPL_ONE(LONGS, std::vector<int64_t>); IMPL_ONE(LONGS, std::vector<int64_t>);
case AttrType::BLOCK: { case AttrType::BLOCK: {
auto i = pb_desc->GetAttrIfExists<int16_t>(name); auto i = pb_desc->GetAttrIfExists<int32_t>(name);
cpp_desc->SetAttr<int32_t>(name, i); cpp_desc->SetAttr<int32_t>(name, i);
break; break;
} }
......
...@@ -56,7 +56,7 @@ namespace framework { ...@@ -56,7 +56,7 @@ namespace framework {
class OperatorBase; class OperatorBase;
using InferShapeVarPtr = boost::variant<VarDesc *, Variable *>; using InferShapeVarPtr = paddle::variant<VarDesc *, Variable *>;
class InferShapeContext { class InferShapeContext {
public: public:
......
...@@ -27,13 +27,13 @@ limitations under the License. */ ...@@ -27,13 +27,13 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
typedef boost::variant<int, typedef paddle::variant<int,
int64_t, int64_t,
float, float,
double, double,
std::string, std::string,
Tensor, Tensor,
LoDTensor /*, ChannelHolder*/> LoDTensor /*, ChannelHolder*/>
ElementVar; ElementVar;
class Tuple { class Tuple {
...@@ -64,8 +64,8 @@ bool Tuple::isSameType(const Tuple& t) const { ...@@ -64,8 +64,8 @@ bool Tuple::isSameType(const Tuple& t) const {
return false; return false;
} }
for (size_t j = 0; j < tuple_size; ++j) { for (size_t j = 0; j < tuple_size; ++j) {
auto type1 = get(j).which(); auto type1 = get(j).index();
auto type2 = t.get(j).which(); auto type2 = t.get(j).index();
if (type1 != type2) return false; if (type1 != type2) return false;
} }
return true; return true;
......
...@@ -22,9 +22,11 @@ limitations under the License. */ ...@@ -22,9 +22,11 @@ limitations under the License. */
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
#include "boost/blank.hpp"
#include "paddle/fluid/imperative/type_defs.h" #include "paddle/fluid/imperative/type_defs.h"
#include "paddle/fluid/platform/variant.h" #include "paddle/fluid/platform/variant.h"
#include "paddle/utils/small_vector.h" #include "paddle/utils/small_vector.h"
#include "paddle/utils/variant.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -40,38 +42,38 @@ class InferNoNeedBufferVarsFN; ...@@ -40,38 +42,38 @@ class InferNoNeedBufferVarsFN;
using VariableNameMap = std::map<std::string, std::vector<std::string>>; using VariableNameMap = std::map<std::string, std::vector<std::string>>;
using VariableValueMap = std::map<std::string, std::vector<Variable*>>; using VariableValueMap = std::map<std::string, std::vector<Variable*>>;
using Attribute = boost::variant<boost::blank, using Attribute = paddle::variant<boost::blank,
int, int,
float, float,
std::string, std::string,
std::vector<int>, std::vector<int>,
std::vector<float>, std::vector<float>,
std::vector<std::string>, std::vector<std::string>,
bool, bool,
std::vector<bool>, std::vector<bool>,
BlockDesc*, BlockDesc*,
int64_t, int64_t,
std::vector<BlockDesc*>, std::vector<BlockDesc*>,
std::vector<int64_t>, std::vector<int64_t>,
std::vector<double>>; std::vector<double>>;
using AttributeMap = std::unordered_map<std::string, Attribute>; using AttributeMap = std::unordered_map<std::string, Attribute>;
#ifdef PADDLE_WITH_ASCEND_CL #ifdef PADDLE_WITH_ASCEND_CL
using NPUAttribute = boost::variant<boost::blank, using NPUAttribute = paddle::variant<boost::blank,
int, int,
float, float,
std::string, std::string,
std::vector<int>, std::vector<int>,
std::vector<float>, std::vector<float>,
std::vector<std::string>, std::vector<std::string>,
bool, bool,
std::vector<bool>, std::vector<bool>,
BlockDesc*, BlockDesc*,
int64_t, int64_t,
std::vector<BlockDesc*>, std::vector<BlockDesc*>,
std::vector<int64_t>, std::vector<int64_t>,
std::vector<double>, std::vector<double>,
std::vector<std::vector<int64_t>>>; std::vector<std::vector<int64_t>>>;
using NPUAttributeMap = std::unordered_map<std::string, NPUAttribute>; using NPUAttributeMap = std::unordered_map<std::string, NPUAttribute>;
#endif #endif
......
...@@ -315,7 +315,7 @@ void VarDesc::SetAttr(const std::string &name, const Attribute &v) { ...@@ -315,7 +315,7 @@ void VarDesc::SetAttr(const std::string &name, const Attribute &v) {
// NOTICE(sandyhouse): pybind11 will take the empty list in python as // NOTICE(sandyhouse): pybind11 will take the empty list in python as
// the std::vector<int> type in C++; so we have to change the attr's type // the std::vector<int> type in C++; so we have to change the attr's type
// here if we meet this issue // here if we meet this issue
proto::AttrType attr_type = static_cast<proto::AttrType>(v.which() - 1); proto::AttrType attr_type = static_cast<proto::AttrType>(v.index() - 1);
if (attr_type == proto::AttrType::INTS && if (attr_type == proto::AttrType::INTS &&
BOOST_GET_CONST(std::vector<int>, v).size() == 0u) { BOOST_GET_CONST(std::vector<int>, v).size() == 0u) {
// Find current attr via attr name and set the correct attribute value // Find current attr via attr name and set the correct attribute value
......
...@@ -122,7 +122,7 @@ class FeedOp : public framework::OperatorBase { ...@@ -122,7 +122,7 @@ class FeedOp : public framework::OperatorBase {
auto &feed_item = feed_list.at(static_cast<size_t>(col)); auto &feed_item = feed_list.at(static_cast<size_t>(col));
FeedVariableVisitor visitor(out_var, place); FeedVariableVisitor visitor(out_var, place);
boost::apply_visitor(visitor, feed_item); paddle::visit(visitor, feed_item);
} }
}; };
......
...@@ -53,19 +53,19 @@ struct RawPointerVisitor : public boost::static_visitor<const void *> { ...@@ -53,19 +53,19 @@ struct RawPointerVisitor : public boost::static_visitor<const void *> {
}; };
const framework::VariableNameMap &OpVariant::Inputs() const { const framework::VariableNameMap &OpVariant::Inputs() const {
return *boost::apply_visitor(InputsVisitor(), op_); return *paddle::visit(InputsVisitor(), op_);
} }
const framework::VariableNameMap &OpVariant::Outputs() const { const framework::VariableNameMap &OpVariant::Outputs() const {
return *boost::apply_visitor(OutputsVisitor(), op_); return *paddle::visit(OutputsVisitor(), op_);
} }
const framework::AttributeMap &OpVariant::Attrs() const { const framework::AttributeMap &OpVariant::Attrs() const {
return *boost::apply_visitor(AttributeMapVisitor(), op_); return *paddle::visit(AttributeMapVisitor(), op_);
} }
const void *OpVariant::RawPointer() const { const void *OpVariant::RawPointer() const {
return boost::apply_visitor(RawPointerVisitor(), op_); return paddle::visit(RawPointerVisitor(), op_);
} }
void AppendOpVariantByOpName(const std::vector<framework::OpDesc *> &op_descs, void AppendOpVariantByOpName(const std::vector<framework::OpDesc *> &op_descs,
......
...@@ -61,7 +61,7 @@ class OpVariant { ...@@ -61,7 +61,7 @@ class OpVariant {
return RawPointer() == other.RawPointer(); return RawPointer() == other.RawPointer();
} }
int which() const { return static_cast<int>(op_.which()); } int index() const { return static_cast<int>(op_.index()); }
struct Hasher { struct Hasher {
size_t operator()(const OpVariant &op) const { size_t operator()(const OpVariant &op) const {
...@@ -70,8 +70,8 @@ class OpVariant { ...@@ -70,8 +70,8 @@ class OpVariant {
}; };
private: private:
const boost::variant<const framework::OperatorBase *, const paddle::variant<const framework::OperatorBase *,
const framework::OpDesc *> const framework::OpDesc *>
op_; op_;
}; };
......
...@@ -226,6 +226,7 @@ REGISTER_OPERATOR( ...@@ -226,6 +226,7 @@ REGISTER_OPERATOR(
ops::BoxDecoderAndAssignOpMaker, ops::BoxDecoderAndAssignOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
box_decoder_and_assign, box_decoder_and_assign,
ops::BoxDecoderAndAssignKernel<paddle::platform::CPUDeviceContext, float>, ops::BoxDecoderAndAssignKernel<paddle::platform::CPUDeviceContext, float>,
......
...@@ -122,7 +122,7 @@ class BoxDecoderAndAssignCUDAKernel : public framework::OpKernel<T> { ...@@ -122,7 +122,7 @@ class BoxDecoderAndAssignCUDAKernel : public framework::OpKernel<T> {
int grid = (roi_num * class_num + block - 1) / block; int grid = (roi_num * class_num + block - 1) / block;
auto& device_ctx = context.cuda_device_context(); auto& device_ctx = context.cuda_device_context();
const T box_clip = context.Attr<T>("box_clip"); const T box_clip = static_cast<T>(context.Attr<float>("box_clip"));
DecodeBoxKernel<T> DecodeBoxKernel<T>
<<<grid, block, 0, device_ctx.stream()>>>(prior_box_data, <<<grid, block, 0, device_ctx.stream()>>>(prior_box_data,
......
...@@ -41,7 +41,7 @@ class BoxDecoderAndAssignKernel : public framework::OpKernel<T> { ...@@ -41,7 +41,7 @@ class BoxDecoderAndAssignKernel : public framework::OpKernel<T> {
output_assign_box->mutable_data<T>({roi_num, 4}, context.GetPlace()); output_assign_box->mutable_data<T>({roi_num, 4}, context.GetPlace());
T* output_box_data = output_box->data<T>(); T* output_box_data = output_box->data<T>();
T* output_assign_box_data = output_assign_box->data<T>(); T* output_assign_box_data = output_assign_box->data<T>();
const T bbox_clip = context.Attr<T>("box_clip"); const T bbox_clip = static_cast<T>(context.Attr<float>("box_clip"));
for (int i = 0; i < roi_num; ++i) { for (int i = 0; i < roi_num; ++i) {
T prior_box_width = prior_box_data[i * 4 + 2] - prior_box_data[i * 4] + 1; T prior_box_width = prior_box_data[i * 4 + 2] - prior_box_data[i * 4] + 1;
......
...@@ -47,7 +47,7 @@ template <typename T> ...@@ -47,7 +47,7 @@ template <typename T>
void MatrixBitCodeFunctor<T>::Add(const framework::Tensor &vec, void MatrixBitCodeFunctor<T>::Add(const framework::Tensor &vec,
framework::Tensor *tmat) { framework::Tensor *tmat) {
MatrixBitCodeFunctorAdd<T> func(vec, tmat); MatrixBitCodeFunctorAdd<T> func(vec, tmat);
code_table_.apply_visitor(func); paddle::visit(func, code_table_);
} }
template <typename T> template <typename T>
...@@ -79,7 +79,7 @@ template <typename T> ...@@ -79,7 +79,7 @@ template <typename T>
void MatrixBitCodeFunctor<T>::AddGrad(const framework::Tensor &tmat, void MatrixBitCodeFunctor<T>::AddGrad(const framework::Tensor &tmat,
framework::Tensor *vec) { framework::Tensor *vec) {
MatrixBitCodeFunctorAddGrad<T> func(tmat, vec); MatrixBitCodeFunctorAddGrad<T> func(tmat, vec);
code_table_.apply_visitor(func); paddle::visit(func, code_table_);
} }
template <typename T> template <typename T>
...@@ -121,7 +121,7 @@ void MatrixBitCodeFunctor<T>::Sum(const framework::Tensor &tmat, ...@@ -121,7 +121,7 @@ void MatrixBitCodeFunctor<T>::Sum(const framework::Tensor &tmat,
framework::Tensor *sum, framework::Tensor *sum,
T scale_sum) { T scale_sum) {
MatrixBitCodeFunctorSum<T> func(tmat, sum, scale_sum); MatrixBitCodeFunctorSum<T> func(tmat, sum, scale_sum);
code_table_.apply_visitor(func); paddle::visit(func, code_table_);
} }
template <typename T> template <typename T>
...@@ -165,7 +165,7 @@ void MatrixBitCodeFunctor<T>::Mul(framework::Tensor *tmat, ...@@ -165,7 +165,7 @@ void MatrixBitCodeFunctor<T>::Mul(framework::Tensor *tmat,
const framework::Tensor &weight, const framework::Tensor &weight,
const framework::Tensor &input) { const framework::Tensor &input) {
MatrixBitCodeFunctorMul<T> func(tmat, weight, input); MatrixBitCodeFunctorMul<T> func(tmat, weight, input);
code_table_.apply_visitor(func); paddle::visit(func, code_table_);
} }
template <typename T, size_t N> template <typename T, size_t N>
...@@ -222,7 +222,7 @@ void MatrixBitCodeFunctor<T>::MulGradWeight(const framework::Tensor &tmat, ...@@ -222,7 +222,7 @@ void MatrixBitCodeFunctor<T>::MulGradWeight(const framework::Tensor &tmat,
framework::Tensor *weight, framework::Tensor *weight,
const framework::Tensor &input) { const framework::Tensor &input) {
MatrixBitCodeFunctorMulGradWeight<T> func(tmat, weight, input); MatrixBitCodeFunctorMulGradWeight<T> func(tmat, weight, input);
code_table_.apply_visitor(func); paddle::visit(func, code_table_);
} }
template <typename T> template <typename T>
...@@ -279,7 +279,7 @@ void MatrixBitCodeFunctor<T>::MulGradWeight(const framework::Tensor &tmat, ...@@ -279,7 +279,7 @@ void MatrixBitCodeFunctor<T>::MulGradWeight(const framework::Tensor &tmat,
phi::SelectedRows *weight, phi::SelectedRows *weight,
const framework::Tensor &input) { const framework::Tensor &input) {
MatrixBitCodeFunctorMulGradWeightSR<T> func(tmat, weight, input); MatrixBitCodeFunctorMulGradWeightSR<T> func(tmat, weight, input);
code_table_.apply_visitor(func); paddle::visit(func, code_table_);
} }
template <typename T> template <typename T>
...@@ -323,7 +323,7 @@ void MatrixBitCodeFunctor<T>::MulGradError(const framework::Tensor &tmat, ...@@ -323,7 +323,7 @@ void MatrixBitCodeFunctor<T>::MulGradError(const framework::Tensor &tmat,
const framework::Tensor &weight, const framework::Tensor &weight,
framework::Tensor *input) { framework::Tensor *input) {
MatrixBitCodeFunctorMulGradError<T> func(tmat, weight, input); MatrixBitCodeFunctorMulGradError<T> func(tmat, weight, input);
code_table_.apply_visitor(func); paddle::visit(func, code_table_);
} }
template <typename T> template <typename T>
...@@ -352,7 +352,7 @@ struct MatrixBitCodeFunctorSub : public boost::static_visitor<void> { ...@@ -352,7 +352,7 @@ struct MatrixBitCodeFunctorSub : public boost::static_visitor<void> {
template <typename T> template <typename T>
void MatrixBitCodeFunctor<T>::Sub(framework::Tensor *tmat) { void MatrixBitCodeFunctor<T>::Sub(framework::Tensor *tmat) {
MatrixBitCodeFunctorSub<T> func(tmat); MatrixBitCodeFunctorSub<T> func(tmat);
code_table_.apply_visitor(func); paddle::visit(func, code_table_);
} }
template class MatrixBitCodeFunctor<float>; template class MatrixBitCodeFunctor<float>;
......
...@@ -208,7 +208,7 @@ class CustomCodeTable { ...@@ -208,7 +208,7 @@ class CustomCodeTable {
const int64_t* ids_; const int64_t* ids_;
}; };
using CodeTable = boost::variant<SimpleCodeTable, CustomCodeTable<int64_t>>; using CodeTable = paddle::variant<SimpleCodeTable, CustomCodeTable<int64_t>>;
template <typename T> template <typename T>
class MatrixBitCodeFunctor { class MatrixBitCodeFunctor {
......
...@@ -201,7 +201,8 @@ class SliceOpVarTypeInference : public framework::VarTypeInference { ...@@ -201,7 +201,8 @@ class SliceOpVarTypeInference : public framework::VarTypeInference {
auto x_name = "Input"; auto x_name = "Input";
auto out_name = "Out"; auto out_name = "Out";
auto decrease_axis = ctx->GetAttr("decrease_axis"); auto decrease_axis = ctx->GetAttr("decrease_axis");
auto not_decrease = boost::get<std::vector<int>>(decrease_axis).size() == 0; auto not_decrease =
paddle::get<std::vector<int>>(decrease_axis).size() == 0;
if (not_decrease) { if (not_decrease) {
// The default type of out is LoDTensor. // The default type of out is LoDTensor.
// However, if no axis is decreased and the type of input is not // However, if no axis is decreased and the type of input is not
......
...@@ -19,11 +19,11 @@ ...@@ -19,11 +19,11 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "boost/variant.hpp"
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/platform/device/npu/dynload/hccl.h" #include "paddle/fluid/platform/device/npu/dynload/hccl.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/utils/variant.h"
#if defined(PADDLE_WITH_CNCL) #if defined(PADDLE_WITH_CNCL)
#include "paddle/fluid/platform/device/mlu/device_context.h" #include "paddle/fluid/platform/device/mlu/device_context.h"
#endif #endif
......
...@@ -19,6 +19,8 @@ ...@@ -19,6 +19,8 @@
#include <popart/optimizer.hpp> #include <popart/optimizer.hpp>
#include <popart/sgd.hpp> #include <popart/sgd.hpp>
#include "boost/blank.hpp"
#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/platform/device/ipu/ipu_names.h" #include "paddle/fluid/platform/device/ipu/ipu_names.h"
#include "paddle/fluid/platform/device/ipu/ipu_strategy.h" #include "paddle/fluid/platform/device/ipu/ipu_strategy.h"
...@@ -390,7 +392,7 @@ void Compiler::LowerConstants(const Scope* scope) { ...@@ -390,7 +392,7 @@ void Compiler::LowerConstants(const Scope* scope) {
auto* tensor = var->GetMutable<framework::LoDTensor>(); auto* tensor = var->GetMutable<framework::LoDTensor>();
ConstantOpAttrVisitor visitor(tensor, dtype); ConstantOpAttrVisitor visitor(tensor, dtype);
auto value = op_desc->GetAttr("value"); auto value = op_desc->GetAttr("value");
boost::apply_visitor(visitor, value); paddle::visit(visitor, value);
auto ddim = phi::make_ddim(shape); auto ddim = phi::make_ddim(shape);
tensor->Resize(ddim); tensor->Resize(ddim);
...@@ -475,7 +477,7 @@ void Compiler::LowerBody() { ...@@ -475,7 +477,7 @@ void Compiler::LowerBody() {
auto attributes = std::map<std::string, popart::any>{}; auto attributes = std::map<std::string, popart::any>{};
for (auto& attr : op_desc->GetAttrMap()) { for (auto& attr : op_desc->GetAttrMap()) {
CustomOpAttrVisitor visitor(&attributes, attr.first); CustomOpAttrVisitor visitor(&attributes, attr.first);
boost::apply_visitor(visitor, attr.second); paddle::visit(visitor, attr.second);
} }
auto __op_type = auto __op_type =
BOOST_GET_CONST(std::string, op_desc->GetAttr("__op_type")); BOOST_GET_CONST(std::string, op_desc->GetAttr("__op_type"));
......
...@@ -122,20 +122,21 @@ using namespace ::phi::enforce; // NOLINT ...@@ -122,20 +122,21 @@ using namespace ::phi::enforce; // NOLINT
#endif #endif
/* /*
* Summary: This BOOST_GET(_**) series macros are used to call boost::get * Summary: This BOOST_GET(_**) series macros are used to call paddle::get
* safely. boost::get is not a completely safe api, although it will not * safely. paddle::get is not a completely safe api, although it will not
* go wrong in most cases, but in extreme cases, it may fail and directly * go wrong in most cases, but in extreme cases, it may fail and directly
* throw a boost::bad_get exception, without any stack information. * throw a paddle::bad_variant_access const exception, without any stack
*information.
* This kind of problems is difficult to debug, so add these macros to * This kind of problems is difficult to debug, so add these macros to
* enrich boost::get error information. At the same time, we restrict * enrich paddle::get error information. At the same time, we restrict
* the direct use of boost::get by CI rule. * the direct use of paddle::get by CI rule.
* *
* Parameters: * Parameters:
*     __TYPE: the target variable type *     __TYPE: the target variable type
* __VALUE: the target variable to get * __VALUE: the target variable to get
* *
* Examples: * Examples:
* - unsafe writing: int x = boost::get<int>(y); * - unsafe writing: int x = paddle::get<int>(y);
* - safe writing: int x = BOOST_GET(int, y); * - safe writing: int x = BOOST_GET(int, y);
* *
* Note: GCC 4.8 cannot select right overloaded function here, so need * Note: GCC 4.8 cannot select right overloaded function here, so need
...@@ -155,12 +156,12 @@ using namespace phi::enforce::details; // NOLINT ...@@ -155,12 +156,12 @@ using namespace phi::enforce::details; // NOLINT
__OutputTypePtr, \ __OutputTypePtr, \
__OutputType>::type { \ __OutputType>::type { \
try { \ try { \
return boost::get<OutputType>(input); \ return paddle::get<OutputType>(input); \
} catch (boost::bad_get&) { \ } catch (paddle::bad_variant_access const&) { \
HANDLE_THE_ERROR \ HANDLE_THE_ERROR \
throw ::phi::enforce::EnforceNotMet( \ throw ::phi::enforce::EnforceNotMet( \
phi::errors::InvalidArgument( \ phi::errors::InvalidArgument( \
"boost::get failed, cannot get value " \ "paddle::get failed, cannot get value " \
"(%s) by type %s, its type is %s.", \ "(%s) by type %s, its type is %s.", \
expression, \ expression, \
phi::enforce::demangle(typeid(OutputType).name()), \ phi::enforce::demangle(typeid(OutputType).name()), \
......
...@@ -22,13 +22,14 @@ ...@@ -22,13 +22,14 @@
#include "gflags/gflags.h" #include "gflags/gflags.h"
#include "paddle/fluid/platform/macros.h" #include "paddle/fluid/platform/macros.h"
#include "paddle/fluid/platform/variant.h" #include "paddle/fluid/platform/variant.h"
#include "paddle/utils/variant.h"
namespace paddle { namespace paddle {
namespace platform { namespace platform {
struct FlagInfo { struct FlagInfo {
using ValueType = using ValueType =
boost::variant<bool, int32_t, int64_t, uint64_t, double, std::string>; paddle::variant<bool, int32_t, int64_t, uint64_t, double, std::string>;
std::string name; std::string name;
mutable void *value_ptr; mutable void *value_ptr;
ValueType default_value; ValueType default_value;
......
...@@ -259,7 +259,7 @@ static void RegisterGlobalVarGetterSetter() { ...@@ -259,7 +259,7 @@ static void RegisterGlobalVarGetterSetter() {
const auto &default_value = pair.second.default_value; const auto &default_value = pair.second.default_value;
RegisterGetterSetterVisitor visitor( RegisterGetterSetterVisitor visitor(
"FLAGS_" + name, is_writable, value_ptr); "FLAGS_" + name, is_writable, value_ptr);
boost::apply_visitor(visitor, default_value); paddle::visit(visitor, default_value);
} }
} }
......
...@@ -3338,7 +3338,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -3338,7 +3338,7 @@ All parameter, weight, gradient are variables in Paddle.
py::return_value_policy::take_ownership); py::return_value_policy::take_ownership);
py::class_<FetchList>(m, "FetchList", R"DOC( FetchList is a py::class_<FetchList>(m, "FetchList", R"DOC( FetchList is a
vector of boost::variant<LoDTensor, LoDTensorArray>. vector of paddle::variant<LoDTensor, LoDTensorArray>.
)DOC") )DOC")
.def( .def(
"_move_to_list", "_move_to_list",
...@@ -3385,7 +3385,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -3385,7 +3385,7 @@ All parameter, weight, gradient are variables in Paddle.
py::arg("var")); py::arg("var"));
py::class_<FetchUnmergedList>(m, "FetchUnmergedList", R"DOC( py::class_<FetchUnmergedList>(m, "FetchUnmergedList", R"DOC(
FetchUnmergedList is 2-D array of FetchType(boost::variant(LoDTensor, LoDTensorArray)). FetchUnmergedList is 2-D array of FetchType(paddle::variant(LoDTensor, LoDTensorArray)).
)DOC") )DOC")
.def( .def(
"_move_to_list", "_move_to_list",
...@@ -4606,12 +4606,15 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -4606,12 +4606,15 @@ All parameter, weight, gradient are variables in Paddle.
pybind11::gil_scoped_release release; pybind11::gil_scoped_release release;
ret = self.Run(fetch_tensors, return_merged); ret = self.Run(fetch_tensors, return_merged);
} }
// TODO(Ruibiao): Refactor the run interface of PE to avoid use
// boost::get here
if (return_merged) { if (return_merged) {
return py::cast( return py::cast(
std::move(BOOST_GET(paddle::framework::FetchList, ret))); std::move(boost::get<paddle::framework::FetchList>(ret)));
} else { } else {
return py::cast(std::move( return py::cast(std::move(
BOOST_GET(paddle::framework::FetchUnmergedList, ret))); boost::get<paddle::framework::FetchUnmergedList>(ret)));
} }
}) })
.def("device_count", &ParallelExecutor::DeviceCount); .def("device_count", &ParallelExecutor::DeviceCount);
......
...@@ -19,11 +19,13 @@ limitations under the License. */ ...@@ -19,11 +19,13 @@ limitations under the License. */
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/fluid/platform/variant.h" #include "paddle/fluid/platform/variant.h"
#include "paddle/utils/variant.h"
#include "pybind11/numpy.h" #include "pybind11/numpy.h"
#include "pybind11/pybind11.h" #include "pybind11/pybind11.h"
#include "pybind11/stl.h" #include "pybind11/stl.h"
// Cast boost::variant for PyBind. // Cast paddle::variant for PyBind.
// Copy from // Copy from
// https://github.com/pybind/pybind11/issues/576#issuecomment-269563199 // https://github.com/pybind/pybind11/issues/576#issuecomment-269563199
namespace pybind11 { namespace pybind11 {
namespace detail { namespace detail {
...@@ -37,8 +39,7 @@ namespace detail { ...@@ -37,8 +39,7 @@ namespace detail {
#endif #endif
// Can be replaced by a generic lambda in C++14 // Can be replaced by a generic lambda in C++14
struct PYBIND11_HIDDEN paddle_variant_caster_visitor struct PYBIND11_HIDDEN paddle_variant_caster_visitor {
: public boost::static_visitor<handle> {
return_value_policy policy; return_value_policy policy;
handle parent; handle parent;
...@@ -127,8 +128,13 @@ struct paddle_variant_caster<V<Ts...>> { ...@@ -127,8 +128,13 @@ struct paddle_variant_caster<V<Ts...>> {
static handle cast(Type const& src, static handle cast(Type const& src,
return_value_policy policy, return_value_policy policy,
handle parent) { handle parent) {
/*
auto paddle_variant_caster_visitor = [&](Type const& src)->handle {
return make_caster<Type>::cast(src, policy, parent);
}
*/
paddle_variant_caster_visitor visitor(policy, parent); paddle_variant_caster_visitor visitor(policy, parent);
return boost::apply_visitor(visitor, src); return paddle::visit(visitor, src);
} }
PYBIND11_TYPE_CASTER(Type, _("Variant")); PYBIND11_TYPE_CASTER(Type, _("Variant"));
...@@ -137,8 +143,8 @@ struct paddle_variant_caster<V<Ts...>> { ...@@ -137,8 +143,8 @@ struct paddle_variant_caster<V<Ts...>> {
// Add specialization for concrete variant type // Add specialization for concrete variant type
template <class... Args> template <class... Args>
struct type_caster<boost::variant<Args...>> struct type_caster<paddle::variant<Args...>>
: paddle_variant_caster<boost::variant<Args...>> {}; : paddle_variant_caster<paddle::variant<Args...>> {};
} // namespace detail } // namespace detail
} // namespace pybind11 } // namespace pybind11
...@@ -36,7 +36,7 @@ class ProtoArgumentMappingContext : public ::phi::ArgumentMappingContext { ...@@ -36,7 +36,7 @@ class ProtoArgumentMappingContext : public ::phi::ArgumentMappingContext {
bool HasAttr(const std::string& name) const override; bool HasAttr(const std::string& name) const override;
// now we can't use Attribute here, it will cause phi relay on // now we can't use Attribute here, it will cause phi relay on
// boost::variant and BlockDesc // paddle::variant and BlockDesc
paddle::any Attr(const std::string& name) const override; paddle::any Attr(const std::string& name) const override;
size_t InputSize(const std::string& name) const override; size_t InputSize(const std::string& name) const override;
......
...@@ -99,7 +99,7 @@ class ArgumentMappingContext { ...@@ -99,7 +99,7 @@ class ArgumentMappingContext {
virtual bool HasAttr(const std::string& name) const = 0; virtual bool HasAttr(const std::string& name) const = 0;
// now we can't use Attribute here, it will cause phi relay on // now we can't use Attribute here, it will cause phi relay on
// boost::variant and BlockDesc // paddle::variant and BlockDesc
virtual paddle::any Attr(const std::string& name) const = 0; virtual paddle::any Attr(const std::string& name) const = 0;
virtual size_t InputSize(const std::string& name) const = 0; virtual size_t InputSize(const std::string& name) const = 0;
......
...@@ -113,7 +113,7 @@ void* DenseTensor::mutable_data(const Place& place, ...@@ -113,7 +113,7 @@ void* DenseTensor::mutable_data(const Place& place,
size = requested_size; size = requested_size;
} }
/* some versions of boost::variant don't have operator!= */ /* some versions of paddle::variant don't have operator!= */
if (holder_ == nullptr || !(holder_->place() == place) || if (holder_ == nullptr || !(holder_->place() == place) ||
holder_->size() < size + meta_.offset) { holder_->size() < size + meta_.offset) {
holder_.reset(); holder_.reset();
...@@ -142,7 +142,7 @@ void* DenseTensor::mutable_data(const Place& place, ...@@ -142,7 +142,7 @@ void* DenseTensor::mutable_data(const Place& place,
"] now")); "] now"));
size_t size = numel() * SizeOf(dtype()); size_t size = numel() * SizeOf(dtype());
/* some versions of boost::variant don't have operator!= */ /* some versions of paddle::variant don't have operator!= */
if (holder_ == nullptr || !(holder_->place() == place) || if (holder_ == nullptr || !(holder_->place() == place) ||
holder_->size() < size + meta_.offset || holder_->size() < size + meta_.offset ||
!(place.GetType() == phi::AllocationType::GPU && !(place.GetType() == phi::AllocationType::GPU &&
......
...@@ -20,31 +20,29 @@ limitations under the License. */ ...@@ -20,31 +20,29 @@ limitations under the License. */
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
// <boost/variant.hpp> is not suitable to be placed in the header file, #include "boost/blank.hpp"
// it will introduce a large number of unnecessary includes, and these type #include "paddle/utils/variant.h"
// declarations that depend on boost are also not suitable for the phi header
// file. Do some repeated forward declarations here to avoid
// <boost/variant.hpp> spreading to a large number of phi kernel files
namespace egr { namespace egr {
class EagerVariable; class EagerVariable;
} }
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class BlockDesc; class BlockDesc;
using Attribute = boost::variant<boost::blank, using Attribute = paddle::variant<boost::blank,
int, int,
float, float,
std::string, std::string,
std::vector<int>, std::vector<int>,
std::vector<float>, std::vector<float>,
std::vector<std::string>, std::vector<std::string>,
bool, bool,
std::vector<bool>, std::vector<bool>,
BlockDesc*, BlockDesc*,
int64_t, int64_t,
std::vector<BlockDesc*>, std::vector<BlockDesc*>,
std::vector<int64_t>, std::vector<int64_t>,
std::vector<double>>; std::vector<double>>;
using AttributeMap = std::unordered_map<std::string, Attribute>; using AttributeMap = std::unordered_map<std::string, Attribute>;
} // namespace framework } // namespace framework
namespace imperative { namespace imperative {
......
...@@ -179,7 +179,7 @@ dtype::pstring* StringTensor::mutable_data(const phi::Place& place, ...@@ -179,7 +179,7 @@ dtype::pstring* StringTensor::mutable_data(const phi::Place& place,
size = requested_size; size = requested_size;
} }
/* some versions of boost::variant don't have operator!= */ /* some versions of paddle::variant don't have operator!= */
if (holder_ == nullptr || !(holder_->place() == place) || if (holder_ == nullptr || !(holder_->place() == place) ||
holder_->size() < size + meta_.offset) { holder_->size() < size + meta_.offset) {
holder_.reset(); holder_.reset();
......
...@@ -2,6 +2,11 @@ ...@@ -2,6 +2,11 @@
// https://github.com/mpark/variant/blob/single-header/v1.4.0/variant.hpp // https://github.com/mpark/variant/blob/single-header/v1.4.0/variant.hpp
// Modify the following points: // Modify the following points:
// 1. modify namespace mpark to namespace paddle // 1. modify namespace mpark to namespace paddle
// 2. add type() member function for variant class
// 3. remove the visitation implementation under the branhch with
// MPARK_CPP14_CONSTEXPR defined since lib::cpp14::array could not be converted
// to std::initializer_list in Paddle's compilation
// 4. decorate PYBIND11_HIDDEN for struct value_visitor
// MPark.Variant // MPark.Variant
// //
...@@ -22,6 +27,14 @@ ...@@ -22,6 +27,14 @@
#pragma GCC diagnostic ignored "-Wdeprecated-copy" #pragma GCC diagnostic ignored "-Wdeprecated-copy"
#endif #endif
#if !defined(PYBIND11_HIDDEN)
#ifdef _WIN32
#define PYBIND11_HIDDEN __declspec(dllexport)
#else
#define PYBIND11_HIDDEN __attribute__((visibility("hidden")))
#endif
#endif
/* /*
variant synopsis variant synopsis
...@@ -1649,7 +1662,7 @@ struct variant { ...@@ -1649,7 +1662,7 @@ struct variant {
}; };
template <typename Visitor> template <typename Visitor>
struct value_visitor { struct PYBIND11_HIDDEN value_visitor {
Visitor &&visitor_; Visitor &&visitor_;
template <typename... Alts> template <typename... Alts>
...@@ -2454,7 +2467,7 @@ class variant { ...@@ -2454,7 +2467,7 @@ class variant {
impl_.swap(that.impl_); impl_.swap(that.impl_);
} }
inline const std::type_info &type() noexcept { return impl_.type(); } inline const std::type_info &type() const noexcept { return impl_.type(); }
private: private:
detail::impl<Ts...> impl_; detail::impl<Ts...> impl_;
...@@ -2708,30 +2721,6 @@ inline constexpr bool operator!=(monostate, monostate) noexcept { ...@@ -2708,30 +2721,6 @@ inline constexpr bool operator!=(monostate, monostate) noexcept {
return false; return false;
} }
#ifdef MPARK_CPP14_CONSTEXPR
namespace detail {
inline constexpr bool all(std::initializer_list<bool> bs) {
for (bool b : bs) {
if (!b) {
return false;
}
}
return true;
}
} // namespace detail
template <typename Visitor, typename... Vs>
inline constexpr decltype(auto) visit(Visitor &&visitor, Vs &&...vs) {
return (detail::all(
lib::array<bool, sizeof...(Vs)>{!vs.valueless_by_exception()...})
? (void)0
: throw_bad_variant_access()),
detail::visitation::variant::visit_value(
lib::forward<Visitor>(visitor), lib::forward<Vs>(vs)...);
}
#else
namespace detail { namespace detail {
template <std::size_t N> template <std::size_t N>
...@@ -2755,12 +2744,11 @@ inline constexpr DECLTYPE_AUTO visit(Visitor &&visitor, Vs &&...vs) ...@@ -2755,12 +2744,11 @@ inline constexpr DECLTYPE_AUTO visit(Visitor &&visitor, Vs &&...vs)
: throw_bad_variant_access()), : throw_bad_variant_access()),
detail::visitation::variant::visit_value(lib::forward<Visitor>(visitor), detail::visitation::variant::visit_value(lib::forward<Visitor>(visitor),
lib::forward<Vs>(vs)...)) lib::forward<Vs>(vs)...))
#endif
template <typename... Ts> template <typename... Ts>
inline auto swap(variant<Ts...> &lhs, inline auto swap(variant<Ts...> &lhs,
variant<Ts...> &rhs) noexcept(noexcept(lhs.swap(rhs))) variant<Ts...> &rhs) noexcept(noexcept(lhs.swap(rhs)))
-> decltype(lhs.swap(rhs)) { -> decltype(lhs.swap(rhs)) {
lhs.swap(rhs); lhs.swap(rhs);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册