代码语言:javascript复制
import com.mysql.jdbc.jdbc2.optional.MysqlDataSource
import org.slf4j.LoggerFactory
import org.springframework.stereotype.Service
import java.sql.*
import java.util.*
import javax.sql.DataSource
@Service
class Mysql2OdpsService {
/**
* 生成 ODPS DDL 语句
*/
fun generateddl(table: String, dataSource: MysqlDataSource): String? {
val conn = getConnection(dataSource) ?: return null
val fields = getTableFields(table, dataSource)
return ddl(table, fields)
}
/**
* 获取数据库全部表
*/
fun getAllTables(dataSource: MysqlDataSource): List<String>? {
val conn = getConnection(dataSource) ?: return null
val result = ArrayList<String>()
var rs: ResultSet? = null
try {
conn.createStatement(ResultSet.TYPE_SCROLL_INSENSITIVE, ResultSet.CONCUR_READ_ONLY)
val meta = conn.metaData
//目录名称, 数据库名, 表名称, 表类型
rs = meta.getTables(catalog(), dataSource.databaseName, tableNamePattern(), types())
while (rs!!.next()) {
result.add(rs.getString("TABLE_NAME"))
}
} catch (e: Exception) {
logger.error("获取数据库全部表:", e)
} finally {
close(conn, null, rs)
}
return result
}
/**
* 获取数据库表所包含的字段
*/
fun getTableFields(table: String, dataSource: MysqlDataSource): List<FieldInfo>? {
val conn = getConnection(dataSource) ?: return null
val result = ArrayList<FieldInfo>()
var rs: ResultSet? = null
try {
val meta = conn.metaData
rs = meta.getColumns(catalog(), dataSource.databaseName, table, null)
while (rs.next()) {
val fieldInfo = FieldInfo(
rs.getString("COLUMN_NAME"),
rs.getString("REMARKS"),
rs.getString("TYPE_NAME")
)
result.add(fieldInfo)
}
} catch (e: Exception) {
logger.error("获取数据库表所包含的字段:", e)
} finally {
close(conn, null, rs)
}
return result
}
data class FieldInfo(var fieldName: String, var comment: String, var type: String)
fun getConnection(dataSource: DataSource): Connection? {
var conn: Connection? = null
try {
conn = dataSource.connection
} catch (e: SQLException) {
logger.error("数据库连接失败", e)
}
return conn
}
/**
* 关闭(释放)资源
*
* @param conn Connection
* @param ps PreparedStatement
* @param rs ResultSet
*/
fun close(conn: Connection?, ps: Statement? = null, rs: ResultSet? = null) {
var conn = conn
var ps = ps
var rs = rs
//关闭ResultSet
if (rs != null) {
try {
rs.close()
} catch (e: SQLException) {
rs = null
}
}
//关闭PreparedStatement
if (ps != null) {
try {
ps.close()
} catch (e: SQLException) {
ps = null
}
}
//关闭Connection
if (conn != null) {
try {
conn.close()
} catch (e: SQLException) {
conn = null
}
}
}
/**
* a catalog name; must match the catalog name as it is stored in the database; "" retrieves those without a catalog; null means that the catalog name should not be used to narrow the search
*/
fun catalog(): String? {
return null
}
/**
* a table name pattern; must match the table name as it is stored in the database
*/
fun tableNamePattern(): String {
return "%"
}
/**
* a list of table types, which must be from the list of table types returned from [DatabaseMetaData],to include; null returns all types
*/
fun types(): Array<String> {
return arrayOf("TABLE", "VIEW")
}
fun ddl(table: String, fields: List<FieldInfo>?): String {
var fieldLines = StringBuilder()
fields?.forEachIndexed { index, fieldInfo ->
if (index == 0) {
val line = "${fieldInfo.fieldName} STRING COMMENT '${fieldInfo.comment}'"
fieldLines.append("n")
fieldLines.append(line)
fieldLines.append("n")
} else {
val line = ",${fieldInfo.fieldName} STRING COMMENT '${fieldInfo.comment}'"
fieldLines.append(line)
fieldLines.append("n")
}
}
return """
CREATE TABLE IF NOT EXISTS $table(
$fieldLines
)
COMMENT '' PARTITIONED BY
(
pt STRING COMMENT '时间分区键-yyyymmdd'
)
LIFECYCLE 750;
""".trimIndent()
}
val logger = LoggerFactory.getLogger(this.javaClass)
}