提交 384d6ee8 编写于 作者: C chengduoZH

follow comments

上级 02842cfc
......@@ -61,33 +61,24 @@ void BroadcastOpHandle::RunImpl() {
// Wait input done, this Wait is asynchronous operation
auto &in_place = in_var_handle[0]->place_;
if (in_var_handle[0]->generated_op_) {
in_var_handle[0]->generated_op_->Wait(dev_ctxes_[in_place]);
for (auto *out : out_var_handles) {
auto &out_p = out->place_;
if (platform::is_same_place(in_place, out_p)) continue;
in_var_handle[0]->generated_op_->Wait(dev_ctxes_[out_p]);
}
}
//
auto in_scope_idx = in_var_handle[0]->scope_idx_;
PADDLE_ENFORCE_LT(in_scope_idx, local_scopes_.size(),
"The input(%s) is not in the local_scopes.",
in_var_handle[0]->name_);
auto in_var = local_scopes_[in_scope_idx]->FindVar(in_var_handle[0]->name_);
auto in_var =
local_scopes_.at(in_scope_idx)->FindVar(in_var_handle[0]->name_);
Tensor *in_tensor = GetTensorFromVar(in_var);
for (auto *out : out_var_handles) {
auto &out_p = out->place_;
auto out_var = local_scopes_.at(out->scope_idx_)->FindVar(out->name_);
auto out_scope_idx = out->scope_idx_;
PADDLE_ENFORCE_LT(out_scope_idx, local_scopes_.size(),
"%s is not in the local_scopes ", out->name_);
auto *s = local_scopes_[out_scope_idx];
auto out_var = s->FindVar(out->name_);
PADDLE_ENFORCE_EQ(out_p.which(), in_place.which(),
"The place of input and output should be the same.");
"Places must be all on CPU or all on CUDA.");
if (in_var->IsType<framework::SelectedRows>()) {
auto &in_sr = in_var->Get<framework::SelectedRows>();
......@@ -109,24 +100,8 @@ void BroadcastOpHandle::RunImpl() {
}
Tensor *out_tensor = GetTensorFromVar(out_var);
if (platform::is_cpu_place(in_place)) {
paddle::framework::TensorCopy(*in_tensor, out_p, *(dev_ctxes_[in_place]),
out_tensor);
} else if (platform::is_gpu_place(in_place)) {
#ifdef PADDLE_WITH_CUDA
auto src_gpu_place = boost::get<platform::CUDAPlace>(in_place);
auto dst_gpu_place = boost::get<platform::CUDAPlace>(out_p);
void *dst_ptr = out_tensor->mutable_data(out_p);
void *src_ptr = in_tensor->data<void>();
int64_t size = in_tensor->numel() * SizeOfType(in_tensor->type());
memory::Copy(
dst_gpu_place, dst_ptr, src_gpu_place, src_ptr, size,
reinterpret_cast<platform::CUDADeviceContext *>(dev_ctxes_[out_p])
->stream());
#else
PADDLE_THROW("CUDAPlace is not supported in CPU device.");
#endif
}
}
}
......
......@@ -74,13 +74,10 @@ void GatherOpHandle::RunImpl() {
auto in_handle = static_cast<VarHandle *>(in);
auto in_p = in_handle->place_;
in_places.push_back(in_p);
PADDLE_ENFORCE_LT(in_handle->scope_idx_, local_scopes_.size(),
"%s is not the the local_scopes ", in_handle->name_);
PADDLE_ENFORCE_EQ(in_p.which(), pre_place.which(),
"The place of input should be the same.");
auto *s = local_scopes_[in_handle->scope_idx_];
auto in_var = s->FindVar(in_handle->name_);
"Places must be all on CPU or all on CUDA.");
auto in_var =
local_scopes_.at(in_handle->scope_idx_)->FindVar(in_handle->name_);
auto &in_sr = in_var->Get<framework::SelectedRows>();
PADDLE_ENFORCE_EQ(in_sr.value().type(), pre_in.value().type(),
......
......@@ -11,8 +11,10 @@
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/tensor_util.h"
#include <algorithm>
#include <limits>
#include <vector>
namespace paddle {
namespace framework {
......@@ -65,8 +67,6 @@ void TensorCopy(const Tensor& src, const platform::Place& dst_place,
auto dst_gpu_place = boost::get<platform::CUDAPlace>(dst_place);
auto ctx_place = ctx.GetPlace();
PADDLE_ENFORCE(platform::is_gpu_place(ctx_place));
auto ctx_gpu_place = boost::get<platform::CUDAPlace>(ctx_place);
PADDLE_ENFORCE_EQ(src_gpu_place, ctx_gpu_place);
memory::Copy(
dst_gpu_place, dst_ptr, src_gpu_place, src_ptr, size,
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册