提交 b26e9bd2 编写于 作者: S sneaxiy

refine code

test=develop
上级 cfd012e2
...@@ -16,46 +16,22 @@ limitations under the License. */ ...@@ -16,46 +16,22 @@ limitations under the License. */
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include "paddle/fluid/operators/cross_entropy_op_base.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class CrossEntropyOp2 : public framework::OperatorWithKernel { class CrossEntropyOp2 : public CrossEntropyOpBase {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using CrossEntropyOpBase::CrossEntropyOpBase;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null."); CrossEntropyOpBase::InferShape(ctx);
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null.");
PADDLE_ENFORCE(ctx->HasOutput("Y"), "Output(Y) should be not null.");
PADDLE_ENFORCE(ctx->HasOutput("XShape"), PADDLE_ENFORCE(ctx->HasOutput("XShape"),
"Output(XShape) should be not null."); "Output(XShape) should be not null.");
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
auto label_dims = ctx->GetInputDim("Label");
int rank = x_dims.size();
PADDLE_ENFORCE_EQ(rank, label_dims.size(),
"Input(X) and Input(Label) shall have the same rank.");
bool check = true;
if ((!ctx->IsRuntime()) && (framework::product(x_dims) <= 0 ||
framework::product(label_dims) <= 0)) {
check = false;
}
if (check) {
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1),
framework::slice_ddim(label_dims, 0, rank - 1),
"Input(X) and Input(Label) shall have the same shape "
"except the last dimension.");
}
PADDLE_ENFORCE_EQ(label_dims[rank - 1], 1UL,
"Last dimension of Input(Label) should be 1.");
auto y_dims = x_dims;
y_dims[rank - 1] = 1;
ctx->SetOutputDim("Y", y_dims);
ctx->ShareLoD("X", /*->*/ "Y");
auto x_dims_vec = framework::vectorize(x_dims); auto x_dims_vec = framework::vectorize(x_dims);
x_dims_vec.push_back(0); x_dims_vec.push_back(0);
ctx->SetOutputDim("XShape", framework::make_ddim(x_dims_vec)); ctx->SetOutputDim("XShape", framework::make_ddim(x_dims_vec));
...@@ -63,73 +39,25 @@ class CrossEntropyOp2 : public framework::OperatorWithKernel { ...@@ -63,73 +39,25 @@ class CrossEntropyOp2 : public framework::OperatorWithKernel {
} }
protected: protected:
// Explicitly set that the data type of computation kernel of cross_entropy bool IsSoftLabel(framework::InferShapeContext* ctx) const override {
// is determined by its input "X". return false;
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
ctx.device_context());
} }
}; };
class CrossEntropyGradientOp2 : public framework::OperatorWithKernel { class CrossEntropyGradientOp2 : public CrossEntropyGradientOpBase {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using CrossEntropyGradientOpBase::CrossEntropyGradientOpBase;
void InferShape(framework::InferShapeContext* ctx) const override { protected:
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null."); virtual framework::DDim GetXDim(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(ctx->HasInput("XShape"), auto x_shape = ctx->GetInputDim("XShape");
"Input(XShape) should be not null."); return framework::DDim(x_shape.Get(), x_shape.size() - 1);
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should be not null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")),
"Input(Y@GRAD) shoudl be not null.");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
"Output(X@GRAD) should be not null.");
auto x_shapes = ctx->GetInputDim("XShape");
framework::DDim x_dims(x_shapes.Get(), x_shapes.size() - 1);
auto label_dims = ctx->GetInputDim("Label");
auto dy_dims = ctx->GetInputDim(framework::GradVarName("Y"));
int rank = x_dims.size();
PADDLE_ENFORCE_EQ(dy_dims.size(), rank,
"Input(Y@Grad) and Input(X) should have the same rank.");
PADDLE_ENFORCE_EQ(label_dims.size(), rank,
"Input(Label) and Input(X) should have the same rank.");
bool check = true;
if ((!ctx->IsRuntime()) && (framework::product(x_dims) <= 0 ||
framework::product(label_dims) <= 0)) {
check = false;
}
if (check) {
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1),
framework::slice_ddim(label_dims, 0, rank - 1),
"The Input(X) and Input(Label) should have the same "
"shape except the last dimension.");
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1),
framework::slice_ddim(dy_dims, 0, rank - 1),
"The Input(X) and Input(Y@Grad) should have the same "
"shape except the last dimension.");
}
PADDLE_ENFORCE_EQ(dy_dims[rank - 1], 1,
"The last dimension of Input(Y@Grad) should be 1.");
PADDLE_ENFORCE_EQ(label_dims[rank - 1], 1,
"Last dimension of Input(Label) should be 1.");
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
ctx->ShareLoD("XShape", framework::GradVarName("X"));
} }
protected: virtual const char* VarNameWithXLoD() const { return "XShape"; }
// Explicitly set that the data type of computation kernel of cross_entropy
// is determined by its input "X". virtual bool IsSoftLabel(framework::InferShapeContext* ctx) const {
framework::OpKernelType GetExpectedKernelType( return false;
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
ctx.Input<Tensor>(framework::GradVarName("Y"))->type(),
ctx.device_context());
} }
}; };
...@@ -156,7 +84,7 @@ class CrossEntropyOpMaker2 : public framework::OpProtoAndCheckerMaker { ...@@ -156,7 +84,7 @@ class CrossEntropyOpMaker2 : public framework::OpProtoAndCheckerMaker {
"Only valid if soft_label is set to False") "Only valid if soft_label is set to False")
.SetDefault(-100); .SetDefault(-100);
AddComment(R"DOC( AddComment(R"DOC(
CrossEntropy Operator. Hard-label CrossEntropy Operator.
The input 'X' and 'Label' will first be logically flattened to 2-D matrixs. The input 'X' and 'Label' will first be logically flattened to 2-D matrixs.
The matrix's second dimension(row length) is as same as the original last The matrix's second dimension(row length) is as same as the original last
...@@ -173,15 +101,6 @@ or not. But the output only shares the LoD information with input X. ...@@ -173,15 +101,6 @@ or not. But the output only shares the LoD information with input X.
} }
}; };
class CrossEntropyOpInferVarType2
: public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
const override {
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Y"}};
}
};
class CrossEntropyGradOpMaker2 : public framework::SingleGradOpDescMaker { class CrossEntropyGradOpMaker2 : public framework::SingleGradOpDescMaker {
public: public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
...@@ -207,7 +126,7 @@ namespace ops = paddle::operators; ...@@ -207,7 +126,7 @@ namespace ops = paddle::operators;
using CPUCtx = paddle::platform::CPUDeviceContext; using CPUCtx = paddle::platform::CPUDeviceContext;
REGISTER_OPERATOR(cross_entropy2, ops::CrossEntropyOp2, REGISTER_OPERATOR(cross_entropy2, ops::CrossEntropyOp2,
ops::CrossEntropyOpMaker2, ops::CrossEntropyOpInferVarType2, ops::CrossEntropyOpMaker2, ops::CrossEntropyOpInferVarType,
ops::CrossEntropyGradOpMaker2); ops::CrossEntropyGradOpMaker2);
REGISTER_OPERATOR(cross_entropy_grad2, ops::CrossEntropyGradientOp2); REGISTER_OPERATOR(cross_entropy_grad2, ops::CrossEntropyGradientOp2);
REGISTER_OP_CPU_KERNEL(cross_entropy2, REGISTER_OP_CPU_KERNEL(cross_entropy2,
......
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include <cmath> #include <cmath>
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math.h"
#include "paddle/fluid/operators/math/cross_entropy.h" #include "paddle/fluid/operators/math/cross_entropy.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/for_range.h" #include "paddle/fluid/platform/for_range.h"
...@@ -26,81 +27,6 @@ namespace operators { ...@@ -26,81 +27,6 @@ namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
HOSTDEVICE inline platform::float16 RealLog(platform::float16 x) {
#ifdef __NVCC__
return static_cast<platform::float16>(logf(static_cast<float>(x)));
#else
return static_cast<platform::float16>(std::log(static_cast<float>(x)));
#endif
}
HOSTDEVICE inline float RealLog(float x) {
#ifdef __NVCC__
return logf(x);
#else
return std::log(x);
#endif
}
HOSTDEVICE inline double RealLog(double x) {
#ifdef __NVCC__
return log(x);
#else
return std::log(x);
#endif
}
HOSTDEVICE inline platform::float16 RealExp(platform::float16 x) {
#ifdef __NVCC__
return static_cast<platform::float16>(expf(static_cast<float>(x)));
#else
return static_cast<platform::float16>(std::exp(static_cast<float>(x)));
#endif
}
HOSTDEVICE inline float RealExp(float x) {
#ifdef __NVCC__
return expf(x);
#else
return std::exp(x);
#endif
}
HOSTDEVICE inline double RealExp(double x) {
#ifdef __NVCC__
return exp(x);
#else
return std::exp(x);
#endif
}
template <typename T>
struct CrossEntropyForwardFunctor {
CrossEntropyForwardFunctor(const T *x, T *y, const int64_t *label,
int64_t ignore_index, int64_t feature_size)
: x_(x),
y_(y),
label_(label),
ignore_index_(ignore_index),
feature_size_(feature_size) {}
HOSTDEVICE void operator()(int64_t row_idx) const {
auto col_idx = label_[row_idx];
if (col_idx != ignore_index_) {
y_[row_idx] = -math::TolerableValue<T>()(
RealLog(x_[row_idx * feature_size_ + col_idx]));
} else {
y_[row_idx] = 0;
}
}
const T *x_;
T *y_;
const int64_t *label_;
int64_t ignore_index_;
int64_t feature_size_;
};
template <typename T> template <typename T>
struct CrossEntropyBackwardFunctor { struct CrossEntropyBackwardFunctor {
CrossEntropyBackwardFunctor(T *dx, const T *y, const T *dy, CrossEntropyBackwardFunctor(T *dx, const T *y, const T *dy,
...@@ -118,7 +44,7 @@ struct CrossEntropyBackwardFunctor { ...@@ -118,7 +44,7 @@ struct CrossEntropyBackwardFunctor {
auto col_idx = idx % feature_size_; auto col_idx = idx % feature_size_;
auto label = label_[row_idx]; auto label = label_[row_idx];
if (label == col_idx && label != ignore_index_) { if (label == col_idx && label != ignore_index_) {
dx_[idx] = -dy_[row_idx] * RealExp(y_[row_idx]); dx_[idx] = -dy_[row_idx] * real_exp(y_[row_idx]);
} else { } else {
dx_[idx] = 0; dx_[idx] = 0;
} }
...@@ -136,24 +62,20 @@ template <typename DeviceContext, typename T> ...@@ -136,24 +62,20 @@ template <typename DeviceContext, typename T>
class CrossEntropyOpKernel2 : public framework::OpKernel<T> { class CrossEntropyOpKernel2 : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
auto *x = ctx.Input<Tensor>("X"); auto *x_original = ctx.Input<Tensor>("X");
auto *label = ctx.Input<Tensor>("Label"); int rank = x_original->dims().size();
auto *y = ctx.Output<Tensor>("Y");
auto *p_y = y->mutable_data<T>(ctx.GetPlace()); auto x = framework::ReshapeToMatrix(*x_original, rank - 1);
auto *p_x = x->data<T>(); auto label =
auto *p_label = label->data<int64_t>(); framework::ReshapeToMatrix(*ctx.Input<Tensor>("Label"), rank - 1);
auto *y = ctx.Output<Tensor>("Y");
y->mutable_data<T>(ctx.GetPlace());
int rank = x->dims().size(); auto ignore_index = ctx.Attr<int>("ignore_index");
int64_t feature_size = x->dims()[rank - 1];
int64_t batch_size = framework::product(x->dims()) / feature_size;
int64_t ignore_index = ctx.Attr<int>("ignore_index"); math::CrossEntropyFunctor<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), y, &x, &label, false,
platform::ForRange<DeviceContext> for_range( ignore_index);
ctx.template device_context<DeviceContext>(), batch_size);
for_range(CrossEntropyForwardFunctor<T>(p_x, p_y, p_label, ignore_index,
feature_size));
} }
}; };
......
...@@ -14,128 +14,11 @@ limitations under the License. */ ...@@ -14,128 +14,11 @@ limitations under the License. */
#include "paddle/fluid/operators/cross_entropy_op.h" #include "paddle/fluid/operators/cross_entropy_op.h"
#include <string> #include <string>
#include "paddle/fluid/operators/cross_entropy_op_base.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class CrossEntropyOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null.");
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null.");
PADDLE_ENFORCE(ctx->HasOutput("Y"), "Output(Y) should be not null.");
auto x_dims = ctx->GetInputDim("X");
auto label_dims = ctx->GetInputDim("Label");
int rank = x_dims.size();
PADDLE_ENFORCE_EQ(rank, label_dims.size(),
"Input(X) and Input(Label) shall have the same rank.");
bool check = true;
if ((!ctx->IsRuntime()) && (framework::product(x_dims) <= 0 ||
framework::product(label_dims) <= 0)) {
check = false;
}
if (check) {
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1),
framework::slice_ddim(label_dims, 0, rank - 1),
"Input(X) and Input(Label) shall have the same shape "
"except the last dimension.");
}
if (ctx->Attrs().Get<bool>("soft_label")) {
if (check) {
PADDLE_ENFORCE_EQ(x_dims[rank - 1], label_dims[rank - 1],
"If Attr(soft_label) == true, the last dimension of "
"Input(X) and Input(Label) should be equal.");
}
} else {
PADDLE_ENFORCE_EQ(label_dims[rank - 1], 1UL,
"If Attr(softLabel) == false, the last dimension of "
"Input(Label) should be 1.");
}
auto y_dims = x_dims;
y_dims[rank - 1] = 1;
ctx->SetOutputDim("Y", y_dims);
ctx->ShareLoD("X", /*->*/ "Y");
}
protected:
// Explicitly set that the data type of computation kernel of cross_entropy
// is determined by its input "X".
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
ctx.device_context());
}
};
class CrossEntropyGradientOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null.");
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")),
"Input(Y@GRAD) shoudl be not null.");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
"Output(X@GRAD) should be not null.");
auto x_dims = ctx->GetInputDim("X");
auto label_dims = ctx->GetInputDim("Label");
auto dy_dims = ctx->GetInputDim(framework::GradVarName("Y"));
int rank = x_dims.size();
PADDLE_ENFORCE_EQ(dy_dims.size(), rank,
"Input(Y@Grad) and Input(X) should have the same rank.");
PADDLE_ENFORCE_EQ(label_dims.size(), rank,
"Input(Label) and Input(X) should have the same rank.");
bool check = true;
if ((!ctx->IsRuntime()) && (framework::product(x_dims) <= 0 ||
framework::product(label_dims) <= 0)) {
check = false;
}
if (check) {
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1),
framework::slice_ddim(label_dims, 0, rank - 1),
"The Input(X) and Input(Label) should have the same "
"shape except the last dimension.");
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1),
framework::slice_ddim(dy_dims, 0, rank - 1),
"The Input(X) and Input(Y@Grad) should have the same "
"shape except the last dimension.");
}
PADDLE_ENFORCE_EQ(dy_dims[rank - 1], 1,
"The last dimension of Input(Y@Grad) should be 1.");
if (ctx->Attrs().Get<bool>("soft_label")) {
if (check) {
PADDLE_ENFORCE_EQ(
x_dims[rank - 1], label_dims[rank - 1],
"When Attr(soft_label) == true, the last dimension of "
"Input(X) and Input(Label) should be equal.");
}
} else {
PADDLE_ENFORCE_EQ(label_dims[rank - 1], 1,
"When Attr(soft_label) == false, the last dimension of "
"Input(Label) should be 1.");
}
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
ctx->ShareLoD("X", framework::GradVarName("X"));
}
protected:
// Explicitly set that the data type of computation kernel of cross_entropy
// is determined by its input "X".
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
ctx.device_context());
}
};
class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker { class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
...@@ -200,22 +83,24 @@ or not. But the output only shares the LoD information with input X. ...@@ -200,22 +83,24 @@ or not. But the output only shares the LoD information with input X.
} }
}; };
class CrossEntropyOpInferVarType class CrossEntropyGradientOp : public CrossEntropyGradientOpBase {
: public framework::PassInDtypeAndVarTypeToOutput { public:
protected: using CrossEntropyGradientOpBase::CrossEntropyGradientOpBase;
std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
const override { void InferShape(framework::InferShapeContext *ctx) const override {
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Y"}}; PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null.");
CrossEntropyGradientOpBase::InferShape(ctx);
} }
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
using CPUCtx = paddle::platform::CPUDeviceContext; using CPUCtx = paddle::platform::CPUDeviceContext;
REGISTER_OPERATOR(cross_entropy, ops::CrossEntropyOp, ops::CrossEntropyOpMaker, REGISTER_OPERATOR(cross_entropy, ops::CrossEntropyOpBase,
ops::CrossEntropyOpInferVarType, ops::CrossEntropyOpMaker, ops::CrossEntropyOpInferVarType,
paddle::framework::DefaultGradOpDescMaker<true>); paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(cross_entropy_grad, ops::CrossEntropyGradientOp); REGISTER_OPERATOR(cross_entropy_grad, ops::CrossEntropyGradientOp);
REGISTER_OP_CPU_KERNEL(cross_entropy, ops::CrossEntropyOpKernel<CPUCtx, float>, REGISTER_OP_CPU_KERNEL(cross_entropy, ops::CrossEntropyOpKernel<CPUCtx, float>,
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <unordered_map>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
class CrossEntropyOpBase : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null.");
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null.");
PADDLE_ENFORCE(ctx->HasOutput("Y"), "Output(Y) should be not null.");
auto x_dims = ctx->GetInputDim("X");
auto label_dims = ctx->GetInputDim("Label");
int rank = x_dims.size();
PADDLE_ENFORCE_EQ(rank, label_dims.size(),
"Input(X) and Input(Label) shall have the same rank.");
bool check = true;
if ((!ctx->IsRuntime()) && (framework::product(x_dims) <= 0 ||
framework::product(label_dims) <= 0)) {
check = false;
}
if (check) {
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1),
framework::slice_ddim(label_dims, 0, rank - 1),
"Input(X) and Input(Label) shall have the same shape "
"except the last dimension.");
}
if (IsSoftLabel(ctx)) {
if (check) {
PADDLE_ENFORCE_EQ(x_dims[rank - 1], label_dims[rank - 1],
"If Attr(soft_label) == true, the last dimension of "
"Input(X) and Input(Label) should be equal.");
}
} else {
PADDLE_ENFORCE_EQ(label_dims[rank - 1], 1UL,
"If Attr(softLabel) == false, the last dimension of "
"Input(Label) should be 1.");
}
auto y_dims = x_dims;
y_dims[rank - 1] = 1;
ctx->SetOutputDim("Y", y_dims);
ctx->ShareLoD("X", /*->*/ "Y");
}
protected:
// Explicitly set that the data type of computation kernel of cross_entropy
// is determined by its input "X".
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
ctx.device_context());
}
virtual bool IsSoftLabel(framework::InferShapeContext* ctx) const {
return ctx->Attrs().Get<bool>("soft_label");
}
};
class CrossEntropyOpInferVarType
: public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
const override {
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Y"}};
}
};
class CrossEntropyGradientOpBase : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")),
"Input(Y@GRAD) shoudl be not null.");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
"Output(X@GRAD) should be not null.");
auto x_dims = GetXDim(ctx);
auto label_dims = ctx->GetInputDim("Label");
auto dy_dims = ctx->GetInputDim(framework::GradVarName("Y"));
int rank = x_dims.size();
PADDLE_ENFORCE_EQ(dy_dims.size(), rank,
"Input(Y@Grad) and Input(X) should have the same rank.");
PADDLE_ENFORCE_EQ(label_dims.size(), rank,
"Input(Label) and Input(X) should have the same rank.");
bool check = true;
if ((!ctx->IsRuntime()) && (framework::product(x_dims) <= 0 ||
framework::product(label_dims) <= 0)) {
check = false;
}
if (check) {
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1),
framework::slice_ddim(label_dims, 0, rank - 1),
"The Input(X) and Input(Label) should have the same "
"shape except the last dimension.");
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1),
framework::slice_ddim(dy_dims, 0, rank - 1),
"The Input(X) and Input(Y@Grad) should have the same "
"shape except the last dimension.");
}
if (IsSoftLabel(ctx)) {
if (check) {
PADDLE_ENFORCE_EQ(
x_dims[rank - 1], label_dims[rank - 1],
"When Attr(soft_label) == true, the last dimension of "
"Input(X) and Input(Label) should be equal.");
}
} else {
PADDLE_ENFORCE_EQ(label_dims[rank - 1], 1,
"When Attr(soft_label) == false, the last dimension of "
"Input(Label) should be 1.");
}
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
PADDLE_ENFORCE_EQ(dy_dims[rank - 1], 1,
"The last dimension of Input(Y@Grad) should be 1.");
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
ctx->ShareLoD(VarNameWithXLoD(), framework::GradVarName("X"));
}
protected:
// Explicitly set that the data type of computation kernel of cross_entropy
// is determined by its input "X".
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
ctx.Input<Tensor>(framework::GradVarName("Y"))->type(),
ctx.device_context());
}
virtual framework::DDim GetXDim(framework::InferShapeContext* ctx) const {
return ctx->GetInputDim("X");
}
virtual const char* VarNameWithXLoD() const { return "X"; }
virtual bool IsSoftLabel(framework::InferShapeContext* ctx) const {
return ctx->Attrs().Get<bool>("soft_label");
}
};
} // namespace operators
} // namespace paddle
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/expand_op.h" #include "paddle/fluid/operators/expand_op.h"
#include <memory>
#include <vector> #include <vector>
namespace paddle { namespace paddle {
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/hostdevice.h"
#include "math.h" // NOLINT
namespace paddle {
namespace operators {
inline HOSTDEVICE platform::float16 real_exp(platform::float16 x) {
return static_cast<platform::float16>(::expf(static_cast<float>(x)));
}
inline HOSTDEVICE float real_exp(float x) { return ::expf(x); }
inline HOSTDEVICE double real_exp(double x) { return ::exp(x); }
inline HOSTDEVICE platform::float16 real_log(platform::float16 x) {
return static_cast<platform::float16>(::logf(static_cast<float>(x)));
}
inline HOSTDEVICE float real_log(float x) { return ::logf(x); }
inline HOSTDEVICE double real_log(double x) { return ::log(x); }
} // namespace operators
} // namespace paddle
...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/math.h"
#include "paddle/fluid/operators/math/cross_entropy.h" #include "paddle/fluid/operators/math/cross_entropy.h"
#include "paddle/fluid/platform/cuda_device_function.h" #include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/cuda_primitives.h"
...@@ -20,17 +21,6 @@ namespace paddle { ...@@ -20,17 +21,6 @@ namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
namespace {
__device__ __forceinline__ float real_log(float x) { return logf(x); }
__device__ __forceinline__ double real_log(double x) { return log(x); }
__device__ __forceinline__ platform::float16 real_log(
const platform::float16& val) {
return static_cast<platform::float16>(logf(static_cast<float>(val)));
}
template <typename T> template <typename T>
__global__ void CrossEntropyKernel(T* Y, const T* X, const int64_t* label, __global__ void CrossEntropyKernel(T* Y, const T* X, const int64_t* label,
const int N, const int D, const int N, const int D,
...@@ -61,7 +51,6 @@ __global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label, ...@@ -61,7 +51,6 @@ __global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label,
Y[blockIdx.x] = -val; Y[blockIdx.x] = -val;
} }
} }
} // namespace
template <typename T> template <typename T>
class CrossEntropyFunctor<platform::CUDADeviceContext, T> { class CrossEntropyFunctor<platform::CUDADeviceContext, T> {
......
...@@ -15,13 +15,12 @@ limitations under the License. */ ...@@ -15,13 +15,12 @@ limitations under the License. */
#pragma once #pragma once
#include <string> #include <string>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math.h"
#include "paddle/fluid/platform/for_range.h" #include "paddle/fluid/platform/for_range.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
static HOSTDEVICE float real_exp(float x) { return expf(x); }
static HOSTDEVICE float real_exp(double x) { return exp(x); }
template <typename T> template <typename T>
struct SeluFunctor { struct SeluFunctor {
SeluFunctor(const T* x_data_ptr, float alpha, float scale, T* y_data_ptr) SeluFunctor(const T* x_data_ptr, float alpha, float scale, T* y_data_ptr)
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#include <algorithm> #include <algorithm>
#include <cub/cub.cuh> // NOLINT #include <cub/cub.cuh> // NOLINT
#include "paddle/fluid/operators/math.h"
#include "paddle/fluid/operators/sequence_ops/sequence_softmax_op.h" #include "paddle/fluid/operators/sequence_ops/sequence_softmax_op.h"
namespace paddle { namespace paddle {
...@@ -21,9 +22,6 @@ namespace operators { ...@@ -21,9 +22,6 @@ namespace operators {
using LoDTensor = framework::LoDTensor; using LoDTensor = framework::LoDTensor;
__device__ __forceinline__ float real_exp(float x) { return expf(x); }
__device__ __forceinline__ double real_exp(double x) { return exp(x); }
template <typename T, int BlockDim> template <typename T, int BlockDim>
using BlockReduce = cub::BlockReduce<T, BlockDim>; using BlockReduce = cub::BlockReduce<T, BlockDim>;
......
...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "cub/cub.cuh" #include "cub/cub.cuh"
#include "paddle/fluid/operators/math.h"
#include "paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.h" #include "paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.h"
#include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/hostdevice.h" #include "paddle/fluid/platform/hostdevice.h"
...@@ -21,11 +22,6 @@ namespace operators { ...@@ -21,11 +22,6 @@ namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
static HOSTDEVICE float real_exp(float x) { return expf(x); }
static HOSTDEVICE float real_exp(double x) { return exp(x); }
static HOSTDEVICE float real_log(float x) { return logf(x); }
static HOSTDEVICE float real_log(double x) { return log(x); }
static constexpr int kNumCUDAThreads = 512; static constexpr int kNumCUDAThreads = 512;
static constexpr int kNumMaxinumNumBlocks = 4096; static constexpr int kNumMaxinumNumBlocks = 4096;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册