未验证 提交 c5b6573a 编写于 作者: C chengduo 提交者: GitHub

Fix input<tensor> (#14208)

* fix input<tensor>
test=develop

* fix split_ids
test=develop

* ElementwiseMul should not support SelectedRows

* fix scale op
test=develop

* change GetTensorFromVar() method to GetTensorOrSelectedRowsFromVar()

* fix operator

* refine MultiOutput

* fix MultiOutput
test=develop

* disable test_dist_save_load
test=develop

* fix elementwise_op
test=develop

* add get_sparse_as_op
test=develop

* add info for check
test=develop

* rename get_sparse_as_op with extract_rows_as_op.
test=develop

* elementwise doesn't support selected_rows

* fix regularizer

* remove extract_rows_as
test=develop

* fix ci
test=develop

* add test for sum_op

* fix regularizer
test=develop

*  test=develop

* fix pserver weight decay multi inputs test=develop
上级 813e54ef
...@@ -648,6 +648,12 @@ int MultiDevSSAGraphBuilder::GetVarDeviceID( ...@@ -648,6 +648,12 @@ int MultiDevSSAGraphBuilder::GetVarDeviceID(
const ir::Graph &graph, const std::string &varname, const ir::Graph &graph, const std::string &varname,
const std::unordered_map<std::string, int> &sharded_var_device) const { const std::unordered_map<std::string, int> &sharded_var_device) const {
auto got = sharded_var_device.find(varname); auto got = sharded_var_device.find(varname);
if (got == sharded_var_device.end()) {
auto pos = varname.find(framework::kNewGradSuffix);
if (pos != std::string::npos) {
got = sharded_var_device.find(varname.substr(0, pos));
}
}
return got == sharded_var_device.end() ? -1 : got->second; return got == sharded_var_device.end() ? -1 : got->second;
} }
......
...@@ -358,7 +358,7 @@ static bool VarIsTensor(const Variable& var) { ...@@ -358,7 +358,7 @@ static bool VarIsTensor(const Variable& var) {
return var.IsType<LoDTensor>() || var.IsType<SelectedRows>(); return var.IsType<LoDTensor>() || var.IsType<SelectedRows>();
} }
const Tensor* GetTensorFromVar(const Variable& var) { const Tensor* GetLoDTensorOrSelectedRowsValueFromVar(const Variable& var) {
if (var.IsType<LoDTensor>()) { if (var.IsType<LoDTensor>()) {
return static_cast<const Tensor*>(&(var.Get<LoDTensor>())); return static_cast<const Tensor*>(&(var.Get<LoDTensor>()));
} else if (var.IsType<SelectedRows>()) { } else if (var.IsType<SelectedRows>()) {
...@@ -369,7 +369,7 @@ const Tensor* GetTensorFromVar(const Variable& var) { ...@@ -369,7 +369,7 @@ const Tensor* GetTensorFromVar(const Variable& var) {
} }
} }
static Tensor* GetMutableTensorFromVar(Variable* var) { Tensor* GetMutableLoDTensorOrSelectedRowsValueFromVar(Variable* var) {
if (var->IsType<LoDTensor>()) { if (var->IsType<LoDTensor>()) {
return var->GetMutable<LoDTensor>(); return var->GetMutable<LoDTensor>();
} else if (var->IsType<SelectedRows>()) { } else if (var->IsType<SelectedRows>()) {
...@@ -414,8 +414,7 @@ bool ExecutionContext::HasOutput(const std::string& name) const { ...@@ -414,8 +414,7 @@ bool ExecutionContext::HasOutput(const std::string& name) const {
template <> template <>
const Tensor* ExecutionContext::Input<Tensor>(const std::string& name) const { const Tensor* ExecutionContext::Input<Tensor>(const std::string& name) const {
auto* var = InputVar(name); return Input<LoDTensor>(name);
return var == nullptr ? nullptr : GetTensorFromVar(*var);
} }
template <> template <>
...@@ -425,17 +424,21 @@ const std::vector<const Tensor*> ExecutionContext::MultiInput<Tensor>( ...@@ -425,17 +424,21 @@ const std::vector<const Tensor*> ExecutionContext::MultiInput<Tensor>(
std::vector<const Tensor*> res; std::vector<const Tensor*> res;
res.reserve(names.size()); res.reserve(names.size());
std::transform(names.begin(), names.end(), std::back_inserter(res), std::transform(names.begin(), names.end(), std::back_inserter(res),
[&](const std::string& sub_name) { [&](const std::string& sub_name) -> const Tensor* {
auto var = scope_.FindVar(sub_name); auto var = scope_.FindVar(sub_name);
return var == nullptr ? nullptr : GetTensorFromVar(*var); if (var == nullptr) return nullptr;
PADDLE_ENFORCE(
var->IsType<LoDTensor>(),
"%s should be LoDTensor, but the received type is %s",
sub_name, var->Type().name());
return &(var->Get<LoDTensor>());
}); });
return res; return res;
} }
template <> template <>
Tensor* ExecutionContext::Output<Tensor>(const std::string& name) const { Tensor* ExecutionContext::Output<Tensor>(const std::string& name) const {
auto var = OutputVar(name); return Output<LoDTensor>(name);
return var == nullptr ? nullptr : GetMutableTensorFromVar(var);
} }
template <> template <>
...@@ -445,10 +448,14 @@ std::vector<Tensor*> ExecutionContext::MultiOutput<Tensor>( ...@@ -445,10 +448,14 @@ std::vector<Tensor*> ExecutionContext::MultiOutput<Tensor>(
std::vector<Tensor*> res; std::vector<Tensor*> res;
res.reserve(names.size()); res.reserve(names.size());
std::transform(names.begin(), names.end(), std::back_inserter(res), std::transform(names.begin(), names.end(), std::back_inserter(res),
[&](const std::string& sub_name) { [&](const std::string& sub_name) -> Tensor* {
auto var = scope_.FindVar(sub_name); auto var = scope_.FindVar(sub_name);
return var == nullptr ? nullptr if (var == nullptr) return nullptr;
: GetMutableTensorFromVar(var); PADDLE_ENFORCE(
var->IsType<LoDTensor>(),
"%s should be LoDTensor, but the received type is %s",
sub_name, var->Type().name());
return var->GetMutable<LoDTensor>();
}); });
return res; return res;
} }
...@@ -768,11 +775,12 @@ void OperatorWithKernel::TransferInplaceVarsBack( ...@@ -768,11 +775,12 @@ void OperatorWithKernel::TransferInplaceVarsBack(
const Scope& transfer_scope) const { const Scope& transfer_scope) const {
for (auto& var_name : inplace_vars) { for (auto& var_name : inplace_vars) {
VLOG(3) << "share inplace var " + var_name + " back to it's original scope"; VLOG(3) << "share inplace var " + var_name + " back to it's original scope";
auto* original_tensor = GetMutableTensorFromVar(scope.FindVar(var_name)); auto* original_tensor =
GetMutableLoDTensorOrSelectedRowsValueFromVar(scope.FindVar(var_name));
auto* var = transfer_scope.FindVar(var_name); auto* var = transfer_scope.FindVar(var_name);
PADDLE_ENFORCE(var != nullptr, "The var[%s] should not be nullptr", PADDLE_ENFORCE(var != nullptr, "The var[%s] should not be nullptr",
var_name); var_name);
auto* transformed_tensor = GetTensorFromVar(*var); auto* transformed_tensor = GetLoDTensorOrSelectedRowsValueFromVar(*var);
original_tensor->ShareDataWith(*transformed_tensor); original_tensor->ShareDataWith(*transformed_tensor);
} }
} }
...@@ -789,7 +797,7 @@ Scope* OperatorWithKernel::TryTransferData( ...@@ -789,7 +797,7 @@ Scope* OperatorWithKernel::TryTransferData(
continue; continue;
} }
auto* tensor_in = GetTensorFromVar(*var); auto* tensor_in = GetLoDTensorOrSelectedRowsValueFromVar(*var);
if (!tensor_in->IsInitialized()) { if (!tensor_in->IsInitialized()) {
continue; continue;
} }
......
...@@ -54,6 +54,9 @@ constexpr char kGradVarSuffix[] = "@GRAD"; ...@@ -54,6 +54,9 @@ constexpr char kGradVarSuffix[] = "@GRAD";
/// Variables with this suffix are supposed to be filled up with zeros. /// Variables with this suffix are supposed to be filled up with zeros.
constexpr char kZeroVarSuffix[] = "@ZERO"; constexpr char kZeroVarSuffix[] = "@ZERO";
/// Variables with this suffix are the new Gradient.
constexpr char kNewGradSuffix[] = "@NEWGRAD@";
// define some kernel priority // define some kernel priority
/* Define multiple kernel type fallback order*/ /* Define multiple kernel type fallback order*/
extern std::vector<std::tuple<platform::Place, LibraryType>> kKernelPriority; extern std::vector<std::tuple<platform::Place, LibraryType>> kKernelPriority;
...@@ -63,7 +66,8 @@ inline std::string GradVarName(const std::string& var_name) { ...@@ -63,7 +66,8 @@ inline std::string GradVarName(const std::string& var_name) {
} }
proto::VarType::Type GetDataTypeOfVar(const Variable* var); proto::VarType::Type GetDataTypeOfVar(const Variable* var);
const Tensor* GetTensorFromVar(const Variable& var); const Tensor* GetLoDTensorOrSelectedRowsValueFromVar(const Variable& var);
Tensor* GetMutableLoDTensorOrSelectedRowsValueFromVar(Variable* var);
class OperatorBase; class OperatorBase;
class ExecutionContext; class ExecutionContext;
...@@ -224,7 +228,7 @@ class ExecutionContext { ...@@ -224,7 +228,7 @@ class ExecutionContext {
std::vector<const T*> res; std::vector<const T*> res;
res.reserve(names.size()); res.reserve(names.size());
std::transform(names.begin(), names.end(), std::back_inserter(res), std::transform(names.begin(), names.end(), std::back_inserter(res),
[&](const std::string& sub_name) { [&](const std::string& sub_name) -> const T* {
auto var = scope_.FindVar(sub_name); auto var = scope_.FindVar(sub_name);
return var == nullptr ? nullptr : &var->Get<T>(); return var == nullptr ? nullptr : &var->Get<T>();
}); });
...@@ -237,7 +241,7 @@ class ExecutionContext { ...@@ -237,7 +241,7 @@ class ExecutionContext {
std::vector<T*> res; std::vector<T*> res;
res.reserve(names.size()); res.reserve(names.size());
std::transform(names.begin(), names.end(), std::back_inserter(res), std::transform(names.begin(), names.end(), std::back_inserter(res),
[&](const std::string& sub_name) { [&](const std::string& sub_name) -> T* {
auto var = scope_.FindVar(sub_name); auto var = scope_.FindVar(sub_name);
return var == nullptr ? nullptr : var->GetMutable<T>(); return var == nullptr ? nullptr : var->GetMutable<T>();
}); });
......
...@@ -296,7 +296,6 @@ op_library(cos_sim_op DEPS cos_sim_functor) ...@@ -296,7 +296,6 @@ op_library(cos_sim_op DEPS cos_sim_functor)
op_library(parallel_do_op DEPS executor) op_library(parallel_do_op DEPS executor)
op_library(unsqueeze_op DEPS reshape_op) op_library(unsqueeze_op DEPS reshape_op)
op_library(squeeze_op DEPS reshape_op) op_library(squeeze_op DEPS reshape_op)
op_library(extract_rows_op DEPS memory)
op_library(flatten_op DEPS reshape_op) op_library(flatten_op DEPS reshape_op)
op_library(sequence_pad_op DEPS sequence_padding) op_library(sequence_pad_op DEPS sequence_padding)
op_library(unstack_op DEPS stack_op) op_library(unstack_op DEPS stack_op)
......
...@@ -28,9 +28,9 @@ struct AddFunctor { ...@@ -28,9 +28,9 @@ struct AddFunctor {
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
void default_elementwise_add(const framework::ExecutionContext& ctx, void default_elementwise_add(const framework::ExecutionContext &ctx,
const framework::Tensor* x, const framework::Tensor *x,
const framework::Tensor* y, framework::Tensor* z) { const framework::Tensor *y, framework::Tensor *z) {
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
ElementwiseComputeEx<AddFunctor<T>, DeviceContext, T>(ctx, x, y, axis, ElementwiseComputeEx<AddFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
AddFunctor<T>(), z); AddFunctor<T>(), z);
...@@ -40,9 +40,9 @@ template <typename DeviceContext, typename T> ...@@ -40,9 +40,9 @@ template <typename DeviceContext, typename T>
typename std::enable_if< typename std::enable_if<
std::is_floating_point<T>::value && std::is_floating_point<T>::value &&
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
elementwise_add(const framework::ExecutionContext& ctx, elementwise_add(const framework::ExecutionContext &ctx,
const framework::Tensor* x, const framework::Tensor* y, const framework::Tensor *x, const framework::Tensor *y,
framework::Tensor* z) { framework::Tensor *z) {
auto eigen_x = framework::EigenVector<T>::Flatten(*x); auto eigen_x = framework::EigenVector<T>::Flatten(*x);
auto eigen_y = framework::EigenVector<T>::Flatten(*y); auto eigen_y = framework::EigenVector<T>::Flatten(*y);
auto eigen_z = framework::EigenVector<T>::Flatten(*z); auto eigen_z = framework::EigenVector<T>::Flatten(*z);
...@@ -55,21 +55,20 @@ template <typename DeviceContext, typename T> ...@@ -55,21 +55,20 @@ template <typename DeviceContext, typename T>
typename std::enable_if< typename std::enable_if<
!std::is_floating_point<T>::value || !std::is_floating_point<T>::value ||
!std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type !std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
elementwise_add(const framework::ExecutionContext& ctx, elementwise_add(const framework::ExecutionContext &ctx,
const framework::Tensor* x, const framework::Tensor* y, const framework::Tensor *x, const framework::Tensor *y,
framework::Tensor* z) { framework::Tensor *z) {
default_elementwise_add<DeviceContext, T>(ctx, x, y, z); default_elementwise_add<DeviceContext, T>(ctx, x, y, z);
} }
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class ElementwiseAddKernel : public framework::OpKernel<T> { class ElementwiseAddKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
using Tensor = framework::Tensor; auto *x = ctx.Input<framework::LoDTensor>("X");
auto *y = ctx.Input<framework::LoDTensor>("Y");
auto *z = ctx.Output<framework::LoDTensor>("Out");
const auto x = ctx.Input<Tensor>("X");
const auto y = ctx.Input<Tensor>("Y");
auto z = ctx.Output<Tensor>("Out");
z->mutable_data<T>(ctx.GetPlace()); z->mutable_data<T>(ctx.GetPlace());
auto dims_equal = x->dims() == y->dims(); auto dims_equal = x->dims() == y->dims();
...@@ -87,13 +86,13 @@ struct IdentityGrad { ...@@ -87,13 +86,13 @@ struct IdentityGrad {
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
void default_elementwise_add_grad(const framework::ExecutionContext& ctx, void default_elementwise_add_grad(const framework::ExecutionContext &ctx,
const framework::Tensor* x, const framework::Tensor *x,
const framework::Tensor* y, const framework::Tensor *y,
const framework::Tensor* out, const framework::Tensor *out,
const framework::Tensor* dout, const framework::Tensor *dout,
framework::Tensor* dx, framework::Tensor *dx,
framework::Tensor* dy) { framework::Tensor *dy) {
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
ElemwiseExplicitGradCompute<DeviceContext, T, IdentityGrad<T>, ElemwiseExplicitGradCompute<DeviceContext, T, IdentityGrad<T>,
...@@ -106,11 +105,11 @@ template <typename DeviceContext, typename T> ...@@ -106,11 +105,11 @@ template <typename DeviceContext, typename T>
typename std::enable_if< typename std::enable_if<
std::is_floating_point<T>::value && std::is_floating_point<T>::value &&
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
elementwise_add_grad(const framework::ExecutionContext& ctx, elementwise_add_grad(const framework::ExecutionContext &ctx,
const framework::Tensor* x, const framework::Tensor* y, const framework::Tensor *x, const framework::Tensor *y,
const framework::Tensor* out, const framework::Tensor *out,
const framework::Tensor* dout, framework::Tensor* dx, const framework::Tensor *dout, framework::Tensor *dx,
framework::Tensor* dy) { framework::Tensor *dy) {
auto blas = math::GetBlas<DeviceContext, T>(ctx); auto blas = math::GetBlas<DeviceContext, T>(ctx);
if (dx) { if (dx) {
...@@ -128,27 +127,27 @@ template <typename DeviceContext, typename T> ...@@ -128,27 +127,27 @@ template <typename DeviceContext, typename T>
typename std::enable_if< typename std::enable_if<
!std::is_floating_point<T>::value || !std::is_floating_point<T>::value ||
!std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type !std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
elementwise_add_grad(const framework::ExecutionContext& ctx, elementwise_add_grad(const framework::ExecutionContext &ctx,
const framework::Tensor* x, const framework::Tensor* y, const framework::Tensor *x, const framework::Tensor *y,
const framework::Tensor* out, const framework::Tensor *out,
const framework::Tensor* dout, framework::Tensor* dx, const framework::Tensor *dout, framework::Tensor *dx,
framework::Tensor* dy) { framework::Tensor *dy) {
default_elementwise_add_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy); default_elementwise_add_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
} }
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class ElementwiseAddGradKernel : public ElemwiseGradKernel<T> { class ElementwiseAddGradKernel : public ElemwiseGradKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
ElemwiseGradKernel<T>::Compute(ctx); ElemwiseGradKernel<T>::Compute(ctx);
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out")); auto *dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X")); auto *dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y")); auto *dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
// skip out, x, y // skip out, x, y
auto* out = dout; auto *out = dout;
auto *x = dout, *y = dout; auto *x = dout, *y = dout;
if (platform::is_cpu_place(ctx.GetPlace()) && dx != nullptr && if (platform::is_cpu_place(ctx.GetPlace()) && dx != nullptr &&
......
...@@ -28,11 +28,10 @@ template <typename DeviceContext, typename T> ...@@ -28,11 +28,10 @@ template <typename DeviceContext, typename T>
class ElementwiseDivKernel : public framework::OpKernel<T> { class ElementwiseDivKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
using Tensor = framework::Tensor; auto* x = ctx.Input<framework::LoDTensor>("X");
auto* y = ctx.Input<framework::LoDTensor>("Y");
auto* z = ctx.Output<framework::LoDTensor>("Out");
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* z = ctx.Output<Tensor>("Out");
z->mutable_data<T>(ctx.GetPlace()); z->mutable_data<T>(ctx.GetPlace());
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
ElementwiseComputeEx<DivFunctor<T>, DeviceContext, T>(ctx, x, y, axis, ElementwiseComputeEx<DivFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
......
...@@ -29,11 +29,10 @@ template <typename DeviceContext, typename T> ...@@ -29,11 +29,10 @@ template <typename DeviceContext, typename T>
class ElementwiseMaxKernel : public framework::OpKernel<T> { class ElementwiseMaxKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
using Tensor = framework::Tensor; auto* x = ctx.Input<framework::LoDTensor>("X");
auto* y = ctx.Input<framework::LoDTensor>("Y");
auto* z = ctx.Output<framework::LoDTensor>("Out");
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* z = ctx.Output<Tensor>("Out");
z->mutable_data<T>(ctx.GetPlace()); z->mutable_data<T>(ctx.GetPlace());
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
ElementwiseComputeEx<MaxFunctor<T>, DeviceContext, T>(ctx, x, y, axis, ElementwiseComputeEx<MaxFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
......
...@@ -28,11 +28,10 @@ template <typename DeviceContext, typename T> ...@@ -28,11 +28,10 @@ template <typename DeviceContext, typename T>
class ElementwiseMinKernel : public framework::OpKernel<T> { class ElementwiseMinKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
using Tensor = framework::Tensor; auto* x = ctx.Input<framework::LoDTensor>("X");
auto* y = ctx.Input<framework::LoDTensor>("Y");
auto* z = ctx.Output<framework::LoDTensor>("Out");
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* z = ctx.Output<Tensor>("Out");
z->mutable_data<T>(ctx.GetPlace()); z->mutable_data<T>(ctx.GetPlace());
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
ElementwiseComputeEx<MinFunctor<T>, DeviceContext, T>(ctx, x, y, axis, ElementwiseComputeEx<MinFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
......
...@@ -60,11 +60,10 @@ template <typename DeviceContext, typename T> ...@@ -60,11 +60,10 @@ template <typename DeviceContext, typename T>
class ElementwiseMulKernel : public framework::OpKernel<T> { class ElementwiseMulKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
using Tensor = framework::Tensor; auto* x = ctx.Input<framework::LoDTensor>("X");
auto* y = ctx.Input<framework::LoDTensor>("Y");
auto* z = ctx.Output<framework::LoDTensor>("Out");
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* z = ctx.Output<Tensor>("Out");
z->mutable_data<T>(ctx.GetPlace()); z->mutable_data<T>(ctx.GetPlace());
if (x->numel() == y->numel()) { if (x->numel() == y->numel()) {
elementwise_mul<DeviceContext, T>(ctx, x, y, z); elementwise_mul<DeviceContext, T>(ctx, x, y, z);
......
...@@ -13,10 +13,12 @@ See the License for the specific language governing permissions and ...@@ -13,10 +13,12 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <string> #include <string>
#include "paddle/fluid/framework/data_layout.h" #include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_helper.h"
#endif #endif
...@@ -29,7 +31,8 @@ class ElementwiseOp : public framework::OperatorWithKernel { ...@@ -29,7 +31,8 @@ class ElementwiseOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
void InferShape(framework::InferShapeContext* ctx) const override {
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of elementwise op should not be null."); "Input(X) of elementwise op should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Y"), PADDLE_ENFORCE(ctx->HasInput("Y"),
...@@ -37,6 +40,17 @@ class ElementwiseOp : public framework::OperatorWithKernel { ...@@ -37,6 +40,17 @@ class ElementwiseOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of elementwise op should not be null."); "Output(Out) of elementwise op should not be null.");
PADDLE_ENFORCE(
ctx->GetInputsVarType("X").front() ==
framework::proto::VarType::LOD_TENSOR,
"The input var's type should be LoDTensor, but the received is %s",
ctx->Inputs("X").front(), ctx->GetInputsVarType("X").front());
PADDLE_ENFORCE(
ctx->GetInputsVarType("Y").front() ==
framework::proto::VarType::LOD_TENSOR,
"The input var's type should be LoDTensor, but the received is %s",
ctx->Inputs("Y").front(), ctx->GetInputsVarType("Y").front());
auto x_dim = ctx->GetInputDim("X"); auto x_dim = ctx->GetInputDim("X");
auto y_dim = ctx->GetInputDim("Y"); auto y_dim = ctx->GetInputDim("Y");
PADDLE_ENFORCE_GE(x_dim.size(), y_dim.size(), PADDLE_ENFORCE_GE(x_dim.size(), y_dim.size(),
...@@ -47,9 +61,8 @@ class ElementwiseOp : public framework::OperatorWithKernel { ...@@ -47,9 +61,8 @@ class ElementwiseOp : public framework::OperatorWithKernel {
} }
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext &ctx) const override {
auto input_data_type = auto input_data_type = framework::GetDataTypeOfVar(ctx.InputVar("X"));
framework::ToDataType(ctx.Input<Tensor>("X")->type());
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (platform::CanMKLDNNBeUsed(ctx)) { if (platform::CanMKLDNNBeUsed(ctx)) {
...@@ -64,12 +77,12 @@ class ElementwiseOp : public framework::OperatorWithKernel { ...@@ -64,12 +77,12 @@ class ElementwiseOp : public framework::OperatorWithKernel {
class ElementwiseOpInferVarType : public framework::VarTypeInference { class ElementwiseOpInferVarType : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc& op_desc, void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc* block) const override { framework::BlockDesc *block) const override {
auto x_name = op_desc.Input("X")[0]; auto x_name = op_desc.Input("X")[0];
auto out_name = op_desc.Output("Out")[0]; auto out_name = op_desc.Output("Out")[0];
auto& x = block->FindRecursiveOrCreateVar(x_name); auto &x = block->FindRecursiveOrCreateVar(x_name);
auto& out = block->FindRecursiveOrCreateVar(out_name); auto &out = block->FindRecursiveOrCreateVar(out_name);
out.SetType(x.GetType()); out.SetType(x.GetType());
out.SetDataType(x.GetDataType()); out.SetDataType(x.GetDataType());
} }
...@@ -131,6 +144,7 @@ But the output only shares the LoD information with the input $X$. ...@@ -131,6 +144,7 @@ But the output only shares the LoD information with the input $X$.
protected: protected:
virtual std::string GetName() const = 0; virtual std::string GetName() const = 0;
virtual std::string GetEquation() const = 0; virtual std::string GetEquation() const = 0;
}; };
...@@ -139,7 +153,7 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel { ...@@ -139,7 +153,7 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null"); PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
...@@ -165,7 +179,7 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel { ...@@ -165,7 +179,7 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
} }
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext &ctx) const override {
auto input_data_type = framework::ToDataType( auto input_data_type = framework::ToDataType(
ctx.Input<Tensor>(framework::GradVarName("Out"))->type()); ctx.Input<Tensor>(framework::GradVarName("Out"))->type());
...@@ -187,7 +201,7 @@ class ElementwiseOpExplicitGrad : public ElementwiseOpGrad { ...@@ -187,7 +201,7 @@ class ElementwiseOpExplicitGrad : public ElementwiseOpGrad {
using operators::ElementwiseOpGrad::GetExpectedKernelType; using operators::ElementwiseOpGrad::GetExpectedKernelType;
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null"); "Input(Out@GRAD) should not be null");
...@@ -209,11 +223,11 @@ class ElementwiseOpExplicitGrad : public ElementwiseOpGrad { ...@@ -209,11 +223,11 @@ class ElementwiseOpExplicitGrad : public ElementwiseOpGrad {
template <typename T> template <typename T>
class ElemwiseGradKernel : public framework::OpKernel<T> { class ElemwiseGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext &context) const override {
auto* dx = auto *dx =
context.Output<framework::LoDTensor>(framework::GradVarName("X")); context.Output<framework::LoDTensor>(framework::GradVarName("X"));
if (dx != nullptr) { if (dx != nullptr) {
auto& dout = auto &dout =
*context.Input<framework::LoDTensor>(framework::GradVarName("Out")); *context.Input<framework::LoDTensor>(framework::GradVarName("Out"));
dx->set_lod(dout.lod()); dx->set_lod(dout.lod());
} }
...@@ -234,7 +248,7 @@ class ElemwiseGradKernel : public framework::OpKernel<T> { ...@@ -234,7 +248,7 @@ class ElemwiseGradKernel : public framework::OpKernel<T> {
\ \
protected: \ protected: \
std::unique_ptr<paddle::framework::OpDesc> Apply() const override { \ std::unique_ptr<paddle::framework::OpDesc> Apply() const override { \
auto* op = new paddle::framework::OpDesc(); \ auto *op = new paddle::framework::OpDesc(); \
op->SetType(#kernel_type "_grad"); \ op->SetType(#kernel_type "_grad"); \
op->SetInput("Y", Input("Y")); \ op->SetInput("Y", Input("Y")); \
op->SetInput(::paddle::framework::GradVarName("Out"), \ op->SetInput(::paddle::framework::GradVarName("Out"), \
......
...@@ -28,11 +28,10 @@ template <typename DeviceContext, typename T> ...@@ -28,11 +28,10 @@ template <typename DeviceContext, typename T>
class ElementwiseSubKernel : public framework::OpKernel<T> { class ElementwiseSubKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
using Tensor = framework::Tensor; auto* x = ctx.Input<framework::LoDTensor>("X");
auto* y = ctx.Input<framework::LoDTensor>("Y");
auto* z = ctx.Output<framework::LoDTensor>("Out");
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* z = ctx.Output<Tensor>("Out");
z->mutable_data<T>(ctx.GetPlace()); z->mutable_data<T>(ctx.GetPlace());
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(ctx, x, y, axis, ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
class ExtractRowsOpInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of ExtractRowsOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of ExtractRowsOp should not be null.");
PADDLE_ENFORCE_EQ(ctx->GetInputsVarType("X")[0],
framework::proto::VarType::SELECTED_ROWS,
"The type of input(X) must be SelectedRows.");
auto in_dims = ctx->GetInputDim("X");
ctx->SetOutputDim(
"Out", framework::make_ddim(std::vector<int64_t>{in_dims[0], 1}));
}
};
class ExtractRowsOp : public framework::OperatorBase {
public:
ExtractRowsOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: framework::OperatorBase(type, inputs, outputs, attrs) {}
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto &in = scope.FindVar(Input("X"))->Get<framework::SelectedRows>();
auto out = scope.FindVar(Output("Out"))->GetMutable<framework::LoDTensor>();
auto &in_rows = in.rows();
auto out_dim = framework::make_ddim(
std::vector<int64_t>{static_cast<int64_t>(in_rows.size()), 1});
auto dst_ptr = out->mutable_data<int64_t>(out_dim, in.place());
if (paddle::platform::is_gpu_place(in.place())) {
#ifdef PADDLE_WITH_CUDA
platform::DeviceContextPool &pool =
platform::DeviceContextPool::Instance();
auto *dev_ctx = pool.Get(in.place());
auto src_ptr = in_rows.Data(in.place());
auto stream =
reinterpret_cast<const platform::CUDADeviceContext &>(*dev_ctx)
.stream();
memory::Copy(boost::get<platform::CUDAPlace>(out->place()), dst_ptr,
boost::get<platform::CUDAPlace>(in.place()), src_ptr,
in_rows.size() * sizeof(int64_t), stream);
#else
PADDLE_THROW("Not compiled with CUDA.");
#endif
} else {
memory::Copy(platform::CPUPlace(), dst_ptr, platform::CPUPlace(),
in_rows.data(), in_rows.size() * sizeof(int64_t));
}
}
};
class ExtractRowsOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(SelectedRows). The input tensor of extract_rows operator,"
" and its type is SelectedRows.");
AddOutput("Out", "(Tensor). The the rows of input(X).");
AddComment(R"DOC(
ExtractRows Operator.
The function of extract_rows_op is extracting the rows from the input(X)
whose type is SelectedRows.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(extract_rows, ops::ExtractRowsOp, ops::ExtractRowsOpMaker,
ops::ExtractRowsOpInferShape);
...@@ -64,6 +64,8 @@ struct SelectedRowsSumTo { ...@@ -64,6 +64,8 @@ struct SelectedRowsSumTo {
framework::SelectedRows* input2); framework::SelectedRows* input2);
}; };
// FIXME: The result of SelectedRowsAddToTensor maybe non deterministic,
// because it uses CudaAtomicAdd.
// input2 = input1 + input2 // input2 = input1 + input2
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
struct SelectedRowsAddToTensor { struct SelectedRowsAddToTensor {
......
...@@ -24,19 +24,13 @@ class ScaleKernel : public framework::OpKernel<T> { ...@@ -24,19 +24,13 @@ class ScaleKernel : public framework::OpKernel<T> {
public: public:
virtual void Compute(const framework::ExecutionContext& ctx) const { virtual void Compute(const framework::ExecutionContext& ctx) const {
auto* in_var = ctx.InputVar("X"); auto* in_var = ctx.InputVar("X");
auto* in = ctx.Input<framework::Tensor>("X"); auto* in = framework::GetLoDTensorOrSelectedRowsValueFromVar(*in_var);
auto* out_var = ctx.OutputVar("Out");
auto* out = ctx.Output<framework::Tensor>("Out");
out->mutable_data<T>(in->place());
PADDLE_ENFORCE_EQ(in->dims(), out->dims(),
"in and out should have the same dim");
auto scale = static_cast<T>(ctx.Attr<float>("scale")); auto scale = static_cast<T>(ctx.Attr<float>("scale"));
auto bias = static_cast<T>(ctx.Attr<float>("bias")); auto bias = static_cast<T>(ctx.Attr<float>("bias"));
auto bias_after_scale = ctx.Attr<bool>("bias_after_scale"); auto bias_after_scale = ctx.Attr<bool>("bias_after_scale");
auto* out_var = ctx.OutputVar("Out");
if (in_var->IsType<framework::SelectedRows>() && in_var != out_var) { if (in_var->IsType<framework::SelectedRows>() && in_var != out_var) {
auto& in_slr = in_var->Get<framework::SelectedRows>(); auto& in_slr = in_var->Get<framework::SelectedRows>();
auto* out_slr = out_var->GetMutable<framework::SelectedRows>(); auto* out_slr = out_var->GetMutable<framework::SelectedRows>();
...@@ -44,6 +38,13 @@ class ScaleKernel : public framework::OpKernel<T> { ...@@ -44,6 +38,13 @@ class ScaleKernel : public framework::OpKernel<T> {
out_slr->set_height(in_slr.height()); out_slr->set_height(in_slr.height());
} }
auto* out =
framework::GetMutableLoDTensorOrSelectedRowsValueFromVar(out_var);
out->mutable_data<T>(in->place());
PADDLE_ENFORCE_EQ(in->dims(), out->dims(),
"in and out should have the same dim");
auto eigen_out = framework::EigenVector<T>::Flatten(*out); auto eigen_out = framework::EigenVector<T>::Flatten(*out);
auto eigen_in = framework::EigenVector<T>::Flatten(*in); auto eigen_in = framework::EigenVector<T>::Flatten(*in);
auto& dev = *ctx.template device_context<DeviceContext>().eigen_device(); auto& dev = *ctx.template device_context<DeviceContext>().eigen_device();
......
...@@ -64,8 +64,7 @@ class SplitIdsOp : public framework::OperatorWithKernel { ...@@ -64,8 +64,7 @@ class SplitIdsOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType( framework::GetDataTypeOfVar(ctx.MultiInputVar("Ids").front()),
ctx.MultiInput<framework::Tensor>("Ids").front()->type()),
ctx.GetPlace()); ctx.GetPlace());
} }
}; };
......
...@@ -113,6 +113,10 @@ class SplitIdsOpKernel : public framework::OpKernel<T> { ...@@ -113,6 +113,10 @@ class SplitIdsOpKernel : public framework::OpKernel<T> {
row_width * sizeof(T)); row_width * sizeof(T));
} }
} }
} else {
PADDLE_THROW(
"% should be LoDTensor or SelectedRows, but the received type is %s",
ctx.Inputs("Ids")[0], ids_var->Type().name());
} }
} }
}; };
......
...@@ -85,8 +85,8 @@ class SumOp : public framework::OperatorWithKernel { ...@@ -85,8 +85,8 @@ class SumOp : public framework::OperatorWithKernel {
for (size_t idx = 0; idx < x_vars.size(); ++idx) { for (size_t idx = 0; idx < x_vars.size(); ++idx) {
PADDLE_ENFORCE(x_vars[idx] != nullptr, PADDLE_ENFORCE(x_vars[idx] != nullptr,
"Input var[%s] should not be nullptr", x_vars_name[idx]); "Input var[%s] should not be nullptr", x_vars_name[idx]);
// FIXME(zcd): The input x_var may be SelectedRows or LoDTensor. auto tensor =
auto tensor = framework::GetTensorFromVar(*x_vars[idx]); framework::GetLoDTensorOrSelectedRowsValueFromVar(*x_vars[idx]);
if (tensor->numel() == 0) { if (tensor->numel() == 0) {
continue; continue;
} }
......
...@@ -27,6 +27,7 @@ void BindConstValue(pybind11::module* m) { ...@@ -27,6 +27,7 @@ void BindConstValue(pybind11::module* m) {
m->def("kZeroVarSuffix", [] { return framework::kZeroVarSuffix; }); m->def("kZeroVarSuffix", [] { return framework::kZeroVarSuffix; });
m->def("kControlDepVarName", m->def("kControlDepVarName",
[] { return framework::ir::Node::kControlDepVarName; }); [] { return framework::ir::Node::kControlDepVarName; });
m->def("kNewGradSuffix", [] { return framework::kNewGradSuffix; });
auto op_proto_and_checker_maker = auto op_proto_and_checker_maker =
m->def_submodule("op_proto_and_checker_maker"); m->def_submodule("op_proto_and_checker_maker");
......
...@@ -61,14 +61,25 @@ def append_regularization_ops(parameters_and_grads, regularization=None): ...@@ -61,14 +61,25 @@ def append_regularization_ops(parameters_and_grads, regularization=None):
params_and_grads.append((param, grad)) params_and_grads.append((param, grad))
continue continue
assert grad.shape == regularization_term.shape new_grad = grad
if grad.type == core.VarDesc.VarType.SELECTED_ROWS:
# FIXME(zcd): If the grad is SELECTED_ROWS, after regularization,
# the grad's type and name will be changed. But the gradient's name
# is used in ParallelExecutor Reduce mode, so I add a flag for
# the new_grad here.
new_grad = grad.block.create_var(
name=grad.name + core.kNewGradSuffix(),
dtype=param.dtype,
shape=param.shape,
lod_level=param.lod_level,
type=core.VarDesc.VarType.LOD_TENSOR)
grad.block.append_op( grad.block.append_op(
type='elementwise_add', type='sum',
inputs={"X": grad, inputs={"X": [grad, regularization_term]},
"Y": regularization_term}, outputs={"Out": new_grad})
outputs={"Out": grad})
params_and_grads.append((param, grad)) params_and_grads.append((param, new_grad))
return params_and_grads return params_and_grads
...@@ -142,26 +153,7 @@ class L2DecayRegularizer(WeightDecayRegularizer): ...@@ -142,26 +153,7 @@ class L2DecayRegularizer(WeightDecayRegularizer):
assert isinstance(block, framework.Block) assert isinstance(block, framework.Block)
decay = block.create_var( decay = block.create_var(
dtype="float32", shape=param.shape, lod_level=param.lod_level) dtype=param.dtype, shape=param.shape, lod_level=param.lod_level)
if grad.type == core.VarDesc.VarType.SELECTED_ROWS:
idx = block.create_var(
dtype="int64",
shape=param.shape,
type=core.VarDesc.VarType.LOD_TENSOR)
decay = block.create_var(
dtype="float32",
shape=param.shape,
type=core.VarDesc.VarType.LOD_TENSOR)
block.append_op(
type='extract_rows', inputs={'X': grad}, outputs={'Out': idx})
block.append_op(
type='lookup_table',
inputs={'W': param,
'Ids': idx},
outputs={'Out': decay},
attrs={'is_sparse': True})
param = decay
# Append Op to calculate decay # Append Op to calculate decay
block.append_op( block.append_op(
...@@ -218,27 +210,9 @@ class L1DecayRegularizer(WeightDecayRegularizer): ...@@ -218,27 +210,9 @@ class L1DecayRegularizer(WeightDecayRegularizer):
""" """
assert isinstance(param, framework.Parameter) assert isinstance(param, framework.Parameter)
assert isinstance(block, framework.Block) assert isinstance(block, framework.Block)
decay = block.create_var( decay = block.create_var(
dtype="float32", shape=param.shape, lod_level=param.lod_level) dtype=param.dtype, shape=param.shape, lod_level=param.lod_level)
if grad.type == core.VarDesc.VarType.SELECTED_ROWS:
idx = block.create_var(
dtype="int64",
shape=param.shape,
type=core.VarDesc.VarType.LOD_TENSOR)
decay = block.create_var(
dtype="float32",
shape=param.shape,
type=core.VarDesc.VarType.LOD_TENSOR)
block.append_op(
type='extract_rows', inputs={'X': grad}, outputs={'Out': idx})
block.append_op(
type='lookup_table',
inputs={'W': param,
'Ids': idx},
outputs={'Out': decay},
attrs={'is_sparse': True})
param = decay
# Append sign op # Append sign op
block.append_op( block.append_op(
......
...@@ -373,9 +373,8 @@ class TestL2Decay(TranspilerTest): ...@@ -373,9 +373,8 @@ class TestL2Decay(TranspilerTest):
self.assertEqual(len(pserver.blocks), 3) self.assertEqual(len(pserver.blocks), 3)
self.assertEqual([op.type for op in pserver.blocks[1].ops], self.assertEqual([op.type for op in pserver.blocks[1].ops],
["sum", "scale", "clip", "sgd"]) ["sum", "scale", "clip", "sgd"])
self.assertEqual( self.assertEqual([op.type for op in pserver.blocks[2].ops],
[op.type for op in pserver.blocks[2].ops], ["sum", "scale", "clip", "scale", "sum", "sgd"])
["sum", "scale", "clip", "scale", "elementwise_add", "sgd"])
# TODO(typhoonzero): test clipping and L2Decay ops are removed from trainer # TODO(typhoonzero): test clipping and L2Decay ops are removed from trainer
...@@ -416,12 +415,10 @@ class TestL2DecayWithPiecewise(TranspilerTest): ...@@ -416,12 +415,10 @@ class TestL2DecayWithPiecewise(TranspilerTest):
"logical_and", "conditional_block", "fill_constant", "logical_and", "conditional_block", "fill_constant",
"conditional_block" "conditional_block"
]) ])
self.assertEqual( self.assertEqual([op.type for op in pserver.blocks[7].ops],
[op.type for op in pserver.blocks[7].ops], ["sum", "scale", "scale", "sum", "momentum"])
["sum", "scale", "scale", "elementwise_add", "momentum"]) self.assertEqual([op.type for op in pserver.blocks[8].ops],
self.assertEqual( ["sum", "scale", "scale", "sum", "momentum"])
[op.type for op in pserver.blocks[8].ops],
["sum", "scale", "scale", "elementwise_add", "momentum"])
class TestEmptyPserverOptimizeBlocks(TranspilerTest): class TestEmptyPserverOptimizeBlocks(TranspilerTest):
......
...@@ -117,56 +117,5 @@ class TestElementwiseMulOp_broadcast_3(ElementwiseMulOp): ...@@ -117,56 +117,5 @@ class TestElementwiseMulOp_broadcast_3(ElementwiseMulOp):
} }
class TestElementWiseMulSelectedRows(OpTest):
def setUp(self):
self.rows = [0, 1, 2, 3, 4, 5, 6]
self.feature = 12
self.height = 100
self.input_shape = (len(self.rows), self.feature)
def prepare_input(self, scope, place):
self.input = {
"X": np.random.random(self.input_shape).astype("float32"),
"Y": np.random.random(self.input_shape).astype("float32")
}
def init_input(in_name):
x_selected_rows = scope.var(in_name).get_selected_rows()
x_selected_rows.set_height(self.height)
x_selected_rows.set_rows(self.rows)
x_array = self.input[in_name]
x_tensor = x_selected_rows.get_tensor()
x_tensor.set(x_array, place)
init_input("X")
init_input("Y")
def create_out_selected_row(self, scope):
return scope.var('Out').get_selected_rows()
def check_result(self, out_selected_rows):
assert out_selected_rows.height() == self.height
assert out_selected_rows.rows() == self.rows
out_tensor = np.array(out_selected_rows.get_tensor())
assert out_tensor.shape == self.input_shape
def check_with_place(self, place):
scope = core.Scope()
self.prepare_input(scope, place)
out_selected_rows = self.create_out_selected_row(scope)
out_selected_rows.set_height(0)
out_selected_rows.set_rows([])
elementwise_mul = Operator("elementwise_mul", X='X', Y='Y', Out='Out')
elementwise_mul.run(scope, place)
self.check_result(out_selected_rows)
def test_elewisemul_with_selected_rows_input(self):
places = [core.CPUPlace()]
for place in places:
self.check_with_place(place)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle.fluid.core as core
from paddle.fluid.op import Operator
from op_test import OpTest
class TestExtractRows(OpTest):
def check_with_place(self, place):
scope = core.Scope()
# create and initialize Variable
feature_len = 12
rows = [0, 4, 4, 7]
np_array = np.ones((len(rows), feature_len)).astype("float32")
in_x = scope.var('X').get_selected_rows()
in_x.set_height(len(rows))
in_x.set_rows(rows)
in_x_tensor = in_x.get_tensor()
in_x_tensor.set(np_array, place)
# create Out Variable
out_tensor = scope.var('Out').get_tensor()
# create and run lookup_table operator
extract_rows_op = Operator("extract_rows", X='X', Out='Out')
extract_rows_op.run(scope, place)
# get result from Out
result_array = np.array(out_tensor)
result_array = [ele[0] for ele in result_array]
assert result_array == rows
def test_concat_rows(self):
places = [core.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(core.CUDAPlace(0))
for place in places:
self.check_with_place(place)
if __name__ == '__main__':
unittest.main()
...@@ -55,7 +55,7 @@ class TestL2DecayRegularizer(unittest.TestCase): ...@@ -55,7 +55,7 @@ class TestL2DecayRegularizer(unittest.TestCase):
params_grads = optimizer.append_regularization_ops(params_grads) params_grads = optimizer.append_regularization_ops(params_grads)
self.assertEqual(len(params_grads), 1) self.assertEqual(len(params_grads), 1)
self.assertEqual(len(block.ops), count_ops + 2) self.assertEqual(len(block.ops), count_ops + 2)
self.assertEqual(block.ops[-1].type, 'elementwise_add') self.assertEqual(block.ops[-1].type, 'sum')
self.assertEqual(block.ops[-2].type, 'scale') self.assertEqual(block.ops[-2].type, 'scale')
...@@ -92,7 +92,7 @@ class TestL1DecayRegularizer(unittest.TestCase): ...@@ -92,7 +92,7 @@ class TestL1DecayRegularizer(unittest.TestCase):
params_grads = optimizer.append_regularization_ops(params_grads) params_grads = optimizer.append_regularization_ops(params_grads)
self.assertEqual(len(params_grads), 1) self.assertEqual(len(params_grads), 1)
self.assertEqual(len(block.ops), count_ops + 3) self.assertEqual(len(block.ops), count_ops + 3)
self.assertEqual(block.ops[-1].type, 'elementwise_add') self.assertEqual(block.ops[-1].type, 'sum')
self.assertEqual(block.ops[-2].type, 'scale') self.assertEqual(block.ops[-2].type, 'scale')
self.assertEqual(block.ops[-3].type, 'sign') self.assertEqual(block.ops[-3].type, 'sign')
......
...@@ -49,11 +49,14 @@ class TestSumOp(OpTest): ...@@ -49,11 +49,14 @@ class TestSumOp(OpTest):
class TestSelectedRowsSumOp(OpTest): class TestSelectedRowsSumOp(OpTest):
def check_with_place(self, place, inplace): def setUp(self):
self.height = 10 self.height = 10
self.row_numel = 12 self.row_numel = 12
self.rows = [0, 1, 2, 3, 4, 5, 6] self.rows = [0, 1, 2, 3, 4, 5, 6]
self.dtype = np.float32
self.init_kernel_type()
def check_with_place(self, place, inplace):
self.check_input_and_optput(core.Scope(), place, inplace, True, True, self.check_input_and_optput(core.Scope(), place, inplace, True, True,
True) True)
self.check_input_and_optput(core.Scope(), place, inplace, False, True, self.check_input_and_optput(core.Scope(), place, inplace, False, True,
...@@ -64,12 +67,12 @@ class TestSelectedRowsSumOp(OpTest): ...@@ -64,12 +67,12 @@ class TestSelectedRowsSumOp(OpTest):
False) False)
def init_kernel_type(self): def init_kernel_type(self):
self.dtype = np.float32 pass
def _get_array(self, row_num, row_numel): def _get_array(self, rows, row_numel):
array = np.ones((row_num, row_numel)).astype(self.dtype) array = np.ones((len(rows), row_numel)).astype(self.dtype)
for i in range(row_num): for i in range(len(rows)):
array[i] *= i array[i] *= rows[i]
return array return array
def check_input_and_optput(self, def check_input_and_optput(self,
...@@ -105,7 +108,7 @@ class TestSelectedRowsSumOp(OpTest): ...@@ -105,7 +108,7 @@ class TestSelectedRowsSumOp(OpTest):
self.assertTrue( self.assertTrue(
np.array_equal( np.array_equal(
np.array(out.get_tensor()), np.array(out.get_tensor()),
self._get_array(len(self.rows), self.row_numel) * self._get_array(self.rows, self.row_numel) *
has_data_w_num)) has_data_w_num))
else: else:
self.assertEqual(len(out.rows()), 0) self.assertEqual(len(out.rows()), 0)
...@@ -121,7 +124,7 @@ class TestSelectedRowsSumOp(OpTest): ...@@ -121,7 +124,7 @@ class TestSelectedRowsSumOp(OpTest):
w_selected_rows = var.get_selected_rows() w_selected_rows = var.get_selected_rows()
w_selected_rows.set_height(self.height) w_selected_rows.set_height(self.height)
w_selected_rows.set_rows(rows) w_selected_rows.set_rows(rows)
w_array = self._get_array(len(rows), self.row_numel) w_array = self._get_array(self.rows, self.row_numel)
w_tensor = w_selected_rows.get_tensor() w_tensor = w_selected_rows.get_tensor()
w_tensor.set(w_array, place) w_tensor.set(w_array, place)
...@@ -136,36 +139,91 @@ class TestSelectedRowsSumOp(OpTest): ...@@ -136,36 +139,91 @@ class TestSelectedRowsSumOp(OpTest):
self.check_with_place(place, inplace) self.check_with_place(place, inplace)
class TestLoDTensorAndSelectedRowsOp(TestSelectedRowsSumOp):
def setUp(self):
self.height = 10
self.row_numel = 12
self.rows = [0, 1, 2, 2, 4, 5, 6]
def check_with_place(self, place, inplace):
scope = core.Scope()
if inplace:
self.create_lod_tensor(scope, place, "x1")
self.create_selected_rows(scope, place, "x2", True)
out = scope.var("x1").get_tensor()
out_name = "x1"
else:
self.create_selected_rows(scope, place, "x1", True)
self.create_lod_tensor(scope, place, "x2")
out = scope.var("out").get_tensor()
out_name = "out"
# create and run sum operator
sum_op = Operator("sum", X=["x1", "x2"], Out=out_name)
sum_op.run(scope, place)
result = np.ones((1, self.height)).astype(np.int32).tolist()[0]
for ele in self.rows:
result[ele] += 1
out_t = np.array(out)
self.assertEqual(out_t.shape[0], self.height)
self.assertTrue(
np.array_equal(out_t,
self._get_array([i for i in range(
self.height)], self.row_numel) * np.tile(
np.array(result).reshape(self.height, 1),
self.row_numel)))
def create_lod_tensor(self, scope, place, var_name):
var = scope.var(var_name)
w_tensor = var.get_tensor()
w_array = self._get_array([i for i in range(self.height)],
self.row_numel)
w_tensor.set(w_array, place)
return var
#----------- test fp16 -----------
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestFP16SumOp(TestSumOp): class TestFP16SumOp(TestSumOp):
def init_kernel_type(self): def init_kernel_type(self):
self.dtype = np.float16 self.dtype = np.float16
def test_check_output(self): def test_check_output(self):
if core.is_compiled_with_cuda(): place = core.CUDAPlace(0)
place = core.CUDAPlace(0) if core.is_float16_supported(place):
if core.is_float16_supported(place): self.check_output_with_place(place, atol=2e-2)
self.check_output_with_place(place, atol=2e-2)
# FIXME: Because of the precision fp16, max_relative_error # FIXME: Because of the precision fp16, max_relative_error
# should be 0.15 here. # should be 0.15 here.
def test_check_grad(self): def test_check_grad(self):
if core.is_compiled_with_cuda(): place = core.CUDAPlace(0)
place = core.CUDAPlace(0) if core.is_float16_supported(place):
if core.is_float16_supported(place): self.check_grad(['x0'], 'Out', max_relative_error=0.15)
self.check_grad(['x0'], 'Out', max_relative_error=0.15)
class TestFP16SelectedRowsSumOp(TestSelectedRowsSumOp): def create_test_sum_fp16_class(parent):
def init_kernel_type(self): @unittest.skipIf(not core.is_compiled_with_cuda(),
self.dtype = np.float16 "core is not compiled with CUDA")
class TestSumFp16Case(parent):
def init_kernel_type(self):
self.dtype = np.float16
def test_w_is_selected_rows(self): def test_w_is_selected_rows(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
if core.is_float16_supported(place): if core.is_float16_supported(place):
for inplace in [True, False]: for inplace in [True, False]:
self.check_with_place(place, inplace) self.check_with_place(place, inplace)
cls_name = "{0}_{1}".format(parent.__name__, "SumFp16Test")
TestSumFp16Case.__name__ = cls_name
globals()[cls_name] = TestSumFp16Case
create_test_sum_fp16_class(TestSelectedRowsSumOp)
create_test_sum_fp16_class(TestLoDTensorAndSelectedRowsOp)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册