提交 fa5060a2 编写于 作者: J Juncheng 提交者: GitHub

Remove NormalModelUpdateOpConf (#3917)

* Remove NormalModelUpdateOpConf

* Remove useless import

* Remove TaskNode/Actor

* fix

Former-commit-id: bab1c62a
上级 380f538a
......@@ -67,7 +67,6 @@ void NormalForwardCompActor::AsyncInitModelAndConstBuf() {
}
REGISTER_ACTOR(TaskType::kNormalForward, NormalForwardCompActor);
REGISTER_ACTOR(TaskType::kOptimizer, NormalForwardCompActor);
REGISTER_ACTOR(TaskType::kPrint, NormalForwardCompActor);
REGISTER_ACTOR(TaskType::kForeignInput, NormalForwardCompActor);
REGISTER_ACTOR(TaskType::kForeignOutput, NormalForwardCompActor);
......
......@@ -15,7 +15,6 @@ limitations under the License.
*/
#include "oneflow/core/graph/logical_node.h"
#include "oneflow/core/graph/normal_forward_compute_task_node.h"
#include "oneflow/core/graph/optimizer_compute_task_node.h"
#include "oneflow/core/graph/print_compute_task_node.h"
#include "oneflow/core/graph/decode_compute_task_node.h"
#include "oneflow/core/graph/decode_random_compute_task_node.h"
......@@ -294,12 +293,6 @@ int64_t NormalForwardLogicalNode::GetAreaId() const {
}
}
std::string OptimizerLogicalNode::TypeName() const { return "Optimizer"; }
CompTaskNode* OptimizerLogicalNode::NewCompTaskNode() const { return new OptimizerCompTaskNode; }
int64_t OptimizerLogicalNode::GetAreaId() const { return kMdUpdtArea; }
int64_t NewAreaId() {
static int64_t next_area_id = AreaType_ARRAYSIZE;
return ++next_area_id;
......
......@@ -147,13 +147,6 @@ class NormalForwardLogicalNode final : public ForwardLogicalNode {
private:
};
class OptimizerLogicalNode final : public ForwardLogicalNode {
public:
LOGICAL_NODE_BOILERPLATE(OptimizerLogicalNode);
private:
};
int64_t NewAreaId();
#define LOGICAL_NODE_WITH_NEW_AREA_ID_BOILERPLATE(name) \
......
......@@ -16,7 +16,6 @@ limitations under the License.
#include "oneflow/core/graph/op_graph.h"
#include "oneflow/core/job/job_builder.h"
#include "oneflow/core/job/mirrored_sig_infer_hint.h"
#include "oneflow/core/operator/normal_model_update_op.h"
namespace oneflow {
......
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/graph/logical_node.h"
#include "oneflow/core/graph/optimizer_compute_task_node.h"
namespace oneflow {
void OptimizerCompTaskNode::ConsumeAllRegsts() {
ForEachInDataEdge([&](TaskEdge* edge) {
for (const auto& regst : edge->GetRegsts()) { ConsumeRegst("in", regst); }
});
}
void OptimizerCompTaskNode::ProduceAllRegstsAndBindEdges() { ProduceRegst("tmp", false, 1, 1); }
void OptimizerCompTaskNode::BuildExecGphAndRegst() {
ExecNode* node = mut_exec_gph().NewNode();
std::shared_ptr<Operator> sole_op = this->logical_node()->SoleOp();
node->mut_op() = sole_op;
const std::list<std::shared_ptr<RegstDesc>>& in_regsts = GetConsumedRegst("in");
for (const auto& ibn : node->op()->input_bns()) {
node->BindBnWithOneOfTheRegsts(ibn, in_regsts);
}
node->AddBnToRegstAndBindIt(&Operator::tmp_bns, GetProducedRegst("tmp"));
node->InferBlobDescs(parallel_ctx());
}
void OptimizerCompTaskNode::InferProducedDataRegstTimeShape() {
ForEachProducedDataRegst([](const std::string& name, RegstDesc* regst) {
regst->mut_data_regst_time_shape()->reset(
new Shape({GlobalJobDesc().TotalBatchNum(), static_cast<int64_t>(1)}));
});
}
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_GRAPH_OPTIMIZER_COMPUTE_TASK_NODE_H_
#define ONEFLOW_CORE_GRAPH_OPTIMIZER_COMPUTE_TASK_NODE_H_
#include "oneflow/core/graph/compute_task_node.h"
namespace oneflow {
class OptimizerCompTaskNode final : public CompTaskNode {
public:
OF_DISALLOW_COPY_AND_MOVE(OptimizerCompTaskNode);
OptimizerCompTaskNode() = default;
~OptimizerCompTaskNode() = default;
void ProduceAllRegstsAndBindEdges() override;
void ConsumeAllRegsts() override;
TaskType GetTaskType() const override { return TaskType::kOptimizer; }
CudaWorkType GetCudaWorkType() const override {
#ifdef WITH_CUDA
return CudaWorkType::kCompute;
#else
UNIMPLEMENTED();
#endif
}
private:
void BuildExecGphAndRegst() override;
void InferProducedDataRegstTimeShape() override;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_OPTIMIZER_COMPUTE_TASK_NODE_H_
......@@ -19,7 +19,6 @@ enum TaskType {
kUnpack = 32;
kRepeat = 34;
kAcc = 37;
kOptimizer = 38;
kSourceTick = 40;
kTick = 41;
kAccTick = 42;
......
......@@ -81,32 +81,6 @@ float GetOptimizerWeightDecayRate(const NormalModelUpdateOpUserConf& model_updat
}
}
template<typename T>
void ConstructMdUpdtOpConf(const VariableOp& op, const LogicalBlobId& diff_lbi_of_var_out,
JobBuilder* job_builder, T* mdupdt_op_conf) {
const auto& train_conf = job_builder->job().job_conf().train_conf();
*mdupdt_op_conf->mutable_user_conf() = train_conf.model_update_conf();
mdupdt_op_conf->set_model_diff(GenLogicalBlobName(diff_lbi_of_var_out));
mdupdt_op_conf->set_model(GenLogicalBlobName(op.BnInOp2Lbi("out")));
mdupdt_op_conf->set_train_step(train_conf.train_step_lbn());
const std::string& primary_lr_lbn = train_conf.primary_lr_lbn();
const std::string& secondary_lr_lbn = train_conf.secondary_lr_lbn();
if (op.op_conf().variable_conf().model_name() == "weight") {
mdupdt_op_conf->set_learning_rate(primary_lr_lbn);
} else if (op.op_conf().variable_conf().model_name() == "bias") {
mdupdt_op_conf->set_learning_rate(secondary_lr_lbn);
} else {
mdupdt_op_conf->set_learning_rate(primary_lr_lbn);
}
const float weight_decay_rate = GetOptimizerWeightDecayRate(train_conf.model_update_conf(), op);
if (weight_decay_rate != 0) { mdupdt_op_conf->set_weight_decay(weight_decay_rate); }
}
#define INSTANTIATE_CONSTRUCTOR_MDUPDT_OP_CONF(T) \
template void ConstructMdUpdtOpConf<T>(const VariableOp& op, \
const LogicalBlobId& diff_lbi_of_var_out, \
JobBuilder* job_builder, T* mdupdt_op_conf)
void SetDynamicLossScaleSkipIf(JobPassCtx* ctx, user_op::UserOpConfWrapperBuilder* builder) {
if (!ctx->job_desc().job_conf().train_conf().has_dynamic_loss_scale_policy()) { return; }
builder->Input("skip_if",
......
......@@ -31,10 +31,6 @@ float GetOptimizerWeightDecayRate(const NormalModelUpdateOpUserConf& model_updat
void SetDynamicLossScaleSkipIf(JobPassCtx* ctx, user_op::UserOpConfWrapperBuilder* builder);
template<typename T>
void ConstructMdUpdtOpConf(const VariableOp& op, const LogicalBlobId& diff_lbi_of_var_out,
JobBuilder* job_builder, T*);
class GenerateOptimizerOpConfWrapperStruct final {
public:
using Func = std::function<void(JobPassCtx*, const VariableOp&, const ParallelConf&, JobBuilder*,
......
syntax = "proto2";
package oneflow;
import "oneflow/core/common/shape.proto";
import "oneflow/core/common/data_type.proto";
import "oneflow/core/common/dtype_signature.proto";
import "oneflow/core/operator/op_attribute.proto";
import "oneflow/core/operator/op_conf.proto";
import "oneflow/core/job/placement.proto";
import "oneflow/core/register/tensor_slice_view.proto";
import "oneflow/core/job/sbp_parallel.proto";
......@@ -19,15 +17,6 @@ message DecodeOFRecordKernelConf {
required uint32 random_seed = 1;
}
message SliceKernelConf {
required ShapeProto in_shape = 1;
}
message ConstantKernelConf {
required InitializerConf initializer = 1;
required uint32 random_seed = 2;
}
message VariableKernelConf {
required bool is_fw_inplace = 1;
required bool is_bw_inplace = 2;
......@@ -104,7 +93,6 @@ message KernelConf {
SyncDynamicResizeKernelConf sync_dynamic_resize_conf = 360;
ArgWhereKernelConf arg_where_conf = 361;
SliceKernelConf slice_conf = 402;
VariableKernelConf variable_conf = 407;
RecordLoadKernelConf record_load_conf = 408;
ShapeElemCntKernelConf shape_elem_cnt_conf = 412;
......
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/kernel/normal_model_update_kernel.h"
namespace oneflow {
template<DeviceType device_type, typename T>
void NormalMdUpdateKernel<device_type, T>::VirtualKernelInit() {
const PbMessage& op_conf = this->GetCustomizedOpConf();
weight_decay_ = static_cast<T>(GetValFromPbMessage<float>(op_conf, "weight_decay"));
if (!IsWeightDecaySupported()) { CHECK_EQ(weight_decay_, static_cast<T>(0)); }
}
template<DeviceType device_type, typename T>
void NormalMdUpdateKernel<device_type, T>::ForwardDataContent(
const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
const int64_t* train_step_ptr = BnInOp2Blob("train_step")->dptr<int64_t>();
const float* learning_rate_ptr = BnInOp2Blob("learning_rate")->dptr<float>();
UpdateModel(ctx.device_ctx, weight_decay_, train_step_ptr, learning_rate_ptr, BnInOp2Blob);
}
#define INSTANTIATE_KERNEL(device_type, data_type_pair) \
template class NormalMdUpdateKernel<device_type, OF_PP_PAIR_FIRST(data_type_pair)>;
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_KERNEL, DEVICE_TYPE_SEQ, FLOATING_DATA_TYPE_SEQ)
#undef INSTANTIATE_KERNEL
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_KERNEL_NORMAL_MODEL_UPDATE_KERNEL_H_
#define ONEFLOW_CORE_KERNEL_NORMAL_MODEL_UPDATE_KERNEL_H_
#include "oneflow/core/framework/to_string.h"
#include "oneflow/core/kernel/kernel.h"
namespace oneflow {
template<DeviceType device_type, typename T>
class NormalMdUpdateKernel : public KernelIf<device_type> {
public:
OF_DISALLOW_COPY_AND_MOVE(NormalMdUpdateKernel);
virtual ~NormalMdUpdateKernel() = default;
protected:
NormalMdUpdateKernel() = default;
virtual void UpdateModel(DeviceCtx* ctx, T weight_decay, const int64_t* train_step,
const float* learning_rate,
std::function<Blob*(const std::string&)> BnInOp2Blob) const = 0;
virtual bool IsWeightDecaySupported() { return false; }
void Forward(const KernelCtx& ctx,
std::function<Blob*(const std::string&)> BnInOp2Blob) const override {
ForwardDataContent(ctx, BnInOp2Blob);
}
private:
void ForwardDataContent(const KernelCtx& ctx,
std::function<Blob*(const std::string&)> BnInOp2Blob) const override;
void VirtualKernelInit() override;
T weight_decay_;
};
#define DECLARE_MDUPDT_KERNEL_CREATOR(x) Kernel* Create##x##MdUpdtKernel(const KernelConf&);
#define DEFINE_MDUPDT_KERNEL_CREATOR(x) \
Kernel* Create##x##MdUpdtKernel(const KernelConf& kernel_conf) { \
static const HashMap<std::string, std::function<Kernel*()>> creators = { \
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_KERNEL_CREATOR_ENTRY, (x##MdUpdateKernel), \
DEVICE_TYPE_SEQ, FLOATING_DATA_TYPE_SEQ)}; \
DeviceType device_type = \
CHECK_JUST(DeviceType4DeviceTag(kernel_conf.op_attribute().op_conf().device_tag())); \
return creators.at(GetHashKey(device_type, kernel_conf.data_type()))(); \
}
} // namespace oneflow
#endif // ONEFLOW_CORE_KERNEL_NORMAL_MODEL_UPDATE_KERNEL_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/operator/normal_model_update_op.h"
#include "oneflow/core/job/sbp_signature_builder.h"
namespace oneflow {
void NormalModelUpdtOp::InitFromOpConf() {
EnrollInputBn("model_diff", false);
EnrollInputBn("model", false)->set_is_mutable(true);
EnrollInputBn("learning_rate", false);
EnrollInputBn("train_step", false);
MdUpdtVirtualInitFromOpConf();
}
Maybe<void> NormalModelUpdtOp::InferBlobDescs(
std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const {
return MdUpdtVirtualInferBlobDescs(GetBlobDesc4BnInOp, parallel_ctx);
}
LogicalBlobId NormalModelUpdtOp::lbi4obn(const std::string& output_bn) const {
const google::protobuf::Descriptor* desc = GetCustomizedConf().GetDescriptor();
const google::protobuf::FieldDescriptor* fd = desc->FindFieldByName(output_bn);
CHECK(fd);
return GenLogicalBlobId(GetValFromCustomizedConf<std::string>(output_bn));
}
Maybe<void> NormalModelUpdtOp::InferBatchAxis(
std::function<OptInt64*(const std::string&)> BatchAxis4BnInOp) const {
return Maybe<void>::Ok();
}
Maybe<void> NormalModelUpdtOp::GetSbpSignatures(
const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,
SbpSignatureList* sbp_sig_list) const {
const auto& bns = AlwaysBroadcastParallelBns();
PbRpf<std::string> broadcast_bns = {bns.begin(), bns.end()};
*broadcast_bns.Add() = "learning_rate";
*broadcast_bns.Add() = "train_step";
FOR_RANGE(int64_t, i, 0, JUST(LogicalBlobDesc4Ibn("model")).shape().NumAxes()) {
SbpSignatureBuilder()
.Split(input_bns(), i)
.Broadcast(broadcast_bns)
.Build(sbp_sig_list->mutable_sbp_signature()->Add());
}
return Maybe<void>::Ok();
}
REGISTER_OP_CREATOR(OperatorConf::kNormalMdupdtConf, ([](const OperatorConf& op_conf) -> Operator* {
return NewObj<int32_t, NormalModelUpdtOp>(
op_conf.normal_mdupdt_conf().user_conf().normal_mdupdt_case());
}));
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_OPERATOR_NORMAL_MODEL_UPDATE_OP_H_
#define ONEFLOW_CORE_OPERATOR_NORMAL_MODEL_UPDATE_OP_H_
#include "oneflow/core/operator/operator.h"
namespace oneflow {
class NormalModelUpdtOp : public Operator {
public:
OF_DISALLOW_COPY_AND_MOVE(NormalModelUpdtOp);
virtual ~NormalModelUpdtOp() = default;
void InitFromOpConf() override;
Maybe<void> InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const override;
protected:
NormalModelUpdtOp() = default;
virtual void MdUpdtVirtualInitFromOpConf() {}
virtual Maybe<void> MdUpdtVirtualInferBlobDescs(
std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext*) const {
return Maybe<void>::Ok();
}
virtual const HashSet<std::string> AlwaysBroadcastParallelBns() const = 0;
private:
Maybe<void> InferBatchAxis(
std::function<OptInt64*(const std::string&)> BatchAxis4BnInOp) const override;
Maybe<void> GetSbpSignatures(
const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,
SbpSignatureList* sbp_sig_list) const override;
LogicalBlobId lbi4obn(const std::string& output_bn) const override;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_OPERATOR_NORMAL_MODEL_UPDATE_OP_H_
......@@ -351,15 +351,6 @@ message NormalModelUpdateOpUserConf {
}
}
message NormalModelUpdateOpConf {
required NormalModelUpdateOpUserConf user_conf = 1;
required string model_diff = 2;
required string model = 4;
required string train_step = 5;
required string learning_rate = 6;
optional float weight_decay = 7 [default = 0.0];
}
message AccumulateOpConf {
}
......@@ -888,7 +879,6 @@ message OperatorConf {
CopyCommNetOpConf copy_comm_net_conf = 106;
BoxingOpConf boxing_conf = 108;
AccumulateOpConf accumulate_conf = 117;
NormalModelUpdateOpConf normal_mdupdt_conf = 118;
VariableOpConf variable_conf = 122;
TickOpConf tick_conf = 124;
KeepHeaderOnlyOpConf keep_header_only_conf = 125;
......
......@@ -31,8 +31,6 @@ class XrtLaunchOp : public Operator {
const ParallelContext* parallel_ctx) const override;
LogicalNode* NewProperLogicalNode() const override {
const auto& launch_conf = op_conf().xrt_launch_conf();
if (launch_conf.model_update()) { return new OptimizerLogicalNode; }
return new NormalForwardLogicalNode;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册