提交 d953611e 编写于 作者: Q Qiao Longfei 提交者: GitHub

Softmax grad op (#3164)

* init softmax grad op

* add compute code

* export Backward to python

* update test ,export op.type to python

* update python test, fix compute bug

* update unit test

* use eigen

* optimize eigen code

* add gpu test

* register softmax_grad GPU kernel and fix test bug

* typo

* follow comments
上级 809793c9
......@@ -55,6 +55,10 @@ class OperatorBase {
/// e.g. Variable "x@GRAD" is the gradient of varibale "x".
static std::string GRAD_VAR_SUFFIX() { return "@GRAD"; }
static std::string GRAD_VAR_NAME(const std::string& name) {
return name + GRAD_VAR_SUFFIX();
}
/// Variables with this suffix are supposed to be filled up with zeros.
static std::string ZERO_VAR_SUFFIX() { return "@ZERO"; }
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/softmax_op.h"
namespace paddle {
......@@ -19,12 +20,13 @@ namespace operators {
class SoftmaxOp : public OperatorWithKernel {
protected:
void InferShape(const InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.InputSize() == 1, "Only one input is need for softmax");
PADDLE_ENFORCE(ctx.Input<Tensor>(0)->dims().size() == 2,
PADDLE_ENFORCE(ctx.InputSize() == 1UL,
"Only one input is need for softmax");
PADDLE_ENFORCE(ctx.Input<Tensor>("X")->dims().size() == 2UL,
"The input of softmax op must be matrix");
PADDLE_ENFORCE(ctx.OutputSize() == 1,
PADDLE_ENFORCE(ctx.OutputSize() == 1UL,
"Only one output is need for softmax");
ctx.Output<Tensor>(0)->Resize(ctx.Input<Tensor>(0)->dims());
ctx.Output<Tensor>("Y")->Resize(ctx.Input<Tensor>("X")->dims());
}
};
......@@ -40,10 +42,19 @@ public:
class SoftmaxOpGrad : public OperatorWithKernel {
protected:
void InferShape(const InferShapeContext &ctx) const override {}
std::string DebugString() const override {
LOG(INFO) << "SoftmaxOpGrad";
return "";
void InferShape(const InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.InputSize() == 3UL,
"Input of SoftmaxOpGrad should be 3, X, Y, YG");
PADDLE_ENFORCE(ctx.OutputSize() == 1UL,
"Output of SoftmaxOpGrad should be 1");
PADDLE_ENFORCE(ctx.InputVar("Y") != nullptr, "Input(Y) should not be null");
PADDLE_ENFORCE(ctx.InputVar(GRAD_VAR_NAME("Y")) != nullptr,
"Input(Y@GRAD) should not be null");
PADDLE_ENFORCE(ctx.Input<Tensor>("Y")->dims() ==
ctx.Input<Tensor>(GRAD_VAR_NAME("Y"))->dims(),
"the shape of Input(0) and Input(1) should be the same");
ctx.Output<Tensor>(GRAD_VAR_NAME("X"))
->Resize(ctx.Input<Tensor>("Y")->dims());
}
};
......@@ -51,5 +62,7 @@ protected:
} // namespace paddle
REGISTER_OP(softmax, ops::SoftmaxOp, ops::SoftmaxOpMaker);
REGISTER_GRADIENT_OP(softmax, softmax_grad, ops::SoftmaxOpGrad);
REGISTER_OP_CPU_KERNEL(softmax, ops::SoftmaxKernel<ops::CPUPlace, float>);
REGISTER_GRADIENT_OP(softmax, softmax_grad, ops::SoftmaxOpGrad);
REGISTER_OP_CPU_KERNEL(softmax_grad,
ops::SoftmaxGradKernel<ops::CPUPlace, float>);
......@@ -3,3 +3,4 @@
#include "paddle/operators/softmax_op.h"
REGISTER_OP_GPU_KERNEL(softmax, ops::SoftmaxKernel<ops::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(softmax_grad, ops::SoftmaxGradKernel<ops::GPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/ddim.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/tensor.h"
#include "paddle/operators/type_alias.h"
namespace paddle {
......@@ -23,8 +26,8 @@ template <typename Place, typename T>
class SoftmaxKernel : public OpKernel {
public:
void Compute(const ExecutionContext& context) const override {
auto input = context.Input<Tensor>(0);
auto output = context.Output<Tensor>(0);
auto input = context.Input<Tensor>("X");
auto output = context.Output<Tensor>("Y");
output->mutable_data<T>(context.GetPlace());
auto logits = EigenMatrix<T>::From(*input);
......@@ -57,5 +60,38 @@ public:
.broadcast(one_by_class));
}
};
template <typename Place, typename T>
class SoftmaxGradKernel : public OpKernel {
public:
void Compute(const ExecutionContext& context) const override {
std::shared_ptr<Tensor> scale_ = std::make_shared<Tensor>();
auto Y = context.Input<Tensor>("Y");
auto dY = context.Input<Tensor>(OperatorBase::GRAD_VAR_NAME("Y"));
auto dX = context.Output<Tensor>(OperatorBase::GRAD_VAR_NAME("X"));
dX->mutable_data<T>(context.GetPlace());
const int batch_size = Y->dims()[0];
const int class_num = Y->dims()[1];
Eigen::DSizes<int, 1> along_class(1);
Eigen::DSizes<int, 2> batch_by_one(batch_size, 1);
Eigen::DSizes<int, 2> one_by_class(1, class_num);
auto Y_eigen = EigenMatrix<T>::From(*Y);
auto dY_eigen = EigenMatrix<T>::From(*dY);
auto dX_eigen = EigenMatrix<T>::From(*dX);
auto place = context.GetEigenDevice<Place>();
auto dot = (Y_eigen * dY_eigen)
.sum(along_class)
.eval()
.reshape(batch_by_one)
.broadcast(one_by_class);
dX_eigen.device(place) = (dY_eigen - dot) * Y_eigen;
}
};
} // namespace operators
} // namespace paddle
......@@ -22,6 +22,7 @@ namespace paddle {
namespace operators {
using OpKernel = framework::OpKernel;
using OperatorBase = framework::OperatorBase;
using InferShapeContext = framework::InferShapeContext;
using ExecutionContext = framework::ExecutionContext;
using Variable = framework::Variable;
......
import unittest
from op_test_util import OpTestMeta
import numpy as np
import paddle.v2.framework.core as core
import paddle.v2.framework.create_op_creation_methods as creation
from op_test_util import OpTestMeta
def stable_softmax(x):
......@@ -19,5 +23,63 @@ class TestSoftmaxOp(unittest.TestCase):
self.Y = np.apply_along_axis(stable_softmax, 1, self.X)
class TestSoftmaxGradOp(unittest.TestCase):
def test_softmax_grad(self):
op = creation.op_creations.softmax(X="X", Y="Y")
backward_op = core.Operator.backward(op, set())
self.assertEqual(backward_op.type(), "softmax_grad")
expected = '''Op(softmax_grad), inputs:(X, Y, Y@GRAD), outputs:(X@GRAD).'''
self.assertEqual(expected, str(backward_op))
batch_size = 3
class_num = 5
# Initialize X and add 1e-2 for numerical stability
Y = np.random.rand(batch_size, class_num).astype(np.float32)
Y = Y + 1e-2
dY = np.random.rand(batch_size, class_num).astype(np.float32)
# Reference implementation of cross entropy with soft labels
def label_softmax_grad(Y, dY):
dX = Y * 0.0
for i in range(batch_size):
d = np.dot(Y[i, :], dY[i, :])
dX[i, :] = Y[i, :] * (dY[i, :] - d)
return dX
expected = label_softmax_grad(Y, dY)
scope = core.Scope()
places = []
places.append(core.CPUPlace())
if core.is_compile_gpu():
places.append(core.GPUPlace(0))
for place in places:
y = scope.new_var("Y")
y_tensor = y.get_tensor()
y_tensor.set_dims([batch_size, class_num])
y_tensor.alloc_float(place)
y_tensor.set(Y, place)
dy = scope.new_var("Y@GRAD")
dy_tensor = dy.get_tensor()
dy_tensor.set_dims([batch_size, class_num])
dy_tensor.alloc_float(place)
dy_tensor.set(dY, place)
x = scope.new_var("X")
dx = scope.new_var("X@GRAD")
tensor = scope.find_var("X@GRAD").get_tensor()
backward_op.infer_shape(scope)
self.assertEqual([batch_size, class_num], tensor.shape())
ctx = core.DeviceContext.create(place)
backward_op.run(scope, ctx)
actual = np.array(tensor)
np.testing.assert_almost_equal(actual, expected, decimal=3)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册