未验证 提交 29f64c8c 编写于 作者: Z Zeng Jinle 提交者: GitHub

refine some grad op makers, test=develop (#21629)

上级 e2d849b9
cos_sim cos_sim
fsp fsp
gru gru
lstm_unit
match_matrix_tensor match_matrix_tensor
max_pool2d_with_index
max_pool3d_with_index
maxout maxout
pool2d pool2d
pool3d pool3d
...@@ -15,6 +12,5 @@ reshape ...@@ -15,6 +12,5 @@ reshape
rnn_memory_helper rnn_memory_helper
sequence_softmax sequence_softmax
spp spp
tensor_array_to_tensor
transpose transpose
unsqueeze unsqueeze
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/lstm_unit_op.h" #include "paddle/fluid/operators/lstm_unit_op.h"
#include <memory>
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -94,14 +95,34 @@ class LstmUnitGradOp : public framework::OperatorWithKernel { ...@@ -94,14 +95,34 @@ class LstmUnitGradOp : public framework::OperatorWithKernel {
} }
}; };
template <typename T>
class LstmUnitGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> op(new T());
op->SetType("lstm_unit_grad");
op->SetInput("X", this->Input("X"));
op->SetInput("C_prev", this->Input("C_prev"));
op->SetInput("C", this->Output("C"));
op->SetInput(framework::GradVarName("H"), this->OutputGrad("H"));
op->SetInput(framework::GradVarName("C"), this->OutputGrad("C"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetOutput(framework::GradVarName("C_prev"), this->InputGrad("C_prev"));
op->SetAttrMap(this->Attrs());
return op;
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR( REGISTER_OPERATOR(lstm_unit, ops::LstmUnitOp, ops::LstmUnitOpMaker,
lstm_unit, ops::LstmUnitOp, ops::LstmUnitOpMaker, ops::LstmUnitGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::DefaultGradOpMaker<paddle::framework::OpDesc, true>, ops::LstmUnitGradOpMaker<paddle::imperative::OpBase>);
paddle::framework::DefaultGradOpMaker<paddle::imperative::OpBase, true>);
REGISTER_OPERATOR(lstm_unit_grad, ops::LstmUnitGradOp); REGISTER_OPERATOR(lstm_unit_grad, ops::LstmUnitGradOp);
REGISTER_OP_CPU_KERNEL(lstm_unit, REGISTER_OP_CPU_KERNEL(lstm_unit,
ops::LstmUnitKernel<paddle::platform::CPUPlace, float>, ops::LstmUnitKernel<paddle::platform::CPUPlace, float>,
......
...@@ -62,9 +62,9 @@ __global__ void LSTMUnitKernel(const int nthreads, const int dim, ...@@ -62,9 +62,9 @@ __global__ void LSTMUnitKernel(const int nthreads, const int dim,
template <typename T> template <typename T>
__global__ void LSTMUnitGradientKernel(const int nthreads, const int dim, __global__ void LSTMUnitGradientKernel(const int nthreads, const int dim,
const T* C_prev, const T* X, const T* C, const T* C_prev, const T* X, const T* C,
const T* H, const T* C_diff, const T* C_diff, const T* H_diff,
const T* H_diff, T* C_prev_diff, T* C_prev_diff, T* X_diff,
T* X_diff, const T forget_bias) { const T forget_bias) {
CUDA_1D_KERNEL_LOOP(index, nthreads) { CUDA_1D_KERNEL_LOOP(index, nthreads) {
const int n = index / dim; const int n = index / dim;
const int d = index % dim; const int d = index % dim;
...@@ -146,7 +146,6 @@ class LstmUnitGradOpCUDAKernel : public framework::OpKernel<T> { ...@@ -146,7 +146,6 @@ class LstmUnitGradOpCUDAKernel : public framework::OpKernel<T> {
auto* X = x_tensor->data<T>(); auto* X = x_tensor->data<T>();
auto* C_prev = c_prev_tensor->data<T>(); auto* C_prev = c_prev_tensor->data<T>();
auto* C = c_tensor->data<T>(); auto* C = c_tensor->data<T>();
auto* H = h_tensor->data<T>();
auto* H_diff = hdiff_tensor->data<T>(); auto* H_diff = hdiff_tensor->data<T>();
auto* C_diff = cdiff_tensor->data<T>(); auto* C_diff = cdiff_tensor->data<T>();
...@@ -163,9 +162,8 @@ class LstmUnitGradOpCUDAKernel : public framework::OpKernel<T> { ...@@ -163,9 +162,8 @@ class LstmUnitGradOpCUDAKernel : public framework::OpKernel<T> {
int n = N * D; int n = N * D;
int grid = (n + block - 1) / block; int grid = (n + block - 1) / block;
LSTMUnitGradientKernel<T><<<grid, block>>>(n, D, C_prev, X, C, H, C_diff, LSTMUnitGradientKernel<T><<<grid, block>>>(
H_diff, C_prev_diff, X_diff, n, D, C_prev, X, C, C_diff, H_diff, C_prev_diff, X_diff, forget_bias);
forget_bias);
} }
}; };
......
...@@ -88,7 +88,6 @@ class LstmUnitGradKernel : public framework::OpKernel<T> { ...@@ -88,7 +88,6 @@ class LstmUnitGradKernel : public framework::OpKernel<T> {
auto x_tensor = ctx.Input<Tensor>("X"); auto x_tensor = ctx.Input<Tensor>("X");
auto c_prev_tensor = ctx.Input<Tensor>("C_prev"); auto c_prev_tensor = ctx.Input<Tensor>("C_prev");
auto c_tensor = ctx.Input<Tensor>("C"); auto c_tensor = ctx.Input<Tensor>("C");
auto h_tensor = ctx.Input<Tensor>("H");
auto hdiff_tensor = ctx.Input<Tensor>(framework::GradVarName("H")); auto hdiff_tensor = ctx.Input<Tensor>(framework::GradVarName("H"));
auto cdiff_tensor = ctx.Input<Tensor>(framework::GradVarName("C")); auto cdiff_tensor = ctx.Input<Tensor>(framework::GradVarName("C"));
...@@ -100,7 +99,6 @@ class LstmUnitGradKernel : public framework::OpKernel<T> { ...@@ -100,7 +99,6 @@ class LstmUnitGradKernel : public framework::OpKernel<T> {
auto* X = x_tensor->data<T>(); auto* X = x_tensor->data<T>();
auto* C_prev = c_prev_tensor->data<T>(); auto* C_prev = c_prev_tensor->data<T>();
auto* C = c_tensor->data<T>(); auto* C = c_tensor->data<T>();
auto* H = h_tensor->data<T>();
auto* H_diff = hdiff_tensor->data<T>(); auto* H_diff = hdiff_tensor->data<T>();
auto* C_diff = cdiff_tensor->data<T>(); auto* C_diff = cdiff_tensor->data<T>();
...@@ -138,7 +136,6 @@ class LstmUnitGradKernel : public framework::OpKernel<T> { ...@@ -138,7 +136,6 @@ class LstmUnitGradKernel : public framework::OpKernel<T> {
C_prev += D; C_prev += D;
X += 4 * D; X += 4 * D;
C += D; C += D;
H += D;
C_diff += D; C_diff += D;
H_diff += D; H_diff += D;
X_diff += 4 * D; X_diff += 4 * D;
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/pool_with_index_op.h" #include "paddle/fluid/operators/pool_with_index_op.h"
#include <memory>
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -283,16 +284,33 @@ Example: ...@@ -283,16 +284,33 @@ Example:
} }
}; };
template <typename T>
class MaxPoolWithIndexGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> op(new T());
op->SetType(this->ForwardOpType() + "_grad");
op->SetAttrMap(this->Attrs());
op->SetInput("X", this->Input("X"));
op->SetInput("Mask", this->Output("Mask"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
return op;
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR( REGISTER_OPERATOR(max_pool2d_with_index, ops::MaxPoolWithIndexOp,
max_pool2d_with_index, ops::MaxPoolWithIndexOp, ops::MaxPool2dWithIndexOpMaker,
ops::MaxPool2dWithIndexOpMaker, ops::MaxPoolWithIndexGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::DefaultGradOpMaker<paddle::framework::OpDesc, true>, ops::MaxPoolWithIndexGradOpMaker<paddle::imperative::OpBase>);
paddle::framework::DefaultGradOpMaker<paddle::imperative::OpBase, true>);
REGISTER_OPERATOR(max_pool2d_with_index_grad, ops::MaxPoolWithIndexOpGrad); REGISTER_OPERATOR(max_pool2d_with_index_grad, ops::MaxPoolWithIndexOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
...@@ -307,11 +325,10 @@ REGISTER_OP_CPU_KERNEL( ...@@ -307,11 +325,10 @@ REGISTER_OP_CPU_KERNEL(
ops::MaxPoolWithIndexGradKernel<paddle::platform::CPUDeviceContext, double, ops::MaxPoolWithIndexGradKernel<paddle::platform::CPUDeviceContext, double,
int>); int>);
REGISTER_OPERATOR( REGISTER_OPERATOR(max_pool3d_with_index, ops::MaxPoolWithIndexOp,
max_pool3d_with_index, ops::MaxPoolWithIndexOp, ops::MaxPool3dWithIndexOpMaker,
ops::MaxPool3dWithIndexOpMaker, ops::MaxPoolWithIndexGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::DefaultGradOpMaker<paddle::framework::OpDesc, true>, ops::MaxPoolWithIndexGradOpMaker<paddle::imperative::OpBase>);
paddle::framework::DefaultGradOpMaker<paddle::imperative::OpBase, true>);
REGISTER_OPERATOR(max_pool3d_with_index_grad, ops::MaxPoolWithIndexOpGrad); REGISTER_OPERATOR(max_pool3d_with_index_grad, ops::MaxPoolWithIndexOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
......
...@@ -273,6 +273,23 @@ class LoDTensorArray2TensorGradOp : public framework::OperatorBase { ...@@ -273,6 +273,23 @@ class LoDTensorArray2TensorGradOp : public framework::OperatorBase {
} }
}; };
template <typename T>
class TensorArrayToTensorGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> op(new T());
op->SetType("tensor_array_to_tensor_grad");
op->SetAttrMap(this->Attrs());
op->SetInput("X", this->Input("X"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
return op;
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
USE_OP(concat); USE_OP(concat);
...@@ -281,8 +298,8 @@ namespace ops = paddle::operators; ...@@ -281,8 +298,8 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR( REGISTER_OPERATOR(
tensor_array_to_tensor, ops::LoDTensorArray2TensorOp, tensor_array_to_tensor, ops::LoDTensorArray2TensorOp,
ops::LoDTensorArray2TensorOpMaker, ops::LoDTensorArray2TensorOpInferShape, ops::LoDTensorArray2TensorOpMaker, ops::LoDTensorArray2TensorOpInferShape,
paddle::framework::DefaultGradOpMaker<paddle::framework::OpDesc, true>, ops::TensorArrayToTensorGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::DefaultGradOpMaker<paddle::imperative::OpBase, true>); ops::TensorArrayToTensorGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(tensor_array_to_tensor_grad, ops::LoDTensorArray2TensorGradOp, REGISTER_OPERATOR(tensor_array_to_tensor_grad, ops::LoDTensorArray2TensorGradOp,
ops::LoDTensorArray2TensorGradInferShape, ops::LoDTensorArray2TensorGradInferShape,
ops::LoDTensorArray2TensorGradInferVarType); ops::LoDTensorArray2TensorGradInferVarType);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册