提交 c7225312 编写于 作者: Y Yan Chunwei 提交者: GitHub

crossentropy grad op (#3186)

* init cross entropy graident

* add crossentropy grad op

* remove details

* fix static compile
上级 338d861f
......@@ -36,6 +36,17 @@ class OnehotCrossEntropyOp : public OperatorWithKernel {
}
};
class OnehotCrossEntropyGradientOp : public OperatorWithKernel {
protected:
void InferShape(const InferShapeContext &ctx) const override {
auto X_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto X = ctx.Input<Tensor>("X");
// TODO(superjom) add enforce here after helper functions ready
X_grad->Resize(X->dims());
}
};
class OnehotCrossEntropyOpMaker : public OpProtoAndCheckerMaker {
public:
OnehotCrossEntropyOpMaker(OpProto *proto, OpAttrChecker *op_checker)
......@@ -58,3 +69,7 @@ REGISTER_OP(onehot_cross_entropy, ops::OnehotCrossEntropyOp,
ops::OnehotCrossEntropyOpMaker);
REGISTER_OP_CPU_KERNEL(onehot_cross_entropy,
ops::OnehotCrossEntropyOpKernel<ops::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
onehot_cross_entropy_grad,
ops::OnehotCrossEntropyGradientOpKernel<ops::CPUPlace, float>);
......@@ -18,28 +18,53 @@ limitations under the License. */
namespace paddle {
namespace operators {
static const float kCrossEntropyLogThreshold{1e-20};
template <typename Place, typename T>
class OnehotCrossEntropyOpKernel : public OpKernel {
public:
constexpr T LOG_THRESHOLD() const { return static_cast<T>(1e-20); }
void Compute(const ExecutionContext& ctx) const override {
auto X = ctx.Input<Tensor>(0);
const T* X_data = X->data<T>();
auto X = ctx.Input<Tensor>("X");
const T* Xdata = X->data<T>();
const int* label_data = ctx.Input<Tensor>(1)->data<int>();
auto Y = ctx.Output<Tensor>(0);
auto Y = ctx.Output<Tensor>("Y");
Y->mutable_data<T>(ctx.GetPlace());
T* Y_data = Y->data<T>();
T* Ydata = Y->data<T>();
int batch_size = X->dims()[0];
int class_num = X->dims()[1];
// Y[i] = -log(X[i][j])
for (int i = 0; i < batch_size; ++i) {
Y_data[i] = -std::log(
std::max(X_data[i * class_num + label_data[i]], LOG_THRESHOLD()));
Ydata[i] = -std::log(std::max(Xdata[i * class_num + label_data[i]],
kCrossEntropyLogThreshold));
}
}
};
template <typename Place, typename T>
class OnehotCrossEntropyGradientOpKernel : public OpKernel {
public:
void Compute(const ExecutionContext& ctx) const override {
auto X = ctx.Input<Tensor>("X");
auto dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto dY = ctx.Input<Tensor>(framework::GradVarName("Y"));
auto label = ctx.Input<Tensor>("label");
auto* dXdata = dX->template mutable_data<T>(ctx.GetPlace());
auto* dYdata = dY->template data<T>();
auto* Xdata = X->template data<T>();
auto* label_data = label->data<int>();
const int batch_size = X->dims()[0];
const int class_num = X->dims()[1];
for (int i = 0; i < batch_size; ++i) {
dXdata[i * class_num + label_data[i]] =
-dYdata[i] / std::max(Xdata[i * class_num + label_data[i]],
kCrossEntropyLogThreshold);
}
}
};
......
......@@ -18,5 +18,7 @@ class TestSGD(unittest.TestCase):
self.Y = numpy.array(Y).astype("float32")
# TODO(superjom) add gradient check
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册