提交 8496b2e4 编写于 作者: Y Yang Yu

Refine parallel_do

上级 60e27d11
......@@ -270,10 +270,10 @@ std::vector<LoDTensor> LoDTensor::SplitLoDTensor(
"Batch size should be divided by places size");
std::vector<LoDTensor> lods;
for (int place_idx = 0; place_idx < places.size(); ++place_idx) {
int begin = place_idx * dims()[0] / places.size();
int end = (place_idx + 1) * dims()[0] / places.size();
auto src = Slice(begin, end);
for (size_t place_idx = 0; place_idx < places.size(); ++place_idx) {
size_t begin = place_idx * dims()[0] / places.size();
size_t end = (place_idx + 1) * dims()[0] / places.size();
auto src = Slice(static_cast<int>(begin), static_cast<int>(end));
LoDTensor dst;
dst.Resize(src.dims());
......
......@@ -12,23 +12,23 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <thread>
#include <vector>
#include "paddle/framework/executor.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/threadpool.h"
namespace paddle {
namespace operators {
constexpr char kInputs[] = "inputs";
constexpr char kParameters[] = "parameters";
constexpr char kPlaces[] = "places";
static constexpr char kInputs[] = "inputs";
static constexpr char kParameters[] = "parameters";
static constexpr char kPlaces[] = "places";
constexpr char kOutputs[] = "outputs";
constexpr char kParallelScopes[] = "parallel_scopes";
static constexpr char kOutputs[] = "outputs";
static constexpr char kParallelScopes[] = "parallel_scopes";
constexpr char kParallelBlock[] = "sub_block";
static constexpr char kParallelBlock[] = "sub_block";
// using ParallelScopeVar = std::vector<framework::Scope *>;
using LoDTensor = framework::LoDTensor;
......@@ -85,7 +85,8 @@ class ParallelDoOp : public framework::OperatorBase {
SplitTensorAndMoveTensorToScopes(scope, sub_scopes, places,
Inputs(kInputs));
std::vector<std::thread> workers;
std::vector<std::future<void>> workers;
workers.reserve(places.size());
for (size_t place_idx = 0; place_idx < places.size(); ++place_idx) {
VLOG(3) << "Run " << place_idx;
......@@ -93,26 +94,27 @@ class ParallelDoOp : public framework::OperatorBase {
auto *cur_scope = sub_scopes[place_idx];
// copy parameter
if (dev_ctx.GetPlace() != place) {
// some version of boost lacks != for boost::variant
if (!(dev_ctx.GetPlace() == place)) {
PADDLE_THROW("Not Implemented");
}
// execute
workers.push_back(std::thread([program, cur_scope, place, block] {
auto executor = framework::Executor(place);
workers.emplace_back(framework::Async([program, cur_scope, place, block] {
framework::Executor executor(place);
executor.Run(*program, cur_scope, block->ID(),
false /*create_local_scope*/);
}));
}
for (auto &worker : workers) {
worker.join();
worker.wait();
}
// merge output
for (auto &o_name : Outputs(kOutputs)) {
std::vector<const framework::LoDTensor *> lod_tensors;
lod_tensors.reserve(sub_scopes.size());
for (auto *sub_scope : sub_scopes) {
lod_tensors.push_back(&sub_scope->FindVar(o_name)->Get<LoDTensor>());
lod_tensors.emplace_back(&sub_scope->FindVar(o_name)->Get<LoDTensor>());
}
auto *lod_tensor_to_be_merged =
......@@ -177,7 +179,7 @@ class ParallelDoGradOp : public OperatorBase {
}
// exe run
std::vector<std::thread> workers;
std::vector<std::future<void>> workers;
for (size_t place_idx = 0; place_idx < places.size(); ++place_idx) {
VLOG(3) << "Run " << place_idx;
......@@ -185,14 +187,14 @@ class ParallelDoGradOp : public OperatorBase {
auto *cur_scope = sub_scopes[place_idx];
// execute
workers.push_back(std::thread([program, cur_scope, place, block] {
auto executor = framework::Executor(place);
workers.emplace_back(framework::Async([program, cur_scope, place, block] {
framework::Executor executor(place);
executor.Run(*program, cur_scope, block->ID(),
false /*create_local_scope*/);
}));
}
for (auto &worker : workers) {
worker.join();
worker.wait();
}
// merge grad
......
......@@ -205,6 +205,7 @@ def _append_backward_ops_(target,
# Getting op's corresponding grad_op
grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
op.desc, no_grad_dict[block.idx], grad_sub_block_list)
grad_op_descs.extend(grad_op_desc)
grad_to_var.update(op_grad_to_var)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册