莫烦theano学习自修第三天【共享变量】

时间:2021-12-13 04:02:55

1. 代码实现

#!/usr/bin/env python
#! _*_ coding:UTF-8 _*_

import numpy as np
import theano.tensor as T
import theano

if __name__ == "__main__":

    # 用一个累加器来测试共享变量
    state = theano.shared(np.array(0, dtype=np.float64), 'state')

    inc = T.scalar('inc', dtype=state.dtype)

    accumulator = theano.function([inc], state, updates=[
        (state, state + inc)
    ])

    print state.get_value()
    accumulator(10)
    print state.get_value()
    # 这里不宜直接用print accumulaot(10)进行取值

    # 设置共享变量的值
    state.set_value(-1)
    accumulator(3)
    print state.get_value()

    # 使用另一个变量暂时代替共享变量进行赋值
    tmp_function = state * 2 + inc
    a = T.scalar(dtype=state.dtype)
    skip_shared = theano.function([inc, a], tmp_function, givens=[
        (state, a)
    ])

    print skip_shared(2, 3)
    print state.get_value()

结果:

/Users/liudaoqiang/PycharmProjects/numpy/venv/bin/python /Users/liudaoqiang/Project/python_project/theano_day3/shared_value.py
0.0
10.0
2.0
8.0
2.0

Process finished with exit code 0