提交 4760ac44 编写于 作者: C chengduoZH

check the generate_op is null or not and add DEPS of broadcast_op_handle and gather_op_handle

上级 d24ef931
...@@ -21,11 +21,10 @@ cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framewor ...@@ -21,11 +21,10 @@ cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framewor
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
simple_threadpool device_context) simple_threadpool device_context)
cc_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory)
cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory)
cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows) cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows)
cc_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base variable_visitor scope ddim memory)
cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope variable_visitor ddim memory)
cc_test(broadcast_op_test SRCS broadcast_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory cc_test(broadcast_op_test SRCS broadcast_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory
device_context broadcast_op_handle) device_context broadcast_op_handle)
......
...@@ -61,8 +61,9 @@ void BroadcastOpHandle::RunImpl() { ...@@ -61,8 +61,9 @@ void BroadcastOpHandle::RunImpl() {
"Places must be all on CPU or all on CUDA."); "Places must be all on CPU or all on CUDA.");
VariableVisitor::ShareDimsAndLoD(*in_var, out_var); VariableVisitor::ShareDimsAndLoD(*in_var, out_var);
VariableVisitor::GetMutableTensor(out_var).mutable_data(out_p, VariableVisitor::GetMutableTensor(out_var)
in_tensor.type()); .Resize(in_tensor.dims())
.mutable_data(out_p, in_tensor.type());
auto dev_ctx = dev_ctxes_[out_p]; auto dev_ctx = dev_ctxes_[out_p];
RunAndRecordEvent(out_p, [in_tensor, out_var, dev_ctx, out_p] { RunAndRecordEvent(out_p, [in_tensor, out_var, dev_ctx, out_p] {
...@@ -74,9 +75,11 @@ void BroadcastOpHandle::RunImpl() { ...@@ -74,9 +75,11 @@ void BroadcastOpHandle::RunImpl() {
} }
void BroadcastOpHandle::WaitInputVarGenerated(const VarHandle &in_var) { void BroadcastOpHandle::WaitInputVarGenerated(const VarHandle &in_var) {
if (in_var.generated_op_) {
for (auto &pair : dev_ctxes_) { for (auto &pair : dev_ctxes_) {
in_var.generated_op_->Wait(pair.second); in_var.generated_op_->Wait(pair.second);
} }
}
} }
std::string BroadcastOpHandle::Name() const { return "broadcast"; } std::string BroadcastOpHandle::Name() const { return "broadcast"; }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册