...
 
Commits (3)
    https://gitcode.net/Oneflow-Inc/oneflow/-/commit/21caffd9d94e70538a9035abdcf215405d045167 just macro: rename local variables to prevent shadowing (#6667) 2021-11-01T21:37:51+08:00 Twice i@twice.moe Co-authored-by: <span data-trailer="Co-authored-by:"><a href="mailto:69100618+oneflow-ci-bot@users.noreply.github.com" title="69100618+oneflow-ci-bot@users.noreply.github.com"></a><a href="javascript:void(0)" class="avatar s16 avatar-inline identicon bg6" style="text-decoration: none">N</a><a href="mailto:69100618+oneflow-ci-bot@users.noreply.github.com" title="69100618+oneflow-ci-bot@users.noreply.github.com">oneflow-ci-bot</a> &lt;<a href="mailto:69100618+oneflow-ci-bot@users.noreply.github.com" title="69100618+oneflow-ci-bot@users.noreply.github.com">69100618+oneflow-ci-bot@users.noreply.github.com</a>&gt;</span> https://gitcode.net/Oneflow-Inc/oneflow/-/commit/8b94ac9b8fd0578aeed91a85c955a7f2a400b6aa restruct reshape gradient funcs (#6634) 2021-11-02T04:15:23+00:00 Luyang flowingsun007@163.com * restruct * refine https://gitcode.net/Oneflow-Inc/oneflow/-/commit/8edb7e5f97775a0bf793660df22d9ad715c63034 Merge branch 'master' into fea/graph_op_debug 2021-11-02T17:00:26+08:00 Xiaoyu Xu xiaoyulink@gmail.com
...@@ -24,7 +24,11 @@ limitations under the License. ...@@ -24,7 +24,11 @@ limitations under the License.
namespace oneflow { namespace oneflow {
namespace one { namespace one {
class ReshapeOpExprGrad : public OpExprGradFunction<AutoGradCaptureState> { struct ReshapeCaptureState : public AutoGradCaptureState {
DimVector input_shape_vec;
};
class ReshapeOpExprGrad : public OpExprGradFunction<ReshapeCaptureState> {
public: public:
Maybe<void> Init(const OpExpr& op) override { Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op); const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
...@@ -32,17 +36,18 @@ class ReshapeOpExprGrad : public OpExprGradFunction<AutoGradCaptureState> { ...@@ -32,17 +36,18 @@ class ReshapeOpExprGrad : public OpExprGradFunction<AutoGradCaptureState> {
return Maybe<void>::Ok(); return Maybe<void>::Ok();
} }
Maybe<void> Capture(AutoGradCaptureState* ctx, const TensorTuple& inputs, Maybe<void> Capture(ReshapeCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override { const TensorTuple& outputs, const AttrMap& attrs) const override {
ctx->SaveTensorForBackward(inputs.at(0)); ctx->input_shape_vec = inputs.at(0)->shape()->dim_vec();
return Maybe<void>::Ok(); return Maybe<void>::Ok();
} }
Maybe<void> Apply(const AutoGradCaptureState* ctx, const TensorTuple& out_grads, Maybe<void> Apply(const ReshapeCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override { TensorTuple* in_grads) const override {
const auto& saved_tensors = ctx->SavedTensors(); const auto& saved_tensors = ctx->SavedTensors();
in_grads->resize(1); in_grads->resize(1);
in_grads->at(0) = JUST(functional::ReshapeLike(out_grads.at(0), saved_tensors.at(0))); Shape shape(ctx->input_shape_vec);
in_grads->at(0) = JUST(functional::Reshape(out_grads.at(0), shape));
return Maybe<void>::Ok(); return Maybe<void>::Ok();
} }
}; };
......
...@@ -90,62 +90,62 @@ typename std::remove_const<typename std::remove_reference<T>::type>::type&& Remo ...@@ -90,62 +90,62 @@ typename std::remove_const<typename std::remove_reference<T>::type>::type&& Remo
#if defined(__GNUC__) || defined(__CUDACC__) || defined(__clang__) #if defined(__GNUC__) || defined(__CUDACC__) || defined(__clang__)
#define JUST(...) \ #define JUST(...) \
::oneflow::private_details::RemoveRValConst(({ \ ::oneflow::private_details::RemoveRValConst(({ \
auto&& value_to_check_ = __JustStackCheckWrapper__(__VA_ARGS__); \ auto&& _just_value_to_check_ = __JustStackCheckWrapper__(__VA_ARGS__); \
if (!::oneflow::private_details::JustIsOk(value_to_check_)) { \ if (!::oneflow::private_details::JustIsOk(_just_value_to_check_)) { \
return ::oneflow::private_details::JustErrorAddStackFrame( \ return ::oneflow::private_details::JustErrorAddStackFrame( \
::oneflow::private_details::JustGetError(value_to_check_), __FILE__, __LINE__, \ ::oneflow::private_details::JustGetError(_just_value_to_check_), __FILE__, __LINE__, \
__FUNCTION__, OF_PP_STRINGIZE(__VA_ARGS__)); \ __FUNCTION__, OF_PP_STRINGIZE(__VA_ARGS__)); \
} \ } \
std::forward<decltype(value_to_check_)>(value_to_check_); \ std::forward<decltype(_just_value_to_check_)>(_just_value_to_check_); \
})).Data_YouAreNotAllowedToCallThisFuncOutsideThisFile() })).Data_YouAreNotAllowedToCallThisFuncOutsideThisFile()
#define CHECK_JUST(...) \ #define CHECK_JUST(...) \
([&](const char* func_name) { \ ([&](const char* _just_closure_func_name_) { \
auto&& value_to_check_ = __JustStackCheckWrapper__(__VA_ARGS__); \ auto&& _just_value_to_check_ = __JustStackCheckWrapper__(__VA_ARGS__); \
if (!::oneflow::private_details::JustIsOk(value_to_check_)) { \ if (!::oneflow::private_details::JustIsOk(_just_value_to_check_)) { \
LOG(FATAL) << ::oneflow::GetFormatedSerializedError( \ LOG(FATAL) << ::oneflow::GetFormatedSerializedError( \
::oneflow::private_details::JustErrorAddStackFrame( \ ::oneflow::private_details::JustErrorAddStackFrame( \
::oneflow::private_details::JustGetError(value_to_check_), __FILE__, __LINE__, \ ::oneflow::private_details::JustGetError(_just_value_to_check_), __FILE__, __LINE__, \
func_name, OF_PP_STRINGIZE(__VA_ARGS__))); \ _just_closure_func_name_, OF_PP_STRINGIZE(__VA_ARGS__))); \
} \ } \
return std::forward<decltype(value_to_check_)>(value_to_check_); \ return std::forward<decltype(_just_value_to_check_)>(_just_value_to_check_); \
})(__FUNCTION__) \ })(__FUNCTION__) \
.Data_YouAreNotAllowedToCallThisFuncOutsideThisFile() .Data_YouAreNotAllowedToCallThisFuncOutsideThisFile()
#define JUST_MSG(value, ...) \ #define JUST_MSG(value, ...) \
::oneflow::private_details::RemoveRValConst(({ \ ::oneflow::private_details::RemoveRValConst(({ \
auto&& value_to_check_ = (value); \ auto&& _just_value_to_check_ = (value); \
if (!::oneflow::private_details::JustIsOk(value_to_check_)) { \ if (!::oneflow::private_details::JustIsOk(_just_value_to_check_)) { \
return ::oneflow::private_details::JustErrorAddMessage( \ return ::oneflow::private_details::JustErrorAddMessage( \
::oneflow::Error(::oneflow::private_details::JustGetError(value_to_check_)) \ ::oneflow::Error(::oneflow::private_details::JustGetError(_just_value_to_check_)) \
.AddStackFrame(__FILE__, __LINE__, __FUNCTION__), \ .AddStackFrame(__FILE__, __LINE__, __FUNCTION__), \
OF_PP_STRINGIZE(value), ": ", __VA_ARGS__); \ OF_PP_STRINGIZE(value), ": ", __VA_ARGS__); \
} \ } \
std::forward<decltype(value_to_check_)>(value_to_check_); \ std::forward<decltype(_just_value_to_check_)>(_just_value_to_check_); \
})).Data_YouAreNotAllowedToCallThisFuncOutsideThisFile() })).Data_YouAreNotAllowedToCallThisFuncOutsideThisFile()
#define CHECK_JUST_MSG(value, ...) \ #define CHECK_JUST_MSG(value, ...) \
([&](const char* func_name) { \ ([&](const char* _just_closure_func_name_) { \
auto&& value_to_check_ = (value); \ auto&& _just_value_to_check_ = (value); \
if (!::oneflow::private_details::JustIsOk(value_to_check_)) { \ if (!::oneflow::private_details::JustIsOk(_just_value_to_check_)) { \
LOG(FATAL) << ::oneflow::GetFormatedSerializedError( \ LOG(FATAL) << ::oneflow::GetFormatedSerializedError( \
::oneflow::private_details::JustErrorAddMessage( \ ::oneflow::private_details::JustErrorAddMessage( \
::oneflow::Error(::oneflow::private_details::JustGetError(value_to_check_)) \ ::oneflow::Error(::oneflow::private_details::JustGetError(_just_value_to_check_)) \
.AddStackFrame(__FILE__, __LINE__, func_name), \ .AddStackFrame(__FILE__, __LINE__, _just_closure_func_name_), \
OF_PP_STRINGIZE(value), ": ", __VA_ARGS__) \ OF_PP_STRINGIZE(value), ": ", __VA_ARGS__) \
.error_proto()); \ .error_proto()); \
} \ } \
return std::forward<decltype(value_to_check_)>(value_to_check_); \ return std::forward<decltype(_just_value_to_check_)>(_just_value_to_check_); \
})(__FUNCTION__) \ })(__FUNCTION__) \
.Data_YouAreNotAllowedToCallThisFuncOutsideThisFile() .Data_YouAreNotAllowedToCallThisFuncOutsideThisFile()
#define JUST_OPT(...) \ #define JUST_OPT(...) \
::oneflow::private_details::RemoveRValConst(({ \ ::oneflow::private_details::RemoveRValConst(({ \
auto&& value_to_check_ = __JustStackCheckWrapper__(__VA_ARGS__); \ auto&& _just_value_to_check_ = __JustStackCheckWrapper__(__VA_ARGS__); \
if (!value_to_check_.has_value()) { return NullOpt; } \ if (!_just_value_to_check_.has_value()) { return NullOpt; } \
std::forward<decltype(value_to_check_)>(value_to_check_); \ std::forward<decltype(_just_value_to_check_)>(_just_value_to_check_); \
})).Data_YouAreNotAllowedToCallThisFuncOutsideThisFile() })).Data_YouAreNotAllowedToCallThisFuncOutsideThisFile()
#else #else
......