未验证 提交 ccc3b389 编写于 作者: J Jinhui Yuan 提交者: GitHub

fix reduce_gather in case of enable_mem_sharing == false (#1186)

上级 28a6fc98
......@@ -3,7 +3,8 @@
namespace oneflow {
void ReduceGatherCompActor::SetKernelCtxOther(void** other) {
other_val_ = InBnId4RegstDescId(cur_processed_regst_desc_id());
int64_t in_bn_id = InBnId4RegstDescId(cur_processed_regst_desc_id());
other_val_ = std::make_pair(in_bn_id, EnableInplace());
*other = static_cast<void*>(&other_val_);
}
......
......@@ -15,7 +15,7 @@ class ReduceGatherCompActor final : public InputWiseCompActor {
void VirtualCompActorInit(const TaskProto& proto) override { InputWiseCompActor::Init(proto); }
void SetKernelCtxOther(void** other) override;
int64_t other_val_;
std::pair<int64_t, bool> other_val_;
};
} // namespace oneflow
......
......@@ -5,15 +5,17 @@ namespace oneflow {
template<DeviceType device_type>
void ReduceGatherKernel<device_type>::ForwardDataContent(
const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
if (device_type == DeviceType::kGPU) { return; }
int64_t in_bn_id = *static_cast<int64_t*>(ctx.other);
const auto* other_val = static_cast<std::pair<int64_t, bool>*>(ctx.other);
int64_t in_bn_id = other_val->first;
bool is_inplace = other_val->second;
if (is_inplace) { return; }
Blob* out_blob = BnInOp2Blob("out");
char* dst_cur_dptr = out_blob->mut_dptr<char>();
dst_cur_dptr += this->kernel_conf().reduce_gather_conf().data_offset().Get(in_bn_id);
Blob* in_blob = BnInOp2Blob(this->op_attribute().input_bns().Get(in_bn_id));
size_t in_byte_size = in_blob->ByteSizeOfDataContentField();
Memcpy<DeviceType::kCPU>(ctx.device_ctx, dst_cur_dptr, in_blob->dptr<char>(), in_byte_size);
Memcpy<device_type>(ctx.device_ctx, dst_cur_dptr, in_blob->dptr<char>(), in_byte_size);
}
ADD_DEVICE_TYPE_KERNEL_CREATOR(OperatorConf::kReduceGatherConf, ReduceGatherKernel);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册