提交 874cac0c 编写于 作者: F fengjiayi

Change softmax

上级 f8391545
...@@ -24,13 +24,13 @@ class SoftmaxOp : public framework::OperatorWithKernel { ...@@ -24,13 +24,13 @@ class SoftmaxOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of SoftmaxOp should not be null."); "Input(X) of SoftmaxOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Y"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Y) of SoftmaxOp should not be null."); "Output(Out) of SoftmaxOp should not be null.");
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE(x_dims.size() == 2UL, PADDLE_ENFORCE(x_dims.size() == 2UL,
"The input of softmax op must be a matrix."); "The input of softmax op must be a matrix.");
ctx->SetOutputDim("Y", x_dims); ctx->SetOutputDim("Out", x_dims);
} }
}; };
...@@ -41,7 +41,7 @@ class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -41,7 +41,7 @@ class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("X", AddInput("X",
"The input tensor of softmax. " "The input tensor of softmax. "
"2-D with shape [batch_size, input_feature_dimensions]."); "2-D with shape [batch_size, input_feature_dimensions].");
AddOutput("Y", "The normalized values with the same shape as X."); AddOutput("Out", "The normalized values with the same shape as X.");
AddComment(R"DOC( AddComment(R"DOC(
Softmax Operator. Softmax Operator.
...@@ -59,7 +59,7 @@ exponential values of all the other dimensions is the output of the softmax ...@@ -59,7 +59,7 @@ exponential values of all the other dimensions is the output of the softmax
operator. operator.
For each row $i$ and each column $j$ in Input(X), we have: For each row $i$ and each column $j$ in Input(X), we have:
$$Y[i, j] = \frac{\exp(X[i, j])}{\sum_j(exp(X[i, j])}$$ $$Out[i, j] = \frac{\exp(X[i, j])}{\sum_j(exp(X[i, j])}$$
)DOC"); )DOC");
} }
...@@ -70,12 +70,12 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel { ...@@ -70,12 +70,12 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should be not null."); PADDLE_ENFORCE(ctx->HasInput("Out"), "Input(Out) should be not null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")), PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Y@GRAD) should be not null."); "Input(Out@GRAD) should be not null.");
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Y"), PADDLE_ENFORCE_EQ(ctx->GetInputDim("Out"),
ctx->GetInputDim(framework::GradVarName("Y")), ctx->GetInputDim(framework::GradVarName("Out")),
"Input(Y) and its gradients should have a same shape."); "Input(Out) and its gradients should have a same shape.");
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
} }
......
...@@ -26,13 +26,13 @@ class SoftmaxKernel : public framework::OpKernel<T> { ...@@ -26,13 +26,13 @@ class SoftmaxKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<Tensor>("X"); auto* X = context.Input<Tensor>("X");
auto* Y = context.Output<Tensor>("Y"); auto* Out = context.Output<Tensor>("Out");
// allocate memory on device. // allocate memory on device.
Y->mutable_data<T>(context.GetPlace()); Out->mutable_data<T>(context.GetPlace());
math::SoftmaxFunctor<DeviceContext, T>()( math::SoftmaxFunctor<DeviceContext, T>()(
context.template device_context<DeviceContext>(), X, Y); context.template device_context<DeviceContext>(), X, Out);
} }
}; };
...@@ -40,15 +40,15 @@ template <typename DeviceContext, typename T> ...@@ -40,15 +40,15 @@ template <typename DeviceContext, typename T>
class SoftmaxGradKernel : public framework::OpKernel<T> { class SoftmaxGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* Y = context.Input<Tensor>("Y"); auto* Out = context.Input<Tensor>("Out");
auto* dY = context.Input<Tensor>(framework::GradVarName("Y")); auto* dOut = context.Input<Tensor>(framework::GradVarName("Out"));
auto* dX = context.Output<Tensor>(framework::GradVarName("X")); auto* dX = context.Output<Tensor>(framework::GradVarName("X"));
// allocate memory on device. // allocate memory on device.
dX->mutable_data<T>(context.GetPlace()); dX->mutable_data<T>(context.GetPlace());
math::SoftmaxGradFunctor<DeviceContext, T>()( math::SoftmaxGradFunctor<DeviceContext, T>()(
context.template device_context<DeviceContext>(), Y, dY, dX); context.template device_context<DeviceContext>(), Out, dOut, dX);
} }
}; };
......
...@@ -184,7 +184,7 @@ class LayerHelper(object): ...@@ -184,7 +184,7 @@ class LayerHelper(object):
self.append_op( self.append_op(
type=act_type, type=act_type,
inputs={"X": [input_var]}, inputs={"X": [input_var]},
outputs={"Y": [tmp]}, outputs={"Out": [tmp]},
attrs=act) attrs=act)
return tmp return tmp
......
...@@ -17,14 +17,14 @@ class TestSoftmaxOp(OpTest): ...@@ -17,14 +17,14 @@ class TestSoftmaxOp(OpTest):
'X': np.random.uniform(0.1, 1, [10, 10]).astype("float32") 'X': np.random.uniform(0.1, 1, [10, 10]).astype("float32")
} }
self.outputs = { self.outputs = {
'Y': np.apply_along_axis(stable_softmax, 1, self.inputs['X']) 'Out': np.apply_along_axis(stable_softmax, 1, self.inputs['X'])
} }
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['X'], 'Y') self.check_grad(['X'], 'Out')
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册