...
 
Commits (4)
    https://gitcode.net/Oneflow-Inc/oneflow/-/commit/8b94ac9b8fd0578aeed91a85c955a7f2a400b6aa restruct reshape gradient funcs (#6634) 2021-11-02T04:15:23+00:00 Luyang flowingsun007@163.com * restruct * refine https://gitcode.net/Oneflow-Inc/oneflow/-/commit/2247386074c82dd56248405b52095d9b1609bae8 Fix model update pass adam (#6673) 2021-11-02T18:11:41+08:00 ZZK 42901638+MARD1NO@users.noreply.github.com * add first version of unary primitive op * fix * remove redundant file * Revert * fix format * use has input to check https://gitcode.net/Oneflow-Inc/oneflow/-/commit/55d32c333c8a298da5307bc1d219f48967fe5490 adjust GILForeignLockHelper order to avoid glog print to stderr (#6671) 2021-11-02T18:50:45+08:00 Xiaoyu Xu xiaoyulink@gmail.com Co-authored-by: <span data-trailer="Co-authored-by:"><a href="mailto:69100618+oneflow-ci-bot@users.noreply.github.com" title="69100618+oneflow-ci-bot@users.noreply.github.com"></a><a href="javascript:void(0)" class="avatar s16 avatar-inline identicon bg1" style="text-decoration: none">N</a><a href="mailto:69100618+oneflow-ci-bot@users.noreply.github.com" title="69100618+oneflow-ci-bot@users.noreply.github.com">oneflow-ci-bot</a> &lt;<a href="mailto:69100618+oneflow-ci-bot@users.noreply.github.com" title="69100618+oneflow-ci-bot@users.noreply.github.com">69100618+oneflow-ci-bot@users.noreply.github.com</a>&gt;</span> https://gitcode.net/Oneflow-Inc/oneflow/-/commit/fe31e54ef5f7b3154be5abfc4b65e202a26904f2 Merge branch 'master' into save_load_by_pickle 2021-11-02T21:14:50+08:00 oneflow-ci-bot 69100618+oneflow-ci-bot@users.noreply.github.com
......@@ -24,7 +24,11 @@ limitations under the License.
namespace oneflow {
namespace one {
class ReshapeOpExprGrad : public OpExprGradFunction<AutoGradCaptureState> {
struct ReshapeCaptureState : public AutoGradCaptureState {
DimVector input_shape_vec;
};
class ReshapeOpExprGrad : public OpExprGradFunction<ReshapeCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
......@@ -32,17 +36,18 @@ class ReshapeOpExprGrad : public OpExprGradFunction<AutoGradCaptureState> {
return Maybe<void>::Ok();
}
Maybe<void> Capture(AutoGradCaptureState* ctx, const TensorTuple& inputs,
Maybe<void> Capture(ReshapeCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
ctx->SaveTensorForBackward(inputs.at(0));
ctx->input_shape_vec = inputs.at(0)->shape()->dim_vec();
return Maybe<void>::Ok();
}
Maybe<void> Apply(const AutoGradCaptureState* ctx, const TensorTuple& out_grads,
Maybe<void> Apply(const ReshapeCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
const auto& saved_tensors = ctx->SavedTensors();
in_grads->resize(1);
in_grads->at(0) = JUST(functional::ReshapeLike(out_grads.at(0), saved_tensors.at(0)));
Shape shape(ctx->input_shape_vec);
in_grads->at(0) = JUST(functional::Reshape(out_grads.at(0), shape));
return Maybe<void>::Ok();
}
};
......
......@@ -173,6 +173,12 @@ Maybe<void> FuseUpdateOpsPass::Apply(const OpGraph& op_graph, JobBuilder* job_bu
.Attr<float>("beta1", user_op_conf.attr<float>("beta1"))
.Attr<float>("beta2", user_op_conf.attr<float>("beta2"))
.Attr<float>("epsilon", user_op_conf.attr<float>("epsilon"));
if (user_op_conf.has_input("bias_correction1", 0)) {
fused_op_builder.Input("bias_correction1", user_op_conf.input("bias_correction1", 0));
}
if (user_op_conf.has_input("bias_correction2", 0)) {
fused_op_builder.Input("bias_correction2", user_op_conf.input("bias_correction2", 0));
}
} else if (user_op_conf.op_type_name() == "rmsprop_update") {
const bool centered = user_op_conf.attr<bool>("centered");
fused_op_builder.Input("mean_square", user_op_conf.input("mean_square", 0.f))
......
......@@ -156,7 +156,6 @@ import oneflow.framework.register_python_callback
INVALID_SPLIT_AXIS = oneflow._oneflow_internal.INVALID_SPLIT_AXIS
register_class_method_util.RegisterMethod4Class()
oneflow._oneflow_internal.RegisterGILForeignLockHelper()
import oneflow.framework.env_util as env_util
import oneflow.framework.scope_util as scope_util
import oneflow.framework.session_context as session_ctx
......@@ -166,6 +165,7 @@ if not env_util.HasAllMultiClientEnvVars():
env_util.SetDefaultMultiClientEnvVars()
oneflow._oneflow_internal.SetIsMultiClient(True)
env_util.api_env_init()
oneflow._oneflow_internal.RegisterGILForeignLockHelper()
oneflow._oneflow_internal.InitDefaultConsistentTransportTokenScope()
session_ctx.OpenDefaultSession(
MultiClientSession(oneflow._oneflow_internal.NewSessionId())
......