利用自定义智能指针析构函数实现自动回收的内存池

2024-01-20 12:24:39 浏览数 (1)

mem_pool.h

代码语言:c复制
#pragma once

#include <memory>
#include <mutex>
#include <vector>
#include <atomic>

// 利用自定义智能指针析构函数实现自动回收的内存池
class MemoryPool {
public:
    MemoryPool()
    {
    }
    explicit MemoryPool(size_t count)
    {
        // 池子最少1000个
        maxCount_ = count > 1000 ? count : 1000;
    }
    ~MemoryPool()
    {
        std::unique_lock<std::mutex> lock(lock_);
        for (auto p : freeConn_) {
            free(p);
        }
        freeConn_.clear();
    }
    std::shared_ptr<unsigned char> Malloc(size_t size)
    {
        unsigned char* ptr = nullptr;
        // 如果大于1024字节就直接malloc内存
        if (size > 1024) {
            ptr = (unsigned char*)malloc(size);
            if (ptr == nullptr) {
				return nullptr;
			}
            // 自定义智能指针析构
            return std::shared_ptr<unsigned char>(ptr, [this](unsigned char* p) {
                free(p);
                });
        }
        std::unique_lock<std::mutex> lock(lock_);
        if (freeConn_.empty()) {
            if (curCount_ < maxCount_) {
                // 内存池大小为1024字节
                ptr = (unsigned char*)malloc(1024);
                if (ptr == nullptr) {
                    return nullptr;
                }
                  curCount_;
            } else {
                return nullptr;
            }
        } else {
            ptr = freeConn_.back();
            freeConn_.pop_back();
        }
        // 自定义智能指针析构
        return std::shared_ptr<unsigned char>(ptr, [this](unsigned char* p) {
            // give back to pool
            std::lock_guard<std::mutex> lock(lock_);
            freeConn_.push_back(p);
            });
    }

    size_t MaxCount() const
    {
        return maxCount_;
    }

    size_t CurCount() const
    {
        return curCount_;
    }

    size_t FreeCount() const
    {
        return freeConn_.size();
    }
private:
    size_t maxCount_ = 1000;
    size_t curCount_ = 0;
    std::vector<unsigned char*> freeConn_;
    std::mutex lock_;
};

gtest测试代码

代码语言:c复制
#include <gtest/gtest.h>
#include "mem_pool.h"

TEST(MemPoolTest, TestAlloc1) 
{
    MemoryPool pool;
    for (int i=0;i<2000;  i) {
        pool.Malloc(100);
    }
    // 因为使用完立即就被回收了所以malloc 2000次只需要一个池子
    EXPECT_EQ(1, pool.CurCount());
    EXPECT_EQ(1, pool.FreeCount());
}

TEST(MemPoolTest, TestAlloc2) 
{
    MemoryPool pool;
    std::vector<std::shared_ptr<unsigned char>> vecs;
    for (int i=0;i<2000;  i) {
        auto p = pool.Malloc(100);
        if (p != nullptr) {
            vecs.emplace_back(p);
        }
    }
    EXPECT_TRUE(nullptr == pool.Malloc(1024));
    EXPECT_TRUE(nullptr != pool.Malloc(1025));
    EXPECT_EQ(1000, pool.CurCount());
    EXPECT_EQ(0, pool.FreeCount());
    EXPECT_EQ(1000, vecs.size());

    vecs.pop_back();
    EXPECT_EQ(1, pool.FreeCount());

    vecs.pop_back();
    EXPECT_EQ(2, pool.FreeCount());

    vecs.clear();
    EXPECT_EQ(1000, pool.CurCount());
    EXPECT_EQ(1000, pool.FreeCount());
}


TEST(MemPoolTest, TestAlloc3) 
{
    MemoryPool pool;
    unsigned char* p = nullptr;
    {
        auto ptr = pool.Malloc(100);
        p = ptr.get();
    }
    {
        auto ptr = pool.Malloc(200);
        EXPECT_TRUE(p == ptr.get());
    }
    EXPECT_EQ(1, pool.CurCount());
    EXPECT_EQ(1, pool.FreeCount());
}

0 人点赞