自定义损失函数如何除了y_true, y_pred,如何添加自己需要的参数呢?

来源:3-5 实战自定义损失函数与DenseLayer回顾

BloodHound_swh

2019-06-18

例如

def rmse(matrix_true, matrix_pre, BI_train=tf_UIBI_train):
    loss = tf.reduce_sum(tf.square(tf.multiply(matrix_pre, BI_train) - tf.multiply(matrix_true, BI_train)))
    return loss

我希望BI_train的值是我之前定义好的变量tf_UIBI_train。但是这样写报错

AssertionError                            Traceback (most recent call last)
 in ()
     23 
     24 model.summary()
---> 25 model.compile(loss=rmse, optimizer=op)

~\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\training\tracking\base.py in _method_wrapper(self, *args, **kwargs)
    454     self._setattr_tracking = False  # pylint: disable=protected-access
    455     try:
--> 456       result = method(self, *args, **kwargs)
    457     finally:
    458       self._setattr_tracking = previous_value  # pylint: disable=protected-access

~\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\keras\engine\training.py in compile(self, optimizer, loss, metrics, loss_weights, sample_weight_mode, weighted_metrics, target_tensors, distribute, **kwargs)
    428       #                   loss_weight_2 * output_2_loss_fn(...) +
    429       #                   layer losses.
--> 430       self.total_loss = self._prepare_total_loss(skip_target_indices, masks)
    431 
    432       # Functions for train, test and predict will

~\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\keras\engine\training.py in _prepare_total_loss(self, skip_target_indices, masks)
   1684             loss_fn.reduction = losses_utils.ReductionV2.NONE
   1685             weighted_losses = loss_fn(
-> 1686                 y_true, y_pred, sample_weight=sample_weight)
   1687             loss_fn.reduction = current_loss_reduction
   1688 

~\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\keras\losses.py in __call__(self, y_true, y_pred, sample_weight)
     94     with ops.name_scope(scope_name, format(self.__class__.__name__),
     95                         (y_pred, y_true, sample_weight)):
---> 96       losses = self.call(y_true, y_pred)
     97       return losses_utils.compute_weighted_loss(
     98           losses, sample_weight, reduction=self.reduction)

~\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\keras\losses.py in call(self, y_true, y_pred)
    156       Loss values per sample.
    157     """
--> 158     return self.fn(y_true, y_pred, **self._fn_kwargs)
    159 
    160   def get_config(self):

 in rmse(matrix_true, matrix_pre, BI_train)
      2 
      3 def rmse(matrix_true, matrix_pre, BI_train=tf_UIBI_train):
----> 4     loss = tf.reduce_sum(tf.square(tf.multiply(matrix_pre, BI_train) - tf.multiply(matrix_true, BI_train)))
      5     return loss
      6 

~\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\util\dispatch.py in wrapper(*args, **kwargs)
    178     """Call target, and fall back on dispatchers if there is a TypeError."""
    179     try:
--> 180       return target(*args, **kwargs)
    181     except (TypeError, ValueError):
    182       # Note: convert_to_eager_tensor currently raises a ValueError, not a

~\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\ops\math_ops.py in multiply(x, y, name)
    297 @dispatch.add_dispatch_support
    298 def multiply(x, y, name=None):
--> 299   return gen_math_ops.mul(x, y, name)
    300 
    301 

~\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\ops\gen_math_ops.py in mul(x, y, name)
   6646   # Add nodes to the TensorFlow graph.
   6647   _, _, _op = _op_def_lib._apply_op_helper(
-> 6648         "Mul", x=x, y=y, name=name)
   6649   _result = _op.outputs[:]
   6650   _inputs_flat = _op.inputs

~\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\framework\op_def_library.py in _apply_op_helper(self, op_type_name, name, **keywords)
    619           if input_arg.type_attr in attrs:
    620             if attrs[input_arg.type_attr] != attr_value:
--> 621               assert False, "Unreachable"
    622           else:
    623             for base_type in base_types:

AssertionError: Unreachable
[5]

写回答

1回答

正十七

2019-08-01

同学你好,我这样写是可以的,不知到是不是版本的问题?我用的是2.0.0-beta1

matrix_pre = tf.constant([[1]], dtype=tf.float32)
def customized_mse(y_true, y_pred, a = matrix_pre):
    return tf.reduce_mean(tf.square(tf.multiply(a, y_pred) - tf.multiply(a, y_true)))

model = keras.models.Sequential([
    keras.layers.Dense(30, activation="relu", input_shape=x_train_scaled.shape[1:]),
    keras.layers.Dense(1),
])
model.compile(loss=customized_mse, optimizer="sgd", metrics=["mean_squared_error"])
model.fit(x_train_scaled, y_train, epochs=2,
          validation_data = (x_valid_scaled, y_valid))
model.evaluate(x_test_scaled, y_test)


0
0

Google老师亲授 TensorFlow2.0 入门到进阶

Tensorflow2.0实战—以实战促理论的方式学习深度学习

1849 学习 · 896 问题

查看课程