未验证 提交 35961404 编写于 作者: T Twice 提交者: GitHub

optional: refactor value_or to allow auto deref & prevent ref dangling (#6520)

* optional: refactor value_or to allow auto deref & prevent ref dangling

* use value_or

* use value_or

* fix compiler error: invalid abstract return type
Co-authored-by: Noneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
上级 d61b8384
......@@ -54,8 +54,7 @@ class ConsistentToConsistentGradFunction : public OpExprGradFunction<ConsistentT
const auto& out_grad = out_grads.at(0);
CHECK_OR_RETURN(out_grad->is_consistent());
in_grads->resize(1);
const auto& grad_nd_sbp =
grad_nd_sbp_.has_value() ? JUST(grad_nd_sbp_) : JUST(out_grad->nd_sbp());
const auto& grad_nd_sbp = grad_nd_sbp_.value_or(JUST(out_grad->nd_sbp()));
const auto& grad_sbp_list = JUST(GetSbpList(grad_nd_sbp));
const auto& grad_grad_sbp_list = JUST(GetSbpList(ctx->nd_sbp));
in_grads->at(0) = JUST(one::functional::ToConsistent(out_grad, ctx->parallel_desc,
......
......@@ -86,6 +86,14 @@ class OptionalBase<T, typename std::enable_if<IsScalarType<T>::value>::type> {
bool has_value() const { return init_; }
T value_or(const T& other) const {
if (has_value()) {
return value();
} else {
return other;
}
}
void reset() { init_ = false; }
private:
......@@ -126,6 +134,14 @@ class OptionalBase<T, typename std::enable_if<std::is_reference<T>::value>::type
bool has_value() const { return value_; }
const value_type& value_or(const value_type& other) const {
if (has_value()) {
return value();
} else {
return other;
}
}
void reset() { value_ = nullptr; }
private:
......@@ -199,6 +215,87 @@ class OptionalBase<
bool has_value() const { return bool(value_); }
const storage_type& value_or(const storage_type& other) const& {
if (has_value()) {
return value_;
} else {
return other;
}
}
storage_type value_or(const storage_type& other) && {
if (has_value()) {
return std::move(value_);
} else {
return other;
}
}
storage_type value_or(storage_type&& other) const& {
if (has_value()) {
return value_;
} else {
return std::move(other);
}
}
storage_type value_or(storage_type&& other) && {
if (has_value()) {
return std::move(value_);
} else {
return std::move(other);
}
}
// we introduce a dependent name `U` to delay the instantiation,
// so only the default parameter of `U` is allowed
template<typename U = value_type>
typename std::enable_if<!std::is_abstract<U>::value, const U&>::type value_or(
const value_type& other) const& {
static_assert(std::is_same<U, value_type>::value, "expected default U");
if (has_value()) {
return *value_;
} else {
return other;
}
}
template<typename U = value_type>
typename std::enable_if<!std::is_abstract<U>::value, U>::type value_or(
const value_type& other) && {
static_assert(std::is_same<U, value_type>::value, "expected default U");
if (has_value()) {
return std::move(*value_);
} else {
return other;
}
}
template<typename U = value_type>
typename std::enable_if<!std::is_abstract<U>::value, U>::type value_or(
value_type&& other) const& {
static_assert(std::is_same<U, value_type>::value, "expected default U");
if (has_value()) {
return *value_;
} else {
return std::move(other);
}
}
template<typename U = value_type>
typename std::enable_if<!std::is_abstract<U>::value, U>::type value_or(value_type&& other) && {
static_assert(std::is_same<U, value_type>::value, "expected default U");
if (has_value()) {
return std::move(*value_);
} else {
return std::move(other);
}
}
void reset() { value_.reset(); }
private:
......@@ -216,10 +313,6 @@ class Optional final : private internal::OptionalBase<T> {
using value_type = typename base::value_type;
using storage_type = typename base::storage_type;
using const_return_type = decltype(std::declval<const base&>().value());
using return_type = decltype(std::declval<base&>().value());
using move_return_type = decltype(std::declval<base&&>().value());
Optional() = default;
~Optional() = default;
......@@ -247,18 +340,21 @@ class Optional final : private internal::OptionalBase<T> {
Optional& operator=(const Optional& rhs) = default;
Optional& operator=(Optional&& rhs) noexcept = default;
const_return_type value_or(const_return_type default_) const {
if (has_value()) {
return base::value();
} else {
return default_;
}
template<typename U>
auto value_or(U&& other) const& -> decltype(base::value_or(std::forward<U>(other))) {
return base::value_or(std::forward<U>(other));
}
template<typename U>
auto value_or(U&& other) && -> decltype(std::move(*this).base::value_or(std::forward<U>(other))) {
return std::move(*this).base::value_or(std::forward<U>(other));
}
bool has_value() const { return base::has_value(); }
explicit operator bool() const { return has_value(); }
move_return_type Data_YouAreNotAllowedToCallThisFuncOutsideThisFile() && {
auto Data_YouAreNotAllowedToCallThisFuncOutsideThisFile() && -> decltype(
std::move(*this).base::value()) {
return std::move(*this).base::value();
}
......
......@@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License.
*/
#include <gtest/gtest.h>
#include "oneflow/core/common/just.h"
#include "oneflow/core/common/optional.h"
#include "oneflow/core/common/util.h"
#include "oneflow/core/common/exception.h"
......@@ -92,6 +93,14 @@ TEST(Optional, non_scalar) {
auto x = std::make_shared<std::vector<int>>(1);
ASSERT_EQ(b.value_or(x), x);
ASSERT_EQ(b.value_or(std::vector<int>{1, 2, 3}), (std::vector<int>{1, 2, 3}));
ASSERT_EQ(b.value_or(*x), *x);
ASSERT_EQ(a.value_or(*x), *CHECK_JUST(a));
ASSERT_EQ(Optional<std::vector<int>>().value_or(*x), *x);
ASSERT_EQ(Optional<std::vector<int>>().value_or(std::vector<int>{1, 2, 3}),
(std::vector<int>{1, 2, 3}));
Optional<const std::vector<int>> c(std::vector<int>{1, 2, 3});
ASSERT_EQ(CHECK_JUST(c)->at(1), 2);
......
......@@ -372,10 +372,10 @@ class EyeFunctor {
const Optional<Symbol<Device>>& device) const {
MutableAttrMap attrs;
JUST(attrs.SetAttr<int64_t>("n", JUST(n.As<int64_t>())));
JUST(attrs.SetAttr<int64_t>("m", m ? JUST(JUST(m)->As<int64_t>()) : JUST(n.As<int64_t>())));
JUST(attrs.SetAttr<int64_t>("m", JUST(m.value_or(n).As<int64_t>())));
JUST(attrs.SetAttr<DataType>("dtype", dtype ? JUST(dtype)->data_type() : DataType::kFloat));
OpExprInterpContext ctx(attrs);
if (device) { ctx.device = JUST(device); }
ctx.device = device;
return OpInterpUtil::Dispatch<Tensor>(*op_, {}, ctx);
}
......@@ -392,7 +392,7 @@ class ConsistentEyeFunctor {
const std::vector<Symbol<cfg::SbpParallel>>& sbp_tuple) const {
MutableAttrMap attrs;
JUST(attrs.SetAttr<int64_t>("n", JUST(n.As<int64_t>())));
JUST(attrs.SetAttr<int64_t>("m", m ? JUST(JUST(m)->As<int64_t>()) : JUST(n.As<int64_t>())));
JUST(attrs.SetAttr<int64_t>("m", JUST(m.value_or(n).As<int64_t>())));
JUST(attrs.SetAttr<DataType>("dtype", dtype ? JUST(dtype)->data_type() : DataType::kFloat));
if (LazyMode::is_enabled()) {
std::vector<std::string> nd_sbp(sbp_tuple.size());
......@@ -461,7 +461,7 @@ class ArangeFunctor {
JUST(attrs.SetAttr<double>("float_delta", JUST(delta.As<double>())));
}
OpExprInterpContext ctx(attrs);
if (device) { ctx.device = JUST(device); }
ctx.device = device;
return OpInterpUtil::Dispatch<Tensor>(*op_, {}, ctx);
}
......
......@@ -103,7 +103,7 @@ class RandFunctor {
const auto& distribution_state = std::make_shared<DistributionKernelState>(gen);
OpExprInterpContext ctx(attrs, distribution_state);
if (device) { ctx.device = JUST(device); }
ctx.device = device;
auto result = JUST(OpInterpUtil::Dispatch<Tensor>(*op_, {}, ctx));
JUST(result->set_requires_grad(requires_grad));
return result;
......@@ -192,7 +192,7 @@ class RandNFunctor {
const auto& distribution_state = std::make_shared<DistributionKernelState>(gen);
OpExprInterpContext ctx(attrs, distribution_state);
if (device) { ctx.device = JUST(device); }
ctx.device = device;
auto result = JUST(OpInterpUtil::Dispatch<Tensor>(*op_, {}, ctx));
JUST(result->set_requires_grad(requires_grad));
return result;
......@@ -276,7 +276,7 @@ class RandIntFunctor {
const auto& distribution_state = std::make_shared<DistributionKernelState>(gen);
OpExprInterpContext ctx(attrs, distribution_state);
if (device) { ctx.device = JUST(device); }
ctx.device = device;
auto result = JUST(OpInterpUtil::Dispatch<Tensor>(*op_, {}, ctx));
JUST(result->set_requires_grad(requires_grad));
......@@ -384,7 +384,7 @@ class RandPermFunctor {
const auto& distribution_state = std::make_shared<DistributionKernelState>(gen);
OpExprInterpContext ctx(attrs, distribution_state);
if (device) { ctx.device = JUST(device); }
ctx.device = device;
auto result = JUST(OpInterpUtil::Dispatch<Tensor>(*randperm_op_, {}, ctx));
JUST(result->set_requires_grad(requires_grad));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册