手写AQS-非公平锁

2021-04-08 13:34:56 浏览数 (1)

1. Unsafe工具类

代码语言:javascript复制
package com.shi.flink.unsafeTest;

import sun.misc.Unsafe;

import java.lang.reflect.Field;

/**
 * @author shiye
 * @create 2021-03-30 17:03
 */
public class UnsafeUtil {
    public static Unsafe getInstance() {
        Field field = null;
        try {
            field = Unsafe.class.getDeclaredField("theUnsafe");
            field.setAccessible(true);
            return (Unsafe) field.get(null);
        } catch (NoSuchFieldException e) {
            e.printStackTrace();
        } catch (IllegalAccessException e) {
            e.printStackTrace();
        }
        return null;
    }
}

2. 手写AQS抽象类

代码语言:javascript复制
package com.shi.flink.shilock;

import com.shi.flink.unsafeTest.UnsafeUtil;
import sun.misc.Unsafe;

import java.util.concurrent.locks.AbstractOwnableSynchronizer;
import java.util.concurrent.locks.LockSupport;

/**
 * 自己写抽象AQS实现
 *
 * @author shiye
 * @create 2021-03-30 14:10
 */
public abstract class ShiAQS extends AbstractOwnableSynchronizer implements java.io.Serializable {
    private static final long serialVersionUID = 7373984972572414691L;

    /**
     * 头指针
     */
    private transient volatile Node head;

    /**
     * 尾指针
     */
    private transient volatile Node tail;

    /**
     * 状态值:
     * 0:空闲,
     * 1:正在有人使用
     */
    private volatile int state;

    /**
     * 获取当前状态
     *
     * @return
     */
    protected final int getState() {
        return state;
    }

    /**
     * 设置当前锁的状态
     *
     * @param state
     */
    public void setState(int state) {
        this.state = state;
    }

    /**
     * 使用unsafe类来初始化一些参数值
     */
    private static final Unsafe unsafe = UnsafeUtil.getInstance();
    private static long stateOffset;
    private static long headOffset;
    private static long tailOffset;
    private static long waitStatusOffset;
    private static long nextOffset;

    static {
        try {
            stateOffset = unsafe.objectFieldOffset(ShiAQS.class.getDeclaredField("state"));
            headOffset = unsafe.objectFieldOffset(ShiAQS.class.getDeclaredField("head"));
            tailOffset = unsafe.objectFieldOffset(ShiAQS.class.getDeclaredField("tail"));
            waitStatusOffset = unsafe.objectFieldOffset(Node.class.getDeclaredField("waitStatus"));
            nextOffset = unsafe.objectFieldOffset(Node.class.getDeclaredField("next"));
        } catch (NoSuchFieldException e) {
            e.printStackTrace();
        }
    }

    /**
     * 设置状态
     *
     * @param expect
     * @param update
     * @return
     */
    protected boolean compareAndSetState(int expect, int update) {
        //读取传入对象o在内存中偏移量为offset位置的值与期望值expected作比较
        return unsafe.compareAndSwapInt(this, stateOffset, expect, update);
    }

    /**
     * 设置头指针
     *
     * @param update
     * @return
     */
    private final boolean compareAndSetHead(Node update) {
        return unsafe.compareAndSwapObject(this, headOffset, null, update);
    }

    /**
     * 如果pre节点得waitStatus值为ws,
     * 则把signal赋值给waitStatus
     *
     * @param pre
     * @param ws
     * @param signal
     * @return
     */
    private static boolean compareAndSetWaitStatus(Node pre, int ws, int signal) {
        return unsafe.compareAndSwapInt(pre, waitStatusOffset, ws, signal);
    }

    /**
     * 设置尾指针
     *
     * @param expect
     * @param update
     * @return
     */
    private final boolean compareAndSetTail(Node expect, Node update) {
        return unsafe.compareAndSwapObject(this, tailOffset, expect, update);
    }

    /**
     * 设置下一个节点
     *
     * @param node
     * @param expect
     * @param update
     * @return
     */
    private static final boolean compareAndSetNext(Node node,
                                                   Node expect,
                                                   Node update) {
        return unsafe.compareAndSwapObject(node, nextOffset, expect, update);
    }

    /**
     * 解锁方法
     *
     * @param arg
     */
    protected void release(int arg) throws Exception {
        //尝试去释放占用锁得线程
        boolean tryRelease = tryRelease(arg);

        if (tryRelease) {
            Node h = head;
            if (h != null && h.waitStatus != 0) {
                unparkSuccessor(h);
            }
        }
    }

    /**
     * 尝试去释放占用锁得线程
     *
     * @param arg
     * @return
     * @throws Exception
     */
    protected boolean tryRelease(int arg) throws Exception {
        if (Thread.currentThread() != getExclusiveOwnerThread()) {
            //如果当前线程不是占用锁得线程就抛出异常
            throw new Exception("解锁失败,当前线程不是占用锁得线程无法解锁");
        } else {
            setExclusiveOwnerThread(null);
            this.setState(0);
            return true;
        }
    }

    /**
     * 打断某个线程
     *
     * @return
     */
    protected boolean interruptThread(Thread thread) throws Exception {
        Thread ownerThread = getExclusiveOwnerThread();
        if (ownerThread == thread) {
            //如果是正在运行得线程
            compareAndSetState(1, 0);
            setExclusiveOwnerThread(null);
        } else if (head != null) {
            //再对类中查找当前线程,并且取消排队
            for (Node next1 = head.next; next1 != null; next1 = next1.next) {
                if (next1.thread == thread) {
                    compareAndSetWaitStatus(next1, next1.waitStatus, 1);
                }
            }
        }
        //解锁
        thread.interrupt();
        System.out.println(thread.getName()   " 已经中断了 ====> ");
        unparkSuccessor(head);
        System.out.println(thread.getName()   " 已经结束了 ====> ");
        return false;
    }

    /**
     * 自定义一个内部类Node节点
     */
    static final class Node {

        //共享模式标记
        static final Node shared = new Node();

        //独占锁标记
        static final Node excusive = null;

        //waitStatus值,指示线程已取消
        static final int cancelled = 1;

        //waitStatus值,用于指示后续线程需要解除等待状态
        static final int signal = -1;

        //waitStatus值,指示线程正在等待条件
        static final int condition = -2;

        //waitStatus值,指示下一个acquireShared应该 无条件传播
        static final int propagate = -3;

        //锁的等待状态
        volatile int waitStatus;

        //前指针
        volatile Node prev;

        //后指针
        volatile Node next;

        //线程
        volatile Thread thread;

        Node nextWaiter;

        //是否是共享锁
        final boolean isShared() {
            return nextWaiter == shared;
        }

        //无参构造
        public Node() {
        }

        public Node(Node nextWaiter, Thread thread) {
            this.nextWaiter = nextWaiter;
            this.thread = thread;
        }

        /**
         * 获取前节点
         *
         * @return
         */
        public Node getPrev() {
            Node p = prev;
            if (p == null) {
                throw new NullPointerException("前节点不能为空");
            }
            return p;
        }
    }


    /**
     * 获得
     *
     * @param arg
     */
    public void acquire(int arg) {
        //1.尝试去排队
        boolean tryAcquire = tryAcquire(arg);
        if (!tryAcquire) {
            //2.如果抢占锁失败,就去排队
            Node node = addWaiter(Node.excusive);

            //3.对已经再队列中的节点,进行休眠等侯
            acquireQueued(node, arg);
        }
    }

    /**
     * 先尝试去排队
     * 1.先获取锁得状态,如果状态为0,就尝试去占用一次锁
     * 否则返回占用失败
     *
     * @param arg
     * @return true:表示抢占锁成功
     * false:表示抢占所失败
     */
    public final boolean tryAcquire(int arg) {
        Thread current = Thread.currentThread();
        int state = getState();
        if (state == 0) {
            //如果空闲了,就尝试去占用一次锁
            if (compareAndSetState(0, arg)) {
                //抢占成功就返回true,并设置线程
                setExclusiveOwnerThread(current);
                return true;
            }
        } else if (getExclusiveOwnerThread() == current) {
            //如果当前当前线程多次抢占锁,就将状态 arg
            int nextState = state   arg;
            if (nextState < 0) {
                throw new Error("超过最大锁计数");
            }
            setState(nextState);
            return true;
        }
        return false;
    }

    /**
     * 添加等待队列
     *
     * @param mode
     */
    public Node addWaiter(Node mode) {
        Node node = new Node(mode, Thread.currentThread());
        Node temp = tail;
        if (temp == null) {
            //入队
            enQueue(node);
            return node;
        } else {
            //如果队列中不为空,就把当前节点添加到尾节点中
            node.prev = temp;
            if (compareAndSetTail(temp, node)) {
                temp.next = node;
                return node;
            }
        }
        return node;
    }

    /**
     * 入队
     * 把node节点添加到队列中,
     * 如果队列为null就初始化一个队列并且把node节点添加到尾节点中
     *
     * @param node 返回当前节点
     */
    public Node enQueue(Node node) {
        while (true) {
            Node temp = tail;
            if (temp == null) {
                //创建一个头指针
                compareAndSetHead(new Node());
                //让尾指针也指向头指针(空节点)
                tail = head;
            } else {
                node.prev = temp;
                if (compareAndSetTail(temp, node)) {
                    temp.next = node;
                    return node;
                }
            }
        }
    }

    /**
     * @param node 当前正在侯队中得节点
     * @param arg
     * @return
     */
    protected boolean acquireQueued(Node node, int arg) {
        boolean failed = true;

        try {
            //是否被打断,默认false
            boolean interrupted = false;
            while (true) {
                final Node p = node.getPrev();
                if (p == head && tryAcquire(arg)) {
                    //如果是他的头节点是head,并且尝试抢占锁成功就出队,让当前线程运行
                    setHead(node);
                    p.next = null;//利于gc回收
                    failed = false;
                    return interrupted;
                }

                if (shouldParkAfterFailedAcquire(p, node) && parkAndCheckInterrupt()) {
                    //pre节点得waitStatus 设置成-1,并且让当前线程阻塞,打断当前线程
                    interrupted = true;
                }
            }

        } finally {
            if (failed) cancelAcquire(node);
        }
    }

    protected final void cancelAcquire(Node node) {
        if (node == null) return;

        node.thread = null;
        Node pre = node.prev;

        while (pre.waitStatus > 0) {
            node.prev = pre = pre.prev;
        }
        Node predNext = pre.next;
        node.waitStatus = Node.cancelled;

        if (node == tail && compareAndSetTail(node, pre)) {
            compareAndSetNext(pre, predNext, null);
        } else {
            // If successor needs signal, try to set pred's next-link
            // so it will get one. Otherwise wake it up to propagate.
            int ws;
            if (pre != head &&
                    ((ws = pre.waitStatus) == Node.signal ||
                            (ws <= 0 && compareAndSetWaitStatus(pre, ws, Node.signal))) &&
                    pre.thread != null) {
                Node next = node.next;
                if (next != null && next.waitStatus <= 0)
                    compareAndSetNext(pre, predNext, next);
            } else {
                unparkSuccessor(node);
            }

            node.next = node; // help GC
        }

    }

    /**
     * 解锁必须成功
     *
     * @param node
     */
    protected final void unparkSuccessor(Node node) {
        int ws = node.waitStatus;
        if (ws < 0) {
            compareAndSetWaitStatus(node, ws, 0);
        }

        /**
         * AQS源码是这样实现得
         * 如果当前节点不为空,并且用户取消了,就从尾节点往前遍历一个,直到找到最前面得一个节点,解锁当前线程
         */
        Node next1 = node.next;
//        if (next1 == null || next1.waitStatus > 0) {
//            next1 = null;
//            for (Node t = tail; t != null && t != node; t = t.prev) {
//                if (t.waitStatus <= 0) {
//                    next1 = t;
//                }
//            }
//        }

        /**
         * 我自己实现,从前往后找
         */
        if (next1 != null && next1.waitStatus > 0) {
            for (next1 = next1.next; next1 != null; next1 = next1.next) {
                if (next1.waitStatus <= 0) {
                    break;
                }
            }
        }

        if (next1 != null) {
            //唤醒下一个线程
            System.out.println(next1.thread.getName()   " 开始唤醒了 ====> ");
            LockSupport.unpark(next1.thread);
            System.out.println(next1.thread.getName()   " 已经唤醒了 ====> ");

        }
    }

    /**
     * 将pre节点得waitStatus 设置成-1
     *
     * @param pre
     * @param node
     * @return
     */
    protected static boolean shouldParkAfterFailedAcquire(Node pre, Node node) {
        //获取node节点得前一个节点得状态
        int ws = pre.waitStatus;

        //如果是-1 就返回true
        if (ws == Node.signal) {
            return true;
        }

        if (ws > 0) {
            do {
                pre = pre.prev;
                node.prev = pre;
            } while (pre.waitStatus > 0);
            pre.next = node;
        } else {
            //设置成-1
            boolean flag = compareAndSetWaitStatus(pre, ws, Node.signal);
//            System.out.println("设置成-1是否成功:"   flag);
        }
        return false;
    }

    /**
     * 阻塞当前线程,并且返回当前线程得打断状态
     *
     * @return true: 打断线程成功
     */
    protected final boolean parkAndCheckInterrupt() {
        //打断线程,让线程阻塞
        LockSupport.park(this);
        return Thread.interrupted();
    }

    public Node getHead() {
        return head;
    }

    public void setHead(Node head) {
        this.head = head;
    }
}

3.非公平所实现

代码语言:javascript复制
package com.shi.flink.shilock;

import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.Lock;

/**
 * 自定义非公平锁
 * @author shiye
 * @create 2021-03-30 11:02
 */
public class ShiNonfairLock extends ShiAQS implements Lock, java.io.Serializable {

    private static final long serialVersionUID = 7373984872572414699L;

    @Override
    public void lock() {
        if(compareAndSetState(0, 1)){
            //如果抢到了锁,就把当前线程设置进去
            setExclusiveOwnerThread(Thread.currentThread());
        }else{
//            否则就去排队
            acquire(1);
        }
    }

    @Override
    public void lockInterruptibly() throws InterruptedException {

    }

    @Override
    public boolean tryLock() {
        return false;
    }

    @Override
    public boolean tryLock(long time, TimeUnit unit) throws InterruptedException {
        return false;
    }

    @Override
    public void unlock() {
        try {
            super.release(1);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    @Override
    public Condition newCondition() {
        return null;
    }


    /**
     *  打断某个线程(自己瞎写的,有bug)
     * @param thread
     * @return
     */
    public boolean interruptThread(Thread thread) throws Exception {
        return super.interruptThread(thread);
    }
}

4.测试

代码语言:javascript复制
package com.shi.flink.shilock;

import java.util.concurrent.TimeUnit;

/**
 * @author shiye
 * @create 2021-03-31 17:09
 */
public class MyLockTest {

    public static void main(String[] args) throws Exception {
        ShiNonfairLock lock = new ShiNonfairLock();

        new Thread(() -> {
            try {
                System.out.println("A 线程进入到...加锁过程");
                lock.lock();
                System.out.println("A 已经抢占到锁...休眠10s后运行......");
                TimeUnit.SECONDS.sleep(10);
            } catch (Exception e) {
                e.printStackTrace();
            } finally {
                System.out.println("A线程运行完成,开始解锁....");
                lock.unlock();
            }

        }, "A").start();


        TimeUnit.SECONDS.sleep(1);
        Thread B = new Thread(() -> {
            try {
                System.out.println("B 线程进入到...加锁过程");
                lock.lock();
                System.out.println("B 已经抢占到锁...休眠10s后运行......");
                TimeUnit.SECONDS.sleep(10);
            } catch (Exception e) {
                e.printStackTrace();
            } finally {
                System.out.println("B线程运行完成,开始解锁....");
                lock.unlock();
            }

        }, "B");
        B.start();

        TimeUnit.SECONDS.sleep(1);
        new Thread(() -> {
            try {
                System.out.println("C 线程进入到...加锁过程");
                lock.lock();
                System.out.println("C 已经抢占到锁...休眠10s后运行......");
                TimeUnit.SECONDS.sleep(10);
            } catch (Exception e) {
                e.printStackTrace();
            } finally {
                System.out.println("C线程运行完成,开始解锁....");
                lock.unlock();
            }

        }, "C").start();

        TimeUnit.SECONDS.sleep(1);
        new Thread(() -> {
            try {
                System.out.println("D 线程进入到...加锁过程");
                lock.lock();
                System.out.println("D 已经抢占到锁...休眠10s后运行......");
                TimeUnit.SECONDS.sleep(10);
            } catch (Exception e) {
                e.printStackTrace();
            } finally {
                System.out.println("D线程运行完成,开始解锁....");
                lock.unlock();
            }

        }, "D").start();

//        TimeUnit.SECONDS.sleep(1);
//        System.out.println("强制让 "   B.getName()   " 线程中断...");
//        lock.interruptThread(B);
    }
}

0 人点赞