提交 3e036956 编写于 作者: S sneaxiy

fix numeric error

test=develop
上级 5a92e4c0
......@@ -248,10 +248,14 @@ class CrossEntropyOp2 : public CrossEntropyOpBase {
PADDLE_ENFORCE(ctx->HasOutput("XShape"),
"Output(XShape) should be not null.");
PADDLE_ENFORCE(ctx->HasOutput("MatchX"),
"Output(MatchX) should be not null.");
auto x_dims = ctx->GetInputDim("X");
auto x_dims_vec = framework::vectorize(x_dims);
x_dims_vec.push_back(0);
ctx->SetOutputDim("XShape", framework::make_ddim(x_dims_vec));
x_dims[x_dims.size() - 1] = 1;
ctx->SetOutputDim("MatchX", x_dims);
ctx->ShareLoD("X", /*->*/ "XShape");
}
......@@ -264,6 +268,10 @@ class CrossEntropyOp2 : public CrossEntropyOpBase {
class CrossEntropyGradientOp2 : public CrossEntropyGradientOpBase {
public:
using CrossEntropyGradientOpBase::CrossEntropyGradientOpBase;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("MatchX"), "Input(MatchX) must exist");
CrossEntropyGradientOpBase::InferShape(ctx);
}
protected:
virtual framework::DDim GetXDim(framework::InferShapeContext* ctx) const {
......@@ -295,6 +303,8 @@ class CrossEntropyOpMaker2 : public framework::OpProtoAndCheckerMaker {
"with 'X' except that the last dimension size is 1. It "
"represents the cross entropy loss.");
AddOutput("XShape", "Temporaily variable to save shape and LoD of X.");
AddOutput("MatchX",
"X value that matches label, used for gradient computation.");
AddAttr<int>("ignore_index",
"(int, default -100), Specifies a target value that is"
"ignored and does not contribute to the input gradient."
......@@ -327,7 +337,7 @@ class CrossEntropyGradOpDescMaker2 : public framework::SingleGradOpDescMaker {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("cross_entropy_grad2");
op->SetInput("Label", Input("Label"));
op->SetInput("Y", Output("Y"));
op->SetInput("MatchX", Output("MatchX"));
op->SetInput("XShape", Output("XShape"));
op->SetInput(framework::GradVarName("Y"), OutputGrad("Y"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
......
......@@ -138,15 +138,48 @@ class CrossEntropyGradientOpKernel : public framework::OpKernel<T> {
}
};
template <typename T>
struct HardLabelCrossEntropyForwardFunctor {
HardLabelCrossEntropyForwardFunctor(const T* x, T* y, T* match_x,
const int64_t* label,
int64_t ignore_index,
int64_t feature_size)
: x_(x),
y_(y),
match_x_(match_x),
label_(label),
ignore_index_(ignore_index),
feature_size_(feature_size) {}
HOSTDEVICE void operator()(int64_t idx) const {
auto label = label_[idx];
if (label != ignore_index_) {
auto match_x = x_[idx * feature_size_ + label];
y_[idx] = -math::TolerableValue<T>()(real_log(match_x));
match_x_[idx] = match_x;
} else {
y_[idx] = 0;
match_x_[idx] = 0; // any value is ok
}
}
const T* x_;
T* y_;
T* match_x_;
const int64_t* label_;
int64_t ignore_index_;
int64_t feature_size_;
};
template <typename T>
struct HardLabelCrossEntropyBackwardFunctor {
HardLabelCrossEntropyBackwardFunctor(T* dx, const T* y, const T* dy,
HardLabelCrossEntropyBackwardFunctor(T* dx, const T* dy, const T* match_x,
const int64_t* label,
int64_t ignore_index,
int64_t feature_size)
: dx_(dx),
y_(y),
dy_(dy),
match_x_(match_x),
label_(label),
ignore_index_(ignore_index),
feature_size_(feature_size) {}
......@@ -156,15 +189,15 @@ struct HardLabelCrossEntropyBackwardFunctor {
auto col_idx = idx % feature_size_;
auto label = label_[row_idx];
if (label == col_idx && label != ignore_index_) {
dx_[idx] = -dy_[row_idx] * real_exp(y_[row_idx]);
dx_[idx] = -dy_[row_idx] / match_x_[row_idx];
} else {
dx_[idx] = 0;
}
}
T* dx_;
const T* y_;
const T* dy_;
const T* match_x_;
const int64_t* label_;
int64_t ignore_index_;
int64_t feature_size_;
......@@ -174,20 +207,26 @@ template <typename DeviceContext, typename T>
class CrossEntropyOpKernel2 : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x_original = ctx.Input<Tensor>("X");
int rank = x_original->dims().size();
auto x = framework::ReshapeToMatrix(*x_original, rank - 1);
auto label =
framework::ReshapeToMatrix(*ctx.Input<Tensor>("Label"), rank - 1);
auto* x = ctx.Input<Tensor>("X");
auto* label = ctx.Input<Tensor>("Label");
auto* y = ctx.Output<Tensor>("Y");
y->mutable_data<T>(ctx.GetPlace());
auto* match_x = ctx.Output<Tensor>("MatchX");
auto& x_dims = x->dims();
auto feature_size = x_dims[x_dims.size() - 1];
auto batch_size = framework::product(x->dims()) / feature_size;
auto* p_x = x->data<T>();
auto* p_label = label->data<int64_t>();
auto* p_y = y->mutable_data<T>(ctx.GetPlace());
auto* p_match_x = match_x->mutable_data<T>(ctx.GetPlace());
auto ignore_index = ctx.Attr<int>("ignore_index");
math::CrossEntropyFunctor<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), y, &x, &label, false,
ignore_index);
platform::ForRange<DeviceContext> for_range(
ctx.template device_context<DeviceContext>(), batch_size);
for_range(HardLabelCrossEntropyForwardFunctor<T>(
p_x, p_y, p_match_x, p_label, ignore_index, feature_size));
}
};
......@@ -196,13 +235,13 @@ class CrossEntropyGradientOpKernel2 : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* y = ctx.Input<Tensor>("Y");
auto* dy = ctx.Input<Tensor>(framework::GradVarName("Y"));
auto* match_x = ctx.Input<Tensor>("MatchX");
auto* label = ctx.Input<Tensor>("Label");
auto* p_dx = dx->mutable_data<T>(ctx.GetPlace());
auto* p_y = y->data<T>();
auto* p_dy = dy->data<T>();
auto* p_match_x = match_x->data<T>();
auto* p_label = label->data<int64_t>();
int64_t ignore_index = ctx.Attr<int>("ignore_index");
......@@ -214,7 +253,7 @@ class CrossEntropyGradientOpKernel2 : public framework::OpKernel<T> {
ctx.template device_context<DeviceContext>(),
batch_size * feature_size);
for_range(HardLabelCrossEntropyBackwardFunctor<T>(
p_dx, p_y, p_dy, p_label, ignore_index, feature_size));
p_dx, p_dy, p_match_x, p_label, ignore_index, feature_size));
}
};
......
......@@ -1450,11 +1450,13 @@ def cross_entropy2(input, label, ignore_index=kIgnoreIndex):
helper = LayerHelper('cross_entropy2', **locals())
out = helper.create_variable_for_type_inference(dtype=input.dtype)
xshape = helper.create_variable_for_type_inference(dtype=input.dtype)
match_x = helper.create_variable_for_type_inference(dtype=input.dtype)
helper.append_op(
type='cross_entropy2',
inputs={'X': [input],
'Label': [label]},
outputs={'Y': [out],
'MatchX': [match_x],
'XShape': [xshape]},
attrs={'ignore_index': ignore_index})
return out
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册