提交 6e361542 编写于 作者: J JiabinYang

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into add_reorg_op

...@@ -28,3 +28,4 @@ third_party/ ...@@ -28,3 +28,4 @@ third_party/
build_* build_*
# clion workspace. # clion workspace.
cmake-build-* cmake-build-*
model_test
...@@ -75,14 +75,14 @@ RUN pip3 install -U wheel && \ ...@@ -75,14 +75,14 @@ RUN pip3 install -U wheel && \
pip3 install -U docopt PyYAML sphinx==1.5.6 && \ pip3 install -U docopt PyYAML sphinx==1.5.6 && \
pip3 install sphinx-rtd-theme==0.1.9 recommonmark && \ pip3 install sphinx-rtd-theme==0.1.9 recommonmark && \
easy_install -U pip && \ easy_install -U pip && \
pip install -U wheel && \ pip install -U pip setuptools wheel && \
pip install -U docopt PyYAML sphinx==1.5.6 && \ pip install -U docopt PyYAML sphinx==1.5.6 && \
pip install sphinx-rtd-theme==0.1.9 recommonmark pip install sphinx-rtd-theme==0.1.9 recommonmark
RUN pip3 install pre-commit 'ipython==5.3.0' && \ RUN pip3 install 'pre-commit==1.10.4' 'ipython==5.3.0' && \
pip3 install 'ipykernel==4.6.0' 'jupyter==1.0.0' && \ pip3 install 'ipykernel==4.6.0' 'jupyter==1.0.0' && \
pip3 install opencv-python && \ pip3 install opencv-python && \
pip install pre-commit 'ipython==5.3.0' && \ pip install 'pre-commit==1.10.4' 'ipython==5.3.0' && \
pip install 'ipykernel==4.6.0' 'jupyter==1.0.0' && \ pip install 'ipykernel==4.6.0' 'jupyter==1.0.0' && \
pip install opencv-python pip install opencv-python
......
...@@ -86,7 +86,7 @@ paddle.fluid.layers.reduce_prod ArgSpec(args=['input', 'dim', 'keep_dim', 'name' ...@@ -86,7 +86,7 @@ paddle.fluid.layers.reduce_prod ArgSpec(args=['input', 'dim', 'keep_dim', 'name'
paddle.fluid.layers.sequence_first_step ArgSpec(args=['input'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.sequence_first_step ArgSpec(args=['input'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.sequence_last_step ArgSpec(args=['input'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.sequence_last_step ArgSpec(args=['input'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.sequence_slice ArgSpec(args=['input', 'offset', 'length', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.sequence_slice ArgSpec(args=['input', 'offset', 'length', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.dropout ArgSpec(args=['x', 'dropout_prob', 'is_test', 'seed', 'name'], varargs=None, keywords=None, defaults=(False, None, None)) paddle.fluid.layers.dropout ArgSpec(args=['x', 'dropout_prob', 'is_test', 'seed', 'name', 'dropout_implementation'], varargs=None, keywords=None, defaults=(False, None, None, 'downgrade_in_infer'))
paddle.fluid.layers.split ArgSpec(args=['input', 'num_or_sections', 'dim', 'name'], varargs=None, keywords=None, defaults=(-1, None)) paddle.fluid.layers.split ArgSpec(args=['input', 'num_or_sections', 'dim', 'name'], varargs=None, keywords=None, defaults=(-1, None))
paddle.fluid.layers.ctc_greedy_decoder ArgSpec(args=['input', 'blank', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.ctc_greedy_decoder ArgSpec(args=['input', 'blank', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.edit_distance ArgSpec(args=['input', 'label', 'normalized', 'ignored_tokens'], varargs=None, keywords=None, defaults=(True, None)) paddle.fluid.layers.edit_distance ArgSpec(args=['input', 'label', 'normalized', 'ignored_tokens'], varargs=None, keywords=None, defaults=(True, None))
...@@ -107,7 +107,7 @@ paddle.fluid.layers.softmax_with_cross_entropy ArgSpec(args=['logits', 'label', ...@@ -107,7 +107,7 @@ paddle.fluid.layers.softmax_with_cross_entropy ArgSpec(args=['logits', 'label',
paddle.fluid.layers.smooth_l1 ArgSpec(args=['x', 'y', 'inside_weight', 'outside_weight', 'sigma'], varargs=None, keywords=None, defaults=(None, None, None)) paddle.fluid.layers.smooth_l1 ArgSpec(args=['x', 'y', 'inside_weight', 'outside_weight', 'sigma'], varargs=None, keywords=None, defaults=(None, None, None))
paddle.fluid.layers.one_hot ArgSpec(args=['input', 'depth'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.one_hot ArgSpec(args=['input', 'depth'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.autoincreased_step_counter ArgSpec(args=['counter_name', 'begin', 'step'], varargs=None, keywords=None, defaults=(None, 1, 1)) paddle.fluid.layers.autoincreased_step_counter ArgSpec(args=['counter_name', 'begin', 'step'], varargs=None, keywords=None, defaults=(None, 1, 1))
paddle.fluid.layers.reshape ArgSpec(args=['x', 'shape', 'actual_shape', 'act', 'inplace', 'name'], varargs=None, keywords=None, defaults=(None, None, True, None)) paddle.fluid.layers.reshape ArgSpec(args=['x', 'shape', 'actual_shape', 'act', 'inplace', 'name'], varargs=None, keywords=None, defaults=(None, None, False, None))
paddle.fluid.layers.squeeze ArgSpec(args=['input', 'axes', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.squeeze ArgSpec(args=['input', 'axes', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.unsqueeze ArgSpec(args=['input', 'axes', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.unsqueeze ArgSpec(args=['input', 'axes', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.lod_reset ArgSpec(args=['x', 'y', 'target_lod'], varargs=None, keywords=None, defaults=(None, None)) paddle.fluid.layers.lod_reset ArgSpec(args=['x', 'y', 'target_lod'], varargs=None, keywords=None, defaults=(None, None))
......
...@@ -252,9 +252,9 @@ std::vector<ir::Node *> SortOpsAndDelayOptimizeOp(const ir::Graph &graph) { ...@@ -252,9 +252,9 @@ std::vector<ir::Node *> SortOpsAndDelayOptimizeOp(const ir::Graph &graph) {
std::vector<ir::Node *> sorted_ret; std::vector<ir::Node *> sorted_ret;
for (size_t i = 0; i < ret.size(); ++i) { for (size_t i = 0; i < ret.size(); ++i) {
if (i < last_backward) { if (i < last_backward) {
if (boost::get<int>(ret[i]->Op()->GetAttr( if (static_cast<bool>(boost::get<int>(ret[i]->Op()->GetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName())) == OpProtoAndCheckerMaker::OpRoleAttrName())) &
static_cast<int>(OpRole::kOptimize)) { static_cast<int>(OpRole::kOptimize))) {
optimize_ops.push_back(ret[i]); optimize_ops.push_back(ret[i]);
} else { } else {
sorted_ret.push_back(ret[i]); sorted_ret.push_back(ret[i]);
......
...@@ -542,6 +542,33 @@ class CPUVector : public std::vector<T, std::allocator<T>> { ...@@ -542,6 +542,33 @@ class CPUVector : public std::vector<T, std::allocator<T>> {
this->reserve(this->size() + size_t(end - begin)); this->reserve(this->size() + size_t(end - begin));
this->insert(this->end(), begin, end); this->insert(this->end(), begin, end);
} }
const T *CUDAData(platform::Place place) const {
PADDLE_THROW(
"Vector::CUDAData() method is not supported in CPU-only version");
}
T *CUDAMutableData(platform::Place place) {
PADDLE_THROW(
"Vector::CUDAMutableData() method is not supported in CPU-only "
"version");
}
const T *Data(platform::Place place) const {
PADDLE_ENFORCE(
platform::is_cpu_place(place),
"Vector::Data() method is not supported when not in CPUPlace");
return this->data();
}
T *MutableData(platform::Place place) {
PADDLE_ENFORCE(
platform::is_cpu_place(place),
"Vector::MutableData() method is not supported when not in CPUPlace");
return this->data();
}
const void *Handle() const { return static_cast<const void *>(this); }
}; };
template <typename T> template <typename T>
......
...@@ -146,22 +146,5 @@ void NaiveExecutor::CleanFeedFetchOps() { ...@@ -146,22 +146,5 @@ void NaiveExecutor::CleanFeedFetchOps() {
ops_.swap(ops); ops_.swap(ops);
} }
void NaiveExecutor::EnableMKLDNN(const ProgramDesc &program) {
#ifdef PADDLE_WITH_MKLDNN
VLOG(3) << "use_mkldnn=True";
for (size_t block_id = 0; block_id < program.Size(); ++block_id) {
auto *block = const_cast<ProgramDesc &>(program).MutableBlock(block_id);
for (auto *op : block->AllOps()) {
if (op->HasAttr("use_mkldnn")) {
op->SetAttr("use_mkldnn", true);
}
}
}
#else
LOG(WARNING)
<< "'MKLDNN' is not supported, Please re-compile with WITH_MKLDNN option";
#endif
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -48,8 +48,6 @@ class NaiveExecutor { ...@@ -48,8 +48,6 @@ class NaiveExecutor {
void CleanFeedFetchOps(); void CleanFeedFetchOps();
void EnableMKLDNN(const ProgramDesc& program);
protected: protected:
void CreateVariables(const ProgramDesc& desc, Scope* scope, int block_id); void CreateVariables(const ProgramDesc& desc, Scope* scope, int block_id);
......
...@@ -71,6 +71,8 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto, ...@@ -71,6 +71,8 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto,
static_cast<int>(OpRole::kLoss) | static_cast<int>(OpRole::kForward), static_cast<int>(OpRole::kLoss) | static_cast<int>(OpRole::kForward),
static_cast<int>(OpRole::kLoss) | static_cast<int>(OpRole::kLoss) |
static_cast<int>(OpRole::kBackward), static_cast<int>(OpRole::kBackward),
static_cast<int>(OpRole::kOptimize) |
static_cast<int>(OpRole::kLRSched),
static_cast<int>(OpRole::kNotSpecified)}) static_cast<int>(OpRole::kNotSpecified)})
.SetDefault(static_cast<int>(OpRole::kNotSpecified)); .SetDefault(static_cast<int>(OpRole::kNotSpecified));
AddAttr<std::vector<std::string>>(OpRoleVarAttrName(), AddAttr<std::vector<std::string>>(OpRoleVarAttrName(),
......
...@@ -20,6 +20,9 @@ limitations under the License. */ ...@@ -20,6 +20,9 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
//////////////////////////
// Don't add more roles to make this too complicated!
//////////////////////////
enum class OpRole { enum class OpRole {
kForward = 0x0000, kForward = 0x0000,
kBackward = 0x0001, kBackward = 0x0001,
......
...@@ -156,12 +156,6 @@ ParallelExecutor::ParallelExecutor( ...@@ -156,12 +156,6 @@ ParallelExecutor::ParallelExecutor(
params, member_->local_scopes_, member_->use_cuda_); params, member_->local_scopes_, member_->use_cuda_);
#endif #endif
// If the loss_var_name is given, the number of graph should be only one.
if (loss_var_name.size()) {
PADDLE_ENFORCE_EQ(ir::GraphNum(*graph), 1,
"The number of graph should be only one");
}
if (exec_strategy.type_ == ExecutionStrategy::kDefault) { if (exec_strategy.type_ == ExecutionStrategy::kDefault) {
member_->executor_.reset(new details::ThreadedSSAGraphExecutor( member_->executor_.reset(new details::ThreadedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, places, std::move(graph))); exec_strategy, member_->local_scopes_, places, std::move(graph)));
......
...@@ -25,7 +25,6 @@ DEFINE_int32(dist_threadpool_size, 0, ...@@ -25,7 +25,6 @@ DEFINE_int32(dist_threadpool_size, 0,
namespace paddle { namespace paddle {
namespace framework { namespace framework {
std::unique_ptr<ThreadPool> ThreadPool::threadpool_(nullptr); std::unique_ptr<ThreadPool> ThreadPool::threadpool_(nullptr);
std::once_flag ThreadPool::init_flag_; std::once_flag ThreadPool::init_flag_;
...@@ -47,8 +46,7 @@ void ThreadPool::Init() { ...@@ -47,8 +46,7 @@ void ThreadPool::Init() {
} }
} }
ThreadPool::ThreadPool(int num_threads) ThreadPool::ThreadPool(int num_threads) : running_(true) {
: total_threads_(num_threads), idle_threads_(num_threads), running_(true) {
threads_.resize(num_threads); threads_.resize(num_threads);
for (auto& thread : threads_) { for (auto& thread : threads_) {
// TODO(Yancey1989): binding the thread on the specify CPU number // TODO(Yancey1989): binding the thread on the specify CPU number
...@@ -59,6 +57,7 @@ ThreadPool::ThreadPool(int num_threads) ...@@ -59,6 +57,7 @@ ThreadPool::ThreadPool(int num_threads)
ThreadPool::~ThreadPool() { ThreadPool::~ThreadPool() {
{ {
// notify all threads to stop running // notify all threads to stop running
std::lock_guard<std::mutex> l(mutex_);
running_ = false; running_ = false;
scheduled_.notify_all(); scheduled_.notify_all();
} }
...@@ -69,36 +68,24 @@ ThreadPool::~ThreadPool() { ...@@ -69,36 +68,24 @@ ThreadPool::~ThreadPool() {
} }
} }
void ThreadPool::Wait() {
std::unique_lock<std::mutex> lock(mutex_);
completed_.wait(lock, [=] { return Done() == true; });
}
void ThreadPool::TaskLoop() { void ThreadPool::TaskLoop() {
while (running_) { while (true) {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
scheduled_.wait(lock, [=] { return !tasks_.empty() || !running_; });
if (!running_) { scheduled_.wait(
break; lock, [this] { return !this->tasks_.empty() || !this->running_; });
if (!running_ || tasks_.empty()) {
return;
} }
// pop a task from the task queue // pop a task from the task queue
auto task = std::move(tasks_.front()); auto task = std::move(tasks_.front());
tasks_.pop(); tasks_.pop();
--idle_threads_;
lock.unlock(); lock.unlock();
// run the task // run the task
task(); task();
{
std::unique_lock<std::mutex> lock(mutex_);
++idle_threads_;
if (Done()) {
completed_.notify_all();
}
}
} }
} }
......
...@@ -57,15 +57,6 @@ class ThreadPool { ...@@ -57,15 +57,6 @@ class ThreadPool {
~ThreadPool(); ~ThreadPool();
// Returns the number of threads created by the constructor.
size_t Threads() const { return total_threads_; }
// Returns the number of currently idle threads.
size_t IdleThreads() {
std::unique_lock<std::mutex> lock(mutex_);
return idle_threads_;
}
// Run pushes a function to the task queue and returns a std::future // Run pushes a function to the task queue and returns a std::future
// object. To wait for the completion of the task, call // object. To wait for the completion of the task, call
// std::future::wait(). // std::future::wait().
...@@ -94,25 +85,13 @@ class ThreadPool { ...@@ -94,25 +85,13 @@ class ThreadPool {
}); });
std::future<std::unique_ptr<platform::EnforceNotMet>> f = task.get_future(); std::future<std::unique_ptr<platform::EnforceNotMet>> f = task.get_future();
tasks_.push(std::move(task)); tasks_.push(std::move(task));
lock.unlock();
scheduled_.notify_one(); scheduled_.notify_one();
return f; return f;
} }
// Wait until all the tasks are completed.
void Wait();
private: private:
DISABLE_COPY_AND_ASSIGN(ThreadPool); DISABLE_COPY_AND_ASSIGN(ThreadPool);
// If the task queue is empty and avaialbe is equal to the number of
// threads, means that all tasks are completed. Note: this function
// is not thread-safe. Returns true if all tasks are completed.
// Note: don't delete the data member total_threads_ and use
// threads_.size() instead; because you'd need to lock the mutex
// before accessing threads_.
bool Done() { return tasks_.empty() && idle_threads_ == total_threads_; }
// The constructor starts threads to run TaskLoop, which retrieves // The constructor starts threads to run TaskLoop, which retrieves
// and runs tasks from the queue. // and runs tasks from the queue.
void TaskLoop(); void TaskLoop();
...@@ -125,14 +104,11 @@ class ThreadPool { ...@@ -125,14 +104,11 @@ class ThreadPool {
static std::once_flag init_flag_; static std::once_flag init_flag_;
std::vector<std::unique_ptr<std::thread>> threads_; std::vector<std::unique_ptr<std::thread>> threads_;
const size_t total_threads_;
size_t idle_threads_;
std::queue<Task> tasks_; std::queue<Task> tasks_;
std::mutex mutex_; std::mutex mutex_;
bool running_; bool running_;
std::condition_variable scheduled_; std::condition_variable scheduled_;
std::condition_variable completed_;
}; };
class ThreadPoolIO : ThreadPool { class ThreadPoolIO : ThreadPool {
......
...@@ -19,10 +19,11 @@ limitations under the License. */ ...@@ -19,10 +19,11 @@ limitations under the License. */
namespace framework = paddle::framework; namespace framework = paddle::framework;
void do_sum(framework::ThreadPool* pool, std::atomic<int>* sum, int cnt) { void do_sum(std::vector<std::future<void>>* fs, std::mutex* mu,
std::vector<std::future<void>> fs; std::atomic<int>* sum, int cnt) {
for (int i = 0; i < cnt; ++i) { for (int i = 0; i < cnt; ++i) {
fs.push_back(framework::Async([sum]() { sum->fetch_add(1); })); std::lock_guard<std::mutex> l(*mu);
fs->push_back(framework::Async([sum]() { sum->fetch_add(1); }));
} }
} }
...@@ -40,18 +41,21 @@ TEST(ThreadPool, ConcurrentInit) { ...@@ -40,18 +41,21 @@ TEST(ThreadPool, ConcurrentInit) {
} }
TEST(ThreadPool, ConcurrentRun) { TEST(ThreadPool, ConcurrentRun) {
framework::ThreadPool* pool = framework::ThreadPool::GetInstance();
std::atomic<int> sum(0); std::atomic<int> sum(0);
std::vector<std::thread> threads; std::vector<std::thread> threads;
std::vector<std::future<void>> fs;
std::mutex fs_mu;
int n = 50; int n = 50;
// sum = (n * (n + 1)) / 2 // sum = (n * (n + 1)) / 2
for (int i = 1; i <= n; ++i) { for (int i = 1; i <= n; ++i) {
std::thread t(do_sum, pool, &sum, i); std::thread t(do_sum, &fs, &fs_mu, &sum, i);
threads.push_back(std::move(t)); threads.push_back(std::move(t));
} }
for (auto& t : threads) { for (auto& t : threads) {
t.join(); t.join();
} }
pool->Wait(); for (auto& t : fs) {
t.wait();
}
EXPECT_EQ(sum, ((n + 1) * n) / 2); EXPECT_EQ(sum, ((n + 1) * n) / 2);
} }
...@@ -284,7 +284,7 @@ static inline Tensor NMS(const platform::DeviceContext &ctx, Tensor *bbox, ...@@ -284,7 +284,7 @@ static inline Tensor NMS(const platform::DeviceContext &ctx, Tensor *bbox,
selected_indices.push_back(idx); selected_indices.push_back(idx);
++selected_num; ++selected_num;
} }
sorted_indices.erase(sorted_indices.end()); sorted_indices.erase(sorted_indices.end() - 1);
if (flag && eta < 1 && adaptive_threshold > 0.5) { if (flag && eta < 1 && adaptive_threshold > 0.5) {
adaptive_threshold *= eta; adaptive_threshold *= eta;
} }
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/dropout_op.h" #include "paddle/fluid/operators/dropout_op.h"
#include <string>
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -57,6 +58,29 @@ class DropoutOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -57,6 +58,29 @@ class DropoutOpMaker : public framework::OpProtoAndCheckerMaker {
"will be dropped.") "will be dropped.")
.SetDefault(false); .SetDefault(false);
AddAttr<int>("seed", "Dropout random seed.").SetDefault(0); AddAttr<int>("seed", "Dropout random seed.").SetDefault(0);
AddAttr<std::string>(
"dropout_implementation",
"[\"downgrade_in_infer\"|\"upscale_in_train\"]"
"There are two kinds of ways to implement dropout"
"(the mask below is a tensor have the same shape with input"
"the value of mask is 0 or 1, the ratio of 0 is dropout_prob)"
"1. downgrade_in_infer(default), downgrade the outcome at inference "
"time"
" train: out = input * mask"
" inference: out = input * dropout_prob"
"2. upscale_in_train, upscale the outcome at training time, do nothing "
"in inference"
" train: out = input * mask / ( 1.0 - dropout_prob )"
" inference: out = input"
" dropout op can be removed from the program. the program will be "
"efficient")
.SetDefault("downgrade_in_infer")
.AddCustomChecker([](const std::string& type) {
PADDLE_ENFORCE(
type == "downgrade_in_infer" || type == "upscale_in_train",
"dropout_implementation can only be downgrade_in_infer or "
"upscale_in_train");
});
AddComment(R"DOC( AddComment(R"DOC(
Dropout Operator. Dropout Operator.
...@@ -104,7 +128,9 @@ REGISTER_OPERATOR(dropout, ops::DropoutOp, ops::DropoutOpMaker, ...@@ -104,7 +128,9 @@ REGISTER_OPERATOR(dropout, ops::DropoutOp, ops::DropoutOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>); paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(dropout_grad, ops::DropoutOpGrad); REGISTER_OPERATOR(dropout_grad, ops::DropoutOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
dropout, ops::CPUDropoutKernel<paddle::platform::CPUDeviceContext, float>); dropout, ops::CPUDropoutKernel<paddle::platform::CPUDeviceContext, float>,
ops::CPUDropoutKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
dropout_grad, dropout_grad,
ops::DropoutGradKernel<paddle::platform::CPUDeviceContext, float>); ops::DropoutGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::DropoutGradKernel<paddle::platform::CPUDeviceContext, double>);
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include <thrust/iterator/counting_iterator.h> #include <thrust/iterator/counting_iterator.h>
#include <thrust/random.h> #include <thrust/random.h>
#include <thrust/transform.h> #include <thrust/transform.h>
#include <string>
#include "paddle/fluid/operators/dropout_op.h" #include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
...@@ -26,7 +27,8 @@ namespace operators { ...@@ -26,7 +27,8 @@ namespace operators {
template <typename T> template <typename T>
__global__ void RandomGenerator(const size_t n, const int seed, __global__ void RandomGenerator(const size_t n, const int seed,
const float dropout_prob, const T* src, const float dropout_prob, const T* src,
T* mask_data, T* dst) { T* mask_data, T* dst,
bool is_upscale_in_train) {
thrust::minstd_rand rng; thrust::minstd_rand rng;
rng.seed(seed); rng.seed(seed);
thrust::uniform_real_distribution<float> dist(0, 1); thrust::uniform_real_distribution<float> dist(0, 1);
...@@ -47,7 +49,11 @@ __global__ void RandomGenerator(const size_t n, const int seed, ...@@ -47,7 +49,11 @@ __global__ void RandomGenerator(const size_t n, const int seed,
if (dist(rng) < dropout_prob) { if (dist(rng) < dropout_prob) {
mask = static_cast<T>(0); mask = static_cast<T>(0);
} else { } else {
mask = static_cast<T>(1); if (is_upscale_in_train) {
mask = static_cast<T>(1.0f / (1.0f - dropout_prob));
} else {
mask = static_cast<T>(1);
}
} }
dest = s * mask; dest = s * mask;
mask_data[idx] = mask; mask_data[idx] = mask;
...@@ -67,6 +73,8 @@ class GPUDropoutKernel : public framework::OpKernel<T> { ...@@ -67,6 +73,8 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
y->mutable_data<T>(context.GetPlace()); y->mutable_data<T>(context.GetPlace());
float dropout_prob = context.Attr<float>("dropout_prob"); float dropout_prob = context.Attr<float>("dropout_prob");
auto dropout_implementation =
context.Attr<std::string>("dropout_implementation");
auto& place = *context.template device_context<Place>().eigen_device(); auto& place = *context.template device_context<Place>().eigen_device();
if (!context.Attr<bool>("is_test")) { if (!context.Attr<bool>("is_test")) {
auto* mask = context.Output<Tensor>("Mask"); auto* mask = context.Output<Tensor>("Mask");
...@@ -83,11 +91,16 @@ class GPUDropoutKernel : public framework::OpKernel<T> { ...@@ -83,11 +91,16 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
int grid = (x->numel() + threads - 1) / threads; int grid = (x->numel() + threads - 1) / threads;
RandomGenerator< RandomGenerator<
T><<<grid, threads, 0, context.cuda_device_context().stream()>>>( T><<<grid, threads, 0, context.cuda_device_context().stream()>>>(
size, seed, dropout_prob, x_data, mask_data, y_data); size, seed, dropout_prob, x_data, mask_data, y_data,
(dropout_implementation == "upscale_in_train"));
} else { } else {
auto X = EigenMatrix<T>::Reshape(*x, 1); auto X = EigenMatrix<T>::Reshape(*x, 1);
auto Y = EigenMatrix<T>::Reshape(*y, 1); auto Y = EigenMatrix<T>::Reshape(*y, 1);
Y.device(place) = X * static_cast<T>(1.0f - dropout_prob); if (dropout_implementation == "upscale_in_train") {
Y.device(place) = X;
} else {
Y.device(place) = X * static_cast<T>(1.0f - dropout_prob);
}
} }
} }
}; };
...@@ -99,6 +112,8 @@ namespace ops = paddle::operators; ...@@ -99,6 +112,8 @@ namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
dropout, ops::GPUDropoutKernel<plat::CUDADeviceContext, float>, dropout, ops::GPUDropoutKernel<plat::CUDADeviceContext, float>,
ops::GPUDropoutKernel<plat::CUDADeviceContext, plat::float16>); ops::GPUDropoutKernel<plat::CUDADeviceContext, plat::float16>,
REGISTER_OP_CUDA_KERNEL(dropout_grad, ops::GPUDropoutKernel<plat::CUDADeviceContext, double>);
ops::DropoutGradKernel<plat::CUDADeviceContext, float>); REGISTER_OP_CUDA_KERNEL(
dropout_grad, ops::DropoutGradKernel<plat::CUDADeviceContext, float>,
ops::DropoutGradKernel<plat::CUDADeviceContext, double>);
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include <random> #include <random>
#include <string>
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
...@@ -36,6 +37,8 @@ class CPUDropoutKernel : public framework::OpKernel<T> { ...@@ -36,6 +37,8 @@ class CPUDropoutKernel : public framework::OpKernel<T> {
auto* y_data = y->mutable_data<T>(context.GetPlace()); auto* y_data = y->mutable_data<T>(context.GetPlace());
float dropout_prob = context.Attr<float>("dropout_prob"); float dropout_prob = context.Attr<float>("dropout_prob");
auto dropout_implementation =
context.Attr<std::string>("dropout_implementation");
if (!context.Attr<bool>("is_test")) { if (!context.Attr<bool>("is_test")) {
auto* mask = context.Output<Tensor>("Mask"); auto* mask = context.Output<Tensor>("Mask");
auto* mask_data = mask->mutable_data<T>(context.GetPlace()); auto* mask_data = mask->mutable_data<T>(context.GetPlace());
...@@ -49,14 +52,20 @@ class CPUDropoutKernel : public framework::OpKernel<T> { ...@@ -49,14 +52,20 @@ class CPUDropoutKernel : public framework::OpKernel<T> {
engine.seed(seed); engine.seed(seed);
std::uniform_real_distribution<float> dist(0, 1); std::uniform_real_distribution<float> dist(0, 1);
size_t size = framework::product(mask->dims()); size_t size = framework::product(mask->dims());
for (size_t i = 0; i < size; ++i) { for (size_t i = 0; i < size; ++i) {
if (dist(engine) < dropout_prob) { if (dist(engine) < dropout_prob) {
mask_data[i] = 0; mask_data[i] = 0;
y_data[i] = 0; y_data[i] = 0;
} else { } else {
mask_data[i] = 1; if (dropout_implementation == "upscale_in_train") {
y_data[i] = x_data[i]; mask_data[i] = 1.0f / static_cast<T>(1.0f - dropout_prob);
y_data[i] = x_data[i] / static_cast<T>(1.0f - dropout_prob);
} else {
mask_data[i] = 1;
y_data[i] = x_data[i];
}
} }
} }
} else { } else {
...@@ -64,7 +73,11 @@ class CPUDropoutKernel : public framework::OpKernel<T> { ...@@ -64,7 +73,11 @@ class CPUDropoutKernel : public framework::OpKernel<T> {
auto Y = EigenMatrix<T>::Reshape(*y, 1); auto Y = EigenMatrix<T>::Reshape(*y, 1);
auto& place = auto& place =
*context.template device_context<DeviceContext>().eigen_device(); *context.template device_context<DeviceContext>().eigen_device();
Y.device(place) = X * (1.0f - dropout_prob); if (dropout_implementation == "upscale_in_train") {
Y.device(place) = X;
} else {
Y.device(place) = X * static_cast<T>(1.0f - dropout_prob);
}
} }
} }
}; };
......
...@@ -136,6 +136,7 @@ static std::shared_ptr<const VActKernel<T>> GetActKernel( ...@@ -136,6 +136,7 @@ static std::shared_ptr<const VActKernel<T>> GetActKernel(
return nullptr; return nullptr;
} }
#ifdef __AVX__
template <jit::cpu_isa_t isa> template <jit::cpu_isa_t isa>
static std::unique_ptr<AVXAct> GetAVXAct(const std::string& type) { static std::unique_ptr<AVXAct> GetAVXAct(const std::string& type) {
if (type == "sigmoid") { if (type == "sigmoid") {
...@@ -150,6 +151,7 @@ static std::unique_ptr<AVXAct> GetAVXAct(const std::string& type) { ...@@ -150,6 +151,7 @@ static std::unique_ptr<AVXAct> GetAVXAct(const std::string& type) {
PADDLE_THROW("Not support type: %s", type); PADDLE_THROW("Not support type: %s", type);
return nullptr; return nullptr;
} }
#endif
/* LSTM JitKernel */ /* LSTM JitKernel */
template <typename T, jit::cpu_isa_t isa, jit_block> template <typename T, jit::cpu_isa_t isa, jit_block>
......
...@@ -76,6 +76,8 @@ namespace ops = paddle::operators; ...@@ -76,6 +76,8 @@ namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
REGISTER_OP_KERNEL(softmax, CUDNN, plat::CUDAPlace, REGISTER_OP_KERNEL(softmax, CUDNN, plat::CUDAPlace,
ops::SoftmaxCUDNNKernel<float>, ops::SoftmaxCUDNNKernel<float>,
ops::SoftmaxCUDNNKernel<double>,
ops::SoftmaxCUDNNKernel<plat::float16>); ops::SoftmaxCUDNNKernel<plat::float16>);
REGISTER_OP_KERNEL(softmax_grad, CUDNN, plat::CUDAPlace, REGISTER_OP_KERNEL(softmax_grad, CUDNN, plat::CUDAPlace,
ops::SoftmaxGradCUDNNKernel<float>); ops::SoftmaxGradCUDNNKernel<float>,
ops::SoftmaxGradCUDNNKernel<double>);
...@@ -210,18 +210,21 @@ REGISTER_OPERATOR(transpose, ops::TransposeOp, ops::TransposeOpMaker, ...@@ -210,18 +210,21 @@ REGISTER_OPERATOR(transpose, ops::TransposeOp, ops::TransposeOpMaker,
REGISTER_OPERATOR(transpose_grad, ops::TransposeOpGrad); REGISTER_OPERATOR(transpose_grad, ops::TransposeOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
transpose, ops::TransposeKernel<paddle::platform::CPUDeviceContext, float>); transpose, ops::TransposeKernel<paddle::platform::CPUDeviceContext, float>,
ops::TransposeKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
transpose_grad, transpose_grad,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, float>); ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OPERATOR(transpose2, ops::Transpose2Op, ops::Transpose2OpMaker, REGISTER_OPERATOR(transpose2, ops::Transpose2Op, ops::Transpose2OpMaker,
ops::Transpose2GradMaker); ops::Transpose2GradMaker);
REGISTER_OPERATOR(transpose2_grad, ops::Transpose2OpGrad); REGISTER_OPERATOR(transpose2_grad, ops::Transpose2OpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
transpose2, transpose2, ops::TransposeKernel<paddle::platform::CPUDeviceContext, float>,
ops::TransposeKernel<paddle::platform::CPUDeviceContext, float>); ops::TransposeKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
transpose2_grad, transpose2_grad,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, float>); ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, double>);
...@@ -16,15 +16,18 @@ limitations under the License. */ ...@@ -16,15 +16,18 @@ limitations under the License. */
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
transpose, transpose, ops::TransposeKernel<paddle::platform::CUDADeviceContext, float>,
ops::TransposeKernel<paddle::platform::CUDADeviceContext, float>); ops::TransposeKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
transpose_grad, transpose_grad,
ops::TransposeGradKernel<paddle::platform::CUDADeviceContext, float>); ops::TransposeGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::TransposeGradKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
transpose2, transpose2,
ops::TransposeKernel<paddle::platform::CUDADeviceContext, float>); ops::TransposeKernel<paddle::platform::CUDADeviceContext, float>,
ops::TransposeKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
transpose2_grad, transpose2_grad,
ops::TransposeGradKernel<paddle::platform::CUDADeviceContext, float>); ops::TransposeGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::TransposeGradKernel<paddle::platform::CUDADeviceContext, double>);
...@@ -272,7 +272,7 @@ class GradientClipByGlobalNorm(BaseGradientClipAttr): ...@@ -272,7 +272,7 @@ class GradientClipByGlobalNorm(BaseGradientClipAttr):
) )
square = grad * grad square = grad * grad
local_norm_var = layers.cast(layers.reduce_sum(input=square), 'float64') local_norm_var = layers.reduce_sum(input=square)
context[self.group_name].append(local_norm_var) context[self.group_name].append(local_norm_var)
self.context = context self.context = context
...@@ -282,7 +282,6 @@ class GradientClipByGlobalNorm(BaseGradientClipAttr): ...@@ -282,7 +282,6 @@ class GradientClipByGlobalNorm(BaseGradientClipAttr):
if group_scale_name not in self.context: if group_scale_name not in self.context:
group_norm_var = layers.sums(input=self.context[self.group_name]) group_norm_var = layers.sums(input=self.context[self.group_name])
group_norm_var = layers.sqrt(x=group_norm_var) group_norm_var = layers.sqrt(x=group_norm_var)
group_norm_var = layers.cast(group_norm_var, 'float32')
clip_var = self.context[self.group_name + "_clip"] clip_var = self.context[self.group_name + "_clip"]
group_scale_var = layers.elementwise_div( group_scale_var = layers.elementwise_div(
x=clip_var, x=clip_var,
...@@ -333,7 +332,8 @@ def append_gradient_clip_ops(param_grads): ...@@ -333,7 +332,8 @@ def append_gradient_clip_ops(param_grads):
for p, g in param_grads: for p, g in param_grads:
if g is None: if g is None:
continue continue
with p.block.program._optimized_guard([p, g]): with p.block.program._optimized_guard(
[p, g]), framework.name_scope('append_clip'):
clip_attr = getattr(p, 'gradient_clip_attr', NullGradientClipAttr()) clip_attr = getattr(p, 'gradient_clip_attr', NullGradientClipAttr())
if clip_attr is None: if clip_attr is None:
clip_attr = NullGradientClipAttr() clip_attr = NullGradientClipAttr()
...@@ -348,7 +348,8 @@ def append_gradient_clip_ops(param_grads): ...@@ -348,7 +348,8 @@ def append_gradient_clip_ops(param_grads):
for p, g in param_grads: for p, g in param_grads:
if g is None: if g is None:
continue continue
with p.block.program._optimized_guard([p, g]): with p.block.program._optimized_guard(
[p, g]), framework.name_scope('append_graident_clip'):
res.append(clip_attr._create_operators(param=p, grad=g)) res.append(clip_attr._create_operators(param=p, grad=g))
return res return res
......
...@@ -1496,6 +1496,9 @@ class Program(object): ...@@ -1496,6 +1496,9 @@ class Program(object):
>>> with program._optimized_guard([p,g]): >>> with program._optimized_guard([p,g]):
>>> p = p - 0.001 * g >>> p = p - 0.001 * g
""" """
tmp_role = self._current_role
tmp_var = self._op_role_var
OpRole = core.op_proto_and_checker_maker.OpRole OpRole = core.op_proto_and_checker_maker.OpRole
self._current_role = OpRole.Optimize self._current_role = OpRole.Optimize
self._op_role_var = [ self._op_role_var = [
...@@ -1503,11 +1506,11 @@ class Program(object): ...@@ -1503,11 +1506,11 @@ class Program(object):
for var in param_and_grads for var in param_and_grads
] ]
yield yield
self._op_role_var = [] self._op_role_var = tmp_var
self._current_role = OpRole.Forward self._current_role = tmp_role
@contextlib.contextmanager @contextlib.contextmanager
def _lr_schedule_guard(self): def _lr_schedule_guard(self, is_with_opt=False):
""" """
A with guard to set :code:`LRSched` :code:`OpRole` and A with guard to set :code:`LRSched` :code:`OpRole` and
:code:`OpRoleVar` automatically. The :code:`OpRoleVar` is :code:`OpRoleVar` automatically. The :code:`OpRoleVar` is
...@@ -1515,6 +1518,10 @@ class Program(object): ...@@ -1515,6 +1518,10 @@ class Program(object):
Notes: This is a very low level API. Users should not use it directly. Notes: This is a very low level API. Users should not use it directly.
Args:
is_with_opt: Only set to true if these ops a in the middle
of a bunch of optimize ops so that it can be treated
correctly. For example, sgd->lr_op->sgd->lr_op->sgd.
Examples: Examples:
...@@ -1528,6 +1535,8 @@ class Program(object): ...@@ -1528,6 +1535,8 @@ class Program(object):
OpRole = core.op_proto_and_checker_maker.OpRole OpRole = core.op_proto_and_checker_maker.OpRole
self._current_role = OpRole.LRSched self._current_role = OpRole.LRSched
if is_with_opt:
self._current_role = int(OpRole.LRSched) | int(OpRole.Optimize)
# TODO(typhoonzero): how to set target learning rate var # TODO(typhoonzero): how to set target learning rate var
self._op_role_var = [] self._op_role_var = []
yield yield
......
...@@ -981,7 +981,12 @@ def cos_sim(X, Y): ...@@ -981,7 +981,12 @@ def cos_sim(X, Y):
return out return out
def dropout(x, dropout_prob, is_test=False, seed=None, name=None): def dropout(x,
dropout_prob,
is_test=False,
seed=None,
name=None,
dropout_implementation="downgrade_in_infer"):
""" """
Computes dropout. Computes dropout.
...@@ -1001,6 +1006,21 @@ def dropout(x, dropout_prob, is_test=False, seed=None, name=None): ...@@ -1001,6 +1006,21 @@ def dropout(x, dropout_prob, is_test=False, seed=None, name=None):
units will be dropped. DO NOT use a fixed seed in training. units will be dropped. DO NOT use a fixed seed in training.
name (str|None): A name for this layer(optional). If set None, the layer name (str|None): A name for this layer(optional). If set None, the layer
will be named automatically. will be named automatically.
dropout_implementation(string): ['downgrade_in_infer'(defauld)|'upscale_in_train']
1. downgrade_in_infer(default), downgrade the outcome at inference
train: out = input * mask
inference: out = input * dropout_prob
(make is a tensor same shape with input, value is 0 or 1
ratio of 0 is dropout_prob)
2. upscale_in_train, upscale the outcome at training time
train: out = input * mask / ( 1.0 - dropout_prob )
inference: out = input
(make is a tensor same shape with input, value is 0 or 1
ratio of 0 is dropout_prob)
dropout op can be removed from the program.
the program will be efficient
Returns: Returns:
Variable: A tensor variable is the shape with `x`. Variable: A tensor variable is the shape with `x`.
...@@ -1030,7 +1050,8 @@ def dropout(x, dropout_prob, is_test=False, seed=None, name=None): ...@@ -1030,7 +1050,8 @@ def dropout(x, dropout_prob, is_test=False, seed=None, name=None):
'dropout_prob': dropout_prob, 'dropout_prob': dropout_prob,
'is_test': is_test, 'is_test': is_test,
'fix_seed': seed is not None, 'fix_seed': seed is not None,
'seed': seed if seed is not None else 0 'seed': seed if seed is not None else 0,
'dropout_implementation': dropout_implementation,
}) })
return out return out
...@@ -4845,7 +4866,7 @@ def autoincreased_step_counter(counter_name=None, begin=1, step=1): ...@@ -4845,7 +4866,7 @@ def autoincreased_step_counter(counter_name=None, begin=1, step=1):
return counter return counter
def reshape(x, shape, actual_shape=None, act=None, inplace=True, name=None): def reshape(x, shape, actual_shape=None, act=None, inplace=False, name=None):
""" """
Gives a new shape to the input Tensor without changing its data. Gives a new shape to the input Tensor without changing its data.
...@@ -4893,15 +4914,22 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=True, name=None): ...@@ -4893,15 +4914,22 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=True, name=None):
:attr:`shape` specifying shape. That is to :attr:`shape` specifying shape. That is to
say :attr:`actual_shape` has a higher priority say :attr:`actual_shape` has a higher priority
than :attr:`shape`. than :attr:`shape`.
act (str): The non-linear activation to be applied to output variable. act (str): The non-linear activation to be applied to the reshaped tensor
inplace(bool): If this flag is set true, the output variable.
shares data with input without copying, otherwise inplace(bool): Must use :attr:`False` if :attr:`x` is used in multiple
a new output tensor is created operators. If this flag is set :attr:`True`, reuse input
whose data is copied from input x. :attr:`x` to reshape, which will change the shape of
tensor variable :attr:`x` and might cause errors when
:attr:`x` is used in multiple operators. If :attr:`False`,
preserve the shape :attr:`x` and create a new output tensor
variable whose data is copied from input x but reshaped.
name (str): The name of this layer. It is optional. name (str): The name of this layer. It is optional.
Returns: Returns:
Variable: The output tensor. Variable: The reshaped tensor variable if :attr:`act` is None. It is a \
new tensor variable if :attr:`inplace` is :attr:`False`, \
otherwise it is :attr:`x`. If :attr:`act` is not None, return \
the activated tensor variable.
Raises: Raises:
TypeError: if actual_shape is neither Variable nor None. TypeError: if actual_shape is neither Variable nor None.
...@@ -4912,7 +4940,7 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=True, name=None): ...@@ -4912,7 +4940,7 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=True, name=None):
data = fluid.layers.data( data = fluid.layers.data(
name='data', shape=[2, 4, 6], dtype='float32') name='data', shape=[2, 4, 6], dtype='float32')
reshaped = fluid.layers.reshape( reshaped = fluid.layers.reshape(
x=data, shape=[-1, 0, 3, 2], act='tanh', inplace=True) x=data, shape=[-1, 0, 3, 2], inplace=True)
""" """
if not (isinstance(shape, list) or isinstance(shape, tuple)): if not (isinstance(shape, list) or isinstance(shape, tuple)):
...@@ -4939,7 +4967,8 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=True, name=None): ...@@ -4939,7 +4967,8 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=True, name=None):
"except one unknown dimension.") "except one unknown dimension.")
helper = LayerHelper("reshape2", **locals()) helper = LayerHelper("reshape2", **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype) out = x if inplace else helper.create_variable_for_type_inference(
dtype=x.dtype)
x_shape = helper.create_variable_for_type_inference(dtype=x.dtype) x_shape = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op( helper.append_op(
type="reshape2", type="reshape2",
......
...@@ -111,7 +111,9 @@ class Optimizer(object): ...@@ -111,7 +111,9 @@ class Optimizer(object):
if param_lr == 1.0: if param_lr == 1.0:
return self._global_learning_rate() return self._global_learning_rate()
else: else:
with default_main_program()._lr_schedule_guard(): with default_main_program()._lr_schedule_guard(
is_with_opt=True), framework.name_scope(
'scale_with_param_lr'):
return self._global_learning_rate() * param_lr return self._global_learning_rate() * param_lr
def _create_accumulators(self, block, parameters): def _create_accumulators(self, block, parameters):
...@@ -602,7 +604,8 @@ class AdamOptimizer(Optimizer): ...@@ -602,7 +604,8 @@ class AdamOptimizer(Optimizer):
for param, grad in param_and_grads: for param, grad in param_and_grads:
if grad is None: if grad is None:
continue continue
with param.block.program._optimized_guard([param, grad]): with param.block.program._optimized_guard(
[param, grad]), name_scope("optimizer"):
beta1_pow_acc = self._get_accumulator(self._beta1_pow_acc_str, beta1_pow_acc = self._get_accumulator(self._beta1_pow_acc_str,
param) param)
beta2_pow_acc = self._get_accumulator(self._beta2_pow_acc_str, beta2_pow_acc = self._get_accumulator(self._beta2_pow_acc_str,
...@@ -740,7 +743,8 @@ class AdamaxOptimizer(Optimizer): ...@@ -740,7 +743,8 @@ class AdamaxOptimizer(Optimizer):
for param, grad in parameters_and_grads: for param, grad in parameters_and_grads:
if grad is None: if grad is None:
continue continue
with param.block.program._optimized_guard([param, grad]): with param.block.program._optimized_guard(
[param, grad]), name_scope('adamx'):
beta1_pow_acc = self._get_accumulator(self._beta1_pow_acc_str, beta1_pow_acc = self._get_accumulator(self._beta1_pow_acc_str,
param) param)
main_block.append_op( main_block.append_op(
...@@ -1279,7 +1283,8 @@ class ModelAverage(Optimizer): ...@@ -1279,7 +1283,8 @@ class ModelAverage(Optimizer):
for param, grad in self.params_grads: for param, grad in self.params_grads:
if grad is None: if grad is None:
continue continue
with param.block.program._optimized_guard([param, grad]): with param.block.program._optimized_guard(
[param, grad]), name_scope('move_average'):
self._append_average_accumulate_op(param) self._append_average_accumulate_op(param)
self.apply_program = Program() self.apply_program = Program()
......
...@@ -47,7 +47,8 @@ def append_regularization_ops(parameters_and_grads, regularization=None): ...@@ -47,7 +47,8 @@ def append_regularization_ops(parameters_and_grads, regularization=None):
if grad is None: if grad is None:
params_and_grads.append((param, grad)) params_and_grads.append((param, grad))
continue continue
with param.block.program._optimized_guard([param, grad]): with param.block.program._optimized_guard(
[param, grad]), framework.name_scope('regularization'):
regularization_term = None regularization_term = None
if param.regularizer is not None: if param.regularizer is not None:
# Add variable for regularization term in grad block # Add variable for regularization term in grad block
......
...@@ -78,9 +78,9 @@ if(WITH_DISTRIBUTE) ...@@ -78,9 +78,9 @@ if(WITH_DISTRIBUTE)
set_tests_properties(test_dist_word2vec PROPERTIES TIMEOUT 200) set_tests_properties(test_dist_word2vec PROPERTIES TIMEOUT 200)
py_test_modules(test_dist_se_resnext MODULES test_dist_se_resnext) py_test_modules(test_dist_se_resnext MODULES test_dist_se_resnext)
set_tests_properties(test_dist_se_resnext PROPERTIES TIMEOUT 1000) set_tests_properties(test_dist_se_resnext PROPERTIES TIMEOUT 1000)
# FIXME(typhoonzero): add this back
py_test_modules(test_dist_transformer MODULES test_dist_transformer) #py_test_modules(test_dist_transformer MODULES test_dist_transformer)
set_tests_properties(test_dist_transformer PROPERTIES TIMEOUT 1000) #set_tests_properties(test_dist_transformer PROPERTIES TIMEOUT 1000)
endif(NOT APPLE) endif(NOT APPLE)
py_test_modules(test_dist_transpiler MODULES test_dist_transpiler) py_test_modules(test_dist_transpiler MODULES test_dist_transpiler)
endif() endif()
......
...@@ -85,6 +85,69 @@ class TestDropoutOp5(OpTest): ...@@ -85,6 +85,69 @@ class TestDropoutOp5(OpTest):
self.check_output() self.check_output()
class TestDropoutOp6(TestDropoutOp):
def setUp(self):
self.op_type = "dropout"
self.inputs = {'X': np.random.random((32, 64)).astype("float32")}
self.attrs = {
'dropout_prob': 1.0,
'fix_seed': True,
'is_test': False,
'dropout_implementation': 'upscale_in_train'
}
self.outputs = {
'Out': np.zeros((32, 64)).astype('float32'),
'Mask': np.zeros((32, 64)).astype('float32')
}
class TestDropoutOp7(TestDropoutOp):
def setUp(self):
self.op_type = "dropout"
self.inputs = {'X': np.random.random((32, 64, 2)).astype("float32")}
self.attrs = {
'dropout_prob': 0.0,
'fix_seed': True,
'is_test': False,
'dropout_implementation': 'upscale_in_train'
}
self.outputs = {
'Out': self.inputs['X'],
'Mask': np.ones((32, 64, 2)).astype('float32')
}
class TestDropoutOp8(OpTest):
def setUp(self):
self.op_type = "dropout"
self.inputs = {'X': np.random.random((32, 64)).astype("float32")}
self.attrs = {
'dropout_prob': 0.35,
'fix_seed': True,
'is_test': True,
'dropout_implementation': 'upscale_in_train'
}
self.outputs = {'Out': self.inputs['X']}
def test_check_output(self):
self.check_output()
class TestDropoutOp9(OpTest):
def setUp(self):
self.op_type = "dropout"
self.inputs = {'X': np.random.random((32, 64, 3)).astype("float32")}
self.attrs = {
'dropout_prob': 0.75,
'is_test': True,
'dropout_implementation': 'upscale_in_train'
}
self.outputs = {'Out': self.inputs['X']}
def test_check_output(self):
self.check_output()
class TestFP16DropoutOp(OpTest): class TestFP16DropoutOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "dropout" self.op_type = "dropout"
......
...@@ -49,6 +49,7 @@ LOOKUP_TABLE_GRAD_TYPE = "lookup_table_grad" ...@@ -49,6 +49,7 @@ LOOKUP_TABLE_GRAD_TYPE = "lookup_table_grad"
OP_ROLE_VAR_ATTR_NAME = core.op_proto_and_checker_maker.kOpRoleVarAttrName() OP_ROLE_VAR_ATTR_NAME = core.op_proto_and_checker_maker.kOpRoleVarAttrName()
RPC_OP_ROLE_ATTR_NAME = op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName( RPC_OP_ROLE_ATTR_NAME = op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName(
) )
OPT_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.Optimize
RPC_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.RPC RPC_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.RPC
DIST_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.Dist DIST_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.Dist
LR_SCHED_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.LRSched LR_SCHED_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.LRSched
...@@ -1717,8 +1718,10 @@ to transpile() call.") ...@@ -1717,8 +1718,10 @@ to transpile() call.")
lr_ops = [] lr_ops = []
block = self.origin_program.global_block() block = self.origin_program.global_block()
for op in block.ops: for op in block.ops:
if int(op.attr(RPC_OP_ROLE_ATTR_NAME)) == int( role_id = int(op.attr(RPC_OP_ROLE_ATTR_NAME))
LR_SCHED_OP_ROLE_ATTR_VALUE): if role_id == int(LR_SCHED_OP_ROLE_ATTR_VALUE) or \
role_id == int(LR_SCHED_OP_ROLE_ATTR_VALUE) | \
int(OPT_OP_ROLE_ATTR_VALUE):
lr_ops.append(op) lr_ops.append(op)
log("append lr op: ", op.type) log("append lr op: ", op.type)
return lr_ops return lr_ops
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册