未验证 提交 17fcc4f5 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #12864 from reyoung/feature/process_lod_grad

Feature/process lod grad
...@@ -62,9 +62,21 @@ class ConcatGradKernel : public framework::OpKernel<T> { ...@@ -62,9 +62,21 @@ class ConcatGradKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const { void Compute(const framework::ExecutionContext& ctx) const {
auto* out_grad = auto* out_grad =
ctx.Input<framework::Tensor>(framework::GradVarName("Out")); ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto ins = ctx.MultiInput<framework::Tensor>("X"); auto ins = ctx.MultiInput<framework::LoDTensor>("X");
auto out_var_names = ctx.Outputs(framework::GradVarName("X")); auto out_var_names = ctx.Outputs(framework::GradVarName("X"));
auto outs = ctx.MultiOutput<framework::Tensor>(framework::GradVarName("X")); auto outs =
ctx.MultiOutput<framework::LoDTensor>(framework::GradVarName("X"));
{
auto dx = outs;
auto x = ins;
for (size_t i = 0; i < dx.size(); ++i) {
if (dx[i] != nullptr) {
dx[i]->set_lod(x[i]->lod());
}
}
}
int64_t axis = static_cast<int64_t>(ctx.Attr<int>("axis")); int64_t axis = static_cast<int64_t>(ctx.Attr<int>("axis"));
// get output tensor that the name is not kEmptyVarName // get output tensor that the name is not kEmptyVarName
......
...@@ -137,9 +137,10 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> { ...@@ -137,9 +137,10 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> {
}; };
template <typename T> template <typename T>
class EltwiseAddMKLDNNGradKernel : public framework::OpKernel<T> { class EltwiseAddMKLDNNGradKernel : public ElemwiseGradKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
ElemwiseGradKernel<T>::Compute(ctx);
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out")); auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/operators/elementwise_op.h"
#include "paddle/fluid/operators/elementwise_op_function.h" #include "paddle/fluid/operators/elementwise_op_function.h"
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
...@@ -136,9 +137,11 @@ elementwise_add_grad(const framework::ExecutionContext& ctx, ...@@ -136,9 +137,11 @@ elementwise_add_grad(const framework::ExecutionContext& ctx,
} }
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class ElementwiseAddGradKernel : public framework::OpKernel<T> { class ElementwiseAddGradKernel : public ElemwiseGradKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
ElemwiseGradKernel<T>::Compute(ctx);
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out")); auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
......
...@@ -14,8 +14,8 @@ limitations under the License. */ ...@@ -14,8 +14,8 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/operators/elementwise_op.h"
#include "paddle/fluid/operators/elementwise_op_function.h" #include "paddle/fluid/operators/elementwise_op_function.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -53,9 +53,10 @@ struct DivGradDY { ...@@ -53,9 +53,10 @@ struct DivGradDY {
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class ElementwiseDivGradKernel : public framework::OpKernel<T> { class ElementwiseDivGradKernel : public ElemwiseGradKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
ElemwiseGradKernel<T>::Compute(ctx);
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
auto* x = ctx.Input<Tensor>("X"); auto* x = ctx.Input<Tensor>("X");
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/operators/elementwise_op.h"
#include "paddle/fluid/operators/elementwise_op_function.h" #include "paddle/fluid/operators/elementwise_op_function.h"
namespace paddle { namespace paddle {
...@@ -55,9 +56,10 @@ struct MaxGradDy { ...@@ -55,9 +56,10 @@ struct MaxGradDy {
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class ElementwiseMaxGradKernel : public framework::OpKernel<T> { class ElementwiseMaxGradKernel : public ElemwiseGradKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
ElemwiseGradKernel<T>::Compute(ctx);
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
auto* x = ctx.Input<Tensor>("X"); auto* x = ctx.Input<Tensor>("X");
......
...@@ -14,8 +14,8 @@ limitations under the License. */ ...@@ -14,8 +14,8 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/operators/elementwise_op.h"
#include "paddle/fluid/operators/elementwise_op_function.h" #include "paddle/fluid/operators/elementwise_op_function.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -55,9 +55,10 @@ struct MinGradDy { ...@@ -55,9 +55,10 @@ struct MinGradDy {
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class ElementwiseMinGradKernel : public framework::OpKernel<T> { class ElementwiseMinGradKernel : public ElemwiseGradKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
ElemwiseGradKernel<T>::Compute(ctx);
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
auto* x = ctx.Input<Tensor>("X"); auto* x = ctx.Input<Tensor>("X");
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/operators/elementwise_op.h"
#include "paddle/fluid/operators/elementwise_op_function.h" #include "paddle/fluid/operators/elementwise_op_function.h"
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
...@@ -84,9 +85,10 @@ struct MulGradDY { ...@@ -84,9 +85,10 @@ struct MulGradDY {
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class ElementwiseMulGradKernel : public framework::OpKernel<T> { class ElementwiseMulGradKernel : public ElemwiseGradKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
ElemwiseGradKernel<T>::Compute(ctx);
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
auto* x = ctx.Input<Tensor>("X"); auto* x = ctx.Input<Tensor>("X");
......
...@@ -205,6 +205,20 @@ class ElementwiseOpExplicitGrad : public ElementwiseOpGrad { ...@@ -205,6 +205,20 @@ class ElementwiseOpExplicitGrad : public ElementwiseOpGrad {
} }
}; };
template <typename T>
class ElemwiseGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* dx =
context.Output<framework::LoDTensor>(framework::GradVarName("X"));
if (dx != nullptr) {
auto& dout =
*context.Input<framework::LoDTensor>(framework::GradVarName("Out"));
dx->set_lod(dout.lod());
}
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/operators/elementwise_op.h"
#include "paddle/fluid/operators/elementwise_op_function.h" #include "paddle/fluid/operators/elementwise_op_function.h"
namespace paddle { namespace paddle {
...@@ -50,9 +51,10 @@ struct SubGradDY { ...@@ -50,9 +51,10 @@ struct SubGradDY {
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class ElementwiseSubGradKernel : public framework::OpKernel<T> { class ElementwiseSubGradKernel : public ElemwiseGradKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
ElemwiseGradKernel<T>::Compute(ctx);
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out")); auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
......
...@@ -71,7 +71,7 @@ class ConcatGradFunctor<platform::CPUDeviceContext, T> { ...@@ -71,7 +71,7 @@ class ConcatGradFunctor<platform::CPUDeviceContext, T> {
public: public:
void operator()(const platform::CPUDeviceContext& context, void operator()(const platform::CPUDeviceContext& context,
const framework::Tensor& input, const framework::Tensor& input,
const std::vector<const framework::Tensor*>& ref_inputs, const std::vector<const framework::LoDTensor*>& ref_inputs,
const int axis, std::vector<framework::Tensor*>* outputs) { const int axis, std::vector<framework::Tensor*>* outputs) {
// TODO(zcd): Add input data validity checking // TODO(zcd): Add input data validity checking
size_t num = outputs->size(); size_t num = outputs->size();
......
...@@ -189,7 +189,7 @@ class ConcatGradFunctor<platform::CUDADeviceContext, T> { ...@@ -189,7 +189,7 @@ class ConcatGradFunctor<platform::CUDADeviceContext, T> {
public: public:
void operator()(const platform::CUDADeviceContext& context, void operator()(const platform::CUDADeviceContext& context,
const framework::Tensor& input, const framework::Tensor& input,
const std::vector<const framework::Tensor*>& ref_inputs, const std::vector<const framework::LoDTensor*>& ref_inputs,
const int axis, std::vector<framework::Tensor*>* outputs) { const int axis, std::vector<framework::Tensor*>* outputs) {
// TODO(zcd): Add input data validity checking // TODO(zcd): Add input data validity checking
int o_num = outputs->size(); int o_num = outputs->size();
......
...@@ -15,7 +15,7 @@ limitations under the License. */ ...@@ -15,7 +15,7 @@ limitations under the License. */
#pragma once #pragma once
#include <vector> #include <vector>
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -57,7 +57,7 @@ template <typename DeviceContext, typename T> ...@@ -57,7 +57,7 @@ template <typename DeviceContext, typename T>
class ConcatGradFunctor { class ConcatGradFunctor {
public: public:
void operator()(const DeviceContext& context, const framework::Tensor& input, void operator()(const DeviceContext& context, const framework::Tensor& input,
const std::vector<const framework::Tensor*>& ref_inputs, const std::vector<const framework::LoDTensor*>& ref_inputs,
const int axis, std::vector<framework::Tensor*>* outputs); const int axis, std::vector<framework::Tensor*>* outputs);
}; };
......
...@@ -62,23 +62,31 @@ class MulGradKernel : public framework::OpKernel<T> { ...@@ -62,23 +62,31 @@ class MulGradKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
int x_num_col_dims = ctx.template Attr<int>("x_num_col_dims"); int x_num_col_dims = ctx.template Attr<int>("x_num_col_dims");
int y_num_col_dims = ctx.template Attr<int>("y_num_col_dims"); int y_num_col_dims = ctx.template Attr<int>("y_num_col_dims");
const Tensor* x = ctx.Input<Tensor>("X"); auto* x = ctx.Input<framework::LoDTensor>("X");
const Tensor* y = ctx.Input<Tensor>("Y"); auto* y = ctx.Input<framework::LoDTensor>("Y");
const Tensor x_matrix = x->dims().size() > 2 auto x_matrix = x->dims().size() > 2
? framework::ReshapeToMatrix(*x, x_num_col_dims) ? framework::ReshapeToMatrix(*x, x_num_col_dims)
: *x; : static_cast<const Tensor&>(*x);
const Tensor y_matrix = y->dims().size() > 2 auto y_matrix = y->dims().size() > 2
? framework::ReshapeToMatrix(*y, y_num_col_dims) ? framework::ReshapeToMatrix(*y, y_num_col_dims)
: *y; : static_cast<const Tensor&>(*y);
const Tensor* dout = ctx.Input<Tensor>(framework::GradVarName("Out")); auto* dout = ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"));
Tensor dout_mat; Tensor dout_mat;
dout_mat.ShareDataWith(*dout); dout_mat.ShareDataWith(*dout);
dout_mat.Resize({framework::flatten_to_2d(x->dims(), x_num_col_dims)[0], dout_mat.Resize({framework::flatten_to_2d(x->dims(), x_num_col_dims)[0],
framework::flatten_to_2d(y->dims(), y_num_col_dims)[1]}); framework::flatten_to_2d(y->dims(), y_num_col_dims)[1]});
Tensor* dx = ctx.Output<Tensor>(framework::GradVarName("X")); auto* dx = ctx.Output<framework::LoDTensor>(framework::GradVarName("X"));
Tensor* dy = ctx.Output<Tensor>(framework::GradVarName("Y")); auto* dy = ctx.Output<framework::LoDTensor>(framework::GradVarName("Y"));
if (dx != nullptr) {
dx->set_lod(x->lod());
}
if (dy != nullptr) {
dy->set_lod(y->lod());
}
auto& dev_ctx = ctx.template device_context<DeviceContext>(); auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx); auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
if (dx) { if (dx) {
......
...@@ -68,7 +68,9 @@ class SequenceSoftmaxGradCUDNNKernel : public framework::OpKernel<T> { ...@@ -68,7 +68,9 @@ class SequenceSoftmaxGradCUDNNKernel : public framework::OpKernel<T> {
auto* out_grad = ctx.Input<LoDTensor>(framework::GradVarName("Out")); auto* out_grad = ctx.Input<LoDTensor>(framework::GradVarName("Out"));
auto* x = ctx.Input<LoDTensor>("X"); auto* x = ctx.Input<LoDTensor>("X");
auto* x_grad = ctx.Output<LoDTensor>(framework::GradVarName("X")); auto* x_grad = ctx.Output<LoDTensor>(framework::GradVarName("X"));
if (x_grad) {
x_grad->set_lod(x->lod());
}
auto lod = x->lod(); auto lod = x->lod();
const size_t level = lod.size() - 1; const size_t level = lod.size() - 1;
......
...@@ -66,6 +66,9 @@ class SequenceSoftmaxGradKernel : public framework::OpKernel<T> { ...@@ -66,6 +66,9 @@ class SequenceSoftmaxGradKernel : public framework::OpKernel<T> {
auto* out_grad = ctx.Input<LoDTensor>(framework::GradVarName("Out")); auto* out_grad = ctx.Input<LoDTensor>(framework::GradVarName("Out"));
auto* x = ctx.Input<LoDTensor>("X"); auto* x = ctx.Input<LoDTensor>("X");
auto* x_grad = ctx.Output<LoDTensor>(framework::GradVarName("X")); auto* x_grad = ctx.Output<LoDTensor>(framework::GradVarName("X"));
if (x_grad) {
x_grad->set_lod(x->lod());
}
auto lod = x->lod(); auto lod = x->lod();
const size_t level = lod.size() - 1; const size_t level = lod.size() - 1;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册