Flink SQL JSON Format 源码解析

Posted JasonLee-后厂村程序员

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Flink SQL JSON Format 源码解析相关的知识,希望对你有一定的参考价值。

用 Flink SQL 解析 JSON 格式的数据是非常简单的,只需要在 DDL 语句中设置 Format 为 json 即可,像下面这样:

CREATE TABLE kafka_source (
    funcName STRING,
    data ROW<snapshots ARRAY<ROW<content_type STRING,url STRING>>,audio ARRAY<ROW<content_type STRING,url STRING>>>,
    resultMap ROW<`result` MAP<STRING,STRING>,isSuccess BOOLEAN>,
    meta  MAP<STRING,STRING>,
    `type` INT,
    `timestamp` BIGINT,
    arr ARRAY<ROW<address STRING,city STRING>>,
    map MAP<STRING,INT>,
    doublemap MAP<STRING,MAP<STRING,INT>>,
    proctime as PROCTIME()
) WITH (
    'connector' = 'kafka', -- 使用 kafka connector
    'topic' = 'test',  -- kafka topic
    'properties.bootstrap.servers' = 'master:9092,storm1:9092,storm2:9092',  -- broker连接信息
    'properties.group.id' = 'jason_flink_test', -- 消费kafka的group_id
    'scan.startup.mode' = 'latest-offset',  -- 读取数据的位置
    'format' = 'json',  -- 数据源格式为 json
    'json.fail-on-missing-field' = 'true', -- 字段丢失任务不失败
    'json.ignore-parse-errors' = 'false'  -- 解析失败跳过
)

那么你有没有想过它的底层是怎么实现的呢? 今天这篇文章就带你深入浅出,了解其实现细节.

当你输入一条 SQL 的时候在 Flink 里面会经过解析,验证,优化,转换等几个重要的步骤,因为前面的几个过程比较繁琐,这里暂时不展开说明,我们直接来到比较关键的源码处,在把 sqlNode 转换成 relNode 的过程中,会来到 CatalogSourceTable#createDynamicTableSource 该类的作用是把 Calcite 的 RelOptTable 翻译成 Flink 的 TableSourceTable 对象.

createDynamicTableSource  源码

private DynamicTableSource createDynamicTableSource(
        FlinkContext context, ResolvedCatalogTable catalogTable) {
    final ReadableConfig config = context.getTableConfig().getConfiguration();
    return FactoryUtil.createTableSource(
            schemaTable.getCatalog(),
            schemaTable.getTableIdentifier(),
            catalogTable,
            config,
            Thread.currentThread().getContextClassLoader(),
            schemaTable.isTemporary());
}

其实这个就是要创建 Kafka Source 的流表,然后会调用 FactoryUtil#createTableSource 这个方法

createTableSource 源码

public static DynamicTableSource createTableSource(
        @Nullable Catalog catalog,
        ObjectIdentifier objectIdentifier,
        ResolvedCatalogTable catalogTable,
        ReadableConfig configuration,
        ClassLoader classLoader,
        boolean isTemporary) {
    final DefaultDynamicTableContext context =
            new DefaultDynamicTableContext(
                    objectIdentifier, catalogTable, configuration, classLoader, isTemporary);
    try {
        // 获取对应的 factory 这里其实就是 KafkaDynamicTableFactory
        final DynamicTableSourceFactory factory =
                getDynamicTableFactory(DynamicTableSourceFactory.class, catalog, context);
        // 创建动态表
        return factory.createDynamicTableSource(context);
    } catch (Throwable t) {
        throw new ValidationException(
                String.format(
                        "Unable to create a source for reading table '%s'.\\n\\n"
                                + "Table options are:\\n\\n"
                                + "%s",
                        objectIdentifier.asSummaryString(),
                        catalogTable.getOptions().entrySet().stream()
                                .map(e -> stringifyOption(e.getKey(), e.getValue()))
                                .sorted()
                                .collect(Collectors.joining("\\n"))),
                t);
    }
}

在这个方法里面,有两个重要的过程,首先是获取对应的 factory 对象,然后创建 DynamicTableSource 实例.在 getDynamicTableFactory 中实际调用的是 discoverFactory 方法,顾名思义就是发现工厂.

discoverFactory 源码

public static <T extends Factory> T discoverFactory(
        ClassLoader classLoader, Class<T> factoryClass, String factoryIdentifier) {
    final List<Factory> factories = discoverFactories(classLoader);

    final List<Factory> foundFactories =
            factories.stream()
                    .filter(f -> factoryClass.isAssignableFrom(f.getClass()))
                    .collect(Collectors.toList());

    if (foundFactories.isEmpty()) {
        throw new ValidationException(
                String.format(
                        "Could not find any factories that implement '%s' in the classpath.",
                        factoryClass.getName()));
    }

    final List<Factory> matchingFactories =
            foundFactories.stream()
                    .filter(f -> f.factoryIdentifier().equals(factoryIdentifier))
                    .collect(Collectors.toList());

    if (matchingFactories.isEmpty()) {
        throw new ValidationException(
                String.format(
                        "Could not find any factory for identifier '%s' that implements '%s' in the classpath.\\n\\n"
                                + "Available factory identifiers are:\\n\\n"
                                + "%s",
                        factoryIdentifier,
                        factoryClass.getName(),
                        foundFactories.stream()
                                .map(Factory::factoryIdentifier)
                                .distinct()
                                .sorted()
                                .collect(Collectors.joining("\\n"))));
    }
    if (matchingFactories.size() > 1) {
        throw new ValidationException(
                String.format(
                        "Multiple factories for identifier '%s' that implement '%s' found in the classpath.\\n\\n"
                                + "Ambiguous factory classes are:\\n\\n"
                                + "%s",
                        factoryIdentifier,
                        factoryClass.getName(),
                        matchingFactories.stream()
                                .map(f -> f.getClass().getName())
                                .sorted()
                                .collect(Collectors.joining("\\n"))));
    }

    return (T) matchingFactories.get(0);
}

这个代码相对简单,就不加注释了,逻辑也非常的清晰,就是获取对应的 factory ,先是通过 SPI 机制加载所有的 factory 然后根据 factoryIdentifier 过滤出满足条件的,这里其实就是 kafka connector 了.最后还有一些异常的判断.

discoverFactories 源码

private static List<Factory> discoverFactories(ClassLoader classLoader) {
    try {
        final List<Factory> result = new LinkedList<>();
        ServiceLoader.load(Factory.class, classLoader).iterator().forEachRemaining(result::add);
        return result;
    } catch (ServiceConfigurationError e) {
        LOG.error("Could not load service provider for factories.", e);
        throw new TableException("Could not load service provider for factories.", e);
    }
}

这个代码大家应该比较熟悉了,前面也有文章介绍过了.加载所有的 Factory 返回一个 Factory 的集合.

下面才是今天的重点.

createDynamicTableSource 源码

public DynamicTableSource createDynamicTableSource(Context context) {
    TableFactoryHelper helper = FactoryUtil.createTableFactoryHelper(this, context);
    ReadableConfig tableOptions = helper.getOptions();
    Optional<DecodingFormat<DeserializationSchema<RowData>>> keyDecodingFormat = getKeyDecodingFormat(helper);
    // format 的逻辑
    DecodingFormat<DeserializationSchema<RowData>> valueDecodingFormat = getValueDecodingFormat(helper);
    helper.validateExcept(new String[]{"properties."});
    KafkaOptions.validateTableSourceOptions(tableOptions);
    validatePKConstraints(context.getObjectIdentifier(), context.getCatalogTable(), valueDecodingFormat);
    StartupOptions startupOptions = KafkaOptions.getStartupOptions(tableOptions);
    Properties properties = KafkaOptions.getKafkaProperties(context.getCatalogTable().getOptions());
    properties.setProperty("flink.partition-discovery.interval-millis", String.valueOf(tableOptions.getOptional(KafkaOptions.SCAN_TOPIC_PARTITION_DISCOVERY).map(Duration::toMillis).orElse(-9223372036854775808L)));
    DataType physicalDataType = context.getCatalogTable().getSchema().toPhysicalRowDataType();
    int[] keyProjection = KafkaOptions.createKeyFormatProjection(tableOptions, physicalDataType);
    int[] valueProjection = KafkaOptions.createValueFormatProjection(tableOptions, physicalDataType);
    String keyPrefix = (String)tableOptions.getOptional(KafkaOptions.KEY_FIELDS_PREFIX).orElse((Object)null);
    return this.createKafkaTableSource(physicalDataType, (DecodingFormat)keyDecodingFormat.orElse((Object)null), valueDecodingFormat, keyProjection, valueProjection, keyPrefix, KafkaOptions.getSourceTopics(tableOptions), KafkaOptions.getSourceTopicPattern(tableOptions), properties, startupOptions.startupMode, startupOptions.specificOffsets, startupOptions.startupTimestampMillis);
}

getValueDecodingFormat 方法最终会调用 discoverOptionalFormatFactory 方法

discoverOptionalDecodingFormat 和 discoverOptionalFormatFactory 源码

public <I, F extends DecodingFormatFactory<I>>
                Optional<DecodingFormat<I>> discoverOptionalDecodingFormat(
                        Class<F> formatFactoryClass, ConfigOption<String> formatOption) {
            return discoverOptionalFormatFactory(formatFactoryClass, formatOption)
                    .map(
                            formatFactory -> {
                                String formatPrefix = formatPrefix(formatFactory, formatOption);
                                try {
                                    return formatFactory.createDecodingFormat(
                                            context, projectOptions(formatPrefix));
                                } catch (Throwable t) {
                                    throw new ValidationException(
                                            String.format(
                                                    "Error creating scan format '%s' in option space '%s'.",
                                                    formatFactory.factoryIdentifier(),
                                                    formatPrefix),
                                            t);
                                }
                            });
        }

private <F extends Factory> Optional<F> discoverOptionalFormatFactory(
        Class<F> formatFactoryClass, ConfigOption<String> formatOption) {
    final String identifier = allOptions.get(formatOption);
    if (identifier == null) {
        return Optional.empty();
    }
    final F factory =
            discoverFactory(context.getClassLoader(), formatFactoryClass, identifier);
    String formatPrefix = formatPrefix(factory, formatOption);
    // log all used options of other factories
    consumedOptionKeys.addAll(
            factory.requiredOptions().stream()
                    .map(ConfigOption::key)
                    .map(k -> formatPrefix + k)
                    .collect(Collectors.toSet()));
    consumedOptionKeys.addAll(
            factory.optionalOptions().stream()
                    .map(ConfigOption::key)
                    .map(k -> formatPrefix + k)
                    .collect(Collectors.toSet()));
    return Optional.of(factory);
}

// 直接过滤出满足条件的 format 
public static <T extends Factory> T discoverFactory(
            ClassLoader classLoader, Class<T> factoryClass, String factoryIdentifier) {
        final List<Factory> factories = discoverFactories(classLoader);

        final List<Factory> foundFactories =
                factories.stream()
                        .filter(f -> factoryClass.isAssignableFrom(f.getClass()))
                        .collect(Collectors.toList());

        if (foundFactories.isEmpty()) {
            throw new ValidationException(
                    String.format(
                            "Could not find any factories that implement '%s' in the classpath.",
                            factoryClass.getName()));
        }

        final List<Factory> matchingFactories =
                foundFactories.stream()
                        .filter(f -> f.factoryIdentifier().equals(factoryIdentifier))
                        .collect(Collectors.toList());

        if (matchingFactories.isEmpty()) {
            throw new ValidationException(
                    String.format(
                            "Could not find any factory for identifier '%s' that implements '%s' in the classpath.\\n\\n"
                                    + "Available factory identifiers are:\\n\\n"
                                    + "%s",
                            factoryIdentifier,
                            factoryClass.getName(),
                            foundFactories.stream()
                                    .map(Factory::factoryIdentifier)
                                    .distinct()
                                    .sorted()
                                    .collect(Collectors.joining("\\n"))));
        }
        if (matchingFactories.size() > 1) {
            throw new ValidationException(
                    String.format(
                            "Multiple factories for identifier '%s' that implement '%s' found in the classpath.\\n\\n"
                                    + "Ambiguous factory classes are:\\n\\n"
                                    + "%s",
                            factoryIdentifier,
                            factoryClass.getName(),
                            matchingFactories.stream()
                                    .map(f -> f.getClass().getName())
                                    .sorted()
                                    .collect(Collectors.joining("\\n"))));
        }

        return (T) matchingFactories.get(0);
    }

这里的逻辑和上面加载 connector 的逻辑是一样的,同样通过 SPI 先加载所有的 format 然后根据 factoryIdentifier 过滤出满足条件的 format 这里其实就是 json 了. 返回 formatFactory 后开始创建 format 这个时候就会走到 JsonFormatFactory#createDecodingFormat 这个方法里面.真正的创建一个 DecodingFormat 对象.

createDecodingFormat 源码

@Override
    public DecodingFormat<DeserializationSchema<RowData>> createDecodingFormat(
            DynamicTableFactory.Context context, ReadableConfig formatOptions) {
        // 验证相关的参数
        FactoryUtil.validateFactoryOptions(this, formatOptions);
        // 验证 json.fail-on-missing-field 和 json.ignore-parse-errors
        validateDecodingFormatOptions(formatOptions);
  // 获取 json.fail-on-missing-field 和 json.ignore-parse-errors
        final boolean failOnMissingField = formatOptions.get(FAIL_ON_MISSING_FIELD);
        final boolean ignoreParseErrors = formatOptions.get(IGNORE_PARSE_ERRORS);
        // 获取 timestamp-format.standard
        TimestampFormat timestampOption = JsonOptions.getTimestampFormat(formatOptions);

        return new DecodingFormat<DeserializationSchema<RowData>>() {
            @Override
            public DeserializationSchema<RowData> createRuntimeDecoder(
                    DynamicTableSource.Context context, DataType producedDataType) {
                final RowType rowType = (RowType) producedDataType.getLogicalType();
                final TypeInformation<RowData> rowDataTypeInfo =
                        context.createTypeInformation(producedDataType);
                return new JsonRowDataDeserializationSchema(
                        rowType,
                        rowDataTypeInfo,
                        failOnMissingField,
                        ignoreParseErrors,
                        timestampOption);
            }

            @Override
            public ChangelogMode getChangelogMode() {
                return ChangelogMode.insertOnly();
            }
        };
    }

这里的逻辑也非常简单,首先会对 format 相关的参数进行验证, 然后验证 json.fail-on-missing-field 和 json.ignore-parse-errors 这两个参数.之后就开始创建 JsonRowDataDeserializationSchema 对象

JsonRowDataDeserializationSchema 源码

public JsonRowDataDeserializationSchema(
        RowType rowType,
        TypeInformation<RowData> resultTypeInfo,
        boolean failOnMissingField,
        boolean ignoreParseErrors,
        TimestampFormat timestampFormat) {
    if (ignoreParseErrors && failOnMissingField) {
        throw new IllegalArgumentException(
                "JSON format doesn't support failOnMissingField and ignoreParseErrors are both enabled.");
    }
    this.resultTypeInfo = checkNotNull(resultTypeInfo);
    this.failOnMissingField = failOnMissingField;
    this.ignoreParseErrors = ignoreParseErrors;
    this.runtimeConverter =
            new JsonToRowDataConverters(failOnMissingField, ignoreParseErrors, timestampFormat)
                    .createConverter(checkNotNull(rowType));
    this.timestampFormat = timestampFormat;
    boolean hasDecimalType =
            LogicalTypeChecks.hasNested(rowType, t -> t instanceof DecimalType);
    if (hasDecimalType) {
        objectMapper.enable(DeserializationFeature.USE_BIG_DECIMAL_FOR_FLOATS);
    }
    objectMapper.configure(JsonReadFeature.ALLOW_UNESCAPED_CONTROL_CHARS.mappedFeature(), true);
}

在构造方法里面最重要的是创建 JsonToRowDataConverter 对象,这里面方法的调用比较多,这里只重要的方法进行说明

createRowConverter 源码

public JsonToRowDataConverter createRowConverter(RowType rowType) {
    final JsonToRowDataConverter[] fieldConverters =
            rowType.getFields().stream()
                    .map(RowType.RowField::getType)
                    .map(this::createConverter)
                    .toArray(JsonToRowDataConverter[]::new);
    final String[] fieldNames = rowType.getFieldNames().toArray(new String[0]);

    return jsonNode -> {
        ObjectNode node = (ObjectNode) jsonNode;
        int arity = fieldNames.length;
        GenericRowData row = new GenericRowData(arity);
        for (int i = 0; i < arity; i++) {
            String fieldName = fieldNames[i];
            JsonNode field = node.get(fieldName);
            Object convertedField = convertField(fieldConverters[i], fieldName, field);
            row.setField(i, convertedField);
        }
        return row;
    };
}

因为是 JSON 格式的数据,所以是一个 ROW 类型,所以要先创建 JsonToRowDataConverter 对象,然后在这里会对每一个字段创建一个 fieldConverter 根据你在 DDL 里面定义的字段类型走不同的转换方法,比如 String 类型的数据会调用 convertToString 方法

convertToString 源码

private StringData convertToString(JsonNode jsonNode) {
    if (jsonNode.isContainerNode()) {
        return StringData.fromString(jsonNode.toString());
    } else {
        return StringData.fromString(jsonNode.asText());
    }
}

这里需要注意的是 string 类型的数据需要返回 StringData 类型不然会报类型转换异常的错.感兴趣的朋友可以看下其他类型是如何处理的.

到这里 JsonRowDataDeserializationSchema 对象就构造完成了.那后面其实就是优化,转换到翻译成 streamGraph 再后面的过程就和 datastream api 开发的任务一样了.

然后真正开始消费数据的时候,会走到 JsonRowDataDeserializationSchema#deserialize 方法对数据进行反序列化.

deserialize 源码

@Override
public RowData deserialize(@Nullable byte[] message) throws IOException {
    if (message == null) {
        return null;
    }
    try {
        return convertToRowData(deserializeToJsonNode(message));
    } catch (Throwable t) {
        if (ignoreParseErrors) {
            return null;
        }
        throw new IOException(
                format("Failed to deserialize JSON '%s'.", new String(message)), t);
    }
}

先会把数据反序列成 JsonNode 对象.

deserializeToJsonNode 源码

public JsonNode deserializeToJsonNode(byte[] message) throws IOException {
    return objectMapper.readTree(message);
}

可以看到 Flink 的内部是用 jackson 解析数据的.接着把 jsonNode 格式的数据转换成 RowData 格式的数据

convertToRowData 源码

public RowData convertToRowData(JsonNode message) {
    return (RowData) runtimeConverter.convert(message);
}

然后这里的调用其实和上面构造 JsonRowDataDeserializationSchema 的时候是一样的

return jsonNode -> {
    ObjectNode node = (ObjectNode) jsonNode;
    int arity = fieldNames.length;
    GenericRowData row = new GenericRowData(arity);
    for (int i = 0; i < arity; i++) {
        String fieldName = fieldNames[i];
        JsonNode field = node.get(fieldName);
        Object convertedField = convertField(fieldConverters[i], fieldName, field);
        row.setField(i, convertedField);
    }
    return row;
};

最终返回的是 GenericRowData 类型的数据,其实就是 RowData 类型的,因为是 RowData 的实现类.然后就会把反序列后的数据发送到下游了.

总结

这篇文章主要分析了 Flink SQL JSON Format 的相关源码,从构建 JsonRowDataDeserializationSchema 到反序列化数据 deserialize.因为篇幅原因,只展示每个环节最重要的代码,其实很多细节都直接跳过了.感兴趣的朋友也可以自己去调试一下代码.有时间的话会更新更多的实现细节.

推荐阅读

Flink 任务实时监控最佳实践

Flink on yarn 实时日志收集最佳实践

Flink 1.14.0 全新的 Kafka Connector

Flink 1.14.0 消费 kafka 数据自定义反序列化类

如果你觉得文章对你有帮助,麻烦点一下在看吧,你的支持是我创作的最大动力.

以上是关于Flink SQL JSON Format 源码解析的主要内容,如果未能解决你的问题,请参考以下文章

Flink SQL实战演练之自定义Table Format

Flink SQL 知其所以然(五) 自定义 protobuf format

Flink SQL 解析复杂(嵌套)JSON

源码Flink sql 流式去重源码解析

源码Flink sql 流式去重源码解析

Flink 1.17 Flink-SQL-Gateway HiveServer2 源码分析