提交 2a36ad1a 编写于 作者: Y Yu Yang

Handle LoD for concat & seq_softmax ops

上级 211d8186
...@@ -62,9 +62,21 @@ class ConcatGradKernel : public framework::OpKernel<T> { ...@@ -62,9 +62,21 @@ class ConcatGradKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const { void Compute(const framework::ExecutionContext& ctx) const {
auto* out_grad = auto* out_grad =
ctx.Input<framework::Tensor>(framework::GradVarName("Out")); ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto ins = ctx.MultiInput<framework::Tensor>("X"); auto ins = ctx.MultiInput<framework::LoDTensor>("X");
auto out_var_names = ctx.Outputs(framework::GradVarName("X")); auto out_var_names = ctx.Outputs(framework::GradVarName("X"));
auto outs = ctx.MultiOutput<framework::Tensor>(framework::GradVarName("X")); auto outs =
ctx.MultiOutput<framework::LoDTensor>(framework::GradVarName("X"));
{
auto dx = outs;
auto x = ins;
for (size_t i = 0; i < dx.size(); ++i) {
if (dx[i] != nullptr) {
dx[i]->set_lod(x[i]->lod());
}
}
}
int64_t axis = static_cast<int64_t>(ctx.Attr<int>("axis")); int64_t axis = static_cast<int64_t>(ctx.Attr<int>("axis"));
// get output tensor that the name is not kEmptyVarName // get output tensor that the name is not kEmptyVarName
......
...@@ -71,7 +71,7 @@ class ConcatGradFunctor<platform::CPUDeviceContext, T> { ...@@ -71,7 +71,7 @@ class ConcatGradFunctor<platform::CPUDeviceContext, T> {
public: public:
void operator()(const platform::CPUDeviceContext& context, void operator()(const platform::CPUDeviceContext& context,
const framework::Tensor& input, const framework::Tensor& input,
const std::vector<const framework::Tensor*>& ref_inputs, const std::vector<const framework::LoDTensor*>& ref_inputs,
const int axis, std::vector<framework::Tensor*>* outputs) { const int axis, std::vector<framework::Tensor*>* outputs) {
// TODO(zcd): Add input data validity checking // TODO(zcd): Add input data validity checking
size_t num = outputs->size(); size_t num = outputs->size();
......
...@@ -189,7 +189,7 @@ class ConcatGradFunctor<platform::CUDADeviceContext, T> { ...@@ -189,7 +189,7 @@ class ConcatGradFunctor<platform::CUDADeviceContext, T> {
public: public:
void operator()(const platform::CUDADeviceContext& context, void operator()(const platform::CUDADeviceContext& context,
const framework::Tensor& input, const framework::Tensor& input,
const std::vector<const framework::Tensor*>& ref_inputs, const std::vector<const framework::LoDTensor*>& ref_inputs,
const int axis, std::vector<framework::Tensor*>* outputs) { const int axis, std::vector<framework::Tensor*>* outputs) {
// TODO(zcd): Add input data validity checking // TODO(zcd): Add input data validity checking
int o_num = outputs->size(); int o_num = outputs->size();
......
...@@ -15,7 +15,7 @@ limitations under the License. */ ...@@ -15,7 +15,7 @@ limitations under the License. */
#pragma once #pragma once
#include <vector> #include <vector>
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -57,7 +57,7 @@ template <typename DeviceContext, typename T> ...@@ -57,7 +57,7 @@ template <typename DeviceContext, typename T>
class ConcatGradFunctor { class ConcatGradFunctor {
public: public:
void operator()(const DeviceContext& context, const framework::Tensor& input, void operator()(const DeviceContext& context, const framework::Tensor& input,
const std::vector<const framework::Tensor*>& ref_inputs, const std::vector<const framework::LoDTensor*>& ref_inputs,
const int axis, std::vector<framework::Tensor*>* outputs); const int axis, std::vector<framework::Tensor*>* outputs);
}; };
......
...@@ -66,6 +66,9 @@ class SequenceSoftmaxGradKernel : public framework::OpKernel<T> { ...@@ -66,6 +66,9 @@ class SequenceSoftmaxGradKernel : public framework::OpKernel<T> {
auto* out_grad = ctx.Input<LoDTensor>(framework::GradVarName("Out")); auto* out_grad = ctx.Input<LoDTensor>(framework::GradVarName("Out"));
auto* x = ctx.Input<LoDTensor>("X"); auto* x = ctx.Input<LoDTensor>("X");
auto* x_grad = ctx.Output<LoDTensor>(framework::GradVarName("X")); auto* x_grad = ctx.Output<LoDTensor>(framework::GradVarName("X"));
if (x_grad) {
x_grad->set_lod(x->lod());
}
auto lod = x->lod(); auto lod = x->lod();
const size_t level = lod.size() - 1; const size_t level = lod.size() - 1;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册