flink UDX
1.UDF: 自定义标量函数(User Defined Scalar Function)。一行输入一行输出。2.UDAF: 自定义聚合函数。多行输入一行输出。3.UDTF: 自定义表函数。一行输入多行输出或一列输入多列输出。
sql 语句
代码语言:javascript复制select
first_non_null(businessId) as id
from
test_new
where
eventType = '1'
group by
businessId
执行流程:
自定义udaf
代码语言:javascript复制public class FirstNonNull extends AggregateFunction<String[],ArrayList<String>> {
@Override
public ArrayList<String> createAccumulator() {
return new ArrayList<>();
}
@Override
public String[] getValue(ArrayList<String> data) {
if (data == null || data.size() == 0) {
return null;
}
return data.toArray(new String[data.size()]);
}
public void accumulate(ArrayList<String> src, String... input) {
if (src.size() == 0) {
addAll(src, input);
} else {
String curr_order_by_value = String.valueOf(input[0]);
String src_order_by_value = String.valueOf(src.get(0));
if (src_order_by_value.compareTo(curr_order_by_value) > 0) {
addAll(src, input);
} else if (src.contains(null)) {
fillNull(src, input);
}
}
}
public void fillNull(ArrayList<String> src, String[] input) {
int size = src.size();
for (int i = 0; i < size; i ) {
if (src.get(i) == null) {
src.set(i, input[i] == null ? null : String.valueOf(input[i]));
}
}
}
public void addAll(ArrayList<String> src, String[] input) {
for (int i = 0; i < input.length; i ) {
Object value = input[i];
if (i >= src.size()) {
src.add(i, value == null ? null : String.valueOf(value));
} else {
if (value != null) {
src.set(i, String.valueOf(value));
}
}
}
}
}
一个aggFunction必须要实现的方法有:
- createAccumulator创建accumulator
- accumulate(ACC accumulator, [user defined inputs])
- getValue返回结果
一个aggFunction可选的方法有:
•retract: OVER窗口聚合时使用;•merge: 使用窗口操作时必须实现(SessionWindow)。用于优化hop的场景,详细说明见:https://www.zhihu.com/question/346639699;•resetAccumulator:used for data set grouping aggregates
重点说一下accumulate方法和retract方法
•accumulate方法
代码语言:javascript复制/**
* param: accumulator the accumulator which contains the current aggregated results
* param: [user defined inputs] the input value (usually obtained from a new arrived data).
*/
public void accumulate(ACC accumulator, [user defined inputs])
}
它的输入的第一参数为包含所有结果集的accumulator(归集器); 第二个参数是当前到达的输入数据。这里是用于归集的逻辑。
•retract方法
代码语言:javascript复制 /**
* param: accumulator the accumulator which contains the current aggregated results
* param: [user defined inputs] the input value (usually obtained from a new arrived data).
*/
public void retract(ACC accumulator, [user defined inputs])
}
它的输入的第一参数为包含所有结果集的accumulator(归集器); 第二个参数是当前到达的输入数据。这里是用于回撤的逻辑。
任务层面
org.apache.flink.streaming.runtime.tasks.StreamTask#processInput:
代码语言:javascript复制protected void processInput(ActionContext context) throws Exception {
if (!inputProcessor.processInput()) {
context.allActionsCompleted();
}
}
在这里会使用inputProcessor来往下进行processInput操作。由于是单一的source源,所以这个inputProcessor对应的是StreamOneInputProcessor类型的,对应的processInput方法为org.apache.flink.streaming.runtime.io.StreamOneInputProcessor#processInput:
代码语言:javascript复制 @Override
public boolean processInput() throws Exception {
// 初始化输入的记录数量
initializeNumRecordsIn();
StreamElement recordOrMark = input.pollNextNullable();
if (recordOrMark == null) {
input.isAvailable().get();
return !checkFinished();
}
// 从input中获取到对应的channel
int channel = input.getLastChannel();
checkState(channel != StreamTaskInput.UNSPECIFIED);
// 处理对应channel的记录
processElement(recordOrMark, channel);
return true;
}
这里是处理input的地方,我们主要关注下processElement方法。
算子层面
我们主要关注下org.apache.flink.streaming.runtime.io.StreamOneInputProcessor#processElement方法:
代码语言:javascript复制private void processElement(StreamElement recordOrMark, int channel) throws Exception {
if (recordOrMark.isRecord()) {// 如果输入是记录
// now we can do the actual processing
StreamRecord<IN> record = recordOrMark.asRecord();
synchronized (lock) {
// 增加输入的记录数
numRecordsIn.inc();
streamOperator.setKeyContextElement1(record);
// 使用算子处理record
streamOperator.processElement(record);
}
}
else if (recordOrMark.isWatermark()) {// 如果输入是水位信息
// handle watermark
statusWatermarkValve.inputWatermark(recordOrMark.asWatermark(), channel);
} else if (recordOrMark.isStreamStatus()) {// 如果输入是stream的状态信息
// handle stream status
statusWatermarkValve.inputStreamStatus(recordOrMark.asStreamStatus(), channel);
} else if (recordOrMark.isLatencyMarker()) {// 如果是延迟的水平
// handle latency marker
synchronized (lock) {
streamOperator.processLatencyMarker(recordOrMark.asLatencyMarker());
}
} else {// 不知道StreamElement的类型时抛出异常
throw new UnsupportedOperationException("Unknown type of StreamElement");
}
}
这个方法处理的StreamElement类型比较多,我们主要看下它对普通记录的处理方法streamOperator.processElement(record),由于我们这里使用的是group by 操作,所以对应的算子为KeyedProcessOperator,方法为org.apache.flink.streaming.api.operators.KeyedProcessOperator#processElement:
代码语言:javascript复制 @Override
public void processElement(StreamRecord<IN> element) throws Exception {
// 设置时间戳
collector.setTimestamp(element);
context.element = element;
// 使用用户定义的udf来处理元素
userFunction.processElement(element.getValue(), context, collector);
context.element = null;
}
这里需要注意的一点是这个userFunction应当算是我们自定义的udf的一个代理,它会动态编译产生一个GroupAggsHandler类,在类内部的方法中处理时会回调我们自定义的udf中实现的方法(接口中约定好的那些方法)。
Agg层面
我们看下上面的userFunction的一些属性:
这个genAggHandler是在哪里生成的呢?这里简单提一下,见下图:
在flink解析sql生成streamGraph的过程中会调用org.apache.flink.table.planner.plan.nodes.physical.stream.StreamExecGroupAggregate#translateToPlanInternal方法,在该方法中会创建aggsHandler对象。
也就是说GroupAggFunction中实际产生作用的是GroupAggsHandler对象,genAggHandler中动态编译产生的就是GroupAggsHandler,在genAggHandler中code的具体内容如下:
代码语言:javascript复制 public final class GroupAggsHandler$39 implements org.apache.flink.table.runtime.generated.AggsHandleFunction {
private transient com.test.dream.flink.udf.aggfunctions.FirstNonNull function_com$test$dream$flink$udf$aggfunctions$FirstNonNull$53ab4e4c4415303976432217433a2633;
private transient org.apache.flink.table.dataformat.DataFormatConverters.GenericConverter converter$21;
private transient org.apache.flink.table.dataformat.DataFormatConverters.GenericConverter converter$24;
private org.apache.flink.table.dataformat.BinaryGeneric agg0_acc_internal;
private java.util.ArrayList agg0_acc_external;
private transient org.apache.flink.table.dataformat.DataFormatConverters.GenericConverter converter$27;
private transient org.apache.flink.table.dataformat.DataFormatConverters.GenericConverter converter$28;
private transient org.apache.flink.table.runtime.typeutils.BinaryStringSerializer typeSerializer$31;
private transient org.apache.flink.table.dataformat.DataFormatConverters.StringConverter converter$33;
private transient org.apache.flink.table.dataformat.DataFormatConverters.ObjectArrayConverter converter$37;
public GroupAggsHandler$39(java.lang.Object[] references) throws Exception {
function_com$test$dream$flink$udf$aggfunctions$FirstNonNull$53ab4e4c4415303976432217433a2633 = (((com.test.dream.flink.udf.aggfunctions.FirstNonNull) references[0]));
converter$21 = (((org.apache.flink.table.dataformat.DataFormatConverters.GenericConverter) references[1]));
converter$24 = (((org.apache.flink.table.dataformat.DataFormatConverters.GenericConverter) references[2]));
converter$27 = (((org.apache.flink.table.dataformat.DataFormatConverters.GenericConverter) references[3]));
converter$28 = (((org.apache.flink.table.dataformat.DataFormatConverters.GenericConverter) references[4]));
typeSerializer$31 = (((org.apache.flink.table.runtime.typeutils.BinaryStringSerializer) references[5]));
converter$33 = (((org.apache.flink.table.dataformat.DataFormatConverters.StringConverter) references[6]));
converter$37 = (((org.apache.flink.table.dataformat.DataFormatConverters.ObjectArrayConverter) references[7]));
}
@Override
public void open(org.apache.flink.table.runtime.dataview.StateDataViewStore store) throws Exception {
function_com$test$dream$flink$udf$aggfunctions$FirstNonNull$53ab4e4c4415303976432217433a2633.open(new org.apache.flink.table.functions.FunctionContext(store.getRuntimeContext()));
}
@Override
public void accumulate(org.apache.flink.table.dataformat.BaseRow accInput) throws Exception {
org.apache.flink.table.dataformat.BinaryString field$29;
boolean isNull$29;
isNull$29 = accInput.isNullAt(0);
field$29 = org.apache.flink.table.dataformat.BinaryString.EMPTY_UTF8;
if (!isNull$29) {
field$29 = accInput.getString(0);
}
org.apache.flink.table.dataformat.BinaryString field$30 = field$29;
if (!isNull$29) {
field$30 = (org.apache.flink.table.dataformat.BinaryString) (typeSerializer$31.copy(field$30));
}
org.apache.flink.table.dataformat.BinaryString field$32 = field$30;
if (!isNull$29) {
field$32 = (org.apache.flink.table.dataformat.BinaryString) (typeSerializer$31.copy(field$32));
}
function_com$test$dream$flink$udf$aggfunctions$FirstNonNull$53ab4e4c4415303976432217433a2633.accumulate(agg0_acc_external, isNull$29 ? null : (java.lang.String) converter$33.toExternal((org.apache.flink.table.dataformat.BinaryString) field$32));
}
@Override
public void retract(org.apache.flink.table.dataformat.BaseRow retractInput) throws Exception {
throw new java.lang.RuntimeException("This function not require retract method, but the retract method is called.");
}
@Override
public void merge(org.apache.flink.table.dataformat.BaseRow otherAcc) throws Exception {
throw new java.lang.RuntimeException("This function not require merge method, but the merge method is called.");
}
@Override
public void setAccumulators(org.apache.flink.table.dataformat.BaseRow acc) throws Exception {
org.apache.flink.table.dataformat.BinaryGeneric field$26;
boolean isNull$26;
isNull$26 = acc.isNullAt(0);
field$26 = null;
if (!isNull$26) {
field$26 = acc.getGeneric(0);
}
agg0_acc_internal = field$26;
agg0_acc_external = (java.util.ArrayList) converter$27.toExternal((org.apache.flink.table.dataformat.BinaryGeneric) agg0_acc_internal);
}
@Override
public void resetAccumulators() throws Exception {
agg0_acc_external = (java.util.ArrayList) function_com$test$dream$flink$udf$aggfunctions$FirstNonNull$53ab4e4c4415303976432217433a2633.createAccumulator();
agg0_acc_internal = (org.apache.flink.table.dataformat.BinaryGeneric) converter$28.toInternal((java.util.ArrayList) agg0_acc_external);
}
@Override
public org.apache.flink.table.dataformat.BaseRow getAccumulators() throws Exception {
final org.apache.flink.table.dataformat.GenericRow acc$25 = new org.apache.flink.table.dataformat.GenericRow(1);
agg0_acc_internal = (org.apache.flink.table.dataformat.BinaryGeneric) converter$24.toInternal((java.util.ArrayList) agg0_acc_external);
if (false) {
acc$25.setNullAt(0);
} else {
acc$25.setField(0, agg0_acc_internal);;
}
return acc$25;
}
@Override
public org.apache.flink.table.dataformat.BaseRow createAccumulators() throws Exception {
final org.apache.flink.table.dataformat.GenericRow acc$23 = new org.apache.flink.table.dataformat.GenericRow(1);
org.apache.flink.table.dataformat.BinaryGeneric acc_internal$22 = (org.apache.flink.table.dataformat.BinaryGeneric) (org.apache.flink.table.dataformat.BinaryGeneric) converter$21.toInternal((java.util.ArrayList) function_com$test$dream$flink$udf$aggfunctions$FirstNonNull$53ab4e4c4415303976432217433a2633.createAccumulator());
if (false) {
acc$23.setNullAt(0);
} else {
acc$23.setField(0, acc_internal$22);;
}
return acc$23;
}
@Override
public org.apache.flink.table.dataformat.BaseRow getValue() throws Exception {
final org.apache.flink.table.dataformat.GenericRow aggValue$38 = new org.apache.flink.table.dataformat.GenericRow(1);
java.lang.String[] value_external$34 = (java.lang.String[])
org.apache.flink.table.dataformat.BaseArray value_internal$35 =
(org.apache.flink.table.dataformat.BaseArray) converter$37.toInternal((java.lang.String[]) value_external$34);
boolean valueIsNull$36 = value_internal$35 == null;
if (valueIsNull$36) {
aggValue$38.setNullAt(0);
} else {
aggValue$38.setField(0, value_internal$35);;
}
return aggValue$38;
}
@Override
public void cleanup() throws Exception {
}
@Override
public void close() throws Exception {
function_com$test$dream$flink$udf$aggfunctions$FirstNonNull$53ab4e4c4415303976432217433a2633.close();
}
}
在代码中可以看到,内部调用了FirstNonNull函数的实现方法。
紧接着来看org.apache.flink.table.runtime.operators.aggregate.GroupAggFunction#processElement方法:
代码语言:javascript复制 @Override
public void processElement(BaseRow input, Context ctx, Collector<BaseRow> out) throws Exception {
long currentTime = ctx.timerService().currentProcessingTime();
// register state-cleanup timer
registerProcessingCleanupTimer(ctx, currentTime);
BaseRow currentKey = ctx.getCurrentKey();
boolean firstRow;
BaseRow accumulators = accState.value();
if (null == accumulators) {
firstRow = true;
// 这个function就是上面的GroupAggsHandler$39类型的对象,在GroupAggsHandler$39对象内部的createAccumulators方法中会回调我们自定义的udf的createAccumulator()方法
accumulators = function.createAccumulators();
} else {
firstRow = false;
}
// set accumulators to handler first
function.setAccumulators(accumulators);
// get previous aggregate result
BaseRow prevAggValue = function.getValue();
// update aggregate result and set to the newRow
if (isAccumulateMsg(input)) {
// accumulate input
//在GroupAggsHandler$39对象内部的accumulate方法中会回调我们自定义的udf的accumulate()方法
function.accumulate(input);
} else {
// retract input
function.retract(input);
}
// get current aggregate result
// 在GroupAggsHandler$39对象内部的getValue方法中会回调我们自定义的udf的getValue()方法
BaseRow newAggValue = function.getValue();
// get accumulator
// 在GroupAggsHandler$39对象内部的getAccumulators方法中会回调我们自定义的udf的getAccumulators()方法
accumulators = function.getAccumulators();
-------------省略回撤处理和状态清理部分代码----------------
}
这里主要需要注意以下几点:
•上面代码中的function就是上面的GroupAggsHandler39类型的对象,在GroupAggsHandler39对象内部的createAccumulators方法中会回调我们自定义的udf的createAccumulator()方法;•在GroupAggsHandler39对象内部的accumulate方法中会回调我们自定义的udf的accumulate()方法;•在GroupAggsHandler39对象内部的getValue方法中会回调我们自定义的udf的getValue()方法;•在GroupAggsHandler
这步执行完成后,去进入整个graph的下一个算子中,调用下一个算子的processElement方法,直到sink算子,完成sink操作。