提交 3e036956 编写于 作者: S sneaxiy

fix numeric error

test=develop
上级 5a92e4c0
...@@ -248,10 +248,14 @@ class CrossEntropyOp2 : public CrossEntropyOpBase { ...@@ -248,10 +248,14 @@ class CrossEntropyOp2 : public CrossEntropyOpBase {
PADDLE_ENFORCE(ctx->HasOutput("XShape"), PADDLE_ENFORCE(ctx->HasOutput("XShape"),
"Output(XShape) should be not null."); "Output(XShape) should be not null.");
PADDLE_ENFORCE(ctx->HasOutput("MatchX"),
"Output(MatchX) should be not null.");
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
auto x_dims_vec = framework::vectorize(x_dims); auto x_dims_vec = framework::vectorize(x_dims);
x_dims_vec.push_back(0); x_dims_vec.push_back(0);
ctx->SetOutputDim("XShape", framework::make_ddim(x_dims_vec)); ctx->SetOutputDim("XShape", framework::make_ddim(x_dims_vec));
x_dims[x_dims.size() - 1] = 1;
ctx->SetOutputDim("MatchX", x_dims);
ctx->ShareLoD("X", /*->*/ "XShape"); ctx->ShareLoD("X", /*->*/ "XShape");
} }
...@@ -264,6 +268,10 @@ class CrossEntropyOp2 : public CrossEntropyOpBase { ...@@ -264,6 +268,10 @@ class CrossEntropyOp2 : public CrossEntropyOpBase {
class CrossEntropyGradientOp2 : public CrossEntropyGradientOpBase { class CrossEntropyGradientOp2 : public CrossEntropyGradientOpBase {
public: public:
using CrossEntropyGradientOpBase::CrossEntropyGradientOpBase; using CrossEntropyGradientOpBase::CrossEntropyGradientOpBase;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("MatchX"), "Input(MatchX) must exist");
CrossEntropyGradientOpBase::InferShape(ctx);
}
protected: protected:
virtual framework::DDim GetXDim(framework::InferShapeContext* ctx) const { virtual framework::DDim GetXDim(framework::InferShapeContext* ctx) const {
...@@ -295,6 +303,8 @@ class CrossEntropyOpMaker2 : public framework::OpProtoAndCheckerMaker { ...@@ -295,6 +303,8 @@ class CrossEntropyOpMaker2 : public framework::OpProtoAndCheckerMaker {
"with 'X' except that the last dimension size is 1. It " "with 'X' except that the last dimension size is 1. It "
"represents the cross entropy loss."); "represents the cross entropy loss.");
AddOutput("XShape", "Temporaily variable to save shape and LoD of X."); AddOutput("XShape", "Temporaily variable to save shape and LoD of X.");
AddOutput("MatchX",
"X value that matches label, used for gradient computation.");
AddAttr<int>("ignore_index", AddAttr<int>("ignore_index",
"(int, default -100), Specifies a target value that is" "(int, default -100), Specifies a target value that is"
"ignored and does not contribute to the input gradient." "ignored and does not contribute to the input gradient."
...@@ -327,7 +337,7 @@ class CrossEntropyGradOpDescMaker2 : public framework::SingleGradOpDescMaker { ...@@ -327,7 +337,7 @@ class CrossEntropyGradOpDescMaker2 : public framework::SingleGradOpDescMaker {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc()); std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("cross_entropy_grad2"); op->SetType("cross_entropy_grad2");
op->SetInput("Label", Input("Label")); op->SetInput("Label", Input("Label"));
op->SetInput("Y", Output("Y")); op->SetInput("MatchX", Output("MatchX"));
op->SetInput("XShape", Output("XShape")); op->SetInput("XShape", Output("XShape"));
op->SetInput(framework::GradVarName("Y"), OutputGrad("Y")); op->SetInput(framework::GradVarName("Y"), OutputGrad("Y"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X")); op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
......
...@@ -138,15 +138,48 @@ class CrossEntropyGradientOpKernel : public framework::OpKernel<T> { ...@@ -138,15 +138,48 @@ class CrossEntropyGradientOpKernel : public framework::OpKernel<T> {
} }
}; };
template <typename T>
struct HardLabelCrossEntropyForwardFunctor {
HardLabelCrossEntropyForwardFunctor(const T* x, T* y, T* match_x,
const int64_t* label,
int64_t ignore_index,
int64_t feature_size)
: x_(x),
y_(y),
match_x_(match_x),
label_(label),
ignore_index_(ignore_index),
feature_size_(feature_size) {}
HOSTDEVICE void operator()(int64_t idx) const {
auto label = label_[idx];
if (label != ignore_index_) {
auto match_x = x_[idx * feature_size_ + label];
y_[idx] = -math::TolerableValue<T>()(real_log(match_x));
match_x_[idx] = match_x;
} else {
y_[idx] = 0;
match_x_[idx] = 0; // any value is ok
}
}
const T* x_;
T* y_;
T* match_x_;
const int64_t* label_;
int64_t ignore_index_;
int64_t feature_size_;
};
template <typename T> template <typename T>
struct HardLabelCrossEntropyBackwardFunctor { struct HardLabelCrossEntropyBackwardFunctor {
HardLabelCrossEntropyBackwardFunctor(T* dx, const T* y, const T* dy, HardLabelCrossEntropyBackwardFunctor(T* dx, const T* dy, const T* match_x,
const int64_t* label, const int64_t* label,
int64_t ignore_index, int64_t ignore_index,
int64_t feature_size) int64_t feature_size)
: dx_(dx), : dx_(dx),
y_(y),
dy_(dy), dy_(dy),
match_x_(match_x),
label_(label), label_(label),
ignore_index_(ignore_index), ignore_index_(ignore_index),
feature_size_(feature_size) {} feature_size_(feature_size) {}
...@@ -156,15 +189,15 @@ struct HardLabelCrossEntropyBackwardFunctor { ...@@ -156,15 +189,15 @@ struct HardLabelCrossEntropyBackwardFunctor {
auto col_idx = idx % feature_size_; auto col_idx = idx % feature_size_;
auto label = label_[row_idx]; auto label = label_[row_idx];
if (label == col_idx && label != ignore_index_) { if (label == col_idx && label != ignore_index_) {
dx_[idx] = -dy_[row_idx] * real_exp(y_[row_idx]); dx_[idx] = -dy_[row_idx] / match_x_[row_idx];
} else { } else {
dx_[idx] = 0; dx_[idx] = 0;
} }
} }
T* dx_; T* dx_;
const T* y_;
const T* dy_; const T* dy_;
const T* match_x_;
const int64_t* label_; const int64_t* label_;
int64_t ignore_index_; int64_t ignore_index_;
int64_t feature_size_; int64_t feature_size_;
...@@ -174,20 +207,26 @@ template <typename DeviceContext, typename T> ...@@ -174,20 +207,26 @@ template <typename DeviceContext, typename T>
class CrossEntropyOpKernel2 : public framework::OpKernel<T> { class CrossEntropyOpKernel2 : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* x_original = ctx.Input<Tensor>("X"); auto* x = ctx.Input<Tensor>("X");
int rank = x_original->dims().size(); auto* label = ctx.Input<Tensor>("Label");
auto x = framework::ReshapeToMatrix(*x_original, rank - 1);
auto label =
framework::ReshapeToMatrix(*ctx.Input<Tensor>("Label"), rank - 1);
auto* y = ctx.Output<Tensor>("Y"); auto* y = ctx.Output<Tensor>("Y");
y->mutable_data<T>(ctx.GetPlace()); auto* match_x = ctx.Output<Tensor>("MatchX");
auto& x_dims = x->dims();
auto feature_size = x_dims[x_dims.size() - 1];
auto batch_size = framework::product(x->dims()) / feature_size;
auto* p_x = x->data<T>();
auto* p_label = label->data<int64_t>();
auto* p_y = y->mutable_data<T>(ctx.GetPlace());
auto* p_match_x = match_x->mutable_data<T>(ctx.GetPlace());
auto ignore_index = ctx.Attr<int>("ignore_index"); auto ignore_index = ctx.Attr<int>("ignore_index");
math::CrossEntropyFunctor<DeviceContext, T>()( platform::ForRange<DeviceContext> for_range(
ctx.template device_context<DeviceContext>(), y, &x, &label, false, ctx.template device_context<DeviceContext>(), batch_size);
ignore_index); for_range(HardLabelCrossEntropyForwardFunctor<T>(
p_x, p_y, p_match_x, p_label, ignore_index, feature_size));
} }
}; };
...@@ -196,13 +235,13 @@ class CrossEntropyGradientOpKernel2 : public framework::OpKernel<T> { ...@@ -196,13 +235,13 @@ class CrossEntropyGradientOpKernel2 : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X")); auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* y = ctx.Input<Tensor>("Y");
auto* dy = ctx.Input<Tensor>(framework::GradVarName("Y")); auto* dy = ctx.Input<Tensor>(framework::GradVarName("Y"));
auto* match_x = ctx.Input<Tensor>("MatchX");
auto* label = ctx.Input<Tensor>("Label"); auto* label = ctx.Input<Tensor>("Label");
auto* p_dx = dx->mutable_data<T>(ctx.GetPlace()); auto* p_dx = dx->mutable_data<T>(ctx.GetPlace());
auto* p_y = y->data<T>();
auto* p_dy = dy->data<T>(); auto* p_dy = dy->data<T>();
auto* p_match_x = match_x->data<T>();
auto* p_label = label->data<int64_t>(); auto* p_label = label->data<int64_t>();
int64_t ignore_index = ctx.Attr<int>("ignore_index"); int64_t ignore_index = ctx.Attr<int>("ignore_index");
...@@ -214,7 +253,7 @@ class CrossEntropyGradientOpKernel2 : public framework::OpKernel<T> { ...@@ -214,7 +253,7 @@ class CrossEntropyGradientOpKernel2 : public framework::OpKernel<T> {
ctx.template device_context<DeviceContext>(), ctx.template device_context<DeviceContext>(),
batch_size * feature_size); batch_size * feature_size);
for_range(HardLabelCrossEntropyBackwardFunctor<T>( for_range(HardLabelCrossEntropyBackwardFunctor<T>(
p_dx, p_y, p_dy, p_label, ignore_index, feature_size)); p_dx, p_dy, p_match_x, p_label, ignore_index, feature_size));
} }
}; };
......
...@@ -1450,11 +1450,13 @@ def cross_entropy2(input, label, ignore_index=kIgnoreIndex): ...@@ -1450,11 +1450,13 @@ def cross_entropy2(input, label, ignore_index=kIgnoreIndex):
helper = LayerHelper('cross_entropy2', **locals()) helper = LayerHelper('cross_entropy2', **locals())
out = helper.create_variable_for_type_inference(dtype=input.dtype) out = helper.create_variable_for_type_inference(dtype=input.dtype)
xshape = helper.create_variable_for_type_inference(dtype=input.dtype) xshape = helper.create_variable_for_type_inference(dtype=input.dtype)
match_x = helper.create_variable_for_type_inference(dtype=input.dtype)
helper.append_op( helper.append_op(
type='cross_entropy2', type='cross_entropy2',
inputs={'X': [input], inputs={'X': [input],
'Label': [label]}, 'Label': [label]},
outputs={'Y': [out], outputs={'Y': [out],
'MatchX': [match_x],
'XShape': [xshape]}, 'XShape': [xshape]},
attrs={'ignore_index': ignore_index}) attrs={'ignore_index': ignore_index})
return out return out
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册