提交 6db96ec2 编写于 作者: C chengduoZH

follow comments

上级 8eaec5dd
...@@ -5,7 +5,7 @@ cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod ...@@ -5,7 +5,7 @@ cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod
if(WITH_GPU) if(WITH_GPU)
nv_library(nccl_all_reduce_op_handle SRCS nccl_all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory nv_library(nccl_all_reduce_op_handle SRCS nccl_all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
dynload_cuda) dynload_cuda)
nv_library(broad_cast_op_handle SRCS broad_cast_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory) nv_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
endif() endif()
cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry) cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry)
...@@ -15,8 +15,8 @@ cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS ssa_graph) ...@@ -15,8 +15,8 @@ cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS ssa_graph)
if(WITH_GPU) if(WITH_GPU)
set(multi_devices_graph_builder_deps nccl_all_reduce_op_handle) set(multi_devices_graph_builder_deps nccl_all_reduce_op_handle)
nv_test(broad_cast_op_test SRCS broad_cast_op_handle_test.cc DEPS var_handle op_handle_base scope lod_tensor ddim memory nv_test(broadcast_op_test SRCS broadcast_op_handle_test.cc DEPS var_handle op_handle_base scope lod_tensor ddim memory
device_context broad_cast_op_handle) device_context broadcast_op_handle)
else() else()
set(multi_devices_graph_builder_deps) set(multi_devices_graph_builder_deps)
endif() endif()
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/details/broad_cast_op_handle.h" #include "paddle/fluid/framework/details/broadcast_op_handle.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -28,16 +28,16 @@ Tensor *GetTensorFromVar(Variable *in_var) { ...@@ -28,16 +28,16 @@ Tensor *GetTensorFromVar(Variable *in_var) {
} }
return nullptr; return nullptr;
} }
BCastOpHandle::BCastOpHandle(const std::vector<Scope *> &local_scopes, BroadcastOpHandle::BroadcastOpHandle(const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
const platform::ContextMap &ctxs) const platform::ContextMap &ctxs)
: local_scopes_(local_scopes), places_(places), ctxs_(ctxs) { : local_scopes_(local_scopes), places_(places), ctxs_(ctxs) {
for (auto &p : places_) { for (auto &p : places_) {
this->dev_ctxes_[p] = ctxs_.DevCtx(p); this->dev_ctxes_[p] = ctxs_.DevCtx(p);
} }
} }
void BCastOpHandle::RunImpl() { void BroadcastOpHandle::RunImpl() {
PADDLE_ENFORCE_EQ(this->inputs_.size(), 1); PADDLE_ENFORCE_EQ(this->inputs_.size(), 1);
PADDLE_ENFORCE_EQ(this->outputs_.size(), places_.size()); PADDLE_ENFORCE_EQ(this->outputs_.size(), places_.size());
...@@ -97,7 +97,7 @@ void BCastOpHandle::RunImpl() { ...@@ -97,7 +97,7 @@ void BCastOpHandle::RunImpl() {
} }
} }
std::string BCastOpHandle::Name() const { return "broadcast"; } std::string BroadcastOpHandle::Name() const { return "broadcast"; }
} // namespace details } // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -29,17 +29,17 @@ namespace framework { ...@@ -29,17 +29,17 @@ namespace framework {
namespace details { namespace details {
/* /*
* BroadCast the input to all scope. * Broadcast the input to all scope.
* *
*/ */
struct BCastOpHandle : public OpHandleBase { struct BroadcastOpHandle : public OpHandleBase {
const std::vector<Scope *> &local_scopes_; const std::vector<Scope *> &local_scopes_;
const std::vector<platform::Place> &places_; const std::vector<platform::Place> &places_;
const platform::ContextMap &ctxs_; const platform::ContextMap &ctxs_;
BCastOpHandle(const std::vector<Scope *> &local_scopes, BroadcastOpHandle(const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
const platform::ContextMap &ctxs); const platform::ContextMap &ctxs);
std::string Name() const override; std::string Name() const override;
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/details/broad_cast_op_handle.h" #include "paddle/fluid/framework/details/broadcast_op_handle.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
...@@ -23,12 +23,12 @@ namespace p = paddle::platform; ...@@ -23,12 +23,12 @@ namespace p = paddle::platform;
// test data amount // test data amount
const f::DDim kDims = {20, 20}; const f::DDim kDims = {20, 20};
class BroadCastTester : public ::testing::Test { class BroadcastTester : public ::testing::Test {
public: public:
void SetUp() override { void SetUp() override {
int count = p::GetCUDADeviceCount(); int count = p::GetCUDADeviceCount();
if (count <= 1) { if (count <= 1) {
LOG(WARNING) << "Cannot test multi-gpu BroadCast, because the CUDA " LOG(WARNING) << "Cannot test multi-gpu Broadcast, because the CUDA "
"device count is " "device count is "
<< count; << count;
exit(0); exit(0);
...@@ -40,7 +40,7 @@ class BroadCastTester : public ::testing::Test { ...@@ -40,7 +40,7 @@ class BroadCastTester : public ::testing::Test {
} }
template <class T> template <class T>
void BroadCastInitOp(int gpu_id = 0) { void BroadcastInitOp(int gpu_id = 0) {
for (size_t j = 0; j < gpu_list_.size(); ++j) { for (size_t j = 0; j < gpu_list_.size(); ++j) {
local_scope_.push_back(&g_scope_.NewScope()); local_scope_.push_back(&g_scope_.NewScope());
auto* out_var = local_scope_[j]->Var("out"); auto* out_var = local_scope_[j]->Var("out");
...@@ -50,7 +50,7 @@ class BroadCastTester : public ::testing::Test { ...@@ -50,7 +50,7 @@ class BroadCastTester : public ::testing::Test {
in_var->GetMutable<T>(); in_var->GetMutable<T>();
bc_op_handle_ = bc_op_handle_ =
new f::details::BCastOpHandle(local_scope_, gpu_list_, *ctxs_); new f::details::BroadcastOpHandle(local_scope_, gpu_list_, *ctxs_);
f::details::VarHandle* in_var_handle = new f::details::VarHandle(); f::details::VarHandle* in_var_handle = new f::details::VarHandle();
in_var_handle->place_ = gpu_list_[gpu_id]; in_var_handle->place_ = gpu_list_[gpu_id];
...@@ -68,7 +68,7 @@ class BroadCastTester : public ::testing::Test { ...@@ -68,7 +68,7 @@ class BroadCastTester : public ::testing::Test {
bc_op_handle_->AddOutput(out_var_handle); bc_op_handle_->AddOutput(out_var_handle);
} }
} }
void BroadCastDestroy() { void BroadcastDestroy() {
delete ctxs_; delete ctxs_;
for (auto in : bc_op_handle_->inputs_) { for (auto in : bc_op_handle_->inputs_) {
delete in; delete in;
...@@ -84,12 +84,12 @@ class BroadCastTester : public ::testing::Test { ...@@ -84,12 +84,12 @@ class BroadCastTester : public ::testing::Test {
p::ContextMap* ctxs_; p::ContextMap* ctxs_;
std::vector<f::Scope*> local_scope_; std::vector<f::Scope*> local_scope_;
std::vector<p::Place> gpu_list_; std::vector<p::Place> gpu_list_;
f::details::BCastOpHandle* bc_op_handle_; f::details::BroadcastOpHandle* bc_op_handle_;
}; };
TEST_F(BroadCastTester, BroadCastTestLodTensor) { TEST_F(BroadcastTester, BroadcastTestLodTensor) {
int gpu_id = 0; int gpu_id = 0;
BroadCastInitOp<f::LoDTensor>(gpu_id); BroadcastInitOp<f::LoDTensor>(gpu_id);
auto in_var = local_scope_[gpu_id]->Var("input"); auto in_var = local_scope_[gpu_id]->Var("input");
auto in_lod_tensor = in_var->GetMutable<f::LoDTensor>(); auto in_lod_tensor = in_var->GetMutable<f::LoDTensor>();
...@@ -122,12 +122,12 @@ TEST_F(BroadCastTester, BroadCastTestLodTensor) { ...@@ -122,12 +122,12 @@ TEST_F(BroadCastTester, BroadCastTestLodTensor) {
} }
} }
BroadCastDestroy(); BroadcastDestroy();
} }
TEST_F(BroadCastTester, BroadCastTestSelectedRows) { TEST_F(BroadcastTester, BroadcastTestSelectedRows) {
int gpu_id = 0; int gpu_id = 0;
BroadCastInitOp<f::SelectedRows>(gpu_id); BroadcastInitOp<f::SelectedRows>(gpu_id);
auto in_var = local_scope_[gpu_id]->Var("input"); auto in_var = local_scope_[gpu_id]->Var("input");
auto in_selected_rows = in_var->GetMutable<f::SelectedRows>(); auto in_selected_rows = in_var->GetMutable<f::SelectedRows>();
...@@ -170,5 +170,5 @@ TEST_F(BroadCastTester, BroadCastTestSelectedRows) { ...@@ -170,5 +170,5 @@ TEST_F(BroadCastTester, BroadCastTestSelectedRows) {
} }
} }
BroadCastDestroy(); BroadcastDestroy();
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册