diff --git a/paddle/framework/backward.cc b/paddle/framework/backward.cc index 913cd0f81eaef37014f38c71e7c3d23bfeec1466..b3b9c45ded95ce2e735b8898d47760956dcacdce 100644 --- a/paddle/framework/backward.cc +++ b/paddle/framework/backward.cc @@ -377,6 +377,12 @@ std::vector> MakeOpGrad( return grad_op_descs; } +static BlockDescBind* CreateStepBlock( + ProgramDescBind& program_desc, + std::unordered_set* no_grad_vars, + std::unordered_map* grad_to_var, + int step_block_idx); + std::vector> MakeBlockBackward( ProgramDescBind& program_desc, int block_idx, std::unordered_set* no_grad_vars, @@ -392,13 +398,13 @@ std::vector> MakeBlockBackward( if ((*it)->Type() == "recurrent") { int step_block_idx = (*it)->GetBlockAttr("step_block"); - auto backward_block_op_descs = MakeBlockBackward( - program_desc, step_block_idx, no_grad_vars, grad_to_var); + BlockDescBind* backward_block = CreateStepBlock( + program_desc, no_grad_vars, grad_to_var, step_block_idx); + op_grads = MakeOpGrad(*it, no_grad_vars, grad_to_var, {backward_block}); + } else if ((*it)->Type() == "conditional_block") { BlockDescBind* backward_block = - program_desc.AppendBlock(*program_desc.MutableBlock(step_block_idx)); - for (auto& ptr : backward_block_op_descs) { - backward_block->AppendAllocatedOp(std::move(ptr)); - } + CreateStepBlock(program_desc, no_grad_vars, grad_to_var, + (*it)->GetBlockAttr("block")); op_grads = MakeOpGrad(*it, no_grad_vars, grad_to_var, {backward_block}); } else { op_grads = MakeOpGrad(*it, no_grad_vars, grad_to_var); @@ -449,6 +455,21 @@ std::vector> MakeBlockBackward( return backward_descs; } +static BlockDescBind* CreateStepBlock( + ProgramDescBind& program_desc, + std::unordered_set* no_grad_vars, + std::unordered_map* grad_to_var, + int step_block_idx) { + auto backward_block_op_descs = MakeBlockBackward(program_desc, step_block_idx, + no_grad_vars, grad_to_var); + BlockDescBind* backward_block = + program_desc.AppendBlock(*program_desc.MutableBlock(step_block_idx)); + for (auto& ptr : backward_block_op_descs) { + backward_block->AppendAllocatedOp(move(ptr)); + } + return backward_block; +} + ParamGradInfoMap AppendBackward( ProgramDescBind& program_desc, const VarDescBind& target, const std::unordered_set& no_grad_vars) { diff --git a/paddle/framework/var_type.h b/paddle/framework/var_type.h index d060196bb2c478b776851288cb71a1880d60660d..0f19870bec3e69d07278507cc556a86bbd25d12d 100644 --- a/paddle/framework/var_type.h +++ b/paddle/framework/var_type.h @@ -27,10 +27,32 @@ inline VarDesc::VarType ToVarType(std::type_index type) { return VarDesc_VarType_LOD_RANK_TABLE; } else if (type.hash_code() == typeid(LoDTensorArray).hash_code()) { return VarDesc_VarType_LOD_TENSOR_ARRAY; + } else if (type.hash_code() == typeid(SelectedRows).hash_code()) { + return VarDesc_VarType_SELECTED_ROWS; } else { PADDLE_THROW("ToVarType:Unsupported type %s", type.name()); } } +template +inline void VisitVarType(const Variable& var, Visitor visitor) { + switch (ToVarType(var.Type())) { + case VarDesc_VarType_LOD_TENSOR: + visitor(var.Get()); + return; + case VarDesc_VarType_LOD_RANK_TABLE: + visitor(var.Get()); + return; + case VarDesc_VarType_LOD_TENSOR_ARRAY: + visitor(var.Get()); + return; + case VarDesc_VarType_SELECTED_ROWS: + visitor(var.Get()); + return; + default: + PADDLE_THROW("Not supported visit type, %d", ToVarType(var.Type())); + } +} + } // namespace framework } // namespace paddle diff --git a/paddle/operators/assign_op.cc b/paddle/operators/assign_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..609e915b932e2bc4d5abee1e5f868cc07a7619d3 --- /dev/null +++ b/paddle/operators/assign_op.cc @@ -0,0 +1,138 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/framework/data_type.h" +#include "paddle/framework/op_registry.h" +#include "paddle/framework/var_type.h" + +namespace paddle { +namespace operators { +class AssignFunctor { + public: + AssignFunctor(framework::Variable *out, + const platform::DeviceContext &dev_ctx) + : out_(out), dev_ctx_(dev_ctx) {} + + void operator()(const framework::LoDTensor &lod_tensor) const { + auto &out_tensor = *out_->GetMutable(); + copy_tensor(lod_tensor, &out_tensor); + } + + void operator()(const framework::LoDTensorArray &array) const { + auto &out_array = *out_->GetMutable(); + out_array.resize(array.size()); + for (size_t i = 0; i < array.size(); ++i) { + copy_tensor(array[i], &out_array[i]); + } + } + + void operator()(const framework::SelectedRows &rows) const { + framework::SelectedRows &out_rows = + *out_->GetMutable(); + out_rows.set_rows(rows.rows()); + out_rows.set_height(rows.height()); + auto &t = rows.value(); + out_rows.mutable_value()->CopyFrom(t, t.place(), dev_ctx_); + } + + template + void operator()(const T &v) const { + PADDLE_THROW("Not support type for assign op %s", typeid(T).name()); + } + + private: + void copy_tensor(const framework::LoDTensor &lod_tensor, + framework::LoDTensor *out) const { + auto &out_tensor = *out; + out_tensor.CopyFrom(lod_tensor, lod_tensor.place(), dev_ctx_); + out_tensor.set_lod(lod_tensor.lod()); + } + + framework::Variable *out_; + const platform::DeviceContext &dev_ctx_; +}; + +class AssignOp : public framework::OperatorBase { + public: + AssignOp(const std::string &type, const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + void Run(const framework::Scope &scope, + const platform::DeviceContext &dev_ctx) const override { + auto *x = scope.FindVar(Input("X")); + if (x == nullptr) { + return; + } + auto *out = scope.FindVar(Output("Out")); + PADDLE_ENFORCE( + out != nullptr, + "The Output(Out) should not be null if the Input(X) is set."); + framework::VisitVarType(*x, AssignFunctor(out, dev_ctx)); + } +}; + +class AssignOpProtoMaker : public framework::OpProtoAndCheckerMaker { + public: + AssignOpProtoMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", + "(LoDTensor, SelectedRows or LoDTensorArray) The input variable " + "could be LoDTensor, SelectedRows or LoDTensorArray.") + .AsDispensable(); + AddOutput("Out", + "(LoDTensor, SelectedRows or LoDTensorArray) The type of output " + "is the same as input X."); + AddComment(R"DOC(Assign Operator + +Out = X, when type in [LoDTensor/SelectedRows/LoDTensorArray] +raise error if the type is not listed above. +)DOC"); + } +}; + +class AssignInferShape : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *context) const override { + if (context->HasInput("X")) { + auto type = context->GetInputsVarType("X")[0]; + if (type == framework::VarDesc_VarType_SELECTED_ROWS || + type == framework::VarDesc_VarType_LOD_TENSOR) { + context->SetOutputDim("Out", context->GetInputDim("X")); + } + } + } +}; + +class AssignGradMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + auto *op = new framework::OpDescBind(); + op->SetType("assign"); + op->SetInput("X", OutputGrad("Out")); + op->SetOutput("Out", InputGrad("X")); + return std::unique_ptr(op); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(assign, ops::AssignOp, ops::AssignGradMaker, + ops::AssignInferShape, ops::AssignOpProtoMaker); diff --git a/paddle/operators/bilinear_tensor_product_op.cc b/paddle/operators/bilinear_tensor_product_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..c65ba7eb262f3aabe2c00837b79806c0b40b60fd --- /dev/null +++ b/paddle/operators/bilinear_tensor_product_op.cc @@ -0,0 +1,159 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/bilinear_tensor_product_op.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; + +class BilinearTensorProductOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Weight"), + "Input(Weight) should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null."); + auto x_dims = ctx->GetInputDim("X"); + auto y_dims = ctx->GetInputDim("Y"); + auto weight_dims = ctx->GetInputDim("Weight"); + + PADDLE_ENFORCE_EQ(x_dims.size(), 2UL, "The input(X) must be a 2D Tensor."); + PADDLE_ENFORCE_EQ(y_dims.size(), 2UL, "The input(Y) must be a 2D Tensor."); + PADDLE_ENFORCE_EQ(weight_dims.size(), 3UL, + "The input(Weight) must be a 3D tensor."); + PADDLE_ENFORCE_EQ(x_dims[0], y_dims[0], + "The first dimension(batch_size) of input(X) must be " + "equal to the first dimension of the input(Y)."); + PADDLE_ENFORCE_EQ(x_dims[1], weight_dims[1], + "The second dimension of input(X) must be equal to " + "the second dimension of the input(Weight)."); + PADDLE_ENFORCE_EQ(y_dims[1], weight_dims[2], + "The second dimension of input(Y) must be equal to " + "the third dimension of the input(Weight)."); + + if (ctx->HasInput("Bias")) { + auto bias_dims = ctx->GetInputDim("Bias"); + PADDLE_ENFORCE(bias_dims.size() == 2UL && bias_dims[0] == 1UL, + "The Input(Bias) must be a 2-D tensor with " + "the 2nd dimension fixed to 1 (a row vector)."); + PADDLE_ENFORCE_EQ(bias_dims[1], weight_dims[0], + "The second dimension of input(Bias) must be equal " + "to the first dimension of the input(Weight)."); + } + + ctx->SetOutputDim("Out", {x_dims[0], weight_dims[0]}); + ctx->ShareLoD("X", /*->*/ "Out"); + } +}; + +class BilinearTensorProductOpMaker : public framework::OpProtoAndCheckerMaker { + public: + BilinearTensorProductOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "The first input of bilinear_tensor_product operator."); + AddInput("Y", "The second input of bilinear_tensor_product operator."); + AddInput("Weight", + "The learnable parameters of bilinear_tensor_product operator."); + AddInput("Bias", "The learnable bias of bilinear_tensor_product operator.") + .AsDispensable(); + AddOutput("Out", "The output of bilinear_tensor_product operator."); + AddComment(R"DOC( +Bilinear Tensor Product operator. +Given input X and Y, a 3D tensor weight, and bias. Each column of the +output is computed by one slice i = 1, . . . , k of the tensor: + + M = (X W_i) \cdot Y + Out_i = \sum_i {M_i} + Bias_i + +)DOC"); + } +}; + +class BilinearTensorProductOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Weight"), + "Input(Weight) should not be null."); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) should not be null."); + auto x_dims = ctx->GetInputDim("X"); + auto y_dims = ctx->GetInputDim("Y"); + auto weight_dims = ctx->GetInputDim("Weight"); + auto out_dims = ctx->GetInputDim(framework::GradVarName("Out")); + + PADDLE_ENFORCE_EQ(out_dims.size(), 2UL, + "The input(Out@GRAD) must be a 2D Tensor."); + PADDLE_ENFORCE_EQ( + x_dims[0], out_dims[0], + "The first dimension(batch_size) of input(Out@GRAD) must be " + "equal to the first dimension of the Input(X)."); + PADDLE_ENFORCE_EQ( + weight_dims[0], out_dims[1], + "The second dimension of input(Out@GRAD) must be equal to " + "the third dimension of the Input(Weight)."); + + if (ctx->HasInput("Bias")) { + auto bias_dims = ctx->GetInputDim("Bias"); + PADDLE_ENFORCE_EQ( + bias_dims[1], out_dims[1], + "The second dimension of input(Out@GRAD) must be equal to " + "the second dimension of the Input(Bias)."); + auto bias_grad_name = framework::GradVarName("Bias"); + if (ctx->HasOutput(bias_grad_name)) + ctx->SetOutputDim(bias_grad_name, bias_dims); + } + + auto x_grad_name = framework::GradVarName("X"); + auto y_grad_name = framework::GradVarName("Y"); + auto weight_grad_name = framework::GradVarName("Weight"); + + if (ctx->HasOutput(x_grad_name)) { + ctx->SetOutputDim(x_grad_name, x_dims); + } + if (ctx->HasOutput(y_grad_name)) { + ctx->SetOutputDim(y_grad_name, y_dims); + } + if (ctx->HasOutput(weight_grad_name)) { + ctx->SetOutputDim(weight_grad_name, weight_dims); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(bilinear_tensor_product, ops::BilinearTensorProductOp, + ops::BilinearTensorProductOpMaker, bilinear_tensor_product_grad, + ops::BilinearTensorProductOpGrad); +REGISTER_OP_CPU_KERNEL( + bilinear_tensor_product, + ops::BilinearTensorProductKernel, + ops::BilinearTensorProductKernel); +REGISTER_OP_CPU_KERNEL( + bilinear_tensor_product_grad, + ops::BilinearTensorProductGradKernel, + ops::BilinearTensorProductGradKernel); diff --git a/paddle/operators/bilinear_tensor_product_op.cu b/paddle/operators/bilinear_tensor_product_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..858d2668d01379afe8082cd1eda32a2a5d09bd18 --- /dev/null +++ b/paddle/operators/bilinear_tensor_product_op.cu @@ -0,0 +1,26 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#define EIGEN_USE_GPU +#include "paddle/operators/bilinear_tensor_product_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL( + bilinear_tensor_product, + ops::BilinearTensorProductKernel, + ops::BilinearTensorProductKernel); +REGISTER_OP_GPU_KERNEL( + bilinear_tensor_product_grad, + ops::BilinearTensorProductGradKernel, + ops::BilinearTensorProductGradKernel); diff --git a/paddle/operators/bilinear_tensor_product_op.h b/paddle/operators/bilinear_tensor_product_op.h new file mode 100644 index 0000000000000000000000000000000000000000..ffa4f43a327418498c1f110504127e7d2878409d --- /dev/null +++ b/paddle/operators/bilinear_tensor_product_op.h @@ -0,0 +1,184 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" +#include "paddle/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; + +template +using EigenMatrix = framework::EigenMatrix; + +template +class BilinearTensorProductKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto* weight = ctx.Input("Weight"); + auto* bias = ctx.Input("Bias"); + auto* out = ctx.Output("Out"); + out->mutable_data(ctx.GetPlace()); + + auto y_mat = EigenMatrix::From(*y); + auto output_mat = EigenMatrix::From(*out); + + auto batch_size = x->dims()[0]; + auto weight_dims = weight->dims(); + int out_dim = weight_dims[0]; + auto x_dim = weight_dims[1]; + auto y_dim = weight_dims[2]; + auto place = ctx.GetEigenDevice(); + + // Create the intermediate variable to caculate the result of + // Input(X) multiplied by Input(Weight_i), the formula is: + // left_mul = X Weight_i. + Tensor left_mul; + left_mul.mutable_data(framework::make_ddim({batch_size, y_dim}), + ctx.GetPlace()); + auto left_mul_mat = EigenMatrix::From(left_mul); + + for (int i = 0; i < out_dim; ++i) { + auto output_col_vec = output_mat.chip(i, 1); + Tensor weight_mat = + weight->Slice(i, i + 1).Resize(framework::make_ddim({x_dim, y_dim})); + math::gemm(ctx.device_context(), CblasNoTrans, CblasNoTrans, + batch_size, y_dim, x_dim, 1, x->data(), + weight_mat.data(), 0, left_mul.data()); + output_col_vec.device(place) = + (left_mul_mat * y_mat).sum(Eigen::DSizes(1)); + } + if (bias) { + auto bias_vec = EigenMatrix::From(*bias); + Eigen::DSizes bcast(batch_size, 1); + output_mat.device(place) = bias_vec.broadcast(bcast) + output_mat; + } + } +}; + +template +class BilinearTensorProductGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + const Tensor* x = ctx.Input("X"); + const Tensor* y = ctx.Input("Y"); + const Tensor* weight = ctx.Input("Weight"); + Tensor* d_x = ctx.Output(framework::GradVarName("X")); + Tensor* d_y = ctx.Output(framework::GradVarName("Y")); + Tensor* d_weight = ctx.Output(framework::GradVarName("Weight")); + Tensor* d_bias = ctx.Output(framework::GradVarName("Bias")); + const Tensor* d_out = ctx.Input(framework::GradVarName("Out")); + + auto batch_size = x->dims()[0]; + auto weight_dims = weight->dims(); + int out_dim = weight_dims[0]; + auto x_dim = weight_dims[1]; + auto y_dim = weight_dims[2]; + + auto x_mat = EigenMatrix::From(*x); + auto y_mat = EigenMatrix::From(*y); + auto d_out_mat = EigenMatrix::From(*d_out); + auto place = ctx.GetEigenDevice(); + + // Create the intermediate variable to caculate the Output(Y@Grad). + Tensor x_scale; + x_scale.mutable_data(framework::make_ddim({batch_size, x_dim}), + ctx.GetPlace()); + auto x_scale_mat = EigenMatrix::From(x_scale); + + // Create the intermediate variable to caculate the Output(X@Grad). + Tensor y_scale; + y_scale.mutable_data(framework::make_ddim({batch_size, y_dim}), + ctx.GetPlace()); + auto y_scale_mat = EigenMatrix::From(y_scale); + + math::SetConstant set_zero; + + // Set Output(X@Grad) be zero. + if (d_x) { + d_x->mutable_data(ctx.GetPlace()); + set_zero(ctx.device_context(), d_x, static_cast(0)); + } + + // Set Output(Y@Grad) be zero. + if (d_y) { + d_y->mutable_data(ctx.GetPlace()); + set_zero(ctx.device_context(), d_y, static_cast(0)); + } + + // Caculate the Output(X@Grad) and Output(Y@Grad). + if (d_x || d_y) { + Eigen::DSizes bcast_for_x(1, y_dim); + Eigen::DSizes bcast_for_y(1, x_dim); + for (int i = 0; i < out_dim; ++i) { + Tensor weight_i = weight->Slice(i, i + 1).Resize( + framework::make_ddim({x_dim, y_dim})); + auto output_vec = d_out_mat.chip(i, 1); + if (d_x) { + y_scale_mat.device(place) = + output_vec.reshape(Eigen::DSizes(batch_size, 1)) + .broadcast(bcast_for_x) * + y_mat; + math::gemm(ctx.device_context(), CblasNoTrans, CblasTrans, + batch_size, x_dim, y_dim, 1, y_scale.data(), + weight_i.data(), 1, d_x->data()); + } + if (d_y) { + x_scale_mat.device(place) = + output_vec.reshape(Eigen::DSizes(batch_size, 1)) + .broadcast(bcast_for_y) * + x_mat; + math::gemm(ctx.device_context(), CblasNoTrans, CblasNoTrans, + batch_size, y_dim, x_dim, 1, x_scale.data(), + weight_i.data(), 1, d_y->data()); + } + } + } + + // Caculate the gradient of Input(Weight). + if (d_weight) { + d_weight->mutable_data(ctx.GetPlace()); + Eigen::DSizes bcast_for_weight(1, x_dim); + for (int i = 0; i < out_dim; ++i) { + Tensor d_weight_i = d_weight->Slice(i, i + 1).Resize( + framework::make_ddim({x_dim, y_dim})); + auto output_vec = d_out_mat.chip(i, 1); + x_scale_mat.device(place) = + output_vec.reshape(Eigen::DSizes(batch_size, 1)) + .broadcast(bcast_for_weight) * + x_mat; + math::gemm(ctx.device_context(), CblasTrans, CblasNoTrans, + x_dim, y_dim, batch_size, 1, x_scale.data(), + y->data(), 0, d_weight_i.data()); + } + } + + // Caculate the gradient of Input(Bias). + if (d_bias) { + d_bias->mutable_data(ctx.GetPlace()); + auto d_bias_mat = EigenMatrix::From(*d_bias); + d_bias_mat.device(place) = d_out_mat.sum(Eigen::DSizes(0)); + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/conditional_block_op.cc b/paddle/operators/conditional_block_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..d5b124682d755ffb39f32c9f001a3cf113a01a2c --- /dev/null +++ b/paddle/operators/conditional_block_op.cc @@ -0,0 +1,197 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ +#include +#include "paddle/framework/executor.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +class ConditionalOp : public framework::OperatorBase { + public: + ConditionalOp(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + + protected: + std::vector InputTensors( + const framework::Scope &scope) const { + std::vector retv; + auto xs = Inputs("X"); + retv.resize(xs.size(), nullptr); + std::transform( + xs.begin(), xs.end(), retv.begin(), + [&scope](const std::string &var_name) -> const framework::LoDTensor * { + auto *var = scope.FindVar(var_name); + PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s", var_name); + return &var->Get(); + }); + return retv; + } +}; + +class ConditionalBlockOp : public ConditionalOp { + public: + ConditionalBlockOp(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : ConditionalOp(type, inputs, outputs, attrs) {} + void Run(const framework::Scope &scope, + const platform::DeviceContext &dev_ctx) const override { + auto xs = InputTensors(scope); + bool need_run = std::all_of( + xs.begin(), xs.end(), + [](const framework::LoDTensor *t) { return t->numel() != 0; }); + + if (need_run) { + auto *scope_var = scope.FindVar(Output("Scope")); + PADDLE_ENFORCE(scope_var != nullptr, "Must set scope"); + auto *scopes = scope_var->GetMutable>(); + scopes->resize(1); + scopes->front() = &scope.NewScope(); + auto &cur_scope = *scopes->front(); + + auto *block = Attr("block"); + framework::Executor exec(dev_ctx); + exec.Run(*block->Program(), &cur_scope, block->ID(), false); + } + } +}; + +class ConditionalBlockOpProtoMaker : public framework::OpProtoAndCheckerMaker { + public: + ConditionalBlockOpProtoMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", + "The conditional variable of this operator. If X is empty, the " + "whole sub-block will not be executed.") + .AsDuplicable(); + AddInput("Params", "The input variables of the sub-block.").AsDuplicable(); + AddOutput("Out", "The output variables of the sub-block.").AsDuplicable(); + AddOutput("Scope", + "(std::vector) The step scope of conditional block. To " + "unify the conditional block, rnn and while op, the type of " + "scope is std::vector"); + AddAttr( + "block", "The step block of conditional block operator"); + AddComment(R"DOC(Conditional block operator + +Run the sub-block if X is not empty. Params is the other inputs and Out is the +outputs of the sub-block. +)DOC"); + } +}; + +class ConditionalBlockGradOp : public ConditionalOp { + public: + ConditionalBlockGradOp(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : ConditionalOp(type, inputs, outputs, attrs) {} + void Run(const framework::Scope &scope, + const platform::DeviceContext &dev_ctx) const override { + auto xs = this->InputTensors(scope); + bool need_run = std::all_of( + xs.begin(), xs.end(), + [](const framework::LoDTensor *t) { return t->numel() != 0; }); + + if (need_run) { + auto *scope_var = scope.FindVar(Input("Scope")); + PADDLE_ENFORCE(scope_var != nullptr, "Must set scope"); + auto &scopes = scope_var->Get>(); + framework::Scope &cur_scope = *scopes[0]; + + auto *block = Attr("block"); + framework::Executor exec(dev_ctx); + exec.Run(*block->Program(), &cur_scope, block->ID(), false); + + AssignLocalGradientToGlobal(dev_ctx, cur_scope, Inputs("Params"), + Outputs(framework::GradVarName("Params"))); + + AssignLocalGradientToGlobal(dev_ctx, cur_scope, Inputs("X"), + Outputs(framework::GradVarName("X"))); + } + } + + private: + void AssignLocalGradientToGlobal( + const platform::DeviceContext &dev_ctx, const framework::Scope &cur_scope, + const std::vector &p_names, + const std::vector &pg_names) const { + for (size_t i = 0; i < p_names.size(); ++i) { + auto out_grad_name = pg_names[i]; + auto in_grad_name = framework::GradVarName(p_names[i]); + auto *in_var = cur_scope.FindVar(in_grad_name); + if (in_var == nullptr) { + continue; + } + auto new_in_grad_name = cur_scope.Rename(in_grad_name); + auto assign = + framework::OpRegistry::CreateOp("assign", {{"X", {new_in_grad_name}}}, + {{"Out", {out_grad_name}}}, {}); + assign->Run(cur_scope, dev_ctx); + cur_scope.Rename(new_in_grad_name, in_grad_name); + } + } +}; + +class ConditionalBlockGradInferShape : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *context) const override { + PADDLE_ENFORCE(context->HasInputs("X")); + if (context->HasInputs("Params")) { + PADDLE_ENFORCE(context->HasOutputs(framework::GradVarName("Params"))); + context->SetOutputsDim(framework::GradVarName("Params"), + context->GetInputsDim("Params")); + } + PADDLE_ENFORCE(context->HasOutputs(framework::GradVarName("X"))); + context->SetOutputsDim(framework::GradVarName("X"), + context->GetInputsDim("X")); + } +}; + +class ConditionalBlockGradMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + auto grad_op = new framework::OpDescBind(); + grad_op->SetType("conditional_block_grad"); + grad_op->SetInput("X", Input("X")); + grad_op->SetInput("Params", Input("Params")); + grad_op->SetInput("Out", Output("Out")); + grad_op->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); + grad_op->SetInput("Scope", Output("Scope")); + grad_op->SetOutput(framework::GradVarName("X"), InputGrad("X")); + grad_op->SetOutput(framework::GradVarName("Params"), InputGrad("Params")); + grad_op->SetBlockAttr("block", *this->grad_block_[0]); + return std::unique_ptr(grad_op); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(conditional_block, ops::ConditionalBlockOp, + ops::ConditionalBlockOpProtoMaker, + ops::ConditionalBlockGradMaker); +REGISTER_OPERATOR(conditional_block_grad, ops::ConditionalBlockGradOp, + ops::ConditionalBlockGradInferShape); diff --git a/paddle/operators/lod_reset_op.cc b/paddle/operators/lod_reset_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..32831cb1e2cf188a507773ef1e00b22de98d82ab --- /dev/null +++ b/paddle/operators/lod_reset_op.cc @@ -0,0 +1,120 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/lod_reset_op.h" + +namespace paddle { +namespace operators { + +class LoDResetOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + // input check + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of LoDResetOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of LoDResetOp should not be null."); + // If target LoD is not set form Input(), then it must be set from Attr(). + if (!ctx->HasInput("TargetLoD")) { + auto level0 = ctx->Attrs().Get>("target_lod"); + PADDLE_ENFORCE(level0.size() > 1, + "Target LoD is not found, should be set to be a valid one " + "through Input() or Attr()."); + } + ctx->SetOutputDim("Out", ctx->GetInputDim("X")); + } + + protected: + framework::OpKernelType GetKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + ctx.device_context()); + } +}; + +class LoDResetOpMaker : public framework::OpProtoAndCheckerMaker { + public: + LoDResetOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "(LoDTensor) The input tensor of lod_reset operator."); + AddInput("TargetLoD", + "(Tensor, optional) The target level 0 LoD from Input().") + .AsDispensable(); + AddOutput("Out", "(LoDTensor) The output tensor of lod_reset operator."); + AddAttr>("target_lod", + "The target level 0 LoD from Attr().") + .SetDefault(std::vector{}); + AddComment(R"DOC(LoDReset operator + +Reset LoD of Input(X) into a new one specified by Input(TargetLoD) or +Attr(target_lod), or set LoD for Input(X) if it doesn't have one. +Currently the lod_reset operator only supports the reset of level 0 LoD. +At least one of Input(TargetLoD) and Attr(target_lod) must be set, +and if both of them are set, Input(TargetLoD) will be chosen as the +target LoD. + +An example: +Given a float LoDTensor X with shape (6, 1), its transpose form represents + + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0], + +with LoD = [[0, 2, 5, 6]] and the three (transposed) sequences look like + + [1.0, 2.0], [3.0, 4.0, 5.0], [6.0]. + +If target LoD = [0, 4, 6], the lod_reset operator will reset the LoD and +the sequences that the LoDTensor Output(Out) contains becomes: + + [1.0, 2.0, 3.0, 4.0], [5.0, 6.0]. + +)DOC"); + } +}; + +class LoDResetGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) shouldn't be null."); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) shouldn't be null."); + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); + } + + protected: + framework::OpKernelType GetKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + ctx.device_context()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(lod_reset, ops::LoDResetOp, ops::LoDResetOpMaker, lod_reset_grad, + ops::LoDResetGradOp); +REGISTER_OP_CPU_KERNEL(lod_reset, + ops::LoDResetKernel, + ops::LoDResetKernel); +REGISTER_OP_CPU_KERNEL( + lod_reset_grad, ops::LoDResetGradKernel, + ops::LoDResetGradKernel); diff --git a/paddle/operators/lod_reset_op.cu b/paddle/operators/lod_reset_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..5244a17c3aad01909e3b8cf5f4d5abf8a44edc7f --- /dev/null +++ b/paddle/operators/lod_reset_op.cu @@ -0,0 +1,24 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/lod_reset_op.h" + +namespace ops = paddle::operators; + +REGISTER_OP_GPU_KERNEL(lod_reset, + ops::LoDResetKernel, + ops::LoDResetKernel); +REGISTER_OP_GPU_KERNEL( + lod_reset_grad, ops::LoDResetGradKernel, + ops::LoDResetGradKernel); diff --git a/paddle/operators/lod_reset_op.h b/paddle/operators/lod_reset_op.h new file mode 100644 index 0000000000000000000000000000000000000000..2bb916ccee80c83a02ea429fe95f5fafc86ccfa6 --- /dev/null +++ b/paddle/operators/lod_reset_op.h @@ -0,0 +1,78 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once + +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +class LoDResetKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const { + auto* out = ctx.Output("Out"); + auto* in = ctx.Input("X"); + auto* lod_t = ctx.Input("TargetLoD"); + + std::vector level0; + if (lod_t) { + auto* lod = lod_t->data(); + if (platform::is_gpu_place(ctx.GetPlace())) { + framework::Tensor lod_cpu; + lod_cpu.CopyFrom(*lod_t, platform::CPUPlace(), ctx.device_context()); + lod = lod_cpu.data(); + } + level0 = std::vector(lod, lod + lod_t->numel()); + } else { + level0 = ctx.Attr>("target_lod"); + } + + PADDLE_ENFORCE(level0.size() > 1UL, + "The size of target LoD should be greater than 1."); + PADDLE_ENFORCE(level0[0] == 0, + "Target LoD should be a vector starting from 0."); + PADDLE_ENFORCE(level0.back() == in->dims()[0], + "Target LoD should be a vector end with the " + "first dimension of Input(X)."); + for (size_t i = 0; i < level0.size() - 1; ++i) { + PADDLE_ENFORCE(level0[i + 1] > level0[i], + "Target LoD should be an ascending vector."); + } + + out->ShareDataWith(*in); + // cast level0 to size_t + std::vector ulevel0(level0.size(), 0); + std::transform(level0.begin(), level0.end(), ulevel0.begin(), + [](int a) { return static_cast(a); }); + framework::LoD target_lod; + target_lod.push_back(ulevel0); + out->set_lod(target_lod); + } +}; + +template +class LoDResetGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const { + auto* d_out = ctx.Input(framework::GradVarName("Out")); + auto* d_x = ctx.Output(framework::GradVarName("X")); + + d_x->ShareDataWith(*d_out); + } +}; +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/matmul_op.h b/paddle/operators/matmul_op.h index 5ce30740c90b5cd0bd4f8ab183cf985ed5d827c1..4f565946d596b5e5fbf90f16c0c13c780c36886c 100644 --- a/paddle/operators/matmul_op.h +++ b/paddle/operators/matmul_op.h @@ -74,11 +74,10 @@ Tensor CombineBatchAndN(const framework::ExecutionContext& context, Tensor output; auto in_dims = input.dims(); if (in_dims.size() == 3) { - output.Resize(in_dims); + output.Resize({in_dims[1], in_dims[0], in_dims[2]}); output.mutable_data(context.GetPlace()); EigenTranspose(context, input, output, {1, 0, 2}); - std::vector out_dims = {in_dims[1], in_dims[0] * in_dims[2]}; - output.Resize(make_ddim(out_dims)); + output.Resize({in_dims[1], in_dims[0] * in_dims[2]}); } else { output.ShareDataWith(input); } diff --git a/paddle/operators/merge_lod_tensor_op.cc b/paddle/operators/merge_lod_tensor_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..80460c476921b63ec5228a9780880c7db3c85217 --- /dev/null +++ b/paddle/operators/merge_lod_tensor_op.cc @@ -0,0 +1,182 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/framework/op_registry.h" +#include "paddle/memory/memcpy.h" + +namespace paddle { +namespace operators { + +using LoD = framework::LoD; + +class MergeLoDTensorOp : public framework::OperatorBase { + public: + MergeLoDTensorOp(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + void Run(const framework::Scope &scope, + const platform::DeviceContext &dev_ctx) const override { + auto &x = scope.FindVar(Input("X"))->Get(); + auto &mask = scope.FindVar(Input("Mask"))->Get(); + auto &in_true = scope.FindVar(Input("InTrue"))->Get(); + auto &in_false = + scope.FindVar(Input("InFalse"))->Get(); + auto *out = + scope.FindVar(Output("Out"))->GetMutable(); + auto level = static_cast(Attr("level")); + + auto &mask_dim = mask.dims(); + + std::unique_ptr cpu_mask{new framework::LoDTensor()}; + if (platform::is_cpu_place(mask.place())) { + cpu_mask->ShareDataWith(mask); + } else if (platform::is_gpu_place(mask.place())) { +#ifdef PADDLE_WITH_CUDA + cpu_mask->CopyFrom(mask, platform::CPUPlace(), dev_ctx); +#else + PADDLE_THROW("Not supported GPU, Please compile WITH_GPU option"); +#endif + } + auto *mask_data = cpu_mask->data(); + + int rank = in_true.dims().size(); + platform::Place place = in_true.place(); + std::type_index data_type = in_true.type(); + framework::DDim in_true_dims = + framework::slice_ddim(in_true.dims(), 1, rank); + + int64_t batch_size = in_true.dims()[0] + in_false.dims()[0]; + + auto in_true_dim_vec = framework::vectorize(in_true_dims); + in_true_dim_vec.insert(in_true_dim_vec.begin(), batch_size); + + framework::DDim out_dims = framework::make_ddim(in_true_dim_vec); + out->Resize(out_dims); + out->mutable_data(place, data_type); + + auto *out_lod = out->mutable_lod(); + out_lod->clear(); + size_t out_offset = 0; + + // Build LoDTensor `out` + + size_t in_true_idx = 0; + size_t in_false_idx = 0; + for (size_t i = 0; i < static_cast(mask_dim[0]); i++) { + const framework::LoDTensor *input = nullptr; + size_t *in_idx = nullptr; + if (static_cast(mask_data[i]) == 0) { + input = &in_false; + in_idx = &in_false_idx; + } else { + input = &in_true; + in_idx = &in_true_idx; + } + auto lod_and_offset = framework::GetSubLoDAndAbsoluteOffset( + input->lod(), *in_idx, (*in_idx) + 1, 0); + auto &lod_length = lod_and_offset.first; + + framework::AppendLoD(out_lod, lod_length); + + size_t start_offset = lod_and_offset.second.first; + size_t end_offset = lod_and_offset.second.second; + + PADDLE_ENFORCE_GE(end_offset, start_offset); + size_t len = end_offset - start_offset; + if (len == 0) { + continue; + } + out->Slice(out_offset, out_offset + len) + .CopyFrom(input->Slice(start_offset, end_offset), place, dev_ctx); + out_offset += len; + (*in_idx) += 1; + } + + for (size_t i = 0; i < level; i++) { + out_lod->insert(out_lod->begin(), x.lod()[i]); + } + } +}; + +class MergeLoDTensorOpProtoMaker : public framework::OpProtoAndCheckerMaker { + public: + MergeLoDTensorOpProtoMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", + "The input LoDTensor, contains complete lod information to " + "construct the output"); + AddInput("Mask", "A bool column vector which mask the input"); + AddInput("InTrue", "The True branch to be merged"); + AddInput("InFalse", "The False branch to be merged"); + AddOutput("Out", "The merged output LoDTensor"); + AddAttr("level", "(int) the specific lod level to rank.") + .SetDefault(0) + .EqualGreaterThan(0); + AddComment( + R"DOC( + Merge True and False branches of LoDTensor into a single Output, + with a mask at certain lod level. X is used to obtain complete + lod information. Please refer to SplitLoDTensorOp.)DOC"); + } +}; + +class MergeLoDTensorInferShape : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *context) const override { + PADDLE_ENFORCE(context->HasInput("X"), + "MergeLoDTensorOp must has input X."); + PADDLE_ENFORCE(context->HasInput("Mask"), + "MergeLoDTensorOp must has input Mask."); + PADDLE_ENFORCE(context->HasInput("InTrue"), + "MergeLoDTensorOp must has input InTrue."); + PADDLE_ENFORCE(context->HasInput("InFalse"), + "MergeLoDTensorOp must has input InFalse."); + PADDLE_ENFORCE(context->HasOutput("Out"), + "MergeLoDTensorOp must has output Out"); + + auto mask_dim = context->GetInputDim("Mask"); + PADDLE_ENFORCE_EQ(mask_dim.size(), 2); + PADDLE_ENFORCE_EQ(mask_dim[1], 1); + + context->SetOutputDim("Out", context->GetInputDim("InTrue")); + } +}; + +class MergeLoDTensorGradMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + auto *grad_op = new framework::OpDescBind(); + grad_op->SetType("split_lod_tensor"); + grad_op->SetInput("X", OutputGrad("Out")); + grad_op->SetInput("Mask", Input("Mask")); + grad_op->SetOutput("OutTrue", InputGrad("InTrue")); + grad_op->SetOutput("OutFalse", InputGrad("InFalse")); + grad_op->SetAttrMap(Attrs()); + return std::unique_ptr(grad_op); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(merge_lod_tensor, ops::MergeLoDTensorOp, + ops::MergeLoDTensorOpProtoMaker, + ops::MergeLoDTensorInferShape, ops::MergeLoDTensorGradMaker); diff --git a/paddle/operators/sequence_pool_op.h b/paddle/operators/sequence_pool_op.h index 2b8a25c2414c20efaffedfc8603697b3a104634f..7f136d8cf0e1eaae7b4de32988b60ae8a5034cc6 100644 --- a/paddle/operators/sequence_pool_op.h +++ b/paddle/operators/sequence_pool_op.h @@ -126,6 +126,7 @@ class SequencePoolGradKernel : public framework::OpKernel { int64_t h = static_cast(lod[i + 1] - lod[i]); auto in_g_e = EigenMatrix::From(in_g_t, {h, w}); auto out_g_e = EigenMatrix::From(out_g_t, {1, w}); + auto out_g_e_v = EigenVector::Flatten(out_g_t); Eigen::DSizes bcast(h, 1); if (pooltype == "AVERAGE") { @@ -136,9 +137,9 @@ class SequencePoolGradKernel : public framework::OpKernel { in_g_e.device(place) = (out_g_e / std::sqrt(static_cast(h))).broadcast(bcast); } else if (pooltype == "LAST") { - in_g_e.chip(h - 1, 0).device(place) = out_g_e; + in_g_e.chip(h - 1, 0).device(place) = out_g_e_v; } else if (pooltype == "FIRST") { - in_g_e.chip(0, 0).device(place) = out_g_e; + in_g_e.chip(0, 0).device(place) = out_g_e_v; } else { PADDLE_THROW("unsupported pooling pooltype"); } diff --git a/paddle/operators/split_lod_tensor_op.cc b/paddle/operators/split_lod_tensor_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..db635f2ba0804143c9a2e04ff006dfbc8744f3fc --- /dev/null +++ b/paddle/operators/split_lod_tensor_op.cc @@ -0,0 +1,186 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/framework/op_registry.h" +#include "paddle/memory/memcpy.h" + +namespace paddle { +namespace operators { + +struct CopyRange { + size_t begin; + size_t end; +}; + +using LoD = framework::LoD; + +class SplitLoDTensorOp : public framework::OperatorBase { + public: + SplitLoDTensorOp(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + void Run(const framework::Scope &scope, + const platform::DeviceContext &dev_ctx) const override { + auto &x = scope.FindVar(Input("X"))->Get(); + auto &mask = scope.FindVar(Input("Mask"))->Get(); + auto *out_true = + scope.FindVar(Output("OutTrue"))->GetMutable(); + auto *out_false = + scope.FindVar(Output("OutFalse"))->GetMutable(); + auto level = static_cast(Attr("level")); + auto &x_lod = x.lod(); + auto &mask_dim = mask.dims(); + + std::unique_ptr cpu_mask{new framework::LoDTensor()}; + if (platform::is_cpu_place(mask.place())) { + cpu_mask->ShareDataWith(mask); + } else if (platform::is_gpu_place(mask.place())) { +#ifdef PADDLE_WITH_CUDA + cpu_mask->CopyFrom(mask, platform::CPUPlace(), dev_ctx); +#else + PADDLE_THROW("Not supported GPU, Please compile WITH_GPU option"); +#endif + } + auto *mask_data = cpu_mask->data(); + + std::vector> copy_ranges(mask_dim[0]); + + // set out_true/out_false lod + for (size_t t = 0; t < 2; t++) { + LoD *lod = nullptr; + if (t == 0) { + lod = out_false->mutable_lod(); + } else { + lod = out_true->mutable_lod(); + } + lod->clear(); + for (size_t i = 0; i < static_cast(mask_dim[0]); i++) { + if (static_cast(mask_data[i]) == t) { + size_t start_idx = i; + auto lod_and_offset = framework::GetSubLoDAndAbsoluteOffset( + x_lod, start_idx, start_idx + 1, level); + + auto &lod_length = lod_and_offset.first; + framework::AppendLoD(lod, lod_length); + + size_t start_offset = lod_and_offset.second.first; + size_t end_offset = lod_and_offset.second.second; + copy_ranges[t].emplace_back(CopyRange{start_offset, end_offset}); + } + } + } + + for (size_t t = 0; t < 2; ++t) { + framework::LoDTensor *out; + if (t == 0) { + out = out_false; + } else { + out = out_true; + } + auto &ranges = copy_ranges[t]; + size_t height = std::accumulate( + ranges.begin(), ranges.end(), 0UL, + [](size_t a, const CopyRange &b) { return a + b.end - b.begin; }); + auto x_dim = x.dims(); + x_dim[0] = static_cast(height); + out->Resize(x_dim); + out->mutable_data(x.place(), x.type()); + size_t offset = 0; + for (auto &each_range : ranges) { + size_t len = each_range.end - each_range.begin; + if (len == 0) { + continue; + } + // out[offset: offset+len] = x[each_range.begin: each_range.end] + out->Slice(static_cast(offset), static_cast(offset + len)) + .CopyFrom(x.Slice(static_cast(each_range.begin), + static_cast(each_range.end)), + x.place(), dev_ctx); + offset += len; + } + } + } +}; + +class SplitLoDTensorOpProtoMaker : public framework::OpProtoAndCheckerMaker { + public: + SplitLoDTensorOpProtoMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "The input LoDTensor"); + AddInput("Mask", "A bool column vector which mask the input"); + AddOutput("OutTrue", "True branch of input LoDTensor"); + AddOutput("OutFalse", "False branch of input LoDTensor"); + AddAttr("level", "(int) the specific lod level to split.") + .SetDefault(0) + .EqualGreaterThan(0); + AddComment( + R"DOC( + Split a LoDTensor with a Mask at certain level. The input LoDTensor + has 3 sequence at certain lod level. The Mask is a bool column vector, + such as [0, 1, 0] at the same level. The first and third sequence will + be send to False Output LoDTensor; whereas the second sequence will + be send to True Output LoDTensor. Please refer to MergeLoDTensorOp.)DOC"); + } +}; + +class SplitLoDTensorInferShape : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *context) const override { + PADDLE_ENFORCE(context->HasInput("X"), + "SplitLoDTensorOp must has input X."); + PADDLE_ENFORCE(context->HasInput("Mask"), + "SplitLoDTensorOp must has input Mask."); + PADDLE_ENFORCE(context->HasOutput("OutTrue"), + "SplitLoDTensorOp must has output OutTrue."); + PADDLE_ENFORCE(context->HasOutput("OutFalse"), + "SplitLoDTensorOp must has output OutFalse."); + + auto mask_dim = context->GetInputDim("Mask"); + PADDLE_ENFORCE_EQ(mask_dim.size(), 2); + PADDLE_ENFORCE_EQ(mask_dim[1], 1); + + context->SetOutputDim("OutTrue", context->GetInputDim("X")); + context->SetOutputDim("OutFalse", context->GetInputDim("X")); + } +}; + +class SplitLoDTensorArrayGradMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + auto *grad_op = new framework::OpDescBind(); + grad_op->SetType("merge_lod_tensor"); + grad_op->SetInput("InTrue", OutputGrad("OutTrue")); + grad_op->SetInput("InFalse", OutputGrad("OutFalse")); + grad_op->SetInput("Mask", Input("Mask")); + grad_op->SetInput("X", Input("X")); + grad_op->SetOutput("Out", InputGrad("X")); + grad_op->SetAttrMap(Attrs()); + return std::unique_ptr(grad_op); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(split_lod_tensor, ops::SplitLoDTensorOp, + ops::SplitLoDTensorOpProtoMaker, + ops::SplitLoDTensorInferShape, + ops::SplitLoDTensorArrayGradMaker); diff --git a/python/paddle/v2/framework/framework.py b/python/paddle/v2/framework/framework.py index b9db2707c0705659260c04ab3412f429058a1316..0e6f083e5ba25f4c64fe6988965de93be1d1688b 100644 --- a/python/paddle/v2/framework/framework.py +++ b/python/paddle/v2/framework/framework.py @@ -285,7 +285,7 @@ class Operator(object): self.desc.check_attrs() no_kernel_op_set = { 'feed', 'fetch', 'save', 'load', 'recurrent', - 'rnn_memory_helper_grad', 'while' + 'rnn_memory_helper_grad', 'conditional_block', 'while' } if type not in no_kernel_op_set: self.desc.infer_var_type(self.block.desc) diff --git a/python/paddle/v2/framework/layers.py b/python/paddle/v2/framework/layers.py index fe3c86febb00a748c908b8bac853e86af9b6dac3..ae85f460f78c9a661d490e0fe673882bcbacd19f 100644 --- a/python/paddle/v2/framework/layers.py +++ b/python/paddle/v2/framework/layers.py @@ -11,7 +11,7 @@ import cStringIO __all__ = [ 'fc', 'data', 'cross_entropy', 'conv2d', 'pool2d', 'embedding', 'concat', 'StaticRNN', 'cast', 'sequence_conv', 'sequence_pool', 'sums', 'cos_sim', - 'batch_norm', 'accuracy' + 'batch_norm', 'accuracy', 'split_lod_tensor' ] @@ -226,6 +226,11 @@ def data(name, stop_gradient=stop_gradient) +def create_tensor(dtype, name=None, main_program=None): + helper = LayerHelper("create_tensor", **locals()) + return helper.create_variable(name=helper.name, dtype=dtype) + + def _convert_(name): """ Formatting. @@ -451,6 +456,56 @@ def sums(input, main_program=None, startup_program=None): return out +def assign(input, output, main_program=None): + helper = LayerHelper('assign', **locals()) + helper.append_op( + type='scale', + inputs={'X': [input]}, + outputs={'Out': [output]}, + attrs={'scale': 1.0}) + return output + + +def split_lod_tensor(input, + mask, + level, + main_program=None, + startup_program=None): + helper = LayerHelper('split_lod_tensor', **locals()) + out_true = helper.create_tmp_variable(dtype=input.data_type) + out_false = helper.create_tmp_variable(dtype=input.data_type) + helper.append_op( + type='split_lod_tensor', + inputs={ + 'X': input, + 'Mask': mask, + }, + outputs={'OutTrue': out_true, + 'OutFalse': out_false}, + attrs={'level': level}) + return out_true, out_false + + +def merge_lod_tensor(in_true, + in_false, + x, + mask, + level, + main_program=None, + startup_program=None): + helper = LayerHelper('merge_lod_tensor', **locals()) + out = helper.create_tmp_variable(dtype=x.data_type) + helper.append_op( + type='merge_lod_tensor', + inputs={'X': x, + 'Mask': mask, + 'InTrue': in_true, + 'InFalse': in_false}, + outputs={'Out': out}, + attrs={'level': level}) + return out + + def cos_sim(X, Y, **kwargs): """ This function performs the cosine similarity between two tensors @@ -1375,3 +1430,73 @@ def array_length(array, main_program=None): helper.append_op( type='lod_array_length', inputs={'X': [array]}, outputs={'Out': [tmp]}) return tmp + + +class ConditionalBlockGuard(BlockGuard): + def __init__(self, block): + if not isinstance(block, ConditionalBlock): + raise TypeError("block should be conditional block") + super(ConditionalBlockGuard, self).__init__(block.helper.main_program) + self.block = block + + def __enter__(self): + return super(ConditionalBlockGuard, self).__enter__() + + def __exit__(self, exc_type, exc_val, exc_tb): + self.block.complete() + return super(ConditionalBlockGuard, self).__exit__(exc_type, exc_val, + exc_tb) + + +class ConditionalBlock(object): + def __init__(self, inputs, name=None, main_program=None): + for each_input in inputs: + if not isinstance(each_input, Variable): + raise TypeError("Each input should be variable") + self.inputs = inputs + self.helper = LayerHelper( + 'conditional_block', name=name, main_program=main_program) + + def block(self): + return ConditionalBlockGuard(self) + + def complete(self): + inside_block = self.helper.main_program.current_block() + parent_block = self.helper.main_program.block(inside_block.parent_idx) + + intermediate = set() + params = set() + + for each_op in inside_block.ops: + assert isinstance(each_op, Operator) + for iname in each_op.input_names: + for in_var_name in each_op.input(iname): + if in_var_name not in intermediate: + params.add(in_var_name) + + for oname in each_op.output_names: + for out_var_name in each_op.output(oname): + intermediate.add(out_var_name) + input_set = set([ipt.name for ipt in self.inputs]) + + param_list = [ + parent_block.var(each_name) for each_name in params + if each_name not in input_set + ] + + out_list = [ + parent_block.var(var_name) for var_name in parent_block.vars + if var_name not in intermediate + ] + + step_scope = parent_block.create_var( + type=core.VarDesc.VarType.STEP_SCOPES) + parent_block.append_op( + type='conditional_block', + inputs={ + 'X': self.inputs, + 'Params': param_list, + }, + outputs={'Out': out_list, + 'Scope': [step_scope]}, + attrs={'block': inside_block}) diff --git a/python/paddle/v2/framework/tests/CMakeLists.txt b/python/paddle/v2/framework/tests/CMakeLists.txt index 4d7664469e481344cf9eea84688f068b4fb99dee..e795627bfe9e8ad0c196349a332e62e975f20aa3 100644 --- a/python/paddle/v2/framework/tests/CMakeLists.txt +++ b/python/paddle/v2/framework/tests/CMakeLists.txt @@ -3,3 +3,5 @@ string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") foreach(src ${TEST_OPS}) py_test(${src} SRCS ${src}.py) endforeach() + +add_subdirectory(book) diff --git a/python/paddle/v2/framework/tests/book/CMakeLists.txt b/python/paddle/v2/framework/tests/book/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..4d7664469e481344cf9eea84688f068b4fb99dee --- /dev/null +++ b/python/paddle/v2/framework/tests/book/CMakeLists.txt @@ -0,0 +1,5 @@ +file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py") +string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") +foreach(src ${TEST_OPS}) + py_test(${src} SRCS ${src}.py) +endforeach() diff --git a/python/paddle/v2/framework/tests/test_fit_a_line.py b/python/paddle/v2/framework/tests/book/test_fit_a_line.py similarity index 100% rename from python/paddle/v2/framework/tests/test_fit_a_line.py rename to python/paddle/v2/framework/tests/book/test_fit_a_line.py diff --git a/python/paddle/v2/framework/tests/test_image_classification_train.py b/python/paddle/v2/framework/tests/book/test_image_classification_train.py similarity index 100% rename from python/paddle/v2/framework/tests/test_image_classification_train.py rename to python/paddle/v2/framework/tests/book/test_image_classification_train.py diff --git a/python/paddle/v2/framework/tests/test_recognize_digits_conv.py b/python/paddle/v2/framework/tests/book/test_recognize_digits_conv.py similarity index 100% rename from python/paddle/v2/framework/tests/test_recognize_digits_conv.py rename to python/paddle/v2/framework/tests/book/test_recognize_digits_conv.py diff --git a/python/paddle/v2/framework/tests/test_recognize_digits_mlp.py b/python/paddle/v2/framework/tests/book/test_recognize_digits_mlp.py similarity index 100% rename from python/paddle/v2/framework/tests/test_recognize_digits_mlp.py rename to python/paddle/v2/framework/tests/book/test_recognize_digits_mlp.py diff --git a/python/paddle/v2/framework/tests/test_recommender_system.py b/python/paddle/v2/framework/tests/book/test_recommender_system.py similarity index 100% rename from python/paddle/v2/framework/tests/test_recommender_system.py rename to python/paddle/v2/framework/tests/book/test_recommender_system.py diff --git a/python/paddle/v2/framework/tests/test_understand_sentiment_conv.py b/python/paddle/v2/framework/tests/book/test_understand_sentiment_conv.py similarity index 100% rename from python/paddle/v2/framework/tests/test_understand_sentiment_conv.py rename to python/paddle/v2/framework/tests/book/test_understand_sentiment_conv.py diff --git a/python/paddle/v2/framework/tests/test_understand_sentiment_dynamic_lstm.py b/python/paddle/v2/framework/tests/book/test_understand_sentiment_dynamic_lstm.py similarity index 100% rename from python/paddle/v2/framework/tests/test_understand_sentiment_dynamic_lstm.py rename to python/paddle/v2/framework/tests/book/test_understand_sentiment_dynamic_lstm.py diff --git a/python/paddle/v2/framework/tests/test_understand_sentiment_lstm.py b/python/paddle/v2/framework/tests/book/test_understand_sentiment_lstm.py similarity index 100% rename from python/paddle/v2/framework/tests/test_understand_sentiment_lstm.py rename to python/paddle/v2/framework/tests/book/test_understand_sentiment_lstm.py diff --git a/python/paddle/v2/framework/tests/test_word2vec.py b/python/paddle/v2/framework/tests/book/test_word2vec.py similarity index 100% rename from python/paddle/v2/framework/tests/test_word2vec.py rename to python/paddle/v2/framework/tests/book/test_word2vec.py diff --git a/python/paddle/v2/framework/tests/test_assign_op.py b/python/paddle/v2/framework/tests/test_assign_op.py new file mode 100644 index 0000000000000000000000000000000000000000..1b0c145f1a69678b228bc70e4e4e273f5bcf9888 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_assign_op.py @@ -0,0 +1,21 @@ +import op_test +import numpy +import unittest + + +class TestAssignOp(op_test.OpTest): + def setUp(self): + self.op_type = "assign" + x = numpy.random.random(size=(100, 10)) + self.inputs = {'X': x} + self.outputs = {'Out': x} + + def test_forward(self): + self.check_output() + + def test_backward(self): + self.check_grad(['X'], 'Out') + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_bilinear_tensor_product_op.py b/python/paddle/v2/framework/tests/test_bilinear_tensor_product_op.py new file mode 100644 index 0000000000000000000000000000000000000000..080ca43b8269e0f6a9f4d0ce3973f4d4a07a8e2a --- /dev/null +++ b/python/paddle/v2/framework/tests/test_bilinear_tensor_product_op.py @@ -0,0 +1,37 @@ +import unittest +import numpy as np +from op_test import OpTest + + +class TestBilinearTensorProductOp(OpTest): + def setUp(self): + self.op_type = "bilinear_tensor_product" + batch_size = 6 + size0 = 3 + size1 = 4 + size2 = 5 + a = np.random.random((batch_size, size0)).astype("float32") + b = np.random.random((batch_size, size1)).astype("float32") + w = np.random.random((size2, size0, size1)).astype("float32") + bias = np.random.random((1, size2)).astype("float32") + output = np.zeros((batch_size, size2)).astype("float32") + for i in range(size2): + w_i = w[i, :, :] + output[:, i] = np.sum(np.matmul(a, w_i) * b, axis=1) + self.inputs = { + 'X': a, + 'Y': b, + 'Weight': w, + 'Bias': bias, + } + self.outputs = {'Out': output + bias} + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad(['X', 'Y', 'Weight', 'Bias'], 'Out') + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_conditional_block.py b/python/paddle/v2/framework/tests/test_conditional_block.py new file mode 100644 index 0000000000000000000000000000000000000000..9b96ff306c37a6f96bd78951752d32274cca6282 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_conditional_block.py @@ -0,0 +1,40 @@ +import unittest +import paddle.v2.framework.layers as layers +import paddle.v2.framework.core as core +from paddle.v2.framework.framework import g_startup_program, g_main_program +from paddle.v2.framework.executor import Executor +from paddle.v2.framework.backward import append_backward_ops +import numpy + + +class ConditionalBlock(unittest.TestCase): + def test_forward(self): + data = layers.data(name='X', shape=[1], data_type='float32') + data.stop_gradient = False + cond = layers.ConditionalBlock(inputs=[data]) + out = layers.create_tensor(dtype='float32') + with cond.block(): + hidden = layers.fc(input=data, size=10) + layers.assign(hidden, out) + + cpu = core.CPUPlace() + exe = Executor(cpu) + exe.run(g_startup_program) + + x = core.LoDTensor() + x.set(numpy.random.random(size=(10, 1)).astype('float32'), cpu) + + outs = map(numpy.array, exe.run(feed={'X': x}, fetch_list=[out]))[0] + print outs + loss = layers.mean(x=out) + append_backward_ops(loss=loss) + outs = map(numpy.array, + exe.run(feed={'X': x}, + fetch_list=[ + g_main_program.block(0).var(data.name + "@GRAD") + ]))[0] + print outs + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_lod_reset_op.py b/python/paddle/v2/framework/tests/test_lod_reset_op.py new file mode 100644 index 0000000000000000000000000000000000000000..652ccecfa443fc95f08f52df766709cb550f4049 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_lod_reset_op.py @@ -0,0 +1,64 @@ +import unittest +import numpy as np +from op_test import OpTest + + +class TestLodResetOpByAttr(OpTest): + def setUp(self): + self.op_type = "lod_reset" + x = np.random.random((10, 20)).astype("float32") + lod = [[0, 3, 5, 10]] + target_lod_0 = [0, 7, 10] + self.inputs = {'X': (x, lod)} + self.attrs = {'target_lod': target_lod_0} + self.outputs = {'Out': (x, [target_lod_0])} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Out") + + +class TestLodResetOpByInput(OpTest): + def setUp(self): + self.op_type = "lod_reset" + x = np.random.random((10, 20)).astype("float32") + lod = [[0, 3, 5, 10]] + target_lod_0 = [0, 4, 7, 10] + self.inputs = { + 'X': (x, lod), + 'TargetLoD': np.array([target_lod_0]).astype('int32') + } + self.outputs = {'Out': (x, [target_lod_0])} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Out", no_grad_set=set("TargetLoD")) + + +class TestLodResetOpBoth(OpTest): + def setUp(self): + self.op_type = "lod_reset" + x = np.random.random((10, 20)).astype("float32") + lod = [[0, 3, 5, 10]] + target_lod_0_attr = [0, 7, 10] + target_lod_0_in = [0, 4, 7, 10] + self.inputs = { + 'X': (x, lod), + 'TargetLoD': np.array(target_lod_0_in).astype('int32') + } + self.attrs = {'target_lod': target_lod_0_attr} + self.outputs = {'Out': (x, [target_lod_0_in])} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Out", no_grad_set=set("TargetLoD")) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_split_and_merge_lod_tensor_op.py b/python/paddle/v2/framework/tests/test_split_and_merge_lod_tensor_op.py new file mode 100644 index 0000000000000000000000000000000000000000..6ba1e568249d4a72820cc26193a8e0e030ae5f7c --- /dev/null +++ b/python/paddle/v2/framework/tests/test_split_and_merge_lod_tensor_op.py @@ -0,0 +1,181 @@ +import unittest +import paddle.v2.framework.core as core +import numpy as np +import paddle.v2.framework.layers as layers +from paddle.v2.framework.framework import Program +from paddle.v2.framework.executor import Executor +from paddle.v2.framework.backward import append_backward_ops + + +class TestCPULoDTensorArrayOps(unittest.TestCase): + def place(self): + return core.CPUPlace() + + def test_split_and_merge_lod_tensor_no_lod(self): + tensor = core.LoDTensor() + tensor.set(np.arange(10).reshape(10, 1).astype('int32'), self.place()) + + mask_np = np.array([0, 0, 1, 1, 1, 1, 0, 0, 0, 0]).astype('bool') + mask_np = np.expand_dims(mask_np, axis=1) + + mask = core.LoDTensor() + mask.set(mask_np, self.place()) + + expect_true_tensor = np.array([2, 3, 4, 5]).astype('int32') + expect_true_tensor = np.expand_dims(expect_true_tensor, axis=1) + expect_true = core.LoDTensor() + expect_true.set(expect_true_tensor, self.place()) + + expect_false_tensor = np.array([0, 1, 6, 7, 8, 9]).astype('int32') + expect_false_tensor = np.expand_dims(expect_false_tensor, axis=1) + + expect_false = core.LoDTensor() + expect_false.set(expect_false_tensor, self.place()) + + self.main( + tensor=tensor, + mask=mask, + expect_true=expect_true, + expect_false=expect_false, + expect_out=tensor) + + def test_split_and_merge_lod_tensor_level_0(self): + tensor = core.LoDTensor() + tensor.set(np.arange(10).reshape(10, 1).astype('int32'), self.place()) + tensor.set_lod([[0, 3, 9, 10]]) + + mask_np = np.array([0, 1, 0]).astype('bool') + mask_np = np.expand_dims(mask_np, axis=1) + + mask = core.LoDTensor() + mask.set(mask_np, self.place()) + + expect_true_tensor = np.array([3, 4, 5, 6, 7, 8]).astype('int32') + expect_true_tensor = np.expand_dims(expect_true_tensor, axis=1) + expect_true = core.LoDTensor() + expect_true.set(expect_true_tensor, self.place()) + expect_true.set_lod([[0, 6]]) + + expect_false_tensor = np.array([0, 1, 2, 9]).astype('int32') + expect_false_tensor = np.expand_dims(expect_false_tensor, axis=1) + expect_false_lod = [[0, 3, 4]] + + expect_false = core.LoDTensor() + expect_false.set(expect_false_tensor, self.place()) + expect_false.set_lod(expect_false_lod) + + self.main( + tensor=tensor, + mask=mask, + expect_true=expect_true, + expect_false=expect_false, + expect_out=tensor) + + def main(self, tensor, mask, expect_true, expect_false, expect_out, + level=0): + place = self.place() + program = Program() + x = layers.data(name='x', shape=[1], main_program=program) + x.persistable = True + + y = layers.data(name='y', shape=[1], main_program=program) + y.persistable = True + + out_true, out_false = layers.split_lod_tensor( + input=x, mask=y, level=level, main_program=program) + out_true.persistable = True + out_false.persistable = True + + out = layers.merge_lod_tensor( + in_true=out_true, + in_false=out_false, + mask=y, + x=x, + level=level, + main_program=program) + + out.persistable = True + + exe = Executor(place) + scope = core.Scope() + exe.run(program, feed={'x': tensor, 'y': mask}, scope=scope) + + var_true = scope.find_var(out_true.name).get_tensor() + + var_false = scope.find_var(out_false.name).get_tensor() + + var_out = scope.find_var(out.name).get_tensor() + + self.check_tensor_same(var_true, expect_true) + self.check_tensor_same(var_false, expect_false) + self.check_tensor_same(var_out, expect_out) + + def check_tensor_same(self, actual, expect): + self.assertTrue(np.allclose(np.array(actual), np.array(expect))) + self.assertEqual(actual.lod(), expect.lod()) + + +class TestCPUSplitMergeLoDTensorGrad(unittest.TestCase): + def test_grad(self): + place = core.CPUPlace() + program = Program() + + x = layers.data( + name='x', + shape=[1], + data_type='float32', + main_program=program, + stop_gradient=False) + y = layers.data( + name='y', + shape=[1], + data_type='bool', + main_program=program, + stop_gradient=False) + + level = 0 + + out_true, out_false = layers.split_lod_tensor( + input=x, mask=y, level=level, main_program=program) + out = layers.merge_lod_tensor( + in_true=out_true, + in_false=out_false, + mask=y, + x=x, + level=level, + main_program=program) + mean = layers.mean(x=out, main_program=program) + + append_backward_ops(mean) + + tensor = core.LoDTensor() + tensor.set(np.arange(10).reshape(10, 1).astype('float32'), place) + tensor.set_lod([[0, 3, 9, 10]]) + + mask_np = np.array([0, 1, 0]).astype('bool') + mask_np = np.expand_dims(mask_np, axis=1) + + mask = core.LoDTensor() + mask.set(mask_np, place) + + exe = Executor(place) + scope = core.Scope() + + g_vars = program.global_block().var(x.name + "@GRAD") + g_out = [ + item.sum() + for item in map(np.array, + exe.run(program, + feed={'x': tensor, + 'y': mask}, + fetch_list=[g_vars], + scope=scope)) + ] + + g_out_sum = np.array(g_out).sum() + + self.assertAlmostEqual(1.0, g_out_sum, delta=0.1) + + +if __name__ == '__main__': + unittest.main()