关于自定义模型的一个问题

来源:1-8 Google_cloud_远程jupyter_notebook配置

兔儿月饼

2019-06-24

老师您好,在第3章中我们写了一个继承了keras.layers.Layer 的类,模拟实现了一个全连接层,那么能不能在自定义模型中自由组合keras.layers下一些现有的层次呢?比如搭建一个ResNet的残差连接块,如果能的话老师能不能给点提示- ̗̀(๑ᵔ_ᵔ๑)

写回答

1回答

正十七

2019-06-27

自定义模型中依然可以用keas.layers下的层次。比如

# customized dense layer.
class CustomizedDenseLayer(keras.layers.Layer):
    def __init__(self, units, activation=None, **kwargs):
        self.units = units
        self.activation = keras.layers.Activation(activation)
        super(CustomizedDenseLayer, self).__init__(**kwargs)
    
    def build(self, input_shape):
        """构建所需要的参数"""
        self.layer2 = tf.keras.layers.Dense(self.units)
        self.layer1 = tf.keras.layers.Dense(100, input_shape=input_shape)
        """

        # x * w + b. input_shape:[None, a] w:[a,b]output_shape: [None, b]
        self.kernel = self.add_weight(name = 'kernel',
                                      shape = (input_shape[1], self.units),
                                      initializer = 'uniform',
                                      trainable = True)
        self.bias = self.add_weight(name = 'bias',
                                    shape = (self.units, ),
                                    initializer = 'zeros',
                                    trainable = True)
        """
        super(CustomizedDenseLayer, self).build(input_shape)
    
    def call(self, x):
        """完成正向计算"""
        hidden1 = self.layer1(x)
        hidden2 = self.layer2(hidden1)
        return self.activation(hidden2)

model = keras.models.Sequential([
    CustomizedDenseLayer(30, activation='relu',
                         input_shape=x_train.shape[1:]),
    CustomizedDenseLayer(1),
    customized_softplus,
    # keras.layers.Dense(1, activation="softplus"),
    # keras.layers.Dense(1), keras.layers.Activation('softplus'),
])
model.summary()
model.compile(loss="mean_squared_error", optimizer="sgd")
callbacks = [keras.callbacks.EarlyStopping(
    patience=5, min_delta=1e-2)]

如果想实现resnet的话,建议改造https://keras.io/examples/cifar10_resnet/的模型实现。

0
1
兔儿月饼
非常感谢!
2019-06-29
共1条回复

Google老师亲授 TensorFlow2.0 入门到进阶

Tensorflow2.0实战—以实战促理论的方式学习深度学习

1849 学习 · 896 问题

查看课程