提交 5a92e4c0 编写于 作者: S sneaxiy

revert revert 16144

test=develop
上级 ad5f0e60
...@@ -13,19 +13,21 @@ See the License for the specific language governing permissions and ...@@ -13,19 +13,21 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/cross_entropy_op.h" #include "paddle/fluid/operators/cross_entropy_op.h"
#include <memory>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class CrossEntropyOp : public framework::OperatorWithKernel { class CrossEntropyOpBase : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null."); PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null.");
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null."); PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null.");
PADDLE_ENFORCE(ctx->HasOutput("Y"), "Output(Y) should be not null."); PADDLE_ENFORCE(ctx->HasOutput("Y"), "Output(Y) should be not null.");
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
...@@ -44,7 +46,8 @@ class CrossEntropyOp : public framework::OperatorWithKernel { ...@@ -44,7 +46,8 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
"Input(X) and Input(Label) shall have the same shape " "Input(X) and Input(Label) shall have the same shape "
"except the last dimension."); "except the last dimension.");
} }
if (ctx->Attrs().Get<bool>("soft_label")) {
if (IsSoftLabel(ctx)) {
if (check) { if (check) {
PADDLE_ENFORCE_EQ(x_dims[rank - 1], label_dims[rank - 1], PADDLE_ENFORCE_EQ(x_dims[rank - 1], label_dims[rank - 1],
"If Attr(soft_label) == true, the last dimension of " "If Attr(soft_label) == true, the last dimension of "
...@@ -70,21 +73,24 @@ class CrossEntropyOp : public framework::OperatorWithKernel { ...@@ -70,21 +73,24 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(), return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
ctx.device_context()); ctx.device_context());
} }
virtual bool IsSoftLabel(framework::InferShapeContext* ctx) const {
return ctx->Attrs().Get<bool>("soft_label");
}
}; };
class CrossEntropyGradientOp : public framework::OperatorWithKernel { class CrossEntropyGradientOpBase : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null.");
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null."); PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")), PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")),
"Input(Y@GRAD) shoudl be not null."); "Input(Y@GRAD) shoudl be not null.");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
"Output(X@GRAD) should be not null."); "Output(X@GRAD) should be not null.");
auto x_dims = ctx->GetInputDim("X"); auto x_dims = GetXDim(ctx);
auto label_dims = ctx->GetInputDim("Label"); auto label_dims = ctx->GetInputDim("Label");
auto dy_dims = ctx->GetInputDim(framework::GradVarName("Y")); auto dy_dims = ctx->GetInputDim(framework::GradVarName("Y"));
int rank = x_dims.size(); int rank = x_dims.size();
...@@ -109,9 +115,7 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel { ...@@ -109,9 +115,7 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
"The Input(X) and Input(Y@Grad) should have the same " "The Input(X) and Input(Y@Grad) should have the same "
"shape except the last dimension."); "shape except the last dimension.");
} }
PADDLE_ENFORCE_EQ(dy_dims[rank - 1], 1, if (IsSoftLabel(ctx)) {
"The last dimension of Input(Y@Grad) should be 1.");
if (ctx->Attrs().Get<bool>("soft_label")) {
if (check) { if (check) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
x_dims[rank - 1], label_dims[rank - 1], x_dims[rank - 1], label_dims[rank - 1],
...@@ -124,7 +128,10 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel { ...@@ -124,7 +128,10 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
"Input(Label) should be 1."); "Input(Label) should be 1.");
} }
ctx->SetOutputDim(framework::GradVarName("X"), x_dims); ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
ctx->ShareLoD("X", framework::GradVarName("X")); PADDLE_ENFORCE_EQ(dy_dims[rank - 1], 1,
"The last dimension of Input(Y@Grad) should be 1.");
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
ctx->ShareLoD(VarNameWithXLoD(), framework::GradVarName("X"));
} }
protected: protected:
...@@ -132,9 +139,29 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel { ...@@ -132,9 +139,29 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
// is determined by its input "X". // is determined by its input "X".
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(), return framework::OpKernelType(
ctx.Input<Tensor>(framework::GradVarName("Y"))->type(),
ctx.device_context()); ctx.device_context());
} }
virtual framework::DDim GetXDim(framework::InferShapeContext* ctx) const {
return ctx->GetInputDim("X");
}
virtual const char* VarNameWithXLoD() const { return "X"; }
virtual bool IsSoftLabel(framework::InferShapeContext* ctx) const {
return ctx->Attrs().Get<bool>("soft_label");
}
};
class CrossEntropyOpInferVarType
: public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
const override {
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Y"}};
}
}; };
class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker { class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker {
...@@ -201,22 +228,122 @@ or not. But the output only shares the LoD information with input X. ...@@ -201,22 +228,122 @@ or not. But the output only shares the LoD information with input X.
} }
}; };
class CrossEntropyOpInferVarType class CrossEntropyGradientOp : public CrossEntropyGradientOpBase {
: public framework::PassInDtypeAndVarTypeToOutput { public:
using CrossEntropyGradientOpBase::CrossEntropyGradientOpBase;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null.");
CrossEntropyGradientOpBase::InferShape(ctx);
}
};
class CrossEntropyOp2 : public CrossEntropyOpBase {
public:
using CrossEntropyOpBase::CrossEntropyOpBase;
void InferShape(framework::InferShapeContext* ctx) const override {
CrossEntropyOpBase::InferShape(ctx);
PADDLE_ENFORCE(ctx->HasOutput("XShape"),
"Output(XShape) should be not null.");
auto x_dims = ctx->GetInputDim("X");
auto x_dims_vec = framework::vectorize(x_dims);
x_dims_vec.push_back(0);
ctx->SetOutputDim("XShape", framework::make_ddim(x_dims_vec));
ctx->ShareLoD("X", /*->*/ "XShape");
}
protected: protected:
std::unordered_map<std::string, std::string> GetInputOutputWithSameType() bool IsSoftLabel(framework::InferShapeContext* ctx) const override {
const override { return false;
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Y"}};
} }
}; };
class CrossEntropyGradientOp2 : public CrossEntropyGradientOpBase {
public:
using CrossEntropyGradientOpBase::CrossEntropyGradientOpBase;
protected:
virtual framework::DDim GetXDim(framework::InferShapeContext* ctx) const {
auto x_shape = ctx->GetInputDim("XShape");
return framework::DDim(x_shape.Get(), x_shape.size() - 1);
}
virtual const char* VarNameWithXLoD() const { return "XShape"; }
virtual bool IsSoftLabel(framework::InferShapeContext* ctx) const {
return false;
}
};
class CrossEntropyOpMaker2 : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(Tensor, default Tensor<float>), a tensor whose last dimension "
"size is equal to the number of classes. This input is a "
"probability computed by the previous operator, which is almost "
"always the result of a softmax operator.");
AddInput(
"Label",
"(Tensor), the tensor which represents the ground truth. It has the "
"same shape with 'X' except the last dimension. One hot Tensor.");
AddOutput("Y",
"(Tensor, default Tensor<float>), a tensor whose shape is same "
"with 'X' except that the last dimension size is 1. It "
"represents the cross entropy loss.");
AddOutput("XShape", "Temporaily variable to save shape and LoD of X.");
AddAttr<int>("ignore_index",
"(int, default -100), Specifies a target value that is"
"ignored and does not contribute to the input gradient."
"Only valid if soft_label is set to False")
.SetDefault(-100);
AddComment(R"DOC(
Hard-label CrossEntropy Operator.
The input 'X' and 'Label' will first be logically flattened to 2-D matrixs.
The matrix's second dimension(row length) is as same as the original last
dimension, and the first dimension(column length) is the product of all other
original dimensions. Then the softmax computation will take palce on each raw
of flattened matrixs.
Only support hard label.
Both the input X and Label can carry the LoD (Level of Details) information,
or not. But the output only shares the LoD information with input X.
)DOC");
}
};
class CrossEntropyGradOpDescMaker2 : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("cross_entropy_grad2");
op->SetInput("Label", Input("Label"));
op->SetInput("Y", Output("Y"));
op->SetInput("XShape", Output("XShape"));
op->SetInput(framework::GradVarName("Y"), OutputGrad("Y"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetAttrMap(Attrs());
return op;
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
using CPUCtx = paddle::platform::CPUDeviceContext; using CPUCtx = paddle::platform::CPUDeviceContext;
REGISTER_OPERATOR(cross_entropy, ops::CrossEntropyOp, ops::CrossEntropyOpMaker, REGISTER_OPERATOR(cross_entropy, ops::CrossEntropyOpBase,
ops::CrossEntropyOpInferVarType, ops::CrossEntropyOpMaker, ops::CrossEntropyOpInferVarType,
paddle::framework::DefaultGradOpDescMaker<true>); paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(cross_entropy_grad, ops::CrossEntropyGradientOp); REGISTER_OPERATOR(cross_entropy_grad, ops::CrossEntropyGradientOp);
REGISTER_OP_CPU_KERNEL(cross_entropy, ops::CrossEntropyOpKernel<CPUCtx, float>, REGISTER_OP_CPU_KERNEL(cross_entropy, ops::CrossEntropyOpKernel<CPUCtx, float>,
...@@ -224,3 +351,14 @@ REGISTER_OP_CPU_KERNEL(cross_entropy, ops::CrossEntropyOpKernel<CPUCtx, float>, ...@@ -224,3 +351,14 @@ REGISTER_OP_CPU_KERNEL(cross_entropy, ops::CrossEntropyOpKernel<CPUCtx, float>,
REGISTER_OP_CPU_KERNEL(cross_entropy_grad, REGISTER_OP_CPU_KERNEL(cross_entropy_grad,
ops::CrossEntropyGradientOpKernel<CPUCtx, float>, ops::CrossEntropyGradientOpKernel<CPUCtx, float>,
ops::CrossEntropyGradientOpKernel<CPUCtx, double>); ops::CrossEntropyGradientOpKernel<CPUCtx, double>);
REGISTER_OPERATOR(cross_entropy2, ops::CrossEntropyOp2,
ops::CrossEntropyOpMaker2, ops::CrossEntropyOpInferVarType,
ops::CrossEntropyGradOpDescMaker2);
REGISTER_OPERATOR(cross_entropy_grad2, ops::CrossEntropyGradientOp2);
REGISTER_OP_CPU_KERNEL(cross_entropy2,
ops::CrossEntropyOpKernel2<CPUCtx, float>,
ops::CrossEntropyOpKernel2<CPUCtx, double>);
REGISTER_OP_CPU_KERNEL(cross_entropy_grad2,
ops::CrossEntropyGradientOpKernel2<CPUCtx, float>,
ops::CrossEntropyGradientOpKernel2<CPUCtx, double>);
...@@ -27,3 +27,13 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -27,3 +27,13 @@ REGISTER_OP_CUDA_KERNEL(
cross_entropy_grad, ops::CrossEntropyGradientOpKernel<CUDACtx, float>, cross_entropy_grad, ops::CrossEntropyGradientOpKernel<CUDACtx, float>,
ops::CrossEntropyGradientOpKernel<CUDACtx, double>, ops::CrossEntropyGradientOpKernel<CUDACtx, double>,
ops::CrossEntropyGradientOpKernel<CUDACtx, plat::float16>); ops::CrossEntropyGradientOpKernel<CUDACtx, plat::float16>);
REGISTER_OP_CUDA_KERNEL(cross_entropy2,
ops::CrossEntropyOpKernel2<CUDACtx, float>,
ops::CrossEntropyOpKernel2<CUDACtx, double>,
ops::CrossEntropyOpKernel2<CUDACtx, plat::float16>);
REGISTER_OP_CUDA_KERNEL(
cross_entropy_grad2, ops::CrossEntropyGradientOpKernel2<CUDACtx, float>,
ops::CrossEntropyGradientOpKernel2<CUDACtx, double>,
ops::CrossEntropyGradientOpKernel2<CUDACtx, plat::float16>);
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math.h"
#include "paddle/fluid/operators/math/cross_entropy.h" #include "paddle/fluid/operators/math/cross_entropy.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/for_range.h" #include "paddle/fluid/platform/for_range.h"
...@@ -137,5 +138,85 @@ class CrossEntropyGradientOpKernel : public framework::OpKernel<T> { ...@@ -137,5 +138,85 @@ class CrossEntropyGradientOpKernel : public framework::OpKernel<T> {
} }
}; };
template <typename T>
struct HardLabelCrossEntropyBackwardFunctor {
HardLabelCrossEntropyBackwardFunctor(T* dx, const T* y, const T* dy,
const int64_t* label,
int64_t ignore_index,
int64_t feature_size)
: dx_(dx),
y_(y),
dy_(dy),
label_(label),
ignore_index_(ignore_index),
feature_size_(feature_size) {}
HOSTDEVICE void operator()(int64_t idx) const {
auto row_idx = idx / feature_size_;
auto col_idx = idx % feature_size_;
auto label = label_[row_idx];
if (label == col_idx && label != ignore_index_) {
dx_[idx] = -dy_[row_idx] * real_exp(y_[row_idx]);
} else {
dx_[idx] = 0;
}
}
T* dx_;
const T* y_;
const T* dy_;
const int64_t* label_;
int64_t ignore_index_;
int64_t feature_size_;
};
template <typename DeviceContext, typename T>
class CrossEntropyOpKernel2 : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x_original = ctx.Input<Tensor>("X");
int rank = x_original->dims().size();
auto x = framework::ReshapeToMatrix(*x_original, rank - 1);
auto label =
framework::ReshapeToMatrix(*ctx.Input<Tensor>("Label"), rank - 1);
auto* y = ctx.Output<Tensor>("Y");
y->mutable_data<T>(ctx.GetPlace());
auto ignore_index = ctx.Attr<int>("ignore_index");
math::CrossEntropyFunctor<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), y, &x, &label, false,
ignore_index);
}
};
template <typename DeviceContext, typename T>
class CrossEntropyGradientOpKernel2 : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* y = ctx.Input<Tensor>("Y");
auto* dy = ctx.Input<Tensor>(framework::GradVarName("Y"));
auto* label = ctx.Input<Tensor>("Label");
auto* p_dx = dx->mutable_data<T>(ctx.GetPlace());
auto* p_y = y->data<T>();
auto* p_dy = dy->data<T>();
auto* p_label = label->data<int64_t>();
int64_t ignore_index = ctx.Attr<int>("ignore_index");
int rank = dx->dims().size();
int64_t feature_size = dx->dims()[rank - 1];
int64_t batch_size = framework::product(dx->dims()) / feature_size;
platform::ForRange<DeviceContext> for_range(
ctx.template device_context<DeviceContext>(),
batch_size * feature_size);
for_range(HardLabelCrossEntropyBackwardFunctor<T>(
p_dx, p_y, p_dy, p_label, ignore_index, feature_size));
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/expand_op.h" #include "paddle/fluid/operators/expand_op.h"
#include <memory>
#include <vector> #include <vector>
namespace paddle { namespace paddle {
...@@ -138,12 +139,28 @@ class ExpandGradOp : public framework::OperatorWithKernel { ...@@ -138,12 +139,28 @@ class ExpandGradOp : public framework::OperatorWithKernel {
} }
}; };
class ExpandGradOpDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("expand_grad");
op->SetInput("X", Input("X"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetAttrMap(Attrs());
return op;
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(expand, ops::ExpandOp, ops::ExpandOpMaker, REGISTER_OPERATOR(expand, ops::ExpandOp, ops::ExpandOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>); ops::ExpandGradOpDescMaker);
REGISTER_OPERATOR(expand_grad, ops::ExpandGradOp); REGISTER_OPERATOR(expand_grad, ops::ExpandGradOp);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
expand, ops::ExpandKernel<paddle::platform::CPUDeviceContext, float>, expand, ops::ExpandKernel<paddle::platform::CPUDeviceContext, float>,
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/hostdevice.h"
#include "math.h" // NOLINT
namespace paddle {
namespace operators {
inline HOSTDEVICE platform::float16 real_exp(platform::float16 x) {
return static_cast<platform::float16>(::expf(static_cast<float>(x)));
}
inline HOSTDEVICE float real_exp(float x) { return ::expf(x); }
inline HOSTDEVICE double real_exp(double x) { return ::exp(x); }
inline HOSTDEVICE platform::float16 real_log(platform::float16 x) {
return static_cast<platform::float16>(::logf(static_cast<float>(x)));
}
inline HOSTDEVICE float real_log(float x) { return ::logf(x); }
inline HOSTDEVICE double real_log(double x) { return ::log(x); }
} // namespace operators
} // namespace paddle
...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/math.h"
#include "paddle/fluid/operators/math/cross_entropy.h" #include "paddle/fluid/operators/math/cross_entropy.h"
#include "paddle/fluid/platform/cuda_device_function.h" #include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/cuda_primitives.h"
...@@ -20,17 +21,6 @@ namespace paddle { ...@@ -20,17 +21,6 @@ namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
namespace {
__device__ __forceinline__ float real_log(float x) { return logf(x); }
__device__ __forceinline__ double real_log(double x) { return log(x); }
__device__ __forceinline__ platform::float16 real_log(
const platform::float16& val) {
return static_cast<platform::float16>(logf(static_cast<float>(val)));
}
template <typename T> template <typename T>
__global__ void CrossEntropyKernel(T* Y, const T* X, const int64_t* label, __global__ void CrossEntropyKernel(T* Y, const T* X, const int64_t* label,
const int N, const int D, const int N, const int D,
...@@ -61,7 +51,6 @@ __global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label, ...@@ -61,7 +51,6 @@ __global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label,
Y[blockIdx.x] = -val; Y[blockIdx.x] = -val;
} }
} }
} // namespace
template <typename T> template <typename T>
class CrossEntropyFunctor<platform::CUDADeviceContext, T> { class CrossEntropyFunctor<platform::CUDADeviceContext, T> {
......
...@@ -15,13 +15,12 @@ limitations under the License. */ ...@@ -15,13 +15,12 @@ limitations under the License. */
#pragma once #pragma once
#include <string> #include <string>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math.h"
#include "paddle/fluid/platform/for_range.h" #include "paddle/fluid/platform/for_range.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
static HOSTDEVICE float real_exp(float x) { return expf(x); }
static HOSTDEVICE float real_exp(double x) { return exp(x); }
template <typename T> template <typename T>
struct SeluFunctor { struct SeluFunctor {
SeluFunctor(const T* x_data_ptr, float alpha, float scale, T* y_data_ptr) SeluFunctor(const T* x_data_ptr, float alpha, float scale, T* y_data_ptr)
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#include <algorithm> #include <algorithm>
#include <cub/cub.cuh> // NOLINT #include <cub/cub.cuh> // NOLINT
#include "paddle/fluid/operators/math.h"
#include "paddle/fluid/operators/sequence_ops/sequence_softmax_op.h" #include "paddle/fluid/operators/sequence_ops/sequence_softmax_op.h"
namespace paddle { namespace paddle {
...@@ -21,9 +22,6 @@ namespace operators { ...@@ -21,9 +22,6 @@ namespace operators {
using LoDTensor = framework::LoDTensor; using LoDTensor = framework::LoDTensor;
__device__ __forceinline__ float real_exp(float x) { return expf(x); }
__device__ __forceinline__ double real_exp(double x) { return exp(x); }
template <typename T, int BlockDim> template <typename T, int BlockDim>
using BlockReduce = cub::BlockReduce<T, BlockDim>; using BlockReduce = cub::BlockReduce<T, BlockDim>;
......
...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "cub/cub.cuh" #include "cub/cub.cuh"
#include "paddle/fluid/operators/math.h"
#include "paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.h" #include "paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.h"
#include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/hostdevice.h" #include "paddle/fluid/platform/hostdevice.h"
...@@ -21,11 +22,6 @@ namespace operators { ...@@ -21,11 +22,6 @@ namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
static HOSTDEVICE float real_exp(float x) { return expf(x); }
static HOSTDEVICE float real_exp(double x) { return exp(x); }
static HOSTDEVICE float real_log(float x) { return logf(x); }
static HOSTDEVICE float real_log(double x) { return log(x); }
static constexpr int kNumCUDAThreads = 512; static constexpr int kNumCUDAThreads = 512;
static constexpr int kNumMaxinumNumBlocks = 4096; static constexpr int kNumMaxinumNumBlocks = 4096;
......
...@@ -1432,6 +1432,8 @@ def cross_entropy(input, label, soft_label=False, ignore_index=kIgnoreIndex): ...@@ -1432,6 +1432,8 @@ def cross_entropy(input, label, soft_label=False, ignore_index=kIgnoreIndex):
predict = fluid.layers.fc(input=net, size=classdim, act='softmax') predict = fluid.layers.fc(input=net, size=classdim, act='softmax')
cost = fluid.layers.cross_entropy(input=predict, label=label) cost = fluid.layers.cross_entropy(input=predict, label=label)
""" """
if not soft_label:
return cross_entropy2(input, label, ignore_index)
helper = LayerHelper('cross_entropy', **locals()) helper = LayerHelper('cross_entropy', **locals())
out = helper.create_variable_for_type_inference(dtype=input.dtype) out = helper.create_variable_for_type_inference(dtype=input.dtype)
helper.append_op( helper.append_op(
...@@ -1444,6 +1446,20 @@ def cross_entropy(input, label, soft_label=False, ignore_index=kIgnoreIndex): ...@@ -1444,6 +1446,20 @@ def cross_entropy(input, label, soft_label=False, ignore_index=kIgnoreIndex):
return out return out
def cross_entropy2(input, label, ignore_index=kIgnoreIndex):
helper = LayerHelper('cross_entropy2', **locals())
out = helper.create_variable_for_type_inference(dtype=input.dtype)
xshape = helper.create_variable_for_type_inference(dtype=input.dtype)
helper.append_op(
type='cross_entropy2',
inputs={'X': [input],
'Label': [label]},
outputs={'Y': [out],
'XShape': [xshape]},
attrs={'ignore_index': ignore_index})
return out
def bpr_loss(input, label, name=None): def bpr_loss(input, label, name=None):
""" """
Bayesian Personalized Ranking Loss Operator. Bayesian Personalized Ranking Loss Operator.
......
...@@ -524,8 +524,8 @@ class TestLocalLookupTable(TestDistLookupTableBase): ...@@ -524,8 +524,8 @@ class TestLocalLookupTable(TestDistLookupTableBase):
ops = [ ops = [
'lookup_table', 'sequence_pool', 'lookup_table', 'sequence_pool', 'lookup_table', 'sequence_pool', 'lookup_table', 'sequence_pool',
'lookup_table', 'sequence_pool', 'concat', 'mul', 'elementwise_add', 'lookup_table', 'sequence_pool', 'concat', 'mul', 'elementwise_add',
'cross_entropy', 'mean', 'fill_constant', 'mean_grad', 'cross_entropy2', 'mean', 'fill_constant', 'mean_grad',
'cross_entropy_grad', 'elementwise_add_grad', 'send', 'mul_grad', 'cross_entropy_grad2', 'elementwise_add_grad', 'send', 'mul_grad',
'send', 'concat_grad', 'sequence_pool_grad', 'lookup_table_grad', 'send', 'concat_grad', 'sequence_pool_grad', 'lookup_table_grad',
'split_selected_rows', 'send', 'sequence_pool_grad', 'split_selected_rows', 'send', 'sequence_pool_grad',
'lookup_table_grad', 'sequence_pool_grad', 'lookup_table_grad', 'lookup_table_grad', 'sequence_pool_grad', 'lookup_table_grad',
...@@ -564,8 +564,8 @@ class TestDistLookupTable(TestDistLookupTableBase): ...@@ -564,8 +564,8 @@ class TestDistLookupTable(TestDistLookupTableBase):
ops = [ ops = [
'split_ids', 'prefetch', 'merge_ids', 'sequence_pool', 'split_ids', 'prefetch', 'merge_ids', 'sequence_pool',
'sequence_pool', 'lookup_table', 'sequence_pool', 'concat', 'mul', 'sequence_pool', 'lookup_table', 'sequence_pool', 'concat', 'mul',
'elementwise_add', 'cross_entropy', 'mean', 'fill_constant', 'elementwise_add', 'cross_entropy2', 'mean', 'fill_constant',
'mean_grad', 'cross_entropy_grad', 'elementwise_add_grad', 'send', 'mean_grad', 'cross_entropy_grad2', 'elementwise_add_grad', 'send',
'mul_grad', 'send', 'concat_grad', 'sequence_pool_grad', 'mul_grad', 'send', 'concat_grad', 'sequence_pool_grad',
'lookup_table_grad', 'split_selected_rows', 'send', 'lookup_table_grad', 'split_selected_rows', 'send',
'sequence_pool_grad', 'lookup_table_grad', 'sequence_pool_grad', 'sequence_pool_grad', 'lookup_table_grad', 'sequence_pool_grad',
...@@ -612,8 +612,8 @@ class TestAsyncLocalLookupTable(TestDistLookupTableBase): ...@@ -612,8 +612,8 @@ class TestAsyncLocalLookupTable(TestDistLookupTableBase):
ops = [ ops = [
'lookup_table', 'sequence_pool', 'lookup_table', 'sequence_pool', 'lookup_table', 'sequence_pool', 'lookup_table', 'sequence_pool',
'lookup_table', 'sequence_pool', 'concat', 'mul', 'elementwise_add', 'lookup_table', 'sequence_pool', 'concat', 'mul', 'elementwise_add',
'cross_entropy', 'mean', 'fill_constant', 'mean_grad', 'cross_entropy2', 'mean', 'fill_constant', 'mean_grad',
'cross_entropy_grad', 'elementwise_add_grad', 'send', 'mul_grad', 'cross_entropy_grad2', 'elementwise_add_grad', 'send', 'mul_grad',
'send', 'concat_grad', 'sequence_pool_grad', 'lookup_table_grad', 'send', 'concat_grad', 'sequence_pool_grad', 'lookup_table_grad',
'split_selected_rows', 'send', 'sequence_pool_grad', 'split_selected_rows', 'send', 'sequence_pool_grad',
'lookup_table_grad', 'sequence_pool_grad', 'lookup_table_grad', 'lookup_table_grad', 'sequence_pool_grad', 'lookup_table_grad',
...@@ -652,8 +652,8 @@ class TestAsyncDistLookupTable(TestDistLookupTableBase): ...@@ -652,8 +652,8 @@ class TestAsyncDistLookupTable(TestDistLookupTableBase):
ops = [ ops = [
'split_ids', 'prefetch', 'merge_ids', 'sequence_pool', 'split_ids', 'prefetch', 'merge_ids', 'sequence_pool',
'sequence_pool', 'lookup_table', 'sequence_pool', 'concat', 'mul', 'sequence_pool', 'lookup_table', 'sequence_pool', 'concat', 'mul',
'elementwise_add', 'cross_entropy', 'mean', 'fill_constant', 'elementwise_add', 'cross_entropy2', 'mean', 'fill_constant',
'mean_grad', 'cross_entropy_grad', 'elementwise_add_grad', 'send', 'mean_grad', 'cross_entropy_grad2', 'elementwise_add_grad', 'send',
'mul_grad', 'send', 'concat_grad', 'sequence_pool_grad', 'mul_grad', 'send', 'concat_grad', 'sequence_pool_grad',
'lookup_table_grad', 'split_selected_rows', 'send', 'lookup_table_grad', 'split_selected_rows', 'send',
'sequence_pool_grad', 'lookup_table_grad', 'sequence_pool_grad', 'sequence_pool_grad', 'lookup_table_grad', 'sequence_pool_grad',
...@@ -841,8 +841,8 @@ class TestRemoteLookupTable(TestDistLookupTableBase): ...@@ -841,8 +841,8 @@ class TestRemoteLookupTable(TestDistLookupTableBase):
ops = [ ops = [
'lookup_table', 'sequence_pool', 'lookup_table', 'sequence_pool', 'lookup_table', 'sequence_pool', 'lookup_table', 'sequence_pool',
'lookup_table', 'sequence_pool', 'concat', 'mul', 'elementwise_add', 'lookup_table', 'sequence_pool', 'concat', 'mul', 'elementwise_add',
'cross_entropy', 'mean', 'fill_constant', 'mean_grad', 'cross_entropy2', 'mean', 'fill_constant', 'mean_grad',
'cross_entropy_grad', 'elementwise_add_grad', 'send', 'mul_grad', 'cross_entropy_grad2', 'elementwise_add_grad', 'send', 'mul_grad',
'send', 'concat_grad', 'sequence_pool_grad', 'lookup_table_grad', 'send', 'concat_grad', 'sequence_pool_grad', 'lookup_table_grad',
'split_selected_rows', 'send', 'sequence_pool_grad', 'split_selected_rows', 'send', 'sequence_pool_grad',
'lookup_table_grad', 'sequence_pool_grad', 'lookup_table_grad', 'lookup_table_grad', 'sequence_pool_grad', 'lookup_table_grad',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册