diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 1db042c6fc8b6c4ea7c3854ea4b1cd016deeb0b6..d8012fba27bfca05e062e22d38d672bd395df7a6 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -43,4 +43,5 @@ cc_library(paddle_pybind SHARED add_op mean_op cross_entropy_op + fill_zeros_like_op recurrent_op) diff --git a/paddle/framework/pybind.cc b/paddle/framework/pybind.cc index b189e6f9e833b4374f493cc51ac6bfc08491bece..9ee2c6af86476ea50def237ed011fcddaa41daad 100644 --- a/paddle/framework/pybind.cc +++ b/paddle/framework/pybind.cc @@ -40,6 +40,7 @@ USE_OP(mean); USE_OP(sigmoid); USE_OP(softmax); USE_OP(rowwise_add); +USE_OP(fill_zeros_like); USE_OP_WITHOUT_KERNEL(recurrent_op); namespace paddle { namespace framework { diff --git a/paddle/operators/fill_zeros_like_op.cc b/paddle/operators/fill_zeros_like_op.cc index 3d37d64c5a8c288684122f3e686262399d32ed7b..198b4576c887122fdac1a3fbff5a248a3d9fa0a3 100644 --- a/paddle/operators/fill_zeros_like_op.cc +++ b/paddle/operators/fill_zeros_like_op.cc @@ -13,8 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/operators/fill_zeros_like_op.h" -#include "paddle/framework/op_registry.h" -#include "paddle/framework/tensor.h" namespace paddle { namespace operators { diff --git a/paddle/operators/fill_zeros_like_op.cu b/paddle/operators/fill_zeros_like_op.cu index ed1068219c8fee8c6e8809f450a9d38c8226f317..4f1054cf47e35572dbbc51ca742994065a027919 100644 --- a/paddle/operators/fill_zeros_like_op.cu +++ b/paddle/operators/fill_zeros_like_op.cu @@ -12,6 +12,7 @@ See the License for the specific language governing permissions and limitations under the License. */ +#define EIGEN_USE_GPU #include "paddle/framework/op_registry.h" #include "paddle/operators/fill_zeros_like_op.h" diff --git a/paddle/operators/fill_zeros_like_op.h b/paddle/operators/fill_zeros_like_op.h index 4bff1fbfc15af1f4d1ce9c99fe48b0b0f11b5b3f..dfaed2c9aaf2bf5c1a9b803fc9c8b9ea0e5c5d4e 100644 --- a/paddle/operators/fill_zeros_like_op.h +++ b/paddle/operators/fill_zeros_like_op.h @@ -13,9 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once -#include "glog/logging.h" -#include "paddle/framework/eigen.h" -#include "paddle/framework/operator.h" +#include "paddle/operators/type_alias.h" namespace paddle { namespace operators { @@ -26,7 +24,8 @@ class FillZerosLikeKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& context) const override { auto* output = context.Output(0); output->mutable_data(context.GetPlace()); - framework::EigenVector::Flatten(*output).setZero(); + auto t = framework::EigenVector::Flatten(*output); + t.device(context.GetEigenDevice()) = t.constant(T(0)); } }; diff --git a/paddle/pybind/CMakeLists.txt b/paddle/pybind/CMakeLists.txt index 29dd0ded0ac75893da7e244d92725cd5e285efce..8e6b258e00c0012876cda8ffc5b340322d51e894 100644 --- a/paddle/pybind/CMakeLists.txt +++ b/paddle/pybind/CMakeLists.txt @@ -6,4 +6,5 @@ cc_library(paddle_pybind SHARED add_op mean_op cross_entropy_op - recurrent_op) + recurrent_op + fill_zeros_like_op) diff --git a/python/paddle/v2/framework/tests/CMakeLists.txt b/python/paddle/v2/framework/tests/CMakeLists.txt index 4322781b342fa94ac77db03954b8592e750cfbe3..541639ac21661529b0b1f2cc8d8fa25605052c8c 100644 --- a/python/paddle/v2/framework/tests/CMakeLists.txt +++ b/python/paddle/v2/framework/tests/CMakeLists.txt @@ -13,6 +13,7 @@ py_test(test_protobuf SRCS test_protobuf.py) py_test(test_add_two_op SRCS test_add_two_op.py) py_test(test_sigmoid_op SRCS test_sigmoid_op.py) py_test(test_softmax_op SRCS test_softmax_op.py) +py_test(test_fill_zeros_like_op SRCS test_fill_zeros_like_op.py) py_test(gradient_checker SRCS gradient_checker.py) diff --git a/python/paddle/v2/framework/tests/test_fill_zeros_like_op.py b/python/paddle/v2/framework/tests/test_fill_zeros_like_op.py new file mode 100644 index 0000000000000000000000000000000000000000..e5c862605fb11a5ea1426cf8f9054589dc377ff1 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_fill_zeros_like_op.py @@ -0,0 +1,16 @@ +import unittest +from op_test_util import OpTestMeta +import numpy + + +class TestFillZerosLikeOp(unittest.TestCase): + __metaclass__ = OpTestMeta + + def setUp(self): + self.type = "fill_zeros_like" + self.inputs = {'Src': numpy.random.random((219, 232)).astype("float32")} + self.outputs = {'Dst': numpy.zeros_like(self.inputs['Src'])} + + +if __name__ == '__main__': + unittest.main()