提交 87648f8e 编写于 作者: J JiabinYang

merge develop, test=develop

...@@ -186,8 +186,7 @@ set(module "inference") ...@@ -186,8 +186,7 @@ set(module "inference")
copy(inference_lib DEPS ${inference_deps} copy(inference_lib DEPS ${inference_deps}
SRCS ${src_dir}/${module}/*.h ${PADDLE_BINARY_DIR}/paddle/fluid/inference/libpaddle_fluid.* SRCS ${src_dir}/${module}/*.h ${PADDLE_BINARY_DIR}/paddle/fluid/inference/libpaddle_fluid.*
${src_dir}/${module}/api/paddle_*.h ${src_dir}/${module}/api/paddle_*.h
${PADDLE_BINARY_DIR}/paddle/fluid/inference/api/paddle_inference_pass.h DSTS ${dst_dir}/${module} ${dst_dir}/${module} ${dst_dir}/${module}
DSTS ${dst_dir}/${module} ${dst_dir}/${module} ${dst_dir}/${module} ${dst_dir}/${module}
) )
set(module "platform") set(module "platform")
......
...@@ -97,8 +97,8 @@ paddle.fluid.layers.warpctc ArgSpec(args=['input', 'label', 'blank', 'norm_by_ti ...@@ -97,8 +97,8 @@ paddle.fluid.layers.warpctc ArgSpec(args=['input', 'label', 'blank', 'norm_by_ti
paddle.fluid.layers.sequence_reshape ArgSpec(args=['input', 'new_dim'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.sequence_reshape ArgSpec(args=['input', 'new_dim'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.transpose ArgSpec(args=['x', 'perm', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.transpose ArgSpec(args=['x', 'perm', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.im2sequence ArgSpec(args=['input', 'filter_size', 'stride', 'padding', 'input_image_size', 'out_stride', 'name'], varargs=None, keywords=None, defaults=(1, 1, 0, None, 1, None)) paddle.fluid.layers.im2sequence ArgSpec(args=['input', 'filter_size', 'stride', 'padding', 'input_image_size', 'out_stride', 'name'], varargs=None, keywords=None, defaults=(1, 1, 0, None, 1, None))
paddle.fluid.layers.nce ArgSpec(args=['input', 'label', 'num_total_classes', 'sample_weight', 'param_attr', 'bias_attr', 'num_neg_samples', 'name', 'sampler', 'custom_dist', 'seed'], varargs=None, keywords=None, defaults=(None, None, None, None, None, 'uniform', None, 0))
paddle.fluid.layers.hsigmoid ArgSpec(args=['input', 'label', 'num_classes', 'param_attr', 'bias_attr', 'name', 'non_leaf_num', 'ptable', 'pcode', 'is_costum', 'is_sparse'], varargs=None, keywords=None, defaults=(None, None, None, None, None, None, None, False, False)) paddle.fluid.layers.hsigmoid ArgSpec(args=['input', 'label', 'num_classes', 'param_attr', 'bias_attr', 'name', 'non_leaf_num', 'ptable', 'pcode', 'is_costum', 'is_sparse'], varargs=None, keywords=None, defaults=(None, None, None, None, None, None, None, False, False))
paddle.fluid.layers.nce ArgSpec(args=['input', 'label', 'num_total_classes', 'sample_weight', 'param_attr', 'bias_attr', 'num_neg_samples', 'name', 'sampler', 'custom_dist', 'seed', 'is_sparse'], varargs=None, keywords=None, defaults=(None, None, None, None, None, 'uniform', None, 0, False))
paddle.fluid.layers.beam_search ArgSpec(args=['pre_ids', 'pre_scores', 'ids', 'scores', 'beam_size', 'end_id', 'level', 'name'], varargs=None, keywords=None, defaults=(0, None)) paddle.fluid.layers.beam_search ArgSpec(args=['pre_ids', 'pre_scores', 'ids', 'scores', 'beam_size', 'end_id', 'level', 'name'], varargs=None, keywords=None, defaults=(0, None))
paddle.fluid.layers.row_conv ArgSpec(args=['input', 'future_context_size', 'param_attr', 'act'], varargs=None, keywords=None, defaults=(None, None)) paddle.fluid.layers.row_conv ArgSpec(args=['input', 'future_context_size', 'param_attr', 'act'], varargs=None, keywords=None, defaults=(None, None))
paddle.fluid.layers.multiplex ArgSpec(args=['inputs', 'index'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.multiplex ArgSpec(args=['inputs', 'index'], varargs=None, keywords=None, defaults=None)
......
...@@ -39,11 +39,12 @@ if (WITH_GPU) ...@@ -39,11 +39,12 @@ if (WITH_GPU)
endif() endif()
cc_library(sequential_execution_pass SRCS sequential_execution_pass.cc DEPS graph graph_helper pass) cc_library(sequential_execution_pass SRCS sequential_execution_pass.cc DEPS graph graph_helper pass)
cc_library(all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS graph graph_helper pass)
cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle
scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle fused_broadcast_op_handle) scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle fused_broadcast_op_handle)
set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto sequential_execution_pass modify_op_lock_and_record_event_pass) set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto sequential_execution_pass modify_op_lock_and_record_event_pass all_reduce_deps_pass)
if (WITH_GPU) if (WITH_GPU)
list(APPEND SSA_GRAPH_EXECUTOR_DEPS reference_count_pass) list(APPEND SSA_GRAPH_EXECUTOR_DEPS reference_count_pass)
endif() endif()
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/details/all_reduce_deps_pass.h"
#include "paddle/fluid/framework/details/all_reduce_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/op_graph_view.h"
#include "paddle/fluid/framework/details/var_handle.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_proto_maker.h"
namespace paddle {
namespace framework {
namespace details {
static constexpr char kAllOpDescs[] = "all_op_descs";
VarHandle* GetValidInput(const OpHandleBase* a) {
for (auto p : a->Inputs()) {
VarHandle* b = dynamic_cast<VarHandle*>(p);
if (b) {
return b;
}
}
return nullptr;
}
std::unique_ptr<ir::Graph> AllReduceDepsPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
auto graph_ops = ir::FilterByNodeWrapper<OpHandleBase>(*graph);
// get vars order
int order = 0;
std::unordered_map<std::string, int> vars;
// TODO(gongwb): use graph topology sort to find the order of operators.
// Note that must assert topology sort is stable
auto& ops = Get<const std::vector<OpDesc*>>(kAllOpDescs);
for (auto* op_desc : ops) {
auto outputs = op_desc->Outputs();
for (auto& o_it : outputs) {
for (auto& v : o_it.second) { // values
vars[v] = order;
}
}
order++;
}
std::vector<OpHandleBase*> dist_ops;
// get allreduce ops.
for (auto& op : graph_ops) {
// FIXME(gongwb):add broad cast.
if (op->Name() == "all_reduce" || op->Name() == "reduce") {
dist_ops.push_back(op);
}
}
VLOG(10) << "dist_ops size:" << dist_ops.size() << std::endl;
std::sort(dist_ops.begin(), dist_ops.end(), [&](OpHandleBase* op1,
OpHandleBase* op2) {
VarHandle* i0 = dynamic_cast<VarHandle*>(GetValidInput(op1));
VarHandle* i1 = dynamic_cast<VarHandle*>(GetValidInput(op2));
PADDLE_ENFORCE(i0 != nullptr && i1 != nullptr, "%s convert to %s error",
op1->DebugString(), op2->DebugString());
auto l_it = vars.find(i0->name_);
auto r_it = vars.find(i1->name_);
if (l_it->second < r_it->second) return true;
if (l_it->second == r_it->second) {
return i0->name_ < i1->name_;
}
return false;
});
// add dependency.
auto& sorted_ops = dist_ops;
for (size_t i = 1; i < sorted_ops.size(); ++i) {
auto* dep_var = new DummyVarHandle(graph->CreateControlDepVar());
auto* pre_op = sorted_ops[i - 1];
auto* op = sorted_ops[i];
pre_op->AddOutput(dep_var);
op->AddInput(dep_var);
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var);
VLOG(10) << "add all_reduce sequential dependencies between " << pre_op
<< " and " << op;
VLOG(10) << "pre_op:" << pre_op->DebugString()
<< ", op:" << op->DebugString();
}
return graph;
}
} // namespace details
} // namespace framework
} // namespace paddle
REGISTER_PASS(all_reduce_deps_pass,
paddle::framework::details::AllReduceDepsPass)
.RequirePassAttr(paddle::framework::details::kAllOpDescs);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace details {
// TODO(gongwb): overlap allreduce with backward computation.
class AllReduceDepsPass : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
};
} // namespace details
} // namespace framework
} // namespace paddle
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
#include "paddle/fluid/framework/details/multi_devices_graph_check_pass.h" #include "paddle/fluid/framework/details/multi_devices_graph_check_pass.h"
#include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h" #include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h"
#include "paddle/fluid/framework/details/reduce_op_handle.h"
#include "paddle/fluid/framework/details/sequential_execution_pass.h" #include "paddle/fluid/framework/details/sequential_execution_pass.h"
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_viz_pass.h" #include "paddle/fluid/framework/ir/graph_viz_pass.h"
...@@ -24,6 +25,10 @@ namespace paddle { ...@@ -24,6 +25,10 @@ namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
static inline bool SeqOnlyAllReduceOps(const BuildStrategy &strategy) {
return (!strategy.enable_sequential_execution_ && strategy.num_trainers_ > 1);
}
class ParallelExecutorPassBuilder : public ir::PassBuilder { class ParallelExecutorPassBuilder : public ir::PassBuilder {
public: public:
explicit ParallelExecutorPassBuilder(const BuildStrategy &strategy) explicit ParallelExecutorPassBuilder(const BuildStrategy &strategy)
...@@ -70,6 +75,10 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -70,6 +75,10 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
// Verify that the graph is correct for multi-device executor. // Verify that the graph is correct for multi-device executor.
AppendPass("multi_devices_check_pass"); AppendPass("multi_devices_check_pass");
if (SeqOnlyAllReduceOps(strategy)) {
AppendPass("all_reduce_deps_pass");
}
if (strategy_.remove_unnecessary_lock_) { if (strategy_.remove_unnecessary_lock_) {
AppendPass("modify_op_lock_and_record_event_pass"); AppendPass("modify_op_lock_and_record_event_pass");
} }
...@@ -124,6 +133,17 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply( ...@@ -124,6 +133,17 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
pass->SetNotOwned<platform::NCCLContextMap>("nccl_ctxs", nctx); pass->SetNotOwned<platform::NCCLContextMap>("nccl_ctxs", nctx);
#endif #endif
} else if (pass->Type() == "sequential_execution_pass") { } else if (pass->Type() == "sequential_execution_pass") {
VLOG(1) << "set enable_sequential_execution:"
<< enable_sequential_execution_;
pass->Erase(kAllOpDescs);
pass->Set<const std::vector<OpDesc *>>(
kAllOpDescs,
new std::vector<OpDesc *>(main_program.Block(0).AllOps()));
} else if (pass->Type() == "all_reduce_deps_pass") {
VLOG(1) << "SeqOnlyAllReduceOps:" << SeqOnlyAllReduceOps(*this)
<< ", num_trainers:" << num_trainers_;
pass->Erase(kAllOpDescs); pass->Erase(kAllOpDescs);
pass->Set<const std::vector<OpDesc *>>( pass->Set<const std::vector<OpDesc *>>(
kAllOpDescs, kAllOpDescs,
...@@ -144,4 +164,5 @@ USE_PASS(multi_devices_pass); ...@@ -144,4 +164,5 @@ USE_PASS(multi_devices_pass);
USE_PASS(multi_devices_check_pass); USE_PASS(multi_devices_check_pass);
USE_PASS(multi_devices_print_pass); USE_PASS(multi_devices_print_pass);
USE_PASS(sequential_execution_pass); USE_PASS(sequential_execution_pass);
USE_PASS(all_reduce_deps_pass);
USE_PASS(modify_op_lock_and_record_event_pass); USE_PASS(modify_op_lock_and_record_event_pass);
...@@ -73,6 +73,7 @@ struct BuildStrategy { ...@@ -73,6 +73,7 @@ struct BuildStrategy {
bool fuse_broadcast_op_{false}; bool fuse_broadcast_op_{false};
int num_trainers_{1};
bool remove_unnecessary_lock_{false}; bool remove_unnecessary_lock_{false};
// NOTE: // NOTE:
......
...@@ -20,7 +20,7 @@ limitations under the License. */ ...@@ -20,7 +20,7 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#ifdef PADDLE_WITH_CUDA #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#include "paddle/fluid/platform/nccl_helper.h" #include "paddle/fluid/platform/nccl_helper.h"
#endif #endif
...@@ -54,7 +54,7 @@ class ParallelExecutorPrivate { ...@@ -54,7 +54,7 @@ class ParallelExecutorPrivate {
Scope *global_scope_; // not owned Scope *global_scope_; // not owned
std::unique_ptr<details::SSAGraphExecutor> executor_; std::unique_ptr<details::SSAGraphExecutor> executor_;
#ifdef PADDLE_WITH_CUDA #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
std::unique_ptr<platform::NCCLContextMap> nccl_ctxs_; std::unique_ptr<platform::NCCLContextMap> nccl_ctxs_;
#endif #endif
bool own_local_scope_; bool own_local_scope_;
...@@ -104,7 +104,7 @@ ParallelExecutor::ParallelExecutor( ...@@ -104,7 +104,7 @@ ParallelExecutor::ParallelExecutor(
if (member_->use_cuda_) { if (member_->use_cuda_) {
// Bcast Parameters to all GPUs // Bcast Parameters to all GPUs
#ifdef PADDLE_WITH_CUDA #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
auto *nccl_id_var = scope->FindVar(NCCL_ID_VARNAME); auto *nccl_id_var = scope->FindVar(NCCL_ID_VARNAME);
ncclUniqueId *nccl_id = nullptr; ncclUniqueId *nccl_id = nullptr;
if (nccl_id_var != nullptr) { if (nccl_id_var != nullptr) {
...@@ -124,7 +124,7 @@ ParallelExecutor::ParallelExecutor( ...@@ -124,7 +124,7 @@ ParallelExecutor::ParallelExecutor(
// Step 2. Convert main_program to SSA form and dependency graph. Also, insert // Step 2. Convert main_program to SSA form and dependency graph. Also, insert
// ncclOp // ncclOp
#ifdef PADDLE_WITH_CUDA #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
std::unique_ptr<ir::Graph> graph = build_strategy.Apply( std::unique_ptr<ir::Graph> graph = build_strategy.Apply(
main_program, member_->places_, loss_var_name, params, main_program, member_->places_, loss_var_name, params,
member_->local_scopes_, member_->use_cuda_, member_->nccl_ctxs_.get()); member_->local_scopes_, member_->use_cuda_, member_->nccl_ctxs_.get());
...@@ -213,7 +213,7 @@ void ParallelExecutor::BCastParamsToDevices( ...@@ -213,7 +213,7 @@ void ParallelExecutor::BCastParamsToDevices(
} }
auto &dims = main_tensor.dims(); auto &dims = main_tensor.dims();
if (paddle::platform::is_gpu_place(main_tensor.place())) { if (paddle::platform::is_gpu_place(main_tensor.place())) {
#ifdef PADDLE_WITH_CUDA #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
std::vector<void *> buffers; std::vector<void *> buffers;
size_t numel = main_tensor.numel(); size_t numel = main_tensor.numel();
ncclDataType_t data_type = platform::ToNCCLDataType(main_tensor.type()); ncclDataType_t data_type = platform::ToNCCLDataType(main_tensor.type());
......
...@@ -76,7 +76,8 @@ void TestWord2vecPrediction(const std::string& model_path) { ...@@ -76,7 +76,8 @@ void TestWord2vecPrediction(const std::string& model_path) {
0.000932706}; 0.000932706};
const size_t num_elements = outputs.front().data.length() / sizeof(float); const size_t num_elements = outputs.front().data.length() / sizeof(float);
// The outputs' buffers are in CPU memory. // The outputs' buffers are in CPU memory.
for (size_t i = 0; i < std::min((size_t)5UL, num_elements); i++) { for (size_t i = 0; i < std::min(static_cast<size_t>(5UL), num_elements);
i++) {
LOG(INFO) << "data: " LOG(INFO) << "data: "
<< static_cast<float*>(outputs.front().data.data())[i]; << static_cast<float*>(outputs.front().data.data())[i];
PADDLE_ENFORCE(static_cast<float*>(outputs.front().data.data())[i], PADDLE_ENFORCE(static_cast<float*>(outputs.front().data.data())[i],
......
...@@ -99,9 +99,8 @@ TEST(BestFitAllocator, test_concurrent_cpu_allocation) { ...@@ -99,9 +99,8 @@ TEST(BestFitAllocator, test_concurrent_cpu_allocation) {
LockedAllocator locked_allocator(std::move(best_fit_allocator)); LockedAllocator locked_allocator(std::move(best_fit_allocator));
auto th_main = [&] { auto th_main = [&](std::random_device::result_type seed) {
std::random_device dev; std::default_random_engine engine(seed);
std::default_random_engine engine(dev());
std::uniform_int_distribution<size_t> dist(1U, 1024U); std::uniform_int_distribution<size_t> dist(1U, 1024U);
for (size_t i = 0; i < 128; ++i) { for (size_t i = 0; i < 128; ++i) {
...@@ -125,7 +124,8 @@ TEST(BestFitAllocator, test_concurrent_cpu_allocation) { ...@@ -125,7 +124,8 @@ TEST(BestFitAllocator, test_concurrent_cpu_allocation) {
{ {
std::vector<std::thread> threads; std::vector<std::thread> threads;
for (size_t i = 0; i < 1024; ++i) { for (size_t i = 0; i < 1024; ++i) {
threads.emplace_back(th_main); std::random_device dev;
threads.emplace_back(th_main, dev());
} }
for (auto& th : threads) { for (auto& th : threads) {
th.join(); th.join();
......
...@@ -41,9 +41,8 @@ TEST(BestFitAllocator, concurrent_cuda) { ...@@ -41,9 +41,8 @@ TEST(BestFitAllocator, concurrent_cuda) {
LockedAllocator concurrent_allocator( LockedAllocator concurrent_allocator(
std::unique_ptr<Allocator>(new BestFitAllocator(cuda_allocation.get()))); std::unique_ptr<Allocator>(new BestFitAllocator(cuda_allocation.get())));
auto th_main = [&] { auto th_main = [&](std::random_device::result_type seed) {
std::random_device dev; std::default_random_engine engine(seed);
std::default_random_engine engine(dev());
std::uniform_int_distribution<size_t> dist(1U, 1024U); std::uniform_int_distribution<size_t> dist(1U, 1024U);
platform::CUDAPlace gpu(0); platform::CUDAPlace gpu(0);
platform::CUDADeviceContext dev_ctx(gpu); platform::CUDADeviceContext dev_ctx(gpu);
...@@ -75,7 +74,8 @@ TEST(BestFitAllocator, concurrent_cuda) { ...@@ -75,7 +74,8 @@ TEST(BestFitAllocator, concurrent_cuda) {
{ {
std::vector<std::thread> threads; std::vector<std::thread> threads;
for (size_t i = 0; i < 1024; ++i) { for (size_t i = 0; i < 1024; ++i) {
threads.emplace_back(th_main); std::random_device dev;
threads.emplace_back(th_main, dev());
} }
for (auto& th : threads) { for (auto& th : threads) {
th.join(); th.join();
......
...@@ -86,7 +86,11 @@ void CPUAllocator::Free(void* p, size_t size, size_t index) { ...@@ -86,7 +86,11 @@ void CPUAllocator::Free(void* p, size_t size, size_t index) {
munlock(p, size); munlock(p, size);
#endif #endif
} }
#ifdef _WIN32
_aligned_free(p);
#else
free(p); free(p);
#endif
} }
bool CPUAllocator::UseGpu() const { return false; } bool CPUAllocator::UseGpu() const { return false; }
......
...@@ -12,7 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifndef _WIN32
#include <unistd.h> #include <unistd.h>
#endif
#include <string> #include <string>
#include <thread> // NOLINT #include <thread> // NOLINT
......
...@@ -60,75 +60,30 @@ float LogUniformSampler::Probability(int64_t value) const { ...@@ -60,75 +60,30 @@ float LogUniformSampler::Probability(int64_t value) const {
return (log((value + 2.0) / (value + 1.0))) / log_range_; return (log((value + 2.0) / (value + 1.0))) / log_range_;
} }
CustomSampler::CustomSampler(int64_t range, const float* probabilities, CustomSampler::CustomSampler(int64_t range, const float *probabilities,
const int *alias, const float *alias_probabilities,
unsigned int seed) unsigned int seed)
: Sampler(range, seed) { : Sampler(range, seed) {
random_engine_ = std::make_shared<std::mt19937_64>(seed_); random_engine_ = std::make_shared<std::mt19937>(seed_);
real_dist_ = std::make_shared<std::uniform_real_distribution<>>(0, 1); real_dist_ = std::make_shared<std::uniform_real_distribution<>>(0, 1);
int_dist_ = std::make_shared<std::uniform_int_distribution<>>(0, range); int_dist_ = std::make_shared<std::uniform_int_distribution<>>(0, range);
alias_probs_ = std::make_shared<std::vector<float>>(range + 1);
alias_ = std::make_shared<std::vector<int64_t>>(range + 1);
probs_ = std::make_shared<std::vector<float>>(range + 1);
std::queue<std::pair<int64_t, float>> bigs;
std::queue<std::pair<int64_t, float>> littles;
for (int64_t i = 0; i <= range; ++i) {
(*probs_)[i] = probabilities[i];
float normal_prob = probabilities[i] * (range + 1);
if (normal_prob - 1.0 > 1e-4) {
bigs.emplace(i, normal_prob);
} else if (1.0 - normal_prob > 1e-4) {
littles.emplace(i, normal_prob);
} else {
(*alias_probs_)[i] = normal_prob;
(*alias_)[i] = -1;
}
}
while ((!littles.empty()) && (!bigs.empty())) {
auto big = bigs.front();
auto little = littles.front();
bigs.pop();
littles.pop();
(*alias_probs_)[little.first] = little.second;
(*alias_)[little.first] = big.first;
auto big_left = big.second - (1 - little.second);
if (big_left - 1.0 > 1e-4) {
bigs.emplace(big.first, big_left);
} else if (1.0 - big_left > 1e-4) {
littles.emplace(big.first, big_left);
} else {
(*alias_probs_)[big.first] = big_left;
(*alias_)[big.first] = -1;
}
}
if (!littles.empty()) { // littles.second is close to 1.0 alias_probs_ = alias_probabilities;
auto little = littles.front(); probs_ = probabilities;
(*alias_probs_)[little.first] = 1.0; alias_ = alias;
(*alias_)[little.first] = -1;
}
if (!bigs.empty()) { // bigs.second is close to 1.0
auto big = bigs.front();
(*alias_probs_)[big.first] = 1.0;
(*alias_)[big.first] = -1;
}
} }
int64_t CustomSampler::Sample() const { int64_t CustomSampler::Sample() const {
auto index = (*int_dist_)(*random_engine_); auto index = (*int_dist_)(*random_engine_);
auto p = (*real_dist_)(*random_engine_); auto p = (*real_dist_)(*random_engine_);
if (p > (*alias_probs_)[index]) { if (p > alias_probs_[index]) {
return (*alias_)[index]; return alias_[index];
} else { } else {
return index; return index;
} }
} }
float CustomSampler::Probability(int64_t value) const { float CustomSampler::Probability(int64_t value) const { return probs_[value]; }
return (*probs_)[value];
}
} // namespace math } // namespace math
} // namespace operators } // namespace operators
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <cstdint> #include <cstdint>
#include <memory> #include <memory>
#include <random> #include <random>
...@@ -38,9 +39,12 @@ class Sampler { ...@@ -38,9 +39,12 @@ class Sampler {
seed_ = seed; seed_ = seed;
} }
} }
virtual ~Sampler(); virtual ~Sampler();
// Sample a single value // Sample a single value
virtual int64_t Sample() const = 0; virtual int64_t Sample() const = 0;
// The probability that a single call to Sample() returns the given value. // The probability that a single call to Sample() returns the given value.
virtual float Probability(int64_t value) const = 0; virtual float Probability(int64_t value) const = 0;
...@@ -99,6 +103,7 @@ class LogUniformSampler : public Sampler { ...@@ -99,6 +103,7 @@ class LogUniformSampler : public Sampler {
class CustomSampler : public Sampler { class CustomSampler : public Sampler {
public: public:
explicit CustomSampler(int64_t range, const float* probabilities, explicit CustomSampler(int64_t range, const float* probabilities,
const int* alias, const float* alias_probabilities,
unsigned int seed = 0UL); unsigned int seed = 0UL);
~CustomSampler() override {} ~CustomSampler() override {}
...@@ -108,10 +113,10 @@ class CustomSampler : public Sampler { ...@@ -108,10 +113,10 @@ class CustomSampler : public Sampler {
float Probability(int64_t value) const override; float Probability(int64_t value) const override;
private: private:
std::shared_ptr<std::vector<float>> alias_probs_; const float* alias_probs_;
std::shared_ptr<std::vector<int64_t>> alias_; const int* alias_;
std::shared_ptr<std::vector<float>> probs_; const float* probs_;
std::shared_ptr<std::mt19937_64> random_engine_; std::shared_ptr<std::mt19937> random_engine_;
std::shared_ptr<std::uniform_real_distribution<>> real_dist_; std::shared_ptr<std::uniform_real_distribution<>> real_dist_;
std::shared_ptr<std::uniform_int_distribution<>> int_dist_; std::shared_ptr<std::uniform_int_distribution<>> int_dist_;
}; };
......
...@@ -16,13 +16,12 @@ limitations under the License. */ ...@@ -16,13 +16,12 @@ limitations under the License. */
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/sequence_pooling.h" #include "paddle/fluid/operators/math/sequence_pooling.h"
#include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/macros.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
#define FLT_MAX __FLT_MAX__
template <typename T> template <typename T>
struct MaxPoolFunctor { struct MaxPoolFunctor {
HOSTDEVICE void operator()(const T* input, const size_t start, HOSTDEVICE void operator()(const T* input, const size_t start,
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/nce_op.h" #include "paddle/fluid/operators/nce_op.h"
#include <string>
#include <vector> #include <vector>
namespace paddle { namespace paddle {
...@@ -25,7 +26,7 @@ class NCEOp : public framework::OperatorWithKernel { ...@@ -25,7 +26,7 @@ class NCEOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input")); PADDLE_ENFORCE(ctx->HasInput("Input"));
PADDLE_ENFORCE(ctx->HasInput("Label")); PADDLE_ENFORCE(ctx->HasInput("Label"));
PADDLE_ENFORCE(ctx->HasInput("Weight")); PADDLE_ENFORCE(ctx->HasInput("Weight"));
...@@ -67,7 +68,7 @@ class NCEOp : public framework::OperatorWithKernel { ...@@ -67,7 +68,7 @@ class NCEOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Input")->type()), framework::ToDataType(ctx.Input<Tensor>("Input")->type()),
platform::CPUPlace()); platform::CPUPlace());
...@@ -101,11 +102,24 @@ class NCEOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -101,11 +102,24 @@ class NCEOpMaker : public framework::OpProtoAndCheckerMaker {
.AsDispensable(); .AsDispensable();
AddInput( AddInput(
"CustomDistribution", "CustomDistProbs",
"(Tensor) It is used in 'CostumDist' sampler. " "(Tensor) It is used in 'CostumDist' sampler. "
"It is a tensor with shape [num_total_classes]." "It is a tensor with shape [num_total_classes]."
"The i-th element is the probsbility of the i-th class being sampled.") "The i-th element is the probsbility of the i-th class being sampled.")
.AsDispensable(); .AsDispensable();
AddInput(
"CustomDistAlias",
"(Tensor) It is used in 'CostumDist' sampler. "
"It is a tensor with shape [num_total_classes]."
"The i-th element is the probsbility of the i-th class being sampled.")
.AsDispensable();
AddInput(
"CustomDistAliasProbs",
"(Tensor) It is used in 'CostumDist' sampler. "
"It is a tensor with shape [num_total_classes]."
"The i-th element is the probsbility of the i-th class being sampled.")
.AsDispensable();
AddOutput("Cost", AddOutput("Cost",
"(Tensor) A tensor of shape [batch_size, 1]. Cost of samples."); "(Tensor) A tensor of shape [batch_size, 1]. Cost of samples.");
AddOutput("SampleLogits", AddOutput("SampleLogits",
...@@ -124,21 +138,22 @@ class NCEOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -124,21 +138,22 @@ class NCEOpMaker : public framework::OpProtoAndCheckerMaker {
"kernel to compute grads." "kernel to compute grads."
"") "")
.AsIntermediate(); .AsIntermediate();
AddAttr<int>("num_total_classes", AddAttr<int>("num_total_classes",
"Total number of classes in all samples."); "Total number of classes in all samples.");
AddAttr<int>("num_neg_samples", AddAttr<int>("num_neg_samples",
"The number of negative classes. The default value is 10.") "The number of negative classes. The default value is 10.")
.SetDefault(10); .SetDefault(10);
AddAttr<int>("sampler", AddAttr<int>("sampler",
"(int) Which sampler to be used to sample negative class." "(int) Which sampler to be used to sample negative class."
"0: Uniform; 1: LogUniform; 2: CostumDist.") "0: Uniform; 1: LogUniform; 2: CostumDist.")
.SetDefault(0); .SetDefault(0);
AddAttr<int>("seed", AddAttr<int>("seed",
"(int) The seed used in sampler. If it is 0, " "(int) The seed used in sampler. If it is 0, "
"the sampler will generate a seed randomly.") "the sampler will generate a seed randomly.")
.SetDefault(0); .SetDefault(0);
AddAttr<bool>("is_sparse", "(boolean, default false) Sparse update.")
.SetDefault(false);
AddAttr<std::vector<int>>("custom_neg_classes", AddAttr<std::vector<int>>("custom_neg_classes",
"This attribute only be used in unitest. Classes " "This attribute only be used in unitest. Classes "
...@@ -156,11 +171,19 @@ By default this operator uses a uniform distribution for sampling. ...@@ -156,11 +171,19 @@ By default this operator uses a uniform distribution for sampling.
} }
}; };
class NCEOpGradDescMaker : public framework::DefaultGradOpDescMaker<true> {
using ::paddle::framework::DefaultGradOpDescMaker<
true>::DefaultGradOpDescMaker;
protected:
virtual std::string GradOpType() const { return "nce_grad"; }
};
class NCEOpGrad : public framework::OperatorWithKernel { class NCEOpGrad : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input")); PADDLE_ENFORCE(ctx->HasInput("Input"));
PADDLE_ENFORCE(ctx->HasInput("Weight")); PADDLE_ENFORCE(ctx->HasInput("Weight"));
PADDLE_ENFORCE(ctx->HasInput("Cost")); PADDLE_ENFORCE(ctx->HasInput("Cost"));
...@@ -190,20 +213,45 @@ class NCEOpGrad : public framework::OperatorWithKernel { ...@@ -190,20 +213,45 @@ class NCEOpGrad : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Input")->type()), framework::ToDataType(ctx.Input<Tensor>("Input")->type()),
platform::CPUPlace()); platform::CPUPlace());
} }
}; };
class NCEOpGradVarTypeInference : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {
auto weight_grad = op_desc.Output(framework::GradVarName("Weight")).front();
auto bias_grad = op_desc.Output(framework::GradVarName("Bias")).front();
auto attr = op_desc.GetAttr("is_sparse");
bool is_sparse = boost::get<bool>(attr);
if (is_sparse) {
VLOG(30) << "nce_op_grad op " << weight_grad << " and " << bias_grad
<< " is set to SelectedRows";
block->Var(weight_grad)
->SetType(framework::proto::VarType::SELECTED_ROWS);
block->Var(bias_grad)->SetType(framework::proto::VarType::SELECTED_ROWS);
} else {
VLOG(30) << "nce_op_grad op " << weight_grad << " and " << bias_grad
<< " is set to LoDTensor";
block->Var(weight_grad)->SetType(framework::proto::VarType::LOD_TENSOR);
block->Var(bias_grad)->SetType(framework::proto::VarType::LOD_TENSOR);
}
block->Var(weight_grad)->SetDataType(block->Var("Input")->GetDataType());
block->Var(bias_grad)->SetDataType(block->Var("Input")->GetDataType());
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(nce, ops::NCEOp, ops::NCEOpMaker, REGISTER_OPERATOR(nce, ops::NCEOp, ops::NCEOpGradDescMaker, ops::NCEOpMaker);
paddle::framework::DefaultGradOpDescMaker<true>); REGISTER_OPERATOR(nce_grad, ops::NCEOpGrad, ops::NCEOpGradVarTypeInference);
REGISTER_OPERATOR(nce_grad, ops::NCEOpGrad);
REGISTER_OP_CPU_KERNEL(nce, ops::NCEKernel<paddle::platform::CPUPlace, float>, REGISTER_OP_CPU_KERNEL(nce, ops::NCEKernel<paddle::platform::CPUPlace, float>,
ops::NCEKernel<paddle::platform::CPUPlace, double>); ops::NCEKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL(nce_grad, REGISTER_OP_CPU_KERNEL(nce_grad,
......
...@@ -16,26 +16,32 @@ limitations under the License. */ ...@@ -16,26 +16,32 @@ limitations under the License. */
#include <math.h> #include <math.h>
#include <random> #include <random>
#include <set>
#include <vector> #include <vector>
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/math/sampler.h" #include "paddle/fluid/operators/math/sampler.h"
#include "unsupported/Eigen/CXX11/Tensor" #include "unsupported/Eigen/CXX11/Tensor"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
using SelectedRows = framework::SelectedRows;
using Sampler = math::Sampler; using Sampler = math::Sampler;
using DDim = framework::DDim;
template <typename T, int MajorType = Eigen::RowMajor, template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex> typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>; using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
void PrepareSamples(const framework::ExecutionContext& context, void PrepareSamples(const framework::ExecutionContext &context,
Sampler* sampler) { Sampler *sampler) {
auto label = context.Input<Tensor>("Label"); auto label = context.Input<Tensor>("Label");
const int64_t* label_data = label->data<int64_t>(); const int64_t *label_data = label->data<int64_t>();
auto label_dims = label->dims(); auto label_dims = label->dims();
// int num_total_classes = context.Attr<int>("num_total_classes"); // int num_total_classes = context.Attr<int>("num_total_classes");
// for unitest // for unitest
...@@ -44,7 +50,7 @@ void PrepareSamples(const framework::ExecutionContext& context, ...@@ -44,7 +50,7 @@ void PrepareSamples(const framework::ExecutionContext& context,
auto sample_labels = context.Output<Tensor>("SampleLabels"); auto sample_labels = context.Output<Tensor>("SampleLabels");
auto sample_labels_dims = sample_labels->dims(); auto sample_labels_dims = sample_labels->dims();
int64_t* sample_labels_data = int64_t *sample_labels_data =
sample_labels->mutable_data<int64_t>(context.GetPlace()); sample_labels->mutable_data<int64_t>(context.GetPlace());
int num_label = label_dims.size() == 2 ? label_dims[1] : 1; int num_label = label_dims.size() == 2 ? label_dims[1] : 1;
...@@ -70,13 +76,13 @@ void PrepareSamples(const framework::ExecutionContext& context, ...@@ -70,13 +76,13 @@ void PrepareSamples(const framework::ExecutionContext& context,
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class NCEKernel : public framework::OpKernel<T> { class NCEKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext &context) const override {
int sampler_type = context.Attr<int>("sampler"); int sampler_type = context.Attr<int>("sampler");
int seed = context.Attr<int>("seed"); int seed = context.Attr<int>("seed");
int num_total_classes = context.Attr<int>("num_total_classes"); int num_total_classes = context.Attr<int>("num_total_classes");
int num_neg_samples = context.Attr<int>("num_neg_samples"); int num_neg_samples = context.Attr<int>("num_neg_samples");
Sampler* sampler; Sampler *sampler;
switch (sampler_type) { switch (sampler_type) {
case 0: { case 0: {
sampler = new math::UniformSampler(num_total_classes - 1, seed); sampler = new math::UniformSampler(num_total_classes - 1, seed);
...@@ -87,11 +93,19 @@ class NCEKernel : public framework::OpKernel<T> { ...@@ -87,11 +93,19 @@ class NCEKernel : public framework::OpKernel<T> {
break; break;
} }
case 2: { case 2: {
auto custom_dist = context.Input<Tensor>("CustomDistribution"); auto dist_probs = context.Input<Tensor>("CustomDistProbs");
const float* custom_dist_data = custom_dist->data<float>(); auto dist_alias = context.Input<Tensor>("CustomDistAlias");
PADDLE_ENFORCE_EQ(custom_dist->numel(), num_total_classes); auto dist_alias_probs = context.Input<Tensor>("CustomDistAliasProbs");
sampler = new math::CustomSampler(num_total_classes - 1,
custom_dist_data, seed); PADDLE_ENFORCE_EQ(dist_probs->numel(), num_total_classes);
PADDLE_ENFORCE_EQ(dist_alias->numel(), num_total_classes);
PADDLE_ENFORCE_EQ(dist_alias_probs->numel(), num_total_classes);
const float *probs_data = dist_probs->data<float>();
const int *alias_data = dist_alias->data<int>();
const float *alias_probs_data = dist_alias_probs->data<float>();
sampler = new math::CustomSampler(num_total_classes - 1, probs_data,
alias_data, alias_probs_data, seed);
break; break;
} }
default: { PADDLE_THROW("Unsupported SamplerType."); } default: { PADDLE_THROW("Unsupported SamplerType."); }
...@@ -99,17 +113,17 @@ class NCEKernel : public framework::OpKernel<T> { ...@@ -99,17 +113,17 @@ class NCEKernel : public framework::OpKernel<T> {
PrepareSamples<DeviceContext, T>(context, sampler); PrepareSamples<DeviceContext, T>(context, sampler);
auto sample_labels = context.Output<Tensor>("SampleLabels"); auto sample_labels = context.Output<Tensor>("SampleLabels");
const int64_t* sample_labels_data = sample_labels->data<int64_t>(); const int64_t *sample_labels_data = sample_labels->data<int64_t>();
auto sample_out = context.Output<Tensor>("SampleLogits"); auto sample_out = context.Output<Tensor>("SampleLogits");
T* sample_out_data = sample_out->mutable_data<T>(context.GetPlace()); T *sample_out_data = sample_out->mutable_data<T>(context.GetPlace());
auto label = context.Input<Tensor>("Label"); auto label = context.Input<Tensor>("Label");
auto sample_weight = context.Input<Tensor>("SampleWeight"); auto sample_weight = context.Input<Tensor>("SampleWeight");
const T* sample_weight_data = nullptr; const T *sample_weight_data = nullptr;
if (sample_weight != nullptr) { if (sample_weight != nullptr) {
sample_weight_data = sample_weight->data<T>(); sample_weight_data = sample_weight->data<T>();
} }
auto out = context.Output<Tensor>("Cost"); auto out = context.Output<Tensor>("Cost");
T* out_data = out->mutable_data<T>(context.GetPlace()); T *out_data = out->mutable_data<T>(context.GetPlace());
int64_t num_true_class = 1; int64_t num_true_class = 1;
if (label != nullptr) { if (label != nullptr) {
num_true_class = label->dims()[1]; num_true_class = label->dims()[1];
...@@ -119,7 +133,7 @@ class NCEKernel : public framework::OpKernel<T> { ...@@ -119,7 +133,7 @@ class NCEKernel : public framework::OpKernel<T> {
// forward bias // forward bias
auto bias = context.Input<Tensor>("Bias"); auto bias = context.Input<Tensor>("Bias");
if (bias != nullptr) { if (bias != nullptr) {
const T* bias_data = bias->data<T>(); const T *bias_data = bias->data<T>();
for (int64_t i = 0; i < sample_labels->numel(); ++i) { for (int64_t i = 0; i < sample_labels->numel(); ++i) {
sample_out_data[i] = bias_data[sample_labels_data[i]]; sample_out_data[i] = bias_data[sample_labels_data[i]];
} }
...@@ -158,16 +172,16 @@ class NCEKernel : public framework::OpKernel<T> { ...@@ -158,16 +172,16 @@ class NCEKernel : public framework::OpKernel<T> {
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class NCEGradKernel : public framework::OpKernel<T> { class NCEGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext &context) const override {
auto d_out = context.Input<Tensor>(framework::GradVarName("Cost")); auto d_out = context.Input<Tensor>(framework::GradVarName("Cost"));
const T* d_out_data = d_out->data<T>(); const T *d_out_data = d_out->data<T>();
auto label = context.Input<Tensor>("Label"); auto label = context.Input<Tensor>("Label");
auto sample_out = context.Input<Tensor>("SampleLogits"); auto sample_out = context.Input<Tensor>("SampleLogits");
const T* sample_out_data = sample_out->data<T>(); const T *sample_out_data = sample_out->data<T>();
auto sample_labels = context.Input<Tensor>("SampleLabels"); auto sample_labels = context.Input<Tensor>("SampleLabels");
const int64_t* sample_labels_data = sample_labels->data<int64_t>(); const int64_t *sample_labels_data = sample_labels->data<int64_t>();
auto sample_weight = context.Input<Tensor>("SampleWeight"); auto sample_weight = context.Input<Tensor>("SampleWeight");
const T* sample_weight_data = nullptr; const T *sample_weight_data = nullptr;
if (sample_weight != nullptr) { if (sample_weight != nullptr) {
sample_weight_data = sample_weight->data<T>(); sample_weight_data = sample_weight->data<T>();
} }
...@@ -180,7 +194,7 @@ class NCEGradKernel : public framework::OpKernel<T> { ...@@ -180,7 +194,7 @@ class NCEGradKernel : public framework::OpKernel<T> {
int sampler_type = context.Attr<int>("sampler"); int sampler_type = context.Attr<int>("sampler");
int seed = context.Attr<int>("seed"); int seed = context.Attr<int>("seed");
Sampler* sampler; Sampler *sampler;
switch (sampler_type) { switch (sampler_type) {
case 0: { case 0: {
sampler = new math::UniformSampler(num_total_classes - 1, seed); sampler = new math::UniformSampler(num_total_classes - 1, seed);
...@@ -191,11 +205,19 @@ class NCEGradKernel : public framework::OpKernel<T> { ...@@ -191,11 +205,19 @@ class NCEGradKernel : public framework::OpKernel<T> {
break; break;
} }
case 2: { case 2: {
auto custom_dist = context.Input<Tensor>("CustomDistribution"); auto dist_probs = context.Input<Tensor>("CustomDistProbs");
const float* custom_dist_data = custom_dist->data<float>(); auto dist_alias = context.Input<Tensor>("CustomDistAlias");
PADDLE_ENFORCE_EQ(custom_dist->numel(), num_total_classes); auto dist_alias_probs = context.Input<Tensor>("CustomDistAliasProbs");
sampler = new math::CustomSampler(num_total_classes - 1,
custom_dist_data, seed); PADDLE_ENFORCE_EQ(dist_probs->numel(), num_total_classes);
PADDLE_ENFORCE_EQ(dist_alias->numel(), num_total_classes);
PADDLE_ENFORCE_EQ(dist_alias_probs->numel(), num_total_classes);
const float *probs_data = dist_probs->data<float>();
const int *alias_data = dist_alias->data<int>();
const float *alias_probs_data = dist_alias_probs->data<float>();
sampler = new math::CustomSampler(num_total_classes - 1, probs_data,
alias_data, alias_probs_data, seed);
break; break;
} }
default: { PADDLE_THROW("Unsupported SamplerType."); } default: { PADDLE_THROW("Unsupported SamplerType."); }
...@@ -203,7 +225,7 @@ class NCEGradKernel : public framework::OpKernel<T> { ...@@ -203,7 +225,7 @@ class NCEGradKernel : public framework::OpKernel<T> {
// T b = 1. / num_total_classes * num_neg_samples; // T b = 1. / num_total_classes * num_neg_samples;
Tensor sample_grad; // tmp tensor Tensor sample_grad; // tmp tensor
T* sample_grad_data = T *sample_grad_data =
sample_grad.mutable_data<T>(sample_labels->dims(), context.GetPlace()); sample_grad.mutable_data<T>(sample_labels->dims(), context.GetPlace());
// backward cost // backward cost
for (int64_t i = 0; i < sample_labels->numel(); ++i) { for (int64_t i = 0; i < sample_labels->numel(); ++i) {
...@@ -217,32 +239,105 @@ class NCEGradKernel : public framework::OpKernel<T> { ...@@ -217,32 +239,105 @@ class NCEGradKernel : public framework::OpKernel<T> {
: w * (o * (1 - o) / (o + b)); : w * (o * (1 - o) / (o + b));
sample_grad_data[i] *= d_out_data[sample_idx]; sample_grad_data[i] *= d_out_data[sample_idx];
} }
// get d_bias
auto d_bias = context.Output<Tensor>(framework::GradVarName("Bias")); bool is_sparse = context.Attr<bool>("is_sparse");
if (d_bias != nullptr) {
T* d_bias_data = d_bias->mutable_data<T>(context.GetPlace()); if (!is_sparse) {
std::fill(d_bias_data, d_bias_data + d_bias->numel(), 0.0); // get d_bias
auto d_bias = context.Output<Tensor>(framework::GradVarName("Bias"));
if (d_bias != nullptr) {
T *d_bias_data = d_bias->mutable_data<T>(context.GetPlace());
std::fill(d_bias_data, d_bias_data + d_bias->numel(), 0.0);
for (int64_t i = 0; i < sample_labels->numel(); ++i) {
d_bias_data[sample_labels_data[i]] += sample_grad_data[i];
}
}
// get d_w
auto d_w = context.Output<Tensor>(framework::GradVarName("Weight"));
if (d_w != nullptr) {
auto d_w_data = d_w->mutable_data<T>(context.GetPlace());
std::fill(d_w_data, d_w_data + d_w->numel(), 0.0);
auto d_w_matrix = EigenMatrix<T>::From(*d_w);
auto x_matrix = EigenMatrix<T>::From(*(context.Input<Tensor>("Input")));
for (int64_t i = 0; i < sample_labels->numel(); ++i) {
d_w_matrix.chip(sample_labels_data[i], 0) +=
x_matrix.chip(static_cast<int>(i / sample_labels->dims()[1]), 0) *
sample_grad_data[i];
}
}
} else {
std::vector<int64_t> labels;
for (int64_t i = 0; i < sample_labels->numel(); ++i) { for (int64_t i = 0; i < sample_labels->numel(); ++i) {
d_bias_data[sample_labels_data[i]] += sample_grad_data[i]; labels.push_back(sample_labels_data[i]);
} }
} std::set<T> st(labels.begin(), labels.end());
// get d_w labels.assign(st.begin(), st.end());
auto d_w = context.Output<Tensor>(framework::GradVarName("Weight"));
if (d_w != nullptr) { auto *bias_var = context.InputVar("Bias");
auto d_w_data = d_w->mutable_data<T>(context.GetPlace()); DDim bias_dim;
std::fill(d_w_data, d_w_data + d_w->numel(), 0.0); if (bias_var->IsType<LoDTensor>()) {
auto d_w_matrix = EigenMatrix<T>::From(*d_w); bias_dim = context.Input<LoDTensor>("Bias")->dims();
} else if (bias_var->IsType<SelectedRows>()) {
auto *table_t = context.Input<SelectedRows>("Bias");
bias_dim = table_t->value().dims();
} else {
PADDLE_THROW(
"The parameter Bias of a NCE_OP "
"must be either LoDTensor or SelectedRows");
}
auto d_bias =
context.Output<SelectedRows>(framework::GradVarName("Bias"));
d_bias->set_rows(labels);
d_bias->set_height(bias_dim[0]);
d_bias->mutable_value()->Resize(
{static_cast<int64_t>(labels.size()), bias_dim[1]});
T *d_bias_data =
d_bias->mutable_value()->mutable_data<T>(context.GetPlace());
std::fill(d_bias_data, d_bias_data + labels.size(), 0.0);
for (int64_t i = 0; i < sample_labels->numel(); ++i) {
d_bias_data[d_bias->Index(sample_labels_data[i])] +=
sample_grad_data[i];
}
auto *table_var = context.InputVar("Weight");
DDim table_dim;
if (table_var->IsType<LoDTensor>()) {
table_dim = context.Input<LoDTensor>("Weight")->dims();
} else if (table_var->IsType<SelectedRows>()) {
auto *table_t = context.Input<SelectedRows>("Weight");
table_dim = table_t->value().dims();
} else {
PADDLE_THROW(
"The parameter Weight of a NCE_OP "
"must be either LoDTensor or SelectedRows");
}
auto d_w = context.Output<SelectedRows>(framework::GradVarName("Weight"));
d_w->set_rows(labels);
d_w->set_height(table_dim[0]);
auto *d_table_value = d_w->mutable_value();
d_table_value->Resize(
{static_cast<int64_t>(labels.size()), table_dim[1]});
auto d_w_data = d_table_value->mutable_data<T>(context.GetPlace());
std::fill(d_w_data, d_w_data + d_table_value->numel(), 0.0);
auto d_w_matrix = EigenMatrix<T>::From(*d_table_value);
auto x_matrix = EigenMatrix<T>::From(*(context.Input<Tensor>("Input"))); auto x_matrix = EigenMatrix<T>::From(*(context.Input<Tensor>("Input")));
for (int64_t i = 0; i < sample_labels->numel(); ++i) { for (int64_t i = 0; i < sample_labels->numel(); ++i) {
d_w_matrix.chip(sample_labels_data[i], 0) += d_w_matrix.chip(d_w->Index(sample_labels_data[i]), 0) +=
x_matrix.chip(static_cast<int>(i / sample_labels->dims()[1]), 0) * x_matrix.chip(static_cast<int>(i / sample_labels->dims()[1]), 0) *
sample_grad_data[i]; sample_grad_data[i];
} }
} }
// get d_x // get d_x
auto d_x = context.Output<Tensor>(framework::GradVarName("Input")); auto d_x = context.Output<Tensor>(framework::GradVarName("Input"));
if (d_x != nullptr) { if (d_x != nullptr) {
auto* d_x_data = d_x->mutable_data<T>(context.GetPlace()); auto *d_x_data = d_x->mutable_data<T>(context.GetPlace());
std::fill(d_x_data, d_x_data + d_x->numel(), 0.0); std::fill(d_x_data, d_x_data + d_x->numel(), 0.0);
auto d_x_matrix = EigenMatrix<T>::From(*d_x); auto d_x_matrix = EigenMatrix<T>::From(*d_x);
auto w_matrix = EigenMatrix<T>::From(*(context.Input<Tensor>("Weight"))); auto w_matrix = EigenMatrix<T>::From(*(context.Input<Tensor>("Weight")));
...@@ -251,6 +346,7 @@ class NCEGradKernel : public framework::OpKernel<T> { ...@@ -251,6 +346,7 @@ class NCEGradKernel : public framework::OpKernel<T> {
w_matrix.chip(sample_labels_data[i], 0) * sample_grad_data[i]; w_matrix.chip(sample_labels_data[i], 0) * sample_grad_data[i];
} }
} }
delete sampler; delete sampler;
} }
}; };
......
...@@ -20,12 +20,12 @@ limitations under the License. */ ...@@ -20,12 +20,12 @@ limitations under the License. */
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#ifndef _WIN32 #ifndef _WIN32
const float fraction_of_gpu_memory_to_use = 0.92f; constexpr static float fraction_of_gpu_memory_to_use = 0.92f;
#else #else
// fraction_of_gpu_memory_to_use cannot be too high on windows, // fraction_of_gpu_memory_to_use cannot be too high on windows,
// since the win32 graphic sub-system can occupy some GPU memory // since the win32 graphic sub-system can occupy some GPU memory
// which may lead to insufficient memory left for paddle // which may lead to insufficient memory left for paddle
const float fraction_of_gpu_memory_to_use = 0.5f; constexpr static float fraction_of_gpu_memory_to_use = 0.5f;
#endif #endif
DEFINE_double(fraction_of_gpu_memory_to_use, fraction_of_gpu_memory_to_use, DEFINE_double(fraction_of_gpu_memory_to_use, fraction_of_gpu_memory_to_use,
......
...@@ -29,8 +29,16 @@ limitations under the License. */ ...@@ -29,8 +29,16 @@ limitations under the License. */
namespace pybind11 { namespace pybind11 {
namespace detail { namespace detail {
#if !defined(PYBIND11_HIDDEN)
#ifdef _WIN32
#define PYBIND11_HIDDEN __declspec(dllexport)
#else
#define PYBIND11_HIDDEN __attribute__((visibility("hidden")))
#endif
#endif
// Can be replaced by a generic lambda in C++14 // Can be replaced by a generic lambda in C++14
struct __attribute__((visibility("hidden"))) paddle_variant_caster_visitor struct PYBIND11_HIDDEN paddle_variant_caster_visitor
: public boost::static_visitor<handle> { : public boost::static_visitor<handle> {
return_value_policy policy; return_value_policy policy;
handle parent; handle parent;
......
...@@ -860,6 +860,12 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -860,6 +860,12 @@ All parameter, weight, gradient are variables in Paddle.
self.remove_unnecessary_lock_ = b; self.remove_unnecessary_lock_ = b;
}, },
R"DOC(The type is BOOL. If set True, some locks in GPU ops would be released and ParallelExecutor would run faster. Default False.)DOC") R"DOC(The type is BOOL. If set True, some locks in GPU ops would be released and ParallelExecutor would run faster. Default False.)DOC")
.def_property(
"num_trainers",
[](const BuildStrategy &self) { return self.num_trainers_; },
[](BuildStrategy &self, int num_trainers) {
self.num_trainers_ = num_trainers;
})
.def_property( .def_property(
"fuse_elewise_add_act_ops", "fuse_elewise_add_act_ops",
[](const BuildStrategy &self) { [](const BuildStrategy &self) {
......
...@@ -4394,7 +4394,8 @@ def nce(input, ...@@ -4394,7 +4394,8 @@ def nce(input,
name=None, name=None,
sampler="uniform", sampler="uniform",
custom_dist=None, custom_dist=None,
seed=0): seed=0,
is_sparse=False):
""" """
${comment} ${comment}
...@@ -4420,11 +4421,12 @@ def nce(input, ...@@ -4420,11 +4421,12 @@ def nce(input,
sampler (str): The sampler used to sample class from negtive classes. sampler (str): The sampler used to sample class from negtive classes.
It can be 'uniform', 'log_uniform' or 'custom_dist'. It can be 'uniform', 'log_uniform' or 'custom_dist'.
default: 'uniform'. default: 'uniform'.
custom_dist (Variable): A tensor with shape [num_total_classes]. custom_dist (float[]): A float[] with size=num_total_classes.
It is used when sampler is set to 'custom_dist'. It is used when sampler is set to 'custom_dist'.
custom_dist[i] is the probsbility of i-th class to be sampled. custom_dist[i] is the probsbility of i-th class to be sampled.
default: None. default: None.
seed (int): The seed used in sampler. default: 0. seed (int): The seed used in sampler. default: 0.
is_sparse(bool): The flag indicating whether to use sparse update, the weight@GRAD and bias@GRAD will be changed to SelectedRows.
Returns: Returns:
Variable: The output nce loss. Variable: The output nce loss.
...@@ -4476,12 +4478,7 @@ def nce(input, ...@@ -4476,12 +4478,7 @@ def nce(input,
shape=[num_total_classes, dim], shape=[num_total_classes, dim],
is_bias=False, is_bias=False,
dtype=input.dtype) dtype=input.dtype)
inputs = { inputs = {}
'Input': input,
'Label': label,
'Weight': w,
'SampleWeight': sample_weight if sample_weight is not None else []
}
if helper.bias_attr: if helper.bias_attr:
b = helper.create_parameter( b = helper.create_parameter(
attr=helper.bias_attr, attr=helper.bias_attr,
...@@ -4493,18 +4490,10 @@ def nce(input, ...@@ -4493,18 +4490,10 @@ def nce(input,
sample_logits = helper.create_variable_for_type_inference(dtype=input.dtype) sample_logits = helper.create_variable_for_type_inference(dtype=input.dtype)
sample_labels = helper.create_variable_for_type_inference(dtype=label.dtype) sample_labels = helper.create_variable_for_type_inference(dtype=label.dtype)
if num_neg_samples is None: inputs['Input'] = input
num_neg_samples = 10 inputs['Label'] = label
else: inputs['Weight'] = w
num_neg_samples = int(num_neg_samples) inputs['SampleWeight'] = sample_weight if sample_weight is not None else []
inputs = {
'Input': input,
'Label': label,
'Weight': w,
'Bias': b,
'SampleWeight': sample_weight if sample_weight is not None else []
}
if sampler == "uniform": if sampler == "uniform":
sampler = 0 sampler = 0
...@@ -4512,17 +4501,73 @@ def nce(input, ...@@ -4512,17 +4501,73 @@ def nce(input,
sampler = 1 sampler = 1
elif sampler == "custom_dist": elif sampler == "custom_dist":
assert custom_dist is not None assert custom_dist is not None
assert isinstance(custom_dist, Variable) # assert isinstance(custom_dist, Variable)
inputs['CustomDistribution'] = custom_dist
custom_dist_len = len(custom_dist)
alias_probs_ = [0] * custom_dist_len
alias_ = [0] * custom_dist_len
bigs = []
littles = []
for i in range(custom_dist_len):
normal_prob = custom_dist[i] * custom_dist_len
if normal_prob - 1.0 > 1e-4:
bigs.append((i, normal_prob))
elif 1.0 - normal_prob > 1e-4:
littles.append((i, normal_prob))
else:
alias_probs_[i] = normal_prob
alias_[i] = -1
while len(bigs) and len(littles):
big = bigs.pop(0)
little = littles.pop(0)
big_idx = big[0]
big_prob = big[1]
alias_probs_[little[0]] = little[1]
alias_[little[0]] = big_idx
big_left = big[1] + little[1] - 1
if big_left - 1.0 > 1e-4:
bigs.append((big_idx, big_left))
elif 1.0 - big_left > 1e-4:
littles.append((big_idx, big_left))
else:
alias_probs_[big_idx] = big_left
alias_[big_idx] = -1
if len(bigs):
big = bigs.pop(0)
alias_probs_[big[0]] = 1.0
alias_[big[0]] = -1
if len(littles):
little = littles.pop(0)
alias_probs_[little[0]] = 1.0
alias_[little[0]] = -1
probs = assign(input=np.array(custom_dist).astype('float32'))
custom_alias = assign(input=np.array(alias_).astype('int32'))
custom_alias_probs = assign(
input=np.array(alias_probs_).astype('float32'))
inputs['CustomDistProbs'] = probs
inputs['CustomDistAlias'] = custom_alias
inputs['CustomDistAliasProbs'] = custom_alias_probs
sampler = 2 sampler = 2
else: else:
raise Exception("Unsupported sampler type.") raise Exception("Unsupported sampler type.")
if num_neg_samples is None:
num_neg_samples = 10
else:
num_neg_samples = int(num_neg_samples)
attrs = { attrs = {
'num_total_classes': int(num_total_classes), 'num_total_classes': int(num_total_classes),
'num_neg_samples': num_neg_samples, 'num_neg_samples': num_neg_samples,
'seed': seed, 'seed': seed,
'sampler': sampler 'sampler': sampler,
'is_sparse': is_sparse
} }
helper.append_op( helper.append_op(
...@@ -6525,7 +6570,7 @@ def crop(x, shape=None, offsets=None, name=None): ...@@ -6525,7 +6570,7 @@ def crop(x, shape=None, offsets=None, name=None):
helper = LayerHelper('crop', **locals()) helper = LayerHelper('crop', **locals())
if not (isinstance(shape, list) or isinstance(shape, tuple) or \ if not (isinstance(shape, list) or isinstance(shape, tuple) or \
isinstance(shape, Variable)): isinstance(shape, Variable)):
raise ValueError("The shape should be a list, tuple or Variable.") raise ValueError("The shape should be a list, tuple or Variable.")
if offsets is None: if offsets is None:
...@@ -6647,7 +6692,7 @@ def affine_grid(theta, out_shape, name=None): ...@@ -6647,7 +6692,7 @@ def affine_grid(theta, out_shape, name=None):
helper = LayerHelper('affine_grid') helper = LayerHelper('affine_grid')
if not (isinstance(out_shape, list) or isinstance(out_shape, tuple) or \ if not (isinstance(out_shape, list) or isinstance(out_shape, tuple) or \
isinstance(out_shape, Variable)): isinstance(out_shape, Variable)):
raise ValueError("The out_shape should be a list, tuple or Variable.") raise ValueError("The out_shape should be a list, tuple or Variable.")
if not isinstance(theta, Variable): if not isinstance(theta, Variable):
...@@ -6888,6 +6933,13 @@ def elu(x, alpha=1.0, name=None): ...@@ -6888,6 +6933,13 @@ def elu(x, alpha=1.0, name=None):
Returns: Returns:
output(${out_type}): ${out_comment} output(${out_type}): ${out_comment}
Examples:
.. code-block:: python
x = fluid.layers.data(name="x", shape=[3,10,32,32], dtype="float32")
y = fluid.layers.elu(x, alpha=0.2)
""" """
helper = LayerHelper('elu', **locals()) helper = LayerHelper('elu', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype) out = helper.create_variable_for_type_inference(dtype=x.dtype)
...@@ -6911,6 +6963,13 @@ def relu6(x, threshold=6.0, name=None): ...@@ -6911,6 +6963,13 @@ def relu6(x, threshold=6.0, name=None):
Returns: Returns:
output(${out_type}): ${out_comment} output(${out_type}): ${out_comment}
Examples:
.. code-block:: python
x = fluid.layers.data(name="x", shape=[3,10,32,32], dtype="float32")
y = fluid.layers.relu6(x, threshold=6.0)
""" """
helper = LayerHelper('relu6', **locals()) helper = LayerHelper('relu6', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype) out = helper.create_variable_for_type_inference(dtype=x.dtype)
...@@ -6934,6 +6993,13 @@ def pow(x, factor=1.0, name=None): ...@@ -6934,6 +6993,13 @@ def pow(x, factor=1.0, name=None):
Returns: Returns:
output(${out_type}): ${out_comment} output(${out_type}): ${out_comment}
Examples:
.. code-block:: python
x = fluid.layers.data(name="x", shape=[3,10,32,32], dtype="float32")
y = fluid.layers.pow(x, factor=2.0)
""" """
helper = LayerHelper('pow', **locals()) helper = LayerHelper('pow', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype) out = helper.create_variable_for_type_inference(dtype=x.dtype)
...@@ -6958,6 +7024,13 @@ def stanh(x, scale_a=2.0 / 3.0, scale_b=1.7159, name=None): ...@@ -6958,6 +7024,13 @@ def stanh(x, scale_a=2.0 / 3.0, scale_b=1.7159, name=None):
Returns: Returns:
output(${out_type}): ${out_comment} output(${out_type}): ${out_comment}
Examples:
.. code-block:: python
x = fluid.layers.data(name="x", shape=[3,10,32,32], dtype="float32")
y = fluid.layers.stanh(x, scale_a=0.67, scale_b=1.72)
""" """
helper = LayerHelper('stanh', **locals()) helper = LayerHelper('stanh', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype) out = helper.create_variable_for_type_inference(dtype=x.dtype)
...@@ -6983,6 +7056,13 @@ def hard_sigmoid(x, slope=0.2, offset=0.5, name=None): ...@@ -6983,6 +7056,13 @@ def hard_sigmoid(x, slope=0.2, offset=0.5, name=None):
Returns: Returns:
output(${out_type}): ${out_comment} output(${out_type}): ${out_comment}
Examples:
.. code-block:: python
x = fluid.layers.data(name="x", shape=[3,10,32,32], dtype="float32")
y = fluid.layers.hard_sigmoid(x, slope=0.3, offset=0.8)
""" """
helper = LayerHelper('hard_sigmoid', **locals()) helper = LayerHelper('hard_sigmoid', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype) out = helper.create_variable_for_type_inference(dtype=x.dtype)
...@@ -7007,6 +7087,13 @@ def swish(x, beta=1.0, name=None): ...@@ -7007,6 +7087,13 @@ def swish(x, beta=1.0, name=None):
Returns: Returns:
output(${out_type}): ${out_comment} output(${out_type}): ${out_comment}
Examples:
.. code-block:: python
x = fluid.layers.data(name="x", shape=[3,10,32,32], dtype="float32")
y = fluid.layers.swish(x, beta=2.0)
""" """
helper = LayerHelper('swish', **locals()) helper = LayerHelper('swish', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype) out = helper.create_variable_for_type_inference(dtype=x.dtype)
......
...@@ -124,16 +124,11 @@ class ParallelExecutor(object): ...@@ -124,16 +124,11 @@ class ParallelExecutor(object):
os.environ.get('CPU_NUM', multiprocessing.cpu_count())) os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
exec_strategy.num_threads = cpu_num * 2 exec_strategy.num_threads = cpu_num * 2
# Set 1 thread num under nccl2 distribute
# env to make sure all gpus run ops in same order.
if num_trainers > 1:
assert (use_cuda)
# FIXME(gongwb): avoid this set.
exec_strategy.num_threads = 1
if build_strategy is None: if build_strategy is None:
build_strategy = BuildStrategy() build_strategy = BuildStrategy()
build_strategy.num_trainers = num_trainers
main = main_program main = main_program
main = main if main else framework.default_main_program() main = main if main else framework.default_main_program()
if scope == None: if scope == None:
......
...@@ -63,7 +63,7 @@ function(py_test_modules TARGET_NAME) ...@@ -63,7 +63,7 @@ function(py_test_modules TARGET_NAME)
set(multiValueArgs MODULES DEPS ENVS) set(multiValueArgs MODULES DEPS ENVS)
cmake_parse_arguments(py_test_modules "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) cmake_parse_arguments(py_test_modules "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
add_test(NAME ${TARGET_NAME} add_test(NAME ${TARGET_NAME}
COMMAND env PYTHONPATH=${PADDLE_BINARY_DIR}/python ${py_test_modules_ENVS} COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${PADDLE_BINARY_DIR}/python ${py_test_modules_ENVS}
${PYTHON_EXECUTABLE} ${PADDLE_SOURCE_DIR}/tools/test_runner.py ${py_test_modules_MODULES} ${PYTHON_EXECUTABLE} ${PADDLE_SOURCE_DIR}/tools/test_runner.py ${py_test_modules_MODULES}
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
if (py_test_modules_SERIAL) if (py_test_modules_SERIAL)
......
...@@ -14,8 +14,12 @@ ...@@ -14,8 +14,12 @@
from __future__ import print_function from __future__ import print_function
import unittest
import numpy as np import numpy as np
import unittest
import paddle.fluid as fluid
import paddle.fluid.initializer as initializer
from op_test import OpTest from op_test import OpTest
...@@ -59,7 +63,7 @@ def nce(input, weight, bias, sample_weight, labels, num_classes, ...@@ -59,7 +63,7 @@ def nce(input, weight, bias, sample_weight, labels, num_classes,
class TestNCE(OpTest): class TestNCE(OpTest):
def generate_data(self, dim, batch_size, num_classes, num_true_class, def generate_data(self, dim, batch_size, num_classes, num_true_class,
num_neg_samples): num_neg_samples, is_sparse):
input = np.random.randn(batch_size, dim).astype(np.float32) input = np.random.randn(batch_size, dim).astype(np.float32)
weight = np.random.randn(num_classes, dim).astype(np.float32) weight = np.random.randn(num_classes, dim).astype(np.float32)
bias = np.random.randn(num_classes).astype(np.float32) bias = np.random.randn(num_classes).astype(np.float32)
...@@ -70,7 +74,8 @@ class TestNCE(OpTest): ...@@ -70,7 +74,8 @@ class TestNCE(OpTest):
'num_neg_samples': num_neg_samples, 'num_neg_samples': num_neg_samples,
'custom_neg_classes': list(range(num_neg_samples)), 'custom_neg_classes': list(range(num_neg_samples)),
'seed': 0, 'seed': 0,
'sampler': 0 'sampler': 0,
'is_sparse': is_sparse
} }
self.inputs = { self.inputs = {
'Input': input, 'Input': input,
...@@ -81,7 +86,7 @@ class TestNCE(OpTest): ...@@ -81,7 +86,7 @@ class TestNCE(OpTest):
} }
def set_data(self): def set_data(self):
self.generate_data(5, 5, 4, 1, 2) self.generate_data(5, 5, 4, 1, 2, False)
def compute(self): def compute(self):
out = nce(self.inputs['Input'], self.inputs['Weight'], out = nce(self.inputs['Input'], self.inputs['Weight'],
...@@ -107,9 +112,110 @@ class TestNCE(OpTest): ...@@ -107,9 +112,110 @@ class TestNCE(OpTest):
["Input", "Weight", "Bias"], "Cost", max_relative_error=0.02) ["Input", "Weight", "Bias"], "Cost", max_relative_error=0.02)
class TestNCECase1(TestNCE): class TestNCECase1Tensor(TestNCE):
def set_data(self): def set_data(self):
self.generate_data(10, 20, 10, 2, 5) self.generate_data(10, 20, 10, 2, 5, False)
class TestNCECase1SelectedRows(unittest.TestCase):
def setUp(self):
self.base_lr = 0.0001
self.batch_size = 8
@staticmethod
def get_place():
place = fluid.core.CPUPlace()
return place
@staticmethod
def get_train_data(batch_size):
batchs = []
for i in range(batch_size):
input = np.random.randn(batch_size, 10).astype(np.float32)
labels = np.random.randint(0, 20, (batch_size, 1))
batchs.append([input, labels])
return batchs
def get_optimizer(self):
# SGD optimizer
optimizer = fluid.optimizer.SGD(learning_rate=self.base_lr)
return optimizer
def train_network(self, num_total_classes, num_neg_samples, sampler,
custom_dist, is_sparse):
input = fluid.layers.data(name="input", shape=[10], dtype="float32")
label = fluid.layers.data(name="label", shape=[1], dtype="int64")
w_param = fluid.default_main_program().global_block().create_parameter(
shape=[num_total_classes, 10],
dtype='float32',
name='nce_w',
initializer=initializer.ConstantInitializer())
b_param = fluid.default_main_program().global_block().create_parameter(
shape=[num_total_classes, 1],
dtype='float32',
name='nce_b',
initializer=initializer.ConstantInitializer())
cost = fluid.layers.nce(input=input,
label=label,
num_total_classes=num_total_classes,
sampler=sampler,
custom_dist=custom_dist,
sample_weight=None,
param_attr='nce_w',
bias_attr='nce_b',
seed=1,
num_neg_samples=num_neg_samples,
is_sparse=is_sparse)
avg_cost = fluid.layers.mean(cost)
# optimizer
optimizer = self.get_optimizer()
optimizer.minimize(avg_cost)
return [avg_cost, [input, label]]
def test_input_is_selected_rows(self):
place = self.get_place()
exe = fluid.Executor(place)
data = self.get_train_data(self.batch_size)
nid_freq_arr = np.random.dirichlet(np.ones(20) * 1000).astype('float32')
rets = []
# for dense
dense_scope = fluid.core.Scope()
dense_startup_program = fluid.framework.Program()
dense_train_program = fluid.framework.Program()
with fluid.scope_guard(dense_scope):
with fluid.program_guard(dense_train_program,
dense_startup_program):
cost, feeds = self.train_network(20, 5, "custom_dist",
nid_freq_arr.tolist(), False)
feeder = fluid.DataFeeder(feed_list=feeds, place=place)
exe.run(dense_startup_program)
loss_val = exe.run(dense_train_program,
feed=feeder.feed(data),
fetch_list=[cost.name])
rets.append(np.mean(loss_val))
# for sparse
sparse_scope = fluid.core.Scope()
sparse_startup_program = fluid.framework.Program()
sparse_train_program = fluid.framework.Program()
with fluid.scope_guard(sparse_scope):
with fluid.program_guard(sparse_train_program,
sparse_startup_program):
cost, feeds = self.train_network(20, 5, "custom_dist",
nid_freq_arr.tolist(), True)
feeder = fluid.DataFeeder(feed_list=feeds, place=place)
exe.run(sparse_startup_program)
loss_val = exe.run(sparse_train_program,
feed=feeder.feed(data),
fetch_list=[cost.name])
rets.append(np.mean(loss_val))
self.assertEqual(rets[0], rets[1])
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册