聊聊PowerJob的AbstractSqlProcessor

2024-01-03 15:08:53 浏览数 (2)

本文主要研究一下PowerJob的AbstractSqlProcessor

AbstractSqlProcessor

tech/powerjob/official/processors/impl/sql/AbstractSqlProcessor.java

代码语言:javascript复制
@Slf4j
public abstract class AbstractSqlProcessor extends CommonBasicProcessor {

    /**
     * 默认超时时间
     */
    protected static final int DEFAULT_TIMEOUT = 60;
    /**
     * name => SQL validator
     * 注意 :
     * - 返回 true 表示验证通过
     * - 返回 false 表示 SQL 非法,将被拒绝执行
     */
    protected final Map<String, Predicate<String>> sqlValidatorMap = Maps.newConcurrentMap();
    /**
     * 自定义 SQL 解析器
     */
    protected SqlParser sqlParser;

    private static final Joiner JOINER = Joiner.on("|").useForNull("-");


    @Override
    public ProcessResult process0(TaskContext taskContext) {

        OmsLogger omsLogger = taskContext.getOmsLogger();
        // 解析参数
        SqlParams sqlParams = extractParams(taskContext);
        omsLogger.info("origin sql params: {}", JSON.toJSON(sqlParams));
        // 校验参数
        validateParams(sqlParams);

        StopWatch stopWatch = new StopWatch(this.getClass().getSimpleName());
        // 解析
        stopWatch.start("Parse SQL");
        if (sqlParser != null) {
            omsLogger.info("before parse sql: {}", sqlParams.getSql());
            String newSQL = sqlParser.parse(sqlParams.getSql(), taskContext);
            sqlParams.setSql(newSQL);
            omsLogger.info("after parse sql: {}", newSQL);
        }
        stopWatch.stop();

        // 校验 SQL
        stopWatch.start("Validate SQL");
        validateSql(sqlParams.getSql(), omsLogger);
        stopWatch.stop();

        // 执行
        stopWatch.start("Execute SQL");
        omsLogger.info("final sql params: {}", JSON.toJSON(sqlParams));
        executeSql(sqlParams, taskContext);
        stopWatch.stop();

        omsLogger.info(stopWatch.prettyPrint());
        String message = String.format("execute successfully, used time: %s millisecond", stopWatch.getTotalTimeMillis());
        return new ProcessResult(true, message);
    }

    abstract Connection getConnection(SqlParams sqlParams, TaskContext taskContext) throws SQLException;

    public void setSqlParser(SqlParser sqlParser) {
        this.sqlParser = sqlParser;
    }

    public void registerSqlValidator(String validatorName, Predicate<String> sqlValidator) {
        sqlValidatorMap.put(validatorName, sqlValidator);
        log.info("register sql validator({})' successfully.", validatorName);
    }

    //......
}    

AbstractSqlProcessor继承了CommonBasicProcessor,其process0先将入参解析为SqlParams,然后调用validateParams进行参数校验,针对sqlParser不为null的会通过sqlParser进行解析,接着通过validateSql校验sql,最后通过executeSql执行sql;它定义了getConnection抽象方法,提供了setSqlParser、registerSqlValidator方法

SqlParams

代码语言:javascript复制
    @Data
    public static class SqlParams {
        /**
         * 数据源名称
         */
        private String dataSourceName;
        /**
         * 需要执行的 SQL
         */
        private String sql;
        /**
         * 超时时间
         */
        private Integer timeout;
        /**
         * jdbc url
         * 具体格式可参考 https://www.baeldung.com/java-jdbc-url-format
         */
        private String jdbcUrl;
        /**
         * 是否展示 SQL 执行结果
         */
        private boolean showResult;
    }

SqlParams定义了dataSourceName、sql、timeout、jdbcUrl、showResult属性

validateSql

代码语言:javascript复制
    private void validateSql(String sql, OmsLogger omsLogger) {
        if (sqlValidatorMap.isEmpty()) {
            return;
        }
        for (Map.Entry<String, Predicate<String>> entry : sqlValidatorMap.entrySet()) {
            Predicate<String> validator = entry.getValue();
            if (!validator.test(sql)) {
                omsLogger.error("validate sql by validator[{}] failed, skip to process!", entry.getKey());
                throw new IllegalArgumentException("illegal sql, can't pass the validation of "   entry.getKey());
            }
        }
    }

validateSql遍历sqlValidatorMap,挨个执行test方法,验证不通过抛出IllegalArgumentException

executeSql

代码语言:javascript复制
    @SneakyThrows
    private void executeSql(SqlParams sqlParams, TaskContext ctx) {

        OmsLogger omsLogger = ctx.getOmsLogger();

        boolean originAutoCommitFlag ;
        try (Connection connection = getConnection(sqlParams, ctx)) {
            originAutoCommitFlag = connection.getAutoCommit();
            connection.setAutoCommit(false);
            try (Statement statement = connection.createStatement()) {
                statement.setQueryTimeout(sqlParams.getTimeout() == null ? DEFAULT_TIMEOUT : sqlParams.getTimeout());
                statement.execute(sqlParams.getSql());

                connection.commit();

                if (sqlParams.showResult) {
                    outputSqlResult(statement, omsLogger);
                }
            } catch (Throwable e) {
                omsLogger.error("execute sql failed, try to rollback", e);
                connection.rollback();
                throw e;
            } finally {
                connection.setAutoCommit(originAutoCommitFlag);
            }
        }
    }

executeSql通过getConnection获取连接,设置为手动提交,然后创建Statement,设置queryTimeout,执行,最后提交,针对showResult的执行outputSqlResult

outputSqlResult

代码语言:javascript复制
    private void outputSqlResult(Statement statement, OmsLogger omsLogger) throws SQLException {
        omsLogger.info("====== SQL EXECUTE RESULT ======");

        for (int index = 0; index < Integer.MAX_VALUE; index  ) {

            // 某一个结果集
            ResultSet resultSet = statement.getResultSet();
            if (resultSet != null) {
                try (ResultSet rs = resultSet) {
                    int columnCount = rs.getMetaData().getColumnCount();
                    List<String> columnNames = Lists.newLinkedList();
                    //column – the first column is 1, the second is 2, ...
                    for (int i = 1; i <= columnCount; i  ) {
                        columnNames.add(rs.getMetaData().getColumnName(i));
                    }
                    omsLogger.info("[Result-{}] [Columns] {}"   System.lineSeparator(), index, JOINER.join(columnNames));
                    int rowIndex = 0;
                    List<Object> row = Lists.newLinkedList();
                    while (rs.next()) {
                        for (int i = 1; i <= columnCount; i  ) {
                            row.add(rs.getObject(i));
                        }
                        omsLogger.info("[Result-{}] [Row-{}] {}"   System.lineSeparator(), index, rowIndex  , JOINER.join(row));
                    }
                }
            } else {
                int updateCount = statement.getUpdateCount();
                if (updateCount != -1) {
                    omsLogger.info("[Result-{}] update count: {}", index, updateCount);
                }
            }
            if (((!statement.getMoreResults()) && (statement.getUpdateCount() == -1))) {
                break;
            }
        }
        omsLogger.info("====== SQL EXECUTE RESULT ======");
    }

outputSqlResult从statement获取resultSet,然后打印columnName,在打印每行数据,对于更新操作则打印updateCount

SqlParser

代码语言:javascript复制
    @FunctionalInterface
    public interface SqlParser {
        /**
         * 自定义 SQL 解析逻辑
         *
         * @param sql         原始 SQL 语句
         * @param taskContext 任务上下文
         * @return 解析后的 SQL
         */
        String parse(String sql, TaskContext taskContext);
    }

SqlParser接口定义了parse方法

DynamicDatasourceSqlProcessor

tech/powerjob/official/processors/impl/sql/DynamicDatasourceSqlProcessor.java

代码语言:javascript复制
public class DynamicDatasourceSqlProcessor extends AbstractSqlProcessor {

    @Override
    protected void validateParams(SqlParams sqlParams) {
        if (StringUtils.isEmpty(sqlParams.getJdbcUrl())) {
            throw new IllegalArgumentException("jdbcUrl can't be empty in DynamicDatasourceSqlProcessor!");
        }
    }

    @Override
    Connection getConnection(SqlParams sqlParams, TaskContext taskContext) throws SQLException {

        JSONObject params = JSONObject.parseObject(CommonUtils.parseParams(taskContext));
        Properties properties = new Properties();

        // normally at least a "user" and "password" property should be included
        params.forEach((k, v) -> properties.setProperty(k, String.valueOf(v)));

        return DriverManager.getConnection(sqlParams.getJdbcUrl(), properties);
    }

    @Override
    protected String getSecurityDKey() {
        return SecurityUtils.ENABLE_DYNAMIC_SQL_PROCESSOR;
    }
}

DynamicDatasourceSqlProcessor继承了AbstractSqlProcessor,其validateParams要求jdbcUrl不能为空,其getConnection方法会从taskContext提取properties作为DriverManager.getConnection的属性,其getSecurityDKey返回的是powerjob.official-processor.dynamic-datasource.enable配置

SpringDatasourceSqlProcessor

tech/powerjob/official/processors/impl/sql/SpringDatasourceSqlProcessor.java

代码语言:javascript复制
@Slf4j
public class SpringDatasourceSqlProcessor extends AbstractSqlProcessor {
    /**
     * 默认的数据源名称
     */
    private static final String DEFAULT_DATASOURCE_NAME = "default";
    /**
     * name => data source
     */
    private final Map<String, DataSource> dataSourceMap;

    /**
     * 指定默认的数据源
     *
     * @param defaultDataSource 默认数据源
     */
    public SpringDatasourceSqlProcessor(DataSource defaultDataSource) {
        dataSourceMap = Maps.newConcurrentMap();
        registerDataSource(DEFAULT_DATASOURCE_NAME, defaultDataSource);
    }

    @Override
    Connection getConnection(SqlParams sqlParams, TaskContext taskContext) throws SQLException {
        return dataSourceMap.get(sqlParams.getDataSourceName()).getConnection();
    }

    /**
     * 校验参数,如果校验不通过直接抛异常
     *
     * @param sqlParams SQL 参数信息
     */
    @Override
    protected void validateParams(SqlParams sqlParams) {
        // 检查数据源
        if (StringUtils.isEmpty(sqlParams.getDataSourceName())) {
            // use the default data source when current data source name is empty
            sqlParams.setDataSourceName(DEFAULT_DATASOURCE_NAME);
        }
        dataSourceMap.computeIfAbsent(sqlParams.getDataSourceName(), dataSourceName -> {
            throw new IllegalArgumentException("can't find data source with name "   dataSourceName);
        });
    }

    /**
     * 注册数据源
     *
     * @param dataSourceName 数据源名称
     * @param dataSource     数据源
     */
    public void registerDataSource(String dataSourceName, DataSource dataSource) {
        Objects.requireNonNull(dataSourceName, "DataSource name must not be null");
        Objects.requireNonNull(dataSource, "DataSource must not be null");
        dataSourceMap.put(dataSourceName, dataSource);
        log.info("register data source({})' successfully.", dataSourceName);
    }

    /**
     * 移除数据源
     *
     * @param dataSourceName 数据源名称
     */
    public void removeDataSource(String dataSourceName) {
        DataSource remove = dataSourceMap.remove(dataSourceName);
        if (remove != null) {
            log.warn("remove data source({})' successfully.", dataSourceName);
        }
    }
}

SpringDatasourceSqlProcessor继承了AbstractSqlProcessor,其构造器注册名为default的DataSource,其getConnection根据sqlParams的dataSourceName来获取连接,validateParams会先校验指定的dataSource是否存在;它提供了registerDataSource、removeDataSource方法

小结

AbstractSqlProcessor继承了CommonBasicProcessor,其process0先将入参解析为SqlParams,然后调用validateParams进行参数校验,针对sqlParser不为null的会通过sqlParser进行解析,接着通过validateSql校验sql,最后通过executeSql执行sql;它定义了getConnection抽象方法,提供了setSqlParser、registerSqlValidator方法。它有两个实现类分别是DynamicDatasourceSqlProcessor(通过jdbcUrl来构造连接)、SpringDatasourceSqlProcessor(通过给定的dataSource获取连接)。

0 人点赞