未验证 提交 595a2c83 编写于 作者: D dzhwinter 提交者: GitHub

explicit gradient of elementwise_add/elementwise_sub (#11970)

* "add gradient register"

* "make some enhance"

* "better format"

* "fix typo"

* "fix reuse"

* "fix get expected kernel"

* "change the mkldnn code"

* "fix mkldnn"

* "fix mkldnn failed test"

* "add comment"
上级 f37f875f
...@@ -40,6 +40,40 @@ OpProtoAndCheckerMaker::VariableBuilder OpProtoAndCheckerMaker::AddOutput( ...@@ -40,6 +40,40 @@ OpProtoAndCheckerMaker::VariableBuilder OpProtoAndCheckerMaker::AddOutput(
return OpProtoAndCheckerMaker::VariableBuilder{output}; return OpProtoAndCheckerMaker::VariableBuilder{output};
} }
void OpProtoAndCheckerMaker::Reuse(const std::string& name,
const std::string& reused_name) {
bool found = false;
proto::OpProto::Var* var;
for (auto& var : proto_->inputs()) {
if (var.name() == reused_name) {
found = true;
break;
}
}
PADDLE_ENFORCE(found == true,
"Input/Output name: %s reused_name: %s, one of them is not "
"exists or not matched.",
name, reused_name);
found = false;
for (int i = 0; i < proto_->outputs().size(); ++i) {
var = proto_->mutable_outputs()->Mutable(i);
if (var->name() == name) {
PADDLE_ENFORCE(!var->has_reuse(),
"Output(%s) has been set reused var of %s", name,
var->reuse());
found = true;
var->set_reuse(reused_name);
break;
}
}
PADDLE_ENFORCE(found == true,
"Input/Output name: %s reused_name: %s, one of them is not "
"exists or not matched.",
name, reused_name);
}
void OpProtoAndCheckerMaker::CheckNoDuplicatedInOutAttrs() { void OpProtoAndCheckerMaker::CheckNoDuplicatedInOutAttrs() {
std::unordered_set<std::string> names; std::unordered_set<std::string> names;
auto checker = [&](const std::string& name) { auto checker = [&](const std::string& name) {
......
...@@ -78,6 +78,8 @@ class OpProtoAndCheckerMaker { ...@@ -78,6 +78,8 @@ class OpProtoAndCheckerMaker {
VariableBuilder AddOutput(const std::string &name, VariableBuilder AddOutput(const std::string &name,
const std::string &comment); const std::string &comment);
void Reuse(const std::string &name, const std::string &reused_name);
template <typename T> template <typename T>
TypedAttrChecker<T> &AddAttr(const std::string &name, TypedAttrChecker<T> &AddAttr(const std::string &name,
const std::string &comment, const std::string &comment,
......
...@@ -49,6 +49,15 @@ TEST(ProtoMaker, DuplicatedInOut) { ...@@ -49,6 +49,15 @@ TEST(ProtoMaker, DuplicatedInOut) {
} }
class TestInplaceProtoMaker : public paddle::framework::OpProtoAndCheckerMaker { class TestInplaceProtoMaker : public paddle::framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "input of test op");
AddOutput("XOut", "output of test op").Reuse("X");
}
};
class TestInplaceProtoMaker2
: public paddle::framework::OpProtoAndCheckerMaker {
public: public:
void Make() { void Make() {
AddInput("X", "input of test op"); AddInput("X", "input of test op");
...@@ -58,12 +67,100 @@ class TestInplaceProtoMaker : public paddle::framework::OpProtoAndCheckerMaker { ...@@ -58,12 +67,100 @@ class TestInplaceProtoMaker : public paddle::framework::OpProtoAndCheckerMaker {
}; };
TEST(ProtoMaker, InplaceOutput) { TEST(ProtoMaker, InplaceOutput) {
paddle::framework::proto::OpProto op_proto; paddle::framework::proto::OpProto op_proto, op_proto2;
paddle::framework::OpAttrChecker op_checker; paddle::framework::OpAttrChecker op_checker;
TestInplaceProtoMaker proto_maker; TestInplaceProtoMaker proto_maker;
ASSERT_THROW(proto_maker(&op_proto, &op_checker), TestInplaceProtoMaker2 proto_maker2;
proto_maker(&op_proto, &op_checker);
ASSERT_THROW(proto_maker2(&op_proto2, &op_checker),
paddle::platform::EnforceNotMet); paddle::platform::EnforceNotMet);
// proto_maker(&op_proto, &op_checker);
// proto_maker.Make();
// ASSERT_THROW(proto_maker.Validate(), paddle::platform::EnforceNotMet);
} }
// normal reuse
class TestReuseProtoMaker : public paddle::framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "input of test op");
AddInput("Y", "input of test op");
AddOutput("Out", "output of test op");
AddOutput("XOut", "output of test op");
// avoid destructor exception.
// Validate();
TestReuse();
}
virtual void TestReuse() {}
};
// test duplicate reuse error
class TestReuseProtoMaker2 : public TestReuseProtoMaker {
public:
void TestReuse() {
Reuse("Out", "X");
Reuse("Out", "Y");
}
};
// NotExists Input
class TestReuseProtoMaker3 : public TestReuseProtoMaker {
public:
void TestReuse() {
Reuse("Out", "NotExists");
Reuse("XOut", "X");
}
};
// NotExists Output
class TestReuseProtoMaker4 : public TestReuseProtoMaker {
public:
void TestReuse() { Reuse("NotExists", "X"); }
};
TEST(ProtoMaker, Reuse) {
paddle::framework::proto::OpProto op_proto;
paddle::framework::OpAttrChecker op_checker;
TestReuseProtoMaker proto_maker;
proto_maker(&op_proto, &op_checker);
}
// NOTE(dzhwinter):
// There is a Fatal CHECK on base class destructor, which will call abort inside
// instead of
// throw an exception. If we throw an exception in Make(), we will trigger the
// CHECK and terminate the tests.
//
// I had tried to replace the default CHECK with a exception, however, it's
// still not supported by glog.
// the details:
// https://github.com/google/glog/issues/249
// https://github.com/facebookresearch/TensorComprehensions/issues/351
/*
TEST(ProtoMaker, ReuseWithException) {
paddle::framework::proto::OpProto op_proto2, op_proto3, op_proto4;
paddle::framework::OpAttrChecker op_checker;
TestReuseProtoMaker2 proto_maker2;
TestReuseProtoMaker3 proto_maker3;
TestReuseProtoMaker4 proto_maker4;
EXPECT_THROW(proto_maker2(&op_proto2, &op_checker),
paddle::platform::EnforceNotMet);
EXPECT_THROW(proto_maker3(&op_proto3, &op_checker),
paddle::platform::EnforceNotMet);
EXPECT_THROW(proto_maker4(&op_proto4, &op_checker),
paddle::platform::EnforceNotMet);
}
void FailureFunction() {
throw std::runtime_error("Check failed in destructor.");
// return 0;
}
int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
google::InstallFailureFunction(&FailureFunction);
return RUN_ALL_TESTS();
}
*/
...@@ -47,12 +47,12 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> { ...@@ -47,12 +47,12 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> {
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
auto x_dims = x->dims(); auto x_dims = x->dims();
auto y_dims = y->dims(); auto y_dims_untrimed = y->dims();
auto z_dims = z->dims(); auto z_dims = z->dims();
// Execute default elementwise_add operator when // Execute default elementwise_add operator when
// broadcast operations need to performed. // broadcast operations need to performed.
if (x_dims != y_dims) { if (x_dims != y_dims_untrimed) {
auto sum_func = [](T a, T b) -> T { return a + b; }; auto sum_func = [](T a, T b) -> T { return a + b; };
TransformFunctor<decltype(sum_func), T, TransformFunctor<decltype(sum_func), T,
...@@ -62,11 +62,11 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> { ...@@ -62,11 +62,11 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> {
ctx.template device_context<paddle::platform::CPUDeviceContext>(), ctx.template device_context<paddle::platform::CPUDeviceContext>(),
sum_func); sum_func);
axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis); axis = (axis == -1 ? x_dims.size() - y_dims_untrimed.size() : axis);
PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(), PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(),
"Axis should be in range [0, x_dims)"); "Axis should be in range [0, x_dims)");
trim_trailing_singular_dims(&y_dims); auto y_dims = trim_trailing_singular_dims(y_dims_untrimed);
axis = (y_dims.size() == 0) ? x_dims.size() : axis; axis = (y_dims.size() == 0) ? x_dims.size() : axis;
int pre, n, post; int pre, n, post;
...@@ -88,7 +88,7 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> { ...@@ -88,7 +88,7 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> {
"Wrong layout/format set for Y tensor"); "Wrong layout/format set for Y tensor");
std::vector<int> src_x_tz = framework::vectorize2int(x_dims); std::vector<int> src_x_tz = framework::vectorize2int(x_dims);
std::vector<int> src_y_tz = framework::vectorize2int(y_dims); std::vector<int> src_y_tz = framework::vectorize2int(y_dims_untrimed);
std::vector<int> dst_tz = framework::vectorize2int(z_dims); std::vector<int> dst_tz = framework::vectorize2int(z_dims);
std::vector<memory::primitive_desc> srcs_pd; std::vector<memory::primitive_desc> srcs_pd;
...@@ -142,20 +142,22 @@ class EltwiseAddMKLDNNGradKernel : public framework::OpKernel<T> { ...@@ -142,20 +142,22 @@ class EltwiseAddMKLDNNGradKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* out = ctx.Input<Tensor>("Out");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out")); auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X")); auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y")); auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
// skip out, x, y,
// dout length is larger or equal than dx, dy.
auto* out = dout;
auto *x = dout, *y = dout;
auto set_mkldnn_format = [](Tensor* in, const Tensor* out) { auto set_mkldnn_format = [](Tensor* in, const Tensor* out) {
in->set_layout(DataLayout::kMKLDNN); in->set_layout(DataLayout::kMKLDNN);
in->set_format(out->format()); in->set_format(out->format());
}; };
if (x->dims() == y->dims()) { if (dx != nullptr && dy != nullptr && dx->dims() == dy->dims()) {
if (dx->dims() == dy->dims()) {
auto blas = math::GetBlas<paddle::platform::CPUDeviceContext, T>(ctx); auto blas = math::GetBlas<paddle::platform::CPUDeviceContext, T>(ctx);
if (dx) { if (dx) {
blas.VCOPY(dout->numel(), dout->data<T>(), blas.VCOPY(dout->numel(), dout->data<T>(),
...@@ -168,9 +170,10 @@ class EltwiseAddMKLDNNGradKernel : public framework::OpKernel<T> { ...@@ -168,9 +170,10 @@ class EltwiseAddMKLDNNGradKernel : public framework::OpKernel<T> {
dy->mutable_data<T>(ctx.GetPlace())); dy->mutable_data<T>(ctx.GetPlace()));
set_mkldnn_format(dy, dout); set_mkldnn_format(dy, dout);
} }
}
} else { } else {
// Execute default kernel when broadcast is needed // Execute default kernel when broadcast is needed
ElemwiseGradCompute<paddle::platform::CPUDeviceContext, T, ElemwiseExplicitGradCompute<paddle::platform::CPUDeviceContext, T,
IdentityGrad<T>, IdentityGrad<T>>( IdentityGrad<T>, IdentityGrad<T>>(
ctx, *x, *y, *out, *dout, axis, dx, dy, IdentityGrad<T>(), ctx, *x, *y, *out, *dout, axis, dx, dy, IdentityGrad<T>(),
IdentityGrad<T>()); IdentityGrad<T>());
......
...@@ -15,7 +15,9 @@ limitations under the License. */ ...@@ -15,7 +15,9 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise_add_op.h" #include "paddle/fluid/operators/elementwise_add_op.h"
#include "paddle/fluid/operators/elementwise_op.h" #include "paddle/fluid/operators/elementwise_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_ELEMWISE_OP(elementwise_add, "Add", "Out = X + Y"); REGISTER_ELEMWISE_GRAD_MAKER(elementwise_add, Add);
REGISTER_ELEMWISE_EXPLICIT_OP(elementwise_add, "Add", "Out = X + Y", "Out",
"X");
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
elementwise_add, elementwise_add,
ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext, float>, ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext, float>,
......
...@@ -95,8 +95,9 @@ void default_elementwise_add_grad(const framework::ExecutionContext& ctx, ...@@ -95,8 +95,9 @@ void default_elementwise_add_grad(const framework::ExecutionContext& ctx,
framework::Tensor* dy) { framework::Tensor* dy) {
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
ElemwiseGradCompute<DeviceContext, T, IdentityGrad<T>, IdentityGrad<T>>( ElemwiseExplicitGradCompute<DeviceContext, T, IdentityGrad<T>,
ctx, *x, *y, *out, *dout, axis, dx, dy, IdentityGrad<T>(), IdentityGrad<T>>(ctx, *x, *y, *out, *dout, axis,
dx, dy, IdentityGrad<T>(),
IdentityGrad<T>()); IdentityGrad<T>());
} }
...@@ -140,14 +141,15 @@ class ElementwiseAddGradKernel : public framework::OpKernel<T> { ...@@ -140,14 +141,15 @@ class ElementwiseAddGradKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* out = ctx.Input<Tensor>("Out");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out")); auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X")); auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y")); auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
// skip out, x, y
auto* out = dout;
auto *x = dout, *y = dout;
if (platform::is_cpu_place(ctx.GetPlace()) && (x->dims() == y->dims())) { if (platform::is_cpu_place(ctx.GetPlace()) && dx != nullptr &&
dy != nullptr && (dx->dims() == dy->dims())) {
elementwise_add_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy); elementwise_add_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
} else { } else {
default_elementwise_add_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, default_elementwise_add_grad<DeviceContext, T>(ctx, x, y, out, dout, dx,
......
...@@ -15,7 +15,9 @@ limitations under the License. */ ...@@ -15,7 +15,9 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise_div_op.h" #include "paddle/fluid/operators/elementwise_div_op.h"
#include "paddle/fluid/operators/elementwise_op.h" #include "paddle/fluid/operators/elementwise_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_ELEMWISE_OP(elementwise_div, "Div", "Out = X / Y"); REGISTER_ELEMWISE_OP(elementwise_div, "Div", "Out = X / Y");
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
elementwise_div, elementwise_div,
ops::ElementwiseDivKernel<paddle::platform::CPUDeviceContext, float>, ops::ElementwiseDivKernel<paddle::platform::CPUDeviceContext, float>,
......
...@@ -78,7 +78,9 @@ class ElementwiseOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -78,7 +78,9 @@ class ElementwiseOpMaker : public framework::OpProtoAndCheckerMaker {
void Make() final { void Make() final {
AddInput("X", "(Tensor), The first input tensor of elementwise op."); AddInput("X", "(Tensor), The first input tensor of elementwise op.");
AddInput("Y", "(Tensor), The second input tensor of elementwise op."); AddInput("Y", "(Tensor), The second input tensor of elementwise op.");
AddOutput("Out", "The output of elementwise op.").Reuse("X"); // AddOutput("SavedShape", "(Tensor), save X, Y shape for grad to save
// memory.").AsIntermediate();
AddOutput("Out", "The output of elementwise op.");
AddAttr<int>("axis", AddAttr<int>("axis",
"(int, default -1). The start dimension index " "(int, default -1). The start dimension index "
"for broadcasting Y onto X.") "for broadcasting Y onto X.")
...@@ -125,11 +127,13 @@ But the output only shares the LoD information with the input $X$. ...@@ -125,11 +127,13 @@ But the output only shares the LoD information with the input $X$.
)DOC", )DOC",
GetName(), GetEquation())); GetName(), GetEquation()));
SetReuse();
} }
protected: protected:
virtual std::string GetName() const = 0; virtual std::string GetName() const = 0;
virtual std::string GetEquation() const = 0; virtual std::string GetEquation() const = 0;
virtual void SetReuse() {}
}; };
class ElementwiseOpGrad : public framework::OperatorWithKernel { class ElementwiseOpGrad : public framework::OperatorWithKernel {
...@@ -162,8 +166,8 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel { ...@@ -162,8 +166,8 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
auto input_data_type = auto input_data_type = framework::ToDataType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()); ctx.Input<Tensor>(framework::GradVarName("Out"))->type());
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (platform::CanMKLDNNBeUsed(ctx)) { if (platform::CanMKLDNNBeUsed(ctx)) {
...@@ -175,9 +179,58 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel { ...@@ -175,9 +179,58 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
}; };
// For Add, Sub op, the X, Out is not needed.
class ElementwiseOpExplicitGrad : public ElementwiseOpGrad {
public:
using operators::ElementwiseOpGrad::ElementwiseOpGrad;
using operators::ElementwiseOpGrad::GetExpectedKernelType;
using Tensor = framework::Tensor;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null");
auto x_grad_name = framework::GradVarName("X");
if (ctx->HasOutput(x_grad_name)) {
auto out_dims = ctx->GetInputDim(framework::GradVarName("Out"));
ctx->SetOutputDim(x_grad_name, out_dims);
}
auto y_grad_name = framework::GradVarName("Y");
if (ctx->HasOutput(y_grad_name)) {
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null");
auto y_dims = ctx->GetInputDim("Y");
ctx->SetOutputDim(y_grad_name, y_dims);
}
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
/*
*/
#define REGISTER_ELEMWISE_GRAD_MAKER(kernel_type, op_name) \
class kernel_type##GradMaker \
: public paddle::framework::SingleGradOpDescMaker { \
public: \
using ::paddle::framework::SingleGradOpDescMaker::SingleGradOpDescMaker; \
\
protected: \
std::unique_ptr<paddle::framework::OpDesc> Apply() const override { \
auto* op = new paddle::framework::OpDesc(); \
op->SetType(#kernel_type "_grad"); \
op->SetInput("Y", Input("Y")); \
op->SetInput(::paddle::framework::GradVarName("Out"), \
OutputGrad("Out")); \
op->SetAttrMap(Attrs()); \
op->SetOutput(::paddle::framework::GradVarName("X"), InputGrad("X")); \
op->SetOutput(::paddle::framework::GradVarName("Y"), InputGrad("Y")); \
return std::unique_ptr<::paddle::framework::OpDesc>(op); \
} \
}
#define REGISTER_ELEMWISE_OP(op_type, op_name, equation) \ #define REGISTER_ELEMWISE_OP(op_type, op_name, equation) \
class __ElemwiseOp##op_type##Maker__ \ class __ElemwiseOp##op_type##Maker__ \
: public ::paddle::operators::ElementwiseOpMaker { \ : public ::paddle::operators::ElementwiseOpMaker { \
...@@ -190,3 +243,18 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel { ...@@ -190,3 +243,18 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
::paddle::operators::ElementwiseOpInferVarType, \ ::paddle::operators::ElementwiseOpInferVarType, \
::paddle::framework::DefaultGradOpDescMaker<true>); \ ::paddle::framework::DefaultGradOpDescMaker<true>); \
REGISTER_OPERATOR(op_type##_grad, ::paddle::operators::ElementwiseOpGrad) REGISTER_OPERATOR(op_type##_grad, ::paddle::operators::ElementwiseOpGrad)
#define REGISTER_ELEMWISE_EXPLICIT_OP(op_type, op_name, equation, ...) \
class __ElemwiseOp##op_type##Maker__ \
: public ::paddle::operators::ElementwiseOpMaker { \
protected: \
virtual std::string GetName() const { return op_name; } \
virtual std::string GetEquation() const { return equation; } \
virtual void SetReuse() { Reuse(__VA_ARGS__); } \
}; \
REGISTER_OPERATOR(op_type, ::paddle::operators::ElementwiseOp, \
__ElemwiseOp##op_type##Maker__, \
::paddle::operators::ElementwiseOpInferVarType, \
op_type##GradMaker); \
REGISTER_OPERATOR(op_type##_grad, \
::paddle::operators::ElementwiseOpExplicitGrad)
...@@ -13,7 +13,9 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <glog/logging.h>
#include <algorithm> #include <algorithm>
#include <vector>
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
...@@ -65,17 +67,21 @@ inline void get_mid_dims(const framework::DDim& x_dims, ...@@ -65,17 +67,21 @@ inline void get_mid_dims(const framework::DDim& x_dims,
} }
} }
inline void trim_trailing_singular_dims(framework::DDim* dims) { inline framework::DDim trim_trailing_singular_dims(
const framework::DDim& dims) {
// Remove trailing dimensions of size 1 for y // Remove trailing dimensions of size 1 for y
auto actual_dims_size = dims->size(); auto actual_dims_size = dims.size();
for (; actual_dims_size != 0; --actual_dims_size) { for (; actual_dims_size != 0; --actual_dims_size) {
if ((*dims)[actual_dims_size - 1] != 1) break; if (dims[actual_dims_size - 1] != 1) break;
} }
if (actual_dims_size != dims->size()) {
auto actual_dims = framework::vectorize(*dims); std::vector<int> trim_dims;
actual_dims.resize(actual_dims_size); trim_dims.resize(actual_dims_size);
*dims = framework::make_ddim(actual_dims); for (int i = 0; i < actual_dims_size; ++i) {
trim_dims[i] = dims[i];
} }
framework::DDim actual_dims = framework::make_ddim(trim_dims);
return actual_dims;
} }
template <typename T, typename DeviceContext> template <typename T, typename DeviceContext>
...@@ -457,26 +463,30 @@ static void ElemwiseGradBroadcast2CUDA(cudaStream_t stream, const T* x, ...@@ -457,26 +463,30 @@ static void ElemwiseGradBroadcast2CUDA(cudaStream_t stream, const T* x,
#endif #endif
template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP> template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP>
void ElemwiseGradCompute(const framework::ExecutionContext& ctx, void ElemwiseGradComputeNoBroadcast(
const framework::Tensor& x, const framework::Tensor& y, const framework::ExecutionContext& ctx, const framework::DDim& x_dim,
const framework::Tensor& out, const framework::DDim& y_dim, const framework::Tensor& x,
const framework::Tensor& dout, int axis, const framework::Tensor& y, const framework::Tensor& out,
framework::Tensor* dx, framework::Tensor* dy, const framework::Tensor& dout, int axis, framework::Tensor* dx,
DX_OP dx_op, DY_OP dy_op) { framework::Tensor* dy, DX_OP dx_op, DY_OP dy_op) {
if (x.dims() == y.dims()) { size_t N = static_cast<size_t>(framework::product(x_dim));
size_t N = static_cast<size_t>(framework::product(x.dims()));
platform::ForRange<DeviceContext> for_range( platform::ForRange<DeviceContext> for_range(
ctx.template device_context<DeviceContext>(), N); ctx.template device_context<DeviceContext>(), N);
for_range(ElemwiseGradNoBroadcast<T, DX_OP, DY_OP>{ for_range(ElemwiseGradNoBroadcast<T, DX_OP, DY_OP>{
x.data<T>(), y.data<T>(), out.data<T>(), dout.data<T>(), dx_op, dy_op, x.data<T>(), y.data<T>(), out.data<T>(), dout.data<T>(), dx_op, dy_op,
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()), dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace())}); dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace())});
} else { // Y is a scalar }
auto x_dim = x.dims();
auto y_dim = y.dims();
axis = (axis == -1 ? x_dim.size() - y_dim.size() : axis); template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP>
trim_trailing_singular_dims(&y_dim); void ElemwiseGradComputeWithBroadcast(
const framework::ExecutionContext& ctx, const framework::DDim& x_dim,
const framework::DDim& y_dim_untrimed, const framework::Tensor& x,
const framework::Tensor& y, const framework::Tensor& out,
const framework::Tensor& dout, int axis, framework::Tensor* dx,
framework::Tensor* dy, DX_OP dx_op, DY_OP dy_op) {
axis = (axis == -1 ? x_dim.size() - y_dim_untrimed.size() : axis);
auto y_dim = trim_trailing_singular_dims(y_dim_untrimed);
axis = (y_dim.size() == 0) ? x_dim.size() : axis; axis = (y_dim.size() == 0) ? x_dim.size() : axis;
int pre, n, post; int pre, n, post;
...@@ -494,9 +504,8 @@ void ElemwiseGradCompute(const framework::ExecutionContext& ctx, ...@@ -494,9 +504,8 @@ void ElemwiseGradCompute(const framework::ExecutionContext& ctx,
#endif #endif
} else { } else {
ElemwiseGradBroadcast1CPU( ElemwiseGradBroadcast1CPU(
x.data<T>(), y.data<T>(), out.data<T>(), dout.data<T>(), h, w, x.data<T>(), y.data<T>(), out.data<T>(), dout.data<T>(), h, w, dx_op,
dx_op, dy_op, dy_op, dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace())); dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
} }
} else { } else {
...@@ -505,21 +514,70 @@ void ElemwiseGradCompute(const framework::ExecutionContext& ctx, ...@@ -505,21 +514,70 @@ void ElemwiseGradCompute(const framework::ExecutionContext& ctx,
ElemwiseGradBroadcast2CUDA( ElemwiseGradBroadcast2CUDA(
ctx.template device_context<DeviceContext>().stream(), x.data<T>(), ctx.template device_context<DeviceContext>().stream(), x.data<T>(),
y.data<T>(), out.data<T>(), dout.data<T>(), pre, n, post, dx_op, y.data<T>(), out.data<T>(), dout.data<T>(), pre, n, post, dx_op,
dy_op, dy_op, dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace())); dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
#endif #endif
} else { } else {
ElemwiseGradBroadcast2CPU( ElemwiseGradBroadcast2CPU(
x.data<T>(), y.data<T>(), out.data<T>(), dout.data<T>(), pre, n, x.data<T>(), y.data<T>(), out.data<T>(), dout.data<T>(), pre, n, post,
post, dx_op, dy_op, dx_op, dy_op,
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()), dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace())); dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
} }
} }
}
template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP>
void ElemwiseGradCompute(const framework::ExecutionContext& ctx,
const framework::Tensor& x, const framework::Tensor& y,
const framework::Tensor& out,
const framework::Tensor& dout, int axis,
framework::Tensor* dx, framework::Tensor* dy,
DX_OP dx_op, DY_OP dy_op) {
const framework::DDim x_dim = x.dims();
const framework::DDim y_dim = y.dims();
if (x.dims() == y.dims()) {
ElemwiseGradComputeNoBroadcast<DeviceContext, T, DX_OP, DY_OP>(
ctx, x_dim, y_dim, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
} else { // Y is a scalar
ElemwiseGradComputeWithBroadcast<DeviceContext, T, DX_OP, DY_OP>(
ctx, x_dim, y_dim, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
}
}
// NOTE(dzhwinter): Only used in elementwise_add, elementwise_sub.
// explicit gradient can cut off X, Y, Out from gradient op
// In elementwise_add, elementwise_sub, we use dout as fake X, Y, Out to reuse
// elementwise code.
template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP>
void ElemwiseExplicitGradCompute(const framework::ExecutionContext& ctx,
const framework::Tensor& x,
const framework::Tensor& y,
const framework::Tensor& out,
const framework::Tensor& dout, int axis,
framework::Tensor* dx, framework::Tensor* dy,
DX_OP dx_op, DY_OP dy_op) {
if (dy == nullptr) {
const framework::DDim dx_dims = dout.dims();
auto dy_dims = dx_dims;
ElemwiseGradComputeNoBroadcast<DeviceContext, T, DX_OP, DY_OP>(
ctx, dx_dims, dy_dims, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
} else {
if (dout.dims() == dy->dims()) {
const framework::DDim dx_dims = dout.dims();
const framework::DDim dy_dims = dy->dims();
ElemwiseGradComputeNoBroadcast<DeviceContext, T, DX_OP, DY_OP>(
ctx, dx_dims, dy_dims, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
} else { // Y is a scalar
auto dx_dims = dout.dims();
const framework::DDim dy_dims = dy->dims();
ElemwiseGradComputeWithBroadcast<DeviceContext, T, DX_OP, DY_OP>(
ctx, dx_dims, dy_dims, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
}
} }
} }
// Deprecated
template <typename DeviceContext, typename T, typename functor, template <typename DeviceContext, typename T, typename functor,
typename broadcastfunctor, typename broadcast2functor> typename broadcastfunctor, typename broadcast2functor>
void ElementwiseGradCompute(const framework::ExecutionContext& ctx, void ElementwiseGradCompute(const framework::ExecutionContext& ctx,
...@@ -547,7 +605,7 @@ void ElementwiseGradCompute(const framework::ExecutionContext& ctx, ...@@ -547,7 +605,7 @@ void ElementwiseGradCompute(const framework::ExecutionContext& ctx,
} }
axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis); axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis);
trim_trailing_singular_dims(&y_dims); trim_trailing_singular_dims(y_dims);
axis = (y_dims.size() == 0) ? x_dims.size() : axis; axis = (y_dims.size() == 0) ? x_dims.size() : axis;
int pre, n, post; int pre, n, post;
...@@ -574,19 +632,19 @@ void ElementwiseComputeEx(const framework::ExecutionContext& ctx, ...@@ -574,19 +632,19 @@ void ElementwiseComputeEx(const framework::ExecutionContext& ctx,
x, y, z, ctx.template device_context<DeviceContext>(), func); x, y, z, ctx.template device_context<DeviceContext>(), func);
auto x_dims = x->dims(); auto x_dims = x->dims();
auto y_dims = y->dims(); auto y_dims_untrimed = y->dims();
PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(), PADDLE_ENFORCE_GE(x_dims.size(), y_dims_untrimed.size(),
"Rank of first input must >= rank of second input."); "Rank of first input must >= rank of second input.");
if (x_dims == y_dims) { if (x_dims == y_dims_untrimed) {
functor.Run(); functor.Run();
return; return;
} }
axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis); axis = (axis == -1 ? x_dims.size() - y_dims_untrimed.size() : axis);
PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(), PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(),
"Axis should be in range [0, x_dims)"); "Axis should be in range [0, x_dims)");
trim_trailing_singular_dims(&y_dims); auto y_dims = trim_trailing_singular_dims(y_dims_untrimed);
axis = (y_dims.size() == 0) ? x_dims.size() : axis; axis = (y_dims.size() == 0) ? x_dims.size() : axis;
int pre, n, post; int pre, n, post;
......
...@@ -15,7 +15,10 @@ limitations under the License. */ ...@@ -15,7 +15,10 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise_sub_op.h" #include "paddle/fluid/operators/elementwise_sub_op.h"
#include "paddle/fluid/operators/elementwise_op.h" #include "paddle/fluid/operators/elementwise_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_ELEMWISE_OP(elementwise_sub, "Sub", "Out = X - Y"); REGISTER_ELEMWISE_GRAD_MAKER(elementwise_sub, Sub);
REGISTER_ELEMWISE_EXPLICIT_OP(elementwise_sub, "Sub", "Out = X - Y", "Out",
"X");
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
elementwise_sub, elementwise_sub,
ops::ElementwiseSubKernel<paddle::platform::CPUDeviceContext, float>, ops::ElementwiseSubKernel<paddle::platform::CPUDeviceContext, float>,
......
...@@ -55,14 +55,15 @@ class ElementwiseSubGradKernel : public framework::OpKernel<T> { ...@@ -55,14 +55,15 @@ class ElementwiseSubGradKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* out = ctx.Input<Tensor>("Out");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out")); auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X")); auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y")); auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
ElemwiseGradCompute<DeviceContext, T, SubGradDX<T>, SubGradDY<T>>( // skip out, x, y
auto* out = dout;
auto *x = dout, *y = dout;
ElemwiseExplicitGradCompute<DeviceContext, T, SubGradDX<T>, SubGradDY<T>>(
ctx, *x, *y, *out, *dout, axis, dx, dy, SubGradDX<T>(), SubGradDY<T>()); ctx, *x, *y, *out, *dout, axis, dx, dy, SubGradDX<T>(), SubGradDY<T>());
} }
}; };
......
...@@ -137,7 +137,8 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel { ...@@ -137,7 +137,8 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel {
ctx->GetInputDim(framework::GradVarName("Out")), ctx->GetInputDim(framework::GradVarName("Out")),
"Input(Out) and its gradients should have a same shape."); "Input(Out) and its gradients should have a same shape.");
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); ctx->SetOutputDim(framework::GradVarName("X"),
ctx->GetInputDim(framework::GradVarName("Out")));
} }
protected: protected:
...@@ -160,8 +161,8 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel { ...@@ -160,8 +161,8 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel {
layout_ = framework::DataLayout::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN;
} }
#endif #endif
auto input_data_type = auto input_data_type = framework::ToDataType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()); ctx.Input<Tensor>(framework::GradVarName("Out"))->type());
if (input_data_type == framework::proto::VarType::FP16) { if (input_data_type == framework::proto::VarType::FP16) {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"float16 can only be used on GPU place"); "float16 can only be used on GPU place");
...@@ -172,13 +173,31 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel { ...@@ -172,13 +173,31 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel {
} }
}; };
class SoftmaxOpGradMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
auto* op = new framework::OpDesc();
op->SetType("softmax_grad");
op->SetInput("Out", Output("Out"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetAttrMap(Attrs());
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
return std::unique_ptr<framework::OpDesc>(op);
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(softmax, ops::SoftmaxOp, ops::SoftmaxOpMaker, REGISTER_OPERATOR(softmax, ops::SoftmaxOp, ops::SoftmaxOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>); ops::SoftmaxOpGradMaker);
REGISTER_OPERATOR(softmax_grad, ops::SoftmaxOpGrad); REGISTER_OPERATOR(softmax_grad, ops::SoftmaxOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
softmax, ops::SoftmaxKernel<paddle::platform::CPUDeviceContext, float>, softmax, ops::SoftmaxKernel<paddle::platform::CPUDeviceContext, float>,
......
...@@ -20,8 +20,8 @@ class TestElementwiseOp(OpTest): ...@@ -20,8 +20,8 @@ class TestElementwiseOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "elementwise_sub" self.op_type = "elementwise_sub"
self.inputs = { self.inputs = {
'X': np.random.uniform(0.1, 1, [13, 17]).astype("float32"), 'X': np.random.uniform(0.1, 1, [2, 3]).astype("float32"),
'Y': np.random.uniform(0.1, 1, [13, 17]).astype("float32") 'Y': np.random.uniform(0.1, 1, [2, 3]).astype("float32")
} }
self.outputs = {'Out': self.inputs['X'] - self.inputs['Y']} self.outputs = {'Out': self.inputs['X'] - self.inputs['Y']}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册