class GreaterThanChecker {
diff --git a/paddle/framework/backward.md b/paddle/framework/backward.md
index 0a6d762bc8be5201ac196b4bc6107c06d07a31d7..ac60be572419b62f4beb644ff192d413c35e19bb 100644
--- a/paddle/framework/backward.md
+++ b/paddle/framework/backward.md
@@ -2,7 +2,7 @@
## Motivation
-In Neural Network, many model is solved by the the backpropagation algorithm(known as BP) at present. Technically it caculates the gradient of the loss function, then distributed back through the networks. Follows the chain rule, so we need a module chains the gradient operators/expressions together with to construct the backward pass. Every forward network needs a backward network to construct the full computation graph, the operator/expression's backward pass will be generated respect to forward pass.
+In Neural Network, most models are solved by the backpropagation algorithm(known as **BP**) at present. Technically, BP calculates the gradient of the loss function, then propagates it back through the networks following the chain rule. Hence we need a module that chains the gradient operators/expressions together to construct the backward pass. Every forward network needs a backward network to construct the full computation graph. The operator/expression's backward pass will be generated with respect to the forward pass.
## Implementation
@@ -24,9 +24,9 @@ A backward network is built up with several backward operators. Backward operato
| **Operator::inputs_** | Inputs | Inputs, Outputs, OutputGradients |
| **Operator::outputs_** | Outputs | InputGradients |
- In most cases, there is a one-to-one correspondence between the forward and backward operators. These correspondences are recorded by a global hash map(`OpInfoMap`). To follow the philosophy of minimum core and make operators pluggable, the registry mechanism is introduced.
+ In most cases, there is a one-to-one relation between the forward and backward operators. These relations are recorded by a global hash map(`OpInfoMap`). To follow the philosophy of minimum core and to make operators pluggable, the registry mechanism is introduced.
-For example, we have got a `mul_op`, and we can register its information and corresponding backward operator by the following macro:
+For example, we have `mul_op`, and we can register its information and corresponding backward operator by the following macro:
```cpp
REGISTER_OP(mul, MulOp, MulOpMaker, mul_grad, MulOpGrad);
@@ -48,7 +48,7 @@ The function `BuildGradOp` will sequentially execute following processes:
1. Get the `type_` of given forward operator, and then get the corresponding backward operator's type by looking up the `OpInfoMap`.
-2. Build two maps named `inputs` and `outputs` to temporary storage backward operator's inputs and outputs. Copy forward operator's `inputs_` and `outputs_` to map `inputs`, except these, are not necessary for gradient computing.
+2. Build two maps named `inputs` and `outputs` to temporarily store backward operator's inputs and outputs. Copy forward operator's `inputs_` and `outputs_` to map `inputs`, except these, are not necessary for gradient computing.
3. Add forward inputs' gradient variables into map `output`, adding forward outputs' gradient variables into map `input`.
@@ -56,11 +56,11 @@ The function `BuildGradOp` will sequentially execute following processes:
### Backward Network Building
-A backward network is a series of backward operators. The main idea of building a backward network is creating backward operators in the inverted sequence and append them together one by one. There is some corner case need to process specially.
+A backward network is a series of backward operators. The main idea of building a backward network is creating backward operators in the inverted sequence and appending them together one by one. There are some corner cases that need special processing.
1. Op
- When the input forward network is an Op, return its gradient Operator Immediately. If all of its outputs are in no gradient set, then return a special `NOP`.
+ When the input forward network is an Op, return its gradient Operator immediately. If all of its outputs are in no gradient set, then return a special `NOP`.
2. NetOp
@@ -68,33 +68,33 @@ A backward network is a series of backward operators. The main idea of building
3. RnnOp
- RnnOp is a nested stepnet operator. Backward module need to recusively call `Backward` for every stepnet.
+ RnnOp is a nested stepnet operator. Backward module needs to recusively call `Backward` for every stepnet.
4. Sharing Variables
- **sharing variables**. As illustrated in the pictures, two operator's share the same variable name of W@GRAD, which will overwrite their sharing input variable.
+ As illustrated in the figure 1 and figure 2, two operators share the same variable name **W@GRAD**, which will overwrite their shared input variable.
- pic 1. Sharing variables in operators.
+ Figure 1. Sharing variables in operators.
- Sharing variable between operators or same input variable used in multiple operators leads to a duplicate gradient variable. As demo show above, we need to rename gradient name recursively and add a generic add operator to replace the overwrite links.
+ Sharing variable between operators or same input variable used in multiple operators can lead to duplicate gradient variables. As illustrated in figure 2, we need to rename the gradient names recursively and add a generic add operator to prevent overwriting.
- pic 2. Replace sharing variable's gradient with `Add` operator.
+ Figure 2. Replace sharing variable's gradient with `Add` operator.
- Because our framework finds variables accord to their names, we need to rename the output links. We add a suffix of number to represent its position in clockwise.
+ Because the framework finds variables according to their names, we need to rename the output links. We add an integer suffix to represent its position in the clockwise direction.
-5. Part of Gradient is Zero.
+5. Part of the Gradient is Zero.
- In the whole graph, there is some case of that one operator's gradient is not needed, but its input's gradient is a dependency link of other operator, we need to fill a same shape gradient matrix in the position. In our implement, we insert a special `fillZeroLike` operator.
+ In the whole graph, there is some case of that one operator's gradient is not needed, but its input's gradient is a dependency link of other operator, we need to fill a same shape gradient matrix in the position. In our implementation, we insert a special `fillZeroLike` operator.
Follow these rules above, then collect the sub graph `OutputGradients`/`InputGradients` as the NetOp's and return it.
diff --git a/paddle/framework/op_registry_test.cc b/paddle/framework/op_registry_test.cc
index e00c6e8d904508ec9985537fc703c7c61a14e0de..b8fdf69683e645d991cf8dc2297b486680445a00 100644
--- a/paddle/framework/op_registry_test.cc
+++ b/paddle/framework/op_registry_test.cc
@@ -174,4 +174,4 @@ TEST(OpRegistry, CustomChecker) {
op->Run(scope, dev_ctx);
int test_attr = op->Attr("test_attr");
ASSERT_EQ(test_attr, 4);
-}
\ No newline at end of file
+}
diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc
index fcbfc3e4377edd0ea55c8d4328c325fa18663001..a3f28339aa64c6bde3fcefdae8b0973a5bbdd585 100644
--- a/paddle/framework/operator.cc
+++ b/paddle/framework/operator.cc
@@ -14,7 +14,6 @@ limitations under the License. */
#include "paddle/framework/operator.h"
#include
-#include "paddle/framework/op_registry.h"
namespace paddle {
namespace framework {
@@ -33,6 +32,24 @@ ExecutionContext::GetEigenDevice() const {
}
#endif
+const Tensor* GetTensorFromVar(const Variable* var) {
+ if (var->IsType()) {
+ return &var->Get();
+ }
+ PADDLE_ENFORCE(var->IsType(),
+ "The Input must be LoDTensor or Tensor.");
+ return &var->Get();
+}
+
+Tensor* GetTensorFromVar(Variable* var) {
+ if (var->IsType()) {
+ return var->GetMutable();
+ }
+ PADDLE_ENFORCE(var->IsType(),
+ "The Input must be LoDTensor or Tensor.");
+ return var->GetMutable();
+}
+
std::string OperatorBase::Input(const std::string& name) const {
auto& ins = Inputs(name);
PADDLE_ENFORCE_LE(ins.size(), 1UL,
diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h
index 2d6d5510ef6dc83f1a016be6ff123f0b9bcaf230..77c7c855c0ffed5032e639237b01037a990652c4 100644
--- a/paddle/framework/operator.h
+++ b/paddle/framework/operator.h
@@ -24,6 +24,7 @@ limitations under the License. */
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/scope.h"
+#include "paddle/framework/shape_inference.h"
#include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/place.h"
@@ -56,6 +57,9 @@ class OperatorBase;
class InferShapeContext;
class ExecutionContext;
+extern const Tensor* GetTensorFromVar(const Variable* var);
+extern Tensor* GetTensorFromVar(Variable* var);
+
/**
* OperatorBase has the basic element that Net will call to do computation.
* Only CreateOperator from OpRegistry will new Operator directly. User
@@ -262,15 +266,6 @@ class InferShapeContext {
return res;
}
- const Tensor* GetTensorFromVar(const Variable* var) const {
- if (var->IsType()) {
- return &var->Get();
- }
- PADDLE_ENFORCE(var->IsType(),
- "The Input(%s) must be LoDTensor or Tensor.");
- return &var->Get();
- }
-
void ShareLoD(const std::string& in, const std::string& out, size_t i = 0,
size_t j = 0) const {
PADDLE_ENFORCE_LT(i, InputSize(in));
@@ -340,6 +335,78 @@ class ExecutionContext : public InferShapeContext {
const platform::DeviceContext& device_context_;
};
+class RuntimeInferShapeContext : public InferShapeContextBase {
+ public:
+ RuntimeInferShapeContext(const OperatorBase& op, const Scope& scope)
+ : op_(op), scope_(scope) {}
+
+ bool HasInput(const std::string& name) const {
+ auto ipt = op_.Input(name);
+ auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
+ return var != nullptr;
+ }
+
+ bool HasOutput(const std::string& name) const {
+ auto ipt = op_.Output(name);
+ auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
+ return var != nullptr;
+ }
+
+ DDim GetInputDim(const std::string& name) const {
+ return GetDim(op_.Input(name));
+ }
+
+ void SetInputDim(const std::string& name, const DDim& dim) {
+ SetDim(op_.Input(name), dim);
+ }
+
+ DDim GetOutputDim(const std::string& name) const {
+ return GetDim(op_.Output(name));
+ }
+
+ void SetOutputDim(const std::string& name, const DDim& dim) {
+ SetDim(op_.Output(name), dim);
+ }
+
+ AttrReader Attrs() const { return AttrReader(op_.Attrs()); }
+
+ const std::vector& Inputs(const std::string& name) const {
+ return op_.Inputs(name);
+ }
+
+ const std::vector& Outputs(const std::string& name) const {
+ return op_.Outputs(name);
+ }
+
+ private:
+ template
+ Tensor* GetTensor(const std::string& name) const {
+ Tensor* t = nullptr;
+ auto* var = scope_.FindVar(name);
+ if (!var->IsType() && !var->IsType()) {
+ if (Allocate) {
+ t = var->GetMutable();
+ } else {
+ PADDLE_THROW("Variable(%s) should be tensor", name);
+ }
+ } else {
+ t = GetTensorFromVar(scope_.FindVar(name));
+ }
+ return t;
+ }
+
+ DDim GetDim(const std::string& name) const {
+ return GetTensor(name)->dims();
+ }
+
+ void SetDim(const std::string& name, const DDim& dim) {
+ GetTensor(name)->Resize(dim);
+ }
+
+ const OperatorBase& op_;
+ const Scope& scope_;
+};
+
class OpKernel {
public:
/**
@@ -383,8 +450,10 @@ class OperatorWithKernel : public OperatorBase {
const VariableNameMap& outputs, const AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
+ // runtime infershape
void InferShape(const Scope& scope) const override {
- InferShape(InferShapeContext(*this, scope));
+ auto c = RuntimeInferShapeContext(*this, scope);
+ InferShape(&c);
}
void Run(const Scope& scope,
@@ -406,7 +475,7 @@ class OperatorWithKernel : public OperatorBase {
}
protected:
- virtual void InferShape(const InferShapeContext& ctx) const = 0;
+ virtual void InferShape(InferShapeContextBase* ctx) const = 0;
};
} // namespace framework
diff --git a/paddle/framework/operator_test.cc b/paddle/framework/operator_test.cc
index 0beab0fac5b94c78121261d2661a6f969289afc4..8b4bb01a7bb80eaccee40f14fa82617505b1e2e5 100644
--- a/paddle/framework/operator_test.cc
+++ b/paddle/framework/operator_test.cc
@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/framework/operator.h"
#include "gtest/gtest.h"
+#include "paddle/framework/op_info.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
@@ -114,7 +115,7 @@ class OpWithKernelTest : public OperatorWithKernel {
using OperatorWithKernel::OperatorWithKernel;
protected:
- void InferShape(const framework::InferShapeContext& ctx) const override {}
+ void InferShape(framework::InferShapeContextBase* ctx) const override {}
};
template
diff --git a/paddle/framework/shape_inference.h b/paddle/framework/shape_inference.h
new file mode 100644
index 0000000000000000000000000000000000000000..b07fc788124413f728c713027609d9d2d1c39538
--- /dev/null
+++ b/paddle/framework/shape_inference.h
@@ -0,0 +1,82 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#pragma once
+
+#include "paddle/framework/ddim.h"
+
+namespace paddle {
+namespace framework {
+
+class InferShapeContextBase {
+ public:
+ virtual ~InferShapeContextBase() {}
+ virtual bool HasInput(const std::string &name) const = 0;
+ virtual bool HasOutput(const std::string &name) const = 0;
+ virtual framework::DDim GetInputDim(const std::string &name) const = 0;
+ std::vector GetInputsDim(const std::string &name) const {
+ const std::vector &names = Inputs(name);
+ return GetDims(names);
+ }
+ virtual void SetInputDim(const std::string &name,
+ const framework::DDim &dim) = 0;
+ void SetInputsDim(const std::string &name,
+ const std::vector &dims) {
+ auto &names = Inputs(name);
+ SetDims(names, dims);
+ }
+ virtual framework::DDim GetOutputDim(const std::string &name) const = 0;
+ std::vector GetOutputsDim(const std::string &name) const {
+ const std::vector &names = Outputs(name);
+ return GetDims(names);
+ }
+ virtual void SetOutputDim(const std::string &name, const DDim &dim) = 0;
+ void SetOutputsDim(const std::string &name,
+ const std::vector &dims) {
+ auto &names = Outputs(name);
+ SetDims(names, dims);
+ }
+ virtual AttrReader Attrs() const = 0;
+ virtual const std::vector &Inputs(
+ const std::string &name) const = 0;
+ virtual const std::vector &Outputs(
+ const std::string &name) const = 0;
+ // TODO(qiao) implement this function
+ void ShareLoD(const std::string &in, const std::string &out, size_t i = 0,
+ size_t j = 0) const {}
+
+ protected:
+ virtual framework::DDim GetDim(const std::string &name) const = 0;
+ virtual void SetDim(const std::string &name, const framework::DDim &dim) = 0;
+ std::vector GetDims(
+ const std::vector &names) const {
+ std::vector ret;
+ ret.reserve(names.size());
+ std::transform(
+ names.begin(), names.end(), std::back_inserter(ret),
+ [this](const std::string &name) { return this->GetDim(name); });
+ return ret;
+ }
+ void SetDims(const std::vector &names,
+ const std::vector &dims) {
+ size_t length = names.size();
+ PADDLE_ENFORCE_EQ(length, dims.size());
+ for (size_t i = 0; i < length; ++i) {
+ SetDim(names[i], dims[i]);
+ }
+ }
+};
+
+} // namespace framework
+} // namespace paddle
diff --git a/paddle/framework/variable.md b/paddle/framework/variable.md
index f44d5ea46e7ce98dd443d684ad42308496bc4179..442ef6b718b227d79ca73031efcbb55817558252 100644
--- a/paddle/framework/variable.md
+++ b/paddle/framework/variable.md
@@ -7,7 +7,7 @@ Variable is also known as *blob* in MxNet and Caffe2. It is the input and outpu
For the flexibility of a DL system, a variable should be able to contain any typed value -- a tensor in most cases, but could also be some integer IDs or a scope of other variables in the case of RNN.
-To use the minimum amount of memory, we'd like that a variable to allocate memory when it has to, or, lazy memory allocation. Let's take the following example:
+To use the minimum amount of memory, we would like that a variable allocates memory only when it has to, or, lazy memory allocation. Let's take the following example:
```cpp
Variable vr, v1, v2;
@@ -38,7 +38,7 @@ This syntax for lazy memory allocation when we call `Randomize` and `Mult`, thos
To make memory allocation lazy, we cannot assume that we know the type held by a variable at definition time. In other words, `class Variable` cannot be a template `template class Variable`.
-Because we don't know the type `T`, we cannot save a `T*` as `Variable's` data member. Instead, we save an interface object `Placeholder`, who can return the pointer to the saved object via `Placeholder::Ptr()` as `void*`.
+Because we don't know the type `T`, we cannot save a `T*` as `Variable's` data member. Instead, we save an interface object `Placeholder`, which can return the pointer to the saved object via `Placeholder::Ptr()` as `void*`.
But anyway, Variable needs to know `T` so could it `delete(ptr)` and so could `Variable::Get` checks the expected type and the saved object's type.
@@ -49,4 +49,4 @@ Because `PlaceholderImpl` knows `T`, it can save and return `typeid(T)` for the
## Conclusion
-The technique type hiding utilizes C++ class templates, interface and derivation, and C++ RTTI (typeid). This combination saves us from definition something like `caffe2::TypeMata`, which takes hundreds of lines of C++ code.
+The technique type hiding utilizes C++ class templates, interface and derivation, and C++ RTTI (typeid). This combination saves us from defining something like `caffe2::TypeMeta`, which takes hundreds of lines of C++ code.
diff --git a/paddle/gserver/activations/MKLDNNActivation.cpp b/paddle/gserver/activations/MKLDNNActivation.cpp
index ac50937ef3e28c1ac5aae651f9cf266ad07abcc4..18c5638100065109fba1f0647a1c5f91256f7b9d 100644
--- a/paddle/gserver/activations/MKLDNNActivation.cpp
+++ b/paddle/gserver/activations/MKLDNNActivation.cpp
@@ -27,31 +27,53 @@ static ClassRegistrar gMKLDNNActivationRegistrar;
#define MKLDNN_ACTIVATION_CLASS_NAME(ACT_TYPE) mkldnn_##ACT_TYPE##Activation
/**
- * @def DEFINE_MKLDNN_ELTWISE_ACTIVATION
+ * @def BEGIN_MKLDNN_ACTIVATION
+ */
+#define BEGIN_MKLDNN_ACTIVATION(ACT_TYPE, BASE_CLASS) \
+ class MKLDNN_ACTIVATION_CLASS_NAME(ACT_TYPE) : public BASE_CLASS {
+/**
+ * @def END_MKLDNN_ACTIVATION
*/
-#define DEFINE_MKLDNN_ELTWISE_ACTIVATION(ACT_TYPE, ALPHA, BWD_ALPHA) \
- class MKLDNN_ACTIVATION_CLASS_NAME(ACT_TYPE) \
- : public MKLDNNEltwiseActivation { \
- private: \
- static const std::string name; \
- static const float alpha; \
- static const float bwdAlpha; \
- \
- public: \
- const std::string& getName() const { return name; } \
- float getAlpha() const { return alpha; } \
- float getBwdAlpha() const { return bwdAlpha; } \
- }; \
- const std::string MKLDNN_ACTIVATION_CLASS_NAME(ACT_TYPE)::name = \
- "mkldnn_" #ACT_TYPE; \
- const float MKLDNN_ACTIVATION_CLASS_NAME(ACT_TYPE)::alpha = ALPHA; \
- const float MKLDNN_ACTIVATION_CLASS_NAME(ACT_TYPE)::bwdAlpha = BWD_ALPHA; \
- static InitFunction __reg_activation__mkldnn_##ACT_TYPE([] { \
- gMKLDNNActivationRegistrar \
- .registerClass( \
- "mkldnn_" #ACT_TYPE); \
+#define END_MKLDNN_ACTIVATION(ACT_TYPE) \
+private: \
+ static const std::string name; \
+ \
+public: \
+ const std::string& getName() const { return name; } \
+ } \
+ ; \
+ const std::string MKLDNN_ACTIVATION_CLASS_NAME(ACT_TYPE)::name = \
+ "mkldnn_" #ACT_TYPE; \
+ static InitFunction __reg_activation__mkldnn_##ACT_TYPE([] { \
+ gMKLDNNActivationRegistrar \
+ .registerClass( \
+ "mkldnn_" #ACT_TYPE); \
});
+/**
+ * @def DEFINE_MKLDNN_ACTIVATION
+ */
+#define DEFINE_MKLDNN_ACTIVATION(ACT_TYPE, BASE_CLASS) \
+ BEGIN_MKLDNN_ACTIVATION(ACT_TYPE, BASE_CLASS) \
+ END_MKLDNN_ACTIVATION(ACT_TYPE)
+
+/**
+ * @def DEFINE_MKLDNN_ELTWISE_ACTIVATION
+ */
+#define DEFINE_MKLDNN_ELTWISE_ACTIVATION( \
+ ACT_TYPE, BASE_CLASS, ALPHA, BWD_ALPHA) \
+ BEGIN_MKLDNN_ACTIVATION(ACT_TYPE, BASE_CLASS) \
+private: \
+ static const float alpha; \
+ static const float bwdAlpha; \
+ \
+public: \
+ float getAlpha() const { return alpha; } \
+ float getBwdAlpha() const { return bwdAlpha; } \
+ END_MKLDNN_ACTIVATION(ACT_TYPE) \
+ const float MKLDNN_ACTIVATION_CLASS_NAME(ACT_TYPE)::alpha = ALPHA; \
+ const float MKLDNN_ACTIVATION_CLASS_NAME(ACT_TYPE)::bwdAlpha = BWD_ALPHA;
+
/**
* @brief MKLDNN Relu Activation.
* Actually mkldnn_relu is Leaky Relu.
@@ -59,19 +81,129 @@ static ClassRegistrar gMKLDNNActivationRegistrar;
* f(x) = negative_slope * x (x < 0)
* @note the negative_slope should be -0.f in forward
*/
-DEFINE_MKLDNN_ELTWISE_ACTIVATION(relu, -0.f, 0.f)
+DEFINE_MKLDNN_ELTWISE_ACTIVATION(relu, MKLDNNEltwiseActivation, -0.f, 0.f)
/**
* @brief MKLDNN Tanh Activation.
*/
-DEFINE_MKLDNN_ELTWISE_ACTIVATION(tanh, 0.f, 0.f)
+DEFINE_MKLDNN_ELTWISE_ACTIVATION(tanh, MKLDNNEltwiseActivation, 0.f, 0.f)
/**
* @brief MKLDNN ELU(Exponential Linear Unit) Activation.
* f(x) = x (x >= 0)
* f(x) = negative_slope * (exp(x) - 1) (x < 0)
*/
-DEFINE_MKLDNN_ELTWISE_ACTIVATION(elu, 0.f, 0.f)
+DEFINE_MKLDNN_ELTWISE_ACTIVATION(elu, MKLDNNEltwiseActivation, 0.f, 0.f)
+
+mkldnn::algorithm MKLDNNEltwiseActivation::getAlgo(std::string type) const {
+ const std::map algoMap = {
+ {"relu", algorithm::eltwise_relu},
+ {"tanh", algorithm::eltwise_tanh},
+ {"elu", algorithm::eltwise_elu}};
+ type.erase(0, 7); // remove mkldnn_
+ algorithm algo = (algorithm)0;
+ mapGet(type, algoMap, &algo);
+ return algo;
+}
+
+void MKLDNNEltwiseActivation::resetFwd(Argument& act) {
+ if (cnt_ == act.value->getElementCnt()) {
+ return;
+ }
+ MKLDNNActivation::resetFwd(act);
+ // note: alpha represents the NegativeSlope when used in relu.
+ float alpha = getAlpha();
+ float beta = getBeta();
+ algorithm algo = getAlgo(this->getName());
+ auto fwdDesc = eltwise_fwd::desc(mkldnn::prop_kind::forward_training,
+ algo,
+ val_->getMemoryDesc(),
+ alpha,
+ beta);
+ fwdPD_.reset(new eltwise_fwd::primitive_desc(fwdDesc, *engine_));
+ // use inplace for forward but save input value before submit
+ inVal_ = val_;
+ copyInVal_ = nullptr;
+ if (act.grad && algo == algorithm::eltwise_tanh) {
+ // tanh need save src input for backward
+ inVal_ = MKLDNNMatrix::create(nullptr, val_->getPrimitiveDesc());
+ copyInVal_ = std::make_shared(*val_, *inVal_);
+ CHECK(copyInVal_) << "should not be emptry";
+ pipelineFwd_.push_back(*copyInVal_);
+ }
+ fwd_.reset(new eltwise_fwd(*fwdPD_, *val_, *val_));
+ pipelineFwd_.push_back(*fwd_);
+ needResetBwd_ = true;
+}
+
+void MKLDNNEltwiseActivation::resetBwd(Argument& act) {
+ if (!needResetBwd_) {
+ return;
+ }
+ VLOG(MKLDNN_BASE) << getName() << " reset mkldnn backward";
+ needResetBwd_ = false;
+ algorithm algo = getAlgo(this->getName());
+ float alpha = getBwdAlpha();
+ float beta = getBeta();
+ grad_ = MKLDNNMatrix::create(act.grad, val_->getPrimitiveDesc());
+ auto eng = CPUEngine::Instance().getEngine();
+ auto bwdDesc = eltwise_bwd::desc(
+ algo, grad_->getMemoryDesc(), val_->getMemoryDesc(), alpha, beta);
+ auto bwdPD = eltwise_bwd::primitive_desc(bwdDesc, eng, *fwdPD_);
+ CHECK(inVal_);
+ bwd_.reset(new eltwise_bwd(bwdPD, *inVal_, *grad_, *grad_));
+ pipelineBwd_.clear();
+ pipelineBwd_.push_back(*bwd_);
+}
+
+/**
+ * @brief MKLDNN Softmax Activation
+ */
+DEFINE_MKLDNN_ACTIVATION(softmax, MKLDNNSoftmaxActivation)
+
+void MKLDNNSoftmaxActivation::resetFwd(Argument& act) {
+ if (cnt_ == act.value->getElementCnt()) {
+ return;
+ }
+ MKLDNNActivation::resetFwd(act);
+ int axis = 1;
+ auto fwdDesc = softmax_fwd::desc(
+ mkldnn::prop_kind::forward_scoring, val_->getMemoryDesc(), axis);
+ auto fwdPD = softmax_fwd::primitive_desc(fwdDesc, *engine_);
+ fwd_.reset(new softmax_fwd(fwdPD, *val_, *val_));
+ pipelineFwd_.push_back(*fwd_);
+}
+
+Error __must_check MKLDNNSoftmaxActivation::forward(Argument& act) {
+ resetFwd(act);
+ stream_->submit(pipelineFwd_);
+ real* v = act.value->getData();
+ real threshold = exp(-64);
+#pragma omp parallel for
+ for (size_t i = 0; i < act.value->getElementCnt(); ++i) {
+ v[i] = v[i] < threshold ? threshold : v[i];
+ }
+ return Error();
+}
+
+Error __must_check MKLDNNSoftmaxActivation::backward(Argument& act) {
+ MatrixPtr outputV = act.value;
+ MatrixPtr outputG = act.grad;
+ Matrix::resizeOrCreate(sftMaxDot_,
+ outputG->getHeight(),
+ outputG->getWidth(),
+ /* trans */ false,
+ /* useGpu */ false);
+ Matrix::resizeOrCreate(sftMaxSum_,
+ outputG->getHeight(),
+ 1,
+ /* trans */ false,
+ /* useGpu */ false);
+ sftMaxDot_->dotMul(*outputG, *outputV);
+ sftMaxSum_->colMerge(*sftMaxDot_);
+ act.grad->softmaxDerivative(*act.value, *sftMaxSum_);
+ return Error();
+}
ActivationFunction* MKLDNNActivation::create(const std::string& type) {
return gMKLDNNActivationRegistrar.createByType(type);
@@ -84,4 +216,34 @@ std::vector MKLDNNActivation::getAllRegisteredTypes() {
return types;
}
+void MKLDNNActivation::resetFwd(Argument& act) {
+ VLOG(MKLDNN_BASE) << getName() << " reset mkldnn forward";
+ cnt_ = act.value->getElementCnt();
+ pipelineFwd_.clear();
+ stream_.reset(new MKLDNNStream());
+ engine_.reset(new mkldnn::engine(mkldnn::engine::cpu, 0));
+ val_ = std::dynamic_pointer_cast(act.value);
+ if (val_ == nullptr) {
+ int bs = act.getBatchSize();
+ int ih = act.getFrameHeight() > 0 ? act.getFrameHeight() : 1;
+ int iw = act.getFrameWidth() > 0 ? act.getFrameWidth() : 1;
+ int ic = cnt_ / bs / ih / iw;
+ CHECK_EQ(cnt_, (size_t)bs * ic * ih * iw);
+ val_ = MKLDNNMatrix::create(
+ act.value, {bs, ic, ih, iw}, mkldnn::memory::format::nchw, *engine_);
+ CHECK(val_);
+ val_->downSpatial();
+ }
+}
+
+Error __must_check MKLDNNActivation::forward(Argument& act) {
+ resetFwd(act);
+ stream_->submit(pipelineFwd_);
+ return Error();
+}
+Error __must_check MKLDNNActivation::backward(Argument& act) {
+ resetBwd(act);
+ stream_->submit(pipelineBwd_);
+ return Error();
+}
} // namespace paddle
diff --git a/paddle/gserver/activations/MKLDNNActivation.h b/paddle/gserver/activations/MKLDNNActivation.h
index 40dd8c618aa2b70d410130e12efc54520218afea..dd16421fd6e93b49c30b1d3b601f95980afec57b 100644
--- a/paddle/gserver/activations/MKLDNNActivation.h
+++ b/paddle/gserver/activations/MKLDNNActivation.h
@@ -36,6 +36,7 @@ protected:
// mkldnn matrix, primitive, stream and pipeline
MKLDNNMatrixPtr val_;
MKLDNNMatrixPtr grad_;
+ std::shared_ptr engine_;
std::shared_ptr stream_;
std::shared_ptr fwd_;
std::shared_ptr bwd_;
@@ -48,8 +49,18 @@ public:
static ActivationFunction* create(const std::string& type);
static std::vector getAllRegisteredTypes();
virtual const std::string& getName() const = 0;
- virtual Error __must_check forward(Argument& act) = 0;
- virtual Error __must_check backward(Argument& act) = 0;
+ /**
+ * reset the forward primitives
+ */
+ virtual void resetFwd(Argument& act);
+ /**
+ * reset the backward primitives,
+ * can not merge this functions into resetFwd as the grad data
+ * would be changing before backward.
+ */
+ virtual void resetBwd(Argument& act) {}
+ virtual Error __must_check forward(Argument& act);
+ virtual Error __must_check backward(Argument& act);
};
/**
@@ -59,6 +70,7 @@ public:
class MKLDNNEltwiseActivation : public MKLDNNActivation {
typedef mkldnn::eltwise_forward eltwise_fwd;
typedef mkldnn::eltwise_backward eltwise_bwd;
+ typedef mkldnn::algorithm algorithm;
protected:
// save the forward primitive desc, which can be used backward
@@ -70,9 +82,7 @@ protected:
public:
MKLDNNEltwiseActivation() {}
-
~MKLDNNEltwiseActivation() {}
-
virtual const std::string& getName() const = 0;
// in common, the alpha of forward and backward should be equal.
@@ -80,105 +90,30 @@ public:
virtual float getAlpha() const = 0;
virtual float getBwdAlpha() const = 0;
virtual float getBeta() const { return 0.f; }
- virtual mkldnn::algorithm getAlgo(const std::string& type) const {
- if (type == "mkldnn_relu") {
- return mkldnn::algorithm::eltwise_relu;
- } else if (type == "mkldnn_tanh") {
- return mkldnn::algorithm::eltwise_tanh;
- } else if (type == "mkldnn_elu") {
- return mkldnn::algorithm::eltwise_elu;
- } else {
- LOG(FATAL) << "Unkown eltwise activation type: " << type;
- }
- return (mkldnn::algorithm)0;
- }
-
- /**
- * reshape and reset the forward primitives
- */
- void resetFwd(Argument& act) {
- if (cnt_ == act.value->getElementCnt()) {
- return;
- }
- VLOG(MKLDNN_BASE) << getName() << " reset mkldnn forward";
- cnt_ = act.value->getElementCnt();
- stream_.reset(new MKLDNNStream());
- auto eng = CPUEngine::Instance().getEngine();
-
- // get algo setting
- mkldnn::algorithm algo = getAlgo(this->getName());
- // note: alpha represents the NegativeSlope when used in relu.
- float alpha = getAlpha();
- float beta = getBeta();
-
- pipelineFwd_.clear();
- val_ = std::dynamic_pointer_cast(act.value);
- if (val_ == nullptr) {
- int bs = act.getBatchSize();
- int ih = act.getFrameHeight() > 0 ? act.getFrameHeight() : 1;
- int iw = act.getFrameWidth() > 0 ? act.getFrameWidth() : 1;
- int ic = cnt_ / bs / ih / iw;
- CHECK_EQ(cnt_, (size_t)bs * ic * ih * iw);
- val_ = MKLDNNMatrix::create(
- act.value, {bs, ic, ih, iw}, mkldnn::memory::format::nchw, eng);
- CHECK(val_);
- }
- auto fwdDesc = eltwise_fwd::desc(mkldnn::prop_kind::forward_training,
- algo,
- val_->getMemoryDesc(),
- alpha,
- beta);
- fwdPD_.reset(new eltwise_fwd::primitive_desc(fwdDesc, eng));
- // use inplace for forward but save input value before submit
- inVal_ = val_;
- copyInVal_ = nullptr;
- if (act.grad && algo == mkldnn::algorithm::eltwise_tanh) {
- // tanh need save src input for backward
- inVal_ = MKLDNNMatrix::create(nullptr, val_->getPrimitiveDesc());
- copyInVal_ = std::make_shared(*val_, *inVal_);
- CHECK(copyInVal_) << "should not be emptry";
- pipelineFwd_.push_back(*copyInVal_);
- }
- fwd_.reset(new eltwise_fwd(*fwdPD_, *val_, *val_));
- pipelineFwd_.push_back(*fwd_);
- needResetBwd_ = true;
- }
+ virtual algorithm getAlgo(std::string type) const;
+ void resetFwd(Argument& act) override;
+ void resetBwd(Argument& act) override;
+};
- /**
- * reset the backward primitives, can not merge into resetFwd as the grad data
- * would be changing before backward.
- */
- void resetBwd(Argument& act) {
- if (!needResetBwd_) {
- return;
- }
- VLOG(MKLDNN_BASE) << getName() << " reset mkldnn backward";
- needResetBwd_ = false;
- mkldnn::algorithm algo = getAlgo(this->getName());
- float alpha = getBwdAlpha();
- float beta = getBeta();
- grad_ = MKLDNNMatrix::create(act.grad, val_->getPrimitiveDesc());
- auto eng = CPUEngine::Instance().getEngine();
- auto bwdDesc = eltwise_bwd::desc(
- algo, grad_->getMemoryDesc(), val_->getMemoryDesc(), alpha, beta);
- auto bwdPD = eltwise_bwd::primitive_desc(bwdDesc, eng, *fwdPD_);
- CHECK(inVal_);
- bwd_.reset(new eltwise_bwd(bwdPD, *inVal_, *grad_, *grad_));
- pipelineBwd_.clear();
- pipelineBwd_.push_back(*bwd_);
- }
+/**
+ * @brief Base class of MKLDNN softmax Activation,
+ * only have mkldnn forward, use cpu implement for backward.
+ */
+class MKLDNNSoftmaxActivation : public MKLDNNActivation {
+ typedef mkldnn::softmax_forward softmax_fwd;
- Error __must_check forward(Argument& act) {
- resetFwd(act);
- stream_->submit(pipelineFwd_);
- return Error();
- }
+private:
+ // for backward
+ MatrixPtr sftMaxSum_;
+ MatrixPtr sftMaxDot_;
- Error __must_check backward(Argument& act) {
- resetBwd(act);
- stream_->submit(pipelineBwd_);
- return Error();
- }
+public:
+ MKLDNNSoftmaxActivation() {}
+ ~MKLDNNSoftmaxActivation() {}
+ virtual const std::string& getName() const = 0;
+ void resetFwd(Argument& act) override;
+ Error __must_check forward(Argument& act) override;
+ Error __must_check backward(Argument& act) override;
};
} // namespace paddle
diff --git a/paddle/gserver/tests/test_MKLDNN.cpp b/paddle/gserver/tests/test_MKLDNN.cpp
index 1bfbbde4246a10eaf86693a6a2f237f390966db3..857d07df3e3088be28943d9e2fe58017e9e57f4a 100644
--- a/paddle/gserver/tests/test_MKLDNN.cpp
+++ b/paddle/gserver/tests/test_MKLDNN.cpp
@@ -222,8 +222,8 @@ static void getAddtoConfig(TestConfig& cfg, const testActDesc& pm) {
}
void testActivation(std::string& actType, const testActDesc& pm) {
- // TODO(TJ): mkldnn_softmax not implemented, paddle do not have elu activation
- if (actType == "mkldnn_softmax" || actType == "mkldnn_elu") {
+ // TODO(TJ): remove me when paddle support elu activation
+ if (actType == "mkldnn_elu") {
return;
}
const std::string compareTypes[] = {actType, actType.erase(0, 7)};
diff --git a/paddle/operators/accuracy_op.cc b/paddle/operators/accuracy_op.cc
index 70e4f9da1221ab300e2b507a3da2f7c5da93f2e4..82010bfb53e58a0836c99c353590f4e32e25ac4a 100644
--- a/paddle/operators/accuracy_op.cc
+++ b/paddle/operators/accuracy_op.cc
@@ -22,25 +22,23 @@ class AccuracyOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
- void InferShape(const framework::InferShapeContext &ctx) const override {
- PADDLE_ENFORCE_NOT_NULL(
- ctx.InputVar("Inference"),
- "Input(Inference) of AccuracyOp should not be null.");
- PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"),
- "Input(Label) of AccuracyOp should not be null.");
- PADDLE_ENFORCE_NOT_NULL(
- ctx.OutputVar("Accuracy"),
- "Output(Accuracy) of AccuracyOp should not be null.");
+ void InferShape(framework::InferShapeContextBase *ctx) const override {
+ PADDLE_ENFORCE(ctx->HasInput("Inference"),
+ "Input(Inference) of AccuracyOp should not be null.");
+ PADDLE_ENFORCE(ctx->HasInput("Label"),
+ "Input(Label) of AccuracyOp should not be null.");
+ PADDLE_ENFORCE(ctx->HasOutput("Accuracy"),
+ "Output(Accuracy) of AccuracyOp should not be null.");
- auto *inference = ctx.Input("Inference");
- auto *label = ctx.Input("Label");
+ auto inference_dim = ctx->GetInputDim("Inference");
+ auto label_dim = ctx->GetInputDim("Label");
- PADDLE_ENFORCE_EQ(label->dims().size(), 1, "label must be a vector");
- PADDLE_ENFORCE_EQ(inference->dims()[0], label->dims()[0],
+ PADDLE_ENFORCE_EQ(label_dim.size(), 1, "label must be a vector");
+ PADDLE_ENFORCE_EQ(inference_dim[0], label_dim[0],
"inference size must be the same as label size");
- ctx.Output("Accuracy")->Resize({1});
- ctx.ShareLoD("Inference", /*->*/ "Accuracy");
+ ctx->SetOutputDim("Accuracy", {1});
+ ctx->ShareLoD("Inference", /*->*/ "Accuracy");
}
};
diff --git a/paddle/operators/accuracy_op.cu b/paddle/operators/accuracy_op.cu
index 0a6a0fd15c73330902552f7a9aa6339de24c1a18..75e8a989036f0b818687e1fec3e600bb90e86b22 100644
--- a/paddle/operators/accuracy_op.cu
+++ b/paddle/operators/accuracy_op.cu
@@ -69,8 +69,12 @@ class AccuracyOpCUDAKernel : public framework::OpKernel {
return;
}
- AccuracyCudaKernel<<<1, PADDLE_CUDA_NUM_THREADS>>>(
- num_samples, infer_width, inference_data, label_data, accuracy_data);
+ AccuracyCudaKernel<<<
+ 1, PADDLE_CUDA_NUM_THREADS, 0,
+ reinterpret_cast(
+ ctx.device_context())
+ .stream()>>>(num_samples, infer_width, inference_data, label_data,
+ accuracy_data);
}
};
diff --git a/paddle/operators/activation_op.cc b/paddle/operators/activation_op.cc
index 06654702bc42cc7cf4917b00693334b1d36ce371..f77e1c572e33533ac672e3d476a7e6dad122031f 100644
--- a/paddle/operators/activation_op.cc
+++ b/paddle/operators/activation_op.cc
@@ -22,10 +22,9 @@ class ActivationOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
- void InferShape(const framework::InferShapeContext &ctx) const override {
- ctx.Output("Y")->Resize(
- ctx.Input("X")->dims());
- ctx.ShareLoD("X", /*->*/ "Y");
+ void InferShape(framework::InferShapeContextBase *ctx) const override {
+ ctx->SetOutputDim("Y", ctx->GetInputDim("X"));
+ ctx->ShareLoD("X", /*->*/ "Y");
}
};
@@ -34,9 +33,8 @@ class ActivationOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
- void InferShape(const framework::InferShapeContext &ctx) const override {
- ctx.Output(framework::GradVarName("X"))
- ->Resize(ctx.Input("Y")->dims());
+ void InferShape(framework::InferShapeContextBase *ctx) const override {
+ ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("Y"));
}
};
diff --git a/paddle/operators/add_op.cc b/paddle/operators/add_op.cc
index ed11d096974341022637676537793645f46738f0..3914d1323083ede6a7ea07e7b4ef76b9e4afd26d 100644
--- a/paddle/operators/add_op.cc
+++ b/paddle/operators/add_op.cc
@@ -22,25 +22,23 @@ class AddOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
- void InferShape(const framework::InferShapeContext &ctx) const override {
- PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"),
- "Input(X) of AddOp should not be null.");
- PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"),
- "Input(Y) of AddOp should not be null.");
- PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"),
- "Output(Out) of AddOp should not be null.");
+ void InferShape(framework::InferShapeContextBase* ctx) const override {
+ PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of AddOp should not be null.");
+ PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) of AddOp should not be null.");
+ PADDLE_ENFORCE(ctx->HasOutput("Out"),
+ "Output(Out) of AddOp should not be null.");
- PADDLE_ENFORCE_EQ(ctx.Input("X")->dims(),
- ctx.Input("Y")->dims(),
+ auto x_dims = ctx->GetInputDim("X");
+ auto y_dims = ctx->GetInputDim("Y");
+ PADDLE_ENFORCE_EQ(x_dims, y_dims,
"Two input of Add Op's dimension must be same.");
- ctx.Output("Out")->Resize(
- ctx.Input("X")->dims());
+ ctx->SetOutputDim("Out", x_dims);
}
};
class AddOpMaker : public framework::OpProtoAndCheckerMaker {
public:
- AddOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
+ AddOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The first input of add op");
AddInput("Y", "The second input of add op");
@@ -58,7 +56,7 @@ class AddOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
- void InferShape(const framework::InferShapeContext &ctx) const override {}
+ void InferShape(framework::InferShapeContextBase* ctx) const override {}
};
} // namespace operators
diff --git a/paddle/operators/clip_op.cc b/paddle/operators/clip_op.cc
index e5a54bc4b226fd24337050fdd84b2de9c49f7949..316d28f174658de0e20ed9512f315da305bbb6d0 100644
--- a/paddle/operators/clip_op.cc
+++ b/paddle/operators/clip_op.cc
@@ -22,24 +22,24 @@ class ClipOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
- void InferShape(const framework::InferShapeContext &ctx) const override {
- PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"),
- "Input(X) of ClipOp should not be null.");
- PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"),
- "Output(Out) of ClipOp should not be null.");
- auto x_dims = ctx.Input("X")->dims();
- auto max = Attr("max");
- auto min = Attr("min");
+ void InferShape(framework::InferShapeContextBase* ctx) const override {
+ PADDLE_ENFORCE(ctx->HasInput("X"),
+ "Input(X) of ClipOp should not be null.");
+ PADDLE_ENFORCE(ctx->HasOutput("Out"),
+ "Output(Out) of ClipOp should not be null.");
+ auto x_dims = ctx->GetInputDim("X");
+ auto max = ctx->Attrs().Get("max");
+ auto min = ctx->Attrs().Get("min");
PADDLE_ENFORCE_LT(min, max, "max should be greater than min.");
- ctx.Output("Out")->Resize(x_dims);
- ctx.ShareLoD("X", /*->*/ "Out");
+ ctx->SetOutputDim("Out", x_dims);
+ ctx->ShareLoD("X", /*->*/ "Out");
}
};
template
class ClipOpMaker : public framework::OpProtoAndCheckerMaker {
public:
- ClipOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
+ ClipOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X",
"(Tensor)The input of clip op."
@@ -61,14 +61,13 @@ class ClipOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
- void InferShape(const framework::InferShapeContext &ctx) const override {
- PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null");
- PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
- "Input(Out@GRAD) should not be null");
- auto x_dims = ctx.Input("X")->dims();
- auto *x_grad = ctx.Output(framework::GradVarName("X"));
- if (x_grad != nullptr) {
- x_grad->Resize(x_dims);
+ void InferShape(framework::InferShapeContextBase* ctx) const override {
+ PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
+ PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
+ "Input(Out@GRAD) should not be null");
+ auto x_dims = ctx->GetInputDim("X");
+ if (ctx->HasOutput(framework::GradVarName("X"))) {
+ ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
}
}
};
diff --git a/paddle/operators/concat_op.cc b/paddle/operators/concat_op.cc
index 07f847079e834716904dcc038d2097efd268bd3e..01cbfc33efcb4042438fbb398fbcca9457f1334f 100644
--- a/paddle/operators/concat_op.cc
+++ b/paddle/operators/concat_op.cc
@@ -24,31 +24,30 @@ class ConcatOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
- void InferShape(const framework::InferShapeContext &ctx) const override {
- PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"),
- "Output(Out) of ConcatOp should not be null.");
+ void InferShape(framework::InferShapeContextBase *ctx) const override {
+ PADDLE_ENFORCE(ctx->HasOutput("Out"),
+ "Output(Out) of ConcatOp should not be null.");
- auto ins = ctx.MultiInput("X");
- auto *out = ctx.Output("Out");
- size_t axis = static_cast(ctx.Attr("axis"));
+ auto ins = ctx->GetInputsDim("X");
+ size_t axis = static_cast(ctx->Attrs().Get("axis"));
size_t n = ins.size();
PADDLE_ENFORCE_GT(n, 1, "Input tensors count should > 1.");
- auto out_dims = ins[0]->dims();
+ auto out_dims = ins[0];
size_t in_zero_dims_size = out_dims.size();
for (size_t i = 1; i < n; i++) {
for (size_t j = 0; j < in_zero_dims_size; j++) {
if (j == axis) {
- out_dims[axis] += ins[i]->dims()[j];
+ out_dims[axis] += ins[i][j];
continue;
}
- PADDLE_ENFORCE_EQ(out_dims[j], ins[i]->dims()[j],
+ PADDLE_ENFORCE_EQ(out_dims[j], ins[i][j],
"Input tensors should have the same "
"elements except the specify axis.")
}
}
- out->Resize(out_dims);
+ ctx->SetOutputDim("Out", out_dims);
}
};
diff --git a/paddle/operators/cond_op.cc b/paddle/operators/cond_op.cc
index 8262a7a5c8c13c86c5f6c123a14fa89696358c57..1d44782b210bc0c40fd68dba29a16fa6959d6076 100644
--- a/paddle/operators/cond_op.cc
+++ b/paddle/operators/cond_op.cc
@@ -215,7 +215,7 @@ class CondOpProtoAndCheckerMaker : public framework::OpProtoAndCheckerMaker {
AddComment(R"DOC(
Sample dependent Cond Operator:
Given Cond[i] as a 1/0 vector to indicate true/false
-The equation is:
+The equation is:
Out[i] = subnet_t[i], if Cond[i] == true
Out[i] = subnet_t[i], if Cond[i] == false
)DOC");
diff --git a/paddle/operators/conv2d_op.cc b/paddle/operators/conv2d_op.cc
index c3281db0964de6d7dd6be629fbcc55cabb9fef9d..5cc82944bb6b9a4fc5cd94cf2233ab84fc105fe7 100644
--- a/paddle/operators/conv2d_op.cc
+++ b/paddle/operators/conv2d_op.cc
@@ -27,27 +27,25 @@ class Conv2DOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
- void InferShape(const framework::InferShapeContext &ctx) const override {
- PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Input"),
- "Input(Input) of Conv2DOp should not be null.");
- PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Filter"),
- "Input(Filter) of Conv2DOp should not be null.");
- PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Output"),
- "Output(Output) of Conv2DOp should not be null.");
-
- auto in = ctx.Input("Input");
- auto filter = ctx.Input("Filter");
- auto out = ctx.Output("Output");
- std::vector strides = Attr>("strides");
- std::vector paddings = Attr>("paddings");
- int groups = Attr("groups");
- int input_channels = in->dims()[1];
- int output_channels = filter->dims()[0];
-
- PADDLE_ENFORCE_EQ(in->dims().size(), 4, "Conv2DOp input should be 4-D.");
- PADDLE_ENFORCE_EQ(filter->dims().size(), 4,
- "Conv2DOp filter should be 4-D.");
- PADDLE_ENFORCE_EQ(input_channels, filter->dims()[1] * groups,
+ void InferShape(framework::InferShapeContextBase* ctx) const override {
+ PADDLE_ENFORCE(ctx->HasInput("Input"),
+ "Input(Input) of Conv2DOp should not be null.");
+ PADDLE_ENFORCE(ctx->HasInput("Filter"),
+ "Input(Filter) of Conv2DOp should not be null.");
+ PADDLE_ENFORCE(ctx->HasOutput("Output"),
+ "Output(Output) of Conv2DOp should not be null.");
+
+ auto in_dims = ctx->GetInputDim("Input");
+ auto filter_dims = ctx->GetInputDim("Filter");
+ std::vector strides = ctx->Attrs().Get>("strides");
+ std::vector paddings = ctx->Attrs().Get>("paddings");
+ int groups = ctx->Attrs().Get("groups");
+ int input_channels = in_dims[1];
+ int output_channels = filter_dims[0];
+
+ PADDLE_ENFORCE_EQ(in_dims.size(), 4, "Conv2DOp input should be 4-D.");
+ PADDLE_ENFORCE_EQ(filter_dims.size(), 4, "Conv2DOp filter should be 4-D.");
+ PADDLE_ENFORCE_EQ(input_channels, filter_dims[1] * groups,
"The number of input channels should be equal to filter "
"channels * groups.");
PADDLE_ENFORCE_EQ(
@@ -55,17 +53,17 @@ class Conv2DOp : public framework::OperatorWithKernel {
"The number of output channels should be divided by groups.");
auto output_height =
- outputSize(in->dims()[2], filter->dims()[2], paddings[0], strides[0]);
+ outputSize(in_dims[2], filter_dims[2], paddings[0], strides[0]);
auto output_width =
- outputSize(in->dims()[3], filter->dims()[3], paddings[1], strides[1]);
- out->Resize(
- {in->dims()[0], filter->dims()[0], output_height, output_width});
+ outputSize(in_dims[3], filter_dims[3], paddings[1], strides[1]);
+ ctx->SetOutputDim(
+ "Output", {in_dims[0], filter_dims[0], output_height, output_width});
}
};
class Conv2DOpMaker : public framework::OpProtoAndCheckerMaker {
public:
- Conv2DOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
+ Conv2DOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput(
"Input",
@@ -108,14 +106,15 @@ class Conv2DOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
- void InferShape(const framework::InferShapeContext &ctx) const override {
- auto in = ctx.Input("Input");
- auto filter = ctx.Input("Filter");
- auto d_in = ctx.Output(framework::GradVarName("Input"));
- auto d_filter =
- ctx.Output(framework::GradVarName("Filter"));
- if (d_in) d_in->Resize(in->dims());
- if (d_filter) d_filter->Resize(filter->dims());
+ void InferShape(framework::InferShapeContextBase* ctx) const override {
+ auto in_dims = ctx->GetInputDim("Input");
+ auto filter_dims = ctx->GetInputDim("Filter");
+ if (ctx->HasOutput(framework::GradVarName("Input"))) {
+ ctx->SetOutputDim(framework::GradVarName("Input"), in_dims);
+ }
+ if (ctx->HasOutput(framework::GradVarName("Filter"))) {
+ ctx->SetOutputDim(framework::GradVarName("Filter"), filter_dims);
+ }
}
};
diff --git a/paddle/operators/cos_sim_op.cc b/paddle/operators/cos_sim_op.cc
index b56ee2047b811e212b4bf74bf7fbba753a6bcb11..040546f1a6fe1af6d17a5e363a11d27de88d03c2 100644
--- a/paddle/operators/cos_sim_op.cc
+++ b/paddle/operators/cos_sim_op.cc
@@ -24,22 +24,22 @@ class CosSimOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
- void InferShape(const framework::InferShapeContext &ctx) const override {
+ void InferShape(framework::InferShapeContextBase* ctx) const override {
// notnull check
- PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"),
- "Input(X) of CosSimOp should not be null.");
- PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"),
- "Input(Y) of CosSimOp should not be null.");
- PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"),
- "Output(Out) of CosSimOp should not be null.");
- PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("XNorm"),
- "Output(XNorm) of CosSimOp should not be null.");
- PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("YNorm"),
- "Output(YNorm) of CosSimOp should not be null.");
+ PADDLE_ENFORCE(ctx->HasInput("X"),
+ "Input(X) of CosSimOp should not be null.");
+ PADDLE_ENFORCE(ctx->HasInput("Y"),
+ "Input(Y) of CosSimOp should not be null.");
+ PADDLE_ENFORCE(ctx->HasOutput("Out"),
+ "Output(Out) of CosSimOp should not be null.");
+ PADDLE_ENFORCE(ctx->HasOutput("XNorm"),
+ "Output(XNorm) of CosSimOp should not be null.");
+ PADDLE_ENFORCE(ctx->HasOutput("YNorm"),
+ "Output(YNorm) of CosSimOp should not be null.");
// shape check
- auto x_dims = ctx.Input("X")->dims();
- auto y_dims = ctx.Input("Y")->dims();
+ auto x_dims = ctx->GetInputDim("X");
+ auto y_dims = ctx->GetInputDim("Y");
PADDLE_ENFORCE_EQ(x_dims.size(), y_dims.size(),
"Ranks of Input(X) and Input(Y) must be equal.");
@@ -54,16 +54,16 @@ class CosSimOp : public framework::OperatorWithKernel {
" just 1 (which will be broadcasted to match Input(X)).");
// resize tensor
- ctx.Output("Out")->Resize({x_dims[0], 1});
- ctx.Output("XNorm")->Resize({x_dims[0], 1});
- ctx.Output("YNorm")->Resize({y_dims[0], 1});
- ctx.ShareLoD("X", /*->*/ "Out");
+ ctx->SetOutputDim("Out", {x_dims[0], 1});
+ ctx->SetOutputDim("XNorm", {x_dims[0], 1});
+ ctx->SetOutputDim("YNorm", {y_dims[0], 1});
+ ctx->ShareLoD("X", /*->*/ "Out");
}
};
class CosSimOpMaker : public framework::OpProtoAndCheckerMaker {
public:
- CosSimOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
+ CosSimOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The 1st input of cos_sim op.");
AddInput("Y", "The 2nd input of cos_sim op.");
@@ -98,27 +98,23 @@ class CosSimOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
- void InferShape(const framework::InferShapeContext &ctx) const override {
+ void InferShape(framework::InferShapeContextBase* ctx) const override {
// notnull check
- PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null.");
- PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), "Input(Y) must not be null.");
- PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("XNorm"),
- "Input(XNorm) must not be null.");
- PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("YNorm"),
- "Input(YNorm) must not be null.");
- PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Out"),
- "Input(Out) must not be null.");
- PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
- "Input(Out@GRAD) must not be null.");
+ PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null.");
+ PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) must not be null.");
+ PADDLE_ENFORCE(ctx->HasInput("XNorm"), "Input(XNorm) must not be null.");
+ PADDLE_ENFORCE(ctx->HasInput("YNorm"), "Input(YNorm) must not be null.");
+ PADDLE_ENFORCE(ctx->HasInput("Out"), "Input(Out) must not be null.");
+ PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
+ "Input(Out@GRAD) must not be null.");
// shape check
- auto x_dims = ctx.Input("X")->dims();
- auto y_dims = ctx.Input("Y")->dims();
- auto xnorm_dims = ctx.Input("XNorm")->dims();
- auto ynorm_dims = ctx.Input("YNorm")->dims();
- auto out_dims = ctx.Input("Out")->dims();
- auto out_grad_dims =
- ctx.Input(framework::GradVarName("Out"))->dims();
+ auto x_dims = ctx->GetInputDim("X");
+ auto y_dims = ctx->GetInputDim("Y");
+ auto xnorm_dims = ctx->GetInputDim("XNorm");
+ auto ynorm_dims = ctx->GetInputDim("YNorm");
+ auto out_dims = ctx->GetInputDim("Out");
+ auto out_grad_dims = ctx->GetInputDim(framework::GradVarName("Out"));
PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(),
"Ranks of Input(X) and Input(Y) must be equal.");
@@ -143,10 +139,14 @@ class CosSimOpGrad : public framework::OperatorWithKernel {
"Shape of Input(Out@Grad) must be [X.Dim(0), 1].");
// resize tensor
- auto *x_grad = ctx.Output(framework::GradVarName("X"));
- auto *y_grad = ctx.Output(framework::GradVarName("Y"));
- if (x_grad) x_grad->Resize(x_dims);
- if (y_grad) y_grad->Resize(y_dims);
+ auto x_grad_name = framework::GradVarName("X");
+ auto y_grad_name = framework::GradVarName("Y");
+ if (ctx->HasOutput(x_grad_name)) {
+ ctx->SetOutputDim(x_grad_name, x_dims);
+ }
+ if (ctx->HasOutput(y_grad_name)) {
+ ctx->SetOutputDim(y_grad_name, y_dims);
+ }
}
};
diff --git a/paddle/operators/crop_op.cc b/paddle/operators/crop_op.cc
index 52a1123348b10e39bcfa1ba062c893e5f20ed862..9b2305e90e85a6f39d4c584a3251b25f67e81aca 100644
--- a/paddle/operators/crop_op.cc
+++ b/paddle/operators/crop_op.cc
@@ -25,16 +25,14 @@ class CropOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
- void InferShape(const framework::InferShapeContext &ctx) const override {
- PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"),
- "Input(X) of CropOp should not be null.");
- PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"),
- "Output(Out) of CropOp should not be null.");
- auto x_dim = ctx.Input("X")->dims();
- auto *y = ctx.Input("Y");
- auto *out = ctx.Output("Out");
- if (y == nullptr) {
- auto shape = Attr>("shape");
+ void InferShape(framework::InferShapeContextBase* ctx) const override {
+ PADDLE_ENFORCE(ctx->HasInput("X"),
+ "Input(X) of CropOp should not be null.");
+ PADDLE_ENFORCE(ctx->HasOutput("Out"),
+ "Output(Out) of CropOp should not be null.");
+ auto x_dim = ctx->GetInputDim("X");
+ if (!ctx->HasInput("Y")) {
+ auto shape = ctx->Attrs().Get>("shape");
PADDLE_ENFORCE_EQ(
int64_t(shape.size()), x_dim.size(),
"Shape size should be equal to dimention size of input tensor.");
@@ -42,19 +40,20 @@ class CropOp : public framework::OperatorWithKernel {
for (size_t i = 0; i < shape.size(); ++i) {
tensor_shape[i] = static_cast(shape[i]);
}
- out->Resize(framework::make_ddim(tensor_shape));
+ ctx->SetOutputDim("Out", framework::make_ddim(tensor_shape));
} else {
- PADDLE_ENFORCE_EQ(framework::arity(x_dim), framework::arity(y->dims()),
+ auto y_dim = ctx->GetInputDim("Y");
+ PADDLE_ENFORCE_EQ(framework::arity(x_dim), framework::arity(y_dim),
"Tensor rank of both CropOp's "
"inputs must be same.");
- out->Resize(y->dims());
+ ctx->SetOutputDim("Out", y_dim);
}
}
};
class CropOpMaker : public framework::OpProtoAndCheckerMaker {
public:
- CropOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
+ CropOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X",
"The input of pad op. "
@@ -78,12 +77,12 @@ class CropOpMaker : public framework::OpProtoAndCheckerMaker {
Crop Operator.
Crop input into output, as specified by offsets and shape.
-There are two ways to set shape:
+There are two ways to set shape:
1. referenc input: crop input X as shape as reference input.
- The dimension of reference input should
+ The dimension of reference input should
be as same as input X.
2. shape list: crop input X by shape described by a list.
- The size of shape list should be as same as
+ The size of shape list should be as same as
dimension size of input X.
The input should be a k-D tensor(k > 0 and k < 7). As an example:
@@ -94,15 +93,15 @@ Given:
[0, 3, 4, 0, 0]
[0, 0, 0, 0, 0]]
-and
+and
offsets = [0, 1]
and
-
+
shape = [2, 2]
-then we get
+then we get
Out = [[1, 2],
[3, 4]]
@@ -116,14 +115,14 @@ class CropOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
- void InferShape(const framework::InferShapeContext &ctx) const override {
- PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null");
- PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
- "Input(Out@GRAD) should not be null");
- auto x_dims = ctx.Input("X")->dims();
- auto *x_grad = ctx.Output(framework::GradVarName("X"));
- if (x_grad != nullptr) {
- x_grad->Resize(x_dims);
+ void InferShape(framework::InferShapeContextBase* ctx) const override {
+ PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
+ PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
+ "Input(Out@GRAD) should not be null");
+ auto x_dims = ctx->GetInputDim("X");
+ auto x_grad_name = framework::GradVarName("X");
+ if (ctx->HasOutput(x_grad_name)) {
+ ctx->SetOutputDim(x_grad_name, x_dims);
}
}
};
diff --git a/paddle/operators/cross_entropy_op.cc b/paddle/operators/cross_entropy_op.cc
index b11dc1472d153dd188a0b3553d6950774216a3fd..26fc9b51c44d21d92851030449e116538f937846 100644
--- a/paddle/operators/cross_entropy_op.cc
+++ b/paddle/operators/cross_entropy_op.cc
@@ -22,32 +22,30 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
- void InferShape(const framework::InferShapeContext &ctx) const override {
- PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null.");
- PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"),
- "Input(Label) must not be null.");
- PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Y"), "Output(Y) must not be null.");
-
- auto x = ctx.Input("X");
- auto label = ctx.Input("Label");
- PADDLE_ENFORCE_EQ(x->dims().size(), 2, "Input(X)'s rank must be 2.");
- PADDLE_ENFORCE_EQ(label->dims().size(), 2,
- "Input(Label)'s rank must be 2.");
- PADDLE_ENFORCE_EQ(x->dims()[0], label->dims()[0],
- "The 1st dimension of Input(X) and Input(Label) must "
+ void InferShape(framework::InferShapeContextBase* ctx) const override {
+ PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null.");
+ PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null.");
+ PADDLE_ENFORCE(ctx->HasOutput("Y"), "Output(Y) should be not null.");
+
+ auto x_dims = ctx->GetInputDim("X");
+ auto label_dims = ctx->GetInputDim("Label");
+ PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank should be 2.");
+ PADDLE_ENFORCE_EQ(label_dims.size(), 2, "Input(Label)'s rank should be 2.");
+ PADDLE_ENFORCE_EQ(x_dims[0], label_dims[0],
+ "The 1st dimension of Input(X) and Input(Label) should "
"be equal.");
- if (ctx.Attr("soft_label")) {
- PADDLE_ENFORCE_EQ(x->dims()[1], label->dims()[1],
- "If Attr(soft_label) == true, The 2nd dimension of "
- "Input(X) and Input(Label) must be equal.");
+ if (ctx->Attrs().Get("softLabel")) {
+ PADDLE_ENFORCE_EQ(x_dims[1], label_dims[1],
+ "If Attr(softLabel) == true, the 2nd dimension of "
+ "Input(X) and Input(Label) should be equal.");
} else {
- PADDLE_ENFORCE_EQ(label->dims()[1], 1,
- "If Attr(soft_label) == false, The 2nd dimension of "
- "Input(Label) must be 1.");
+ PADDLE_ENFORCE_EQ(label_dims[1], 1,
+ "If Attr(softLabel) == false, the 2nd dimension of "
+ "Input(Label) should be 1.");
}
- ctx.Output("Y")->Resize({x->dims()[0], 1});
- ctx.ShareLoD("X", /*->*/ "Y");
+ ctx->SetOutputDim("Y", {x_dims[0], 1});
+ ctx->ShareLoD("X", /*->*/ "Y");
}
};
@@ -56,66 +54,79 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
- void InferShape(const framework::InferShapeContext &ctx) const override {
- PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null.");
- PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"),
- "Input(Label) must not be null.");
- PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Y")),
- "Input(Y@GRAD) must not be null.");
-
- auto x = ctx.Input("X");
- auto label = ctx.Input("Label");
- auto dy = ctx.Input(framework::GradVarName("Y"));
- PADDLE_ENFORCE_EQ(x->dims().size(), 2, "Input(X)'s rank must be 2.");
- PADDLE_ENFORCE_EQ(dy->dims().size(), 2, "Input(Y@Grad)'s rank must be 2.");
- PADDLE_ENFORCE_EQ(label->dims().size(), 2,
- "Input(Label)'s rank must be 2.");
- PADDLE_ENFORCE_EQ(x->dims()[0], label->dims()[0],
- "The 1st dimension of Input(X) and Input(Label) must "
+ void InferShape(framework::InferShapeContextBase* ctx) const override {
+ PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null.");
+ PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null.");
+ PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")),
+ "Input(Y@GRAD) shoudl be not null.");
+ PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
+ "Output(X@GRAD) should be not null.");
+
+ auto x_dims = ctx->GetInputDim("X");
+ auto label_dims = ctx->GetInputDim("Label");
+ auto dy_dims = ctx->GetInputDim(framework::GradVarName("Y"));
+ PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank should be 2.");
+ PADDLE_ENFORCE_EQ(dy_dims.size(), 2, "Input(Y@Grad)'s rank should be 2.");
+ PADDLE_ENFORCE_EQ(label_dims.size(), 2, "Input(Label)'s rank should be 2.");
+ PADDLE_ENFORCE_EQ(x_dims[0], label_dims[0],
+ "The 1st dimension of Input(X) and Input(Label) should "
"be equal.");
- PADDLE_ENFORCE_EQ(x->dims()[0], dy->dims()[0],
- "The 1st dimension of Input(X) and Input(Y@Grad) must "
+ PADDLE_ENFORCE_EQ(x_dims[0], dy_dims[0],
+ "The 1st dimension of Input(X) and Input(Y@Grad) should "
"be equal.");
- PADDLE_ENFORCE_EQ(dy->dims()[1], 1,
- "The 2nd dimension of Input(Y@Grad) must be 1.");
- if (ctx.Attr("soft_label")) {
- PADDLE_ENFORCE_EQ(x->dims()[1], label->dims()[1],
- "If Attr(soft_label) == true, The 2nd dimension of "
- "Input(X) and Input(Label) must be equal.");
+ PADDLE_ENFORCE_EQ(dy_dims[1], 1,
+ "The 2nd dimension of Input(Y@Grad) should be 1.");
+ if (ctx->Attrs().Get("softLabel")) {
+ PADDLE_ENFORCE_EQ(x_dims[1], label_dims[1],
+ "When Attr(softLabel) == true, the 2nd dimension of "
+ "Input(X) and Input(Label) should be equal.");
} else {
- PADDLE_ENFORCE_EQ(label->dims()[1], 1,
- "If Attr(soft_label) == false, The 2nd dimension of "
- "Input(Label) must be 1.");
+ PADDLE_ENFORCE_EQ(label_dims[1], 1,
+ "When Attr(softLabel) == false, the 2nd dimension of "
+ "Input(Label) should be 1.");
}
-
- auto dx = ctx.Output(framework::GradVarName("X"));
- dx->Resize(x->dims());
+ ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
}
};
class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker {
public:
- CrossEntropyOpMaker(framework::OpProto *proto,
- framework::OpAttrChecker *op_checker)
+ CrossEntropyOpMaker(framework::OpProto* proto,
+ framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
- AddInput("X", "The first input of CrossEntropyOp");
- AddInput("Label", "The second input of CrossEntropyOp");
- AddOutput("Y", "The output of CrossEntropyOp");
- AddAttr("soft_label", "Is soft label. Default zero.")
+ AddInput("X",
+ "(Tensor, default Tensor), a 2-D tensor with shape N x D, "
+ "where N is the batch size and D is the number of classes. "
+ "This input is a probability computed by the previous operator, "
+ "which is almost always the result of a softmax operator.");
+ AddInput(
+ "Label",
+ "(Tensor, default Tensor), the ground truth which is "
+ "a 2-D tensor. "
+ "When softLabel is set to false, `Label` is a Tensor with shape "
+ "[N x 1]. "
+ "When softLabel is set to true, `Label` is a Tensor "
+ "with shape [N x K].");
+ AddOutput("Y",
+ "(Tensor, default Tensor), a 2-D tensor "
+ "with shape [N x 1]. The cross entropy loss.");
+ AddAttr(
+ "softLabel",
+ "(bool, default false), a flag to indicate whether to interpretate "
+ "the given labels as soft labels.")
.SetDefault(false);
-
AddComment(R"DOC(
CrossEntropy Operator.
It supports both standard cross-entropy and soft-label cross-entropy loss
computation.
1) One-hot cross-entropy:
- soft_label = False, Label[i, 0] indicates the class index for sample i:
+ softLabel = false, Label[i, 0] indicates the class index for sample i:
Y[i] = -log(X[i, Label[i]])
2) Soft-label cross-entropy:
- soft_label = True, Label[i, j] indicates the soft label of class j
+ softLabel = true, Label[i, j] indicates the soft label of class j
for sample i:
Y[i] = \sum_j{-Label[i, j] * log(X[i, j])}
diff --git a/paddle/operators/cross_entropy_op.cu b/paddle/operators/cross_entropy_op.cu
index 1d6361a81472a49729958120c52060b1dff803f2..18e44d77c9f62b296dc57952e546f844670c7d57 100644
--- a/paddle/operators/cross_entropy_op.cu
+++ b/paddle/operators/cross_entropy_op.cu
@@ -28,26 +28,49 @@ __global__ void CrossEntropyKernel(T* Y, const T* X, const int* label,
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
i += blockDim.x * gridDim.x) {
PADDLE_ASSERT(label[i] >= 0 && label[i] < D);
- Y[i] = -tolerable_value(log(X[i * D + label[i]]));
+ Y[i] = -TolerableValue()(log(X[i * D + label[i]]));
}
}
+template
+__device__ __forceinline__ T sum_single_warp(T val) {
+ val += __shfl_down(val, 16);
+ val += __shfl_down(val, 8);
+ val += __shfl_down(val, 4);
+ val += __shfl_down(val, 2);
+ val += __shfl_down(val, 1);
+ return val;
+}
+
template
__global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label,
- const int N, const int D) {
- for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
- i += blockDim.x * gridDim.x) {
- T sum = static_cast(0);
- for (int j = 0; j < D; j++) {
- sum += label[i * D + j] * tolerable_value(log(X[i * D + j]));
- }
- Y[i] = -sum;
+ const int class_num) {
+ int tid = threadIdx.x;
+ extern __shared__ T d_sum[];
+ d_sum[tid] = 0;
+
+ int cur_idx = tid;
+ int next_idx = blockIdx.x * class_num + tid;
+ while (cur_idx < class_num) {
+ d_sum[tid] += TolerableValue()(std::log(X[next_idx])) * label[next_idx];
+ next_idx += blockDim.x;
+ cur_idx += blockDim.x;
+ }
+ __syncthreads();
+
+ for (unsigned int stride = blockDim.x >> 1; stride >= 32; stride >>= 1) {
+ if (tid < stride) d_sum[tid] += d_sum[tid + stride];
+ __syncthreads();
}
+
+ T val = d_sum[tid];
+ val = sum_single_warp(val);
+ if (tid == 0) Y[blockIdx.x] = -val;
}
-// TODO(qingqing): make zero setting an common function.
+// TODO(qingqing): make zero setting a common function.
template
-__global__ void zero(T* X, const int N) {
+__global__ void Zero(T* X, const int N) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
i += blockDim.x * gridDim.x) {
X[i] = 0.0;
@@ -71,13 +94,10 @@ template
__global__ void SoftCrossEntropyGradientKernel(T* dX, const T* dY, const T* X,
const T* label, const int N,
const int D) {
- // TOOD(qingqing): optimize for this kernel
- for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
- i += blockDim.x * gridDim.x) {
- for (int j = 0; j < D; ++j) {
- int idx = i * D + j;
- dX[idx] = -label[idx] * dY[i] / X[idx];
- }
+ int ids = blockIdx.x * blockDim.x + threadIdx.x;
+ if (ids < N * D) {
+ int row_ids = ids / D;
+ dX[ids] = -label[ids] * dY[row_ids] / X[ids];
}
}
@@ -86,29 +106,36 @@ class CrossEntropyOpCUDAKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
- "It must use GPUPlace.");
+ "This kernel only runs on GPU device.");
- auto x = ctx.Input("X");
- auto y = ctx.Output("Y");
- auto label = ctx.Input("Label");
+ const Tensor* x = ctx.Input("X");
+ const Tensor* label = ctx.Input("Label");
+ Tensor* y = ctx.Output("Y");
- auto* x_data = x->data();
- y->mutable_data(ctx.GetPlace());
- auto* y_data = y->data();
+ const T* x_data = x->data();
+ T* y_data = y->mutable_data(ctx.GetPlace());
- int n = x->dims()[0];
- int d = x->dims()[1];
- int block = 512;
- int grid = (n + block - 1) / block;
- // TODO(qingqing) launch kernel on specified stream
- // base on ExecutionContext.
- if (ctx.Attr("soft_label")) {
+ int batch_size = x->dims()[0];
+ int class_num = x->dims()[1];
+
+ if (ctx.Attr("softLabel")) {
auto* label_data = ctx.Input("Label")->data();
- SoftCrossEntropyKernel<<>>(y_data, x_data, label_data, n,
- d);
+ int block = class_num > 512 ? 512 : pow(2, int(std::log2(class_num)));
+
+ SoftCrossEntropyKernel<
+ T><<(
+ ctx.device_context())
+ .stream()>>>(y_data, x_data, label_data, class_num);
} else {
auto* label_data = ctx.Input("Label")->data();
- CrossEntropyKernel<<>>(y_data, x_data, label_data, n, d);
+ int block = 512;
+ int grid = (batch_size + block - 1) / block;
+ CrossEntropyKernel<<<
+ grid, block, 0, reinterpret_cast(
+ ctx.device_context())
+ .stream()>>>(y_data, x_data, label_data,
+ batch_size, class_num);
}
}
};
@@ -118,33 +145,43 @@ class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
- "It must use GPUPlace.");
+ "This kernel only runs on GPU device.");
+
+ const Tensor* x = ctx.Input("X");
+ const Tensor* label = ctx.Input("Label");
+ Tensor* dx = ctx.Output(framework::GradVarName("X"));
- auto x = ctx.Input("X");
- auto dx = ctx.Output(framework::GradVarName("X"));
- auto dy = ctx.Input(framework::GradVarName("Y"));
- auto label = ctx.Input("Label");
+ const T* dy_data =
+ ctx.Input(framework::GradVarName("Y"))->data();
+ T* dx_data = dx->mutable_data(ctx.GetPlace());
+ const T* x_data = x->data();
- auto* dx_data = dx->mutable_data(ctx.GetPlace());
- auto* dy_data = dy->data();
- auto* x_data = x->data();
+ int batch_size = x->dims()[0];
+ int class_num = x->dims()[1];
- int n = x->dims()[0];
- int d = x->dims()[1];
int block = 512;
- int grid = (n * d + block - 1) / block;
- zero<<>>(dx_data, n * d);
- grid = (n + block - 1) / block;
- // TODO(qingqing): launch kernel on specified stream
- // base on ExecutionContext.
- if (ctx.Attr("soft_label")) {
+ int grid = (batch_size * class_num + block - 1) / block;
+
+ if (ctx.Attr("softLabel")) {
auto* label_data = label->data();
- SoftCrossEntropyGradientKernel<<>>(
- dx_data, dy_data, x_data, label_data, n, d);
+ SoftCrossEntropyGradientKernel<<<
+ grid, block, 0, reinterpret_cast(
+ ctx.device_context())
+ .stream()>>>(dx_data, dy_data, x_data, label_data,
+ batch_size, class_num);
} else {
+ Zero<<(
+ ctx.device_context())
+ .stream()>>>(dx_data, batch_size * class_num);
+
auto* label_data = label->data();
- CrossEntropyGradientKernel<<>>(dx_data, dy_data, x_data,
- label_data, n, d);
+ grid = (batch_size + block - 1) / block;
+ CrossEntropyGradientKernel<<<
+ grid, block, 0, reinterpret_cast(
+ ctx.device_context())
+ .stream()>>>(dx_data, dy_data, x_data, label_data,
+ batch_size, class_num);
}
}
};
diff --git a/paddle/operators/cross_entropy_op.h b/paddle/operators/cross_entropy_op.h
index 69caba5ff31f60df2c24cef0e6331f058f6ba8d6..255b2e9f5ea7566cca7fd3914e38da804b7c7006 100644
--- a/paddle/operators/cross_entropy_op.h
+++ b/paddle/operators/cross_entropy_op.h
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
+#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/platform/hostdevice.h"
@@ -20,53 +21,51 @@ namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
+template
+using EigenMatrix = framework::EigenMatrix;
template
-HOSTDEVICE T tolerable_value(const T x) {
- PADDLE_ASSERT(std::is_floating_point::value);
- const T kApproInf = 1e20;
- if (x == INFINITY) {
- return kApproInf;
+struct TolerableValue {
+ HOSTDEVICE T operator()(const T& x) const {
+ PADDLE_ASSERT(std::is_floating_point::value);
+ const T kApproInf = 1e20;
+
+ if (x == INFINITY) return kApproInf;
+ if (x == -INFINITY) return -kApproInf;
+ return x;
}
- if (x == -INFINITY) {
- return -kApproInf;
- }
- return x;
-}
+};
template
class CrossEntropyOpKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
- "It must use CPUPlace.");
-
- auto x = ctx.Input("X");
- auto y = ctx.Output("Y");
-
- auto* x_data = x->data();
- y->mutable_data(ctx.GetPlace());
- auto* y_data = y->data();
-
- int batch_size = x->dims()[0];
- int class_num = x->dims()[1];
-
- if (ctx.Attr("soft_label")) {
- auto* label_data = ctx.Input("Label")->data();
- int index = 0;
- for (int i = 0; i < batch_size; ++i) {
- T sum = static_cast(0);
- for (int j = 0; j < class_num; ++j) {
- sum += label_data[index] * tolerable_value(std::log(x_data[index]));
- y_data[i] = -sum;
- index++;
- }
- }
+ "This kernel only runs on CPU.");
+ const Tensor* x = ctx.Input("X");
+ const Tensor* labels = ctx.Input("Label");
+ Tensor* y = ctx.Output("Y");
+ T* y_data = y->mutable_data(ctx.GetPlace());
+
+ const int batch_size = x->dims()[0];
+ if (ctx.Attr("softLabel")) {
+ auto prob = EigenMatrix::From(*x);
+ auto lbl_mat = EigenMatrix::From(*labels);
+ auto loss = EigenMatrix::From(*y);
+
+ loss.device(ctx.GetEigenDevice()) =
+ -((lbl_mat * prob.log().unaryExpr(TolerableValue()))
+ .sum(Eigen::DSizes(1))
+ .reshape(Eigen::DSizes(batch_size, 1)));
} else {
- auto* label_data = ctx.Input("Label")->data();
+ const int class_num = x->dims()[1];
+ const T* x_data = x->data();
+
+ const int* label_data = labels->data();
for (int i = 0; i < batch_size; ++i) {
int index = i * class_num + label_data[i];
- y_data[i] = -tolerable_value(std::log(x_data[index]));
+ y_data[i] = -TolerableValue()(std::log(x_data[index]));
}
}
}
@@ -77,33 +76,32 @@ class CrossEntropyGradientOpKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
- "It must use CPUPlace.");
-
- auto x = ctx.Input("X");
- auto dx = ctx.Output(framework::GradVarName("X"));
- auto dy = ctx.Input(framework::GradVarName("Y"));
- auto label = ctx.Input("Label");
+ "This kernel only runs on CPU.");
+ const Tensor* x = ctx.Input("X");
+ const Tensor* dy = ctx.Input(framework::GradVarName("Y"));
+ const Tensor* label = ctx.Input("Label");
+ Tensor* dx = ctx.Output(framework::GradVarName("X"));
+ T* dx_data = dx->mutable_data(ctx.GetPlace());
- auto* dx_data = dx->mutable_data(ctx.GetPlace());
- auto* dy_data = dy->data();
- auto* x_data = x->data();
-
- int batch_size = x->dims()[0];
int class_num = x->dims()[1];
-
- // TODO(qingqing): make zero setting an common function.
- if (ctx.Attr("soft_label")) {
- auto* label_data = ctx.Input("Label")->data();
- int index = 0;
- for (int i = 0; i < batch_size; ++i) {
- for (int j = 0; j < class_num; ++j) {
- dx_data[index] = -label_data[index] * dy_data[i] / x_data[index];
- index++;
- }
- }
+ if (ctx.Attr