未验证 提交 595a2c83 编写于 作者: D dzhwinter 提交者: GitHub

explicit gradient of elementwise_add/elementwise_sub (#11970)

* "add gradient register"

* "make some enhance"

* "better format"

* "fix typo"

* "fix reuse"

* "fix get expected kernel"

* "change the mkldnn code"

* "fix mkldnn"

* "fix mkldnn failed test"

* "add comment"
上级 f37f875f
...@@ -40,6 +40,40 @@ OpProtoAndCheckerMaker::VariableBuilder OpProtoAndCheckerMaker::AddOutput( ...@@ -40,6 +40,40 @@ OpProtoAndCheckerMaker::VariableBuilder OpProtoAndCheckerMaker::AddOutput(
return OpProtoAndCheckerMaker::VariableBuilder{output}; return OpProtoAndCheckerMaker::VariableBuilder{output};
} }
void OpProtoAndCheckerMaker::Reuse(const std::string& name,
const std::string& reused_name) {
bool found = false;
proto::OpProto::Var* var;
for (auto& var : proto_->inputs()) {
if (var.name() == reused_name) {
found = true;
break;
}
}
PADDLE_ENFORCE(found == true,
"Input/Output name: %s reused_name: %s, one of them is not "
"exists or not matched.",
name, reused_name);
found = false;
for (int i = 0; i < proto_->outputs().size(); ++i) {
var = proto_->mutable_outputs()->Mutable(i);
if (var->name() == name) {
PADDLE_ENFORCE(!var->has_reuse(),
"Output(%s) has been set reused var of %s", name,
var->reuse());
found = true;
var->set_reuse(reused_name);
break;
}
}
PADDLE_ENFORCE(found == true,
"Input/Output name: %s reused_name: %s, one of them is not "
"exists or not matched.",
name, reused_name);
}
void OpProtoAndCheckerMaker::CheckNoDuplicatedInOutAttrs() { void OpProtoAndCheckerMaker::CheckNoDuplicatedInOutAttrs() {
std::unordered_set<std::string> names; std::unordered_set<std::string> names;
auto checker = [&](const std::string& name) { auto checker = [&](const std::string& name) {
......
...@@ -78,6 +78,8 @@ class OpProtoAndCheckerMaker { ...@@ -78,6 +78,8 @@ class OpProtoAndCheckerMaker {
VariableBuilder AddOutput(const std::string &name, VariableBuilder AddOutput(const std::string &name,
const std::string &comment); const std::string &comment);
void Reuse(const std::string &name, const std::string &reused_name);
template <typename T> template <typename T>
TypedAttrChecker<T> &AddAttr(const std::string &name, TypedAttrChecker<T> &AddAttr(const std::string &name,
const std::string &comment, const std::string &comment,
......
...@@ -49,6 +49,15 @@ TEST(ProtoMaker, DuplicatedInOut) { ...@@ -49,6 +49,15 @@ TEST(ProtoMaker, DuplicatedInOut) {
} }
class TestInplaceProtoMaker : public paddle::framework::OpProtoAndCheckerMaker { class TestInplaceProtoMaker : public paddle::framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "input of test op");
AddOutput("XOut", "output of test op").Reuse("X");
}
};
class TestInplaceProtoMaker2
: public paddle::framework::OpProtoAndCheckerMaker {
public: public:
void Make() { void Make() {
AddInput("X", "input of test op"); AddInput("X", "input of test op");
...@@ -58,12 +67,100 @@ class TestInplaceProtoMaker : public paddle::framework::OpProtoAndCheckerMaker { ...@@ -58,12 +67,100 @@ class TestInplaceProtoMaker : public paddle::framework::OpProtoAndCheckerMaker {
}; };
TEST(ProtoMaker, InplaceOutput) { TEST(ProtoMaker, InplaceOutput) {
paddle::framework::proto::OpProto op_proto; paddle::framework::proto::OpProto op_proto, op_proto2;
paddle::framework::OpAttrChecker op_checker; paddle::framework::OpAttrChecker op_checker;
TestInplaceProtoMaker proto_maker; TestInplaceProtoMaker proto_maker;
ASSERT_THROW(proto_maker(&op_proto, &op_checker), TestInplaceProtoMaker2 proto_maker2;
proto_maker(&op_proto, &op_checker);
ASSERT_THROW(proto_maker2(&op_proto2, &op_checker),
paddle::platform::EnforceNotMet); paddle::platform::EnforceNotMet);
// proto_maker(&op_proto, &op_checker);
// proto_maker.Make();
// ASSERT_THROW(proto_maker.Validate(), paddle::platform::EnforceNotMet);
} }
// normal reuse
class TestReuseProtoMaker : public paddle::framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "input of test op");
AddInput("Y", "input of test op");
AddOutput("Out", "output of test op");
AddOutput("XOut", "output of test op");
// avoid destructor exception.
// Validate();
TestReuse();
}
virtual void TestReuse() {}
};
// test duplicate reuse error
class TestReuseProtoMaker2 : public TestReuseProtoMaker {
public:
void TestReuse() {
Reuse("Out", "X");
Reuse("Out", "Y");
}
};
// NotExists Input
class TestReuseProtoMaker3 : public TestReuseProtoMaker {
public:
void TestReuse() {
Reuse("Out", "NotExists");
Reuse("XOut", "X");
}
};
// NotExists Output
class TestReuseProtoMaker4 : public TestReuseProtoMaker {
public:
void TestReuse() { Reuse("NotExists", "X"); }
};
TEST(ProtoMaker, Reuse) {
paddle::framework::proto::OpProto op_proto;
paddle::framework::OpAttrChecker op_checker;
TestReuseProtoMaker proto_maker;
proto_maker(&op_proto, &op_checker);
}
// NOTE(dzhwinter):
// There is a Fatal CHECK on base class destructor, which will call abort inside
// instead of
// throw an exception. If we throw an exception in Make(), we will trigger the
// CHECK and terminate the tests.
//
// I had tried to replace the default CHECK with a exception, however, it's
// still not supported by glog.
// the details:
// https://github.com/google/glog/issues/249
// https://github.com/facebookresearch/TensorComprehensions/issues/351
/*
TEST(ProtoMaker, ReuseWithException) {
paddle::framework::proto::OpProto op_proto2, op_proto3, op_proto4;
paddle::framework::OpAttrChecker op_checker;
TestReuseProtoMaker2 proto_maker2;
TestReuseProtoMaker3 proto_maker3;
TestReuseProtoMaker4 proto_maker4;
EXPECT_THROW(proto_maker2(&op_proto2, &op_checker),
paddle::platform::EnforceNotMet);
EXPECT_THROW(proto_maker3(&op_proto3, &op_checker),
paddle::platform::EnforceNotMet);
EXPECT_THROW(proto_maker4(&op_proto4, &op_checker),
paddle::platform::EnforceNotMet);
}
void FailureFunction() {
throw std::runtime_error("Check failed in destructor.");
// return 0;
}
int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
google::InstallFailureFunction(&FailureFunction);
return RUN_ALL_TESTS();
}
*/
...@@ -47,12 +47,12 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> { ...@@ -47,12 +47,12 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> {
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
auto x_dims = x->dims(); auto x_dims = x->dims();
auto y_dims = y->dims(); auto y_dims_untrimed = y->dims();
auto z_dims = z->dims(); auto z_dims = z->dims();
// Execute default elementwise_add operator when // Execute default elementwise_add operator when
// broadcast operations need to performed. // broadcast operations need to performed.
if (x_dims != y_dims) { if (x_dims != y_dims_untrimed) {
auto sum_func = [](T a, T b) -> T { return a + b; }; auto sum_func = [](T a, T b) -> T { return a + b; };
TransformFunctor<decltype(sum_func), T, TransformFunctor<decltype(sum_func), T,
...@@ -62,11 +62,11 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> { ...@@ -62,11 +62,11 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> {
ctx.template device_context<paddle::platform::CPUDeviceContext>(), ctx.template device_context<paddle::platform::CPUDeviceContext>(),
sum_func); sum_func);
axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis); axis = (axis == -1 ? x_dims.size() - y_dims_untrimed.size() : axis);
PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(), PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(),
"Axis should be in range [0, x_dims)"); "Axis should be in range [0, x_dims)");
trim_trailing_singular_dims(&y_dims); auto y_dims = trim_trailing_singular_dims(y_dims_untrimed);
axis = (y_dims.size() == 0) ? x_dims.size() : axis; axis = (y_dims.size() == 0) ? x_dims.size() : axis;
int pre, n, post; int pre, n, post;
...@@ -88,7 +88,7 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> { ...@@ -88,7 +88,7 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> {
"Wrong layout/format set for Y tensor"); "Wrong layout/format set for Y tensor");
std::vector<int> src_x_tz = framework::vectorize2int(x_dims); std::vector<int> src_x_tz = framework::vectorize2int(x_dims);
std::vector<int> src_y_tz = framework::vectorize2int(y_dims); std::vector<int> src_y_tz = framework::vectorize2int(y_dims_untrimed);
std::vector<int> dst_tz = framework::vectorize2int(z_dims); std::vector<int> dst_tz = framework::vectorize2int(z_dims);
std::vector<memory::primitive_desc> srcs_pd; std::vector<memory::primitive_desc> srcs_pd;
...@@ -142,36 +142,39 @@ class EltwiseAddMKLDNNGradKernel : public framework::OpKernel<T> { ...@@ -142,36 +142,39 @@ class EltwiseAddMKLDNNGradKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* out = ctx.Input<Tensor>("Out");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out")); auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X")); auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y")); auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
// skip out, x, y,
// dout length is larger or equal than dx, dy.
auto* out = dout;
auto *x = dout, *y = dout;
auto set_mkldnn_format = [](Tensor* in, const Tensor* out) { auto set_mkldnn_format = [](Tensor* in, const Tensor* out) {
in->set_layout(DataLayout::kMKLDNN); in->set_layout(DataLayout::kMKLDNN);
in->set_format(out->format()); in->set_format(out->format());
}; };
if (x->dims() == y->dims()) { if (dx != nullptr && dy != nullptr && dx->dims() == dy->dims()) {
auto blas = math::GetBlas<paddle::platform::CPUDeviceContext, T>(ctx); if (dx->dims() == dy->dims()) {
if (dx) { auto blas = math::GetBlas<paddle::platform::CPUDeviceContext, T>(ctx);
blas.VCOPY(dout->numel(), dout->data<T>(), if (dx) {
dx->mutable_data<T>(ctx.GetPlace())); blas.VCOPY(dout->numel(), dout->data<T>(),
set_mkldnn_format(dx, dout); dx->mutable_data<T>(ctx.GetPlace()));
} set_mkldnn_format(dx, dout);
}
if (dy) {
blas.VCOPY(dout->numel(), dout->data<T>(), if (dy) {
dy->mutable_data<T>(ctx.GetPlace())); blas.VCOPY(dout->numel(), dout->data<T>(),
set_mkldnn_format(dy, dout); dy->mutable_data<T>(ctx.GetPlace()));
set_mkldnn_format(dy, dout);
}
} }
} else { } else {
// Execute default kernel when broadcast is needed // Execute default kernel when broadcast is needed
ElemwiseGradCompute<paddle::platform::CPUDeviceContext, T, ElemwiseExplicitGradCompute<paddle::platform::CPUDeviceContext, T,
IdentityGrad<T>, IdentityGrad<T>>( IdentityGrad<T>, IdentityGrad<T>>(
ctx, *x, *y, *out, *dout, axis, dx, dy, IdentityGrad<T>(), ctx, *x, *y, *out, *dout, axis, dx, dy, IdentityGrad<T>(),
IdentityGrad<T>()); IdentityGrad<T>());
} }
......
...@@ -15,7 +15,9 @@ limitations under the License. */ ...@@ -15,7 +15,9 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise_add_op.h" #include "paddle/fluid/operators/elementwise_add_op.h"
#include "paddle/fluid/operators/elementwise_op.h" #include "paddle/fluid/operators/elementwise_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_ELEMWISE_OP(elementwise_add, "Add", "Out = X + Y"); REGISTER_ELEMWISE_GRAD_MAKER(elementwise_add, Add);
REGISTER_ELEMWISE_EXPLICIT_OP(elementwise_add, "Add", "Out = X + Y", "Out",
"X");
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
elementwise_add, elementwise_add,
ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext, float>, ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext, float>,
......
...@@ -95,9 +95,10 @@ void default_elementwise_add_grad(const framework::ExecutionContext& ctx, ...@@ -95,9 +95,10 @@ void default_elementwise_add_grad(const framework::ExecutionContext& ctx,
framework::Tensor* dy) { framework::Tensor* dy) {
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
ElemwiseGradCompute<DeviceContext, T, IdentityGrad<T>, IdentityGrad<T>>( ElemwiseExplicitGradCompute<DeviceContext, T, IdentityGrad<T>,
ctx, *x, *y, *out, *dout, axis, dx, dy, IdentityGrad<T>(), IdentityGrad<T>>(ctx, *x, *y, *out, *dout, axis,
IdentityGrad<T>()); dx, dy, IdentityGrad<T>(),
IdentityGrad<T>());
} }
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
...@@ -140,14 +141,15 @@ class ElementwiseAddGradKernel : public framework::OpKernel<T> { ...@@ -140,14 +141,15 @@ class ElementwiseAddGradKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* out = ctx.Input<Tensor>("Out");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out")); auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X")); auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y")); auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
// skip out, x, y
auto* out = dout;
auto *x = dout, *y = dout;
if (platform::is_cpu_place(ctx.GetPlace()) && (x->dims() == y->dims())) { if (platform::is_cpu_place(ctx.GetPlace()) && dx != nullptr &&
dy != nullptr && (dx->dims() == dy->dims())) {
elementwise_add_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy); elementwise_add_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
} else { } else {
default_elementwise_add_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, default_elementwise_add_grad<DeviceContext, T>(ctx, x, y, out, dout, dx,
......
...@@ -15,7 +15,9 @@ limitations under the License. */ ...@@ -15,7 +15,9 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise_div_op.h" #include "paddle/fluid/operators/elementwise_div_op.h"
#include "paddle/fluid/operators/elementwise_op.h" #include "paddle/fluid/operators/elementwise_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_ELEMWISE_OP(elementwise_div, "Div", "Out = X / Y"); REGISTER_ELEMWISE_OP(elementwise_div, "Div", "Out = X / Y");
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
elementwise_div, elementwise_div,
ops::ElementwiseDivKernel<paddle::platform::CPUDeviceContext, float>, ops::ElementwiseDivKernel<paddle::platform::CPUDeviceContext, float>,
......
...@@ -78,7 +78,9 @@ class ElementwiseOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -78,7 +78,9 @@ class ElementwiseOpMaker : public framework::OpProtoAndCheckerMaker {
void Make() final { void Make() final {
AddInput("X", "(Tensor), The first input tensor of elementwise op."); AddInput("X", "(Tensor), The first input tensor of elementwise op.");
AddInput("Y", "(Tensor), The second input tensor of elementwise op."); AddInput("Y", "(Tensor), The second input tensor of elementwise op.");
AddOutput("Out", "The output of elementwise op.").Reuse("X"); // AddOutput("SavedShape", "(Tensor), save X, Y shape for grad to save
// memory.").AsIntermediate();
AddOutput("Out", "The output of elementwise op.");
AddAttr<int>("axis", AddAttr<int>("axis",
"(int, default -1). The start dimension index " "(int, default -1). The start dimension index "
"for broadcasting Y onto X.") "for broadcasting Y onto X.")
...@@ -125,11 +127,13 @@ But the output only shares the LoD information with the input $X$. ...@@ -125,11 +127,13 @@ But the output only shares the LoD information with the input $X$.
)DOC", )DOC",
GetName(), GetEquation())); GetName(), GetEquation()));
SetReuse();
} }
protected: protected:
virtual std::string GetName() const = 0; virtual std::string GetName() const = 0;
virtual std::string GetEquation() const = 0; virtual std::string GetEquation() const = 0;
virtual void SetReuse() {}
}; };
class ElementwiseOpGrad : public framework::OperatorWithKernel { class ElementwiseOpGrad : public framework::OperatorWithKernel {
...@@ -162,8 +166,8 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel { ...@@ -162,8 +166,8 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
auto input_data_type = auto input_data_type = framework::ToDataType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()); ctx.Input<Tensor>(framework::GradVarName("Out"))->type());
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (platform::CanMKLDNNBeUsed(ctx)) { if (platform::CanMKLDNNBeUsed(ctx)) {
...@@ -175,9 +179,58 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel { ...@@ -175,9 +179,58 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
}; };
// For Add, Sub op, the X, Out is not needed.
class ElementwiseOpExplicitGrad : public ElementwiseOpGrad {
public:
using operators::ElementwiseOpGrad::ElementwiseOpGrad;
using operators::ElementwiseOpGrad::GetExpectedKernelType;
using Tensor = framework::Tensor;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null");
auto x_grad_name = framework::GradVarName("X");
if (ctx->HasOutput(x_grad_name)) {
auto out_dims = ctx->GetInputDim(framework::GradVarName("Out"));
ctx->SetOutputDim(x_grad_name, out_dims);
}
auto y_grad_name = framework::GradVarName("Y");
if (ctx->HasOutput(y_grad_name)) {
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null");
auto y_dims = ctx->GetInputDim("Y");
ctx->SetOutputDim(y_grad_name, y_dims);
}
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
/*
*/
#define REGISTER_ELEMWISE_GRAD_MAKER(kernel_type, op_name) \
class kernel_type##GradMaker \
: public paddle::framework::SingleGradOpDescMaker { \
public: \
using ::paddle::framework::SingleGradOpDescMaker::SingleGradOpDescMaker; \
\
protected: \
std::unique_ptr<paddle::framework::OpDesc> Apply() const override { \
auto* op = new paddle::framework::OpDesc(); \
op->SetType(#kernel_type "_grad"); \
op->SetInput("Y", Input("Y")); \
op->SetInput(::paddle::framework::GradVarName("Out"), \
OutputGrad("Out")); \
op->SetAttrMap(Attrs()); \
op->SetOutput(::paddle::framework::GradVarName("X"), InputGrad("X")); \
op->SetOutput(::paddle::framework::GradVarName("Y"), InputGrad("Y")); \
return std::unique_ptr<::paddle::framework::OpDesc>(op); \
} \
}
#define REGISTER_ELEMWISE_OP(op_type, op_name, equation) \ #define REGISTER_ELEMWISE_OP(op_type, op_name, equation) \
class __ElemwiseOp##op_type##Maker__ \ class __ElemwiseOp##op_type##Maker__ \
: public ::paddle::operators::ElementwiseOpMaker { \ : public ::paddle::operators::ElementwiseOpMaker { \
...@@ -190,3 +243,18 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel { ...@@ -190,3 +243,18 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
::paddle::operators::ElementwiseOpInferVarType, \ ::paddle::operators::ElementwiseOpInferVarType, \
::paddle::framework::DefaultGradOpDescMaker<true>); \ ::paddle::framework::DefaultGradOpDescMaker<true>); \
REGISTER_OPERATOR(op_type##_grad, ::paddle::operators::ElementwiseOpGrad) REGISTER_OPERATOR(op_type##_grad, ::paddle::operators::ElementwiseOpGrad)
#define REGISTER_ELEMWISE_EXPLICIT_OP(op_type, op_name, equation, ...) \
class __ElemwiseOp##op_type##Maker__ \
: public ::paddle::operators::ElementwiseOpMaker { \
protected: \
virtual std::string GetName() const { return op_name; } \
virtual std::string GetEquation() const { return equation; } \
virtual void SetReuse() { Reuse(__VA_ARGS__); } \
}; \
REGISTER_OPERATOR(op_type, ::paddle::operators::ElementwiseOp, \
__ElemwiseOp##op_type##Maker__, \
::paddle::operators::ElementwiseOpInferVarType, \
op_type##GradMaker); \
REGISTER_OPERATOR(op_type##_grad, \
::paddle::operators::ElementwiseOpExplicitGrad)
...@@ -13,7 +13,9 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <glog/logging.h>
#include <algorithm> #include <algorithm>
#include <vector>
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
...@@ -65,17 +67,21 @@ inline void get_mid_dims(const framework::DDim& x_dims, ...@@ -65,17 +67,21 @@ inline void get_mid_dims(const framework::DDim& x_dims,
} }
} }
inline void trim_trailing_singular_dims(framework::DDim* dims) { inline framework::DDim trim_trailing_singular_dims(
const framework::DDim& dims) {
// Remove trailing dimensions of size 1 for y // Remove trailing dimensions of size 1 for y
auto actual_dims_size = dims->size(); auto actual_dims_size = dims.size();
for (; actual_dims_size != 0; --actual_dims_size) { for (; actual_dims_size != 0; --actual_dims_size) {
if ((*dims)[actual_dims_size - 1] != 1) break; if (dims[actual_dims_size - 1] != 1) break;
} }
if (actual_dims_size != dims->size()) {
auto actual_dims = framework::vectorize(*dims); std::vector<int> trim_dims;
actual_dims.resize(actual_dims_size); trim_dims.resize(actual_dims_size);
*dims = framework::make_ddim(actual_dims); for (int i = 0; i < actual_dims_size; ++i) {
trim_dims[i] = dims[i];
} }
framework::DDim actual_dims = framework::make_ddim(trim_dims);
return actual_dims;
} }
template <typename T, typename DeviceContext> template <typename T, typename DeviceContext>
...@@ -456,6 +462,71 @@ static void ElemwiseGradBroadcast2CUDA(cudaStream_t stream, const T* x, ...@@ -456,6 +462,71 @@ static void ElemwiseGradBroadcast2CUDA(cudaStream_t stream, const T* x,
#endif #endif
template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP>
void ElemwiseGradComputeNoBroadcast(
const framework::ExecutionContext& ctx, const framework::DDim& x_dim,
const framework::DDim& y_dim, const framework::Tensor& x,
const framework::Tensor& y, const framework::Tensor& out,
const framework::Tensor& dout, int axis, framework::Tensor* dx,
framework::Tensor* dy, DX_OP dx_op, DY_OP dy_op) {
size_t N = static_cast<size_t>(framework::product(x_dim));
platform::ForRange<DeviceContext> for_range(
ctx.template device_context<DeviceContext>(), N);
for_range(ElemwiseGradNoBroadcast<T, DX_OP, DY_OP>{
x.data<T>(), y.data<T>(), out.data<T>(), dout.data<T>(), dx_op, dy_op,
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace())});
}
template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP>
void ElemwiseGradComputeWithBroadcast(
const framework::ExecutionContext& ctx, const framework::DDim& x_dim,
const framework::DDim& y_dim_untrimed, const framework::Tensor& x,
const framework::Tensor& y, const framework::Tensor& out,
const framework::Tensor& dout, int axis, framework::Tensor* dx,
framework::Tensor* dy, DX_OP dx_op, DY_OP dy_op) {
axis = (axis == -1 ? x_dim.size() - y_dim_untrimed.size() : axis);
auto y_dim = trim_trailing_singular_dims(y_dim_untrimed);
axis = (y_dim.size() == 0) ? x_dim.size() : axis;
int pre, n, post;
get_mid_dims(x_dim, y_dim, axis, &pre, &n, &post);
if (post == 1) {
int h = pre;
int w = n;
if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef __NVCC__
ElemwiseGradBroadcast1CUDA(
ctx.template device_context<DeviceContext>().stream(), x.data<T>(),
y.data<T>(), out.data<T>(), dout.data<T>(), h, w, dx_op, dy_op,
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
#endif
} else {
ElemwiseGradBroadcast1CPU(
x.data<T>(), y.data<T>(), out.data<T>(), dout.data<T>(), h, w, dx_op,
dy_op, dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
}
} else {
if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef __NVCC__
ElemwiseGradBroadcast2CUDA(
ctx.template device_context<DeviceContext>().stream(), x.data<T>(),
y.data<T>(), out.data<T>(), dout.data<T>(), pre, n, post, dx_op,
dy_op, dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
#endif
} else {
ElemwiseGradBroadcast2CPU(
x.data<T>(), y.data<T>(), out.data<T>(), dout.data<T>(), pre, n, post,
dx_op, dy_op,
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
}
}
}
template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP> template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP>
void ElemwiseGradCompute(const framework::ExecutionContext& ctx, void ElemwiseGradCompute(const framework::ExecutionContext& ctx,
const framework::Tensor& x, const framework::Tensor& y, const framework::Tensor& x, const framework::Tensor& y,
...@@ -463,63 +534,50 @@ void ElemwiseGradCompute(const framework::ExecutionContext& ctx, ...@@ -463,63 +534,50 @@ void ElemwiseGradCompute(const framework::ExecutionContext& ctx,
const framework::Tensor& dout, int axis, const framework::Tensor& dout, int axis,
framework::Tensor* dx, framework::Tensor* dy, framework::Tensor* dx, framework::Tensor* dy,
DX_OP dx_op, DY_OP dy_op) { DX_OP dx_op, DY_OP dy_op) {
const framework::DDim x_dim = x.dims();
const framework::DDim y_dim = y.dims();
if (x.dims() == y.dims()) { if (x.dims() == y.dims()) {
size_t N = static_cast<size_t>(framework::product(x.dims())); ElemwiseGradComputeNoBroadcast<DeviceContext, T, DX_OP, DY_OP>(
platform::ForRange<DeviceContext> for_range( ctx, x_dim, y_dim, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
ctx.template device_context<DeviceContext>(), N);
for_range(ElemwiseGradNoBroadcast<T, DX_OP, DY_OP>{
x.data<T>(), y.data<T>(), out.data<T>(), dout.data<T>(), dx_op, dy_op,
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace())});
} else { // Y is a scalar } else { // Y is a scalar
auto x_dim = x.dims(); ElemwiseGradComputeWithBroadcast<DeviceContext, T, DX_OP, DY_OP>(
auto y_dim = y.dims(); ctx, x_dim, y_dim, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
}
axis = (axis == -1 ? x_dim.size() - y_dim.size() : axis); }
trim_trailing_singular_dims(&y_dim);
axis = (y_dim.size() == 0) ? x_dim.size() : axis; // NOTE(dzhwinter): Only used in elementwise_add, elementwise_sub.
// explicit gradient can cut off X, Y, Out from gradient op
int pre, n, post; // In elementwise_add, elementwise_sub, we use dout as fake X, Y, Out to reuse
get_mid_dims(x_dim, y_dim, axis, &pre, &n, &post); // elementwise code.
if (post == 1) { template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP>
int h = pre; void ElemwiseExplicitGradCompute(const framework::ExecutionContext& ctx,
int w = n; const framework::Tensor& x,
if (platform::is_gpu_place(ctx.GetPlace())) { const framework::Tensor& y,
#ifdef __NVCC__ const framework::Tensor& out,
ElemwiseGradBroadcast1CUDA( const framework::Tensor& dout, int axis,
ctx.template device_context<DeviceContext>().stream(), x.data<T>(), framework::Tensor* dx, framework::Tensor* dy,
y.data<T>(), out.data<T>(), dout.data<T>(), h, w, dx_op, dy_op, DX_OP dx_op, DY_OP dy_op) {
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()), if (dy == nullptr) {
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace())); const framework::DDim dx_dims = dout.dims();
#endif auto dy_dims = dx_dims;
} else { ElemwiseGradComputeNoBroadcast<DeviceContext, T, DX_OP, DY_OP>(
ElemwiseGradBroadcast1CPU( ctx, dx_dims, dy_dims, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
x.data<T>(), y.data<T>(), out.data<T>(), dout.data<T>(), h, w, } else {
dx_op, dy_op, if (dout.dims() == dy->dims()) {
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()), const framework::DDim dx_dims = dout.dims();
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace())); const framework::DDim dy_dims = dy->dims();
} ElemwiseGradComputeNoBroadcast<DeviceContext, T, DX_OP, DY_OP>(
} else { ctx, dx_dims, dy_dims, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
if (platform::is_gpu_place(ctx.GetPlace())) { } else { // Y is a scalar
#ifdef __NVCC__ auto dx_dims = dout.dims();
ElemwiseGradBroadcast2CUDA( const framework::DDim dy_dims = dy->dims();
ctx.template device_context<DeviceContext>().stream(), x.data<T>(), ElemwiseGradComputeWithBroadcast<DeviceContext, T, DX_OP, DY_OP>(
y.data<T>(), out.data<T>(), dout.data<T>(), pre, n, post, dx_op, ctx, dx_dims, dy_dims, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
dy_op,
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
#endif
} else {
ElemwiseGradBroadcast2CPU(
x.data<T>(), y.data<T>(), out.data<T>(), dout.data<T>(), pre, n,
post, dx_op, dy_op,
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
}
} }
} }
} }
// Deprecated
template <typename DeviceContext, typename T, typename functor, template <typename DeviceContext, typename T, typename functor,
typename broadcastfunctor, typename broadcast2functor> typename broadcastfunctor, typename broadcast2functor>
void ElementwiseGradCompute(const framework::ExecutionContext& ctx, void ElementwiseGradCompute(const framework::ExecutionContext& ctx,
...@@ -547,7 +605,7 @@ void ElementwiseGradCompute(const framework::ExecutionContext& ctx, ...@@ -547,7 +605,7 @@ void ElementwiseGradCompute(const framework::ExecutionContext& ctx,
} }
axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis); axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis);
trim_trailing_singular_dims(&y_dims); trim_trailing_singular_dims(y_dims);
axis = (y_dims.size() == 0) ? x_dims.size() : axis; axis = (y_dims.size() == 0) ? x_dims.size() : axis;
int pre, n, post; int pre, n, post;
...@@ -574,19 +632,19 @@ void ElementwiseComputeEx(const framework::ExecutionContext& ctx, ...@@ -574,19 +632,19 @@ void ElementwiseComputeEx(const framework::ExecutionContext& ctx,
x, y, z, ctx.template device_context<DeviceContext>(), func); x, y, z, ctx.template device_context<DeviceContext>(), func);
auto x_dims = x->dims(); auto x_dims = x->dims();
auto y_dims = y->dims(); auto y_dims_untrimed = y->dims();
PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(), PADDLE_ENFORCE_GE(x_dims.size(), y_dims_untrimed.size(),
"Rank of first input must >= rank of second input."); "Rank of first input must >= rank of second input.");
if (x_dims == y_dims) { if (x_dims == y_dims_untrimed) {
functor.Run(); functor.Run();
return; return;
} }
axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis); axis = (axis == -1 ? x_dims.size() - y_dims_untrimed.size() : axis);
PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(), PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(),
"Axis should be in range [0, x_dims)"); "Axis should be in range [0, x_dims)");
trim_trailing_singular_dims(&y_dims); auto y_dims = trim_trailing_singular_dims(y_dims_untrimed);
axis = (y_dims.size() == 0) ? x_dims.size() : axis; axis = (y_dims.size() == 0) ? x_dims.size() : axis;
int pre, n, post; int pre, n, post;
......
...@@ -15,7 +15,10 @@ limitations under the License. */ ...@@ -15,7 +15,10 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise_sub_op.h" #include "paddle/fluid/operators/elementwise_sub_op.h"
#include "paddle/fluid/operators/elementwise_op.h" #include "paddle/fluid/operators/elementwise_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_ELEMWISE_OP(elementwise_sub, "Sub", "Out = X - Y"); REGISTER_ELEMWISE_GRAD_MAKER(elementwise_sub, Sub);
REGISTER_ELEMWISE_EXPLICIT_OP(elementwise_sub, "Sub", "Out = X - Y", "Out",
"X");
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
elementwise_sub, elementwise_sub,
ops::ElementwiseSubKernel<paddle::platform::CPUDeviceContext, float>, ops::ElementwiseSubKernel<paddle::platform::CPUDeviceContext, float>,
......
...@@ -4,7 +4,7 @@ Licensed under the Apache License, Version 2.0 (the "License"); ...@@ -4,7 +4,7 @@ Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
...@@ -55,14 +55,15 @@ class ElementwiseSubGradKernel : public framework::OpKernel<T> { ...@@ -55,14 +55,15 @@ class ElementwiseSubGradKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* out = ctx.Input<Tensor>("Out");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out")); auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X")); auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y")); auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
ElemwiseGradCompute<DeviceContext, T, SubGradDX<T>, SubGradDY<T>>( // skip out, x, y
auto* out = dout;
auto *x = dout, *y = dout;
ElemwiseExplicitGradCompute<DeviceContext, T, SubGradDX<T>, SubGradDY<T>>(
ctx, *x, *y, *out, *dout, axis, dx, dy, SubGradDX<T>(), SubGradDY<T>()); ctx, *x, *y, *out, *dout, axis, dx, dy, SubGradDX<T>(), SubGradDY<T>());
} }
}; };
......
...@@ -137,7 +137,8 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel { ...@@ -137,7 +137,8 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel {
ctx->GetInputDim(framework::GradVarName("Out")), ctx->GetInputDim(framework::GradVarName("Out")),
"Input(Out) and its gradients should have a same shape."); "Input(Out) and its gradients should have a same shape.");
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); ctx->SetOutputDim(framework::GradVarName("X"),
ctx->GetInputDim(framework::GradVarName("Out")));
} }
protected: protected:
...@@ -160,8 +161,8 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel { ...@@ -160,8 +161,8 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel {
layout_ = framework::DataLayout::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN;
} }
#endif #endif
auto input_data_type = auto input_data_type = framework::ToDataType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()); ctx.Input<Tensor>(framework::GradVarName("Out"))->type());
if (input_data_type == framework::proto::VarType::FP16) { if (input_data_type == framework::proto::VarType::FP16) {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"float16 can only be used on GPU place"); "float16 can only be used on GPU place");
...@@ -172,13 +173,31 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel { ...@@ -172,13 +173,31 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel {
} }
}; };
class SoftmaxOpGradMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
auto* op = new framework::OpDesc();
op->SetType("softmax_grad");
op->SetInput("Out", Output("Out"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetAttrMap(Attrs());
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
return std::unique_ptr<framework::OpDesc>(op);
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(softmax, ops::SoftmaxOp, ops::SoftmaxOpMaker, REGISTER_OPERATOR(softmax, ops::SoftmaxOp, ops::SoftmaxOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>); ops::SoftmaxOpGradMaker);
REGISTER_OPERATOR(softmax_grad, ops::SoftmaxOpGrad); REGISTER_OPERATOR(softmax_grad, ops::SoftmaxOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
softmax, ops::SoftmaxKernel<paddle::platform::CPUDeviceContext, float>, softmax, ops::SoftmaxKernel<paddle::platform::CPUDeviceContext, float>,
......
...@@ -20,8 +20,8 @@ class TestElementwiseOp(OpTest): ...@@ -20,8 +20,8 @@ class TestElementwiseOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "elementwise_sub" self.op_type = "elementwise_sub"
self.inputs = { self.inputs = {
'X': np.random.uniform(0.1, 1, [13, 17]).astype("float32"), 'X': np.random.uniform(0.1, 1, [2, 3]).astype("float32"),
'Y': np.random.uniform(0.1, 1, [13, 17]).astype("float32") 'Y': np.random.uniform(0.1, 1, [2, 3]).astype("float32")
} }
self.outputs = {'Out': self.inputs['X'] - self.inputs['Y']} self.outputs = {'Out': self.inputs['X'] - self.inputs['Y']}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册