提交 6e7666f1 编写于 作者: Y Yang Yang

before backward

上级 c83ea1cd
...@@ -44,7 +44,7 @@ add_custom_command(TARGET framework_py_proto POST_BUILD ...@@ -44,7 +44,7 @@ add_custom_command(TARGET framework_py_proto POST_BUILD
cc_library(backward SRCS backward.cc DEPS net_op) cc_library(backward SRCS backward.cc DEPS net_op)
cc_test(backward_test SRCS backward_test.cc DEPS backward recurrent_op device_context) cc_test(backward_test SRCS backward_test.cc DEPS backward recurrent_op device_context)
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto ${GLOB_OP_LIB}) cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto backward ${GLOB_OP_LIB})
if(WITH_GPU) if(WITH_GPU)
nv_test(executor_test SRCS executor_test.cc DEPS executor) nv_test(executor_test SRCS executor_test.cc DEPS executor)
else() else()
......
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
#include <vector> #include <vector>
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/framework/attribute.h" #include "paddle/framework/attribute.h"
#include "paddle/framework/backward.h"
#include "paddle/framework/block_desc.h" #include "paddle/framework/block_desc.h"
#include "paddle/framework/grad_op_builder.h" #include "paddle/framework/grad_op_builder.h"
#include "paddle/framework/op_desc.h" #include "paddle/framework/op_desc.h"
...@@ -27,6 +28,7 @@ USE_OP(gaussian_random); ...@@ -27,6 +28,7 @@ USE_OP(gaussian_random);
USE_OP(feed); USE_OP(feed);
USE_OP(fetch); USE_OP(fetch);
USE_OP(mul); USE_OP(mul);
USE_OP(squared_l2_distance);
using std::string; using std::string;
using namespace paddle::platform; using namespace paddle::platform;
...@@ -170,10 +172,16 @@ class ExecutorTesterRandom : public ::testing::Test { ...@@ -170,10 +172,16 @@ class ExecutorTesterRandom : public ::testing::Test {
root_block); root_block);
AddOp("mul", {{"X", {"b"}}, {"Y", {"w2"}}}, {{"Out", {"a_out"}}}, {}, AddOp("mul", {{"X", {"b"}}, {"Y", {"w2"}}}, {{"Out", {"a_out"}}}, {},
root_block); root_block);
AddOp("squared_l2_distance", {{"X", {"a"}}, {"Y", {"a_out"}}},
AddOp("fetch", {{"Input", {"a_out"}}}, {}, {{"Out", {"l2_distance"}}, {"sub_result", {"l2_distance_sub"}}}, {},
{{"dims", std::vector<int>{input_dim, batch_size}}, {"col", 1}},
root_block); root_block);
AppendBackward(pdesc_, {});
// AddOp("fetch", {{"Input", {"sub_result"}}}, {},
// {{"dims", std::vector<int>{input_dim, batch_size}}, {"col", 0}},
// root_block);
AddOp("fetch", {{"Input", {"l2_distance"}}}, {},
{{"dims", std::vector<int>{batch_size}}, {"col", 1}}, root_block);
} }
protected: protected:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册