提交 384d6ee8 编写于 作者: C chengduoZH

follow comments

上级 02842cfc
...@@ -61,33 +61,24 @@ void BroadcastOpHandle::RunImpl() { ...@@ -61,33 +61,24 @@ void BroadcastOpHandle::RunImpl() {
// Wait input done, this Wait is asynchronous operation // Wait input done, this Wait is asynchronous operation
auto &in_place = in_var_handle[0]->place_; auto &in_place = in_var_handle[0]->place_;
if (in_var_handle[0]->generated_op_) { if (in_var_handle[0]->generated_op_) {
in_var_handle[0]->generated_op_->Wait(dev_ctxes_[in_place]);
for (auto *out : out_var_handles) { for (auto *out : out_var_handles) {
auto &out_p = out->place_; auto &out_p = out->place_;
if (platform::is_same_place(in_place, out_p)) continue;
in_var_handle[0]->generated_op_->Wait(dev_ctxes_[out_p]); in_var_handle[0]->generated_op_->Wait(dev_ctxes_[out_p]);
} }
} }
// //
auto in_scope_idx = in_var_handle[0]->scope_idx_; auto in_scope_idx = in_var_handle[0]->scope_idx_;
PADDLE_ENFORCE_LT(in_scope_idx, local_scopes_.size(), auto in_var =
"The input(%s) is not in the local_scopes.", local_scopes_.at(in_scope_idx)->FindVar(in_var_handle[0]->name_);
in_var_handle[0]->name_);
auto in_var = local_scopes_[in_scope_idx]->FindVar(in_var_handle[0]->name_);
Tensor *in_tensor = GetTensorFromVar(in_var); Tensor *in_tensor = GetTensorFromVar(in_var);
for (auto *out : out_var_handles) { for (auto *out : out_var_handles) {
auto &out_p = out->place_; auto &out_p = out->place_;
auto out_var = local_scopes_.at(out->scope_idx_)->FindVar(out->name_);
auto out_scope_idx = out->scope_idx_;
PADDLE_ENFORCE_LT(out_scope_idx, local_scopes_.size(),
"%s is not in the local_scopes ", out->name_);
auto *s = local_scopes_[out_scope_idx];
auto out_var = s->FindVar(out->name_);
PADDLE_ENFORCE_EQ(out_p.which(), in_place.which(), PADDLE_ENFORCE_EQ(out_p.which(), in_place.which(),
"The place of input and output should be the same."); "Places must be all on CPU or all on CUDA.");
if (in_var->IsType<framework::SelectedRows>()) { if (in_var->IsType<framework::SelectedRows>()) {
auto &in_sr = in_var->Get<framework::SelectedRows>(); auto &in_sr = in_var->Get<framework::SelectedRows>();
...@@ -109,24 +100,8 @@ void BroadcastOpHandle::RunImpl() { ...@@ -109,24 +100,8 @@ void BroadcastOpHandle::RunImpl() {
} }
Tensor *out_tensor = GetTensorFromVar(out_var); Tensor *out_tensor = GetTensorFromVar(out_var);
if (platform::is_cpu_place(in_place)) {
paddle::framework::TensorCopy(*in_tensor, out_p, *(dev_ctxes_[in_place]), paddle::framework::TensorCopy(*in_tensor, out_p, *(dev_ctxes_[in_place]),
out_tensor); out_tensor);
} else if (platform::is_gpu_place(in_place)) {
#ifdef PADDLE_WITH_CUDA
auto src_gpu_place = boost::get<platform::CUDAPlace>(in_place);
auto dst_gpu_place = boost::get<platform::CUDAPlace>(out_p);
void *dst_ptr = out_tensor->mutable_data(out_p);
void *src_ptr = in_tensor->data<void>();
int64_t size = in_tensor->numel() * SizeOfType(in_tensor->type());
memory::Copy(
dst_gpu_place, dst_ptr, src_gpu_place, src_ptr, size,
reinterpret_cast<platform::CUDADeviceContext *>(dev_ctxes_[out_p])
->stream());
#else
PADDLE_THROW("CUDAPlace is not supported in CPU device.");
#endif
}
} }
} }
......
...@@ -74,13 +74,10 @@ void GatherOpHandle::RunImpl() { ...@@ -74,13 +74,10 @@ void GatherOpHandle::RunImpl() {
auto in_handle = static_cast<VarHandle *>(in); auto in_handle = static_cast<VarHandle *>(in);
auto in_p = in_handle->place_; auto in_p = in_handle->place_;
in_places.push_back(in_p); in_places.push_back(in_p);
PADDLE_ENFORCE_LT(in_handle->scope_idx_, local_scopes_.size(),
"%s is not the the local_scopes ", in_handle->name_);
PADDLE_ENFORCE_EQ(in_p.which(), pre_place.which(), PADDLE_ENFORCE_EQ(in_p.which(), pre_place.which(),
"The place of input should be the same."); "Places must be all on CPU or all on CUDA.");
auto *s = local_scopes_[in_handle->scope_idx_]; auto in_var =
auto in_var = s->FindVar(in_handle->name_); local_scopes_.at(in_handle->scope_idx_)->FindVar(in_handle->name_);
auto &in_sr = in_var->Get<framework::SelectedRows>(); auto &in_sr = in_var->Get<framework::SelectedRows>();
PADDLE_ENFORCE_EQ(in_sr.value().type(), pre_in.value().type(), PADDLE_ENFORCE_EQ(in_sr.value().type(), pre_in.value().type(),
......
...@@ -11,8 +11,10 @@ ...@@ -11,8 +11,10 @@
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/tensor_util.h"
#include <algorithm>
#include <limits>
#include <vector>
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -65,8 +67,6 @@ void TensorCopy(const Tensor& src, const platform::Place& dst_place, ...@@ -65,8 +67,6 @@ void TensorCopy(const Tensor& src, const platform::Place& dst_place,
auto dst_gpu_place = boost::get<platform::CUDAPlace>(dst_place); auto dst_gpu_place = boost::get<platform::CUDAPlace>(dst_place);
auto ctx_place = ctx.GetPlace(); auto ctx_place = ctx.GetPlace();
PADDLE_ENFORCE(platform::is_gpu_place(ctx_place)); PADDLE_ENFORCE(platform::is_gpu_place(ctx_place));
auto ctx_gpu_place = boost::get<platform::CUDAPlace>(ctx_place);
PADDLE_ENFORCE_EQ(src_gpu_place, ctx_gpu_place);
memory::Copy( memory::Copy(
dst_gpu_place, dst_ptr, src_gpu_place, src_ptr, size, dst_gpu_place, dst_ptr, src_gpu_place, src_ptr, size,
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream()); reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册