提交 a32c6ffa 编写于 作者: L lujun

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into move-api-to-root

......@@ -13,125 +13,186 @@
// limitations under the License.
#include <algorithm>
#include <memory>
#include <map>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/details/all_reduce_deps_pass.h"
#include "paddle/fluid/framework/details/all_reduce_op_handle.h"
#include "paddle/fluid/framework/details/container_cast.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/op_graph_view.h"
#include "paddle/fluid/framework/details/var_handle.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/op_proto_maker.h"
namespace paddle {
namespace framework {
namespace details {
VarHandle* GetValidInput(const OpHandleBase* a) {
for (auto p : a->Inputs()) {
VarHandle* b = dynamic_cast<VarHandle*>(p);
if (b) {
return b;
class AllReduceDepsPass : public ir::Pass {
protected:
void ApplyImpl(ir::Graph* graph) const override {
std::vector<AllReduceOpHandle*> all_reduce_op_handles =
GetSortedAllReduceOps(*graph);
for (size_t i = 1; i < all_reduce_op_handles.size(); ++i) {
auto* dep_var = new DummyVarHandle(graph->CreateControlDepVar());
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var);
all_reduce_op_handles[i - 1]->AddOutput(dep_var);
all_reduce_op_handles[i]->AddInput(dep_var);
}
}
return nullptr;
}
void AllReduceDepsPass::ApplyImpl(ir::Graph* graph) const {
auto graph_ops = ir::FilterByNodeWrapper<OpHandleBase>(*graph);
// get vars order
int order = 0;
std::unordered_map<std::string, int> vars;
// TODO(gongwb): use graph topology sort to find the order of operators.
// Note that must assert topology sort is stable
auto& ops = graph->Get<const std::vector<OpDesc*>>(kStaleProgramOpDescs);
for (auto* op_desc : ops) {
try {
bool is_bk_op =
static_cast<bool>(boost::get<int>(op_desc->GetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName())) &
static_cast<int>(OpRole::kBackward));
if (!is_bk_op) continue;
auto backward_vars =
boost::get<std::vector<std::string>>(op_desc->GetNullableAttr(
OpProtoAndCheckerMaker::OpRoleVarAttrName()));
PADDLE_ENFORCE_EQ(backward_vars.size() % 2, 0);
auto outputs = op_desc->Outputs();
for (auto& o_it : outputs) {
for (auto& v : o_it.second) { // values
vars[v] = order;
VLOG(10) << "in all_reduce_deps_pass:" << v;
}
}
order++;
} catch (boost::bad_get e) {
if (VLOG_IS_ON(10)) {
DebugString(*graph, all_reduce_op_handles);
}
}
std::vector<OpHandleBase*> dist_ops;
// get allreduce ops.
for (auto& op : graph_ops) {
// FIXME(gongwb):add broad cast.
if (op->Name() == "all_reduce" || op->Name() == "reduce") {
dist_ops.push_back(op);
std::vector<AllReduceOpHandle*> GetSortedAllReduceOps(
const ir::Graph& graph) const {
std::vector<AllReduceOpHandle*> all_reduce_op_handles;
std::unordered_map<OpHandleBase*, size_t> pending_ops;
std::unordered_set<OpHandleBase*> ready_ops;
std::unordered_set<OpHandleBase*> next_ready_ops;
auto op_handles = ir::FilterByNodeWrapper<OpHandleBase>(graph);
size_t num_of_ops = op_handles.size();
for (OpHandleBase* op : op_handles) {
size_t not_ready_vars = op->NotReadyInputSize();
if (not_ready_vars) {
pending_ops.insert({op, not_ready_vars});
} else {
ready_ops.insert(op);
}
}
}
VLOG(10) << "dist_ops size:" << dist_ops.size()
<< ", outputs size:" << vars.size() << ", ops size:" << ops.size();
std::sort(dist_ops.begin(), dist_ops.end(), [&](OpHandleBase* op1,
OpHandleBase* op2) {
VarHandle* i0 = dynamic_cast<VarHandle*>(GetValidInput(op1));
VarHandle* i1 = dynamic_cast<VarHandle*>(GetValidInput(op2));
PADDLE_ENFORCE(i0 != nullptr && i1 != nullptr, "%s convert to %s error",
op1->DebugString(), op2->DebugString());
auto l_it = vars.find(i0->name());
auto r_it = vars.find(i1->name());
PADDLE_ENFORCE(l_it != vars.end() && r_it != vars.end(),
"can't find var's name %s and %s in opdesc", i0->name(),
i1->name());
if (l_it->second < r_it->second) return true;
GetSortedAllReduceOps(ready_ops, &all_reduce_op_handles);
size_t has_run_ops = ready_ops.size();
while (has_run_ops != num_of_ops) {
for (auto* op : ready_ops) {
for (auto& ready_var : op->Outputs()) {
for (auto* pend_op : ready_var->PendingOps()) {
auto& deps = --pending_ops[pend_op];
if (deps == 0) {
next_ready_ops.insert(pend_op);
}
}
}
}
if (l_it->second == r_it->second) {
return i0->name() < i1->name();
PADDLE_ENFORCE_NE(next_ready_ops.size(), 0, "There maybe have a cycle.");
ready_ops.clear();
std::swap(ready_ops, next_ready_ops);
GetSortedAllReduceOps(ready_ops, &all_reduce_op_handles);
has_run_ops += ready_ops.size();
}
return all_reduce_op_handles;
}
return false;
});
// add dependency.
auto& sorted_ops = dist_ops;
for (size_t i = 1; i < sorted_ops.size(); ++i) {
auto* dep_var = new DummyVarHandle(graph->CreateControlDepVar());
auto* pre_op = sorted_ops[i - 1];
auto* op = sorted_ops[i];
pre_op->AddOutput(dep_var);
op->AddInput(dep_var);
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var);
void GetSortedAllReduceOps(
const std::unordered_set<OpHandleBase*>& ready_ops,
std::vector<AllReduceOpHandle*>* all_reduce_op_handles) const {
std::vector<AllReduceOpHandle*> current_all_reduce_op_handles;
for (auto& op_handle : ready_ops) {
auto all_reduce_op_handle = dynamic_cast<AllReduceOpHandle*>(op_handle);
if (all_reduce_op_handle) {
current_all_reduce_op_handles.emplace_back(all_reduce_op_handle);
}
}
VLOG(10) << "add all_reduce sequential dependencies between " << pre_op
<< " and " << op;
// NOTE(zcd): For distributed training, it is important to keep the order of
// allReduce on each node consistent. Otherwise, hang may occur.
// Sort the current_all_reduce_op_handles according to the name of input.
sort(current_all_reduce_op_handles.begin(),
current_all_reduce_op_handles.end(),
[](const AllReduceOpHandle* left,
const AllReduceOpHandle* right) -> bool {
auto left_in_vars = DynamicCast<VarHandle>(left->Inputs());
auto right_in_vars = DynamicCast<VarHandle>(right->Inputs());
PADDLE_ENFORCE_GT(left_in_vars.size(), 0);
PADDLE_ENFORCE_EQ(left_in_vars.size(), right_in_vars.size());
return left_in_vars[0]->Name() > right_in_vars[0]->Name();
});
all_reduce_op_handles->insert(all_reduce_op_handles->end(),
current_all_reduce_op_handles.begin(),
current_all_reduce_op_handles.end());
}
VLOG(10) << "pre_op:" << pre_op->DebugString()
<< ", op:" << op->DebugString();
void DebugString(
const ir::Graph& graph,
const std::vector<AllReduceOpHandle*>& all_reduce_op_handles) const {
// get vars order
std::map<int, std::vector<std::string>> vars =
GetSoredGradientsFromStaleProgram(graph);
std::stringstream out;
size_t grads_of_stale_program = 0;
out << "Get Order From kStaleProgramOpDescs: ";
for (auto& var : vars) {
out << "Order " << var.first << " [";
for (auto& var_name : var.second) {
out << var_name << ", ";
++grads_of_stale_program;
}
out << "], ";
}
VLOG(10) << out.str();
std::stringstream out2;
out2 << "Get Order From Topological order: ";
for (auto& op : all_reduce_op_handles) {
bool find_valid_input = false;
for (auto& in_var : op->Inputs()) {
if (dynamic_cast<VarHandle*>(in_var)) {
out2 << in_var->Name() << ", ";
find_valid_input = true;
break;
}
}
PADDLE_ENFORCE(find_valid_input, "Doesn't find valid input.");
}
VLOG(10) << out2.str();
if (grads_of_stale_program != all_reduce_op_handles.size()) {
VLOG(10)
<< "The gradients number of stale program and graph is not equal.";
}
}
}
std::map<int, std::vector<std::string>> GetSoredGradientsFromStaleProgram(
const ir::Graph& graph) const {
std::map<int, std::vector<std::string>> vars;
auto ops = graph.Get<const std::vector<OpDesc*>>(kStaleProgramOpDescs);
int order = 0;
for (auto* op_desc : ops) {
try {
bool is_bk_op =
static_cast<bool>(boost::get<int>(op_desc->GetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName())) &
static_cast<int>(OpRole::kBackward));
if (!is_bk_op) continue;
auto backward_vars =
boost::get<std::vector<std::string>>(op_desc->GetNullableAttr(
OpProtoAndCheckerMaker::OpRoleVarAttrName()));
if (backward_vars.empty()) continue;
PADDLE_ENFORCE_EQ(backward_vars.size() % 2, 0);
for (size_t i = 1; i < backward_vars.size(); i += 2) {
vars[order].emplace_back(backward_vars[i]);
VLOG(1) << "get parameter and gradient: " << backward_vars[i - 1]
<< ", " << backward_vars[i];
}
order++;
} catch (boost::bad_get e) {
}
}
return vars;
}
};
} // namespace details
} // namespace framework
} // namespace paddle
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace details {
// TODO(gongwb): overlap allreduce with backward computation.
class AllReduceDepsPass : public ir::Pass {
protected:
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -28,7 +28,7 @@
// asynchronous nccl allreduce or synchronous issue:
// https://github.com/PaddlePaddle/Paddle/issues/15049
DEFINE_bool(
sync_nccl_allreduce, false,
sync_nccl_allreduce, true,
"If set true, will call `cudaStreamSynchronize(nccl_stream)`"
"after allreduce, this mode can get better performance in some scenarios.");
......
......@@ -163,14 +163,11 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
"graph_printer", new details::GraphvizSSAGraphPrinter);
}
// Verify that the graph is correct for multi-device executor.
AppendPass("multi_devices_check_pass");
if (VLOG_IS_ON(2)) {
AppendPass("all_reduce_deps_pass");
}
if (SeqOnlyAllReduceOps(strategy_)) {
// experimental shows that the program will be faster if append
// all_reduce_deps_pass here.
if (!strategy_.enable_parallel_graph_ &&
(SeqOnlyAllReduceOps(strategy_) ||
strategy.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce)) {
VLOG(10) << "Add all_reduce_deps_pass";
AppendPass("all_reduce_deps_pass");
}
......@@ -179,6 +176,9 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
VLOG(10) << "Add modify_op_lock_and_record_event_pass";
AppendPass("modify_op_lock_and_record_event_pass");
}
// Verify that the graph is correct for multi-device executor.
AppendPass("multi_devices_check_pass");
}
// Convert graph to run on multi-devices.
......
......@@ -91,7 +91,11 @@ struct BuildStrategy {
bool enable_sequential_execution_{false};
bool fuse_broadcast_op_{false};
// NOTE(zcd): In reduce mode, fusing broadcast ops may make the program
// faster. Because fusing broadcast OP equals delaying the execution of all
// broadcast Ops, in this case, all nccl streams are used only for reduce
// operations for a period of time.
bool fuse_broadcast_ops_{false};
// FIXME(zcd): is_distribution_ is a temporary field, because in pserver mode,
// num_trainers is 1, so the current fields of build_strategy doesn't tell if
......
......@@ -56,6 +56,7 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run(
fetches.resize(fetch_tensors.size());
std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars;
std::vector<FetchOpHandle *> fetch_ops;
std::vector<OpHandleBase *> ready_fetch_ops;
for (auto &fetch_var_name : fetch_tensors) {
for (auto &var_map : graph_->Get<details::GraphVars>(details::kGraphVars)) {
......@@ -70,8 +71,9 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run(
auto &var_name = fetch_tensors[i];
auto fetched_var_it = fetched_vars.find(var_name);
PADDLE_ENFORCE(fetched_var_it != fetched_vars.end(),
"Cannot find fetched variable.(Perhaps the main_program "
"is not set to ParallelExecutor)");
"Cannot find fetched variable(%s).(Perhaps the main_program "
"is not set to ParallelExecutor)",
var_name);
auto &vars = fetched_var_it->second;
......@@ -88,7 +90,11 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run(
op->AddInput(var);
}
(*op_deps)[op] = static_cast<int>(op->NotReadyInputSize());
int dep = static_cast<int>(op->NotReadyInputSize());
(*op_deps)[op] = dep;
if (dep == 0) {
ready_fetch_ops.emplace_back(op);
}
}
size_t num_complete = 0;
......@@ -97,7 +103,9 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run(
for (auto op : bootstrap_ops_) {
RunOpAsync(op_deps.get(), op, complete_q);
}
for (auto op : ready_fetch_ops) {
RunOpAsync(op_deps.get(), op, complete_q);
}
while (num_complete != op_deps->size()) {
size_t num_comp = complete_q->Pop();
if (num_comp == -1UL) {
......
......@@ -13,9 +13,9 @@
// limitations under the License.
#include "paddle/fluid/framework/details/fetch_op_handle.h"
#include <string>
#include <vector>
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
namespace framework {
......@@ -44,6 +44,7 @@ void FetchOpHandle::WaitAndMergeCPUTensors() const {
}
void FetchOpHandle::RunImpl() {
platform::RecordEvent record_event(Name());
WaitInputVarGenerated(platform::CPUPlace());
tensors_.resize(inputs_.size());
......@@ -62,7 +63,8 @@ void FetchOpHandle::RunImpl() {
auto &t = var->Get<framework::LoDTensor>();
if (platform::is_gpu_place(t.place())) {
#ifdef PADDLE_WITH_CUDA
TensorCopySync(t, cpu, &tensors_[i]);
TensorCopy(t, cpu, *dev_ctxes_.at(t.place()), &tensors_[i]);
dev_ctxes_.at(t.place())->Wait();
#endif
} else {
tensors_[i].ShareDataWith(t);
......
......@@ -658,7 +658,7 @@ bool ReduceSSAGraphBuilder::DealWithSpecialOp(ir::Graph *result,
void ReduceSSAGraphBuilder::InsertPostprocessOps(ir::Graph *result) const {
if (UseGPU()) {
if (strategy_.fuse_broadcast_op_) {
if (strategy_.fuse_broadcast_ops_) {
CreateFusedBroadcastOp(result, bcast_var_name_set_);
} else {
for (size_t dev_id = 0; dev_id < bcast_var_name_set_.size(); ++dev_id) {
......@@ -1021,7 +1021,7 @@ void DistSSAGraphBuilder::InsertPostprocessOps(ir::Graph *result) const {
strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce) {
return;
}
if (strategy_.fuse_broadcast_op_) {
if (strategy_.fuse_broadcast_ops_) {
CreateFusedBroadcastOp(result, bcast_var_name_set_);
} else {
for (size_t dev_id = 0; dev_id < bcast_var_name_set_.size(); ++dev_id) {
......
......@@ -68,7 +68,7 @@ void OpHandleBase::Run(bool use_cuda) {
if (out_var_handle) {
PADDLE_ENFORCE(
platform::is_same_place(place, out_var_handle->place()),
"The place of input(%s) is not consistent with the "
"The place of output(%s) is not consistent with the "
"place of current op(%s).",
out_var_handle->Name(), Name());
out_var_handle->SetGenerateEvent(events_.at(dev_id));
......
......@@ -80,7 +80,6 @@ inline FeedFetchList ThreadedSSAGraphExecutor::RunImpl(
}
set.clear();
};
auto run_all_op = [&](OpHandleBase *op) { RunOp(ready_vars, op); };
// Clean run context
run_op_futures_.clear();
exception_holder_.Clear();
......@@ -116,7 +115,7 @@ inline FeedFetchList ThreadedSSAGraphExecutor::RunImpl(
auto &deps = pending_ops[op];
--deps;
if (deps == 0) {
run_all_op(op);
ready_ops.insert(op);
}
}
}
......
......@@ -617,6 +617,25 @@ void OpDesc::Flush() {
static std::once_flag init_infer_shape_funcs;
/**
* NOTE(paddle-dev): Very tricky code here. Maybe we should find a
* better way to register compile-time infershape method gentlely.
*
* Normally, we can register a class derived from InferShapeBase, so that
* we can set the field of `infer_shape_` inside OpInfo when registering op.
*
* However, there is another way we can set the field of `infer_shape_` inside
* OpInfo. Usually, we overload InferShape method of OperatorWithKernel. After
* running the following method InitInferShapeFuncs, `infer_shape_` would be set
* to be the InferShape method of OperatorWithKernel. That is to say, we borrow
* the run-time InferShape method of OperatorWithKernel to be the compile-time
* InferShape method.
*
* However, during compiling time, we may not know inputs, outputs and attrs of
* run-time OperatorWithKernel. So the following code creates a fake
* OperatorWithKernel object. That is why the field info_ of OperatorBase
* would be null.
*/
static void InitInferShapeFuncs() {
std::call_once(init_infer_shape_funcs, [] {
auto &map = OpInfoMap::Instance();
......@@ -628,11 +647,16 @@ static void InitInferShapeFuncs() {
PADDLE_ENFORCE(it != info_map.end(), "%s has not been registered",
op_type);
auto &op_info = it->second;
auto op = static_cast<OperatorWithKernel *>(op_info.Creator()(
"", VariableNameMap{}, VariableNameMap{}, AttributeMap{}));
if (op_info.infer_shape_) { // infer_shape has been registered.
continue;
}
auto op = dynamic_cast<OperatorWithKernel *>(op_info.Creator()(
"", VariableNameMap{}, VariableNameMap{}, AttributeMap{}));
PADDLE_ENFORCE_NOT_NULL(
op, "InferShapeBase is not registered to Operator %s", op_type);
op_info.infer_shape_ = [op](InferShapeContext *ctx) {
op->InferShape(ctx);
};
......
......@@ -19,11 +19,6 @@ limitations under the License. */
#include <tuple>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/details/all_reduce_deps_pass.h"
#include "paddle/fluid/framework/details/async_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
......@@ -31,6 +26,8 @@ limitations under the License. */
#include "paddle/fluid/framework/details/reference_count_pass_helper.h"
#include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/platform/profiler.h"
#ifdef WITH_GPERFTOOLS
......
......@@ -142,7 +142,6 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) {
void AnalysisConfig::EnableMKLDNN() {
#ifdef PADDLE_WITH_MKLDNN
pass_builder()->EnableMKLDNN();
use_mkldnn_ = true;
#else
LOG(ERROR) << "Please compile with MKLDNN first to use MKLDNN";
......@@ -235,16 +234,13 @@ void AnalysisConfig::Update() {
}
if (use_mkldnn_) {
#ifdef PADDLE_WITH_MKLDNN
if (!enable_ir_optim_) {
LOG(ERROR)
<< "EnableMKLDNN() only works when IR optimization is enabled.";
} else {
pass_builder()->EnableMKLDNN();
}
#ifdef PADDLE_WITH_MKLDNN
pass_builder()->EnableMKLDNN();
use_mkldnn_ = true;
#else
LOG(ERROR) << "Please compile with MKLDNN first to use MKLDNN";
use_mkldnn_ = false;
#endif
}
......@@ -256,9 +252,6 @@ void AnalysisConfig::Update() {
}
#ifdef PADDLE_WITH_MKLDNN
pass_builder()->EnableMkldnnQuantizer();
#else
LOG(ERROR) << "Please compile with MKLDNN first to use MkldnnQuantizer";
use_mkldnn_quantizer_ = false;
#endif
}
......
......@@ -27,6 +27,7 @@
#include <string>
#include <vector>
#include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/port.h"
#include "paddle/fluid/string/printf.h"
......@@ -266,17 +267,17 @@ static std::string DescribeZeroCopyTensor(const ZeroCopyTensor &tensor) {
}
static void PrintTime(int batch_size, int repeat, int num_threads, int tid,
double latency, int epoch = 1) {
LOG(INFO) << "====== batch_size: " << batch_size << ", repeat: " << repeat
<< ", threads: " << num_threads << ", thread id: " << tid
<< ", latency: " << latency << "ms, fps: " << 1 / (latency / 1000.f)
double batch_latency, int epoch = 1) {
PADDLE_ENFORCE(batch_size > 0, "Non-positive batch size.");
double sample_latency = batch_latency / batch_size;
LOG(INFO) << "====== threads: " << num_threads << ", thread id: " << tid
<< " ======";
if (epoch > 1) {
int samples = batch_size * epoch;
LOG(INFO) << "====== sample number: " << samples
<< ", average latency of each sample: " << latency / samples
<< "ms ======";
}
LOG(INFO) << "====== batch_size: " << batch_size << ", iterations: " << epoch
<< ", repetitions: " << repeat << " ======";
LOG(INFO) << "====== batch latency: " << batch_latency
<< "ms, number of samples: " << batch_size * epoch
<< ", sample latency: " << sample_latency
<< "ms, fps: " << 1000.f / sample_latency << " ======";
}
static bool IsFileExists(const std::string &path) {
......
......@@ -64,10 +64,12 @@ void PaddlePassBuilder::DeletePass(size_t idx) {
passes_.erase(std::begin(passes_) + idx);
}
void GpuPassStrategy::EnableMKLDNN() {
LOG(ERROR) << "GPU not support MKLDNN yet";
void PaddlePassBuilder::AppendAnalysisPass(const std::string &pass) {
analysis_passes_.push_back(pass);
}
void PaddlePassBuilder::ClearPasses() { passes_.clear(); }
// The following passes works for Anakin sub-graph engine.
const std::vector<std::string> kAnakinSubgraphPasses({
"infer_clean_graph_pass", //
......@@ -102,12 +104,12 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
use_gpu_ = true;
}
void GpuPassStrategy::EnableMkldnnQuantizer() {
LOG(ERROR) << "GPU not support MKL-DNN quantization";
void GpuPassStrategy::EnableMKLDNN() {
LOG(ERROR) << "GPU not support MKLDNN yet";
}
void PaddlePassBuilder::AppendAnalysisPass(const std::string &pass) {
analysis_passes_.push_back(pass);
void GpuPassStrategy::EnableMkldnnQuantizer() {
LOG(ERROR) << "GPU not support MKL-DNN quantization";
}
CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) {
......@@ -135,5 +137,39 @@ CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) {
});
use_gpu_ = false;
}
void PaddlePassBuilder::ClearPasses() { passes_.clear(); }
void CpuPassStrategy::EnableMKLDNN() {
// TODO(Superjomn) Consider the way to mix CPU with GPU.
#ifdef PADDLE_WITH_MKLDNN
if (!use_mkldnn_) {
passes_.insert(passes_.begin(), "mkldnn_placement_pass");
for (auto &pass : std::vector<std::string>(
{"depthwise_conv_mkldnn_pass", //
"conv_bn_fuse_pass", // Execute BN passes again to
"conv_eltwiseadd_bn_fuse_pass", // preserve correct pass order
"conv_bias_mkldnn_fuse_pass", //
"conv3d_bias_mkldnn_fuse_pass", //
"conv_elementwise_add_mkldnn_fuse_pass",
"conv_relu_mkldnn_fuse_pass"})) {
passes_.push_back(pass);
}
}
use_mkldnn_ = true;
#else
use_mkldnn_ = false;
#endif
}
void CpuPassStrategy::EnableMkldnnQuantizer() {
#ifdef PADDLE_WITH_MKLDNN
if (!use_mkldnn_quantizer_) {
passes_.push_back("cpu_quantize_placement_pass");
}
use_mkldnn_quantizer_ = true;
#else
use_mkldnn_quantizer_ = false;
#endif
}
} // namespace paddle
......@@ -109,43 +109,16 @@ class CpuPassStrategy : public PassStrategy {
CpuPassStrategy();
explicit CpuPassStrategy(const CpuPassStrategy &other)
: PassStrategy(other.AllPasses()) {}
: PassStrategy(other.AllPasses()) {
use_gpu_ = other.use_gpu_;
use_mkldnn_ = other.use_mkldnn_;
use_mkldnn_quantizer_ = other.use_mkldnn_quantizer_;
}
virtual ~CpuPassStrategy() = default;
void EnableMKLDNN() override {
// TODO(Superjomn) Consider the way to mix CPU with GPU.
#ifdef PADDLE_WITH_MKLDNN
if (!use_mkldnn_) {
passes_.insert(passes_.begin(), "mkldnn_placement_pass");
for (auto &pass : std::vector<std::string>(
{"depthwise_conv_mkldnn_pass", //
"conv_bn_fuse_pass", // Execute BN passes again to
"conv_eltwiseadd_bn_fuse_pass", // preserve correct pass order
"conv_bias_mkldnn_fuse_pass", //
"conv3d_bias_mkldnn_fuse_pass", //
"conv_relu_mkldnn_fuse_pass", //
"conv_elementwise_add_mkldnn_fuse_pass"})) {
passes_.push_back(pass);
}
}
use_mkldnn_ = true;
#else
use_mkldnn_ = false;
#endif
}
void EnableMkldnnQuantizer() override {
#ifdef PADDLE_WITH_MKLDNN
if (!use_mkldnn_quantizer_) {
passes_.push_back("cpu_quantize_placement_pass");
}
use_mkldnn_quantizer_ = true;
#else
use_mkldnn_quantizer_ = false;
#endif
}
void EnableMKLDNN() override;
void EnableMkldnnQuantizer() override;
protected:
bool use_mkldnn_quantizer_{false};
......
......@@ -26,7 +26,11 @@ endfunction()
function(inference_analysis_api_int8_test target model_dir data_dir filename)
inference_analysis_test(${target} SRCS ${filename}
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} benchmark
ARGS --infer_model=${model_dir}/model --infer_data=${data_dir}/data.bin --batch_size=100)
ARGS --infer_model=${model_dir}/model
--infer_data=${data_dir}/data.bin
--warmup_batch_size=100
--batch_size=50
--iterations=2)
endfunction()
function(inference_analysis_api_test_with_fake_data target install_dir filename model_name)
......@@ -146,22 +150,22 @@ inference_analysis_api_test_with_fake_data(test_analyzer_mobilenet_depthwise_con
# int8 image classification tests
if(WITH_MKLDNN)
set(INT8_DATA_DIR "${INFERENCE_DEMO_INSTALL_DIR}/int8")
set(INT8_DATA_DIR "${INFERENCE_DEMO_INSTALL_DIR}/int8v2")
if (NOT EXISTS ${INT8_DATA_DIR})
inference_download_and_uncompress(${INT8_DATA_DIR} ${INFERENCE_URL}"/int8" "imagenet_val_100.tar.gz")
inference_download_and_uncompress(${INT8_DATA_DIR} "${INFERENCE_URL}/int8" "imagenet_val_100_tail.tar.gz")
endif()
#resnet50 int8
set(INT8_RESNET50_MODEL_DIR "${INT8_DATA_DIR}/resnet50")
if (NOT EXISTS ${INT8_RESNET50_MODEL_DIR})
inference_download_and_uncompress(${INT8_RESNET50_MODEL_DIR} ${INFERENCE_URL}"/int8" "resnet50_int8_model.tar.gz" )
inference_download_and_uncompress(${INT8_RESNET50_MODEL_DIR} "${INFERENCE_URL}/int8" "resnet50_int8_model.tar.gz" )
endif()
inference_analysis_api_int8_test(test_analyzer_int8_resnet50 ${INT8_RESNET50_MODEL_DIR} ${INT8_DATA_DIR} analyzer_int8_image_classification_tester.cc SERIAL)
#mobilenet int8
set(INT8_MOBILENET_MODEL_DIR "${INT8_DATA_DIR}/mobilenet")
if (NOT EXISTS ${INT8_MOBILENET_MODEL_DIR})
inference_download_and_uncompress(${INT8_MOBILENET_MODEL_DIR} ${INFERENCE_URL}"/int8" "mobilenetv1_int8_model.tar.gz" )
inference_download_and_uncompress(${INT8_MOBILENET_MODEL_DIR} "${INFERENCE_URL}/int8" "mobilenetv1_int8_model.tar.gz" )
endif()
inference_analysis_api_int8_test(test_analyzer_int8_mobilenet ${INT8_MOBILENET_MODEL_DIR} ${INT8_DATA_DIR} analyzer_int8_image_classification_tester.cc SERIAL)
endif()
......
......@@ -154,7 +154,7 @@ void profile(bool use_mkldnn = false) {
config.EnableMKLDNN();
}
std::vector<PaddleTensor> outputs;
std::vector<std::vector<PaddleTensor>> outputs;
std::vector<std::vector<PaddleTensor>> inputs;
LoadInputData(&inputs);
TestPrediction(reinterpret_cast<const PaddlePredictor::Config *>(&config),
......
......@@ -197,7 +197,7 @@ void profile(bool use_mkldnn = false) {
cfg.SetMKLDNNOp(op_list);
}
std::vector<PaddleTensor> outputs;
std::vector<std::vector<PaddleTensor>> outputs;
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all);
......@@ -206,9 +206,11 @@ void profile(bool use_mkldnn = false) {
if (FLAGS_num_threads == 1 && !FLAGS_test_all_data) {
PADDLE_ENFORCE_GT(outputs.size(), 0);
size_t size = GetSize(outputs[0]);
auto output = outputs.back();
PADDLE_ENFORCE_GT(output.size(), 0);
size_t size = GetSize(output[0]);
PADDLE_ENFORCE_GT(size, 0);
float *result = static_cast<float *>(outputs[0].data.data());
float *result = static_cast<float *>(output[0].data.data());
for (size_t i = 0; i < size; i++) {
EXPECT_NEAR(result[i], result_data[i], 1e-3);
}
......
......@@ -17,8 +17,6 @@ limitations under the License. */
#include "paddle/fluid/inference/api/paddle_analysis_config.h"
#include "paddle/fluid/inference/tests/api/tester_helper.h"
DEFINE_int32(iterations, 0, "Number of iterations");
namespace paddle {
namespace inference {
namespace analysis {
......@@ -30,8 +28,13 @@ void SetConfig(AnalysisConfig *cfg) {
cfg->SwitchIrOptim();
cfg->SwitchSpecifyInputNames(false);
cfg->SetCpuMathLibraryNumThreads(FLAGS_paddle_num_threads);
cfg->EnableMKLDNN();
cfg->pass_builder()->SetPasses(
{"infer_clean_graph_pass", "mkldnn_placement_pass",
"depthwise_conv_mkldnn_pass", "conv_bn_fuse_pass",
"conv_eltwiseadd_bn_fuse_pass", "conv_bias_mkldnn_fuse_pass",
"conv_elementwise_add_mkldnn_fuse_pass", "conv_relu_mkldnn_fuse_pass",
"fc_fuse_pass", "is_test_pass"});
}
template <typename T>
......@@ -40,8 +43,8 @@ class TensorReader {
TensorReader(std::ifstream &file, size_t beginning_offset,
std::vector<int> shape, std::string name)
: file_(file), position(beginning_offset), shape_(shape), name_(name) {
numel =
std::accumulate(shape_.begin(), shape_.end(), 1, std::multiplies<T>());
numel = std::accumulate(shape_.begin(), shape_.end(), size_t{1},
std::multiplies<size_t>());
}
PaddleTensor NextBatch() {
......@@ -71,10 +74,14 @@ class TensorReader {
};
std::shared_ptr<std::vector<PaddleTensor>> GetWarmupData(
const std::vector<std::vector<PaddleTensor>> &test_data, int num_images) {
const std::vector<std::vector<PaddleTensor>> &test_data,
int num_images = FLAGS_warmup_batch_size) {
int test_data_batch_size = test_data[0][0].shape[0];
CHECK_LE(static_cast<size_t>(num_images),
test_data.size() * test_data_batch_size);
auto iterations_max = test_data.size();
PADDLE_ENFORCE(
static_cast<size_t>(num_images) <= iterations_max * test_data_batch_size,
"The requested quantization warmup data size " +
std::to_string(num_images) + " is bigger than all test data size.");
PaddleTensor images;
images.name = "input";
......@@ -120,20 +127,17 @@ void SetInput(std::vector<std::vector<PaddleTensor>> *inputs,
std::vector<int> image_batch_shape{batch_size, 3, 224, 224};
std::vector<int> label_batch_shape{batch_size, 1};
auto images_offset_in_file = static_cast<size_t>(file.tellg());
auto labels_offset_in_file =
static_cast<size_t>(file.tellg()) +
sizeof(float) * total_images *
std::accumulate(image_batch_shape.begin() + 1,
image_batch_shape.end(), 1, std::multiplies<int>());
images_offset_in_file + sizeof(float) * total_images * 3 * 224 * 224;
TensorReader<float> image_reader(file, 0, image_batch_shape, "input");
TensorReader<float> image_reader(file, images_offset_in_file,
image_batch_shape, "input");
TensorReader<int64_t> label_reader(file, labels_offset_in_file,
label_batch_shape, "label");
auto iterations = total_images / batch_size;
if (FLAGS_iterations > 0 && FLAGS_iterations < iterations)
iterations = FLAGS_iterations;
for (auto i = 0; i < iterations; i++) {
auto iterations_max = total_images / batch_size;
for (auto i = 0; i < iterations_max; i++) {
auto images = image_reader.NextBatch();
auto labels = label_reader.NextBatch();
inputs->emplace_back(
......@@ -148,20 +152,21 @@ TEST(Analyzer_int8_resnet50, quantization) {
AnalysisConfig q_cfg;
SetConfig(&q_cfg);
// read data from file and prepare batches with test data
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all, 100);
SetInput(&input_slots_all);
// prepare warmup batch from input data read earlier
// warmup batch size can be different than batch size
std::shared_ptr<std::vector<PaddleTensor>> warmup_data =
GetWarmupData(input_slots_all, 100);
GetWarmupData(input_slots_all);
// configure quantizer
q_cfg.EnableMkldnnQuantizer();
q_cfg.mkldnn_quantizer_config()->SetWarmupData(warmup_data);
q_cfg.mkldnn_quantizer_config()->SetWarmupBatchSize(100);
q_cfg.mkldnn_quantizer_config()->SetWarmupBatchSize(FLAGS_warmup_batch_size);
CompareQuantizedAndAnalysis(
reinterpret_cast<const PaddlePredictor::Config *>(&cfg),
reinterpret_cast<const PaddlePredictor::Config *>(&q_cfg),
input_slots_all);
CompareQuantizedAndAnalysis(&cfg, &q_cfg, input_slots_all);
}
} // namespace analysis
......
......@@ -124,7 +124,7 @@ void SetInput(std::vector<std::vector<PaddleTensor>> *inputs) {
TEST(Analyzer_LAC, profile) {
AnalysisConfig cfg;
SetConfig(&cfg);
std::vector<PaddleTensor> outputs;
std::vector<std::vector<PaddleTensor>> outputs;
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all);
......@@ -137,11 +137,13 @@ TEST(Analyzer_LAC, profile) {
24, 25, 25, 25, 38, 30, 31, 14, 15, 44, 24, 25, 25, 25, 25, 25,
44, 24, 25, 25, 25, 36, 42, 43, 44, 14, 15, 44, 14, 15, 44, 14,
15, 44, 38, 39, 14, 15, 44, 22, 23, 23, 23, 23, 23, 23, 23};
PADDLE_ENFORCE_EQ(outputs.size(), 1UL);
size_t size = GetSize(outputs[0]);
PADDLE_ENFORCE_GT(outputs.size(), 0);
auto output = outputs.back();
PADDLE_ENFORCE_EQ(output.size(), 1UL);
size_t size = GetSize(output[0]);
size_t batch1_size = sizeof(lac_ref_data) / sizeof(int64_t);
PADDLE_ENFORCE_GE(size, batch1_size);
int64_t *pdata = static_cast<int64_t *>(outputs[0].data.data());
int64_t *pdata = static_cast<int64_t *>(output[0].data.data());
for (size_t i = 0; i < batch1_size; ++i) {
EXPECT_EQ(pdata[i], lac_ref_data[i]);
}
......
......@@ -96,7 +96,7 @@ void SetInput(std::vector<std::vector<PaddleTensor>> *inputs) {
void profile(bool use_mkldnn = false) {
AnalysisConfig cfg;
SetConfig(&cfg);
std::vector<PaddleTensor> outputs;
std::vector<std::vector<PaddleTensor>> outputs;
if (use_mkldnn) {
cfg.EnableMKLDNN();
......@@ -108,8 +108,9 @@ void profile(bool use_mkldnn = false) {
input_slots_all, &outputs, FLAGS_num_threads);
if (FLAGS_num_threads == 1 && !FLAGS_test_all_data) {
PADDLE_ENFORCE_EQ(outputs.size(), 2UL);
for (auto &output : outputs) {
PADDLE_ENFORCE_GT(outputs.size(), 0);
PADDLE_ENFORCE_EQ(outputs.back().size(), 2UL);
for (auto &output : outputs.back()) {
size_t size = GetSize(output);
PADDLE_ENFORCE_GT(size, 0);
float *result = static_cast<float *>(output.data.data());
......
......@@ -106,7 +106,7 @@ void SetInput(std::vector<std::vector<PaddleTensor>> *inputs) {
void profile(bool memory_load = false) {
AnalysisConfig cfg;
SetConfig(&cfg, memory_load);
std::vector<PaddleTensor> outputs;
std::vector<std::vector<PaddleTensor>> outputs;
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all);
......@@ -117,10 +117,12 @@ void profile(bool memory_load = false) {
// the first inference result
const int chinese_ner_result_data[] = {30, 45, 41, 48, 17, 26,
48, 39, 38, 16, 25};
PADDLE_ENFORCE_EQ(outputs.size(), 1UL);
size_t size = GetSize(outputs[0]);
PADDLE_ENFORCE_GT(outputs.size(), 0);
auto output = outputs.back();
PADDLE_ENFORCE_EQ(output.size(), 1UL);
size_t size = GetSize(output[0]);
PADDLE_ENFORCE_GT(size, 0);
int64_t *result = static_cast<int64_t *>(outputs[0].data.data());
int64_t *result = static_cast<int64_t *>(output[0].data.data());
for (size_t i = 0; i < std::min(11UL, size); i++) {
EXPECT_EQ(result[i], chinese_ner_result_data[i]);
}
......
......@@ -127,7 +127,7 @@ void SetInput(std::vector<std::vector<PaddleTensor>> *inputs) {
TEST(Analyzer_Pyramid_DNN, profile) {
AnalysisConfig cfg;
SetConfig(&cfg);
std::vector<PaddleTensor> outputs;
std::vector<std::vector<PaddleTensor>> outputs;
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all);
......@@ -135,10 +135,12 @@ TEST(Analyzer_Pyramid_DNN, profile) {
input_slots_all, &outputs, FLAGS_num_threads);
if (FLAGS_num_threads == 1 && !FLAGS_test_all_data && !FLAGS_zero_copy) {
PADDLE_ENFORCE_EQ(outputs.size(), 1UL);
size_t size = GetSize(outputs[0]);
PADDLE_ENFORCE_GT(outputs.size(), 0);
auto output = outputs.back();
PADDLE_ENFORCE_EQ(output.size(), 1UL);
size_t size = GetSize(output[0]);
PADDLE_ENFORCE_GT(size, 0);
float *result = static_cast<float *>(outputs[0].data.data());
float *result = static_cast<float *>(output[0].data.data());
// output is probability, which is in (0, 1).
for (size_t i = 0; i < size; i++) {
EXPECT_GT(result[i], 0);
......
......@@ -40,7 +40,7 @@ void profile(bool use_mkldnn = false) {
if (use_mkldnn) {
cfg.EnableMKLDNN();
}
std::vector<PaddleTensor> outputs;
std::vector<std::vector<PaddleTensor>> outputs;
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all);
......
......@@ -229,7 +229,7 @@ TEST(Analyzer_rnn1, profile) {
SetConfig(&cfg);
cfg.DisableGpu();
cfg.SwitchIrDebug();
std::vector<PaddleTensor> outputs;
std::vector<std::vector<PaddleTensor>> outputs;
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all);
......@@ -280,7 +280,7 @@ TEST(Analyzer_rnn1, compare_determine) {
TEST(Analyzer_rnn1, multi_thread) {
AnalysisConfig cfg;
SetConfig(&cfg);
std::vector<PaddleTensor> outputs;
std::vector<std::vector<PaddleTensor>> outputs;
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all);
......
......@@ -126,7 +126,7 @@ void SetInput(std::vector<std::vector<PaddleTensor>> *inputs) {
TEST(Analyzer_rnn2, profile) {
AnalysisConfig cfg;
SetConfig(&cfg);
std::vector<PaddleTensor> outputs;
std::vector<std::vector<PaddleTensor>> outputs;
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all);
......@@ -136,9 +136,11 @@ TEST(Analyzer_rnn2, profile) {
if (FLAGS_num_threads == 1 && !FLAGS_test_all_data) {
// the first inference result
PADDLE_ENFORCE_GT(outputs.size(), 0);
size_t size = GetSize(outputs[0]);
auto output = outputs.back();
PADDLE_ENFORCE_GT(output.size(), 0);
size_t size = GetSize(output[0]);
PADDLE_ENFORCE_GT(size, 0);
float *result = static_cast<float *>(outputs[0].data.data());
float *result = static_cast<float *>(output[0].data.data());
for (size_t i = 0; i < size; i++) {
EXPECT_NEAR(result[i], result_data[i], 1e-3);
}
......
......@@ -110,7 +110,7 @@ void SetInput(std::vector<std::vector<PaddleTensor>> *inputs) {
TEST(Analyzer_seq_conv1, profile) {
AnalysisConfig cfg;
SetConfig(&cfg);
std::vector<PaddleTensor> outputs;
std::vector<std::vector<PaddleTensor>> outputs;
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all);
......@@ -119,10 +119,12 @@ TEST(Analyzer_seq_conv1, profile) {
if (FLAGS_num_threads == 1 && !FLAGS_test_all_data) {
// the first inference result
PADDLE_ENFORCE_EQ(outputs.size(), 1UL);
size_t size = GetSize(outputs[0]);
PADDLE_ENFORCE_GT(outputs.size(), 0);
auto output = outputs.back();
PADDLE_ENFORCE_EQ(output.size(), 1UL);
size_t size = GetSize(output[0]);
PADDLE_ENFORCE_GT(size, 0);
float *result = static_cast<float *>(outputs[0].data.data());
float *result = static_cast<float *>(output[0].data.data());
// output is probability, which is in (0, 1).
for (size_t i = 0; i < size; i++) {
EXPECT_GT(result[i], 0);
......
......@@ -156,7 +156,7 @@ void profile(bool use_mkldnn = false) {
AnalysisConfig cfg;
SetConfig(&cfg, use_mkldnn);
std::vector<PaddleTensor> outputs;
std::vector<std::vector<PaddleTensor>> outputs;
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all);
TestPrediction(reinterpret_cast<const PaddlePredictor::Config *>(&cfg),
......
......@@ -70,7 +70,7 @@ TEST(Analyzer_Text_Classification, profile) {
AnalysisConfig cfg;
SetConfig(&cfg);
cfg.SwitchIrDebug();
std::vector<PaddleTensor> outputs;
std::vector<std::vector<PaddleTensor>> outputs;
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all);
......@@ -79,8 +79,9 @@ TEST(Analyzer_Text_Classification, profile) {
if (FLAGS_num_threads == 1) {
// Get output
LOG(INFO) << "get outputs " << outputs.size();
for (auto &output : outputs) {
PADDLE_ENFORCE_GT(outputs.size(), 0);
LOG(INFO) << "get outputs " << outputs.back().size();
for (auto &output : outputs.back()) {
LOG(INFO) << "output.shape: " << to_string(output.shape);
// no lod ?
CHECK_EQ(output.lod.size(), 0UL);
......
......@@ -186,7 +186,7 @@ void SetInput(std::vector<std::vector<PaddleTensor>> *inputs) {
void profile(bool use_mkldnn = false) {
AnalysisConfig cfg;
SetConfig(&cfg);
std::vector<PaddleTensor> outputs;
std::vector<std::vector<PaddleTensor>> outputs;
if (use_mkldnn) {
cfg.EnableMKLDNN();
}
......
......@@ -87,7 +87,7 @@ void profile(bool use_mkldnn = false) {
cfg.EnableMKLDNN();
}
// cfg.pass_builder()->TurnOnDebug();
std::vector<PaddleTensor> outputs;
std::vector<std::vector<PaddleTensor>> outputs;
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all);
......@@ -100,7 +100,8 @@ void profile(bool use_mkldnn = false) {
auto refer = ProcessALine(line);
file.close();
auto &output = outputs.front();
PADDLE_ENFORCE_GT(outputs.size(), 0);
auto &output = outputs.back().front();
size_t numel = output.data.length() / PaddleDtypeSize(output.dtype);
CHECK_EQ(numel, refer.data.size());
for (size_t i = 0; i < numel; ++i) {
......
......@@ -41,7 +41,10 @@ DEFINE_string(model_name, "", "model name");
DEFINE_string(infer_model, "", "model path");
DEFINE_string(infer_data, "", "data file");
DEFINE_string(refer_result, "", "reference result for comparison");
DEFINE_int32(batch_size, 1, "batch size.");
DEFINE_int32(batch_size, 1, "batch size");
DEFINE_int32(warmup_batch_size, 100, "batch size for quantization warmup");
// setting iterations to 0 means processing the whole dataset
DEFINE_int32(iterations, 0, "number of batches to process");
DEFINE_int32(repeat, 1, "Running the inference program repeat times.");
DEFINE_bool(test_all_data, false, "Test the all dataset in data file.");
DEFINE_int32(num_threads, 1, "Running the inference program in multi-threads.");
......@@ -239,7 +242,7 @@ void SetFakeImageInput(std::vector<std::vector<PaddleTensor>> *inputs,
}
input.shape = shape;
input.dtype = PaddleDType::FLOAT32;
size_t len = std::accumulate(shape.begin(), shape.end(), 1,
size_t len = std::accumulate(shape.begin(), shape.end(), size_t{1},
[](int a, int b) { return a * b; });
input.data.Resize(len * sizeof(float));
input.lod.assign({{0, static_cast<size_t>(FLAGS_batch_size)}});
......@@ -286,17 +289,18 @@ void ConvertPaddleTensorToZeroCopyTensor(
void PredictionWarmUp(PaddlePredictor *predictor,
const std::vector<std::vector<PaddleTensor>> &inputs,
std::vector<PaddleTensor> *outputs, int num_threads,
int tid) {
std::vector<std::vector<PaddleTensor>> *outputs,
int num_threads, int tid) {
int batch_size = FLAGS_batch_size;
LOG(INFO) << "Running thread " << tid << ", warm up run...";
if (FLAGS_zero_copy) {
ConvertPaddleTensorToZeroCopyTensor(predictor, inputs[0]);
}
outputs->resize(1);
Timer warmup_timer;
warmup_timer.tic();
if (!FLAGS_zero_copy) {
predictor->Run(inputs[0], outputs, batch_size);
predictor->Run(inputs[0], &(*outputs)[0], batch_size);
} else {
predictor->ZeroCopyRun();
}
......@@ -308,11 +312,16 @@ void PredictionWarmUp(PaddlePredictor *predictor,
void PredictionRun(PaddlePredictor *predictor,
const std::vector<std::vector<PaddleTensor>> &inputs,
std::vector<PaddleTensor> *outputs, int num_threads,
int tid) {
int batch_size = FLAGS_batch_size;
std::vector<std::vector<PaddleTensor>> *outputs,
int num_threads, int tid) {
int num_times = FLAGS_repeat;
LOG(INFO) << "Thread " << tid << " run " << num_times << " times...";
int iterations = inputs.size(); // process the whole dataset ...
if (FLAGS_iterations > 0 && FLAGS_iterations < inputs.size())
iterations =
FLAGS_iterations; // ... unless the number of iterations is set
outputs->resize(iterations);
LOG(INFO) << "Thread " << tid << ", number of threads " << num_threads
<< ", run " << num_times << " times...";
Timer run_timer;
double elapsed_time = 0;
#ifdef WITH_GPERFTOOLS
......@@ -320,14 +329,14 @@ void PredictionRun(PaddlePredictor *predictor,
#endif
if (!FLAGS_zero_copy) {
run_timer.tic();
for (size_t i = 0; i < inputs.size(); i++) {
for (size_t i = 0; i < iterations; i++) {
for (int j = 0; j < num_times; j++) {
predictor->Run(inputs[i], outputs, batch_size);
predictor->Run(inputs[i], &(*outputs)[i], FLAGS_batch_size);
}
}
elapsed_time = run_timer.toc();
} else {
for (size_t i = 0; i < inputs.size(); i++) {
for (size_t i = 0; i < iterations; i++) {
ConvertPaddleTensorToZeroCopyTensor(predictor, inputs[i]);
run_timer.tic();
for (int j = 0; j < num_times; j++) {
......@@ -340,13 +349,14 @@ void PredictionRun(PaddlePredictor *predictor,
ProfilerStop();
#endif
PrintTime(batch_size, num_times, num_threads, tid, elapsed_time / num_times,
inputs.size());
auto batch_latency = elapsed_time / (iterations * num_times);
PrintTime(FLAGS_batch_size, num_times, num_threads, tid, batch_latency,
iterations);
if (FLAGS_record_benchmark) {
Benchmark benchmark;
benchmark.SetName(FLAGS_model_name);
benchmark.SetBatchSize(batch_size);
benchmark.SetLatency(elapsed_time / num_times);
benchmark.SetBatchSize(FLAGS_batch_size);
benchmark.SetLatency(batch_latency);
benchmark.PersistToFile("benchmark_record.txt");
}
}
......@@ -354,16 +364,17 @@ void PredictionRun(PaddlePredictor *predictor,
void TestOneThreadPrediction(
const PaddlePredictor::Config *config,
const std::vector<std::vector<PaddleTensor>> &inputs,
std::vector<PaddleTensor> *outputs, bool use_analysis = true) {
std::vector<std::vector<PaddleTensor>> *outputs, bool use_analysis = true) {
auto predictor = CreateTestPredictor(config, use_analysis);
PredictionWarmUp(predictor.get(), inputs, outputs, 1, 0);
PredictionRun(predictor.get(), inputs, outputs, 1, 0);
PredictionWarmUp(predictor.get(), inputs, outputs, FLAGS_paddle_num_threads,
0);
PredictionRun(predictor.get(), inputs, outputs, FLAGS_paddle_num_threads, 0);
}
void TestMultiThreadPrediction(
const PaddlePredictor::Config *config,
const std::vector<std::vector<PaddleTensor>> &inputs,
std::vector<PaddleTensor> *outputs, int num_threads,
std::vector<std::vector<PaddleTensor>> *outputs, int num_threads,
bool use_analysis = true) {
std::vector<std::thread> threads;
std::vector<std::unique_ptr<PaddlePredictor>> predictors;
......@@ -376,7 +387,7 @@ void TestMultiThreadPrediction(
threads.emplace_back([&, tid]() {
// Each thread should have local inputs and outputs.
// The inputs of each thread are all the same.
std::vector<PaddleTensor> outputs_tid;
std::vector<std::vector<PaddleTensor>> outputs_tid;
auto &predictor = predictors[tid];
#ifdef PADDLE_WITH_MKLDNN
if (use_analysis) {
......@@ -384,8 +395,8 @@ void TestMultiThreadPrediction(
->SetMkldnnThreadID(static_cast<int>(tid) + 1);
}
#endif
PredictionWarmUp(predictor.get(), inputs, outputs, num_threads, tid);
PredictionRun(predictor.get(), inputs, outputs, num_threads, tid);
PredictionWarmUp(predictor.get(), inputs, &outputs_tid, num_threads, tid);
PredictionRun(predictor.get(), inputs, &outputs_tid, num_threads, tid);
});
}
for (int i = 0; i < num_threads; ++i) {
......@@ -395,8 +406,8 @@ void TestMultiThreadPrediction(
void TestPrediction(const PaddlePredictor::Config *config,
const std::vector<std::vector<PaddleTensor>> &inputs,
std::vector<PaddleTensor> *outputs, int num_threads,
bool use_analysis = FLAGS_use_analysis) {
std::vector<std::vector<PaddleTensor>> *outputs,
int num_threads, bool use_analysis = FLAGS_use_analysis) {
PrintConfig(config, use_analysis);
if (num_threads == 1) {
TestOneThreadPrediction(config, inputs, outputs, use_analysis);
......@@ -406,30 +417,41 @@ void TestPrediction(const PaddlePredictor::Config *config,
}
}
void CompareTopAccuracy(const std::vector<PaddleTensor> &output_slots1,
const std::vector<PaddleTensor> &output_slots2) {
// first output: avg_cost
if (output_slots1.size() == 0 || output_slots2.size() == 0)
void CompareTopAccuracy(
const std::vector<std::vector<PaddleTensor>> &output_slots_quant,
const std::vector<std::vector<PaddleTensor>> &output_slots_ref) {
if (output_slots_quant.size() == 0 || output_slots_ref.size() == 0)
throw std::invalid_argument(
"CompareTopAccuracy: output_slots vector is empty.");
PADDLE_ENFORCE(output_slots1.size() >= 2UL);
PADDLE_ENFORCE(output_slots2.size() >= 2UL);
// second output: acc_top1
if (output_slots1[1].lod.size() > 0 || output_slots2[1].lod.size() > 0)
throw std::invalid_argument(
"CompareTopAccuracy: top1 accuracy output has nonempty LoD.");
if (output_slots1[1].dtype != paddle::PaddleDType::FLOAT32 ||
output_slots2[1].dtype != paddle::PaddleDType::FLOAT32)
throw std::invalid_argument(
"CompareTopAccuracy: top1 accuracy output is of a wrong type.");
float *top1_quantized = static_cast<float *>(output_slots1[1].data.data());
float *top1_reference = static_cast<float *>(output_slots2[1].data.data());
LOG(INFO) << "top1 INT8 accuracy: " << *top1_quantized;
LOG(INFO) << "top1 FP32 accuracy: " << *top1_reference;
float total_accs1_quant{0};
float total_accs1_ref{0};
for (size_t i = 0; i < output_slots_quant.size(); ++i) {
PADDLE_ENFORCE(output_slots_quant[i].size() >= 2UL);
PADDLE_ENFORCE(output_slots_ref[i].size() >= 2UL);
// second output: acc_top1
if (output_slots_quant[i][1].lod.size() > 0 ||
output_slots_ref[i][1].lod.size() > 0)
throw std::invalid_argument(
"CompareTopAccuracy: top1 accuracy output has nonempty LoD.");
if (output_slots_quant[i][1].dtype != paddle::PaddleDType::FLOAT32 ||
output_slots_ref[i][1].dtype != paddle::PaddleDType::FLOAT32)
throw std::invalid_argument(
"CompareTopAccuracy: top1 accuracy output is of a wrong type.");
total_accs1_quant +=
*static_cast<float *>(output_slots_quant[i][1].data.data());
total_accs1_ref +=
*static_cast<float *>(output_slots_ref[i][1].data.data());
}
float avg_acc1_quant = total_accs1_quant / output_slots_quant.size();
float avg_acc1_ref = total_accs1_ref / output_slots_ref.size();
LOG(INFO) << "Avg top1 INT8 accuracy: " << std::fixed << std::setw(6)
<< std::setprecision(4) << avg_acc1_quant;
LOG(INFO) << "Avg top1 FP32 accuracy: " << std::fixed << std::setw(6)
<< std::setprecision(4) << avg_acc1_ref;
LOG(INFO) << "Accepted accuracy drop threshold: " << FLAGS_quantized_accuracy;
CHECK_LE(std::abs(*top1_quantized - *top1_reference),
FLAGS_quantized_accuracy);
CHECK_LE(std::abs(avg_acc1_quant - avg_acc1_ref), FLAGS_quantized_accuracy);
}
void CompareDeterministic(
......@@ -455,20 +477,35 @@ void CompareNativeAndAnalysis(
const PaddlePredictor::Config *config,
const std::vector<std::vector<PaddleTensor>> &inputs) {
PrintConfig(config, true);
std::vector<PaddleTensor> native_outputs, analysis_outputs;
std::vector<std::vector<PaddleTensor>> native_outputs, analysis_outputs;
TestOneThreadPrediction(config, inputs, &native_outputs, false);
TestOneThreadPrediction(config, inputs, &analysis_outputs, true);
CompareResult(analysis_outputs, native_outputs);
PADDLE_ENFORCE(native_outputs.size() > 0, "Native output is empty.");
PADDLE_ENFORCE(analysis_outputs.size() > 0, "Analysis output is empty.");
CompareResult(analysis_outputs.back(), native_outputs.back());
}
void CompareQuantizedAndAnalysis(
const PaddlePredictor::Config *config,
const PaddlePredictor::Config *qconfig,
const AnalysisConfig *config, const AnalysisConfig *qconfig,
const std::vector<std::vector<PaddleTensor>> &inputs) {
PrintConfig(config, true);
std::vector<PaddleTensor> analysis_outputs, quantized_outputs;
TestOneThreadPrediction(config, inputs, &analysis_outputs, true);
TestOneThreadPrediction(qconfig, inputs, &quantized_outputs, true);
PADDLE_ENFORCE_EQ(inputs[0][0].shape[0], FLAGS_batch_size,
"Input data has to be packed batch by batch.");
LOG(INFO) << "FP32 & INT8 prediction run: batch_size " << FLAGS_batch_size
<< ", warmup batch size " << FLAGS_warmup_batch_size << ".";
LOG(INFO) << "--- FP32 prediction start ---";
auto *cfg = reinterpret_cast<const PaddlePredictor::Config *>(config);
PrintConfig(cfg, true);
std::vector<std::vector<PaddleTensor>> analysis_outputs;
TestOneThreadPrediction(cfg, inputs, &analysis_outputs, true);
LOG(INFO) << "--- INT8 prediction start ---";
auto *qcfg = reinterpret_cast<const PaddlePredictor::Config *>(qconfig);
PrintConfig(qcfg, true);
std::vector<std::vector<PaddleTensor>> quantized_outputs;
TestOneThreadPrediction(qcfg, inputs, &quantized_outputs, true);
LOG(INFO) << "--- comparing outputs --- ";
CompareTopAccuracy(quantized_outputs, analysis_outputs);
}
......@@ -578,9 +615,9 @@ static bool CompareTensorData(const framework::LoDTensor &a,
const framework::LoDTensor &b) {
auto a_shape = framework::vectorize(a.dims());
auto b_shape = framework::vectorize(b.dims());
size_t a_size = std::accumulate(a_shape.begin(), a_shape.end(), 1,
size_t a_size = std::accumulate(a_shape.begin(), a_shape.end(), size_t{1},
[](int a, int b) { return a * b; });
size_t b_size = std::accumulate(b_shape.begin(), b_shape.end(), 1,
size_t b_size = std::accumulate(b_shape.begin(), b_shape.end(), size_t{1},
[](int a, int b) { return a * b; });
if (a_size != b_size) {
LOG(ERROR) << string::Sprintf("tensor data size not match, %d != %d",
......
......@@ -74,7 +74,7 @@ void profile(std::string model_dir, bool use_analysis, bool use_tensorrt) {
SetFakeImageInput(&inputs_all, model_dir, false, "__model__", "");
}
std::vector<PaddleTensor> outputs;
std::vector<std::vector<PaddleTensor>> outputs;
if (use_analysis || use_tensorrt) {
AnalysisConfig config;
config.EnableUseGpu(100, 0);
......
......@@ -3,15 +3,11 @@ acos
asin
atan
attention_lstm
bilinear_tensor_product
brelu
conv_shift
cos
cos_sim
dequantize
elementwise_div
elementwise_max
elementwise_min
elu
fc
flatten
......@@ -29,8 +25,6 @@ gelu
gru
hard_shrink
hierarchical_sigmoid
hinge_loss
huber_loss
leaky_relu
log
logsigmoid
......@@ -43,7 +37,6 @@ max_pool3d_with_index
maxout
modified_huber_loss
nce
norm
pool2d
pool3d
pow
......@@ -59,17 +52,7 @@ requantize
reshape
rnn_memory_helper
round
row_conv
sequence_concat
sequence_conv
sequence_expand
sequence_expand_as
sequence_pad
sequence_scatter
sequence_slice
sequence_softmax
sequence_unpad
sigmoid_cross_entropy_with_logits
sin
softplus
softshrink
......@@ -84,7 +67,6 @@ stanh
swish
tanh_shrink
teacher_student_sigmoid_loss
temporal_shift
tensor_array_to_tensor
thresholded_relu
transpose
......
......@@ -13,7 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/affine_grid_op.h"
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cudnn_helper.h"
......@@ -173,9 +175,10 @@ class AffineGridOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
auto theta_dims = ctx->GetInputDim("Theta");
if (ctx->HasOutput(framework::GradVarName("Theta"))) {
ctx->SetOutputDim(framework::GradVarName("Theta"), theta_dims);
auto output_dims = ctx->GetInputDim(framework::GradVarName("Output"));
ctx->SetOutputDim(framework::GradVarName("Theta"),
{output_dims[0], 2, 3});
}
}
......
......@@ -74,5 +74,8 @@ class BatchSizeLikeOpMaker : public framework::OpProtoAndCheckerMaker {
virtual void Apply() = 0;
};
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(BatchSizeLikeNoNeedBufferVarsInference,
"Input");
} // namespace operators
} // namespace paddle
......@@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/bilinear_tensor_product_op.h"
#include <memory>
#include <string>
#include <vector>
namespace paddle {
namespace operators {
......@@ -121,15 +124,9 @@ class BilinearTensorProductOpGrad : public framework::OperatorWithKernel {
"The second dimension of input(Out@GRAD) must be equal to "
"the third dimension of the Input(Weight).");
if (ctx->HasInput("Bias")) {
auto bias_dims = ctx->GetInputDim("Bias");
PADDLE_ENFORCE_EQ(
bias_dims[1], out_dims[1],
"The second dimension of input(Out@GRAD) must be equal to "
"the second dimension of the Input(Bias).");
auto bias_grad_name = framework::GradVarName("Bias");
if (ctx->HasOutput(bias_grad_name))
ctx->SetOutputDim(bias_grad_name, bias_dims);
auto bias_grad_name = framework::GradVarName("Bias");
if (ctx->HasOutput(bias_grad_name)) {
ctx->SetOutputDim(bias_grad_name, {1, out_dims[1]});
}
auto x_grad_name = framework::GradVarName("X");
......@@ -148,13 +145,39 @@ class BilinearTensorProductOpGrad : public framework::OperatorWithKernel {
}
};
class BilinearTensorProductGradOpDescMaker
: public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("bilinear_tensor_product_grad");
op->SetAttrMap(Attrs());
op->SetInput("X", Input("X"));
op->SetInput("Y", Input("Y"));
op->SetInput("Weight", Input("Weight"));
if (ForwardOp().Inputs().count("Bias") > 0) {
op->SetOutput(framework::GradVarName("Bias"), InputGrad("Bias"));
}
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetOutput(framework::GradVarName("Y"), InputGrad("Y"));
op->SetOutput(framework::GradVarName("Weight"), InputGrad("Weight"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
return op;
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(bilinear_tensor_product, ops::BilinearTensorProductOp,
ops::BilinearTensorProductOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
ops::BilinearTensorProductGradOpDescMaker);
REGISTER_OPERATOR(bilinear_tensor_product_grad,
ops::BilinearTensorProductOpGrad);
REGISTER_OP_CPU_KERNEL(
......
......@@ -13,10 +13,47 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_div_op.h"
#include <memory>
#include <string>
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
namespace paddle {
namespace operators {
class ElementwiseDivOpMaker : public ElementwiseOpMaker {
protected:
std::string GetName() const override { return "Div"; }
std::string GetEquation() const override { return "Out = X / Y"; }
};
class ElementwiseDivGradOpDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("elementwise_div_grad");
op->SetInput("Y", Input("Y"));
op->SetInput("Out", Output("Out"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetOutput(framework::GradVarName("Y"), InputGrad("Y"));
op->SetAttrMap(Attrs());
return op;
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_ELEMWISE_OP(elementwise_div, "Div", "Out = X / Y");
REGISTER_OPERATOR(elementwise_div, ops::ElementwiseOp,
ops::ElementwiseDivOpMaker, ops::ElementwiseOpInferVarType,
ops::ElementwiseDivGradOpDescMaker);
REGISTER_OPERATOR(elementwise_div_grad, ops::ElementwiseOpGrad);
REGISTER_OP_CPU_KERNEL(
elementwise_div,
......
......@@ -47,7 +47,7 @@ struct DivGradDX {
template <typename T>
struct DivGradDY {
HOSTDEVICE T operator()(T x, T y, T out, T dout) const {
return -dout * x / (y * y);
return -dout * out / y;
}
};
......@@ -58,13 +58,15 @@ class ElementwiseDivGradKernel : public ElemwiseGradKernel<T> {
ElemwiseGradKernel<T>::Compute(ctx);
using Tensor = framework::Tensor;
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* out = ctx.Input<Tensor>("Out");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
int axis = ctx.Attr<int>("axis");
auto* x = dout; // Fake x, not used
ElemwiseGradCompute<DeviceContext, T, DivGradDX<T>, DivGradDY<T>>(
ctx, *x, *y, *out, *dout, axis, dx, dy, DivGradDX<T>(), DivGradDY<T>());
}
......
......@@ -13,9 +13,48 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_max_op.h"
#include <memory>
#include <string>
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
namespace paddle {
namespace operators {
class ElementwiseMaxOpMaker : public ElementwiseOpMaker {
protected:
std::string GetName() const override { return "Max"; }
std::string GetEquation() const override { return "Out = max(X, Y)"; }
};
class ElementwiseMaxGradOpDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("elementwise_max_grad");
op->SetInput("X", Input("X"));
op->SetInput("Y", Input("Y"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetOutput(framework::GradVarName("Y"), InputGrad("Y"));
op->SetAttrMap(Attrs());
return op;
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_ELEMWISE_OP(elementwise_max, "Max", "Out = max(X, Y)");
REGISTER_OPERATOR(elementwise_max, ops::ElementwiseOp,
ops::ElementwiseMaxOpMaker, ops::ElementwiseOpInferVarType,
ops::ElementwiseMaxGradOpDescMaker);
REGISTER_OPERATOR(elementwise_max_grad, ops::ElementwiseOpGrad);
REGISTER_OP_CPU_KERNEL(
elementwise_max,
ops::ElementwiseMaxKernel<paddle::platform::CPUDeviceContext, float>,
......
......@@ -63,10 +63,10 @@ class ElementwiseMaxGradKernel : public ElemwiseGradKernel<T> {
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* out = ctx.Input<Tensor>("Out");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
auto* out = dout; // Fake out, not used
int axis = ctx.Attr<int>("axis");
ElemwiseGradCompute<DeviceContext, T, MaxGradDx<T>, MaxGradDy<T>>(
ctx, *x, *y, *out, *dout, axis, dx, dy, MaxGradDx<T>(), MaxGradDy<T>());
......
......@@ -13,9 +13,48 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_min_op.h"
#include <memory>
#include <string>
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
namespace paddle {
namespace operators {
class ElementwiseMinOpMaker : public ElementwiseOpMaker {
protected:
std::string GetName() const override { return "Min"; }
std::string GetEquation() const override { return "Out = min(X, Y)"; }
};
class ElementwiseMinGradOpDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("elementwise_min_grad");
op->SetInput("X", Input("X"));
op->SetInput("Y", Input("Y"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetOutput(framework::GradVarName("Y"), InputGrad("Y"));
op->SetAttrMap(Attrs());
return op;
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_ELEMWISE_OP(elementwise_min, "Min", "Out = min(X, Y)");
REGISTER_OPERATOR(elementwise_min, ops::ElementwiseOp,
ops::ElementwiseMinOpMaker, ops::ElementwiseOpInferVarType,
ops::ElementwiseMinGradOpDescMaker);
REGISTER_OPERATOR(elementwise_min_grad, ops::ElementwiseOpGrad);
REGISTER_OP_CPU_KERNEL(
elementwise_min,
ops::ElementwiseMinKernel<paddle::platform::CPUDeviceContext, float>,
......
......@@ -62,10 +62,10 @@ class ElementwiseMinGradKernel : public ElemwiseGradKernel<T> {
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* out = ctx.Input<Tensor>("Out");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
auto* out = dout; // Fake out, not used
int axis = ctx.Attr<int>("axis");
ElemwiseGradCompute<DeviceContext, T, MinGradDx<T>, MinGradDy<T>>(
ctx, *x, *y, *out, *dout, axis, dx, dy, MinGradDx<T>(), MinGradDy<T>());
......
......@@ -173,12 +173,12 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
using Tensor = framework::Tensor;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
auto out_grad_name = framework::GradVarName("Out");
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
PADDLE_ENFORCE(ctx->HasInput(out_grad_name),
"Input(Out@GRAD) should not be null");
auto x_dims = ctx->GetInputDim("X");
auto x_dims = ctx->GetInputDim(out_grad_name);
auto y_dims = ctx->GetInputDim("Y");
PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(),
......@@ -187,8 +187,8 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
auto x_grad_name = framework::GradVarName("X");
auto y_grad_name = framework::GradVarName("Y");
if (ctx->HasOutput(x_grad_name)) {
ctx->ShareDim("X", /*->*/ x_grad_name);
ctx->ShareLoD("X", /*->*/ x_grad_name);
ctx->ShareDim(out_grad_name, /*->*/ x_grad_name);
ctx->ShareLoD(out_grad_name, /*->*/ x_grad_name);
}
if (ctx->HasOutput(y_grad_name)) {
ctx->ShareDim("Y", /*->*/ y_grad_name);
......
......@@ -46,6 +46,7 @@ obtained from the `input` tensor.
)DOC");
}
};
} // namespace operators
} // namespace paddle
......@@ -53,7 +54,8 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(fill_constant_batch_size_like,
ops::FillConstantBatchSizeLikeOp,
paddle::framework::EmptyGradOpMaker,
ops::FillConstantBatchSizeLikeOpMaker);
ops::FillConstantBatchSizeLikeOpMaker,
ops::BatchSizeLikeNoNeedBufferVarsInference);
REGISTER_OP_CPU_KERNEL(
fill_constant_batch_size_like,
ops::FillConstantBatchSizeLikeOpKernel<paddle::platform::CPUDeviceContext,
......
......@@ -36,6 +36,7 @@ class FillZerosLikeOpMaker : public framework::OpProtoAndCheckerMaker {
void Make() override {
AddInput("X", "The input of fill-zeros-like op.");
AddOutput("Out", "The variable will be filled up with zeros.");
ExtraMake();
AddComment(R"DOC(
FillZerosLike Operator.
......@@ -44,13 +45,49 @@ The output will have the same size as the input.
)DOC");
}
protected:
virtual void ExtraMake() {}
};
class FillZerosLikeOp2 : public FillZerosLikeOp {
public:
using FillZerosLikeOp::FillZerosLikeOp;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
static_cast<framework::proto::VarType::Type>(ctx.Attr<int>("dtype")),
ctx.GetPlace());
}
};
class FillZerosLikeOp2Maker : public FillZerosLikeOpMaker {
protected:
void ExtraMake() override {
this->AddAttr<int>("dtype",
"(int, default 5(FP32)) "
"Output data type.")
.SetDefault(framework::proto::VarType::FP32);
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(FillZerosLikeOp2NoNeedBufferVarsInference,
"X");
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(fill_zeros_like, ops::FillZerosLikeOp,
ops::FillZerosLikeOpMaker);
REGISTER_OPERATOR(fill_zeros_like2, ops::FillZerosLikeOp2,
ops::FillZerosLikeOp2Maker,
ops::FillZerosLikeOp2NoNeedBufferVarsInference,
paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(
fill_zeros_like,
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, int>,
......@@ -58,3 +95,11 @@ REGISTER_OP_CPU_KERNEL(
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, float>,
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, double>,
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, bool>);
REGISTER_OP_CPU_KERNEL(
fill_zeros_like2,
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, int>,
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, float>,
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, double>,
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, bool>);
......@@ -26,3 +26,13 @@ REGISTER_OP_CUDA_KERNEL(
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, bool>);
REGISTER_OP_CUDA_KERNEL(
fill_zeros_like2,
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, int>,
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, float>,
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, double>,
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, bool>);
......@@ -65,17 +65,13 @@ by input arguments.
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(
GaussianRandomBatchSizeLikeNoNeedBufferVarsInference, "Input");
} // namespace operators
} // namespace paddle
REGISTER_OPERATOR(
gaussian_random_batch_size_like,
paddle::operators::GaussianRandomBatchSizeLikeOp,
paddle::operators::GaussianRandomBatchSizeLikeOpMaker,
paddle::framework::EmptyGradOpMaker,
paddle::operators::GaussianRandomBatchSizeLikeNoNeedBufferVarsInference);
REGISTER_OPERATOR(gaussian_random_batch_size_like,
paddle::operators::GaussianRandomBatchSizeLikeOp,
paddle::operators::GaussianRandomBatchSizeLikeOpMaker,
paddle::framework::EmptyGradOpMaker,
paddle::operators::BatchSizeLikeNoNeedBufferVarsInference);
// Kernels are registered in gaussian_random_op.cc and gaussian_random_op.cu
......@@ -16,6 +16,7 @@ limitations under the License. */
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
namespace paddle {
namespace operators {
......@@ -107,8 +108,6 @@ class GroupNormGradOp : public framework::OperatorWithKernel {
// check input
PADDLE_ENFORCE(ctx->HasInput("Y"),
"Input(Y) of GroupNormOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Mean"),
"Input(Mean) of GroupNormOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Variance"),
"Input(Variance) of GroupNormOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")),
......@@ -159,7 +158,6 @@ class GroupNormGradMaker : public framework::SingleGradOpDescMaker {
op->SetInput("Bias", Input("Bias"));
op->SetInput(framework::GradVarName("Y"), OutputGrad("Y"));
op->SetInput("Y", Output("Y"));
op->SetInput("Mean", Output("Mean"));
op->SetInput("Variance", Output("Variance"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
......
......@@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/hinge_loss_op.h"
#include <memory>
#include <string>
#include <vector>
namespace paddle {
namespace operators {
......@@ -97,12 +100,29 @@ class HingeLossGradOp : public framework::OperatorWithKernel {
}
};
class HingeLossGradOpDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("hinge_loss_grad");
op->SetInput("Logits", Input("Logits"));
op->SetInput("Labels", Input("Labels"));
op->SetInput(framework::GradVarName("Loss"), OutputGrad("Loss"));
op->SetOutput(framework::GradVarName("Logits"), InputGrad("Logits"));
op->SetAttrMap(Attrs());
return op;
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(hinge_loss, ops::HingeLossOp, ops::HingeLossOpMaker<float>,
paddle::framework::DefaultGradOpDescMaker<true>);
ops::HingeLossGradOpDescMaker);
REGISTER_OPERATOR(hinge_loss_grad, ops::HingeLossGradOp);
REGISTER_OP_CPU_KERNEL(
hinge_loss,
......
......@@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/huber_loss_op.h"
#include <memory>
#include <string>
#include <vector>
namespace paddle {
namespace operators {
......@@ -90,38 +93,45 @@ class HuberLossGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Residual"),
"Input(Residual) should not be null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null.");
auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
auto residual_dims = ctx->GetInputDim("Residual");
auto out_grad_dims = ctx->GetInputDim(framework::GradVarName("Out"));
PADDLE_ENFORCE_EQ(residual_dims, x_dims);
PADDLE_ENFORCE_EQ(out_grad_dims, x_dims);
auto x_grad_name = framework::GradVarName("X");
auto y_grad_name = framework::GradVarName("Y");
if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, x_dims);
ctx->SetOutputDim(x_grad_name, residual_dims);
}
if (ctx->HasOutput(y_grad_name)) {
ctx->SetOutputDim(y_grad_name, y_dims);
ctx->SetOutputDim(y_grad_name, residual_dims);
}
}
};
class HuberLossGradOpDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("huber_loss_grad");
op->SetInput("Residual", Output("Residual"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetOutput(framework::GradVarName("Y"), InputGrad("Y"));
op->SetAttrMap(Attrs());
return op;
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(huber_loss, ops::HuberLossOp, ops::HuberLossOpMaker<float>,
paddle::framework::DefaultGradOpDescMaker<true>);
ops::HuberLossGradOpDescMaker);
REGISTER_OPERATOR(huber_loss_grad, ops::HuberLossGradOp);
REGISTER_OP_CPU_KERNEL(
huber_loss, ops::HuberLossKernel<paddle::platform::CPUDeviceContext, float>,
......
......@@ -13,6 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/norm_op.h"
#include <memory>
#include <string>
#include <vector>
namespace paddle {
namespace operators {
......@@ -74,6 +78,24 @@ class NormOpGrad : public framework::OperatorWithKernel {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
};
class NormOpGradOpDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("norm_grad");
op->SetAttrMap(Attrs());
op->SetInput("X", Input("X"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetInput("Norm", Output("Norm"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
return op;
}
};
} // namespace operators
} // namespace paddle
......@@ -81,7 +103,7 @@ namespace ops = paddle::operators;
using CPU = paddle::platform::CPUDeviceContext;
REGISTER_OPERATOR(norm, ops::NormOp, ops::NormOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
ops::NormOpGradOpDescMaker);
REGISTER_OPERATOR(norm_grad, ops::NormOpGrad);
REGISTER_OP_CPU_KERNEL(norm, ops::NormKernel<CPU, float>,
ops::NormKernel<CPU, double>);
......
......@@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
......@@ -612,8 +615,9 @@ class Pad2dOpGrad : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
ctx.GetPlace());
return framework::OpKernelType(
ctx.Input<Tensor>(framework::GradVarName("Out"))->type(),
ctx.GetPlace());
}
};
......@@ -625,7 +629,9 @@ class Pad2dOpGradMaker : public framework::SingleGradOpDescMaker {
std::unique_ptr<framework::OpDesc> Apply() const override {
auto* bind = new framework::OpDesc();
bind->SetInput("X", Input("X"));
bind->SetInput("Paddings", Input("Paddings"));
if (ForwardOp().Inputs().count("Paddings") > 0) {
bind->SetInput("Paddings", Input("Paddings"));
}
bind->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
bind->SetOutput(framework::GradVarName("X"), InputGrad("X"));
bind->SetAttrMap(Attrs());
......@@ -634,6 +640,10 @@ class Pad2dOpGradMaker : public framework::SingleGradOpDescMaker {
}
};
// TODO(zjl): Paddings can also be skipped!
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(Pad2dOpGradNoNeedBufferVarsInference,
"X");
} // namespace operators
} // namespace paddle
......@@ -641,6 +651,7 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(pad2d, ops::Pad2dOp, ops::Pad2dOpMaker,
ops::Pad2dOpGradMaker);
REGISTER_OPERATOR(pad2d_grad, ops::Pad2dOpGrad);
REGISTER_OPERATOR(pad2d_grad, ops::Pad2dOpGrad,
ops::Pad2dOpGradNoNeedBufferVarsInference);
REGISTER_OP_CPU_KERNEL(pad2d, ops::Pad2dCPUKernel<float>);
REGISTER_OP_CPU_KERNEL(pad2d_grad, ops::Pad2dGradCPUKernel<float>);
......@@ -13,6 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/row_conv_op.h"
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/eigen.h"
namespace paddle {
......@@ -54,7 +58,6 @@ class RowConvGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Filter"),
"Input(Filter) should not be null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
......@@ -62,8 +65,8 @@ class RowConvGradOp : public framework::OperatorWithKernel {
auto x_grad_name = framework::GradVarName("X");
if (ctx->HasOutput(x_grad_name)) {
auto x_dims = ctx->GetInputDim("X");
ctx->SetOutputDim(x_grad_name, x_dims);
auto dout_dims = ctx->GetInputDim(framework::GradVarName("Out"));
ctx->SetOutputDim(x_grad_name, dout_dims);
}
auto filter_grad_name = framework::GradVarName("Filter");
......@@ -259,12 +262,31 @@ class RowConvGradKernel<platform::CPUDeviceContext, T>
}
}
};
class RowConvGradOpDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("row_conv_grad");
op->SetAttrMap(Attrs());
op->SetInput("X", Input("X"));
op->SetInput("Filter", Input("Filter"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetOutput(framework::GradVarName("Filter"), InputGrad("Filter"));
return op;
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(row_conv, ops::RowConvOp, ops::RowConvOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
ops::RowConvGradOpDescMaker);
REGISTER_OPERATOR(row_conv_grad, ops::RowConvGradOp);
REGISTER_OP_CPU_KERNEL(
row_conv, ops::RowConvKernel<paddle::platform::CPUDeviceContext, float>);
......
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/fluid/operators/sequence_ops/sequence_concat_op.h"
#include <memory>
#include <vector>
namespace paddle {
......@@ -73,13 +74,43 @@ class SeqConcatShapeInferer : public framework::InferShapeBase {
}
};
class SeqConcatGradShapeInferer : public framework::InferShapeBase {
class SeqConcatGradOpDescMaker : public framework::SingleGradOpDescMaker {
public:
void operator()(framework::InferShapeContext *context) const override {
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("sequence_concat_grad");
op->SetInput("X", Input("X"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X", false));
op->SetAttrMap(Attrs());
return op;
}
};
class SeqConcatGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *context) const override {
context->SetOutputsDim(framework::GradVarName("X"),
context->GetInputsDim("X"));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
ctx.Input<framework::Tensor>(framework::GradVarName("Out"))->type(),
ctx.GetPlace());
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(SeqConcatGradNoNeedBufferVarsInference,
"X");
} // namespace operators
} // namespace paddle
......@@ -87,14 +118,14 @@ namespace op = paddle::operators;
REGISTER_OPERATOR(sequence_concat, paddle::framework::OperatorWithKernel,
op::SeqConcatOpMaker, op::SeqConcatShapeInferer,
paddle::framework::DefaultGradOpDescMaker<false>);
op::SeqConcatGradOpDescMaker);
template <typename T>
using Kernel = op::SeqConcatKernel<paddle::platform::CPUDeviceContext, T>;
REGISTER_OP_CPU_KERNEL(sequence_concat, Kernel<float>, Kernel<double>,
Kernel<int64_t>);
REGISTER_OPERATOR(sequence_concat_grad, paddle::framework::OperatorWithKernel,
op::SeqConcatGradShapeInferer);
REGISTER_OPERATOR(sequence_concat_grad, op::SeqConcatGradOp,
op::SeqConcatGradNoNeedBufferVarsInference);
template <typename T>
using GradKernel =
op::SeqConcatGradKernel<paddle::platform::CPUDeviceContext, T>;
......
......@@ -14,7 +14,9 @@
#pragma once
#include <utility>
#include <vector>
#include "boost/optional.hpp"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/math/concat_and_split.h"
......@@ -89,37 +91,49 @@ class SeqConcatGradKernel : public framework::OpKernel<T> {
dxs[i]->mutable_data<T>(context.GetPlace());
}
}
std::vector<framework::Tensor> sliced_x;
std::vector<boost::variant<boost::blank, framework::Tensor>> sliced_dx;
std::vector<boost::optional<framework::Tensor>> sliced_dx;
for (size_t i = 1; i < xs[0]->lod()[0].size(); ++i) {
for (size_t j = 0; j < xs.size(); ++j) {
const framework::LoDTensor *x = xs[j];
framework::DDim x_dims = x->dims();
framework::LoDTensor *dx = dxs[j];
auto &x_lod = x->lod()[0];
sliced_x.emplace_back(x->Slice(x_lod[i - 1], x_lod[i]));
if (dx != nullptr) {
sliced_dx.emplace_back(dx->Slice(x_lod[i - 1], x_lod[i]));
auto prev_lod = x_lod[i - 1];
auto next_lod = x_lod[i];
x_dims[0] = next_lod - prev_lod;
sliced_x.emplace_back();
sliced_x.back().Resize(x_dims);
if (dx) {
sliced_dx.emplace_back(dx->Slice(prev_lod, next_lod));
} else {
sliced_dx.emplace_back(boost::blank());
sliced_dx.emplace_back(boost::none);
}
}
}
math::SplitFunctor<DeviceContext, T> functor;
std::vector<const framework::Tensor *> sliced_x_ptr;
std::vector<framework::Tensor *> sliced_dx_ptr;
sliced_x_ptr.reserve(sliced_x.size());
for (auto &x : sliced_x) {
sliced_x_ptr.emplace_back(&x);
}
std::vector<framework::Tensor *> sliced_dx_ptr;
sliced_dx_ptr.reserve(sliced_dx.size());
for (auto &dx : sliced_dx) {
try {
sliced_dx_ptr.emplace_back(&boost::get<framework::Tensor>(dx));
} catch (boost::bad_get &) {
sliced_dx_ptr.emplace_back(nullptr);
if (dx) {
sliced_dx_ptr.emplace_back(&dx.get());
}
}
math::SplitFunctor<DeviceContext, T> functor;
functor(context.template device_context<DeviceContext>(),
detail::Ref(
context.Input<framework::Tensor>(framework::GradVarName("Out")),
......
......@@ -15,6 +15,9 @@ limitations under the License. */
#include "paddle/fluid/operators/sequence_ops/sequence_conv_op.h"
#include <algorithm>
#include <memory>
#include <string>
#include <unordered_set>
namespace paddle {
namespace operators {
......@@ -171,13 +174,57 @@ context_length, context_stride and context_start.
}
};
class SequenceConvGradOpDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("sequence_conv_grad");
op->SetAttrMap(Attrs());
if (boost::get<bool>(Attrs().at("paddingTrainable")) &&
ForwardOp().Inputs().count("PaddingData") > 0) {
op->SetInput("PaddingData", Input("PaddingData"));
op->SetOutput(framework::GradVarName("PaddingData"),
InputGrad("PaddingData"));
}
op->SetInput("X", Input("X"));
op->SetInput("Filter", Input("Filter"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetOutput(framework::GradVarName("Filter"), InputGrad("Filter"));
return op;
}
};
class SequenceConvGradNoNeedBufferVarsInference
: public framework::NoNeedBufferVarsInference {
public:
using framework::NoNeedBufferVarsInference::NoNeedBufferVarsInference;
std::unordered_set<std::string> operator()() const override {
if (!boost::get<bool>(Attrs().at("paddingTrainable"))) {
return {"PaddingData"};
} else {
return {};
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(sequence_conv, ops::SequenceConvOp, ops::SequenceConvOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(sequence_conv_grad, ops::SequenceConvGradOp);
ops::SequenceConvGradOpDescMaker);
REGISTER_OPERATOR(sequence_conv_grad, ops::SequenceConvGradOp,
ops::SequenceConvGradNoNeedBufferVarsInference);
REGISTER_OP_CPU_KERNEL(
sequence_conv,
......
......@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/sequence_ops/sequence_expand_as_op.h"
#include <memory>
#include <string>
namespace paddle {
namespace operators {
......@@ -70,6 +72,12 @@ class SequenceExpandAsOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("Out", out_dims);
ctx->ShareLoD("Y", /*->*/ "Out");
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(),
ctx.GetPlace());
}
};
class SequenceExpandAsOpMaker : public framework::OpProtoAndCheckerMaker {
......@@ -131,7 +139,6 @@ class SequenceExpandAsOpGrad : public framework::OperatorWithKernel {
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Out"), "Input(Out) should not be null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null.");
......@@ -143,16 +150,48 @@ class SequenceExpandAsOpGrad : public framework::OperatorWithKernel {
ctx->ShareLoD("X", x_grad_name);
}
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
ctx.Input<framework::Tensor>(framework::GradVarName("Out"))->type(),
ctx.GetPlace());
}
};
class SequenceExpandAsOpGradOpDescMaker
: public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("sequence_expand_as_grad");
op->SetInput("X", Input("X"));
op->SetInput("Y", Input("Y"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetAttrMap(Attrs());
return op;
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(
SequenceExpandAsOpNoNeedBufferVarsInference, "Y");
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(
SequenceExpandAsGradOpNoNeedBufferVarsInference, "X", "Y");
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(sequence_expand_as, ops::SequenceExpandAsOp,
ops::SequenceExpandAsOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(sequence_expand_as_grad, ops::SequenceExpandAsOpGrad);
ops::SequenceExpandAsOpGradOpDescMaker,
ops::SequenceExpandAsOpNoNeedBufferVarsInference);
REGISTER_OPERATOR(sequence_expand_as_grad, ops::SequenceExpandAsOpGrad,
ops::SequenceExpandAsGradOpNoNeedBufferVarsInference);
REGISTER_OP_CPU_KERNEL(
sequence_expand_as,
ops::SequenceExpandAsKernel<paddle::platform::CPUDeviceContext, float>,
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/sequence_ops/sequence_expand_op.h"
#include <memory>
namespace paddle {
namespace operators {
......@@ -96,6 +97,12 @@ class SequenceExpandOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("Out", out_dims);
ctx->ShareLoD("X", /*->*/ "Out");
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(),
ctx.GetPlace());
}
};
class SequenceExpandOpMaker : public framework::OpProtoAndCheckerMaker {
......@@ -188,7 +195,6 @@ class SequenceExpandOpGrad : public framework::OperatorWithKernel {
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Out"), "Input(Out) should not be null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null.");
......@@ -199,16 +205,47 @@ class SequenceExpandOpGrad : public framework::OperatorWithKernel {
ctx->SetOutputDim(x_grad_name, x_dims);
}
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
ctx.Input<framework::Tensor>(framework::GradVarName("Out"))->type(),
ctx.GetPlace());
}
};
class SequenceExpandOpGradDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("sequence_expand_grad");
op->SetInput("X", Input("X"));
op->SetInput("Y", Input("Y"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetAttrMap(Attrs());
return op;
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(SequenceExpandOpNoNeedBufferVarsInference,
"Y");
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(
SequenceExpandGradOpNoNeedBufferVarsInference, "X", "Y");
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(sequence_expand, ops::SequenceExpandOp,
ops::SequenceExpandOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(sequence_expand_grad, ops::SequenceExpandOpGrad);
ops::SequenceExpandOpGradDescMaker,
ops::SequenceExpandOpNoNeedBufferVarsInference);
REGISTER_OPERATOR(sequence_expand_grad, ops::SequenceExpandOpGrad,
ops::SequenceExpandGradOpNoNeedBufferVarsInference);
REGISTER_OP_CPU_KERNEL(
sequence_expand,
ops::SequenceExpandKernel<paddle::platform::CPUDeviceContext, float>,
......
......@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/sequence_ops/sequence_pad_op.h"
#include <memory>
#include <string>
namespace paddle {
namespace operators {
......@@ -194,18 +196,39 @@ class SequencePadGradOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("X"));
auto data_type = framework::GetDataTypeOfVar(
ctx.InputVar(framework::GradVarName("Out")));
return framework::OpKernelType(data_type, ctx.device_context());
}
};
class SequencePadGradOpDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("sequence_pad_grad");
op->SetAttrMap(Attrs());
op->SetInput("X", Input("X"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
return op;
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(
SequencePadGradOpNoNeedBufferVarsInference, "X");
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(sequence_pad, ops::SequencePadOp, ops::SequencePadOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(sequence_pad_grad, ops::SequencePadGradOp);
ops::SequencePadGradOpDescMaker);
REGISTER_OPERATOR(sequence_pad_grad, ops::SequencePadGradOp,
ops::SequencePadGradOpNoNeedBufferVarsInference);
REGISTER_OP_CPU_KERNEL(
sequence_pad,
ops::SequencePadOpKernel<paddle::platform::CPUDeviceContext, float>,
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/sequence_ops/sequence_pool_op.h"
#include <memory>
#include <string>
namespace paddle {
......@@ -114,8 +115,9 @@ class SequencePoolGradOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
ctx.device_context());
return framework::OpKernelType(
ctx.Input<Tensor>(framework::GradVarName("Out"))->type(),
ctx.device_context());
}
};
......@@ -138,13 +140,17 @@ class SequencePoolGradOpMaker : public framework::SingleGradOpDescMaker {
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(
SequencePoolGradOpNoNeedBufferVarsInference, "X");
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(sequence_pool, ops::SequencePoolOp, ops::SequencePoolOpMaker,
ops::SequencePoolGradOpMaker);
REGISTER_OPERATOR(sequence_pool_grad, ops::SequencePoolGradOp);
REGISTER_OPERATOR(sequence_pool_grad, ops::SequencePoolGradOp,
ops::SequencePoolGradOpNoNeedBufferVarsInference);
REGISTER_OP_CPU_KERNEL(
sequence_pool,
ops::SequencePoolKernel<paddle::platform::CPUDeviceContext, float>);
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/sequence_ops/sequence_scatter_op.h"
#include <memory>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/gather.h"
......@@ -124,25 +125,49 @@ class SequenceScatterGradOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override {
ctx->SetOutputDim(framework::GradVarName("Updates"),
ctx->GetInputDim("Updates"));
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
ctx->SetOutputDim(framework::GradVarName("X"),
ctx->GetInputDim(framework::GradVarName("Out")));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
platform::CPUPlace());
return framework::OpKernelType(
ctx.Input<Tensor>(framework::GradVarName("Out"))->type(),
platform::CPUPlace());
}
};
class SequenceScatterGradDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("sequence_scatter_grad");
op->SetInput("Ids", Input("Ids"));
op->SetInput("Updates", Input("Updates"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetOutput(framework::GradVarName("Updates"), InputGrad("Updates"));
op->SetAttrMap(Attrs());
return op;
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(
SequenceScatterGradNoNeedBufferVarsInference, "Updates");
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(sequence_scatter, ops::SequenceScatterOp,
ops::SequenceScatterOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(sequence_scatter_grad, ops::SequenceScatterGradOp);
ops::SequenceScatterGradDescMaker);
REGISTER_OPERATOR(sequence_scatter_grad, ops::SequenceScatterGradOp,
ops::SequenceScatterGradNoNeedBufferVarsInference);
REGISTER_OP_CPU_KERNEL(sequence_scatter, ops::SequenceScatterOpKernel<float>,
ops::SequenceScatterOpKernel<double>,
ops::SequenceScatterOpKernel<int>,
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/sequence_ops/sequence_slice_op.h"
#include <memory>
namespace paddle {
namespace operators {
......@@ -70,8 +71,9 @@ class SequenceSliceGradOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
ctx.device_context());
return framework::OpKernelType(
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"))->type(),
ctx.device_context());
}
};
......@@ -113,14 +115,35 @@ NOTE: The first dimension size of input, the size of offset and Length, should b
}
};
class SequenceSliceGradOpDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("sequence_slice_grad");
op->SetInput("X", Input("X"));
op->SetInput("Offset", Input("Offset"));
op->SetInput("Length", Input("Length"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetAttrMap(Attrs());
return op;
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(
SequenceSliceGradNoNeedBufferVarsInference, "X");
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(sequence_slice, ops::SequenceSliceOp,
ops::SequenceSliceOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(sequence_slice_grad, ops::SequenceSliceGradOp);
ops::SequenceSliceOpMaker, ops::SequenceSliceGradOpDescMaker);
REGISTER_OPERATOR(sequence_slice_grad, ops::SequenceSliceGradOp,
ops::SequenceSliceGradNoNeedBufferVarsInference);
REGISTER_OP_CPU_KERNEL(
sequence_slice,
ops::SequenceSliceOpKernel<paddle::platform::CPUDeviceContext, float>);
......
......@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/sequence_ops/sequence_unpad_op.h"
#include <memory>
#include <string>
namespace paddle {
namespace operators {
......@@ -125,19 +127,39 @@ class SequenceUnpadGradOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("X"));
auto data_type = framework::GetDataTypeOfVar(
ctx.InputVar(framework::GradVarName("Out")));
return framework::OpKernelType(data_type, ctx.device_context());
}
};
class SequenceUnpadGradOpDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("sequence_unpad_grad");
op->SetAttrMap(Attrs());
op->SetInput("X", Input("X"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
return op;
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(
SequenceUnpadGradOpNoNeedBufferVarsInference, "X");
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(sequence_unpad, ops::SequenceUnpadOp,
ops::SequenceUnpadOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(sequence_unpad_grad, ops::SequenceUnpadGradOp);
ops::SequenceUnpadOpMaker, ops::SequenceUnpadGradOpDescMaker);
REGISTER_OPERATOR(sequence_unpad_grad, ops::SequenceUnpadGradOp,
ops::SequenceUnpadGradOpNoNeedBufferVarsInference);
REGISTER_OP_CPU_KERNEL(
sequence_unpad,
ops::SequenceUnpadOpKernel<paddle::platform::CPUDeviceContext, float>,
......
......@@ -81,10 +81,9 @@ class SequenceUnpadGradOpKernel : public framework::OpKernel<T> {
auto* d_x = ctx.Output<LoDTensor>(framework::GradVarName("X"));
if (d_x) {
const auto* d_out = ctx.Input<LoDTensor>(framework::GradVarName("Out"));
const auto* x_t = ctx.Input<LoDTensor>("X");
d_x->mutable_data<T>(ctx.GetPlace());
int padded_length = x_t->dims()[1];
int padded_length = d_x->dims()[1];
LoDTensor zero_pads;
zero_pads.Resize({1, 1});
......
......@@ -11,6 +11,7 @@ limitations under the License. */
#include "paddle/fluid/operators/shuffle_channel_op.h"
#include <memory>
#include <string>
namespace paddle {
namespace operators {
......@@ -73,12 +74,7 @@ class ShuffleChannelGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@Grad) should not be null");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
"Output(X@Grad) should not be null");
auto input_dims = ctx->GetInputDim("X");
auto input_dims = ctx->GetInputDim(framework::GradVarName("Out"));
PADDLE_ENFORCE(input_dims.size() == 4, "The layout of input is NCHW.");
ctx->SetOutputDim(framework::GradVarName("X"), input_dims);
......@@ -87,8 +83,9 @@ class ShuffleChannelGradOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(),
ctx.device_context());
return framework::OpKernelType(
ctx.Input<framework::Tensor>(framework::GradVarName("Out"))->type(),
ctx.device_context());
}
};
......@@ -100,7 +97,6 @@ class ShuffleChannelGradDescMaker : public framework::SingleGradOpDescMaker {
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("shuffle_channel_grad");
op->SetInput("X", Input("X"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetAttrMap(Attrs());
......
......@@ -78,10 +78,14 @@ template <typename DeviceContext, typename T>
class ShuffleChannelGradOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<framework::Tensor>("X");
auto* output_grad =
ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* input_grad =
ctx.Output<framework::Tensor>(framework::GradVarName("X"));
int group = ctx.Attr<int>("group");
auto input_dims = input->dims();
const auto& input_dims = input_grad->dims();
auto num = input_dims[0];
auto channel = input_dims[1];
auto height = input_dims[2];
......@@ -91,10 +95,7 @@ class ShuffleChannelGradOpCUDAKernel : public framework::OpKernel<T> {
int group_row = group;
int group_column = channel / group_row;
auto* output_grad =
ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* input_grad =
ctx.Output<framework::Tensor>(framework::GradVarName("X"));
T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
const T* output_grad_data = output_grad->data<T>();
......
......@@ -57,10 +57,14 @@ template <typename DeviceContext, typename T>
class ShuffleChannelGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<framework::Tensor>("X");
auto* output_grad =
ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* input_grad =
ctx.Output<framework::Tensor>(framework::GradVarName("X"));
int group = ctx.Attr<int>("group");
auto input_dims = input->dims();
const auto& input_dims = input_grad->dims();
auto num = input_dims[0];
auto channel = input_dims[1];
auto height = input_dims[2];
......@@ -71,10 +75,6 @@ class ShuffleChannelGradOpKernel : public framework::OpKernel<T> {
int group_row = group;
int group_column = channel / group_row;
auto* output_grad =
ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* input_grad =
ctx.Output<framework::Tensor>(framework::GradVarName("X"));
T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
const T* output_grad_data = output_grad->data<T>();
for (int n = 0; n < num; ++n) {
......
......@@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.h"
#include <memory>
#include <string>
#include <vector>
namespace paddle {
namespace operators {
......@@ -139,6 +142,24 @@ However the output only shares the LoD with input `X`.
}
};
class SigmoidCrossEntropyWithLogitsGradOpDescMaker
: public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("sigmoid_cross_entropy_with_logits_grad");
op->SetInput("X", Input("X"));
op->SetInput("Label", Input("Label"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetAttrMap(Attrs());
return op;
}
};
} // namespace operators
} // namespace paddle
......@@ -146,7 +167,7 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(sigmoid_cross_entropy_with_logits,
ops::SigmoidCrossEntropyWithLogitsOp,
ops::SigmoidCrossEntropyWithLogitsOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
ops::SigmoidCrossEntropyWithLogitsGradOpDescMaker);
REGISTER_OPERATOR(sigmoid_cross_entropy_with_logits_grad,
ops::SigmoidCrossEntropyWithLogitsGradOp);
REGISTER_OP_CPU_KERNEL(
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/slice_op.h"
#include <algorithm>
#include <memory>
#include <vector>
namespace paddle {
......@@ -135,6 +136,13 @@ class SliceOpGrad : public framework::OperatorWithKernel {
ctx->SetOutputDim(x_grad_name, x_dims);
}
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
ctx.Input<framework::Tensor>(framework::GradVarName("Out"))->type(),
ctx.GetPlace());
}
};
class SliceOpGradMaker : public framework::SingleGradOpDescMaker {
......@@ -153,13 +161,17 @@ class SliceOpGradMaker : public framework::SingleGradOpDescMaker {
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(SliceOpGradNoNeedBufferVarsInference,
"Input");
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(slice, ops::SliceOp, ops::SliceOpMaker,
ops::SliceOpGradMaker);
REGISTER_OPERATOR(slice_grad, ops::SliceOpGrad);
REGISTER_OPERATOR(slice_grad, ops::SliceOpGrad,
ops::SliceOpGradNoNeedBufferVarsInference);
REGISTER_OP_CPU_KERNEL(
slice, ops::SliceKernel<paddle::platform::CPUDeviceContext, int>,
......
......@@ -10,6 +10,9 @@
limitations under the License. */
#include "paddle/fluid/operators/temporal_shift_op.h"
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
......@@ -125,19 +128,32 @@ class TemporalShiftOpGrad : public framework::OperatorWithKernel {
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null");
auto dim_x = ctx->GetInputDim("X");
if (ctx->HasOutput(framework::GradVarName("X"))) {
ctx->SetOutputDim(framework::GradVarName("X"), dim_x);
ctx->SetOutputDim(framework::GradVarName("X"),
ctx->GetInputDim(framework::GradVarName("Out")));
}
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
ctx.GetPlace());
return framework::OpKernelType(
ctx.Input<Tensor>(framework::GradVarName("Out"))->type(),
ctx.GetPlace());
}
};
class TemporalShiftGradOpDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("temporal_shift_grad");
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetAttrMap(Attrs());
return op;
}
};
......@@ -146,8 +162,7 @@ class TemporalShiftOpGrad : public framework::OperatorWithKernel {
namespace ops = paddle::operators;
REGISTER_OPERATOR(temporal_shift, ops::TemporalShiftOp,
ops::TemporalShiftOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
ops::TemporalShiftOpMaker, ops::TemporalShiftGradOpDescMaker);
REGISTER_OPERATOR(temporal_shift_grad, ops::TemporalShiftOpGrad);
REGISTER_OP_CPU_KERNEL(temporal_shift, ops::TemporalShiftKernel<float>,
ops::TemporalShiftKernel<double>);
......
......@@ -64,8 +64,9 @@ with random values sampled from a uniform distribution.
} // namespace operators
} // namespace paddle
REGISTER_OP_WITHOUT_GRADIENT(
uniform_random_batch_size_like,
paddle::operators::UniformRandomBatchSizeLikeOp,
paddle::operators::UniformRandomBatchSizeLikeOpMaker);
REGISTER_OPERATOR(uniform_random_batch_size_like,
paddle::operators::UniformRandomBatchSizeLikeOp,
paddle::operators::UniformRandomBatchSizeLikeOpMaker,
paddle::framework::EmptyGradOpMaker,
paddle::operators::BatchSizeLikeNoNeedBufferVarsInference);
// Kernels are registered in uniform_random_op.cc and uniform_random_op.cu
......@@ -1299,7 +1299,20 @@ All parameter, weight, gradient are variables in Paddle.
to fuse relu and depthwise_conv2d,
it will save GPU memory and may make the execution faster.
This options is only available in GPU devices.
Default False)DOC")
Default False.)DOC")
.def_property(
"fuse_broadcast_ops",
[](const BuildStrategy &self) { return self.fuse_broadcast_ops_; },
[](BuildStrategy &self, bool b) {
PADDLE_ENFORCE(!self.IsFinalized(), "BuildStrategy is finlaized.");
self.fuse_broadcast_ops_ = b;
},
R"DOC(The type is BOOL, fuse_broadcast_op indicates whether
to fuse the broadcast ops. Note that, in Reduce mode,
fusing broadcast ops may make the program faster. Because
fusing broadcast OP equals delaying the execution of all
broadcast Ops, in this case, all nccl streams are used only
for NCCLReduce operations for a period of time. Default False.)DOC")
.def_property("fuse_all_optimizer_ops",
[](const BuildStrategy &self) {
return self.fuse_all_optimizer_ops_;
......
......@@ -231,9 +231,16 @@ def _remove_no_grad_branch_(op_descs, no_grad_set):
for idx, op_desc in enumerate(op_descs):
for arg in op_desc.input_arg_names():
if core.grad_var_suffix() in arg and arg in no_grad_set:
to_insert.append((_create_op_desc_("fill_zeros_like", {
"X": [_strip_grad_suffix_(arg)]
}, {"Out": [arg]}, {}), idx))
x_in = _strip_grad_suffix_(arg)
x_in_var_desc = op_desc.block().find_var_recursive(
cpt.to_bytes(x_in))
assert x_in_var_desc is not None, "Variable {} not found".format(
x_in)
dtype = x_in_var_desc.dtype()
to_insert.append(
(_create_op_desc_("fill_zeros_like2", {"X": [x_in]},
{"Out": [arg]}, {"dtype": dtype}), idx))
list([op_descs.insert(p[1], p[0]) for p in reversed(to_insert)])
......
此差异已折叠。
......@@ -227,7 +227,7 @@ class Precision(MetricBase):
metric.reset()
for data in train_reader():
loss, preds, labels = exe.run(fetch_list=[cost, preds, labels])
metric.update(preds=preds, labels=labels)
metric.update(preds=preds, labels=labels)
numpy_precision = metric.eval()
"""
......@@ -241,9 +241,11 @@ class Precision(MetricBase):
raise ValueError("The 'preds' must be a numpy ndarray.")
if not _is_numpy_(labels):
raise ValueError("The 'labels' must be a numpy ndarray.")
sample_num = labels[0]
sample_num = labels.shape[0]
preds = np.rint(preds).astype("int32")
for i in range(sample_num):
pred = preds[i].astype("int32")
pred = preds[i]
label = labels[i]
if label == 1:
if pred == label:
......
......@@ -81,6 +81,7 @@ list(REMOVE_ITEM TEST_OPS test_imperative_resnet)
list(REMOVE_ITEM TEST_OPS test_imperative_se_resnext)
list(REMOVE_ITEM TEST_OPS test_imperative_mnist)
list(REMOVE_ITEM TEST_OPS test_ir_memory_optimize_transformer)
list(REMOVE_ITEM TEST_OPS test_layers)
foreach(TEST_OP ${TEST_OPS})
py_test_modules(${TEST_OP} MODULES ${TEST_OP})
endforeach(TEST_OP)
......@@ -118,7 +119,7 @@ py_test_modules(test_parallel_executor_crf MODULES test_parallel_executor_crf SE
py_test_modules(test_parallel_executor_fetch_feed MODULES test_parallel_executor_fetch_feed SERIAL)
set_tests_properties(test_parallel_executor_fetch_feed PROPERTIES TIMEOUT 450)
py_test_modules(test_parallel_executor_transformer MODULES test_parallel_executor_transformer SERIAL)
py_test_modules(test_layers MODULES test_layers ENVS FLAGS_cudnn_deterministic=1)
if(NOT WIN32)
py_test_modules(test_ir_memory_optimize_transformer MODULES test_ir_memory_optimize_transformer SERIAL)
endif()
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle.fluid as fluid
import importlib
fluid.core._set_eager_deletion_mode(0.0, 1.0, True)
from test_elementwise_add_op import *
from test_elementwise_sub_op import *
from test_concat_op import *
from test_gather_op import *
from test_gaussian_random_batch_size_like_op import *
from test_uniform_random_batch_size_like_op import *
from test_fill_constant_batch_size_like_op import *
from test_lod_reset_op import *
from test_scatter_op import *
from test_mean_op import *
from test_slice_op import *
from test_linear_chain_crf_op import *
from test_bilinear_interp_op import *
from test_nearest_interp_op import *
from test_sequence_concat import *
from test_seq_conv import *
from test_seq_pool import *
from test_sequence_expand_as import *
from test_sequence_expand import *
from test_sequence_pad_op import *
from test_sequence_unpad_op import *
from test_sequence_scatter_op import *
from test_sequence_slice_op import *
from test_pad2d_op import *
from test_fill_zeros_like2_op import *
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from paddle.fluid.framework import convert_np_dtype_to_dtype_
from op_test import OpTest
class TestFillZerosLike2Op(OpTest):
def setUp(self):
self.op_type = "fill_zeros_like2"
self.dtype = np.float32
self.init_dtype()
self.inputs = {'X': np.random.random((219, 232)).astype(self.dtype)}
self.outputs = {'Out': np.zeros_like(self.inputs["X"])}
self.attrs = {'dtype': convert_np_dtype_to_dtype_(self.dtype)}
def init_dtype(self):
pass
def test_check_output(self):
self.check_output()
class TestFillZerosLike2OpFp16(TestFillZerosLike2Op):
def init_dtype(self):
self.dtype = np.float16
class TestFillZerosLike2OpFp64(TestFillZerosLike2Op):
def init_dtype(self):
self.dtype = np.float64
if __name__ == "__main__":
unittest.main()
......@@ -348,6 +348,55 @@ class TestImperative(unittest.TestCase):
self.assertEqual(mlp._fc2, sublayers[1])
self.assertEqual(len(sublayers), 2)
def test_dygraph_vs_static(self):
inp1 = np.random.rand(4, 3, 3)
inp2 = np.random.rand(4, 3, 3)
# dynamic graph
with fluid.dygraph.guard():
if np.sum(inp1) < np.sum(inp2):
x = fluid.layers.elementwise_add(inp1, inp2)
else:
x = fluid.layers.elementwise_sub(inp1, inp2)
dygraph_result = x._numpy()
# static graph
with new_program_scope():
inp_data1 = fluid.layers.data(
name='inp1', shape=[3, 3], dtype=np.float32)
inp_data2 = fluid.layers.data(
name='inp2', shape=[3, 3], dtype=np.float32)
a = fluid.layers.expand(
fluid.layers.reshape(
fluid.layers.reduce_sum(inp_data1), [1, 1]), [4, 1])
b = fluid.layers.expand(
fluid.layers.reshape(
fluid.layers.reduce_sum(inp_data2), [1, 1]), [4, 1])
cond = fluid.layers.less_than(x=a, y=b)
ie = fluid.layers.IfElse(cond)
with ie.true_block():
d1 = ie.input(inp_data1)
d2 = ie.input(inp_data2)
d3 = fluid.layers.elementwise_add(d1, d2)
ie.output(d3)
with ie.false_block():
d1 = ie.input(inp_data1)
d2 = ie.input(inp_data2)
d3 = fluid.layers.elementwise_sub(d1, d2)
ie.output(d3)
out = ie()
exe = fluid.Executor(fluid.CPUPlace(
) if not core.is_compiled_with_cuda() else fluid.CUDAPlace(0))
static_result = exe.run(fluid.default_main_program(),
feed={'inp1': inp1,
'inp2': inp2},
fetch_list=out)[0]
self.assertTrue(np.allclose(dygraph_result, static_result))
def test_rnn(self):
np_inp = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0],
[10.0, 11.0, 12.0]])
......
......@@ -595,6 +595,280 @@ class TestLayer(LayerTest):
self.assertTrue(np.allclose(static_rlt2, static_rlt))
self.assertTrue(np.allclose(nce_loss3.numpy(), static_rlt))
def test_conv3d(self):
with self.static_graph():
images = layers.data(
name='pixel', shape=[3, 6, 6, 6], dtype='float32')
ret = layers.conv3d(input=images, num_filters=3, filter_size=2)
static_ret = self.get_static_graph_result(
feed={'pixel': np.ones(
[2, 3, 6, 6, 6], dtype='float32')},
fetch_list=[ret])[0]
with self.static_graph():
images = layers.data(
name='pixel', shape=[3, 6, 6, 6], dtype='float32')
conv3d = nn.Conv3D('conv3d', num_filters=3, filter_size=2)
ret = conv3d(images)
static_ret2 = self.get_static_graph_result(
feed={'pixel': np.ones(
[2, 3, 6, 6, 6], dtype='float32')},
fetch_list=[ret])[0]
with self.dynamic_graph():
images = np.ones([2, 3, 6, 6, 6], dtype='float32')
conv3d = nn.Conv3D('conv3d', num_filters=3, filter_size=2)
dy_ret = conv3d(base.to_variable(images))
self.assertTrue(np.allclose(static_ret, dy_ret._numpy()))
self.assertTrue(np.allclose(static_ret, static_ret2))
def test_row_conv(self):
input = np.arange(15).reshape([3, 5]).astype('float32')
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
else:
place = core.CPUPlace()
with self.static_graph():
x = layers.data(
name='X',
shape=[3, 5],
dtype='float32',
lod_level=1,
append_batch_size=False)
ret = layers.row_conv(input=x, future_context_size=2)
static_ret = self.get_static_graph_result(
feed={
'X': fluid.create_lod_tensor(
data=input, recursive_seq_lens=[[1, 1, 1]], place=place)
},
fetch_list=[ret],
with_lod=True)[0]
with self.static_graph():
x = layers.data(
name='X',
shape=[3, 5],
dtype='float32',
lod_level=1,
append_batch_size=False)
rowConv = nn.RowConv('RowConv', future_context_size=2)
ret = rowConv(x)
static_ret2 = self.get_static_graph_result(
feed={
'X': fluid.create_lod_tensor(
data=input, recursive_seq_lens=[[1, 1, 1]], place=place)
},
fetch_list=[ret],
with_lod=True)[0]
# TODO: dygraph can't support LODTensor
self.assertTrue(np.allclose(static_ret, static_ret2))
def test_group_norm(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
else:
place = core.CPUPlace()
shape = (2, 4, 3, 3)
input = np.random.random(shape).astype('float32')
with self.static_graph():
X = fluid.layers.data(
name='X',
shape=shape,
dtype='float32',
lod_level=1,
append_batch_size=False)
ret = layers.group_norm(input=X, groups=2)
static_ret = self.get_static_graph_result(
feed={
'X': fluid.create_lod_tensor(
data=input, recursive_seq_lens=[[1, 1]], place=place)
},
fetch_list=[ret],
with_lod=True)[0]
with self.static_graph():
X = fluid.layers.data(
name='X',
shape=shape,
dtype='float32',
lod_level=1,
append_batch_size=False)
groupNorm = nn.GroupNorm('GroupNorm', groups=2)
ret = groupNorm(X)
static_ret2 = self.get_static_graph_result(
feed={
'X': fluid.create_lod_tensor(
data=input, recursive_seq_lens=[[1, 1]], place=place)
},
fetch_list=[ret],
with_lod=True)[0]
with self.dynamic_graph():
groupNorm = nn.GroupNorm('GroupNorm', groups=2)
dy_ret = groupNorm(base.to_variable(input))
self.assertTrue(np.allclose(static_ret, dy_ret._numpy()))
self.assertTrue(np.allclose(static_ret, static_ret2))
def test_spectral_norm(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
else:
place = core.CPUPlace()
shape = (2, 4, 3, 3)
input = np.random.random(shape).astype('float32')
with self.static_graph():
Weight = fluid.layers.data(
name='Weight',
shape=shape,
dtype='float32',
lod_level=1,
append_batch_size=False)
ret = layers.spectral_norm(weight=Weight, dim=1, power_iters=2)
static_ret = self.get_static_graph_result(
feed={
'Weight': fluid.create_lod_tensor(
data=input, recursive_seq_lens=[[1, 1]], place=place),
},
fetch_list=[ret],
with_lod=True)[0]
with self.static_graph():
Weight = fluid.layers.data(
name='Weight',
shape=shape,
dtype='float32',
lod_level=1,
append_batch_size=False)
spectralNorm = nn.SpectralNorm('SpectralNorm', dim=1, power_iters=2)
ret = spectralNorm(Weight)
static_ret2 = self.get_static_graph_result(
feed={
'Weight': fluid.create_lod_tensor(
data=input, recursive_seq_lens=[[1, 1]], place=place)
},
fetch_list=[ret],
with_lod=True)[0]
with self.dynamic_graph():
spectralNorm = nn.SpectralNorm('SpectralNorm', dim=1, power_iters=2)
dy_ret = spectralNorm(base.to_variable(input))
self.assertTrue(np.allclose(static_ret, dy_ret._numpy()))
self.assertTrue(np.allclose(static_ret, static_ret2))
def test_tree_conv(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
else:
place = core.CPUPlace()
adj_array = [1, 2, 1, 3, 1, 4, 1, 5, 2, 6, 2, 7, 2, 8, 4, 9, 4, 10]
adj = np.array(adj_array).reshape((1, 9, 2)).astype('int32')
adj = np.tile(adj, (1, 1, 1))
vectors = np.random.random((1, 10, 5)).astype('float32')
with self.static_graph():
NodesVector = fluid.layers.data(
name='NodesVector',
shape=(1, 10, 5),
dtype='float32',
lod_level=1,
append_batch_size=False)
EdgeSet = fluid.layers.data(
name='EdgeSet',
shape=(1, 9, 2),
dtype='int32',
lod_level=1,
append_batch_size=False)
ret = layers.tree_conv(
nodes_vector=NodesVector,
edge_set=EdgeSet,
output_size=6,
num_filters=1,
max_depth=2)
static_ret = self.get_static_graph_result(
feed={
'NodesVector': fluid.create_lod_tensor(
data=vectors, recursive_seq_lens=[[1]], place=place),
'EdgeSet': fluid.create_lod_tensor(
data=adj, recursive_seq_lens=[[1]], place=place)
},
fetch_list=[ret],
with_lod=False)[0]
with self.static_graph():
NodesVector = fluid.layers.data(
name='NodesVector',
shape=(1, 10, 5),
dtype='float32',
lod_level=1,
append_batch_size=False)
EdgeSet = fluid.layers.data(
name='EdgeSet',
shape=(1, 9, 2),
dtype='int32',
lod_level=1,
append_batch_size=False)
treeConv = nn.TreeConv(
'TreeConv', output_size=6, num_filters=1, max_depth=2)
ret = treeConv(NodesVector, EdgeSet)
static_ret2 = self.get_static_graph_result(
feed={
'NodesVector': fluid.create_lod_tensor(
data=vectors, recursive_seq_lens=[[1]], place=place),
'EdgeSet': fluid.create_lod_tensor(
data=adj, recursive_seq_lens=[[1]], place=place)
},
fetch_list=[ret],
with_lod=False)[0]
with self.dynamic_graph():
treeConv = nn.TreeConv(
'SpectralNorm', output_size=6, num_filters=1, max_depth=2)
dy_ret = treeConv(base.to_variable(vectors), base.to_variable(adj))
self.assertTrue(np.allclose(static_ret, static_ret2))
self.assertTrue(np.allclose(static_ret, dy_ret._numpy()))
def test_conv3d_transpose(self):
input_array = np.arange(0, 48).reshape(
[2, 3, 2, 2, 2]).astype('float32')
with self.static_graph():
img = layers.data(name='pixel', shape=[3, 2, 2, 2], dtype='float32')
out = layers.conv3d_transpose(
input=img, num_filters=12, filter_size=12, use_cudnn=False)
static_rlt = self.get_static_graph_result(
feed={'pixel': input_array}, fetch_list=[out])[0]
with self.static_graph():
img = layers.data(name='pixel', shape=[3, 2, 2, 2], dtype='float32')
conv3d_transpose = nn.Conv3DTranspose(
'Conv3DTranspose',
num_filters=12,
filter_size=12,
use_cudnn=False)
out = conv3d_transpose(img)
static_rlt2 = self.get_static_graph_result(
feed={'pixel': input_array}, fetch_list=[out])[0]
with self.dynamic_graph():
conv3d_transpose = nn.Conv3DTranspose(
'Conv3DTranspose',
num_filters=12,
filter_size=12,
use_cudnn=False)
dy_rlt = conv3d_transpose(base.to_variable(input_array))
self.assertTrue(np.allclose(static_rlt2, static_rlt))
self.assertTrue(np.allclose(dy_rlt._numpy(), static_rlt))
class TestBook(unittest.TestCase):
def test_fit_a_line(self):
......
......@@ -38,7 +38,15 @@ def Lenet(data, class_dim):
class TestFetchAndFeed(unittest.TestCase):
def parallel_exe(self, use_cuda, run_parallel_exe, seed=1):
@classmethod
def setUpClass(cls):
os.environ['CPU_NUM'] = str(4)
def parallel_exe(self,
use_cuda,
run_parallel_exe,
use_experimental_executor=False,
seed=1):
main_program = fluid.Program()
startup = fluid.Program()
startup.random_seed = seed
......@@ -63,8 +71,12 @@ class TestFetchAndFeed(unittest.TestCase):
build_strategy = fluid.BuildStrategy()
build_strategy.enable_inplace = False
build_strategy.memory_optimize = False
exec_strategy = fluid.ExecutionStrategy()
exec_strategy.use_experimental_executor = use_experimental_executor
train_cp = compiler.CompiledProgram(main_program).with_data_parallel(
loss_name=loss.name, build_strategy=build_strategy)
loss_name=loss.name,
build_strategy=build_strategy,
exec_strategy=exec_strategy)
run_parallel_exe(train_cp, exe, use_cuda, data, label, loss)
......@@ -131,8 +143,7 @@ class TestFetchAndFeed(unittest.TestCase):
if batch_id == 2:
break
def test_fetch(self):
os.environ['CPU_NUM'] = str(4)
def test_fetch_with_threaded_executor(self):
if core.is_compiled_with_cuda():
self.parallel_exe(
use_cuda=True,
......@@ -140,8 +151,18 @@ class TestFetchAndFeed(unittest.TestCase):
self.parallel_exe(
use_cuda=False, run_parallel_exe=self.run_parallel_exe_with_fetch)
def test_fetch_with_fast_threaded_executor(self):
if core.is_compiled_with_cuda():
self.parallel_exe(
use_cuda=True,
run_parallel_exe=self.run_parallel_exe_with_fetch,
use_experimental_executor=True)
self.parallel_exe(
use_cuda=False,
run_parallel_exe=self.run_parallel_exe_with_fetch,
use_experimental_executor=True)
def test_feed(self):
os.environ['CPU_NUM'] = str(4)
if core.is_compiled_with_cuda():
self.parallel_exe(
use_cuda=True, run_parallel_exe=self.run_parallel_exe_with_feed)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册