线程池管理的pipeline设计模式(用了“精进C++”里的内容)

2022-12-04 16:29:30 浏览数 (1)

记录最近算法工程里开发的pipeline设计模式。优化了上一版本:

1,增加了线程池管理,每个node可以异步处理;

2,增加了callback,将最后一个node的结果callback到主程序,避免的参数传递的冗余实现;

3,去掉了模板类设计,避免只能在头文件中去实现的弊端;

4,去掉了前node的输出就是后node的输入,避免函数返回值带来复制的开销的应用;

/** @ 带有线程池的pipeline pipeline里的Node可以异步执行,加快处理速度 */

task_queue.h

/** @ 线程池的任务队列 @ 入队和出队 */

代码语言:javascript复制
template<class T>
class TaskQueue
{
    public:
        TaskQueue() = default;
        ~TaskQueue() = default;

        //任务入队
        void enqueue(T& t)
        {
            std::unique_lock<std::mutex> lock(m_mutex);
            if(m_pNextQueue)
            {
                m_pNextQueue->enqueue(t);
                return;
            }
            m_queue.push(t);
        }

        //任务出队
        bool dequeue(T& t)
        {
            std::unique_lock<std::mutex> lock(m_mutex);
            if(m_queue.empty())
                return false;

            t = std::move(m_queue.front());
            m_queue.pop();
            return true;
        }

        int32_t size()
        {
            std::unique_lock<std::mutex> lock(m_mutex);
            return m_queue.szie();
        }

        bool empty()
        {
            std::unique_lock<std::mutex> lock(m_mutex);
            return m_queue.empty();
        }

        //出队等待
        bool dequeue_wait(T& t,uint32_t timeout)
        {
            std::unique_lock<std::mutex> lock(m_mutex);
            if(m_queue.empty())
                m_cond.wait_for(lock,std::chrono::milliseconds(timeout));

            if(m_queue.empty())
                return false;

            t = std::move(m_queue.front());
            m_queue.pop();
            return true;
        }

        //取出taskQueue对象
        void connect(TaskQueue<T>* pQueue)
        {
            std::unique_lock<std::mutex> lock(m_mutex);
            m_pNextQueue = pQueue;
        }

    private:

        std::queue<T> m_queue;
        std::mutex m_mutex;
        std::condition_variable m_cond;
        TaskQueue<T>* m_pNextQueue;
};

thread_manager.h

/** @ 线程管理 */

代码语言:javascript复制
static const uint32_t MaxThreadNums = 8;
class ThreadManager
{
    public:
        ThreadManager(const int m_threads = MaxThreadNums ):m_threads(std::vector<std::thread>(m_threads)),m_shutdown(false){

        }

        ~ThreadManager(){
            this->shutdown();
        }

        ThreadManager(ThreadManager &&)=delete;
        ThreadManager(const ThreadManager &)=delete;
        ThreadManager &operator=(ThreadManager &&)=delete;
        ThreadManager &operator=(const ThreadManager &) =delete;

        void init()
        {
            for(uint32_t i =0; i < m_threads.size();  i)
            {
                m_threads.at(i) = std::thread(ThreadWorker(this,i));
            }
        }

        void shutdown()
        {
            m_shutdown = true;
            m_cond.notify_all();
            for(uint32_t i =0; i < m_threads.size();   i)
            {
                if(m_threads.at(i).joinable())
                {
                    m_threads.at(i).join();
                }
            }
        }

        template<typename F,typename... Args>
        auto postJobs(F&& f, Args &&...args)->std::future<decltype(f(args...))>
        {
            std::function<decltype(f(args...))()> func = std::bind(std::forward<F>(f),std::forward<Args>...);
            auto task_ptr = std::make_shared<std::packaged_task<decltype(f(args...))()>>(func);

            std::function<void()> warpper_func = [task_ptr]()
            {
                (*task_ptr);
            };

            m_task_queue.push(warpper_func);
            m_cond.notify_one();

            return task_ptr->get_future();
        }

    private:

        class ThreadWorker
        {
            public:
                ThreadWorker(ThreadManager *pThreadManager,const int32_t tid):m_pThreadManager(pThreadManager),m_tid(tid){

                };

                void operator()()
                {
                    std::function<void()> task;

                    bool dequeued = false;
                    while(!m_pThreadManager->m_shutdown)
                    {
                        std::unique_lock<std::mutex> lock(m_pThreadManager->m_mutex);
                        m_pThreadManager->m_cond.wait(lock,[&](){
                            return !m_pThreadManager->m_task_queue.empty();
                        });

                        m_pThreadManager->m_task_queue.pop();
                        lock.unlock();

                        task();
                    }
                }

            private:
                int32_t m_tid;
                ThreadManager *m_pThreadManager;
        };

    private:
        bool m_shutdown;
        std::mutex m_mutex;
        std::condition_variable m_cond;
        std::vector<std::thread> m_threads;
        std::queue<std::function<void()>> m_task_queue;

};

common_struct.h

/** @ pipeline的入参结构体 */

代码语言:javascript复制
enum NodeType
{
    Source,
    Channel,
    Sink
};
struct NodeNeedInfo
{
    std::string name;
    NodeType type;
};
struct InputRequestInfo
{
    bool isOK;
    uint32_t requestId;
    
    //nodeInput Info

    NodeNeedInfo nodeInfo[8];
};
using NodeNeedInfoPtr = std::shared_ptr<NodeNeedInfo>;
using InputRequestInfoPtr = std::shared_ptr<InputRequestInfo>;
using ResultCallback = std::function<void(const InputRequestInfoPtr&)>;

struct PipelineDescriptor
{
    uint32_t nums;
    std::string name;
    
    //NodeInfo
    NodeNeedInfo nodes[8];
    ResultCallback callback;
};
using PipelineDescriptorPtr = std::shared_ptr<PipelineDescriptor>;

node.h

//node.h : base Node /*** @ 1, 去掉了类模板 @ 2, 不需要上一级的输出是下一级的输入 @ 3, 通过callback的方式将最后一级的结果输出给前一级 */

代码语言:javascript复制
class Node
{
    public:
        Node(): m_stop(false),m_is_sink(false){};
        virtual ~Node() = default;

        virtual int32_t initialize(const std::string& conf) = 0;
        virtual int32_t process(InputRequestInfoPtr pRequestInfo) = 0;

        virtual  std::string getNodeName() const= 0;

        virtual NodeType Type()const =0;
    

    public:

        void start()
        {
            //起线程处理
            m_thread = std::thread([this](){
                executeRequest();
            });
        }

        void stop()
        {
            m_stop = true;
            if(m_thread.joinable())
            {
                m_thread.join();
            }
        }

        // inline std::string getNodeName() const
        // {
        //     return m_node_name;
        // }

        void executeRequest()
        {   
            int count = 0;
            while(!m_stop)
            {
                InputRequestInfoPtr pRequest;
                if(m_input_queue.dequeue(pRequest))
                {
                    int32_t ret = process(pRequest);

                    if(ret != 0)
                    {
                        ///////////
                    }

                    //set request for next node
                    if(m_type != NodeType::Sink)//bug to do
                    {
                        count  ;
                        m_output_queue.enqueue(pRequest);
                    }
                    else
                    {
                        m_result_callback(pRequest);//回到main: publishResult
                    }
                 
                    
                }
                else
                {
                    ////////////
                }
            }
        }

        TaskQueue<InputRequestInfoPtr> &input_queue()
        {
            return m_input_queue;
        }

        TaskQueue<InputRequestInfoPtr> &output_queue()
        {
            return m_output_queue;
        }

        // inline NodeType Type()const
        // {
        //     return m_type;
        // }

        void callbackRegister(ResultCallback callback)
        {
            m_result_callback = std::move(callback);
        }

    private:
        bool m_stop;
        bool m_is_sink;
        bool m_source;
        TaskQueue<InputRequestInfoPtr> m_input_queue;
        TaskQueue<InputRequestInfoPtr> m_output_queue;
        std::thread  m_thread;
        std::string m_node_name;
        ResultCallback m_result_callback;
        NodeType m_type;

};

nodeA/B

/** NodeA -> NodeB -> NodeC */

代码语言:javascript复制
class Node_A :public Node
{
    public:
        Node_A() = default;
        ~Node_A() =default;

        int32_t initialize(const std::string& conf)override{
            std::cout<<"I am NodeA initialize"<<std::endl;
            return 0;
        }

        int32_t process(InputRequestInfoPtr pRequestInfo)override{
            std::cout<<"I am NodeA process"<<std::endl;
            pRequestInfo->requestId = 100;
            return 0;
        }

        std::string getNodeName()const override
        {
            return "Node_A";
        }

        NodeType Type()const override
        {
            return NodeType::Source;
        }

};

//NodeB
class Node_B :public Node
{
    public:
        Node_B() = default;
        ~Node_B() =default;

        int32_t initialize(const std::string& conf)override{
            std::cout<<"I am NodeB initialize"<<std::endl;
            return 0;
        }

        int32_t process(InputRequestInfoPtr pRequestInfo)override{
            std::cout<<"I am NodeB process"<<std::endl;
            return 0;
        }

        std::string getNodeName()const override
        {
            return "Node_B";
        }

        NodeType Type()const override
        {
            return NodeType::Sink;
        }

};

perceptionPipeline.h

/** @ 一个具体的pipeline */

代码语言:javascript复制
class PerceptionPipeline
{

    public:
        PerceptionPipeline()=default;
        ~PerceptionPipeline()=default;

        /**
        @ submit request to source node
        */
        void submit(InputRequestInfoPtr& pRequest)
        {
            m_pNodes[0]->input_queue().enqueue(pRequest);
        }

        /**
        @ initialize an pipline
        */
        int32_t initialize(const PipelineDescriptorPtr& pPipelineDesc)
        {
            int32_t result = 0;
            result = createNodes(pPipelineDesc);

            return result;
        }
        
        int32_t createNodes(const PipelineDescriptorPtr& pPipelineInfo)
        {
            int32_t result = 0;
            for(uint32_t i=0 ; i < pPipelineInfo->nums; i  )
            {
                //todo factory create nodes
                std::shared_ptr<Node> pNode = std::move(CreateNode(pPipelineInfo->nodes[i]));

                result = pNode->initialize("lxk");

                if(0!=result)
                {
                    //////////////
                    break;
                }

                if(pNode->Type() == NodeType::Sink)
                {   
                    std::cout<<"------------callbackRegister-----------"<<std::endl;
                    pNode->callbackRegister(pPipelineInfo->callback);
                }

                this->addNode(pNode);
            }

            return result;
        }
        static std::shared_ptr<Node> CreateNode(const NodeNeedInfo& node_desc)
        {
            if(node_desc.name == "NodeA")
                return (std::make_shared<Node_A>());
            if(node_desc.name == "NodeB")
                return (std::make_shared<Node_B>());

            return nullptr;
        }

        void start()
        {
            for(auto i:m_pNodes)
            {
                i->start();
            }
        }

        void stop()
        {
            for(auto i:m_pNodes)
            {
                i->stop();
            }
        }

        std::string PipelineInfo()
        {
            std::stringstream sstr;
            sstr<<"n";
            sstr<<"-------Pipeline info  start----------n";
            sstr<<"number of nodes: "<<m_pNodes.size()<<"n";
            for(uint32_t i =0; i <m_pNodes.size(); i  )
            {
                if(i == m_pNodes.size() -1)
                {
                    sstr<<m_pNodes[i]->getNodeName()<<"n";
                }
                else
                {
                    sstr<<m_pNodes[i]->getNodeName()<<"->";
                }
            }

            sstr<<"----------Pipeline info end----------n";

            return sstr.str();
        }

    private:

        void addNode(std::shared_ptr<Node>& pNode)
        {
            std::shared_ptr<Node> pTail = nullptr;
            if(!m_pNodes.empty())
            {
                pTail = m_pNodes.back();
            }
            m_pNodes.push_back(pNode);

            //connect output queue node with input queue of next node
            if(pTail)
            {
                pTail->output_queue().connect(&pNode->input_queue());
            }
        }

    private:
        std::vector<std::shared_ptr<Node>> m_pNodes;
};

CameraPerception

/** @ 实际测试案例 */

代码语言:javascript复制
class CameraPerception
{
    public:
        CameraPerception();
        ~CameraPerception();

        bool init();

    private:

        void cameraPerceptionCallback();

        void publishResult(const InputRequestInfoPtr& pInferResult);

        void MarkObstacleOnImage(uint64_t request_id);

        std::unique_ptr<PerceptionPipeline> m_perception_pipeline;
        std::unique_ptr<ThreadManager> m_thread_manager;

};

CameraPerception::CameraPerception()
{

}

CameraPerception::~CameraPerception()
{
    m_perception_pipeline->stop();
}

bool CameraPerception::init()
{
    m_thread_manager.reset(new ThreadManager());
    m_thread_manager->init();

    m_perception_pipeline.reset(new PerceptionPipeline);
    
    PipelineDescriptorPtr pPipeline(new PipelineDescriptor);
    pPipeline->name = "perception pipeline";

    int count = 2;

    pPipeline->nodes[0].name = "NodeA";
    pPipeline->nodes[0].type = NodeType::Source;

    pPipeline->nodes[1].name = "NodeB";
    pPipeline->nodes[1].type = NodeType::Sink;
    pPipeline->nums = count;
    
    pPipeline->callback = std::bind(&CameraPerception::publishResult, this, std::placeholders::_1);
    int32_t ret = m_perception_pipeline->initialize(pPipeline);

    if(ret != 0)
    {
        std::cout<<"pipeline init error";
        return false;
    }
  
    m_perception_pipeline->start();

    std::cout<<"pipeline info: "<<m_perception_pipeline->PipelineInfo()<<std::endl;

    cameraPerceptionCallback();


    return ret;
}


void CameraPerception::cameraPerceptionCallback()
{
    InputRequestInfoPtr input_info(new InputRequestInfo);
    input_info->requestId = 1;
    input_info->isOK = true;

    for(size_t i =0; i < 3; i  )
    {
        input_info->nodeInfo[i].name = "lxkkk";
    }
    
    m_perception_pipeline->submit(input_info);

}

//callbacked by node
void CameraPerception::publishResult(const InputRequestInfoPtr& pInferResult)
{
    std::cout<<"publishResult:  ID: "<<pInferResult->requestId<<std::endl;

    m_thread_manager->postJobs(std::bind(&CameraPerception::MarkObstacleOnImage, this,pInferResult->requestId));
}

void CameraPerception::MarkObstacleOnImage(uint64_t request_id)
{
     std::cout<<"MarkObstacleOnImage:  ID: "<<request_id<<std::endl;
}
int main()
{
    std::unique_ptr<CameraPerception> pCameraPerceptionHandle(new CameraPerception());
    pCameraPerceptionHandle->init();
}

0 人点赞