class GreaterThanChecker {
diff --git a/paddle/framework/backward.md b/paddle/framework/backward.md
index 0a6d762bc8be5201ac196b4bc6107c06d07a31d7..ac60be572419b62f4beb644ff192d413c35e19bb 100644
--- a/paddle/framework/backward.md
+++ b/paddle/framework/backward.md
@@ -2,7 +2,7 @@
## Motivation
-In Neural Network, many model is solved by the the backpropagation algorithm(known as BP) at present. Technically it caculates the gradient of the loss function, then distributed back through the networks. Follows the chain rule, so we need a module chains the gradient operators/expressions together with to construct the backward pass. Every forward network needs a backward network to construct the full computation graph, the operator/expression's backward pass will be generated respect to forward pass.
+In Neural Network, most models are solved by the backpropagation algorithm(known as **BP**) at present. Technically, BP calculates the gradient of the loss function, then propagates it back through the networks following the chain rule. Hence we need a module that chains the gradient operators/expressions together to construct the backward pass. Every forward network needs a backward network to construct the full computation graph. The operator/expression's backward pass will be generated with respect to the forward pass.
## Implementation
@@ -24,9 +24,9 @@ A backward network is built up with several backward operators. Backward operato
| **Operator::inputs_** | Inputs | Inputs, Outputs, OutputGradients |
| **Operator::outputs_** | Outputs | InputGradients |
- In most cases, there is a one-to-one correspondence between the forward and backward operators. These correspondences are recorded by a global hash map(`OpInfoMap`). To follow the philosophy of minimum core and make operators pluggable, the registry mechanism is introduced.
+ In most cases, there is a one-to-one relation between the forward and backward operators. These relations are recorded by a global hash map(`OpInfoMap`). To follow the philosophy of minimum core and to make operators pluggable, the registry mechanism is introduced.
-For example, we have got a `mul_op`, and we can register its information and corresponding backward operator by the following macro:
+For example, we have `mul_op`, and we can register its information and corresponding backward operator by the following macro:
```cpp
REGISTER_OP(mul, MulOp, MulOpMaker, mul_grad, MulOpGrad);
@@ -48,7 +48,7 @@ The function `BuildGradOp` will sequentially execute following processes:
1. Get the `type_` of given forward operator, and then get the corresponding backward operator's type by looking up the `OpInfoMap`.
-2. Build two maps named `inputs` and `outputs` to temporary storage backward operator's inputs and outputs. Copy forward operator's `inputs_` and `outputs_` to map `inputs`, except these, are not necessary for gradient computing.
+2. Build two maps named `inputs` and `outputs` to temporarily store backward operator's inputs and outputs. Copy forward operator's `inputs_` and `outputs_` to map `inputs`, except these, are not necessary for gradient computing.
3. Add forward inputs' gradient variables into map `output`, adding forward outputs' gradient variables into map `input`.
@@ -56,11 +56,11 @@ The function `BuildGradOp` will sequentially execute following processes:
### Backward Network Building
-A backward network is a series of backward operators. The main idea of building a backward network is creating backward operators in the inverted sequence and append them together one by one. There is some corner case need to process specially.
+A backward network is a series of backward operators. The main idea of building a backward network is creating backward operators in the inverted sequence and appending them together one by one. There are some corner cases that need special processing.
1. Op
- When the input forward network is an Op, return its gradient Operator Immediately. If all of its outputs are in no gradient set, then return a special `NOP`.
+ When the input forward network is an Op, return its gradient Operator immediately. If all of its outputs are in no gradient set, then return a special `NOP`.
2. NetOp
@@ -68,33 +68,33 @@ A backward network is a series of backward operators. The main idea of building
3. RnnOp
- RnnOp is a nested stepnet operator. Backward module need to recusively call `Backward` for every stepnet.
+ RnnOp is a nested stepnet operator. Backward module needs to recusively call `Backward` for every stepnet.
4. Sharing Variables
- **sharing variables**. As illustrated in the pictures, two operator's share the same variable name of W@GRAD, which will overwrite their sharing input variable.
+ As illustrated in the figure 1 and figure 2, two operators share the same variable name **W@GRAD**, which will overwrite their shared input variable.
- pic 1. Sharing variables in operators.
+ Figure 1. Sharing variables in operators.
- Sharing variable between operators or same input variable used in multiple operators leads to a duplicate gradient variable. As demo show above, we need to rename gradient name recursively and add a generic add operator to replace the overwrite links.
+ Sharing variable between operators or same input variable used in multiple operators can lead to duplicate gradient variables. As illustrated in figure 2, we need to rename the gradient names recursively and add a generic add operator to prevent overwriting.
- pic 2. Replace sharing variable's gradient with `Add` operator.
+ Figure 2. Replace sharing variable's gradient with `Add` operator.
- Because our framework finds variables accord to their names, we need to rename the output links. We add a suffix of number to represent its position in clockwise.
+ Because the framework finds variables according to their names, we need to rename the output links. We add an integer suffix to represent its position in the clockwise direction.
-5. Part of Gradient is Zero.
+5. Part of the Gradient is Zero.
- In the whole graph, there is some case of that one operator's gradient is not needed, but its input's gradient is a dependency link of other operator, we need to fill a same shape gradient matrix in the position. In our implement, we insert a special `fillZeroLike` operator.
+ In the whole graph, there is some case of that one operator's gradient is not needed, but its input's gradient is a dependency link of other operator, we need to fill a same shape gradient matrix in the position. In our implementation, we insert a special `fillZeroLike` operator.
Follow these rules above, then collect the sub graph `OutputGradients`/`InputGradients` as the NetOp's and return it.
diff --git a/paddle/framework/framework.proto b/paddle/framework/framework.proto
index cf83d4cec312ac16366d84f897e7dc4784596ae8..951c7afbc14e2d9119169c1351d38ff0b67bdc5b 100644
--- a/paddle/framework/framework.proto
+++ b/paddle/framework/framework.proto
@@ -22,17 +22,11 @@ enum AttrType {
INTS = 3;
FLOATS = 4;
STRINGS = 5;
- INT_PAIRS = 6;
- BOOLEAN = 7;
- BOOLEANS = 8;
- BLOCK = 9;
+ BOOLEAN = 6;
+ BOOLEANS = 7;
+ BLOCK = 8;
}
-message IntPair {
- required int32 first = 1;
- required int32 second = 2;
-};
-
// OpDesc describes an instance of a C++ framework::OperatorBase
// derived class type.
message OpDesc {
@@ -46,7 +40,6 @@ message OpDesc {
repeated int32 ints = 6;
repeated float floats = 7;
repeated string strings = 8;
- repeated IntPair int_pairs = 9;
optional bool b = 10;
repeated bool bools = 11;
optional int32 block_idx = 12;
@@ -106,7 +99,7 @@ enum DataType {
message LoDTensorDesc {
required DataType data_type = 1;
- repeated int32 dims = 2; // [UNK, 640, 480] is saved as [-1, 640, 480]
+ repeated int64 dims = 2; // [UNK, 640, 480] is saved as [-1, 640, 480]
optional int32 lod_level = 3 [ default = 0 ];
}
diff --git a/paddle/framework/lod_tensor.cc b/paddle/framework/lod_tensor.cc
index 513e63657b1d8ee2e3ac6ca7668e5e53de1ae0e8..5b7badf89c1714331bae9fc8cf94c8da2c66dbad 100644
--- a/paddle/framework/lod_tensor.cc
+++ b/paddle/framework/lod_tensor.cc
@@ -88,20 +88,16 @@ size_t LoDTensor::NumElements(size_t level, size_t idx) const {
return tmp[1].size() - 1;
}
-void LoDTensor::SliceLevels(size_t level_begin, size_t level_end) {
+void LoDTensor::ShrinkLevels(size_t level_begin, size_t level_end) {
auto new_lod = framework::SliceLevels(lod_, level_begin, level_end);
lod_ = new_lod;
}
-void LoDTensor::SliceInLevel(size_t level, size_t elem_begin, size_t elem_end) {
- PADDLE_ENFORCE_LT(level, NumLevels(), "level [%d] out of range [%d]", level,
- NumLevels());
- PADDLE_ENFORCE_LT(elem_begin, NumElements(level),
- "element begin [%d] out of range [%d]", elem_begin,
- NumElements(level));
- PADDLE_ENFORCE_LT(elem_end, NumElements(level) + 1,
- "element end [%d] out of range [%d]", elem_end,
- NumElements(level));
+void LoDTensor::ShrinkInLevel(size_t level, size_t elem_begin,
+ size_t elem_end) {
+ PADDLE_ENFORCE_LT(level, NumLevels());
+ PADDLE_ENFORCE_LT(elem_begin, NumElements(level));
+ PADDLE_ENFORCE_LT(elem_end, NumElements(level) + 1);
auto new_lod = framework::SliceInLevel(lod_, level, elem_begin, elem_end);
lod_ = new_lod;
diff --git a/paddle/framework/lod_tensor.h b/paddle/framework/lod_tensor.h
index 176e1de4d4b661735ddaaacaa680f1cac1b2f286..49786a4a6635f1b39356dbf9633c4e7da443f04e 100644
--- a/paddle/framework/lod_tensor.h
+++ b/paddle/framework/lod_tensor.h
@@ -109,15 +109,15 @@ class LoDTensor : public Tensor {
size_t NumElements(size_t level, size_t idx) const;
/*
- * Slice of levels[level_begin:level_end]
+ * Shrink levels[level_begin:level_end]
*/
- void SliceLevels(size_t level_begin, size_t level_end);
+ void ShrinkLevels(size_t level_begin, size_t level_end);
/*
- * Slice of elements of a level, [elem_begin: elem_end]
+ * Shrink elements of a level, [elem_begin: elem_end]
* @note: low performance in slice lod_.
*/
- void SliceInLevel(size_t level, size_t elem_begin, size_t elem_end);
+ void ShrinkInLevel(size_t level, size_t elem_begin, size_t elem_end);
private:
LoD lod_;
diff --git a/paddle/framework/lod_tensor_test.cc b/paddle/framework/lod_tensor_test.cc
index 2b35c0bcd27aa9936ef0e861168d862cbb32a255..44f09f584fb752d7003baa804979f3bb5cd9d651 100644
--- a/paddle/framework/lod_tensor_test.cc
+++ b/paddle/framework/lod_tensor_test.cc
@@ -62,11 +62,11 @@ TEST_F(LoDTensorTester, NumElements2) {
ASSERT_EQ(lod_tensor_.NumElements(1, 1), 2UL);
}
-TEST_F(LoDTensorTester, SliceLevels) {
- // shrink 1 level
+TEST_F(LoDTensorTester, ShrinkLevels) {
+ // slice 1 level
for (size_t level = 0; level < 3UL; ++level) {
LoDTensor new_lod_tensor = lod_tensor_;
- new_lod_tensor.SliceLevels(level, level + 1);
+ new_lod_tensor.ShrinkLevels(level, level + 1);
ASSERT_EQ(new_lod_tensor.NumLevels(), 1UL);
ASSERT_EQ(new_lod_tensor.NumElements(0), lod_tensor_.NumElements(level));
ASSERT_EQ(new_lod_tensor.data(), lod_tensor_.data());
@@ -74,7 +74,7 @@ TEST_F(LoDTensorTester, SliceLevels) {
// shrink 2 level
for (size_t level = 0; level < 2UL; ++level) {
LoDTensor new_lod_tensor = lod_tensor_;
- new_lod_tensor.SliceLevels(level, level + 2);
+ new_lod_tensor.ShrinkLevels(level, level + 2);
ASSERT_EQ(new_lod_tensor.NumLevels(), 2UL);
ASSERT_EQ(new_lod_tensor.NumElements(0), lod_tensor_.NumElements(level));
ASSERT_EQ(new_lod_tensor.NumElements(1),
@@ -83,10 +83,10 @@ TEST_F(LoDTensorTester, SliceLevels) {
}
}
-TEST_F(LoDTensorTester, SliceInLevel) {
+TEST_F(LoDTensorTester, ShrinkInLevel) {
size_t level = 0;
LoDTensor new_lod_tensor = lod_tensor_;
- new_lod_tensor.SliceInLevel(level, 0, 2);
+ new_lod_tensor.ShrinkInLevel(level, 0, 2);
EXPECT_EQ(new_lod_tensor.NumLevels(), 3UL);
EXPECT_EQ(new_lod_tensor.NumElements(0), 2UL);
EXPECT_EQ(new_lod_tensor.NumElements(1), 4UL);
@@ -95,7 +95,7 @@ TEST_F(LoDTensorTester, SliceInLevel) {
level = 1;
new_lod_tensor = lod_tensor_;
- new_lod_tensor.SliceInLevel(level, 0, 2);
+ new_lod_tensor.ShrinkInLevel(level, 0, 2);
ASSERT_EQ(new_lod_tensor.NumLevels(), 2UL);
ASSERT_EQ(new_lod_tensor.NumElements(0), 2UL);
ASSERT_EQ(new_lod_tensor.NumElements(1), 4UL);
diff --git a/paddle/framework/op_registry_test.cc b/paddle/framework/op_registry_test.cc
index e00c6e8d904508ec9985537fc703c7c61a14e0de..b8fdf69683e645d991cf8dc2297b486680445a00 100644
--- a/paddle/framework/op_registry_test.cc
+++ b/paddle/framework/op_registry_test.cc
@@ -174,4 +174,4 @@ TEST(OpRegistry, CustomChecker) {
op->Run(scope, dev_ctx);
int test_attr = op->Attr("test_attr");
ASSERT_EQ(test_attr, 4);
-}
\ No newline at end of file
+}
diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc
index fcbfc3e4377edd0ea55c8d4328c325fa18663001..a3f28339aa64c6bde3fcefdae8b0973a5bbdd585 100644
--- a/paddle/framework/operator.cc
+++ b/paddle/framework/operator.cc
@@ -14,7 +14,6 @@ limitations under the License. */
#include "paddle/framework/operator.h"
#include
-#include "paddle/framework/op_registry.h"
namespace paddle {
namespace framework {
@@ -33,6 +32,24 @@ ExecutionContext::GetEigenDevice() const {
}
#endif
+const Tensor* GetTensorFromVar(const Variable* var) {
+ if (var->IsType()) {
+ return &var->Get();
+ }
+ PADDLE_ENFORCE(var->IsType(),
+ "The Input must be LoDTensor or Tensor.");
+ return &var->Get();
+}
+
+Tensor* GetTensorFromVar(Variable* var) {
+ if (var->IsType()) {
+ return var->GetMutable();
+ }
+ PADDLE_ENFORCE(var->IsType(),
+ "The Input must be LoDTensor or Tensor.");
+ return var->GetMutable();
+}
+
std::string OperatorBase::Input(const std::string& name) const {
auto& ins = Inputs(name);
PADDLE_ENFORCE_LE(ins.size(), 1UL,
diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h
index 2d6d5510ef6dc83f1a016be6ff123f0b9bcaf230..77c7c855c0ffed5032e639237b01037a990652c4 100644
--- a/paddle/framework/operator.h
+++ b/paddle/framework/operator.h
@@ -24,6 +24,7 @@ limitations under the License. */
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/scope.h"
+#include "paddle/framework/shape_inference.h"
#include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/place.h"
@@ -56,6 +57,9 @@ class OperatorBase;
class InferShapeContext;
class ExecutionContext;
+extern const Tensor* GetTensorFromVar(const Variable* var);
+extern Tensor* GetTensorFromVar(Variable* var);
+
/**
* OperatorBase has the basic element that Net will call to do computation.
* Only CreateOperator from OpRegistry will new Operator directly. User
@@ -262,15 +266,6 @@ class InferShapeContext {
return res;
}
- const Tensor* GetTensorFromVar(const Variable* var) const {
- if (var->IsType()) {
- return &var->Get();
- }
- PADDLE_ENFORCE(var->IsType(),
- "The Input(%s) must be LoDTensor or Tensor.");
- return &var->Get();
- }
-
void ShareLoD(const std::string& in, const std::string& out, size_t i = 0,
size_t j = 0) const {
PADDLE_ENFORCE_LT(i, InputSize(in));
@@ -340,6 +335,78 @@ class ExecutionContext : public InferShapeContext {
const platform::DeviceContext& device_context_;
};
+class RuntimeInferShapeContext : public InferShapeContextBase {
+ public:
+ RuntimeInferShapeContext(const OperatorBase& op, const Scope& scope)
+ : op_(op), scope_(scope) {}
+
+ bool HasInput(const std::string& name) const {
+ auto ipt = op_.Input(name);
+ auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
+ return var != nullptr;
+ }
+
+ bool HasOutput(const std::string& name) const {
+ auto ipt = op_.Output(name);
+ auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
+ return var != nullptr;
+ }
+
+ DDim GetInputDim(const std::string& name) const {
+ return GetDim(op_.Input(name));
+ }
+
+ void SetInputDim(const std::string& name, const DDim& dim) {
+ SetDim(op_.Input(name), dim);
+ }
+
+ DDim GetOutputDim(const std::string& name) const {
+ return GetDim(op_.Output(name));
+ }
+
+ void SetOutputDim(const std::string& name, const DDim& dim) {
+ SetDim(op_.Output(name), dim);
+ }
+
+ AttrReader Attrs() const { return AttrReader(op_.Attrs()); }
+
+ const std::vector& Inputs(const std::string& name) const {
+ return op_.Inputs(name);
+ }
+
+ const std::vector& Outputs(const std::string& name) const {
+ return op_.Outputs(name);
+ }
+
+ private:
+ template
+ Tensor* GetTensor(const std::string& name) const {
+ Tensor* t = nullptr;
+ auto* var = scope_.FindVar(name);
+ if (!var->IsType() && !var->IsType()) {
+ if (Allocate) {
+ t = var->GetMutable();
+ } else {
+ PADDLE_THROW("Variable(%s) should be tensor", name);
+ }
+ } else {
+ t = GetTensorFromVar(scope_.FindVar(name));
+ }
+ return t;
+ }
+
+ DDim GetDim(const std::string& name) const {
+ return GetTensor(name)->dims();
+ }
+
+ void SetDim(const std::string& name, const DDim& dim) {
+ GetTensor(name)->Resize(dim);
+ }
+
+ const OperatorBase& op_;
+ const Scope& scope_;
+};
+
class OpKernel {
public:
/**
@@ -383,8 +450,10 @@ class OperatorWithKernel : public OperatorBase {
const VariableNameMap& outputs, const AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
+ // runtime infershape
void InferShape(const Scope& scope) const override {
- InferShape(InferShapeContext(*this, scope));
+ auto c = RuntimeInferShapeContext(*this, scope);
+ InferShape(&c);
}
void Run(const Scope& scope,
@@ -406,7 +475,7 @@ class OperatorWithKernel : public OperatorBase {
}
protected:
- virtual void InferShape(const InferShapeContext& ctx) const = 0;
+ virtual void InferShape(InferShapeContextBase* ctx) const = 0;
};
} // namespace framework
diff --git a/paddle/framework/operator_test.cc b/paddle/framework/operator_test.cc
index 0beab0fac5b94c78121261d2661a6f969289afc4..8b4bb01a7bb80eaccee40f14fa82617505b1e2e5 100644
--- a/paddle/framework/operator_test.cc
+++ b/paddle/framework/operator_test.cc
@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/framework/operator.h"
#include "gtest/gtest.h"
+#include "paddle/framework/op_info.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
@@ -114,7 +115,7 @@ class OpWithKernelTest : public OperatorWithKernel {
using OperatorWithKernel::OperatorWithKernel;
protected:
- void InferShape(const framework::InferShapeContext& ctx) const override {}
+ void InferShape(framework::InferShapeContextBase* ctx) const override {}
};
template
diff --git a/paddle/framework/shape_inference.h b/paddle/framework/shape_inference.h
new file mode 100644
index 0000000000000000000000000000000000000000..b07fc788124413f728c713027609d9d2d1c39538
--- /dev/null
+++ b/paddle/framework/shape_inference.h
@@ -0,0 +1,82 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#pragma once
+
+#include "paddle/framework/ddim.h"
+
+namespace paddle {
+namespace framework {
+
+class InferShapeContextBase {
+ public:
+ virtual ~InferShapeContextBase() {}
+ virtual bool HasInput(const std::string &name) const = 0;
+ virtual bool HasOutput(const std::string &name) const = 0;
+ virtual framework::DDim GetInputDim(const std::string &name) const = 0;
+ std::vector GetInputsDim(const std::string &name) const {
+ const std::vector &names = Inputs(name);
+ return GetDims(names);
+ }
+ virtual void SetInputDim(const std::string &name,
+ const framework::DDim &dim) = 0;
+ void SetInputsDim(const std::string &name,
+ const std::vector &dims) {
+ auto &names = Inputs(name);
+ SetDims(names, dims);
+ }
+ virtual framework::DDim GetOutputDim(const std::string &name) const = 0;
+ std::vector GetOutputsDim(const std::string &name) const {
+ const std::vector &names = Outputs(name);
+ return GetDims(names);
+ }
+ virtual void SetOutputDim(const std::string &name, const DDim &dim) = 0;
+ void SetOutputsDim(const std::string &name,
+ const std::vector &dims) {
+ auto &names = Outputs(name);
+ SetDims(names, dims);
+ }
+ virtual AttrReader Attrs() const = 0;
+ virtual const std::vector &Inputs(
+ const std::string &name) const = 0;
+ virtual const std::vector &Outputs(
+ const std::string &name) const = 0;
+ // TODO(qiao) implement this function
+ void ShareLoD(const std::string &in, const std::string &out, size_t i = 0,
+ size_t j = 0) const {}
+
+ protected:
+ virtual framework::DDim GetDim(const std::string &name) const = 0;
+ virtual void SetDim(const std::string &name, const framework::DDim &dim) = 0;
+ std::vector GetDims(
+ const std::vector &names) const {
+ std::vector ret;
+ ret.reserve(names.size());
+ std::transform(
+ names.begin(), names.end(), std::back_inserter(ret),
+ [this](const std::string &name) { return this->GetDim(name); });
+ return ret;
+ }
+ void SetDims(const std::vector &names,
+ const std::vector &dims) {
+ size_t length = names.size();
+ PADDLE_ENFORCE_EQ(length, dims.size());
+ for (size_t i = 0; i < length; ++i) {
+ SetDim(names[i], dims[i]);
+ }
+ }
+};
+
+} // namespace framework
+} // namespace paddle
diff --git a/paddle/framework/variable.md b/paddle/framework/variable.md
index f44d5ea46e7ce98dd443d684ad42308496bc4179..442ef6b718b227d79ca73031efcbb55817558252 100644
--- a/paddle/framework/variable.md
+++ b/paddle/framework/variable.md
@@ -7,7 +7,7 @@ Variable is also known as *blob* in MxNet and Caffe2. It is the input and outpu
For the flexibility of a DL system, a variable should be able to contain any typed value -- a tensor in most cases, but could also be some integer IDs or a scope of other variables in the case of RNN.
-To use the minimum amount of memory, we'd like that a variable to allocate memory when it has to, or, lazy memory allocation. Let's take the following example:
+To use the minimum amount of memory, we would like that a variable allocates memory only when it has to, or, lazy memory allocation. Let's take the following example:
```cpp
Variable vr, v1, v2;
@@ -38,7 +38,7 @@ This syntax for lazy memory allocation when we call `Randomize` and `Mult`, thos
To make memory allocation lazy, we cannot assume that we know the type held by a variable at definition time. In other words, `class Variable` cannot be a template `template class Variable`.
-Because we don't know the type `T`, we cannot save a `T*` as `Variable's` data member. Instead, we save an interface object `Placeholder`, who can return the pointer to the saved object via `Placeholder::Ptr()` as `void*`.
+Because we don't know the type `T`, we cannot save a `T*` as `Variable's` data member. Instead, we save an interface object `Placeholder`, which can return the pointer to the saved object via `Placeholder::Ptr()` as `void*`.
But anyway, Variable needs to know `T` so could it `delete(ptr)` and so could `Variable::Get` checks the expected type and the saved object's type.
@@ -49,4 +49,4 @@ Because `PlaceholderImpl` knows `T`, it can save and return `typeid(T)` for the
## Conclusion
-The technique type hiding utilizes C++ class templates, interface and derivation, and C++ RTTI (typeid). This combination saves us from definition something like `caffe2::TypeMata`, which takes hundreds of lines of C++ code.
+The technique type hiding utilizes C++ class templates, interface and derivation, and C++ RTTI (typeid). This combination saves us from defining something like `caffe2::TypeMeta`, which takes hundreds of lines of C++ code.
diff --git a/paddle/gserver/activations/MKLDNNActivation.cpp b/paddle/gserver/activations/MKLDNNActivation.cpp
index ac50937ef3e28c1ac5aae651f9cf266ad07abcc4..18c5638100065109fba1f0647a1c5f91256f7b9d 100644
--- a/paddle/gserver/activations/MKLDNNActivation.cpp
+++ b/paddle/gserver/activations/MKLDNNActivation.cpp
@@ -27,31 +27,53 @@ static ClassRegistrar gMKLDNNActivationRegistrar;
#define MKLDNN_ACTIVATION_CLASS_NAME(ACT_TYPE) mkldnn_##ACT_TYPE##Activation
/**
- * @def DEFINE_MKLDNN_ELTWISE_ACTIVATION
+ * @def BEGIN_MKLDNN_ACTIVATION
+ */
+#define BEGIN_MKLDNN_ACTIVATION(ACT_TYPE, BASE_CLASS) \
+ class MKLDNN_ACTIVATION_CLASS_NAME(ACT_TYPE) : public BASE_CLASS {
+/**
+ * @def END_MKLDNN_ACTIVATION
*/
-#define DEFINE_MKLDNN_ELTWISE_ACTIVATION(ACT_TYPE, ALPHA, BWD_ALPHA) \
- class MKLDNN_ACTIVATION_CLASS_NAME(ACT_TYPE) \
- : public MKLDNNEltwiseActivation { \
- private: \
- static const std::string name; \
- static const float alpha; \
- static const float bwdAlpha; \
- \
- public: \
- const std::string& getName() const { return name; } \
- float getAlpha() const { return alpha; } \
- float getBwdAlpha() const { return bwdAlpha; } \
- }; \
- const std::string MKLDNN_ACTIVATION_CLASS_NAME(ACT_TYPE)::name = \
- "mkldnn_" #ACT_TYPE; \
- const float MKLDNN_ACTIVATION_CLASS_NAME(ACT_TYPE)::alpha = ALPHA; \
- const float MKLDNN_ACTIVATION_CLASS_NAME(ACT_TYPE)::bwdAlpha = BWD_ALPHA; \
- static InitFunction __reg_activation__mkldnn_##ACT_TYPE([] { \
- gMKLDNNActivationRegistrar \
- .registerClass( \
- "mkldnn_" #ACT_TYPE); \
+#define END_MKLDNN_ACTIVATION(ACT_TYPE) \
+private: \
+ static const std::string name; \
+ \
+public: \
+ const std::string& getName() const { return name; } \
+ } \
+ ; \
+ const std::string MKLDNN_ACTIVATION_CLASS_NAME(ACT_TYPE)::name = \
+ "mkldnn_" #ACT_TYPE; \
+ static InitFunction __reg_activation__mkldnn_##ACT_TYPE([] { \
+ gMKLDNNActivationRegistrar \
+ .registerClass( \
+ "mkldnn_" #ACT_TYPE); \
});
+/**
+ * @def DEFINE_MKLDNN_ACTIVATION
+ */
+#define DEFINE_MKLDNN_ACTIVATION(ACT_TYPE, BASE_CLASS) \
+ BEGIN_MKLDNN_ACTIVATION(ACT_TYPE, BASE_CLASS) \
+ END_MKLDNN_ACTIVATION(ACT_TYPE)
+
+/**
+ * @def DEFINE_MKLDNN_ELTWISE_ACTIVATION
+ */
+#define DEFINE_MKLDNN_ELTWISE_ACTIVATION( \
+ ACT_TYPE, BASE_CLASS, ALPHA, BWD_ALPHA) \
+ BEGIN_MKLDNN_ACTIVATION(ACT_TYPE, BASE_CLASS) \
+private: \
+ static const float alpha; \
+ static const float bwdAlpha; \
+ \
+public: \
+ float getAlpha() const { return alpha; } \
+ float getBwdAlpha() const { return bwdAlpha; } \
+ END_MKLDNN_ACTIVATION(ACT_TYPE) \
+ const float MKLDNN_ACTIVATION_CLASS_NAME(ACT_TYPE)::alpha = ALPHA; \
+ const float MKLDNN_ACTIVATION_CLASS_NAME(ACT_TYPE)::bwdAlpha = BWD_ALPHA;
+
/**
* @brief MKLDNN Relu Activation.
* Actually mkldnn_relu is Leaky Relu.
@@ -59,19 +81,129 @@ static ClassRegistrar gMKLDNNActivationRegistrar;
* f(x) = negative_slope * x (x < 0)
* @note the negative_slope should be -0.f in forward
*/
-DEFINE_MKLDNN_ELTWISE_ACTIVATION(relu, -0.f, 0.f)
+DEFINE_MKLDNN_ELTWISE_ACTIVATION(relu, MKLDNNEltwiseActivation, -0.f, 0.f)
/**
* @brief MKLDNN Tanh Activation.
*/
-DEFINE_MKLDNN_ELTWISE_ACTIVATION(tanh, 0.f, 0.f)
+DEFINE_MKLDNN_ELTWISE_ACTIVATION(tanh, MKLDNNEltwiseActivation, 0.f, 0.f)
/**
* @brief MKLDNN ELU(Exponential Linear Unit) Activation.
* f(x) = x (x >= 0)
* f(x) = negative_slope * (exp(x) - 1) (x < 0)
*/
-DEFINE_MKLDNN_ELTWISE_ACTIVATION(elu, 0.f, 0.f)
+DEFINE_MKLDNN_ELTWISE_ACTIVATION(elu, MKLDNNEltwiseActivation, 0.f, 0.f)
+
+mkldnn::algorithm MKLDNNEltwiseActivation::getAlgo(std::string type) const {
+ const std::map algoMap = {
+ {"relu", algorithm::eltwise_relu},
+ {"tanh", algorithm::eltwise_tanh},
+ {"elu", algorithm::eltwise_elu}};
+ type.erase(0, 7); // remove mkldnn_
+ algorithm algo = (algorithm)0;
+ mapGet(type, algoMap, &algo);
+ return algo;
+}
+
+void MKLDNNEltwiseActivation::resetFwd(Argument& act) {
+ if (cnt_ == act.value->getElementCnt()) {
+ return;
+ }
+ MKLDNNActivation::resetFwd(act);
+ // note: alpha represents the NegativeSlope when used in relu.
+ float alpha = getAlpha();
+ float beta = getBeta();
+ algorithm algo = getAlgo(this->getName());
+ auto fwdDesc = eltwise_fwd::desc(mkldnn::prop_kind::forward_training,
+ algo,
+ val_->getMemoryDesc(),
+ alpha,
+ beta);
+ fwdPD_.reset(new eltwise_fwd::primitive_desc(fwdDesc, *engine_));
+ // use inplace for forward but save input value before submit
+ inVal_ = val_;
+ copyInVal_ = nullptr;
+ if (act.grad && algo == algorithm::eltwise_tanh) {
+ // tanh need save src input for backward
+ inVal_ = MKLDNNMatrix::create(nullptr, val_->getPrimitiveDesc());
+ copyInVal_ = std::make_shared(*val_, *inVal_);
+ CHECK(copyInVal_) << "should not be emptry";
+ pipelineFwd_.push_back(*copyInVal_);
+ }
+ fwd_.reset(new eltwise_fwd(*fwdPD_, *val_, *val_));
+ pipelineFwd_.push_back(*fwd_);
+ needResetBwd_ = true;
+}
+
+void MKLDNNEltwiseActivation::resetBwd(Argument& act) {
+ if (!needResetBwd_) {
+ return;
+ }
+ VLOG(MKLDNN_BASE) << getName() << " reset mkldnn backward";
+ needResetBwd_ = false;
+ algorithm algo = getAlgo(this->getName());
+ float alpha = getBwdAlpha();
+ float beta = getBeta();
+ grad_ = MKLDNNMatrix::create(act.grad, val_->getPrimitiveDesc());
+ auto eng = CPUEngine::Instance().getEngine();
+ auto bwdDesc = eltwise_bwd::desc(
+ algo, grad_->getMemoryDesc(), val_->getMemoryDesc(), alpha, beta);
+ auto bwdPD = eltwise_bwd::primitive_desc(bwdDesc, eng, *fwdPD_);
+ CHECK(inVal_);
+ bwd_.reset(new eltwise_bwd(bwdPD, *inVal_, *grad_, *grad_));
+ pipelineBwd_.clear();
+ pipelineBwd_.push_back(*bwd_);
+}
+
+/**
+ * @brief MKLDNN Softmax Activation
+ */
+DEFINE_MKLDNN_ACTIVATION(softmax, MKLDNNSoftmaxActivation)
+
+void MKLDNNSoftmaxActivation::resetFwd(Argument& act) {
+ if (cnt_ == act.value->getElementCnt()) {
+ return;
+ }
+ MKLDNNActivation::resetFwd(act);
+ int axis = 1;
+ auto fwdDesc = softmax_fwd::desc(
+ mkldnn::prop_kind::forward_scoring, val_->getMemoryDesc(), axis);
+ auto fwdPD = softmax_fwd::primitive_desc(fwdDesc, *engine_);
+ fwd_.reset(new softmax_fwd(fwdPD, *val_, *val_));
+ pipelineFwd_.push_back(*fwd_);
+}
+
+Error __must_check MKLDNNSoftmaxActivation::forward(Argument& act) {
+ resetFwd(act);
+ stream_->submit(pipelineFwd_);
+ real* v = act.value->getData();
+ real threshold = exp(-64);
+#pragma omp parallel for
+ for (size_t i = 0; i < act.value->getElementCnt(); ++i) {
+ v[i] = v[i] < threshold ? threshold : v[i];
+ }
+ return Error();
+}
+
+Error __must_check MKLDNNSoftmaxActivation::backward(Argument& act) {
+ MatrixPtr outputV = act.value;
+ MatrixPtr outputG = act.grad;
+ Matrix::resizeOrCreate(sftMaxDot_,
+ outputG->getHeight(),
+ outputG->getWidth(),
+ /* trans */ false,
+ /* useGpu */ false);
+ Matrix::resizeOrCreate(sftMaxSum_,
+ outputG->getHeight(),
+ 1,
+ /* trans */ false,
+ /* useGpu */ false);
+ sftMaxDot_->dotMul(*outputG, *outputV);
+ sftMaxSum_->colMerge(*sftMaxDot_);
+ act.grad->softmaxDerivative(*act.value, *sftMaxSum_);
+ return Error();
+}
ActivationFunction* MKLDNNActivation::create(const std::string& type) {
return gMKLDNNActivationRegistrar.createByType(type);
@@ -84,4 +216,34 @@ std::vector MKLDNNActivation::getAllRegisteredTypes() {
return types;
}
+void MKLDNNActivation::resetFwd(Argument& act) {
+ VLOG(MKLDNN_BASE) << getName() << " reset mkldnn forward";
+ cnt_ = act.value->getElementCnt();
+ pipelineFwd_.clear();
+ stream_.reset(new MKLDNNStream());
+ engine_.reset(new mkldnn::engine(mkldnn::engine::cpu, 0));
+ val_ = std::dynamic_pointer_cast(act.value);
+ if (val_ == nullptr) {
+ int bs = act.getBatchSize();
+ int ih = act.getFrameHeight() > 0 ? act.getFrameHeight() : 1;
+ int iw = act.getFrameWidth() > 0 ? act.getFrameWidth() : 1;
+ int ic = cnt_ / bs / ih / iw;
+ CHECK_EQ(cnt_, (size_t)bs * ic * ih * iw);
+ val_ = MKLDNNMatrix::create(
+ act.value, {bs, ic, ih, iw}, mkldnn::memory::format::nchw, *engine_);
+ CHECK(val_);
+ val_->downSpatial();
+ }
+}
+
+Error __must_check MKLDNNActivation::forward(Argument& act) {
+ resetFwd(act);
+ stream_->submit(pipelineFwd_);
+ return Error();
+}
+Error __must_check MKLDNNActivation::backward(Argument& act) {
+ resetBwd(act);
+ stream_->submit(pipelineBwd_);
+ return Error();
+}
} // namespace paddle
diff --git a/paddle/gserver/activations/MKLDNNActivation.h b/paddle/gserver/activations/MKLDNNActivation.h
index 86ffe387366409d81a91740cc8cea886e618f7e2..dd16421fd6e93b49c30b1d3b601f95980afec57b 100644
--- a/paddle/gserver/activations/MKLDNNActivation.h
+++ b/paddle/gserver/activations/MKLDNNActivation.h
@@ -36,6 +36,7 @@ protected:
// mkldnn matrix, primitive, stream and pipeline
MKLDNNMatrixPtr val_;
MKLDNNMatrixPtr grad_;
+ std::shared_ptr engine_;
std::shared_ptr stream_;
std::shared_ptr fwd_;
std::shared_ptr bwd_;
@@ -48,8 +49,18 @@ public:
static ActivationFunction* create(const std::string& type);
static std::vector getAllRegisteredTypes();
virtual const std::string& getName() const = 0;
- virtual Error __must_check forward(Argument& act) = 0;
- virtual Error __must_check backward(Argument& act) = 0;
+ /**
+ * reset the forward primitives
+ */
+ virtual void resetFwd(Argument& act);
+ /**
+ * reset the backward primitives,
+ * can not merge this functions into resetFwd as the grad data
+ * would be changing before backward.
+ */
+ virtual void resetBwd(Argument& act) {}
+ virtual Error __must_check forward(Argument& act);
+ virtual Error __must_check backward(Argument& act);
};
/**
@@ -59,6 +70,7 @@ public:
class MKLDNNEltwiseActivation : public MKLDNNActivation {
typedef mkldnn::eltwise_forward eltwise_fwd;
typedef mkldnn::eltwise_backward eltwise_bwd;
+ typedef mkldnn::algorithm algorithm;
protected:
// save the forward primitive desc, which can be used backward
@@ -70,9 +82,7 @@ protected:
public:
MKLDNNEltwiseActivation() {}
-
~MKLDNNEltwiseActivation() {}
-
virtual const std::string& getName() const = 0;
// in common, the alpha of forward and backward should be equal.
@@ -80,104 +90,30 @@ public:
virtual float getAlpha() const = 0;
virtual float getBwdAlpha() const = 0;
virtual float getBeta() const { return 0.f; }
- virtual mkldnn::algorithm getAlgo(const std::string& type) const {
- if (type == "mkldnn_relu") {
- return mkldnn::algorithm::eltwise_relu;
- } else if (type == "mkldnn_tanh") {
- return mkldnn::algorithm::eltwise_tanh;
- } else if (type == "mkldnn_elu") {
- return mkldnn::algorithm::eltwise_elu;
- } else {
- LOG(FATAL) << "Unkown eltwise activation type: " << type;
- }
- return (mkldnn::algorithm)0;
- }
-
- /**
- * reshape and reset the forward primitives
- */
- void resetFwd(Argument& act) {
- if (cnt_ == act.value->getElementCnt()) {
- return;
- }
- cnt_ = act.value->getElementCnt();
- stream_.reset(new MKLDNNStream());
- auto eng = CPUEngine::Instance().getEngine();
-
- // get algo setting
- mkldnn::algorithm algo = getAlgo(this->getName());
- // note: alpha represents the NegativeSlope when used in relu.
- float alpha = getAlpha();
- float beta = getBeta();
-
- /// forward
- pipelineFwd_.clear();
- val_ = std::dynamic_pointer_cast(act.value);
- if (val_ == nullptr) {
- int bs = act.getBatchSize();
- int ih = act.getFrameHeight() > 0 ? act.getFrameHeight() : 1;
- int iw = act.getFrameWidth() > 0 ? act.getFrameWidth() : 1;
- int ic = cnt_ / bs / ih / iw;
- CHECK_EQ(cnt_, (size_t)bs * ic * ih * iw);
- val_ = MKLDNNMatrix::create(
- act.value, {bs, ic, ih, iw}, mkldnn::memory::format::nchw, eng);
- CHECK(val_);
- }
- auto fwdDesc = eltwise_fwd::desc(mkldnn::prop_kind::forward_training,
- algo,
- val_->getMemoryDesc(),
- alpha,
- beta);
- fwdPD_.reset(new eltwise_fwd::primitive_desc(fwdDesc, eng));
- // use inplace for forward but save input value before submit
- inVal_ = val_;
- copyInVal_ = nullptr;
- if (act.grad && algo == mkldnn::algorithm::eltwise_tanh) {
- // tanh need save src input for backward
- inVal_ = MKLDNNMatrix::create(nullptr, val_->getPrimitiveDesc());
- copyInVal_ = std::make_shared(*val_, *inVal_);
- CHECK(copyInVal_) << "should not be emptry";
- pipelineFwd_.push_back(*copyInVal_);
- }
- fwd_.reset(new eltwise_fwd(*fwdPD_, *val_, *val_));
- pipelineFwd_.push_back(*fwd_);
- needResetBwd_ = true;
- }
+ virtual algorithm getAlgo(std::string type) const;
+ void resetFwd(Argument& act) override;
+ void resetBwd(Argument& act) override;
+};
- /**
- * reset the backward primitives, can not merge into resetFwd as the grad data
- * would be changing before backward.
- */
- void resetBwd(Argument& act) {
- if (!needResetBwd_) {
- return;
- }
- needResetBwd_ = false;
- mkldnn::algorithm algo = getAlgo(this->getName());
- float alpha = getBwdAlpha();
- float beta = getBeta();
- grad_ = MKLDNNMatrix::create(act.grad, val_->getPrimitiveDesc());
- auto eng = CPUEngine::Instance().getEngine();
- auto bwdDesc = eltwise_bwd::desc(
- algo, grad_->getMemoryDesc(), val_->getMemoryDesc(), alpha, beta);
- auto bwdPD = eltwise_bwd::primitive_desc(bwdDesc, eng, *fwdPD_);
- CHECK(inVal_);
- bwd_.reset(new eltwise_bwd(bwdPD, *inVal_, *grad_, *grad_));
- pipelineBwd_.clear();
- pipelineBwd_.push_back(*bwd_);
- }
+/**
+ * @brief Base class of MKLDNN softmax Activation,
+ * only have mkldnn forward, use cpu implement for backward.
+ */
+class MKLDNNSoftmaxActivation : public MKLDNNActivation {
+ typedef mkldnn::softmax_forward softmax_fwd;
- Error __must_check forward(Argument& act) {
- resetFwd(act);
- stream_->submit(pipelineFwd_);
- return Error();
- }
+private:
+ // for backward
+ MatrixPtr sftMaxSum_;
+ MatrixPtr sftMaxDot_;
- Error __must_check backward(Argument& act) {
- resetBwd(act);
- stream_->submit(pipelineBwd_);
- return Error();
- }
+public:
+ MKLDNNSoftmaxActivation() {}
+ ~MKLDNNSoftmaxActivation() {}
+ virtual const std::string& getName() const = 0;
+ void resetFwd(Argument& act) override;
+ Error __must_check forward(Argument& act) override;
+ Error __must_check backward(Argument& act) override;
};
} // namespace paddle
diff --git a/paddle/gserver/layers/MKLDNNConvLayer.cpp b/paddle/gserver/layers/MKLDNNConvLayer.cpp
index 88b047c89bd40aba1afc456c22a2870c62989c1c..9a0abd291ae8fae43b0e95c7371f3ce35d1261ec 100644
--- a/paddle/gserver/layers/MKLDNNConvLayer.cpp
+++ b/paddle/gserver/layers/MKLDNNConvLayer.cpp
@@ -64,7 +64,7 @@ bool MKLDNNConvLayer::init(const LayerMap& layerMap,
// create biases
if (biasParameter_.get() != NULL) {
- biases_ = std::unique_ptr(new Weight(1, oc_, biasParameter_));
+ biases_ = std::unique_ptr(new Weight(1, oc_, biasParameter_, 0));
}
return true;
}
@@ -251,22 +251,31 @@ void MKLDNNConvLayer::resetInValue(
// create buffer and reorder if input value do not match
cpuInVal_ = nullptr;
cvtInVal_ = nullptr;
- if (inputIsOnlyMKLDNN()) {
- MKLDNNMatrixPtr dnnIn = std::dynamic_pointer_cast(inMat);
- CHECK(dnnIn) << "Input should be MKLDNNMatrix";
- if (dnnIn->getPrimitiveDesc() != in->getPrimitiveDesc()) {
- CHECK_EQ(dnnIn->getFormat(), format::nc);
+
+ MKLDNNMatrixPtr dnnIn = std::dynamic_pointer_cast(inMat);
+ CHECK_EQ(inputIsOnlyMKLDNN(), dnnIn != nullptr);
+ if (dnnIn != nullptr && dnnIn->getPrimitiveDesc() == in->getPrimitiveDesc()) {
+ in = dnnIn;
+ return;
+ }
+ if (dnnIn) {
+ if (dnnIn->getFormat() == format::nc) {
CHECK(ih_ == 1 && iw_ == 1) << "when input is nc format";
// create a new one with nchw format and same data
memory::dims inDims = memory::dims{bs_, ic_, 1, 1};
dnnIn = MKLDNNMatrix::create(inMat, inDims, format::nchw, engine_);
- CHECK(dnnIn->getPrimitiveDesc() == in->getPrimitiveDesc());
}
- in = dnnIn;
+ if (dnnIn->getPrimitiveDesc() == in->getPrimitiveDesc()) {
+ in = dnnIn;
+ return;
+ }
+ cpuInVal_ = dnnIn;
+ in = MKLDNNMatrix::create(nullptr, pd->src_primitive_desc());
+ cvtInVal_ = MKLDNNMatrix::createReorder(cpuInVal_, in);
+ CHECK(cvtInVal_) << "should not be emptry";
} else {
- const MatrixPtr& cpuIn = getInputValue(0, CPU_DEVICE);
memory::dims inDims = memory::dims{bs_, ic_, ih_, iw_};
- cpuInVal_ = MKLDNNMatrix::create(cpuIn, inDims, format::nchw, engine_);
+ cpuInVal_ = MKLDNNMatrix::create(inMat, inDims, format::nchw, engine_);
if (cpuInVal_->getPrimitiveDesc() != in->getPrimitiveDesc()) {
// create new mkldnn matrix
in = MKLDNNMatrix::create(nullptr, pd->src_primitive_desc());
@@ -535,7 +544,7 @@ void MKLDNNConvLayer::resetWgtValBwdData(
} else {
wgtValBwdData_ = wgtVal_;
}
- VLOG(MKLDNN_FMTS) << "weight value format for backward data"
+ VLOG(MKLDNN_FMTS) << "weight value format for backward data: "
<< wgtValBwdData_->getFormat();
}
diff --git a/paddle/gserver/layers/MKLDNNFcLayer.cpp b/paddle/gserver/layers/MKLDNNFcLayer.cpp
index afd092666bf8b8a3389b36aa1f0edb256a9968e6..8cbfbd0d2b9f2149f7c959aec5c4ae1de952f903 100644
--- a/paddle/gserver/layers/MKLDNNFcLayer.cpp
+++ b/paddle/gserver/layers/MKLDNNFcLayer.cpp
@@ -49,7 +49,7 @@ bool MKLDNNFcLayer::init(const LayerMap& layerMap,
// create biases
if (biasParameter_.get() != NULL) {
- biases_ = std::unique_ptr(new Weight(1, oc_, biasParameter_));
+ biases_ = std::unique_ptr(new Weight(1, oc_, biasParameter_, 0));
}
return true;
}
@@ -161,9 +161,16 @@ void MKLDNNFcLayer::resetInValue(MKLDNNMatrixPtr& in) {
void MKLDNNFcLayer::resetWgtBiasValue(MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias) {
+ format wgtFmt = format::oihw;
+ if (inVal_->getFormat() == format::nChw8c) {
+ wgtFmt = format::oIhw8i;
+ } else if (inVal_->getFormat() == format::nChw16c) {
+ wgtFmt = format::oIhw16i;
+ }
wgt = MKLDNNMatrix::create(
- weight_->getW(), {oc_, ic_, ih_, iw_}, format::oihw, engine_);
+ weight_->getW(), {oc_, ic_, ih_, iw_}, wgtFmt, engine_);
wgt->downSpatial();
+ VLOG(MKLDNN_FMTS) << "Weight value format: " << wgt->getFormat();
bias = (biases_ && biases_->getW())
? MKLDNNMatrix::create(biases_->getW(), {oc_}, format::x, engine_)
diff --git a/paddle/gserver/layers/MKLDNNLayer.h b/paddle/gserver/layers/MKLDNNLayer.h
index d8555a833187ddf64b096135e920e5be2b3a8c2f..c09fd89462ef4fdaeaae3e122f96b0cc6ce373ea 100644
--- a/paddle/gserver/layers/MKLDNNLayer.h
+++ b/paddle/gserver/layers/MKLDNNLayer.h
@@ -115,6 +115,7 @@ public:
copySeqInfoToOutputs();
size_t elemenCnt = inputLayers_[0]->getOutput().value->getElementCnt();
if (inputElemenCnt_ != elemenCnt) {
+ VLOG(MKLDNN_BASE) << getName() << " reset mkldnn forward";
// reset when input total sizes changed, not only the batchsize
inputElemenCnt_ = elemenCnt;
reshape(bs_, ic_, ih_, iw_, oc_, oh_, ow_);
@@ -142,6 +143,7 @@ public:
void backward(const UpdateCallback& callback) override {
if (needResetBwd_) {
+ VLOG(MKLDNN_BASE) << getName() << " reset mkldnn backward";
resetBwd(pipelineBwd_, inGrad_, wgtGrad_, biasGrad_, outGrad_);
needResetBwd_ = false;
}
diff --git a/paddle/gserver/tests/test_MKLDNN.cpp b/paddle/gserver/tests/test_MKLDNN.cpp
index 1bfbbde4246a10eaf86693a6a2f237f390966db3..857d07df3e3088be28943d9e2fe58017e9e57f4a 100644
--- a/paddle/gserver/tests/test_MKLDNN.cpp
+++ b/paddle/gserver/tests/test_MKLDNN.cpp
@@ -222,8 +222,8 @@ static void getAddtoConfig(TestConfig& cfg, const testActDesc& pm) {
}
void testActivation(std::string& actType, const testActDesc& pm) {
- // TODO(TJ): mkldnn_softmax not implemented, paddle do not have elu activation
- if (actType == "mkldnn_softmax" || actType == "mkldnn_elu") {
+ // TODO(TJ): remove me when paddle support elu activation
+ if (actType == "mkldnn_elu") {
return;
}
const std::string compareTypes[] = {actType, actType.erase(0, 7)};
diff --git a/paddle/operators/accuracy_op.cc b/paddle/operators/accuracy_op.cc
index 70e4f9da1221ab300e2b507a3da2f7c5da93f2e4..82010bfb53e58a0836c99c353590f4e32e25ac4a 100644
--- a/paddle/operators/accuracy_op.cc
+++ b/paddle/operators/accuracy_op.cc
@@ -22,25 +22,23 @@ class AccuracyOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
- void InferShape(const framework::InferShapeContext &ctx) const override {
- PADDLE_ENFORCE_NOT_NULL(
- ctx.InputVar("Inference"),
- "Input(Inference) of AccuracyOp should not be null.");
- PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"),
- "Input(Label) of AccuracyOp should not be null.");
- PADDLE_ENFORCE_NOT_NULL(
- ctx.OutputVar("Accuracy"),
- "Output(Accuracy) of AccuracyOp should not be null.");
+ void InferShape(framework::InferShapeContextBase *ctx) const override {
+ PADDLE_ENFORCE(ctx->HasInput("Inference"),
+ "Input(Inference) of AccuracyOp should not be null.");
+ PADDLE_ENFORCE(ctx->HasInput("Label"),
+ "Input(Label) of AccuracyOp should not be null.");
+ PADDLE_ENFORCE(ctx->HasOutput("Accuracy"),
+ "Output(Accuracy) of AccuracyOp should not be null.");
- auto *inference = ctx.Input("Inference");
- auto *label = ctx.Input("Label");
+ auto inference_dim = ctx->GetInputDim("Inference");
+ auto label_dim = ctx->GetInputDim("Label");
- PADDLE_ENFORCE_EQ(label->dims().size(), 1, "label must be a vector");
- PADDLE_ENFORCE_EQ(inference->dims()[0], label->dims()[0],
+ PADDLE_ENFORCE_EQ(label_dim.size(), 1, "label must be a vector");
+ PADDLE_ENFORCE_EQ(inference_dim[0], label_dim[0],
"inference size must be the same as label size");
- ctx.Output("Accuracy")->Resize({1});
- ctx.ShareLoD("Inference", /*->*/ "Accuracy");
+ ctx->SetOutputDim("Accuracy", {1});
+ ctx->ShareLoD("Inference", /*->*/ "Accuracy");
}
};
diff --git a/paddle/operators/accuracy_op.cu b/paddle/operators/accuracy_op.cu
index 0a6a0fd15c73330902552f7a9aa6339de24c1a18..75e8a989036f0b818687e1fec3e600bb90e86b22 100644
--- a/paddle/operators/accuracy_op.cu
+++ b/paddle/operators/accuracy_op.cu
@@ -69,8 +69,12 @@ class AccuracyOpCUDAKernel : public framework::OpKernel {
return;
}
- AccuracyCudaKernel<<<1, PADDLE_CUDA_NUM_THREADS>>>(
- num_samples, infer_width, inference_data, label_data, accuracy_data);
+ AccuracyCudaKernel<<<
+ 1, PADDLE_CUDA_NUM_THREADS, 0,
+ reinterpret_cast(
+ ctx.device_context())
+ .stream()>>>(num_samples, infer_width, inference_data, label_data,
+ accuracy_data);
}
};
diff --git a/paddle/operators/activation_op.cc b/paddle/operators/activation_op.cc
index 06654702bc42cc7cf4917b00693334b1d36ce371..f77e1c572e33533ac672e3d476a7e6dad122031f 100644
--- a/paddle/operators/activation_op.cc
+++ b/paddle/operators/activation_op.cc
@@ -22,10 +22,9 @@ class ActivationOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
- void InferShape(const framework::InferShapeContext &ctx) const override {
- ctx.Output("Y")->Resize(
- ctx.Input("X")->dims());
- ctx.ShareLoD("X", /*->*/ "Y");
+ void InferShape(framework::InferShapeContextBase *ctx) const override {
+ ctx->SetOutputDim("Y", ctx->GetInputDim("X"));
+ ctx->ShareLoD("X", /*->*/ "Y");
}
};
@@ -34,9 +33,8 @@ class ActivationOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
- void InferShape(const framework::InferShapeContext &ctx) const override {
- ctx.Output(framework::GradVarName("X"))
- ->Resize(ctx.Input("Y")->dims());
+ void InferShape(framework::InferShapeContextBase *ctx) const override {
+ ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("Y"));
}
};
diff --git a/paddle/operators/add_op.cc b/paddle/operators/add_op.cc
index ed11d096974341022637676537793645f46738f0..3914d1323083ede6a7ea07e7b4ef76b9e4afd26d 100644
--- a/paddle/operators/add_op.cc
+++ b/paddle/operators/add_op.cc
@@ -22,25 +22,23 @@ class AddOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
- void InferShape(const framework::InferShapeContext &ctx) const override {
- PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"),
- "Input(X) of AddOp should not be null.");
- PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"),
- "Input(Y) of AddOp should not be null.");
- PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"),
- "Output(Out) of AddOp should not be null.");
+ void InferShape(framework::InferShapeContextBase* ctx) const override {
+ PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of AddOp should not be null.");
+ PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) of AddOp should not be null.");
+ PADDLE_ENFORCE(ctx->HasOutput("Out"),
+ "Output(Out) of AddOp should not be null.");
- PADDLE_ENFORCE_EQ(ctx.Input("X")->dims(),
- ctx.Input("Y")->dims(),
+ auto x_dims = ctx->GetInputDim("X");
+ auto y_dims = ctx->GetInputDim("Y");
+ PADDLE_ENFORCE_EQ(x_dims, y_dims,
"Two input of Add Op's dimension must be same.");
- ctx.Output("Out")->Resize(
- ctx.Input("X")->dims());
+ ctx->SetOutputDim("Out", x_dims);
}
};
class AddOpMaker : public framework::OpProtoAndCheckerMaker {
public:
- AddOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
+ AddOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The first input of add op");
AddInput("Y", "The second input of add op");
@@ -58,7 +56,7 @@ class AddOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
- void InferShape(const framework::InferShapeContext &ctx) const override {}
+ void InferShape(framework::InferShapeContextBase* ctx) const override {}
};
} // namespace operators
diff --git a/paddle/operators/clip_op.cc b/paddle/operators/clip_op.cc
index e5a54bc4b226fd24337050fdd84b2de9c49f7949..316d28f174658de0e20ed9512f315da305bbb6d0 100644
--- a/paddle/operators/clip_op.cc
+++ b/paddle/operators/clip_op.cc
@@ -22,24 +22,24 @@ class ClipOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
- void InferShape(const framework::InferShapeContext &ctx) const override {
- PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"),
- "Input(X) of ClipOp should not be null.");
- PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"),
- "Output(Out) of ClipOp should not be null.");
- auto x_dims = ctx.Input("X")->dims();
- auto max = Attr("max");
- auto min = Attr("min");
+ void InferShape(framework::InferShapeContextBase* ctx) const override {
+ PADDLE_ENFORCE(ctx->HasInput("X"),
+ "Input(X) of ClipOp should not be null.");
+ PADDLE_ENFORCE(ctx->HasOutput("Out"),
+ "Output(Out) of ClipOp should not be null.");
+ auto x_dims = ctx->GetInputDim("X");
+ auto max = ctx->Attrs().Get("max");
+ auto min = ctx->Attrs().Get("min");
PADDLE_ENFORCE_LT(min, max, "max should be greater than min.");
- ctx.Output("Out")->Resize(x_dims);
- ctx.ShareLoD("X", /*->*/ "Out");
+ ctx->SetOutputDim("Out", x_dims);
+ ctx->ShareLoD("X", /*->*/ "Out");
}
};
template
class ClipOpMaker : public framework::OpProtoAndCheckerMaker {
public:
- ClipOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
+ ClipOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X",
"(Tensor)The input of clip op."
@@ -61,14 +61,13 @@ class ClipOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
- void InferShape(const framework::InferShapeContext &ctx) const override {
- PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null");
- PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
- "Input(Out@GRAD) should not be null");
- auto x_dims = ctx.Input("X")->dims();
- auto *x_grad = ctx.Output(framework::GradVarName("X"));
- if (x_grad != nullptr) {
- x_grad->Resize(x_dims);
+ void InferShape(framework::InferShapeContextBase* ctx) const override {
+ PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
+ PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
+ "Input(Out@GRAD) should not be null");
+ auto x_dims = ctx->GetInputDim("X");
+ if (ctx->HasOutput(framework::GradVarName("X"))) {
+ ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
}
}
};
diff --git a/paddle/operators/concat_op.cc b/paddle/operators/concat_op.cc
index 07f847079e834716904dcc038d2097efd268bd3e..01cbfc33efcb4042438fbb398fbcca9457f1334f 100644
--- a/paddle/operators/concat_op.cc
+++ b/paddle/operators/concat_op.cc
@@ -24,31 +24,30 @@ class ConcatOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
- void InferShape(const framework::InferShapeContext &ctx) const override {
- PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"),
- "Output(Out) of ConcatOp should not be null.");
+ void InferShape(framework::InferShapeContextBase *ctx) const override {
+ PADDLE_ENFORCE(ctx->HasOutput("Out"),
+ "Output(Out) of ConcatOp should not be null.");
- auto ins = ctx.MultiInput("X");
- auto *out = ctx.Output("Out");
- size_t axis = static_cast(ctx.Attr("axis"));
+ auto ins = ctx->GetInputsDim("X");
+ size_t axis = static_cast(ctx->Attrs().Get("axis"));
size_t n = ins.size();
PADDLE_ENFORCE_GT(n, 1, "Input tensors count should > 1.");
- auto out_dims = ins[0]->dims();
+ auto out_dims = ins[0];
size_t in_zero_dims_size = out_dims.size();
for (size_t i = 1; i < n; i++) {
for (size_t j = 0; j < in_zero_dims_size; j++) {
if (j == axis) {
- out_dims[axis] += ins[i]->dims()[j];
+ out_dims[axis] += ins[i][j];
continue;
}
- PADDLE_ENFORCE_EQ(out_dims[j], ins[i]->dims()[j],
+ PADDLE_ENFORCE_EQ(out_dims[j], ins[i][j],
"Input tensors should have the same "
"elements except the specify axis.")
}
}
- out->Resize(out_dims);
+ ctx->SetOutputDim("Out", out_dims);
}
};
diff --git a/paddle/operators/cond_op.cc b/paddle/operators/cond_op.cc
index 8262a7a5c8c13c86c5f6c123a14fa89696358c57..1d44782b210bc0c40fd68dba29a16fa6959d6076 100644
--- a/paddle/operators/cond_op.cc
+++ b/paddle/operators/cond_op.cc
@@ -215,7 +215,7 @@ class CondOpProtoAndCheckerMaker : public framework::OpProtoAndCheckerMaker {
AddComment(R"DOC(
Sample dependent Cond Operator:
Given Cond[i] as a 1/0 vector to indicate true/false
-The equation is:
+The equation is:
Out[i] = subnet_t[i], if Cond[i] == true
Out[i] = subnet_t[i], if Cond[i] == false
)DOC");
diff --git a/paddle/operators/conv2d_op.cc b/paddle/operators/conv2d_op.cc
index c3281db0964de6d7dd6be629fbcc55cabb9fef9d..5cc82944bb6b9a4fc5cd94cf2233ab84fc105fe7 100644
--- a/paddle/operators/conv2d_op.cc
+++ b/paddle/operators/conv2d_op.cc
@@ -27,27 +27,25 @@ class Conv2DOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
- void InferShape(const framework::InferShapeContext &ctx) const override {
- PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Input"),
- "Input(Input) of Conv2DOp should not be null.");
- PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Filter"),
- "Input(Filter) of Conv2DOp should not be null.");
- PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Output"),
- "Output(Output) of Conv2DOp should not be null.");
-
- auto in = ctx.Input("Input");
- auto filter = ctx.Input("Filter");
- auto out = ctx.Output("Output");
- std::vector strides = Attr>("strides");
- std::vector paddings = Attr>("paddings");
- int groups = Attr("groups");
- int input_channels = in->dims()[1];
- int output_channels = filter->dims()[0];
-
- PADDLE_ENFORCE_EQ(in->dims().size(), 4, "Conv2DOp input should be 4-D.");
- PADDLE_ENFORCE_EQ(filter->dims().size(), 4,
- "Conv2DOp filter should be 4-D.");
- PADDLE_ENFORCE_EQ(input_channels, filter->dims()[1] * groups,
+ void InferShape(framework::InferShapeContextBase* ctx) const override {
+ PADDLE_ENFORCE(ctx->HasInput("Input"),
+ "Input(Input) of Conv2DOp should not be null.");
+ PADDLE_ENFORCE(ctx->HasInput("Filter"),
+ "Input(Filter) of Conv2DOp should not be null.");
+ PADDLE_ENFORCE(ctx->HasOutput("Output"),
+ "Output(Output) of Conv2DOp should not be null.");
+
+ auto in_dims = ctx->GetInputDim("Input");
+ auto filter_dims = ctx->GetInputDim("Filter");
+ std::vector strides = ctx->Attrs().Get>("strides");
+ std::vector paddings = ctx->Attrs().Get>("paddings");
+ int groups = ctx->Attrs().Get("groups");
+ int input_channels = in_dims[1];
+ int output_channels = filter_dims[0];
+
+ PADDLE_ENFORCE_EQ(in_dims.size(), 4, "Conv2DOp input should be 4-D.");
+ PADDLE_ENFORCE_EQ(filter_dims.size(), 4, "Conv2DOp filter should be 4-D.");
+ PADDLE_ENFORCE_EQ(input_channels, filter_dims[1] * groups,
"The number of input channels should be equal to filter "
"channels * groups.");
PADDLE_ENFORCE_EQ(
@@ -55,17 +53,17 @@ class Conv2DOp : public framework::OperatorWithKernel {
"The number of output channels should be divided by groups.");
auto output_height =
- outputSize(in->dims()[2], filter->dims()[2], paddings[0], strides[0]);
+ outputSize(in_dims[2], filter_dims[2], paddings[0], strides[0]);
auto output_width =
- outputSize(in->dims()[3], filter->dims()[3], paddings[1], strides[1]);
- out->Resize(
- {in->dims()[0], filter->dims()[0], output_height, output_width});
+ outputSize(in_dims[3], filter_dims[3], paddings[1], strides[1]);
+ ctx->SetOutputDim(
+ "Output", {in_dims[0], filter_dims[0], output_height, output_width});
}
};
class Conv2DOpMaker : public framework::OpProtoAndCheckerMaker {
public:
- Conv2DOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
+ Conv2DOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput(
"Input",
@@ -108,14 +106,15 @@ class Conv2DOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
- void InferShape(const framework::InferShapeContext &ctx) const override {
- auto in = ctx.Input("Input");
- auto filter = ctx.Input("Filter");
- auto d_in = ctx.Output(framework::GradVarName("Input"));
- auto d_filter =
- ctx.Output(framework::GradVarName("Filter"));
- if (d_in) d_in->Resize(in->dims());
- if (d_filter) d_filter->Resize(filter->dims());
+ void InferShape(framework::InferShapeContextBase* ctx) const override {
+ auto in_dims = ctx->GetInputDim("Input");
+ auto filter_dims = ctx->GetInputDim("Filter");
+ if (ctx->HasOutput(framework::GradVarName("Input"))) {
+ ctx->SetOutputDim(framework::GradVarName("Input"), in_dims);
+ }
+ if (ctx->HasOutput(framework::GradVarName("Filter"))) {
+ ctx->SetOutputDim(framework::GradVarName("Filter"), filter_dims);
+ }
}
};
diff --git a/paddle/operators/cos_sim_op.cc b/paddle/operators/cos_sim_op.cc
index b56ee2047b811e212b4bf74bf7fbba753a6bcb11..040546f1a6fe1af6d17a5e363a11d27de88d03c2 100644
--- a/paddle/operators/cos_sim_op.cc
+++ b/paddle/operators/cos_sim_op.cc
@@ -24,22 +24,22 @@ class CosSimOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
- void InferShape(const framework::InferShapeContext &ctx) const override {
+ void InferShape(framework::InferShapeContextBase* ctx) const override {
// notnull check
- PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"),
- "Input(X) of CosSimOp should not be null.");
- PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"),
- "Input(Y) of CosSimOp should not be null.");
- PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"),
- "Output(Out) of CosSimOp should not be null.");
- PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("XNorm"),
- "Output(XNorm) of CosSimOp should not be null.");
- PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("YNorm"),
- "Output(YNorm) of CosSimOp should not be null.");
+ PADDLE_ENFORCE(ctx->HasInput("X"),
+ "Input(X) of CosSimOp should not be null.");
+ PADDLE_ENFORCE(ctx->HasInput("Y"),
+ "Input(Y) of CosSimOp should not be null.");
+ PADDLE_ENFORCE(ctx->HasOutput("Out"),
+ "Output(Out) of CosSimOp should not be null.");
+ PADDLE_ENFORCE(ctx->HasOutput("XNorm"),
+ "Output(XNorm) of CosSimOp should not be null.");
+ PADDLE_ENFORCE(ctx->HasOutput("YNorm"),
+ "Output(YNorm) of CosSimOp should not be null.");
// shape check
- auto x_dims = ctx.Input("X")->dims();
- auto y_dims = ctx.Input("Y")->dims();
+ auto x_dims = ctx->GetInputDim("X");
+ auto y_dims = ctx->GetInputDim("Y");
PADDLE_ENFORCE_EQ(x_dims.size(), y_dims.size(),
"Ranks of Input(X) and Input(Y) must be equal.");
@@ -54,16 +54,16 @@ class CosSimOp : public framework::OperatorWithKernel {
" just 1 (which will be broadcasted to match Input(X)).");
// resize tensor
- ctx.Output("Out")->Resize({x_dims[0], 1});
- ctx.Output("XNorm")->Resize({x_dims[0], 1});
- ctx.Output("YNorm")->Resize({y_dims[0], 1});
- ctx.ShareLoD("X", /*->*/ "Out");
+ ctx->SetOutputDim("Out", {x_dims[0], 1});
+ ctx->SetOutputDim("XNorm", {x_dims[0], 1});
+ ctx->SetOutputDim("YNorm", {y_dims[0], 1});
+ ctx->ShareLoD("X", /*->*/ "Out");
}
};
class CosSimOpMaker : public framework::OpProtoAndCheckerMaker {
public:
- CosSimOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
+ CosSimOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The 1st input of cos_sim op.");
AddInput("Y", "The 2nd input of cos_sim op.");
@@ -98,27 +98,23 @@ class CosSimOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
- void InferShape(const framework::InferShapeContext &ctx) const override {
+ void InferShape(framework::InferShapeContextBase* ctx) const override {
// notnull check
- PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null.");
- PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), "Input(Y) must not be null.");
- PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("XNorm"),
- "Input(XNorm) must not be null.");
- PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("YNorm"),
- "Input(YNorm) must not be null.");
- PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Out"),
- "Input(Out) must not be null.");
- PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
- "Input(Out@GRAD) must not be null.");
+ PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null.");
+ PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) must not be null.");
+ PADDLE_ENFORCE(ctx->HasInput("XNorm"), "Input(XNorm) must not be null.");
+ PADDLE_ENFORCE(ctx->HasInput("YNorm"), "Input(YNorm) must not be null.");
+ PADDLE_ENFORCE(ctx->HasInput("Out"), "Input(Out) must not be null.");
+ PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
+ "Input(Out@GRAD) must not be null.");
// shape check
- auto x_dims = ctx.Input("X")->dims();
- auto y_dims = ctx.Input("Y")->dims();
- auto xnorm_dims = ctx.Input("XNorm")->dims();
- auto ynorm_dims = ctx.Input("YNorm")->dims();
- auto out_dims = ctx.Input("Out")->dims();
- auto out_grad_dims =
- ctx.Input(framework::GradVarName("Out"))->dims();
+ auto x_dims = ctx->GetInputDim("X");
+ auto y_dims = ctx->GetInputDim("Y");
+ auto xnorm_dims = ctx->GetInputDim("XNorm");
+ auto ynorm_dims = ctx->GetInputDim("YNorm");
+ auto out_dims = ctx->GetInputDim("Out");
+ auto out_grad_dims = ctx->GetInputDim(framework::GradVarName("Out"));
PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(),
"Ranks of Input(X) and Input(Y) must be equal.");
@@ -143,10 +139,14 @@ class CosSimOpGrad : public framework::OperatorWithKernel {
"Shape of Input(Out@Grad) must be [X.Dim(0), 1].");
// resize tensor
- auto *x_grad = ctx.Output(framework::GradVarName("X"));
- auto *y_grad = ctx.Output(framework::GradVarName("Y"));
- if (x_grad) x_grad->Resize(x_dims);
- if (y_grad) y_grad->Resize(y_dims);
+ auto x_grad_name = framework::GradVarName("X");
+ auto y_grad_name = framework::GradVarName("Y");
+ if (ctx->HasOutput(x_grad_name)) {
+ ctx->SetOutputDim(x_grad_name, x_dims);
+ }
+ if (ctx->HasOutput(y_grad_name)) {
+ ctx->SetOutputDim(y_grad_name, y_dims);
+ }
}
};
diff --git a/paddle/operators/crop_op.cc b/paddle/operators/crop_op.cc
index 52a1123348b10e39bcfa1ba062c893e5f20ed862..9b2305e90e85a6f39d4c584a3251b25f67e81aca 100644
--- a/paddle/operators/crop_op.cc
+++ b/paddle/operators/crop_op.cc
@@ -25,16 +25,14 @@ class CropOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
- void InferShape(const framework::InferShapeContext &ctx) const override {
- PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"),
- "Input(X) of CropOp should not be null.");
- PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"),
- "Output(Out) of CropOp should not be null.");
- auto x_dim = ctx.Input("X")->dims();
- auto *y = ctx.Input("Y");
- auto *out = ctx.Output("Out");
- if (y == nullptr) {
- auto shape = Attr>("shape");
+ void InferShape(framework::InferShapeContextBase* ctx) const override {
+ PADDLE_ENFORCE(ctx->HasInput("X"),
+ "Input(X) of CropOp should not be null.");
+ PADDLE_ENFORCE(ctx->HasOutput("Out"),
+ "Output(Out) of CropOp should not be null.");
+ auto x_dim = ctx->GetInputDim("X");
+ if (!ctx->HasInput("Y")) {
+ auto shape = ctx->Attrs().Get>("shape");
PADDLE_ENFORCE_EQ(
int64_t(shape.size()), x_dim.size(),
"Shape size should be equal to dimention size of input tensor.");
@@ -42,19 +40,20 @@ class CropOp : public framework::OperatorWithKernel {
for (size_t i = 0; i < shape.size(); ++i) {
tensor_shape[i] = static_cast(shape[i]);
}
- out->Resize(framework::make_ddim(tensor_shape));
+ ctx->SetOutputDim("Out", framework::make_ddim(tensor_shape));
} else {
- PADDLE_ENFORCE_EQ(framework::arity(x_dim), framework::arity(y->dims()),
+ auto y_dim = ctx->GetInputDim("Y");
+ PADDLE_ENFORCE_EQ(framework::arity(x_dim), framework::arity(y_dim),
"Tensor rank of both CropOp's "
"inputs must be same.");
- out->Resize(y->dims());
+ ctx->SetOutputDim("Out", y_dim);
}
}
};
class CropOpMaker : public framework::OpProtoAndCheckerMaker {
public:
- CropOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
+ CropOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X",
"The input of pad op. "
@@ -78,12 +77,12 @@ class CropOpMaker : public framework::OpProtoAndCheckerMaker {
Crop Operator.
Crop input into output, as specified by offsets and shape.
-There are two ways to set shape:
+There are two ways to set shape:
1. referenc input: crop input X as shape as reference input.
- The dimension of reference input should
+ The dimension of reference input should
be as same as input X.
2. shape list: crop input X by shape described by a list.
- The size of shape list should be as same as
+ The size of shape list should be as same as
dimension size of input X.
The input should be a k-D tensor(k > 0 and k < 7). As an example:
@@ -94,15 +93,15 @@ Given:
[0, 3, 4, 0, 0]
[0, 0, 0, 0, 0]]
-and
+and
offsets = [0, 1]
and
-
+
shape = [2, 2]
-then we get
+then we get
Out = [[1, 2],
[3, 4]]
@@ -116,14 +115,14 @@ class CropOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
- void InferShape(const framework::InferShapeContext &ctx) const override {
- PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null");
- PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
- "Input(Out@GRAD) should not be null");
- auto x_dims = ctx.Input("X")->dims();
- auto *x_grad = ctx.Output(framework::GradVarName("X"));
- if (x_grad != nullptr) {
- x_grad->Resize(x_dims);
+ void InferShape(framework::InferShapeContextBase* ctx) const override {
+ PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
+ PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
+ "Input(Out@GRAD) should not be null");
+ auto x_dims = ctx->GetInputDim("X");
+ auto x_grad_name = framework::GradVarName("X");
+ if (ctx->HasOutput(x_grad_name)) {
+ ctx->SetOutputDim(x_grad_name, x_dims);
}
}
};
diff --git a/paddle/operators/crop_op.h b/paddle/operators/crop_op.h
index 2f40c059033ec649b29f6ecdee4fcedd128a63a6..ac3aeaf41e206c1deb74c7022c36f02c4777a84b 100644
--- a/paddle/operators/crop_op.h
+++ b/paddle/operators/crop_op.h
@@ -38,10 +38,10 @@ class CropKernel : public framework::OpKernel {
auto out_stride = framework::stride(out->dims());
auto offsets = context.Attr>("offsets");
PADDLE_ENFORCE_EQ(
- x->dims().size(), offsets.size(),
+ x->dims().size(), static_cast(offsets.size()),
"Offsets size should be equal to dimension size of input tensor.");
int64_t offset = 0;
- for (int i = 0; i < offsets.size(); ++i) {
+ for (size_t i = 0; i < offsets.size(); ++i) {
offset += (x_stride[i] * offsets[i]);
}
StridedMemcpy(context.device_context(), x_data + offset, x_stride,
@@ -57,7 +57,7 @@ void CropGradFunction(const framework::ExecutionContext& context) {
d_x->mutable_data(context.GetPlace());
auto offsets = context.Attr>("offsets");
Eigen::array, D> paddings;
- for (int i = 0; i < D; ++i) {
+ for (size_t i = 0; i < D; ++i) {
paddings[i].first = offsets[i];
paddings[i].second = d_x->dims()[i] - d_out->dims()[i] - offsets[i];
}
diff --git a/paddle/operators/cross_entropy_op.cc b/paddle/operators/cross_entropy_op.cc
index b11dc1472d153dd188a0b3553d6950774216a3fd..26fc9b51c44d21d92851030449e116538f937846 100644
--- a/paddle/operators/cross_entropy_op.cc
+++ b/paddle/operators/cross_entropy_op.cc
@@ -22,32 +22,30 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
- void InferShape(const framework::InferShapeContext &ctx) const override {
- PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null.");
- PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"),
- "Input(Label) must not be null.");
- PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Y"), "Output(Y) must not be null.");
-
- auto x = ctx.Input("X");
- auto label = ctx.Input("Label");
- PADDLE_ENFORCE_EQ(x->dims().size(), 2, "Input(X)'s rank must be 2.");
- PADDLE_ENFORCE_EQ(label->dims().size(), 2,
- "Input(Label)'s rank must be 2.");
- PADDLE_ENFORCE_EQ(x->dims()[0], label->dims()[0],
- "The 1st dimension of Input(X) and Input(Label) must "
+ void InferShape(framework::InferShapeContextBase* ctx) const override {
+ PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null.");
+ PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null.");
+ PADDLE_ENFORCE(ctx->HasOutput("Y"), "Output(Y) should be not null.");
+
+ auto x_dims = ctx->GetInputDim("X");
+ auto label_dims = ctx->GetInputDim("Label");
+ PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank should be 2.");
+ PADDLE_ENFORCE_EQ(label_dims.size(), 2, "Input(Label)'s rank should be 2.");
+ PADDLE_ENFORCE_EQ(x_dims[0], label_dims[0],
+ "The 1st dimension of Input(X) and Input(Label) should "
"be equal.");
- if (ctx.Attr("soft_label")) {
- PADDLE_ENFORCE_EQ(x->dims()[1], label->dims()[1],
- "If Attr(soft_label) == true, The 2nd dimension of "
- "Input(X) and Input(Label) must be equal.");
+ if (ctx->Attrs().Get("softLabel")) {
+ PADDLE_ENFORCE_EQ(x_dims[1], label_dims[1],
+ "If Attr(softLabel) == true, the 2nd dimension of "
+ "Input(X) and Input(Label) should be equal.");
} else {
- PADDLE_ENFORCE_EQ(label->dims()[1], 1,
- "If Attr(soft_label) == false, The 2nd dimension of "
- "Input(Label) must be 1.");
+ PADDLE_ENFORCE_EQ(label_dims[1], 1,
+ "If Attr(softLabel) == false, the 2nd dimension of "
+ "Input(Label) should be 1.");
}
- ctx.Output("Y")->Resize({x->dims()[0], 1});
- ctx.ShareLoD("X", /*->*/ "Y");
+ ctx->SetOutputDim("Y", {x_dims[0], 1});
+ ctx->ShareLoD("X", /*->*/ "Y");
}
};
@@ -56,66 +54,79 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
- void InferShape(const framework::InferShapeContext &ctx) const override {
- PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null.");
- PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"),
- "Input(Label) must not be null.");
- PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Y")),
- "Input(Y@GRAD) must not be null.");
-
- auto x = ctx.Input("X");
- auto label = ctx.Input("Label");
- auto dy = ctx.Input(framework::GradVarName("Y"));
- PADDLE_ENFORCE_EQ(x->dims().size(), 2, "Input(X)'s rank must be 2.");
- PADDLE_ENFORCE_EQ(dy->dims().size(), 2, "Input(Y@Grad)'s rank must be 2.");
- PADDLE_ENFORCE_EQ(label->dims().size(), 2,
- "Input(Label)'s rank must be 2.");
- PADDLE_ENFORCE_EQ(x->dims()[0], label->dims()[0],
- "The 1st dimension of Input(X) and Input(Label) must "
+ void InferShape(framework::InferShapeContextBase* ctx) const override {
+ PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null.");
+ PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null.");
+ PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")),
+ "Input(Y@GRAD) shoudl be not null.");
+ PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
+ "Output(X@GRAD) should be not null.");
+
+ auto x_dims = ctx->GetInputDim("X");
+ auto label_dims = ctx->GetInputDim("Label");
+ auto dy_dims = ctx->GetInputDim(framework::GradVarName("Y"));
+ PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank should be 2.");
+ PADDLE_ENFORCE_EQ(dy_dims.size(), 2, "Input(Y@Grad)'s rank should be 2.");
+ PADDLE_ENFORCE_EQ(label_dims.size(), 2, "Input(Label)'s rank should be 2.");
+ PADDLE_ENFORCE_EQ(x_dims[0], label_dims[0],
+ "The 1st dimension of Input(X) and Input(Label) should "
"be equal.");
- PADDLE_ENFORCE_EQ(x->dims()[0], dy->dims()[0],
- "The 1st dimension of Input(X) and Input(Y@Grad) must "
+ PADDLE_ENFORCE_EQ(x_dims[0], dy_dims[0],
+ "The 1st dimension of Input(X) and Input(Y@Grad) should "
"be equal.");
- PADDLE_ENFORCE_EQ(dy->dims()[1], 1,
- "The 2nd dimension of Input(Y@Grad) must be 1.");
- if (ctx.Attr("soft_label")) {
- PADDLE_ENFORCE_EQ(x->dims()[1], label->dims()[1],
- "If Attr(soft_label) == true, The 2nd dimension of "
- "Input(X) and Input(Label) must be equal.");
+ PADDLE_ENFORCE_EQ(dy_dims[1], 1,
+ "The 2nd dimension of Input(Y@Grad) should be 1.");
+ if (ctx->Attrs().Get("softLabel")) {
+ PADDLE_ENFORCE_EQ(x_dims[1], label_dims[1],
+ "When Attr(softLabel) == true, the 2nd dimension of "
+ "Input(X) and Input(Label) should be equal.");
} else {
- PADDLE_ENFORCE_EQ(label->dims()[1], 1,
- "If Attr(soft_label) == false, The 2nd dimension of "
- "Input(Label) must be 1.");
+ PADDLE_ENFORCE_EQ(label_dims[1], 1,
+ "When Attr(softLabel) == false, the 2nd dimension of "
+ "Input(Label) should be 1.");
}
-
- auto dx = ctx.Output(framework::GradVarName("X"));
- dx->Resize(x->dims());
+ ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
}
};
class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker {
public:
- CrossEntropyOpMaker(framework::OpProto *proto,
- framework::OpAttrChecker *op_checker)
+ CrossEntropyOpMaker(framework::OpProto* proto,
+ framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
- AddInput("X", "The first input of CrossEntropyOp");
- AddInput("Label", "The second input of CrossEntropyOp");
- AddOutput("Y", "The output of CrossEntropyOp");
- AddAttr("soft_label", "Is soft label. Default zero.")
+ AddInput("X",
+ "(Tensor, default Tensor), a 2-D tensor with shape N x D, "
+ "where N is the batch size and D is the number of classes. "
+ "This input is a probability computed by the previous operator, "
+ "which is almost always the result of a softmax operator.");
+ AddInput(
+ "Label",
+ "(Tensor, default Tensor), the ground truth which is "
+ "a 2-D tensor. "
+ "When softLabel is set to false, `Label` is a Tensor with shape "
+ "[N x 1]. "
+ "When softLabel is set to true, `Label` is a Tensor "
+ "with shape [N x K].");
+ AddOutput("Y",
+ "(Tensor, default Tensor), a 2-D tensor "
+ "with shape [N x 1]. The cross entropy loss.");
+ AddAttr(
+ "softLabel",
+ "(bool, default false), a flag to indicate whether to interpretate "
+ "the given labels as soft labels.")
.SetDefault(false);
-
AddComment(R"DOC(
CrossEntropy Operator.
It supports both standard cross-entropy and soft-label cross-entropy loss
computation.
1) One-hot cross-entropy:
- soft_label = False, Label[i, 0] indicates the class index for sample i:
+ softLabel = false, Label[i, 0] indicates the class index for sample i:
Y[i] = -log(X[i, Label[i]])
2) Soft-label cross-entropy:
- soft_label = True, Label[i, j] indicates the soft label of class j
+ softLabel = true, Label[i, j] indicates the soft label of class j
for sample i:
Y[i] = \sum_j{-Label[i, j] * log(X[i, j])}
diff --git a/paddle/operators/cross_entropy_op.cu b/paddle/operators/cross_entropy_op.cu
index 1d6361a81472a49729958120c52060b1dff803f2..18e44d77c9f62b296dc57952e546f844670c7d57 100644
--- a/paddle/operators/cross_entropy_op.cu
+++ b/paddle/operators/cross_entropy_op.cu
@@ -28,26 +28,49 @@ __global__ void CrossEntropyKernel(T* Y, const T* X, const int* label,
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
i += blockDim.x * gridDim.x) {
PADDLE_ASSERT(label[i] >= 0 && label[i] < D);
- Y[i] = -tolerable_value(log(X[i * D + label[i]]));
+ Y[i] = -TolerableValue()(log(X[i * D + label[i]]));
}
}
+template
+__device__ __forceinline__ T sum_single_warp(T val) {
+ val += __shfl_down(val, 16);
+ val += __shfl_down(val, 8);
+ val += __shfl_down(val, 4);
+ val += __shfl_down(val, 2);
+ val += __shfl_down(val, 1);
+ return val;
+}
+
template
__global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label,
- const int N, const int D) {
- for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
- i += blockDim.x * gridDim.x) {
- T sum = static_cast(0);
- for (int j = 0; j < D; j++) {
- sum += label[i * D + j] * tolerable_value(log(X[i * D + j]));
- }
- Y[i] = -sum;
+ const int class_num) {
+ int tid = threadIdx.x;
+ extern __shared__ T d_sum[];
+ d_sum[tid] = 0;
+
+ int cur_idx = tid;
+ int next_idx = blockIdx.x * class_num + tid;
+ while (cur_idx < class_num) {
+ d_sum[tid] += TolerableValue()(std::log(X[next_idx])) * label[next_idx];
+ next_idx += blockDim.x;
+ cur_idx += blockDim.x;
+ }
+ __syncthreads();
+
+ for (unsigned int stride = blockDim.x >> 1; stride >= 32; stride >>= 1) {
+ if (tid < stride) d_sum[tid] += d_sum[tid + stride];
+ __syncthreads();
}
+
+ T val = d_sum[tid];
+ val = sum_single_warp(val);
+ if (tid == 0) Y[blockIdx.x] = -val;
}
-// TODO(qingqing): make zero setting an common function.
+// TODO(qingqing): make zero setting a common function.
template
-__global__ void zero(T* X, const int N) {
+__global__ void Zero(T* X, const int N) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
i += blockDim.x * gridDim.x) {
X[i] = 0.0;
@@ -71,13 +94,10 @@ template
__global__ void SoftCrossEntropyGradientKernel(T* dX, const T* dY, const T* X,
const T* label, const int N,
const int D) {
- // TOOD(qingqing): optimize for this kernel
- for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
- i += blockDim.x * gridDim.x) {
- for (int j = 0; j < D; ++j) {
- int idx = i * D + j;
- dX[idx] = -label[idx] * dY[i] / X[idx];
- }
+ int ids = blockIdx.x * blockDim.x + threadIdx.x;
+ if (ids < N * D) {
+ int row_ids = ids / D;
+ dX[ids] = -label[ids] * dY[row_ids] / X[ids];
}
}
@@ -86,29 +106,36 @@ class CrossEntropyOpCUDAKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
- "It must use GPUPlace.");
+ "This kernel only runs on GPU device.");
- auto x = ctx.Input("X");
- auto y = ctx.Output("Y");
- auto label = ctx.Input("Label");
+ const Tensor* x = ctx.Input("X");
+ const Tensor* label = ctx.Input("Label");
+ Tensor* y = ctx.Output("Y");
- auto* x_data = x->data();
- y->mutable_data(ctx.GetPlace());
- auto* y_data = y->data();
+ const T* x_data = x->data();
+ T* y_data = y->mutable_data(ctx.GetPlace());
- int n = x->dims()[0];
- int d = x->dims()[1];
- int block = 512;
- int grid = (n + block - 1) / block;
- // TODO(qingqing) launch kernel on specified stream
- // base on ExecutionContext.
- if (ctx.Attr("soft_label")) {
+ int batch_size = x->dims()[0];
+ int class_num = x->dims()[1];
+
+ if (ctx.Attr("softLabel")) {
auto* label_data = ctx.Input("Label")->data();
- SoftCrossEntropyKernel<<>>(y_data, x_data, label_data, n,
- d);
+ int block = class_num > 512 ? 512 : pow(2, int(std::log2(class_num)));
+
+ SoftCrossEntropyKernel<
+ T><<(
+ ctx.device_context())
+ .stream()>>>(y_data, x_data, label_data, class_num);
} else {
auto* label_data = ctx.Input("Label")->data();
- CrossEntropyKernel<<>>(y_data, x_data, label_data, n, d);
+ int block = 512;
+ int grid = (batch_size + block - 1) / block;
+ CrossEntropyKernel<<<
+ grid, block, 0, reinterpret_cast(
+ ctx.device_context())
+ .stream()>>>(y_data, x_data, label_data,
+ batch_size, class_num);
}
}
};
@@ -118,33 +145,43 @@ class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
- "It must use GPUPlace.");
+ "This kernel only runs on GPU device.");
+
+ const Tensor* x = ctx.Input("X");
+ const Tensor* label = ctx.Input("Label");
+ Tensor* dx = ctx.Output(framework::GradVarName("X"));
- auto x = ctx.Input("X");
- auto dx = ctx.Output(framework::GradVarName("X"));
- auto dy = ctx.Input(framework::GradVarName("Y"));
- auto label = ctx.Input("Label");
+ const T* dy_data =
+ ctx.Input(framework::GradVarName("Y"))->data();
+ T* dx_data = dx->mutable_data(ctx.GetPlace());
+ const T* x_data = x->data();
- auto* dx_data = dx->mutable_data(ctx.GetPlace());
- auto* dy_data = dy->data();
- auto* x_data = x->data();
+ int batch_size = x->dims()[0];
+ int class_num = x->dims()[1];
- int n = x->dims()[0];
- int d = x->dims()[1];
int block = 512;
- int grid = (n * d + block - 1) / block;
- zero<<>>(dx_data, n * d);
- grid = (n + block - 1) / block;
- // TODO(qingqing): launch kernel on specified stream
- // base on ExecutionContext.
- if (ctx.Attr("soft_label")) {
+ int grid = (batch_size * class_num + block - 1) / block;
+
+ if (ctx.Attr("softLabel")) {
auto* label_data = label->data();
- SoftCrossEntropyGradientKernel<<>>(
- dx_data, dy_data, x_data, label_data, n, d);
+ SoftCrossEntropyGradientKernel<<<
+ grid, block, 0, reinterpret_cast