回收站拦截器

2022-10-27 11:43:16 浏览数 (1)

阿谀奉承者的喉咙是一座敞开的坟墓——佚名

基于mybatis-plus的租户拦截器TenantLineInnerInterceptor复制过来拓展

kotlin代码如下:

代码语言:javascript复制
package com.ruben.simpleboot.interceptor

import com.baomidou.mybatisplus.core.metadata.TableInfoHelper
import com.baomidou.mybatisplus.core.plugins.InterceptorIgnoreHelper
import com.baomidou.mybatisplus.core.toolkit.*
import com.baomidou.mybatisplus.extension.parser.JsqlParserSupport
import com.baomidou.mybatisplus.extension.plugins.inner.InnerInterceptor
import com.baomidou.mybatisplus.extension.toolkit.PropertyMapper
import net.sf.jsqlparser.expression.*
import net.sf.jsqlparser.expression.Function
import net.sf.jsqlparser.expression.operators.conditional.AndExpression
import net.sf.jsqlparser.expression.operators.conditional.OrExpression
import net.sf.jsqlparser.expression.operators.relational.*
import net.sf.jsqlparser.schema.Column
import net.sf.jsqlparser.schema.Table
import net.sf.jsqlparser.statement.Statement
import net.sf.jsqlparser.statement.delete.Delete
import net.sf.jsqlparser.statement.insert.Insert
import net.sf.jsqlparser.statement.select.*
import net.sf.jsqlparser.statement.update.Update
import org.apache.ibatis.executor.Executor
import org.apache.ibatis.executor.statement.StatementHandler
import org.apache.ibatis.mapping.BoundSql
import org.apache.ibatis.mapping.MappedStatement
import org.apache.ibatis.mapping.SqlCommandType
import org.apache.ibatis.session.ResultHandler
import org.apache.ibatis.session.RowBounds
import java.sql.Connection
import java.sql.SQLException
import java.util.*
import java.util.function.Consumer
import java.util.stream.Collectors


/**
 * 回收站插件
 *
 * @author VampireAchao
 * @since 2022/9/30 16:20
 */
class RecycleBinInterceptor(private var recycleBinLineHandler: RecycleBinLineHandler? = null) : JsqlParserSupport(),
    InnerInterceptor {

    @Throws(SQLException::class)
    override fun beforeQuery(
        executor: Executor?,
        ms: MappedStatement,
        parameter: Any?,
        rowBounds: RowBounds?,
        resultHandler: ResultHandler<*>?,
        boundSql: BoundSql?
    ) {
        if (InterceptorIgnoreHelper.willIgnoreOthersByKey(ms.id, "recycleBin")) return
        val mpBs = PluginUtils.mpBoundSql(boundSql)
        mpBs.sql(parserSingle(mpBs.sql(), null))
    }

    /**
     * 执行 SQL 解析
     *
     * @param statement JsqlParser Statement
     * @return sql
     */
    override fun processParser(statement: Statement?, index: Int, sql: String?, obj: Any?): String {
        if (logger.isDebugEnabled) {
            logger.debug("SQL to parse, SQL: $sql")
        }
        var sql: String? = null
        if (statement is Insert) {
            processInsert((statement as Insert?)!!, index, sql, obj)
        } else if (statement is Select) {
            processSelect((statement as Select?)!!, index, sql, obj)
        } else if (statement is Update) {
            processUpdate((statement as Update?)!!, index, sql, obj)
        } else if (statement is Delete) {
            if (!RecycleThreadLocal.REAL_DELETE.get() && !recycleBinLineHandler!!.ignoreTable(statement.table.name)) {
                // 如果当前需要回收
                sql = Update().apply {
                    table = statement.table
                    where = statement.where
                    addUpdateSet(
                        Column(recycleBinLineHandler!!.getRecycleBinColumn()),
                        Column("NOW()")
                    )
                }.toString()
            } else {
                processDelete((statement as Delete?)!!, index, sql, obj)
            }
        }
        sql = sql ?: statement.toString()
        if (logger.isDebugEnabled) {
            logger.debug("parse the finished SQL: $sql")
        }
        return sql
    }

    override fun beforePrepare(sh: StatementHandler?, connection: Connection?, transactionTimeout: Int?) {
        val mpSh = PluginUtils.mpStatementHandler(sh)
        val ms = mpSh.mappedStatement()
        val sct = ms.sqlCommandType
        if (sct == SqlCommandType.INSERT ||
            sct == SqlCommandType.DELETE ||
            sct == SqlCommandType.UPDATE ||
            sct == SqlCommandType.SELECT
        ) {
            if (InterceptorIgnoreHelper.willIgnoreOthersByKey(ms.id, "recycleBin")) {
                return
            }
            val mpBs = mpSh.mPBoundSql()
            mpBs.sql(parserMulti(mpBs.sql(), null))
        }
    }

    override fun processSelect(select: Select, index: Int, sql: String?, obj: Any?) {
        processSelectBody(select.selectBody)
        val withItemsList = select.withItemsList
        if (!CollectionUtils.isEmpty(withItemsList)) {
            withItemsList.forEach(Consumer { selectBody: WithItem? ->
                processSelectBody(
                    selectBody
                )
            })
        }
    }

    protected fun processSelectBody(selectBody: SelectBody?) {
        if (selectBody == null) {
            return
        }
        if (selectBody is PlainSelect) {
            processPlainSelect(selectBody)
        } else if (selectBody is WithItem) {
            processSelectBody(selectBody.subSelect.selectBody)
        } else {
            val operationList = selectBody as SetOperationList
            val selectBodyList = operationList.selects
            if (CollectionUtils.isNotEmpty(selectBodyList)) {
                selectBodyList.forEach(Consumer { body: SelectBody? -> processSelectBody(body) })
            }
        }
    }

    override fun processInsert(insert: Insert, index: Int, sql: String?, obj: Any?) {
        if (recycleBinLineHandler!!.ignoreTable(insert.table.name)) {
            // 过滤退出执行
            return
        }
        val columns = insert.columns
        if (CollectionUtils.isEmpty(columns)) {
            // 针对不给列名的insert 不处理
            return
        }
        val recycleBinColumn = recycleBinLineHandler!!.getRecycleBinColumn()
        if (recycleBinLineHandler!!.ignoreInsert(columns, recycleBinColumn)) {
            // 针对已给出回收列的insert 不处理
            return
        }
        columns.add(Column(recycleBinColumn))

        // fixed gitee pulls/141 duplicate update
        val duplicateUpdateColumns = insert.duplicateUpdateExpressionList
        if (CollectionUtils.isNotEmpty(duplicateUpdateColumns)) {
            val equalsTo = if (RecycleThreadLocal.ONLY_SHOW_RECYCLE.get()) NotEqualsTo() else EqualsTo()
            equalsTo.leftExpression = StringValue(recycleBinColumn)
            equalsTo.rightExpression = recycleBinLineHandler!!.getRecycleBin()
            duplicateUpdateColumns.add(equalsTo)
        }
        val select = insert.select
        if (select != null) {
            processInsertSelect(select.selectBody)
        } else if (insert.itemsList != null) {
            // fixed github pull/295
            val itemsList = insert.itemsList
            if (itemsList is MultiExpressionList) {
                itemsList.expressionLists.forEach(Consumer { el: ExpressionList ->
                    el.expressions.add(
                        recycleBinLineHandler!!.getRecycleBin()
                    )
                })
            } else {
                (itemsList as ExpressionList).expressions.add(
                    recycleBinLineHandler!!.getRecycleBin()
                )
            }
        } else {
            throw ExceptionUtils.mpe("Failed to process multiple-table update, please exclude the tableName or statementId")
        }
    }

    /**
     * update 语句处理
     */
    override fun processUpdate(update: Update, index: Int, sql: String?, obj: Any?) {
        val table = update.table
        if (recycleBinLineHandler!!.ignoreTable(table.name)) {
            // 过滤退出执行
            return
        }
        update.where = andExpression(table, update.where)
        val tableInfo = TableInfoHelper.getTableInfo(update.table.name)
        if (update.updateSets.any { it.columns.any { it.columnName == tableInfo.logicDeleteFieldInfo.column } } &&
            !RecycleThreadLocal.REAL_DELETE.get()
        ) {
            // 如果当前需要回收
            update.updateSets.clear()
            update.addUpdateSet(
                Column(recycleBinLineHandler!!.getRecycleBinColumn()),
                Column("NOW()")
            )
        }
    }

    /**
     * delete 语句处理
     */
    override fun processDelete(delete: Delete, index: Int, sql: String?, obj: Any?) {
        if (recycleBinLineHandler!!.ignoreTable(delete.table.name)) {
            // 过滤退出执行
            return
        }
        delete.where = andExpression(delete.table, delete.where)
    }

    /**
     * delete update 语句 where 处理
     */
    protected fun andExpression(table: Table?, where: Expression?): BinaryExpression {
        //获得where条件表达式
        val equalsTo = if (RecycleThreadLocal.ONLY_SHOW_RECYCLE.get()) NotEqualsTo() else EqualsTo()
        equalsTo.leftExpression = getAliasColumn(table)
        equalsTo.rightExpression = recycleBinLineHandler!!.getRecycleBin()
        return if (null != where) {
            if (where is OrExpression) {
                AndExpression(equalsTo, Parenthesis(where))
            } else {
                AndExpression(equalsTo, where)
            }
        } else equalsTo
    }


    /**
     * 处理 insert into select
     *
     *
     * 进入这里表示需要 insert 的表启用了回收,则 select 的表都启动了
     *
     * @param selectBody SelectBody
     */
    protected fun processInsertSelect(selectBody: SelectBody) {
        val plainSelect = selectBody as PlainSelect
        val fromItem = plainSelect.fromItem
        if (fromItem is Table) {
            // fixed gitee pulls/141 duplicate update
            processPlainSelect(plainSelect)
            appendSelectItem(plainSelect.selectItems)
        } else if (fromItem is SubSelect) {
            appendSelectItem(plainSelect.selectItems)
            processInsertSelect(fromItem.selectBody)
        }
    }

    /**
     * 追加 SelectItem
     *
     * @param selectItems SelectItem
     */
    protected fun appendSelectItem(selectItems: MutableList<SelectItem?>) {
        if (CollectionUtils.isEmpty(selectItems)) {
            return
        }
        if (selectItems.size == 1) {
            val item = selectItems[0]
            if (item is AllColumns || item is AllTableColumns) {
                return
            }
        }
        selectItems.add(SelectExpressionItem(Column(recycleBinLineHandler!!.getRecycleBinColumn())))
    }

    /**
     * 处理 PlainSelect
     */
    protected fun processPlainSelect(plainSelect: PlainSelect) {
        //#3087 github
        val selectItems = plainSelect.selectItems
        if (CollectionUtils.isNotEmpty(selectItems)) {
            selectItems.forEach(Consumer { selectItem: SelectItem? ->
                processSelectItem(
                    selectItem
                )
            })
        }

        // 处理 where 中的子查询
        val where = plainSelect.where
        processWhereSubSelect(where)

        // 处理 fromItem
        val fromItem = plainSelect.fromItem
        val list = processFromItem(fromItem)
        var mainTables: MutableList<Table?>? = ArrayList(list)

        // 处理 join
        val joins = plainSelect.joins
        if (CollectionUtils.isNotEmpty(joins)) {
            mainTables = processJoins(mainTables, joins)
        }

        // 当有 mainTable 时,进行 where 条件追加
        if (CollectionUtils.isNotEmpty(mainTables)) {
            plainSelect.where = builderExpression(where, mainTables)
        }
    }

    private fun processFromItem(fromItem: FromItem): List<Table?> {
        // 处理括号括起来的表达式
        var fromItem: FromItem? = fromItem
        while (fromItem is ParenthesisFromItem) {
            fromItem = fromItem.fromItem
        }
        val mainTables: MutableList<Table?> = ArrayList()
        // 无 join 时的处理逻辑
        if (fromItem is Table) {
            mainTables.add(fromItem)
        } else if (fromItem is SubJoin) {
            // SubJoin 类型则还需要添加上 where 条件
            val tables: List<Table?>? = processSubJoin(fromItem)
            mainTables.addAll(tables!!)
        } else {
            // 处理下 fromItem
            processOtherFromItem(fromItem)
        }
        return mainTables
    }

    /**
     * 处理where条件内的子查询
     *
     *
     * 支持如下:
     * 1. in
     * 2. =
     * 3. >
     * 4. <
     * 5. >=
     * 6. <=
     * 7. <>
     * 8. EXISTS
     * 9. NOT EXISTS
     *
     *
     * 前提条件:
     * 1. 子查询必须放在小括号中
     * 2. 子查询一般放在比较操作符的右边
     *
     * @param where where 条件
     */
    protected fun processWhereSubSelect(where: Expression?) {
        if (where == null) {
            return
        }
        if (where is FromItem) {
            processOtherFromItem(where as FromItem?)
            return
        }
        if (where.toString().indexOf("SELECT") > 0) {
            // 有子查询
            if (where is BinaryExpression) {
                // 比较符号 , and , or , 等等
                val expression = where
                processWhereSubSelect(expression.leftExpression)
                processWhereSubSelect(expression.rightExpression)
            } else if (where is InExpression) {
                // in
                val inExpression = where.rightExpression
                if (inExpression is SubSelect) {
                    processSelectBody(inExpression.selectBody)
                }
            } else if (where is ExistsExpression) {
                // exists
                processWhereSubSelect(where.rightExpression)
            } else if (where is NotExpression) {
                // not exists
                processWhereSubSelect(where.expression)
            } else if (where is Parenthesis) {
                processWhereSubSelect(where.expression)
            }
        }
    }

    protected fun processSelectItem(selectItem: SelectItem?) {
        if (selectItem is SelectExpressionItem) {
            val selectExpressionItem = selectItem
            if (selectExpressionItem.expression is SubSelect) {
                processSelectBody((selectExpressionItem.expression as SubSelect).selectBody)
            } else if (selectExpressionItem.expression is Function) {
                processFunction(selectExpressionItem.expression as Function)
            }
        }
    }

    /**
     * 处理函数
     *
     * 支持: 1. select fun(args..) 2. select fun1(fun2(args..),args..)
     *
     *
     *
     *  fixed gitee pulls/141
     *
     * @param function
     */
    protected fun processFunction(function: Function) {
        val parameters = function.parameters
        parameters?.expressions?.forEach(Consumer { expression: Expression? ->
            if (expression is SubSelect) {
                processSelectBody(expression.selectBody)
            } else if (expression is Function) {
                processFunction(expression)
            }
        })
    }

    /**
     * 处理子查询等
     */
    protected fun processOtherFromItem(fromItem: FromItem?) {
        // 去除括号
        var fromItem = fromItem
        while (fromItem is ParenthesisFromItem) {
            fromItem = fromItem.fromItem
        }
        if (fromItem is SubSelect) {
            val subSelect = fromItem
            if (subSelect.selectBody != null) {
                processSelectBody(subSelect.selectBody)
            }
        } else if (fromItem is ValuesList) {
            logger.debug("Perform a subQuery, if you do not give us feedback")
        } else if (fromItem is LateralSubSelect) {
            val lateralSubSelect = fromItem
            if (lateralSubSelect.subSelect != null) {
                val subSelect = lateralSubSelect.subSelect
                if (subSelect.selectBody != null) {
                    processSelectBody(subSelect.selectBody)
                }
            }
        }
    }

    /**
     * 处理 sub join
     *
     * @param subJoin subJoin
     * @return Table subJoin 中的主表
     */
    private fun processSubJoin(subJoin: SubJoin): MutableList<Table?>? {
        var mainTables: MutableList<Table?>? = ArrayList()
        if (subJoin.joinList != null) {
            val list = processFromItem(subJoin.left)
            mainTables!!.addAll(list)
            mainTables = processJoins(mainTables, subJoin.joinList)
        }
        return mainTables
    }

    /**
     * 处理 joins
     *
     * @param mainTables 可以为 null
     * @param joins      join 集合
     * @return List<Table> 右连接查询的 Table 列表
    </Table> */
    private fun processJoins(mainTables: MutableList<Table?>?, joins: List<Join?>): MutableList<Table?>? {
        // join 表达式中最终的主表
        var mainTables = mainTables
        var mainTable: Table? = null
        // 当前 join 的左表
        var leftTable: Table? = null
        if (mainTables == null) {
            mainTables = ArrayList()
        } else if (mainTables.size == 1) {
            mainTable = mainTables[0]
            leftTable = mainTable
        }

        //对于 on 表达式写在最后的 join,需要记录下前面多个 on 的表名
        val onTableDeque: Deque<List<Table?>?> = LinkedList()
        for (join in joins) {
            // 处理 on 表达式
            val joinItem = join!!.rightItem

            // 获取当前 join 的表,subJoint 可以看作是一张表
            var joinTables: MutableList<Table?>? = null
            if (joinItem is Table) {
                joinTables = ArrayList()
                joinTables.add(joinItem)
            } else if (joinItem is SubJoin) {
                joinTables = processSubJoin(joinItem)
            }
            if (joinTables != null) {

                // 如果是隐式内连接
                if (join.isSimple) {
                    mainTables!!.addAll(joinTables)
                    continue
                }

                // 当前表是否忽略
                val joinTable = joinTables[0]
                var onTables: List<Table?>? = null
                // 如果不要忽略,且是右连接,则记录下当前表
                if (join.isRight) {
                    mainTable = joinTable
                    if (leftTable != null) {
                        onTables = listOf(leftTable)
                    }
                } else if (join.isLeft) {
                    onTables = listOf(joinTable)
                } else if (join.isInner) {
                    onTables = if (mainTable == null) {
                        listOf(joinTable)
                    } else {
                        Arrays.asList(mainTable, joinTable)
                    }
                    mainTable = null
                }
                mainTables = ArrayList()
                if (mainTable != null) {
                    mainTables.add(mainTable)
                }

                // 获取 join 尾缀的 on 表达式列表
                val originOnExpressions = join.onExpressions
                // 正常 join on 表达式只有一个,立刻处理
                if (originOnExpressions.size == 1 && onTables != null) {
                    val onExpressions: MutableList<Expression?> = LinkedList()
                    onExpressions.add(builderExpression(originOnExpressions.iterator().next(), onTables))
                    join.onExpressions = onExpressions
                    leftTable = joinTable
                    continue
                }
                // 表名压栈,忽略的表压入 null,以便后续不处理
                onTableDeque.push(onTables)
                // 尾缀多个 on 表达式的时候统一处理
                if (originOnExpressions.size > 1) {
                    val onExpressions: MutableCollection<Expression?> = LinkedList()
                    for (originOnExpression in originOnExpressions) {
                        val currentTableList = onTableDeque.poll()
                        if (CollectionUtils.isEmpty(currentTableList)) {
                            onExpressions.add(originOnExpression)
                        } else {
                            onExpressions.add(builderExpression(originOnExpression, currentTableList))
                        }
                    }
                    join.onExpressions = onExpressions
                }
                leftTable = joinTable
            } else {
                processOtherFromItem(joinItem)
                leftTable = null
            }
        }
        return mainTables
    }

    /**
     * 处理条件
     */
    protected fun builderExpression(currentExpression: Expression?, tables: List<Table?>?): Expression? {
        // 没有表需要处理直接返回
        if (CollectionUtils.isEmpty(tables)) {
            return currentExpression
        }
        // 构造每张表的条件
        val tempTables = tables!!.stream()
            .filter { x: Table? ->
                !recycleBinLineHandler!!.ignoreTable(
                    x!!.name
                )
            }
            .collect(Collectors.toList())

        // 没有表需要处理直接返回
        if (CollectionUtils.isEmpty(tempTables)) {
            return currentExpression
        }
        val recycleBin = recycleBinLineHandler!!.getRecycleBin()
        val equalsTos = tempTables.stream()
            .map { item: Table? ->
                if (RecycleThreadLocal.ONLY_SHOW_RECYCLE.get()) NotEqualsTo(
                    getAliasColumn(item),
                    recycleBin
                ) else EqualsTo(
                    getAliasColumn(item),
                    recycleBin
                )
            }
            .collect(Collectors.toList())

        // 注入的表达式
        var injectExpression: Expression = equalsTos[0]
        // 如果有多表,则用 and 连接
        if (equalsTos.size > 1) {
            for (i in 1 until equalsTos.size) {
                injectExpression = AndExpression(injectExpression, equalsTos[i])
            }
        }
        if (currentExpression == null) {
            return injectExpression
        }
        return if (currentExpression is OrExpression) {
            AndExpression(Parenthesis(currentExpression), injectExpression)
        } else {
            AndExpression(currentExpression, injectExpression)
        }
    }

    /**
     * 回收字段别名设置
     *
     * recycleBin 或 tableAlias.recycleBin
     *
     * @param table 表对象
     * @return 字段
     */
    protected fun getAliasColumn(table: Table?): Column {
        val column = StringBuilder()
        // 为了兼容隐式内连接,没有别名时条件就需要加上表名
        if (table!!.alias != null) {
            column.append(table.alias.name)
        } else {
            column.append(table.name)
        }
        column.append(StringPool.DOT).append(recycleBinLineHandler!!.getRecycleBinColumn())
        return Column(column.toString())
    }

    override fun setProperties(properties: Properties?) {
        PropertyMapper.newInstance(properties).whenNotBlank("recycleBinLineHandler",
            { clazzName: String? ->
                ClassUtils.newInstance(
                    clazzName
                )
            }
        ) { recycleBinLineHandler: RecycleBinLineHandler? ->
            this.recycleBinLineHandler = recycleBinLineHandler
        }
    }

}

interface RecycleBinLineHandler {

    companion object {
        const val NOT_RECYCLE_VALUE = "2001-01-01 00:00:00"
    }

    /**
     * 获取回收 ID 值表达式,只支持单个 ID 值
     *
     *
     *
     * @return 回收 ID 值表达式
     */
    fun getRecycleBin(): Expression? {
        return StringValue(NOT_RECYCLE_VALUE)
    }

    /**
     * 获取回收字段名
     *
     *
     * 默认字段名叫: recycleBin_id
     *
     * @return 回收字段名
     */
    fun getRecycleBinColumn(): String? {
        return "gmt_recycled"
    }

    /**
     * 根据表名判断是否忽略拼接回收条件
     *
     *
     * 默认都要进行解析并拼接回收条件
     *
     * @param tableName 表名
     * @return 是否忽略, true:表示忽略,false:需要解析并拼接回收条件
     */
    fun ignoreTable(tableName: String?): Boolean {
        // 只有包含了回收时间字段的,才需要拼接回收条件
        return TableInfoHelper.getTableInfo(tableName).fieldList.none { f -> f.column == getRecycleBinColumn() }
    }

    /**
     * 忽略插入回收字段逻辑
     *
     * @param columns        插入字段
     * @param recycleBinColumn 回收 ID 字段
     * @return
     */
    fun ignoreInsert(columns: List<Column>, recycleBinColumn: String?): Boolean {
        return columns.stream().map { obj: Column -> obj.columnName }.anyMatch { i: String ->
            i.equals(recycleBinColumn, ignoreCase = true)
        }
    }
}

class RecycleThreadLocal {

    companion object {
        val ONLY_SHOW_RECYCLE: ThreadLocal<Boolean> = InheritableThreadLocal.withInitial { false }
        val REAL_DELETE: ThreadLocal<Boolean> = InheritableThreadLocal.withInitial { false }
    }
}

配置代码如下:

代码语言:javascript复制
package com.ruben.simpleboot.config

import com.baomidou.mybatisplus.extension.plugins.MybatisPlusInterceptor
import com.baomidou.mybatisplus.extension.plugins.inner.PaginationInnerInterceptor
import com.ruben.simpleboot.interceptor.RecycleBinInterceptor
import com.ruben.simpleboot.interceptor.RecycleBinLineHandler
import org.springframework.context.annotation.Bean
import org.springframework.context.annotation.Configuration

@Configuration
class MybatisPlusConfig {

    @Bean
    fun mybatisPlusInterceptor(): MybatisPlusInterceptor {
        return MybatisPlusInterceptor()
            .apply { addInnerInterceptor(PaginationInnerInterceptor()) }
            .apply {
                addInnerInterceptor(RecycleBinInterceptor(object : RecycleBinLineHandler {}))
            }
    }
}

使用:

代码语言:javascript复制
package com.ruben.simpleboot

import com.ruben.simpleboot.interceptor.RecycleBinLineHandler
import com.ruben.simpleboot.interceptor.RecycleThreadLocal
import com.ruben.simpleboot.po.RoleInfo
import com.ruben.simpleboot.po.UserInfo
import com.ruben.simpleboot.po.UserRole
import io.github.vampireachao.stream.plugin.mybatisplus.Database
import io.github.vampireachao.stream.plugin.mybatisplus.One
import io.github.vampireachao.stream.plugin.mybatisplus.OneToManyToOne
import org.junit.jupiter.api.Assertions
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
import org.springframework.boot.test.context.SpringBootTest
import java.time.LocalDateTime
import java.time.format.DateTimeFormatter

@SpringBootTest
class SimpleBootApplicationTests {

    @BeforeEach
    fun clear() {
        RecycleThreadLocal.ONLY_SHOW_RECYCLE.set(false)
        RecycleThreadLocal.REAL_DELETE.set(false)
    }

    @Test
    fun onlyShowNotRecycle() {
        // 默认查询没回收的
        val userIdRolesMap = OneToManyToOne
            .of<UserRole, Long, String, RoleInfo>(UserRole::getUserId).eq(1L)
            .value(UserRole::getRoleId).attachKey(RoleInfo::getId).query()
        Assertions.assertFalse(userIdRolesMap[1L]?.isEmpty() ?: false)
    }

    @Test
    fun onlyShowRecycle() {
        // 查询回收了的
        RecycleThreadLocal.ONLY_SHOW_RECYCLE.set(true)
        val userIdRolesMap = OneToManyToOne
            .of<UserRole, Long, String, RoleInfo>(UserRole::getUserId).eq(1L)
            .value(UserRole::getRoleId).attachKey(RoleInfo::getId).query()
        Assertions.assertTrue(userIdRolesMap[1L]?.isEmpty() ?: false)
    }

    @Test
    fun recycleLogicDeleteTest() {
        // 默认删除变为回收数据
        Database.removeById(1L, UserInfo::class.java)
        Assertions.assertNull(One.of(UserInfo::getId).eq(1L).query())
        // 查询回收站
        RecycleThreadLocal.ONLY_SHOW_RECYCLE.set(true)
        Assertions.assertNotNull(One.of(UserInfo::getId).eq(1L).query())
    }

    @Test
    fun recycleTest() {
        // 默认删除变为回收数据
        Database.removeById("1", RoleInfo::class.java)
        Assertions.assertNull(One.of(RoleInfo::getId).eq("1").query())
        // 查询回收站
        RecycleThreadLocal.ONLY_SHOW_RECYCLE.set(true)
        Assertions.assertNotNull(One.of(RoleInfo::getId).eq("1").query())
    }

    @Test
    fun restoreTest() {
        // 还原数据
        Database.removeById(1L, UserInfo::class.java)
        Assertions.assertNull(One.of(UserInfo::getId).eq(1L).query())
        // 查询回收站
        RecycleThreadLocal.ONLY_SHOW_RECYCLE.set(true)
        // 还原数据
        Database.updateById(UserInfo().apply {
            id = 1L;gmtRecycled =
            LocalDateTime.parse(
                RecycleBinLineHandler.NOT_RECYCLE_VALUE,
                DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss")
            )
        })
        RecycleThreadLocal.ONLY_SHOW_RECYCLE.set(false)
        Assertions.assertNotNull(One.of(UserInfo::getId).eq(1L).query())
    }

    @Test
    fun deleteTest() {
        // 彻底删除
        RecycleThreadLocal.REAL_DELETE.set(true)
        Database.removeById(1L, UserInfo::class.java)
        // 查询未回收
        Assertions.assertNull(One.of(UserInfo::getId).eq(1L).query())
        // 查询回收站
        RecycleThreadLocal.ONLY_SHOW_RECYCLE.set(true)
        Assertions.assertNull(One.of(UserInfo::getId).eq(1L).query())
    }

}

完整源码地址:https://gitee.com/VampireAchao/simple-boot.git

0 人点赞