提交 874cac0c 编写于 作者: F fengjiayi

Change softmax

上级 f8391545
......@@ -24,13 +24,13 @@ class SoftmaxOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of SoftmaxOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Y"),
"Output(Y) of SoftmaxOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of SoftmaxOp should not be null.");
auto x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE(x_dims.size() == 2UL,
"The input of softmax op must be a matrix.");
ctx->SetOutputDim("Y", x_dims);
ctx->SetOutputDim("Out", x_dims);
}
};
......@@ -41,7 +41,7 @@ class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("X",
"The input tensor of softmax. "
"2-D with shape [batch_size, input_feature_dimensions].");
AddOutput("Y", "The normalized values with the same shape as X.");
AddOutput("Out", "The normalized values with the same shape as X.");
AddComment(R"DOC(
Softmax Operator.
......@@ -59,7 +59,7 @@ exponential values of all the other dimensions is the output of the softmax
operator.
For each row $i$ and each column $j$ in Input(X), we have:
$$Y[i, j] = \frac{\exp(X[i, j])}{\sum_j(exp(X[i, j])}$$
$$Out[i, j] = \frac{\exp(X[i, j])}{\sum_j(exp(X[i, j])}$$
)DOC");
}
......@@ -70,12 +70,12 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should be not null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")),
"Input(Y@GRAD) should be not null.");
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Y"),
ctx->GetInputDim(framework::GradVarName("Y")),
"Input(Y) and its gradients should have a same shape.");
PADDLE_ENFORCE(ctx->HasInput("Out"), "Input(Out) should be not null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should be not null.");
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Out"),
ctx->GetInputDim(framework::GradVarName("Out")),
"Input(Out) and its gradients should have a same shape.");
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
......
......@@ -26,13 +26,13 @@ class SoftmaxKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<Tensor>("X");
auto* Y = context.Output<Tensor>("Y");
auto* Out = context.Output<Tensor>("Out");
// allocate memory on device.
Y->mutable_data<T>(context.GetPlace());
Out->mutable_data<T>(context.GetPlace());
math::SoftmaxFunctor<DeviceContext, T>()(
context.template device_context<DeviceContext>(), X, Y);
context.template device_context<DeviceContext>(), X, Out);
}
};
......@@ -40,15 +40,15 @@ template <typename DeviceContext, typename T>
class SoftmaxGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* Y = context.Input<Tensor>("Y");
auto* dY = context.Input<Tensor>(framework::GradVarName("Y"));
auto* Out = context.Input<Tensor>("Out");
auto* dOut = context.Input<Tensor>(framework::GradVarName("Out"));
auto* dX = context.Output<Tensor>(framework::GradVarName("X"));
// allocate memory on device.
dX->mutable_data<T>(context.GetPlace());
math::SoftmaxGradFunctor<DeviceContext, T>()(
context.template device_context<DeviceContext>(), Y, dY, dX);
context.template device_context<DeviceContext>(), Out, dOut, dX);
}
};
......
......@@ -184,7 +184,7 @@ class LayerHelper(object):
self.append_op(
type=act_type,
inputs={"X": [input_var]},
outputs={"Y": [tmp]},
outputs={"Out": [tmp]},
attrs=act)
return tmp
......
......@@ -17,14 +17,14 @@ class TestSoftmaxOp(OpTest):
'X': np.random.uniform(0.1, 1, [10, 10]).astype("float32")
}
self.outputs = {
'Y': np.apply_along_axis(stable_softmax, 1, self.inputs['X'])
'Out': np.apply_along_axis(stable_softmax, 1, self.inputs['X'])
}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Y')
self.check_grad(['X'], 'Out')
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册