提交 ebfc7a5a 编写于 作者: H hjchen2

Merge branch 'dev_fix_broadcast_binary_ops_grad_func' of...

Merge branch 'dev_fix_broadcast_binary_ops_grad_func' of https://github.com/Oneflow-Inc/oneflow into dev_fix_broadcast_binary_ops_grad_func
......@@ -18,13 +18,13 @@ function write_to_file_and_print {
python3 scripts/compare_speed_with_pytorch.py resnet50/models/resnet50.py resnet50 16x3x224x224 --no-show-memory --times 100 | check_relative_speed 1.01 | write_to_file_and_print
python3 scripts/compare_speed_with_pytorch.py resnet50/models/resnet50.py resnet50 8x3x224x224 --no-show-memory --times 100 | check_relative_speed 1.05 | write_to_file_and_print
python3 scripts/compare_speed_with_pytorch.py resnet50/models/resnet50.py resnet50 4x3x224x224 --no-show-memory --times 200 | check_relative_speed 1.05 | write_to_file_and_print
python3 scripts/compare_speed_with_pytorch.py resnet50/models/resnet50.py resnet50 2x3x224x224 --no-show-memory --times 200 | check_relative_speed 1.09 | write_to_file_and_print
python3 scripts/compare_speed_with_pytorch.py resnet50/models/resnet50.py resnet50 1x3x224x224 --no-show-memory --times 200 | check_relative_speed 0.95 | write_to_file_and_print
python3 scripts/compare_speed_with_pytorch.py resnet50/models/resnet50.py resnet50 4x3x224x224 --no-show-memory --times 200 | check_relative_speed 1.01 | write_to_file_and_print
python3 scripts/compare_speed_with_pytorch.py resnet50/models/resnet50.py resnet50 2x3x224x224 --no-show-memory --times 200 | check_relative_speed 1.06 | write_to_file_and_print
python3 scripts/compare_speed_with_pytorch.py resnet50/models/resnet50.py resnet50 1x3x224x224 --no-show-memory --times 200 | check_relative_speed 0.94 | write_to_file_and_print
python3 -m oneflow.distributed.launch --nproc_per_node 2 scripts/compare_speed_with_pytorch.py resnet50/models/resnet50.py resnet50 16x3x224x224 --no-show-memory --times 100 --ddp | check_relative_speed 0.99 | write_to_file_and_print
python3 -m oneflow.distributed.launch --nproc_per_node 2 scripts/compare_speed_with_pytorch.py resnet50/models/resnet50.py resnet50 8x3x224x224 --no-show-memory --times 100 --ddp | check_relative_speed 0.99 | write_to_file_and_print
python3 -m oneflow.distributed.launch --nproc_per_node 2 scripts/compare_speed_with_pytorch.py resnet50/models/resnet50.py resnet50 4x3x224x224 --no-show-memory --times 200 --ddp | check_relative_speed 0.93 | write_to_file_and_print
python3 -m oneflow.distributed.launch --nproc_per_node 2 scripts/compare_speed_with_pytorch.py resnet50/models/resnet50.py resnet50 4x3x224x224 --no-show-memory --times 200 --ddp | check_relative_speed 0.91 | write_to_file_and_print
python3 -m oneflow.distributed.launch --nproc_per_node 2 scripts/compare_speed_with_pytorch.py resnet50/models/resnet50.py resnet50 2x3x224x224 --no-show-memory --times 200 --ddp | check_relative_speed 0.83 | write_to_file_and_print
python3 -m oneflow.distributed.launch --nproc_per_node 2 scripts/compare_speed_with_pytorch.py resnet50/models/resnet50.py resnet50 1x3x224x224 --no-show-memory --times 200 --ddp | check_relative_speed 0.82 | write_to_file_and_print
......
......@@ -24,7 +24,11 @@ limitations under the License.
namespace oneflow {
namespace one {
class ReshapeOpExprGrad : public OpExprGradFunction<AutoGradCaptureState> {
struct ReshapeCaptureState : public AutoGradCaptureState {
DimVector input_shape_vec;
};
class ReshapeOpExprGrad : public OpExprGradFunction<ReshapeCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
......@@ -32,17 +36,18 @@ class ReshapeOpExprGrad : public OpExprGradFunction<AutoGradCaptureState> {
return Maybe<void>::Ok();
}
Maybe<void> Capture(AutoGradCaptureState* ctx, const TensorTuple& inputs,
Maybe<void> Capture(ReshapeCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
ctx->SaveTensorForBackward(inputs.at(0));
ctx->input_shape_vec = inputs.at(0)->shape()->dim_vec();
return Maybe<void>::Ok();
}
Maybe<void> Apply(const AutoGradCaptureState* ctx, const TensorTuple& out_grads,
Maybe<void> Apply(const ReshapeCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
const auto& saved_tensors = ctx->SavedTensors();
in_grads->resize(1);
in_grads->at(0) = JUST(functional::ReshapeLike(out_grads.at(0), saved_tensors.at(0)));
Shape shape(ctx->input_shape_vec);
in_grads->at(0) = JUST(functional::Reshape(out_grads.at(0), shape));
return Maybe<void>::Ok();
}
};
......
......@@ -90,62 +90,62 @@ typename std::remove_const<typename std::remove_reference<T>::type>::type&& Remo
#if defined(__GNUC__) || defined(__CUDACC__) || defined(__clang__)
#define JUST(...) \
::oneflow::private_details::RemoveRValConst(({ \
auto&& value_to_check_ = __JustStackCheckWrapper__(__VA_ARGS__); \
if (!::oneflow::private_details::JustIsOk(value_to_check_)) { \
return ::oneflow::private_details::JustErrorAddStackFrame( \
::oneflow::private_details::JustGetError(value_to_check_), __FILE__, __LINE__, \
__FUNCTION__, OF_PP_STRINGIZE(__VA_ARGS__)); \
} \
std::forward<decltype(value_to_check_)>(value_to_check_); \
#define JUST(...) \
::oneflow::private_details::RemoveRValConst(({ \
auto&& _just_value_to_check_ = __JustStackCheckWrapper__(__VA_ARGS__); \
if (!::oneflow::private_details::JustIsOk(_just_value_to_check_)) { \
return ::oneflow::private_details::JustErrorAddStackFrame( \
::oneflow::private_details::JustGetError(_just_value_to_check_), __FILE__, __LINE__, \
__FUNCTION__, OF_PP_STRINGIZE(__VA_ARGS__)); \
} \
std::forward<decltype(_just_value_to_check_)>(_just_value_to_check_); \
})).Data_YouAreNotAllowedToCallThisFuncOutsideThisFile()
#define CHECK_JUST(...) \
([&](const char* func_name) { \
auto&& value_to_check_ = __JustStackCheckWrapper__(__VA_ARGS__); \
if (!::oneflow::private_details::JustIsOk(value_to_check_)) { \
LOG(FATAL) << ::oneflow::GetFormatedSerializedError( \
::oneflow::private_details::JustErrorAddStackFrame( \
::oneflow::private_details::JustGetError(value_to_check_), __FILE__, __LINE__, \
func_name, OF_PP_STRINGIZE(__VA_ARGS__))); \
} \
return std::forward<decltype(value_to_check_)>(value_to_check_); \
})(__FUNCTION__) \
#define CHECK_JUST(...) \
([&](const char* _just_closure_func_name_) { \
auto&& _just_value_to_check_ = __JustStackCheckWrapper__(__VA_ARGS__); \
if (!::oneflow::private_details::JustIsOk(_just_value_to_check_)) { \
LOG(FATAL) << ::oneflow::GetFormatedSerializedError( \
::oneflow::private_details::JustErrorAddStackFrame( \
::oneflow::private_details::JustGetError(_just_value_to_check_), __FILE__, __LINE__, \
_just_closure_func_name_, OF_PP_STRINGIZE(__VA_ARGS__))); \
} \
return std::forward<decltype(_just_value_to_check_)>(_just_value_to_check_); \
})(__FUNCTION__) \
.Data_YouAreNotAllowedToCallThisFuncOutsideThisFile()
#define JUST_MSG(value, ...) \
::oneflow::private_details::RemoveRValConst(({ \
auto&& value_to_check_ = (value); \
if (!::oneflow::private_details::JustIsOk(value_to_check_)) { \
return ::oneflow::private_details::JustErrorAddMessage( \
::oneflow::Error(::oneflow::private_details::JustGetError(value_to_check_)) \
.AddStackFrame(__FILE__, __LINE__, __FUNCTION__), \
OF_PP_STRINGIZE(value), ": ", __VA_ARGS__); \
} \
std::forward<decltype(value_to_check_)>(value_to_check_); \
#define JUST_MSG(value, ...) \
::oneflow::private_details::RemoveRValConst(({ \
auto&& _just_value_to_check_ = (value); \
if (!::oneflow::private_details::JustIsOk(_just_value_to_check_)) { \
return ::oneflow::private_details::JustErrorAddMessage( \
::oneflow::Error(::oneflow::private_details::JustGetError(_just_value_to_check_)) \
.AddStackFrame(__FILE__, __LINE__, __FUNCTION__), \
OF_PP_STRINGIZE(value), ": ", __VA_ARGS__); \
} \
std::forward<decltype(_just_value_to_check_)>(_just_value_to_check_); \
})).Data_YouAreNotAllowedToCallThisFuncOutsideThisFile()
#define CHECK_JUST_MSG(value, ...) \
([&](const char* func_name) { \
auto&& value_to_check_ = (value); \
if (!::oneflow::private_details::JustIsOk(value_to_check_)) { \
LOG(FATAL) << ::oneflow::GetFormatedSerializedError( \
::oneflow::private_details::JustErrorAddMessage( \
::oneflow::Error(::oneflow::private_details::JustGetError(value_to_check_)) \
.AddStackFrame(__FILE__, __LINE__, func_name), \
OF_PP_STRINGIZE(value), ": ", __VA_ARGS__) \
.error_proto()); \
} \
return std::forward<decltype(value_to_check_)>(value_to_check_); \
})(__FUNCTION__) \
#define CHECK_JUST_MSG(value, ...) \
([&](const char* _just_closure_func_name_) { \
auto&& _just_value_to_check_ = (value); \
if (!::oneflow::private_details::JustIsOk(_just_value_to_check_)) { \
LOG(FATAL) << ::oneflow::GetFormatedSerializedError( \
::oneflow::private_details::JustErrorAddMessage( \
::oneflow::Error(::oneflow::private_details::JustGetError(_just_value_to_check_)) \
.AddStackFrame(__FILE__, __LINE__, _just_closure_func_name_), \
OF_PP_STRINGIZE(value), ": ", __VA_ARGS__) \
.error_proto()); \
} \
return std::forward<decltype(_just_value_to_check_)>(_just_value_to_check_); \
})(__FUNCTION__) \
.Data_YouAreNotAllowedToCallThisFuncOutsideThisFile()
#define JUST_OPT(...) \
::oneflow::private_details::RemoveRValConst(({ \
auto&& value_to_check_ = __JustStackCheckWrapper__(__VA_ARGS__); \
if (!value_to_check_.has_value()) { return NullOpt; } \
std::forward<decltype(value_to_check_)>(value_to_check_); \
#define JUST_OPT(...) \
::oneflow::private_details::RemoveRValConst(({ \
auto&& _just_value_to_check_ = __JustStackCheckWrapper__(__VA_ARGS__); \
if (!_just_value_to_check_.has_value()) { return NullOpt; } \
std::forward<decltype(_just_value_to_check_)>(_just_value_to_check_); \
})).Data_YouAreNotAllowedToCallThisFuncOutsideThisFile()
#else
......
......@@ -173,6 +173,12 @@ Maybe<void> FuseUpdateOpsPass::Apply(const OpGraph& op_graph, JobBuilder* job_bu
.Attr<float>("beta1", user_op_conf.attr<float>("beta1"))
.Attr<float>("beta2", user_op_conf.attr<float>("beta2"))
.Attr<float>("epsilon", user_op_conf.attr<float>("epsilon"));
if (user_op_conf.has_input("bias_correction1", 0)) {
fused_op_builder.Input("bias_correction1", user_op_conf.input("bias_correction1", 0));
}
if (user_op_conf.has_input("bias_correction2", 0)) {
fused_op_builder.Input("bias_correction2", user_op_conf.input("bias_correction2", 0));
}
} else if (user_op_conf.op_type_name() == "rmsprop_update") {
const bool centered = user_op_conf.attr<bool>("centered");
fused_op_builder.Input("mean_square", user_op_conf.input("mean_square", 0.f))
......
......@@ -156,7 +156,6 @@ import oneflow.framework.register_python_callback
INVALID_SPLIT_AXIS = oneflow._oneflow_internal.INVALID_SPLIT_AXIS
register_class_method_util.RegisterMethod4Class()
oneflow._oneflow_internal.RegisterGILForeignLockHelper()
import oneflow.framework.env_util as env_util
import oneflow.framework.scope_util as scope_util
import oneflow.framework.session_context as session_ctx
......@@ -166,6 +165,7 @@ if not env_util.HasAllMultiClientEnvVars():
env_util.SetDefaultMultiClientEnvVars()
oneflow._oneflow_internal.SetIsMultiClient(True)
env_util.api_env_init()
oneflow._oneflow_internal.RegisterGILForeignLockHelper()
oneflow._oneflow_internal.InitDefaultConsistentTransportTokenScope()
session_ctx.OpenDefaultSession(
MultiClientSession(oneflow._oneflow_internal.NewSessionId())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册