提交 cbe128bb 编写于 作者: Q Qiao Longfei

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into optimize-sum-seq-pooling-op

...@@ -28,3 +28,4 @@ third_party/ ...@@ -28,3 +28,4 @@ third_party/
build_* build_*
# clion workspace. # clion workspace.
cmake-build-* cmake-build-*
model_test
...@@ -75,14 +75,14 @@ RUN pip3 install -U wheel && \ ...@@ -75,14 +75,14 @@ RUN pip3 install -U wheel && \
pip3 install -U docopt PyYAML sphinx==1.5.6 && \ pip3 install -U docopt PyYAML sphinx==1.5.6 && \
pip3 install sphinx-rtd-theme==0.1.9 recommonmark && \ pip3 install sphinx-rtd-theme==0.1.9 recommonmark && \
easy_install -U pip && \ easy_install -U pip && \
pip install -U wheel && \ pip install -U pip setuptools wheel && \
pip install -U docopt PyYAML sphinx==1.5.6 && \ pip install -U docopt PyYAML sphinx==1.5.6 && \
pip install sphinx-rtd-theme==0.1.9 recommonmark pip install sphinx-rtd-theme==0.1.9 recommonmark
RUN pip3 install pre-commit 'ipython==5.3.0' && \ RUN pip3 install 'pre-commit==1.10.4' 'ipython==5.3.0' && \
pip3 install 'ipykernel==4.6.0' 'jupyter==1.0.0' && \ pip3 install 'ipykernel==4.6.0' 'jupyter==1.0.0' && \
pip3 install opencv-python && \ pip3 install opencv-python && \
pip install pre-commit 'ipython==5.3.0' && \ pip install 'pre-commit==1.10.4' 'ipython==5.3.0' && \
pip install 'ipykernel==4.6.0' 'jupyter==1.0.0' && \ pip install 'ipykernel==4.6.0' 'jupyter==1.0.0' && \
pip install opencv-python pip install opencv-python
......
...@@ -86,7 +86,7 @@ paddle.fluid.layers.reduce_prod ArgSpec(args=['input', 'dim', 'keep_dim', 'name' ...@@ -86,7 +86,7 @@ paddle.fluid.layers.reduce_prod ArgSpec(args=['input', 'dim', 'keep_dim', 'name'
paddle.fluid.layers.sequence_first_step ArgSpec(args=['input'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.sequence_first_step ArgSpec(args=['input'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.sequence_last_step ArgSpec(args=['input'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.sequence_last_step ArgSpec(args=['input'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.sequence_slice ArgSpec(args=['input', 'offset', 'length', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.sequence_slice ArgSpec(args=['input', 'offset', 'length', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.dropout ArgSpec(args=['x', 'dropout_prob', 'is_test', 'seed', 'name'], varargs=None, keywords=None, defaults=(False, None, None)) paddle.fluid.layers.dropout ArgSpec(args=['x', 'dropout_prob', 'is_test', 'seed', 'name', 'dropout_implementation'], varargs=None, keywords=None, defaults=(False, None, None, 'downgrade_in_infer'))
paddle.fluid.layers.split ArgSpec(args=['input', 'num_or_sections', 'dim', 'name'], varargs=None, keywords=None, defaults=(-1, None)) paddle.fluid.layers.split ArgSpec(args=['input', 'num_or_sections', 'dim', 'name'], varargs=None, keywords=None, defaults=(-1, None))
paddle.fluid.layers.ctc_greedy_decoder ArgSpec(args=['input', 'blank', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.ctc_greedy_decoder ArgSpec(args=['input', 'blank', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.edit_distance ArgSpec(args=['input', 'label', 'normalized', 'ignored_tokens'], varargs=None, keywords=None, defaults=(True, None)) paddle.fluid.layers.edit_distance ArgSpec(args=['input', 'label', 'normalized', 'ignored_tokens'], varargs=None, keywords=None, defaults=(True, None))
...@@ -107,7 +107,7 @@ paddle.fluid.layers.softmax_with_cross_entropy ArgSpec(args=['logits', 'label', ...@@ -107,7 +107,7 @@ paddle.fluid.layers.softmax_with_cross_entropy ArgSpec(args=['logits', 'label',
paddle.fluid.layers.smooth_l1 ArgSpec(args=['x', 'y', 'inside_weight', 'outside_weight', 'sigma'], varargs=None, keywords=None, defaults=(None, None, None)) paddle.fluid.layers.smooth_l1 ArgSpec(args=['x', 'y', 'inside_weight', 'outside_weight', 'sigma'], varargs=None, keywords=None, defaults=(None, None, None))
paddle.fluid.layers.one_hot ArgSpec(args=['input', 'depth'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.one_hot ArgSpec(args=['input', 'depth'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.autoincreased_step_counter ArgSpec(args=['counter_name', 'begin', 'step'], varargs=None, keywords=None, defaults=(None, 1, 1)) paddle.fluid.layers.autoincreased_step_counter ArgSpec(args=['counter_name', 'begin', 'step'], varargs=None, keywords=None, defaults=(None, 1, 1))
paddle.fluid.layers.reshape ArgSpec(args=['x', 'shape', 'actual_shape', 'act', 'inplace', 'name'], varargs=None, keywords=None, defaults=(None, None, True, None)) paddle.fluid.layers.reshape ArgSpec(args=['x', 'shape', 'actual_shape', 'act', 'inplace', 'name'], varargs=None, keywords=None, defaults=(None, None, False, None))
paddle.fluid.layers.squeeze ArgSpec(args=['input', 'axes', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.squeeze ArgSpec(args=['input', 'axes', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.unsqueeze ArgSpec(args=['input', 'axes', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.unsqueeze ArgSpec(args=['input', 'axes', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.lod_reset ArgSpec(args=['x', 'y', 'target_lod'], varargs=None, keywords=None, defaults=(None, None)) paddle.fluid.layers.lod_reset ArgSpec(args=['x', 'y', 'target_lod'], varargs=None, keywords=None, defaults=(None, None))
...@@ -174,6 +174,7 @@ paddle.fluid.layers.mean ArgSpec(args=['x', 'name'], varargs=None, keywords=None ...@@ -174,6 +174,7 @@ paddle.fluid.layers.mean ArgSpec(args=['x', 'name'], varargs=None, keywords=None
paddle.fluid.layers.mul ArgSpec(args=['x', 'y', 'x_num_col_dims', 'y_num_col_dims', 'name'], varargs=None, keywords=None, defaults=(1, 1, None)) paddle.fluid.layers.mul ArgSpec(args=['x', 'y', 'x_num_col_dims', 'y_num_col_dims', 'name'], varargs=None, keywords=None, defaults=(1, 1, None))
paddle.fluid.layers.sigmoid_cross_entropy_with_logits ArgSpec(args=['x', 'label', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.sigmoid_cross_entropy_with_logits ArgSpec(args=['x', 'label', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.maxout ArgSpec(args=['x', 'groups', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.maxout ArgSpec(args=['x', 'groups', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.sequence_reverse ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.affine_channel ArgSpec(args=['x', 'scale', 'bias', 'data_layout', 'name'], varargs=None, keywords=None, defaults=(None, None, 'NCHW', None)) paddle.fluid.layers.affine_channel ArgSpec(args=['x', 'scale', 'bias', 'data_layout', 'name'], varargs=None, keywords=None, defaults=(None, None, 'NCHW', None))
paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)) paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True))
paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None)) paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None))
......
...@@ -252,9 +252,9 @@ std::vector<ir::Node *> SortOpsAndDelayOptimizeOp(const ir::Graph &graph) { ...@@ -252,9 +252,9 @@ std::vector<ir::Node *> SortOpsAndDelayOptimizeOp(const ir::Graph &graph) {
std::vector<ir::Node *> sorted_ret; std::vector<ir::Node *> sorted_ret;
for (size_t i = 0; i < ret.size(); ++i) { for (size_t i = 0; i < ret.size(); ++i) {
if (i < last_backward) { if (i < last_backward) {
if (boost::get<int>(ret[i]->Op()->GetAttr( if (static_cast<bool>(boost::get<int>(ret[i]->Op()->GetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName())) == OpProtoAndCheckerMaker::OpRoleAttrName())) &
static_cast<int>(OpRole::kOptimize)) { static_cast<int>(OpRole::kOptimize))) {
optimize_ops.push_back(ret[i]); optimize_ops.push_back(ret[i]);
} else { } else {
sorted_ret.push_back(ret[i]); sorted_ret.push_back(ret[i]);
......
...@@ -542,6 +542,33 @@ class CPUVector : public std::vector<T, std::allocator<T>> { ...@@ -542,6 +542,33 @@ class CPUVector : public std::vector<T, std::allocator<T>> {
this->reserve(this->size() + size_t(end - begin)); this->reserve(this->size() + size_t(end - begin));
this->insert(this->end(), begin, end); this->insert(this->end(), begin, end);
} }
const T *CUDAData(platform::Place place) const {
PADDLE_THROW(
"Vector::CUDAData() method is not supported in CPU-only version");
}
T *CUDAMutableData(platform::Place place) {
PADDLE_THROW(
"Vector::CUDAMutableData() method is not supported in CPU-only "
"version");
}
const T *Data(platform::Place place) const {
PADDLE_ENFORCE(
platform::is_cpu_place(place),
"Vector::Data() method is not supported when not in CPUPlace");
return this->data();
}
T *MutableData(platform::Place place) {
PADDLE_ENFORCE(
platform::is_cpu_place(place),
"Vector::MutableData() method is not supported when not in CPUPlace");
return this->data();
}
const void *Handle() const { return static_cast<const void *>(this); }
}; };
template <typename T> template <typename T>
......
...@@ -146,22 +146,5 @@ void NaiveExecutor::CleanFeedFetchOps() { ...@@ -146,22 +146,5 @@ void NaiveExecutor::CleanFeedFetchOps() {
ops_.swap(ops); ops_.swap(ops);
} }
void NaiveExecutor::EnableMKLDNN(const ProgramDesc &program) {
#ifdef PADDLE_WITH_MKLDNN
VLOG(3) << "use_mkldnn=True";
for (size_t block_id = 0; block_id < program.Size(); ++block_id) {
auto *block = const_cast<ProgramDesc &>(program).MutableBlock(block_id);
for (auto *op : block->AllOps()) {
if (op->HasAttr("use_mkldnn")) {
op->SetAttr("use_mkldnn", true);
}
}
}
#else
LOG(WARNING)
<< "'MKLDNN' is not supported, Please re-compile with WITH_MKLDNN option";
#endif
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -48,8 +48,6 @@ class NaiveExecutor { ...@@ -48,8 +48,6 @@ class NaiveExecutor {
void CleanFeedFetchOps(); void CleanFeedFetchOps();
void EnableMKLDNN(const ProgramDesc& program);
protected: protected:
void CreateVariables(const ProgramDesc& desc, Scope* scope, int block_id); void CreateVariables(const ProgramDesc& desc, Scope* scope, int block_id);
......
...@@ -71,6 +71,8 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto, ...@@ -71,6 +71,8 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto,
static_cast<int>(OpRole::kLoss) | static_cast<int>(OpRole::kForward), static_cast<int>(OpRole::kLoss) | static_cast<int>(OpRole::kForward),
static_cast<int>(OpRole::kLoss) | static_cast<int>(OpRole::kLoss) |
static_cast<int>(OpRole::kBackward), static_cast<int>(OpRole::kBackward),
static_cast<int>(OpRole::kOptimize) |
static_cast<int>(OpRole::kLRSched),
static_cast<int>(OpRole::kNotSpecified)}) static_cast<int>(OpRole::kNotSpecified)})
.SetDefault(static_cast<int>(OpRole::kNotSpecified)); .SetDefault(static_cast<int>(OpRole::kNotSpecified));
AddAttr<std::vector<std::string>>(OpRoleVarAttrName(), AddAttr<std::vector<std::string>>(OpRoleVarAttrName(),
......
...@@ -20,6 +20,9 @@ limitations under the License. */ ...@@ -20,6 +20,9 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
//////////////////////////
// Don't add more roles to make this too complicated!
//////////////////////////
enum class OpRole { enum class OpRole {
kForward = 0x0000, kForward = 0x0000,
kBackward = 0x0001, kBackward = 0x0001,
......
...@@ -156,12 +156,6 @@ ParallelExecutor::ParallelExecutor( ...@@ -156,12 +156,6 @@ ParallelExecutor::ParallelExecutor(
params, member_->local_scopes_, member_->use_cuda_); params, member_->local_scopes_, member_->use_cuda_);
#endif #endif
// If the loss_var_name is given, the number of graph should be only one.
if (loss_var_name.size()) {
PADDLE_ENFORCE_EQ(ir::GraphNum(*graph), 1,
"The number of graph should be only one");
}
if (exec_strategy.type_ == ExecutionStrategy::kDefault) { if (exec_strategy.type_ == ExecutionStrategy::kDefault) {
member_->executor_.reset(new details::ThreadedSSAGraphExecutor( member_->executor_.reset(new details::ThreadedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, places, std::move(graph))); exec_strategy, member_->local_scopes_, places, std::move(graph)));
......
...@@ -25,7 +25,6 @@ DEFINE_int32(dist_threadpool_size, 0, ...@@ -25,7 +25,6 @@ DEFINE_int32(dist_threadpool_size, 0,
namespace paddle { namespace paddle {
namespace framework { namespace framework {
std::unique_ptr<ThreadPool> ThreadPool::threadpool_(nullptr); std::unique_ptr<ThreadPool> ThreadPool::threadpool_(nullptr);
std::once_flag ThreadPool::init_flag_; std::once_flag ThreadPool::init_flag_;
...@@ -47,8 +46,7 @@ void ThreadPool::Init() { ...@@ -47,8 +46,7 @@ void ThreadPool::Init() {
} }
} }
ThreadPool::ThreadPool(int num_threads) ThreadPool::ThreadPool(int num_threads) : running_(true) {
: total_threads_(num_threads), idle_threads_(num_threads), running_(true) {
threads_.resize(num_threads); threads_.resize(num_threads);
for (auto& thread : threads_) { for (auto& thread : threads_) {
// TODO(Yancey1989): binding the thread on the specify CPU number // TODO(Yancey1989): binding the thread on the specify CPU number
...@@ -59,6 +57,7 @@ ThreadPool::ThreadPool(int num_threads) ...@@ -59,6 +57,7 @@ ThreadPool::ThreadPool(int num_threads)
ThreadPool::~ThreadPool() { ThreadPool::~ThreadPool() {
{ {
// notify all threads to stop running // notify all threads to stop running
std::lock_guard<std::mutex> l(mutex_);
running_ = false; running_ = false;
scheduled_.notify_all(); scheduled_.notify_all();
} }
...@@ -69,36 +68,24 @@ ThreadPool::~ThreadPool() { ...@@ -69,36 +68,24 @@ ThreadPool::~ThreadPool() {
} }
} }
void ThreadPool::Wait() {
std::unique_lock<std::mutex> lock(mutex_);
completed_.wait(lock, [=] { return Done() == true; });
}
void ThreadPool::TaskLoop() { void ThreadPool::TaskLoop() {
while (running_) { while (true) {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
scheduled_.wait(lock, [=] { return !tasks_.empty() || !running_; });
if (!running_) { scheduled_.wait(
break; lock, [this] { return !this->tasks_.empty() || !this->running_; });
if (!running_ || tasks_.empty()) {
return;
} }
// pop a task from the task queue // pop a task from the task queue
auto task = std::move(tasks_.front()); auto task = std::move(tasks_.front());
tasks_.pop(); tasks_.pop();
--idle_threads_;
lock.unlock(); lock.unlock();
// run the task // run the task
task(); task();
{
std::unique_lock<std::mutex> lock(mutex_);
++idle_threads_;
if (Done()) {
completed_.notify_all();
}
}
} }
} }
......
...@@ -57,15 +57,6 @@ class ThreadPool { ...@@ -57,15 +57,6 @@ class ThreadPool {
~ThreadPool(); ~ThreadPool();
// Returns the number of threads created by the constructor.
size_t Threads() const { return total_threads_; }
// Returns the number of currently idle threads.
size_t IdleThreads() {
std::unique_lock<std::mutex> lock(mutex_);
return idle_threads_;
}
// Run pushes a function to the task queue and returns a std::future // Run pushes a function to the task queue and returns a std::future
// object. To wait for the completion of the task, call // object. To wait for the completion of the task, call
// std::future::wait(). // std::future::wait().
...@@ -94,25 +85,13 @@ class ThreadPool { ...@@ -94,25 +85,13 @@ class ThreadPool {
}); });
std::future<std::unique_ptr<platform::EnforceNotMet>> f = task.get_future(); std::future<std::unique_ptr<platform::EnforceNotMet>> f = task.get_future();
tasks_.push(std::move(task)); tasks_.push(std::move(task));
lock.unlock();
scheduled_.notify_one(); scheduled_.notify_one();
return f; return f;
} }
// Wait until all the tasks are completed.
void Wait();
private: private:
DISABLE_COPY_AND_ASSIGN(ThreadPool); DISABLE_COPY_AND_ASSIGN(ThreadPool);
// If the task queue is empty and avaialbe is equal to the number of
// threads, means that all tasks are completed. Note: this function
// is not thread-safe. Returns true if all tasks are completed.
// Note: don't delete the data member total_threads_ and use
// threads_.size() instead; because you'd need to lock the mutex
// before accessing threads_.
bool Done() { return tasks_.empty() && idle_threads_ == total_threads_; }
// The constructor starts threads to run TaskLoop, which retrieves // The constructor starts threads to run TaskLoop, which retrieves
// and runs tasks from the queue. // and runs tasks from the queue.
void TaskLoop(); void TaskLoop();
...@@ -125,14 +104,11 @@ class ThreadPool { ...@@ -125,14 +104,11 @@ class ThreadPool {
static std::once_flag init_flag_; static std::once_flag init_flag_;
std::vector<std::unique_ptr<std::thread>> threads_; std::vector<std::unique_ptr<std::thread>> threads_;
const size_t total_threads_;
size_t idle_threads_;
std::queue<Task> tasks_; std::queue<Task> tasks_;
std::mutex mutex_; std::mutex mutex_;
bool running_; bool running_;
std::condition_variable scheduled_; std::condition_variable scheduled_;
std::condition_variable completed_;
}; };
class ThreadPoolIO : ThreadPool { class ThreadPoolIO : ThreadPool {
......
...@@ -19,10 +19,11 @@ limitations under the License. */ ...@@ -19,10 +19,11 @@ limitations under the License. */
namespace framework = paddle::framework; namespace framework = paddle::framework;
void do_sum(framework::ThreadPool* pool, std::atomic<int>* sum, int cnt) { void do_sum(std::vector<std::future<void>>* fs, std::mutex* mu,
std::vector<std::future<void>> fs; std::atomic<int>* sum, int cnt) {
for (int i = 0; i < cnt; ++i) { for (int i = 0; i < cnt; ++i) {
fs.push_back(framework::Async([sum]() { sum->fetch_add(1); })); std::lock_guard<std::mutex> l(*mu);
fs->push_back(framework::Async([sum]() { sum->fetch_add(1); }));
} }
} }
...@@ -40,18 +41,21 @@ TEST(ThreadPool, ConcurrentInit) { ...@@ -40,18 +41,21 @@ TEST(ThreadPool, ConcurrentInit) {
} }
TEST(ThreadPool, ConcurrentRun) { TEST(ThreadPool, ConcurrentRun) {
framework::ThreadPool* pool = framework::ThreadPool::GetInstance();
std::atomic<int> sum(0); std::atomic<int> sum(0);
std::vector<std::thread> threads; std::vector<std::thread> threads;
std::vector<std::future<void>> fs;
std::mutex fs_mu;
int n = 50; int n = 50;
// sum = (n * (n + 1)) / 2 // sum = (n * (n + 1)) / 2
for (int i = 1; i <= n; ++i) { for (int i = 1; i <= n; ++i) {
std::thread t(do_sum, pool, &sum, i); std::thread t(do_sum, &fs, &fs_mu, &sum, i);
threads.push_back(std::move(t)); threads.push_back(std::move(t));
} }
for (auto& t : threads) { for (auto& t : threads) {
t.join(); t.join();
} }
pool->Wait(); for (auto& t : fs) {
t.wait();
}
EXPECT_EQ(sum, ((n + 1) * n) / 2); EXPECT_EQ(sum, ((n + 1) * n) / 2);
} }
...@@ -107,6 +107,9 @@ void Analyzer::Run(Argument* argument) { ...@@ -107,6 +107,9 @@ void Analyzer::Run(Argument* argument) {
passes.push_back("mkldnn_placement_pass"); passes.push_back("mkldnn_placement_pass");
} }
#endif #endif
// infer_clean_graph_pass should be the first default pass
// after mkldnn_placement_pass.
passes.push_back("infer_clean_graph_pass");
for (auto& pass : ir_passes_) { for (auto& pass : ir_passes_) {
if (!disabled_ir_passes_.count(pass)) { if (!disabled_ir_passes_.count(pass)) {
passes.push_back(pass); passes.push_back(pass);
......
...@@ -67,7 +67,6 @@ class Analyzer : public OrderedRegistry<PassManager> { ...@@ -67,7 +67,6 @@ class Analyzer : public OrderedRegistry<PassManager> {
// larger fusion. // larger fusion.
const std::vector<std::string> all_ir_passes_{{ const std::vector<std::string> all_ir_passes_{{
// Manual update the passes here. // Manual update the passes here.
"infer_clean_graph_pass", //
"attention_lstm_fuse_pass", // "attention_lstm_fuse_pass", //
"seqconv_eltadd_relu_fuse_pass", // "seqconv_eltadd_relu_fuse_pass", //
"embedding_fc_lstm_fuse_pass", // "embedding_fc_lstm_fuse_pass", //
......
...@@ -124,7 +124,7 @@ class ZeroCopyTensor { ...@@ -124,7 +124,7 @@ class ZeroCopyTensor {
std::vector<std::vector<size_t>> lod() const; std::vector<std::vector<size_t>> lod() const;
protected: protected:
ZeroCopyTensor(void* scope) : scope_{scope} {} explicit ZeroCopyTensor(void* scope) : scope_{scope} {}
void SetName(const std::string& name) { name_ = name; } void SetName(const std::string& name) { name_ = name; }
void* FindTensor() const; void* FindTensor() const;
...@@ -259,12 +259,6 @@ struct AnalysisConfig : public NativeConfig { ...@@ -259,12 +259,6 @@ struct AnalysisConfig : public NativeConfig {
kExclude // Specify the disabled passes in `ir_passes`. kExclude // Specify the disabled passes in `ir_passes`.
}; };
void SetIncludeMode() {
ir_mode = IrPassMode::kInclude;
// this pass has to be run at the beginning of all fuse passes
ir_passes = {"infer_clean_graph_pass"};
}
// Determine whether to perform graph optimization. // Determine whether to perform graph optimization.
bool enable_ir_optim = true; bool enable_ir_optim = true;
// Manually determine the IR passes to run. // Manually determine the IR passes to run.
......
...@@ -284,7 +284,7 @@ static inline Tensor NMS(const platform::DeviceContext &ctx, Tensor *bbox, ...@@ -284,7 +284,7 @@ static inline Tensor NMS(const platform::DeviceContext &ctx, Tensor *bbox,
selected_indices.push_back(idx); selected_indices.push_back(idx);
++selected_num; ++selected_num;
} }
sorted_indices.erase(sorted_indices.end()); sorted_indices.erase(sorted_indices.end() - 1);
if (flag && eta < 1 && adaptive_threshold > 0.5) { if (flag && eta < 1 && adaptive_threshold > 0.5) {
adaptive_threshold *= eta; adaptive_threshold *= eta;
} }
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/dropout_op.h" #include "paddle/fluid/operators/dropout_op.h"
#include <string>
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -57,6 +58,29 @@ class DropoutOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -57,6 +58,29 @@ class DropoutOpMaker : public framework::OpProtoAndCheckerMaker {
"will be dropped.") "will be dropped.")
.SetDefault(false); .SetDefault(false);
AddAttr<int>("seed", "Dropout random seed.").SetDefault(0); AddAttr<int>("seed", "Dropout random seed.").SetDefault(0);
AddAttr<std::string>(
"dropout_implementation",
"[\"downgrade_in_infer\"|\"upscale_in_train\"]"
"There are two kinds of ways to implement dropout"
"(the mask below is a tensor have the same shape with input"
"the value of mask is 0 or 1, the ratio of 0 is dropout_prob)"
"1. downgrade_in_infer(default), downgrade the outcome at inference "
"time"
" train: out = input * mask"
" inference: out = input * dropout_prob"
"2. upscale_in_train, upscale the outcome at training time, do nothing "
"in inference"
" train: out = input * mask / ( 1.0 - dropout_prob )"
" inference: out = input"
" dropout op can be removed from the program. the program will be "
"efficient")
.SetDefault("downgrade_in_infer")
.AddCustomChecker([](const std::string& type) {
PADDLE_ENFORCE(
type == "downgrade_in_infer" || type == "upscale_in_train",
"dropout_implementation can only be downgrade_in_infer or "
"upscale_in_train");
});
AddComment(R"DOC( AddComment(R"DOC(
Dropout Operator. Dropout Operator.
...@@ -104,7 +128,9 @@ REGISTER_OPERATOR(dropout, ops::DropoutOp, ops::DropoutOpMaker, ...@@ -104,7 +128,9 @@ REGISTER_OPERATOR(dropout, ops::DropoutOp, ops::DropoutOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>); paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(dropout_grad, ops::DropoutOpGrad); REGISTER_OPERATOR(dropout_grad, ops::DropoutOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
dropout, ops::CPUDropoutKernel<paddle::platform::CPUDeviceContext, float>); dropout, ops::CPUDropoutKernel<paddle::platform::CPUDeviceContext, float>,
ops::CPUDropoutKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
dropout_grad, dropout_grad,
ops::DropoutGradKernel<paddle::platform::CPUDeviceContext, float>); ops::DropoutGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::DropoutGradKernel<paddle::platform::CPUDeviceContext, double>);
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include <thrust/iterator/counting_iterator.h> #include <thrust/iterator/counting_iterator.h>
#include <thrust/random.h> #include <thrust/random.h>
#include <thrust/transform.h> #include <thrust/transform.h>
#include <string>
#include "paddle/fluid/operators/dropout_op.h" #include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
...@@ -26,7 +27,8 @@ namespace operators { ...@@ -26,7 +27,8 @@ namespace operators {
template <typename T> template <typename T>
__global__ void RandomGenerator(const size_t n, const int seed, __global__ void RandomGenerator(const size_t n, const int seed,
const float dropout_prob, const T* src, const float dropout_prob, const T* src,
T* mask_data, T* dst) { T* mask_data, T* dst,
bool is_upscale_in_train) {
thrust::minstd_rand rng; thrust::minstd_rand rng;
rng.seed(seed); rng.seed(seed);
thrust::uniform_real_distribution<float> dist(0, 1); thrust::uniform_real_distribution<float> dist(0, 1);
...@@ -46,9 +48,13 @@ __global__ void RandomGenerator(const size_t n, const int seed, ...@@ -46,9 +48,13 @@ __global__ void RandomGenerator(const size_t n, const int seed,
} }
if (dist(rng) < dropout_prob) { if (dist(rng) < dropout_prob) {
mask = static_cast<T>(0); mask = static_cast<T>(0);
} else {
if (is_upscale_in_train) {
mask = static_cast<T>(1.0f / (1.0f - dropout_prob));
} else { } else {
mask = static_cast<T>(1); mask = static_cast<T>(1);
} }
}
dest = s * mask; dest = s * mask;
mask_data[idx] = mask; mask_data[idx] = mask;
dst[idx] = dest; dst[idx] = dest;
...@@ -67,6 +73,8 @@ class GPUDropoutKernel : public framework::OpKernel<T> { ...@@ -67,6 +73,8 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
y->mutable_data<T>(context.GetPlace()); y->mutable_data<T>(context.GetPlace());
float dropout_prob = context.Attr<float>("dropout_prob"); float dropout_prob = context.Attr<float>("dropout_prob");
auto dropout_implementation =
context.Attr<std::string>("dropout_implementation");
auto& place = *context.template device_context<Place>().eigen_device(); auto& place = *context.template device_context<Place>().eigen_device();
if (!context.Attr<bool>("is_test")) { if (!context.Attr<bool>("is_test")) {
auto* mask = context.Output<Tensor>("Mask"); auto* mask = context.Output<Tensor>("Mask");
...@@ -83,13 +91,18 @@ class GPUDropoutKernel : public framework::OpKernel<T> { ...@@ -83,13 +91,18 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
int grid = (x->numel() + threads - 1) / threads; int grid = (x->numel() + threads - 1) / threads;
RandomGenerator< RandomGenerator<
T><<<grid, threads, 0, context.cuda_device_context().stream()>>>( T><<<grid, threads, 0, context.cuda_device_context().stream()>>>(
size, seed, dropout_prob, x_data, mask_data, y_data); size, seed, dropout_prob, x_data, mask_data, y_data,
(dropout_implementation == "upscale_in_train"));
} else { } else {
auto X = EigenMatrix<T>::Reshape(*x, 1); auto X = EigenMatrix<T>::Reshape(*x, 1);
auto Y = EigenMatrix<T>::Reshape(*y, 1); auto Y = EigenMatrix<T>::Reshape(*y, 1);
if (dropout_implementation == "upscale_in_train") {
Y.device(place) = X;
} else {
Y.device(place) = X * static_cast<T>(1.0f - dropout_prob); Y.device(place) = X * static_cast<T>(1.0f - dropout_prob);
} }
} }
}
}; };
} // namespace operators } // namespace operators
...@@ -99,6 +112,8 @@ namespace ops = paddle::operators; ...@@ -99,6 +112,8 @@ namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
dropout, ops::GPUDropoutKernel<plat::CUDADeviceContext, float>, dropout, ops::GPUDropoutKernel<plat::CUDADeviceContext, float>,
ops::GPUDropoutKernel<plat::CUDADeviceContext, plat::float16>); ops::GPUDropoutKernel<plat::CUDADeviceContext, plat::float16>,
REGISTER_OP_CUDA_KERNEL(dropout_grad, ops::GPUDropoutKernel<plat::CUDADeviceContext, double>);
ops::DropoutGradKernel<plat::CUDADeviceContext, float>); REGISTER_OP_CUDA_KERNEL(
dropout_grad, ops::DropoutGradKernel<plat::CUDADeviceContext, float>,
ops::DropoutGradKernel<plat::CUDADeviceContext, double>);
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include <random> #include <random>
#include <string>
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
...@@ -36,6 +37,8 @@ class CPUDropoutKernel : public framework::OpKernel<T> { ...@@ -36,6 +37,8 @@ class CPUDropoutKernel : public framework::OpKernel<T> {
auto* y_data = y->mutable_data<T>(context.GetPlace()); auto* y_data = y->mutable_data<T>(context.GetPlace());
float dropout_prob = context.Attr<float>("dropout_prob"); float dropout_prob = context.Attr<float>("dropout_prob");
auto dropout_implementation =
context.Attr<std::string>("dropout_implementation");
if (!context.Attr<bool>("is_test")) { if (!context.Attr<bool>("is_test")) {
auto* mask = context.Output<Tensor>("Mask"); auto* mask = context.Output<Tensor>("Mask");
auto* mask_data = mask->mutable_data<T>(context.GetPlace()); auto* mask_data = mask->mutable_data<T>(context.GetPlace());
...@@ -49,22 +52,32 @@ class CPUDropoutKernel : public framework::OpKernel<T> { ...@@ -49,22 +52,32 @@ class CPUDropoutKernel : public framework::OpKernel<T> {
engine.seed(seed); engine.seed(seed);
std::uniform_real_distribution<float> dist(0, 1); std::uniform_real_distribution<float> dist(0, 1);
size_t size = framework::product(mask->dims()); size_t size = framework::product(mask->dims());
for (size_t i = 0; i < size; ++i) { for (size_t i = 0; i < size; ++i) {
if (dist(engine) < dropout_prob) { if (dist(engine) < dropout_prob) {
mask_data[i] = 0; mask_data[i] = 0;
y_data[i] = 0; y_data[i] = 0;
} else {
if (dropout_implementation == "upscale_in_train") {
mask_data[i] = 1.0f / static_cast<T>(1.0f - dropout_prob);
y_data[i] = x_data[i] / static_cast<T>(1.0f - dropout_prob);
} else { } else {
mask_data[i] = 1; mask_data[i] = 1;
y_data[i] = x_data[i]; y_data[i] = x_data[i];
} }
} }
}
} else { } else {
auto X = EigenMatrix<T>::Reshape(*x, 1); auto X = EigenMatrix<T>::Reshape(*x, 1);
auto Y = EigenMatrix<T>::Reshape(*y, 1); auto Y = EigenMatrix<T>::Reshape(*y, 1);
auto& place = auto& place =
*context.template device_context<DeviceContext>().eigen_device(); *context.template device_context<DeviceContext>().eigen_device();
Y.device(place) = X * (1.0f - dropout_prob); if (dropout_implementation == "upscale_in_train") {
Y.device(place) = X;
} else {
Y.device(place) = X * static_cast<T>(1.0f - dropout_prob);
}
} }
} }
}; };
......
...@@ -16,10 +16,9 @@ limitations under the License. */ ...@@ -16,10 +16,9 @@ limitations under the License. */
#include <cstring> // for memcpy #include <cstring> // for memcpy
#include <string> #include <string>
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/cpu_vec.h"
#include "paddle/fluid/operators/math/fc_compute.h" #include "paddle/fluid/operators/math/fc_compute.h"
#include "paddle/fluid/operators/math/jit_kernel.h"
#include "paddle/fluid/operators/math/sequence2batch.h" #include "paddle/fluid/operators/math/sequence2batch.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -174,58 +173,44 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -174,58 +173,44 @@ class FusionGRUKernel : public framework::OpKernel<T> {
} }
} }
#define INIT_VEC_FUNC \ #define INIT_BASE_DEFINES \
std::function<void(const int, const T *, T *)> act_gate, act_state; \ auto* x = ctx.Input<LoDTensor>("X"); \
std::function<void(const int, const T*, const T*, const T*, T*)> cross; \
auto& act_gate_str = ctx.Attr<std::string>("gate_activation"); \
auto& act_state_str = ctx.Attr<std::string>("activation"); \
if (platform::jit::MayIUse(platform::jit::avx)) { \
math::VecActivations<T, platform::jit::avx> act_functor; \
act_gate = act_functor(act_gate_str); \
act_state = act_functor(act_state_str); \
cross = math::vec_cross<T, platform::jit::avx>; \
} else { \
math::VecActivations<T, platform::jit::isa_any> act_functor; \
act_gate = act_functor(act_gate_str); \
act_state = act_functor(act_state_str); \
cross = math::vec_cross<T, platform::jit::isa_any>; \
}
#define INIT_BASE_INPUT_OUTPUT \
auto* h0 = ctx.Input<Tensor>("H0"); \
auto* wx = ctx.Input<Tensor>("WeightX"); \
auto* wh = ctx.Input<Tensor>("WeightH"); \ auto* wh = ctx.Input<Tensor>("WeightH"); \
auto* bias = ctx.Input<Tensor>("Bias"); \
auto* xx = ctx.Output<LoDTensor>("XX"); \ auto* xx = ctx.Output<LoDTensor>("XX"); \
auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); \ auto x_lod = x->lod(); \
bool is_reverse = ctx.Attr<bool>("is_reverse");
#define INIT_BASE_SIZES \
auto x_dims = x->dims(); /* T x M*/ \ auto x_dims = x->dims(); /* T x M*/ \
auto wh_dims = wh->dims(); /* D x 3D*/ \ auto wh_dims = wh->dims(); /* D x 3D*/ \
const int total_T = x_dims[0]; \ const int total_T = x_dims[0]; \
const int D3 = wh_dims[1]
#define INIT_OTHER_DEFINES \
auto* h0 = ctx.Input<Tensor>("H0"); \
auto* wx = ctx.Input<Tensor>("WeightX"); \
auto* bias = ctx.Input<Tensor>("Bias"); \
auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); \
bool is_reverse = ctx.Attr<bool>("is_reverse"); \
const int M = x_dims[1]; \ const int M = x_dims[1]; \
const int D = wh_dims[0]; \ const int D = wh_dims[0]; \
const int D3 = wh_dims[1]; \ const int D2 = D * 2; \
const int D2 = D * 2; const auto& ker = math::jitkernel::KernelPool::Instance() \
.template Get<math::jitkernel::GRUKernel<T>, \
const std::string&, const std::string&>( \
ctx.Attr<std::string>("gate_activation"), \
ctx.Attr<std::string>("activation"), D); \
const T* x_data = x->data<T>(); \
const T* wx_data = wx->data<T>(); \
const T* wh_data = wh->data<T>(); \
auto place = ctx.GetPlace(); \
T* xx_data = xx->mutable_data<T>(place)
void SeqCompute(const framework::ExecutionContext& ctx) const { void SeqCompute(const framework::ExecutionContext& ctx) const {
using DeviceContext = paddle::platform::CPUDeviceContext; using DeviceContext = paddle::platform::CPUDeviceContext;
auto* x = ctx.Input<LoDTensor>("X"); INIT_BASE_DEFINES;
INIT_BASE_INPUT_OUTPUT INIT_OTHER_DEFINES;
INIT_BASE_SIZES
INIT_VEC_FUNC
auto x_lod = x->lod();
const int N = x_lod[0].size() - 1; const int N = x_lod[0].size() - 1;
const T* x_data = x->data<T>();
const T* h0_data = h0 ? h0->data<T>() : nullptr; const T* h0_data = h0 ? h0->data<T>() : nullptr;
const T* wx_data = wx->data<T>();
const T* wh_data = wh->data<T>();
const T* wh_state_data = wh_data + D * D2; const T* wh_state_data = wh_data + D * D2;
T* xx_data = xx->mutable_data<T>(ctx.GetPlace()); T* hidden_out_data = hidden_out->mutable_data<T>(place);
T* hidden_out_data = hidden_out->mutable_data<T>(ctx.GetPlace());
auto blas = math::GetBlas<DeviceContext, T>(ctx); auto blas = math::GetBlas<DeviceContext, T>(ctx);
math::FCCompute<DeviceContext, T>(blas, total_T, D3, M, x_data, wx_data, math::FCCompute<DeviceContext, T>(blas, total_T, D3, M, x_data, wx_data,
xx_data, xx_data,
...@@ -252,14 +237,7 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -252,14 +237,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
if (h0_data) { if (h0_data) {
prev_hidden_data = h0_data + bid * D; prev_hidden_data = h0_data + bid * D;
} else { } else {
// W: {W_update, W_reset; W_state} ker->ComputeH1(xx_data, hidden_out_data);
// update gate
act_gate(D, xx_data, xx_data);
// state gate
act_state(D, xx_data + D2, xx_data + D2);
// out = a*b
blas.VMUL(D, xx_data, xx_data + D2, hidden_out_data);
// save prev
prev_hidden_data = hidden_out_data; prev_hidden_data = hidden_out_data;
tstart = 1; tstart = 1;
move_step(); move_step();
...@@ -269,17 +247,12 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -269,17 +247,12 @@ class FusionGRUKernel : public framework::OpKernel<T> {
blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D2, D, static_cast<T>(1), blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D2, D, static_cast<T>(1),
prev_hidden_data, D, wh_data, D2, static_cast<T>(1), xx_data, prev_hidden_data, D, wh_data, D2, static_cast<T>(1), xx_data,
D3); D3);
act_gate(D2, xx_data, xx_data); ker->ComputeHtPart1(xx_data, prev_hidden_data, hidden_out_data);
// rt = rt*ht_1 inplace result
blas.VMUL(D, prev_hidden_data, xx_data + D, hidden_out_data);
// gemm rt * Ws // gemm rt * Ws
blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D, D, static_cast<T>(1), blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D, D, static_cast<T>(1),
hidden_out_data, D, wh_state_data, D, static_cast<T>(1), hidden_out_data, D, wh_state_data, D, static_cast<T>(1),
xx_data + D2, D3); xx_data + D2, D3);
act_state(D, xx_data + D2, xx_data + D2); ker->ComputeHtPart2(xx_data, prev_hidden_data, hidden_out_data);
// out = zt*ht~ + (1-zt)*ht_1
cross(D, xx_data, xx_data + D2, prev_hidden_data, hidden_out_data);
// save prev // save prev
prev_hidden_data = hidden_out_data; prev_hidden_data = hidden_out_data;
move_step(); move_step();
...@@ -289,28 +262,19 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -289,28 +262,19 @@ class FusionGRUKernel : public framework::OpKernel<T> {
void BatchCompute(const framework::ExecutionContext& ctx) const { void BatchCompute(const framework::ExecutionContext& ctx) const {
using DeviceContext = paddle::platform::CPUDeviceContext; using DeviceContext = paddle::platform::CPUDeviceContext;
auto* x = ctx.Input<LoDTensor>("X"); INIT_BASE_DEFINES;
INIT_BASE_INPUT_OUTPUT if (x_lod[0].size() == 2) {
INIT_BASE_SIZES
if (x->lod()[0].size() == 2) {
xx->Resize({total_T, D3}); xx->Resize({total_T, D3});
SeqCompute(ctx); SeqCompute(ctx);
return; return;
} }
INIT_VEC_FUNC INIT_OTHER_DEFINES;
auto* reordered_h0 = ctx.Output<Tensor>("ReorderedH0"); auto* reordered_h0 = ctx.Output<Tensor>("ReorderedH0");
auto* batched_input = ctx.Output<LoDTensor>("BatchedInput"); auto* batched_input = ctx.Output<LoDTensor>("BatchedInput");
auto* batched_out = ctx.Output<LoDTensor>("BatchedOut"); auto* batched_out = ctx.Output<LoDTensor>("BatchedOut");
T* batched_input_data = batched_input->mutable_data<T>(place);
const T* x_data = x->data<T>(); T* batched_out_data = batched_out->mutable_data<T>(place);
const T* wx_data = wx->data<T>(); hidden_out->mutable_data<T>(place);
const T* wh_data = wh->data<T>();
T* xx_data = xx->mutable_data<T>(ctx.GetPlace());
T* batched_input_data = batched_input->mutable_data<T>(ctx.GetPlace());
T* batched_out_data = batched_out->mutable_data<T>(ctx.GetPlace());
hidden_out->mutable_data<T>(ctx.GetPlace());
auto& dev_ctx = ctx.template device_context<DeviceContext>(); auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx); auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch; math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
...@@ -336,7 +300,7 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -336,7 +300,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
T* prev_hidden_data = nullptr; T* prev_hidden_data = nullptr;
if (h0) { if (h0) {
// reorder h0 // reorder h0
T* reordered_h0_data = reordered_h0->mutable_data<T>(ctx.GetPlace()); T* reordered_h0_data = reordered_h0->mutable_data<T>(place);
const T* h0_data = h0->data<T>(); const T* h0_data = h0->data<T>();
prev_hidden_data = reordered_h0_data; prev_hidden_data = reordered_h0_data;
size_t sz = sizeof(T) * D; size_t sz = sizeof(T) * D;
...@@ -350,12 +314,7 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -350,12 +314,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
T* cur_out_data = batched_out_data; T* cur_out_data = batched_out_data;
// W: {W_update, W_reset; W_state} // W: {W_update, W_reset; W_state}
for (int i = 0; i < max_bs; ++i) { for (int i = 0; i < max_bs; ++i) {
// update gate ker->ComputeH1(cur_in_data, cur_out_data);
act_gate(D, cur_in_data, cur_in_data);
// state gate
act_state(D, cur_in_data + D2, cur_in_data + D2);
// out = a*b
blas.VMUL(D, cur_in_data, cur_in_data + D2, cur_out_data);
// add offset // add offset
cur_in_data += D3; cur_in_data += D3;
cur_out_data += D; cur_out_data += D;
...@@ -380,10 +339,8 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -380,10 +339,8 @@ class FusionGRUKernel : public framework::OpKernel<T> {
T* cur_out_data = batched_out_data; T* cur_out_data = batched_out_data;
T* cur_prev_hidden_data = prev_hidden_data; T* cur_prev_hidden_data = prev_hidden_data;
for (int i = 0; i < cur_bs; ++i) { for (int i = 0; i < cur_bs; ++i) {
act_gate(D2, cur_batched_data, cur_batched_data); ker->ComputeHtPart1(cur_batched_data, cur_prev_hidden_data,
// rt = rt*ht_1 inplace result cur_out_data);
blas.VMUL(D, cur_prev_hidden_data, cur_batched_data + D, cur_out_data);
cur_batched_data += D3; cur_batched_data += D3;
cur_prev_hidden_data += D; cur_prev_hidden_data += D;
cur_out_data += D; cur_out_data += D;
...@@ -397,12 +354,8 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -397,12 +354,8 @@ class FusionGRUKernel : public framework::OpKernel<T> {
cur_prev_hidden_data = prev_hidden_data; cur_prev_hidden_data = prev_hidden_data;
for (int i = 0; i < cur_bs; ++i) { for (int i = 0; i < cur_bs; ++i) {
// ht~ = act_state(...) ker->ComputeHtPart2(cur_batched_data, cur_prev_hidden_data,
act_state(D, cur_batched_data + D2, cur_batched_data + D2);
// out = zt*ht~ + (1-zt)*ht_1
cross(D, cur_batched_data, cur_batched_data + D2, cur_prev_hidden_data,
cur_out_data); cur_out_data);
cur_batched_data += D3; cur_batched_data += D3;
cur_prev_hidden_data += D; cur_prev_hidden_data += D;
cur_out_data += D; cur_out_data += D;
...@@ -416,9 +369,8 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -416,9 +369,8 @@ class FusionGRUKernel : public framework::OpKernel<T> {
batched_out->set_lod(batched_lod); batched_out->set_lod(batched_lod);
to_seq(dev_ctx, *batched_out, hidden_out); to_seq(dev_ctx, *batched_out, hidden_out);
} }
#undef INIT_VEC_FUNC #undef INIT_OTHER_DEFINES
#undef INIT_BASE_SIZES #undef INIT_BASE_DEFINES
#undef INIT_BASE_INPUT_OUTPUT
}; };
} // namespace operators } // namespace operators
......
...@@ -68,6 +68,7 @@ cc_test(selected_rows_functor_test SRCS selected_rows_functor_test.cc DEPS selec ...@@ -68,6 +68,7 @@ cc_test(selected_rows_functor_test SRCS selected_rows_functor_test.cc DEPS selec
cc_test(im2col_test SRCS im2col_test.cc DEPS im2col) cc_test(im2col_test SRCS im2col_test.cc DEPS im2col)
cc_test(vol2col_test SRCS vol2col_test.cc DEPS vol2col) cc_test(vol2col_test SRCS vol2col_test.cc DEPS vol2col)
cc_test(sequence_padding_test SRCS sequence_padding_test.cc DEPS sequence_padding) cc_test(sequence_padding_test SRCS sequence_padding_test.cc DEPS sequence_padding)
cc_test(sequence_pooling_test SRCS sequence_pooling_test.cc DEPS sequence_pooling)
if(WITH_GPU) if(WITH_GPU)
nv_test(math_function_gpu_test SRCS math_function_test.cu DEPS math_function) nv_test(math_function_gpu_test SRCS math_function_test.cu DEPS math_function)
nv_test(selected_rows_functor_gpu_test SRCS selected_rows_functor_test.cu DEPS selected_rows_functor math_function) nv_test(selected_rows_functor_gpu_test SRCS selected_rows_functor_test.cu DEPS selected_rows_functor math_function)
...@@ -75,6 +76,6 @@ endif() ...@@ -75,6 +76,6 @@ endif()
cc_test(concat_test SRCS concat_test.cc DEPS concat_and_split) cc_test(concat_test SRCS concat_test.cc DEPS concat_and_split)
cc_test(cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info) cc_test(cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info)
cc_library(jit_kernel cc_library(jit_kernel
SRCS jit_kernel.cc jit_kernel_blas.cc jit_kernel_exp.cc jit_kernel_lstm.cc SRCS jit_kernel.cc jit_kernel_blas.cc jit_kernel_exp.cc jit_kernel_rnn.cc
DEPS cpu_info cblas) DEPS cpu_info cblas)
cc_test(jit_kernel_test SRCS jit_kernel_test.cc DEPS jit_kernel) cc_test(jit_kernel_test SRCS jit_kernel_test.cc DEPS jit_kernel)
...@@ -39,6 +39,52 @@ HOSTDEVICE inline int64_t BinarySearch(const T *x, int64_t num, const T &val) { ...@@ -39,6 +39,52 @@ HOSTDEVICE inline int64_t BinarySearch(const T *x, int64_t num, const T &val) {
return -1; return -1;
} }
template <typename T>
HOSTDEVICE inline size_t LowerBound(const T *x, size_t num, const T &val) {
#ifdef __CUDA_ARCH__
// The following code is from
// https://en.cppreference.com/w/cpp/algorithm/lower_bound
auto *first = x;
int64_t count = static_cast<int64_t>(num);
while (count > 0) {
int64_t step = (count >> 1);
auto *it = first + step;
if (*it < val) {
first = ++it;
count -= (step + 1);
} else {
count = step;
}
}
return static_cast<size_t>(first - x);
#else
return static_cast<size_t>(std::lower_bound(x, x + num, val) - x);
#endif
}
template <typename T>
HOSTDEVICE inline size_t UpperBound(const T *x, size_t num, const T &val) {
#ifdef __CUDA_ARCH__
// The following code is from
// https://en.cppreference.com/w/cpp/algorithm/upper_bound
auto *first = x;
int64_t count = static_cast<int64_t>(num);
while (count > 0) {
auto step = (count >> 1);
auto *it = first + step;
if (val < *it) {
count = step;
} else {
first = ++it;
count -= (step + 1);
}
}
return static_cast<size_t>(first - x);
#else
return static_cast<size_t>(std::upper_bound(x, x + num, val) - x);
#endif
}
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -142,6 +142,15 @@ class LSTMKernel : public Kernel { ...@@ -142,6 +142,15 @@ class LSTMKernel : public Kernel {
const T *wp_data = nullptr) const = 0; const T *wp_data = nullptr) const = 0;
}; };
template <typename T>
class GRUKernel : public Kernel {
public:
// compute h1 without h0
virtual void ComputeH1(T *gates, T *ht) const = 0;
virtual void ComputeHtPart1(T *gates, const T *ht_1, T *ht) const = 0;
virtual void ComputeHtPart2(T *gates, const T *ht_1, T *ht) const = 0;
};
} // namespace jitkernel } // namespace jitkernel
} // namespace math } // namespace math
} // namespace operators } // namespace operators
......
...@@ -136,6 +136,23 @@ static std::shared_ptr<const VActKernel<T>> GetActKernel( ...@@ -136,6 +136,23 @@ static std::shared_ptr<const VActKernel<T>> GetActKernel(
return nullptr; return nullptr;
} }
#ifdef __AVX__
template <jit::cpu_isa_t isa>
static std::unique_ptr<AVXAct> GetAVXAct(const std::string& type) {
if (type == "sigmoid") {
return std::unique_ptr<AVXAct>(new AVXActImpl<kSigmoid, isa>());
} else if (type == "relu") {
return std::unique_ptr<AVXAct>(new AVXActImpl<kRelu, isa>());
} else if (type == "tanh") {
return std::unique_ptr<AVXAct>(new AVXActImpl<kTanh, isa>());
} else if (type == "identity" || type == "") {
return std::unique_ptr<AVXAct>(new AVXActImpl<kIdentity, isa>());
}
PADDLE_THROW("Not support type: %s", type);
return nullptr;
}
#endif
/* LSTM JitKernel */ /* LSTM JitKernel */
template <typename T, jit::cpu_isa_t isa, jit_block> template <typename T, jit::cpu_isa_t isa, jit_block>
class LSTMKernelImpl : public LSTMKernel<T> { class LSTMKernelImpl : public LSTMKernel<T> {
...@@ -198,21 +215,9 @@ class LSTMKernelImpl : public LSTMKernel<T> { ...@@ -198,21 +215,9 @@ class LSTMKernelImpl : public LSTMKernel<T> {
const std::string& act_gate, const std::string& act_cand, \ const std::string& act_gate, const std::string& act_cand, \
const std::string& act_cell, int d) \ const std::string& act_cell, int d) \
: LSTMKernel<float>() { \ : LSTMKernel<float>() { \
auto GetAVXAct = [&](const std::string& type) -> std::unique_ptr<AVXAct> { \ avx_act_gate_ = GetAVXAct<isa>(act_gate); \
if (type == "sigmoid") { \ avx_act_cand_ = GetAVXAct<isa>(act_cand); \
return std::unique_ptr<AVXAct>(new AVXActImpl<kSigmoid, isa>()); \ avx_act_cell_ = GetAVXAct<isa>(act_cell); \
} else if (type == "relu") { \
return std::unique_ptr<AVXAct>(new AVXActImpl<kRelu, isa>()); \
} else if (type == "tanh") { \
return std::unique_ptr<AVXAct>(new AVXActImpl<kTanh, isa>()); \
} else if (type == "identity" || type == "") { \
return std::unique_ptr<AVXAct>(new AVXActImpl<kIdentity, isa>()); \
} \
PADDLE_THROW("Not support type: %s", type); \
}; \
avx_act_gate_ = GetAVXAct(act_gate); \
avx_act_cand_ = GetAVXAct(act_cand); \
avx_act_cell_ = GetAVXAct(act_cell); \
} \ } \
template <> \ template <> \
void LSTMKernelImpl<float, isa, kEQ8>::ComputeCtHt( \ void LSTMKernelImpl<float, isa, kEQ8>::ComputeCtHt( \
...@@ -354,6 +359,126 @@ REGISTER_JITKERNEL_ARGS(lstm, LSTMKernel, JITKERNEL_DECLARE_LSTM, ...@@ -354,6 +359,126 @@ REGISTER_JITKERNEL_ARGS(lstm, LSTMKernel, JITKERNEL_DECLARE_LSTM,
#undef JITKERNEL_DECLARE_LSTM #undef JITKERNEL_DECLARE_LSTM
#undef JITKERNEL_KEY_LSTM #undef JITKERNEL_KEY_LSTM
#undef JITKERNEL_NEW_LSTM_IMPL #undef JITKERNEL_NEW_LSTM_IMPL
/* GRU JitKernel */
template <typename T, jit::cpu_isa_t isa, jit_block>
class GRUKernelImpl : public GRUKernel<T> {
public:
explicit GRUKernelImpl(const std::string& act_gate,
const std::string& act_state, int d)
: GRUKernel<T>() {
d_ = d;
d2_ = d * 2;
act_gate_d2_ = GetActKernel<T>(act_gate, d2_);
act_gate_d_ = GetActKernel<T>(act_gate, d);
act_state_d_ = GetActKernel<T>(act_state, d);
vmul_d_ = KernelPool::Instance().template Get<VMulKernel<T>>(d);
}
void ComputeH1(T* gates, T* ht) const override {
act_gate_d_->Compute(gates, gates);
act_state_d_->Compute(gates + d2_, gates + d2_);
vmul_d_->Compute(gates, gates + d2_, ht);
}
void ComputeHtPart1(T* gates, const T* ht_1, T* ht) const override {
// W: {W_update, W_reset; W_state}
act_gate_d2_->Compute(gates, gates);
vmul_d_->Compute(ht_1, gates + d_, ht);
}
void ComputeHtPart2(T* gates, const T* ht_1, T* ht) const override {
T* y = gates + d2_;
act_state_d_->Compute(y, y);
// out = zt*ht~ + (1-zt)*ht_1
for (int i = 0; i < d_; ++i) {
ht[i] = gates[i] * y[i] + (static_cast<T>(1) - gates[i]) * ht_1[i];
}
}
private:
int d_, d2_;
std::shared_ptr<const VActKernel<T>> act_gate_d2_, act_gate_d_, act_state_d_;
std::shared_ptr<const VMulKernel<T>> vmul_d_;
#ifdef __AVX__
std::unique_ptr<const AVXAct> avx_act_gate_, avx_act_state_;
#endif
};
#define INTRI8_FLOAT(isa) \
template <> \
GRUKernelImpl<float, isa, kEQ8>::GRUKernelImpl( \
const std::string& act_gate, const std::string& act_state, int d) \
: GRUKernel<float>() { \
avx_act_gate_ = GetAVXAct<isa>(act_gate); \
avx_act_state_ = GetAVXAct<isa>(act_state); \
} \
template <> \
void GRUKernelImpl<float, isa, kEQ8>::ComputeH1(float* gates, float* ht) \
const { \
__m256 u, s; \
/* W: {W_update, W_reset; W_state} */ \
u = _mm256_loadu_ps(gates); \
s = _mm256_loadu_ps(gates + 16); \
s = _mm256_mul_ps(avx_act_gate_->Compute(u), avx_act_state_->Compute(s)); \
_mm256_storeu_ps(ht, s); \
} \
template <> \
void GRUKernelImpl<float, isa, kEQ8>::ComputeHtPart1( \
float* gates, const float* ht_1, float* ht) const { \
/* not exactly equal the any implementation */ \
__m256 r, ht0; \
r = _mm256_loadu_ps(gates + 8); \
ht0 = _mm256_loadu_ps(ht_1); \
r = _mm256_mul_ps(avx_act_gate_->Compute(r), ht0); \
_mm256_storeu_ps(ht, r); \
} \
template <> \
void GRUKernelImpl<float, isa, kEQ8>::ComputeHtPart2( \
float* gates, const float* ht_1, float* ht) const { \
/* not exactly equal the any implementation */ \
__m256 u, s, ht0; \
u = _mm256_loadu_ps(gates); \
s = _mm256_loadu_ps(gates + 16); \
ht0 = _mm256_loadu_ps(ht_1); \
u = avx_act_gate_->Compute(u); \
s = _mm256_mul_ps(u, avx_act_state_->Compute(s)); \
u = _mm256_sub_ps(_mm256_set1_ps(1.f), u); \
u = _mm256_mul_ps(u, ht0); \
u = _mm256_add_ps(s, u); \
_mm256_storeu_ps(ht, u); \
}
#ifdef __AVX__
INTRI8_FLOAT(jit::avx);
#endif
#ifdef __AVX2__
INTRI8_FLOAT(jit::avx2);
#endif
#ifdef __AVX512F__
INTRI8_FLOAT(jit::avx512f);
#endif
#define JITKERNEL_DECLARE_GRU(ker_class, ker_dtype) \
template <> \
std::shared_ptr<const GRUKernel<ker_dtype>> KernelPool::Get< \
GRUKernel<ker_dtype>, const std::string&, const std::string&, int>( \
const std::string& act_gate, const std::string& act_state, int d)
#define JITKERNEL_KEY_GRU(ker_key, dtype_key) \
#ker_key #dtype_key + std::to_string(d) + act_gate + act_state
#define JITKERNEL_NEW_GRU_IMPL(ker, dtype, isa, k) \
p = std::dynamic_pointer_cast<ker<dtype>>( \
std::make_shared<ker##Impl<dtype, isa, k>>(act_gate, act_state, d));
REGISTER_JITKERNEL_ARGS(gru, GRUKernel, JITKERNEL_DECLARE_GRU,
JITKERNEL_KEY_GRU, JITKERNEL_NEW_GRU_IMPL);
#undef INTRI8_FLOAT
#undef JITKERNEL_NEW_GRU_IMPL
#undef JITKERNEL_KEY_GRU
#undef JITKERNEL_DECLARE_GRU
} // namespace jitkernel } // namespace jitkernel
} // namespace math } // namespace math
} // namespace operators } // namespace operators
......
...@@ -157,6 +157,31 @@ class FirstSeqPoolFunctor { ...@@ -157,6 +157,31 @@ class FirstSeqPoolFunctor {
} }
}; };
template <typename T>
class SumSeqPoolGradFunctor {
public:
void operator()(const platform::CPUDeviceContext& context,
const framework::Tensor& out_grad,
framework::LoDTensor* in_grad) {
auto lod = in_grad->lod()[0];
int64_t out_w = out_grad.numel() / out_grad.dims()[0];
int64_t in_w = in_grad->numel() / in_grad->dims()[0];
PADDLE_ENFORCE(in_w == out_w);
const T* out_g_data = out_grad.data<T>();
T* in_g_data = in_grad->mutable_data<T>(context.GetPlace());
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
int64_t h = static_cast<int64_t>(lod[i + 1] - lod[i]);
int64_t in_offset = lod[i] * in_w;
const T* out_pos = out_g_data + i * out_w;
T* in_pos = in_g_data + in_offset;
for (int r = 0; r != h; ++r) {
blas.VCOPY(in_w, out_pos, in_pos + r * in_w);
}
}
}
};
template <typename T> template <typename T>
class SequencePoolFunctor<platform::CPUDeviceContext, T> { class SequencePoolFunctor<platform::CPUDeviceContext, T> {
public: public:
...@@ -231,9 +256,15 @@ class SequencePoolGradFunctor<platform::CPUDeviceContext, T> { ...@@ -231,9 +256,15 @@ class SequencePoolGradFunctor<platform::CPUDeviceContext, T> {
math::SetConstant<platform::CPUDeviceContext, T> functor; math::SetConstant<platform::CPUDeviceContext, T> functor;
functor(context, in_grad, 0); functor(context, in_grad, 0);
} }
if (pooltype == "SUM") {
math::SumSeqPoolGradFunctor<T> sum_pool_grad;
sum_pool_grad(context, out_grad, in_grad);
return;
}
auto lod = in_grad->lod()[0]; auto lod = in_grad->lod()[0];
auto& place = *context.eigen_device(); auto& place = *context.eigen_device();
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) { for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
auto in_g_t = in_grad->Slice(static_cast<int>(lod[i]), auto in_g_t = in_grad->Slice(static_cast<int>(lod[i]),
static_cast<int>(lod[i + 1])); static_cast<int>(lod[i + 1]));
...@@ -247,12 +278,6 @@ class SequencePoolGradFunctor<platform::CPUDeviceContext, T> { ...@@ -247,12 +278,6 @@ class SequencePoolGradFunctor<platform::CPUDeviceContext, T> {
if (pooltype == "AVERAGE") { if (pooltype == "AVERAGE") {
in_g_e.device(place) = (out_g_e / static_cast<T>(h)).broadcast(bcast); in_g_e.device(place) = (out_g_e / static_cast<T>(h)).broadcast(bcast);
} else if (pooltype == "SUM") {
const T* out_g_data = out_g_t.data<T>();
T* in_g_data = in_g_t.mutable_data<T>(context.GetPlace());
for (int r = 0; r != h; ++r) {
blas.VCOPY(w, out_g_data, in_g_data + r * w);
}
} else if (pooltype == "SQRT") { } else if (pooltype == "SQRT") {
in_g_e.device(place) = in_g_e.device(place) =
(out_g_e / std::sqrt(static_cast<T>(h))).broadcast(bcast); (out_g_e / std::sqrt(static_cast<T>(h))).broadcast(bcast);
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/sequence_pooling.h"
#include <gtest/gtest.h>
#include <vector>
template <typename DeviceContext, typename Place, typename T>
void TestSequencePoolingSum(const paddle::framework::LoD& lod) {
paddle::framework::LoDTensor cpu_out_grad;
paddle::framework::LoDTensor cpu_in_grad;
paddle::framework::LoDTensor out_grad;
paddle::framework::LoDTensor in_grad;
const size_t second_dim = 128u;
// construct out_grad's tensor in cpu
const size_t out_first_dim = lod[0].size() - 1;
auto out_dims = paddle::framework::make_ddim(
{static_cast<int64_t>(out_first_dim), static_cast<int64_t>(second_dim)});
cpu_out_grad.mutable_data<T>(out_dims, paddle::platform::CPUPlace());
for (int64_t i = 0; i < cpu_out_grad.numel(); ++i) {
cpu_out_grad.data<T>()[i] = static_cast<T>(i);
}
// copy to dst out_grad
auto* place = new Place();
DeviceContext* context = new DeviceContext(*place);
if (paddle::platform::is_cpu_place(*place)) {
out_grad = cpu_out_grad;
} else {
TensorCopySync(cpu_out_grad, *place, &out_grad);
}
// construct in_grad
in_grad.set_lod(lod);
auto in_dims = paddle::framework::make_ddim(
{static_cast<int64_t>(lod[0].back()), static_cast<int64_t>(second_dim)});
in_grad.mutable_data<T>(in_dims, context->GetPlace());
// check tensor contruction result
PADDLE_ENFORCE_EQ(in_grad.dims().size(), out_grad.dims().size());
for (int64_t i = 1; i < out_grad.dims().size(); ++i) {
PADDLE_ENFORCE_EQ(in_grad.dims()[i], out_grad.dims()[i]);
}
// call functor
paddle::operators::math::SequencePoolGradFunctor<DeviceContext, T>()(
*context, "SUM", out_grad, &in_grad);
if (paddle::platform::is_cpu_place(*place)) {
cpu_in_grad = in_grad;
} else {
TensorCopySync(in_grad, paddle::platform::CPUPlace(), &cpu_in_grad);
cpu_in_grad.set_lod(in_grad.lod());
}
EXPECT_EQ(in_grad.numel(), lod[0].back() * second_dim);
EXPECT_EQ(in_grad.lod(), lod);
if (paddle::platform::is_cpu_place(*place)) {
for (int64_t i = 0; i < in_grad.lod()[0].size() - 1; ++i) {
int64_t begin = in_grad.lod()[0][i];
int64_t end = in_grad.lod()[0][i + 1];
paddle::framework::Tensor tmp = in_grad.Slice(begin, end);
for (int64_t j = 0; j != tmp.numel() / second_dim; ++j) {
for (int64_t m = 0; m != second_dim; ++m) {
EXPECT_EQ(tmp.data<T>()[m + j * second_dim],
out_grad.data<T>()[m + i * second_dim]);
}
}
}
} else {
for (int64_t i = 0; i < cpu_in_grad.lod()[0].size() - 1; ++i) {
int64_t begin = cpu_in_grad.lod()[0][i];
int64_t end = cpu_in_grad.lod()[0][i + 1];
paddle::framework::Tensor tmp = cpu_in_grad.Slice(begin, end);
for (int64_t j = 0; j != tmp.numel() / second_dim; ++j) {
for (int64_t m = 0; m != second_dim; ++m) {
EXPECT_EQ(tmp.data<T>()[m + j * second_dim],
cpu_out_grad.data<T>()[m + i * second_dim]);
}
}
}
}
delete place;
delete context;
}
TEST(SequencePoolingGrad, CPU_SUM) {
paddle::framework::LoD lod1;
lod1.push_back(std::vector<size_t>{0, 10});
TestSequencePoolingSum<paddle::platform::CPUDeviceContext,
paddle::platform::CPUPlace, float>(lod1);
paddle::framework::LoD lod2;
lod2.push_back(std::vector<size_t>{0, 2, 7, 10});
TestSequencePoolingSum<paddle::platform::CPUDeviceContext,
paddle::platform::CPUPlace, float>(lod2);
}
#ifdef PADDLE_WITH_CUDA
TEST(SequencePoolingGrad, CUDA_SUM) {
paddle::framework::LoD lod1;
lod1.push_back(std::vector<size_t>{0, 10});
TestSequencePoolingSum<paddle::platform::CUDADeviceContext,
paddle::platform::CUDAPlace, float>(lod1);
paddle::framework::LoD lod2;
lod2.push_back(std::vector<size_t>{0, 2, 7, 10});
TestSequencePoolingSum<paddle::platform::CUDADeviceContext,
paddle::platform::CUDAPlace, float>(lod2);
}
#endif
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/sequence_reverse_op.h"
namespace ops = paddle::operators;
REGISTER_OPERATOR(sequence_reverse, ops::SequenceReverseOp,
ops::SequenceReverseOpMaker,
ops::SequenceReverseGradOpDescMaker);
REGISTER_OP_CPU_KERNEL(
sequence_reverse,
ops::SequenceReverseOpKernel<paddle::platform::CPUDeviceContext, uint8_t>,
ops::SequenceReverseOpKernel<paddle::platform::CPUDeviceContext, int>,
ops::SequenceReverseOpKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::SequenceReverseOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::SequenceReverseOpKernel<paddle::platform::CPUDeviceContext, double>);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/sequence_reverse_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
sequence_reverse,
ops::SequenceReverseOpKernel<paddle::platform::CUDADeviceContext, uint8_t>,
ops::SequenceReverseOpKernel<paddle::platform::CUDADeviceContext, int>,
ops::SequenceReverseOpKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::SequenceReverseOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::SequenceReverseOpKernel<paddle::platform::CUDADeviceContext, double>);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/algorithm.h"
#include "paddle/fluid/platform/for_range.h"
namespace paddle {
namespace operators {
class SequenceReverseOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must exist");
PADDLE_ENFORCE(ctx->HasOutput("Y"), "Output(Y) must exist");
auto x_dim = ctx->GetInputDim("X");
PADDLE_ENFORCE_GE(x_dim.size(), 2,
"Rank of Input(X) must be not less than 2.");
ctx->SetOutputDim("Y", x_dim);
ctx->ShareLoD("X", "Y");
}
};
class SequenceReverseOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "The input LoDTensor of sequence_reverse op.");
AddOutput("Y", "The output LoDTensor of sequence_reverse op.");
AddComment(R"DOC(
SequenceReverse Operator.
Reverse each sequence in input X along dim 0.
Assuming X is a LoDTensor with dims [5, 4] and lod [[0, 2, 5]], where:
X.data() = [
[1, 2, 3, 4],
[5, 6, 7, 8], # the 0-th sequence with length 2
[9, 10, 11, 12],
[13, 14, 15, 16],
[17, 18, 19, 20] # the 1-st sequence with length 3
]
The output Y would be a LoDTensor sharing the same dims and lod with input X,
and:
Y.data() = [
[5, 6, 7, 8],
[1, 2, 3, 4], # the reversed 0-th sequence with length 2
[17, 18, 19, 20],
[13, 14, 15, 16],
[9, 10, 11, 12] # the reversed 1-st sequence with length 3
]
This Operator is useful to build a reverse dynamic RNN network.
This Operator only supports one-level lod currently.
)DOC");
}
};
template <typename T>
struct SequenceReverseFunctor {
SequenceReverseFunctor(const T *x, T *y, const size_t *lod, size_t lod_count,
size_t row_numel)
: x_(x), y_(y), lod_(lod), lod_count_(lod_count), row_numel_(row_numel) {}
HOSTDEVICE void operator()(size_t idx_x) const {
auto row_idx_x = idx_x / row_numel_;
auto lod_idx = math::UpperBound(lod_, lod_count_, row_idx_x);
auto row_idx_y = lod_[lod_idx - 1] + (lod_[lod_idx] - 1 - row_idx_x);
auto idx_y = row_idx_y * row_numel_ + idx_x % row_numel_;
y_[idx_y] = x_[idx_x];
}
const T *x_;
T *y_;
const size_t *lod_;
size_t lod_count_;
size_t row_numel_;
};
template <typename DeviceContext, typename T>
class SequenceReverseOpKernel : public framework::OpKernel<T> {
using LoDTensor = framework::LoDTensor;
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto &x = *ctx.Input<LoDTensor>("X");
auto *y = ctx.Output<LoDTensor>("Y");
PADDLE_ENFORCE_EQ(x.lod().size(), 1,
"SequenceReverse Op only support one level lod.");
auto &dev_ctx = ctx.template device_context<DeviceContext>();
const size_t *lod;
size_t lod_count = x.lod()[0].size();
#ifdef PADDLE_WITH_CUDA
if (platform::is_gpu_place(ctx.GetPlace())) {
lod = x.lod()[0].CUDAData(ctx.GetPlace());
} else {
#endif
lod = x.lod()[0].data();
#ifdef PADDLE_WITH_CUDA
}
#endif
size_t limit = static_cast<size_t>(x.numel());
size_t row_numel = static_cast<size_t>(limit / x.dims()[0]);
auto *x_data = x.data<T>();
auto *y_data = y->mutable_data<T>(ctx.GetPlace());
PADDLE_ENFORCE_NE(x_data, y_data,
"SequenceReverse Op does not support in-place operation");
SequenceReverseFunctor<T> functor(x_data, y_data, lod, lod_count,
row_numel);
platform::ForRange<DeviceContext> for_range(dev_ctx, limit);
for_range(functor);
}
};
class SequenceReverseGradOpDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("sequence_reverse");
op->SetInput("X", OutputGrad("Y"));
op->SetOutput("Y", InputGrad("X"));
op->SetAttrMap(Attrs());
return op;
}
};
} // namespace operators
} // namespace paddle
...@@ -76,6 +76,8 @@ namespace ops = paddle::operators; ...@@ -76,6 +76,8 @@ namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
REGISTER_OP_KERNEL(softmax, CUDNN, plat::CUDAPlace, REGISTER_OP_KERNEL(softmax, CUDNN, plat::CUDAPlace,
ops::SoftmaxCUDNNKernel<float>, ops::SoftmaxCUDNNKernel<float>,
ops::SoftmaxCUDNNKernel<double>,
ops::SoftmaxCUDNNKernel<plat::float16>); ops::SoftmaxCUDNNKernel<plat::float16>);
REGISTER_OP_KERNEL(softmax_grad, CUDNN, plat::CUDAPlace, REGISTER_OP_KERNEL(softmax_grad, CUDNN, plat::CUDAPlace,
ops::SoftmaxGradCUDNNKernel<float>); ops::SoftmaxGradCUDNNKernel<float>,
ops::SoftmaxGradCUDNNKernel<double>);
...@@ -210,18 +210,21 @@ REGISTER_OPERATOR(transpose, ops::TransposeOp, ops::TransposeOpMaker, ...@@ -210,18 +210,21 @@ REGISTER_OPERATOR(transpose, ops::TransposeOp, ops::TransposeOpMaker,
REGISTER_OPERATOR(transpose_grad, ops::TransposeOpGrad); REGISTER_OPERATOR(transpose_grad, ops::TransposeOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
transpose, ops::TransposeKernel<paddle::platform::CPUDeviceContext, float>); transpose, ops::TransposeKernel<paddle::platform::CPUDeviceContext, float>,
ops::TransposeKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
transpose_grad, transpose_grad,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, float>); ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OPERATOR(transpose2, ops::Transpose2Op, ops::Transpose2OpMaker, REGISTER_OPERATOR(transpose2, ops::Transpose2Op, ops::Transpose2OpMaker,
ops::Transpose2GradMaker); ops::Transpose2GradMaker);
REGISTER_OPERATOR(transpose2_grad, ops::Transpose2OpGrad); REGISTER_OPERATOR(transpose2_grad, ops::Transpose2OpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
transpose2, transpose2, ops::TransposeKernel<paddle::platform::CPUDeviceContext, float>,
ops::TransposeKernel<paddle::platform::CPUDeviceContext, float>); ops::TransposeKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
transpose2_grad, transpose2_grad,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, float>); ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, double>);
...@@ -16,15 +16,18 @@ limitations under the License. */ ...@@ -16,15 +16,18 @@ limitations under the License. */
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
transpose, transpose, ops::TransposeKernel<paddle::platform::CUDADeviceContext, float>,
ops::TransposeKernel<paddle::platform::CUDADeviceContext, float>); ops::TransposeKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
transpose_grad, transpose_grad,
ops::TransposeGradKernel<paddle::platform::CUDADeviceContext, float>); ops::TransposeGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::TransposeGradKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
transpose2, transpose2,
ops::TransposeKernel<paddle::platform::CUDADeviceContext, float>); ops::TransposeKernel<paddle::platform::CUDADeviceContext, float>,
ops::TransposeKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
transpose2_grad, transpose2_grad,
ops::TransposeGradKernel<paddle::platform::CUDADeviceContext, float>); ops::TransposeGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::TransposeGradKernel<paddle::platform::CUDADeviceContext, double>);
...@@ -78,7 +78,7 @@ def __build_dict(tar_file, dict_size, save_path, lang): ...@@ -78,7 +78,7 @@ def __build_dict(tar_file, dict_size, save_path, lang):
six.iteritems(word_dict), key=lambda x: x[1], six.iteritems(word_dict), key=lambda x: x[1],
reverse=True)): reverse=True)):
if idx + 3 == dict_size: break if idx + 3 == dict_size: break
fout.write("%s\n" % (word[0])) fout.write("%s\n" % (cpt.to_bytes(word[0])))
def __load_dict(tar_file, dict_size, lang, reverse=False): def __load_dict(tar_file, dict_size, lang, reverse=False):
......
...@@ -272,7 +272,7 @@ class GradientClipByGlobalNorm(BaseGradientClipAttr): ...@@ -272,7 +272,7 @@ class GradientClipByGlobalNorm(BaseGradientClipAttr):
) )
square = grad * grad square = grad * grad
local_norm_var = layers.cast(layers.reduce_sum(input=square), 'float64') local_norm_var = layers.reduce_sum(input=square)
context[self.group_name].append(local_norm_var) context[self.group_name].append(local_norm_var)
self.context = context self.context = context
...@@ -282,7 +282,6 @@ class GradientClipByGlobalNorm(BaseGradientClipAttr): ...@@ -282,7 +282,6 @@ class GradientClipByGlobalNorm(BaseGradientClipAttr):
if group_scale_name not in self.context: if group_scale_name not in self.context:
group_norm_var = layers.sums(input=self.context[self.group_name]) group_norm_var = layers.sums(input=self.context[self.group_name])
group_norm_var = layers.sqrt(x=group_norm_var) group_norm_var = layers.sqrt(x=group_norm_var)
group_norm_var = layers.cast(group_norm_var, 'float32')
clip_var = self.context[self.group_name + "_clip"] clip_var = self.context[self.group_name + "_clip"]
group_scale_var = layers.elementwise_div( group_scale_var = layers.elementwise_div(
x=clip_var, x=clip_var,
...@@ -333,7 +332,8 @@ def append_gradient_clip_ops(param_grads): ...@@ -333,7 +332,8 @@ def append_gradient_clip_ops(param_grads):
for p, g in param_grads: for p, g in param_grads:
if g is None: if g is None:
continue continue
with p.block.program._optimized_guard([p, g]): with p.block.program._optimized_guard(
[p, g]), framework.name_scope('append_clip'):
clip_attr = getattr(p, 'gradient_clip_attr', NullGradientClipAttr()) clip_attr = getattr(p, 'gradient_clip_attr', NullGradientClipAttr())
if clip_attr is None: if clip_attr is None:
clip_attr = NullGradientClipAttr() clip_attr = NullGradientClipAttr()
...@@ -348,7 +348,8 @@ def append_gradient_clip_ops(param_grads): ...@@ -348,7 +348,8 @@ def append_gradient_clip_ops(param_grads):
for p, g in param_grads: for p, g in param_grads:
if g is None: if g is None:
continue continue
with p.block.program._optimized_guard([p, g]): with p.block.program._optimized_guard(
[p, g]), framework.name_scope('append_graident_clip'):
res.append(clip_attr._create_operators(param=p, grad=g)) res.append(clip_attr._create_operators(param=p, grad=g))
return res return res
......
...@@ -1496,6 +1496,9 @@ class Program(object): ...@@ -1496,6 +1496,9 @@ class Program(object):
>>> with program._optimized_guard([p,g]): >>> with program._optimized_guard([p,g]):
>>> p = p - 0.001 * g >>> p = p - 0.001 * g
""" """
tmp_role = self._current_role
tmp_var = self._op_role_var
OpRole = core.op_proto_and_checker_maker.OpRole OpRole = core.op_proto_and_checker_maker.OpRole
self._current_role = OpRole.Optimize self._current_role = OpRole.Optimize
self._op_role_var = [ self._op_role_var = [
...@@ -1503,11 +1506,11 @@ class Program(object): ...@@ -1503,11 +1506,11 @@ class Program(object):
for var in param_and_grads for var in param_and_grads
] ]
yield yield
self._op_role_var = [] self._op_role_var = tmp_var
self._current_role = OpRole.Forward self._current_role = tmp_role
@contextlib.contextmanager @contextlib.contextmanager
def _lr_schedule_guard(self): def _lr_schedule_guard(self, is_with_opt=False):
""" """
A with guard to set :code:`LRSched` :code:`OpRole` and A with guard to set :code:`LRSched` :code:`OpRole` and
:code:`OpRoleVar` automatically. The :code:`OpRoleVar` is :code:`OpRoleVar` automatically. The :code:`OpRoleVar` is
...@@ -1515,6 +1518,10 @@ class Program(object): ...@@ -1515,6 +1518,10 @@ class Program(object):
Notes: This is a very low level API. Users should not use it directly. Notes: This is a very low level API. Users should not use it directly.
Args:
is_with_opt: Only set to true if these ops a in the middle
of a bunch of optimize ops so that it can be treated
correctly. For example, sgd->lr_op->sgd->lr_op->sgd.
Examples: Examples:
...@@ -1528,6 +1535,8 @@ class Program(object): ...@@ -1528,6 +1535,8 @@ class Program(object):
OpRole = core.op_proto_and_checker_maker.OpRole OpRole = core.op_proto_and_checker_maker.OpRole
self._current_role = OpRole.LRSched self._current_role = OpRole.LRSched
if is_with_opt:
self._current_role = int(OpRole.LRSched) | int(OpRole.Optimize)
# TODO(typhoonzero): how to set target learning rate var # TODO(typhoonzero): how to set target learning rate var
self._op_role_var = [] self._op_role_var = []
yield yield
......
...@@ -154,6 +154,7 @@ __all__ = [ ...@@ -154,6 +154,7 @@ __all__ = [
'mul', 'mul',
'sigmoid_cross_entropy_with_logits', 'sigmoid_cross_entropy_with_logits',
'maxout', 'maxout',
'sequence_reverse',
'affine_channel', 'affine_channel',
] ]
...@@ -980,7 +981,12 @@ def cos_sim(X, Y): ...@@ -980,7 +981,12 @@ def cos_sim(X, Y):
return out return out
def dropout(x, dropout_prob, is_test=False, seed=None, name=None): def dropout(x,
dropout_prob,
is_test=False,
seed=None,
name=None,
dropout_implementation="downgrade_in_infer"):
""" """
Computes dropout. Computes dropout.
...@@ -1000,6 +1006,21 @@ def dropout(x, dropout_prob, is_test=False, seed=None, name=None): ...@@ -1000,6 +1006,21 @@ def dropout(x, dropout_prob, is_test=False, seed=None, name=None):
units will be dropped. DO NOT use a fixed seed in training. units will be dropped. DO NOT use a fixed seed in training.
name (str|None): A name for this layer(optional). If set None, the layer name (str|None): A name for this layer(optional). If set None, the layer
will be named automatically. will be named automatically.
dropout_implementation(string): ['downgrade_in_infer'(defauld)|'upscale_in_train']
1. downgrade_in_infer(default), downgrade the outcome at inference
train: out = input * mask
inference: out = input * dropout_prob
(make is a tensor same shape with input, value is 0 or 1
ratio of 0 is dropout_prob)
2. upscale_in_train, upscale the outcome at training time
train: out = input * mask / ( 1.0 - dropout_prob )
inference: out = input
(make is a tensor same shape with input, value is 0 or 1
ratio of 0 is dropout_prob)
dropout op can be removed from the program.
the program will be efficient
Returns: Returns:
Variable: A tensor variable is the shape with `x`. Variable: A tensor variable is the shape with `x`.
...@@ -1029,7 +1050,8 @@ def dropout(x, dropout_prob, is_test=False, seed=None, name=None): ...@@ -1029,7 +1050,8 @@ def dropout(x, dropout_prob, is_test=False, seed=None, name=None):
'dropout_prob': dropout_prob, 'dropout_prob': dropout_prob,
'is_test': is_test, 'is_test': is_test,
'fix_seed': seed is not None, 'fix_seed': seed is not None,
'seed': seed if seed is not None else 0 'seed': seed if seed is not None else 0,
'dropout_implementation': dropout_implementation,
}) })
return out return out
...@@ -4844,7 +4866,7 @@ def autoincreased_step_counter(counter_name=None, begin=1, step=1): ...@@ -4844,7 +4866,7 @@ def autoincreased_step_counter(counter_name=None, begin=1, step=1):
return counter return counter
def reshape(x, shape, actual_shape=None, act=None, inplace=True, name=None): def reshape(x, shape, actual_shape=None, act=None, inplace=False, name=None):
""" """
Gives a new shape to the input Tensor without changing its data. Gives a new shape to the input Tensor without changing its data.
...@@ -4892,15 +4914,22 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=True, name=None): ...@@ -4892,15 +4914,22 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=True, name=None):
:attr:`shape` specifying shape. That is to :attr:`shape` specifying shape. That is to
say :attr:`actual_shape` has a higher priority say :attr:`actual_shape` has a higher priority
than :attr:`shape`. than :attr:`shape`.
act (str): The non-linear activation to be applied to output variable. act (str): The non-linear activation to be applied to the reshaped tensor
inplace(bool): If this flag is set true, the output variable.
shares data with input without copying, otherwise inplace(bool): Must use :attr:`False` if :attr:`x` is used in multiple
a new output tensor is created operators. If this flag is set :attr:`True`, reuse input
whose data is copied from input x. :attr:`x` to reshape, which will change the shape of
tensor variable :attr:`x` and might cause errors when
:attr:`x` is used in multiple operators. If :attr:`False`,
preserve the shape :attr:`x` and create a new output tensor
variable whose data is copied from input x but reshaped.
name (str): The name of this layer. It is optional. name (str): The name of this layer. It is optional.
Returns: Returns:
Variable: The output tensor. Variable: The reshaped tensor variable if :attr:`act` is None. It is a \
new tensor variable if :attr:`inplace` is :attr:`False`, \
otherwise it is :attr:`x`. If :attr:`act` is not None, return \
the activated tensor variable.
Raises: Raises:
TypeError: if actual_shape is neither Variable nor None. TypeError: if actual_shape is neither Variable nor None.
...@@ -4911,7 +4940,7 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=True, name=None): ...@@ -4911,7 +4940,7 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=True, name=None):
data = fluid.layers.data( data = fluid.layers.data(
name='data', shape=[2, 4, 6], dtype='float32') name='data', shape=[2, 4, 6], dtype='float32')
reshaped = fluid.layers.reshape( reshaped = fluid.layers.reshape(
x=data, shape=[-1, 0, 3, 2], act='tanh', inplace=True) x=data, shape=[-1, 0, 3, 2], inplace=True)
""" """
if not (isinstance(shape, list) or isinstance(shape, tuple)): if not (isinstance(shape, list) or isinstance(shape, tuple)):
...@@ -4938,7 +4967,8 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=True, name=None): ...@@ -4938,7 +4967,8 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=True, name=None):
"except one unknown dimension.") "except one unknown dimension.")
helper = LayerHelper("reshape2", **locals()) helper = LayerHelper("reshape2", **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype) out = x if inplace else helper.create_variable_for_type_inference(
dtype=x.dtype)
x_shape = helper.create_variable_for_type_inference(dtype=x.dtype) x_shape = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op( helper.append_op(
type="reshape2", type="reshape2",
...@@ -7455,6 +7485,33 @@ def maxout(x, groups, name=None): ...@@ -7455,6 +7485,33 @@ def maxout(x, groups, name=None):
return out return out
@templatedoc()
def sequence_reverse(x, name=None):
"""
${comment}
Args:
x(${x_type}): ${x_comment}
name(basestring|None): Name of the output.
Returns:
out(${y_type}): ${y_comment}
"""
helper = LayerHelper("sequence_reverse", **locals())
if name is None:
out = helper.create_variable_for_type_inference(dtype=x.dtype)
else:
out = helper.create_variable(
name=name, dtype=x.dtype, persistable=False)
helper.append_op(
type="sequence_reverse",
inputs={"X": x},
outputs={"Y": out},
attrs=dict())
return out
def affine_channel(x, scale=None, bias=None, data_layout='NCHW', name=None): def affine_channel(x, scale=None, bias=None, data_layout='NCHW', name=None):
""" """
Applies a separate affine transformation to each channel of the input. Applies a separate affine transformation to each channel of the input.
......
...@@ -111,7 +111,9 @@ class Optimizer(object): ...@@ -111,7 +111,9 @@ class Optimizer(object):
if param_lr == 1.0: if param_lr == 1.0:
return self._global_learning_rate() return self._global_learning_rate()
else: else:
with default_main_program()._lr_schedule_guard(): with default_main_program()._lr_schedule_guard(
is_with_opt=True), framework.name_scope(
'scale_with_param_lr'):
return self._global_learning_rate() * param_lr return self._global_learning_rate() * param_lr
def _create_accumulators(self, block, parameters): def _create_accumulators(self, block, parameters):
...@@ -602,7 +604,8 @@ class AdamOptimizer(Optimizer): ...@@ -602,7 +604,8 @@ class AdamOptimizer(Optimizer):
for param, grad in param_and_grads: for param, grad in param_and_grads:
if grad is None: if grad is None:
continue continue
with param.block.program._optimized_guard([param, grad]): with param.block.program._optimized_guard(
[param, grad]), name_scope("optimizer"):
beta1_pow_acc = self._get_accumulator(self._beta1_pow_acc_str, beta1_pow_acc = self._get_accumulator(self._beta1_pow_acc_str,
param) param)
beta2_pow_acc = self._get_accumulator(self._beta2_pow_acc_str, beta2_pow_acc = self._get_accumulator(self._beta2_pow_acc_str,
...@@ -740,7 +743,8 @@ class AdamaxOptimizer(Optimizer): ...@@ -740,7 +743,8 @@ class AdamaxOptimizer(Optimizer):
for param, grad in parameters_and_grads: for param, grad in parameters_and_grads:
if grad is None: if grad is None:
continue continue
with param.block.program._optimized_guard([param, grad]): with param.block.program._optimized_guard(
[param, grad]), name_scope('adamx'):
beta1_pow_acc = self._get_accumulator(self._beta1_pow_acc_str, beta1_pow_acc = self._get_accumulator(self._beta1_pow_acc_str,
param) param)
main_block.append_op( main_block.append_op(
...@@ -1279,7 +1283,8 @@ class ModelAverage(Optimizer): ...@@ -1279,7 +1283,8 @@ class ModelAverage(Optimizer):
for param, grad in self.params_grads: for param, grad in self.params_grads:
if grad is None: if grad is None:
continue continue
with param.block.program._optimized_guard([param, grad]): with param.block.program._optimized_guard(
[param, grad]), name_scope('move_average'):
self._append_average_accumulate_op(param) self._append_average_accumulate_op(param)
self.apply_program = Program() self.apply_program = Program()
......
...@@ -47,7 +47,8 @@ def append_regularization_ops(parameters_and_grads, regularization=None): ...@@ -47,7 +47,8 @@ def append_regularization_ops(parameters_and_grads, regularization=None):
if grad is None: if grad is None:
params_and_grads.append((param, grad)) params_and_grads.append((param, grad))
continue continue
with param.block.program._optimized_guard([param, grad]): with param.block.program._optimized_guard(
[param, grad]), framework.name_scope('regularization'):
regularization_term = None regularization_term = None
if param.regularizer is not None: if param.regularizer is not None:
# Add variable for regularization term in grad block # Add variable for regularization term in grad block
......
...@@ -78,7 +78,7 @@ if(WITH_DISTRIBUTE) ...@@ -78,7 +78,7 @@ if(WITH_DISTRIBUTE)
set_tests_properties(test_dist_word2vec PROPERTIES TIMEOUT 200) set_tests_properties(test_dist_word2vec PROPERTIES TIMEOUT 200)
py_test_modules(test_dist_se_resnext MODULES test_dist_se_resnext) py_test_modules(test_dist_se_resnext MODULES test_dist_se_resnext)
set_tests_properties(test_dist_se_resnext PROPERTIES TIMEOUT 1000) set_tests_properties(test_dist_se_resnext PROPERTIES TIMEOUT 1000)
# TODO: fix this test # FIXME(typhoonzero): add this back
#py_test_modules(test_dist_transformer MODULES test_dist_transformer) #py_test_modules(test_dist_transformer MODULES test_dist_transformer)
#set_tests_properties(test_dist_transformer PROPERTIES TIMEOUT 1000) #set_tests_properties(test_dist_transformer PROPERTIES TIMEOUT 1000)
endif(NOT APPLE) endif(NOT APPLE)
......
...@@ -35,7 +35,7 @@ import paddle ...@@ -35,7 +35,7 @@ import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.layers as layers import paddle.fluid.layers as layers
from paddle.fluid import core from paddle.fluid import core
from test_dist_base import TestDistRunnerBase, runtime_main from test_dist_base import TestDistRunnerBase, runtime_main, RUN_STEP
import paddle.compat as cpt import paddle.compat as cpt
from paddle.compat import long_type from paddle.compat import long_type
...@@ -562,18 +562,12 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler, ...@@ -562,18 +562,12 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler,
for pass_id in six.moves.xrange(TrainTaskConfig.pass_num): for pass_id in six.moves.xrange(TrainTaskConfig.pass_num):
pass_start_time = time.time() pass_start_time = time.time()
for batch_id, data in enumerate(train_data()): for batch_id, data in enumerate(train_data()):
if batch_id >= 5: if batch_id >= RUN_STEP:
break break
feed_list = [] feed_list = []
total_num_token = 0 total_num_token = 0
#if TrainTaskConfig.local:
# lr_rate = lr_scheduler.update_learning_rate()
#for place_id, data_buffer in enumerate(
# split_data(
# data, num_part=dev_count)):
if TrainTaskConfig.local: if TrainTaskConfig.local:
lr_rate = lr_scheduler.update_learning_rate() lr_rate = lr_scheduler.update_learning_rate()
...@@ -619,7 +613,6 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler, ...@@ -619,7 +613,6 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler,
init = True init = True
# Validate and save the model for inference. # Validate and save the model for inference.
if batch_id == 0 or batch_id == 4:
if TrainTaskConfig.val_file_pattern is not None: if TrainTaskConfig.val_file_pattern is not None:
val_avg_cost, val_ppl = test() val_avg_cost, val_ppl = test()
print("[%f]" % val_avg_cost) print("[%f]" % val_avg_cost)
...@@ -1701,7 +1694,7 @@ class DistTransformer2x2(TestDistRunnerBase): ...@@ -1701,7 +1694,7 @@ class DistTransformer2x2(TestDistRunnerBase):
def run_trainer(self, args): def run_trainer(self, args):
TrainTaskConfig.use_gpu = args.use_cuda TrainTaskConfig.use_gpu = args.use_cuda
sum_cost, avg_cost, predict, token_num, local_lr_scheduler = get_model( sum_cost, avg_cost, predict, token_num, local_lr_scheduler, test_program = get_model(
args.is_dist, not args.sync_mode) args.is_dist, not args.sync_mode)
if args.is_dist: if args.is_dist:
......
...@@ -40,7 +40,8 @@ class TestDistMnistAsync(TestDistBase): ...@@ -40,7 +40,8 @@ class TestDistMnistAsync(TestDistBase):
self._sync_mode = False self._sync_mode = False
self._use_reduce = False self._use_reduce = False
def test_dist_train(self): # FIXME(typhoonzero): fix async mode test later
def no_test_dist_train(self):
self.check_with_place("dist_mnist.py", delta=200) self.check_with_place("dist_mnist.py", delta=200)
......
...@@ -40,7 +40,8 @@ class TestDistSeResneXt2x2Async(TestDistBase): ...@@ -40,7 +40,8 @@ class TestDistSeResneXt2x2Async(TestDistBase):
self._sync_mode = False self._sync_mode = False
self._use_reader_alloc = False self._use_reader_alloc = False
def test_dist_train(self): #FIXME(typhoonzero): fix async mode later
def no_test_dist_train(self):
self.check_with_place("dist_se_resnext.py", delta=100) self.check_with_place("dist_se_resnext.py", delta=100)
......
...@@ -42,7 +42,8 @@ class TestDistSimnetBow2x2DenseAsync(TestDistBase): ...@@ -42,7 +42,8 @@ class TestDistSimnetBow2x2DenseAsync(TestDistBase):
self._sync_mode = False self._sync_mode = False
self._enforce_place = "CPU" self._enforce_place = "CPU"
def test_simnet_bow(self): #FIXME(typhoonzero): fix async tests later
def no_test_simnet_bow(self):
need_envs = { need_envs = {
"IS_DISTRIBUTED": '0', "IS_DISTRIBUTED": '0',
"IS_SPARSE": '0', "IS_SPARSE": '0',
...@@ -78,7 +79,8 @@ class TestDistSimnetBow2x2SparseAsync(TestDistBase): ...@@ -78,7 +79,8 @@ class TestDistSimnetBow2x2SparseAsync(TestDistBase):
self._sync_mode = False self._sync_mode = False
self._enforce_place = "CPU" self._enforce_place = "CPU"
def test_simnet_bow(self): #FIXME(typhoonzero): fix async tests later
def no_test_simnet_bow(self):
need_envs = { need_envs = {
"IS_DISTRIBUTED": '0', "IS_DISTRIBUTED": '0',
"IS_SPARSE": '1', "IS_SPARSE": '1',
......
...@@ -61,7 +61,8 @@ class TestDistTransformer2x2Sync(TestDistBase): ...@@ -61,7 +61,8 @@ class TestDistTransformer2x2Sync(TestDistBase):
def test_dist_train(self): def test_dist_train(self):
download_files() download_files()
self.check_with_place("dist_transformer.py", delta=1e-5) self.check_with_place(
"dist_transformer.py", delta=1e-5, check_error_log=False)
class TestDistTransformer2x2Async(TestDistBase): class TestDistTransformer2x2Async(TestDistBase):
...@@ -70,7 +71,8 @@ class TestDistTransformer2x2Async(TestDistBase): ...@@ -70,7 +71,8 @@ class TestDistTransformer2x2Async(TestDistBase):
def test_dist_train(self): def test_dist_train(self):
download_files() download_files()
self.check_with_place("dist_transformer.py", delta=1.0) self.check_with_place(
"dist_transformer.py", delta=1.0, check_error_log=False)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -85,6 +85,69 @@ class TestDropoutOp5(OpTest): ...@@ -85,6 +85,69 @@ class TestDropoutOp5(OpTest):
self.check_output() self.check_output()
class TestDropoutOp6(TestDropoutOp):
def setUp(self):
self.op_type = "dropout"
self.inputs = {'X': np.random.random((32, 64)).astype("float32")}
self.attrs = {
'dropout_prob': 1.0,
'fix_seed': True,
'is_test': False,
'dropout_implementation': 'upscale_in_train'
}
self.outputs = {
'Out': np.zeros((32, 64)).astype('float32'),
'Mask': np.zeros((32, 64)).astype('float32')
}
class TestDropoutOp7(TestDropoutOp):
def setUp(self):
self.op_type = "dropout"
self.inputs = {'X': np.random.random((32, 64, 2)).astype("float32")}
self.attrs = {
'dropout_prob': 0.0,
'fix_seed': True,
'is_test': False,
'dropout_implementation': 'upscale_in_train'
}
self.outputs = {
'Out': self.inputs['X'],
'Mask': np.ones((32, 64, 2)).astype('float32')
}
class TestDropoutOp8(OpTest):
def setUp(self):
self.op_type = "dropout"
self.inputs = {'X': np.random.random((32, 64)).astype("float32")}
self.attrs = {
'dropout_prob': 0.35,
'fix_seed': True,
'is_test': True,
'dropout_implementation': 'upscale_in_train'
}
self.outputs = {'Out': self.inputs['X']}
def test_check_output(self):
self.check_output()
class TestDropoutOp9(OpTest):
def setUp(self):
self.op_type = "dropout"
self.inputs = {'X': np.random.random((32, 64, 3)).astype("float32")}
self.attrs = {
'dropout_prob': 0.75,
'is_test': True,
'dropout_implementation': 'upscale_in_train'
}
self.outputs = {'Out': self.inputs['X']}
def test_check_output(self):
self.check_output()
class TestFP16DropoutOp(OpTest): class TestFP16DropoutOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "dropout" self.op_type = "dropout"
......
...@@ -125,6 +125,12 @@ class TestFusionGRUOpMD2(TestFusionGRUOp): ...@@ -125,6 +125,12 @@ class TestFusionGRUOpMD2(TestFusionGRUOp):
self.D = 8 self.D = 8
class TestFusionGRUOpMD3(TestFusionGRUOp):
def set_confs(self):
self.M = 17
self.D = 15
class TestFusionGRUOpBS1(TestFusionGRUOp): class TestFusionGRUOpBS1(TestFusionGRUOp):
def set_confs(self): def set_confs(self):
self.lod = [[3]] self.lod = [[3]]
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle.fluid as fluid
import paddle.fluid.core as core
from op_test import OpTest
import numpy as np
class TestSequenceReverseBase(OpTest):
def initParameters(self):
pass
def setUp(self):
self.size = (10, 3, 4)
self.lod = [2, 3, 5]
self.dtype = 'float32'
self.initParameters()
self.op_type = 'sequence_reverse'
self.x = np.random.random(self.size).astype(self.dtype)
self.y = self.get_output()
self.inputs = {'X': (self.x, [self.lod, ]), }
self.outputs = {'Y': (self.y, [self.lod, ]), }
def get_output(self):
tmp_x = np.reshape(self.x, newshape=[self.x.shape[0], -1])
tmp_y = np.ndarray(tmp_x.shape).astype(self.dtype)
prev_idx = 0
for cur_len in self.lod:
idx_range = range(prev_idx, prev_idx + cur_len)
tmp_y[idx_range, :] = np.flip(tmp_x[idx_range, :], 0)
prev_idx += cur_len
return np.reshape(tmp_y, newshape=self.x.shape).astype(self.dtype)
def test_output(self):
self.check_output(0)
def test_grad(self):
self.check_grad(['X'], 'Y')
class TestSequenceReserve1(TestSequenceReverseBase):
def initParameters(self):
self.size = (12, 10)
self.lod = [4, 5, 3]
class TestSequenceReverse2(TestSequenceReverseBase):
def initParameters(self):
self.size = (12, 10)
self.lod = [12]
if __name__ == '__main__':
unittest.main()
...@@ -49,6 +49,7 @@ LOOKUP_TABLE_GRAD_TYPE = "lookup_table_grad" ...@@ -49,6 +49,7 @@ LOOKUP_TABLE_GRAD_TYPE = "lookup_table_grad"
OP_ROLE_VAR_ATTR_NAME = core.op_proto_and_checker_maker.kOpRoleVarAttrName() OP_ROLE_VAR_ATTR_NAME = core.op_proto_and_checker_maker.kOpRoleVarAttrName()
RPC_OP_ROLE_ATTR_NAME = op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName( RPC_OP_ROLE_ATTR_NAME = op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName(
) )
OPT_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.Optimize
RPC_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.RPC RPC_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.RPC
DIST_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.Dist DIST_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.Dist
LR_SCHED_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.LRSched LR_SCHED_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.LRSched
...@@ -1717,8 +1718,10 @@ to transpile() call.") ...@@ -1717,8 +1718,10 @@ to transpile() call.")
lr_ops = [] lr_ops = []
block = self.origin_program.global_block() block = self.origin_program.global_block()
for op in block.ops: for op in block.ops:
if int(op.attr(RPC_OP_ROLE_ATTR_NAME)) == int( role_id = int(op.attr(RPC_OP_ROLE_ATTR_NAME))
LR_SCHED_OP_ROLE_ATTR_VALUE): if role_id == int(LR_SCHED_OP_ROLE_ATTR_VALUE) or \
role_id == int(LR_SCHED_OP_ROLE_ATTR_VALUE) | \
int(OPT_OP_ROLE_ATTR_VALUE):
lr_ops.append(op) lr_ops.append(op)
log("append lr op: ", op.type) log("append lr op: ", op.type)
return lr_ops return lr_ops
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册