提交 2336d5ca 编写于 作者: Z zhoukunsheng

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into rank

......@@ -134,6 +134,11 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
out_layout =
out_layout == DataLayout::kAnyLayout ? DataLayout::kNCHW : out_layout;
auto& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = dynamic_cast<platform::MKLDNNDeviceContext*>(
pool.Get(expected_kernel_type.place_));
auto& cpu_engine = dev_ctx->GetEngine();
std::vector<int> in_tz = paddle::framework::vectorize2int(in.dims());
std::vector<int> out_tz = in_tz;
......@@ -142,25 +147,29 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
"Input tensor type is not supported: %s", in.type());
memory::data_type out_type = in_type;
auto in_format = platform::MKLDNNFormatForSize(in_tz.size(), in.format());
auto out_format =
platform::MKLDNNFormatForSize(in_tz.size(), ToMKLDNNFormat(out_layout));
// output tensor has the same dims as input. Reorder don't change dims
out->Resize(in.dims());
// tempory mem pd fr out , to make reorder
auto out_mem_pd = paddle::platform::create_prim_desc_from_dims(
paddle::framework::vectorize2int(out->dims()),
mkldnn::memory::format::blocked, out_type);
if (in.get_mkldnn_prim_desc() != out_mem_pd) {
if (in_format != out_format) {
void* in_data = GetDataFromTensor(in, in_type);
auto out_data = out->mutable_data(expected_kernel_type.place_, in.type());
auto in_memory = memory(in.get_mkldnn_prim_desc(), in_data);
auto out_memory = memory(out_mem_pd, out_data);
auto in_memory =
memory({{{in_tz}, in_type, in_format}, cpu_engine}, in_data);
auto out_memory =
memory({{{out_tz}, out_type, out_format}, cpu_engine}, out_data);
platform::Reorder(in_memory, out_memory);
} else {
out->ShareDataWith(in);
}
out->set_layout(out_layout);
// reset format since the out tensor will be feed to non-MKLDNN OPkernel
out->set_format(memory::format::format_undef);
#endif
}
......
......@@ -51,31 +51,13 @@ void TransformData(const OpKernelType &expected_kernel_type,
#ifdef PADDLE_WITH_MKLDNN
// Case1 - transform from Non-MKLDNN OPKernel to MKLDNN OPKernel
// Just set layout/format. No real transform occur
auto out_format = platform::MKLDNNFormatForSize(in.dims().size(),
ToMKLDNNFormat(lin));
out.ShareDataWith(input_tensor);
// TODO(jczaja): Remove that once all mkldnn ops
// are modified to work with mkldnn_blocked
auto mkldnn_fmt = [&](int rank) {
switch (rank) {
case 5:
return mkldnn::memory::format::ncdhw;
case 4:
return mkldnn::memory::format::nchw;
case 3:
return mkldnn::memory::format::ncw;
case 2:
return mkldnn::memory::format::nc;
case 1:
return mkldnn::memory::format::x;
default:
return mkldnn::memory::format::blocked;
}
};
auto out_mem_pd = paddle::platform::create_prim_desc_from_dims(
paddle::framework::vectorize2int(out.dims()),
mkldnn_fmt(out.dims().size()));
out.set_mkldnn_prim_desc(out_mem_pd);
out.set_layout(DataLayout::kMKLDNN);
out.set_format(out_format);
#endif
} else {
// Case2 - transfrom from MKLDNN OPKernel to Non-MKLDNN OPKernel
......
......@@ -10,7 +10,10 @@ cc_library(fetch_barrier_op_handle SRCS fetch_barrier_op_handle.cc DEPS framewor
cc_library(multi_devices_helper SRCS multi_devices_helper.cc DEPS graph graph_helper)
cc_library(multi_devices_graph_print_pass SRCS multi_devices_graph_print_pass.cc DEPS multi_devices_helper)
cc_library(multi_devices_graph_check_pass SRCS multi_devices_graph_check_pass.cc DEPS multi_devices_helper)
cc_library(alloc_continuous_space_for_grad_pass SRCS alloc_continuous_space_for_grad_pass.cc DEPS graph graph_helper)
cc_library(fuse_adam_op_pass SRCS fuse_adam_op_pass.cc fuse_optimizer_op_pass.cc DEPS graph graph_helper)
cc_library(fuse_sgd_op_pass SRCS fuse_sgd_op_pass.cc fuse_optimizer_op_pass.cc DEPS graph graph_helper)
cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows)
......@@ -104,5 +107,7 @@ cc_library(build_strategy SRCS build_strategy.cc DEPS
graph_viz_pass multi_devices_graph_pass
multi_devices_graph_print_pass multi_devices_graph_check_pass
fuse_elewise_add_act_pass multi_batch_merge_pass
fuse_relu_depthwise_conv_pass
memory_optimize_pass lock_free_optimize_pass alloc_continuous_space_for_grad_pass fuse_all_reduce_op_pass)
fuse_relu_depthwise_conv_pass
memory_optimize_pass lock_free_optimize_pass
alloc_continuous_space_for_grad_pass fuse_all_reduce_op_pass
fuse_adam_op_pass fuse_sgd_op_pass)
......@@ -21,6 +21,7 @@
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_registry.h"
DEFINE_uint32(fuse_parameter_memory_size, 0, // 0 KB
"fuse_parameter_memory_size is up limited memory size "
"of one group parameters' gradient which is the input "
......@@ -105,20 +106,29 @@ class AllocContinuousSpaceForGradPass : public ir::Pass {
auto ele_dtype = iter->second->Var()->GetDataType();
if (dtype == kDefaultDtype) {
dtype = ele_dtype;
PADDLE_ENFORCE_NE(ele_dtype, kDefaultDtype);
PADDLE_ENFORCE_NE(ele_dtype, kDefaultDtype,
"The data type should not be bool.");
}
PADDLE_ENFORCE_EQ(ele_dtype, dtype);
PADDLE_ENFORCE_EQ(ele_dtype, dtype,
"The data type of input is not consistent.");
}
// Create the fused variable name.
// Create a FusedVarsSet to avoid duplicating names for fused_var in other
// pass.
if (!result.Has(kFusedVars)) {
result.Set(kFusedVars, new FusedVars);
}
const std::string prefix(kFusedVarNamePrefix);
// The fused_var_name should be unique.
auto fused_var_name = prefix + "GRAD@" + params_grads[0].second;
// the kFusedGrads is used be fuse_optimizer_op_pass.
result.Set(kFusedGrads, new FusedGrads);
// the fused_var_name should be unique, so it appends
// params_grads.begin()->second.
auto fused_var_name = std::string(kFusedVarNamePrefix) + "@GRAD@" +
params_grads.begin()->second;
result.Get<FusedGrads>(kFusedGrads) = fused_var_name;
auto &fused_var_set = result.Get<FusedVars>(kFusedVars);
PADDLE_ENFORCE_EQ(fused_var_set.count(fused_var_name), 0);
PADDLE_ENFORCE_EQ(fused_var_set.count(fused_var_name), 0,
"%s is duplicate in FusedVars.", fused_var_name);
fused_var_set.insert(fused_var_name);
InitFusedVarsAndAllocSpaceForVars(places, local_scopes, vars,
......@@ -295,17 +305,6 @@ class AllocContinuousSpaceForGradPass : public ir::Pass {
return type == proto::VarType::LOD_TENSOR;
}
void AppendAllocSpaceForVarsOp(const std::vector<std::string> &params_name,
const std::vector<std::string> &grads_name,
const std::string &fused_var_name,
BlockDesc *global_block) const {
auto op_desc = global_block->AppendOp();
op_desc->SetType("alloc_continuous_space");
op_desc->SetInput("Input", params_name);
op_desc->SetOutput("Output", grads_name);
op_desc->SetOutput("FusedOutput", {fused_var_name});
}
void RecordParamsAndGrads(ir::Node *node,
ParamsAndGrads *params_grads) const {
try {
......@@ -358,6 +357,7 @@ class AllocContinuousSpaceForGradPass : public ir::Pass {
}
}
// Alloc continuous space for vars.
std::vector<std::string> grads_name;
std::vector<std::string> params_name;
grads_name.reserve(params_grads.size());
......@@ -370,7 +370,6 @@ class AllocContinuousSpaceForGradPass : public ir::Pass {
AppendAllocSpaceForVarsOp(params_name, grads_name, fused_var_name,
program_desc.MutableBlock(0));
// Run Only Once Programs
for (size_t i = 0; i < local_scopes.size(); ++i) {
for (auto &op_desc : program_desc.Block(0).AllOps()) {
auto op = OpRegistry::CreateOp(*op_desc);
......@@ -378,6 +377,17 @@ class AllocContinuousSpaceForGradPass : public ir::Pass {
}
}
}
void AppendAllocSpaceForVarsOp(const std::vector<std::string> &params_name,
const std::vector<std::string> &grads_name,
const std::string &fused_var_name,
BlockDesc *global_block) const {
auto op_desc = global_block->AppendOp();
op_desc->SetType("alloc_continuous_space");
op_desc->SetInput("Input", params_name);
op_desc->SetOutput("Output", grads_name);
op_desc->SetOutput("FusedOutput", {fused_var_name});
}
};
} // namespace details
......
......@@ -27,20 +27,17 @@ void BroadcastOpHandle::RunImpl() {
if (places_.size() == 1) return;
// The input and output may have dummy vars.
VarHandle *in_var_handle;
{
auto in_var_handles = DynamicCast<VarHandle>(inputs_);
PADDLE_ENFORCE_EQ(in_var_handles.size(), 1UL,
"The number of input should be one.");
in_var_handle = in_var_handles[0];
}
auto in_var_handles = DynamicCast<VarHandle>(inputs_);
auto out_var_handles = DynamicCast<VarHandle>(outputs_);
PADDLE_ENFORCE_EQ(in_var_handles.size(), 1UL,
"The number of input should be one.");
PADDLE_ENFORCE_EQ(
out_var_handles.size(), places_.size(),
"The number of output should equal to the number of places.");
VarHandle *in_var_handle = in_var_handles[0];
WaitInputVarGenerated();
std::vector<const Scope *> var_scopes;
......
......@@ -17,7 +17,6 @@ limitations under the License. */
#include <glog/logging.h>
#include <memory>
#include <utility>
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include "paddle/fluid/framework/details/multi_devices_graph_pass.h"
#include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h"
......@@ -82,23 +81,43 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
AppendPass("inplace_pass");
}
if (strategy.fuse_elewise_add_act_ops_) {
if (strategy_.fuse_elewise_add_act_ops_) {
VLOG(10) << "Add fuse_elewise_add_act_pass";
AppendPass("fuse_elewise_add_act_pass");
}
// for single card training, fuse_all_reduce_ops is unnecessary.
// alloc_continuous_space_for_grad_pass should be before of MultiDevPass.
if (strategy.fuse_all_reduce_ops_) {
if (strategy_.fuse_all_reduce_ops_) {
VLOG(10) << "Add alloc_continuous_space_for_grad_pass";
AppendPass("alloc_continuous_space_for_grad_pass");
}
if (strategy_.fuse_all_optimizer_ops_) {
if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce ||
strategy_.is_distribution_) {
VLOG(3)
<< "Currently, fuse_all_optimizer_ops only works under AllReduce "
"mode.";
strategy_.fuse_all_optimizer_ops_ = false;
} else {
VLOG(10) << "Add alloc_continuous_space_for_grad_pass";
AppendPass("alloc_continuous_space_for_grad_pass");
// NOTE: fuse_all_xx_ops will count the number of xx operator first,
// if the number is zero, fuse_all_reduce_ops will do nothing.
// Currently, only one type of optimization algorithm can be fused.
VLOG(10) << "Add fuse_adam_op_pass";
AppendPass("fuse_adam_op_pass");
VLOG(10) << "Add fuse_sgd_op_pass";
AppendPass("fuse_sgd_op_pass");
}
}
// Add a graph viz pass to record a graph.
if (!strategy.debug_graphviz_path_.empty()) {
auto viz_pass = AppendPass("graph_viz_pass");
const std::string graph_path = string::Sprintf(
"%s%s", strategy.debug_graphviz_path_.c_str(), "_fused_graph");
"%s%s", strategy_.debug_graphviz_path_.c_str(), "_fused_graph");
viz_pass->Set<std::string>("graph_viz_path", new std::string(graph_path));
}
......@@ -118,14 +137,14 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
// the de-fact IR, any reuse on Graph is meaningless.
// A side-effect of that, memory optimize cannot forsee the fetched vars
// , so fetchlist should be set persistable before call the Run interface.
if (strategy.memory_optimize_) {
if (strategy_.memory_optimize_) {
VLOG(10) << "Add memory_optimize_pass";
AppendPass("memory_optimize_pass");
}
AppendMultiDevPass(strategy);
AppendMultiDevPass(strategy_);
if (strategy.fuse_all_reduce_ops_) {
if (strategy_.fuse_all_reduce_ops_) {
// NOTE: fuse_all_reduce_ops will count the number of all_reduce operator
// first, if the number is zero, fuse_all_reduce_ops will do nothing.
VLOG(10) << "Add fuse_all_reduce_op_pass";
......@@ -151,7 +170,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
AppendPass("all_reduce_deps_pass");
}
if (SeqOnlyAllReduceOps(strategy)) {
if (SeqOnlyAllReduceOps(strategy_)) {
VLOG(10) << "Add all_reduce_deps_pass";
AppendPass("all_reduce_deps_pass");
}
......@@ -165,7 +184,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
// Convert graph to run on multi-devices.
void AppendMultiDevPass(const BuildStrategy &strategy) {
ir::Pass *multi_devices_pass = nullptr;
if (strategy_.is_distribution_) {
if (strategy.is_distribution_) {
VLOG(10) << "Add dist_multi_devices_pass";
multi_devices_pass = AppendPass("dist_multi_devices_pass").get();
} else {
......@@ -235,17 +254,22 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
pass->Erase(kNCCLCtxs);
pass->SetNotOwned<platform::NCCLContextMap>(kNCCLCtxs, nctx);
#endif
} else if (pass->Type() == "fuse_all_reduce_op_pass") {
} else if (pass->Type() == "alloc_continuous_space_for_grad_pass" ||
pass->Type() == "fuse_adam_op_pass" ||
pass->Type() == "fuse_sgd_op_pass" ||
pass->Type() == "fuse_all_reduce_op_pass") {
pass->Erase(kPlaces);
pass->SetNotOwned<const std::vector<platform::Place>>(kPlaces, &places);
pass->Erase(kLocalScopes);
pass->SetNotOwned<const std::vector<Scope *>>(kLocalScopes,
&local_scopes);
if (pass->Type() == "fuse_all_reduce_op_pass") {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
platform::NCCLContextMap *nctx = use_cuda ? nccl_ctxs : nullptr;
pass->Erase(kNCCLCtxs);
pass->SetNotOwned<platform::NCCLContextMap>(kNCCLCtxs, nctx);
platform::NCCLContextMap *nctx = use_cuda ? nccl_ctxs : nullptr;
pass->Erase(kNCCLCtxs);
pass->SetNotOwned<platform::NCCLContextMap>(kNCCLCtxs, nctx);
#endif
}
} else if (pass->Type() == "alloc_continuous_space_for_grad_pass") {
pass->Erase(kPlaces);
pass->SetNotOwned<const std::vector<platform::Place>>(kPlaces, &places);
......@@ -294,4 +318,6 @@ USE_PASS(inplace_pass);
USE_PASS(lock_free_optimize_pass);
USE_PASS(alloc_continuous_space_for_grad_pass);
USE_PASS(graph_to_program_pass);
USE_PASS(fuse_adam_op_pass);
USE_PASS(fuse_sgd_op_pass);
USE_PASS(fuse_all_reduce_op_pass);
......@@ -18,7 +18,6 @@
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/ir/pass_builder.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
......@@ -76,6 +75,8 @@ struct BuildStrategy {
bool fuse_elewise_add_act_ops_{false};
bool fuse_all_optimizer_ops_{false};
bool fuse_all_reduce_ops_{false};
bool fuse_relu_depthwise_conv_{false};
......
......@@ -31,9 +31,10 @@ FastThreadedSSAGraphExecutor::FastThreadedSSAGraphExecutor(
local_scopes_(local_scopes),
places_(places),
graph_(graph),
fetch_ctxs_(places),
pool_(strategy.num_threads_),
prepare_pool_(1), // add one more thread for generate op_deps
fetch_ctxs_(places) {
// add one more thread for generate op_deps
prepare_pool_(1) {
for (auto &op : ir::FilterByNodeWrapper<OpHandleBase>(*graph_)) {
int dep = static_cast<int>(op->NotReadyInputSize());
op_deps_.emplace(op, dep);
......
......@@ -14,7 +14,9 @@
#pragma once
#include <ThreadPool.h>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/framework/details/exception_holder.h"
......@@ -37,6 +39,8 @@ class FastThreadedSSAGraphExecutor : public SSAGraphExecutor {
const ir::Graph &Graph() const override;
private:
// Note(zcd): the ThreadPool should be placed last so that ThreadPool should
// be destroyed first.
ExecutionStrategy strategy_;
std::vector<Scope *> local_scopes_;
std::vector<platform::Place> places_;
......@@ -45,21 +49,22 @@ class FastThreadedSSAGraphExecutor : public SSAGraphExecutor {
std::unordered_map<OpHandleBase *, int> op_deps_;
std::vector<OpHandleBase *> bootstrap_ops_;
::ThreadPool pool_;
::ThreadPool prepare_pool_;
platform::DeviceContextPool fetch_ctxs_;
std::atomic<int> remaining_;
std::future<
std::unique_ptr<std::unordered_map<OpHandleBase *, std::atomic<int>>>>
atomic_op_deps_;
ExceptionHolder exception_;
::ThreadPool pool_;
::ThreadPool prepare_pool_;
void RunOpAsync(std::unordered_map<OpHandleBase *, std::atomic<int>> *op_deps,
OpHandleBase *op,
const std::shared_ptr<BlockingQueue<size_t>> &complete_q);
void PrepareAtomicOpDeps();
std::future<
std::unique_ptr<std::unordered_map<OpHandleBase *, std::atomic<int>>>>
atomic_op_deps_;
ExceptionHolder exception_;
};
} // namespace details
} // namespace framework
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/fuse_adam_op_pass.h"
#include <algorithm>
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace framework {
namespace details {
const std::string FuseAdamOpPass::GetOpType() const { return "adam"; }
const std::vector<std::string> FuseAdamOpPass::GetAuxiliaryVarNames() const {
return {"Param", "Moment1", "Moment2", "Beta1Pow", "Beta2Pow"};
}
void FuseAdamOpPass::FuseOptimizerOps(
const std::unordered_map<std::string, std::vector<std::string>>
&aux_var_set,
const std::unordered_map<std::string, std::string> &fused_vars_name,
const std::vector<ir::Node *> &adam_ops, ir::Graph *graph) const {
FuseAdamOps(aux_var_set, fused_vars_name, adam_ops, graph);
FuseScaleOps(aux_var_set.at("Beta1Pow"), fused_vars_name.at("Beta1Pow"),
adam_ops, graph);
FuseScaleOps(aux_var_set.at("Beta2Pow"), fused_vars_name.at("Beta2Pow"),
adam_ops, graph);
}
void FuseAdamOpPass::FuseAdamOps(
const std::unordered_map<std::string, std::vector<std::string>> &vars_set,
const std::unordered_map<std::string, std::string> &fused_vars_name,
const std::vector<ir::Node *> &adam_ops, ir::Graph *graph) const {
PADDLE_ENFORCE_GT(adam_ops.size(), static_cast<size_t>(0));
// Check attributions
// NOTE: If new attribution is added, the following code maybe need change.
int op_role = boost::get<int>(
adam_ops[0]->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName()));
float beta1 = boost::get<float>(adam_ops[0]->Op()->GetAttr("beta1"));
float beta2 = boost::get<float>(adam_ops[0]->Op()->GetAttr("beta2"));
float epsilon = boost::get<float>(adam_ops[0]->Op()->GetAttr("epsilon"));
bool lazy_mode = boost::get<bool>(adam_ops[0]->Op()->GetAttr("lazy_mode"));
int64_t min_row_size_to_use_multithread = boost::get<int64_t>(
adam_ops[0]->Op()->GetAttr("min_row_size_to_use_multithread"));
for (auto &adam_op : adam_ops) {
PADDLE_ENFORCE_EQ(beta1,
boost::get<float>(adam_op->Op()->GetAttr("beta1")));
PADDLE_ENFORCE_EQ(beta2,
boost::get<float>(adam_op->Op()->GetAttr("beta2")));
PADDLE_ENFORCE_EQ(epsilon,
boost::get<float>(adam_op->Op()->GetAttr("epsilon")));
PADDLE_ENFORCE_EQ(lazy_mode,
boost::get<bool>(adam_op->Op()->GetAttr("lazy_mode")));
PADDLE_ENFORCE_EQ(min_row_size_to_use_multithread,
boost::get<int64_t>(adam_op->Op()->GetAttr(
"min_row_size_to_use_multithread")));
PADDLE_ENFORCE_EQ(op_role, boost::get<int>(adam_op->Op()->GetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName())));
}
// NOTE: fused_var is only exist in scope, so the graph doesn't have fused_var
// node.
VLOG(10) << "Insert adam to graph ";
OpDesc adam_desc(adam_ops[0]->Op()->Block());
adam_desc.SetType("adam");
adam_desc.SetInput("Param", {fused_vars_name.at("Param")});
adam_desc.SetInput("Grad", {fused_vars_name.at("Grad")});
adam_desc.SetInput("Moment1", {fused_vars_name.at("Moment1")});
adam_desc.SetInput("Moment2", {fused_vars_name.at("Moment2")});
// TODO(zcd): The LearningRate, Beta1Pow, Beta2Pow should be equal.
adam_desc.SetInput("LearningRate", adam_ops[0]->Op()->Input("LearningRate"));
adam_desc.SetInput("Beta1Pow", adam_ops[0]->Op()->Input("Beta1Pow"));
adam_desc.SetInput("Beta2Pow", adam_ops[0]->Op()->Input("Beta2Pow"));
adam_desc.SetOutput("ParamOut", {fused_vars_name.at("Param")});
adam_desc.SetOutput("Moment1Out", {fused_vars_name.at("Moment1")});
adam_desc.SetOutput("Moment2Out", {fused_vars_name.at("Moment2")});
adam_desc.SetAttr("beta1", beta1);
adam_desc.SetAttr("beta2", beta2);
adam_desc.SetAttr("epsilon", epsilon);
adam_desc.SetAttr("lazy_mode", lazy_mode);
adam_desc.SetAttr("min_row_size_to_use_multithread",
min_row_size_to_use_multithread);
adam_desc.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), op_role);
auto adam_node = graph->CreateOpNode(&adam_desc);
InserInputAndOutputForOptOps(adam_ops, adam_node);
}
void FuseAdamOpPass::FuseScaleOps(const std::vector<std::string> &beta_name,
const std::string &fused_var_name,
const std::vector<ir::Node *> &adam_ops,
ir::Graph *graph) const {
PADDLE_ENFORCE_EQ(beta_name.size(), adam_ops.size());
const std::string scale_op_name = "scale";
// Get the scale_ops of dealing the adam's beta var.
std::vector<ir::Node *> scale_ops;
scale_ops.reserve(beta_name.size());
for (size_t i = 0; i < adam_ops.size(); ++i) {
auto &beta_1_pow_name = beta_name[i];
auto beta_pow_iter = std::find_if(
adam_ops[i]->inputs.begin(), adam_ops[i]->inputs.end(),
[&beta_name, &beta_1_pow_name](ir::Node *var_node) -> bool {
return var_node->Var() && var_node->Var()->Name() == beta_1_pow_name;
});
PADDLE_ENFORCE(beta_pow_iter != adam_ops[i]->inputs.end());
auto beta_pow_node = *beta_pow_iter;
auto scale_op_iter = std::find_if(
beta_pow_node->outputs.begin(), beta_pow_node->outputs.end(),
[&scale_op_name](ir::Node *op_node) -> bool {
return op_node->Op() && op_node->Op()->Type() == scale_op_name;
});
PADDLE_ENFORCE(scale_op_iter != beta_pow_node->outputs.end());
scale_ops.emplace_back(*scale_op_iter);
}
PADDLE_ENFORCE_EQ(scale_ops.size(), beta_name.size());
// Check attributions
// NOTE: If new attribution is added, the following code maybe need change.
int op_role = boost::get<int>(
scale_ops[0]->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName()));
float scale = boost::get<float>(scale_ops[0]->Op()->GetAttr("scale"));
float bias = boost::get<float>(scale_ops[0]->Op()->GetAttr("bias"));
bool bias_after_scale =
boost::get<bool>(scale_ops[0]->Op()->GetAttr("bias_after_scale"));
for (auto &scale_op : scale_ops) {
PADDLE_ENFORCE_EQ(scale,
boost::get<float>(scale_op->Op()->GetAttr("scale")));
PADDLE_ENFORCE_EQ(bias, boost::get<float>(scale_op->Op()->GetAttr("bias")));
PADDLE_ENFORCE_EQ(
bias_after_scale,
boost::get<bool>(scale_op->Op()->GetAttr("bias_after_scale")));
PADDLE_ENFORCE_EQ(op_role, boost::get<int>(scale_op->Op()->GetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName())));
}
// NOTE: fused_var is only exist in scope, so the graph doesn't have fused_var
// node.
VLOG(10) << "Insert fused scale to graph.";
OpDesc scale_desc(scale_ops[0]->Op()->Block());
scale_desc.SetType("scale");
scale_desc.SetInput("X", {fused_var_name});
scale_desc.SetOutput("Out", {fused_var_name});
scale_desc.SetAttr("scale", scale);
scale_desc.SetAttr("bias", bias);
scale_desc.SetAttr("bias_after_scale", bias_after_scale);
scale_desc.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), op_role);
auto scale_node = graph->CreateOpNode(&scale_desc);
for (auto scale_op : scale_ops) {
// set inputs
scale_node->inputs.insert(scale_node->inputs.begin(),
scale_op->inputs.begin(), scale_op->inputs.end());
for (auto &input : scale_op->inputs) {
std::replace(input->outputs.begin(), input->outputs.end(), scale_op,
scale_node);
}
// set outputs
scale_node->outputs.insert(scale_node->outputs.begin(),
scale_op->outputs.begin(),
scale_op->outputs.end());
for (auto &output : scale_op->outputs) {
std::replace(output->inputs.begin(), output->inputs.end(), scale_op,
scale_node);
}
}
// Delete scale_ops
for (auto &scale_op : scale_ops) {
graph->RemoveNode(scale_op);
}
}
} // namespace details
} // namespace framework
} // namespace paddle
REGISTER_PASS(fuse_adam_op_pass, paddle::framework::details::FuseAdamOpPass)
.RequirePassAttr(paddle::framework::details::kPlaces)
.RequirePassAttr(paddle::framework::details::kLocalScopes);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/details/fuse_optimizer_op_pass.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph.h"
namespace paddle {
namespace framework {
namespace details {
class FuseAdamOpPass : public FuseOptimizerOpPass {
private:
virtual const std::string GetOpType() const;
virtual const std::vector<std::string> GetAuxiliaryVarNames() const;
// Fuse Adam Ops and Scale Ops which are used to update "Beta1Pow", "Beta2Pow"
virtual void FuseOptimizerOps(
const std::unordered_map<std::string, std::vector<std::string>> &vars_set,
const std::unordered_map<std::string, std::string> &fused_vars_name,
const std::vector<ir::Node *> &adam_ops, ir::Graph *graph) const;
void FuseAdamOps(
const std::unordered_map<std::string, std::vector<std::string>> &vars_set,
const std::unordered_map<std::string, std::string> &fused_vars_name,
const std::vector<ir::Node *> &adam_ops, ir::Graph *graph) const;
void FuseScaleOps(const std::vector<std::string> &aux_var_set,
const std::string &fused_var_name,
const std::vector<ir::Node *> &adam_ops,
ir::Graph *graph) const;
};
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/fuse_optimizer_op_pass.h"
#include <algorithm>
#include <unordered_set>
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace framework {
namespace details {
void FuseOptimizerOpPass::ApplyImpl(ir::Graph *graph) const {
ir::Graph &result = *graph;
auto &places = Get<const std::vector<platform::Place>>(kPlaces);
auto &local_scopes = Get<const std::vector<Scope *>>(kLocalScopes);
const std::string fuse_op_type = GetOpType();
const std::vector<std::string> aux_var_names = GetAuxiliaryVarNames();
// Step 1: Get the specified op and auxiliary variables.
std::vector<ir::Node *> topo_nodes = ir::TopologySortOperations(result);
std::unordered_map<std::string, std::vector<std::string>> aux_var_set;
std::vector<ir::Node *> opt_ops;
for (auto &node : topo_nodes) {
GetSpecifiedOpsAndVars(fuse_op_type, aux_var_names, node, &opt_ops,
&aux_var_set);
}
VLOG(10) << "Find " << fuse_op_type << " operators: " << opt_ops.size();
if (opt_ops.size() == 0) {
return;
}
if (result.Has(kFusedOptType)) {
VLOG(10)
<< "Currently only support fusing one type optimizer op. Has fused "
<< result.Get<FusedOptType>(kFusedOptType);
return;
} else {
result.Set(kFusedOptType, new FusedOptType);
}
result.Get<FusedOptType>(kFusedOptType) = fuse_op_type;
// Step 2: Insert fused_var_name to FusedVars, and the FusedVars need be
// initialized in scopes before execution.
if (!result.Has(kFusedVars)) {
result.Set(kFusedVars, new FusedVars);
}
std::unordered_map<std::string, std::string> fused_vars_name;
fused_vars_name.reserve(aux_var_names.size() + 1);
auto &fused_var_set = result.Get<FusedVars>(kFusedVars);
const std::string prefix(kFusedVarNamePrefix);
// NOTE: the fused_var_name should be unique.
for (auto &var_name : aux_var_names) {
auto fused_var_name = prefix + "_" + fuse_op_type + "_" + var_name + "_" +
aux_var_set[var_name][0];
VLOG(10) << fused_var_name;
fused_vars_name.emplace(var_name, fused_var_name);
PADDLE_ENFORCE_EQ(fused_var_set.count(fused_var_name), 0);
fused_var_set.insert(fused_var_name);
}
// Step 3: Get the fused Gradient's name
auto &params_grads = result.Get<ParamsAndGrads>(kParamsAndGrads);
if (!result.Has(kFusedGrads)) {
PADDLE_THROW(
"The alloc_continuous_space_for_grad_pass should be called before this "
"pass.");
}
auto &fused_grad = result.Get<FusedGrads>(kFusedGrads);
auto &fused_vars = result.Get<FusedVars>(kFusedVars);
auto iter = std::find(fused_vars.begin(), fused_vars.end(), fused_grad);
PADDLE_ENFORCE(iter != fused_vars.end(), "Not find the fused_grad.");
fused_vars_name.emplace("Grad", fused_grad);
// Step 4: Sort the parameters and auxiliary variables according
// to parameters' name to make variables' name correspond correctly.
PADDLE_ENFORCE(result.Has(kParamsAndGrads), "Does't find kParamsAndGrads.");
PADDLE_ENFORCE_EQ(params_grads.size(), aux_var_set.begin()->second.size(),
"The size of params_grads and aux_var_set are not equal.");
SortParametersAndAuxVars(params_grads, &aux_var_set, &opt_ops);
// Step 5: Alloc continuous space for Parameters and AuxiliaryVar(e.g.
// Moment1, Moment2, Beta1Pow, Beta2Pow) of all the optimizer ops separately.
InitFusedVarsAndAllocSpaceForVars(places, local_scopes, aux_var_names,
aux_var_set, fused_vars_name);
// Step 6: Fuse optimizer Ops and Scale Ops
FuseOptimizerOps(aux_var_set, fused_vars_name, opt_ops, &result);
// Step 7: Remove optimizer Ops
for (auto &opt_op : opt_ops) {
graph->RemoveNode(opt_op);
}
}
void FuseOptimizerOpPass::InitFusedVarsAndAllocSpaceForVars(
const std::vector<platform::Place> &places,
const std::vector<Scope *> &local_scopes,
const std::vector<std::string> &aux_var_names,
const std::unordered_map<std::string, std::vector<std::string>>
&aux_var_set,
const std::unordered_map<std::string, std::string> &fused_vars_name) const {
VLOG(10) << "Init FusedVars.";
// Alloc parameters and auxiliary vars in the respective scope.
size_t idx = local_scopes.size();
for (auto iter = local_scopes.rbegin(); iter != local_scopes.rend();
++iter, --idx) {
auto &scope = *iter;
for (auto &var_name : aux_var_names) {
auto fused_var_name = fused_vars_name.at(var_name);
VLOG(10) << "Init " << fused_var_name;
PADDLE_ENFORCE(scope->FindVar(fused_var_name) == nullptr,
"%s has exist in scope[%d]", fused_var_name, idx);
scope->Var(fused_var_name)->GetMutable<LoDTensor>();
}
}
ProgramDesc program_desc;
auto *global_block = program_desc.MutableBlock(0);
for (auto &var_name : aux_var_names) {
AppendAllocContinuousSpace(aux_var_set.at(var_name),
fused_vars_name.at(var_name), true,
global_block);
}
for (size_t i = 0; i < local_scopes.size(); ++i) {
for (auto &op_desc : global_block->AllOps()) {
auto op = OpRegistry::CreateOp(*op_desc);
op->Run(*local_scopes[i], places[i]);
}
}
}
void FuseOptimizerOpPass::SortParametersAndAuxVars(
const std::vector<std::pair<std::string, std::string>> &params_grads,
std::unordered_map<std::string, std::vector<std::string>> *aux_vars_set,
std::vector<ir::Node *> *ops) const {
PADDLE_ENFORCE_NE(aux_vars_set->count("Param"), static_cast<size_t>(0));
auto &param_vec = aux_vars_set->at("Param");
std::vector<size_t> param_sort_idx;
param_sort_idx.reserve(param_vec.size());
for (auto &p_g : params_grads) {
auto iter = std::find(param_vec.begin(), param_vec.end(), p_g.first);
PADDLE_ENFORCE(iter != param_vec.end());
auto idx = std::distance(param_vec.begin(), iter);
param_sort_idx.emplace_back(idx);
}
for (auto &aux_vars : *aux_vars_set) {
std::vector<std::string> sorted_vars;
sorted_vars.reserve(aux_vars.second.size());
for (size_t i = 0; i < aux_vars.second.size(); ++i) {
sorted_vars.emplace_back(aux_vars.second.at(param_sort_idx[i]));
}
std::swap(aux_vars.second, sorted_vars);
std::stringstream out;
for (auto &var_name : aux_vars.second) {
out << var_name << " ";
}
VLOG(10) << aux_vars.first << ": " << out.str();
}
std::vector<ir::Node *> sorted_ops;
sorted_ops.reserve(ops->size());
for (size_t i = 0; i < ops->size(); ++i) {
sorted_ops.emplace_back(ops->at(param_sort_idx[i]));
}
std::swap(*ops, sorted_ops);
}
void FuseOptimizerOpPass::GetSpecifiedOpsAndVars(
const std::string &op_type, const std::vector<std::string> &aux_vars_name,
ir::Node *node, std::vector<ir::Node *> *ops,
std::unordered_map<std::string, std::vector<std::string>> *aux_args_name)
const {
if (node->Op()->Type() != op_type) return;
for (auto &var_n : aux_vars_name) {
auto arg_names = node->Op()->Input(var_n);
PADDLE_ENFORCE_EQ(arg_names.size(), static_cast<size_t>(1));
(*aux_args_name)[var_n].emplace_back(arg_names[0]);
VLOG(10) << var_n << ", " << arg_names[0];
}
ops->emplace_back(node);
}
void FuseOptimizerOpPass::AppendAllocContinuousSpace(
const std::vector<std::string> &args, const std::string &out_arg,
bool copy_data, BlockDesc *global_block) const {
auto op_desc = global_block->AppendOp();
op_desc->SetType("alloc_continuous_space");
op_desc->SetInput("Input", args);
op_desc->SetOutput("Output", args);
op_desc->SetOutput("FusedOutput", {out_arg});
op_desc->SetAttr("copy_data", copy_data);
op_desc->SetAttr("check_name", true);
}
void FuseOptimizerOpPass::InserInputAndOutputForOptOps(
const std::vector<ir::Node *> &opt_ops, ir::Node *opt_node) const {
std::unordered_set<ir::Node *> inputs;
std::unordered_set<ir::Node *> outputs;
for (auto opt_op : opt_ops) {
// set inputs
inputs.insert(opt_op->inputs.begin(), opt_op->inputs.end());
for (auto &input : opt_op->inputs) {
replace(input->outputs.begin(), input->outputs.end(), opt_op, opt_node);
}
// set outputs
outputs.insert(opt_op->outputs.begin(), opt_op->outputs.end());
for (auto &output : opt_op->outputs) {
replace(output->inputs.begin(), output->inputs.end(), opt_op, opt_node);
}
}
opt_node->inputs.insert(opt_node->inputs.begin(), inputs.begin(),
inputs.end());
opt_node->outputs.insert(opt_node->outputs.begin(), outputs.begin(),
outputs.end());
}
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph.h"
namespace paddle {
namespace framework {
namespace details {
class FuseOptimizerOpPass : public ir::Pass {
protected:
void ApplyImpl(ir::Graph *graph) const override;
protected:
virtual void SortParametersAndAuxVars(
const std::vector<std::pair<std::string, std::string>> &params_grads,
std::unordered_map<std::string, std::vector<std::string>> *aux_var_set,
std::vector<ir::Node *> *ops) const;
void InserInputAndOutputForOptOps(const std::vector<ir::Node *> &opt_ops,
ir::Node *opt_node) const;
private:
virtual const std::string GetOpType() const = 0;
virtual const std::vector<std::string> GetAuxiliaryVarNames() const = 0;
virtual void FuseOptimizerOps(
const std::unordered_map<std::string, std::vector<std::string>> &vars_set,
const std::unordered_map<std::string, std::string> &fused_vars_name,
const std::vector<ir::Node *> &adam_ops, ir::Graph *graph) const = 0;
void GetSpecifiedOpsAndVars(
const std::string &op_type, const std::vector<std::string> &aux_vars_name,
ir::Node *node, std::vector<ir::Node *> *ops,
std::unordered_map<std::string, std::vector<std::string>> *aux_args_name)
const;
void AppendAllocContinuousSpace(const std::vector<std::string> &args,
const std::string &out_arg, bool copy_data,
BlockDesc *global_block) const;
void InitFusedVarsAndAllocSpaceForVars(
const std::vector<platform::Place> &places,
const std::vector<Scope *> &local_scopes,
const std::vector<std::string> &aux_var_names,
const std::unordered_map<std::string, std::vector<std::string>>
&aux_var_set,
const std::unordered_map<std::string, std::string> &fused_vars_name)
const;
};
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/fuse_sgd_op_pass.h"
#include <algorithm>
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace framework {
namespace details {
const std::string FuseSgdOpPass::GetOpType() const { return "sgd"; }
const std::vector<std::string> FuseSgdOpPass::GetAuxiliaryVarNames() const {
return {"Param"};
}
void FuseSgdOpPass::FuseOptimizerOps(
const std::unordered_map<std::string, std::vector<std::string>>
&aux_var_set,
const std::unordered_map<std::string, std::string> &fused_vars_name,
const std::vector<ir::Node *> &sgd_ops, ir::Graph *graph) const {
FuseSgdOps(aux_var_set, fused_vars_name, sgd_ops, graph);
}
void FuseSgdOpPass::FuseSgdOps(
const std::unordered_map<std::string, std::vector<std::string>> &vars_set,
const std::unordered_map<std::string, std::string> &fused_vars_name,
const std::vector<ir::Node *> &sgd_ops, ir::Graph *graph) const {
PADDLE_ENFORCE_GT(sgd_ops.size(), static_cast<size_t>(0));
// NOTE: fused_var is only exist in scope, so the graph doesn't have fused_var
// node.
int op_role = boost::get<int>(
sgd_ops[0]->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName()));
VLOG(10) << "Insert sgd to graph ";
// Add fused scale
OpDesc Sgd_desc(sgd_ops[0]->Op()->Block());
Sgd_desc.SetType("sgd");
Sgd_desc.SetInput("Param", {fused_vars_name.at("Param")});
Sgd_desc.SetInput("Grad", {fused_vars_name.at("Grad")});
Sgd_desc.SetOutput("ParamOut", {fused_vars_name.at("Param")});
// TODO(zcd): The LearningRate, Beta1Pow, Beta2Pow should be equal.
Sgd_desc.SetInput("LearningRate", sgd_ops[0]->Op()->Input("LearningRate"));
// NOTE: multi_devices_pass requires that every op should have a role.
Sgd_desc.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), op_role);
auto sgd_node = graph->CreateOpNode(&Sgd_desc);
InserInputAndOutputForOptOps(sgd_ops, sgd_node);
}
} // namespace details
} // namespace framework
} // namespace paddle
REGISTER_PASS(fuse_sgd_op_pass, paddle::framework::details::FuseSgdOpPass)
.RequirePassAttr(paddle::framework::details::kPlaces)
.RequirePassAttr(paddle::framework::details::kLocalScopes);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/details/fuse_optimizer_op_pass.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph.h"
namespace paddle {
namespace framework {
namespace details {
class FuseSgdOpPass : public FuseOptimizerOpPass {
private:
virtual const std::string GetOpType() const;
virtual const std::vector<std::string> GetAuxiliaryVarNames() const;
// Fuse Sgd Ops
virtual void FuseOptimizerOps(
const std::unordered_map<std::string, std::vector<std::string>> &vars_set,
const std::unordered_map<std::string, std::string> &fused_vars_name,
const std::vector<ir::Node *> &sgd_ops, ir::Graph *graph) const;
void FuseSgdOps(
const std::unordered_map<std::string, std::vector<std::string>> &vars_set,
const std::unordered_map<std::string, std::string> &fused_vars_name,
const std::vector<ir::Node *> &sgd_ops, ir::Graph *graph) const;
};
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -24,6 +24,19 @@ namespace paddle {
namespace framework {
namespace details {
// Note(zcd): Addresses should be aligned, otherwise, the results may have
// diff.
static size_t Alignment(size_t size, const platform::Place &place) {
// Allow to allocate the minimum chunk size is 4 KB.
size_t alignment = 1 << 12;
if (platform::is_gpu_place(place)) {
// Allow to allocate the minimum chunk size is 256 B.
alignment = 1 << 8;
}
size_t remaining = size % alignment;
return remaining == 0 ? size : size + (alignment - remaining);
}
typedef std::vector<std::vector<std::pair<std::string, const LoDTensor *>>>
GradientAndLoDTensor;
......@@ -111,10 +124,11 @@ void FusedAllReduceOpHandle::RunImpl() {
return grad1.second->data<void>() < grad2.second->data<void>();
});
size_t size_of_dtype = framework::SizeOfType(dtype);
for (size_t k = 1; k < g_tensor.size(); ++k) {
const void *cur_address = g_tensor.at(k - 1).second->data<void>();
int64_t len = g_tensor.at(k - 1).second->numel();
auto offset = len * framework::SizeOfType(dtype);
auto offset = Alignment(len * size_of_dtype, places_[0]);
void *infer_next_address = reinterpret_cast<void *>(
reinterpret_cast<uintptr_t>(cur_address) + offset);
const void *next_address = g_tensor.at(k).second->data<void>();
......@@ -228,18 +242,21 @@ void FusedAllReduceOpHandle::GetDTypeAndNumel(
const std::vector<std::pair<std::string, const LoDTensor *>> &grad_tensor,
proto::VarType::Type *dtype, int64_t *numel) const {
*numel = 0;
size_t size_of_dtype = 0;
for (size_t i = 0; i < grad_tensor.size(); ++i) {
// Get element number
int64_t len = grad_tensor.at(i).second->numel();
PADDLE_ENFORCE_GT(len, 0);
*numel += len;
// Get dtype
auto ele_type = grad_tensor.at(i).second->type();
if (i == 0) {
*dtype = ele_type;
size_of_dtype = framework::SizeOfType(ele_type);
}
PADDLE_ENFORCE_EQ(ele_type, *dtype);
// Get element number
int64_t len = grad_tensor.at(i).second->numel();
PADDLE_ENFORCE_GT(len, 0);
// Alignment(len)
*numel += Alignment(len * size_of_dtype, places_[0]) / size_of_dtype;
}
}
......
......@@ -20,7 +20,6 @@
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph.h"
......@@ -34,6 +33,10 @@ namespace framework {
class Scope;
namespace details {
constexpr char kLossVarName[] = "loss_var_name";
constexpr char kStrategy[] = "strategy";
constexpr char kNRanks[] = "nranks";
class MultiDevSSAGraphBuilderBase : public ir::Pass {
protected:
void ApplyImpl(ir::Graph *graph) const override;
......
......@@ -20,7 +20,6 @@
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/details/var_handle.h"
......@@ -41,22 +40,25 @@ namespace details {
// `std::vector<VarHandle*>` is the version of varaibles.
typedef std::vector<std::unordered_map<std::string, std::vector<VarHandle *>>>
GraphVars;
const char kGraphVars[] = "vars";
// aux variables to represent dependency. Useful to resolve data hazard.
typedef std::unordered_set<VarHandleBase *> GraphDepVars;
const char kGraphDepVars[] = "dep_vars";
constexpr char kGraphVars[] = "vars";
constexpr char kNCCLCtxs[] = "nccl_ctxs";
constexpr char kLossVarName[] = "loss_var_name";
constexpr char kPlaces[] = "places";
constexpr char kLocalScopes[] = "local_scopes";
constexpr char kStrategy[] = "strategy";
constexpr char kNRanks[] = "nranks";
constexpr char kNCCLCtxs[] = "nccl_ctxs";
// aux variables to represent dependency. Useful to resolve data hazard.
typedef std::unordered_set<VarHandleBase *> GraphDepVars;
constexpr char kGraphDepVars[] = "dep_vars";
typedef std::unordered_set<std::string> FusedVars;
constexpr char kFusedVars[] = "fused_vars";
constexpr char kFusedVarNamePrefix[] = "@FUSEDVAR@";
typedef std::string FusedOptType;
constexpr char kFusedOptType[] = "fused_opt_type";
typedef std::string FusedGrads;
constexpr char kFusedGrads[] = "fused_gradients";
typedef std::vector<std::pair<std::string, std::string>> ParamsAndGrads;
constexpr char kParamsAndGrads[] = "params_grads";
......@@ -65,8 +67,6 @@ typedef std::vector<std::vector<std::pair<std::string, std::string>>>
GroupGradsAndParams;
constexpr char kGroupGradsAndParams[] = "group_grads_params";
constexpr char kFusedVarNamePrefix[] = "@FUSEDVAR@";
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -24,13 +24,13 @@ ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor(
const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, ir::Graph *graph)
: graph_(graph),
pool_(strategy.num_threads_ >= 2 ? new ::ThreadPool(strategy.num_threads_)
: nullptr),
prepare_pool_(1),
local_scopes_(local_scopes),
places_(places),
fetch_ctxs_(places),
strategy_(strategy) {
strategy_(strategy),
prepare_pool_(1),
pool_(strategy.num_threads_ >= 2 ? new ::ThreadPool(strategy.num_threads_)
: nullptr) {
PrepareOpDeps();
CopyOpDeps();
}
......
......@@ -63,13 +63,20 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
details::OpHandleBase *op);
private:
// Note(zcd): the ThreadPool should be placed last so that ThreadPool should
// be destroyed first.
ir::Graph *graph_;
std::unique_ptr<::ThreadPool> pool_;
::ThreadPool prepare_pool_;
std::vector<Scope *> local_scopes_;
std::vector<platform::Place> places_;
platform::DeviceContextPool fetch_ctxs_;
ExceptionHolder exception_holder_;
std::unique_ptr<OpDependentData> op_deps_;
std::future<std::unique_ptr<OpDependentData>> op_deps_futures_;
ExecutionStrategy strategy_;
// use std::list because clear(), push_back, and for_each are O(1)
std::list<std::future<void>> run_op_futures_;
::ThreadPool prepare_pool_;
std::unique_ptr<::ThreadPool> pool_;
void InsertPendingOp(std::unordered_map<OpHandleBase *, size_t> *pending_ops,
OpHandleBase *op_instance) const;
......@@ -88,14 +95,6 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
void PrepareOpDeps();
void CopyOpDeps();
private:
std::future<std::unique_ptr<OpDependentData>> op_deps_futures_;
ExecutionStrategy strategy_;
std::unique_ptr<OpDependentData> op_deps_;
// use std::list because clear(), push_back, and for_each are O(1)
std::list<std::future<void>> run_op_futures_;
};
} // namespace details
......
......@@ -70,7 +70,7 @@ Tensor& Tensor::ShareDataWith(const Tensor& src) {
return *this;
}
Tensor Tensor::Slice(int begin_idx, int end_idx) const {
Tensor Tensor::Slice(int64_t begin_idx, int64_t end_idx) const {
check_memory_size();
PADDLE_ENFORCE_GE(begin_idx, 0,
"The start row index must be greater than 0.");
......
......@@ -18,6 +18,7 @@ limitations under the License. */
#include <cstring>
#include <memory>
#include <typeindex>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/framework/ddim.h"
......@@ -27,10 +28,6 @@ limitations under the License. */
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_utils.h"
#endif
namespace paddle {
namespace framework {
......@@ -41,34 +38,10 @@ class Tensor {
#ifdef PADDLE_WITH_MKLDNN
public:
// TODO(jczaja): This is depracted and will be removed
inline mkldnn::memory::format format() const {
if (layout_ == DataLayout::kMKLDNN) {
return static_cast<mkldnn::memory::format>(mem_pd_.desc().data.format);
} else {
return mkldnn::memory::format::format_undef;
}
}
inline mkldnn::memory::format format() const { return format_; }
// TODO(jczaja): This is depracted and will be removed
inline void set_format(
const mkldnn::memory::format fmt,
mkldnn::memory::data_type data_type = mkldnn::memory::f32) {
mem_pd_ = paddle::platform::create_prim_desc_from_format(
paddle::framework::vectorize2int(dims()), fmt, data_type);
layout_ = DataLayout::kMKLDNN;
}
inline mkldnn::memory::primitive_desc get_mkldnn_prim_desc() const {
return mem_pd_;
}
inline void set_mkldnn_prim_desc(
const mkldnn::memory::primitive_desc& mem_pd) {
// Internally MKL-DNN is just copying (increasing reference counter)
// to shared_ptr. So asignment should be quite cheap
mem_pd_ = mem_pd;
layout_ = DataLayout::kMKLDNN;
inline void set_format(const mkldnn::memory::format format) {
format_ = format;
}
protected:
......@@ -76,9 +49,12 @@ class Tensor {
* @brief the detail format of memory block which have layout as kMKLDNN
*
* @note MKLDNN lib support various memory format like nchw, nhwc, nChw8C,
* nChw16c, etc. For a MKLDNN memory block, we store memory descriptor
* nChw16c, etc. For a MKLDNN memory block, layout will be set as
* DataLayout::kMKLDNN meanwhile detail memory format will be kept in
* this field.
*/
mutable mkldnn::memory::primitive_desc mem_pd_;
mkldnn::memory::format format_ = mkldnn::memory::format::format_undef;
#endif
public:
......@@ -157,7 +133,7 @@ class Tensor {
* @param[in] end_idx The index of the end row(exclusive) to slice.
* The index number begins from 0.
*/
Tensor Slice(int begin_idx, int end_idx) const;
Tensor Slice(int64_t begin_idx, int64_t end_idx) const;
platform::Place place() const {
PADDLE_ENFORCE_NOT_NULL(
......
......@@ -44,11 +44,6 @@ void TensorCopy(const Tensor& src, const platform::Place& dst_place,
<< dst_place;
return;
}
#ifdef PADDLE_WITH_MKLDNN
if (src.layout() == DataLayout::kMKLDNN) {
dst->set_mkldnn_prim_desc(src.get_mkldnn_prim_desc());
}
#endif
memory::Copy(boost::get<platform::CPUPlace>(dst_place), dst_ptr,
boost::get<platform::CPUPlace>(src_place), src_ptr, size);
}
......
......@@ -65,7 +65,8 @@ class AllocContinuousSpaceKernel : public framework::OpKernel<T> {
// Get numel and dtype
size_t numel = 0;
auto dtype = kDefaultDtype;
GetMemSizeAndDtype(in_tensors, in_var_names, &numel, &dtype);
GetMemSizeAndDtype(in_tensors, in_var_names, &numel, &dtype,
context.GetPlace());
// Alloc the continuous space
auto fused_tensor = context.Output<framework::LoDTensor>("FusedOutput");
......@@ -74,14 +75,18 @@ class AllocContinuousSpaceKernel : public framework::OpKernel<T> {
// Init the continuous space
auto out_tensors = context.MultiOutput<framework::LoDTensor>("Output");
int64_t offset = 0;
size_t offset = 0;
size_t size_of_dtype = framework::SizeOfType(dtype);
if (context.Attr<bool>("copy_data")) {
for (size_t i = 0; i < in_var_names.size(); ++i) {
int64_t len = out_tensors[i]->numel();
auto sub_tensor = fused_tensor->Slice(offset, offset + len);
offset += len;
framework::TensorCopy(*out_tensors[i], context.GetPlace(), dev_ctx,
size_t len = static_cast<size_t>(in_tensors[i]->numel());
auto sub_tensor = fused_tensor->Slice(
static_cast<int64_t>(offset), static_cast<int64_t>(offset + len));
framework::TensorCopy(*in_tensors[i], context.GetPlace(), dev_ctx,
&sub_tensor);
offset +=
Alignment(len * size_of_dtype, context.GetPlace()) / size_of_dtype;
}
} else if (context.Attr<bool>("set_constant")) {
math::SetConstant<DeviceContext, T> set_constant;
......@@ -92,11 +97,13 @@ class AllocContinuousSpaceKernel : public framework::OpKernel<T> {
// Make the outputs point to the continuous space.
offset = 0;
for (size_t i = 0; i < out_tensors.size(); ++i) {
int64_t len = out_tensors[i]->numel();
size_t len = static_cast<size_t>(out_tensors[i]->numel());
auto dim = out_tensors[i]->dims();
out_tensors[i]
->ShareDataWith(fused_tensor->Slice(offset, offset + len))
->ShareDataWith(fused_tensor->Slice(
static_cast<int64_t>(offset), static_cast<int64_t>(offset + len)))
.Resize(dim);
len = Alignment(len * size_of_dtype, context.GetPlace()) / size_of_dtype;
offset += len;
VLOG(10) << "alloc_space_for_vars: output(" << out_var_names[i]
<< ") ,dim:(" << dim << ")"
......@@ -104,12 +111,28 @@ class AllocContinuousSpaceKernel : public framework::OpKernel<T> {
}
}
private:
// Note(zcd): Addresses should be aligned, otherwise, the results may have
// diff.
size_t Alignment(size_t size, const platform::Place &place) const {
// Allow to allocate the minimum chunk size is 4 KB.
size_t alignment = 1 << 12;
if (platform::is_gpu_place(place)) {
// Allow to allocate the minimum chunk size is 256 B.
alignment = 1 << 8;
}
size_t remaining = size % alignment;
return remaining == 0 ? size : size + (alignment - remaining);
}
void GetMemSizeAndDtype(
const std::vector<const framework::LoDTensor *> &lod_tensors,
const std::vector<std::string> var_names, size_t *numel,
framework::proto::VarType::Type *dtype) const {
framework::proto::VarType::Type *dtype,
const platform::Place &place) const {
PADDLE_ENFORCE_EQ(lod_tensors.size(), var_names.size());
*numel = 0;
size_t size_of_dtype = 0;
for (size_t i = 0; i < var_names.size(); ++i) {
PADDLE_ENFORCE(lod_tensors[i]->IsInitialized(), "%s is not initialized.",
var_names[i]);
......@@ -119,6 +142,7 @@ class AllocContinuousSpaceKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_NE(p_dtype, kDefaultDtype, "%s's type should not be %s.",
var_names[i], kDefaultDtype);
*dtype = p_dtype;
size_of_dtype = framework::SizeOfType(p_dtype);
}
PADDLE_ENFORCE_EQ(p_dtype, *dtype, "Input vars is not equal.");
......@@ -126,7 +150,8 @@ class AllocContinuousSpaceKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_GT(size, 0);
VLOG(10) << "alloc_space_for_vars: input(" << var_names[i] << ") ,dim:("
<< lod_tensors[i]->dims() << ")";
*numel += size;
*numel += Alignment(static_cast<size_t>(size) * size_of_dtype, place) /
size_of_dtype;
}
}
};
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/bpr_loss_op.h"
#include <memory>
namespace paddle {
namespace operators {
......@@ -127,6 +128,23 @@ neural networks>(https://arxiv.org/abs/1511.06939)
)DOC");
}
};
class BprLossGradDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("bpr_loss_grad");
op->SetInput("X", Input("X"));
op->SetInput("Label", Input("Label"));
op->SetInput(framework::GradVarName("Y"), OutputGrad("Y"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetAttrMap(Attrs());
return op;
}
};
} // namespace operators
} // namespace paddle
......@@ -134,7 +152,7 @@ namespace ops = paddle::operators;
using CPUCtx = paddle::platform::CPUDeviceContext;
REGISTER_OPERATOR(bpr_loss, ops::BprLossOp, ops::BprLossOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
ops::BprLossGradDescMaker);
REGISTER_OPERATOR(bpr_loss_grad, ops::BprLossGradientOp);
REGISTER_OP_CPU_KERNEL(bpr_loss, ops::BprLossOpKernel<CPUCtx, float>,
ops::BprLossOpKernel<CPUCtx, double>);
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include <memory>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
......@@ -568,13 +569,31 @@ class ROIPerspectiveTransformOpMaker
}
};
class ROIPerspectiveTransformGradDescMaker
: public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("roi_perspective_transform_grad");
op->SetInput("X", Input("X"));
op->SetInput("ROIs", Input("ROIs"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetAttrMap(Attrs());
return op;
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(roi_perspective_transform, ops::ROIPerspectiveTransformOp,
ops::ROIPerspectiveTransformOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
ops::ROIPerspectiveTransformGradDescMaker);
REGISTER_OPERATOR(roi_perspective_transform_grad,
ops::ROIPerspectiveTransformGradOp);
REGISTER_OP_CPU_KERNEL(roi_perspective_transform,
......
......@@ -77,7 +77,8 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> {
} else {
functor.RunMidWise(n, pre, post);
}
z->set_mkldnn_prim_desc(x->get_mkldnn_prim_desc());
z->set_layout(DataLayout::kMKLDNN);
z->set_format(x->format());
} else {
PADDLE_ENFORCE(x->layout() == DataLayout::kMKLDNN &&
x->format() != memory::format::format_undef,
......@@ -115,8 +116,7 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> {
auto sum_pd = sum::primitive_desc(dst_md, scales, srcs_pd);
// create mkldnn memory for dst
auto dst_mem_pd = sum_pd.dst_primitive_desc();
memory dst_memory = memory(dst_mem_pd, z_data);
memory dst_memory = memory(sum_pd.dst_primitive_desc(), z_data);
std::vector<primitive::at> inputs;
inputs.push_back(srcs[0]);
......@@ -129,7 +129,9 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> {
pipeline.push_back(sum_prim);
stream(stream::kind::eager).submit(pipeline).wait();
z->set_mkldnn_prim_desc(dst_mem_pd);
z->set_layout(DataLayout::kMKLDNN);
z->set_format(
(memory::format)dst_memory.get_primitive_desc().desc().data.format);
}
}
};
......@@ -150,19 +152,24 @@ class EltwiseAddMKLDNNGradKernel : public ElemwiseGradKernel<T> {
auto* out = dout;
auto *x = dout, *y = dout;
auto set_mkldnn_format = [](Tensor* in, const Tensor* out) {
in->set_layout(DataLayout::kMKLDNN);
in->set_format(out->format());
};
if (dx != nullptr && dy != nullptr && dx->dims() == dy->dims()) {
if (dx->dims() == dy->dims()) {
auto blas = math::GetBlas<paddle::platform::CPUDeviceContext, T>(ctx);
if (dx) {
blas.VCOPY(dout->numel(), dout->data<T>(),
dx->mutable_data<T>(ctx.GetPlace()));
dx->set_mkldnn_prim_desc(dout->get_mkldnn_prim_desc());
set_mkldnn_format(dx, dout);
}
if (dy) {
blas.VCOPY(dout->numel(), dout->data<T>(),
dy->mutable_data<T>(ctx.GetPlace()));
dy->set_mkldnn_prim_desc(dout->get_mkldnn_prim_desc());
set_mkldnn_format(dy, dout);
}
}
} else {
......
......@@ -65,11 +65,17 @@ by input arguments.
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(
GaussianRandomBatchSizeLikeNoNeedBufferVarsInference, "Input");
} // namespace operators
} // namespace paddle
REGISTER_OP_WITHOUT_GRADIENT(
REGISTER_OPERATOR(
gaussian_random_batch_size_like,
paddle::operators::GaussianRandomBatchSizeLikeOp,
paddle::operators::GaussianRandomBatchSizeLikeOpMaker);
paddle::operators::GaussianRandomBatchSizeLikeOpMaker,
paddle::framework::EmptyGradOpMaker,
paddle::operators::GaussianRandomBatchSizeLikeNoNeedBufferVarsInference);
// Kernels are registered in gaussian_random_op.cc and gaussian_random_op.cu
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/im2sequence_op.h"
#include <memory>
#include <string>
#include <vector>
......@@ -146,12 +147,28 @@ class Im2SequenceGradOp : public framework::OperatorWithKernel {
}
};
class Im2SequenceGradDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("im2sequence_grad");
op->SetInput("X", Input("X"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetAttrMap(Attrs());
return op;
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(im2sequence, ops::Im2SequenceOp, ops::Im2SequenceOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
ops::Im2SequenceGradDescMaker);
REGISTER_OPERATOR(im2sequence_grad, ops::Im2SequenceGradOp);
REGISTER_OP_CPU_KERNEL(
im2sequence,
......
......@@ -10,6 +10,7 @@
limitations under the License. */
#include "paddle/fluid/operators/interpolate_op.h"
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
......@@ -194,21 +195,46 @@ class InterpolateOpGrad : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
ctx.GetPlace());
return framework::OpKernelType(
ctx.Input<Tensor>(framework::GradVarName("Out"))->type(),
ctx.GetPlace());
}
};
class InterpolateGradDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType(ForwardOp().Type() + "_grad");
op->SetInput("X", Input("X"));
if (ForwardOp().Inputs().count("OutSize") > 0) {
op->SetInput("OutSize", Input("OutSize"));
}
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetAttrMap(Attrs());
return op;
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(InterpolateGradNoNeedBufferVarsInference,
"X");
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(bilinear_interp, ops::InterpolateOp, ops::InterpolateOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(bilinear_interp_grad, ops::InterpolateOpGrad);
ops::InterpolateGradDescMaker);
REGISTER_OPERATOR(bilinear_interp_grad, ops::InterpolateOpGrad,
ops::InterpolateGradNoNeedBufferVarsInference);
REGISTER_OPERATOR(nearest_interp, ops::InterpolateOp, ops::InterpolateOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(nearest_interp_grad, ops::InterpolateOpGrad);
ops::InterpolateGradDescMaker);
REGISTER_OPERATOR(nearest_interp_grad, ops::InterpolateOpGrad,
ops::InterpolateGradNoNeedBufferVarsInference);
REGISTER_OP_CPU_KERNEL(bilinear_interp, ops::InterpolateKernel<float>,
ops::InterpolateKernel<double>,
ops::InterpolateKernel<uint8_t>);
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/l1_norm_op.h"
#include <memory>
namespace paddle {
namespace operators {
......@@ -62,12 +63,28 @@ $$Out = \sum{|X|}$$
}
};
class L1NormGradDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("l1_norm_grad");
op->SetInput("X", Input("X"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetAttrMap(Attrs());
return op;
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(l1_norm, ops::L1NormOp, ops::L1NormOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
ops::L1NormGradDescMaker);
REGISTER_OPERATOR(l1_norm_grad, ops::L1NormGradOp);
REGISTER_OP_CPU_KERNEL(
l1_norm, ops::L1NormKernel<paddle::platform::CPUDeviceContext, float>);
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/label_smooth_op.h"
#include <memory>
#include <string>
namespace paddle {
......@@ -105,10 +106,23 @@ class LabelSmoothGradOp : public framework::OperatorWithKernel {
: OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) shouldn't be null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) shouldn't be null.");
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
ctx->SetOutputDim(framework::GradVarName("X"),
ctx->GetInputDim(framework::GradVarName("Out")));
}
};
class LabelSmoothGradDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("label_smooth_grad");
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetAttrMap(Attrs());
return op;
}
};
......@@ -117,7 +131,7 @@ class LabelSmoothGradOp : public framework::OperatorWithKernel {
namespace ops = paddle::operators;
REGISTER_OPERATOR(label_smooth, ops::LabelSmoothOp, ops::LabelSmoothOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
ops::LabelSmoothGradDescMaker);
REGISTER_OPERATOR(label_smooth_grad, ops::LabelSmoothGradOp);
REGISTER_OP_CPU_KERNEL(
label_smooth,
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/linear_chain_crf_op.h"
#include <memory>
namespace paddle {
namespace operators {
......@@ -250,14 +251,46 @@ class LinearChainCRFGradOp : public framework::OperatorWithKernel {
}
};
class LinearChainCRFGradDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("linear_chain_crf_grad");
op->SetAttrMap(Attrs());
op->SetInput("Emission", Input("Emission"));
op->SetInput("Transition", Input("Transition"));
op->SetInput("Label", Input("Label"));
op->SetInput("Alpha", Output("Alpha"));
op->SetInput("EmissionExps", Output("EmissionExps"));
op->SetInput("TransitionExps", Output("TransitionExps"));
op->SetInput(framework::GradVarName("LogLikelihood"),
OutputGrad("LogLikelihood"));
op->SetOutput(framework::GradVarName("Emission"), InputGrad("Emission"));
op->SetOutput(framework::GradVarName("Transition"),
InputGrad("Transition"));
return op;
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(
LinearChainCRFGradNoNeedBufferVarsInference, "Transition", "Emission");
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(linear_chain_crf, ops::LinearChainCRFOp,
ops::LinearChainCRFOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(linear_chain_crf_grad, ops::LinearChainCRFGradOp);
ops::LinearChainCRFOpMaker, ops::LinearChainCRFGradDescMaker);
REGISTER_OPERATOR(linear_chain_crf_grad, ops::LinearChainCRFGradOp,
ops::LinearChainCRFGradNoNeedBufferVarsInference);
REGISTER_OP_CPU_KERNEL(
linear_chain_crf,
ops::LinearChainCRFOpKernel<paddle::platform::CPUDeviceContext, float>,
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/log_loss_op.h"
#include <memory>
namespace paddle {
namespace operators {
......@@ -100,12 +101,29 @@ class LogLossGradOp : public framework::OperatorWithKernel {
}
};
class LogLossGradDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("log_loss_grad");
op->SetInput("Predicted", Input("Predicted"));
op->SetInput("Labels", Input("Labels"));
op->SetInput(framework::GradVarName("Loss"), OutputGrad("Loss"));
op->SetOutput(framework::GradVarName("Predicted"), InputGrad("Predicted"));
op->SetAttrMap(Attrs());
return op;
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(log_loss, ops::LogLossOp, ops::LogLossOpMaker<float>,
paddle::framework::DefaultGradOpDescMaker<true>);
ops::LogLossGradDescMaker);
REGISTER_OPERATOR(log_loss_grad, ops::LogLossGradOp);
REGISTER_OP_CPU_KERNEL(
log_loss, ops::LogLossKernel<paddle::platform::CPUDeviceContext, float>);
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/lstm_op.h"
#include <memory>
#include <string>
namespace paddle {
......@@ -264,12 +265,51 @@ class LSTMGradOp : public framework::OperatorWithKernel {
}
};
class LSTMGradOpDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("lstm_grad");
op->SetAttrMap(Attrs());
op->SetInput("Input", Input("Input"));
op->SetOutput(framework::GradVarName("Input"), InputGrad("Input"));
if (ForwardOp().Inputs().count("H0") > 0) {
op->SetInput("H0", Input("H0"));
op->SetOutput(framework::GradVarName("H0"), InputGrad("H0"));
}
if (ForwardOp().Inputs().count("C0") > 0) {
op->SetInput("C0", Input("C0"));
op->SetOutput(framework::GradVarName("C0"), InputGrad("C0"));
}
op->SetInput("Weight", Input("Weight"));
op->SetOutput(framework::GradVarName("Weight"), InputGrad("Weight"));
op->SetInput("Bias", Input("Bias"));
op->SetOutput(framework::GradVarName("Bias"), InputGrad("Bias"));
op->SetInput("Cell", Output("Cell"));
op->SetInput("Hidden", Output("Hidden"));
op->SetInput(framework::GradVarName("Hidden"), OutputGrad("Hidden"));
op->SetInput("BatchGate", Output("BatchGate"));
op->SetInput("BatchCellPreAct", Output("BatchCellPreAct"));
return op;
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(lstm, ops::LSTMOp, ops::LSTMOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
ops::LSTMGradOpDescMaker);
REGISTER_OPERATOR(lstm_grad, ops::LSTMGradOp);
REGISTER_OP_CPU_KERNEL(
lstm, ops::LSTMKernel<paddle::platform::CPUDeviceContext, float>,
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/margin_rank_loss_op.h"
#include <memory>
namespace paddle {
namespace operators {
......@@ -94,8 +95,6 @@ class MarginRankLossGradOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) shouldn't be null.");
PADDLE_ENFORCE(ctx->HasInput("X1"), "Input(X1) shouldn't be null.");
PADDLE_ENFORCE(ctx->HasInput("X2"), "Input(X2) shouldn't be null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) shouldn't be null.");
PADDLE_ENFORCE(ctx->HasInput("Activated"),
......@@ -106,13 +105,31 @@ class MarginRankLossGradOp : public framework::OperatorWithKernel {
}
};
class MarginRankLossGradDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("margin_rank_loss_grad");
op->SetInput("Activated", Output("Activated"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetInput("Label", Input("Label"));
op->SetOutput(framework::GradVarName("X1"), InputGrad("X1"));
op->SetOutput(framework::GradVarName("X2"), InputGrad("X2"));
op->SetAttrMap(Attrs());
return op;
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(margin_rank_loss, ops::MarginRankLossOp,
ops::MarginRankLossOpMaker<float>,
paddle::framework::DefaultGradOpDescMaker<true>);
ops::MarginRankLossGradDescMaker);
REGISTER_OPERATOR(margin_rank_loss_grad, ops::MarginRankLossGradOp);
REGISTER_OP_CPU_KERNEL(
margin_rank_loss,
......
......@@ -13,7 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/mean_op.h"
#include <memory>
#include <string>
#include <unordered_map>
namespace paddle {
namespace operators {
......@@ -61,7 +64,8 @@ class MeanGradOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto input_data_type = ctx.Input<Tensor>("X")->type();
auto input_data_type =
ctx.Input<Tensor>(framework::GradVarName("Out"))->type();
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
......@@ -81,13 +85,16 @@ class MeanGradMaker : public framework::SingleGradOpDescMaker {
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(MeanGradNoNeedBufferVarsInference, "X");
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(mean, ops::MeanOp, ops::MeanOpMaker, ops::MeanOpInferVarType,
ops::MeanGradMaker);
REGISTER_OPERATOR(mean_grad, ops::MeanGradOp);
REGISTER_OPERATOR(mean_grad, ops::MeanGradOp,
ops::MeanGradNoNeedBufferVarsInference);
REGISTER_OP_CPU_KERNEL(
mean, ops::MeanKernel<paddle::platform::CPUDeviceContext, float>,
ops::MeanKernel<paddle::platform::CPUDeviceContext, double>);
......
......@@ -96,7 +96,8 @@ void eltwise_forward(const framework::ExecutionContext &ctx,
std::vector<int> src_tz = framework::vectorize2int(x->dims());
auto src_format = x->format();
auto src_format =
src_tz.size() == 2 ? mkldnn::memory::format::nc : x->format();
const std::string key = gethash(src_tz, algorithm);
const std::string key_src_data =
......@@ -126,8 +127,10 @@ void eltwise_forward(const framework::ExecutionContext &ctx,
if (p_fwd == nullptr) {
// create mkldnn memory for input X
auto src_md = platform::MKLDNNMemDesc(
src_tz, platform::MKLDNNGetDataType<T>(), src_format);
auto src_memory = std::shared_ptr<memory>(
new memory(x->get_mkldnn_prim_desc(), to_void_cast(x_data)));
new memory({src_md, mkldnn_engine}, to_void_cast(x_data)));
// save src_memory to be referred in backward path
dev_ctx.SetBlob(key_src_mem, src_memory);
......@@ -174,7 +177,8 @@ void eltwise_forward(const framework::ExecutionContext &ctx,
pipeline.push_back(*p_fwd);
stream(stream::kind::eager).submit(pipeline).wait();
y->set_mkldnn_prim_desc(dst_memory->get_primitive_desc());
y->set_layout(DataLayout::kMKLDNN);
y->set_format(GetMKLDNNFormat(*dst_memory));
}
template <typename T>
......@@ -192,6 +196,9 @@ void eltwise_grad(const framework::ExecutionContext &ctx,
std::vector<int> diff_dst_tz = framework::vectorize2int(diff_y->dims());
auto diff_y_format =
diff_dst_tz.size() == 2 ? mkldnn::memory::format::nc : diff_y->format();
const std::string key = gethash(diff_dst_tz, algorithm);
const std::string key_src_data =
key + ctx.op().Input("Out") + "@eltwise_fwd_src_data";
......@@ -203,8 +210,8 @@ void eltwise_grad(const framework::ExecutionContext &ctx,
key + std::to_string(*p_src_layout) + "@eltwise_fwd_src_mem";
const std::string key_fwd_pd =
key + std::to_string(*p_src_layout) + "@eltwise_fwd_pd";
const std::string key_with_layouts = key + std::to_string(*p_src_layout) +
"-" + std::to_string(diff_y->format());
const std::string key_with_layouts =
key + std::to_string(*p_src_layout) + "-" + std::to_string(diff_y_format);
const std::string key_diff_src_mem =
key_with_layouts + "@eltwise_diff_src_mem";
const std::string key_diff_dst_mem =
......@@ -227,8 +234,10 @@ void eltwise_grad(const framework::ExecutionContext &ctx,
if (p_grad == nullptr) {
// create mkldnn memory for input diff_y
auto diff_dst_md = platform::MKLDNNMemDesc(
diff_dst_tz, platform::MKLDNNGetDataType<T>(), diff_y_format);
auto diff_dst_memory = std::shared_ptr<memory>(
new memory(diff_y->get_mkldnn_prim_desc(), to_void_cast(diff_y_data)));
new memory({diff_dst_md, mkldnn_engine}, to_void_cast(diff_y_data)));
dev_ctx.SetBlob(key_diff_dst_mem, diff_dst_memory);
// retrieve eltwise primitive desc from device context
......@@ -272,7 +281,8 @@ void eltwise_grad(const framework::ExecutionContext &ctx,
pipeline.push_back(*p_grad);
stream(stream::kind::eager).submit(pipeline).wait();
diff_x->set_mkldnn_prim_desc(diff_src_memory->get_primitive_desc());
diff_x->set_layout(DataLayout::kMKLDNN);
diff_x->set_format(GetMKLDNNFormat(*diff_src_memory));
}
template <typename T, mkldnn::algorithm algorithm>
......
......@@ -206,14 +206,17 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
if (fuse_with_relu) flags |= mkldnn::fuse_bn_relu;
// create mkldnn memory from input x tensor
mkldnn::memory::format input_format =
platform::MKLDNNFormatForSize(src_tz.size(), x->format());
// keys for backward pass
const std::string key = BatchNormMKLDNNHandler::GetHash(
src_tz, epsilon, flags, global_stats, x->format(),
src_tz, epsilon, flags, global_stats, input_format,
ctx.op().Output("SavedMean"));
const std::string key_batch_norm_fwd_pd = key + "@bn_fwd_pd";
auto user_src_md = x->get_mkldnn_prim_desc().desc();
auto user_src_md = platform::MKLDNNMemDesc(
{src_tz}, platform::MKLDNNGetDataType<T>(), input_format);
// create primitive descriptor for batch norm forward
using bn_fwd_types = bn_type_traits<mkldnn::batch_normalization_forward>;
......@@ -227,8 +230,8 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
BatchNormMKLDNNHandler handler(batch_norm_fwd_pd, dev_ctx, mkldnn_engine,
key);
auto src_memory = handler.AcquireSrcMemory(x->get_mkldnn_prim_desc(),
to_void_cast(x_data));
auto src_memory =
handler.AcquireSrcMemory(user_src_md, to_void_cast(x_data));
// crate mkldnn memory for weights(scale/shift)
auto scaleshift_memory =
......@@ -262,7 +265,8 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
variance_memory, false);
}
y->set_mkldnn_prim_desc(dst_memory->get_primitive_desc());
y->set_layout(DataLayout::kMKLDNN);
y->set_format(platform::GetMKLDNNFormat(*dst_memory));
std::vector<mkldnn::primitive> pipeline;
pipeline.push_back(*batch_norm_p);
......@@ -332,6 +336,9 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
using bn_bwd_types = bn_type_traits<mkldnn::batch_normalization_backward>;
mkldnn::memory::format dst_format =
platform::MKLDNNFormatForSize(src_tz.size(), diff_y->format());
mkldnn::memory::format input_format =
platform::MKLDNNFormatForSize(src_tz.size(), x->format());
......@@ -339,14 +346,14 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
// keys from forward pass
const std::string key = BatchNormMKLDNNHandler::GetHash(
src_tz, epsilon, flags, false, x->format(),
src_tz, epsilon, flags, false, input_format,
ctx.op().Input("SavedMean"));
const std::string key_batch_norm_fwd_pd = key + "@bn_fwd_pd";
// keys for primitives reuse
const std::string key_with_hash =
key + BatchNormMKLDNNHandler::GetHash(src_tz, epsilon, flags, false,
x->format());
input_format);
const std::string key_batch_norm_bwd_p =
key_with_hash + "@batch_norm_bwd_p";
const std::string key_batch_norm_src_mem_p =
......@@ -366,8 +373,9 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
primitive reorder_diff_dst;
bool is_diff_dst_reordered = false;
auto user_diff_dst_memory =
memory(diff_y->get_mkldnn_prim_desc(), to_void_cast(diff_y_data));
auto user_diff_dst_memory = memory(
{{{diff_dst_tz}, memory::data_type::f32, dst_format}, mkldnn_engine},
to_void_cast(diff_y_data));
// MKLDNN requires a single piece of memory for scale and shift/bias data
const size_t scaleshift_size = 2 * ic;
......@@ -451,7 +459,10 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
dev_ctx.SetBlob(key_batch_norm_diff_dst_mem_p, diff_dst_memory);
// set layout/format of output tensors
diff_x->set_mkldnn_prim_desc(diff_src_memory->get_primitive_desc());
diff_x->set_layout(DataLayout::kMKLDNN);
diff_x->set_format((memory::format)diff_src_memory->get_primitive_desc()
.desc()
.data.format);
} else {
// primitives already exist
UpdateMemoryData(dev_ctx, key_batch_norm_src_mem_p, to_void_cast(x_data));
......@@ -476,7 +487,10 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
}
// set layout/format of output tensors
diff_x->set_mkldnn_prim_desc(diff_src_memory->get_primitive_desc());
diff_x->set_layout(DataLayout::kMKLDNN);
diff_x->set_format((memory::format)diff_src_memory->get_primitive_desc()
.desc()
.data.format);
}
// execute optional reorder and batch_norm backward primitive
......
......@@ -210,7 +210,8 @@ class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
stream(stream::kind::eager).submit({*concat_p}).wait();
output->set_mkldnn_prim_desc(concat_pd->dst_primitive_desc());
output->set_layout(DataLayout::kMKLDNN);
output->set_format(GetDstMemFormat(*concat_pd));
}
};
} // namespace operators
......
......@@ -96,8 +96,12 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto* bias = ctx.HasInput("Bias") ? ctx.Input<Tensor>("Bias") : nullptr;
auto* output = ctx.Output<Tensor>("Output");
PADDLE_ENFORCE(input->layout() == DataLayout::kMKLDNN);
PADDLE_ENFORCE(filter->layout() == DataLayout::kMKLDNN);
PADDLE_ENFORCE(input->layout() == DataLayout::kMKLDNN &&
input->format() != memory::format::format_undef,
"Wrong layout/format set for Input tensor");
PADDLE_ENFORCE(filter->layout() == DataLayout::kMKLDNN &&
filter->format() != memory::format::format_undef,
"Wrong layout/format set for Filter tensor");
PADDLE_ENFORCE(input->dims().size() == 4 || input->dims().size() == 5,
"Input must be with 4 or 5 dimensions, i.e. NCHW or NCDHW");
PADDLE_ENFORCE(filter->dims().size() == 4 || filter->dims().size() == 5,
......@@ -144,19 +148,14 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std::vector<primitive> pipeline;
// For convolution with groups we need to recreate primitive descriptor
// as Paddle tensor is not having group dims while mkldnn treats
// group as another dimensions
mkldnn::memory::primitive_desc user_weights_mpd =
filter->get_mkldnn_prim_desc();
if (g > 1) {
mkldnn::memory::format weights_format =
GetWeightsFormat(filter->format(), g, is_conv3d);
auto user_weights_md = platform::MKLDNNMemDesc(
{weights_tz}, platform::MKLDNNGetDataType<T>(), weights_format);
user_weights_mpd =
mkldnn::memory::primitive_desc(user_weights_md, mkldnn_engine);
}
auto src_format = input->format();
mkldnn::memory::format weights_format =
GetWeightsFormat(filter->format(), g, is_conv3d);
auto user_src_md = platform::MKLDNNMemDesc(
{src_tz}, platform::MKLDNNGetDataType<T>(), src_format);
auto user_weights_md = platform::MKLDNNMemDesc(
{weights_tz}, platform::MKLDNNGetDataType<T>(), weights_format);
/* create memory descriptor for convolution without specified format
* ('any') which lets a primitive (convolution in this case) choose
......@@ -166,7 +165,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto chosen_memory_format =
platform::data_format_to_memory_format(data_format);
mkldnn::memory::format weights_format = mkldnn::memory::format::any;
weights_format = mkldnn::memory::format::any;
// Check the format for user's special output
if (chosen_memory_format != mkldnn::memory::format::any) {
if (is_conv3d) {
......@@ -206,10 +205,10 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
platform::ConvMKLDNNHandler handler(conv_pd, dev_ctx, mkldnn_engine, key);
// create mkldnn memory from input tensors (data/weights)
auto user_src_memory_p = handler.AcquireSrcMemory(
input->get_mkldnn_prim_desc(), to_void_cast<T>(input_data));
auto user_src_memory_p =
handler.AcquireSrcMemory(user_src_md, to_void_cast<T>(input_data));
auto user_weights_memory_p = handler.AcquireWeightsMemory(
user_weights_mpd, to_void_cast<T>(filter_data));
user_weights_md, to_void_cast<T>(filter_data));
// create reorder primitive if the input format is not the preferred one
auto src_memory_p =
......@@ -282,7 +281,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
pipeline.push_back(*conv_p);
stream(stream::kind::eager).submit(pipeline).wait();
output->set_mkldnn_prim_desc(dst_memory_p->get_primitive_desc());
output->set_layout(DataLayout::kMKLDNN);
output->set_format(GetMKLDNNFormat(*dst_memory_p));
}
void ComputeINT8(const paddle::framework::ExecutionContext& ctx) const {
const bool is_test = ctx.Attr<bool>("is_test");
......@@ -948,8 +948,8 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
// push primitive to stream and wait until it's executed
pipeline.push_back(*conv_bwd_weights_p);
auto filter_grad_mpd = diff_weights_memory_p->get_primitive_desc();
filter_grad->set_mkldnn_prim_desc(filter_grad_mpd);
filter_grad->set_layout(DataLayout::kMKLDNN);
filter_grad->set_format(GetMKLDNNFormat(*diff_weights_memory_p));
}
if (input_grad) {
......@@ -972,7 +972,8 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
pipeline.push_back(*conv_bwd_data_p);
input_grad->set_mkldnn_prim_desc(diff_src_memory_p->get_primitive_desc());
input_grad->set_layout(DataLayout::kMKLDNN);
input_grad->set_format(GetMKLDNNFormat(*diff_src_memory_p));
}
stream(stream::kind::eager).submit(pipeline).wait();
}
......
......@@ -221,7 +221,8 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
pipeline.push_back(*conv_p);
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
output->set_mkldnn_prim_desc(dst_memory_p->get_primitive_desc());
output->set_layout(DataLayout::kMKLDNN);
output->set_format(platform::GetMKLDNNFormat(*dst_memory_p));
}
private:
......
......@@ -42,12 +42,8 @@ class GaussianMKLDNNKernel : public paddle::framework::OpKernel<T> {
// The format of output is set as the mkldnn's format
// TODO(@mozga-intel) The format of matrix sets inside the another layers.
// TODO(jczaja): Remove this hack after checking performance on block layout
auto tensor_mem_pd = paddle::platform::create_prim_desc_from_dims(
paddle::framework::vectorize2int(tensor->dims()),
mkldnn::memory::format::oihw);
tensor->set_mkldnn_prim_desc(tensor_mem_pd);
tensor->set_layout(DataLayout::kMKLDNN);
tensor->set_format(mkldnn::memory::format::oihw);
}
};
} // namespace operators
......
......@@ -81,7 +81,10 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto e_mid = framework::EigenTensor<T, 4>::From(*mid);
e_mid = e_mid.constant(k);
auto src_md = x->get_mkldnn_prim_desc().desc();
auto dims = paddle::framework::vectorize2int(x->dims());
auto src_md = paddle::platform::MKLDNNMemDesc(
dims, mkldnn::memory::data_type::f32, x->format());
auto forward_desc = mkldnn::lrn_forward::desc{mkldnn::prop_kind::forward,
mkldnn::lrn_across_channels,
......@@ -91,7 +94,7 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
beta,
k};
auto src_memory_pd = x->get_mkldnn_prim_desc();
auto src_memory_pd = mkldnn::memory::primitive_desc{src_md, mkldnn_engine};
if (!is_test) {
const std::string key = ctx.op().Output("Out");
......@@ -108,15 +111,16 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
src_memory->set_data_handle(
static_cast<void*>(const_cast<T*>(input_data)));
auto dst_memory_pd = forward_pd->dst_primitive_desc();
auto dst_memory =
mkldnn::memory(dst_memory_pd, static_cast<void*>(output_data));
auto dst_memory = mkldnn::memory(forward_pd->dst_primitive_desc(),
static_cast<void*>(output_data));
auto workspace_memory = insert_to_context<mkldnn::memory>(
key_workspace_memory, dev_ctx,
forward_pd->workspace_primitive_desc());
run_primitive(*forward_pd, *src_memory, *workspace_memory, dst_memory);
out->set_mkldnn_prim_desc(dst_memory_pd);
out->set_layout(framework::DataLayout::kMKLDNN);
out->set_format(platform::GetMKLDNNFormat(dst_memory));
} else {
auto forward_pd =
mkldnn::lrn_forward::primitive_desc{forward_desc, mkldnn_engine};
......@@ -124,12 +128,13 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
src_memory_pd, static_cast<void*>(const_cast<T*>(input_data))};
auto workspace_memory =
mkldnn::memory{forward_pd.workspace_primitive_desc()};
auto dst_memory_pd = forward_pd.dst_primitive_desc();
auto dst_memory = mkldnn::memory(forward_pd.dst_primitive_desc(),
static_cast<void*>(output_data));
run_primitive(forward_pd, src_memory, workspace_memory, dst_memory);
out->set_mkldnn_prim_desc(dst_memory_pd);
out->set_layout(framework::DataLayout::kMKLDNN);
out->set_format(platform::GetMKLDNNFormat(dst_memory));
}
}
};
......
......@@ -158,14 +158,6 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
auto softmax_p =
handler.AcquireSoftmax(softmax_dst_memory_p, softmax_src_memory_p);
// We cannot use softmax_dst_memory_p to get prim desc as
// it contains flattened dims (2D) while output tensor can
// have 2,3,4+ dims
auto output_mem_pd = paddle::platform::create_prim_desc_from_dims(
paddle::framework::vectorize2int(output->dims()),
mkldnn::memory::format::blocked);
output->set_mkldnn_prim_desc(output_mem_pd);
std::vector<primitive> pipeline{
*(static_cast<softmax_forward::primitive*>(softmax_p.get()))};
stream(stream::kind::eager).submit(pipeline).wait();
......
......@@ -106,12 +106,12 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
memory::desc(dst_tz, memory::data_type::f32, memory::format::any);
auto sum_pd = sum::primitive_desc(dst_md, scales, srcs_mpd);
auto dst_mem_pd = sum_pd.dst_primitive_desc();
std::shared_ptr<memory> dst_mem;
if (in_place) {
dst_mem.reset(new memory(dst_mem_pd));
dst_mem.reset(new memory(sum_pd.dst_primitive_desc()));
} else {
dst_mem.reset(new memory(dst_mem_pd, output_data));
dst_mem.reset(new memory(sum_pd.dst_primitive_desc(), output_data));
}
std::vector<mkldnn::primitive::at> inputs;
for (size_t i = 0; i < srcs_mem.size(); ++i) {
......@@ -136,7 +136,8 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
if (in_place) pipeline.push_back(reorder_prim);
stream(stream::kind::eager).submit(pipeline).wait();
output->set_mkldnn_prim_desc(dst_mem_pd);
output->set_layout(DataLayout::kMKLDNN);
output->set_format(output_format);
} else { // Fallback to naive version
// TODO(@mozga-intel) Add MKLDNN SelectedRows & LoDTensorArray support
SumKernel<CPUDeviceContext, T> reference_kernel;
......
......@@ -52,7 +52,7 @@ class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
mkldnn_engine, key);
auto transpose_src_memory_p = handler.AcquireSrcMemory(
input->get_mkldnn_prim_desc(), platform::to_void_cast<T>(input_data));
input->format(), platform::to_void_cast<T>(input_data));
auto transpose_dst_memory_p =
handler.AcquireDstMemory(output, ctx.GetPlace());
auto transpose_p = handler.AcquireTranspose(transpose_dst_memory_p,
......@@ -62,14 +62,8 @@ class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
pipeline.push_back(*transpose_p);
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
// Transpose did change logical dimensions of Tensor, but reorder does not.
// Reorder does change only physical layout eg. format , strides
// so we need to create new primitive descriptor with changed logical layout
// so it match output shape
auto output_mem_pd = paddle::platform::create_prim_desc_from_dims(
paddle::framework::vectorize2int(output->dims()),
mkldnn::memory::format::blocked);
output->set_mkldnn_prim_desc(output_mem_pd);
output->set_layout(DataLayout::kNCHW);
output->set_format(mkldnn::memory::format::format_undef);
}
};
......@@ -134,9 +128,8 @@ class TransposeMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
platform::TransposeMKLDNNHandler handler(nchw_tz, reversed_axis, dev_ctx,
mkldnn_engine, key);
auto transpose_src_memory_p =
handler.AcquireSrcMemory(out_grad->get_mkldnn_prim_desc(),
platform::to_void_cast<T>(out_grad_data));
auto transpose_src_memory_p = handler.AcquireSrcMemory(
out_grad->format(), platform::to_void_cast<T>(out_grad_data));
auto transpose_dst_memory_p =
handler.AcquireDstMemory(x_grad, ctx.GetPlace());
auto transpose_p = handler.AcquireTranspose(transpose_dst_memory_p,
......@@ -145,15 +138,6 @@ class TransposeMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
std::vector<mkldnn::primitive> pipeline;
pipeline.push_back(*transpose_p);
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
// Transpose did change logical dimensions of Tensor, but reorder does not.
// Reorder does change only physical layout eg. format , strides
// so we need to create new primitive descriptor with changed logical layout
// so it match output shape
auto x_grad_mem_pd = paddle::platform::create_prim_desc_from_dims(
paddle::framework::vectorize2int(x_grad->dims()),
mkldnn::memory::format::blocked);
x_grad->set_mkldnn_prim_desc(x_grad_mem_pd);
}
};
......
......@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/multiplex_op.h"
#include <memory>
#include <vector>
namespace paddle {
namespace operators {
......@@ -111,28 +113,47 @@ class MultiplexGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(!ctx->Inputs("X").empty(), "Input(X) should not be null.");
PADDLE_ENFORCE(!ctx->Outputs(framework::GradVarName("X")).empty(),
"Output(X@Grad) should not be null.");
auto& dxs = ctx->Outputs(framework::GradVarName("X"));
PADDLE_ENFORCE(!dxs.empty(), "Output(X@Grad) should not be null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null.");
ctx->SetOutputsDim(framework::GradVarName("X"), ctx->GetInputsDim("X"));
auto dout_dim = ctx->GetInputDim(framework::GradVarName("Out"));
ctx->SetOutputsDim(framework::GradVarName("X"),
std::vector<framework::DDim>(dxs.size(), dout_dim));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.MultiInput<Tensor>("X")[0]->type(),
ctx.device_context());
return framework::OpKernelType(
ctx.Input<Tensor>(framework::GradVarName("Out"))->type(),
ctx.device_context());
}
};
class MultiplexGradDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("multiplex_grad");
op->SetInput("Ids", Input("Ids"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X", false));
op->SetAttrMap(Attrs());
return op;
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(multiplex, ops::MultiplexOp, ops::MultiplexOpMaker,
paddle::framework::DefaultGradOpDescMaker<false>);
ops::MultiplexGradDescMaker);
REGISTER_OPERATOR(multiplex_grad, ops::MultiplexGradOp);
REGISTER_OP_CPU_KERNEL(
multiplex,
......
......@@ -53,20 +53,25 @@ class MultiplexGradGPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const {
auto* d_out = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto ins = ctx.MultiInput<Tensor>("X");
auto* ids = ctx.Input<Tensor>("Ids");
auto d_ins = ctx.MultiOutput<Tensor>(framework::GradVarName("X"));
size_t idx = -1UL;
for (size_t i = 0; i < d_ins.size(); i++) {
if (d_ins[i]) {
d_ins[i]->mutable_data<T>(ctx.GetPlace());
auto t = framework::EigenVector<T>::Flatten(*d_ins[i]);
t.device(*ctx.template device_context<Place>().eigen_device()) =
t.constant(static_cast<T>(0));
idx = i;
}
}
auto rows = ins[0]->dims()[0];
auto cols = ins[0]->numel() / rows;
if (idx == -1UL) return;
auto rows = d_ins[idx]->dims()[0];
auto cols = d_ins[idx]->numel() / rows;
// copy index to cpu
Tensor index_t_cpu;
TensorCopySync(*ids, platform::CPUPlace(), &index_t_cpu);
......
......@@ -52,20 +52,25 @@ class MultiplexGradCPUKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const {
auto* d_out = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* ids = ctx.Input<framework::Tensor>("Ids");
auto ins = ctx.MultiInput<framework::Tensor>("X");
auto d_ins =
ctx.MultiOutput<framework::Tensor>(framework::GradVarName("X"));
size_t idx = -1UL;
for (size_t i = 0; i < d_ins.size(); i++) {
if (d_ins[i]) {
d_ins[i]->mutable_data<T>(ctx.GetPlace());
auto t = framework::EigenVector<T>::Flatten(*d_ins[i]);
t.device(*ctx.template device_context<DeviceContext>().eigen_device()) =
t.constant(static_cast<T>(0));
idx = i;
}
}
auto rows = ins[0]->dims()[0];
auto cols = ins[0]->numel() / rows;
if (idx == -1UL) return;
auto rows = d_ins[idx]->dims()[0];
auto cols = d_ins[idx]->numel() / rows;
auto* index = ids->data<int32_t>();
platform::CPUPlace place = boost::get<platform::CPUPlace>(ctx.GetPlace());
for (auto i = 0; i < rows; i++) {
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/pad_op.h"
#include <memory>
namespace paddle {
namespace operators {
......@@ -29,7 +30,7 @@ class PadOp : public framework::OperatorWithKernel {
"Output(Out) of PadOp should not be null.");
auto x_dim = ctx->GetInputDim("X");
auto paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
auto& paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
PADDLE_ENFORCE_EQ(x_dim.size() * 2, int64_t(paddings.size()),
"Size of paddings should be equal to 2 * dimension size "
"of input tensor.");
......@@ -99,13 +100,20 @@ class PadOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null");
auto x_dims = ctx->GetInputDim("X");
auto dout_dims = ctx->GetInputDim(framework::GradVarName("Out"));
auto& paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
for (int i = 0; i < dout_dims.size(); ++i) {
dout_dims[i] -= (paddings[i * 2] + paddings[i * 2 + 1]);
}
auto x_grad_name = framework::GradVarName("X");
if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, x_dims);
auto dout_dims = ctx->GetInputDim(framework::GradVarName("Out"));
auto& paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
for (int i = 0; i < dout_dims.size(); ++i) {
dout_dims[i] -= (paddings[i * 2] + paddings[i * 2 + 1]);
}
ctx->SetOutputDim(x_grad_name, dout_dims);
}
}
};
......@@ -117,7 +125,6 @@ class PadOpGradMaker : public framework::SingleGradOpDescMaker {
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
auto* bind = new framework::OpDesc();
bind->SetInput("X", Input("X"));
bind->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
bind->SetOutput(framework::GradVarName("X"), InputGrad("X"));
bind->SetAttrMap(Attrs());
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/psroi_pool_op.h"
#include <memory>
namespace paddle {
namespace operators {
......@@ -154,12 +155,29 @@ class PSROIPoolGradOp : public framework::OperatorWithKernel {
}
};
class PSROIPoolGradDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("psroi_pool_grad");
op->SetInput("X", Input("X"));
op->SetInput("ROIs", Input("ROIs"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetAttrMap(Attrs());
return op;
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(psroi_pool, ops::PSROIPoolOp, ops::PSROIPoolOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
ops::PSROIPoolGradDescMaker);
REGISTER_OPERATOR(psroi_pool_grad, ops::PSROIPoolGradOp);
REGISTER_OP_CPU_KERNEL(
psroi_pool,
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/rank_loss_op.h"
#include <memory>
#include <string>
namespace paddle {
......@@ -116,6 +117,25 @@ class RankLossGradOp : public framework::OperatorWithKernel {
}
};
class RankLossGradDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("rank_loss_grad");
op->SetInput("Label", Input("Label"));
op->SetInput("Left", Input("Left"));
op->SetInput("Right", Input("Right"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("Left"), InputGrad("Left"));
op->SetOutput(framework::GradVarName("Right"), InputGrad("Right"));
op->SetAttrMap(Attrs());
return op;
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
......
......@@ -10,6 +10,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/roi_align_op.h"
#include <memory>
namespace paddle {
namespace operators {
......@@ -147,12 +148,29 @@ Thus avoid the misaligned problem.
}
};
class ROIAlignGradDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("roi_align_grad");
op->SetInput("X", Input("X"));
op->SetInput("ROIs", Input("ROIs"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetAttrMap(Attrs());
return op;
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(roi_align, ops::ROIAlignOp, ops::ROIAlignOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
ops::ROIAlignGradDescMaker);
REGISTER_OPERATOR(roi_align_grad, ops::ROIAlignGradOp);
REGISTER_OP_CPU_KERNEL(
roi_align,
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/roi_pool_op.h"
#include <memory>
namespace paddle {
namespace operators {
......@@ -158,12 +159,30 @@ https://stackoverflow.com/questions/43430056/what-is-roi-layer-in-fast-rcnn
}
};
class ROIPoolGradDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("roi_pool_grad");
op->SetInput("X", Input("X"));
op->SetInput("ROIs", Input("ROIs"));
op->SetInput("Argmax", Output("Argmax"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetAttrMap(Attrs());
return op;
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(roi_pool, ops::ROIPoolOp, ops::ROIPoolOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
ops::ROIPoolGradDescMaker);
REGISTER_OPERATOR(roi_pool_grad, ops::ROIPoolGradOp);
REGISTER_OP_CPU_KERNEL(
roi_pool,
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/scatter_op.h"
#include <memory>
#include "paddle/fluid/framework/ddim.h"
namespace paddle {
......@@ -63,14 +64,16 @@ class ScatterGradOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override {
ctx->SetOutputDim(framework::GradVarName("Updates"),
ctx->GetInputDim("Updates"));
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
ctx->SetOutputDim(framework::GradVarName("X"),
ctx->GetInputDim(framework::GradVarName("Out")));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
ctx.device_context());
return framework::OpKernelType(
ctx.Input<Tensor>(framework::GradVarName("Out"))->type(),
ctx.device_context());
}
};
......@@ -95,12 +98,34 @@ $$
}
};
class ScatterGradDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("scatter_grad");
op->SetInput("Ids", Input("Ids"));
op->SetInput("Updates", Input("Updates"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetOutput(framework::GradVarName("Updates"), InputGrad("Updates"));
op->SetAttrMap(Attrs());
return op;
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(ScatterGradNoNeedBufferVarsInference,
"Updates");
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(scatter, ops::ScatterOp, ops::ScatterOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(scatter_grad, ops::ScatterGradOp);
ops::ScatterGradDescMaker);
REGISTER_OPERATOR(scatter_grad, ops::ScatterGradOp,
ops::ScatterGradNoNeedBufferVarsInference);
REGISTER_OP_CPU_KERNEL(scatter, ops::ScatterOpKernel<float>);
REGISTER_OP_CPU_KERNEL(scatter_grad, ops::ScatterGradientOpKernel<float>);
......@@ -10,6 +10,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/shuffle_channel_op.h"
#include <memory>
namespace paddle {
namespace operators {
......@@ -91,13 +92,28 @@ class ShuffleChannelGradOp : public framework::OperatorWithKernel {
}
};
class ShuffleChannelGradDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("shuffle_channel_grad");
op->SetInput("X", Input("X"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetAttrMap(Attrs());
return op;
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(shuffle_channel, ops::ShuffleChannelOp,
ops::ShuffleChannelOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
ops::ShuffleChannelOpMaker, ops::ShuffleChannelGradDescMaker);
REGISTER_OPERATOR(shuffle_channel_grad, ops::ShuffleChannelGradOp);
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/data_layout_transform.h"
......@@ -39,45 +40,6 @@ class MKLDNNHandler {
return this->AcquireMemory(md, ptr, "@user_src_mem_p");
}
// TODO(jczaja): extract common part and make AcquireMemory
std::shared_ptr<mkldnn::memory> AcquireSrcMemory(
const mkldnn::memory::primitive_desc& mpd, void* ptr) {
auto local_key = key_ + "@user_src_mem_p";
auto mem_p =
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
PADDLE_ENFORCE((mem_p != nullptr) || (is_reusing_ == false),
" find mem primitive in device context");
if (mem_p == nullptr) {
mem_p = std::make_shared<mkldnn::memory>(mpd, ptr);
dev_ctx_.SetBlob(local_key, mem_p);
} else {
mem_p->set_data_handle(ptr);
// Mark that reusing happenned. All primitives from operator instance
// should be reused or none of them. So we check consistency
is_reusing_ = true;
}
return mem_p;
}
std::shared_ptr<mkldnn::memory> AcquireWeightsMemory(
const mkldnn::memory::primitive_desc& mpd, void* ptr) {
auto local_key = key_ + "@user_weights_mem_p";
auto mem_p =
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
PADDLE_ENFORCE((mem_p != nullptr) || (is_reusing_ == false),
" find mem primitive in device context");
if (mem_p == nullptr) {
mem_p = std::make_shared<mkldnn::memory>(mpd, ptr);
dev_ctx_.SetBlob(local_key, mem_p);
} else {
mem_p->set_data_handle(ptr);
// Mark that reusing happenned. All primitives from operator instance
// should be reused or none of them. So we check consistency
is_reusing_ = true;
}
return mem_p;
}
std::shared_ptr<mkldnn::memory> AcquireWeightsMemory(
const mkldnn::memory::desc& md, void* ptr,
user_function custom_func = {}) {
......@@ -315,7 +277,37 @@ class TransposeMKLDNNHandler : public MKLDNNHandler {
mkldnn::engine engine, const std::string& base_key)
: platform::MKLDNNHandler(dev_ctx, engine, base_key),
dims_(dims),
axis_(axis) {}
axis_(axis),
logical_axis_(dims.size(), 0) {}
std::shared_ptr<mkldnn::memory> AcquireSrcMemory(
const mkldnn::memory::format& fmt, void* ptr) {
auto local_key = key_ + "@user_src_mem_p";
auto mem_p =
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
PADDLE_ENFORCE((mem_p != nullptr) || (is_reusing_ == false),
" find mem primitive in device context");
if (mem_p == nullptr) {
// Make memory descriptor using input format, unless it
// cannot be trusted (nchw) then make up memory fmt manually
for (size_t i = 0; i < logical_axis_.size(); ++i) {
logical_axis_[i] = i;
}
auto src_md = fmt != mkldnn::memory::format::nchw
? platform::MKLDNNMemDesc(
dims_, platform::MKLDNNGetDataType<float>(), fmt)
: Axis2MemoryDesc(dims_, logical_axis_);
mem_p = std::make_shared<mkldnn::memory>(
mkldnn::memory::primitive_desc{src_md, engine_}, ptr);
dev_ctx_.SetBlob(local_key, mem_p);
} else {
mem_p->set_data_handle(ptr);
// Mark that reusing happenned. All primitives from operator instance
// should be reused or none of them. So we check consistency
is_reusing_ = true;
}
return mem_p;
}
std::shared_ptr<mkldnn::memory> AcquireDstMemory(framework::Tensor* output,
platform::Place place) {
......@@ -400,6 +392,7 @@ class TransposeMKLDNNHandler : public MKLDNNHandler {
private:
std::vector<int> dims_;
std::vector<int> axis_;
std::vector<int> logical_axis_;
};
template <class forward_t, class backward_data_t, class backward_weights_t>
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <mkldnn.h>
#include <string>
namespace paddle {
namespace platform {
inline mkldnn::memory::primitive_desc create_prim_desc_from_dims(
const std::vector<int>& ltz, mkldnn::memory::format fmt,
mkldnn::memory::data_type data_type = mkldnn::memory::data_type::f32) {
mkldnn_memory_desc_t mem_fmt;
mem_fmt.primitive_kind = mkldnn_memory;
mem_fmt.ndims = ltz.size();
for (unsigned int i = 0; i < ltz.size(); ++i) {
mem_fmt.dims[i] = ltz[i]; // logical dimensions (nchw format,
// regardless physical layout)
}
mem_fmt.data_type = static_cast<mkldnn_data_type_t>(data_type);
mem_fmt.format = static_cast<mkldnn_memory_format_t>(fmt);
unsigned int total_stride = 1;
for (int i = ltz.size() - 1; i >= 0; --i) {
mem_fmt.layout_desc.blocking.padding_dims[i] =
ltz[i]; // logical dimensions (nchw format, regardless physical
// layout)
mem_fmt.layout_desc.blocking.block_dims[i] = 1;
mem_fmt.layout_desc.blocking.offset_padding_to_data[i] = 0; // no offset
mem_fmt.layout_desc.blocking.strides[0][i] = total_stride;
mem_fmt.layout_desc.blocking.strides[1][i] = 1;
total_stride *= ltz[i];
}
mem_fmt.layout_desc.blocking.offset_padding = 0; // no initial offset
auto& pool = platform::DeviceContextPool::Instance();
auto place = paddle::platform::CPUPlace();
auto* dev_ctx = dynamic_cast<platform::MKLDNNDeviceContext*>(pool.Get(place));
auto& cpu_engine = dev_ctx->GetEngine();
return mkldnn::memory::primitive_desc(mem_fmt, cpu_engine);
}
inline mkldnn::memory::primitive_desc create_prim_desc_from_format(
const std::vector<int>& ltz, const mkldnn::memory::format format,
const mkldnn::memory::data_type data_type) {
auto md = mkldnn::memory::desc({ltz}, data_type, format);
auto& pool = platform::DeviceContextPool::Instance();
auto place = paddle::platform::CPUPlace();
auto dev_ctx = dynamic_cast<platform::MKLDNNDeviceContext*>(pool.Get(place));
PADDLE_ENFORCE_NOT_NULL(dev_ctx, "Could not get valid device");
auto& cpu_engine = dev_ctx->GetEngine();
return mkldnn::memory::primitive_desc(md, cpu_engine);
}
} // namespace platform
} // namespace paddle
......@@ -1282,6 +1282,15 @@ All parameter, weight, gradient are variables in Paddle.
it will save GPU memory and may make the execution faster.
This options is only available in GPU devices.
Default False)DOC")
.def_property("fuse_all_optimizer_ops",
[](const BuildStrategy &self) {
return self.fuse_all_optimizer_ops_;
},
[](BuildStrategy &self, bool b) {
PADDLE_ENFORCE(!self.IsFinalized(),
"BuildStrategy is finlaized.");
self.fuse_all_optimizer_ops_ = b;
})
.def_property(
"sync_batch_norm",
[](const BuildStrategy &self) { return self.sync_batch_norm_; },
......
......@@ -26,6 +26,17 @@ __all__ = [
]
def _init_var_node(var_node, value, scope, place):
assert isinstance(value,
np.ndarray), 'The type of value should be numpy array.'
assert scope is not None, \
'The scope cannot be set None.'
assert place is not None, \
'The place cannot be set None.'
tensor = scope.var(var_node.name()).get_tensor()
tensor.set(value, place)
class QuantizationTransformPass(object):
def __init__(self,
scope=None,
......@@ -88,14 +99,14 @@ class QuantizationTransformPass(object):
assert activation_quantize_type != 'channel_wise_abs_max', "The activation quantization type does not support 'channel_wise_abs_max'."
if activation_quantize_type not in quant_type:
raise ValueError(
"Unknown activation_quantize_type : '%s'. It can only be ",
"'abs_max' or 'range_abs_max' or 'moving_average_abs_max'.",
str(activation_quantize_type))
"Unknown activation_quantize_type : '%s'. It can only be "
"'abs_max' or 'range_abs_max' or 'moving_average_abs_max'." %
(str(activation_quantize_type)))
if weight_quantize_type not in quant_type:
raise ValueError(
"Unknown weight_quantize_type: '%s'. It can only be ",
"'abs_max' or 'channel_wise_abs_max' or 'range_abs_max' or 'moving_average_abs_max'.",
str(weight_quantize_type))
"Unknown weight_quantize_type: '%s'. It can only be "
"'abs_max' or 'channel_wise_abs_max' or 'range_abs_max' or 'moving_average_abs_max'."
% (str(weight_quantize_type)))
self._activation_quantize_type = activation_quantize_type
self._weight_quantize_type = weight_quantize_type
......@@ -121,8 +132,6 @@ class QuantizationTransformPass(object):
"""
assert isinstance(graph,
IrGraph), 'graph must be the instance of IrGraph.'
#sequential_execution = core.get_pass('sequential_execution_pass')
#sequential_execution.apply(graph.graph)
self._is_test = graph.is_test()
# marked the variable which has been dequantized.
dequantized_vars = collections.OrderedDict()
......@@ -203,9 +212,12 @@ class QuantizationTransformPass(object):
var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=[1],
var_dtype=core.VarDesc.VarType.INT64)
self._init_var_node(
global_step_in, np.zeros(
[1], dtype='int64'))
_init_var_node(
global_step_in,
np.zeros(
[1], dtype='int64'),
self._scope,
self._place)
global_step_out = graph.create_var_node_from_desc(
global_step_in.var())
# The attribute of `op_role` is needed by ParallelExecutor.
......@@ -284,7 +296,12 @@ class QuantizationTransformPass(object):
var_dtype=var_node.dtype())
data_type = 'float64' if var_node.dtype(
) == core.VarDesc.VarType.FP64 else 'float32'
self._init_var_node(scale_in_node, np.array([0.001], dtype=data_type))
_init_var_node(
scale_in_node,
np.array(
[0.001], dtype=data_type),
self._scope,
self._place)
scale_out_node = graph.create_var_node_from_desc(scale_in_node.var())
inputs = {'X': var_node, 'InScale': scale_in_node}
......@@ -299,9 +316,13 @@ class QuantizationTransformPass(object):
var_dtype=var_node.dtype())
data_type = 'float64' if var_node.dtype(
) == core.VarDesc.VarType.FP64 else 'float32'
self._init_var_node(
scales_node, np.zeros(
[self._window_size], dtype=data_type))
_init_var_node(
scales_node,
np.zeros(
[self._window_size], dtype=data_type),
self._scope,
self._place)
inputs['Iter'] = self._global_step
outputs['OutScales'] = scales_node
attrs = {
......@@ -343,7 +364,12 @@ class QuantizationTransformPass(object):
var_dtype=var_node.dtype())
data_type = 'float64' if var_node.dtype(
) == core.VarDesc.VarType.FP64 else 'float32'
self._init_var_node(scale_in_node, np.array([0.001], dtype=data_type))
_init_var_node(
scale_in_node,
np.array(
[0.001], dtype=data_type),
self._scope,
self._place)
scale_out_node = graph.create_var_node_from_desc(scale_in_node.var())
ins = {'X': var_node, 'InScale': scale_in_node}
......@@ -356,13 +382,23 @@ class QuantizationTransformPass(object):
shape=[1])
data_type = 'float64' if var_node.dtype(
) == core.VarDesc.VarType.FP64 else 'float32'
self._init_var_node(scale_in_node, np.ones([1], dtype=data_type))
_init_var_node(
scale_in_node,
np.ones(
[1], dtype=data_type),
self._scope,
self._place)
accum_in_node = graph.create_persistable_node(
name=unique_name.generate('accum'),
var_type=core.VarDesc.VarType.LOD_TENSOR,
var_dtype=var_node.dtype(),
shape=[1])
self._init_var_node(accum_in_node, np.ones([1], dtype=data_type))
_init_var_node(
accum_in_node,
np.ones(
[1], dtype=data_type),
self._scope,
self._place)
state_out_node = graph.create_var_node_from_desc(state_in_node.var(
))
accum_out_node = graph.create_var_node_from_desc(accum_in_node.var(
......@@ -482,16 +518,6 @@ class QuantizationTransformPass(object):
graph.link_to(dequant_op_node, dequant_var_node)
return dequant_var_node
def _init_var_node(self, var_node, value):
assert isinstance(
value, np.ndarray), 'The type of value should be numpy array.'
assert self._scope is not None, \
'The scope cannot be set None when activation_quantize_type equals to range_abs_max.'
assert self._place is not None, \
'The place cannot be set None when activation_quantize_type equals to range_abs_max.'
tensor = self._scope.var(var_node.name()).get_tensor()
tensor.set(value, self._place)
def _quantized_var_name(self, var_name):
"""
Return quantized variable name for the input `var_name`.
......@@ -594,8 +620,8 @@ class QuantizationFreezePass(object):
self._weight_bits)
self._restore_var(input_arg_name, quantized_param_v)
else:
scale_v = self._to_node(op_node.outputs,
op_node.output('OutScale')[0])
scale_v = graph._find_node_by_name(
op_node.outputs, op_node.output('OutScale')[0])
self._var_scale_map[input_arg_name] = scale_v
ops = graph.all_op_nodes()
......@@ -627,8 +653,8 @@ class QuantizationFreezePass(object):
return graph
def _remove_fake_quant_and_dequant_op(self, graph, op_node):
k = self._to_node(op_node.outputs, op_node.output('Out')[0])
v = self._to_node(op_node.inputs, op_node.input('X')[0])
k = graph._find_node_by_name(op_node.outputs, op_node.output('Out')[0])
v = graph._find_node_by_name(op_node.inputs, op_node.input('X')[0])
if v.node not in self._op_input_rename_map:
self._op_input_rename_map[k.node] = v
else:
......@@ -663,8 +689,8 @@ class QuantizationFreezePass(object):
raise ValueError("Only support one output, but op %s has"
" more than one output." % (op_node.name()))
output_var_node = self._to_node(op_node.outputs,
op_node.output_arg_names()[0])
output_var_node = graph._find_node_by_name(
op_node.outputs, op_node.output_arg_names()[0])
weight_scale_node = graph.create_persistable_node(
name=unique_name.generate('channel_scale'),
var_type=core.VarDesc.VarType.LOD_TENSOR,
......@@ -672,7 +698,9 @@ class QuantizationFreezePass(object):
var_dtype=output_var_node.dtype())
data_type = 'float64' if output_var_node.dtype(
) == core.VarDesc.VarType.FP64 else 'float32'
self._init_var_node(weight_scale_node, channel_scale.astype(data_type))
_init_var_node(weight_scale_node,
channel_scale.astype(data_type), self._scope,
self._place)
dequant_var_node = graph.create_var_node(
name=self._dequantized_var_name(output_var_node.name()),
var_type=output_var_node.type(),
......@@ -724,8 +752,8 @@ class QuantizationFreezePass(object):
raise ValueError("Only support one output, but op %s has"
" more than one output." % (op_node.name()))
output_var_node = self._to_node(op_node.outputs,
op_node.output_arg_names()[0])
output_var_node = graph._find_node_by_name(
op_node.outputs, op_node.output_arg_names()[0])
dequant_var_node = graph.create_var_node(
name=self._dequantized_var_name(output_var_node.name()),
var_type=output_var_node.type(),
......@@ -746,24 +774,6 @@ class QuantizationFreezePass(object):
self._op_output_rename_map[output_var_node.node] = dequant_var_node
return dequant_var_node
def _init_var_node(self, var_node, value):
assert isinstance(
value, np.ndarray), 'The type of value should be numpy array.'
assert self._scope is not None, \
'The scope cannot be set None when activation_quantize_type equals to range_abs_max.'
assert self._place is not None, \
'The place cannot be set None when activation_quantize_type equals to range_abs_max.'
tensor = self._scope.var(var_node.name()).get_tensor()
tensor.set(value, self._place)
def _to_node(self, nodes, node_name):
target_node = None
for n in nodes:
if n.name() == node_name:
target_node = n
assert target_node is not None, "Cannot find the target node in the giving set."
return target_node
def _load_var(self, name):
return np.array(self._scope.find_var(name).get_tensor())
......
......@@ -45,13 +45,14 @@ class QuantizationStrategy(Strategy):
activation_bits=8,
weight_bits=8,
activation_quantize_type='abs_max',
weight_quantize_type='abs_max',
save_in_nodes=None,
save_out_nodes=None):
"""
Args:
start_epoch(int): The 'on_epoch_begin' function will be called in start_epoch. default: 0
end_epoch(int): The 'on_epoch_end' function will be called in end_epoch. default: 0
float_model_save_path(str): The path to save model with float weights.
float_model_save_path(str): The path to save model with float weights.
None means it doesn't save float model. defalut: None.
mobile_model_save_path(str): The path to save model for paddle-mobile execution.
None means it doesn't save mobile model. defalut: None.
......@@ -66,9 +67,11 @@ class QuantizationStrategy(Strategy):
dynamically each step in both training and testing period. If use
'range_abs_max', a static quantization scale will be calculated
during training and used in inference.
save_in_nodes(list<str>): A list of variable names used to prune graph
weight_quantize_type (str): quantization type for weights, support 'abs_max' and 'channel_wise_abs_max'.
The 'range_abs_max' usually is not used for weight, since weights are fixed once the model is well trained.
save_in_nodes(list<str>): A list of variable names used to prune graph
for saving inference model.
save_out_nodes(list<str>): A list of variable names used to prune graph
save_out_nodes(list<str>): A list of variable names used to prune graph
for saving inference model.
"""
......@@ -81,6 +84,7 @@ class QuantizationStrategy(Strategy):
self.activation_bits = activation_bits
self.weight_bits = weight_bits
self.activation_quantize_type = activation_quantize_type
self.weight_quantize_type = weight_quantize_type
self.save_out_nodes = save_out_nodes
self.save_in_nodes = save_in_nodes
......@@ -100,7 +104,8 @@ class QuantizationStrategy(Strategy):
place=context.place,
weight_bits=self.weight_bits,
activation_bits=self.activation_bits,
activation_quantize_type=self.activation_quantize_type)
activation_quantize_type=self.activation_quantize_type,
weight_quantize_type=self.weight_quantize_type)
transform_pass.apply(train_ir_graph)
transform_pass.apply(test_ir_graph)
......@@ -134,7 +139,8 @@ class QuantizationStrategy(Strategy):
scope=context.scope,
place=context.place,
weight_bits=self.weight_bits,
activation_bits=self.activation_bits)
activation_bits=self.activation_bits,
weight_quantize_type=self.weight_quantize_type)
freeze_pass.apply(test_ir_graph)
# for other strategies
......
......@@ -35,6 +35,8 @@ strategies:
start_epoch: 0
end_epoch: 0
float_model_save_path: './output/float'
mobile_model_save_path: './output/mobile'
int8_model_save_path: './output/int8'
weight_bits: 8
activation_bits: 8
weight_quantize_type: 'abs_max'
......
......@@ -256,8 +256,6 @@ class TestQuantizationFreezePass(unittest.TestCase):
place=place,
activation_quantize_type=activation_quant_type,
weight_quantize_type=weight_quant_type)
#transform_pass = QuantizationTransformPass(
# scope=scope, place=place, activation_quantize_type=activation_quant_type)
transform_pass.apply(main_graph)
transform_pass.apply(test_graph)
dev_name = '_gpu_' if use_cuda else '_cpu_'
......@@ -315,7 +313,6 @@ class TestQuantizationFreezePass(unittest.TestCase):
# Freeze graph for inference, but the weight of fc/conv is still float type.
freeze_pass = QuantizationFreezePass(
scope=scope, place=place, weight_quantize_type=weight_quant_type)
#freeze_pass = QuantizationFreezePass(scope=scope, place=place)
freeze_pass.apply(test_graph)
if not for_ci:
marked_nodes = set()
......
......@@ -2347,40 +2347,6 @@ class IrGraph(object):
"""
return {IrOpNode(node) for node in self.graph.nodes() if node.is_op()}
def _find_var_node(self, key):
"""
Get a variable node by the `key` from this graph. The key
can be a node name or a node id.
WARNS:
There are some nodes may have the same name. So, be
cautious about using this method when you find the
target var node by its name.
Args:
key(str|int): The str type denotes that the target variable node's name.
And the int type denotes that the target variable node's id.
Raises:
ValueError: If this graph doesn't have a variable with the giving name or id.
Returns:
IrVarNode: the variable node with the giving name or id.
"""
target_var_node = None
var_nodes = self.all_var_nodes()
if isinstance(key, six.string_types):
for var_node in var_nodes:
if var_node.name() == key:
target_var_node = var_node
elif isinstance(key, int):
for var_node in var_nodes:
if var_node.id() == key:
target_var_node = var_node
if target_var_node is None:
raise ValueError("var_node %s not in this graph" % key)
return target_var_node
def create_persistable_node(self, name, var_type, shape, var_dtype):
"""
Create a persistable variable node in the graph. In IrGraph,
......@@ -2525,14 +2491,6 @@ class IrGraph(object):
core.graph_safe_remove_nodes(self.graph, original_nodes)
def resolve_hazard(self):
def _to_node(nodes, node_name):
target_node = None
for n in nodes:
if n.name() == node_name:
target_node = n
assert target_node is not None, "Cannot find the target node in the giving set."
return target_node
ordered_nodes = core.topology_sort(self.graph)
var_nodes = dict()
for node in ordered_nodes:
......@@ -2540,16 +2498,17 @@ class IrGraph(object):
for each_var_name in node.op().input_arg_names():
if each_var_name not in var_nodes:
var_nodes[each_var_name] = [
_to_node(node.inputs, each_var_name)
self._find_node_by_name(node.inputs, each_var_name)
]
for each_var_name in node.op().output_arg_names():
if each_var_name not in var_nodes:
var_nodes[each_var_name] = [
_to_node(node.outputs, each_var_name)
self._find_node_by_name(node.outputs, each_var_name)
]
else:
var_nodes[each_var_name].append(
_to_node(node.outputs, each_var_name))
self._find_node_by_name(node.outputs,
each_var_name))
self.graph.resolve_hazard(var_nodes)
def has_circle(self):
......@@ -2662,6 +2621,17 @@ class IrGraph(object):
program = Program._construct_from_desc(desc)
return program
def _find_node_by_name(self, nodes, node_name):
"""
Find a node in the giving nodes set by the name.
"""
target_node = None
for n in nodes:
if n.name() == node_name:
target_node = n
assert target_node is not None, "Cannot find the target node in the giving set."
return target_node
def _update_desc_attr(self, desc, name, val):
"""
Update the value of desc's attribute by attribute's name.
......
......@@ -43,6 +43,7 @@ class TestParallelExecutorBase(unittest.TestCase):
use_ir_memory_optimize=True,
enable_inplace=True,
fuse_elewise_add_act_ops=False,
fuse_all_optimizer_ops=False,
fuse_all_reduce_ops=False,
fuse_relu_depthwise_conv=False,
optimizer=fluid.optimizer.Adam,
......@@ -81,6 +82,7 @@ class TestParallelExecutorBase(unittest.TestCase):
build_strategy.fuse_elewise_add_act_ops = fuse_elewise_add_act_ops
build_strategy.fuse_relu_depthwise_conv = fuse_relu_depthwise_conv
build_strategy.memory_optimize = False if memory_opt else use_ir_memory_optimize
build_strategy.fuse_all_optimizer_ops = fuse_all_optimizer_ops
build_strategy.fuse_all_reduce_ops = fuse_all_reduce_ops
# python memory optimization is conflict with inplace pass.
# Use ir graph memory optimization after inplace pass is the correct way.
......
......@@ -16,8 +16,10 @@ from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
from paddle.fluid import core
alignment = 256
class TestAllocContinuousSpace(OpTest):
......@@ -29,11 +31,11 @@ class TestAllocContinuousSpace(OpTest):
self.constant = attrs["constant"]
self.set_constant = attrs["set_constant"]
self.Inputs = self.init_input()
self.FusedOutput = self.init_output(self.Inputs, self.set_constant,
self.constant)
self.Outputs, self.FusedOutput = self.init_output(
self.Inputs, self.set_constant, self.constant)
self.inputs = {'Input': self.Inputs}
self.attrs = attrs
self.outputs = {'Output': self.Inputs, 'FusedOutput': self.FusedOutput}
self.outputs = {'Output': self.Outputs, 'FusedOutput': self.FusedOutput}
def init_dtype(self):
self.dtype = np.float32
......@@ -52,14 +54,31 @@ class TestAllocContinuousSpace(OpTest):
return {"copy_data": True, "set_constant": False, "constant": 0.0}
def init_output(self, input_list, set_constant, constant):
inputs = [input[1].flatten() for input in input_list]
output = np.concatenate(inputs)
inputs = []
outputs = input_list
for input in input_list:
length = len(input[1].flatten())
aligned_len = (length + alignment) / alignment * alignment
out = np.zeros(int(aligned_len))
out[0:length] = input[1].flatten()
inputs.append(out)
alloc_continuous_space_var = np.concatenate([input for input in inputs])
if set_constant:
output = np.ones((len(output))) * constant
return output
alloc_continuous_space_var = np.ones(
(len(alloc_continuous_space_var))) * constant
outputs = [(out[0],
np.ones(out[1].shape).astype(self.dtype) * constant)
for out in outputs]
return outputs, alloc_continuous_space_var
def test_check_output(self):
self.check_output()
if core.is_compiled_with_cuda():
self.check_output_with_place(
place=core.CUDAPlace(0),
no_check_set=["FusedOutput"],
atol=1e-5)
class TestAllocContinuousSpace2(TestAllocContinuousSpace):
......@@ -67,7 +86,11 @@ class TestAllocContinuousSpace2(TestAllocContinuousSpace):
return {"copy_data": False, "set_constant": True, "constant": 0.5}
def test_check_output(self):
self.check_output(no_check_set=["Output"])
if core.is_compiled_with_cuda():
self.check_output_with_place(
place=core.CUDAPlace(0),
no_check_set=["FusedOutput"],
atol=1e-5)
if __name__ == '__main__':
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from parallel_executor_test_base import TestParallelExecutorBase
import paddle.fluid as fluid
import paddle.fluid.core as core
import numpy as np
import paddle
import paddle.dataset.mnist as mnist
import unittest
import os
def simple_fc_net(use_feed):
img = fluid.layers.data(name='image', shape=[784], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
hidden = img
for _ in range(4):
hidden = fluid.layers.fc(
hidden,
size=200,
act='relu',
bias_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=1.0)))
prediction = fluid.layers.fc(hidden, size=10, act='softmax')
loss = fluid.layers.cross_entropy(input=prediction, label=label)
loss = fluid.layers.mean(loss)
return loss
def fc_with_batchnorm(use_feed):
img = fluid.layers.data(name='image', shape=[784], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
hidden = img
for _ in range(2):
hidden = fluid.layers.fc(
hidden,
size=200,
act='relu',
bias_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=1.0)))
hidden = fluid.layers.batch_norm(input=hidden)
prediction = fluid.layers.fc(hidden, size=10, act='softmax')
loss = fluid.layers.cross_entropy(input=prediction, label=label)
loss = fluid.layers.mean(loss)
return loss
class TestFuseAdamOps(TestParallelExecutorBase):
@classmethod
def setUpClass(cls):
os.environ['CPU_NUM'] = str(4)
def _init_data(self, random=True):
np.random.seed(5)
if random:
img = np.random.random(size=[32, 784]).astype(np.float32)
else:
img = np.ones(shape=[32, 784], dtype='float32')
label = np.ones(shape=[32, 1], dtype='int64')
return img, label
def _compare_fused_optimizer_ops(self,
model,
use_cuda,
random_data=True,
optimizer=fluid.optimizer.Adam):
if use_cuda and not core.is_compiled_with_cuda():
return
img, label = self._init_data(random_data)
not_fuse_op_first_loss, not_fuse_op_last_loss = self.check_network_convergence(
model,
feed_dict={"image": img,
"label": label},
use_cuda=use_cuda,
fuse_all_optimizer_ops=False,
memory_opt=False, # avoid the gradient's name changed in Python side.
optimizer=optimizer)
fuse_op_first_loss, fuse_op_last_loss = self.check_network_convergence(
model,
feed_dict={"image": img,
"label": label},
use_cuda=use_cuda,
fuse_all_optimizer_ops=True,
memory_opt=False, # avoid the gradient's name changed in Python side.
optimizer=optimizer)
for loss in zip(not_fuse_op_first_loss, fuse_op_first_loss):
self.assertAlmostEquals(loss[0], loss[1], delta=1e-6)
for loss in zip(not_fuse_op_last_loss, fuse_op_last_loss):
self.assertAlmostEquals(loss[0], loss[1], delta=1e-6)
def test_simple_fc_with_fuse_op(self):
self._compare_fused_optimizer_ops(simple_fc_net, True)
self._compare_fused_optimizer_ops(simple_fc_net, False)
def test_batchnorm_fc_with_fuse_op(self):
self._compare_fused_optimizer_ops(fc_with_batchnorm, True)
# self._compare_fused_optimizer_ops(fc_with_batchnorm, False)
class TestFuseSGDOps(TestFuseAdamOps):
def sgd_optimizer(self, learning_rate=1e-4):
return fluid.optimizer.SGD(learning_rate=learning_rate)
def test_simple_fc_with_fuse_op(self):
self._compare_fused_optimizer_ops(
simple_fc_net, True, optimizer=self.sgd_optimizer)
self._compare_fused_optimizer_ops(
simple_fc_net, False, optimizer=self.sgd_optimizer)
def test_batchnorm_fc_with_fuse_op(self):
self._compare_fused_optimizer_ops(
fc_with_batchnorm, True, optimizer=self.sgd_optimizer)
self._compare_fused_optimizer_ops(
fc_with_batchnorm, False, optimizer=self.sgd_optimizer)
if __name__ == '__main__':
unittest.main()
......@@ -61,6 +61,11 @@ def db_lstm(word, predicate, ctx_n2, ctx_n1, ctx_0, ctx_p1, ctx_p2, mark,
param_attr=fluid.ParamAttr(
name=embedding_name, trainable=False)) for x in word_input
]
# TODO(zcd): if the parameter is not trainable, the
# parameter's gradient should not generated.
for emb_layer in emb_layers:
emb_layer.stop_gradient = True
emb_layers.append(predicate_embedding)
emb_layers.append(mark_embedding)
......@@ -113,60 +118,62 @@ class TestCRFModel(unittest.TestCase):
os.environ['CPU_NUM'] = str(4)
main = fluid.Program()
startup = fluid.Program()
with fluid.program_guard(main, startup):
word = fluid.layers.data(
name='word_data', shape=[1], dtype='int64', lod_level=1)
predicate = fluid.layers.data(
name='verb_data', shape=[1], dtype='int64', lod_level=1)
ctx_n2 = fluid.layers.data(
name='ctx_n2_data', shape=[1], dtype='int64', lod_level=1)
ctx_n1 = fluid.layers.data(
name='ctx_n1_data', shape=[1], dtype='int64', lod_level=1)
ctx_0 = fluid.layers.data(
name='ctx_0_data', shape=[1], dtype='int64', lod_level=1)
ctx_p1 = fluid.layers.data(
name='ctx_p1_data', shape=[1], dtype='int64', lod_level=1)
ctx_p2 = fluid.layers.data(
name='ctx_p2_data', shape=[1], dtype='int64', lod_level=1)
mark = fluid.layers.data(
name='mark_data', shape=[1], dtype='int64', lod_level=1)
feature_out = db_lstm(**locals())
target = fluid.layers.data(
name='target', shape=[1], dtype='int64', lod_level=1)
crf_cost = fluid.layers.linear_chain_crf(
input=feature_out,
label=target,
param_attr=fluid.ParamAttr(
name='crfw', learning_rate=1e-1))
avg_cost = fluid.layers.mean(crf_cost)
sgd_optimizer = fluid.optimizer.SGD(
learning_rate=fluid.layers.exponential_decay(
learning_rate=0.01,
decay_steps=100000,
decay_rate=0.5,
staircase=True))
sgd_optimizer.minimize(avg_cost)
train_data = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.conll05.test(), buf_size=8192),
batch_size=16)
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(startup)
train_cp = compiler.CompiledProgram(main).with_data_parallel(
loss_name=avg_cost.name, build_strategy=build_strategy)
feeder = fluid.DataFeeder(
feed_list=[
word, ctx_n2, ctx_n1, ctx_0, ctx_p1, ctx_p2, predicate,
mark, target
],
place=fluid.CPUPlace())
scope = fluid.Scope()
with fluid.scope_guard(scope):
with fluid.program_guard(main, startup):
word = fluid.layers.data(
name='word_data', shape=[1], dtype='int64', lod_level=1)
predicate = fluid.layers.data(
name='verb_data', shape=[1], dtype='int64', lod_level=1)
ctx_n2 = fluid.layers.data(
name='ctx_n2_data', shape=[1], dtype='int64', lod_level=1)
ctx_n1 = fluid.layers.data(
name='ctx_n1_data', shape=[1], dtype='int64', lod_level=1)
ctx_0 = fluid.layers.data(
name='ctx_0_data', shape=[1], dtype='int64', lod_level=1)
ctx_p1 = fluid.layers.data(
name='ctx_p1_data', shape=[1], dtype='int64', lod_level=1)
ctx_p2 = fluid.layers.data(
name='ctx_p2_data', shape=[1], dtype='int64', lod_level=1)
mark = fluid.layers.data(
name='mark_data', shape=[1], dtype='int64', lod_level=1)
feature_out = db_lstm(**locals())
target = fluid.layers.data(
name='target', shape=[1], dtype='int64', lod_level=1)
crf_cost = fluid.layers.linear_chain_crf(
input=feature_out,
label=target,
param_attr=fluid.ParamAttr(
name='crfw', learning_rate=1e-1))
avg_cost = fluid.layers.mean(crf_cost)
sgd_optimizer = fluid.optimizer.SGD(
learning_rate=fluid.layers.exponential_decay(
learning_rate=0.01,
decay_steps=100000,
decay_rate=0.5,
staircase=True))
sgd_optimizer.minimize(avg_cost)
train_data = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.conll05.test(), buf_size=8192),
batch_size=16)
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(startup)
train_cp = compiler.CompiledProgram(main).with_data_parallel(
loss_name=avg_cost.name, build_strategy=build_strategy)
feeder = fluid.DataFeeder(
feed_list=[
word, ctx_n2, ctx_n1, ctx_0, ctx_p1, ctx_p2, predicate,
mark, target
],
place=fluid.CPUPlace())
data = train_data()
for i in range(10):
......
......@@ -41,14 +41,15 @@ class TestBase(unittest.TestCase):
fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace())
exe.run(startup_prog)
for _ in six.moves.xrange(iter):
exe_strategy = fluid.ExecutionStrategy()
exe_strategy._dry_run = True
exe_strategy.use_experimental_executor = use_experimental_executor
train_cp = compiler.CompiledProgram(main_prog).with_data_parallel(
loss_name=loss.name, exec_strategy=exe_strategy)
for _ in six.moves.xrange(iter_per_pe):
exe.run(train_cp)
exe_strategy = fluid.ExecutionStrategy()
exe_strategy._dry_run = True
exe_strategy.use_experimental_executor = use_experimental_executor
train_cp = compiler.CompiledProgram(
main_prog).with_data_parallel(
loss_name=loss.name, exec_strategy=exe_strategy)
for _ in six.moves.xrange(iter):
for _ in six.moves.xrange(iter_per_pe):
exe.run(train_cp)
class TestMNISTDryRun(TestBase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册