提交 87da1542 编写于 作者: Y Yu Yang

FIx sigmoid_xe_with_logits_op compile

上级 63469dae
...@@ -21,7 +21,7 @@ namespace operators { ...@@ -21,7 +21,7 @@ namespace operators {
// Out = max(X, 0) - X * Labels + log(1 + exp(-abs(X))) // Out = max(X, 0) - X * Labels + log(1 + exp(-abs(X)))
template <typename Place, typename T> template <typename Place, typename T>
class SigmoidCrossEntropyWithLogitsKernel : public framework::OpKernel { class SigmoidCrossEntropyWithLogitsKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &context) const override { void Compute(const framework::ExecutionContext &context) const override {
const framework::Tensor *X = context.Input<framework::Tensor>("X"); const framework::Tensor *X = context.Input<framework::Tensor>("X");
...@@ -48,7 +48,7 @@ class SigmoidCrossEntropyWithLogitsKernel : public framework::OpKernel { ...@@ -48,7 +48,7 @@ class SigmoidCrossEntropyWithLogitsKernel : public framework::OpKernel {
// dX = sigmoid(X) - labels // dX = sigmoid(X) - labels
template <typename Place, typename T> template <typename Place, typename T>
class SigmoidCrossEntropyWithLogitsGradKernel : public framework::OpKernel { class SigmoidCrossEntropyWithLogitsGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &context) const override { void Compute(const framework::ExecutionContext &context) const override {
const framework::Tensor *X = context.Input<framework::Tensor>("X"); const framework::Tensor *X = context.Input<framework::Tensor>("X");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册