博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
tf keras Dense源码解析
阅读量:2134 次
发布时间:2019-04-30

本文共 5212 字,大约阅读时间需要 17 分钟。

环境
package version
tensorflow 2.3.0
keras 2.4.3
源码
class Dense(Layer):  def __init__(self,               units,               activation=None,               use_bias=True,               kernel_initializer='glorot_uniform',               bias_initializer='zeros',               kernel_regularizer=None,               bias_regularizer=None,               activity_regularizer=None,               kernel_constraint=None,               bias_constraint=None,               **kwargs):    super(Dense, self).__init__(        activity_regularizer=activity_regularizer, **kwargs)    self.units = int(units) if not isinstance(units, int) else units    self.activation = activations.get(activation)    self.use_bias = use_bias    self.kernel_initializer = initializers.get(kernel_initializer)    self.bias_initializer = initializers.get(bias_initializer)    self.kernel_regularizer = regularizers.get(kernel_regularizer)    self.bias_regularizer = regularizers.get(bias_regularizer)    self.kernel_constraint = constraints.get(kernel_constraint)    self.bias_constraint = constraints.get(bias_constraint)    self.input_spec = InputSpec(min_ndim=2)    self.supports_masking = True  def build(self, input_shape):    dtype = dtypes.as_dtype(self.dtype or K.floatx())    if not (dtype.is_floating or dtype.is_complex):      raise TypeError('Unable to build `Dense` layer with non-floating point '                      'dtype %s' % (dtype,))    input_shape = tensor_shape.TensorShape(input_shape)    last_dim = tensor_shape.dimension_value(input_shape[-1])    if last_dim is None:      raise ValueError('The last dimension of the inputs to `Dense` '                       'should be defined. Found `None`.')    self.input_spec = InputSpec(min_ndim=2, axes={
-1: last_dim}) self.kernel = self.add_weight( 'kernel', shape=[last_dim, self.units], initializer=self.kernel_initializer, regularizer=self.kernel_regularizer, constraint=self.kernel_constraint, dtype=self.dtype, trainable=True) if self.use_bias: self.bias = self.add_weight( 'bias', shape=[self.units,], initializer=self.bias_initializer, regularizer=self.bias_regularizer, constraint=self.bias_constraint, dtype=self.dtype, trainable=True) else: self.bias = None self.built = True def call(self, inputs): return core_ops.dense( inputs, self.kernel, self.bias, self.activation, dtype=self._compute_dtype_object) def compute_output_shape(self, input_shape): input_shape = tensor_shape.TensorShape(input_shape) input_shape = input_shape.with_rank_at_least(2) if tensor_shape.dimension_value(input_shape[-1]) is None: raise ValueError( 'The innermost dimension of input_shape must be defined, but saw: %s' % input_shape) return input_shape[:-1].concatenate(self.units) def get_config(self): config = super(Dense, self).get_config() config.update({
'units': self.units, 'activation': activations.serialize(self.activation), 'use_bias': self.use_bias, 'kernel_initializer': initializers.serialize(self.kernel_initializer), 'bias_initializer': initializers.serialize(self.bias_initializer), 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer), 'bias_regularizer': regularizers.serialize(self.bias_regularizer), 'activity_regularizer': regularizers.serialize(self.activity_regularizer), 'kernel_constraint': constraints.serialize(self.kernel_constraint), 'bias_constraint': constraints.serialize(self.bias_constraint) }) return config

查看源码可以看到最简单的Dense总共有四个方法

  1. init 初始该层
  2. build 初始weight和bias
  3. call 计算
  4. get_config 获取config
init

创建时各个参数的含义

parms x
units 激活单元
activation 激活函数
use_bias 是否用偏移量
initializer
regularizer 权重正则化的方法
constraint 限制方法
build

初始化后就可以创建权重矩阵和偏移矩阵了(weight bias),主要运用的add_weight方法

call

计算,用的是core_ops.dense方法,以下是dense源码

def dense(inputs, kernel, bias=None, activation=None, dtype=None):  if dtype:    if inputs.dtype.base_dtype != dtype.base_dtype:      inputs = math_ops.cast(inputs, dtype=dtype)  rank = inputs.shape.rank  if rank == 2 or rank is None:    if isinstance(inputs, sparse_tensor.SparseTensor):      outputs = sparse_ops.sparse_tensor_dense_matmul(inputs, kernel)    else:      outputs = gen_math_ops.mat_mul(inputs, kernel)  # Broadcast kernel to inputs.  else:    outputs = standard_ops.tensordot(inputs, kernel, [[rank - 1], [0]])    # Reshape the output back to the original ndim of the input.    if not context.executing_eagerly():      shape = inputs.shape.as_list()      output_shape = shape[:-1] + [kernel.shape[-1]]      outputs.set_shape(output_shape)  if bias is not None:    outputs = nn_ops.bias_add(outputs, bias)  if activation is not None:    outputs = activation(outputs)  return outputs## TODO:乘法区别

这里input是个tensor,所以有rank变量,rank即tensor是几维的

一个是 sparse_ops.sparse_tensor_dense_matmul 和 gen_math_ops.mat_mul
一个是 standard_ops.tensordot

compute_output_shape

根据input和units,计算output_shape

get_config

返回config dict

转载地址:http://ilugf.baihongyu.com/

你可能感兴趣的文章
问题:Opencv(3.1.0/3.4)找不到 /opencv2/gpu/gpu.hpp 问题
查看>>
目的:使用CUDA环境变量CUDA_VISIBLE_DEVICES来限定CUDA程序所能使用的GPU设备
查看>>
问题:Mysql中字段类型为text的值, java使用selectByExample查询为null
查看>>
程序员--学习之路--技巧
查看>>
解决问题之 MySQL慢查询日志设置
查看>>
contOS6 部署 lnmp、FTP、composer、ThinkPHP5、docker详细步骤
查看>>
TP5.1模板布局中遇到的坑,配置完不生效解决办法
查看>>
PHPstudy中遇到的坑No input file specified,以及传到linux环境下遇到的坑,模板文件不存在
查看>>
TP5.1事务操作和TP5事务回滚操作多表
查看>>
composer install或composer update 或 composer require phpoffice/phpexcel 失败解决办法
查看>>
TP5.1项目从windows的Apache服务迁移到linux的Nginx服务需要注意几点。
查看>>
win10安装软件 打开时报错 找不到 msvcp120.dll
查看>>
PHPunit+Xdebug代码覆盖率以及遇到的问题汇总
查看>>
PHPUnit安装及使用
查看>>
PHP项目用xhprof性能分析(安装及应用实例)
查看>>
composer安装YII
查看>>
Sublime text3快捷键演示
查看>>
sublime text3 快捷键修改
查看>>
关于PHP几点建议
查看>>
硬盘的接口、协议
查看>>