我想通过以下嵌入方式将文本分为2类:https://tfhub.dev/google/universal-sentence-encoder-multilingual/3
我还想在嵌入后添加其他功能。所以我有两个输入:
import tensorflow as tf
import tensorflow_hub as tfh
import tensorflow_datasets as tfds
import tensorflow_text as tft
hp = {
'embedding': 'https://tfhub.dev/google/universal-sentence-encoder-multilingual/3' EMBEDDINGS['senm'],'units': 64,'learning_rate': 1e-3,'dropout': 0.2,'layers': 2
}
textInput = tf.keras.Input(shape=(1,),name = 'text',dtype = tf.string)
featuresInput = tf.keras.Input(shape=(36,name = 'features')
x = tfh.KerasLayer(hp.get('embedding'),dtype = tf.string,trainable = False)(textInput)
x = tf.keras.layers.concatenate([x,featuresInput])
for index in range(hp.get('layers')):
x = tf.keras.layers.Dense(hp.get('units'),activation = 'relu')(x)
x = tf.keras.layers.Dropout(hp.get('dropout'))(x)
output = tf.keras.layers.Dense(
1,activation = 'sigmoid',bias_initializer = tf.keras.initializers.Constant(INITIAL_BIAS) if INITIAL_BIAS else None
)(x)
model = tf.keras.Model(inputs = [textInput,featuresInput],outputs = output)
model.compile(
optimizer = tf.keras.optimizers.Adam(lr = hp.get('learning_rate')),loss = tf.keras.losses.BinaryCrossentropy(),metrics = METRICS,)
代码失败并出现错误:
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
<ipython-input-17-61aed6f885c9> in <module>
10 featuresInput = tf.keras.Input(shape=(36,name = 'features')
11
---> 12 x = tfh.KerasLayer(hp.get('embedding'),trainable = False)(textInput)
13 x = tf.keras.layers.concatenate([x,featuresInput])
14
~/.virtualenvs/python3/lib/python3.6/site-packages/tensorflow/python/keras/engine/base_layer.py in __call__(self,*args,**kwargs)
920 not base_layer_utils.is_in_eager_or_tf_function()):
921 with auto_control_deps.AutomaticControlDependencies() as acd:
--> 922 outputs = call_fn(cast_inputs,**kwargs)
923 # Wrap Tensors in `outputs` in `tf.identity` to avoid
924 # circular dependencies.
~/.virtualenvs/python3/lib/python3.6/site-packages/tensorflow/python/autograph/impl/api.py in wrapper(*args,**kwargs)
263 except Exception as e: # pylint:disable=broad-except
264 if hasattr(e,'ag_error_metadata'):
--> 265 raise e.ag_error_metadata.to_exception(e)
266 else:
267 raise
AssertionError: in user code:
/home/e/.virtualenvs/python3/lib/python3.6/site-packages/tensorflow_hub/keras_layer.py:222 call *
result = f()
/home/e/.virtualenvs/python3/lib/python3.6/site-packages/tensorflow/python/saved_model/load.py:486 _call_attribute **
return instance.__call__(*args,**kwargs)
/home/e/.virtualenvs/python3/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py:580 __call__
result = self._call(*args,**kwds)
/home/e/.virtualenvs/python3/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py:650 _call
return self._concrete_stateful_fn._filtered_call(canon_args,canon_kwds) # pylint: disable=protected-access
/home/e/.virtualenvs/python3/lib/python3.6/site-packages/tensorflow/python/eager/function.py:1665 _filtered_call
self.captured_inputs)
/home/e/.virtualenvs/python3/lib/python3.6/site-packages/tensorflow/python/eager/function.py:1759 _call_flat
"StatefulPartitionedCall": self._get_gradient_function()}):
/usr/lib/python3.6/contextlib.py:81 __enter__
return next(self.gen)
/home/e/.virtualenvs/python3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py:4735 _override_gradient_function
assert not self._gradient_function_map
AssertionError:
但是如果我使用顺序的,它会起作用
model = tf.keras.Sequential([
hub.KerasLayer(embedding,input_shape=[],trainable = True),tf.keras.layers.Dense(16,activation = 'relu',input_shape = (train_features.shape[-1],)),tf.keras.layers.Dropout(0.5),tf.keras.layers.Dense(1,activation = 'sigmoid',bias_initializer = output_bias),])
model.compile(optimizer = tf.keras.optimizers.Adam(lr=1e-3),loss = tf.keras.losses.BinaryCrossentropy(),metrics = metrics)
函数API是否有做错什么?您能帮我解决错误