未验证 提交 8c4573a3 编写于 作者: H hong 提交者: GitHub

GradMaker for dygraph (#19706)

* refactor dygraph,test=develop

* fix failed unittest,test=develop

* polish code,test=develop

* check windows ci error,test=develop
try to fix windows ci error by np.allclose,test=develop

* polish vlog and profiler, test=develop

* try to fix preceding ops order,test=develop

* test transformer in windows ci, test=develop

* use python c-api to speed up tracer.trace,test=develop

* test=develop, fix docker with paddle nccl problem

* test=develop, add ut for debug string and gradient_accumulator

* test=develop, add tests for layer/gradient_accumulator/prepared_op

* test=develop, fix complie error for test_prepared_op

* test=develop, add more ut for dygraph

* test=develop, create API.spec for dygraph api change

* optimize grad maker; test=develop

* optimize grad maker

* test

* grad make optim; test=develop

* fix unittest bugs; test=develop

* add dygraph grad op maker and split_op

* grad op maker refactor; test=develop

* add dygraph grad maker; test=develop

* fix op deformable_conv_v1_op bug; test=develop

* fix deformable_conv prroi pool bugs;

* fix new op grad op maker bug; test=develop

* fix split by ref bug; test=develop

* fix dygraph auto prune bug; test=develop

* fix test_trace bug; test=develop

* fix fused emb seq pool bug; test=develop

* remove useless code in op_desc file; test=develop

* remove useless code, StrVarBaseNode; test=develop

* fix review issues; test=develop

* fix rank_loss grad maker; test=develop

* remove flag in VarBase; test=develop

* fix distributed_notify_op compile bug ; test=develop

* fix reshape op double grad; test=develop

* fix expand as op; test=develop

* add impertive type_defs.h for demo_train; test=develop

* fix inference lib cmake; test=develop

* fix inference lib; test=develop

* fix infernce_lib; test=develop

* fix inference cmake; test=develop

* fix inference lib; test=develop

* fix inference lib; test=develop

* remove condition dygraph grad maker, modify local name; test=develop

* fix split grad maker bug; test=develop

* fix pyramid_op bug; test=develop

* change travis time out limit; test=develop

* restore travis; test=develop

* change timeout limit; test=develop
上级 b7417610
...@@ -20,6 +20,7 @@ before_install: ...@@ -20,6 +20,7 @@ before_install:
- | - |
function timeout() { perl -e 'alarm shift; exec @ARGV' "$@"; } function timeout() { perl -e 'alarm shift; exec @ARGV' "$@"; }
script: script:
- "travis_wait 30 sleep 1800 &"
- | - |
# 43min timeout # 43min timeout
paddle/scripts/paddle_docker_build.sh ${JOB} paddle/scripts/paddle_docker_build.sh ${JOB}
......
...@@ -191,6 +191,13 @@ copy(fluid_lib_dist ...@@ -191,6 +191,13 @@ copy(fluid_lib_dist
${src_dir}/${module}/ir/*.h ${src_dir}/${module}/fleet/*.h ${src_dir}/${module}/ir/*.h ${src_dir}/${module}/fleet/*.h
DSTS ${dst_dir}/${module} ${dst_dir}/${module}/details ${dst_dir}/${module} ${dst_dir}/${module} ${dst_dir}/${module} ${dst_dir}/${module}/ir/memory_optimize_pass ${dst_dir}/${module}/ir ${dst_dir}/${module}/fleet) DSTS ${dst_dir}/${module} ${dst_dir}/${module}/details ${dst_dir}/${module} ${dst_dir}/${module} ${dst_dir}/${module} ${dst_dir}/${module}/ir/memory_optimize_pass ${dst_dir}/${module}/ir ${dst_dir}/${module}/fleet)
set(module "imperative")
copy(fluid_lib_dist
SRCS ${src_dir}/${module}/type_defs.h ${src_dir}/${module}/dygraph_grad_maker.h ${src_dir}/${module}/layer.h ${src_dir}/${module}/flags.h
DSTS ${dst_dir}/${module}/ ${dst_dir}/${module}/ ${dst_dir}/${module}/ ${dst_dir}/${module}/
)
set(module "operators") set(module "operators")
copy(fluid_lib_dist copy(fluid_lib_dist
SRCS ${src_dir}/${module}/reader/blocking_queue.h SRCS ${src_dir}/${module}/reader/blocking_queue.h
......
...@@ -27,6 +27,8 @@ limitations under the License. */ ...@@ -27,6 +27,8 @@ limitations under the License. */
#include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/var_type_inference.h" #include "paddle/fluid/framework/var_type_inference.h"
#include "paddle/fluid/imperative/dygraph_grad_maker.h"
#include "paddle/fluid/imperative/type_defs.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -40,6 +42,7 @@ enum OpInfoFillType { ...@@ -40,6 +42,7 @@ enum OpInfoFillType {
kShapeInference = 4, kShapeInference = 4,
kInplaceOpInference = 5, kInplaceOpInference = 5,
kNoNeedBufferVarsInference = 6, kNoNeedBufferVarsInference = 6,
kGradOpBaseMaker = 7,
kUnknown = -1 kUnknown = -1
}; };
...@@ -54,6 +57,7 @@ using OpRegistryClasses = std::tuple< // NOLINT ...@@ -54,6 +57,7 @@ using OpRegistryClasses = std::tuple< // NOLINT
TypePair<OperatorBase, kOperator>, // NOLINT TypePair<OperatorBase, kOperator>, // NOLINT
TypePair<OpProtoAndCheckerMaker, kOpProtoAndCheckerMaker>, // NOLINT TypePair<OpProtoAndCheckerMaker, kOpProtoAndCheckerMaker>, // NOLINT
TypePair<GradOpDescMakerBase, kGradOpDescMaker>, // NOLINT TypePair<GradOpDescMakerBase, kGradOpDescMaker>, // NOLINT
TypePair<imperative::GradOpBaseMakerBase, kGradOpBaseMaker>, // NOLINT
TypePair<VarTypeInference, kVarTypeInference>, // NOLINT TypePair<VarTypeInference, kVarTypeInference>, // NOLINT
TypePair<InferShapeBase, kShapeInference>, // NOLINT TypePair<InferShapeBase, kShapeInference>, // NOLINT
TypePair<InplaceOpInference, kInplaceOpInference>, // NOLINT TypePair<InplaceOpInference, kInplaceOpInference>, // NOLINT
...@@ -186,8 +190,21 @@ struct OpInfoFiller<T, kGradOpDescMaker> { ...@@ -186,8 +190,21 @@ struct OpInfoFiller<T, kGradOpDescMaker> {
}; };
info->use_default_grad_op_desc_maker_ = info->use_default_grad_op_desc_maker_ =
std::is_base_of<DefaultGradOpDescMaker<true>, T>::value || std::is_base_of<DefaultGradOpMaker<OpDesc, true>, T>::value ||
std::is_base_of<DefaultGradOpDescMaker<false>, T>::value; std::is_base_of<DefaultGradOpMaker<OpDesc, false>, T>::value;
}
};
template <typename T>
struct OpInfoFiller<T, kGradOpBaseMaker> {
void operator()(const char* op_type, OpInfo* info) const {
info->dygraph_grad_op_maker_ = [](
const imperative::OpBase* fw_op_base,
const imperative::NameVarBaseMap& var_base_map_in,
const imperative::NameVarBaseMap& var_base_map_out) {
T maker(fw_op_base, var_base_map_in, var_base_map_out);
return maker();
};
} }
}; };
......
...@@ -21,6 +21,9 @@ limitations under the License. */ ...@@ -21,6 +21,9 @@ limitations under the License. */
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/imperative/dygraph_grad_maker.h"
#include "paddle/fluid/imperative/layer.h"
#include "paddle/fluid/imperative/type_defs.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -97,6 +100,8 @@ class GradOpDescMakerBase { ...@@ -97,6 +100,8 @@ class GradOpDescMakerBase {
return ret_val; return ret_val;
} }
std::vector<std::string> Empty() const { return {}; }
std::vector<std::string> InputNames() const { std::vector<std::string> InputNames() const {
return this->fwd_op_.InputNames(); return this->fwd_op_.InputNames();
} }
...@@ -132,7 +137,9 @@ class GradOpDescMakerBase { ...@@ -132,7 +137,9 @@ class GradOpDescMakerBase {
std::string ForwardOpType() const { return this->fwd_op_.Type(); } std::string ForwardOpType() const { return this->fwd_op_.Type(); }
protected: protected:
const OpDesc& ForwardOp() const { return fwd_op_; } bool HasInput(const std::string& name) const {
return (fwd_op_.Inputs().count(name) > 0);
}
private: private:
const OpDesc& fwd_op_; const OpDesc& fwd_op_;
...@@ -143,11 +150,24 @@ class GradOpDescMakerBase { ...@@ -143,11 +150,24 @@ class GradOpDescMakerBase {
std::vector<BlockDesc*> grad_block_; std::vector<BlockDesc*> grad_block_;
}; };
class SingleGradOpDescMaker : public GradOpDescMakerBase { template <typename T>
class SingleGradOpMaker {
public:
std::vector<std::unique_ptr<T>> operator()() const {
PADDLE_ENFORCE(false, "should not call this function");
return {};
}
protected:
virtual std::unique_ptr<T> Apply() const = 0;
};
template <>
class SingleGradOpMaker<OpDesc> : public GradOpDescMakerBase {
public: public:
using GradOpDescMakerBase::GradOpDescMakerBase; using GradOpDescMakerBase::GradOpDescMakerBase;
std::vector<std::unique_ptr<OpDesc>> operator()() const final { std::vector<std::unique_ptr<OpDesc>> operator()() const {
std::vector<std::unique_ptr<OpDesc>> retv; std::vector<std::unique_ptr<OpDesc>> retv;
retv.emplace_back(this->Apply()); retv.emplace_back(this->Apply());
return retv; return retv;
...@@ -157,14 +177,32 @@ class SingleGradOpDescMaker : public GradOpDescMakerBase { ...@@ -157,14 +177,32 @@ class SingleGradOpDescMaker : public GradOpDescMakerBase {
virtual std::unique_ptr<OpDesc> Apply() const = 0; virtual std::unique_ptr<OpDesc> Apply() const = 0;
}; };
template <bool DropEmptyIG = true> template <>
class DefaultGradOpDescMaker final : public SingleGradOpDescMaker { class SingleGradOpMaker<imperative::OpBase>
: public imperative::GradOpBaseMakerBase {
public:
using GradOpBaseMakerBase::GradOpBaseMakerBase;
public: public:
using SingleGradOpDescMaker::SingleGradOpDescMaker; std::vector<std::unique_ptr<imperative::OpBase>> operator()() const {
std::vector<std::unique_ptr<imperative::OpBase>> retv;
retv.emplace_back(this->Apply());
return retv;
}
protected: protected:
std::unique_ptr<OpDesc> Apply() const final { virtual std::unique_ptr<imperative::OpBase> Apply() const = 0;
auto* grad = new OpDesc(); };
template <typename T, bool DropEmptyIG = true>
class DefaultGradOpMaker final : public SingleGradOpMaker<T> {
public:
using SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const final {
auto* grad = new T();
grad->SetType(this->ForwardOpType() + "_grad"); grad->SetType(this->ForwardOpType() + "_grad");
for (auto& input_param : this->InputNames()) { for (auto& input_param : this->InputNames()) {
...@@ -180,15 +218,35 @@ class DefaultGradOpDescMaker final : public SingleGradOpDescMaker { ...@@ -180,15 +218,35 @@ class DefaultGradOpDescMaker final : public SingleGradOpDescMaker {
grad->SetAttrMap(this->Attrs()); grad->SetAttrMap(this->Attrs());
return std::unique_ptr<OpDesc>(grad); return std::unique_ptr<T>(grad);
} }
}; };
class EmptyGradOpMaker final : public GradOpDescMakerBase { template <typename T>
class EmptyGradOpMaker {
public:
virtual std::vector<std::unique_ptr<T>> operator()()
const final { /* NOLINT */
return {};
}
};
template <>
class EmptyGradOpMaker<OpDesc> final : public GradOpDescMakerBase {
public: public:
using GradOpDescMakerBase::GradOpDescMakerBase; using GradOpDescMakerBase::GradOpDescMakerBase;
std::vector<std::unique_ptr<OpDesc>> operator()() const final { return {}; } std::vector<std::unique_ptr<OpDesc>> operator()() const final { return {}; }
}; };
template <>
class EmptyGradOpMaker<imperative::OpBase> final
: public imperative::GradOpBaseMakerBase {
public:
using GradOpBaseMakerBase::GradOpBaseMakerBase;
std::vector<std::unique_ptr<imperative::OpBase>> operator()() const final {
return {};
}
};
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/imperative/type_defs.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.h" #include "paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "paddle/fluid/framework/naive_executor.h" #include "paddle/fluid/framework/naive_executor.h"
#include "paddle/fluid/imperative/type_defs.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
namespace paddle { namespace paddle {
...@@ -29,6 +30,7 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, ...@@ -29,6 +30,7 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
op->SetType(type); op->SetType(type);
op->SetAttr("use_mkldnn", use_mkldnn); op->SetAttr("use_mkldnn", use_mkldnn);
op->SetAttr("name", name); op->SetAttr("name", name);
if (type == "conv2d") { if (type == "conv2d") {
op->SetInput("Input", {inputs[0]}); op->SetInput("Input", {inputs[0]});
op->SetInput("Filter", {inputs[1]}); op->SetInput("Filter", {inputs[1]});
......
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/attribute.h" #include "paddle/fluid/framework/attribute.h"
#include "paddle/fluid/framework/type_defs.h" #include "paddle/fluid/framework/type_defs.h"
......
...@@ -42,6 +42,7 @@ struct OpInfo { ...@@ -42,6 +42,7 @@ struct OpInfo {
InferShapeFN infer_shape_; InferShapeFN infer_shape_;
InferInplaceOpFN infer_inplace_; InferInplaceOpFN infer_inplace_;
InferNoNeedBufferVarsFN infer_no_need_buffer_vars_; InferNoNeedBufferVarsFN infer_no_need_buffer_vars_;
DygraphGradOpMakerFN dygraph_grad_op_maker_;
// NOTE(zjl): this flag is added to check whether // NOTE(zjl): this flag is added to check whether
// the grad maker is the default one. // the grad maker is the default one.
...@@ -81,6 +82,24 @@ struct OpInfo { ...@@ -81,6 +82,24 @@ struct OpInfo {
// some op has no grad_op_maker, add check before use GradOpMaker() // some op has no grad_op_maker, add check before use GradOpMaker()
bool HasGradOpMaker() const { return grad_op_maker_ != nullptr; } bool HasGradOpMaker() const { return grad_op_maker_ != nullptr; }
const DygraphGradOpMakerFN& DygraphGradOpMaker() const {
// Normally, proto_ should not be null, except some special operators, such
// as LeaklyReluDoubleGrad op.
std::string type = proto_ ? proto_->type() : "unknown";
PADDLE_ENFORCE_NOT_NULL(
dygraph_grad_op_maker_,
"Operator %s's DygraphGradOpMaker has not been "
"registered.\nPlease check whether %s_op has "
"grad_op.\nIf not, please set stop_gradient to True "
"for its input and output variables using var.stop_gradient=True.",
type.c_str(), type.c_str());
return dygraph_grad_op_maker_;
}
bool HasDygraphGradOpMaker() const {
return dygraph_grad_op_maker_ != nullptr ? true : false;
}
bool HasInferInplace() const { return infer_inplace_ != nullptr; } bool HasInferInplace() const { return infer_inplace_ != nullptr; }
const OpAttrChecker* Checker() const { return checker_; } const OpAttrChecker* Checker() const { return checker_; }
......
...@@ -209,7 +209,8 @@ struct OpKernelRegistrarFunctorEx<PlaceType, false, I, ...@@ -209,7 +209,8 @@ struct OpKernelRegistrarFunctorEx<PlaceType, false, I,
#define REGISTER_OP_WITHOUT_GRADIENT(op_type, op_class, op_maker_class) \ #define REGISTER_OP_WITHOUT_GRADIENT(op_type, op_class, op_maker_class) \
REGISTER_OPERATOR(op_type, op_class, op_maker_class, \ REGISTER_OPERATOR(op_type, op_class, op_maker_class, \
paddle::framework::EmptyGradOpMaker) paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, \
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>)
/** /**
* Macro to register OperatorKernel. * Macro to register OperatorKernel.
......
...@@ -1171,7 +1171,7 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType( ...@@ -1171,7 +1171,7 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType(
proto::VarType::Type dafault_data_type = proto::VarType::Type dafault_data_type =
static_cast<proto::VarType::Type>(-1); static_cast<proto::VarType::Type>(-1);
proto::VarType::Type data_type = dafault_data_type; proto::VarType::Type data_type = dafault_data_type;
for (auto& input : this->inputs_) { for (auto& input : ctx.Context().inputs) {
ParseInputDataType(ctx, input.first, &data_type); ParseInputDataType(ctx, input.first, &data_type);
} }
PADDLE_ENFORCE_NE(data_type, dafault_data_type, PADDLE_ENFORCE_NE(data_type, dafault_data_type,
......
...@@ -385,6 +385,8 @@ class ExecutionContext { ...@@ -385,6 +385,8 @@ class ExecutionContext {
return *boost::get<std::shared_ptr<T>>((*kernel_configs_)[idx]); return *boost::get<std::shared_ptr<T>>((*kernel_configs_)[idx]);
} }
const RuntimeContext& Context() const { return ctx_; }
private: private:
const OperatorBase& op_; const OperatorBase& op_;
const Scope& scope_; const Scope& scope_;
......
...@@ -20,6 +20,7 @@ limitations under the License. */ ...@@ -20,6 +20,7 @@ limitations under the License. */
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
#include "paddle/fluid/imperative/type_defs.h"
#include "paddle/fluid/platform/variant.h" #include "paddle/fluid/platform/variant.h"
namespace paddle { namespace paddle {
...@@ -54,6 +55,12 @@ using GradOpMakerFN = std::function<std::vector<std::unique_ptr<OpDesc>>( ...@@ -54,6 +55,12 @@ using GradOpMakerFN = std::function<std::vector<std::unique_ptr<OpDesc>>(
std::unordered_map<std::string, std::string>* /*grad_to_var*/, std::unordered_map<std::string, std::string>* /*grad_to_var*/,
const std::vector<BlockDesc*>& grad_block)>; const std::vector<BlockDesc*>& grad_block)>;
using DygraphGradOpMakerFN =
std::function<std::vector<std::unique_ptr<imperative::OpBase>>(
const imperative::OpBase* fw_op_base,
const imperative::NameVarBaseMap& var_base_map_in,
const imperative::NameVarBaseMap& var_base_map_out)>;
using InferVarTypeFN = using InferVarTypeFN =
std::function<void(framework::InferVarTypeContext* /*context*/)>; std::function<void(framework::InferVarTypeContext* /*context*/)>;
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/imperative/layer.h"
#include "paddle/fluid/imperative/type_defs.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/macros.h"
namespace paddle {
namespace imperative {
class GradOpBaseMakerBase {
public:
explicit GradOpBaseMakerBase(const OpBase* fw_op_base,
const NameVarBaseMap& var_base_map_in,
const NameVarBaseMap& var_base_map_out)
: fw_op_base_(fw_op_base),
var_base_map_in_(var_base_map_in),
var_base_map_out_(var_base_map_out) {}
virtual ~GradOpBaseMakerBase() = default;
virtual std::vector<std::unique_ptr<OpBase>> operator()() const = 0;
std::vector<std::shared_ptr<VarBase>> InputGrad(
const std::string& name, bool drop_empty_grad = true) const {
return GetVarBaseList(name, true, true);
}
std::vector<std::shared_ptr<VarBase>> OutputGrad(
const std::string& name) const {
return GetVarBaseList(name, true, false);
}
std::vector<std::shared_ptr<VarBase>> Input(const std::string name) const {
return GetVarBaseList(name, false, true);
}
std::vector<std::shared_ptr<VarBase>> Output(const std::string& name) const {
return GetVarBaseList(name, false, false);
}
std::vector<std::shared_ptr<VarBase>> Empty() const { return {}; }
std::vector<std::string> InputNames() const {
std::vector<std::string> vec_temp;
vec_temp.reserve(var_base_map_in_.size());
for (auto& it : var_base_map_in_) {
vec_temp.emplace_back(it.first);
}
return vec_temp;
}
std::vector<std::string> OutputNames() const {
std::vector<std::string> vec_temp;
vec_temp.reserve(var_base_map_out_.size());
for (auto& it : var_base_map_out_) {
vec_temp.emplace_back(it.first);
}
return vec_temp;
}
const std::unordered_map<std::string, framework::Attribute>& Attrs() const {
return fw_op_base_->Attrs();
}
const framework::Attribute& GetAttr(const std::string& name) const {
auto& map = fw_op_base_->Attrs();
auto it = map.find(name);
PADDLE_ENFORCE(it != map.end(),
"Cannot find attribute [%s] in operator [%s]", name,
fw_op_base_->Type());
return it->second;
}
template <typename T>
inline const T& Attr(const std::string& name) const {
return boost::get<T>(GetAttr(name));
}
std::string ForwardOpType() const { return fw_op_base_->Type(); }
protected:
bool HasInput(const std::string& name) const {
auto it = var_base_map_in_.find(name);
return it != var_base_map_in_.end();
}
private:
std::vector<std::shared_ptr<VarBase>> GetVarBaseList(const std::string& name,
bool is_grad,
bool is_input) const {
const NameVarBaseMap& data_map =
is_input ? var_base_map_in_ : var_base_map_out_;
auto iterator = data_map.find(name);
std::vector<std::shared_ptr<imperative::VarBase>> vec_temp;
if (iterator != data_map.end()) {
vec_temp.reserve(iterator->second.size());
for (auto& var_base_temp : iterator->second) {
if (is_grad) {
PADDLE_ENFORCE_NOT_NULL(var_base_temp->GradVarBase(),
"VarBase grad of OP [%s] should not be null",
fw_op_base_->Type());
auto grad_var_base_tmp = var_base_temp->GradVarBase();
auto* tensor = grad_var_base_tmp->MutableVar()
->GetMutable<framework::LoDTensor>();
tensor->Resize(
var_base_temp->Var().Get<framework::LoDTensor>().dims());
vec_temp.emplace_back(grad_var_base_tmp);
} else {
vec_temp.emplace_back(var_base_temp);
}
}
}
return vec_temp;
}
private:
const OpBase* fw_op_base_;
const NameVarBaseMap& var_base_map_in_;
const NameVarBaseMap& var_base_map_out_;
protected:
std::vector<framework::BlockDesc*> grad_block_;
};
} // namespace imperative
} // namespace paddle
...@@ -173,6 +173,7 @@ void BasicEngine::PrepareDeps() { ...@@ -173,6 +173,7 @@ void BasicEngine::PrepareDeps() {
void BasicEngine::SumGradient(OpBase* op, std::shared_ptr<VarBase> src, void BasicEngine::SumGradient(OpBase* op, std::shared_ptr<VarBase> src,
VarBase* dst) { VarBase* dst) {
auto iter = accumulators_.find(dst); auto iter = accumulators_.find(dst);
PADDLE_ENFORCE_EQ(iter != accumulators_.end(), true, PADDLE_ENFORCE_EQ(iter != accumulators_.end(), true,
"Cannot find gradient of variable %s", dst->Name()); "Cannot find gradient of variable %s", dst->Name());
iter->second->Add(std::move(src), op->id()); iter->second->Add(std::move(src), op->id());
...@@ -195,16 +196,16 @@ void BasicEngine::Execute() { ...@@ -195,16 +196,16 @@ void BasicEngine::Execute() {
NameVarBaseMap tmp_outs; NameVarBaseMap tmp_outs;
// A var may be coresponding to several grad var in one op // A var may be coresponding to several grad var in one op
std::unordered_map<VarBase*, std::vector<std::shared_ptr<VarBase>>> var_map; std::unordered_map<VarBase*, std::vector<std::shared_ptr<VarBase>>> var_map;
size_t counter = 0;
for (auto& bwd_out : bwd_outs) { for (auto& bwd_out : bwd_outs) {
auto& tmp_var_list = tmp_outs[bwd_out.first]; auto& tmp_var_list = tmp_outs[bwd_out.first];
tmp_var_list.reserve(bwd_out.second.size()); tmp_var_list.reserve(bwd_out.second.size());
for (auto& var : bwd_out.second) { for (auto& var : bwd_out.second) {
auto tmp_var = std::make_shared<VarBase>( auto tmp_var =
false, "Gtmp@" + std::to_string(counter++)); // Do not need grad std::make_shared<VarBase>(false, "Gtmp@"); // Do not need grad
tmp_var_list.emplace_back(tmp_var); tmp_var_list.emplace_back(tmp_var);
if (var) { if (var) {
var_map[var.get()].emplace_back(std::move(tmp_var)); var_map[var.get()].emplace_back(std::move(tmp_var));
var->ClearGradOps(); var->ClearGradOps();
} }
} }
...@@ -227,6 +228,7 @@ void BasicEngine::Execute() { ...@@ -227,6 +228,7 @@ void BasicEngine::Execute() {
} }
// Step 3: Collect ready ops // Step 3: Collect ready ops
for (auto* grad_pending_op : cur_op->GradPendingOps()) { for (auto* grad_pending_op : cur_op->GradPendingOps()) {
PADDLE_ENFORCE_NOT_NULL(grad_pending_op); PADDLE_ENFORCE_NOT_NULL(grad_pending_op);
auto iter = op_deps_.find(grad_pending_op); auto iter = op_deps_.find(grad_pending_op);
......
...@@ -53,7 +53,18 @@ static framework::VariableNameMap CreateVarNameMap( ...@@ -53,7 +53,18 @@ static framework::VariableNameMap CreateVarNameMap(
const framework::OpInfo& op_info, const std::string& op_type, const framework::OpInfo& op_info, const std::string& op_type,
const NameVarBaseMap& varbase_map, bool is_input) { const NameVarBaseMap& varbase_map, bool is_input) {
if (op_info.proto_ == nullptr) { if (op_info.proto_ == nullptr) {
return {}; framework::VariableNameMap result;
for (auto& it : varbase_map) {
auto& var_vector = it.second;
std::vector<std::string> args;
args.reserve(var_vector.size());
for (auto& var_base : var_vector) {
args.emplace_back(var_base->Name());
}
result[it.first] = std::move(args);
}
return result;
} }
framework::VariableNameMap result; framework::VariableNameMap result;
...@@ -220,21 +231,20 @@ std::shared_ptr<VarBase> VarBase::NewVarBase(const platform::Place& dst_place, ...@@ -220,21 +231,20 @@ std::shared_ptr<VarBase> VarBase::NewVarBase(const platform::Place& dst_place,
} }
// create OpBase from optype // create OpBase from optype
OpBase::OpBase(size_t id, const std::string& type, const NameVarBaseMap& ins, OpBase::OpBase(size_t id, const std::string& type, const NameVarBaseMap& ins,
const NameVarBaseMap& outs, framework::AttributeMap attrs, const NameVarBaseMap& outs, const framework::AttributeMap& attrs,
const platform::Place& place) const platform::Place& place)
: id_(id), place_(place) { : id_(id), place_(place), attrs_(attrs) {
const auto& info = framework::OpInfoMap::Instance().Get(type); const auto& info = framework::OpInfoMap::Instance().Get(type);
// Step 1: Run forward // Step 1: Run forward
if (info.Checker() != nullptr) { if (info.Checker() != nullptr) {
info.Checker()->Check(&attrs); info.Checker()->Check(&attrs_);
} }
auto input_name_map = CreateVarNameMap(info, type, ins, true); auto input_name_map = CreateVarNameMap(info, type, ins, true);
auto output_name_map = CreateVarNameMap(info, type, outs, false); auto output_name_map = CreateVarNameMap(info, type, outs, false);
op_ = framework::OpRegistry::CreateOp(type, std::move(input_name_map), op_ = framework::OpRegistry::CreateOp(type, std::move(input_name_map),
std::move(output_name_map), std::move(output_name_map), attrs);
std::move(attrs));
VLOG(3) << "Construct Op: " << type << std::endl; VLOG(3) << "Construct Op: " << type << std::endl;
} }
...@@ -245,6 +255,18 @@ OpBase::OpBase(size_t id, const framework::OpDesc& op_desc, ...@@ -245,6 +255,18 @@ OpBase::OpBase(size_t id, const framework::OpDesc& op_desc,
VLOG(3) << "Construct Op: " << op_desc.Type() << std::endl; VLOG(3) << "Construct Op: " << op_desc.Type() << std::endl;
} }
void OpBase::CreateOperatorBase() {
const auto& info = framework::OpInfoMap::Instance().Get(type_);
if (info.Checker() != nullptr) {
info.Checker()->Check(&attrs_);
}
auto input_name_map = CreateVarNameMap(info, type_, ins_, true);
auto output_name_map = CreateVarNameMap(info, type_, outs_, false);
op_ = framework::OpRegistry::CreateOp(type_, std::move(input_name_map),
std::move(output_name_map), attrs_);
}
void OpBase::Run(const NameVarBaseMap& ins, const NameVarBaseMap& outs) { void OpBase::Run(const NameVarBaseMap& ins, const NameVarBaseMap& outs) {
auto* op_kernel = dynamic_cast<framework::OperatorWithKernel*>(op_.get()); auto* op_kernel = dynamic_cast<framework::OperatorWithKernel*>(op_.get());
PADDLE_ENFORCE_NOT_NULL(op_kernel, "only support op with kernel"); PADDLE_ENFORCE_NOT_NULL(op_kernel, "only support op with kernel");
......
...@@ -182,6 +182,7 @@ class VarBase { ...@@ -182,6 +182,7 @@ class VarBase {
framework::Variable var_; framework::Variable var_;
std::string name_; std::string name_;
std::shared_ptr<VarBase> grad_var_; std::shared_ptr<VarBase> grad_var_;
mutable size_t copied_counter_ = 0; mutable size_t copied_counter_ = 0;
// grad_op indicates which grad_op will this var be used as input // grad_op indicates which grad_op will this var be used as input
...@@ -271,6 +272,7 @@ class RuntimeInferVarTypeContext : public framework::InferVarTypeContext { ...@@ -271,6 +272,7 @@ class RuntimeInferVarTypeContext : public framework::InferVarTypeContext {
const std::vector<std::string>& Output( const std::vector<std::string>& Output(
const std::string& name) const override { const std::string& name) const override {
auto iter = output_names_.find(name); auto iter = output_names_.find(name);
PADDLE_ENFORCE_EQ(iter != output_names_.end(), true, PADDLE_ENFORCE_EQ(iter != output_names_.end(), true,
"Cannot find output %s", name); "Cannot find output %s", name);
return iter->second; return iter->second;
...@@ -279,6 +281,7 @@ class RuntimeInferVarTypeContext : public framework::InferVarTypeContext { ...@@ -279,6 +281,7 @@ class RuntimeInferVarTypeContext : public framework::InferVarTypeContext {
framework::proto::VarType::Type GetType( framework::proto::VarType::Type GetType(
const std::string& name) const override { const std::string& name) const override {
auto iter = var_set_.find(name); auto iter = var_set_.find(name);
PADDLE_ENFORCE_EQ(iter != var_set_.end(), true, PADDLE_ENFORCE_EQ(iter != var_set_.end(), true,
"Cannot find var %s in GetType", name); "Cannot find var %s in GetType", name);
return iter->second->Type(); return iter->second->Type();
...@@ -296,6 +299,7 @@ class RuntimeInferVarTypeContext : public framework::InferVarTypeContext { ...@@ -296,6 +299,7 @@ class RuntimeInferVarTypeContext : public framework::InferVarTypeContext {
framework::proto::VarType::Type GetDataType( framework::proto::VarType::Type GetDataType(
const std::string& name) const override { const std::string& name) const override {
auto iter = var_set_.find(name); auto iter = var_set_.find(name);
PADDLE_ENFORCE_EQ(iter != var_set_.end(), true, PADDLE_ENFORCE_EQ(iter != var_set_.end(), true,
"Cannot find var %s in GetDataType", name); "Cannot find var %s in GetDataType", name);
return iter->second->DataType(); return iter->second->DataType();
...@@ -380,6 +384,10 @@ class OpBase : public std::enable_shared_from_this<OpBase> { ...@@ -380,6 +384,10 @@ class OpBase : public std::enable_shared_from_this<OpBase> {
return grad_pending_ops_; return grad_pending_ops_;
} }
void SetGradPendingOps(std::vector<OpBase*> vec_temp) {
grad_pending_ops_.swap(vec_temp);
}
void InsertGradPendingOps(OpBase* op) { grad_pending_ops_.emplace_back(op); } void InsertGradPendingOps(OpBase* op) { grad_pending_ops_.emplace_back(op); }
void SortGradPendingOps() { void SortGradPendingOps() {
...@@ -406,12 +414,56 @@ class OpBase : public std::enable_shared_from_this<OpBase> { ...@@ -406,12 +414,56 @@ class OpBase : public std::enable_shared_from_this<OpBase> {
private: private:
OpBase(size_t id, const std::string& type, const NameVarBaseMap& ins, OpBase(size_t id, const std::string& type, const NameVarBaseMap& ins,
const NameVarBaseMap& outs, framework::AttributeMap attrs, const NameVarBaseMap& outs, const framework::AttributeMap& attrs,
const platform::Place& place); const platform::Place& place);
OpBase(size_t id, const framework::OpDesc& op_desc, OpBase(size_t id, const framework::OpDesc& op_desc,
const platform::Place& place); const platform::Place& place);
public:
OpBase() {}
void SetType(const std::string& type) { type_ = type; }
void SetInput(const std::string& name,
std::vector<std::shared_ptr<VarBase>> vec_var_base) {
ins_[name] = std::move(vec_var_base);
}
void SetOutput(const std::string& name,
std::vector<std::shared_ptr<VarBase>> vec_var_base) {
outs_[name] = std::move(vec_var_base);
}
void SetAttrMap(const framework::AttributeMap& attrs) { attrs_ = attrs; }
void SetAttr(const std::string& name, const framework::Attribute& v) {
attrs_[name] = v;
}
void SetBlockAttr(const std::string& name, framework::BlockDesc* block) {
PADDLE_THROW("SetBlockAttr is not support in dygraph OpBase");
}
const framework::AttributeMap& Attrs() { return attrs_; }
void CreateOperatorBase();
void SetId(size_t id) { id_ = id; }
void SetPlace(platform::Place place) { place_ = place; }
bool HasAttr(const std::string& name) const {
return attrs_.find(name) != attrs_.end();
}
const framework::Attribute& GetAttr(const std::string& name) const {
auto it = attrs_.find(name);
PADDLE_ENFORCE(it != attrs_.end(), "can not find attribute [%s]", name);
return it->second;
}
template <typename T>
inline const T& Attr(const std::string& name) const {
return boost::get<T>(GetAttr(name));
}
private:
size_t id_; size_t id_;
std::unique_ptr<framework::OperatorBase> op_; std::unique_ptr<framework::OperatorBase> op_;
...@@ -421,11 +473,14 @@ class OpBase : public std::enable_shared_from_this<OpBase> { ...@@ -421,11 +473,14 @@ class OpBase : public std::enable_shared_from_this<OpBase> {
// Not need to be std::weak_ptr, because op is binded to a certain Tracer, // Not need to be std::weak_ptr, because op is binded to a certain Tracer,
// and would not be used by a Tracer that does not create itself. // and would not be used by a Tracer that does not create itself.
std::vector<OpBase*> grad_pending_ops_; std::vector<OpBase*> grad_pending_ops_;
// This part is only used for backward // This part is only used for backward
NameVarBaseMap ins_; NameVarBaseMap ins_;
NameVarBaseMap outs_; NameVarBaseMap outs_;
std::string type_;
framework::AttributeMap attrs_;
}; };
} // namespace imperative } // namespace imperative
......
...@@ -37,6 +37,7 @@ void PreparedOp::PrepareData( ...@@ -37,6 +37,7 @@ void PreparedOp::PrepareData(
const auto* tensor = GetTensorFromVar(var_base->Var()); const auto* tensor = GetTensorFromVar(var_base->Var());
if (tensor && tensor->IsInitialized()) { if (tensor && tensor->IsInitialized()) {
auto tmp_place = tensor->place(); auto tmp_place = tensor->place();
// TODO(jiabin): Support transform data layout when we Verify it on more // TODO(jiabin): Support transform data layout when we Verify it on more
// tests // tests
if (!(tmp_place == place)) { if (!(tmp_place == place)) {
......
...@@ -2,4 +2,4 @@ cc_test(nccl_context_test SRCS nccl_context_test.cc DEPS nccl_context) ...@@ -2,4 +2,4 @@ cc_test(nccl_context_test SRCS nccl_context_test.cc DEPS nccl_context)
cc_test(test_gradient_accmulator SRCS test_gradient_accmulator.cc DEPS gradient_accumulator memcpy) cc_test(test_gradient_accmulator SRCS test_gradient_accmulator.cc DEPS gradient_accumulator memcpy)
cc_test(test_layer SRCS test_layer.cc DEPS layer proto_desc operator op_registry variable_helper mul_op memcpy) cc_test(test_layer SRCS test_layer.cc DEPS layer proto_desc operator op_registry variable_helper mul_op memcpy)
cc_test(test_prepare_op SRCS test_prepare_op.cc DEPS prepared_operator op_info split_op layer concat_and_split assign_op place) cc_test(test_prepare_op SRCS test_prepare_op.cc DEPS prepared_operator op_info split_op layer concat_and_split assign_op place)
cc_test(test_tracer SRCS test_tracer.cc DEPS tracer layer proto_desc operator op_registry variable_helper mul_op memcpy) cc_test(test_tracer SRCS test_tracer.cc DEPS tracer layer proto_desc operator op_registry variable_helper mul_op reduce_sum_op elementwise_add_op memcpy)
...@@ -74,6 +74,45 @@ TEST(test_tracer, test_trace_op) { ...@@ -74,6 +74,45 @@ TEST(test_tracer, test_trace_op) {
} }
} }
TEST(test_tracer, test_trace_op_with_backward) {
// Doing an mul
imperative::Tracer tracer;
std::shared_ptr<imperative::VarBase> x_in(
new imperative::VarBase(true, "x_in"));
std::shared_ptr<imperative::VarBase> y_in(
new imperative::VarBase(true, "y_in"));
std::shared_ptr<imperative::VarBase> vout(
new imperative::VarBase(true, "vout"));
platform::CPUPlace place;
std::vector<float> src_data(10, 2.0);
std::vector<int64_t> dims1 = {2, 5};
std::vector<int64_t> dims2 = {5, 2};
auto* x_in_tensor = x_in->MutableVar()->GetMutable<framework::LoDTensor>();
auto* y_in_tensor = y_in->MutableVar()->GetMutable<framework::LoDTensor>();
x_in_tensor->Resize(framework::make_ddim(dims1));
auto* mutable_x = x_in_tensor->mutable_data<float>(place);
paddle::memory::Copy(place, mutable_x, place, src_data.data(),
sizeof(float) * src_data.size());
y_in_tensor->Resize(framework::make_ddim(dims2));
auto* mutable_y = y_in_tensor->mutable_data<float>(place);
paddle::memory::Copy(place, mutable_y, place, src_data.data(),
sizeof(float) * src_data.size());
var_pair x_pair = var_pair("X", vb_vector(1, x_in));
var_pair y_pair = var_pair("Y", vb_vector(1, y_in));
var_pair out_pair = var_pair("Out", vb_vector(1, vout));
imperative::NameVarBaseMap ins = {x_pair, y_pair};
imperative::NameVarBaseMap outs = {out_pair};
framework::AttributeMap mul_attr_map;
mul_attr_map["use_mkldnn"] = false;
tracer.TraceOp("mul", ins, outs, mul_attr_map, place, true);
const auto& out_tensor = vout->Var().Get<framework::LoDTensor>();
for (size_t i = 0; i < vout->Var().Get<framework::LoDTensor>().numel(); i++) {
ASSERT_EQ(out_tensor.data<float>()[i], 20.0);
}
}
TEST(test_tracer, test_track_backward_output) { TEST(test_tracer, test_track_backward_output) {
// Doing an mul // Doing an mul
imperative::Tracer tracer; imperative::Tracer tracer;
...@@ -151,15 +190,17 @@ TEST(test_tracer, test_trace_op_with_multi_device_inputs) { ...@@ -151,15 +190,17 @@ TEST(test_tracer, test_trace_op_with_multi_device_inputs) {
imperative::Tracer tracer; imperative::Tracer tracer;
std::shared_ptr<imperative::VarBase> x_in( std::shared_ptr<imperative::VarBase> x_in(
new imperative::VarBase(true, "x_in")); new imperative::VarBase(true, "x_in"));
x_in->SetOverridedStopGradient(false); // force to run backward
std::shared_ptr<imperative::VarBase> y_in( std::shared_ptr<imperative::VarBase> y_in(
new imperative::VarBase(true, "y_in")); new imperative::VarBase(true, "y_in"));
y_in->SetOverridedStopGradient(false);
std::shared_ptr<imperative::VarBase> vout( std::shared_ptr<imperative::VarBase> vout(
new imperative::VarBase(true, "vout")); new imperative::VarBase(true, "vout"));
platform::CPUPlace place; platform::CPUPlace place;
platform::CUDAPlace gpu_place(0); platform::CUDAPlace gpu_place(0);
std::vector<float> src_data(10, 2.0); std::vector<float> src_data(10, 2.0);
std::vector<int64_t> dims1 = {2, 5}; std::vector<int64_t> dims1 = {2, 5};
std::vector<int64_t> dims2 = {5, 2}; std::vector<int64_t> dims2 = {2, 5};
auto* x_in_tensor = x_in->MutableVar()->GetMutable<framework::LoDTensor>(); auto* x_in_tensor = x_in->MutableVar()->GetMutable<framework::LoDTensor>();
auto* y_in_tensor = y_in->MutableVar()->GetMutable<framework::LoDTensor>(); auto* y_in_tensor = y_in->MutableVar()->GetMutable<framework::LoDTensor>();
...@@ -178,14 +219,54 @@ TEST(test_tracer, test_trace_op_with_multi_device_inputs) { ...@@ -178,14 +219,54 @@ TEST(test_tracer, test_trace_op_with_multi_device_inputs) {
imperative::NameVarBaseMap outs = {out_pair}; imperative::NameVarBaseMap outs = {out_pair};
framework::AttributeMap mul_attr_map; framework::AttributeMap mul_attr_map;
mul_attr_map["use_mkldnn"] = false; mul_attr_map["use_mkldnn"] = false;
tracer.TraceOp("mul", ins, outs, mul_attr_map, gpu_place, true); tracer.TraceOp("elementwise_add", ins, outs, mul_attr_map, gpu_place, true);
// run reduce sum
std::shared_ptr<imperative::VarBase> reduce_sum_out(
new imperative::VarBase(true, "reduce_sum_out"));
var_pair reduce_sum_in_pair = var_pair("X", vb_vector(1, vout));
var_pair reduce_sum_out_pair = var_pair("Out", vb_vector(1, reduce_sum_out));
imperative::NameVarBaseMap reduce_in = {reduce_sum_in_pair};
imperative::NameVarBaseMap reduce_out = {reduce_sum_out_pair};
framework::AttributeMap reduce_attr_map;
tracer.TraceOp("reduce_sum", reduce_in, reduce_out, reduce_attr_map,
gpu_place, true);
detail::BackwardStrategy back_st;
imperative::Engine* engine = tracer.GetDefaultEngine();
engine->Init(reduce_sum_out.get(), back_st);
engine->Execute();
framework::LoDTensor rlt; framework::LoDTensor rlt;
framework::TensorCopySync(vout->Var().Get<framework::LoDTensor>(), place, framework::TensorCopySync(vout->Var().Get<framework::LoDTensor>(), place,
&rlt); &rlt);
for (size_t i = 0; i < rlt.numel(); i++) { for (size_t i = 0; i < rlt.numel(); i++) {
ASSERT_EQ(rlt.data<float>()[i], 20.0); ASSERT_EQ(rlt.data<float>()[i], 4.0);
}
framework::LoDTensor out_grad;
framework::TensorCopySync(vout->GradVar().Get<framework::LoDTensor>(), place,
&out_grad);
for (size_t i = 0; i < out_grad.numel(); ++i) {
ASSERT_EQ(out_grad.data<float>()[i], 1.0);
}
framework::LoDTensor x_grad;
framework::TensorCopySync(x_in->GradVar().Get<framework::LoDTensor>(), place,
&x_grad);
for (size_t i = 0; i < x_grad.numel(); ++i) {
ASSERT_EQ(x_grad.data<float>()[i], 1.0);
}
framework::LoDTensor y_grad;
framework::TensorCopySync(y_in->GradVar().Get<framework::LoDTensor>(), place,
&y_grad);
for (size_t i = 0; i < y_grad.numel(); ++i) {
ASSERT_EQ(y_grad.data<float>()[i], 1.0);
} }
} }
#endif #endif
TEST(test_tracer, test_unique_name_generator) { TEST(test_tracer, test_unique_name_generator) {
...@@ -201,3 +282,6 @@ TEST(test_tracer, test_unique_name_generator) { ...@@ -201,3 +282,6 @@ TEST(test_tracer, test_unique_name_generator) {
} // namespace paddle } // namespace paddle
USE_OP(mul); USE_OP(mul);
USE_OP(reduce_sum);
USE_OP(reduce_sum_grad);
USE_OP(elementwise_add);
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/imperative/tracer.h" #include "paddle/fluid/imperative/tracer.h"
#include <set>
#include <unordered_set> #include <unordered_set>
#include <utility> #include <utility>
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
...@@ -19,14 +20,17 @@ ...@@ -19,14 +20,17 @@
namespace paddle { namespace paddle {
namespace imperative { namespace imperative {
static std::vector<std::unique_ptr<framework::OpDesc>> CreateGradOpDescs( struct OpBaseCmp {
const framework::OpInfo& op_info, const framework::OpDesc& op_desc, bool operator()(OpBase* first, OpBase* second) {
const std::unordered_set<std::string>& no_grad_set, return first->id() > second->id();
const std::vector<framework::BlockDesc*>& grad_sub_block, }
std::unordered_map<std::string, std::string>* grad_to_var) { };
if (op_info.grad_op_maker_) {
return op_info.grad_op_maker_(op_desc, no_grad_set, grad_to_var, static std::vector<std::unique_ptr<OpBase>> CreateGradOpBases(
grad_sub_block); const OpBase* fw_op_base, const NameVarBaseMap& in,
const NameVarBaseMap& out) {
if (fw_op_base->Info().dygraph_grad_op_maker_) {
return fw_op_base->Info().dygraph_grad_op_maker_(fw_op_base, in, out);
} else { } else {
return {}; return {};
} }
...@@ -57,9 +61,7 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins, ...@@ -57,9 +61,7 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins,
} }
if (ComputeRequiredGrad(ins, outs, trace_backward)) { if (ComputeRequiredGrad(ins, outs, trace_backward)) {
TraceBackward(op, framework::OpDesc(op->Type(), op->InputNameMap(), TraceBackward(op, ins, outs);
op->OutputNameMap(), op->Attrs()),
ins, outs);
} else { } else {
VLOG(3) << "No Grad to track for Op: " << type; VLOG(3) << "No Grad to track for Op: " << type;
} }
...@@ -84,7 +86,6 @@ bool Tracer::ComputeRequiredGrad(const NameVarBaseMap& ins, ...@@ -84,7 +86,6 @@ bool Tracer::ComputeRequiredGrad(const NameVarBaseMap& ins,
} }
void Tracer::TraceBackward(const std::shared_ptr<OpBase>& fwd_op, void Tracer::TraceBackward(const std::shared_ptr<OpBase>& fwd_op,
const framework::OpDesc& fwd_op_desc,
const NameVarBaseMap& ins, const NameVarBaseMap& ins,
const NameVarBaseMap& outs) { const NameVarBaseMap& outs) {
// grad_to_var is a map of framework::GradVarName(in_var_name/out_var_name) -> // grad_to_var is a map of framework::GradVarName(in_var_name/out_var_name) ->
...@@ -92,168 +93,70 @@ void Tracer::TraceBackward(const std::shared_ptr<OpBase>& fwd_op, ...@@ -92,168 +93,70 @@ void Tracer::TraceBackward(const std::shared_ptr<OpBase>& fwd_op,
std::unordered_map<std::string, std::string> grad_to_var; std::unordered_map<std::string, std::string> grad_to_var;
// Get grad_op_desc using fwd_op_desc // Get grad_op_desc using fwd_op_desc
std::vector<std::unique_ptr<framework::OpDesc>> grad_op_descs_ = std::vector<std::unique_ptr<OpBase>> grad_op_bases_ =
CreateGradOpDescs(fwd_op->Info(), fwd_op_desc, {}, {}, &grad_to_var); CreateGradOpBases(fwd_op.get(), ins, outs);
// Create grad_ops using grad_op_descs
size_t grad_op_num = grad_op_descs_.size(); size_t grad_op_num = grad_op_bases_.size();
VLOG(3) << "Create " << grad_op_num << " grad op desc(s) to op " std::set<VarBase*> set_input_vars;
<< fwd_op->Type(); for (auto& fwd_in_it : ins) {
for (auto& var_base_it : fwd_in_it.second) {
if (grad_op_num == 0) { set_input_vars.insert(var_base_it.get());
return;
}
// Build a map to record var_name -> std::shared_ptr<VarBase>*,
// so that we can find suitable var in grad op descs
std::unordered_map<std::string, const std::shared_ptr<VarBase>*> name_to_var;
for (auto& pair : ins) {
for (auto& var : pair.second) {
auto& var_ptr = name_to_var[var->Name()];
PADDLE_ENFORCE_EQ(var_ptr == nullptr || var_ptr->get() == var.get(), true,
"There are different variables with same name %s",
var->Name());
var_ptr = &var;
} }
} }
for (auto& pair : outs) { for (auto& fwd_out_it : outs) {
for (auto& var : pair.second) { for (auto& var_base_it : fwd_out_it.second) {
auto& var_ptr = name_to_var[var->Name()]; set_input_vars.insert(var_base_it.get());
PADDLE_ENFORCE_EQ(var_ptr == nullptr || var_ptr->get() == var.get(), true,
"There are different variables with same name %s",
var->Name());
var_ptr = &var;
} }
} }
// Build backward ins and outs for (size_t i = 0; i < grad_op_num; ++i) {
for (size_t i = 0; i < grad_op_num; i++) {
// Step1: build grad op and add them to engine
// Use trace id to decide the order of gradient sum in sorted sum mode
size_t trace_id = fwd_op->id(); size_t trace_id = fwd_op->id();
std::shared_ptr<OpBase> grad_op =
OpBase::Create(trace_id, (*(grad_op_descs_[i].get())), fwd_op->place());
// this OpBase* is just used to manage op's life time
engine_->InsertOp(grad_op.get(), grad_op);
std::unordered_set<OpBase*> visited_preceding_ops;
// Step2 : prepare grad_in vars and bind them with grad_op,
// set inputs' grad_op as current grad_op
for (const auto& grad_ins : grad_op_descs_[i]->Inputs()) {
if (grad_ins.second.empty()) continue;
auto& bwd_in = (*grad_op->GetMutableInsMap())[grad_ins.first];
bwd_in.reserve(grad_ins.second.size());
for (auto& grad_in_var_name : grad_ins.second) {
auto iter = grad_to_var.find(grad_in_var_name);
if (iter != grad_to_var.end()) { std::shared_ptr<OpBase> grad_op = std::move(grad_op_bases_[i]);
// If it is a grad var, find its coresponding forward var grad_op->SetId(trace_id);
auto& fwd_var_name = iter->second; grad_op->SetPlace(fwd_op->place());
auto fwd_var_iter = name_to_var.find(fwd_var_name); grad_op->CreateOperatorBase();
PADDLE_ENFORCE_EQ(fwd_var_iter != name_to_var.end(), true,
"Cannot find forward variable named %s", auto& grad_in = *(grad_op->GetMutableInsMap());
fwd_var_name); auto& grad_out = *(grad_op->GetMutableOutsMap());
const auto& tmp = (*(fwd_var_iter->second))->GradVarBase(); for (auto& grad_in_it : grad_in) {
PADDLE_ENFORCE_NOT_NULL( for (auto& var_base_it : grad_in_it.second) {
tmp.get(), if (set_input_vars.count(var_base_it.get()) == 0) {
"Grad of %s should " var_base_it->AddGradOps(grad_op);
"not be NULL when we Track_Backward Input of %s", engine_->InsertGradVar(var_base_it.get());
(*(fwd_var_iter->second))->Name(), grad_op->Type());
// Create grad_in's dim in tensor for Grad Dependency compute
auto* tensor = tmp->MutableVar()->GetMutable<framework::LoDTensor>();
tensor->Resize((*(fwd_var_iter->second))
->Var()
.Get<framework::LoDTensor>()
.dims());
// Add Grad Op for grad_in
tmp->AddGradOps(grad_op);
VLOG(3) << "Add Grad Op " << grad_op->Type() << " for :"
<< (*(fwd_var_iter->second))->GradVarBase()->Name();
// Add Grad var input to engine set
engine_->InsertGradVar(tmp.get());
VLOG(3) << "Add Grad: " << tmp->Name() << " in to Engine";
bwd_in.emplace_back((*(fwd_var_iter->second))->GradVarBase());
} else {
// If it is a forward var, just add it
auto fwd_var_iter = name_to_var.find(grad_in_var_name);
PADDLE_ENFORCE_EQ(fwd_var_iter != name_to_var.end(), true,
"Cannot find forward variable named %s",
grad_in_var_name);
bwd_in.emplace_back(*(fwd_var_iter->second));
} }
VLOG(3) << "Set backward input from fwd var" << grad_ins.first << " of "
<< grad_op->Type() << " to be "
<< (bwd_in.back() ? bwd_in.back()->Name() : "nullptr");
} }
} }
// Step3: prepare grad_out vars and using their grad_ops to set current std::set<OpBase*, OpBaseCmp> visited_preceding_ops;
// grad_op's preceding op for (auto& grad_out_it : grad_out) {
for (auto& grad_outs : grad_op_descs_[i]->Outputs()) { bool flag_clear_list = false;
if (grad_outs.second.empty()) continue; for (auto& var_base_it : grad_out_it.second) {
auto& bwd_out = (*grad_op->GetMutableOutsMap())[grad_outs.first]; if ((!var_base_it->OverridedStopGradient()) ||
bwd_out.reserve(grad_outs.second.size()); (grad_out_it.second.size() > 1)) {
auto preceding_ops = var_base_it->GradOps();
for (auto& grad_out_var_name : grad_outs.second) { if (!preceding_ops.empty()) {
auto iter = grad_to_var.find(grad_out_var_name); for (const auto& op : preceding_ops) {
PADDLE_ENFORCE_EQ(iter != grad_to_var.end(), true, visited_preceding_ops.insert(op);
"Cannot find output of input grad %s in op %s",
grad_out_var_name, fwd_op->Type());
auto fwd_var_iter = name_to_var.find(iter->second);
PADDLE_ENFORCE_EQ(fwd_var_iter != name_to_var.end(), true,
"Cannot find forward variable named %s",
iter->second);
const auto& tmp = (*(fwd_var_iter->second))->GradVarBase();
PADDLE_ENFORCE_NOT_NULL(tmp.get(),
"Grad output: %s of op: %s should not be NULL",
(tmp->Name(), grad_op->Type()));
if ((!tmp->OverridedStopGradient()) || (grad_outs.second.size() > 1)) {
VLOG(3) << "Set backward output " << grad_outs.first << " of "
<< grad_op->Type() << " to be " << tmp->Name()
<< ". Its Overrided Stop_Gradient is: False";
bwd_out.emplace_back(tmp);
auto grad_pending_ops =
(*(fwd_var_iter->second))->GradVarBase()->GradOps();
if (VLOG_IS_ON(3) && !grad_pending_ops.empty()) {
VLOG(3) << "Add grad_pending Op of :"
<< (*(fwd_var_iter->second))->GradVarBase()->Name()
<< " It's grad_pending Op are: ";
for (const auto& op : grad_pending_ops) {
VLOG(3) << op->Type();
}
}
auto grad_name = (*(fwd_var_iter->second))->GradVarBase()->Name();
if (!grad_pending_ops.empty()) {
for (const auto& op : grad_pending_ops) {
PADDLE_ENFORCE_NOT_NULL(
op, "No nullptr should be grad_pending op for variable %s ",
grad_name);
if (visited_preceding_ops.count(op) == 0) {
visited_preceding_ops.insert(op);
grad_op->InsertGradPendingOps(op);
}
} }
} else {
VLOG(5) << "Hit leaf VarBase"
<< (*(fwd_var_iter->second))->GradVarBase()->Name();
} }
} else { } else {
VLOG(3) << "Skip backward output " << grad_outs.first << " of " flag_clear_list = true;
<< grad_op->Type() << " Named: " << tmp->Name()
<< ", since its Overrided Stop_Gradient is: True";
} }
} }
if (flag_clear_list) {
grad_out_it.second.clear();
}
} }
// To ensure numeric stability as static graph std::vector<OpBase*> vec_preceding_ops(visited_preceding_ops.begin(),
grad_op->SortGradPendingOps(); visited_preceding_ops.end());
grad_op->SetGradPendingOps(std::move(vec_preceding_ops));
// this OpBase* is just used to manage op's life time
engine_->InsertOp(grad_op.get(), grad_op);
} }
} }
......
...@@ -60,7 +60,6 @@ class Tracer { ...@@ -60,7 +60,6 @@ class Tracer {
const NameVarBaseMap& outs, bool trace_backward); const NameVarBaseMap& outs, bool trace_backward);
void TraceBackward(const std::shared_ptr<OpBase>& fwd_op, void TraceBackward(const std::shared_ptr<OpBase>& fwd_op,
const framework::OpDesc& fwd_op_desc,
const NameVarBaseMap& ins, const NameVarBaseMap& outs); const NameVarBaseMap& ins, const NameVarBaseMap& outs);
Engine* GetDefaultEngine() const { return engine_.get(); } Engine* GetDefaultEngine() const { return engine_.get(); }
......
...@@ -62,29 +62,29 @@ static constexpr bool CanInplaceAct() { ...@@ -62,29 +62,29 @@ static constexpr bool CanInplaceAct() {
} \ } \
} }
template <ActBwdOpFwdDeps kDepValue> template <ActBwdOpFwdDeps kDepValue, typename T>
class ActivationGradOpDescMaker : public framework::SingleGradOpDescMaker { class ActivationGradOpMaker : public framework::SingleGradOpMaker<T> {
public: public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<framework::OpDesc> Apply() const override { std::unique_ptr<T> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc()); std::unique_ptr<T> op(new T());
op->SetType(ForwardOpType() + "_grad"); op->SetType(this->ForwardOpType() + "_grad");
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(Attrs()); op->SetAttrMap(this->Attrs());
if ((static_cast<int>(kDepValue) & if ((static_cast<int>(kDepValue) &
static_cast<int>(ActBwdOpFwdDeps::kDepX)) || static_cast<int>(ActBwdOpFwdDeps::kDepX)) ||
FLAGS_use_mkldnn || (op->HasAttr("use_mkldnn") && FLAGS_use_mkldnn || (op->HasAttr("use_mkldnn") &&
boost::get<bool>(op->GetAttr("use_mkldnn")))) { boost::get<bool>(op->GetAttr("use_mkldnn")))) {
op->SetInput("X", Input("X")); op->SetInput("X", this->Input("X"));
} }
if (static_cast<int>(kDepValue) & if (static_cast<int>(kDepValue) &
static_cast<int>(ActBwdOpFwdDeps::kDepOut)) { static_cast<int>(ActBwdOpFwdDeps::kDepOut)) {
op->SetInput("Out", Output("Out")); op->SetInput("Out", this->Output("Out"));
} }
return op; return op;
...@@ -721,91 +721,94 @@ class ActivationOpDoubleGrad2 : public framework::OperatorWithKernel { ...@@ -721,91 +721,94 @@ class ActivationOpDoubleGrad2 : public framework::OperatorWithKernel {
// ReluGrad: dx = dy if y >= 0 else 0 // ReluGrad: dx = dy if y >= 0 else 0
// ReluGradGrad: ddy = ddx if y >= 0 else 0 // ReluGradGrad: ddy = ddx if y >= 0 else 0
// //
class ReluDoubleGradMaker : public ::paddle::framework::SingleGradOpDescMaker { template <typename T>
class ReluDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
public: public:
using ::paddle::framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<::paddle::framework::OpDesc> Apply() const override { std::unique_ptr<T> Apply() const override {
auto* op = new ::paddle::framework::OpDesc(); auto* op = new T();
op->SetType("relu_grad_grad"); op->SetType("relu_grad_grad");
// input1: Out // input1: Out
op->SetInput("Out", Input("Out")); op->SetInput("Out", this->Input("Out"));
// input2: ddx // input2: ddx
op->SetInput("DDX", OutputGrad(framework::GradVarName("X"))); op->SetInput("DDX", this->OutputGrad(framework::GradVarName("X")));
op->SetAttrMap(Attrs()); op->SetAttrMap(this->Attrs());
// output: ddy // output: ddy
op->SetOutput("DDOut", InputGrad(framework::GradVarName("Out"))); op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out")));
return std::unique_ptr<::paddle::framework::OpDesc>(op); return std::unique_ptr<T>(op);
} }
}; };
// leaky_relu Grad: dx=dy if y>=0 else alpha * dy // leaky_relu Grad: dx=dy if y>=0 else alpha * dy
// leaky_relu GradGrad: ddy=ddx if y>=0 else alpha * ddx // leaky_relu GradGrad: ddy=ddx if y>=0 else alpha * ddx
template <typename T>
class LeakyReluDoubleGradMaker class LeakyReluDoubleGradMaker
: public ::paddle::framework::SingleGradOpDescMaker { : public ::paddle::framework::SingleGradOpMaker<T> {
public: public:
using ::paddle::framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<::paddle::framework::OpDesc> Apply() const override { std::unique_ptr<T> Apply() const override {
auto* op = new ::paddle::framework::OpDesc(); auto* op = new T();
op->SetType("leaky_relu_grad_grad"); op->SetType("leaky_relu_grad_grad");
// input1: Out // input1: Out
op->SetInput("Out", Input("Out")); op->SetInput("Out", this->Input("Out"));
// X@GRAD@GRAD: ddx // X@GRAD@GRAD: ddx
op->SetInput("DDX", OutputGrad(framework::GradVarName("X"))); op->SetInput("DDX", this->OutputGrad(framework::GradVarName("X")));
op->SetAttrMap(Attrs()); op->SetAttrMap(this->Attrs());
// Out@GRAD@GRAD: ddy // Out@GRAD@GRAD: ddy
op->SetOutput("DDOut", InputGrad(framework::GradVarName("Out"))); op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out")));
return std::unique_ptr<::paddle::framework::OpDesc>(op); return std::unique_ptr<T>(op);
} }
}; };
// sqrt Grad: dx = 0.5 * dy / y // sqrt Grad: dx = 0.5 * dy / y
// sqrt GradGrad: ddy = 0.5 * ddx / y, dy = -1 * dx * ddx // sqrt GradGrad: ddy = 0.5 * ddx / y, dy = -1 * dx * ddx
class SqrtDoubleGradMaker : public ::paddle::framework::SingleGradOpDescMaker { template <typename T>
class SqrtDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
public: public:
using ::paddle::framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<::paddle::framework::OpDesc> Apply() const override { std::unique_ptr<T> Apply() const override {
auto* op = new ::paddle::framework::OpDesc(); auto* op = new T();
op->SetType("sqrt_grad_grad"); op->SetType("sqrt_grad_grad");
op->SetInput("Out", Input("Out")); op->SetInput("Out", this->Input("Out"));
op->SetInput("DX", Output(framework::GradVarName("X"))); op->SetInput("DX", this->Output(framework::GradVarName("X")));
op->SetInput("DDX", OutputGrad(framework::GradVarName("X"))); op->SetInput("DDX", this->OutputGrad(framework::GradVarName("X")));
op->SetAttrMap(Attrs()); op->SetAttrMap(this->Attrs());
op->SetOutput("DOut", InputGrad("Out")); op->SetOutput("DOut", this->InputGrad("Out"));
op->SetOutput("DDOut", InputGrad(framework::GradVarName("Out"))); op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out")));
return std::unique_ptr<::paddle::framework::OpDesc>(op); return std::unique_ptr<T>(op);
} }
}; };
// square Grad: dx=2x*dy // square Grad: dx=2x*dy
// square GradGrad: ddy=2x*ddx, dx=2dy*ddx // square GradGrad: ddy=2x*ddx, dx=2dy*ddx
class SquareDoubleGradMaker template <typename T>
: public ::paddle::framework::SingleGradOpDescMaker { class SquareDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
public: public:
using ::paddle::framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<::paddle::framework::OpDesc> Apply() const override { std::unique_ptr<T> Apply() const override {
auto* op = new ::paddle::framework::OpDesc(); auto* op = new T();
op->SetType("square_grad_grad"); op->SetType("square_grad_grad");
op->SetInput("X", Input("X")); op->SetInput("X", this->Input("X"));
// Out@GRAD: dy // Out@GRAD: dy
op->SetInput("DOut", Input(framework::GradVarName("Out"))); op->SetInput("DOut", this->Input(framework::GradVarName("Out")));
// X@GRAD@GRAD: ddx // X@GRAD@GRAD: ddx
op->SetInput("DDX", OutputGrad(framework::GradVarName("X"))); op->SetInput("DDX", this->OutputGrad(framework::GradVarName("X")));
op->SetAttrMap(Attrs()); op->SetAttrMap(this->Attrs());
// X@GRAD: dx // X@GRAD: dx
op->SetOutput("DX", InputGrad("X")); op->SetOutput("DX", this->InputGrad("X"));
// Out@GRAD@GRAD: ddy // Out@GRAD@GRAD: ddy
op->SetOutput("DDOut", InputGrad(framework::GradVarName("Out"))); op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out")));
return std::unique_ptr<::paddle::framework::OpDesc>(op); return std::unique_ptr<T>(op);
} }
}; };
...@@ -815,19 +818,20 @@ DECLARE_INPLACE_OP_INFERER(ActivationGradOpInplaceInference, ...@@ -815,19 +818,20 @@ DECLARE_INPLACE_OP_INFERER(ActivationGradOpInplaceInference,
DECLARE_INPLACE_OP_INFERER(ActivationDoubleGradOpInplaceInference, DECLARE_INPLACE_OP_INFERER(ActivationDoubleGradOpInplaceInference,
{"DDX", "DDOut"}); {"DDX", "DDOut"});
class PowGradOpDescMaker : public framework::SingleGradOpDescMaker { template <typename T>
class PowGradOpMaker : public framework::SingleGradOpMaker<T> {
public: public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<framework::OpDesc> Apply() const override { std::unique_ptr<T> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc()); std::unique_ptr<T> op(new T());
op->SetType("pow_grad"); op->SetType("pow_grad");
op->SetInput("X", Input("X")); op->SetInput("X", this->Input("X"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetInput("FactorTensor", Input("FactorTensor")); op->SetInput("FactorTensor", this->Input("FactorTensor"));
op->SetAttrMap(Attrs()); op->SetAttrMap(this->Attrs());
return op; return op;
} }
...@@ -894,7 +898,10 @@ namespace plat = paddle::platform; ...@@ -894,7 +898,10 @@ namespace plat = paddle::platform;
REGISTER_OPERATOR( \ REGISTER_OPERATOR( \
KERNEL_TYPE, ops::ActivationOp, ops::OP_NAME##OpMaker, \ KERNEL_TYPE, ops::ActivationOp, ops::OP_NAME##OpMaker, \
ops::ActivationOpInferVarType, \ ops::ActivationOpInferVarType, \
ops::ActivationGradOpDescMaker<ops::grad_functor<float>::FwdDeps()>, \ ops::ActivationGradOpMaker<ops::grad_functor<float>::FwdDeps(), \
paddle::framework::OpDesc>, \
ops::ActivationGradOpMaker<ops::grad_functor<float>::FwdDeps(), \
paddle::imperative::OpBase>, \
std::conditional<ops::CanInplaceAct<ops::grad_functor<float>>(), \ std::conditional<ops::CanInplaceAct<ops::grad_functor<float>>(), \
::paddle::framework::SingleOpInplaceInToOut, \ ::paddle::framework::SingleOpInplaceInToOut, \
void>::type); \ void>::type); \
...@@ -921,11 +928,15 @@ FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_CPU_KERNEL); ...@@ -921,11 +928,15 @@ FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_CPU_KERNEL);
/* ========================== relu register ============================= */ /* ========================== relu register ============================= */
REGISTER_OPERATOR( REGISTER_OPERATOR(
relu, ops::ActivationOp, ops::ReluOpMaker, ops::ActivationOpInferVarType, relu, ops::ActivationOp, ops::ReluOpMaker, ops::ActivationOpInferVarType,
ops::ActivationGradOpDescMaker<ops::ReluGradFunctor<float>::FwdDeps()>, ops::ActivationGradOpMaker<ops::ReluGradFunctor<float>::FwdDeps(),
paddle::framework::OpDesc>,
ops::ActivationGradOpMaker<ops::ReluGradFunctor<float>::FwdDeps(),
paddle::imperative::OpBase>,
paddle::framework::SingleOpInplaceInToOut); paddle::framework::SingleOpInplaceInToOut);
REGISTER_OPERATOR(relu_grad, ops::ActivationOpGrad, REGISTER_OPERATOR(relu_grad, ops::ActivationOpGrad,
ops::ActivationGradOpInplaceInference, ops::ActivationGradOpInplaceInference,
ops::ReluDoubleGradMaker); ops::ReluDoubleGradMaker<paddle::framework::OpDesc>,
ops::ReluDoubleGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR( REGISTER_OPERATOR(
relu_grad_grad, relu_grad_grad,
ops::ActivationOpDoubleGrad2<ops::ReluGradFunctor<float>::FwdDeps()>, ops::ActivationOpDoubleGrad2<ops::ReluGradFunctor<float>::FwdDeps()>,
...@@ -947,11 +958,15 @@ REGISTER_OP_CPU_KERNEL( ...@@ -947,11 +958,15 @@ REGISTER_OP_CPU_KERNEL(
REGISTER_OPERATOR( REGISTER_OPERATOR(
leaky_relu, ops::ActivationOp, ops::LeakyReluOpMaker, leaky_relu, ops::ActivationOp, ops::LeakyReluOpMaker,
ops::ActivationOpInferVarType, ops::ActivationOpInferVarType,
ops::ActivationGradOpDescMaker<ops::LeakyReluGradFunctor<float>::FwdDeps()>, ops::ActivationGradOpMaker<ops::LeakyReluGradFunctor<float>::FwdDeps(),
paddle::framework::OpDesc>,
ops::ActivationGradOpMaker<ops::LeakyReluGradFunctor<float>::FwdDeps(),
paddle::imperative::OpBase>,
paddle::framework::SingleOpInplaceInToOut); paddle::framework::SingleOpInplaceInToOut);
REGISTER_OPERATOR(leaky_relu_grad, ops::ActivationOpGrad, REGISTER_OPERATOR(leaky_relu_grad, ops::ActivationOpGrad,
ops::ActivationGradOpInplaceInference, ops::ActivationGradOpInplaceInference,
ops::LeakyReluDoubleGradMaker); ops::LeakyReluDoubleGradMaker<paddle::framework::OpDesc>,
ops::LeakyReluDoubleGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR( REGISTER_OPERATOR(
leaky_relu_grad_grad, leaky_relu_grad_grad,
ops::ActivationOpDoubleGrad2<ops::LeakyReluGradFunctor<float>::FwdDeps()>, ops::ActivationOpDoubleGrad2<ops::LeakyReluGradFunctor<float>::FwdDeps()>,
...@@ -972,11 +987,15 @@ REGISTER_OP_CPU_KERNEL( ...@@ -972,11 +987,15 @@ REGISTER_OP_CPU_KERNEL(
/* =========================== sqrt register ============================= */ /* =========================== sqrt register ============================= */
REGISTER_OPERATOR( REGISTER_OPERATOR(
sqrt, ops::ActivationOp, ops::SqrtOpMaker, ops::ActivationOpInferVarType, sqrt, ops::ActivationOp, ops::SqrtOpMaker, ops::ActivationOpInferVarType,
ops::ActivationGradOpDescMaker<ops::SqrtGradFunctor<float>::FwdDeps()>, ops::ActivationGradOpMaker<ops::SqrtGradFunctor<float>::FwdDeps(),
paddle::framework::OpDesc>,
ops::ActivationGradOpMaker<ops::SqrtGradFunctor<float>::FwdDeps(),
paddle::imperative::OpBase>,
paddle::framework::SingleOpInplaceInToOut); paddle::framework::SingleOpInplaceInToOut);
REGISTER_OPERATOR(sqrt_grad, ops::ActivationOpGrad, REGISTER_OPERATOR(sqrt_grad, ops::ActivationOpGrad,
ops::ActivationGradOpInplaceInference, ops::ActivationGradOpInplaceInference,
ops::SqrtDoubleGradMaker); ops::SqrtDoubleGradMaker<paddle::framework::OpDesc>,
ops::SqrtDoubleGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR( REGISTER_OPERATOR(
sqrt_grad_grad, sqrt_grad_grad,
ops::ActivationOpDoubleGrad<ops::SqrtGradGradFunctor<float>::FwdDeps()>, ops::ActivationOpDoubleGrad<ops::SqrtGradGradFunctor<float>::FwdDeps()>,
...@@ -996,11 +1015,15 @@ REGISTER_OP_CPU_KERNEL( ...@@ -996,11 +1015,15 @@ REGISTER_OP_CPU_KERNEL(
REGISTER_OPERATOR( REGISTER_OPERATOR(
square, ops::ActivationOp, ops::SquareOpMaker, square, ops::ActivationOp, ops::SquareOpMaker,
ops::ActivationOpInferVarType, ops::ActivationOpInferVarType,
ops::ActivationGradOpDescMaker<ops::SquareGradFunctor<float>::FwdDeps()>, ops::ActivationGradOpMaker<ops::SquareGradFunctor<float>::FwdDeps(),
paddle::framework::OpDesc>,
ops::ActivationGradOpMaker<ops::SquareGradFunctor<float>::FwdDeps(),
paddle::imperative::OpBase>,
paddle::framework::SingleOpInplaceInToOut); paddle::framework::SingleOpInplaceInToOut);
REGISTER_OPERATOR(square_grad, ops::ActivationOpGrad, REGISTER_OPERATOR(square_grad, ops::ActivationOpGrad,
ops::ActivationGradOpInplaceInference, ops::ActivationGradOpInplaceInference,
ops::SquareDoubleGradMaker); ops::SquareDoubleGradMaker<paddle::framework::OpDesc>,
ops::SquareDoubleGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR( REGISTER_OPERATOR(
square_grad_grad, square_grad_grad,
ops::ActivationOpDoubleGrad<ops::SquareGradGradFunctor<float>::FwdDeps()>, ops::ActivationOpDoubleGrad<ops::SquareGradGradFunctor<float>::FwdDeps()>,
...@@ -1023,7 +1046,8 @@ REGISTER_OP_CPU_KERNEL( ...@@ -1023,7 +1046,8 @@ REGISTER_OP_CPU_KERNEL(
REGISTER_OPERATOR( REGISTER_OPERATOR(
pow, ops::PowOp, ops::PowOpMaker, ops::ActivationOpInferVarType, pow, ops::PowOp, ops::PowOpMaker, ops::ActivationOpInferVarType,
ops::PowGradOpDescMaker, ops::PowGradOpMaker<paddle::framework::OpDesc>,
ops::PowGradOpMaker<paddle::imperative::OpBase>,
std::conditional<ops::CanInplaceAct<ops::PowGradFunctor<float>>(), std::conditional<ops::CanInplaceAct<ops::PowGradFunctor<float>>(),
::paddle::framework::SingleOpInplaceInToOut, void>::type); ::paddle::framework::SingleOpInplaceInToOut, void>::type);
REGISTER_OPERATOR(pow_grad, ops::PowOpGrad, REGISTER_OPERATOR(pow_grad, ops::PowOpGrad,
......
...@@ -87,18 +87,18 @@ class AddPositionEncodingOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -87,18 +87,18 @@ class AddPositionEncodingOpMaker : public framework::OpProtoAndCheckerMaker {
} }
}; };
class AddPositionEncodingGradOpDescMaker template <typename T>
: public framework::SingleGradOpDescMaker { class AddPositionEncodingGradOpMaker : public framework::SingleGradOpMaker<T> {
public: public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<framework::OpDesc> Apply() const override { std::unique_ptr<T> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc()); std::unique_ptr<T> op(new T());
op->SetType("add_position_encoding_grad"); op->SetType("add_position_encoding_grad");
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(Attrs()); op->SetAttrMap(this->Attrs());
return op; return op;
} }
}; };
...@@ -109,9 +109,11 @@ class AddPositionEncodingGradOpDescMaker ...@@ -109,9 +109,11 @@ class AddPositionEncodingGradOpDescMaker
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plt = paddle::platform; namespace plt = paddle::platform;
REGISTER_OPERATOR(add_position_encoding, ops::AddPositionEncodingOp, REGISTER_OPERATOR(
ops::AddPositionEncodingOpMaker, add_position_encoding, ops::AddPositionEncodingOp,
ops::AddPositionEncodingGradOpDescMaker); ops::AddPositionEncodingOpMaker,
ops::AddPositionEncodingGradOpMaker<paddle::framework::OpDesc>,
ops::AddPositionEncodingGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(add_position_encoding_grad, ops::AddPositionEncodingOpGrad); REGISTER_OPERATOR(add_position_encoding_grad, ops::AddPositionEncodingOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
......
...@@ -127,24 +127,25 @@ class AffineChannelOpGrad : public framework::OperatorWithKernel { ...@@ -127,24 +127,25 @@ class AffineChannelOpGrad : public framework::OperatorWithKernel {
} }
}; };
class AffineChannelGradMaker : public framework::SingleGradOpDescMaker { template <typename T>
class AffineChannelGradMaker : public framework::SingleGradOpMaker<T> {
public: public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
std::unique_ptr<framework::OpDesc> Apply() const override { std::unique_ptr<T> Apply() const override {
auto* op = new framework::OpDesc(); auto* op = new T();
op->SetType("affine_channel_grad"); op->SetType("affine_channel_grad");
op->SetInput("X", Input("X")); op->SetInput("X", this->Input("X"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetInput("Scale", Input("Scale")); op->SetInput("Scale", this->Input("Scale"));
op->SetAttrMap(Attrs()); op->SetAttrMap(this->Attrs());
op->SetOutput(framework::GradVarName("X"), InputGrad("X")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetOutput(framework::GradVarName("Scale"), InputGrad("Scale")); op->SetOutput(framework::GradVarName("Scale"), this->InputGrad("Scale"));
op->SetOutput(framework::GradVarName("Bias"), InputGrad("Bias")); op->SetOutput(framework::GradVarName("Bias"), this->InputGrad("Bias"));
return std::unique_ptr<framework::OpDesc>(op); return std::unique_ptr<T>(op);
} }
}; };
...@@ -331,7 +332,9 @@ namespace ops = paddle::operators; ...@@ -331,7 +332,9 @@ namespace ops = paddle::operators;
using CPU = paddle::platform::CPUDeviceContext; using CPU = paddle::platform::CPUDeviceContext;
REGISTER_OPERATOR(affine_channel, ops::AffineChannelOp, REGISTER_OPERATOR(affine_channel, ops::AffineChannelOp,
ops::AffineChannelOpMaker, ops::AffineChannelGradMaker, ops::AffineChannelOpMaker,
ops::AffineChannelGradMaker<paddle::framework::OpDesc>,
ops::AffineChannelGradMaker<paddle::imperative::OpBase>,
ops::AffineChannelInplaceInferer); ops::AffineChannelInplaceInferer);
REGISTER_OPERATOR(affine_channel_grad, ops::AffineChannelOpGrad, REGISTER_OPERATOR(affine_channel_grad, ops::AffineChannelOpGrad,
ops::AffineChannelNoNeedBufferVarsInference, ops::AffineChannelNoNeedBufferVarsInference,
......
...@@ -197,22 +197,23 @@ class AffineGridOpGrad : public framework::OperatorWithKernel { ...@@ -197,22 +197,23 @@ class AffineGridOpGrad : public framework::OperatorWithKernel {
} }
}; };
class AffineGridGradMaker : public framework::SingleGradOpDescMaker { template <typename T>
class AffineGridGradMaker : public framework::SingleGradOpMaker<T> {
public: public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<framework::OpDesc> Apply() const override { std::unique_ptr<T> Apply() const override {
auto* op = new framework::OpDesc(); auto* op = new T();
op->SetType("affine_grid_grad"); op->SetType("affine_grid_grad");
op->SetInput("Theta", Input("Theta")); op->SetInput("Theta", this->Input("Theta"));
op->SetInput("OutputShape", Input("OutputShape")); op->SetInput("OutputShape", this->Input("OutputShape"));
op->SetInput(framework::GradVarName("Output"), OutputGrad("Output")); op->SetInput(framework::GradVarName("Output"), this->OutputGrad("Output"));
op->SetAttrMap(Attrs()); op->SetAttrMap(this->Attrs());
op->SetOutput(framework::GradVarName("Theta"), InputGrad("Theta")); op->SetOutput(framework::GradVarName("Theta"), this->InputGrad("Theta"));
return std::unique_ptr<framework::OpDesc>(op); return std::unique_ptr<T>(op);
} }
}; };
...@@ -221,7 +222,8 @@ class AffineGridGradMaker : public framework::SingleGradOpDescMaker { ...@@ -221,7 +222,8 @@ class AffineGridGradMaker : public framework::SingleGradOpDescMaker {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(affine_grid, ops::AffineGridOp, ops::AffineGridOpMaker, REGISTER_OPERATOR(affine_grid, ops::AffineGridOp, ops::AffineGridOpMaker,
ops::AffineGridGradMaker); ops::AffineGridGradMaker<paddle::framework::OpDesc>,
ops::AffineGridGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(affine_grid_grad, ops::AffineGridOpGrad); REGISTER_OPERATOR(affine_grid_grad, ops::AffineGridOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
......
...@@ -14,9 +14,10 @@ limitations under the License. */ ...@@ -14,9 +14,10 @@ limitations under the License. */
#include "paddle/fluid/operators/arg_min_max_op_base.h" #include "paddle/fluid/operators/arg_min_max_op_base.h"
REGISTER_OPERATOR(arg_max, paddle::operators::ArgMinMaxOp, REGISTER_OPERATOR(
paddle::operators::ArgMaxOpMaker, arg_max, paddle::operators::ArgMinMaxOp, paddle::operators::ArgMaxOpMaker,
paddle::framework::EmptyGradOpMaker); paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
arg_max, arg_max,
......
...@@ -14,9 +14,10 @@ limitations under the License. */ ...@@ -14,9 +14,10 @@ limitations under the License. */
#include "paddle/fluid/operators/arg_min_max_op_base.h" #include "paddle/fluid/operators/arg_min_max_op_base.h"
REGISTER_OPERATOR(arg_min, paddle::operators::ArgMinMaxOp, REGISTER_OPERATOR(
paddle::operators::ArgMinOpMaker, arg_min, paddle::operators::ArgMinMaxOp, paddle::operators::ArgMinOpMaker,
paddle::framework::EmptyGradOpMaker); paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
arg_min, arg_min,
......
...@@ -80,8 +80,10 @@ Output(Indices) gives the sorted order along the given axis Attr(axis). ...@@ -80,8 +80,10 @@ Output(Indices) gives the sorted order along the given axis Attr(axis).
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(argsort, ops::ArgsortOp, ops::ArgsortOpMaker, REGISTER_OPERATOR(
paddle::framework::EmptyGradOpMaker); argsort, ops::ArgsortOp, ops::ArgsortOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(argsort, REGISTER_OP_CPU_KERNEL(argsort,
ops::ArgsortKernel<paddle::platform::CPUPlace, float>, ops::ArgsortKernel<paddle::platform::CPUPlace, float>,
ops::ArgsortKernel<paddle::platform::CPUPlace, double>); ops::ArgsortKernel<paddle::platform::CPUPlace, double>);
...@@ -210,19 +210,20 @@ class ArrayToLoDTensorInferShape : public framework::InferShapeBase { ...@@ -210,19 +210,20 @@ class ArrayToLoDTensorInferShape : public framework::InferShapeBase {
} }
}; };
class ArrayToLoDTensorGradMaker : public framework::SingleGradOpDescMaker { template <typename T>
class ArrayToLoDTensorGradMaker : public framework::SingleGradOpMaker<T> {
public: public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<framework::OpDesc> Apply() const override { std::unique_ptr<T> Apply() const override {
auto *grad_op = new framework::OpDesc(); auto *grad_op = new T();
grad_op->SetType("lod_tensor_to_array"); grad_op->SetType("lod_tensor_to_array");
grad_op->SetInput("X", OutputGrad("Out")); grad_op->SetInput("X", this->OutputGrad("Out"));
grad_op->SetInput("RankTable", Input("RankTable")); grad_op->SetInput("RankTable", this->Input("RankTable"));
grad_op->SetOutput("Out", InputGrad("X")); grad_op->SetOutput("Out", this->InputGrad("X"));
grad_op->SetAttrMap(Attrs()); grad_op->SetAttrMap(this->Attrs());
return std::unique_ptr<framework::OpDesc>(grad_op); return std::unique_ptr<T>(grad_op);
} }
}; };
...@@ -233,4 +234,5 @@ namespace ops = paddle::operators; ...@@ -233,4 +234,5 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(array_to_lod_tensor, ops::ArrayToLoDTensorOp, REGISTER_OPERATOR(array_to_lod_tensor, ops::ArrayToLoDTensorOp,
ops::ArrayToLoDTensorOpProtoMaker, ops::ArrayToLoDTensorOpProtoMaker,
ops::ArrayToLoDTensorInferShape, ops::ArrayToLoDTensorInferShape,
ops::ArrayToLoDTensorGradMaker); ops::ArrayToLoDTensorGradMaker<paddle::framework::OpDesc>,
ops::ArrayToLoDTensorGradMaker<paddle::imperative::OpBase>);
...@@ -131,17 +131,18 @@ raise error if the type is not listed above. ...@@ -131,17 +131,18 @@ raise error if the type is not listed above.
} }
}; };
class AssignGradMaker : public framework::SingleGradOpDescMaker { template <typename T>
class AssignGradMaker : public framework::SingleGradOpMaker<T> {
public: public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<framework::OpDesc> Apply() const override { std::unique_ptr<T> Apply() const override {
auto *op = new framework::OpDesc(); auto *op = new T();
op->SetType("assign"); op->SetType("assign");
op->SetInput("X", OutputGrad("Out")); op->SetInput("X", this->OutputGrad("Out"));
op->SetOutput("Out", InputGrad("X")); op->SetOutput("Out", this->InputGrad("X"));
return std::unique_ptr<framework::OpDesc>(op); return std::unique_ptr<T>(op);
} }
}; };
...@@ -151,8 +152,11 @@ DECLARE_INPLACE_OP_INFERER(AssignOpInplaceInferer, {"X", "Out"}); ...@@ -151,8 +152,11 @@ DECLARE_INPLACE_OP_INFERER(AssignOpInplaceInferer, {"X", "Out"});
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(assign, ops::AssignOp, ops::AssignGradMaker, REGISTER_OPERATOR(assign, ops::AssignOp,
ops::AssignGradMaker<paddle::framework::OpDesc>,
ops::AssignGradMaker<paddle::imperative::OpBase>,
ops::AssignOpProtoMaker, ops::AssignOpInplaceInferer); ops::AssignOpProtoMaker, ops::AssignOpInplaceInferer);
REGISTER_OP_CPU_KERNEL_FUNCTOR(assign, float, ops::AssignKernel, double, REGISTER_OP_CPU_KERNEL_FUNCTOR(assign, float, ops::AssignKernel, double,
ops::AssignKernel, int, ops::AssignKernel, ops::AssignKernel, int, ops::AssignKernel,
int64_t, ops::AssignKernel, bool, int64_t, ops::AssignKernel, bool,
......
...@@ -70,7 +70,9 @@ $$Out = values$$ ...@@ -70,7 +70,9 @@ $$Out = values$$
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(assign_value, ops::AssignValueOp, ops::AssignValueOpMaker, REGISTER_OPERATOR(
paddle::framework::EmptyGradOpMaker); assign_value, ops::AssignValueOp, ops::AssignValueOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(assign_value, ops::AssignValueKernel<int>, REGISTER_OP_CPU_KERNEL(assign_value, ops::AssignValueKernel<int>,
ops::AssignValueKernel<float>); ops::AssignValueKernel<float>);
...@@ -205,9 +205,11 @@ And for a mini-batch in training, accumulators were computed as below steps: ...@@ -205,9 +205,11 @@ And for a mini-batch in training, accumulators were computed as below steps:
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(average_accumulates, ops::AverageAccumulatesOp, REGISTER_OPERATOR(
ops::AverageAccumulatesOpMaker, average_accumulates, ops::AverageAccumulatesOp,
paddle::framework::EmptyGradOpMaker); ops::AverageAccumulatesOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
average_accumulates, average_accumulates,
ops::AverageAccumulatesKernel<paddle::platform::CPUDeviceContext, float>, ops::AverageAccumulatesKernel<paddle::platform::CPUDeviceContext, float>,
......
...@@ -613,38 +613,47 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T> ...@@ -613,38 +613,47 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
} }
}; };
std::unique_ptr<framework::OpDesc> BatchNormGradMaker::Apply() const { template <typename T>
auto *op = new framework::OpDesc(); class BatchNormGradMaker : public framework::SingleGradOpMaker<T> {
op->SetType(GradOpType()); public:
op->SetInput("X", Input("X")); using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
op->SetInput(framework::GradVarName("Y"), OutputGrad("Y"));
protected:
op->SetInput("Scale", Input("Scale")); std::unique_ptr<T> Apply() const override {
op->SetInput("Bias", Input("Bias")); auto *op = new T();
op->SetInput("SavedMean", Output("SavedMean")); op->SetType(this->ForwardOpType() + "_grad");
op->SetInput("SavedVariance", Output("SavedVariance")); op->SetInput("X", this->Input("X"));
op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y"));
// used when setting use_global_stats True during training
if (boost::get<bool>(GetAttr("use_global_stats"))) { op->SetInput("Scale", this->Input("Scale"));
op->SetInput("Mean", Output("MeanOut")); op->SetInput("Bias", this->Input("Bias"));
op->SetInput("Variance", Output("VarianceOut")); op->SetInput("SavedMean", this->Output("SavedMean"));
} op->SetInput("SavedVariance", this->Output("SavedVariance"));
// used when setting use_global_stats True during training
if (boost::get<bool>(this->GetAttr("use_global_stats"))) {
op->SetInput("Mean", this->Output("MeanOut"));
op->SetInput("Variance", this->Output("VarianceOut"));
}
op->SetAttrMap(Attrs()); op->SetAttrMap(this->Attrs());
op->SetOutput(framework::GradVarName("X"), InputGrad("X")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetOutput(framework::GradVarName("Scale"), InputGrad("Scale")); op->SetOutput(framework::GradVarName("Scale"), this->InputGrad("Scale"));
op->SetOutput(framework::GradVarName("Bias"), InputGrad("Bias")); op->SetOutput(framework::GradVarName("Bias"), this->InputGrad("Bias"));
return std::unique_ptr<framework::OpDesc>(op); return std::unique_ptr<T>(op);
} }
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(batch_norm, ops::BatchNormOp, ops::BatchNormOpMaker, REGISTER_OPERATOR(batch_norm, ops::BatchNormOp, ops::BatchNormOpMaker,
ops::BatchNormOpInferVarType, ops::BatchNormGradMaker); ops::BatchNormOpInferVarType,
ops::BatchNormGradMaker<paddle::framework::OpDesc>,
ops::BatchNormGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(batch_norm_grad, ops::BatchNormGradOp); REGISTER_OPERATOR(batch_norm_grad, ops::BatchNormGradOp);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
......
...@@ -64,18 +64,6 @@ class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -64,18 +64,6 @@ class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker {
void Make() override; void Make() override;
}; };
class BatchNormGradMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override;
virtual std::string GradOpType() const {
return this->ForwardOpType() + "_grad";
}
};
class BatchNormOpInferVarType class BatchNormOpInferVarType
: public framework::PassInDtypeAndVarTypeToOutput { : public framework::PassInDtypeAndVarTypeToOutput {
protected: protected:
......
...@@ -216,8 +216,10 @@ class BeamSearchDecodeInferVarType : public framework::VarTypeInference { ...@@ -216,8 +216,10 @@ class BeamSearchDecodeInferVarType : public framework::VarTypeInference {
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
REGISTER_OPERATOR(beam_search_decode, paddle::operators::BeamSearchDecodeOp, REGISTER_OPERATOR(
paddle::operators::BeamSearchDecodeOpProtoMaker, beam_search_decode, paddle::operators::BeamSearchDecodeOp,
paddle::operators::BeamSearchDecodeInferShape, paddle::operators::BeamSearchDecodeOpProtoMaker,
paddle::operators::BeamSearchDecodeInferVarType, paddle::operators::BeamSearchDecodeInferShape,
paddle::framework::EmptyGradOpMaker); paddle::operators::BeamSearchDecodeInferVarType,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
...@@ -147,27 +147,28 @@ class BilinearTensorProductOpGrad : public framework::OperatorWithKernel { ...@@ -147,27 +147,28 @@ class BilinearTensorProductOpGrad : public framework::OperatorWithKernel {
} }
}; };
class BilinearTensorProductGradOpDescMaker template <typename T>
: public framework::SingleGradOpDescMaker { class BilinearTensorProductGradOpMaker
: public framework::SingleGradOpMaker<T> {
public: public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<framework::OpDesc> Apply() const override { std::unique_ptr<T> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc()); std::unique_ptr<T> op(new T());
op->SetType("bilinear_tensor_product_grad"); op->SetType("bilinear_tensor_product_grad");
op->SetAttrMap(Attrs()); op->SetAttrMap(this->Attrs());
op->SetInput("X", Input("X")); op->SetInput("X", this->Input("X"));
op->SetInput("Y", Input("Y")); op->SetInput("Y", this->Input("Y"));
op->SetInput("Weight", Input("Weight")); op->SetInput("Weight", this->Input("Weight"));
if (ForwardOp().Inputs().count("Bias") > 0) { if (this->HasInput("Bias")) {
op->SetOutput(framework::GradVarName("Bias"), InputGrad("Bias")); op->SetOutput(framework::GradVarName("Bias"), this->InputGrad("Bias"));
} }
op->SetOutput(framework::GradVarName("X"), InputGrad("X")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetOutput(framework::GradVarName("Y"), InputGrad("Y")); op->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
op->SetOutput(framework::GradVarName("Weight"), InputGrad("Weight")); op->SetOutput(framework::GradVarName("Weight"), this->InputGrad("Weight"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
return op; return op;
} }
...@@ -177,9 +178,11 @@ class BilinearTensorProductGradOpDescMaker ...@@ -177,9 +178,11 @@ class BilinearTensorProductGradOpDescMaker
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(bilinear_tensor_product, ops::BilinearTensorProductOp, REGISTER_OPERATOR(
ops::BilinearTensorProductOpMaker, bilinear_tensor_product, ops::BilinearTensorProductOp,
ops::BilinearTensorProductGradOpDescMaker); ops::BilinearTensorProductOpMaker,
ops::BilinearTensorProductGradOpMaker<paddle::framework::OpDesc>,
ops::BilinearTensorProductGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(bilinear_tensor_product_grad, REGISTER_OPERATOR(bilinear_tensor_product_grad,
ops::BilinearTensorProductOpGrad); ops::BilinearTensorProductOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
......
...@@ -135,19 +135,20 @@ neural networks>(https://arxiv.org/abs/1511.06939) ...@@ -135,19 +135,20 @@ neural networks>(https://arxiv.org/abs/1511.06939)
} }
}; };
class BprLossGradDescMaker : public framework::SingleGradOpDescMaker { template <typename T>
class BprLossGradMaker : public framework::SingleGradOpMaker<T> {
public: public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<framework::OpDesc> Apply() const override { std::unique_ptr<T> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc()); std::unique_ptr<T> op(new T());
op->SetType("bpr_loss_grad"); op->SetType("bpr_loss_grad");
op->SetInput("X", Input("X")); op->SetInput("X", this->Input("X"));
op->SetInput("Label", Input("Label")); op->SetInput("Label", this->Input("Label"));
op->SetInput(framework::GradVarName("Y"), OutputGrad("Y")); op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(Attrs()); op->SetAttrMap(this->Attrs());
return op; return op;
} }
}; };
...@@ -158,7 +159,8 @@ namespace ops = paddle::operators; ...@@ -158,7 +159,8 @@ namespace ops = paddle::operators;
using CPUCtx = paddle::platform::CPUDeviceContext; using CPUCtx = paddle::platform::CPUDeviceContext;
REGISTER_OPERATOR(bpr_loss, ops::BprLossOp, ops::BprLossOpMaker, REGISTER_OPERATOR(bpr_loss, ops::BprLossOp, ops::BprLossOpMaker,
ops::BprLossGradDescMaker); ops::BprLossGradMaker<paddle::framework::OpDesc>,
ops::BprLossGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(bpr_loss_grad, ops::BprLossGradientOp); REGISTER_OPERATOR(bpr_loss_grad, ops::BprLossGradientOp);
REGISTER_OP_CPU_KERNEL(bpr_loss, ops::BprLossOpKernel<CPUCtx, float>, REGISTER_OP_CPU_KERNEL(bpr_loss, ops::BprLossOpKernel<CPUCtx, float>,
ops::BprLossOpKernel<CPUCtx, double>); ops::BprLossOpKernel<CPUCtx, double>);
......
...@@ -49,19 +49,20 @@ class CastOpInferShape : public framework::InferShapeBase { ...@@ -49,19 +49,20 @@ class CastOpInferShape : public framework::InferShapeBase {
} }
}; };
class CastOpGradMaker : public framework::SingleGradOpDescMaker { template <typename T>
class CastOpGradMaker : public framework::SingleGradOpMaker<T> {
public: public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<framework::OpDesc> Apply() const override { std::unique_ptr<T> Apply() const override {
auto grad = new framework::OpDesc(); auto grad = new T();
grad->SetType("cast"); grad->SetType("cast");
grad->SetInput("X", OutputGrad("Out")); grad->SetInput("X", this->OutputGrad("Out"));
grad->SetOutput("Out", InputGrad("X")); grad->SetOutput("Out", this->InputGrad("X"));
grad->SetAttr("out_dtype", GetAttr("in_dtype")); grad->SetAttr("out_dtype", this->GetAttr("in_dtype"));
grad->SetAttr("in_dtype", GetAttr("out_dtype")); grad->SetAttr("in_dtype", this->GetAttr("out_dtype"));
return std::unique_ptr<framework::OpDesc>(grad); return std::unique_ptr<T>(grad);
} }
}; };
...@@ -84,7 +85,9 @@ class CastOp : public framework::OperatorWithKernel { ...@@ -84,7 +85,9 @@ class CastOp : public framework::OperatorWithKernel {
namespace ops = paddle::operators; namespace ops = paddle::operators;
using CPU = paddle::platform::CPUDeviceContext; using CPU = paddle::platform::CPUDeviceContext;
REGISTER_OPERATOR(cast, ops::CastOp, ops::CastOpGradMaker, REGISTER_OPERATOR(cast, ops::CastOp,
ops::CastOpGradMaker<paddle::framework::OpDesc>,
ops::CastOpGradMaker<paddle::imperative::OpBase>,
ops::CastOpInferShape, ops::CastOpProtoMaker); ops::CastOpInferShape, ops::CastOpProtoMaker);
REGISTER_OP_CPU_KERNEL(cast, ops::CastOpKernel<CPU, float>, REGISTER_OP_CPU_KERNEL(cast, ops::CastOpKernel<CPU, float>,
ops::CastOpKernel<CPU, double>, ops::CastOpKernel<CPU, double>,
......
...@@ -123,20 +123,21 @@ class CenterLossGradOp : public framework::OperatorWithKernel { ...@@ -123,20 +123,21 @@ class CenterLossGradOp : public framework::OperatorWithKernel {
} }
}; };
class CenterLossOpGradMaker : public framework::SingleGradOpDescMaker { template <typename T>
class CenterLossOpGradMaker : public framework::SingleGradOpMaker<T> {
public: public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<framework::OpDesc> Apply() const override { std::unique_ptr<T> Apply() const override {
std::unique_ptr<framework::OpDesc> retv(new framework::OpDesc()); std::unique_ptr<T> retv(new T());
retv->SetType("center_loss_grad"); retv->SetType("center_loss_grad");
retv->SetInput(framework::GradVarName("Loss"), OutputGrad("Loss")); retv->SetInput(framework::GradVarName("Loss"), this->OutputGrad("Loss"));
retv->SetInput("SampleCenterDiff", Output("SampleCenterDiff")); retv->SetInput("SampleCenterDiff", this->Output("SampleCenterDiff"));
retv->SetInput("X", Input("X")); retv->SetInput("X", this->Input("X"));
retv->SetOutput(framework::GradVarName("X"), InputGrad("X")); retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
retv->SetAttrMap(Attrs()); retv->SetAttrMap(this->Attrs());
return retv; return retv;
} }
}; };
...@@ -147,7 +148,8 @@ namespace ops = paddle::operators; ...@@ -147,7 +148,8 @@ namespace ops = paddle::operators;
using CPUCtx = paddle::platform::CPUDeviceContext; using CPUCtx = paddle::platform::CPUDeviceContext;
REGISTER_OPERATOR(center_loss, ops::CenterLossOp, ops::CenterLossOpMaker, REGISTER_OPERATOR(center_loss, ops::CenterLossOp, ops::CenterLossOpMaker,
ops::CenterLossOpGradMaker); ops::CenterLossOpGradMaker<paddle::framework::OpDesc>,
ops::CenterLossOpGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(center_loss_grad, ops::CenterLossGradOp); REGISTER_OPERATOR(center_loss_grad, ops::CenterLossGradOp);
......
...@@ -78,18 +78,19 @@ class ClipOpGrad : public framework::OperatorWithKernel { ...@@ -78,18 +78,19 @@ class ClipOpGrad : public framework::OperatorWithKernel {
} }
}; };
class ClipGradOpDescMaker : public framework::SingleGradOpDescMaker { template <typename T>
class ClipGradOpMaker : public framework::SingleGradOpMaker<T> {
public: public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<framework::OpDesc> Apply() const override { std::unique_ptr<T> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc()); std::unique_ptr<T> op(new T());
op->SetType("clip_grad"); op->SetType("clip_grad");
op->SetInput("X", Input("X")); op->SetInput("X", this->Input("X"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(Attrs()); op->SetAttrMap(this->Attrs());
return op; return op;
} }
}; };
...@@ -104,7 +105,9 @@ DECLARE_INPLACE_OP_INFERER(ClipGradInplaceInferer, ...@@ -104,7 +105,9 @@ DECLARE_INPLACE_OP_INFERER(ClipGradInplaceInferer,
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(clip, ops::ClipOp, ops::ClipOpMaker<float>, REGISTER_OPERATOR(clip, ops::ClipOp, ops::ClipOpMaker<float>,
ops::ClipGradOpDescMaker, ops::ClipInplaceInferer); ops::ClipGradOpMaker<paddle::framework::OpDesc>,
ops::ClipGradOpMaker<paddle::imperative::OpBase>,
ops::ClipInplaceInferer);
REGISTER_OPERATOR(clip_grad, ops::ClipOpGrad, ops::ClipGradInplaceInferer); REGISTER_OPERATOR(clip_grad, ops::ClipOpGrad, ops::ClipGradInplaceInferer);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
clip, ops::ClipKernel<paddle::platform::CPUDeviceContext, float>, clip, ops::ClipKernel<paddle::platform::CPUDeviceContext, float>,
......
...@@ -55,17 +55,18 @@ reference: https://docs.nvidia.com/deeplearning/sdk/nccl-developer-guide/docs/us ...@@ -55,17 +55,18 @@ reference: https://docs.nvidia.com/deeplearning/sdk/nccl-developer-guide/docs/us
} }
}; };
class CAllGatherOpGradMaker : public framework::SingleGradOpDescMaker { template <typename T>
class CAllGatherOpGradMaker : public framework::SingleGradOpMaker<T> {
public: public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<framework::OpDesc> Apply() const override { std::unique_ptr<T> Apply() const override {
std::unique_ptr<framework::OpDesc> retv(new framework::OpDesc()); std::unique_ptr<T> retv(new T());
retv->SetType("c_reducescatter"); retv->SetType("c_reducescatter");
retv->SetInput("X", OutputGrad("Out")); retv->SetInput("X", this->OutputGrad("Out"));
retv->SetOutput("Out", InputGrad("X")); retv->SetOutput("Out", this->InputGrad("X"));
retv->SetAttrMap(Attrs()); retv->SetAttrMap(this->Attrs());
return retv; return retv;
} }
}; };
...@@ -76,7 +77,9 @@ class CAllGatherOpGradMaker : public framework::SingleGradOpDescMaker { ...@@ -76,7 +77,9 @@ class CAllGatherOpGradMaker : public framework::SingleGradOpDescMaker {
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
REGISTER_OPERATOR(c_allgather, ops::CAllGatherOp, ops::CAllGatherOpGradMaker, REGISTER_OPERATOR(c_allgather, ops::CAllGatherOp,
ops::CAllGatherOpGradMaker<paddle::framework::OpDesc>,
ops::CAllGatherOpGradMaker<paddle::imperative::OpBase>,
ops::CAllGatherOpMaker); ops::CAllGatherOpMaker);
REGISTER_OP_CPU_KERNEL(c_allgather, ops::CAllGatherOpCPUKernel<float>, REGISTER_OP_CPU_KERNEL(c_allgather, ops::CAllGatherOpCPUKernel<float>,
......
...@@ -17,17 +17,18 @@ limitations under the License. */ ...@@ -17,17 +17,18 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class CAllReduceSumOpGradMaker : public framework::SingleGradOpDescMaker { template <typename T>
class CAllReduceSumOpGradMaker : public framework::SingleGradOpMaker<T> {
public: public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<framework::OpDesc> Apply() const override { std::unique_ptr<T> Apply() const override {
std::unique_ptr<framework::OpDesc> retv(new framework::OpDesc()); std::unique_ptr<T> retv(new T());
retv->SetType("c_allreduce_sum"); retv->SetType("c_allreduce_sum");
retv->SetInput("X", OutputGrad("Out")); retv->SetInput("X", this->OutputGrad("Out"));
retv->SetOutput("Out", InputGrad("X")); retv->SetOutput("Out", this->InputGrad("X"));
retv->SetAttrMap(Attrs()); retv->SetAttrMap(this->Attrs());
return retv; return retv;
} }
}; };
...@@ -44,7 +45,9 @@ namespace ops = paddle::operators; ...@@ -44,7 +45,9 @@ namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
REGISTER_OPERATOR(c_allreduce_sum, ops::CAllReduceOp, REGISTER_OPERATOR(c_allreduce_sum, ops::CAllReduceOp,
ops::CAllReduceSumOpGradMaker, ops::CAllReduceSumOpMaker); ops::CAllReduceSumOpGradMaker<paddle::framework::OpDesc>,
ops::CAllReduceSumOpGradMaker<paddle::imperative::OpBase>,
ops::CAllReduceSumOpMaker);
REGISTER_OP_CPU_KERNEL(c_allreduce_sum, REGISTER_OP_CPU_KERNEL(c_allreduce_sum,
ops::CAllReduceOpCPUKernel<ops::kRedSum, float>, ops::CAllReduceOpCPUKernel<ops::kRedSum, float>,
......
...@@ -60,17 +60,18 @@ Reference: https://docs.nvidia.com/deeplearning/sdk/nccl-developer-guide/docs/us ...@@ -60,17 +60,18 @@ Reference: https://docs.nvidia.com/deeplearning/sdk/nccl-developer-guide/docs/us
} }
}; };
class CReduceScatterOpGradMaker : public framework::SingleGradOpDescMaker { template <typename T>
class CReduceScatterOpGradMaker : public framework::SingleGradOpMaker<T> {
public: public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<framework::OpDesc> Apply() const override { std::unique_ptr<T> Apply() const override {
std::unique_ptr<framework::OpDesc> retv(new framework::OpDesc()); std::unique_ptr<T> retv(new T());
retv->SetType("c_allgather"); retv->SetType("c_allgather");
retv->SetInput("X", OutputGrad("Out")); retv->SetInput("X", this->OutputGrad("Out"));
retv->SetOutput("Out", InputGrad("X")); retv->SetOutput("Out", this->InputGrad("X"));
retv->SetAttrMap(Attrs()); retv->SetAttrMap(this->Attrs());
return retv; return retv;
} }
}; };
......
...@@ -187,19 +187,20 @@ class ConcatOpGrad : public framework::OperatorWithKernel { ...@@ -187,19 +187,20 @@ class ConcatOpGrad : public framework::OperatorWithKernel {
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(ConcatOpGradNoNeedBufferVarInference, DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(ConcatOpGradNoNeedBufferVarInference,
"X"); "X");
class ConcatGradOpDescMaker : public framework::SingleGradOpDescMaker { template <typename T>
class ConcatGradOpMaker : public framework::SingleGradOpMaker<T> {
public: public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<framework::OpDesc> Apply() const override { std::unique_ptr<T> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc()); std::unique_ptr<T> op(new T());
op->SetType("concat_grad"); op->SetType("concat_grad");
op->SetInput("X", Input("X")); op->SetInput("X", this->Input("X"));
op->SetInput("AxisTensor", Input("AxisTensor")); op->SetInput("AxisTensor", this->Input("AxisTensor"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X", false)); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X", false));
op->SetAttrMap(Attrs()); op->SetAttrMap(this->Attrs());
return op; return op;
} }
}; };
...@@ -209,7 +210,8 @@ class ConcatGradOpDescMaker : public framework::SingleGradOpDescMaker { ...@@ -209,7 +210,8 @@ class ConcatGradOpDescMaker : public framework::SingleGradOpDescMaker {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(concat, ops::ConcatOp, ops::ConcatOpMaker, REGISTER_OPERATOR(concat, ops::ConcatOp, ops::ConcatOpMaker,
ops::ConcatGradOpDescMaker); ops::ConcatGradOpMaker<paddle::framework::OpDesc>,
ops::ConcatGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(concat_grad, ops::ConcatOpGrad, REGISTER_OPERATOR(concat_grad, ops::ConcatOpGrad,
ops::ConcatOpGradNoNeedBufferVarInference); ops::ConcatOpGradNoNeedBufferVarInference);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
......
...@@ -110,18 +110,19 @@ class CompareOp : public framework::OperatorWithKernel { ...@@ -110,18 +110,19 @@ class CompareOp : public framework::OperatorWithKernel {
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
#define REGISTER_COMPARE_OP(op_type, _equation) \ #define REGISTER_COMPARE_OP(op_type, _equation) \
struct _##op_type##Comment { \ struct _##op_type##Comment { \
static char type[]; \ static char type[]; \
static char equation[]; \ static char equation[]; \
}; \ }; \
char _##op_type##Comment::type[]{#op_type}; \ char _##op_type##Comment::type[]{#op_type}; \
char _##op_type##Comment::equation[]{_equation}; \ char _##op_type##Comment::equation[]{_equation}; \
REGISTER_OPERATOR( \ REGISTER_OPERATOR( \
op_type, ::paddle::operators::CompareOp, \ op_type, ::paddle::operators::CompareOp, \
::paddle::operators::CompareOpProtoMaker<_##op_type##Comment>, \ ::paddle::operators::CompareOpProtoMaker<_##op_type##Comment>, \
::paddle::operators::CompareOpInferShape<_##op_type##Comment>, \ ::paddle::operators::CompareOpInferShape<_##op_type##Comment>, \
::paddle::framework::EmptyGradOpMaker); ::paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, \
::paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_COMPARE_OP(less_than, "Out = X < Y"); REGISTER_COMPARE_OP(less_than, "Out = X < Y");
REGISTER_COMPARE_KERNEL(less_than, CPU, paddle::operators::LessThanFunctor); REGISTER_COMPARE_KERNEL(less_than, CPU, paddle::operators::LessThanFunctor);
......
...@@ -69,6 +69,8 @@ class ConditionalBlockInferOp : public ConditionalOp { ...@@ -69,6 +69,8 @@ class ConditionalBlockInferOp : public ConditionalOp {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(conditional_block_infer, ops::ConditionalBlockInferOp, REGISTER_OPERATOR(
ops::ConditionalBlockOpProtoMaker, conditional_block_infer, ops::ConditionalBlockInferOp,
paddle::framework::EmptyGradOpMaker); ops::ConditionalBlockOpProtoMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
...@@ -168,28 +168,33 @@ class ConditionalBlockGradInferShape : public framework::InferShapeBase { ...@@ -168,28 +168,33 @@ class ConditionalBlockGradInferShape : public framework::InferShapeBase {
} }
}; };
class ConditionalBlockGradMaker : public framework::SingleGradOpDescMaker { template <typename T>
class ConditionalBlockGradMaker : public framework::SingleGradOpMaker<T> {
public: public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<framework::OpDesc> Apply() const override { std::unique_ptr<T> Apply() const override {
auto grad_op = new framework::OpDesc(); auto grad_op = new T();
grad_op->SetType("conditional_block_grad"); grad_op->SetType("conditional_block_grad");
grad_op->SetInput(ConditionalOp::kCondition, grad_op->SetInput(ConditionalOp::kCondition,
Input(ConditionalOp::kCondition)); this->Input(ConditionalOp::kCondition));
grad_op->SetInput(ConditionalOp::kInputs, Input(ConditionalOp::kInputs)); grad_op->SetInput(ConditionalOp::kInputs,
grad_op->SetInput(ConditionalOp::kOutputs, Output(ConditionalOp::kOutputs)); this->Input(ConditionalOp::kInputs));
grad_op->SetInput(ConditionalOp::kOutputs,
this->Output(ConditionalOp::kOutputs));
grad_op->SetInput(framework::GradVarName(ConditionalOp::kOutputs), grad_op->SetInput(framework::GradVarName(ConditionalOp::kOutputs),
OutputGrad(ConditionalOp::kOutputs)); this->OutputGrad(ConditionalOp::kOutputs));
grad_op->SetInput(ConditionalOp::kScope, Output(ConditionalOp::kScope)); grad_op->SetInput(ConditionalOp::kScope,
this->Output(ConditionalOp::kScope));
grad_op->SetOutput(framework::GradVarName(ConditionalOp::kCondition), grad_op->SetOutput(framework::GradVarName(ConditionalOp::kCondition),
InputGrad(ConditionalOp::kCondition, false)); this->InputGrad(ConditionalOp::kCondition, false));
grad_op->SetOutput(framework::GradVarName(ConditionalOp::kInputs), grad_op->SetOutput(framework::GradVarName(ConditionalOp::kInputs),
InputGrad(ConditionalOp::kInputs, false)); this->InputGrad(ConditionalOp::kInputs, false));
grad_op->SetBlockAttr("sub_block", this->grad_block_[0]); grad_op->SetBlockAttr("sub_block", this->grad_block_[0]);
grad_op->SetAttr("is_scalar_condition", GetAttr("is_scalar_condition")); grad_op->SetAttr("is_scalar_condition",
return std::unique_ptr<framework::OpDesc>(grad_op); this->GetAttr("is_scalar_condition"));
return std::unique_ptr<T>(grad_op);
} }
}; };
...@@ -199,6 +204,6 @@ class ConditionalBlockGradMaker : public framework::SingleGradOpDescMaker { ...@@ -199,6 +204,6 @@ class ConditionalBlockGradMaker : public framework::SingleGradOpDescMaker {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(conditional_block, ops::ConditionalBlockOp, REGISTER_OPERATOR(conditional_block, ops::ConditionalBlockOp,
ops::ConditionalBlockOpProtoMaker, ops::ConditionalBlockOpProtoMaker,
ops::ConditionalBlockGradMaker); ops::ConditionalBlockGradMaker<paddle::framework::OpDesc>);
REGISTER_OPERATOR(conditional_block_grad, ops::ConditionalBlockGradOp, REGISTER_OPERATOR(conditional_block_grad, ops::ConditionalBlockGradOp,
ops::ConditionalBlockGradInferShape); ops::ConditionalBlockGradInferShape);
...@@ -81,6 +81,8 @@ It should not be configured by users directly. ...@@ -81,6 +81,8 @@ It should not be configured by users directly.
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
REGISTER_OPERATOR(feed, paddle::operators::FeedOp, REGISTER_OPERATOR(
paddle::framework::EmptyGradOpMaker, feed, paddle::operators::FeedOp,
paddle::operators::FeedOpInfoMaker); paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
paddle::operators::FeedOpInfoMaker);
...@@ -94,6 +94,8 @@ It should not be configured by users directly. ...@@ -94,6 +94,8 @@ It should not be configured by users directly.
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
REGISTER_OPERATOR(fetch, paddle::operators::FetchOp, REGISTER_OPERATOR(
paddle::framework::EmptyGradOpMaker, fetch, paddle::operators::FetchOp,
paddle::operators::FetchOpInfoMaker); paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
paddle::operators::FetchOpInfoMaker);
...@@ -111,6 +111,8 @@ class GetPlacesInferShape : public framework::InferShapeBase { ...@@ -111,6 +111,8 @@ class GetPlacesInferShape : public framework::InferShapeBase {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(get_places, ops::GetPlacesOp, ops::GetPlacesOpProtoMaker, REGISTER_OPERATOR(
ops::GetPlacesInferVarType, ops::GetPlacesInferShape, get_places, ops::GetPlacesOp, ops::GetPlacesOpProtoMaker,
paddle::framework::EmptyGradOpMaker); ops::GetPlacesInferVarType, ops::GetPlacesInferShape,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
...@@ -124,7 +124,8 @@ class LogicalOp : public framework::OperatorWithKernel { ...@@ -124,7 +124,8 @@ class LogicalOp : public framework::OperatorWithKernel {
op_type, ::paddle::operators::LogicalOp, \ op_type, ::paddle::operators::LogicalOp, \
::paddle::operators::BinaryLogicalOpProtoMaker<_##op_type##Comment>, \ ::paddle::operators::BinaryLogicalOpProtoMaker<_##op_type##Comment>, \
::paddle::operators::BinaryLogicalOpInferShape<_##op_type##Comment>, \ ::paddle::operators::BinaryLogicalOpInferShape<_##op_type##Comment>, \
::paddle::framework::EmptyGradOpMaker); ::paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, \
::paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
#define REGISTER_UNARY_LOGICAL_OP(op_type, _equation) \ #define REGISTER_UNARY_LOGICAL_OP(op_type, _equation) \
struct _##op_type##Comment { \ struct _##op_type##Comment { \
...@@ -137,7 +138,8 @@ class LogicalOp : public framework::OperatorWithKernel { ...@@ -137,7 +138,8 @@ class LogicalOp : public framework::OperatorWithKernel {
op_type, ::paddle::operators::LogicalOp, \ op_type, ::paddle::operators::LogicalOp, \
::paddle::operators::UnaryLogicalOpProtoMaker<_##op_type##Comment>, \ ::paddle::operators::UnaryLogicalOpProtoMaker<_##op_type##Comment>, \
::paddle::operators::UnaryLogicalOpInferShape<_##op_type##Comment>, \ ::paddle::operators::UnaryLogicalOpInferShape<_##op_type##Comment>, \
::paddle::framework::EmptyGradOpMaker); ::paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, \
::paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_BINARY_LOGICAL_OP(logical_and, "$$Out = X \\&\\& Y$$"); REGISTER_BINARY_LOGICAL_OP(logical_and, "$$Out = X \\&\\& Y$$");
REGISTER_BINARY_LOGICAL_KERNEL(logical_and, CPU, REGISTER_BINARY_LOGICAL_KERNEL(logical_and, CPU,
......
...@@ -188,35 +188,37 @@ class ReadFromArrayInferShape : public WriteToArrayInferShape { ...@@ -188,35 +188,37 @@ class ReadFromArrayInferShape : public WriteToArrayInferShape {
} }
}; };
class WriteToArrayGradMaker : public framework::SingleGradOpDescMaker { template <typename T>
class WriteToArrayGradMaker : public framework::SingleGradOpMaker<T> {
public: public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<framework::OpDesc> Apply() const override { std::unique_ptr<T> Apply() const override {
auto *grad_op = new framework::OpDesc(); auto *grad_op = new T();
grad_op->SetType("read_from_array"); grad_op->SetType("read_from_array");
grad_op->SetInput("I", Input("I")); grad_op->SetInput("I", this->Input("I"));
grad_op->SetInput("X", OutputGrad("Out")); grad_op->SetInput("X", this->OutputGrad("Out"));
grad_op->SetOutput("Out", InputGrad("X")); grad_op->SetOutput("Out", this->InputGrad("X"));
grad_op->SetAttrMap(Attrs()); grad_op->SetAttrMap(this->Attrs());
return std::unique_ptr<framework::OpDesc>(grad_op); return std::unique_ptr<T>(grad_op);
} }
}; };
class ReadFromArrayGradMaker : public framework::SingleGradOpDescMaker { template <typename T>
class ReadFromArrayGradMaker : public framework::SingleGradOpMaker<T> {
public: public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<framework::OpDesc> Apply() const override { std::unique_ptr<T> Apply() const override {
auto *grad_op = new framework::OpDesc(); auto *grad_op = new T();
grad_op->SetType("write_to_array"); grad_op->SetType("write_to_array");
grad_op->SetInput("I", Input("I")); grad_op->SetInput("I", this->Input("I"));
grad_op->SetInput("X", OutputGrad("Out")); grad_op->SetInput("X", this->OutputGrad("Out"));
grad_op->SetOutput("Out", InputGrad("X")); grad_op->SetOutput("Out", this->InputGrad("X"));
grad_op->SetAttrMap(Attrs()); grad_op->SetAttrMap(this->Attrs());
return std::unique_ptr<framework::OpDesc>(grad_op); return std::unique_ptr<T>(grad_op);
} }
}; };
...@@ -226,7 +228,10 @@ class ReadFromArrayGradMaker : public framework::SingleGradOpDescMaker { ...@@ -226,7 +228,10 @@ class ReadFromArrayGradMaker : public framework::SingleGradOpDescMaker {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(write_to_array, ops::WriteToArrayOp, REGISTER_OPERATOR(write_to_array, ops::WriteToArrayOp,
ops::WriteToArrayInferShape, ops::WriteToArrayOpProtoMaker, ops::WriteToArrayInferShape, ops::WriteToArrayOpProtoMaker,
ops::WriteToArrayGradMaker, ops::WriteToArrayInferVarType); ops::WriteToArrayGradMaker<paddle::framework::OpDesc>,
ops::WriteToArrayGradMaker<paddle::imperative::OpBase>,
ops::WriteToArrayInferVarType);
REGISTER_OPERATOR(read_from_array, ops::ReadFromArrayOp, REGISTER_OPERATOR(read_from_array, ops::ReadFromArrayOp,
ops::ReadFromArrayInferShape, ops::ReadFromArrayProtoMaker, ops::ReadFromArrayInferShape, ops::ReadFromArrayProtoMaker,
ops::ReadFromArrayGradMaker); ops::ReadFromArrayGradMaker<paddle::framework::OpDesc>,
ops::ReadFromArrayGradMaker<paddle::imperative::OpBase>);
...@@ -320,17 +320,18 @@ class WhileGradOp : public framework::OperatorBase { ...@@ -320,17 +320,18 @@ class WhileGradOp : public framework::OperatorBase {
} }
}; };
class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker { template <typename T>
class WhileGradOpMaker : public framework::SingleGradOpMaker<T> {
public: public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<framework::OpDesc> Apply() const override { std::unique_ptr<T> Apply() const override {
auto *while_grad = new framework::OpDesc(); auto *while_grad = new T();
while_grad->SetType("while_grad"); while_grad->SetType("while_grad");
while_grad->SetInput(kX, Input(kX)); while_grad->SetInput(kX, this->Input(kX));
while_grad->SetInput(kOutputs, Output(kOutputs)); while_grad->SetInput(kOutputs, this->Output(kOutputs));
while_grad->SetInput(kStepScopes, Output(kStepScopes)); while_grad->SetInput(kStepScopes, this->Output(kStepScopes));
auto *grad_block = this->grad_block_[0]; auto *grad_block = this->grad_block_[0];
auto *fwd_block = grad_block->ForwardBlock(); auto *fwd_block = grad_block->ForwardBlock();
...@@ -344,7 +345,8 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker { ...@@ -344,7 +345,8 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
inner_op_outputs.insert(oname); inner_op_outputs.insert(oname);
} }
} }
auto igs = InputGrad(kX, /*do not drop empty gradient*/ false); auto igs = this->InputGrad(kX, /*do not drop empty gradient*/ false);
for (auto &each_ig : igs) { for (auto &each_ig : igs) {
if (inner_op_outputs.find(each_ig) == inner_op_outputs.end()) { if (inner_op_outputs.find(each_ig) == inner_op_outputs.end()) {
VLOG(8) << "Ignore " << each_ig; VLOG(8) << "Ignore " << each_ig;
...@@ -356,11 +358,11 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker { ...@@ -356,11 +358,11 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
// OG should be re-calculated by step blocks, since many outputs of while op // OG should be re-calculated by step blocks, since many outputs of while op
// do not need to calculate gradients. // do not need to calculate gradients.
std::unordered_set<std::string> block_ins; std::unordered_set<std::string> block_ins;
block_ins.reserve(Input(kX).size() + Output(kOutputs).size()); block_ins.reserve(this->Input(kX).size() + this->Output(kOutputs).size());
for (auto &p : Input(kX)) { for (auto &p : this->Input(kX)) {
block_ins.insert(p); block_ins.insert(p);
} }
for (auto &o : Output(kOutputs)) { for (auto &o : this->Output(kOutputs)) {
block_ins.insert(o); block_ins.insert(o);
} }
std::unordered_set<std::string> output_grads; std::unordered_set<std::string> output_grads;
...@@ -398,7 +400,7 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker { ...@@ -398,7 +400,7 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
while_grad->SetAttr(kSkipEagerDeletionVars, std::vector<std::string>()); while_grad->SetAttr(kSkipEagerDeletionVars, std::vector<std::string>());
return std::unique_ptr<framework::OpDesc>(while_grad); return std::unique_ptr<T>(while_grad);
} }
}; };
...@@ -468,9 +470,9 @@ class WhileGradOpShapeInference : public framework::InferShapeBase { ...@@ -468,9 +470,9 @@ class WhileGradOpShapeInference : public framework::InferShapeBase {
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
REGISTER_OPERATOR(while, paddle::operators::WhileOp, REGISTER_OPERATOR(
paddle::operators::WhileOpMaker, while, paddle::operators::WhileOp, paddle::operators::WhileOpMaker,
paddle::operators::WhileGradOpDescMaker); paddle::operators::WhileGradOpMaker<paddle::framework::OpDesc>);
REGISTER_OPERATOR(while_grad, paddle::operators::WhileGradOp, REGISTER_OPERATOR(while_grad, paddle::operators::WhileGradOp,
paddle::operators::WhileGradOpShapeInference, paddle::operators::WhileGradOpShapeInference,
paddle::operators::WhileGradOpVarTypeInference); paddle::operators::WhileGradOpVarTypeInference);
...@@ -548,48 +548,50 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType( ...@@ -548,48 +548,50 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType(
return type; return type;
} }
class Conv2DGradMaker : public framework::SingleGradOpDescMaker { template <typename T>
class Conv2DGradMaker : public framework::SingleGradOpMaker<T> {
public: public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
std::unique_ptr<framework::OpDesc> Apply() const override { std::unique_ptr<T> Apply() const override {
auto* op = new framework::OpDesc(); auto* op = new T();
op->SetType(this->ForwardOpType() + "_grad"); op->SetType(this->ForwardOpType() + "_grad");
op->SetInput("Input", Input("Input")); op->SetInput("Input", this->Input("Input"));
op->SetInput("Filter", Input("Filter")); op->SetInput("Filter", this->Input("Filter"));
op->SetInput("Bias", Input("Bias")); op->SetInput("Bias", this->Input("Bias"));
op->SetInput(framework::GradVarName("Output"), OutputGrad("Output")); op->SetInput(framework::GradVarName("Output"), this->OutputGrad("Output"));
op->SetOutput(framework::GradVarName("Input"), InputGrad("Input")); op->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input"));
op->SetOutput(framework::GradVarName("Filter"), InputGrad("Filter")); op->SetOutput(framework::GradVarName("Filter"), this->InputGrad("Filter"));
op->SetOutput(framework::GradVarName("Bias"), InputGrad("Bias")); op->SetOutput(framework::GradVarName("Bias"), this->InputGrad("Bias"));
op->SetAttrMap(Attrs()); op->SetAttrMap(this->Attrs());
return std::unique_ptr<framework::OpDesc>(op); return std::unique_ptr<T>(op);
} }
}; };
class Conv3DGradMaker : public framework::SingleGradOpDescMaker { template <typename T>
class Conv3DGradMaker : public framework::SingleGradOpMaker<T> {
public: public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
std::unique_ptr<framework::OpDesc> Apply() const override { std::unique_ptr<T> Apply() const override {
auto* op = new framework::OpDesc(); auto* op = new T();
op->SetType(this->ForwardOpType() + "_grad"); op->SetType(this->ForwardOpType() + "_grad");
op->SetInput("Input", Input("Input")); op->SetInput("Input", this->Input("Input"));
op->SetInput("Filter", Input("Filter")); op->SetInput("Filter", this->Input("Filter"));
op->SetInput(framework::GradVarName("Output"), OutputGrad("Output")); op->SetInput(framework::GradVarName("Output"), this->OutputGrad("Output"));
op->SetOutput(framework::GradVarName("Input"), InputGrad("Input")); op->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input"));
op->SetOutput(framework::GradVarName("Filter"), InputGrad("Filter")); op->SetOutput(framework::GradVarName("Filter"), this->InputGrad("Filter"));
if (ForwardOp().Inputs().count("ResidualData") != 0) { if (this->HasInput("ResidualData")) {
op->SetInput("ResidualData", Input("ResidualData")); op->SetInput("ResidualData", this->Input("ResidualData"));
} }
op->SetAttrMap(Attrs()); op->SetAttrMap(this->Attrs());
return std::unique_ptr<framework::OpDesc>(op); return std::unique_ptr<T>(op);
} }
}; };
...@@ -597,37 +599,40 @@ class Conv3DGradMaker : public framework::SingleGradOpDescMaker { ...@@ -597,37 +599,40 @@ class Conv3DGradMaker : public framework::SingleGradOpDescMaker {
* Inputs: I, W, dO, ddI, ddW * Inputs: I, W, dO, ddI, ddW
* Outputs: ddO, dW, dI * Outputs: ddO, dW, dI
*/ */
class Conv2DDoubleGradMaker : public framework::SingleGradOpDescMaker { template <typename T>
class Conv2DDoubleGradMaker : public framework::SingleGradOpMaker<T> {
public: public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
std::unique_ptr<framework::OpDesc> Apply() const override { std::unique_ptr<T> Apply() const override {
auto* op = new framework::OpDesc(); auto* op = new T();
op->SetType(this->ForwardOpType() + "_grad"); op->SetType(this->ForwardOpType() + "_grad");
// I, W, dO, ddI, ddW // I, W, dO, ddI, ddW
op->SetInput("Input", Input("Input")); op->SetInput("Input", this->Input("Input"));
op->SetInput("Filter", Input("Filter")); op->SetInput("Filter", this->Input("Filter"));
op->SetInput("DOutput", Input(framework::GradVarName("Output"))); op->SetInput("DOutput", this->Input(framework::GradVarName("Output")));
op->SetInput("DDInput", OutputGrad(framework::GradVarName("Input"))); op->SetInput("DDInput", this->OutputGrad(framework::GradVarName("Input")));
op->SetInput("DDFilter", OutputGrad(framework::GradVarName("Filter"))); op->SetInput("DDFilter",
this->OutputGrad(framework::GradVarName("Filter")));
// ddO, dI, dW // ddO, dI, dW
// Unlike grad op, double grad op does not use name@GRAD@GRAD // Unlike grad op, double grad op does not use name@GRAD@GRAD
// as key of ops' inputs and outputs. // as key of ops' inputs and outputs.
auto ddx = OutputGrad(framework::GradVarName("Input")); auto ddx = this->OutputGrad(framework::GradVarName("Input"));
auto ddw = OutputGrad(framework::GradVarName("Filter")); auto ddw = this->OutputGrad(framework::GradVarName("Filter"));
std::vector<std::string> empty_str = {};
op->SetOutput("DDOutput", op->SetOutput("DDOutput",
(ddx.empty() && ddw.empty()) ddx.empty()
? empty_str ? this->Empty()
: InputGrad(framework::GradVarName("Output"))); : this->InputGrad(framework::GradVarName("Output")));
op->SetOutput("DFilter", ddx.empty() ? empty_str : InputGrad("Filter")); op->SetOutput("DFilter",
op->SetOutput("DInput", ddw.empty() ? empty_str : InputGrad("Input")); ddx.empty() ? this->Empty() : this->InputGrad("Filter"));
op->SetOutput("DInput",
ddw.empty() ? this->Empty() : this->InputGrad("Input"));
op->SetAttrMap(Attrs()); op->SetAttrMap(this->Attrs());
return std::unique_ptr<framework::OpDesc>(op); return std::unique_ptr<T>(op);
} }
}; };
...@@ -635,34 +640,37 @@ class Conv2DDoubleGradMaker : public framework::SingleGradOpDescMaker { ...@@ -635,34 +640,37 @@ class Conv2DDoubleGradMaker : public framework::SingleGradOpDescMaker {
* Inputs: I, W, dO, ddI, ddW * Inputs: I, W, dO, ddI, ddW
* Outputs: ddO, dW, dI * Outputs: ddO, dW, dI
*/ */
class Conv3DDoubleGradMaker : public framework::SingleGradOpDescMaker { template <typename T>
class Conv3DDoubleGradMaker : public framework::SingleGradOpMaker<T> {
public: public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
std::unique_ptr<framework::OpDesc> Apply() const override { std::unique_ptr<T> Apply() const override {
auto* op = new framework::OpDesc(); auto* op = new T();
op->SetType(this->ForwardOpType() + "_grad"); op->SetType(this->ForwardOpType() + "_grad");
// I, W, dO, ddI, ddW // I, W, dO, ddI, ddW
op->SetInput("Input", Input("Input")); op->SetInput("Input", this->Input("Input"));
op->SetInput("Filter", Input("Filter")); op->SetInput("Filter", this->Input("Filter"));
op->SetInput("DOutput", Input(framework::GradVarName("Output"))); op->SetInput("DOutput", this->Input(framework::GradVarName("Output")));
op->SetInput("DDInput", OutputGrad(framework::GradVarName("Input"))); op->SetInput("DDInput", this->OutputGrad(framework::GradVarName("Input")));
op->SetInput("DDFilter", OutputGrad(framework::GradVarName("Filter"))); op->SetInput("DDFilter",
this->OutputGrad(framework::GradVarName("Filter")));
auto ddx = OutputGrad(framework::GradVarName("Input")); auto ddx = this->OutputGrad(framework::GradVarName("Input"));
auto ddw = OutputGrad(framework::GradVarName("Filter")); auto ddw = this->OutputGrad(framework::GradVarName("Filter"));
std::vector<std::string> empty_str = {};
op->SetOutput("DDOutput", op->SetOutput("DDOutput",
(ddx.empty() && ddw.empty()) ddx.empty()
? empty_str ? this->Empty()
: InputGrad(framework::GradVarName("Output"))); : this->InputGrad(framework::GradVarName("Output")));
op->SetOutput("DFilter", ddx.empty() ? empty_str : InputGrad("Filter")); op->SetOutput("DFilter",
op->SetOutput("DInput", ddw.empty() ? empty_str : InputGrad("Input")); ddx.empty() ? this->Empty() : this->InputGrad("Filter"));
op->SetOutput("DInput",
ddw.empty() ? this->Empty() : this->InputGrad("Input"));
op->SetAttrMap(Attrs()); op->SetAttrMap(this->Attrs());
return std::unique_ptr<framework::OpDesc>(op); return std::unique_ptr<T>(op);
} }
}; };
...@@ -734,18 +742,28 @@ framework::OpKernelType ConvOpDoubleGrad::GetExpectedKernelType( ...@@ -734,18 +742,28 @@ framework::OpKernelType ConvOpDoubleGrad::GetExpectedKernelType(
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(conv2d, ops::ConvOp, ops::Conv2DOpMaker, REGISTER_OPERATOR(conv2d, ops::ConvOp, ops::Conv2DOpMaker,
ops::ConvOpInferVarType, ops::Conv2DGradMaker); ops::ConvOpInferVarType,
REGISTER_OPERATOR(conv2d_grad, ops::ConvOpGrad, ops::Conv2DDoubleGradMaker); ops::Conv2DGradMaker<paddle::framework::OpDesc>,
ops::Conv2DGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(conv2d_grad, ops::ConvOpGrad,
ops::Conv2DDoubleGradMaker<paddle::framework::OpDesc>,
ops::Conv2DDoubleGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(conv2d_grad_grad, ops::ConvOpDoubleGrad); REGISTER_OPERATOR(conv2d_grad_grad, ops::ConvOpDoubleGrad);
// depthwise convolution op // depthwise convolution op
REGISTER_OPERATOR(depthwise_conv2d, ops::ConvOp, ops::Conv2DOpMaker, REGISTER_OPERATOR(depthwise_conv2d, ops::ConvOp, ops::Conv2DOpMaker,
ops::ConvOpInferVarType, ops::Conv2DGradMaker); ops::ConvOpInferVarType,
ops::Conv2DGradMaker<paddle::framework::OpDesc>,
ops::Conv2DGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(depthwise_conv2d_grad, ops::ConvOpGrad); REGISTER_OPERATOR(depthwise_conv2d_grad, ops::ConvOpGrad);
REGISTER_OPERATOR(conv3d, ops::ConvOp, ops::Conv3DOpMaker, REGISTER_OPERATOR(conv3d, ops::ConvOp, ops::Conv3DOpMaker,
ops::ConvOpInferVarType, ops::Conv3DGradMaker); ops::ConvOpInferVarType,
REGISTER_OPERATOR(conv3d_grad, ops::ConvOpGrad, ops::Conv3DDoubleGradMaker); ops::Conv3DGradMaker<paddle::framework::OpDesc>,
ops::Conv3DGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(conv3d_grad, ops::ConvOpGrad,
ops::Conv3DDoubleGradMaker<paddle::framework::OpDesc>,
ops::Conv3DDoubleGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(conv3d_grad_grad, ops::ConvOpDoubleGrad); REGISTER_OPERATOR(conv3d_grad_grad, ops::ConvOpDoubleGrad);
// depthwise conv kernel // depthwise conv kernel
......
...@@ -193,20 +193,21 @@ class ConvShiftGradKernel<platform::CPUPlace, T> ...@@ -193,20 +193,21 @@ class ConvShiftGradKernel<platform::CPUPlace, T>
} }
}; };
class ConvShiftGradOpDescMaker : public framework::SingleGradOpDescMaker { template <typename T>
class ConvShiftGradOpMaker : public framework::SingleGradOpMaker<T> {
public: public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<framework::OpDesc> Apply() const override { std::unique_ptr<T> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc()); std::unique_ptr<T> op(new T());
op->SetType("conv_shift_grad"); op->SetType("conv_shift_grad");
op->SetInput("X", Input("X")); op->SetInput("X", this->Input("X"));
op->SetInput("Y", Input("Y")); op->SetInput("Y", this->Input("Y"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetOutput(framework::GradVarName("Y"), InputGrad("Y")); op->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
op->SetAttrMap(Attrs()); op->SetAttrMap(this->Attrs());
return op; return op;
} }
}; };
...@@ -216,7 +217,8 @@ class ConvShiftGradOpDescMaker : public framework::SingleGradOpDescMaker { ...@@ -216,7 +217,8 @@ class ConvShiftGradOpDescMaker : public framework::SingleGradOpDescMaker {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(conv_shift, ops::ConvShiftOp, ops::ConvShiftOpMaker, REGISTER_OPERATOR(conv_shift, ops::ConvShiftOp, ops::ConvShiftOpMaker,
ops::ConvShiftGradOpDescMaker); ops::ConvShiftGradOpMaker<paddle::framework::OpDesc>,
ops::ConvShiftGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(conv_shift_grad, ops::ConvShiftGradOp); REGISTER_OPERATOR(conv_shift_grad, ops::ConvShiftGradOp);
REGISTER_OP_CPU_KERNEL(conv_shift, REGISTER_OP_CPU_KERNEL(conv_shift,
ops::ConvShiftKernel<paddle::platform::CPUPlace, float>); ops::ConvShiftKernel<paddle::platform::CPUPlace, float>);
......
...@@ -390,24 +390,25 @@ framework::OpKernelType ConvTransposeOpGrad::GetExpectedKernelType( ...@@ -390,24 +390,25 @@ framework::OpKernelType ConvTransposeOpGrad::GetExpectedKernelType(
layout_, library_); layout_, library_);
} }
class ConvTransposeGradOpDescMaker : public framework::SingleGradOpDescMaker { template <typename T>
class ConvTransposeGradOpMaker : public framework::SingleGradOpMaker<T> {
public: public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<framework::OpDesc> Apply() const override { std::unique_ptr<T> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc()); std::unique_ptr<T> op(new T());
op->SetType(ForwardOp().Type() + "_grad"); op->SetType(this->ForwardOpType() + "_grad");
op->SetInput("Input", Input("Input")); op->SetInput("Input", this->Input("Input"));
op->SetInput("Filter", Input("Filter")); op->SetInput("Filter", this->Input("Filter"));
op->SetOutput(framework::GradVarName("Input"), InputGrad("Input")); op->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input"));
op->SetOutput(framework::GradVarName("Filter"), InputGrad("Filter")); op->SetOutput(framework::GradVarName("Filter"), this->InputGrad("Filter"));
if (ForwardOp().Inputs().count("Bias") > 0) { if (this->HasInput("Bias")) {
op->SetInput("Bias", Input("Bias")); op->SetInput("Bias", this->Input("Bias"));
op->SetOutput(framework::GradVarName("Bias"), InputGrad("Bias")); op->SetOutput(framework::GradVarName("Bias"), this->InputGrad("Bias"));
} }
op->SetInput(framework::GradVarName("Output"), OutputGrad("Output")); op->SetInput(framework::GradVarName("Output"), this->OutputGrad("Output"));
op->SetAttrMap(Attrs()); op->SetAttrMap(this->Attrs());
return op; return op;
} }
}; };
...@@ -420,7 +421,8 @@ namespace ops = paddle::operators; ...@@ -420,7 +421,8 @@ namespace ops = paddle::operators;
// conv2d_transpose // conv2d_transpose
REGISTER_OPERATOR(conv2d_transpose, ops::ConvTransposeOp, REGISTER_OPERATOR(conv2d_transpose, ops::ConvTransposeOp,
ops::Conv2DTransposeOpMaker, ops::Conv2DTransposeOpMaker,
ops::ConvTransposeGradOpDescMaker); ops::ConvTransposeGradOpMaker<paddle::framework::OpDesc>,
ops::ConvTransposeGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(conv2d_transpose_grad, ops::ConvTransposeOpGrad); REGISTER_OPERATOR(conv2d_transpose_grad, ops::ConvTransposeOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
...@@ -436,7 +438,8 @@ REGISTER_OP_CPU_KERNEL( ...@@ -436,7 +438,8 @@ REGISTER_OP_CPU_KERNEL(
// conv3d_transpose // conv3d_transpose
REGISTER_OPERATOR(conv3d_transpose, ops::ConvTransposeOp, REGISTER_OPERATOR(conv3d_transpose, ops::ConvTransposeOp,
ops::Conv3DTransposeOpMaker, ops::Conv3DTransposeOpMaker,
ops::ConvTransposeGradOpDescMaker); ops::ConvTransposeGradOpMaker<paddle::framework::OpDesc>,
ops::ConvTransposeGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(conv3d_transpose_grad, ops::ConvTransposeOpGrad); REGISTER_OPERATOR(conv3d_transpose_grad, ops::ConvTransposeOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
...@@ -452,7 +455,8 @@ REGISTER_OP_CPU_KERNEL( ...@@ -452,7 +455,8 @@ REGISTER_OP_CPU_KERNEL(
// depthwise conv2d_transpose // depthwise conv2d_transpose
REGISTER_OPERATOR(depthwise_conv2d_transpose, ops::ConvTransposeOp, REGISTER_OPERATOR(depthwise_conv2d_transpose, ops::ConvTransposeOp,
ops::Conv2DTransposeOpMaker, ops::Conv2DTransposeOpMaker,
ops::ConvTransposeGradOpDescMaker); ops::ConvTransposeGradOpMaker<paddle::framework::OpDesc>,
ops::ConvTransposeGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(depthwise_conv2d_transpose_grad, ops::ConvTransposeOpGrad); REGISTER_OPERATOR(depthwise_conv2d_transpose_grad, ops::ConvTransposeOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
......
...@@ -169,8 +169,10 @@ class CosSimOpGrad : public framework::OperatorWithKernel { ...@@ -169,8 +169,10 @@ class CosSimOpGrad : public framework::OperatorWithKernel {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(cos_sim, ops::CosSimOp, ops::CosSimOpMaker, REGISTER_OPERATOR(
paddle::framework::DefaultGradOpDescMaker<true>); cos_sim, ops::CosSimOp, ops::CosSimOpMaker,
paddle::framework::DefaultGradOpMaker<paddle::framework::OpDesc, true>,
paddle::framework::DefaultGradOpMaker<paddle::imperative::OpBase, true>);
REGISTER_OPERATOR(cos_sim_grad, ops::CosSimOpGrad); REGISTER_OPERATOR(cos_sim_grad, ops::CosSimOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
cos_sim, ops::CosSimKernel<paddle::platform::CPUDeviceContext, float>); cos_sim, ops::CosSimKernel<paddle::platform::CPUDeviceContext, float>);
......
...@@ -181,21 +181,22 @@ class CropOpGrad : public framework::OperatorWithKernel { ...@@ -181,21 +181,22 @@ class CropOpGrad : public framework::OperatorWithKernel {
} }
}; };
class CropGradOpDescMaker : public framework::SingleGradOpDescMaker { template <typename T>
class CropGradOpMaker : public framework::SingleGradOpMaker<T> {
public: public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<framework::OpDesc> Apply() const override { std::unique_ptr<T> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc()); std::unique_ptr<T> op(new T());
op->SetType("crop_grad"); op->SetType("crop_grad");
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetInput("X", Input("X")); op->SetInput("X", this->Input("X"));
if (ForwardOp().Inputs().count("Offsets") > 0) { if (this->HasInput("Offsets")) {
op->SetInput("Offsets", Input("Offsets")); op->SetInput("Offsets", this->Input("Offsets"));
} }
op->SetOutput(framework::GradVarName("X"), InputGrad("X")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(Attrs()); op->SetAttrMap(this->Attrs());
return op; return op;
} }
}; };
...@@ -205,7 +206,8 @@ class CropGradOpDescMaker : public framework::SingleGradOpDescMaker { ...@@ -205,7 +206,8 @@ class CropGradOpDescMaker : public framework::SingleGradOpDescMaker {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(crop, ops::CropOp, ops::CropOpMaker, REGISTER_OPERATOR(crop, ops::CropOp, ops::CropOpMaker,
ops::CropGradOpDescMaker); ops::CropGradOpMaker<paddle::framework::OpDesc>,
ops::CropGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(crop_grad, ops::CropOpGrad); REGISTER_OPERATOR(crop_grad, ops::CropOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
crop, ops::CropKernel<paddle::platform::CPUDeviceContext, float>, crop, ops::CropKernel<paddle::platform::CPUDeviceContext, float>,
......
...@@ -273,24 +273,25 @@ class CropTensorOpGrad : public framework::OperatorWithKernel { ...@@ -273,24 +273,25 @@ class CropTensorOpGrad : public framework::OperatorWithKernel {
} }
}; };
class CropTensorGradOpDescMaker : public framework::SingleGradOpDescMaker { template <typename T>
class CropTensorGradOpMaker : public framework::SingleGradOpMaker<T> {
public: public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<framework::OpDesc> Apply() const override { std::unique_ptr<T> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc()); std::unique_ptr<T> op(new T());
op->SetType("crop_tensor_grad"); op->SetType("crop_tensor_grad");
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetInput("X", Input("X")); op->SetInput("X", this->Input("X"));
if (ForwardOp().Inputs().count("OffsetsTensor") > 0) { if (this->HasInput("OffsetsTensor")) {
op->SetInput("OffsetsTensor", Input("OffsetsTensor")); op->SetInput("OffsetsTensor", this->Input("OffsetsTensor"));
} }
if (ForwardOp().Inputs().count("Offsets") > 0) { if (this->HasInput("Offsets")) {
op->SetInput("Offsets", Input("Offsets")); op->SetInput("Offsets", this->Input("Offsets"));
} }
op->SetOutput(framework::GradVarName("X"), InputGrad("X")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(Attrs()); op->SetAttrMap(this->Attrs());
return op; return op;
} }
}; };
...@@ -300,7 +301,8 @@ class CropTensorGradOpDescMaker : public framework::SingleGradOpDescMaker { ...@@ -300,7 +301,8 @@ class CropTensorGradOpDescMaker : public framework::SingleGradOpDescMaker {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(crop_tensor, ops::CropTensorOp, ops::CropTensorOpMaker, REGISTER_OPERATOR(crop_tensor, ops::CropTensorOp, ops::CropTensorOpMaker,
ops::CropTensorGradOpDescMaker); ops::CropTensorGradOpMaker<paddle::framework::OpDesc>,
ops::CropTensorGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(crop_tensor_grad, ops::CropTensorOpGrad); REGISTER_OPERATOR(crop_tensor_grad, ops::CropTensorOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
crop_tensor, crop_tensor,
......
...@@ -257,19 +257,20 @@ class CrossEntropyGradientOp : public CrossEntropyGradientOpBase { ...@@ -257,19 +257,20 @@ class CrossEntropyGradientOp : public CrossEntropyGradientOpBase {
} }
}; };
class CrossEntropyGradOpDescMaker : public framework::SingleGradOpDescMaker { template <typename T>
class CrossEntropyGradOpMaker : public framework::SingleGradOpMaker<T> {
public: public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<framework::OpDesc> Apply() const override { std::unique_ptr<T> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc()); std::unique_ptr<T> op(new T());
op->SetType("cross_entropy_grad"); op->SetType("cross_entropy_grad");
op->SetInput("X", Input("X")); op->SetInput("X", this->Input("X"));
op->SetInput("Label", Input("Label")); op->SetInput("Label", this->Input("Label"));
op->SetInput(framework::GradVarName("Y"), OutputGrad("Y")); op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(Attrs()); op->SetAttrMap(this->Attrs());
return op; return op;
} }
}; };
...@@ -365,20 +366,21 @@ or not. But the output only shares the LoD information with input X. ...@@ -365,20 +366,21 @@ or not. But the output only shares the LoD information with input X.
} }
}; };
class CrossEntropyGradOpDescMaker2 : public framework::SingleGradOpDescMaker { template <typename T>
class CrossEntropyGradOpMaker2 : public framework::SingleGradOpMaker<T> {
public: public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<framework::OpDesc> Apply() const override { std::unique_ptr<T> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc()); std::unique_ptr<T> op(new T());
op->SetType("cross_entropy_grad2"); op->SetType("cross_entropy_grad2");
op->SetInput("Label", Input("Label")); op->SetInput("Label", this->Input("Label"));
op->SetInput("MatchX", Output("MatchX")); op->SetInput("MatchX", this->Output("MatchX"));
op->SetInput("XShape", Output("XShape")); op->SetInput("XShape", this->Output("XShape"));
op->SetInput(framework::GradVarName("Y"), OutputGrad("Y")); op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(Attrs()); op->SetAttrMap(this->Attrs());
return op; return op;
} }
}; };
...@@ -391,7 +393,8 @@ using CPUCtx = paddle::platform::CPUDeviceContext; ...@@ -391,7 +393,8 @@ using CPUCtx = paddle::platform::CPUDeviceContext;
REGISTER_OPERATOR(cross_entropy, ops::CrossEntropyOpBase, REGISTER_OPERATOR(cross_entropy, ops::CrossEntropyOpBase,
ops::CrossEntropyOpMaker, ops::CrossEntropyOpInferVarType, ops::CrossEntropyOpMaker, ops::CrossEntropyOpInferVarType,
ops::CrossEntropyGradOpDescMaker); ops::CrossEntropyGradOpMaker<paddle::framework::OpDesc>,
ops::CrossEntropyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(cross_entropy_grad, ops::CrossEntropyGradientOp); REGISTER_OPERATOR(cross_entropy_grad, ops::CrossEntropyGradientOp);
REGISTER_OP_CPU_KERNEL(cross_entropy, ops::CrossEntropyOpKernel<CPUCtx, float>, REGISTER_OP_CPU_KERNEL(cross_entropy, ops::CrossEntropyOpKernel<CPUCtx, float>,
ops::CrossEntropyOpKernel<CPUCtx, double>); ops::CrossEntropyOpKernel<CPUCtx, double>);
...@@ -401,7 +404,8 @@ REGISTER_OP_CPU_KERNEL(cross_entropy_grad, ...@@ -401,7 +404,8 @@ REGISTER_OP_CPU_KERNEL(cross_entropy_grad,
REGISTER_OPERATOR(cross_entropy2, ops::CrossEntropyOp2, REGISTER_OPERATOR(cross_entropy2, ops::CrossEntropyOp2,
ops::CrossEntropyOpMaker2, ops::CrossEntropyOpInferVarType, ops::CrossEntropyOpMaker2, ops::CrossEntropyOpInferVarType,
ops::CrossEntropyGradOpDescMaker2); ops::CrossEntropyGradOpMaker2<paddle::framework::OpDesc>,
ops::CrossEntropyGradOpMaker2<paddle::imperative::OpBase>);
REGISTER_OPERATOR(cross_entropy_grad2, ops::CrossEntropyGradientOp2); REGISTER_OPERATOR(cross_entropy_grad2, ops::CrossEntropyGradientOp2);
REGISTER_OP_CPU_KERNEL(cross_entropy2, REGISTER_OP_CPU_KERNEL(cross_entropy2,
ops::CrossEntropyOpKernel2<CPUCtx, float>, ops::CrossEntropyOpKernel2<CPUCtx, float>,
......
...@@ -125,8 +125,10 @@ Then: ...@@ -125,8 +125,10 @@ Then:
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(ctc_align, ops::CTCAlignOp, ops::CTCAlignOpMaker, REGISTER_OPERATOR(
paddle::framework::EmptyGradOpMaker); ctc_align, ops::CTCAlignOp, ops::CTCAlignOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
ctc_align, ops::CTCAlignKernel<paddle::platform::CPUDeviceContext, int>, ctc_align, ops::CTCAlignKernel<paddle::platform::CPUDeviceContext, int>,
ops::CTCAlignKernel<paddle::platform::CPUDeviceContext, int64_t>); ops::CTCAlignKernel<paddle::platform::CPUDeviceContext, int64_t>);
...@@ -197,31 +197,32 @@ class CudnnLSTMGradOp : public framework::OperatorWithKernel { ...@@ -197,31 +197,32 @@ class CudnnLSTMGradOp : public framework::OperatorWithKernel {
} }
}; };
class CudnnLSTMGradOpDescMaker : public framework::SingleGradOpDescMaker { template <typename T>
class CudnnLSTMGradOpMaker : public framework::SingleGradOpMaker<T> {
public: public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<framework::OpDesc> Apply() const override { std::unique_ptr<T> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc()); std::unique_ptr<T> op(new T());
op->SetType("cudnn_lstm_grad"); op->SetType("cudnn_lstm_grad");
op->SetInput("Input", Input("Input")); op->SetInput("Input", this->Input("Input"));
op->SetInput("InitH", Input("InitH")); op->SetInput("InitH", this->Input("InitH"));
op->SetInput("InitC", Input("InitC")); op->SetInput("InitC", this->Input("InitC"));
op->SetInput("W", Input("W")); op->SetInput("W", this->Input("W"));
if (ForwardOp().Inputs().count("Cache") > 0) { if (this->HasInput("Cache")) {
op->SetInput("Cache", Input("Cache")); op->SetInput("Cache", this->Input("Cache"));
} }
op->SetInput("Out", Output("Out")); op->SetInput("Out", this->Output("Out"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetInput(framework::GradVarName("last_c"), OutputGrad("last_c")); op->SetInput(framework::GradVarName("last_c"), this->OutputGrad("last_c"));
op->SetInput(framework::GradVarName("last_h"), OutputGrad("last_h")); op->SetInput(framework::GradVarName("last_h"), this->OutputGrad("last_h"));
op->SetOutput(framework::GradVarName("Input"), InputGrad("Input")); op->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input"));
op->SetOutput(framework::GradVarName("W"), InputGrad("W")); op->SetOutput(framework::GradVarName("W"), this->InputGrad("W"));
op->SetOutput(framework::GradVarName("InitH"), InputGrad("InitH")); op->SetOutput(framework::GradVarName("InitH"), this->InputGrad("InitH"));
op->SetOutput(framework::GradVarName("InitC"), InputGrad("InitC")); op->SetOutput(framework::GradVarName("InitC"), this->InputGrad("InitC"));
op->SetAttrMap(Attrs()); op->SetAttrMap(this->Attrs());
return op; return op;
} }
}; };
...@@ -240,7 +241,8 @@ class NotImpleKernel : public framework::OpKernel<T> { ...@@ -240,7 +241,8 @@ class NotImpleKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(cudnn_lstm, ops::CudnnLSTMOp, ops::CudnnLSTMOpMaker, REGISTER_OPERATOR(cudnn_lstm, ops::CudnnLSTMOp, ops::CudnnLSTMOpMaker,
ops::CudnnLSTMGradOpDescMaker); ops::CudnnLSTMGradOpMaker<paddle::framework::OpDesc>,
ops::CudnnLSTMGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(cudnn_lstm_grad, ops::CudnnLSTMGradOp); REGISTER_OPERATOR(cudnn_lstm_grad, ops::CudnnLSTMGradOp);
REGISTER_OP_CPU_KERNEL(cudnn_lstm, ops::NotImpleKernel<float>); REGISTER_OP_CPU_KERNEL(cudnn_lstm, ops::NotImpleKernel<float>);
......
...@@ -52,20 +52,21 @@ the input. If exlusive is true, the first element of the result is 0. ...@@ -52,20 +52,21 @@ the input. If exlusive is true, the first element of the result is 0.
} }
}; };
class CumsumGradMaker : public framework::SingleGradOpDescMaker { template <typename T>
class CumsumGradMaker : public framework::SingleGradOpMaker<T> {
public: public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<framework::OpDesc> Apply() const override { std::unique_ptr<T> Apply() const override {
auto *grad_op = new framework::OpDesc(); auto *grad_op = new T();
grad_op->SetType("cumsum"); grad_op->SetType("cumsum");
grad_op->SetInput("X", OutputGrad("Out")); grad_op->SetInput("X", this->OutputGrad("Out"));
grad_op->SetOutput("Out", InputGrad("X")); grad_op->SetOutput("Out", this->InputGrad("X"));
grad_op->SetAttr("axis", Attr<int>("axis")); grad_op->SetAttr("axis", boost::get<int>(this->GetAttr("axis")));
grad_op->SetAttr("reverse", !Attr<bool>("reverse")); grad_op->SetAttr("reverse", !boost::get<bool>(this->GetAttr("reverse")));
grad_op->SetAttr("exclusive", Attr<bool>("exclusive")); grad_op->SetAttr("exclusive", boost::get<bool>(this->GetAttr("exclusive")));
return std::unique_ptr<framework::OpDesc>(grad_op); return std::unique_ptr<T>(grad_op);
} }
}; };
...@@ -75,7 +76,9 @@ class CumsumGradMaker : public framework::SingleGradOpDescMaker { ...@@ -75,7 +76,9 @@ class CumsumGradMaker : public framework::SingleGradOpDescMaker {
namespace ops = paddle::operators; namespace ops = paddle::operators;
using CPU = paddle::platform::CPUDeviceContext; using CPU = paddle::platform::CPUDeviceContext;
REGISTER_OPERATOR(cumsum, ops::CumOp, ops::CumsumOpMaker, ops::CumsumGradMaker); REGISTER_OPERATOR(cumsum, ops::CumOp, ops::CumsumOpMaker,
ops::CumsumGradMaker<paddle::framework::OpDesc>,
ops::CumsumGradMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(cumsum, ops::CumKernel<CPU, ops::CumsumFunctor<float>>, REGISTER_OP_CPU_KERNEL(cumsum, ops::CumKernel<CPU, ops::CumsumFunctor<float>>,
ops::CumKernel<CPU, ops::CumsumFunctor<double>>, ops::CumKernel<CPU, ops::CumsumFunctor<double>>,
ops::CumKernel<CPU, ops::CumsumFunctor<int>>); ops::CumKernel<CPU, ops::CumsumFunctor<int>>);
...@@ -125,19 +125,20 @@ CVM Operator. ...@@ -125,19 +125,20 @@ CVM Operator.
} }
}; };
class CVMGradOpDescMaker : public framework::SingleGradOpDescMaker { template <typename T>
class CVMGradOpMaker : public framework::SingleGradOpMaker<T> {
public: public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<framework::OpDesc> Apply() const override { std::unique_ptr<T> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc()); std::unique_ptr<T> op(new T());
op->SetType("cvm_grad"); op->SetType("cvm_grad");
op->SetInput("X", Input("X")); op->SetInput("X", this->Input("X"));
op->SetInput("CVM", Input("CVM")); op->SetInput("CVM", this->Input("CVM"));
op->SetInput(framework::GradVarName("Y"), OutputGrad("Y")); op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(Attrs()); op->SetAttrMap(this->Attrs());
return op; return op;
} }
}; };
...@@ -146,7 +147,9 @@ class CVMGradOpDescMaker : public framework::SingleGradOpDescMaker { ...@@ -146,7 +147,9 @@ class CVMGradOpDescMaker : public framework::SingleGradOpDescMaker {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(cvm, ops::CVMOp, ops::CVMOpMaker, ops::CVMGradOpDescMaker); REGISTER_OPERATOR(cvm, ops::CVMOp, ops::CVMOpMaker,
ops::CVMGradOpMaker<paddle::framework::OpDesc>,
ops::CVMGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(cvm_grad, ops::CVMGradientOp); REGISTER_OPERATOR(cvm_grad, ops::CVMGradientOp);
......
...@@ -375,32 +375,35 @@ class DataNormGradKernel<platform::CPUDeviceContext, T> ...@@ -375,32 +375,35 @@ class DataNormGradKernel<platform::CPUDeviceContext, T>
} }
}; };
class DataNormGradMaker : public framework::SingleGradOpDescMaker { template <typename T>
class DataNormGradMaker : public framework::SingleGradOpMaker<T> {
public: public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<framework::OpDesc> Apply() const override { std::unique_ptr<T> Apply() const override {
auto *op = new framework::OpDesc(); auto *op = new T();
op->SetType("data_norm_grad"); op->SetType("data_norm_grad");
op->SetInput("X", Input("X")); op->SetInput("X", this->Input("X"));
op->SetInput(framework::GradVarName("Y"), OutputGrad("Y")); op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y"));
op->SetInput("BatchSize", Input("BatchSize")); op->SetInput("BatchSize", this->Input("BatchSize"));
op->SetInput("BatchSum", Input("BatchSum")); op->SetInput("BatchSum", this->Input("BatchSum"));
op->SetInput("BatchSquareSum", Input("BatchSquareSum")); op->SetInput("BatchSquareSum", this->Input("BatchSquareSum"));
op->SetInput("Scales", Output("Scales")); op->SetInput("Scales", this->Output("Scales"));
op->SetInput("Means", Output("Means")); op->SetInput("Means", this->Output("Means"));
op->SetAttrMap(Attrs()); op->SetAttrMap(this->Attrs());
op->SetOutput(framework::GradVarName("X"), InputGrad("X")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetOutput(framework::GradVarName("BatchSize"), InputGrad("BatchSize")); op->SetOutput(framework::GradVarName("BatchSize"),
op->SetOutput(framework::GradVarName("BatchSum"), InputGrad("BatchSum")); this->InputGrad("BatchSize"));
op->SetOutput(framework::GradVarName("BatchSum"),
this->InputGrad("BatchSum"));
op->SetOutput(framework::GradVarName("BatchSquareSum"), op->SetOutput(framework::GradVarName("BatchSquareSum"),
InputGrad("BatchSquareSum")); this->InputGrad("BatchSquareSum"));
return std::unique_ptr<framework::OpDesc>(op); return std::unique_ptr<T>(op);
} }
}; };
...@@ -409,7 +412,8 @@ class DataNormGradMaker : public framework::SingleGradOpDescMaker { ...@@ -409,7 +412,8 @@ class DataNormGradMaker : public framework::SingleGradOpDescMaker {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(data_norm, ops::DataNormOp, ops::DataNormOpMaker, REGISTER_OPERATOR(data_norm, ops::DataNormOp, ops::DataNormOpMaker,
ops::DataNormGradMaker); ops::DataNormGradMaker<paddle::framework::OpDesc>,
ops::DataNormGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(data_norm_grad, ops::DataNormGradOp); REGISTER_OPERATOR(data_norm_grad, ops::DataNormGradOp);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
......
...@@ -222,27 +222,28 @@ class DeformableConvOp : public framework::OperatorWithKernel { ...@@ -222,27 +222,28 @@ class DeformableConvOp : public framework::OperatorWithKernel {
} }
}; };
class DeformableConvGradOpDescMaker : public framework::SingleGradOpDescMaker { template <typename T>
class DeformableConvGradOpMaker : public framework::SingleGradOpMaker<T> {
public: public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<framework::OpDesc> Apply() const override { std::unique_ptr<T> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc()); std::unique_ptr<T> op(new T());
op->SetType("deformable_conv_grad"); op->SetType("deformable_conv_grad");
op->SetInput("Input", Input("Input")); op->SetInput("Input", this->Input("Input"));
op->SetInput("Filter", Input("Filter")); op->SetInput("Filter", this->Input("Filter"));
op->SetInput("Offset", Input("Offset")); op->SetInput("Offset", this->Input("Offset"));
op->SetInput("Mask", Input("Mask")); op->SetInput("Mask", this->Input("Mask"));
op->SetInput(framework::GradVarName("Output"), OutputGrad("Output")); op->SetInput(framework::GradVarName("Output"), this->OutputGrad("Output"));
op->SetOutput(framework::GradVarName("Input"), InputGrad("Input")); op->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input"));
op->SetOutput(framework::GradVarName("Filter"), InputGrad("Filter")); op->SetOutput(framework::GradVarName("Filter"), this->InputGrad("Filter"));
op->SetOutput(framework::GradVarName("Offset"), InputGrad("Offset")); op->SetOutput(framework::GradVarName("Offset"), this->InputGrad("Offset"));
op->SetOutput(framework::GradVarName("Mask"), InputGrad("Mask")); op->SetOutput(framework::GradVarName("Mask"), this->InputGrad("Mask"));
op->SetAttrMap(Attrs()); op->SetAttrMap(this->Attrs());
return op; return op;
} }
}; };
...@@ -287,7 +288,9 @@ class DeformableConvGradOp : public framework::OperatorWithKernel { ...@@ -287,7 +288,9 @@ class DeformableConvGradOp : public framework::OperatorWithKernel {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(deformable_conv, ops::DeformableConvOp, REGISTER_OPERATOR(deformable_conv, ops::DeformableConvOp,
ops::DeformableConvOpMaker, ops::DeformableConvOpMaker,
ops::DeformableConvGradOpDescMaker); ops::DeformableConvGradOpMaker<paddle::framework::OpDesc>,
ops::DeformableConvGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(deformable_conv_grad, ops::DeformableConvGradOp); REGISTER_OPERATOR(deformable_conv_grad, ops::DeformableConvGradOp);
REGISTER_OP_CPU_KERNEL(deformable_conv, ops::DeformableConvCPUKernel<float>, REGISTER_OP_CPU_KERNEL(deformable_conv, ops::DeformableConvCPUKernel<float>,
......
...@@ -205,26 +205,26 @@ class DeformableConvV1Op : public framework::OperatorWithKernel { ...@@ -205,26 +205,26 @@ class DeformableConvV1Op : public framework::OperatorWithKernel {
} }
}; };
class DeformableConvV1GradOpDescMaker template <typename T>
: public framework::SingleGradOpDescMaker { class DeformableConvV1GradOpMaker : public framework::SingleGradOpMaker<T> {
public: public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<framework::OpDesc> Apply() const override { std::unique_ptr<T> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc()); std::unique_ptr<T> op(new T());
op->SetType("deformable_conv_v1_grad"); op->SetType("deformable_conv_v1_grad");
op->SetInput("Input", Input("Input")); op->SetInput("Input", this->Input("Input"));
op->SetInput("Filter", Input("Filter")); op->SetInput("Filter", this->Input("Filter"));
op->SetInput("Offset", Input("Offset")); op->SetInput("Offset", this->Input("Offset"));
op->SetInput(framework::GradVarName("Output"), OutputGrad("Output")); op->SetInput(framework::GradVarName("Output"), this->OutputGrad("Output"));
op->SetOutput(framework::GradVarName("Input"), InputGrad("Input")); op->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input"));
op->SetOutput(framework::GradVarName("Filter"), InputGrad("Filter")); op->SetOutput(framework::GradVarName("Filter"), this->InputGrad("Filter"));
op->SetOutput(framework::GradVarName("Offset"), InputGrad("Offset")); op->SetOutput(framework::GradVarName("Offset"), this->InputGrad("Offset"));
op->SetAttrMap(Attrs()); op->SetAttrMap(this->Attrs());
return op; return op;
} }
}; };
...@@ -265,7 +265,8 @@ class DeformableConvV1GradOp : public framework::OperatorWithKernel { ...@@ -265,7 +265,8 @@ class DeformableConvV1GradOp : public framework::OperatorWithKernel {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(deformable_conv_v1, ops::DeformableConvV1Op, REGISTER_OPERATOR(deformable_conv_v1, ops::DeformableConvV1Op,
ops::DeformableConvV1OpMaker, ops::DeformableConvV1OpMaker,
ops::DeformableConvV1GradOpDescMaker); ops::DeformableConvV1GradOpMaker<paddle::framework::OpDesc>,
ops::DeformableConvV1GradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(deformable_conv_v1_grad, ops::DeformableConvV1GradOp); REGISTER_OPERATOR(deformable_conv_v1_grad, ops::DeformableConvV1GradOp);
REGISTER_OP_CPU_KERNEL(deformable_conv_v1, REGISTER_OP_CPU_KERNEL(deformable_conv_v1,
......
...@@ -205,26 +205,26 @@ class DeformablePSROIPoolOp : public framework::OperatorWithKernel { ...@@ -205,26 +205,26 @@ class DeformablePSROIPoolOp : public framework::OperatorWithKernel {
} }
}; };
class DeformablePSROIPoolGradOpDescMaker template <typename T>
: public framework::SingleGradOpDescMaker { class DeformablePSROIPoolGradOpMaker : public framework::SingleGradOpMaker<T> {
public: public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<framework::OpDesc> Apply() const override { std::unique_ptr<T> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc()); std::unique_ptr<T> op(new T());
op->SetType("deformable_psroi_pooling_grad"); op->SetType("deformable_psroi_pooling_grad");
op->SetInput("Input", Input("Input")); op->SetInput("Input", this->Input("Input"));
op->SetInput("Trans", Input("Trans")); op->SetInput("Trans", this->Input("Trans"));
op->SetInput("ROIs", Input("ROIs")); op->SetInput("ROIs", this->Input("ROIs"));
op->SetInput("TopCount", Output("TopCount")); op->SetInput("TopCount", this->Output("TopCount"));
op->SetInput(framework::GradVarName("Output"), OutputGrad("Output")); op->SetInput(framework::GradVarName("Output"), this->OutputGrad("Output"));
op->SetOutput(framework::GradVarName("Input"), InputGrad("Input")); op->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input"));
op->SetOutput(framework::GradVarName("Trans"), InputGrad("Trans")); op->SetOutput(framework::GradVarName("Trans"), this->InputGrad("Trans"));
op->SetAttrMap(Attrs()); op->SetAttrMap(this->Attrs());
return op; return op;
} }
}; };
...@@ -259,9 +259,11 @@ class DeformablePSROIPoolGradOp : public framework::OperatorWithKernel { ...@@ -259,9 +259,11 @@ class DeformablePSROIPoolGradOp : public framework::OperatorWithKernel {
namespace ops = paddle::operators; namespace ops = paddle::operators;
using CPU = paddle::platform::CPUDeviceContext; using CPU = paddle::platform::CPUDeviceContext;
REGISTER_OPERATOR(deformable_psroi_pooling, ops::DeformablePSROIPoolOp, REGISTER_OPERATOR(
ops::DeformablePSROIPoolOpMaker, deformable_psroi_pooling, ops::DeformablePSROIPoolOp,
ops::DeformablePSROIPoolGradOpDescMaker); ops::DeformablePSROIPoolOpMaker,
ops::DeformablePSROIPoolGradOpMaker<paddle::framework::OpDesc>,
ops::DeformablePSROIPoolGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(deformable_psroi_pooling_grad, REGISTER_OPERATOR(deformable_psroi_pooling_grad,
ops::DeformablePSROIPoolGradOp); ops::DeformablePSROIPoolGradOp);
REGISTER_OP_CPU_KERNEL(deformable_psroi_pooling, REGISTER_OP_CPU_KERNEL(deformable_psroi_pooling,
......
...@@ -51,7 +51,9 @@ It should not be configured by users directly. ...@@ -51,7 +51,9 @@ It should not be configured by users directly.
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
REGISTER_OPERATOR(delete_var, paddle::operators::DeleteVarOp, REGISTER_OPERATOR(
paddle::framework::EmptyGradOpMaker, delete_var, paddle::operators::DeleteVarOp,
paddle::operators::DeleteVarOpInfoMaker, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::operators::DeleteVarOpShapeInference); paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
paddle::operators::DeleteVarOpInfoMaker,
paddle::operators::DeleteVarOpShapeInference);
...@@ -146,9 +146,10 @@ https://arxiv.org/abs/1506.01497. ...@@ -146,9 +146,10 @@ https://arxiv.org/abs/1506.01497.
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(anchor_generator, ops::AnchorGeneratorOp, REGISTER_OPERATOR(
ops::AnchorGeneratorOpMaker, anchor_generator, ops::AnchorGeneratorOp, ops::AnchorGeneratorOpMaker,
paddle::framework::EmptyGradOpMaker); paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(anchor_generator, ops::AnchorGeneratorOpKernel<float>, REGISTER_OP_CPU_KERNEL(anchor_generator, ops::AnchorGeneratorOpKernel<float>,
ops::AnchorGeneratorOpKernel<double>); ops::AnchorGeneratorOpKernel<double>);
...@@ -284,8 +284,9 @@ If Tensor, the height of ColToRowMatchIndices is 1. ...@@ -284,8 +284,9 @@ If Tensor, the height of ColToRowMatchIndices is 1.
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(bipartite_match, ops::BipartiteMatchOp, REGISTER_OPERATOR(
ops::BipartiteMatchOpMaker, bipartite_match, ops::BipartiteMatchOp, ops::BipartiteMatchOpMaker,
paddle::framework::EmptyGradOpMaker); paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(bipartite_match, ops::BipartiteMatchKernel<float>, REGISTER_OP_CPU_KERNEL(bipartite_match, ops::BipartiteMatchKernel<float>,
ops::BipartiteMatchKernel<double>); ops::BipartiteMatchKernel<double>);
...@@ -79,8 +79,10 @@ where im_w and im_h are computed from ImInfo, the formula is given as follows: ...@@ -79,8 +79,10 @@ where im_w and im_h are computed from ImInfo, the formula is given as follows:
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(box_clip, ops::BoxClipOp, ops::BoxClipOpMaker, REGISTER_OPERATOR(
paddle::framework::EmptyGradOpMaker); box_clip, ops::BoxClipOp, ops::BoxClipOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
box_clip, ops::BoxClipKernel<paddle::platform::CPUDeviceContext, float>, box_clip, ops::BoxClipKernel<paddle::platform::CPUDeviceContext, float>,
ops::BoxClipKernel<paddle::platform::CPUDeviceContext, double>); ops::BoxClipKernel<paddle::platform::CPUDeviceContext, double>);
...@@ -185,8 +185,10 @@ box will broadcast to target box along the assigned axis. ...@@ -185,8 +185,10 @@ box will broadcast to target box along the assigned axis.
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(box_coder, ops::BoxCoderOp, ops::BoxCoderOpMaker, REGISTER_OPERATOR(
paddle::framework::EmptyGradOpMaker); box_coder, ops::BoxCoderOp, ops::BoxCoderOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
box_coder, ops::BoxCoderKernel<paddle::platform::CPUDeviceContext, float>, box_coder, ops::BoxCoderKernel<paddle::platform::CPUDeviceContext, float>,
ops::BoxCoderKernel<paddle::platform::CPUDeviceContext, double>); ops::BoxCoderKernel<paddle::platform::CPUDeviceContext, double>);
...@@ -162,9 +162,11 @@ output_assign_box is the same as PriorBox. ...@@ -162,9 +162,11 @@ output_assign_box is the same as PriorBox.
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(box_decoder_and_assign, ops::BoxDecoderAndAssignOp, REGISTER_OPERATOR(
ops::BoxDecoderAndAssignOpMaker, box_decoder_and_assign, ops::BoxDecoderAndAssignOp,
paddle::framework::EmptyGradOpMaker); ops::BoxDecoderAndAssignOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
box_decoder_and_assign, box_decoder_and_assign,
ops::BoxDecoderAndAssignKernel<paddle::platform::CPUDeviceContext, float>, ops::BoxDecoderAndAssignKernel<paddle::platform::CPUDeviceContext, float>,
......
...@@ -100,9 +100,11 @@ by objectness confidence. Select the post_nms_topN RoIs in ...@@ -100,9 +100,11 @@ by objectness confidence. Select the post_nms_topN RoIs in
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(collect_fpn_proposals, ops::CollectFpnProposalsOp, REGISTER_OPERATOR(
ops::CollectFpnProposalsOpMaker, collect_fpn_proposals, ops::CollectFpnProposalsOp,
paddle::framework::EmptyGradOpMaker); ops::CollectFpnProposalsOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(collect_fpn_proposals, REGISTER_OP_CPU_KERNEL(collect_fpn_proposals,
ops::CollectFpnProposalsOpKernel<float>, ops::CollectFpnProposalsOpKernel<float>,
ops::CollectFpnProposalsOpKernel<double>); ops::CollectFpnProposalsOpKernel<double>);
...@@ -172,9 +172,10 @@ class DensityPriorBoxOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -172,9 +172,10 @@ class DensityPriorBoxOpMaker : public framework::OpProtoAndCheckerMaker {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(density_prior_box, ops::DensityPriorBoxOp, REGISTER_OPERATOR(
ops::DensityPriorBoxOpMaker, density_prior_box, ops::DensityPriorBoxOp, ops::DensityPriorBoxOpMaker,
paddle::framework::EmptyGradOpMaker); paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(density_prior_box, ops::DensityPriorBoxOpKernel<float>, REGISTER_OP_CPU_KERNEL(density_prior_box, ops::DensityPriorBoxOpKernel<float>,
ops::DensityPriorBoxOpKernel<double>); ops::DensityPriorBoxOpKernel<double>);
...@@ -85,9 +85,11 @@ we return an array which indicate the original index of rois in ...@@ -85,9 +85,11 @@ we return an array which indicate the original index of rois in
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(distribute_fpn_proposals, ops::DistributeFpnProposalsOp, REGISTER_OPERATOR(
ops::DistributeFpnProposalsOpMaker, distribute_fpn_proposals, ops::DistributeFpnProposalsOp,
paddle::framework::EmptyGradOpMaker); ops::DistributeFpnProposalsOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(distribute_fpn_proposals, REGISTER_OP_CPU_KERNEL(distribute_fpn_proposals,
ops::DistributeFpnProposalsOpKernel<float>, ops::DistributeFpnProposalsOpKernel<float>,
ops::DistributeFpnProposalsOpKernel<double>); ops::DistributeFpnProposalsOpKernel<double>);
...@@ -434,8 +434,10 @@ K classes. This mask targets are used to compute loss of mask branch. ...@@ -434,8 +434,10 @@ K classes. This mask targets are used to compute loss of mask branch.
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(generate_mask_labels, ops::GenerateMaskLabelsOp, REGISTER_OPERATOR(
ops::GenerateMaskLabelsOpMaker, generate_mask_labels, ops::GenerateMaskLabelsOp,
paddle::framework::EmptyGradOpMaker); ops::GenerateMaskLabelsOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(generate_mask_labels, REGISTER_OP_CPU_KERNEL(generate_mask_labels,
ops::GenerateMaskLabelsKernel<float>); ops::GenerateMaskLabelsKernel<float>);
...@@ -583,9 +583,11 @@ Finally BboxInsideWeights and BboxOutsideWeights are used to specify whether it ...@@ -583,9 +583,11 @@ Finally BboxInsideWeights and BboxOutsideWeights are used to specify whether it
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(generate_proposal_labels, ops::GenerateProposalLabelsOp, REGISTER_OPERATOR(
ops::GenerateProposalLabelsOpMaker, generate_proposal_labels, ops::GenerateProposalLabelsOp,
paddle::framework::EmptyGradOpMaker); ops::GenerateProposalLabelsOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(generate_proposal_labels, REGISTER_OP_CPU_KERNEL(generate_proposal_labels,
ops::GenerateProposalLabelsKernel<float>, ops::GenerateProposalLabelsKernel<float>,
ops::GenerateProposalLabelsKernel<double>); ops::GenerateProposalLabelsKernel<double>);
...@@ -494,8 +494,9 @@ boxes. ...@@ -494,8 +494,9 @@ boxes.
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(generate_proposals, ops::GenerateProposalsOp, REGISTER_OPERATOR(
ops::GenerateProposalsOpMaker, generate_proposals, ops::GenerateProposalsOp, ops::GenerateProposalsOpMaker,
paddle::framework::EmptyGradOpMaker); paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(generate_proposals, ops::GenerateProposalsKernel<float>, REGISTER_OP_CPU_KERNEL(generate_proposals, ops::GenerateProposalsKernel<float>,
ops::GenerateProposalsKernel<double>); ops::GenerateProposalsKernel<double>);
...@@ -87,9 +87,10 @@ $$ ...@@ -87,9 +87,10 @@ $$
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(iou_similarity, ops::IOUSimilarityOp, REGISTER_OPERATOR(
ops::IOUSimilarityOpMaker, iou_similarity, ops::IOUSimilarityOp, ops::IOUSimilarityOpMaker,
paddle::framework::EmptyGradOpMaker); paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
iou_similarity, iou_similarity,
......
...@@ -332,9 +332,10 @@ MatchIndices elements with value -1. ...@@ -332,9 +332,10 @@ MatchIndices elements with value -1.
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(mine_hard_examples, ops::MineHardExamplesOp, REGISTER_OPERATOR(
ops::MineHardExamplesOpMaker, mine_hard_examples, ops::MineHardExamplesOp, ops::MineHardExamplesOpMaker,
paddle::framework::EmptyGradOpMaker); paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
mine_hard_examples, mine_hard_examples,
......
...@@ -590,13 +590,15 @@ class MultiClassNMS2OpMaker : public MultiClassNMSOpMaker { ...@@ -590,13 +590,15 @@ class MultiClassNMS2OpMaker : public MultiClassNMSOpMaker {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(multiclass_nms, ops::MultiClassNMSOp, REGISTER_OPERATOR(
ops::MultiClassNMSOpMaker, multiclass_nms, ops::MultiClassNMSOp, ops::MultiClassNMSOpMaker,
paddle::framework::EmptyGradOpMaker); paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(multiclass_nms, ops::MultiClassNMSKernel<float>, REGISTER_OP_CPU_KERNEL(multiclass_nms, ops::MultiClassNMSKernel<float>,
ops::MultiClassNMSKernel<double>); ops::MultiClassNMSKernel<double>);
REGISTER_OPERATOR(multiclass_nms2, ops::MultiClassNMS2Op, REGISTER_OPERATOR(
ops::MultiClassNMS2OpMaker, multiclass_nms2, ops::MultiClassNMS2Op, ops::MultiClassNMS2OpMaker,
paddle::framework::EmptyGradOpMaker); paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(multiclass_nms2, ops::MultiClassNMSKernel<float>, REGISTER_OP_CPU_KERNEL(multiclass_nms2, ops::MultiClassNMSKernel<float>,
ops::MultiClassNMSKernel<double>); ops::MultiClassNMSKernel<double>);
...@@ -98,9 +98,11 @@ the geometry output contains 2*n channels. ...@@ -98,9 +98,11 @@ the geometry output contains 2*n channels.
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(polygon_box_transform, ops::PolygonBoxTransformOp, REGISTER_OPERATOR(
ops::PolygonBoxTransformOpMaker, polygon_box_transform, ops::PolygonBoxTransformOp,
paddle::framework::EmptyGradOpMaker); ops::PolygonBoxTransformOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
polygon_box_transform, polygon_box_transform,
ops::PolygonBoxTransformCPUKernel<paddle::platform::CPUPlace, float>, ops::PolygonBoxTransformCPUKernel<paddle::platform::CPUPlace, float>,
......
...@@ -203,8 +203,10 @@ https://arxiv.org/abs/1512.02325. ...@@ -203,8 +203,10 @@ https://arxiv.org/abs/1512.02325.
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(prior_box, ops::PriorBoxOp, ops::PriorBoxOpMaker, REGISTER_OPERATOR(
paddle::framework::EmptyGradOpMaker); prior_box, ops::PriorBoxOp, ops::PriorBoxOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(prior_box, ops::PriorBoxOpKernel<float, float>, REGISTER_OP_CPU_KERNEL(prior_box, ops::PriorBoxOpKernel<float, float>,
ops::PriorBoxOpKernel<double, double>); ops::PriorBoxOpKernel<double, double>);
......
...@@ -557,9 +557,11 @@ empty (None). ...@@ -557,9 +557,11 @@ empty (None).
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(retinanet_detection_output, ops::RetinanetDetectionOutputOp, REGISTER_OPERATOR(
ops::RetinanetDetectionOutputOpMaker, retinanet_detection_output, ops::RetinanetDetectionOutputOp,
paddle::framework::EmptyGradOpMaker); ops::RetinanetDetectionOutputOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(retinanet_detection_output, REGISTER_OP_CPU_KERNEL(retinanet_detection_output,
ops::RetinanetDetectionOutputKernel<float>, ops::RetinanetDetectionOutputKernel<float>,
ops::RetinanetDetectionOutputKernel<double>); ops::RetinanetDetectionOutputKernel<double>);
...@@ -620,22 +620,23 @@ class ROIPerspectiveTransformOpMaker ...@@ -620,22 +620,23 @@ class ROIPerspectiveTransformOpMaker
} }
}; };
class ROIPerspectiveTransformGradDescMaker template <typename T>
: public framework::SingleGradOpDescMaker { class ROIPerspectiveTransformGradMaker
: public framework::SingleGradOpMaker<T> {
public: public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<framework::OpDesc> Apply() const override { std::unique_ptr<T> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc()); std::unique_ptr<T> op(new T());
op->SetType("roi_perspective_transform_grad"); op->SetType("roi_perspective_transform_grad");
op->SetInput("X", Input("X")); op->SetInput("X", this->Input("X"));
op->SetInput("ROIs", Input("ROIs")); op->SetInput("ROIs", this->Input("ROIs"));
op->SetInput("Out2InIdx", Output("Out2InIdx")); op->SetInput("Out2InIdx", this->Output("Out2InIdx"));
op->SetInput("Out2InWeights", Output("Out2InWeights")); op->SetInput("Out2InWeights", this->Output("Out2InWeights"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(Attrs()); op->SetAttrMap(this->Attrs());
return op; return op;
} }
}; };
...@@ -644,9 +645,11 @@ class ROIPerspectiveTransformGradDescMaker ...@@ -644,9 +645,11 @@ class ROIPerspectiveTransformGradDescMaker
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(roi_perspective_transform, ops::ROIPerspectiveTransformOp, REGISTER_OPERATOR(
ops::ROIPerspectiveTransformOpMaker, roi_perspective_transform, ops::ROIPerspectiveTransformOp,
ops::ROIPerspectiveTransformGradDescMaker); ops::ROIPerspectiveTransformOpMaker,
ops::ROIPerspectiveTransformGradMaker<paddle::framework::OpDesc>,
ops::ROIPerspectiveTransformGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(roi_perspective_transform_grad, REGISTER_OPERATOR(roi_perspective_transform_grad,
ops::ROIPerspectiveTransformGradOp); ops::ROIPerspectiveTransformGradOp);
REGISTER_OP_CPU_KERNEL(roi_perspective_transform, REGISTER_OP_CPU_KERNEL(roi_perspective_transform,
......
...@@ -1022,14 +1022,17 @@ class RetinanetTargetAssignKernel : public framework::OpKernel<T> { ...@@ -1022,14 +1022,17 @@ class RetinanetTargetAssignKernel : public framework::OpKernel<T> {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(rpn_target_assign, ops::RpnTargetAssignOp, REGISTER_OPERATOR(
ops::RpnTargetAssignOpMaker, rpn_target_assign, ops::RpnTargetAssignOp, ops::RpnTargetAssignOpMaker,
paddle::framework::EmptyGradOpMaker); paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(rpn_target_assign, ops::RpnTargetAssignKernel<float>, REGISTER_OP_CPU_KERNEL(rpn_target_assign, ops::RpnTargetAssignKernel<float>,
ops::RpnTargetAssignKernel<double>); ops::RpnTargetAssignKernel<double>);
REGISTER_OPERATOR(retinanet_target_assign, ops::RetinanetTargetAssignOp, REGISTER_OPERATOR(
ops::RetinanetTargetAssignOpMaker, retinanet_target_assign, ops::RetinanetTargetAssignOp,
paddle::framework::EmptyGradOpMaker); ops::RetinanetTargetAssignOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(retinanet_target_assign, REGISTER_OP_CPU_KERNEL(retinanet_target_assign,
ops::RetinanetTargetAssignKernel<float>, ops::RetinanetTargetAssignKernel<float>,
ops::RetinanetTargetAssignKernel<double>); ops::RetinanetTargetAssignKernel<double>);
...@@ -172,21 +172,21 @@ We know that $$\sigma(X_j) = \\frac{1}{1 + \exp(-X_j)}$$. ...@@ -172,21 +172,21 @@ We know that $$\sigma(X_j) = \\frac{1}{1 + \exp(-X_j)}$$.
} }
}; };
class SigmoidFocalLossGradOpDescMaker template <typename T>
: public framework::SingleGradOpDescMaker { class SigmoidFocalLossGradOpMaker : public framework::SingleGradOpMaker<T> {
public: public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<framework::OpDesc> Apply() const override { std::unique_ptr<T> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc()); std::unique_ptr<T> op(new T());
op->SetType("sigmoid_focal_loss_grad"); op->SetType("sigmoid_focal_loss_grad");
op->SetInput("X", Input("X")); op->SetInput("X", this->Input("X"));
op->SetInput("Label", Input("Label")); op->SetInput("Label", this->Input("Label"));
op->SetInput("FgNum", Input("FgNum")); op->SetInput("FgNum", this->Input("FgNum"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(Attrs()); op->SetAttrMap(this->Attrs());
return op; return op;
} }
}; };
...@@ -197,7 +197,8 @@ class SigmoidFocalLossGradOpDescMaker ...@@ -197,7 +197,8 @@ class SigmoidFocalLossGradOpDescMaker
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(sigmoid_focal_loss, ops::SigmoidFocalLossOp, REGISTER_OPERATOR(sigmoid_focal_loss, ops::SigmoidFocalLossOp,
ops::SigmoidFocalLossOpMaker, ops::SigmoidFocalLossOpMaker,
ops::SigmoidFocalLossGradOpDescMaker); ops::SigmoidFocalLossGradOpMaker<paddle::framework::OpDesc>,
ops::SigmoidFocalLossGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(sigmoid_focal_loss_grad, ops::SigmoidFocalLossGradOp); REGISTER_OPERATOR(sigmoid_focal_loss_grad, ops::SigmoidFocalLossGradOp);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
sigmoid_focal_loss, sigmoid_focal_loss,
......
...@@ -152,8 +152,10 @@ template struct NegTargetAssignFunctor<platform::CPUDeviceContext, float, ...@@ -152,8 +152,10 @@ template struct NegTargetAssignFunctor<platform::CPUDeviceContext, float,
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(target_assign, ops::TargetAssignOp, ops::TargetAssignOpMaker, REGISTER_OPERATOR(
paddle::framework::EmptyGradOpMaker); target_assign, ops::TargetAssignOp, ops::TargetAssignOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
target_assign, target_assign,
ops::TargetAssignKernel<paddle::platform::CPUDeviceContext, int, float>, ops::TargetAssignKernel<paddle::platform::CPUDeviceContext, int, float>,
......
...@@ -161,7 +161,9 @@ class YoloBoxOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -161,7 +161,9 @@ class YoloBoxOpMaker : public framework::OpProtoAndCheckerMaker {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(yolo_box, ops::YoloBoxOp, ops::YoloBoxOpMaker, REGISTER_OPERATOR(
paddle::framework::EmptyGradOpMaker); yolo_box, ops::YoloBoxOp, ops::YoloBoxOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(yolo_box, ops::YoloBoxKernel<float>, REGISTER_OP_CPU_KERNEL(yolo_box, ops::YoloBoxKernel<float>,
ops::YoloBoxKernel<double>); ops::YoloBoxKernel<double>);
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include "paddle/fluid/operators/detection/yolov3_loss_op.h" #include "paddle/fluid/operators/detection/yolov3_loss_op.h"
#include <memory> #include <memory>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/imperative/type_defs.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -262,29 +263,30 @@ class Yolov3LossOpGrad : public framework::OperatorWithKernel { ...@@ -262,29 +263,30 @@ class Yolov3LossOpGrad : public framework::OperatorWithKernel {
} }
}; };
class Yolov3LossGradMaker : public framework::SingleGradOpDescMaker { template <typename T>
class Yolov3LossGradMaker : public framework::SingleGradOpMaker<T> {
public: public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<framework::OpDesc> Apply() const override { std::unique_ptr<T> Apply() const override {
auto* op = new framework::OpDesc(); auto* op = new T();
op->SetType("yolov3_loss_grad"); op->SetType("yolov3_loss_grad");
op->SetInput("X", Input("X")); op->SetInput("X", this->Input("X"));
op->SetInput("GTBox", Input("GTBox")); op->SetInput("GTBox", this->Input("GTBox"));
op->SetInput("GTLabel", Input("GTLabel")); op->SetInput("GTLabel", this->Input("GTLabel"));
op->SetInput("GTScore", Input("GTScore")); op->SetInput("GTScore", this->Input("GTScore"));
op->SetInput(framework::GradVarName("Loss"), OutputGrad("Loss")); op->SetInput(framework::GradVarName("Loss"), this->OutputGrad("Loss"));
op->SetInput("ObjectnessMask", Output("ObjectnessMask")); op->SetInput("ObjectnessMask", this->Output("ObjectnessMask"));
op->SetInput("GTMatchMask", Output("GTMatchMask")); op->SetInput("GTMatchMask", this->Output("GTMatchMask"));
op->SetAttrMap(Attrs()); op->SetAttrMap(this->Attrs());
op->SetOutput(framework::GradVarName("X"), InputGrad("X")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetOutput(framework::GradVarName("GTBox"), {}); op->SetOutput(framework::GradVarName("GTBox"), {});
op->SetOutput(framework::GradVarName("GTLabel"), {}); op->SetOutput(framework::GradVarName("GTLabel"), {});
op->SetOutput(framework::GradVarName("GTScore"), {}); op->SetOutput(framework::GradVarName("GTScore"), {});
return std::unique_ptr<framework::OpDesc>(op); return std::unique_ptr<T>(op);
} }
}; };
...@@ -293,7 +295,8 @@ class Yolov3LossGradMaker : public framework::SingleGradOpDescMaker { ...@@ -293,7 +295,8 @@ class Yolov3LossGradMaker : public framework::SingleGradOpDescMaker {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(yolov3_loss, ops::Yolov3LossOp, ops::Yolov3LossOpMaker, REGISTER_OPERATOR(yolov3_loss, ops::Yolov3LossOp, ops::Yolov3LossOpMaker,
ops::Yolov3LossGradMaker); ops::Yolov3LossGradMaker<paddle::framework::OpDesc>,
ops::Yolov3LossGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(yolov3_loss_grad, ops::Yolov3LossOpGrad); REGISTER_OPERATOR(yolov3_loss_grad, ops::Yolov3LossOpGrad);
REGISTER_OP_CPU_KERNEL(yolov3_loss, ops::Yolov3LossKernel<float>, REGISTER_OP_CPU_KERNEL(yolov3_loss, ops::Yolov3LossKernel<float>,
ops::Yolov3LossKernel<double>); ops::Yolov3LossKernel<double>);
......
...@@ -191,8 +191,10 @@ https://arxiv.org/abs/1512.02325 ...@@ -191,8 +191,10 @@ https://arxiv.org/abs/1512.02325
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(detection_map, ops::DetectionMAPOp, ops::DetectionMAPOpMaker, REGISTER_OPERATOR(
paddle::framework::EmptyGradOpMaker); detection_map, ops::DetectionMAPOp, ops::DetectionMAPOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
detection_map, ops::DetectionMAPOpKernel<paddle::platform::CPUPlace, float>, detection_map, ops::DetectionMAPOpKernel<paddle::platform::CPUPlace, float>,
ops::DetectionMAPOpKernel<paddle::platform::CPUPlace, double>); ops::DetectionMAPOpKernel<paddle::platform::CPUPlace, double>);
...@@ -51,8 +51,10 @@ class DiagOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -51,8 +51,10 @@ class DiagOpMaker : public framework::OpProtoAndCheckerMaker {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(diag, ops::DiagOp, ops::DiagOpMaker, REGISTER_OPERATOR(
paddle::framework::EmptyGradOpMaker); diag, ops::DiagOp, ops::DiagOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
diag, ops::DiagKernel<paddle::platform::CPUDeviceContext, int>, diag, ops::DiagKernel<paddle::platform::CPUDeviceContext, int>,
ops::DiagKernel<paddle::platform::CPUDeviceContext, float>, ops::DiagKernel<paddle::platform::CPUDeviceContext, float>,
......
...@@ -84,7 +84,8 @@ class CheckpointNotifyOpShapeInference : public framework::InferShapeBase { ...@@ -84,7 +84,8 @@ class CheckpointNotifyOpShapeInference : public framework::InferShapeBase {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(checkpoint_notify, ops::CheckpointNotifyOp, REGISTER_OPERATOR(
paddle::framework::EmptyGradOpMaker, checkpoint_notify, ops::CheckpointNotifyOp,
ops::CheckpointNotifyOpMaker, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
ops::CheckpointNotifyOpShapeInference); paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
ops::CheckpointNotifyOpMaker, ops::CheckpointNotifyOpShapeInference);
...@@ -78,7 +78,8 @@ class DistributedNotifyOpShapeInference : public framework::InferShapeBase { ...@@ -78,7 +78,8 @@ class DistributedNotifyOpShapeInference : public framework::InferShapeBase {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(distributed_notify, ops::DistributedNotifyOp, REGISTER_OPERATOR(
paddle::framework::EmptyGradOpMaker, distributed_notify, ops::DistributedNotifyOp,
ops::DistributedNotifyOpMaker, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
ops::DistributedNotifyOpShapeInference); paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
ops::DistributedNotifyOpMaker, ops::DistributedNotifyOpShapeInference);
...@@ -80,6 +80,8 @@ table parameter at trainer side in distributed lookup table. ...@@ -80,6 +80,8 @@ table parameter at trainer side in distributed lookup table.
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(fake_init, ops::FakeInitOp, ops::FakeInitInferShape, REGISTER_OPERATOR(
ops::FakeInitOpMaker, paddle::framework::EmptyGradOpMaker, fake_init, ops::FakeInitOp, ops::FakeInitInferShape, ops::FakeInitOpMaker,
ops::FakeInitOpVarTypeInference); paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
ops::FakeInitOpVarTypeInference);
...@@ -85,6 +85,8 @@ class FetchBarrierOpShapeInference : public framework::InferShapeBase { ...@@ -85,6 +85,8 @@ class FetchBarrierOpShapeInference : public framework::InferShapeBase {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(fetch_barrier, ops::FetchBarrierOp, REGISTER_OPERATOR(
paddle::framework::EmptyGradOpMaker, ops::FetchBarrierOpMaker, fetch_barrier, ops::FetchBarrierOp,
ops::FetchBarrierOpShapeInference); paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
ops::FetchBarrierOpMaker, ops::FetchBarrierOpShapeInference);
...@@ -95,6 +95,8 @@ class PrefetchOpShapeInference : public framework::InferShapeBase { ...@@ -95,6 +95,8 @@ class PrefetchOpShapeInference : public framework::InferShapeBase {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(prefetch, ops::PrefetchOp, REGISTER_OPERATOR(
paddle::framework::EmptyGradOpMaker, ops::PrefetchOpMaker, prefetch, ops::PrefetchOp,
ops::PrefetchOpShapeInference); paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
ops::PrefetchOpMaker, ops::PrefetchOpShapeInference);
...@@ -138,5 +138,8 @@ class RecvOpShapeInference : public framework::InferShapeBase { ...@@ -138,5 +138,8 @@ class RecvOpShapeInference : public framework::InferShapeBase {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(recv, ops::RecvOp, paddle::framework::EmptyGradOpMaker, REGISTER_OPERATOR(
ops::RecvOpMaker, ops::RecvOpShapeInference); recv, ops::RecvOp,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
ops::RecvOpMaker, ops::RecvOpShapeInference);
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册