使用beam&tf变换创建通用句子编码器嵌入时出错

Posted

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了使用beam&tf变换创建通用句子编码器嵌入时出错相关的知识,希望对你有一定的参考价值。

我有一个简单的光束点,它使用一些带有tf变换的通用句子编码器来获取一些文本并进行嵌入。与使用tf 1进行的演示非常相似。

import tensorflow as tf
import apache_beam as beam
import tensorflow_transform.beam as tft_beam
import tensorflow_transform.coders as tft_coders
from apache_beam.options.pipeline_options import PipelineOptions
import tempfile

model = None

def embed_text(text):
    import tensorflow_hub as hub
    global model
    if model is None:
        model = hub.load(
            'https://tfhub.dev/google/universal-sentence-encoder/4')
    embedding = model(text)
    return embedding


def get_metadata():
    from tensorflow_transform.tf_metadata import dataset_schema
    from tensorflow_transform.tf_metadata import dataset_metadata

    metadata = dataset_metadata.DatasetMetadata(dataset_schema.Schema({
        'id': dataset_schema.ColumnSchema(
            tf.string, [], dataset_schema.FixedColumnRepresentation()),
        'text': dataset_schema.ColumnSchema(
            tf.string, [], dataset_schema.FixedColumnRepresentation())
    }))
    return metadata


def preprocess_fn(input_features):

    text_integerized = embed_text(input_features['text'])
    output_features = {
        'id': input_features['id'],
        'embedding': text_integerized
    }
    return output_features


def run(pipeline_options, known_args):
    argv = None  # if None, uses sys.argv
    pipeline_options = PipelineOptions(argv)

    pipeline = beam.Pipeline(options=pipeline_options)
    with tft_beam.Context(temp_dir=tempfile.mkdtemp()):
        articles = (
                pipeline
                | beam.Create([
            {'id':'01','text':'To be, or not to be: that is the question: '},
            {'id':'02','text':"Whether 'tis nobler in the mind to suffer "},
            {'id':'03','text':'The slings and arrows of outrageous fortune, '},
            {'id':'04','text':'Or to take arms against a sea of troubles, '},
        ]))

        articles_dataset = (articles, get_metadata())

        transformed_dataset, transform_fn = (
                articles_dataset
                | 'Extract embeddings' >> tft_beam.AnalyzeAndTransformDataset(preprocess_fn)
        )

        transformed_data, transformed_metadata = transformed_dataset

        _ = (
            transformed_data | 'Write embeddings to TFRecords' >> beam.io.tfrecordio.WriteToTFRecord(
            file_path_prefix='{0}'.format(known_args.output_dir),
            file_name_suffix='.tfrecords',
            coder=tft_coders.example_proto_coder.ExampleProtoCoder(
                transformed_metadata.schema),
            num_shards=1
            )
    )
    result = pipeline.run()
    result.wait_until_finished()

python 3.6.8,tf == 2.0,tf_transform == 0.15,apache-beam [gcp] == 0.16(我尝试过来自https://github.com/tensorflow/transform的各种兼容组合)

当tf_transform调用图形分析器时出现错误:

...
  File "/Users/justingrace/.pyenv/versions/hlx36/lib/python3.6/site-packages/tensorflow_transform/beam/impl.py", line 462, in process
    lambda: self._make_graph_state(saved_model_dir))
  File "/Users/justingrace/.pyenv/versions/hlx36/lib/python3.6/site-packages/tfx_bsl/beam/shared.py", line 221, in acquire
    return _shared_map.acquire(self._key, constructor_fn)
  File "/Users/justingrace/.pyenv/versions/hlx36/lib/python3.6/site-packages/tfx_bsl/beam/shared.py", line 184, in acquire
    result = control_block.acquire(constructor_fn)
  File "/Users/justingrace/.pyenv/versions/hlx36/lib/python3.6/site-packages/tfx_bsl/beam/shared.py", line 87, in acquire
    result = constructor_fn()
  File "/Users/justingrace/.pyenv/versions/hlx36/lib/python3.6/site-packages/tensorflow_transform/beam/impl.py", line 462, in <lambda>
    lambda: self._make_graph_state(saved_model_dir))
  File "/Users/justingrace/.pyenv/versions/hlx36/lib/python3.6/site-packages/tensorflow_transform/beam/impl.py", line 438, in _make_graph_state
    self._exclude_outputs, self._tf_config)
  File "/Users/justingrace/.pyenv/versions/hlx36/lib/python3.6/site-packages/tensorflow_transform/beam/impl.py", line 357, in __init__
    tensor_inputs = graph_tools.get_dependent_inputs(graph, inputs, fetches)
  File "/Users/justingrace/.pyenv/versions/hlx36/lib/python3.6/site-packages/tensorflow_transform/graph_tools.py", line 686, in get_dependent_inputs
    sink_tensors_ready)
  File "/Users/justingrace/.pyenv/versions/hlx36/lib/python3.6/site-packages/tensorflow_transform/graph_tools.py", line 499, in __init__
    table_init_op, graph_analyzer_for_table_init, translate_path_fn)
  File "/Users/justingrace/.pyenv/versions/hlx36/lib/python3.6/site-packages/tensorflow_transform/graph_tools.py", line 560, in _get_table_init_op_source_info
    if table_init_op.type not in _TABLE_INIT_OP_TYPES:
AttributeError: 'Tensor' object has no attribute 'type' [while running 'Extract embeddings/TransformDataset/Transform']
Exception ignored in: <bound method CapturableResourceDeleter.__del__ of <tensorflow.python.training.tracking.tracking.CapturableResourceDeleter object at 0x14152fbe0>>
Traceback (most recent call last):
  File "/Users/justingrace/.pyenv/versions/hlx36/lib/python3.6/site-packages/tensorflow_core/python/training/tracking/tracking.py", line 190, in __del__
  File "/Users/justingrace/.pyenv/versions/hlx36/lib/python3.6/site-packages/tensorflow_core/python/framework/ops.py", line 3872, in as_default
  File "/Users/justingrace/.pyenv/versions/3.6.8/lib/python3.6/contextlib.py", line 159, in helper
TypeError: 'NoneType' object is not callable

似乎图分析器期望具有类型属性的操作列表,但它正在接收张量。除了图形分析器中的错误或tfx_bsl的兼容性问题外,我无法理解为什么会发生此错误(pyarrow 0.14似乎存在问题,因此我将其降级为0.13)

冻结点的输出:

absl-py==0.8.1
annoy==1.12.0
apache-beam==2.16.0
appnope==0.1.0
astor==0.8.1
astunparse==1.6.3
attrs==19.1.0
avro-python3==1.9.1
backcall==0.1.0
bleach==3.1.0
cachetools==3.1.1
certifi==2019.11.28
chardet==3.0.4
crcmod==1.7
cymem==1.31.2
cytoolz==0.9.0.1
decorator==4.4.1
defusedxml==0.6.0
dill==0.3.0
docopt==0.6.2
en-core-web-lg==2.0.0
en-coref-lg==3.0.0
en-ner-trained==2.0.0
entrypoints==0.3
fastavro==0.21.24
fasteners==0.15
flashtext==2.7
future==0.18.2
fuzzywuzzy==0.16.0
gast==0.2.2
google-api-core==1.16.0
google-apitools==0.5.28
google-auth==1.11.0
google-auth-oauthlib==0.4.1
google-cloud-bigquery==1.17.1
google-cloud-bigtable==1.0.0
google-cloud-core==1.3.0
google-cloud-datastore==1.7.4
google-cloud-pubsub==1.0.2
google-pasta==0.1.8
google-resumable-media==0.4.1
googleapis-common-protos==1.51.0
grpc-google-iam-v1==0.12.3
grpcio==1.24.0
h5py==2.10.0
hdfs==2.5.8
httplib2==0.12.0
idna==2.8
importlib-metadata==1.5.0
ipykernel==5.1.4
ipython==7.12.0
ipython-genutils==0.2.0
ipywidgets==7.5.1
jedi==0.16.0
Jinja2==2.11.1
jsonpickle==1.2
jsonschema==3.2.0
jupyter==1.0.0
jupyter-client==5.3.4
jupyter-console==6.1.0
jupyter-core==4.6.2
Keras-Applications==1.0.8
Keras-Preprocessing==1.1.0
lxml==4.2.1
Markdown==3.2.1
MarkupSafe==1.1.1
mistune==0.8.4
mock==2.0.0
monotonic==1.5
more-itertools==8.2.0
msgpack==0.6.2
msgpack-numpy==0.4.4
murmurhash==0.28.0
nbconvert==5.6.1
nbformat==5.0.4
networkx==2.1
nltk==3.4.5
notebook==6.0.3
numpy==1.18.1
oauth2client==3.0.0
oauthlib==3.1.0
opt-einsum==3.1.0
packaging==20.1
pandas==0.23.0
pandocfilters==1.4.2
parso==0.6.1
pathlib2==2.3.5
pbr==5.4.4
pexpect==4.8.0
pickleshare==0.7.5
plac==0.9.6
pluggy==0.13.1
preshed==1.0.1
prometheus-client==0.7.1
prompt-toolkit==3.0.3
proto-google-cloud-datastore-v1==0.90.4
protobuf==3.11.3
psutil==5.6.7
ptyprocess==0.6.0
py==1.8.1
pyahocorasick==1.4.0
pyarrow==0.13.0
pyasn1==0.4.8
pyasn1-modules==0.2.8
pydot==1.4.1
Pygments==2.5.2
PyHamcrest==1.9.0
pymongo==3.10.1
pyparsing==2.4.6
pyrsistent==0.15.7
pytest==5.3.5
python-dateutil==2.8.0
python-Levenshtein==0.12.0
pytz==2019.3
PyYAML==3.13
pyzmq==18.1.1
qtconsole==4.6.0
regex==2017.4.5
repoze.lru==0.7
requests==2.22.0
requests-oauthlib==1.3.0
rsa==4.0
scikit-learn==0.19.1
scipy==1.4.1
Send2Trash==1.5.0
six==1.14.0
spacy==2.0.12
tb-nightly==2.2.0a20200217
tensorboard==2.0.2
tensorflow==2.0.0
tensorflow-estimator==2.0.1
tensorflow-hub==0.6.0
tensorflow-metadata==0.15.2
tensorflow-serving-api==2.1.0
tensorflow-transform==0.15.0
termcolor==1.1.0
terminado==0.8.3
testpath==0.4.4
textblob==0.15.1
tf-estimator-nightly==2.1.0.dev2020012309
tf-nightly==2.2.0.dev20200217
tfx-bsl==0.15.0
thinc==6.10.3
toolz==0.10.0
tornado==6.0.3
tqdm==4.23.3
traitlets==4.3.3
typing==3.7.4.1
typing-extensions==3.7.4.1
ujson==1.35
Unidecode==1.0.22
urllib3==1.25.8
wcwidth==0.1.8
webencodings==0.5.1
Werkzeug==1.0.0
Whoosh==2.7.4
widgetsnbextension==3.5.1
wrapt==1.11.2
zipp==2.2.0
答案

根据此github帖子,这可能是一个潜在的问题。尝试使用tensorflow(2.1.0)的更新版本,或者甚至使用keras软件包的更新版本。

以上是关于使用beam&tf变换创建通用句子编码器嵌入时出错的主要内容,如果未能解决你的问题,请参考以下文章

用于多句子文本相似度的通用句子编码器

从通用句子编码器输出为 LSTM 生成输入

句子编码和语境化词嵌入有啥区别?

如何从头开始训练通用句子编码器

Beam Search生成的句子基本都一样,是不是有方法扩展生成句子的多样性?

apache beam入门之组装数据变换过程