记录最近算法工程里开发的pipeline设计模式。优化了上一版本:
1,增加了线程池管理,每个node可以异步处理;
2,增加了callback,将最后一个node的结果callback到主程序,避免的参数传递的冗余实现;
3,去掉了模板类设计,避免只能在头文件中去实现的弊端;
4,去掉了前node的输出就是后node的输入,避免函数返回值带来复制的开销的应用;
/** @ 带有线程池的pipeline pipeline里的Node可以异步执行,加快处理速度 */
task_queue.h
代码语言:javascript复制/** @ 线程池的任务队列 @ 入队和出队 */
template<class T>
class TaskQueue
{
public:
TaskQueue() = default;
~TaskQueue() = default;
//任务入队
void enqueue(T& t)
{
std::unique_lock<std::mutex> lock(m_mutex);
if(m_pNextQueue)
{
m_pNextQueue->enqueue(t);
return;
}
m_queue.push(t);
}
//任务出队
bool dequeue(T& t)
{
std::unique_lock<std::mutex> lock(m_mutex);
if(m_queue.empty())
return false;
t = std::move(m_queue.front());
m_queue.pop();
return true;
}
int32_t size()
{
std::unique_lock<std::mutex> lock(m_mutex);
return m_queue.szie();
}
bool empty()
{
std::unique_lock<std::mutex> lock(m_mutex);
return m_queue.empty();
}
//出队等待
bool dequeue_wait(T& t,uint32_t timeout)
{
std::unique_lock<std::mutex> lock(m_mutex);
if(m_queue.empty())
m_cond.wait_for(lock,std::chrono::milliseconds(timeout));
if(m_queue.empty())
return false;
t = std::move(m_queue.front());
m_queue.pop();
return true;
}
//取出taskQueue对象
void connect(TaskQueue<T>* pQueue)
{
std::unique_lock<std::mutex> lock(m_mutex);
m_pNextQueue = pQueue;
}
private:
std::queue<T> m_queue;
std::mutex m_mutex;
std::condition_variable m_cond;
TaskQueue<T>* m_pNextQueue;
};
thread_manager.h
代码语言:javascript复制/** @ 线程管理 */
static const uint32_t MaxThreadNums = 8;
class ThreadManager
{
public:
ThreadManager(const int m_threads = MaxThreadNums ):m_threads(std::vector<std::thread>(m_threads)),m_shutdown(false){
}
~ThreadManager(){
this->shutdown();
}
ThreadManager(ThreadManager &&)=delete;
ThreadManager(const ThreadManager &)=delete;
ThreadManager &operator=(ThreadManager &&)=delete;
ThreadManager &operator=(const ThreadManager &) =delete;
void init()
{
for(uint32_t i =0; i < m_threads.size(); i)
{
m_threads.at(i) = std::thread(ThreadWorker(this,i));
}
}
void shutdown()
{
m_shutdown = true;
m_cond.notify_all();
for(uint32_t i =0; i < m_threads.size(); i)
{
if(m_threads.at(i).joinable())
{
m_threads.at(i).join();
}
}
}
template<typename F,typename... Args>
auto postJobs(F&& f, Args &&...args)->std::future<decltype(f(args...))>
{
std::function<decltype(f(args...))()> func = std::bind(std::forward<F>(f),std::forward<Args>...);
auto task_ptr = std::make_shared<std::packaged_task<decltype(f(args...))()>>(func);
std::function<void()> warpper_func = [task_ptr]()
{
(*task_ptr);
};
m_task_queue.push(warpper_func);
m_cond.notify_one();
return task_ptr->get_future();
}
private:
class ThreadWorker
{
public:
ThreadWorker(ThreadManager *pThreadManager,const int32_t tid):m_pThreadManager(pThreadManager),m_tid(tid){
};
void operator()()
{
std::function<void()> task;
bool dequeued = false;
while(!m_pThreadManager->m_shutdown)
{
std::unique_lock<std::mutex> lock(m_pThreadManager->m_mutex);
m_pThreadManager->m_cond.wait(lock,[&](){
return !m_pThreadManager->m_task_queue.empty();
});
m_pThreadManager->m_task_queue.pop();
lock.unlock();
task();
}
}
private:
int32_t m_tid;
ThreadManager *m_pThreadManager;
};
private:
bool m_shutdown;
std::mutex m_mutex;
std::condition_variable m_cond;
std::vector<std::thread> m_threads;
std::queue<std::function<void()>> m_task_queue;
};
common_struct.h
代码语言:javascript复制/** @ pipeline的入参结构体 */
enum NodeType
{
Source,
Channel,
Sink
};
struct NodeNeedInfo
{
std::string name;
NodeType type;
};
struct InputRequestInfo
{
bool isOK;
uint32_t requestId;
//nodeInput Info
NodeNeedInfo nodeInfo[8];
};
using NodeNeedInfoPtr = std::shared_ptr<NodeNeedInfo>;
using InputRequestInfoPtr = std::shared_ptr<InputRequestInfo>;
using ResultCallback = std::function<void(const InputRequestInfoPtr&)>;
struct PipelineDescriptor
{
uint32_t nums;
std::string name;
//NodeInfo
NodeNeedInfo nodes[8];
ResultCallback callback;
};
using PipelineDescriptorPtr = std::shared_ptr<PipelineDescriptor>;
node.h
代码语言:javascript复制//node.h : base Node /*** @ 1, 去掉了类模板 @ 2, 不需要上一级的输出是下一级的输入 @ 3, 通过callback的方式将最后一级的结果输出给前一级 */
class Node
{
public:
Node(): m_stop(false),m_is_sink(false){};
virtual ~Node() = default;
virtual int32_t initialize(const std::string& conf) = 0;
virtual int32_t process(InputRequestInfoPtr pRequestInfo) = 0;
virtual std::string getNodeName() const= 0;
virtual NodeType Type()const =0;
public:
void start()
{
//起线程处理
m_thread = std::thread([this](){
executeRequest();
});
}
void stop()
{
m_stop = true;
if(m_thread.joinable())
{
m_thread.join();
}
}
// inline std::string getNodeName() const
// {
// return m_node_name;
// }
void executeRequest()
{
int count = 0;
while(!m_stop)
{
InputRequestInfoPtr pRequest;
if(m_input_queue.dequeue(pRequest))
{
int32_t ret = process(pRequest);
if(ret != 0)
{
///////////
}
//set request for next node
if(m_type != NodeType::Sink)//bug to do
{
count ;
m_output_queue.enqueue(pRequest);
}
else
{
m_result_callback(pRequest);//回到main: publishResult
}
}
else
{
////////////
}
}
}
TaskQueue<InputRequestInfoPtr> &input_queue()
{
return m_input_queue;
}
TaskQueue<InputRequestInfoPtr> &output_queue()
{
return m_output_queue;
}
// inline NodeType Type()const
// {
// return m_type;
// }
void callbackRegister(ResultCallback callback)
{
m_result_callback = std::move(callback);
}
private:
bool m_stop;
bool m_is_sink;
bool m_source;
TaskQueue<InputRequestInfoPtr> m_input_queue;
TaskQueue<InputRequestInfoPtr> m_output_queue;
std::thread m_thread;
std::string m_node_name;
ResultCallback m_result_callback;
NodeType m_type;
};
nodeA/B
代码语言:javascript复制/** NodeA -> NodeB -> NodeC */
class Node_A :public Node
{
public:
Node_A() = default;
~Node_A() =default;
int32_t initialize(const std::string& conf)override{
std::cout<<"I am NodeA initialize"<<std::endl;
return 0;
}
int32_t process(InputRequestInfoPtr pRequestInfo)override{
std::cout<<"I am NodeA process"<<std::endl;
pRequestInfo->requestId = 100;
return 0;
}
std::string getNodeName()const override
{
return "Node_A";
}
NodeType Type()const override
{
return NodeType::Source;
}
};
//NodeB
class Node_B :public Node
{
public:
Node_B() = default;
~Node_B() =default;
int32_t initialize(const std::string& conf)override{
std::cout<<"I am NodeB initialize"<<std::endl;
return 0;
}
int32_t process(InputRequestInfoPtr pRequestInfo)override{
std::cout<<"I am NodeB process"<<std::endl;
return 0;
}
std::string getNodeName()const override
{
return "Node_B";
}
NodeType Type()const override
{
return NodeType::Sink;
}
};
perceptionPipeline.h
代码语言:javascript复制/** @ 一个具体的pipeline */
class PerceptionPipeline
{
public:
PerceptionPipeline()=default;
~PerceptionPipeline()=default;
/**
@ submit request to source node
*/
void submit(InputRequestInfoPtr& pRequest)
{
m_pNodes[0]->input_queue().enqueue(pRequest);
}
/**
@ initialize an pipline
*/
int32_t initialize(const PipelineDescriptorPtr& pPipelineDesc)
{
int32_t result = 0;
result = createNodes(pPipelineDesc);
return result;
}
int32_t createNodes(const PipelineDescriptorPtr& pPipelineInfo)
{
int32_t result = 0;
for(uint32_t i=0 ; i < pPipelineInfo->nums; i )
{
//todo factory create nodes
std::shared_ptr<Node> pNode = std::move(CreateNode(pPipelineInfo->nodes[i]));
result = pNode->initialize("lxk");
if(0!=result)
{
//////////////
break;
}
if(pNode->Type() == NodeType::Sink)
{
std::cout<<"------------callbackRegister-----------"<<std::endl;
pNode->callbackRegister(pPipelineInfo->callback);
}
this->addNode(pNode);
}
return result;
}
static std::shared_ptr<Node> CreateNode(const NodeNeedInfo& node_desc)
{
if(node_desc.name == "NodeA")
return (std::make_shared<Node_A>());
if(node_desc.name == "NodeB")
return (std::make_shared<Node_B>());
return nullptr;
}
void start()
{
for(auto i:m_pNodes)
{
i->start();
}
}
void stop()
{
for(auto i:m_pNodes)
{
i->stop();
}
}
std::string PipelineInfo()
{
std::stringstream sstr;
sstr<<"n";
sstr<<"-------Pipeline info start----------n";
sstr<<"number of nodes: "<<m_pNodes.size()<<"n";
for(uint32_t i =0; i <m_pNodes.size(); i )
{
if(i == m_pNodes.size() -1)
{
sstr<<m_pNodes[i]->getNodeName()<<"n";
}
else
{
sstr<<m_pNodes[i]->getNodeName()<<"->";
}
}
sstr<<"----------Pipeline info end----------n";
return sstr.str();
}
private:
void addNode(std::shared_ptr<Node>& pNode)
{
std::shared_ptr<Node> pTail = nullptr;
if(!m_pNodes.empty())
{
pTail = m_pNodes.back();
}
m_pNodes.push_back(pNode);
//connect output queue node with input queue of next node
if(pTail)
{
pTail->output_queue().connect(&pNode->input_queue());
}
}
private:
std::vector<std::shared_ptr<Node>> m_pNodes;
};
CameraPerception
代码语言:javascript复制/** @ 实际测试案例 */
class CameraPerception
{
public:
CameraPerception();
~CameraPerception();
bool init();
private:
void cameraPerceptionCallback();
void publishResult(const InputRequestInfoPtr& pInferResult);
void MarkObstacleOnImage(uint64_t request_id);
std::unique_ptr<PerceptionPipeline> m_perception_pipeline;
std::unique_ptr<ThreadManager> m_thread_manager;
};
CameraPerception::CameraPerception()
{
}
CameraPerception::~CameraPerception()
{
m_perception_pipeline->stop();
}
bool CameraPerception::init()
{
m_thread_manager.reset(new ThreadManager());
m_thread_manager->init();
m_perception_pipeline.reset(new PerceptionPipeline);
PipelineDescriptorPtr pPipeline(new PipelineDescriptor);
pPipeline->name = "perception pipeline";
int count = 2;
pPipeline->nodes[0].name = "NodeA";
pPipeline->nodes[0].type = NodeType::Source;
pPipeline->nodes[1].name = "NodeB";
pPipeline->nodes[1].type = NodeType::Sink;
pPipeline->nums = count;
pPipeline->callback = std::bind(&CameraPerception::publishResult, this, std::placeholders::_1);
int32_t ret = m_perception_pipeline->initialize(pPipeline);
if(ret != 0)
{
std::cout<<"pipeline init error";
return false;
}
m_perception_pipeline->start();
std::cout<<"pipeline info: "<<m_perception_pipeline->PipelineInfo()<<std::endl;
cameraPerceptionCallback();
return ret;
}
void CameraPerception::cameraPerceptionCallback()
{
InputRequestInfoPtr input_info(new InputRequestInfo);
input_info->requestId = 1;
input_info->isOK = true;
for(size_t i =0; i < 3; i )
{
input_info->nodeInfo[i].name = "lxkkk";
}
m_perception_pipeline->submit(input_info);
}
//callbacked by node
void CameraPerception::publishResult(const InputRequestInfoPtr& pInferResult)
{
std::cout<<"publishResult: ID: "<<pInferResult->requestId<<std::endl;
m_thread_manager->postJobs(std::bind(&CameraPerception::MarkObstacleOnImage, this,pInferResult->requestId));
}
void CameraPerception::MarkObstacleOnImage(uint64_t request_id)
{
std::cout<<"MarkObstacleOnImage: ID: "<<request_id<<std::endl;
}
int main()
{
std::unique_ptr<CameraPerception> pCameraPerceptionHandle(new CameraPerception());
pCameraPerceptionHandle->init();
}