提交 fb8d007f 编写于 作者: Q Qiao Longfei 提交者: qiaolongfei

Scale support selectedrows (#12960)

* add ScaleOpVarTypeInference for scale op

* scale op support scale selected rows

* optimize code

* use FindVar

* use FindVarRecursive in ScaleOpVarTypeInference
上级 ec9eb220
...@@ -13,8 +13,11 @@ See the License for the specific language governing permissions and ...@@ -13,8 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/scale_op.h" #include "paddle/fluid/operators/scale_op.h"
#include <string> #include <string>
#include "paddle/fluid/operators/detail/safe_ref.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -52,6 +55,21 @@ $$Out = scale*X$$ ...@@ -52,6 +55,21 @@ $$Out = scale*X$$
} }
}; };
class ScaleOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {
auto &in_var_name = op_desc.Input("X").front();
auto &in_var = detail::Ref(block->FindVarRecursive(in_var_name));
auto out_var_name = op_desc.Output("Out").front();
auto *out_var = block->FindVarRecursive(out_var_name);
out_var->SetType(in_var.GetType());
out_var->SetDataType(in_var.GetDataType());
}
};
class ScaleGradMaker : public framework::SingleGradOpDescMaker { class ScaleGradMaker : public framework::SingleGradOpDescMaker {
public: public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
...@@ -71,7 +89,8 @@ class ScaleGradMaker : public framework::SingleGradOpDescMaker { ...@@ -71,7 +89,8 @@ class ScaleGradMaker : public framework::SingleGradOpDescMaker {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(scale, ops::ScaleOp, ops::ScaleOpMaker, ops::ScaleGradMaker); REGISTER_OPERATOR(scale, ops::ScaleOp, ops::ScaleOpMaker, ops::ScaleGradMaker,
ops::ScaleOpVarTypeInference);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
scale, ops::ScaleKernel<paddle::platform::CPUDeviceContext, float>, scale, ops::ScaleKernel<paddle::platform::CPUDeviceContext, float>,
ops::ScaleKernel<paddle::platform::CPUDeviceContext, double>, ops::ScaleKernel<paddle::platform::CPUDeviceContext, double>,
......
...@@ -22,17 +22,29 @@ namespace operators { ...@@ -22,17 +22,29 @@ namespace operators {
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class ScaleKernel : public framework::OpKernel<T> { class ScaleKernel : public framework::OpKernel<T> {
public: public:
virtual void Compute(const framework::ExecutionContext& context) const { virtual void Compute(const framework::ExecutionContext& ctx) const {
auto* tensor = context.Output<framework::Tensor>("Out"); auto* in_var = ctx.InputVar("X");
auto* in = context.Input<framework::Tensor>("X"); auto* in = ctx.Input<framework::Tensor>("X");
tensor->mutable_data<T>(in->place());
auto scale = static_cast<T>(context.Attr<float>("scale")); auto* out_var = ctx.OutputVar("Out");
auto* out = ctx.Output<framework::Tensor>("Out");
out->mutable_data<T>(in->place());
auto eigen_out = framework::EigenVector<T>::Flatten(*tensor); PADDLE_ENFORCE_EQ(in->dims(), out->dims(),
"in and out should have the same dim");
auto scale = static_cast<T>(ctx.Attr<float>("scale"));
if (in_var->IsType<framework::SelectedRows>() && in_var != out_var) {
auto& in_slr = in_var->Get<framework::SelectedRows>();
auto* out_slr = out_var->GetMutable<framework::SelectedRows>();
out_slr->set_rows(in_slr.rows());
out_slr->set_height(in_slr.height());
}
auto eigen_out = framework::EigenVector<T>::Flatten(*out);
auto eigen_in = framework::EigenVector<T>::Flatten(*in); auto eigen_in = framework::EigenVector<T>::Flatten(*in);
auto& dev = auto& dev = *ctx.template device_context<DeviceContext>().eigen_device();
*context.template device_context<DeviceContext>().eigen_device();
eigen_out.device(dev) = scale * eigen_in; eigen_out.device(dev) = scale * eigen_in;
} }
}; };
......
...@@ -17,6 +17,8 @@ from __future__ import print_function ...@@ -17,6 +17,8 @@ from __future__ import print_function
import unittest import unittest
import numpy as np import numpy as np
from op_test import OpTest from op_test import OpTest
import paddle.fluid.core as core
from paddle.fluid.op import Operator
class TestScaleOp(OpTest): class TestScaleOp(OpTest):
...@@ -33,5 +35,57 @@ class TestScaleOp(OpTest): ...@@ -33,5 +35,57 @@ class TestScaleOp(OpTest):
self.check_grad(['X'], 'Out') self.check_grad(['X'], 'Out')
class TestScaleOpSelectedRows(unittest.TestCase):
def check_with_place(self, place, in_name, out_name):
scope = core.Scope()
# create and initialize Grad Variable
in_height = 10
in_rows = [0, 4, 7]
in_row_numel = 12
scale = 2.0
in_selected_rows = scope.var(in_name).get_selected_rows()
in_selected_rows.set_height(in_height)
in_selected_rows.set_rows(in_rows)
in_array = np.random.random(
(len(in_rows), in_row_numel)).astype("float32")
in_tensor = in_selected_rows.get_tensor()
in_tensor.set(in_array, place)
# create and initialize Param Variable
out_selected_rows = scope.var(out_name).get_selected_rows()
out_tensor = out_selected_rows.get_tensor()
out_tensor._set_dims(in_tensor._get_dims())
# create and run sgd operator
scale_op = Operator("scale", X=in_name, Out=out_name, scale=scale)
scale_op.run(scope, place)
# get and compare result
out_height = out_selected_rows.height()
out_rows = out_selected_rows.rows()
result_array = np.array(out_tensor)
assert (in_array * scale == result_array).all()
assert in_height == out_height
assert in_rows == out_rows
def test_scale_selected_rows(self):
places = [core.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(core.CUDAPlace(0))
for place in places:
self.check_with_place(place, 'in', 'out')
def test_scale_selected_rows_inplace(self):
places = [core.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(core.CUDAPlace(0))
for place in places:
self.check_with_place(place, 'in', 'in')
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册