提交 6db96ec2 编写于 作者: C chengduoZH

follow comments

上级 8eaec5dd
......@@ -5,7 +5,7 @@ cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod
if(WITH_GPU)
nv_library(nccl_all_reduce_op_handle SRCS nccl_all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
dynload_cuda)
nv_library(broad_cast_op_handle SRCS broad_cast_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
nv_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
endif()
cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry)
......@@ -15,8 +15,8 @@ cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS ssa_graph)
if(WITH_GPU)
set(multi_devices_graph_builder_deps nccl_all_reduce_op_handle)
nv_test(broad_cast_op_test SRCS broad_cast_op_handle_test.cc DEPS var_handle op_handle_base scope lod_tensor ddim memory
device_context broad_cast_op_handle)
nv_test(broadcast_op_test SRCS broadcast_op_handle_test.cc DEPS var_handle op_handle_base scope lod_tensor ddim memory
device_context broadcast_op_handle)
else()
set(multi_devices_graph_builder_deps)
endif()
......
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/broad_cast_op_handle.h"
#include "paddle/fluid/framework/details/broadcast_op_handle.h"
namespace paddle {
namespace framework {
......@@ -28,7 +28,7 @@ Tensor *GetTensorFromVar(Variable *in_var) {
}
return nullptr;
}
BCastOpHandle::BCastOpHandle(const std::vector<Scope *> &local_scopes,
BroadcastOpHandle::BroadcastOpHandle(const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
const platform::ContextMap &ctxs)
: local_scopes_(local_scopes), places_(places), ctxs_(ctxs) {
......@@ -37,7 +37,7 @@ BCastOpHandle::BCastOpHandle(const std::vector<Scope *> &local_scopes,
}
}
void BCastOpHandle::RunImpl() {
void BroadcastOpHandle::RunImpl() {
PADDLE_ENFORCE_EQ(this->inputs_.size(), 1);
PADDLE_ENFORCE_EQ(this->outputs_.size(), places_.size());
......@@ -97,7 +97,7 @@ void BCastOpHandle::RunImpl() {
}
}
std::string BCastOpHandle::Name() const { return "broadcast"; }
std::string BroadcastOpHandle::Name() const { return "broadcast"; }
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -29,15 +29,15 @@ namespace framework {
namespace details {
/*
* BroadCast the input to all scope.
* Broadcast the input to all scope.
*
*/
struct BCastOpHandle : public OpHandleBase {
struct BroadcastOpHandle : public OpHandleBase {
const std::vector<Scope *> &local_scopes_;
const std::vector<platform::Place> &places_;
const platform::ContextMap &ctxs_;
BCastOpHandle(const std::vector<Scope *> &local_scopes,
BroadcastOpHandle(const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
const platform::ContextMap &ctxs);
......
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/broad_cast_op_handle.h"
#include "paddle/fluid/framework/details/broadcast_op_handle.h"
#include "gtest/gtest.h"
#include "paddle/fluid/platform/device_context.h"
......@@ -23,12 +23,12 @@ namespace p = paddle::platform;
// test data amount
const f::DDim kDims = {20, 20};
class BroadCastTester : public ::testing::Test {
class BroadcastTester : public ::testing::Test {
public:
void SetUp() override {
int count = p::GetCUDADeviceCount();
if (count <= 1) {
LOG(WARNING) << "Cannot test multi-gpu BroadCast, because the CUDA "
LOG(WARNING) << "Cannot test multi-gpu Broadcast, because the CUDA "
"device count is "
<< count;
exit(0);
......@@ -40,7 +40,7 @@ class BroadCastTester : public ::testing::Test {
}
template <class T>
void BroadCastInitOp(int gpu_id = 0) {
void BroadcastInitOp(int gpu_id = 0) {
for (size_t j = 0; j < gpu_list_.size(); ++j) {
local_scope_.push_back(&g_scope_.NewScope());
auto* out_var = local_scope_[j]->Var("out");
......@@ -50,7 +50,7 @@ class BroadCastTester : public ::testing::Test {
in_var->GetMutable<T>();
bc_op_handle_ =
new f::details::BCastOpHandle(local_scope_, gpu_list_, *ctxs_);
new f::details::BroadcastOpHandle(local_scope_, gpu_list_, *ctxs_);
f::details::VarHandle* in_var_handle = new f::details::VarHandle();
in_var_handle->place_ = gpu_list_[gpu_id];
......@@ -68,7 +68,7 @@ class BroadCastTester : public ::testing::Test {
bc_op_handle_->AddOutput(out_var_handle);
}
}
void BroadCastDestroy() {
void BroadcastDestroy() {
delete ctxs_;
for (auto in : bc_op_handle_->inputs_) {
delete in;
......@@ -84,12 +84,12 @@ class BroadCastTester : public ::testing::Test {
p::ContextMap* ctxs_;
std::vector<f::Scope*> local_scope_;
std::vector<p::Place> gpu_list_;
f::details::BCastOpHandle* bc_op_handle_;
f::details::BroadcastOpHandle* bc_op_handle_;
};
TEST_F(BroadCastTester, BroadCastTestLodTensor) {
TEST_F(BroadcastTester, BroadcastTestLodTensor) {
int gpu_id = 0;
BroadCastInitOp<f::LoDTensor>(gpu_id);
BroadcastInitOp<f::LoDTensor>(gpu_id);
auto in_var = local_scope_[gpu_id]->Var("input");
auto in_lod_tensor = in_var->GetMutable<f::LoDTensor>();
......@@ -122,12 +122,12 @@ TEST_F(BroadCastTester, BroadCastTestLodTensor) {
}
}
BroadCastDestroy();
BroadcastDestroy();
}
TEST_F(BroadCastTester, BroadCastTestSelectedRows) {
TEST_F(BroadcastTester, BroadcastTestSelectedRows) {
int gpu_id = 0;
BroadCastInitOp<f::SelectedRows>(gpu_id);
BroadcastInitOp<f::SelectedRows>(gpu_id);
auto in_var = local_scope_[gpu_id]->Var("input");
auto in_selected_rows = in_var->GetMutable<f::SelectedRows>();
......@@ -170,5 +170,5 @@ TEST_F(BroadCastTester, BroadCastTestSelectedRows) {
}
}
BroadCastDestroy();
BroadcastDestroy();
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册