提交 8496b2e4 编写于 作者: Y Yang Yu

Refine parallel_do

上级 60e27d11
...@@ -270,10 +270,10 @@ std::vector<LoDTensor> LoDTensor::SplitLoDTensor( ...@@ -270,10 +270,10 @@ std::vector<LoDTensor> LoDTensor::SplitLoDTensor(
"Batch size should be divided by places size"); "Batch size should be divided by places size");
std::vector<LoDTensor> lods; std::vector<LoDTensor> lods;
for (int place_idx = 0; place_idx < places.size(); ++place_idx) { for (size_t place_idx = 0; place_idx < places.size(); ++place_idx) {
int begin = place_idx * dims()[0] / places.size(); size_t begin = place_idx * dims()[0] / places.size();
int end = (place_idx + 1) * dims()[0] / places.size(); size_t end = (place_idx + 1) * dims()[0] / places.size();
auto src = Slice(begin, end); auto src = Slice(static_cast<int>(begin), static_cast<int>(end));
LoDTensor dst; LoDTensor dst;
dst.Resize(src.dims()); dst.Resize(src.dims());
......
...@@ -12,23 +12,23 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,23 +12,23 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <thread>
#include <vector> #include <vector>
#include "paddle/framework/executor.h" #include "paddle/framework/executor.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/framework/threadpool.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
constexpr char kInputs[] = "inputs"; static constexpr char kInputs[] = "inputs";
constexpr char kParameters[] = "parameters"; static constexpr char kParameters[] = "parameters";
constexpr char kPlaces[] = "places"; static constexpr char kPlaces[] = "places";
constexpr char kOutputs[] = "outputs"; static constexpr char kOutputs[] = "outputs";
constexpr char kParallelScopes[] = "parallel_scopes"; static constexpr char kParallelScopes[] = "parallel_scopes";
constexpr char kParallelBlock[] = "sub_block"; static constexpr char kParallelBlock[] = "sub_block";
// using ParallelScopeVar = std::vector<framework::Scope *>; // using ParallelScopeVar = std::vector<framework::Scope *>;
using LoDTensor = framework::LoDTensor; using LoDTensor = framework::LoDTensor;
...@@ -85,7 +85,8 @@ class ParallelDoOp : public framework::OperatorBase { ...@@ -85,7 +85,8 @@ class ParallelDoOp : public framework::OperatorBase {
SplitTensorAndMoveTensorToScopes(scope, sub_scopes, places, SplitTensorAndMoveTensorToScopes(scope, sub_scopes, places,
Inputs(kInputs)); Inputs(kInputs));
std::vector<std::thread> workers; std::vector<std::future<void>> workers;
workers.reserve(places.size());
for (size_t place_idx = 0; place_idx < places.size(); ++place_idx) { for (size_t place_idx = 0; place_idx < places.size(); ++place_idx) {
VLOG(3) << "Run " << place_idx; VLOG(3) << "Run " << place_idx;
...@@ -93,26 +94,27 @@ class ParallelDoOp : public framework::OperatorBase { ...@@ -93,26 +94,27 @@ class ParallelDoOp : public framework::OperatorBase {
auto *cur_scope = sub_scopes[place_idx]; auto *cur_scope = sub_scopes[place_idx];
// copy parameter // copy parameter
if (dev_ctx.GetPlace() != place) { // some version of boost lacks != for boost::variant
if (!(dev_ctx.GetPlace() == place)) {
PADDLE_THROW("Not Implemented"); PADDLE_THROW("Not Implemented");
} }
// execute workers.emplace_back(framework::Async([program, cur_scope, place, block] {
workers.push_back(std::thread([program, cur_scope, place, block] { framework::Executor executor(place);
auto executor = framework::Executor(place);
executor.Run(*program, cur_scope, block->ID(), executor.Run(*program, cur_scope, block->ID(),
false /*create_local_scope*/); false /*create_local_scope*/);
})); }));
} }
for (auto &worker : workers) { for (auto &worker : workers) {
worker.join(); worker.wait();
} }
// merge output // merge output
for (auto &o_name : Outputs(kOutputs)) { for (auto &o_name : Outputs(kOutputs)) {
std::vector<const framework::LoDTensor *> lod_tensors; std::vector<const framework::LoDTensor *> lod_tensors;
lod_tensors.reserve(sub_scopes.size());
for (auto *sub_scope : sub_scopes) { for (auto *sub_scope : sub_scopes) {
lod_tensors.push_back(&sub_scope->FindVar(o_name)->Get<LoDTensor>()); lod_tensors.emplace_back(&sub_scope->FindVar(o_name)->Get<LoDTensor>());
} }
auto *lod_tensor_to_be_merged = auto *lod_tensor_to_be_merged =
...@@ -177,7 +179,7 @@ class ParallelDoGradOp : public OperatorBase { ...@@ -177,7 +179,7 @@ class ParallelDoGradOp : public OperatorBase {
} }
// exe run // exe run
std::vector<std::thread> workers; std::vector<std::future<void>> workers;
for (size_t place_idx = 0; place_idx < places.size(); ++place_idx) { for (size_t place_idx = 0; place_idx < places.size(); ++place_idx) {
VLOG(3) << "Run " << place_idx; VLOG(3) << "Run " << place_idx;
...@@ -185,14 +187,14 @@ class ParallelDoGradOp : public OperatorBase { ...@@ -185,14 +187,14 @@ class ParallelDoGradOp : public OperatorBase {
auto *cur_scope = sub_scopes[place_idx]; auto *cur_scope = sub_scopes[place_idx];
// execute // execute
workers.push_back(std::thread([program, cur_scope, place, block] { workers.emplace_back(framework::Async([program, cur_scope, place, block] {
auto executor = framework::Executor(place); framework::Executor executor(place);
executor.Run(*program, cur_scope, block->ID(), executor.Run(*program, cur_scope, block->ID(),
false /*create_local_scope*/); false /*create_local_scope*/);
})); }));
} }
for (auto &worker : workers) { for (auto &worker : workers) {
worker.join(); worker.wait();
} }
// merge grad // merge grad
......
...@@ -205,6 +205,7 @@ def _append_backward_ops_(target, ...@@ -205,6 +205,7 @@ def _append_backward_ops_(target,
# Getting op's corresponding grad_op # Getting op's corresponding grad_op
grad_op_desc, op_grad_to_var = core.get_grad_op_desc( grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
op.desc, no_grad_dict[block.idx], grad_sub_block_list) op.desc, no_grad_dict[block.idx], grad_sub_block_list)
grad_op_descs.extend(grad_op_desc) grad_op_descs.extend(grad_op_desc)
grad_to_var.update(op_grad_to_var) grad_to_var.update(op_grad_to_var)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册